[
  {
    "path": ".claude/skills/add-jit-kernel/SKILL.md",
    "content": "---\nname: add-jit-kernel\ndescription: Step-by-step tutorial for adding a new lightweight JIT CUDA kernel to sglang's jit_kernel module\n---\n\n# Tutorial: Adding a New JIT Kernel to SGLang\n\nThis tutorial walks through adding a simple element-wise scale operation as a JIT kernel. We'll implement `scale(x, factor) = x * factor` to demonstrate the complete workflow.\n\n## Goal\n\nAdd a new operation that scales each element of a tensor by a scalar factor:\n\n- Input: tensor `x` (CUDA) and scalar `factor` (float, passed at runtime)\n- Output: `x * factor` (element-wise), allocated internally\n- Supported dtypes: **FP16 (`torch.float16`), BF16 (`torch.bfloat16`), FP32 (`torch.float32`)**\n\n## When to use JIT vs AOT (`sgl-kernel`)\n\n- **JIT (`jit_kernel`)**: prefer this first for kernels that do **not** depend on CUTLASS or another large C++ project. It is the default choice for lightweight kernels that benefit from rapid iteration and first-use compilation.\n- **AOT (`sgl-kernel`)**: prefer this when the kernel **does** depend on CUTLASS or another large C++ project, or when it should live in `sgl-kernel/` and participate in the wheel build / torch op registration flow.\n- **Exception**: kernels that depend on `flashinfer`, or on CUTLASS that is already provided through `flashinfer`, can still be implemented as `jit_kernel`.\n\n---\n\n## Common Abstractions in `python/sglang/jit_kernel/include/sgl_kernel/`\n\n**Always prefer these abstractions over raw CUDA primitives.** They provide safety, readability, and consistency with the rest of the codebase.\n\n**Important include rule:** for every `#include <sgl_kernel/...>` line, add a short trailing comment explaining why that header is included (for example `// For TensorMatcher, SymbolicSize, SymbolicDevice`). This matches the current JIT kernel style and keeps include usage self-documenting.\n\n### `utils.h` — Host-side utilities\n\n```cpp\n#include <sgl_kernel/utils.h>\n```\n\n- **`host::RuntimeCheck(cond, args...)`** — Assert a condition at runtime; throws `PanicError` with file/line info on failure. Prefer this over bare `assert`.\n- **`host::Panic(args...)`** — Unconditionally throw a `PanicError` with a descriptive message.\n- **`host::div_ceil(a, b)`** — Integer ceiling division `(a + b - 1) / b`.\n- **`host::irange(n)`** / **`host::irange(start, end)`** — Range views for cleaner loops.\n- **`host::pointer::offset(ptr, offsets...)`** — Byte-safe pointer arithmetic on `void*`. Use this instead of raw casts.\n\n### `utils.cuh` — Device-side utilities + `LaunchKernel`\n\n```cpp\n#include <sgl_kernel/utils.cuh>\n```\n\n- **Type aliases**: `fp16_t`, `bf16_t`, `fp32_t`, `fp8_e4m3_t`, `fp8_e5m2_t` and their packed variants `fp16x2_t`, `bf16x2_t`, `fp32x2_t`, etc.\n- **`SGL_DEVICE`** — Expands to `__forceinline__ __device__`. Use on all device functions.\n- **`device::kWarpThreads`** — Constant `32`.\n- **`device::load_as<T>(ptr, offset)`** / **`device::store_as<T>(ptr, val, offset)`** — Type-safe loads/stores from `void*`.\n- **`device::pointer::offset(ptr, offsets...)`** — Pointer arithmetic on device.\n- **`host::LaunchKernel(grid, block, device_or_stream [, smem])`** — RAII kernel launcher that:\n  - Resolves the CUDA stream from a `DLDevice` via TVM-FFI automatically.\n  - Checks the CUDA error with file/line info after launch via `operator()(kernel, args...)`.\n  - Supports `.enable_pdl(bool)` for PDL (Programmatic Dependent Launch, SM90+).\n- **`host::RuntimeDeviceCheck(cudaError_t)`** — Check a CUDA error; throw on failure.\n\n### `tensor.h` — Tensor validation (`TensorMatcher`, Symbolic types)\n\n```cpp\n#include <sgl_kernel/tensor.h>\n```\n\nThis is the **primary validation API** for all kernel launchers. Use it to validate every `tvm::ffi::TensorView` argument.\n\n- **`host::SymbolicSize{\"name\"}`** — A named symbolic dimension. Call `.set_value(n)` to pin it, `.unwrap()` to extract after verification.\n- **`host::SymbolicDType`** — Symbolic dtype. Use `.set_options<Ts...>()` to restrict allowed types.\n- **`host::SymbolicDevice`** — Symbolic device. Use `.set_options<kDLCUDA>()` to restrict to CUDA.\n- **`host::TensorMatcher({dims...})`** — Fluent builder for tensor validation:\n  - `.with_dtype<T>()` — require a specific C++ type (e.g. `fp16_t`)\n  - `.with_dtype<T1, T2, ...>()` — allow a set of types\n  - `.with_device<kDLCUDA>(device_sym)` — require CUDA and bind the checked device to a `SymbolicDevice`\n  - `.with_strides({strides...})` — validate strides (omit to require contiguous)\n  - `.verify(tensor_view)` — execute the check; throws `PanicError` with full context on failure; **chainable** (`verify(a).verify(b)` to check multiple tensors with the same shape)\n\n**Typical pattern:**\n```cpp\nauto N = SymbolicSize{\"num_elements\"};\nauto device = SymbolicDevice{};\ndevice.set_options<kDLCUDA>();\nTensorMatcher({N})  //\n    .with_dtype<fp16_t>()\n    .with_device<kDLCUDA>(device)\n    .verify(dst)\n    .verify(src);  // same shape, dtype, device as dst\nconst size_t n = N.unwrap();\nconst DLDevice dev = device.unwrap();\n```\n\n### `type.cuh` — `dtype_trait<T>` and `packed_t<T>`\n\n```cpp\n#include <sgl_kernel/type.cuh>\n```\n\n- **`dtype_trait<T>`** — Static trait struct for each scalar type. Provides:\n  - `dtype_trait<T>::from(value)` — convert from another type (e.g. `fp32_t` → `fp16_t`)\n  - `dtype_trait<T>::abs/sqrt/rsqrt/exp/sin/cos(x)` — type-dispatched unary math (primarily for `fp32_t`)\n  - `dtype_trait<T>::max/min(x, y)` — type-dispatched binary math (primarily for `fp32_t`)\n- **`packed_t<T>`** — Two-element packed alias: `packed_t<fp16_t>` = `fp16x2_t`, `packed_t<bf16_t>` = `bf16x2_t`, `packed_t<fp32_t>` = `fp32x2_t`. Use for vectorized loads/stores.\n- **`device::cast<To, From>(value)`** — Type-safe cast using `dtype_trait`, e.g. `cast<fp32x2_t, fp16x2_t>(v)`.\n\n### `vec.cuh` — Vectorized memory access (`AlignedVector`)\n\n```cpp\n#include <sgl_kernel/vec.cuh>\n```\n\n- **`device::AlignedVector<T, N>`** — Aligned storage for N elements of type T. N must be a power of two, `sizeof(T)*N <= 32`. Enables vectorized loads/stores for bandwidth efficiency. In terms of API/codegen constraints, the upper bound is 256-bit; in practice, 128-bit is the portable default, while 256-bit vectorization is typically only viable on `SM100+` and should be gated by an architecture check when needed.\n  - `.load(ptr, offset)` — vectorized load from `ptr[offset]`\n  - `.store(ptr, offset)` — vectorized store to `ptr[offset]`\n  - `.fill(value)` — fill all lanes\n  - `operator[](i)` — element access\n\n### `tile.cuh` — `tile::Memory` (strided memory access pattern)\n\n```cpp\n#include <sgl_kernel/tile.cuh>\n```\n\n- `tile::Memory<T>` is fundamentally a **1D cooperative accessor** over a contiguous region.\n- **`device::tile::Memory<T>::cta(blockDim.x)`** — Creates a tile accessor where each thread handles `tid = threadIdx.x` with stride `tsize` (for `cta(blockDim.x)`, this is `blockDim.x`). Common for loops over a 1D array.\n- **`.load(ptr, offset)`** — loads `ptr[tid + offset * tsize]`\n- **`.store(ptr, val, offset)`** — stores to `ptr[tid + offset * tsize]`\n- **`.in_bound(n, offset)`** — boundary check\n\nFor a **2D tile**, either flatten `(row, col)` into a linear tile index first, or compute the address manually with `ptr[row * stride + col]` using your thread/block coordinates.\n\n### `math.cuh` — Device math (`device::math::`)\n\n```cpp\n#include <sgl_kernel/math.cuh>\n```\n\n- `device::math::max/min<T>(a, b)` — type-dispatched binary math via `dtype_trait`\n- `device::math::abs/sqrt/rsqrt/exp/sin/cos<T>(x)` — type-dispatched unary math via `dtype_trait`\n\n### `warp.cuh` — Warp-level primitives\n\n```cpp\n#include <sgl_kernel/warp.cuh>\n```\n\n- `device::warp::reduce_sum<T>(value)` — warp-level sum reduction via `__shfl_xor_sync`\n- `device::warp::reduce_max<T>(value)` — warp-level max reduction\n\n### `cta.cuh` — CTA-level primitives\n\n```cpp\n#include <sgl_kernel/cta.cuh>\n```\n\n- `device::cta::reduce_max<T>(value, smem, min_value)` — CTA-wide max using shared memory + warp reduction. Caller is responsible for a `__syncthreads()` after if the result in `smem[0]` is needed.\n\n### `atomic.cuh` — Atomic operations\n\n```cpp\n#include <sgl_kernel/atomic.cuh>\n```\n\n- `device::atomic::max(float* addr, float value)` — float atomic max (handles negative values correctly via bit tricks).\n\n### `runtime.cuh` — Occupancy and device info\n\n```cpp\n#include <sgl_kernel/runtime.cuh>\n```\n\n- `host::runtime::get_blocks_per_sm(kernel, block_dim)` — max active blocks per SM (occupancy)\n- `host::runtime::get_sm_count(device_id)` — number of SMs on the device\n- `host::runtime::get_cc_major(device_id)` — compute capability major version\n\n**Persistent kernel pattern** (cap blocks to SM count × occupancy):\n```cpp\nstatic const uint32_t max_occ = runtime::get_blocks_per_sm(kernel, kBlockSize);\nstatic const uint32_t num_sm  = runtime::get_sm_count(device.unwrap().device_id);\nconst auto num_blocks = std::min(num_sm * max_occ, div_ceil(n, kBlockSize));\nLaunchKernel(num_blocks, kBlockSize, device.unwrap())(kernel, params);\n```\n\n---\n\n## Step 0 (optional): Generate a `.clangd` config for better IDE support\n\n```bash\npython -m sglang.jit_kernel\n```\n\n---\n\n## Step 1: Implement the CUDA kernel in `jit_kernel/csrc/`\n\nCreate `python/sglang/jit_kernel/csrc/elementwise/scale.cuh`.\n\nThe implementation fully uses the project abstractions described above:\n\n```cpp\n#include <sgl_kernel/tensor.h>   // For TensorMatcher, SymbolicSize, SymbolicDevice\n#include <sgl_kernel/type.cuh>   // For dtype_trait, fp16_t, bf16_t, fp32_t\n#include <sgl_kernel/utils.h>    // For RuntimeCheck, div_ceil\n#include <sgl_kernel/utils.cuh>  // For LaunchKernel, SGL_DEVICE\n#include <sgl_kernel/vec.cuh>    // For AlignedVector\n\n#include <dlpack/dlpack.h>\n#include <tvm/ffi/container/tensor.h>\n\nnamespace {\n\n// ----------------------------------------------------------------\n// Kernel: element-wise scale using vectorized 128-bit loads/stores\n// T       = fp16_t | bf16_t | fp32_t\n// kVecN   = number of elements per vector load (e.g. 8 for fp16)\n// factor  = runtime scale factor\n// ----------------------------------------------------------------\ntemplate <typename T, int kVecN>\n__global__ void scale_kernel(T* __restrict__ dst,\n                              const T* __restrict__ src,\n                              float factor,\n                              uint32_t n_total) {\n  using vec_t = device::AlignedVector<T, kVecN>;\n  const uint32_t n_vecs = n_total / kVecN;\n\n  // --- vectorised body ---\n  const uint32_t vec_stride = blockDim.x * gridDim.x;\n  for (uint32_t vi = blockIdx.x * blockDim.x + threadIdx.x;\n       vi < n_vecs;\n       vi += vec_stride) {\n    vec_t v;\n    v.load(src, vi);\n#pragma unroll\n    for (int i = 0; i < kVecN; ++i) {\n      v[i] = static_cast<T>(static_cast<float>(v[i]) * factor);\n    }\n    v.store(dst, vi);\n  }\n\n  // --- scalar tail ---\n  const uint32_t base = n_vecs * kVecN;\n  const uint32_t scalar_stride = blockDim.x * gridDim.x;\n  for (uint32_t i = blockIdx.x * blockDim.x + threadIdx.x;\n       base + i < n_total;\n       i += scalar_stride) {\n    dst[base + i] = static_cast<T>(static_cast<float>(src[base + i]) * factor);\n  }\n}\n\n// ----------------------------------------------------------------\n// Launcher: validates tensors, selects vector width, launches kernel\n// ----------------------------------------------------------------\ntemplate <typename T>\nvoid scale(tvm::ffi::TensorView dst, tvm::ffi::TensorView src, float factor) {\n  using namespace host;\n\n  // 1. Validate input tensors with TensorMatcher\n  SymbolicSize N = {\"num_elements\"};\n  SymbolicDevice device_;\n  device_.set_options<kDLCUDA>();\n\n  TensorMatcher({N})  //\n      .with_dtype<T>()\n      .with_device<kDLCUDA>(device_)\n      .verify(dst)\n      .verify(src);  // same shape / dtype / device as dst\n\n  const uint32_t n = static_cast<uint32_t>(N.unwrap());\n  const DLDevice device = device_.unwrap();\n\n  RuntimeCheck(n > 0, \"scale: num_elements must be > 0, got \", n);\n\n  // 2. Choose vector width for 128-bit loads (16 bytes)\n  //    fp16/bf16: 8 elements × 2 bytes = 16 bytes\n  //    fp32:      4 elements × 4 bytes = 16 bytes\n  constexpr int kVecN = 16 / sizeof(T);\n  const uint32_t n_work_items = div_ceil(n, static_cast<uint32_t>(kVecN));\n\n  // 3. Launch\n  constexpr uint32_t kBlockSize = 256;\n  const uint32_t grid = div_ceil(n_work_items, kBlockSize);\n\n  LaunchKernel(grid, kBlockSize, device)(\n      scale_kernel<T, kVecN>,\n      static_cast<T*>(dst.data_ptr()),\n      static_cast<const T*>(src.data_ptr()),\n      factor,\n      n);\n}\n\n}  // namespace\n```\n\n**Key points:**\n\n- Include headers from `sgl_kernel/` — **not** raw CUDA headers for anything already covered\n- Add a short trailing `// For ...` explanation to every `#include <sgl_kernel/...>` line\n- Use `TensorMatcher` for all tensor validation; never manually check shape/dtype/device\n- Use `AlignedVector` for vectorised 128-bit loads/stores — significant bandwidth win\n- Use `LaunchKernel` — it resolves the stream and checks errors automatically\n- Use `RuntimeCheck` for runtime assertions with useful error messages\n- Prefer passing runtime scalars like `factor` directly unless compile-time specialisation is genuinely required\n- `fp16_t` / `bf16_t` / `fp32_t` are the project's type aliases (from `utils.cuh`)\n- `device::cast<To, From>` or `dtype_trait<T>::from(val)` for cross-type conversions\n- `device::math::` functions for device math instead of bare `__` intrinsics\n\n---\n\n## Step 2: Add the Python wrapper in `jit_kernel/`\n\nCreate `python/sglang/jit_kernel/scale.py`:\n\n```python\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\n\nimport torch\n\nfrom sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args\n\nif TYPE_CHECKING:\n    from tvm_ffi.module import Module\n\n\n@cache_once\ndef _jit_scale_module(dtype: torch.dtype) -> Module:\n    \"\"\"Compile and cache the JIT scale module for a given dtype.\"\"\"\n    args = make_cpp_args(dtype)\n    return load_jit(\n        \"scale\",\n        *args,\n        cuda_files=[\"elementwise/scale.cuh\"],\n        cuda_wrappers=[(\"scale\", f\"scale<{args}>\")],\n    )\n\n\ndef scale(src: torch.Tensor, factor: float, out: torch.Tensor | None = None) -> torch.Tensor:\n    \"\"\"\n    Element-wise scale: dst = src * factor.\n\n    Supported dtypes: torch.float16, torch.bfloat16, torch.float32.\n\n    Parameters\n    ----------\n    src    : CUDA tensor (FP16 / BF16 / FP32)\n    factor : scale factor\n    out    : optional pre-allocated output tensor (same shape/dtype as src)\n\n    Returns\n    -------\n    Scaled tensor (dst = src * factor).\n    \"\"\"\n    if not src.is_cuda:\n        raise RuntimeError(\"src must be a CUDA tensor\")\n    if src.dtype not in (torch.float16, torch.bfloat16, torch.float32):\n        raise RuntimeError(\n            f\"Unsupported dtype {src.dtype}. Supported: float16, bfloat16, float32\"\n        )\n    if out is None:\n        out = torch.empty_like(src)\n    else:\n        if out.shape != src.shape:\n            raise RuntimeError(\"out shape must match src\")\n        if out.dtype != src.dtype:\n            raise RuntimeError(\"out dtype must match src\")\n        if out.device != src.device:\n            raise RuntimeError(\"out device must match src\")\n\n    # Keep the Python wrapper thin, but still enforce the basic preconditions\n    # that the current JIT/FFI path does not reject safely on its own.\n    module = _jit_scale_module(src.dtype)\n    module.scale(out, src, factor)\n    return out\n```\n\n**Key points:**\n\n- Use `cache_once` — **not** `functools.lru_cache` (incompatible with `torch.compile`)\n- `load_jit` first arg(s) form the unique build marker; same marker = same cached binary\n- Only include compile-time specialisation knobs in the build marker; runtime values like `factor` should stay runtime unless the kernel truly needs templating\n- `cuda_wrappers`: `(export_name, kernel_symbol)` — `export_name` is called from Python\n- `make_cpp_args(dtype, ...)` converts `torch.dtype` to C++ type alias:\n- Keep Python launchers thin, but still validate the basic invariants (`is_cuda`, supported dtype, `out` metadata). In the current JIT/FFI path, invalid tensors are not always rejected safely before launch\n\n| `torch.dtype`      | C++ type   |\n|--------------------|------------|\n| `torch.float16`    | `fp16_t`   |\n| `torch.bfloat16`   | `bf16_t`   |\n| `torch.float32`    | `fp32_t`   |\n\n---\n\n## Step 3 (optional): Tune JIT build flags\n\n```python\nreturn load_jit(\n    \"scale\",\n    *args,\n    cuda_files=[\"elementwise/scale.cuh\"],\n    cuda_wrappers=[(\"scale\", f\"scale<{args}>\")],\n    extra_cuda_cflags=[\"-O3\", \"--use_fast_math\"],\n)\n```\n\nIf your kernel requires SM90+, raise a clear Python error before calling `load_jit`:\n\n```python\nif torch.cuda.get_device_capability()[0] < 9:\n    raise RuntimeError(\"This kernel requires SM90 (Hopper) or later\")\n```\n\n---\n\n## Step 4: Write tests (required)\n\nCreate `python/sglang/jit_kernel/tests/test_scale.py`:\n\n```python\nimport pytest\nimport torch\nfrom sglang.jit_kernel.scale import scale\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16, torch.float32])\n@pytest.mark.parametrize(\"size\", [1, 127, 128, 1024, 4097])  # cover tail remainder\n@pytest.mark.parametrize(\"factor\", [0.5, 1.0, 2.0, 3.0])\ndef test_scale_correctness(dtype, size, factor):\n    src = torch.randn(size, dtype=dtype, device=\"cuda\")\n    out = scale(src, factor)\n    expected = src * factor\n\n    rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-2, 1e-2)\n    torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16, torch.float32])\ndef test_scale_out_param(dtype):\n    src = torch.randn(1024, dtype=dtype, device=\"cuda\")\n    out = torch.empty_like(src)\n    result = scale(src, 2.0, out=out)\n    assert result is out\n    torch.testing.assert_close(out, src * 2.0, rtol=1e-2, atol=1e-2)\n\n\ndef test_scale_cpu_error():\n    src = torch.randn(128, dtype=torch.float16)  # CPU tensor\n    with pytest.raises(RuntimeError, match=\"CUDA\"):\n        scale(src, 2.0)\n\n\ndef test_scale_unsupported_dtype():\n    src = torch.randint(0, 10, (128,), dtype=torch.int32, device=\"cuda\")\n    with pytest.raises(RuntimeError, match=\"dtype\"):\n        scale(src, 2.0)\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__, \"-v\", \"-s\"])\n```\n\n---\n\n## Step 5: Add a benchmark (required)\n\nCreate `python/sglang/jit_kernel/benchmark/bench_scale.py`:\n\n```python\nimport itertools\n\nimport torch\nimport triton\nimport triton.testing\n\nfrom sglang.jit_kernel.benchmark.utils import (\n    DEFAULT_DEVICE,\n    DEFAULT_DTYPE,\n    get_benchmark_range,\n    run_benchmark,\n)\nfrom sglang.jit_kernel.scale import scale as jit_scale\n\n\nSIZE_LIST = get_benchmark_range(\n    full_range=[2**n for n in range(10, 20)],  # 1K … 512K elements\n    ci_range=[4096, 65536],\n)\n\nconfigs = list(itertools.product(SIZE_LIST))\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=[\"size\"],\n        x_vals=configs,\n        line_arg=\"provider\",\n        line_vals=[\"jit\", \"torch\"],\n        line_names=[\"SGL JIT Kernel\", \"PyTorch\"],\n        styles=[(\"blue\", \"-\"), (\"red\", \"--\")],\n        ylabel=\"us\",\n        plot_name=\"scale-performance\",\n        args={},\n    )\n)\ndef benchmark(size: int, provider: str):\n    src = torch.randn(size, dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE)\n    factor = 2.0\n\n    if provider == \"jit\":\n        fn = lambda: jit_scale(src, factor)\n    else:\n        fn = lambda: src * factor\n\n    return run_benchmark(fn)\n\n\nif __name__ == \"__main__\":\n    benchmark.run(print_data=True)\n```\n\nRun:\n\n```bash\npython python/sglang/jit_kernel/benchmark/bench_scale.py\n```\n\n---\n\n## Troubleshooting\n\n- **JIT compilation fails**: ensure the `.cuh` file is under `python/sglang/jit_kernel/csrc/`; reduce template argument combinations\n- **CUDA crash / illegal memory access**: `CUDA_LAUNCH_BLOCKING=1`; `compute-sanitizer --tool memcheck python ...`\n- **Unstable benchmark results**: `run_benchmark` uses CUDA-graph-based timing by default\n\n---\n\n## References\n\n- `docs/developer_guide/development_jit_kernel_guide.md`\n- `python/sglang/jit_kernel/utils.py` — `cache_once`, `load_jit`, `make_cpp_args`\n- `python/sglang/jit_kernel/include/sgl_kernel/tensor.h` — `TensorMatcher`, `SymbolicSize/DType/Device`\n- `python/sglang/jit_kernel/include/sgl_kernel/utils.cuh` — type aliases, `LaunchKernel`, `SGL_DEVICE`\n- `python/sglang/jit_kernel/include/sgl_kernel/vec.cuh` — `AlignedVector`\n- `python/sglang/jit_kernel/include/sgl_kernel/tile.cuh` — `tile::Memory`\n- `python/sglang/jit_kernel/include/sgl_kernel/type.cuh` — `dtype_trait`, `packed_t`, `device::cast`\n- `python/sglang/jit_kernel/include/sgl_kernel/math.cuh` — `device::math::`\n- `python/sglang/jit_kernel/include/sgl_kernel/warp.cuh` — `warp::reduce_sum/max`\n- `python/sglang/jit_kernel/include/sgl_kernel/cta.cuh` — `cta::reduce_max`\n- `python/sglang/jit_kernel/include/sgl_kernel/atomic.cuh` — `atomic::max`\n- `python/sglang/jit_kernel/include/sgl_kernel/runtime.cuh` — occupancy / SM count helpers\n- `python/sglang/jit_kernel/csrc/add_constant.cuh` — minimal runnable reference\n- `python/sglang/jit_kernel/csrc/elementwise/rmsnorm.cuh` — real example using `TensorMatcher` + `LaunchKernel` + `tile::Memory`\n- `python/sglang/jit_kernel/csrc/elementwise/qknorm.cuh` — real example using `runtime::get_blocks_per_sm` + persistent kernel pattern\n- `python/sglang/jit_kernel/benchmark/utils.py` — benchmark helpers\n\n## Summary of Files Created\n\n```\npython/sglang/jit_kernel/csrc/elementwise/scale.cuh   # NEW: CUDA kernel\npython/sglang/jit_kernel/scale.py                     # NEW: Python wrapper\npython/sglang/jit_kernel/tests/test_scale.py          # NEW: Tests\npython/sglang/jit_kernel/benchmark/bench_scale.py     # NEW: Benchmark\n```\n"
  },
  {
    "path": ".claude/skills/add-sgl-kernel/SKILL.md",
    "content": "---\nname: add-sgl-kernel\ndescription: Step-by-step tutorial for adding a heavyweight AOT CUDA/C++ kernel to sgl-kernel (including tests & benchmarks)\n---\n\n# Tutorial: Adding a New Kernel to `sgl-kernel` (AOT / Heavyweight)\n\nThis tutorial walks through adding a simple element-wise scale operation as an AOT kernel. We'll implement `scale(x, factor) = x * factor` to demonstrate the complete workflow.\n\n## Goal\n\nAdd a new operation that scales each element of a tensor by a scalar factor:\n\n- Input: tensor `x` (CUDA) and scalar `factor` (float)\n- Output: `x * factor` (element-wise, in-place or into pre-allocated `out`)\n- Supported dtypes: **FP16 (`torch.float16`), BF16 (`torch.bfloat16`), FP32 (`torch.float32`)**\n  - Dispatched via `DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16` macro (defined in `sgl-kernel/include/utils.h`)\n\n## Two rules of thumb (must follow)\n\n1. **Prefer `python/sglang/jit_kernel` first** when the kernel does **not** depend on CUTLASS or another large C++ project. This is the default path for lightweight kernels that benefit from rapid iteration.\n2. **Prefer `sgl-kernel`** when the kernel **does** depend on CUTLASS or another large C++ project, or when it should be part of the AOT wheel / torch op registration flow.\n3. **Exception**: if the dependency is `flashinfer`, or CUTLASS that is already provided through `flashinfer`, the kernel can still be implemented as `jit_kernel`.\n\nIn addition, every new kernel must ship with:\n\n- **Tests** (pytest)\n- **A benchmark script** (triton.testing)\n\n---\n\n## Repository integration map\n\nYou will typically touch these files/areas:\n\n- Implementation: `sgl-kernel/csrc/elementwise/scale.cu` (pick the right subdirectory)\n- Public declarations: `sgl-kernel/include/sgl_kernel_ops.h`\n- Torch extension registration: `sgl-kernel/csrc/common_extension.cc`\n- Build: `sgl-kernel/CMakeLists.txt` (`set(SOURCES ...)`)\n- Python API: `sgl-kernel/python/sgl_kernel/` and `sgl-kernel/python/sgl_kernel/__init__.py`\n- Tests: `sgl-kernel/tests/test_scale.py`\n- Benchmarks: `sgl-kernel/benchmark/bench_scale.py`\n\n---\n\n## Step 1: Implement the kernel in `csrc/`\n\nPick the right subdirectory:\n\n- `csrc/elementwise/` — for element-wise ops (our example)\n- `csrc/gemm/`, `csrc/attention/`, `csrc/moe/` — for other categories\n\nCreate `sgl-kernel/csrc/elementwise/scale.cu`:\n\n```cpp\n#include <ATen/cuda/CUDAContext.h>\n#include <c10/cuda/CUDAGuard.h>\n#include <torch/all.h>\n\n#include \"utils.h\"  // DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16\n\n// scale_kernel: out[i] = input[i] * factor\n// Supports float, half (__half), __nv_bfloat16 via template T\ntemplate <typename T>\n__global__ void scale_kernel(T* __restrict__ out,\n                              const T* __restrict__ input,\n                              float factor,\n                              int64_t n) {\n  int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;\n  if (idx < n) {\n    out[idx] = static_cast<T>(static_cast<float>(input[idx]) * factor);\n  }\n}\n\nvoid scale(at::Tensor& out, const at::Tensor& input, double factor) {\n  TORCH_CHECK(input.is_cuda(),       \"input must be a CUDA tensor\");\n  TORCH_CHECK(input.is_contiguous(), \"input must be contiguous\");\n  TORCH_CHECK(out.is_cuda(),         \"out must be a CUDA tensor\");\n  TORCH_CHECK(out.is_contiguous(),   \"out must be contiguous\");\n  TORCH_CHECK(out.sizes() == input.sizes(),  \"out and input must have the same shape\");\n  TORCH_CHECK(out.scalar_type() == input.scalar_type(),\n              \"out and input must have the same dtype\");\n\n  const int64_t n = input.numel();\n  const int threads = 256;\n  const int blocks  = (n + threads - 1) / threads;\n\n  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));\n\n  // Dispatches over float, float16, bfloat16\n  DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {\n    scale_kernel<c_type><<<blocks, threads, 0, stream>>>(\n        static_cast<c_type*>(out.data_ptr()),\n        static_cast<const c_type*>(input.data_ptr()),\n        static_cast<float>(factor),\n        n);\n    cudaError_t status = cudaGetLastError();\n    TORCH_CHECK(status == cudaSuccess,\n                \"scale_kernel launch failed: \", cudaGetErrorString(status));\n    return true;\n  });\n}\n```\n\n**Key points:**\n\n- Use `at::Tensor` (PyTorch tensors), `TORCH_CHECK` for validation, `at::cuda::getCurrentCUDAStream()` for stream\n- Keep Python wrappers thin; do shape/dtype/device validation in C++ right around the launch path\n- `DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16` covers `float`, `half` (FP16), `__nv_bfloat16` (BF16)\n- Add device error checking after every kernel launch\n- If a kernel only works on certain architectures, enforce that with `TORCH_CHECK` and skip logic in tests\n\n---\n\n## Step 2: Add a C++ declaration in `include/sgl_kernel_ops.h`\n\nEdit `sgl-kernel/include/sgl_kernel_ops.h`, add to the elementwise section:\n\n```cpp\nvoid scale(at::Tensor& out, const at::Tensor& input, double factor);\n```\n\n---\n\n## Step 3: Register the op in `csrc/common_extension.cc`\n\nEdit `sgl-kernel/csrc/common_extension.cc`, inside `TORCH_LIBRARY_FRAGMENT(sgl_kernel, m)`:\n\n```cpp\n// From csrc/elementwise\nm.def(\"scale(Tensor! out, Tensor input, float factor) -> ()\");\nm.impl(\"scale\", torch::kCUDA, &scale);\n```\n\n**Key points:**\n\n- `Tensor!` means in-place / mutable output argument\n- The schema is important for `torch.compile` and for consistent call signatures\n- Keep the torch schema in PyTorch scalar types (`float` here), but note that the C++ launcher signature still needs `double` for scalar arguments accepted by `torch::Library`\n\n---\n\n## Step 4: Add the new source file to `CMakeLists.txt`\n\nEdit `sgl-kernel/CMakeLists.txt`, add to `set(SOURCES ...)`:\n\n```cmake\ncsrc/elementwise/scale.cu\n```\n\n**Key points:**\n\n- Keep the list **alphabetically sorted** (the file explicitly requires this)\n- If the kernel has arch constraints, reflect that in tests/benchmarks via skip logic\n\n---\n\n## Step 5: Expose a Python API under `sgl-kernel/python/sgl_kernel/`\n\nPrefer following the existing module organization first. For elementwise kernels, the usual pattern is:\n\n- implement the Python wrapper in `sgl-kernel/python/sgl_kernel/elementwise.py`\n- then re-export it from `sgl-kernel/python/sgl_kernel/__init__.py`\n\nFor example, in `sgl-kernel/python/sgl_kernel/elementwise.py`, add:\n\n```python\nimport torch\n\ndef scale(\n    input: torch.Tensor,\n    factor: float,\n    out: torch.Tensor | None = None,\n) -> torch.Tensor:\n    \"\"\"\n    Element-wise scale: out = input * factor.\n\n    Supported dtypes: torch.float16, torch.bfloat16, torch.float32.\n\n    Parameters\n    ----------\n    input  : CUDA input tensor\n    factor : scale factor (float)\n    out    : optional pre-allocated CUDA output tensor (same shape/dtype as input)\n    \"\"\"\n    if out is None:\n        out = torch.empty_like(input)\n    torch.ops.sgl_kernel.scale.default(out, input, factor)\n    return out\n```\n\nThen re-export it from `sgl-kernel/python/sgl_kernel/__init__.py` following the existing import style used by other kernels.\n\n---\n\n## Step 6: Write tests (required)\n\nCreate `sgl-kernel/tests/test_scale.py`:\n```python\nimport pytest\n\nimport torch\nimport sgl_kernel\n\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16, torch.float32])\n@pytest.mark.parametrize(\"size\", [128, 1024, 4096, 65536])\n@pytest.mark.parametrize(\"factor\", [0.5, 1.0, 2.0])\ndef test_scale_correctness(dtype, size, factor):\n    input = torch.randn(size, dtype=dtype, device=\"cuda\")\n    out   = torch.empty_like(input)\n\n    result = sgl_kernel.scale(input, factor, out=out)\n    assert result is out\n\n    expected = input * factor\n    rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-2, 1e-2)\n    torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)\n\n\ndef test_scale_shape_mismatch():\n    input = torch.randn(128, dtype=torch.float16, device=\"cuda\")\n    out   = torch.empty(256, dtype=torch.float16, device=\"cuda\")\n    with pytest.raises(RuntimeError, match=\"same shape\"):\n        sgl_kernel.scale(input, 2.0, out=out)\n\n\ndef test_scale_cpu_input():\n    input = torch.randn(128, dtype=torch.float16)  # CPU\n    out   = torch.empty_like(input)\n    with pytest.raises(RuntimeError, match=\"CUDA\"):\n        sgl_kernel.scale(input, 2.0, out=out)\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__, \"-q\"])\n```\n\n---\n\n## Step 7: Add a benchmark (required)\n\nCreate `sgl-kernel/benchmark/bench_scale.py`:\n\n```python\nimport itertools\nimport os\n\nimport torch\nimport triton\nimport triton.testing\n\nimport sgl_kernel\n\nIS_CI = (\n    os.getenv(\"CI\", \"false\").lower() == \"true\"\n    or os.getenv(\"GITHUB_ACTIONS\", \"false\").lower() == \"true\"\n)\n\ndtypes  = [torch.float16] if IS_CI else [torch.float16, torch.bfloat16, torch.float32]\nsizes   = [4096] if IS_CI else [2**n for n in range(10, 20)]  # 1K … 512K\nfactors = [2.0]\n\nconfigs = list(itertools.product(dtypes, sizes))\n\n\ndef torch_scale(input: torch.Tensor, factor: float) -> torch.Tensor:\n    return input * factor\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=[\"dtype\", \"size\"],\n        x_vals=configs,\n        line_arg=\"provider\",\n        line_vals=[\"sglang\", \"torch\"],\n        line_names=[\"SGL Kernel\", \"PyTorch\"],\n        styles=[(\"green\", \"-\"), (\"red\", \"--\")],\n        ylabel=\"µs (median)\",\n        plot_name=\"scale-performance\",\n        args={},\n    )\n)\ndef benchmark(dtype, size, provider):\n    input  = torch.randn(size, dtype=dtype, device=\"cuda\")\n    out    = torch.empty_like(input)\n    factor = 2.0\n\n    if provider == \"sglang\":\n        fn = lambda: sgl_kernel.scale(input, factor, out=out)\n    else:\n        fn = lambda: torch_scale(input, factor)\n\n    ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(\n        fn, quantiles=[0.5, 0.2, 0.8]\n    )\n    return 1000 * ms, 1000 * max_ms, 1000 * min_ms\n\n\nif __name__ == \"__main__\":\n    benchmark.run(print_data=True)\n```\n\n---\n\n## Step 8: Build\n\nBuild:\n\n```bash\ncd sgl-kernel\nmake build -j16\n```\n\nIf you need to limit host resource usage:\n\n```bash\ncd sgl-kernel\nmake build -j1 MAX_JOBS=2 CMAKE_ARGS=\"-DSGL_KERNEL_COMPILE_THREADS=1\"\n```\n\n---\n\n## Step 9: Validate\n\nAfter building successfully, run the test and benchmark:\n\n```bash\npytest sgl-kernel/tests/test_scale.py -q\npython sgl-kernel/benchmark/bench_scale.py\n```\n\n---\n\n## Troubleshooting\n\n- **Async CUDA errors**: `CUDA_LAUNCH_BLOCKING=1`\n- **Memory errors**: `compute-sanitizer --tool memcheck python ...`\n- **Build is too slow / OOM**: reduce `MAX_JOBS` and `SGL_KERNEL_COMPILE_THREADS`\n- **Binary bloat**: use `sgl-kernel/analyze_whl_kernel_sizes.py`\n- **CMake sources list**: if your `.cu` file is missing from `SOURCES`, the symbol will be undefined at link time\n\n---\n\n## References\n\n- `sgl-kernel/README.md`\n- `sgl-kernel/include/sgl_kernel_ops.h`\n- `sgl-kernel/csrc/common_extension.cc`\n- `sgl-kernel/CMakeLists.txt`\n- `sgl-kernel/include/utils.h` — `DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16` macro and friends\n- `sgl-kernel/csrc/elementwise/activation.cu` — reference for the FP16/BF16/FP32 dispatch pattern\n\n## Summary of Files Created/Modified\n\n```\nsgl-kernel/csrc/elementwise/scale.cu          # NEW: CUDA kernel + launcher\nsgl-kernel/include/sgl_kernel_ops.h           # MODIFIED: C++ declaration\nsgl-kernel/csrc/common_extension.cc           # MODIFIED: schema + dispatch registration\nsgl-kernel/CMakeLists.txt                     # MODIFIED: add source file (alphabetical)\nsgl-kernel/python/sgl_kernel/elementwise.py   # MODIFIED: Python wrapper\nsgl-kernel/python/sgl_kernel/__init__.py      # MODIFIED: re-export Python API\nsgl-kernel/tests/test_scale.py                # NEW: tests\nsgl-kernel/benchmark/bench_scale.py           # NEW: benchmark\n```\n"
  },
  {
    "path": ".claude/skills/sglang-bisect-ci-regression/SKILL.md",
    "content": "# SGLang Bisect CI Regression\n\nInvestigate a consistently failing CI test to find the root cause - whether it's a code regression from a specific PR, a hardware/runner-specific issue, or an environment change. Optionally reproduce the failure on a remote GPU server.\n\n## Slash Command\n\n`/sglang-bisect-ci-regression <test_name_or_ci_url> [ssh_target] [docker_container]`\n\n## When to Use This Skill\n\n- A CI test is failing consistently on main (scheduled runs)\n- You need to find which PR introduced a regression\n- You suspect a runner-specific or GPU-specific issue\n- You want to reproduce a CI failure on a remote server\n\n## Arguments\n\n- **First argument (required)**: Test file name (e.g. `test_lora_tp.py`) or a GitHub Actions job URL\n- **Second argument (optional)**: SSH target for remote reproduction (e.g. `user@host`)\n- **Third argument (optional)**: Docker container name on the SSH target (e.g. `sglang_dev`)\n\nIf SSH target and docker container are not provided, the skill will only perform the CI log analysis and bisection, without remote reproduction. **Ask the user** for these if reproduction is needed and they weren't provided.\n\n## Background: Scheduled CI Runs\n\nSGLang uses the `pr-test.yml` workflow with **scheduled runs** (cron-triggered) to periodically test the `main` branch. These runs are the primary data source for detecting regressions:\n\n- **Workflow**: `pr-test.yml` with `event: schedule`\n- **Branch**: `main`\n- **Dashboard**: https://github.com/sgl-project/sglang/actions/workflows/pr-test.yml?query=event%3Aschedule\n- **Frequency**: Runs multiple times daily, each pinned to the HEAD of `main` at trigger time\n- **Purpose**: Catches regressions that slip through PR-level CI (e.g., interaction bugs between merged PRs, hardware-specific issues)\n\nAlways use these scheduled runs (not PR-triggered runs) when bisecting regressions on `main`. The `--event schedule` filter in `gh run list` ensures you only see these periodic main-branch runs.\n\n## Workflow\n\n### Phase 1: Extract the Failure Signature\n\n1. **Get the failing test details from CI logs.** If given a URL, fetch logs directly. If given a test name, find recent scheduled runs of `pr-test.yml` on `main` that failed:\n\n```bash\n# List recent scheduled runs targeting main (the primary source of truth for regressions)\n# These are cron-triggered runs visible at:\n# https://github.com/sgl-project/sglang/actions/workflows/pr-test.yml?query=event%3Aschedule\ngh run list --repo sgl-project/sglang --workflow=\"pr-test.yml\" --event schedule --branch main --limit 20 --json databaseId,conclusion,createdAt,headSha\n\n# Find the job containing the test\ngh run view {RUN_ID} --repo sgl-project/sglang --json jobs --jq '.jobs[] | select(.conclusion == \"failure\") | {name, conclusion, databaseId}'\n\n# Get the failure details\ngh run view {RUN_ID} --repo sgl-project/sglang --job {JOB_ID} --log 2>&1 | grep -E -B 5 -A 30 \"AssertionError|FAIL|Error|{TEST_NAME}\"\n```\n\n2. **Record the failure signature:**\n   - Exact error message and assertion\n   - Affected test method name\n   - Model/config involved\n   - Numeric values (e.g., tolerance diffs, scores)\n   - Whether the failure is deterministic (same values across runs)\n\n### Phase 2: Temporal Bisection\n\n3. **Find the boundary between passing and failing runs.** Walk through the scheduled run history (from the `pr-test.yml` schedule runs on `main`) to identify:\n   - Last known PASSING run (sha + date)\n   - First known FAILING run (sha + date)\n\n```bash\n# For each scheduled run, check the specific partition/job status\ngh run view {RUN_ID} --repo sgl-project/sglang --json jobs --jq '.jobs[] | select(.name == \"{JOB_NAME}\") | {conclusion, databaseId}'\n\n# Verify a specific test passed or failed in a run\ngh run view {RUN_ID} --repo sgl-project/sglang --job {JOB_ID} --log 2>&1 | grep -E \"{TEST_NAME}|PASSED|FAILED|logprobs mismatch\" | head -10\n```\n\n4. **List commits between the boundary:**\n\n```bash\ngit log --oneline {LAST_PASS_SHA}..{FIRST_FAIL_SHA}\n```\n\n5. **Filter for relevant commits** that touch files related to the failing test (model layers, kernels, test utilities, etc.):\n\n```bash\ngit log --oneline {LAST_PASS_SHA}..{FIRST_FAIL_SHA} -- {relevant_paths}\n```\n\n### Phase 3: Runner/Hardware Analysis\n\n6. **Check if the failure is runner-specific.** Extract the runner identity from each failing and passing run:\n\n```bash\n# Get runner name and machine\ngh run view {RUN_ID} --repo sgl-project/sglang --job {JOB_ID} --log 2>&1 | grep -E \"Runner name|Machine name\" | head -5\n\n# Get GPU/driver info\ngh run view {RUN_ID} --repo sgl-project/sglang --job {JOB_ID} --log 2>&1 | grep -i -E \"NVIDIA-SMI|Driver Version|CUDA Version\" | head -5\n\n# Get package versions\ngh run view {RUN_ID} --repo sgl-project/sglang --job {JOB_ID} --log 2>&1 | grep -E \"sgl.kernel.*==|flashinfer.*==\" | head -5\n```\n\n7. **Correlate runners with pass/fail outcomes.** Build a table:\n\n| Run ID | Date | Runner | GPU Type | Driver | Result |\n|--------|------|--------|----------|--------|--------|\n\nIf all failures map to a specific runner type/GPU and all passes map to another, the issue is **hardware-specific**, not a code regression.\n\n### Phase 4: Code Analysis\n\n8. **If a code regression is suspected** (failures not runner-specific), examine the candidate commits:\n   - Read the changed files\n   - Understand how the changes could affect the failing test\n   - Look for prefill-vs-decode differences, TP-specific paths, kernel changes\n\n9. **If a hardware issue is suspected**, analyze:\n   - Kernel compatibility (CUDA compute capability)\n   - Driver version differences\n   - All-reduce / NCCL behavior differences\n   - CUDA graph capture differences across GPU architectures\n\n### Phase 5: Remote Reproduction (Optional)\n\nOnly if SSH target and docker container were provided.\n\n10. **Verify the remote environment:**\n\n```bash\nssh {SSH_TARGET} \"docker exec {CONTAINER} nvidia-smi --query-gpu=name,driver_version --format=csv\"\nssh {SSH_TARGET} \"docker exec {CONTAINER} pip show sgl-kernel sglang flashinfer-python 2>&1 | grep -E 'Name:|Version:'\"\n```\n\n11. **Ensure latest code is installed.** If the container is stale, update:\n\n```bash\n# Try fetching latest main\nssh {SSH_TARGET} \"docker exec {CONTAINER} bash -c 'cd /path/to/sglang && git fetch origin main && git checkout origin/main'\"\n# Or download and install from tarball if git auth fails\nssh {SSH_TARGET} \"docker exec {CONTAINER} bash -c 'cd /tmp && curl -L https://github.com/sgl-project/sglang/archive/refs/heads/main.tar.gz | tar xz && cd sglang-main && pip install -e \\\"python[all]\\\"'\"\n# Reinstall (after git fetch)\nssh {SSH_TARGET} \"docker exec {CONTAINER} bash -c 'cd /path/to/sglang && pip install -e \\\"python[all]\\\"'\"\n# Install test dependencies if needed\nssh {SSH_TARGET} \"docker exec {CONTAINER} pip install peft rouge-score\"\n```\n\n12. **Create a minimal reproduction script** that:\n    - Uses `if __name__ == '__main__'` with `mp.set_start_method(\"spawn\")`\n    - Runs the specific failing test configuration\n    - Prints key metrics (diffs, scores, outputs)\n    - Exits with code 1 on failure\n\n13. **Copy and run the reproduction script:**\n\n```bash\nscp /tmp/repro_script.py {SSH_TARGET}:/tmp/\nssh {SSH_TARGET} \"docker cp /tmp/repro_script.py {CONTAINER}:/tmp/\"\nssh {SSH_TARGET} \"docker exec -e CUDA_VISIBLE_DEVICES=0,1 {CONTAINER} python3 /tmp/repro_script.py\"\n```\n\n14. **Run control experiments** to isolate the variable:\n    - If suspecting TP issue: run with TP=1 as control\n    - If suspecting GPU issue: compare same code on different GPU\n    - If suspecting a specific commit: test before/after that commit\n\n### Phase 6: Report\n\n15. **Produce a structured report:**\n\n```markdown\n## CI Regression Bisection Report\n\n### Failure Signature\n- **Test**: {test_file}::{test_method}\n- **Error**: {exact error message}\n- **Key metrics**: {numeric values}\n- **Deterministic**: Yes/No\n\n### Root Cause Classification\nOne of:\n- **Code Regression**: PR #{number} introduced the bug\n- **Hardware-Specific**: Fails on {GPU_TYPE}, passes on others\n- **Environment Change**: New runner/driver/package version\n- **Pre-existing Flakiness**: Intermittent, not a new regression\n\n### Evidence\n| Condition | Result |\n|-----------|--------|\n| {condition1} | PASS/FAIL |\n| {condition2} | PASS/FAIL |\n\n### Timeline\n- {date}: Last known pass ({sha}, {runner})\n- {date}: First known fail ({sha}, {runner})\n- {date}: Confirmed reproduction on {server}\n\n### Recommended Fix\n- **Short-term**: {workaround}\n- **Long-term**: {proper fix}\n```\n\n## Key Patterns to Recognize\n\n| Pattern | Diagnosis |\n|---------|-----------|\n| Same SHA passes on runner A, fails on runner B | Hardware/runner-specific |\n| All runners fail after commit X | Code regression from commit X |\n| Intermittent - same runner sometimes passes/fails | Flaky test or race condition |\n| Prefill OK but decode fails | TP/all-reduce issue in decode path |\n| Works with TP=1, fails with TP>1 | Tensor parallelism bug |\n| Exact same numeric diff every time | Deterministic bug, not flakiness |\n\n## Important Notes\n\n- **Always check runner identity** before concluding it's a code regression. Many \"consistent\" failures are actually runner-specific.\n- **Test partition assignments change over time** as tests are added/removed. A test may move between partitions, landing on different runner types.\n- **H200 runners** use `/root/actions-runner/` path and machine names like `gpu-h200-worker-*`. Non-H200 runners use `/public_sglang_ci/runner-*` paths.\n- When running remote reproduction, use `run_in_background` for long-running tests and check output with `TaskOutput`.\n- Container environments may be stale - always verify package versions match CI before drawing conclusions.\n"
  },
  {
    "path": ".claude/skills/write-sglang-test/SKILL.md",
    "content": "---\nname: write-sglang-test\ndescription: Guide for writing SGLang CI/UT tests following project conventions. Covers CustomTestCase, CI registration, server fixtures, model selection, and test placement. Use when creating new tests, adding CI test cases, writing unit tests, or when the user asks to add tests for SGLang features.\n---\n\n# Writing SGLang CI / UT Tests\n\n## Core Rules\n\n1. **Always use `CustomTestCase`** — never raw `unittest.TestCase`\n2. **Place tests in `test/registered/<category>/`** — only use `test/manual/` for debugging / non-CI tests\n3. **Reuse server fixtures** — inherit from `DefaultServerBase` or write `setUpClass`/`tearDownClass` with `popen_launch_server`\n4. **Smallest model for model-agnostic functionality** — use `DEFAULT_SMALL_MODEL_NAME_FOR_TEST` (Llama-3.2-1B-Instruct) for basic features that don't depend on model size\n5. **8B for general performance** — use `DEFAULT_MODEL_NAME_FOR_TEST` (Llama-3.1-8B-Instruct, single-node) for performance tests that don't involve spec / DP / parallelism\n6. **Bigger features → discuss case by case** — spec, DP attention, tensor/pipeline parallelism etc. may need multi-GPU suites and specific models\n\n---\n\n## Test File Template\n\n### Functional correctness test (small model)\n\n```python\nimport unittest\n\nimport requests\n\nfrom sglang.srt.utils import kill_process_tree\nfrom sglang.test.ci.ci_register import register_cuda_ci\nfrom sglang.test.test_utils import (\n    DEFAULT_SMALL_MODEL_NAME_FOR_TEST,\n    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,\n    DEFAULT_URL_FOR_TEST,\n    CustomTestCase,\n    popen_launch_server,\n)\n\nregister_cuda_ci(est_time=60, suite=\"stage-b-test-small-1-gpu\")\n\n\nclass TestMyFeature(CustomTestCase):\n    @classmethod\n    def setUpClass(cls):\n        cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST\n        cls.base_url = DEFAULT_URL_FOR_TEST\n        cls.process = popen_launch_server(\n            cls.model,\n            cls.base_url,\n            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,\n            other_args=[\"--arg1\", \"value1\"],  # feature-specific args\n        )\n\n    @classmethod\n    def tearDownClass(cls):\n        kill_process_tree(cls.process.pid)\n\n    def test_basic_functionality(self):\n        response = requests.post(\n            self.base_url + \"/generate\",\n            json={\"text\": \"Hello\", \"sampling_params\": {\"max_new_tokens\": 32}},\n        )\n        self.assertEqual(response.status_code, 200)\n\n\nif __name__ == \"__main__\":\n    unittest.main(verbosity=3)\n```\n\n### General performance test (8B model, single node, no spec/DP/parallelism)\n\n```python\nimport time\nimport unittest\n\nimport requests\n\nfrom sglang.srt.utils import kill_process_tree\nfrom sglang.test.ci.ci_register import register_cuda_ci\nfrom sglang.test.test_utils import (\n    DEFAULT_MODEL_NAME_FOR_TEST,\n    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,\n    DEFAULT_URL_FOR_TEST,\n    CustomTestCase,\n    popen_launch_server,\n)\n\nregister_cuda_ci(est_time=300, suite=\"stage-b-test-large-1-gpu\")\n\n\nclass TestMyFeaturePerf(CustomTestCase):\n    @classmethod\n    def setUpClass(cls):\n        cls.model = DEFAULT_MODEL_NAME_FOR_TEST\n        cls.base_url = DEFAULT_URL_FOR_TEST\n        cls.process = popen_launch_server(\n            cls.model,\n            cls.base_url,\n            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,\n        )\n\n    @classmethod\n    def tearDownClass(cls):\n        kill_process_tree(cls.process.pid)\n\n    def test_latency(self):\n        start = time.perf_counter()\n        response = requests.post(\n            self.base_url + \"/generate\",\n            json={\"text\": \"Hello\", \"sampling_params\": {\"max_new_tokens\": 128}},\n        )\n        elapsed = time.perf_counter() - start\n        self.assertEqual(response.status_code, 200)\n        self.assertLess(elapsed, 5.0, \"Latency exceeded threshold\")\n\n\nif __name__ == \"__main__\":\n    unittest.main(verbosity=3)\n```\n\n---\n\n## Server Fixture Reuse\n\nFor tests that only need a standard server, inherit from `DefaultServerBase` and override class attributes:\n\n```python\nfrom sglang.test.server_fixtures.default_fixture import DefaultServerBase\n\nclass TestMyFeature(DefaultServerBase):\n    model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST\n    other_args = [\"--enable-my-feature\"]\n\n    def test_something(self):\n        ...\n```\n\nAvailable fixtures in `python/sglang/test/server_fixtures/`:\n\n| Fixture | Use case |\n|---------|----------|\n| `DefaultServerBase` | Standard single-server tests |\n| `EagleServerBase` | EAGLE speculative decoding |\n| `PDDisaggregationServerBase` | Disaggregated prefill/decode |\n| `MMMUServerBase` | Multimodal VLM tests |\n\n---\n\n## CI Registration\n\nEvery test file in `test/registered/` **must** call a registration function at module level:\n\n```python\nfrom sglang.test.ci.ci_register import register_cuda_ci, register_amd_ci\n\nregister_cuda_ci(est_time=60, suite=\"stage-b-test-small-1-gpu\")\nregister_amd_ci(est_time=60, suite=\"stage-b-test-small-1-gpu-amd\")  # optional\n```\n\nParameters:\n- `est_time`: estimated runtime in seconds (used for CI partitioning)\n- `suite`: which CI suite to run in (see below)\n- `nightly=True`: for nightly-only tests (default `False` = per-commit)\n- `disabled=\"reason\"`: temporarily disable with explanation\n\n### Suite selection guide\n\n**Default cases (1 GPU):**\n\n| Scenario | Model | Suite |\n|----------|-------|-------|\n| Model-agnostic basic functionality | 1B (smallest) | `stage-b-test-small-1-gpu` |\n| General performance (no spec/DP/parallelism) | 8B | `stage-b-test-large-1-gpu` |\n\n**Bigger features (case by case):**\n\n| Scenario | Suite |\n|----------|-------|\n| 2 GPU (e.g. TP=2) | `stage-b-test-large-2-gpu` |\n| 4 GPU (H100) | `stage-c-test-4-gpu-h100` |\n| 8 GPU (H200) | `stage-c-test-8-gpu-h200` |\n| Nightly, 1 GPU | `nightly-1-gpu` |\n| Nightly, 8 GPU | `nightly-8-gpu` |\n\nFor spec, DP attention, parallelism, disaggregation, etc., discuss with the team to determine the appropriate suite and GPU configuration.\n\n---\n\n## Model Constants\n\nAll defined in `python/sglang/test/test_utils.py`:\n\n| Constant | Model | When to use |\n|----------|-------|-------------|\n| `DEFAULT_SMALL_MODEL_NAME_FOR_TEST` | Llama-3.2-1B-Instruct | Model-agnostic basic functionality |\n| `DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE` | Llama-3.2-1B | Base (non-instruct) model tests |\n| `DEFAULT_MODEL_NAME_FOR_TEST` | Llama-3.1-8B-Instruct | General performance (single node) |\n| `DEFAULT_MOE_MODEL_NAME_FOR_TEST` | Mixtral-8x7B-Instruct | MoE-specific tests |\n| `DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST` | — | Embedding tests |\n| `DEFAULT_SMALL_VLM_MODEL_NAME_FOR_TEST` | — | Vision-language tests |\n\n---\n\n## Test Placement\n\n```\ntest/\n├── registered/          # CI tests (auto-discovered by run_suite.py)\n│   ├── sampling/        # test_penalty.py, test_sampling_params.py ...\n│   ├── sessions/        # test_session_control.py ...\n│   ├── openai_server/   # basic/, features/, validation/ ...\n│   ├── spec/            # eagle/, utils/ ...\n│   ├── models/          # model-specific accuracy tests\n│   ├── perf/            # performance benchmarks\n│   └── <category>/      # create new category if needed\n├── manual/              # Non-CI: debugging, one-off, manual verification\n└── run_suite.py         # CI runner (scans registered/ only)\n```\n\n**Decision rule**: if the test should run in CI → `registered/`. If it's for local debugging or requires special hardware not in CI → `manual/`.\n\n---\n\n## Key Utilities\n\n```python\nfrom sglang.test.test_utils import (\n    CustomTestCase,              # base class with retry logic\n    popen_launch_server,         # launch server subprocess\n    DEFAULT_URL_FOR_TEST,        # auto-configured base URL\n    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,  # 600s default\n    run_bench_serving,           # benchmark helper (launch + bench)\n)\nfrom sglang.srt.utils import kill_process_tree  # cleanup server\n```\n\n---\n\n## Checklist\n\nBefore submitting a test:\n\n- [ ] Inherits from `CustomTestCase` (not `unittest.TestCase`)\n- [ ] Has `register_*_ci(...)` call at module level\n- [ ] Placed in `test/registered/<category>/`\n- [ ] Model selection: smallest for model-agnostic features, 8B for general perf, case-by-case for other complex features\n- [ ] `setUpClass` launches server, `tearDownClass` kills it\n- [ ] Has `if __name__ == \"__main__\": unittest.main(verbosity=3)`\n- [ ] `est_time` is reasonable (measure locally)\n"
  },
  {
    "path": ".codespellrc",
    "content": "[codespell]\nignore-words-list = ans, als, hel, boostrap, childs, te, vas, hsa, ment, cann, thi, makro, wil, rouge, PRIS\nskip = *.json,*.jsonl,*.patch,*.txt\n"
  },
  {
    "path": ".coveragerc",
    "content": "[run]\nsource = python/sglang/srt\nomit =\n    */test/*\n    */__pycache__/*\n\n[report]\nshow_missing = true\nexclude_lines =\n    pragma: no cover\n    if __name__ == .__main__.:\n    raise NotImplementedError\n    if TYPE_CHECKING\n\n[html]\ndirectory = htmlcov\n"
  },
  {
    "path": ".devcontainer/Dockerfile",
    "content": "FROM lmsysorg/sglang:dev\n\n# Create non-root user with specified UID and GID\n# NOTE: Replace with your own UID and GID. This is a workaround from https://github.com/microsoft/vscode-remote-release/issues/49#issuecomment-489060908.\nARG HOST_UID=1003\nARG HOST_GID=1003\nRUN groupadd -g $HOST_GID devuser && \\\n    useradd -m -u $HOST_UID -g $HOST_GID -s /bin/zsh devuser\n\n# Give devuser sudo access\nRUN apt-get update && apt-get install -y sudo && \\\n    echo \"devuser ALL=(ALL) NOPASSWD:ALL\" > /etc/sudoers.d/devuser && \\\n    rm -rf /var/lib/apt/lists/* && \\\n    apt-get clean\n\n# Set up oh-my-zsh for devuser\nRUN cp -r /root/.oh-my-zsh /home/devuser/.oh-my-zsh && \\\n    cp /root/.zshrc /home/devuser/.zshrc && \\\n    cp /root/.vimrc /home/devuser/.vimrc && \\\n    cp /root/.tmux.conf /home/devuser/.tmux.conf && \\\n    sed -i 's|/root/.oh-my-zsh|/home/devuser/.oh-my-zsh|g' /home/devuser/.zshrc && \\\n    chown -R devuser:devuser /home/devuser/\n\n# Set workspace directory and ownership\nWORKDIR /sgl-workspace/sglang\nRUN chown -R devuser:devuser /sgl-workspace\n\n# Switch to devuser\nUSER devuser\n\n# Install uv\nRUN curl -LsSf https://astral.sh/uv/install.sh | sh\n\n# Install rust\nRUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y\n"
  },
  {
    "path": ".devcontainer/devcontainer.json",
    "content": "{\n    \"name\": \"sglang\",\n    \"build\": {\n        \"dockerfile\": \"Dockerfile\"\n    },\n    \"remoteUser\": \"devuser\",\n    \"customizations\": {\n        \"vscode\": {\n            \"extensions\": [\n                // Python development\n                \"ms-python.python\",\n                \"charliermarsh.ruff\",\n                // Rust development\n                \"rust-lang.rust-analyzer\",\n                \"tamasfe.even-better-toml\"\n            ]\n        }\n    },\n    \"forwardPorts\": [],\n    \"runArgs\": [\n        \"--gpus\",\n        \"all\"\n    ],\n    // The two lines below ensures that your local changes in the sglang\n    // repo is automatically synced to the sglang pip package installed\n    // in the dev docker container. You can remove / comment out these\n    // two lines if you prefer to sync code changes manually.\n    \"workspaceMount\": \"source=${localWorkspaceFolder},target=/sgl-workspace/sglang,type=bind\",\n    \"workspaceFolder\": \"/sgl-workspace/sglang\"\n}\n"
  },
  {
    "path": ".github/CI_PERMISSIONS.json",
    "content": "{\n    \"1pikachu\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"Alcanderian\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"AniZpZ\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"BBuf\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"BHZ-BER\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"ByronHsu\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"CaoE\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"CatherineSue\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"Chen-0210\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"ClawSeven\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"ConnorLi96\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"DarkSharpness\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"Edwardf0t1\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"FlamingoPg\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"FrankLeeeee\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"Fridge003\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"HaiShaw\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"HanHan009527\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"HandH1998\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"Hanrui-Wang\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"HydraQYH\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"JeremieMelo\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"Johnsonms\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"JustinTong0323\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"Kangyan-Zhou\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"LorrinWWW\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"MingxuZh\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"Oasis-Git\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"Prozac614\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"Qiaolin-Yu\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"Qihang-Zhang\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"Ratish1\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"RubiaCx\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"ShangmingCai\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"Shunkangz\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"SimonCqk\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"TianQiLin666666\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"Ubospica\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"Valentine233\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"Xia-Weiwen\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"XiaotongJiang\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"XucSh\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"YAMY1234\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"Ying1123\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"ZailiWang\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"ZhengWG\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"ZhengdQin\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"acelyc111\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"adarshxs\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"airMeng\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"alisonshao\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"alphabetc1\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"amysaq2023\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"attack204\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"ayrnb\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"azhurkevich\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"b8zhong\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"blzheng\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"byjiang1996\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"cctry\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"ch-wan\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"chunyuan-w\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"cicirori\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"cyb70289\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"dongjiyingdjy\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"dougyster\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"elfiegg\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"fy1214\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"fzyzcjy\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"gaopengff\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"gongwei-130\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"gongy\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"guapisolo\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"guoyuhong\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"hanming-lu\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"harrisonlimh\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"harvenstar\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"hebiao064\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"hlu1\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"hnyls2002\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"huaiyuzh\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"huangtingwei9988\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"hubertlu-tw\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"hyhieu\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"hzh0425\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"iforgetmyname\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"ishandhanani\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"ispobock\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"jason-fxz\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"jasperjiaguo\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"jhinpan\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"jianan-gu\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"jinleic\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"jinmingyi1998\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"kaixih\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"kevin85421\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"key4ng\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"kkHuang-amd\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"kpham-sgl\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"kssteven418\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"kushanam\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"lanking520\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"lifuhuang\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"liusy58\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"liz-badada\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"merrymercy\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"mickqian\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"mingfeima\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"minleminzui\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"mmangkad\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"narutolhy\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"netanel-haber\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"nvcastet\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"ocss884\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"pansicheng\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"pavanimajety\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"pdasgup\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"ping1jing2\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"pranavm-nvidia\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"pyc96\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"qingquansong\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"qywu\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"rainj-me\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"ravi03071991\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"rkooo567\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"saienduri\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"samuellees\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"scottjlee\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"sglang-bot\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"sglang-npu-bot\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"shaharmor98\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"shanyu-sys\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"shuaills\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"sleepcoo\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"slin1237\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"stmatengss\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"strgrb\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"sufeng-buaa\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"sundar24295s\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"sunjiweiswift\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"sunxxuns\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"thecodingwizard\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"timmy-feng\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"trevor-m\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"vincentzed\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"wenscarl\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"whybeyoung\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"wisclmy0611\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"xiezhq-hermann\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"xutizhou\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"xyjixyjixyji\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"yanbing-j\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"yangsijia-serena\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"yeahdongcn\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"yhyang201\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"yilian49\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"yinghai\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"yingluosanqian\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"yizhang2077\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"ykcombat\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"ynwang007\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"yuan-luo\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"yundai424\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"yushengsu-thu\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"yyihuang\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"yzh119\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"zhaochenyang20\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"zhijian-liu\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"zhuzilin\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"zhyncs\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"zminglei\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\",\n        \"can_rerun_stage\": true\n    },\n    \"zyksir\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    },\n    \"zyzshishui\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"custom override\",\n        \"can_rerun_stage\": true\n    }\n}\n"
  },
  {
    "path": ".github/CODEOWNERS",
    "content": ".github @merrymercy @Fridge003 @ispobock @Kangyan-Zhou @bingxche\n/docker @Fridge003 @ispobock @HaiShaw @ishandhanani @yctseng0211\n/docker/npu.Dockerfile @ping1jing2 @iforgetmyname\n/python/pyproject.toml @merrymercy @Fridge003 @ispobock\n/python/sglang/jit_kernel @DarkSharpness @BBuf @celve @HydraQYH @yuan-luo\n/python/sglang/jit_kernel/diffusion @yingluosanqian @BBuf @mickqian\n/python/sglang/multimodal_gen @mickqian @yhyang201 @ping1jing2\n/python/sglang/multimodal_gen/runtime/cache @DefTruth\n/python/sglang/multimodal_gen/runtime/layers @mickqian @yhyang201 @BBuf @yingluosanqian @ping1jing2\n/python/sglang/multimodal_gen/runtime/models/dits @mickqian @yhyang201 @BBuf @yingluosanqian @ping1jing2\n/python/sglang/srt/batch_invariant_ops @Fridge003 @hebiao064\n/python/sglang/srt/constrained @hnyls2002 @DarkSharpness\n/python/sglang/srt/compilation @hebiao064\n/python/sglang/srt/disaggregation @ByronHsu @hnyls2002 @ShangmingCai\n/python/sglang/srt/disaggregation/ascend @ping1jing2 @iforgetmyname\n/python/sglang/srt/distributed @yizhang2077 @merrymercy @ch-wan\n/python/sglang/srt/distributed/device_communicators/mooncake_transfer_engine.py @ShangmingCai @stmatengss\n/python/sglang/srt/dllm @ClawSeven @btw616\n/python/sglang/srt/entrypoints @ispobock @CatherineSue @slin1237 @merrymercy @JustinTong0323\n/python/sglang/srt/entrypoints/grpc_server.py @CatherineSue @slin1237\n/python/sglang/srt/eplb @fzyzcjy @ch-wan\n/python/sglang/srt/function_call @CatherineSue @JustinTong0323\n/python/sglang/srt/grpc @CatherineSue @slin1237\n/python/sglang/srt/hardware_backend/npu @ping1jing2 @iforgetmyname\n/python/sglang/srt/hardware_backend/npu/quantization @OrangeRedeng @TamirBaydasov @iforgetmyname\n/python/sglang/srt/layers @merrymercy @Ying1123 @Fridge003 @ispobock @HaiShaw @ch-wan @BBuf @Edwardf0t1\n/python/sglang/srt/layers/attention @merrymercy @Fridge003 @ispobock @Qiaolin-Yu @hebiao064 @HaiShaw\n/python/sglang/srt/layers/attention/fla @yizhang2077 @hebiao064\n/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py @yizhang2077 @hebiao064 @hanming-lu\n/python/sglang/srt/layers/attention/mamba @yizhang2077 @hebiao064\n/python/sglang/srt/layers/attention/nsa @1am9trash @hubertlu-tw @kkHuang-amd @HaiShaw @Fridge003 @hlu1 @rainj-me\n/python/sglang/srt/layers/attention/vision.py @mickqian @yuan-luo @yhyang201\n/python/sglang/srt/layers/quantization @ch-wan @BBuf @Edwardf0t1 @FlamingoPg @AniZpZ @HaiShaw @b8zhong\n/python/sglang/srt/layers/quantization/quark @kkHuang-amd @yichiche @hubertlu-tw @1am9trash @BowenBao\n/python/sglang/srt/lora @Ying1123 @Fridge003 @lifuhuang @yushengsu-thu\n/python/sglang/srt/managers @merrymercy @Ying1123 @hnyls2002 @xiezhq-hermann\n/python/sglang/srt/managers/scheduler_pp_mixin.py @ShangmingCai @XucSh\n/python/sglang/srt/mem_cache @merrymercy @Ying1123 @hnyls2002 @xiezhq-hermann @hanming-lu @yizhang2077 @hzh0425 @ispobock\n/python/sglang/srt/model_executor @merrymercy @Ying1123 @hnyls2002 @Fridge003 @ispobock\n/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py @hebiao064\n/python/sglang/srt/models/deepseek_common @Fridge003 @ispobock @fzyzcjy @ch-wan\n/python/sglang/srt/models/deepseek_v2.py @fzyzcjy @zhyncs @ispobock @ch-wan @merrymercy @Fridge003\n/python/sglang/srt/multimodal @mickqian @JustinTong0323 @yhyang201 @yuan-luo\n/python/sglang/srt/observability @merrymercy @fzyzcjy @sufeng-buaa\n/python/sglang/srt/ray @Qiaolin-Yu @xyuzh\n/python/sglang/srt/speculative @Ying1123 @merrymercy @hnyls2002\n/sgl-kernel @ispobock @BBuf @yizhang2077 @merrymercy @FlamingoPg @HaiShaw\n/sgl-model-gateway @slin1237 @CatherineSue\n/sgl-model-gateway/benches @slin1237\n/sgl-model-gateway/bindings/python @CatherineSue @key4ng @slin1237\n/sgl-model-gateway/e2e_test @CatherineSue @key4ng\n/sgl-model-gateway/src/config @slin1237\n/sgl-model-gateway/src/core @slin1237\n/sgl-model-gateway/src/data_connector @key4ng\n/sgl-model-gateway/src/grpc_client @CatherineSue @slin1237\n/sgl-model-gateway/src/mcp @key4ng @slin1237\n/sgl-model-gateway/src/policies @slin1237 @ByronHsu\n/sgl-model-gateway/src/proto @CatherineSue @slin1237\n/sgl-model-gateway/src/protocols @CatherineSue @key4ng\n/sgl-model-gateway/src/reasoning_parser @CatherineSue\n/sgl-model-gateway/src/routers @CatherineSue @key4ng @slin1237\n/sgl-model-gateway/src/tokenizer @slin1237 @CatherineSue\n/sgl-model-gateway/src/tool_parser @slin1237 @CatherineSue\n/sgl-model-gateway/src/wasm @slin1237\n/sgl-model-gateway/examples/wasm @slin1237\n/test/srt/ascend @ping1jing2 @iforgetmyname\n/test/srt/test_modelopt* @Edwardf0t1\n"
  },
  {
    "path": ".github/FOLDER_README.md",
    "content": "# Maintenance Tools\n\nThis folder contains tools and workflows for automating maintenance tasks.\n\n## CI Permissions\n\n`CI_PERMISSIONS.json` defines the CI permissions granted to each user.\nMaintainers can directly edit the file to add entries with `\"reason\": \"custom override\"`.\nMaintainers can also run `update_ci_permission.py` to update it with some auto rules (e.g., top contributors in the last 90 days get full permissions).\n\n## Others\n- `MAINTAINER.md` defines the code maintenance model.\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/1-bug-report.yml",
    "content": "name: 🐞 Bug report\ndescription: Report a bug to help us reproduce and fix it.\ntitle: \"[Bug] \"\nlabels: ['Bug']\n\nbody:\n- type: checkboxes\n  attributes:\n    label: Checklist\n    options:\n      - label: I searched related issues but found no solution.\n      - label: The bug persists in the latest version.\n      - label: Issues without environment info and a minimal reproducible demo are hard to resolve and may receive no feedback.\n      - label: If this is not a bug report but a general question, please start a discussion at https://github.com/sgl-project/sglang/discussions. Otherwise, it will be closed.\n      - label: Please use English. Otherwise, it will be closed.\n- type: textarea\n  attributes:\n    label: Describe the bug\n    description: A clear, concise description of the bug.\n  validations:\n    required: true\n- type: textarea\n  attributes:\n    label: Reproduction\n    description: Command/script run and model used.\n    placeholder: Paste the command here.\n  validations:\n    required: true\n- type: textarea\n  attributes:\n    label: Environment\n    description: Run `python3 -m sglang.check_env` and paste output here. Issues without this will be closed.\n    placeholder: Paste environment output here.\n  validations:\n    required: true\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/2-feature-request.yml",
    "content": "name: 🚀 Feature request\ndescription: Suggest an idea for this project\ntitle: \"[Feature] \"\n\nbody:\n- type: checkboxes\n  attributes:\n    label: Checklist\n    options:\n      - label: If this is not a feature request but a general question, please start a discussion at https://github.com/sgl-project/sglang/discussions. Otherwise, it will be closed.\n      - label: Please use English. Otherwise, it will be closed.\n- type: textarea\n  attributes:\n    label: Motivation\n    description: |\n      Clearly and concisely describe the feature's motivation.\n  validations:\n    required: true\n- type: textarea\n  attributes:\n    label: Related resources\n    description: |\n      Provide official releases or third-party implementations if available.\n"
  },
  {
    "path": ".github/MAINTAINER.md",
    "content": "# SGLang Code Maintenance Model\nThis document describes the code maintenance model for the SGLang project.\nSince SGLang is a large project involving multiple organizations and hardware platforms, we designed this model with the following goals:\n- Ensure a responsive and smooth review process.\n- Allow for fast iteration, so maintainers can sometimes bypass flaky CI tests for important PRs.\n\n## Role Descriptions\nThere are four roles in this maintenance model. Some are custom roles, while others are predefined by GitHub.\n\n- **Merge Oncall**: The person who drives the PR merge process. They have strong area-specific expertise and uphold a high bar for code quality.\n  - Permission: Merge PRs. Bypass branch protection rules if needed.\n  - Responsibility: Shepherd the merge of PRs assigned to their area. Revert or hotfix any issues related to their merge (especially if they bypass).\n- **Codeowner**: The person who protects critical code. Without a bypass, each PR needs at least one Codeowner approval for each modified file protected by [CODEOWNERS](./CODEOWNERS). Please note that this role is not an honor but a significant responsibility because PRs cannot be merged without your approval (except when bypassed by a Merge Oncall).\n  - Permission: Approve PRs, allowing them to be merged without a bypass.\n  - Responsibility: Review PRs in a timely manner.\n- **Write**: A person with write permission to the SGLang repo.\n  - Permission: Merge PRs if they have passed required tests and been approved by Codeowners. This role cannot bypass branch protection rules.\n  - Responsibility: Review and merge PRs in a timely manner.\n- **CI Oncall**: A person who manages CI runners for specific hardware platforms.\n  - Permission: Add CI runners.\n  - Responsibility: Keep the CI runners up and running.\n\n__Note__: Difference between Merge Oncall and Codeowner\n- The Merge Oncall is an active role held by someone who actively tries to help merge PRs and can bypass CI if needed.\n- The Codeowner is a passive protection role provided by GitHub; it prevents accidental changes to critical code.\n- The list of Merge Oncalls is attached below. The list of Codeowners is in the [CODEOWNERS](./CODEOWNERS) file.\n\n__Note__: The permissions to trigger CI tests are defined separately according to these [rules](https://docs.sglang.io/developer_guide/contribution_guide.html#how-to-trigger-ci-tests).\n\n\n## Pull Request Merge Process\n1. The author submits a pull request (PR) and fills out the PR checklist.\n2. A bot assigns this PR to a Merge Oncall and @-mentions them. At the same time, GitHub will automatically request reviews from Codeowners.\n3. Someone tags the PR with a `run-ci` label ([help](https://docs.sglang.io/developer_guide/contribution_guide.html#how-to-trigger-ci-tests)). Then the author can trigger CI by pushing new commits.\n4. The Merge Oncall coordinates the review (e.g., asking people to review) and approves the PR; the Codeowners also approve the PR. If the assigned Merge Oncall is not responsive, the author can ping other related Merge Oncalls and Reviewers in the list below.\n5. The code can now be merged:\n   - **Ideal case:** For each modified file, one Codeowner has approved the PR. The PR has also passed the required CI tests. Then, anyone with write permission can merge the PR.\n   - **Exception:** In cases where it is difficult to meet all requirements (due to flaky CI or slow responses), a Merge Oncall can bypass branch protection to merge the PR.\n\nIf you meet any issues during the merge, you can discuss in [slack channels](https://slack.sglang.io/): #dev, #pull-request, and #ci-cd-build-release.\n\n## The List of Merge Oncalls and Reviewers\nThe format is @github-username (Slack username).\n\nTODO: fill in the list.\n\nNow we have many Merge Oncalls mainly because the CI is flaky and the CODEOWNERS is too coarse-grained.\nIn the future, we hope the CI can be improved and we only need bypass rarely. After that, most Merge Oncalls can be converted back to Write and CODEOWNERS.\n\nThis list is based on the current situation. If you or someone you know would like to take on more responsibility and are qualified, please ping @Lianmin Zheng and @Ying Sheng in the Slack channel. They will start a nomination and internal review process.\n\n## The List of CI Oncalls\nThe format is @github-username (Slack username).\n\n### NVIDIA GPUs\n@merrymercy (Lianmin Zheng), @Kangyan-Zhou (Kangyan Zhou), @ch-wan (Cheng Wan), @HanHan009527 (hanhan), @ishandhanani (Ishan Dhanani), @key4ng (Keyang Ru), @slin1237 (Simo Lin), @ShangmingCai (Shangming Cai)\n\n### AMD GPUs\n@saienduri (Sai Enduri), @HaiShaw (Henry HAI)\n\n### Intel CPU and XPU\n@mingfeima (Mingfei Ma), @DiweiSun (Diwei Sun)\n\n### Ascend NPUs\n@iforgetmyname (Even Zhou)\n\nThis list is based on the current situation. If you or someone you know would like to donate machines for CI, they can serve as the CI oncalls for their machines. Please ping @Lianmin Zheng and @Ying Sheng in the Slack channel. They will start a nomination and internal review process.\n"
  },
  {
    "path": ".github/actions/upload-cuda-coredumps/action.yml",
    "content": "name: Upload CUDA Coredumps\ndescription: Upload CUDA coredump files as artifacts and clean up the directory.\n\ninputs:\n  artifact-suffix:\n    description: Suffix appended to the artifact name (e.g. matrix partition id)\n    required: false\n    default: \"\"\n  retention-days:\n    description: Number of days to retain the artifact\n    required: false\n    default: \"7\"\n\nruns:\n  using: composite\n  steps:\n    - name: Upload CUDA coredumps\n      uses: actions/upload-artifact@v4\n      with:\n        name: cuda-coredumps-${{ github.job }}${{ inputs.artifact-suffix && format('-{0}', inputs.artifact-suffix) }}\n        path: ${{ env.SGLANG_CUDA_COREDUMP_DIR || '/tmp/sglang_cuda_coredumps' }}/\n        retention-days: ${{ inputs.retention-days }}\n        if-no-files-found: ignore\n\n    - name: Cleanup CUDA coredumps\n      shell: bash\n      run: rm -rf \"${{ env.SGLANG_CUDA_COREDUMP_DIR || '/tmp/sglang_cuda_coredumps' }}\"\n"
  },
  {
    "path": ".github/actions/wait-for-jobs/action.yml",
    "content": "name: Wait for Jobs\ndescription: Poll and wait for specified jobs in the current workflow run to complete\n\ninputs:\n  stage-name:\n    description: 'Human-readable stage name for log messages (e.g. \"stage-a\")'\n    required: true\n  jobs:\n    description: |\n      JSON array of job specs to wait for. Each element is either:\n        - a string: exact job name (e.g. \"stage-a-test-1\")\n        - an object { \"prefix\": \"...\", \"expected_count\": N }: for matrix jobs\n    required: true\n  max-wait-minutes:\n    description: 'Maximum time to wait before timing out'\n    required: false\n    default: '240'\n  poll-interval-seconds:\n    description: 'Seconds between polling attempts'\n    required: false\n    default: '120'\n  github-token:\n    description: 'GitHub token for API calls'\n    required: false\n    default: ${{ github.token }}\n\noutputs:\n  result:\n    description: 'Overall result: success, failure, or timeout'\n    value: ${{ steps.wait.outputs.result }}\n\nruns:\n  using: composite\n  steps:\n    - name: Wait for jobs to complete\n      id: wait\n      uses: actions/github-script@v7\n      env:\n        INPUT_STAGE_NAME: ${{ inputs.stage-name }}\n        INPUT_JOBS: ${{ inputs.jobs }}\n        INPUT_MAX_WAIT_MINUTES: ${{ inputs.max-wait-minutes }}\n        INPUT_POLL_INTERVAL_SECONDS: ${{ inputs.poll-interval-seconds }}\n      with:\n        github-token: ${{ inputs.github-token }}\n        script: |\n          const stageName = process.env.INPUT_STAGE_NAME;\n          const jobSpecs = JSON.parse(process.env.INPUT_JOBS);\n          const maxWaitMinutes = parseInt(process.env.INPUT_MAX_WAIT_MINUTES);\n          const pollIntervalSeconds = parseInt(process.env.INPUT_POLL_INTERVAL_SECONDS);\n          const maxAttempts = (maxWaitMinutes * 60) / pollIntervalSeconds;\n\n          // Normalize job specs into a uniform format\n          const normalizedSpecs = jobSpecs.map(spec => {\n            if (typeof spec === 'string') {\n              return { prefix: spec, expected_count: 1, exact: true };\n            }\n            return { ...spec, exact: false };\n          });\n\n          const totalExpectedJobs = normalizedSpecs.reduce((sum, s) => sum + s.expected_count, 0);\n\n          // Match job name: exact match or prefix + \" (\" for matrix jobs\n          const matchesSpec = (jobName, spec) => {\n            if (spec.exact) {\n              return jobName === spec.prefix;\n            }\n            return jobName === spec.prefix || jobName.startsWith(spec.prefix + ' (');\n          };\n\n          for (let attempt = 0; attempt < maxAttempts; attempt++) {\n            const jobs = await github.paginate(github.rest.actions.listJobsForWorkflowRun, {\n              owner: context.repo.owner,\n              repo: context.repo.repo,\n              run_id: context.runId,\n              per_page: 100,\n            });\n\n            let allCompleted = true;\n            let failedJobs = [];\n            let completedCount = 0;\n            let totalCount = 0;\n\n            for (const spec of normalizedSpecs) {\n              const matchingJobs = jobs.filter(job => matchesSpec(job.name, spec));\n\n              for (const job of matchingJobs) {\n                totalCount++;\n                console.log(`${job.name}: status=${job.status}, conclusion=${job.conclusion}`);\n\n                if (job.status === 'completed') {\n                  completedCount++;\n                  if (job.conclusion !== 'success' && job.conclusion !== 'skipped') {\n                    failedJobs.push(job.name);\n                  }\n                } else {\n                  allCompleted = false;\n                }\n              }\n\n              if (matchingJobs.length < spec.expected_count) {\n                console.log(`${spec.prefix}: found ${matchingJobs.length}/${spec.expected_count} jobs (waiting for more)`);\n                allCompleted = false;\n              }\n            }\n\n            console.log(`[${stageName}] Progress: ${completedCount}/${totalCount} jobs completed (expected ${totalExpectedJobs})`);\n\n            // Fail fast if any jobs failed\n            if (failedJobs.length > 0) {\n              core.setOutput('result', 'failure');\n              core.setFailed(`${stageName} jobs failed: ${failedJobs.join(', ')}`);\n              return;\n            }\n\n            if (allCompleted && totalCount >= totalExpectedJobs) {\n              core.setOutput('result', 'success');\n              return;\n            }\n\n            console.log(`Waiting ${pollIntervalSeconds}s... (attempt ${attempt + 1}/${maxAttempts})`);\n            await new Promise(resolve => setTimeout(resolve, pollIntervalSeconds * 1000));\n          }\n\n          core.setFailed(`Timeout waiting for ${stageName} jobs`);\n          core.setOutput('result', 'timeout');\n"
  },
  {
    "path": ".github/labeler.yml",
    "content": "# Configuration for the GitHub Labeler action\n# Automatically adds labels to PRs based on the files changed\n\n# Router specific (Rust code in sgl-model-gateway)\nmodel-gateway:\n  - changed-files:\n    - any-glob-to-any-file: 'sgl-model-gateway/**/*'\n\n# Kernel specific\nsgl-kernel:\n  - changed-files:\n    - any-glob-to-any-file: 'sgl-kernel/**/*'\n\n# JIT kernel specific\njit-kernel:\n  - changed-files:\n    - any-glob-to-any-file: 'python/sglang/jit_kernel/**/*'\n\n# Documentation\ndocumentation:\n  - changed-files:\n    - any-glob-to-any-file:\n      - '**/*.md'\n      - 'docs/**/*'\n      - 'README*'\n\n# Dependencies\ndependencies:\n  - changed-files:\n    - any-glob-to-any-file:\n      - '**/requirements*.txt'\n      - '**/Cargo.toml'\n      - '**/Cargo.lock'\n      - '**/pyproject*.toml'\n      - '**/setup.py'\n      - '**/poetry.lock'\n      - '**/package.json'\n      - '**/package-lock.json'\n\n# Multi-modal\nMulti-modal:\n  - changed-files:\n    - any-glob-to-any-file:\n      - '**/*multimodal*'\n      - '**/*vision*'\n      - '**/*vlm*'\n\n# Diffusion\ndiffusion:\n  - changed-files:\n    - any-glob-to-any-file: 'python/sglang/multimodal_gen/**/*'\n\n# LoRA\nlora:\n  - changed-files:\n    - any-glob-to-any-file:\n      - '**/*lora*'\n\n# Quantization\nquant:\n  - changed-files:\n    - any-glob-to-any-file:\n      - '**/*quant*'\n      - '**/*quantization*'\n\n# Speculative decoding\nspeculative-decoding:\n  - changed-files:\n    - any-glob-to-any-file:\n      - '**/*speculative*'\n\n# AMD specific\namd:\n  - changed-files:\n    - any-glob-to-any-file:\n      - '**/*amd*'\n      - '**/*rocm*'\n\n# NPU specific\nnpu:\n  - changed-files:\n    - any-glob-to-any-file:\n      - '**/*npu*'\n      - '**/*ascend*'\n\n# Blackwell\nblackwell:\n  - changed-files:\n    - any-glob-to-any-file:\n      - '**/*nvfp4*'\n      - 'sgl-kernel/csrc/attention/cutlass_sm100_mla/**/*'\n      - 'python/sglang/srt/layers/attention/trtllm_mla_backend.py'\n      - 'python/sglang/srt/layers/attention/trtllm_mha_backend.py'\n\n# DeepSeek specific\ndeepseek:\n  - changed-files:\n    - any-glob-to-any-file:\n      - '**/*deepseek*'\n\n# HiCache\nhicache:\n  - changed-files:\n    - any-glob-to-any-file:\n      - '**/*hicache*'\n\n# Deterministic\ndeterministic:\n  - changed-files:\n    - any-glob-to-any-file: 'python/sglang/srt/batch_invariant_ops/**/*'\n\n# Piecewise CUDA Graph\npiecewise-cuda-graph:\n  - changed-files:\n    - any-glob-to-any-file: 'python/sglang/srt/compilation/**/*'\n\n# Moore Threads specific\nmthreads:\n  - changed-files:\n    - any-glob-to-any-file:\n      - '**/*mthreads*'\n      - '**/*musa*'\n"
  },
  {
    "path": ".github/pull_request_template.md",
    "content": "<!-- Thank you for your contribution! Please follow these guidelines to enhance your pull request. If anything is unclear, submit your PR and reach out to maintainers for assistance. Join our Slack community at https://slack.sglang.io to discuss further. -->\n\n## Motivation\n\n<!-- Describe the purpose and goals of this pull request. -->\n\n## Modifications\n\n<!-- Detail the changes made in this pull request. -->\n\n## Accuracy Tests\n\n<!-- If this pull request affects model outputs (e.g., changes to the kernel or model forward code), provide accuracy test results. -->\n\n## Benchmarking and Profiling\n\n<!-- If this pull request impacts inference speed, provide benchmarking and profiling results. -->\n\n## Checklist\n\n- [ ] Format your code according to the [Format code with pre-commit](https://docs.sglang.io/developer_guide/contribution_guide.html#format-code-with-pre-commit).\n- [ ] Add unit tests according to the [Run and add unit tests](https://docs.sglang.io/developer_guide/contribution_guide.html#run-and-add-unit-tests).\n- [ ] Update documentation according to [Write documentations](https://docs.sglang.io/developer_guide/contribution_guide.html#write-documentations).\n- [ ] Provide accuracy and speed benchmark results according to [Test the accuracy](https://docs.sglang.io/developer_guide/contribution_guide.html#test-the-accuracy) and [Benchmark the speed](https://docs.sglang.io/developer_guide/contribution_guide.html#benchmark-the-speed).\n- [ ] Follow the SGLang code style [guidance](https://docs.sglang.io/developer_guide/contribution_guide.html#code-style-guidance).\n\n## Review Process\n\n1. Ping Merge Oncalls to start the PR flow. See the [PR Merge Process](https://github.com/sgl-project/sglang/blob/main/.github/MAINTAINER.md#pull-request-merge-process).\n2. Get approvals from [CODEOWNERS](https://github.com/sgl-project/sglang/blob/main/.github/CODEOWNERS) and other reviewers.\n3. Trigger CI tests with [comments](https://docs.sglang.io/developer_guide/contribution_guide.html#how-to-trigger-ci-tests) or contact authorized users to do so.\n   - `/tag-run-ci-label`, `/rerun-failed-ci`, `/tag-and-rerun-ci`\n4. After green CI and required approvals, ask Merge Oncalls to merge.\n"
  },
  {
    "path": ".github/update_ci_permission.py",
    "content": "\"\"\"\nUpdate the CI permissions configuration file.\n\nThis script updates the `CI_PERMISSIONS.json` file, which defines the CI permissions granted to each user.\n\nThe format of `CI_PERMISSIONS.json` is as follows:\n\n{\n    \"username1\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 0,\n        \"reason\": \"top contributor\"\n    },\n    \"username2\": {\n        \"can_tag_run_ci_label\": true,\n        \"can_rerun_failed_ci\": true,\n        \"cooldown_interval_minutes\": 60,\n        \"reason\": \"custom override\"\n    }\n}\n\nPermissions are assigned according to the following rules:\n\n1. Add the top 50 contributors from the last 90 days with full permissions, no cooldown, and the reason \"top contributor\".\n2. Load all users from the existing `CI_PERMISSIONS.json` file and update their entries as follows:\n   - If a user is already covered by rule 1, skip that user.\n   - If the old reason of a user is \"top contributor\" but they are not in the current top contributors list, change their configuration to:\n       {\n           \"can_tag_run_ci_label\": true,\n           \"can_rerun_failed_ci\": true,\n           \"cooldown_interval_minutes\": 60,\n           \"reason\": \"custom override\"\n       }\n    - For all other cases, preserve the original configuration unchanged.\n3. All other users receive no permissions and a 120-minute cooldown (they are omitted from the file).\n\nUsage:\n    export GH_TOKEN=\"your_github_token\"\n    python3 update_ci_permission.py\n\n    # Sort-only mode (no network calls, no GH_TOKEN required)\n    python3 update_ci_permission.py --sort-only\n\"\"\"\n\nimport argparse\nimport json\nimport os\nfrom collections import Counter\nfrom datetime import datetime, timedelta, timezone\n\ntry:\n    import requests\nexcept ImportError:\n    requests = None  # Only needed for non-sort-only runs\n\n# Configuration\nREPO_OWNER = \"sgl-project\"\nREPO_NAME = \"sglang\"\nFILE_NAME = os.path.join(os.path.dirname(__file__), \"CI_PERMISSIONS.json\")\nHEADERS = {}\n\n\ndef github_api_get(endpoint, params=None):\n    \"\"\"Helper to make paginated GitHub API requests.\"\"\"\n    if requests is None:\n        raise RuntimeError(\n            \"The requests package is required. Install it or use --sort-only.\"\n        )\n    if not HEADERS:\n        raise RuntimeError(\n            \"GitHub headers not initialized. Set GH_TOKEN or use --sort-only.\"\n        )\n\n    results = []\n    url = f\"https://api.github.com/repos/{REPO_OWNER}/{REPO_NAME}/{endpoint}\"\n\n    while url:\n        response = requests.get(url, headers=HEADERS, params=params)\n        if response.status_code != 200:\n            print(f\"Error fetching {url}: {response.status_code} {response.text}\")\n            # If we fail to fetch, strictly return what we have or empty to avoid crashing logic\n            break\n\n        data = response.json()\n        if isinstance(data, list):\n            results.extend(data)\n        else:\n            return data  # Non-list response (not paginated usually)\n\n        # Handle pagination\n        url = None\n        if \"link\" in response.headers:\n            links = response.headers[\"link\"].split(\", \")\n            for link in links:\n                if 'rel=\"next\"' in link:\n                    url = link[link.find(\"<\") + 1 : link.find(\">\")]\n                    params = None  # Params are included in the next link\n                    break\n    return results\n\n\ndef get_write_access_users():\n    \"\"\"Fetches users with push (write) or admin access.\"\"\"\n    print(\"Fetching collaborators with write access...\")\n    # Note: This endpoint usually requires admin rights on the token.\n    collaborators = github_api_get(\"collaborators\", params={\"per_page\": 100})\n\n    writers = set()\n    for col in collaborators:\n        perms = col.get(\"permissions\", {})\n        # Check for admin, maintain, or push rights\n        if perms.get(\"admin\") or perms.get(\"maintain\") or perms.get(\"push\"):\n            writers.add(col[\"login\"])\n\n    print(f\"Found {len(writers)} users with write access.\")\n    return writers\n\n\ndef get_top_contributors(days=90, limit=50):\n    \"\"\"Fetches top contributors based on commit count in the last N days.\"\"\"\n    print(f\"Fetching commits from the last {days} days...\")\n    since_date = (datetime.now(timezone.utc) - timedelta(days=days)).isoformat()\n\n    # Fetch commits\n    commits = github_api_get(\"commits\", params={\"since\": since_date, \"per_page\": 100})\n\n    author_counts = Counter()\n    for commit in commits:\n        # commit['author'] contains the GitHub user object (can be None if not linked)\n        if commit.get(\"author\") and \"login\" in commit[\"author\"]:\n            author_counts[commit[\"author\"][\"login\"]] += 1\n\n    top_users = [user for user, _ in author_counts.most_common(limit)]\n    print(f\"Found {len(top_users)} active contributors in the last {days} days.\")\n    return set(top_users)\n\n\ndef load_existing_permissions():\n    if os.path.exists(FILE_NAME):\n        try:\n            with open(FILE_NAME, \"r\") as f:\n                return json.load(f)\n        except json.JSONDecodeError:\n            print(f\"Warning: {FILE_NAME} is invalid JSON. Starting fresh.\")\n    return {}\n\n\ndef sort_permissions_file():\n    \"\"\"Sort the existing CI permissions file alphabetically and exit.\"\"\"\n    if not os.path.exists(FILE_NAME):\n        print(f\"{FILE_NAME} not found. Nothing to sort.\")\n        return\n\n    old_permissions = load_existing_permissions()\n    sorted_permissions = dict(sorted(old_permissions.items()))\n\n    with open(FILE_NAME, \"w\") as f:\n        json.dump(sorted_permissions, f, indent=4)\n        f.write(\"\\n\")\n\n    print(f\"Sorted {FILE_NAME}. Total users: {len(sorted_permissions)}\")\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Update or sort CI permissions.\")\n    parser.add_argument(\n        \"--sort-only\",\n        action=\"store_true\",\n        help=\"Only sort CI_PERMISSIONS.json alphabetically without fetching data.\",\n    )\n    args = parser.parse_args()\n\n    if args.sort_only:\n        sort_permissions_file()\n        return\n\n    gh_token = os.getenv(\"GH_TOKEN\")\n    if not gh_token:\n        raise ValueError(\"Error: GH_TOKEN environment variable is not set.\")\n\n    global HEADERS\n    HEADERS = {\n        \"Authorization\": f\"Bearer {gh_token}\",\n        \"Accept\": \"application/vnd.github+json\",\n        \"X-GitHub-Api-Version\": \"2022-11-28\",\n    }\n\n    # Gather Data\n    try:\n        write_access_users = get_write_access_users()\n    except Exception as e:\n        print(f\"Warning: Could not fetch collaborators (check token scope). Error: {e}\")\n        write_access_users = set()\n\n    top_contributors = get_top_contributors(days=90, limit=50)\n    old_permissions = load_existing_permissions()\n\n    new_permissions = {}\n\n    # Rule 1: Add Top 50 Contributors\n    for user in top_contributors:\n        new_permissions[user] = {\n            \"can_tag_run_ci_label\": True,\n            \"can_rerun_failed_ci\": True,\n            \"cooldown_interval_minutes\": 0,\n            \"reason\": \"top contributor\",\n        }\n\n    # Rule 2: Process Existing Users (Merge Logic)\n    for user, config in old_permissions.items():\n        if user in new_permissions:\n            # Already handled by Rule 1 or 2\n            continue\n\n        old_reason = config.get(\"reason\", \"\")\n\n        # If they fell off the top contributor list\n        if old_reason in [\"top contributor\"]:\n            new_permissions[user] = {\n                \"can_tag_run_ci_label\": True,\n                \"can_rerun_failed_ci\": True,\n                \"cooldown_interval_minutes\": 60,\n                \"reason\": \"custom override\",\n            }\n        else:\n            # Preserve custom overrides\n            new_permissions[user] = config\n\n    # Save and Sort\n    # Sorting keys for cleaner diffs\n    sorted_permissions = dict(sorted(new_permissions.items()))\n\n    with open(FILE_NAME, \"w\") as f:\n        json.dump(sorted_permissions, f, indent=4)\n        f.write(\"\\n\")  # Add trailing newline\n\n    print(f\"Successfully updated {FILE_NAME}. Total users: {len(sorted_permissions)}\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": ".github/workflows/amd-aiter-scout.yml",
    "content": "name: AMD AITER Scout\n\non:\n  schedule:\n    - cron: '0 20 * * 1'   # Monday 20:00 UTC\n    - cron: '0 20 * * 4'   # Thursday 20:00 UTC\n  workflow_dispatch:\n    inputs:\n      aiter_ref:\n        description: 'AITER git ref (branch, tag, or SHA). Default: main (latest commit)'\n        required: false\n        type: string\n        default: 'main'\n      job_filter:\n        description: 'Comma-separated workflows to run: nightly-amd, nightly-amd-rocm720, pr-test-amd, pr-test-amd-rocm720. Default: all'\n        required: false\n        type: string\n        default: 'all'\n      continue_on_error:\n        description: 'Continue running other workflows even if one fails'\n        required: false\n        type: boolean\n        default: true\n\nconcurrency:\n  group: amd-aiter-scout-${{ github.run_id }}\n  cancel-in-progress: true\n\njobs:\n  resolve-aiter:\n    runs-on: ubuntu-latest\n    outputs:\n      aiter_sha: ${{ steps.resolve.outputs.sha }}\n      run_nightly_amd: ${{ steps.parse.outputs.run_nightly_amd }}\n      run_nightly_amd_rocm720: ${{ steps.parse.outputs.run_nightly_amd_rocm720 }}\n      run_pr_test_amd: ${{ steps.parse.outputs.run_pr_test_amd }}\n      run_pr_test_amd_rocm720: ${{ steps.parse.outputs.run_pr_test_amd_rocm720 }}\n    steps:\n      - name: Resolve AITER commit\n        id: resolve\n        run: |\n          REF=\"${{ inputs.aiter_ref || 'main' }}\"\n          echo \"Resolving AITER ref: ${REF}\"\n\n          SHA=$(git ls-remote https://github.com/ROCm/aiter.git \"refs/heads/${REF}\" | head -1 | cut -f1)\n          if [ -z \"$SHA\" ]; then\n            SHA=$(git ls-remote https://github.com/ROCm/aiter.git \"refs/tags/${REF}\" | head -1 | cut -f1)\n          fi\n          if [ -z \"$SHA\" ]; then\n            SHA=$(git ls-remote https://github.com/ROCm/aiter.git \"${REF}\" | head -1 | cut -f1)\n          fi\n          if [ -z \"$SHA\" ]; then\n            SHA=\"${REF}\"\n          fi\n\n          echo \"sha=${SHA}\" >> $GITHUB_OUTPUT\n          echo \"### AITER Ref Resolution\" >> $GITHUB_STEP_SUMMARY\n          echo \"- **Requested ref:** \\`${REF}\\`\" >> $GITHUB_STEP_SUMMARY\n          echo \"- **Resolved SHA:** \\`${SHA}\\`\" >> $GITHUB_STEP_SUMMARY\n          echo \"- **AITER commit:** https://github.com/ROCm/aiter/commit/${SHA}\" >> $GITHUB_STEP_SUMMARY\n\n      - name: Parse job filter\n        id: parse\n        run: |\n          FILTER=\"${{ inputs.job_filter || 'all' }}\"\n          echo \"Job filter: ${FILTER}\"\n\n          if [[ \"$FILTER\" == \"all\" ]]; then\n            echo \"run_nightly_amd=true\" >> $GITHUB_OUTPUT\n            echo \"run_nightly_amd_rocm720=true\" >> $GITHUB_OUTPUT\n            echo \"run_pr_test_amd=true\" >> $GITHUB_OUTPUT\n            echo \"run_pr_test_amd_rocm720=true\" >> $GITHUB_OUTPUT\n          else\n            # Wrap with commas for exact substring matching (avoids \"nightly-amd\" matching \"nightly-amd-rocm720\")\n            PADDED=\",${FILTER// /},\"\n            echo \"run_nightly_amd=$(echo \"$PADDED\" | grep -q ',nightly-amd,' && echo true || echo false)\" >> $GITHUB_OUTPUT\n            echo \"run_nightly_amd_rocm720=$(echo \"$PADDED\" | grep -q ',nightly-amd-rocm720,' && echo true || echo false)\" >> $GITHUB_OUTPUT\n            echo \"run_pr_test_amd=$(echo \"$PADDED\" | grep -q ',pr-test-amd,' && echo true || echo false)\" >> $GITHUB_OUTPUT\n            echo \"run_pr_test_amd_rocm720=$(echo \"$PADDED\" | grep -q ',pr-test-amd-rocm720,' && echo true || echo false)\" >> $GITHUB_OUTPUT\n          fi\n\n          echo \"### Job Filter\" >> $GITHUB_STEP_SUMMARY\n          echo \"- **Filter:** \\`${FILTER}\\`\" >> $GITHUB_STEP_SUMMARY\n\n  call-nightly-amd:\n    if: needs.resolve-aiter.outputs.run_nightly_amd == 'true'\n    needs: resolve-aiter\n    uses: ./.github/workflows/nightly-test-amd.yml\n    secrets: inherit\n    with:\n      ref: ${{ github.sha }}\n      aiter_ref: ${{ needs.resolve-aiter.outputs.aiter_sha }}\n      job_filter: 'all'\n      continue_on_error: ${{ inputs.continue_on_error == '' && true || inputs.continue_on_error }}\n\n  call-nightly-amd-rocm720:\n    if: needs.resolve-aiter.outputs.run_nightly_amd_rocm720 == 'true'\n    needs: resolve-aiter\n    uses: ./.github/workflows/nightly-test-amd-rocm720.yml\n    secrets: inherit\n    with:\n      ref: ${{ github.sha }}\n      aiter_ref: ${{ needs.resolve-aiter.outputs.aiter_sha }}\n      job_filter: 'all'\n      continue_on_error: ${{ inputs.continue_on_error == '' && true || inputs.continue_on_error }}\n\n  call-pr-test-amd:\n    if: needs.resolve-aiter.outputs.run_pr_test_amd == 'true'\n    needs: resolve-aiter\n    uses: ./.github/workflows/pr-test-amd.yml\n    secrets: inherit\n    with:\n      run_all_tests: true\n      aiter_ref: ${{ needs.resolve-aiter.outputs.aiter_sha }}\n      continue_on_error: ${{ inputs.continue_on_error == '' && true || inputs.continue_on_error }}\n\n  call-pr-test-amd-rocm720:\n    if: needs.resolve-aiter.outputs.run_pr_test_amd_rocm720 == 'true'\n    needs: resolve-aiter\n    uses: ./.github/workflows/pr-test-amd-rocm720.yml\n    secrets: inherit\n    with:\n      run_all_tests: true\n      aiter_ref: ${{ needs.resolve-aiter.outputs.aiter_sha }}\n      continue_on_error: ${{ inputs.continue_on_error == '' && true || inputs.continue_on_error }}\n\n  check-all-jobs:\n    if: always()\n    needs:\n      - resolve-aiter\n      - call-nightly-amd\n      - call-nightly-amd-rocm720\n      - call-pr-test-amd\n      - call-pr-test-amd-rocm720\n    runs-on: ubuntu-latest\n    steps:\n      - name: Summary\n        run: |\n          echo \"## AMD AITER Scout Results\" >> $GITHUB_STEP_SUMMARY\n          echo \"\" >> $GITHUB_STEP_SUMMARY\n          echo \"- **AITER SHA:** \\`${{ needs.resolve-aiter.outputs.aiter_sha }}\\`\" >> $GITHUB_STEP_SUMMARY\n          echo \"- **AITER commit:** https://github.com/ROCm/aiter/commit/${{ needs.resolve-aiter.outputs.aiter_sha }}\" >> $GITHUB_STEP_SUMMARY\n          echo \"\" >> $GITHUB_STEP_SUMMARY\n          echo \"| Workflow | Result |\" >> $GITHUB_STEP_SUMMARY\n          echo \"|----------|--------|\" >> $GITHUB_STEP_SUMMARY\n          echo \"| Nightly AMD (AITER Latest) | \\`${{ needs.call-nightly-amd.result }}\\` |\" >> $GITHUB_STEP_SUMMARY\n          echo \"| Nightly AMD ROCm 7.2 | \\`${{ needs.call-nightly-amd-rocm720.result }}\\` |\" >> $GITHUB_STEP_SUMMARY\n          echo \"| PR Test AMD (AITER Latest) | \\`${{ needs.call-pr-test-amd.result }}\\` |\" >> $GITHUB_STEP_SUMMARY\n          echo \"| PR Test AMD ROCm 7.2 | \\`${{ needs.call-pr-test-amd-rocm720.result }}\\` |\" >> $GITHUB_STEP_SUMMARY\n\n      - name: Check if any job failed\n        run: |\n          if [[ \"${{ contains(needs.*.result, 'failure') }}\" == \"true\" ]]; then\n            echo \"One or more workflows failed\"\n            exit 1\n          fi\n          if [[ \"${{ contains(needs.*.result, 'cancelled') }}\" == \"true\" ]]; then\n            echo \"One or more workflows were cancelled\"\n            exit 1\n          fi\n          echo \"All workflows passed\"\n"
  },
  {
    "path": ".github/workflows/amd-ci-job-monitor.yml",
    "content": "name: AMD CI Job Monitor\n\non:\n  schedule:\n    - cron: '0 0 * * *'  # Daily at midnight UTC\n  pull_request:\n    paths:\n      - '.github/workflows/amd-ci-job-monitor.yml'\n      - 'scripts/ci/utils/query_job_status.py'\n  workflow_dispatch:\n    inputs:\n      hours:\n        description: 'Time window in hours'\n        required: false\n        default: '24'\n        type: string\n      job_filter:\n        description: 'Job name filter (leave empty for all AMD jobs)'\n        required: false\n        type: string\n\njobs:\n  # Single job filter mode\n  custom-report:\n    name: Custom Job Report\n    if: ${{ inputs.job_filter }}\n    runs-on: ubuntu-latest\n    env:\n      GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: '3.10'\n\n      - name: Install dependencies\n        run: pip install tabulate\n\n      - name: Generate Custom Job Report\n        timeout-minutes: 30\n        run: |\n          python scripts/ci/utils/query_job_status.py \\\n            --repo ${{ github.repository }} \\\n            --job \"${{ inputs.job_filter }}\" \\\n            --workflow \"pr-test-amd.yml\" \\\n            --hours ${{ inputs.hours || '24' }} \\\n            --summary\n\n  # Parse workflow files to get job names dynamically\n  parse-workflows:\n    name: Parse Workflow Jobs\n    if: ${{ !inputs.job_filter }}\n    runs-on: ubuntu-latest\n    outputs:\n      pr_jobs: ${{ steps.parse.outputs.pr_jobs }}\n      nightly_jobs: ${{ steps.parse.outputs.nightly_jobs }}\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n\n      - name: Parse workflow files\n        id: parse\n        run: |\n          # Parse pr-test-amd.yml and extract job names (exclude utility jobs)\n          # Excluded: call-gate, check-changes, pr-test-amd-finish, cancel, check-all-jobs\n          pr_jobs=$(yq -r '.jobs | keys | .[]' .github/workflows/pr-test-amd.yml | \\\n            grep -v -E '^(call-gate|check-changes|pr-test-amd-finish|cancel|check-all-jobs)$' | \\\n            jq -R -s -c 'split(\"\\n\") | map(select(length > 0))')\n          echo \"pr_jobs=$pr_jobs\" >> $GITHUB_OUTPUT\n          echo \"PR jobs: $pr_jobs\"\n\n          # Parse nightly-test-amd.yml and extract job names (exclude utility jobs)\n          # Excluded: check-all-jobs\n          nightly_jobs=$(yq -r '.jobs | keys | .[]' .github/workflows/nightly-test-amd.yml | \\\n            grep -v -E '^(check-all-jobs)$' | \\\n            jq -R -s -c 'split(\"\\n\") | map(select(length > 0))')\n          echo \"nightly_jobs=$nightly_jobs\" >> $GITHUB_OUTPUT\n          echo \"Nightly jobs: $nightly_jobs\"\n\n  # PR CI reports using dynamic matrix\n  pr-ci-reports:\n    name: PR - ${{ matrix.job_name }}\n    needs: parse-workflows\n    if: ${{ !inputs.job_filter }}\n    runs-on: ubuntu-latest\n    strategy:\n      fail-fast: false\n      matrix:\n        job_name: ${{ fromJson(needs.parse-workflows.outputs.pr_jobs) }}\n    env:\n      GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: '3.10'\n\n      - name: Install dependencies\n        run: pip install tabulate\n\n      - name: Generate Report\n        timeout-minutes: 15\n        run: |\n          python scripts/ci/utils/query_job_status.py \\\n            --repo ${{ github.repository }} \\\n            --job \"${{ matrix.job_name }}\" \\\n            --workflow \"pr-test-amd.yml\" \\\n            --hours ${{ inputs.hours || '24' }} \\\n            --summary\n\n  # Nightly AMD test reports using dynamic matrix\n  nightly-reports:\n    name: Nightly - ${{ matrix.job_name }}\n    needs: parse-workflows\n    if: ${{ !inputs.job_filter }}\n    runs-on: ubuntu-latest\n    strategy:\n      fail-fast: false\n      matrix:\n        job_name: ${{ fromJson(needs.parse-workflows.outputs.nightly_jobs) }}\n    env:\n      GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: '3.10'\n\n      - name: Install dependencies\n        run: pip install tabulate\n\n      - name: Generate Nightly Report\n        timeout-minutes: 15\n        run: |\n          python scripts/ci/utils/query_job_status.py \\\n            --repo ${{ github.repository }} \\\n            --job \"${{ matrix.job_name }}\" \\\n            --workflow \"nightly-test-amd.yml\" \\\n            --hours ${{ inputs.hours || '24' }} \\\n            --summary\n"
  },
  {
    "path": ".github/workflows/auto-tune.yml",
    "content": "name: Auto tune\n\non:\n  workflow_dispatch:\n\njobs:\n  lint:\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v4\n"
  },
  {
    "path": ".github/workflows/bot-bump-flashinfer-version.yml",
    "content": "name: Bot Bump Flashinfer Version\n\non:\n  workflow_dispatch:\n    inputs:\n      new_version:\n        description: 'New flashinfer version (e.g., 0.6.4)'\n        required: true\n        type: string\n\npermissions:\n  contents: write\n  pull-requests: write\n\njobs:\n  bump-flashinfer-version:\n    runs-on: ubuntu-latest\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          token: ${{ secrets.GITHUB_TOKEN }}\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: '3.10'\n\n      - name: Install Python dependencies\n        run: |\n            pip install tomli\n\n      - name: Configure Git and branch\n        run: |\n          git config user.name \"sglang-bot\"\n          git config user.email \"sglang-bot@users.noreply.github.com\"\n          RANDOM_SUFFIX=$(echo $RANDOM | md5sum | head -c 4)\n          BRANCH_NAME=\"bot/bump-flashinfer-version-${{ github.event.inputs.new_version }}-${RANDOM_SUFFIX}\"\n          git checkout -b \"$BRANCH_NAME\"\n          echo \"BRANCH_NAME=$BRANCH_NAME\" >> $GITHUB_ENV\n\n      - name: Run flashinfer version bump script\n        run: |\n          python scripts/release/bump_flashinfer_version.py \"${{ github.event.inputs.new_version }}\"\n\n      - name: Commit and create PR\n        env:\n          GH_TOKEN: ${{ secrets.GH_PAT_FOR_PULL_REQUEST }}\n        run: |\n          bash scripts/release/commit_and_pr.sh \"flashinfer\" \"${{ github.event.inputs.new_version }}\" \"$BRANCH_NAME\"\n"
  },
  {
    "path": ".github/workflows/bot-bump-kernel-version-to-sglang.yml",
    "content": "name: Bot Bump Kernel Version to SGLang\n\non:\n  workflow_dispatch:\n\npermissions:\n  contents: write\n  pull-requests: write\n\njobs:\n  bump-kernel-version-to-sglang:\n    runs-on: ubuntu-latest\n    outputs:\n      branch_name: ${{ steps.set_output.outputs.branch_name }}\n      needs_sync: ${{ steps.check_sync.outputs.needs_sync }}\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          token: ${{ secrets.GITHUB_TOKEN }}\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: '3.10'\n\n      - name: Install Python dependencies\n        run: |\n          pip install tomli\n\n      - name: Check if sync is needed\n        id: check_sync\n        run: |\n          python scripts/release/check_kernel_version_to_sglang.py\n\n      - name: Configure Git and branch\n        if: steps.check_sync.outputs.needs_sync == 'true'\n        id: set_output\n        run: |\n          git config user.name \"sglang-bot\"\n          git config user.email \"sglang-bot@users.noreply.github.com\"\n          RANDOM_SUFFIX=$(echo $RANDOM | md5sum | head -c 4)\n          KERNEL_VERSION=\"${{ steps.check_sync.outputs.kernel_version }}\"\n          BRANCH_NAME=\"bot/bump-kernel-version-to-sglang-${KERNEL_VERSION}-${RANDOM_SUFFIX}\"\n          git checkout -b \"$BRANCH_NAME\"\n          echo \"BRANCH_NAME=$BRANCH_NAME\" >> $GITHUB_ENV\n          echo \"KERNEL_VERSION=$KERNEL_VERSION\" >> $GITHUB_ENV\n          echo \"branch_name=$BRANCH_NAME\" >> $GITHUB_OUTPUT\n\n      - name: Run kernel version bump script\n        if: steps.check_sync.outputs.needs_sync == 'true'\n        run: |\n          python scripts/release/bump_kernel_version_to_sglang.py\n\n      - name: Commit and create PR\n        if: steps.check_sync.outputs.needs_sync == 'true'\n        env:\n          GH_TOKEN: ${{ secrets.GH_PAT_FOR_PULL_REQUEST }}\n        run: |\n          bash scripts/release/commit_and_pr_kernel_to_sglang.sh \"$KERNEL_VERSION\" \"$BRANCH_NAME\"\n\n  run-nightly-tests-nvidia:\n    needs: bump-kernel-version-to-sglang\n    if: needs.bump-kernel-version-to-sglang.outputs.needs_sync == 'true'\n    uses: ./.github/workflows/nightly-test-nvidia.yml\n    with:\n      ref: ${{ needs.bump-kernel-version-to-sglang.outputs.branch_name }}\n    secrets: inherit\n\n  run-nightly-tests-amd:\n    needs: bump-kernel-version-to-sglang\n    if: needs.bump-kernel-version-to-sglang.outputs.needs_sync == 'true'\n    uses: ./.github/workflows/nightly-test-amd.yml\n    with:\n      ref: ${{ needs.bump-kernel-version-to-sglang.outputs.branch_name }}\n    secrets: inherit\n\n  run-nightly-tests-npu:\n    needs: bump-kernel-version-to-sglang\n    if: needs.bump-kernel-version-to-sglang.outputs.needs_sync == 'true'\n    uses: ./.github/workflows/nightly-test-npu.yml\n    with:\n      ref: ${{ needs.bump-kernel-version-to-sglang.outputs.branch_name }}\n    secrets: inherit\n\n  run-pr-tests-xeon:\n    needs: bump-kernel-version-to-sglang\n    if: needs.bump-kernel-version-to-sglang.outputs.needs_sync == 'true'\n    uses: ./.github/workflows/pr-test-xeon.yml\n    with:\n      ref: ${{ needs.bump-kernel-version-to-sglang.outputs.branch_name }}\n    secrets: inherit\n\n  run-pr-tests-xpu:\n    needs: bump-kernel-version-to-sglang\n    if: needs.bump-kernel-version-to-sglang.outputs.needs_sync == 'true'\n    uses: ./.github/workflows/pr-test-xpu.yml\n    with:\n      ref: ${{ needs.bump-kernel-version-to-sglang.outputs.branch_name }}\n    secrets: inherit\n"
  },
  {
    "path": ".github/workflows/bot-bump-kernel-version.yml",
    "content": "name: Bot Bump Kernel Version\n\non:\n  workflow_dispatch:\n    inputs:\n      new_version:\n        description: 'New sgl-kernel version (e.g., 0.3.12)'\n        required: true\n        type: string\n\npermissions:\n  contents: write\n  pull-requests: write\n\njobs:\n  bump-kernel-version:\n    runs-on: ubuntu-latest\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          token: ${{ secrets.GITHUB_TOKEN }}\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: '3.10'\n\n      - name: Install Python dependencies\n        run: |\n          pip install tomli\n\n      - name: Configure Git and branch\n        run: |\n          git config user.name \"sglang-bot\"\n          git config user.email \"sglang-bot@users.noreply.github.com\"\n          RANDOM_SUFFIX=$(echo $RANDOM | md5sum | head -c 4)\n          BRANCH_NAME=\"bot/bump-kernel-version-${{ github.event.inputs.new_version }}-${RANDOM_SUFFIX}\"\n          git checkout -b \"$BRANCH_NAME\"\n          echo \"BRANCH_NAME=$BRANCH_NAME\" >> $GITHUB_ENV\n\n      - name: Run kernel version bump script\n        run: |\n          python scripts/release/bump_kernel_version.py \"${{ github.event.inputs.new_version }}\"\n\n      - name: Commit and create PR\n        env:\n          GH_TOKEN: ${{ secrets.GH_PAT_FOR_PULL_REQUEST }}\n        run: |\n          bash scripts/release/commit_and_pr.sh \"sgl-kernel\" \"${{ github.event.inputs.new_version }}\" \"$BRANCH_NAME\"\n"
  },
  {
    "path": ".github/workflows/bot-bump-sglang-version.yml",
    "content": "name: Bot Bump SGLang Version\n\non:\n  workflow_dispatch:\n    inputs:\n      new_version:\n        description: 'New SGLang version (e.g., 0.5.3 or 0.5.3rc0)'\n        required: true\n        type: string\n\npermissions:\n  contents: write\n  pull-requests: write\n\njobs:\n  bump-sglang-version:\n    runs-on: ubuntu-latest\n    outputs:\n      branch_name: ${{ steps.set_output.outputs.branch_name }}\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          token: ${{ secrets.GITHUB_TOKEN }}\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: '3.10'\n\n      - name: Install Python dependencies\n        run: |\n          pip install tomli\n\n      - name: Configure Git and branch\n        id: set_output\n        run: |\n          git config user.name \"sglang-bot\"\n          git config user.email \"sglang-bot@users.noreply.github.com\"\n          RANDOM_SUFFIX=$(echo $RANDOM | md5sum | head -c 4)\n          BRANCH_NAME=\"bot/bump-sglang-version-${{ github.event.inputs.new_version }}-${RANDOM_SUFFIX}\"\n          git checkout -b \"$BRANCH_NAME\"\n          echo \"BRANCH_NAME=$BRANCH_NAME\" >> $GITHUB_ENV\n          echo \"branch_name=$BRANCH_NAME\" >> $GITHUB_OUTPUT\n\n      - name: Run SGLang version bump script\n        run: |\n          python scripts/release/bump_sglang_version.py \"${{ github.event.inputs.new_version }}\"\n\n      - name: Commit and create PR\n        env:\n          GH_TOKEN: ${{ secrets.GH_PAT_FOR_PULL_REQUEST }}\n        run: |\n          bash scripts/release/commit_and_pr.sh \"SGLang\" \"${{ github.event.inputs.new_version }}\" \"$BRANCH_NAME\"\n\n  run-nightly-tests-nvidia:\n    needs: bump-sglang-version\n    uses: ./.github/workflows/nightly-test-nvidia.yml\n    with:\n      ref: ${{ needs.bump-sglang-version.outputs.branch_name }}\n    secrets: inherit\n\n  run-nightly-tests-amd:\n    needs: bump-sglang-version\n    uses: ./.github/workflows/nightly-test-amd.yml\n    with:\n      ref: ${{ needs.bump-sglang-version.outputs.branch_name }}\n    secrets: inherit\n\n  run-nightly-tests-npu:\n    needs: bump-sglang-version\n    uses: ./.github/workflows/nightly-test-npu.yml\n    with:\n      ref: ${{ needs.bump-sglang-version.outputs.branch_name }}\n    secrets: inherit\n\n  run-pr-tests-xeon:\n    needs: bump-sglang-version\n    uses: ./.github/workflows/pr-test-xeon.yml\n    with:\n      ref: ${{ needs.bump-sglang-version.outputs.branch_name }}\n    secrets: inherit\n\n  run-pr-tests-xpu:\n    needs: bump-sglang-version\n    uses: ./.github/workflows/pr-test-xpu.yml\n    with:\n      ref: ${{ needs.bump-sglang-version.outputs.branch_name }}\n    secrets: inherit\n"
  },
  {
    "path": ".github/workflows/bot-cherry-pick.yml",
    "content": "name: Bot Cherry Pick to Release Branch\n\non:\n  workflow_dispatch:\n    inputs:\n      commit_sha:\n        description: 'Commit SHA to cherry-pick (full or short hash)'\n        required: true\n        type: string\n      target_branch:\n        description: 'Target release branch (e.g., release/v0.5.7)'\n        required: true\n        type: string\n      create_pr:\n        description: 'Create a PR instead of pushing directly'\n        required: false\n        type: boolean\n        default: true\n\npermissions:\n  contents: write\n  pull-requests: write\n\nconcurrency:\n  group: cherry-pick-${{ github.event.inputs.target_branch }}\n  cancel-in-progress: false\n\njobs:\n  cherry-pick:\n    if: github.repository == 'sgl-project/sglang'\n    runs-on: ubuntu-latest\n    environment: 'prod'\n    steps:\n      - name: Validate inputs\n        env:\n          TARGET_BRANCH: ${{ github.event.inputs.target_branch }}\n        run: |\n          if [[ ! \"$TARGET_BRANCH\" =~ ^release/v[0-9]+\\.[0-9]+(\\.[0-9]+)?$ ]]; then\n            echo \"::error::Target branch must match pattern 'release/vX.Y' or 'release/vX.Y.Z' (e.g., release/v0.5.7)\"\n            exit 1\n          fi\n\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          fetch-depth: 0\n          token: ${{ secrets.GH_PAT_FOR_PULL_REQUEST }}\n\n      - name: Configure Git\n        run: |\n          git config user.name \"sglang-bot\"\n          git config user.email \"sglang-bot@users.noreply.github.com\"\n\n      - name: Validate target branch exists\n        env:\n          TARGET_BRANCH: ${{ github.event.inputs.target_branch }}\n        run: |\n          git fetch origin\n          if ! git ls-remote --exit-code --heads origin \"$TARGET_BRANCH\" > /dev/null 2>&1; then\n            echo \"::error::Target branch '$TARGET_BRANCH' does not exist on remote\"\n            exit 1\n          fi\n\n      - name: Get commit info\n        id: commit_info\n        env:\n          COMMIT_SHA_INPUT: ${{ github.event.inputs.commit_sha }}\n        run: |\n          # Verify commit exists\n          if ! git cat-file -t \"$COMMIT_SHA_INPUT\" > /dev/null 2>&1; then\n            echo \"::error::Commit SHA '$COMMIT_SHA_INPUT' does not exist\"\n            exit 1\n          fi\n\n          # Get full SHA if short hash provided\n          FULL_SHA=$(git rev-parse \"$COMMIT_SHA_INPUT\")\n          COMMIT_TITLE=$(git log -1 --format=\"%s\" \"$FULL_SHA\")\n          SHORT_SHA=$(git rev-parse --short \"$FULL_SHA\")\n          echo \"full_sha=$FULL_SHA\" >> $GITHUB_OUTPUT\n          echo \"short_sha=$SHORT_SHA\" >> $GITHUB_OUTPUT\n          # Use delimiter for multiline-safe output\n          {\n            echo \"commit_title<<EOF\"\n            echo \"$COMMIT_TITLE\"\n            echo \"EOF\"\n          } >> $GITHUB_OUTPUT\n          echo \"Cherry-picking commit: $SHORT_SHA - $COMMIT_TITLE\"\n\n      - name: Cherry-pick commit\n        id: cherry_pick\n        env:\n          TARGET_BRANCH: ${{ github.event.inputs.target_branch }}\n          FULL_SHA: ${{ steps.commit_info.outputs.full_sha }}\n          SHORT_SHA: ${{ steps.commit_info.outputs.short_sha }}\n          CREATE_PR: ${{ github.event.inputs.create_pr }}\n        run: |\n          if [[ \"$CREATE_PR\" == \"true\" ]]; then\n            # Create a new branch for the PR\n            RANDOM_SUFFIX=$(head -c 4 /dev/urandom | xxd -p)\n            NEW_BRANCH=\"cherry-pick/${SHORT_SHA}-to-${TARGET_BRANCH#release/}-${RANDOM_SUFFIX}\"\n            git checkout -b \"$NEW_BRANCH\" \"origin/$TARGET_BRANCH\"\n            echo \"new_branch=$NEW_BRANCH\" >> $GITHUB_OUTPUT\n          else\n            # Checkout target branch directly\n            git checkout \"$TARGET_BRANCH\"\n          fi\n\n          # Attempt cherry-pick\n          if git cherry-pick \"$FULL_SHA\"; then\n            echo \"cherry_pick_success=true\" >> $GITHUB_OUTPUT\n          else\n            echo \"::error::Cherry-pick failed due to conflicts. Please resolve manually.\"\n            git cherry-pick --abort || true\n            echo \"cherry_pick_success=false\" >> $GITHUB_OUTPUT\n            exit 1\n          fi\n\n      - name: Push changes\n        if: steps.cherry_pick.outputs.cherry_pick_success == 'true'\n        env:\n          CREATE_PR: ${{ github.event.inputs.create_pr }}\n          TARGET_BRANCH: ${{ github.event.inputs.target_branch }}\n          NEW_BRANCH: ${{ steps.cherry_pick.outputs.new_branch }}\n        run: |\n          if [[ \"$CREATE_PR\" == \"true\" ]]; then\n            git push origin \"$NEW_BRANCH\"\n          else\n            git push origin \"$TARGET_BRANCH\"\n          fi\n\n      - name: Create Pull Request\n        if: steps.cherry_pick.outputs.cherry_pick_success == 'true' && github.event.inputs.create_pr == 'true'\n        env:\n          GH_TOKEN: ${{ secrets.GH_PAT_FOR_PULL_REQUEST }}\n          TARGET_BRANCH: ${{ github.event.inputs.target_branch }}\n          SHORT_SHA: ${{ steps.commit_info.outputs.short_sha }}\n          COMMIT_TITLE: ${{ steps.commit_info.outputs.commit_title }}\n          FULL_SHA: ${{ steps.commit_info.outputs.full_sha }}\n          NEW_BRANCH: ${{ steps.cherry_pick.outputs.new_branch }}\n        run: |\n          PR_TITLE=\"[Cherry-pick] ${COMMIT_TITLE} to ${TARGET_BRANCH}\"\n\n          gh pr create \\\n            --title \"$PR_TITLE\" \\\n            --base \"$TARGET_BRANCH\" \\\n            --head \"$NEW_BRANCH\" \\\n            --label \"cherry-pick\" \\\n            --body-file - <<EOF\n          Cherry-pick of commit ${FULL_SHA} to \\`${TARGET_BRANCH}\\`\n\n          **Original commit:** ${FULL_SHA}\n          **Original title:** ${COMMIT_TITLE}\n\n          ---\n          *This PR was automatically created by the cherry-pick workflow.*\n          EOF\n\n      - name: Summary\n        if: always()\n        env:\n          FULL_SHA: ${{ steps.commit_info.outputs.full_sha }}\n          COMMIT_TITLE: ${{ steps.commit_info.outputs.commit_title }}\n          TARGET_BRANCH: ${{ github.event.inputs.target_branch }}\n          CHERRY_PICK_SUCCESS: ${{ steps.cherry_pick.outputs.cherry_pick_success }}\n          CREATE_PR: ${{ github.event.inputs.create_pr }}\n          NEW_BRANCH: ${{ steps.cherry_pick.outputs.new_branch }}\n          ACTOR: ${{ github.actor }}\n        run: |\n          echo \"## Cherry-Pick Summary\" >> $GITHUB_STEP_SUMMARY\n          echo \"\" >> $GITHUB_STEP_SUMMARY\n          echo \"- **Triggered by:** @${ACTOR}\" >> $GITHUB_STEP_SUMMARY\n          echo \"- **Commit:** ${FULL_SHA}\" >> $GITHUB_STEP_SUMMARY\n          echo \"- **Title:** ${COMMIT_TITLE}\" >> $GITHUB_STEP_SUMMARY\n          echo \"- **Target Branch:** ${TARGET_BRANCH}\" >> $GITHUB_STEP_SUMMARY\n          if [[ \"$CHERRY_PICK_SUCCESS\" == \"true\" ]]; then\n            echo \"- **Status:** ✅ Success\" >> $GITHUB_STEP_SUMMARY\n          else\n            echo \"- **Status:** ❌ Failed\" >> $GITHUB_STEP_SUMMARY\n          fi\n          if [[ \"$CREATE_PR\" == \"true\" && \"$CHERRY_PICK_SUCCESS\" == \"true\" ]]; then\n            echo \"- **PR Branch:** ${NEW_BRANCH}\" >> $GITHUB_STEP_SUMMARY\n          fi\n"
  },
  {
    "path": ".github/workflows/cancel-pr-workflow-on-merge.yml",
    "content": "name: Cancel PR Workflows on Merge\n\non:\n  pull_request_target:\n    types:\n      - closed\n\npermissions:\n  actions: write\n\njobs:\n  cancel:\n    if: github.event.pull_request.merged == true\n    runs-on: ubuntu-latest\n    steps:\n      - name: Cancel Previous Runs\n        uses: styfle/cancel-workflow-action@0.12.1\n        with:\n          workflow_id: all\n          access_token: ${{ secrets.GITHUB_TOKEN }}\n          ignore_sha: true\n          pr_number: ${{ github.event.pull_request.number }}\n"
  },
  {
    "path": ".github/workflows/cancel-unfinished-pr-tests.yml",
    "content": "name: Cancel Unfinished PR Runs\n\non:\n  workflow_dispatch:\n    inputs:\n      workflows:\n        description: 'Space-separated list of workflow filenames to cancel'\n        required: true\n        type: string\n        default: 'pr-test.yml'\n\npermissions:\n  actions: write   # Needed to cancel runs\n  contents: read   # Needed to read repo info\n  pull-requests: read  # needed for gh pr view (labels)\n\njobs:\n  cancel-unfinished-pr-runs:\n    runs-on: ubuntu-latest\n    steps:\n      - name: Install GitHub CLI\n        run: sudo apt-get install -y gh jq\n\n      - name: Cancel unfinished PR-associated runs (skip high-priority PRs)\n        env:\n          GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n          REPO: ${{ github.repository }}\n          WORKFLOWS: ${{ github.event.inputs.workflows || 'pr-test.yml' }}\n        shell: bash\n        run: |\n          set -euo pipefail\n\n          # Read the space-separated string from the input into a bash array\n          read -r -a WORKFLOW_FILES <<< \"${WORKFLOWS}\"\n\n          echo \"Targeting ${#WORKFLOW_FILES[@]} workflow(s): ${WORKFLOWS}\"\n          echo \"\"\n\n          for workflow_file in \"${WORKFLOW_FILES[@]}\"; do\n            echo \"=========================================\"\n            echo \"Workflow: $workflow_file\"\n            echo \"=========================================\"\n\n            # Get all unfinished runs\n            all_runs=$(gh run list \\\n              --repo \"$REPO\" \\\n              --workflow \"$workflow_file\" \\\n              --json databaseId,status,event,url,createdAt \\\n              --limit 1000 \\\n            | jq -c '.[] | select(.status==\"queued\" or .status==\"waiting\" or .status==\"in_progress\")')\n\n            if [ -z \"$all_runs\" ]; then\n              echo \"✅ No unfinished runs found\"\n              echo \"\"\n              continue\n            fi\n\n            # Count runs by event type\n            total_runs=$(echo \"$all_runs\" | wc -l)\n            pr_runs=$(echo \"$all_runs\" | jq -s '[.[] | select(.event==\"pull_request\")] | length')\n            other_runs=$(echo \"$all_runs\" | jq -s '[.[] | select(.event!=\"pull_request\")] | length')\n\n            echo \"📊 Summary: $total_runs unfinished runs ($pr_runs PR-related, $other_runs other)\"\n            echo \"\"\n\n            # Process non-PR runs first\n            if [ \"$other_runs\" -gt 0 ]; then\n              echo \"--- Non-PR Runs ---\"\n              echo \"$all_runs\" | jq -c 'select(.event!=\"pull_request\")' | while read -r run; do\n                run_url=$(echo \"$run\" | jq -r '.url')\n                run_event=$(echo \"$run\" | jq -r '.event')\n                run_status=$(echo \"$run\" | jq -r '.status')\n                echo \"  • $run_event ($run_status): $run_url\"\n              done\n              echo \"\"\n            fi\n\n            # Process PR runs\n            if [ \"$pr_runs\" -gt 0 ]; then\n              echo \"--- PR Runs (checking for cancellation) ---\"\n              echo \"$all_runs\" | jq -c 'select(.event==\"pull_request\")' | while read -r run; do\n                run_id=$(echo \"$run\" | jq -r '.databaseId')\n                run_url=$(echo \"$run\" | jq -r '.url')\n                run_status=$(echo \"$run\" | jq -r '.status')\n\n                echo \"\"\n                echo \"Run ($run_status): $run_url\"\n\n                # Fetch full run details to get head repository and branch info\n                run_details=$(gh api -H \"Accept: application/vnd.github+json\" \\\n                  \"repos/$REPO/actions/runs/$run_id\" 2>/dev/null || true)\n\n                if [ -z \"$run_details\" ]; then\n                  echo \"  ⚠️  Could not fetch run details, skipping\"\n                  continue\n                fi\n\n                # Get head owner and branch (works for both fork and non-fork PRs)\n                head_owner=$(echo \"$run_details\" | jq -r '.head_repository.owner.login // empty')\n                head_branch=$(echo \"$run_details\" | jq -r '.head_branch // empty')\n\n                if [ -z \"$head_owner\" ] || [ -z \"$head_branch\" ]; then\n                  echo \"  ⚠️  Missing head info, skipping\"\n                  continue\n                fi\n\n                echo \"  Branch: ${head_owner}:${head_branch}\"\n\n                # Find PR by searching with head=owner:branch\n                pr_number=$(gh api -H \"Accept: application/vnd.github+json\" \\\n                  \"repos/$REPO/pulls?state=open&head=${head_owner}:${head_branch}\" \\\n                  --jq '.[0].number // empty' 2>/dev/null || true)\n\n                if [ -z \"$pr_number\" ]; then\n                  echo \"  ⚠️  No open PR found, skipping\"\n                  continue\n                fi\n\n                pr_url=\"https://github.com/$REPO/pull/$pr_number\"\n                echo \"  PR: $pr_url\"\n\n                # Check for high priority label\n                labels=$(gh pr view \"$pr_number\" --repo \"$REPO\" --json labels \\\n                  | jq -r '.labels[].name' 2>/dev/null || true)\n\n                if echo \"$labels\" | grep -Fxq \"high priority\"; then\n                  echo \"  🛑 Skipping (high priority label)\"\n                  continue\n                fi\n\n                echo \"  🚫 Cancelling...\"\n                gh run cancel \"$run_id\" --repo \"$REPO\" || echo \"  ⚠️  Cancellation failed\"\n              done\n            fi\n\n            echo \"\"\n          done\n\n          echo \"=========================================\"\n          echo \"✅ Processing complete\"\n          echo \"=========================================\"\n"
  },
  {
    "path": ".github/workflows/ci-coverage-overview.yml",
    "content": "name: CI Coverage Overview\n\non:\n  schedule:\n    - cron: '0 6 * * *'  # Daily at 6 AM UTC\n  pull_request:\n    paths:\n      - '.github/workflows/ci-coverage-overview.yml'\n      - 'scripts/ci/utils/ci_coverage_report.py'\n      - 'test/registered/**'\n  workflow_dispatch:\n    inputs:\n      output_format:\n        description: 'Output format'\n        required: false\n        default: 'markdown'\n        type: choice\n        options:\n          - markdown\n          - json\n\njobs:\n  summary:\n    name: Summary\n    runs-on: ubuntu-latest\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: '3.10'\n\n      - name: Generate Summary Report\n        run: |\n          python scripts/ci/utils/ci_coverage_report.py --section summary\n\n  by-folder:\n    name: Tests by Folder\n    runs-on: ubuntu-latest\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: '3.10'\n\n      - name: Generate Tests by Folder Report\n        run: |\n          python scripts/ci/utils/ci_coverage_report.py --section by-folder\n\n  by-suite:\n    name: Tests by Suite\n    runs-on: ubuntu-latest\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: '3.10'\n\n      - name: Generate Tests by Suite Report\n        run: |\n          python scripts/ci/utils/ci_coverage_report.py --section by-suite\n\n  unit-test-coverage:\n    name: Unit Test Code Coverage\n    if: github.event_name != 'pull_request'\n    runs-on: 1-gpu-runner\n    timeout-minutes: 30\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n\n      - name: Install dependencies\n        timeout-minutes: 10\n        run: |\n          pip install -e \"python/[test]\"\n\n      - name: Run unit tests with coverage\n        timeout-minutes: 10\n        run: |\n          pytest test/registered/unit/ \\\n            --cov --cov-config=.coveragerc \\\n            --cov-report=term-missing:skip-covered \\\n            --continue-on-collection-errors \\\n            -v | tee coverage_output.txt\n\n      - name: Write coverage to summary\n        if: always()\n        run: |\n          echo \"## Unit Test Code Coverage\" >> $GITHUB_STEP_SUMMARY\n          echo \"\" >> $GITHUB_STEP_SUMMARY\n          echo \"**Commit:** \\`${GITHUB_SHA::8}\\` | **Branch:** \\`${GITHUB_REF_NAME}\\`\" >> $GITHUB_STEP_SUMMARY\n          echo \"\" >> $GITHUB_STEP_SUMMARY\n\n          # Test result line (e.g., \"== 42 passed, 1 failed in 23.5s ==\")\n          echo '```' >> $GITHUB_STEP_SUMMARY\n          grep -E '^=+.*passed' coverage_output.txt >> $GITHUB_STEP_SUMMARY || true\n          echo \"\" >> $GITHUB_STEP_SUMMARY\n          # Coverage total\n          grep -E '^TOTAL ' coverage_output.txt >> $GITHUB_STEP_SUMMARY || true\n          echo '```' >> $GITHUB_STEP_SUMMARY\n\n          # Partially covered core modules (1-49%) — most actionable for contributors\n          # Only show modules with testable logic; skip configs, models, layers, etc.\n          LOW_COV=$(awk '/^python\\/.*%/ {\n            for (i=1; i<=NF; i++) {\n              if ($i ~ /^[0-9]+%$/) {\n                pct = $i + 0\n                if (pct >= 1 && pct < 50) printf \"%-80s %5s  %s\\n\", $1, $(i-2), $i\n                break\n              }\n            }\n          }' coverage_output.txt \\\n            | grep -E '/(mem_cache|managers|sampling|parser|observability|function_call|entrypoints|speculative|multimodal|utils)/' \\\n            | head -40 || true)\n          if [ -n \"$LOW_COV\" ]; then\n            echo \"\" >> $GITHUB_STEP_SUMMARY\n            echo \"<details><summary>Core modules with coverage below 50% — good candidates for more unit tests</summary>\" >> $GITHUB_STEP_SUMMARY\n            echo \"\" >> $GITHUB_STEP_SUMMARY\n            echo '```' >> $GITHUB_STEP_SUMMARY\n            echo \"$LOW_COV\" >> $GITHUB_STEP_SUMMARY\n            echo '```' >> $GITHUB_STEP_SUMMARY\n            echo \"</details>\" >> $GITHUB_STEP_SUMMARY\n          fi\n\n  json-export:\n    name: JSON Export\n    runs-on: ubuntu-latest\n    if: inputs.output_format == 'json'\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: '3.10'\n\n      - name: Generate JSON Report\n        run: |\n          python scripts/ci/utils/ci_coverage_report.py --output-format json > ci_coverage.json\n\n      - name: Upload JSON artifact\n        uses: actions/upload-artifact@v4\n        with:\n          name: ci-coverage-report\n          path: ci_coverage.json\n"
  },
  {
    "path": ".github/workflows/ci-failure-monitor.yml",
    "content": "name: CI Failure Monitor\n\non:\n  schedule:\n    - cron: '0 */12 * * *' # Every 12 hour\n  workflow_dispatch:\n\nconcurrency:\n  group: ci-failure-monitor-${{ github.ref }}\n  cancel-in-progress: true\n\npermissions:\n  contents: read\n  actions: read\n\njobs:\n  failure-analysis:\n    if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'\n    runs-on: ubuntu-latest\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: '3.14'\n\n      - name: Install dependencies\n        run: |\n          python -m pip install --upgrade pip\n          pip install requests slack_sdk\n\n      - name: Run Failure Analysis\n        env:\n          GITHUB_TOKEN: ${{ secrets.GH_PAT_FOR_NIGHTLY_CI_DATA }}\n          GH_PAT_FOR_RUNNER_ADMIN: ${{ secrets.GH_PAT_FOR_RUNNER_ADMIN }}\n          PYTHONUNBUFFERED: 1\n          PYTHONIOENCODING: utf-8\n        run: |\n          cd scripts/ci_monitor\n          python ci_failures_analysis.py \\\n            --token $GITHUB_TOKEN \\\n            --limit 100 \\\n            --output ci_failure_analysis_$(date +%Y%m%d_%H%M%S).json\n\n      - name: Upload Analysis Results\n        uses: actions/upload-artifact@v4\n        with:\n          name: ci-failure-analysis-${{ github.run_number }}\n          path: |\n            scripts/ci_monitor/ci_failure_analysis_*.json\n          retention-days: 7\n\n      - name: Send Slack Notification\n        if: always()\n        env:\n          SGLANG_DIFFUSION_SLACK_TOKEN: ${{ secrets.SGLANG_DIFFUSION_SLACK_TOKEN }}\n        run: |\n          cd scripts/ci_monitor\n          LATEST_REPORT=$(ls -t ci_failure_analysis_*.json | head -1)\n\n          if [ ! -f \"$LATEST_REPORT\" ]; then\n            echo \"No report found, so skipping Slack notification\"\n            exit 0\n          fi\n\n          if [ -n \"$SGLANG_DIFFUSION_SLACK_TOKEN\" ]; then\n            python3 post_ci_failures_to_slack.py --report-file \"$LATEST_REPORT\"\n          else\n            echo \"SGLANG_DIFFUSION_SLACK_TOKEN not configured, skipping notification\"\n          fi\n"
  },
  {
    "path": ".github/workflows/close-inactive-issues.yml",
    "content": "name: Close Inactive Issues\n\non:\n  schedule:\n    - cron: '0 0 * * *'\n  workflow_dispatch:\n\npermissions:\n  issues: write\n  contents: read\n\njobs:\n  close-inactive-issues:\n    if: github.repository == 'sgl-project/sglang'\n    runs-on: ubuntu-latest\n    steps:\n      - name: Check and close inactive issues\n        uses: actions/github-script@v6\n        with:\n          github-token: ${{secrets.GITHUB_TOKEN}}\n          script: |\n            const sixtyDaysAgo = new Date(Date.now() - 60 * 24 * 60 * 60 * 1000);\n\n            const [owner, repo] = process.env.GITHUB_REPOSITORY.split('/');\n            console.log(`Owner: ${owner}, Repo: ${repo}`);\n\n            async function fetchIssues(page = 1) {\n              console.log(`Fetching issues for ${owner}/${repo}, page ${page}`);\n              return await github.rest.issues.listForRepo({\n                owner,\n                repo,\n                state: 'open',\n                sort: 'updated',\n                direction: 'asc',\n                per_page: 100,\n                page: page\n              });\n            }\n\n            async function processIssues() {\n              console.log('Starting to process issues');\n              console.log(`Repository: ${owner}/${repo}`);\n\n              let page = 1;\n              let hasMoreIssues = true;\n              while (hasMoreIssues) {\n                try {\n                  const issues = await fetchIssues(page);\n                  console.log(`Fetched ${issues.data.length} issues on page ${page}`);\n\n                  if (issues.data.length === 0) {\n                    hasMoreIssues = false;\n                    break;\n                  }\n\n                  for (const issue of issues.data) {\n                    // Skip if the issue has 'good first issue' label\n                    if (issue.labels.some(label => label.name === 'good first issue')) {\n                      console.log(`Skipping issue #${issue.number} as it's marked as 'good first issue'`);\n                      continue;\n                    }\n                    if (new Date(issue.updated_at) < sixtyDaysAgo) {\n                      try {\n                        await github.rest.issues.update({\n                          owner,\n                          repo,\n                          issue_number: issue.number,\n                          state: 'closed',\n                          labels: [...issue.labels.map(l => l.name), 'inactive']\n                        });\n                        await github.rest.issues.createComment({\n                          owner,\n                          repo,\n                          issue_number: issue.number,\n                          body: 'This issue has been automatically closed due to inactivity. Please feel free to reopen it if needed.'\n                        });\n                        console.log(`Closed issue #${issue.number} due to inactivity.`);\n                      } catch (error) {\n                        console.error(`Failed to close issue #${issue.number}: ${error.message}`);\n                      }\n                    } else {\n                      console.log(`Issue #${issue.number} is still active. Stopping processing.`);\n                      hasMoreIssues = false;\n                      break;\n                    }\n                  }\n                  page += 1;\n                } catch (error) {\n                  console.error(`Error fetching issues on page ${page}: ${error.message}`);\n                  hasMoreIssues = false;\n                }\n              }\n              console.log('Finished processing issues');\n            }\n\n            await processIssues();\n"
  },
  {
    "path": ".github/workflows/diffusion-ci-gt-gen.yml",
    "content": "name: Diffusion CI Ground Truth Generation\n\non:\n  workflow_dispatch:\n    inputs:\n      ref:\n        description: 'Git ref to checkout'\n        required: false\n        default: ''\n        type: string\n      case_ids:\n        description: 'Specific case IDs to run (space-separated, optional)'\n        required: false\n        default: ''\n        type: string\n\nconcurrency:\n  group: diffusion-ci-gt-gen-${{ github.ref }}\n  cancel-in-progress: true\n\npermissions:\n  contents: write\n  actions: read\n\njobs:\n  multimodal-diffusion-gen-1gpu:\n    if: github.repository == 'sgl-project/sglang'\n    runs-on: 1-gpu-runner\n    strategy:\n      matrix:\n        part: [0, 1]\n    timeout-minutes: 150\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/cuda/ci_install_dependency.sh diffusion\n\n      - name: Generate outputs\n        run: |\n          cd python\n          python -m sglang.multimodal_gen.test.scripts.gen_diffusion_ci_outputs \\\n            --suite 1-gpu \\\n            --partition-id ${{ matrix.part }} \\\n            --total-partitions 2 \\\n            --out-dir ./diffusion-ci-outputs \\\n            --continue-on-error \\\n            ${{ inputs.case_ids != '' && format('--case-ids {0}', inputs.case_ids) || '' }}\n\n      - name: Upload artifact\n        uses: actions/upload-artifact@v4\n        with:\n          name: diffusion-gen-1gpu-part${{ matrix.part }}\n          path: python/diffusion-ci-outputs\n          retention-days: 7\n\n  multimodal-diffusion-gen-2gpu:\n    if: github.repository == 'sgl-project/sglang'\n    runs-on: 2-gpu-runner\n    strategy:\n      matrix:\n        part: [0, 1]\n    timeout-minutes: 150\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/cuda/ci_install_dependency.sh diffusion\n\n      - name: Generate outputs\n        run: |\n          cd python\n          python -m sglang.multimodal_gen.test.scripts.gen_diffusion_ci_outputs \\\n            --suite 2-gpu \\\n            --partition-id ${{ matrix.part }} \\\n            --total-partitions 2 \\\n            --out-dir ./diffusion-ci-outputs \\\n            --continue-on-error \\\n            ${{ inputs.case_ids != '' && format('--case-ids {0}', inputs.case_ids) || '' }}\n\n      - name: Upload artifact\n        uses: actions/upload-artifact@v4\n        with:\n          name: diffusion-gen-2gpu-part${{ matrix.part }}\n          path: python/diffusion-ci-outputs\n          retention-days: 7\n\n  diffusion-ci-push:\n    needs: [multimodal-diffusion-gen-1gpu, multimodal-diffusion-gen-2gpu]\n    if: github.repository == 'sgl-project/sglang'\n    runs-on: ubuntu-latest\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n\n      - name: Download artifacts\n        uses: actions/download-artifact@v4\n        with:\n          pattern: diffusion-gen-*\n          path: combined\n          merge-multiple: true\n\n      - name: Collect image files\n        run: |\n          mkdir -p gt_images\n          find combined \\( -name \"*.png\" -o -name \"*.jpg\" -o -name \"*.jpeg\" -o -name \"*.webp\" \\) -type f -exec cp -f {} gt_images/ \\;\n\n      - name: Publish GT images to sglang-bot/sglang-ci-data\n        env:\n          GITHUB_TOKEN: ${{ secrets.GH_PAT_FOR_NIGHTLY_CI_DATA }}\n        run: python scripts/ci/utils/publish_diffusion_gt.py --source-dir gt_images\n"
  },
  {
    "path": ".github/workflows/execute-notebook.yml",
    "content": "name: Execute Notebooks\n\non:\n  pull_request:\n    branches: [ main ]\n    types: [opened, synchronize, reopened, labeled]\n    paths:\n      - \"python/sglang/**\"\n      - \"docs/**\"\n      - \"!python/sglang/**/*.md\"\n      - \"!docs/**/*.md\"\n  workflow_dispatch:\n\n\nconcurrency:\n  group: execute-notebook-${{ github.ref }}\n  cancel-in-progress: true\n\nenv:\n  SGLANG_IS_IN_CI: true\n\njobs:\n  call-gate:\n    # Align with PR Test: fail fast if PR doesn't have run-ci label.\n    # This makes /tag-and-rerun-ci work by rerunning this failed workflow.\n    uses: ./.github/workflows/pr-gate.yml\n    secrets: inherit\n\n  run-all-notebooks:\n    needs: [call-gate]\n    runs-on: 1-gpu-runner\n    if: github.event_name != 'pull_request' || needs.call-gate.result == 'success'\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/cuda/ci_install_dependency.sh\n          pip install -r docs/requirements.txt\n          apt-get update && apt-get install -y pandoc parallel retry\n          ln -sf \"$(which python3)\" /usr/bin/python\n\n      - name: Setup Jupyter Kernel\n        run: |\n          python -m ipykernel install --user --name python3 --display-name \"Python 3\"\n\n      - name: Execute notebooks\n        timeout-minutes: 40\n        run: |\n          cd docs\n          make clean\n          make compile\n\n\n  notebook-finish:\n    needs: [\n      call-gate,\n      run-all-notebooks\n    ]\n    runs-on: ubuntu-latest\n    if: always() && needs.run-all-notebooks.result != 'skipped'\n    steps:\n      - name: Check all dependent job statuses\n        run: |\n          results=(${{ join(needs.*.result, ' ') }})\n          for result in \"${results[@]}\"; do\n            if [ \"$result\" = \"failure\" ] || [ \"$result\" = \"cancelled\" ]; then\n              echo \"Job failed with result: $result\"\n              exit 1\n            fi\n          done\n          echo \"All jobs completed successfully\"\n          exit 0\n"
  },
  {
    "path": ".github/workflows/labeler.yml",
    "content": "name: Auto Label PRs\n\non:\n  pull_request_target:\n    types: [opened, synchronize, reopened]\n\npermissions:\n  contents: read\n  pull-requests: write\n\njobs:\n  label:\n    runs-on: ubuntu-latest\n    steps:\n      - name: Auto-label by file changes\n        uses: actions/labeler@v5\n        with:\n          repo-token: \"${{ secrets.GITHUB_TOKEN }}\"\n          configuration-path: .github/labeler.yml\n          sync-labels: false\n"
  },
  {
    "path": ".github/workflows/lint.yml",
    "content": "name: Lint\n\non:\n  push:\n    branches: [main]\n  pull_request:\n    branches: [main]\n\njobs:\n  lint:\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v4\n\n      - name: Set up Python\n        uses: actions/setup-python@v4\n        with:\n          python-version: \"3.12\"\n\n      - name: Install pre-commit hook\n        run: |\n          python -m pip install pre-commit\n          pre-commit install\n\n      - name: Run pre-commit checks\n        run: SKIP=no-commit-to-branch pre-commit run --all-files --show-diff-on-failure\n\n      - name: Run sgl-kernel clang-format checks\n        uses: DoozyX/clang-format-lint-action@v0.20\n        with:\n          source: sgl-kernel\n          extensions: h,c,cpp,hpp,cu,cuh,cc\n          clangFormatVersion: 20\n          style: file\n"
  },
  {
    "path": ".github/workflows/list-active-pr-runs.yml.yml",
    "content": "name: List Active Runs\n\non:\n  workflow_dispatch:\n    inputs:\n      workflows:\n        description: 'Space-separated list of workflow filenames to check'\n        required: false\n        type: string\n        default: 'pr-test.yml'\n\npermissions:\n  actions: read\n  contents: read\n  pull-requests: read\n\njobs:\n  list-active-runs:\n    runs-on: ubuntu-latest\n    steps:\n      - name: Install GitHub CLI\n        run: sudo apt-get install -y gh jq\n\n      - name: List active runs grouped by PR\n        env:\n          GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n          REPO: ${{ github.repository }}\n          WORKFLOWS: ${{ github.event.inputs.workflows || 'pr-test.yml' }}\n        shell: bash\n        run: |\n          set -euo pipefail\n\n          echo \"=========================================\"\n          echo \"🔍 Active Workflow Runs Report\"\n          echo \"=========================================\"\n          echo \"\"\n\n          # Get all workflows or specific ones\n          read -r -a workflow_files <<< \"${WORKFLOWS}\"\n          echo \"📋 Checking specified workflows: ${WORKFLOWS}\"\n\n          echo \"\"\n\n          # Create a temporary file to store PR data\n          pr_data_file=$(mktemp)\n\n          # Process each workflow\n          for workflow_file in ${workflow_files[@]}; do\n            echo \"Scanning workflow: $workflow_file\"\n\n            # Get all active runs (queued, waiting, in_progress)\n            active_runs=$(gh run list \\\n              --repo \"$REPO\" \\\n              --workflow \"$workflow_file\" \\\n              --json databaseId,status,event,headBranch,createdAt,updatedAt,headSha,number,attempt \\\n              --limit 500 \\\n              | jq -c '.[] | select(.status==\"queued\" or .status==\"waiting\" or .status==\"in_progress\")')\n\n            if [ -z \"$active_runs\" ]; then\n              continue\n            fi\n\n            # Process each run\n            echo \"$active_runs\" | while read -r run; do\n              run_id=$(echo \"$run\" | jq -r '.databaseId')\n              run_status=$(echo \"$run\" | jq -r '.status')\n              run_event=$(echo \"$run\" | jq -r '.event')\n              created_at=$(echo \"$run\" | jq -r '.createdAt')\n              head_sha=$(echo \"$run\" | jq -r '.headSha')\n              run_number=$(echo \"$run\" | jq -r '.number')\n              run_attempt=$(echo \"$run\" | jq -r '.attempt // 1')\n\n              # Get detailed run information including jobs\n              run_details=$(gh api \"repos/$REPO/actions/runs/$run_id\" 2>/dev/null || true)\n\n              if [ -z \"$run_details\" ]; then\n                continue\n              fi\n\n              head_owner=$(echo \"$run_details\" | jq -r '.head_repository.owner.login // empty')\n              head_branch=$(echo \"$run_details\" | jq -r '.head_branch // empty')\n\n              if [ -z \"$head_owner\" ] || [ -z \"$head_branch\" ]; then\n                continue\n              fi\n\n              # Find PR number (may be empty for non-PR runs)\n              pr_number=$(gh api \"repos/$REPO/pulls?state=open&head=${head_owner}:${head_branch}\" \\\n                --jq '.[0].number // empty' 2>/dev/null || true)\n\n              if [ -z \"$pr_number\" ]; then\n                pr_number=\"NO_PR\"\n              fi\n\n              # Get jobs for this run (with pagination to avoid missing jobs)\n              jobs=$(gh api \"repos/$REPO/actions/runs/$run_id/jobs\" --paginate --jq '.jobs[]' | jq -s '.')\n\n              running_jobs=$(echo \"$jobs\" | jq '[.[] | select(.status==\"in_progress\")] | length')\n              queued_jobs=$(echo \"$jobs\" | jq '[.[] | select(.status==\"queued\" or .status==\"waiting\")] | length')\n\n              # Get runner info for running jobs\n              runners=$(echo \"$jobs\" | jq -r '.[] | select(.status==\"in_progress\") | .runner_name // \"N/A\"' | paste -sd \",\" -)\n\n              # Calculate queue time\n              current_time=$(date -u +%s)\n              created_time=$(date -u -d \"$created_at\" +%s 2>/dev/null || echo \"$current_time\")\n              queue_time=$((current_time - created_time))\n              queue_minutes=$((queue_time / 60))\n\n              # Store data in temporary file (unified format with event and branch)\n              echo \"$pr_number|$workflow_file|$run_id|$run_status|$running_jobs|$queued_jobs|$runners|$queue_minutes|$created_at|$head_sha|$run_attempt|$run_event|$head_branch\" >> \"$pr_data_file\"\n            done\n          done\n\n          echo \"\"\n          echo \"=========================================\"\n          echo \"📊 Active Runs Summary\"\n          echo \"=========================================\"\n          echo \"\"\n\n          if [ ! -s \"$pr_data_file\" ]; then\n            echo \"✅ No active runs found\"\n            rm -f \"$pr_data_file\"\n            exit 0\n          fi\n\n          # Get unique PR numbers (exclude NO_PR entries)\n          pr_numbers=$(cut -d'|' -f1 < \"$pr_data_file\" | grep -v '^NO_PR$' | sort -u || true)\n\n          # Separate high priority and normal PRs\n          high_priority_prs=()\n          normal_prs=()\n\n          for pr_num in $pr_numbers; do\n            labels=$(gh pr view \"$pr_num\" --repo \"$REPO\" --json labels \\\n              | jq -r '.labels[].name' 2>/dev/null || true)\n\n            if echo \"$labels\" | grep -Fxq \"high priority\"; then\n              high_priority_prs+=($pr_num)\n            else\n              normal_prs+=($pr_num)\n            fi\n          done\n\n          # Combine: high priority first, then normal\n          sorted_pr_numbers=(\"${high_priority_prs[@]}\" \"${normal_prs[@]}\")\n\n          pr_count=0\n          total_running=0\n          total_queued=0\n\n          for pr_num in \"${sorted_pr_numbers[@]}\"; do\n            pr_count=$((pr_count + 1))\n\n            # Get PR details\n            pr_info=$(gh pr view \"$pr_num\" --repo \"$REPO\" --json title,author,labels,url 2>/dev/null || true)\n\n            if [ -z \"$pr_info\" ]; then\n              continue\n            fi\n\n            pr_title=$(echo \"$pr_info\" | jq -r '.title')\n            pr_author=$(echo \"$pr_info\" | jq -r '.author.login')\n            pr_url=$(echo \"$pr_info\" | jq -r '.url')\n            pr_labels=$(echo \"$pr_info\" | jq -r '.labels[].name' | paste -sd \", \" -)\n\n            if [ -z \"$pr_labels\" ]; then\n              pr_labels=\"(no labels)\"\n            fi\n\n            # Add priority indicator\n            priority_indicator=\"\"\n            if echo \"$pr_labels\" | grep -q \"high priority\"; then\n              priority_indicator=\"🔴 [HIGH PRIORITY] \"\n            fi\n\n            echo \"━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\"\n            echo \"🔗 ${priority_indicator}PR #$pr_num: $pr_title\"\n            echo \"━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\"\n            echo \"👤 Author: $pr_author\"\n            echo \"🏷️  Labels: $pr_labels\"\n            echo \"🔗 URL: $pr_url\"\n            echo \"\"\n\n            # Get all runs for this PR\n            pr_runs=$(grep \"^$pr_num|\" \"$pr_data_file\")\n\n            pr_running_total=0\n            pr_queued_total=0\n\n            echo \"$pr_runs\" | while read -r line; do\n              workflow=$(echo \"$line\" | cut -d'|' -f2)\n              run_id=$(echo \"$line\" | cut -d'|' -f3)\n              status=$(echo \"$line\" | cut -d'|' -f4)\n              running=$(echo \"$line\" | cut -d'|' -f5)\n              queued=$(echo \"$line\" | cut -d'|' -f6)\n              runners=$(echo \"$line\" | cut -d'|' -f7)\n              queue_min=$(echo \"$line\" | cut -d'|' -f8)\n              created=$(echo \"$line\" | cut -d'|' -f9)\n              attempt=$(echo \"$line\" | cut -d'|' -f11)\n\n              pr_running_total=$((pr_running_total + running))\n              pr_queued_total=$((pr_queued_total + queued))\n\n              run_url=\"https://github.com/$REPO/actions/runs/$run_id\"\n\n              # Calculate retry count for this specific run\n              retry_count=$((attempt - 1))\n\n              # Show retry indicator\n              retry_indicator=\"\"\n              if [ \"$retry_count\" -gt 0 ]; then\n                retry_indicator=\" 🔄 Retry #$retry_count\"\n              fi\n\n              echo \"  📦 Workflow: $workflow (Run #$run_id)$retry_indicator\"\n              echo \"     Status: $status\"\n              echo \"     🟢 Running jobs: $running\"\n              echo \"     🟡 Queued jobs: $queued\"\n\n              if [ \"$running\" -gt 0 ] && [ \"$runners\" != \"\" ]; then\n                echo \"     🖥️  Runners: $runners\"\n              fi\n\n              if [ \"$queue_min\" -gt 0 ]; then\n                echo \"     ⏱️  Queue time: ${queue_min} minutes\"\n              fi\n\n              echo \"     🔗 Run URL: $run_url\"\n              echo \"\"\n            done\n\n            # Summary for this PR\n            pr_running_total=$(grep \"^$pr_num|\" \"$pr_data_file\" | cut -d'|' -f5 | awk '{sum+=$1} END {print sum+0}')\n            pr_queued_total=$(grep \"^$pr_num|\" \"$pr_data_file\" | cut -d'|' -f6 | awk '{sum+=$1} END {print sum+0}')\n\n            total_running=$((total_running + pr_running_total))\n            total_queued=$((total_queued + pr_queued_total))\n\n            echo \"  📊 PR Total: $pr_running_total running, $pr_queued_total queued\"\n            echo \"\"\n          done\n\n          # --- Non-PR Runs Section ---\n          non_pr_runs=$(grep '^NO_PR|' \"$pr_data_file\" 2>/dev/null || true)\n          non_pr_running=0\n          non_pr_queued=0\n\n          if [ -n \"$non_pr_runs\" ]; then\n            echo \"=========================================\"\n            echo \"📦 Non-PR Runs (manual / scheduled / other)\"\n            echo \"=========================================\"\n            echo \"\"\n\n            echo \"$non_pr_runs\" | while read -r line; do\n              workflow=$(echo \"$line\" | cut -d'|' -f2)\n              run_id=$(echo \"$line\" | cut -d'|' -f3)\n              status=$(echo \"$line\" | cut -d'|' -f4)\n              running=$(echo \"$line\" | cut -d'|' -f5)\n              queued=$(echo \"$line\" | cut -d'|' -f6)\n              runners=$(echo \"$line\" | cut -d'|' -f7)\n              queue_min=$(echo \"$line\" | cut -d'|' -f8)\n              created=$(echo \"$line\" | cut -d'|' -f9)\n              attempt=$(echo \"$line\" | cut -d'|' -f11)\n              event=$(echo \"$line\" | cut -d'|' -f12)\n              branch=$(echo \"$line\" | cut -d'|' -f13)\n\n              run_url=\"https://github.com/$REPO/actions/runs/$run_id\"\n\n              retry_count=$((attempt - 1))\n              retry_indicator=\"\"\n              if [ \"$retry_count\" -gt 0 ]; then\n                retry_indicator=\" 🔄 Retry #$retry_count\"\n              fi\n\n              echo \"  📦 Workflow: $workflow (Run #$run_id)$retry_indicator\"\n              echo \"     Event: $event\"\n              echo \"     Branch: $branch\"\n              echo \"     Status: $status\"\n              echo \"     🟢 Running jobs: $running\"\n              echo \"     🟡 Queued jobs: $queued\"\n\n              if [ \"$running\" -gt 0 ] && [ \"$runners\" != \"\" ]; then\n                echo \"     🖥️  Runners: $runners\"\n              fi\n\n              if [ \"$queue_min\" -gt 0 ]; then\n                echo \"     ⏱️  Queue time: ${queue_min} minutes\"\n              fi\n\n              echo \"     🔗 Run URL: $run_url\"\n              echo \"\"\n            done\n\n            non_pr_running=$(echo \"$non_pr_runs\" | cut -d'|' -f5 | awk '{sum+=$1} END {print sum+0}')\n            non_pr_queued=$(echo \"$non_pr_runs\" | cut -d'|' -f6 | awk '{sum+=$1} END {print sum+0}')\n            non_pr_count=$(echo \"$non_pr_runs\" | wc -l | tr -d ' ')\n\n            total_running=$((total_running + non_pr_running))\n            total_queued=$((total_queued + non_pr_queued))\n\n            echo \"  📊 Non-PR Total: $non_pr_running running, $non_pr_queued queued\"\n            echo \"\"\n          fi\n\n          # Overall summary\n          echo \"=========================================\"\n          echo \"📈 Overall Summary\"\n          echo \"=========================================\"\n          echo \"Total PRs with active runs: $pr_count\"\n          echo \"Total non-PR active runs: ${non_pr_count:-0}\"\n          echo \"Total running jobs: $total_running\"\n          echo \"Total queued jobs: $total_queued\"\n          echo \"=========================================\"\n\n          # Cleanup\n          rm -f \"$pr_data_file\"\n"
  },
  {
    "path": ".github/workflows/nightly-release-gateway.yml",
    "content": "# Nightly release workflow for SGLang Model Gateway\n\nname: Nightly Release SGLang Model Gateway to PyPI\n\non:\n  schedule:\n    # Run at 2 AM UTC every day\n    - cron: '0 2 * * *'\n  workflow_dispatch:  # Allow manual trigger\n\njobs:\n  build:\n    name: build on ${{ matrix.platform || matrix.os }} (${{ matrix.target }} - ${{ matrix.manylinux || 'auto' }})\n    runs-on: ${{ matrix.os }}-latest\n    strategy:\n      fail-fast: false\n      matrix:\n        os: [ubuntu, macos, windows]\n        target: [x86_64, aarch64]\n        manylinux: [auto]\n        include:\n          - os: ubuntu\n            platform: linux\n          - os: windows\n            ls: dir\n            target: x86_64\n            python-architecture: x64\n            interpreter: 3.9 3.10 3.11 3.12 3.13\n          - os: macos\n            target: aarch64\n            interpreter: 3.9 3.10 3.11 3.12 3.13\n          - os: ubuntu\n            platform: linux\n            target: aarch64\n          # musllinux\n          - os: ubuntu\n            platform: linux\n            target: x86_64\n            manylinux: musllinux_1_1\n          - os: ubuntu\n            platform: linux\n            target: aarch64\n            manylinux: musllinux_1_1\n        exclude:\n          - os: windows\n            target: aarch64\n\n    steps:\n      - uses: actions/checkout@v4\n        with:\n          path: sglang-repo\n\n      - name: Move sgl-model-gateway folder to root and delete sglang-repo\n        run: |\n          mv sglang-repo/sgl-model-gateway/* .\n          rm -rf sglang-repo\n          ls -alt\n        shell: bash\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: \"3.13\"\n          architecture: ${{ matrix.python-architecture || 'x64' }}\n\n      - name: Modify version for nightly release\n        run: |\n          # Get current version from pyproject.toml\n          CURRENT_VERSION=$(python -c \"import tomllib; print(tomllib.load(open('bindings/python/pyproject.toml', 'rb'))['project']['version'])\" 2>/dev/null || python -c \"import tomli; print(tomli.load(open('bindings/python/pyproject.toml', 'rb'))['project']['version'])\")\n          # Create nightly version with date: e.g., 0.2.1.dev20250128\n          NIGHTLY_VERSION=\"${CURRENT_VERSION}.dev$(date +%Y%m%d)\"\n          echo \"Nightly version: $NIGHTLY_VERSION\"\n\n          # Update pyproject.toml with nightly version (temporary, not committed)\n          sed -i.bak \"s/version = \\\"${CURRENT_VERSION}\\\"/version = \\\"${NIGHTLY_VERSION}\\\"/\" bindings/python/pyproject.toml\n\n          # Verify the change\n          cat bindings/python/pyproject.toml | grep \"^version\"\n        shell: bash\n\n      - name: Install twine and tomli\n        run: pip install -U twine tomli\n\n      - name: Install protoc (macOS)\n        if: matrix.os == 'macos'\n        run: brew install protobuf\n\n      - name: Install protoc (Windows)\n        if: matrix.os == 'windows'\n        run: choco install protoc -y\n\n      - name: Build wheels\n        uses: PyO3/maturin-action@v1\n        with:\n          working-directory: bindings/python\n          target: ${{ matrix.target }}\n          manylinux: ${{ matrix.manylinux || 'auto' }}\n          args: --release --out dist --features vendored-openssl --interpreter ${{ matrix.interpreter || '3.9 3.10 3.11 3.12 3.13 3.14' }}\n          rust-toolchain: stable\n          docker-options: -e CI -e CC_aarch64_unknown_linux_gnu=aarch64-linux-gnu-gcc -e CXX_aarch64_unknown_linux_gnu=aarch64-linux-gnu-g++\n          before-script-linux: |\n            # Install build dependencies (perl/make for vendored OpenSSL, protoc for gRPC)\n            if command -v yum &> /dev/null; then\n              yum update -y && yum install -y wget unzip gcc gcc-c++ perl-core make\n              # Install cross-compilation toolchain for aarch64 if needed\n              if [ \"${{ matrix.target }}\" = \"aarch64\" ]; then\n                yum install -y gcc-aarch64-linux-gnu gcc-c++-aarch64-linux-gnu || true\n              fi\n            elif command -v apt-get &> /dev/null; then\n              apt-get update && apt-get install -y wget unzip gcc g++ perl make\n              # Install cross-compilation toolchain for aarch64 if needed\n              if [ \"${{ matrix.target }}\" = \"aarch64\" ]; then\n                apt-get install -y gcc-aarch64-linux-gnu g++-aarch64-linux-gnu || true\n              fi\n            fi\n            (cd /tmp && \\\n             wget https://github.com/protocolbuffers/protobuf/releases/download/v32.0/protoc-32.0-linux-x86_64.zip && \\\n             unzip protoc-32.0-linux-x86_64.zip -d /usr/local && \\\n             rm protoc-32.0-linux-x86_64.zip)\n            protoc --version\n\n      - name: List built packages\n        run: ${{ matrix.ls || 'ls -lh' }} bindings/python/dist/\n\n      - name: Check packages\n        run: twine check --strict bindings/python/dist/*\n\n      - uses: actions/upload-artifact@v4\n        with:\n          name: packages-${{ matrix.os }}-${{ matrix.target }}-${{ matrix.manylinux || 'auto' }}\n          path: bindings/python/dist/\n\n  build-sdist:\n    name: Build SDist\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v4\n        with:\n          path: sglang-repo\n\n      - name: Move sgl-model-gateway folder to root and delete sglang-repo\n        run: |\n          mv sglang-repo/sgl-model-gateway/* .\n          rm -rf sglang-repo\n          ls -alt\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: \"3.13\"\n\n      - name: Modify version for nightly release\n        run: |\n          # Get current version from pyproject.toml\n          CURRENT_VERSION=$(python -c \"import tomllib; print(tomllib.load(open('bindings/python/pyproject.toml', 'rb'))['project']['version'])\" 2>/dev/null || python -c \"import tomli; print(tomli.load(open('bindings/python/pyproject.toml', 'rb'))['project']['version'])\")\n          # Create nightly version with date: e.g., 0.2.1.dev20250128\n          NIGHTLY_VERSION=\"${CURRENT_VERSION}.dev$(date +%Y%m%d)\"\n          echo \"Nightly version: $NIGHTLY_VERSION\"\n\n          # Update pyproject.toml with nightly version (temporary, not committed)\n          sed -i \"s/version = \\\"${CURRENT_VERSION}\\\"/version = \\\"${NIGHTLY_VERSION}\\\"/\" bindings/python/pyproject.toml\n\n          # Verify the change\n          cat bindings/python/pyproject.toml | grep \"^version\"\n\n      - name: Build SDist\n        uses: PyO3/maturin-action@v1\n        with:\n          working-directory: bindings/python\n          command: sdist\n          args: --out dist\n          rust-toolchain: stable\n\n      - uses: actions/upload-artifact@v4\n        with:\n          name: sdist\n          path: bindings/python/dist/*.tar.gz\n\n  upload:\n    name: Upload to TestPyPI\n    if: github.repository == 'sgl-project/sglang'  # Ensure this job only runs for the sgl-project/sglang repository\n    needs: [build, build-sdist]\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/download-artifact@v4\n        with:\n          path: dist\n          merge-multiple: true\n\n      - name: Upload to TestPyPI\n        env:\n          TWINE_USERNAME: __token__\n          TWINE_PASSWORD: ${{ secrets.TEST_PYPI_TOKEN_ROUTER }}\n        run: |\n          pip install twine\n          twine upload --repository testpypi dist/* --verbose\n"
  },
  {
    "path": ".github/workflows/nightly-test-amd-rocm720.yml",
    "content": "name: Nightly Test (AMD ROCm 7.2)\n\non:\n  schedule:\n    - cron: '30 17 * * *'\n  push:\n    branches:\n      - main\n    paths:\n      - \"python/sglang/version.py\"\n  workflow_dispatch:\n    inputs:\n      aiter_ref:\n        description: 'Override AITER commit (optional, leave empty to use Dockerfile default)'\n        required: false\n        type: string\n        default: ''\n      continue_on_error:\n        description: 'Continue on error (do not fail the workflow on test failures)'\n        required: false\n        type: boolean\n        default: true\n      job_select:\n        description: 'Select a job to run from dropdown (choose \"all\" to run all jobs)'\n        required: false\n        type: choice\n        default: 'all'\n        options:\n          - 'all'\n          - nightly-test-1-gpu-unit-rocm720\n          - nightly-accuracy-2-gpu-rocm720\n          - nightly-accuracy-2-gpu-vlm-rocm720\n          - nightly-perf-2-gpu-text-rocm720\n          - nightly-perf-2-gpu-vlm-rocm720\n          - nightly-accuracy-8-gpu-rocm720\n          - nightly-8-gpu-grok1-int4-rocm720\n          - nightly-8-gpu-grok2-rocm720\n          - nightly-8-gpu-deepseek-v31-rocm720\n          - nightly-8-gpu-deepseek-v32-rocm720\n          - nightly-8-gpu-deepseek-v32-mtp-rocm720\n          - nightly-8-gpu-deepseek-v3-kv-fp8-rocm720\n          - nightly-8-gpu-kimi-k25-rocm720\n          - nightly-8-gpu-qwen3-235b-rocm720\n          - nightly-8-gpu-qwen35-rocm720\n          - nightly-8-gpu-glm5-rocm720\n          - nightly-8-gpu-minimax-m25-rocm720\n          - nightly-1-gpu-zimage-turbo-rocm720\n          - nightly-test-1-gpu-mi35x-rocm720\n          - nightly-accuracy-8-gpu-mi35x-rocm720\n          - nightly-8-gpu-mi35x-grok1-int4-rocm720\n          - nightly-8-gpu-mi35x-grok2-rocm720\n          - nightly-8-gpu-mi35x-deepseek-r1-mxfp4-rocm720\n          - nightly-8-gpu-mi35x-deepseek-r1-mxfp4-kv-fp8-rocm720\n          - nightly-8-gpu-mi35x-deepseek-r1-mxfp4-ar-fusion-rocm720\n          - nightly-accuracy-8-gpu-mi35x-deepseek-v32-rocm720\n          - nightly-accuracy-8-gpu-mi35x-deepseek-v32-mtp-rocm720\n          - nightly-perf-8-gpu-mi35x-deepseek-v32-basic-rocm720\n          - nightly-perf-8-gpu-mi35x-deepseek-v32-mtp-rocm720\n          - nightly-8-gpu-mi35x-kimi-k25-rocm720\n          - nightly-8-gpu-mi35x-qwen3-235b-mxfp4-rocm720\n          - nightly-8-gpu-mi35x-qwen35-rocm720\n          - nightly-8-gpu-mi35x-glm5-rocm720\n          - nightly-8-gpu-mi35x-minimax-m25-rocm720\n      job_filter:\n        description: 'Or type comma-separated job names (overrides dropdown if non-empty)'\n        required: false\n        type: string\n        default: ''\n  workflow_call:\n    inputs:\n      ref:\n        description: 'Git ref (branch, tag, or SHA) to test. If not provided, uses the default branch.'\n        required: false\n        type: string\n        default: ''\n      aiter_ref:\n        description: 'Override AITER commit (optional, leave empty to use Dockerfile default)'\n        required: false\n        type: string\n        default: ''\n      job_filter:\n        description: 'Select which job to run (leave empty or \"all\" to run all jobs)'\n        required: false\n        type: string\n        default: 'all'\n      continue_on_error:\n        description: 'Continue on error (do not fail the workflow on test failures)'\n        required: false\n        type: boolean\n        default: true\n\nenv:\n  AITER_COMMIT_OVERRIDE: ${{ inputs.aiter_ref }}\n\nconcurrency:\n  # When called via workflow_call with ref set, use a unique group per caller run to avoid\n  # collisions with direct schedule/push triggers. We use inputs.ref (not github.event_name)\n  # to detect this, because github.event_name inherits from the caller in workflow_call.\n  group: nightly-test-amd-rocm720-${{ inputs.ref && format('caller-{0}', github.run_id) || github.ref }}\n  cancel-in-progress: ${{ !inputs.ref && github.event_name != 'workflow_call' }}\n\njobs:\n  # ============================================== MI30x ROCm 7.2 Unit Tests ==============================================\n  # 1-GPU Unit Tests - LoRA, debug utils, scheduler, etc. (MI30x ROCm 7.2)\n  nightly-test-1-gpu-unit-rocm720:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-test-1-gpu-unit-rocm720,'))\n    runs-on: linux-mi325-1gpu-sglang\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker (ROCm 7.2)\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh\n      - name: Nightly Unit Test ROCm 7.2 (1-GPU)\n        timeout-minutes: 90\n        run: |\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-1-gpu --nightly --timeout-per-file 900 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # ============================================== MI30x ROCm 7.2 Accuracy Tests ==============================================\n  # 2-GPU Accuracy Tests - GSM8K eval (MI30x ROCm 7.2)\n  nightly-accuracy-2-gpu-rocm720:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-accuracy-2-gpu-rocm720,'))\n    runs-on: linux-mi325-2gpu-sglang\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker (ROCm 7.2)\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh\n      - name: Nightly Test ROCm 7.2 (2-GPU)\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd --nightly --timeout-per-file 7200 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # 2-GPU VLM Accuracy Tests - Vision-Language Models MMMU evaluation (ROCm 7.2)\n  nightly-accuracy-2-gpu-vlm-rocm720:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-accuracy-2-gpu-vlm-rocm720,'))\n    runs-on: linux-mi325-2gpu-sglang\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker (ROCm 7.2)\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh\n      - name: Nightly Accuracy Test ROCm 7.2 (2-GPU VLM MMMU)\n        timeout-minutes: 180\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-accuracy-2-gpu-vlm --nightly --timeout-per-file 7200 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # 2-GPU Text Models Performance Tests (ROCm 7.2)\n  nightly-perf-2-gpu-text-rocm720:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-perf-2-gpu-text-rocm720,'))\n    runs-on: linux-mi325-2gpu-sglang\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker (ROCm 7.2)\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh\n      - name: Performance Test ROCm 7.2 (2-GPU Text Models)\n        timeout-minutes: 120\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e SGLANG_USE_AITER=1 \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-perf-text-2-gpu --nightly --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # 2-GPU VLM Performance Tests (ROCm 7.2)\n  nightly-perf-2-gpu-vlm-rocm720:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-perf-2-gpu-vlm-rocm720,'))\n    runs-on: linux-mi325-2gpu-sglang\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker (ROCm 7.2)\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh\n      - name: Performance Test ROCm 7.2 (2-GPU VLM Models)\n        timeout-minutes: 180\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e SGLANG_USE_AITER=1 \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-perf-vlm-2-gpu --nightly --timeout-per-file 7200 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # 8-GPU Accuracy Tests - GPT-OSS, Grok1-FP8 (ROCm 7.2)\n  nightly-accuracy-8-gpu-rocm720:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-accuracy-8-gpu-rocm720,'))\n    runs-on: linux-mi325-8gpu-sglang\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker (ROCm 7.2)\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh --skip-test-time-deps\n\n      - name: Accuracy Test ROCm 7.2 (8-GPU GPT-OSS)\n        timeout-minutes: 180\n        run: |\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-accuracy-8-gpu-gpt-oss --nightly --timeout-per-file 7200 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n      - name: Accuracy Test ROCm 7.2 (8-GPU Grok1-FP8)\n        timeout-minutes: 60\n        run: |\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e RCCL_MSCCL_ENABLE=0 \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-accuracy-8-gpu-grok1-fp8 --nightly --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # ============================================== MI30x ROCm 7.2 Combined Accuracy + Performance Tests ==============================================\n  # 8-GPU Grok1-INT4 (Accuracy + Performance) ROCm 7.2\n  nightly-8-gpu-grok1-int4-rocm720:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-grok1-int4-rocm720,'))\n    runs-on: linux-mi325-8gpu-sglang\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker (ROCm 7.2)\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh --skip-test-time-deps\n\n      - name: Accuracy Test ROCm 7.2 (8-GPU Grok1-INT4)\n        timeout-minutes: 60\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e RCCL_MSCCL_ENABLE=0 \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-accuracy-8-gpu-grok1-int4 --nightly --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n      - name: Performance Test ROCm 7.2 (8-GPU Grok1-INT4)\n        timeout-minutes: 60\n        continue-on-error: true\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e RCCL_MSCCL_ENABLE=0 \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-perf-8-gpu-grok1-int4 --nightly --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # 8-GPU Grok2 (Accuracy + Performance) ROCm 7.2\n  nightly-8-gpu-grok2-rocm720:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-grok2-rocm720,'))\n    runs-on: linux-mi325-8gpu-sglang\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker (ROCm 7.2)\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh --skip-test-time-deps\n\n      - name: Accuracy Test ROCm 7.2 (8-GPU Grok2)\n        timeout-minutes: 60\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e RCCL_MSCCL_ENABLE=0 \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-accuracy-8-gpu-grok2 --nightly --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n      - name: Performance Test ROCm 7.2 (8-GPU Grok2)\n        timeout-minutes: 60\n        continue-on-error: true\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e RCCL_MSCCL_ENABLE=0 \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-perf-8-gpu-grok2 --nightly --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # 8-GPU DeepSeek-V3.1 (Accuracy + Performance) ROCm 7.2\n  nightly-8-gpu-deepseek-v31-rocm720:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-deepseek-v31-rocm720,'))\n    runs-on: linux-mi325-8gpu-sglang\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker (ROCm 7.2)\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh --skip-test-time-deps\n\n      - name: Accuracy Test ROCm 7.2 (8-GPU DeepSeek-V3.1)\n        timeout-minutes: 120\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e SGLANG_USE_AITER=1 \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-accuracy-8-gpu-deepseek-v31 --nightly --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n      - name: Performance Test ROCm 7.2 (8-GPU DeepSeek-V3.1)\n        timeout-minutes: 300\n        continue-on-error: true\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e SGLANG_USE_ROCM700A=1 \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-perf-8-gpu-deepseek-v31 --nightly --timeout-per-file 18000 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # 8-GPU DeepSeek-V3.2 (Basic Accuracy + Perf) ROCm 7.2\n  nightly-8-gpu-deepseek-v32-rocm720:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-deepseek-v32-rocm720,'))\n    runs-on: linux-mi325-8gpu-sglang\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker (ROCm 7.2)\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh --skip-test-time-deps\n\n      - name: Accuracy Test ROCm 7.2 (8-GPU DeepSeek-V3.2 Basic)\n        timeout-minutes: 120\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-accuracy-8-gpu-deepseek-v32 --nightly --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n      - name: Performance Test ROCm 7.2 (8-GPU DeepSeek-V3.2 Basic)\n        timeout-minutes: 150\n        continue-on-error: true\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-perf-8-gpu-deepseek-v32-basic --nightly --timeout-per-file 5400 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # 8-GPU DeepSeek-V3.2 MTP (MTP Accuracy + Perf) ROCm 7.2\n  nightly-8-gpu-deepseek-v32-mtp-rocm720:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-deepseek-v32-mtp-rocm720,'))\n    runs-on: linux-mi325-8gpu-sglang\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker (ROCm 7.2)\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh --skip-test-time-deps\n\n      - name: Accuracy Test ROCm 7.2 (8-GPU DeepSeek-V3.2 MTP)\n        timeout-minutes: 120\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-accuracy-8-gpu-deepseek-v32-mtp --nightly --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n      - name: Performance Test ROCm 7.2 (8-GPU DeepSeek-V3.2 MTP)\n        timeout-minutes: 180\n        continue-on-error: true\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-perf-8-gpu-deepseek-v32-mtp --nightly --timeout-per-file 7200 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # 8-GPU DeepSeek-V3 KV FP8 (Basic + MTP with --kv-cache-dtype fp8_e4m3) ROCm 7.2\n  nightly-8-gpu-deepseek-v3-kv-fp8-rocm720:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-deepseek-v3-kv-fp8-rocm720,'))\n    runs-on: linux-mi325-8gpu-sglang\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker (ROCm 7.2)\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh --skip-test-time-deps\n\n      - name: DeepSeek-V3 KV FP8 Test ROCm 7.2 (8-GPU Basic + MTP)\n        timeout-minutes: 120\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-8-gpu-deepseek-v3-kv-fp8 --nightly --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # 8-GPU Kimi-K2.5 (Accuracy) ROCm 7.2\n  nightly-8-gpu-kimi-k25-rocm720:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-kimi-k25-rocm720,'))\n    runs-on: linux-mi325-8gpu-sglang\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker (ROCm 7.2)\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh --skip-test-time-deps\n\n      - name: Accuracy Test ROCm 7.2 (8-GPU Kimi-K2.5)\n        timeout-minutes: 120\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-accuracy-8-gpu-kimi-k25 --nightly --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # 8-GPU Qwen3-235B (Accuracy + Performance) ROCm 7.2\n  nightly-8-gpu-qwen3-235b-rocm720:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-qwen3-235b-rocm720,'))\n    runs-on: linux-mi325-8gpu-sglang\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker (ROCm 7.2)\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh --skip-test-time-deps\n\n      - name: Accuracy Test + Performance Test ROCm 7.2 (8-GPU Qwen3)\n        timeout-minutes: 120\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-8-gpu-qwen3-235b --nightly --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # 8-GPU Qwen 3.5 (Accuracy) ROCm 7.2\n  nightly-8-gpu-qwen35-rocm720:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-qwen35-rocm720,'))\n    runs-on: linux-mi325-8gpu-sglang\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker (ROCm 7.2)\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh --skip-aiter-build --skip-test-time-deps\n          bash scripts/ci/amd/amd_ci_exec.sh pip install mistral-common \"lm-eval[api]\"\n\n      - name: Accuracy Test ROCm 7.2 (8-GPU Qwen 3.5)\n        timeout-minutes: 120\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-accuracy-8-gpu-qwen35 --nightly --timeout-per-file 3600 --continue-on-error || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # 8-GPU GLM-5 (Accuracy) ROCm 7.2\n  nightly-8-gpu-glm5-rocm720:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-glm5-rocm720,'))\n    runs-on: linux-mi325-8gpu-sglang\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker (ROCm 7.2)\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh --skip-test-time-deps\n          bash scripts/ci/amd/amd_ci_exec.sh pip install git+https://github.com/huggingface/transformers.git@96f807a33b75\n\n      - name: Accuracy Test ROCm 7.2 (8-GPU GLM-5 NSA)\n        timeout-minutes: 120\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-accuracy-8-gpu-glm5 --nightly --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # 8-GPU MiniMax-M2.5 (Accuracy) ROCm 7.2\n  nightly-8-gpu-minimax-m25-rocm720:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-minimax-m25-rocm720,'))\n    runs-on: linux-mi325-8gpu-sglang\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker (ROCm 7.2)\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh --skip-test-time-deps\n\n      - name: Accuracy Test ROCm 7.2 (8-GPU MiniMax-M2.5)\n        timeout-minutes: 120\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e SGLANG_USE_AITER=1 \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-accuracy-8-gpu-minimax-m25 --nightly --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # ============================================== MI30x ROCm 7.2 Diffusion Tests ==============================================\n  # 1-GPU Z-Image-Turbo (Diffusion T2I) ROCm 7.2\n  nightly-1-gpu-zimage-turbo-rocm720:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-1-gpu-zimage-turbo-rocm720,'))\n    runs-on: linux-mi325-1gpu-sglang\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker (ROCm 7.2)\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh\n\n      - name: Z-Image-Turbo Diffusion Test ROCm 7.2 (1-GPU)\n        timeout-minutes: 45\n        run: |\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            -e SGLANG_DIFFUSION_ARTIFACT_DIR=\"/sglang-checkout/diffusion-artifacts\" \\\n            pytest test/registered/amd/test_zimage_turbo.py -v -s ${{ inputs.continue_on_error && '|| true' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n      - name: Upload generated images\n        if: always()\n        uses: actions/upload-artifact@v4\n        with:\n          name: zimage-turbo-outputs-rocm720\n          path: diffusion-artifacts/\n          if-no-files-found: ignore\n          retention-days: 30\n\n  # ============================================== MI35x ROCm 7.2 Tests ==============================================\n  # MI35x 1-GPU ROCm 7.2 tests\n  nightly-test-1-gpu-mi35x-rocm720:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-test-1-gpu-mi35x-rocm720,'))\n    runs-on: linux-mi35x-gpu-1\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker (ROCm 7.2)\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh\n      - name: Nightly Test MI35x ROCm 7.2 (1-GPU)\n        timeout-minutes: 90\n        run: |\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-1-gpu-mi35x --nightly --timeout-per-file 900 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # MI35x 8-GPU Accuracy Tests - GPT-OSS (ROCm 7.2)\n  nightly-accuracy-8-gpu-mi35x-rocm720:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-accuracy-8-gpu-mi35x-rocm720,'))\n    runs-on: linux-mi35x-gpu-8\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker (ROCm 7.2)\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh --skip-test-time-deps\n          # Install tabulate for run_suite.py (missing in MI35x container)\n          bash scripts/ci/amd/amd_ci_exec.sh pip install tabulate\n\n      - name: Accuracy Test MI35x ROCm 7.2 (8-GPU GPT-OSS)\n        timeout-minutes: 180\n        run: |\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-8-gpu-mi35x --nightly --timeout-per-file 7200 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # MI35x 8-GPU Grok1-INT4 (Accuracy + Performance) ROCm 7.2\n  nightly-8-gpu-mi35x-grok1-int4-rocm720:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-mi35x-grok1-int4-rocm720,'))\n    runs-on: linux-mi35x-gpu-8\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker (ROCm 7.2)\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh --skip-test-time-deps\n          # Install tabulate for run_suite.py (missing in MI35x container)\n          bash scripts/ci/amd/amd_ci_exec.sh pip install tabulate\n\n      - name: Accuracy Test MI35x ROCm 7.2 (8-GPU Grok1-INT4)\n        timeout-minutes: 60\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e RCCL_MSCCL_ENABLE=0 \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-accuracy-8-gpu-mi35x-grok1-int4 --nightly --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n      - name: Performance Test MI35x ROCm 7.2 (8-GPU Grok1-INT4)\n        timeout-minutes: 60\n        continue-on-error: true\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e RCCL_MSCCL_ENABLE=0 \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-perf-8-gpu-mi35x-grok1-int4 --nightly --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # MI35x 8-GPU Grok2 (Accuracy + Performance) ROCm 7.2\n  nightly-8-gpu-mi35x-grok2-rocm720:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-mi35x-grok2-rocm720,'))\n    runs-on: linux-mi35x-gpu-8\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker (ROCm 7.2)\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh --skip-test-time-deps\n          # Install tabulate for run_suite.py (missing in MI35x container)\n          bash scripts/ci/amd/amd_ci_exec.sh pip install tabulate\n\n      - name: Accuracy Test MI35x ROCm 7.2 (8-GPU Grok2)\n        timeout-minutes: 60\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e RCCL_MSCCL_ENABLE=0 \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-accuracy-8-gpu-mi35x-grok2 --nightly --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n      - name: Performance Test MI35x ROCm 7.2 (8-GPU Grok2)\n        timeout-minutes: 60\n        continue-on-error: true\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e RCCL_MSCCL_ENABLE=0 \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-perf-8-gpu-mi35x-grok2 --nightly --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # MI35x 8-GPU DeepSeek-R1-MXFP4 (Accuracy + Performance) ROCm 7.2\n  nightly-8-gpu-mi35x-deepseek-r1-mxfp4-rocm720:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-mi35x-deepseek-r1-mxfp4-rocm720,'))\n    runs-on: linux-mi35x-gpu-8\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker (ROCm 7.2)\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh --skip-test-time-deps\n          # Install tabulate for run_suite.py (missing in MI35x container)\n          bash scripts/ci/amd/amd_ci_exec.sh pip install tabulate\n\n      - name: Accuracy Test MI35x ROCm 7.2 (8-GPU DeepSeek-R1-MXFP4)\n        timeout-minutes: 180\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-8-gpu-mi35x-deepseek-r1-mxfp4 --nightly --timeout-per-file 7200 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n      - name: Performance Test MI35x ROCm 7.2 (8-GPU DeepSeek-R1-MXFP4)\n        timeout-minutes: 300\n        continue-on-error: true\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 registered/amd/perf/mi35x/test_deepseek_r1_mxfp4_perf_mi35x.py || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # MI35x 8-GPU DeepSeek-R1-MXFP4 KV FP8 (Accuracy + Performance) ROCm 7.2\n  nightly-8-gpu-mi35x-deepseek-r1-mxfp4-kv-fp8-rocm720:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-mi35x-deepseek-r1-mxfp4-kv-fp8-rocm720,'))\n    runs-on: linux-mi35x-gpu-8\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker (ROCm 7.2)\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh --skip-test-time-deps\n          # Install tabulate for run_suite.py (missing in MI35x container)\n          bash scripts/ci/amd/amd_ci_exec.sh pip install tabulate\n\n      - name: Accuracy Test MI35x ROCm 7.2 (8-GPU DeepSeek-R1-MXFP4 KV FP8)\n        timeout-minutes: 180\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-8-gpu-mi35x-deepseek-r1-mxfp4-kv-fp8 --nightly --timeout-per-file 7200 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n      - name: Performance Test MI35x ROCm 7.2 (8-GPU DeepSeek-R1-MXFP4 KV FP8)\n        timeout-minutes: 300\n        continue-on-error: true\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 registered/amd/perf/mi35x/test_deepseek_r1_mxfp4_kv_fp8_perf_mi35x.py || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # MI35x 8-GPU DeepSeek-R1-MXFP4 AllReduce Fusion (Accuracy + Performance) ROCm 7.2\n  nightly-8-gpu-mi35x-deepseek-r1-mxfp4-ar-fusion-rocm720:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-mi35x-deepseek-r1-mxfp4-ar-fusion-rocm720,'))\n    runs-on: linux-mi35x-gpu-8\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker (ROCm 7.2)\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh --skip-test-time-deps\n          # Install tabulate for run_suite.py (missing in MI35x container)\n          bash scripts/ci/amd/amd_ci_exec.sh pip install tabulate\n\n      - name: Accuracy Test MI35x ROCm 7.2 (8-GPU DeepSeek-R1-MXFP4 AllReduce Fusion)\n        timeout-minutes: 180\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-8-gpu-mi35x-deepseek-r1-mxfp4-ar-fusion --nightly --timeout-per-file 7200 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n      - name: Performance Test MI35x ROCm 7.2 (8-GPU DeepSeek-R1-MXFP4 AllReduce Fusion)\n        timeout-minutes: 300\n        continue-on-error: true\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 registered/amd/perf/mi35x/test_deepseek_r1_mxfp4_ar_fusion_perf_mi35x.py || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # MI35x 8-GPU DeepSeek-V3.2 Accuracy Test (ROCm 7.2)\n  nightly-accuracy-8-gpu-mi35x-deepseek-v32-rocm720:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-accuracy-8-gpu-mi35x-deepseek-v32-rocm720,'))\n    runs-on: linux-mi35x-gpu-8\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker (ROCm 7.2)\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh --skip-test-time-deps\n          # Install tabulate for run_suite.py (missing in MI35x container)\n          bash scripts/ci/amd/amd_ci_exec.sh pip install tabulate\n\n      - name: Accuracy Test MI35x ROCm 7.2 (8-GPU DeepSeek-V3.2)\n        timeout-minutes: 120\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-8-gpu-mi35x-deepseek-v32 --nightly --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # MI35x 8-GPU DeepSeek-V3.2 TP+MTP Accuracy Test (ROCm 7.2)\n  nightly-accuracy-8-gpu-mi35x-deepseek-v32-mtp-rocm720:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-accuracy-8-gpu-mi35x-deepseek-v32-mtp-rocm720,'))\n    runs-on: linux-mi35x-gpu-8\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker (ROCm 7.2)\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh --skip-test-time-deps\n          # Install tabulate for run_suite.py (missing in MI35x container)\n          bash scripts/ci/amd/amd_ci_exec.sh pip install tabulate\n\n      - name: Accuracy Test MI35x ROCm 7.2 (8-GPU DeepSeek-V3.2 TP+MTP)\n        timeout-minutes: 120\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-accuracy-8-gpu-mi35x-deepseek-v32-mtp --nightly --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # MI35x 8-GPU DeepSeek-V3.2 Performance Test (Basic) ROCm 7.2\n  nightly-perf-8-gpu-mi35x-deepseek-v32-basic-rocm720:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-perf-8-gpu-mi35x-deepseek-v32-basic-rocm720,'))\n    runs-on: linux-mi35x-gpu-8\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker (ROCm 7.2)\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh --skip-test-time-deps\n          # Install tabulate for run_suite.py (missing in MI35x container)\n          bash scripts/ci/amd/amd_ci_exec.sh pip install tabulate\n\n      - name: Performance Test MI35x ROCm 7.2 (8-GPU DeepSeek-V3.2 Basic)\n        timeout-minutes: 150\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-perf-8-gpu-mi35x-deepseek-v32-basic --nightly --timeout-per-file 5400 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # MI35x 8-GPU Kimi-K2.5 (Accuracy) ROCm 7.2\n  nightly-8-gpu-mi35x-kimi-k25-rocm720:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-mi35x-kimi-k25-rocm720,'))\n    runs-on: linux-mi35x-gpu-8\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker (ROCm 7.2)\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh --skip-test-time-deps\n          # Install tabulate for run_suite.py (missing in MI35x container)\n          bash scripts/ci/amd/amd_ci_exec.sh pip install tabulate\n\n      - name: Accuracy Test MI35x ROCm 7.2 (8-GPU Kimi-K2.5)\n        timeout-minutes: 180\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-accuracy-8-gpu-mi35x-kimi-k25 --nightly --timeout-per-file 7200 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # MI35x 8-GPU Qwen3-235B-MXFP4 (Accuracy + Performance) ROCm 7.2\n  nightly-8-gpu-mi35x-qwen3-235b-mxfp4-rocm720:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-mi35x-qwen3-235b-mxfp4-rocm720,'))\n    runs-on: linux-mi35x-gpu-8\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker (ROCm 7.2)\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh --skip-test-time-deps\n          # Install tabulate for run_suite.py (missing in MI35x container)\n          bash scripts/ci/amd/amd_ci_exec.sh pip install tabulate\n\n      - name: Accuracy Test + Performance Test MI35x ROCm 7.2 (8-GPU Qwen3-235B-MXFP4)\n        timeout-minutes: 120\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-8-gpu-mi35x-qwen3-235b-mxfp4 --nightly --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # MI35x 8-GPU Qwen 3.5 (Accuracy) ROCm 7.2\n  nightly-8-gpu-mi35x-qwen35-rocm720:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-mi35x-qwen35-rocm720,'))\n    runs-on: linux-mi35x-gpu-8\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker (ROCm 7.2)\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh --skip-aiter-build --skip-test-time-deps\n          bash scripts/ci/amd/amd_ci_exec.sh pip install tabulate\n          bash scripts/ci/amd/amd_ci_exec.sh pip install mistral-common \"lm-eval[api]\"\n\n      - name: Accuracy Test MI35x ROCm 7.2 (8-GPU Qwen 3.5)\n        timeout-minutes: 120\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-accuracy-8-gpu-mi35x-qwen35 --nightly --timeout-per-file 3600 --continue-on-error || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  nightly-8-gpu-mi35x-glm5-rocm720:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-mi35x-glm5-rocm720,'))\n    runs-on: linux-mi35x-gpu-8\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker (ROCm 7.2)\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh --skip-test-time-deps\n          # Install tabulate for run_suite.py (missing in MI35x container)\n          bash scripts/ci/amd/amd_ci_exec.sh pip install tabulate\n          bash scripts/ci/amd/amd_ci_exec.sh pip install git+https://github.com/huggingface/transformers.git@96f807a33b75\n\n      - name: Accuracy Test MI35x ROCm 7.2 (8-GPU GLM-5 NSA)\n        timeout-minutes: 180\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-8-gpu-mi35x-glm5 --nightly --timeout-per-file 7200 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # MI35x 8-GPU MiniMax-M2.5 (Accuracy) ROCm 7.2\n  nightly-8-gpu-mi35x-minimax-m25-rocm720:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-mi35x-minimax-m25-rocm720,'))\n    runs-on: linux-mi35x-gpu-8\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker (ROCm 7.2)\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh --skip-test-time-deps\n          bash scripts/ci/amd/amd_ci_exec.sh pip install tabulate\n\n      - name: Accuracy Test MI35x ROCm 7.2 (8-GPU MiniMax-M2.5)\n        timeout-minutes: 120\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e SGLANG_USE_AITER=1 \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-8-gpu-mi35x-minimax-m25 --nightly --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # MI35x 8-GPU DeepSeek-V3.2 Performance Test (MTP) ROCm 7.2\n  nightly-perf-8-gpu-mi35x-deepseek-v32-mtp-rocm720:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-perf-8-gpu-mi35x-deepseek-v32-mtp-rocm720,'))\n    runs-on: linux-mi35x-gpu-8\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker (ROCm 7.2)\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh --skip-test-time-deps\n          # Install tabulate for run_suite.py (missing in MI35x container)\n          bash scripts/ci/amd/amd_ci_exec.sh pip install tabulate\n\n      - name: Performance Test MI35x ROCm 7.2 (8-GPU DeepSeek-V3.2 MTP)\n        timeout-minutes: 180\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-perf-8-gpu-mi35x-deepseek-v32-mtp --nightly --timeout-per-file 7200 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  check-all-jobs:\n    if: always() && (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' || github.event_name == 'workflow_dispatch')\n    needs:\n      # MI30x ROCm 7.2 Unit Tests\n      - nightly-test-1-gpu-unit-rocm720\n      # MI30x ROCm 7.2 Accuracy Tests\n      - nightly-accuracy-2-gpu-rocm720\n      - nightly-accuracy-2-gpu-vlm-rocm720\n      # MI30x ROCm 7.2 Performance Tests\n      - nightly-perf-2-gpu-text-rocm720\n      - nightly-perf-2-gpu-vlm-rocm720\n      - nightly-accuracy-8-gpu-rocm720\n      # MI30x ROCm 7.2 Combined Accuracy + Performance Tests\n      - nightly-8-gpu-grok1-int4-rocm720\n      - nightly-8-gpu-grok2-rocm720\n      - nightly-8-gpu-deepseek-v31-rocm720\n      - nightly-8-gpu-deepseek-v32-rocm720\n      - nightly-8-gpu-deepseek-v32-mtp-rocm720\n      - nightly-8-gpu-deepseek-v3-kv-fp8-rocm720\n      - nightly-8-gpu-kimi-k25-rocm720\n      - nightly-8-gpu-qwen3-235b-rocm720\n      - nightly-8-gpu-qwen35-rocm720\n      - nightly-8-gpu-glm5-rocm720\n      - nightly-8-gpu-minimax-m25-rocm720\n      # MI30x ROCm 7.2 Diffusion Tests\n      - nightly-1-gpu-zimage-turbo-rocm720\n      # MI35x ROCm 7.2 jobs\n      - nightly-test-1-gpu-mi35x-rocm720\n      - nightly-accuracy-8-gpu-mi35x-rocm720\n      - nightly-8-gpu-mi35x-grok1-int4-rocm720\n      - nightly-8-gpu-mi35x-grok2-rocm720\n      - nightly-8-gpu-mi35x-deepseek-r1-mxfp4-rocm720\n      - nightly-8-gpu-mi35x-deepseek-r1-mxfp4-kv-fp8-rocm720\n      - nightly-8-gpu-mi35x-deepseek-r1-mxfp4-ar-fusion-rocm720\n      - nightly-accuracy-8-gpu-mi35x-deepseek-v32-rocm720\n      - nightly-accuracy-8-gpu-mi35x-deepseek-v32-mtp-rocm720\n      - nightly-perf-8-gpu-mi35x-deepseek-v32-basic-rocm720\n      - nightly-perf-8-gpu-mi35x-deepseek-v32-mtp-rocm720\n      - nightly-8-gpu-mi35x-kimi-k25-rocm720\n      - nightly-8-gpu-mi35x-qwen3-235b-mxfp4-rocm720\n      - nightly-8-gpu-mi35x-qwen35-rocm720\n      - nightly-8-gpu-mi35x-glm5-rocm720\n      - nightly-8-gpu-mi35x-minimax-m25-rocm720\n    runs-on: ubuntu-latest\n    steps:\n      - name: Check if any job failed\n        run: |\n          if [[ \"${{ contains(needs.*.result, 'failure') }}\" == \"true\" ]]; then\n            echo \"One or more ROCm 7.2 nightly test jobs failed\"\n            exit 1\n          fi\n          if [[ \"${{ contains(needs.*.result, 'cancelled') }}\" == \"true\" ]]; then\n            echo \"One or more ROCm 7.2 nightly test jobs were cancelled\"\n            exit 1\n          fi\n          echo \"All ROCm 7.2 nightly test jobs passed\"\n"
  },
  {
    "path": ".github/workflows/nightly-test-amd.yml",
    "content": "name: Nightly Test (AMD)\n\non:\n  schedule:\n    - cron: '30 17 * * *'\n  push:\n    branches:\n      - main\n    paths:\n      - \"python/sglang/version.py\"\n  workflow_dispatch:\n    inputs:\n      aiter_ref:\n        description: 'Override AITER commit (optional, leave empty to use Dockerfile default)'\n        required: false\n        type: string\n        default: ''\n      continue_on_error:\n        description: 'Continue on error (do not fail the workflow on test failures)'\n        required: false\n        type: boolean\n        default: true\n      job_select:\n        description: 'Select a job to run from dropdown (choose \"all\" to run all jobs)'\n        required: false\n        type: choice\n        default: 'all'\n        options:\n          - 'all'\n          - nightly-test-1-gpu-unit\n          - nightly-accuracy-2-gpu\n          - nightly-accuracy-2-gpu-vlm\n          - nightly-perf-2-gpu-text\n          - nightly-perf-2-gpu-vlm\n          - nightly-accuracy-8-gpu\n          - nightly-8-gpu-grok1-int4\n          - nightly-8-gpu-grok2\n          - nightly-8-gpu-deepseek-v31\n          - nightly-8-gpu-deepseek-v32\n          - nightly-8-gpu-deepseek-v32-mtp\n          - nightly-8-gpu-deepseek-v3-kv-fp8\n          - nightly-8-gpu-kimi-k25\n          - nightly-8-gpu-qwen3-235b\n          - nightly-8-gpu-qwen35\n          - nightly-8-gpu-glm5\n          - nightly-8-gpu-minimax-m25\n          - nightly-1-gpu-zimage-turbo\n          - nightly-test-1-gpu-mi35x\n          - nightly-accuracy-8-gpu-mi35x\n          - nightly-8-gpu-mi35x-grok1-int4\n          - nightly-8-gpu-mi35x-grok2\n          - nightly-8-gpu-mi35x-deepseek-r1-mxfp4\n          - nightly-8-gpu-mi35x-deepseek-r1-mxfp4-kv-fp8\n          - nightly-8-gpu-mi35x-deepseek-r1-mxfp4-ar-fusion\n          - nightly-accuracy-8-gpu-mi35x-deepseek-v32\n          - nightly-accuracy-8-gpu-mi35x-deepseek-v32-mtp\n          - nightly-perf-8-gpu-mi35x-deepseek-v32-basic\n          - nightly-perf-8-gpu-mi35x-deepseek-v32-mtp\n          - nightly-8-gpu-mi35x-kimi-k25\n          - nightly-8-gpu-mi35x-qwen3-235b-mxfp4\n          - nightly-8-gpu-mi35x-qwen35\n          - nightly-8-gpu-mi35x-glm5\n          - nightly-8-gpu-mi35x-minimax-m25\n      job_filter:\n        description: 'Or type comma-separated job names (overrides dropdown if non-empty)'\n        required: false\n        type: string\n        default: ''\n  workflow_call:\n    inputs:\n      ref:\n        description: 'Git ref (branch, tag, or SHA) to test. If not provided, uses the default branch.'\n        required: false\n        type: string\n        default: ''\n      aiter_ref:\n        description: 'Override AITER commit (optional, leave empty to use Dockerfile default)'\n        required: false\n        type: string\n        default: ''\n      job_filter:\n        description: 'Select which job to run (leave empty or \"all\" to run all jobs)'\n        required: false\n        type: string\n        default: 'all'\n      continue_on_error:\n        description: 'Continue on error (do not fail the workflow on test failures)'\n        required: false\n        type: boolean\n        default: true\n\nenv:\n  AITER_COMMIT_OVERRIDE: ${{ inputs.aiter_ref }}\n\nconcurrency:\n  # When called via workflow_call with ref set, use a unique group per caller run to avoid\n  # collisions with direct schedule/push triggers. We use inputs.ref (not github.event_name)\n  # to detect this, because github.event_name inherits from the caller in workflow_call.\n  group: nightly-test-amd-${{ inputs.ref && format('caller-{0}', github.run_id) || github.ref }}\n  cancel-in-progress: ${{ !inputs.ref && github.event_name != 'workflow_call' }}\n\njobs:\n  # ============================================== MI30x Unit Tests ==============================================\n  # 1-GPU Unit Tests - LoRA, debug utils, scheduler, etc. (MI30x only)\n  nightly-test-1-gpu-unit:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-test-1-gpu-unit,'))\n    runs-on: linux-mi325-1gpu-sglang\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh\n\n      - name: Nightly Unit Test (1-GPU)\n        timeout-minutes: 90\n        run: |\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-1-gpu --nightly --timeout-per-file 900 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # ============================================== MI30x Accuracy Tests ==============================================\n  # 2-GPU Accuracy Tests - GSM8K eval (MI30x only)\n  nightly-accuracy-2-gpu:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-accuracy-2-gpu,'))\n    runs-on: linux-mi325-2gpu-sglang\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh\n\n      - name: Nightly Test (2-GPU)\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd --nightly --timeout-per-file 7200 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # 2-GPU VLM Accuracy Tests - Vision-Language Models MMMU evaluation\n  nightly-accuracy-2-gpu-vlm:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-accuracy-2-gpu-vlm,'))\n    runs-on: linux-mi325-2gpu-sglang\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh\n\n      - name: Nightly Accuracy Test (2-GPU VLM MMMU)\n        timeout-minutes: 180\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-accuracy-2-gpu-vlm --nightly --timeout-per-file 7200 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # 2-GPU Text Models Performance Tests\n  nightly-perf-2-gpu-text:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-perf-2-gpu-text,'))\n    runs-on: linux-mi325-2gpu-sglang\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh\n\n      - name: Performance Test (2-GPU Text Models)\n        timeout-minutes: 120\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e SGLANG_USE_AITER=1 \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-perf-text-2-gpu --nightly --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # 2-GPU VLM Performance Tests\n  nightly-perf-2-gpu-vlm:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-perf-2-gpu-vlm,'))\n    runs-on: linux-mi325-2gpu-sglang\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh\n\n      - name: Performance Test (2-GPU VLM Models)\n        timeout-minutes: 180\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e SGLANG_USE_AITER=1 \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-perf-vlm-2-gpu --nightly --timeout-per-file 7200 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # 8-GPU Accuracy Tests - GPT-OSS, Grok1-FP8 (accuracy only)\n  nightly-accuracy-8-gpu:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-accuracy-8-gpu,'))\n    runs-on: linux-mi325-8gpu-sglang\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh\n\n      - name: Accuracy Test (8-GPU GPT-OSS)\n        timeout-minutes: 180\n        run: |\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-accuracy-8-gpu-gpt-oss --nightly --timeout-per-file 7200 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n      - name: Accuracy Test (8-GPU Grok1-FP8)\n        timeout-minutes: 60\n        run: |\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e RCCL_MSCCL_ENABLE=0 \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-accuracy-8-gpu-grok1-fp8 --nightly --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # ============================================== MI30x Combined Accuracy + Performance Tests ==============================================\n  # 8-GPU Grok1-INT4 (Accuracy + Performance combined)\n  nightly-8-gpu-grok1-int4:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-grok1-int4,'))\n    runs-on: linux-mi325-8gpu-sglang\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh\n\n      - name: Accuracy Test (8-GPU Grok1-INT4)\n        timeout-minutes: 60\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e RCCL_MSCCL_ENABLE=0 \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-accuracy-8-gpu-grok1-int4 --nightly --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n      - name: Performance Test (8-GPU Grok1-INT4)\n        timeout-minutes: 60\n        continue-on-error: true  # Perf test failure doesn't fail the job if accuracy passed\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e RCCL_MSCCL_ENABLE=0 \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-perf-8-gpu-grok1-int4 --nightly --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # 8-GPU Grok2 (Accuracy + Performance combined)\n  nightly-8-gpu-grok2:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-grok2,'))\n    runs-on: linux-mi325-8gpu-sglang\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh\n\n      - name: Accuracy Test (8-GPU Grok2)\n        timeout-minutes: 60\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e RCCL_MSCCL_ENABLE=0 \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-accuracy-8-gpu-grok2 --nightly --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n      - name: Performance Test (8-GPU Grok2)\n        timeout-minutes: 60\n        continue-on-error: true  # Perf test failure doesn't fail the job if accuracy passed\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e RCCL_MSCCL_ENABLE=0 \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-perf-8-gpu-grok2 --nightly --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # 8-GPU DeepSeek-V3.1 (Accuracy + Performance combined)\n  nightly-8-gpu-deepseek-v31:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-deepseek-v31,'))\n    runs-on: linux-mi325-8gpu-sglang\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh\n\n      - name: Accuracy Test (8-GPU DeepSeek-V3.1)\n        timeout-minutes: 120\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e SGLANG_USE_AITER=1 \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-accuracy-8-gpu-deepseek-v31 --nightly --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n      - name: Performance Test (8-GPU DeepSeek-V3.1)\n        timeout-minutes: 300\n        continue-on-error: true  # Perf test failure doesn't fail the job if accuracy passed\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e SGLANG_USE_ROCM700A=1 \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-perf-8-gpu-deepseek-v31 --nightly --timeout-per-file 18000 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # 8-GPU DeepSeek-V3.2 (Basic Accuracy + Perf)\n  nightly-8-gpu-deepseek-v32:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-deepseek-v32,'))\n    runs-on: linux-mi325-8gpu-sglang\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh\n\n      - name: Accuracy Test (8-GPU DeepSeek-V3.2 Basic)\n        timeout-minutes: 120\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-accuracy-8-gpu-deepseek-v32 --nightly --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n      - name: Performance Test (8-GPU DeepSeek-V3.2 Basic)\n        timeout-minutes: 150\n        continue-on-error: true  # Perf test failure doesn't fail the job if accuracy passed\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-perf-8-gpu-deepseek-v32-basic --nightly --timeout-per-file 5400 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # 8-GPU DeepSeek-V3.2 MTP (MTP Accuracy + Perf)\n  nightly-8-gpu-deepseek-v32-mtp:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-deepseek-v32-mtp,'))\n    runs-on: linux-mi325-8gpu-sglang\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh\n\n      - name: Accuracy Test (8-GPU DeepSeek-V3.2 MTP)\n        timeout-minutes: 120\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-accuracy-8-gpu-deepseek-v32-mtp --nightly --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n      - name: Performance Test (8-GPU DeepSeek-V3.2 MTP)\n        timeout-minutes: 180\n        continue-on-error: true  # Perf test failure doesn't fail the job if accuracy passed\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-perf-8-gpu-deepseek-v32-mtp --nightly --timeout-per-file 7200 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # 8-GPU DeepSeek-V3 KV FP8 (Basic + MTP with --kv-cache-dtype fp8_e4m3)\n  nightly-8-gpu-deepseek-v3-kv-fp8:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-deepseek-v3-kv-fp8,'))\n    runs-on: linux-mi325-8gpu-sglang\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh\n\n      - name: DeepSeek-V3 KV FP8 Test (8-GPU Basic + MTP)\n        timeout-minutes: 120\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-8-gpu-deepseek-v3-kv-fp8 --nightly --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # 8-GPU Kimi-K2.5 (Accuracy)\n  nightly-8-gpu-kimi-k25:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-kimi-k25,'))\n    runs-on: linux-mi325-8gpu-sglang\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh\n\n      - name: Accuracy Test (8-GPU Kimi-K2.5)\n        timeout-minutes: 120\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-accuracy-8-gpu-kimi-k25 --nightly --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  nightly-8-gpu-qwen3-235b:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-qwen3-235b,'))\n    runs-on: linux-mi325-8gpu-sglang\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh\n\n      - name: Accuracy Test + Performance Test (8-GPU Qwen3)\n        timeout-minutes: 120\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-8-gpu-qwen3-235b --nightly --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # 8-GPU Qwen 3.5 (Accuracy)\n  nightly-8-gpu-qwen35:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-qwen35,'))\n    runs-on: linux-mi325-8gpu-sglang\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh\n          bash scripts/ci/amd/amd_ci_exec.sh pip install mistral-common \"lm-eval[api]\"\n\n      - name: Accuracy Test (8-GPU Qwen 3.5)\n        timeout-minutes: 120\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-accuracy-8-gpu-qwen35 --nightly --timeout-per-file 3600 || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  nightly-8-gpu-glm5:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-glm5,'))\n    runs-on: linux-mi325-8gpu-sglang\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh\n          bash scripts/ci/amd/amd_ci_exec.sh pip install git+https://github.com/huggingface/transformers.git@96f807a33b75\n\n      - name: Accuracy Test (8-GPU GLM-5 NSA)\n        timeout-minutes: 120\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-accuracy-8-gpu-glm5 --nightly --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # 8-GPU MiniMax-M2.5 (Accuracy)\n  nightly-8-gpu-minimax-m25:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-minimax-m25,'))\n    runs-on: linux-mi325-8gpu-sglang\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh\n\n      - name: Accuracy Test (8-GPU MiniMax-M2.5)\n        timeout-minutes: 120\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e SGLANG_USE_AITER=1 \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-accuracy-8-gpu-minimax-m25 --nightly --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # ============================================== MI30x Diffusion Tests ==============================================\n  # 1-GPU Z-Image-Turbo (Diffusion T2I)\n  nightly-1-gpu-zimage-turbo:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-1-gpu-zimage-turbo,'))\n    runs-on: linux-mi325-1gpu-sglang\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh\n\n      - name: Z-Image-Turbo Diffusion Test (1-GPU)\n        timeout-minutes: 45\n        run: |\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            -e SGLANG_DIFFUSION_ARTIFACT_DIR=\"/sglang-checkout/diffusion-artifacts\" \\\n            pytest test/registered/amd/test_zimage_turbo.py -v -s ${{ inputs.continue_on_error && '|| true' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n      - name: Upload generated images\n        if: always()\n        uses: actions/upload-artifact@v4\n        with:\n          name: zimage-turbo-outputs\n          path: diffusion-artifacts/\n          if-no-files-found: ignore\n          retention-days: 30\n\n  # ============================================== MI35x Tests ==============================================\n  # MI35x 1-GPU tests - platform-agnostic tests that may work on CDNA4 (gfx950)\n  nightly-test-1-gpu-mi35x:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-test-1-gpu-mi35x,'))\n    runs-on: linux-mi35x-gpu-1\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh\n          # Install tabulate for run_suite.py (missing in MI35x container)\n          bash scripts/ci/amd/amd_ci_exec.sh pip install tabulate\n\n      - name: Nightly Test MI35x (1-GPU)\n        timeout-minutes: 90\n        run: |\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-1-gpu-mi35x --nightly --timeout-per-file 900 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # MI35x 8-GPU Accuracy Tests - GPT-OSS (accuracy only)\n  nightly-accuracy-8-gpu-mi35x:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-accuracy-8-gpu-mi35x,'))\n    runs-on: linux-mi35x-gpu-8\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh\n          # Install tabulate for run_suite.py (missing in MI35x container)\n          bash scripts/ci/amd/amd_ci_exec.sh pip install tabulate\n\n      - name: Accuracy Test MI35x (8-GPU GPT-OSS)\n        timeout-minutes: 180\n        run: |\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-8-gpu-mi35x --nightly --timeout-per-file 7200 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # MI35x 8-GPU Grok1-INT4 (Accuracy + Performance combined)\n  nightly-8-gpu-mi35x-grok1-int4:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-mi35x-grok1-int4,'))\n    runs-on: linux-mi35x-gpu-8\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh\n          # Install tabulate for run_suite.py (missing in MI35x container)\n          bash scripts/ci/amd/amd_ci_exec.sh pip install tabulate\n\n      - name: Accuracy Test MI35x (8-GPU Grok1-INT4)\n        timeout-minutes: 90\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e RCCL_MSCCL_ENABLE=0 \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-accuracy-8-gpu-mi35x-grok1-int4 --nightly --timeout-per-file 5400 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n      - name: Performance Test MI35x (8-GPU Grok1-INT4)\n        timeout-minutes: 60\n        continue-on-error: true  # Perf test failure doesn't fail the job if accuracy passed\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e RCCL_MSCCL_ENABLE=0 \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-perf-8-gpu-mi35x-grok1-int4 --nightly --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # MI35x 8-GPU Grok2 (Accuracy + Performance combined)\n  nightly-8-gpu-mi35x-grok2:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-mi35x-grok2,'))\n    runs-on: linux-mi35x-gpu-8\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh\n          # Install tabulate for run_suite.py (missing in MI35x container)\n          bash scripts/ci/amd/amd_ci_exec.sh pip install tabulate\n\n      - name: Accuracy Test MI35x (8-GPU Grok2)\n        timeout-minutes: 60\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e RCCL_MSCCL_ENABLE=0 \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-accuracy-8-gpu-mi35x-grok2 --nightly --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n      - name: Performance Test MI35x (8-GPU Grok2)\n        timeout-minutes: 60\n        continue-on-error: true  # Perf test failure doesn't fail the job if accuracy passed\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e RCCL_MSCCL_ENABLE=0 \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-perf-8-gpu-mi35x-grok2 --nightly --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # MI35x 8-GPU DeepSeek-R1-MXFP4 (Accuracy + Performance combined)\n  nightly-8-gpu-mi35x-deepseek-r1-mxfp4:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-mi35x-deepseek-r1-mxfp4,'))\n    runs-on: linux-mi35x-gpu-8\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh\n          # Install tabulate for run_suite.py (missing in MI35x container)\n          bash scripts/ci/amd/amd_ci_exec.sh pip install tabulate\n\n      - name: Accuracy Test MI35x (8-GPU DeepSeek-R1-MXFP4)\n        timeout-minutes: 180\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-8-gpu-mi35x-deepseek-r1-mxfp4 --nightly --timeout-per-file 7200 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n      - name: Performance Test MI35x (8-GPU DeepSeek-R1-MXFP4)\n        timeout-minutes: 300\n        continue-on-error: true  # Perf test failure doesn't fail the job if accuracy passed\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 registered/amd/perf/mi35x/test_deepseek_r1_mxfp4_perf_mi35x.py || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # MI35x 8-GPU DeepSeek-R1-MXFP4 KV FP8 (Accuracy + Performance combined)\n  nightly-8-gpu-mi35x-deepseek-r1-mxfp4-kv-fp8:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-mi35x-deepseek-r1-mxfp4-kv-fp8,'))\n    runs-on: linux-mi35x-gpu-8\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh\n          # Install tabulate for run_suite.py (missing in MI35x container)\n          bash scripts/ci/amd/amd_ci_exec.sh pip install tabulate\n\n      - name: Accuracy Test MI35x (8-GPU DeepSeek-R1-MXFP4 KV FP8)\n        timeout-minutes: 180\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-8-gpu-mi35x-deepseek-r1-mxfp4-kv-fp8 --nightly --timeout-per-file 7200 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n      - name: Performance Test MI35x (8-GPU DeepSeek-R1-MXFP4 KV FP8)\n        timeout-minutes: 300\n        continue-on-error: true  # Perf test failure doesn't fail the job if accuracy passed\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 registered/amd/perf/mi35x/test_deepseek_r1_mxfp4_kv_fp8_perf_mi35x.py || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # MI35x 8-GPU DeepSeek-R1-MXFP4 AllReduce Fusion (Accuracy + Performance combined)\n  nightly-8-gpu-mi35x-deepseek-r1-mxfp4-ar-fusion:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-mi35x-deepseek-r1-mxfp4-ar-fusion,'))\n    runs-on: linux-mi35x-gpu-8\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh\n          # Install tabulate for run_suite.py (missing in MI35x container)\n          bash scripts/ci/amd/amd_ci_exec.sh pip install tabulate\n\n      - name: Accuracy Test MI35x (8-GPU DeepSeek-R1-MXFP4 AllReduce Fusion)\n        timeout-minutes: 180\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-8-gpu-mi35x-deepseek-r1-mxfp4-ar-fusion --nightly --timeout-per-file 7200 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n      - name: Performance Test MI35x (8-GPU DeepSeek-R1-MXFP4 AllReduce Fusion)\n        timeout-minutes: 300\n        continue-on-error: true  # Perf test failure doesn't fail the job if accuracy passed\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 registered/amd/perf/mi35x/test_deepseek_r1_mxfp4_ar_fusion_perf_mi35x.py || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # MI35x 8-GPU DeepSeek-V3.2 Accuracy Test\n  nightly-accuracy-8-gpu-mi35x-deepseek-v32:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-accuracy-8-gpu-mi35x-deepseek-v32,'))\n    runs-on: linux-mi35x-gpu-8\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh\n          # Install tabulate for run_suite.py (missing in MI35x container)\n          bash scripts/ci/amd/amd_ci_exec.sh pip install tabulate\n\n      - name: Accuracy Test MI35x (8-GPU DeepSeek-V3.2)\n        timeout-minutes: 120\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-8-gpu-mi35x-deepseek-v32 --nightly --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # MI35x 8-GPU DeepSeek-V3.2 TP+MTP Accuracy Test\n  nightly-accuracy-8-gpu-mi35x-deepseek-v32-mtp:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-accuracy-8-gpu-mi35x-deepseek-v32-mtp,'))\n    runs-on: linux-mi35x-gpu-8\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh\n          # Install tabulate for run_suite.py (missing in MI35x container)\n          bash scripts/ci/amd/amd_ci_exec.sh pip install tabulate\n\n      - name: Accuracy Test MI35x (8-GPU DeepSeek-V3.2 TP+MTP)\n        timeout-minutes: 120\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-accuracy-8-gpu-mi35x-deepseek-v32-mtp --nightly --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # MI35x 8-GPU DeepSeek-V3.2 Performance Test (Basic)\n  nightly-perf-8-gpu-mi35x-deepseek-v32-basic:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-perf-8-gpu-mi35x-deepseek-v32-basic,'))\n    runs-on: linux-mi35x-gpu-8\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh\n          # Install tabulate for run_suite.py (missing in MI35x container)\n          bash scripts/ci/amd/amd_ci_exec.sh pip install tabulate\n\n      - name: Performance Test MI35x (8-GPU DeepSeek-V3.2 Basic)\n        timeout-minutes: 150\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-perf-8-gpu-mi35x-deepseek-v32-basic --nightly --timeout-per-file 5400 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # MI35x 8-GPU Kimi-K2.5 (Accuracy)\n  nightly-8-gpu-mi35x-kimi-k25:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-mi35x-kimi-k25,'))\n    runs-on: linux-mi35x-gpu-8\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh\n          # Install tabulate for run_suite.py (missing in MI35x container)\n          bash scripts/ci/amd/amd_ci_exec.sh pip install tabulate\n\n      - name: Accuracy Test MI35x (8-GPU Kimi-K2.5)\n        timeout-minutes: 180\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-accuracy-8-gpu-mi35x-kimi-k25 --nightly --timeout-per-file 7200 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # MI35x 8-GPU Qwen3-235B-MXFP4 (Accuracy + Performance)\n  nightly-8-gpu-mi35x-qwen3-235b-mxfp4:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-mi35x-qwen3-235b-mxfp4,'))\n    runs-on: linux-mi35x-gpu-8\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh\n          # Install tabulate for run_suite.py (missing in MI35x container)\n          bash scripts/ci/amd/amd_ci_exec.sh pip install tabulate\n\n      - name: Accuracy Test + Performance Test MI35x (8-GPU Qwen3-235B-MXFP4)\n        timeout-minutes: 120\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-8-gpu-mi35x-qwen3-235b-mxfp4 --nightly --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # MI35x 8-GPU Qwen 3.5 (Accuracy)\n  nightly-8-gpu-mi35x-qwen35:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-mi35x-qwen35,'))\n    runs-on: linux-mi35x-gpu-8\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh\n          bash scripts/ci/amd/amd_ci_exec.sh pip install tabulate\n          bash scripts/ci/amd/amd_ci_exec.sh pip install mistral-common \"lm-eval[api]\"\n\n      - name: Accuracy Test MI35x (8-GPU Qwen 3.5)\n        timeout-minutes: 120\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-accuracy-8-gpu-mi35x-qwen35 --nightly --timeout-per-file 3600 || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  nightly-8-gpu-mi35x-glm5:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-mi35x-glm5,'))\n    runs-on: linux-mi35x-gpu-8\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh\n          # Install tabulate for run_suite.py (missing in MI35x container)\n          bash scripts/ci/amd/amd_ci_exec.sh pip install tabulate\n          bash scripts/ci/amd/amd_ci_exec.sh pip install git+https://github.com/huggingface/transformers.git@96f807a33b75\n\n      - name: Accuracy Test MI35x (8-GPU GLM-5 NSA)\n        timeout-minutes: 180\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-8-gpu-mi35x-glm5 --nightly --timeout-per-file 7200 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # MI35x 8-GPU MiniMax-M2.5 (Accuracy)\n  nightly-8-gpu-mi35x-minimax-m25:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-mi35x-minimax-m25,'))\n    runs-on: linux-mi35x-gpu-8\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh\n          bash scripts/ci/amd/amd_ci_exec.sh pip install tabulate\n\n      - name: Accuracy Test MI35x (8-GPU MiniMax-M2.5)\n        timeout-minutes: 120\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e SGLANG_USE_AITER=1 \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-amd-8-gpu-mi35x-minimax-m25 --nightly --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  # MI35x 8-GPU DeepSeek-V3.2 Performance Test (MTP)\n  nightly-perf-8-gpu-mi35x-deepseek-v32-mtp:\n    if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-perf-8-gpu-mi35x-deepseek-v32-mtp,'))\n    runs-on: linux-mi35x-gpu-8\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Setup docker\n        run: |\n          touch github_summary.md\n          bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh\n          # Install tabulate for run_suite.py (missing in MI35x container)\n          bash scripts/ci/amd/amd_ci_exec.sh pip install tabulate\n\n      - name: Performance Test MI35x (8-GPU DeepSeek-V3.2 MTP)\n        timeout-minutes: 180\n        run: |\n          > github_summary.md  # Clear summary file\n          bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \\\n            -e GITHUB_STEP_SUMMARY=\"/sglang-checkout/github_summary.md\" \\\n            python3 run_suite.py --hw amd --suite nightly-perf-8-gpu-mi35x-deepseek-v32-mtp --nightly --timeout-per-file 7200 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$?\n          echo \"$(<github_summary.md )\" >> $GITHUB_STEP_SUMMARY || true\n          exit ${TEST_EXIT_CODE:-0}\n\n  check-all-jobs:\n    if: always() && (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' || github.event_name == 'workflow_dispatch')\n    needs:\n      # MI30x Unit Tests\n      - nightly-test-1-gpu-unit\n      # MI30x Accuracy Tests\n      - nightly-accuracy-2-gpu\n      - nightly-accuracy-2-gpu-vlm\n      - nightly-accuracy-8-gpu\n      # MI30x Performance Tests - excluded from check (perf failures don't block CI)\n      # - nightly-perf-2-gpu-text\n      # - nightly-perf-2-gpu-vlm\n      # MI30x Combined Accuracy + Performance Tests\n      - nightly-8-gpu-grok1-int4\n      - nightly-8-gpu-grok2\n      - nightly-8-gpu-deepseek-v31\n      - nightly-8-gpu-deepseek-v32\n      - nightly-8-gpu-deepseek-v32-mtp\n      - nightly-8-gpu-deepseek-v3-kv-fp8\n      - nightly-8-gpu-kimi-k25\n      - nightly-8-gpu-qwen3-235b\n      - nightly-8-gpu-qwen35\n      - nightly-8-gpu-glm5\n      - nightly-8-gpu-minimax-m25\n      # MI30x Diffusion Tests\n      - nightly-1-gpu-zimage-turbo\n      # MI35x jobs\n      - nightly-test-1-gpu-mi35x\n      - nightly-accuracy-8-gpu-mi35x\n      - nightly-8-gpu-mi35x-grok1-int4\n      - nightly-8-gpu-mi35x-grok2\n      - nightly-8-gpu-mi35x-deepseek-r1-mxfp4\n      - nightly-8-gpu-mi35x-deepseek-r1-mxfp4-kv-fp8\n      - nightly-8-gpu-mi35x-deepseek-r1-mxfp4-ar-fusion\n      - nightly-accuracy-8-gpu-mi35x-deepseek-v32\n      - nightly-accuracy-8-gpu-mi35x-deepseek-v32-mtp\n      - nightly-8-gpu-mi35x-kimi-k25\n      - nightly-8-gpu-mi35x-qwen3-235b-mxfp4\n      - nightly-8-gpu-mi35x-qwen35\n      - nightly-8-gpu-mi35x-glm5\n      - nightly-8-gpu-mi35x-minimax-m25\n      # MI35x perf jobs excluded from check - perf failures don't block CI\n      # - nightly-perf-8-gpu-mi35x-deepseek-v32-basic\n      # - nightly-perf-8-gpu-mi35x-deepseek-v32-mtp\n    runs-on: ubuntu-latest\n    steps:\n      - name: Check if any job failed\n        run: |\n          if [[ \"${{ contains(needs.*.result, 'failure') }}\" == \"true\" ]]; then\n            echo \"One or more nightly test jobs failed\"\n            exit 1\n          fi\n          if [[ \"${{ contains(needs.*.result, 'cancelled') }}\" == \"true\" ]]; then\n            echo \"One or more nightly test jobs were cancelled\"\n            exit 1\n          fi\n          echo \"All nightly test jobs passed\"\n"
  },
  {
    "path": ".github/workflows/nightly-test-intel.yml",
    "content": "name: Nightly Test (Intel)\n\non:\n  schedule:\n    - cron: '0 0 * * *'\n  push:\n    branches:\n      - main\n    paths:\n      - \"python/sglang/version.py\"\n  workflow_dispatch:\n  workflow_call:\n    inputs:\n      ref:\n        description: \"Branch, tag or SHA to checkout\"\n        required: false\n        type: string\n        default: \"\"\n\nconcurrency:\n  group: nightly-test-intel-${{ inputs.ref || github.ref }}\n  cancel-in-progress: ${{ github.event_name != 'workflow_call' }}\n\njobs:\n  # Placeholder for Intel GPU tests\n  # Add Intel-specific nightly test workflows here when available\n\n  placeholder:\n    if: github.repository == 'sgl-project/sglang'\n    runs-on: ubuntu-latest\n    steps:\n      - name: Placeholder\n        run: echo \"Intel nightly tests will be added here\"\n"
  },
  {
    "path": ".github/workflows/nightly-test-npu.yml",
    "content": "name: Nightly Test (NPU)\n\non:\n  schedule:\n    - cron: '0 17 * * *'  # Execute at 1:00 a.m. Beijing Time every day\n  pull_request:\n    branches:\n      - main\n    paths:\n      - \".github/workflows/nightly-test-npu.yml\"\n  workflow_dispatch:\n  workflow_call:\n    inputs:\n      ref:\n        description: 'Git ref (branch, tag, or SHA) to test. If not provided, uses the default branch.'\n        required: false\n        type: string\n        default: ''\n      job_filter:\n        description: 'Select which job to run (leave empty or \"all\" to run all jobs)'\n        required: false\n        type: string\n        default: 'all'\n\nconcurrency:\n  group: nightly-test-npu-${{ inputs.ref || github.ref }}\n  cancel-in-progress: ${{ github.event_name != 'workflow_call' }}\n\njobs:\n  nightly-1-npu-a3:\n    if: ${{ (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') }}\n    runs-on: linux-aarch64-a3-2\n    strategy:\n      fail-fast: false\n      matrix:\n        part: [0, 1]\n    container:\n      image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.5.0-a3-ubuntu22.04-py3.11\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Install dependencies\n        run: |\n          # speed up by using infra cache services\n          CACHING_URL=\"cache-service.nginx-pypi-cache.svc.cluster.local\"\n          sed -Ei \"s@(ports|archive).ubuntu.com@${CACHING_URL}:8081@g\" /etc/apt/sources.list\n          pip config set global.index-url http://${CACHING_URL}/pypi/simple\n          pip config set global.extra-index-url \"https://pypi.tuna.tsinghua.edu.cn/simple\"\n          pip config set global.trusted-host \"${CACHING_URL} pypi.tuna.tsinghua.edu.cn\"\n          bash scripts/ci/npu/npu_ci_install_dependency.sh a3\n          # copy required file from our daily cache\n          cp ~/.cache/modelscope/hub/datasets/otavia/ShareGPT_Vicuna_unfiltered/ShareGPT_V3_unfiltered_cleaned_split.json /tmp\n          # copy download through proxy\n          curl -o /tmp/test.jsonl -L https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl\n\n      - name: Print Log Information\n        run: |\n          bash scripts/ci/npu/npu_log_print.sh\n\n      - name: Run test\n        timeout-minutes: 240\n        env:\n          SGLANG_USE_MODELSCOPE: true\n          SGLANG_IS_IN_CI: true\n          HF_ENDPOINT: https://hf-mirror.com\n          TORCH_EXTENSIONS_DIR: /tmp/torch_extensions\n          PYTORCH_NPU_ALLOC_CONF: \"expandable_segments:True\"\n          STREAMS_PER_DEVICE: 32\n        run: |\n          pip install sglang_router\n          hf download lmms-lab/MMMU --repo-type dataset\n          pip install sentence_transformers torchaudio==2.8.0\n          pip install protobuf==6.31.1 zss pre-commit wandb>=0.16.0 tenacity==8.3.0 loguru openpyxl latex2sympy2 zstandard transformers-stream-generator tqdm-multiprocess pycocoevalcap\n          pip install yt-dlp sentencepiece==0.1.99 nltk av ftfy sqlitedict==2.1.0 sacrebleu>=1.5.0 pytablewriter black==24.1.0 isort==5.13.2 peft>=0.2.0 accelerate>=0.29.1\n          pip install jsonlines httpx==0.25.0 evaluate>=0.4.0 datasets==2.16.1 numexpr xgrammar==0.1.25 numpy==1.26.4 dotenv\n          git clone --branch v0.3.3 --depth 1 https://github.com/EvolvingLMMs-Lab/lmms-eval.git\n          cd ./lmms-eval\n          nohup pip install . > lmmslog.txt 2>&1 &\n          sleep 120\n          export PYTHONPATH=$PYTHONPATH:$(pwd)\n          cd ../\n          cd test\n          python3 run_suite.py --hw npu --suite nightly-1-npu-a3 --nightly --continue-on-error --timeout-per-file 3600 --auto-partition-id ${{ matrix.part }} --auto-partition-size 2\n\n  nightly-2-npu-a3:\n    if: ${{ (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') }}\n    runs-on: linux-aarch64-a3-2\n    strategy:\n      fail-fast: false\n      matrix:\n        part: [0]\n    container:\n      image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.5.0-a3-ubuntu22.04-py3.11\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Install dependencies\n        run: |\n          # speed up by using infra cache services\n          CACHING_URL=\"cache-service.nginx-pypi-cache.svc.cluster.local\"\n          sed -Ei \"s@(ports|archive).ubuntu.com@${CACHING_URL}:8081@g\" /etc/apt/sources.list\n          pip config set global.index-url http://${CACHING_URL}/pypi/simple\n          pip config set global.extra-index-url \"https://pypi.tuna.tsinghua.edu.cn/simple\"\n          pip config set global.trusted-host \"${CACHING_URL} pypi.tuna.tsinghua.edu.cn\"\n          bash scripts/ci/npu/npu_ci_install_dependency.sh a3\n          # copy required file from our daily cache\n          cp ~/.cache/modelscope/hub/datasets/otavia/ShareGPT_Vicuna_unfiltered/ShareGPT_V3_unfiltered_cleaned_split.json /tmp\n          # copy download through proxy\n          curl -o /tmp/test.jsonl -L https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl\n\n      - name: Print Log Information\n        run: |\n          bash scripts/ci/npu/npu_log_print.sh\n      - name: Run test\n        timeout-minutes: 240\n        env:\n          SGLANG_USE_MODELSCOPE: true\n          SGLANG_IS_IN_CI: true\n          HF_ENDPOINT: https://hf-mirror.com\n          TORCH_EXTENSIONS_DIR: /tmp/torch_extensions\n          PYTORCH_NPU_ALLOC_CONF: \"expandable_segments:True\"\n          STREAMS_PER_DEVICE: 32\n        run: |\n          pip install sglang_router\n          hf download lmms-lab/MMMU --repo-type dataset\n          pip install sentence_transformers torchaudio==2.8.0\n          pip install protobuf==6.31.1 zss pre-commit wandb>=0.16.0 tenacity==8.3.0 loguru openpyxl latex2sympy2 zstandard transformers-stream-generator tqdm-multiprocess pycocoevalcap\n          pip install yt-dlp sentencepiece==0.1.99 nltk av ftfy sqlitedict==2.1.0 sacrebleu>=1.5.0 pytablewriter black==24.1.0 isort==5.13.2 peft>=0.2.0 accelerate>=0.29.1\n          pip install jsonlines httpx==0.25.0 evaluate>=0.4.0 datasets==2.16.1 numexpr xgrammar==0.1.25 numpy==1.26.4 dotenv\n          git clone --branch v0.3.3 --depth 1 https://github.com/EvolvingLMMs-Lab/lmms-eval.git\n          cd ./lmms-eval\n          nohup pip install . > lmmslog.txt 2>&1 &\n          sleep 120\n          export PYTHONPATH=$PYTHONPATH:$(pwd)\n          cd ../\n          cd test\n          python3 run_suite.py --hw npu --suite nightly-2-npu-a3 --nightly --continue-on-error --timeout-per-file 3600 --auto-partition-id ${{ matrix.part }} --auto-partition-size 1\n\n  nightly-4-npu-a3:\n    if: ${{ (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') }}\n    runs-on: linux-aarch64-a3-4\n    strategy:\n      fail-fast: false\n      matrix:\n        part: [0]\n    container:\n      image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.5.0-a3-ubuntu22.04-py3.11\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Install dependencies\n        run: |\n          # speed up by using infra cache services\n          CACHING_URL=\"cache-service.nginx-pypi-cache.svc.cluster.local\"\n          sed -Ei \"s@(ports|archive).ubuntu.com@${CACHING_URL}:8081@g\" /etc/apt/sources.list\n          pip config set global.extra-index-url \"https://pypi.tuna.tsinghua.edu.cn/simple\"\n          pip config set global.trusted-host \"${CACHING_URL} pypi.tuna.tsinghua.edu.cn\"\n          bash scripts/ci/npu/npu_ci_install_dependency.sh a3\n          # copy required file from our daily cache\n          cp ~/.cache/modelscope/hub/datasets/otavia/ShareGPT_Vicuna_unfiltered/ShareGPT_V3_unfiltered_cleaned_split.json /tmp\n          # copy download through proxy\n          curl -o /tmp/test.jsonl -L https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl\n\n      - name: Print Log Information\n        run: |\n          bash scripts/ci/npu/npu_log_print.sh\n\n      - name: Run test\n        timeout-minutes: 240\n        env:\n          SGLANG_USE_MODELSCOPE: true\n          SGLANG_IS_IN_CI: true\n          HF_ENDPOINT: https://hf-mirror.com\n          TORCH_EXTENSIONS_DIR: /tmp/torch_extensions\n          PYTORCH_NPU_ALLOC_CONF: \"expandable_segments:True\"\n          STREAMS_PER_DEVICE: 32\n        run: |\n          pip install sglang_router\n          hf download lmms-lab/MMMU --repo-type dataset\n          pip install sentence_transformers torchaudio==2.8.0\n          pip install protobuf==6.31.1 zss pre-commit wandb>=0.16.0 tenacity==8.3.0 loguru openpyxl latex2sympy2 zstandard transformers-stream-generator tqdm-multiprocess pycocoevalcap\n          pip install yt-dlp sentencepiece==0.1.99 nltk av ftfy sqlitedict==2.1.0 sacrebleu>=1.5.0 pytablewriter black==24.1.0 isort==5.13.2 peft>=0.2.0 accelerate>=0.29.1\n          pip install jsonlines httpx==0.25.0 evaluate>=0.4.0 datasets==2.16.1 numexpr xgrammar==0.1.25 numpy==1.26.4 dotenv\n          git clone --branch v0.3.3 --depth 1 https://github.com/EvolvingLMMs-Lab/lmms-eval.git\n          cd ./lmms-eval\n          nohup pip install . > lmmslog.txt 2>&1 &\n          sleep 120\n          export PYTHONPATH=$PYTHONPATH:$(pwd)\n          cd ../\n          cd test\n          python3 run_suite.py --hw npu --suite nightly-4-npu-a3 --nightly --continue-on-error --timeout-per-file 3600 --auto-partition-id ${{ matrix.part }} --auto-partition-size 1\n\n  nightly-8-npu-a3:\n    if: ${{ (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') }}\n    runs-on: linux-aarch64-a3-8\n    strategy:\n      fail-fast: false\n      matrix:\n        part: [0]\n    container:\n      image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.5.0-a3-ubuntu22.04-py3.11\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Install dependencies\n        run: |\n          # speed up by using infra cache services\n          CACHING_URL=\"cache-service.nginx-pypi-cache.svc.cluster.local\"\n          sed -Ei \"s@(ports|archive).ubuntu.com@${CACHING_URL}:8081@g\" /etc/apt/sources.list\n          pip config set global.index-url http://${CACHING_URL}/pypi/simple\n          pip config set global.extra-index-url \"https://pypi.tuna.tsinghua.edu.cn/simple\"\n          pip config set global.trusted-host \"${CACHING_URL} pypi.tuna.tsinghua.edu.cn\"\n          bash scripts/ci/npu/npu_ci_install_dependency.sh a3\n          # copy required file from our daily cache\n          cp ~/.cache/modelscope/hub/datasets/otavia/ShareGPT_Vicuna_unfiltered/ShareGPT_V3_unfiltered_cleaned_split.json /tmp\n          # copy download through proxy\n          curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl\n\n      - name: Print Log Information\n        run: |\n          bash scripts/ci/npu/npu_log_print.sh\n\n      - name: Run test\n        timeout-minutes: 240\n        env:\n          SGLANG_USE_MODELSCOPE: true\n          SGLANG_IS_IN_CI: true\n          HF_ENDPOINT: https://hf-mirror.com\n          TORCH_EXTENSIONS_DIR: /tmp/torch_extensions\n          PYTORCH_NPU_ALLOC_CONF: \"expandable_segments:True\"\n          STREAMS_PER_DEVICE: 32\n        run: |\n          pip install sglang_router\n          hf download lmms-lab/MMMU --repo-type dataset\n          pip install sentence_transformers torchaudio==2.8.0\n          pip install protobuf==6.31.1 zss pre-commit wandb>=0.16.0 tenacity==8.3.0 loguru openpyxl latex2sympy2 zstandard transformers-stream-generator tqdm-multiprocess pycocoevalcap\n          pip install yt-dlp sentencepiece==0.1.99 nltk av ftfy sqlitedict==2.1.0 sacrebleu>=1.5.0 pytablewriter black==24.1.0 isort==5.13.2 peft>=0.2.0 accelerate>=0.29.1\n          pip install jsonlines httpx==0.25.0 evaluate>=0.4.0 datasets==2.16.1 numexpr xgrammar==0.1.25 numpy==1.26.4 dotenv\n          git clone --branch v0.3.3 --depth 1 https://github.com/EvolvingLMMs-Lab/lmms-eval.git\n          cd ./lmms-eval\n          nohup pip install . > lmmslog.txt 2>&1 &\n          sleep 120\n          export PYTHONPATH=$PYTHONPATH:$(pwd)\n          cd ../\n          cd test\n          python3 run_suite.py --hw npu --suite nightly-8-npu-a3 --nightly --continue-on-error --timeout-per-file 3600 --auto-partition-id ${{ matrix.part }} --auto-partition-size 1\n\n  nightly-16-npu-a3:\n    if: ${{ (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') }}\n    runs-on: linux-aarch64-a3-16\n    strategy:\n      fail-fast: false\n      matrix:\n        part: [0, 1]\n    container:\n      image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.5.0-a3-ubuntu22.04-py3.11\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Install dependencies\n        run: |\n          # speed up by using infra cache services\n          CACHING_URL=\"cache-service.nginx-pypi-cache.svc.cluster.local\"\n          sed -Ei \"s@(ports|archive).ubuntu.com@${CACHING_URL}:8081@g\" /etc/apt/sources.list\n          pip config set global.index-url http://${CACHING_URL}/pypi/simple\n          pip config set global.extra-index-url \"https://pypi.tuna.tsinghua.edu.cn/simple\"\n          pip config set global.trusted-host \"${CACHING_URL} pypi.tuna.tsinghua.edu.cn\"\n          bash scripts/ci/npu/npu_ci_install_dependency.sh a3\n          # copy required file from our daily cache\n          cp ~/.cache/modelscope/hub/datasets/otavia/ShareGPT_Vicuna_unfiltered/ShareGPT_V3_unfiltered_cleaned_split.json /tmp\n          # copy download through proxy\n          curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl\n\n      - name: Print Log Information\n        run: |\n          bash scripts/ci/npu/npu_log_print.sh\n\n      - name: Run test\n        timeout-minutes: 240\n        env:\n          SGLANG_USE_MODELSCOPE: true\n          SGLANG_IS_IN_CI: true\n          HF_ENDPOINT: https://hf-mirror.com\n          TORCH_EXTENSIONS_DIR: /tmp/torch_extensions\n          PYTORCH_NPU_ALLOC_CONF: \"expandable_segments:True\"\n          STREAMS_PER_DEVICE: 32\n        run: |\n          pip install sglang_router\n          hf download lmms-lab/MMMU --repo-type dataset\n          pip install sentence_transformers torchaudio==2.8.0\n          pip install protobuf==6.31.1 zss pre-commit wandb>=0.16.0 tenacity==8.3.0 loguru openpyxl latex2sympy2 zstandard transformers-stream-generator tqdm-multiprocess pycocoevalcap\n          pip install yt-dlp sentencepiece==0.1.99 nltk av ftfy sqlitedict==2.1.0 sacrebleu>=1.5.0 pytablewriter black==24.1.0 isort==5.13.2 peft>=0.2.0 accelerate>=0.29.1\n          pip install jsonlines httpx==0.25.0 evaluate>=0.4.0 datasets==2.16.1 numexpr xgrammar==0.1.25 numpy==1.26.4 dotenv\n          git clone --branch v0.3.3 --depth 1 https://github.com/EvolvingLMMs-Lab/lmms-eval.git\n          cd ./lmms-eval\n          nohup pip install . > lmmslog.txt 2>&1 &\n          sleep 120\n          export PYTHONPATH=$PYTHONPATH:$(pwd)\n          cd ../\n          cd test\n          python3 run_suite.py --hw npu --suite nightly-16-npu-a3 --nightly --continue-on-error --timeout-per-file 3600 --auto-partition-id ${{ matrix.part }} --auto-partition-size 2\n\n  check-all-jobs:\n    if: github.repository == 'sgl-project/sglang' && always()\n    needs:\n      - nightly-1-npu-a3\n      - nightly-2-npu-a3\n      - nightly-4-npu-a3\n      - nightly-8-npu-a3\n      - nightly-16-npu-a3\n    runs-on: ubuntu-latest\n    container:\n      image: docker.m.daocloud.io/ubuntu:22.04\n    steps:\n      - name: Check if any job failed\n        run: |\n          if [[ \"${{ contains(needs.*.result, 'failure') }}\" == \"true\" ]]; then\n            echo \"One or more nightly test jobs failed\"\n            exit 1\n          fi\n          if [[ \"${{ contains(needs.*.result, 'cancelled') }}\" == \"true\" ]]; then\n            echo \"One or more nightly test jobs were cancelled\"\n            exit 1\n          fi\n          echo \"All nightly test jobs passed\"\n"
  },
  {
    "path": ".github/workflows/nightly-test-nvidia.yml",
    "content": "name: Nightly Test (Nvidia)\n\non:\n  schedule:\n    - cron: '0 0 * * *'\n  workflow_dispatch:\n    inputs:\n      job_filter:\n        description: 'Select which job to run (leave empty or \"all\" to run all jobs)'\n        required: false\n        type: choice\n        default: 'all'\n        options:\n          - 'all'\n          - 'nightly-test-general-1-gpu-runner'\n          - 'nightly-test-general-4-gpu-h100'\n          - 'nightly-test-general-8-gpu-h200'\n          - 'nightly-test-general-8-gpu-h20'\n          - 'nightly-test-general-8-gpu-b200'\n          - 'nightly-test-text-accuracy-2-gpu-runner'\n          - 'nightly-test-text-perf-2-gpu-runner'\n          - 'nightly-test-vlm-accuracy-2-gpu-runner'\n          - 'nightly-test-vlm-perf-2-gpu-runner'\n          - 'nightly-test-multimodal-server-1-gpu'\n          - 'nightly-test-multimodal-server-2-gpu'\n          - 'nightly-test-perf-4-gpu-b200'\n          - 'nightly-test-perf-8-gpu-b200'\n  workflow_call:\n    inputs:\n      ref:\n        description: 'Git ref (branch, tag, or SHA) to test. If not provided, uses the default branch.'\n        required: false\n        type: string\n        default: ''\n      job_filter:\n        description: 'Select which job to run (leave empty or \"all\" to run all jobs)'\n        required: false\n        type: string\n        default: 'all'\n\nconcurrency:\n  group: nightly-test-nvidia-${{ inputs.ref || github.ref }}\n  cancel-in-progress: ${{ github.event_name != 'workflow_call' }}\n\nenv:\n  SGLANG_IS_IN_CI: true\n  SGLANG_CUDA_COREDUMP: \"1\"\n  HF_HUB_DOWNLOAD_TIMEOUT: 300\n  HF_HUB_ETAG_TIMEOUT: 300\n\njobs:\n  # General tests - 1 GPU\n  nightly-test-general-1-gpu-runner:\n    if: github.repository == 'sgl-project/sglang' && (inputs.job_filter == '' || inputs.job_filter == 'all' || inputs.job_filter == 'nightly-test-general-1-gpu-runner')\n    runs-on: 1-gpu-runner\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/cuda/ci_install_dependency.sh\n\n      - name: Run test\n        timeout-minutes: 60\n        run: |\n          cd test\n          python3 run_suite.py --hw cuda --suite nightly-1-gpu --nightly --continue-on-error\n\n      - uses: ./.github/actions/upload-cuda-coredumps\n        if: always()\n\n  # General tests - 4 GPU H100\n  nightly-test-general-4-gpu-h100:\n    if: github.repository == 'sgl-project/sglang' && (inputs.job_filter == '' || inputs.job_filter == 'all' || inputs.job_filter == 'nightly-test-general-4-gpu-h100')\n    runs-on: 4-gpu-h100\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/cuda/ci_install_dependency.sh\n\n      - name: Run test\n        timeout-minutes: 30\n        run: |\n          cd test\n          python3 run_suite.py --hw cuda --suite nightly-4-gpu --nightly --continue-on-error\n\n      - uses: ./.github/actions/upload-cuda-coredumps\n        if: always()\n\n  # General tests - 8 GPU H200\n  nightly-test-general-8-gpu-h200:\n    if: github.repository == 'sgl-project/sglang' && (inputs.job_filter == '' || inputs.job_filter == 'all' || inputs.job_filter == 'nightly-test-general-8-gpu-h200')\n    runs-on: 8-gpu-h200\n    strategy:\n      fail-fast: false\n      matrix:\n        partition: [0, 1, 2, 3]\n    env:\n      RUNNER_LABELS: 8-gpu-h200\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/cuda/ci_install_dependency.sh\n\n      - name: Run common 8-GPU model tests\n        if: always()\n        timeout-minutes: 300\n        env:\n          TRACE_BASE_URL: https://raw.githubusercontent.com/sglang-bot/sglang-ci-data/main/traces/${{ github.run_id }}\n          PERFETTO_RELAY_URL: ${{ vars.PERFETTO_RELAY_URL }}\n          GPU_CONFIG: \"8-gpu-h200\"\n          IS_H200: \"1\"\n        run: |\n          cd test\n          python3 run_suite.py --hw cuda --suite nightly-8-gpu-common --nightly --timeout-per-file=18000 --continue-on-error --auto-partition-id=${{ matrix.partition }} --auto-partition-size=4\n\n      - name: Publish traces to storage repo\n        if: always()\n        continue-on-error: true\n        env:\n          GITHUB_TOKEN: ${{ secrets.GH_PAT_FOR_NIGHTLY_CI_DATA }}\n          GITHUB_RUN_ID: ${{ github.run_id }}\n          GITHUB_RUN_NUMBER: ${{ github.run_number }}\n        run: |\n          TRACE_ARGS=\"\"\n          for dir in test/performance_profiles_*/; do\n            [ -d \"$dir\" ] && TRACE_ARGS=\"$TRACE_ARGS --traces-dir $dir\"\n          done\n          if [ -n \"$TRACE_ARGS\" ]; then\n            python3 scripts/ci/utils/publish_traces.py $TRACE_ARGS\n            find test/performance_profiles_*/ -name '*.json.gz' -delete\n          else\n            echo \"No trace directories found, skipping publish\"\n          fi\n\n      - name: Run test\n        timeout-minutes: 30\n        env:\n          GPU_CONFIG: \"8-gpu-h200\"\n        run: |\n          cd test\n          python3 run_suite.py --hw cuda --suite nightly-8-gpu-h200 --nightly --continue-on-error\n\n      - name: Collect performance metrics\n        if: always()\n        run: |\n          python3 scripts/ci/utils/save_metrics.py \\\n            --gpu-config 8-gpu-h200 \\\n            --partition ${{ matrix.partition }} \\\n            --run-id ${{ github.run_id }} \\\n            --output test/metrics-8gpu-h200-partition-${{ matrix.partition }}.json \\\n            --search-dir test/performance_profiles_8_gpu \\\n            --search-dir test\n\n      - name: Upload partition metrics\n        if: always()\n        uses: actions/upload-artifact@v4\n        with:\n          name: metrics-8gpu-h200-partition-${{ matrix.partition }}\n          path: test/metrics-8gpu-h200-partition-${{ matrix.partition }}.json\n          retention-days: 5\n          if-no-files-found: ignore\n\n      - uses: ./.github/actions/upload-cuda-coredumps\n        if: always()\n        with:\n          artifact-suffix: ${{ matrix.partition }}\n\n  # General tests - 8 GPU H20\n  nightly-test-general-8-gpu-h20:\n    if: github.repository == 'sgl-project/sglang' && (inputs.job_filter == '' || inputs.job_filter == 'all' || inputs.job_filter == 'nightly-test-general-8-gpu-h20')\n    runs-on: 8-gpu-h20\n    env:\n      SGLANG_CI_RDMA_ALL_DEVICES: \"mlx5_1,mlx5_2,mlx5_3,mlx5_4\"\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/cuda/ci_install_dependency.sh\n\n      - name: Run test\n        timeout-minutes: 30\n        env:\n          GPU_CONFIG: \"8-gpu-h20\"\n        run: |\n          cd test\n          python3 run_suite.py --hw cuda --suite nightly-8-gpu-h20 --nightly --continue-on-error\n\n      - uses: ./.github/actions/upload-cuda-coredumps\n        if: always()\n\n  # General tests - 8 GPU B200\n  nightly-test-general-8-gpu-b200:\n    if: github.repository == 'sgl-project/sglang' && (inputs.job_filter == '' || inputs.job_filter == 'all' || inputs.job_filter == 'nightly-test-general-8-gpu-b200')\n    runs-on: 8-gpu-b200\n    strategy:\n      fail-fast: false\n      matrix:\n        partition: [0, 1, 2, 3]\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Install dependencies\n        run: |\n          IS_BLACKWELL=1 bash scripts/ci/cuda/ci_install_dependency.sh\n\n      - name: Run common 8-GPU model tests\n        if: always()\n        timeout-minutes: 300\n        env:\n          TRACE_BASE_URL: https://raw.githubusercontent.com/sglang-bot/sglang-ci-data/main/traces/${{ github.run_id }}\n          PERFETTO_RELAY_URL: ${{ vars.PERFETTO_RELAY_URL }}\n          GPU_CONFIG: \"8-gpu-b200\"\n        run: |\n          cd test\n          IS_BLACKWELL=1 python3 run_suite.py --hw cuda --suite nightly-8-gpu-common --nightly --timeout-per-file=12000 --continue-on-error --auto-partition-id=${{ matrix.partition }} --auto-partition-size=4\n\n      - name: Publish traces to storage repo\n        if: always()\n        continue-on-error: true\n        env:\n          GITHUB_TOKEN: ${{ secrets.GH_PAT_FOR_NIGHTLY_CI_DATA }}\n          GITHUB_RUN_ID: ${{ github.run_id }}\n          GITHUB_RUN_NUMBER: ${{ github.run_number }}\n        run: |\n          TRACE_ARGS=\"\"\n          for dir in test/performance_profiles_*/; do\n            [ -d \"$dir\" ] && TRACE_ARGS=\"$TRACE_ARGS --traces-dir $dir\"\n          done\n          if [ -n \"$TRACE_ARGS\" ]; then\n            python3 scripts/ci/utils/publish_traces.py $TRACE_ARGS\n            find test/performance_profiles_*/ -name '*.json.gz' -delete\n          else\n            echo \"No trace directories found, skipping publish\"\n          fi\n\n      - name: Collect performance metrics\n        if: always()\n        run: |\n          python3 scripts/ci/utils/save_metrics.py \\\n            --gpu-config 8-gpu-b200 \\\n            --partition ${{ matrix.partition }} \\\n            --run-id ${{ github.run_id }} \\\n            --output test/metrics-8gpu-b200-partition-${{ matrix.partition }}.json \\\n            --search-dir test/performance_profiles_8_gpu \\\n            --search-dir test\n\n      - name: Upload partition metrics\n        if: always()\n        uses: actions/upload-artifact@v4\n        with:\n          name: metrics-8gpu-b200-partition-${{ matrix.partition }}\n          path: test/metrics-8gpu-b200-partition-${{ matrix.partition }}.json\n          retention-days: 5\n          if-no-files-found: ignore\n\n      - uses: ./.github/actions/upload-cuda-coredumps\n        if: always()\n        with:\n          artifact-suffix: ${{ matrix.partition }}\n\n  # Text model accuracy tests\n  nightly-test-text-accuracy-2-gpu-runner:\n    if: github.repository == 'sgl-project/sglang' && (inputs.job_filter == '' || inputs.job_filter == 'all' || inputs.job_filter == 'nightly-test-text-accuracy-2-gpu-runner')\n    runs-on: 2-gpu-runner\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/cuda/ci_install_dependency.sh\n\n      - name: Run eval test for text models\n        timeout-minutes: 120\n        run: |\n          cd test\n          python3 run_suite.py --hw cuda --suite nightly-eval-text-2-gpu --nightly --continue-on-error --timeout-per-file 4500\n\n      - uses: ./.github/actions/upload-cuda-coredumps\n        if: always()\n\n  # Text model performance tests\n  nightly-test-text-perf-2-gpu-runner:\n    if: github.repository == 'sgl-project/sglang' && (inputs.job_filter == '' || inputs.job_filter == 'all' || inputs.job_filter == 'nightly-test-text-perf-2-gpu-runner')\n    runs-on: 2-gpu-runner\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/cuda/ci_install_dependency.sh\n\n      - name: Run performance test for text models\n        timeout-minutes: 180\n        env:\n          TRACE_BASE_URL: https://raw.githubusercontent.com/sglang-bot/sglang-ci-data/main/traces/${{ github.run_id }}\n          PERFETTO_RELAY_URL: ${{ vars.PERFETTO_RELAY_URL }}\n          GPU_CONFIG: \"2-gpu-runner\"\n        run: |\n          cd test\n          rm -rf performance_profiles_text_models/\n          python3 run_suite.py --hw cuda --suite nightly-perf-text-2-gpu --nightly --continue-on-error --timeout-per-file 3600\n\n      - name: Publish traces to storage repo\n        env:\n          GITHUB_TOKEN: ${{ secrets.GH_PAT_FOR_NIGHTLY_CI_DATA }}\n          GITHUB_RUN_ID: ${{ github.run_id }}\n          GITHUB_RUN_NUMBER: ${{ github.run_number }}\n        run: |\n          python3 scripts/ci/utils/publish_traces.py --traces-dir test/performance_profiles_text_models\n\n      - uses: ./.github/actions/upload-cuda-coredumps\n        if: always()\n\n  # VLM accuracy tests\n  nightly-test-vlm-accuracy-2-gpu-runner:\n    if: github.repository == 'sgl-project/sglang' && (inputs.job_filter == '' || inputs.job_filter == 'all' || inputs.job_filter == 'nightly-test-vlm-accuracy-2-gpu-runner')\n    runs-on: 2-gpu-runner\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/cuda/ci_install_dependency.sh\n\n      - name: Run eval test for VLM models (fixed MMMU-100)\n        timeout-minutes: 240\n        run: |\n          cd test\n          python3 run_suite.py --hw cuda --suite nightly-eval-vlm-2-gpu --nightly --continue-on-error --timeout-per-file 9000\n\n      - uses: ./.github/actions/upload-cuda-coredumps\n        if: always()\n\n  # VLM performance tests\n  nightly-test-vlm-perf-2-gpu-runner:\n    if: github.repository == 'sgl-project/sglang' && (inputs.job_filter == '' || inputs.job_filter == 'all' || inputs.job_filter == 'nightly-test-vlm-perf-2-gpu-runner')\n    runs-on: 2-gpu-runner\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/cuda/ci_install_dependency.sh\n\n      - name: Run perf test for VLM models (MMMU)\n        timeout-minutes: 240\n        env:\n          TRACE_BASE_URL: https://raw.githubusercontent.com/sglang-bot/sglang-ci-data/main/traces/${{ github.run_id }}\n          PERFETTO_RELAY_URL: ${{ vars.PERFETTO_RELAY_URL }}\n          GPU_CONFIG: \"2-gpu-runner\"\n        run: |\n          cd test\n          rm -rf performance_profiles_vlms/\n          python3 run_suite.py --hw cuda --suite nightly-perf-vlm-2-gpu --nightly --continue-on-error --timeout-per-file 3600\n\n      - name: Publish traces to storage repo\n        env:\n          GITHUB_TOKEN: ${{ secrets.GH_PAT_FOR_NIGHTLY_CI_DATA }}\n          GITHUB_RUN_ID: ${{ github.run_id }}\n          GITHUB_RUN_NUMBER: ${{ github.run_number }}\n        run: |\n          python3 scripts/ci/utils/publish_traces.py --traces-dir test/performance_profiles_vlms\n\n      - uses: ./.github/actions/upload-cuda-coredumps\n        if: always()\n\n  # diffusion performance tests\n  nightly-test-multimodal-server-1-gpu:\n    if: github.repository == 'sgl-project/sglang' && (inputs.job_filter == '' || inputs.job_filter == 'all' || inputs.job_filter == 'nightly-test-multimodal-server-1-gpu')\n    runs-on: 1-gpu-runner\n    strategy:\n      fail-fast: false\n      max-parallel: 5\n      matrix:\n        part: [0, 1]\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/cuda/ci_install_dependency.sh diffusion\n          pip install slack_sdk\n\n      - name: Run diffusion server tests\n        env:\n          SGLANG_DIFFUSION_SLACK_TOKEN: ${{ secrets.SGLANG_DIFFUSION_SLACK_TOKEN }}\n          GITHUB_RUN_ID: ${{ github.run_id }}\n          GPU_CONFIG: \"1-gpu-runner\"\n\n        timeout-minutes: 60\n        run: |\n          cd python\n          python3 sglang/multimodal_gen/test/run_suite.py \\\n            --suite 1-gpu \\\n            --partition-id ${{ matrix.part }} \\\n            --total-partitions 2\n\n      - name: Collect diffusion performance metrics\n        if: always()\n        run: |\n          python3 scripts/ci/utils/save_diffusion_metrics.py \\\n            --gpu-config 1-gpu-runner \\\n            --run-id ${{ github.run_id }} \\\n            --output python/diffusion-metrics-1gpu-partition-${{ matrix.part }}.json \\\n            --results-json python/diffusion-results.json\n\n      - name: Upload diffusion metrics\n        if: always()\n        uses: actions/upload-artifact@v4\n        with:\n          name: diffusion-metrics-1gpu-partition-${{ matrix.part }}\n          path: python/diffusion-metrics-1gpu-partition-${{ matrix.part }}.json\n          retention-days: 90\n          if-no-files-found: ignore\n\n      - uses: ./.github/actions/upload-cuda-coredumps\n        if: always()\n        with:\n          artifact-suffix: ${{ matrix.part }}\n\n  nightly-test-multimodal-server-2-gpu:\n    if: github.repository == 'sgl-project/sglang' && (inputs.job_filter == '' || inputs.job_filter == 'all' || inputs.job_filter == 'nightly-test-multimodal-server-2-gpu')\n    runs-on: 2-gpu-runner\n    strategy:\n      fail-fast: false\n      max-parallel: 5\n      matrix:\n        part: [0, 1]\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/cuda/ci_install_dependency.sh diffusion\n          pip install slack_sdk\n\n      - name: Run diffusion server tests\n        env:\n          SGLANG_DIFFUSION_SLACK_TOKEN: ${{ secrets.SGLANG_DIFFUSION_SLACK_TOKEN }}\n          GITHUB_RUN_ID: ${{ github.run_id }}\n          GPU_CONFIG: \"2-gpu-runner\"\n\n        timeout-minutes: 60\n        run: |\n          cd python\n          python3 sglang/multimodal_gen/test/run_suite.py \\\n            --suite 2-gpu \\\n            --partition-id ${{ matrix.part }} \\\n            --total-partitions 2\n\n      - name: Collect diffusion performance metrics\n        if: always()\n        run: |\n          python3 scripts/ci/utils/save_diffusion_metrics.py \\\n            --gpu-config 2-gpu-runner \\\n            --run-id ${{ github.run_id }} \\\n            --output python/diffusion-metrics-2gpu-partition-${{ matrix.part }}.json \\\n            --results-json python/diffusion-results.json\n\n      - name: Upload diffusion metrics\n        if: always()\n        uses: actions/upload-artifact@v4\n        with:\n          name: diffusion-metrics-2gpu-partition-${{ matrix.part }}\n          path: python/diffusion-metrics-2gpu-partition-${{ matrix.part }}.json\n          retention-days: 90\n          if-no-files-found: ignore\n\n      - uses: ./.github/actions/upload-cuda-coredumps\n        if: always()\n        with:\n          artifact-suffix: ${{ matrix.part }}\n\n  # B200 Performance tests - 4 GPU\n  nightly-test-perf-4-gpu-b200:\n    if: github.repository == 'sgl-project/sglang' && (inputs.job_filter == '' || inputs.job_filter == 'all' || inputs.job_filter == 'nightly-test-perf-4-gpu-b200')\n    runs-on: 4-gpu-b200\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Install dependencies\n        run: |\n          IS_BLACKWELL=1 bash scripts/ci/cuda/ci_install_dependency.sh\n\n      - name: Run test\n        timeout-minutes: 300\n        run: |\n          cd test\n          python3 run_suite.py --hw cuda --suite nightly-4-gpu-b200 --nightly --continue-on-error --timeout-per-file 12000\n\n      - uses: ./.github/actions/upload-cuda-coredumps\n        if: always()\n\n  # Specialized B200 tests - 8 GPU, for specific backends and configs\n  nightly-test-specialized-8-gpu-b200:\n    if: github.repository == 'sgl-project/sglang' && (inputs.job_filter == '' || inputs.job_filter == 'all' || inputs.job_filter == 'nightly-test-perf-8-gpu-b200')\n    runs-on: 8-gpu-b200\n    env:\n      RUNNER_LABELS: 8-gpu-b200\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Install dependencies\n        run: |\n          IS_BLACKWELL=1 bash scripts/ci/cuda/ci_install_dependency.sh\n\n      - name: Run test\n        timeout-minutes: 120\n        env:\n          GPU_CONFIG: \"8-gpu-b200\"\n        run: |\n          cd test\n          python3 run_suite.py --hw cuda --suite nightly-8-gpu-b200 --nightly --continue-on-error --timeout-per-file 2400\n\n      - uses: ./.github/actions/upload-cuda-coredumps\n        if: always()\n\n  # Consolidate performance metrics from all jobs\n  consolidate-metrics:\n    if: github.repository == 'sgl-project/sglang' && always()\n    needs:\n      - nightly-test-general-8-gpu-h200\n      - nightly-test-general-8-gpu-b200\n      - nightly-test-multimodal-server-1-gpu\n      - nightly-test-multimodal-server-2-gpu\n    runs-on: ubuntu-latest\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Download all partition metrics\n        uses: actions/download-artifact@v4\n        with:\n          pattern: \"*metrics-*\"\n          path: metrics/\n          merge-multiple: true\n\n      - name: List downloaded metrics\n        run: |\n          echo \"Downloaded metrics files:\"\n          find metrics/ -name \"*.json\" -type f 2>/dev/null || echo \"No metrics files found\"\n\n      - name: Merge metrics\n        run: |\n          python3 scripts/ci/utils/merge_metrics.py \\\n            --input-dir metrics/ \\\n            --output consolidated-metrics-${{ github.run_id }}.json \\\n            --run-id ${{ github.run_id }} \\\n            --commit-sha ${{ github.sha }} \\\n            --branch ${{ github.ref_name }}\n\n      - name: Upload consolidated metrics\n        uses: actions/upload-artifact@v4\n        with:\n          name: consolidated-metrics-${{ github.run_id }}\n          path: consolidated-metrics-${{ github.run_id }}.json\n          retention-days: 90\n          if-no-files-found: warn\n\n  # Final check job\n  check-all-jobs:\n    if: github.repository == 'sgl-project/sglang' && always()\n    needs:\n      - nightly-test-general-1-gpu-runner\n      - nightly-test-general-4-gpu-h100\n      - nightly-test-general-8-gpu-h200\n      - nightly-test-general-8-gpu-h20\n      - nightly-test-general-8-gpu-b200\n      - nightly-test-text-accuracy-2-gpu-runner\n      - nightly-test-text-perf-2-gpu-runner\n      - nightly-test-vlm-accuracy-2-gpu-runner\n      - nightly-test-vlm-perf-2-gpu-runner\n      - nightly-test-multimodal-server-1-gpu\n      - nightly-test-multimodal-server-2-gpu\n      - nightly-test-perf-4-gpu-b200\n      - nightly-test-specialized-8-gpu-b200\n      - consolidate-metrics\n    runs-on: ubuntu-latest\n    steps:\n      - name: Check if any job failed\n        run: |\n          if [[ \"${{ contains(needs.*.result, 'failure') }}\" == \"true\" ]]; then\n            echo \"One or more nightly test jobs failed\"\n            exit 1\n          fi\n          if [[ \"${{ contains(needs.*.result, 'cancelled') }}\" == \"true\" ]]; then\n            echo \"One or more nightly test jobs were cancelled\"\n            exit 1\n          fi\n          echo \"All nightly test jobs passed\"\n"
  },
  {
    "path": ".github/workflows/open-pr-copy-from-oss.yml",
    "content": "name: Open A PR to Copy Code From OSS\n\non:\n  workflow_dispatch:\n  # schedule:\n  #   - cron: '0 10 * * *'\n\npermissions:\n  contents: write\n\njobs:\n  copy:\n    runs-on: ubuntu-latest\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v4\n        with:\n          ref: 'main'\n\n      - name: Install GitHub CLI (if not present)\n        run: |\n          bash scripts/code_sync/install_github_cli.sh\n\n      - name: Copy from OSS code\n        env:\n          GH_TOKEN: ${{ secrets.GH_PAT_FOR_OPEN_PR_TO_PRIVATE }}\n        run: |\n          python3 scripts/code_sync/copy_from_oss.py\n"
  },
  {
    "path": ".github/workflows/open-pr-copy-to-oss.yml",
    "content": "name: Open A PR to Copy Diff To OSS\n\non:\n  workflow_dispatch:\n    inputs:\n      commit_sha:\n        description: 'The commit SHA to copy. Defaults to LAST to copy the latest commit.'\n        required: false\n        default: 'LAST'\n\npermissions:\n  contents: write\n\njobs:\n  copy:\n    runs-on: ubuntu-latest\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v4\n        with:\n          fetch-depth: 0\n\n      - name: Install GitHub CLI (if not present)\n        run: |\n          bash scripts/code_sync/install_github_cli.sh\n\n      - name: Copy to OSS code\n        env:\n          GH_TOKEN: ${{ secrets.GH_PAT_FOR_OPEN_PR_TO_OSS }}\n        run: |\n          python3 scripts/code_sync/copy_to_oss.py --commit ${{ github.event.inputs.commit_sha }}\n"
  },
  {
    "path": ".github/workflows/patch-docker-dev.yml",
    "content": "name: Patch Docker Image\n\non:\n  workflow_dispatch:\n    inputs:\n      pr_numbers:\n        description: \"Comma-separated PR numbers to apply (e.g. 18962,19010)\"\n        required: false\n        default: \"\"\n      image_tag:\n        description: \"Base image tag to patch (e.g. dev-x86, dev-x86-cu13)\"\n        required: true\n\nconcurrency:\n  group: patch-docker-${{ inputs.image_tag }}\n  cancel-in-progress: true\n\njobs:\n  patch:\n    if: github.repository == 'sgl-project/sglang'\n    runs-on: x64-docker-build-node\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v4\n        with:\n          fetch-depth: 0\n\n      - name: Login to Docker Hub\n        uses: docker/login-action@v2\n        with:\n          username: ${{ secrets.DOCKERHUB_USERNAME }}\n          password: ${{ secrets.DOCKERHUB_TOKEN }}\n\n      - name: Pull base image and extract commit\n        run: |\n          IMAGE=\"lmsysorg/sglang:${{ inputs.image_tag }}\"\n          docker pull \"${IMAGE}\"\n          if BASE_SHA=$(docker run --rm \"${IMAGE}\" git -C /sgl-workspace/sglang rev-parse HEAD 2>/dev/null); then\n            echo \"Image built from commit: ${BASE_SHA}\"\n          else\n            BASE_SHA=\"\"\n            echo \"::warning::Image has no .git directory — cannot extract base commit\"\n          fi\n          echo \"BASE_SHA=${BASE_SHA}\" >> \"$GITHUB_ENV\"\n\n      - name: Generate patches\n        run: |\n          git config --global --add safe.directory \"$GITHUB_WORKSPACE\"\n          git fetch origin main\n          mkdir -p /tmp/patch-ctx\n\n          if [ -n \"${{ inputs.pr_numbers }}\" ]; then\n            IFS=',' read -ra PRS <<< \"${{ inputs.pr_numbers }}\"\n            for pr in \"${PRS[@]}\"; do\n              pr=$(echo \"${pr}\" | xargs)\n              echo \"Fetching PR #${pr}\"\n              git fetch origin \"pull/${pr}/head:pr-${pr}\"\n              MERGE_BASE=$(git merge-base origin/main \"pr-${pr}\")\n              echo \"  PR #${pr}: merge-base=${MERGE_BASE}\"\n              git diff \"${MERGE_BASE}..pr-${pr}\" > \"/tmp/patch-ctx/${pr}.patch\"\n              echo \"  PR #${pr}: $(wc -l < /tmp/patch-ctx/${pr}.patch) lines\"\n            done\n          elif [ -n \"${BASE_SHA}\" ]; then\n            echo \"Generating diff: image ${BASE_SHA} → latest main\"\n            git fetch origin \"${BASE_SHA}\"\n            git diff \"${BASE_SHA}..origin/main\" > /tmp/patch-ctx/main.patch\n            echo \"  main: $(wc -l < /tmp/patch-ctx/main.patch) lines\"\n          else\n            echo \"::error::No PR numbers specified and image has no .git — cannot generate diff against main\"\n            exit 1\n          fi\n\n          TOTAL=$(cat /tmp/patch-ctx/*.patch | wc -l)\n          if [ \"${TOTAL}\" -eq 0 ]; then\n            echo \"::warning::All patches are empty — image is already up to date\"\n            echo \"SKIP_BUILD=true\" >> \"$GITHUB_ENV\"\n          fi\n\n      - name: Build patched image\n        if: env.SKIP_BUILD != 'true'\n        run: |\n          IMAGE=\"lmsysorg/sglang:${{ inputs.image_tag }}\"\n\n          cat <<'DOCKERFILE' > /tmp/patch-ctx/Dockerfile\n          ARG BASE_IMAGE\n          FROM ${BASE_IMAGE}\n          COPY *.patch /tmp/patches/\n          RUN cd /sgl-workspace/sglang \\\n              && for p in /tmp/patches/*.patch; do \\\n                   if [ ! -s \"${p}\" ]; then \\\n                     echo \"Skipping ${p} (empty)\"; \\\n                   else \\\n                     echo \"Applying ${p}...\" \\\n                     && patch -p1 --fuzz=2 --no-backup-if-mismatch -f < \"${p}\" \\\n                     || { echo \"ERROR: Failed to apply ${p}\"; exit 1; }; \\\n                   fi; \\\n                 done \\\n              && rm -rf /tmp/patches\n          DOCKERFILE\n\n          docker build \\\n            --no-cache \\\n            --build-arg BASE_IMAGE=\"${IMAGE}\" \\\n            -t \"${IMAGE}\" \\\n            /tmp/patch-ctx/\n\n      - name: Push patched image\n        if: env.SKIP_BUILD != 'true'\n        run: |\n          IMAGE=\"lmsysorg/sglang:${{ inputs.image_tag }}\"\n          docker push \"${IMAGE}\"\n\n          echo \"### Patched \\`${IMAGE}\\`\" >> \"$GITHUB_STEP_SUMMARY\"\n          echo \"- **Base commit:** \\`${BASE_SHA:-unknown (no .git)}\\`\" >> \"$GITHUB_STEP_SUMMARY\"\n          echo \"- **Source:** ${{ inputs.pr_numbers && format('PRs: {0}', inputs.pr_numbers) || 'latest main' }}\" >> \"$GITHUB_STEP_SUMMARY\"\n"
  },
  {
    "path": ".github/workflows/pr-benchmark-rust.yml",
    "content": "name: PR Benchmark (SMG Components)\n\non:\n  push:\n    branches: [ main ]\n    paths:\n      - \"sgl-model-gateway/**\"\n  pull_request:\n    branches: [ main ]\n    paths:\n      - \"sgl-model-gateway/**\"\n  workflow_dispatch:\n\nconcurrency:\n  group: pr-benchmark-rust-${{ github.ref }}\n  cancel-in-progress: true\n\nenv:\n  RUSTC_WRAPPER: sccache\n  SCCACHE_GHA_ENABLED: \"true\"\n\npermissions:\n  contents: read\n  pull-requests: write\n  issues: write\n\njobs:\n  benchmark-compile-check:\n    name: Benchmark Compilation Check\n    runs-on: ubuntu-latest\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/cuda/ci_install_gateway_dependencies.sh\n\n      - name: Configure sccache\n        uses: mozilla-actions/sccache-action@v0.0.9\n        with:\n          version: \"v0.12.0\"\n          disable_annotations: true\n\n      - name: Rust cache\n        uses: Swatinem/rust-cache@v2\n        with:\n          workspaces: sgl-model-gateway\n          shared-key: \"rust-cache\"\n          save-if: true\n          cache-all-crates: true\n          cache-on-failure: true\n\n      - name: Check benchmarks compile\n        run: |\n          source \"$HOME/.cargo/env\"\n          cd sgl-model-gateway/\n          cargo check --benches\n\n      - name: Show sccache stats\n        if: always()\n        run: sccache --show-stats\n\n  benchmark:\n    name: Benchmark - ${{ matrix.name }}\n    if: |\n      github.repository == 'sgl-project/sglang' &&\n      (github.event_name == 'push' ||\n       github.event_name == 'workflow_dispatch' ||\n       (contains(github.event.pull_request.labels.*.name, 'router-benchmark') &&\n        contains(github.event.pull_request.labels.*.name, 'run-ci')))\n    strategy:\n      fail-fast: false\n      matrix:\n        include:\n          - name: Request Processing\n            bench_name: request_processing\n            bench_args: \"benchmark_summary --exact\"\n            runner: ubuntu-latest\n            sccache_version: \"v0.12.0\"\n            artifact_name: request-processing-results\n            artifact_path: criterion/benchmark_summary/\n          - name: Manual Policy\n            bench_name: manual_policy_benchmark\n            bench_args: \"\"\n            runner: ubuntu-latest\n            sccache_version: \"v0.12.0\"\n            artifact_name: manual-policy-results\n            artifact_path: criterion/manual_policy*/\n    runs-on: ${{ matrix.runner }}\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          fetch-depth: 100\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/cuda/ci_install_gateway_dependencies.sh\n\n      - name: Configure sccache\n        uses: mozilla-actions/sccache-action@v0.0.9\n        with:\n          version: ${{ matrix.sccache_version }}\n          disable_annotations: true\n\n      - name: Rust cache\n        uses: Swatinem/rust-cache@v2\n        with:\n          workspaces: sgl-model-gateway\n          shared-key: \"rust-cache\"\n          cache-all-crates: true\n          cache-on-failure: true\n          save-if: true\n\n      - name: Run benchmark\n        timeout-minutes: 30\n        run: |\n          source \"$HOME/.cargo/env\"\n          cd sgl-model-gateway/\n          if command -v sccache &> /dev/null; then\n            echo \"Testing sccache availability...\"\n            export RUSTC_WRAPPER=sccache\n            export SCCACHE_GHA_ENABLED=\"true\"\n            if sccache --start-server 2>/dev/null && sccache --show-stats 2>/dev/null; then\n              echo \"sccache is working, using it for compilation\"\n            else\n              echo \"sccache failed to start, falling back to regular cargo\"\n              unset RUSTC_WRAPPER\n              unset SCCACHE_GHA_ENABLED\n            fi\n          else\n            echo \"sccache not available, using regular cargo\"\n          fi\n          cargo bench --bench ${{ matrix.bench_name }} -- ${{ matrix.bench_args }} 2>&1 | tee benchmark_output.txt\n\n      - name: Upload benchmark results\n        if: always()\n        uses: actions/upload-artifact@v4\n        with:\n          name: ${{ matrix.artifact_name }}-${{ github.sha }}\n          path: |\n            sgl-model-gateway/target/${{ matrix.artifact_path }}\n            sgl-model-gateway/benchmark_output.txt\n          retention-days: 30\n\n      - name: Show sccache stats\n        if: always()\n        run: sccache --show-stats\n\n  benchmark-summary:\n    name: Benchmark Summary\n    needs: [benchmark]\n    if: always() && (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request')\n    runs-on: ubuntu-latest\n    steps:\n      - name: Download all benchmark results\n        uses: actions/download-artifact@v4\n        with:\n          pattern: '*-results-${{ github.sha }}'\n          path: benchmark-results\n\n      - name: Generate summary\n        run: |\n          generate_section() {\n            local title=\"$1\" dir_name=\"$2\" lines=\"${3:-100}\"\n            local dir=\"benchmark-results/${dir_name}-${{ github.sha }}\"\n            echo \"### $title\" >> summary.md\n            if [ -d \"$dir\" ]; then\n              echo \"✅ **Completed**\" >> summary.md\n              if [ -f \"$dir/benchmark_output.txt\" ]; then\n                echo -e \"\\n<details>\\n<summary>View Results</summary>\\n\\n\\`\\`\\`\" >> summary.md\n                tail -\"$lines\" \"$dir/benchmark_output.txt\" >> summary.md\n                echo -e \"\\`\\`\\`\\n</details>\" >> summary.md\n              fi\n            else\n              echo \"❌ Failed or skipped\" >> summary.md\n            fi\n            echo \"\" >> summary.md\n          }\n\n          echo \"## 🚀 Benchmark Results Summary\" > summary.md\n          echo \"\" >> summary.md\n\n          generate_section \"Request Processing\" \"request-processing-results\" 60\n          generate_section \"Manual Policy (Sticky Sessions)\" \"manual-policy-results\" 100\n\n          echo -e \"---\\n_Generated at $(date -u '+%Y-%m-%d %H:%M:%S UTC')_\" >> summary.md\n\n          cat summary.md\n          cat summary.md >> $GITHUB_STEP_SUMMARY\n\n      - name: Upload summary\n        uses: actions/upload-artifact@v4\n        with:\n          name: benchmark-summary-${{ github.sha }}\n          path: summary.md\n          retention-days: 30\n"
  },
  {
    "path": ".github/workflows/pr-gate.yml",
    "content": "on:\n  workflow_call:\n    inputs:\n      require-run-ci:\n        description: \"Whether the PR must have the run-ci label\"\n        type: boolean\n        default: true\n      cool-down-minutes:\n        description: \"Cooldown period in minutes for low-permission users; 0 disables rate limiting\"\n        type: number\n        default: 120\n\njobs:\n  pr-gate:\n    # 1. for commits on main: no gating needed\n    # 2. for workflow_dispatch: this can only be triggered by users with write access\n    runs-on: ubuntu-latest\n    steps:\n      - name: Fetch latest PR info\n        if: github.event_name == 'pull_request'\n        id: pr\n        uses: actions/github-script@v7\n        with:\n          github-token: ${{ secrets.GITHUB_TOKEN }}\n          script: |\n            const pr = await github.rest.pulls.get({\n              owner: context.repo.owner,\n              repo: context.repo.repo,\n              pull_number: context.issue.number\n            });\n            core.setOutput(\"labels\", JSON.stringify(pr.data.labels.map(l => l.name)));\n            core.setOutput(\"draft\", pr.data.draft);\n            core.setOutput(\"user\", pr.data.user.login);\n\n      - name: Log PR info\n        if: github.event_name == 'pull_request'\n        run: |\n          echo \"===== PR Info =====\"\n          echo \"PR Event: ${{ github.event_name }}\"\n          echo \"PR Labels: ${{ steps.pr.outputs.labels }}\"\n          echo \"PR Draft: ${{ steps.pr.outputs.draft }}\"\n          echo \"PR User: ${{ steps.pr.outputs.user }}\"\n          echo \"Require run-ci: ${{ inputs.require-run-ci }}\"\n          echo \"Cool down minutes: ${{ inputs.cool-down-minutes }}\"\n          echo \"===================\"\n\n      - name: Block draft PR\n        if: github.event_name == 'pull_request' && fromJson(steps.pr.outputs.draft)\n        run: |\n          echo \"PR is draft. Blocking CI.\"\n          exit 1\n\n      - name: Require run-ci label (optional)\n        if:  github.event_name == 'pull_request' && inputs.require-run-ci == true\n        run: |\n          labels='${{ steps.pr.outputs.labels }}'\n          if [[ \"${{ contains(fromJson(steps.pr.outputs.labels), 'run-ci') }}\" == \"false\" ]]; then\n            echo \"Missing required label 'run-ci'. See https://docs.sglang.io/developer_guide/contribution_guide.html#how-to-trigger-ci-tests for more details.\"\n            exit 1\n          fi\n\n      - name: Enforce rate limit for low-permission actors (optional)\n        if: github.event_name == 'pull_request' && inputs.cool-down-minutes > 0\n        uses: actions/github-script@v7\n        with:\n          github-token: ${{ secrets.GITHUB_TOKEN }}\n          script: |\n            const DEFAULT_MINUTES = Number(\"${{ inputs.cool-down-minutes }}\");\n            const owner = context.repo.owner;\n            const repo = context.repo.repo;\n            const eventName = context.eventName;\n            const curRun = await github.rest.actions.getWorkflowRun({\n              owner, repo, run_id: context.runId\n            });\n            let triggeringActor = curRun.data.triggering_actor?.login || context.actor;\n            if (triggeringActor === \"github-actions[bot]\") {\n              triggeringActor = `${{ steps.pr.outputs.user }}`;\n              core.info(\n                `triggering_actor is github-actions[bot]; substituting PR author '${triggeringActor}'.`\n              );\n            }\n\n            async function hasHighPermission(username) {\n              try {\n                const { data } = await github.rest.repos.getCollaboratorPermissionLevel({ owner, repo, username });\n                const perm = data.permission || 'none';\n                return perm === 'write' || perm === 'maintain' || perm === 'admin';\n              } catch (e) {\n                if (e.status === 404 || e.status === 403) return false;\n                throw e;\n              }\n            }\n\n            if (await hasHighPermission(triggeringActor)) {\n              core.info(`Triggering user '${triggeringActor}' has high permission. No rate limit applied.`);\n              return;\n            }\n\n            let effectiveCooldownMinutes = DEFAULT_MINUTES;\n            let perUserCooldownMinutes = null;\n\n            try {\n              const contentResp = await github.rest.repos.getContent({\n                owner,\n                repo,\n                path: \".github/CI_PERMISSIONS.json\",\n                ref: \"main\",\n              });\n\n              if (!Array.isArray(contentResp.data) && contentResp.data && \"content\" in contentResp.data) {\n                const raw = Buffer.from(\n                  contentResp.data.content,\n                  contentResp.data.encoding || \"base64\"\n                ).toString();\n                const ciPermissions = JSON.parse(raw);\n\n                const userPerm = ciPermissions[triggeringActor];\n                if (userPerm && typeof userPerm.cooldown_interval_minutes === \"number\") {\n                  perUserCooldownMinutes = userPerm.cooldown_interval_minutes;\n                  core.info(\n                    `Per-user cooldown for '${triggeringActor}' from CI_PERMISSIONS.json: ${perUserCooldownMinutes} minutes.`\n                  );\n                } else {\n                  core.info(`No per-user cooldown found for '${triggeringActor}' in CI_PERMISSIONS.json.`);\n                }\n              } else {\n                core.info(\"CI_PERMISSIONS.json content response is not a file; skipping per-user cooldown.\");\n              }\n            } catch (e) {\n              core.info(`CI_PERMISSIONS.json not found or unreadable: ${e.message}. Using default rate limit only.`);\n            }\n\n            if (perUserCooldownMinutes !== null) {\n              effectiveCooldownMinutes = Math.min(effectiveCooldownMinutes, perUserCooldownMinutes);\n            }\n\n            if (effectiveCooldownMinutes <= 0) {\n              core.info(\n                `Effective cooldown for '${triggeringActor}' is 0 minutes; no rate limit enforced for this user.`\n              );\n              return;\n            }\n\n            const cutoff = new Date(Date.now() - effectiveCooldownMinutes * 60 * 1000);\n            core.info(\n              `Checking for workflow runs since ${cutoff.toISOString()} (last ${effectiveCooldownMinutes} minutes) for event '${eventName}'.`\n            );\n\n            const { data } = await github.rest.actions.listWorkflowRuns({\n              owner,\n              repo,\n              workflow_id: 'pr-test.yml',\n              event: eventName,\n              per_page: 100,\n            });\n\n            const runs = data.workflow_runs || [];\n\n            // Rate Limiting Logic:\n            // We only count workflow runs that actually consumed CI resources (i.e., passed the gate).\n            // A run \"passes the gate\" if any jobs beyond the gate jobs (check-changes, pr-gate, call-gate)\n            // actually executed (not skipped/cancelled). This prevents scenarios where:\n            // - User has PR A with missing 'run-ci' label (fails at gate)\n            // - User opens PR B with 'run-ci' label\n            // - PR B should be able to run even though PR A triggered a run recently\n\n            // Helper function to check if a run passed the gate (i.e., actually consumed CI resources)\n            async function didRunPassGate(run) {\n              try {\n                // Note: Fetching up to 100 jobs (API maximum). If a workflow has >100 jobs,\n                // we may miss some, but this is unlikely in practice.\n                const { data: jobsData } = await github.rest.actions.listJobsForWorkflowRun({\n                  owner, repo, run_id: run.id, per_page: 100\n                });\n                const jobs = jobsData.jobs || [];\n\n                // If no jobs exist yet, the run hasn't started consuming resources\n                if (jobs.length === 0) {\n                  core.info(`Run ${run.id} has no jobs yet; not counting against rate limit.`);\n                  return false;\n                }\n\n                // Gate jobs that don't consume significant CI resources\n                const gateJobs = ['check-changes', 'pr-gate', 'call-gate', 'pr-test-finish'];\n                const jobsBeyondGate = jobs.filter(j => !gateJobs.some(g => j.name === g || j.name.startsWith(g + ' ')));\n\n                // A job \"ran\" if it reached a terminal conclusion state that indicates actual execution\n                const ranStates = ['success', 'failure', 'timed_out', 'action_required'];\n                const hasJobsThatRan = jobsBeyondGate.some(j => j.conclusion && ranStates.includes(j.conclusion));\n                return hasJobsThatRan;\n              } catch (e) {\n                core.warning(`Could not check jobs for run ${run.id}: ${e.message}`);\n\n                // If it's a rate limit error, count it conservatively to prevent abuse\n                if (e.status === 429) {\n                  core.warning(`Hit rate limit checking run ${run.id}; counting it to be safe.`);\n                  return true;\n                }\n\n                // For cancelled/skipped runs, they likely didn't consume resources\n                if (run.conclusion === 'cancelled' || run.conclusion === 'skipped') {\n                  return false;\n                }\n\n                // Default to counting it to prevent abuse\n                return true;\n              }\n            }\n\n            // Limit the number of runs we'll check in detail to avoid API rate limits\n            const MAX_RUNS_TO_CHECK = 5;\n            let runsChecked = 0;\n            let runsSkippedAtGate = 0;\n            let recentFound = null;\n\n            for (const run of runs) {\n              if (String(run.id) === String(context.runId)) continue;\n              if (new Date(run.created_at) < cutoff) continue;\n              const isUserRun = (run.actor?.login === triggeringActor) || (run.triggering_actor?.login === triggeringActor);\n              if (!isUserRun) continue;\n\n              runsChecked++;\n              core.info(`Checking run ${run.id} (created: ${run.created_at}, conclusion: ${run.conclusion})`);\n\n              // Safety limit: if we've checked too many runs, assume the next one passed to be conservative\n              if (runsChecked > MAX_RUNS_TO_CHECK) {\n                core.warning(`Checked ${MAX_RUNS_TO_CHECK} runs; assuming this one passed gate to avoid API limits.`);\n                recentFound = run;\n                break;\n              }\n\n              // Only count runs that actually passed the gate and consumed CI resources\n              if (await didRunPassGate(run)) {\n                recentFound = run;\n                core.info(`Found recent run ${run.id} that passed gate.`);\n                break;\n              } else {\n                runsSkippedAtGate++;\n                core.info(`Run ${run.id} failed at gate; not counting against rate limit.`);\n              }\n            }\n\n            core.info(`Rate limit check summary: checked ${runsChecked} runs, ${runsSkippedAtGate} failed at gate.`);\n\n            if (recentFound) {\n              core.setFailed(\n                `User '${triggeringActor}' already triggered '${context.workflow}' via '${eventName}' at ${recentFound.created_at}. ` +\n                `Please wait ${effectiveCooldownMinutes} minutes before triggering again.`\n              );\n            } else {\n              core.info(\n                `No recent runs detected for '${triggeringActor}' within the last ${effectiveCooldownMinutes} minutes; proceeding.`\n              );\n            }\n"
  },
  {
    "path": ".github/workflows/pr-test-amd-rocm720.yml",
    "content": "name: PR Test ROCm 7.2 (AMD)\n# Dynamic run-name for /rerun-stage commands to enable URL lookup\n# Format: \"[stage-name] sha\" for fork PRs, \"[stage-name]\" for non-fork, default for normal runs\nrun-name: ${{ (inputs.target_stage || inputs.target_stage_select) && (inputs.pr_head_sha && format('[{0}] {1}', inputs.target_stage || inputs.target_stage_select, inputs.pr_head_sha) || format('[{0}]', inputs.target_stage || inputs.target_stage_select)) || '' }}\n\non:\n  schedule:\n    - cron: '30 17 * * *'\n  # push:\n  #   branches: [ main ]\n  #   paths:\n  #     - \"python/**\"\n  #     - \"scripts/ci/**\"\n  #     - \"test/**\"\n  #     - \"sgl-kernel/**\"\n  #     - \".github/workflows/pr-test-amd-rocm720.yml\"\n  #     - \"docker/rocm.Dockerfile\"\n  # pull_request:\n  #   branches: [ main ]\n  #   paths:\n  #     - \"python/**\"\n  #     - \"scripts/ci/**\"\n  #     - \"test/**\"\n  #     - \"sgl-kernel/**\"\n  #     - \".github/workflows/pr-test-amd-rocm720.yml\"\n  #     - \"docker/rocm.Dockerfile\"\n  workflow_dispatch:\n    inputs:\n      target_stage_select:\n        description: \"Select a stage to run from dropdown (leave empty for auto-detect)\"\n        required: false\n        type: choice\n        default: ''\n        options:\n          - ''\n          - sgl-kernel-unit-test-amd\n          - sgl-kernel-unit-test-2-gpu-amd\n          - stage-a-test-1-amd\n          - jit-kernel-unit-test-amd\n          - stage-b-test-small-1-gpu-amd\n          - stage-b-test-small-1-gpu-amd-nondeterministic\n          - stage-b-test-small-1-gpu-amd-mi35x\n          - stage-b-test-large-1-gpu-amd\n          - stage-b-test-large-2-gpu-amd\n          - multimodal-gen-test-1-gpu-amd\n          - multimodal-gen-test-2-gpu-amd\n          - stage-c-test-large-8-gpu-amd\n          - stage-c-test-large-8-gpu-amd-mi35x\n          - stage-b-test-large-8-gpu-disaggregation-amd\n      target_stage:\n        description: \"Or type comma-separated stage names (overrides dropdown if non-empty)\"\n        required: false\n        type: string\n        default: \"\"\n      pr_head_sha:\n        description: \"PR head SHA to checkout (for /rerun-stage on fork PRs)\"\n        required: false\n        type: string\n        default: \"\"\n      aiter_ref:\n        description: 'Override AITER commit (optional, leave empty to use Dockerfile default)'\n        required: false\n        type: string\n        default: ''\n      continue_on_error:\n        description: 'Continue on error (do not fail the workflow on test failures)'\n        required: false\n        type: boolean\n        default: true\n  workflow_call:\n    inputs:\n      ref:\n        description: 'Git ref (branch, tag, or SHA) to test. If not provided, uses the default branch.'\n        required: false\n        type: string\n        default: ''\n      run_all_tests:\n        description: \"Run all tests (for releasing or testing purpose)\"\n        required: false\n        type: boolean\n        default: false\n      aiter_ref:\n        description: 'Override AITER commit (optional, leave empty to use Dockerfile default)'\n        required: false\n        type: string\n        default: ''\n      continue_on_error:\n        description: 'Continue on error (do not fail the workflow on test failures)'\n        required: false\n        type: boolean\n        default: true\n\nenv:\n  AITER_COMMIT_OVERRIDE: ${{ inputs.aiter_ref }}\n\nconcurrency:\n  # When called via workflow_call with run_all_tests=true, use a unique group per run to\n  # avoid collisions with direct schedule/workflow_dispatch triggers. We use run_all_tests\n  # (not github.event_name) to detect this, because github.event_name inherits from the caller.\n  group: pr-test-amd-rocm720-${{ inputs.run_all_tests && format('full-{0}', github.run_id) || inputs.pr_head_sha || inputs.ref || github.ref }}\n  cancel-in-progress: ${{ !inputs.run_all_tests && github.event_name != 'workflow_call' }}\n\njobs:\n  call-gate:\n    uses: ./.github/workflows/pr-gate.yml\n    secrets: inherit\n  check-changes:\n    needs: [call-gate]\n    runs-on: ubuntu-latest\n    outputs:\n      main_package: ${{ steps.filter.outputs.main_package || steps.run-mode.outputs.run_all_tests }}\n      sgl_kernel: ${{ steps.filter.outputs.sgl_kernel || steps.run-mode.outputs.run_all_tests }}\n      jit_kernel: ${{ steps.filter.outputs.jit_kernel || steps.run-mode.outputs.run_all_tests }}\n      multimodal_gen: ${{ steps.filter.outputs.multimodal_gen || steps.run-mode.outputs.run_all_tests }}\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.ref || github.sha }}\n\n      - name: Determine run mode\n        id: run-mode\n        run: |\n          # Run all tests for workflow_call (when ref input is provided)\n          # Note: github.event_name is inherited from caller, so we detect workflow_call by checking inputs.ref\n          if [[ \"${{ inputs.run_all_tests }}\" == \"true\" ]]; then\n            echo \"run_all_tests=true\" >> $GITHUB_OUTPUT\n            echo \"Run mode: ALL TESTS (run_all_tests=${{ inputs.run_all_tests }})\"\n          else\n            echo \"run_all_tests=false\" >> $GITHUB_OUTPUT\n            echo \"Run mode: FILTERED (triggered by ${{ github.event_name }})\"\n          fi\n\n      - name: Detect file changes\n        id: filter\n        uses: dorny/paths-filter@v3\n        if: steps.run-mode.outputs.run_all_tests != 'true'\n        with:\n          filters: |\n            main_package:\n              - \"python/sglang/!(multimodal_gen)/**\"\n              - \"python/pyproject_rocm.toml\"\n              - \"python/pyproject_other.toml\"\n              - \"scripts/ci/amd/*\"\n              - \"scripts/ci/utils/*\"\n              - \"test/**\"\n              - \".github/workflows/pr-test-amd-rocm720.yml\"\n            sgl_kernel:\n              - \"sgl-kernel/**\"\n              - \".github/workflows/pr-test-amd-rocm720.yml\"\n            jit_kernel:\n              - \"python/sglang/jit_kernel/**\"\n              - \".github/workflows/pr-test-amd-rocm720.yml\"\n            multimodal_gen:\n              - \"python/sglang/multimodal_gen/**\"\n              - \"python/sglang/cli/**\"\n              - \"python/sglang/jit_kernel/diffusion/**\"\n              - \"python/pyproject_rocm.toml\"\n              - \"python/pyproject_other.toml\"\n\n  # =============================================== sgl-kernel ====================================================\n  sgl-kernel-unit-test-amd:\n    needs: [check-changes]\n    if: |\n      always() &&\n      (\n        (contains(format(',{0},', inputs.target_stage || inputs.target_stage_select), ',sgl-kernel-unit-test-amd,')) ||\n        (\n          !(inputs.target_stage || inputs.target_stage_select) &&\n          needs.check-changes.outputs.sgl_kernel == 'true'\n        )\n      )\n    strategy:\n      fail-fast: false\n      matrix:\n        runner: [linux-mi325-1gpu-sglang]\n    runs-on: ${{matrix.runner}}\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.ref || github.sha }}\n\n      - name: Ensure VRAM is clear\n        run: bash scripts/ensure_vram_clear.sh rocm\n\n      - name: Start CI container\n        run: bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh\n      - name: Run test\n        timeout-minutes: 14\n        run: |\n          docker exec -w /sglang-checkout/sgl-kernel/tests ci_sglang python3 -m pytest test_moe_align.py\n          docker exec -w /sglang-checkout/sgl-kernel/tests ci_sglang python3 -m pytest test_moe_topk_softmax.py\n          docker exec -w /sglang-checkout/sgl-kernel/tests/speculative ci_sglang python3 -m pytest test_eagle_utils.py\n          docker exec -w /sglang-checkout/sgl-kernel/tests ci_sglang python3 -m pytest test_apply_token_bitmask_inplace.py\n          docker exec -w /sglang-checkout/sgl-kernel/tests ci_sglang python3 -m pytest test_activation.py\n          docker exec -w /sglang-checkout/sgl-kernel/tests ci_sglang python3 -m pytest test_topk.py\n          docker exec -w /sglang-checkout/sgl-kernel/tests ci_sglang python3 -m pytest test_kvcacheio.py\n          docker exec -w /sglang-checkout/sgl-kernel/tests ci_sglang python3 -m pytest test_moe_topk_sigmoid.py\n          docker exec -w /sglang-checkout/sgl-kernel/tests ci_sglang python3 -m pytest test_torch_defaults_reset.py\n\n  sgl-kernel-unit-test-2-gpu-amd:\n    needs: [check-changes]\n    if: |\n      always() &&\n      (\n        (contains(format(',{0},', inputs.target_stage || inputs.target_stage_select), ',sgl-kernel-unit-test-2-gpu-amd,')) ||\n        (\n          !(inputs.target_stage || inputs.target_stage_select) &&\n          needs.check-changes.outputs.sgl_kernel == 'true'\n        )\n      )\n    strategy:\n      fail-fast: false\n      matrix:\n        runner: [linux-mi325-2gpu-sglang]\n    runs-on: ${{matrix.runner}}\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.ref || github.sha }}\n\n      - name: Ensure VRAM is clear\n        run: bash scripts/ensure_vram_clear.sh rocm\n\n      - name: Start CI container\n        run: bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh\n      - name: Run test\n        timeout-minutes: 20\n        run: |\n          docker exec -w /sglang-checkout/sgl-kernel/tests ci_sglang python3 -m pytest test_amd_deterministic_custom_allreduce.py\n          docker exec -w /sglang-checkout/sgl-kernel/tests ci_sglang python3 -m pytest test_amd_nccl_allreduce_determinism.py\n\n  # =============================================== primary ====================================================\n\n  stage-a-test-1-amd:\n    needs: [check-changes]\n    if: |\n      always() &&\n      (\n        (contains(format(',{0},', inputs.target_stage || inputs.target_stage_select), ',stage-a-test-1-amd,')) ||\n        (\n          !(inputs.target_stage || inputs.target_stage_select) &&\n          (!failure() && !cancelled()) &&\n          ((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true'))\n        )\n      )\n    strategy:\n      fail-fast: false\n      matrix:\n        runner: [linux-mi325-1gpu-sglang]\n    runs-on: ${{matrix.runner}}\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.ref || github.sha }}\n\n      - name: Ensure VRAM is clear\n        run: bash scripts/ensure_vram_clear.sh rocm\n\n      - name: Start CI container\n        run: bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh\n      - name: Run test\n        timeout-minutes: 10\n        run: |\n          bash scripts/ci/amd/amd_ci_exec.sh -w \"/sglang-checkout/test\" python3 run_suite.py --hw amd --suite stage-a-test-1-amd ${{ inputs.continue_on_error && '--continue-on-error' || '' }}\n\n  jit-kernel-unit-test-amd:\n    needs: [check-changes]\n    if: |\n      always() &&\n      (\n        (contains(format(',{0},', inputs.target_stage || inputs.target_stage_select), ',jit-kernel-unit-test-amd,')) ||\n        (\n          !(inputs.target_stage || inputs.target_stage_select) &&\n          needs.check-changes.outputs.jit_kernel == 'true'\n        )\n      )\n    strategy:\n      fail-fast: false\n      matrix:\n        runner: [linux-mi325-1gpu-sglang]\n    runs-on: ${{matrix.runner}}\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.ref || github.sha }}\n\n      - name: Ensure VRAM is clear\n        run: bash scripts/ensure_vram_clear.sh rocm\n\n      - name: Start CI container\n        run: bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh\n      - name: Run JIT kernel unit tests\n        timeout-minutes: 10\n        run: |\n          bash scripts/ci/amd/amd_ci_exec.sh -w \"/sglang-checkout\" python3 -m pytest -q python/sglang/jit_kernel/tests/test_store_cache.py\n\n  stage-b-test-small-1-gpu-amd:\n    needs: [check-changes]\n    if: |\n      always() &&\n      (\n        (contains(format(',{0},', inputs.target_stage || inputs.target_stage_select), ',stage-b-test-small-1-gpu-amd,')) ||\n        (\n          !(inputs.target_stage || inputs.target_stage_select) &&\n          (!failure() && !cancelled()) &&\n          ((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true'))\n        )\n      )\n    strategy:\n      fail-fast: false\n      matrix:\n        runner: [linux-mi325-1gpu-sglang]\n        part: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]\n    runs-on: ${{matrix.runner}}\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.ref || github.sha }}\n\n      - name: Ensure VRAM is clear\n        run: bash scripts/ensure_vram_clear.sh rocm\n\n      - name: Start CI container\n        run: bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh\n      - name: Run test\n        timeout-minutes: 30\n        run: |\n          bash scripts/ci/amd/amd_ci_exec.sh -w \"/sglang-checkout/test\" python3 run_suite.py --hw amd --suite stage-b-test-small-1-gpu-amd --auto-partition-id ${{ matrix.part }} --auto-partition-size 14 --timeout-per-file 1800 ${{ inputs.continue_on_error && '--continue-on-error' || '' }}\n\n  stage-b-test-small-1-gpu-amd-nondeterministic:\n    needs: [check-changes]\n    if: |\n      always() &&\n      (\n        (contains(format(',{0},', inputs.target_stage || inputs.target_stage_select), ',stage-b-test-small-1-gpu-amd-nondeterministic,')) ||\n        (\n          !(inputs.target_stage || inputs.target_stage_select) &&\n          (!failure() && !cancelled()) &&\n          ((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true'))\n        )\n      )\n    strategy:\n      fail-fast: false\n      matrix:\n        runner: [linux-mi325-1gpu-sglang]\n    runs-on: ${{matrix.runner}}\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.ref || github.sha }}\n\n      - name: Ensure VRAM is clear\n        run: bash scripts/ensure_vram_clear.sh rocm\n\n      - name: Start CI container\n        run: bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh\n      - name: Run test\n        timeout-minutes: 30\n        run: |\n          bash scripts/ci/amd/amd_ci_exec.sh -w \"/sglang-checkout/test\" python3 run_suite.py --hw amd --suite stage-b-test-small-1-gpu-amd-nondeterministic --timeout-per-file 1800 ${{ inputs.continue_on_error && '--continue-on-error' || '' }}\n\n  stage-b-test-small-1-gpu-amd-mi35x:\n    needs: [check-changes]\n    if: |\n      always() &&\n      (\n        (contains(format(',{0},', inputs.target_stage || inputs.target_stage_select), ',stage-b-test-small-1-gpu-amd-mi35x,')) ||\n        (\n          !(inputs.target_stage || inputs.target_stage_select) &&\n          (!failure() && !cancelled()) &&\n          ((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true'))\n        )\n      )\n    strategy:\n      fail-fast: false\n      matrix:\n        runner: [linux-mi35x-gpu-1]\n    runs-on: ${{matrix.runner}}\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.ref || github.sha }}\n\n      - name: Ensure VRAM is clear\n        run: bash scripts/ensure_vram_clear.sh rocm\n\n      - name: Start CI container\n        run: bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh\n      - name: Run test\n        timeout-minutes: 30\n        run: |\n          bash scripts/ci/amd/amd_ci_exec.sh -w \"/sglang-checkout/test\" python3 run_suite.py --hw amd --suite stage-b-test-small-1-gpu-amd-mi35x ${{ inputs.continue_on_error && '--continue-on-error' || '' }}\n\n  stage-b-test-large-1-gpu-amd:\n    needs: [check-changes]\n    if: |\n      always() &&\n      (\n        (contains(format(',{0},', inputs.target_stage || inputs.target_stage_select), ',stage-b-test-large-1-gpu-amd,')) ||\n        (\n          !(inputs.target_stage || inputs.target_stage_select) &&\n          (!failure() && !cancelled()) &&\n          ((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true'))\n        )\n      )\n    strategy:\n      fail-fast: false\n      matrix:\n        runner: [linux-mi325-1gpu-sglang]\n        part: [0, 1]\n    runs-on: ${{matrix.runner}}\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.ref || github.sha }}\n\n      - name: Ensure VRAM is clear\n        run: bash scripts/ensure_vram_clear.sh rocm\n\n      - name: Start CI container\n        run: bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh\n      - name: Run test\n        timeout-minutes: 30\n        run: |\n          bash scripts/ci/amd/amd_ci_exec.sh -w \"/sglang-checkout/test\" python3 run_suite.py --hw amd --suite stage-b-test-large-1-gpu-amd --auto-partition-id ${{ matrix.part }} --auto-partition-size 2 --timeout-per-file 1800 ${{ inputs.continue_on_error && '--continue-on-error' || '' }}\n\n  stage-b-test-large-2-gpu-amd:\n    needs: [check-changes]\n    if: |\n      always() &&\n      (\n        (contains(format(',{0},', inputs.target_stage || inputs.target_stage_select), ',stage-b-test-large-2-gpu-amd,')) ||\n        (\n          !(inputs.target_stage || inputs.target_stage_select) &&\n          (!failure() && !cancelled()) &&\n          ((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true'))\n        )\n      )\n    strategy:\n      fail-fast: false\n      matrix:\n        runner: [linux-mi325-2gpu-sglang]\n        part: [0, 1]\n    runs-on: ${{matrix.runner}}\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.ref || github.sha }}\n\n      - name: Ensure VRAM is clear\n        run: bash scripts/ensure_vram_clear.sh rocm\n\n      - name: Start CI container\n        run: bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh\n      - name: Run test\n        timeout-minutes: 30\n        run: |\n          bash scripts/ci/amd/amd_ci_exec.sh -w \"/sglang-checkout/test\" python3 run_suite.py --hw amd --suite stage-b-test-large-2-gpu-amd --auto-partition-id ${{ matrix.part }} --auto-partition-size 2 --timeout-per-file 1800 ${{ inputs.continue_on_error && '--continue-on-error' || '' }}\n\n  multimodal-gen-test-1-gpu-amd:\n    needs: [check-changes]\n    if: |\n      always() &&\n      (\n        (contains(format(',{0},', inputs.target_stage || inputs.target_stage_select), ',multimodal-gen-test-1-gpu-amd,')) ||\n        (\n          !(inputs.target_stage || inputs.target_stage_select) &&\n          (!failure() && !cancelled()) &&\n          ((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true'))\n        )\n      )\n    strategy:\n      fail-fast: false\n      max-parallel: 1  # Run one at a time to avoid eviction from resource exhaustion during AITER kernel JIT\n      matrix:\n        runner: [linux-mi325-1gpu-sglang]\n        part: [0, 1, 2, 3]\n    runs-on: ${{matrix.runner}}\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.ref || github.sha }}\n\n      - name: Ensure VRAM is clear\n        run: bash scripts/ensure_vram_clear.sh rocm\n\n      - name: Download artifacts\n        if: needs.check-changes.outputs.sgl_kernel == 'true'\n        uses: actions/download-artifact@v4\n        with:\n          path: sgl-kernel/dist/\n          merge-multiple: true\n          pattern: wheel-python3.10-cuda12.9\n\n      - name: Start CI container\n        run: bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh diffusion\n          docker exec ci_sglang pip install amdsmi\n\n      - name: Setup kernel caches\n        run: |\n          # Use the persistent /sgl-data directory (mounted from /home/runner/sgl-data)\n          # This directory persists across container restarts on the self-hosted runner\n          docker exec ci_sglang mkdir -p /sgl-data/aiter-kernels /sgl-data/miopen-cache /sgl-data/hf-cache/hub\n\n          # Clear pre-built AITER kernels from Docker image to avoid segfaults\n          # The image may have stale/incompatible kernels at /sgl-workspace/aiter/aiter/jit/\n          echo \"Clearing pre-built AITER kernels from Docker image...\"\n          docker exec ci_sglang rm -rf /sgl-workspace/aiter/aiter/jit/*.so 2>/dev/null || true\n          docker exec ci_sglang rm -rf /sgl-data/aiter-kernels/*.so 2>/dev/null || true\n          echo \"AITER kernels cleared - will be rebuilt on first use\"\n\n          # Create persistent cache marker if /sgl-data is a real mount (not ephemeral)\n          # This tells the test cleanup code to NOT delete downloaded models\n          if docker exec ci_sglang test -d /sgl-data && docker exec ci_sglang mountpoint -q /sgl-data 2>/dev/null; then\n            docker exec ci_sglang touch /sgl-data/hf-cache/.persistent_cache\n            echo \"Created .persistent_cache marker - HF cache will persist\"\n          else\n            echo \"WARNING: /sgl-data is not a mount point - models will be cleaned up after each test\"\n          fi\n\n          # Check MIOpen cache (VAE convolution kernels)\n          miopen_files=$(docker exec ci_sglang find /sgl-data/miopen-cache -name \"*.udb\" 2>/dev/null | wc -l || echo \"0\")\n          echo \"Found ${miopen_files} MIOpen cache files\"\n\n      - name: Diagnose HF cache and system resources\n        run: |\n          echo \"=== System Memory Status ===\"\n          free -h\n          echo \"\"\n          echo \"=== Disk Space ===\"\n          df -h /home/runner/sgl-data 2>/dev/null || df -h\n          echo \"\"\n          echo \"=== HF Cache Directory Structure ===\"\n          docker exec ci_sglang ls -la /sgl-data/hf-cache/ 2>/dev/null || echo \"HF cache dir not found\"\n          docker exec ci_sglang ls -la /sgl-data/hf-cache/hub/ 2>/dev/null || echo \"HF hub cache not found\"\n          echo \"\"\n          echo \"=== Checking for cached diffusion models (1-GPU tests) ===\"\n          # Models used in 1-GPU tests: Wan2.1-T2V-1.3B, HunyuanVideo, Qwen-Image, FLUX.1, FLUX.2\n          for model in \"Wan-AI--Wan2.1-T2V-1.3B-Diffusers\" \"tencent--HunyuanVideo\" \"Qwen--Qwen-Image\" \"black-forest-labs--FLUX.1-dev\" \"black-forest-labs--FLUX.2-dev\"; do\n            cache_path=\"/sgl-data/hf-cache/hub/models--${model}\"\n            if docker exec ci_sglang test -d \"$cache_path\"; then\n              size=$(docker exec ci_sglang du -sh \"$cache_path\" 2>/dev/null | cut -f1)\n              echo \"✓ CACHED: $model ($size)\"\n            else\n              echo \"✗ NOT CACHED: $model\"\n            fi\n          done\n          echo \"\"\n          echo \"=== GPU Memory Status ===\"\n          docker exec ci_sglang rocm-smi --showmeminfo vram 2>/dev/null || echo \"rocm-smi not available\"\n\n      - name: Run diffusion server tests (1-GPU)\n        timeout-minutes: 60\n        run: |\n          # AMD CI: All 1-GPU tests except FLUX.2 (FLUX.1 covers same code path)\n          # Tests: T2V, T2I, I2V, LoRA\n          #\n          # HF download env vars:\n          # - HF_HUB_ENABLE_HF_TRANSFER=1: Use faster hf_transfer for downloads (if available)\n          # - HF_HUB_DISABLE_SYMLINKS_WARNING=1: Suppress symlink warnings\n          docker exec \\\n            -e SGLANG_E2E_TOLERANCE=0.3 \\\n            -e SGLANG_STAGE_TIME_TOLERANCE=0.2 \\\n            -e SGLANG_NON_DENOISE_STAGE_TIME_TOLERANCE=0.6 \\\n            -e SGLANG_DENOISE_STEP_TOLERANCE=0.6 \\\n            -e SGLANG_DENOISE_AGG_TOLERANCE=0.3 \\\n            -e SGLANG_TEST_NUM_INFERENCE_STEPS=5 \\\n            -e AITER_JIT_DIR=/sgl-data/aiter-kernels \\\n            -e MIOPEN_USER_DB_PATH=/sgl-data/miopen-cache \\\n            -e HF_HUB_ENABLE_HF_TRANSFER=1 \\\n            -e HF_HUB_DISABLE_SYMLINKS_WARNING=1 \\\n            -w /sglang-checkout/python \\\n            ci_sglang python3 sglang/multimodal_gen/test/run_suite.py \\\n              --suite 1-gpu \\\n              --partition-id ${{ matrix.part }} \\\n              --total-partitions 4 \\\n              -k \"not flux_2\"\n\n          # Post-test diagnostics\n          echo \"=== Post-test System Memory Status ===\"\n          free -h\n\n  multimodal-gen-test-2-gpu-amd:\n    needs: [check-changes]\n    if: |\n      always() &&\n      (\n        (contains(format(',{0},', inputs.target_stage || inputs.target_stage_select), ',multimodal-gen-test-2-gpu-amd,')) ||\n        (\n          !(inputs.target_stage || inputs.target_stage_select) &&\n          (!failure() && !cancelled()) &&\n          ((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true'))\n        )\n      )\n    strategy:\n      fail-fast: false\n      max-parallel: 1  # Run one at a time to avoid eviction from resource exhaustion during AITER kernel JIT\n      matrix:\n        runner: [linux-mi325-2gpu-sglang]\n        part: [0, 1]  # 2 partitions: 9 tests ÷ 2 = ~4-5 tests each\n    runs-on: ${{matrix.runner}}\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.ref || github.sha }}\n\n      - name: Ensure VRAM is clear\n        run: bash scripts/ensure_vram_clear.sh rocm\n\n      - name: Download artifacts\n        if: needs.check-changes.outputs.sgl_kernel == 'true'\n        uses: actions/download-artifact@v4\n        with:\n          path: sgl-kernel/dist/\n          merge-multiple: true\n          pattern: wheel-python3.10-cuda12.9\n\n      - name: Start CI container\n        run: bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh diffusion\n          docker exec ci_sglang pip install amdsmi\n\n      - name: Setup kernel caches\n        run: |\n          # Use the persistent /sgl-data directory (mounted from /home/runner/sgl-data)\n          docker exec ci_sglang mkdir -p /sgl-data/aiter-kernels /sgl-data/miopen-cache /sgl-data/hf-cache/hub\n\n          # Clear pre-built AITER kernels from Docker image to avoid segfaults\n          # The image may have stale/incompatible kernels at /sgl-workspace/aiter/aiter/jit/\n          echo \"Clearing pre-built AITER kernels from Docker image...\"\n          docker exec ci_sglang rm -rf /sgl-workspace/aiter/aiter/jit/*.so 2>/dev/null || true\n          docker exec ci_sglang rm -rf /sgl-data/aiter-kernels/*.so 2>/dev/null || true\n          echo \"AITER kernels cleared - will be rebuilt on first use\"\n\n          # Create persistent cache marker if /sgl-data is a real mount (not ephemeral)\n          # This tells the test cleanup code to NOT delete downloaded models\n          if docker exec ci_sglang test -d /sgl-data && docker exec ci_sglang mountpoint -q /sgl-data 2>/dev/null; then\n            docker exec ci_sglang touch /sgl-data/hf-cache/.persistent_cache\n            echo \"Created .persistent_cache marker - HF cache will persist\"\n          else\n            echo \"WARNING: /sgl-data is not a mount point - models will be cleaned up after each test\"\n          fi\n\n          # Check MIOpen cache (VAE convolution kernels)\n          miopen_files=$(docker exec ci_sglang find /sgl-data/miopen-cache -name \"*.udb\" 2>/dev/null | wc -l || echo \"0\")\n          echo \"Found ${miopen_files} MIOpen cache files\"\n\n      - name: Diagnose HF cache and system resources\n        run: |\n          echo \"=== System Memory Status ===\"\n          free -h\n          echo \"\"\n          echo \"=== Disk Space ===\"\n          df -h /home/runner/sgl-data 2>/dev/null || df -h\n          echo \"\"\n          echo \"=== HF Cache Directory Structure ===\"\n          docker exec ci_sglang ls -la /sgl-data/hf-cache/ 2>/dev/null || echo \"HF cache dir not found\"\n          docker exec ci_sglang ls -la /sgl-data/hf-cache/hub/ 2>/dev/null || echo \"HF hub cache not found\"\n          echo \"\"\n          echo \"=== Checking for cached diffusion models (2-GPU tests) ===\"\n          # Models used in 2-GPU tests: Wan2.2-T2V-A14B, Wan2.1-T2V-14B, Qwen-Image, FLUX.1\n          for model in \"Wan-AI--Wan2.2-T2V-A14B-Diffusers\" \"Wan-AI--Wan2.1-T2V-14B-Diffusers\" \"Qwen--Qwen-Image\" \"black-forest-labs--FLUX.1-dev\"; do\n            cache_path=\"/sgl-data/hf-cache/hub/models--${model}\"\n            if docker exec ci_sglang test -d \"$cache_path\"; then\n              size=$(docker exec ci_sglang du -sh \"$cache_path\" 2>/dev/null | cut -f1)\n              echo \"✓ CACHED: $model ($size)\"\n            else\n              echo \"✗ NOT CACHED: $model\"\n            fi\n          done\n          echo \"\"\n          echo \"=== GPU Memory Status ===\"\n          docker exec ci_sglang rocm-smi --showmeminfo vram 2>/dev/null || echo \"rocm-smi not available\"\n\n      - name: Run diffusion server tests (2-GPU)\n        timeout-minutes: 80\n        run: |\n          # AMD CI: All 2-GPU tests including LoRA\n          # Tests: T2V, T2I, I2V, LoRA\n          #\n          # HF download env vars:\n          # - HF_HUB_ENABLE_HF_TRANSFER=1: Use faster hf_transfer for downloads (if available)\n          # - HF_HUB_DISABLE_SYMLINKS_WARNING=1: Suppress symlink warnings\n          docker exec \\\n            -e SGLANG_E2E_TOLERANCE=0.3 \\\n            -e SGLANG_STAGE_TIME_TOLERANCE=0.2 \\\n            -e SGLANG_NON_DENOISE_STAGE_TIME_TOLERANCE=0.6 \\\n            -e SGLANG_DENOISE_STEP_TOLERANCE=0.6 \\\n            -e SGLANG_DENOISE_AGG_TOLERANCE=0.3 \\\n            -e SGLANG_TEST_NUM_INFERENCE_STEPS=5 \\\n            -e AITER_JIT_DIR=/sgl-data/aiter-kernels \\\n            -e MIOPEN_USER_DB_PATH=/sgl-data/miopen-cache \\\n            -e HF_HUB_ENABLE_HF_TRANSFER=1 \\\n            -e HF_HUB_DISABLE_SYMLINKS_WARNING=1 \\\n            -w /sglang-checkout/python \\\n            ci_sglang python3 sglang/multimodal_gen/test/run_suite.py \\\n              --suite 2-gpu \\\n              --partition-id ${{ matrix.part }} \\\n              --total-partitions 2\n\n          # Post-test diagnostics\n          echo \"=== Post-test System Memory Status ===\"\n          free -h\n\n\n  stage-c-test-large-8-gpu-amd:\n    needs: [check-changes]\n    if: |\n      always() &&\n      (\n        (contains(format(',{0},', inputs.target_stage || inputs.target_stage_select), ',stage-c-test-large-8-gpu-amd,')) ||\n        (\n          !(inputs.target_stage || inputs.target_stage_select) &&\n          (!failure() && !cancelled()) &&\n          ((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true'))\n        )\n      )\n    env:\n      RUNNER_LABELS: linux-mi325-8gpu-sglang\n    strategy:\n      fail-fast: false\n      matrix:\n        runner: [linux-mi325-8gpu-sglang]\n        part: [0, 1, 2]\n    runs-on: ${{matrix.runner}}\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.ref || github.sha }}\n\n      - name: Ensure VRAM is clear\n        run: bash scripts/ensure_vram_clear.sh rocm\n\n      - name: Start CI container\n        run: bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh\n      - name: Test RCCL multi-GPU communication\n        timeout-minutes: 5\n        run: |\n          echo \"Testing RCCL multi-GPU communication with debug info...\"\n          docker exec ci_sglang bash -c \"cd /sglang-checkout && NCCL_DEBUG=INFO RCCL_DEBUG=INFO torchrun --nproc_per_node=8 scripts/ci/amd/test_rccl_multi_gpu.py\"\n\n      - name: Run test\n        timeout-minutes: 60\n        run: |\n          bash scripts/ci/amd/amd_ci_exec.sh -w \"/sglang-checkout/test\" python3 run_suite.py --hw amd --suite stage-c-test-large-8-gpu-amd --auto-partition-id ${{ matrix.part }} --auto-partition-size 3 --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }}\n\n  stage-c-test-large-8-gpu-amd-mi35x:\n    needs: [check-changes]\n    if: |\n      always() &&\n      (\n        (contains(format(',{0},', inputs.target_stage || inputs.target_stage_select), ',stage-c-test-large-8-gpu-amd-mi35x,')) ||\n        (\n          !(inputs.target_stage || inputs.target_stage_select) &&\n          (!failure() && !cancelled()) &&\n          ((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true'))\n        )\n      )\n    strategy:\n      fail-fast: false\n      matrix:\n        runner: [linux-mi35x-gpu-8]\n        part: [0, 1]\n    runs-on: ${{matrix.runner}}\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.ref || github.sha }}\n\n      - name: Ensure VRAM is clear\n        run: bash scripts/ensure_vram_clear.sh rocm\n\n      - name: Start CI container\n        run: bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh\n      - name: Run test\n        timeout-minutes: 60\n        run: |\n          bash scripts/ci/amd/amd_ci_exec.sh -w \"/sglang-checkout/test\" python3 run_suite.py --hw amd --suite stage-c-test-large-8-gpu-amd-mi35x --auto-partition-id ${{ matrix.part }} --auto-partition-size 2 --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }}\n\n  # =============================================== Disaggregation ====================================================\n  stage-b-test-large-8-gpu-35x-disaggregation-amd:\n    needs: [check-changes]\n    if: |\n      always() &&\n      (\n        (contains(format(',{0},', inputs.target_stage || inputs.target_stage_select), ',stage-b-test-large-8-gpu-disaggregation-amd,')) ||\n        (\n          !(inputs.target_stage || inputs.target_stage_select) &&\n          (!failure() && !cancelled()) &&\n          ((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true'))\n        )\n      )\n    strategy:\n      fail-fast: false\n      matrix:\n        runner: [linux-mi35x-gpu-8.fabric]\n\n    runs-on: ${{matrix.runner}}\n\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.ref || github.sha }}\n\n      - name: Ensure VRAM is clear\n        run: bash scripts/ensure_vram_clear.sh rocm\n\n      - name: Check Host RDMA Environment\n        id: rdma_detect\n        run: |\n          set +e\n          echo \"=== Checking Host RDMA Environment ===\"\n\n          echo \"\"\n          echo \"=== 1. Ionic driver library check ===\"\n          ls -l /usr/lib/x86_64-linux-gnu/libibverbs/libionic* 2>/dev/null || echo \"libionic not found in standard path\"\n\n          echo \"\"\n          echo \"=== 2. Infiniband devices ===\"\n          ls -la /dev/infiniband/ 2>/dev/null || echo \"/dev/infiniband not found\"\n          ls -la /sys/class/infiniband/ 2>/dev/null || echo \"/sys/class/infiniband not found\"\n\n          echo \"\"\n          echo \"=== 3. ibv_devinfo ===\"\n          which ibv_devinfo 2>/dev/null && ibv_devinfo 2>&1 || echo \"ibv_devinfo not available\"\n\n          echo \"\"\n          echo \"=== 4. Kernel modules ===\"\n          lsmod 2>/dev/null | grep -E \"ib_|rdma|ionic\" || echo \"No RDMA kernel modules loaded\"\n\n          echo \"\"\n          echo \"=== 5. Detect RDMA Devices for test environment ===\"\n          if [ -d \"/sys/class/infiniband\" ]; then\n            RDMA_DEVS=$(ls /sys/class/infiniband | paste -sd \",\" -)\n            echo \"Detected RDMA Devices: $RDMA_DEVS\"\n            echo \"SGLANG_TEST_RDMA_DEVICE=$RDMA_DEVS\" >> $GITHUB_ENV\n          else\n            echo \"No RDMA devices found in /sys/class/infiniband\"\n            echo \"SGLANG_TEST_RDMA_DEVICE=\" >> $GITHUB_ENV\n          fi\n\n          echo \"\"\n          echo \"=== Host RDMA Check Complete ===\"\n\n      - name: Start Special Container\n        run: bash scripts/ci/amd/amd_ci_start_container_disagg.sh --rocm-version rocm720\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh\n\n      - name: Verify RDMA in Container\n        run: |\n          docker exec -u root ci_sglang bash -c '\n            echo \"=== Container RDMA Verification ===\"\n            echo \"Device nodes:\"\n            ls -la /dev/infiniband/\n            echo \"\"\n            echo \"Provider libraries:\"\n            ls /usr/lib/x86_64-linux-gnu/libibverbs/ | grep -E \"ionic|mlx\" || echo \"No Ionic/Mellanox providers\"\n            echo \"\"\n            echo \"HCA devices:\"\n            HCA_COUNT=$(ibv_devinfo -list 2>&1 | grep -oE \"^[0-9]+ HCAs? found\" | grep -oE \"^[0-9]+\" || echo \"0\")\n            ibv_devinfo -list\n            if [ \"$HCA_COUNT\" -gt 0 ]; then\n              echo \"\"\n              echo \"=== SUCCESS: RDMA setup complete. Found $HCA_COUNT HCA(s) ===\"\n            else\n              echo \"\"\n              echo \"=== WARNING: No HCAs detected. RDMA tests may fail ===\"\n            fi\n          '\n\n      - name: Run Aiter Op Test (RMSNorm)\n        timeout-minutes: 10\n        run: |\n          echo \"Running pre-check: test_rmsnorm2d.py\"\n          docker exec \\\n            -e MAX_JOBS=192 \\\n            ci_sglang \\\n            python /sgl-workspace/aiter/op_tests/test_rmsnorm2d.py\n\n      - name: Run test_disaggregation\n        timeout-minutes: 60\n        run: |\n          bash scripts/ci/amd/amd_ci_exec.sh \\\n            -e SGLANG_TEST_RDMA_DEVICE=\"${{ env.SGLANG_TEST_RDMA_DEVICE }}\" \\\n            -w \"/sglang-checkout/test\" python3 run_suite.py --hw amd --suite stage-b-test-large-8-gpu-35x-disaggregation-amd --timeout-per-file 1800 ${{ inputs.continue_on_error && '--continue-on-error' || '' }}\n\n  pr-test-amd-finish:\n    needs:\n      [\n        call-gate,\n        check-changes,\n\n        sgl-kernel-unit-test-amd,\n        sgl-kernel-unit-test-2-gpu-amd,\n        multimodal-gen-test-1-gpu-amd,\n        multimodal-gen-test-2-gpu-amd,\n\n        stage-a-test-1-amd,\n        jit-kernel-unit-test-amd,\n        stage-b-test-small-1-gpu-amd,\n        stage-b-test-small-1-gpu-amd-nondeterministic,\n        stage-b-test-small-1-gpu-amd-mi35x,\n        stage-b-test-large-1-gpu-amd,\n        stage-b-test-large-2-gpu-amd,\n        stage-b-test-large-8-gpu-35x-disaggregation-amd,\n        stage-c-test-large-8-gpu-amd,\n        stage-c-test-large-8-gpu-amd-mi35x,\n      ]\n    if: always()\n    runs-on: ubuntu-latest\n    steps:\n      - name: Check all dependent job statuses\n        run: |\n          # Convert the 'needs' context to a JSON string\n          json_needs='${{ toJson(needs) }}'\n\n          # Get a list of all job names from the JSON keys\n          job_names=$(echo \"$json_needs\" | jq -r 'keys_unsorted[]')\n\n          for job in $job_names; do\n            # For each job, extract its result\n            result=$(echo \"$json_needs\" | jq -r --arg j \"$job\" '.[$j].result')\n\n            # Print the job name and its result\n            echo \"$job: $result\"\n\n            # Check for failure or cancellation and exit if found\n            if [[ \"$result\" == \"failure\" || \"$result\" == \"cancelled\" ]]; then\n              echo \"The above jobs failed.\"\n              exit 1\n            fi\n          done\n\n          # If the loop completes, all jobs were successful\n          echo \"All jobs completed successfully\"\n          exit 0\n"
  },
  {
    "path": ".github/workflows/pr-test-amd.yml",
    "content": "name: PR Test (AMD)\n# Dynamic run-name for /rerun-stage commands to enable URL lookup\n# Format: \"[stage-name] sha\" for fork PRs, \"[stage-name]\" for non-fork, default for normal runs\nrun-name: ${{ (inputs.target_stage || inputs.target_stage_select) && (inputs.pr_head_sha && format('[{0}] {1}', inputs.target_stage || inputs.target_stage_select, inputs.pr_head_sha) || format('[{0}]', inputs.target_stage || inputs.target_stage_select)) || '' }}\n\non:\n  push:\n    branches: [ main ]\n    paths:\n      - \"python/**\"\n      - \"scripts/ci/**\"\n      - \"test/**\"\n      - \"sgl-kernel/**\"\n      - \".github/workflows/pr-test-amd.yml\"\n      - \"docker/rocm.Dockerfile\"\n  pull_request:\n    branches: [ main ]\n    paths:\n      - \"python/**\"\n      - \"scripts/ci/**\"\n      - \"test/**\"\n      - \"sgl-kernel/**\"\n      - \".github/workflows/pr-test-amd.yml\"\n      - \"docker/rocm.Dockerfile\"\n  workflow_dispatch:\n    inputs:\n      target_stage_select:\n        description: \"Select a stage to run from dropdown (leave empty for auto-detect)\"\n        required: false\n        type: choice\n        default: ''\n        options:\n          - ''\n          - sgl-kernel-unit-test-amd\n          - sgl-kernel-unit-test-2-gpu-amd\n          - stage-a-test-1-amd\n          - jit-kernel-unit-test-amd\n          - stage-b-test-small-1-gpu-amd\n          - stage-b-test-small-1-gpu-amd-nondeterministic\n          - stage-b-test-small-1-gpu-amd-mi35x\n          - stage-b-test-large-1-gpu-amd\n          - stage-b-test-large-2-gpu-amd\n          - multimodal-gen-test-1-gpu-amd\n          - multimodal-gen-test-2-gpu-amd\n          - stage-c-test-large-8-gpu-amd\n          - stage-c-test-large-8-gpu-amd-mi35x\n          - stage-b-test-large-8-gpu-disaggregation-amd\n      target_stage:\n        description: \"Or type comma-separated stage names (overrides dropdown if non-empty)\"\n        required: false\n        type: string\n        default: \"\"\n      pr_head_sha:\n        description: \"PR head SHA to checkout (for /rerun-stage on fork PRs)\"\n        required: false\n        type: string\n        default: \"\"\n      aiter_ref:\n        description: 'Override AITER commit (optional, leave empty to use Dockerfile default)'\n        required: false\n        type: string\n        default: ''\n      continue_on_error:\n        description: 'Continue on error (do not fail the workflow on test failures)'\n        required: false\n        type: boolean\n        default: false\n  workflow_call:\n    inputs:\n      ref:\n        description: 'Git ref (branch, tag, or SHA) to test. If not provided, uses the default branch.'\n        required: false\n        type: string\n        default: ''\n      run_all_tests:\n        description: \"Run all tests (for releasing or testing purpose)\"\n        required: false\n        type: boolean\n        default: false\n      aiter_ref:\n        description: 'Override AITER commit (optional, leave empty to use Dockerfile default)'\n        required: false\n        type: string\n        default: ''\n      continue_on_error:\n        description: 'Continue on error (do not fail the workflow on test failures)'\n        required: false\n        type: boolean\n        default: false\n\nenv:\n  AITER_COMMIT_OVERRIDE: ${{ inputs.aiter_ref }}\n\nconcurrency:\n  # When called via workflow_call with run_all_tests=true, use a unique group per run to\n  # avoid collisions with direct push/PR triggers. We use run_all_tests (not github.event_name)\n  # to detect this, because github.event_name inherits from the caller in workflow_call.\n  group: pr-test-amd-${{ inputs.run_all_tests && format('full-{0}', github.run_id) || inputs.pr_head_sha || inputs.ref || github.ref }}\n  cancel-in-progress: ${{ !inputs.run_all_tests && github.event_name != 'workflow_call' }}\n\njobs:\n  call-gate:\n    uses: ./.github/workflows/pr-gate.yml\n    secrets: inherit\n  check-changes:\n    needs: [call-gate]\n    runs-on: ubuntu-latest\n    outputs:\n      main_package: ${{ steps.filter.outputs.main_package || steps.run-mode.outputs.run_all_tests }}\n      sgl_kernel: ${{ steps.filter.outputs.sgl_kernel || steps.run-mode.outputs.run_all_tests }}\n      jit_kernel: ${{ steps.filter.outputs.jit_kernel || steps.run-mode.outputs.run_all_tests }}\n      multimodal_gen: ${{ steps.filter.outputs.multimodal_gen || steps.run-mode.outputs.run_all_tests }}\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.ref || github.sha }}\n\n      - name: Determine run mode\n        id: run-mode\n        run: |\n          # Run all tests for workflow_call (when ref input is provided)\n          # Note: github.event_name is inherited from caller, so we detect workflow_call by checking inputs.ref\n          if [[ \"${{ inputs.run_all_tests }}\" == \"true\" ]]; then\n            echo \"run_all_tests=true\" >> $GITHUB_OUTPUT\n            echo \"Run mode: ALL TESTS (run_all_tests=${{ inputs.run_all_tests }})\"\n          else\n            echo \"run_all_tests=false\" >> $GITHUB_OUTPUT\n            echo \"Run mode: FILTERED (triggered by ${{ github.event_name }})\"\n          fi\n\n      - name: Detect file changes\n        id: filter\n        uses: dorny/paths-filter@v3\n        if: steps.run-mode.outputs.run_all_tests != 'true'\n        with:\n          filters: |\n            main_package:\n              - \"python/sglang/!(multimodal_gen)/**\"\n              - \"python/pyproject_rocm.toml\"\n              - \"python/pyproject_other.toml\"\n              - \"scripts/ci/amd/*\"\n              - \"scripts/ci/utils/*\"\n              - \"test/**\"\n              - \".github/workflows/pr-test-amd.yml\"\n            sgl_kernel:\n              - \"sgl-kernel/**\"\n              - \".github/workflows/pr-test-amd.yml\"\n            jit_kernel:\n              - \"python/sglang/jit_kernel/**\"\n              - \".github/workflows/pr-test-amd.yml\"\n            multimodal_gen:\n              - \"python/sglang/multimodal_gen/**\"\n              - \"python/sglang/cli/**\"\n              - \"python/sglang/jit_kernel/diffusion/**\"\n              - \"python/pyproject_rocm.toml\"\n              - \"python/pyproject_other.toml\"\n\n  # =============================================== sgl-kernel ====================================================\n  sgl-kernel-unit-test-amd:\n    needs: [check-changes]\n    if: |\n      always() &&\n      (\n        (contains(format(',{0},', inputs.target_stage || inputs.target_stage_select), ',sgl-kernel-unit-test-amd,')) ||\n        (\n          !(inputs.target_stage || inputs.target_stage_select) &&\n          needs.check-changes.outputs.sgl_kernel == 'true'\n        )\n      )\n    strategy:\n      fail-fast: false\n      matrix:\n        runner: [linux-mi325-1gpu-sglang]\n    runs-on: ${{matrix.runner}}\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.ref || github.sha }}\n\n      - name: Ensure VRAM is clear\n        run: bash scripts/ensure_vram_clear.sh rocm\n\n      - name: Start CI container\n        run: bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh\n\n      - name: Run test\n        timeout-minutes: 14\n        run: |\n          docker exec -w /sglang-checkout/sgl-kernel/tests ci_sglang python3 -m pytest test_moe_align.py\n          docker exec -w /sglang-checkout/sgl-kernel/tests ci_sglang python3 -m pytest test_moe_topk_softmax.py\n          docker exec -w /sglang-checkout/sgl-kernel/tests/speculative ci_sglang python3 -m pytest test_eagle_utils.py\n          docker exec -w /sglang-checkout/sgl-kernel/tests ci_sglang python3 -m pytest test_apply_token_bitmask_inplace.py\n          docker exec -w /sglang-checkout/sgl-kernel/tests ci_sglang python3 -m pytest test_activation.py\n          docker exec -w /sglang-checkout/sgl-kernel/tests ci_sglang python3 -m pytest test_topk.py\n          docker exec -w /sglang-checkout/sgl-kernel/tests ci_sglang python3 -m pytest test_kvcacheio.py\n          docker exec -w /sglang-checkout/sgl-kernel/tests ci_sglang python3 -m pytest test_moe_topk_sigmoid.py\n          docker exec -w /sglang-checkout/sgl-kernel/tests ci_sglang python3 -m pytest test_torch_defaults_reset.py\n\n  sgl-kernel-unit-test-2-gpu-amd:\n    needs: [check-changes]\n    if: |\n      always() &&\n      (\n        (contains(format(',{0},', inputs.target_stage || inputs.target_stage_select), ',sgl-kernel-unit-test-2-gpu-amd,')) ||\n        (\n          !(inputs.target_stage || inputs.target_stage_select) &&\n          needs.check-changes.outputs.sgl_kernel == 'true'\n        )\n      )\n    strategy:\n      fail-fast: false\n      matrix:\n        runner: [linux-mi325-2gpu-sglang]\n    runs-on: ${{matrix.runner}}\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.ref || github.sha }}\n\n      - name: Ensure VRAM is clear\n        run: bash scripts/ensure_vram_clear.sh rocm\n\n      - name: Start CI container\n        run: bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh\n\n      - name: Run test\n        timeout-minutes: 20\n        run: |\n          docker exec -w /sglang-checkout/sgl-kernel/tests ci_sglang python3 -m pytest test_amd_deterministic_custom_allreduce.py\n          docker exec -w /sglang-checkout/sgl-kernel/tests ci_sglang python3 -m pytest test_amd_nccl_allreduce_determinism.py\n\n  # =============================================== primary ====================================================\n\n  stage-a-test-1-amd:\n    needs: [check-changes]\n    if: |\n      always() &&\n      (\n        (contains(format(',{0},', inputs.target_stage || inputs.target_stage_select), ',stage-a-test-1-amd,')) ||\n        (\n          !(inputs.target_stage || inputs.target_stage_select) &&\n          (!failure() && !cancelled()) &&\n          ((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true'))\n        )\n      )\n    strategy:\n      fail-fast: false\n      matrix:\n        runner: [linux-mi325-1gpu-sglang]\n    runs-on: ${{matrix.runner}}\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.ref || github.sha }}\n\n      - name: Ensure VRAM is clear\n        run: bash scripts/ensure_vram_clear.sh rocm\n\n      - name: Start CI container\n        run: bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh\n\n      - name: Run test\n        timeout-minutes: 10\n        run: |\n          bash scripts/ci/amd/amd_ci_exec.sh -w \"/sglang-checkout/test\" python3 run_suite.py --hw amd --suite stage-a-test-1-amd ${{ inputs.continue_on_error && '--continue-on-error' || '' }}\n\n  jit-kernel-unit-test-amd:\n    needs: [check-changes]\n    if: |\n      always() &&\n      (\n        (contains(format(',{0},', inputs.target_stage || inputs.target_stage_select), ',jit-kernel-unit-test-amd,')) ||\n        (\n          !(inputs.target_stage || inputs.target_stage_select) &&\n          needs.check-changes.outputs.jit_kernel == 'true'\n        )\n      )\n    strategy:\n      fail-fast: false\n      matrix:\n        runner: [linux-mi325-1gpu-sglang]\n    runs-on: ${{matrix.runner}}\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.ref || github.sha }}\n\n      - name: Ensure VRAM is clear\n        run: bash scripts/ensure_vram_clear.sh rocm\n\n      - name: Start CI container\n        run: bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh\n\n      - name: Run JIT kernel unit tests\n        timeout-minutes: 10\n        run: |\n          bash scripts/ci/amd/amd_ci_exec.sh -w \"/sglang-checkout\" python3 -m pytest -q python/sglang/jit_kernel/tests/test_store_cache.py\n\n  stage-b-test-small-1-gpu-amd:\n    needs: [check-changes, stage-a-test-1-amd]\n    if: |\n      always() &&\n      (\n        (contains(format(',{0},', inputs.target_stage || inputs.target_stage_select), ',stage-b-test-small-1-gpu-amd,')) ||\n        (\n          !(inputs.target_stage || inputs.target_stage_select) &&\n          (!failure() && !cancelled()) &&\n          ((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true'))\n        )\n      )\n    strategy:\n      fail-fast: false\n      matrix:\n        runner: [linux-mi325-1gpu-sglang]\n        part: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]\n    runs-on: ${{matrix.runner}}\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.ref || github.sha }}\n\n      - name: Ensure VRAM is clear\n        run: bash scripts/ensure_vram_clear.sh rocm\n\n      - name: Start CI container\n        run: bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh\n\n      - name: Run test\n        timeout-minutes: 30\n        run: |\n          bash scripts/ci/amd/amd_ci_exec.sh -w \"/sglang-checkout/test\" python3 run_suite.py --hw amd --suite stage-b-test-small-1-gpu-amd --auto-partition-id ${{ matrix.part }} --auto-partition-size 14 --timeout-per-file 1800 ${{ inputs.continue_on_error && '--continue-on-error' || '' }}\n\n  stage-b-test-small-1-gpu-amd-nondeterministic:\n    needs: [check-changes, stage-a-test-1-amd]\n    if: |\n      always() &&\n      (\n        (contains(format(',{0},', inputs.target_stage || inputs.target_stage_select), ',stage-b-test-small-1-gpu-amd-nondeterministic,')) ||\n        (\n          !(inputs.target_stage || inputs.target_stage_select) &&\n          (!failure() && !cancelled()) &&\n          ((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true'))\n        )\n      )\n    strategy:\n      fail-fast: false\n      matrix:\n        runner: [linux-mi325-1gpu-sglang]\n    runs-on: ${{matrix.runner}}\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.ref || github.sha }}\n\n      - name: Ensure VRAM is clear\n        run: bash scripts/ensure_vram_clear.sh rocm\n\n      - name: Start CI container\n        run: bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh\n\n      - name: Run test\n        timeout-minutes: 30\n        run: |\n          bash scripts/ci/amd/amd_ci_exec.sh -w \"/sglang-checkout/test\" python3 run_suite.py --hw amd --suite stage-b-test-small-1-gpu-amd-nondeterministic --timeout-per-file 1800 ${{ inputs.continue_on_error && '--continue-on-error' || '' }}\n\n  stage-b-test-small-1-gpu-amd-mi35x:\n    needs: [check-changes, stage-a-test-1-amd]\n    if: |\n      always() &&\n      (\n        (contains(format(',{0},', inputs.target_stage || inputs.target_stage_select), ',stage-b-test-small-1-gpu-amd-mi35x,')) ||\n        (\n          !(inputs.target_stage || inputs.target_stage_select) &&\n          (!failure() && !cancelled()) &&\n          ((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true'))\n        )\n      )\n    strategy:\n      fail-fast: false\n      matrix:\n        runner: [linux-mi35x-gpu-1]\n    runs-on: ${{matrix.runner}}\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.ref || github.sha }}\n\n      - name: Ensure VRAM is clear\n        run: bash scripts/ensure_vram_clear.sh rocm\n\n      - name: Start CI container\n        run: bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh\n\n      - name: Run test\n        timeout-minutes: 30\n        run: |\n          bash scripts/ci/amd/amd_ci_exec.sh -w \"/sglang-checkout/test\" python3 run_suite.py --hw amd --suite stage-b-test-small-1-gpu-amd-mi35x ${{ inputs.continue_on_error && '--continue-on-error' || '' }}\n\n  stage-b-test-large-1-gpu-amd:\n    needs: [check-changes, stage-a-test-1-amd]\n    if: |\n      always() &&\n      (\n        (contains(format(',{0},', inputs.target_stage || inputs.target_stage_select), ',stage-b-test-large-1-gpu-amd,')) ||\n        (\n          !(inputs.target_stage || inputs.target_stage_select) &&\n          (!failure() && !cancelled()) &&\n          ((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true'))\n        )\n      )\n    strategy:\n      fail-fast: false\n      matrix:\n        runner: [linux-mi325-1gpu-sglang]\n        part: [0, 1]\n    runs-on: ${{matrix.runner}}\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.ref || github.sha }}\n\n      - name: Ensure VRAM is clear\n        run: bash scripts/ensure_vram_clear.sh rocm\n\n      - name: Start CI container\n        run: bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh\n\n      - name: Run test\n        timeout-minutes: 30\n        run: |\n          bash scripts/ci/amd/amd_ci_exec.sh -w \"/sglang-checkout/test\" python3 run_suite.py --hw amd --suite stage-b-test-large-1-gpu-amd --auto-partition-id ${{ matrix.part }} --auto-partition-size 2 --timeout-per-file 1800 ${{ inputs.continue_on_error && '--continue-on-error' || '' }}\n\n  stage-b-test-large-2-gpu-amd:\n    needs: [check-changes, stage-a-test-1-amd]\n    if: |\n      always() &&\n      (\n        (contains(format(',{0},', inputs.target_stage || inputs.target_stage_select), ',stage-b-test-large-2-gpu-amd,')) ||\n        (\n          !(inputs.target_stage || inputs.target_stage_select) &&\n          (!failure() && !cancelled()) &&\n          ((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true'))\n        )\n      )\n    strategy:\n      fail-fast: false\n      matrix:\n        runner: [linux-mi325-2gpu-sglang]\n        part: [0, 1]\n    runs-on: ${{matrix.runner}}\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.ref || github.sha }}\n\n      - name: Ensure VRAM is clear\n        run: bash scripts/ensure_vram_clear.sh rocm\n\n      - name: Start CI container\n        run: bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh\n\n      - name: Run test\n        timeout-minutes: 30\n        run: |\n          bash scripts/ci/amd/amd_ci_exec.sh -w \"/sglang-checkout/test\" python3 run_suite.py --hw amd --suite stage-b-test-large-2-gpu-amd --auto-partition-id ${{ matrix.part }} --auto-partition-size 2 --timeout-per-file 1800 ${{ inputs.continue_on_error && '--continue-on-error' || '' }}\n\n  multimodal-gen-test-1-gpu-amd:\n    needs: [check-changes]\n    if: |\n      always() &&\n      (\n        (contains(format(',{0},', inputs.target_stage || inputs.target_stage_select), ',multimodal-gen-test-1-gpu-amd,')) ||\n        (\n          !(inputs.target_stage || inputs.target_stage_select) &&\n          needs.check-changes.outputs.multimodal_gen == 'true'\n        )\n      )\n    strategy:\n      fail-fast: false\n      max-parallel: 1  # Run one at a time to avoid eviction from resource exhaustion during AITER kernel JIT\n      matrix:\n        runner: [linux-mi325-1gpu-sglang]\n        part: [0, 1, 2, 3]  # 2 partitions: 11 tests ÷ 2 = ~5-6 tests each\n    runs-on: ${{matrix.runner}}\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.ref || github.sha }}\n\n      - name: Ensure VRAM is clear\n        run: bash scripts/ensure_vram_clear.sh rocm\n\n      - name: Download artifacts\n        if: needs.check-changes.outputs.sgl_kernel == 'true'\n        uses: actions/download-artifact@v4\n        with:\n          path: sgl-kernel/dist/\n          merge-multiple: true\n          pattern: wheel-python3.10-cuda12.9\n\n      - name: Start CI container\n        run: bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh diffusion\n\n      - name: Setup kernel caches\n        run: |\n          # Use the persistent /sgl-data directory (mounted from /home/runner/sgl-data)\n          # This directory persists across container restarts on the self-hosted runner\n          docker exec ci_sglang mkdir -p /sgl-data/aiter-kernels /sgl-data/miopen-cache /sgl-data/hf-cache/hub\n\n          # Clear pre-built AITER kernels from Docker image to avoid segfaults\n          # The image may have stale/incompatible kernels at /sgl-workspace/aiter/aiter/jit/\n          echo \"Clearing pre-built AITER kernels from Docker image...\"\n          docker exec ci_sglang rm -rf /sgl-workspace/aiter/aiter/jit/*.so 2>/dev/null || true\n          docker exec ci_sglang rm -rf /sgl-data/aiter-kernels/*.so 2>/dev/null || true\n          echo \"AITER kernels cleared - will be rebuilt on first use\"\n\n          # Create persistent cache marker if /sgl-data is a real mount (not ephemeral)\n          # This tells the test cleanup code to NOT delete downloaded models\n          if docker exec ci_sglang test -d /sgl-data && docker exec ci_sglang mountpoint -q /sgl-data 2>/dev/null; then\n            docker exec ci_sglang touch /sgl-data/hf-cache/.persistent_cache\n            echo \"Created .persistent_cache marker - HF cache will persist\"\n          else\n            echo \"WARNING: /sgl-data is not a mount point - models will be cleaned up after each test\"\n          fi\n\n          # Check MIOpen cache (VAE convolution kernels)\n          miopen_files=$(docker exec ci_sglang find /sgl-data/miopen-cache -name \"*.udb\" 2>/dev/null | wc -l || echo \"0\")\n          echo \"Found ${miopen_files} MIOpen cache files\"\n\n      - name: Diagnose HF cache and system resources\n        run: |\n          echo \"=== System Memory Status ===\"\n          free -h\n          echo \"\"\n          echo \"=== Disk Space ===\"\n          df -h /home/runner/sgl-data 2>/dev/null || df -h\n          echo \"\"\n          echo \"=== HF Cache Directory Structure ===\"\n          docker exec ci_sglang ls -la /sgl-data/hf-cache/ 2>/dev/null || echo \"HF cache dir not found\"\n          docker exec ci_sglang ls -la /sgl-data/hf-cache/hub/ 2>/dev/null || echo \"HF hub cache not found\"\n          echo \"\"\n          echo \"=== Checking for cached diffusion models (1-GPU tests) ===\"\n          # Models used in 1-GPU tests: Wan2.1-T2V-1.3B, HunyuanVideo, Qwen-Image, FLUX.1, FLUX.2\n          for model in \"Wan-AI--Wan2.1-T2V-1.3B-Diffusers\" \"tencent--HunyuanVideo\" \"Qwen--Qwen-Image\" \"black-forest-labs--FLUX.1-dev\" \"black-forest-labs--FLUX.2-dev\"; do\n            cache_path=\"/sgl-data/hf-cache/hub/models--${model}\"\n            if docker exec ci_sglang test -d \"$cache_path\"; then\n              size=$(docker exec ci_sglang du -sh \"$cache_path\" 2>/dev/null | cut -f1)\n              echo \"✓ CACHED: $model ($size)\"\n            else\n              echo \"✗ NOT CACHED: $model\"\n            fi\n          done\n          echo \"\"\n          echo \"=== GPU Memory Status ===\"\n          docker exec ci_sglang rocm-smi --showmeminfo vram 2>/dev/null || echo \"rocm-smi not available\"\n\n      - name: Run diffusion server tests (1-GPU)\n        timeout-minutes: 90\n        run: |\n          # AMD CI: All 1-GPU tests except FLUX.2 (FLUX.1 covers same code path)\n          # Tests: T2V, T2I, I2V, LoRA\n          #\n          # HF download env vars:\n          # - HF_HUB_ENABLE_HF_TRANSFER=1: Use faster hf_transfer for downloads (if available)\n          # - HF_HUB_DISABLE_SYMLINKS_WARNING=1: Suppress symlink warnings\n          docker exec \\\n            -e SGLANG_E2E_TOLERANCE=0.3 \\\n            -e SGLANG_STAGE_TIME_TOLERANCE=0.2 \\\n            -e SGLANG_NON_DENOISE_STAGE_TIME_TOLERANCE=0.6 \\\n            -e SGLANG_DENOISE_STEP_TOLERANCE=0.6 \\\n            -e SGLANG_DENOISE_AGG_TOLERANCE=0.3 \\\n            -e SGLANG_TEST_NUM_INFERENCE_STEPS=5 \\\n            -e AITER_JIT_DIR=/sgl-data/aiter-kernels \\\n            -e MIOPEN_USER_DB_PATH=/sgl-data/miopen-cache \\\n            -e HF_HUB_ENABLE_HF_TRANSFER=1 \\\n            -e HF_HUB_DISABLE_SYMLINKS_WARNING=1 \\\n            -w /sglang-checkout/python \\\n            ci_sglang python3 sglang/multimodal_gen/test/run_suite.py \\\n              --suite 1-gpu \\\n              --partition-id ${{ matrix.part }} \\\n              --total-partitions 4 \\\n              -k \"not flux_2\"\n\n          # Post-test diagnostics\n          echo \"=== Post-test System Memory Status ===\"\n          free -h\n\n  multimodal-gen-test-2-gpu-amd:\n    needs: [check-changes]\n    if: |\n      always() &&\n      (\n        (contains(format(',{0},', inputs.target_stage || inputs.target_stage_select), ',multimodal-gen-test-2-gpu-amd,')) ||\n        (\n          !(inputs.target_stage || inputs.target_stage_select) &&\n          needs.check-changes.outputs.multimodal_gen == 'true'\n        )\n      )\n    strategy:\n      fail-fast: false\n      max-parallel: 1  # Run one at a time to avoid eviction from resource exhaustion during AITER kernel JIT\n      matrix:\n        runner: [linux-mi325-2gpu-sglang]\n        part: [0, 1]  # 2 partitions: 9 tests ÷ 2 = ~4-5 tests each\n    runs-on: ${{matrix.runner}}\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.ref || github.sha }}\n\n      - name: Ensure VRAM is clear\n        run: bash scripts/ensure_vram_clear.sh rocm\n\n      - name: Download artifacts\n        if: needs.check-changes.outputs.sgl_kernel == 'true'\n        uses: actions/download-artifact@v4\n        with:\n          path: sgl-kernel/dist/\n          merge-multiple: true\n          pattern: wheel-python3.10-cuda12.9\n\n      - name: Start CI container\n        run: bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/amd/amd_ci_install_dependency.sh diffusion\n\n      - name: Setup kernel caches\n        run: |\n          # Use the persistent /sgl-data directory (mounted from /home/runner/sgl-data)\n          docker exec ci_sglang mkdir -p /sgl-data/aiter-kernels /sgl-data/miopen-cache /sgl-data/hf-cache/hub\n\n          # Clear pre-built AITER kernels from Docker image to avoid segfaults\n          # The image may have stale/incompatible kernels at /sgl-workspace/aiter/aiter/jit/\n          echo \"Clearing pre-built AITER kernels from Docker image...\"\n          docker exec ci_sglang rm -rf /sgl-workspace/aiter/aiter/jit/*.so 2>/dev/null || true\n          docker exec ci_sglang rm -rf /sgl-data/aiter-kernels/*.so 2>/dev/null || true\n          echo \"AITER kernels cleared - will be rebuilt on first use\"\n\n          # Create persistent cache marker if /sgl-data is a real mount (not ephemeral)\n          # This tells the test cleanup code to NOT delete downloaded models\n          if docker exec ci_sglang test -d /sgl-data && docker exec ci_sglang mountpoint -q /sgl-data 2>/dev/null; then\n            docker exec ci_sglang touch /sgl-data/hf-cache/.persistent_cache\n            echo \"Created .persistent_cache marker - HF cache will persist\"\n          else\n            echo \"WARNING: /sgl-data is not a mount point - models will be cleaned up after each test\"\n          fi\n\n          # Check MIOpen cache (VAE convolution kernels)\n          miopen_files=$(docker exec ci_sglang find /sgl-data/miopen-cache -name \"*.udb\" 2>/dev/null | wc -l || echo \"0\")\n          echo \"Found ${miopen_files} MIOpen cache files\"\n\n      - name: Diagnose HF cache and system resources\n        run: |\n          echo \"=== System Memory Status ===\"\n          free -h\n          echo \"\"\n          echo \"=== Disk Space ===\"\n          df -h /home/runner/sgl-data 2>/dev/null || df -h\n          echo \"\"\n          echo \"=== HF Cache Directory Structure ===\"\n          docker exec ci_sglang ls -la /sgl-data/hf-cache/ 2>/dev/null || echo \"HF cache dir not found\"\n          docker exec ci_sglang ls -la /sgl-data/hf-cache/hub/ 2>/dev/null || echo \"HF hub cache not found\"\n          echo \"\"\n          echo \"=== Checking for cached diffusion models (2-GPU tests) ===\"\n          # Models used in 2-GPU tests: Wan2.2-T2V-A14B, Wan2.1-T2V-14B, Qwen-Image, FLUX.1\n          for model in \"Wan-AI--Wan2.2-T2V-A14B-Diffusers\" \"Wan-AI--Wan2.1-T2V-14B-Diffusers\" \"Qwen--Qwen-Image\" \"black-forest-labs--FLUX.1-dev\"; do\n            cache_path=\"/sgl-data/hf-cache/hub/models--${model}\"\n            if docker exec ci_sglang test -d \"$cache_path\"; then\n              size=$(docker exec ci_sglang du -sh \"$cache_path\" 2>/dev/null | cut -f1)\n              echo \"✓ CACHED: $model ($size)\"\n            else\n              echo \"✗ NOT CACHED: $model\"\n            fi\n          done\n          echo \"\"\n          echo \"=== GPU Memory Status ===\"\n          docker exec ci_sglang rocm-smi --showmeminfo vram 2>/dev/null || echo \"rocm-smi not available\"\n\n      - name: Run diffusion server tests (2-GPU)\n        timeout-minutes: 80\n        run: |\n          # AMD CI: All 2-GPU tests including LoRA\n          # Tests: T2V, T2I, I2V, LoRA\n          #\n          # HF download env vars:\n          # - HF_HUB_ENABLE_HF_TRANSFER=1: Use faster hf_transfer for downloads (if available)\n          # - HF_HUB_DISABLE_SYMLINKS_WARNING=1: Suppress symlink warnings\n          docker exec \\\n            -e SGLANG_E2E_TOLERANCE=0.3 \\\n            -e SGLANG_STAGE_TIME_TOLERANCE=0.2 \\\n            -e SGLANG_NON_DENOISE_STAGE_TIME_TOLERANCE=0.6 \\\n            -e SGLANG_DENOISE_STEP_TOLERANCE=0.6 \\\n            -e SGLANG_DENOISE_AGG_TOLERANCE=0.3 \\\n            -e SGLANG_TEST_NUM_INFERENCE_STEPS=5 \\\n            -e AITER_JIT_DIR=/sgl-data/aiter-kernels \\\n            -e MIOPEN_USER_DB_PATH=/sgl-data/miopen-cache \\\n            -e HF_HUB_ENABLE_HF_TRANSFER=1 \\\n            -e HF_HUB_DISABLE_SYMLINKS_WARNING=1 \\\n            -w /sglang-checkout/python \\\n            ci_sglang python3 sglang/multimodal_gen/test/run_suite.py \\\n              --suite 2-gpu \\\n              --partition-id ${{ matrix.part }} \\\n              --total-partitions 2\n\n          # Post-test diagnostics\n          echo \"=== Post-test System Memory Status ===\"\n          free -h\n\n\n  stage-c-test-large-8-gpu-amd:\n    needs: [check-changes, call-gate, stage-b-test-small-1-gpu-amd, stage-b-test-large-2-gpu-amd]\n    if: |\n      always() &&\n      (\n        (contains(format(',{0},', inputs.target_stage || inputs.target_stage_select), ',stage-c-test-large-8-gpu-amd,')) ||\n        (\n          !(inputs.target_stage || inputs.target_stage_select) &&\n          (!failure() && !cancelled()) &&\n          ((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true'))\n        )\n      )\n    env:\n      RUNNER_LABELS: linux-mi325-8gpu-sglang\n    strategy:\n      fail-fast: false\n      matrix:\n        runner: [linux-mi325-8gpu-sglang]\n        part: [0, 1, 2]\n    runs-on: ${{matrix.runner}}\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.ref || github.sha }}\n\n      - name: Ensure VRAM is clear\n        run: bash scripts/ensure_vram_clear.sh rocm\n\n      - name: Start CI container\n        run: bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh\n\n      - name: Test RCCL multi-GPU communication\n        timeout-minutes: 5\n        run: |\n          echo \"Testing RCCL multi-GPU communication with debug info...\"\n          docker exec ci_sglang bash -c \"cd /sglang-checkout && NCCL_DEBUG=INFO RCCL_DEBUG=INFO torchrun --nproc_per_node=8 scripts/ci/amd/test_rccl_multi_gpu.py\"\n\n      - name: Run test\n        timeout-minutes: 60\n        run: |\n          bash scripts/ci/amd/amd_ci_exec.sh -w \"/sglang-checkout/test\" python3 run_suite.py --hw amd --suite stage-c-test-large-8-gpu-amd --auto-partition-id ${{ matrix.part }} --auto-partition-size 3 --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }}\n\n  stage-c-test-large-8-gpu-amd-mi35x:\n    needs: [check-changes, call-gate, stage-b-test-small-1-gpu-amd, stage-b-test-large-2-gpu-amd]\n    if: |\n      always() &&\n      (\n        (contains(format(',{0},', inputs.target_stage || inputs.target_stage_select), ',stage-c-test-large-8-gpu-amd-mi35x,')) ||\n        (\n          !(inputs.target_stage || inputs.target_stage_select) &&\n          (!failure() && !cancelled()) &&\n          ((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true'))\n        )\n      )\n    strategy:\n      fail-fast: false\n      matrix:\n        runner: [linux-mi35x-gpu-8]\n        part: [0, 1]\n    runs-on: ${{matrix.runner}}\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.ref || github.sha }}\n\n      - name: Ensure VRAM is clear\n        run: bash scripts/ensure_vram_clear.sh rocm\n\n      - name: Start CI container\n        run: bash scripts/ci/amd/amd_ci_start_container.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh\n\n      - name: Run test\n        timeout-minutes: 60\n        run: |\n          bash scripts/ci/amd/amd_ci_exec.sh -w \"/sglang-checkout/test\" python3 run_suite.py --hw amd --suite stage-c-test-large-8-gpu-amd-mi35x --auto-partition-id ${{ matrix.part }} --auto-partition-size 2 --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }}\n\n  # =============================================== Disaggregation ====================================================\n  stage-b-test-large-8-gpu-35x-disaggregation-amd:\n    needs: [check-changes, stage-a-test-1-amd]\n    if: |\n      always() &&\n      (\n        (contains(format(',{0},', inputs.target_stage || inputs.target_stage_select), ',stage-b-test-large-8-gpu-disaggregation-amd,')) ||\n        (\n          !(inputs.target_stage || inputs.target_stage_select) &&\n          (!failure() && !cancelled()) &&\n          ((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true'))\n        )\n      )\n    strategy:\n      fail-fast: false\n      matrix:\n        runner: [linux-mi35x-gpu-8.fabric]\n\n    runs-on: ${{matrix.runner}}\n\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.ref || github.sha }}\n\n      - name: Ensure VRAM is clear\n        run: bash scripts/ensure_vram_clear.sh rocm\n\n      - name: Check Host RDMA Environment\n        id: rdma_detect\n        run: |\n          set +e\n          echo \"=== Checking Host RDMA Environment ===\"\n\n          echo \"\"\n          echo \"=== 1. Ionic driver library check ===\"\n          ls -l /usr/lib/x86_64-linux-gnu/libibverbs/libionic* 2>/dev/null || echo \"libionic not found in standard path\"\n\n          echo \"\"\n          echo \"=== 2. Infiniband devices ===\"\n          ls -la /dev/infiniband/ 2>/dev/null || echo \"/dev/infiniband not found\"\n          ls -la /sys/class/infiniband/ 2>/dev/null || echo \"/sys/class/infiniband not found\"\n\n          echo \"\"\n          echo \"=== 3. ibv_devinfo ===\"\n          which ibv_devinfo 2>/dev/null && ibv_devinfo 2>&1 || echo \"ibv_devinfo not available\"\n\n          echo \"\"\n          echo \"=== 4. Kernel modules ===\"\n          lsmod 2>/dev/null | grep -E \"ib_|rdma|ionic\" || echo \"No RDMA kernel modules loaded\"\n\n          echo \"\"\n          echo \"=== 5. Detect RDMA Devices for test environment ===\"\n          if [ -d \"/sys/class/infiniband\" ]; then\n            RDMA_DEVS=$(ls /sys/class/infiniband | paste -sd \",\" -)\n            echo \"Detected RDMA Devices: $RDMA_DEVS\"\n            echo \"SGLANG_TEST_RDMA_DEVICE=$RDMA_DEVS\" >> $GITHUB_ENV\n          else\n            echo \"No RDMA devices found in /sys/class/infiniband\"\n            echo \"SGLANG_TEST_RDMA_DEVICE=\" >> $GITHUB_ENV\n          fi\n\n          echo \"\"\n          echo \"=== Host RDMA Check Complete ===\"\n\n      - name: Start Special Container\n        run: bash scripts/ci/amd/amd_ci_start_container_disagg.sh\n        env:\n          GITHUB_WORKSPACE: ${{ github.workspace }}\n\n      - name: Install dependencies\n        run: bash scripts/ci/amd/amd_ci_install_dependency.sh\n\n      - name: Verify RDMA in Container\n        run: |\n          docker exec -u root ci_sglang bash -c '\n            echo \"=== Container RDMA Verification ===\"\n            echo \"Device nodes:\"\n            ls -la /dev/infiniband/\n            echo \"\"\n            echo \"Provider libraries:\"\n            ls /usr/lib/x86_64-linux-gnu/libibverbs/ | grep -E \"ionic|mlx\" || echo \"No Ionic/Mellanox providers\"\n            echo \"\"\n            echo \"HCA devices:\"\n            HCA_COUNT=$(ibv_devinfo -list 2>&1 | grep -oE \"^[0-9]+ HCAs? found\" | grep -oE \"^[0-9]+\" || echo \"0\")\n            ibv_devinfo -list\n            if [ \"$HCA_COUNT\" -gt 0 ]; then\n              echo \"\"\n              echo \"=== SUCCESS: RDMA setup complete. Found $HCA_COUNT HCA(s) ===\"\n            else\n              echo \"\"\n              echo \"=== WARNING: No HCAs detected. RDMA tests may fail ===\"\n            fi\n          '\n\n      - name: Run Aiter Op Test (RMSNorm)\n        timeout-minutes: 10\n        run: |\n          echo \"Running pre-check: test_rmsnorm2d.py\"\n          docker exec \\\n            -e MAX_JOBS=192 \\\n            ci_sglang \\\n            python /sgl-workspace/aiter/op_tests/test_rmsnorm2d.py\n\n      - name: Run test_disaggregation\n        timeout-minutes: 60\n        run: |\n          bash scripts/ci/amd/amd_ci_exec.sh \\\n            -e SGLANG_TEST_RDMA_DEVICE=\"${{ env.SGLANG_TEST_RDMA_DEVICE }}\" \\\n            -w \"/sglang-checkout/test\" python3 run_suite.py --hw amd --suite stage-b-test-large-8-gpu-35x-disaggregation-amd --timeout-per-file 1800 ${{ inputs.continue_on_error && '--continue-on-error' || '' }}\n\n  pr-test-amd-finish:\n    needs:\n      [\n        call-gate,\n        check-changes,\n\n        sgl-kernel-unit-test-amd,\n        sgl-kernel-unit-test-2-gpu-amd,\n        multimodal-gen-test-1-gpu-amd,\n        multimodal-gen-test-2-gpu-amd,\n\n        stage-a-test-1-amd,\n        jit-kernel-unit-test-amd,\n        stage-b-test-small-1-gpu-amd,\n        stage-b-test-small-1-gpu-amd-nondeterministic,\n        stage-b-test-small-1-gpu-amd-mi35x,\n        stage-b-test-large-1-gpu-amd,\n        stage-b-test-large-2-gpu-amd,\n        stage-b-test-large-8-gpu-35x-disaggregation-amd,\n        stage-c-test-large-8-gpu-amd,\n        stage-c-test-large-8-gpu-amd-mi35x,\n      ]\n    if: always()\n    runs-on: ubuntu-latest\n    steps:\n      - name: Check all dependent job statuses\n        run: |\n          # Convert the 'needs' context to a JSON string\n          json_needs='${{ toJson(needs) }}'\n\n          # Get a list of all job names from the JSON keys\n          job_names=$(echo \"$json_needs\" | jq -r 'keys_unsorted[]')\n\n          for job in $job_names; do\n            # For each job, extract its result\n            result=$(echo \"$json_needs\" | jq -r --arg j \"$job\" '.[$j].result')\n\n            # Print the job name and its result\n            echo \"$job: $result\"\n\n            # Check for failure or cancellation and exit if found\n            if [[ \"$result\" == \"failure\" || \"$result\" == \"cancelled\" ]]; then\n              echo \"The above jobs failed.\"\n              exit 1\n            fi\n          done\n\n          # If the loop completes, all jobs were successful\n          echo \"All jobs completed successfully\"\n          exit 0\n"
  },
  {
    "path": ".github/workflows/pr-test-npu.yml",
    "content": "name: PR Test (NPU)\r\n\r\non:\r\n  push:\r\n    branches: [ main ]\r\n  pull_request:\r\n    branches: [ main ]\r\n  workflow_dispatch:\r\n  workflow_call:\r\n    inputs:\r\n      ref:\r\n        description: 'Git ref (branch, tag, or SHA) to test. If not provided, uses the default branch.'\r\n        required: false\r\n        type: string\r\n        default: ''\r\n      run_all_tests:\r\n        description: \"Run all tests (for releasing or testing purpose)\"\r\n        required: false\r\n        type: boolean\r\n        default: false\r\n\r\nconcurrency:\r\n  group: pr-test-npu-${{ inputs.ref || github.ref }}\r\n  cancel-in-progress: ${{ github.event_name != 'workflow_call' }}\r\n\r\njobs:\r\n  # ==================== Check Changes ==================== #\r\n  check-changes:\r\n    runs-on: ubuntu-latest\r\n    outputs:\r\n      changes_exist: ${{ steps.filter.outputs.main_package == 'true' || steps.filter.outputs.multimodal_gen == 'true' || steps.run-mode.outputs.run_all_tests == 'true'}}\r\n      main_package: ${{ steps.filter.outputs.main_package == 'true' || steps.run-mode.outputs.run_all_tests == 'true' }}\r\n      multimodal_gen: ${{ steps.filter.outputs.multimodal_gen == 'true' || steps.run-mode.outputs.run_all_tests == 'true' }}\r\n    steps:\r\n      - name: Checkout code\r\n        uses: actions/checkout@v4\r\n        with:\r\n          ref: ${{ inputs.ref || github.ref }}\r\n\r\n      - name: Determine run mode\r\n        id: run-mode\r\n        run: |\r\n          # Run all tests for workflow_call (when ref input is provided)\r\n          # Note: github.event_name is inherited from caller, so we detect workflow_call by checking inputs.ref\r\n          if [[ \"${{ inputs.run_all_tests }}\" == \"true\" ]]; then\r\n            echo \"run_all_tests=true\" >> $GITHUB_OUTPUT\r\n            echo \"Run mode: ALL TESTS (run_all_tests=${{ inputs.run_all_tests }})\"\r\n          else\r\n            echo \"run_all_tests=false\" >> $GITHUB_OUTPUT\r\n            echo \"Run mode: FILTERED (triggered by ${{ github.event_name }})\"\r\n          fi\r\n\r\n      - name: Detect file changes\r\n        id: filter\r\n        uses: dorny/paths-filter@v3\r\n        if: steps.run-mode.outputs.run_all_tests != 'true'\r\n        with:\r\n          filters: |\r\n            main_package:\r\n              - \"python/sglang/!(multimodal_gen)/**\"\r\n              - \"python/pyproject_npu.toml\"\r\n              - \"scripts/ci/npu/npu_ci_install_dependency.sh\"\r\n              - \"test/srt/ascend/**\"\r\n              - \".github/workflows/pr-test-npu.yml\"\r\n            multimodal_gen:\r\n              - \"python/sglang/multimodal_gen/**\"\r\n              - \"python/sglang/srt/**\"\r\n              - \"python/pyproject_npu.toml\"\r\n              - \"scripts/ci/npu/npu_ci_install_dependency.sh\"\r\n              - \".github/workflows/pr-test-npu.yml\"\r\n\r\n  # ==================== PR Gate ==================== #\r\n  pr-gate:\r\n    needs: check-changes\r\n    if: needs.check-changes.outputs.changes_exist == 'true'\r\n    uses: ./.github/workflows/pr-gate.yml\r\n    secrets: inherit\r\n\r\n  per-commit-1-npu-a2:\r\n    needs: [check-changes, pr-gate]\r\n    if: needs.check-changes.outputs.main_package == 'true'\r\n    runs-on: linux-aarch64-a2-1\r\n    strategy:\r\n      fail-fast: false\r\n      matrix:\r\n        part: [ 0, 1 ]\r\n    container:\r\n      image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.5.0-910b-ubuntu22.04-py3.11\r\n    steps:\r\n      - name: Checkout code\r\n        uses: actions/checkout@v4\r\n        with:\r\n          ref: ${{ inputs.ref || github.ref }}\r\n\r\n      - name: Install dependencies\r\n        env:\r\n          TORCH_CACHE_URL: \"http://cache-service.nginx-pypi-cache.svc.cluster.local/whl/cpu\"\r\n          PYPI_CACHE_URL: \"http://cache-service.nginx-pypi-cache.svc.cluster.local/pypi/simple\"\r\n          GITHUB_PROXY_URL: \"https://gh-proxy.test.osinfra.cn/\"\r\n        run: |\r\n          # speed up by using infra cache services\r\n          CACHING_URL=\"cache-service.nginx-pypi-cache.svc.cluster.local\"\r\n          sed -Ei \"s@(ports|archive).ubuntu.com@${CACHING_URL}:8081@g\" /etc/apt/sources.list\r\n          pip config set global.index-url http://${CACHING_URL}/pypi/simple\r\n          pip config set global.trusted-host \"${CACHING_URL}\"\r\n\r\n          bash scripts/ci/npu/npu_ci_install_dependency.sh 910b\r\n          # copy required file from our daily cache\r\n          cp ~/.cache/modelscope/hub/datasets/otavia/ShareGPT_Vicuna_unfiltered/ShareGPT_V3_unfiltered_cleaned_split.json /tmp\r\n          # copy download through proxy\r\n          curl -o /tmp/test.jsonl -L https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl\r\n\r\n      - name: Run registered test\r\n        timeout-minutes: 240\r\n        env:\r\n          SGLANG_USE_MODELSCOPE: true\r\n          SGLANG_IS_IN_CI: true\r\n          HF_ENDPOINT: https://hf-mirror.com\r\n          TORCH_EXTENSIONS_DIR: /tmp/torch_extensions\r\n          PYTORCH_NPU_ALLOC_CONF: \"expandable_segments:True\"\r\n          STREAMS_PER_DEVICE: 32\r\n        run: |\r\n          cd test\r\n          python3 run_suite.py --hw npu --suite per-commit-1-npu-a2 --continue-on-error --timeout-per-file 3600 --auto-partition-id ${{ matrix.part }} --auto-partition-size 2\r\n\r\n      - name: Run test\r\n        timeout-minutes: 60\r\n        env:\r\n          SGLANG_USE_MODELSCOPE: true\r\n          SGLANG_IS_IN_CI: true\r\n          HF_ENDPOINT: https://hf-mirror.com\r\n          TORCH_EXTENSIONS_DIR: /tmp/torch_extensions\r\n          PYTORCH_NPU_ALLOC_CONF: \"expandable_segments:True\"\r\n          STREAMS_PER_DEVICE: 32\r\n        run: |\r\n          cd test/srt\r\n          python3 run_suite.py --suite per-commit-1-npu-a2 --auto-partition-id ${{ matrix.part }} --auto-partition-size 2\r\n\r\n  per-commit-2-npu-a2:\r\n    needs: [check-changes, pr-gate]\r\n    if: needs.check-changes.outputs.main_package == 'true'\r\n    runs-on: linux-aarch64-a2-2\r\n    strategy:\r\n      fail-fast: true\r\n      matrix:\r\n        part: [0, 1]\r\n    container:\r\n      image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.5.0-910b-ubuntu22.04-py3.11\r\n    steps:\r\n      - name: Checkout code\r\n        uses: actions/checkout@v4\r\n        with:\r\n          ref: ${{ inputs.ref || github.ref }}\r\n\r\n      - name: Install dependencies\r\n        env:\r\n          TORCH_CACHE_URL: \"http://cache-service.nginx-pypi-cache.svc.cluster.local/whl/cpu\"\r\n          PYPI_CACHE_URL: \"http://cache-service.nginx-pypi-cache.svc.cluster.local/pypi/simple\"\r\n          GITHUB_PROXY_URL: \"https://gh-proxy.test.osinfra.cn/\"\r\n        run: |\r\n          # speed up by using infra cache services\r\n          CACHING_URL=\"cache-service.nginx-pypi-cache.svc.cluster.local\"\r\n          sed -Ei \"s@(ports|archive).ubuntu.com@${CACHING_URL}:8081@g\" /etc/apt/sources.list\r\n          pip config set global.index-url http://${CACHING_URL}/pypi/simple\r\n          pip config set global.trusted-host \"${CACHING_URL}\"\r\n\r\n          bash scripts/ci/npu/npu_ci_install_dependency.sh 910b\r\n          # copy required file from our daily cache\r\n          cp ~/.cache/modelscope/hub/datasets/otavia/ShareGPT_Vicuna_unfiltered/ShareGPT_V3_unfiltered_cleaned_split.json /tmp\r\n          # copy download through proxy\r\n          curl -o /tmp/test.jsonl -L https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl\r\n\r\n      - name: Run test\r\n        timeout-minutes: 60\r\n        env:\r\n          SGLANG_USE_MODELSCOPE: true\r\n          SGLANG_IS_IN_CI: true\r\n          HF_ENDPOINT: https://hf-mirror.com\r\n          TORCH_EXTENSIONS_DIR: /tmp/torch_extensions\r\n          PYTORCH_NPU_ALLOC_CONF: \"expandable_segments:True\"\r\n          STREAMS_PER_DEVICE: 32\r\n        run: |\r\n          cd test/srt\r\n          python3 run_suite.py --suite per-commit-2-npu-a2 --auto-partition-id ${{ matrix.part }} --auto-partition-size 2\r\n\r\n  per-commit-4-npu-a3:\r\n    needs: [check-changes, pr-gate]\r\n    if: needs.check-changes.outputs.main_package == 'true'\r\n    runs-on: linux-aarch64-a3-4\r\n    container:\r\n      image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.5.0-a3-ubuntu22.04-py3.11\r\n    steps:\r\n      - name: Checkout code\r\n        uses: actions/checkout@v4\r\n        with:\r\n          ref: ${{ inputs.ref || github.ref }}\r\n\r\n      - name: Install dependencies\r\n        env:\r\n          TORCH_CACHE_URL: \"http://cache-service.nginx-pypi-cache.svc.cluster.local/whl/cpu\"\r\n          PYPI_CACHE_URL: \"http://cache-service.nginx-pypi-cache.svc.cluster.local/pypi/simple\"\r\n          GITHUB_PROXY_URL: \"https://gh-proxy.test.osinfra.cn/\"\r\n        run: |\r\n          # speed up by using infra cache services\r\n          CACHING_URL=\"cache-service.nginx-pypi-cache.svc.cluster.local\"\r\n          sed -Ei \"s@(ports|archive).ubuntu.com@${CACHING_URL}:8081@g\" /etc/apt/sources.list\r\n          pip config set global.index-url http://${CACHING_URL}/pypi/simple\r\n          pip config set global.trusted-host \"${CACHING_URL}\"\r\n\r\n          bash scripts/ci/npu/npu_ci_install_dependency.sh a3\r\n          # copy required file from our daily cache\r\n          cp ~/.cache/modelscope/hub/datasets/otavia/ShareGPT_Vicuna_unfiltered/ShareGPT_V3_unfiltered_cleaned_split.json /tmp\r\n          # copy download through proxy\r\n          curl -o /tmp/test.jsonl -L https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl\r\n\r\n      - name: Run test\r\n        timeout-minutes: 60\r\n        env:\r\n          SGLANG_USE_MODELSCOPE: true\r\n          SGLANG_IS_IN_CI: true\r\n          HF_ENDPOINT: https://hf-mirror.com\r\n          TORCH_EXTENSIONS_DIR: /tmp/torch_extensions\r\n          PYTORCH_NPU_ALLOC_CONF: \"expandable_segments:True\"\r\n          STREAMS_PER_DEVICE: 32\r\n        run: |\r\n          cd test/srt\r\n          python3 run_suite.py --suite per-commit-4-npu-a3 --timeout-per-file 3600\r\n\r\n  per-commit-16-npu-a3:\r\n    needs: [check-changes, pr-gate]\r\n    if: needs.check-changes.outputs.main_package == 'true'\r\n    runs-on: linux-aarch64-a3-16\r\n    container:\r\n      image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.5.0-a3-ubuntu22.04-py3.11\r\n    steps:\r\n      - name: Checkout code\r\n        uses: actions/checkout@v4\r\n        with:\r\n          ref: ${{ inputs.ref || github.ref }}\r\n\r\n      - name: Install dependencies\r\n        env:\r\n          TORCH_CACHE_URL: \"http://cache-service.nginx-pypi-cache.svc.cluster.local/whl/cpu\"\r\n          PYPI_CACHE_URL: \"http://cache-service.nginx-pypi-cache.svc.cluster.local/pypi/simple\"\r\n          GITHUB_PROXY_URL: \"https://gh-proxy.test.osinfra.cn/\"\r\n        run: |\r\n          # speed up by using infra cache services\r\n          CACHING_URL=\"cache-service.nginx-pypi-cache.svc.cluster.local\"\r\n          sed -Ei \"s@(ports|archive).ubuntu.com@${CACHING_URL}:8081@g\" /etc/apt/sources.list\r\n          pip config set global.index-url http://${CACHING_URL}/pypi/simple\r\n          pip config set global.trusted-host \"${CACHING_URL}\"\r\n\r\n          bash scripts/ci/npu/npu_ci_install_dependency.sh a3\r\n          # copy required file from our daily cache\r\n          cp ~/.cache/modelscope/hub/datasets/otavia/ShareGPT_Vicuna_unfiltered/ShareGPT_V3_unfiltered_cleaned_split.json /tmp\r\n          # copy download through proxy\r\n          curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl\r\n\r\n      - name: Run test\r\n        timeout-minutes: 60\r\n        env:\r\n          SGLANG_USE_MODELSCOPE: true\r\n          SGLANG_IS_IN_CI: true\r\n          HF_ENDPOINT: https://hf-mirror.com\r\n          TORCH_EXTENSIONS_DIR: /tmp/torch_extensions\r\n          PYTORCH_NPU_ALLOC_CONF: \"expandable_segments:True\"\r\n          STREAMS_PER_DEVICE: 32\r\n        run: |\r\n          cd test/srt\r\n          python3 run_suite.py --suite per-commit-16-npu-a3 --timeout-per-file 3600\r\n\r\n  multimodal-gen-test-1-npu-a3:\r\n    needs: [check-changes, pr-gate]\r\n    if: needs.check-changes.outputs.multimodal_gen == 'true'\r\n    runs-on: linux-aarch64-a3-2\r\n    container:\r\n      image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.3.rc2-a3-ubuntu22.04-py3.11\r\n    steps:\r\n      - name: Checkout code\r\n        uses: actions/checkout@v4\r\n\r\n      - name: Install dependencies\r\n        env:\r\n          TORCH_CACHE_URL: \"http://cache-service.nginx-pypi-cache.svc.cluster.local/whl/cpu\"\r\n          PYPI_CACHE_URL: \"http://cache-service.nginx-pypi-cache.svc.cluster.local/pypi/simple\"\r\n          GITHUB_PROXY_URL: \"https://gh-proxy.test.osinfra.cn/\"\r\n        run: |\r\n          # speed up by using infra cache services\r\n          CACHING_URL=\"cache-service.nginx-pypi-cache.svc.cluster.local\"\r\n          sed -Ei \"s@(ports|archive).ubuntu.com@${CACHING_URL}:8081@g\" /etc/apt/sources.list\r\n          pip config set global.index-url http://${CACHING_URL}/pypi/simple\r\n          pip config set global.trusted-host \"${CACHING_URL}\"\r\n\r\n          bash scripts/ci/npu/npu_ci_install_dependency.sh a3 diffusion\r\n          # copy required file from our daily cache\r\n          cp ~/.cache/modelscope/hub/datasets/otavia/ShareGPT_Vicuna_unfiltered/ShareGPT_V3_unfiltered_cleaned_split.json /tmp\r\n          # copy download through proxy\r\n          curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl\r\n\r\n      - name: Run test\r\n        timeout-minutes: 60\r\n        env:\r\n          SGLANG_USE_MODELSCOPE: true\r\n          SGLANG_IS_IN_CI: true\r\n          HF_ENDPOINT: https://hf-mirror.com\r\n          TORCH_EXTENSIONS_DIR: /tmp/torch_extensions\r\n          PYTORCH_NPU_ALLOC_CONF: \"expandable_segments:True\"\r\n          STREAMS_PER_DEVICE: 32\r\n        run: |\r\n          export PATH=\"/usr/local/Ascend/8.3.RC1/compiler/bishengir/bin:${PATH}\"\r\n          cd python\r\n          python3 sglang/multimodal_gen/test/run_suite.py --suite 1-npu\r\n\r\n  multimodal-gen-test-2-npu-a3:\r\n    needs: [check-changes, pr-gate]\r\n    if: needs.check-changes.outputs.multimodal_gen == 'true'\r\n    runs-on: linux-aarch64-a3-16\r\n    container:\r\n      image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.3.rc2-a3-ubuntu22.04-py3.11\r\n    steps:\r\n      - name: Checkout code\r\n        uses: actions/checkout@v4\r\n\r\n      - name: Install dependencies\r\n        env:\r\n          TORCH_CACHE_URL: \"http://cache-service.nginx-pypi-cache.svc.cluster.local/whl/cpu\"\r\n          PYPI_CACHE_URL: \"http://cache-service.nginx-pypi-cache.svc.cluster.local/pypi/simple\"\r\n          GITHUB_PROXY_URL: \"https://gh-proxy.test.osinfra.cn/\"\r\n        run: |\r\n          # speed up by using infra cache services\r\n          CACHING_URL=\"cache-service.nginx-pypi-cache.svc.cluster.local\"\r\n          sed -Ei \"s@(ports|archive).ubuntu.com@${CACHING_URL}:8081@g\" /etc/apt/sources.list\r\n          pip config set global.index-url http://${CACHING_URL}/pypi/simple\r\n          pip config set global.trusted-host \"${CACHING_URL}\"\r\n\r\n          bash scripts/ci/npu/npu_ci_install_dependency.sh a3 diffusion\r\n          # copy required file from our daily cache\r\n          cp ~/.cache/modelscope/hub/datasets/otavia/ShareGPT_Vicuna_unfiltered/ShareGPT_V3_unfiltered_cleaned_split.json /tmp\r\n          # copy download through proxy\r\n          curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl\r\n\r\n      - name: Run test\r\n        timeout-minutes: 60\r\n        env:\r\n          SGLANG_USE_MODELSCOPE: true\r\n          SGLANG_IS_IN_CI: true\r\n          HF_ENDPOINT: https://hf-mirror.com\r\n          TORCH_EXTENSIONS_DIR: /tmp/torch_extensions\r\n          PYTORCH_NPU_ALLOC_CONF: \"expandable_segments:True\"\r\n          STREAMS_PER_DEVICE: 32\r\n        run: |\r\n          export PATH=\"/usr/local/Ascend/8.3.RC1/compiler/bishengir/bin:${PATH}\"\r\n          cd python\r\n          python3 sglang/multimodal_gen/test/run_suite.py --suite 2-npu\r\n\r\n  multimodal-gen-test-8-npu-a3:\r\n    needs: [check-changes, pr-gate]\r\n    if: needs.check-changes.outputs.multimodal_gen == 'true'\r\n    runs-on: linux-aarch64-a3-16\r\n    container:\r\n      image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.5.0-a3-ubuntu22.04-py3.11\r\n    steps:\r\n      - name: Checkout code\r\n        uses: actions/checkout@v4\r\n\r\n      - name: Install dependencies\r\n        run: |\r\n          # speed up by using infra cache services\r\n          CACHING_URL=\"cache-service.nginx-pypi-cache.svc.cluster.local\"\r\n          sed -Ei \"s@(ports|archive).ubuntu.com@${CACHING_URL}:8081@g\" /etc/apt/sources.list\r\n          pip config set global.index-url http://${CACHING_URL}/pypi/simple\r\n          pip config set global.extra-index-url \"https://pypi.tuna.tsinghua.edu.cn/simple\"\r\n          pip config set global.trusted-host \"${CACHING_URL} pypi.tuna.tsinghua.edu.cn\"\r\n\r\n          bash scripts/ci/npu/npu_ci_install_dependency.sh a3 diffusion\r\n          # copy required file from our daily cache\r\n          cp ~/.cache/modelscope/hub/datasets/otavia/ShareGPT_Vicuna_unfiltered/ShareGPT_V3_unfiltered_cleaned_split.json /tmp\r\n          # copy download through proxy\r\n          curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl\r\n\r\n      - name: Run test\r\n        timeout-minutes: 60\r\n        env:\r\n          SGLANG_USE_MODELSCOPE: true\r\n          SGLANG_IS_IN_CI: true\r\n          HF_ENDPOINT: https://hf-mirror.com\r\n          TORCH_EXTENSIONS_DIR: /tmp/torch_extensions\r\n          PYTORCH_NPU_ALLOC_CONF: \"expandable_segments:True\"\r\n          STREAMS_PER_DEVICE: 32\r\n        run: |\r\n          cd python\r\n          python3 sglang/multimodal_gen/test/run_suite.py --suite 8-npu\r\n"
  },
  {
    "path": ".github/workflows/pr-test-rust.yml",
    "content": "name: PR Test (SMG)\n\non:\n  push:\n    branches: [ main ]\n    paths:\n      - \"sgl-model-gateway/**\"\n  pull_request:\n    branches: [ main ]\n    types: [opened, synchronize, reopened, labeled]\n    paths:\n      - \"sgl-model-gateway/**\"\n  workflow_dispatch:\n\nconcurrency:\n  group: gateway-tests-${{ github.ref }}\n  cancel-in-progress: true\n\nenv:\n  RUSTC_WRAPPER: sccache\n  SCCACHE_GHA_ENABLED: \"true\"\n\njobs:\n  build-wheel:\n    if: |\n      github.event_name != 'pull_request' ||\n      (github.event.action != 'labeled' && contains(github.event.pull_request.labels.*.name, 'run-ci')) ||\n      (github.event.action == 'labeled' && github.event.label.name == 'run-ci')\n    runs-on: 4-gpu-a10\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n\n      - name: Install rust dependencies\n        run: |\n          bash scripts/ci/cuda/ci_install_gateway_dependencies.sh\n\n      - name: Configure sccache\n        uses: mozilla-actions/sccache-action@v0.0.9\n        with:\n          version: \"v0.12.0\"\n          disable_annotations: true\n\n      - name: Rust cache\n        uses: Swatinem/rust-cache@v2\n        with:\n          workspaces: sgl-model-gateway\n          shared-key: \"rust-cache\"\n          cache-all-crates: true\n          cache-on-failure: true\n          save-if: true\n\n      - name: Build python binding\n        run: |\n          source \"$HOME/.cargo/env\"\n          export RUSTC_WRAPPER=sccache\n          cd sgl-model-gateway/bindings/python\n          python3 -m pip install --upgrade pip maturin\n          maturin build --profile ci --features vendored-openssl --out dist\n\n      - name: List built wheel\n        run: ls -lh sgl-model-gateway/bindings/python/dist/\n\n      - name: Upload wheel artifact\n        uses: actions/upload-artifact@v4\n        with:\n          name: smg-wheel\n          path: sgl-model-gateway/bindings/python/dist/*.whl\n          retention-days: 1\n\n      - name: Test wheel install\n        run: |\n          pip install sgl-model-gateway/bindings/python/dist/*.whl\n          python3 -c \"import sglang_router; print('Python package: OK')\"\n          python3 -c \"from sglang_router.sglang_router_rs import Router; print('Rust extension: OK')\"\n          python3 -m sglang_router.launch_router --help > /dev/null && echo \"Entry point: OK\"\n\n  python-unit-tests:\n    needs: build-wheel\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v4\n        with:\n          path: sglang-repo\n\n      - name: Move sgl-model-gateway folder to root\n        run: |\n          mv sglang-repo/sgl-model-gateway/* .\n          rm -rf sglang-repo\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: \"3.13\"\n\n      - name: Download wheel artifact\n        uses: actions/download-artifact@v4\n        with:\n          name: smg-wheel\n          path: dist/\n\n      - name: Install wheel\n        run: pip install dist/*.whl\n\n      - name: Run Python unit tests\n        run: |\n          cd bindings/python\n          python3 -m pip install pytest pytest-cov pytest-xdist\n          pytest -q tests --cov=sglang_router --cov-config=.coveragerc --cov-report=term-missing --cov-fail-under=80\n\n  unit-tests:\n    if: |\n      github.event_name != 'pull_request' ||\n      (github.event.action != 'labeled' && contains(github.event.pull_request.labels.*.name, 'run-ci')) ||\n      (github.event.action == 'labeled' && github.event.label.name == 'run-ci')\n    runs-on: ubuntu-latest\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/cuda/ci_install_gateway_dependencies.sh\n\n      - name: Configure sccache\n        uses: mozilla-actions/sccache-action@v0.0.9\n        with:\n          version: \"v0.12.0\"\n          disable_annotations: true\n\n      - name: Rust cache\n        uses: Swatinem/rust-cache@v2\n        with:\n          workspaces: sgl-model-gateway\n          shared-key: \"rust-cache\"\n          cache-all-crates: true\n          cache-on-failure: true\n          save-if: true\n\n      - name: Run lint\n        run: |\n          source \"$HOME/.cargo/env\"\n          cd sgl-model-gateway/\n          rustup component add clippy\n          cargo clippy --all-targets --all-features -- -D warnings\n\n      - name: Run fmt\n        run: |\n          source \"$HOME/.cargo/env\"\n          cd sgl-model-gateway/\n          rustup component add --toolchain nightly-x86_64-unknown-linux-gnu rustfmt\n          rustup toolchain install nightly --profile minimal\n          cargo +nightly fmt -- --check\n\n      - name: Generate vision golden fixtures\n        run: |\n          pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu\n\n          pip install transformers pillow numpy scipy\n          pip install transformers pillow numpy\n          cd sgl-model-gateway/\n          python scripts/generate_vision_golden.py\n\n      - name: Run Rust tests\n        timeout-minutes: 20\n        run: |\n          source \"$HOME/.cargo/env\"\n          cd sgl-model-gateway/\n          cargo test\n\n      - name: Show sccache stats\n        if: always()\n        run: sccache --show-stats\n\n  gateway-e2e:\n    name: ${{ matrix.name }}\n    needs: build-wheel\n    if: |\n      github.event_name != 'pull_request' ||\n      (github.event.action != 'labeled' && contains(github.event.pull_request.labels.*.name, 'run-ci')) ||\n      (github.event.action == 'labeled' && github.event.label.name == 'run-ci')\n    strategy:\n      fail-fast: false\n      matrix:\n        include:\n          - name: benchmarks\n            timeout: 32\n            test_dirs: \"e2e_test/benchmarks\"\n            extra_deps: \"genai-bench==0.0.3\"\n            env_vars: \"\"\n            reruns: \"\"\n            upload_benchmarks: true\n            parallel_opts: \"\"  # No parallel for benchmarks (performance measurement)\n          - name: responses\n            timeout: 45\n            test_dirs: \"e2e_test/responses\"\n            extra_deps: \"\"\n            env_vars: \"SHOW_WORKER_LOGS=0 SHOW_ROUTER_LOGS=1\"\n            reruns: \"--reruns 2 --reruns-delay 5\"\n            setup_oracle: true\n            setup_brave: true\n            parallel_opts: \"\"  # Cloud backend tests not compatible with parallel execution\n          - name: e2e\n            timeout: 45\n            test_dirs: \"e2e_test/router e2e_test/embeddings\"\n            extra_deps: \"pytest-parallel py\"  # py is required for pytest-parallel with newer pytest\n            env_vars: \"SHOW_WORKER_LOGS=0 SHOW_ROUTER_LOGS=1\"\n            reruns: \"--reruns 2 --reruns-delay 5\"\n            parallel_opts: \"--workers 1 --tests-per-worker 4\"  # Thread-based parallelism\n          - name: chat-completions\n            timeout: 45\n            test_dirs: \"e2e_test/chat_completions\"\n            extra_deps: \"\"\n            env_vars: \"SHOW_WORKER_LOGS=0 SHOW_ROUTER_LOGS=1\"\n            reruns: \"--reruns 2 --reruns-delay 5\"\n            parallel_opts: \"\"\n    runs-on: 4-gpu-a10\n    timeout-minutes: ${{ matrix.timeout }}\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n\n      - name: Install SGLang dependencies\n        run: |\n          sudo --preserve-env=PATH bash scripts/ci/cuda/ci_install_dependency.sh\n\n      - name: Setup Oracle Instant Client\n        if: matrix.setup_oracle\n        run: |\n          sudo apt-get install -y unzip\n          INSTANT_CLIENT_DIR=\"/home/ubuntu/instant-client\"\n          INSTANT_CLIENT_ZIP=\"instantclient-basic-linux.x64-23.9.0.25.07.zip\"\n\n          if [ ! -d \"$INSTANT_CLIENT_DIR/instantclient_23_9\" ]; then\n            echo \"Downloading Oracle Instant Client...\"\n            mkdir -p \"$INSTANT_CLIENT_DIR\"\n            cd \"$INSTANT_CLIENT_DIR\"\n            wget https://download.oracle.com/otn_software/linux/instantclient/2390000/$INSTANT_CLIENT_ZIP\n            unzip $INSTANT_CLIENT_ZIP\n            rm $INSTANT_CLIENT_ZIP\n          else\n            echo \"Oracle Instant Client already exists, skipping download\"\n          fi\n\n          echo \"LD_LIBRARY_PATH=/home/ubuntu/instant-client/instantclient_23_9:\\$LD_LIBRARY_PATH\" >> $GITHUB_ENV\n\n      - name: Start Oracle Database\n        if: matrix.setup_oracle\n        run: |\n          docker run -d -p 1521:1521 -e ORACLE_PASSWORD=oracle --name oracle-db gvenzl/oracle-xe:21-slim\n          echo \"Starting Oracle DB...\"\n\n          # Export Oracle connection environment variables\n          echo \"ATP_USER=system\" >> $GITHUB_ENV\n          echo \"ATP_PASSWORD=oracle\" >> $GITHUB_ENV\n          echo \"ATP_DSN=localhost:1521/XEPDB1\" >> $GITHUB_ENV\n\n      - name: Start Brave MCP Server\n        if: matrix.setup_brave\n        run: |\n          docker run -d --rm \\\n            -p 8001:8080 \\\n            -e BRAVE_API_KEY \\\n            --name brave-search-server \\\n            shoofio/brave-search-mcp-sse:1.0.10\n          echo \"Starting Brave MCP Server...\"\n          sleep 2\n          curl -f --max-time 1 http://localhost:8001/sse > /dev/null 2>&1 && echo \"Brave MCP Server is healthy!\" || echo \"Brave MCP Server responded\"\n\n      - name: Download wheel artifact\n        uses: actions/download-artifact@v4\n        with:\n          name: smg-wheel\n          path: wheel/\n\n      - name: Install wheel\n        run: |\n          pip uninstall -y sglang-router || true\n          pip install wheel/*.whl\n\n      - name: Install e2e test dependencies\n        run: |\n          python3 -m pip install pytest pytest-rerunfailures httpx openai grpcio grpcio-health-checking numpy\n          if [ -n \"${{ matrix.extra_deps }}\" ]; then\n            python3 -m pip --no-cache-dir install --upgrade ${{ matrix.extra_deps }}\n          fi\n\n      - name: Run E2E tests\n        run: |\n          bash scripts/killall_sglang.sh all\n          cd sgl-model-gateway\n          ${{ matrix.env_vars }} ROUTER_LOCAL_MODEL_PATH=\"/home/ubuntu/models\" pytest ${{ matrix.reruns }} ${{ matrix.parallel_opts }} ${{ matrix.test_dirs }} -s -vv -o log_cli=true --log-cli-level=INFO\n\n      - name: Upload benchmark results\n        if: matrix.upload_benchmarks && success()\n        uses: actions/upload-artifact@v4\n        with:\n          name: genai-bench-results-all-policies\n          path: sgl-model-gateway/benchmark_**/\n\n      - name: Cleanup Brave MCP Server\n        if: always() && matrix.setup_brave\n        run: |\n          docker stop brave-search-server || true\n          docker rm brave-search-server || true\n\n      - name: Cleanup Oracle Database\n        if: always() && matrix.setup_oracle\n        run: |\n          docker stop oracle-db || true\n          docker rm oracle-db || true\n\n  docker-build-test:\n    if: |\n      github.event_name != 'pull_request' ||\n      (github.event.action != 'labeled' && contains(github.event.pull_request.labels.*.name, 'run-ci')) ||\n      (github.event.action == 'labeled' && github.event.label.name == 'run-ci')\n    runs-on: ubuntu-24.04\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v4\n\n      - name: Set up Docker Buildx\n        uses: docker/setup-buildx-action@v3\n\n      - name: Build Docker image (no push)\n        uses: docker/build-push-action@v5\n        with:\n          context: .\n          file: docker/gateway.Dockerfile\n          push: false\n          tags: sgl-model-gateway:test\n          cache-from: type=gha\n          cache-to: type=gha,mode=max\n\n  finish:\n    needs: [build-wheel, python-unit-tests, unit-tests, gateway-e2e, docker-build-test]\n    runs-on: ubuntu-latest\n    steps:\n      - name: Finish\n        run: echo \"This is an empty step to ensure that all jobs are completed.\"\n\n  summarize-benchmarks:\n    needs: gateway-e2e\n    runs-on: ubuntu-latest\n    if: success()\n\n    steps:\n    - name: Checkout code\n      uses: actions/checkout@v4\n\n    - name: Download benchmark results\n      uses: actions/download-artifact@v4\n      with:\n        name: genai-bench-results-all-policies\n\n    - name: Create benchmark summary\n      run: python3 sgl-model-gateway/e2e_test/benchmarks/summarize.py .\n"
  },
  {
    "path": ".github/workflows/pr-test-xeon.yml",
    "content": "name: PR Test (Xeon)\n\non:\n  push:\n    branches: [ main ]\n  pull_request:\n    branches: [ main ]\n  workflow_dispatch:\n  workflow_call:\n    inputs:\n      ref:\n        description: 'Git ref (branch, tag, or SHA) to test. If not provided, uses the default branch.'\n        required: false\n        type: string\n        default: ''\n      run_all_tests:\n        description: \"Run all tests (for releasing or testing purpose)\"\n        required: false\n        type: boolean\n        default: false\n\nconcurrency:\n  group: pr-test-xeon-${{ inputs.ref || github.ref }}\n  cancel-in-progress: false\n\njobs:\n  # ==================== Check Changes ==================== #\n  check-changes:\n    runs-on: ubuntu-latest\n    outputs:\n      main_package: ${{ steps.filter.outputs.main_package || steps.run-mode.outputs.run_all_tests}}\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Determine run mode\n        id: run-mode\n        run: |\n          # Run all tests for workflow_call (when ref input is provided)\n          # Note: github.event_name is inherited from caller, so we detect workflow_call by checking inputs.ref\n          if [[ \"${{ inputs.run_all_tests }}\" == \"true\" ]]; then\n            echo \"run_all_tests=true\" >> $GITHUB_OUTPUT\n            echo \"Run mode: ALL TESTS (run_all_tests=${{ inputs.run_all_tests }})\"\n          else\n            echo \"run_all_tests=false\" >> $GITHUB_OUTPUT\n            echo \"Run mode: FILTERED (triggered by ${{ github.event_name }})\"\n          fi\n\n      - name: Detect file changes\n        id: filter\n        uses: dorny/paths-filter@v3\n        if: steps.run-mode.outputs.run_all_tests != 'true'\n        with:\n          filters: |\n            main_package:\n              - \"python/sglang/!(multimodal_gen)/**\"\n              - \"python/pyproject_cpu.toml\"\n              - \"test/**\"\n              - \"sgl-kernel/**\"\n              - \".github/workflows/pr-test-xeon.yml\"\n              - \"docker/xeon.Dockerfile\"\n\n  # ==================== PR Gate ==================== #\n  pr-gate:\n    needs: check-changes\n    if: needs.check-changes.outputs.main_package == 'true'\n    uses: ./.github/workflows/pr-gate.yml\n    secrets: inherit\n\n  build-test:\n    needs: [check-changes, pr-gate]\n    if: needs.check-changes.outputs.main_package == 'true'\n    runs-on: xeon-gnr\n    env:\n      HF_HOME: /home/sdp/.cache/huggingface\n    strategy:\n      matrix:\n        build_type: ['all']\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Build and Push\n        run: |\n          version=$(cat python/sglang/version.py | cut -d'\"' -f2)\n          tag=v${version}-xeon\n          PR_REPO=${{ github.event.pull_request.head.repo.clone_url }}\n          PR_HEAD_REF=${{ github.head_ref }}\n\n          docker build \\\n            ${PR_REPO:+--build-arg SGLANG_REPO=$PR_REPO} \\\n            ${PR_HEAD_REF:+--build-arg VER_SGLANG=$PR_HEAD_REF} \\\n            . -f docker/xeon.Dockerfile  -t sglang_xeon --no-cache\n\n      - name: Run container\n        run: |\n          docker run -dt \\\n            -v ${{ github.workspace }}:/sglang-checkout/ --ipc=host \\\n            -v ${HF_HOME}:/root/.cache/huggingface \\\n            --name ci_sglang_xeon \\\n            sglang_xeon\n\n      - name: Check AMX support\n        id: check_amx\n        timeout-minutes: 5\n        run: |\n          docker exec -w /sglang-checkout/ ci_sglang_xeon \\\n            bash -c \"source /opt/.venv/bin/activate && python3 -c 'import torch; import sgl_kernel; assert torch._C._cpu._is_amx_tile_supported(); assert hasattr(torch.ops.sgl_kernel, \\\"convert_weight_packed\\\"); '\"\n\n      - name: Run unit tests\n        timeout-minutes: 36\n        run: |\n          docker exec -w /sglang-checkout/ ci_sglang_xeon \\\n            bash -c \"source /opt/.venv/bin/activate && cd ./test/srt && python3 run_suite.py --suite per-commit-cpu --timeout-per-file 1500\"\n\n      - name: Change permission\n        timeout-minutes: 2\n        run: |\n          docker exec -u root ci_sglang_xeon bash -c \"\n            rm -rf /tmp/ci-home  &&\n            chown -R  $(id -u):$(id -g) /sglang-checkout/ 2>/dev/null || true\n          \"\n\n      - name: Cleanup container\n        if: always()\n        run: |\n          docker rm -f ci_sglang_xeon || true\n"
  },
  {
    "path": ".github/workflows/pr-test-xpu.yml",
    "content": "name: PR Test (XPU)\n\non:\n  push:\n    branches: [ main ]\n  pull_request:\n    branches: [ main ]\n  workflow_dispatch:\n  workflow_call:\n    inputs:\n      ref:\n        description: 'Git ref (branch, tag, or SHA) to test. If not provided, uses the default branch.'\n        required: false\n        type: string\n        default: ''\n      run_all_tests:\n        description: \"Run all tests (for releasing or testing purpose)\"\n        required: false\n        type: boolean\n        default: false\n\nconcurrency:\n  group: pr-test-xpu-${{ inputs.ref || github.ref }}\n  cancel-in-progress: ${{ github.event_name != 'workflow_call' }}\n\njobs:\n  # ==================== Check Changes ==================== #\n  check-changes:\n    runs-on: ubuntu-latest\n    outputs:\n      main_package: ${{ steps.filter.outputs.main_package || steps.run-mode.outputs.run_all_tests }}\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Determine run mode\n        id: run-mode\n        run: |\n          # Run all tests for workflow_call (when ref input is provided)\n          # Note: github.event_name is inherited from caller, so we detect workflow_call by checking inputs.ref\n          if [[ \"${{ inputs.run_all_tests }}\" == \"true\" ]]; then\n            echo \"run_all_tests=true\" >> $GITHUB_OUTPUT\n            echo \"Run mode: ALL TESTS (run_all_tests=${{ inputs.run_all_tests }})\"\n          else\n            echo \"run_all_tests=false\" >> $GITHUB_OUTPUT\n            echo \"Run mode: FILTERED (triggered by ${{ github.event_name }})\"\n          fi\n      - name: Detect file changes\n        id: filter\n        uses: dorny/paths-filter@v3\n        if: steps.run-mode.outputs.run_all_tests != 'true'\n        with:\n          filters: |\n            main_package:\n              - \"python/sglang/!(multimodal_gen)/**\"\n              - \"python/pyproject_xpu.toml\"\n              - \"test/**\"\n              - \"sgl-kernel/**\"\n              - \".github/workflows/pr-test-xpu.yml\"\n              - \"docker/xpu.Dockerfile\"\n\n  # ==================== PR Gate ==================== #\n  pr-gate:\n    needs: check-changes\n    if: needs.check-changes.outputs.main_package == 'true'\n    uses: ./.github/workflows/pr-gate.yml\n    secrets: inherit\n\n  build-and-test:\n    needs: [check-changes, pr-gate]\n    if: needs.check-changes.outputs.main_package == 'true'\n    runs-on: intel-bmg\n    env:\n      HF_HOME: /home/sdp/.cache/huggingface\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          fetch-depth: 0\n          ref: ${{ inputs.ref || github.ref }}\n\n      - name: Set up Docker Buildx\n        uses: docker/setup-buildx-action@v3\n\n      - name: Build Docker image\n        run: |\n          PR_REPO=${{ github.event.pull_request.head.repo.clone_url }}\n          PR_HEAD_REF=${{ github.head_ref }}\n          docker build \\\n            ${PR_REPO:+--build-arg SG_LANG_REPO=$PR_REPO} \\\n            ${PR_HEAD_REF:+--build-arg SG_LANG_BRANCH=$PR_HEAD_REF} \\\n            --no-cache --progress=plain -f docker/xpu.Dockerfile -t xpu_sglang_main:bmg .\n\n      - name: Run container\n        id: start_container\n        run: |\n          container_id=$(docker run -dt \\\n            --group-add 992 \\\n            --group-add $(getent group video | cut -d: -f3) \\\n            -v ${HF_HOME}:/root/.cache/huggingface \\\n            --device /dev/dri \\\n            -e HF_TOKEN=\"$(cat ~/huggingface_token.txt)\" \\\n            xpu_sglang_main:bmg)\n          echo \"Started container: $container_id\"\n          echo \"container_id=$container_id\" >> \"$GITHUB_OUTPUT\"\n\n      - name: Install Dependency\n        timeout-minutes: 20\n        run: |\n          cid=\"${{ steps.start_container.outputs.container_id }}\"\n          docker exec \"$cid\" /home/sdp/miniforge3/envs/py3.10/bin/python3 -m pip install --upgrade pip\n          docker exec \"$cid\" /home/sdp/miniforge3/envs/py3.10/bin/python3 -m pip install pytest expecttest ray huggingface_hub\n          docker exec \"$cid\" /home/sdp/miniforge3/envs/py3.10/bin/python3 -m pip uninstall -y flashinfer-python\n          docker exec \"$cid\" /bin/bash -c '/home/sdp/miniforge3/envs/py3.10/bin/hf auth login --token ${HF_TOKEN} '\n\n\n      - name: Run E2E Bfloat16 tests\n        timeout-minutes: 20\n        run: |\n          cid=\"${{ steps.start_container.outputs.container_id }}\"\n          docker exec \"$cid\" bash -c \"source /home/sdp/miniforge3/bin/activate && conda activate py3.10 && cd /home/sdp/sglang/test/srt && python3 run_suite.py --suite per-commit-xpu\"\n      - name: Cleanup container\n        if: always()\n        run: |\n          cid=\"${{ steps.start_container.outputs.container_id }}\"\n          docker rm -f \"$cid\" || true\n\n  finish:\n    if: always()\n    needs: [build-and-test, pr-gate]\n    runs-on: ubuntu-latest\n    steps:\n      - name: Check job status\n        run: |\n          result=\"${{ needs.build-and-test.result }}\"\n          if [ \"$result\" != \"success\" ] && [ \"$result\" != \"skipped\" ]; then\n            echo \"Job failed with result: $result\"\n            exit 1\n          fi\n          echo \"All jobs completed successfully (result: $result)\"\n          exit 0\n"
  },
  {
    "path": ".github/workflows/pr-test.yml",
    "content": "name: PR Test\n# Dynamic run-name for /rerun-stage commands to enable URL lookup\n# Format: \"[stage-name] sha\" for fork PRs, \"[stage-name]\" for non-fork, default for normal runs\nrun-name: ${{ inputs.target_stage && (inputs.pr_head_sha && format('[{0}] {1}', inputs.target_stage, inputs.pr_head_sha) || format('[{0}]', inputs.target_stage)) || '' }}\n\non:\n  schedule:\n    - cron: '0 */12 * * *'  # Run every 12 hours\n  pull_request:\n    branches: [main]\n  workflow_dispatch:\n    inputs:\n      target_stage:\n        description: \"Specific stage to run (optional, for quick testing)\"\n        required: false\n        type: string\n        default: \"\"\n      force_continue_on_error:\n        description: \"Force continue-on-error (test scheduled CI behavior)\"\n        required: false\n        type: boolean\n        default: false\n      pr_head_sha:\n        description: \"PR head SHA to checkout (for /rerun-stage on fork PRs)\"\n        required: false\n        type: string\n        default: \"\"\n      test_parallel_dispatch:\n        description: \"Test parallel dispatch behavior (simulates scheduled run)\"\n        required: false\n        type: boolean\n        default: false\n  workflow_call:\n    inputs:\n      git_ref:\n        description: 'Git ref (branch, tag, or SHA) to test. If not provided, uses the default branch.'\n        required: false\n        type: string\n        default: ''\n      run_all_tests:\n        description: \"Run all tests (for releasing or testing purpose)\"\n        required: false\n        type: boolean\n        default: false\n\nconcurrency:\n  # Concurrency group structure: pr-test-{event}-{branch}-{pr_sha}-{stage}\n  # - event_name prevents scheduled runs from colliding with fork PRs whose branch is named 'main'\n  #   (without it, both resolve the branch segment to 'main' and block each other)\n  # - github.head_ref (pull_request) or github.ref_name (workflow_dispatch) normalizes to branch name\n  # - pr_head_sha isolates /rerun-stage from main branch runs\n  # - target_stage allows parallel stage dispatches to run independently\n  group: pr-test-${{ github.event_name }}-${{ github.head_ref || github.ref_name || 'default' }}-${{ inputs.pr_head_sha || 'current' }}-${{ inputs.target_stage || inputs.git_ref || 'all' }}\n  cancel-in-progress: ${{ github.event_name != 'workflow_call' }}\n\nenv:\n  SGLANG_IS_IN_CI: true\n  SGLANG_CUDA_COREDUMP: \"1\"\n  SGLANG_JIT_DEEPGEMM_FAST_WARMUP: true\n\npermissions:\n  actions: write\n  contents: read\n  pull-requests: read\n\njobs:\n  # =============================================== check changes ====================================================\n  check-changes:\n    runs-on: ubuntu-latest\n    outputs:\n      # Use API-based detection for target_stage mode (filter-api), otherwise use dorny/paths-filter (filter)\n      main_package: ${{ steps.filter-api.outputs.main_package || steps.filter.outputs.main_package || steps.run-mode.outputs.run_all_tests }}\n      # sgl_kernel is forced to false when target_stage is set, since sgl-kernel-build-wheels won't run\n      # This prevents CUSTOM_BUILD_SGL_KERNEL=true when the wheel artifacts aren't available\n      # Note: If PR has kernel changes AND target_stage is set, the validate-target-stage step will fail\n      sgl_kernel: ${{ !inputs.target_stage && (steps.filter-api.outputs.sgl_kernel || steps.filter.outputs.sgl_kernel) }}\n      # Raw sgl_kernel value before target_stage override (used for validation)\n      sgl_kernel_raw: ${{ steps.filter-api.outputs.sgl_kernel || steps.filter.outputs.sgl_kernel }}\n      jit_kernel: ${{ steps.filter-api.outputs.jit_kernel || steps.filter.outputs.jit_kernel || steps.run-mode.outputs.run_all_tests }}\n      multimodal_gen: ${{ steps.filter-api.outputs.multimodal_gen || steps.filter.outputs.multimodal_gen || steps.run-mode.outputs.run_all_tests }}\n      max_parallel: ${{ steps.set-parallel.outputs.max_parallel }}\n      b200_runner: ${{ steps.set-runner.outputs.b200_runner }}\n      enable_retry: ${{ steps.set-retry.outputs.enable_retry }}\n      continue_on_error: ${{ steps.set-continue-on-error.outputs.continue_on_error }}\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.git_ref || github.sha }}\n\n      - name: Show test partition assignments\n        continue-on-error: true\n        run: python3 test/show_partitions.py\n\n      - name: Determine run mode\n        id: run-mode\n        run: |\n          # Run all tests for scheduled runs and workflow_call (when ref input is provided)\n          # Note: github.event_name is inherited from caller, so we detect workflow_call by checking inputs.git_ref\n          if [[ \"${{ github.event_name }}\" == \"schedule\" || \"${{ inputs.run_all_tests }}\" == \"true\" ]]; then\n            echo \"run_all_tests=true\" >> $GITHUB_OUTPUT\n            echo \"Run mode: ALL TESTS (schedule=${{ github.event_name == 'schedule' }}, run_all_tests=${{ inputs.run_all_tests }})\"\n          else\n            echo \"run_all_tests=false\" >> $GITHUB_OUTPUT\n            echo \"Run mode: FILTERED (triggered by ${{ github.event_name }})\"\n          fi\n\n      - name: Detect file changes\n        id: filter\n        uses: dorny/paths-filter@v3\n        # Only use paths-filter for pull_request events (where it works correctly)\n        # For workflow_dispatch with target_stage, we use GitHub API in the next step\n        if: steps.run-mode.outputs.run_all_tests != 'true' && !inputs.target_stage\n        with:\n          filters: |\n            main_package:\n              - \".github/workflows/pr-test.yml\"\n              - \"python/pyproject.toml\"\n              - \"python/sglang/!(multimodal_gen)/**\"\n              - \"scripts/ci/cuda/*\"\n              - \"scripts/ci/utils/*\"\n              - \"test/**\"\n            multimodal_gen:\n              - \".github/workflows/pr-test.yml\"\n              - \"python/pyproject.toml\"\n              - \"python/sglang/multimodal_gen/**\"\n              - \"python/sglang/jit_kernel/**\"\n              - \"python/sglang/cli/**\"\n            jit_kernel:\n              - \".github/workflows/pr-test.yml\"\n              - \"python/pyproject.toml\"\n              - \"python/sglang/jit_kernel/**\"\n            sgl_kernel:\n              - \"sgl-kernel/**\"\n\n      # For /rerun-stage (workflow_dispatch with target_stage), dorny/paths-filter doesn't work\n      # correctly because it falls back to \"last commit\" detection which breaks for merge commits.\n      # Instead, we use the GitHub API to compare the PR commit against main.\n      - name: Detect file changes via API (for target_stage)\n        id: filter-api\n        if: inputs.target_stage && inputs.pr_head_sha\n        env:\n          GH_TOKEN: ${{ github.token }}\n        run: |\n          echo \"Detecting file changes via GitHub API for target_stage mode...\"\n          echo \"PR head SHA: ${{ inputs.pr_head_sha }}\"\n\n          # Get the list of changed files by comparing PR commit against main\n          # This correctly handles merge commits by looking at the actual PR diff\n          CHANGED_FILES=$(gh api \"repos/${{ github.repository }}/compare/main...${{ inputs.pr_head_sha }}\" \\\n            --jq '[.files[].filename] | .[]' 2>/dev/null || echo \"\")\n\n          if [ -z \"$CHANGED_FILES\" ]; then\n            echo \"Warning: Could not fetch changed files from API, assuming no changes\"\n            echo \"sgl_kernel=false\" >> $GITHUB_OUTPUT\n            echo \"main_package=false\" >> $GITHUB_OUTPUT\n            echo \"jit_kernel=false\" >> $GITHUB_OUTPUT\n            echo \"multimodal_gen=false\" >> $GITHUB_OUTPUT\n            exit 0\n          fi\n\n          echo \"Changed files:\"\n          echo \"$CHANGED_FILES\" | head -20\n          echo \"...\"\n\n          # Check for sgl-kernel changes\n          if echo \"$CHANGED_FILES\" | grep -q \"^sgl-kernel/\"; then\n            echo \"sgl_kernel=true\" >> $GITHUB_OUTPUT\n            echo \"Detected sgl-kernel changes\"\n          else\n            echo \"sgl_kernel=false\" >> $GITHUB_OUTPUT\n          fi\n\n          # Check for main_package changes (excluding multimodal_gen)\n          # Note: Need to filter out multimodal_gen before checking, not pipe grep -q output\n          MAIN_PKG_FILES=$(echo \"$CHANGED_FILES\" | grep -E \"^(python/sglang/|python/pyproject\\.toml|scripts/ci/cuda/|scripts/ci/utils/|test/|\\.github/workflows/pr-test\\.yml)\" | grep -v \"^python/sglang/multimodal_gen/\" || true)\n          if [ -n \"$MAIN_PKG_FILES\" ]; then\n            echo \"main_package=true\" >> $GITHUB_OUTPUT\n            echo \"Detected main_package changes\"\n          else\n            echo \"main_package=false\" >> $GITHUB_OUTPUT\n          fi\n\n          # Check for jit_kernel changes\n          if echo \"$CHANGED_FILES\" | grep -qE \"^(python/sglang/jit_kernel/|python/pyproject\\.toml|\\.github/workflows/pr-test\\.yml)\"; then\n            echo \"jit_kernel=true\" >> $GITHUB_OUTPUT\n            echo \"Detected jit_kernel changes\"\n          else\n            echo \"jit_kernel=false\" >> $GITHUB_OUTPUT\n          fi\n\n          # Check for multimodal_gen changes\n          if echo \"$CHANGED_FILES\" | grep -qE \"^(python/sglang/multimodal_gen/|python/sglang/cli/|python/pyproject\\.toml|\\.github/workflows/pr-test\\.yml)\"; then\n            echo \"multimodal_gen=true\" >> $GITHUB_OUTPUT\n            echo \"Detected multimodal_gen changes\"\n          else\n            echo \"multimodal_gen=false\" >> $GITHUB_OUTPUT\n          fi\n\n      - name: Set max-parallel based on run type\n        id: set-parallel\n        env:\n          GH_TOKEN: ${{ github.token }}\n        run: |\n          # Scheduled runs and high-priority PRs get full parallelism\n          if [[ \"${{ github.event_name }}\" == \"schedule\" ]]; then\n            echo \"max_parallel=14\" >> $GITHUB_OUTPUT\n            echo \"Scheduled run detected, setting max_parallel to 14\"\n          elif [[ \"${{ github.event_name }}\" == \"pull_request\" && \"${{ contains(github.event.pull_request.labels.*.name, 'high priority') }}\" == \"true\" ]]; then\n            echo \"max_parallel=14\" >> $GITHUB_OUTPUT\n            echo \"High priority PR detected, setting max_parallel to 14\"\n          elif [[ -n \"${{ inputs.target_stage }}\" ]]; then\n            # /rerun-stage (workflow_dispatch): query PR labels via GitHub API\n            # Try SHA lookup first (fork PRs), fallback to branch name (non-fork PRs)\n            LABELS=\"\"\n            PR_HEAD_SHA=\"${{ inputs.pr_head_sha }}\"\n            if [[ -n \"$PR_HEAD_SHA\" ]]; then\n              LABELS=$(gh api \"repos/${{ github.repository }}/commits/${PR_HEAD_SHA}/pulls\" \\\n                --jq '.[0].labels[].name' 2>/dev/null || true)\n            fi\n            if [[ -z \"$LABELS\" ]]; then\n              LABELS=$(gh pr list --head \"${{ github.ref_name }}\" --repo \"${{ github.repository }}\" \\\n                --json labels --jq '.[0].labels[].name' 2>/dev/null || true)\n            fi\n            echo \"PR labels: ${LABELS:-\"(none)\"}\"\n            if echo \"$LABELS\" | grep -Fxq \"high priority\"; then\n              echo \"max_parallel=14\" >> $GITHUB_OUTPUT\n              echo \"High priority PR detected via API (/rerun-stage), setting max_parallel to 14\"\n            else\n              echo \"max_parallel=3\" >> $GITHUB_OUTPUT\n              echo \"Using default max_parallel of 3 (/rerun-stage, no high priority label)\"\n            fi\n          else\n            echo \"max_parallel=3\" >> $GITHUB_OUTPUT\n            echo \"Using default max_parallel of 3\"\n          fi\n\n      - name: Set B200 runner tag\n        id: set-runner\n        run: |\n          # Use kernel-build runner only when sgl_kernel changes are detected AND we're not in target_stage mode\n          # (target_stage skips wheel builds, so we can't use custom kernels)\n          # Use API-based detection (filter-api) for target_stage mode, otherwise use dorny/paths-filter (filter)\n          sgl_kernel=\"${{ steps.filter-api.outputs.sgl_kernel || steps.filter.outputs.sgl_kernel }}\"\n          target_stage=\"${{ inputs.target_stage }}\"\n          if [[ \"$sgl_kernel\" == \"true\" && -z \"$target_stage\" ]]; then\n            echo \"b200_runner=4-gpu-b200-kernel\" >> $GITHUB_OUTPUT\n          else\n            echo \"b200_runner=4-gpu-b200\" >> $GITHUB_OUTPUT\n          fi\n\n      - name: Enable retry for CI\n        id: set-retry\n        run: |\n          echo \"enable_retry=true\" >> $GITHUB_OUTPUT\n          echo \"Retry logic enabled for CI\"\n\n      - name: Set continue-on-error for full test runs\n        id: set-continue-on-error\n        run: |\n          if [[ \"${{ steps.run-mode.outputs.run_all_tests }}\" == \"true\" || \"${{ inputs.force_continue_on_error }}\" == \"true\" ]]; then\n            echo \"continue_on_error=true\" >> $GITHUB_OUTPUT\n            echo \"Full test run or force flag detected, enabling continue-on-error to run all tests\"\n          else\n            echo \"continue_on_error=false\" >> $GITHUB_OUTPUT\n            echo \"Filtered run, continue-on-error disabled\"\n          fi\n\n      - name: Validate target_stage with kernel changes\n        # Use API-based detection (filter-api) for target_stage mode, otherwise use dorny/paths-filter (filter)\n        if: inputs.target_stage && (steps.filter-api.outputs.sgl_kernel == 'true' || steps.filter.outputs.sgl_kernel == 'true')\n        run: |\n          echo \"::error::Cannot use /rerun-stage when PR has sgl-kernel changes.\"\n          echo \"::error::The sgl-kernel-build-wheels job is skipped in target_stage mode, but this PR modifies sgl-kernel/ files.\"\n          echo \"::error::Please use /tag-and-rerun-ci to run the full workflow including kernel builds.\"\n          echo \"\"\n          echo \"ERROR: Cannot use /rerun-stage when PR has sgl-kernel changes.\"\n          echo \"\"\n          echo \"This PR modifies files in sgl-kernel/, which requires building custom kernel wheels.\"\n          echo \"The /rerun-stage command skips the wheel build job, so the test would run against\"\n          echo \"the wrong (PyPI) version of sgl-kernel instead of your changes.\"\n          echo \"\"\n          echo \"To properly test your kernel changes, use one of these commands instead:\"\n          echo \"  /tag-and-rerun-ci           - Re-run the full workflow including kernel builds\"\n          echo \"  /rerun-ci                   - Re-run the full workflow\"\n          echo \"\"\n          exit 1\n\n      - name: Show filter results in summary (table)\n        run: |\n          {\n            echo \"## Change Detection\"\n            echo \"\"\n            echo \"| Component         | Changed |\"\n            echo \"|-------------------|---------|\"\n            echo \"| main_package      | ${{ steps.filter-api.outputs.main_package || steps.filter.outputs.main_package || steps.run-mode.outputs.run_all_tests }} |\"\n            echo \"| sgl_kernel (raw)  | ${{ steps.filter-api.outputs.sgl_kernel || steps.filter.outputs.sgl_kernel }} |\"\n            echo \"| sgl_kernel (used) | ${{ !inputs.target_stage && (steps.filter-api.outputs.sgl_kernel || steps.filter.outputs.sgl_kernel) }} |\"\n            echo \"| jit_kernel        | ${{ steps.filter-api.outputs.jit_kernel || steps.filter.outputs.jit_kernel || steps.run-mode.outputs.run_all_tests }} |\"\n            echo \"| multimodal_gen    | ${{ steps.filter-api.outputs.multimodal_gen || steps.filter.outputs.multimodal_gen || steps.run-mode.outputs.run_all_tests }} |\"\n            echo \"| target_stage      | ${{ inputs.target_stage || '(none)' }} |\"\n            echo \"| detection_method  | ${{ inputs.target_stage && 'GitHub API' || 'dorny/paths-filter' }} |\"\n            echo \"| max_parallel      | ${{ steps.set-parallel.outputs.max_parallel }} |\"\n            echo \"| b200_runner       | ${{ steps.set-runner.outputs.b200_runner }} |\"\n            echo \"| enable_retry      | ${{ steps.set-retry.outputs.enable_retry }} |\"\n            echo \"| continue_on_error | ${{ steps.set-continue-on-error.outputs.continue_on_error }} |\"\n          } >> $GITHUB_STEP_SUMMARY\n\n  # =============================================== Wait Jobs for Sequential PR Execution ====================================================\n  # These jobs poll GitHub API to wait for previous stages to complete.\n  # For PR runs: wait jobs run and enforce sequential execution via polling.\n  # For scheduled runs: wait jobs are skipped, enabling parallel execution for easier retry.\n\n  wait-for-stage-a:\n    needs: [check-changes, call-gate]\n    # Only run for PRs (not scheduled) and when not targeting a specific stage\n    # Skip if call-gate failed (stage-a jobs will be skipped, nothing to wait for)\n    # !cancelled() ensures this job respects workflow cancellation from concurrency group\n    if: |\n      always() &&\n      !cancelled() &&\n      github.event_name == 'pull_request' &&\n      !inputs.target_stage &&\n      inputs.test_parallel_dispatch != true &&\n      (needs.check-changes.outputs.main_package == 'true' || needs.check-changes.outputs.sgl_kernel == 'true') &&\n      (needs.call-gate.result == 'success' || needs.call-gate.result == 'skipped')\n    runs-on: ubuntu-latest\n    outputs:\n      stage_a_result: ${{ steps.wait.outputs.result }}\n    steps:\n      - uses: actions/checkout@v4\n      - uses: ./.github/actions/wait-for-jobs\n        id: wait\n        with:\n          stage-name: stage-a\n          jobs: '[\"stage-a-test-1\", \"stage-a-cpu-only\"]'\n          max-wait-minutes: '240'\n\n  wait-for-stage-b:\n    needs: [check-changes, call-gate, wait-for-stage-a]\n    # Only run for PRs (not scheduled) and when not targeting a specific stage\n    # Skip if call-gate failed (stage-b jobs will be skipped, nothing to wait for)\n    if: |\n      always() &&\n      !cancelled() &&\n      github.event_name == 'pull_request' &&\n      !inputs.target_stage &&\n      inputs.test_parallel_dispatch != true &&\n      (needs.check-changes.outputs.main_package == 'true' || needs.check-changes.outputs.sgl_kernel == 'true') &&\n      (needs.wait-for-stage-a.result == 'success' || needs.wait-for-stage-a.result == 'skipped') &&\n      (needs.call-gate.result == 'success' || needs.call-gate.result == 'skipped')\n    runs-on: ubuntu-latest\n    outputs:\n      stage_b_result: ${{ steps.wait.outputs.result }}\n    steps:\n      - uses: actions/checkout@v4\n      - uses: ./.github/actions/wait-for-jobs\n        id: wait\n        with:\n          stage-name: stage-b\n          jobs: |\n            [\n              {\"prefix\": \"stage-b-test-small-1-gpu\", \"expected_count\": 8},\n              {\"prefix\": \"stage-b-test-large-1-gpu\", \"expected_count\": 14},\n              {\"prefix\": \"stage-b-test-large-2-gpu\", \"expected_count\": 4},\n              {\"prefix\": \"stage-b-test-4-gpu-b200\", \"expected_count\": 1}\n            ]\n          max-wait-minutes: '480'\n\n  # =============================================== PR Gate ====================================================\n  call-gate:\n    needs: check-changes\n    # Skip for scheduled runs (they run all tests) and when target_stage is specified\n    if: |\n      github.event_name != 'schedule' &&\n      inputs.test_parallel_dispatch != true &&\n      !inputs.target_stage &&\n      (\n        needs.check-changes.outputs.main_package == 'true' ||\n        needs.check-changes.outputs.sgl_kernel == 'true' ||\n        needs.check-changes.outputs.jit_kernel == 'true' ||\n        needs.check-changes.outputs.multimodal_gen == 'true'\n      )\n    uses: ./.github/workflows/pr-gate.yml\n    secrets: inherit\n\n  # =============================================== sgl-kernel ====================================================\n\n  sgl-kernel-build-wheels:\n    needs: [check-changes, call-gate]\n    # Skip for scheduled runs (they run stages independently) and when target_stage is set\n    if: github.event_name != 'schedule' && inputs.test_parallel_dispatch != true && !inputs.target_stage && needs.check-changes.outputs.sgl_kernel == 'true'\n    runs-on: x64-kernel-build-node\n    timeout-minutes: 240\n    strategy:\n      matrix:\n        include:\n          - python-version: \"3.10\"\n            cuda-version: \"12.9\"\n          # Add back when CUDA 13.0 is supported on CI\n          # - python-version: \"3.10\"\n          #   cuda-version: \"13.0\"\n    name: Build Wheel\n    steps:\n      - name: Cleanup\n        run: |\n          sudo rm -rf $GITHUB_WORKSPACE/* || true\n\n      - uses: actions/checkout@v4\n        with:\n          submodules: \"recursive\"\n          ref: ${{ inputs.pr_head_sha || inputs.git_ref || github.sha }}\n\n      - name: Set up Python ${{ matrix.python-version }}\n        uses: actions/setup-python@v5\n        with:\n          python-version: ${{ matrix.python-version }}\n\n      - name: Build wheel for Python ${{ matrix.python-version }} and CUDA ${{ matrix.cuda-version }}\n        run: |\n          cd sgl-kernel\n          ./build.sh \"${{ matrix.python-version }}\" \"${{ matrix.cuda-version }}\"\n        env:\n          USE_CCACHE: 1\n\n      - name: Verify wheel artifacts\n        run: |\n          ls -alh sgl-kernel/dist\n          ls -alh sgl-kernel/dist/*.whl\n\n      - name: Upload artifacts\n        uses: actions/upload-artifact@v4\n        with:\n          name: wheel-python${{ matrix.python-version }}-cuda${{ matrix.cuda-version }}\n          path: sgl-kernel/dist/*\n          if-no-files-found: error\n\n  sgl-kernel-build-wheels-arm:\n    needs: [check-changes, call-gate]\n    # Skip for scheduled runs (they run stages independently) and when target_stage is set\n    if: github.event_name != 'schedule' && inputs.test_parallel_dispatch != true && !inputs.target_stage && needs.check-changes.outputs.sgl_kernel == 'true'\n    runs-on: arm-kernel-build-node\n    timeout-minutes: 240\n    strategy:\n      matrix:\n        include:\n          - python-version: \"3.10\"\n            cuda-version: \"12.9\"\n    name: Build Wheel Arm\n    steps:\n      - name: Cleanup\n        run: |\n          if [ -d \"$GITHUB_WORKSPACE\" ]; then\n            sudo rm -rf \"$GITHUB_WORKSPACE\"/* || true\n          else\n            echo \"$GITHUB_WORKSPACE does not exist, nothing to clean\"\n          fi\n\n      - uses: actions/checkout@v4\n        with:\n          submodules: \"recursive\"\n          ref: ${{ inputs.pr_head_sha || inputs.git_ref || github.sha }}\n\n      - name: Set up Python ${{ matrix.python-version }}\n        uses: actions/setup-python@v5\n        with:\n          python-version: ${{ matrix.python-version }}\n\n      - name: Build wheel for Python ${{ matrix.python-version }} and CUDA ${{ matrix.cuda-version }}\n        run: |\n          cd sgl-kernel\n          ./build.sh \"${{ matrix.python-version }}\" \"${{ matrix.cuda-version }}\"\n        env:\n          USE_CCACHE: 1\n\n      - name: Verify wheel artifacts\n        run: |\n          ls -alh sgl-kernel/dist\n          ls -alh sgl-kernel/dist/*.whl\n\n      - name: Upload artifacts\n        uses: actions/upload-artifact@v4\n        with:\n          name: wheel-python${{ matrix.python-version }}-cuda${{ matrix.cuda-version }}-aarch64\n          path: sgl-kernel/dist/*\n          if-no-files-found: error\n\n  sgl-kernel-unit-test:\n    needs: [check-changes, call-gate, sgl-kernel-build-wheels]\n    # Skip for scheduled runs and when target_stage is set\n    if: |\n      github.event_name != 'schedule' &&\n      inputs.test_parallel_dispatch != true &&\n      !inputs.target_stage &&\n      needs.check-changes.outputs.sgl_kernel == 'true'\n    runs-on: 1-gpu-runner\n    timeout-minutes: 240\n    steps:\n      - uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.git_ref || github.sha }}\n\n      - name: Cleanup\n        run: |\n          ls -alh sgl-kernel/dist || true\n          rm -rf sgl-kernel/dist/* || true\n\n      - name: Download artifacts\n        uses: actions/download-artifact@v4\n        with:\n          path: sgl-kernel/dist/\n          merge-multiple: true\n          pattern: wheel-python3.10-cuda12.9\n\n      - name: Install dependencies\n        timeout-minutes: 20\n        run: |\n          CUSTOM_BUILD_SGL_KERNEL=${{needs.check-changes.outputs.sgl_kernel}} bash scripts/ci/cuda/ci_install_dependency.sh diffusion\n\n      - name: Run test\n        timeout-minutes: 30\n        run: |\n          cd sgl-kernel\n          pytest tests/\n\n  sgl-kernel-mla-test:\n    needs: [check-changes, call-gate, sgl-kernel-build-wheels]\n    # Skip for scheduled runs and when target_stage is set\n    if: |\n      github.event_name != 'schedule' &&\n      inputs.test_parallel_dispatch != true &&\n      !inputs.target_stage &&\n      needs.check-changes.outputs.sgl_kernel == 'true'\n    runs-on: 1-gpu-runner\n    timeout-minutes: 240\n    steps:\n      - uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.git_ref || github.sha }}\n\n      - name: Cleanup\n        run: |\n          ls -alh sgl-kernel/dist || true\n          rm -rf sgl-kernel/dist/* || true\n\n      - name: Download artifacts\n        uses: actions/download-artifact@v4\n        with:\n          path: sgl-kernel/dist/\n          merge-multiple: true\n          pattern: wheel-python3.10-cuda12.9\n\n      - name: Install dependencies\n        timeout-minutes: 20\n        run: |\n          CUSTOM_BUILD_SGL_KERNEL=${{needs.check-changes.outputs.sgl_kernel}} bash scripts/ci/cuda/ci_install_dependency.sh\n\n      - name: Run test\n        timeout-minutes: 30\n        run: |\n          cd test/registered/mla\n          python3 test_mla_deepseek_v3.py\n\n  sgl-kernel-benchmark-test:\n    needs: [check-changes, call-gate, sgl-kernel-build-wheels]\n    # Skip for scheduled runs and when target_stage is set\n    if: |\n      github.event_name != 'schedule' &&\n      inputs.test_parallel_dispatch != true &&\n      !inputs.target_stage &&\n      needs.check-changes.outputs.sgl_kernel == 'true'\n    runs-on: 1-gpu-runner\n    timeout-minutes: 240\n    env:\n      CI: true\n    steps:\n      - uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.git_ref || github.sha }}\n\n      - name: Cleanup\n        run: |\n          ls -alh sgl-kernel/dist || true\n          rm -rf sgl-kernel/dist/* || true\n\n      - name: Download artifacts\n        uses: actions/download-artifact@v4\n        with:\n          path: sgl-kernel/dist/\n          merge-multiple: true\n          pattern: wheel-python3.10-cuda12.9\n\n      - name: Install dependencies\n        timeout-minutes: 20\n        run: |\n          CUSTOM_BUILD_SGL_KERNEL=${{needs.check-changes.outputs.sgl_kernel}} bash scripts/ci/cuda/ci_install_dependency.sh\n\n      - name: Run benchmark tests\n        timeout-minutes: 45\n        run: |\n          cd sgl-kernel/benchmark\n          echo \"Running sgl-kernel benchmark tests in CI mode...\"\n\n          echo \"CI environment variable: $CI\"\n          echo \"GITHUB_ACTIONS environment variable: $GITHUB_ACTIONS\"\n\n          for bench_file in bench_*.py; do\n            echo \"Testing $bench_file...\"\n            timeout 60 python3 \"$bench_file\" || echo \"Warning: $bench_file timed out or failed, continuing...\"\n            echo \"Completed $bench_file\"\n            echo \"---\"\n          done\n\n          echo \"All benchmark tests completed!\"\n\n  sgl-kernel-b200-test:\n    needs: [check-changes, sgl-kernel-build-wheels]\n    # Skip for scheduled runs and when target_stage is set\n    if: |\n      github.event_name != 'schedule' &&\n      inputs.test_parallel_dispatch != true &&\n      !inputs.target_stage &&\n      needs.check-changes.outputs.sgl_kernel == 'true'\n    runs-on: ${{ needs.check-changes.outputs.b200_runner }}\n    timeout-minutes: 240\n    steps:\n      - uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.git_ref || github.sha }}\n\n      - name: Cleanup\n        run: |\n          ls -alh sgl-kernel/dist || true\n          rm -rf sgl-kernel/dist/* || true\n\n      - name: Download artifacts\n        uses: actions/download-artifact@v4\n        with:\n          path: sgl-kernel/dist/\n          merge-multiple: true\n          pattern: wheel-python3.10-cuda12.9\n\n      - name: Install dependencies\n        timeout-minutes: 20\n        run: |\n          CUSTOM_BUILD_SGL_KERNEL=${{needs.check-changes.outputs.sgl_kernel}} IS_BLACKWELL=1 bash scripts/ci/cuda/ci_install_dependency.sh diffusion\n\n      - name: Run sgl-kernel unit tests on B200\n        timeout-minutes: 30\n        run: |\n          cd sgl-kernel\n          pytest tests/\n\n  # Adding a single CUDA13 smoke test to verify that the kernel builds and runs\n  # TODO: Add back this test when it can pass on CI\n  # cuda13-kernel-smoke-test:\n  #   needs: [check-changes, sgl-kernel-build-wheels]\n  #   if: needs.check-changes.outputs.sgl_kernel == 'true'\n  #   runs-on: x64-cu13-kernel-tests\n  #   steps:\n  #     - uses: actions/checkout@v4\n\n  #     - name: Cleanup\n  #       run: |\n  #         ls -alh sgl-kernel/dist || true\n  #         rm -rf sgl-kernel/dist/* || true\n\n  #     - name: Download CUDA 13.0 artifacts\n  #       uses: actions/download-artifact@v4\n  #       with:\n  #         path: sgl-kernel/dist/\n  #         merge-multiple: true\n  #         pattern: wheel-python3.10-cuda13.0\n\n  #     - name: Install dependencies\n  #       run: |\n  #         CUSTOM_BUILD_SGL_KERNEL=${{needs.check-changes.outputs.sgl_kernel}} bash scripts/ci/cuda/ci_install_dependency.sh\n\n  #     - name: Run kernel unit tests\n  #       timeout-minutes: 30\n  #       run: |\n  #         cd sgl-kernel\n  #         pytest tests/\n\n  # =============================================== jit-kernel ====================================================\n\n  jit-kernel-unit-test:\n    needs: [check-changes, call-gate]\n    # Skip for scheduled runs and when target_stage is set\n    if: |\n      github.event_name != 'schedule' &&\n      inputs.test_parallel_dispatch != true &&\n      !inputs.target_stage &&\n      needs.check-changes.outputs.jit_kernel == 'true'\n    runs-on: 1-gpu-runner\n    timeout-minutes: 240\n    steps:\n      - uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.git_ref || github.sha }}\n\n      - name: Install dependencies\n        timeout-minutes: 20\n        run: |\n          bash scripts/ci/cuda/ci_install_dependency.sh\n\n      - name: Run test\n        timeout-minutes: 30\n        run: |\n          cd python/sglang/jit_kernel\n          pytest tests/\n\n  jit-kernel-unit-test-nightly:\n    needs: [check-changes]\n    if: |\n      github.event_name == 'schedule' &&\n      needs.check-changes.outputs.jit_kernel == 'true'\n    runs-on: 1-gpu-runner\n    timeout-minutes: 240\n    env:\n      SGLANG_JIT_KERNEL_RUN_FULL_TESTS: \"1\"\n    steps:\n      - uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.git_ref || github.sha }}\n\n      - name: Install dependencies\n        timeout-minutes: 20\n        run: |\n          bash scripts/ci/cuda/ci_install_dependency.sh\n\n      - name: Run full nightly test\n        timeout-minutes: 60\n        run: |\n          cd python/sglang/jit_kernel\n          pytest tests/\n\n  jit-kernel-benchmark-test:\n    needs: [check-changes, call-gate]\n    # Skip for scheduled runs and when target_stage is set\n    if: |\n      github.event_name != 'schedule' &&\n      inputs.test_parallel_dispatch != true &&\n      !inputs.target_stage &&\n      needs.check-changes.outputs.jit_kernel == 'true'\n    runs-on: 1-gpu-runner\n    timeout-minutes: 240\n    env:\n      CI: true\n    steps:\n      - uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.git_ref || github.sha }}\n\n      - name: Install dependencies\n        timeout-minutes: 20\n        run: |\n          bash scripts/ci/cuda/ci_install_dependency.sh\n\n      - name: Run benchmark tests\n        timeout-minutes: 45\n        run: |\n          cd python/sglang/jit_kernel/benchmark\n          echo \"Running jit-kernel benchmark tests in CI mode...\"\n\n          failures=()\n\n          for bench_file in bench_*.py; do\n            echo \"Testing $bench_file...\"\n            if ! timeout 120 python3 \"$bench_file\"; then\n              failures+=(\"$bench_file\")\n            fi\n            echo \"Completed $bench_file\"\n            echo \"---\"\n          done\n\n          if [ ${#failures[@]} -ne 0 ]; then\n            echo \"The following benchmark tests failed: ${failures[*]}\"\n            exit 1\n          fi\n\n          echo \"All jit-kernel benchmark tests completed successfully!\"\n\n  # =============================================== primary ====================================================\n\n  stage-a-test-1:\n    needs: [check-changes, call-gate, sgl-kernel-build-wheels]\n    if: |\n      always() &&\n      (\n        (inputs.target_stage == 'stage-a-test-1') ||\n        (\n          !inputs.target_stage &&\n          ((github.event_name == 'schedule' || inputs.test_parallel_dispatch == true) || (!failure() && !cancelled())) &&\n          ((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true'))\n        )\n      )\n    runs-on: 1-gpu-runner\n    timeout-minutes: 240\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.git_ref || github.sha }}\n\n      - name: Download artifacts\n        if: needs.check-changes.outputs.sgl_kernel == 'true'\n        uses: actions/download-artifact@v4\n        with:\n          path: sgl-kernel/dist/\n          merge-multiple: true\n          pattern: wheel-python3.10-cuda12.9\n\n      - name: Install dependencies\n        timeout-minutes: 20\n        run: |\n          CUSTOM_BUILD_SGL_KERNEL=${{needs.check-changes.outputs.sgl_kernel}} bash scripts/ci/cuda/ci_install_dependency.sh\n\n      - name: Run test\n        timeout-minutes: 10\n        run: |\n          cd test/\n          CONTINUE_ON_ERROR_FLAG=\"\"\n          if [[ \"${{ needs.check-changes.outputs.continue_on_error }}\" == \"true\" ]]; then\n            CONTINUE_ON_ERROR_FLAG=\"--continue-on-error\"\n          fi\n          python3 run_suite.py --hw cuda --suite stage-a-test-1 $CONTINUE_ON_ERROR_FLAG\n\n      - uses: ./.github/actions/upload-cuda-coredumps\n        if: always()\n\n  stage-a-cpu-only:\n    needs: [check-changes, call-gate]\n    if: |\n      always() &&\n      (\n        (inputs.target_stage == 'stage-a-cpu-only') ||\n        (\n          !inputs.target_stage &&\n          ((github.event_name == 'schedule' || inputs.test_parallel_dispatch == true) || (!failure() && !cancelled())) &&\n          (needs.check-changes.outputs.main_package == 'true')\n        )\n      )\n    runs-on: ubuntu-latest\n    timeout-minutes: 240\n    steps:\n      - name: Free disk space\n        run: |\n          sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc\n          df -h\n\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.git_ref || github.sha }}\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: '3.10'\n\n      - name: Install dependencies\n        timeout-minutes: 20\n        run: |\n          pip install -e \"python/[dev]\"\n\n      - name: Run test\n        timeout-minutes: 10\n        run: |\n          cd test/\n          CONTINUE_ON_ERROR_FLAG=\"\"\n          if [[ \"${{ needs.check-changes.outputs.continue_on_error }}\" == \"true\" ]]; then\n            CONTINUE_ON_ERROR_FLAG=\"--continue-on-error\"\n          fi\n          python3 run_suite.py --hw cpu --suite stage-a-cpu-only $CONTINUE_ON_ERROR_FLAG\n\n  # Runs on 5090 (32GB, SM120)\n  stage-b-test-small-1-gpu:\n    needs: [check-changes, call-gate, wait-for-stage-a, sgl-kernel-build-wheels]\n    if: |\n      always() &&\n      (\n        (inputs.target_stage == 'stage-b-test-small-1-gpu') ||\n        (\n          !inputs.target_stage &&\n          ((github.event_name == 'schedule' || inputs.test_parallel_dispatch == true) || (!failure() && !cancelled())) &&\n          ((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true'))\n        )\n      )\n    runs-on: 1-gpu-5090\n    timeout-minutes: 240\n    env:\n      IS_BLACKWELL: \"1\"\n    strategy:\n      fail-fast: false\n      max-parallel: 8\n      matrix:\n        partition: [0, 1, 2, 3, 4, 5, 6, 7]\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.git_ref || github.sha }}\n\n      - name: Download artifacts\n        if: needs.check-changes.outputs.sgl_kernel == 'true'\n        uses: actions/download-artifact@v4\n        with:\n          path: sgl-kernel/dist/\n          merge-multiple: true\n          pattern: wheel-python3.10-cuda12.9\n\n      - name: Install dependencies\n        timeout-minutes: 20\n        run: |\n          source /etc/profile.d/sglang-ci.sh\n          CUSTOM_BUILD_SGL_KERNEL=${{needs.check-changes.outputs.sgl_kernel}} bash scripts/ci/cuda/ci_install_dependency.sh\n          git clone https://github.com/merrymercy/human-eval.git\n          cd human-eval\n          pip install -e . --no-build-isolation\n\n      - name: Run test\n        timeout-minutes: 30\n        run: |\n          source /etc/profile.d/sglang-ci.sh\n          cd test/\n          CONTINUE_ON_ERROR_FLAG=\"\"\n          if [[ \"${{ needs.check-changes.outputs.continue_on_error }}\" == \"true\" ]]; then\n            CONTINUE_ON_ERROR_FLAG=\"--continue-on-error\"\n          fi\n          python3 run_suite.py --hw cuda --suite stage-b-test-small-1-gpu --auto-partition-id ${{ matrix.partition }} --auto-partition-size 8 $CONTINUE_ON_ERROR_FLAG\n\n      - uses: ./.github/actions/upload-cuda-coredumps\n        if: always()\n        with:\n          artifact-suffix: ${{ matrix.partition }}\n\n  # Runs on H100 (80GB, SM90) - tests that don't pass on 5090 (FA3, FP8, high VRAM, etc.)\n  stage-b-test-large-1-gpu:\n    needs: [check-changes, call-gate, wait-for-stage-a, sgl-kernel-build-wheels]\n    if: |\n      always() &&\n      (\n        (inputs.target_stage == 'stage-b-test-large-1-gpu') ||\n        (\n          !inputs.target_stage &&\n          ((github.event_name == 'schedule' || inputs.test_parallel_dispatch == true) || (!failure() && !cancelled())) &&\n          ((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true'))\n        )\n      )\n    runs-on: 1-gpu-runner\n    timeout-minutes: 240\n    strategy:\n      fail-fast: false\n      max-parallel: ${{ fromJson(needs.check-changes.outputs.max_parallel) }}\n      matrix:\n        partition: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.git_ref || github.sha }}\n\n      - name: Download artifacts\n        if: needs.check-changes.outputs.sgl_kernel == 'true'\n        uses: actions/download-artifact@v4\n        with:\n          path: sgl-kernel/dist/\n          merge-multiple: true\n          pattern: wheel-python3.10-cuda12.9\n\n      - name: Install dependencies\n        timeout-minutes: 20\n        run: |\n          CUSTOM_BUILD_SGL_KERNEL=${{needs.check-changes.outputs.sgl_kernel}} bash scripts/ci/cuda/ci_install_dependency.sh\n\n      - name: Run test\n        timeout-minutes: 30\n        run: |\n          cd test/\n          CONTINUE_ON_ERROR_FLAG=\"\"\n          if [[ \"${{ needs.check-changes.outputs.continue_on_error }}\" == \"true\" ]]; then\n            CONTINUE_ON_ERROR_FLAG=\"--continue-on-error\"\n          fi\n          python3 run_suite.py --hw cuda --suite stage-b-test-large-1-gpu --auto-partition-id ${{ matrix.partition }} --auto-partition-size 14 --timeout-per-file 1800 $CONTINUE_ON_ERROR_FLAG\n\n      - uses: ./.github/actions/upload-cuda-coredumps\n        if: always()\n        with:\n          artifact-suffix: ${{ matrix.partition }}\n\n  stage-b-test-large-2-gpu:\n    needs: [check-changes, call-gate, wait-for-stage-a, sgl-kernel-build-wheels]\n    if: |\n      always() &&\n      (\n        (inputs.target_stage == 'stage-b-test-large-2-gpu') ||\n        (\n          !inputs.target_stage &&\n          ((github.event_name == 'schedule' || inputs.test_parallel_dispatch == true) || (!failure() && !cancelled())) &&\n          ((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true'))\n        )\n      )\n    runs-on: 2-gpu-runner\n    timeout-minutes: 240\n    strategy:\n      fail-fast: false\n      matrix:\n        partition: [0, 1, 2, 3]\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.git_ref || github.sha }}\n\n      - name: Download artifacts\n        if: needs.check-changes.outputs.sgl_kernel == 'true'\n        uses: actions/download-artifact@v4\n        with:\n          path: sgl-kernel/dist/\n          merge-multiple: true\n          pattern: wheel-python3.10-cuda12.9\n\n      - name: Install dependencies\n        timeout-minutes: 20\n        run: |\n          CUSTOM_BUILD_SGL_KERNEL=${{needs.check-changes.outputs.sgl_kernel}} bash scripts/ci/cuda/ci_install_dependency.sh\n          git clone https://github.com/merrymercy/human-eval.git\n          cd human-eval\n          pip install -e . --no-build-isolation\n\n      - name: Run test\n        timeout-minutes: 30\n        run: |\n          cd test/\n          CONTINUE_ON_ERROR_FLAG=\"\"\n          if [[ \"${{ needs.check-changes.outputs.continue_on_error }}\" == \"true\" ]]; then\n            CONTINUE_ON_ERROR_FLAG=\"--continue-on-error\"\n          fi\n          python3 run_suite.py --hw cuda --suite stage-b-test-large-2-gpu --auto-partition-id ${{ matrix.partition }} --auto-partition-size 4 $CONTINUE_ON_ERROR_FLAG\n\n      - uses: ./.github/actions/upload-cuda-coredumps\n        if: always()\n        with:\n          artifact-suffix: ${{ matrix.partition }}\n\n  stage-b-test-4-gpu-b200:\n    needs: [check-changes, call-gate, wait-for-stage-a, sgl-kernel-build-wheels]\n    if: |\n      always() &&\n      (\n        (inputs.target_stage == 'stage-b-test-4-gpu-b200') ||\n        (\n          !inputs.target_stage &&\n          ((github.event_name == 'schedule' || inputs.test_parallel_dispatch == true) || (!failure() && !cancelled())) &&\n          ((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true'))\n        )\n      )\n    runs-on: ${{ needs.check-changes.outputs.b200_runner }}\n    timeout-minutes: 240\n    strategy:\n      fail-fast: false\n\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.git_ref || github.sha }}\n\n      - name: Download artifacts\n        if: needs.check-changes.outputs.sgl_kernel == 'true'\n        uses: actions/download-artifact@v6\n        with:\n          path: sgl-kernel/dist/\n          merge-multiple: true\n          pattern: wheel-python3.10-cuda12.9\n\n      - name: Install dependencies\n        timeout-minutes: 20\n        run: |\n          CUSTOM_BUILD_SGL_KERNEL=${{needs.check-changes.outputs.sgl_kernel}} IS_BLACKWELL=1 bash scripts/ci/cuda/ci_install_dependency.sh\n\n      - name: Run test\n        timeout-minutes: 30\n        run: |\n          cd test\n          CONTINUE_ON_ERROR_FLAG=\"\"\n          if [[ \"${{ needs.check-changes.outputs.continue_on_error }}\" == \"true\" ]]; then\n            CONTINUE_ON_ERROR_FLAG=\"--continue-on-error\"\n          fi\n          IS_BLACKWELL=1 python3 run_suite.py --hw cuda --suite stage-b-test-4-gpu-b200 $CONTINUE_ON_ERROR_FLAG\n\n      - name: Run FA4 jit_kernel tests (SM100+)\n        timeout-minutes: 10\n        run: |\n          IS_BLACKWELL=1 python3 -m pytest -q python/sglang/jit_kernel/tests/test_flash_attention_4.py\n\n      - uses: ./.github/actions/upload-cuda-coredumps\n        if: always()\n\n  multimodal-gen-test-1-gpu:\n    needs: [check-changes, call-gate, sgl-kernel-build-wheels]\n    if: |\n      always() &&\n      (\n        (inputs.target_stage == 'multimodal-gen-test-1-gpu') ||\n        (\n          !inputs.target_stage &&\n          ((github.event_name == 'schedule' || inputs.test_parallel_dispatch == true) || (!failure() && !cancelled())) &&\n          needs.check-changes.outputs.multimodal_gen == 'true'\n        )\n      )\n    runs-on: 1-gpu-runner\n    timeout-minutes: 240\n    strategy:\n      fail-fast: false\n      matrix:\n        part: [0, 1]\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.git_ref || github.sha }}\n\n      - name: Download artifacts\n        if: needs.check-changes.outputs.sgl_kernel == 'true'\n        uses: actions/download-artifact@v4\n        with:\n          path: sgl-kernel/dist/\n          merge-multiple: true\n          pattern: wheel-python3.10-cuda12.9\n\n      - name: Install dependencies\n        timeout-minutes: 20\n        run: |\n          CUSTOM_BUILD_SGL_KERNEL=${{needs.check-changes.outputs.sgl_kernel}} bash scripts/ci/cuda/ci_install_dependency.sh diffusion\n      - name: Run diffusion server tests\n        timeout-minutes: 240\n        env:\n          RUNAI_STREAMER_MEMORY_LIMIT: 0\n        run: |\n          cd python\n          CONTINUE_ON_ERROR_FLAG=\"\"\n          if [[ \"${{ needs.check-changes.outputs.continue_on_error }}\" == \"true\" ]]; then\n            CONTINUE_ON_ERROR_FLAG=\"--continue-on-error\"\n          fi\n          python3 sglang/multimodal_gen/test/run_suite.py \\\n            --suite 1-gpu \\\n            --partition-id ${{ matrix.part }} \\\n            --total-partitions 2 \\\n            $CONTINUE_ON_ERROR_FLAG\n\n      - uses: ./.github/actions/upload-cuda-coredumps\n        if: always()\n        with:\n          artifact-suffix: ${{ matrix.part }}\n\n  multimodal-gen-test-2-gpu:\n    needs: [check-changes, call-gate, sgl-kernel-build-wheels]\n    if: |\n      always() &&\n      (\n        (inputs.target_stage == 'multimodal-gen-test-2-gpu') ||\n        (\n          !inputs.target_stage &&\n          ((github.event_name == 'schedule' || inputs.test_parallel_dispatch == true) || (!failure() && !cancelled())) &&\n          needs.check-changes.outputs.multimodal_gen == 'true'\n        )\n      )\n    runs-on: 2-gpu-runner\n    timeout-minutes: 240\n    strategy:\n      fail-fast: false\n      matrix:\n        part: [0, 1]\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.git_ref || github.sha }}\n\n      - name: Download artifacts\n        if: needs.check-changes.outputs.sgl_kernel == 'true'\n        uses: actions/download-artifact@v4\n        with:\n          path: sgl-kernel/dist/\n          merge-multiple: true\n          pattern: wheel-python3.10-cuda12.9\n\n      - name: Install dependencies\n        timeout-minutes: 20\n        run: |\n          CUSTOM_BUILD_SGL_KERNEL=${{needs.check-changes.outputs.sgl_kernel}} bash scripts/ci/cuda/ci_install_dependency.sh diffusion\n\n      - name: Run diffusion server tests\n        timeout-minutes: 240\n        env:\n          RUNAI_STREAMER_MEMORY_LIMIT: 0\n        run: |\n          cd python\n          CONTINUE_ON_ERROR_FLAG=\"\"\n          if [[ \"${{ needs.check-changes.outputs.continue_on_error }}\" == \"true\" ]]; then\n            CONTINUE_ON_ERROR_FLAG=\"--continue-on-error\"\n          fi\n          python3 sglang/multimodal_gen/test/run_suite.py \\\n            --suite 2-gpu \\\n            --partition-id ${{ matrix.part }} \\\n            --total-partitions 2 \\\n            $CONTINUE_ON_ERROR_FLAG\n\n      - uses: ./.github/actions/upload-cuda-coredumps\n        if: always()\n        with:\n          artifact-suffix: ${{ matrix.part }}\n\n  multimodal-gen-unit-test:\n    needs: [check-changes, call-gate, sgl-kernel-build-wheels]\n    if: |\n      always() &&\n      (\n        (inputs.target_stage == 'multimodal-gen-unit-test') ||\n        (\n          !inputs.target_stage &&\n          ((github.event_name == 'schedule' || inputs.test_parallel_dispatch == true) || (!failure() && !cancelled())) &&\n          needs.check-changes.outputs.multimodal_gen == 'true'\n        )\n      )\n    runs-on: 1-gpu-runner\n    timeout-minutes: 120\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.git_ref || github.sha }}\n\n      - name: Download artifacts\n        if: needs.check-changes.outputs.sgl_kernel == 'true'\n        uses: actions/download-artifact@v4\n        with:\n          path: sgl-kernel/dist/\n          merge-multiple: true\n          pattern: wheel-python3.10-cuda12.9\n\n      - name: Install dependencies\n        timeout-minutes: 20\n        run: |\n          CUSTOM_BUILD_SGL_KERNEL=${{needs.check-changes.outputs.sgl_kernel}} bash scripts/ci/cuda/ci_install_dependency.sh diffusion\n\n      - name: Run diffusion unit tests\n        timeout-minutes: 60\n        run: |\n          cd python\n          python3 sglang/multimodal_gen/test/run_suite.py --suite unit\n\n  stage-c-test-4-gpu-h100:\n    needs: [check-changes, call-gate, wait-for-stage-b]\n    if: |\n      always() &&\n      (\n        (inputs.target_stage == 'stage-c-test-4-gpu-h100') ||\n        (\n          !inputs.target_stage &&\n          ((github.event_name == 'schedule' || inputs.test_parallel_dispatch == true) || (!failure() && !cancelled())) &&\n          ((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true'))\n        )\n      )\n    runs-on: 4-gpu-h100\n    timeout-minutes: 240\n    strategy:\n      fail-fast: false\n      matrix:\n        part: [0, 1, 2]\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.git_ref || github.sha }}\n\n      - name: Download artifacts\n        if: needs.check-changes.outputs.sgl_kernel == 'true'\n        uses: actions/download-artifact@v4\n        with:\n          path: sgl-kernel/dist/\n          merge-multiple: true\n          pattern: wheel-python3.10-cuda12.9\n\n      - name: Install dependencies\n        timeout-minutes: 20\n        run: |\n          CUSTOM_BUILD_SGL_KERNEL=${{needs.check-changes.outputs.sgl_kernel}} bash scripts/ci/cuda/ci_install_dependency.sh\n\n      - name: Run test\n        timeout-minutes: 30\n        run: |\n          cd test\n          CONTINUE_ON_ERROR_FLAG=\"\"\n          if [[ \"${{ needs.check-changes.outputs.continue_on_error }}\" == \"true\" ]]; then\n            CONTINUE_ON_ERROR_FLAG=\"--continue-on-error\"\n          fi\n          python3 run_suite.py --hw cuda --suite stage-c-test-4-gpu-h100 --auto-partition-id ${{ matrix.part }} --auto-partition-size 3 $CONTINUE_ON_ERROR_FLAG\n\n      - uses: ./.github/actions/upload-cuda-coredumps\n        if: always()\n        with:\n          artifact-suffix: ${{ matrix.part }}\n\n  stage-c-test-8-gpu-h200:\n    needs: [check-changes, call-gate, wait-for-stage-b]\n    if: |\n      always() &&\n      (\n        (inputs.target_stage == 'stage-c-test-8-gpu-h200') ||\n        (\n          !inputs.target_stage &&\n          ((github.event_name == 'schedule' || inputs.test_parallel_dispatch == true) || (!failure() && !cancelled())) &&\n          ((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true'))\n        )\n      )\n    runs-on: 8-gpu-h200\n    timeout-minutes: 240\n    strategy:\n      fail-fast: false\n      matrix:\n        part: [0, 1, 2, 3]\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.git_ref || github.sha }}\n\n      - name: Download artifacts\n        if: needs.check-changes.outputs.sgl_kernel == 'true'\n        uses: actions/download-artifact@v4\n        with:\n          path: sgl-kernel/dist/\n          merge-multiple: true\n          pattern: wheel-python3.10-cuda12.9\n\n      - name: Install dependencies\n        timeout-minutes: 20\n        run: |\n          CUSTOM_BUILD_SGL_KERNEL=${{needs.check-changes.outputs.sgl_kernel}} bash scripts/ci/cuda/ci_install_dependency.sh\n\n      - name: Warmup DeepGEMM JIT Compilation\n        timeout-minutes: 25\n        run: |\n          python3 scripts/ci/cuda/warmup_deep_gemm.py \\\n            deepseek-ai/DeepSeek-V3-0324:8 \\\n            deepseek-ai/DeepSeek-V3.2-Exp:8\n\n      - name: Warmup Server CUDA Graphs\n        timeout-minutes: 25\n        run: |\n          python3 scripts/ci/cuda/warmup_server.py \\\n            deepseek-ai/DeepSeek-V3-0324:8 \\\n            inclusionAI/Ring-2.5-1T:8\n\n      - name: Run test\n        timeout-minutes: 30\n        run: |\n          cd test\n          CONTINUE_ON_ERROR_FLAG=\"\"\n          if [[ \"${{ needs.check-changes.outputs.continue_on_error }}\" == \"true\" ]]; then\n            CONTINUE_ON_ERROR_FLAG=\"--continue-on-error\"\n          fi\n          python3 run_suite.py --hw cuda --suite stage-c-test-8-gpu-h200 --auto-partition-id ${{ matrix.part }} --auto-partition-size 4 $CONTINUE_ON_ERROR_FLAG\n\n      - uses: ./.github/actions/upload-cuda-coredumps\n        if: always()\n        with:\n          artifact-suffix: ${{ matrix.part }}\n\n  stage-c-test-8-gpu-h20:\n    needs: [check-changes, call-gate, wait-for-stage-b]\n    if: |\n      always() &&\n      (\n        (inputs.target_stage == 'stage-c-test-8-gpu-h20') ||\n        (\n          !inputs.target_stage &&\n          ((github.event_name == 'schedule' || inputs.test_parallel_dispatch == true) || (!failure() && !cancelled())) &&\n          ((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true'))\n        )\n      )\n    runs-on: 8-gpu-h20\n    timeout-minutes: 240\n    env:\n      SGLANG_CI_RDMA_ALL_DEVICES: \"mlx5_1,mlx5_2,mlx5_3,mlx5_4\"\n    strategy:\n      fail-fast: false\n      matrix:\n        part: [0, 1]\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.git_ref || github.sha }}\n\n      - name: Download artifacts\n        if: needs.check-changes.outputs.sgl_kernel == 'true'\n        uses: actions/download-artifact@v4\n        with:\n          path: sgl-kernel/dist/\n          merge-multiple: true\n          pattern: wheel-python3.10-cuda12.9\n\n      - name: Install dependencies\n        timeout-minutes: 20\n        run: |\n          CUSTOM_BUILD_SGL_KERNEL=${{needs.check-changes.outputs.sgl_kernel}} bash scripts/ci/cuda/ci_install_deepep.sh\n\n      - name: Run test\n        timeout-minutes: 30\n        run: |\n          cd test\n          CONTINUE_ON_ERROR_FLAG=\"\"\n          if [[ \"${{ needs.check-changes.outputs.continue_on_error }}\" == \"true\" ]]; then\n            CONTINUE_ON_ERROR_FLAG=\"--continue-on-error\"\n          fi\n          python3 run_suite.py --hw cuda --suite stage-c-test-8-gpu-h20 --auto-partition-id ${{ matrix.part }} --auto-partition-size 2 $CONTINUE_ON_ERROR_FLAG\n\n      - uses: ./.github/actions/upload-cuda-coredumps\n        if: always()\n        with:\n          artifact-suffix: ${{ matrix.part }}\n\n  stage-c-test-deepep-4-gpu:\n    needs: [check-changes, call-gate, wait-for-stage-b]\n    if: |\n      always() &&\n      (\n        (inputs.target_stage == 'stage-c-test-deepep-4-gpu') ||\n        (\n          !inputs.target_stage &&\n          ((github.event_name == 'schedule' || inputs.test_parallel_dispatch == true) || (!failure() && !cancelled())) &&\n          ((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true'))\n        )\n      )\n    runs-on: 4-gpu-h100\n    timeout-minutes: 240\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.git_ref || github.sha }}\n\n      - name: Download artifacts\n        if: needs.check-changes.outputs.sgl_kernel == 'true'\n        uses: actions/download-artifact@v4\n        with:\n          path: sgl-kernel/dist/\n          merge-multiple: true\n          pattern: wheel-python3.10-cuda12.9\n\n      - name: Install dependencies\n        timeout-minutes: 20\n        run: |\n          CUSTOM_BUILD_SGL_KERNEL=${{needs.check-changes.outputs.sgl_kernel}} bash scripts/ci/cuda/ci_install_deepep.sh\n\n      - name: Warmup DeepGEMM JIT Compilation\n        timeout-minutes: 25\n        run: |\n          python3 scripts/ci/cuda/warmup_deep_gemm.py \\\n            lmsys/sglang-ci-dsv3-test:4\n\n      - name: Warmup Server CUDA Graphs\n        timeout-minutes: 25\n        run: |\n          python3 scripts/ci/cuda/warmup_server.py \\\n            lmsys/sglang-ci-dsv3-test:4\n\n      - name: Run test\n        timeout-minutes: 30\n        run: |\n          cd test\n          CONTINUE_ON_ERROR_FLAG=\"\"\n          if [[ \"${{ needs.check-changes.outputs.continue_on_error }}\" == \"true\" ]]; then\n            CONTINUE_ON_ERROR_FLAG=\"--continue-on-error\"\n          fi\n          python3 run_suite.py --hw cuda --suite stage-c-test-deepep-4-gpu $CONTINUE_ON_ERROR_FLAG\n\n      - uses: ./.github/actions/upload-cuda-coredumps\n        if: always()\n\n  stage-c-test-deepep-8-gpu-h200:\n    needs: [check-changes, call-gate, wait-for-stage-b]\n    if: |\n      always() &&\n      (\n        (inputs.target_stage == 'stage-c-test-deepep-8-gpu-h200') ||\n        (\n          !inputs.target_stage &&\n          ((github.event_name == 'schedule' || inputs.test_parallel_dispatch == true) || (!failure() && !cancelled())) &&\n          ((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true'))\n        )\n      )\n    runs-on: 8-gpu-h200\n    timeout-minutes: 240\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.git_ref || github.sha }}\n\n      - name: Download artifacts\n        if: needs.check-changes.outputs.sgl_kernel == 'true'\n        uses: actions/download-artifact@v4\n        with:\n          path: sgl-kernel/dist/\n          merge-multiple: true\n          pattern: wheel-python3.10-cuda12.9\n\n      - name: Install dependencies\n        timeout-minutes: 20\n        run: |\n          CUSTOM_BUILD_SGL_KERNEL=${{needs.check-changes.outputs.sgl_kernel}} bash scripts/ci/cuda/ci_install_deepep.sh\n\n      - name: Warmup DeepGEMM JIT Compilation\n        timeout-minutes: 25\n        run: |\n          python3 scripts/ci/cuda/warmup_deep_gemm.py \\\n            deepseek-ai/DeepSeek-V3-0324:8 \\\n            deepseek-ai/DeepSeek-V3.2-Exp:8\n\n      - name: Warmup Server CUDA Graphs\n        timeout-minutes: 25\n        run: |\n          python3 scripts/ci/cuda/warmup_server.py \\\n            deepseek-ai/DeepSeek-V3-0324:8\n\n      - name: Run test\n        timeout-minutes: 45\n        run: |\n          cd test\n          CONTINUE_ON_ERROR_FLAG=\"\"\n          if [[ \"${{ needs.check-changes.outputs.continue_on_error }}\" == \"true\" ]]; then\n            CONTINUE_ON_ERROR_FLAG=\"--continue-on-error\"\n          fi\n          python3 run_suite.py --hw cuda --suite stage-c-test-deepep-8-gpu-h200 $CONTINUE_ON_ERROR_FLAG\n\n      - uses: ./.github/actions/upload-cuda-coredumps\n        if: always()\n\n  stage-c-test-4-gpu-b200:\n    needs: [check-changes, call-gate, wait-for-stage-b]\n    if: |\n      always() &&\n      (\n        (inputs.target_stage == 'stage-c-test-4-gpu-b200') ||\n        (\n          !inputs.target_stage &&\n          ((github.event_name == 'schedule' || inputs.test_parallel_dispatch == true) || (!failure() && !cancelled())) &&\n          ((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true'))\n        )\n      )\n    runs-on: ${{ needs.check-changes.outputs.b200_runner }}\n    timeout-minutes: 240\n    strategy:\n      fail-fast: false\n      matrix:\n        part: [0, 1, 2, 3]\n\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || inputs.git_ref || github.sha }}\n\n      - name: Download artifacts\n        if: needs.check-changes.outputs.sgl_kernel == 'true'\n        uses: actions/download-artifact@v6\n        with:\n          path: sgl-kernel/dist/\n          merge-multiple: true\n          pattern: wheel-python3.10-cuda12.9\n\n      - name: Install dependencies\n        timeout-minutes: 20\n        run: |\n          CUSTOM_BUILD_SGL_KERNEL=${{needs.check-changes.outputs.sgl_kernel}} IS_BLACKWELL=1 bash scripts/ci/cuda/ci_install_dependency.sh\n\n      - name: Run test\n        timeout-minutes: 30\n        run: |\n          cd test\n          CONTINUE_ON_ERROR_FLAG=\"\"\n          if [[ \"${{ needs.check-changes.outputs.continue_on_error }}\" == \"true\" ]]; then\n            CONTINUE_ON_ERROR_FLAG=\"--continue-on-error\"\n          fi\n          IS_BLACKWELL=1 python3 run_suite.py --hw cuda --suite stage-c-test-4-gpu-b200 --auto-partition-id ${{ matrix.part }} --auto-partition-size 4 --timeout-per-file 1800 $CONTINUE_ON_ERROR_FLAG\n\n      - uses: ./.github/actions/upload-cuda-coredumps\n        if: always()\n        with:\n          artifact-suffix: ${{ matrix.part }}\n\n  # NOTE: GB200 stage temporarily disabled — no company-owned GB200 runner available yet.\n  # Re-enable when a 4-gpu-gb200 runner is provisioned.\n  # stage-c-test-4-gpu-gb200:\n  #   needs: [check-changes, call-gate, wait-for-stage-b, sgl-kernel-build-wheels-arm]\n  #   if: |\n  #     always() &&\n  #     (\n  #       (inputs.target_stage == 'stage-c-test-4-gpu-gb200') ||\n  #       (\n  #         !inputs.target_stage &&\n  #         ((github.event_name == 'schedule' || inputs.test_parallel_dispatch == true) || (!failure() && !cancelled())) &&\n  #         ((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true'))\n  #       )\n  #     )\n  #   runs-on: 4-gpu-gb200\n  #   timeout-minutes: 240\n  #   strategy:\n  #     fail-fast: false\n  #   steps:\n  #     - name: Checkout code\n  #       uses: actions/checkout@v4\n  #       with:\n  #         ref: ${{ inputs.pr_head_sha || inputs.git_ref || github.sha }}\n  #\n  #     - name: Download artifacts\n  #       if: needs.check-changes.outputs.sgl_kernel == 'true'\n  #       uses: actions/download-artifact@v4\n  #       with:\n  #         path: sgl-kernel/dist/\n  #         merge-multiple: true\n  #         pattern: wheel-python3.10-cuda12.9-aarch64\n  #\n  #     - name: Install dependencies\n  #       timeout-minutes: 20\n  #       run: |\n  #         CUSTOM_BUILD_SGL_KERNEL=${{needs.check-changes.outputs.sgl_kernel}} IS_BLACKWELL=1 GRACE_BLACKWELL=1 bash scripts/ci/cuda/ci_install_deepep.sh\n  #\n  #     - name: Run test\n  #       timeout-minutes: 45\n  #       run: |\n  #         cd test\n  #         CONTINUE_ON_ERROR_FLAG=\"\"\n  #         if [[ \"${{ needs.check-changes.outputs.continue_on_error }}\" == \"true\" ]]; then\n  #           CONTINUE_ON_ERROR_FLAG=\"--continue-on-error\"\n  #         fi\n  #         python3 run_suite.py --hw cuda --suite stage-c-test-4-gpu-gb200 --timeout-per-file 3600 $CONTINUE_ON_ERROR_FLAG\n  #\n  #     - uses: ./.github/actions/upload-cuda-coredumps\n  #       if: always()\n\n  pr-test-finish:\n    needs:\n      [\n        call-gate,\n        check-changes,\n\n        sgl-kernel-build-wheels,\n        sgl-kernel-unit-test,\n        sgl-kernel-mla-test,\n        sgl-kernel-benchmark-test,\n        sgl-kernel-b200-test,\n\n        wait-for-stage-a,\n        wait-for-stage-b,\n\n        jit-kernel-unit-test,\n        jit-kernel-unit-test-nightly,\n        jit-kernel-benchmark-test,\n\n        multimodal-gen-unit-test,\n        multimodal-gen-test-1-gpu,\n        multimodal-gen-test-2-gpu,\n\n        stage-a-test-1,\n        stage-a-cpu-only,\n        stage-b-test-small-1-gpu,\n        stage-b-test-large-1-gpu,\n        stage-b-test-large-2-gpu,\n        stage-b-test-4-gpu-b200,\n        stage-c-test-4-gpu-h100,\n        stage-c-test-8-gpu-h20,\n        stage-c-test-8-gpu-h200,\n        stage-c-test-deepep-4-gpu,\n        stage-c-test-deepep-8-gpu-h200,\n        stage-c-test-4-gpu-b200,\n        # stage-c-test-4-gpu-gb200,  # Temporarily disabled — no GB200 runner\n      ]\n    if: always()\n    runs-on: ubuntu-latest\n    steps:\n      - name: Check all dependent job statuses\n        run: |\n          # Convert the 'needs' context to a JSON string\n          json_needs='${{ toJson(needs) }}'\n\n          # Get a list of all job names from the JSON keys\n          job_names=$(echo \"$json_needs\" | jq -r 'keys_unsorted[]')\n\n          for job in $job_names; do\n            # For each job, extract its result\n            result=$(echo \"$json_needs\" | jq -r --arg j \"$job\" '.[$j].result')\n\n            # Print the job name and its result\n            echo \"$job: $result\"\n\n            # Check for failure or cancellation and exit if found\n            if [[ \"$result\" == \"failure\" || \"$result\" == \"cancelled\" ]]; then\n              echo \"The above jobs failed.\"\n              exit 1\n            fi\n          done\n          # If the loop completes, all jobs were successful\n          echo \"All jobs completed successfully\"\n          exit 0\n"
  },
  {
    "path": ".github/workflows/release-branch-cut.yml",
    "content": "name: Release Branch Cut\n\non:\n  workflow_dispatch:\n    inputs:\n      branch_name:\n        description: 'Branch name to create (e.g., release/v0.5.7)'\n        required: true\n        type: string\n      commit_sha:\n        description: 'Commit SHA from main to cut the release branch from (defaults to latest main)'\n        required: false\n        type: string\n        default: ''\n\npermissions:\n  actions: write\n  contents: write\n  pull-requests: read\n\njobs:\n  cut-release-branch:\n    if: github.repository == 'sgl-project/sglang'\n    runs-on: ubuntu-latest\n    environment: 'prod'\n    outputs:\n      branch_name: ${{ steps.set_output.outputs.branch_name }}\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v4\n        with:\n          ref: main\n          fetch-depth: 0\n          token: ${{ secrets.GITHUB_TOKEN }}\n\n      - name: Validate branch name\n        run: |\n          BRANCH_NAME=\"${{ github.event.inputs.branch_name }}\"\n\n          if [ -z \"$BRANCH_NAME\" ]; then\n            echo \"::error::Branch name is required\"\n            exit 1\n          fi\n\n          # Validate branch name format (should start with release/)\n          if [[ ! \"$BRANCH_NAME\" =~ ^release/ ]]; then\n            echo \"::warning::Branch name '$BRANCH_NAME' does not follow convention 'release/vX.Y.Z'\"\n          fi\n\n          echo \"Branch name: $BRANCH_NAME\"\n\n      - name: Validate commit SHA\n        id: validate\n        run: |\n          COMMIT_SHA=\"${{ github.event.inputs.commit_sha }}\"\n\n          # If no commit SHA provided, use latest main\n          if [ -z \"$COMMIT_SHA\" ]; then\n            COMMIT_SHA=$(git rev-parse HEAD)\n            echo \"No commit SHA provided, using latest main: $COMMIT_SHA\"\n          fi\n\n          # Verify the commit exists and is on main\n          if ! git cat-file -t \"$COMMIT_SHA\" > /dev/null 2>&1; then\n            echo \"::error::Commit SHA '$COMMIT_SHA' does not exist\"\n            exit 1\n          fi\n\n          # Check if commit is an ancestor of main (i.e., is on main branch)\n          if ! git merge-base --is-ancestor \"$COMMIT_SHA\" main; then\n            echo \"::error::Commit SHA '$COMMIT_SHA' is not on the main branch\"\n            exit 1\n          fi\n\n          echo \"COMMIT_SHA=$COMMIT_SHA\" >> $GITHUB_OUTPUT\n          echo \"Validated commit SHA: $COMMIT_SHA\"\n\n      - name: Check if branch already exists\n        run: |\n          BRANCH_NAME=\"${{ github.event.inputs.branch_name }}\"\n\n          if git ls-remote --heads origin \"$BRANCH_NAME\" | grep -q \"$BRANCH_NAME\"; then\n            echo \"::error::Branch '$BRANCH_NAME' already exists\"\n            exit 1\n          fi\n\n          echo \"Branch '$BRANCH_NAME' does not exist, proceeding with creation\"\n\n      - name: Create release branch\n        id: set_output\n        run: |\n          COMMIT_SHA=\"${{ steps.validate.outputs.COMMIT_SHA }}\"\n          BRANCH_NAME=\"${{ github.event.inputs.branch_name }}\"\n\n          git config user.name \"sglang-bot\"\n          git config user.email \"sglang-bot@users.noreply.github.com\"\n\n          # Create branch from the specified commit\n          git checkout -b \"$BRANCH_NAME\" \"$COMMIT_SHA\"\n\n          echo \"branch_name=$BRANCH_NAME\" >> $GITHUB_OUTPUT\n          echo \"Successfully created branch '$BRANCH_NAME' from commit '$COMMIT_SHA'\"\n\n      - name: Update version references in documentation\n        run: |\n          BRANCH_NAME=\"${{ github.event.inputs.branch_name }}\"\n          # Extract version from branch name (e.g., release/v0.5.8 -> v0.5.8)\n          VERSION=$(echo \"$BRANCH_NAME\" | sed 's/release\\///')\n\n          # Update git clone version references in docs\n          sed -i \"s/git clone -b v[0-9]\\+\\.[0-9]\\+\\.[0-9]\\+\\.\\?post\\?[0-9]*/git clone -b $VERSION/\" docs/get_started/install.md\n          sed -i \"s/git clone -b v[0-9]\\+\\.[0-9]\\+\\.[0-9]\\+\\.\\?post\\?[0-9]*/git clone -b $VERSION/\" docs/platforms/amd_gpu.md\n\n          # Check if any changes were made\n          if git diff --quiet; then\n            echo \"No version references needed updating\"\n          else\n            git add docs/get_started/install.md docs/platforms/amd_gpu.md\n            git commit -m \"docs: update version references to $VERSION\"\n            echo \"Updated version references to $VERSION\"\n          fi\n\n      - name: Push release branch\n        run: |\n          BRANCH_NAME=\"${{ steps.set_output.outputs.branch_name }}\"\n          git push origin \"$BRANCH_NAME\"\n          echo \"Successfully pushed branch '$BRANCH_NAME'\"\n\n      - name: Summary\n        run: |\n          COMMIT_SHA=\"${{ steps.validate.outputs.COMMIT_SHA }}\"\n          BRANCH_NAME=\"${{ github.event.inputs.branch_name }}\"\n\n          echo \"## Release Branch Cut Summary\" >> $GITHUB_STEP_SUMMARY\n          echo \"\" >> $GITHUB_STEP_SUMMARY\n          echo \"| Property | Value |\" >> $GITHUB_STEP_SUMMARY\n          echo \"|----------|-------|\" >> $GITHUB_STEP_SUMMARY\n          echo \"| Branch | \\`$BRANCH_NAME\\` |\" >> $GITHUB_STEP_SUMMARY\n          echo \"| Commit | \\`$COMMIT_SHA\\` |\" >> $GITHUB_STEP_SUMMARY\n          echo \"| Triggered by | @${{ github.actor }} |\" >> $GITHUB_STEP_SUMMARY\n          echo \"\" >> $GITHUB_STEP_SUMMARY\n          echo \"### Next Steps\" >> $GITHUB_STEP_SUMMARY\n          echo \"1. Tests are automatically triggered on the release branch\" >> $GITHUB_STEP_SUMMARY\n          echo \"2. Apply any hotfixes if needed\" >> $GITHUB_STEP_SUMMARY\n          echo \"3. Create a tag to trigger release: \\`gh workflow run release-tag.yml -f version=X.Y.Z -f ref=$BRANCH_NAME\\`\" >> $GITHUB_STEP_SUMMARY\n\n  run-pr-tests-nvidia:\n    needs: cut-release-branch\n    uses: ./.github/workflows/pr-test.yml\n    with:\n      git_ref: ${{ needs.cut-release-branch.outputs.branch_name }}\n      run_all_tests: true\n    secrets: inherit\n\n  run-pr-tests-amd:\n    needs: cut-release-branch\n    uses: ./.github/workflows/pr-test-amd.yml\n    with:\n      ref: ${{ needs.cut-release-branch.outputs.branch_name }}\n      run_all_tests: true\n    secrets: inherit\n\n  run-pr-test-npu:\n    needs: cut-release-branch\n    uses: ./.github/workflows/pr-test-npu.yml\n    with:\n      ref: ${{ needs.cut-release-branch.outputs.branch_name }}\n      run_all_tests: true\n    secrets: inherit\n\n  run-pr-tests-xeon:\n    needs: cut-release-branch\n    uses: ./.github/workflows/pr-test-xeon.yml\n    with:\n      ref: ${{ needs.cut-release-branch.outputs.branch_name }}\n      run_all_tests: true\n    secrets: inherit\n\n  run-pr-tests-xpu:\n    needs: cut-release-branch\n    uses: ./.github/workflows/pr-test-xpu.yml\n    with:\n      ref: ${{ needs.cut-release-branch.outputs.branch_name }}\n      run_all_tests: true\n    secrets: inherit\n\n  run-nightly-tests-nvidia:\n    needs: cut-release-branch\n    uses: ./.github/workflows/nightly-test-nvidia.yml\n    with:\n      ref: ${{ needs.cut-release-branch.outputs.branch_name }}\n    secrets: inherit\n\n  run-nightly-tests-amd:\n    needs: cut-release-branch\n    uses: ./.github/workflows/nightly-test-amd.yml\n    with:\n      ref: ${{ needs.cut-release-branch.outputs.branch_name }}\n    secrets: inherit\n\n  run-nightly-tests-npu:\n    needs: cut-release-branch\n    uses: ./.github/workflows/nightly-test-npu.yml\n    with:\n      ref: ${{ needs.cut-release-branch.outputs.branch_name }}\n    secrets: inherit\n\n  run-nightly-tests-intel:\n    needs: cut-release-branch\n    uses: ./.github/workflows/nightly-test-intel.yml\n    with:\n      ref: ${{ needs.cut-release-branch.outputs.branch_name }}\n    secrets: inherit\n"
  },
  {
    "path": ".github/workflows/release-docker-amd-nightly.yml",
    "content": "name: Release Docker Images Nightly (AMD)\non:\n  workflow_dispatch:\n  schedule:\n    - cron: '0 12 * * *'\n\nconcurrency:\n  # A PR number if a pull request and otherwise the commit hash. This cancels\n  # queued and in-progress runs for the same PR (presubmit) or commit\n  # (postsubmit). The workflow name is prepended to avoid conflicts between\n  # different workflows.\n  group: ${{ github.workflow }}-${{ github.event.number || github.sha }}\n  cancel-in-progress: true\n\njobs:\n  publish:\n    if: github.repository == 'sgl-project/sglang'\n    runs-on: amd-docker-scale\n    environment: 'prod'\n    strategy:\n      fail-fast: false\n      matrix:\n        gpu_arch: ['gfx942', 'gfx950']\n        build_type: ['all']\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v4\n        with:\n          fetch-depth: 0  # Required for git describe to find tags\n\n      - name: \"Set Date\"\n        run: |\n          echo \"DATE=$(date +%Y%m%d)\" >> $GITHUB_ENV\n\n      - name: Get version from latest tag\n        id: version\n        run: |\n          # Get the latest version tag sorted by version number (e.g., v0.5.7 -> 0.5.7)\n          VERSION=$(git tag -l 'v[0-9]*' --sort=-v:refname | head -1 | sed 's/^v//')\n\n          if [ -z \"$VERSION\" ]; then\n            echo \"::error::Could not determine version from git tags\"\n            exit 1\n          fi\n\n          # Get short commit hash of current HEAD\n          COMMIT_HASH=$(git rev-parse --short HEAD)\n\n          # Compose pretend version for setuptools_scm: e.g., 0.5.8.dev20260129+g1a2b3c4\n          PRETEND_VERSION=\"${VERSION}.dev${{ env.DATE }}+g${COMMIT_HASH}\"\n\n          echo \"version=${VERSION}\" >> $GITHUB_OUTPUT\n          echo \"pretend_version=${PRETEND_VERSION}\" >> $GITHUB_OUTPUT\n          echo \"Detected version: ${VERSION}\"\n          echo \"Pretend version for pip: ${PRETEND_VERSION}\"\n\n      - name: Login to Docker Hub (AMD)\n        uses: docker/login-action@v2\n        with:\n          username: ${{ secrets.DOCKERHUB_AMD_USERNAME }}\n          password: ${{ secrets.DOCKERHUB_AMD_TOKEN }}\n\n      - name: Build and Push to rocm/sgl-dev\n        run: |\n          version=${{ steps.version.outputs.version }}\n          pretend_version=${{ steps.version.outputs.pretend_version }}\n          echo \"Version: ${version}\"\n          echo \"Pretend version: ${pretend_version}\"\n\n          if [ \"${{ matrix.gpu_arch }}\" = \"gfx942\" ]; then\n            rocm_tag=\"rocm700-mi30x\"\n          elif [ \"${{ matrix.gpu_arch }}\" = \"gfx950\" ]; then\n            rocm_tag=\"rocm700-mi35x\"\n          else\n            echo \"Unsupported gfx arch\"\n            exit 1\n          fi\n\n          tag=v${version}-${rocm_tag}\n          echo \"IMAGE_TAG=${tag}-${{ env.DATE }}\" >> $GITHUB_ENV\n\n          docker build . -f docker/rocm.Dockerfile --build-arg SGL_BRANCH=${{ github.ref_name }} --build-arg BUILD_TYPE=${{ matrix.build_type }} --build-arg GPU_ARCH=${{ matrix.gpu_arch }} --build-arg ENABLE_MORI=1 --build-arg NIC_BACKEND=ainic --build-arg SETUPTOOLS_SCM_PRETEND_VERSION=${pretend_version} -t rocm/sgl-dev:${tag}-${{ env.DATE }} --no-cache\n          docker push rocm/sgl-dev:${tag}-${{ env.DATE }}\n\n      - name: Login to Docker Hub (lmsys)\n        uses: docker/login-action@v2\n        with:\n          username: ${{ secrets.DOCKERHUB_USERNAME }}\n          password: ${{ secrets.DOCKERHUB_TOKEN }}\n\n      - name: Push to lmsysorg/sglang-rocm\n        run: |\n          docker tag rocm/sgl-dev:${{ env.IMAGE_TAG }} lmsysorg/sglang-rocm:${{ env.IMAGE_TAG }}\n          docker push lmsysorg/sglang-rocm:${{ env.IMAGE_TAG }}\n\n  # Temporarily disable docker cache seeding until performant storage is in place\n  cache:\n    if: false\n    # if: always() && github.repository == 'sgl-project/sglang'\n    runs-on: linux-mi300-gpu-1\n    environment: 'prod'\n    needs: publish\n    strategy:\n      fail-fast: false\n      matrix:\n        gpu_arch: ['gfx942']\n        build_type: ['all']\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v4\n        with:\n          fetch-depth: 0  # Required for git describe to find tags\n\n      - name: \"Set Date\"\n        run: |\n          echo \"DATE=$(date +%Y%m%d)\" >> $GITHUB_ENV\n\n      - name: Get version from latest tag\n        id: version\n        run: |\n          # Get the latest version tag sorted by version number (e.g., v0.5.7 -> 0.5.7)\n          VERSION=$(git tag -l 'v[0-9]*' --sort=-v:refname | head -1 | sed 's/^v//')\n\n          if [ -z \"$VERSION\" ]; then\n            echo \"::error::Could not determine version from git tags\"\n            exit 1\n          fi\n\n          echo \"version=${VERSION}\" >> $GITHUB_OUTPUT\n          echo \"Detected version: ${VERSION}\"\n\n      - name: Login to Docker Hub\n        uses: docker/login-action@v2\n        with:\n          username: ${{ secrets.DOCKERHUB_AMD_USERNAME }}\n          password: ${{ secrets.DOCKERHUB_AMD_TOKEN }}\n\n      - name: Pull and Save Docker Image to Cache\n        run: |\n          set -euxo pipefail\n\n          version=${{ steps.version.outputs.version }}\n          echo \"Version: ${version}\"\n\n          if [ \"${{ matrix.gpu_arch }}\" = \"gfx942\" ]; then\n            rocm_tag=\"rocm700-mi30x\"\n          else\n            echo \"Unsupported gfx arch\"\n            exit 1\n          fi\n\n          tag=v${version}-${rocm_tag}\n\n          if [ \"${{ matrix.build_type }}\" = \"all\" ]; then\n            tag_suffix=\"\"\n          else\n            echo \"Unsupported build type\"\n            exit 1\n          fi\n\n          image=\"rocm/sgl-dev:${tag}-${{ env.DATE }}${tag_suffix}\"\n\n          # Determine target cache file name based on ROCm variant\n          if [[ \"${rocm_tag}\" == rocm700* ]]; then\n            final_path=\"/home/runner/sgl-data/docker/image-700.tar\"\n          else\n            echo \"Unexpected ROCm tag: ${rocm_tag}\"\n            exit 1\n          fi\n\n          tmp_path=\"${final_path}.tmp\"\n\n          echo \"Pulling image: ${image}\"\n          docker pull \"${image}\"\n\n          echo \"Saving to temp file: ${tmp_path}\"\n          docker save \"${image}\" -o \"${tmp_path}\"\n\n          echo \"Moving to final path: ${final_path}\"\n          mv -f \"${tmp_path}\" \"${final_path}\"\n\n          echo \"Cache populated successfully at ${final_path}\"\n"
  },
  {
    "path": ".github/workflows/release-docker-amd-rocm720-nightly.yml",
    "content": "name: Release Docker Images ROCm 7.2.0 Nightly Preview (AMD)\non:\n  workflow_dispatch:\n  schedule:\n    - cron: '0 12 * * *'\n\nconcurrency:\n  # A PR number if a pull request and otherwise the commit hash. This cancels\n  # queued and in-progress runs for the same PR (presubmit) or commit\n  # (postsubmit). The workflow name is prepended to avoid conflicts between\n  # different workflows.\n  group: ${{ github.workflow }}-${{ github.event.number || github.sha }}\n  cancel-in-progress: True\n\njobs:\n  publish:\n    if: github.repository == 'sgl-project/sglang'\n    runs-on: amd-docker-scale\n    environment: 'prod'\n    strategy:\n      fail-fast: false\n      matrix:\n        gpu_arch: ['gfx942-rocm720', 'gfx950-rocm720']\n        build_type: ['all']\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v4\n        with:\n          fetch-depth: 0  # Required for git describe to find tags\n\n      - name: \"Set Date\"\n        run: |\n          echo \"DATE=$(date +%Y%m%d)\" >> $GITHUB_ENV\n\n      - name: Get version from latest tag\n        id: version\n        run: |\n          # Get the latest version tag sorted by version number (e.g., v0.5.7 -> 0.5.7)\n          VERSION=$(git tag -l 'v[0-9]*' --sort=-v:refname | head -1 | sed 's/^v//')\n\n          if [ -z \"$VERSION\" ]; then\n            echo \"::error::Could not determine version from git tags\"\n            exit 1\n          fi\n\n          # Get short commit hash of current HEAD\n          COMMIT_HASH=$(git rev-parse --short HEAD)\n\n          # Compose pretend version for setuptools_scm: e.g., 0.5.8.post1.dev20260211+g1a2b3c4\n          PRETEND_VERSION=\"${VERSION}.dev${{ env.DATE }}+g${COMMIT_HASH}\"\n\n          echo \"version=${VERSION}\" >> $GITHUB_OUTPUT\n          echo \"pretend_version=${PRETEND_VERSION}\" >> $GITHUB_OUTPUT\n          echo \"Detected version: ${VERSION}\"\n          echo \"Pretend version for pip: ${PRETEND_VERSION}\"\n\n      - name: Login to Docker Hub\n        uses: docker/login-action@v2\n        with:\n          username: ${{ secrets.DOCKERHUB_AMD_USERNAME }}\n          password: ${{ secrets.DOCKERHUB_AMD_TOKEN }}\n\n      - name: Build and Push to rocm/sgl-dev\n        run: |\n          version=${{ steps.version.outputs.version }}\n          pretend_version=${{ steps.version.outputs.pretend_version }}\n          echo \"Version: ${version}\"\n          echo \"Pretend version: ${pretend_version}\"\n\n          if [ \"${{ matrix.gpu_arch }}\" = \"gfx942-rocm720\" ]; then\n            rocm_tag=\"rocm720-mi30x\"\n          elif [ \"${{ matrix.gpu_arch }}\" = \"gfx950-rocm720\" ]; then\n            rocm_tag=\"rocm720-mi35x\"\n          else\n            echo \"Unsupported gfx arch\"\n            exit 1\n          fi\n\n          tag=v${version}-${rocm_tag}\n          echo \"IMAGE_TAG=${tag}-${{ env.DATE }}\" >> $GITHUB_ENV\n\n          docker build . -f docker/rocm.Dockerfile --build-arg SGL_BRANCH=${{ github.ref_name }} --build-arg BUILD_TYPE=${{ matrix.build_type }} --build-arg GPU_ARCH=${{ matrix.gpu_arch }} --build-arg ENABLE_MORI=1 --build-arg NIC_BACKEND=ainic --build-arg SETUPTOOLS_SCM_PRETEND_VERSION=${pretend_version} -t rocm/sgl-dev:${tag}-${{ env.DATE }} --no-cache\n          docker push rocm/sgl-dev:${tag}-${{ env.DATE }}\n\n      - name: Login to Docker Hub (lmsys)\n        uses: docker/login-action@v2\n        with:\n          username: ${{ secrets.DOCKERHUB_USERNAME }}\n          password: ${{ secrets.DOCKERHUB_TOKEN }}\n\n      - name: Push to lmsysorg/sglang-rocm\n        run: |\n          docker tag rocm/sgl-dev:${{ env.IMAGE_TAG }} lmsysorg/sglang-rocm:${{ env.IMAGE_TAG }}\n          docker push lmsysorg/sglang-rocm:${{ env.IMAGE_TAG }}\n"
  },
  {
    "path": ".github/workflows/release-docker-amd.yml",
    "content": "name: Release Docker Images (AMD)\non:\n  push:\n    tags:\n      - 'v[0-9]+.*'\n  workflow_dispatch:\n    inputs:\n      version:\n        description: 'Version to build (without v prefix, e.g., 0.5.7)'\n        required: true\n\njobs:\n  publish:\n    if: github.repository == 'sgl-project/sglang'\n    runs-on: amd-docker-scale\n    environment: 'prod'\n    strategy:\n      matrix:\n        rocm_version: ['rocm700', 'rocm720']\n        gpu_arch: ['gfx942', 'gfx950']\n        build_type: ['all']\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v4\n\n      - name: Login to Docker Hub\n        uses: docker/login-action@v2\n        with:\n          username: ${{ secrets.DOCKERHUB_USERNAME }}\n          password: ${{ secrets.DOCKERHUB_TOKEN }}\n\n      - name: Get version from tag\n        id: version\n        run: |\n          if [ \"${{ github.event_name }}\" = \"workflow_dispatch\" ]; then\n            VERSION=\"${{ github.event.inputs.version }}\"\n          else\n            # Extract version from tag (e.g., v0.5.7 -> 0.5.7)\n            VERSION=\"${GITHUB_REF_NAME#v}\"\n          fi\n\n          # Validate version format\n          if [ -z \"$VERSION\" ]; then\n            echo \"::error::Version is empty\"\n            exit 1\n          fi\n          if ! echo \"$VERSION\" | grep -qE '^[0-9]+\\.[0-9]+\\.[0-9]+'; then\n            echo \"::error::Invalid version format: $VERSION (expected: X.Y.Z)\"\n            exit 1\n          fi\n\n          echo \"version=${VERSION}\" >> $GITHUB_OUTPUT\n\n      - name: Build and Push\n        run: |\n          version=${{ steps.version.outputs.version }}\n          echo \"Version: ${version}\"\n\n          gpu_arch_suffix=\"\"\n          if [ \"${{ matrix.rocm_version }}\" = \"rocm700\" ]; then\n            if [ \"${{ matrix.gpu_arch }}\" = \"gfx942\" ]; then\n              rocm_tag=\"rocm700-mi30x\"\n            elif [ \"${{ matrix.gpu_arch }}\" = \"gfx950\" ]; then\n              rocm_tag=\"rocm700-mi35x\"\n            else\n              echo \"Unsupported gfx arch\"\n              exit 1\n            fi\n          elif [ \"${{ matrix.rocm_version }}\" = \"rocm720\" ]; then\n            gpu_arch_suffix=\"-${{ matrix.rocm_version }}\"\n            if [ \"${{ matrix.gpu_arch }}\" = \"gfx942\" ]; then\n              rocm_tag=\"rocm720-mi30x\"\n            elif [ \"${{ matrix.gpu_arch }}\" = \"gfx950\" ]; then\n              rocm_tag=\"rocm720-mi35x\"\n            else\n              echo \"Unsupported gfx arch\"\n              exit 1\n            fi\n          else\n            echo \"Unsupported rocm version\"\n            exit 1\n          fi\n\n          tag=v${version}-${rocm_tag}\n\n          # rocm.Dockerfile expects SGL_BRANCH with 'v' prefix for git tag checkout\n          docker build . -f docker/rocm.Dockerfile --build-arg BUILD_TYPE=${{ matrix.build_type }} --build-arg GPU_ARCH=${{ matrix.gpu_arch }}${gpu_arch_suffix} --build-arg SGL_BRANCH=v${version} --build-arg ENABLE_MORI=1 --build-arg NIC_BACKEND=ainic -t lmsysorg/sglang:${tag} --no-cache\n          docker push lmsysorg/sglang:${tag}\n"
  },
  {
    "path": ".github/workflows/release-docker-cu13-framework.yml",
    "content": "name: Release CUDA 13 Framework Docker Images (Temporary)\n\n# Temporary workflow to build only versioned cu13 framework images\n# Can be deleted after use\n\non:\n  workflow_dispatch:\n    inputs:\n      version:\n        description: \"Version to build (without v prefix, e.g., 0.5.8)\"\n        required: true\njobs:\n  publish-x86:\n    if: github.repository == 'sgl-project/sglang'\n    runs-on: x64-docker-build-node\n    steps:\n      - name: Delete huge unnecessary tools folder\n        run: rm -rf /opt/hostedtoolcache\n\n      - name: Checkout repository\n        uses: actions/checkout@v4\n\n      - name: Free disk space\n        uses: jlumbroso/free-disk-space@main\n        with:\n          tool-cache: false\n          docker-images: false\n          android: true\n          dotnet: true\n          haskell: true\n          large-packages: true\n          swap-storage: false\n\n      - name: Set up Docker Buildx\n        uses: docker/setup-buildx-action@v3\n\n      - name: Login to Docker Hub\n        uses: docker/login-action@v2\n        with:\n          username: ${{ secrets.DOCKERHUB_USERNAME }}\n          password: ${{ secrets.DOCKERHUB_TOKEN }}\n\n      - name: Validate version\n        id: version\n        run: |\n          VERSION=\"${{ github.event.inputs.version }}\"\n          if [ -z \"$VERSION\" ]; then\n            echo \"::error::Version is empty\"\n            exit 1\n          fi\n          if ! echo \"$VERSION\" | grep -qE '^[0-9]+\\.[0-9]+\\.[0-9]+'; then\n            echo \"::error::Invalid version format: $VERSION (expected: X.Y.Z)\"\n            exit 1\n          fi\n          echo \"version=${VERSION}\" >> $GITHUB_OUTPUT\n\n      - name: Build and Push AMD64 Framework (CUDA 13)\n        run: |\n          version=${{ steps.version.outputs.version }}\n\n          docker buildx build \\\n            --target framework \\\n            --platform linux/amd64 \\\n            --output type=image,name=lmsysorg/sglang,push-by-digest=true,name-canonical=true,push=true \\\n            -f docker/Dockerfile \\\n            --build-arg CUDA_VERSION=13.0.1 \\\n            --build-arg BUILD_TYPE=all \\\n            --build-arg INSTALL_FLASHINFER_JIT_CACHE=1 \\\n            --build-arg GRACE_BLACKWELL=0 \\\n            --build-arg SGL_VERSION=${version} \\\n            --metadata-file /tmp/metadata.json \\\n            --no-cache \\\n            .\n\n          DIGEST=$(python3 -c \"import json; print(json.load(open('/tmp/metadata.json'))['containerimage.digest'])\")\n          echo \"Pushed digest: ${DIGEST}\"\n          echo \"${DIGEST}\" > /tmp/digest-cu130-amd64-framework.txt\n\n      - name: Upload digest\n        uses: actions/upload-artifact@v4\n        with:\n          name: digest-cu130-amd64\n          path: /tmp/digest-cu130-amd64-framework.txt\n          retention-days: 1\n\n  publish-arm64:\n    if: github.repository == 'sgl-project/sglang'\n    runs-on: arm-docker-build-node\n    steps:\n      - name: Delete huge unnecessary tools folder\n        run: rm -rf /opt/hostedtoolcache\n\n      - name: Checkout repository\n        uses: actions/checkout@v4\n\n      - name: Set up Docker Buildx\n        uses: docker/setup-buildx-action@v3\n\n      - name: Login to Docker Hub\n        uses: docker/login-action@v2\n        with:\n          username: ${{ secrets.DOCKERHUB_USERNAME }}\n          password: ${{ secrets.DOCKERHUB_TOKEN }}\n\n      - name: Validate version\n        id: version\n        run: |\n          VERSION=\"${{ github.event.inputs.version }}\"\n          if [ -z \"$VERSION\" ]; then\n            echo \"::error::Version is empty\"\n            exit 1\n          fi\n          if ! echo \"$VERSION\" | grep -qE '^[0-9]+\\.[0-9]+\\.[0-9]+'; then\n            echo \"::error::Invalid version format: $VERSION (expected: X.Y.Z)\"\n            exit 1\n          fi\n          echo \"version=${VERSION}\" >> $GITHUB_OUTPUT\n\n      - name: Build and Push ARM64 Framework (CUDA 13)\n        run: |\n          version=${{ steps.version.outputs.version }}\n\n          docker buildx build \\\n            --target framework \\\n            --platform linux/arm64 \\\n            --output type=image,name=lmsysorg/sglang,push-by-digest=true,name-canonical=true,push=true \\\n            -f docker/Dockerfile \\\n            --build-arg CUDA_VERSION=13.0.1 \\\n            --build-arg BUILD_TYPE=all \\\n            --build-arg INSTALL_FLASHINFER_JIT_CACHE=1 \\\n            --build-arg GRACE_BLACKWELL=1 \\\n            --build-arg SGL_VERSION=${version} \\\n            --metadata-file /tmp/metadata.json \\\n            --no-cache \\\n            .\n\n          DIGEST=$(python3 -c \"import json; print(json.load(open('/tmp/metadata.json'))['containerimage.digest'])\")\n          echo \"Pushed digest: ${DIGEST}\"\n          echo \"${DIGEST}\" > /tmp/digest-cu130-arm64-framework.txt\n\n      - name: Upload digest\n        uses: actions/upload-artifact@v4\n        with:\n          name: digest-cu130-arm64\n          path: /tmp/digest-cu130-arm64-framework.txt\n          retention-days: 1\n\n  create-manifest:\n    runs-on: ubuntu-22.04\n    needs: [publish-x86, publish-arm64]\n    if: github.repository == 'sgl-project/sglang'\n    steps:\n      - name: Set up Docker Buildx\n        uses: docker/setup-buildx-action@v3\n\n      - name: Login to Docker Hub\n        uses: docker/login-action@v2\n        with:\n          username: ${{ secrets.DOCKERHUB_USERNAME }}\n          password: ${{ secrets.DOCKERHUB_TOKEN }}\n\n      - name: Download amd64 digest\n        uses: actions/download-artifact@v4\n        with:\n          name: digest-cu130-amd64\n          path: /tmp/digests/amd64\n\n      - name: Download arm64 digest\n        uses: actions/download-artifact@v4\n        with:\n          name: digest-cu130-arm64\n          path: /tmp/digests/arm64\n\n      - name: Create multi-arch manifest\n        run: |\n          version=${{ github.event.inputs.version }}\n          AMD64_DIGEST=$(cat /tmp/digests/amd64/digest-cu130-amd64-framework.txt)\n          ARM64_DIGEST=$(cat /tmp/digests/arm64/digest-cu130-arm64-framework.txt)\n\n          # Create versioned CUDA 13 framework manifest\n          docker buildx imagetools create \\\n            -t lmsysorg/sglang:v${version}-cu130 \\\n            lmsysorg/sglang@${AMD64_DIGEST} \\\n            lmsysorg/sglang@${ARM64_DIGEST}\n\n          # Create latest CUDA 13 framework manifest\n          docker buildx imagetools create \\\n            -t lmsysorg/sglang:latest-cu130 \\\n            lmsysorg/sglang@${AMD64_DIGEST} \\\n            lmsysorg/sglang@${ARM64_DIGEST}\n"
  },
  {
    "path": ".github/workflows/release-docker-dev.yml",
    "content": "name: Build and Push Development Docker Images\n\non:\n  workflow_dispatch:\n    inputs:\n      pr_number:\n        description: \"PR number to build from (leave empty to use current branch)\"\n        required: false\n        default: \"\"\n      tag:\n        description: \"Custom tag suffix (overrides pr_number in tag). E.g. 'my-test' → dev-my-test, dev-cu13-my-test, etc.\"\n        required: false\n        default: \"\"\n  schedule:\n    - cron: \"0 0 * * *\"\n\nconcurrency:\n  group: release-docker-dev-${{ inputs.tag || inputs.pr_number || 'nightly' }}\n  cancel-in-progress: true\n\njobs:\n  build-dev:\n    if: ${{ github.repository == 'sgl-project/sglang' }}\n    runs-on: ${{ matrix.runner }}\n    strategy:\n      matrix:\n        include:\n          - runner: x64-docker-build-node\n            platform: linux/amd64\n            build_type: all\n            grace_blackwell: 0\n            arch_tag: x86\n            version: 12.9.1\n          - runner: arm-docker-build-node\n            platform: linux/arm64\n            build_type: all\n            grace_blackwell: 1\n            arch_tag: arm64\n            version: 12.9.1\n          - runner: x64-docker-build-node\n            platform: linux/amd64\n            build_type: all\n            grace_blackwell: 0\n            arch_tag: x86-cu13\n            version: 13.0.1\n          - runner: arm-docker-build-node\n            platform: linux/arm64\n            build_type: all\n            grace_blackwell: 1\n            arch_tag: arm64-cu13\n            version: 13.0.1\n    steps:\n      - name: Delete huge unnecessary tools folder\n        run: rm -rf /opt/hostedtoolcache\n\n      - name: Checkout repository\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_number && format('refs/pull/{0}/head', inputs.pr_number) || github.ref }}\n\n      - name: Free disk space\n        uses: jlumbroso/free-disk-space@main\n        with:\n          tool-cache: true\n          docker-images: true\n          android: true\n          dotnet: true\n          haskell: true\n          large-packages: true\n          swap-storage: true\n\n      - name: Prune Docker to reclaim disk space\n        run: |\n          docker buildx prune --filter \"until=72h\" -f\n          docker system prune -af --filter \"until=72h\"\n          docker volume prune -af\n\n      - name: Set up Docker Buildx\n        uses: docker/setup-buildx-action@v3\n\n      - name: Login to Docker Hub\n        uses: docker/login-action@v2\n        with:\n          username: ${{ secrets.DOCKERHUB_USERNAME }}\n          password: ${{ secrets.DOCKERHUB_TOKEN }}\n\n      - name: Build and Push Dev Image\n        run: |\n          # Nightly (schedule) installs latest release; manual dispatch builds from checked-out source\n          if [ \"${{ github.event_name }}\" = \"schedule\" ]; then\n            SOURCE_ARG=\"--build-arg USE_LATEST_SGLANG=1\"\n          else\n            SOURCE_ARG=\"--build-arg BRANCH_TYPE=local\"\n          fi\n\n          docker buildx build \\\n            --platform ${{ matrix.platform }} \\\n            --output type=image,name=lmsysorg/sglang,push-by-digest=true,name-canonical=true,push=true \\\n            --target framework \\\n            -f docker/Dockerfile \\\n            --build-arg CUDA_VERSION=${{ matrix.version }} \\\n            --build-arg BUILD_TYPE=${{ matrix.build_type }} \\\n            --build-arg CMAKE_BUILD_PARALLEL_LEVEL=$(nproc) \\\n            --build-arg GRACE_BLACKWELL=${{ matrix.grace_blackwell }} \\\n            ${SOURCE_ARG} \\\n            --build-arg INSTALL_FLASHINFER_JIT_CACHE=1 \\\n            --metadata-file /tmp/metadata.json \\\n            --no-cache \\\n            .\n\n          DIGEST=$(python3 -c \"import json; print(json.load(open('/tmp/metadata.json'))['containerimage.digest'])\")\n          echo \"Pushed digest: ${DIGEST}\"\n          echo \"${DIGEST}\" > /tmp/digest.txt\n\n      - name: Upload digest\n        uses: actions/upload-artifact@v4\n        with:\n          name: digest-${{ matrix.arch_tag }}\n          path: /tmp/digest.txt\n          retention-days: 1\n\n  create-manifests:\n    runs-on: ubuntu-22.04\n    needs: [build-dev]\n    if: ${{ github.repository == 'sgl-project/sglang' }}\n    strategy:\n      matrix:\n        variant:\n          - base: dev\n            x86: x86\n            arm64: arm64\n          - base: dev-cu13\n            x86: x86-cu13\n            arm64: arm64-cu13\n    steps:\n      - uses: docker/setup-buildx-action@v3\n\n      - uses: docker/login-action@v2\n        with:\n          username: ${{ secrets.DOCKERHUB_USERNAME }}\n          password: ${{ secrets.DOCKERHUB_TOKEN }}\n\n      - name: Download x86 digest\n        uses: actions/download-artifact@v4\n        with:\n          name: digest-${{ matrix.variant.x86 }}\n          path: /tmp/digests/x86\n\n      - name: Download arm64 digest\n        uses: actions/download-artifact@v4\n        with:\n          name: digest-${{ matrix.variant.arm64 }}\n          path: /tmp/digests/arm64\n\n      - name: Create multi-arch manifest\n        run: |\n          X86_DIGEST=$(cat /tmp/digests/x86/digest.txt)\n          ARM64_DIGEST=$(cat /tmp/digests/arm64/digest.txt)\n\n          SUFFIX=\"\"\n          if [ -n \"${{ inputs.tag }}\" ]; then\n            SUFFIX=\"-${{ inputs.tag }}\"\n          elif [ -n \"${{ inputs.pr_number }}\" ]; then\n            SUFFIX=\"-pr-${{ inputs.pr_number }}\"\n          fi\n\n          TAG=\"${{ matrix.variant.base }}${SUFFIX}\"\n\n          # For nightly (no suffix), also stamp a dated tag\n          EXTRA_TAG=\"\"\n          if [ -z \"${SUFFIX}\" ]; then\n            SHORT_SHA=\"${{ github.sha }}\"\n            EXTRA_TAG=\"-t lmsysorg/sglang:nightly-${TAG}-$(date +%Y%m%d)-${SHORT_SHA:0:8}\"\n          fi\n\n          docker buildx imagetools create \\\n            -t lmsysorg/sglang:${TAG} \\\n            ${EXTRA_TAG} \\\n            lmsysorg/sglang@${X86_DIGEST} \\\n            lmsysorg/sglang@${ARM64_DIGEST}\n\n          echo \"✓ Published lmsysorg/sglang:${TAG}\"\n\n      - name: Cleanup Old Nightly Builds\n        if: ${{ !inputs.tag && !inputs.pr_number }}\n        run: |\n          TOKEN=$(curl -s -H \"Content-Type: application/json\" \\\n            -X POST -d '{\"username\": \"${{ secrets.DOCKERHUB_USERNAME }}\", \"password\": \"${{ secrets.DOCKERHUB_TOKEN }}\"}' \\\n            https://hub.docker.com/v2/users/login/ | jq -r .token)\n\n          TAGS_RESPONSE=$(curl -s -H \"Authorization: JWT $TOKEN\" \\\n            \"https://hub.docker.com/v2/repositories/lmsysorg/sglang/tags/?page_size=100\")\n\n          TAGS=$(echo \"$TAGS_RESPONSE\" | jq -r \\\n            '.results[] | select(.name | test(\"^nightly-${{ matrix.variant.base }}-[0-9]\")) | \"\\(.last_updated)|\\(.name)\"' \\\n            | sort -r | cut -d'|' -f2)\n\n          TAG_COUNT=$(echo \"$TAGS\" | wc -l)\n          if [ \"$TAG_COUNT\" -gt 14 ]; then\n            echo \"Found $TAG_COUNT nightly builds, keeping only the 14 most recent\"\n            TAGS_TO_DELETE=$(echo \"$TAGS\" | tail -n +15)\n            for tag in $TAGS_TO_DELETE; do\n              echo \"Deleting tag: $tag\"\n              curl -X DELETE -H \"Authorization: JWT $TOKEN\" \\\n                \"https://hub.docker.com/v2/repositories/lmsysorg/sglang/tags/$tag/\"\n            done\n          else\n            echo \"Only $TAG_COUNT nightly builds found, no cleanup needed\"\n          fi\n"
  },
  {
    "path": ".github/workflows/release-docker-gateway.yml",
    "content": "name: Release SGLang Model Gateway Docker Image\non:\n  push:\n    branches:\n      - main\n    paths:\n      - sgl-model-gateway/bindings/python/pyproject.toml\n  workflow_dispatch:\n\njobs:\n  publish:\n    if: github.repository == 'sgl-project/sglang'\n    runs-on: ubuntu-24.04\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v4\n\n      - name: Set up QEMU\n        uses: docker/setup-qemu-action@v3\n\n      - name: Set up Docker Buildx\n        uses: docker/setup-buildx-action@v3\n\n      - name: Login to Docker Hub\n        uses: docker/login-action@v3\n        with:\n          username: ${{ secrets.DOCKERHUB_USERNAME }}\n          password: ${{ secrets.DOCKERHUB_TOKEN }}\n\n      - name: Build and Push\n        run: |\n          version=$(cat sgl-model-gateway/bindings/python/src/sglang_router/version.py | cut -d'\"' -f2)\n          tag=v${version}\n\n          docker buildx build . -f docker/gateway.Dockerfile \\\n            --platform linux/amd64,linux/arm64 \\\n            -t lmsysorg/sgl-model-gateway:${tag} \\\n            -t lmsysorg/sgl-model-gateway:latest \\\n            --push\n"
  },
  {
    "path": ".github/workflows/release-docker-npu-nightly.yml",
    "content": "name: Release Docker Images Nightly (NPU)\non:\n  pull_request:\n    branches:\n      - 'main'\n    paths:\n      - '.github/workflows/release-docker-npu-nightly.yml'\n      - 'docker/npu.Dockerfile'\n  workflow_dispatch:\n  schedule:\n    - cron: \"0 0 * * *\"\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.sha }}\n  cancel-in-progress: true\n\njobs:\n  build:\n    runs-on: ubuntu-22.04-arm\n    strategy:\n      matrix:\n        cann_version: [\"8.5.0\"]\n        device_type: [\"910b\", \"a3\"]\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v4\n\n      - name: Free up disk space\n        uses: jlumbroso/free-disk-space@54081f138730dfa15788a46383842cd2f914a1be # v1.3.1\n        with:\n          tool-cache: true\n          docker-images: false\n\n      - name: Setup Docker buildx\n        uses: docker/setup-buildx-action@v3\n\n      - name: Docker meta\n        id: meta\n        uses: docker/metadata-action@v5\n        with:\n          images: |\n            lmsysorg/sglang\n          # push with schedule event\n          # push with workflow_dispatch event\n          tags: |\n            type=ref,event=pr\n            type=ref,event=branch\n            type=schedule,pattern=main\n          flavor: |\n            latest=false\n            suffix=-cann${{ matrix.cann_version }}-${{ matrix.device_type }},onlatest=true\n      # Login against a Docker registry except on PR\n      # https://github.com/docker/login-action\n      - name: Log into docker hub\n        uses: docker/login-action@v3\n        if: ${{ github.repository == 'sgl-project/sglang' && github.event_name != 'pull_request' }}\n        with:\n          username: ${{ secrets.DOCKERHUB_USERNAME }}\n          password: ${{ secrets.DOCKERHUB_TOKEN }}\n\n      # Enable Docker multi-architecture build environment\n      # Emulate non-native architectures\n      - name: Set up QEMU\n        uses: docker/setup-qemu-action@v3\n      # Required for building and pushing multi-arch Docker images\n      - name: Set up Docker Buildx\n        uses: docker/setup-buildx-action@v3\n\n      # Build and push Docker image with Buildx (don't push on PR)\n      # https://github.com/docker/build-push-action\n      - name: Build and push Docker image\n        id: build-and-push\n        uses: docker/build-push-action@v6\n        with:\n          context: docker\n          file: docker/npu.Dockerfile\n          platforms: linux/arm64,linux/amd64\n          labels: ${{ steps.meta.outputs.labels }}\n          tags: ${{ steps.meta.outputs.tags }}\n          push: ${{ github.repository == 'sgl-project/sglang' && github.event_name != 'pull_request' }}\n          provenance: false\n          build-args: |\n            SGLANG_KERNEL_NPU_TAG=2026.03.10.rc1\n            CANN_VERSION=${{ matrix.cann_version }}\n            DEVICE_TYPE=${{ matrix.device_type }}\n"
  },
  {
    "path": ".github/workflows/release-docker-npu.yml",
    "content": "name: Release Docker Images (NPU)\non:\n  push:\n    tags:\n      - 'v[0-9]+.*'\n  workflow_dispatch:\n    inputs:\n      version:\n        description: 'Version to build (without v prefix, e.g., 0.5.7)'\n        required: true\n\njobs:\n  build:\n    runs-on: ubuntu-22.04-arm\n    strategy:\n      matrix:\n        cann_version: [\"8.5.0\"]\n        device_type: [\"910b\", \"a3\"]\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v4\n\n      - name: Free up disk space\n        uses: jlumbroso/free-disk-space@54081f138730dfa15788a46383842cd2f914a1be # v1.3.1\n        with:\n          tool-cache: true\n          docker-images: false\n\n        # push with tag\n      - name: Docker meta\n        id: meta\n        uses: docker/metadata-action@v5\n        with:\n          images: |\n            lmsysorg/sglang\n          tags: |\n            type=ref,event=pr\n          flavor: |\n            latest=false\n\n      # Login against a Docker registry except on PR\n      # https://github.com/docker/login-action\n      - name: Login to Docker Hub\n        uses: docker/login-action@v2\n        if: ${{ github.repository == 'sgl-project/sglang' && github.event_name != 'pull_request' }}\n        with:\n          username: ${{ secrets.DOCKERHUB_USERNAME }}\n          password: ${{ secrets.DOCKERHUB_TOKEN }}\n\n      - name: Get version from tag\n        id: version\n        run: |\n          if [ \"${{ github.event_name }}\" = \"workflow_dispatch\" ]; then\n            VERSION=\"${{ github.event.inputs.version }}\"\n          else\n            # Extract version from tag (e.g., v0.5.7 -> 0.5.7)\n            VERSION=\"${GITHUB_REF_NAME#v}\"\n          fi\n          # Validate version format\n          if [ -z \"$VERSION\" ]; then\n            echo \"::error::Version is empty\"\n            exit 1\n          fi\n          if ! echo \"$VERSION\" | grep -qE '^[0-9]+\\.[0-9]+\\.[0-9]+'; then\n            echo \"::error::Invalid version format: $VERSION (expected: X.Y.Z)\"\n            exit 1\n          fi\n          echo \"version=v${VERSION}\" >> $GITHUB_OUTPUT\n          echo \"TAG=lmsysorg/sglang:v${VERSION}-cann${{ matrix.cann_version }}-${{ matrix.device_type }}\" >> $GITHUB_OUTPUT\n      # Enable Docker multi-architecture build environment\n      # Emulate non-native architectures\n      - name: Set up QEMU\n        uses: docker/setup-qemu-action@v3\n      # Required for building and pushing multi-arch Docker images\n      - name: Set up Docker Buildx\n        uses: docker/setup-buildx-action@v3\n\n      - name: Build and push Docker image\n        id: build-and-push\n        uses: docker/build-push-action@v6\n        with:\n          context: docker\n          file: docker/npu.Dockerfile\n          platforms: linux/arm64,linux/amd64\n          labels: ${{ steps.meta.outputs.labels }}\n          tags: ${{ steps.meta.outputs.tags || steps.version.outputs.TAG }}\n          push: ${{ github.repository == 'sgl-project/sglang' && github.event_name != 'pull_request' }}\n          provenance: false\n          build-args: |\n            SGLANG_KERNEL_NPU_TAG=2026.03.10.rc1\n            CANN_VERSION=${{ matrix.cann_version }}\n            DEVICE_TYPE=${{ matrix.device_type }}\n            SGLANG_TAG=${{ steps.version.outputs.version }}\n"
  },
  {
    "path": ".github/workflows/release-docker-xeon.yml",
    "content": "name: Release Docker Xeon Images\non:\n  push:\n    tags:\n      - 'v[0-9]+.*'\n  workflow_dispatch:\n    inputs:\n      version:\n        description: 'Version to build (without v prefix, e.g., 0.5.7)'\n        required: true\n\njobs:\n  publish:\n    if: github.repository == 'sgl-project/sglang'\n    runs-on: ubuntu-24.04\n    environment: 'prod'\n    strategy:\n      matrix:\n        build_type: ['all']\n    steps:\n\n      - name: Checkout repository\n        uses: actions/checkout@v4\n\n      - name: Login to Docker Hub\n        uses: docker/login-action@v2\n        with:\n          username: ${{ secrets.DOCKERHUB_USERNAME }}\n          password: ${{ secrets.DOCKERHUB_TOKEN }}\n\n      - name: Get version from tag\n        id: version\n        run: |\n          if [ \"${{ github.event_name }}\" = \"workflow_dispatch\" ]; then\n            VERSION=\"${{ github.event.inputs.version }}\"\n          else\n            # Extract version from tag (e.g., v0.5.7 -> 0.5.7)\n            VERSION=\"${GITHUB_REF_NAME#v}\"\n          fi\n\n          # Validate version format\n          if [ -z \"$VERSION\" ]; then\n            echo \"::error::Version is empty\"\n            exit 1\n          fi\n          if ! echo \"$VERSION\" | grep -qE '^[0-9]+\\.[0-9]+\\.[0-9]+'; then\n            echo \"::error::Invalid version format: $VERSION (expected: X.Y.Z)\"\n            exit 1\n          fi\n\n          echo \"version=${VERSION}\" >> $GITHUB_OUTPUT\n\n      - name: Build and Push\n        run: |\n          version=${{ steps.version.outputs.version }}\n          tag=v${version}-xeon\n\n          docker build . -f docker/xeon.Dockerfile \\\n            --build-arg VER_SGLANG=v${version} \\\n            -t lmsysorg/sglang:${tag} \\\n            --no-cache\n          docker push lmsysorg/sglang:${tag}\n"
  },
  {
    "path": ".github/workflows/release-docker.yml",
    "content": "name: Release Docker Images\n#\n# This workflow builds and publishes both framework and runtime Docker images:\n#\n# Framework images (full development environment):\n#   - lmsysorg/sglang:v{version}, lmsysorg/sglang:latest\n#   - lmsysorg/sglang:v{version}-cu130, lmsysorg/sglang:latest-cu130\n#\n# Runtime images (production-optimized, ~50% smaller):\n#   - lmsysorg/sglang:v{version}-runtime, lmsysorg/sglang:latest-runtime\n#   - lmsysorg/sglang:v{version}-cu130-runtime, lmsysorg/sglang:latest-cu130-runtime\n#\non:\n  push:\n    tags:\n      - \"v[0-9]+.*\"\n  workflow_dispatch:\n    inputs:\n      version:\n        description: \"Version to build (without v prefix, e.g., 0.5.7)\"\n        required: true\n\njobs:\n  publish-x86:\n    if: github.repository == 'sgl-project/sglang'\n    environment: \"prod\"\n    strategy:\n      matrix:\n        variant:\n          - cuda_version: \"12.9.1\"\n            build_type: \"all\"\n            grace_blackwell: 0\n    runs-on: x64-docker-build-node\n    steps:\n      - name: Delete huge unnecessary tools folder\n        run: rm -rf /opt/hostedtoolcache\n\n      - name: Checkout repository\n        uses: actions/checkout@v4\n\n      - name: Free disk space\n        uses: jlumbroso/free-disk-space@main\n        with:\n          tool-cache: false\n          docker-images: false\n          android: true\n          dotnet: true\n          haskell: true\n          large-packages: true\n          swap-storage: false\n\n      - name: Set up Docker Buildx\n        uses: docker/setup-buildx-action@v3\n\n      - name: Login to Docker Hub\n        uses: docker/login-action@v2\n        with:\n          username: ${{ secrets.DOCKERHUB_USERNAME }}\n          password: ${{ secrets.DOCKERHUB_TOKEN }}\n\n      - name: Get version from tag\n        id: version\n        run: |\n          if [ \"${{ github.event_name }}\" = \"workflow_dispatch\" ]; then\n            VERSION=\"${{ github.event.inputs.version }}\"\n          else\n            # Extract version from tag (e.g., v0.5.7 -> 0.5.7)\n            VERSION=\"${GITHUB_REF_NAME#v}\"\n          fi\n\n          # Validate version format\n          if [ -z \"$VERSION\" ]; then\n            echo \"::error::Version is empty\"\n            exit 1\n          fi\n          if ! echo \"$VERSION\" | grep -qE '^[0-9]+\\.[0-9]+\\.[0-9]+'; then\n            echo \"::error::Invalid version format: $VERSION (expected: X.Y.Z)\"\n            exit 1\n          fi\n\n          echo \"version=${VERSION}\" >> $GITHUB_OUTPUT\n\n      - name: Build AMD64 Framework\n        run: |\n          version=${{ steps.version.outputs.version }}\n\n          docker buildx build \\\n            --target framework \\\n            --platform linux/amd64 \\\n            --output type=image,name=lmsysorg/sglang,push-by-digest=true,name-canonical=true,push=true \\\n            -f docker/Dockerfile \\\n            --build-arg CUDA_VERSION=${{ matrix.variant.cuda_version }} \\\n            --build-arg BUILD_TYPE=${{ matrix.variant.build_type }} \\\n            --build-arg GRACE_BLACKWELL=${{ matrix.variant.grace_blackwell }} \\\n            --build-arg INSTALL_FLASHINFER_JIT_CACHE=1 \\\n            --build-arg SGL_VERSION=${version} \\\n            --metadata-file /tmp/metadata-cu129-framework.json \\\n            --no-cache \\\n            .\n\n          DIGEST=$(python3 -c \"import json; print(json.load(open('/tmp/metadata-cu129-framework.json'))['containerimage.digest'])\")\n          echo \"Pushed digest: ${DIGEST}\"\n          echo \"${DIGEST}\" > /tmp/digest-cu129-amd64-framework.txt\n\n      - name: Build and Push AMD64 Runtime\n        run: |\n          version=${{ steps.version.outputs.version }}\n\n          docker buildx build \\\n            --target runtime \\\n            --platform linux/amd64 \\\n            --output type=image,name=lmsysorg/sglang,push-by-digest=true,name-canonical=true,push=true \\\n            -f docker/Dockerfile \\\n            --build-arg CUDA_VERSION=${{ matrix.variant.cuda_version }} \\\n            --build-arg BUILD_TYPE=${{ matrix.variant.build_type }} \\\n            --build-arg GRACE_BLACKWELL=${{ matrix.variant.grace_blackwell }} \\\n            --build-arg INSTALL_FLASHINFER_JIT_CACHE=1 \\\n            --build-arg SGL_VERSION=${version} \\\n            --metadata-file /tmp/metadata-cu129-runtime.json \\\n            --no-cache \\\n            .\n\n          DIGEST=$(python3 -c \"import json; print(json.load(open('/tmp/metadata-cu129-runtime.json'))['containerimage.digest'])\")\n          echo \"Pushed digest: ${DIGEST}\"\n          echo \"${DIGEST}\" > /tmp/digest-cu129-amd64-runtime.txt\n\n      - name: Build and Push AMD64 Framework (CUDA 13)\n        run: |\n          version=${{ steps.version.outputs.version }}\n\n          docker buildx build \\\n            --target framework \\\n            --platform linux/amd64 \\\n            --output type=image,name=lmsysorg/sglang,push-by-digest=true,name-canonical=true,push=true \\\n            -f docker/Dockerfile \\\n            --build-arg CUDA_VERSION=13.0.1 \\\n            --build-arg BUILD_TYPE=${{ matrix.variant.build_type }} \\\n            --build-arg INSTALL_FLASHINFER_JIT_CACHE=1 \\\n            --build-arg GRACE_BLACKWELL=0 \\\n            --build-arg SGL_VERSION=${version} \\\n            --metadata-file /tmp/metadata-cu130-framework.json \\\n            --no-cache \\\n            .\n\n          DIGEST=$(python3 -c \"import json; print(json.load(open('/tmp/metadata-cu130-framework.json'))['containerimage.digest'])\")\n          echo \"Pushed digest: ${DIGEST}\"\n          echo \"${DIGEST}\" > /tmp/digest-cu130-amd64-framework.txt\n\n      - name: Build and Push AMD64 Runtime (CUDA 13)\n        run: |\n          version=${{ steps.version.outputs.version }}\n\n          docker buildx build \\\n            --target runtime \\\n            --platform linux/amd64 \\\n            --output type=image,name=lmsysorg/sglang,push-by-digest=true,name-canonical=true,push=true \\\n            -f docker/Dockerfile \\\n            --build-arg CUDA_VERSION=13.0.1 \\\n            --build-arg BUILD_TYPE=${{ matrix.variant.build_type }} \\\n            --build-arg INSTALL_FLASHINFER_JIT_CACHE=1 \\\n            --build-arg GRACE_BLACKWELL=0 \\\n            --build-arg SGL_VERSION=${version} \\\n            --metadata-file /tmp/metadata-cu130-runtime.json \\\n            --no-cache \\\n            .\n\n          DIGEST=$(python3 -c \"import json; print(json.load(open('/tmp/metadata-cu130-runtime.json'))['containerimage.digest'])\")\n          echo \"Pushed digest: ${DIGEST}\"\n          echo \"${DIGEST}\" > /tmp/digest-cu130-amd64-runtime.txt\n\n      - name: Upload digests\n        uses: actions/upload-artifact@v4\n        with:\n          name: digests-amd64\n          path: /tmp/digest-*.txt\n          retention-days: 1\n\n  publish-arm64:\n    if: github.repository == 'sgl-project/sglang'\n    environment: \"prod\"\n    strategy:\n      matrix:\n        variant:\n          - cuda_version: \"12.9.1\"\n            build_type: \"all\"\n            grace_blackwell: 1\n    runs-on: arm-docker-build-node\n    steps:\n      - name: Delete huge unnecessary tools folder\n        run: rm -rf /opt/hostedtoolcache\n\n      - name: Checkout repository\n        uses: actions/checkout@v4\n\n      - name: Set up Docker Buildx\n        uses: docker/setup-buildx-action@v3\n\n      - name: Login to Docker Hub\n        uses: docker/login-action@v2\n        with:\n          username: ${{ secrets.DOCKERHUB_USERNAME }}\n          password: ${{ secrets.DOCKERHUB_TOKEN }}\n\n      - name: Get version from tag\n        id: version\n        run: |\n          if [ \"${{ github.event_name }}\" = \"workflow_dispatch\" ]; then\n            VERSION=\"${{ github.event.inputs.version }}\"\n          else\n            # Extract version from tag (e.g., v0.5.7 -> 0.5.7)\n            VERSION=\"${GITHUB_REF_NAME#v}\"\n          fi\n\n          # Validate version format\n          if [ -z \"$VERSION\" ]; then\n            echo \"::error::Version is empty\"\n            exit 1\n          fi\n          if ! echo \"$VERSION\" | grep -qE '^[0-9]+\\.[0-9]+\\.[0-9]+'; then\n            echo \"::error::Invalid version format: $VERSION (expected: X.Y.Z)\"\n            exit 1\n          fi\n\n          echo \"version=${VERSION}\" >> $GITHUB_OUTPUT\n\n      - name: Build ARM64 Framework\n        run: |\n          version=${{ steps.version.outputs.version }}\n\n          docker buildx build \\\n            --target framework \\\n            --platform linux/arm64 \\\n            --output type=image,name=lmsysorg/sglang,push-by-digest=true,name-canonical=true,push=true \\\n            -f docker/Dockerfile \\\n            --build-arg CUDA_VERSION=${{ matrix.variant.cuda_version }} \\\n            --build-arg BUILD_TYPE=${{ matrix.variant.build_type }} \\\n            --build-arg GRACE_BLACKWELL=${{ matrix.variant.grace_blackwell }} \\\n            --build-arg INSTALL_FLASHINFER_JIT_CACHE=1 \\\n            --build-arg SGL_VERSION=${version} \\\n            --metadata-file /tmp/metadata-cu129-framework.json \\\n            --no-cache \\\n            .\n\n          DIGEST=$(python3 -c \"import json; print(json.load(open('/tmp/metadata-cu129-framework.json'))['containerimage.digest'])\")\n          echo \"Pushed digest: ${DIGEST}\"\n          echo \"${DIGEST}\" > /tmp/digest-cu129-arm64-framework.txt\n\n      - name: Build and Push ARM64 Runtime\n        run: |\n          version=${{ steps.version.outputs.version }}\n\n          docker buildx build \\\n            --target runtime \\\n            --platform linux/arm64 \\\n            --output type=image,name=lmsysorg/sglang,push-by-digest=true,name-canonical=true,push=true \\\n            -f docker/Dockerfile \\\n            --build-arg CUDA_VERSION=${{ matrix.variant.cuda_version }} \\\n            --build-arg BUILD_TYPE=${{ matrix.variant.build_type }} \\\n            --build-arg GRACE_BLACKWELL=${{ matrix.variant.grace_blackwell }} \\\n            --build-arg INSTALL_FLASHINFER_JIT_CACHE=1 \\\n            --build-arg SGL_VERSION=${version} \\\n            --metadata-file /tmp/metadata-cu129-runtime.json \\\n            --no-cache \\\n            .\n\n          DIGEST=$(python3 -c \"import json; print(json.load(open('/tmp/metadata-cu129-runtime.json'))['containerimage.digest'])\")\n          echo \"Pushed digest: ${DIGEST}\"\n          echo \"${DIGEST}\" > /tmp/digest-cu129-arm64-runtime.txt\n\n      - name: Build and Push ARM64 Framework (CUDA 13)\n        run: |\n          version=${{ steps.version.outputs.version }}\n\n          docker buildx build \\\n            --target framework \\\n            --platform linux/arm64 \\\n            --output type=image,name=lmsysorg/sglang,push-by-digest=true,name-canonical=true,push=true \\\n            -f docker/Dockerfile \\\n            --build-arg CUDA_VERSION=13.0.1 \\\n            --build-arg BUILD_TYPE=${{ matrix.variant.build_type }} \\\n            --build-arg INSTALL_FLASHINFER_JIT_CACHE=1 \\\n            --build-arg GRACE_BLACKWELL=1 \\\n            --build-arg SGL_VERSION=${version} \\\n            --metadata-file /tmp/metadata-cu130-framework.json \\\n            --no-cache \\\n            .\n\n          DIGEST=$(python3 -c \"import json; print(json.load(open('/tmp/metadata-cu130-framework.json'))['containerimage.digest'])\")\n          echo \"Pushed digest: ${DIGEST}\"\n          echo \"${DIGEST}\" > /tmp/digest-cu130-arm64-framework.txt\n\n      - name: Build and Push ARM64 Runtime (CUDA 13)\n        run: |\n          version=${{ steps.version.outputs.version }}\n\n          docker buildx build \\\n            --target runtime \\\n            --platform linux/arm64 \\\n            --output type=image,name=lmsysorg/sglang,push-by-digest=true,name-canonical=true,push=true \\\n            -f docker/Dockerfile \\\n            --build-arg CUDA_VERSION=13.0.1 \\\n            --build-arg BUILD_TYPE=${{ matrix.variant.build_type }} \\\n            --build-arg GRACE_BLACKWELL=1 \\\n            --build-arg SGL_VERSION=${version} \\\n            --metadata-file /tmp/metadata-cu130-runtime.json \\\n            --no-cache \\\n            .\n\n          DIGEST=$(python3 -c \"import json; print(json.load(open('/tmp/metadata-cu130-runtime.json'))['containerimage.digest'])\")\n          echo \"Pushed digest: ${DIGEST}\"\n          echo \"${DIGEST}\" > /tmp/digest-cu130-arm64-runtime.txt\n\n      - name: Upload digests\n        uses: actions/upload-artifact@v4\n        with:\n          name: digests-arm64\n          path: /tmp/digest-*.txt\n          retention-days: 1\n\n  create-manifests:\n    runs-on: ubuntu-22.04\n    needs: [publish-x86, publish-arm64]\n    if: github.repository == 'sgl-project/sglang'\n    environment: \"prod\"\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v4\n\n      - name: Set up Docker Buildx\n        uses: docker/setup-buildx-action@v3\n\n      - name: Login to Docker Hub\n        uses: docker/login-action@v2\n        with:\n          username: ${{ secrets.DOCKERHUB_USERNAME }}\n          password: ${{ secrets.DOCKERHUB_TOKEN }}\n\n      - name: Get version from tag\n        id: version\n        run: |\n          if [ \"${{ github.event_name }}\" = \"workflow_dispatch\" ]; then\n            VERSION=\"${{ github.event.inputs.version }}\"\n          else\n            # Extract version from tag (e.g., v0.5.7 -> 0.5.7)\n            VERSION=\"${GITHUB_REF_NAME#v}\"\n          fi\n\n          # Validate version format\n          if [ -z \"$VERSION\" ]; then\n            echo \"::error::Version is empty\"\n            exit 1\n          fi\n          if ! echo \"$VERSION\" | grep -qE '^[0-9]+\\.[0-9]+\\.[0-9]+'; then\n            echo \"::error::Invalid version format: $VERSION (expected: X.Y.Z)\"\n            exit 1\n          fi\n\n          echo \"version=${VERSION}\" >> $GITHUB_OUTPUT\n\n      - name: Download amd64 digests\n        uses: actions/download-artifact@v4\n        with:\n          name: digests-amd64\n          path: /tmp/digests/amd64\n\n      - name: Download arm64 digests\n        uses: actions/download-artifact@v4\n        with:\n          name: digests-arm64\n          path: /tmp/digests/arm64\n\n      - name: Create multi-arch manifests\n        run: |\n          version=${{ steps.version.outputs.version }}\n\n          # Load all digests\n          CU129_AMD64_FW=$(cat /tmp/digests/amd64/digest-cu129-amd64-framework.txt)\n          CU129_AMD64_RT=$(cat /tmp/digests/amd64/digest-cu129-amd64-runtime.txt)\n          CU130_AMD64_FW=$(cat /tmp/digests/amd64/digest-cu130-amd64-framework.txt)\n          CU130_AMD64_RT=$(cat /tmp/digests/amd64/digest-cu130-amd64-runtime.txt)\n          CU129_ARM64_FW=$(cat /tmp/digests/arm64/digest-cu129-arm64-framework.txt)\n          CU129_ARM64_RT=$(cat /tmp/digests/arm64/digest-cu129-arm64-runtime.txt)\n          CU130_ARM64_FW=$(cat /tmp/digests/arm64/digest-cu130-arm64-framework.txt)\n          CU130_ARM64_RT=$(cat /tmp/digests/arm64/digest-cu130-arm64-runtime.txt)\n\n          # Create versioned framework manifest (default)\n          docker buildx imagetools create \\\n            -t lmsysorg/sglang:v${version} \\\n            lmsysorg/sglang@${CU129_AMD64_FW} \\\n            lmsysorg/sglang@${CU129_ARM64_FW}\n\n          # Create latest framework manifest (default)\n          docker buildx imagetools create \\\n            -t lmsysorg/sglang:latest \\\n            lmsysorg/sglang@${CU129_AMD64_FW} \\\n            lmsysorg/sglang@${CU129_ARM64_FW}\n\n          # Create versioned runtime manifest\n          docker buildx imagetools create \\\n            -t lmsysorg/sglang:v${version}-runtime \\\n            lmsysorg/sglang@${CU129_AMD64_RT} \\\n            lmsysorg/sglang@${CU129_ARM64_RT}\n\n          # Create latest runtime manifest\n          docker buildx imagetools create \\\n            -t lmsysorg/sglang:latest-runtime \\\n            lmsysorg/sglang@${CU129_AMD64_RT} \\\n            lmsysorg/sglang@${CU129_ARM64_RT}\n\n          # Create versioned CUDA 13 framework manifest\n          docker buildx imagetools create \\\n            -t lmsysorg/sglang:v${version}-cu130 \\\n            lmsysorg/sglang@${CU130_AMD64_FW} \\\n            lmsysorg/sglang@${CU130_ARM64_FW}\n\n          # Create latest CUDA 13 framework manifest\n          docker buildx imagetools create \\\n            -t lmsysorg/sglang:latest-cu130 \\\n            lmsysorg/sglang@${CU130_AMD64_FW} \\\n            lmsysorg/sglang@${CU130_ARM64_FW}\n\n          # Create versioned CUDA 13 runtime manifest\n          docker buildx imagetools create \\\n            -t lmsysorg/sglang:v${version}-cu130-runtime \\\n            lmsysorg/sglang@${CU130_AMD64_RT} \\\n            lmsysorg/sglang@${CU130_ARM64_RT}\n\n          # Create latest CUDA 13 runtime manifest\n          docker buildx imagetools create \\\n            -t lmsysorg/sglang:latest-cu130-runtime \\\n            lmsysorg/sglang@${CU130_AMD64_RT} \\\n            lmsysorg/sglang@${CU130_ARM64_RT}\n"
  },
  {
    "path": ".github/workflows/release-docs.yml",
    "content": "name: Release Documentation\n\non:\n  release:\n    types: [published]\n  push:\n    branches:\n      - main\n    paths:\n      - \"docs/**\"\n      - \"python/sglang/version.py\"\n      - \"python/sglang/**\"\n  workflow_dispatch:\n\nconcurrency:\n  group: release-docs-${{ github.ref }}\n  cancel-in-progress: true\n\nenv:\n  SGLANG_IS_IN_CI: true\n\njobs:\n  execute-and-deploy:\n    runs-on: 1-gpu-runner\n    if: github.repository == 'sgl-project/sglang'\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n\n      - name: Fetch full git history for release index\n        if: github.event_name == 'release'\n        run: |\n          git fetch --prune --unshallow || git fetch --prune --depth=0\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/cuda/ci_install_dependency.sh\n          pip install -r docs/requirements.txt\n          apt-get update && apt-get install -y pandoc parallel retry\n          ln -sf \"$(which python3)\" /usr/bin/python\n\n      - name: Setup Jupyter Kernel\n        run: |\n          python -m ipykernel install --user --name python3 --display-name \"Python 3\"\n\n      - name: Execute notebooks\n        timeout-minutes: 40\n        run: |\n          cd docs\n          make clean\n          make compile\n\n      - name: Push HTML to sgl-project.github.io\n        timeout-minutes: 30\n        env:\n          GITHUB_TOKEN: ${{ secrets.GH_PAT_FOR_DOCUMENTATION }}\n        run: |\n          cd docs\n          make html\n          make markdown\n          python3 wrap_run_llm.py\n\n          if [[ \"${{ github.event_name }}\" == \"release\" ]]; then\n            python3 release_lookup/generate_index.py --output release_lookup/release_index.json\n\n            # Copy release lookup tool for official docs on published releases.\n            mkdir -p _build/html/release_lookup\n            cp release_lookup/index.html _build/html/release_lookup/\n            cp release_lookup/release_index.json _build/html/release_lookup/\n          fi\n\n          cd _build/html\n\n          git clone https://$GITHUB_TOKEN@github.com/sgl-project/sgl-project.github.io.git ../sgl-project.github.io --depth 1\n          if [[ \"${{ github.event_name }}\" == \"release\" ]]; then\n            find ../sgl-project.github.io/ -mindepth 1 -not -path \"../sgl-project.github.io/.git*\" -not -name CNAME -not -name \".jekyll\" -not -name \".nojekyll\" -delete\n          else\n            find ../sgl-project.github.io/ -mindepth 1 -not -path \"../sgl-project.github.io/.git*\" -not -path \"../sgl-project.github.io/release_lookup*\" -not -name CNAME -not -name \".jekyll\" -not -name \".nojekyll\" -delete\n          fi\n          cp -r * ../sgl-project.github.io\n          cp ../../README.md ../sgl-project.github.io/README.md\n          cd ../sgl-project.github.io\n          git config user.name \"sglang-bot\"\n          git config user.email \"sglangbot@gmail.com\"\n          git add .\n          git commit -m \"Update $(date +'%Y-%m-%d %H:%M:%S')\"\n          git push https://$GITHUB_TOKEN@github.com/sgl-project/sgl-project.github.io.git main\n          cd ..\n          rm -rf sgl-project.github.io\n"
  },
  {
    "path": ".github/workflows/release-pypi-gateway.yml",
    "content": "name: Release SGLang Model Gateway to PyPI\n\non:\n  push:\n    branches:\n      - main\n    paths:\n      - sgl-model-gateway/bindings/python/pyproject.toml\n  workflow_dispatch:\n\njobs:\n  build:\n    name: build on ${{ matrix.platform || matrix.os }} (${{ matrix.target }} - ${{ matrix.manylinux || 'auto' }})\n    runs-on: ${{ matrix.os }}-latest\n    strategy:\n      fail-fast: false\n      matrix:\n        os: [ubuntu, macos, windows]\n        target: [x86_64, aarch64]\n        manylinux: [auto]\n        include:\n          - os: ubuntu\n            platform: linux\n          - os: windows\n            ls: dir\n            target: x86_64\n            python-architecture: x64\n            interpreter: 3.9 3.10 3.11 3.12 3.13\n          - os: macos\n            target: aarch64\n            interpreter: 3.9 3.10 3.11 3.12 3.13\n          - os: ubuntu\n            platform: linux\n            target: aarch64\n          # musllinux\n          - os: ubuntu\n            platform: linux\n            target: x86_64\n            manylinux: musllinux_1_1\n          - os: ubuntu\n            platform: linux\n            target: aarch64\n            manylinux: musllinux_1_1\n        exclude:\n          - os: windows\n            target: aarch64\n\n    steps:\n      - uses: actions/checkout@v4\n        with:\n          path: sglang-repo\n\n      - name: Move sgl-model-gateway folder to root and delete sglang-repo\n        run: |\n          mv sglang-repo/sgl-model-gateway/* .\n          rm -rf sglang-repo\n          ls -alt\n        shell: bash\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: \"3.13\"\n          architecture: ${{ matrix.python-architecture || 'x64' }}\n\n      - name: Install twine\n        run: pip install -U twine\n\n      - name: Install protoc (macOS)\n        if: matrix.os == 'macos'\n        run: brew install protobuf\n\n      - name: Install protoc (Windows)\n        if: matrix.os == 'windows'\n        run: choco install protoc -y\n\n      - name: Build wheels\n        uses: PyO3/maturin-action@v1\n        with:\n          working-directory: bindings/python\n          target: ${{ matrix.target }}\n          manylinux: ${{ matrix.manylinux || 'auto' }}\n          args: --release --out dist --features vendored-openssl --interpreter ${{ matrix.interpreter || '3.9 3.10 3.11 3.12 3.13 3.14' }}\n          rust-toolchain: stable\n          docker-options: -e CI -e CC_aarch64_unknown_linux_gnu=aarch64-linux-gnu-gcc -e CXX_aarch64_unknown_linux_gnu=aarch64-linux-gnu-g++\n          before-script-linux: |\n            # Install build dependencies (perl/make for vendored OpenSSL, protoc for gRPC)\n            if command -v yum &> /dev/null; then\n              yum update -y && yum install -y wget unzip gcc gcc-c++ perl-core make\n              # Install cross-compilation toolchain for aarch64 if needed\n              if [ \"${{ matrix.target }}\" = \"aarch64\" ]; then\n                yum install -y gcc-aarch64-linux-gnu gcc-c++-aarch64-linux-gnu || true\n              fi\n            elif command -v apt-get &> /dev/null; then\n              apt-get update && apt-get install -y wget unzip gcc g++ perl make\n              # Install cross-compilation toolchain for aarch64 if needed\n              if [ \"${{ matrix.target }}\" = \"aarch64\" ]; then\n                apt-get install -y gcc-aarch64-linux-gnu g++-aarch64-linux-gnu || true\n              fi\n            fi\n            (cd /tmp && \\\n             wget https://github.com/protocolbuffers/protobuf/releases/download/v32.0/protoc-32.0-linux-x86_64.zip && \\\n             unzip protoc-32.0-linux-x86_64.zip -d /usr/local && \\\n             rm protoc-32.0-linux-x86_64.zip)\n            protoc --version\n\n      - name: List built packages\n        run: ${{ matrix.ls || 'ls -lh' }} bindings/python/dist/\n\n      - name: Check packages\n        run: twine check --strict bindings/python/dist/*\n\n      - uses: actions/upload-artifact@v4\n        with:\n          name: packages-${{ matrix.os }}-${{ matrix.target }}-${{ matrix.manylinux || 'auto' }}\n          path: bindings/python/dist/\n\n  build-sdist:\n    name: Build SDist\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v4\n        with:\n          path: sglang-repo\n\n      - name: Move sgl-model-gateway folder to root and delete sglang-repo\n        run: |\n          mv sglang-repo/sgl-model-gateway/* .\n          rm -rf sglang-repo\n          ls -alt\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: \"3.13\"\n\n      - name: Build SDist\n        uses: PyO3/maturin-action@v1\n        with:\n          working-directory: bindings/python\n          command: sdist\n          args: --out dist\n          rust-toolchain: stable\n\n      - uses: actions/upload-artifact@v4\n        with:\n          name: sdist\n          path: bindings/python/dist/*.tar.gz\n\n  upload:\n    name: Upload to PyPI\n    if: github.repository == 'sgl-project/sglang'  # Ensure this job only runs for the sgl-project/sglang repository\n    needs: [build, build-sdist]\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/download-artifact@v4\n        with:\n          path: dist\n          merge-multiple: true\n\n      - name: Upload to PyPI\n        env:\n          TWINE_USERNAME: __token__\n          TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN_ROUTER }}\n        run: |\n          pip install twine\n          twine upload dist/* --verbose\n"
  },
  {
    "path": ".github/workflows/release-pypi-nightly.yml",
    "content": "name: Release PyPI Nightly Wheels\n\non:\n  # Run daily at 2 AM UTC\n  schedule:\n    - cron: '0 2 * * *'\n  # Triggered by nightly Docker workflow to use same commit\n  repository_dispatch:\n    types: [nightly-release]\n  # Manual trigger for testing\n  workflow_dispatch:\n    inputs:\n      commit_sha:\n        description: 'Specific commit SHA to build (leave empty for latest)'\n        required: false\n        type: string\n      cuda_version:\n        description: 'CUDA version (e.g., 129 or 130)'\n        required: false\n        default: '129'\n        type: string\n\nconcurrency:\n  group: release-pypi-nightly-${{ github.ref }}\n  cancel-in-progress: true\n\njobs:\n  build-nightly-wheel:\n    if: github.repository == 'sgl-project/sglang'\n    runs-on: ubuntu-latest\n    outputs:\n      nightly_version: ${{ steps.build.outputs.nightly_version }}\n      commit_hash: ${{ steps.build.outputs.commit_hash }}\n      build_date: ${{ steps.build.outputs.build_date }}\n    steps:\n      - uses: actions/checkout@v4\n        with:\n          # Use commit from: 1) Docker workflow, 2) manual input, 3) latest main\n          ref: ${{ github.event.client_payload.commit_sha || inputs.commit_sha || github.sha }}\n          fetch-depth: 0  # Need full history for setuptools-scm\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: \"3.10\"\n\n      - name: Install build dependencies\n        run: |\n          pip install build wheel setuptools setuptools-scm\n\n      - name: Build wheel\n        id: build\n        run: |\n          cd python\n          cp ../README.md ../LICENSE .\n\n          # Parse git describe output to get latest tag\n          # Use same command as pyproject.toml to ensure version consistency\n          DESC=$(git tag --list --sort=-version:refname 'v*.*.*' | head -1 | xargs git describe --tags --long 2>/dev/null || echo 'v0.0.0-0-g0000000')\n          TAG=$(echo \"$DESC\" | cut -d- -f1)\n          HASH=\"g$(git rev-parse --short HEAD)\"\n          BUILD_DATE=$(date -u +%Y%m%d)\n\n          # Increment patch version for nightlies (e.g., v0.5.8 -> 0.5.9)\n          VERSION=${TAG#v}  # Remove 'v' prefix\n          MAJOR=$(echo \"$VERSION\" | cut -d. -f1)\n          MINOR=$(echo \"$VERSION\" | cut -d. -f2)\n          PATCH=$(echo \"$VERSION\" | cut -d. -f3)\n          NEXT_PATCH=$((PATCH + 1))\n          NEXT_VERSION=\"${MAJOR}.${MINOR}.${NEXT_PATCH}\"\n\n          # Use date-based dev number for correct chronological sorting\n          # e.g., 0.5.9.dev20260215+g4cf4f0859 > 0.5.9.dev20260214+g45a4697d4\n          FORCE_VERSION=\"${NEXT_VERSION}.dev${BUILD_DATE}+${HASH}\"\n          echo \"Forcing nightly version to: $FORCE_VERSION\"\n          export SETUPTOOLS_SCM_PRETEND_VERSION=\"$FORCE_VERSION\"\n\n          # Build wheel\n          python3 -m build --wheel\n\n          # Extract version from built wheel filename\n          WHEEL_FILE=$(ls dist/*.whl)\n          NIGHTLY_VERSION=$(echo \"$WHEEL_FILE\" | sed 's/.*sglang-\\(.*\\)-py3.*/\\1/')\n\n          # Get commit info\n          COMMIT_HASH=$(git rev-parse --short HEAD)\n          BUILD_DATE=$(date -u +%Y-%m-%d)\n\n          echo \"Built wheel: $WHEEL_FILE\"\n          echo \"Nightly version: ${NIGHTLY_VERSION}\"\n          echo \"Commit: ${COMMIT_HASH}\"\n          echo \"Build date: ${BUILD_DATE}\"\n\n          echo \"nightly_version=${NIGHTLY_VERSION}\" >> $GITHUB_OUTPUT\n          echo \"commit_hash=${COMMIT_HASH}\" >> $GITHUB_OUTPUT\n          echo \"build_date=${BUILD_DATE}\" >> $GITHUB_OUTPUT\n\n      - name: Upload wheel artifact\n        uses: actions/upload-artifact@v4\n        with:\n          name: nightly-wheel\n          path: python/dist/*.whl\n          retention-days: 7\n\n  release-nightly:\n    needs: build-nightly-wheel\n    runs-on: ubuntu-latest\n    environment: 'prod'\n    steps:\n      - uses: actions/checkout@v4\n\n      - name: Download wheel artifact\n        uses: actions/download-artifact@v4\n        with:\n          name: nightly-wheel\n          path: dist/\n\n      - name: List downloaded wheels\n        run: |\n          echo \"Downloaded wheel:\"\n          ls -lh dist/\n\n      - name: Create GitHub Release for nightly wheel\n        uses: softprops/action-gh-release@v2\n        with:\n          tag_name: nightly-${{ needs.build-nightly-wheel.outputs.build_date }}-${{ needs.build-nightly-wheel.outputs.commit_hash }}\n          name: Nightly Build ${{ needs.build-nightly-wheel.outputs.build_date }} (${{ needs.build-nightly-wheel.outputs.commit_hash }})\n          repository: sgl-project/whl\n          token: ${{ secrets.GH_PAT_FOR_WHL_RELEASE }}\n          prerelease: true\n          body: |\n            Nightly build from commit ${{ github.sha }}\n            Build date: ${{ needs.build-nightly-wheel.outputs.build_date }}\n            Version: ${{ needs.build-nightly-wheel.outputs.nightly_version }}\n          files: |\n            dist/*.whl\n\n      - name: Clone wheel index repository\n        run: |\n          git clone https://oauth2:${WHL_TOKEN}@github.com/sgl-project/whl.git sgl-whl\n          cd sgl-whl\n          git config --local user.name \"sglang-bot\"\n          git config --local user.email \"sglangbot@gmail.com\"\n        env:\n          WHL_TOKEN: ${{ secrets.GH_PAT_FOR_WHL_RELEASE }}\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: \"3.10\"\n\n      - name: Update wheel index\n        run: |\n          python3 scripts/update_nightly_whl_index.py \\\n            --commit-hash ${{ needs.build-nightly-wheel.outputs.commit_hash }} \\\n            --nightly-version ${{ needs.build-nightly-wheel.outputs.nightly_version }} \\\n            --cuda-version ${{ inputs.cuda_version || '129' }} \\\n            --build-date ${{ needs.build-nightly-wheel.outputs.build_date }}\n\n      - name: Push wheel index\n        run: |\n          cd sgl-whl\n          git add -A\n          git diff --staged --quiet || git commit -m \"Update nightly wheel index for commit ${{ needs.build-nightly-wheel.outputs.commit_hash }}\"\n          git push\n"
  },
  {
    "path": ".github/workflows/release-pypi-pr.yml",
    "content": "name: Release PyPI PR Wheels\n\non:\n  workflow_dispatch:\n    inputs:\n      pr_number:\n        description: 'PR number to build wheel for (works with both internal and fork PRs)'\n        required: true\n        type: string\n\nconcurrency:\n  group: build-pr-wheel-${{ github.event.inputs.pr_number }}\n  cancel-in-progress: true\n\njobs:\n  build-pr-wheel:\n    if: github.repository == 'sgl-project/sglang'\n    runs-on: ubuntu-latest\n    outputs:\n      wheel_version: ${{ steps.gen_version.outputs.wheel_version }}\n      commit_hash: ${{ steps.gen_version.outputs.commit_hash }}\n      build_date: ${{ steps.gen_version.outputs.build_date }}\n    steps:\n      - uses: actions/checkout@v4\n        with:\n          ref: refs/pull/${{ inputs.pr_number }}/head\n          fetch-depth: 0  # Need full history for version generation\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: \"3.10\"\n\n      - name: Generate PR wheel version\n        id: gen_version\n        run: |\n          # Get base version from the latest v*.*.* git tag directly\n          # Note: We cannot use setuptools_scm here because the [tool.setuptools_scm]\n          # config (with custom git_describe_command) lives in python/pyproject.toml,\n          # not at the repo root. Without that config, setuptools_scm falls back to\n          # default git describe which finds gateway-* tags instead of v*.*.* release tags.\n          LATEST_TAG=$(git tag --list --sort=-version:refname 'v*.*.*' | head -1)\n          BASE_VERSION=${LATEST_TAG#v}\n          echo \"Latest release tag: ${LATEST_TAG}\"\n\n          # Get commit info\n          COMMIT_HASH=$(git rev-parse --short HEAD)\n          COMMIT_COUNT=$(git rev-list --count HEAD)\n\n          # Get current date in YYYY-MM-DD format\n          BUILD_DATE=$(date -u +%Y-%m-%d)\n\n          # Always use pr-{number} format for suffix\n          SUFFIX=\"pr-${{ inputs.pr_number }}\"\n\n          # Generate PR wheel version following PEP 440\n          # Format: {base_version}.dev{commit_count}+pr-{number}.g{commit_hash}\n          WHEEL_VERSION=\"${BASE_VERSION}.dev${COMMIT_COUNT}+${SUFFIX}.g${COMMIT_HASH}\"\n\n          echo \"Base version: ${BASE_VERSION}\"\n          echo \"PR wheel version: ${WHEEL_VERSION}\"\n          echo \"Commit: ${COMMIT_HASH}\"\n          echo \"Build date: ${BUILD_DATE}\"\n\n          echo \"wheel_version=${WHEEL_VERSION}\" >> $GITHUB_OUTPUT\n          echo \"commit_hash=${COMMIT_HASH}\" >> $GITHUB_OUTPUT\n          echo \"base_version=${BASE_VERSION}\" >> $GITHUB_OUTPUT\n          echo \"build_date=${BUILD_DATE}\" >> $GITHUB_OUTPUT\n\n      - name: Update pyproject.toml with PR wheel version\n        run: |\n          cd python\n          WHEEL_VERSION=\"${{ steps.gen_version.outputs.wheel_version }}\"\n\n          # Update pyproject.toml to use static version instead of dynamic\n          # Remove 'version' from dynamic list and add static version\n          sed -i 's/dynamic = \\[\"version\"\\]/dynamic = []/' pyproject.toml\n          sed -i \"/^name = \\\"sglang\\\"/a version = \\\"${WHEEL_VERSION}\\\"\" pyproject.toml\n\n          # Verify update\n          echo \"Updated version in pyproject.toml:\"\n          grep \"^version\" pyproject.toml\n          grep \"^dynamic\" pyproject.toml\n\n      - name: Install build dependencies\n        run: |\n          cd python\n          pip install build wheel setuptools\n\n      - name: Build wheel\n        run: |\n          cd python\n          cp ../README.md ../LICENSE .\n          python3 -m build --wheel\n\n          # List built wheels\n          echo \"Built wheel:\"\n          ls -lh dist/\n\n      - name: Upload wheel artifact\n        uses: actions/upload-artifact@v4\n        with:\n          name: pr-wheel-${{ inputs.pr_number }}\n          path: python/dist/*.whl\n          retention-days: 30\n\n  release-pr-wheel:\n    needs: build-pr-wheel\n    runs-on: ubuntu-latest\n    environment: 'prod'\n    steps:\n      - uses: actions/checkout@v4\n\n      - name: Download wheel artifact\n        uses: actions/download-artifact@v4\n        with:\n          name: pr-wheel-${{ inputs.pr_number }}\n          path: dist/\n\n      - name: List downloaded wheels\n        run: |\n          echo \"Downloaded wheel:\"\n          ls -lh dist/\n\n      - name: Create GitHub Release for PR wheel\n        uses: softprops/action-gh-release@v2\n        with:\n          tag_name: pr-${{ inputs.pr_number }}-${{ needs.build-pr-wheel.outputs.build_date }}-${{ needs.build-pr-wheel.outputs.commit_hash }}\n          name: \"PR #${{ inputs.pr_number }} Build (${{ needs.build-pr-wheel.outputs.commit_hash }})\"\n          repository: sgl-project/whl\n          token: ${{ secrets.GH_PAT_FOR_WHL_RELEASE }}\n          prerelease: true\n          body: |\n            PR wheel build from PR #${{ inputs.pr_number }}\n            Commit: ${{ github.sha }}\n            Build date: ${{ needs.build-pr-wheel.outputs.build_date }}\n            Version: ${{ needs.build-pr-wheel.outputs.wheel_version }}\n\n            **Installation via index (pip):**\n            ```bash\n            pip install sglang==${{ needs.build-pr-wheel.outputs.wheel_version }} --index-url https://sgl-project.github.io/whl/pr/\n            ```\n\n            **Installation via index (uv):**\n            ```bash\n            uv pip install sglang==${{ needs.build-pr-wheel.outputs.wheel_version }} --index-url https://sgl-project.github.io/whl/pr/ --extra-index-url https://pypi.org/simple --index-strategy unsafe-best-match\n            ```\n\n            **Direct installation:**\n            ```bash\n            pip install https://github.com/sgl-project/whl/releases/download/pr-${{ inputs.pr_number }}-${{ needs.build-pr-wheel.outputs.build_date }}-${{ needs.build-pr-wheel.outputs.commit_hash }}/sglang-${{ needs.build-pr-wheel.outputs.wheel_version }}-py3-none-any.whl\n            ```\n          files: |\n            dist/*.whl\n\n      - name: Clone wheel index repository\n        run: |\n          git clone https://oauth2:${WHL_TOKEN}@github.com/sgl-project/whl.git sgl-whl\n          cd sgl-whl\n          git config --local user.name \"sglang-bot\"\n          git config --local user.email \"sglangbot@gmail.com\"\n        env:\n          WHL_TOKEN: ${{ secrets.GH_PAT_FOR_WHL_RELEASE }}\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: \"3.10\"\n\n      - name: Update wheel index\n        run: |\n          python3 scripts/update_pr_whl_index.py \\\n            --pr-number ${{ inputs.pr_number }} \\\n            --commit-hash ${{ needs.build-pr-wheel.outputs.commit_hash }} \\\n            --wheel-version ${{ needs.build-pr-wheel.outputs.wheel_version }} \\\n            --build-date ${{ needs.build-pr-wheel.outputs.build_date }}\n\n      - name: Push wheel index\n        run: |\n          cd sgl-whl\n          git add -A\n          git diff --staged --quiet || git commit -m \"Update PR wheel index for PR #${{ inputs.pr_number }} (commit ${{ needs.build-pr-wheel.outputs.commit_hash }})\"\n          git push\n"
  },
  {
    "path": ".github/workflows/release-pypi.yml",
    "content": "name: Release PyPI\non:\n  push:\n    tags:\n      - 'v[0-9]+.*'\n  workflow_dispatch:\n\njobs:\n  publish:\n    if: github.repository == 'sgl-project/sglang'\n    runs-on: ubuntu-latest\n    environment: \"prod\"\n    steps:\n      - name: Set up Python\n        uses: actions/setup-python@v4\n        with:\n          python-version: \"3.10\"\n\n      - name: Checkout repository\n        uses: actions/checkout@v4\n        with:\n          fetch-depth: 0  # Required for setuptools-scm to determine version from tags\n\n      - name: Upload to pypi\n        run: |\n          cd python\n          cp ../README.md ../LICENSE .\n          pip install build wheel setuptools setuptools-scm\n          python3 -m build\n          pip install twine\n          python3 -m twine upload dist/* -u __token__ -p ${{ secrets.PYPI_TOKEN }}\n"
  },
  {
    "path": ".github/workflows/release-tag.yml",
    "content": "name: Release Tag\n# Creates a git tag to trigger release workflows (PyPI, Docker)\n# Use this after testing on a release branch is complete\non:\n  workflow_dispatch:\n    inputs:\n      version:\n        description: 'Version to tag (without v prefix, e.g., 0.5.7)'\n        required: true\n        type: string\n      ref:\n        description: 'Branch or commit to tag (e.g., release/v0.5.7, main, or commit SHA)'\n        required: false\n        default: 'main'\n        type: string\n\npermissions:\n  contents: write\n\njobs:\n  create-tag:\n    if: github.repository == 'sgl-project/sglang'\n    runs-on: ubuntu-latest\n    environment: 'prod'\n    steps:\n      - name: Validate version format\n        run: |\n          VERSION=\"${{ github.event.inputs.version }}\"\n          if [ -z \"$VERSION\" ]; then\n            echo \"::error::Version is required\"\n            exit 1\n          fi\n          if ! echo \"$VERSION\" | grep -qE '^[0-9]+\\.[0-9]+\\.[0-9]+'; then\n            echo \"::error::Invalid version format: $VERSION (expected: X.Y.Z or X.Y.Z.postN)\"\n            exit 1\n          fi\n          echo \"Version validated: v$VERSION\"\n\n      - name: Checkout repository\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ github.event.inputs.ref }}\n          fetch-depth: 0\n          token: ${{ secrets.GITHUB_TOKEN }}\n\n      - name: Check if tag already exists\n        run: |\n          TAG=\"v${{ github.event.inputs.version }}\"\n          if git rev-parse \"$TAG\" >/dev/null 2>&1; then\n            echo \"::error::Tag $TAG already exists\"\n            exit 1\n          fi\n          echo \"Tag $TAG does not exist, proceeding...\"\n\n      - name: Create and push tag\n        run: |\n          TAG=\"v${{ github.event.inputs.version }}\"\n          REF=\"${{ github.event.inputs.ref }}\"\n\n          git config user.name \"sglang-bot\"\n          git config user.email \"sglang-bot@users.noreply.github.com\"\n\n          echo \"Creating tag $TAG on ref $REF (commit: $(git rev-parse HEAD))\"\n          git tag -a \"$TAG\" -m \"Release $TAG\"\n          git push origin \"$TAG\"\n\n          echo \"::notice::Successfully created and pushed tag $TAG\"\n          echo \"This will trigger the release workflows (PyPI, Docker)\"\n"
  },
  {
    "path": ".github/workflows/release-whl-kernel.yml",
    "content": "name: Release SGLang Kernels\n\non:\n  push:\n    branches:\n      - main\n    paths:\n      - sgl-kernel/python/sgl_kernel/version.py\n  workflow_dispatch:\n    inputs:\n      target:\n        type: choice\n        description: 'Build target'\n        required: false\n        default: 'all'\n        options:\n          - 'all'\n          - 'cu129'\n          - 'cu130'\n          - 'rocm700'\n          - 'rocm720'\n          - 'musa43'\n      tag_name:\n        type: string\n        required: false\n      pr_number:\n        description: \"PR number to build from (e.g. 12345)\"\n        type: string\n        required: false\n\nconcurrency:\n  group: release-sglang-kernels-${{ github.ref }}\n  cancel-in-progress: true\n\njobs:\n  build-cu129-matrix:\n    if: |\n      github.repository == 'sgl-project/sglang' &&\n      (github.event.inputs.target == 'all' || github.event.inputs.target == 'cu129')\n    strategy:\n      matrix:\n        python-version: [\"3.10\"]\n        cuda-version: [\"12.9\"]\n        arch: [x86_64, aarch64]\n        include:\n          - arch: x86_64\n            runner: x64-kernel-build-node\n          - arch: aarch64\n            runner: arm-kernel-build-node\n    runs-on: ${{ matrix.runner }}\n    steps:\n      - uses: actions/checkout@v4\n        with:\n          submodules: \"recursive\"\n          ref: ${{ inputs.pr_number && format('refs/pull/{0}/head', inputs.pr_number) || '' }}\n\n      - name: Set up Python ${{ matrix.python-version }}\n        uses: actions/setup-python@v5\n        with:\n          python-version: ${{ matrix.python-version }}\n\n      - name: Build wheels\n        run: |\n          cd sgl-kernel\n          chmod +x ./build.sh\n          ./build.sh \"${{ matrix.python-version }}\" \"${{ matrix.cuda-version }}\" ${{ matrix.arch == 'aarch64' && 'aarch64' || '' }}\n        env:\n          BUILD_JOBS: 64\n          NVCC_THREADS: 8\n\n      - name: Upload to PyPI\n        working-directory: sgl-kernel\n        run: |\n          pip install twine\n          python3 -m twine upload --skip-existing dist/* -u __token__ -p ${{ secrets.PYPI_TOKEN_SGLANG_KERNEL }}\n\n      - name: Upload artifacts\n        uses: actions/upload-artifact@v4\n        with:\n          name: wheel-python${{ matrix.python-version }}-cuda${{ matrix.cuda-version }}${{ matrix.arch == 'aarch64' && '-aarch64' || '' }}\n          path: sgl-kernel/dist/*\n\n  release-cu129:\n    needs: build-cu129-matrix\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_number && format('refs/pull/{0}/head', inputs.pr_number) || '' }}\n\n      - name: Download artifacts\n        uses: actions/download-artifact@v4\n        with:\n          path: sgl-kernel/dist/\n          merge-multiple: true\n          pattern: wheel-*\n\n      - name: Set tag name\n        id: set_tag_name\n        run: |\n          if [ -z \"${{ inputs.tag_name }}\" ]; then\n            TAG_NAME=\"v$(cat sgl-kernel/python/sgl_kernel/version.py | cut -d'\"' -f2)\"\n            echo \"tag_name=$TAG_NAME\" >> $GITHUB_OUTPUT\n          else\n            echo \"tag_name=${{ inputs.tag_name }}\" >> $GITHUB_OUTPUT\n          fi\n\n      - name: Release\n        uses: softprops/action-gh-release@v2\n        with:\n          tag_name: ${{ steps.set_tag_name.outputs.tag_name }}\n          repository: sgl-project/whl\n          token: ${{ secrets.GH_PAT_FOR_WHL_RELEASE }}\n          files: |\n            sgl-kernel/dist/*\n\n      - name: Clone wheel index\n        run: git clone https://oauth2:${WHL_TOKEN}@github.com/sgl-project/whl.git sgl-whl\n        env:\n          WHL_TOKEN: ${{ secrets.GH_PAT_FOR_WHL_RELEASE }}\n\n      - name: Update wheel index\n        run: python3 scripts/update_kernel_whl_index.py --cuda 129\n\n      - name: Push wheel index\n        run: |\n          cd sgl-whl\n          git config --local user.name \"sglang-bot\"\n          git config --local user.email \"sglangbot@gmail.com\"\n          git add -A\n          git commit -m \"update whl index\"\n          git push\n\n  # for now we do not release CUDA 13.0 wheels to pypi\n  build-cu130-matrix:\n    if: |\n      github.repository == 'sgl-project/sglang' &&\n      (github.event.inputs.target == 'all' || github.event.inputs.target == 'cu130')\n    strategy:\n      matrix:\n        python-version: [\"3.10\"]\n        cuda-version: [\"13.0\"]\n        arch: [x86_64, aarch64]\n        include:\n          - arch: x86_64\n            runner: x64-kernel-build-node\n          - arch: aarch64\n            runner: arm-kernel-build-node\n    runs-on: ${{ matrix.runner }}\n    steps:\n      - uses: actions/checkout@v4\n        with:\n          submodules: \"recursive\"\n          ref: ${{ inputs.pr_number && format('refs/pull/{0}/head', inputs.pr_number) || '' }}\n\n      - name: Set up Python ${{ matrix.python-version }}\n        uses: actions/setup-python@v5\n        with:\n          python-version: ${{ matrix.python-version }}\n\n      - name: Build wheels\n        run: |\n          cd sgl-kernel\n          chmod +x ./build.sh\n          ./build.sh \"${{ matrix.python-version }}\" \"${{ matrix.cuda-version }}\" ${{ matrix.arch == 'aarch64' && 'aarch64' || '' }}\n        env:\n          BUILD_JOBS: 64\n          NVCC_THREADS: 8\n\n      - name: Upload artifacts\n        uses: actions/upload-artifact@v4\n        with:\n          name: wheel-python${{ matrix.python-version }}-cuda${{ matrix.cuda-version }}${{ matrix.arch == 'aarch64' && '-aarch64' || '' }}\n          path: sgl-kernel/dist/*\n\n  release-cu130:\n    needs: build-cu130-matrix\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_number && format('refs/pull/{0}/head', inputs.pr_number) || '' }}\n\n      - name: Download artifacts\n        uses: actions/download-artifact@v4\n        with:\n          path: sgl-kernel/dist/\n          merge-multiple: true\n          pattern: wheel-*\n\n      - name: Set tag name\n        id: set_tag_name\n        run: |\n          if [ -z \"${{ inputs.tag_name }}\" ]; then\n            TAG_NAME=\"v$(cat sgl-kernel/python/sgl_kernel/version.py | cut -d'\"' -f2)\"\n            echo \"tag_name=$TAG_NAME\" >> $GITHUB_OUTPUT\n          else\n            echo \"tag_name=${{ inputs.tag_name }}\" >> $GITHUB_OUTPUT\n          fi\n\n      - name: Release\n        uses: softprops/action-gh-release@v2\n        with:\n          tag_name: ${{ steps.set_tag_name.outputs.tag_name }}\n          repository: sgl-project/whl\n          token: ${{ secrets.GH_PAT_FOR_WHL_RELEASE }}\n          files: |\n            sgl-kernel/dist/*\n\n      - name: Clone wheel index\n        run: git clone https://oauth2:${WHL_TOKEN}@github.com/sgl-project/whl.git sgl-whl\n        env:\n          WHL_TOKEN: ${{ secrets.GH_PAT_FOR_WHL_RELEASE }}\n\n      - name: Update wheel index\n        run: python3 scripts/update_kernel_whl_index.py --cuda 130\n\n      - name: Push wheel index\n        run: |\n          cd sgl-whl\n          git config --local user.name \"sglang-bot\"\n          git config --local user.email \"sglangbot@gmail.com\"\n          git add -A\n          git commit -m \"update whl index\"\n          git push\n\n  build-rocm-matrix:\n    if: |\n      github.repository == 'sgl-project/sglang' &&\n      (github.event.inputs.target == 'all' || github.event.inputs.target == 'rocm700' || github.event.inputs.target == 'rocm720')\n    runs-on: amd-docker-scale\n    strategy:\n      matrix:\n        python-version: [\"3.10\"]\n        rocm-version: [\"700\", \"720\"]\n    steps:\n      - uses: actions/checkout@v4\n        with:\n          submodules: \"recursive\"\n          ref: ${{ inputs.pr_number && format('refs/pull/{0}/head', inputs.pr_number) || '' }}\n\n      - name: Set up Python ${{ matrix.python-version }}\n        uses: actions/setup-python@v5\n        with:\n          python-version: ${{ matrix.python-version }}\n\n      - name: Build wheels\n        run: |\n          cp 3rdparty/amd/wheel/sgl-kernel/* sgl-kernel/\n          cd sgl-kernel\n          chmod +x ./build_rocm.sh\n          ./build_rocm.sh \"${{ matrix.rocm-version }}\"\n\n      - name: Upload artifacts\n        uses: actions/upload-artifact@v4\n        with:\n          name: wheel-python${{ matrix.python-version }}-rocm${{ matrix.rocm-version }}\n          path: sgl-kernel/dist/*\n\n  release-rocm700:\n    needs: build-rocm-matrix\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_number && format('refs/pull/{0}/head', inputs.pr_number) || '' }}\n\n      - name: Download artifacts\n        uses: actions/download-artifact@v4\n        with:\n          path: sgl-kernel/dist/\n          merge-multiple: true\n          pattern: wheel-*-rocm700\n\n      - name: Set tag name\n        id: set_tag_name\n        run: |\n          if [ -z \"${{ inputs.tag_name }}\" ]; then\n            TAG_NAME=\"v$(cat sgl-kernel/python/sgl_kernel/version.py | cut -d'\"' -f2)\"\n            echo \"tag_name=$TAG_NAME\" >> $GITHUB_OUTPUT\n          else\n            echo \"tag_name=${{ inputs.tag_name }}\" >> $GITHUB_OUTPUT\n          fi\n\n      - name: Release\n        uses: softprops/action-gh-release@v2\n        with:\n          tag_name: ${{ steps.set_tag_name.outputs.tag_name }}\n          repository: sgl-project/whl\n          token: ${{ secrets.GH_PAT_FOR_WHL_RELEASE }}\n          files: |\n            sgl-kernel/dist/*\n\n      - name: Clone wheel index\n        run: git clone https://oauth2:${WHL_TOKEN}@github.com/sgl-project/whl.git sgl-whl\n        env:\n          WHL_TOKEN: ${{ secrets.GH_PAT_FOR_WHL_RELEASE }}\n\n      - name: Update wheel index\n        run: python3 scripts/update_kernel_whl_index.py --rocm 700\n\n      - name: Push wheel index\n        run: |\n          cd sgl-whl\n          git config --local user.name \"sglang-bot\"\n          git config --local user.email \"sglangbot@gmail.com\"\n          git add -A\n          git commit -m \"update whl index\"\n          git push\n\n  release-rocm720:\n    needs: build-rocm-matrix\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_number && format('refs/pull/{0}/head', inputs.pr_number) || '' }}\n\n      - name: Download artifacts\n        uses: actions/download-artifact@v4\n        with:\n          path: sgl-kernel/dist/\n          merge-multiple: true\n          pattern: wheel-*-rocm720\n\n      - name: Set tag name\n        id: set_tag_name\n        run: |\n          if [ -z \"${{ inputs.tag_name }}\" ]; then\n            TAG_NAME=\"v$(cat sgl-kernel/python/sgl_kernel/version.py | cut -d'\"' -f2)\"\n            echo \"tag_name=$TAG_NAME\" >> $GITHUB_OUTPUT\n          else\n            echo \"tag_name=${{ inputs.tag_name }}\" >> $GITHUB_OUTPUT\n          fi\n\n      - name: Release\n        uses: softprops/action-gh-release@v2\n        with:\n          tag_name: ${{ steps.set_tag_name.outputs.tag_name }}\n          repository: sgl-project/whl\n          token: ${{ secrets.GH_PAT_FOR_WHL_RELEASE }}\n          files: |\n            sgl-kernel/dist/*\n\n      - name: Clone wheel index\n        run: git clone https://oauth2:${WHL_TOKEN}@github.com/sgl-project/whl.git sgl-whl\n        env:\n          WHL_TOKEN: ${{ secrets.GH_PAT_FOR_WHL_RELEASE }}\n\n      - name: Update wheel index\n        run: python3 scripts/update_kernel_whl_index.py --rocm 720\n\n      - name: Push wheel index\n        run: |\n          cd sgl-whl\n          git config --local user.name \"sglang-bot\"\n          git config --local user.email \"sglangbot@gmail.com\"\n          git add -A\n          git commit -m \"update whl index\"\n          git push\n\n  build-musa43:\n    if: |\n      github.repository == 'sgl-project/sglang' &&\n      (github.event.inputs.target == 'all' || github.event.inputs.target == 'musa43')\n    runs-on: kernel-build-node-musa\n    strategy:\n      matrix:\n        python-version: [\"3.10\"]\n        musa-version: [\"43\"]\n    steps:\n      - uses: actions/checkout@v4\n        with:\n          submodules: \"recursive\"\n\n      - name: Build wheels\n        run: |\n          cd sgl-kernel\n          mv pyproject_musa.toml pyproject.toml\n          python setup_musa.py sdist bdist_wheel\n\n      - name: Rename MUSA wheels\n        run: |\n          bash scripts/ci/musa/rename_wheels_musa.sh ${{ matrix.musa-version }} sgl-kernel/dist\n\n      - name: Upload artifacts\n        uses: actions/upload-artifact@v4\n        with:\n          name: wheel-python${{ matrix.python-version }}-musa${{ matrix.musa-version }}\n          path: sgl-kernel/dist/*\n\n  release-musa43:\n    needs: build-musa43\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v4\n\n      - name: Download artifacts\n        uses: actions/download-artifact@v4\n        with:\n          path: sgl-kernel/dist/\n          merge-multiple: true\n          pattern: wheel-*\n\n      - name: Set tag name\n        id: set_tag_name\n        run: |\n          if [ -z \"${{ inputs.tag_name }}\" ]; then\n            TAG_NAME=\"v$(cat sgl-kernel/python/sgl_kernel/version.py | cut -d'\"' -f2)\"\n            echo \"tag_name=$TAG_NAME\" >> $GITHUB_OUTPUT\n          else\n            echo \"tag_name=${{ inputs.tag_name }}\" >> $GITHUB_OUTPUT\n          fi\n\n      - name: Release\n        uses: softprops/action-gh-release@v2\n        with:\n          tag_name: ${{ steps.set_tag_name.outputs.tag_name }}\n          repository: sgl-project/whl\n          token: ${{ secrets.GH_PAT_FOR_WHL_RELEASE }}\n          files: |\n            sgl-kernel/dist/*\n\n      - name: Clone wheel index\n        run: git clone https://oauth2:${WHL_TOKEN}@github.com/sgl-project/whl.git sgl-whl\n        env:\n          WHL_TOKEN: ${{ secrets.GH_PAT_FOR_WHL_RELEASE }}\n\n      - name: Update wheel index\n        run: python3 scripts/update_kernel_whl_index.py --musa 43\n\n      - name: Push wheel index\n        run: |\n          cd sgl-whl\n          git config --local user.name \"sglang-bot\"\n          git config --local user.email \"sglangbot@gmail.com\"\n          git add -A\n          git commit -m \"update whl index\"\n          git push\n"
  },
  {
    "path": ".github/workflows/rerun-ut.yml",
    "content": "name: Rerun UT\nrun-name: ${{ inputs.pr_head_sha && format('[rerun-ut] {0}', inputs.pr_head_sha) || '[rerun-ut]' }}\n\non:\n  workflow_dispatch:\n    inputs:\n      test_command:\n        description: \"Test command to run (e.g. 'registered/core/test_srt_endpoint.py TestSRTEndpoint.test_simple_decode')\"\n        required: true\n        type: string\n      runner_label:\n        description: \"Runner label\"\n        required: true\n        type: choice\n        options:\n          - 1-gpu-runner\n          - 1-gpu-5090\n          - 2-gpu-runner\n          - 4-gpu-h100\n          - 4-gpu-a10\n          - 4-gpu-b200\n          - 8-gpu-h200\n          - 8-gpu-h20\n          - 8-gpu-b200\n      pr_head_sha:\n        description: \"PR head SHA to checkout (for /rerun-ut on fork PRs)\"\n        required: false\n        type: string\n        default: \"\"\n      use_deepep:\n        description: \"Use ci_install_deepep.sh instead of ci_install_dependency.sh\"\n        required: false\n        type: string\n        default: \"false\"\n\nenv:\n  SGLANG_IS_IN_CI: true\n  SGLANG_CUDA_COREDUMP: \"1\"\n  SGLANG_JIT_DEEPGEMM_FAST_WARMUP: true\n\npermissions:\n  actions: write\n  contents: read\n\njobs:\n  rerun-ut-cuda:\n    runs-on: ${{ inputs.runner_label }}\n    timeout-minutes: 120\n    env:\n      RUNNER_LABELS: ${{ inputs.runner_label }}\n      IS_BLACKWELL: ${{ (inputs.runner_label == '1-gpu-5090' || contains(inputs.runner_label, 'b200')) && '1' || '0' }}\n      SGLANG_CI_RDMA_ALL_DEVICES: ${{ inputs.runner_label == '8-gpu-h20' && 'mlx5_1,mlx5_2,mlx5_3,mlx5_4' || '' }}\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.pr_head_sha || github.sha }}\n\n      - name: Install dependencies\n        timeout-minutes: 20\n        run: |\n          if [[ \"${{ inputs.runner_label }}\" == \"1-gpu-5090\" ]]; then\n            source /etc/profile.d/sglang-ci.sh\n          fi\n          if [[ \"${{ inputs.use_deepep }}\" == \"true\" ]]; then\n            bash scripts/ci/cuda/ci_install_deepep.sh\n          else\n            bash scripts/ci/cuda/ci_install_dependency.sh\n          fi\n\n      - name: Run test\n        timeout-minutes: 60\n        run: |\n          if [[ \"${{ inputs.runner_label }}\" == \"1-gpu-5090\" ]]; then\n            source /etc/profile.d/sglang-ci.sh\n          fi\n          cd test/\n          python3 ${{ inputs.test_command }}\n\n      - uses: ./.github/actions/upload-cuda-coredumps\n        if: always()\n"
  },
  {
    "path": ".github/workflows/retag-docker.yml",
    "content": "name: Retag Docker Image\n\non:\n  workflow_dispatch:\n    inputs:\n      source_tag:\n        description: \"Existing image tag (e.g., v0.4.7-cu129-amd64)\"\n        required: true\n      target_tag:\n        description: \"New tag to apply (e.g., latest)\"\n        required: true\n\njobs:\n  retag:\n    if: github.repository == 'sgl-project/sglang'\n    runs-on: ubuntu-22.04\n    environment: \"prod\"\n    steps:\n      - name: Login to Docker Hub\n        uses: docker/login-action@v2\n        with:\n          username: ${{ secrets.DOCKERHUB_USERNAME }}\n          password: ${{ secrets.DOCKERHUB_TOKEN }}\n\n      - name: Retag image\n        run: |\n          echo \"Retagging lmsysorg/sglang:${{ inputs.source_tag }} -> lmsysorg/sglang:${{ inputs.target_tag }}\"\n          docker buildx imagetools create \\\n            -t lmsysorg/sglang:${{ inputs.target_tag }} \\\n            lmsysorg/sglang:${{ inputs.source_tag }}\n"
  },
  {
    "path": ".github/workflows/runner-utilization.yml",
    "content": "name: Runner Utilization Report\n\non:\n  schedule:\n    - cron: '0 8 * * *'  # Daily at 8 AM UTC\n  pull_request:\n    paths:\n      - '.github/workflows/runner-utilization.yml'\n      - 'scripts/ci/utils/runner_utilization_report.py'\n  workflow_dispatch:\n    inputs:\n      hours:\n        description: 'Time window in hours'\n        required: false\n        default: '24'\n        type: string\n      filter:\n        description: 'Filter runner labels (e.g., 5090, h200)'\n        required: false\n        type: string\n\njobs:\n  report:\n    name: Generate Report\n    runs-on: ubuntu-latest\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: '3.10'\n\n      - name: Generate Utilization Report\n        timeout-minutes: 30\n        env:\n          GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n        run: |\n          python scripts/ci/utils/runner_utilization_report.py \\\n            --repo ${{ github.repository }} \\\n            --hours ${{ inputs.hours || '24' }} \\\n            ${{ inputs.filter && format('--filter {0}', inputs.filter) || '' }}\n"
  },
  {
    "path": ".github/workflows/slash-command-handler.yml",
    "content": "name: Slash Command Handler\n\non:\n  issue_comment:\n    types: [created, edited]\n\npermissions:\n  contents: read\n  pull-requests: write # Required to add labels and reactions\n  actions: write       # Required to rerun workflows\n  issues: write        # Required for comment reactions in some contexts\n\njobs:\n  slash_command:\n    # Only run if it is a PR and the comment contains a recognized command\n    # Use contains() since startsWith() can't handle leading whitespace/newlines\n    if: >\n      github.event.issue.pull_request &&\n      (contains(github.event.comment.body, '/tag-run-ci-label') ||\n       contains(github.event.comment.body, '/rerun-failed-ci') ||\n       contains(github.event.comment.body, '/tag-and-rerun-ci') ||\n       contains(github.event.comment.body, '/rerun-stage') ||\n       contains(github.event.comment.body, '/rerun-ut'))\n    runs-on: ubuntu-latest\n\n    steps:\n      # SECURITY: This workflow runs on issue_comment trigger with elevated permissions\n      # (pull-requests: write, actions: write). For non-fork PRs, we can safely checkout\n      # the PR branch to allow testing changes to this handler. For fork PRs, we MUST\n      # stay on main to prevent untrusted code execution with these elevated permissions.\n      - name: Get PR details\n        id: pr\n        shell: bash\n        env:\n          GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n        run: |\n          PR_DATA=$(gh pr view ${{ github.event.issue.number }} --repo ${{ github.repository }} --json headRefName,headRepositoryOwner) || {\n            echo \"::error::Failed to fetch PR data\"\n            exit 1\n          }\n          # Use 'empty' filter to handle null/missing values (e.g., deleted forks)\n          HEAD_OWNER=$(echo \"$PR_DATA\" | jq -r '.headRepositoryOwner.login // empty')\n          REPO_OWNER=\"${{ github.repository_owner }}\"\n          # Treat missing/null owner as fork for security (fail-safe)\n          if [[ -z \"$HEAD_OWNER\" || \"$HEAD_OWNER\" != \"$REPO_OWNER\" ]]; then\n            IS_FORK=\"true\"\n          else\n            IS_FORK=\"false\"\n          fi\n          echo \"is_fork=$IS_FORK\" >> $GITHUB_OUTPUT\n          echo \"ref=$(echo \"$PR_DATA\" | jq -r '.headRefName')\" >> $GITHUB_OUTPUT\n          echo \"PR owner: $HEAD_OWNER, Repo owner: $REPO_OWNER, Is fork: $IS_FORK\"\n\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          # For non-fork PRs, checkout PR branch to allow testing handler changes\n          # For fork PRs, stay on main for security (don't run untrusted code with elevated permissions)\n          ref: ${{ steps.pr.outputs.is_fork == 'false' && steps.pr.outputs.ref || '' }}\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: '3.10'\n\n      - name: Install dependencies\n        run: |\n          pip install PyGithub\n\n      - name: Handle Slash Command\n        env:\n          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n          REPO_FULL_NAME: ${{ github.repository }}\n          PR_NUMBER: ${{ github.event.issue.number }}\n          COMMENT_ID: ${{ github.event.comment.id }}\n          COMMENT_BODY: ${{ github.event.comment.body }}\n          USER_LOGIN: ${{ github.event.comment.user.login }}\n        run: |\n          python scripts/ci/utils/slash_command_handler.py\n"
  },
  {
    "path": ".github/workflows/stress-test.yml",
    "content": "name: Stress Test\n\non:\n  workflow_dispatch:\n    inputs:\n      num_prompts:\n        description: 'Number of prompts per model'\n        required: true\n        default: '50000'\n        type: string\n      duration_minutes:\n        description: 'Timeout per model in minutes'\n        required: true\n        default: '45'\n        type: string\n\njobs:\n  stress-test:\n    if: github.repository == 'sgl-project/sglang'\n    runs-on: 8-gpu-h200\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/cuda/ci_install_dependency.sh\n\n      - name: Run stress tests\n        timeout-minutes: 210\n        env:\n          NUM_PROMPTS: ${{ inputs.num_prompts }}\n          DURATION_MINUTES: ${{ inputs.duration_minutes }}\n        run: |\n          cd test\n          python3 run_suite.py --hw cuda --suite stress\n\n      - name: Upload results\n        if: always()\n        uses: actions/upload-artifact@v4\n        with:\n          name: stress-test-results\n          path: |\n            stress_test_*.jsonl\n"
  },
  {
    "path": ".github/workflows/weekly-test-nvidia.yml",
    "content": "name: Weekly Test (Nvidia)\n\non:\n  schedule:\n    - cron: '0 0 * * 0'  # Run every Sunday at midnight UTC\n  workflow_dispatch:\n    inputs:\n      job_filter:\n        description: 'Select which job to run (leave empty or \"all\" to run all jobs)'\n        required: false\n        type: choice\n        default: 'all'\n        options:\n          - 'all'\n          - 'weekly-test-8-gpu-h200'\n\nconcurrency:\n  group: weekly-test-nvidia-${{ github.ref }}\n  cancel-in-progress: true\n\nenv:\n  SGLANG_IS_IN_CI: true\n  HF_HUB_DOWNLOAD_TIMEOUT: 300\n  HF_HUB_ETAG_TIMEOUT: 300\n\njobs:\n  # Weekly tests - 8 GPU H200\n  weekly-test-8-gpu-h200:\n    if: github.repository == 'sgl-project/sglang' && (inputs.job_filter == '' || inputs.job_filter == 'all' || inputs.job_filter == 'weekly-test-8-gpu-h200')\n    runs-on: 8-gpu-h200\n    timeout-minutes: 120\n    env:\n      RUNNER_LABELS: 8-gpu-h200\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n\n      - name: Install dependencies\n        run: |\n          bash scripts/ci/cuda/ci_install_dependency.sh\n\n      - name: Run weekly 8-GPU H200 tests\n        timeout-minutes: 120\n        env:\n          GPU_CONFIG: \"8-gpu-h200\"\n          IS_H200: \"1\"\n        run: |\n          cd test\n          python3 run_suite.py --hw cuda --suite weekly-8-gpu-h200 --nightly --continue-on-error --timeout-per-file 7200\n"
  },
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\n**/build/\n**/develop-eggs/\n**/dist/\n**/downloads/\n**/eggs/\n.eggs/\n**/lib/\n**/lib64/\n**/parts/\n**/sdist/\n**/var/\n**/wheels/\n**/share/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n\n# Tokenizer cache for tests\n.tokenizer_cache/\n.pytest_cache/\ncover/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\n.pybuilder/\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n#   For a library or package, you might want to ignore these files since the code is\n#   intended to run in multiple environments; otherwise, check them in:\n# .python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# poetry\n#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.\n#   This is especially recommended for binary packages to ensure reproducibility, and is more\n#   commonly ignored for libraries.\n#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control\n#poetry.lock\n\n# pdm\n#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.\n#pdm.lock\n#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it\n#   in version control.\n#   https://pdm.fming.dev/#use-with-ide\n.pdm.toml\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# pytype static type analyzer\n.pytype/\n\n# Cython debug symbols\ncython_debug/\n\n# PyCharm\n#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can\n#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore\n#  and can be added to the global gitignore or merged into this file.  For a more nuclear\n#  option (not recommended) you can uncomment the following to ignore the entire idea folder.\n.idea/\n\n# MacOS\n.DS_Store\n\n# Vim\n*.swp\n\n# Documentation\ndocs/_build\n\n# SGL\nbenchmark/mmlu/data\nbenchmark/mmlu/data.tar\nbenchmark/llava_bench/images\nbenchmark/llava_bench/mme_pack\n*.jsonl\ntmp*.txt\n\n# Torch Compile logs\ntl_out/\n\n# Plots\n*.png\n*.pdf\n\n# personnal\nwork_dirs/\n*.csv\n\n!logo.png\n\n# Prerequisites\n*.d\n\n# Compiled Object files\n*.slo\n*.lo\n*.o\n*.obj\n\n# Precompiled Headers\n*.gch\n*.pch\n\n# Compiled Dynamic libraries\n*.so\n*.dylib\n*.dll\n\n# Fortran module files\n*.mod\n*.smod\n\n# Compiled Static libraries\n*.lai\n*.la\n*.a\n*.lib\n\n# Executables\n*.exe\n*.out\n*.app\n*.iml\n\n# VSCode\n.vscode\n\n# Autoenv\n.env.leave\n\n# Rust lib\nCargo.lock\n\n# Generated vision test fixtures (regenerate with: python scripts/generate_vision_golden.py)\nsgl-model-gateway/tests/fixtures/golden/\n\n# Other repos\nlmms-eval\n\n**/.serena/\nctags/\noutputs/\ninputs/\n\n# Eval Cache\n.longbench_cache/\n\n# CUDA kernel develop, profile and debug\n.clangd\n*.nsys-rep\n*.ncu-rep\n*.nvcudmp\n\n# setuptools-scm generated version file\npython/sglang/_version.py\n\n# MUSA section\n# Generated source files by torchada\nsgl-kernel/csrc_musa/\nsgl-kernel/include_musa/\nsgl-kernel/csrc/**/*_musa/\n\n# MUSA core dump files\n*.mudmp\n\n# Others\n# diffusion 3D outputs\n*.glb\n*.ply\n*.npz\n"
  },
  {
    "path": ".isort.cfg",
    "content": "[settings]\nprofile=black\nknown_first_party=sglang\n"
  },
  {
    "path": ".pre-commit-config.yaml",
    "content": "default_stages: [pre-commit, pre-push, manual]\nexclude: ^(python/sglang/multimodal_gen/csrc|python/sglang/jit_kernel/flash_attention/cute)\n\nrepos:\n  - repo: https://github.com/pre-commit/pre-commit-hooks\n    rev: v6.0.0\n    hooks:\n      - id: check-symlinks\n      - id: destroyed-symlinks\n      - id: trailing-whitespace\n      - id: end-of-file-fixer\n      - id: check-yaml\n        args: [--allow-multiple-documents]\n      - id: check-toml\n      - id: check-ast\n      - id: check-added-large-files\n      - id: check-merge-conflict\n      - id: check-shebang-scripts-are-executable\n      - id: detect-private-key\n        exclude: ^sgl-model-gateway/tests/.*_test\\.rs$\n      - id: debug-statements\n      - id: no-commit-to-branch\n  - repo: https://github.com/PyCQA/isort\n    rev: 7.0.0\n    hooks:\n      - id: isort\n        exclude: '^python/sglang/srt/grpc/.*_pb2\\.py$|^python/sglang/srt/grpc/.*_pb2_grpc\\.py$|^python/sglang/srt/grpc/.*_pb2\\.pyi$|^python/sglang/srt/grpc/.*_pb2_grpc\\.pyi$'\n  - repo: https://github.com/astral-sh/ruff-pre-commit\n    rev: v0.15.1\n    hooks:\n      - id: ruff\n        args:\n          - --select=F401,F821\n          - --fix\n        files: ^(benchmark/|docs/|examples/|python/sglang/|sgl-model-gateway/py_*|test/)\n        exclude: |\n          (?x)^(\n          .*/__init__\\.py$|\n          .*\\.ipynb$|\n          python/sglang/srt/grpc/.*_pb2\\.py$|\n          python/sglang/srt/grpc/.*_pb2_grpc\\.py$|\n          python/sglang/srt/grpc/.*_pb2\\.pyi$|\n          python/sglang/srt/grpc/.*_pb2_grpc\\.pyi$|\n          )$\n  - repo: https://github.com/psf/black\n    rev: 26.1.0\n    hooks:\n      - id: black-jupyter\n        exclude: '^python/sglang/srt/grpc/.*_pb2\\.py$|^python/sglang/srt/grpc/.*_pb2_grpc\\.py$|^python/sglang/srt/grpc/.*_pb2\\.pyi$|^python/sglang/srt/grpc/.*_pb2_grpc\\.pyi$'\n  - repo: https://github.com/codespell-project/codespell\n    rev: v2.4.1\n    hooks:\n      - id: codespell\n        args: ['--config', '.codespellrc']\n  - repo: https://github.com/pre-commit/mirrors-clang-format\n    rev: v20.1.7\n    hooks:\n    - id: clang-format\n      types_or: [c++, cuda]\n      args: [--style=file, --verbose]\n  - repo: https://github.com/kynan/nbstripout\n    rev: 0.9.0\n    hooks:\n      - id: nbstripout\n        args:\n          - '--keep-output'\n          - '--extra-keys=metadata.kernelspec metadata.language_info.version'\n  - repo: local\n    hooks:\n      - id: check-chinese-characters\n        name: check chinese characters in multimodal_gen\n        entry: >-\n          python3 -c 'import sys, re; p=re.compile(r\"[\\u4e00-\\u9fff]\"); ec=0; [ ([(print(f\"{f}:{i+1}: {l.strip()}\") or (ec:=1)) for i,l in enumerate(open(f, \"r\", encoding=\"utf-8\", errors=\"ignore\")) if p.search(l)]) for f in sys.argv[1:] ]; sys.exit(ec)'\n        language: system\n        files: ^python/sglang/multimodal_gen/.*\n        exclude: ^(python/sglang/multimodal_gen/configs/sample|python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/workflows|python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages)(/|$)\n        types_or: [python, markdown, json, text]\n      - id: sort-ci-permissions\n        name: sort CI_PERMISSIONS.json\n        entry: python3 .github/update_ci_permission.py --sort-only\n        language: system\n        files: ^\\.github/CI_PERMISSIONS\\.json$\n        pass_filenames: false\n"
  },
  {
    "path": "3rdparty/amd/profiling/PROFILING.md",
    "content": "## Profiling SGLang Infer System with AMD GPUs\nThis AppNote describes the SGLang profiling technical, code augment and running steps for systems with AMD Instinct GPUs, nevertheless the same procedure may work with Nvidia GPUs too.\nExamples and steps are provided in detail, to facilitate easy reproduce and use to localize performance problem towards optimizations.\nTwo primary methods are covered:\n- [RPD](https://github.com/ROCm/rocmProfileData.git)\n- [PyTorch Profiler](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html)\n\n### Profiling SGLang Infer System with RPD Profiler\nRPD profiler is a low-overhead cross-platform profiler. Therefore, the same RPD code augment not only works for profiling on ROCm/AMD GPUs, but also works for profiling on CUDA/Nvidia GPUs as well. To do RPD profiling on SGLang repository, please use scripts and patch files included in this directory and follow the steps below:\n1. Install RPD with rpd.patch applied during installation using install_rpd.sh, both files are in this directory.\n\ninstall_rpd.sh\n\n```bash\n# download and install RPD\napt update && apt install -y sqlite3 libsqlite3-dev libfmt-dev\n\n# install rpd module\ngit clone https://github.com/ROCmSoftwarePlatform/rocmProfileData\ncd rocmProfileData\ngit checkout 976899e9c6dbc6dd2bccf770818e4e44125590ac\ngit apply rpd.patch\nmake && make install\ncd rocpd_python && python setup.py install && cd ..\ncd rpd_tracer && make clean;make install && python setup.py install && cd ..\n```\n\nrpd.patch\n\n```bash\ndiff --git a/rpd_tracer/Makefile b/rpd_tracer/Makefile\nindex e9d9feb..b2e9e1a 100644\n--- a/rpd_tracer/Makefile\n+++ b/rpd_tracer/Makefile\n@@ -16,7 +16,7 @@ ifneq (,$(HIP_PATH))\n         $(info Building with roctracer)\n         RPD_LIBS += -L/opt/rocm/lib -lroctracer64 -lroctx64 -lamdhip64 -lrocm_smi64\n         RPD_INCLUDES += -I/opt/rocm/include -I/opt/rocm/include/roctracer -I/opt/rocm/include/hsa\n-        RPD_SRCS += RoctracerDataSource.cpp RocmSmiDataSource.cpp\n+        RPD_SRCS += RoctracerDataSource.cpp\n         RPD_INCLUDES += -D__HIP_PLATFORM_AMD__\n endif\n```\n2. Add loadTracer.sh file included in this directory to /sglang/python/sglang.\n\nloadTracer.sh\n\n```bash\n#!/bin/bash\n################################################################################\n# Copyright (c) 2021 - 2023 Advanced Micro Devices, Inc. All rights reserved.\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in\n# all copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\n# THE SOFTWARE.\n################################################################################\nOUTPUT_FILE=\"trace.rpd\"\n\nif [ \"$1\" = \"-o\" ] ; then\n  OUTPUT_FILE=$2\n  shift\n  shift\nfi\n\nif [ -e ${OUTPUT_FILE} ] ; then\n  rm ${OUTPUT_FILE}\nfi\n\npython3 -m rocpd.schema --create ${OUTPUT_FILE}\nif [ $? != 0 ] ; then\n  echo \"Error: Could not create rpd file. Please run 'python setup.py install' from the rocpd_python dir\"\n  exit\nfi\n\nexport RPDT_FILENAME=${OUTPUT_FILE}\nexport RPDT_AUTOSTART=0\nLD_PRELOAD=librocm-smi_64:librpd_tracer.so \"$@\"\n```\n3. Apply patch (provided in this directory) with \"git apply rpd_profile_server_enable.patch\" if the main profiling purpose is to get info on gpu kernels as well as limited cpu activity info.\n\n#### Common Notes 1\nPlease note that although we are doing TP=8 in the example, we purposely only log RPD profiling on 2 ranks in the patch file (i.e.tp_rank=0/1) for profiling/visualization convenience, as even Perfetto streaming mode can only load maximal 8GB json file for visualization. With 2 ranks logged in RPD profiling, we could still check whether there are issues among ranks (e.g. load imbalance issue, nccl issue), and at the same time, we could log relatively longer time duration before the json file generated from RPD file hits 8GB size.\n\nrpd_profile_server_enable.patch\n\n```bash\ndiff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py\nindex 62d1ff9..9021c01 100644\n--- a/python/sglang/srt/managers/scheduler.py\n+++ b/python/sglang/srt/managers/scheduler.py\n@@ -71,6 +71,8 @@ from sglang.srt.utils import (\n     suppress_other_loggers,\n )\n from sglang.utils import get_exception_traceback\n+from rpdTracerControl import rpdTracerControl\n+rpdTracerControl.skipCreate()\n\n logger = logging.getLogger(__name__)\n\n@@ -245,6 +247,7 @@ class Scheduler:\n                 ],\n                 with_stack=True,\n             )\n+            self.rpd = rpdTracerControl()\n\n     @torch.inference_mode()\n     def event_loop(self):\n@@ -1027,15 +1030,24 @@ class Scheduler:\n     def start_profile(self) -> None:\n         if self.profiler is None:\n             raise RuntimeError(\"Profiler is not enabled.\")\n-        self.profiler.start()\n+        #self.profiler.start() #block pytorch profiler for rpd profiler enabling\n+        if self.tp_rank == 0 or self.tp_rank == 1:\n+            self.rpd.start()\n+            self.rpd.rangePush(\"\", \"rpd profile range\", \"\")\n+            logger.info(\"rpd is enabled\")\n\n     def stop_profile(self) -> None:\n         if self.profiler is None:\n             raise RuntimeError(\"Profiler is not enabled.\")\n-        self.profiler.stop()\n-        self.profiler.export_chrome_trace(\n-            self.torch_profiler_trace_dir + \"/\" + str(time.time()) + \".trace.json.gz\"\n-        )\n+        #self.profiler.stop()\n+        #self.profiler.export_chrome_trace(\n+        #    self.torch_profiler_trace_dir + \"/\" + str(time.time()) + \".trace.json.gz\"\n+        #)\n+        if self.tp_rank ==0 or self.tp_rank ==1:\n+            self.rpd.rangePop()\n+            self.rpd.stop()\n+            self.rpd.flush()\n+            logger.info(\"rpd is done\")\n         logger.info(\"Profiler is done\")\n```\n\n#### Advanced Debugging with RPD Profiler\nSometimes, we want to use rpd profiler to capture more CPU and python activities in order to debug some challenging issues (e.g. root cause of load imbalance across gpu processes, root cause of bubbles, etc). Only in such cases, we need to apply patch \"git apply rpd_profile_server_enable_wCPU_activities.patch\", where 3 files are modified.\n\nrpd_profile_server_enable_wCPU_activities.patch\n\n```bash\ndiff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py\nindex 62d1ff9..2edb427 100644\n--- a/python/sglang/srt/managers/scheduler.py\n+++ b/python/sglang/srt/managers/scheduler.py\n@@ -71,6 +71,8 @@ from sglang.srt.utils import (\n     suppress_other_loggers,\n )\n from sglang.utils import get_exception_traceback\n+from rpdTracerControl import rpdTracerControl\n+rpdTracerControl.skipCreate()\n\n logger = logging.getLogger(__name__)\n\n@@ -245,6 +247,7 @@ class Scheduler:\n                 ],\n                 with_stack=True,\n             )\n+            self.rpd = rpdTracerControl()\n\n     @torch.inference_mode()\n     def event_loop(self):\n@@ -1027,15 +1030,26 @@ class Scheduler:\n     def start_profile(self) -> None:\n         if self.profiler is None:\n             raise RuntimeError(\"Profiler is not enabled.\")\n-        self.profiler.start()\n+        #self.profiler.start()\n+        logger.info(\"torch profiler is disabled\")\n+        if self.tp_rank == 0 or self.tp_rank == 1:\n+            self.rpd.setPythonTrace(True)\n+            self.rpd.start()\n+            self.rpd.rangePush(\"\", \"scheduler\", \"\")\n+        logger.info(\"rpd is enabled inside scheduler profiling\")\n\n     def stop_profile(self) -> None:\n         if self.profiler is None:\n             raise RuntimeError(\"Profiler is not enabled.\")\n-        self.profiler.stop()\n-        self.profiler.export_chrome_trace(\n-            self.torch_profiler_trace_dir + \"/\" + str(time.time()) + \".trace.json.gz\"\n-        )\n+        #self.profiler.stop()\n+        #self.profiler.export_chrome_trace(\n+        #    self.torch_profiler_trace_dir + \"/\" + str(time.time()) + \".trace.json.gz\"\n+        #)\n+        if self.tp_rank ==0 or self.tp_rank ==1:\n+            self.rpd.rangePop()\n+            self.rpd.stop()\n+            self.rpd.flush()\n+            logger.info(\"rpd is done inside scheduler\")\n         logger.info(\"Profiler is done\")\n\n\ndiff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py\nindex 2621ccd..181df85 100644\n--- a/python/sglang/srt/managers/tokenizer_manager.py\n+++ b/python/sglang/srt/managers/tokenizer_manager.py\n@@ -58,6 +58,10 @@ from sglang.srt.sampling.sampling_params import SamplingParams\n from sglang.srt.server_args import PortArgs, ServerArgs\n from sglang.srt.utils import is_generation_model, is_multimodal_model\n\n+from rpdTracerControl import rpdTracerControl\n+rpdTracerControl.skipCreate()\n+\n+\n asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())\n\n logger = logging.getLogger(__name__)\n@@ -514,10 +518,20 @@ class TokenizerManager:\n         self.send_to_scheduler.send_pyobj(req)\n\n     def start_profile(self):\n+        rpd = rpdTracerControl()\n+        rpd.setPythonTrace(True)\n+        rpd.start()\n+        rpd.rangePush(\"\", \"tokenizer_manager\", \"\")\n+        logger.info(\"tokenizer_manager rpd profiling started!\")\n         req = ProfileReq.START_PROFILE\n         self.send_to_scheduler.send_pyobj(req)\n\n     def stop_profile(self):\n+        rpd = rpdTracerControl()\n+        rpd.rangePop()\n+        rpd.stop()\n+        rpd.flush()\n+        logger.info(\"rpd profiling is done inside tokenizer_manager!\")\n         req = ProfileReq.STOP_PROFILE\n         self.send_to_scheduler.send_pyobj(req)\n\ndiff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py\nindex 7111c93..2bd722c 100644\n--- a/python/sglang/srt/server.py\n+++ b/python/sglang/srt/server.py\n@@ -30,6 +30,8 @@ import threading\n import time\n from http import HTTPStatus\n from typing import Dict, List, Optional, Union\n+from rpdTracerControl import rpdTracerControl\n+rpdTracerControl.skipCreate()\n\n # Fix a bug of Python threading\n setattr(threading, \"_register_atexit\", lambda *args, **kwargs: None)\n@@ -152,6 +154,11 @@ async def flush_cache():\n @app.post(\"/start_profile\")\n async def start_profile():\n     \"\"\"Start profiling.\"\"\"\n+    rpd = rpdTracerControl()\n+    rpd.setPythonTrace(True)\n+    rpd.start()\n+    rpd.rangePush(\"\", \"server rpd profile range\", \"\")\n+    logger.info(\"rpd profiling started in server.py!\")\n     tokenizer_manager.start_profile()\n     return Response(\n         content=\"Start profiling.\\n\",\n@@ -164,6 +171,11 @@ async def start_profile():\n async def stop_profile():\n     \"\"\"Stop profiling.\"\"\"\n     tokenizer_manager.stop_profile()\n+    rpd = rpdTracerControl()\n+    rpd.rangePop()\n+    rpd.stop()\n+    rpd.flush()\n+    logger.info(\"rpd profiling is done in server.py!\")\n     return Response(\n         content=\"Stop profiling. This will take some time.\\n\",\n         status_code=200,\n```\n\n4. As an example for grok1 profiling, we create a dummy_grok1 directory with config.json (see content below) inside this directory and copy this directory to the right path for \"--model-path\" if you want to use the example server.sh file provided.\n```bash\ncat ../dummy_grok1/config.json\n{\n  \"architectures\": [\n    \"Grok1ModelForCausalLM\"\n  ],\n  \"embedding_multiplier_scale\": 78.38367176906169,\n  \"output_multiplier_scale\": 0.5773502691896257,\n  \"vocab_size\": 131072,\n  \"hidden_size\": 6144,\n  \"intermediate_size\": 32768,\n  \"max_position_embeddings\": 8192,\n  \"num_experts_per_tok\": 2,\n  \"num_local_experts\": 8,\n  \"num_attention_heads\": 48,\n  \"num_hidden_layers\": 64,\n  \"num_key_value_heads\": 8,\n  \"head_dim\": 128,\n  \"rms_norm_eps\": 1e-05,\n  \"rope_theta\": 10000.0,\n  \"model_type\": \"mixtral\",\n  \"torch_dtype\": \"bfloat16\"\n}\n```\n5. Launch server with rpd enabled script ./server.sh in one terminal inside the docker container.\n\n#### Common Notes 2\n- Remember to change model-path to the correct path\n- loadTracer.sh is needed to conduct profiling\n- SGLANG_TORCH_PROFILER_DIR is used for default torch profiler\n- Do not use loadTracer.sh if you are using the torch profiler, simply use python3 -m sglang.launch_server.\n\n\nserver.sh\n\n```bash\n#!/bin/bash\n\n# export SGLANG_TORCH_PROFILER_DIR=/data/sglang/\nexport SGLANG_TORCH_PROFILER_DIR=/sgl-workspace/sglang/profile/\n\n# Get the current timestamp\nTIMESTAMP=$(date +\"%Y%m%d_%H%M%S\")\n\n# Define the log file with a timestamp\nLOGFILE=\"sglang_server_log_$TIMESTAMP.json\"\n\n# Run the Python command and save the output to the log file\nloadTracer.sh python3 -m sglang.launch_server \\\n    --model-path /sgl-workspace/sglang/dummy_grok1 \\\n    --tokenizer-path Xenova/grok-1-tokenizer \\\n    --load-format dummy \\\n    --quantization fp8 \\\n    --tp 8 \\\n    --port 30000 \\\n    --disable-radix-cache 2>&1 | tee \"$LOGFILE\"\n```\n6. Open another terminal for the same docker container, and run the rpd enabled ./client.sh after you see \"The server is fired up and is ready to roll!\" message from server side terminal.\n\n#### Common Notes 3\n- Use curl http://localhost:30000/start_profile & curl http://localhost:30000/stop_profile to control the start and end of profiling. Check sglang/python/sglang/srt/managers/scheduler.py for more details.\n- Please don't use RPD profiler together with PyTorch profiler to avoid interference.\n- The rocmProfileData/tools/rpd2tracing.py file is used to generate json file from RPD file.\n\nclient.sh\n\n```bash\n#!/bin/bash\n\n# Start profiling via API\ncurl http://localhost:30000/start_profile -H \"Content-Type: application/json\"\n\n# Benchmark serving using sglang with random dataset and tokenizer\n# Define the log file with a timestamp\nTIMESTAMP=$(date +%Y%m%d_%H%M%S)\nLOGFILE=\"sglang_client_log_$TIMESTAMP.json\"\n\n# Run the benchmark with specified parameters and save logs\npython3 -m sglang.bench_serving \\\n    --backend sglang \\\n    --tokenizer Xenova/grok-1-tokenizer \\\n    --dataset-name random \\\n    --random-input 1024\\\n    --random-output 1024 \\\n    --num-prompts 120 \\\n    --request-rate 8 \\\n    --output-file online.jsonl 2>&1 | tee \"$LOGFILE\"\n\n# Stop profiling via API\ncurl http://localhost:30000/stop_profile -H \"Content-Type: application/json\"\n\n# Convert tracing file to csv & json\nsqlite3 trace.rpd \".mode csv\" \".header on\" \".output trace.csv\" \"select * from top;\" \".output stdout\"\npython3 ./rocmProfileData/tools/rpd2tracing.py trace.rpd trace.json\n```\n7. Follow [Perfetto docs](https://perfetto.dev/docs/visualization/large-traces) to visualize large json files. Try to adjust parameters so that the trace.json file size is less than 9GB.\n\n### Profiling SGLang Infer System with PyTorch Profiler\n\nPlease use the steps as follows:\n\n1. Apply the patch torch_profiler.patch. Note that you can modify \"if self.tp_rank == 0\" in the patch to allow more ranks be recorded in profiling.\n\ntorch_profiler.patch\n```bash\ndiff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py\nindex 62d1ff9..6ecd78c 100644\n--- a/python/sglang/srt/managers/scheduler.py\n+++ b/python/sglang/srt/managers/scheduler.py\n@@ -240,7 +240,6 @@ class Scheduler:\n             )\n             self.profiler = torch.profiler.profile(\n                 activities=[\n-                    torch.profiler.ProfilerActivity.CPU,\n                     torch.profiler.ProfilerActivity.CUDA,\n                 ],\n                 with_stack=True,\n@@ -1033,9 +1032,11 @@ class Scheduler:\n         if self.profiler is None:\n             raise RuntimeError(\"Profiler is not enabled.\")\n         self.profiler.stop()\n-        self.profiler.export_chrome_trace(\n-            self.torch_profiler_trace_dir + \"/\" + str(time.time()) + \".trace.json.gz\"\n-        )\n+        if self.tp_rank == 0:\n+            with open(f\"stats_repro_{int(time.time())}.txt\", \"w\") as f:\n+                print(self.profiler.key_averages(group_by_input_shape=True).table(sort_by=\"cuda_time_total\", row_limit=-1), file=f)\n+                print(\"Profiling stats done.\")\n+\n         logger.info(\"Profiler is done\")\n```\n\n2. Create the model path directory and copy it to the right path for \"--model-path\" if you want to use the server.sh file provided.\n\n3. Modify the included server.sh by removing \"loadTracer.sh\" before python command and launch script ./server.sh in one terminal inside the docker container.\n\n4. Similar to step 6 in RPD profiling section, but remove the last 2 lines in client.sh, which converted rpd file into csv and json files. Run modified client.sh for PyTorch profiling.\n-------\n- [Torch Profiler](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html)\n"
  },
  {
    "path": "3rdparty/amd/profiling/client.sh",
    "content": "#!/bin/bash\n\n# Start profiling via API\ncurl http://localhost:30000/start_profile -H \"Content-Type: application/json\"\n\n# Benchmark serving using sglang with random dataset and tokenizer\n# Define the log file with a timestamp\nTIMESTAMP=$(date +%Y%m%d_%H%M%S)\nLOGFILE=\"sglang_client_log_$TIMESTAMP.json\"\n\n# Run the benchmark with specified parameters and save logs\npython3 -m sglang.bench_serving \\\n    --backend sglang \\\n    --tokenizer Xenova/grok-1-tokenizer \\\n    --dataset-name random \\\n    --random-input 1024\\\n    --random-output 1024 \\\n    --num-prompts 240 \\\n    --request-rate 8 \\\n    --output-file online.jsonl 2>&1 | tee \"$LOGFILE\"\n\n# Stop profiling via API\ncurl http://localhost:30000/stop_profile -H \"Content-Type: application/json\"\n\n# Convert tracing file to csv & json\nsqlite3 trace.rpd \".mode csv\" \".header on\" \".output trace.csv\" \"select * from top;\" \".output stdout\"\npython3 /sgl-workspace/rocmProfileData/tools/rpd2tracing.py trace.rpd trace.json\n"
  },
  {
    "path": "3rdparty/amd/profiling/install_rpd.sh",
    "content": "# download and install RPD\napt update && apt install -y sqlite3 libsqlite3-dev libfmt-dev\n\n# install rpd module\ngit clone https://github.com/ROCmSoftwarePlatform/rocmProfileData\ncd rocmProfileData\ngit apply rpd.patch\nmake && make install\ncd rocpd_python && python setup.py install && cd ..\ncd rpd_tracer && make clean;make install && python setup.py install && cd ..\n"
  },
  {
    "path": "3rdparty/amd/profiling/loadTracer.sh",
    "content": "#!/bin/bash\n################################################################################\n# Copyright (c) 2021 - 2023 Advanced Micro Devices, Inc. All rights reserved.\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in\n# all copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\n# THE SOFTWARE.\n################################################################################\nOUTPUT_FILE=\"trace.rpd\"\n\nif [ \"$1\" = \"-o\" ] ; then\n  OUTPUT_FILE=$2\n  shift\n  shift\nfi\n\nif [ -e ${OUTPUT_FILE} ] ; then\n  rm ${OUTPUT_FILE}\nfi\n\npython3 -m rocpd.schema --create ${OUTPUT_FILE}\nif [ $? != 0 ] ; then\n  echo \"Error: Could not create rpd file. Please run 'python setup.py install' from the rocpd_python dir\"\n  exit\nfi\n\nexport RPDT_FILENAME=${OUTPUT_FILE}\nexport RPDT_AUTOSTART=0\nLD_PRELOAD=librocm-smi_64:librpd_tracer.so \"$@\"\n"
  },
  {
    "path": "3rdparty/amd/profiling/rpd.patch",
    "content": "diff --git a/rpd_tracer/Makefile b/rpd_tracer/Makefile\nindex e9d9feb..b2e9e1a 100644\n--- a/rpd_tracer/Makefile\n+++ b/rpd_tracer/Makefile\n@@ -16,7 +16,7 @@ ifneq (,$(HIP_PATH))\n         $(info Building with roctracer)\n         RPD_LIBS += -L/opt/rocm/lib -lroctracer64 -lroctx64 -lamdhip64 -lrocm_smi64\n         RPD_INCLUDES += -I/opt/rocm/include -I/opt/rocm/include/roctracer -I/opt/rocm/include/hsa\n-        RPD_SRCS += RoctracerDataSource.cpp RocmSmiDataSource.cpp\n+        RPD_SRCS += RoctracerDataSource.cpp\n         RPD_INCLUDES += -D__HIP_PLATFORM_AMD__\n endif\n"
  },
  {
    "path": "3rdparty/amd/profiling/rpd_profile_server_enable.patch",
    "content": "diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py\nindex 62d1ff9..9021c01 100644\n--- a/python/sglang/srt/managers/scheduler.py\n+++ b/python/sglang/srt/managers/scheduler.py\n@@ -71,6 +71,8 @@ from sglang.srt.utils import (\n     suppress_other_loggers,\n )\n from sglang.utils import get_exception_traceback\n+from rpdTracerControl import rpdTracerControl\n+rpdTracerControl.skipCreate()\n\n logger = logging.getLogger(__name__)\n\n@@ -245,6 +247,7 @@ class Scheduler:\n                 ],\n                 with_stack=True,\n             )\n+            self.rpd = rpdTracerControl()\n\n     @torch.inference_mode()\n     def event_loop(self):\n@@ -1027,15 +1030,24 @@ class Scheduler:\n     def start_profile(self) -> None:\n         if self.profiler is None:\n             raise RuntimeError(\"Profiler is not enabled.\")\n-        self.profiler.start()\n+        #self.profiler.start() #block pytorch profiler for rpd profiler enabling\n+        if self.tp_rank == 0 or self.tp_rank == 1:\n+            self.rpd.start()\n+            self.rpd.rangePush(\"\", \"rpd profile range\", \"\")\n+            logger.info(\"rpd is enabled\")\n\n     def stop_profile(self) -> None:\n         if self.profiler is None:\n             raise RuntimeError(\"Profiler is not enabled.\")\n-        self.profiler.stop()\n-        self.profiler.export_chrome_trace(\n-            self.torch_profiler_trace_dir + \"/\" + str(time.time()) + \".trace.json.gz\"\n-        )\n+        #self.profiler.stop()\n+        #self.profiler.export_chrome_trace(\n+        #    self.torch_profiler_trace_dir + \"/\" + str(time.time()) + \".trace.json.gz\"\n+        #)\n+        if self.tp_rank ==0 or self.tp_rank ==1:\n+            self.rpd.rangePop()\n+            self.rpd.stop()\n+            self.rpd.flush()\n+            logger.info(\"rpd is done\")\n         logger.info(\"Profiler is done\")\n"
  },
  {
    "path": "3rdparty/amd/profiling/rpd_profile_server_enable_wCPU_activities.patch",
    "content": "diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py\nindex 62d1ff9..2edb427 100644\n--- a/python/sglang/srt/managers/scheduler.py\n+++ b/python/sglang/srt/managers/scheduler.py\n@@ -71,6 +71,8 @@ from sglang.srt.utils import (\n     suppress_other_loggers,\n )\n from sglang.utils import get_exception_traceback\n+from rpdTracerControl import rpdTracerControl\n+rpdTracerControl.skipCreate()\n\n logger = logging.getLogger(__name__)\n\n@@ -245,6 +247,7 @@ class Scheduler:\n                 ],\n                 with_stack=True,\n             )\n+            self.rpd = rpdTracerControl()\n\n     @torch.inference_mode()\n     def event_loop(self):\n@@ -1027,15 +1030,26 @@ class Scheduler:\n     def start_profile(self) -> None:\n         if self.profiler is None:\n             raise RuntimeError(\"Profiler is not enabled.\")\n-        self.profiler.start()\n+        #self.profiler.start()\n+        logger.info(\"torch profiler is disabled\")\n+        if self.tp_rank == 0 or self.tp_rank == 1:\n+            self.rpd.setPythonTrace(True)\n+            self.rpd.start()\n+            self.rpd.rangePush(\"\", \"scheduler\", \"\")\n+        logger.info(\"rpd is enabled inside scheduler profiling\")\n\n     def stop_profile(self) -> None:\n         if self.profiler is None:\n             raise RuntimeError(\"Profiler is not enabled.\")\n-        self.profiler.stop()\n-        self.profiler.export_chrome_trace(\n-            self.torch_profiler_trace_dir + \"/\" + str(time.time()) + \".trace.json.gz\"\n-        )\n+        #self.profiler.stop()\n+        #self.profiler.export_chrome_trace(\n+        #    self.torch_profiler_trace_dir + \"/\" + str(time.time()) + \".trace.json.gz\"\n+        #)\n+        if self.tp_rank ==0 or self.tp_rank ==1:\n+            self.rpd.rangePop()\n+            self.rpd.stop()\n+            self.rpd.flush()\n+            logger.info(\"rpd is done inside scheduler\")\n         logger.info(\"Profiler is done\")\n\n\ndiff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py\nindex 2621ccd..181df85 100644\n--- a/python/sglang/srt/managers/tokenizer_manager.py\n+++ b/python/sglang/srt/managers/tokenizer_manager.py\n@@ -58,6 +58,10 @@ from sglang.srt.sampling.sampling_params import SamplingParams\n from sglang.srt.server_args import PortArgs, ServerArgs\n from sglang.srt.utils import is_generation_model, is_multimodal_model\n\n+from rpdTracerControl import rpdTracerControl\n+rpdTracerControl.skipCreate()\n+\n+\n asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())\n\n logger = logging.getLogger(__name__)\n@@ -514,10 +518,20 @@ class TokenizerManager:\n         self.send_to_scheduler.send_pyobj(req)\n\n     def start_profile(self):\n+        rpd = rpdTracerControl()\n+        rpd.setPythonTrace(True)\n+        rpd.start()\n+        rpd.rangePush(\"\", \"tokenizer_manager\", \"\")\n+        logger.info(\"tokenizer_manager rpd profiling started!\")\n         req = ProfileReq.START_PROFILE\n         self.send_to_scheduler.send_pyobj(req)\n\n     def stop_profile(self):\n+        rpd = rpdTracerControl()\n+        rpd.rangePop()\n+        rpd.stop()\n+        rpd.flush()\n+        logger.info(\"rpd profiling is done inside tokenizer_manager!\")\n         req = ProfileReq.STOP_PROFILE\n         self.send_to_scheduler.send_pyobj(req)\n\ndiff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py\nindex 7111c93..2bd722c 100644\n--- a/python/sglang/srt/server.py\n+++ b/python/sglang/srt/server.py\n@@ -30,6 +30,8 @@ import threading\n import time\n from http import HTTPStatus\n from typing import Dict, List, Optional, Union\n+from rpdTracerControl import rpdTracerControl\n+rpdTracerControl.skipCreate()\n\n # Fix a bug of Python threading\n setattr(threading, \"_register_atexit\", lambda *args, **kwargs: None)\n@@ -152,6 +154,11 @@ async def flush_cache():\n @app.post(\"/start_profile\")\n async def start_profile():\n     \"\"\"Start profiling.\"\"\"\n+    rpd = rpdTracerControl()\n+    rpd.setPythonTrace(True)\n+    rpd.start()\n+    rpd.rangePush(\"\", \"server rpd profile range\", \"\")\n+    logger.info(\"rpd profiling started in server.py!\")\n     tokenizer_manager.start_profile()\n     return Response(\n         content=\"Start profiling.\\n\",\n@@ -164,6 +171,11 @@ async def start_profile():\n async def stop_profile():\n     \"\"\"Stop profiling.\"\"\"\n     tokenizer_manager.stop_profile()\n+    rpd = rpdTracerControl()\n+    rpd.rangePop()\n+    rpd.stop()\n+    rpd.flush()\n+    logger.info(\"rpd profiling is done in server.py!\")\n     return Response(\n         content=\"Stop profiling. This will take some time.\\n\",\n         status_code=200,\n"
  },
  {
    "path": "3rdparty/amd/profiling/server.sh",
    "content": "#!/bin/bash\n\n# export SGLANG_TORCH_PROFILER_DIR=/data/sglang/\nexport SGLANG_TORCH_PROFILER_DIR=/sgl-workspace/sglang/profile/\n\n# Get the current timestamp\nTIMESTAMP=$(date +\"%Y%m%d_%H%M%S\")\n\n# Define the log file with a timestamp\nLOGFILE=\"sglang_server_log_$TIMESTAMP.json\"\n\n# Run the Python command and save the output to the log file\nloadTracer.sh python3 -m sglang.launch_server \\\n    --model-path /sgl-workspace/sglang/dummy_grok1 \\\n    --tokenizer-path Xenova/grok-1-tokenizer \\\n    --load-format dummy \\\n    --quantization fp8 \\\n    --tp 8 \\\n    --port 30000 \\\n    --disable-radix-cache 2>&1 | tee \"$LOGFILE\"\n"
  },
  {
    "path": "3rdparty/amd/profiling/torch_profiler.patch",
    "content": "diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py\r\nindex 62d1ff9..6ecd78c 100644\r\n--- a/python/sglang/srt/managers/scheduler.py\r\n+++ b/python/sglang/srt/managers/scheduler.py\r\n@@ -240,7 +240,6 @@ class Scheduler:\r\n             )\r\n             self.profiler = torch.profiler.profile(\r\n                 activities=[\r\n-                    torch.profiler.ProfilerActivity.CPU,\r\n                     torch.profiler.ProfilerActivity.CUDA,\r\n                 ],\r\n                 with_stack=True,\r\n@@ -1033,9 +1032,11 @@ class Scheduler:\r\n         if self.profiler is None:\r\n             raise RuntimeError(\"Profiler is not enabled.\")\r\n         self.profiler.stop()\r\n-        self.profiler.export_chrome_trace(\r\n-            self.torch_profiler_trace_dir + \"/\" + str(time.time()) + \".trace.json.gz\"\r\n-        )\r\n+        if self.tp_rank == 0:\r\n+            with open(f\"stats_repro_{int(time.time())}.txt\", \"w\") as f:\r\n+                print(self.profiler.key_averages(group_by_input_shape=True).table(sort_by=\"cuda_time_total\", row_limit=-1), file=f)\r\n+                print(\"Profiling stats done.\")\r\n+\r\n         logger.info(\"Profiler is done\")\n"
  },
  {
    "path": "3rdparty/amd/tuning/TUNING.md",
    "content": "## Tuning SGLang Infer System with AMD GPUs\nThis AppNote describes the SGLang performance tuning technical, code harness and running steps for systems with AMD Instinct GPUs.\nHarness code, examples and steps are provided in detail, to facilitate easy reproduce & use to tune performance towards workloads.\nThree primary runtime areas are covered:\n\n## 1. Triton Kernels\nTo maximize Triton kernel efficiency, several strategies can be employed:\n\n### Key Environment Variables:\n- **num_stages**: Adjusts the number of pipeline stages to optimize kernel efficiency based on the specific type of operations (e.g., General Matrix Multiplication - GEMM).\n- **waves_per_eu**: Controls the usage of Vector General Purpose Registers (VGPR) to enhance occupancy, thereby improving latency or throughput.\n- **BLOCK_M, BLOCK_N, BLOCK_K**: Tunable tile sizes that assist in balancing memory transfer and computational efficiency.\n- **matrix_instr_nonkdim**: Optimizes the usage of Matrix-Fused Multiply-Add (MFMA) instructions for specific kernel types, such as Flash Attention.\n- **OPTIMIZE_EPILOGUE**: An environment variable that can be set to `1` to enhance performance by eliminating the `convert_layout` operation in the kernel's epilogue.\n```python\n@triton.autotune(configs=[\n        triton.Config({'waves_per_eu': 1}, num_warps=4, num_stages=1),\n        triton.Config({'waves_per_eu': 1}, num_warps=8, num_stages=1),\n        triton.Config({'waves_per_eu': 1}, num_warps=16, num_stages=1),\n        triton.Config({'waves_per_eu': 2}, num_warps=4, num_stages=1),\n        triton.Config({'waves_per_eu': 2}, num_warps=8, num_stages=1),\n        triton.Config({'waves_per_eu': 2}, num_warps=16, num_stages=1),\n        triton.Config({'waves_per_eu': 4}, num_warps=4, num_stages=1),\n        triton.Config({'waves_per_eu': 4}, num_warps=8, num_stages=1),\n        triton.Config({'waves_per_eu': 4}, num_warps=16, num_stages=1),\n    ], key=['BLOCK_N', 'NUM_TOKEN_BLKS'], use_cuda_graph=True)\n@triton.jit\ndef _triton_kernel_function():\n    ...\n```\n## 2. Torch Tunable Operations\n**TunableOp** is a feature in PyTorch that allows for the definition and optimization of custom kernels with tunable parameters. This feature is particularly useful for enhancing the performance of kernels by experimenting with different configurations.\n\n### Key Environment Variables:\n1. **PYTORCH_TUNABLEOP_ENABLED**:\n   - Default: `0`\n   - Set to `1` to enable TunableOp.\n\n2. **PYTORCH_TUNABLEOP_TUNING**:\n   - Default: `1`\n   - Set to `0` to disable tuning. If a tuned entry is not found, it will run the tuning step and record the entry when PYTORCH_TUNABLEOP_ENABLED is enabled.\n\n3. **PYTORCH_TUNABLEOP_VERBOSE**:\n   - Default: `0`\n   - Set to `1` to enable verbose output for TunableOp.\n\n### Usage Example:\nTo enable TunableOp and tuning, and optionally enable verbose mode, you can run the following command in your terminal:\n\n```bash\n#Tuning\nPYTORCH_TUNABLEOP_ENABLED=1 PYTORCH_TUNABLEOP_TUNING=1 your_script.sh\n\n#Inference with tuning op\nPYTORCH_TUNABLEOP_ENABLED=1 PYTORCH_TUNABLEOP_TUNING=0 your_script.sh\n\n#Print out the log\nPYTORCH_TUNABLEOP_ENABLED=1 PYTORCH_TUNABLEOP_TUNING=0 PYTORCH_TUNABLEOP_VERBOSE=1 your_script.sh\n\n```\n## 3. Torch Compilation\n\n\nThe following are suggestions for optimizing matrix multiplication (GEMM) and convolution (conv) operations in PyTorch using Inductor, a part of the PyTorch compilation framework. The goal is to leverage Triton to achieve better performance.\n\nTo tune Triton kernels with GEMM and convolution ops (conv), use the `torch.compile` function with the max-autotune mode. This benchmarks a predefined list of Triton configurations and selects the fastest one for each shape.\n\n### Key Configurations:\n1. **Max Autotune**:\n   - Set `torch._inductor.config.max_autotune = True` or `TORCHINDUCTOR_MAX_AUTOTUNE=1`.\n\n2. **Fine-Grained Control**:\n   - Enable GEMM tuning: `torch._inductor.config.max_autotune_gemm = True`.\n   - Enable tuning for pointwise/reduction ops: `torch._inductor.config.max_autotune.pointwise = True`.\n\n3. **Backend Selection**:\n   - Use `torch._inductor.max_autotune_gemm_backends` to limit backends to TRITON for better performance.\n\n4. **Freezing for Inference**:\n   - Use `torch._inductor.config.freezing=True` to enable constant folding optimizations.\n\n5. **Debugging**:\n   - Set `TORCH_COMPILE_DEBUG=1` to extract Triton kernels generated by Inductor.\n\n### Example Code Block:\n```bash\n#Gemm Tuning\nTORCHINDUCTOR_MAX_AUTOTUNE=1 TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1 your_script.sh\n\n#Specify your backend to TRITON for Gemm Tuning\nTORCHINDUCTOR_MAX_AUTOTUNE=1 TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS=TRITON your_script.sh\n\n#Inference with large improvement on AMD GPU\nTORCHINDUCTOR_FREEZING=1 your_script.sh\n```\n## 4. Fused MOE kernel\nTo maximize moe kernel efficiency, need to use below scripts to find out the best launch configuration\n\n### Key parameters:\n- **--model**: what moe model type to do tuning, it will automatically decide the size of d_model, model_intermediate_size, num_layers\n- **--tp-size**: simulate the whole model run configuration to set the dimension size using tp correctly\n- **--batch**: M dimension size of moe kernel, for prefill moe kernel the value is batch*input_len, for decode moe kernel the value is batch\n- **--dtype**: computation type\n\n```bash\n#Tuning\n#for example, we have one case like this \"python3 -m sglang.bench_latency --model dummy_grok1/ --load-format dummy --tokenizer-path Xenova/grok-1-tokenizer --tp 8 --batch-size 32 --input 1024 --output 8 --attention-backend triton --sampling-backend pytorch --quantization fp8\" to run, it defined batch-size 32 input length 1024 and output length 8, from \"--batch\" in moe view point, the prefill batch is 32*1024 = 32768, the decode batch is 32*1(only one output token generated in each run).\n#so we can tune decode moe use below command\npython benchmark_moe_rocm.py --model grok1 --tp-size 8 --dtype float8 --batch \"32\"\n# and use this command to tune prefill moe\npython benchmark_moe_rocm.py --model grok1 --tp-size 8 --dtype float8 --batch \"32768\"\n```\n\n## Reference\n\nFor more detailed information on tuning SGLang performance with AMD GPUs, please refer to the following link:\n\n[ROCm Documentation: Triton Kernel Performance Optimization](https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/workload.html#triton-kernel-performance-optimization)\n"
  },
  {
    "path": "3rdparty/amd/tuning/benchmark_moe_rocm.py",
    "content": "import argparse\nimport json\nimport os\nimport sys\n\nimport torch\nimport torch.nn.functional as F\nimport triton\nimport triton.language as tl\nfrom tqdm import tqdm\nfrom transformers import AutoConfig\n\nfrom sglang.srt.layers.moe.fused_moe_triton.fused_moe import (\n    fused_moe,\n    get_config_file_name,\n)\n\npadding_size = 128 if bool(int(os.getenv(\"SGLANG_MOE_PADDING\", \"0\"))) else 0\n\n\ndef main(model, tp_size, dtype: str, batches):\n    method = fused_moe\n\n    for bs in batches:\n        run_grid(int(bs), model=model, method=method, tp_size=tp_size, dtype=dtype)\n\n\ndef prune_configs(M, N, K, configs):\n    pruned_configs = []\n    elemBytes_a = 1  # [DV Note] Hard-coded for float16 (2 bytes)\n    elemBytes_b = 1  # [DV Note] Hard-coded for float16 (2 bytes)\n\n    mfma = 16 if M < 32 or N < 32 else 32\n\n    # TODO (zhanglx): figure out the boundary between large and small gemms\n    large_gemm = False\n    if M >= 2048 and N >= 2048:\n        large_gemm = True\n\n    for config in configs:\n        BLOCK_SIZE_M = config.get(\"BLOCK_SIZE_M\")\n        BLOCK_SIZE_N = config.get(\"BLOCK_SIZE_N\")\n        BLOCK_SIZE_K = config.get(\"BLOCK_SIZE_K\")\n        num_warps = config.get(\"num_warps\")\n        matrix_instr_nonkdim = config.get(\"matrix_instr_nonkdim\")\n        # kpack = config.get(\"kpack\")\n        if matrix_instr_nonkdim > mfma:\n            continue\n        if mfma == 4 and BLOCK_SIZE_K < 64:\n            continue\n        # some layouts could not work properly in case\n        # number elements per thread is less 1\n        if BLOCK_SIZE_M * BLOCK_SIZE_N < 64:\n            continue\n        SPLIT_K = 1  # config.get(\"SPLIT_K\")\n        GROUP_M = config.get(\"GROUP_SIZE_M\")\n        if matrix_instr_nonkdim > BLOCK_SIZE_M or matrix_instr_nonkdim > BLOCK_SIZE_N:\n            continue\n        if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M:\n            continue\n        if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N:\n            continue\n        # Skip BLOCK_SIZE that is too large compare to M/N\n        # unless BLOCK_SIZE is already small enough\n        if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16:\n            continue\n        if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16:\n            continue\n        # skip large split_k when not necessary\n        if SPLIT_K != 1 and not need_split_k(M, N, K):\n            continue\n        # skip split_k that leads to EVEN_K = false\n        leap = SPLIT_K * BLOCK_SIZE_K\n        modv = K % leap\n        if modv != 0:\n            continue\n        # skip large GROUP_M\n        if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1:\n            continue\n        # out of shared memory resource\n        # TODO (zhanglx): This does not consider the LDS usage in the epilogue\n        LDS = (\n            BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a\n            + BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b\n        )\n        if LDS > 65536:\n            continue\n        # Skip small block sizes and num_warps for large gemm\n        # For fp16 and f8, we want to only use BLOCK_SIZE >= 64\n        if large_gemm:\n            if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64:\n                continue\n            if BLOCK_SIZE_K < 64:\n                continue\n            if num_warps < 4:\n                continue\n\n        pruned_configs.append(config)\n\n    return pruned_configs\n\n\ndef union_of_list_of_dicts(l1, l2):\n    result = []\n    temp_list = l1.copy()\n    temp_list.extend(l2)\n    for myDict in temp_list:\n        if myDict not in result:\n            result.append(myDict)\n\n    return result\n\n\ndef run_grid(bs, model, method, tp_size, dtype: str):\n\n    config = AutoConfig.from_pretrained(model)\n\n    top_k = config.num_experts_per_tok\n    d_model = config.hidden_size\n    model_intermediate_size = config.intermediate_size\n    num_layers = config.num_hidden_layers\n    hidden_states_dtype = config.torch_dtype\n\n    if config.num_experts_per_tok:\n        if config.architectures[0] == \"Grok1ModelForCausalLM\":\n            num_total_experts = config.num_experts\n        else:\n            num_total_experts = config.num_local_experts\n    else:\n        raise ValueError(f\"Unsupported Mixtral model {model}\")\n\n    # tp_size = 2\n    num_warmup_calls = 10\n    num_calls = 30\n\n    num_warmup_trials = 1\n    num_trials = 1\n\n    full_configs = []\n\n    block_m_range = [16, 32, 64, 128, 256]\n    block_n_range = [16, 32, 64, 128, 256]\n    block_k_range = [32, 64, 128, 256]  # MUST >= 32\n    num_warps_range = [1, 2, 4, 8]\n    group_m_range = [1, 4, 8, 16, 32]\n    # For now we see better perf with num_stages=0 for all gemm configs we care\n    # But keep this explicit so that we do not forget we may need to set it to\n    # other values in the future\n    num_stage_range = [2]\n    waves_per_eu_range = [0, 1, 2, 4, 8]\n    # Remove 32 because of triton compiling error\n    matrix_instr_nonkdim_range = [16]\n    kpack_range = [1, 2]\n\n    for block_size_m in block_m_range:\n        for block_size_n in block_n_range:\n            for block_size_k in block_k_range:\n                for group_size_m in group_m_range:\n                    for num_warps in num_warps_range:\n                        for num_stages in num_stage_range:\n                            for waves_per_eu in waves_per_eu_range:\n                                for matrix_instr_nonkdim in matrix_instr_nonkdim_range:\n                                    for kpack in kpack_range:\n                                        full_configs.append(\n                                            {\n                                                \"BLOCK_SIZE_M\": block_size_m,\n                                                \"BLOCK_SIZE_N\": block_size_n,\n                                                \"BLOCK_SIZE_K\": block_size_k,\n                                                \"GROUP_SIZE_M\": group_size_m,\n                                                \"num_warps\": num_warps,\n                                                \"num_stages\": num_stages,\n                                                \"waves_per_eu\": waves_per_eu,\n                                                \"matrix_instr_nonkdim\": matrix_instr_nonkdim,\n                                                \"kpack\": kpack,\n                                            }\n                                        )\n\n    M1 = bs * 2\n    N1 = model_intermediate_size * 2 // tp_size\n    K1 = d_model\n    prune_configs_1 = prune_configs(M1, N1, K1, full_configs)\n\n    M2 = bs * 2\n    N2 = d_model\n    K2 = model_intermediate_size // tp_size\n    prune_configs_2 = prune_configs(M2, N2, K2, full_configs)\n\n    configs = union_of_list_of_dicts(prune_configs_1, prune_configs_2)\n\n    print(f\"{bs=} || {len(full_configs)=} | {len(prune_configs_1)=} | \\\n            {len(prune_configs_2)=} | {len(configs)=}\")\n\n    best_config = None\n    best_time_us = 1e20\n\n    print(f\"{tp_size=} {bs=}\")\n\n    for config in tqdm(configs):\n        # warmup\n        try:\n            print(config)\n            for _ in range(num_warmup_trials):\n                run_timing(\n                    num_calls=num_warmup_calls,\n                    bs=bs,\n                    d_model=d_model,\n                    num_total_experts=num_total_experts,\n                    top_k=top_k,\n                    tp_size=tp_size,\n                    model_intermediate_size=model_intermediate_size,\n                    method=method,\n                    config=config,\n                    dtype=dtype,\n                    hidden_states_dtype=hidden_states_dtype,\n                )\n        except triton.runtime.autotuner.OutOfResources:\n            continue\n\n        # trial\n        for _ in range(num_trials):\n            kernel_dur_ms = run_timing(\n                num_calls=num_calls,\n                bs=bs,\n                d_model=d_model,\n                num_total_experts=num_total_experts,\n                top_k=top_k,\n                tp_size=tp_size,\n                model_intermediate_size=model_intermediate_size,\n                method=method,\n                config=config,\n                dtype=dtype,\n                hidden_states_dtype=hidden_states_dtype,\n            )\n\n            kernel_dur_us = 1000 * kernel_dur_ms\n            model_dur_ms = kernel_dur_ms * num_layers\n\n            if kernel_dur_us < best_time_us:\n                best_config = config\n                best_time_us = kernel_dur_us\n\n                tqdm.write(\n                    f\"{kernel_dur_us=:.1f} {model_dur_ms=:.1f}\"\n                    f\" {bs=} {tp_size=} {top_k=} {num_total_experts=} \"\n                    f\"{d_model=} {model_intermediate_size=} {num_layers=}\"\n                )\n\n    print(\"best_time_us\", best_time_us)\n    print(\"best_config\", best_config)\n\n    # holds Dict[str, Dict[str, int]]\n    filename = get_config_file_name(\n        num_total_experts,\n        model_intermediate_size // tp_size,\n        \"float8\" if dtype == \"float8\" else None,\n    )\n    print(f\"writing config to file {filename}\")\n    existing_content = {}\n    if os.path.exists(filename):\n        with open(filename, \"r\") as f:\n            existing_content = json.load(f)\n    existing_content[str(bs)] = best_config\n    with open(filename, \"w\") as f:\n        json.dump(existing_content, f, indent=4)\n        f.write(\"\\n\")\n\n\ndef run_timing(\n    num_calls: int,\n    bs: int,\n    d_model: int,\n    num_total_experts: int,\n    top_k: int,\n    tp_size: int,\n    model_intermediate_size: int,\n    method,\n    config,\n    dtype: str,\n    hidden_states_dtype,\n) -> float:\n    shard_intermediate_size = model_intermediate_size // tp_size\n\n    hidden_states = torch.rand(\n        (bs, d_model),\n        device=\"cuda:0\",\n        dtype=hidden_states_dtype,\n    )\n\n    w1 = torch.rand(\n        (num_total_experts, 2 * shard_intermediate_size, d_model + padding_size),\n        device=hidden_states.device,\n        dtype=hidden_states.dtype,\n    )\n\n    w2 = torch.rand(\n        (num_total_experts, d_model, shard_intermediate_size + padding_size),\n        device=hidden_states.device,\n        dtype=hidden_states.dtype,\n    )\n\n    w1_scale = None\n    w2_scale = None\n    a1_scale = None\n    a2_scale = None\n\n    if dtype == \"float8\":\n        w1 = w1.to(torch.float8_e4m3fnuz)\n        w2 = w2.to(torch.float8_e4m3fnuz)\n        w1_scale = torch.ones(\n            num_total_experts, device=hidden_states.device, dtype=torch.float32\n        )\n        w2_scale = torch.ones(\n            num_total_experts, device=hidden_states.device, dtype=torch.float32\n        )\n        a1_scale = torch.ones(1, device=hidden_states.device, dtype=torch.float32)\n        a2_scale = torch.ones(1, device=hidden_states.device, dtype=torch.float32)\n\n    gating_output = F.softmax(\n        torch.rand(\n            (num_calls, bs, num_total_experts),\n            device=hidden_states.device,\n            dtype=torch.float32,\n        ),\n        dim=-1,\n    )\n\n    ##################################\n\n    start_event = torch.cuda.Event(enable_timing=True)\n    end_event = torch.cuda.Event(enable_timing=True)\n\n    start_event.record()\n    for i in range(num_calls):\n        hidden_states = method(\n            hidden_states=hidden_states,\n            w1=w1,\n            w2=w2,\n            w1_scale=w1_scale,\n            w2_scale=w2_scale,\n            a1_scale=a1_scale,\n            a2_scale=a2_scale,\n            gating_output=gating_output[0],\n            topk=top_k,\n            renormalize=True,\n            inplace=True,\n            override_config=config,\n            use_fp8=dtype == \"float8\",\n        )\n\n    end_event.record()\n    end_event.synchronize()\n\n    dur_ms = start_event.elapsed_time(end_event) / num_calls\n    return dur_ms\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\n        prog=\"benchmark_mixtral_moe\",\n        description=\"Benchmark and tune the fused_moe kernel\",\n    )\n    parser.add_argument(\n        \"--dtype\",\n        type=str,\n        default=\"auto\",\n        choices=[\"float8\", \"float16\", \"bfloat16\"],\n        help=\"Data type used for fused_moe kernel computations\",\n    )\n    parser.add_argument(\"--model\", type=str, default=\"hpcai-tech/grok-1\")\n\n    parser.add_argument(\"--tp-size\", type=int, default=2, help=\"Tensor paralleli size\")\n    parser.add_argument(\"-b\", \"--batches\", type=str)\n\n    args = parser.parse_args()\n\n    batches = args.batches.split(\",\")\n\n    sys.exit(main(args.model, args.tp_size, args.dtype, batches))\n"
  },
  {
    "path": "3rdparty/amd/wheel/README.md",
    "content": "# sglang-kernel (prior sgl-kernel)\n\nBuilding and releasing `sglang-kernel` as a wheel is a part of the release workflow. Check [release-whl-kernel.yml](https://github.com/sgl-project/sglang/blob/main/.github/workflows/release-whl-kernel.yml) for details.\n\n# sglang\n\n`3rdparty/amd/wheel/sglang/pyproject.toml` is the AMD-specific pyproject for building the `amd-sglang` wheel. It extends `python/pyproject_other.toml` with two ROCm-version extras (`rocm700`, `rocm720`) that pin the matching torch/triton/torchaudio/torchvision/`sglang-kernel` wheels, and renames the package to `amd-sglang`.\n\n## Operation to build sglang wheel\n\n```\n$ git clone https://github.com/sgl-project/sglang.git && cd sglang\n$ cp 3rdparty/amd/wheel/sglang/pyproject.toml python/pyproject.toml\n$ cd python && python -m build\n```\n\n## Installation\n\n### v0.5.9\n\nROCm 7.0.0:\n```\npip uninstall sglang-kernel sglang amd-sglang\npip install \"amd-sglang[all-hip,rocm700]\" -i https://pypi.amd.com/rocm-7.0.0/simple --extra-index-url https://pypi.org/simple\n```\n\nROCm 7.2.0:\n```\npip uninstall sglang-kernel sglang amd-sglang\npip install \"amd-sglang[all-hip,rocm720]\" -i https://pypi.amd.com/rocm-7.2.0/simple --extra-index-url https://pypi.org/simple\n```\n\nNote: You must resolve the two dependencies, AITER and triton, below.  Others are optional depending on your applications.\n\n## Manual Dependency Resolution\n\n### Resolving AITER\n\n[AITER](https://github.com/ROCm/aiter) is a fundamental dependency. Wheel-izing it is ongoing.\nUntil we can pin it reliably, install it manually (typically following the [ROCm docker recipe](https://github.com/sgl-project/sglang/blob/main/docker/rocm.Dockerfile#L106).\n\n### Revolving triton\n\nTo avoid known issues in triton 3.5.1 installed by default, we recommend upgrading triton after installation.  In ROCm 7.0.0 environment,\n```\npip install triton==3.6.0\n```\nor ROCm 7.2.0,\n```\npip install https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/triton-3.6.0%2Brocm7.2.0.gitba5c1517-cp310-cp310-linux_x86_64.whl\n```\n\n#### `torch._inductor.exc.InductorError: AttributeError: 'KernelMetadata' object has no attribute 'cluster_dims'`\n\nAfter upgrading, you may hit this error during inference when PyTorch Inductor interacts with Triton metadata.\n\nA pragmatic workaround is to guard the metadata access in Inductor's Triton heuristics so it only reads `cluster_dims` when the attribute exists:\n\n```diff\n--- a/opt/venv/lib/python3.10/site-packages/torch/_inductor/runtime/triton_heuristics.py\n+++ b/opt/venv/lib/python3.10/site-packages/torch/_inductor/runtime/triton_heuristics.py\n@@ -1759,6 +1759,8 @@\n                 else (\n                     (binary.metadata.num_ctas, *binary.metadata.cluster_dims)\n                     if hasattr(binary, \"metadata\")\n+                    and hasattr(binary.metadata, \"num_ctas\")\n+                    and hasattr(binary.metadata, \"cluster_dims\")\n                     else ()\n                 )\n             ),\n```\n\n### Resolving Dependencies for Distributed Inference\n\n#### sgl-model-gateway\n\nInstall sgl-model-gateway as follows:\n\n```\n$ apt install openssl libssl-dev protobuf\n$ export PATH=\"/$HOME/.cargo/bin:${PATH}\" \\\n  && curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y \\\n  && rustc --version && cargo --version # Prepare for a rust toolchain\n$ python3 -m pip install --no-cache-dir setuptools-rust \\\n  && cd /sgl-workspace/sglang/sgl-model-gateway/bindings/python \\\n  && cargo build --release \\\n  && python3 -m pip install --no-cache-dir . \\\n  && rm -rf /root/.cache # Build and install sgl-model-gateway\n```\n\n#### [Mori](https://github.com/sgl-project/sglang/blob/main/docker/rocm.Dockerfile#L381)\n\n### Resolving Dependencies for DeepSeek-V3.2\n\n#### [TileLang](https://github.com/sgl-project/sglang/blob/main/docker/rocm.Dockerfile#L216)\n\n#### [FHT (fast-hadamard-transform)](https://github.com/sgl-project/sglang/blob/main/docker/rocm.Dockerfile#L300)\n"
  },
  {
    "path": "3rdparty/amd/wheel/sgl-kernel/CMakeLists_rocm.txt",
    "content": "cmake_minimum_required(VERSION 3.24 FATAL_ERROR)\nproject(sgl_kernel LANGUAGES CXX)\n\n# Cmake\nset(CMAKE_CXX_STANDARD 17)\nset(CMAKE_CXX_STANDARD_REQUIRED ON)\nset(CMAKE_POSITION_INDEPENDENT_CODE ON)\nset(CMAKE_SHARED_LIBRARY_PREFIX \"\")\n\nset(CMAKE_COLOR_DIAGNOSTICS ON)\nset(CMAKE_VERBOSE_MAKEFILE ON CACHE BOOL \"ON\")\n\n# Python / Torch\nfind_package(Python COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT} REQUIRED)\n\nexecute_process(\n  COMMAND ${Python_EXECUTABLE} -c \"import torch; print(torch.utils.cmake_prefix_path)\"\n  OUTPUT_VARIABLE TORCH_PY_PREFIX\n  OUTPUT_STRIP_TRAILING_WHITESPACE\n)\n\nset(Torch_DIR \"${TORCH_PY_PREFIX}/Torch\")\nlist(APPEND CMAKE_PREFIX_PATH \"${TORCH_PY_PREFIX}/Torch\")\nfind_package(Torch REQUIRED)\n\nexecute_process(\n  COMMAND ${Python_EXECUTABLE} -c \"import torch; print(int(torch._C._GLIBCXX_USE_CXX11_ABI))\"\n  OUTPUT_VARIABLE TORCH_CXX11_ABI\n  OUTPUT_STRIP_TRAILING_WHITESPACE\n)\nif(TORCH_CXX11_ABI STREQUAL \"0\")\n  add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)\nelse()\n  add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=1)\nendif()\n\n# ROCm/HIP\nenable_language(HIP)\nlist(APPEND CMAKE_PREFIX_PATH \"/opt/rocm/lib/cmake/hip-lang\")\nfind_package(hip REQUIRED CONFIG)\n\n# Determine AMDGPU target from environment variable or default to gfx942\nset(AMDGPU_TARGET_ENV \"$ENV{AMDGPU_TARGET}\")\n\nif(AMDGPU_TARGET_ENV)\n  # Use environment variable if specified\n  set(AMDGPU_TARGETS \"${AMDGPU_TARGET_ENV}\")\n  message(STATUS \"Using AMDGPU_TARGET from environment: ${AMDGPU_TARGETS}\")\nelse()\n  # Default to gfx942 only\n  set(AMDGPU_TARGETS \"gfx942\")\n  message(STATUS \"AMDGPU_TARGET not set, defaulting to gfx942\")\nendif()\n\n# Set HIP architectures\nset(CMAKE_HIP_ARCHITECTURES ${AMDGPU_TARGETS})\n\n# FP8 macro selection\n# Always define HIP_FP8_TYPE_FNUZ=1 (for gfx942 and host compilation)\n# Additionally define HIP_FP8_TYPE_E4M3=1 when building for gfx950\n# The existing utils.h logic will pick the right one based on architecture\nset(SGL_FP8_MACROS \"-DHIP_FP8_TYPE_FNUZ=1\")\n\nif(AMDGPU_TARGETS MATCHES \"gfx950\")\n  list(APPEND SGL_FP8_MACROS \"-DHIP_FP8_TYPE_E4M3=1\")\n  message(STATUS \"Multi-arch build: Enabling both HIP_FP8_TYPE_FNUZ (gfx942) and HIP_FP8_TYPE_E4M3 (gfx950)\")\nelseif(AMDGPU_TARGETS MATCHES \"gfx942\")\n  message(STATUS \"Single-arch build: Enabling HIP_FP8_TYPE_FNUZ for gfx942\")\nelse()\n  message(FATAL_ERROR \"Unsupported AMDGPU_TARGET '${AMDGPU_TARGETS}'. Expected 'gfx942' or 'gfx950' or both.\")\nendif()\n\n# TopK dynamic smem bytes\n# Dynamic shared-memory budget for the TopK kernels.\n# - gfx942 (MI300/MI325): LDS is typically 64KB per workgroup -> keep dynamic smem <= ~48KB\n#   (leaves room for static shared allocations in the kernel).\n# - gfx95x (MI350): LDS is larger (e.g. 160KB per CU) -> allow the original 128KB dynamic smem.\nif(AMDGPU_TARGET_ONE STREQUAL \"gfx942\")\n  math(EXPR SGL_TOPK_DYNAMIC_SMEM_BYTES \"48 * 1024\")\nelse()\n  math(EXPR SGL_TOPK_DYNAMIC_SMEM_BYTES \"32 * 1024 * 4\")\nendif()\n\nset(SGL_TOPK_MACROS \"-DSGL_TOPK_DYNAMIC_SMEM_BYTES=${SGL_TOPK_DYNAMIC_SMEM_BYTES}\")\n\n# Paths / includes\nset(PROJ_ROOT ${CMAKE_CURRENT_LIST_DIR})\nset(SGL_INCLUDE_DIRS\n  ${PROJ_ROOT}/include\n  ${PROJ_ROOT}/include/impl\n  ${PROJ_ROOT}/csrc\n  ${TORCH_INCLUDE_DIRS}\n)\n\n# Platform-specific library directory\nset(PLAT_LIB_DIR \"/usr/lib/x86_64-linux-gnu\")\nlink_directories(${PLAT_LIB_DIR})\n\n# Sources\nset(SOURCES\n${PROJ_ROOT}/csrc/allreduce/custom_all_reduce.hip\n${PROJ_ROOT}/csrc/allreduce/deterministic_all_reduce.hip\n${PROJ_ROOT}/csrc/allreduce/quick_all_reduce.hip\n${PROJ_ROOT}/csrc/common_extension_rocm.cc\n${PROJ_ROOT}/csrc/elementwise/activation.hip\n${PROJ_ROOT}/csrc/elementwise/pos_enc.hip\n${PROJ_ROOT}/csrc/elementwise/topk.hip\n${PROJ_ROOT}/csrc/grammar/apply_token_bitmask_inplace_hip.hip\n${PROJ_ROOT}/csrc/kvcacheio/transfer.hip\n${PROJ_ROOT}/csrc/memory/weak_ref_tensor.cpp\n${PROJ_ROOT}/csrc/moe/moe_align_kernel.hip\n${PROJ_ROOT}/csrc/moe/moe_topk_softmax_kernels.hip\n${PROJ_ROOT}/csrc/moe/moe_topk_sigmoid_kernels.hip\n${PROJ_ROOT}/csrc/speculative/eagle_utils.hip\n)\nset_source_files_properties(\n  ${SOURCES}\n  PROPERTIES\n    LANGUAGE HIP\n)\n\n# Compile / Link flags\nadd_compile_options($<$<COMPILE_LANGUAGE:CXX>:-O3>)\n\nset(SGL_HIP_FLAGS\n  -DNDEBUG\n  -DOPERATOR_NAMESPACE=sgl_kernel\n  -O3\n  -std=c++17\n  -DENABLE_BF16\n  -DENABLE_FP8\n  ${SGL_FP8_MACROS}\n  -Wno-pass-failed\n  -Wundefined-internal\n  ${SGL_TOPK_MACROS}\n)\n\n# Python extension\nPython_add_library(common_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})\ntarget_include_directories(common_ops PRIVATE ${SGL_INCLUDE_DIRS})\n\n# Apply per-language flags\ntarget_compile_options(common_ops PRIVATE\n  $<$<COMPILE_LANGUAGE:HIP>:${SGL_HIP_FLAGS}>\n)\n\ntarget_link_libraries(common_ops PRIVATE\n  ${TORCH_LIBRARIES}\n  hip::device\n  hip::host\n  hiprtc\n  amdhip64\n)\n\ntarget_link_options(common_ops PRIVATE\n  \"SHELL:-Wl,-rpath,'\\$ORIGIN/../../torch/lib'\"\n)\n\ninstall(TARGETS common_ops\n  LIBRARY DESTINATION sgl_kernel\n)\n"
  },
  {
    "path": "3rdparty/amd/wheel/sgl-kernel/build_rocm.sh",
    "content": "#!/bin/bash\nset -euo pipefail\n\nROCM_VERSION=${1:-}\n\nif [[ \"${ROCM_VERSION}\" == \"700\" ]]; then\n  IMAGE=\"lmsysorg/sglang:v0.5.8.post1-rocm700-mi35x\"\nelif [[ \"${ROCM_VERSION}\" == \"720\" ]]; then\n  IMAGE=\"rocm/pytorch:rocm7.2_ubuntu22.04_py3.10_pytorch_release_2.9.1\"\nelse\n  echo \"ERROR: Unsupported ROCM_VERSION='${ROCM_VERSION}'. Only '700' and '720' are supported.\" >&2\n  exit 1\nfi\n\nPYTHON_ROOT_PATH=\"/opt/venv/bin\"\nAMDGPU_TARGET=\"gfx942;gfx950\"\n\n# Pull and run the latest image\necho \"Pulling Docker image: ${IMAGE}\"\ndocker pull \"${IMAGE}\"\n\ndocker run --rm \\\n  -v $(pwd):/sgl-kernel \\\n  -e AMDGPU_TARGET=\"${AMDGPU_TARGET}\" \\\n  -e PYTORCH_ROCM_ARCH=\"${AMDGPU_TARGET}\" \\\n  ${IMAGE} \\\n  bash -c \"\n  # Install torch, triton, and friends, depending on the ROCm version\n  if [[ \"${ROCM_VERSION}\" == \"700\" ]]; then\n    ${PYTHON_ROOT_PATH}/pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-7.0.2/torch-2.9.1.dev20251204%2Brocm7.0.2.lw.git351ff442-cp310-cp310-linux_x86_64.whl https://repo.radeon.com/rocm/manylinux/rocm-rel-7.0.2/triton-3.5.1%2Brocm7.0.2.gita272dfa8-cp310-cp310-linux_x86_64.whl https://repo.radeon.com/rocm/manylinux/rocm-rel-7.0.2/torchaudio-2.9.0%2Brocm7.0.2.gite3c6ee2b-cp310-cp310-linux_x86_64.whl https://repo.radeon.com/rocm/manylinux/rocm-rel-7.0.2/torchvision-0.24.0%2Brocm7.0.2.gitb919bd0c-cp310-cp310-linux_x86_64.whl\n  elif [[ \"${ROCM_VERSION}\" == \"720\" ]]; then\n    ${PYTHON_ROOT_PATH}/pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torch-2.9.1%2Brocm7.2.0.lw.git7e1940d4-cp310-cp310-linux_x86_64.whl https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/triton-3.5.1%2Brocm7.2.0.gita272dfa8-cp310-cp310-linux_x86_64.whl https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torchaudio-2.9.0%2Brocm7.2.0.gite3c6ee2b-cp310-cp310-linux_x86_64.whl https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torchvision-0.24.0%2Brocm7.2.0.gitb919bd0c-cp310-cp310-linux_x86_64.whl\n  fi\n  # Install CMake (version >= 3.26) - Robust Installation\n  export CMAKE_VERSION_MAJOR=3.31\n  export CMAKE_VERSION_MINOR=1\n  echo \\\"Downloading CMake from: https://cmake.org/files/v\\${CMAKE_VERSION_MAJOR}/cmake-\\${CMAKE_VERSION_MAJOR}.\\${CMAKE_VERSION_MINOR}-linux-x86_64.tar.gz\\\"\n  wget https://cmake.org/files/v\\${CMAKE_VERSION_MAJOR}/cmake-\\${CMAKE_VERSION_MAJOR}.\\${CMAKE_VERSION_MINOR}-linux-x86_64.tar.gz\n  tar -xzf cmake-\\${CMAKE_VERSION_MAJOR}.\\${CMAKE_VERSION_MINOR}-linux-x86_64.tar.gz\n  mv cmake-\\${CMAKE_VERSION_MAJOR}.\\${CMAKE_VERSION_MINOR}-linux-x86_64 /opt/cmake\n  export PATH=/opt/cmake/bin:\\$PATH\n\n  ${PYTHON_ROOT_PATH}/pip install --no-cache-dir ninja setuptools wheel numpy uv scikit-build-core && \\\n\n  cd /sgl-kernel && \\\n  rm -rf CMakeLists.txt && mv CMakeLists_rocm.txt CMakeLists.txt && \\\n  ${PYTHON_ROOT_PATH}/python rocm_hipify.py && \\\n  ${PYTHON_ROOT_PATH}/python -m uv build --wheel -Cbuild-dir=build . --color=always --no-build-isolation && \\\n  ./rename_wheels_rocm.sh\n\"\n"
  },
  {
    "path": "3rdparty/amd/wheel/sgl-kernel/rename_wheels_rocm.sh",
    "content": "#!/usr/bin/env bash\nset -ex\n\nWHEEL_DIR=\"dist\"\n\nwheel_files=($WHEEL_DIR/*.whl)\nfor wheel in \"${wheel_files[@]}\"; do\n    intermediate_wheel=\"${wheel/linux/manylinux2014}\"\n    [[ \"$intermediate_wheel\" == *\"+rocm\"* ]] && continue\n\n    # Extract the current python version from the wheel name\n    if [[ $intermediate_wheel =~ -cp([0-9]+)- ]]; then\n        cp_version=\"${BASH_REMATCH[1]}\"\n    else\n        echo \"Could not extract Python version from wheel name: $intermediate_wheel\"\n        continue\n    fi\n\n    # Detect ROCm version and add appropriate suffix\n    ver_abrv=$(realpath /opt/rocm-* | sed -e 's/.*-//' -e 's/\\.//g')\n    new_wheel=${intermediate_wheel/-cp${cp_version}/+rocm${ver_abrv}-cp${cp_version}}\n\n    if [[ \"$wheel\" != \"$new_wheel\" ]]; then\n        echo \"Renaming $wheel to $new_wheel\"\n        mv -- \"$wheel\" \"$new_wheel\"\n    fi\ndone\necho \"Wheel renaming completed.\"\n"
  },
  {
    "path": "3rdparty/amd/wheel/sgl-kernel/rocm_hipify.py",
    "content": "from pathlib import Path\n\nimport torch\nfrom torch.utils.cpp_extension import CUDAExtension\n\nroot = Path(__file__).parent.resolve()\n\ninclude_dirs = [\n    root / \"include\",\n    root / \"include\" / \"impl\",\n    root / \"csrc\",\n]\n\nsources = [\n    \"csrc/allreduce/custom_all_reduce.hip\",\n    \"csrc/allreduce/deterministic_all_reduce.hip\",\n    \"csrc/allreduce/quick_all_reduce.cu\",\n    \"csrc/common_extension_rocm.cc\",\n    \"csrc/elementwise/activation.cu\",\n    \"csrc/elementwise/pos_enc.cu\",\n    \"csrc/elementwise/topk.cu\",\n    \"csrc/grammar/apply_token_bitmask_inplace_cuda.cu\",\n    \"csrc/kvcacheio/transfer.cu\",\n    \"csrc/memory/weak_ref_tensor.cpp\",\n    \"csrc/moe/moe_align_kernel.cu\",\n    \"csrc/moe/moe_topk_softmax_kernels.cu\",\n    \"csrc/moe/moe_topk_sigmoid_kernels.cu\",\n    \"csrc/speculative/eagle_utils.cu\",\n]\n\nlibraries = [\"hiprtc\", \"amdhip64\", \"c10\", \"torch\", \"torch_python\"]\n\next_modules = [\n    CUDAExtension(\n        name=\"sgl_kernel.common_ops\",\n        sources=sources,\n        include_dirs=include_dirs,\n        libraries=libraries,\n        py_limited_api=False,\n    ),\n]\n"
  },
  {
    "path": "3rdparty/amd/wheel/sglang/pyproject.toml",
    "content": "[build-system]\nrequires = [\"setuptools>=61.0\", \"setuptools-scm>=8.0\", \"wheel\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"amd-sglang\"\ndynamic = [\"version\"]\ndescription = \"SGLang is a fast serving framework for large language models and vision language models.\"\nreadme = \"README.md\"\nrequires-python = \">=3.10\"\nlicense = { file = \"LICENSE\" }\nclassifiers = [\n  \"Programming Language :: Python :: 3\",\n  \"License :: OSI Approved :: Apache Software License\",\n]\ndependencies = [\"aiohttp\", \"requests\", \"tqdm\", \"numpy\", \"IPython\", \"setproctitle\"]\n\n[project.optional-dependencies]\nruntime_common = [\n  \"IPython\",\n  \"aiohttp\",\n  \"anthropic>=0.20.0\",\n  \"blobfile==3.0.0\",\n  \"build\",\n  \"compressed-tensors\",\n  \"decord2\",\n  \"datasets\",\n  \"einops\",\n  \"fastapi\",\n  \"gguf\",\n  \"hf_transfer\",\n  \"huggingface_hub\",\n  \"interegular\",\n  \"llguidance>=0.7.11,<0.8.0\",\n  \"modelscope\",\n  \"msgspec\",\n  \"ninja\",\n  \"numpy\",\n  \"openai-harmony==0.0.4\",\n  \"openai==2.6.1\",\n  \"orjson\",\n  \"outlines==0.1.11\",\n  \"packaging\",\n  \"partial_json_parser\",\n  \"pillow\",\n  \"prometheus-client>=0.20.0\",\n  \"psutil\",\n  \"py-spy\",\n  \"pybase64\",\n  \"pydantic\",\n  \"python-multipart\",\n  \"pyzmq>=25.1.2\",\n  \"requests\",\n  \"scipy\",\n  \"sentencepiece\",\n  \"setproctitle\",\n  \"soundfile==0.13.1\",\n  \"tiktoken\",\n  \"timm==1.0.16\",\n  \"torchao==0.9.0\",\n  \"tqdm\",\n  \"transformers==4.57.1\",\n  \"uvicorn\",\n  \"uvloop\",\n  \"xgrammar==0.1.27\",\n  \"smg-grpc-servicer>=0.5.0\",\n]\n\n# ROCm specific packages (https://repo.radeon.com/rocm/manylinux/)\n# Existing practice for daily rocm700 docker images relies on 700-rc\n# versions of software that are not public available. Here we pin some\n# from rocm702 as the closest set as daily rocm700 images.\nrocm700 = [\n  \"torch @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.0.2/torch-2.9.1.dev20251204%2Brocm7.0.2.lw.git351ff442-cp310-cp310-linux_x86_64.whl\",\n  \"triton @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.0.2/triton-3.5.1%2Brocm7.0.2.gita272dfa8-cp310-cp310-linux_x86_64.whl\",\n  \"torchaudio @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.0.2/torchaudio-2.9.0%2Brocm7.0.2.gite3c6ee2b-cp310-cp310-linux_x86_64.whl\",\n  \"torchvision @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.0.2/torchvision-0.24.0%2Brocm7.0.2.gitb919bd0c-cp310-cp310-linux_x86_64.whl\",\n  \"mooncake-transfer-engine-non-cuda==0.3.8.post1\",\n  \"sglang-kernel @ https://github.com/sgl-project/whl/releases/download/v0.4.0/sglang_kernel-0.4.0+rocm700-cp310-abi3-manylinux2014_x86_64.whl\",\n]\n\nrocm720 = [\n  \"torch @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torch-2.9.1%2Brocm7.2.0.lw.git7e1940d4-cp310-cp310-linux_x86_64.whl\",\n  \"triton @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/triton-3.5.1%2Brocm7.2.0.gita272dfa8-cp310-cp310-linux_x86_64.whl\",\n  \"torchaudio @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torchaudio-2.9.0%2Brocm7.2.0.gite3c6ee2b-cp310-cp310-linux_x86_64.whl\",\n  \"torchvision @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torchvision-0.24.0%2Brocm7.2.0.gitb919bd0c-cp310-cp310-linux_x86_64.whl\",\n  \"mooncake-transfer-engine-non-cuda==0.3.8.post1\",\n  \"sglang-kernel @ https://github.com/sgl-project/whl/releases/download/v0.4.0/sglang_kernel-0.4.0+rocm720-cp310-abi3-manylinux2014_x86_64.whl\",\n]\n\n# HIP (Heterogeneous-computing Interface for Portability) for AMD\n# Install with one of:\n#   pip install \"amd-sglang[srt_hip,rocm700]\"\n#   pip install \"amd-sglang[srt_hip,rocm720]\"\nsrt_hip = [\n  \"amd-sglang[runtime_common]\",\n  \"petit_kernel==0.0.2\",\n  \"wave-lang==3.8.2\",\n]\n\ndiffusion_hip = [\n  \"PyYAML==6.0.1\",\n  \"cloudpickle\",\n  \"diffusers==0.37.0\",\n  \"imageio==2.36.0\",\n  \"imageio-ffmpeg==0.5.1\",\n  \"moviepy>=2.0.0\",\n  \"opencv-python-headless==4.10.0.84\",\n  \"remote-pdb\",\n  \"st_attn==0.0.7\",\n  \"vsa==0.0.4\",\n  \"runai_model_streamer>=0.15.5\",\n  \"cache-dit==1.1.8\",\n  \"addict\",\n]\n\n# For Intel Gaudi(device : hpu) follow the installation guide\n# https://docs.vllm.ai/en/latest/getting_started/gaudi-installation.html\nsrt_hpu = [\"sglang[runtime_common]\"]\n\n# https://docs.sglang.io/platforms/mthreads_gpu.md\nsrt_musa = [\n  \"sglang[runtime_common]\",\n  \"torch\",\n  \"torch_musa\",\n  \"torchada>=0.1.25\",\n  \"mthreads-ml-py\",\n  \"numpy<2.0\",\n]\n\ndiffusion_musa = [\n  \"PyYAML==6.0.1\",\n  \"cloudpickle\",\n  \"diffusers==0.37.0\",\n  \"imageio==2.36.0\",\n  \"imageio-ffmpeg==0.5.1\",\n  \"moviepy>=2.0.0\",\n  \"opencv-python-headless==4.10.0.84\",\n  \"remote-pdb\",\n  \"st_attn==0.0.7\",\n  \"vsa==0.0.4\",\n  \"runai_model_streamer>=0.15.5\",\n  \"cache-dit==1.1.8\",\n  \"addict\",\n]\n\ntracing = [\n  \"opentelemetry-sdk\",\n  \"opentelemetry-api\",\n  \"opentelemetry-exporter-otlp\",\n  \"opentelemetry-exporter-otlp-proto-grpc\",\n]\n\ntest = [\n  \"accelerate\",\n  \"expecttest\",\n  \"gguf\",\n  \"jsonlines\",\n  \"matplotlib\",\n  \"pandas\",\n  \"peft\",\n  \"pytest\",\n  \"sentence_transformers\",\n  \"tabulate\",\n]\n\nall_hip = [\"amd-sglang[srt_hip]\", \"amd-sglang[diffusion_hip]\"]\nall_hpu = [\"sglang[srt_hpu]\"]\nall_musa = [\"sglang[srt_musa]\", \"sglang[diffusion_musa]\"]\n\ndev_hip = [\"amd-sglang[all_hip]\", \"amd-sglang[test]\"]\ndev_hpu = [\"sglang[all_hpu]\", \"sglang[test]\"]\ndev_musa = [\"sglang[all_musa]\", \"sglang[test]\"]\n\n[project.urls]\n\"Homepage\" = \"https://github.com/sgl-project/sglang\"\n\"Bug Tracker\" = \"https://github.com/sgl-project/sglang/issues\"\n\n[project.scripts]\nsglang = \"sglang.cli.main:main\"\n\n[tool.setuptools.package-data]\n\"sglang\" = [\n  \"srt/**/*\",\n  \"jit_kernel/**/*\",\n]\n\n[tool.setuptools.packages.find]\nexclude = [\n  \"assets*\",\n  \"benchmark*\",\n  \"docs*\",\n  \"dist*\",\n  \"playground*\",\n  \"scripts*\",\n  \"tests*\",\n]\n\n[tool.wheel]\nexclude = [\n  \"assets*\",\n  \"benchmark*\",\n  \"docs*\",\n  \"dist*\",\n  \"playground*\",\n  \"scripts*\",\n  \"tests*\",\n]\n\n[tool.setuptools_scm]\nroot = \"..\"\nversion_file = \"sglang/_version.py\"\ngit_describe_command = [\"git\", \"describe\", \"--tags\", \"--long\", \"--match\", \"v*\"]\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright 2023-2024 SGLang Team\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "<div align=\"center\" id=\"sglangtop\">\n<img src=\"https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png\" alt=\"logo\" width=\"400\" margin=\"10px\"></img>\n\n[![PyPI](https://img.shields.io/pypi/v/sglang)](https://pypi.org/project/sglang)\n![PyPI - Downloads](https://static.pepy.tech/badge/sglang?period=month)\n[![license](https://img.shields.io/github/license/sgl-project/sglang.svg)](https://github.com/sgl-project/sglang/tree/main/LICENSE)\n[![issue resolution](https://img.shields.io/github/issues-closed-raw/sgl-project/sglang)](https://github.com/sgl-project/sglang/issues)\n[![open issues](https://img.shields.io/github/issues-raw/sgl-project/sglang)](https://github.com/sgl-project/sglang/issues)\n[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/sgl-project/sglang)\n\n</div>\n\n--------------------------------------------------------------------------------\n\n<p align=\"center\">\n<a href=\"https://lmsys.org/blog/\"><b>Blog</b></a> |\n<a href=\"https://docs.sglang.io/\"><b>Documentation</b></a> |\n<a href=\"https://roadmap.sglang.io/\"><b>Roadmap</b></a> |\n<a href=\"https://slack.sglang.io/\"><b>Join Slack</b></a> |\n<a href=\"https://meet.sglang.io/\"><b>Weekly Dev Meeting</b></a> |\n<a href=\"https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#slides\"><b>Slides</b></a>\n</p>\n\n## News\n- [2026/02] 🔥 Unlocking 25x Inference Performance with SGLang on NVIDIA GB300 NVL72 ([blog](https://lmsys.org/blog/2026-02-20-gb300-inferencex/)).\n- [2026/01] 🔥 SGLang Diffusion accelerates video and image generation ([blog](https://lmsys.org/blog/2026-01-16-sglang-diffusion/)).\n- [2025/12] SGLang provides day-0 support for latest open models ([MiMo-V2-Flash](https://lmsys.org/blog/2025-12-16-mimo-v2-flash/), [Nemotron 3 Nano](https://lmsys.org/blog/2025-12-15-run-nvidia-nemotron-3-nano/), [Mistral Large 3](https://github.com/sgl-project/sglang/pull/14213), [LLaDA 2.0 Diffusion LLM](https://lmsys.org/blog/2025-12-19-diffusion-llm/), [MiniMax M2](https://lmsys.org/blog/2025-11-04-miminmax-m2/)).\n- [2025/10] 🔥 SGLang now runs natively on TPU with the SGLang-Jax backend ([blog](https://lmsys.org/blog/2025-10-29-sglang-jax/)).\n- [2025/09] Deploying DeepSeek on GB200 NVL72 with PD and Large Scale EP (Part II): 3.8x Prefill, 4.8x Decode Throughput ([blog](https://lmsys.org/blog/2025-09-25-gb200-part-2/)).\n- [2025/09] SGLang Day 0 Support for DeepSeek-V3.2 with Sparse Attention ([blog](https://lmsys.org/blog/2025-09-29-deepseek-V32/)).\n- [2025/08] SGLang x AMD SF Meetup on 8/22: Hands-on GPU workshop, tech talks by AMD/xAI/SGLang, and networking ([Roadmap](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/amd_meetup_sglang_roadmap.pdf), [Large-scale EP](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/amd_meetup_sglang_ep.pdf), [Highlights](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/amd_meetup_highlights.pdf), [AITER/MoRI](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/amd_meetup_aiter_mori.pdf), [Wave](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/amd_meetup_wave.pdf)).\n\n<details>\n<summary>More</summary>\n\n- [2025/11] SGLang Diffusion accelerates video and image generation ([blog](https://lmsys.org/blog/2025-11-07-sglang-diffusion/)).\n- [2025/10] PyTorch Conference 2025 SGLang Talk ([slide](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/sglang_pytorch_2025.pdf)).\n- [2025/10] SGLang x Nvidia SF Meetup on 10/2 ([recap](https://x.com/lmsysorg/status/1975339501934510231)).\n- [2025/08] SGLang provides day-0 support for OpenAI gpt-oss model ([instructions](https://github.com/sgl-project/sglang/issues/8833))\n- [2025/06] SGLang, the high-performance serving infrastructure powering trillions of tokens daily, has been awarded the third batch of the Open Source AI Grant by a16z ([a16z blog](https://a16z.com/advancing-open-source-ai-through-benchmarks-and-bold-experimentation/)).\n- [2025/05] Deploying DeepSeek with PD Disaggregation and Large-scale Expert Parallelism on 96 H100 GPUs ([blog](https://lmsys.org/blog/2025-05-05-large-scale-ep/)).\n- [2025/06] Deploying DeepSeek on GB200 NVL72 with PD and Large Scale EP (Part I): 2.7x Higher Decoding Throughput ([blog](https://lmsys.org/blog/2025-06-16-gb200-part-1/)).\n- [2025/03] Supercharge DeepSeek-R1 Inference on AMD Instinct MI300X ([AMD blog](https://rocm.blogs.amd.com/artificial-intelligence/DeepSeekR1-Part2/README.html))\n- [2025/03] SGLang Joins PyTorch Ecosystem: Efficient LLM Serving Engine ([PyTorch blog](https://pytorch.org/blog/sglang-joins-pytorch/))\n- [2025/02] Unlock DeepSeek-R1 Inference Performance on AMD Instinct™ MI300X GPU ([AMD blog](https://rocm.blogs.amd.com/artificial-intelligence/DeepSeekR1_Perf/README.html))\n- [2025/01] SGLang provides day one support for DeepSeek V3/R1 models on NVIDIA and AMD GPUs with DeepSeek-specific optimizations. ([instructions](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3), [AMD blog](https://www.amd.com/en/developer/resources/technical-articles/amd-instinct-gpus-power-deepseek-v3-revolutionizing-ai-development-with-sglang.html), [10+ other companies](https://x.com/lmsysorg/status/1887262321636221412))\n- [2024/12] v0.4 Release: Zero-Overhead Batch Scheduler, Cache-Aware Load Balancer, Faster Structured Outputs ([blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/)).\n- [2024/10] The First SGLang Online Meetup ([slides](https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#the-first-sglang-online-meetup)).\n- [2024/09] v0.3 Release: 7x Faster DeepSeek MLA, 1.5x Faster torch.compile, Multi-Image/Video LLaVA-OneVision ([blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/)).\n- [2024/07] v0.2 Release: Faster Llama3 Serving with SGLang Runtime (vs. TensorRT-LLM, vLLM) ([blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/)).\n- [2024/02] SGLang enables **3x faster JSON decoding** with compressed finite state machine ([blog](https://lmsys.org/blog/2024-02-05-compressed-fsm/)).\n- [2024/01] SGLang provides up to **5x faster inference** with RadixAttention ([blog](https://lmsys.org/blog/2024-01-17-sglang/)).\n- [2024/01] SGLang powers the serving of the official **LLaVA v1.6** release demo ([usage](https://github.com/haotian-liu/LLaVA?tab=readme-ov-file#demo)).\n\n</details>\n\n## About\nSGLang is a high-performance serving framework for large language models and multimodal models.\nIt is designed to deliver low-latency and high-throughput inference across a wide range of setups, from a single GPU to large distributed clusters.\nIts core features include:\n\n- **Fast Runtime**: Provides efficient serving with RadixAttention for prefix caching, a zero-overhead CPU scheduler, prefill-decode disaggregation, speculative decoding, continuous batching, paged attention, tensor/pipeline/expert/data parallelism, structured outputs, chunked prefill, quantization (FP4/FP8/INT4/AWQ/GPTQ), and multi-LoRA batching.\n- **Broad Model Support**: Supports a wide range of language models (Llama, Qwen, DeepSeek, Kimi, GLM, GPT, Gemma, Mistral, etc.), embedding models (e5-mistral, gte, mcdse), reward models (Skywork), and diffusion models (WAN, Qwen-Image), with easy extensibility for adding new models. Compatible with most Hugging Face models and OpenAI APIs.\n- **Extensive Hardware Support**: Runs on NVIDIA GPUs (GB200/B300/H100/A100/Spark), AMD GPUs (MI355/MI300), Intel Xeon CPUs, Google TPUs, Ascend NPUs, and more.\n- **Active Community**: SGLang is open-source and supported by a vibrant community with widespread industry adoption, powering over 400,000 GPUs worldwide.\n- **RL & Post-Training Backbone**: SGLang is a proven rollout backend used for training many frontier models, with native RL integrations and adoption by well-known post-training frameworks such as [**AReaL**](https://github.com/inclusionAI/AReaL), [**Miles**](https://github.com/radixark/miles), [**slime**](https://github.com/THUDM/slime), [**Tunix**](https://github.com/google/tunix), [**verl**](https://github.com/volcengine/verl) and more.\n\n## Getting Started\n- [Install SGLang](https://docs.sglang.io/get_started/install.html)\n- [Quick Start](https://docs.sglang.io/basic_usage/send_request.html)\n- [Backend Tutorial](https://docs.sglang.io/basic_usage/openai_api_completions.html)\n- [Frontend Tutorial](https://docs.sglang.io/references/frontend/frontend_tutorial.html)\n- [Contribution Guide](https://docs.sglang.io/developer_guide/contribution_guide.html)\n\n## Benchmark and Performance\nLearn more in the release blogs: [v0.2 blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/), [v0.3 blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/), [v0.4 blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/), [Large-scale expert parallelism](https://lmsys.org/blog/2025-05-05-large-scale-ep/), [GB200 rack-scale parallelism](https://lmsys.org/blog/2025-09-25-gb200-part-2/).\n\n## Adoption and Sponsorship\nSGLang has been deployed at large scale, generating trillions of tokens in production each day. It is trusted and adopted by a wide range of leading enterprises and institutions, including xAI, AMD, NVIDIA, Intel, LinkedIn, Cursor, Oracle Cloud, Google Cloud, Microsoft Azure, AWS, Atlas Cloud, Voltage Park, Nebius, DataCrunch, Novita, InnoMatrix, MIT, UCLA, the University of Washington, Stanford, UC Berkeley, Tsinghua University, Jam & Tea Studios, Baseten, and other major technology organizations across North America and Asia.\nAs an open-source LLM inference engine, SGLang has become the de facto industry standard, with deployments running on over 400,000 GPUs worldwide.\nSGLang is currently hosted under the non-profit open-source organization [LMSYS](https://lmsys.org/about/).\n\n<img src=\"https://raw.githubusercontent.com/sgl-project/sgl-learning-materials/refs/heads/main/slides/adoption.png\" alt=\"logo\" width=\"800\" margin=\"10px\"></img>\n\n## Contact Us\nFor enterprises interested in adopting or deploying SGLang at scale, including technical consulting, sponsorship opportunities, or partnership inquiries, please contact us at sglang@lmsys.org\n\n## Acknowledgment\nWe learned the design and reused code from the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), and [LMQL](https://github.com/eth-sri/lmql).\n"
  },
  {
    "path": "benchmark/asr/README.md",
    "content": "# ASR Benchmark\n\nThis benchmark evaluates the performance and accuracy (Word Error Rate - WER) of Automatic Speech Recognition (ASR) models served via SGLang.\n\n## Supported Models\n\n- `openai/whisper-large-v3`\n- `openai/whisper-large-v3-turbo`\n\n## Setup\n\nInstall the required dependencies:\n\n```bash\napt install ffmpeg\npip install librosa soundfile datasets evaluate jiwer transformers openai torchcodec torch\n```\n\n## Running the Benchmark\n\n### 1. Start SGLang Server\n\nLaunch the SGLang server with a Whisper model:\n\n```bash\npython -m sglang.launch_server --model-path openai/whisper-large-v3 --port 30000\n```\n\n### 2. Run the Benchmark Script\n\nBasic usage (using chat completions API):\n\n```bash\npython bench_sglang.py --base-url http://localhost:30000 --model openai/whisper-large-v3 --n-examples 10\n```\n\nUsing the OpenAI-compatible transcription API:\n\n```bash\npython bench_sglang.py \\\n    --base-url http://localhost:30000 \\\n    --model openai/whisper-large-v3 \\\n    --api-type transcription \\\n    --language English \\\n    --n-examples 10\n```\n\nRun with streaming and show real-time output:\n\n```bash\npython bench_sglang.py \\\n    --base-url http://localhost:30000 \\\n    --model openai/whisper-large-v3 \\\n    --api-type transcription \\\n    --stream \\\n    --show-predictions \\\n    --concurrency 1\n```\n\nRun with higher concurrency and save results:\n\n```bash\npython bench_sglang.py \\\n    --base-url http://localhost:30000 \\\n    --model openai/whisper-large-v3 \\\n    --concurrency 8 \\\n    --n-examples 100 \\\n    --output results.json \\\n    --show-predictions\n```\n\n## Arguments\n\n| Argument | Description | Default |\n|----------|-------------|---------|\n| `--base-url` | SGLang server URL | `http://localhost:30000` |\n| `--model` | Model name on the server | `openai/whisper-large-v3` |\n| `--dataset` | HuggingFace dataset for evaluation | `D4nt3/esb-datasets-earnings22-validation-tiny-filtered` |\n| `--split` | Dataset split to use | `validation` |\n| `--concurrency` | Number of concurrent requests | `4` |\n| `--n-examples` | Number of examples to process (`-1` for all) | `-1` |\n| `--output` | Path to save results as JSON | `None` |\n| `--show-predictions` | Display sample predictions | `False` |\n| `--print-n` | Number of samples to display | `5` |\n| `--api-type` | API to use: `chat` (chat completions) or `transcription` (audio transcriptions) | `chat` |\n| `--language` | Language for transcription API (e.g., `English`, `en`) | `None` |\n| `--stream` | Enable streaming mode for transcription API | `False` |\n\n## Metrics\n\nThe benchmark outputs:\n\n| Metric | Description |\n|--------|-------------|\n| **Total Requests** | Number of successful ASR requests processed |\n| **WER** | Word Error Rate (lower is better), computed using the `evaluate` library |\n| **Average Latency** | Mean time per request (seconds) |\n| **Median Latency** | 50th percentile latency (seconds) |\n| **95th Latency** | 95th percentile latency (seconds) |\n| **Throughput** | Requests processed per second |\n| **Token Throughput** | Output tokens per second |\n\n## Example Output\n\n```bash\npython bench_sglang.py --api-type transcription --concurrency 128 --model openai/whisper-large-v3 --show-predictions\n\nLoading dataset: D4nt3/esb-datasets-earnings22-validation-tiny-filtered...\nUsing API type: transcription\nRepo card metadata block was not found. Setting CardData to empty.\nWARNING:huggingface_hub.repocard:Repo card metadata block was not found. Setting CardData to empty.\nPerforming warmup...\nProcessing 511 samples...\n------------------------------\nResults for openai/whisper-large-v3:\nTotal Requests: 511\nWER: 12.7690\nAverage Latency: 1.3602s\nMedian Latency: 1.2090s\n95th Latency: 2.9986s\nThroughput: 19.02 req/s\nToken Throughput: 354.19 tok/s\nTotal Test Time: 26.8726s\n------------------------------\n\n==================== Sample Predictions ====================\nSample 1:\n  REF: on the use of taxonomy i you know i think it is it is early days for us to to make any clear indications to the market about the proportion that would fall under that requirement\n  PRED: on the eu taxonomy i think it is early days for us to make any clear indications to the market about the proportion that would fall under that requirement\n----------------------------------------\nSample 2:\n  REF: so within fiscal year 2021 say 120 a 100 depending on what the micro will do and next year it is not necessarily payable in q one is we will look at what the cash flows for 2022 look like\n  PRED: so within fiscal year 2021 say $120000 $100000 depending on what the macro will do and next year it is not necessarily payable in q one is we will look at what the cash flows for 2022 look like\n----------------------------------------\nSample 3:\n  REF: we talked about 4.7 gigawatts\n  PRED: we talked about 4.7 gigawatts\n----------------------------------------\nSample 4:\n  REF: and you know depending on that working capital build we will we will see what that yields\n  PRED: and depending on that working capital build we will see what that yields what\n----------------------------------------\nSample 5:\n  REF: so on on sinopec what we have agreed with sinopec way back then is that free cash flows after paying all capexs are distributed out 30 70%\n  PRED: so on sinopec what we have agreed with sinopec way back then is that free cash flows after paying all capexes are distributed out 30% 70%\n----------------------------------------\n============================================================\n```\n\n## Notes\n\n- Audio samples longer than 30 seconds are automatically filtered out (Whisper limitation)\n- The benchmark performs a warmup request before measuring performance\n- Results are normalized using the model's tokenizer when available\n- When using `--stream` with `--show-predictions`, use `--concurrency 1` for clean sequential output\n- The `--language` option accepts both full names (e.g., `English`) and ISO 639-1 codes (e.g., `en`)\n\n## Troubleshooting\n\n**Server connection refused**\n- Ensure the SGLang server is running and accessible at the specified `--base-url`\n- Check that the port is not blocked by a firewall\n\n**Out of memory errors**\n- Reduce `--concurrency` to lower GPU memory usage\n- Use a smaller Whisper model variant\n"
  },
  {
    "path": "benchmark/asr/bench_sglang.py",
    "content": "import argparse\nimport asyncio\nimport base64\nimport io\nimport json\nimport time\nfrom statistics import mean, median\n\nimport httpx\nimport librosa\nimport numpy as np\nimport soundfile\nfrom datasets import load_dataset\nfrom evaluate import load\nfrom openai import AsyncOpenAI, OpenAI\nfrom transformers import AutoTokenizer\n\n\ndef to_bytes(y, sr):\n    buffer = io.BytesIO()\n    soundfile.write(buffer, y, sr, format=\"WAV\")\n    buffer.seek(0)\n    return buffer\n\n\nasync def run_asr_chat(client, model_name, y, sr):\n    \"\"\"Use chat completions API with audio_url for ASR.\"\"\"\n    with to_bytes(y, sr) as f:\n        audio_bytes = f.read()\n        audio_base64 = base64.b64encode(audio_bytes).decode(\"utf-8\")\n\n    start_time = time.perf_counter()\n    response = await client.chat.completions.create(\n        model=model_name,\n        messages=[\n            {\n                \"role\": \"user\",\n                \"content\": [\n                    {\n                        \"type\": \"audio_url\",\n                        \"audio_url\": {\"url\": f\"data:audio/wav;base64,{audio_base64}\"},\n                    }\n                ],\n            }\n        ],\n        temperature=0.0,\n    )\n    end_time = time.perf_counter()\n\n    asr_text = response.choices[0].message.content\n    latency = end_time - start_time\n    return latency, asr_text\n\n\ndef run_asr_transcription_sync(client, model_name, y, sr, language=None):\n    \"\"\"Use audio transcriptions API for ASR (sync version).\"\"\"\n    audio_buffer = to_bytes(y, sr)\n    audio_buffer.name = \"audio.wav\"  # OpenAI client needs a name attribute\n\n    start_time = time.perf_counter()\n    kwargs = {\n        \"model\": model_name,\n        \"file\": audio_buffer,\n    }\n    if language:\n        kwargs[\"language\"] = language\n\n    transcription = client.audio.transcriptions.create(**kwargs)\n    end_time = time.perf_counter()\n\n    latency = end_time - start_time\n    return latency, transcription.text\n\n\ndef run_asr_transcription_stream_sync(\n    base_url, model_name, y, sr, language=None, show_stream=False\n):\n    \"\"\"Use audio transcriptions API with streaming for ASR.\"\"\"\n    audio_buffer = to_bytes(y, sr)\n    audio_bytes = audio_buffer.read()\n\n    data = {\n        \"model\": model_name,\n        \"response_format\": \"json\",\n        \"stream\": \"true\",\n    }\n    if language:\n        data[\"language\"] = language\n\n    start_time = time.perf_counter()\n    text_chunks = []\n\n    if show_stream:\n        print(\"[STREAM] \", end=\"\", flush=True)\n\n    with httpx.stream(\n        \"POST\",\n        f\"{base_url}/v1/audio/transcriptions\",\n        data=data,\n        files={\"file\": (\"audio.wav\", audio_bytes, \"audio/wav\")},\n        timeout=60.0,\n    ) as response:\n        for line in response.iter_lines():\n            if line.startswith(\"data: \") and not line.startswith(\"data: [DONE]\"):\n                try:\n                    chunk = json.loads(line[6:])\n                    if \"choices\" in chunk and chunk[\"choices\"]:\n                        delta = chunk[\"choices\"][0].get(\"delta\", {})\n                        content = delta.get(\"content\", \"\")\n                        if content:\n                            text_chunks.append(content)\n                            if show_stream:\n                                print(content, end=\"\", flush=True)\n                except json.JSONDecodeError:\n                    pass\n\n    if show_stream:\n        print()  # newline after stream\n\n    end_time = time.perf_counter()\n    latency = end_time - start_time\n    return latency, \"\".join(text_chunks)\n\n\nasync def run_asr_transcription(\n    client,\n    model_name,\n    y,\n    sr,\n    language=None,\n    stream=False,\n    base_url=None,\n    show_stream=False,\n):\n    \"\"\"Async wrapper for transcription API (runs sync call in executor).\"\"\"\n    loop = asyncio.get_event_loop()\n    if stream:\n        return await loop.run_in_executor(\n            None,\n            run_asr_transcription_stream_sync,\n            base_url,\n            model_name,\n            y,\n            sr,\n            language,\n            show_stream,\n        )\n    return await loop.run_in_executor(\n        None, run_asr_transcription_sync, client, model_name, y, sr, language\n    )\n\n\nasync def bound_asr(\n    sem,\n    client,\n    model_name,\n    tokenizer,\n    audio,\n    reference,\n    api_type=\"chat\",\n    language=None,\n    stream=False,\n    base_url=None,\n    show_stream=False,\n):\n    async with sem:\n        try:\n            if api_type == \"transcription\":\n                latency, text = await run_asr_transcription(\n                    client,\n                    model_name,\n                    *audio,\n                    language=language,\n                    stream=stream,\n                    base_url=base_url,\n                    show_stream=show_stream,\n                )\n            else:\n                latency, text = await run_asr_chat(client, model_name, *audio)\n\n            # Calculate tokens for throughput metrics\n            num_output_tokens = len(tokenizer(text, add_special_tokens=False).input_ids)\n\n            # Normalize for WER evaluation\n            # Whisper tokenizer has a normalize method\n            if hasattr(tokenizer, \"normalize\"):\n                out = tokenizer.normalize(text)\n                ref = tokenizer.normalize(reference)\n            else:\n                out = text.lower().strip()\n                ref = reference.lower().strip()\n\n            return latency, num_output_tokens, out, ref\n        except Exception as e:\n            print(f\"Error during ASR: {e}\")\n            return None\n\n\nasync def process_dataset(\n    model_name,\n    client,\n    data,\n    concurrent_request,\n    api_type=\"chat\",\n    language=None,\n    stream=False,\n    base_url=None,\n    show_predictions=False,\n):\n    sem = asyncio.Semaphore(concurrent_request)\n    tokenizer = AutoTokenizer.from_pretrained(model_name)\n\n    # Warmup\n    print(\"Performing warmup...\")\n    audio_warmup, sr_warmup = (\n        data[0][\"audio\"][\"array\"],\n        data[0][\"audio\"][\"sampling_rate\"],\n    )\n    await bound_asr(\n        sem,\n        client,\n        model_name,\n        tokenizer,\n        (audio_warmup, sr_warmup),\n        \"\",\n        api_type=api_type,\n        language=language,\n        stream=stream,\n        base_url=base_url,\n        show_stream=False,  # Don't show stream during warmup\n    )\n\n    tasks = []\n    print(f\"Processing {len(data)} samples...\")\n    for sample in data:\n        audio, sr = sample[\"audio\"][\"array\"], sample[\"audio\"][\"sampling_rate\"]\n        tasks.append(\n            asyncio.create_task(\n                bound_asr(\n                    sem,\n                    client,\n                    model_name,\n                    tokenizer,\n                    (audio, sr),\n                    sample[\"text\"],\n                    api_type=api_type,\n                    language=language,\n                    stream=stream,\n                    base_url=base_url,\n                    show_stream=show_predictions and stream,\n                )\n            )\n        )\n\n    results = await asyncio.gather(*tasks)\n    return [r for r in results if r is not None]\n\n\ndef run_evaluation(args):\n    # Use sync client for transcription API, async for chat API\n    if args.api_type == \"transcription\":\n        client = OpenAI(base_url=f\"{args.base_url}/v1\", api_key=\"None\")\n    else:\n        client = AsyncOpenAI(base_url=f\"{args.base_url}/v1\", api_key=\"None\")\n\n    print(f\"Loading dataset: {args.dataset}...\")\n    print(f\"Using API type: {args.api_type}\" + (f\" (streaming)\" if args.stream else \"\"))\n    dataset = load_dataset(args.dataset, split=args.split)\n\n    # Filter by duration if needed (Whisper max is 30s)\n    def add_duration(sample):\n        y, sr = sample[\"audio\"][\"array\"], sample[\"audio\"][\"sampling_rate\"]\n        sample[\"duration_ms\"] = librosa.get_duration(y=y, sr=sr) * 1000\n        return sample\n\n    if \"duration_ms\" not in dataset.column_names:\n        dataset = dataset.map(add_duration)\n\n    dataset = dataset.filter(lambda x: x[\"duration_ms\"] < 30000)\n\n    if args.n_examples > 0:\n        dataset = dataset.select(range(min(args.n_examples, len(dataset))))\n\n    start = time.perf_counter()\n    results = asyncio.run(\n        process_dataset(\n            args.model,\n            client,\n            dataset,\n            args.concurrency,\n            api_type=args.api_type,\n            language=args.language,\n            stream=args.stream,\n            base_url=args.base_url,\n            show_predictions=args.show_predictions,\n        )\n    )\n    total_test_time = time.perf_counter() - start\n\n    if not results:\n        print(\"No successful results to evaluate.\")\n        return\n\n    # Metrics\n    latencies = [res[0] for res in results]\n    total_tokens = sum([res[1] for res in results])\n    predictions = [res[2] for res in results]\n    references = [res[3] for res in results]\n\n    wer_metric = load(\"wer\")\n    wer_score = 100 * wer_metric.compute(references=references, predictions=predictions)\n\n    print(\"-\" * 30)\n    print(f\"Results for {args.model}:\")\n    print(f\"Total Requests: {len(results)}\")\n    print(f\"WER: {wer_score:.4f}\")\n    print(f\"Average Latency: {mean(latencies):.4f}s\")\n    print(f\"Median Latency: {median(latencies):.4f}s\")\n    print(f\"95th Latency: {np.percentile(latencies, 95):.4f}s\")\n    print(f\"Throughput: {len(results) / total_test_time:.2f} req/s\")\n    print(f\"Token Throughput: {total_tokens / total_test_time:.2f} tok/s\")\n    print(f\"Total Test Time: {total_test_time:.4f}s\")\n    print(\"-\" * 30)\n\n    if args.output:\n        with open(args.output, \"w\") as f:\n            import json\n\n            json.dump(\n                {\n                    \"model\": args.model,\n                    \"dataset\": args.dataset,\n                    \"wer\": wer_score,\n                    \"avg_latency\": mean(latencies),\n                    \"throughput\": len(results) / total_test_time,\n                    \"token_throughput\": total_tokens / total_test_time,\n                },\n                f,\n                indent=2,\n            )\n\n    if args.show_predictions:\n        print(\"\\n\" + \"=\" * 20 + \" Sample Predictions \" + \"=\" * 20)\n        num_to_show = min(args.print_n, len(results))\n        for i in range(num_to_show):\n            print(f\"Sample {i+1}:\")\n            print(f\"  REF: {references[i]}\")\n            print(f\"  PRED: {predictions[i]}\")\n            print(\"-\" * 40)\n        print(\"=\" * 60)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"Benchmark sGLang ASR performance.\")\n    parser.add_argument(\n        \"--base-url\", default=\"http://localhost:30000\", help=\"sGLang server base URL\"\n    )\n    parser.add_argument(\n        \"--model\", default=\"openai/whisper-large-v3\", help=\"Model name on the server\"\n    )\n    parser.add_argument(\n        \"--dataset\",\n        default=\"D4nt3/esb-datasets-earnings22-validation-tiny-filtered\",\n        help=\"HF dataset repo\",\n    )\n    parser.add_argument(\"--split\", default=\"validation\", help=\"Dataset split\")\n    parser.add_argument(\n        \"--concurrency\", type=int, default=4, help=\"Number of concurrent requests\"\n    )\n    parser.add_argument(\n        \"--n-examples\",\n        \"-n\",\n        type=int,\n        default=-1,\n        help=\"Number of examples to test (-1 for all)\",\n    )\n    parser.add_argument(\"--output\", help=\"Path to save results in JSON\")\n    parser.add_argument(\n        \"--show-predictions\",\n        action=\"store_true\",\n        help=\"Print sample predictions and references\",\n    )\n    parser.add_argument(\n        \"--print-n\", type=int, default=5, help=\"Number of sample predictions to print\"\n    )\n    parser.add_argument(\n        \"--api-type\",\n        choices=[\"chat\", \"transcription\"],\n        default=\"chat\",\n        help=\"API type to use: 'chat' for chat completions with audio_url, 'transcription' for audio.transcriptions API\",\n    )\n    parser.add_argument(\n        \"--language\",\n        default=None,\n        help=\"Language code for transcription API (e.g., 'en')\",\n    )\n    parser.add_argument(\n        \"--stream\",\n        action=\"store_true\",\n        help=\"Use streaming mode for transcription API\",\n    )\n    args = parser.parse_args()\n\n    run_evaluation(args)\n"
  },
  {
    "path": "benchmark/bench_attention_sink/bench_attention_sink_triton.py",
    "content": "import argparse\n\nimport torch\nimport triton\n\nfrom sglang.srt.layers.attention.triton_ops.decode_attention import (\n    decode_attention_fwd_grouped,\n)\nfrom sglang.srt.layers.attention.triton_ops.extend_attention import extend_attention_fwd\n\n# gpt oss\nhead_num = 64\nhead_dim = 64\nhead_kv_num = 8\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=[\"S\"],  # sequence length on x-axis\n        x_vals=[128, 256, 512, 1024, 2048, 4096],\n        x_log=True,\n        line_arg=\"B\",  # batch size as different lines\n        line_vals=[1, 8, 32, 128],\n        line_names=[\"B=1\", \"B=8\", \"B=32\", \"B=128\"],\n        styles=[\n            (\"blue\", \"-\"),\n            (\"green\", \"-\"),\n            (\"red\", \"-\"),\n            (\"cyan\", \"-\"),\n        ],\n        ylabel=\"TFLOPS\",\n        plot_name=\"attention-sink-triton-decode\",\n        args={},\n    )\n)\ndef benchmark_decode(B, S, H_Q, H_KV, D):\n    D_V = D\n    dtype = torch.bfloat16\n    seq_len = S\n    total_tokens = B * seq_len\n    device = torch.device(\"cuda\")\n    sm_scale = 1.0 / (D**0.5)\n    max_kv_splits = 8\n    num_kv_splits = torch.full((B,), 4, dtype=torch.int32, device=\"cuda\")\n\n    # q represents the new token being generated, one per batch\n    q = torch.randn(B, H_Q, D, dtype=dtype, device=\"cuda\")\n\n    # k_buffer and v_buffer represent all previous tokens\n    k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device=\"cuda\")\n    v_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device=\"cuda\")\n\n    o = torch.zeros(B, H_Q, D_V, dtype=dtype, device=\"cuda\")\n\n    b_seq_len = torch.full((B,), seq_len, device=\"cuda\")\n\n    kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=\"cuda\")\n    kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len, dim=0)\n    kv_indices = torch.arange(total_tokens, device=\"cuda\")\n\n    attn_logits1 = torch.empty(\n        (B, H_Q, max_kv_splits, D_V),\n        dtype=torch.float32,\n        device=\"cuda\",\n    )\n    attn_lse1 = torch.empty(\n        (B, H_Q, max_kv_splits, D_V),\n        dtype=torch.float32,\n        device=\"cuda\",\n    )\n    sink = torch.randn(H_Q, device=device, dtype=torch.float32)\n\n    # warmup\n    for _ in range(5):\n        decode_attention_fwd_grouped(\n            q,\n            k_buffer,\n            v_buffer,\n            o,\n            kv_indptr,\n            kv_indices,\n            attn_logits1,\n            attn_lse1,\n            num_kv_splits,\n            max_kv_splits,\n            sm_scale,\n            logit_cap=0.0,\n            sinks=sink,\n        )\n\n    # benchmark\n    run_step = 500\n    start_event = torch.cuda.Event(enable_timing=True)\n    end_event = torch.cuda.Event(enable_timing=True)\n    start_event.record()\n    for _ in range(run_step):\n        decode_attention_fwd_grouped(\n            q,\n            k_buffer,\n            v_buffer,\n            o,\n            kv_indptr,\n            kv_indices,\n            attn_logits1,\n            attn_lse1,\n            num_kv_splits,\n            max_kv_splits,\n            sm_scale,\n            logit_cap=0.0,\n            sinks=sink,\n        )\n    end_event.record()\n    end_event.synchronize()\n    torch.cuda.synchronize()\n    ms = start_event.elapsed_time(end_event) / run_step\n    tflops = lambda ms: (2 * B * S * H_Q * D) * 1e-9 / ms  # must be causal\n    return tflops(ms)\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=[\"S\"],  # sequence length on x-axis\n        x_vals=[128, 256, 512, 1024, 2048, 4096],\n        x_log=True,\n        line_arg=\"B\",  # batch size as different lines\n        line_vals=[1, 8, 32, 128],\n        line_names=[\"B=1\", \"B=8\", \"B=32\", \"B=128\"],\n        styles=[\n            (\"blue\", \"-\"),\n            (\"green\", \"-\"),\n            (\"red\", \"-\"),\n            (\"cyan\", \"-\"),\n        ],\n        ylabel=\"TFLOPS\",\n        plot_name=\"attention-sink-triton-extend\",\n        args={},\n    )\n)\ndef benchmark_extend(B, S, H_Q, H_KV, D):\n    # S here represents N_CTX from the test\n    dtype = torch.bfloat16\n    device = \"cuda\"\n\n    # Split S into prefix and extend lengths\n    prefill_len = S // 2  # Similar to test's N_CTX // 2\n    extend_len = S // 4  # Make extend length smaller than prefix\n\n    # Calculate total tokens and extend tokens\n    total_extend_tokens = B * extend_len\n    total_prefix_tokens = B * prefill_len\n\n    # Create query, key, value tensors for extension\n    q_extend = torch.randn(total_extend_tokens, H_Q, D, dtype=dtype, device=device)\n    k_extend = torch.randn(total_extend_tokens, H_KV, D, dtype=dtype, device=device)\n    v_extend = torch.randn(total_extend_tokens, H_KV, D, dtype=dtype, device=device)\n    o_extend = torch.empty_like(q_extend)\n\n    # Create key-value buffers for prefix\n    k_buffer = torch.randn(total_prefix_tokens, H_KV, D, dtype=dtype, device=device)\n    v_buffer = torch.randn(total_prefix_tokens, H_KV, D, dtype=dtype, device=device)\n\n    # Create index pointers\n    qo_indptr = torch.arange(0, (B + 1) * extend_len, extend_len, device=device).to(\n        torch.int32\n    )\n    kv_indptr = torch.arange(0, (B + 1) * prefill_len, prefill_len, device=device).to(\n        torch.int32\n    )\n    kv_indices = torch.arange(0, total_prefix_tokens, device=device).to(torch.int32)\n\n    sm_scale = 1.0 / (D**0.5)\n    # sliding_window = 128  # From GPT-OSS config, skip for now\n    sliding_window = -1\n\n    sink = torch.randn(H_Q, device=device, dtype=torch.float32)\n\n    # warmup\n    for _ in range(5):\n        extend_attention_fwd(\n            q_extend,\n            k_extend,\n            v_extend,\n            o_extend,\n            k_buffer,\n            v_buffer,\n            qo_indptr,\n            kv_indptr,\n            kv_indices,\n            custom_mask=None,\n            is_causal=True,\n            mask_indptr=None,\n            max_len_extend=extend_len,\n            sm_scale=sm_scale,\n            sliding_window_size=sliding_window,\n            sinks=sink,\n        )\n\n    # benchmark\n    run_step = 500\n    start_event = torch.cuda.Event(enable_timing=True)\n    end_event = torch.cuda.Event(enable_timing=True)\n    start_event.record()\n    for _ in range(run_step):\n        extend_attention_fwd(\n            q_extend,\n            k_extend,\n            v_extend,\n            o_extend,\n            k_buffer,\n            v_buffer,\n            qo_indptr,\n            kv_indptr,\n            kv_indices,\n            custom_mask=None,\n            is_causal=True,\n            mask_indptr=None,\n            max_len_extend=extend_len,\n            sm_scale=sm_scale,\n            sliding_window_size=sliding_window,\n            sinks=sink,\n        )\n    end_event.record()\n    end_event.synchronize()\n    torch.cuda.synchronize()\n    ms = start_event.elapsed_time(end_event) / run_step\n\n    # FLOPS calculation: each attention operation requires 2 multiplications per element\n    total_flops = 2 * total_extend_tokens * H_Q * (prefill_len + extend_len / 2) * D\n    tflops = lambda ms: total_flops * 1e-12 / (ms * 1e-3)  # convert to TFLOPS\n    return tflops(ms)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--bench\", type=str, default=\"all\", help=\"all, extend, decode\")\n    args = parser.parse_args()\n\n    kwargs = {\n        \"H_Q\": head_num,\n        \"H_KV\": head_kv_num,\n        \"D\": head_dim,\n    }\n\n    if args.bench in [\"all\", \"decode\"]:\n        benchmark_decode.run(print_data=True, show_plots=False, **kwargs)\n\n    if args.bench in [\"all\", \"extend\"]:\n        benchmark_extend.run(print_data=True, show_plots=False, **kwargs)\n\n    print(\"Benchmark finished!\")\n"
  },
  {
    "path": "benchmark/bench_in_batch_prefix/bench_in_batch_prefix.py",
    "content": "# Benchmark with lots of common prefixes. Used to benchmark prefix caching performance.\n#\n# Launch a server:\n# python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --log-level-http warning\n\nimport random\nimport string\nimport time\n\nfrom tqdm import tqdm\nfrom transformers import AutoTokenizer\n\nimport sglang as sgl\nfrom sglang import set_default_backend\nfrom sglang.lang.backend.runtime_endpoint import RuntimeEndpoint\n\n\ndef generate_random_string(token_length: int) -> str:\n    random_string = \"\".join(\n        random.choices(string.ascii_letters + string.digits, k=token_length * 100)\n    )\n    tokenized_output = tokenizer.encode(random_string, add_special_tokens=False)[\n        :token_length\n    ]\n\n    if len(tokenized_output) < token_length:\n        tokenized_output = tokenized_output + [tokenizer.pad_token_id] * (\n            token_length - len(tokenized_output)\n        )\n\n    decoded_string = tokenizer.decode(tokenized_output, skip_special_tokens=False)\n    return decoded_string\n\n\ndef generate_unique_prefix(base_text, index):\n    return str(index) + base_text[len(str(index)) :]\n\n\n@sgl.function\ndef text_qa(s, question, gen_len):\n    s += \"Q: \" + question + \"\\n\"\n    s += \"A:\" + sgl.gen(\"answer\", stop=\"\\n\", temperature=0, max_tokens=gen_len)\n\n\ndef prepare_prompts(num_prefix, num_samples_per_prefix, prefix_length, suffix_length):\n    base_prefix = generate_random_string(prefix_length)\n\n    tot_input_len = 0\n    all_prompts = []\n    for i in tqdm(range(num_prefix), desc=\"prepare prompts\"):\n        unique_prefix = generate_unique_prefix(base_prefix, i)\n        prompt_list = []\n        for j in range(num_samples_per_prefix):\n            suffix = generate_random_string(suffix_length)\n            prompt = unique_prefix + suffix\n            prompt_list.append(prompt)\n            tot_input_len += len(tokenizer.encode(prompt))\n        all_prompts.append(prompt_list)\n    return all_prompts, tot_input_len\n\n\ndef test_batch_by_batch(all_prompts, gen_len):\n    backend.flush_cache()\n\n    tot_time = 0\n    for i in range(len(all_prompts)):\n        tic = time.perf_counter()\n        text_qa.run_batch(\n            list(zip(all_prompts[i], [gen_len] * len(all_prompts[i]))),\n        )\n        tot_time += time.perf_counter() - tic\n\n    return tot_time\n\n\ndef test_batch_by_batch_with_hint(all_prompts, gen_len):\n    backend.flush_cache()\n\n    tot_time = 0\n    for i in range(len(all_prompts)):\n        tic = time.perf_counter()\n        # Send a hint to cache the prefix\n        text_qa.run_batch(list(zip(all_prompts[i][:1], [gen_len])))\n        # Send the batch\n        text_qa.run_batch(list(zip(all_prompts[i], [gen_len] * len(all_prompts[i]))))\n\n        tot_time += time.perf_counter() - tic\n\n    return tot_time\n\n\ndef test_send_all(all_prompts, gen_len):\n    backend.flush_cache()\n\n    all_prompts = [x for prompt_list in all_prompts for x in prompt_list]\n\n    tic = time.perf_counter()\n    text_qa.run_batch(\n        list(zip(all_prompts, [gen_len] * len(all_prompts))),\n    )\n    tot_time = time.perf_counter() - tic\n\n    return tot_time\n\n\nif __name__ == \"__main__\":\n    tokenizer = AutoTokenizer.from_pretrained(\"hf-internal-testing/llama-tokenizer\")\n    backend = RuntimeEndpoint(\"http://127.0.0.1:30000\")\n    set_default_backend(backend)\n\n    random.seed(0)\n    num_prefix = 10\n    num_samples_per_prefix = 32\n    prefix_length = 1024\n    suffix_length = 128\n    gen_len = 1\n    all_prompts, tot_input_len = prepare_prompts(\n        num_prefix, num_samples_per_prefix, prefix_length, suffix_length\n    )\n\n    print(f\"Total input token length: {tot_input_len}\\n\")\n\n    cost = test_batch_by_batch(all_prompts, gen_len)\n    print(f\"Latency of test_batch_by_batch          : {cost:.4f} s\\n\")\n\n    cost = test_batch_by_batch_with_hint(all_prompts, gen_len)\n    print(f\"Latency of test_batch_by_batch_with_hint: {cost:.4f} s\\n\")\n\n    cost = test_send_all(all_prompts, gen_len)\n    print(f\"Latency of test_send_all                : {cost:.4f} s\\n\")\n"
  },
  {
    "path": "benchmark/bench_linear_attention/bench_gdn_decode.py",
    "content": "\"\"\"\nBenchmark & Correctness: GDN Packed Decode vs Baseline Decode.\n\nCompares:\n  - Baseline: split(mixed_qkv) → view → fused_sigmoid_gating_delta_rule_update\n  - Packed:   fused_recurrent_gated_delta_rule_packed_decode (single kernel)\n\nThe packed path eliminates:\n  - torch.split() + .view() tensor materialization\n  - Separate gating kernel launches\n  - Intermediate tensor allocations\n\nReports correctness (output & state matching) and performance (ms, speedup).\n\nUsage:\n    python bench_gdn_decode.py                        # default sweep\n    python bench_gdn_decode.py --mode bench           # benchmark only\n    python bench_gdn_decode.py --mode correctness     # correctness only\n    python bench_gdn_decode.py --preset qwen3.5-35b   # Qwen3.5-35B-A3B config\n\"\"\"\n\nimport argparse\nimport os\nimport sys\nimport time\n\nsys.path.insert(0, os.path.join(os.path.dirname(__file__), \"..\", \"python\"))\n\nimport torch\nimport triton\n\nfrom sglang.srt.layers.attention.fla.fused_recurrent import (\n    fused_recurrent_gated_delta_rule_packed_decode,\n)\nfrom sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import (\n    fused_sigmoid_gating_delta_rule_update,\n)\n\n# ---------------------------------------------------------------------------\n# Input factory\n# ---------------------------------------------------------------------------\n\n\ndef make_inputs(\n    B: int,\n    H: int,\n    HV: int,\n    K: int,\n    V: int,\n    pool_size: int,\n    device: str,\n    dtype: torch.dtype,\n    seed: int = 42,\n):\n    \"\"\"Create all input tensors for a single benchmark / correctness run.\"\"\"\n    torch.manual_seed(seed)\n\n    qkv_dim = 2 * H * K + HV * V\n    mixed_qkv = torch.randn(B, qkv_dim, device=device, dtype=dtype)\n    a = torch.randn(B, HV, device=device, dtype=dtype)\n    b = torch.randn(B, HV, device=device, dtype=dtype)\n    A_log = torch.randn(HV, device=device, dtype=dtype)\n    dt_bias = torch.randn(HV, device=device, dtype=dtype)\n\n    ssm_states = torch.randn(pool_size, HV, V, K, device=device, dtype=dtype) * 0.1\n    cache_indices = torch.arange(B, device=device, dtype=torch.int32)\n\n    cu_seqlens = torch.arange(B + 1, device=device, dtype=torch.long)\n\n    return dict(\n        B=B,\n        H=H,\n        HV=HV,\n        K=K,\n        V=V,\n        qkv_dim=qkv_dim,\n        pool_size=pool_size,\n        mixed_qkv=mixed_qkv,\n        a=a,\n        b=b,\n        A_log=A_log,\n        dt_bias=dt_bias,\n        ssm_states=ssm_states,\n        cache_indices=cache_indices,\n        cu_seqlens=cu_seqlens,\n    )\n\n\n# ---------------------------------------------------------------------------\n# Runner wrappers\n# ---------------------------------------------------------------------------\n\n\ndef run_baseline(inp):\n    \"\"\"Baseline path: split → view → fused_sigmoid_gating_delta_rule_update.\n\n    This mirrors the FULL original decode path in GDNAttnBackend.forward_decode,\n    including the split, view, and kernel call.\n    \"\"\"\n    B, H, HV, K, V = inp[\"B\"], inp[\"H\"], inp[\"HV\"], inp[\"K\"], inp[\"V\"]\n    mixed_qkv = inp[\"mixed_qkv\"]\n    ssm_states = inp[\"ssm_states\"].clone()\n\n    # Step 1: split (same as forward_decode)\n    q_flat, k_flat, v_flat = torch.split(mixed_qkv, [H * K, H * K, HV * V], dim=-1)\n\n    # Step 2: view + reshape (same as forward_decode)\n    q = q_flat.view(1, B, H, K)\n    k = k_flat.view(1, B, H, K)\n    v = v_flat.view(1, B, HV, V)\n\n    # Step 3: fused gating + recurrent update\n    o = fused_sigmoid_gating_delta_rule_update(\n        A_log=inp[\"A_log\"],\n        dt_bias=inp[\"dt_bias\"],\n        q=q,\n        k=k,\n        v=v,\n        a=inp[\"a\"],\n        b=inp[\"b\"],\n        initial_state_source=ssm_states,\n        initial_state_indices=inp[\"cache_indices\"],\n        cu_seqlens=inp[\"cu_seqlens\"],\n        use_qk_l2norm_in_kernel=True,\n        softplus_beta=1.0,\n        softplus_threshold=20.0,\n    )\n\n    return o, ssm_states\n\n\ndef run_packed(inp):\n    \"\"\"Packed path: single fused kernel directly on mixed_qkv.\"\"\"\n    B, HV, K, V = inp[\"B\"], inp[\"HV\"], inp[\"K\"], inp[\"V\"]\n    ssm_states = inp[\"ssm_states\"].clone()\n    out = inp[\"mixed_qkv\"].new_empty(B, 1, HV, V)\n\n    fused_recurrent_gated_delta_rule_packed_decode(\n        mixed_qkv=inp[\"mixed_qkv\"],\n        a=inp[\"a\"],\n        b=inp[\"b\"],\n        A_log=inp[\"A_log\"],\n        dt_bias=inp[\"dt_bias\"],\n        scale=inp[\"K\"] ** -0.5,\n        initial_state=ssm_states,\n        out=out,\n        ssm_state_indices=inp[\"cache_indices\"],\n        use_qk_l2norm_in_kernel=True,\n    )\n\n    # Convert [B, 1, HV, V] → [1, B, HV, V] to match baseline layout\n    return out.transpose(0, 1), ssm_states\n\n\n# ---------------------------------------------------------------------------\n# Correctness check\n# ---------------------------------------------------------------------------\n\n\ndef check_correctness(B, H, HV, K, V, pool_size, device, dtype, seed=42):\n    \"\"\"Run correctness check for a single config. Returns True if PASS.\"\"\"\n    tag = f\"B={B:>4} H={H:>2} HV={HV:>2} K={K:>3} V={V:>3} pool={pool_size:>4}\"\n\n    inp = make_inputs(B, H, HV, K, V, pool_size, device, dtype, seed=seed)\n\n    o_baseline, state_baseline = run_baseline(inp)\n    o_packed, state_packed = run_packed(inp)\n\n    # Output comparison\n    atol = 2e-2 if dtype != torch.float32 else 1e-4\n    rtol = 1e-2 if dtype != torch.float32 else 1e-4\n\n    try:\n        torch.testing.assert_close(o_packed, o_baseline, atol=atol, rtol=rtol)\n        output_ok = True\n    except AssertionError as e:\n        output_ok = False\n        out_diff = (o_packed - o_baseline).abs().max().item()\n\n    # State comparison (only for slots that were updated)\n    indices = inp[\"cache_indices\"]\n    try:\n        torch.testing.assert_close(\n            state_packed[indices], state_baseline[indices], atol=atol, rtol=rtol\n        )\n        state_ok = True\n    except AssertionError:\n        state_ok = False\n        st_diff = (state_packed[indices] - state_baseline[indices]).abs().max().item()\n\n    passed = output_ok and state_ok\n\n    if passed:\n        print(f\"  [PASS] {tag}\")\n    else:\n        details = []\n        if not output_ok:\n            details.append(f\"output max_diff={out_diff:.6f}\")\n        if not state_ok:\n            details.append(f\"state max_diff={st_diff:.6f}\")\n        print(f\"  [FAIL] {tag}  ({', '.join(details)})\")\n\n    return passed\n\n\n# ---------------------------------------------------------------------------\n# Benchmark\n# ---------------------------------------------------------------------------\n\n\ndef bench_shape(B, H, HV, K, V, pool_size, device, dtype):\n    \"\"\"Benchmark baseline vs packed for a single config.\"\"\"\n    inp = make_inputs(B, H, HV, K, V, pool_size, device, dtype)\n\n    # ── Baseline: full path including split + view ──\n    def fn_baseline():\n        q_flat, k_flat, v_flat = torch.split(\n            inp[\"mixed_qkv\"], [H * K, H * K, HV * V], dim=-1\n        )\n        q = q_flat.view(1, B, H, K)\n        k = k_flat.view(1, B, H, K)\n        v = v_flat.view(1, B, HV, V)\n        fused_sigmoid_gating_delta_rule_update(\n            A_log=inp[\"A_log\"],\n            dt_bias=inp[\"dt_bias\"],\n            q=q,\n            k=k,\n            v=v,\n            a=inp[\"a\"],\n            b=inp[\"b\"],\n            initial_state_source=inp[\"ssm_states\"],\n            initial_state_indices=inp[\"cache_indices\"],\n            cu_seqlens=inp[\"cu_seqlens\"],\n            use_qk_l2norm_in_kernel=True,\n            softplus_beta=1.0,\n            softplus_threshold=20.0,\n        )\n\n    # ── Packed: single kernel ──\n    out_buf = inp[\"mixed_qkv\"].new_empty(B, 1, HV, V)\n\n    def fn_packed():\n        fused_recurrent_gated_delta_rule_packed_decode(\n            mixed_qkv=inp[\"mixed_qkv\"],\n            a=inp[\"a\"],\n            b=inp[\"b\"],\n            A_log=inp[\"A_log\"],\n            dt_bias=inp[\"dt_bias\"],\n            scale=K**-0.5,\n            initial_state=inp[\"ssm_states\"],\n            out=out_buf,\n            ssm_state_indices=inp[\"cache_indices\"],\n            use_qk_l2norm_in_kernel=True,\n        )\n\n    # Warmup\n    for _ in range(10):\n        fn_baseline()\n        fn_packed()\n    torch.cuda.synchronize()\n\n    quantiles = [0.5, 0.2, 0.8]\n\n    try:\n        ms_baseline, ms_base_lo, ms_base_hi = triton.testing.do_bench(\n            fn_baseline, quantiles=quantiles, warmup=50, rep=200\n        )\n        ms_packed, ms_pack_lo, ms_pack_hi = triton.testing.do_bench(\n            fn_packed, quantiles=quantiles, warmup=50, rep=200\n        )\n    except Exception:\n        # Fallback to manual timing\n        torch.cuda.synchronize()\n        N = 200\n        start = time.perf_counter()\n        for _ in range(N):\n            fn_baseline()\n        torch.cuda.synchronize()\n        ms_baseline = (time.perf_counter() - start) / N * 1000\n\n        start = time.perf_counter()\n        for _ in range(N):\n            fn_packed()\n        torch.cuda.synchronize()\n        ms_packed = (time.perf_counter() - start) / N * 1000\n\n    speedup = ms_baseline / ms_packed if ms_packed > 0 else float(\"inf\")\n    saved_us = (ms_baseline - ms_packed) * 1000\n\n    print(\n        f\"  {B:>5}  {H:>3}  {HV:>3}  {K:>3}  {V:>3} | \"\n        f\"{ms_baseline * 1000:>10.1f} | \"\n        f\"{ms_packed * 1000:>10.1f} | \"\n        f\"{speedup:>7.2f}x | \"\n        f\"{saved_us:>+9.1f}\"\n    )\n\n\n# ---------------------------------------------------------------------------\n# Main\n# ---------------------------------------------------------------------------\n\n\ndef run_correctness(device, dtype):\n    print(\"=\" * 70)\n    print(\"Correctness: Baseline GDN Decode vs Packed GDN Decode\")\n    print(\"=\" * 70)\n\n    shapes = [\n        # (B,   H,  HV,  K,   V,   pool_size)\n        # --- Qwen3.5-35B-A3B style (TP=2: H=8, HV=16) ---\n        (1, 8, 16, 128, 128, 32),\n        (4, 8, 16, 128, 128, 32),\n        (16, 8, 16, 128, 128, 64),\n        (32, 8, 16, 128, 128, 128),\n        (64, 8, 16, 128, 128, 128),\n        (128, 8, 16, 128, 128, 256),\n        (256, 8, 16, 128, 128, 512),\n        # --- Qwen3.5-35B-A3B style (TP=1: H=16, HV=32) ---\n        (1, 16, 32, 128, 128, 32),\n        (32, 16, 32, 128, 128, 128),\n        (64, 16, 32, 128, 128, 128),\n        # --- Qwen3-Next-80B-A3B style ---\n        (32, 16, 16, 128, 128, 128),\n        (64, 16, 16, 128, 128, 128),\n        # --- With PAD_SLOT_ID ---\n        (32, 8, 16, 128, 128, 128),  # some indices may be padded\n        # --- Edge cases ---\n        (1, 8, 16, 128, 128, 32),\n        (2, 8, 16, 128, 128, 32),\n    ]\n\n    all_pass = True\n    for B, H, HV, K, V, pool_size in shapes:\n        if not check_correctness(B, H, HV, K, V, pool_size, device, dtype):\n            all_pass = False\n\n    # PAD_SLOT_ID test\n    print(\"\\n  PAD_SLOT_ID test (indices with -1):\")\n    inp = make_inputs(32, 8, 16, 128, 128, 128, device, dtype)\n    o_baseline, st_baseline = run_baseline(inp)\n    o_packed, st_packed = run_packed(inp)\n\n    try:\n        torch.testing.assert_close(o_packed, o_baseline, atol=2e-2, rtol=1e-2)\n        print(\"  [PASS] PAD_SLOT_ID=-1 handling\")\n    except AssertionError:\n        print(\"  [FAIL] PAD_SLOT_ID=-1 handling\")\n        all_pass = False\n\n    print()\n    if all_pass:\n        print(\"ALL PASSED.\")\n    else:\n        print(\"SOME FAILED.\")\n    return all_pass\n\n\ndef run_benchmark(device, dtype, args):\n    print()\n    print(\"=\" * 85)\n    print(\"Benchmark: Baseline GDN Decode vs Packed GDN Decode\")\n    print(\"=\" * 85)\n\n    K = args.head_size_k\n    V = args.head_size_v\n    pool_size = args.pool_size\n\n    if args.preset == \"qwen3.5-35b\":\n        # Qwen3.5-35B-A3B: H_qk=16, H_v=32, K=128, V=128\n        # After TP=2: H=8, HV=16\n        bench_configs = [\n            # (B,   H,  HV) — TP=2 config\n            (1, 8, 16),\n            (2, 8, 16),\n            (4, 8, 16),\n            (8, 8, 16),\n            (16, 8, 16),\n            (32, 8, 16),\n            (64, 8, 16),\n            (128, 8, 16),\n            (256, 8, 16),\n            (512, 8, 16),\n            # TP=1 config (full heads)\n            (1, 16, 32),\n            (8, 16, 32),\n            (32, 16, 32),\n            (64, 16, 32),\n            (128, 16, 32),\n            (256, 16, 32),\n        ]\n    elif args.preset == \"qwen3-next-80b\":\n        bench_configs = [\n            # Qwen3-Next-80B-A3B: all same H=HV=16 after TP\n            (1, 16, 16),\n            (8, 16, 16),\n            (32, 16, 16),\n            (64, 16, 16),\n            (128, 16, 16),\n            (256, 16, 16),\n        ]\n    else:\n        bench_configs = []\n        for B in args.batch_sizes:\n            for H in args.num_q_heads:\n                for HV in args.num_v_heads:\n                    bench_configs.append((B, H, HV))\n\n    print(f\"  Config: K={K}, V={V}, pool_size={pool_size}, dtype={dtype}\")\n    print(\n        f\"  {'B':>5}  {'H':>3}  {'HV':>3}  {'K':>3}  {'V':>3} | \"\n        f\"{'base (μs)':>10} | \"\n        f\"{'packed (μs)':>10} | \"\n        f\"{'speedup':>8} | \"\n        f\"{'saved (μs)':>10}\"\n    )\n    print(\"  \" + \"-\" * 75)\n\n    for B, H, HV in bench_configs:\n        actual_pool = max(pool_size, B + 16)\n        bench_shape(B, H, HV, K, V, actual_pool, device, dtype)\n\n\ndef main():\n    parser = argparse.ArgumentParser(\n        description=\"Benchmark & Correctness: GDN Packed Decode vs Baseline\"\n    )\n    parser.add_argument(\n        \"--mode\",\n        choices=[\"all\", \"correctness\", \"bench\"],\n        default=\"all\",\n        help=\"Run mode (default: all)\",\n    )\n    parser.add_argument(\n        \"--preset\",\n        choices=[\"qwen3.5-35b\", \"qwen3-next-80b\", \"custom\"],\n        default=\"qwen3.5-35b\",\n        help=\"Preset config (default: qwen3.5-35b)\",\n    )\n    parser.add_argument(\n        \"--dtype\",\n        choices=[\"float16\", \"bfloat16\", \"float32\"],\n        default=\"bfloat16\",\n    )\n    parser.add_argument(\"--head-size-k\", type=int, default=128)\n    parser.add_argument(\"--head-size-v\", type=int, default=128)\n    parser.add_argument(\"--pool-size\", type=int, default=512)\n    parser.add_argument(\n        \"--batch-sizes\",\n        type=int,\n        nargs=\"+\",\n        default=[1, 4, 8, 16, 32, 64, 128, 256, 512],\n    )\n    parser.add_argument(\n        \"--num-q-heads\",\n        type=int,\n        nargs=\"+\",\n        default=[8, 16],\n    )\n    parser.add_argument(\n        \"--num-v-heads\",\n        type=int,\n        nargs=\"+\",\n        default=[16, 32],\n    )\n    args = parser.parse_args()\n\n    device = \"cuda\"\n    dtype = getattr(torch, args.dtype)\n\n    cap = torch.cuda.get_device_capability()\n    dev_name = torch.cuda.get_device_name()\n    print(f\"Device: {dev_name}  (SM {cap[0]}{cap[1]})\")\n\n    if args.mode in (\"all\", \"correctness\"):\n        all_pass = run_correctness(device, dtype)\n        if not all_pass and args.mode == \"all\":\n            print(\"\\nSkipping benchmark due to correctness failures.\")\n            return 1\n\n    if args.mode in (\"all\", \"bench\"):\n        run_benchmark(device, dtype, args)\n\n    return 0\n\n\nif __name__ == \"__main__\":\n    sys.exit(main())\n"
  },
  {
    "path": "benchmark/bench_linear_attention/bench_gdn_prefill.py",
    "content": "\"\"\"\nBenchmark & Correctness: Triton GDN vs FlashInfer GDN (prefill).\n\nCompares:\n  - Triton:     sglang's chunk_gated_delta_rule (K-contiguous pool, pool-indexed)\n  - FlashInfer: flashinfer's chunk_gated_delta_rule (gather/scatter, 3D tensors)\n\nThe two kernels have different APIs:\n  - Triton:     q/k/v=[1,T,H,D], g=logsigmoid, beta=sigmoid, has initial_state_indices\n  - FlashInfer: q/k/v=[T,H,D],   g=alpha(float32), beta=float32, no indices (gathered state)\n\nReports correctness (output & state matching) and performance (ms, TFLOPS, TB/s).\n\nUsage:\n    python benchmark_gdn_prefill.py                          # default sweep\n    python benchmark_gdn_prefill.py --mode bench             # benchmark only\n    python benchmark_gdn_prefill.py --mode correctness       # correctness only\n    python benchmark_gdn_prefill.py --preset qwen3-next      # Qwen3-Next config\n\"\"\"\n\nimport argparse\nimport os\nimport sys\n\nsys.path.insert(0, os.path.join(os.path.dirname(__file__), \"..\", \"python\"))\n\nimport torch\nfrom flashinfer.gdn_prefill import (\n    chunk_gated_delta_rule as flashinfer_chunk_gated_delta_rule,\n)\n\nfrom sglang.srt.layers.attention.fla.chunk import (\n    chunk_gated_delta_rule as triton_chunk_gated_delta_rule,\n)\nfrom sglang.srt.layers.attention.fla.l2norm import l2norm_fwd\n\n# ---------------------------------------------------------------------------\n# Helpers\n# ---------------------------------------------------------------------------\n\n\ndef make_k_contiguous(t: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Given a V-contiguous tensor [..., K, V], return a K-contiguous view of the\n    same logical shape [..., K, V] (physically [..., V, K], K-last).\n    \"\"\"\n    return t.transpose(-2, -1).contiguous().transpose(-2, -1)\n\n\ndef gdn_flops(\n    total_seq_len: int,\n    num_heads: int,\n    head_size_k: int,\n    head_size_v: int,\n) -> int:\n    \"\"\"\n    FLOPs for GDN prefill (delta rule).\n\n    Per token per head:\n      1. k @ v^T (outer product):  2 * K * V\n      2. q @ state (output):       2 * K * V\n    \"\"\"\n    outer_product_flops = 2 * total_seq_len * num_heads * head_size_k * head_size_v\n    output_flops = 2 * total_seq_len * num_heads * head_size_k * head_size_v\n    return outer_product_flops + output_flops\n\n\ndef gdn_bytes(\n    total_seq_len: int,\n    num_q_heads: int,\n    num_v_heads: int,\n    head_size_k: int,\n    head_size_v: int,\n    num_seqs: int,\n    dtype: torch.dtype,\n) -> int:\n    \"\"\"Memory bytes accessed (inputs + outputs + state).\"\"\"\n    num_o_heads = max(num_q_heads, num_v_heads)\n    elem = dtype.itemsize\n\n    q_bytes = total_seq_len * num_q_heads * head_size_k * elem\n    k_bytes = total_seq_len * num_v_heads * head_size_k * elem\n    v_bytes = total_seq_len * num_v_heads * head_size_v * elem\n    o_bytes = total_seq_len * num_o_heads * head_size_v * elem\n\n    # state (float32): read + write\n    state_bytes = 2 * num_seqs * num_o_heads * head_size_k * head_size_v * 4\n\n    # g, beta (float32)\n    g_bytes = total_seq_len * num_o_heads * 4\n    beta_bytes = total_seq_len * num_o_heads * 4\n\n    return q_bytes + k_bytes + v_bytes + o_bytes + state_bytes + g_bytes + beta_bytes\n\n\n# ---------------------------------------------------------------------------\n# Input factory\n# ---------------------------------------------------------------------------\n\n\ndef make_inputs(\n    B: int,\n    T_per_seq: int,\n    H: int,\n    K: int,\n    V: int,\n    pool_size: int,\n    device: str,\n    dtype: torch.dtype,\n    sequential_indices: bool = False,\n    seed: int = 42,\n):\n    \"\"\"Create all input tensors for a single benchmark / correctness run.\n\n    Returns a dict with both Triton-format and FlashInfer-format tensors.\n    \"\"\"\n    T = B * T_per_seq\n    torch.manual_seed(seed)\n\n    if sequential_indices:\n        cache_indices = torch.arange(B, dtype=torch.int32, device=device)\n    else:\n        perm = torch.randperm(pool_size, device=device)[:B]\n        cache_indices = perm.to(torch.int32)\n\n    pool_init = torch.randn(pool_size, H, K, V, dtype=dtype, device=device) * 0.1\n\n    cu_seqlens = torch.arange(\n        0, (B + 1) * T_per_seq, T_per_seq, dtype=torch.long, device=device\n    )\n\n    # Triton format: [1, T, H, D]\n    q = torch.randn(1, T, H, K, dtype=dtype, device=device)\n    k = torch.randn(1, T, H, K, dtype=dtype, device=device)\n    v = torch.randn(1, T, H, V, dtype=dtype, device=device)\n\n    # g (logsigmoid) and beta (sigmoid) in Triton format: [1, T, H]\n    g_raw = torch.randn(1, T, H, dtype=dtype, device=device)\n    g_triton = torch.nn.functional.logsigmoid(g_raw)  # logsigmoid for Triton\n    beta_triton = torch.sigmoid(torch.randn(1, T, H, dtype=dtype, device=device))\n\n    return dict(\n        B=B,\n        T=T,\n        T_per_seq=T_per_seq,\n        H=H,\n        K=K,\n        V=V,\n        pool_size=pool_size,\n        cache_indices=cache_indices,\n        pool_init=pool_init,\n        cu_seqlens=cu_seqlens,\n        q=q,\n        k=k,\n        v=v,\n        g_triton=g_triton,\n        beta_triton=beta_triton,\n    )\n\n\n# ---------------------------------------------------------------------------\n# Runner wrappers\n# ---------------------------------------------------------------------------\n\n\ndef run_triton(inp):\n    \"\"\"Triton path: K-contiguous pool, pool-indexed, [1,T,H,D] tensors.\"\"\"\n    pool = make_k_contiguous(inp[\"pool_init\"].clone())\n\n    o, _, h = triton_chunk_gated_delta_rule(\n        q=inp[\"q\"],\n        k=inp[\"k\"],\n        v=inp[\"v\"],\n        g=inp[\"g_triton\"],\n        beta=inp[\"beta_triton\"],\n        initial_state=pool,\n        initial_state_indices=inp[\"cache_indices\"],\n        cu_seqlens=inp[\"cu_seqlens\"],\n        head_first=False,\n        use_qk_l2norm_in_kernel=True,\n    )\n    return o, pool, h\n\n\ndef run_flashinfer(inp):\n    \"\"\"FlashInfer path: matches sglang FlashInferGDNKernel.extend() exactly.\n\n    Key differences from Triton path:\n      - q, k are L2-normalized BEFORE calling the kernel\n      - use_qk_l2norm_in_kernel=False (kernel skips internal normalization)\n      - Tensors are [T, H, D] (no batch dim)\n      - g is alpha = exp(logsigmoid(...)) = sigmoid(...), float32\n      - beta is float32\n      - initial_state is gathered from pool (no pool-index support)\n      - Uses keyword arguments (matching sglang production code)\n\n    NOTE: FlashInfer GDN requires K == V (square head_size).\n    \"\"\"\n    K = inp[\"K\"]\n    V = inp[\"V\"]\n    assert K == V, f\"FlashInfer GDN requires K == V, got K={K}, V={V}\"\n\n    pool = make_k_contiguous(inp[\"pool_init\"].clone())\n    cache_indices = inp[\"cache_indices\"]\n\n    # Gather states from K-contiguous pool -> K-contiguous float32\n    # In production, ssm_states is already float32 so .float() is no-op.\n    # Here pool_init is bf16, so .float() loses K-contiguous layout.\n    gathered = pool[cache_indices]\n    initial_state = make_k_contiguous(gathered.float().contiguous())\n\n    q_fi = l2norm_fwd(inp[\"q\"][0].contiguous())\n    k_fi = l2norm_fwd(inp[\"k\"][0].contiguous())\n    v_fi = inp[\"v\"][0].contiguous()\n\n    # g -> alpha (exp of logsigmoid = sigmoid), float32\n    alpha_fi = torch.exp(inp[\"g_triton\"][0].to(torch.float32))\n    # beta -> float32\n    beta_fi = inp[\"beta_triton\"][0].to(torch.float32)\n\n    cu_seqlens_fi = inp[\"cu_seqlens\"].to(torch.int64)\n\n    # Call FlashInfer with keyword args (matching sglang production code)\n    # use_qk_l2norm_in_kernel=False because we pre-normalized above\n    o_fi, state_fi = flashinfer_chunk_gated_delta_rule(\n        q=q_fi,\n        k=k_fi,\n        v=v_fi,\n        g=alpha_fi,\n        beta=beta_fi,\n        scale=None,\n        initial_state=initial_state,\n        output_final_state=True,\n        cu_seqlens=cu_seqlens_fi,\n        use_qk_l2norm_in_kernel=False,\n    )\n\n    # Scatter updated states back to K-contiguous pool\n    pool[cache_indices] = state_fi.to(pool.dtype)\n\n    # Reshape output: [T, H, D] -> [1, T, H, D] to match Triton\n    o_out = o_fi.unsqueeze(0)\n\n    return o_out, pool, state_fi\n\n\n# ---------------------------------------------------------------------------\n# Correctness check\n# ---------------------------------------------------------------------------\n\n\ndef check_shape(\n    B,\n    T_per_seq,\n    H,\n    K,\n    V,\n    pool_size,\n    device,\n    dtype,\n    sequential_indices=False,\n    seed=42,\n):\n    \"\"\"Run correctness check for a single shape config. Returns True if PASS.\n\n    Pass/fail is based on OUTPUT comparison only (atol=5e-2).\n    Pool state diff is reported as informational — state divergence over many\n    tokens is expected due to different chunk sizes and accumulation order.\n    \"\"\"\n    tag = (\n        f\"B={B:>3} T/seq={T_per_seq:>4} H={H:>2} K={K:>3} V={V:>3} pool={pool_size:>4}\"\n    )\n    idx_tag = \" (seq)\" if sequential_indices else \"\"\n\n    # FlashInfer GDN requires K == V (square head_size)\n    if K != V:\n        print(f\"  [SKIP] {tag}{idx_tag}  (FlashInfer requires K==V)\")\n        return True\n\n    # FlashInfer GDN CUTLASS kernels are only compiled for head_size=128.\n    # Running with other sizes causes illegal memory access that poisons\n    # the CUDA context (unrecoverable), so we must skip upfront.\n    FLASHINFER_SUPPORTED_HEAD_SIZES = {128}\n    if K not in FLASHINFER_SUPPORTED_HEAD_SIZES:\n        print(\n            f\"  [SKIP] {tag}{idx_tag}  (FlashInfer only supports head_size={FLASHINFER_SUPPORTED_HEAD_SIZES})\"\n        )\n        return True\n\n    inp = make_inputs(\n        B,\n        T_per_seq,\n        H,\n        K,\n        V,\n        pool_size,\n        device,\n        dtype,\n        sequential_indices=sequential_indices,\n        seed=seed,\n    )\n\n    o_triton, pool_triton, h_triton = run_triton(inp)\n\n    # FlashInfer may not support all head_size values (e.g., only 128).\n    # CUDA errors from unsupported configs are often asynchronous, so we\n    # must synchronize inside the try block to catch them here.\n    try:\n        o_fi, pool_fi, _ = run_flashinfer(inp)\n        torch.cuda.synchronize()\n    except Exception as e:\n        # Catch RuntimeError, torch.AcceleratorError, etc.\n        # Reset CUDA error state so subsequent tests can proceed\n        try:\n            torch.cuda.synchronize()\n        except Exception:\n            pass\n        print(f\"  [SKIP] {tag}{idx_tag}  (FlashInfer error: {e})\")\n        return True\n\n    cache_indices = inp[\"cache_indices\"]\n\n    # --- Output comparison ---\n    # bf16 prefill with L2norm + chunked accumulation\n    torch.testing.assert_close(o_triton, o_fi, atol=5e-2, rtol=1e-2)\n\n    # --- Stride check ---\n    def strides_ok(pool):\n        s = pool.stride()\n        return s[-2] == 1 and s[-1] == K\n\n    strides_triton = strides_ok(pool_triton)\n    strides_fi = strides_ok(pool_fi)\n\n    passed = strides_triton and strides_fi\n\n    # Build detail string\n    details = []\n    if not strides_triton:\n        details.append(\"triton strides bad\")\n    if not strides_fi:\n        details.append(\"flashinfer strides bad\")\n\n    status = \"PASS\" if passed else \"FAIL\"\n    detail_str = f\"  [{', '.join(details)}]\"\n    print(f\"  [{status}] {tag}{idx_tag}\")\n    return passed\n\n\n# ---------------------------------------------------------------------------\n# Benchmark\n# ---------------------------------------------------------------------------\n\n\ndef bench_shape(B, H, T_per_seq, K, V, pool_size, device, dtype):\n    \"\"\"Benchmark Triton vs FlashInfer for a single config. Requires K == V.\"\"\"\n    import triton.testing\n\n    assert K == V, f\"FlashInfer GDN requires K == V, got K={K}, V={V}\"\n\n    T = B * T_per_seq\n    inp = make_inputs(B, T_per_seq, H, K, V, pool_size, device, dtype)\n\n    # -- Shared read-only tensors --\n    q, k_t, v = inp[\"q\"], inp[\"k\"], inp[\"v\"]\n    g_triton, beta_triton = inp[\"g_triton\"], inp[\"beta_triton\"]\n    cu_seqlens = inp[\"cu_seqlens\"]\n    cache_indices = inp[\"cache_indices\"]\n    seq_indices = torch.arange(B, dtype=torch.int32, device=device)\n    pool_v = inp[\"pool_init\"]\n\n    def fn_triton():\n        pool = make_k_contiguous(pool_v.clone())\n        triton_chunk_gated_delta_rule(\n            q=q,\n            k=k_t,\n            v=v,\n            g=g_triton,\n            beta=beta_triton,\n            initial_state=pool,\n            initial_state_indices=cache_indices,\n            cu_seqlens=cu_seqlens,\n            head_first=False,\n            use_qk_l2norm_in_kernel=True,\n        )\n\n    def fn_flashinfer():\n        # -- Pre-compute FlashInfer format tensors (outside timing) --\n        # Pre-normalize q and k (matching sglang production: l2norm_fwd)\n        # q_fi = torch.nn.functional.normalize(q[0].contiguous().float(), p=2.0, dim=-1).to(\n        #     dtype\n        # )\n        # k_fi = torch.nn.functional.normalize(k_t[0].contiguous().float(), p=2.0, dim=-1).to(\n        #     dtype\n        # )\n        q_fi = l2norm_fwd(q[0].contiguous())\n        k_fi = l2norm_fwd(k_t[0].contiguous())\n        v_fi = v[0].contiguous()\n        alpha_fi = torch.exp(g_triton[0].to(torch.float32))\n        beta_fi = beta_triton[0].to(torch.float32)\n        cu_seqlens_fi = cu_seqlens.to(torch.int64)\n        pool = make_k_contiguous(pool_v.clone())\n        gathered = pool[cache_indices]\n        initial_state = make_k_contiguous(gathered.float().contiguous())\n        flashinfer_chunk_gated_delta_rule(\n            q=q_fi,\n            k=k_fi,\n            v=v_fi,\n            g=alpha_fi,\n            beta=beta_fi,\n            scale=None,\n            initial_state=initial_state,\n            output_final_state=True,\n            cu_seqlens=cu_seqlens_fi,\n            use_qk_l2norm_in_kernel=False,\n        )\n\n    quantiles = [0.5, 0.2, 0.8]\n\n    # Warmup\n    fn_triton()\n    fn_flashinfer()\n    torch.cuda.synchronize()\n\n    ms_triton, _, _ = triton.testing.do_bench_cudagraph(fn_triton, quantiles=quantiles)\n    ms_fi, _, _ = triton.testing.do_bench_cudagraph(fn_flashinfer, quantiles=quantiles)\n\n    # Metrics\n    num_o_heads = H\n    flops = gdn_flops(T, num_o_heads, K, V)\n    mem_bytes = gdn_bytes(T, H, H, K, V, B, dtype)\n\n    tflops_triton = flops / ms_triton / 1e9\n    tflops_fi = flops / ms_fi / 1e9\n    tb_s_triton = mem_bytes / ms_triton / 1e9\n    tb_s_fi = mem_bytes / ms_fi / 1e9\n\n    speedup = ms_triton / ms_fi if ms_fi > 0 else float(\"inf\")\n\n    print(\n        f\"  {B:>5}  {H:>3}  {T_per_seq:>6}  {T:>7} | \"\n        f\"{ms_triton:>8.3f}  {tflops_triton:>7.2f}  {tb_s_triton:>7.2f} | \"\n        f\"{ms_fi:>8.3f}  {tflops_fi:>7.2f}  {tb_s_fi:>7.2f} | \"\n        f\"{speedup:>7.2f}x\"\n    )\n\n\n# ---------------------------------------------------------------------------\n# Main\n# ---------------------------------------------------------------------------\n\n\ndef run_correctness(device, dtype):\n    print(\"=\" * 78)\n    print(\"Correctness sweep: Triton vs FlashInfer\")\n    print(\"=\" * 78)\n\n    shapes = [\n        # (B, T_per_seq, H,  K,   V,   pool_size)\n        # --- baseline (Qwen3-Next style) ---\n        (4, 64, 16, 128, 128, 32),\n        (4, 256, 16, 128, 128, 32),\n        # --- different batch sizes ---\n        (1, 128, 16, 128, 128, 32),\n        (8, 128, 16, 128, 128, 64),\n        (16, 64, 16, 128, 128, 128),\n        (32, 32, 16, 128, 128, 256),\n        # --- different head counts ---\n        (4, 128, 4, 128, 128, 32),\n        (4, 128, 8, 128, 128, 32),\n        (4, 128, 16, 64, 64, 32),\n        (4, 128, 32, 128, 128, 32),\n        (4, 128, 64, 128, 128, 32),\n        # --- short sequences ---\n        (4, 1, 16, 128, 128, 32),\n        (4, 7, 16, 128, 128, 32),\n        (4, 16, 16, 128, 128, 32),\n        # --- large pool (sparse access) ---\n        (4, 128, 16, 128, 128, 512),\n        # --- combined stress ---\n        (32, 128, 32, 128, 128, 256),\n    ]\n\n    shapes_seq = [\n        (8, 128, 16, 128, 128, 8),\n        (4, 128, 32, 128, 128, 4),\n        (4, 128, 64, 128, 128, 4),\n        (32, 128, 32, 128, 128, 32),\n    ]\n\n    all_pass = True\n    for B, T_per_seq, H, K, V, pool_size in shapes:\n        if not check_shape(B, T_per_seq, H, K, V, pool_size, device, dtype):\n            all_pass = False\n\n    print()\n    print(\"Sequential-index variants:\")\n    for B, T_per_seq, H, K, V, pool_size in shapes_seq:\n        if not check_shape(\n            B,\n            T_per_seq,\n            H,\n            K,\n            V,\n            pool_size,\n            device,\n            dtype,\n            sequential_indices=True,\n        ):\n            all_pass = False\n\n    print()\n    if all_pass:\n        print(\"ALL PASSED.\")\n    else:\n        print(\"SOME FAILED.\")\n    return all_pass\n\n\ndef run_benchmark(device, dtype, args):\n    print()\n    print(\"=\" * 105)\n    print(\"Benchmark: Triton GDN vs FlashInfer GDN  (do_bench_cudagraph)\")\n    print(\"=\" * 105)\n\n    K = args.head_size_k\n    V = args.head_size_v\n    pool_size = args.pool_size\n\n    if args.preset == \"qwen3-next\":\n        bench_configs = [\n            # (B,   H, T_per_seq)\n            (4, 16, 256),\n            (4, 32, 256),\n            (16, 16, 256),\n            (16, 32, 256),\n            (32, 16, 256),\n            (32, 32, 256),\n            (64, 16, 256),\n            (64, 32, 256),\n            (128, 16, 256),\n            (128, 32, 256),\n            # longer sequences\n            (4, 16, 1024),\n            (4, 32, 1024),\n            (32, 16, 1024),\n            (32, 32, 1024),\n        ]\n    else:\n        bench_configs = []\n        for B in args.batch_sizes:\n            for H in args.num_heads:\n                for T_per_seq in args.seq_lens:\n                    bench_configs.append((B, H, T_per_seq))\n\n    print(f\"  Config: K={K}, V={V}, pool_size={pool_size}, dtype={dtype}\")\n    print(\n        f\"  {'B':>5}  {'H':>3}  {'T/seq':>6}  {'T_tot':>7} | \"\n        f\"{'tri(ms)':>8}  {'TFLOPS':>7}  {'TB/s':>7} | \"\n        f\"{'fi(ms)':>8}  {'TFLOPS':>7}  {'TB/s':>7} | \"\n        f\"{'speedup':>8}\"\n    )\n    print(\"  \" + \"-\" * 98)\n\n    for B, H, T_per_seq in bench_configs:\n        actual_pool = max(pool_size, B)\n        bench_shape(B, H, T_per_seq, K, V, actual_pool, device, dtype)\n\n\ndef main():\n    parser = argparse.ArgumentParser(\n        description=\"Benchmark & Correctness: Triton GDN vs FlashInfer GDN\"\n    )\n    parser.add_argument(\n        \"--mode\",\n        choices=[\"all\", \"correctness\", \"bench\"],\n        default=\"all\",\n        help=\"Run mode (default: all)\",\n    )\n    parser.add_argument(\n        \"--preset\",\n        choices=[\"qwen3-next\", \"custom\"],\n        default=\"qwen3-next\",\n        help=\"Preset config (default: qwen3-next)\",\n    )\n    parser.add_argument(\n        \"--dtype\",\n        choices=[\"float16\", \"bfloat16\"],\n        default=\"bfloat16\",\n    )\n    parser.add_argument(\"--head-size-k\", type=int, default=128)\n    parser.add_argument(\"--head-size-v\", type=int, default=128)\n    parser.add_argument(\"--pool-size\", type=int, default=256)\n    parser.add_argument(\n        \"--batch-sizes\",\n        type=int,\n        nargs=\"+\",\n        default=[4, 16, 32, 64, 128],\n    )\n    parser.add_argument(\n        \"--num-heads\",\n        type=int,\n        nargs=\"+\",\n        default=[16, 32],\n    )\n    parser.add_argument(\n        \"--seq-lens\",\n        type=int,\n        nargs=\"+\",\n        default=[128, 256, 512, 1024],\n    )\n    args = parser.parse_args()\n\n    if args.preset == \"qwen3-next\":\n        args.head_size_k = 128\n        args.head_size_v = 128\n\n    device = \"cuda\"\n    dtype = getattr(torch, args.dtype)\n\n    # Check SM version\n    cap = torch.cuda.get_device_capability()\n    dev_name = torch.cuda.get_device_name()\n    print(f\"Device: {dev_name}  (SM {cap[0]}{cap[1]})\")\n\n    if args.mode in (\"all\", \"correctness\"):\n        all_pass = run_correctness(device, dtype)\n        if not all_pass and args.mode == \"all\":\n            print(\"\\nSkipping benchmark due to correctness failures.\")\n            return 1\n\n    if args.mode in (\"all\", \"bench\"):\n        run_benchmark(device, dtype, args)\n\n    return 0\n\n\nif __name__ == \"__main__\":\n    sys.exit(main())\n"
  },
  {
    "path": "benchmark/bench_rope/benchmark_rope_index.py",
    "content": "# This script benchmarks MRotaryEmbedding.get_rope_index_glm4v (GLM4V mrope index builder).\n# It generates synthetic multimodal input_ids + attention_mask (+ optional image/video grids),\n# runs benchmarks.\n#\n# == Usage Examples ==\n#\n# python3 benchmark_rope_index.py --device cuda --num-tokens 1024 2048 --benchmark-iter 200\n\nimport argparse\nimport math\nimport time\nfrom dataclasses import dataclass, field\nfrom typing import Any\n\nimport numpy as np\nimport torch\n\nfrom sglang.srt.layers.rotary_embedding import MRotaryEmbedding\n\n\n# -----------------------------\n# Minimal config objects\n# -----------------------------\n@dataclass\nclass DummyVisionConfig:\n    spatial_merge_size: int = 2\n\n\n@dataclass\nclass DummyHFConfig:\n    image_token_id: int = 32000\n    video_start_token_id: int = 32001\n    video_end_token_id: int = 32002\n    vision_config: DummyVisionConfig = field(\n        default_factory=lambda: DummyVisionConfig(spatial_merge_size=2)\n    )\n\n\n# -----------------------------\n# Helpers\n# -----------------------------\ndef calculate_stats(times: list[float]) -> dict[str, float]:\n    \"\"\"Calculate statistics from a list of times.\"\"\"\n    times_array = np.array(times, dtype=np.float64)\n    return {\n        \"mean\": float(np.mean(times_array)),\n        \"median\": float(np.median(times_array)),\n        \"p99\": float(np.percentile(times_array, 99)),\n        \"min\": float(np.min(times_array)),\n        \"max\": float(np.max(times_array)),\n    }\n\n\ndef _sync(device: torch.device):\n    if device.type == \"cuda\":\n        torch.cuda.synchronize()\n\n\ndef _approx_hw(patches: int, merge: int) -> tuple[int, int]:\n    # want (h/merge)*(w/merge) ~= patches\n    gh = int(math.sqrt(max(1, patches)))\n    gw = max(1, patches // max(1, gh))\n    return gh * merge, gw * merge\n\n\ndef generate_test_data(\n    num_tokens: int,\n    batch_size: int,\n    hf_config: DummyHFConfig,\n    dtype: torch.dtype,\n    device: torch.device,\n    pad_ratio: float,\n    num_images_per_sample: int,\n    image_patch_tokens: int,\n    num_videos_per_sample: int,\n    video_patch_tokens: int,\n    seed: int,\n):\n    \"\"\"\n    Generate synthetic (input_ids, attention_mask, image_grid_thw, video_grid_thw).\n\n    NOTE:\n      - image_grid_thw / video_grid_thw are global lists across the entire batch in encounter order,\n        matching the function's image_index/video_index behavior.\n      - image patches are represented by repeated image_token_id.\n      - video patches are represented by image_token_id wrapped with start/end tokens.\n    \"\"\"\n    torch.manual_seed(seed)\n\n    forbidden = {\n        0,\n        hf_config.image_token_id,\n        hf_config.video_start_token_id,\n        hf_config.video_end_token_id,\n    }\n    vocab_size = 50000\n\n    def rand_text(n: int) -> torch.Tensor:\n        # generate random ids not in forbidden\n        out = torch.randint(1, vocab_size, (n,), device=device, dtype=torch.long)\n        # fix forbidden by +1 until ok (cheap, deterministic enough for benchmark data)\n        for bad in forbidden:\n            out = torch.where(out == bad, out + 1, out)\n        return out\n\n    image_grids: list[list[int]] = []\n    video_grids: list[list[int]] = []\n\n    input_ids = torch.zeros((batch_size, num_tokens), device=device, dtype=torch.long)\n    attention_mask = torch.zeros(\n        (batch_size, num_tokens), device=device, dtype=torch.long\n    )\n\n    eff_len = int(round(num_tokens * (1.0 - pad_ratio)))\n    eff_len = max(1, min(num_tokens, eff_len))\n\n    min_needed = 1\n    min_needed += num_images_per_sample * image_patch_tokens\n    min_needed += num_videos_per_sample * (2 + video_patch_tokens)\n    if eff_len < min_needed:\n        num_images_per_sample = 0\n        num_videos_per_sample = 0\n\n    for b in range(batch_size):\n        blocks: list[torch.Tensor] = []\n\n        reserved = (\n            num_images_per_sample * image_patch_tokens\n            + num_videos_per_sample * (2 + video_patch_tokens)\n        )\n        reserved = min(reserved, max(0, eff_len - 1))\n        text_budget = max(1, eff_len - reserved)\n\n        n_text_chunks = num_images_per_sample + num_videos_per_sample + 1\n        base = text_budget // n_text_chunks\n        rem = text_budget % n_text_chunks\n        text_chunks = [base + (1 if i < rem else 0) for i in range(n_text_chunks)]\n\n        tci = 0\n        for _ in range(num_images_per_sample):\n            blocks.append(rand_text(text_chunks[tci]))\n            tci += 1\n            blocks.append(\n                torch.full(\n                    (image_patch_tokens,),\n                    hf_config.image_token_id,\n                    device=device,\n                    dtype=torch.long,\n                )\n            )\n\n            h, w = _approx_hw(\n                image_patch_tokens, hf_config.vision_config.spatial_merge_size\n            )\n            image_grids.append([1, h, w])\n\n        for _ in range(num_videos_per_sample):\n            blocks.append(rand_text(text_chunks[tci]))\n            tci += 1\n            blocks.append(\n                torch.tensor(\n                    [hf_config.video_start_token_id], device=device, dtype=torch.long\n                )\n            )\n            blocks.append(\n                torch.full(\n                    (video_patch_tokens,),\n                    hf_config.image_token_id,\n                    device=device,\n                    dtype=torch.long,\n                )\n            )\n            blocks.append(\n                torch.tensor(\n                    [hf_config.video_end_token_id], device=device, dtype=torch.long\n                )\n            )\n\n            h, w = _approx_hw(\n                video_patch_tokens, hf_config.vision_config.spatial_merge_size\n            )\n            # first field = group count used by code; set to 1\n            video_grids.append([1, h, w])\n\n        blocks.append(rand_text(text_chunks[tci]))\n\n        tokens = torch.cat(blocks, dim=0)[:eff_len]\n        pad = torch.zeros(\n            (num_tokens - tokens.numel(),), device=device, dtype=torch.long\n        )\n        ids = torch.cat([tokens, pad], dim=0)\n\n        mask = torch.cat(\n            [\n                torch.ones((tokens.numel(),), device=device, dtype=torch.long),\n                torch.zeros(\n                    (num_tokens - tokens.numel(),), device=device, dtype=torch.long\n                ),\n            ],\n            dim=0,\n        )\n\n        input_ids[b] = ids\n        attention_mask[b] = mask\n\n    image_grid_thw = (\n        torch.tensor(image_grids, device=device, dtype=torch.long)\n        if len(image_grids)\n        else None\n    )\n    video_grid_thw = (\n        torch.tensor(video_grids, device=device, dtype=torch.long)\n        if len(video_grids)\n        else None\n    )\n    return (\n        input_ids.to(dtype=torch.long),\n        attention_mask.to(dtype=torch.long),\n        image_grid_thw,\n        video_grid_thw,\n    )\n\n\ndef benchmark_rope_index(\n    model_name: str,\n    tp_size: int,\n    num_tokens: int,\n    batch_size: int,\n    pad_ratio: float,\n    spatial_merge_size: int,\n    num_images: int,\n    image_patch_tokens: int,\n    num_videos: int,\n    video_patch_tokens: int,\n    dtype: torch.dtype,\n    seed: int,\n    warmup_iter: int,\n    benchmark_iter: int,\n    device: torch.device,\n):\n    torch.manual_seed(seed)\n    hf_config = DummyHFConfig(\n        image_token_id=32000,\n        video_start_token_id=32001,\n        video_end_token_id=32002,\n        vision_config=DummyVisionConfig(spatial_merge_size=spatial_merge_size),\n    )\n\n    print(80 * \"=\")\n    print(\n        f\"Evaluating: {model_name} tp_size={tp_size} \"\n        f\"num_tokens={num_tokens} batch={batch_size} pad_ratio={pad_ratio} \"\n        f\"images/sample={num_images} image_patch_tokens={image_patch_tokens} \"\n        f\"videos/sample={num_videos} video_patch_tokens={video_patch_tokens} \"\n        f\"dtype={dtype} device={device}\"\n    )\n\n    input_ids, attention_mask, image_grid_thw, video_grid_thw = generate_test_data(\n        num_tokens=num_tokens,\n        batch_size=batch_size,\n        hf_config=hf_config,\n        dtype=dtype,\n        device=device,\n        pad_ratio=pad_ratio,\n        num_images_per_sample=num_images,\n        image_patch_tokens=image_patch_tokens,\n        num_videos_per_sample=num_videos,\n        video_patch_tokens=video_patch_tokens,\n        seed=seed,\n    )\n\n    # Smoke test\n    has_mm = (image_grid_thw is not None) or (video_grid_thw is not None)\n    if has_mm:\n        pos, delta = MRotaryEmbedding.get_rope_index_glm4v(\n            input_ids=input_ids,\n            hf_config=hf_config,\n            image_grid_thw=image_grid_thw,\n            video_grid_thw=video_grid_thw,\n            attention_mask=attention_mask,\n        )\n        assert pos.shape == (3, batch_size, num_tokens)\n        assert delta.shape == (batch_size, 1)\n\n    # Warm up\n    for _ in range(warmup_iter):\n        if has_mm:\n            MRotaryEmbedding.get_rope_index_glm4v(\n                input_ids=input_ids,\n                hf_config=hf_config,\n                image_grid_thw=image_grid_thw,\n                video_grid_thw=video_grid_thw,\n                attention_mask=attention_mask,\n            )\n        MRotaryEmbedding.get_rope_index_glm4v(\n            input_ids=input_ids,\n            hf_config=hf_config,\n            image_grid_thw=None,\n            video_grid_thw=None,\n            attention_mask=attention_mask,\n        )\n\n    _sync(device)\n\n    # Time multimodal branch\n    multimodal_times = []\n    for _ in range(benchmark_iter):\n        _sync(device)\n        start = time.time()\n        MRotaryEmbedding.get_rope_index_glm4v(\n            input_ids=input_ids,\n            hf_config=hf_config,\n            image_grid_thw=image_grid_thw,\n            video_grid_thw=video_grid_thw,\n            attention_mask=attention_mask,\n        )\n        _sync(device)\n        multimodal_times.append(time.time() - start)\n\n    # Time fallback branch\n    fallback_times = []\n    for _ in range(benchmark_iter):\n        _sync(device)\n        start = time.time()\n        MRotaryEmbedding.get_rope_index_glm4v(\n            input_ids=input_ids,\n            hf_config=hf_config,\n            image_grid_thw=None,\n            video_grid_thw=None,\n            attention_mask=attention_mask,\n        )\n        _sync(device)\n        fallback_times.append(time.time() - start)\n\n    multimodal_stats = calculate_stats(multimodal_times)\n    fallback_stats = calculate_stats(fallback_times)\n\n    print(f\"\\nPerformance for config (B={batch_size}, T={num_tokens}):\")\n    print(\n        f\"Multimodal: mean={multimodal_stats['mean']:.8f}s, \"\n        f\"median={multimodal_stats['median']:.8f}s, \"\n        f\"p99={multimodal_stats['p99']:.8f}s\"\n    )\n    print(\n        f\"Fallback:   mean={fallback_stats['mean']:.8f}s, \"\n        f\"median={fallback_stats['median']:.8f}s, \"\n        f\"p99={fallback_stats['p99']:.8f}s\"\n    )\n\n    if has_mm:\n        speedup = (\n            multimodal_stats[\"mean\"] / fallback_stats[\"mean\"]\n            if fallback_stats[\"mean\"] > 0\n            else float(\"inf\")\n        )\n        print(f\"Fallback Speedup over Multimodal: {speedup:.8f}x\")\n    else:\n        speedup = float(\"nan\")\n        print(\n            \"[INFO] num_tokens too small for multimodal segments; skip multimodal benchmark.\"\n        )\n\n    print(f\"Fallback Speedup over Multimodal: {speedup:.8f}x\")\n\n    return multimodal_stats, fallback_stats, speedup\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\n        description=\"Benchmark GLM4V get_rope_index_glm4v.\"\n    )\n    parser.add_argument(\"--model-name\", type=str, default=\"GLM4V\")\n    parser.add_argument(\"--tp-size\", type=int, default=1)\n    parser.add_argument(\n        \"--device\", type=str, default=\"cuda\" if torch.cuda.is_available() else \"cpu\"\n    )\n    parser.add_argument(\"--warmup-iter\", type=int, default=10)\n    parser.add_argument(\"--benchmark-iter\", type=int, default=100)\n    parser.add_argument(\"--dtype\", type=str, choices=[\"int64\"], default=\"int64\")\n    parser.add_argument(\"--seed\", type=int, default=0)\n\n    # token length sweep\n    parser.add_argument(\"--num-tokens\", type=int, nargs=\"+\", required=False)\n\n    # data shape knobs\n    parser.add_argument(\"--batch-size\", type=int, default=1)\n    parser.add_argument(\"--pad-ratio\", type=float, default=0.0)\n    parser.add_argument(\"--spatial-merge-size\", type=int, default=2)\n    parser.add_argument(\"--num-images\", type=int, default=1)\n    parser.add_argument(\"--image-patch-tokens\", type=int, default=256)\n    parser.add_argument(\"--num-videos\", type=int, default=1)\n    parser.add_argument(\"--video-patch-tokens\", type=int, default=256)\n\n    # output\n    parser.add_argument(\"--out-dir\", type=str, default=\".\")\n    args = parser.parse_args()\n    print(args)\n\n    device = torch.device(args.device)\n\n    if args.num_tokens is None:\n        num_tokens_list = [2**i for i in range(0, 18)]\n    else:\n        num_tokens_list = args.num_tokens\n\n    rows: list[dict[str, Any]] = []\n\n    for num_tokens in num_tokens_list:\n        multimodal_stats, fallback_stats, speedup = benchmark_rope_index(\n            model_name=args.model_name,\n            tp_size=args.tp_size,\n            num_tokens=num_tokens,\n            batch_size=args.batch_size,\n            pad_ratio=args.pad_ratio,\n            spatial_merge_size=args.spatial_merge_size,\n            num_images=args.num_images,\n            image_patch_tokens=args.image_patch_tokens,\n            num_videos=args.num_videos,\n            video_patch_tokens=args.video_patch_tokens,\n            dtype=getattr(torch, args.dtype),\n            seed=args.seed,\n            warmup_iter=args.warmup_iter,\n            benchmark_iter=args.benchmark_iter,\n            device=device,\n        )\n"
  },
  {
    "path": "benchmark/benchmark_batch/benchmark_batch.py",
    "content": "import concurrent.futures\nimport os\nimport random\nimport time\nfrom concurrent.futures import ProcessPoolExecutor\nfrom statistics import mean\n\nimport requests\nfrom tqdm import tqdm\nfrom transformers import AutoTokenizer\n\nfrom sglang.lang.backend.runtime_endpoint import RuntimeEndpoint\n\n###############################################################################\n# CONFIG\n###############################################################################\nENDPOINT_URL = \"http://127.0.0.1:30000\"\nTOKENIZER_DIR = \"/models/meta-llama/Llama-3.2-3B\"\n\n# Benchmark configurations\nNUM_REQUESTS = 10  # Total number of requests (each with BATCH_SIZE prompts)\nNUM_TOKENS = 32000  # Tokens per prompt\nBATCH_SIZE = 8  # Number of prompts per request\nGEN_TOKENS = 0  # Tokens to generate per prompt\n\n\n###############################################################################\n# REQUEST GENERATION (in parallel)\n###############################################################################\ndef generate_random_prompt(index, tokenizer_dir, num_tokens):\n    \"\"\"Generate a single random prompt with specified token count.\"\"\"\n    tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)\n    vocab_size = tokenizer.vocab_size\n\n    def generate_random_text(num_toks):\n        random_token_ids = [random.randint(0, vocab_size - 1) for _ in range(num_toks)]\n        return tokenizer.decode(random_token_ids, clean_up_tokenization_spaces=True)\n\n    random_text = generate_random_text(num_tokens)\n    return f\"Prompt {index}: {random_text}\"\n\n\ndef prepare_all_prompts(num_requests, batch_size, num_tokens, tokenizer_dir):\n    \"\"\"Generate prompts for all requests in parallel.\"\"\"\n    total_prompts = num_requests * batch_size\n    all_prompts = [None] * total_prompts\n    max_workers = min(os.cpu_count() or 1, total_prompts)\n\n    with ProcessPoolExecutor(max_workers=max_workers) as executor:\n        futures = [\n            executor.submit(generate_random_prompt, i, tokenizer_dir, num_tokens)\n            for i in range(total_prompts)\n        ]\n        for future in tqdm(\n            concurrent.futures.as_completed(futures),\n            total=total_prompts,\n            desc=\"Generating prompts\",\n        ):\n            index = futures.index(future)\n            all_prompts[index] = future.result()\n\n    batched_prompts = [\n        all_prompts[i * batch_size : (i + 1) * batch_size] for i in range(num_requests)\n    ]\n\n    print(\n        f\"Generated {total_prompts} prompts with {num_tokens} tokens each, grouped into {num_requests} requests of {batch_size} prompts.\\n\"\n    )\n    return batched_prompts\n\n\n###############################################################################\n# HTTP CALLS\n###############################################################################\ndef send_batch_request(endpoint, prompts, gen_tokens, request_id):\n    \"\"\"Send a batch of prompts to the /generate endpoint synchronously.\"\"\"\n    sampling_params = {\n        \"max_new_tokens\": gen_tokens,\n        \"temperature\": 0.7,\n        \"stop\": \"\\n\",\n    }\n    data = {\"text\": prompts, \"sampling_params\": sampling_params}\n\n    start_time = time.perf_counter()\n    try:\n        response = requests.post(\n            endpoint.base_url + \"/generate\", json=data, timeout=3600\n        )\n        if response.status_code != 200:\n            error = response.json()\n            raise RuntimeError(f\"Request {request_id} failed: {error}\")\n        result = response.json()\n        elapsed_time = (time.perf_counter() - start_time) * 1000  # Convert to ms\n        avg_per_prompt = elapsed_time / len(prompts) if prompts else 0\n        return request_id, elapsed_time, avg_per_prompt, True, len(prompts)\n    except Exception as e:\n        print(f\"[Request] Error for request {request_id}: {e}\")\n        return request_id, 0, 0, False, len(prompts)\n\n\ndef run_benchmark(endpoint, batched_prompts, batch_size, gen_tokens):\n    \"\"\"Run the benchmark sequentially.\"\"\"\n    results = []\n    num_requests = len(batched_prompts)\n\n    # Record start time for total latency\n    benchmark_start_time = time.perf_counter()\n\n    for i, batch_prompts in enumerate(batched_prompts):\n        request_id = i + 1\n        assert (\n            len(batch_prompts) == batch_size\n        ), f\"Request {request_id} should have {batch_size} prompts, got {len(batch_prompts)}\"\n\n        print(\n            f\"[Request] Sending request {request_id}/{num_requests} with {len(batch_prompts)} prompts at {int(time.time()*1000)}\"\n        )\n        result = send_batch_request(endpoint, batch_prompts, gen_tokens, request_id)\n        results.append(result)\n\n    # Calculate total latency\n    total_latency = (time.perf_counter() - benchmark_start_time) * 1000  # Convert to ms\n\n    return results, total_latency\n\n\n###############################################################################\n# RESULTS\n###############################################################################\ndef process_results(results, total_latency, num_requests):\n    \"\"\"Process and display benchmark results.\"\"\"\n    total_time = 0\n    successful_requests = 0\n    failed_requests = 0\n    request_latencies = []\n    per_prompt_latencies = []\n    total_prompts = 0\n\n    for request_id, elapsed_time, avg_per_prompt, success, batch_size in results:\n        if success:\n            successful_requests += 1\n            total_prompts += batch_size\n            request_latencies.append(elapsed_time)\n            per_prompt_latencies.append(avg_per_prompt)\n            total_time += elapsed_time / 1000  # Convert to seconds\n        else:\n            failed_requests += 1\n\n    avg_request_latency = mean(request_latencies) if request_latencies else 0\n    avg_per_prompt_latency = mean(per_prompt_latencies) if per_prompt_latencies else 0\n    throughput = total_prompts / total_time if total_time > 0 else 0\n\n    print(\"\\nBenchmark Summary:\")\n    print(f\"  Total requests sent:         {len(results)}\")\n    print(f\"  Total prompts sent:          {total_prompts}\")\n    print(f\"  Successful requests:         {successful_requests}\")\n    print(f\"  Failed requests:             {failed_requests}\")\n    print(f\"  Total latency (all requests): {total_latency:.2f} ms\")\n    print(f\"  Avg per request latency:     {avg_request_latency:.2f} ms\")\n    print(f\"  Avg per prompt latency:      {avg_per_prompt_latency:.2f} ms\")\n    print(f\"  Throughput:                  {throughput:.2f} prompts/second\\n\")\n\n\n###############################################################################\n# MAIN\n###############################################################################\ndef main():\n    # Initialize endpoint\n    endpoint = RuntimeEndpoint(ENDPOINT_URL)\n\n    # Generate prompts\n    batched_prompts = prepare_all_prompts(\n        NUM_REQUESTS, BATCH_SIZE, NUM_TOKENS, TOKENIZER_DIR\n    )\n\n    # Flush cache before benchmark\n    # endpoint.flush_cache()\n\n    # Run benchmark\n    print(\n        f\"Starting benchmark: NUM_TOKENS={NUM_TOKENS}, BATCH_SIZE={BATCH_SIZE}, NUM_REQUESTS={NUM_REQUESTS}\\n\"\n    )\n    results, total_latency = run_benchmark(\n        endpoint, batched_prompts, BATCH_SIZE, GEN_TOKENS\n    )\n\n    # Process and display results\n    process_results(results, total_latency, NUM_REQUESTS)\n\n\nif __name__ == \"__main__\":\n    random.seed(0)\n    main()\n"
  },
  {
    "path": "benchmark/benchmark_batch/benchmark_tokenizer.py",
    "content": "import argparse\nimport random\nimport time\nfrom statistics import mean\n\nfrom transformers import AutoTokenizer\n\nfrom sglang.srt.utils.patch_tokenizer import patch_tokenizer\n\n\ndef main():\n    args = parse_args()\n\n    print(\"Tokenizer Benchmark: Sequential vs Batch Processing\")\n    print(\"-\" * 60)\n    print(f\"Tokenizer: {args.tokenizer}\")\n    print(f\"Functions: {', '.join(args.function)}\")\n    print(f\"Tokens per prompt: {args.num_tokens}\")\n    print(f\"Number of runs per batch size: {args.num_runs}\")\n    print(f\"Batch mode: {', '.join(args.batch_mode)}\")\n    print(\"-\" * 60)\n\n    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True)\n    tokenizer = patch_tokenizer(tokenizer)\n    max_batch_size = max(args.batch_sizes)\n\n    token_ids = generate_random_token_ids(\n        num_prompts=max_batch_size, num_tokens=args.num_tokens, tokenizer=tokenizer\n    )\n\n    if \"encode\" in args.function:\n        prompts = [\n            tokenizer.decode(ids, clean_up_tokenization_spaces=True)\n            for ids in token_ids\n        ]\n        run_benchmark(\n            name=\"encode\",\n            data=prompts,\n            sequential_fn=lambda batch: [tokenizer.encode(p) for p in batch],\n            batch_fn=lambda batch: tokenizer(batch),\n            batch_sizes=args.batch_sizes,\n            num_runs=args.num_runs,\n            batch_mode=args.batch_mode,\n        )\n\n    if \"decode\" in args.function:\n        # mimic DetokenizerManager's usual case\n        decode_kwargs = dict(\n            skip_special_tokens=True,\n            spaces_between_special_tokens=True,\n        )\n        run_benchmark(\n            name=\"decode\",\n            data=token_ids,\n            sequential_fn=lambda batch: [\n                tokenizer.decode(ids, **decode_kwargs) for ids in batch\n            ],\n            batch_fn=lambda batch: tokenizer.batch_decode(batch, **decode_kwargs),\n            batch_sizes=args.batch_sizes,\n            num_runs=args.num_runs,\n            batch_mode=args.batch_mode,\n        )\n\n\ndef run_benchmark(\n    *, name, data, sequential_fn, batch_fn, batch_sizes, num_runs, batch_mode\n):\n    print(\"\\n\" + \"=\" * 60)\n    print(f\"{name.upper()} BENCHMARK\")\n    print(\"=\" * 60)\n\n    results = [\n        benchmark(\n            data=data,\n            batch_size=bs,\n            sequential_fn=sequential_fn,\n            batch_fn=batch_fn,\n            num_runs=num_runs,\n            batch_mode=batch_mode,\n        )\n        for bs in batch_sizes\n    ]\n    print_results(results=results, func_name=name, batch_mode=batch_mode)\n\n\ndef benchmark(*, data, batch_size, sequential_fn, batch_fn, num_runs, batch_mode):\n    batch_data = data[:batch_size]\n    run_single = \"single\" in batch_mode\n    run_batch = \"batch\" in batch_mode\n\n    out = {\"batch_size\": batch_size}\n\n    if run_single:\n        sequential_times = measure_times(\n            fn=lambda: sequential_fn(batch_data), num_runs=num_runs\n        )\n        out |= {\n            \"avg_sequential_ms\": mean(sequential_times),\n            \"sequential_runs\": sequential_times,\n        }\n\n    if run_batch:\n        batch_times = measure_times(fn=lambda: batch_fn(batch_data), num_runs=num_runs)\n        out |= {\n            \"avg_batch_ms\": mean(batch_times),\n            \"batch_runs\": batch_times,\n        }\n\n    if run_single and run_batch:\n        out[\"speedup_factor\"] = (\n            out[\"avg_sequential_ms\"] / out[\"avg_batch_ms\"]\n            if out[\"avg_batch_ms\"] > 0\n            else 0\n        )\n\n    return out\n\n\ndef print_results(*, results, func_name, batch_mode):\n    run_single = \"single\" in batch_mode\n    run_batch = \"batch\" in batch_mode\n\n    for r in results:\n        print(f\"\\nBatch size: {r['batch_size']}\")\n        if run_single:\n            print_runs(\n                label=f\"Sequential {func_name}\",\n                runs=r[\"sequential_runs\"],\n                avg=r[\"avg_sequential_ms\"],\n            )\n        if run_batch:\n            print_runs(\n                label=f\"Batch {func_name}\", runs=r[\"batch_runs\"], avg=r[\"avg_batch_ms\"]\n            )\n        if run_single and run_batch:\n            print(f\"  Speedup factor: {r['speedup_factor']:.2f}x\")\n\n    print(\"\\n\" + \"=\" * 60)\n    print(f\"SUMMARY: {func_name.upper()}\")\n    print(\"=\" * 60)\n\n    headers = [\"Batch Size\"]\n    if run_single:\n        headers.append(\"Sequential (ms)\")\n    if run_batch:\n        headers.append(\"Batch (ms)\")\n    if run_single and run_batch:\n        headers.append(\"Speedup\")\n    print(\"\".join(f\"{h:<18}\" for h in headers))\n    print(\"-\" * (18 * len(headers)))\n\n    for r in results:\n        row = [f\"{r['batch_size']}\"]\n        if run_single:\n            row.append(f\"{r['avg_sequential_ms']:.2f} ms\")\n        if run_batch:\n            row.append(f\"{r['avg_batch_ms']:.2f} ms\")\n        if run_single and run_batch:\n            row.append(f\"{r['speedup_factor']:.2f}x\")\n        print(\"\".join(f\"{v:<18}\" for v in row))\n\n\ndef print_runs(*, label, runs, avg):\n    print(f\"  {label}:\")\n    for i, t in enumerate(runs):\n        print(f\"    Run {i+1}: {t:.2f} ms\")\n    print(f\"    Average: {avg:.2f} ms\")\n\n\ndef measure_times(*, fn, num_runs):\n    times = []\n    for _ in range(num_runs):\n        start = time.perf_counter()\n        fn()\n        times.append((time.perf_counter() - start) * 1000)\n    return times\n\n\ndef generate_random_token_ids(*, num_prompts, num_tokens, tokenizer):\n    vocab_size = tokenizer.vocab_size\n    print(f\"Generating {num_prompts} random sequences with {num_tokens} tokens each...\")\n    return [\n        [random.randint(0, vocab_size - 1) for _ in range(num_tokens)]\n        for _ in range(num_prompts)\n    ]\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description=\"Tokenizer Benchmark: Sequential vs Batch Processing\"\n    )\n    parser.add_argument(\n        \"--tokenizer\",\n        type=str,\n        required=True,\n        help=\"Tokenizer name or path (e.g. nvidia/Kimi-K2-Thinking-NVFP4)\",\n    )\n    parser.add_argument(\n        \"--function\",\n        type=str,\n        nargs=\"+\",\n        choices=[\"encode\", \"decode\"],\n        default=[\"encode\", \"decode\"],\n        help=\"Functions to benchmark (default: encode decode)\",\n    )\n    parser.add_argument(\n        \"--num-tokens\",\n        type=int,\n        default=20000,\n        help=\"Number of tokens per prompt (default: 20000)\",\n    )\n    parser.add_argument(\n        \"--batch-sizes\",\n        type=int,\n        nargs=\"+\",\n        default=[1, 2, 4, 8],\n        help=\"Batch sizes to test (default: 1 2 4 8)\",\n    )\n    parser.add_argument(\n        \"--batch-mode\",\n        nargs=\"+\",\n        choices=[\"single\", \"batch\"],\n        default=[\"single\", \"batch\"],\n        help=\"Benchmark modes to run (default: single batch)\",\n    )\n    parser.add_argument(\n        \"--num-runs\",\n        type=int,\n        default=5,\n        help=\"Number of runs per batch size (default: 5)\",\n    )\n    return parser.parse_args()\n\n\nif __name__ == \"__main__\":\n    random.seed(0)\n    main()\n"
  },
  {
    "path": "benchmark/benchmark_vllm_060/README.md",
    "content": "## How to reproduce the benchmark results for SGLang v0.3.0 compared to vLLM v0.6.0\n\nIn short, with multi step enabled, in online scenarios that we benchmarked, the Median TTFT of vLLM is **3 times** that of SGLang, and the Median ITL is **10 times** that of SGLang. Lower Median TTFT and ITL are better. vLLM's multi-step optimization did not improve throughput while ensuring lower Median TTFT and ITL. Also, under maximum throughput benchmark, if vLLM does not set gpu util to 0.95 separately and uses the default configuration instead, its maximum throughput is **lower** than that of SGLang.\n\n## Online benchmark results\n\n### Llama 3.1 8B Instruct 1 x A100 80G\n\n| RPS  | Num prompts | Engine | Median E2E Latency | Median TTFT | Median TPOT | Median ITL |\n|------|-------------|--------|--------------------|-------------|-------------|------------|\n| 4    | 1200        | SGLang | 1564.17            | **31.98**   | 13.17       | **11.93**  |\n| 4    | 1200        | vLLM   | 1691.97            | **100.48**  | 14.14       | **129.32** |\n| 8    | 2400        | SGLang | 2175.02            | **35.68**   | 17.85       | **14.41**  |\n| 8    | 2400        | vLLM   | 2137.16            | **120.39**  | 17.09       | **158.63** |\n\n### Llama 3.1 70B Insruct 4 x H100 80G\n\n| RPS  | Num Prompts | Engine | Median E2E Latency | Median TTFT | Median TPOT | Median ITL |\n|------|-------------|--------|--------------------|-------------|-------------|------------|\n| 4    | 1200        | SGLang | 3005.24            | **53.94**   | 25.03       | **21.67**  |\n| 4    | 1200        | vLLM   | 2915.60            | **179.15**  | 23.58       | **231.23** |\n| 8    | 2400        | SGLang | 4064.98            | **58.11**   | 33.07       | **24.45**  |\n| 8    | 2400        | vLLM   | 3752.38            | **207.12**  | 29.15       | **275.32** |\n\n## Offline benchmark results\n\n### Llama 3.1 8B Instruct 1 x A100 80G\n\n| RPS  | Num Prompts | Engine | Request throughput | Output token throughput |\n|------|-------------|--------|--------------------|-------------------------|\n| inf  | 5000        | SGLang | 22.03              | **4281.51**             |\n| inf  | 5000        | vLLM   | 21.27              | **4132.37**             |\n\n### Llama 3.1 70B Insruct 4 x H100 80G\n\n| RPS  | Num Prompts | Engine | Request throughput | Output token throughput |\n|------|-------------|--------|--------------------|-------------------------|\n| inf  | 5000        | SGLang | 19.84              | **3856.01**             |\n| inf  | 5000        | vLLM   | 19.04              | **3700.64**             |\n\n## Installation\n\n```bash\n# install sglang v0.3.0\npip install --upgrade pip\npip install \"sglang[all]\"==0.3.0\npip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/\n\n# install vllm v0.6.0\npip install vllm==0.6.0\n```\n\n## Notes\n\nWe referred to the reproduction method in https://github.com/vllm-project/vllm/issues/8176, and added the `--num-scheduler-steps 10` parameter when starting the vLLM server. The `gpu_memory_utilization` of vLLM is by default 0.9 at both TP 1 and TP 4, while SGLang's `mem_frac` is 0.88 at TP 1 and 0.85 at TP 4, so we manually set it to 0.88 at TP 4.\n\n## Online benchmarks\n\n```bash\n# Llama 3.1 8B Instruct on 1 x A100\npython -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --enable-torch-compile --disable-radix-cache\npython -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-8B-Instruct --disable-log-requests --num-scheduler-steps 10 --max_model_len 4096\n\n# Llama 3.1 70B Instruct on 4 x H100\npython -m sglang.launch_server --model-path meta-llama/Llama-3.1-70B-Instruct --disable-radix-cache --tp 4\npython -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-70B-Instruct --disable-log-requests --num-scheduler-steps 10 --tensor 4 --max_model_len 4096\n\n# bench serving\npython3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --num-prompts 1200 --request-rate 4\npython3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --num-prompts 2400 --request-rate 8\npython3 -m sglang.bench_serving --backend vllm --dataset-name sharegpt --num-prompts 1200 --request-rate 4\npython3 -m sglang.bench_serving --backend vllm --dataset-name sharegpt --num-prompts 2400 --request-rate 8\n```\n\n## Offline benchmarks\n\n```bash\n# Llama 3.1 8B Instruct on 1 x A100\npython -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --enable-torch-compile --disable-radix-cache\npython -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-8B-Instruct --disable-log-requests --num-scheduler-steps 10 --max_model_len 4096\n\n# Llama 3.1 70B Instruct on 4 x H100\npython -m sglang.launch_server --model-path meta-llama/Llama-3.1-70B-Instruct --disable-radix-cache --tp 4 --mem-frac 0.88\npython -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-70B-Instruct --disable-log-requests --num-scheduler-steps 10 --tensor 4 --max_model_len 4096\n\n# bench serving\npython3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --num-prompts 5000\npython3 -m sglang.bench_serving --backend vllm --dataset-name sharegpt --num-prompts 5000\n```\n"
  },
  {
    "path": "benchmark/blog_v0_2/405b_sglang.sh",
    "content": "# Create dummy weights:\n# 1. Create a folder `~/llama-3.1-405b-fp8-dummy` and create `config.json` and tokenizer under this folder.\n# 2. Get `config.json`` from ./config.md\n# 3. Download the tokenizer\n#   wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer.json\n#   wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer_config.json\n\n# Launch sglang\n# python -m sglang.launch_server --model-path ~/llama-3.1-405b-fp8-dummy/ --load-format dummy --tp 8 --quantization fp8 --disable-radix --mem-frac 0.87\n\n# offline\npython3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 3000 --random-input 1024 --random-output 1024 > sglang_log11\npython3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 4000 --random-input 1024 --random-output 512 > sglang_log12\npython3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 800 --random-input 4096 --random-output 2048 > sglang_log13\npython3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 1500 --random-input 4096 --random-output 1024 > sglang_log14\npython3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 6000 --random-input 256 --random-output 512 > sglang_log15\npython3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --num-prompt 2000 > sglang_log21\n\n# online\npython3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 300 --request-rate 1 --random-input 1024 --random-output 1024 > sglang_log31\npython3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 600 --request-rate 2 --random-input 1024 --random-output 1024 > sglang_log32\npython3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 1200 --request-rate 4 --random-input 1024 --random-output 1024 > sglang_log33\npython3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 2400 --request-rate 8 --random-input 1024 --random-output 1024 > sglang_log34\npython3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 3200 --request-rate 16 --random-input 1024 --random-output 1024 > sglang_log35\n"
  },
  {
    "path": "benchmark/blog_v0_2/405b_trt.sh",
    "content": "# Launch trtllm\n# https://github.com/sgl-project/tensorrt-demo\n\n# offline\npython3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 3000 --random-input 1024 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log11\npython3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 4000 --random-input 1024 --random-output 512 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log12\npython3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 800 --random-input 4096 --random-output 2048 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log13\npython3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 1500 --random-input 4096 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log14\npython3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 6000 --random-input 256 --random-output 512 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log15\npython3 ../../python/sglang/bench_serving.py --backend trt --dataset-name sharegpt --num-prompt 2000 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log21\n\n# online\npython3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 300 --request-rate 1 --random-input 1024 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log31\npython3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 600 --request-rate 2 --random-input 1024 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log32\npython3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 1200 --request-rate 4 --random-input 1024 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log33\npython3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 2400 --request-rate 8 --random-input 1024 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log34\npython3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 3200 --request-rate 16 --random-input 1024 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log35\n"
  },
  {
    "path": "benchmark/blog_v0_2/405b_vllm.sh",
    "content": "# Create dummy weights:\n# 1. Create a folder `~/llama-3.1-405b-fp8-dummy` and create `config.json` and tokenizer under this folder.\n# 2. Get `config.json`` from ./config.md\n# 3. Download the tokenizer\n#   wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer.json\n#   wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer_config.json\n\n# Launch vllm\n# python3 -m vllm.entrypoints.openai.api_server --model ~/llama-3.1-405b-fp8-dummy/ --load-format dummy --disable-log-requests --tensor-parallel-size 8 --max-model-len 10000\n\n# offline\npython3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 3000 --random-input 1024 --random-output 1024 > vllm_log11\npython3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 4000 --random-input 1024 --random-output 512 > vllm_log12\npython3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 800 --random-input 4096 --random-output 2048 > vllm_log13\npython3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 1500 --random-input 4096 --random-output 1024 > vllm_log14\npython3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 6000 --random-input 256 --random-output 512 > vllm_log15\npython3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name sharegpt --num-prompt 2000 > vllm_log21\n\n# online\npython3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 300 --request-rate 1 --random-input 1024 --random-output 1024 > vllm_log31\npython3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 600 --request-rate 2 --random-input 1024 --random-output 1024 > vllm_log32\npython3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 1200 --request-rate 4 --random-input 1024 --random-output 1024 > vllm_log33\npython3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 2400 --request-rate 8 --random-input 1024 --random-output 1024 > vllm_log34\npython3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 3200 --request-rate 16 --random-input 1024 --random-output 1024 > vllm_log35\n"
  },
  {
    "path": "benchmark/blog_v0_2/README.md",
    "content": "# How to reproduce the benchmark results of SGLang\n\n## Prerequisite\n\n### Install the latest SGLang\n\n```bash\ngit clone https://github.com/sgl-project/sglang.git\ncd sglang\ngit checkout v0.2.7\n\npip install --upgrade pip\npip install -e \"python[all]\"\n\npip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/\n```\n\n### Set up ulimit and HF_TOKEN\n\n```bash\nulimit -n 65535\n# Change the token to a real and usable one, with access permissions for the Llama 3 models.\nexport HF_TOKEN=hf_token\n```\n\n### Launch the server\n\n```bash\n# Meta-Llama-3.1-8B-Instruct\npython -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --enable-torch-compile --disable-radix-cache\n\n# Meta-Llama-3.1-70B-Instruct\npython -m sglang.launch_server --model-path meta-llama/Llama-3.1-70B-Instruct --disable-radix-cache --tp 8\n\n# Meta-Llama-3-70B-Instruct-FP8\npython -m sglang.launch_server --model-path neuralmagic/Meta-Llama-3-70B-Instruct-FP8 --disable-radix-cache --tp 8\n```\n\n## Benchmark\n\n### Hardware Requirements\n\n- 8B models: Single NVIDIA A100 80GB GPU\n- 70B models: 8 x NVIDIA A100 80GB GPUs with Tensor Parallelism (TP) 8\n- 70B FP8 models: 8 x NVIDIA H100 GPUs with Tensor Parallelism (TP) 8\n\nPlease ensure you have the appropriate hardware before running the benchmarks.\n\n#### Offline benchmark\n\n```bash\npython3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 4000 --random-input 1024 --random-output 1024 --output-file offline.jsonl\npython3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 5000 --random-input 1024 --random-output 512 --output-file offline.jsonl\npython3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 1000 --random-input 4096 --random-output 2048 --output-file offline.jsonl\npython3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 2000 --random-input 4096 --random-output 1024 --output-file offline.jsonl\npython3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 6000 --random-input 256 --random-output 512 --output-file offline.jsonl\npython3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --num-prompts 3000 --output-file offline.jsonl\ncat offline.jsonl | cut -d':' -f12 | cut -d',' -f1\n```\n\n#### Online benchmark\n\n```bash\npython3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 300 --request-rate 1 --output-file online.jsonl\npython3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 600 --request-rate 2 --output-file online.jsonl\npython3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 1200 --request-rate 4 --output-file online.jsonl\npython3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 2400 --request-rate 8 --output-file online.jsonl\npython3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 3200 --request-rate 16 --output-file online.jsonl\ncat online.jsonl | cut -d':' -f9 | cut -d',' -f1\n```\n\n## Other\n\nWe tried using vLLM 0.5.3.post1, but it often crashes under high loads, and it seems to have similar or worse performance compared to vLLM 0.5.2 from our partial benchmarking, so we are using the older version, vLLM 0.5.2.\n\nPreparation for TensorRT LLM can refer to https://github.com/sgl-project/tensorrt-demo. Specifically, we used a batch size of 512, a max input length of 8192, and a max number of tokens of 8192. The instance count for preprocessing and postprocessing in Triton Server is 16.\n\n```bash\n# vLLM\npip install vllm==0.5.2\npip install jsonschema==4.21.1\n\n# Meta-Llama-3-8B-Instruct\npython -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-8B-Instruct --disable-log-requests\n\n# meta-llama/Meta-Llama-3-70B-Instruct\npython -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-70B-Instruct --disable-log-requests --tensor 8\n\n# neuralmagic/Meta-Llama-3-70B-Instruct-FP8\npython -m vllm.entrypoints.openai.api_server --model neuralmagic/Meta-Llama-3-70B-Instruct-FP8 --disable-log-requests --tensor 8\n```\n\n```bash\nwget https://raw.githubusercontent.com/sgl-project/sglang/main/python/sglang/bench_serving.py\n```\n\n```bash\n# vLLM Offline\n\npython3 bench_serving.py --backend vllm --dataset-name random --num-prompts 4000 --random-input 1024 --random-output 1024 --output-file offline_vllm.jsonl\npython3 bench_serving.py --backend vllm --dataset-name random --num-prompts 5000 --random-input 1024 --random-output 512 --output-file offline_vllm.jsonl\npython3 bench_serving.py --backend vllm --dataset-name random --num-prompts 1000 --random-input 4096 --random-output 2048 --output-file offline_vllm.jsonl\npython3 bench_serving.py --backend vllm --dataset-name random --num-prompts 2000 --random-input 4096 --random-output 1024 --output-file offline_vllm.jsonl\npython3 bench_serving.py --backend vllm --dataset-name random --num-prompts 6000 --random-input 256 --random-output 512 --output-file offline_vllm.jsonl\npython3 bench_serving.py --backend vllm --dataset-name sharegpt --num-prompts 3000 --output-file offline_vllm.jsonl\ncat offline_vllm.jsonl | cut -d':' -f12 | cut -d',' -f1\n```\n\n```bash\n# vLLM Online\n\npython3 bench_serving.py --backend vllm --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 300 --request-rate 1 --output-file online_vllm.jsonl\npython3 bench_serving.py --backend vllm --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 600 --request-rate 2 --output-file online_vllm.jsonl\npython3 bench_serving.py --backend vllm --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 1200 --request-rate 4 --output-file online_vllm.jsonl\npython3 bench_serving.py --backend vllm --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 2400 --request-rate 8 --output-file online_vllm.jsonl\npython3 bench_serving.py --backend vllm --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 3200 --request-rate 16 --output-file online_vllm.jsonl\ncat online_vllm.jsonl | cut -d':' -f9 | cut -d',' -f1\n```\n\n```bash\n# TensorRT LLM Offline 8B\n\npython3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --num-prompts 4000 --random-input 1024 --random-output 1024 --output-file offline_trt_8b.jsonl\npython3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --num-prompts 5000 --random-input 1024 --random-output 512 --output-file offline_trt_8b.jsonl\npython3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --num-prompts 1000 --random-input 4096 --random-output 2048 --output-file offline_trt_8b.jsonl\npython3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --num-prompts 2000 --random-input 4096 --random-output 1024 --output-file offline_trt_8b.jsonl\npython3 bench_serving.py --backend trt --dataset-name random --num-prompts 6000 --random-input 256 --random-output 512 --output-file offline_trt_8b.jsonl --model meta-llama/Meta-Llama-3-8B-Instruct\npython3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name sharegpt --num-prompts 3000 --output-file offline_trt_8b.jsonl\ncat offline_trt_8b.jsonl | cut -d':' -f12 | cut -d',' -f1\n```\n\n```bash\n# TensorRT LLM Online 8B\n\npython3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 300 --request-rate 1 --output-file online_trt_8b.jsonl\npython3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 600 --request-rate 2 --output-file online_trt_8b.jsonl\npython3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 1200 --request-rate 4 --output-file online_trt_8b.jsonl\npython3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 2400 --request-rate 8 --output-file online_trt_8b.jsonl\npython3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 3200 --request-rate 16 --output-file online_trt_8b.jsonl\ncat online_trt_8b.jsonl | cut -d':' -f9 | cut -d',' -f1\n```\n\n```bash\n# TensorRT LLM Offline 70B\n\npython3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --num-prompts 4000 --random-input 1024 --random-output 1024 --output-file offline_trt_70b.jsonl\npython3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --num-prompts 5000 --random-input 1024 --random-output 512 --output-file offline_trt_70b.jsonl\npython3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --num-prompts 1000 --random-input 4096 --random-output 2048 --output-file offline_trt_70b.jsonl\npython3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --num-prompts 2000 --random-input 4096 --random-output 1024 --output-file offline_trt_70b.jsonl\npython3 bench_serving.py --backend trt --dataset-name random --num-prompts 6000 --random-input 256 --random-output 512 --output-file offline_trt_70b.jsonl --model meta-llama/Meta-Llama-3-70B-Instruct\npython3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name sharegpt --num-prompts 3000 --output-file offline_trt_70b.jsonl\ncat offline_trt_70b.jsonl | cut -d':' -f12 | cut -d',' -f1\n```\n\n```bash\n# TensorRT LLM Online 70B\n\npython3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 300 --request-rate 1 --output-file online_trt_70b.jsonl\npython3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 600 --request-rate 2 --output-file online_trt_70b.jsonl\npython3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 1200 --request-rate 4 --output-file online_trt_70b.jsonl\npython3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 2400 --request-rate 8 --output-file online_trt_70b.jsonl\npython3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 3200 --request-rate 16 --output-file online_trt_70b.jsonl\ncat online_trt_70b.jsonl | cut -d':' -f9 | cut -d',' -f1\n```\n"
  },
  {
    "path": "benchmark/blog_v0_2/config.md",
    "content": "### used for TensorRT LLM\n\n```\n{\n    \"architecture\": \"LlamaForCausalLM\",\n    \"dtype\": \"float16\",\n    \"logits_dtype\": \"float32\",\n    \"vocab_size\": 128256,\n    \"max_position_embeddings\": 8192,\n    \"hidden_size\": 16384,\n    \"num_hidden_layers\": 126,\n    \"num_attention_heads\": 128,\n    \"num_key_value_heads\": 16,\n    \"head_size\": 128,\n    \"qk_layernorm\": false,\n    \"hidden_act\": \"silu\",\n    \"intermediate_size\": 53248,\n    \"norm_epsilon\": 1e-05,\n    \"position_embedding_type\": \"rope_gpt_neox\",\n    \"use_parallel_embedding\": false,\n    \"embedding_sharding_dim\": 0,\n    \"share_embedding_table\": false,\n    \"mapping\": {\n        \"world_size\": 8,\n        \"tp_size\": 8,\n        \"pp_size\": 1,\n        \"gpus_per_node\": 8\n    },\n    \"quantization\": {\n        \"quant_algo\": \"FP8\",\n        \"kv_cache_quant_algo\": null,\n        \"group_size\": 128,\n        \"smoothquant_val\": null,\n        \"has_zero_point\": false,\n        \"pre_quant_scale\": false,\n        \"exclude_modules\": [\n            \"lm_head\"\n        ]\n    },\n    \"kv_dtype\": \"float16\",\n    \"rotary_scaling\": null,\n    \"residual_mlp\": false,\n    \"moe_normalization_mode\": null,\n    \"rotary_base\": 500000.0,\n    \"moe_num_experts\": 0,\n    \"moe_top_k\": 0,\n    \"moe_tp_mode\": 2,\n    \"attn_bias\": false,\n    \"disable_weight_only_quant_plugin\": false,\n    \"mlp_bias\": false\n}\n```\n\n### used for vLLM and SGLang\n\n```\n{\n  \"_name_or_path\": \"dummy_fp8\",\n  \"architectures\": [\n    \"LlamaForCausalLM\"\n  ],\n  \"attention_bias\": false,\n  \"attention_dropout\": 0.0,\n  \"bos_token_id\": 128000,\n  \"eos_token_id\": 128009,\n  \"hidden_act\": \"silu\",\n  \"hidden_size\": 16384,\n  \"initializer_range\": 0.02,\n  \"intermediate_size\": 53248,\n  \"mlp_bias\": false,\n  \"model_type\": \"llama\",\n  \"num_attention_heads\": 128,\n  \"num_hidden_layers\": 126,\n  \"num_key_value_heads\": 8,\n  \"pretraining_tp\": 1,\n  \"quantization_config\": {\n    \"activation_scheme\": \"static\",\n    \"ignored_layers\": [\n      \"lm_head\"\n    ],\n    \"quant_method\": \"fp8\"\n  },\n  \"rope_scaling\": {\n    \"factor\": 8.0,\n    \"low_freq_factor\": 1.0,\n    \"high_freq_factor\": 4.0,\n    \"original_max_position_embeddings\": 8192,\n    \"rope_type\": \"llama3\"\n  },\n  \"max_position_embeddings\": 131072,\n  \"rms_norm_eps\": 1e-05,\n  \"rope_scaling\": null,\n  \"rope_theta\": 500000.0,\n  \"tie_word_embeddings\": false,\n  \"torch_dtype\": \"bfloat16\",\n  \"transformers_version\": \"4.41.1\",\n  \"use_cache\": true,\n  \"vocab_size\": 128256\n}\n```\n"
  },
  {
    "path": "benchmark/boolq/README.md",
    "content": "## Download data\n```\ngit clone https://hf-mirror.com/datasets/google/boolq\n```\n\n## Convert parquet to json\n```\nbash parquet_to_json.sh\n```\n## Run benchmark\n\n### Benchmark sglang\n```\npython -m sglang.launch_server --model-path ramblingpolymath/Qwen3-32B-W8A8 --port 30000\n```\n\n```\npython3 bench_sglang.py\n```\n"
  },
  {
    "path": "benchmark/boolq/bench_sglang.py",
    "content": "import argparse\nimport json\nimport time\n\nimport numpy as np\n\nfrom sglang.api import set_default_backend\nfrom sglang.test.test_utils import (\n    add_common_sglang_args_and_parse,\n    select_sglang_backend,\n)\nfrom sglang.utils import read_jsonl\n\n\ndef get_example(lines, i, answer):\n    prompt = \"Question: \" + lines[i][\"question\"] + lines[i][\"passage\"] + \"\\nAnswer:\"\n    if answer:\n        prompt += str(lines[i][\"answer\"])\n    return prompt\n\n\ndef few_shot_examples(lines, k):\n    prompts = \"\"\n    for i in range(k):\n        prompts += get_example(lines, i, True) + \"\\n\\n\"\n    return prompts\n\n\ndef main(args):\n    # Select backend\n    set_default_backend(select_sglang_backend(args))\n\n    # Read data\n    train_data_path = args.train_data_path\n    test_data_path = args.test_data_path\n    lines_train = list(read_jsonl(train_data_path))\n    lines_test = list(read_jsonl(test_data_path))\n\n    # Construct prompts\n    num_questions = args.num_questions\n    num_shots = args.num_shots\n    few_shots = few_shot_examples(lines_train, num_shots)\n\n    questions = []\n    answer = []\n    for i in range(len(lines_test[:num_questions])):\n        questions.append(get_example(lines_test, i, False))\n        answer.append(str(lines_test[i][\"answer\"]))\n    arguments = [{\"question\": q} for q in questions]\n\n    #####################################\n    ######### SGL Program Begin #########\n    #####################################\n\n    import sglang as sgl\n\n    @sgl.function\n    def few_shot_boolq(s, question):\n        s += few_shots + question\n        s += sgl.gen(\"answer\", max_tokens=5, stop=[\"\\n\"])\n\n    #####################################\n    ########## SGL Program End ##########\n    #####################################\n\n    # Run requests\n    tic = time.perf_counter()\n    states = few_shot_boolq.run_batch(\n        arguments,\n        temperature=0,\n        num_threads=args.parallel,\n        progress_bar=True,\n    )\n    latency = time.perf_counter() - tic\n\n    preds = []\n    for i in range(len(states)):\n        preds.append(states[i][\"answer\"])\n\n    # Compute accuracy\n    acc = np.mean(np.array(preds) == np.array(answer))\n\n    # Compute speed\n    num_output_tokens = sum(\n        s.get_meta_info(\"answer\")[\"completion_tokens\"] for s in states\n    )\n    output_throughput = num_output_tokens / latency\n\n    # Print results\n    print(f\"Accuracy: {acc:.3f}\")\n    print(f\"Latency: {latency:.3f} s\")\n    print(f\"Output throughput: {output_throughput:.3f} token/s\")\n\n    # Results\n    with open(args.result_file, \"a\") as fout:\n        value = {\n            \"task\": \"boolq\",\n            \"backend\": args.backend,\n            \"num_gpus\": 1,\n            \"latency\": round(latency, 3),\n            \"accuracy\": round(acc, 3),\n            \"num_requests\": args.num_questions,\n            \"other\": {\n                \"num_questions\": args.num_questions,\n                \"parallel\": args.parallel,\n            },\n        }\n        fout.write(json.dumps(value) + \"\\n\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--num-shots\", type=int, default=5)\n    parser.add_argument(\n        \"--train-data-path\", type=str, default=\"./boolq/data/train-00000-of-00001.json\"\n    )\n    parser.add_argument(\n        \"--test-data-path\",\n        type=str,\n        default=\"./boolq/data/validation-00000-of-00001.json\",\n    )\n    parser.add_argument(\"--num-questions\", type=int, default=200)\n    args = add_common_sglang_args_and_parse(parser)\n    main(args)\n"
  },
  {
    "path": "benchmark/boolq/convert_parquet_to_json.py",
    "content": "import sys\n\nimport pyarrow.parquet as pq\n\n\ndef convert_parquet_to_json(input_file, output_file):\n    # read parquet file\n    table = pq.read_table(input_file)\n\n    # turn parquet data to dataframe\n    df = table.to_pandas()\n\n    # turn dataframe to json form\n    json_data = df.to_json(orient=\"records\", lines=True)\n\n    # write json to file\n    with open(output_file, \"w\") as f:\n        f.write(json_data)\n\n\nif __name__ == \"__main__\":\n    if len(sys.argv) != 3:\n        print(\"Usage:python convert_parquet_to_json.py <input_file> <output_file>\")\n\n    input_file = sys.argv[1]\n    output_file = sys.argv[2]\n\n    convert_parquet_to_json(input_file, output_file)\n"
  },
  {
    "path": "benchmark/boolq/parquet_to_json.sh",
    "content": "#!/bin/bash\n\n#define input and output direction\ninput_dir=\"./boolq/data\"\noutput_dir=\"./boolq/data\"\n\n#define files needed to be handled\nfiles=(\n        \"train-00000-of-00001.parquet\"\n        \"validation-00000-of-00001.parquet\"\n)\n\n#foe files above, use python script to convert the form\nfor file in \"${files[@]}\"; do\n    input_file=\"${input_dir}/${file}\"\n    output_file=\"${output_dir}/${file%.parquet}.json\"\n\n    echo \"Converting ${input_file} to ${output_file} ...\"\n    python3 convert_parquet_to_json.py \"${input_file}\" \"${output_file}\"\n\n    if [ $? -eq 0 ]; then\n        echo \"Conversion successful: ${output_file}\"\n    else\n        echo \"Conversion failed: ${input_file}\"\n    fi\ndone\n"
  },
  {
    "path": "benchmark/ceval/README.md",
    "content": "## Download data\n```\ngit lfs clone https://huggingface.co/datasets/ceval/ceval-exam\n```\n\n## Run benchmark\n\n### Benchmark sglang\n```\npython -m sglang.launch_server --model-path ramblingpolymath/Qwen3-32B-W8A8 --port 30000\n```\n\n```\npython3 bench_sglang.py\n```\n"
  },
  {
    "path": "benchmark/ceval/bench_sglang.py",
    "content": "import argparse\nimport json\nimport os\nimport random\nimport re\nimport time\n\nimport numpy as np\nfrom datasets import load_dataset\n\nfrom sglang.lang.api import set_default_backend\nfrom sglang.test.test_utils import (\n    add_common_sglang_args_and_parse,\n    select_sglang_backend,\n)\n\nchoices = [\"A\", \"B\", \"C\", \"D\"]\n\n\ndef get_one_example(line, include_answer):\n    res = line[\"question\"]\n    res += f\"\\nA. {line['A']}\"\n    res += f\"\\nB. {line['B']}\"\n    res += f\"\\nC. {line['C']}\"\n    res += f\"\\nD. {line['D']}\"\n\n    if include_answer:\n        res += f\"\\nAnswer: {line['answer']} \\n\\n\"\n    return res\n\n\ndef get_few_shot_examples(lines):\n    res = \"\"\n    for line in lines:\n        res += get_one_example(line, True) + \"\\n\\n\"\n    return res\n\n\ndef get_answer_value(response):\n    pattern = r\"(Answer:|answer:|答案是|答案是:|正确答案是:|答案:|Assistant:)\\s*([A-D])(?![\\w])\"\n    match = re.search(pattern, response)\n\n    if match:\n        return match.group(2)\n\n    return random.choice(choices)\n\n\ndef main(args):\n    # Read data && Construct prompts\n    arguments = []\n    labels = []\n    examples = \"examples:\\n\"\n    data_path = args.data_path\n    for subject in os.listdir(data_path):\n        subject_path = os.path.join(data_path, subject)\n        if os.path.isdir(subject_path) and subject != \".git\":\n            dataset = load_dataset(data_path, name=subject)\n            dev_lines_temp = dataset[\"dev\"]\n            val_lines_temp = dataset[\"val\"]\n            few_shot_examples = get_few_shot_examples(dev_lines_temp)\n            examples += f\"{few_shot_examples}\"\n            for val_line in val_lines_temp:\n                arguments.append(\n                    {\n                        \"examples\": few_shot_examples,\n                        \"question\": get_one_example(val_line, False),\n                    }\n                )\n                labels.append(val_line[\"answer\"])\n\n    #####################################\n    ######### SGL Program Begin #########\n    #####################################\n\n    import sglang as sgl\n\n    @sgl.function\n    def few_shot_ceval(s, examples, question):\n        s += examples + question + sgl.gen(\"Answer\")\n\n    #####################################\n    ########## SGL Program End ##########\n    #####################################\n\n    num_questions = args.num_questions if args.num_questions else len(arguments)\n\n    # Select backend\n    set_default_backend(select_sglang_backend(args))\n\n    # Run requests\n    tic = time.perf_counter()\n    states = few_shot_ceval.run_batch(\n        arguments[:num_questions],\n        temperature=0,\n        num_threads=args.parallel,\n        progress_bar=True,\n    )\n    latency = time.perf_counter() - tic\n\n    preds = [get_answer_value(states[i][\"Answer\"]) for i in range(num_questions)]\n\n    # Compute accuracy\n    acc = np.mean(np.array(preds) == np.array(labels[:num_questions]))\n\n    # Compute speed\n    num_output_tokens = sum(\n        s.get_meta_info(\"Answer\")[\"completion_tokens\"] for s in states\n    )\n    output_throughput = num_output_tokens / latency\n\n    # Print results\n    print(f\"Accuracy: {acc:.3f}\")\n    print(f\"Latency: {latency:.3f} s\")\n    print(f\"Output throughput: {output_throughput:.3f} token/s\")\n\n    # Write results\n    with open(args.result_file, \"a\") as fout:\n        value = {\n            \"task\": \"ceval\",\n            \"backend\": args.backend,\n            \"num_gpus\": 1,\n            \"latency\": round(latency, 3),\n            \"accuracy\": round(acc, 3),\n            \"num_requests\": args.num_questions,\n            \"other\": {\n                \"parallel\": args.parallel,\n            },\n        }\n        fout.write(json.dumps(value) + \"\\n\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--data-path\", type=str, default=\"ceval/ceval-exam\")\n    parser.add_argument(\"--num-questions\", type=int, default=None)\n    args = add_common_sglang_args_and_parse(parser)\n    main(args)\n"
  },
  {
    "path": "benchmark/deepseek_v3/README.md",
    "content": "# DeepSeek V3.1/V3/R1 Support\n\nThe SGLang and DeepSeek teams collaborated to get DeepSeek V3 FP8 running on NVIDIA and AMD GPUs **from day one**. SGLang also supports [MLA optimization](https://lmsys.org/blog/2024-09-04-sglang-v0-3/#deepseek-multi-head-latent-attention-mla-throughput-optimizations) and [DP attention](https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models), making SGLang one of the best open-source LLM engines for running DeepSeek models. SGLang is the inference engine recommended by the official [DeepSeek team](https://github.com/deepseek-ai/DeepSeek-V3/tree/main?tab=readme-ov-file#62-inference-with-sglang-recommended).\n\nSpecial thanks to Meituan's Search & Recommend Platform Team and Baseten's Model Performance Team for implementing the model, and DataCrunch for providing GPU resources.\n\nFor optimizations made on the DeepSeek series models regarding SGLang, please refer to [DeepSeek V3/V3.1/R1 Model Optimizations in SGLang](https://docs.sglang.io/basic_usage/deepseek_v3.html#optimizations).\n\n## Installation & Launch\n\nIf you encounter errors when starting the server, ensure the weights have finished downloading. It's recommended to download them beforehand or restart multiple times until all weights are downloaded.\n\n### Using Docker (Recommended)\n\n```bash\n# Pull latest image\n# https://hub.docker.com/r/lmsysorg/sglang/tags\ndocker pull lmsysorg/sglang:latest\n\n# Launch\ndocker run --gpus all --shm-size 32g -p 30000:30000 -v ~/.cache/huggingface:/root/.cache/huggingface --ipc=host --network=host --privileged lmsysorg/sglang:latest \\\n    python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code --port 30000\n```\n\nIf you are using RDMA, please note that:\n\n1. `--network host` and `--privileged` are required by RDMA. If you don't need RDMA, you can remove them.\n2. You may need to set `NCCL_IB_GID_INDEX` if you are using RoCE, for example: `export NCCL_IB_GID_INDEX=3`.\n\nAdd [performance optimization options](#performance-optimization-options) as needed.\n\n### Using pip\n\n```bash\n# Installation\npip install sglang\n\n# Launch\npython3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code\n```\n\nAdd [performance optimization options](#performance-optimization-options) as needed.\n\n<a id=\"option_args\"></a>\n\n### Performance Optimization Options\n\n[MLA optimizations](https://lmsys.org/blog/2024-09-04-sglang-v0-3/#deepseek-multi-head-latent-attention-mla-throughput-optimizations) are enabled by default. Here are some optional optimizations can be enabled as needed.\n\n- [Data Parallelism Attention](https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models): For high QPS scenarios, add the `--enable-dp-attention` argument to boost throughput.\n- [Torch.compile Optimization](https://lmsys.org/blog/2024-09-04-sglang-v0-3/#torchcompile-latency-optimizations): Add `--enable-torch-compile` argument to enable it. This will take some time while server starts. The maximum batch size for torch.compile optimization can be controlled with `--torch-compile-max-bs`. It's recommended to set it between `1` and `8`. (e.g., `--torch-compile-max-bs 8`)\n\n### Usage: Chat with DeepSeek\n\n#### DeepSeek V3/R1\n\n```python3\nimport openai\nclient = openai.Client(\n    base_url=\"http://127.0.0.1:30000/v1\", api_key=\"EMPTY\")\n\n# Chat completion\nresponse = client.chat.completions.create(\n    model=\"default\",\n    messages=[\n        {\"role\": \"system\", \"content\": \"You are a helpful AI assistant\"},\n        {\"role\": \"user\", \"content\": \"List 3 countries and their capitals.\"},\n    ],\n    temperature=0,\n    max_tokens=64,\n)\nprint(response)\n```\n\n#### DeepSeek V3.1\nOn top of the basic usage similar to the DeepSeek V3/R1 example, DeepSeek V3.1 supports a request-level thinking/non-thinking toggle. Simply switch the `\"thinking\"` field in `extra_body={\"chat_template_kwargs\": {\"thinking\": True}}` to enable/disable the thinking mode.\n\n##### Non Thinking\n```python3\nimport openai\nclient = openai.Client(\n    base_url=\"http://127.0.0.1:30000/v1\", api_key=\"EMPTY\")\n\n# Chat completion\nresponse = client.chat.completions.create(\n    model=\"default\",\n    messages=[\n        {\"role\": \"system\", \"content\": \"You are a helpful AI assistant\"},\n        {\"role\": \"user\", \"content\": \"Answer the following with the second letter of the correct answer only: What is the capital of France?\"},\n    ],\n    temperature=0,\n    max_tokens=1024,\n    extra_body = {\"chat_template_kwargs\": {\"thinking\": False}}\n)\nprint(response.choices[0].message.content)\n```\nAnswer:\n```\nh\n```\n* The correct response should be 'A', as the correct answer to the question is 'Paris'.\n##### Thinking\n```python3\nimport openai\nclient = openai.Client(\n    base_url=\"http://127.0.0.1:30000/v1\", api_key=\"EMPTY\")\n\n# Chat completion\nresponse = client.chat.completions.create(\n    model=\"default\",\n    messages=[\n        {\"role\": \"system\", \"content\": \"You are a helpful AI assistant\"},\n        {\"role\": \"user\", \"content\": \"Answer the following with the second letter of the correct answer only: What is the capital of France?\"},\n    ],\n    temperature=0,\n    max_tokens=1024,\n    extra_body = {\"chat_template_kwargs\": {\"thinking\": True}}\n)\nprint(response)\n```\nAnswer:\n```\nFirst, the question is: \"What is the capital of France?\" I know that the capital of France is Paris.\n\nThe user says: \"Answer the following with the second letter of the correct answer only.\" So, I need to provide only the second letter of the correct answer.\n\nThe correct answer is \"Paris\". Now, I need to find the second letter of \"Paris\".\n\nLet's spell it out: P-A-R-I-S.\n\n- First letter: P\n\n- Second letter: A\n\n- Third letter: R\n\n- Fourth letter: I\n\n- Fifth letter: S\n\nSo, the second letter is \"A\".\n\nI should only output the second letter, which is \"A\". No additional text or explanation, just the letter.\n\nThe user emphasized \"the second letter of the correct answer only\", so my response should be just \"A\".\n\nFinally, I need to make sure that this is the correct answer. Yes, Paris is indeed the capital of France.</think>A\n```\n* The response contains `</think>` thinking trace and model was able to derive the correct answer from it.\n\n### Example: Serving with two H20\\*8 nodes\n\nFor example, there are two H20 nodes, each with 8 GPUs. The first node's IP is `10.0.0.1`, and the second node's IP is `10.0.0.2`. Please **use the first node's IP** for both commands.\n\nIf the command fails, try setting the `GLOO_SOCKET_IFNAME` parameter. For more information, see [Common Environment Variables](https://pytorch.org/docs/stable/distributed.html#common-environment-variables).\n\nIf the multi nodes support NVIDIA InfiniBand and encounter hanging issues during startup, consider adding the parameter `export NCCL_IB_GID_INDEX=3`. For more information, see [this](https://github.com/sgl-project/sglang/issues/3516#issuecomment-2668493307).\n\n```bash\n# node 1\npython3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --dist-init-addr 10.0.0.1:5000 --nnodes 2 --node-rank 0 --trust-remote-code\n\n# node 2\npython3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --dist-init-addr 10.0.0.1:5000 --nnodes 2 --node-rank 1 --trust-remote-code\n```\n\nIf you have two H100 nodes, the usage is similar to the aforementioned H20.\n\n> **Note that the launch command here does not enable Data Parallelism Attention or `torch.compile` Optimization**. For optimal performance, please refer to the command options in [Performance Optimization Options](#option_args).\n\n### Example: Serving with one B200 node\n\nThere is one B200 node with 4 (for FP4) GPUs or 8 (for FP4 or FP8) GPUs.  Both FP4 and FP8 models are supported for DeepSeek R1.  The flags to achieve optimal performance for each are slightly different.\n\n#### FP4\n\nIf using 4 GPUs:\n\n```bash\npython3 -m sglang.launch_server --model-path nvidia/DeepSeek-R1-0528-FP4-V2 --host 0.0.0.0 --port 8000 --tensor-parallel-size=4 --cuda-graph-max-bs 256 --max-running-requests 256 --mem-fraction-static 0.85 --ep-size 4 --scheduler-recv-interval 30 --enable-symm-mem --stream-interval 10\n```\n\nIf using 8 GPUs:\n\n```bash\npython3 -m sglang.launch_server --model-path nvidia/DeepSeek-R1-0528-FP4-V2 --host 0.0.0.0 --port 8000 --tensor-parallel-size=8 --cuda-graph-max-bs 256 --max-running-requests 256 --mem-fraction-static 0.85 --ep-size 8 --scheduler-recv-interval 30 --enable-symm-mem --stream-interval 10\n```\n\n#### FP8\n\n```bash\nSGLANG_ENABLE_JIT_DEEPGEMM=false python3 -m sglang.launch_server --model-path=deepseek-ai/DeepSeek-R1-0528 --host=0.0.0.0 --port=8000 --tensor-parallel-size=8 --cuda-graph-max-bs 128 --max-running-requests 128 --mem-fraction-static 0.82 --kv-cache-dtype fp8_e4m3 --chunked-prefill-size 32768 --max-prefill-tokens 32768 --scheduler-recv-interval 30 --stream-interval 30 --fp8-gemm-backend flashinfer_trtllm\n```\n\n### Example: Serving with two H200\\*8 nodes and docker\n\nThere are two H200 nodes, each with 8 GPUs. The first node's IP is `192.168.114.10`, and the second node's IP is `192.168.114.11`. Configure the endpoint to expose it to another Docker container using `--host 0.0.0.0` and `--port 40000`, and set up communications with `--dist-init-addr 192.168.114.10:20000`.\nA single H200 with 8 devices can run DeepSeek V3, the dual H200 setup is just to demonstrate multi-node usage.\n\n```bash\n# node 1\ndocker run --gpus all \\\n    --shm-size 32g \\\n    --network=host \\\n    -v ~/.cache/huggingface:/root/.cache/huggingface \\\n    --name sglang_multinode1 \\\n    -it \\\n    --rm \\\n    --env \"HF_TOKEN=$HF_TOKEN\" \\\n    --ipc=host \\\n    lmsysorg/sglang:latest \\\n    python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --dist-init-addr 192.168.114.10:20000 --nnodes 2 --node-rank 0 --trust-remote-code --host 0.0.0.0 --port 40000\n```\n\n```bash\n# node 2\ndocker run --gpus all \\\n    --shm-size 32g \\\n    --network=host \\\n    -v ~/.cache/huggingface:/root/.cache/huggingface \\\n    --name sglang_multinode2 \\\n    -it \\\n    --rm \\\n    --env \"HF_TOKEN=$HF_TOKEN\" \\\n    --ipc=host \\\n    lmsysorg/sglang:latest \\\n    python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --dist-init-addr 192.168.114.10:20000 --nnodes 2 --node-rank 1 --trust-remote-code --host 0.0.0.0 --port 40000\n```\n\nTo ensure functionality, we include a test from a client Docker container.\n\n```bash\ndocker run --gpus all \\\n    --shm-size 32g \\\n    --network=host \\\n    -v ~/.cache/huggingface:/root/.cache/huggingface \\\n    --name sglang_multinode_client \\\n    -it \\\n    --rm \\\n    --env \"HF_TOKEN=$HF_TOKEN\" \\\n    --ipc=host \\\n    lmsysorg/sglang:latest \\\n    python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 1 --random-output 512 --random-range-ratio 1 --num-prompts 1 --host 0.0.0.0 --port 40000 --output-file \"deepseekv3_multinode.jsonl\"\n```\n\n> **Note that the launch command here does not enable Data Parallelism Attention or `torch.compile` Optimization**. For optimal performance, please refer to the command options in [Performance Optimization Options](#option_args).\n\n### Example: Serving with four A100\\*8 nodes\n\nTo serve DeepSeek-V3 with A100 GPUs, we need to convert the [FP8 model checkpoints](https://huggingface.co/deepseek-ai/DeepSeek-V3) to BF16 with [script](https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/fp8_cast_bf16.py) mentioned [here](https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/fp8_cast_bf16.py) first.\n\nSince the BF16 model is over 1.3 TB, we need to prepare four A100 nodes, each with 8 80GB GPUs. Assume the first node's IP is `10.0.0.1`, and the converted model path is `/path/to/DeepSeek-V3-BF16`, we can have following commands to launch the server.\n\n```bash\n# node 1\npython3 -m sglang.launch_server --model-path /path/to/DeepSeek-V3-BF16 --tp 32 --dist-init-addr 10.0.0.1:5000 --nnodes 4 --node-rank 0 --trust-remote-code --host 0.0.0.0 --port 30000\n\n# node 2\npython3 -m sglang.launch_server --model-path /path/to/DeepSeek-V3-BF16 --tp 32 --dist-init-addr 10.0.0.1:5000 --nnodes 4 --node-rank 1 --trust-remote-code\n\n# node 3\npython3 -m sglang.launch_server --model-path /path/to/DeepSeek-V3-BF16 --tp 32 --dist-init-addr 10.0.0.1:5000 --nnodes 4 --node-rank 2 --trust-remote-code\n\n# node 4\npython3 -m sglang.launch_server --model-path /path/to/DeepSeek-V3-BF16 --tp 32 --dist-init-addr 10.0.0.1:5000 --nnodes 4 --node-rank 3 --trust-remote-code\n```\n\n> **Note that the launch command here does not enable Data Parallelism Attention or `torch.compile` Optimization**. For optimal performance, please refer to the command options in [Performance Optimization Options](#option_args).\n\nThen we can benchmark the accuracy and latency by accessing the first node's exposed port with the following example commands.\n\n```bash\n# bench accuracy\npython3 benchmark/gsm8k/bench_sglang.py --num-questions 1319 --host 10.0.0.1 --port 30000\n\n# bench latency\npython3 -m sglang.bench_one_batch_server --model None --base-url http://10.0.0.1:30000 --batch-size 1 --input-len 128 --output-len 128\n```\n\n\n### Example: Serving with 8 A100/A800 with AWQ Quantization\n\n**Recommended Usage**\n\nAdd `--quantization moe_wna16` flag to enable moe wna16 kernel for better performance.\nOne example is as follows:\n\n```bash\npython3 -m sglang.launch_server --model cognitivecomputations/DeepSeek-R1-AWQ --tp 8 --trust-remote-code --quantization moe_wna16\n```\n\nAlternatively, you can use `--quantization awq_marlin` as follows:\n\n```bash\npython3 -m sglang.launch_server --model cognitivecomputations/DeepSeek-R1-AWQ --tp 8 --trust-remote-code --quantization awq_marlin --dtype float16\n```\n\nNote that `awq_marlin` only supports `float16` now, which may lead to some precision loss.\n\n### Example: Serving with 16 A100/A800 with int8 Quantization\n\nThere are block-wise and per-channel quantization methods, and the quantization parameters have already been uploaded to Huggingface. One example is as follows:\n\n- [meituan/DeepSeek-R1-Block-INT8](https://huggingface.co/meituan/DeepSeek-R1-Block-INT8)\n- [meituan/DeepSeek-R1-Channel-INT8](https://huggingface.co/meituan/DeepSeek-R1-Channel-INT8)\n\nAssuming that master node IP is `MASTER_IP`, checkpoint path is `/path/to/DeepSeek-R1-INT8` and port=5000, we can have following commands to launch the server:\n```bash\n#master\npython3 -m sglang.launch_server \\\n\t--model meituan/DeepSeek-R1-Block-INT8 --tp 16 --dist-init-addr \\\n\tMASTER_IP:5000 --nnodes 2 --node-rank 0 --trust-remote-code --enable-torch-compile --torch-compile-max-bs 8\n#cluster\npython3 -m sglang.launch_server \\\n\t--model meituan/DeepSeek-R1-Block-INT8 --tp 16 --dist-init-addr \\\n\tMASTER_IP:5000 --nnodes 2 --node-rank 1 --trust-remote-code --enable-torch-compile --torch-compile-max-bs 8\n```\n\n> **Note that the launch command here enables `torch.compile` Optimization**. For optimal performance, please refer to the command options in [Performance Optimization Options](#option_args).\n\nThen on the **master node**, supposing the ShareGPT data is located at `/path/to/ShareGPT_V3_unfiltered_cleaned_split.json`, you can run the following commands to benchmark the launched server:\n\n```bash\n# bench accuracy\npython3 benchmark/gsm8k/bench_sglang.py --num-questions 1319\n\n# bench serving\npython3 -m sglang.bench_serving --dataset-path /path/to/ShareGPT_V3_unfiltered_cleaned_split.json --dataset-name random  --random-input 128 --random-output 128 --num-prompts 1000 --request-rate 128 --random-range-ratio 1.0\n```\n\n> **Note: using `--parallel 200` can accelerate accuracy benchmarking**.\n\n### Example: Serving with 32 L40S with int8 Quantization\n\nRunning with per-channel quantization model:\n\n- [meituan/DeepSeek-R1-Channel-INT8](https://huggingface.co/meituan/DeepSeek-R1-Channel-INT8)\n\nAssuming that master node IP is `MASTER_IP`, checkpoint path is `/path/to/DeepSeek-R1-Channel-INT8` and port=5000, we can have following commands to launch the server:\n\n```bash\n#master\npython3 -m sglang.launch_server --model meituan/DeepSeek-R1-Channel-INT8 --tp 32 --quantization w8a8_int8 \\\n\t--dist-init-addr MASTER_IP:5000 --nnodes 4 --node-rank 0 --trust-remote \\\n\t--enable-torch-compile --torch-compile-max-bs 32\n#cluster\npython3 -m sglang.launch_server --model meituan/DeepSeek-R1-Channel-INT8 --tp 32 --quantization w8a8_int8 \\\n\t--dist-init-addr MASTER_IP:5000 --nnodes 4 --node-rank 1 --trust-remote \\\n\t--enable-torch-compile --torch-compile-max-bs 32\npython3 -m sglang.launch_server --model meituan/DeepSeek-R1-Channel-INT8 --tp 32 --quantization w8a8_int8 \\\n\t--dist-init-addr MASTER_IP:5000 --nnodes 4 --node-rank 2 --trust-remote \\\n\t--enable-torch-compile --torch-compile-max-bs 32\npython3 -m sglang.launch_server --model meituan/DeepSeek-R1-Channel-INT8 --tp 32 --quantization w8a8_int8 \\\n\t--dist-init-addr MASTER_IP:5000 --nnodes 4 --node-rank 3 --trust-remote \\\n\t--enable-torch-compile --torch-compile-max-bs 32\n```\n\nThe benchmarking method is the same as describted in the previous [16 x A100](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-16-a100a800-with-int8-quantization) example.\n\n### Example: Serving on any cloud or Kubernetes with SkyPilot\n\nSkyPilot helps find cheapest available GPUs across any cloud or existing Kubernetes clusters and launch distributed serving with a single command. See details [here](https://github.com/skypilot-org/skypilot/tree/master/llm/deepseek-r1).\n\nTo serve on multiple nodes:\n\n```bash\ngit clone https://github.com/skypilot-org/skypilot.git\n# Serve on 2 H100/H200x8 nodes\nsky launch -c r1 llm/deepseek-r1/deepseek-r1-671B.yaml --retry-until-up\n# Serve on 4 A100x8 nodes\nsky launch -c r1 llm/deepseek-r1/deepseek-r1-671B-A100.yaml --retry-until-up\n```\n\n#### Troubleshooting\n\nIf you encounter the following error with fp16/bf16 checkpoint:\n\n```bash\nValueError: Weight output_partition_size = 576 is not divisible by weight quantization block_n = 128.\n```\n\nedit your `config.json` and remove the `quantization_config` block. For example:\n\n```json\n\"quantization_config\": {\n    \"activation_scheme\": \"dynamic\",\n    \"fmt\": \"e4m3\",\n    \"quant_method\": \"fp8\",\n    \"weight_block_size\": [128, 128]\n},\n```\n\nRemoving this block typically resolves the error. For more details, see the discussion in [sgl-project/sglang#3491](https://github.com/sgl-project/sglang/issues/3491#issuecomment-2650779851).\n\n# Example: Serving with 4 H200 with w4fp8 Quantization\nThere are mixed-precision quantization methods where MoE layers are computed using W4(int)A(FP)8 quantization while the dense layers remain in FP8 precision. Users can run these models efficiently on 4xH200 GPUs (or potentially 8xH100 GPUs), as the pre-quantized weights are already available on Hugging Face. Here's an example:\n\n```bash\npython -m sglang.launch_server --model novita/Deepseek-V3-0324-W4AFP8 --mem-fraction-static 0.85 --disable-shared-experts-fusion --tp-size 4\n```\n\nOther variants of pre-quantized DeepSeek models are also available:\n\n- [novita/Deepseek-V3.1-W4AFP8](https://huggingface.co/novita/Deepseek-V3.1-W4AFP8)\n- [novita/Deepseek-R1-0528-W4AFP8](https://huggingface.co/novita/Deepseek-R1-0528-W4AFP8)\n- [novita/Deepseek-R1-W4AFP8](https://huggingface.co/novita/Deepseek-R1-W4AFP8)\n- [novita/Deepseek-V3-0324-W4AFP8](https://huggingface.co/novita/Deepseek-V3-0324-W4AFP8)\n\n\n## DeepSeek V3 Optimization Plan\n\nhttps://github.com/sgl-project/sglang/issues/2591\n"
  },
  {
    "path": "benchmark/dspy/README.md",
    "content": "## Install\n\n```\npip3 install dspy-ai\n```\n\nTurn off cache at https://github.com/stanfordnlp/dspy/blob/34d8420383ec752037aa271825c1d3bf391e1277/dsp/modules/cache_utils.py#L10.\n```\ncache_turn_on = False\n```\n\nor set the environment variable\n\n```\nexport DSP_CACHEBOOL=false\n```\n\n## Benchmark SGLang\n```\npython -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000\n```\n\n```\npython3 bench_dspy_intro.py --backend sglang\n```\n\n\n## Benchmark TGI\n```\ndocker run --name tgi --rm -ti --gpus all --network host \\\n  -v /home/ubuntu/model_weights/Llama-2-7b-chat-hf:/Llama-2-7b-chat-hf \\\n  ghcr.io/huggingface/text-generation-inference:1.3.0 \\\n  --model-id /Llama-2-7b-chat-hf --num-shard 1  --trust-remote-code \\\n  --max-input-length 2048 --max-total-tokens 4096 \\\n  --port 24000\n```\n\n```\npython3 bench_dspy_intro.py --backend tgi\n```\n\n\n\n## Benchmark vLLM\n```\npython3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests  --port 21000\n```\n\n```\npython3 bench_dspy_intro.py --backend vllm\n```\n"
  },
  {
    "path": "benchmark/dspy/bench_dspy_intro.py",
    "content": "\"\"\"\nAdapted from\nhttps://github.com/stanfordnlp/dspy/blob/34d8420383ec752037aa271825c1d3bf391e1277/intro.ipynb#L9\n\"\"\"\n\nimport argparse\n\nimport dspy\nfrom dspy.datasets import HotPotQA\n\n\nclass BasicQA(dspy.Signature):\n    \"\"\"Answer questions with short factoid answers.\"\"\"\n\n    question = dspy.InputField()\n    answer = dspy.OutputField(desc=\"often between 1 and 5 words\")\n\n\nclass GenerateAnswer(dspy.Signature):\n    \"\"\"Answer questions with short factoid answers.\"\"\"\n\n    context = dspy.InputField(desc=\"may contain relevant facts\")\n    question = dspy.InputField()\n    answer = dspy.OutputField(desc=\"often between 1 and 5 words\")\n\n\nclass RAG(dspy.Module):\n    def __init__(self, num_passages=3):\n        super().__init__()\n\n        self.retrieve = dspy.Retrieve(k=num_passages)\n        self.generate_answer = dspy.ChainOfThought(GenerateAnswer)\n\n    def forward(self, question):\n        context = self.retrieve(question).passages\n        prediction = self.generate_answer(context=context, question=question)\n        return dspy.Prediction(context=context, answer=prediction.answer)\n\n\ndef main(args):\n    # lm = dspy.OpenAI(model='gpt-3.5-turbo')\n    if args.backend == \"tgi\":\n        lm = dspy.HFClientTGI(\n            model=\"meta-llama/Llama-2-7b-chat-hf\",\n            port=args.port,\n            url=\"http://localhost\",\n        )\n    elif args.backend == \"sglang\":\n        lm = dspy.HFClientSGLang(\n            model=\"meta-llama/Llama-2-7b-chat-hf\",\n            port=args.port,\n            url=\"http://localhost\",\n        )\n    elif args.backend == \"vllm\":\n        lm = dspy.HFClientVLLM(\n            model=\"meta-llama/Llama-2-7b-chat-hf\",\n            port=args.port,\n            url=\"http://localhost\",\n        )\n    else:\n        raise ValueError(f\"Invalid backend: {args.backend}\")\n\n    colbertv2_wiki17_abstracts = dspy.ColBERTv2(\n        url=\"http://20.102.90.50:2017/wiki17_abstracts\"\n    )\n    dspy.settings.configure(lm=lm, rm=colbertv2_wiki17_abstracts)\n\n    # Load the dataset.\n    dataset = HotPotQA(\n        train_seed=1, train_size=20, eval_seed=2023, dev_size=args.dev_size, test_size=0\n    )\n\n    # Tell DSPy that the 'question' field is the input. Any other fields are labels and/or metadata.\n    trainset = [x.with_inputs(\"question\") for x in dataset.train]\n    devset = [x.with_inputs(\"question\") for x in dataset.dev]\n\n    print(len(trainset), len(devset))\n\n    train_example = trainset[0]\n    print(f\"Question: {train_example.question}\")\n    print(f\"Answer: {train_example.answer}\")\n\n    dev_example = devset[18]\n    print(f\"Question: {dev_example.question}\")\n    print(f\"Answer: {dev_example.answer}\")\n    print(f\"Relevant Wikipedia Titles: {dev_example.gold_titles}\")\n\n    print(\n        f\"For this dataset, training examples have input keys {train_example.inputs().keys()} and label keys {train_example.labels().keys()}\"\n    )\n    print(\n        f\"For this dataset, dev examples have input keys {dev_example.inputs().keys()} and label keys {dev_example.labels().keys()}\"\n    )\n\n    # Define the predictor.\n    generate_answer = dspy.Predict(BasicQA)\n\n    # Call the predictor on a particular input.\n    pred = generate_answer(question=dev_example.question)\n\n    # Print the input and the prediction.\n    print(f\"Question: {dev_example.question}\")\n    print(f\"Predicted Answer: {pred.answer}\")\n\n    lm.inspect_history(n=1)\n\n    # Define the predictor. Notice we're just changing the class. The signature BasicQA is unchanged.\n    generate_answer_with_chain_of_thought = dspy.ChainOfThought(BasicQA)\n\n    # Call the predictor on the same input.\n    pred = generate_answer_with_chain_of_thought(question=dev_example.question)\n\n    # Print the input, the chain of thought, and the prediction.\n    print(f\"Question: {dev_example.question}\")\n    print(f\"Thought: {pred.rationale.split('.', 1)[1].strip()}\")\n    print(f\"Predicted Answer: {pred.answer}\")\n\n    retrieve = dspy.Retrieve(k=3)\n    topK_passages = retrieve(dev_example.question).passages\n\n    print(\n        f\"Top {retrieve.k} passages for question: {dev_example.question} \\n\",\n        \"-\" * 30,\n        \"\\n\",\n    )\n\n    for idx, passage in enumerate(topK_passages):\n        print(f\"{idx+1}]\", passage, \"\\n\")\n\n    retrieve(\"When was the first FIFA World Cup held?\").passages[0]\n\n    from dspy.teleprompt import BootstrapFewShot\n\n    # Validation logic: check that the predicted answer is correct.\n    # Also check that the retrieved context does actually contain that answer.\n    def validate_context_and_answer(example, pred, trace=None):\n        answer_EM = dspy.evaluate.answer_exact_match(example, pred)\n        answer_PM = dspy.evaluate.answer_passage_match(example, pred)\n        return answer_EM and answer_PM\n\n    # Set up a basic teleprompter, which will compile our RAG program.\n    teleprompter = BootstrapFewShot(metric=validate_context_and_answer)\n\n    # Compile!\n    compiled_rag = teleprompter.compile(RAG(), trainset=trainset)\n\n    # Ask any question you like to this simple RAG program.\n    my_question = \"What castle did David Gregory inherit?\"\n\n    # Get the prediction. This contains `pred.context` and `pred.answer`.\n    pred = compiled_rag(my_question)\n\n    # Print the contexts and the answer.\n    print(f\"Question: {my_question}\")\n    print(f\"Predicted Answer: {pred.answer}\")\n    print(f\"Retrieved Contexts (truncated): {[c[:200] + '...' for c in pred.context]}\")\n\n    from dspy.evaluate.evaluate import Evaluate\n\n    # Set up the `evaluate_on_hotpotqa` function. We'll use this many times below.\n    evaluate_on_hotpotqa = Evaluate(\n        devset=devset,\n        num_threads=args.num_threads,\n        display_progress=True,\n        display_table=5,\n    )\n\n    # Evaluate the `compiled_rag` program with the `answer_exact_match` metric.\n    metric = dspy.evaluate.answer_exact_match\n    evaluate_on_hotpotqa(compiled_rag, metric=metric)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--port\", type=int)\n    parser.add_argument(\"--num-threads\", type=int, default=32)\n    parser.add_argument(\"--dev-size\", type=int, default=150)\n    parser.add_argument(\n        \"--backend\", type=str, choices=[\"sglang\", \"tgi\", \"vllm\"], default=\"sglang\"\n    )\n    args = parser.parse_args()\n\n    if args.port is None:\n        default_port = {\n            \"vllm\": 21000,\n            \"lightllm\": 22000,\n            \"tgi\": 24000,\n            \"sglang\": 30000,\n        }\n        args.port = default_port.get(args.backend, None)\n\n    main(args)\n"
  },
  {
    "path": "benchmark/fla/benchmark_layernorm_gated.py",
    "content": "from typing import Optional\n\nimport numpy as np\nimport torch\n\n# Import the function to benchmark\nfrom sglang.srt.layers.attention.fla.layernorm_gated import (\n    _layer_norm_fwd as layer_norm_fwd,\n)\nfrom sglang.srt.layers.attention.fla.layernorm_gated import (\n    rms_norm_ref,\n)\n\n\ndef benchmark_layer_norm_fwd(\n    M: int = 65536,\n    N: int = 128,\n    eps: float = 1e-6,\n    has_z: bool = True,\n    has_bias: bool = False,\n    group_size: Optional[int] = None,\n    norm_before_gate: bool = True,\n    is_rms_norm: bool = True,\n    dtype: torch.dtype = torch.float16,\n    warmup_iters: int = 10,\n    benchmark_iters: int = 100,\n    device: str = \"cuda\",\n    verbose: bool = True,\n):\n    \"\"\"\n    Benchmark layer_norm_fwd with specified parameters.\n\n    Args:\n        M: Number of rows (batch size)\n        N: Number of columns (hidden dimension)\n        eps: Epsilon for numerical stability\n        has_z: Whether to use gating tensor z\n        has_bias: Whether to use bias\n        group_size: Group size for group normalization (None = full dimension)\n        norm_before_gate: Whether to normalize before gating\n        is_rms_norm: Whether to use RMS normalization (vs LayerNorm)\n        dtype: Data type for tensors\n        warmup_iters: Number of warmup iterations\n        benchmark_iters: Number of benchmark iterations\n        device: Device to run on\n    \"\"\"\n    if verbose:\n        print(\"=\" * 80)\n        print(\"LayerNorm Forward Pass Benchmark\")\n        print(\"=\" * 80)\n        print(f\"\\nConfiguration:\")\n        print(f\"  x.shape: torch.Size([{M}, {N}])\")\n        print(f\"  weight.shape: torch.Size([{N}])\")\n        print(f\"  bias: {'torch.Size([{}])'.format(N) if has_bias else None}\")\n        print(f\"  eps: {eps}\")\n        print(f\"  z: {'torch.Size([{}, {}])'.format(M, N) if has_z else None}\")\n        print(f\"  out: None\")\n        print(f\"  group_size: {group_size}\")\n        print(f\"  norm_before_gate: {norm_before_gate}\")\n        print(f\"  is_rms_norm: {is_rms_norm}\")\n        print(f\"  dtype: {dtype}\")\n        print(f\"  device: {device}\")\n        print()\n\n    # Create input tensors\n    torch.manual_seed(42)\n    x = torch.randn(M, N, dtype=dtype, device=device)\n    weight = torch.randn(N, dtype=dtype, device=device)\n    bias = torch.randn(N, dtype=dtype, device=device) if has_bias else None\n    z = torch.randn(M, N, dtype=dtype, device=device) if has_z else None\n\n    # Ensure contiguous memory layout\n    x = x.contiguous()\n    weight = weight.contiguous()\n    if bias is not None:\n        bias = bias.contiguous()\n    if z is not None:\n        z = z.contiguous()\n\n    if verbose:\n        print(\"Warming up...\")\n    # Warmup\n    for _ in range(warmup_iters):\n        out, mean, rstd = layer_norm_fwd(\n            x=x,\n            weight=weight,\n            bias=bias,\n            eps=eps,\n            z=z,\n            out=None,\n            group_size=group_size,\n            norm_before_gate=norm_before_gate,\n            is_rms_norm=is_rms_norm,\n        )\n        torch.cuda.synchronize()\n\n    if verbose:\n        print(f\"Capturing CUDA graph...\")\n\n    # Capture the kernel execution in a CUDA graph\n    runs_per_measurement = 100\n\n    # Create output tensor for graph capture\n    out_graph = torch.empty_like(x)\n    mean_graph = (\n        torch.empty((x.shape[0],), dtype=torch.float32, device=x.device)\n        if not is_rms_norm\n        else None\n    )\n    rstd_graph = torch.empty((x.shape[0],), dtype=torch.float32, device=x.device)\n\n    # Capture the graph\n    graph = torch.cuda.CUDAGraph()\n    with torch.cuda.graph(graph):\n        for _ in range(runs_per_measurement):\n            out, mean, rstd = layer_norm_fwd(\n                x=x,\n                weight=weight,\n                bias=bias,\n                eps=eps,\n                z=z,\n                out=out_graph,\n                group_size=group_size,\n                norm_before_gate=norm_before_gate,\n                is_rms_norm=is_rms_norm,\n            )\n\n    if verbose:\n        print(\n            f\"Running benchmark with {benchmark_iters} iterations using CUDA graph...\"\n        )\n\n    # Benchmark by replaying the graph\n    times = []\n    for i in range(benchmark_iters):\n        start_event = torch.cuda.Event(enable_timing=True)\n        end_event = torch.cuda.Event(enable_timing=True)\n\n        start_event.record()\n        graph.replay()\n        end_event.record()\n        torch.cuda.synchronize()\n\n        # elapsed_time_ms returns milliseconds, divide by runs_per_measurement\n        elapsed_ms = start_event.elapsed_time(end_event)\n        times.append(\n            elapsed_ms / 1000.0 / runs_per_measurement\n        )  # Convert to seconds per run\n\n    # Compute statistics\n    times = np.array(times) * 1_000_000  # Convert to microseconds\n    mean_time = np.mean(times)\n    std_time = np.std(times)\n    min_time = np.min(times)\n    max_time = np.max(times)\n    median_time = np.median(times)\n    p95_time = np.percentile(times, 95)\n    p99_time = np.percentile(times, 99)\n\n    # Calculate throughput\n    num_elements = M * N\n    throughput_gelements_per_sec = (num_elements / mean_time) * 1_000_000 / 1e9\n\n    # Calculate memory bandwidth\n    # Read: x, weight, z (if has_z)\n    # Write: out, rstd, mean (if not rms_norm)\n    bytes_per_element = 2 if dtype == torch.float16 else 4  # fp16 or fp32\n    read_bytes = (M * N + N) * bytes_per_element  # x + weight\n    if has_z:\n        read_bytes += M * N * bytes_per_element  # z\n    write_bytes = M * N * bytes_per_element  # out\n    write_bytes += M * 4  # rstd (float32)\n    if not is_rms_norm:\n        write_bytes += M * 4  # mean (float32)\n\n    total_bytes = read_bytes + write_bytes\n    bandwidth_gb_per_sec = (total_bytes / mean_time) * 1_000_000 / 1e9\n\n    if verbose:\n        print(\"\\n\" + \"=\" * 80)\n        print(\"Benchmark Results\")\n        print(\"=\" * 80)\n        print(f\"\\nTiming Statistics (microseconds):\")\n        print(f\"  Mean:     {mean_time:.2f} us\")\n        print(f\"  Std Dev:  {std_time:.2f} us\")\n        print(f\"  Min:      {min_time:.2f} us\")\n        print(f\"  Max:      {max_time:.2f} us\")\n        print(f\"  Median:   {median_time:.2f} us\")\n        print(f\"  P95:      {p95_time:.2f} us\")\n        print(f\"  P99:      {p99_time:.2f} us\")\n\n        print(f\"\\nThroughput:\")\n        print(f\"  {throughput_gelements_per_sec:.2f} GElements/sec\")\n        print(f\"  {bandwidth_gb_per_sec:.2f} GB/sec\")\n\n        print(f\"\\nMemory Usage:\")\n        print(f\"  Input size: {read_bytes / 1e6:.2f} MB\")\n        print(f\"  Output size: {write_bytes / 1e6:.2f} MB\")\n        print(f\"  Total: {total_bytes / 1e6:.2f} MB\")\n\n    # Verify correctness against reference implementation\n    if verbose:\n        print(\"\\nVerifying correctness...\")\n    out_triton, mean_triton, rstd_triton = layer_norm_fwd(\n        x=x,\n        weight=weight,\n        bias=bias,\n        eps=eps,\n        z=z,\n        out=None,\n        group_size=group_size,\n        norm_before_gate=norm_before_gate,\n        is_rms_norm=is_rms_norm,\n    )\n\n    # Compute reference output\n    out_ref = rms_norm_ref(\n        x=x,\n        weight=weight,\n        bias=bias,\n        z=z,\n        eps=eps,\n        group_size=group_size,\n        norm_before_gate=norm_before_gate,\n        upcast=True,\n    )\n\n    # Compare outputs\n    max_diff = torch.max(torch.abs(out_triton - out_ref)).item()\n    mean_diff = torch.mean(torch.abs(out_triton - out_ref)).item()\n    rel_diff = torch.mean(\n        torch.abs(out_triton - out_ref) / (torch.abs(out_ref) + 1e-5)\n    ).item()\n\n    if verbose:\n        print(f\"\\nCorrectness Check (vs Reference Implementation):\")\n        print(f\"  Max absolute difference: {max_diff:.6e}\")\n        print(f\"  Mean absolute difference: {mean_diff:.6e}\")\n        print(f\"  Mean relative difference: {rel_diff:.6e}\")\n\n        if max_diff < 1e-2:\n            print(\"  ✓ PASS: Results match reference implementation\")\n        else:\n            print(\"  ✗ FAIL: Results do not match reference implementation\")\n\n        print(\"\\n\" + \"=\" * 80)\n\n    return {\n        \"mean_time_us\": mean_time,\n        \"std_time_us\": std_time,\n        \"min_time_us\": min_time,\n        \"max_time_us\": max_time,\n        \"median_time_us\": median_time,\n        \"p95_time_us\": p95_time,\n        \"p99_time_us\": p99_time,\n        \"throughput_gelements_per_sec\": throughput_gelements_per_sec,\n        \"bandwidth_gb_per_sec\": bandwidth_gb_per_sec,\n        \"max_diff\": max_diff,\n        \"mean_diff\": mean_diff,\n        \"rel_diff\": rel_diff,\n    }\n\n\ndef main():\n    \"\"\"Run the benchmark with the specified configuration.\"\"\"\n    # Configuration from user\n    config = {\n        \"M\": 65536,\n        \"N\": 128,\n        \"eps\": 1e-6,\n        \"has_z\": True,\n        \"has_bias\": False,\n        \"group_size\": None,\n        \"norm_before_gate\": True,\n        \"is_rms_norm\": True,\n        \"dtype\": torch.float16,\n        \"warmup_iters\": 10,\n        \"benchmark_iters\": 100,\n        \"device\": \"cuda\",\n    }\n\n    if not torch.cuda.is_available():\n        print(\"CUDA is not available. This benchmark requires a CUDA-enabled GPU.\")\n        return\n\n    results = benchmark_layer_norm_fwd(**config)\n\n    # Collect all results\n    all_results = []\n    # Test with different batch sizes\n    print(\"\\nRunning benchmarks for varying batch sizes...\")\n    for M in [256, 512, 1024, 4096, 16384, 65536, 2**17, 2**18]:\n        config_var = config.copy()\n        config_var[\"M\"] = M\n        config_var[\"warmup_iters\"] = 5\n        config_var[\"benchmark_iters\"] = 50\n        config_var[\"verbose\"] = False\n        result = benchmark_layer_norm_fwd(**config_var)\n        all_results.append({\"M\": M, \"N\": config_var[\"N\"], **result})\n        print(f\"  M={M:>5}: {result['mean_time_us']:>7.2f} us\")\n\n    # Print summary table\n    print(\"\\n\\n\")\n    print(\"=\" * 30)\n    print(\"SUMMARY TABLE - Varying Batch Size (M) with N=128\")\n    print(\"=\" * 30)\n    print(f\"{'M':>8} | {'Median (us)':>12}\")\n    print(\"-\" * 30)\n    for r in all_results:\n        print(f\"{r['M']:>8} | {r['median_time_us']:>12.2f}\")\n    print(\"=\" * 30)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "benchmark/generative_agents/README.md",
    "content": "## Download the dataset\n\n```\nwget -O agent_calls.jsonl https://drive.google.com/uc?export=download&id=19qLpD45e9JGTKF2cUjJJegwzSUEZEKht\n```\n\n## Run benchmark\n\nEnsure that this benchmark is run in a serial manner (using --parallel 1) to preserve any potential dependencies between requests.\n\n### Benchmark sglang\n```\npython -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000\n```\n\n```\npython3 bench_sglang.py --num-events 1000 --parallel 1\n```\n\n### Benchmark vllm\n```\npython3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000\n```\n\n```\npython3 bench_other.py --num-events 1000 --backend vllm --parallel 1\n```\n\n### Benchmark guidance\n```\npython3 bench_other.py --num-events 1000 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf\n```\n\n### Benchmark lmql\n\n```\npython3 bench_other.py --num-events 1000 --backend lmql --parallel 1\n```\n"
  },
  {
    "path": "benchmark/generative_agents/agent_functions.py",
    "content": "import sglang as sgl\n\n# here are the top five agent functions contributing ~70% LLM calls\n# reference: https://github.com/joonspk-research/generative_agents/\n\n\n@sgl.function\ndef poignancy_event(s, persona_name, persona_iss, event):\n    s += \"Here is a brief description of \" + persona_name + \".\\n\"\n    s += persona_iss + \"\\n\"\n    s += \"On the scale of 1 to 10, where 1 is purely mundane (e.g., brushing teeth, making bed) and 10 is extremely poignant (e.g., a break up, college acceptance), rate the likely poignancy of the following event for\"\n    s += persona_name + \".\\n\\n\"\n    s += \"Event: \" + event\n    s += \"Rate (return a number between 1 to 10):\"\n    s += sgl.gen(name=\"Rate\", max_tokens=2)\n\n\ndef poignancy_event_prompt(persona_name, persona_iss, event):\n    # return prompt and max_tokens\n    s = \"\"\n    s += \"Here is a brief description of \" + persona_name + \".\\n\"\n    s += persona_iss + \"\\n\"\n    s += \"On the scale of 1 to 10, where 1 is purely mundane (e.g., brushing teeth, making bed) and 10 is extremely poignant (e.g., a break up, college acceptance), rate the likely poignancy of the following event for\"\n    s += persona_name + \".\\n\\n\"\n    s += \"Event: \" + event\n    s += \"Rate (return a number between 1 to 10):\"\n    return {\"prompt\": s, \"max_tokens\": 2, \"stop\": None}\n\n\n@sgl.function\ndef generate_event_triple(s, persona_name, action):\n    s += \"\"\"Task: Turn the input into (subject, predicate, object).\nInput: Sam Johnson is eating breakfast.\nOutput: (Dolores Murphy, eat, breakfast)\n---\nInput: Joon Park is brewing coffee.\nOutput: (Joon Park, brew, coffee)\n---\nInput: Jane Cook is sleeping.\nOutput: (Jane Cook, is, sleep)\n---\nInput: Michael Bernstein is writing email on a computer.\nOutput: (Michael Bernstein, write, email)\n---\nInput: Percy Liang is teaching students in a classroom.\nOutput: (Percy Liang, teach, students)\n---\nInput: Merrie Morris is running on a treadmill.\nOutput: (Merrie Morris, run, treadmill)\n---\"\"\"\n    s += persona_name + \"is\" + action + \".\\n\"\n    s += \"(\" + persona_name + \",\"\n    s += sgl.gen(name=\"Triple\", max_tokens=20, stop=\")\")\n\n\ndef generate_event_triple_prompt(persona_name, action):\n    s = \"\"\n    s += \"\"\"Task: Turn the input into (subject, predicate, object).\nInput: Sam Johnson is eating breakfast.\nOutput: (Dolores Murphy, eat, breakfast)\n---\nInput: Joon Park is brewing coffee.\nOutput: (Joon Park, brew, coffee)\n---\nInput: Jane Cook is sleeping.\nOutput: (Jane Cook, is, sleep)\n---\nInput: Michael Bernstein is writing email on a computer.\nOutput: (Michael Bernstein, write, email)\n---\nInput: Percy Liang is teaching students in a classroom.\nOutput: (Percy Liang, teach, students)\n---\nInput: Merrie Morris is running on a treadmill.\nOutput: (Merrie Morris, run, treadmill)\n---\"\"\"\n    s += persona_name + \"is\" + action + \".\\n\"\n    s += \"(\" + persona_name + \",\"\n    return {\"prompt\": s, \"max_tokens\": 20, \"stop\": \")\"}\n\n\n@sgl.function\ndef generate_pronunciatio(s, action):\n    s += \"Convert an action description to an emoji (important: use two or less emojis).\\n\"\n    s += \"Action description: \" + action + \".\\n\"\n    s += \"Emoji:\" + sgl.gen(name=\"Emoji\", max_tokens=6)\n\n\ndef generate_pronunciatio_prompt(action):\n    s = \"\"\n    s += \"Convert an action description to an emoji (important: use two or less emojis).\\n\"\n    s += \"Action description: \" + action + \".\\n\"\n    s += \"Emoji:\"\n    return {\"prompt\": s, \"max_tokens\": 6, \"stop\": None}\n\n\n@sgl.function\ndef action_location_sector(\n    s,\n    persona_name,\n    living_sector,\n    living_sector_areas,\n    current_sector,\n    current_sector_areas,\n    daily_plan,\n    sector_options,\n    current_action,\n    next_action,\n):\n    s += \"\"\"Task -- choose an appropriate area  from the area options for a task at hand.\nSam Kim lives in {Sam Kim's house} that has Sam Kim's room, bathroom, kitchen.\nSam Kim is currently in {Sam Kim's house} that has Sam Kim's room, bathroom, kitchen.\nArea options: {Sam Kim's house, The Rose and Crown Pub, Hobbs Cafe, Oak Hill College, Johnson Park, Harvey Oak Supply Store, The Willows Market and Pharmacy}.\n* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.\n* Must be one of the \"Area options,\" verbatim.\nFor taking a walk, Sam Kim should go to the following area: {Johnson Park}\n---\nJane Anderson lives in {Oak Hill College Student Dormatory} that has Jane Anderson's room.\nJane Anderson is currently in {Oak Hill College} that has a classroom, library\nArea options: {Oak Hill College Student Dormatory, The Rose and Crown Pub, Hobbs Cafe, Oak Hill College, Johnson Park, Harvey Oak Supply Store, The Willows Market and Pharmacy}.\n* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.\n* Must be one of the \"Area options,\" verbatim.\nFor eating dinner, Jane Anderson should go to the following area: {Hobbs Cafe}\n---\"\"\"\n    s += (\n        persona_name\n        + \" lives in \"\n        + living_sector\n        + \" that has \"\n        + living_sector_areas\n        + \".\\n\"\n    )\n    s += (\n        persona_name\n        + \" is currently in \"\n        + current_sector\n        + \" that has \"\n        + current_sector_areas\n        + \".\\n\"\n    )\n    s += daily_plan + \".\\n\"\n    s += \"Area options: \" + sector_options + \".\\n\"\n    s += \"\"\"* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.\n* Must be one of the \"Area options,\" verbatim.\\n\"\"\"\n    s += (\n        persona_name\n        + \" is \"\n        + current_action\n        + \". For \"\n        + next_action\n        + \", \"\n        + persona_name\n        + \" should go to the following area: {\"\n    )\n    s += sgl.gen(name=\"Location\", max_tokens=10, stop=\"}\")\n\n\ndef action_location_sector_prompt(\n    persona_name,\n    living_sector,\n    living_sector_areas,\n    current_sector,\n    current_sector_areas,\n    daily_plan,\n    sector_options,\n    current_action,\n    next_action,\n):\n    s = \"\"\n    s += \"\"\"Task -- choose an appropriate area  from the area options for a task at hand.\nSam Kim lives in {Sam Kim's house} that has Sam Kim's room, bathroom, kitchen.\nSam Kim is currently in {Sam Kim's house} that has Sam Kim's room, bathroom, kitchen.\nArea options: {Sam Kim's house, The Rose and Crown Pub, Hobbs Cafe, Oak Hill College, Johnson Park, Harvey Oak Supply Store, The Willows Market and Pharmacy}.\n* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.\n* Must be one of the \"Area options,\" verbatim.\nFor taking a walk, Sam Kim should go to the following area: {Johnson Park}\n---\nJane Anderson lives in {Oak Hill College Student Dormatory} that has Jane Anderson's room.\nJane Anderson is currently in {Oak Hill College} that has a classroom, library\nArea options: {Oak Hill College Student Dormatory, The Rose and Crown Pub, Hobbs Cafe, Oak Hill College, Johnson Park, Harvey Oak Supply Store, The Willows Market and Pharmacy}.\n* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.\n* Must be one of the \"Area options,\" verbatim.\nFor eating dinner, Jane Anderson should go to the following area: {Hobbs Cafe}\n---\"\"\"\n    s += (\n        persona_name\n        + \" lives in \"\n        + living_sector\n        + \" that has \"\n        + living_sector_areas\n        + \".\\n\"\n    )\n    s += (\n        persona_name\n        + \" is currently in \"\n        + current_sector\n        + \" that has \"\n        + current_sector_areas\n        + \".\\n\"\n    )\n    s += daily_plan + \".\\n\"\n    s += \"Area options: \" + sector_options + \".\\n\"\n    s += \"\"\"* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.\n* Must be one of the \"Area options,\" verbatim.\\n\"\"\"\n    s += (\n        persona_name\n        + \" is \"\n        + current_action\n        + \". For \"\n        + next_action\n        + \", \"\n        + persona_name\n        + \" should go to the following area: {\"\n    )\n    return {\"prompt\": s, \"max_tokens\": 10, \"stop\": \"}\"}\n\n\n@sgl.function\ndef action_location_object(\n    s, persona_name, target_sector, target_sector_areas, current_action, next_action\n):\n    s += \"\"\"\nJane Anderson is in kitchen in Jane Anderson's house.\nJane Anderson is going to Jane Anderson's house that has the following areas: {kitchen,  bedroom, bathroom}\nStay in the current area if the activity can be done there. Never go into other people's rooms unless necessary.\nFor cooking, Jane Anderson should go to the following area in Jane Anderson's house:\nAnswer: {kitchen}\n---\nTom Watson is in common room in Tom Watson's apartment.\nTom Watson is going to Hobbs Cafe that has the following areas: {cafe}\nStay in the current area if the activity can be done there. Never go into other people's rooms unless necessary.\nFor getting coffee, Tom Watson should go to the following area in Hobbs Cafe:\nAnswer: {cafe}\n---\"\"\"\n    s += (\n        persona_name\n        + \" is going to \"\n        + target_sector\n        + \" that has the following areas: {\"\n        + target_sector_areas\n        + \"}\\n\"\n    )\n    s += \"\"\"* Stay in the current area if the activity can be done there.\n* NEVER go into other people's rooms unless necessary.\"\"\"\n    s += (\n        persona_name\n        + \" is \"\n        + current_action\n        + \". For \"\n        + next_action\n        + \", \"\n        + persona_name\n        + \"should go to the following area in \"\n        + target_sector\n    )\n    s += \" (MUST pick one of {\" + target_sector_areas + \"}):\\n\"\n    s += \"Answer: {\" + sgl.gen(name=\"Area\", max_tokens=5, stop=\"}\")\n\n\ndef action_location_object_prompt(\n    persona_name, target_sector, target_sector_areas, current_action, next_action\n):\n    s = \"\"\n    s += \"\"\"\nJane Anderson is in kitchen in Jane Anderson's house.\nJane Anderson is going to Jane Anderson's house that has the following areas: {kitchen,  bedroom, bathroom}\nStay in the current area if the activity can be done there. Never go into other people's rooms unless necessary.\nFor cooking, Jane Anderson should go to the following area in Jane Anderson's house:\nAnswer: {kitchen}\n---\nTom Watson is in common room in Tom Watson's apartment.\nTom Watson is going to Hobbs Cafe that has the following areas: {cafe}\nStay in the current area if the activity can be done there. Never go into other people's rooms unless necessary.\nFor getting coffee, Tom Watson should go to the following area in Hobbs Cafe:\nAnswer: {cafe}\n---\"\"\"\n    s += (\n        persona_name\n        + \" is going to \"\n        + target_sector\n        + \" that has the following areas: {\"\n        + target_sector_areas\n        + \"}\\n\"\n    )\n    s += \"\"\"* Stay in the current area if the activity can be done there.\n* NEVER go into other people's rooms unless necessary.\"\"\"\n    s += (\n        persona_name\n        + \" is \"\n        + current_action\n        + \". For \"\n        + next_action\n        + \", \"\n        + persona_name\n        + \"should go to the following area in \"\n        + target_sector\n    )\n    s += \" (MUST pick one of {\" + target_sector_areas + \"}):\\n\"\n    s += \"Answer: {\"\n    return {\"prompt\": s, \"max_tokens\": 5, \"stop\": \"}\"}\n"
  },
  {
    "path": "benchmark/generative_agents/bench_other.py",
    "content": "import argparse\nimport json\nimport time\n\nfrom agent_functions import (\n    action_location_object_prompt,\n    action_location_sector_prompt,\n    generate_event_triple_prompt,\n    generate_pronunciatio_prompt,\n    poignancy_event_prompt,\n)\nfrom tqdm import tqdm\n\nfrom sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate\nfrom sglang.utils import dump_state_text, read_jsonl\n\n\ndef main(args):\n    lines = read_jsonl(args.data_path)[: args.num_events]\n    mapping = {\n        \"poignancy_event\": poignancy_event_prompt,\n        \"generate_event_triple\": generate_event_triple_prompt,\n        \"generate_pronunciatio\": generate_pronunciatio_prompt,\n        \"action_location_sector\": action_location_sector_prompt,\n        \"action_location_object\": action_location_object_prompt,\n    }\n\n    arguments = [mapping[k](**v) for l in lines for k, v in l.items()]\n    states = []\n\n    # Select backend\n    call_generate = get_call_generate(args)\n\n    def get_one_answer(arg):\n        answer = call_generate(**arg, temperature=0)\n        states.append(answer)\n\n    async def get_one_answer_async(arg):\n        answer = await call_generate(**arg, temperature=0)\n        states.append(answer)\n\n    tic = time.perf_counter()\n    # we always sequentially execute agent calls to maintain its dependency\n    if args.backend != \"lmql\":\n        for arg in tqdm(arguments):\n            get_one_answer(arg)\n    else:\n        import asyncio\n\n        loop = asyncio.get_event_loop()\n        for arg in tqdm(arguments):\n            loop.run_until_complete(get_one_answer_async(arg))\n    latency = time.perf_counter() - tic\n\n    print(f\"Latency: {latency:.3f}\")\n\n    # Write results\n    dump_state_text(f\"tmp_output_{args.backend}.txt\", states)\n\n    with open(args.result_file, \"a\") as fout:\n        value = {\n            \"task\": \"Generative Agents\",\n            \"backend\": args.backend,\n            \"num_gpus\": 1,\n            \"latency\": round(latency, 3),\n            # to pack weighted functions as a single agent\n            \"num_requests\": len(arguments) / len(mapping),\n            \"other\": {\n                \"parallel\": args.parallel,\n            },\n        }\n        fout.write(json.dumps(value) + \"\\n\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--data-path\", type=str, default=\"agent_calls.jsonl\")\n    parser.add_argument(\"--num-events\", type=int, default=10)\n    args = add_common_other_args_and_parse(parser)\n    main(args)\n"
  },
  {
    "path": "benchmark/generative_agents/bench_sglang.py",
    "content": "import argparse\nimport json\nimport time\n\nfrom agent_functions import (\n    action_location_object,\n    action_location_sector,\n    generate_event_triple,\n    generate_pronunciatio,\n    poignancy_event,\n)\n\nimport sglang as sgl\nfrom sglang.test.test_utils import (\n    add_common_sglang_args_and_parse,\n    select_sglang_backend,\n)\nfrom sglang.utils import dump_state_text, read_jsonl\n\n\ndef main(args):\n    lines = read_jsonl(args.data_path)[: args.num_events]\n    mapping = {\n        \"poignancy_event\": poignancy_event,\n        \"generate_event_triple\": generate_event_triple,\n        \"generate_pronunciatio\": generate_pronunciatio,\n        \"action_location_sector\": action_location_sector,\n        \"action_location_object\": action_location_object,\n    }\n    arguments = [{mapping[k]: v for k, v in l.items()} for l in lines]\n\n    # Select backend\n    backend = select_sglang_backend(args)\n    sgl.set_default_backend(backend)\n\n    states = []\n    # Run requests\n    tic = time.perf_counter()\n    for a in arguments:\n        # only a single key in the dict\n        for func, arg in a.items():\n            result = func.run(**arg)\n        result.sync()\n        states.append(result)\n    latency = time.perf_counter() - tic\n\n    # Compute accuracy\n    print(f\"Latency: {latency:.3f}\")\n\n    # Write results\n    dump_state_text(f\"tmp_output_{args.backend}.txt\", states)\n\n    with open(args.result_file, \"a\") as fout:\n        value = {\n            \"task\": \"Generative Agents\",\n            \"backend\": args.backend,\n            \"num_gpus\": 1,\n            \"latency\": round(latency, 3),\n            # to pack weighted functions as a single agent\n            \"num_requests\": len(arguments) / len(mapping),\n            \"other\": {\n                \"num_events\": args.num_events,\n                \"parallel\": args.parallel,\n            },\n        }\n        fout.write(json.dumps(value) + \"\\n\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--data-path\", type=str, default=\"agent_calls.jsonl\")\n    parser.add_argument(\"--num-events\", type=int, default=10)\n    args = add_common_sglang_args_and_parse(parser)\n    main(args)\n"
  },
  {
    "path": "benchmark/gpt_oss/README.md",
    "content": "# How to reproduce the result of GPT-OSS with SGLang\n\n### Install the latest SGLang\n\n```bash\ngit clone https://github.com/sgl-project/sglang.git\ncd sglang\ngit checkout v0.5.1.post3\n\npip install --upgrade pip\npip install -e \"python[all]\"\n```\n\n### Reproduce the benchmark throughput result (Batch Size 1)\n\nLaunch Command\n\n```bash\n# MXFP4 120B on H100\npython3 -m sglang.launch_server --model openai/gpt-oss-120b --tp 8 --attention-backend triton\n\n# BF16 120B on H100\npython3 -m sglang.launch_server --model lmsys/gpt-oss-120b-bf16 --tp 8 --attention-backend triton\n\n# MXFP4 120B on B200\npython3 -m sglang.launch_server --model openai/gpt-oss-120b --tp 4\n\n# BF16 120B on B200\npython3 -m sglang.launch_server --model lmsys/gpt-oss-120b-bf16 --tp 4\n```\n\nBenchmark Command\n\n```bash\n\n# MXFP4 120B on H100\npython3 -m sglang.bench_one_batch_server --model openai/gpt-oss-120b --base-url http://localhost:30000 --batch-size 1 --input-len 1024 --output-len 512 --show-report\n```\n\n### Reproduce the benchmark throughput result (Batch Size 32)\n\nLaunch Command\n\n```bash\n# MXFP4 120B on H100\npython3 -m sglang.launch_server --model openai/gpt-oss-120b --tp 8\n\n# BF16 120B on H100\npython3 -m sglang.launch_server --model lmsys/gpt-oss-120b-bf16 --tp 8\n\n# MXFP4 120B on B200\npython3 -m sglang.launch_server --model openai/gpt-oss-120b --tp 4\n\n# BF16 120B on B200\npython3 -m sglang.launch_server --model lmsys/gpt-oss-120b-bf16 --tp 4\n```\n\nBenchmark Command\n\n```bash\npython3 -m sglang.bench_one_batch_server --model openai/gpt-oss-120b --base-url http://localhost:30000 --batch-size 32 --input-len 1024 8192 --output-len 512 --show-report\n```\n\n### Reproduce the evaluation result\n\nInstall gpt-oss\n\n```bash\ngit clone https://github.com/openai/gpt-oss.git\ncd gpt-oss\npip install -e .\n```\n\nEvaluation Command\n\n```bash\nDATASET=gpqa\nBASE_URL=YOUR_BASE_URL\nOPENAI_API_KEY=dummy python -m gpt_oss.evals \\\n    --base-url ${BASE_URL}/v1 \\\n    --model dummy \\\n    --reasoning-effort low,medium,high \\\n    --eval $DATASET \\\n    --n-threads 1000\n```\n\n### Reproduce the benchmark result of acceptance length\n> Note: On B200, if top k is 1, set `--attention-backend trtllm_mha`\n```bash\ngit clone https://github.com/sgl-project/SpecForge.git\ncd SpecForge/benchmarks\nconfig_list=(\n    \"1,0,0,0\"\n    \"1,3,1,4\"\n    \"1,5,4,8\"\n)\npython3 bench_model_speedup.py \\\n    --model-path openai/gpt-oss-120b \\\n    --speculative-draft-model-path lmsys/EAGLE3-gpt-oss-120b-bf16 \\\n    --port 20001 \\\n    --trust-remote-code \\\n    --mem-fraction-static 0.8 \\\n    --tp-size 4 \\\n    --attention-backend fa3 \\\n    --config-list \"${config_list[@]}\" \\\n    --benchmark-list mtbench:80 gsm8k:200 humaneval:200 math500:200 \\\n    --output lmsys_gpt-oss-120b_Eagle3_result.jsonl\n\npython3 bench_model_speedup.py \\\n    --model-path openai/gpt-oss-120b \\\n    --speculative-draft-model-path nvidia/gpt-oss-120b-Eagle3 \\\n    --port 20001 \\\n    --trust-remote-code \\\n    --mem-fraction-static 0.8 \\\n    --tp-size 4 \\\n    --attention-backend fa3 \\\n    --config-list \"${config_list[@]}\" \\\n    --benchmark-list mtbench:80 gsm8k:200 humaneval:200 math500:200 \\\n    --output nv_gpt-oss-120b_Eagle3_result.jsonl\n```\n\n### Reproduce the result of speculative decoding speedup\n\nLaunch Command\n\n```bash\n# On Hopper:\n# - Tree decoding (topk > 1) and chain decoding (topk = 1) are supported on both FA3 and Triton backends.\npython3 -m sglang.launch_server --model openai/gpt-oss-120b --speculative-algorithm EAGLE3 --speculative-draft-model-path lmsys/EAGLE3-gpt-oss-120b-bf16 --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 --tp 4\npython3 -m sglang.launch_server --model openai/gpt-oss-120b --speculative-algorithm EAGLE3 --speculative-draft-model-path lmsys/EAGLE3-gpt-oss-120b-bf16 --speculative-num-steps 5 --speculative-eagle-topk 4 --speculative-num-draft-tokens 8 --tp 4\n\n# On Blackwell:\n# - Chain decoding (topk = 1) is supported on TRTLLM-MHA backend. Tree decoding (topk > 1) is in progress, stay tuned!\n# - Both tree decoding (topk > 1) and chain decoding (topk = 1) are supported on the Triton backend.\npython3 -m sglang.launch_server --model openai/gpt-oss-120b --speculative-algo EAGLE3 --speculative-draft-model-path lmsys/EAGLE3-gpt-oss-120b-bf16 --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 --tp 4\npython3 -m sglang.launch_server --model openai/gpt-oss-120b --speculative-algo EAGLE3 --speculative-draft-model-path lmsys/EAGLE3-gpt-oss-120b-bf16 --speculative-num-steps 5 --speculative-eagle-topk 4 --speculative-num-draft-tokens 8 --attention-backend triton --tp 4\n```\n\nBenchmark Command\n\n```bash\nconfig_list=(\n    \"1,0,0,0\"\n    \"1,3,1,4\"\n    \"1,5,4,8\"\n)\npython3 bench_model_speedup.py \\\n    --model-path openai/gpt-oss-120b \\\n    --speculative-draft-model-path lmsys/EAGLE3-gpt-oss-120b-bf16 \\\n    --port 20001 \\\n    --trust-remote-code \\\n    --mem-fraction-static 0.8 \\\n    --tp-size 4 \\\n    --attention-backend fa3 \\\n    --config-list \"${config_list[@]}\" \\\n    --benchmark-list gsm8k:200 humaneval:200 math500:200 \\\n    --output lmsys_gpt-oss-120b_Eagle3_result.jsonl\n```\n\nWe can gain the best speedup with the following settings:\n\n- **1.39x** speedup with the `--speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4` setting.\n- **1.52x** speedup with the `--speculative-num-steps 5 --speculative-eagle-topk 4 --speculative-num-draft-tokens 8` setting.\n"
  },
  {
    "path": "benchmark/gsm8k/README.md",
    "content": "## Run benchmark\n\n### Using GSM8K Platinum\n\nGSM8K Platinum is a revised version of the GSM8K test set with corrected labels and removed ambiguous questions. It can be more stable than the original GSM8K dataset. It's a drop-in replacement that can be used by adding the `--platinum` flag:\n\n```\npython3 bench_sglang.py --num-shots 8 --num-questions 1209 --parallel 1209 --platinum\n```\n\nFor more information, see: https://huggingface.co/datasets/madrylab/gsm8k-platinum\n\n### Benchmark sglang\n```\npython -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000\n```\n\n```\npython3 bench_sglang.py --num-questions 200\n```\n\n\n### Benchmark vllm\n```\npython3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000\n```\n\n```\npython3 bench_other.py --num-questions 200 --backend vllm\n```\n\n\n### Benchmark lightllm\n```\n# A10G\npython -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000\n```\n\n```\npython3 bench_other.py --num-questions 200 --backend lightllm\n```\n\n\n### Benchmark guidance\n```\npython3 bench_other.py --num-questions 200 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf\n```\n\n\n### Benchmark lmql\n```\nCUDA_VISIBLE_DEVICES=0,1 lmql serve-model meta-llama/Llama-2-7b-chat-hf --cuda --port 23000\n```\n\n```\npython3 bench_other.py --num-questions 100 --backend lmql --parallel 2\n```\n"
  },
  {
    "path": "benchmark/gsm8k/bench_other.py",
    "content": "import argparse\nimport ast\nimport asyncio\nimport json\nimport re\nimport time\nfrom concurrent.futures import ThreadPoolExecutor\n\nimport numpy as np\nfrom datasets import load_dataset\nfrom tqdm import tqdm\n\nfrom sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate\nfrom sglang.utils import download_and_cache_file, dump_state_text, read_jsonl\n\nINVALID = -9999999\n\n\ndef get_one_example(lines, i, include_answer):\n    ret = \"Question: \" + lines[i][\"question\"] + \"\\nAnswer:\"\n    if include_answer:\n        ret += \" \" + lines[i][\"answer\"]\n    return ret\n\n\ndef get_few_shot_examples(lines, k):\n    ret = \"\"\n    for i in range(k):\n        ret += get_one_example(lines, i, True) + \"\\n\\n\"\n    return ret\n\n\ndef get_answer_value(answer_str):\n    answer_str = answer_str.replace(\",\", \"\")\n    numbers = re.findall(r\"\\d+\", answer_str)\n    if len(numbers) < 1:\n        return INVALID\n    try:\n        return ast.literal_eval(numbers[-1])\n    except SyntaxError:\n        return INVALID\n\n\ndef main(args):\n    # Select backend\n    call_generate = get_call_generate(args)\n\n    # Read data\n    if args.platinum:\n        print(\"Loading GSM8K Platinum dataset from HuggingFace...\")\n        dataset = load_dataset(\"madrylab/gsm8k-platinum\", \"main\", split=\"test\")\n        lines = [\n            {\"question\": item[\"question\"], \"answer\": item[\"answer\"]} for item in dataset\n        ]\n    else:\n        url = \"https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl\"\n        filename = download_and_cache_file(url)\n        lines = list(read_jsonl(filename))\n\n    # Construct prompts\n    num_questions = args.num_questions\n    num_shots = args.num_shots\n    few_shot_examples = get_few_shot_examples(lines, num_shots)\n\n    questions = []\n    labels = []\n    for i in range(len(lines[:num_questions])):\n        questions.append(get_one_example(lines, i, False))\n        labels.append(get_answer_value(lines[i][\"answer\"]))\n    assert all(l != INVALID for l in labels)\n\n    states = [None] * len(labels)\n\n    # Run requests\n    if args.backend != \"lmql\":\n        # Use thread pool\n        def get_one_answer(i):\n            answer = call_generate(\n                prompt=few_shot_examples + questions[i],\n                temperature=0,\n                max_tokens=256,\n                stop=[\"Question\", \"Assistant:\", \"<|separator|>\"],\n            )\n            states[i] = answer\n\n        tic = time.perf_counter()\n        if args.parallel == 1:\n            for i in tqdm(range(len(questions))):\n                get_one_answer(i)\n        else:\n            with ThreadPoolExecutor(args.parallel) as executor:\n                list(\n                    tqdm(\n                        executor.map(get_one_answer, list(range(len(questions)))),\n                        total=len(questions),\n                    )\n                )\n\n    else:\n        # Use asyncio\n        async def batched_call(batch_size):\n            for i in range(0, len(questions), batch_size):\n                tasks = []\n                for q in questions[i : i + batch_size]:\n                    tasks.append(\n                        call_generate(\n                            few_shot_examples + q,\n                            temperature=0,\n                            max_tokens=256,\n                            stop=\"Question\",\n                        )\n                    )\n                rets = await asyncio.gather(*tasks)\n                for j in range(len(rets)):\n                    states[i + j] = rets[j]\n\n        tic = time.perf_counter()\n        asyncio.run(batched_call(batch_size=args.parallel))\n    latency = time.perf_counter() - tic\n\n    preds = []\n    for i in range(len(states)):\n        preds.append(get_answer_value(states[i]))\n\n    # Compute accuracy\n    acc = np.mean(np.array(preds) == np.array(labels))\n    invalid = np.mean(np.array(preds) == INVALID)\n\n    # Print results\n    print(f\"Accuracy: {acc:.3f}\")\n    print(f\"Invalid: {invalid:.3f}\")\n    print(f\"Latency: {latency:.3f} s\")\n\n    # Dump results\n    dump_state_text(f\"tmp_output_{args.backend}.txt\", states)\n\n    with open(args.result_file, \"a\") as fout:\n        value = {\n            \"task\": \"gsm8k-platinum\" if args.platinum else \"gsm8k\",\n            \"backend\": args.backend,\n            \"num_gpus\": 1,\n            \"latency\": round(latency, 3),\n            \"accuracy\": round(acc, 3),\n            \"num_requests\": args.num_questions,\n            \"other\": {\n                \"num_questions\": args.num_questions,\n                \"parallel\": args.parallel,\n            },\n        }\n        fout.write(json.dumps(value) + \"\\n\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--num-shots\", type=int, default=5)\n    parser.add_argument(\"--data-path\", type=str, default=\"test.jsonl\")\n    parser.add_argument(\"--num-questions\", type=int, default=200)\n    parser.add_argument(\n        \"--platinum\",\n        action=\"store_true\",\n        help=\"Use GSM8K Platinum dataset (drop-in replacement with corrected labels)\",\n    )\n    args = add_common_other_args_and_parse(parser)\n    main(args)\n"
  },
  {
    "path": "benchmark/gsm8k/bench_sglang.py",
    "content": "import argparse\nimport ast\nimport json\nimport os\nimport re\nimport time\n\nimport numpy as np\nfrom datasets import load_dataset\n\nfrom sglang.lang.api import set_default_backend\nfrom sglang.test.test_utils import (\n    add_common_sglang_args_and_parse,\n    dump_bench_raw_result,\n    select_sglang_backend,\n)\nfrom sglang.utils import download_and_cache_file, dump_state_text, read_jsonl\n\nINVALID = -9999999\n\n\ndef get_one_example(lines, i, include_answer):\n    ret = \"Question: \" + lines[i][\"question\"] + \"\\nAnswer:\"\n    if include_answer:\n        ret += \" \" + lines[i][\"answer\"]\n    return ret\n\n\ndef get_few_shot_examples(lines, k):\n    ret = \"\"\n    for i in range(k):\n        ret += get_one_example(lines, i, True) + \"\\n\\n\"\n    return ret\n\n\ndef get_answer_value(answer_str):\n    answer_str = answer_str.replace(\",\", \"\")\n    numbers = re.findall(r\"\\d+\", answer_str)\n    if len(numbers) < 1:\n        return INVALID\n    try:\n        return ast.literal_eval(numbers[-1])\n    except SyntaxError:\n        return INVALID\n\n\ndef main(args):\n    # Select backend\n    set_default_backend(select_sglang_backend(args))\n\n    # Load tokenizer if enable_thinking is set\n    tokenizer = None\n    if args.enable_thinking:\n        from transformers import AutoTokenizer\n\n        assert (\n            args.tokenizer_path is not None\n        ), \"--tokenizer-path is required when --enable-thinking is set\"\n        tokenizer = AutoTokenizer.from_pretrained(\n            args.tokenizer_path, trust_remote_code=True\n        )\n\n    # Read data\n    if args.platinum:\n        print(\"Loading GSM8K Platinum dataset from HuggingFace...\")\n        dataset = load_dataset(\"madrylab/gsm8k-platinum\", \"main\", split=\"test\")\n        lines = [\n            {\"question\": item[\"question\"], \"answer\": item[\"answer\"]} for item in dataset\n        ]\n    else:\n        data_path = args.data_path\n        url = \"https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl\"\n        if not os.path.isfile(data_path):\n            data_path = download_and_cache_file(url)\n        lines = list(read_jsonl(data_path))\n\n    # Construct prompts\n    num_questions = args.num_questions\n    num_shots = args.num_shots\n    few_shot_examples = get_few_shot_examples(lines, num_shots)\n\n    questions = []\n    labels = []\n    for i in range(len(lines[:num_questions])):\n        raw_question = few_shot_examples + get_one_example(lines, i, False)\n        if tokenizer is not None:\n            messages = [{\"role\": \"user\", \"content\": raw_question}]\n            raw_question = tokenizer.apply_chat_template(\n                messages,\n                tokenize=False,\n                add_generation_prompt=True,\n                enable_thinking=True,\n            )\n        questions.append(raw_question)\n        labels.append(get_answer_value(lines[i][\"answer\"]))\n    assert all(l != INVALID for l in labels)\n    arguments = [{\"question\": q} for q in questions]\n\n    #####################################\n    ######### SGL Program Begin #########\n    #####################################\n\n    import sglang as sgl\n\n    @sgl.function\n    def few_shot_gsm8k(s, question):\n        s += question\n        s += sgl.gen(\n            \"answer\",\n            max_tokens=args.max_new_tokens,\n            stop=[\"Question\", \"Assistant:\", \"<|separator|>\"],\n        )\n\n    #####################################\n    ########## SGL Program End ##########\n    #####################################\n\n    # Run requests\n    tic = time.perf_counter()\n    states = few_shot_gsm8k.run_batch(\n        arguments,\n        temperature=args.temperature,\n        top_p=args.top_p,\n        num_threads=args.parallel,\n        progress_bar=True,\n    )\n    latency = time.perf_counter() - tic\n\n    preds = []\n    for i in range(len(states)):\n        preds.append(get_answer_value(states[i][\"answer\"]))\n\n    # Compute accuracy\n    acc = np.mean(np.array(preds) == np.array(labels))\n    invalid = np.mean(np.array(preds) == INVALID)\n\n    # Compute speed\n    num_output_tokens = sum(\n        s.get_meta_info(\"answer\")[\"completion_tokens\"] for s in states\n    )\n    output_throughput = num_output_tokens / latency\n\n    # Print results\n    print(f\"Accuracy: {acc:.3f}\")\n    print(f\"Invalid: {invalid:.3f}\")\n    print(f\"Latency: {latency:.3f} s\")\n    print(f\"Output throughput: {output_throughput:.3f} token/s\")\n\n    # Dump results\n    dump_state_text(f\"tmp_output_{args.backend}.txt\", states)\n    dump_bench_raw_result(\n        path=args.raw_result_file,\n        states=states,\n        preds=preds,\n        labels=labels,\n    )\n\n    with open(args.result_file, \"a\") as fout:\n        value = {\n            \"task\": \"gsm8k-platinum\" if args.platinum else \"gsm8k\",\n            \"backend\": args.backend,\n            \"num_gpus\": 1,\n            \"latency\": round(latency, 3),\n            \"accuracy\": round(acc, 3),\n            \"num_requests\": args.num_questions,\n            \"other\": {\n                \"num_questions\": args.num_questions,\n                \"parallel\": args.parallel,\n            },\n        }\n        fout.write(json.dumps(value) + \"\\n\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--num-shots\", type=int, default=5)\n    parser.add_argument(\"--data-path\", type=str, default=\"test.jsonl\")\n    parser.add_argument(\"--num-questions\", type=int, default=200)\n    parser.add_argument(\"--max-new-tokens\", type=int, default=512)\n    parser.add_argument(\"--temperature\", type=float, default=0.0)\n    parser.add_argument(\"--top-p\", type=float, default=1.0)\n    parser.add_argument(\n        \"--enable-thinking\",\n        action=\"store_true\",\n        help=\"Enable thinking mode by wrapping prompts with chat template\",\n    )\n    parser.add_argument(\n        \"--tokenizer-path\",\n        type=str,\n        default=None,\n        help=\"Path to tokenizer (required when --enable-thinking is set)\",\n    )\n    parser.add_argument(\n        \"--platinum\",\n        action=\"store_true\",\n        help=\"Use GSM8K Platinum dataset (drop-in replacement with corrected labels)\",\n    )\n    args = add_common_sglang_args_and_parse(parser)\n    main(args)\n"
  },
  {
    "path": "benchmark/hellaswag/README.md",
    "content": "## Run benchmark\n\n### Benchmark sglang\n```\npython -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000\n```\n\n```\npython3 bench_sglang.py --num-questions 200\n```\n\n\n### Benchmark vllm\n```\npython3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000\n```\n\n```\npython3 bench_other.py --num-questions 200 --backend vllm\n```\n\n\n### Benchmark lightllm\n```\n# A10G\npython -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000\n```\n\n```\npython3 bench_other.py --num-questions 200 --backend lightllm\n```\n\n\n### Benchmark guidance\n```\nCUDA_VISIBLE_DEVICES=0,1 python3 bench_other.py --num-questions 200 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf\n```\n\n\n### Benchmark lmql\n```\nlmql serve-model meta-llama/Llama-2-7b-chat-hf --cuda --port 23000\n```\n\n```\npython3 bench_other.py --num-questions 200 --backend lmql --port 23000 --parallel 1\n```\n"
  },
  {
    "path": "benchmark/hellaswag/bench_other.py",
    "content": "import argparse\nimport asyncio\nimport json\nimport time\nfrom concurrent.futures import ThreadPoolExecutor\n\nimport numpy as np\nfrom tqdm import tqdm\n\nfrom sglang.test.test_utils import add_common_other_args_and_parse, get_call_select\nfrom sglang.utils import download_and_cache_file, read_jsonl\n\n\ndef get_one_example(lines, i, include_answer):\n    ret = lines[i][\"activity_label\"] + \": \" + lines[i][\"ctx\"] + \" \"\n    if include_answer:\n        ret += lines[i][\"endings\"][lines[i][\"label\"]]\n    return ret\n\n\ndef get_few_shot_examples(lines, k):\n    ret = \"\"\n    for i in range(k):\n        ret += get_one_example(lines, i, True) + \"\\n\\n\"\n    return ret\n\n\ndef main(args):\n    # Select backend\n    call_select = get_call_select(args)\n\n    # Read data\n    url = \"https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl\"\n    filename = download_and_cache_file(url)\n    lines = list(read_jsonl(filename))\n\n    # Construct prompts\n    num_questions = args.num_questions\n    num_shots = args.num_shots\n    few_shot_examples = get_few_shot_examples(lines, num_shots)\n\n    questions = []\n    choices = []\n    labels = []\n    for i in range(len(lines[:num_questions])):\n        questions.append(get_one_example(lines, i, False))\n        choices.append(lines[i][\"endings\"])\n        labels.append(lines[i][\"label\"])\n\n    preds = [None] * len(labels)\n\n    # Run requests\n    if args.backend != \"lmql\":\n        # Use thread pool\n        def get_one_answer(i):\n            preds[i] = call_select(\n                context=few_shot_examples + questions[i], choices=choices[i]\n            )\n\n        tic = time.perf_counter()\n        if args.parallel == 1:\n            for i in tqdm(range(len(questions))):\n                get_one_answer(i)\n        else:\n            with ThreadPoolExecutor(args.parallel) as executor:\n                list(\n                    tqdm(\n                        executor.map(get_one_answer, list(range(len(questions)))),\n                        total=len(questions),\n                    )\n                )\n    else:\n        # Use asyncio\n        async def batched_call(batch_size):\n            for i in range(0, len(questions), batch_size):\n                tasks = []\n                for q, c in zip(\n                    questions[i : i + batch_size], choices[i : i + batch_size]\n                ):\n                    tasks.append(call_select(context=few_shot_examples + q, choices=c))\n                rets = await asyncio.gather(*tasks)\n                for j in range(len(rets)):\n                    preds[i + j] = rets[j]\n\n        tic = time.perf_counter()\n        asyncio.run(batched_call(batch_size=args.parallel))\n\n    latency = time.perf_counter() - tic\n\n    # Compute accuracy\n    acc = np.mean(np.array(preds) == np.array(labels))\n    print(f\"Latency: {latency:.3f}\")\n    print(f\"Accuracy: {acc:.3f}\")\n\n    # Write results\n    with open(args.result_file, \"a\") as fout:\n        value = {\n            \"task\": \"hellaswag\",\n            \"backend\": args.backend,\n            \"num_gpus\": 1,\n            \"latency\": round(latency, 3),\n            \"accuracy\": round(acc, 3),\n            \"num_requests\": args.num_questions,\n            \"other\": {\n                \"num_questions\": args.num_questions,\n                \"parallel\": args.parallel,\n            },\n        }\n        fout.write(json.dumps(value) + \"\\n\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--num-shots\", type=int, default=20)\n    parser.add_argument(\"--data-path\", type=str, default=\"hellaswag_val.jsonl\")\n    parser.add_argument(\"--num-questions\", type=int, default=200)\n    args = add_common_other_args_and_parse(parser)\n    main(args)\n"
  },
  {
    "path": "benchmark/hellaswag/bench_sglang.py",
    "content": "import argparse\nimport json\nimport os\nimport time\n\nimport numpy as np\n\nfrom sglang.lang.api import set_default_backend\nfrom sglang.test.test_utils import (\n    add_common_sglang_args_and_parse,\n    select_sglang_backend,\n)\nfrom sglang.utils import download_and_cache_file, read_jsonl\n\n\ndef get_one_example(lines, i, include_answer):\n    ret = lines[i][\"activity_label\"] + \": \" + lines[i][\"ctx\"] + \" \"\n    if include_answer:\n        ret += lines[i][\"endings\"][lines[i][\"label\"]]\n    return ret\n\n\ndef get_few_shot_examples(lines, k):\n    ret = \"\"\n    for i in range(k):\n        ret += get_one_example(lines, i, True) + \"\\n\\n\"\n    return ret\n\n\ndef main(args):\n    # Select backend\n    set_default_backend(select_sglang_backend(args))\n\n    # Read data\n    data_path = args.data_path\n    url = \"https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl\"\n    if not os.path.isfile(data_path):\n        data_path = download_and_cache_file(url)\n    lines = list(read_jsonl(data_path))\n\n    # Construct prompts\n    num_questions = args.num_questions\n    num_shots = args.num_shots\n    few_shot_examples = get_few_shot_examples(lines, num_shots)\n\n    questions = []\n    choices = []\n    labels = []\n    for i in range(len(lines[:num_questions])):\n        questions.append(get_one_example(lines, i, False))\n        choices.append(lines[i][\"endings\"])\n        labels.append(lines[i][\"label\"])\n    arguments = [{\"question\": q, \"choices\": c} for q, c in zip(questions, choices)]\n\n    #####################################\n    ######### SGL Program Begin #########\n    #####################################\n\n    import sglang as sgl\n\n    @sgl.function\n    def few_shot_hellaswag(s, question, choices):\n        s += few_shot_examples + question\n        s += sgl.select(\"answer\", choices=choices)\n\n    #####################################\n    ########## SGL Program End ##########\n    #####################################\n\n    # Run requests\n    tic = time.perf_counter()\n    rets = few_shot_hellaswag.run_batch(\n        arguments,\n        temperature=0,\n        num_threads=args.parallel,\n        progress_bar=True,\n    )\n    preds = [choices[i].index(rets[i][\"answer\"]) for i in range(len(rets))]\n    latency = time.perf_counter() - tic\n\n    # Compute accuracy\n    acc = np.mean(np.array(preds) == np.array(labels))\n    print(f\"Latency: {latency:.3f}\")\n    print(f\"Accuracy: {acc:.3f}\")\n\n    # Write results\n    with open(args.result_file, \"a\") as fout:\n        value = {\n            \"task\": \"hellaswag\",\n            \"backend\": args.backend,\n            \"num_gpus\": 1,\n            \"latency\": round(latency, 3),\n            \"accuracy\": round(acc, 3),\n            \"num_requests\": args.num_questions,\n            \"other\": {\n                \"num_questions\": args.num_questions,\n                \"parallel\": args.parallel,\n            },\n        }\n        fout.write(json.dumps(value) + \"\\n\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--num-shots\", type=int, default=20)\n    parser.add_argument(\"--data-path\", type=str, default=\"hellaswag_val.jsonl\")\n    parser.add_argument(\"--num-questions\", type=int, default=200)\n    args = add_common_sglang_args_and_parse(parser)\n    main(args)\n"
  },
  {
    "path": "benchmark/hf3fs/bench.sh",
    "content": "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.12/dist-packages:/usr/local/lib/python3.12/dist-packages/torch/lib\npython3 benchmark/hf3fs/bench_client.py\n\nexport LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.12/dist-packages:/usr/local/lib/python3.12/dist-packages/torch/lib\nSGLANG_HICACHE_HF3FS_CONFIG_PATH=/sgl-workspace/sglang/benchmark/hf3fs/hf3fs.json \\\npython3 benchmark/hf3fs/bench_storage.py\n\nexport LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.12/dist-packages:/usr/local/lib/python3.12/dist-packages/torch/lib\nexport SGLANG_HICACHE_HF3FS_CONFIG_PATH=/sgl-workspace/sglang/benchmark/hf3fs/hf3fs.json\necho '{\"file_path_prefix\": \"/data/hf3fs-test-0\", \"file_size\": 1099511627776, \"numjobs\": 16, \"entries\": 8}' > \\\n${SGLANG_HICACHE_HF3FS_CONFIG_PATH}\npython3 benchmark/hf3fs/bench_zerocopy.py\n\n####################################################################################################\n\nrm -rf nohup.out && \\\nnohup python3 -m sglang.launch_server \\\n    --model-path /code/models/Qwen3-32B/ \\\n    --host 0.0.0.0 --port 33301 \\\n    --page-size 64 \\\n    --enable-hierarchical-cache \\\n    --hicache-ratio 2 --hicache-size 0 \\\n    --hicache-write-policy write_through \\\n    --hicache-storage-backend hf3fs &\n\nrm -rf bench_multiturn.out && \\\nnohup python3 benchmark/hicache/bench_multiturn.py \\\n    --model-path /code/models/Qwen3-32B \\\n    --dataset-path /code/models/ShareGPT_V3_unfiltered_cleaned_split.json \\\n    --port 33301 \\\n    --request-length 2048 --num-clients 512 --num-rounds 3 --max-parallel 8 \\\n    > bench_multiturn.out &\n\n####################################################################################################\n\nrm -rf nohup.out && \\\nnohup python3 -m sglang.launch_server \\\n    --model-path /code/models/DeepSeek-R1/ \\\n    --tp 16 --nnodes 2 --node-rank 0 \\\n    --dist-init-addr 10.74.249.153:5000 \\\n    --host 0.0.0.0 --port 33301 \\\n    --page-size 64 \\\n    --enable-hierarchical-cache \\\n    --hicache-ratio 2 --hicache-size 60 \\\n    --hicache-write-policy write_through \\\n    --hicache-storage-backend hf3fs &\n\nrm -rf bench_multiturn.out && \\\nnohup python3 benchmark/hicache/bench_multiturn.py \\\n    --model-path /code/models/Qwen3-32B \\\n    --dataset-path /code/models/ShareGPT_V3_unfiltered_cleaned_split.json \\\n    --port 33301 \\\n    --request-length 2048 --num-clients 1024 --num-rounds 3 --max-parallel 8 \\\n    > bench_multiturn.out &\n\n####################################################################################################\n\nps aux | grep \"sglang.launch_server\" | grep -v grep | awk '{print $2}' | xargs kill -9\nps aux | grep \"bench_multiturn.py\" | grep -v grep | awk '{print $2}' | xargs kill -9\n"
  },
  {
    "path": "benchmark/hf3fs/bench_client.py",
    "content": "import concurrent.futures\nimport logging\nimport random\nimport time\nfrom typing import List\n\nimport torch\nfrom tqdm import tqdm\n\nfrom sglang.srt.mem_cache.storage.hf3fs.hf3fs_usrbio_client import Hf3fsUsrBioClient\n\n\ndef print_stats(x: List[int]):\n    x = sorted(x)\n    lenx = len(x)\n    print(\n        f\"mean = {sum(x)/len(x):.2f}, \"\n        f\"min = {min(x):.2f}, \"\n        f\"p25 = {x[int(lenx*0.25)]:.2f}, \"\n        f\"p50 = {x[int(lenx*0.5)]:.2f}, \"\n        f\"p75 = {x[int(lenx*0.75)]:.2f}, \"\n        f\"max = {max(x):.2f}\"\n    )\n\n\ndef test():\n    # /path/to/hf3fs\n    file_path = \"/data/bench.bin\"\n    file_size = 1 << 40\n    bytes_per_page = 16 << 20\n    entries = 32\n    file_ops = Hf3fsUsrBioClient(file_path, file_size, bytes_per_page, entries, 5)\n\n    print(\"test batch_read / batch_write\")\n    num_pages = 128\n    dtype = torch.bfloat16\n    numel = bytes_per_page // dtype.itemsize\n    offsets = list(range(file_size // bytes_per_page))\n    random.shuffle(offsets)\n    offsets = offsets[:num_pages]\n    offsets = [i * bytes_per_page for i in offsets]\n    tensor_writes = [\n        torch.randn(numel, dtype=dtype)\n        for _ in tqdm(range(num_pages), desc=\"prepare tensor\")\n    ]\n    for i in tqdm(range(0, num_pages, file_ops.entries), desc=\"batch_write\"):\n        results = file_ops.batch_write(\n            offsets[i : i + file_ops.entries], tensor_writes[i : i + file_ops.entries]\n        )\n        assert all([result == numel * dtype.itemsize for result in results])\n    tensor_reads = [\n        torch.empty(numel, dtype=dtype)\n        for _ in tqdm(range(num_pages), desc=\"prepare tensor\")\n    ]\n    for i in tqdm(range(0, num_pages, file_ops.entries), desc=\"batch_read\"):\n        results = file_ops.batch_read(\n            offsets[i : i + file_ops.entries], tensor_reads[i : i + file_ops.entries]\n        )\n        assert all([result == numel * dtype.itemsize for result in results])\n    assert all([torch.allclose(r, w) for r, w in zip(tensor_reads, tensor_writes)])\n\n    file_ops.close()\n    print(\"test done\")\n\n\ndef bench():\n    file_path = \"/data/bench.bin\"\n    file_size = 1 << 40\n    bytes_per_page = 16 << 20\n    entries = 8\n    numjobs = 16\n\n    dtype = torch.bfloat16\n    numel = bytes_per_page // dtype.itemsize\n\n    file_ops = [\n        Hf3fsUsrBioClient(file_path, file_size, bytes_per_page, entries, 5)\n        for _ in range(numjobs)\n    ]\n\n    num_page = entries\n\n    offsets = list(range(file_size // bytes_per_page))\n    tensors_write = [torch.randn(numel, dtype=dtype)] * num_page\n    tensors_read = [torch.empty(numel, dtype=dtype)] * num_page\n    random.shuffle(offsets)\n\n    warmup = 50\n    iteration = 100\n\n    executor = concurrent.futures.ThreadPoolExecutor(max_workers=numjobs)\n\n    w_bw = []\n    w_size = num_page * numjobs * bytes_per_page / (1 << 30)\n    for i in tqdm(range(warmup + iteration), desc=\"Benchmarking write (GB/s)\"):\n        _offsets = [\n            [\n                offset * bytes_per_page\n                for offset in offsets[\n                    (i * numjobs + j) * num_page : (i * numjobs + j + 1) * num_page\n                ]\n            ]\n            for j in range(numjobs)\n        ]\n        tik = time.perf_counter()\n        futures = [\n            executor.submit(file_ops[j].batch_write, offset, tensors_write)\n            for j, offset in enumerate(_offsets)\n        ]\n        results = [future.result() for future in futures]\n        tok = time.perf_counter()\n        if i < warmup:\n            continue\n        w_bw.append(w_size / (tok - tik))\n        results = [\n            _result == bytes_per_page for result in results for _result in result\n        ]\n        assert all(results)\n    print_stats(w_bw)\n\n    r_bw = []\n    r_size = w_size\n    for i in tqdm(range(warmup + iteration), desc=\"Benchmarking read (GB/s)\"):\n        _offsets = [\n            [\n                offset * bytes_per_page\n                for offset in offsets[\n                    (i * numjobs + j) * num_page : (i * numjobs + j + 1) * num_page\n                ]\n            ]\n            for j in range(numjobs)\n        ]\n        tik = time.perf_counter()\n        futures = [\n            executor.submit(file_ops[j].batch_read, offset, tensors_read)\n            for j, offset in enumerate(_offsets)\n        ]\n        results = [future.result() for future in futures]\n        tok = time.perf_counter()\n        if i < warmup:\n            continue\n        r_bw.append(r_size / (tok - tik))\n        results = [\n            _result == bytes_per_page for result in results for _result in result\n        ]\n        assert all(results)\n    print_stats(r_bw)\n\n    executor.shutdown(wait=True)\n    for _file_ops in file_ops:\n        _file_ops.close()\n    print(\"bench done\")\n\n\ndef main():\n    logging.basicConfig(level=logging.INFO)\n    test()\n    bench()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "benchmark/hf3fs/bench_storage.py",
    "content": "import json\nimport logging\nimport os\nimport random\nimport time\nfrom typing import List\n\nimport torch\nfrom tqdm import tqdm\n\nfrom sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (\n    Hf3fsLocalMetadataClient,\n)\nfrom sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import HiCacheHF3FS\n\n\ndef print_stats(x: List[int]):\n    x = sorted(x)\n    lenx = len(x)\n    print(\n        f\"mean = {sum(x)/len(x):.2f}, \"\n        f\"min = {min(x):.2f}, \"\n        f\"p25 = {x[int(lenx*0.25)]:.2f}, \"\n        f\"p50 = {x[int(lenx*0.5)]:.2f}, \"\n        f\"p75 = {x[int(lenx*0.75)]:.2f}, \"\n        f\"max = {max(x):.2f}\"\n    )\n\n\ndef test():\n    # Qwen3-32B\n    layer_num = 64\n    head_num, head_dim = 8, 128\n    kv_lora_rank, qk_rope_head_dim = 0, 0\n    store_dtype = torch.bfloat16\n    tokens_per_page = 64\n\n    file_path_prefix = \"/data/test\"\n    file_size = 128 << 20\n    numjobs = 16\n    bytes_per_page = 16 << 20\n    entries = 2\n    dtype = store_dtype\n\n    config_path = os.getenv(HiCacheHF3FS.default_env_var)\n    assert config_path\n    try:\n        with open(config_path, \"w\") as f:\n            json.dump(\n                {\n                    \"file_path_prefix\": file_path_prefix,\n                    \"file_size\": file_size,\n                    \"numjobs\": numjobs,\n                    \"entries\": entries,\n                },\n                f,\n            )\n    except Exception as e:\n        raise RuntimeError(f\"Failed to dump config to {config_path}: {str(e)}\")\n    hicache_hf3fs = HiCacheHF3FS.from_env_config(bytes_per_page, dtype)\n\n    numel = 2 * tokens_per_page * layer_num * head_num * head_dim\n    assert numel * dtype.itemsize == bytes_per_page\n\n    num_pages = 10\n    tensors = {}\n    for i in range(num_pages):\n        k = f\"key_{i}\"\n        v = torch.randn((numel,)).to(dtype=dtype)\n        ok = hicache_hf3fs.set(k, v)\n        if i < (file_size // bytes_per_page):\n            assert ok, f\"Failed to insert {k}\"\n        else:\n            assert not ok\n        tensors[k] = v\n    assert hicache_hf3fs.get(\"key_8\") is None\n    assert hicache_hf3fs.get(\"key_9\") is None\n\n    start = 0\n    for i in range(start, start + hicache_hf3fs.num_pages):\n        k = f\"key_{i}\"\n        assert hicache_hf3fs.exists(k)\n        out = hicache_hf3fs.get(k)\n        assert out is not None\n        v = tensors[k]\n        assert torch.allclose(v, out, atol=1e-3), f\"Tensor mismatch for {k}\"\n\n    assert not hicache_hf3fs.exists(\"not_exists\")\n\n    hicache_hf3fs.delete(\"key_7\")\n    v2 = torch.randn((numel,)).to(dtype=dtype)\n    assert hicache_hf3fs.set(\"key_new\", v2)\n    assert torch.allclose(hicache_hf3fs.get(\"key_new\"), v2, atol=1e-3)\n\n    hicache_hf3fs.clear()\n    assert (\n        len(hicache_hf3fs.metadata_client.rank_metadata.free_pages)\n        == hicache_hf3fs.metadata_client.rank_metadata.num_pages\n    )\n\n    # batch\n    num_pages = 10\n    tensors = {}\n    keys = []\n    values = []\n    for i in range(num_pages):\n        k = f\"key_{i}\"\n        keys.append(k)\n        v = torch.randn((numel,)).to(dtype=dtype)\n        values.append(v)\n\n    ok = hicache_hf3fs.batch_set(keys, values)\n    assert not ok\n    assert hicache_hf3fs.get(\"key_8\") is None\n    assert hicache_hf3fs.get(\"key_9\") is None\n\n    results = hicache_hf3fs.batch_get(keys[: hicache_hf3fs.num_pages])\n    for result, key, value in zip(\n        results, keys[: hicache_hf3fs.num_pages], values[: hicache_hf3fs.num_pages]\n    ):\n        assert torch.allclose(value, result, atol=1e-3), f\"Tensor mismatch for {key}\"\n\n    hicache_hf3fs.close()\n    os.remove(hicache_hf3fs.file_path)\n\n    print(\"All test cases passed.\")\n\n\ndef bench():\n    # Qwen3-32B\n    layer_num = 64\n    head_num, head_dim = 8, 128\n    kv_lora_rank, qk_rope_head_dim = 0, 0\n    store_dtype = torch.bfloat16\n    tokens_per_page = 64\n\n    file_path = \"/data/test.bin\"\n    file_size = 1 << 40\n    numjobs = 16\n    bytes_per_page = 16 << 20\n    entries = 8\n    dtype = store_dtype\n    hicache_hf3fs = HiCacheHF3FS(\n        rank=0,\n        file_path=file_path,\n        file_size=file_size,\n        numjobs=numjobs,\n        bytes_per_page=bytes_per_page,\n        entries=entries,\n        dtype=dtype,\n        metadata_client=Hf3fsLocalMetadataClient(),\n    )\n\n    numel = 2 * tokens_per_page * layer_num * head_num * head_dim\n    assert numel * dtype.itemsize == bytes_per_page\n\n    num_page = 128\n    values = [torch.randn((numel,)).to(dtype=dtype) for _ in tqdm(range(num_page))]\n\n    warmup = 50\n    iteration = 100\n\n    w_bw = []\n    w_size = num_page * bytes_per_page / (1 << 30)\n    for i in tqdm(range(warmup + iteration), desc=\"Benchmarking write (GB/s)\"):\n        keys = [f\"{j}\" for j in range(i * num_page, (i + 1) * num_page)]\n        tik = time.perf_counter()\n        ok = hicache_hf3fs.batch_set(keys, values)\n        tok = time.perf_counter()\n        if i < warmup:\n            continue\n        w_bw.append(w_size / (tok - tik))\n        assert ok\n    print_stats(w_bw)\n\n    r_bw = []\n    r_size = num_page * bytes_per_page / (1 << 30)\n    for i in tqdm(range(warmup + iteration), desc=\"Benchmarking read (GB/s)\"):\n        keys = random.sample(\n            list(hicache_hf3fs.metadata_client.rank_metadata.key_to_index.keys()),\n            num_page,\n        )\n        tik = time.perf_counter()\n        results = hicache_hf3fs.batch_get(keys)\n        tok = time.perf_counter()\n        if i < warmup:\n            continue\n        r_bw.append(r_size / (tok - tik))\n        assert all([r is not None for r in results])\n    print_stats(r_bw)\n\n    hicache_hf3fs.close()\n\n\ndef allclose():\n    # Qwen3-32B\n    layer_num = 64\n    head_num, head_dim = 8, 128\n    kv_lora_rank, qk_rope_head_dim = 0, 0\n    store_dtype = torch.bfloat16\n    tokens_per_page = 64\n\n    file_path = \"/data/test.bin\"\n    file_size = 1 << 40\n    numjobs = 16\n    bytes_per_page = 16 << 20\n    entries = 8\n    dtype = store_dtype\n    hicache_hf3fs = HiCacheHF3FS(\n        rank=0,\n        file_path=file_path,\n        file_size=file_size,\n        numjobs=numjobs,\n        bytes_per_page=bytes_per_page,\n        entries=entries,\n        dtype=dtype,\n        metadata_client=Hf3fsLocalMetadataClient(),\n    )\n\n    numel = 2 * tokens_per_page * layer_num * head_num * head_dim\n    assert numel * dtype.itemsize == bytes_per_page\n\n    num_page = 128\n    values = [torch.randn((numel,)).to(dtype=dtype) for _ in tqdm(range(num_page))]\n\n    iteration = 100\n\n    for i in tqdm(range(iteration), desc=\"Benchmarking write (GB/s)\"):\n        keys = [f\"{j}\" for j in range(i * num_page, (i + 1) * num_page)]\n        ok = hicache_hf3fs.batch_set(keys, values)\n        assert ok\n\n    read_keys, read_results = [], []\n    for i in tqdm(range(iteration), desc=\"Benchmarking read (GB/s)\"):\n        keys = random.sample(\n            list(hicache_hf3fs.metadata_client.rank_metadata.key_to_index.keys()),\n            num_page,\n        )\n        results = hicache_hf3fs.batch_get(keys)\n        read_keys.extend(keys)\n        read_results.extend(results)\n        assert all([r is not None for r in results])\n\n    for key, result in tqdm(zip(read_keys, read_results)):\n        assert torch.allclose(values[int(key) % num_page], result, atol=1e-3)\n\n    hicache_hf3fs.close()\n\n\ndef main():\n    logging.basicConfig(level=logging.INFO)\n    test()\n    bench()\n    allclose()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "benchmark/hf3fs/bench_zerocopy.py",
    "content": "import threading\nimport time\n\nimport torch\nfrom tqdm import tqdm\n\nfrom sglang.srt.distributed import (\n    get_world_group,\n    init_distributed_environment,\n    initialize_model_parallel,\n)\nfrom sglang.srt.managers.cache_controller import (\n    HiCacheController,\n    PrefetchOperation,\n    StorageOperation,\n)\nfrom sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator\nfrom sglang.srt.mem_cache.memory_pool import MHATokenToKVPool\nfrom sglang.srt.mem_cache.memory_pool_host import MHATokenToKVPoolHost\n\ninit_distributed_environment(\n    world_size=1,\n    rank=0,\n    distributed_init_method=\"tcp://127.0.0.1:23456\",\n    local_rank=0,\n    backend=\"gloo\",\n)\n\ninitialize_model_parallel(\n    tensor_model_parallel_size=1,\n    pipeline_model_parallel_size=1,\n)\n\ngroup = get_world_group().cpu_group\n\nmax_total_num_tokens = 524288\npage_size = 64\nkv_cache_dtype = torch.bfloat16\nlayer_num = 64\nhead_num, head_dim = 8, 128\ndevice = \"cuda\"\nhicache_ratio = 2\nhicache_size = 0\nhicache_mem_layout = \"page_first\"\n# hicache_mem_layout = \"layer_first\"\nhicache_write_policy = \"write_through\"\nhicache_io_backend = \"kernel\"\nhicache_storage_backend = \"hf3fs\"\nprefetch_threshold = 256\n\nop_size = 1024\nop_num = 16\n\ntoken_to_kv_pool = MHATokenToKVPool(\n    max_total_num_tokens,\n    page_size=page_size,\n    dtype=kv_cache_dtype,\n    head_num=head_num,\n    head_dim=head_dim,\n    layer_num=layer_num,\n    device=device,\n    enable_memory_saver=True,\n)\n\ntoken_to_kv_pool_allocator = TokenToKVPoolAllocator(\n    max_total_num_tokens,\n    dtype=kv_cache_dtype,\n    device=device,\n    kvcache=token_to_kv_pool,\n    need_sort=False,\n)\n\nkv_cache = token_to_kv_pool_allocator.get_kvcache()\ntoken_to_kv_pool_host = MHATokenToKVPoolHost(\n    kv_cache,\n    hicache_ratio,\n    hicache_size,\n    page_size,\n    hicache_mem_layout,\n)\n\nload_cache_event = threading.Event()\ncache_controller = HiCacheController(\n    token_to_kv_pool_allocator,\n    token_to_kv_pool_host,\n    page_size,\n    group,\n    load_cache_event=load_cache_event,\n    write_policy=hicache_write_policy,\n    io_backend=hicache_io_backend,\n    storage_backend=hicache_storage_backend,\n    prefetch_threshold=prefetch_threshold,\n)\n\noperations = [\n    StorageOperation(\n        torch.tensor(list(range(i, i + op_size))),\n        list(range(i, i + op_size)),\n        hash_value=[f\"{j}\" for j in range(i, i + op_size, page_size)],\n    )\n    for i in tqdm(range(0, op_num * op_size, op_size))\n]\n\ntik = time.monotonic()\nif hicache_mem_layout == \"page_first\":\n    for operation in operations:\n        cache_controller.zerocopy_page_backup(operation, batch_size=128)\nelif hicache_mem_layout == \"layer_first\":\n    for operation in operations:\n        cache_controller.generic_page_backup(operation, batch_size=128)\ntok = time.monotonic()\nprint(f\"{tok-tik:.6f} s\")\n\noperations = [\n    PrefetchOperation(\n        f\"{i}\",\n        torch.tensor(list(range(i, i + op_size))),\n        list(range(i, i + op_size)),\n        f\"{i}\",\n    )\n    for i in tqdm(range(0, op_num * op_size, op_size))\n]\n\nfor operation in operations:\n    operation.hash_value = [\n        f\"{j}\"\n        for j in range(\n            int(operation.last_hash), int(operation.last_hash) + op_size, page_size\n        )\n    ]\n\ntik = time.monotonic()\nif hicache_mem_layout == \"page_first\":\n    for operation in operations:\n        cache_controller.zerocopy_page_transfer(operation, batch_size=128)\nelif hicache_mem_layout == \"layer_first\":\n    for operation in operations:\n        cache_controller.generic_page_transfer(operation, batch_size=128)\ntok = time.monotonic()\nprint(f\"{tok-tik:.6f} s\")\n"
  },
  {
    "path": "benchmark/hicache/README.md",
    "content": "## Run synthetic multi-turn benchmark\n\n```\n# SGLang server with radix cache disabled\npython -m sglang.launch_server --model-path Qwen/Qwen2.5-14B-Instruct --port 30000 --disable-radix-cache\n\n# SGLang server with radix cache on and first-come-first-serve policy\npython -m sglang.launch_server --model-path Qwen/Qwen2.5-14B-Instruct --port 30000 --schedule-policy fcfs\n\n# The default SGLang server with radix cache on and long-prefix-match policy\npython -m sglang.launch_server --model-path Qwen/Qwen2.5-14B-Instruct --port 30000\n\n# SGLang server with hierarchical radix cache enabled\npython -m sglang.launch_server --model-path Qwen/Qwen2.5-14B-Instruct --port 30000 --enable-hierarchical-cache\n\n```\n\n```\npython bench_multiturn.py --model-path Qwen/Qwen2.5-14B-Instruct\n```\n\nNote: The performance gain of hierarchical caching depends on the ratio of reusable tokens to GPU memory capacity. The more tokens to be reused, the larger the model, and the more constrained the GPU memory size, the greater the benefit one can expect from hierarchical caching.\n\n\n# Benchmark with more datasets\n## Download Dataset\n```bash\n./download.sh {sharegpt|ultragpt|loogle|nextqa|all}\n```\nThis script will automatically download the required dataset to the current working directory\n\n## Multiturn Benchmark\n### Supported Datasets\n- sharegpt\n- ultrachat\n- loogle\n### Example Usage:\n```bash\npython3 bench_serving.py --model mistralai/Mistral-7B-Instruct-v0.3 --backend sglang \\\n--dataset-path longdep_qa.json --dataset-name loogle --request-rate 10 --num-prompts 10  \\\n--port 8001 --enable-multiturn --disable-shuffle\n```\nThis uses `mistralai/Mistral-7B-Instruct-v0.3` model with `sglang` as backend. The dataset\nis `longdep_qa.json`. We send `10 conversations` with `10 req/s` to port 8001. We enable\nmultiturn chat without shuffling the order of conversations (i.e. following the original\norder in the dataset file).\n\n### Note:\nThe requests of multiple conversations are sent in a round robin fashion.\nFor example, if we have 3 conversations A, B, C whose rounds are `[2, 3, 4]` correspondingly,\nmultiturn chat will send the requests to the backend in the following order: `[A1, B1, C1, A2, B2, C2, B3, C3, C4]`\nThis has implications on the cache reuse patterns: the cache reuse distance is the largest\nunder this request pattern (which means a prefix-aware local scheduler in the backend can\nyield the most benefit compared to a FIFO scheduler)\n\n## Shared Prefix Benchmark\n### Supported Datasets\n- loogle\n### Example Usage:\n```bash\npython3 bench_serving.py --model mistralai/Mistral-7B-Instruct-v0.3 --backend sglang \\\n--dataset-path longdep_qa.json --dataset-name loogle --request-rate 10 --num-prompts 10  \\\n--port 8001 --enable-shared-prefix --disable-shuffle\n```\n### Note:\nShared Prefix benchmark sends the questions for the same prompt together. For example,\nif we have 3 shared prefix A, B, C, which have [2, 3, 4] questions correspondingly,\nthe shared prefix benchmark will send the requests to the\nbackend in the following order: `[A+Q1, A+Q2, B+Q1, B+Q2, B+Q3, C+Q1, C+Q2, C+Q3]`.\n\n\n## Multi Modality Benchmark (WIP)\n### Supported Datasets:\n- nextqa\n### Example Usage:\n```bash\nServer:\npython3 -m sglang.launch_server --model-path lmms-lab/LLaVA-NeXT-Video-7B  --tp 2 --dp 1 --port 8001 \\\n--host 0.0.0.0 --mem-fraction-static 0.9 --tokenizer-path llava-hf/llava-1.5-7b-hf \\\n--json-model-override-args \"{\\\"architectures\\\": [\\\"LlavaVidForCausalLM\\\"], \\\"model_type\\\":\\\"llava\\\", \\\"mm_spatial_pool_stride\\\":2}\"\n\nClient:\npython3 bench_serving.py --model lmms-lab/LLaVA-NeXT-Video-7B --backend sglang  --dataset-path \\\nNExTVideo  --dataset-name nextqa --request-rate 10 --num-prompts 1 --disable-shuffle --port 8001 \\ --enable-multiturn --max-frames 16 --tokenizer llava-hf/llava-1.5-7b-hf --fixed-output-len 2048\n```\nNote: for the server args, `tokenizer-path`, overriding architecture are necessary.\n\n## Supported Backend\n- sglang (oai)\n- vllm (oai)\n- lmdeploy (oai)\n"
  },
  {
    "path": "benchmark/hicache/bench_long_context.py",
    "content": "import json\nimport queue\nimport time\n\nimport requests\nfrom bench_multiturn import (\n    ReadyQueue,\n    WorkloadGenerator,\n    gen_payload,\n    log_to_jsonl_file,\n    parse_args,\n)\nfrom tqdm.asyncio import tqdm\n\nfrom sglang.benchmark.utils import get_tokenizer\n\n\nclass ContextWorkloadGenerator(WorkloadGenerator):\n    def __init__(self, args):\n        # Construct the base URL for requests\n        self.baseurl = f\"http://{args.host}:{args.port}/\"\n        self.url = self.baseurl + \"generate\"\n\n        self.tokenizer = get_tokenizer(args.model_path)\n        self.distribution = args.distribution\n        self.request_rate = args.request_rate\n        self.start_time = None\n        self.finished_time = None\n\n        self.sent_requests = 0\n        self.completed_requests = 0\n\n        self.dataset = json.load(open(args.dataset_path))\n        num_requests = min(args.num_clients, len(self.dataset[\"queries\"]))\n\n        init_requests = []\n        for i in range(num_requests):\n            context_id = self.dataset[\"queries\"][i][\"context\"]\n            # Tokenize the context + question to get input_ids\n            prompt_text = (\n                self.dataset[\"contexts\"][context_id]\n                + self.dataset[\"queries\"][i][\"question\"]\n            )\n            input_ids = self.tokenizer.encode(prompt_text)\n            output_len = len(\n                self.tokenizer(self.dataset[\"queries\"][i][\"reference_answer\"])[\n                    \"input_ids\"\n                ]\n            )\n            init_requests.append((i, gen_payload(input_ids, output_len)))\n        self.ready_queue = ReadyQueue(init_requests=init_requests)\n\n        self.response_queue = queue.Queue()\n        self.pbar = tqdm(total=num_requests)\n        self.performance_metrics = {\n            \"ttft\": [],\n            \"latency\": [],\n            \"itl\": [],\n            \"prompt_len\": [],\n            \"cached_tokens\": [],\n            \"generated_len\": [],\n        }\n\n        self.max_parallel = args.max_parallel\n        self.logfile = args.log_file\n        self.enable_round_barrier = False\n\n    def response_handler(self):\n        while True:\n            try:\n                client_id, response = self.response_queue.get(\n                    timeout=10\n                )  # Block until response is available\n                if not response.success:\n                    raise ValueError(f\"Request failed with error: {response.error}\")\n                self.performance_metrics[\"ttft\"].append(response.ttft)\n                self.performance_metrics[\"itl\"].extend(response.itl)\n                self.performance_metrics[\"latency\"].append(response.latency)\n                self.performance_metrics[\"prompt_len\"].append(response.prompt_len)\n                self.performance_metrics[\"cached_tokens\"].append(response.cached_tokens)\n                self.performance_metrics[\"generated_len\"].append(response.generated_len)\n                self.completed_requests += 1\n\n            except queue.Empty:\n                if self.pbar.n == self.pbar.total:\n                    break\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    args.num_rounds = 1\n    args.max_parallel = 24\n    flush_cache_url = f\"http://{args.host}:{args.port}/flush_cache\"\n\n    for request_rate in [24, 16, 12, 8, 4, 2, 1]:\n        args.request_rate = request_rate\n        requests.post(flush_cache_url)\n        time.sleep(1)\n        performance_data = ContextWorkloadGenerator(args).run()\n        log_to_jsonl_file(performance_data, args.log_file, args.tag)\n"
  },
  {
    "path": "benchmark/hicache/bench_mix.py",
    "content": "import argparse\nimport asyncio\nimport json\nimport logging\nimport os\nimport queue\nimport random\nimport threading\nimport time\nfrom dataclasses import dataclass\nfrom functools import wraps\n\nimport aiohttp\n\nfrom sglang.bench_serving import RequestFuncOutput\nfrom sglang.benchmark.datasets.random import sample_random_requests\nfrom sglang.benchmark.utils import get_tokenizer, remove_prefix\n\n# Set up logger\nlogger = logging.getLogger(__name__)\n\n# Set up JSONL file for debug logging\ndebug_log_file = None\n# Create a lock for thread-safe debug log writing\ndebug_log_lock = threading.Lock()\n\n\ndef write_debug_log(data):\n    global debug_log_file\n\n    \"\"\"Write debug information to a JSONL file\"\"\"\n    if debug_log_file is None:\n        return\n\n    # Acquire lock for thread-safe writing\n    with debug_log_lock:\n        # Write as JSONL (JSON Line format)\n        debug_log_file.write(json.dumps(data) + \"\\n\")\n        debug_log_file.flush()\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description=\"Script to benchmark concurrent requests to a server.\"\n    )\n    parser.add_argument(\n        \"--model-path\",\n        type=str,\n        default=\"/data/models/Qwen3-0.6B\",\n        help=\"model path compatible with Hugging Face Transformers\",\n    )\n    parser.add_argument(\n        \"--dataset-path\",\n        type=str,\n        default=\"/data/models/ShareGPT_V3_unfiltered_cleaned_split/ShareGPT_V3_unfiltered_cleaned_split.json\",\n        help=\"local dataset to sample tokens from\",\n    )\n    parser.add_argument(\n        \"--host\",\n        type=str,\n        default=\"localhost\",\n        help=\"Server hostname or IP (default: localhost)\",\n    )\n    parser.add_argument(\n        \"--port\",\n        type=int,\n        default=30000,\n        help=\"Server port (default: 30000)\",\n    )\n    parser.add_argument(\n        \"--duration\",\n        type=int,\n        default=600,\n        help=\"Duration to run the benchmark in seconds (default: 300 seconds)\",\n    )\n    parser.add_argument(\n        \"--log-level\",\n        type=str,\n        default=\"info\",\n        choices=[\"debug\", \"info\"],\n        help=\"Set the logging level (default: info)\",\n    )\n    parser.add_argument(\n        \"--debug-log-file\",\n        type=str,\n        default=\"debug.log.jsonl\",\n        help=\"File to write debug logs in JSONL format\",\n    )\n    return parser.parse_args()\n\n\ndef load_config():\n    config_path = os.getenv(\"CONFIG_PATH\")\n    if not config_path:\n        raise ValueError(\"Environment variable 'CONFIG_PATH' is not set.\")\n\n    with open(config_path, \"r\") as f:\n        config = json.load(f)\n\n    required_keys = [\n        \"num_rounds\",\n        \"num_clients\",\n        \"round_ratios\",\n        \"mean_new_tokens_per_round\",\n        \"mean_return_tokens_per_round\",\n        \"mean_inter_round_interval\",\n    ]\n\n    for key in required_keys:\n        if key not in config:\n            raise KeyError(f\"Missing required configuration key: {key}\")\n\n    num_rounds = config[\"num_rounds\"]\n    assert len(config[\"round_ratios\"]) == num_rounds\n    assert len(config[\"mean_new_tokens_per_round\"]) == num_rounds\n    assert len(config[\"mean_return_tokens_per_round\"]) == num_rounds\n    assert len(config[\"mean_inter_round_interval\"]) == num_rounds\n\n    print(config)\n\n    return config\n\n\n@dataclass\nclass UserData:\n    user_id: int\n    current_round: int\n    total_rounds: int\n    prompt: str\n    return_tokens: int\n    start: int\n\n\ndef synchronized():\n    def _decorator(func):\n        @wraps(func)\n        def wrapper(self, *args, **kwargs):\n            with self.lock:\n                return func(self, *args, **kwargs)\n\n        return wrapper\n\n    return _decorator\n\n\nclass UserGenerator:\n    def __init__(self, config, model_path, dataset_path):\n        self.tokenizer_path = model_path\n        self.tokenizer = get_tokenizer(self.tokenizer_path)\n        self.dataset_path = dataset_path\n\n        self.user_id = 0\n        self.lock = threading.Lock()\n\n        self.num_rounds = config[\"num_rounds\"]\n\n        self.cumulative_ratios = [\n            sum(config[\"round_ratios\"][: i + 1])\n            for i in range(len(config[\"round_ratios\"]))\n        ]\n        self.mean_new_tokens_per_round = config[\"mean_new_tokens_per_round\"]\n        self.mean_return_tokens_per_round = config[\"mean_return_tokens_per_round\"]\n        self.mean_inter_round_interval = config[\"mean_inter_round_interval\"]\n\n        self.sigma = 100\n        self.range_ratio = 0.8\n        assert self.range_ratio <= 1\n\n        self.candidate_inputs = [\n            [\n                r\n                for r in sample_random_requests(\n                    input_len=(\n                        self.mean_new_tokens_per_round[i] * (2 - self.range_ratio)\n                    ),\n                    output_len=(\n                        self.mean_return_tokens_per_round[i] * (2 - self.range_ratio)\n                    ),\n                    num_prompts=config[\"num_clients\"],\n                    range_ratio=self.range_ratio / (2 - self.range_ratio),\n                    tokenizer=self.tokenizer,\n                    dataset_path=self.dataset_path,\n                    random_sample=False,\n                )\n            ]\n            for i in range(self.num_rounds)\n        ]\n\n        self.multiturn_queue = []\n\n        self.user_stats = [0 for _ in range(self.num_rounds)]\n        self.input_stats = [[0, 0] for _ in range(self.num_rounds)]\n        self.output_stats = [[0, 0] for _ in range(self.num_rounds)]\n\n    def gen(self):\n        user_id = self.user_id\n        self.user_id += 1\n\n        rand_ratio = random.randint(0, self.cumulative_ratios[-1])\n        i = len(self.cumulative_ratios)\n        for idx, cumulative_ratio in enumerate(self.cumulative_ratios):\n            if rand_ratio >= cumulative_ratio:\n                continue\n            else:\n                i = idx + 1\n                break\n        total_rounds = i\n        current_round = 0\n\n        candidate_input = random.sample(self.candidate_inputs[current_round], 1)[0]\n        self.input_stats[0][0] += candidate_input.prompt_len\n        self.input_stats[0][1] += 1\n        prompt = f\"{user_id} \" + candidate_input.prompt\n        return_tokens = int(\n            random.gauss(self.mean_return_tokens_per_round[current_round], self.sigma)\n        )\n        if return_tokens <= 0:\n            return_tokens = self.mean_return_tokens_per_round[current_round]\n        start = 0\n\n        user_data = UserData(\n            user_id, current_round, total_rounds, prompt, return_tokens, start\n        )\n\n        self.user_stats[total_rounds - 1] += 1\n\n        return user_data\n\n    @synchronized()\n    def push(self, user_data, generated_text, len_itl):\n        self.output_stats[user_data.current_round][0] += len_itl + 1\n        self.output_stats[user_data.current_round][1] += 1\n        user_data.current_round += 1\n        if user_data.current_round >= user_data.total_rounds:\n            return\n\n        candidate_input = random.sample(\n            self.candidate_inputs[user_data.current_round], 1\n        )[0]\n        self.input_stats[user_data.current_round][0] += candidate_input.prompt_len\n        self.input_stats[user_data.current_round][1] += 1\n        user_data.prompt += generated_text + candidate_input.prompt\n        user_data.return_tokens = int(\n            random.gauss(\n                self.mean_return_tokens_per_round[user_data.current_round], self.sigma\n            )\n        )\n        if user_data.return_tokens <= 0:\n            user_data.return_tokens = self.mean_return_tokens_per_round[\n                user_data.current_round\n            ]\n        interval = random.gauss(\n            self.mean_inter_round_interval[user_data.current_round], self.sigma\n        )\n        if interval <= 0:\n            interval = self.mean_inter_round_interval[user_data.current_round]\n        user_data.start = time.perf_counter() + interval\n\n        if len(self.multiturn_queue) == 0:\n            self.multiturn_queue.append(user_data)\n        else:\n            i = len(self.multiturn_queue)\n            for idx, d in enumerate(self.multiturn_queue):\n                if user_data.start < d.start:\n                    i = idx\n                    break\n            self.multiturn_queue.insert(idx, user_data)\n\n    @synchronized()\n    def pop(self):\n        if (\n            len(self.multiturn_queue)\n            and time.perf_counter() > self.multiturn_queue[0].start\n        ):\n            return self.multiturn_queue.pop(0)\n        return self.gen()\n\n\ndef gen_payload(prompt, output_len):\n    payload = {\n        \"text\": prompt,\n        \"sampling_params\": {\n            \"temperature\": 0.0,\n            \"max_new_tokens\": output_len,\n            \"ignore_eos\": True,\n        },\n        \"stream\": True,\n        \"stream_options\": {\"include_usage\": True},\n        \"lora_path\": \"\",\n        \"return_logprob\": False,\n        \"logprob_start_len\": -1,\n    }\n    return payload\n\n\nAIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=20 * 60 * 60)\n\n\nasync def async_request_sglang_generate(\n    user_data,\n    url,\n    atomic_counter,\n):\n    \"\"\"\n    Sends a streaming request to the server. Gathers text token-by-token.\n    \"\"\"\n    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:\n        headers = {}\n        generated_text = \"\"\n        ttft = 0.0\n        st = time.perf_counter()\n        most_recent_timestamp = st\n        output = RequestFuncOutput()\n        payload = gen_payload(user_data.prompt, user_data.return_tokens)\n        write_debug_log({\"timestamp\": st, \"user_data\": user_data.__dict__})\n\n        try:\n            async with session.post(url=url, json=payload, headers=headers) as response:\n                if response.status == 200:\n                    prompt_tokens = 0\n                    cached_tokens = 0\n                    async for chunk_bytes in response.content:\n                        chunk_bytes = chunk_bytes.strip()\n                        if not chunk_bytes:\n                            continue\n\n                        chunk = remove_prefix(chunk_bytes.decode(\"utf-8\"), \"data: \")\n                        latency = time.perf_counter() - st\n                        if chunk == \"[DONE]\":\n                            pass\n                        else:\n                            data = json.loads(chunk)\n\n                            if data.get(\"text\"):\n                                timestamp = time.perf_counter()\n                                # First token\n                                if ttft == 0.0:\n                                    ttft = time.perf_counter() - st\n                                    output.ttft = ttft\n                                    prompt_tokens = (data.get(\"meta_info\") or {}).get(\n                                        \"prompt_tokens\", 0\n                                    )\n                                    cached_tokens = (data.get(\"meta_info\") or {}).get(\n                                        \"cached_tokens\", 0\n                                    )\n\n                                # Decoding phase\n                                else:\n                                    output.itl.append(timestamp - most_recent_timestamp)\n\n                                most_recent_timestamp = timestamp\n                                generated_text = data[\"text\"]\n\n                    output.generated_text = generated_text\n                    output.success = True\n                    output.latency = latency\n                    output.prompt_len = prompt_tokens\n                    output.cached_tokens = cached_tokens\n                else:\n                    output.error = response.reason or \"\"\n                    output.success = False\n        except Exception as e:\n            output.success = False\n            output.error = str(e)\n            print(f\"Request failed: {e}\")\n\n    atomic_counter.increment(1)\n    return output\n\n\nclass AtomicCounter:\n    def __init__(self, initial_value=0):\n        self._value = initial_value\n        self.lock = threading.Lock()\n\n    @synchronized()\n    def increment(self, amount=1):\n        self._value += amount\n\n    @synchronized()\n    def get(self):\n        return self._value\n\n\nclass WorkloadGenerator:\n    def __init__(self, args):\n        config = load_config()\n        user_generator = UserGenerator(\n            config,\n            args.model_path,\n            args.dataset_path,\n        )\n\n        self.url = f\"http://{args.host}:{args.port}/generate\"\n\n        self.tokenizer = user_generator.tokenizer\n        self.start_time = None\n        self.finished_time = None\n        self.duration = args.duration\n        self.done = False\n\n        self.sent_requests = 0\n        self.completed_requests = 0\n\n        self.user_generator = user_generator\n        self.response_queue = queue.Queue()\n        self.performance_metrics = {\n            \"ttft\": [],\n            \"latency\": [],\n            \"prompt_len\": [],\n            \"cached_tokens\": [],\n        }\n        self.max_parallel = config[\"num_clients\"]\n\n        self.atomic_counter = AtomicCounter()\n\n    async def handle_request(self, user_data):\n        try:\n            response = await async_request_sglang_generate(\n                user_data, self.url, self.atomic_counter\n            )\n            self.response_queue.put((user_data, response))\n        except Exception as e:\n            print(f\"Request failed: {e}\")\n            self.completed_requests += 1\n\n    def request_sender(self):\n        async def request_loop():\n            while True:\n                if self.sent_requests - self.completed_requests < self.max_parallel:\n                    new_request = self.user_generator.pop()\n                    if new_request:\n                        asyncio.create_task(self.handle_request(new_request))\n                        self.sent_requests += 1\n                else:\n                    await asyncio.sleep(0.05)\n                    continue\n\n                if time.perf_counter() - self.start_time > self.duration:\n                    self.done = True\n                    break\n\n        loop = asyncio.new_event_loop()\n        asyncio.set_event_loop(loop)\n        loop.run_until_complete(request_loop())\n        loop.close()\n\n    def response_handler(self):\n        while True:\n            try:\n                user_data, response = self.response_queue.get(timeout=10)\n                logger.info(\n                    f\"{((time.perf_counter()-self.start_time)/self.duration*100):.2f}%\"\n                )\n                if not response.success:\n                    raise ValueError(f\"Request failed with error: {response.error}\")\n\n                self.user_generator.push(\n                    user_data, response.generated_text, len(response.itl)\n                )\n                self.performance_metrics[\"ttft\"].append(response.ttft)\n                self.performance_metrics[\"latency\"].append(response.latency)\n                self.performance_metrics[\"prompt_len\"].append(response.prompt_len)\n                self.performance_metrics[\"cached_tokens\"].append(response.cached_tokens)\n                self.completed_requests += 1\n                self.finished_time = time.perf_counter()\n\n            except queue.Empty:\n                if self.done:\n                    break\n            except ValueError as e:\n                print(f\"Error processing response for client {user_data}: {e}\")\n                continue\n\n    def run(self):\n        request_thread = threading.Thread(target=self.request_sender, daemon=True)\n        response_thread = threading.Thread(target=self.response_handler, daemon=True)\n\n        self.start_time = time.perf_counter()\n        request_thread.start()\n        response_thread.start()\n\n        request_thread.join()\n        response_thread.join()\n\n        performance_data = {\n            \"summary\": {\n                \"total_requests\": len(self.performance_metrics[\"ttft\"]),\n                \"average_ttft\": sum(self.performance_metrics[\"ttft\"])\n                / len(self.performance_metrics[\"ttft\"]),\n                \"p90_ttft\": sorted(self.performance_metrics[\"ttft\"])[\n                    int(0.9 * len(self.performance_metrics[\"ttft\"]))\n                ],\n                \"median_ttft\": sorted(self.performance_metrics[\"ttft\"])[\n                    len(self.performance_metrics[\"ttft\"]) // 2\n                ],\n                \"average_latency\": sum(self.performance_metrics[\"latency\"])\n                / len(self.performance_metrics[\"latency\"]),\n                \"p90_latency\": sorted(self.performance_metrics[\"latency\"])[\n                    int(0.9 * len(self.performance_metrics[\"latency\"]))\n                ],\n                \"median_latency\": sorted(self.performance_metrics[\"latency\"])[\n                    len(self.performance_metrics[\"latency\"]) // 2\n                ],\n                \"throughput\": self.atomic_counter.get()\n                / (self.finished_time - self.start_time),\n                \"cache_hit_rate\": (\n                    0\n                    if sum(self.performance_metrics[\"prompt_len\"]) == 0\n                    else sum(self.performance_metrics[\"cached_tokens\"])\n                    / sum(self.performance_metrics[\"prompt_len\"])\n                ),\n            },\n        }\n        print(\"All requests completed\")\n        print(\"Performance metrics summary:\")\n        print(f\"  Total requests: {performance_data['summary']['total_requests']}\")\n        print(f\"  Average TTFT: {performance_data['summary']['average_ttft']:.2f}\")\n        print(f\"  P90 TTFT: {performance_data['summary']['p90_ttft']:.2f}\")\n        print(f\"  Median TTFT: {performance_data['summary']['median_ttft']:.2f}\")\n        print(\n            f\"  Average latency: {performance_data['summary']['average_latency']:.2f}\"\n        )\n        print(f\"  P90 latency: {performance_data['summary']['p90_latency']:.2f}\")\n        print(f\"  Median latency: {performance_data['summary']['median_latency']:.2f}\")\n        print(\n            f\"  Throughput: {performance_data['summary']['throughput']:.2f} requests per second\"\n        )\n        print(f\"  Cache Hit Rate: {performance_data['summary']['cache_hit_rate']:.6f}\")\n\n        user_stats = self.user_generator.user_stats\n        input_stats = self.user_generator.input_stats\n        output_stats = self.user_generator.output_stats\n        print(f\"round_ratios: {user_stats}\")\n        print(\n            f\"mean_new_tokens_per_round: {[int(a/b) if b > 0 else 0 for a, b in input_stats]}\"\n        )\n        print(\n            f\"mean_return_tokens_per_round: {[int(a/b) if b > 0 else 0 for a, b in output_stats]}\"\n        )\n        return performance_data\n\n\ndef main():\n    global debug_log_file\n\n    args = parse_args()\n    if args.log_level == \"debug\":\n        logging.basicConfig(level=logging.DEBUG)\n        logger.info(\"use log_level debug\")\n        # Initialize debug log file\n        debug_log_file = open(args.debug_log_file, \"w\")\n    else:\n        logging.basicConfig(level=logging.INFO)\n        logger.info(\"use log_level info\")\n    performance_data = WorkloadGenerator(args).run()\n\n    # Close debug log file if it was opened\n    if debug_log_file:\n        debug_log_file.close()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "benchmark/hicache/bench_mix.sh",
    "content": "#!/bin/bash\n\nexport LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.12/dist-packages:/usr/local/lib/python3.12/dist-packages/torch/lib\nrm -rf nohup.out && \\\nnohup python3 -m sglang.launch_server \\\n    --attention-backend triton \\\n    --model-path /code/models/Qwen3-32B/ \\\n    --log-level info \\\n    --tp 4 --mem-frac 0.25 \\\n    --host 0.0.0.0 --port 33301 \\\n    --enable-metrics --enable-cache-report \\\n    --page-size 64 \\\n    --enable-hierarchical-cache \\\n    --hicache-ratio 2.5 --hicache-size 0 \\\n    --hicache-io-backend kernel \\\n    --hicache-mem-layout layer_first \\\n    --hicache-write-policy write_through \\\n    &\n\n##################################################\n\nexport CONFIG_PATH=/tmp/bench_mix_config.json\n\n# num_clients: Maximum number of concurrent client requests to be simulated\n# round_ratios: Distribution of requests across rounds. Given sum(round_ratios) total requests,\n#               round_ratios[i] denotes the number of requests that will execute for (i+1) rounds\necho '{\n  \"num_rounds\": 10,\n  \"num_clients\": 60,\n  \"round_ratios\": [50, 25, 15, 15, 10, 10, 9, 8, 7, 6],\n  \"mean_new_tokens_per_round\": [1000, 400, 350, 300, 280, 260, 240, 220, 210, 200],\n  \"mean_return_tokens_per_round\": [100, 100, 100, 100, 100, 100, 100, 100, 100, 100],\n  \"mean_inter_round_interval\": [30, 30, 30, 30, 30, 30, 30, 30, 30, 30]\n}' > ${CONFIG_PATH}\n\nrm -rf bench_mix.out && \\\nnohup python3 /sgl-workspace/sglang/benchmark/hicache/bench_mix.py \\\n    --model-path /code/models/Qwen3-32B/ \\\n    --dataset-path /code/models/ShareGPT_V3_unfiltered_cleaned_split.json \\\n    --port 33301 \\\n    --duration 600 \\\n> bench_mix.out &\n"
  },
  {
    "path": "benchmark/hicache/bench_multiturn.py",
    "content": "import argparse\nimport asyncio\nimport json\nimport queue\nimport random\nimport threading\nimport time\nfrom datetime import datetime\n\nimport numpy as np\nimport requests\nfrom tqdm.asyncio import tqdm\n\nfrom sglang.bench_serving import RequestFuncOutput\nfrom sglang.benchmark.datasets.random import sample_random_requests\nfrom sglang.benchmark.utils import get_tokenizer\nfrom sglang.test.kits.cache_hit_kit import (\n    async_request_openai_chat_completions,\n    async_request_sglang_generate,\n    gen_payload,\n    gen_payload_openai,\n)\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description=\"Script to benchmark concurrent requests to a server.\"\n    )\n    parser.add_argument(\n        \"--num-clients\",\n        type=int,\n        default=256,\n        help=\"Number of concurrent clients\",\n    )\n    parser.add_argument(\n        \"--max-parallel\",\n        type=int,\n        default=128,\n        help=\"Maximum number of parallel requests\",\n    )\n    parser.add_argument(\n        \"--request-length\",\n        type=int,\n        default=512,\n        help=\"Length of each new request\",\n    )\n    parser.add_argument(\n        \"--output-length\",\n        type=int,\n        default=64,\n        help=\"Length of each output\",\n    )\n    parser.add_argument(\n        \"--num-rounds\",\n        type=int,\n        default=5,\n        help=\"Number of rounds per client\",\n    )\n    parser.add_argument(\n        \"--distribution\",\n        type=str,\n        default=\"poisson\",\n        choices=[\"poisson\", \"uniform\"],\n        help=\"Distribution type for request intervals (poisson or uniform)\",\n    )\n    parser.add_argument(\n        \"--request-rate\",\n        type=float,\n        default=1.0,\n        help=\"Average number of requests per second\",\n    )\n    parser.add_argument(\n        \"--host\",\n        type=str,\n        default=\"localhost\",\n        help=\"Server hostname or IP (default: localhost)\",\n    )\n    parser.add_argument(\n        \"--port\",\n        type=int,\n        default=30000,\n        help=\"Server port (default: 30000)\",\n    )\n    parser.add_argument(\n        \"--model-path\",\n        type=str,\n        default=\"meta-llama/Llama-3.1-8B-Instruct\",\n        help=\"model path compatible with Hugging Face Transformers\",\n    )\n    parser.add_argument(\n        \"--dataset-path\",\n        type=str,\n        default=\"\",\n        help=\"local dataset to sample tokens from\",\n    )\n    parser.add_argument(\n        \"--log-file\",\n        type=str,\n        default=\"performance_metrics.jsonl\",\n        help=\"File to log performance metrics\",\n    )\n    parser.add_argument(\n        \"--disable-auto-run\",\n        action=\"store_true\",\n        help=\"If set, disable automatically testing with a range of request rates.\",\n    )\n    parser.add_argument(\n        \"--disable-random-sample\",\n        action=\"store_true\",\n        help=\"If set, disable random sampling of requests from the ShareGPT dataset.\",\n    )\n    parser.add_argument(\n        \"--enable-round-barrier\",\n        action=\"store_true\",\n        help=\"If set, only send i-th turn requests after all (i-1)-th turn requests finished.\",\n    )\n    parser.add_argument(\n        \"--sub-question-input-length\",\n        type=int,\n        default=0,\n        help=\"Length of the sub question input for each request, if set 0 use request_length\",\n    )\n    parser.add_argument(\n        \"--ready-queue-policy\",\n        type=str,\n        default=\"random\",\n        help=\"Policy for popping requests from the ready queue (random or fifo)\",\n    )\n    parser.add_argument(\n        \"--tag\",\n        type=str,\n        default=\"\",\n        help=\"Tag of a certain run in the log file\",\n    )\n    parser.add_argument(\n        \"--min-rounds\",\n        type=int,\n        default=0,\n        help=\"Min rounds per client (0 = use --num-rounds)\",\n    )\n    parser.add_argument(\n        \"--max-rounds\",\n        type=int,\n        default=0,\n        help=\"Max rounds per client (0 = use --num-rounds)\",\n    )\n    parser.add_argument(\n        \"--range-ratio\",\n        type=float,\n        default=1.0,\n        help=\"Length variation ratio for prompts and outputs (1.0 = no variation, 0.5 = 50%% variation)\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=1, help=\"The random seed.\")\n    parser.add_argument(\n        \"--lora-path\",\n        type=str,\n        default=\"\",\n        help=\"String of LoRA path. Currently we only support benchmarking on a single LoRA adaptor.\",\n    )\n    parser.add_argument(\n        \"--api-format\",\n        type=str,\n        default=\"sglang\",\n        choices=[\"sglang\", \"openai\"],\n        help=\"API format to use: 'sglang' for native /generate endpoint, \"\n        \"'openai' for OpenAI-compatible /v1/chat/completions endpoint.\",\n    )\n    return parser.parse_args()\n\n\ndef log_to_jsonl_file(data, file_path=\"performance_metrics.jsonl\", tag=\"\"):\n    \"\"\"Append the data with a timestamp and tag to the specified JSONL file.\"\"\"\n    timestamped_data = {\"timestamp\": datetime.now().isoformat(), \"tag\": tag, **data}\n    try:\n        with open(file_path, \"a\") as file:\n            file.write(\n                json.dumps(timestamped_data) + \"\\n\"\n            )  # Write as a single line in JSONL format\n    except IOError as e:\n        print(f\"Error writing to JSONL file: {e}\")\n\n\nclass ReadyQueue:\n    \"\"\"\n    Thread-safe queue that can pop requests in different orders based on given policy.\n    \"\"\"\n\n    def __init__(self, init_requests=None, policy=\"random\"):\n        self.lock = threading.Lock()\n        self.requests = init_requests or []\n        self.policy = policy\n\n    def append(self, item):\n        with self.lock:\n            self.requests.append(item)\n\n    def pop(self):\n        with self.lock:\n            if not self.requests:\n                return None\n            if self.policy == \"random\":\n                index = random.randrange(len(self.requests))\n                return self.requests.pop(index)\n            elif self.policy == \"fifo\":\n                return self.requests.pop(0)\n            else:\n                # todo, varying thinking time of clients\n                raise ValueError(f\"{self.policy} not implemented\")\n\n\nclass WorkloadGenerator:\n    def __init__(self, args):\n        self.api_format = args.api_format\n        self.model_path = args.model_path\n\n        # Construct the base URL and select request/payload functions\n        if self.api_format == \"openai\":\n            self.url = f\"http://{args.host}:{args.port}/v1/chat/completions\"\n            self.request_func = async_request_openai_chat_completions\n        else:\n            self.url = f\"http://{args.host}:{args.port}/generate\"\n            self.request_func = async_request_sglang_generate\n\n        self.tokenizer = get_tokenizer(args.model_path)\n        self.distribution = args.distribution\n        self.request_rate = args.request_rate\n        self.start_time = None\n        self.finished_time = None\n        self.lora_path = args.lora_path\n\n        self.sent_requests = 0\n        self.completed_requests = 0\n\n        # Resolve per-client round counts\n        min_rounds = args.min_rounds\n        max_rounds = args.max_rounds\n        if min_rounds == 0 and max_rounds == 0:\n            # Backward compat: all clients use --num-rounds\n            min_rounds = args.num_rounds\n            max_rounds = args.num_rounds\n        elif min_rounds == 0:\n            min_rounds = max_rounds\n        elif max_rounds == 0:\n            max_rounds = min_rounds\n        if min_rounds < 1:\n            raise ValueError(f\"--min-rounds must be >= 1, got {min_rounds}\")\n        if min_rounds > max_rounds:\n            raise ValueError(\n                f\"--min-rounds ({min_rounds}) must be <= --max-rounds ({max_rounds})\"\n            )\n\n        self.min_rounds = min_rounds\n        self.max_rounds = max_rounds\n\n        if min_rounds == max_rounds:\n            # All clients have the same round count; skip randint to preserve random state\n            self.client_total_rounds = [min_rounds] * args.num_clients\n        else:\n            self.client_total_rounds = [\n                random.randint(min_rounds, max_rounds) for _ in range(args.num_clients)\n            ]\n\n        # clients_per_round[r] = number of clients participating in round r\n        self.clients_per_round = [\n            sum(1 for t in self.client_total_rounds if t > r) for r in range(max_rounds)\n        ]\n        self.total_requests = sum(self.client_total_rounds)\n\n        range_ratio = args.range_ratio\n\n        # Use return_text=False to get token ids instead of text\n        first_round_samples = sample_random_requests(\n            input_len=args.request_length,\n            output_len=args.output_length,\n            num_prompts=args.num_clients,\n            range_ratio=range_ratio,\n            tokenizer=self.tokenizer,\n            dataset_path=args.dataset_path,\n            random_sample=not args.disable_random_sample,\n            return_text=False,\n        )\n        # Store per-sample output_len for first round\n        first_round_output_lens = [row.output_len for row in first_round_samples]\n        # r.prompt is now List[int] when return_text=False\n        self.candidate_inputs = [list(i.prompt) for i in first_round_samples]\n\n        if args.sub_question_input_length != 0:\n            sub_question_input_length = args.sub_question_input_length\n        else:\n            sub_question_input_length = args.request_length\n\n        num_sub_questions = sum(max(t - 1, 0) for t in self.client_total_rounds)\n\n        self.sub_question_inputs = sample_random_requests(\n            input_len=sub_question_input_length,\n            output_len=args.output_length,\n            num_prompts=max(num_sub_questions, 1),\n            range_ratio=range_ratio,\n            tokenizer=self.tokenizer,\n            dataset_path=args.dataset_path,\n            random_sample=not args.disable_random_sample,\n            return_text=False,\n        )\n\n        if self.api_format == \"openai\":\n            # OpenAI mode: history is a messages list for /v1/chat/completions\n            initial_messages = {\n                i: [\n                    {\n                        \"role\": \"user\",\n                        \"content\": self.tokenizer.decode(self.candidate_inputs[i]),\n                    }\n                ]\n                for i in range(args.num_clients)\n            }\n            init_requests = [\n                (\n                    i,\n                    gen_payload_openai(\n                        initial_messages[i],\n                        first_round_output_lens[i],\n                        self.model_path,\n                    ),\n                )\n                for i in range(args.num_clients)\n            ]\n            self.client_records = {\n                i: {\n                    \"round\": 0,\n                    \"history\": initial_messages[i],\n                    \"total_rounds\": self.client_total_rounds[i],\n                }\n                for i in range(args.num_clients)\n            }\n        else:\n            # SGLang mode: history is List[int] (token ids)\n            init_requests = [\n                (\n                    i,\n                    gen_payload(\n                        self.candidate_inputs[i],\n                        first_round_output_lens[i],\n                        args.lora_path,\n                    ),\n                )\n                for i in range(args.num_clients)\n            ]\n            self.client_records = {\n                i: {\n                    \"round\": 0,\n                    \"history\": list(self.candidate_inputs[i]),\n                    \"total_rounds\": self.client_total_rounds[i],\n                }\n                for i in range(args.num_clients)\n            }\n        self.ready_queue = ReadyQueue(\n            init_requests=init_requests, policy=args.ready_queue_policy\n        )\n        self.candidate_inputs = self.candidate_inputs[args.num_clients :]\n\n        self.response_queue = queue.Queue()\n        self.pbar = tqdm(total=self.total_requests)\n        self.performance_metrics = {\n            \"ttft\": [],\n            \"itl\": [],\n            \"latency\": [],\n            \"prompt_len\": [],\n            \"cached_tokens\": [],\n            \"generated_len\": [],\n        }\n        self.enable_round_barrier = args.enable_round_barrier\n        if self.enable_round_barrier:\n            # Add round-specific metrics while preserving the original structure\n            for i in range(self.max_rounds):\n                self.performance_metrics[f\"round_{i}\"] = {\n                    \"ttft\": [],\n                    \"latency\": [],\n                    \"prompt_len\": [],\n                    \"cached_tokens\": [],\n                    \"generated_len\": [],\n                }\n        self.num_clients = args.num_clients\n\n        self.num_rounds = self.max_rounds\n        self.max_parallel = args.max_parallel\n        self.output_length = args.output_length\n\n    async def handle_request(self, item):\n        client_id, payload = item\n        try:\n            response = await self.request_func(payload, self.url, self.pbar)\n            if self.pbar.n == self.pbar.total:\n                self.finished_time = time.perf_counter()\n            self.response_queue.put((client_id, response))\n        except Exception as e:\n            print(f\"Request failed for client {client_id}: {e}\")\n            failed_response = RequestFuncOutput()\n            failed_response.success = False\n            failed_response.error = str(e)\n            self.response_queue.put((client_id, failed_response))\n\n    def request_sender(self):\n        async def request_loop():\n            while True:\n                if self.sent_requests - self.completed_requests < self.max_parallel:\n                    new_request = self.ready_queue.pop()\n                    if new_request:\n                        asyncio.create_task(self.handle_request(new_request))\n                        self.sent_requests += 1\n                else:\n                    await asyncio.sleep(0.05)\n                    continue\n\n                if self.pbar.n == self.pbar.total:\n                    break\n\n                # Calculate Poisson-distributed wait time\n                if self.distribution == \"poisson\":\n                    sleep_time = random.expovariate(self.request_rate)\n                elif self.distribution == \"uniform\":\n                    avg_interval = (\n                        1.0 / self.request_rate if self.request_rate > 0 else 1.0\n                    )\n                    sleep_time = random.uniform(0, 2 * avg_interval)\n                else:\n                    raise ValueError(\"Invalid distribution type\")\n                await asyncio.sleep(sleep_time)  # Wait before sending the next request\n\n        # Create and run the event loop for asynchronous requests\n        loop = asyncio.new_event_loop()\n        asyncio.set_event_loop(loop)\n        loop.run_until_complete(request_loop())\n        loop.close()\n\n    def response_handler(self):\n        next_round_reqs = []\n        current_barrier_round = 0\n        barrier_round_completed = 0\n        while True:\n            try:\n                client_id, response = self.response_queue.get(\n                    timeout=10\n                )  # Block until response is available\n                if not response.success:\n                    print(f\"Request failed for client {client_id}: {response.error}\")\n                    self.completed_requests += 1\n                    continue\n                # Extend history with response\n                if self.api_format == \"openai\":\n                    if response.generated_text:\n                        self.client_records[client_id][\"history\"].append(\n                            {\"role\": \"assistant\", \"content\": response.generated_text}\n                        )\n                else:\n                    self.client_records[client_id][\"history\"].extend(\n                        response.output_ids\n                    )\n                current_round = self.client_records[client_id][\"round\"]\n                self.client_records[client_id][\"round\"] += 1\n                self.performance_metrics[\"ttft\"].append(response.ttft)\n                self.performance_metrics[\"itl\"].extend(response.itl)\n                self.performance_metrics[\"latency\"].append(response.latency)\n                self.performance_metrics[\"prompt_len\"].append(response.prompt_len)\n                self.performance_metrics[\"cached_tokens\"].append(response.cached_tokens)\n                self.performance_metrics[\"generated_len\"].append(response.generated_len)\n                if self.enable_round_barrier:\n                    self.performance_metrics[f\"round_{current_round}\"][\"ttft\"].append(\n                        response.ttft\n                    )\n                    self.performance_metrics[f\"round_{current_round}\"][\n                        \"latency\"\n                    ].append(response.latency)\n                    self.performance_metrics[f\"round_{current_round}\"][\n                        \"prompt_len\"\n                    ].append(response.prompt_len)\n                    self.performance_metrics[f\"round_{current_round}\"][\n                        \"cached_tokens\"\n                    ].append(response.cached_tokens)\n                    self.performance_metrics[f\"round_{current_round}\"][\n                        \"generated_len\"\n                    ].append(response.generated_len)\n                self.completed_requests += 1\n\n                client_total = self.client_records[client_id][\"total_rounds\"]\n                if self.client_records[client_id][\"round\"] < client_total:\n                    sub_q = self.sub_question_inputs.pop()\n                    if self.api_format == \"openai\":\n                        # Append sub-question as a new user message\n                        sub_q_text = self.tokenizer.decode(list(sub_q.prompt))\n                        self.client_records[client_id][\"history\"].append(\n                            {\"role\": \"user\", \"content\": sub_q_text}\n                        )\n                        new_req = (\n                            client_id,\n                            gen_payload_openai(\n                                self.client_records[client_id][\"history\"],\n                                sub_q.output_len,\n                                self.model_path,\n                            ),\n                        )\n                    else:\n                        # Append sub-question token ids to client's history\n                        sub_q_ids = list(sub_q.prompt)\n                        self.client_records[client_id][\"history\"].extend(sub_q_ids)\n                        new_req = (\n                            client_id,\n                            gen_payload(\n                                self.client_records[client_id][\"history\"],\n                                sub_q.output_len,\n                                self.lora_path,\n                            ),\n                        )\n                    if self.enable_round_barrier:\n                        next_round_reqs.append(new_req)\n                    else:\n                        self.ready_queue.append(new_req)\n\n                # Barrier logic: release next round when all clients for\n                # current barrier round have completed\n                if (\n                    self.enable_round_barrier\n                    and current_barrier_round < self.max_rounds\n                ):\n                    barrier_round_completed += 1\n                    expected = self.clients_per_round[current_barrier_round]\n                    if barrier_round_completed == expected:\n                        print(\n                            f\"\\n  Barrier: round {current_barrier_round} complete \"\n                            f\"({expected} clients), releasing {len(next_round_reqs)} \"\n                            f\"requests for round {current_barrier_round + 1}\"\n                        )\n                        for req in next_round_reqs:\n                            self.ready_queue.append(req)\n                        next_round_reqs = []\n                        current_barrier_round += 1\n                        barrier_round_completed = 0\n            except queue.Empty:\n                if self.pbar.n == self.pbar.total:\n                    break\n            except ValueError as e:\n                print(f\"Error processing response for client {client_id}: {e}\")\n                continue\n\n    def run(self):\n        request_thread = threading.Thread(target=self.request_sender, daemon=True)\n        response_thread = threading.Thread(target=self.response_handler, daemon=True)\n\n        self.start_time = time.perf_counter()\n        request_thread.start()\n        response_thread.start()\n\n        request_thread.join()\n        response_thread.join()\n        self.pbar.close()\n\n        duration = self.finished_time - self.start_time\n        sorted_ttft = sorted(self.performance_metrics[\"ttft\"])\n        sorted_latency = sorted(self.performance_metrics[\"latency\"])\n        sorted_itl = sorted(self.performance_metrics[\"itl\"])\n        sorted_prompt_len = sorted(self.performance_metrics[\"prompt_len\"])\n        sorted_output_len = sorted(self.performance_metrics[\"generated_len\"])\n\n        def percentile(sorted_vals, q):\n            if not sorted_vals:\n                return 0.0\n            idx = int(q * len(sorted_vals))\n            if idx >= len(sorted_vals):\n                idx = len(sorted_vals) - 1\n            return sorted_vals[idx]\n\n        def max_or_zero(sorted_vals):\n            return sorted_vals[-1] if sorted_vals else 0.0\n\n        performance_data = {\n            \"summary\": {\n                \"total_requests\": len(self.performance_metrics[\"ttft\"]),\n                \"request_rate\": self.request_rate,\n                \"average_prompt_len\": (\n                    sum(self.performance_metrics[\"prompt_len\"])\n                    / len(self.performance_metrics[\"prompt_len\"])\n                    if self.performance_metrics[\"prompt_len\"]\n                    else 0.0\n                ),\n                \"average_output_len\": (\n                    sum(self.performance_metrics[\"generated_len\"])\n                    / len(self.performance_metrics[\"generated_len\"])\n                    if self.performance_metrics[\"generated_len\"]\n                    else 0.0\n                ),\n                \"p90_prompt_len\": percentile(sorted_prompt_len, 0.9),\n                \"p99_prompt_len\": percentile(sorted_prompt_len, 0.99),\n                \"p90_output_len\": percentile(sorted_output_len, 0.9),\n                \"p99_output_len\": percentile(sorted_output_len, 0.99),\n                \"average_ttft\": sum(self.performance_metrics[\"ttft\"])\n                / len(self.performance_metrics[\"ttft\"]),\n                \"p90_ttft\": percentile(sorted_ttft, 0.9),\n                \"p99_ttft\": percentile(sorted_ttft, 0.99),\n                \"median_ttft\": percentile(sorted_ttft, 0.5),\n                \"max_ttft\": max_or_zero(sorted_ttft),\n                \"average_itl\": (\n                    sum(self.performance_metrics[\"itl\"])\n                    / len(self.performance_metrics[\"itl\"])\n                    if self.performance_metrics[\"itl\"]\n                    else 0.0\n                ),\n                \"p90_itl\": percentile(sorted_itl, 0.9),\n                \"p99_itl\": percentile(sorted_itl, 0.99),\n                \"median_itl\": percentile(sorted_itl, 0.5),\n                \"max_itl\": max_or_zero(sorted_itl),\n                \"average_latency\": sum(self.performance_metrics[\"latency\"])\n                / len(self.performance_metrics[\"latency\"]),\n                \"p90_latency\": percentile(sorted_latency, 0.9),\n                \"p99_latency\": percentile(sorted_latency, 0.99),\n                \"median_latency\": percentile(sorted_latency, 0.5),\n                \"max_latency\": max_or_zero(sorted_latency),\n                \"input_token_throughput\": sum(self.performance_metrics[\"prompt_len\"])\n                / duration,\n                \"output_token_throughput\": sum(\n                    self.performance_metrics[\"generated_len\"]\n                )\n                / duration,\n                \"throughput\": self.pbar.total / duration,\n                \"cache_hit_rate\": (\n                    0\n                    if sum(self.performance_metrics[\"prompt_len\"]) == 0\n                    else sum(self.performance_metrics[\"cached_tokens\"])\n                    / sum(self.performance_metrics[\"prompt_len\"])\n                ),\n            },\n        }\n        if self.enable_round_barrier:\n            performance_data[\"round\"] = {}\n            for round_num in range(self.num_rounds):\n                round_key = f\"round_{round_num}\"\n                round_metrics = self.performance_metrics[round_key]\n                performance_data[\"round\"][round_key] = {\n                    \"average_ttft\": (\n                        sum(round_metrics[\"ttft\"]) / len(round_metrics[\"ttft\"])\n                        if round_metrics[\"ttft\"]\n                        else 0\n                    ),\n                    \"cache_hit_rate\": (\n                        0\n                        if sum(round_metrics[\"prompt_len\"]) == 0\n                        else sum(round_metrics[\"cached_tokens\"])\n                        / sum(round_metrics[\"prompt_len\"])\n                    ),\n                    \"request_count\": len(round_metrics[\"ttft\"]),\n                }\n        print(\"All requests completed\")\n        print(\"Performance metrics summary:\")\n        print(\n            f\"  Total requests: {performance_data['summary']['total_requests']} at {performance_data['summary']['request_rate']} requests per second\"\n        )\n        print(\n            f\"  Average Prompt Length: {performance_data['summary']['average_prompt_len']:.2f} tokens\"\n        )\n        print(\n            f\"  Average Output Length: {performance_data['summary']['average_output_len']:.2f} tokens\"\n        )\n        print(\n            f\"  P90 Prompt Length: {performance_data['summary']['p90_prompt_len']:.0f} tokens\"\n        )\n        print(\n            f\"  P99 Prompt Length: {performance_data['summary']['p99_prompt_len']:.0f} tokens\"\n        )\n        print(\n            f\"  P90 Output Length: {performance_data['summary']['p90_output_len']:.0f} tokens\"\n        )\n        print(\n            f\"  P99 Output Length: {performance_data['summary']['p99_output_len']:.0f} tokens\"\n        )\n        print(f\"  Average TTFT: {performance_data['summary']['average_ttft']:.2f}\")\n        print(f\"  P90 TTFT: {performance_data['summary']['p90_ttft']:.2f}\")\n        print(f\"  P99 TTFT: {performance_data['summary']['p99_ttft']:.2f}\")\n        print(f\"  Median TTFT: {performance_data['summary']['median_ttft']:.2f}\")\n        print(f\"  Max TTFT: {performance_data['summary']['max_ttft']:.2f}\")\n        print(f\"  Average ITL: {performance_data['summary']['average_itl']:.4f}\")\n        print(f\"  P90 ITL: {performance_data['summary']['p90_itl']:.4f}\")\n        print(f\"  P99 ITL: {performance_data['summary']['p99_itl']:.4f}\")\n        print(f\"  Median ITL: {performance_data['summary']['median_itl']:.4f}\")\n        print(f\"  Max ITL: {performance_data['summary']['max_itl']:.4f}\")\n        print(\n            f\"  Average latency: {performance_data['summary']['average_latency']:.2f}\"\n        )\n        print(f\"  P90 latency: {performance_data['summary']['p90_latency']:.2f}\")\n        print(f\"  P99 latency: {performance_data['summary']['p99_latency']:.2f}\")\n        print(f\"  Median latency: {performance_data['summary']['median_latency']:.2f}\")\n        print(f\"  Max latency: {performance_data['summary']['max_latency']:.2f}\")\n        print(\n            f\"  Input token throughput: {performance_data['summary']['input_token_throughput']:.2f} tokens per second\"\n        )\n        print(\n            f\"  Output token throughput: {performance_data['summary']['output_token_throughput']:.2f} tokens per second\"\n        )\n        print(\n            f\"  Request Throughput: {performance_data['summary']['throughput']:.2f} requests per second\"\n        )\n        print(f\"  Cache Hit Rate: {performance_data['summary']['cache_hit_rate']:.6f}\")\n\n        if self.enable_round_barrier:\n            # Print round-basedsummary\n            print(\"Per-round metrics:\")\n            if \"round\" in performance_data:\n                for round_num in range(self.num_rounds):\n                    round_key = f\"round_{round_num}\"\n                    if round_key in performance_data[\"round\"]:\n                        round_data = performance_data[\"round\"][round_key]\n                        avg_ttft = round_data[\"average_ttft\"]\n                        cache_hit_rate = round_data[\"cache_hit_rate\"]\n                        request_count = round_data[\"request_count\"]\n                        clients_in_round = self.clients_per_round[round_num]\n                        print(\n                            f\"  Round {round_num}: Average TTFT = {avg_ttft:.2f}s, \"\n                            f\"Cache Hit Rate = {cache_hit_rate:.6f} \"\n                            f\"({request_count} requests, \"\n                            f\"{clients_in_round} clients)\"\n                        )\n                    else:\n                        print(f\"  Round {round_num}: No requests completed\")\n\n        return performance_data\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    flush_cache_url = f\"http://{args.host}:{args.port}/flush_cache\"\n\n    random.seed(args.seed)\n    np.random.seed(args.seed)\n\n    if args.disable_auto_run:\n        print(\"Running with specified request rate...\")\n        request_rates = [args.request_rate]\n    else:\n        print(\"Auto-running with different request rates...\")\n        request_rates = [16, 14, 12, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]\n\n    for rate in request_rates:\n        args.request_rate = rate\n        requests.post(flush_cache_url)\n        time.sleep(1)\n        performance_data = WorkloadGenerator(args).run()\n        log_to_jsonl_file(performance_data, args.log_file, tag=args.tag)\n"
  },
  {
    "path": "benchmark/hicache/bench_serving.py",
    "content": "# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/backend_request_func.py\n# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/benchmark_serving.py\n\n\"\"\"\nBenchmark online serving with dynamic requests.\n\nUsage:\npython3 -m sglang.bench_serving --backend sglang --num-prompt 10\n\npython3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 3000 --random-input 1024 --random-output 1024 --random-range-ratio 0.5\npython3 -m sglang.bench_serving --backend sglang --dataset-name random --request-rate-range 1,2,4,8,16,32 --random-input 4096 --random-output 1024 --random-range-ratio 0.125 --multi\n\"\"\"\n\nimport argparse\nimport asyncio\nimport json\nimport os\nimport random\nimport sys\nimport time\nimport traceback\nimport warnings\nfrom argparse import ArgumentParser\nfrom dataclasses import dataclass, field\nfrom datetime import datetime\nfrom typing import Any, AsyncGenerator, Dict, List, Optional, Tuple\n\nimport aiohttp\nimport numpy as np\nimport requests\nfrom data_processing import MsgContent, SampleOutput, get_dataset\nfrom tqdm.asyncio import tqdm\nfrom transformers import PreTrainedTokenizerBase\n\nfrom sglang.benchmark.utils import get_tokenizer, remove_prefix, set_ulimit\n\nAIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=20 * 60 * 60)\n\nglobal args\n\n\n@dataclass\nclass RequestFuncInput:\n    prompts: List[Tuple[MsgContent, int, int]]\n    api_url: str\n    model: str\n    lora_name: str\n    extra_request_body: Dict[str, Any]\n\n    # For multiturn chat, store the context\n    prev_messages: List = field(default_factory=list)\n    finished_prompts: int = 0\n\n\n@dataclass\nclass RequestFuncOutput:\n    generated_text: List[str] = field(default_factory=list)\n    prompt_len: List[int] = field(default_factory=list)\n    output_len: List[int] = field(default_factory=list)\n    latency: List[float] = field(default_factory=list)\n    ttft: List[float] = field(default_factory=list)\n    itl: List[float] = field(default_factory=list)  # List of inter-token latencies\n\n    success: bool = False\n    error: str = \"\"\n\n\n# set ignore_eos True by default\nasync def async_request_openai_completions(\n    request_func_input: RequestFuncInput,\n    queue: asyncio.Queue,\n    tokenizer: PreTrainedTokenizerBase,\n    pbar: Optional[tqdm] = None,\n) -> RequestFuncOutput:\n    api_url = request_func_input.api_url\n    assert api_url.endswith(\n        \"completions\"\n    ), \"OpenAI Completions API URL must end with 'completions'.\"\n\n    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:\n        payload = {\n            \"model\": request_func_input.model,\n            \"temperature\": 0.0,\n            \"best_of\": 1,\n            \"stream\": not args.disable_stream,\n            \"stream_options\": {\"include_usage\": True},\n            \"ignore_eos\": not args.disable_ignore_eos,\n            **request_func_input.extra_request_body,\n        }\n        headers = {\n            \"Content-Type\": \"application/json\",\n            \"Authorization\": f\"Bearer {os.environ.get('OPENAI_API_KEY')}\",\n        }\n\n        output = RequestFuncOutput()\n\n        prompt_idx = request_func_input.finished_prompts\n        messages = request_func_input.prev_messages\n        prompt, input_len, max_tokens = request_func_input.prompts[prompt_idx]\n        prompt_len = sum(\n            prompt[1] + prompt[2]  # input_len + output_len\n            for prompt in request_func_input.prompts[:prompt_idx]\n        )\n        prompt_len += input_len\n\n        # Messages\n        messages.append(\n            {\n                \"role\": \"user\",\n                \"content\": prompt,\n            }\n        )\n        payload[\"messages\"] = messages\n        payload[\"max_tokens\"] = max_tokens\n\n        # output.prompt_len = request_func_input.prompt_len\n        # print(payload)\n\n        generated_text = \"\"\n        ttft = 0.0\n        st = time.perf_counter()\n        most_recent_timestamp = st\n        try:\n            async with session.post(\n                url=api_url, json=payload, headers=headers\n            ) as response:\n                if response.status == 200:\n                    actual_prompt_len = prompt_len - 1\n                    actual_output_len = 0\n                    async for chunk_bytes in response.content:\n                        chunk_bytes = chunk_bytes.strip()\n                        if not chunk_bytes:\n                            continue\n\n                        chunk = remove_prefix(chunk_bytes.decode(\"utf-8\"), \"data: \")\n                        latency = time.perf_counter() - st\n                        if chunk == \"[DONE]\":\n                            pass\n                        else:\n                            data = json.loads(chunk)\n                            timestamp = time.perf_counter()\n                            # NOTE: Some completion API might have a last\n                            # usage summary response without a token so we\n                            # want to check a token was generated\n                            if data[\"usage\"] is not None and len(data[\"usage\"]) > 0:\n                                actual_prompt_len = data[\"usage\"][\"prompt_tokens\"]\n                                actual_output_len = data[\"usage\"][\"completion_tokens\"]\n                                continue\n                            delta = data[\"choices\"][0][\"delta\"]\n\n                            if delta.get(\"content\", None):\n                                # First token\n                                if ttft == 0.0:\n                                    ttft = time.perf_counter() - st\n                                    output.ttft.append(ttft)\n\n                                # Decoding phase\n                                else:\n                                    output.itl.append(timestamp - most_recent_timestamp)\n\n                                generated_text += delta[\"content\"]\n                            most_recent_timestamp = timestamp\n\n                    output.prompt_len.append(actual_prompt_len)  # truncate <s>\n                    output.output_len.append(actual_output_len)\n                    output.generated_text.append(generated_text)\n                    output.success = True\n                    output.latency.append(latency)\n\n                    # Prepare for the new request\n                    request_func_input.prompts[prompt_idx] = (\n                        prompt,\n                        input_len,\n                        actual_output_len,  # changes from max_tokens to output_len\n                    )\n                    prompt_idx += 1\n                    messages.append(\n                        {\n                            \"role\": \"assistant\",\n                            \"content\": generated_text,\n                        }\n                    )\n\n                    # Move the new request to the end of the queue\n                    if prompt_idx < len(request_func_input.prompts):\n                        request_func_input.finished_prompts = prompt_idx\n                        request_func_input.prev_messages = messages\n                        await queue.put(request_func_input)\n                else:\n                    output.error = response.reason or \"\"\n                    output.success = False\n        except Exception:\n            output.success = False\n            exc_info = sys.exc_info()\n            output.error = \"\".join(traceback.format_exception(*exc_info))\n\n    if pbar:\n        pbar.update(1)\n    return output\n\n\nasync def async_request_profile(api_url: str) -> RequestFuncOutput:\n    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:\n        output = RequestFuncOutput()\n        try:\n            async with session.post(url=api_url) as response:\n                if response.status == 200:\n                    output.success = True\n                else:\n                    output.error = response.reason or \"\"\n                    output.success = False\n        except Exception:\n            output.success = False\n            exc_info = sys.exc_info()\n            output.error = \"\".join(traceback.format_exception(*exc_info))\n\n    return output\n\n\nASYNC_REQUEST_FUNCS = {\n    \"sglang\": async_request_openai_completions,\n    \"vllm\": async_request_openai_completions,\n    \"lmdeploy\": async_request_openai_completions,\n}\n\n\n@dataclass\nclass BenchmarkMetrics:\n    completed: int\n    total_input: int\n    total_output: int\n    total_output_retokenized: int\n    request_throughput: float\n    input_throughput: float\n    output_throughput: float\n    output_throughput_retokenized: float\n    total_throughput: float\n    total_throughput_retokenized: float\n    mean_ttft_ms: float\n    median_ttft_ms: float\n    std_ttft_ms: float\n    p90_ttft_ms: float\n    p99_ttft_ms: float\n    mean_tpot_ms: float\n    median_tpot_ms: float\n    std_tpot_ms: float\n    p90_tpot_ms: float\n    p99_tpot_ms: float\n    mean_itl_ms: float\n    median_itl_ms: float\n    std_itl_ms: float\n    p90_itl_ms: float\n    p99_itl_ms: float\n    mean_e2e_latency_ms: float\n    median_e2e_latency_ms: float\n    std_e2e_latency_ms: float\n    p99_e2e_latency_ms: float\n    concurrency: float\n\n\nasync def get_requests(\n    input_requests_queue: asyncio.Queue,\n    request_rate: float,\n    num_actual_requests: int,\n) -> AsyncGenerator[RequestFuncInput, None]:\n    for _ in range(num_actual_requests):\n        try:\n            request = await asyncio.wait_for(\n                input_requests_queue.get(), timeout=300\n            )  # Wait for 5 minutes then abort\n        except Exception as e:\n            print(f\"exception: {e}\")\n            break\n\n        yield request\n\n        if request_rate == float(\"inf\"):\n            continue\n\n        interval = np.random.exponential(1.0 / request_rate)\n        await asyncio.sleep(interval)\n\n\ndef calculate_metrics(\n    outputs: List[RequestFuncOutput],\n    dur_s: float,\n    tokenizer: PreTrainedTokenizerBase,\n    backend: str,\n) -> Tuple[BenchmarkMetrics, List[int]]:\n    output_lens: List[int] = []\n    retokenized_output_lens: List[int] = []\n    total_input = 0\n    completed = 0\n    itls: List[float] = []\n    tpots: List[float] = []\n    ttfts: List[float] = []\n    e2e_latencies: List[float] = []\n    output_success = 0\n    for i in range(len(outputs)):\n        if outputs[i].success:\n            output_success += 1\n            assert len(outputs[i].generated_text) == len(outputs[i].latency)\n            assert len(outputs[i].generated_text) == len(outputs[i].ttft)\n            for j in range(len(outputs[i].generated_text)):\n                output_len = outputs[i].output_len[j]\n                output_lens.append(output_len)\n                retokenized_output_len = len(\n                    tokenizer.encode(\n                        outputs[i].generated_text[j], add_special_tokens=False\n                    )\n                )\n                retokenized_output_lens.append(retokenized_output_len)\n                total_input += outputs[i].prompt_len[j]\n                if output_len > 1:\n                    tpots.append(\n                        (outputs[i].latency[j] - outputs[i].ttft[j]) / (output_len - 1)\n                    )\n\n                completed += 1\n            itls += outputs[i].itl\n            ttfts += outputs[i].ttft\n            e2e_latencies += outputs[i].latency\n\n        else:\n            output_lens.append(0)\n            retokenized_output_lens.append(0)\n\n    if completed == 0:\n        warnings.warn(\n            \"All requests failed. This is likely due to a misconfiguration \"\n            \"on the benchmark arguments.\",\n            stacklevel=2,\n        )\n    metrics = BenchmarkMetrics(\n        completed=completed,\n        total_input=total_input,\n        total_output=sum(output_lens),\n        total_output_retokenized=sum(retokenized_output_lens),\n        request_throughput=completed / dur_s,\n        input_throughput=total_input / dur_s,\n        output_throughput=sum(output_lens) / dur_s,\n        output_throughput_retokenized=sum(retokenized_output_lens) / dur_s,\n        total_throughput=(total_input + sum(output_lens)) / dur_s,\n        total_throughput_retokenized=(total_input + sum(retokenized_output_lens))\n        / dur_s,\n        mean_ttft_ms=np.mean(ttfts or 0)\n        * 1000,  # ttfts is empty if streaming is not supported by backend\n        median_ttft_ms=np.median(ttfts or 0) * 1000,\n        std_ttft_ms=np.std(ttfts or 0) * 1000,\n        p90_ttft_ms=np.percentile(ttfts or 0, 90) * 1000,\n        p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,\n        mean_tpot_ms=np.mean(tpots or 0) * 1000,\n        median_tpot_ms=np.median(tpots or 0) * 1000,\n        std_tpot_ms=np.std(tpots or 0) * 1000,\n        p90_tpot_ms=np.percentile(tpots or 0, 90) * 1000,\n        p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,\n        mean_itl_ms=np.mean(itls or 0) * 1000,\n        median_itl_ms=np.median(itls or 0) * 1000,\n        std_itl_ms=np.std(itls or 0) * 1000,\n        p90_itl_ms=np.percentile(itls or 0, 90) * 1000,\n        p99_itl_ms=np.percentile(itls or 0, 99) * 1000,\n        mean_e2e_latency_ms=np.mean(e2e_latencies) * 1000,\n        median_e2e_latency_ms=np.median(e2e_latencies) * 1000,\n        std_e2e_latency_ms=np.std(e2e_latencies) * 1000,\n        p99_e2e_latency_ms=np.percentile(e2e_latencies, 99) * 1000,\n        concurrency=np.sum(e2e_latencies) / dur_s,\n    )\n\n    return metrics, output_lens\n\n\nasync def benchmark(\n    backend: str,\n    api_url: str,\n    base_url: str,\n    model_id: str,\n    tokenizer: PreTrainedTokenizerBase,\n    input_requests: SampleOutput,\n    request_rate: float,\n    max_concurrency: Optional[int],\n    disable_tqdm: bool,\n    lora_name: str,\n    extra_request_body: Dict[str, Any],\n    profile: bool,\n    enable_shared_prefix: bool,\n):\n    if backend in ASYNC_REQUEST_FUNCS:\n        request_func = ASYNC_REQUEST_FUNCS[backend]\n    else:\n        raise ValueError(f\"Unknown backend: {backend}\")\n\n    # Limit concurrency\n    # From https://github.com/vllm-project/vllm/pull/9390\n    semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None\n\n    async def limited_request_func(request_func_input, queue, tokenizer, pbar):\n        if semaphore is None:\n            return await request_func(\n                request_func_input=request_func_input,\n                queue=queue,\n                tokenizer=tokenizer,\n                pbar=pbar,\n            )\n        async with semaphore:\n            return await request_func(\n                request_func_input=request_func_input,\n                queue=queue,\n                tokenizer=tokenizer,\n                pbar=pbar,\n            )\n\n    num_actual_requests = sum(len(r) for r in input_requests)\n    print(f\"Num of shared prefixes or conversations: {len(input_requests)}\")\n    print(f\"Num of total requests: {num_actual_requests}\")\n\n    # flatten the requests for shared prefix\n    if enable_shared_prefix:\n        input_requests = [[r] for requests in input_requests for r in requests]\n    inputs_requests_queue = asyncio.Queue(maxsize=len(input_requests))\n    print(\"Starting initial single prompt test run...\")\n    # NOTE: Just use the first request of the first conversation for warmup\n    test_input = RequestFuncInput(\n        model=model_id,\n        prompts=input_requests[0][:1],\n        api_url=api_url,\n        lora_name=lora_name,\n        extra_request_body=extra_request_body,\n    )\n    test_output = await request_func(\n        request_func_input=test_input, queue=inputs_requests_queue, tokenizer=tokenizer\n    )\n    if not test_output.success:\n        raise ValueError(\n            \"Initial test run failed - Please make sure benchmark arguments \"\n            f\"are correctly specified. Error: {test_output.error}\"\n        )\n    else:\n        print(\"Initial test run completed. Starting main benchmark run...\")\n\n    # Check the states\n    assert inputs_requests_queue.empty()\n\n    # Flush cache\n    if \"sglang\" in backend:\n        requests.post(base_url + \"/flush_cache\")\n\n    time.sleep(1.0)\n\n    # Start profiler\n    if profile:\n        print(\"Starting profiler...\")\n        profile_output = await async_request_profile(\n            api_url=base_url + \"/start_profile\"\n        )\n        if profile_output.success:\n            print(\"Profiler started\")\n\n    for request in input_requests:\n        request_func_input = RequestFuncInput(\n            model=model_id,\n            prompts=request,\n            api_url=api_url,\n            lora_name=lora_name,\n            extra_request_body=extra_request_body,\n        )\n        inputs_requests_queue.put_nowait(request_func_input)\n    if (\n        not args.enable_multiturn\n        and not args.enable_shared_prefix\n        and not args.dataset_name == \"generated-shared-prefix\"\n    ):\n        assert len(input_requests) == num_actual_requests\n\n    pbar = None if disable_tqdm else tqdm(total=num_actual_requests)\n\n    benchmark_start_time = time.perf_counter()\n    tasks: List[asyncio.Task] = []\n    async for request in get_requests(\n        inputs_requests_queue, request_rate, num_actual_requests\n    ):\n        tasks.append(\n            asyncio.create_task(\n                limited_request_func(\n                    request_func_input=request,\n                    queue=inputs_requests_queue,\n                    tokenizer=tokenizer,\n                    pbar=pbar,\n                )\n            )\n        )\n    outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)\n\n    # Stop profiler\n    if profile:\n        print(\"Stopping profiler...\")\n        profile_output = await async_request_profile(api_url=base_url + \"/stop_profile\")\n        if profile_output.success:\n            print(\"Profiler stopped\")\n\n    if pbar is not None:\n        pbar.close()\n\n    # Compute metrics and print results\n    benchmark_duration = time.perf_counter() - benchmark_start_time\n    metrics, output_lens = calculate_metrics(\n        outputs=outputs,\n        dur_s=benchmark_duration,\n        tokenizer=tokenizer,\n        backend=backend,\n    )\n\n    print(\"\\n{s:{c}^{n}}\".format(s=\" Serving Benchmark Result \", n=50, c=\"=\"))\n    print(\"{:<40} {:<10}\".format(\"Backend:\", backend))\n    print(\"{:<40} {:<10}\".format(\"Traffic request rate:\", request_rate))\n    print(\n        \"{:<40} {:<10}\".format(\n            \"Max request concurrency:\",\n            max_concurrency if max_concurrency else \"not set\",\n        )\n    )\n    print(\"{:<40} {:<10}\".format(\"Successful requests:\", metrics.completed))\n    print(\"{:<40} {:<10.2f}\".format(\"Benchmark duration (s):\", benchmark_duration))\n    print(\"{:<40} {:<10}\".format(\"Total input tokens:\", metrics.total_input))\n    print(\"{:<40} {:<10}\".format(\"Total generated tokens:\", metrics.total_output))\n    print(\n        \"{:<40} {:<10}\".format(\n            \"Total generated tokens (retokenized):\", metrics.total_output_retokenized\n        )\n    )\n    print(\n        \"{:<40} {:<10.2f}\".format(\n            \"Request throughput (req/s):\", metrics.request_throughput\n        )\n    )\n    print(\n        \"{:<40} {:<10.2f}\".format(\n            \"Input token throughput (tok/s):\", metrics.input_throughput\n        )\n    )\n    print(\n        \"{:<40} {:<10.2f}\".format(\n            \"Output token throughput (tok/s):\", metrics.output_throughput\n        )\n    )\n    print(\n        \"{:<40} {:<10.2f}\".format(\n            \"Total token throughput (tok/s):\", metrics.total_throughput\n        )\n    )\n    print(\"{:<40} {:<10.2f}\".format(\"Concurrency:\", metrics.concurrency))\n    print(\"{s:{c}^{n}}\".format(s=\"End-to-End Latency\", n=50, c=\"-\"))\n    print(\n        \"{:<40} {:<10.2f}\".format(\"Mean E2E Latency (ms):\", metrics.mean_e2e_latency_ms)\n    )\n    print(\n        \"{:<40} {:<10.2f}\".format(\n            \"Median E2E Latency (ms):\", metrics.median_e2e_latency_ms\n        )\n    )\n    print(\"{s:{c}^{n}}\".format(s=\"Time to First Token\", n=50, c=\"-\"))\n    print(\"{:<40} {:<10.2f}\".format(\"Mean TTFT (ms):\", metrics.mean_ttft_ms))\n    print(\"{:<40} {:<10.2f}\".format(\"Median TTFT (ms):\", metrics.median_ttft_ms))\n    print(\"{:<40} {:<10.2f}\".format(\"P90 TTFT (ms):\", metrics.p90_ttft_ms))\n    print(\"{:<40} {:<10.2f}\".format(\"P99 TTFT (ms):\", metrics.p99_ttft_ms))\n    print(\n        \"{s:{c}^{n}}\".format(s=\"Time per Output Token (excl. 1st token)\", n=50, c=\"-\")\n    )\n    print(\"{:<40} {:<10.2f}\".format(\"Mean TPOT (ms):\", metrics.mean_tpot_ms))\n    print(\"{:<40} {:<10.2f}\".format(\"Median TPOT (ms):\", metrics.median_tpot_ms))\n    print(\"{:<40} {:<10.2f}\".format(\"P90 TPOT (ms):\", metrics.p90_tpot_ms))\n    print(\"{:<40} {:<10.2f}\".format(\"P99 TPOT (ms):\", metrics.p99_tpot_ms))\n    print(\"{s:{c}^{n}}\".format(s=\"Inter-token Latency\", n=50, c=\"-\"))\n    print(\"{:<40} {:<10.2f}\".format(\"Mean ITL (ms):\", metrics.mean_itl_ms))\n    print(\"{:<40} {:<10.2f}\".format(\"Median ITL (ms):\", metrics.median_itl_ms))\n    print(\"{:<40} {:<10.2f}\".format(\"P90 ITL (ms):\", metrics.p90_itl_ms))\n    print(\"{:<40} {:<10.2f}\".format(\"P99 ITL (ms):\", metrics.p99_itl_ms))\n    print(\"=\" * 50)\n\n    if (\n        metrics.median_ttft_ms is not None\n        and metrics.mean_itl_ms is not None\n        and metrics.output_throughput is not None\n    ):\n        result = {\n            # Arguments\n            \"backend\": args.backend,\n            \"dataset_name\": args.dataset_name,\n            \"request_rate\": request_rate,\n            \"max_concurrency\": max_concurrency,\n            \"fixed_output_len\": args.fixed_output_len,\n            \"random_input_len\": args.random_input_len,\n            \"random_output_len\": args.random_output_len,\n            \"random_range_ratio\": args.random_range_ratio,\n            # Results\n            \"duration\": benchmark_duration,\n            \"completed\": metrics.completed,\n            \"total_input_tokens\": metrics.total_input,\n            \"total_output_tokens\": metrics.total_output,\n            \"total_output_tokens_retokenized\": metrics.total_output_retokenized,\n            \"request_throughput\": metrics.request_throughput,\n            \"input_throughput\": metrics.input_throughput,\n            \"output_throughput\": metrics.output_throughput,\n            \"mean_e2e_latency_ms\": metrics.mean_e2e_latency_ms,\n            \"median_e2e_latency_ms\": metrics.median_e2e_latency_ms,\n            \"std_e2e_latency_ms\": metrics.std_e2e_latency_ms,\n            \"p99_e2e_latency_ms\": metrics.p99_e2e_latency_ms,\n            \"mean_ttft_ms\": metrics.mean_ttft_ms,\n            \"median_ttft_ms\": metrics.median_ttft_ms,\n            \"std_ttft_ms\": metrics.std_ttft_ms,\n            \"p99_ttft_ms\": metrics.p99_ttft_ms,\n            \"mean_tpot_ms\": metrics.mean_tpot_ms,\n            \"median_tpot_ms\": metrics.median_tpot_ms,\n            \"std_tpot_ms\": metrics.std_tpot_ms,\n            \"p99_tpot_ms\": metrics.p99_tpot_ms,\n            \"mean_itl_ms\": metrics.mean_itl_ms,\n            \"median_itl_ms\": metrics.median_itl_ms,\n            \"std_itl_ms\": metrics.std_itl_ms,\n            \"p99_itl_ms\": metrics.p99_itl_ms,\n            \"concurrency\": metrics.concurrency,\n            \"input_throughput\": metrics.input_throughput,\n            \"output_throughput\": metrics.output_throughput,\n            \"fixed_output_len\": args.fixed_output_len,\n            \"random_input_len\": args.random_input_len,\n            \"random_output_len\": args.random_output_len,\n            \"random_range_ratio\": args.random_range_ratio,\n            \"duration\": benchmark_duration,\n            \"completed\": metrics.completed,\n        }\n    else:\n        print(f\"Error running benchmark for request rate: {request_rate}\")\n        print(\"-\" * 30)\n\n    # Determine output file name\n    if args.output_file:\n        output_file_name = args.output_file\n    else:\n        now = datetime.now().strftime(\"%m%d\")\n        if args.dataset_name == \"random\":\n            output_file_name = f\"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl\"\n        else:\n            output_file_name = (\n                f\"{args.backend}_{now}_{args.num_prompts}_{args.dataset_name}.jsonl\"\n            )\n\n    # Append results to a JSONL file\n    with open(output_file_name, \"a\") as file:\n        file.write(json.dumps(result) + \"\\n\")\n\n    result = {\n        \"duration\": benchmark_duration,\n        \"completed\": metrics.completed,\n        \"total_input_tokens\": metrics.total_input,\n        \"total_output_tokens\": metrics.total_output,\n        \"total_output_tokens_retokenized\": metrics.total_output_retokenized,\n        \"request_throughput\": metrics.request_throughput,\n        \"input_throughput\": metrics.input_throughput,\n        \"output_throughput\": metrics.output_throughput,\n        \"mean_ttft_ms\": metrics.mean_ttft_ms,\n        \"median_ttft_ms\": metrics.median_ttft_ms,\n        \"std_ttft_ms\": metrics.std_ttft_ms,\n        \"p90_ttft_ms\": metrics.p90_ttft_ms,\n        \"p99_ttft_ms\": metrics.p99_ttft_ms,\n        \"mean_tpot_ms\": metrics.mean_tpot_ms,\n        \"median_tpot_ms\": metrics.median_tpot_ms,\n        \"std_tpot_ms\": metrics.std_tpot_ms,\n        \"p90_tpot_ms\": metrics.p90_tpot_ms,\n        \"p99_tpot_ms\": metrics.p99_tpot_ms,\n        \"mean_itl_ms\": metrics.mean_itl_ms,\n        \"median_itl_ms\": metrics.median_itl_ms,\n        \"std_itl_ms\": metrics.std_itl_ms,\n        \"p90_itl_ms\": metrics.p90_itl_ms,\n        \"p99_itl_ms\": metrics.p99_itl_ms,\n        \"input_lens\": [output.prompt_len for output in outputs],\n        \"output_lens\": output_lens,\n        \"ttfts\": [output.ttft for output in outputs],\n        \"itls\": [output.itl for output in outputs],\n        \"generated_texts\": [output.generated_text for output in outputs],\n        \"errors\": [output.error for output in outputs],\n        \"mean_e2e_latency_ms\": metrics.mean_e2e_latency_ms,\n        \"median_e2e_latency_ms\": metrics.median_e2e_latency_ms,\n    }\n    return result\n\n\ndef run_benchmark(args_: argparse.Namespace):\n    global args\n    args = args_\n\n    # Set default value for max_concurrency if not present\n    if not hasattr(args, \"max_concurrency\"):\n        args.max_concurrency = None\n\n    # Set global environments\n    set_ulimit()\n    random.seed(args.seed)\n    np.random.seed(args.seed)\n\n    extra_request_body = {}\n    if args.extra_request_body:\n        extra_request_body = json.loads(args.extra_request_body)\n\n    # Set url\n    if args.port is None:\n        args.port = {\n            \"sglang\": 30000,\n            \"lmdeploy\": 23333,\n            \"vllm\": 8000,\n        }.get(args.backend, 30000)\n\n    model_url = (\n        f\"{args.base_url}/v1/models\"\n        if args.base_url\n        else f\"http://{args.host}:{args.port}/v1/models\"\n    )\n\n    if args.backend in [\"sglang\", \"vllm\", \"lmdeploy\"]:\n        api_url = (\n            f\"{args.base_url}/v1/chat/completions\"\n            if args.base_url\n            else f\"http://{args.host}:{args.port}/v1/chat/completions\"\n        )\n    base_url = (\n        f\"http://{args.host}:{args.port}\" if args.base_url is None else args.base_url\n    )\n\n    # Get model name\n    if args.model is None:\n        if args.backend == \"truss\":\n            print(\n                \"Please provide a model with `--model` when using truss backend. e.g. --model meta-llama/Llama-3.1-8B-Instruct\"\n            )\n            sys.exit(1)\n        try:\n            response = requests.get(model_url)\n            model_list = response.json().get(\"data\", [])\n            args.model = model_list[0][\"id\"] if model_list else None\n        except Exception as e:\n            print(f\"Failed to fetch model from {model_url}. Error: {e}\")\n            print(\n                \"Please specify the correct host and port using `--host` and `--port`.\"\n            )\n            sys.exit(1)\n\n    if args.model is None:\n        print(\"No model specified or found. Please provide a model using `--model`.\")\n        sys.exit(1)\n\n    # Dataset compatibility check\n    if args.enable_multiturn:\n        # TODO: Support multiturn for random\n        if args.dataset_name not in [\"sharegpt\", \"ultrachat\", \"loogle\", \"nextqa\"]:\n            print(\n                \"Multiturn conversation is only supported for sharegpt, ultrachat, loogle, and nextqa datasets.\"\n            )\n            sys.exit(1)\n\n    if args.enable_shared_prefix:\n        if args.dataset_name not in [\"loogle\", \"nextqa\"]:\n            print(\"Shared prefix is only supported for loogle and nextqa datasets.\")\n            sys.exit(1)\n\n    print(f\"{args}\\n\")\n\n    # Read dataset\n    backend = args.backend\n    model_id = args.model\n    tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model\n\n    tokenizer = get_tokenizer(tokenizer_id)\n\n    input_requests = get_dataset(args, tokenizer)\n\n    return asyncio.run(\n        benchmark(\n            backend=backend,\n            api_url=api_url,\n            base_url=base_url,\n            model_id=model_id,\n            tokenizer=tokenizer,\n            input_requests=input_requests,\n            request_rate=args.request_rate,\n            max_concurrency=args.max_concurrency,\n            disable_tqdm=args.disable_tqdm,\n            lora_name=args.lora_name,\n            extra_request_body=extra_request_body,\n            profile=args.profile,\n            enable_shared_prefix=args.enable_shared_prefix,\n        )\n    )\n\n\nif __name__ == \"__main__\":\n    parser = ArgumentParser(description=\"Benchmark the online serving throughput.\")\n    parser.add_argument(\n        \"--backend\",\n        type=str,\n        choices=list(ASYNC_REQUEST_FUNCS.keys()),\n        default=\"sglang\",\n        help=\"Must specify a backend, depending on the LLM Inference Engine.\",\n    )\n    parser.add_argument(\n        \"--base-url\",\n        type=str,\n        default=None,\n        help=\"Server or API base url if not using http host and port.\",\n    )\n    parser.add_argument(\n        \"--host\", type=str, default=\"0.0.0.0\", help=\"Default host is 0.0.0.0.\"\n    )\n    parser.add_argument(\n        \"--port\",\n        type=int,\n        help=\"If not set, the default port is configured according to its default value for different LLM Inference Engines.\",\n    )\n    parser.add_argument(\n        \"--dataset-name\",\n        type=str,\n        default=\"sharegpt\",\n        choices=[\n            \"sharegpt\",\n            \"random\",\n            \"generated-shared-prefix\",\n            \"ultrachat\",\n            \"loogle\",\n            \"nextqa\",\n        ],\n        help=\"Name of the dataset to benchmark on.\",\n    )\n    parser.add_argument(\n        \"--dataset-path\", type=str, default=\"\", help=\"Path to the dataset.\"\n    )\n    parser.add_argument(\n        \"--model\",\n        type=str,\n        help=\"Name or path of the model. If not set, the default model will request /v1/models for conf.\",\n    )\n    parser.add_argument(\n        \"--tokenizer\",\n        type=str,\n        help=\"Name or path of the tokenizer. If not set, using the model conf.\",\n    )\n    parser.add_argument(\n        \"--chat-template\",\n        type=str,\n        help=\"The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.\",\n    )\n    parser.add_argument(\n        \"--num-prompts\",\n        type=int,\n        default=1000,\n        help=\"Number of prompts to process. Default is 1000.\",\n    )\n    parser.add_argument(\n        \"--fixed-output-len\",\n        type=int,\n        default=None,\n        help=\"Output length for each request. Overrides the output length from the dataset.\",\n    )\n    parser.add_argument(\n        \"--sharegpt-context-len\",\n        type=int,\n        default=None,\n        help=\"The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped.\",\n    )\n    parser.add_argument(\n        \"--random-input-len\",\n        type=int,\n        default=1024,\n        help=\"Number of input tokens per request, used only for random dataset.\",\n    )\n    parser.add_argument(\n        \"--random-output-len\",\n        default=1024,\n        type=int,\n        help=\"Number of output tokens per request, used only for random dataset.\",\n    )\n    parser.add_argument(\n        \"--random-range-ratio\",\n        type=float,\n        default=0.0,\n        help=\"Range of sampled ratio of input/output length, \"\n        \"used only for random dataset.\",\n    )\n    parser.add_argument(\n        \"--request-rate\",\n        type=float,\n        default=float(\"inf\"),\n        help=\"Number of requests per second. If this is inf, then all the requests are sent at time 0. \"\n        \"Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.\",\n    )\n    parser.add_argument(\n        \"--max-concurrency\",\n        type=int,\n        default=None,\n        help=\"Maximum number of concurrent requests. This can be used \"\n        \"to help simulate an environment where a higher level component \"\n        \"is enforcing a maximum number of concurrent requests. While the \"\n        \"--request-rate argument controls the rate at which requests are \"\n        \"initiated, this argument will control how many are actually allowed \"\n        \"to execute at a time. This means that when used in combination, the \"\n        \"actual request rate may be lower than specified with --request-rate, \"\n        \"if the server is not processing requests fast enough to keep up.\",\n    )\n    parser.add_argument(\n        \"--multi\",\n        action=\"store_true\",\n        help=\"Use request rate range rather than single value.\",\n    )\n    parser.add_argument(\n        \"--request-rate-range\",\n        type=str,\n        default=\"2,34,2\",\n        help=\"Range of request rates in the format start,stop,step. Default is 2,34,2. It also supports a list of request rates, requiring the parameters to not equal three.\",\n    )\n    parser.add_argument(\"--output-file\", type=str, help=\"Output JSONL file name.\")\n    parser.add_argument(\n        \"--enable-multiturn\",\n        action=\"store_true\",\n        help=\"Enable multiturn chat for online serving benchmarking. \"\n        \"This option is effective on the following datasets: \"\n        \"sharegpt, ultrachat, loogle, nextqa\",\n    )\n    parser.add_argument(\n        \"--enable-shared-prefix\",\n        action=\"store_true\",\n        help=\"Enable shared prefix for online serving benchmarking. \"\n        \"This option is effective on the following datasets: \"\n        \"loogle, nextqa\",\n    )\n\n    parser.add_argument(\n        \"--disable-shuffle\",\n        action=\"store_true\",\n        help=\"Disable shuffling datasets. This is useful to generate stable output \"\n        \"in benchmarking\",\n    )\n    parser.add_argument(\n        \"--disable-tqdm\",\n        action=\"store_true\",\n        help=\"Specify to disable tqdm progress bar.\",\n    )\n    parser.add_argument(\n        \"--disable-stream\",\n        action=\"store_true\",\n        help=\"Disable streaming mode.\",\n    )\n    parser.add_argument(\n        \"--return-logprob\",\n        action=\"store_true\",\n        help=\"Return logprob.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=1, help=\"The random seed.\")\n    parser.add_argument(\n        \"--disable-ignore-eos\",\n        action=\"store_true\",\n        help=\"Disable ignoring EOS.\",\n    )\n    parser.add_argument(\n        \"--extra-request-body\",\n        metavar='{\"key1\": \"value1\", \"key2\": \"value2\"}',\n        type=str,\n        help=\"Append given JSON object to the request payload. You can use this to specify\"\n        \"additional generate params like sampling params.\",\n    )\n    parser.add_argument(\n        \"--apply-chat-template\",\n        action=\"store_true\",\n        help=\"Apply chat template\",\n    )\n    parser.add_argument(\n        \"--profile\",\n        action=\"store_true\",\n        help=\"Use Torch Profiler. The endpoint must be launched with \"\n        \"SGLANG_TORCH_PROFILER_DIR to enable profiler.\",\n    )\n    parser.add_argument(\n        \"--lora-name\",\n        type=str,\n        default=None,\n        help=\"The name of LoRA adapter\",\n    )\n\n    group = parser.add_argument_group(\"generated-shared-prefix dataset arguments\")\n    group.add_argument(\n        \"--gsp-num-groups\",\n        type=int,\n        default=64,\n        help=\"Number of system prompt groups for generated-shared-prefix dataset\",\n    )\n    group.add_argument(\n        \"--gsp-prompts-per-group\",\n        type=int,\n        default=16,\n        help=\"Number of prompts per system prompt group for generated-shared-prefix dataset\",\n    )\n    group.add_argument(\n        \"--gsp-system-prompt-len\",\n        type=int,\n        default=2048,\n        help=\"Target length in tokens for system prompts in generated-shared-prefix dataset\",\n    )\n    group.add_argument(\n        \"--gsp-question-len\",\n        type=int,\n        default=128,\n        help=\"Target length in tokens for questions in generated-shared-prefix dataset\",\n    )\n    group.add_argument(\n        \"--gsp-output-len\",\n        type=int,\n        default=256,\n        help=\"Target length in tokens for outputs in generated-shared-prefix dataset\",\n    )\n    # videos specific\n    parser.add_argument(\n        \"--max-frames\",\n        type=int,\n        default=sys.maxsize,\n        help=\"The maximum number of frames to extract from each video. \"\n        \"This option is specific to the nextqa dataset (video benchmark). \",\n    )\n    args = parser.parse_args()\n\n    if args.enable_multiturn and args.enable_shared_prefix:\n        parser.error(\n            \"--enable-multiturn and --enable-shared-prefix cannot be set at the same time.\"\n        )\n\n    run_benchmark(args)\n"
  },
  {
    "path": "benchmark/hicache/data_processing.py",
    "content": "import json\nimport os\nimport pickle\nimport random\nfrom typing import List, Optional, Tuple, Union\n\nimport numpy as np\nfrom nextqa import NExTQALoader\n\n# from nextqa.video import , VideoPrompt\nfrom tqdm.asyncio import tqdm\nfrom transformers import PreTrainedTokenizerBase\n\nfrom sglang.benchmark.datasets.common import (\n    SHAREGPT_FILENAME,\n    SHAREGPT_REPO_ID,\n    gen_prompt,\n)\nfrom sglang.benchmark.datasets.generated_shared_prefix import get_gen_prefix_cache_path\nfrom sglang.benchmark.utils import download_and_cache_hf_file\nfrom sglang.lang.chat_template import get_chat_template, get_chat_template_by_model_path\nfrom sglang.srt.entrypoints.openai.protocol import ChatCompletionMessageContentPart\nfrom sglang.utils import encode_video_base64\n\n# type of content fields, can be only prompts or with images/videos\nMsgContent = Union[str, List[ChatCompletionMessageContentPart]]\n\n# A list of all the conversations. Each conversation is a list of\n# tuples. If multiturn is not enabled, the length of list is 1,\n# containing only the first Q&A pair.\n# For the shared prefix workload (synthetic, loogle, nextqa), it\n# is a list of conversations sharing the same prefix (synthetic,\n# doc, video)\nSampleOutput = List[List[Tuple[MsgContent, int, int]]]\n\n\ndef common_filter_chat(\n    num_requests: int,\n    new_dataset: List,\n    tokenizer: PreTrainedTokenizerBase,\n    min_prompt_len: Optional[int],\n    min_output_len: Optional[int],\n    max_prompt_len: Optional[int],\n    max_output_len: Optional[int],\n    fixed_output_len: Optional[int],\n) -> SampleOutput:\n    # Filter out sequences that are too long or too short\n    filtered_dataset: SampleOutput = []\n    l = 0\n    input_tokens = 0\n    output_tokens = 0\n    while l < num_requests:\n        for i in range(len(new_dataset)):\n            if l == num_requests:\n                break\n            processed = []\n            for j in new_dataset[i]:\n                # Tokenize the prompts and completions.\n                prompt = j[0]\n                prompt_token_ids = tokenizer.encode(prompt)\n                prompt_len = len(prompt_token_ids)\n\n                completion = j[1]\n                completion_token_ids = tokenizer.encode(completion)\n                output_len = (\n                    len(completion_token_ids)\n                    if fixed_output_len is None\n                    else fixed_output_len\n                )\n                if (\n                    min_prompt_len is not None\n                    and prompt_len < min_prompt_len\n                    or min_output_len is not None\n                    and output_len < min_output_len\n                    or max_prompt_len is not None\n                    and prompt_len > max_prompt_len\n                    or max_output_len is not None\n                    and output_len > max_output_len\n                ):\n                    # Prune too short sequences.\n                    continue\n                input_tokens += prompt_len\n                output_tokens += output_len\n                processed.append((prompt, prompt_len, output_len))\n            if len(processed) != 0:\n                filtered_dataset.append(processed)\n                l += 1\n\n    print(f\"#Input tokens: {input_tokens}\")\n    print(f\"#Output tokens: {output_tokens}\")\n    return filtered_dataset\n\n\ndef sample_sharegpt_requests(\n    dataset_path: str,\n    num_requests: int,\n    tokenizer: PreTrainedTokenizerBase,\n    disable_shuffle: bool = False,\n    enable_multiturn: bool = True,\n    fixed_output_len: Optional[int] = None,\n) -> SampleOutput:\n    if fixed_output_len is not None and fixed_output_len < 4:\n        raise ValueError(\"output_len too small\")\n\n    # Download sharegpt if necessary\n    if not os.path.isfile(dataset_path):\n        dataset_path = download_and_cache_hf_file(\n            repo_id=SHAREGPT_REPO_ID,\n            filename=SHAREGPT_FILENAME,\n        )\n\n    # Load the dataset.\n    with open(dataset_path) as f:\n        dataset = json.load(f)\n    # Filter out the conversations with less than 2 turns.\n    dataset = [data for data in dataset if len(data[\"conversations\"]) >= 2]\n\n    # Keep one conversation in one list\n    new_dataset = []\n    for data in dataset:\n        if len(data[\"conversations\"]) % 2 != 0:\n            continue\n        if data[\"conversations\"][0][\"from\"] != \"human\":\n            continue\n        chat = []\n        total_len = 2\n        if enable_multiturn:\n            total_len = len(data[\"conversations\"])\n        for i in range(0, total_len, 2):\n            # One user One Assistant\n            chat.append(\n                (\n                    data[\"conversations\"][i][\"value\"],\n                    data[\"conversations\"][i + 1][\"value\"],\n                )\n            )\n        new_dataset.append(chat)\n\n    if not disable_shuffle:\n        # Shuffle the dataset.\n        random.shuffle(new_dataset)\n\n    # Filter out sequences that are too long or too short\n    filtered_dataset: SampleOutput = common_filter_chat(\n        num_requests, new_dataset, tokenizer, 4, 4, None, None, fixed_output_len\n    )\n    return filtered_dataset\n\n\ndef sample_ultrachat_requests(\n    dataset_path: str,\n    num_requests: int,\n    tokenizer: PreTrainedTokenizerBase,\n    disable_shuffle: bool = False,\n    enable_multiturn: bool = True,\n    fixed_output_len: Optional[int] = None,\n) -> SampleOutput:\n    if fixed_output_len is not None and fixed_output_len < 4:\n        raise ValueError(\"output_len too small\")\n\n    # Load the dataset\n    dataset = []\n    with open(dataset_path) as f:\n        while True:\n            line = f.readline()\n            if not line:\n                break\n            dataset.append(json.loads(line))\n\n    # Filter out the conversations with less than 2 turns.\n    dataset = [data for data in dataset if len(data[\"data\"]) >= 2]\n\n    # Keep one conversation in one list\n    new_dataset = []\n    for data in dataset:\n        if len(data[\"data\"]) % 2 != 0:\n            continue\n        chat = []\n        total_len = 2\n        if enable_multiturn:\n            total_len = len(data[\"data\"])\n        for i in range(0, total_len, 2):\n            # One user One Assistant\n            chat.append((data[\"data\"][i], data[\"data\"][i + 1]))\n        new_dataset.append(chat)\n\n    # Shuffle the dataset.\n    if not disable_shuffle:\n        random.shuffle(new_dataset)\n\n    # Filter out sequences that are too long or too short\n    filtered_dataset: SampleOutput = common_filter_chat(\n        num_requests, new_dataset, tokenizer, 4, 4, None, None, fixed_output_len\n    )\n    return filtered_dataset\n\n\ndef sample_loogle_requests(\n    dataset_path: str,\n    num_requests: int,\n    tokenizer: PreTrainedTokenizerBase,\n    disable_shuffle: bool = False,\n    enable_multiturn: bool = True,\n    enable_shared_prefix: bool = False,\n    fixed_output_len: Optional[int] = None,\n) -> SampleOutput:\n    if fixed_output_len is not None and fixed_output_len < 4:\n        raise ValueError(\"output_len too small\")\n\n    # Load the dataset\n    dataset = []\n    with open(dataset_path) as f:\n        while True:\n            line = f.readline()\n            if not line:\n                break\n            dataset.append(json.loads(line))\n\n    # Keep one conversation in one list\n    new_dataset = []\n    # TODO: Add shared prefix support for loogle\n    # NOTE: Now we preprocess it only for chat\n    for data in dataset:\n        chat = []\n        if (\n            \"qa_pairs\" not in data\n            or data[\"qa_pairs\"] == \"none\"\n            or len(data[\"qa_pairs\"]) == 0\n        ):\n            # If Q is none (for summarization),\n            # We add a question for summarization\n            # And keep the summary up to 1024 words\n            chat.append(\n                (\n                    \"Input: \"\n                    + data[\"input\"]\n                    + \" Question: \"\n                    + \"Please summarize the input\",\n                    data[\"input\"][:1024],\n                )\n            )\n            new_dataset.append(chat)\n        else:\n            qa_pairs = eval(data[\"qa_pairs\"])\n            for i, qa in enumerate(qa_pairs):\n                if i == 0 or enable_shared_prefix:\n                    # Combine input with the first Q\n                    chat.append(\n                        (\"Input: \" + data[\"input\"] + \" Question: \" + qa[\"Q\"], qa[\"A\"])\n                    )\n                elif enable_multiturn:\n                    chat.append((qa[\"Q\"], qa[\"A\"]))\n\n            new_dataset.append(chat)\n\n    # Shuffle the dataset.\n    if not disable_shuffle:\n        random.shuffle(new_dataset)\n\n    # Filter out sequences that are too long or too short\n    filtered_dataset: SampleOutput = common_filter_chat(\n        num_requests, new_dataset, tokenizer, 4, None, None, None, fixed_output_len\n    )\n    return filtered_dataset\n\n\ndef sample_nextqa_requests(\n    dataset_path: str,\n    num_requests: int,\n    tokenizer: PreTrainedTokenizerBase,\n    max_frames: int,  # Specific for video\n    model_path: str,\n    disable_shuffle: bool = False,\n    enable_multiturn: bool = True,  # No multiturn support for now\n    backend: str = \"sglang-oai\",\n    chat_template_name: Optional[str] = None,\n    fixed_output_len: Optional[int] = None,\n) -> SampleOutput:\n    \"\"\"\n    Example of messages:\n    message = {\n        \"role\": \"user\",\n        \"content\": [\n            {\"type\": \"image_url\", \"image_url\": {\"url\": base64_data}},\n            {\"type\": \"text\", \"text\": video.prompt},\n        ],\n    }\n    \"\"\"\n\n    if fixed_output_len is None:\n        fixed_output_len = 4096\n\n    # TODO: Check for multiturn\n    dataset = NExTQALoader(video_dir=dataset_path, max_frames=max_frames)\n    new_dataset = []\n    for v in dataset:\n        new_dataset.append(v)\n\n    if not disable_shuffle:\n        random.shuffle(new_dataset)\n\n    # TODO: prompt len can get from server side\n    filtered_dataset = []\n    l = 0\n    while l < num_requests:\n        for i in range(len(new_dataset)):\n            if l == num_requests:\n                break\n\n            video = new_dataset[i]\n\n            # text prompt\n            prompt = video.prompt\n\n            # NOTE: Chat Template is a must for video benchmark because we have to\n            # add special image token for later expansion\n            if backend == \"sglang\" or backend == \"sglang-native\":\n                if \"chat_template\" in tokenizer.init_kwargs:\n                    chat_template = get_chat_template(tokenizer.get_chat_template())\n                elif chat_template_name is not None:\n                    chat_template = get_chat_template(chat_template_name)\n                else:\n                    chat_template = get_chat_template_by_model_path(model_path)\n                prompt = chat_template.image_token + prompt\n\n            prompt_token_ids = tokenizer(prompt).input_ids\n            prompt_len = len(prompt_token_ids)\n            output_len = fixed_output_len  # max output len, not real output len\n\n            # video input\n            base64_data = encode_video_base64(video.path, video.num_frames)\n\n            # NOTE: This will be replaced by the expanded length from the server\n            prompt_len += video.num_frames\n\n            # add to content\n            content = [\n                {\"type\": \"image_url\", \"image_url\": {\"url\": base64_data}},\n                {\"type\": \"text\", \"text\": prompt},\n            ]\n\n            filtered_dataset.append([(content, prompt_len, output_len)])\n            l += 1\n    return filtered_dataset\n\n\ndef sample_random_requests(\n    input_len: int,\n    output_len: int,\n    num_prompts: int,\n    range_ratio: float,\n    tokenizer: PreTrainedTokenizerBase,\n    dataset_path: str,\n    disable_shuffle: bool = False,\n) -> SampleOutput:\n\n    input_lens = np.random.randint(\n        max(int(input_len * range_ratio), 1),\n        input_len + 1,\n        size=num_prompts,\n    )\n    output_lens = np.random.randint(\n        int(output_len * range_ratio),\n        output_len + 1,\n        size=num_prompts,\n    )\n\n    if True:\n        # Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens\n\n        # Download sharegpt if necessary\n        if not os.path.isfile(dataset_path):\n            dataset_path = download_and_cache_hf_file(\n                repo_id=SHAREGPT_REPO_ID,\n                filename=SHAREGPT_FILENAME,\n            )\n\n        # Load the dataset.\n        with open(dataset_path) as f:\n            dataset = json.load(f)\n        # Filter out the conversations with less than 2 turns.\n        dataset = [data for data in dataset if len(data[\"conversations\"]) >= 2]\n        # Only keep the first two turns of each conversation.\n        dataset = [\n            (data[\"conversations\"][0][\"value\"], data[\"conversations\"][1][\"value\"])\n            for data in dataset\n        ]\n\n        if not disable_shuffle:\n            # Shuffle the dataset.\n            random.shuffle(dataset)\n\n        # Filter out sequences that are too long or too short\n        input_requests: SampleOutput = []\n        for data in dataset:\n            i = len(input_requests)\n            if i == num_prompts:\n                break\n\n            # Tokenize the prompts and completions.\n            prompt = data[0]\n            prompt_token_ids = tokenizer.encode(prompt)\n            prompt_len = len(prompt_token_ids)\n\n            # Skip empty prompt\n            if prompt_len == 0:\n                continue\n\n            if prompt_len > input_lens[i]:\n                input_ids = prompt_token_ids[: input_lens[i]]\n            else:\n                ratio = (input_lens[i] + prompt_len - 1) // prompt_len\n                input_ids = (prompt_token_ids * ratio)[: input_lens[i]]\n            prompt = tokenizer.decode(input_ids)\n            input_requests.append([(prompt, int(input_lens[i]), int(output_lens[i]))])\n    else:\n        # Sample token ids from random integers. This can cause some NaN issues.\n        offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)\n        input_requests = []\n        for i in range(num_prompts):\n            prompt = tokenizer.decode(\n                [\n                    (offsets[i] + i + j) % tokenizer.vocab_size\n                    for j in range(input_lens[i])\n                ]\n            )\n            input_requests.append([(prompt, int(input_lens[i]), int(output_lens[i]))])\n\n    print(f\"#Input tokens: {np.sum(input_lens)}\")\n    print(f\"#Output tokens: {np.sum(output_lens)}\")\n    return input_requests\n\n\ndef sample_generated_shared_prefix_requests(\n    num_groups: int,\n    prompts_per_group: int,\n    system_prompt_len: int,\n    question_len: int,\n    output_len: int,\n    tokenizer: PreTrainedTokenizerBase,\n    args,\n    disable_shuffle: bool = False,\n) -> SampleOutput:\n    \"\"\"Generate benchmark requests with shared system prompts using random tokens and caching.\"\"\"\n    cache_path = get_gen_prefix_cache_path(\n        args.seed,\n        num_groups,\n        prompts_per_group,\n        system_prompt_len,\n        question_len,\n        output_len,\n        tokenizer,\n    )\n\n    # Try to load from cache first\n    if cache_path.exists():\n        print(f\"\\nLoading cached generated input data from {cache_path}\")\n        with open(cache_path, \"rb\") as f:\n            return pickle.load(f)\n\n    print(\"\\nGenerating new input data...\")\n\n    # Generate system prompts for each group\n    system_prompts = []\n    for _ in range(num_groups):\n        system_prompt = gen_prompt(tokenizer, system_prompt_len)\n        system_prompts.append(system_prompt)\n\n    # Generate questions\n    questions = []\n    for _ in range(num_groups * prompts_per_group):\n        question = gen_prompt(tokenizer, question_len)\n        questions.append(question)\n\n    # Combine system prompts with questions\n    input_requests = []\n    total_input_tokens = 0\n    total_output_tokens = 0\n\n    for group_idx in tqdm(range(num_groups), desc=\"Generating system prompt\"):\n        system_prompt = system_prompts[group_idx]\n        input_requests.append([])\n        for prompt_idx in tqdm(\n            range(prompts_per_group), desc=\"Generating questions\", leave=False\n        ):\n            question = questions[group_idx * prompts_per_group + prompt_idx]\n            full_prompt = f\"{system_prompt}\\n\\n{question}\"\n            prompt_len = len(tokenizer.encode(full_prompt))\n            input_requests[-1].append((full_prompt, prompt_len, output_len))\n            total_input_tokens += prompt_len\n            total_output_tokens += output_len\n\n    if not disable_shuffle:\n        # Shuffle questions\n        random.shuffle(input_requests)\n\n    # Print statistics\n    print(f\"\\nGenerated shared prefix dataset statistics:\")\n    print(f\"Number of groups: {num_groups}\")\n    print(f\"Prompts per group: {prompts_per_group}\")\n    print(f\"Total prompts: {len(input_requests) * prompts_per_group}\")\n    print(f\"Total input tokens: {total_input_tokens}\")\n    print(f\"Total output tokens: {total_output_tokens}\")\n    print(\n        f\"Average system prompt length: {sum(len(tokenizer.encode(sp)) for sp in system_prompts) / len(system_prompts):.1f} tokens\"\n    )\n    print(\n        f\"Average question length: {sum(len(tokenizer.encode(q)) for q in questions) / len(questions):.1f} tokens\\n\"\n    )\n\n    # Save to cache\n    cache_path.parent.mkdir(parents=True, exist_ok=True)\n    print(f\"Caching generated input data to {cache_path}\")\n    with open(cache_path, \"wb\") as f:\n        pickle.dump(input_requests, f)\n\n    return input_requests\n\n\ndef get_dataset(args, tokenizer):\n    if args.dataset_name == \"sharegpt\":\n        input_requests = sample_sharegpt_requests(\n            dataset_path=args.dataset_path,\n            num_requests=args.num_prompts,\n            tokenizer=tokenizer,\n            disable_shuffle=args.disable_shuffle,\n            enable_multiturn=args.enable_multiturn,\n            fixed_output_len=args.fixed_output_len,\n        )\n    elif args.dataset_name == \"ultrachat\":\n        input_requests = sample_ultrachat_requests(\n            dataset_path=args.dataset_path,\n            num_requests=args.num_prompts,\n            tokenizer=tokenizer,\n            disable_shuffle=args.disable_shuffle,\n            enable_multiturn=args.enable_multiturn,\n            fixed_output_len=args.fixed_output_len,\n        )\n    elif args.dataset_name == \"loogle\":\n        input_requests = sample_loogle_requests(\n            dataset_path=args.dataset_path,\n            num_requests=args.num_prompts,\n            tokenizer=tokenizer,\n            disable_shuffle=args.disable_shuffle,\n            enable_multiturn=args.enable_multiturn,\n            enable_shared_prefix=args.enable_shared_prefix,\n            fixed_output_len=args.fixed_output_len,\n        )\n    elif args.dataset_name == \"nextqa\":\n        input_requests = sample_nextqa_requests(\n            dataset_path=args.dataset_path,\n            num_requests=args.num_prompts,\n            tokenizer=tokenizer,\n            max_frames=args.max_frames,\n            model_path=args.model,\n            disable_shuffle=args.disable_shuffle,\n            enable_multiturn=args.enable_multiturn,\n            backend=args.backend,\n            chat_template_name=args.chat_template,\n            fixed_output_len=args.fixed_output_len,\n        )\n    elif args.dataset_name == \"random\":\n        input_requests = sample_random_requests(\n            input_len=args.random_input_len,\n            output_len=args.random_output_len,\n            num_prompts=args.num_prompts,\n            range_ratio=args.random_range_ratio,\n            tokenizer=tokenizer,\n            dataset_path=args.dataset_path,\n        )\n    elif args.dataset_name == \"generated-shared-prefix\":\n        input_requests = sample_generated_shared_prefix_requests(\n            num_groups=args.gsp_num_groups,\n            prompts_per_group=args.gsp_prompts_per_group,\n            system_prompt_len=args.gsp_system_prompt_len,\n            question_len=args.gsp_question_len,\n            output_len=args.gsp_output_len,\n            args=args,\n            tokenizer=tokenizer,\n        )\n    else:\n        raise ValueError(f\"Unknown dataset: {args.dataset_name}\")\n    return input_requests\n"
  },
  {
    "path": "benchmark/hicache/download.sh",
    "content": "#!/usr/bin/bash\n\n# The usage function\nusage() {\n    echo \"Usage: $0 {sharegpt|ultragpt|loogle|nextqa|all}\"\n    exit 1\n}\n\n# The download function\ndownload() {\n    case \"$1\" in\n        sharegpt)\n            echo $1\n            wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json\n            ;;\n        ultragpt)\n            echo $1\n            # Questions about the world\n            wget https://cloud.tsinghua.edu.cn/seafhttp/files/be1d7b87-22ca-449e-a6a7-c61d1ea7e010/ultrachat_release_230407.json\n            # Writing and Creation\n            wget https://cloud.tsinghua.edu.cn/seafhttp/files/61742d2a-25e2-4d08-b2b9-15f47ae50ace/ultrachat_material_release_230417.json\n            wget https://cloud.tsinghua.edu.cn/seafhttp/files/f71f6aa6-d346-4b16-85b7-8502efa3d608/ultrachat_material_release_230412.json\n            # External materials\n            wget https://cloud.tsinghua.edu.cn/seafhttp/files/42d22e28-e899-4975-a70f-5eda163e265d/ultrachat_existent_material_release_230420.json.gz\n            gunzip ultrachat_existent_material_release_230420.json.gz\n            ;;\n        loogle)\n            echo $1\n            git lfs install\n            git clone git@hf.co:datasets/bigainlco/LooGLE\n            unzip LooGLE/data.zip\n            ;;\n        nextqa)\n            echo $1\n            git lfs install\n            git clone https://huggingface.co/datasets/lmms-lab/NExTQA\n            unzip NExTQA/videos.zip\n            ;;\n        *)\n            usage\n            exit 1\n            ;;\n    esac\n}\n\n# Arg check\nif [ \"$#\" -ne 1 ]; then\n    usage\nfi\n\n# Invoke\n\ncase \"$1\" in\n    sharegpt|ultragpt|loogle|nextqa)\n        download \"$1\"\n        ;;\n    all)\n        download sharegpt\n        download ultragpt\n        download loogle\n        download nextqa\n        ;;\n    *)\n        usage\n        ;;\nesac\n"
  },
  {
    "path": "benchmark/hicache/nextqa.py",
    "content": "import os\nimport sys\nfrom typing import List\n\nimport av\nfrom datasets import load_dataset\n\n\ndef find_video_files(video_dir) -> List[str]:\n    if os.path.isfile(video_dir):\n        return [video_dir]\n\n    video_files = []\n    for root, dirs, files in os.walk(video_dir):\n        for file in files:\n            if file.endswith((\".mp4\", \".avi\", \".mov\")):\n                video_files.append(os.path.join(root, file))\n            # if file is dir\n            elif os.path.isdir(file):\n                video_files.extend(find_video_files(file))\n    return video_files\n\n\ndef video_frames(video_path, max_frames) -> int:\n    container = av.open(video_path)\n    total_frames = container.streams.video[0].frames\n    return min(total_frames, max_frames)\n\n\nclass Video:\n    def __init__(self, video_path, num_frames):\n        self.path = video_path\n        self.num_frames = num_frames\n\n    def __str__(self):\n        return f\"Video({self.path}, {self.num_frames})\"\n\n    def __iter__(self):\n        return iter((self.path, self.num_frames))\n\n\nclass VideoPrompt(Video):\n    def __init__(self, video_path, num_frames, prompt):\n        super().__init__(video_path, num_frames)\n        self.prompt = prompt\n\n    def __str__(self):\n        return f\"VideoPrompt({self.path}, {self.num_frames}, {self.prompt})\"\n\n    def __iter__(self):\n        return iter((self.path, self.num_frames, self.prompt))\n\n\nclass VideoLoader:\n    pass\n\n\nclass VideoFileLoader(VideoLoader):\n    \"\"\"\n    Load all the videos in a directory\n    \"\"\"\n\n    def __init__(self, video_dir, batch_size=1, max_frames=sys.maxsize):\n        super().__init__()\n        self.video_dir = video_dir\n        self.video_files = find_video_files(video_dir)\n        self.batch_size = batch_size\n        self.max_frames = max_frames\n        print(f\"batch_size: {batch_size}, max_frames: {max_frames}\")\n\n    def __iter__(self):  # (file, number of frames)\n        if self.batch_size == 1:\n            for video_file in self.video_files:\n                yield Video(video_file, video_frames(video_file, self.max_frames))\n        else:\n            batch = []\n            for video_file in self.video_files:\n                video = Video(video_file, video_frames(video_file, self.max_frames))\n                batch.append(video)\n                if len(batch) == self.batch_size:\n                    yield batch\n                    batch = []\n\n\nclass NExTQALoader(VideoLoader):\n    \"\"\"\n    Load vdideos and prompts from NExT dataset\n    set: train, test or validation\n    \"\"\"\n\n    def __init__(\n        self, video_dir, batch_size=1, max_frames=sys.maxsize, dset=\"test\", task=\"OE\"\n    ):\n        \"\"\"\n        task: 'MV' or 'OE'\n        \"\"\"\n        super().__init__()\n        self.task = task\n        print(f\"Loading the {dset} data of {task} from lmms-lab/NExTQA\")\n        self.ds = load_dataset(\"lmms-lab/NExTQA\", task)\n        self.ds = self.ds[dset]\n\n        # self.n = ds.num_rows\n        self.video_dir = video_dir\n        self.video_files = find_video_files(video_dir)\n        self.video_to_path = dict()\n        for video_file in self.video_files:\n            video_id = video_file.split(\"/\")[-1].split(\".\")[0]\n            self.video_to_path[video_id] = video_file\n\n        self.batch_size = batch_size\n        self.max_frames = max_frames\n\n    def get_video_prompt(self, entry, max_frames) -> VideoPrompt:\n        # Get video\n        video_id = entry[\"video\"]\n        video_path = self.video_to_path[video_id]\n        assert os.path.exists(video_path), f\"Video not found: {video_path}\"\n        num_frames = min(entry[\"frame_count\"], max_frames)\n        video = Video(video_path, num_frames)\n        prompt = entry[\"question\"] + \"?\"\n        if self.task == \"MC\":  # add choices\n            prompt += f' a0: {entry[\"a0\"]}, a1: {entry[\"a1\"]}, a2: {entry[\"a2\"]}, a3: {entry[\"a3\"]}'\n        return VideoPrompt(video_path, num_frames, prompt)\n\n    def __iter__(self):\n        if self.batch_size == 1:\n            for entry in self.ds:\n                yield self.get_video_prompt(entry, self.max_frames)\n        else:\n            batch = []\n            for entry in self.ds:\n                video = self.get_video_prompt(entry, self.max_frames)\n                batch.append(video)\n                if len(batch) == self.batch_size:\n                    yield batch\n                    batch = []\n\n\n# main\nif __name__ == \"__main__\":\n    video_dir = \"./videos\"\n    # video_loader = VideoFileLoader(video_dir, batch_size=16)\n    # for batch in video_loader:\n    #     print(f\"Number of videos in batch: {len(batch)}\")\n    #     for video_file, num_frames in batch:\n    #         print(f\"Video: {video_file} number of frames: {num_frames}\")\n\n    video_loader = NExTQALoader(video_dir, batch_size=16, dset=\"test\", task=\"OE\")\n    for batch in video_loader:\n        print(f\"Number of videos in batch: {len(batch)}\")\n        for video_file, num_frames, prompt in batch:\n            print(\n                f\"Video: {video_file} number of frames: {num_frames}, prompt: {prompt}\"\n            )\n        # break\n        # for video_file, prompt in batch:\n        #     print(f\"Video: {video_file} prompt: {prompt}\")\n        #     break\n"
  },
  {
    "path": "benchmark/hicache/perf.py",
    "content": "from __future__ import annotations\n\nfrom typing import Any, Callable, NamedTuple\n\nimport torch\n\n\ndef jit_hicache_impl(\n    k_cache_dst: torch.Tensor,\n    v_cache_dst: torch.Tensor,\n    indices_dst: torch.Tensor,\n    k_cache_src: torch.Tensor,\n    v_cache_src: torch.Tensor,\n    indices_src: torch.Tensor,\n    item_bytes: int,\n    block_quota: int,\n) -> None:\n    from sglang.jit_kernel.hicache import transfer_hicache_one_layer\n\n    _ = item_bytes\n\n    transfer_hicache_one_layer(\n        k_cache_dst=k_cache_dst,\n        v_cache_dst=v_cache_dst,\n        indices_dst=indices_dst,\n        k_cache_src=k_cache_src,\n        v_cache_src=v_cache_src,\n        indices_src=indices_src,\n        block_quota=block_quota,\n    )\n\n\ndef ref_hicache_impl(\n    k_cache_dst: torch.Tensor,\n    v_cache_dst: torch.Tensor,\n    indices_dst: torch.Tensor,\n    k_cache_src: torch.Tensor,\n    v_cache_src: torch.Tensor,\n    indices_src: torch.Tensor,\n    item_bytes: int,\n    block_quota: int,\n) -> None:\n    from sgl_kernel import transfer_kv_per_layer\n\n    transfer_kv_per_layer(\n        src_k=k_cache_src,\n        src_v=v_cache_src,\n        dst_k=k_cache_dst,\n        dst_v=v_cache_dst,\n        src_indices=indices_src,\n        dst_indices=indices_dst,\n        item_size=item_bytes,\n        block_quota=block_quota,\n    )\n\n\nclass HicacheBenchArgs(NamedTuple):\n    cache_item_size: int\n    dtype: torch.dtype\n    block_quota: int\n\n\ndef perf(f: Callable[[], Any], loop: int = 100) -> float:\n    tic = torch.cuda.Event(enable_timing=True)\n    toc = torch.cuda.Event(enable_timing=True)\n    torch.cuda.synchronize()\n    # warm up\n    f()\n    torch.cuda._sleep(10**8)\n    tic.record()\n    for _ in range(loop):\n        f()\n    toc.record()\n    toc.synchronize()\n    return tic.elapsed_time(toc) / loop\n\n\n@torch.inference_mode()\ndef test_hicache_kernel(args: HicacheBenchArgs) -> None:\n    CACHE_ITEM_SIZE, DTYPE, BLOCK_QUOTA = args\n\n    CUDA_CACHE_SIZE = 1024 * 1024\n    HOST_CACHE_SIZE = CUDA_CACHE_SIZE * 2\n\n    cuda_cache = torch.randn(\n        (2, CUDA_CACHE_SIZE, CACHE_ITEM_SIZE),\n        dtype=DTYPE,\n        device=\"cuda\",\n    )\n    host_cache = torch.empty(\n        (2, HOST_CACHE_SIZE, CACHE_ITEM_SIZE),\n        dtype=DTYPE,\n        device=\"cpu\",\n        pin_memory=True,\n    )\n\n    ITEM_BYTES = cuda_cache.element_size() * CACHE_ITEM_SIZE\n\n    def _gen_indices(size: int, bs: int) -> torch.Tensor:\n        assert bs <= size\n        result = (\n            (torch.randperm(size, dtype=torch.int64, device=\"cuda\")[:bs]).sort().values\n        )\n        if not (torch.all(result >= 0) and torch.all(result < size)):\n            where = (result < 0) | (result >= size)\n            place = where.nonzero(as_tuple=False)\n            print(\"Invalid indices at positions:\", place)\n            print(\"Invalid indices values:\", result[place])\n            raise ValueError(\"Generated invalid indices\")\n        return result\n\n    def _calc_tput(dur: float) -> float:\n        return (MEM / (1024**3)) / (dur / 1000)  # GB/s\n\n    def _gain_str(aot_dur: float, jit_dur: float) -> str:\n        gain = 100 * (aot_dur / jit_dur - 1)\n        if gain >= 0:\n            return f\"+{gain:>6.2f}%\"\n        else:\n            return f\"-{-gain:>6.2f}%\"\n\n    print(f\"{CACHE_ITEM_SIZE = }, {DTYPE = }, {BLOCK_QUOTA = }\")\n\n    def _fast_test_correctness(bs: int):\n        src_indices = _gen_indices(CUDA_CACHE_SIZE, bs)\n        dst_indices = _gen_indices(HOST_CACHE_SIZE, bs)\n        host_cache_cuda = torch.randn_like(host_cache, device=\"cuda\")\n        host_cache.copy_(host_cache_cuda, non_blocking=True)\n\n        # copy from cuda to host\n        jit_hicache_impl(\n            k_cache_dst=host_cache[0],\n            v_cache_dst=host_cache[1],\n            indices_dst=dst_indices,\n            k_cache_src=cuda_cache[0],\n            v_cache_src=cuda_cache[1],\n            indices_src=src_indices,\n            item_bytes=ITEM_BYTES,\n            block_quota=BLOCK_QUOTA,\n        )\n        dst_indices = dst_indices.cpu()\n        assert torch.all(\n            host_cache[0][dst_indices].cuda() == cuda_cache[0][src_indices]\n        )\n\n    BS_RANGE = [2**n for n in range(8, 18)]\n    for bs in BS_RANGE:\n        _fast_test_correctness(bs)\n\n    print(\"Correctness passed! Start HiCache kernel performance test...\")\n    print(\"=\" * 70)\n\n    for bs in BS_RANGE:\n        indices_dst = _gen_indices(CUDA_CACHE_SIZE, bs)\n        indices_src = _gen_indices(HOST_CACHE_SIZE, bs)\n        MEM = 2 * bs * ITEM_BYTES\n\n        def _run_kernel_h2d(impl):\n            return impl(\n                k_cache_dst=cuda_cache[0],\n                v_cache_dst=cuda_cache[1],\n                indices_dst=indices_dst,\n                k_cache_src=host_cache[0],\n                v_cache_src=host_cache[1],\n                indices_src=indices_src,\n                item_bytes=ITEM_BYTES,\n                block_quota=BLOCK_QUOTA,\n            )\n\n        our_h2d_dur = perf(lambda: _run_kernel_h2d(jit_hicache_impl))\n        ref_h2d_dur = perf(lambda: _run_kernel_h2d(ref_hicache_impl))\n        print(\n            f\"{bs = :6d}, H->D\",\n            f\"| aot {_calc_tput(ref_h2d_dur):<6.2f} GB/s\",\n            f\"| jit {_calc_tput(our_h2d_dur):<6.2f} GB/s\",\n            f\"| {_gain_str(ref_h2d_dur, our_h2d_dur)}\",\n        )\n\n    print(\"=\" * 70)\n\n    for bs in BS_RANGE:\n        indices_dst = _gen_indices(HOST_CACHE_SIZE, bs)\n        indices_src = _gen_indices(CUDA_CACHE_SIZE, bs)\n        MEM = 2 * bs * ITEM_BYTES\n\n        def _run_kernel_d2h(impl):\n            return impl(\n                k_cache_dst=host_cache[0],\n                v_cache_dst=host_cache[1],\n                indices_dst=indices_dst,\n                k_cache_src=cuda_cache[0],\n                v_cache_src=cuda_cache[1],\n                indices_src=indices_src,\n                item_bytes=ITEM_BYTES,\n                block_quota=BLOCK_QUOTA,\n            )\n\n        our_d2h_dur = perf(lambda: _run_kernel_d2h(jit_hicache_impl))\n        ref_d2h_dur = perf(lambda: _run_kernel_d2h(ref_hicache_impl))\n        print(\n            f\"{bs = :6d}, D->H\",\n            f\"| aot {_calc_tput(ref_d2h_dur):<6.2f} GB/s\",\n            f\"| jit {_calc_tput(our_d2h_dur):<6.2f} GB/s\",\n            f\"| {_gain_str(ref_d2h_dur, our_d2h_dur)}\",\n        )\n\n    print(\"=\" * 70)\n\n\ndef main() -> None:\n    torch.cuda.set_device(0)\n    stream = torch.cuda.Stream()\n    torch.cuda.set_stream(stream)\n\n    tic = torch.cuda.Event(enable_timing=True)\n    toc = torch.cuda.Event(enable_timing=True)\n\n    BUF_SIZE = 1024 * 1024 * 1024\n    cuda_mem = torch.empty(BUF_SIZE, dtype=torch.uint8, device=\"cuda\")\n    host_mem = torch.empty(BUF_SIZE, dtype=torch.uint8, device=\"cpu\", pin_memory=True)\n\n    # test peak bandwidth\n    tic.record()\n    cuda_mem.copy_(host_mem, non_blocking=True)\n    toc.record()\n    toc.synchronize()\n    dur = tic.elapsed_time(toc)\n    print(f\"Peak H->D Bandwidth: {(BUF_SIZE / (1024**3)) / (dur / 1000):.2f} GB/s\")\n\n    tic.record()\n    host_mem.copy_(cuda_mem, non_blocking=True)\n    toc.record()\n    toc.synchronize()\n    dur = tic.elapsed_time(toc)\n    print(f\"Peak D->H Bandwidth: {(BUF_SIZE / (1024**3)) / (dur / 1000):.2f} GB/s\")\n\n    for block_quota in [1, 2, 3, 4]:\n        for cache_item_size in [128, 256, 512, 1024]:\n            args = HicacheBenchArgs(\n                cache_item_size=cache_item_size,\n                dtype=torch.float16,\n                block_quota=block_quota,\n            )\n            test_hicache_kernel(args)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "benchmark/json_decode_regex/README.md",
    "content": "## Run benchmark\n\n### Build dataset\n```\npip install wikipedia\npython3 build_dataset.py\n```\n\n### Dependencies\n\n```\nllama_cpp_python          0.2.19\nguidance                  0.1.10\nvllm                      0.2.5\noutlines                  0.0.22\n```\n\n### Benchmark sglang\n\nRun Llama-7B\n\n```\npython3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000\n```\n\nRun Mixtral-8x7B\n\n```\npython3 -m sglang.launch_server --model-path mistralai/Mixtral-8x7B-Instruct-v0.1 --port 30000 --tp-size 8\n```\n\nBenchmark\n\n```\npython3 bench_sglang.py --num-questions 10\n```\n\n\n### Benchmark Outlines + vLLM\n\nRun Llama-7B\n\n```\npython3 -m outlines.serve.serve --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf  --disable-log-requests --port 21000\n```\n\nBenchmark\n\n```\npython3 bench_other.py --backend outlines --num-questions 10\n```\n\n\n### Benchmark guidance\n\nRun Llama-7B and benchmark\n\n```\npython3 bench_other.py --backend guidance --num-questions 10 --parallel 1 --n-ctx 4096 --model-path path/to/gguf\n```\n"
  },
  {
    "path": "benchmark/json_decode_regex/bench_other.py",
    "content": "import argparse\nimport json\nimport time\nfrom concurrent.futures import ThreadPoolExecutor\nfrom functools import partial\n\nfrom tqdm import tqdm\n\nfrom sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STR\nfrom sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate\nfrom sglang.utils import dump_state_text, read_jsonl\n\nREGEX_LIST = r\"\\[(\" + REGEX_STR + \", )*\" + REGEX_STR + r\"\\]\"\n\n\n# fmt: off\ndef json_decode(document, generate):\n    s = \"Please extract the information of a city from the following wikipedia page.\\n\"\n    s += \"Page begin.\\n\" + document + \"Page end.\\n\"\n    s += \"Here is the name, country, and symbol of the city in JSON format.\\n\"\n    s += \"{\\n\"\n    s += '  \"name\": '\n    s += generate(s, max_tokens=8, regex=REGEX_STR + \",\") + \"\\n\"\n    s += '  \"country\": '\n    s += generate(s, max_tokens=8, regex=REGEX_STR + \",\") + \"\\n\"\n    s += '  \"latitude\": '\n    s += generate(s, max_tokens=8, regex=REGEX_FLOAT + \",\") + \"\\n\"\n    s += '  \"population\": '\n    s += generate(s, max_tokens=8, regex=REGEX_INT + \",\") + \"\\n\"\n    s += '  \"top 3 landmarks\": '\n    s += generate(s, max_tokens=24, regex=REGEX_LIST) + \"\\n\"\n    s += \"}\\n\"\n\n    return s\n# fmt: on\n\n\ndef main(args):\n    lines = read_jsonl(args.data_path)\n    arguments = []\n    for i in range(len(lines[: args.num_questions])):\n        arguments.append(\n            {\n                \"document\": lines[i][\"document\"],\n            }\n        )\n    states = [None] * len(arguments)\n\n    # Select backend\n    call_generate = partial(get_call_generate(args), temperature=0)\n\n    # Run requests\n    def get_one_answer(i):\n        states[i] = json_decode(generate=call_generate, **arguments[i])\n\n    tic = time.perf_counter()\n    if args.parallel == 1:\n        for i in tqdm(range(len(arguments))):\n            get_one_answer(i)\n    else:\n        with ThreadPoolExecutor(args.parallel) as executor:\n            rets = list(\n                tqdm(\n                    executor.map(get_one_answer, list(range(len(arguments)))),\n                    total=len(arguments),\n                )\n            )\n            for _ in rets:\n                pass\n\n    latency = time.perf_counter() - tic\n\n    # Compute accuracy\n    print(f\"Latency: {latency:.3f}\")\n\n    # Write results\n    dump_state_text(f\"tmp_output_{args.backend}.txt\", states)\n\n    with open(args.result_file, \"a\") as fout:\n        value = {\n            \"task\": \"json_decode_regex\",\n            \"backend\": args.backend,\n            \"num_gpus\": 1,\n            \"latency\": round(latency, 3),\n            \"num_requests\": args.num_questions,\n            \"other\": {\n                \"parallel\": args.parallel,\n            },\n        }\n        fout.write(json.dumps(value) + \"\\n\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--data-path\", type=str, default=\"questions.jsonl\")\n    parser.add_argument(\"--num-questions\", type=int, default=20)\n    args = add_common_other_args_and_parse(parser)\n    main(args)\n"
  },
  {
    "path": "benchmark/json_decode_regex/bench_sglang.py",
    "content": "import argparse\nimport json\nimport time\n\nimport sglang as sgl\nfrom sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STR\nfrom sglang.test.test_utils import (\n    add_common_sglang_args_and_parse,\n    select_sglang_backend,\n)\nfrom sglang.utils import dump_state_text, read_jsonl\n\nREGEX_LIST = r\"\\[(\" + REGEX_STR + \", )*\" + REGEX_STR + r\"\\]\"\n\n# fmt: off\n@sgl.function\ndef json_warm_up(s):\n    s += \"The information about Hogwarts is in the following JSON format.\\n\"\n    with s.var_scope(\"json_output\"):\n        s += \"{\\n\"\n        s += '  \"name\": ' + sgl.gen(\"name\", max_tokens=8, regex=REGEX_STR + \",\") + \"\\n\"\n        s += '  \"country\": ' + sgl.gen(\"country\", max_tokens=8, regex=REGEX_STR + \",\") + \"\\n\"\n        s += '  \"latitude\": ' + sgl.gen(\"latitude\", max_tokens=8, regex=REGEX_FLOAT + \",\") + \"\\n\"\n        s += '  \"population\": ' + sgl.gen(\"population\", max_tokens=8, regex=REGEX_INT + \",\") + \"\\n\"\n        s += '  \"top 3 landmarks\": ' + sgl.gen( \"landmarks\", max_tokens=24, regex=REGEX_LIST) + \"\\n\"\n        s += \"}\\n\"\n    print(f'The warmp up json result is:\\n{s[\"json_output\"]}')\n# fmt: on\n\n# fmt: off\n@sgl.function\ndef json_decode(s, document):\n    s += \"Please extract the information of a city from the following wikipedia page.\\n\"\n    s += \"Page begin.\\n\" + document + \"Page end.\\n\"\n    s += \"Here is the name, country, and symbol of the city in JSON format.\\n\"\n    with s.var_scope(\"json_output\"):\n        s += \"{\\n\"\n        s += '  \"name\": ' + sgl.gen(\"name\", max_tokens=8, regex=REGEX_STR + \",\") + \"\\n\"\n        s += '  \"country\": ' + sgl.gen(\"country\", max_tokens=8, regex=REGEX_STR + \",\") + \"\\n\"\n        s += '  \"latitude\": ' + sgl.gen(\"latitude\", max_tokens=8, regex=REGEX_FLOAT + \",\") + \"\\n\"\n        s += '  \"population\": ' + sgl.gen(\"population\", max_tokens=8, regex=REGEX_INT + \",\") + \"\\n\"\n        s += '  \"top 3 landmarks\": ' + sgl.gen( \"landmarks\", max_tokens=24, regex=REGEX_LIST) + \"\\n\"\n        s += \"}\\n\"\n# fmt: on\n\n\ndef main(args):\n    lines = read_jsonl(args.data_path)\n    lines = list(lines)\n    arguments = []\n    for i in range(len(lines[: args.num_questions])):\n        arguments.append(\n            {\n                \"document\": lines[i][\"document\"],\n            }\n        )\n\n    # Select backend\n    backend = select_sglang_backend(args)\n    sgl.set_default_backend(backend)\n\n    # Warm up\n    json_warm_up.run().sync()\n\n    # Run requests\n    tic = time.perf_counter()\n    states = json_decode.run_batch(\n        arguments, temperature=0, num_threads=args.parallel, progress_bar=True\n    )\n    latency = time.perf_counter() - tic\n\n    # Compute accuracy\n    print(f\"Latency: {latency:.3f}\")\n\n    # Write results\n    dump_state_text(f\"tmp_output_{args.backend}.txt\", states)\n\n    with open(f\"tmp_{args.backend}_json_results.txt\", \"w\") as fout:\n        for state in states:\n            fout.write(state[\"json_output\"] + \"\\n\")\n\n    with open(args.result_file, \"a\") as fout:\n        value = {\n            \"task\": \"json_decode_regex\",\n            \"backend\": args.backend,\n            \"num_gpus\": 1,\n            \"latency\": round(latency, 3),\n            \"num_requests\": args.num_questions,\n            \"other\": {\n                \"parallel\": args.parallel,\n            },\n        }\n        fout.write(json.dumps(value) + \"\\n\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--data-path\", type=str, default=\"questions.jsonl\")\n    parser.add_argument(\"--num-questions\", type=int, default=20)\n    args = add_common_sglang_args_and_parse(parser)\n    main(args)\n"
  },
  {
    "path": "benchmark/json_decode_regex/build_dataset.py",
    "content": "import json\n\nimport transformers\nimport wikipedia\n\nmodel_path = \"meta-llama/Llama-2-7b-chat-hf\"\nt = transformers.AutoTokenizer.from_pretrained(model_path)\ncity_names = [\n    \"los angles\",\n    \"london\",\n    \"tokyo\",\n    \"beijing\",\n    \"singapore\",\n    \"paris\",\n    \"dubai\",\n    \"sydney\",\n    \"moscow\",\n    \"rome\",\n    \"toronto\",\n    \"rio de janeiro\",\n    \"istanbul\",\n    \"berlin\",\n    \"auckland\",\n    \"buenos aires\",\n    \"mexico city\",\n    \"mumbai\",\n    \"seoul\",\n    \"bangkok\",\n    \"cairo\",\n    \"athens\",\n    \"jerusalem\",\n]\n\n\ndef get_content(city_name):\n    content = str(wikipedia.page(city_name).content)\n    content = content.replace(\"\\n\\n\", \"\\n\")\n\n    tokens = t.encode(content)\n\n    expected_tokens = 3000\n    truncate_len = int((expected_tokens / len(tokens)) * len(content))\n    truncate_content = content[:truncate_len]\n    truncate_tokens = t.encode(truncate_content)\n\n    # Count token\n    print(\n        f\"city_name: {city_name}, #tokens: {len(tokens)}, #truncate tokens: {len(truncate_tokens)}\"\n    )\n\n    return truncate_content\n\n\nif __name__ == \"__main__\":\n    with open(\"questions.jsonl\", \"w\") as fout:\n        for city_name in city_names:\n            truncate_content = get_content(city_name)\n            fout.write(json.dumps({\"document\": truncate_content}) + \"\\n\")\n"
  },
  {
    "path": "benchmark/json_jump_forward/README.md",
    "content": "## Run benchmark\n\n### Dependencies\n\n```\nllama_cpp_python          0.2.38\nguidance                  0.1.10\nvllm                      0.2.7\noutlines                  0.0.25\n```\n\n### Build dataset\n\nWhen benchmarking long document information retrieval, run the following command to build the dataset:\n\n```bash\npip install wikipedia\npython3 build_dataset.py\n```\n\n### Benchmark sglang\n\nRun Llama-7B\n\n```bash\npython3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000\n```\n\nBenchmark Character Generation\n\n```bash\npython3 bench_sglang.py --mode character\n```\n\nBenchmark City Information Retrieval\n\n```bash\npython3 bench_sglang.py --mode city\n```\n\n\n### Benchmark Outlines + vLLM\n\nRun Llama-7B\n\n```bash\npython3 -m outlines.serve.serve --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf  --disable-log-requests --port 21000\n```\n\nBenchmark Character Generation\n\n```bash\npython3 bench_other.py --mode character --backend outlines\n```\n\nBenchmark City Information Retrieval\n\n```bash\npython3 bench_other.py --mode city --backend outlines\n```\n\n### Benchmark guidance\n\nRun Llama-7B and benchmark character generation\n\n```bash\npython3 bench_other.py --mode character --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf\n```\n\nRun Llama-7B and benchmark city information retrieval\n\n```bash\npython3 bench_other.py --mode city --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf\n```\n\n### Benchmark lmql\n\nRun Llama-7B and benchmark character generation\n\n```\npython3 bench_other.py --mode character --backend lmql --parallel 1\n```\n\nRun Llama-7B and benchmark city information retrieval\n\n```\npython3 bench_other.py --mode city --backend lmql --parallel 1\n```\n"
  },
  {
    "path": "benchmark/json_jump_forward/bench_other.py",
    "content": "import argparse\nimport json\nimport time\nfrom concurrent.futures import ThreadPoolExecutor\nfrom functools import partial\n\nimport guidance\nfrom tqdm import tqdm\n\nfrom sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate\nfrom sglang.utils import dump_state_text, read_jsonl\n\n# there are some FSM bugs with json regex converted from pydantic model\n# here use a string regex instead\n# regex_string = build_regex_from_object(HarryPoterRole)\ncharacter_regex = (\n    r\"\"\"\\{\\n\"\"\"\n    + r\"\"\"    \"name\": \"[\\w\\d\\s]{1,16}\",\\n\"\"\"\n    + r\"\"\"    \"house\": \"(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)\",\\n\"\"\"\n    + r\"\"\"    \"blood status\": \"(Pure-blood|Half-blood|Muggle-born)\",\\n\"\"\"\n    + r\"\"\"    \"occupation\": \"(student|teacher|auror|ministry of magic|death eater|order of the phoenix)\",\\n\"\"\"\n    + r\"\"\"    \"wand\": \\{\\n\"\"\"\n    + r\"\"\"        \"wood\": \"[\\w\\d\\s]{1,16}\",\\n\"\"\"\n    + r\"\"\"        \"core\": \"[\\w\\d\\s]{1,16}\",\\n\"\"\"\n    + r\"\"\"        \"length\": [0-9]{1,2}\\.[0-9]{0,2}\\n\"\"\"\n    + r\"\"\"    \\},\\n\"\"\"\n    + r\"\"\"    \"alive\": \"(Alive|Deceased)\",\\n\"\"\"\n    + r\"\"\"    \"patronus\": \"[\\w\\d\\s]{1,16}\",\\n\"\"\"\n    + r\"\"\"    \"bogart\": \"[\\w\\d\\s]{1,16}\"\\n\"\"\"\n    + r\"\"\"\\}\"\"\"\n)\n\ncity_regex = (\n    r\"\"\"\\{\\n\"\"\"\n    + r\"\"\"  \"name\": \"[\\w\\d\\s]{1,16}\",\\n\"\"\"\n    + r\"\"\"  \"country\": \"[\\w\\d\\s]{1,16}\",\\n\"\"\"\n    + r\"\"\"  \"latitude\": [-+]?[0-9]*\\.?[0-9]{0,2},\\n\"\"\"\n    + r\"\"\"  \"population\": [-+]?[0-9]{1,9},\\n\"\"\"\n    + r\"\"\"  \"top 3 landmarks\": \\[\"[\\w\\d\\s]{1,16}\", \"[\\w\\d\\s]{1,16}\", \"[\\w\\d\\s]{1,16}\"\\]\\n\"\"\"\n    + r\"\"\"\\}\"\"\"\n)\n\n# fmt: off\ndef character_gen(name, generate):\n    s = name + \" is a character in Harry Potter. Please fill in the following information about this character.\\n\"\n    s += generate(s, max_tokens=256, regex=character_regex)\n    return s\n# fmt: on\n\n# fmt: off\ndef city_gen(document, generate):\n    s = \"Please extract the information of a city from the following wikipedia page.\\n\"\n    s += \"Page begin.\\n\" + document + \"Page end.\\n\"\n    s += \"Here is the name, country, and symbol of the city in JSON format.\\n\"\n    s += generate(s, max_tokens=256, regex=city_regex)\n    return s\n# fmt: on\n\n\n@guidance\ndef character_maker(lm, name):\n    regex_str_no_quote = r\"[\\w\\d\\s]+\"\n    regex_float = r\"[0-9]+\\.[0-9]+\"\n    lm += f\"\"\"\\\n    {name} is a character in Harry Potter. Please fill in the following information about this character.\n    {{\n        \"name\": \"{guidance.gen(\"name\", max_tokens=16, regex=regex_str_no_quote)}\",\n        \"house\": \"{guidance.select(options=['Gryffindor', 'Slytherin', 'Ravenclaw', 'Hufflepuff'], name='house')}\",\n        \"blood status\": \"{guidance.select(options=['Pure-blood', 'Half-blood', 'Muggle-born'], name='blood status')}\",\n        \"occupation\": \"{guidance.select(options=['student', 'teacher', 'auror', 'ministry of magic', 'death eater', 'order of the phoenix'], name='occupation')}\",\n        \"wand\": {{\n            \"wood\": \"{guidance.gen(\"wood\", max_tokens=16, regex=regex_str_no_quote)}\",\n            \"core\": \"{guidance.gen('core', max_tokens=16, regex=regex_str_no_quote)}\",\n            \"length\": {guidance.gen('length', max_tokens=10, regex=regex_float)}\n        }},\n        \"alive\": \"{guidance.select(options=['Alive', 'Deceased'], name='alive')}\",\n        \"patronus\": \"{guidance.gen('patronus', max_tokens=16, regex=regex_str_no_quote)}\",\n        \"bogart\": \"{guidance.gen('bogart', max_tokens=16, regex=regex_str_no_quote)}\"\n    }}\n    \"\"\"\n\n    return lm\n\n\nasync def call_generate_lmql(\n    prompt, temperature, max_tokens, regex, max_len=4096, model=None, **kwargs\n):\n    assert model is not None\n    import lmql\n\n    @lmql.query(model=model)\n    async def program(question, max_tokens, regex):\n        '''lmql\n        \"\"\"{question}[ANSWER]\"\"\" where len(TOKENS(ANSWER)) < max_tokens and REGEX(ANSWER, regex)\n        return ANSWER\n        '''\n\n    return await program(\n        question=prompt,\n        temperature=temperature,\n        max_tokens=max_tokens,\n        max_len=max_len,\n        regex=regex,\n        **kwargs,\n    )\n\n\n@guidance\ndef city_maker(lm, document):\n    regex_str_no_quote = r\"[\\w\\d\\s]+\"\n    regex_float = r\"[0-9]+\\.[0-9]+\"\n    lm += f\"\"\"\\\n    Please extract the information of a city from the following wikipedia page.\n    Page begin.\n    {document}\n    Page end.\n    Here is the name, country, and symbol of the city in JSON format.\n    {{\n        \"name\": \"{guidance.gen(\"name\", max_tokens=16, regex=regex_str_no_quote)}\",\n        \"country\": \"{guidance.gen(\"country\", max_tokens=16, regex=regex_str_no_quote)}\",\n        \"latitude\": {guidance.gen(\"latitude\", max_tokens=10, regex=regex_float)},\n        \"population\": {guidance.gen(\"population\", max_tokens=10, regex=r\"[0-9]+\")},\n        \"top 3 landmarks\": [\n            \"{guidance.gen(\"landmark1\", max_tokens=16, regex=regex_str_no_quote)}\", \"{guidance.gen(\"landmark2\", max_tokens=16, regex=regex_str_no_quote)}\", \"{guidance.gen(\"landmark3\", max_tokens=16, regex=regex_str_no_quote)}\"\n        ]\n    }}\n    \"\"\"\n\n    return lm\n\n\ndef bench_character(args):\n    arguments = []\n    with open(args.data_path, \"r\") as f:\n        for line in f:\n            arguments.append({\"name\": line.strip()})\n    arguments = arguments[: args.num_jsons]\n\n    states = [None] * len(arguments)\n\n    # Select backend\n    if args.backend == \"outlines\":\n        call_generate = partial(get_call_generate(args), temperature=0)\n\n        def get_one_answer(i):\n            states[i] = character_gen(**arguments[i], generate=call_generate)\n\n    elif args.backend == \"guidance\":\n        model = guidance.models.LlamaCpp(\n            args.model_path,\n            n_gpu_layers=-1,\n            n_ctx=args.n_ctx,\n        )\n\n        def get_one_answer(i):\n            lm = model + character_maker(**arguments[i])\n            states[i] = lm\n\n    elif args.backend == \"lmql\":\n        import asyncio\n\n        import lmql\n\n        model = lmql.model(args.model_path, endpoint=f\"{args.host}:{args.port}\")\n        call_generate = partial(\n            call_generate_lmql,\n            model=model,\n            max_tokens=256,\n            regex=character_regex,\n        )\n\n        async def get_one_answer_async(i):\n            states[i] = await call_generate(prompt=arguments[i][\"name\"], temperature=0)\n\n    else:\n        raise ValueError(f\"Invalid backend: {args.backend}\")\n\n    tic = time.perf_counter()\n\n    if args.backend != \"lmql\":\n        if args.parallel == 1:\n            for i in tqdm(range(len(arguments))):\n                get_one_answer(i)\n        else:\n            with ThreadPoolExecutor(args.parallel) as executor:\n                rets = list(\n                    tqdm(\n                        executor.map(get_one_answer, list(range(len(arguments)))),\n                        total=len(arguments),\n                    )\n                )\n                for _ in rets:\n                    pass\n    else:\n        batches = []\n        for i in range(0, len(arguments), args.parallel):\n            batches.append(list(range(i, min(i + args.parallel, len(arguments)))))\n        loop = asyncio.get_event_loop()\n\n        for bt in tqdm(batches):\n            loop.run_until_complete(\n                asyncio.gather(*[get_one_answer_async(i) for i in bt])\n            )\n\n    latency = time.perf_counter() - tic\n\n    return states, latency\n\n\ndef bench_city_doc(args):\n    arguments = []\n    for line in read_jsonl(args.data_path):\n        arguments.append({\"document\": line[\"document\"]})\n    arguments = arguments[: args.num_jsons]\n\n    states = [None] * len(arguments)\n\n    # Select backend\n    if args.backend == \"outlines\":\n        call_generate = partial(get_call_generate(args), temperature=0)\n\n        def get_one_answer(i):\n            states[i] = city_gen(**arguments[i], generate=call_generate)\n\n    elif args.backend == \"guidance\":\n        model = guidance.models.LlamaCpp(\n            args.model_path,\n            n_gpu_layers=-1,\n            n_ctx=args.n_ctx,\n        )\n\n        def get_one_answer(i):\n            lm = model + city_maker(**arguments[i])\n            states[i] = lm\n\n    else:\n        raise ValueError(f\"Invalid backend: {args.backend}\")\n\n    tic = time.perf_counter()\n    if args.parallel == 1:\n        for i in tqdm(range(len(arguments))):\n            get_one_answer(i)\n    else:\n        with ThreadPoolExecutor(args.parallel) as executor:\n            rets = executor.map(get_one_answer, list(range(len(arguments))))\n            for _ in rets:\n                pass\n\n    latency = time.perf_counter() - tic\n\n    return states, latency\n\n\ndef main(args):\n    if args.mode == \"character\":\n        args.data_path = \"dataset.txt\"\n        states, latency = bench_character(args)\n    elif args.mode == \"city\":\n        args.data_path = \"questions.jsonl\"\n        states, latency = bench_city_doc(args)\n\n    # Compute accuracy\n    print(f\"Latency: {latency:.3f}\")\n\n    # Write results\n    dump_state_text(f\"tmp_output_{args.backend}_{args.mode}.txt\", states)\n\n    with open(args.result_file, \"a\") as fout:\n        value = {\n            \"task\": \"json_jump_forward\",\n            \"backend\": args.backend,\n            \"latency\": round(latency, 3),\n            \"num_jsons\": args.num_jsons,\n            \"mode\": args.mode,\n            \"parallel\": args.parallel,\n        }\n        fout.write(json.dumps(value) + \"\\n\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--data-path\", type=str)\n    parser.add_argument(\"--num-jsons\", type=int, default=50)\n    parser.add_argument(\n        \"--mode\", type=str, default=\"character\", choices=[\"character\", \"city\"]\n    )\n    args = add_common_other_args_and_parse(parser)\n    main(args)\n"
  },
  {
    "path": "benchmark/json_jump_forward/bench_sglang.py",
    "content": "import argparse\nimport json\nimport time\n\nimport sglang as sgl\nfrom sglang.test.test_utils import (\n    add_common_sglang_args_and_parse,\n    select_sglang_backend,\n)\nfrom sglang.utils import dump_state_text, read_jsonl\n\n# there are some FSM bugs with json regex converted from pydantic model\n# here use a string regex instead\n# regex_string = build_regex_from_object(HarryPoterRole)\ncharacter_regex = (\n    r\"\"\"\\{\\n\"\"\"\n    + r\"\"\"    \"name\": \"[\\w\\d\\s]{1,16}\",\\n\"\"\"\n    + r\"\"\"    \"house\": \"(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)\",\\n\"\"\"\n    + r\"\"\"    \"blood status\": \"(Pure-blood|Half-blood|Muggle-born)\",\\n\"\"\"\n    + r\"\"\"    \"occupation\": \"(student|teacher|auror|ministry of magic|death eater|order of the phoenix)\",\\n\"\"\"\n    + r\"\"\"    \"wand\": \\{\\n\"\"\"\n    + r\"\"\"        \"wood\": \"[\\w\\d\\s]{1,16}\",\\n\"\"\"\n    + r\"\"\"        \"core\": \"[\\w\\d\\s]{1,16}\",\\n\"\"\"\n    + r\"\"\"        \"length\": [0-9]{1,2}\\.[0-9]{0,2}\\n\"\"\"\n    + r\"\"\"    \\},\\n\"\"\"\n    + r\"\"\"    \"alive\": \"(Alive|Deceased)\",\\n\"\"\"\n    + r\"\"\"    \"patronus\": \"[\\w\\d\\s]{1,16}\",\\n\"\"\"\n    + r\"\"\"    \"bogart\": \"[\\w\\d\\s]{1,16}\"\\n\"\"\"\n    + r\"\"\"\\}\"\"\"\n)\n\ncity_regex = (\n    r\"\"\"\\{\\n\"\"\"\n    + r\"\"\"  \"name\": \"[\\w\\d\\s]{1,16}\",\\n\"\"\"\n    + r\"\"\"  \"country\": \"[\\w\\d\\s]{1,16}\",\\n\"\"\"\n    + r\"\"\"  \"latitude\": [-+]?[0-9]*\\.?[0-9]{0,2},\\n\"\"\"\n    + r\"\"\"  \"population\": [-+]?[0-9]{1,9},\\n\"\"\"\n    + r\"\"\"  \"top 3 landmarks\": \\[\"[\\w\\d\\s]{1,16}\", \"[\\w\\d\\s]{1,16}\", \"[\\w\\d\\s]{1,16}\"\\]\\n\"\"\"\n    + r\"\"\"\\}\"\"\"\n)\n\n# fmt: off\n@sgl.function\ndef character_gen(s, name):\n    s += name + \" is a character in Harry Potter. Please fill in the following information about this character.\\n\"\n    s += sgl.gen(\"json_output\", max_tokens=256, regex=character_regex)\n# fmt: on\n\n# fmt: off\n@sgl.function\ndef city_gen(s, document):\n    s += \"Please extract the information of a city from the following wikipedia page.\\n\"\n    s += \"Page begin.\\n\" + document + \"Page end.\\n\"\n    s += \"Here is the name, country, and symbol of the city in JSON format.\\n\"\n    s += sgl.gen(\"json_output\",max_tokens=256, regex=city_regex)\n# fmt: on\n\n\ndef bench_city_doc(args):\n    arguments = []\n    for line in read_jsonl(args.data_path):\n        arguments.append({\"document\": line[\"document\"]})\n    arguments = arguments[: args.num_jsons]\n\n    # Select backend\n    backend = select_sglang_backend(args)\n    sgl.set_default_backend(backend)\n\n    # Run requests\n    tic = time.perf_counter()\n    states = city_gen.run_batch(\n        arguments,\n        temperature=0,\n        num_threads=args.parallel,\n        progress_bar=True,\n    )\n    latency = time.perf_counter() - tic\n\n    return states, latency\n\n\ndef bench_character(args):\n    arguments = []\n    with open(args.data_path, \"r\") as f:\n        for line in f:\n            arguments.append({\"name\": line.strip()})\n    arguments = arguments[: args.num_jsons]\n\n    # Select backend\n    backend = select_sglang_backend(args)\n    sgl.set_default_backend(backend)\n\n    # Run requests\n    tic = time.perf_counter()\n    states = character_gen.run_batch(\n        arguments,\n        temperature=0,\n        num_threads=args.parallel,\n        progress_bar=True,\n    )\n    latency = time.perf_counter() - tic\n\n    return states, latency\n\n\ndef main(args):\n    if args.mode == \"character\":\n        args.data_path = \"dataset.txt\"\n        states, latency = bench_character(args)\n    elif args.mode == \"city\":\n        args.data_path = \"questions.jsonl\"\n        states, latency = bench_city_doc(args)\n\n    # Compute accuracy\n    print(f\"Latency: {latency:.3f}\")\n\n    # Write results\n    dump_state_text(f\"tmp_output_{args.backend}_{args.mode}.txt\", states)\n    with open(f\"{args.backend}_{args.mode}.json\", \"w\") as fout:\n        for state in states:\n            fout.write(state[\"json_output\"] + \"\\n\")\n\n    with open(args.result_file, \"a\") as fout:\n        value = {\n            \"task\": \"json_jump_forward\",\n            \"backend\": args.backend,\n            \"latency\": round(latency, 3),\n            \"num_jsons\": args.num_jsons,\n            \"mode\": args.mode,\n            \"parallel\": args.parallel,\n        }\n        fout.write(json.dumps(value) + \"\\n\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--data-path\", type=str)\n    parser.add_argument(\"--num-jsons\", type=int, default=50)\n    parser.add_argument(\n        \"--mode\", type=str, default=\"character\", choices=[\"character\", \"city\"]\n    )\n    args = add_common_sglang_args_and_parse(parser)\n    main(args)\n"
  },
  {
    "path": "benchmark/json_jump_forward/build_dataset.py",
    "content": "import json\n\nimport transformers\nimport wikipedia\n\nmodel_path = \"meta-llama/Llama-2-7b-chat-hf\"\nt = transformers.AutoTokenizer.from_pretrained(model_path)\ncity_names = [\n    \"los angles\",\n    \"london\",\n    \"tokyo\",\n    \"beijing\",\n    \"singapore\",\n    \"paris\",\n    \"dubai\",\n    \"sydney\",\n    \"moscow\",\n    \"rome\",\n    \"toronto\",\n    \"rio de janeiro\",\n    \"istanbul\",\n    \"berlin\",\n    \"auckland\",\n    \"buenos aires\",\n    \"mexico city\",\n    \"mumbai\",\n    \"seoul\",\n    \"bangkok\",\n    \"cairo\",\n    \"athens\",\n    \"jerusalem\",\n]\n\n\ndef get_content(city_name):\n    content = str(wikipedia.page(city_name).content)\n    content = content.replace(\"\\n\\n\", \"\\n\")\n\n    tokens = t.encode(content)\n\n    expected_tokens = 3000\n    truncate_len = int((expected_tokens / len(tokens)) * len(content))\n    truncate_content = content[:truncate_len]\n    truncate_tokens = t.encode(truncate_content)\n\n    # Count token\n    print(\n        f\"city_name: {city_name}, #tokens: {len(tokens)}, #truncate tokens: {len(truncate_tokens)}\"\n    )\n\n    return truncate_content\n\n\nif __name__ == \"__main__\":\n    with open(\"questions.jsonl\", \"w\") as fout:\n        for city_name in city_names:\n            truncate_content = get_content(city_name)\n            fout.write(json.dumps({\"document\": truncate_content}) + \"\\n\")\n"
  },
  {
    "path": "benchmark/json_jump_forward/dataset.txt",
    "content": "Harry Potter\nHermione Granger\nRon Weasley\nAlbus Dumbledore\nSeverus Snape\nRubeus Hagrid\nDraco Malfoy\nGinny Weasley\nFred Weasley\nGeorge Weasley\nPercy Weasley\nSirius Black\nRemus Lupin\nNeville Longbottom\nLuna Lovegood\nCedric Diggory\nCho Chang\nLord Voldemort\nMinerva McGonagall\nFilius Flitwick\nDolores Umbridge\nBellatrix Lestrange\nLucius Malfoy\nMolly Weasley\nArthur Weasley\nNymphadora Tonks\nDobby\nMoaning Myrtle\nPeter Pettigrew\nAlastor 'Mad-Eye' Moody\nHorace Slughorn\nVernon Dursley\nPetunia Dursley\nDudley Dursley\nArgus Filch\nSybill Trelawney\nGilderoy Lockhart\nFleur Delacour\nViktor Krum\nBill Weasley\nOliver Wood\nCornelius Fudge\nBarty Crouch Sr.\nBarty Crouch Jr.\nKingsley Shacklebolt\nQuirinus Quirrell\nNearly Headless Nick\nAunt Marge\nGriphook\nLudo Bagman\n"
  },
  {
    "path": "benchmark/json_schema/README.md",
    "content": "## Run benchmark\n\n### Benchmark sglang\n\nRun Llama-8b\n\n```bash\npython3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --port 30000\n```\n\nBenchmark\n\n```bash\npython3 bench_sglang.py\n```\n"
  },
  {
    "path": "benchmark/json_schema/bench_sglang.py",
    "content": "import argparse\nimport json\nimport time\nfrom typing import List, Tuple\n\nimport jsonschema\nfrom datasets import load_dataset\n\nimport sglang as sgl\nfrom sglang.global_config import global_config\nfrom sglang.srt.utils.hf_transformers_utils import get_tokenizer\nfrom sglang.test.test_utils import (\n    add_common_sglang_args_and_parse,\n    select_sglang_backend,\n)\nfrom sglang.utils import dump_state_text\n\n\n@sgl.function\ndef schema_gen(s, message: Tuple[str, str], json_schema: str):\n    system, user = message\n    s += sgl.system(system)\n    s += sgl.user(user)\n    s += sgl.assistant(\n        sgl.gen(\"json_output\", temperature=0, max_tokens=256, json_schema=json_schema)\n    )\n\n\ndef contains_formats(schema, formats: List[str]):\n    if isinstance(schema, dict):\n        if schema.get(\"format\", None) in formats:\n            return True\n        for value in schema.values():\n            if contains_formats(value, formats):\n                return True\n    elif isinstance(schema, list):\n        for item in schema:\n            if contains_formats(item, formats):\n                return True\n    return False\n\n\ndef convert_dataset(path: str):\n    raw_dataset = load_dataset(path)\n    dataset = []\n    for data in raw_dataset[\"train\"]:\n        messages = data[\"prompt\"]\n        schema = data[\"schema\"]\n        obj = json.loads(schema)\n\n        # skip some corrupted examples\n        if obj.get(\"type\", None) is None:\n            continue\n\n        # skip schema with format \"email\"\n        # which is not supported by outlines for now\n        if contains_formats(obj, [\"email\"]):\n            continue\n\n        system = messages[0]\n        user = messages[1]\n        assert system[\"role\"] == \"system\", \"invalid role\"\n        assert user[\"role\"] == \"user\", \"invalid role\"\n        assert len(messages) == 2, \"invalid message length\"\n        message = json.dumps(system[\"content\"]), json.dumps(user[\"content\"])\n        dataset.append(\n            {\n                \"message\": message,\n                \"json_schema\": schema,\n            }\n        )\n\n    return dataset\n\n\ndef bench_schema(args):\n    arguments = convert_dataset(args.data_path)\n\n    if args.num_jsons < 0 or args.num_jsons > len(arguments):\n        args.num_jsons = len(arguments)\n    arguments = arguments[: args.num_jsons]\n\n    # Select backend\n    backend = select_sglang_backend(args)\n    sgl.set_default_backend(backend)\n\n    # Run requests\n    tic = time.perf_counter()\n    states = schema_gen.run_batch(\n        arguments,\n        temperature=0,\n        num_threads=args.parallel,\n        progress_bar=True,\n    )\n    latency = time.perf_counter() - tic\n\n    # Check if the outputs are valid\n    indexes = []\n    for i, state in enumerate(states):\n        try:\n            schema = json.loads(arguments[i][\"json_schema\"])\n            obj = json.loads(state[\"json_output\"])\n            assert jsonschema.validate(obj, schema) is None\n        except Exception as e:\n            print(e)\n            indexes.append(i)\n\n    return states, latency\n\n\ndef main(args):\n    states, latency = bench_schema(args)\n\n    # Compute accuracy\n    tokenizer = get_tokenizer(\n        global_config.default_backend.get_server_info()[\"tokenizer_path\"]\n    )\n    output_jsons = [state[\"json_output\"] for state in states]\n    num_output_tokens = sum(len(tokenizer.encode(x)) for x in output_jsons)\n    print(f\"Latency: {latency:.3f}\")\n    print(f\"Output throughput: {num_output_tokens / latency:.3f} token/s\")\n    print(f\"#output tokens: {num_output_tokens}\")\n\n    # Write results\n    dump_state_text(f\"tmp_output_{args.backend}.txt\", states)\n    with open(f\"{args.backend}.jsonl\", \"w\") as fout:\n        for state in states:\n            fout.write(state[\"json_output\"] + \"\\n\")\n\n    with open(args.result_file, \"a\") as fout:\n        value = {\n            \"task\": \"json_schema\",\n            \"backend\": args.backend,\n            \"latency\": round(latency, 3),\n            \"num_jsons\": args.num_jsons,\n            \"parallel\": args.parallel,\n        }\n        fout.write(json.dumps(value) + \"\\n\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--data-path\", type=str, default=\"NousResearch/json-mode-eval\")\n    parser.add_argument(\"--num-jsons\", type=int, default=-1)\n    args = add_common_sglang_args_and_parse(parser)\n    main(args)\n"
  },
  {
    "path": "benchmark/kernels/all_reduce/benchmark_aiter.py",
    "content": "\"\"\"\nBenchmark SGLang vs Aiter custom all-reduce across message sizes.\nUsage:\n    torchrun --nproc_per_node=2 benchmark_aiter.py\n    torchrun --nproc_per_node=4 benchmark_aiter.py\n    torchrun --nproc_per_node=8 benchmark_aiter.py\n\"\"\"\n\nimport argparse\nimport os\nimport sys\nimport time\nfrom typing import List, Optional, Tuple\n\nimport torch\nimport torch.distributed as dist\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description=\"Benchmark SGLang vs Aiter custom all-reduce across message sizes.\"\n    )\n    parser.add_argument(\n        \"--backend\",\n        type=str,\n        default=\"gloo\",\n        help=\"Process group backend for the custom-AR control path (must NOT be nccl).\",\n    )\n    parser.add_argument(\n        \"--warmup\",\n        type=int,\n        default=5,\n        help=\"Warmup iterations per size per implementation.\",\n    )\n    parser.add_argument(\n        \"--iters-small\",\n        type=int,\n        default=50,\n        help=\"Benchmark iterations for sizes <= 1MB.\",\n    )\n    parser.add_argument(\n        \"--iters-large\",\n        type=int,\n        default=20,\n        help=\"Benchmark iterations for sizes > 1MB.\",\n    )\n    parser.add_argument(\n        \"--verbose\",\n        action=\"store_true\",\n        help=\"Print per-iteration timings on rank 0 for debugging.\",\n    )\n    return parser.parse_args()\n\n\ndef get_env_rank_world() -> Tuple[int, int, int]:\n    rank = int(os.environ.get(\"RANK\", \"0\"))\n    world_size = int(os.environ.get(\"WORLD_SIZE\", \"1\"))\n    local_rank = int(os.environ.get(\"LOCAL_RANK\", str(rank)))\n    return rank, world_size, local_rank\n\n\ndef init_dist(backend: str):\n    rank, world_size, _ = get_env_rank_world()\n    if not dist.is_initialized():\n        dist.init_process_group(\n            backend=backend,\n            init_method=\"env://\",\n            rank=rank,\n            world_size=world_size,\n        )\n\n\ndef get_device(local_rank: int) -> torch.device:\n    torch.cuda.set_device(local_rank)\n    return torch.device(f\"cuda:{local_rank}\")\n\n\ndef human_size(num_bytes: int) -> str:\n    units = [(\"B\", 1), (\"K\", 1024), (\"M\", 1024 * 1024), (\"G\", 1024 * 1024 * 1024)]\n    for suf, base in reversed(units):\n        if num_bytes % base == 0 and num_bytes >= base:\n            val = num_bytes // base\n            return f\"{val}{suf}\"\n    return f\"{num_bytes}B\"\n\n\ndef get_message_sizes() -> List[int]:\n    return [\n        32 * 1024,\n        64 * 1024,\n        128 * 1024,\n        256 * 1024,\n        512 * 1024,\n        1 * 1024 * 1024,\n        2 * 1024 * 1024,\n        4 * 1024 * 1024,\n        8 * 1024 * 1024,\n        16 * 1024 * 1024,\n        32 * 1024 * 1024,\n        64 * 1024 * 1024,\n    ]\n\n\n@torch.inference_mode()\ndef run_once(comm, inp: torch.Tensor) -> Optional[torch.Tensor]:\n    if hasattr(comm, \"all_reduce_unreg\"):\n        return comm.all_reduce_unreg(inp)\n    if hasattr(comm, \"custom_all_reduce\"):\n        return comm.custom_all_reduce(inp)\n    raise RuntimeError(\"No known all-reduce method found on the communicator.\")\n\n\n@torch.inference_mode()\ndef bench_impl(\n    name: str,\n    comm,\n    sizes: List[int],\n    device: torch.device,\n    warmup: int,\n    iters_small: int,\n    iters_large: int,\n    verbose: bool,\n    pg: Optional[dist.ProcessGroup] = None,\n) -> List[Tuple[int, Optional[float]]]:\n    rank = dist.get_rank()\n    world_size = dist.get_world_size()\n    results: List[Tuple[int, Optional[float]]] = []\n\n    for size_bytes in sizes:\n        elems = size_bytes // 2  # float16: 2 bytes per element\n        inp = torch.empty(elems, dtype=torch.float16, device=device)\n        inp.uniform_(0, 1)\n\n        disabled = False\n        dist.barrier(group=pg)\n        for _ in range(warmup):\n            torch.cuda.synchronize()\n            out = run_once(comm, inp)\n            torch.cuda.synchronize()\n            if out is None:\n                disabled = True\n                break\n        dist.barrier(group=pg)\n\n        if disabled:\n            if rank == 0:\n                print(\n                    f\"[{name}] {human_size(size_bytes)}: custom AR disabled (skipped)\"\n                )\n            results.append((size_bytes, None))\n            continue\n\n        num_iters = iters_small if size_bytes <= (1 * 1024 * 1024) else iters_large\n\n        times_ms: List[float] = []\n        for it in range(num_iters):\n            dist.barrier(group=pg)\n            torch.cuda.synchronize()\n            t0 = time.perf_counter()\n            out = run_once(comm, inp)\n            torch.cuda.synchronize()\n            t1 = time.perf_counter()\n            dist.barrier(group=pg)\n\n            if out is None:\n                disabled = True\n                break\n\n            dt_ms = (t1 - t0) * 1000.0\n            times_ms.append(dt_ms)\n\n            if verbose and rank == 0:\n                print(\n                    f\"[{name}] size={human_size(size_bytes)} iter={it} time={dt_ms:.3f} ms\"\n                )\n\n        if disabled or not times_ms:\n            if rank == 0:\n                print(\n                    f\"[{name}] {human_size(size_bytes)}: custom AR disabled (no timings)\"\n                )\n            results.append((size_bytes, None))\n            continue\n\n        avg_ms_local = sum(times_ms) / len(times_ms)\n        avg_tensor = torch.tensor([avg_ms_local], dtype=torch.float64, device=device)\n        gather_list = [torch.zeros_like(avg_tensor) for _ in range(world_size)]\n        dist.all_gather(gather_list, avg_tensor, group=pg)\n        if rank == 0:\n            avg_ms = float(torch.stack(gather_list).mean().item())\n            print(\n                f\"[{name}] {human_size(size_bytes)}: {avg_ms:.3f} ms (avg across ranks)\"\n            )\n            results.append((size_bytes, avg_ms))\n        else:\n            results.append((size_bytes, None))\n\n    return results\n\n\ndef main():\n    args = parse_args()\n    rank, world_size, local_rank = get_env_rank_world()\n\n    if world_size not in (2, 4, 6, 8):\n        print(\n            f\"[rank {rank}] WARNING: world_size={world_size} not in supported set (2,4,6,8). \"\n            \"Custom AR may disable itself.\",\n            file=sys.stderr,\n        )\n\n    init_dist(args.backend)\n    device = get_device(local_rank)\n\n    # Import after dist init; some libs query torch dist state on import\n    sgl_comm = None\n    aiter_comm = None\n    HAVE_SGLANG = False\n    HAVE_AITER = False\n\n    try:\n        from sglang.srt.distributed.device_communicators.custom_all_reduce import (\n            CustomAllreduce as SGLCustomAllreduce,\n        )\n\n        HAVE_SGLANG = True\n    except Exception as e:\n        if rank == 0:\n            print(f\"SGLang CustomAllreduce import failed: {e}\", file=sys.stderr)\n\n    try:\n        from aiter.dist.device_communicators.custom_all_reduce import (\n            CustomAllreduce as AiterCustomAllreduce,\n        )\n\n        HAVE_AITER = True\n    except Exception as e:\n        if rank == 0:\n            print(f\"Aiter CustomAllreduce import failed: {e}\", file=sys.stderr)\n\n    if rank == 0:\n        print(f\"Initialized PG backend={args.backend} world_size={world_size}\")\n        print(f\"Device: {device.type}:{device.index}\")\n        print(f\"SGLang available: {HAVE_SGLANG}, Aiter available: {HAVE_AITER}\")\n\n    pg = dist.group.WORLD\n    sizes = get_message_sizes()\n    max_size = max(sizes) if sizes else (64 * 1024 * 1024)\n\n    if HAVE_SGLANG:\n        try:\n            sgl_comm = SGLCustomAllreduce(group=pg, device=device, max_size=max_size)\n        except Exception as e:\n            if rank == 0:\n                print(\n                    f\"Failed to construct SGLang CustomAllreduce: {e}\", file=sys.stderr\n                )\n            sgl_comm = None\n\n    if HAVE_AITER:\n        try:\n            aiter_comm = AiterCustomAllreduce(\n                group=pg, device=device, max_size=max_size\n            )\n        except Exception as e:\n            if rank == 0:\n                print(\n                    f\"Failed to construct Aiter CustomAllreduce: {e}\", file=sys.stderr\n                )\n            aiter_comm = None\n\n    sgl_results: List[Tuple[int, Optional[float]]] = []\n    aiter_results: List[Tuple[int, Optional[float]]] = []\n\n    if sgl_comm is not None:\n        sgl_results = bench_impl(\n            name=\"SGLang\",\n            comm=sgl_comm,\n            sizes=sizes,\n            device=device,\n            warmup=args.warmup,\n            iters_small=args.iters_small,\n            iters_large=args.iters_large,\n            verbose=args.verbose,\n            pg=pg,\n        )\n\n    if aiter_comm is not None:\n        aiter_results = bench_impl(\n            name=\"Aiter\",\n            comm=aiter_comm,\n            sizes=sizes,\n            device=device,\n            warmup=args.warmup,\n            iters_small=args.iters_small,\n            iters_large=args.iters_large,\n            verbose=args.verbose,\n            pg=pg,\n        )\n\n    for comm in (sgl_comm, aiter_comm):\n        if comm is not None and hasattr(comm, \"close\"):\n            try:\n                comm.close()\n            except Exception:\n                pass\n\n    if dist.get_rank() == 0:\n        print(\"\\nResults (avg ms across ranks; None = disabled/unavailable):\")\n        header = f\"{'Size':>8}  {'SGLang(ms)':>12}  {'Aiter(ms)':>11}\"\n        print(header)\n        print(\"-\" * len(header))\n\n        sgl_map = {s: v for s, v in sgl_results if v is not None}\n        aiter_map = {s: v for s, v in aiter_results if v is not None}\n\n        for s in sizes:\n            sgl_ms = sgl_map.get(s, None)\n            aiter_ms = aiter_map.get(s, None)\n            print(\n                f\"{human_size(s):>8}  {('%.3f' % sgl_ms) if sgl_ms is not None else 'None':>12}  \"\n                f\"{('%.3f' % aiter_ms) if aiter_ms is not None else 'None':>11}\"\n            )\n\n    dist.barrier()\n    dist.destroy_process_group()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "benchmark/kernels/all_reduce/benchmark_all_reduce.py",
    "content": "\"\"\"\nBenchmark SGLang custom all-reduce vs Torch symm-mem all-reduce across message sizes.\nUsage:\n    torchrun --nproc_per_node=2 benchmark_all_reduce.py\n    torchrun --nproc_per_node=4 benchmark_all_reduce.py\n    torchrun --nproc_per_node=8 benchmark_all_reduce.py\n\"\"\"\n\nimport argparse\nimport os\nimport sys\nimport time\nfrom typing import List, Optional, Tuple\n\nimport torch\nimport torch.distributed as dist\n\nfrom sglang.srt.distributed.parallel_state import (\n    destroy_distributed_environment,\n    destroy_model_parallel,\n    init_distributed_environment,\n    initialize_model_parallel,\n)\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description=\"Benchmark SGLang custom all-reduce vs Torch symm-mem all-reduce across message sizes.\"\n    )\n    parser.add_argument(\n        \"--backend\",\n        type=str,\n        default=\"gloo\",\n        help=\"Process group backend for the custom-AR control path (must NOT be nccl).\",\n    )\n    parser.add_argument(\n        \"--warmup\",\n        type=int,\n        default=5,\n        help=\"Warmup iterations per size per implementation.\",\n    )\n    parser.add_argument(\n        \"--iters-small\",\n        type=int,\n        default=50,\n        help=\"Benchmark iterations for sizes <= 1MB.\",\n    )\n    parser.add_argument(\n        \"--iters-large\",\n        type=int,\n        default=20,\n        help=\"Benchmark iterations for sizes > 1MB.\",\n    )\n    parser.add_argument(\n        \"--verbose\",\n        action=\"store_true\",\n        help=\"Print per-iteration timings on rank 0 for debugging.\",\n    )\n    return parser.parse_args()\n\n\ndef get_env_rank_world() -> Tuple[int, int, int]:\n    rank = int(os.environ.get(\"RANK\", \"0\"))\n    world_size = int(os.environ.get(\"WORLD_SIZE\", \"1\"))\n    local_rank = int(os.environ.get(\"LOCAL_RANK\", str(rank)))\n    return rank, world_size, local_rank\n\n\ndef init_dist(backend: str):\n    rank, world_size, _ = get_env_rank_world()\n    if not dist.is_initialized():\n        dist.init_process_group(\n            backend=backend,\n            init_method=\"env://\",\n            rank=rank,\n            world_size=world_size,\n        )\n\n    device = torch.device(f\"cuda:{rank}\")\n    torch.cuda.set_device(device)\n    distributed_init_method = f\"tcp://localhost:23456\"\n    init_distributed_environment(\n        world_size=world_size,\n        rank=rank,\n        distributed_init_method=distributed_init_method,\n        local_rank=rank,\n    )\n    initialize_model_parallel(tensor_model_parallel_size=world_size)\n    return dist.group.WORLD\n\n\ndef get_device(local_rank: int) -> torch.device:\n    torch.cuda.set_device(local_rank)\n    return torch.device(f\"cuda:{local_rank}\")\n\n\ndef human_size(num_bytes: int) -> str:\n    units = [(\"B\", 1), (\"K\", 1024), (\"M\", 1024 * 1024), (\"G\", 1024 * 1024 * 1024)]\n    for suf, base in reversed(units):\n        if num_bytes % base == 0 and num_bytes >= base:\n            val = num_bytes // base\n            return f\"{val}{suf}\"\n    return f\"{num_bytes}B\"\n\n\ndef get_message_sizes() -> List[int]:\n    return [\n        32 * 1024,\n        64 * 1024,\n        128 * 1024,\n        256 * 1024,\n        512 * 1024,\n        1 * 1024 * 1024,\n        2 * 1024 * 1024,\n        4 * 1024 * 1024,\n        8 * 1024 * 1024,\n        16 * 1024 * 1024,\n        32 * 1024 * 1024,\n        64 * 1024 * 1024,\n    ]\n\n\n@torch.inference_mode()\ndef run_once(comm, inp: torch.Tensor) -> Optional[torch.Tensor]:\n    if hasattr(comm, \"custom_all_reduce\"):\n        return comm.custom_all_reduce(inp)\n    if hasattr(comm, \"all_reduce\"):\n        return comm.all_reduce(inp)\n    raise RuntimeError(\"No known all-reduce method found on the communicator.\")\n\n\n@torch.inference_mode()\ndef bench_impl(\n    name: str,\n    comm,\n    sizes: List[int],\n    device: torch.device,\n    warmup: int,\n    iters_small: int,\n    iters_large: int,\n    verbose: bool,\n    pg: Optional[dist.ProcessGroup] = None,\n) -> List[Tuple[int, Optional[float]]]:\n    rank = dist.get_rank()\n    world_size = dist.get_world_size()\n    results: List[Tuple[int, Optional[float]]] = []\n\n    for size_bytes in sizes:\n        elems = size_bytes // 2  # float16: 2 bytes per element\n        inp = torch.empty(elems, dtype=torch.float16, device=device)\n        inp.uniform_(0, 1)\n\n        disabled = False\n        dist.barrier(group=pg)\n        for _ in range(warmup):\n            torch.cuda.synchronize()\n            out = run_once(comm, inp)\n            torch.cuda.synchronize()\n            if out is None:\n                disabled = True\n                break\n        dist.barrier(group=pg)\n\n        if disabled:\n            if rank == 0:\n                print(\n                    f\"[{name}] {human_size(size_bytes)}: custom AR disabled (skipped)\"\n                )\n            results.append((size_bytes, None))\n            continue\n\n        num_iters = iters_small if size_bytes <= (1 * 1024 * 1024) else iters_large\n\n        times_ms: List[float] = []\n        for it in range(num_iters):\n            dist.barrier(group=pg)\n            torch.cuda.synchronize()\n            t0 = time.perf_counter()\n            out = run_once(comm, inp)\n            torch.cuda.synchronize()\n            t1 = time.perf_counter()\n            dist.barrier(group=pg)\n\n            if out is None:\n                disabled = True\n                break\n\n            dt_ms = (t1 - t0) * 1000.0\n            times_ms.append(dt_ms)\n\n            if verbose and rank == 0:\n                print(\n                    f\"[{name}] size={human_size(size_bytes)} iter={it} time={dt_ms:.3f} ms\"\n                )\n\n        if disabled or not times_ms:\n            if rank == 0:\n                print(\n                    f\"[{name}] {human_size(size_bytes)}: custom AR disabled (no timings)\"\n                )\n            results.append((size_bytes, None))\n            continue\n\n        avg_ms_local = sum(times_ms) / len(times_ms)\n        avg_tensor = torch.tensor([avg_ms_local], dtype=torch.float64, device=device)\n        gather_list = [torch.zeros_like(avg_tensor) for _ in range(world_size)]\n        dist.all_gather(gather_list, avg_tensor, group=pg)\n        if rank == 0:\n            avg_ms = float(torch.stack(gather_list).mean().item())\n            print(\n                f\"[{name}] {human_size(size_bytes)}: {avg_ms:.3f} ms (avg across ranks)\"\n            )\n            results.append((size_bytes, avg_ms))\n        else:\n            results.append((size_bytes, None))\n\n    return results\n\n\ndef main():\n    args = parse_args()\n    rank, world_size, local_rank = get_env_rank_world()\n\n    if world_size not in (2, 4, 6, 8):\n        print(\n            f\"[rank {rank}] WARNING: world_size={world_size} not in supported set (2,4,6,8). \"\n            \"Custom AR may disable itself.\",\n            file=sys.stderr,\n        )\n\n    group = init_dist(args.backend)\n    device = get_device(local_rank)\n\n    # Import after dist init; some libs query torch dist state on import\n    torch_symm_mem_comm = None\n    HAVE_SGLANG_CUSTOM = False\n    HAVE_TORCH_SYMM_MEM = False\n\n    try:\n        from sglang.srt.distributed.device_communicators.custom_all_reduce import (\n            CustomAllreduce as SGLCustomAllreduce,\n        )\n\n        HAVE_SGLANG_CUSTOM = True\n    except Exception as e:\n        if rank == 0:\n            print(f\"SGLang CustomAllreduce import failed: {e}\", file=sys.stderr)\n\n    try:\n        from sglang.srt.distributed.device_communicators.torch_symm_mem import (\n            TorchSymmMemCommunicator as TorchSymmMemAllreduce,\n        )\n\n        HAVE_TORCH_SYMM_MEM = True\n    except Exception as e:\n        if rank == 0:\n            print(f\"TorchSymmMemAllreduce import failed: {e}\", file=sys.stderr)\n\n    if rank == 0:\n        print(f\"Initialized PG backend={args.backend} world_size={world_size}\")\n        print(f\"Device: {device.type}:{device.index}\")\n        print(\n            f\"SGLang Custom available: {HAVE_SGLANG_CUSTOM}, Torch Symm-Mem available: {HAVE_TORCH_SYMM_MEM}\"\n        )\n\n    sizes = get_message_sizes()\n    max_size = max(sizes) if sizes else (128 * 1024 * 1024)\n\n    if HAVE_SGLANG_CUSTOM:\n        try:\n            sgl_custom_comm = SGLCustomAllreduce(\n                group=group, device=device, max_size=max_size\n            )\n        except Exception as e:\n            if rank == 0:\n                print(\n                    f\"Failed to construct SGLangCustomAllreduce: {e}\", file=sys.stderr\n                )\n            sgl_custom_comm = None\n\n    if HAVE_TORCH_SYMM_MEM:\n        try:\n            torch_symm_mem_comm = TorchSymmMemAllreduce(group=group, device=device)\n        except Exception as e:\n            if rank == 0:\n                print(\n                    f\"Failed to construct TorchSymmMemAllreduce: {e}\", file=sys.stderr\n                )\n            torch_symm_mem_comm = None\n\n    sgl_custom_results: List[Tuple[int, Optional[float]]] = []\n    symm_mem_results: List[Tuple[int, Optional[float]]] = []\n\n    if sgl_custom_comm is not None:\n        sgl_custom_results = bench_impl(\n            name=\"SGLangCustom\",\n            comm=sgl_custom_comm,\n            sizes=sizes,\n            device=device,\n            warmup=args.warmup,\n            iters_small=args.iters_small,\n            iters_large=args.iters_large,\n            verbose=args.verbose,\n            pg=group,\n        )\n\n    if torch_symm_mem_comm is not None:\n        symm_mem_results = bench_impl(\n            name=\"TorchSymmMem\",\n            comm=torch_symm_mem_comm,\n            sizes=sizes,\n            device=device,\n            warmup=args.warmup,\n            iters_small=args.iters_small,\n            iters_large=args.iters_large,\n            verbose=args.verbose,\n            pg=group,\n        )\n\n    for comm in (sgl_custom_comm, torch_symm_mem_comm):\n        if comm is not None and hasattr(comm, \"close\"):\n            try:\n                comm.close()\n            except Exception:\n                pass\n\n    if dist.get_rank() == 0:\n        print(\n            f\"\\nResults (avg ms across {world_size} ranks; None = disabled/unavailable):\"\n        )\n        header = f\"{'Size':>8}  {'CustomAR(ms)':>12}  {'TorchSymmMem(ms)':>11}\"\n        print(header)\n        print(\"-\" * len(header))\n\n        sgl_custom_map = {s: v for s, v in sgl_custom_results if v is not None}\n        symm_mem_map = {s: v for s, v in symm_mem_results if v is not None}\n\n        for s in sizes:\n            sgl_ms = sgl_custom_map.get(s, None)\n            symm_mem_ms = symm_mem_map.get(s, None)\n            print(\n                f\"{human_size(s):>8}  {('%.3f' % sgl_ms) if sgl_ms is not None else 'None':>12}  \"\n                f\"{('%.3f' % symm_mem_ms) if symm_mem_ms is not None else 'None':>11}\"\n            )\n    torch.distributed.barrier(group=group)\n    destroy_model_parallel()\n    destroy_distributed_environment()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "benchmark/kernels/all_reduce/benchmark_fused_ar_rms_amd.py",
    "content": "\"\"\"\nBenchmark fused allreduce+rmsnorm on AMD with correctness checks.\n\nThis script targets the same fused op used by SGLang:\n`tensor_model_parallel_fused_allreduce_rmsnorm`.\n\nIt reports:\n- eager mode latency (prefill-like)\n- graph mode latency (decode-like)\n- fused availability (whether fused path returns non-None)\n- correctness (fused output matches split allreduce + rmsnorm reference)\n\nUsage example:\n  torchrun --nproc_per_node=8 \\\n    benchmark/kernels/all_reduce/benchmark_fused_ar_rms_amd.py \\\n    --dtype bfloat16 \\\n    --prefill-shapes 2048x8192,8192x8192 \\\n    --decode-shapes 1x8192,4x8192,16x8192 \\\n    --warmup 10 --iters 30 --repeats 5\n\"\"\"\n\nimport argparse\nimport csv\nimport os\nimport statistics\nfrom typing import Dict, List, Optional, Sequence, Tuple\n\nimport torch\nimport torch.distributed as dist\nimport torch.nn.functional as F\n\nfrom sglang.srt.distributed.communication_op import (\n    tensor_model_parallel_all_reduce,\n    tensor_model_parallel_fused_allreduce_rmsnorm,\n)\nfrom sglang.srt.distributed.parallel_state import (\n    destroy_distributed_environment,\n    destroy_model_parallel,\n    graph_capture,\n    init_distributed_environment,\n    initialize_model_parallel,\n    set_custom_all_reduce,\n)\n\nShape = Tuple[int, int]\n\n\ndef parse_shapes(raw: str) -> List[Shape]:\n    shapes: List[Shape] = []\n    for item in [x.strip() for x in raw.split(\",\") if x.strip()]:\n        if \"x\" not in item:\n            raise ValueError(f\"Invalid shape '{item}', expected MxN format.\")\n        m_str, n_str = item.split(\"x\", 1)\n        m = int(m_str)\n        n = int(n_str)\n        if m <= 0 or n <= 0:\n            raise ValueError(f\"Invalid shape '{item}', both dims must be positive.\")\n        shapes.append((m, n))\n    if not shapes:\n        raise ValueError(\"Empty shape list is not allowed.\")\n    return shapes\n\n\ndef dtype_from_name(name: str) -> torch.dtype:\n    mapping = {\n        \"float16\": torch.float16,\n        \"fp16\": torch.float16,\n        \"bfloat16\": torch.bfloat16,\n        \"bf16\": torch.bfloat16,\n    }\n    if name not in mapping:\n        raise ValueError(f\"Unsupported dtype: {name}\")\n    return mapping[name]\n\n\ndef check_close(\n    a: torch.Tensor, b: torch.Tensor, dtype: torch.dtype\n) -> Tuple[bool, str]:\n    if dtype == torch.bfloat16:\n        rtol, atol = 2e-2, 1.25e-1\n    else:\n        rtol, atol = 1e-2, 2e-2\n    try:\n        torch.testing.assert_close(a, b, rtol=rtol, atol=atol)\n        return True, \"PASS\"\n    except AssertionError:\n        max_diff = torch.max(torch.abs(a - b)).item()\n        mean_diff = torch.mean(torch.abs(a - b)).item()\n        return False, f\"FAIL(max={max_diff:.6f},mean={mean_diff:.6f})\"\n\n\ndef _measure_us(\n    fn,\n    warmup: int,\n    iters: int,\n    repeats: int,\n    device: torch.device,\n) -> Tuple[float, Dict[str, float]]:\n    for _ in range(warmup):\n        fn()\n    torch.cuda.synchronize()\n\n    start_event = torch.cuda.Event(enable_timing=True)\n    end_event = torch.cuda.Event(enable_timing=True)\n    samples_us: List[float] = []\n\n    for _ in range(max(1, repeats)):\n        _barrier(device)\n        torch.cuda.synchronize()\n        start_event.record()\n        for _ in range(iters):\n            fn()\n        end_event.record()\n        end_event.synchronize()\n        samples_us.append(start_event.elapsed_time(end_event) * 1000.0 / iters)\n\n    sorted_samples = sorted(samples_us)\n    p50 = float(statistics.median(sorted_samples))\n    p95 = float(sorted_samples[int((len(sorted_samples) - 1) * 0.95)])\n    return p50, {\n        \"p50_us\": p50,\n        \"p95_us\": p95,\n        \"min_us\": float(sorted_samples[0]),\n        \"max_us\": float(sorted_samples[-1]),\n    }\n\n\ndef _barrier(device: torch.device):\n    try:\n        dist.barrier(device_ids=[device.index])\n    except TypeError:\n        dist.barrier()\n\n\ndef _mean_across_ranks(value: float, device: torch.device) -> float:\n    t = torch.tensor([value], dtype=torch.float64, device=device)\n    dist.all_reduce(t, op=dist.ReduceOp.SUM)\n    t /= dist.get_world_size()\n    return float(t.item())\n\n\ndef _all_true_across_ranks(value: bool, device: torch.device) -> bool:\n    t = torch.tensor([1 if value else 0], dtype=torch.int32, device=device)\n    dist.all_reduce(t, op=dist.ReduceOp.MIN)\n    return bool(int(t.item()))\n\n\ndef _make_inputs(\n    shape: Shape,\n    dtype: torch.dtype,\n    seed: int,\n    residual_mode: str,\n    rank: int,\n    device: torch.device,\n) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n    m, n = shape\n    torch.manual_seed(seed + rank * 17)\n    x = torch.randn((m, n), dtype=torch.float32, device=device).to(dtype)\n    if residual_mode == \"self\":\n        residual = x.clone()\n    elif residual_mode == \"random\":\n        residual = torch.randn((m, n), dtype=torch.float32, device=device).to(dtype)\n    elif residual_mode == \"zero\":\n        residual = torch.zeros((m, n), dtype=dtype, device=device)\n    else:\n        raise ValueError(f\"Unknown residual_mode: {residual_mode}\")\n    weight = torch.randn((n,), dtype=torch.float32, device=device).to(dtype)\n    return x, residual, weight\n\n\ndef _split_reference(\n    x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    ar_out = tensor_model_parallel_all_reduce(x.clone())\n    residual_out = ar_out + residual\n    out = F.rms_norm(\n        input=residual_out,\n        normalized_shape=(residual_out.shape[-1],),\n        weight=weight,\n        eps=eps,\n    )\n    return out, residual_out\n\n\ndef bench_eager(\n    x: torch.Tensor,\n    residual: torch.Tensor,\n    weight: torch.Tensor,\n    eps: float,\n    warmup: int,\n    iters: int,\n    repeats: int,\n) -> Dict[str, object]:\n    split_fn = lambda: _split_reference(x, residual, weight, eps)\n    split_us, split_stats = _measure_us(split_fn, warmup, iters, repeats, x.device)\n\n    fused_probe = tensor_model_parallel_fused_allreduce_rmsnorm(\n        x.clone(), residual.clone(), weight, eps\n    )\n    fused_available = fused_probe is not None\n\n    fused_us: Optional[float] = None\n    fused_stats: Optional[Dict[str, float]] = None\n    if fused_available:\n        fused_fn = lambda: tensor_model_parallel_fused_allreduce_rmsnorm(\n            x, residual, weight, eps\n        )\n        fused_us, fused_stats = _measure_us(fused_fn, warmup, iters, repeats, x.device)\n\n    ref_out, ref_residual = _split_reference(x, residual, weight, eps)\n    if fused_available:\n        fused_out, fused_residual = tensor_model_parallel_fused_allreduce_rmsnorm(\n            x.clone(), residual.clone(), weight, eps\n        )\n        out_ok, out_detail = check_close(fused_out, ref_out, x.dtype)\n        res_ok, res_detail = check_close(fused_residual, ref_residual, x.dtype)\n        correctness_ok = out_ok and res_ok\n        correctness_detail = f\"out={out_detail}, residual={res_detail}\"\n    else:\n        correctness_ok = True\n        correctness_detail = \"SKIP(fused_unavailable)\"\n\n    return {\n        \"split_us\": split_us,\n        \"split_stats\": split_stats,\n        \"fused_available\": fused_available,\n        \"fused_us\": fused_us,\n        \"fused_stats\": fused_stats,\n        \"correctness_ok\": correctness_ok,\n        \"correctness_detail\": correctness_detail,\n    }\n\n\ndef bench_graph(\n    x: torch.Tensor,\n    residual: torch.Tensor,\n    weight: torch.Tensor,\n    eps: float,\n    warmup: int,\n    iters: int,\n    repeats: int,\n) -> Dict[str, object]:\n    split_x = x.clone()\n    split_res = residual.clone()\n    split_graph_out: Optional[torch.Tensor] = None\n\n    with graph_capture() as gc:\n        split_graph = torch.cuda.CUDAGraph()\n        with torch.cuda.graph(split_graph, stream=gc.stream):\n            split_graph_out, _ = _split_reference(split_x, split_res, weight, eps)\n\n    def split_replay():\n        split_graph.replay()\n\n    split_us, split_stats = _measure_us(split_replay, warmup, iters, repeats, x.device)\n\n    fused_probe = tensor_model_parallel_fused_allreduce_rmsnorm(\n        x.clone(), residual.clone(), weight, eps\n    )\n    fused_available = fused_probe is not None\n\n    fused_us: Optional[float] = None\n    fused_stats: Optional[Dict[str, float]] = None\n    fused_graph_out: Optional[torch.Tensor] = None\n    fused_graph_residual: Optional[torch.Tensor] = None\n\n    if fused_available:\n        fused_x = x.clone()\n        fused_res = residual.clone()\n        with graph_capture() as gc:\n            fused_graph = torch.cuda.CUDAGraph()\n            with torch.cuda.graph(fused_graph, stream=gc.stream):\n                fused_graph_out, fused_graph_residual = (\n                    tensor_model_parallel_fused_allreduce_rmsnorm(\n                        fused_x, fused_res, weight, eps\n                    )\n                )\n\n        def fused_replay():\n            fused_graph.replay()\n\n        fused_us, fused_stats = _measure_us(\n            fused_replay, warmup, iters, repeats, x.device\n        )\n\n    ref_out, ref_residual = _split_reference(x, residual, weight, eps)\n    if (\n        fused_available\n        and fused_graph_out is not None\n        and fused_graph_residual is not None\n    ):\n        fused_graph.replay()\n        torch.cuda.synchronize()\n        out_ok, out_detail = check_close(fused_graph_out, ref_out, x.dtype)\n        res_ok, res_detail = check_close(fused_graph_residual, ref_residual, x.dtype)\n        correctness_ok = out_ok and res_ok\n        correctness_detail = f\"out={out_detail}, residual={res_detail}\"\n    else:\n        correctness_ok = True\n        correctness_detail = \"SKIP(fused_unavailable)\"\n\n    return {\n        \"split_us\": split_us,\n        \"split_stats\": split_stats,\n        \"fused_available\": fused_available,\n        \"fused_us\": fused_us,\n        \"fused_stats\": fused_stats,\n        \"correctness_ok\": correctness_ok,\n        \"correctness_detail\": correctness_detail,\n    }\n\n\ndef _shape_bytes(shape: Shape, dtype: torch.dtype) -> int:\n    m, n = shape\n    return m * n * torch.tensor([], dtype=dtype).element_size()\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description=\"Benchmark fused allreduce+rmsnorm (prefill eager + decode graph).\"\n    )\n    parser.add_argument(\n        \"--dtype\",\n        type=str,\n        default=\"bf16\",\n        choices=[\"fp16\", \"bf16\", \"float16\", \"bfloat16\"],\n    )\n    parser.add_argument(\"--eps\", type=float, default=1e-6)\n    parser.add_argument(\"--seed\", type=int, default=1234)\n    parser.add_argument(\n        \"--residual-mode\",\n        type=str,\n        default=\"self\",\n        choices=[\"self\", \"random\", \"zero\"],\n        help=\"Use residual=x (self) to match aiter test behavior by default.\",\n    )\n    parser.add_argument(\n        \"--prefill-shapes\",\n        type=str,\n        default=\"2048x8192,8192x8192,16384x8192\",\n        help=\"Comma-separated MxN shapes for eager mode.\",\n    )\n    parser.add_argument(\n        \"--decode-shapes\",\n        type=str,\n        default=\"1x8192,2x8192,4x8192,8x8192,16x8192\",\n        help=\"Comma-separated MxN shapes for graph mode.\",\n    )\n    parser.add_argument(\"--warmup\", type=int, default=10)\n    parser.add_argument(\"--iters\", type=int, default=30)\n    parser.add_argument(\"--repeats\", type=int, default=5)\n    parser.add_argument(\n        \"--mode\",\n        type=str,\n        default=\"both\",\n        choices=[\"eager\", \"graph\", \"both\"],\n    )\n    parser.add_argument(\n        \"--csv-out\",\n        type=str,\n        default=None,\n        help=\"Optional output CSV path (written on rank 0 only).\",\n    )\n    return parser.parse_args()\n\n\ndef main():\n    args = parse_args()\n    dtype = dtype_from_name(args.dtype)\n    rank = int(os.environ.get(\"RANK\", \"0\"))\n    world_size = int(os.environ.get(\"WORLD_SIZE\", \"1\"))\n    local_rank = int(os.environ.get(\"LOCAL_RANK\", str(rank)))\n    torch.cuda.set_device(local_rank % torch.cuda.device_count())\n    device = torch.device(f\"cuda:{local_rank % torch.cuda.device_count()}\")\n\n    set_custom_all_reduce(True)\n    init_distributed_environment(\n        world_size=world_size,\n        rank=rank,\n        local_rank=local_rank,\n        distributed_init_method=\"env://\",\n        backend=\"nccl\",\n    )\n    initialize_model_parallel(tensor_model_parallel_size=world_size)\n\n    prefill_shapes = parse_shapes(args.prefill_shapes)\n    decode_shapes = parse_shapes(args.decode_shapes)\n\n    if rank == 0:\n        print(\n            \"Config: \"\n            f\"world_size={world_size}, dtype={dtype}, residual_mode={args.residual_mode}, \"\n            f\"warmup={args.warmup}, iters={args.iters}, repeats={args.repeats}\"\n        )\n\n    run_modes: Sequence[str]\n    if args.mode == \"both\":\n        run_modes = (\"eager\", \"graph\")\n    else:\n        run_modes = (args.mode,)\n    csv_rows: List[Dict[str, object]] = []\n\n    for mode in run_modes:\n        shapes = prefill_shapes if mode == \"eager\" else decode_shapes\n        if rank == 0:\n            phase_name = \"prefill(eager)\" if mode == \"eager\" else \"decode(graph)\"\n            print(\"\\n\" + \"=\" * 120)\n            print(f\"Mode: {phase_name}\")\n            print(\n                \"| Shape | Input bytes/rank | Split p50 (us) | Fused p50 (us) | Speedup | Fused available | Correctness |\"\n            )\n            print(\n                \"|:------|-----------------:|---------------:|---------------:|--------:|:----------------|:------------|\"\n            )\n\n        for shape in shapes:\n            x, residual, weight = _make_inputs(\n                shape=shape,\n                dtype=dtype,\n                seed=args.seed,\n                residual_mode=args.residual_mode,\n                rank=rank,\n                device=device,\n            )\n\n            if mode == \"eager\":\n                metrics = bench_eager(\n                    x=x,\n                    residual=residual,\n                    weight=weight,\n                    eps=args.eps,\n                    warmup=args.warmup,\n                    iters=args.iters,\n                    repeats=args.repeats,\n                )\n            else:\n                metrics = bench_graph(\n                    x=x,\n                    residual=residual,\n                    weight=weight,\n                    eps=args.eps,\n                    warmup=args.warmup,\n                    iters=args.iters,\n                    repeats=args.repeats,\n                )\n\n            split_us = _mean_across_ranks(float(metrics[\"split_us\"]), device)\n            fused_available = _all_true_across_ranks(\n                bool(metrics[\"fused_available\"]), device\n            )\n            correctness_ok = _all_true_across_ranks(\n                bool(metrics[\"correctness_ok\"]), device\n            )\n\n            fused_us: Optional[float] = None\n            if fused_available and metrics[\"fused_us\"] is not None:\n                fused_us = _mean_across_ranks(float(metrics[\"fused_us\"]), device)\n\n            if rank == 0:\n                m, n = shape\n                shape_str = f\"{m}x{n}\"\n                bytes_per_rank = _shape_bytes(shape, dtype)\n                if fused_us is not None and fused_us > 0:\n                    speedup = split_us / fused_us\n                    speedup_str = f\"{speedup:.3f}x\"\n                    fused_str = f\"{fused_us:.1f}\"\n                else:\n                    speedup_str = \"N/A\"\n                    fused_str = \"N/A\"\n                correctness_text = (\n                    \"PASS\" if correctness_ok else str(metrics[\"correctness_detail\"])\n                )\n                print(\n                    f\"| {shape_str} | {bytes_per_rank} | {split_us:.1f} | {fused_str} | \"\n                    f\"{speedup_str} | {str(fused_available)} | {correctness_text} |\"\n                )\n                csv_rows.append(\n                    {\n                        \"mode\": mode,\n                        \"shape\": shape_str,\n                        \"m\": m,\n                        \"n\": n,\n                        \"bytes_per_rank\": bytes_per_rank,\n                        \"split_p50_us\": split_us,\n                        \"fused_p50_us\": fused_us if fused_us is not None else \"\",\n                        \"speedup_split_over_fused\": (\n                            split_us / fused_us\n                            if fused_us is not None and fused_us > 0\n                            else \"\"\n                        ),\n                        \"fused_available\": fused_available,\n                        \"correctness_ok\": correctness_ok,\n                        \"correctness_detail\": correctness_text,\n                        \"dtype\": str(dtype),\n                        \"world_size\": world_size,\n                        \"residual_mode\": args.residual_mode,\n                        \"warmup\": args.warmup,\n                        \"iters\": args.iters,\n                        \"repeats\": args.repeats,\n                    }\n                )\n\n    if rank == 0 and args.csv_out:\n        os.makedirs(os.path.dirname(args.csv_out) or \".\", exist_ok=True)\n        fieldnames = [\n            \"mode\",\n            \"shape\",\n            \"m\",\n            \"n\",\n            \"bytes_per_rank\",\n            \"split_p50_us\",\n            \"fused_p50_us\",\n            \"speedup_split_over_fused\",\n            \"fused_available\",\n            \"correctness_ok\",\n            \"correctness_detail\",\n            \"dtype\",\n            \"world_size\",\n            \"residual_mode\",\n            \"warmup\",\n            \"iters\",\n            \"repeats\",\n        ]\n        with open(args.csv_out, \"w\", newline=\"\", encoding=\"utf-8\") as f:\n            writer = csv.DictWriter(f, fieldnames=fieldnames)\n            writer.writeheader()\n            writer.writerows(csv_rows)\n        print(f\"\\nSaved CSV to: {args.csv_out}\")\n\n    _barrier(device)\n    destroy_model_parallel()\n    destroy_distributed_environment()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "benchmark/kernels/all_reduce/benchmark_mscclpp.py",
    "content": "\"\"\"For Now, MSCCL is only supported on TP16 and TP8 case\n\nexport WORLD_SIZE=1\nexport RANK=0\nexport MASTER_ADDR=127.0.0.1\nexport MASTER_PORT=12345\n\ntorchrun --nproc_per_node gpu \\\n--nnodes $WORLD_SIZE \\\n--node_rank $RANK \\\n--master_addr $MASTER_ADDR \\\n--master_port $MASTER_PORT benchmark/kernels/all_reduce/benchmark_mscclpp.py\n\"\"\"\n\nimport os\nfrom contextlib import nullcontext\nfrom typing import List\n\nimport torch\nimport torch.distributed as dist\nfrom torch.distributed import ProcessGroup\n\nfrom sglang.srt.distributed import init_distributed_environment\nfrom sglang.srt.distributed.device_communicators.pymscclpp import PyMscclppCommunicator\nfrom sglang.srt.distributed.device_communicators.pynccl import PyNcclCommunicator\nfrom sglang.srt.distributed.parallel_state import (\n    get_tensor_model_parallel_group,\n    graph_capture,\n    initialize_model_parallel,\n    set_mscclpp_all_reduce,\n)\n\n\ndef torch_allreduce(torch_input: torch.Tensor, group: ProcessGroup) -> torch.Tensor:\n    dist.all_reduce(torch_input, group=group)\n    return torch_input\n\n\ndef msccl_allreduce(\n    msccl_input: torch.Tensor, msccl_comm: PyMscclppCommunicator\n) -> torch.Tensor:\n    return msccl_comm.all_reduce(msccl_input)\n\n\ndef pynccl_allreduce(\n    msccl_input: torch.Tensor, pynccl_comm: PyNcclCommunicator\n) -> torch.Tensor:\n    pynccl_comm.all_reduce(msccl_input)\n    return msccl_input\n\n\ndef _bench_graph_time(func, inp_randn, warmup_loop=2, graph_loop=10, test_loop=10):\n    graph_input = inp_randn.clone()\n    with graph_capture() as graph_capture_context:\n        graph = torch.cuda.CUDAGraph()\n        with torch.cuda.graph(graph, stream=graph_capture_context.stream):\n            for _ in range(graph_loop):\n                graph_out = func(graph_input)\n\n    graph.replay()\n    func_output = graph_out.clone()\n\n    for _ in range(warmup_loop):\n        graph.replay()\n    torch.cuda.synchronize()\n\n    start_event = torch.cuda.Event(enable_timing=True)\n    end_event = torch.cuda.Event(enable_timing=True)\n\n    latencies: List[float] = []\n    for _ in range(test_loop):\n        torch.cuda.synchronize()\n        dist.barrier()\n        start_event.record()\n        graph.replay()\n        end_event.record()\n        end_event.synchronize()\n        latencies.append(start_event.elapsed_time(end_event))\n    func_cost_us = sum(latencies) / len(latencies) / graph_loop * 1000\n    graph.reset()\n    return func_output, func_cost_us\n\n\ndef _bench_eager_time(func, inp_randn, warmup_loop=2, test_loop=10):\n    eager_input = inp_randn.clone()\n    eager_output = func(eager_input)\n    func_output = eager_output.clone()\n\n    for _ in range(warmup_loop):\n        func(eager_input)\n    torch.cuda.synchronize()\n\n    start_event = torch.cuda.Event(enable_timing=True)\n    end_event = torch.cuda.Event(enable_timing=True)\n    torch.cuda.synchronize()\n    start_event.record()\n    for _ in range(test_loop):\n        func(eager_input)\n    end_event.record()\n    torch.cuda.synchronize()\n    func_cost_us = start_event.elapsed_time(end_event) / test_loop * 1000\n\n    return func_output, func_cost_us\n\n\ndef get_torch_prof_ctx(do_prof: bool):\n    ctx = (\n        torch.profiler.profile(\n            activities=[\n                torch.profiler.ProfilerActivity.CPU,\n                torch.profiler.ProfilerActivity.CUDA,\n            ],\n            record_shapes=True,\n            with_stack=True,\n        )\n        if do_prof\n        else nullcontext()\n    )\n    return ctx\n\n\ndef human_readable_size(size, decimal_places=1):\n    for unit in [\"B\", \"KiB\", \"MiB\", \"GiB\", \"TiB\", \"PiB\"]:\n        if size < 1024.0 or unit == \"PiB\":\n            break\n        size /= 1024.0\n    return f\"{size:.{decimal_places}f} {unit}\"\n\n\ntry:\n    from tabulate import tabulate\nexcept ImportError:\n    print(\"tabulate not installed, skipping table printing\")\n    tabulate = None\n\n\ndef print_markdown_table(data):\n    if tabulate is not None:\n        print(tabulate(data, headers=\"keys\", tablefmt=\"github\"))\n        return\n    headers = data[0].keys()\n    header_row = \"| \" + \" | \".join(headers) + \" |\"\n    separator = \"| \" + \" | \".join([\"---\"] * len(headers)) + \" |\"\n    rows = []\n    for item in data:\n        row = \"| \" + \" | \".join(str(item[key]) for key in headers) + \" |\"\n        rows.append(row)\n    markdown_table = \"\\n\".join([header_row, separator] + rows)\n    print(markdown_table)\n\n\nif __name__ == \"__main__\":\n    import logging\n\n    logging.basicConfig(\n        level=logging.INFO,\n        format=\"%(asctime)s - %(levelname)s - %(message)s\",\n        datefmt=\"%Y-%m-%d %H:%M:%S\",\n        force=True,\n    )\n    if not dist.is_initialized():\n        dist.init_process_group(backend=\"nccl\")\n    world, world_size = dist.group.WORLD, dist.get_world_size()\n    rank = dist.get_rank()\n    torch.cuda.set_device(rank % 8)\n    device = torch.cuda.current_device()\n    set_mscclpp_all_reduce(True)\n    init_distributed_environment(\n        world_size=world_size,\n        rank=rank,\n        local_rank=rank % 8,\n    )\n    initialize_model_parallel(tensor_model_parallel_size=world_size)\n    group = get_tensor_model_parallel_group().device_group\n    cpu_group = get_tensor_model_parallel_group().cpu_group\n    pynccl_comm = get_tensor_model_parallel_group().pynccl_comm\n    pymscclpp_comm = get_tensor_model_parallel_group().pymscclpp_comm\n    dist.barrier()\n    profile = False\n    dtype = torch.bfloat16\n    ctx = get_torch_prof_ctx(profile)\n    result = []\n\n    with ctx:\n        for i in range(10, 20):\n            sz = 2**i\n            if sz * dtype.itemsize > 2**20:\n                break\n            inp_randn = torch.randint(1, 16, (sz,), dtype=dtype, device=device)\n\n            memory = torch.empty_like(inp_randn)\n            memory_out = torch.empty_like(memory)\n            torch_eager_output, torch_eager_time = _bench_eager_time(\n                lambda inp: torch_allreduce(inp, group), inp_randn\n            )\n            msccl_eager_output, msccl_eager_time = _bench_eager_time(\n                lambda inp: msccl_allreduce(inp, pymscclpp_comm), inp_randn\n            )\n            msccl_graph_output, msccl_graph_time = _bench_graph_time(\n                lambda inp: msccl_allreduce(inp, pymscclpp_comm), inp_randn\n            )\n            # since pynccl is inplace op, this return result is not correct if graph loop > 1\n            _, pynccl_graph_time = _bench_graph_time(\n                lambda inp: pynccl_allreduce(inp, pynccl_comm), inp_randn\n            )\n            torch.testing.assert_close(torch_eager_output, msccl_graph_output)\n            torch.testing.assert_close(torch_eager_output, msccl_eager_output)\n            result.append(\n                {\n                    \"msg_size\": human_readable_size(inp_randn.nbytes),\n                    \"torch eager time\": torch_eager_time,\n                    \"msccl eager time\": msccl_eager_time,\n                    \"msccl graph time\": msccl_graph_time,\n                    \"pynccl graph time\": pynccl_graph_time,\n                }\n            )\n            if rank == 0:\n                print(f\"sz={sz}, dtype={dtype}: correctness check PASS!\")\n    if rank == 0:\n        print_markdown_table(result)\n    if profile:\n        prof_dir = f\"prof/msccl\"\n        os.makedirs(prof_dir, exist_ok=True)\n        ctx.export_chrome_trace(f\"{prof_dir}/trace_rank{dist.get_rank()}.json.gz\")\n"
  },
  {
    "path": "benchmark/kernels/all_reduce/benchmark_torch_symm_mem.py",
    "content": "\"\"\"For Now, TORCH_SYMM_MEM is only supported on following limited tp case\n\nSM90: {\n    2: 64 * MiB,  # 64 MB\n    4: 64 * MiB,  # 64 MB\n    6: 128 * MiB,  # 128 MB\n    8: 128 * MiB,  # 128 MB\n},\nSM100: {\n    2: 64 * MiB,  # 64 MB\n    4: 64 * MiB,  # 64 MB\n    6: 128 * MiB,  # 128 MB\n    8: 128 * MiB,  # 128 MB\n}\n\nexport WORLD_SIZE=8\nexport RANK=0\nexport MASTER_ADDR=127.0.0.1\nexport MASTER_PORT=12345\n\ntorchrun --nproc_per_node gpu \\\n--nnodes $WORLD_SIZE \\\n--node_rank $RANK \\\n--master_addr $MASTER_ADDR \\\n--master_port $MASTER_PORT ./benchmark/kernels/all_reduce/benchmark_torch_symm_mem.py\n\"\"\"\n\nimport os\nfrom contextlib import nullcontext\nfrom typing import List\n\nimport torch\nimport torch.distributed as dist\nfrom torch.distributed import ProcessGroup\n\nfrom sglang.srt.distributed import init_distributed_environment\nfrom sglang.srt.distributed.device_communicators.pynccl import PyNcclCommunicator\nfrom sglang.srt.distributed.device_communicators.torch_symm_mem import (\n    TorchSymmMemCommunicator,\n)\nfrom sglang.srt.distributed.parallel_state import (\n    get_tensor_model_parallel_group,\n    graph_capture,\n    initialize_model_parallel,\n    set_torch_symm_mem_all_reduce,\n)\n\n# CI environment detection\nIS_CI = (\n    os.getenv(\"CI\", \"false\").lower() == \"true\"\n    or os.getenv(\"GITHUB_ACTIONS\", \"false\").lower() == \"true\"\n)\n\n\ndef torch_allreduce(torch_input: torch.Tensor, group: ProcessGroup) -> torch.Tensor:\n    dist.all_reduce(torch_input, group=group)\n    return torch_input\n\n\ndef torch_symm_mem_allreduce(\n    torch_symm_mem_input: torch.Tensor, torch_symm_mem_comm: TorchSymmMemCommunicator\n) -> torch.Tensor:\n    return torch_symm_mem_comm.all_reduce(torch_symm_mem_input)\n\n\ndef pynccl_allreduce(\n    pynccl_input: torch.Tensor, pynccl_comm: PyNcclCommunicator\n) -> torch.Tensor:\n    pynccl_comm.all_reduce(pynccl_input)\n    return pynccl_input\n\n\ndef _bench_graph_time(func, inp_randn, warmup_loop=2, graph_loop=10, test_loop=10):\n    graph_input = inp_randn.clone()\n    with graph_capture() as graph_capture_context:\n        graph = torch.cuda.CUDAGraph()\n        with torch.cuda.graph(graph, stream=graph_capture_context.stream):\n            for _ in range(graph_loop):\n                graph_out = func(graph_input)\n\n    graph.replay()\n    func_output = graph_out.clone()\n\n    for _ in range(warmup_loop):\n        graph.replay()\n    torch.cuda.synchronize()\n\n    start_event = torch.cuda.Event(enable_timing=True)\n    end_event = torch.cuda.Event(enable_timing=True)\n\n    latencies: List[float] = []\n    for _ in range(test_loop):\n        torch.cuda.synchronize()\n        dist.barrier()\n        start_event.record()\n        graph.replay()\n        end_event.record()\n        end_event.synchronize()\n        latencies.append(start_event.elapsed_time(end_event))\n    func_cost_us = sum(latencies) / len(latencies) / graph_loop * 1000\n    graph.reset()\n    return func_output, func_cost_us\n\n\ndef _bench_eager_time(func, inp_randn, warmup_loop=2, test_loop=10):\n    eager_input = inp_randn.clone()\n    eager_output = func(eager_input)\n    func_output = eager_output.clone()\n\n    for _ in range(warmup_loop):\n        func(eager_input)\n    torch.cuda.synchronize()\n\n    start_event = torch.cuda.Event(enable_timing=True)\n    end_event = torch.cuda.Event(enable_timing=True)\n    torch.cuda.synchronize()\n    start_event.record()\n    for _ in range(test_loop):\n        func(eager_input)\n    end_event.record()\n    torch.cuda.synchronize()\n    func_cost_us = start_event.elapsed_time(end_event) / test_loop * 1000\n\n    return func_output, func_cost_us\n\n\ndef get_torch_prof_ctx(do_prof: bool):\n    ctx = (\n        torch.profiler.profile(\n            activities=[\n                torch.profiler.ProfilerActivity.CPU,\n                torch.profiler.ProfilerActivity.CUDA,\n            ],\n            record_shapes=True,\n            with_stack=True,\n        )\n        if do_prof\n        else nullcontext()\n    )\n    return ctx\n\n\ndef human_readable_size(size, decimal_places=1):\n    for unit in [\"B\", \"KiB\", \"MiB\", \"GiB\", \"TiB\", \"PiB\"]:\n        if size < 1024.0 or unit == \"PiB\":\n            break\n        size /= 1024.0\n    return f\"{size:.{decimal_places}f} {unit}\"\n\n\ntry:\n    from tabulate import tabulate\nexcept ImportError:\n    print(\"tabulate not installed, skipping table printing\")\n    tabulate = None\n\n\ndef print_markdown_table(data):\n    if tabulate is not None:\n        print(tabulate(data, headers=\"keys\", tablefmt=\"github\"))\n        return\n    headers = data[0].keys()\n    header_row = \"| \" + \" | \".join(headers) + \" |\"\n    separator = \"| \" + \" | \".join([\"---\"] * len(headers)) + \" |\"\n    rows = []\n    for item in data:\n        row = \"| \" + \" | \".join(str(item[key]) for key in headers) + \" |\"\n        rows.append(row)\n    markdown_table = \"\\n\".join([header_row, separator] + rows)\n    print(markdown_table)\n\n\nif __name__ == \"__main__\":\n    import logging\n\n    logging.basicConfig(\n        level=logging.INFO,\n        format=\"%(asctime)s - %(levelname)s - %(message)s\",\n        datefmt=\"%Y-%m-%d %H:%M:%S\",\n        force=True,\n    )\n    if not dist.is_initialized():\n        dist.init_process_group(backend=\"nccl\")\n    world, world_size = dist.group.WORLD, dist.get_world_size()\n    rank = dist.get_rank()\n    torch.cuda.set_device(rank % 8)\n    device = torch.cuda.current_device()\n    set_torch_symm_mem_all_reduce(True)\n    init_distributed_environment(\n        world_size=world_size,\n        rank=rank,\n        local_rank=rank % 8,\n    )\n    initialize_model_parallel(tensor_model_parallel_size=world_size)\n    group = get_tensor_model_parallel_group().device_group\n    cpu_group = get_tensor_model_parallel_group().cpu_group\n    pynccl_comm = get_tensor_model_parallel_group().pynccl_comm\n    torch_symm_mem_comm = get_tensor_model_parallel_group().torch_symm_mem_comm\n    dist.barrier()\n    profile = False\n    dtype = torch.bfloat16\n    ctx = get_torch_prof_ctx(profile)\n    result = []\n\n    with ctx:\n        if IS_CI:\n            i_range = range(10, 11)\n        else:\n            i_range = range(10, 20)\n        for i in i_range:\n            sz = 2**i\n            if sz * dtype.itemsize > 2**24:\n                break\n            inp_randn = torch.randint(1, 16, (sz,), dtype=dtype, device=device)\n\n            memory = torch.empty_like(inp_randn)\n            memory_out = torch.empty_like(memory)\n            torch_eager_output, torch_eager_time = _bench_eager_time(\n                lambda inp: torch_allreduce(inp, group), inp_randn\n            )\n            symm_mem_eager_output, symm_mem_eager_time = _bench_eager_time(\n                lambda inp: torch_symm_mem_allreduce(inp, torch_symm_mem_comm),\n                inp_randn,\n            )\n            symm_mem_graph_output, symm_mem_graph_time = _bench_graph_time(\n                lambda inp: torch_symm_mem_allreduce(inp, torch_symm_mem_comm),\n                inp_randn,\n            )\n            # since pynccl is inplace op, this return result is not correct if graph loop > 1\n            _, pynccl_graph_time = _bench_graph_time(\n                lambda inp: pynccl_allreduce(inp, pynccl_comm), inp_randn\n            )\n            torch.testing.assert_close(torch_eager_output, symm_mem_graph_output)\n            torch.testing.assert_close(torch_eager_output, symm_mem_eager_output)\n            result.append(\n                {\n                    \"msg_size\": human_readable_size(inp_randn.nbytes),\n                    \"torch eager time\": torch_eager_time,\n                    \"symm mem eager time\": symm_mem_eager_time,\n                    \"symm mem graph time\": symm_mem_graph_time,\n                    \"pynccl graph time\": pynccl_graph_time,\n                }\n            )\n            if rank == 0:\n                print(f\"sz={sz}, dtype={dtype}: correctness check PASS!\")\n    if rank == 0:\n        print_markdown_table(result)\n    if profile:\n        prof_dir = f\"prof/torch_symm_mem\"\n        os.makedirs(prof_dir, exist_ok=True)\n        ctx.export_chrome_trace(f\"{prof_dir}/trace_rank{dist.get_rank()}.json.gz\")\n"
  },
  {
    "path": "benchmark/kernels/decoding_attention_triton/triton_flashinfer_cudnn.py",
    "content": "import itertools\nimport math\n\nimport cudnn\nimport torch\nimport torch.utils.benchmark as benchmark\nfrom flashinfer import BatchDecodeWithPagedKVCacheWrapper\n\nfrom sglang.srt.layers.attention.flashinfer_backend import should_use_tensor_core\nfrom sglang.srt.layers.attention.triton_ops.decode_attention import decode_attention_fwd\n\n\ndef benchmark_forward(\n    fn,\n    *inputs,\n    repeats=10,\n    amp=False,\n    amp_dtype=torch.float16,\n    **kwinputs,\n):\n    def amp_wrapper(*inputs, **kwinputs):\n        with torch.autocast(device_type=\"cuda\", dtype=amp_dtype, enabled=amp):\n            fn(*inputs, **kwinputs)\n\n    t = benchmark.Timer(\n        stmt=\"fn_amp(*inputs, **kwinputs)\",\n        globals={\"fn_amp\": amp_wrapper, \"inputs\": inputs, \"kwinputs\": kwinputs},\n        num_threads=torch.get_num_threads(),\n    )\n    m = t.timeit(repeats)\n    return t, m\n\n\ndef time_fwd(func, *args, **kwargs):\n    time_f = benchmark_forward(func, *args, **kwargs)\n    return time_f[1].mean * 1e6\n\n\ndef decode_attention_sglang(\n    q,\n    kv_data,\n    batch_size,\n    kv_len,\n    head_num_q,\n    head_num_kv,\n    head_dim,\n    num_kv_splits,\n    warmup=10,\n):\n\n    k_buffer = kv_data[0].view(-1, head_num_kv, head_dim)\n    v_buffer = kv_data[1].view(-1, head_num_kv, head_dim)\n    o = torch.empty_like(q)\n    total_tokens = batch_size * kv_len\n    req_to_token = torch.arange(0, total_tokens).to(0).int().view(batch_size, kv_len)\n    b_req_idx = torch.arange(0, batch_size).to(0).int()\n    b_seq_len = torch.full((batch_size,), kv_len, dtype=torch.int32, device=\"cuda\")\n    max_len_in_batch = kv_len\n    sm_scale = 1.0 / (head_dim**0.5)\n\n    attn_logits = torch.empty(\n        (batch_size, head_num_q, num_kv_splits, head_dim + 1),\n        dtype=torch.float32,\n        device=\"cuda\",\n    )\n\n    for _ in range(warmup):\n        decode_attention_fwd(\n            q,\n            k_buffer,\n            v_buffer,\n            o,\n            req_to_token,\n            b_req_idx,\n            b_seq_len,\n            attn_logits,\n            num_kv_splits,\n            sm_scale,\n        )\n\n    f = time_fwd(\n        decode_attention_fwd,\n        q,\n        k_buffer,\n        v_buffer,\n        o,\n        req_to_token,\n        b_req_idx,\n        b_seq_len,\n        attn_logits,\n        num_kv_splits,\n        sm_scale,\n    )\n\n    return f, o\n\n\ndef decode_attention_flashinfer(dtype, head_num_q, head_num_kv):\n    workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=\"cuda\")\n    use_tensor_cores = should_use_tensor_core(\n        kv_cache_dtype=dtype,\n        num_attention_heads=head_num_q,\n        num_kv_heads=head_num_kv,\n    )\n    flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(\n        workspace_buffer, \"NHD\", use_tensor_cores=use_tensor_cores\n    )\n\n    class FlashinferAttention(torch.autograd.Function):\n        @staticmethod\n        def forward(\n            ctx,\n            q,\n            kv_data,\n            batch_size,\n            kv_len,\n            head_num_q,\n            head_num_kv,\n            head_dim,\n            dtype,\n            warmup=10,\n        ):\n            total_tokens = batch_size * kv_len\n            kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len\n            kv_indices = torch.arange(0, total_tokens).to(0).int()\n            kv_last_page_len = torch.full(\n                (batch_size,), 1, dtype=torch.int32, device=\"cuda\"\n            )\n\n            flashinfer_decode_wrapper.end_forward()\n            flashinfer_decode_wrapper.begin_forward(\n                kv_indptr,\n                kv_indices,\n                kv_last_page_len,\n                head_num_q,\n                head_num_kv,\n                head_dim,\n                1,\n                pos_encoding_mode=\"NONE\",\n                data_type=dtype,\n            )\n\n            for _ in range(warmup):\n                o = flashinfer_decode_wrapper.forward(\n                    q.contiguous().view(-1, head_num_q, head_dim), kv_data\n                )\n\n            f = time_fwd(\n                flashinfer_decode_wrapper.forward,\n                q.contiguous().view(-1, head_num_q, head_dim),\n                kv_data,\n            )\n\n            return f, o\n\n    return FlashinferAttention\n\n\ndef convert_to_cudnn_type(torch_type):\n    if torch_type == torch.float16:\n        return cudnn.data_type.HALF\n    elif torch_type == torch.bfloat16:\n        return cudnn.data_type.BFLOAT16\n    elif torch_type == torch.float32:\n        return cudnn.data_type.FLOAT\n    elif torch_type == torch.int32:\n        return cudnn.data_type.INT32\n    elif torch_type == torch.int64:\n        return cudnn.data_type.INT64\n    else:\n        raise ValueError(\"Unsupported tensor data type.\")\n\n\ndef decode_attention_cudnn(\n    q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype, warmup=10\n):\n    # Prepare data: continuous q,k,v\n    dims_q = (batch_size, head_num_q, 1, head_dim)\n    strides_q = (head_num_q * head_dim, head_dim, head_num_q * head_dim, 1)\n    q_gpu = q.as_strided(dims_q, strides_q)\n    o_gpu = (\n        torch.empty(batch_size * head_num_q * head_dim)\n        .half()\n        .cuda()\n        .as_strided(dims_q, strides_q)\n    )\n\n    dims_kv = (batch_size, head_num_kv, kv_len, head_dim)\n    strides_kv = (\n        kv_len * head_num_kv * head_dim,\n        head_dim,\n        head_num_kv * head_dim,\n        1,\n    )\n    k_gpu = kv_data[0].as_strided(dims_kv, strides_kv)\n    v_gpu = kv_data[1].as_strided(dims_kv, strides_kv)\n\n    seq_len_q_gpu = torch.full((batch_size, 1, 1, 1), 1, device=\"cuda\")\n    seq_len_kv_gpu = torch.full((batch_size, 1, 1, 1), kv_len, device=\"cuda\")\n    attn_scale = 1.0 / (head_dim**0.5)\n\n    # Prepare data: paged k,v\n    block_size = 1\n    blocks_per_batch = math.ceil(kv_len / block_size)\n    # [num_blocks, head_num_kv, block_size, head_dim], num_blocks = batch_size * blocks_per_batch\n    container_k_gpu = torch.cat(k_gpu.chunk(blocks_per_batch, dim=2), dim=0)\n    container_v_gpu = torch.cat(v_gpu.chunk(blocks_per_batch, dim=2), dim=0)\n    page_table_k_gpu = (\n        torch.linspace(\n            0,\n            batch_size * blocks_per_batch - 1,\n            batch_size * blocks_per_batch,\n            device=\"cuda\",\n            dtype=torch.int32,\n        )\n        .reshape(blocks_per_batch, 1, batch_size, 1)\n        .transpose(0, 2)\n    )\n    page_table_v_gpu = page_table_k_gpu.clone()\n\n    graph = cudnn.pygraph(\n        io_data_type=convert_to_cudnn_type(dtype),\n        intermediate_data_type=cudnn.data_type.FLOAT,\n        compute_data_type=cudnn.data_type.FLOAT,\n    )\n\n    q = graph.tensor_like(q_gpu)\n    container_k = graph.tensor_like(container_k_gpu)\n    container_v = graph.tensor_like(container_v_gpu)\n    page_table_k = graph.tensor_like(page_table_k_gpu)\n    page_table_v = graph.tensor_like(page_table_v_gpu)\n\n    seq_len_q = graph.tensor_like(seq_len_q_gpu)\n    seq_len_kv = graph.tensor_like(seq_len_kv_gpu)\n\n    o, _ = graph.sdpa(\n        name=\"sdpa\",\n        q=q,\n        k=container_k,  # Container K: non contiguous container with K blocks\n        v=container_v,  # Container V: non contiguous container with V blocks\n        is_inference=True,\n        attn_scale=attn_scale,\n        use_causal_mask=False,\n        use_padding_mask=True,\n        seq_len_q=seq_len_q,\n        seq_len_kv=seq_len_kv,\n        paged_attention_k_table=page_table_k,  # Page Table K: Tensor containing offsets to the container with K blocks\n        paged_attention_v_table=page_table_v,  # Page Table V: Tensor containing offsets to the container with V blocks\n        paged_attention_max_seq_len_kv=kv_len,  # The maximum sequence length for K caches (this is optional, but recommended)\n    )\n\n    o.set_output(True).set_dim(dims_q).set_stride(strides_q)\n\n    graph.validate()\n    graph.build_operation_graph()\n    graph.create_execution_plans([cudnn.heur_mode.A])\n    graph.check_support()\n    graph.build_plans()\n\n    workspace = torch.empty(\n        graph.get_workspace_size(), device=\"cuda\", dtype=torch.uint8\n    )\n\n    variant_pack = {\n        q: q_gpu,\n        container_k: container_k_gpu,\n        container_v: container_v_gpu,\n        page_table_k: page_table_k_gpu,\n        page_table_v: page_table_v_gpu,\n        seq_len_q: seq_len_q_gpu,\n        seq_len_kv: seq_len_kv_gpu,\n        o: o_gpu,\n    }\n\n    for _ in range(warmup):\n        graph.execute(variant_pack, workspace)\n\n    f = time_fwd(\n        graph.execute,\n        variant_pack,\n        workspace,\n    )\n\n    return f, o_gpu.squeeze(dim=2)\n\n\ndef calculate_diff():\n\n    dtype = torch.float16\n    batch_size = 64\n    kv_len = 4096\n    head_num_q = 64\n    head_num_kv = 8\n    head_dim = 128\n\n    q = torch.randn(batch_size, head_num_q, head_dim, dtype=dtype, device=\"cuda\")\n    kv_data = (\n        torch.randn(\n            batch_size * kv_len, head_num_kv, head_dim, dtype=dtype, device=\"cuda\"\n        ),\n        torch.randn(\n            batch_size * kv_len, head_num_kv, head_dim, dtype=dtype, device=\"cuda\"\n        ),\n    )\n\n    _, output_sglang = decode_attention_sglang(\n        q,\n        kv_data,\n        batch_size,\n        kv_len,\n        head_num_q,\n        head_num_kv,\n        head_dim,\n        num_kv_splits=8,\n    )\n\n    attn_flashinfer = decode_attention_flashinfer(dtype, head_num_q, head_num_kv).apply\n    _, output_flashinfer = attn_flashinfer(\n        q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype\n    )\n\n    _, output_cudnn = decode_attention_cudnn(\n        q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype\n    )\n\n    print(f\"SGLang output={output_sglang}\")\n    print(f\"FlashInfer output={output_flashinfer}\")\n    print(f\"cuDNN output={output_cudnn}\")\n    if torch.allclose(output_sglang, output_flashinfer, atol=1e-2, rtol=1e-2):\n        print(\"✅ SGLang[Triton] and FlashInfer match\")\n    else:\n        print(\"❌ SGLang[Triton] and FlashInfer differ\")\n\n    if torch.allclose(output_sglang, output_cudnn, atol=1e-2, rtol=1e-2):\n        print(\"✅ SGLang[Triton] and cuDNN match\")\n    else:\n        print(\"❌ SGLang[Triton] and cuDNN differ\")\n\n\nif __name__ == \"__main__\":\n    calculate_diff()\n\n    head_dim = 128\n    dtype = torch.float16\n    batch_size_range = [2**i for i in range(0, 8, 2)]\n    kv_len_range = [2**i for i in range(6, 13, 1)]\n    configs = list(itertools.product(batch_size_range, kv_len_range))\n\n    for head_num_q, head_num_kv in [[32, 32], [64, 8], [40, 8]]:\n        attn_flashinfer = decode_attention_flashinfer(\n            dtype, head_num_q, head_num_kv\n        ).apply\n        for batch_size, kv_len in configs:\n            q = torch.randn(\n                batch_size, head_num_q, head_dim, dtype=dtype, device=\"cuda\"\n            )\n            kv_data = (\n                torch.randn(\n                    batch_size * kv_len,\n                    head_num_kv,\n                    head_dim,\n                    dtype=dtype,\n                    device=\"cuda\",\n                ),\n                torch.randn(\n                    batch_size * kv_len,\n                    head_num_kv,\n                    head_dim,\n                    dtype=dtype,\n                    device=\"cuda\",\n                ),\n            )\n            us_cudnn, output_cudnn = decode_attention_cudnn(\n                q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype\n            )\n            us_sglang, output_sglang = decode_attention_sglang(\n                q,\n                kv_data,\n                batch_size,\n                kv_len,\n                head_num_q,\n                head_num_kv,\n                head_dim,\n                num_kv_splits=8,\n            )\n            us_flashinfer, _ = attn_flashinfer(\n                q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype\n            )\n            print(\n                head_num_q,\n                \"  \",\n                head_num_kv,\n                \"  \",\n                batch_size,\n                \"  \",\n                kv_len,\n                \"  \",\n                us_cudnn,\n                \"  \",\n                us_sglang,\n                \"  \",\n                us_flashinfer,\n            )\n"
  },
  {
    "path": "benchmark/kernels/deepep/deepep_utils.py",
    "content": "# ADAPTED FROM https://github.com/deepseek-ai/DeepEP/blob/main/tests/utils.py\n\nimport os\nimport sys\nfrom typing import Optional\n\nimport numpy as np\nimport torch\nimport torch.distributed as dist\n\n\ndef init_dist(local_rank: int, num_local_ranks: int, args):\n    ip = args.master_addr\n    port = args.master_port\n    num_nodes = args.nnodes\n    node_rank = args.node_rank\n    assert (num_local_ranks < 8 and num_nodes == 1) or num_local_ranks == 8\n\n    dist.init_process_group(\n        backend=\"nccl\",\n        init_method=f\"tcp://{ip}:{port}\",\n        world_size=num_nodes * num_local_ranks,\n        rank=node_rank * num_local_ranks + local_rank,\n    )\n    torch.set_default_dtype(torch.bfloat16)\n    torch.set_default_device(\"cuda\")\n    torch.cuda.set_device(local_rank)\n\n    return (\n        dist.get_rank(),\n        dist.get_world_size(),\n        dist.new_group(list(range(num_local_ranks * num_nodes))),\n    )\n\n\ndef calc_diff(x: torch.Tensor, y: torch.Tensor):\n    x, y = x.double() + 1, y.double() + 1\n    denominator = (x * x + y * y).sum()\n    sim = 2 * (x * y).sum() / denominator\n    return (1 - sim).item()\n\n\ndef per_token_cast_to_fp8(x: torch.Tensor):\n    assert x.dim() == 2 and x.size(1) % 128 == 0\n    m, n = x.shape\n    x_view = x.view(m, -1, 128)\n    x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)\n    return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(\n        m, n\n    ), (x_amax / 448.0).view(m, -1)\n\n\ndef per_token_cast_back(x_fp8: torch.Tensor, x_scales: torch.Tensor):\n    x_fp32 = x_fp8.to(torch.float32).view(x_fp8.size(0), -1, 128)\n    x_scales = x_scales.view(x_fp8.size(0), -1, 1)\n    return (x_fp32 * x_scales).view(x_fp8.shape).to(torch.bfloat16)\n\n\ndef inplace_unique(x: torch.Tensor, num_slots: int):\n    assert x.dim() == 2\n    mask = x < 0\n    x_padded = x.masked_fill(mask, num_slots)\n    bin_count = torch.zeros((x.size(0), num_slots + 1), dtype=x.dtype, device=x.device)\n    bin_count.scatter_add_(1, x_padded, torch.ones_like(x_padded))\n    bin_count = bin_count[:, :num_slots]\n    sorted_bin_count, sorted_bin_idx = torch.sort(bin_count, dim=-1, descending=True)\n    sorted_bin_idx.masked_fill_(sorted_bin_count == 0, -1)\n    sorted_bin_idx = torch.sort(sorted_bin_idx, descending=True, dim=-1).values\n    x[:, :].fill_(-1)\n    valid_len = min(num_slots, x.size(1))\n    x[:, :valid_len] = sorted_bin_idx[:, :valid_len]\n\n\ndef create_grouped_scores(\n    scores: torch.Tensor, group_idx: torch.Tensor, num_groups: int\n):\n    num_tokens, num_experts = scores.shape\n    scores = scores.view(num_tokens, num_groups, -1)\n    mask = torch.zeros((num_tokens, num_groups), dtype=torch.bool, device=scores.device)\n    mask = mask.scatter_(1, group_idx, True).unsqueeze(-1).expand_as(scores)\n    return (scores * mask).view(num_tokens, num_experts)\n\n\ndef bench(fn, num_warmups: int = 20, num_tests: int = 30, post_fn=None):\n    # Flush L2 cache with 256 MB data\n    torch.cuda.synchronize()\n    cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=\"cuda\")\n\n    # Warmup\n    for _ in range(num_warmups):\n        fn()\n\n    # Flush L2\n    cache.zero_()\n\n    # Testing\n    start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)]\n    end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)]\n    for i in range(num_tests):\n        # Record\n        start_events[i].record()\n        fn()\n        end_events[i].record()\n        if post_fn is not None:\n            post_fn()\n    torch.cuda.synchronize()\n\n    times = np.array(\n        [s.elapsed_time(e) / 1e3 for s, e in zip(start_events, end_events)]\n    )[1:]\n    return np.average(times), np.min(times), np.max(times)\n\n\nclass empty_suppress:\n    def __enter__(self):\n        return self\n\n    def __exit__(self, *_):\n        pass\n\n\nclass suppress_stdout_stderr:\n    def __enter__(self):\n        self.outnull_file = open(os.devnull, \"w\")\n        self.errnull_file = open(os.devnull, \"w\")\n\n        self.old_stdout_fileno_undup = sys.stdout.fileno()\n        self.old_stderr_fileno_undup = sys.stderr.fileno()\n\n        self.old_stdout_fileno = os.dup(sys.stdout.fileno())\n        self.old_stderr_fileno = os.dup(sys.stderr.fileno())\n\n        self.old_stdout = sys.stdout\n        self.old_stderr = sys.stderr\n\n        os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)\n        os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)\n\n        sys.stdout = self.outnull_file\n        sys.stderr = self.errnull_file\n        return self\n\n    def __exit__(self, *_):\n        sys.stdout = self.old_stdout\n        sys.stderr = self.old_stderr\n\n        os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)\n        os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)\n\n        os.close(self.old_stdout_fileno)\n        os.close(self.old_stderr_fileno)\n\n        self.outnull_file.close()\n        self.errnull_file.close()\n\n\ndef bench_kineto(\n    fn,\n    kernel_names,\n    num_tests: int = 30,\n    suppress_kineto_output: bool = False,\n    trace_path: Optional[str] = None,\n    barrier_comm_profiling: bool = False,\n):\n    # Profile\n    suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress\n    with suppress():\n        schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1)\n        with torch.profiler.profile(\n            activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule\n        ) as prof:\n            for i in range(2):\n                # NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead\n                if barrier_comm_profiling:\n                    lhs = torch.randn((8192, 8192), dtype=torch.float, device=\"cuda\")\n                    rhs = torch.randn((8192, 8192), dtype=torch.float, device=\"cuda\")\n                    lhs @ rhs\n                    dist.all_reduce(torch.ones(1, dtype=torch.float, device=\"cuda\"))\n                for _ in range(num_tests):\n                    fn()\n                prof.step()\n\n    # Parse the profiling table\n    assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)\n    is_tupled = isinstance(kernel_names, tuple)\n    prof_lines = (\n        prof.key_averages()\n        .table(sort_by=\"cuda_time_total\", max_name_column_width=100)\n        .split(\"\\n\")\n    )\n    kernel_names = (kernel_names,) if isinstance(kernel_names, str) else kernel_names\n    assert all([isinstance(name, str) for name in kernel_names])\n    for name in kernel_names:\n        assert (\n            sum([name in line for line in prof_lines]) == 1\n        ), f\"Errors of the kernel {name} in the profiling table\"\n\n    # Save chrome traces\n    if trace_path is not None:\n        prof.export_chrome_trace(trace_path)\n\n    # Return average kernel times\n    units = {\"ms\": 1e3, \"us\": 1e6}\n    kernel_times = []\n    for name in kernel_names:\n        for line in prof_lines:\n            if name in line:\n                time_str = line.split()[-2]\n                for unit, scale in units.items():\n                    if unit in time_str:\n                        kernel_times.append(float(time_str.replace(unit, \"\")) / scale)\n                        break\n                break\n    return tuple(kernel_times) if is_tupled else kernel_times[0]\n\n\ndef hash_tensor(t: torch.Tensor):\n    return t.view(torch.int64).sum().item()\n"
  },
  {
    "path": "benchmark/kernels/deepep/tuning_deepep.py",
    "content": "# MODIFIED FROM https://github.com/deepseek-ai/DeepEP/blob/main/tests/test_internode.py\n\n\"\"\"\nExample usage:\npython tuning_deepep.py --nnodes 4 --node-rank $MY_NODE_RANK --master-addr 1.2.3.4\nThen check `deepep_tuned.json`\n\"\"\"\n\nimport argparse\nimport json\nimport time\nfrom copy import deepcopy\nfrom pathlib import Path\n\n# noinspection PyUnresolvedReferences\nimport deep_ep\nimport torch\nimport torch.distributed as dist\nfrom deepep_utils import (\n    bench,\n    calc_diff,\n    create_grouped_scores,\n    init_dist,\n    inplace_unique,\n    per_token_cast_back,\n    per_token_cast_to_fp8,\n)\n\n\ndef test_main(\n    num_sms: int,\n    local_rank: int,\n    num_local_ranks: int,\n    num_ranks: int,\n    num_nodes: int,\n    rank: int,\n    buffer: deep_ep.Buffer,\n    group: dist.ProcessGroup,\n    args,\n):\n    # Settings\n    num_tokens, hidden, num_topk_groups, num_topk, num_experts = (\n        args.num_tokens,\n        args.hidden,\n        min(num_nodes, 4),\n        args.num_topk,\n        (args.num_experts // num_ranks) * num_ranks,\n    )\n    assert num_experts % num_ranks == 0 and num_local_ranks == 8\n    if local_rank == 0:\n        print(\n            f\"[config] num_tokens={num_tokens}, hidden={hidden}, num_topk_groups={num_topk_groups}, num_topk={num_topk}\",\n            flush=True,\n        )\n\n    # Random data\n    x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device=\"cuda\") * rank\n    x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device=\"cuda\")\n    x_e4m3 = per_token_cast_to_fp8(x)\n    scores = (\n        torch.randn((num_tokens, num_experts), dtype=torch.float32, device=\"cuda\").abs()\n        + 1\n    )\n    group_scores = scores.view(num_tokens, num_nodes, -1).amax(dim=-1)\n    group_idx = torch.topk(\n        group_scores, k=num_topk_groups, dim=-1, sorted=False\n    ).indices\n    masked_scores = create_grouped_scores(scores, group_idx, num_nodes)\n    topk_idx = torch.topk(masked_scores, num_topk, dim=-1, largest=True, sorted=False)[\n        1\n    ]\n    topk_weights = (\n        torch.ones((num_tokens, num_topk), dtype=torch.float32, device=\"cuda\") * rank\n    )\n    topk_weights_pure_rand = torch.randn(\n        (num_tokens, num_topk), dtype=torch.float32, device=\"cuda\"\n    )\n    rank_idx = topk_idx // (num_experts // num_ranks)\n    rank_idx.masked_fill_(topk_idx == -1, -1)\n    inplace_unique(rank_idx, num_ranks)\n    rdma_rank_idx = rank_idx // num_local_ranks\n    rdma_rank_idx.masked_fill_(rank_idx == -1, -1)\n    inplace_unique(rdma_rank_idx, num_nodes)\n\n    # RDMA dispatch counts\n    rdma_idx = topk_idx // (num_experts // num_nodes)\n    rdma_idx.masked_fill_(topk_idx == -1, -1)\n    inplace_unique(rdma_idx, num_nodes)\n    num_rdma_token_sent = rdma_idx.ne(-1).sum().item()\n\n    # Expert meta\n    num_tokens_per_expert = torch.zeros((num_experts,), dtype=torch.int, device=\"cuda\")\n    for i in range(num_experts):\n        num_tokens_per_expert[i] = (topk_idx == i).sum()\n    gbl_num_tokens_per_expert = num_tokens_per_expert.clone()\n    dist.all_reduce(gbl_num_tokens_per_expert, group=group)\n\n    # Rank layout meta\n    num_tokens_per_rank = torch.empty((num_ranks,), dtype=torch.int, device=\"cuda\")\n    num_tokens_per_rdma_rank = torch.empty((num_nodes,), dtype=torch.int, device=\"cuda\")\n    token_idx_in_rank = torch.full(\n        (num_ranks, num_tokens), -1, dtype=torch.long, device=\"cuda\"\n    )\n    for i in range(num_ranks):\n        num_tokens_per_rank[i] = (rank_idx == i).sum()\n        token_sel = (rank_idx == i).max(dim=-1)[0]\n        count = token_sel.sum().item()\n        tokens = torch.sort(token_sel.to(torch.int), descending=True)[1]\n        tokens[:count] = torch.sort(tokens[:count])[0]\n        token_idx_in_rank[i][tokens[:count]] = torch.arange(\n            count, dtype=torch.long, device=\"cuda\"\n        )\n    for i in range(num_nodes):\n        num_tokens_per_rdma_rank[i] = (rdma_rank_idx == i).sum()\n    token_idx_in_rank = token_idx_in_rank.T.contiguous().to(torch.int)\n    is_token_in_rank = token_idx_in_rank >= 0\n    gbl_num_tokens_per_rank = num_tokens_per_rank.clone()\n    dist.all_reduce(gbl_num_tokens_per_rank, group=group)\n\n    (\n        ref_num_tokens_per_rank,\n        ref_num_tokens_per_rdma_rank,\n        ref_num_tokens_per_expert,\n        ref_is_token_in_rank,\n        _,\n    ) = buffer.get_dispatch_layout(topk_idx, num_experts)\n    assert torch.allclose(ref_num_tokens_per_rank, num_tokens_per_rank)\n    assert torch.allclose(ref_num_tokens_per_rdma_rank, num_tokens_per_rdma_rank)\n    assert torch.allclose(ref_num_tokens_per_expert, num_tokens_per_expert)\n    assert torch.allclose(ref_is_token_in_rank, is_token_in_rank)\n    t = bench(lambda: buffer.get_dispatch_layout(topk_idx, num_experts))[0]\n    if local_rank == 0:\n        print(f\"[layout] Kernel performance: {t * 1000:.3f} ms\", flush=True)\n        print(\"\", flush=True)\n    group.barrier()\n    time.sleep(1)\n\n    # Config\n    rdma_buffer_size, nvl_buffer_size = 128, (720 if num_ranks in (144, 160) else 512)\n    config = deep_ep.Config(num_sms, 8, nvl_buffer_size, 16, rdma_buffer_size)\n\n    # Test dispatch\n    # noinspection PyShadowingNames\n    def check_data(check_x, recv_gbl_rank_prefix_sum):\n        assert torch.allclose(check_x.amin(dim=1), check_x.amax(dim=1))\n        check_start = 0\n        for i in range(num_ranks):\n            check_end = recv_gbl_rank_prefix_sum[i].item()\n            assert (check_x[check_start:check_end, :].int() - i).sum().item() == 0\n            check_start = check_end\n\n    for previous_mode in (False, True):\n        for async_mode in (False, True):\n            for current_x in (x_pure_rand, x, x_e4m3):\n                for with_topk in (False, True):\n                    if local_rank == 0:\n                        print(\n                            f'[testing] Running with {\"FP8\" if isinstance(current_x, tuple) else \"BF16\"}, {\"with\" if with_topk else \"without\"} top-k (async={async_mode}, previous={previous_mode}) ...',\n                            flush=True,\n                            end=\"\",\n                        )\n                    dispatch_args = {\n                        \"x\": current_x,\n                        \"num_tokens_per_rank\": num_tokens_per_rank,\n                        \"num_tokens_per_rdma_rank\": num_tokens_per_rdma_rank,\n                        \"is_token_in_rank\": is_token_in_rank,\n                        \"num_tokens_per_expert\": num_tokens_per_expert,\n                        \"config\": config,\n                        \"async_finish\": async_mode,\n                    }\n                    if with_topk:\n                        dispatch_args.update(\n                            {\n                                \"topk_idx\": topk_idx,\n                                \"topk_weights\": (\n                                    topk_weights_pure_rand\n                                    if current_x is x_pure_rand\n                                    else topk_weights\n                                ),\n                            }\n                        )\n                    if previous_mode:\n                        dispatch_args.update({\"previous_event\": buffer.capture()})\n                    (\n                        recv_x,\n                        recv_topk_idx,\n                        recv_topk_weights,\n                        recv_num_tokens_per_expert_list,\n                        handle,\n                        event,\n                    ) = buffer.dispatch(**dispatch_args)\n                    event.current_stream_wait() if async_mode else ()\n                    recv_x = (\n                        per_token_cast_back(*recv_x)\n                        if isinstance(recv_x, tuple)\n                        else recv_x\n                    )\n\n                    # Checks\n                    recv_gbl_rank_prefix_sum = handle[-4]\n                    assert gbl_num_tokens_per_rank[rank].item() == recv_x.size(\n                        0\n                    ), f\"{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}\"\n                    assert (\n                        gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist()\n                        == recv_num_tokens_per_expert_list\n                    )\n                    if current_x is not x_pure_rand:\n                        check_data(recv_x, recv_gbl_rank_prefix_sum)\n                    if with_topk:\n                        # Check `topk_idx`\n                        assert (\n                            recv_topk_idx.eq(-1)\n                            | (\n                                (recv_topk_idx >= 0)\n                                & (recv_topk_idx < (num_experts // num_ranks))\n                            )\n                        ).sum().item() == recv_topk_idx.numel()\n                        for i, count in enumerate(recv_num_tokens_per_expert_list):\n                            assert recv_topk_idx.eq(i).sum().item() == count\n\n                        # Check `topk_weights`\n                        if current_x is not x_pure_rand:\n                            recv_topk_weights[recv_topk_idx.eq(-1)] = (\n                                recv_topk_weights.amax(dim=1, keepdim=True).expand_as(\n                                    recv_topk_weights\n                                )[recv_topk_idx.eq(-1)]\n                            )\n                            check_data(recv_topk_weights, recv_gbl_rank_prefix_sum)\n\n                    # Test cached dispatch (must without top-k staffs)\n                    if not with_topk:\n                        dispatch_args = {\n                            \"x\": current_x,\n                            \"handle\": handle,\n                            \"config\": config,\n                            \"async_finish\": async_mode,\n                        }\n                        if previous_mode:\n                            dispatch_args.update({\"previous_event\": buffer.capture()})\n                        recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args)\n                        event.current_stream_wait() if async_mode else ()\n                        recv_x = (\n                            per_token_cast_back(*recv_x)\n                            if isinstance(recv_x, tuple)\n                            else recv_x\n                        )\n                        if current_x is not x_pure_rand:\n                            check_data(recv_x, recv_gbl_rank_prefix_sum)\n\n                    # Test combine\n                    combine_args = {\n                        \"x\": recv_x,\n                        \"handle\": handle,\n                        \"config\": config,\n                        \"async_finish\": async_mode,\n                    }\n                    if with_topk:\n                        combine_args.update({\"topk_weights\": recv_topk_weights})\n                    if previous_mode:\n                        combine_args.update({\"previous_event\": buffer.capture()})\n                    combined_x, combined_topk_weights, event = buffer.combine(\n                        **combine_args\n                    )\n                    event.current_stream_wait() if async_mode else ()\n                    check_x = combined_x.float() / is_token_in_rank.sum(\n                        dim=1\n                    ).unsqueeze(1)\n                    ref_x = x_pure_rand if current_x is x_pure_rand else x\n                    assert calc_diff(check_x, ref_x) < 5e-6\n                    if with_topk:\n                        check_topk_weights = (\n                            combined_topk_weights\n                            if (current_x is x_pure_rand)\n                            else (\n                                combined_topk_weights\n                                / is_token_in_rank.sum(dim=1).unsqueeze(1)\n                            )\n                        )\n                        ref_topk_weights = (\n                            topk_weights_pure_rand\n                            if current_x is x_pure_rand\n                            else topk_weights\n                        )\n                        assert calc_diff(check_topk_weights, ref_topk_weights) < 1e-9\n\n                    # For later tuning\n                    dispatch_bf16_rdma_send_bytes = num_rdma_token_sent * hidden * 2\n                    dispatch_bf16_nvl_recv_bytes = recv_x.numel() * 2\n                    combine_bf16_nvl_send_bytes = dispatch_bf16_nvl_recv_bytes\n                    combine_bf16_rdma_recv_bytes = dispatch_bf16_rdma_send_bytes\n\n                    if local_rank == 0:\n                        print(\" passed\", flush=True)\n    if local_rank == 0:\n        print(\"\", flush=True)\n\n    output_data = {}\n\n    # Tune dispatch performance\n    best_dispatch_results = None\n    fp8_factor = (1 + 4 / 128) / 2\n    for current_x in (x_e4m3, x):\n        best_time, best_results = 1e10, None\n        rdma_send_bytes = (\n            (dispatch_bf16_rdma_send_bytes * fp8_factor)\n            if isinstance(current_x, tuple)\n            else dispatch_bf16_rdma_send_bytes\n        )\n        nvl_recv_bytes = (\n            (dispatch_bf16_nvl_recv_bytes * fp8_factor)\n            if isinstance(current_x, tuple)\n            else dispatch_bf16_nvl_recv_bytes\n        )\n        for nvl_chunk_size in range(4, 33, 4):\n            for rdma_chunk_size in range(4, 33, 4):\n                config_kwargs = {\n                    \"num_sms\": num_sms,\n                    \"num_max_nvl_chunked_send_tokens\": nvl_chunk_size,\n                    \"num_max_nvl_chunked_recv_tokens\": nvl_buffer_size,\n                    \"num_max_rdma_chunked_send_tokens\": rdma_chunk_size,\n                    \"num_max_rdma_chunked_recv_tokens\": rdma_buffer_size,\n                }\n                config = deep_ep.Config(**config_kwargs)\n                tune_args = {\"x\": current_x, \"handle\": handle, \"config\": config}\n                t = bench(lambda: buffer.dispatch(**tune_args))[0]\n                if t < best_time:\n                    best_time, best_results = t, (\n                        num_sms,\n                        nvl_chunk_size,\n                        rdma_chunk_size,\n                        config_kwargs,\n                    )\n                if local_rank == 0:\n                    print(\n                        f\"[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: {rdma_send_bytes / 1e9 / t:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) \",\n                        flush=True,\n                    )\n        if local_rank == 0:\n            print(\n                f'[tuning] Best dispatch ({\"FP8\" if isinstance(current_x, tuple) else \"BF16\"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {rdma_send_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)',\n                flush=True,\n            )\n            print(\"\", flush=True)\n            is_fp8 = isinstance(current_x, tuple)\n            if is_fp8:\n                output_data[\"normal_dispatch\"] = deepcopy(best_results[3])\n\n        if isinstance(current_x, tuple):\n            # Gather FP8 the best config from rank 0\n            best_dispatch_results = torch.tensor(\n                [best_results[0], best_results[1], best_results[2]],\n                dtype=torch.int32,\n                device=\"cuda\",\n            )\n            all_best_fp8_results_list = [\n                torch.zeros_like(best_dispatch_results)\n                for _ in range(torch.distributed.get_world_size())\n            ]\n            dist.all_gather(\n                all_best_fp8_results_list, best_dispatch_results, group=group\n            )\n            best_dispatch_results = all_best_fp8_results_list[0].tolist()\n    dispatch_config = deep_ep.Config(\n        best_dispatch_results[0],\n        best_dispatch_results[1],\n        nvl_buffer_size,\n        best_dispatch_results[2],\n        rdma_buffer_size,\n    )\n\n    dispatch_args = {\n        \"x\": x,\n        \"num_tokens_per_rank\": num_tokens_per_rank,\n        \"num_tokens_per_rdma_rank\": num_tokens_per_rdma_rank,\n        \"is_token_in_rank\": is_token_in_rank,\n        \"num_tokens_per_expert\": num_tokens_per_expert,\n        \"config\": dispatch_config if dispatch_config is not None else config,\n    }\n    recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args)\n\n    # Tune combine performance\n    best_time, best_results = 1e10, None\n    for nvl_chunk_size in range(1, 8, 1):\n        for rdma_chunk_size in range(12 if num_nodes == 2 else 8, 33, 4):\n            config_kwargs = {\n                \"num_sms\": num_sms,\n                \"num_max_nvl_chunked_send_tokens\": nvl_chunk_size,\n                \"num_max_nvl_chunked_recv_tokens\": nvl_buffer_size,\n                \"num_max_rdma_chunked_send_tokens\": rdma_chunk_size,\n                \"num_max_rdma_chunked_recv_tokens\": rdma_buffer_size,\n            }\n            config = deep_ep.Config(**config_kwargs)\n            tune_args = {\"x\": recv_x, \"handle\": handle, \"config\": config}\n            t = bench(lambda: buffer.combine(**tune_args))[0]\n            if local_rank == 0:\n                print(\n                    f\"[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: {combine_bf16_rdma_recv_bytes / 1e9 / t:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) \",\n                    flush=True,\n                )\n                if t < best_time:\n                    best_time, best_results = t, (\n                        num_sms,\n                        nvl_chunk_size,\n                        rdma_chunk_size,\n                        config_kwargs,\n                    )\n\n    if local_rank == 0:\n        print(\n            f\"[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {combine_bf16_rdma_recv_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)\",\n            flush=True,\n        )\n        print(\"\", flush=True)\n        output_data[\"normal_combine\"] = deepcopy(best_results[3])\n\n    if rank == 0 and local_rank == 0:\n        _write_output(args, output_data)\n\n\ndef _write_output(args, output_data):\n    text = json.dumps(output_data, indent=4)\n    output_path = args.output_path\n    print(f\"Write to {output_path} with {text}\")\n    Path(output_path).write_text(text)\n\n\n# noinspection PyUnboundLocalVariable\ndef test_loop(local_rank: int, num_local_ranks: int, args):\n    num_nodes = args.nnodes\n    rank, num_ranks, group = init_dist(local_rank, num_local_ranks, args)\n\n    num_sms = args.num_sms\n    num_qps_per_rank = num_sms // 2\n\n    buffer = deep_ep.Buffer(\n        group,\n        int(1e9),\n        int(1e9),\n        low_latency_mode=False,\n        num_qps_per_rank=num_qps_per_rank,\n    )\n    assert num_local_ranks == 8 and num_ranks > 8\n    torch.manual_seed(rank)\n\n    for i in (num_sms,):\n        test_main(\n            i,\n            local_rank,\n            num_local_ranks,\n            num_ranks,\n            num_nodes,\n            rank,\n            buffer,\n            group,\n            args,\n        )\n        if local_rank == 0:\n            print(\"\", flush=True)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--num-sms\", type=int, default=24)\n    parser.add_argument(\"--num-tokens\", type=int, default=4096)\n    parser.add_argument(\"--hidden\", type=int, default=7168)\n    parser.add_argument(\"--num-topk\", type=int, default=8)\n    parser.add_argument(\"--num-experts\", type=int, default=256)\n    parser.add_argument(\"--output-path\", type=str, default=\"deepep_tuned.json\")\n    parser.add_argument(\"--nnodes\", type=int, default=1)\n    parser.add_argument(\"--node-rank\", type=int, default=0)\n    parser.add_argument(\"--master-addr\", type=str, default=\"127.0.0.1\")\n    parser.add_argument(\"--master-port\", type=int, default=8361)\n    args = parser.parse_args()\n    print(f\"Start system with {args=}\")\n\n    num_processes = 8\n    torch.multiprocessing.spawn(\n        test_loop, args=(num_processes, args), nprocs=num_processes\n    )\n"
  },
  {
    "path": "benchmark/kernels/deepseek/README.md",
    "content": "## DeepSeek kernels benchmark\n\n\n### Prerequisites\n- You should install [DeepGemm](https://github.com/deepseek-ai/DeepGEMM) from source before run `benchmark_deepgemm_fp8_gemm.py` and `benchmark_deepgemm_fp8_group_gemm.py`.\n\n### Benchmark\n- `benchmark_deepgemm_fp8_gemm.py`\n    ```bash\n    python benchmark_deepgemm_fp8_gemm.py --run_correctness --tp_size 1\n    ```\n\n- `benchmark_deepgemm_fp8_group_gemm.py`\n    ```bash\n    python benchmark_deepgemm_fp8_group_gemm.py --run_correctness --tp_size 1\n    ```\n\n - You can use the `--run_correctness` parameter to verify all kernels results's correctness.\n    - You can use the `--tp_size` parameter to benchmark all FP8 w8a8 block-wise matrix multiplications involved in DeepSeek V3/R1 under the current tensor parallelism (TP) setting. This benchmark compares DeepSeek's open-source [DeepGemm](https://github.com/deepseek-ai/DeepGEMM) implementation with SGLang's and VLLM Triton implementation.\n"
  },
  {
    "path": "benchmark/kernels/deepseek/benchmark_deepgemm_fp8_gemm.py",
    "content": "from typing import Tuple\n\nimport deep_gemm\nimport tilelang\nimport tilelang.language as T\nimport torch\nimport triton\nfrom deep_gemm import ceil_div\nfrom deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor\nfrom vllm.model_executor.layers.quantization.utils.fp8_utils import (\n    w8a8_block_fp8_matmul as vllm_w8a8_block_fp8_matmul,\n)\n\nfrom sglang.benchmark.bench_utils import run_bench\nfrom sglang.srt.layers.quantization.fp8_kernel import (\n    w8a8_block_fp8_matmul_deepgemm as w8a8_block_fp8_matmul,\n)\n\n\n# Adapted from https://github.com/tile-ai/tilelang/blob/a8cfdce92795cb861c9033573534653ee040b5ed/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py#L1\ndef tl_gemm(\n    M,\n    N,\n    K,\n    in_dtype,\n    out_dtype,\n    accum_dtype,\n):\n    assert in_dtype in [\n        \"e4m3_float8\",\n    ], \"Currently only e4m3_float8 is supported\"\n    assert out_dtype in [\n        \"bfloat16\",\n        \"float16\",\n    ], \"Currently only bfloat16 and float16 are supported\"\n\n    TILE_SIZE = (128, 128, 128)\n    block_M = TILE_SIZE[0]\n    block_N = TILE_SIZE[1]\n    block_K = TILE_SIZE[2]\n\n    A_shape = (M, K)\n    Scales_A_shape = (M, T.ceildiv(K, block_K))\n    B_shape = (N, K)\n    Scales_B_shape = (T.ceildiv(N, block_N), T.ceildiv(K, block_K))\n    A_shared_shape = (block_M, block_K)\n    B_shared_shape = (block_N, block_K)\n    C_shared_shape = (block_M, block_N)\n\n    @T.prim_func\n    def main(\n        A: T.Buffer(A_shape, in_dtype),\n        scales_a: T.Buffer(Scales_A_shape, \"float32\"),\n        B: T.Buffer(B_shape, in_dtype),\n        scales_b: T.Buffer(Scales_B_shape, \"float32\"),\n        C: T.Buffer((M, N), out_dtype),\n    ):\n        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (\n            bx,\n            by,\n        ):\n\n            A_shared = T.alloc_shared(A_shared_shape, in_dtype)\n            B_shared = T.alloc_shared(B_shared_shape, in_dtype)\n            C_shared = T.alloc_shared(C_shared_shape, out_dtype)\n            Scale_C_shared = T.alloc_shared((block_M), \"float32\")\n            C_local = T.alloc_fragment(C_shared_shape, accum_dtype)\n            C_local_accum = T.alloc_fragment(C_shared_shape, accum_dtype)\n\n            # Improve L2 Cache\n            T.use_swizzle(panel_size=10)\n\n            T.clear(C_local)\n            T.clear(C_local_accum)\n            K_iters = T.ceildiv(K, block_K)\n            for k in T.Pipelined(K_iters, num_stages=4):\n                # Load A into shared memory\n                T.copy(A[by * block_M, k * block_K], A_shared)\n                # Load B into shared memory\n                T.copy(B[bx * block_N, k * block_K], B_shared)\n                # Load scale into shared memory\n                Scale_B = scales_b[bx, k]\n                for i in T.Parallel(block_M):\n                    Scale_C_shared[i] = scales_a[by * block_M + i, k] * Scale_B\n\n                T.gemm(A_shared, B_shared, C_local, transpose_B=True)\n                # Promote to enable 2xAcc\n                for i, j in T.Parallel(block_M, block_N):\n                    C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i]\n                T.clear(C_local)\n            # TMA store\n            T.copy(C_local_accum, C_shared)\n            T.copy(C_shared, C[by * block_M, bx * block_N])\n\n    return main\n\n\ndef per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n    assert x.dim() == 2 and x.size(1) % 128 == 0\n    m, n = x.shape\n    x_view = x.view(m, -1, 128)\n    x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)\n    return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(\n        m, n\n    ), (x_amax / 448.0).view(m, -1)\n\n\ndef per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n    assert x.dim() == 2\n    m, n = x.shape\n    x_padded = torch.zeros(\n        (ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device\n    )\n    x_padded[:m, :n] = x\n    x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)\n    x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)\n    x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)\n    return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(\n        x_view.size(0), x_view.size(2)\n    )\n\n\ndef fp8_gemm_deepgemm(\n    x_fp8: torch.Tensor,\n    x_scale: torch.Tensor,\n    y_fp8: torch.Tensor,\n    y_scale: torch.Tensor,\n    m: int,\n    n: int,\n    k: int,\n):\n    \"\"\"DeepGEMM implementation of FP8 GEMM\"\"\"\n    out = torch.empty((m, n), device=\"cuda\", dtype=torch.bfloat16)\n\n    # Run DeepGEMM kernel\n    deep_gemm.fp8_gemm_nt((x_fp8, x_scale), (y_fp8, y_scale), out)\n    return out\n\n\ndef fp8_gemm_sglang(\n    x_fp8: torch.Tensor,\n    x_scale: torch.Tensor,\n    y_fp8: torch.Tensor,\n    y_scale: torch.Tensor,\n    m: int,\n    n: int,\n    k: int,\n):\n    \"\"\"SGLang implementation of FP8 GEMM\"\"\"\n    block_size = [128, 128]  # Matches the block size in per_block_cast_to_fp8\n\n    # Run SGLang kernel\n    out = w8a8_block_fp8_matmul(\n        x_fp8, y_fp8, x_scale, y_scale, block_size, torch.bfloat16\n    )\n    return out\n\n\ndef fp8_gemm_vllm(\n    x_fp8: torch.Tensor,\n    x_scale: torch.Tensor,\n    y_fp8: torch.Tensor,\n    y_scale: torch.Tensor,\n    m: int,\n    n: int,\n    k: int,\n):\n    \"\"\"vLLM implementation of FP8 GEMM\"\"\"\n    block_size = [128, 128]  # Matches the block size in per_block_cast_to_fp8\n\n    # Run vLLM kernel\n    out = vllm_w8a8_block_fp8_matmul(\n        x_fp8, y_fp8, x_scale, y_scale, block_size, torch.bfloat16\n    )\n    return out\n\n\ndef calculate_diff(m: int, n: int, k: int):\n    x = torch.randn((m, k), device=\"cuda\", dtype=torch.bfloat16)\n    y = torch.randn((n, k), device=\"cuda\", dtype=torch.bfloat16)\n\n    x_fp8, x_scale = per_token_cast_to_fp8(x.clone())\n    y_fp8, y_scale = per_block_cast_to_fp8(y.clone())\n    x_scale_col_major = get_mn_major_tma_aligned_tensor(x_scale.clone())\n\n    out_deepgemm = fp8_gemm_deepgemm(\n        x_fp8.clone(),\n        x_scale_col_major.clone(),\n        y_fp8.clone(),\n        y_scale.clone(),\n        m,\n        n,\n        k,\n    )\n    out_sglang = fp8_gemm_sglang(\n        x_fp8.clone(), x_scale.clone(), y_fp8.clone(), y_scale.clone(), m, n, k\n    )\n\n    tilelang_func = tl_gemm(m, n, k, \"e4m3_float8\", \"bfloat16\", \"float32\")\n    tilelang_kernel = tilelang.compile(tilelang_func, out_idx=[-1])\n    out_tilelang = tilelang_kernel(\n        x_fp8.clone(), x_scale.clone(), y_fp8.clone(), y_scale.clone()\n    )\n\n    diff_sglang_deepgemm = torch.abs(out_deepgemm - out_sglang).mean().item()\n    diff_tilelang_deepgemm = torch.abs(out_deepgemm - out_tilelang).mean().item()\n    diff_tilelang_sglang = torch.abs(out_tilelang - out_sglang).mean().item()\n\n    print(f\"Shape m={m}, n={n}, k={k}:\")\n    print(f\"DeepGEMM output: {out_deepgemm[0, 0:5]}\")\n    print(f\"SGLang output: {out_sglang[0, 0:5]}\")\n    print(f\"TileLang output: {out_tilelang[0, 0:5]}\")\n    print(f\"Mean absolute difference (SGLang-DeepGEMM): {diff_sglang_deepgemm}\")\n    print(f\"Mean absolute difference (TileLang-DeepGEMM): {diff_tilelang_deepgemm}\")\n    print(f\"Mean absolute difference (TileLang-SGLang): {diff_tilelang_sglang}\")\n\n    sglang_deepgemm_match = torch.allclose(\n        out_deepgemm, out_sglang, atol=1e-2, rtol=1e-2\n    )\n    tilelang_deepgemm_match = torch.allclose(\n        out_deepgemm, out_tilelang, atol=1e-2, rtol=1e-2\n    )\n    tilelang_sglang_match = torch.allclose(\n        out_tilelang, out_sglang, atol=1e-2, rtol=1e-2\n    )\n\n    if sglang_deepgemm_match and tilelang_deepgemm_match and tilelang_sglang_match:\n        print(\"✅ All implementations match\\n\")\n    else:\n        print(\"❌ Some implementations differ:\")\n        print(f\"  - SGLang vs DeepGEMM: {'✅' if sglang_deepgemm_match else '❌'}\")\n        print(f\"  - TileLang vs DeepGEMM: {'✅' if tilelang_deepgemm_match else '❌'}\")\n        print(f\"  - TileLang vs SGLang: {'✅' if tilelang_sglang_match else '❌'}\\n\")\n\n\ndef get_weight_shapes(tp_size):\n    # cannot TP\n    total = [\n        (512 + 64, 7168),\n        ((128 + 64) * 128, 7168),\n        (128 * (128 + 128), 512),\n        (7168, 16384),\n        (7168, 18432),\n    ]\n    # N can TP\n    n_tp = [\n        (18432 * 2, 7168),\n        ((128 + 64) * 128, 7168),\n        (128 * (128 + 128), 512),\n        (24576, 1536),\n        (4096, 7168),\n    ]\n    # K can TP\n    k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)]\n\n    weight_shapes = []\n    for t in total:\n        weight_shapes.append(t)\n    for n_t in n_tp:\n        new_t = (n_t[0] // tp_size, n_t[1])\n        weight_shapes.append(new_t)\n    for k_t in k_tp:\n        new_t = (k_t[0], k_t[1] // tp_size)\n        weight_shapes.append(new_t)\n\n    return weight_shapes\n\n\ndef create_benchmark_configs(tp_size):\n    configs = []\n    weight_shapes = get_weight_shapes(tp_size)\n    batch_sizes = [8, 16, 32, 64, 128, 256, 1024, 2048, 4096]\n\n    for n, k in weight_shapes:\n        for m in batch_sizes:\n            configs.append((m, n, k, tp_size))\n\n    return configs\n\n\ndef get_benchmark(tp_size):\n    all_configs = create_benchmark_configs(tp_size)\n\n    @triton.testing.perf_report(\n        triton.testing.Benchmark(\n            x_names=[\"m\", \"n\", \"k\", \"tp_size\"],\n            x_vals=[list(config) for config in all_configs],\n            line_arg=\"provider\",\n            line_vals=[\"deepgemm\", \"sglang\", \"tilelang\"],\n            line_names=[\"DeepGEMM\", \"SGLang\", \"TileLang\"],\n            styles=[(\"blue\", \"-\"), (\"red\", \"-\"), (\"green\", \"-\")],\n            ylabel=\"ms\",\n            plot_name=f\"fp8-gemm-performance-comparison-tp{tp_size}\",\n            args={},\n        )\n    )\n    def benchmark(m, n, k, tp_size, provider):\n        print(f\"Shape (m={m}, n={n}, k={k}, tp={tp_size}), Provider: {provider}\")\n        x = torch.randn((m, k), device=\"cuda\", dtype=torch.bfloat16)\n        y = torch.randn((n, k), device=\"cuda\", dtype=torch.bfloat16)\n\n        # Preprocess data before benchmarking\n        x_fp8, x_scale = per_token_cast_to_fp8(x)\n        y_fp8, y_scale = per_block_cast_to_fp8(y)\n        x_scale_col_major = get_mn_major_tma_aligned_tensor(x_scale.clone())\n\n        quantiles = (0.5, 0.2, 0.8)\n\n        if provider == \"deepgemm\":\n            ms, min_ms, max_ms = run_bench(\n                lambda: fp8_gemm_deepgemm(\n                    x_fp8.clone(),\n                    x_scale_col_major.clone(),\n                    y_fp8.clone(),\n                    y_scale.clone(),\n                    m,\n                    n,\n                    k,\n                ),\n                quantiles=quantiles,\n            )\n        elif provider == \"sglang\":\n            ms, min_ms, max_ms = run_bench(\n                lambda: fp8_gemm_sglang(\n                    x_fp8.clone(),\n                    x_scale.clone(),\n                    y_fp8.clone(),\n                    y_scale.clone(),\n                    m,\n                    n,\n                    k,\n                ),\n                quantiles=quantiles,\n            )\n        else:  # tilelang\n            tilelang_func = tl_gemm(m, n, k, \"e4m3_float8\", \"bfloat16\", \"float32\")\n            tilelang_kernel = tilelang.compile(tilelang_func, out_idx=[-1])\n            ms, min_ms, max_ms = run_bench(\n                lambda: tilelang_kernel(\n                    x_fp8.clone(),\n                    x_scale.clone(),\n                    y_fp8.clone(),\n                    y_scale.clone(),\n                ),\n                quantiles=quantiles,\n            )\n\n        # Calculate TFLOPS\n        flops = 2 * m * n * k  # multiply-adds\n        tflops = flops / (ms * 1e-3) / 1e12\n\n        # Print shape-specific results with TFLOPS\n        print(f\"Time: {ms*1000:.2f} ms, TFLOPS: {tflops:.2f}\")\n        return ms * 1000, max_ms * 1000, min_ms * 1000  # convert to ms\n\n    return benchmark\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--save_path\",\n        type=str,\n        default=\"./configs/benchmark_ops/fp8_gemm/\",\n        help=\"Path to save fp8 gemm benchmark results\",\n    )\n    parser.add_argument(\n        \"--run_correctness\",\n        action=\"store_true\",\n        default=True,\n        help=\"Whether to run correctness test\",\n    )\n    parser.add_argument(\n        \"--tp_size\",\n        type=int,\n        default=1,\n        help=\"Tensor parallelism size to benchmark (default: 1)\",\n    )\n    args = parser.parse_args()\n\n    # Set random seed for reproducibility\n    torch.manual_seed(0)\n    torch.cuda.manual_seed(0)\n\n    # Enable TF32, adapted from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_core.py#L148\n    torch.backends.cuda.matmul.allow_tf32 = True\n    torch.backends.cudnn.allow_tf32 = True\n\n    # Run correctness tests on a few examples\n    if args.run_correctness:\n        print(\"Running correctness tests...\")\n        calculate_diff(64, 512, 7168)  # Small test\n        calculate_diff(64, 7168, 16384)  # Medium test\n        calculate_diff(64, 18432, 7168)  # Large test\n\n    # Get the benchmark function with the specified tp_size\n    benchmark = get_benchmark(args.tp_size)\n\n    print(f\"Running performance benchmark for TP size = {args.tp_size}...\")\n    benchmark.run(print_data=True, save_path=args.save_path)\n"
  },
  {
    "path": "benchmark/kernels/deepseek/benchmark_deepgemm_fp8_gemm_blackwell.py",
    "content": "import argparse\nfrom typing import Tuple\n\nimport torch\nimport triton\nfrom deep_gemm import ceil_div\nfrom flashinfer.gemm import gemm_fp8_nt_groupwise\n\nfrom sglang.benchmark.bench_utils import run_bench\nfrom sglang.srt.layers.quantization.fp8_kernel import (\n    sglang_per_token_group_quant_fp8,\n    w8a8_block_fp8_matmul_deepgemm,\n)\nfrom sglang.srt.layers.quantization.fp8_utils import requant_weight_ue8m0\n\nBLOCK_SIZE = 128\n\n\ndef per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n    assert x.dim() == 2\n    assert BLOCK_SIZE == 128\n    m, n = x.shape\n    x_padded = torch.zeros(\n        (ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device\n    )\n    x_padded[:m, :n] = x\n    x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)\n    x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)\n    x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)\n    return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(\n        x_view.size(0), x_view.size(2)\n    )\n\n\ndef get_weight_shapes(tp_size):\n    # cannot TP\n    total = [\n        (512 + 64, 7168),\n        ((128 + 64) * 128, 7168),\n        (128 * (128 + 128), 512),\n        (7168, 16384),\n        (7168, 18432),\n    ]\n    # N can TP\n    n_tp = [\n        (18432 * 2, 7168),\n        ((128 + 64) * 128, 7168),\n        (128 * (128 + 128), 512),\n        (24576, 1536),\n        (4096, 7168),\n    ]\n    # K can TP\n    k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)]\n\n    weight_shapes = []\n    for t in total:\n        weight_shapes.append(t)\n    for n_t in n_tp:\n        new_t = (n_t[0] // tp_size, n_t[1])\n        weight_shapes.append(new_t)\n    for k_t in k_tp:\n        new_t = (k_t[0], k_t[1] // tp_size)\n        weight_shapes.append(new_t)\n\n    return weight_shapes\n\n\ndef create_benchmark_configs(tp_size):\n    configs = []\n    weight_shapes = get_weight_shapes(tp_size)\n    batch_sizes = [8, 16, 32, 64, 128, 256, 1024, 2048, 4096]\n\n    for n, k in weight_shapes:\n        for m in batch_sizes:\n            configs.append((m, n, k, tp_size))\n\n    return configs\n\n\ndef fp8_gemm_flashinfer(\n    x_fp8: torch.Tensor,\n    x_scale: torch.Tensor,\n    y_fp8: torch.Tensor,\n    y_scale: torch.Tensor,\n):\n    \"\"\"Flashinfer implementation of FP8 GEMM\"\"\"\n    output = gemm_fp8_nt_groupwise(\n        x_fp8,\n        y_fp8,\n        x_scale,\n        y_scale,\n        out_dtype=torch.bfloat16,\n        backend=\"trtllm\",\n    )\n    return output\n\n\ndef fp8_gemm_deepgemm_blackwell(\n    x_fp8: torch.Tensor,\n    x_scale: torch.Tensor,\n    y_fp8: torch.Tensor,\n    y_scale: torch.Tensor,\n):\n    \"\"\"DeepGEMM implementation of FP8 GEMM\"\"\"\n    block_size = [BLOCK_SIZE, BLOCK_SIZE]\n    output = w8a8_block_fp8_matmul_deepgemm(\n        x_fp8, y_fp8, x_scale, y_scale, block_size, output_dtype=torch.bfloat16\n    )\n    return output\n\n\ndef check_accuracy(a, b, atol, rtol, percent):\n    \"\"\"Unified accuracy checking function with detailed error reporting.\"\"\"\n    if not torch.isfinite(a).all():\n        print(\"Non-finite values in reference output\")\n        return False\n    if not torch.isfinite(b).all():\n        print(\"Non-finite values in actual output\")\n        return False\n    assert a.shape == b.shape, f\"Shape mismatch: {a.shape} vs {b.shape}\"\n\n    close = torch.isclose(a, b, atol=atol, rtol=rtol)\n    match_ratio = close.float().mean()\n    if match_ratio >= percent:\n        return True\n\n    mismatch_percent = 1.0 - match_ratio.item()\n    if mismatch_percent > 1 - percent:\n        print(\n            f\"Mismatch percentage is {mismatch_percent:.4f} for rtol {rtol} \"\n            f\"(threshold: {1 - percent:.4f})\"\n        )\n        return False\n\n\ndef calculate_diff(m: int, n: int, k: int):\n    x = torch.randn((m, k), device=\"cuda\", dtype=torch.bfloat16)\n    y = torch.randn((n, k), device=\"cuda\", dtype=torch.bfloat16)\n\n    y_fp8, y_scale = per_block_cast_to_fp8(y)\n    x_fp8, x_scale = sglang_per_token_group_quant_fp8(\n        x, BLOCK_SIZE, column_major_scales=True\n    )\n    out_flashinfer = fp8_gemm_flashinfer(\n        x_fp8,\n        x_scale,\n        y_fp8,\n        y_scale,\n    )\n\n    dg_x_fp8, dg_x_scale = sglang_per_token_group_quant_fp8(\n        x,\n        BLOCK_SIZE,\n        column_major_scales=True,\n        scale_tma_aligned=True,\n        scale_ue8m0=True,\n    )\n    # We can directly quantize y here, but to mimic the behavior of the actual\n    # implementations, we requant it here.\n    dg_y_fp8, dg_y_scale = requant_weight_ue8m0(\n        y_fp8, y_scale, [BLOCK_SIZE, BLOCK_SIZE]\n    )\n    out_deepgemm = fp8_gemm_deepgemm_blackwell(\n        dg_x_fp8, dg_x_scale, dg_y_fp8, dg_y_scale\n    )\n\n    print(f\"Shape m={m}, n={n}, k={k}:\")\n    print(f\"Flashinfer output: {out_flashinfer[0, 0:5]}\")\n    print(f\"DeepGEMM output: {out_deepgemm[0, 0:5]}\")\n\n    flashinfer_deepgemm_match = check_accuracy(\n        out_flashinfer, out_deepgemm, 0.1, 0.6, 0.95\n    )\n    print(\"Correctness check:\")\n    print(f\"  - Flashinfer vs DeepGEMM: {'✅' if flashinfer_deepgemm_match else '❌'}\")\n\n\ndef _benchmark(m, n, k, tp_size, provider):\n    print(f\"Shape (m={m}, n={n}, k={k}, tp={tp_size}), Provider: {provider}\")\n    x = torch.randn((m, k), device=\"cuda\", dtype=torch.bfloat16)\n    y = torch.randn((n, k), device=\"cuda\", dtype=torch.bfloat16)\n\n    # Preprocess data before benchmarking\n    y_fp8, y_scale = per_block_cast_to_fp8(y)\n    x_fp8, x_scale = sglang_per_token_group_quant_fp8(\n        x, BLOCK_SIZE, column_major_scales=True\n    )\n    dg_x_fp8, dg_x_scale = sglang_per_token_group_quant_fp8(\n        x,\n        BLOCK_SIZE,\n        column_major_scales=True,\n        scale_tma_aligned=True,\n        scale_ue8m0=True,\n    )\n    dg_y_fp8, dg_y_scale = requant_weight_ue8m0(\n        y_fp8, y_scale, [BLOCK_SIZE, BLOCK_SIZE]\n    )\n\n    quantiles = (0.5, 0.2, 0.8)\n\n    if provider == \"deepgemm\":\n        ms, min_ms, max_ms = run_bench(\n            lambda: fp8_gemm_deepgemm_blackwell(\n                dg_x_fp8,\n                dg_x_scale,\n                dg_y_fp8,\n                dg_y_scale,\n            ),\n            quantiles=quantiles,\n        )\n    elif provider == \"flashinfer\":\n        ms, min_ms, max_ms = run_bench(\n            lambda: fp8_gemm_flashinfer(\n                x_fp8,\n                x_scale,\n                y_fp8,\n                y_scale,\n            ),\n            quantiles=quantiles,\n        )\n\n    # Calculate TFLOPS\n    flops = 2 * m * n * k  # multiply-adds\n    tflops = flops / (ms * 1e-3) / 1e12\n\n    # Print shape-specific results with TFLOPS\n    print(f\"Time: {ms*1000:.2f} us, TFLOPS: {tflops:.2f}\")\n    return ms, max_ms, min_ms\n\n\ndef get_benchmark_plot_friendly(tp_size):\n    all_configs = create_benchmark_configs(tp_size)\n    x_vals = list(range(len(all_configs)))\n\n    @triton.testing.perf_report(\n        triton.testing.Benchmark(\n            x_names=[\"cfg_id\"],\n            x_vals=x_vals,\n            line_arg=\"provider\",\n            line_vals=[\"deepgemm\", \"flashinfer\"],\n            line_names=[\"DeepGEMM\", \"Flashinfer\"],\n            styles=[(\"blue\", \"-\"), (\"red\", \"-\")],\n            ylabel=\"us\",\n            plot_name=f\"fp8-gemm-performance-comparison-tp{tp_size}\",\n            args={},\n        )\n    )\n    def benchmark(cfg_id, provider):\n        m, n, k, tp_size = all_configs[cfg_id]\n        ms, min_ms, max_ms = _benchmark(m, n, k, tp_size, provider)\n        return ms * 1000, max_ms * 1000, min_ms * 1000  # convert to ms\n\n    return benchmark\n\n\ndef get_benchmark(tp_size):\n    all_configs = create_benchmark_configs(tp_size)\n\n    @triton.testing.perf_report(\n        triton.testing.Benchmark(\n            x_names=[\"m\", \"n\", \"k\", \"tp_size\"],\n            x_vals=[list(config) for config in all_configs],\n            line_arg=\"provider\",\n            line_vals=[\"deepgemm\", \"flashinfer\"],\n            line_names=[\"DeepGEMM\", \"Flashinfer\"],\n            styles=[(\"blue\", \"-\"), (\"red\", \"-\")],\n            ylabel=\"us\",\n            plot_name=f\"fp8-gemm-performance-comparison-tp{tp_size}\",\n            args={},\n        )\n    )\n    def benchmark(m, n, k, tp_size, provider):\n        ms, min_ms, max_ms = _benchmark(m, n, k, tp_size, provider)\n        return ms * 1000, max_ms * 1000, min_ms * 1000  # convert to ms\n\n    return benchmark\n\n\nif __name__ == \"__main__\":\n    if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] != 10:\n        print(\"Skipping benchmark because the device is not supported\")\n        exit(0)\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--save-path\",\n        type=str,\n        default=\"./configs/benchmark_ops/fp8_gemm/\",\n        help=\"Path to save fp8 gemm benchmark results\",\n    )\n    parser.add_argument(\n        \"--run-correctness\",\n        action=\"store_true\",\n        default=True,\n        help=\"Whether to run correctness test\",\n    )\n    parser.add_argument(\n        \"--tp-size\",\n        type=int,\n        default=1,\n        help=\"Tensor parallelism size to benchmark (default: 1)\",\n    )\n    parser.add_argument(\n        \"--plot-friendly\",\n        action=\"store_true\",\n        default=False,\n        help=\"Plot x axis as the config index instead of the m\",\n    )\n    args = parser.parse_args()\n\n    # Set random seed for reproducibility\n    torch.manual_seed(0)\n    torch.cuda.manual_seed(0)\n\n    # Run correctness tests on a few examples\n    if args.run_correctness:\n        print(\"Running correctness tests...\")\n        calculate_diff(64, 512, 7168)  # Small test\n        calculate_diff(64, 7168, 16384)  # Medium test\n        calculate_diff(64, 18432, 7168)  # Large test\n\n    # Get the benchmark function with the specified tp_size\n    benchmark = (\n        get_benchmark_plot_friendly(args.tp_size)\n        if args.plot_friendly\n        else get_benchmark(args.tp_size)\n    )\n\n    print(f\"Running performance benchmark for TP size = {args.tp_size}...\")\n    benchmark.run(print_data=True, save_path=args.save_path)\n"
  },
  {
    "path": "benchmark/kernels/deepseek/benchmark_deepgemm_fp8_group_gemm.py",
    "content": "from typing import Tuple\n\nimport deep_gemm\nimport torch\nimport triton\nimport triton.language as tl\nfrom deep_gemm import calc_diff\nfrom deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor\n\n# Import shared functionality from the regular GEMM benchmark\nfrom sglang.benchmark.bench_utils import run_bench\nfrom sglang.benchmark.kernels.deepseek.benchmark_deepgemm_fp8_gemm import (\n    per_block_cast_to_fp8,\n    per_token_cast_to_fp8,\n)\n\n\ndef construct_grouped_and_flat_fp8(\n    x: torch.Tensor, y: torch.Tensor, num_groups: int, is_masked: bool\n) -> Tuple[\n    Tuple[torch.Tensor, torch.Tensor],  # grouped x_fp8\n    Tuple[torch.Tensor, torch.Tensor],  # grouped y_fp8\n    Tuple[torch.Tensor, torch.Tensor],  # flat x_fp8\n    Tuple[torch.Tensor, torch.Tensor],  # flat y_fp8\n    torch.Tensor,  # output\n    torch.Tensor,  # reference output\n]:\n    # Verify input shapes\n    m, k = x.shape\n    n, k_y = y.shape\n    assert k == k_y, f\"Incompatible shapes: x({m}, {k}), y({n}, {k_y})\"\n    assert m % num_groups == 0, f\"m({m}) must be divisible by num_groups({num_groups})\"\n    assert m % 4 == 0, f\"TMA alignment error: {m}\"\n\n    # Reshape inputs for grouped processing\n    m_per_group = m // num_groups\n    x_grouped = x.view(num_groups, m_per_group, k)\n    y_grouped = y.unsqueeze(0).expand(num_groups, n, k)\n\n    # Initialize output tensors\n    out = torch.empty((num_groups, m_per_group, n), device=\"cuda\", dtype=torch.bfloat16)\n    ref_out = torch.einsum(\"gmk,gnk->gmn\", x_grouped, y_grouped)\n\n    # Quantize grouped tensors\n    x_fp8_grouped = (\n        torch.empty_like(x_grouped, dtype=torch.float8_e4m3fn),\n        torch.empty(\n            (num_groups, m_per_group, k // 128), device=\"cuda\", dtype=torch.float\n        ),\n    )\n    y_fp8_grouped = (\n        torch.empty_like(y_grouped, dtype=torch.float8_e4m3fn),\n        torch.empty(\n            (num_groups, (n + 127) // 128, k // 128), device=\"cuda\", dtype=torch.float\n        ),\n    )\n    for i in range(num_groups):\n        x_fp8_grouped[0][i], x_fp8_grouped[1][i] = per_token_cast_to_fp8(x_grouped[i])\n        y_fp8_grouped[0][i], y_fp8_grouped[1][i] = per_block_cast_to_fp8(y_grouped[i])\n\n    # Quantize flat tensors\n    x_fp8_flat = per_token_cast_to_fp8(x)\n    y_fp8_flat = per_block_cast_to_fp8(y)\n\n    # For non-masked input, merge the group and M dims in output\n    if not is_masked:\n        x_fp8_grouped = (\n            x_fp8_grouped[0].view(-1, k),\n            per_token_cast_to_fp8(x_grouped.view(-1, k))[1],\n        )\n        out, ref_out = out.view(-1, n), ref_out.view(-1, n)\n\n    # Transpose earlier for testing\n    x_fp8_grouped = (\n        x_fp8_grouped[0],\n        get_mn_major_tma_aligned_tensor(x_fp8_grouped[1]),\n    )\n    x_fp8_flat = (x_fp8_flat[0], get_mn_major_tma_aligned_tensor(x_fp8_flat[1]))\n\n    return x_fp8_grouped, y_fp8_grouped, x_fp8_flat, y_fp8_flat, out, ref_out\n\n\n# Since we don't have a group gemm kernel in SGLang/vLLM, we implemented a\n# custom kernel based on the Triton tutorial.\n# https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html\n@triton.jit\ndef fp8_gemm_group_triton_kernel(\n    # Pointers to matrices\n    a_ptr,\n    b_ptr,\n    c_ptr,\n    # Pointers to scaling factors\n    a_scale_ptr,\n    b_scale_ptr,\n    # Matrix dimensions\n    M,\n    N,\n    K,\n    # The stride variables represent how much to increase the ptr by when moving by 1\n    # element in a particular dimension.\n    stride_am,\n    stride_ak,\n    stride_bk,\n    stride_bn,\n    stride_cm,\n    stride_cn,\n    # Strides for scaling factors\n    stride_a_scale_m,\n    stride_a_scale_k,\n    stride_b_scale_n,\n    stride_b_scale_k,\n    # Meta-parameters\n    BLOCK_SIZE_M: tl.constexpr,\n    BLOCK_SIZE_N: tl.constexpr,\n    BLOCK_SIZE_K: tl.constexpr,\n    GROUP_SIZE_M: tl.constexpr,\n):\n    \"\"\"Kernel for computing the matmul C = A x B with FP8 inputs and scaling factors.\n    A has shape (M, K), B has shape (K, N) and C has shape (M, N)\n\n    Note: Block sizes must be multiples of 32 for optimal TMA performance.\n    \"\"\"\n    # Map program ids to the block of C it should compute\n    pid_group = tl.program_id(axis=0)  # Group ID\n    pid_n = tl.program_id(axis=1)  # N dimension ID\n\n    # Compute the M block ID within this group\n    group_size_m = min(M - pid_group * GROUP_SIZE_M, GROUP_SIZE_M)\n    pid_m_within_group = tl.program_id(axis=2) % group_size_m\n    pid_m = pid_group * GROUP_SIZE_M + pid_m_within_group\n\n    # Create pointers for the first blocks of A and B\n    offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M\n    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n    offs_k = tl.arange(0, BLOCK_SIZE_K)\n    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n    # Initialize accumulator\n    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n    # Main loop\n    for k_block in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n        k_offset = k_block * BLOCK_SIZE_K\n\n        # Load the next block of A and B, with masks\n        a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k_offset, other=0.0)\n        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k_offset, other=0.0)\n\n        # Calculate indices for scaling factors for this K block\n        a_scale_ptrs = a_scale_ptr + (\n            offs_am * stride_a_scale_m + k_block * stride_a_scale_k\n        )\n        b_scale_ptrs = b_scale_ptr + (\n            pid_n * stride_b_scale_n + k_block * stride_b_scale_k\n        )\n\n        # Perform matrix multiplication in FP8\n        res = tl.dot(a, b)\n\n        # Load scaling factors for the current block\n        a_scale = tl.load(a_scale_ptrs)[:, None]  # [BLOCK_SIZE_M, 1]\n        b_scale = tl.load(b_scale_ptrs)\n\n        # Apply scaling factors to the accumulated result\n        accumulator += res * a_scale * b_scale\n\n        # Advance pointers\n        a_ptrs += BLOCK_SIZE_K * stride_ak\n        b_ptrs += BLOCK_SIZE_K * stride_bk\n\n    # Convert to bfloat16 for output\n    c = accumulator.to(tl.bfloat16)\n\n    # Write back the result\n    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n    tl.store(c_ptrs, c, mask=c_mask)\n\n\ndef fp8_gemm_group_triton(a_tuple, b_tuple, c, num_groups):\n    \"\"\"\n    Perform matrix multiplication with FP8 inputs and proper scaling.\n\n    Args:\n        a_tuple: Tuple of (quantized_tensor, scale_factors) for input A\n        b_tuple: Tuple of (quantized_tensor, scale_factors) for input B\n        c: Output tensor in BF16 format\n        num_groups: Number of groups for grouped GEMM\n\n    Returns:\n        Result tensor in BF16 format\n    \"\"\"\n    # Unpack the tuples\n    a, a_scale = a_tuple\n    b, b_scale = b_tuple\n\n    M, K = a.shape\n    _, N = b.shape\n\n    # Configure block sizes - must be multiples of 32 for TMA alignment\n    BLOCK_SIZE_M = 128\n    BLOCK_SIZE_N = 128\n    BLOCK_SIZE_K = 128\n\n    # Calculate grid dimensions\n    num_pid_m = triton.cdiv(M, BLOCK_SIZE_M)\n    num_pid_n = triton.cdiv(N, BLOCK_SIZE_N)\n    num_groups_grid = triton.cdiv(num_pid_m, num_groups)\n\n    # 3D grid launch - (group, n_blocks, m_blocks_per_group)\n    grid = (num_groups_grid, num_pid_n, min(num_groups, num_pid_m))\n\n    fp8_gemm_group_triton_kernel[grid](\n        a,\n        b,\n        c,\n        a_scale,\n        b_scale,\n        M,\n        N,\n        K,\n        a.stride(0),\n        a.stride(1),\n        b.stride(0),\n        b.stride(1),\n        c.stride(0),\n        c.stride(1),\n        a_scale.stride(0),\n        1,  # Stride in the K dimension may be 1\n        b_scale.stride(0),\n        1 if b_scale.dim() > 1 else 0,\n        BLOCK_SIZE_M=BLOCK_SIZE_M,\n        BLOCK_SIZE_N=BLOCK_SIZE_N,\n        BLOCK_SIZE_K=BLOCK_SIZE_K,\n        GROUP_SIZE_M=num_groups,\n    )\n\n    return c\n\n\ndef fp8_gemm_group_deepgemm(x_fp8_grouped, y_fp8_grouped, out, m_indices):\n    deep_gemm.m_grouped_fp8_gemm_nt_contiguous(\n        x_fp8_grouped,\n        y_fp8_grouped,\n        out,\n        m_indices,\n    )\n    return out\n\n\ndef calculate_diff(m: int, n: int, k: int, num_groups: int):\n    print(f\"Shape (m={m}, n={n}, k={k}\")\n    x = torch.randn((m, k), device=\"cuda\", dtype=torch.bfloat16)\n    y = torch.randn((n, k), device=\"cuda\", dtype=torch.bfloat16)\n    x_fp8_grouped, y_fp8_grouped, x_fp8_flat, y_fp8_flat, out, out_torch = (\n        construct_grouped_and_flat_fp8(x, y, num_groups, is_masked=False)\n    )\n    m_per_group = m // num_groups\n    out_deepgemm = out.clone()\n    m_indices = torch.arange(0, num_groups, device=\"cuda\", dtype=torch.int)\n    m_indices = (\n        m_indices.unsqueeze(-1).expand(num_groups, m_per_group).contiguous().view(-1)\n    )\n\n    fp8_gemm_group_deepgemm(\n        x_fp8_grouped,\n        y_fp8_grouped,\n        out_deepgemm,\n        m_indices,\n    )\n    torch.cuda.synchronize()\n\n    # Prepare inputs for Triton\n    a, a_scale = x_fp8_flat\n    b, b_scale = y_fp8_flat\n    b = b.T.contiguous()\n    # Ensure scales are in the right format and contiguous\n    a_scale, b_scale = a_scale.contiguous(), b_scale.contiguous()\n    M, _ = a.shape\n    _, N = b.shape\n    c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16)\n    out_triton = fp8_gemm_group_triton((a, a_scale), (b, b_scale), c, num_groups)\n    torch.cuda.synchronize()\n\n    diff_torch_deepgemm = torch.abs(out_torch - out_deepgemm).mean().item()\n    diff_torch_triton = torch.abs(out_torch - out_triton).mean().item()\n    diff_deepgemm_triton = torch.abs(out_deepgemm - out_triton).mean().item()\n\n    print(f\"Shape m={m}, n={n}, k={k}:\")\n    print(f\"Torch output: {out_torch[0, 0:5]}\")\n    print(f\"DeepGEMM output: {out_deepgemm[0, 0:5]}\")\n    print(f\"Triton output: {out_triton[0, 0:5]}\")\n    print(f\"Mean absolute difference (Torch-DeepGEMM): {diff_torch_deepgemm}\")\n    print(f\"Mean absolute difference (Torch-Triton): {diff_torch_triton}\")\n    print(f\"Mean absolute difference (DeepGEMM-Triton): {diff_deepgemm_triton}\")\n\n    deepgemm_torch_diff = calc_diff(out_deepgemm, out_torch)\n    triton_torch_diff = calc_diff(out_triton, out_torch)\n    deepgemm_triton_diff = calc_diff(out_deepgemm, out_triton)\n\n    DIFF_THRESHOLD = 0.001\n    all_match = (\n        deepgemm_torch_diff < DIFF_THRESHOLD\n        and triton_torch_diff < DIFF_THRESHOLD\n        and deepgemm_triton_diff < DIFF_THRESHOLD\n    )\n    if all_match:\n        print(\"✅ All implementations match\\n\")\n    else:\n        print(\"❌ Some implementations differ:\")\n        print(\n            f\"  - Torch vs DeepGEMM: {'✅' if deepgemm_torch_diff < DIFF_THRESHOLD else '❌'}\"\n            f\"  - Torch vs Triton: {'✅' if triton_torch_diff < DIFF_THRESHOLD else '❌'}\"\n            f\"  - DeepGEMM vs Triton: {'✅' if deepgemm_triton_diff < DIFF_THRESHOLD else '❌'}\"\n        )\n\n\ndef get_weight_shapes(tp_size):\n    # cannot TP\n    total = [\n        (512 + 64, 7168),\n        ((128 + 64) * 128, 7168),\n        (128 * (128 + 128), 512),\n        (7168, 16384),\n        (7168, 18432),\n    ]\n    # N can TP\n    n_tp = [\n        (18432 * 2, 7168),\n        ((128 + 64) * 128, 7168),\n        (128 * (128 + 128), 512),\n        (24576, 1536),\n        (4096, 7168),\n    ]\n    # K can TP\n    k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)]\n\n    weight_shapes = []\n    for t in total:\n        weight_shapes.append(t)\n    for n_t in n_tp:\n        new_t = (n_t[0] // tp_size, n_t[1])\n        weight_shapes.append(new_t)\n    for k_t in k_tp:\n        new_t = (k_t[0], k_t[1] // tp_size)\n        weight_shapes.append(new_t)\n\n    return weight_shapes\n\n\ndef create_benchmark_configs(tp_size):\n    configs = []\n    weight_shapes = get_weight_shapes(tp_size)\n    batch_sizes = [2048, 4096]\n    group_sizes = [4, 8]\n    for n, k in weight_shapes:\n        for m in batch_sizes:\n            for num_groups in group_sizes:\n                configs.append((m, n, k, num_groups, tp_size))\n\n    return configs\n\n\ndef get_benchmark(tp_size):\n    all_configs = create_benchmark_configs(tp_size)\n\n    @triton.testing.perf_report(\n        triton.testing.Benchmark(\n            x_names=[\"m\", \"n\", \"k\", \"num_groups\", \"tp_size\"],\n            x_vals=[config for config in all_configs],\n            line_arg=\"provider\",\n            line_vals=[\"deepgemm\", \"triton\"],\n            line_names=[\"DeepGEMM\", \"Triton\"],\n            styles=[(\"blue\", \"-\"), (\"red\", \"-\")],\n            ylabel=\"ms\",\n            plot_name=f\"fp8-group-gemm-performance-comparison-tp{tp_size}\",\n            args={},\n        )\n    )\n    def benchmark(m, n, k, num_groups, tp_size, provider):\n        print(\n            f\"Shape (m={m}, n={n}, k={k}, tp={tp_size}, num_groups={num_groups}, Provider: {provider}\"\n        )\n        x = torch.randn((m, k), device=\"cuda\", dtype=torch.bfloat16)\n        y = torch.randn((n, k), device=\"cuda\", dtype=torch.bfloat16)\n        x_fp8_grouped, y_fp8_grouped, x_fp8_flat, y_fp8_flat, out, out_torch = (\n            construct_grouped_and_flat_fp8(x, y, num_groups, is_masked=False)\n        )\n        m_per_group = m // num_groups\n        m_indices = torch.arange(0, num_groups, device=\"cuda\", dtype=torch.int)\n        m_indices = (\n            m_indices.unsqueeze(-1)\n            .expand(num_groups, m_per_group)\n            .contiguous()\n            .view(-1)\n        )\n\n        quantiles = (0.5, 0.2, 0.8)\n\n        if provider == \"deepgemm\":\n            ms, min_ms, max_ms = run_bench(\n                lambda: fp8_gemm_group_deepgemm(\n                    x_fp8_grouped,\n                    y_fp8_grouped,\n                    out,\n                    m_indices,\n                ),\n                quantiles=quantiles,\n            )\n        elif provider == \"triton\":\n            # Prepare inputs for Triton\n            # We did it outside of the lambda function to make it fair comparison like deepgemm\n            a, a_scale = x_fp8_flat\n            b, b_scale = y_fp8_flat\n            b = b.T.contiguous()\n            # Ensure scales are in the right format and contiguous\n            a_scale, b_scale = a_scale.contiguous(), b_scale.contiguous()\n            M, _ = a.shape\n            _, N = b.shape\n            c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16)\n            ms, min_ms, max_ms = run_bench(\n                lambda: fp8_gemm_group_triton(\n                    (a, a_scale),\n                    (b, b_scale),\n                    c,\n                    num_groups,\n                ),\n                quantiles=quantiles,\n            )\n\n        # Calculate TFLOPS\n        flops = 2 * m * n * k  # multiply-adds\n        tflops = flops / (ms * 1e-3) / 1e12\n\n        print(f\"Time: {ms*1000:.2f} ms, TFLOPS: {tflops:.2f}\")\n        return ms * 1000, max_ms * 1000, min_ms * 1000  # convert to ms\n\n    return benchmark\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--save_path\",\n        type=str,\n        default=\"./configs/benchmark_ops/fp8_group_gemm/\",\n        help=\"Path to save deepgemm fp8 group gemm benchmark results\",\n    )\n    parser.add_argument(\n        \"--run_correctness\",\n        action=\"store_true\",\n        help=\"Whether to run correctness test\",\n    )\n    parser.add_argument(\n        \"--tp_size\",\n        type=int,\n        default=1,\n        help=\"Tensor parallelism size to benchmark (default: 1)\",\n    )\n    args = parser.parse_args()\n\n    # Set random seed for reproducibility\n    torch.manual_seed(0)\n    torch.cuda.manual_seed(0)\n\n    # Enable TF32, adapted from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_core.py#L148\n    torch.backends.cuda.matmul.allow_tf32 = True\n    torch.backends.cudnn.allow_tf32 = True\n\n    # Run correctness tests on a few examples\n    if args.run_correctness:\n        print(\"Running correctness tests...\")\n        calculate_diff(8192, 7168, 4096, 4)\n        calculate_diff(8192, 2048, 7168, 4)\n        calculate_diff(4096, 7168, 4096, 8)\n        calculate_diff(4096, 2048, 7168, 8)\n        calculate_diff(4096, 576, 7168, 8)\n\n    # Get the benchmark function with the specified tp_size\n    benchmark = get_benchmark(args.tp_size)\n\n    print(f\"Running performance benchmark for TP size = {args.tp_size}...\")\n    benchmark.run(print_data=True, save_path=args.save_path)\n"
  },
  {
    "path": "benchmark/kernels/elementwise/benchmark_concat_mla.py",
    "content": "import torch\nimport triton\nimport triton.language as tl\nfrom sgl_kernel import concat_mla_k as concat_mla_k_cuda\n\nfrom sglang.benchmark.bench_utils import run_bench\n\nDEVICE = triton.runtime.driver.active.get_active_torch_device()\n\nnum_local_heads = 128\nqk_nope_head_dim = 128\nqk_rope_head_dim = 64\n\n\ndef create_data(num_tokens):\n    k_nope_container = torch.randn(\n        (num_tokens, num_local_heads, qk_nope_head_dim + 128),\n        dtype=torch.bfloat16,\n        device=\"cuda\",\n    )\n    k_nope = k_nope_container[:, :, :qk_nope_head_dim]\n\n    k_rope_container = torch.randn(\n        (num_tokens, 1, 128 + qk_rope_head_dim), dtype=torch.bfloat16, device=\"cuda\"\n    )\n    k_rope = k_rope_container[:, :, -qk_rope_head_dim:]\n\n    k = torch.empty(\n        (num_tokens, num_local_heads, qk_nope_head_dim + qk_rope_head_dim),\n        dtype=torch.bfloat16,\n        device=\"cuda\",\n    )\n    return dict(k=k, k_nope=k_nope, k_rope=k_rope)\n\n\ndef fn_torch(k, k_nope, k_rope):\n    k[..., :qk_nope_head_dim] = k_nope\n    k[..., qk_nope_head_dim:] = k_rope\n\n\ndef fn_hack_non_strided(k, k_nope, k_rope):\n    k_flatten_view = k.flatten()\n    k_flatten_view[: k_nope.numel()] = k_nope.flatten()\n\n    k2 = k_flatten_view[k_nope.numel() :].view(k_rope.numel(), -1)\n    k2 = k_rope.flatten()[:, None]\n\n\n@torch.compile(dynamic=True)\ndef fn_torch_compiled(k, k_nope, k_rope):\n    return fn_torch(k, k_nope, k_rope)\n\n\ndef fn_cuda(k, k_nope, k_rope):\n    concat_mla_k_cuda(k, k_nope, k_rope)\n\n\n@triton.jit\ndef fn_triton_kernel(\n    k_ptr,\n    k_nope_ptr,\n    k_rope_ptr,\n    num_tokens,\n    QK_NOPE_HEAD_DIM: tl.constexpr,\n    QK_ROPE_HEAD_DIM: tl.constexpr,\n    NUM_LOCAL_HEADS: tl.constexpr,\n    K_NOPE_STRIDE_0: tl.constexpr,\n    K_NOPE_STRIDE_1: tl.constexpr,\n    K_STRIDE_0: tl.constexpr,\n    K_STRIDE_1: tl.constexpr,\n    K_ROPE_STRIDE_0: tl.constexpr,\n    BLOCK_ROWS: tl.constexpr,\n):\n    pid = tl.program_id(axis=0)\n\n    token_id = pid * BLOCK_ROWS + tl.arange(0, BLOCK_ROWS)\n    token_mask = token_id < num_tokens\n\n    head_id = tl.arange(0, NUM_LOCAL_HEADS)\n\n    # nope\n    nope_sub_id = tl.arange(0, QK_NOPE_HEAD_DIM)\n    offs_nope = (\n        token_id[:, None, None] * K_NOPE_STRIDE_0\n        + head_id[None, :, None] * K_NOPE_STRIDE_1\n        + nope_sub_id[None, None, :]\n    )\n    offs_k = (\n        token_id[:, None, None] * K_STRIDE_0\n        + head_id[None, :, None] * K_STRIDE_1\n        + nope_sub_id[None, None, :]\n    )\n    vals_nope = tl.load(k_nope_ptr + offs_nope, mask=token_mask[:, None, None])\n    tl.store(k_ptr + offs_k, vals_nope, mask=token_mask[:, None, None])\n\n    # rope\n    rope_sub_id = tl.arange(0, QK_ROPE_HEAD_DIM)\n    offs_rope = token_id[:, None, None] * K_ROPE_STRIDE_0 + rope_sub_id[None, None, :]\n    offs_k = (\n        token_id[:, None, None] * K_STRIDE_0\n        + head_id[None, :, None] * K_STRIDE_1\n        + rope_sub_id[None, None, :]\n        + QK_NOPE_HEAD_DIM\n    )\n    vals_rope = tl.load(k_rope_ptr + offs_rope, mask=token_mask[:, None, None])\n    tl.store(k_ptr + offs_k, vals_rope, mask=token_mask[:, None, None])\n\n\ndef fn_triton(k, k_nope, k_rope):\n    assert k.device == DEVICE and k_nope.device == DEVICE and k_rope.device == DEVICE\n    num_tokens, _, _ = k.shape\n    grid = lambda meta: (triton.cdiv(num_tokens, meta[\"BLOCK_ROWS\"]),)\n    fn_triton_kernel[grid](\n        k,\n        k_nope,\n        k_rope,\n        num_tokens,\n        QK_NOPE_HEAD_DIM=qk_nope_head_dim,\n        QK_ROPE_HEAD_DIM=qk_rope_head_dim,\n        NUM_LOCAL_HEADS=num_local_heads,\n        K_NOPE_STRIDE_0=k_nope.stride(0),\n        K_NOPE_STRIDE_1=k_nope.stride(1),\n        K_STRIDE_0=k.stride(0),\n        K_STRIDE_1=k.stride(1),\n        K_ROPE_STRIDE_0=k_rope.stride(0),\n        BLOCK_ROWS=16,\n    )\n\n\ndef execute_and_get_output(f, data):\n    data[\"k\"].zero_()\n    f(**data)\n    assert data[\"k\"].sum().item() != 0\n    return data[\"k\"].clone()\n\n\ntorch.manual_seed(0)\ndata = create_data(num_tokens=32768)\noutput_ref = execute_and_get_output(fn_torch, data)\noutput_exp = execute_and_get_output(fn_cuda, data)\n# print(output_ref)\n# print(output_exp)\nif not torch.all(output_ref == output_exp):\n    abs_delta = torch.abs(output_ref - output_exp)\n    raise AssertionError(\n        f\"{output_ref=} {output_exp=} \"\n        f\"{abs_delta=} \"\n        f\"{torch.argwhere(abs_delta != 0.0)=} \"\n    )\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=[\"num_tokens\"],  # Argument names to use as an x-axis for the plot.\n        x_vals=[\n            2048,\n            4096,\n            8192,\n            16384,\n            32768,\n        ],  # Different possible values for `x_name`.\n        x_log=False,  # x axis is logarithmic.\n        line_arg=\"provider\",  # Argument name whose value corresponds to a different line in the plot.\n        line_vals=[\n            \"torch\",\n            \"torch_compiled\",\n            \"triton\",\n            \"hack_non_strided\",\n            \"cuda\",\n        ],  # Possible values for `line_arg`.\n        line_names=[\n            \"torch\",\n            \"torch_compiled\",\n            \"triton\",\n            \"hack_non_strided\",\n            \"cuda\",\n        ],  # Label name for the lines.\n        plot_name=\"vector-add-performance\",  # Name for the plot. Used also as a file name for saving the plot.\n        args={},  # Values for function arguments not in `x_names` and `y_name`.\n    )\n)\ndef benchmark(num_tokens, provider):\n    data = create_data(num_tokens=num_tokens)\n    quantiles = (0.5, 0.2, 0.8)\n    fn = {\n        \"torch\": fn_torch,\n        \"torch_compiled\": fn_torch_compiled,\n        \"triton\": fn_triton,\n        \"hack_non_strided\": fn_hack_non_strided,\n        \"cuda\": fn_cuda,\n    }[provider]\n    ms, min_ms, max_ms = run_bench(lambda: fn(**data), quantiles=quantiles)\n    return ms, min_ms, max_ms\n\n\ntorch.cuda.cudart().cudaProfilerStart()\nbenchmark.run(print_data=True, show_plots=True)\ntorch.cuda.cudart().cudaProfilerStop()\n"
  },
  {
    "path": "benchmark/kernels/flashinfer_allreduce_fusion/README.md",
    "content": "# FlashInfer Fused AllReduce + RMSNorm Benchmark\n\nThis benchmark script is modified from the [original implementation](https://github.com/vllm-project/vllm/blob/237e1fb887c7f5a579420fa0295097f24b006594/benchmarks/kernels/benchmark_fused_collective.py) by the vLLM community. It aims to compare the performance differences between FlashInfer fused operators in SGLang (trtllm_allreduce_fusion: AllReduce + Residual Add + RMSNorm + optional quantization) and conventional implementations (standard `tensor_model_parallel_all_reduce` + separate RMSNorm/quantization). Specifically, this script tests the timing performance of two implementation paths: 1) Standard AllReduce and RMSNorm executed separately; 2) FlashInfer's fused operator combining AllReduce, Residual Add, RMSNorm, and optional quantization operations.\n\nThis benchmark script helps us tune the ipc workspace size of the `flashinfer_allreduce_residual_rmsnorm` operator in SGLang and prepare for applications with FP8/FP4 quantized fused operators.\n\nScript path: `benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py`\n\n## Feature Overview\n\n- Compare average execution time (ms) and calculate speedup ratios for the following paths:\n  - standard_allreduce_rmsnorm (Standard AllReduce + RMSNorm)\n  - flashinfer_fused_allreduce_rmsnorm (Fused AllReduce + RMSNorm), including oneshot and twoshot modes\n  - Optionally compare FP8/FP4 quantized fused paths with standard paths\n- Use CUDA Graph capture and batch replay to reduce measurement noise\n- Automatically select the faster \"standard baseline\" (native/compiled version) as the denominator for speedup calculation\n- Optionally export results in Markdown format\n\n## Runtime Environment and Prerequisites\n\n- At least 2 GPUs, and launch multi-process distributed training using `torchrun` (NCCL backend)\n- Properly install/compile sglang along with sgl-kernel and custom operators\n\n## Quick Start (Command Examples)\n\nThe following examples use world_size=2. You can modify `--nproc_per_node` and parameters according to your machine:\n\n- Regular paths only (no quantization):\n```\ntorchrun --nproc_per_node=2 \\\nbenchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py \\\n--no-quant --hidden-dim 1024 --seq-lens 512 1024 2048 4096 --trials 100\n```\n\n- FP8 quantization paths only:\n```\ntorchrun --nproc_per_node=2 \\\nbenchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py \\\n--quant-fp8 --hidden-dim 1024 --seq-lens 512 1024 2048 4096 --trials 100\n```\n\n- FP4 quantization paths only:\n```\ntorchrun --nproc_per_node=2 \\\nbenchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py \\\n--quant-fp4 --hidden-dim 1024 --seq-lens 512 1024 2048 4096 --trials 100\n```\n\n- Larger hidden dimensions:\n```\ntorchrun --nproc_per_node=2 \\\nbenchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py \\\n--no-quant  --hidden-dim 4096 --seq-lens 512 1024 2048 4096 --trials 100\n```\n\n## Parameter Description\n- `--seq-lens`: List of sequence lengths to test (default: 128 512 1024 2048)\n- `--hidden-dim`: Hidden dimension (default: 8192)\n- `--dtypes`: Data type list, `float16|bfloat16|float32` (default: bfloat16)\n- `--no-residual`: Only test \"no residual\" scenarios (default tests both \"with/without residual\")\n- Mutually exclusive quantization options:\n  - `--no-quant`: No quantization testing\n  - `--quant-fp8`: Only FP8 quantization testing\n  - `--quant-fp4`: Only FP4 quantization testing\n  - `--quant-all`: Test all (default)\n- FlashInfer related:\n  - `--disable-oneshot`: Disable oneshot mode (default enables oneshot and tests twoshot simultaneously)\n- Runtime configuration:\n  - `--warmup`: Warmup count before graph capture and before graph replay (default 5)\n  - `--trials`: Benchmark iteration count (default 20; internally each `graph.replay()` will batch replay multiple times)\n  - `--output-file`: Save results as Markdown file (only rank0 takes effect)\n\n## Output Example\n\nEach configuration group prints a table showing average execution time and relative speedup ratios (baseline is the faster standard implementation). For example:\n```\n================================================================================\nResults: seq_len=1024, hidden_dim=1024\ndtype=torch.bfloat16, residual=yes, quant_mode=none\n================================================================================\nOperation                                          Time (ms)    Speedup\n--------------------------------------------------------------------------------\nstandard_allreduce_rmsnorm                         0.024        0.98x\nstandard_allreduce_rmsnorm_native_compiled         0.023        baseline\nflashinfer_fused_allreduce_rmsnorm_oneshot         0.011        2.19x\nflashinfer_fused_allreduce_rmsnorm_twoshot         0.041        0.57x\n```\n\nIf `--output-file` is specified, all configurations will be summarized in Markdown tables in that file.\n\n## Important Notes and Recommendations\n\n- Distributed: The script uses `torchrun` environment variables to initialize distributed training and binds tensors/communication groups to the current rank's corresponding device.\n- World size: Requires `WORLD_SIZE > 1` to perform communication operator benchmarks. Otherwise, the script will error and prompt.\n- FlashInfer:\n  - If not installed or interfaces are missing, the script will only run standard paths and provide prompts in the logs.\n  - The fused operator internally uses \"oneshot\"/\"twoshot\" two trigger methods; oneshot is enabled by default and twoshot is tested simultaneously.\n- FP8/FP4:\n  - FP8 uses sglang's FP8 tools and dtype, with underlying platform selection of `e4m3`/`e4m3fnuz` etc.\n  - FP4 uses sgl-kernel's `scaled_fp4_quant`, requiring corresponding platform support.\n- CUDA Graph:\n  - Uses sglang's `graph_capture()` to prepare capture-ready state for communication, then uses `torch.cuda.graph` to capture kernels, reducing measurement jitter.\n"
  },
  {
    "path": "benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py",
    "content": "# Modified from https://github.com/vllm-project/vllm/blob/237e1fb887c7f5a579420fa0295097f24b006594/benchmarks/kernels/benchmark_fused_collective.py\n\n\"\"\"\nBenchmark for FlashInfer fused collective operations vs standard operations.\n\nThis benchmark compares:\n1. FlashInfer's trtllm_allreduce_fusion (fused allreduce + rmsnorm + optional quant)\n2. Standard tensor_model_parallel_all_reduce + separate rmsnorm/quant operations\n\nUsage with torchrun:\n    torchrun --nproc_per_node=2 benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py --no-quant --hidden-dim 1024 --seq-len 512 1024 2048 4096 --trials 100\n    torchrun --nproc_per_node=2 benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py --quant-fp8 --hidden-dim 1024 --seq-len 512 1024 2048 4096 --trials 100\n    torchrun --nproc_per_node=2 benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py --quant-fp4 --hidden-dim 1024 --seq-len 512 1024 2048 4096 --trials 100\n\n    torchrun --nproc_per_node=2 benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py --no-quant --hidden-dim 4096 --seq-len 512 1024 2048 4096 --trials 100\n    torchrun --nproc_per_node=2 benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py --quant-fp8 --hidden-dim 4096 --seq-len 512 1024 2048 4096 --trials 100\n    torchrun --nproc_per_node=2 benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py --quant-fp4 --hidden-dim 4096 --seq-len 512 1024 2048 4096 --trials 100\n\"\"\"\n\nimport argparse\nimport contextlib\nimport itertools\nimport logging\nimport os\nimport time\nfrom typing import Optional\n\nimport torch  # type: ignore\nimport torch.distributed as dist  # type: ignore\n\nfrom sglang.srt.distributed import get_tp_group, tensor_model_parallel_all_reduce\nfrom sglang.srt.distributed.parallel_state import (\n    cleanup_dist_env_and_memory,\n    graph_capture,\n    init_distributed_environment,\n    initialize_model_parallel,\n)\nfrom sglang.srt.layers.layernorm import RMSNorm  # noqa\nfrom sglang.srt.layers.quantization.fp8_kernel import fp8_dtype as SGLANG_FP8_DTYPE\nfrom sglang.srt.layers.quantization.fp8_kernel import static_quant_fp8\n\ntry:\n    from sgl_kernel import fused_add_rmsnorm as SGL_FUSED_ADD_RMS_NORM\n    from sgl_kernel import rmsnorm as SGL_RMS_NORM\n\n    from sglang.jit_kernel.nvfp4 import scaled_fp4_quant as SGL_SCALED_FP4_QUANT\nexcept Exception:  # pragma: no cover - fallback on non-supported platforms\n    SGL_FUSED_ADD_RMS_NORM = None\n    SGL_RMS_NORM = None\n    SGL_SCALED_FP4_QUANT = None\n\nFP8_DTYPE = SGLANG_FP8_DTYPE\n\nlogger = logging.getLogger(__name__)\n\n# Try to import FlashInfer\ntry:\n    import flashinfer.comm as flashinfer_comm  # type: ignore\n\n    if not hasattr(flashinfer_comm, \"trtllm_allreduce_fusion\"):\n        flashinfer_comm = None\n        logger.warning(\n            \"FlashInfer comm module found but missing trtllm_allreduce_fusion\"\n        )\nexcept ImportError:\n    flashinfer_comm = None\n    logger.warning(\"FlashInfer not found, only benchmarking standard operations\")\n\n# Constants\nMiB = 1024 * 1024\n\n# FlashInfer max sizes per world size\n# Enable 64MB for 2, 4, 8 world sizes to verify large input sizes\n# use --disable-oneshot to disable oneshot mode for very large input sizes\n_FI_MAX_SIZES = {\n    2: 64 * MiB,  # 64MB\n    4: 64 * MiB,  # 64MB\n    8: 64 * MiB,  # 64MB\n}\n\n# Global workspace tensor for FlashInfer\n_FI_WORKSPACE_TENSOR = None\n\n\ndef setup_flashinfer_workspace(\n    world_size: int,\n    rank: int,\n    hidden_dim: int,\n    max_token_num: int,\n    use_fp32_lamport: bool = False,\n):\n    \"\"\"Setup FlashInfer workspace for fused allreduce operations.\"\"\"\n    global _FI_WORKSPACE_TENSOR\n\n    if flashinfer_comm is None:\n        return None, None\n\n    if world_size not in _FI_MAX_SIZES:\n        logger.warning(\"FlashInfer not supported for world size %s\", world_size)\n        return None, None\n\n    try:\n        # Create IPC workspace\n        ipc_handles, workspace_tensor = (\n            flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(\n                tp_rank=rank,\n                tp_size=world_size,\n                max_token_num=max_token_num,\n                hidden_dim=hidden_dim,\n                group=get_tp_group().device_group,\n                use_fp32_lamport=use_fp32_lamport,\n            )\n        )\n\n        _FI_WORKSPACE_TENSOR = workspace_tensor\n        return ipc_handles, workspace_tensor\n    except Exception as e:\n        logger.error(\"Failed to setup FlashInfer workspace: %s\", e)\n        return None, None\n\n\ndef cleanup_flashinfer_workspace(ipc_handles):\n    \"\"\"Cleanup FlashInfer workspace.\"\"\"\n    if flashinfer_comm is None or ipc_handles is None:\n        return\n\n    try:\n        group = get_tp_group().device_group\n        flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(ipc_handles, group)\n    except Exception as e:\n        logger.error(\"Failed to cleanup FlashInfer workspace: %s\", e)\n\n\nclass FlashInferFusedAllReduceParams:\n    \"\"\"Parameters for FlashInfer fused allreduce operations.\"\"\"\n\n    def __init__(\n        self,\n        rank: int,\n        world_size: int,\n        use_fp32_lamport: bool = False,\n        max_token_num: int = 1024,\n    ):\n        self.rank = rank\n        self.world_size = world_size\n        self.use_fp32_lamport = use_fp32_lamport\n        self.trigger_completion_at_end = True\n        self.launch_with_pdl = True\n        self.fp32_acc = True\n        self.max_token_num = max_token_num\n\n    def get_trtllm_fused_allreduce_kwargs(self):\n        return {\n            \"world_rank\": self.rank,\n            \"world_size\": self.world_size,\n            \"launch_with_pdl\": self.launch_with_pdl,\n            \"trigger_completion_at_end\": self.trigger_completion_at_end,\n            \"fp32_acc\": self.fp32_acc,\n        }\n\n\ndef flashinfer_fused_allreduce_rmsnorm(\n    input_tensor: torch.Tensor,\n    residual: Optional[torch.Tensor],\n    rms_gamma: torch.Tensor,\n    rms_eps: float,\n    allreduce_params: \"FlashInferFusedAllReduceParams\",\n    use_oneshot: bool,\n    norm_out: Optional[torch.Tensor] = None,\n):\n    \"\"\"FlashInfer fused allreduce + rmsnorm operation.\"\"\"\n    if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None:\n        raise RuntimeError(\"FlashInfer not available or workspace not initialized\")\n\n    if norm_out is None:\n        norm_out = input_tensor\n        residual_out = residual\n    else:\n        residual_out = input_tensor\n\n    flashinfer_comm.trtllm_allreduce_fusion(\n        allreduce_in=input_tensor,\n        token_num=input_tensor.shape[0],\n        residual_in=residual,\n        residual_out=residual_out,\n        norm_out=norm_out,\n        rms_gamma=rms_gamma,\n        rms_eps=rms_eps,\n        hidden_dim=input_tensor.shape[-1],\n        workspace_ptrs=_FI_WORKSPACE_TENSOR,\n        pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,\n        allreduce_out=None,\n        quant_out=None,\n        scale_out=None,\n        layout_code=None,\n        scale_factor=None,\n        use_oneshot=use_oneshot,\n        **allreduce_params.get_trtllm_fused_allreduce_kwargs(),\n    )\n\n\ndef flashinfer_fused_allreduce_rmsnorm_fp8_quant(\n    input_tensor: torch.Tensor,\n    residual: Optional[torch.Tensor],\n    rms_gamma: torch.Tensor,\n    rms_eps: float,\n    scale_factor: torch.Tensor,\n    allreduce_params: FlashInferFusedAllReduceParams,\n    use_oneshot: bool = True,\n    norm_out: Optional[torch.Tensor] = None,\n    quant_out: Optional[torch.Tensor] = None,\n):\n    \"\"\"FlashInfer fused allreduce + rmsnorm + FP8 quantization.\"\"\"\n    if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None:\n        raise RuntimeError(\"FlashInfer not available or workspace not initialized\")\n\n    if norm_out is None:\n        norm_out = input_tensor\n        residual_out = residual\n    else:\n        residual_out = input_tensor\n\n    flashinfer_comm.trtllm_allreduce_fusion(\n        allreduce_in=input_tensor,\n        token_num=input_tensor.shape[0],\n        residual_in=residual,\n        residual_out=residual_out,\n        norm_out=norm_out,\n        rms_gamma=rms_gamma,\n        rms_eps=rms_eps,\n        hidden_dim=input_tensor.shape[-1],\n        workspace_ptrs=_FI_WORKSPACE_TENSOR,\n        pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant,\n        allreduce_out=None,\n        quant_out=quant_out,\n        scale_out=None,\n        layout_code=None,\n        scale_factor=scale_factor,\n        use_oneshot=use_oneshot,\n        **allreduce_params.get_trtllm_fused_allreduce_kwargs(),\n    )\n\n\ndef flashinfer_fused_allreduce_rmsnorm_fp4_quant(\n    input_tensor: torch.Tensor,\n    residual: Optional[torch.Tensor],\n    rms_gamma: torch.Tensor,\n    rms_eps: float,\n    input_global_scale: torch.Tensor,\n    allreduce_params: FlashInferFusedAllReduceParams,\n    quant_out: torch.Tensor,\n    use_oneshot: bool,\n    output_scale: torch.Tensor,\n    norm_out: Optional[torch.Tensor] = None,\n):\n    \"\"\"FlashInfer fused allreduce + rmsnorm + FP4 quantization.\"\"\"\n    if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None:\n        raise RuntimeError(\"FlashInfer not available or workspace not initialized\")\n\n    if norm_out is None:\n        norm_out = input_tensor\n        residual_out = residual\n    else:\n        residual_out = input_tensor\n\n    flashinfer_comm.trtllm_allreduce_fusion(\n        allreduce_in=input_tensor,\n        token_num=input_tensor.shape[0],\n        residual_in=residual,\n        residual_out=residual_out,\n        norm_out=norm_out,\n        rms_gamma=rms_gamma,\n        rms_eps=rms_eps,\n        hidden_dim=input_tensor.shape[-1],\n        workspace_ptrs=_FI_WORKSPACE_TENSOR,\n        pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant,\n        allreduce_out=None,\n        quant_out=quant_out,\n        scale_out=output_scale,\n        layout_code=None,\n        scale_factor=input_global_scale,\n        use_oneshot=use_oneshot,\n        **allreduce_params.get_trtllm_fused_allreduce_kwargs(),\n    )\n\n\ndef standard_allreduce_rmsnorm(\n    input_tensor: torch.Tensor,\n    residual: Optional[torch.Tensor],\n    rms_gamma: torch.Tensor,\n    rms_eps: float,\n    norm_out: Optional[torch.Tensor] = None,\n):\n    \"\"\"Standard allreduce + rmsnorm operations.\"\"\"\n    # All-reduce first\n    allreduce_out = tensor_model_parallel_all_reduce(input_tensor)\n    # Then RMS norm\n    if residual is not None:\n        # Fused add + RMS norm (in-place on allreduce_out)\n        if SGL_FUSED_ADD_RMS_NORM is not None:\n            SGL_FUSED_ADD_RMS_NORM(allreduce_out, residual, rms_gamma, rms_eps)\n        else:\n            rms = RMSNorm(allreduce_out.shape[-1], eps=rms_eps)\n            rms.weight.data = rms_gamma\n            rms.forward_native(allreduce_out, residual)\n    else:\n        # Just RMS norm\n        if SGL_RMS_NORM is not None:\n            _ = SGL_RMS_NORM(allreduce_out, rms_gamma, rms_eps)\n        else:\n            rms = RMSNorm(allreduce_out.shape[-1], eps=rms_eps)\n            rms.weight.data = rms_gamma\n            _ = rms.forward_native(allreduce_out)\n\n\ndef standard_allreduce_rmsnorm_fp8_quant(\n    input_tensor: torch.Tensor,\n    residual: Optional[torch.Tensor],\n    rms_gamma: torch.Tensor,\n    rms_eps: float,\n    scale_factor: torch.Tensor,\n    norm_out: Optional[torch.Tensor] = None,\n    quant_out: Optional[torch.Tensor] = None,\n):\n    \"\"\"Standard allreduce + rmsnorm + FP8 quantization.\"\"\"\n    # All-reduce first\n    allreduce_out = tensor_model_parallel_all_reduce(input_tensor)\n\n    # Then RMS norm + static FP8 quantization\n    if residual is not None:\n        if SGL_FUSED_ADD_RMS_NORM is not None:\n            SGL_FUSED_ADD_RMS_NORM(allreduce_out, residual, rms_gamma, rms_eps)\n            quant_out, _ = static_quant_fp8(\n                allreduce_out, scale_factor, repeat_scale=False\n            )\n        else:\n            rms = RMSNorm(allreduce_out.shape[-1], eps=rms_eps)\n            rms.weight.data = rms_gamma\n            normed, _ = rms.forward_native(allreduce_out, residual)\n            quant_out, _ = static_quant_fp8(normed, scale_factor, repeat_scale=False)\n        return quant_out, residual\n    else:\n        if SGL_RMS_NORM is not None:\n            normed = SGL_RMS_NORM(allreduce_out, rms_gamma, rms_eps)\n        else:\n            rms = RMSNorm(allreduce_out.shape[-1], eps=rms_eps)\n            rms.weight.data = rms_gamma\n            normed = rms.forward_native(allreduce_out)\n        quant_out, _ = static_quant_fp8(normed, scale_factor, repeat_scale=False)\n        return quant_out\n\n\ndef standard_allreduce_rmsnorm_fp4_quant(\n    input_tensor: torch.Tensor,\n    residual: Optional[torch.Tensor],\n    rms_gamma: torch.Tensor,\n    rms_eps: float,\n    input_global_scale: torch.Tensor,\n    quant_out: torch.Tensor,\n    output_scale: torch.Tensor,\n    norm_out: Optional[torch.Tensor] = None,\n):\n    \"\"\"Standard allreduce + rmsnorm + FP4 quantization.\"\"\"\n\n    # All-reduce first\n    allreduce_out = tensor_model_parallel_all_reduce(input_tensor)\n\n    # Then RMS norm\n    if residual is not None:\n        if SGL_FUSED_ADD_RMS_NORM is not None:\n            SGL_FUSED_ADD_RMS_NORM(allreduce_out, residual, rms_gamma, rms_eps)\n            quant_input = allreduce_out\n        else:\n            rms = RMSNorm(allreduce_out.shape[-1], eps=rms_eps)\n            rms.weight.data = rms_gamma\n            quant_input, _ = rms.forward_native(allreduce_out, residual)\n        residual_out = residual\n    else:\n        if SGL_RMS_NORM is not None:\n            quant_input = SGL_RMS_NORM(allreduce_out, rms_gamma, rms_eps)\n        else:\n            rms = RMSNorm(allreduce_out.shape[-1], eps=rms_eps)\n            rms.weight.data = rms_gamma\n            quant_input = rms.forward_native(allreduce_out)\n        residual_out = allreduce_out\n\n    # Finally FP4 quantization\n    if SGL_SCALED_FP4_QUANT is None:\n        raise RuntimeError(\"scaled_fp4_quant is not available on this platform\")\n    quant_res, output_scale_res = SGL_SCALED_FP4_QUANT(quant_input, input_global_scale)\n    if residual is not None:\n        return quant_res, residual_out, output_scale_res\n    else:\n        return quant_res, quant_input\n\n\ndef standard_allreduce_rmsnorm_native(\n    input_tensor: torch.Tensor,\n    residual: Optional[torch.Tensor],\n    rmsnorm_layer: RMSNorm,\n    norm_out: Optional[torch.Tensor] = None,\n):\n    \"\"\"Standard allreduce + rmsnorm operations using native RMSNorm forward.\"\"\"\n    # All-reduce first\n    allreduce_out = tensor_model_parallel_all_reduce(input_tensor)\n    # Apply native RMSNorm\n    if residual is not None:\n        result = rmsnorm_layer.forward_native(allreduce_out, residual)\n        return result  # Returns (norm_out, residual_out)\n    else:\n        result = rmsnorm_layer.forward_native(allreduce_out)\n        return result  # Returns norm_out\n\n\ndef standard_allreduce_rmsnorm_fp8_quant_native(\n    input_tensor: torch.Tensor,\n    residual: Optional[torch.Tensor],\n    rmsnorm_layer: RMSNorm,\n    scale_factor: torch.Tensor,\n    norm_out: Optional[torch.Tensor] = None,\n    quant_out: Optional[torch.Tensor] = None,\n):\n    \"\"\"Standard allreduce + rmsnorm + FP8 quantization using native implementations.\"\"\"\n    # All-reduce first\n    allreduce_out = tensor_model_parallel_all_reduce(input_tensor)\n\n    # Apply native RMSNorm\n    if residual is not None:\n        norm_out, residual_out = rmsnorm_layer.forward_native(allreduce_out, residual)\n    else:\n        norm_out = rmsnorm_layer.forward_native(allreduce_out)\n        residual_out = allreduce_out\n\n    # Apply native FP8 quantization\n    quant_out, _ = static_quant_fp8(norm_out, scale_factor, repeat_scale=False)\n\n    if residual is not None:\n        return quant_out, residual_out\n    else:\n        return quant_out\n\n\ndef standard_allreduce_rmsnorm_fp4_quant_native(\n    input_tensor: torch.Tensor,\n    residual: Optional[torch.Tensor],\n    rmsnorm_layer: RMSNorm,\n    input_global_scale: torch.Tensor,\n    quant_out: torch.Tensor,\n    output_scale: torch.Tensor,\n    norm_out: Optional[torch.Tensor] = None,\n):\n    \"\"\"Standard allreduce + rmsnorm + FP4 quantization using native RMSNorm.\"\"\"\n    # All-reduce first\n    allreduce_out = tensor_model_parallel_all_reduce(input_tensor)\n\n    # Apply native RMSNorm\n    if residual is not None:\n        norm_out, residual_out = rmsnorm_layer.forward_native(allreduce_out, residual)\n        quant_input = norm_out\n    else:\n        norm_out = rmsnorm_layer.forward_native(allreduce_out)\n        quant_input = norm_out\n        residual_out = allreduce_out\n\n    # Apply FP4 quantization (still using fused CUDA op as there's no native FP4)\n    if SGL_SCALED_FP4_QUANT is None:\n        raise RuntimeError(\"scaled_fp4_quant is not available on this platform\")\n    quant_res, output_scale_res = SGL_SCALED_FP4_QUANT(quant_input, input_global_scale)\n\n    if residual is not None:\n        return quant_res, residual_out, output_scale_res\n    else:\n        return quant_res, norm_out\n\n\n# Compiled versions of native functions\n@torch.compile\ndef standard_allreduce_rmsnorm_native_compiled(\n    input_tensor: torch.Tensor,\n    residual: Optional[torch.Tensor],\n    rmsnorm_layer: RMSNorm,\n    norm_out: Optional[torch.Tensor] = None,\n):\n    \"\"\"Compiled version of standard allreduce + rmsnorm.\"\"\"\n    return standard_allreduce_rmsnorm_native(\n        input_tensor, residual, rmsnorm_layer, norm_out\n    )\n\n\n@torch.compile\ndef standard_allreduce_rmsnorm_fp8_quant_native_compiled(\n    input_tensor: torch.Tensor,\n    residual: Optional[torch.Tensor],\n    rmsnorm_layer: RMSNorm,\n    scale_factor: torch.Tensor,\n    norm_out: Optional[torch.Tensor] = None,\n    quant_out: Optional[torch.Tensor] = None,\n):\n    \"\"\"Compiled version of standard allreduce + rmsnorm + FP8 quantization.\"\"\"\n    return standard_allreduce_rmsnorm_fp8_quant_native(\n        input_tensor,\n        residual,\n        rmsnorm_layer,\n        scale_factor,\n        norm_out,\n        quant_out,\n    )\n\n\n@torch.compile\ndef standard_allreduce_rmsnorm_fp4_quant_native_compiled(\n    input_tensor: torch.Tensor,\n    residual: Optional[torch.Tensor],\n    rmsnorm_layer: RMSNorm,\n    input_global_scale: torch.Tensor,\n    quant_out: torch.Tensor,\n    output_scale: torch.Tensor,\n    norm_out: Optional[torch.Tensor] = None,\n):\n    \"\"\"Compiled version of standard allreduce + rmsnorm + FP4 quantization.\"\"\"\n    return standard_allreduce_rmsnorm_fp4_quant_native(\n        input_tensor,\n        residual,\n        rmsnorm_layer,\n        input_global_scale,\n        quant_out,\n        output_scale,\n        norm_out,\n    )\n\n\ndef create_test_tensors(\n    seq_len: int, hidden_dim: int, dtype: torch.dtype, use_residual: bool = True\n):\n    \"\"\"Create test tensors for benchmarking.\"\"\"\n    input_tensor = torch.randn(seq_len, hidden_dim, dtype=dtype)\n    residual = (\n        torch.randn_like(input_tensor)\n        if use_residual\n        else torch.zeros_like(input_tensor)\n    )\n    rms_gamma = torch.ones(hidden_dim, dtype=dtype)\n    norm_out = None if use_residual else torch.empty_like(input_tensor)\n\n    # Quantization scales\n    scale_fp8 = torch.tensor(1.0, dtype=torch.float32)\n    scale_fp4 = torch.tensor(1.0, dtype=torch.float32)\n    quant_out_fp8 = torch.empty_like(input_tensor, dtype=FP8_DTYPE)\n    # Pre-allocate FP4 output tensors (to avoid allocation overhead in benchmarks)\n    fp4_quant_out = torch.empty((seq_len, hidden_dim // 2), dtype=torch.uint8)\n    fp4_output_scale = torch.empty((128, 4), dtype=torch.int32)\n\n    return (\n        input_tensor,\n        norm_out,\n        residual,\n        rms_gamma,\n        scale_fp8,\n        quant_out_fp8,\n        scale_fp4,\n        fp4_quant_out,\n        fp4_output_scale,\n    )\n\n\ndef benchmark_operation(\n    operation_func, *args, warmup: int = 5, trials: int = 20, **kwargs\n):\n    \"\"\"Benchmark a single operation using CUDA graphs.\"\"\"\n    # Warmup before graph capture\n    for _ in range(warmup):\n        operation_func(*args, **kwargs)\n    torch.cuda.synchronize()\n\n    # Create CUDA graph\n    graph = torch.cuda.CUDAGraph()\n    num_op_per_cudagraph = 10\n\n    # Use sglang's graph_capture to make tensor_model_parallel_all_reduce graph-safe\n    with graph_capture() as graph_capture_context:\n        with torch.cuda.graph(graph, stream=graph_capture_context.stream):\n            for _ in range(num_op_per_cudagraph):\n                operation_func(*args, **kwargs)\n\n    # Graph warmup\n    torch.cuda.synchronize()\n    for _ in range(warmup):\n        graph.replay()\n\n    # Benchmark with CUDA graph\n    torch.cuda.synchronize()\n    start_time = time.perf_counter()\n\n    for _ in range(trials // num_op_per_cudagraph):\n        # operation_func(*args, **kwargs)\n        graph.replay()\n\n    torch.cuda.synchronize()\n    end_time = time.perf_counter()\n\n    avg_time_ms = ((end_time - start_time) / trials) * 1000\n    return avg_time_ms\n\n\ndef run_benchmarks(\n    seq_len: int,\n    hidden_dim: int,\n    dtype: torch.dtype,\n    use_residual: bool,\n    allreduce_params: Optional[FlashInferFusedAllReduceParams],\n    quant_mode: str = \"all\",\n    disable_oneshot: bool = False,\n):\n    \"\"\"Run all benchmarks for given configuration.\n\n    Args:\n        quant_mode: \"none\", \"fp8_only\", \"fp4_only\", or \"all\"\n    \"\"\"\n    (\n        input_tensor,\n        norm_out,\n        residual,\n        rms_gamma,\n        scale_fp8,\n        quant_out_fp8,\n        scale_fp4,\n        fp4_quant_out,\n        fp4_output_scale,\n    ) = create_test_tensors(seq_len, hidden_dim, dtype, use_residual)\n\n    rms_eps = 1e-6\n    results = {}\n\n    # Create RMSNorm once for native benchmarks\n    rmsnorm_layer = RMSNorm(hidden_dim, eps=rms_eps)\n    rmsnorm_layer.weight.data = rms_gamma\n\n    if quant_mode in [\"all\", \"none\"]:\n        # Standard AllReduce + RMSNorm\n        try:\n            time_ms = benchmark_operation(\n                standard_allreduce_rmsnorm,\n                input_tensor,\n                norm_out=norm_out,\n                residual=residual,\n                rms_gamma=rms_gamma,\n                rms_eps=rms_eps,\n            )\n            results[\"standard_allreduce_rmsnorm\"] = time_ms\n        except Exception as e:\n            logger.error(\"Standard AllReduce+RMSNorm failed: %s\", e)\n            results[\"standard_allreduce_rmsnorm\"] = float(\"inf\")\n\n        # Standard AllReduce + RMSNorm Native Compiled\n        try:\n            time_ms = benchmark_operation(\n                standard_allreduce_rmsnorm_native_compiled,\n                input_tensor,\n                residual=residual,\n                rmsnorm_layer=rmsnorm_layer,\n                norm_out=norm_out,\n            )\n            results[\"standard_allreduce_rmsnorm_native_compiled\"] = time_ms\n        except Exception as e:\n            logger.error(\"Standard AllReduce+RMSNorm Native Compiled failed: %s\", e)\n            results[\"standard_allreduce_rmsnorm_native_compiled\"] = float(\"inf\")\n\n        # FlashInfer Fused AllReduce + RMSNorm Oneshot\n        if flashinfer_comm is not None and allreduce_params is not None:\n            try:\n                if not disable_oneshot:\n                    time_ms = benchmark_operation(\n                        flashinfer_fused_allreduce_rmsnorm,\n                        input_tensor,\n                        residual=residual,\n                        norm_out=norm_out,\n                        rms_gamma=rms_gamma,\n                        rms_eps=rms_eps,\n                        allreduce_params=allreduce_params,\n                        use_oneshot=True,\n                    )\n                    results[\"flashinfer_fused_allreduce_rmsnorm_oneshot\"] = time_ms\n            except Exception as e:\n                logger.error(\"FlashInfer Fused AllReduce+RMSNorm Oneshot failed: %s\", e)\n                results[\"flashinfer_fused_allreduce_rmsnorm_oneshot\"] = float(\"inf\")\n\n            # FlashInfer Fused AllReduce + RMSNorm Two-shot\n            try:\n                time_ms = benchmark_operation(\n                    flashinfer_fused_allreduce_rmsnorm,\n                    input_tensor,\n                    residual=residual,\n                    norm_out=norm_out,\n                    rms_gamma=rms_gamma,\n                    rms_eps=rms_eps,\n                    allreduce_params=allreduce_params,\n                    use_oneshot=False,\n                )\n                results[\"flashinfer_fused_allreduce_rmsnorm_twoshot\"] = time_ms\n            except Exception as e:\n                logger.error(\n                    \"FlashInfer Fused AllReduce+RMSNorm Two-shot failed: %s\", e\n                )\n                results[\"flashinfer_fused_allreduce_rmsnorm_twoshot\"] = float(\"inf\")\n\n    if quant_mode in [\"all\", \"fp8_only\"]:\n        # Standard AllReduce + RMSNorm + FP8 Quant\n        try:\n            time_ms = benchmark_operation(\n                standard_allreduce_rmsnorm_fp8_quant,\n                input_tensor,\n                norm_out=norm_out,\n                residual=residual,\n                rms_gamma=rms_gamma,\n                rms_eps=rms_eps,\n                scale_factor=scale_fp8,\n                quant_out=quant_out_fp8,\n            )\n            results[\"standard_allreduce_rmsnorm_fp8_quant\"] = time_ms\n        except Exception as e:\n            logger.error(\"Standard AllReduce+RMSNorm+FP8 failed: %s\", e)\n            results[\"standard_allreduce_rmsnorm_fp8_quant\"] = float(\"inf\")\n\n        # Standard AllReduce + RMSNorm + FP8 Quant Native Compiled\n        try:\n            time_ms = benchmark_operation(\n                standard_allreduce_rmsnorm_fp8_quant_native_compiled,\n                input_tensor,\n                residual=residual,\n                rmsnorm_layer=rmsnorm_layer,\n                # quant_fp8_layer removed in sglang version; static_quant_fp8 is used within the function\n                scale_factor=scale_fp8,\n                norm_out=norm_out,\n                quant_out=quant_out_fp8,\n            )\n            results[\"standard_allreduce_rmsnorm_fp8_quant_native_compiled\"] = time_ms\n        except Exception as e:\n            logger.error(\"Standard AllReduce+RMSNorm+FP8 Native Compiled failed: %s\", e)\n            results[\"standard_allreduce_rmsnorm_fp8_quant_native_compiled\"] = float(\n                \"inf\"\n            )\n\n        # FlashInfer Fused AllReduce + RMSNorm + FP8 Quant Oneshot\n        if flashinfer_comm is not None and allreduce_params is not None:\n            try:\n                if not disable_oneshot:\n                    time_ms = benchmark_operation(\n                        flashinfer_fused_allreduce_rmsnorm_fp8_quant,\n                        input_tensor,\n                        norm_out=norm_out,\n                        residual=residual,\n                        rms_gamma=rms_gamma,\n                        rms_eps=rms_eps,\n                        scale_factor=scale_fp8,\n                        quant_out=quant_out_fp8,\n                        allreduce_params=allreduce_params,\n                        use_oneshot=True,\n                    )\n                    results[\"flashinfer_fused_allreduce_rmsnorm_fp8_quant_oneshot\"] = (\n                        time_ms\n                    )\n            except Exception as e:\n                logger.error(\n                    \"FlashInfer Fused AllReduce+RMSNorm+FP8 Oneshot failed: %s\",\n                    e,\n                )\n                results[\"flashinfer_fused_allreduce_rmsnorm_fp8_quant_oneshot\"] = float(\n                    \"inf\"\n                )\n            # FlashInfer Fused AllReduce + RMSNorm + FP8 Quant Two-shot\n            try:\n                time_ms = benchmark_operation(\n                    flashinfer_fused_allreduce_rmsnorm_fp8_quant,\n                    input_tensor,\n                    norm_out=norm_out,\n                    residual=residual,\n                    rms_gamma=rms_gamma,\n                    rms_eps=rms_eps,\n                    scale_factor=scale_fp8,\n                    quant_out=quant_out_fp8,\n                    allreduce_params=allreduce_params,\n                    use_oneshot=False,\n                )\n                results[\"flashinfer_fused_allreduce_rmsnorm_fp8_quant_twoshot\"] = (\n                    time_ms\n                )\n            except Exception as e:\n                logger.error(\n                    \"FlashInfer Fused AllReduce+RMSNorm+FP8 Two-shot failed: %s\",\n                    e,\n                )\n                results[\"flashinfer_fused_allreduce_rmsnorm_fp8_quant_twoshot\"] = float(\n                    \"inf\"\n                )\n\n    if quant_mode in [\"all\", \"fp4_only\"]:\n        # Standard AllReduce + RMSNorm + FP4 Quant\n        try:\n            time_ms = benchmark_operation(\n                standard_allreduce_rmsnorm_fp4_quant,\n                input_tensor,\n                norm_out=norm_out,\n                residual=residual,\n                rms_gamma=rms_gamma,\n                rms_eps=rms_eps,\n                input_global_scale=scale_fp4,\n                quant_out=fp4_quant_out,\n                output_scale=fp4_output_scale,\n            )\n            results[\"standard_allreduce_rmsnorm_fp4_quant\"] = time_ms\n        except Exception as e:\n            logger.error(\"Standard AllReduce+RMSNorm+FP4 failed: %s\", e)\n            results[\"standard_allreduce_rmsnorm_fp4_quant\"] = float(\"inf\")\n\n        # Standard AllReduce + RMSNorm + FP4 Quant Native Compiled\n        try:\n            time_ms = benchmark_operation(\n                standard_allreduce_rmsnorm_fp4_quant_native_compiled,\n                input_tensor,\n                residual=residual,\n                rmsnorm_layer=rmsnorm_layer,\n                input_global_scale=scale_fp4,\n                quant_out=fp4_quant_out,\n                output_scale=fp4_output_scale,\n                norm_out=norm_out,\n            )\n            results[\"standard_allreduce_rmsnorm_fp4_quant_native_compiled\"] = time_ms\n        except Exception as e:\n            logger.error(\"Standard AllReduce+RMSNorm+FP4 Native Compiled failed: %s\", e)\n            results[\"standard_allreduce_rmsnorm_fp4_quant_native_compiled\"] = float(\n                \"inf\"\n            )\n\n        # FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Oneshot\n        if flashinfer_comm is not None and allreduce_params is not None:\n            try:\n                if not disable_oneshot:\n                    time_ms = benchmark_operation(\n                        flashinfer_fused_allreduce_rmsnorm_fp4_quant,\n                        input_tensor,\n                        residual=residual,\n                        norm_out=norm_out,\n                        rms_gamma=rms_gamma,\n                        rms_eps=rms_eps,\n                        input_global_scale=scale_fp4,\n                        allreduce_params=allreduce_params,\n                        quant_out=fp4_quant_out,\n                        output_scale=fp4_output_scale,\n                        use_oneshot=True,\n                    )\n                    results[\"flashinfer_fused_allreduce_rmsnorm_fp4_quant_oneshot\"] = (\n                        time_ms\n                    )\n            except Exception as e:\n                logger.error(\n                    \"FlashInfer Fused AllReduce+RMSNorm+FP4 Oneshot failed: %s\",\n                    e,\n                )\n                results[\"flashinfer_fused_allreduce_rmsnorm_fp4_quant_oneshot\"] = float(\n                    \"inf\"\n                )\n\n        # FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Two-shot\n        if flashinfer_comm is not None and allreduce_params is not None:\n            try:\n                time_ms = benchmark_operation(\n                    flashinfer_fused_allreduce_rmsnorm_fp4_quant,\n                    input_tensor,\n                    residual=residual,\n                    norm_out=norm_out,\n                    rms_gamma=rms_gamma,\n                    rms_eps=rms_eps,\n                    input_global_scale=scale_fp4,\n                    allreduce_params=allreduce_params,\n                    quant_out=fp4_quant_out,\n                    output_scale=fp4_output_scale,\n                    use_oneshot=False,\n                )\n                results[\"flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot\"] = (\n                    time_ms\n                )\n            except Exception as e:\n                logger.error(\n                    \"FlashInfer Fused AllReduce+RMSNorm+FP4 Two-shot failed: %s\",\n                    e,\n                )\n                results[\"flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot\"] = float(\n                    \"inf\"\n                )\n\n    return results\n\n\ndef prepare_results_with_speedups(results_dict):\n    \"\"\"Prepare results with speedup calculations based on dynamic baseline selection.\"\"\"\n    prepared_results = []\n\n    # Determine the fastest baseline for each operation type\n    def get_fastest_baseline(op_name, results_dict):\n        \"\"\"Get the fastest baseline between standard and native_compiled versions.\"\"\"\n        if \"fp8_quant\" in op_name:\n            candidates = [\n                \"standard_allreduce_rmsnorm_fp8_quant\",\n                \"standard_allreduce_rmsnorm_fp8_quant_native_compiled\",\n            ]\n        elif \"fp4_quant\" in op_name:\n            candidates = [\n                \"standard_allreduce_rmsnorm_fp4_quant\",\n                \"standard_allreduce_rmsnorm_fp4_quant_native_compiled\",\n            ]\n        else:\n            candidates = [\n                \"standard_allreduce_rmsnorm\",\n                \"standard_allreduce_rmsnorm_native_compiled\",\n            ]\n\n        # Find the fastest among available candidates\n        fastest_time = float(\"inf\")\n        fastest_baseline = None\n\n        for candidate in candidates:\n            if (\n                candidate in results_dict\n                and results_dict[candidate] != float(\"inf\")\n                and results_dict[candidate] < fastest_time\n            ):\n                fastest_time = results_dict[candidate]\n                fastest_baseline = candidate\n\n        return fastest_baseline\n\n    # Create dynamic baseline mapping\n    dynamic_baseline_mapping = {}\n    for op_name in results_dict:\n        if (\n            op_name.startswith(\"flashinfer_\")\n            or op_name.startswith(\"standard_\")\n            and not op_name.endswith(\"_native_compiled\")\n        ):\n            dynamic_baseline_mapping[op_name] = get_fastest_baseline(\n                op_name, results_dict\n            )\n\n    for op_name, time_ms in results_dict.items():\n        if time_ms == float(\"inf\"):\n            speedup_str = \"FAILED\"\n            time_str = \"FAILED\"\n        else:\n            time_str = f\"{time_ms:.3f}\"\n            # Find the appropriate baseline for this operation\n            baseline_op = dynamic_baseline_mapping.get(op_name)\n            if baseline_op and baseline_op in results_dict:\n                baseline_time = results_dict[baseline_op]\n                if baseline_time != float(\"inf\") and baseline_time > 0:\n                    speedup = baseline_time / time_ms\n                    speedup_str = f\"{speedup:.2f}x\"\n                else:\n                    speedup_str = \"N/A\"\n            else:\n                # For baseline operations, determine if this is the fastest baseline\n                if op_name.endswith(\"_native_compiled\") or (\n                    op_name.startswith(\"standard_\")\n                    and not op_name.endswith(\"_native_compiled\")\n                ):\n                    fastest_baseline = get_fastest_baseline(op_name, results_dict)\n                    if fastest_baseline == op_name:\n                        speedup_str = \"baseline\"\n                    else:\n                        if fastest_baseline and fastest_baseline in results_dict:\n                            baseline_time = results_dict[fastest_baseline]\n                            if baseline_time != float(\"inf\") and baseline_time > 0:\n                                speedup = baseline_time / time_ms\n                                speedup_str = f\"{speedup:.2f}x\"\n                            else:\n                                speedup_str = \"N/A\"\n                        else:\n                            speedup_str = \"N/A\"\n                else:\n                    speedup_str = \"N/A\"\n\n        prepared_results.append(\n            {\n                \"operation\": op_name,\n                \"time_ms\": time_ms,\n                \"time_str\": time_str,\n                \"speedup_str\": speedup_str,\n            }\n        )\n\n    return prepared_results\n\n\ndef print_results(results_dict, seq_len, hidden_dim, dtype, use_residual, quant_mode):\n    \"\"\"Print benchmark results in a formatted table.\"\"\"\n    print(f\"\\n{'=' * 80}\")\n    print(f\"Results: seq_len={seq_len}, hidden_dim={hidden_dim}\")\n    print(\n        f\"dtype={dtype}, residual={'yes' if use_residual else 'no'}, \"\n        f\"quant_mode={quant_mode}\"\n    )\n    print(f\"{'=' * 80}\")\n    print(f\"{'Operation':<50} {'Time (ms)':<12} {'Speedup':<10}\")\n    print(f\"{'-' * 80}\")\n\n    # Prepare results with speedup calculations\n    prepared_results = prepare_results_with_speedups(results_dict)\n\n    for result in prepared_results:\n        if result[\"time_ms\"] == float(\"inf\"):\n            time_display = result[\"time_str\"]\n        else:\n            time_display = f\"{result['time_ms']:.3f}\"\n\n        print(\n            f\"{result['operation']:<50} {time_display:<12} {result['speedup_str']:<10}\"\n        )\n\n\ndef format_results_markdown(\n    all_results: list[dict], world_size: int, args: argparse.Namespace\n) -> str:\n    \"\"\"Format all benchmark results as markdown.\"\"\"\n    markdown = f\"\"\"# FlashInfer Fused Collective Operations Benchmark Results\n\n**World Size:** {world_size}\n**Hidden Dimension:** {args.hidden_dim}\n**Warmup Iterations:** {args.warmup}\n**Benchmark Trials:** {args.trials}\n**Quantization Mode:** {all_results[0][\"quant_mode\"] if all_results else \"N/A\"}\n\n---\n\n\"\"\"\n\n    for result in all_results:\n        seq_len = result[\"seq_len\"]\n        dtype = result[\"dtype\"]\n        use_residual = result[\"use_residual\"]\n        results_dict = result[\"results\"]\n\n        residual_str = \"with residual\" if use_residual else \"no residual\"\n\n        markdown += f\"\"\"\n## Configuration: seq_len={seq_len}, dtype={dtype}, {residual_str}\n\n| Operation | Time (ms) | Speedup |\n|-----------|-----------|---------|\n\"\"\"\n\n        # Prepare results with speedup calculations\n        prepared_results = prepare_results_with_speedups(results_dict)\n\n        for result in prepared_results:\n            # Format operation name for better readability\n            formatted_op_name = result[\"operation\"].replace(\"_\", \" \").title()\n            markdown += f\"| {formatted_op_name} | {result['time_str']} |\"\n            markdown += f\"{result['speedup_str']} |\\n\"\n\n        markdown += \"\\n\"\n\n    return markdown\n\n\ndef save_results_to_file(\n    all_results: list[dict], world_size: int, args: argparse.Namespace, rank: int\n):\n    \"\"\"Save benchmark results to markdown file (only on rank 0).\"\"\"\n    if rank != 0:\n        return\n\n    if not all_results:\n        logger.warning(\"No results to save\")\n        return\n\n    output_path = args.output_file\n\n    try:\n        markdown_content = format_results_markdown(all_results, world_size, args)\n\n        with open(output_path, \"w\") as f:\n            f.write(markdown_content)\n\n    except Exception as e:\n        logger.error(\"Failed to save results to file: %s\", e)\n\n\ndef main():\n    parser = argparse.ArgumentParser(\n        description=\"Benchmark fused collective operations\"\n    )\n    parser.add_argument(\n        \"--seq-lens\",\n        type=int,\n        nargs=\"+\",\n        default=[128, 512, 1024, 2048],\n        help=\"Sequence lengths to test\",\n    )\n    parser.add_argument(\n        \"--hidden-dim\", type=int, default=8192, help=\"Hidden dimension size\"\n    )\n    parser.add_argument(\n        \"--dtypes\",\n        type=str,\n        nargs=\"+\",\n        default=[\"bfloat16\"],\n        choices=[\"float16\", \"bfloat16\", \"float32\"],\n        help=\"Data types to test\",\n    )\n    parser.add_argument(\n        \"--no-residual\",\n        action=\"store_true\",\n        help=\"Skip residual connection tests\",\n    )\n\n    # Quantization mode options (mutually exclusive with --no-quant)\n    quant_group = parser.add_mutually_exclusive_group()\n    quant_group.add_argument(\n        \"--no-quant\", action=\"store_true\", help=\"Skip all quantization tests\"\n    )\n    quant_group.add_argument(\n        \"--quant-fp8\", action=\"store_true\", help=\"Only run FP8 quantization tests\"\n    )\n    quant_group.add_argument(\n        \"--quant-fp4\", action=\"store_true\", help=\"Only run FP4 quantization tests\"\n    )\n    quant_group.add_argument(\n        \"--quant-all\",\n        action=\"store_true\",\n        help=\"Run all quantization tests (default)\",\n    )\n\n    parser.add_argument(\n        \"--disable-oneshot\",\n        action=\"store_true\",\n        help=\"Disable oneshot mode for FlashInfer operations\",\n    )\n    parser.add_argument(\n        \"--warmup\", type=int, default=5, help=\"Number of warmup iterations\"\n    )\n    parser.add_argument(\n        \"--trials\", type=int, default=20, help=\"Number of benchmark trials\"\n    )\n    parser.add_argument(\n        \"--output-file\",\n        type=str,\n        help=\"\"\"Output file path for markdown results\n                (default: benchmark_results_<timestamp>.md)\n        \"\"\",\n    )\n\n    args = parser.parse_args()\n\n    # Check if running with torchrun (required for collective operations)\n    if \"RANK\" not in os.environ or \"WORLD_SIZE\" not in os.environ:\n        raise RuntimeError(\n            \"Must run with torchrun for distributed benchmarking. \"\n            \"Example: torchrun --nproc_per_node=2 benchmark_fused_collective.py\"\n        )\n\n    # Initialize distributed environment\n    rank = int(os.environ[\"RANK\"])\n    world_size = int(os.environ[\"WORLD_SIZE\"])\n\n    device = torch.device(f\"cuda:{rank}\")\n    torch.cuda.set_device(device)\n    torch.set_default_device(device)\n\n    init_distributed_environment(\n        world_size=world_size,\n        rank=rank,\n        local_rank=rank,\n        backend=\"nccl\",\n    )\n    initialize_model_parallel(tensor_model_parallel_size=world_size)\n\n    # Validate world size (must be > 1 for collective operations)\n    if world_size <= 1:\n        raise ValueError(\n            \"World size must be > 1 for collective operations benchmarking. \"\n            f\"Current world size: {world_size}. Use torchrun with --nproc_per_node > 1.\"\n        )\n\n    # Determine quantization mode\n    if args.no_quant:\n        quant_mode = \"none\"\n    elif args.quant_fp8:\n        quant_mode = \"fp8_only\"\n    elif args.quant_fp4:\n        quant_mode = \"fp4_only\"\n    else:  # args.quant_all or default\n        quant_mode = \"all\"\n\n    if rank == 0:\n        logger.info(\"Running benchmark with world_size=%s, rank=%s\", world_size, rank)\n        logger.info(\"Quantization mode: %s\", quant_mode)\n        if flashinfer_comm is not None:\n            oneshot_status = \"enabled\" if not args.disable_oneshot else \"disabled\"\n            logger.info(\n                \"FlashInfer available - will benchmark fused operations (oneshot: %s)\",\n                oneshot_status,\n            )\n        else:\n            logger.info(\n                \"FlashInfer not available - only benchmarking standard operations\"\n            )\n\n    # Convert dtype strings to torch dtypes\n    dtype_map = {\n        \"float16\": torch.float16,\n        \"bfloat16\": torch.bfloat16,\n        \"float32\": torch.float32,\n    }\n    dtypes = [dtype_map[dt] for dt in args.dtypes]\n\n    # Test configurations\n    residual_options = [True] if not args.no_residual else [False]\n    if not args.no_residual:\n        residual_options.append(False)\n\n    configs = list(itertools.product(args.seq_lens, dtypes, residual_options))\n\n    # Setup FlashInfer workspace if available\n    ipc_handles = None\n    allreduce_params = None\n\n    if flashinfer_comm is not None:\n        # Use the largest hidden dimension for workspace setup\n        max_num_token = _FI_MAX_SIZES.get(world_size) // (\n            args.hidden_dim * world_size * 2\n        )\n\n        ipc_handles, workspace_tensor = setup_flashinfer_workspace(\n            world_size, rank, args.hidden_dim, max_num_token\n        )\n\n        if workspace_tensor is not None:\n            allreduce_params = FlashInferFusedAllReduceParams(\n                rank=rank,\n                world_size=world_size,\n                max_token_num=max_num_token,\n            )\n\n    # Collect all results for markdown export\n    all_results = []\n\n    try:\n        # Run benchmarks\n        for seq_len, dtype, use_residual in configs:\n            if rank == 0:\n                logger.info(\n                    \"\\nTesting:  seq_len=%s, hidden_dim=%s, dtype=%s, residual=%s\",\n                    seq_len,\n                    args.hidden_dim,\n                    dtype,\n                    use_residual,\n                )\n\n            results = run_benchmarks(\n                seq_len,\n                args.hidden_dim,\n                dtype,\n                use_residual,\n                allreduce_params,\n                quant_mode=quant_mode,\n                disable_oneshot=args.disable_oneshot,\n            )\n\n            # Store results for markdown export\n            if rank == 0:\n                all_results.append(\n                    {\n                        \"seq_len\": seq_len,\n                        \"hidden_dim\": args.hidden_dim,\n                        \"dtype\": str(dtype).replace(\"torch.\", \"\"),\n                        \"use_residual\": use_residual,\n                        \"quant_mode\": quant_mode,\n                        \"results\": results,\n                    }\n                )\n\n                print_results(\n                    results,\n                    seq_len,\n                    args.hidden_dim,\n                    dtype,\n                    use_residual,\n                    quant_mode,\n                )\n\n        # Save results to markdown file\n        if args.output_file and rank == 0:\n            save_results_to_file(all_results, world_size, args, rank)\n\n    finally:\n        # Cleanup\n        if ipc_handles is not None:\n            cleanup_flashinfer_workspace(ipc_handles)\n\n        with contextlib.suppress(Exception):\n            dist.barrier()\n        cleanup_dist_env_and_memory(shutdown_ray=False)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "benchmark/kernels/fused_moe_triton/README.md",
    "content": "## Tuning Triton MoE Kernels\n\nThis directory contains benchmarking tools for MoE (Mixture of Experts) kernels.\n\n### Overview\n\nThe tuning tools support both **Tensor Parallelism (TP)** and **Expert Parallelism (EP)** modes:\n\n- **TP Mode**: Traditional tensor parallelism where intermediate layers are sharded across GPUs\n- **EP Mode**: Expert parallelism where experts are distributed across GPUs. Can be combined with TP mode (e.g., `--tp-size 8 --ep-size 2`)\n- **MLLM Support**: Multi-modal Large Language Models with text encoders (e.g., Llama4, Qwen3VL)\n\n### Tuning Tools\n\n#### 1. `tuning_fused_moe_triton.py`\nA unified tool for tuning the `fused_moe_triton` kernel. Adapted from [vllm's benchmark_moe.py](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py), with support for EP mode and various model architectures.\n\n#### 2. `tuning_fused_moe_triton_sep.py`\nA specialized tool for separate kernel tuning, optimizing the first and second MoE kernels independently with TMA (Tensor Memory Accelerator) support.\n\n### Usage Examples\n\n#### Basic TP Mode Tuning\n```bash\n# Tune Mixtral-8x7B with default TP settings\npython benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \\\n    --model mistralai/Mixtral-8x7B-Instruct-v0.1 \\\n    --tune\n\n# Tune Qwen2-57B with FP8 and TP=4\npython benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \\\n    --model Qwen/Qwen2-57B-A14B-Instruct \\\n    --tp-size 4 \\\n    --dtype fp8_w8a8 \\\n    --tune\n\n# Tune DeepSeek-V3 with FP8 and TP=8\npython benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \\\n    --model deepseek-ai/DeepSeek-V3-0324 \\\n    --tp-size 8 \\\n    --dtype fp8_w8a8 \\\n    --tune\n```\n\n#### EP Mode Tuning (Expert Parallelism)\n**Note**: EP mode can be used alone or combined with TP mode. When using both, ensure `tp_size` is divisible by `ep_size`.\n\n```bash\n# Tune Mixtral-8x7B with EP=2 only\npython benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \\\n    --model mistralai/Mixtral-8x7B-Instruct-v0.1 \\\n    --tp-size 2 \\\n    --ep-size 2 \\\n    --tune\n\n# Tune Qwen2-57B with TP=8 and EP=4 (combined mode)\npython benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \\\n    --model Qwen/Qwen2-57B-A14B-Instruct \\\n    --tp-size 8 \\\n    --ep-size 4 \\\n    --dtype fp8_w8a8 \\\n    --tune\n```\n\n#### MLLM Model Tuning (Multi-modal)\n```bash\npython benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \\\n    --model Qwen/Qwen3-VL-30B-A3B-Instruct \\\n    --tp-size 2 \\\n    --tune\n```\n\n#### Separate Kernel Tuning with `tuning_fused_moe_triton_sep.py`\n\nThis tool requires pre-generated topk_ids files and supports both TP and EP modes:\n\nEdit the code file (such as srt/models/deepseek_v2.py) in the Python site package and add the logic for saving topk_ids:\n\n```python\n# import get_tensor_model_parallel_rank\n# DeepseekV2MoE::forward_normal\nif hidden_states.shape[0] >= 4096 and get_tensor_model_parallel_rank() == 0:\n    topk_ids_dir = xxxx\n    if not hasattr(self, \"save_idx\"):\n        self.save_idx = 0\n    if self.save_idx <= 1:\n        torch.save(topk_output.topk_ids, f\"{topk_ids_dir}/topk_ids_layer{self.layer_id}_idx{self.save_idx}.pt\")\n    self.save_idx += 1\n```\n\nLaunch sglang server and send request using `benchmark/kernels/fused_moe_triton/tuning_client.py`\n```bash\npython benchmark/kernels/fused_moe_triton/tuning_client.py --port 8000\n```\n\n```bash\n# TP Mode: Tune separate kernels with TP=4\npython benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton_sep.py \\\n    --model Qwen/Qwen2-57B-A14B-Instruct \\\n    --tp-size 4 \\\n    --topk-ids-dir /path/to/topk_ids \\\n    --tune\n\n# EP Mode: Tune separate kernels with TP=4 and EP=2\npython benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton_sep.py \\\n    --model mistralai/Mixtral-8x7B-Instruct-v0.1 \\\n    --tp-size 4 \\\n    --ep-size 2 \\\n    --topk-ids-dir /path/to/topk_ids \\\n    --tune\n\n# MLLM: Tune DeepSeek-V3 with separate kernels, TP=8 and EP=4\npython benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton_sep.py \\\n    --model deepseek-ai/DeepSeek-V3-0324 \\\n    --tp-size 8 \\\n    --ep-size 4 \\\n    --dtype fp8_w8a8 \\\n    --topk-ids-dir /path/to/topk_ids \\\n    --tune\n\n# Benchmark specific config without tuning\npython benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton_sep.py \\\n    --model deepseek-ai/DeepSeek-V3-0324 \\\n    --tp-size 4 \\\n    --batch-size 1024 \\\n    --dtype fp8_w8a8 \\\n    --configs 128 256 128 16 8 4 \\\n    --topk-ids-dir /path/to/topk_ids\n```\n\n#### Advanced Options\n```bash\n# Channel-wise quantization\npython benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \\\n    --model meituan/DeepSeek-R1-Channel-INT8 \\\n    --tp-size 16 \\\n    --dtype int8_w8a8 \\\n    --per-channel-quant \\\n    --tune\n\n# Specific batch size tuning\npython benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \\\n    --model mistralai/Mixtral-8x7B-Instruct-v0.1 \\\n    --batch-size 2048 \\\n    --tune\n```\n\n### Configuration Files\n\nAfter tuning, configuration files will be generated:\n- **Standard tuning**: `E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json`\n- **Separate kernel tuning**: Two files for up/down kernels with TMA optimization flags\n\nMove these files to `sglang/srt/layers/moe/fused_moe_triton/configs/triton_version/` directory to use them in SGLang.\n\n### Supported Models\n\n- **Mixtral**: mistralai/Mixtral-8x7B-Instruct-v0.1, mixtral-8x22b\n- **Qwen**: Qwen2-57B, Qwen3-235B, Qwen3VL (MLLM)\n- **DeepSeek**: DeepSeek-V2, DeepSeek-V3, DeepSeek-R1\n- **Llama**: Llama4-Vision (MLLM)\n- **DBRX**: databricks/dbrx-instruct\n- **Jamba**: ai21labs/AI21-Jamba\n- **Grok**: xai-org/grok-1\n- **GLM**: THUDM/glm-4-9b-chat\n- **Bailing**: Custom MoE models\n\n### Parameters Reference\n\n- `--model`: HuggingFace model name or local path\n- `--tp-size`: Tensor parallelism size (default: 2)\n- `--ep-size`: Expert parallelism size (default: 1, can be combined with TP mode, ensure tp_size is divisible by ep_size)\n- `--dtype`: Data type (`auto`, `fp8_w8a8`, `int8_w8a16`, `int8_w8a8`)\n- `--batch-size`: Specific batch size for tuning (optional)\n- `--tune`: Enable tuning mode\n- `--per-channel-quant`: Enable per-channel quantization\n- `--disable-shared-experts-fusion`: Disable shared expert fusion for some models\n- `--topk-ids-dir`: Directory containing pre-generated topk_ids (for sep tool only)\n- `--configs`: Manual config specification [BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M, warps, stages]\n\n### Performance Comparison Tool\n\n- `benchmark_vllm_vs_sglang_fused_moe_triton.py`: A tool for comparing the performance of fused MoE kernels between vllm and sglang implementations. Supports various model architectures and data types.\n\nExample usage:\n```bash\n# Compare with default settings (Mixtral model)\npython benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py\n\n# Compare with FP8 mode for Qwen2-57B\npython benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \\\n    --model Qwen/Qwen2-57B-A14B-Instruct \\\n    --use-fp8-w8a8\n\n# Compare with custom TP size\npython benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \\\n    --model deepseek-ai/DeepSeek-V3-0324 \\\n    --tp-size 8\n\n# Compare with custom TP size\npython benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \\\n    --model deepseek-ai/DeepSeek-V3-0324 \\\n    --tp-size 8\n```\n\nThe benchmark results will be saved as plots and data files in the specified output directory (default: `./configs/benchmark_ops/vllm_sglang_fused_moe/`).\n\n- `benchmark_torch_compile_fused_moe.py`: A tool for benchmarking the performance of the fused MoE kernel with `torch.compile` and original fused MoE kernel.\n\nUsage is similar to `benchmark_vllm_vs_sglang_fused_moe_triton.py`, note that `torch.compile` does not support `fp8_w8a8` and `int8_w8a8` fused_moe_kernel. Both tools now support EP mode with `--ep-size` parameter.\n"
  },
  {
    "path": "benchmark/kernels/fused_moe_triton/benchmark_sglang_fused_moe_triton.py",
    "content": "# python3 benchmark/kernels/fused_moe_triton/sglang_fused_moe_triton.py --model /DeepSeek-V3/ --tp-size 8\nimport argparse\n\nimport torch\nimport triton\nfrom common_utils import get_model_config\n\nfrom sglang.benchmark.bench_utils import run_bench\nfrom sglang.srt.distributed.parallel_state import (\n    destroy_distributed_environment,\n    destroy_model_parallel,\n    init_distributed_environment,\n    initialize_model_parallel,\n)\nfrom sglang.srt.layers.moe.fused_moe_triton.fused_moe import (\n    fused_moe as fused_moe_sglang,\n)\nfrom sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (\n    triton_kernel_moe_forward,\n)\nfrom sglang.srt.layers.moe.moe_runner import MoeRunnerConfig\nfrom sglang.srt.layers.moe.topk import (\n    TopK,\n    TopKConfig,\n    TopKOutputFormat,\n    select_experts,\n)\nfrom sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler\n\n\ndef fused_moe_triton_api(\n    x,\n    w1,\n    w2,\n    input_gating,\n    topk,\n):\n    topk_op = TopK(\n        top_k=topk,\n        renormalize=False,\n        use_grouped_topk=False,\n        output_format=TopKOutputFormat.TRITON_KERNEL,\n    )\n    triton_topk_output = topk_op.forward_cuda(\n        hidden_states=x,\n        router_logits=input_gating,\n    )\n\n    moe_runner_config = MoeRunnerConfig(\n        inplace=False,\n    )\n    return triton_kernel_moe_forward(\n        x,\n        w1,\n        w2,\n        triton_topk_output,\n        moe_runner_config,\n    )\n\n\ndef fused_moe_sglang_api(\n    x,\n    w1,\n    w2,\n    input_gating,\n    topk,\n    use_fp8_w8a8=False,\n    w1_scale=None,\n    w2_scale=None,\n    a1_scale=None,\n    a2_scale=None,\n    block_shape=None,\n):\n    topk_output = select_experts(\n        hidden_states=x,\n        router_logits=input_gating,\n        topk_config=TopKConfig(top_k=topk, renormalize=False),\n    )\n    return fused_moe_sglang(\n        x,\n        w1,\n        w2,\n        topk_output,\n        use_fp8_w8a8=use_fp8_w8a8,\n        w1_scale=w1_scale,\n        w2_scale=w2_scale,\n        a1_scale=a1_scale,\n        a2_scale=a2_scale,\n        block_shape=block_shape,\n    )\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=[\"batch_size\"],\n        x_vals=list([128, 256, 512, 1024, 2048, 4096, 8192]),\n        line_arg=\"provider\",\n        line_vals=[\n            \"sglang_fused_moe_triton_v340\",\n            \"sglang_fused_moe_triton\",\n        ],\n        line_names=[\n            \"sglang_fused_moe_triton_v340\",\n            \"sglang_fused_moe_triton\",\n        ],\n        styles=[\n            (\"blue\", \"-\"),\n            (\"green\", \"-\"),\n        ],\n        ylabel=\"Time (ms)\",\n        plot_name=\"fused-moe-performance\",\n        args={},\n    )\n)\ndef benchmark(\n    batch_size,\n    provider,\n    model_config,\n    use_fp8_w8a8=False,\n    use_cuda_graph: bool = False,\n):\n    print(f\"benchmark {provider} with batch_size={batch_size}\")\n    torch.set_default_device(\"cuda\")\n    torch.cuda.manual_seed_all(0)\n\n    num_tokens = batch_size\n    num_experts = model_config[\"num_experts\"]\n    hidden_size = model_config[\"hidden_size\"]\n    shard_intermediate_size = model_config[\"shard_intermediate_size\"]\n    topk = model_config[\"topk\"]\n    dtype = model_config[\"dtype\"]\n    block_shape = model_config[\"block_shape\"]\n\n    x = torch.randn(num_tokens, hidden_size, dtype=dtype)\n\n    w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, dtype=dtype)\n    w2 = torch.randn(\n        num_experts, hidden_size, shard_intermediate_size // 2, dtype=dtype\n    )\n\n    w1_tri = w1.clone()\n    w2_tri = w2.clone()\n    w1_tri = w1_tri.transpose(-2, -1).contiguous()\n    w2_tri = w2_tri.transpose(-2, -1).contiguous()\n\n    input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)\n\n    if provider == \"sglang_fused_moe_triton_v340\":\n        api_func = fused_moe_triton_api\n        api_kwargs = {\n            \"x\": x,\n            \"w1\": w1_tri,\n            \"w2\": w2_tri,\n            \"input_gating\": input_gating,\n            \"topk\": topk,\n        }\n    else:\n        api_func = fused_moe_sglang_api\n        api_kwargs = {\n            \"x\": x,\n            \"w1\": w1,\n            \"w2\": w2,\n            \"input_gating\": input_gating,\n            \"topk\": topk,\n            \"use_fp8_w8a8\": use_fp8_w8a8,\n            \"block_shape\": block_shape,\n        }\n\n    # Warmup\n    for _ in range(10):\n        _ = api_func(**api_kwargs)\n    torch.cuda.synchronize()\n\n    if use_cuda_graph:\n        stream = torch.cuda.Stream()\n        graph = torch.cuda.CUDAGraph()\n        with torch.cuda.graph(graph, stream=stream):\n            api_func(**api_kwargs)\n        torch.cuda.synchronize()\n\n        bench_lambda = lambda: graph.replay()\n    else:\n        bench_lambda = lambda: api_func(**api_kwargs)\n\n    quantiles = (0.5, 0.2, 0.8)\n    ms, min_ms, max_ms = run_bench(bench_lambda, quantiles=quantiles)\n    return ms, min_ms, max_ms\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--model\", type=str, default=\"mistralai/Mixtral-8x7B-Instruct-v0.1\"\n    )\n    parser.add_argument(\"--tp-size\", \"--tp\", type=int, default=2)\n    parser.add_argument(\"--ep-size\", \"--ep\", type=int, default=1)\n    parser.add_argument(\"--use-fp8-w8a8\", action=\"store_true\")\n    parser.add_argument(\n        \"--use-cuda-graph\", action=\"store_true\", help=\"Enable CUDA Graph capture/replay\"\n    )\n    parser.add_argument(\n        \"--save-path\",\n        type=str,\n        default=\"./configs/benchmark_ops/sglang_fused_moe/\",\n    )\n    parser.add_argument(\"--trust-remote-code\", action=\"store_true\")\n    args = parser.parse_args()\n\n    # Initialize global server args (required by SGLang MoE kernels)\n    server_args = ServerArgs(model_path=args.model)\n    set_global_server_args_for_scheduler(server_args)\n\n    try:\n        if not torch.distributed.is_initialized():\n            torch.distributed.init_process_group(\n                backend=\"nccl\" if torch.cuda.is_available() else \"gloo\",\n                init_method=\"tcp://127.0.0.1:23456\",\n                world_size=1,\n                rank=0,\n            )\n\n        init_distributed_environment(\n            world_size=1,\n            rank=0,\n            distributed_init_method=\"tcp://127.0.0.1:23456\",\n            local_rank=0,\n            backend=\"nccl\" if torch.cuda.is_available() else \"gloo\",\n        )\n\n        initialize_model_parallel(\n            tensor_model_parallel_size=1,\n            expert_model_parallel_size=1,\n        )\n\n        model_config = get_model_config(args.model, args.tp_size, args.ep_size)\n        benchmark.run(\n            show_plots=True,\n            print_data=True,\n            save_path=args.save_path,\n            model_config=model_config,\n            use_fp8_w8a8=args.use_fp8_w8a8,\n            use_cuda_graph=args.use_cuda_graph,\n        )\n    finally:\n        destroy_model_parallel()\n        destroy_distributed_environment()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py",
    "content": "# python3 benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py --model /DeepSeek-V3/ --tp-size 8 --use-fp8-w8a8\nimport argparse\n\nimport torch\nimport triton\nfrom torch.nn import functional as F\nfrom transformers import AutoConfig\n\nfrom sglang.benchmark.bench_utils import run_bench\nfrom sglang.srt.layers.moe.fused_moe_triton.fused_moe import (\n    fused_moe as fused_moe_triton,\n)\nfrom sglang.srt.model_executor.cuda_graph_runner import set_torch_compile_config\n\n\ndef get_model_config(model_name: str, tp_size: int):\n    \"\"\"Get model configuration parameters\"\"\"\n    config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n\n    if config.architectures[0] == \"DbrxForCausalLM\":\n        E = config.ffn_config.moe_num_experts\n        topk = config.ffn_config.moe_top_k\n        intermediate_size = config.ffn_config.ffn_hidden_size\n        shard_intermediate_size = 2 * intermediate_size // tp_size\n    elif config.architectures[0] == \"JambaForCausalLM\":\n        E = config.num_experts\n        topk = config.num_experts_per_tok\n        intermediate_size = config.intermediate_size\n        shard_intermediate_size = 2 * intermediate_size // tp_size\n    elif config.architectures[0] == \"Qwen2MoeForCausalLM\":\n        E = config.num_experts\n        topk = config.num_experts_per_tok\n        intermediate_size = config.moe_intermediate_size\n        shard_intermediate_size = 2 * intermediate_size // tp_size\n    elif config.architectures[0] == \"Qwen3MoeForCausalLM\":\n        E = config.n_routed_experts\n        topk = config.num_experts_per_tok\n        intermediate_size = config.moe_intermediate_size\n        shard_intermediate_size = 2 * intermediate_size // tp_size\n    elif config.architectures[0] in [\"DeepseekV2ForCausalLM\", \"DeepseekV3ForCausalLM\"]:\n        E = config.n_routed_experts\n        topk = config.num_experts_per_tok\n        intermediate_size = config.moe_intermediate_size\n        shard_intermediate_size = 2 * intermediate_size // tp_size\n    elif config.architectures[0] == \"Llama4ForConditionalGeneration\":\n        E = config.text_config.num_local_experts\n        topk = config.text_config.num_experts_per_tok\n        intermediate_size = config.text_config.intermediate_size\n        shard_intermediate_size = 2 * intermediate_size // tp_size\n    elif config.architectures[0] in [\n        \"Grok1ForCausalLM\",\n        \"Grok1ImgGen\",\n        \"Grok1AForCausalLM\",\n    ]:\n        E = config.num_local_experts\n        topk = config.num_experts_per_tok\n        intermediate_size = config.moe_intermediate_size\n        shard_intermediate_size = 2 * intermediate_size // tp_size\n    else:\n        # Default: Mixtral\n        E = config.num_local_experts\n        topk = config.num_experts_per_tok\n        intermediate_size = config.intermediate_size\n        shard_intermediate_size = 2 * intermediate_size // tp_size\n\n    shape_configs = {\n        \"num_experts\": E,\n        \"topk\": topk,\n        \"hidden_size\": config.hidden_size,\n        \"shard_intermediate_size\": shard_intermediate_size,\n        \"dtype\": config.torch_dtype,\n    }\n    print(f\"{shape_configs=}\")\n    return shape_configs\n\n\ndef fused_topk_native(\n    hidden_states: torch.Tensor,\n    gating_output: torch.Tensor,\n    topk: int,\n    renormalize: bool,\n):\n    assert hidden_states.shape[0] == gating_output.shape[0], \"Number of tokens mismatch\"\n    M, _ = hidden_states.shape\n    topk_weights = torch.empty(\n        M, topk, dtype=torch.float32, device=hidden_states.device\n    )\n    topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)\n    topk_weights = F.softmax(gating_output.float(), dim=-1)\n    topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)\n    if renormalize:\n        topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)\n    return topk_weights, topk_ids\n\n\n@torch.compile(dynamic=False)\ndef fused_moe_torch(\n    x,\n    w1,\n    w2,\n    input_gating,\n    topk,\n    use_fp8_w8a8=False,\n    w1_scale=None,\n    w2_scale=None,\n    a1_scale=None,\n    a2_scale=None,\n) -> torch.Tensor:\n    assert not use_fp8_w8a8, \"Fp8_w8a8 fused_moe is not supported for torch compile\"\n\n    topk_weights, topk_ids = fused_topk_native(\n        hidden_states=x,\n        gating_output=input_gating,\n        topk=topk,\n        renormalize=True,\n    )\n    w13_weights = w1[topk_ids]\n    w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)\n    w2_weights = w2[topk_ids]\n    x1 = torch.einsum(\"ti,taoi -> tao\", x, w1_weights)\n    x1 = F.silu(x1)\n    x3 = torch.einsum(\"ti, taoi -> tao\", x, w3_weights)\n    expert_outs = torch.einsum(\"tao, taio -> tai\", (x1 * x3), w2_weights)\n    return torch.einsum(\"tai,ta -> ti\", expert_outs, topk_weights.to(expert_outs.dtype))\n\n\ndef fused_moe_torch_compile(\n    x,\n    w1,\n    w2,\n    input_gating,\n    topk,\n    use_fp8_w8a8=False,\n    w1_scale=None,\n    w2_scale=None,\n    a1_scale=None,\n    a2_scale=None,\n):\n    return fused_moe_torch(\n        x,\n        w1,\n        w2,\n        input_gating,\n        topk,\n        use_fp8_w8a8=use_fp8_w8a8,\n        w1_scale=w1_scale,\n        w2_scale=w2_scale,\n        a1_scale=a1_scale,\n        a2_scale=a2_scale,\n    )\n\n\ndef fused_moe_sglang_api(\n    x,\n    w1,\n    w2,\n    input_gating,\n    topk,\n    use_fp8_w8a8=False,\n    w1_scale=None,\n    w2_scale=None,\n    a1_scale=None,\n    a2_scale=None,\n):\n    return fused_moe_triton(\n        x,\n        w1,\n        w2,\n        input_gating,\n        topk,\n        renormalize=True,\n        inplace=True,\n        use_fp8_w8a8=use_fp8_w8a8,\n        w1_scale=w1_scale,\n        w2_scale=w2_scale,\n        a1_scale=a1_scale,\n        a2_scale=a2_scale,\n    )\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=[\"batch_size\"],\n        x_vals=list(range(1, 5)),\n        line_arg=\"provider\",\n        line_vals=[\n            \"fused_moe_triton\",\n            \"fused_moe_torch_compile\",\n        ],\n        line_names=[\n            \"fused_moe_triton\",\n            \"fused_moe_torch_compile\",\n        ],\n        styles=[\n            (\"blue\", \"-\"),\n            (\"green\", \"-\"),\n        ],\n        ylabel=\"Time (ms)\",\n        plot_name=\"fused-moe-performance\",\n        args={},\n    )\n)\ndef benchmark(batch_size, provider, model_config, use_fp8_w8a8=False):\n    print(f\"benchmark {provider} with batch_size={batch_size}\")\n    torch.set_default_device(\"cuda\")\n    torch.cuda.manual_seed_all(0)\n    set_torch_compile_config()\n\n    num_tokens = batch_size\n    num_experts = model_config[\"num_experts\"]\n    hidden_size = model_config[\"hidden_size\"]\n    shard_intermediate_size = model_config[\"shard_intermediate_size\"]\n    topk = model_config[\"topk\"]\n    dtype = model_config[\"dtype\"]\n\n    x = torch.randn(num_tokens, hidden_size, dtype=dtype)\n\n    if use_fp8_w8a8:\n        init_dtype = dtype\n        w1 = torch.randn(\n            num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype\n        )\n        w2 = torch.randn(\n            num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype\n        )\n        w1 = w1.to(torch.float8_e4m3fn)\n        w2 = w2.to(torch.float8_e4m3fn)\n        w1_scale = torch.randn(num_experts, dtype=torch.float32)\n        w2_scale = torch.randn(num_experts, dtype=torch.float32)\n        a1_scale = torch.randn(1, dtype=torch.float32)\n        a2_scale = torch.randn(1, dtype=torch.float32)\n    else:\n        w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, dtype=dtype)\n        w2 = torch.randn(\n            num_experts, hidden_size, shard_intermediate_size // 2, dtype=dtype\n        )\n        w1_scale = w2_scale = a1_scale = a2_scale = None\n\n    input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)\n\n    # Warmup\n    api_func = (\n        fused_moe_torch_compile\n        if provider == \"fused_moe_torch_compile\"\n        else fused_moe_sglang_api\n    )\n    for _ in range(10):\n        y = api_func(\n            x,\n            w1,\n            w2,\n            input_gating,\n            topk,\n            use_fp8_w8a8=use_fp8_w8a8,\n            w1_scale=w1_scale,\n            w2_scale=w2_scale,\n            a1_scale=a1_scale,\n            a2_scale=a2_scale,\n        )\n    torch.cuda.synchronize()\n\n    quantiles = (0.5, 0.2, 0.8)\n    ms, min_ms, max_ms = run_bench(\n        lambda: api_func(\n            x,\n            w1,\n            w2,\n            input_gating,\n            topk,\n            use_fp8_w8a8=use_fp8_w8a8,\n            w1_scale=w1_scale,\n            w2_scale=w2_scale,\n            a1_scale=a1_scale,\n            a2_scale=a2_scale,\n        )[0],\n        quantiles=quantiles,\n    )\n    return ms, min_ms, max_ms\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--model\", type=str, default=\"mistralai/Mixtral-8x7B-Instruct-v0.1\"\n    )\n    parser.add_argument(\"--tp-size\", type=int, default=2)\n    parser.add_argument(\"--use-fp8-w8a8\", action=\"store_true\")\n    parser.add_argument(\n        \"--save-path\",\n        type=str,\n        default=\"./configs/benchmark_ops/fused_moe_torch_compile/\",\n    )\n    args = parser.parse_args()\n\n    model_config = get_model_config(args.model, args.tp_size)\n    benchmark.run(\n        show_plots=True,\n        print_data=True,\n        save_path=args.save_path,\n        model_config=model_config,\n        use_fp8_w8a8=args.use_fp8_w8a8,\n    )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py",
    "content": "# python3 benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py --model /DeepSeek-V3/ --tp-size 8 --use-fp8-w8a8\nimport argparse\n\nimport torch\nimport triton\nfrom vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm\n\nfrom sglang.benchmark.bench_utils import run_bench\nfrom sglang.srt.distributed.parallel_state import (\n    destroy_distributed_environment,\n    destroy_model_parallel,\n    init_distributed_environment,\n    initialize_model_parallel,\n)\nfrom sglang.srt.layers.moe.fused_moe_triton.fused_moe import (\n    fused_moe as fused_moe_sglang,\n)\n\nfrom .common_utils import get_model_config\n\n\ndef fused_moe_vllm_api(\n    x,\n    w1,\n    w2,\n    input_gating,\n    topk,\n    use_fp8_w8a8=False,\n    w1_scale=None,\n    w2_scale=None,\n    a1_scale=None,\n    a2_scale=None,\n    block_shape=None,\n):\n    if block_shape is not None:\n        return fused_moe_vllm(\n            x,\n            w1,\n            w2,\n            input_gating,\n            topk,\n            renormalize=True,\n            inplace=True,\n            use_fp8_w8a8=use_fp8_w8a8,\n            w1_scale=w1_scale,\n            w2_scale=w2_scale,\n            a1_scale=a1_scale,\n            a2_scale=a2_scale,\n            block_shape=block_shape,\n        )\n    else:\n        return fused_moe_vllm(\n            x,\n            w1,\n            w2,\n            input_gating,\n            topk,\n            renormalize=True,\n            inplace=True,\n            use_fp8_w8a8=use_fp8_w8a8,\n            w1_scale=w1_scale,\n            w2_scale=w2_scale,\n            a1_scale=a1_scale,\n            a2_scale=a2_scale,\n        )\n\n\ndef fused_moe_sglang_api(\n    x,\n    w1,\n    w2,\n    input_gating,\n    topk,\n    use_fp8_w8a8=False,\n    w1_scale=None,\n    w2_scale=None,\n    a1_scale=None,\n    a2_scale=None,\n    block_shape=None,\n):\n    return fused_moe_sglang(\n        x,\n        w1,\n        w2,\n        input_gating,\n        topk,\n        renormalize=True,\n        inplace=True,\n        use_fp8_w8a8=use_fp8_w8a8,\n        w1_scale=w1_scale,\n        w2_scale=w2_scale,\n        a1_scale=a1_scale,\n        a2_scale=a2_scale,\n        block_shape=block_shape,\n    )\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=[\"batch_size\"],\n        x_vals=list(range(1, 513)),\n        line_arg=\"provider\",\n        line_vals=[\n            \"vllm_fused_moe_triton\",\n            \"sglang_fused_moe_triton\",\n        ],\n        line_names=[\n            \"vllm_fused_moe_triton\",\n            \"sglang_fused_moe_triton\",\n        ],\n        styles=[\n            (\"blue\", \"-\"),\n            (\"green\", \"-\"),\n        ],\n        ylabel=\"Time (ms)\",\n        plot_name=\"fused-moe-performance\",\n        args={},\n    )\n)\ndef benchmark(batch_size, provider, model_config, use_fp8_w8a8=False):\n    print(f\"benchmark {provider} with batch_size={batch_size}\")\n    torch.set_default_device(\"cuda\")\n    torch.cuda.manual_seed_all(0)\n\n    num_tokens = batch_size\n    num_experts = model_config[\"num_experts\"]\n    hidden_size = model_config[\"hidden_size\"]\n    shard_intermediate_size = model_config[\"shard_intermediate_size\"]\n    topk = model_config[\"topk\"]\n    dtype = model_config[\"dtype\"]\n    block_shape = model_config[\"block_shape\"]\n\n    x = torch.randn(num_tokens, hidden_size, dtype=dtype)\n    w1_scale = w2_scale = a1_scale = a2_scale = None\n\n    if use_fp8_w8a8:\n        init_dtype = dtype\n        w1 = torch.randn(\n            num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype\n        )\n        w2 = torch.randn(\n            num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype\n        )\n        w1 = w1.to(torch.float8_e4m3fn)\n        w2 = w2.to(torch.float8_e4m3fn)\n\n        if block_shape is None:\n            w1_scale = torch.randn(num_experts, dtype=torch.float32)\n            w2_scale = torch.randn(num_experts, dtype=torch.float32)\n            a1_scale = torch.randn(1, dtype=torch.float32)\n            a2_scale = torch.randn(1, dtype=torch.float32)\n        else:\n            block_n, block_k = block_shape[0], block_shape[1]\n            n_tiles_w1 = (shard_intermediate_size + block_n - 1) // block_n\n            n_tiles_w2 = (hidden_size + block_n - 1) // block_n\n            k_tiles_w1 = (hidden_size + block_k - 1) // block_k\n            k_tiles_w2 = (shard_intermediate_size // 2 + block_k - 1) // block_k\n            w1_scale = torch.rand(\n                (num_experts, n_tiles_w1, k_tiles_w1), dtype=torch.float32\n            )\n            w2_scale = torch.rand(\n                (num_experts, n_tiles_w2, k_tiles_w2), dtype=torch.float32\n            )\n    else:\n        w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, dtype=dtype)\n        w2 = torch.randn(\n            num_experts, hidden_size, shard_intermediate_size // 2, dtype=dtype\n        )\n\n    input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)\n\n    # Warmup\n    api_func = (\n        fused_moe_vllm_api\n        if provider == \"vllm_fused_moe_triton\"\n        else fused_moe_sglang_api\n    )\n    for _ in range(10):\n        y = api_func(\n            x,\n            w1,\n            w2,\n            input_gating,\n            topk,\n            use_fp8_w8a8=use_fp8_w8a8,\n            w1_scale=w1_scale,\n            w2_scale=w2_scale,\n            a1_scale=a1_scale,\n            a2_scale=a2_scale,\n            block_shape=block_shape,\n        )\n    torch.cuda.synchronize()\n\n    quantiles = (0.5, 0.2, 0.8)\n    ms, min_ms, max_ms = run_bench(\n        lambda: api_func(\n            x,\n            w1,\n            w2,\n            input_gating,\n            topk,\n            use_fp8_w8a8=use_fp8_w8a8,\n            w1_scale=w1_scale,\n            w2_scale=w2_scale,\n            a1_scale=a1_scale,\n            a2_scale=a2_scale,\n            block_shape=block_shape,\n        )[0],\n        quantiles=quantiles,\n    )\n    return ms, min_ms, max_ms\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--model\", type=str, default=\"mistralai/Mixtral-8x7B-Instruct-v0.1\"\n    )\n    parser.add_argument(\"--tp-size\", \"--tp\", type=int, default=2)\n    parser.add_argument(\"--ep-size\", \"--ep\", type=int, default=1)\n    parser.add_argument(\"--use-fp8-w8a8\", action=\"store_true\")\n    parser.add_argument(\n        \"--save-path\",\n        type=str,\n        default=\"./configs/benchmark_ops/vllm_sglang_fused_moe/\",\n    )\n    args = parser.parse_args()\n\n    try:\n        if not torch.distributed.is_initialized():\n            torch.distributed.init_process_group(\n                backend=\"nccl\" if torch.cuda.is_available() else \"gloo\",\n                init_method=\"tcp://127.0.0.1:23456\",\n                world_size=1,\n                rank=0,\n            )\n\n        init_distributed_environment(\n            world_size=1,\n            rank=0,\n            distributed_init_method=\"tcp://127.0.0.1:23456\",\n            local_rank=0,\n            backend=\"nccl\" if torch.cuda.is_available() else \"gloo\",\n        )\n\n        initialize_model_parallel(\n            tensor_model_parallel_size=1,\n            pipeline_model_parallel_size=1,\n        )\n\n        shape_configs = get_model_config(args.model, args.tp_size, args.ep_size)\n        benchmark.run(\n            show_plots=True,\n            print_data=True,\n            save_path=args.save_path,\n            model_config=shape_configs,\n            use_fp8_w8a8=args.use_fp8_w8a8,\n        )\n    finally:\n        destroy_model_parallel()\n        destroy_distributed_environment()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "benchmark/kernels/fused_moe_triton/common_utils.py",
    "content": "import json\nfrom typing import Dict, List, TypedDict\n\nimport torch\n\nfrom sglang.srt.layers.moe.fused_moe_triton.fused_moe import get_config_dtype_str\nfrom sglang.srt.layers.moe.fused_moe_triton.fused_moe_triton_config import (\n    get_config_file_name,\n)\nfrom sglang.srt.utils import is_hip\nfrom sglang.srt.utils.hf_transformers_utils import get_config\n\n\nclass BenchmarkConfig(TypedDict):\n    BLOCK_SIZE_M: int\n    BLOCK_SIZE_N: int\n    BLOCK_SIZE_K: int\n    GROUP_SIZE_M: int\n    num_warps: int\n    num_stages: int\n\n\ndef calculate_shard_intermediate_size(\n    intermediate_size: int, tp_size: int, ep_size: int = 1\n) -> int:\n    assert tp_size % ep_size == 0\n    moe_tp_size = tp_size // ep_size\n    assert intermediate_size % moe_tp_size == 0\n    return 2 * intermediate_size // moe_tp_size\n\n\ndef get_model_config(\n    model_name: str,\n    tp_size: int,\n    ep_size: int = 1,\n    disable_shared_experts_fusion: bool = False,\n    topk_ids_dir: str = None,\n) -> Dict:\n    config = get_config(model_name, trust_remote_code=True)\n\n    # Replace config with text_config for encoder-decoder models after getting block_shape and architecture\n    if hasattr(config, \"text_config\"):\n        config = config.get_text_config()\n\n    block_shape = None\n    if (\n        hasattr(config, \"quantization_config\")\n        and \"weight_block_size\" in config.quantization_config\n    ):\n        block_shape = config.quantization_config[\"weight_block_size\"]\n        assert len(block_shape) == 2\n\n    if (\n        hasattr(config, \"quantization_config\")\n        and \"config_groups\" in config.quantization_config\n    ):\n        config_groups = config.quantization_config[\"config_groups\"]\n        # Get group_size from the first group's weights config\n        first_group = next(iter(config_groups.values()), {})\n        weights_config = first_group.get(\"weights\", {})\n        group_size = weights_config.get(\"group_size\")\n        block_shape = [0, group_size]\n        assert len(block_shape) == 2\n\n    architecture = config.architectures[0]\n\n    hidden_size = config.hidden_size\n    if architecture == \"DbrxForCausalLM\":\n        E = config.ffn_config.moe_num_experts // ep_size\n        topk = config.ffn_config.moe_top_k\n        intermediate_size = config.ffn_config.ffn_hidden_size\n    elif architecture == \"JambaForCausalLM\":\n        E = config.num_experts // ep_size\n        topk = config.num_experts_per_tok\n        intermediate_size = config.intermediate_size\n    elif architecture in [\n        \"Qwen2MoeForCausalLM\",\n        \"Qwen3MoeForCausalLM\",\n        \"Qwen3NextForCausalLM\",\n        \"Qwen3VLMoeForConditionalGeneration\",\n        \"Qwen3_5MoeForConditionalGeneration\",\n    ]:\n        E = config.num_experts // ep_size\n        topk = config.num_experts_per_tok\n        intermediate_size = config.moe_intermediate_size\n    elif architecture in [\n        \"DeepseekV2ForCausalLM\",\n        \"DeepseekV3ForCausalLM\",\n        \"DeepseekV32ForCausalLM\",\n        \"Glm4MoeForCausalLM\",\n        \"GlmMoeDsaForCausalLM\",\n        \"MistralLarge3ForCausalLM\",\n    ]:\n        E = (config.n_routed_experts // ep_size) + (\n            0\n            if disable_shared_experts_fusion\n            or architecture\n            not in [\n                \"DeepseekV3ForCausalLM\",\n                \"DeepseekV32ForCausalLM\",\n                \"Glm4MoeForCausalLM\",\n                \"GlmMoeDsaForCausalLM\",\n                \"MistralLarge3ForCausalLM\",\n            ]\n            else 1\n        )\n        topk = config.num_experts_per_tok + (\n            0 if disable_shared_experts_fusion or topk_ids_dir is None else 1\n        )\n        intermediate_size = config.moe_intermediate_size\n    elif architecture == \"Llama4ForConditionalGeneration\":\n        E = config.num_local_experts // ep_size + (\n            0 if disable_shared_experts_fusion else 1\n        )\n        topk = config.num_experts_per_tok + (\n            0 if disable_shared_experts_fusion or topk_ids_dir is None else 1\n        )\n        intermediate_size = config.intermediate_size\n    elif architecture in [\n        \"Grok1ForCausalLM\",\n        \"Grok1ImgGen\",\n        \"Grok1AForCausalLM\",\n    ]:\n        E = config.num_local_experts // ep_size\n        topk = config.num_experts_per_tok\n        intermediate_size = config.moe_intermediate_size\n    elif architecture in [\n        \"BailingMoEForCausalLM\",\n        \"BailingMoeForCausalLM\",\n        \"BailingMoeV2ForCausalLM\",\n    ]:\n        E = config.num_experts // ep_size\n        topk = config.num_experts_per_tok\n        intermediate_size = config.moe_intermediate_size\n    elif architecture == \"NemotronHForCausalLM\":\n        E = config.n_routed_experts // ep_size\n        topk = config.num_experts_per_tok\n        intermediate_size = config.moe_intermediate_size\n        hidden_size = getattr(config, \"moe_latent_size\", None) or hidden_size\n    else:\n        # Default: Mixtral\n        E = config.num_local_experts // ep_size\n        topk = config.num_experts_per_tok\n        intermediate_size = config.intermediate_size\n\n    shard_intermediate_size = calculate_shard_intermediate_size(\n        intermediate_size, tp_size, ep_size\n    )\n\n    return {\n        \"num_experts\": E,\n        \"topk\": topk,\n        \"hidden_size\": hidden_size,\n        \"shard_intermediate_size\": shard_intermediate_size,\n        \"dtype\": config.torch_dtype,\n        \"block_shape\": block_shape,\n        \"architecture\": architecture,\n    }\n\n\ndef get_rocm_configs_compute_bound() -> List[Dict[str, int]]:\n    configs: List[BenchmarkConfig] = []\n    waves_per_eu_range = 0\n    for num_stages in [2]:\n        for block_m in [32, 64, 128, 256]:\n            for block_k in [32, 64, 128, 256]:\n                for block_n in [16, 32, 64, 128, 256]:\n                    for num_warps in [1, 2, 4, 8]:\n                        for group_size in [1, 4, 8, 16, 32]:\n                            configs.append(\n                                {\n                                    \"BLOCK_SIZE_M\": block_m,\n                                    \"BLOCK_SIZE_N\": block_n,\n                                    \"BLOCK_SIZE_K\": block_k,\n                                    \"GROUP_SIZE_M\": group_size,\n                                    \"num_warps\": num_warps,\n                                    \"num_stages\": num_stages,\n                                    \"waves_per_eu\": waves_per_eu_range,\n                                }\n                            )\n    return configs\n\n\ndef get_configs_compute_bound() -> List[Dict[str, int]]:\n    configs: List[BenchmarkConfig] = []\n    if is_hip():\n        configs = get_rocm_configs_compute_bound()\n    else:\n        for num_stages in [2, 3, 4, 5]:\n            for block_m in [16, 32, 64, 128, 256]:\n                for block_k in [64, 128, 256]:\n                    for block_n in [32, 64, 128, 256]:\n                        for num_warps in [4, 8]:\n                            for group_size in [1, 16, 32, 64]:\n                                configs.append(\n                                    {\n                                        \"BLOCK_SIZE_M\": block_m,\n                                        \"BLOCK_SIZE_N\": block_n,\n                                        \"BLOCK_SIZE_K\": block_k,\n                                        \"GROUP_SIZE_M\": group_size,\n                                        \"num_warps\": num_warps,\n                                        \"num_stages\": num_stages,\n                                    }\n                                )\n    return configs\n\n\ndef sort_config(config: BenchmarkConfig) -> BenchmarkConfig:\n    return {\n        \"BLOCK_SIZE_M\": config[\"BLOCK_SIZE_M\"],\n        \"BLOCK_SIZE_N\": config[\"BLOCK_SIZE_N\"],\n        \"BLOCK_SIZE_K\": config[\"BLOCK_SIZE_K\"],\n        \"GROUP_SIZE_M\": config[\"GROUP_SIZE_M\"],\n        \"num_warps\": config[\"num_warps\"],\n        \"num_stages\": config[\"num_stages\"],\n        **(\n            {\"waves_per_eu\": config[\"waves_per_eu\"]} if \"waves_per_eu\" in config else {}\n        ),\n        **({\"USE_TMA\": config[\"USE_TMA\"]} if \"USE_TMA\" in config else {}),\n    }\n\n\ndef save_configs(\n    configs: Dict[int, BenchmarkConfig],\n    filename: str,\n) -> None:\n    print(f\"Writing best config to {filename}...\")\n    with open(filename, \"w\") as f:\n        json.dump(configs, f, indent=4)\n        f.write(\"\\n\")\n\n\ndef get_config_filename(\n    num_experts: int,\n    shard_intermediate_size: int,\n    hidden_size: int,\n    topk: int,\n    dtype: torch.dtype,\n    use_fp8_w8a8: bool,\n    use_int8_w8a8: bool,\n    use_int8_w8a16: bool,\n    use_int4_w4a16: bool,\n    per_channel_quant: bool,\n    block_shape: List[int],\n) -> str:\n    dtype_str = get_config_dtype_str(\n        dtype,\n        use_int8_w8a16=use_int8_w8a16,\n        use_fp8_w8a8=use_fp8_w8a8,\n        use_int8_w8a8=use_int8_w8a8,\n        use_int4_w4a16=use_int4_w4a16,\n    )\n\n    # NOTE(woosuk): The current naming convention uses w2.shape[2], which\n    # is the intermediate size after silu_and_mul.\n    N = shard_intermediate_size // 2\n    if use_int4_w4a16:\n        N = N // 2\n\n    filename = get_config_file_name(\n        num_experts,\n        N,\n        dtype_str,\n        block_shape,\n        per_channel_quant,\n    )\n\n    return filename\n\n\ndef get_default_batch_sizes() -> List[int]:\n    return [\n        1,\n        2,\n        4,\n        8,\n        16,\n        24,\n        32,\n        48,\n        64,\n        96,\n        128,\n        256,\n        512,\n        1024,\n        1536,\n        2048,\n        3072,\n        4096,\n    ]\n"
  },
  {
    "path": "benchmark/kernels/fused_moe_triton/tuning_client.py",
    "content": "import argparse\nimport os\nimport time\n\nimport openai\n\n\"\"\"\n# Edit the code file srt/models/deepseek_v2.py in the Python site package and add the logic for saving topk_ids:\n# import get_tensor_model_parallel_rank\n# DeepseekV2MoE::forward_normal\nif hidden_states.shape[0] >= 4096 and get_tensor_model_parallel_rank() == 0:\n    topk_ids_dir = xxxx\n    if not hasattr(self, \"save_idx\"):\n        self.save_idx = 0\n    if self.save_idx <= 1:\n        torch.save(topk_output.topk_ids, f\"{topk_ids_dir}/topk_ids_layer{self.layer_id}_idx{self.save_idx}.pt\")\n    self.save_idx += 1\n\"\"\"\n\n\ndef read_long_prompt():\n    import json\n\n    current_dir = os.path.dirname(os.path.abspath(__file__))\n    with open(f\"{current_dir}/tuning_text.json\", \"r\") as fp:\n        text = fp.read()\n    rst = json.loads(text)\n    return rst[\"prompt\"]\n\n\ndef openai_stream_test(model, ip, port):\n    client = openai.Client(base_url=f\"http://{ip}:{port}/v1\", api_key=\"None\")\n    qst = read_long_prompt()\n\n    messages = [\n        {\"role\": \"user\", \"content\": qst},\n    ]\n    msg2 = dict(\n        model=model,\n        messages=messages,\n        temperature=0.6,\n        top_p=0.75,\n        max_tokens=100,\n    )\n    response = client.chat.completions.create(**msg2, stream=True)\n    time_start = time.time()\n    time_cost = []\n    for chunk in response:\n        time_end = time.time()\n        # if chunk.choices[0].delta.content:\n        #    print(chunk.choices[0].delta.content, end=\"\", flush=True)\n        time_cost.append(time_end - time_start)\n        time_start = time.time()\n\n    ttft = time_cost[0] + time_cost[1]\n    tpot = sum(time_cost[2:]) / len(time_cost[2:])\n    print(f\"\\nTTFT {ttft}, TPOT {tpot}\")\n    return ttft, tpot\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--model\", type=str, default=\"auto\")\n    parser.add_argument(\n        \"--ip\",\n        type=str,\n        default=\"127.0.0.1\",\n    )\n    parser.add_argument(\"--port\", type=int, default=8188)\n    args = parser.parse_args()\n    openai_stream_test(args.model, args.ip, args.port)\n"
  },
  {
    "path": "benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py",
    "content": "# Adapted from https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py\nimport argparse\nimport time\nfrom contextlib import nullcontext\nfrom datetime import datetime\nfrom typing import Any, Dict, List, Tuple\n\nimport ray\nimport torch\nimport triton\nfrom common_utils import (\n    BenchmarkConfig,\n    get_config_filename,\n    get_configs_compute_bound,\n    get_default_batch_sizes,\n    get_model_config,\n    save_configs,\n    sort_config,\n)\nfrom ray.experimental.tqdm_ray import tqdm\n\nfrom sglang.srt.layers.moe.fused_moe_triton import override_config\nfrom sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe\nfrom sglang.srt.layers.moe.fused_moe_triton.fused_moe_triton_config import (\n    get_config_dtype_str,\n    get_default_config,\n    get_moe_configs,\n)\nfrom sglang.srt.layers.moe.moe_runner import MoeRunnerConfig\nfrom sglang.srt.layers.moe.topk import TopKConfig, select_experts\nfrom sglang.srt.server_args import (\n    ServerArgs,\n    set_global_server_args_for_scheduler,\n)\nfrom sglang.srt.utils import is_hip\n\n_is_hip = is_hip()\n\n\ndef benchmark_config(\n    config: BenchmarkConfig,\n    num_tokens: int,\n    num_experts: int,\n    shard_intermediate_size: int,\n    hidden_size: int,\n    topk: int,\n    dtype: torch.dtype,\n    use_fp8_w8a8: bool,\n    use_int8_w8a8: bool,\n    use_int8_w8a16: bool,\n    use_int4_w4a16: bool,\n    per_channel_quant: bool,\n    block_shape: List[int] = None,\n    num_iters: int = 100,\n) -> float:\n    init_dtype = torch.float16 if use_fp8_w8a8 else dtype\n    x = torch.randn(num_tokens, hidden_size, dtype=dtype)\n    if use_int8_w8a16 or use_int8_w8a8:\n        w1 = torch.randint(\n            -127,\n            127,\n            (\n                num_experts,\n                shard_intermediate_size,\n                hidden_size,\n            ),\n            dtype=torch.int8,\n        )\n        w2 = torch.randint(\n            -127,\n            127,\n            (\n                num_experts,\n                hidden_size,\n                shard_intermediate_size // 2,\n            ),\n            dtype=torch.int8,\n        )\n    elif use_int4_w4a16:\n        w1 = torch.randint(\n            0,\n            255,\n            (\n                num_experts,\n                shard_intermediate_size,\n                hidden_size // 2,\n            ),\n            dtype=torch.uint8,\n        )\n        w2 = torch.randint(\n            0,\n            255,\n            (\n                num_experts,\n                hidden_size,\n                shard_intermediate_size // 4,\n            ),\n            dtype=torch.uint8,\n        )\n    else:\n        w1 = torch.randn(\n            num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype\n        )\n        w2 = torch.randn(\n            num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype\n        )\n    gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)\n\n    w1_scale = None\n    w2_scale = None\n    a1_scale = None\n    a2_scale = None\n    if use_int8_w8a16:\n        w1_scale = torch.randn(\n            (num_experts, 2 * shard_intermediate_size), dtype=torch.float32\n        )\n        w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)\n    if use_int4_w4a16:\n        block_n = 1 if (block_shape[0] == 0) else block_shape[0]\n        block_k = block_shape[1]\n        n_tiles_w1 = (shard_intermediate_size + block_n - 1) // block_n\n        n_tiles_w2 = (hidden_size + block_n - 1) // block_n\n        k_tiles_w1 = (hidden_size + block_k - 1) // block_k\n        k_tiles_w2 = (shard_intermediate_size // 2 + block_k - 1) // block_k\n        w1_scale = torch.randn(\n            (num_experts, n_tiles_w1, k_tiles_w1), dtype=torch.bfloat16\n        )\n        w2_scale = torch.randn(\n            (num_experts, n_tiles_w2, k_tiles_w2), dtype=torch.bfloat16\n        )\n    if use_fp8_w8a8 or use_int8_w8a8:\n        if use_int8_w8a8 and block_shape is None:\n            w1_scale = torch.randn(\n                num_experts, shard_intermediate_size, dtype=torch.float32\n            )\n            w2_scale = torch.randn(num_experts, hidden_size, dtype=torch.float32)\n        elif block_shape is None:\n            w1_scale = torch.randn(num_experts, dtype=torch.float32)\n            w2_scale = torch.randn(num_experts, dtype=torch.float32)\n            a1_scale = torch.randn(1, dtype=torch.float32)\n            a2_scale = torch.randn(1, dtype=torch.float32)\n        else:\n            block_n, block_k = block_shape[0], block_shape[1]\n            n_tiles_w1 = (shard_intermediate_size + block_n - 1) // block_n\n            n_tiles_w2 = (hidden_size + block_n - 1) // block_n\n            k_tiles_w1 = (hidden_size + block_k - 1) // block_k\n            k_tiles_w2 = (shard_intermediate_size // 2 + block_k - 1) // block_k\n            w1_scale = torch.rand(\n                (num_experts, n_tiles_w1, k_tiles_w1), dtype=torch.float32\n            )\n            w2_scale = torch.rand(\n                (num_experts, n_tiles_w2, k_tiles_w2), dtype=torch.float32\n            )\n\n    if use_fp8_w8a8:\n        w1 = w1.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)\n        w2 = w2.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)\n\n    input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)\n    topk_config = TopKConfig(\n        top_k=topk,\n        renormalize=True,\n    )\n    topk_output = select_experts(x, input_gating, topk_config)\n\n    def prepare(i: int):\n        input_gating = gating_output[i]\n        new_topk_output = select_experts(x, input_gating, topk_config)\n        topk_output.topk_weights.copy_(new_topk_output.topk_weights)\n        topk_output.topk_ids.copy_(new_topk_output.topk_ids)\n        topk_output.router_logits.copy_(new_topk_output.router_logits)\n\n    def run():\n        moe_runner_config = MoeRunnerConfig(\n            inplace=True,\n        )\n\n        with override_config(config):\n            fused_moe(\n                x,\n                w1,\n                w2,\n                topk_output,\n                moe_runner_config=moe_runner_config,\n                use_fp8_w8a8=use_fp8_w8a8,\n                use_int8_w8a8=use_int8_w8a8,\n                use_int8_w8a16=use_int8_w8a16,\n                use_int4_w4a16=use_int4_w4a16,\n                w1_scale=w1_scale,\n                w2_scale=w2_scale,\n                a1_scale=a1_scale,\n                a2_scale=a2_scale,\n                per_channel_quant=per_channel_quant,\n                block_shape=block_shape,\n            )\n\n    # JIT compilation & warmup\n    run()\n    torch.cuda.synchronize()\n\n    # Capture 10 invocations with CUDA graph\n    graph = torch.cuda.CUDAGraph()\n    with torch.cuda.graph(graph):\n        for _ in range(10):\n            run()\n    torch.cuda.synchronize()\n\n    # Warmup\n    for _ in range(5):\n        graph.replay()\n    torch.cuda.synchronize()\n\n    # Flush L2 cache with 256 MB data\n    cache_flush = torch.empty(int(256e6 // 4), dtype=torch.int, device=\"cuda\")\n    cache_flush.zero_()\n\n    start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_iters)]\n    end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_iters)]\n\n    for i in range(num_iters):\n        prepare(i)\n        start_events[i].record()\n        graph.replay()\n        end_events[i].record()\n    torch.cuda.synchronize()\n\n    latencies: List[float] = []\n    for i in range(num_iters):\n        latencies.append(start_events[i].elapsed_time(end_events[i]))\n    avg = sum(latencies) / (num_iters * 10) * 1000  # us\n    graph.reset()\n    return avg\n\n\n@ray.remote(num_gpus=1)\nclass BenchmarkWorker:\n\n    def __init__(self, seed: int, server_args: ServerArgs) -> None:\n        torch.set_default_device(\"cuda\")\n        torch.cuda.manual_seed_all(0)\n        self.seed = seed\n        # Get the device ID to allocate tensors and kernels\n        # on the respective GPU.\n        self.device_id = int(ray.get_gpu_ids()[0])\n        set_global_server_args_for_scheduler(server_args)\n\n    def benchmark(\n        self,\n        num_tokens: int,\n        num_experts: int,\n        shard_intermediate_size: int,\n        hidden_size: int,\n        topk: int,\n        dtype: torch.dtype,\n        use_fp8_w8a8: bool,\n        use_int8_w8a8: bool,\n        use_int8_w8a16: bool,\n        use_int4_w4a16: bool,\n        per_channel_quant: bool,\n        block_shape: List[int],\n    ) -> Tuple[Dict[str, int], float]:\n        torch.cuda.manual_seed_all(0)\n        dtype_str = get_config_dtype_str(\n            dtype,\n            use_int8_w8a16=use_int8_w8a16,\n            use_fp8_w8a8=use_fp8_w8a8,\n            use_int4_w4a16=use_int4_w4a16,\n        )\n        # NOTE(woosuk): The current naming convention uses w2.shape[2], which\n        # is the intermediate size after silu_and_mul.\n        block_n = block_shape[0] if block_shape else 0\n        block_k = block_shape[1] if block_shape else 0\n        N = shard_intermediate_size // 2\n        if use_int4_w4a16:\n            N = N // 2\n        op_config = get_moe_configs(\n            num_experts,\n            N,\n            dtype_str,\n            block_n,\n            block_k,\n            per_channel_quant,\n        )\n        if op_config is None:\n            config = get_default_config(\n                num_tokens,\n                num_experts,\n                shard_intermediate_size,\n                hidden_size,\n                topk,\n                dtype_str,\n                False,\n                block_shape,\n            )\n        else:\n            config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))]\n        with torch.cuda.device(self.device_id) if is_hip() else nullcontext():\n            kernel_time = benchmark_config(\n                config,\n                num_tokens,\n                num_experts,\n                shard_intermediate_size,\n                hidden_size,\n                topk,\n                dtype,\n                use_fp8_w8a8,\n                use_int8_w8a8,\n                use_int8_w8a16,\n                use_int4_w4a16,\n                per_channel_quant,\n                block_shape,\n            )\n        return config, kernel_time\n\n    def tune(\n        self,\n        num_tokens: int,\n        num_experts: int,\n        shard_intermediate_size: int,\n        hidden_size: int,\n        topk: int,\n        dtype: torch.dtype,\n        use_fp8_w8a8: bool,\n        use_int8_w8a8: bool,\n        use_int8_w8a16: bool,\n        use_int4_w4a16: bool,\n        per_channel_quant: bool,\n        block_shape: List[int],\n        search_space: List[Dict[str, int]],\n    ) -> Dict[str, int]:\n        best_config = None\n        best_time = float(\"inf\")\n        with torch.cuda.device(self.device_id) if is_hip() else nullcontext():\n            for config in tqdm(search_space):\n                try:\n                    kernel_time = benchmark_config(\n                        config,\n                        num_tokens,\n                        num_experts,\n                        shard_intermediate_size,\n                        hidden_size,\n                        topk,\n                        dtype,\n                        use_fp8_w8a8,\n                        use_int8_w8a8,\n                        use_int8_w8a16,\n                        use_int4_w4a16,\n                        per_channel_quant,\n                        block_shape,\n                        num_iters=10,\n                    )\n                except (triton.runtime.autotuner.OutOfResources, RuntimeError):\n                    # Some configurations may be invalid and fail to compile.\n                    continue\n\n                if kernel_time < best_time:\n                    best_time = kernel_time\n                    best_config = config\n        now = datetime.now()\n        print(f\"{now.ctime()}] Completed tuning for batch_size={num_tokens}\")\n        assert best_config is not None\n        return best_config\n\n\ndef main(args: argparse.Namespace):\n    server_args = ServerArgs(\n        model_path=args.model, tp_size=args.tp_size, ep_size=args.ep_size\n    )\n\n    model_config = get_model_config(\n        args.model, args.tp_size, args.ep_size, args.disable_shared_experts_fusion\n    )\n\n    E = model_config[\"num_experts\"]\n    topk = model_config[\"topk\"]\n    hidden_size = model_config[\"hidden_size\"]\n    shard_intermediate_size = model_config[\"shard_intermediate_size\"]\n    dtype = model_config[\"dtype\"]\n    block_shape = model_config[\"block_shape\"]\n\n    use_fp8_w8a8 = args.dtype == \"fp8_w8a8\"\n    use_int8_w8a8 = args.dtype == \"int8_w8a8\"\n    use_int8_w8a16 = args.dtype == \"int8_w8a16\"\n    use_int4_w4a16 = args.dtype == \"int4_w4a16\"\n    per_channel_quant = args.per_channel_quant\n\n    if args.batch_size is None:\n        batch_sizes = get_default_batch_sizes()\n    else:\n        batch_sizes = [args.batch_size]\n\n    ray.init()\n    num_gpus = int(ray.available_resources()[\"GPU\"])\n    workers = [BenchmarkWorker.remote(args.seed, server_args) for _ in range(num_gpus)]\n\n    def _distribute(method: str, inputs: List[Any]) -> List[Any]:\n        outputs = []\n        worker_idx = 0\n        for input_args in inputs:\n            worker = workers[worker_idx]\n            worker_method = getattr(worker, method)\n            output = worker_method.remote(*input_args)\n            outputs.append(output)\n            worker_idx = (worker_idx + 1) % num_gpus\n        return ray.get(outputs)\n\n    if args.tune:\n        search_space = get_configs_compute_bound()\n        if block_shape is not None:\n            block_n, block_k = block_shape[0], block_shape[1]\n            search_space = [\n                config\n                for config in search_space\n                if block_k % config[\"BLOCK_SIZE_K\"] == 0\n            ]\n\n        filename = get_config_filename(\n            E,\n            shard_intermediate_size,\n            hidden_size,\n            topk,\n            dtype,\n            use_fp8_w8a8,\n            use_int8_w8a8,\n            use_int8_w8a16,\n            use_int4_w4a16,\n            per_channel_quant,\n            block_shape,\n        )\n        print(\n            f\"Start tuning over {len(search_space)} configurations to create {filename}...\"\n        )\n\n        start = time.perf_counter()\n        configs = _distribute(\n            \"tune\",\n            [\n                (\n                    batch_size,\n                    E,\n                    shard_intermediate_size,\n                    hidden_size,\n                    topk,\n                    dtype,\n                    use_fp8_w8a8,\n                    use_int8_w8a8,\n                    use_int8_w8a16,\n                    use_int4_w4a16,\n                    per_channel_quant,\n                    block_shape,\n                    search_space,\n                )\n                for batch_size in batch_sizes\n            ],\n        )\n        best_configs = {\n            M: sort_config(config) for M, config in zip(batch_sizes, configs)\n        }\n        save_configs(\n            best_configs,\n            filename,\n        )\n        end = time.perf_counter()\n        print(f\"Tuning took {end - start:.2f} seconds\")\n    else:\n        outputs = _distribute(\n            \"benchmark\",\n            [\n                (\n                    batch_size,\n                    E,\n                    shard_intermediate_size,\n                    hidden_size,\n                    topk,\n                    dtype,\n                    use_fp8_w8a8,\n                    use_int8_w8a8,\n                    use_int8_w8a16,\n                    use_int4_w4a16,\n                    per_channel_quant,\n                    block_shape,\n                )\n                for batch_size in batch_sizes\n            ],\n        )\n\n        for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):\n            print(f\"Batch size: {batch_size}, config: {config}\")\n            print(f\"Kernel time: {kernel_time:.2f} us\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--model\", type=str, default=\"mistralai/Mixtral-8x7B-Instruct-v0.1\"\n    )\n    parser.add_argument(\"--tp-size\", \"--tp\", type=int, default=2)\n    parser.add_argument(\"--ep-size\", \"--ep\", type=int, default=1)\n    parser.add_argument(\n        \"--dtype\",\n        type=str,\n        choices=[\"auto\", \"fp8_w8a8\", \"int8_w8a16\", \"int8_w8a8\", \"int4_w4a16\"],\n        default=\"auto\",\n    )\n    parser.add_argument(\n        \"--per-channel-quant\",\n        action=\"store_true\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=0)\n    parser.add_argument(\"--batch-size\", type=int, required=False)\n    parser.add_argument(\"--tune\", action=\"store_true\")\n    parser.add_argument(\"--disable-shared-experts-fusion\", action=\"store_true\")\n    args = parser.parse_args()\n\n    main(args)\n"
  },
  {
    "path": "benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton_sep.py",
    "content": "# Adapted from https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py\nimport argparse\nimport dataclasses\nimport json\nimport os\nimport time\nfrom contextlib import nullcontext\nfrom datetime import datetime\nfrom typing import Any, Dict, List, Tuple\n\nimport ray\nimport torch\nimport triton\nimport triton.language as tl\nfrom common_utils import (\n    BenchmarkConfig,\n    get_config_filename,\n    get_configs_compute_bound,\n    get_default_batch_sizes,\n    get_model_config,\n    sort_config,\n)\nfrom ray.experimental.tqdm_ray import tqdm\n\nfrom sglang.srt.layers.moe.fused_moe_triton.fused_moe import (\n    get_config_dtype_str,\n    invoke_fused_moe_kernel,\n    moe_align_block_size,\n)\nfrom sglang.srt.layers.moe.fused_moe_triton.fused_moe_triton_config import (\n    get_config_file_name,\n)\nfrom sglang.srt.layers.moe.moe_runner import MoeRunnerConfig\nfrom sglang.srt.layers.moe.topk import TopKConfig, select_experts\nfrom sglang.srt.server_args import (\n    ServerArgs,\n    set_global_server_args_for_scheduler,\n)\nfrom sglang.srt.utils import is_hip\n\n_is_hip = is_hip()\n\n\n@dataclasses.dataclass\nclass MoeInputs:\n    topk_ids: torch.Tensor\n    sorted_token_ids: torch.Tensor\n    expert_ids: torch.Tensor\n    num_tokens_post_padded: torch.Tensor\n\n\nclass KernelWrapper:\n    def __init__(self, moe_inputs, use_cuda_graph=True, inner_iter=10, **kwargs):\n        self.func = invoke_fused_moe_kernel\n        self.use_cuda_graph = use_cuda_graph\n        self.moe_inputs = moe_inputs\n        self.inner_iter = inner_iter\n        self.kwargs = kwargs\n        if use_cuda_graph:\n            self.graph = self.cuda_graph_wrapper()\n        else:\n            self.graph = None\n\n    def cuda_graph_wrapper(self):\n        moe_input = self.moe_inputs[0]\n        self.func(\n            **self.kwargs,\n            topk_ids=moe_input.topk_ids,\n            sorted_token_ids=moe_input.sorted_token_ids,\n            expert_ids=moe_input.expert_ids,\n            num_tokens_post_padded=moe_input.num_tokens_post_padded,\n        )\n        torch.cuda.synchronize()\n\n        # Capture 10 invocations with CUDA graph\n        graph = torch.cuda.CUDAGraph()\n        with torch.cuda.graph(graph):\n            for k in range(self.inner_iter):\n                moe_input = self.moe_inputs[k]\n                self.func(\n                    **self.kwargs,\n                    topk_ids=moe_input.topk_ids,\n                    sorted_token_ids=moe_input.sorted_token_ids,\n                    expert_ids=moe_input.expert_ids,\n                    num_tokens_post_padded=moe_input.num_tokens_post_padded,\n                )\n        torch.cuda.synchronize()\n\n        # Warmup\n        for _ in range(5):\n            graph.replay()\n        torch.cuda.synchronize()\n        return graph\n\n    def forward_cost(self, try_cnt=2):\n        time_cost = float(\"inf\")\n        for _ in range(try_cnt):\n            start_event = torch.cuda.Event(enable_timing=True)\n            end_event = torch.cuda.Event(enable_timing=True)\n            start_event.record()\n            if self.use_cuda_graph:\n                self.graph.replay()\n            else:\n                for k in range(self.inner_iter):\n                    moe_input = self.moe_inputs[k]\n                    self.func(\n                        **self.kwargs,\n                        topk_ids=moe_input.topk_ids,\n                        sorted_token_ids=moe_input.sorted_token_ids,\n                        expert_ids=moe_input.expert_ids,\n                        num_tokens_post_padded=moe_input.num_tokens_post_padded,\n                    )\n            end_event.record()\n            torch.cuda.synchronize()\n            time_cost = min(time_cost, start_event.elapsed_time(end_event))\n        return time_cost\n\n\ndef load_topk_ids(topk_ids_dir, i: int):\n    num_layers = 61\n    dense_layers = 3\n    moe_layers = num_layers - dense_layers\n    return torch.load(\n        f\"{topk_ids_dir}/topk_ids_layer{i % moe_layers + dense_layers}_idx{i // moe_layers}.pt\"\n    )\n\n\ndef benchmark_config(\n    config: BenchmarkConfig,\n    num_tokens: int,\n    num_experts: int,\n    shard_intermediate_size: int,\n    hidden_size: int,\n    topk: int,\n    dtype: torch.dtype,\n    use_fp8_w8a8: bool,\n    use_int8_w8a8: bool,\n    use_int8_w8a16: bool,\n    use_int4_w4a16: bool,\n    topk_ids_list,\n    block_shape: List[int] = None,\n    ep_size: int = 1,\n    num_iters: int = 100,\n) -> float:\n    ncu_enable = os.getenv(\"NCU_ENABLE\", \"0\") == \"1\"\n    if ncu_enable:\n        num_iters = 1\n    init_dtype = torch.float16 if use_fp8_w8a8 else dtype\n    hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)\n    if use_int8_w8a16 or use_int8_w8a8:\n        w1 = torch.randint(\n            -127,\n            127,\n            (\n                num_experts,\n                shard_intermediate_size,\n                hidden_size,\n            ),\n            dtype=torch.int8,\n        )\n        w2 = torch.randint(\n            -127,\n            127,\n            (\n                num_experts,\n                hidden_size,\n                shard_intermediate_size // 2,\n            ),\n            dtype=torch.int8,\n        )\n    elif use_int4_w4a16:\n        w1 = torch.randint(\n            0,\n            255,\n            (\n                num_experts,\n                shard_intermediate_size,\n                hidden_size // 2,\n            ),\n            dtype=torch.uint8,\n        )\n        w2 = torch.randint(\n            0,\n            255,\n            (\n                num_experts,\n                hidden_size,\n                shard_intermediate_size // 4,\n            ),\n            dtype=torch.uint8,\n        )\n    else:\n        w1 = torch.randn(\n            num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype\n        )\n        w2 = torch.randn(\n            num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype\n        )\n\n    w1_scale = None\n    w2_scale = None\n    a1_scale = None\n    a2_scale = None\n    if use_int8_w8a16:\n        w1_scale = torch.randn(\n            (num_experts, 2 * shard_intermediate_size), dtype=torch.float32\n        )\n        w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)\n    if use_int4_w4a16:\n        block_n = 1 if (block_shape[0] == 0) else block_shape[0]\n        block_k = block_shape[1]\n        n_tiles_w1 = (shard_intermediate_size + block_n - 1) // block_n\n        n_tiles_w2 = (hidden_size + block_n - 1) // block_n\n        k_tiles_w1 = (hidden_size + block_k - 1) // block_k\n        k_tiles_w2 = (shard_intermediate_size // 2 + block_k - 1) // block_k\n        w1_scale = torch.randn(\n            (num_experts, n_tiles_w1, k_tiles_w1), dtype=torch.bfloat16\n        )\n        w2_scale = torch.randn(\n            (num_experts, n_tiles_w2, k_tiles_w2), dtype=torch.bfloat16\n        )\n    if use_fp8_w8a8 or use_int8_w8a8:\n        if use_int8_w8a8 and block_shape is None:\n            w1_scale = torch.randn(\n                num_experts, shard_intermediate_size, dtype=torch.float32\n            )\n            w2_scale = torch.randn(num_experts, hidden_size, dtype=torch.float32)\n        elif block_shape is None:\n            w1_scale = torch.randn(num_experts, dtype=torch.float32)\n            w2_scale = torch.randn(num_experts, dtype=torch.float32)\n            a1_scale = torch.randn(1, dtype=torch.float32)\n            a2_scale = torch.randn(1, dtype=torch.float32)\n        else:\n            block_n, block_k = block_shape[0], block_shape[1]\n            n_tiles_w1 = (shard_intermediate_size + block_n - 1) // block_n\n            n_tiles_w2 = (hidden_size + block_n - 1) // block_n\n            k_tiles_w1 = (hidden_size + block_k - 1) // block_k\n            k_tiles_w2 = (shard_intermediate_size // 2 + block_k - 1) // block_k\n            w1_scale = torch.rand(\n                (num_experts, n_tiles_w1, k_tiles_w1), dtype=torch.float32\n            )\n            w2_scale = torch.rand(\n                (num_experts, n_tiles_w2, k_tiles_w2), dtype=torch.float32\n            )\n\n    if use_fp8_w8a8:\n        w1 = w1.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)\n        w2 = w2.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)\n\n    input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)\n    topk_config = TopKConfig(\n        top_k=topk,\n        renormalize=True,\n    )\n    topk_output_ = select_experts(hidden_states, input_gating, topk_config)\n    sorted_token_ids_, expert_ids_, num_tokens_post_padded_ = moe_align_block_size(\n        topk_output_.topk_ids, config[\"BLOCK_SIZE_M\"], num_experts\n    )\n    inner_iter = 10 if not ncu_enable else 1\n    moe_inputs = [\n        MoeInputs(\n            topk_output_.topk_ids.clone(),\n            sorted_token_ids_.clone(),\n            expert_ids_.clone(),\n            num_tokens_post_padded_.clone(),\n        )\n        for _ in range(inner_iter)\n    ]\n    M = hidden_states.shape[0]\n    E, N, _ = w1.shape\n\n    padded_tokens = min(M * topk, E + 1) * (\n        config[\"BLOCK_SIZE_M\"] - 1\n    )  # if moe_use_tma else 0\n    total_tokens = M * topk + padded_tokens\n    cache = torch.empty(\n        total_tokens * max(N, w2.shape[1]),\n        device=hidden_states.device,\n        dtype=hidden_states.dtype,\n    )\n    intermediate_cache1 = cache[: total_tokens * N].view(\n        (total_tokens, N),\n    )\n    intermediate_cache2 = torch.empty(\n        (total_tokens, N // 2),\n        device=hidden_states.device,\n        dtype=hidden_states.dtype,\n    )\n    intermediate_cache3 = cache[: M * topk * w2.shape[1]].view(\n        (M, topk, w2.shape[1]),\n    )\n\n    def prepare(i: int, inner_iter):  # update inputs according to topk_ids\n        for k in range(inner_iter):\n            topk_ids = topk_ids_list[i * inner_iter + k]\n            # With EP, saved topk_ids are global expert indices; remap to local.\n            if ep_size > 1:\n                topk_ids = (topk_ids // ep_size).to(\n                    device=moe_inputs[k].topk_ids.device,\n                    dtype=moe_inputs[k].topk_ids.dtype,\n                )\n            tokens, _topk = moe_inputs[k].topk_ids.shape\n            moe_inputs[k].topk_ids.copy_(topk_ids[:tokens, :_topk])\n            sorted_token_ids_, expert_ids_, num_tokens_post_padded_ = (\n                moe_align_block_size(\n                    moe_inputs[k].topk_ids, config[\"BLOCK_SIZE_M\"], num_experts\n                )\n            )\n            moe_inputs[k].sorted_token_ids.copy_(sorted_token_ids_)\n            moe_inputs[k].expert_ids.copy_(expert_ids_)\n            moe_inputs[k].num_tokens_post_padded.copy_(num_tokens_post_padded_)\n\n    def get_kernel_wrapper(moe_use_tma, inner_iter, use_cuda_graph):\n        compute_type = (\n            tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16\n        )\n        moe_runner_config = MoeRunnerConfig(\n            inplace=True,\n        )\n        apply_router_weight_on_input = moe_runner_config.apply_router_weight_on_input\n        kernel0 = KernelWrapper(\n            A=hidden_states,\n            B=w1,\n            bias=None,\n            C=intermediate_cache1,\n            A_scale=a1_scale,\n            B_scale=w1_scale,\n            B_zp=None,\n            topk_weights=topk_output_.topk_weights,\n            moe_inputs=moe_inputs,\n            mul_routed_weight=apply_router_weight_on_input,\n            top_k=topk,\n            config=config,\n            compute_type=compute_type,\n            use_fp8_w8a8=use_fp8_w8a8,\n            use_int8_w8a8=use_int8_w8a8,\n            use_int8_w8a16=use_int8_w8a16,\n            use_int4_w4a16=use_int4_w4a16,\n            per_channel_quant=False,\n            block_shape=block_shape,\n            b_use_tma=moe_use_tma,\n            c_sorted=moe_use_tma,\n            filter_expert=False,\n            use_cuda_graph=use_cuda_graph,\n            inner_iter=inner_iter,\n        )\n        kernel1 = KernelWrapper(\n            A=intermediate_cache2,\n            B=w2,\n            bias=None,\n            C=intermediate_cache3,\n            A_scale=a2_scale,\n            B_scale=w2_scale,\n            B_zp=None,\n            topk_weights=topk_output_.topk_weights,\n            moe_inputs=moe_inputs,\n            mul_routed_weight=not apply_router_weight_on_input,\n            top_k=1,\n            config=config,\n            compute_type=compute_type,\n            use_fp8_w8a8=use_fp8_w8a8,\n            use_int8_w8a8=use_int8_w8a8,\n            use_int8_w8a16=use_int8_w8a16,\n            use_int4_w4a16=use_int4_w4a16,\n            per_channel_quant=False,\n            block_shape=block_shape,\n            a_use_tma=moe_use_tma,\n            b_use_tma=moe_use_tma,\n            filter_expert=False,\n            use_cuda_graph=use_cuda_graph,\n            inner_iter=inner_iter,\n        )\n        return kernel0, kernel1\n\n    use_cuda_graph = True if not ncu_enable else False\n\n    kernel0, kernel1 = get_kernel_wrapper(False, inner_iter, use_cuda_graph)\n    kernel_tma0, kernel_tma1 = get_kernel_wrapper(True, inner_iter, use_cuda_graph)\n\n    # JIT compilation & warmup\n    if not ncu_enable:\n        kernel0.forward_cost()\n        kernel1.forward_cost()\n        kernel_tma0.forward_cost()\n        kernel_tma1.forward_cost()\n\n    ts0 = []\n    ts1 = []\n    ts_tma0 = []\n    ts_tma1 = []\n\n    for i in range(num_iters // inner_iter):\n        prepare(i, inner_iter)\n        ts0.append(kernel0.forward_cost())\n        ts1.append(kernel1.forward_cost())\n        ts_tma0.append(kernel_tma0.forward_cost())\n        ts_tma1.append(kernel_tma1.forward_cost())\n    torch.cuda.synchronize()\n\n    avg = sum(ts0) / (num_iters) * 1000  # us\n    avg1 = sum(ts1) / (num_iters) * 1000  # us\n    avg_tma = sum(ts_tma0) / (num_iters) * 1000  # us\n    avg1_tma = sum(ts_tma1) / (num_iters) * 1000  # us\n\n    return avg, avg_tma, avg1, avg1_tma\n\n\nclass BestConfigTrace:\n    def __init__(self, name, down_moe=False):\n        self.name = name\n        self.down_moe = down_moe\n        self.best_costs_m = {}  # block_m: best_cost\n\n    def update(self, config, time_cost_all):\n        block_m = config[\"BLOCK_SIZE_M\"]\n        if not self.down_moe:\n            time_cost = time_cost_all[0]\n        else:\n            time_cost = min(time_cost_all[2], time_cost_all[3])\n        if (\n            block_m not in self.best_costs_m\n            or time_cost < self.best_costs_m[block_m][1]\n        ):\n            self.best_costs_m[block_m] = config, time_cost, time_cost_all\n\n    def time_cost(self, block_m):\n        if block_m not in self.best_costs_m:\n            return float(\"inf\")\n        time_cost = self.best_costs_m[block_m][1]\n        return time_cost\n\n    def config_dict(self, block_m):\n        if block_m not in self.best_costs_m:\n            return {}\n        config, _, time_cost_all = self.best_costs_m[block_m]\n        if not self.down_moe:\n            return config\n        else:\n            return {\n                **config,\n                \"USE_TMA\": time_cost_all[2] > time_cost_all[3],\n            }\n\n\nclass BenchmarkWorker:\n\n    def __init__(self, seed: int, server_args: ServerArgs) -> None:\n        torch.set_default_device(\"cuda\")\n        torch.cuda.manual_seed_all(0)\n        self.seed = seed\n        # Get the device ID to allocate tensors and kernels\n        # on the respective GPU.\n        self.device_id = 0  # int(ray.get_gpu_ids()[0])\n        set_global_server_args_for_scheduler(server_args)\n\n    def benchmark(\n        self,\n        num_tokens: int,\n        num_experts: int,\n        shard_intermediate_size: int,\n        hidden_size: int,\n        topk: int,\n        dtype: torch.dtype,\n        use_fp8_w8a8: bool,\n        use_int8_w8a8: bool,\n        use_int8_w8a16: bool,\n        use_int4_w4a16: bool,\n        block_shape: List[int],\n        cfg: Dict[str, int],\n        topk_ids_dir: str,\n        ep_size: int = 1,\n    ) -> Tuple[Dict[str, int], float]:\n        torch.cuda.manual_seed_all(0)\n        topk_ids_list = [load_topk_ids(topk_ids_dir, i) for i in range(100)]\n        with torch.cuda.device(self.device_id) if is_hip() else nullcontext():\n            kernel_time = benchmark_config(\n                cfg,\n                num_tokens,\n                num_experts,\n                shard_intermediate_size,\n                hidden_size,\n                topk,\n                dtype,\n                use_fp8_w8a8,\n                use_int8_w8a8,\n                use_int8_w8a16,\n                use_int4_w4a16,\n                topk_ids_list,\n                block_shape,\n                ep_size=ep_size,\n            )\n        return cfg, kernel_time\n\n    def tune(\n        self,\n        num_tokens: int,\n        num_experts: int,\n        shard_intermediate_size: int,\n        hidden_size: int,\n        topk: int,\n        dtype: torch.dtype,\n        use_fp8_w8a8: bool,\n        use_int8_w8a8: bool,\n        use_int8_w8a16: bool,\n        use_int4_w4a16: bool,\n        block_shape: List[int],\n        search_space: List[Dict[str, int]],\n        topk_ids_dir: str,\n        ep_size: int = 1,\n    ) -> Dict[str, int]:\n        trace0 = BestConfigTrace(\"kernel0\", down_moe=False)\n        trace1 = BestConfigTrace(\"kernel1\", down_moe=True)\n        topk_ids_list = [load_topk_ids(topk_ids_dir, i) for i in range(100)]\n\n        with torch.cuda.device(self.device_id) if is_hip() else nullcontext():\n            for config in tqdm(search_space):\n                try:\n                    kt0_no_tma, kt0_tma, kt1_no_tma, kt1_tma = benchmark_config(\n                        config,\n                        num_tokens,\n                        num_experts,\n                        shard_intermediate_size,\n                        hidden_size,\n                        topk,\n                        dtype,\n                        use_fp8_w8a8,\n                        use_int8_w8a8,\n                        use_int8_w8a16,\n                        use_int4_w4a16,\n                        topk_ids_list,\n                        block_shape,\n                        ep_size=ep_size,\n                        num_iters=100,\n                    )\n                except triton.runtime.autotuner.OutOfResources:\n                    # Some configurations may be invalid and fail to compile.\n                    continue\n                trace0.update(\n                    config,\n                    (kt0_no_tma, kt0_tma, kt1_no_tma, kt1_tma),\n                )\n                trace1.update(\n                    config,\n                    (kt0_no_tma, kt0_tma, kt1_no_tma, kt1_tma),\n                )\n\n        now = datetime.now()\n        print(f\"{now.ctime()}] Completed tuning for batch_size={num_tokens}\")\n        best_block_m = 16\n        for block_m in (32, 64, 128, 256):\n            if trace0.time_cost(block_m) + trace1.time_cost(block_m) < trace0.time_cost(\n                best_block_m\n            ) + trace1.time_cost(best_block_m):\n                best_block_m = block_m\n\n        return (\n            trace0.config_dict(best_block_m),\n            trace1.config_dict(best_block_m),\n            trace0.time_cost(best_block_m),\n            trace1.time_cost(best_block_m),\n        )\n\n    def cmp_configs(\n        self,\n        num_tokens: List[int],\n        num_experts: int,\n        shard_intermediate_size: int,\n        hidden_size: int,\n        topk: int,\n        dtype: torch.dtype,\n        use_fp8_w8a8: bool,\n        use_int8_w8a8: bool,\n        use_int8_w8a16: bool,\n        use_int4_w4a16: bool,\n        block_shape: List[int],\n        cmp_config_files: List[str],\n        topk_ids_dir: str,\n        ep_size: int = 1,\n    ):\n        # compare performance of different configs\n        cmp_configs = []\n        for file in cmp_config_files:\n            with open(file) as f:\n                cmp_configs.append({int(key): val for key, val in json.load(f).items()})\n        for i, file in enumerate(cmp_config_files):\n            print(f\"config {i}: {file}\")\n\n        topk_ids_list = [load_topk_ids(topk_ids_dir, i) for i in range(100)]\n        torch.cuda.manual_seed_all(0)\n        with torch.cuda.device(self.device_id) if is_hip() else nullcontext():\n            for bs in num_tokens:\n                kernel_times = []\n                cfgs = []\n                for configs in cmp_configs:\n                    cfg_org = configs[min(configs.keys(), key=lambda x: abs(x - bs))]\n                    cfgs.append(cfg_org)\n                    cfg = cfg_org.copy()\n                    cfg.pop(\"USE_TMA\", None)\n                    kernel_time = benchmark_config(\n                        cfg,\n                        bs,\n                        num_experts,\n                        shard_intermediate_size,\n                        hidden_size,\n                        topk,\n                        dtype,\n                        use_fp8_w8a8,\n                        use_int8_w8a8,\n                        use_int8_w8a16,\n                        use_int4_w4a16,\n                        topk_ids_list,\n                        block_shape,\n                        ep_size=ep_size,\n                    )\n                    kernel_times.append(kernel_time)\n                print(f\"batch_size={bs=}:\")\n                for i, cfg in enumerate(cfgs):\n                    print(f\"  config {i} {cfg}: {kernel_times[i]}\")\n\n\ndef save_configs_sep(\n    configs: Dict[int, BenchmarkConfig],\n    num_experts: int,\n    shard_intermediate_size: int,\n    hidden_size: int,\n    topk: int,\n    dtype: torch.dtype,\n    use_fp8_w8a8: bool,\n    use_int8_w8a8: bool,\n    use_int8_w8a16: bool,\n    use_int4_w4a16: bool,\n    block_shape: List[int],\n    down_moe: bool = False,\n) -> None:\n    dtype_str = get_config_dtype_str(\n        dtype,\n        use_int8_w8a16=use_int8_w8a16,\n        use_fp8_w8a8=use_fp8_w8a8,\n        use_int8_w8a8=use_int8_w8a8,\n        use_int4_w4a16=use_int4_w4a16,\n    )\n\n    # NOTE(woosuk): The current naming convention uses w2.shape[2], which\n    # is the intermediate size after silu_and_mul.\n    filename = get_config_file_name(\n        num_experts,\n        shard_intermediate_size // 2,\n        dtype_str,\n        block_shape,\n        down_moe=down_moe,\n    )\n\n    print(f\"Writing best config to {filename}...\")\n    with open(filename, \"w\") as f:\n        json.dump(configs, f, indent=4)\n        f.write(\"\\n\")\n\n\ndef main(args: argparse.Namespace):\n    print(args)\n\n    server_args = ServerArgs(\n        model_path=args.model, tp_size=args.tp_size, ep_size=args.ep_size\n    )\n\n    model_config = get_model_config(\n        args.model,\n        args.tp_size,\n        args.ep_size,\n        args.disable_shared_experts_fusion,\n        args.topk_ids_dir,\n    )\n\n    E = model_config[\"num_experts\"]\n    topk = model_config[\"topk\"]\n    hidden_size = model_config[\"hidden_size\"]\n    shard_intermediate_size = model_config[\"shard_intermediate_size\"]\n    dtype = model_config[\"dtype\"]\n    block_shape = model_config[\"block_shape\"]\n\n    use_fp8_w8a8 = args.dtype == \"fp8_w8a8\"\n    use_int8_w8a8 = args.dtype == \"int8_w8a8\"\n    use_int8_w8a16 = args.dtype == \"int8_w8a16\"\n    use_int4_w4a16 = args.dtype == \"int4_w4a16\"\n\n    topk_ids_dir = args.topk_ids_dir\n    if args.batch_size is None:\n        batch_sizes = get_default_batch_sizes()\n        batch_sizes.reverse()\n    else:\n        batch_sizes = [args.batch_size]\n\n    if args.cmp_configs is not None:\n        worker = BenchmarkWorker(args.seed, server_args)\n        worker.cmp_configs(\n            batch_sizes,\n            E,\n            shard_intermediate_size,\n            hidden_size,\n            topk,\n            dtype,\n            use_fp8_w8a8,\n            use_int8_w8a8,\n            use_int8_w8a16,\n            use_int4_w4a16,\n            block_shape,\n            args.cmp_configs,\n            topk_ids_dir,\n            args.ep_size,\n        )\n        return\n\n    if len(batch_sizes) == 1:\n        worker = BenchmarkWorker(args.seed, server_args)\n        if args.tune:\n            search_space = get_configs_compute_bound()\n            worker.tune(\n                batch_sizes[0],\n                E,\n                shard_intermediate_size,\n                hidden_size,\n                topk,\n                dtype,\n                use_fp8_w8a8,\n                use_int8_w8a8,\n                use_int8_w8a16,\n                use_int4_w4a16,\n                block_shape,\n                search_space,\n                topk_ids_dir,\n                args.ep_size,\n            )\n        else:\n            cfg = {\n                \"BLOCK_SIZE_M\": args.configs[0],\n                \"BLOCK_SIZE_N\": args.configs[1],\n                \"BLOCK_SIZE_K\": args.configs[2],\n                \"GROUP_SIZE_M\": args.configs[3],\n                \"num_warps\": args.configs[4],\n                \"num_stages\": args.configs[5],\n            }\n\n            _, (t0, t0_tma, t1, t1_tma) = worker.benchmark(\n                args.batch_size,\n                E,\n                shard_intermediate_size,\n                hidden_size,\n                topk,\n                dtype,\n                use_fp8_w8a8,\n                use_int8_w8a8,\n                use_int8_w8a16,\n                use_int4_w4a16,\n                block_shape,\n                cfg,\n                topk_ids_dir,\n                args.ep_size,\n            )\n            print(f\"{t0=}, {t0_tma=}, {t1=}, {t1_tma=}\")\n        return\n\n    assert args.tune\n\n    ray.init()\n    num_gpus = int(ray.available_resources()[\"GPU\"])\n    workers = [\n        ray.remote(num_gpus=1)(BenchmarkWorker).remote(args.seed, server_args)\n        for _ in range(num_gpus)\n    ]\n\n    def _distribute(method: str, inputs: List[Any]) -> List[Any]:\n        outputs = []\n        worker_idx = 0\n        for input_args in inputs:\n            worker = workers[worker_idx]\n            worker_method = getattr(worker, method)\n            output = worker_method.remote(*input_args)\n            outputs.append(output)\n            worker_idx = (worker_idx + 1) % num_gpus\n        return ray.get(outputs)\n\n    search_space = get_configs_compute_bound()\n    if block_shape is not None:\n        block_n, block_k = block_shape[0], block_shape[1]\n        search_space = [\n            config for config in search_space if block_k % config[\"BLOCK_SIZE_K\"] == 0\n        ]\n    filename = get_config_filename(\n        E,\n        shard_intermediate_size,\n        hidden_size,\n        topk,\n        dtype,\n        use_fp8_w8a8,\n        use_int8_w8a8,\n        use_int8_w8a16,\n        use_int4_w4a16,\n        False,\n        block_shape,\n    )\n    print(\n        f\"Start tuning over {len(search_space)} configurations to create {filename}...\"\n    )\n\n    start = time.perf_counter()\n    configs = _distribute(\n        \"tune\",\n        [\n            (\n                batch_size,\n                E,\n                shard_intermediate_size,\n                hidden_size,\n                topk,\n                dtype,\n                use_fp8_w8a8,\n                use_int8_w8a8,\n                use_int8_w8a16,\n                use_int4_w4a16,\n                block_shape,\n                search_space,\n                topk_ids_dir,\n                args.ep_size,\n            )\n            for batch_size in batch_sizes\n        ],\n    )\n    print(f\"{configs=}\", flush=True)\n    cur_time = time.strftime(\"%Y-%m-%d %H:%M:%S\", time.localtime())\n    with open(f\"tuning_result_{cur_time}.txt\", \"w\") as f:\n        print(configs, file=f)\n    batch_sizes.reverse()\n    configs0 = [config[0] for config in configs]\n    configs1 = [config[1] for config in configs]\n    configs0.reverse()\n    configs1.reverse()\n    best_configs0 = {M: sort_config(config) for M, config in zip(batch_sizes, configs0)}\n    save_configs_sep(\n        best_configs0,\n        E,\n        shard_intermediate_size,\n        hidden_size,\n        topk,\n        dtype,\n        use_fp8_w8a8,\n        use_int8_w8a8,\n        use_int8_w8a16,\n        use_int4_w4a16,\n        block_shape,\n    )\n\n    best_configs1 = {M: sort_config(config) for M, config in zip(batch_sizes, configs1)}\n    save_configs_sep(\n        best_configs1,\n        E,\n        shard_intermediate_size,\n        hidden_size,\n        topk,\n        dtype,\n        use_fp8_w8a8,\n        use_int8_w8a8,\n        use_int8_w8a16,\n        use_int4_w4a16,\n        block_shape,\n        down_moe=True,\n    )\n    end = time.perf_counter()\n    print(f\"Tuning took {end - start:.2f} seconds\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--model\", type=str, default=\"mistralai/Mixtral-8x7B-Instruct-v0.1\"\n    )\n    parser.add_argument(\"--tp-size\", \"--tp\", type=int, default=2)\n    parser.add_argument(\"--ep-size\", \"--ep\", type=int, default=1)\n    parser.add_argument(\n        \"--dtype\",\n        type=str,\n        choices=[\"auto\", \"fp8_w8a8\", \"int8_w8a16\", \"int8_w8a8\", \"int8_w4a16\"],\n        default=\"auto\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=0)\n    parser.add_argument(\"--batch-size\", type=int, required=False)\n    parser.add_argument(\"--tune\", action=\"store_true\")\n    parser.add_argument(\"--disable-shared-experts-fusion\", action=\"store_true\")\n    parser.add_argument(\"--configs\", type=int, nargs=\"+\", required=False)\n    parser.add_argument(\"--topk-ids-dir\", type=str, required=True)\n    parser.add_argument(\"--cmp-configs\", type=str, nargs=\"+\", required=False)\n    args = parser.parse_args()\n\n    main(args)\n"
  },
  {
    "path": "benchmark/kernels/fused_moe_triton/tuning_text.json",
    "content": "{\"prompt\": \"Here are the relevant Wikipedia articles:\\nThe president of the United States (POTUS) is the head of state and head of government of the United States of America. The president directs the executive branch of the federal government and is the commander-in-chief of the United States Armed Forces.\\nThe power of the presidency has grown substantially since the first president, George Washington, took office in 1789. While presidential power has ebbed and flowed over time, the presidency has played an increasingly significant role in American political life since the beginning of the 20th century, carrying over into the 21st century with notable expansions during the presidencies of Franklin D. Roosevelt and George W. Bush. In modern times, the president is one of the world's most powerful political figures and the leader of the world's only remaining superpower. As the leader of the nation with the largest economy by nominal GDP, the president possesses significant domestic and international hard and soft power. For much of the 20th century, especially during the Cold War, the U.S. president was often called \\\"the leader of the free world\\\".\\nArticle II of the Constitution establishes the executive branch of the federal government and vests executive power in the president. The power includes the execution and enforcement of federal law and the responsibility to appoint federal executive, diplomatic, regulatory, and judicial officers.  Based on constitutional provisions empowering the president to appoint and receive ambassadors and conclude treaties with foreign powers, and on subsequent laws enacted by Congress, the modern presidency has primary responsibility for conducting U.S. foreign policy. The role includes responsibility for directing the world's most expensive military, which has the second-largest nuclear arsenal.\\nThe president also plays a leading role in federal legislation and domestic policymaking. As part of the system of separation of powers, Article I, Section 7 of the Constitution gives the president the power to sign or veto federal legislation. Since modern presidents are typically viewed as leaders of their political parties, major policymaking is significantly shaped by the outcome of presidential elections, with presidents taking an active role in promoting their policy priorities to members of Congress who are often electorally dependent on the president. In recent decades, presidents have also made increasing use of executive orders, agency regulations, and judicial appointments to shape domestic policy.\\nThe president is elected indirectly through the Electoral College to a four-year term, along with the vice president. Under the Twenty-second Amendment, ratified in 1951, no person who has been elected to two presidential terms may be elected to a third. In addition, nine vice presidents have become president by virtue of a president's intra-term death or resignation. In all, 45 individuals have served 46 presidencies spanning 58 four-year terms. Joe Biden is the 46th and current president, having assumed office on January 20, 2021.\\n\\nHistory and development\\nOrigins\\nDuring the American Revolutionary War, the Thirteen Colonies, represented by the Second Continental Congress in Philadelphia, declared themselves to be independent sovereign states and no longer under British rule. The affirmation was made in the Declaration of Independence, which was written predominantly by Thomas Jefferson and adopted unanimously on July 4, 1776, by the Second Continental Congress. Recognizing the necessity of closely coordinating their efforts against the British, the Continental Congress simultaneously began the process of drafting a constitution that would bind the states together. There were long debates on a number of issues, including representation and voting, and the exact powers to be given the central government. Congress finished work on the Articles of Confederation to establish a perpetual union between the states in November 1777 and sent it to the states for ratification.\\nUnder the Articles, which took effect on March 1, 1781, the Congress of the Confederation was a central political authority without any legislative power. It could make its own resolutions, determinations, and regulations, but not any laws, and could not impose any taxes or enforce local commercial regulations upon its citizens. This institutional design reflected how Americans believed the deposed British system of Crown and Parliament ought to have functioned with respect to the royal dominion: a superintending body for matters that concerned the entire empire. The states were out from under any monarchy and assigned some formerly royal prerogatives (e.g., making war, receiving ambassadors, etc.) to Congress; the remaining prerogatives were lodged within their own respective state governments. The members of Congress elected a president of the United States in Congress Assembled to preside over its deliberation as a neutral discussion moderator. Unrelated to and quite dissimilar from the later office of president of the United States, it was a largely ceremonial position without much influence.\\nIn 1783, the Treaty of Paris secured independence for each of the former colonies. With peace at hand, the states each turned toward their own internal affairs. By 1786, Americans found their continental borders besieged and weak and their respective economies in crises as neighboring states agitated trade rivalries with one another. They witnessed their hard currency pouring into foreign markets to pay for imports, their Mediterranean commerce preyed upon by North African pirates, and their foreign-financed Revolutionary War debts unpaid and accruing interest. Civil and political unrest loomed.  Events such as the Newburgh Conspiracy and Shays' Rebellion demonstrated that the Articles of Confederation were not working.\\nFollowing the successful resolution of commercial and fishing disputes between Virginia and Maryland at the Mount Vernon Conference in 1785, Virginia called for a trade conference between all the states, set for September 1786 in Annapolis, Maryland, with an aim toward resolving further-reaching interstate commercial antagonisms. When the convention failed for lack of attendance due to suspicions among most of the other states, Alexander Hamilton of New York led the Annapolis delegates in a call for a convention to offer revisions to the Articles, to be held the next spring in Philadelphia. Prospects for the next convention appeared bleak until James Madison and Edmund Randolph succeeded in securing George Washington's attendance to Philadelphia as a delegate for Virginia.\\nWhen the Constitutional Convention convened in May 1787, the 12 state delegations in attendance (Rhode Island did not send delegates) brought with them an accumulated experience over a diverse set of institutional arrangements between legislative and executive branches from within their respective state governments. Most states maintained a weak executive without veto or appointment powers, elected annually by the legislature to a single term only, sharing power with an executive council, and countered by a strong legislature. New York offered the greatest exception, having a strong, unitary governor with veto and appointment power elected to a three-year term, and eligible for reelection to an indefinite number of terms thereafter. It was through the closed-door negotiations at Philadelphia that the presidency framed in the U.S. Constitution emerged.\\n\\n1789–1933\\nAs the nation's first president, George Washington established many norms that would come to define the office. His decision to retire after two terms helped address fears that the nation would devolve into monarchy, and established a precedent that would not be broken until 1940 and would eventually be made permanent by the Twenty-Second Amendment. By the end of his presidency, political parties had developed, with John Adams defeating Thomas Jefferson in 1796, the first truly contested presidential election.  After Jefferson defeated Adams in 1800, he and his fellow Virginians James Madison and James Monroe would each serve two terms, eventually dominating the nation's politics during the Era of Good Feelings until Adams' son John Quincy Adams won election in 1824 after the Democratic-Republican Party split.\\nThe election of Andrew Jackson in 1828 was a significant milestone, as Jackson was not part of the Virginia and Massachusetts elite that had held the presidency for its first 40 years. Jacksonian democracy sought to strengthen the presidency at the expense of Congress, while broadening public participation as the nation rapidly expanded westward. However, his successor, Martin Van Buren, became unpopular after the Panic of 1837, and the death of William Henry Harrison and subsequent poor relations between John Tyler and Congress led to further weakening of the office. Including Van Buren, in the 24 years between 1837 and 1861, six presidential terms would be filled by eight different men, with none serving two terms. The Senate played an important role during this period, with the Great Triumvirate of Henry Clay, Daniel Webster, and John C. Calhoun playing key roles in shaping national policy in the 1830s and 1840s until debates over slavery began pulling the nation apart in the 1850s.\\nAbraham Lincoln's leadership during the Civil War has led historians to regard him as one of the nation's greatest presidents. The circumstances of the war and Republican domination of Congress made the office very powerful, and Lincoln's re-election in 1864 was the first time a president had been re-elected since Jackson in 1832. After Lincoln's assassination, his successor Andrew Johnson lost all political support and was nearly removed from office, with Congress remaining powerful during the two-term presidency of Civil War general Ulysses S. Grant. After the end of Reconstruction, Grover Cleveland would eventually become the first Democratic president elected since before the war, running in three consecutive elections (1884, 1888, 1892) and winning twice. In 1900, William McKinley became the first incumbent to win re-election since Grant in 1872.\\nAfter McKinley's assassination by Leon Czolgosz in 1901, Theodore Roosevelt became a dominant figure in American politics. Historians believe Roosevelt permanently changed the political system by strengthening the presidency, with some key accomplishments including breaking up trusts, conservationism, labor reforms, making personal character as important as the issues, and hand-picking his successor, William Howard Taft. The following decade, Woodrow Wilson led the nation to victory during World War I, although Wilson's proposal for the League of Nations was rejected by the Senate.  Warren Harding, while popular in office, would see his legacy tarnished by scandals, especially Teapot Dome, and Herbert Hoover quickly became very unpopular after failing to alleviate the Great Depression.\\n\\nImperial presidency\\nThe ascendancy of Franklin D. Roosevelt in 1933 led further toward what historians now describe as the Imperial presidency. Backed by enormous Democratic majorities in Congress and public support for major change, Roosevelt's New Deal dramatically increased the size and scope of the federal government, including more executive agencies.: 211–12  The traditionally small presidential staff was greatly expanded, with the Executive Office of the President being created in 1939, none of whom require Senate confirmation.: 229–231  Roosevelt's unprecedented re-election to a third and fourth term, the victory of the United States in World War II, and the nation's growing economy all helped established the office as a position of global leadership.: 269  His successors, Harry Truman and Dwight D. Eisenhower, each served two terms as the Cold War led the presidency to be viewed as the \\\"leader of the free world\\\", while John F. Kennedy was a youthful and popular leader who benefited from the rise of television in the 1960s.\\nAfter Lyndon B. Johnson lost popular support due to the Vietnam War and Richard Nixon's presidency collapsed in the Watergate scandal, Congress enacted a series of reforms intended to reassert itself. These included the War Powers Resolution, enacted over Nixon's veto in 1973, and the Congressional Budget and Impoundment Control Act of 1974 that sought to strengthen congressional fiscal powers. By 1976, Gerald Ford conceded that \\\"the historic pendulum\\\" had swung toward Congress, raising the possibility of a \\\"disruptive\\\" erosion of his ability to govern. Ford failed to win election to a full term and his successor, Jimmy Carter, failed to win re-election.  Ronald Reagan, who had been an actor before beginning his political career, used his talent as a communicator to help reshape the American agenda away from New Deal policies toward more conservative ideology.\\nWith the Cold War ending and the United States becoming the world's undisputed leading power, Bill Clinton, George W. Bush, and Barack Obama each served two terms as president. Meanwhile, Congress and the nation gradually became more politically polarized, especially following the 1994 mid-term elections that saw Republicans control the House for the first time in 40 years, and the rise of routine filibusters in the Senate in recent decades. Recent presidents have thus increasingly focused on executive orders, agency regulations, and judicial appointments to implement major policies, at the expense of legislation and congressional power. Presidential elections in the 21st century have reflected this continuing polarization, with no candidate except Obama in 2008 winning by more than five percent of the popular vote and two, George W. Bush and Donald Trump, winning in the Electoral College while losing the popular vote.\\n\\nCritics of presidency's evolution\\nThe nation's Founding Fathers expected the Congress, which was the first branch of government described in the Constitution, to be the dominant branch of government; however, they did not expect a strong executive department. However, presidential power has shifted over time, which has resulted in claims that the modern presidency has become too powerful, unchecked, unbalanced, and \\\"monarchist\\\" in nature. In 2008 professor Dana D. Nelson expressed belief that presidents over the previous thirty years worked towards \\\"undivided presidential control of the executive branch and its agencies\\\". She criticized proponents of the unitary executive theory for expanding \\\"the many existing uncheckable executive powers—such as executive orders, decrees, memorandums, proclamations, national security directives and legislative signing statements—that already allow presidents to enact a good deal of foreign and domestic policy without aid, interference or consent from Congress\\\". Bill Wilson, board member of Americans for Limited Government, opined that the expanded presidency was \\\"the greatest threat ever to individual freedom and democratic rule\\\".\\n\\nLegislative powers\\nArticle I, Section 1 of the Constitution vests all lawmaking power in Congress's hands, and Article 1, Section 6, Clause 2 prevents the president (and all other executive branch officers) from simultaneously being a member of Congress. Nevertheless, the modern presidency exerts significant power over legislation, both due to constitutional provisions and historical developments over time.\\n\\nSigning and vetoing bills\\nThe president's most significant legislative power derives from the Presentment Clause, which gives the president the power to veto any bill passed by Congress. While Congress can override a presidential veto, it requires a two-thirds vote of both houses, which is usually very difficult to achieve except for widely supported bipartisan legislation. The framers of the Constitution feared that Congress would seek to increase its power and enable a \\\"tyranny of the majority\\\", so giving the indirectly elected president a veto was viewed as an important check on the legislative power. While George Washington believed the veto should only be used in cases where a bill was unconstitutional, it is now routinely used in cases where presidents have policy disagreements with a bill.  The veto – or threat of a veto – has thus evolved to make the modern presidency a central part of the American legislative process.\\nSpecifically, under the Presentment Clause, once a bill has been presented by Congress, the president has three options:\\n\\nSign the legislation within ten days, excluding Sundays, the bill becomes law.\\nVeto the legislation within the above timeframe and return it to the house of Congress from which it originated, expressing any objections, the bill does not become law, unless both houses of Congress vote to override the veto by a two-thirds vote.\\nTake no action on the legislation within the above timeframe—the bill becomes law, as if the president had signed it, unless Congress is adjourned at the time, in which case it does not become law, which is known as a pocket veto.\\nIn 1996, Congress attempted to enhance the president's veto power with the Line Item Veto Act. The legislation empowered the president to sign any spending bill into law while simultaneously striking certain spending items within the bill, particularly any new spending, any amount of discretionary spending, or any new limited tax benefit. Congress could then repass that particular item. If the president then vetoed the new legislation, Congress could override the veto by its ordinary means, a two-thirds vote in both houses. In Clinton v. City of New York, 524 U.S. 417 (1998), the U.S. Supreme Court ruled such a legislative alteration of the veto power to be unconstitutional.\\n\\nSetting the agenda\\nFor most of American history, candidates for president have sought election on the basis of a promised legislative agenda. Article II, Section 3, Clause 2 requires the president to recommend such measures to Congress which the president deems \\\"necessary and expedient\\\". This is done through the constitutionally-based State of the Union address, which usually outlines the president's legislative proposals for the coming year, and through other formal and informal communications with Congress.\\nThe president can be involved in crafting legislation by suggesting, requesting, or even insisting that Congress enact laws that the president believes are needed. Additionally, the president can attempt to shape legislation during the legislative process by exerting influence on individual members of Congress. Presidents possess this power because the Constitution is silent about who can write legislation, but the power is limited because only members of Congress can introduce legislation.\\nThe president or other officials of the executive branch may draft legislation and then ask senators or representatives to introduce these drafts into Congress. Additionally, the president may attempt to have Congress alter proposed legislation by threatening to veto that legislation unless requested changes are made.\\n\\nPromulgating regulations\\nMany laws enacted by Congress do not address every possible detail, and either explicitly or implicitly delegate powers of implementation to an appropriate federal agency. As the head of the executive branch, presidents control a vast array of agencies that can issue regulations with little oversight from Congress.\\nIn the 20th century, critics charged that too many legislative and budgetary powers that should have belonged to Congress had slid into the hands of presidents. One critic charged that presidents could appoint a \\\"virtual army of 'czars'—each wholly unaccountable to Congress yet tasked with spearheading major policy efforts for the White House\\\". Presidents have been criticized for making signing statements when signing congressional legislation about how they understand a bill or plan to execute it. This practice has been criticized by the American Bar Association as unconstitutional. Conservative commentator George Will wrote of an \\\"increasingly swollen executive branch\\\" and \\\"the eclipse of Congress\\\".\\n\\nConvening and adjourning Congress\\nTo allow the government to act quickly in case of a major domestic or international crisis arising when Congress is not in session, the president is empowered by Article II, Section 3 of the Constitution to call a special session of one or both houses of Congress. Since John Adams first did so in 1797, the president has called the full Congress to convene for a special session on 27 occasions. Harry S. Truman was the most recent to do so in July 1948, known as the Turnip Day Session. In addition, prior to ratification of the Twentieth Amendment in 1933, which brought forward the date on which Congress convenes from December to January, newly inaugurated presidents would routinely call the Senate to meet to confirm nominations or ratify treaties. In practice, the power has fallen into disuse in the modern era as Congress now formally remains in session year-round, convening pro forma sessions every three days even when ostensibly in recess. Correspondingly, the president is authorized to adjourn Congress if the House and Senate cannot agree on the time of adjournment; no president has ever had to exercise this power.\\n\\nExecutive powers\\nThe president is head of the executive branch of the federal government and is constitutionally obligated to \\\"take care that the laws be faithfully executed\\\". The executive branch has over four million employees, including the military.\\n\\nAdministrative powers\\nPresidents make political appointments. An incoming president may make up to 4,000 upon taking office, 1200 of which must be confirmed by the U.S. Senate. Ambassadors, members of the Cabinet, and various officers, are among the positions filled by presidential appointment with Senate confirmation.\\nThe power of a president to fire executive officials has long been a contentious political issue. Generally, a president may remove executive officials at will. However, Congress can curtail and constrain a president's authority to fire commissioners of independent regulatory agencies and certain inferior executive officers by statute.\\nTo manage the growing federal bureaucracy, presidents have gradually surrounded themselves with many layers of staff, who were eventually organized into the Executive Office of the President of the United States. Within the Executive Office, the president's innermost layer of aides, and their assistants, are located in the White House Office.\\nThe president also possesses the power to manage operations of the federal government by issuing various types of directives, such as presidential proclamation and executive orders. When the president is lawfully exercising one of the constitutionally conferred presidential responsibilities, the scope of this power is broad. Even so, these directives are subject to judicial review by U.S. federal courts, which can find them to be unconstitutional. Congress can overturn an executive order through legislation.\\n\\nForeign affairs\\nArticle II, Section 3, Clause 4 requires the president to \\\"receive Ambassadors.\\\" This clause, known as the Reception Clause, has been interpreted to imply that the president possesses broad power over matters of foreign policy, and to provide support for the president's exclusive authority to grant recognition to a foreign government. The Constitution also empowers the president to appoint United States ambassadors, and to propose and chiefly negotiate agreements between the United States and other countries. Such agreements, upon receiving the advice and consent of the U.S. Senate (by a two-thirds majority vote), become binding with the force of federal law.\\nWhile foreign affairs has always been a significant element of presidential responsibilities, advances in technology since the Constitution's adoption have increased presidential power. Where formerly ambassadors were vested with significant power to independently negotiate on behalf of the United States, presidents now routinely meet directly with leaders of foreign countries.\\n\\nCommander-in-chief\\nOne of the most important of executive powers is the president's role as commander-in-chief of the United States Armed Forces. The power to declare war is constitutionally vested in Congress, but the president has ultimate responsibility for the direction and disposition of the military. The exact degree of authority that the Constitution grants to the president as commander-in-chief has been the subject of much debate throughout history, with Congress at various times granting the president wide authority and at others attempting to restrict that authority. The framers of the Constitution took care to limit the president's powers regarding the military; Alexander Hamilton explained this in Federalist No. 69:The President is to be commander-in-chief of the army and navy of the United States. ... It would amount to nothing more than the supreme command and direction of the military and naval forces ... while that [the power] of the British king extends to the DECLARING of war and to the RAISING and REGULATING of fleets and armies, all [of] which ... would appertain to the legislature. [Emphasis in the original.]\\nIn the modern era, pursuant to the War Powers Resolution, Congress must authorize any troop deployments longer than 60 days, although that process relies on triggering mechanisms that have never been employed, rendering it ineffectual. Additionally, Congress provides a check to presidential military power through its control over military spending and regulation. Presidents have historically initiated the process for going to war, but critics have charged that there have been several conflicts in which presidents did not get official declarations, including Theodore Roosevelt's military move into Panama in 1903, the Korean War, the Vietnam War, and the invasions of Grenada in 1983 and Panama in 1989.\\nThe amount of military detail handled personally by the president in wartime has varied greatly. George Washington, the first U.S. president, firmly established military subordination under civilian authority. In 1794, Washington used his constitutional powers to assemble 12,000 militia to quell the Whiskey Rebellion, a conflict in Western Pennsylvania involving armed farmers and distillers who refused to pay an excise tax on spirits. According to historian Joseph Ellis, this was the \\\"first and only time a sitting American president led troops in the field\\\", though James Madison briefly took control of artillery units in defense of Washington, D.C., during the War of 1812. Abraham Lincoln was deeply involved in overall strategy and in day-to-day operations during the American Civil War, 1861–1865; historians have given Lincoln high praise for his strategic sense and his ability to select and encourage commanders such as Ulysses S. Grant.\\nThe present-day operational command of the Armed Forces is delegated to the Department of Defense and is normally exercised through the secretary of defense. The chairman of the Joint Chiefs of Staff and the Combatant Commands assist with the operation as outlined in the presidentially approved Unified Command Plan (UCP).\\n\\nJuridical powers and privileges\\nThe president has the power to nominate federal judges, including members of the United States courts of appeals and the Supreme Court of the United States. However, these nominations require Senate confirmation before they may take office. Securing Senate approval can provide a major obstacle for presidents who wish to orient the federal judiciary toward a particular ideological stance. When nominating judges to U.S. district courts, presidents often respect the long-standing tradition of senatorial courtesy. Presidents may also grant pardons and reprieves. Gerald Ford pardoned Richard Nixon a month after taking office. Presidents often grant pardons shortly before leaving office, like when Bill Clinton pardoned Patty Hearst on his last day in office; this is often controversial.\\nTwo doctrines concerning executive power have developed that enable the president to exercise executive power with a degree of autonomy. The first is executive privilege, which allows the president to withhold from disclosure any communications made directly to the president in the performance of executive duties. George Washington first claimed the privilege when Congress requested to see Chief Justice John Jay's notes from an unpopular treaty negotiation with Great Britain. While not enshrined in the Constitution or any other law, Washington's action created the precedent for the privilege. When Nixon tried to use executive privilege as a reason for not turning over subpoenaed evidence to Congress during the Watergate scandal, the Supreme Court ruled in United States v. Nixon, 418 U.S. 683 (1974), that executive privilege did not apply in cases where a president was attempting to avoid criminal prosecution. When Bill Clinton attempted to use executive privilege regarding the Lewinsky scandal, the Supreme Court ruled in Clinton v. Jones, 520 U.S. 681 (1997), that the privilege also could not be used in civil suits. These cases established the legal precedent that executive privilege is valid, although the exact extent of the privilege has yet to be clearly defined. Additionally, federal courts have allowed this privilege to radiate outward and protect other executive branch employees but have weakened that protection for those executive branch communications that do not involve the president.\\nThe state secrets privilege allows the president and the executive branch to withhold information or documents from discovery in legal proceedings if such release would harm national security. Precedent for the privilege arose early in the 19th century when Thomas Jefferson refused to release military documents in the treason trial of Aaron Burr and again in Totten v. United States 92 U.S. 105 (1876), when the Supreme Court dismissed a case brought by a former Union spy. However, the privilege was not formally recognized by the U.S. Supreme Court until United States v. Reynolds 345 U.S. 1 (1953), where it was held to be a common law evidentiary privilege. Before the September 11 attacks, use of the privilege had been rare, but increasing in frequency. Since 2001, the government has asserted the privilege in more cases and at earlier stages of the litigation, thus in some instances causing dismissal of the suits before reaching the merits of the claims, as in the Ninth Circuit's ruling in Mohamed v. Jeppesen Dataplan, Inc. Critics of the privilege claim its use has become a tool for the government to cover up illegal or embarrassing government actions.\\nThe degree to which the president personally has absolute immunity from court cases is contested and has been the subject of several Supreme Court decisions. Nixon v. Fitzgerald (1982) dismissed a civil lawsuit against by-then former president Richard Nixon based on his official actions. Clinton v. Jones (1997) decided that a president has no immunity against civil suits for actions taken before becoming president and ruled that a sexual harassment suit could proceed without delay, even against a sitting president. The 2019 Mueller report on Russian interference in the 2016 presidential election detailed evidence of possible obstruction of justice, but investigators declined to refer Donald Trump for prosecution based on a United States Department of Justice policy against indicting an incumbent president. The report noted that impeachment by Congress was available as a remedy. As of October 2019, a case was pending in the federal courts regarding access to personal tax returns in a criminal case brought against Donald Trump by the New York County District Attorney alleging violations of New York state law.\\n\\nLeadership roles\\nHead of state\\nAs head of state, the president represents the United States government to its own people and represents the nation to the rest of the world. For example, during a state visit by a foreign head of state, the president typically hosts a State Arrival Ceremony held on the South Lawn, a custom begun by John F. Kennedy in 1961. This is followed by a state dinner given by the president which is held in the State Dining Room later in the evening.\\n\\nAs a national leader, the president also fulfills many less formal ceremonial duties. For example, William Howard Taft started the tradition of throwing out the ceremonial first pitch in 1910 at Griffith Stadium, Washington, D.C., on the Washington Senators's Opening Day. Every president since Taft, except for Jimmy Carter, threw out at least one ceremonial first ball or pitch for Opening Day, the All-Star Game, or the World Series, usually with much fanfare. Every president since Theodore Roosevelt has served as honorary president of the Boy Scouts of America.\\nOther presidential traditions are associated with American holidays. Rutherford B. Hayes began in 1878 the first White House egg rolling for local children. Beginning in 1947, during the Harry S. Truman administration, every Thanksgiving the president is presented with a live domestic turkey during the annual National Thanksgiving Turkey Presentation held at the White House. Since 1989, when the custom of \\\"pardoning\\\" the turkey was formalized by George H. W. Bush, the turkey has been taken to a farm where it will live out the rest of its natural life.\\nPresidential traditions also involve the president's role as head of government. Many outgoing presidents since James Buchanan traditionally give advice to their successor during the presidential transition. Ronald Reagan and his successors have also left a private message on the desk of the Oval Office on Inauguration Day for the incoming president.\\nThe modern presidency holds the president as one of the nation's premier celebrities. Some argue that images of the presidency have a tendency to be manipulated by administration public relations officials as well as by presidents themselves. One critic described the presidency as \\\"propagandized leadership\\\" which has a \\\"mesmerizing power surrounding the office\\\". Administration public relations managers staged carefully crafted photo-ops of smiling presidents with smiling crowds for television cameras. One critic wrote the image of John F. Kennedy was described as carefully framed \\\"in rich detail\\\" which \\\"drew on the power of myth\\\" regarding the incident of PT 109 and wrote that Kennedy understood how to use images to further his presidential ambitions. As a result, some political commentators have opined that American voters have unrealistic expectations of presidents: voters expect a president to \\\"drive the economy, vanquish enemies, lead the free world, comfort tornado victims, heal the national soul and protect borrowers from hidden credit-card fees\\\".\\n\\nHead of party\\nThe president is typically considered to be the head of their political party. Since the entire House of Representatives and at least one-third of the Senate is elected simultaneously with the president, candidates from a political party inevitably have their electoral success intertwined with the performance of the party's presidential candidate. The coattail effect, or lack thereof, will also often impact a party's candidates at state and local levels of government as well. However, there are often tensions between a president and others in the party, with presidents who lose significant support from their party's caucus in Congress generally viewed to be weaker and less effective.\\n\\nGlobal leader\\nWith the rise of the United States as a superpower in the 20th century, and the United States having the world's largest economy into the 21st century, the president is typically viewed as a global leader, and at times the world's most powerful political figure. The position of the United States as the leading member of NATO, and the country's strong relationships with other wealthy or democratic nations like those comprising the European Union, have led to the moniker that the president is the \\\"leader of the free world\\\".\\n\\nSelection process\\nEligibility\\nArticle II, Section 1, Clause 5 of the Constitution sets three qualifications for holding the presidency. To serve as president, one must:\\n\\nbe a natural-born citizen of the United States;\\nbe at least 35 years old;\\nbe a resident in the United States for at least 14 years.\\nA person who meets the above qualifications would, however, still be disqualified from holding the office of president under any of the following conditions:\\n\\nUnder Article I, Section 3, Clause 7, having been impeached, convicted and disqualified from holding further public office, although there is some legal debate as to whether the disqualification clause also includes the presidential office: the only previous persons disqualified under this clause were three federal judges.\\nUnder Section 3 of the Fourteenth Amendment, no person who swore an oath to support the Constitution, and later rebelled against the United States, is eligible to hold any office. However, this disqualification can be lifted by a two-thirds vote of each house of Congress. There is, again, some debate as to whether the clause as written allows disqualification from the presidential position, or whether it would first require litigation outside of Congress, although there is precedent for use of this amendment outside of the original intended purpose of excluding Confederates from public office after the Civil War.\\nUnder the Twenty-second Amendment, no person can be elected president more than twice. The amendment also specifies that if any eligible person serves as president or acting president for more than two years of a term for which some other eligible person was elected president, the former can only be elected president once.\\n\\nCampaigns and nomination\\nThe modern presidential campaign begins before the primary elections, which the two major political parties use to clear the field of candidates before their national nominating conventions, where the most successful candidate is made the party's presidential nominee. Typically, the party's presidential candidate chooses a vice presidential nominee, and this choice is rubber-stamped by the convention. The most common previous profession of presidents is lawyer.\\nNominees participate in nationally televised debates, and while the debates are usually restricted to the Democratic and Republican nominees, third party candidates may be invited, such as Ross Perot in the 1992 debates. Nominees campaign across the country to explain their views, convince voters and solicit contributions. Much of the modern electoral process is concerned with winning swing states through frequent visits and mass media advertising drives.\\n\\nElection\\nThe president is elected indirectly by the voters of each state and the District of Columbia through the Electoral College, a body of electors formed every four years for the sole purpose of electing the president and vice president to concurrent four-year terms. As prescribed by Article II, Section 1, Clause 2, each state is entitled to a number of electors equal to the size of its total delegation in both houses of Congress. Additionally, the Twenty-third Amendment provides that the District of Columbia is entitled to the number it would have if it were a state, but in no case more than that of the least populous state. Currently, all states and the District of Columbia select their electors based on a popular election. In all but two states, the party whose presidential–vice presidential ticket receives a plurality of popular votes in the state has its entire slate of elector nominees chosen as the state's electors. Maine and Nebraska deviate from this winner-take-all practice, awarding two electors to the statewide winner and one to the winner in each congressional district.\\nOn the first Monday after the second Wednesday in December, about six weeks after the election, the electors convene in their respective state capitals (and in Washington, D.C.) to vote for president and, on a separate ballot, for vice president. They typically vote for the candidates of the party that nominated them. While there is no constitutional mandate or federal law requiring them to do so, the District of Columbia and 32 states have laws requiring that their electors vote for the candidates to whom they are pledged. The constitutionality of these laws was upheld in Chiafalo v. Washington (2020). Following the vote, each state then sends a certified record of their electoral votes to Congress. The votes of the electors are opened and counted during a joint session of Congress, held in the first week of January. If a candidate has received an absolute majority of electoral votes for president (currently 270 of 538), that person is declared the winner. Otherwise, the House of Representatives must meet to elect a president using a contingent election procedure in which representatives, voting by state delegation, with each state casting a single vote, choose between the top three electoral vote-getters for president. To win the presidency, a candidate must receive the votes of an absolute majority of states (currently 26 of 50).\\nThere have been two contingent presidential elections in the nation's history. A 73–73 electoral vote tie between Thomas Jefferson and fellow Democratic-Republican Aaron Burr in the election of 1800 necessitated the first. Conducted under the original procedure established by Article II, Section 1, Clause 3 of the Constitution, which stipulates that if two or three persons received a majority vote and an equal vote, the House of Representatives would choose one of them for president; the runner-up would become vice president. On February 17, 1801, Jefferson was elected president on the 36th ballot, and Burr elected vice president. Afterward, the system was overhauled through the Twelfth Amendment in time to be used in the 1804 election. A quarter-century later, the choice for president again devolved to the House when no candidate won an absolute majority of electoral votes (131 of 261) in the election of 1824. Under the Twelfth Amendment, the House was required to choose a president from among the top three electoral vote recipients: Andrew Jackson, John Quincy Adams, and William H. Crawford. Held February 9, 1825, this second and most recent contingent election resulted in John Quincy Adams being elected president on the first ballot.\\n\\nInauguration\\nPursuant to the Twentieth Amendment, the four-year term of office for both the president and the vice president begins at noon on January 20, in the year following the preceding presidential election. The first presidential and vice presidential terms to begin on this date, known as Inauguration Day, were the second terms of President Franklin D. Roosevelt and Vice President John Nance Garner in 1937. Previously, Inauguration Day was on March 4. As a result of the date change, the first term (1933–37) of both men had been shortened by 43 days.\\nBefore executing the powers of the office, a president is required to recite the presidential Oath of Office, found in Article II, Section 1, Clause 8 of the Constitution. This is the only component in the inauguration ceremony mandated by the Constitution:\\n\\nI do solemnly swear (or affirm) that I will faithfully execute the Office of President of the United States, and will to the best of my ability, preserve, protect, and defend the Constitution of the United States.\\nPresidents have traditionally placed one hand upon a Bible while taking the oath, and have added \\\"So help me God\\\" to the end of the oath. Although the oath may be administered by any person authorized by law to administer oaths, presidents are traditionally sworn in by the chief justice of the United States.\\n\\nIncumbency\\nTerm limit\\nWhen the first president, George Washington, announced in his Farewell Address that he was not running for a third term, he established a \\\"two terms then out\\\" precedent. Precedent became tradition after Thomas Jefferson publicly embraced the principle a decade later during his second term, as did his two immediate successors, James Madison and James Monroe. In spite of the strong two-term tradition, Ulysses S. Grant sought nomination at the 1880 Republican National Convention for a non-consecutive third term, but was unsuccessful.\\nIn 1940, after leading the nation through the Great Depression and focused on supporting U.S. allied nations at war with the Axis powers, Franklin Roosevelt was elected to a third term, breaking the long-standing precedent. Four years later, with the U.S. engaged in World War II, he was re-elected again despite his declining physical health; he died 82 days into his fourth term on April 12, 1945.\\nIn response to the unprecedented length of Roosevelt's presidency, the Twenty-second Amendment was adopted in 1951. The amendment bars anyone from being elected president more than twice, or once if that person served more than two years (24 months) of another president's four-year term. Harry S. Truman, the president at the time it was submitted to the states by the Congress, was exempted from its limitations. Without the exemption, he would not have been eligible to run for a second full term in 1952 (which he briefly sought), as he had served nearly all of Franklin Roosevelt's unexpired 1945–1949 term and had been elected to a full four-year term beginning in 1949. Since becoming operative in 1951, the amendment has been applicable to six twice-elected presidents: Dwight D. Eisenhower, Richard Nixon, Ronald Reagan, Bill Clinton, George W. Bush, and Barack Obama.\\n\\nVacancies and succession\\nUnder Section 1 of the Twenty-fifth Amendment, ratified in 1967, the vice president becomes president upon the removal from office, death, or resignation of the president. Deaths have occurred a number of times, resignation has occurred only once, and removal from office has never occurred.\\nBefore the ratification of the Twenty-fifth amendment (which clarified the matter of succession), Article II, Section 1, Clause 6, stated only that the vice president assumes the \\\"powers and duties\\\" of the presidency in the event of a president's removal, death, resignation, or inability. Under this clause, there was ambiguity about whether the vice president would actually become president in the event of a vacancy, or simply act as president, potentially resulting in a special election. Upon the death of President William Henry Harrison in 1841, Vice President John Tyler declared that he had succeeded to the office itself, refusing to accept any papers addressed to the \\\"Acting President\\\", and Congress ultimately accepted it.\\nIn the event of a double vacancy, Article II, Section 1, Clause 6 also authorizes Congress to declare who shall become acting president in the \\\"Case of Removal, Death, Resignation or Inability, both of the president and vice president\\\". The Presidential Succession Act of 1947 (codified as 3 U.S.C. § 19) provides that if both the president and vice president have left office or are both otherwise unavailable to serve during their terms of office, the presidential line of succession follows the order of: speaker of the House, then, if necessary, the president pro tempore of the Senate, and then if necessary, the eligible heads of federal executive departments who form the president's cabinet. The cabinet currently has 15 members, of which the secretary of state is first in line; the other Cabinet secretaries follow in the order in which their department (or the department of which their department is the successor) was created. Those individuals who are constitutionally ineligible to be elected to the presidency are also disqualified from assuming the powers and duties of the presidency through succession. No statutory successor has yet been called upon to act as president.\\n\\nDeclarations of inability\\nUnder the Twenty-fifth Amendment, the president may temporarily transfer the presidential powers and duties to the vice president, who then becomes acting president, by transmitting to the speaker of the House and the president pro tempore of the Senate a statement that he is unable to discharge his duties. The president resumes his or her powers upon transmitting a second declaration stating that he is again able. The mechanism has been used by Ronald Reagan (once), George W. Bush (twice), and Joe Biden (once), each in anticipation of surgery.\\nThe Twenty-fifth Amendment also provides that the vice president, together with a majority of certain members of the Cabinet, may transfer the presidential powers and duties to the vice president by transmitting a written declaration, to the speaker of the House and the president pro tempore of the Senate, to the effect that the president is unable to discharge his or her powers and duties. If the president then declares that no such inability exist, he or she resumes the presidential powers unless the vice president and Cabinet make a second declaration of presidential inability, in which case Congress decides the question.\\n\\nRemoval\\nArticle II, Section 4 of the Constitution allows for the removal of high federal officials, including the president, from office for \\\"treason, bribery, or other high crimes and misdemeanors\\\". Article I, Section 2, Clause 5 authorizes the House of Representatives to serve as a \\\"grand jury\\\" with the power to impeach said officials by a majority vote. Article I, Section 3, Clause 6 authorizes the Senate to serve as a court with the power to remove impeached officials from office, by a two-thirds vote to convict.\\nThree presidents have been impeached by the House of Representatives: Andrew Johnson in 1868, Bill Clinton in 1998, and Donald Trump in 2019 and 2021; none have been convicted by the Senate. Additionally, the House Judiciary Committee conducted an impeachment inquiry against Richard Nixon in 1973–74 and reported  three articles of impeachment to the House of Representatives for final action; however, he resigned from office before the House voted on them.\\n\\nCircumvention of authority\\nControversial measures have sometimes been taken short of removal to deal with perceived recklessness on the part of the president, or with a long-term disability. In some cases, staff have intentionally failed to deliver messages to or from the president, typically to avoid executing or promoting the president to write certain orders. This has ranged from Richard Nixon's Chief of Staff not transmitting orders to the Cabinet due to the president's heavy drinking, to staff removing memos from Donald Trump's desk. Decades before the Twenty-fifth Amendment, in 1919, President Woodrow Wilson had a stroke that left him partly incapacitated.  First lady Edith Wilson kept this condition a secret from the public for a while, and controversially became the sole gatekeeper for access to the president (aside from his doctor), assisting him with paperwork and deciding which information was \\\"important\\\" enough to share with him.\\n\\nCompensation\\nSince 2001, the president's annual salary has been $400,000, along with a: $50,000 expense allowance; $100,000 nontaxable travel account, and $19,000 entertainment account. The president's salary is set by Congress, and under Article II, Section 1, Clause 7 of the Constitution, any increase or reduction in presidential salary cannot take effect before the next presidential term of office.\\n\\nResidence\\nThe Executive Residence of the White House in Washington, D.C. is the official residence of the president. The site was selected by George Washington, and the cornerstone was laid in 1792. Every president since John Adams (in 1800) has lived there. At various times in U.S. history, it has been known as the \\\"President's Palace\\\", the \\\"President's House\\\", and the \\\"Executive Mansion\\\". Theodore Roosevelt officially gave the White House its current name in 1901. The federal government pays for state dinners and other official functions, but the president pays for personal, family, and guest dry cleaning and food.\\nCamp David, officially titled Naval Support Facility Thurmont, a mountain-based military camp in Frederick County, Maryland, is the president's country residence. A place of solitude and tranquility, the site has been used extensively to host foreign dignitaries since the 1940s.\\nPresident's Guest House, located next to the Eisenhower Executive Office Building at the White House Complex and Lafayette Park, serves as the president's official guest house and as a secondary residence for the president if needed. Four interconnected, 19th-century houses—Blair House, Lee House, and 700 and 704 Jackson Place—with a combined floor space exceeding 70,000 square feet (6,500 m2) comprise the property.\\n\\n\\tPresidential residences\\n\\nTravel\\nThe primary means of long-distance air travel for the president is one of two identical Boeing VC-25 aircraft, which are extensively modified Boeing 747 airliners and are referred to as Air Force One while the president is on board (although any U.S. Air Force aircraft the president is aboard is designated as \\\"Air Force One\\\" for the duration of the flight). In-country trips are typically handled with just one of the two planes, while overseas trips are handled with both, one primary and one backup. The president also has access to smaller Air Force aircraft, most notably the Boeing C-32, which are used when the president must travel to airports that cannot support a jumbo jet. Any civilian aircraft the president is aboard is designated Executive One for the flight.\\nFor short-distance air travel, the president has access to a fleet of U.S. Marine Corps helicopters of varying models, designated Marine One when the president is aboard any particular one in the fleet. Flights are typically handled with as many as five helicopters all flying together and frequently swapping positions as to disguise which helicopter the president is actually aboard to any would-be threats.\\nFor ground travel, the president uses the presidential state car, which is an armored limousine designed to look like a Cadillac sedan, but built on a truck chassis. The U.S. Secret Service operates and maintains the fleet of several limousines. The president also has access to two armored motorcoaches, which are primarily used for touring trips.\\n\\n\\tPresidential transportation\\n\\nProtection\\nThe U.S. Secret Service is charged with protecting the president and the first family. As part of their protection, presidents, first ladies, their children and other immediate family members, and other prominent persons and locations are assigned Secret Service codenames. The use of such names was originally for security purposes and dates to a time when sensitive electronic communications were not routinely encrypted; today, the names simply serve for purposes of brevity, clarity, and tradition.\\n\\nPost-presidency\\nActivities\\nSome former presidents have had significant careers after leaving office. Prominent examples include William Howard Taft's tenure as chief justice of the United States and Herbert Hoover's work on government reorganization after World War II. Grover Cleveland, whose bid for reelection failed in 1888, was elected president again four years later in 1892. Two former presidents served in Congress after leaving the White House: John Quincy Adams was elected to the House of Representatives, serving there for 17 years, and Andrew Johnson returned to the Senate in 1875, though he died soon after. Some ex-presidents were very active, especially in international affairs, most notably Theodore Roosevelt; Herbert Hoover; Richard Nixon; and Jimmy Carter.\\nPresidents may use their predecessors as emissaries to deliver private messages to other nations or as official representatives of the United States to state funerals and other important foreign events. Richard Nixon made multiple foreign trips to countries including China and Russia and was lauded as an elder statesman. Jimmy Carter has become a global human rights campaigner, international arbiter, and election monitor, as well as a recipient of the Nobel Peace Prize. Bill Clinton has also worked as an informal ambassador, most recently in the negotiations that led to the release of two American journalists, Laura Ling and Euna Lee, from North Korea. During his presidency, George W. Bush called on former Presidents Bush and Clinton to assist with humanitarian efforts after the 2004 Indian Ocean earthquake and tsunami. President Obama followed suit by asking Presidents Clinton and Bush to lead efforts to aid Haiti after an earthquake devastated that country in 2010.\\nClinton was active politically since his presidential term ended, working with his wife Hillary on her 2008 and 2016 presidential bids and President Obama on his 2012 reelection campaign. Obama was also active politically since his presidential term ended, having worked with his former vice president Joe Biden on his 2020 election campaign. Trump has continued to make appearances in the media and at conferences and rallies since leaving office in 2021. He is currently running for a non-consecutive second term in the upcoming 2024 presidential election.\\n\\nPension and other benefits\\nThe Former Presidents Act (FPA), enacted in 1958, grants lifetime benefits to former presidents and their widows, including a monthly pension, medical care in military facilities, health insurance, and Secret Service protection; also provided is funding for a certain number of staff and for office expenses.  The act has been amended several times to provide increases in presidential pensions and in the allowances for office staff.  The FPA excludes any president who was removed from office by impeachment.\\nAccording to a 2008 report by the Congressional Research Service:\\n\\nChief executives leaving office prior to 1958 often entered retirement pursuing various occupations and received no federal assistance. When industrialist Andrew Carnegie announced a plan in 1912 to offer $25,000 annual pensions to former Presidents, many Members of Congress deemed it inappropriate that such a pension would be provided by a private corporation executive. That same year, legislation was first introduced to create presidential pensions, but it was not enacted. In 1955, such legislation was considered by Congress because of former President Harry S. Truman's financial limitations in hiring an office staff\\nThe pension has increased numerous times with congressional approval. Retired presidents receive a pension based on the salary of the current administration's cabinet secretaries, which was $199,700 per year in 2012. Former presidents who served in Congress may also collect congressional pensions. The act also provides former presidents with travel funds and franking privileges.\\nPrior to 1997, all former presidents, their spouses, and their children until age 16 were protected by the Secret Service until the president's death. In 1997, Congress passed legislation limiting Secret Service protection to no more than 10 years from the date a president leaves office. On January 10, 2013, President Obama signed legislation reinstating lifetime Secret Service protection for him, George W. Bush, and all subsequent presidents. A first spouse who remarries is no longer eligible for Secret Service protection.\\n\\nPresidential libraries\\nEvery president since Herbert Hoover has created a repository known as a presidential library for preserving and making available his papers, records, and other documents and materials. Completed libraries are deeded to and maintained by the National Archives and Records Administration (NARA); the initial funding for building and equipping each library must come from private, non-federal sources. There are currently thirteen presidential libraries in the NARA system. There are also presidential libraries maintained by state governments and private foundations and Universities of Higher Education, including:\\n\\nThe Abraham Lincoln Presidential Library and Museum, which is run by the State of Illinois;\\nThe George W. Bush Presidential Library and Museum, which is run by Southern Methodist University;\\nThe George H. W. Bush Presidential Library and Museum, which is run by Texas A&M University; and\\nThe Lyndon Baines Johnson Presidential Library and Museum, which is run by the University of Texas at Austin.\\nSeveral former presidents have overseen the building and opening of their own presidential libraries. Some even made arrangements for their own burial at the site. Several presidential libraries contain the graves of the president they document: \\n\\nThe Harry S. Truman Presidential Library and Museum in Independence, Missouri;\\nThe Dwight D. Eisenhower Presidential Library, Museum and Boyhood Home in Abilene, Kansas;\\nThe Richard Nixon Presidential Library and Museum in Yorba Linda, California; and\\nThe Ronald Reagan Presidential Library and Museum in Simi Valley, California.\\nThese gravesites are open to the general public.\\n\\nPolitical affiliation\\nPolitical parties have dominated American politics for most of the nation's history. Though the Founding Fathers generally spurned political parties as divisive and disruptive, and their rise had not been anticipated when the U.S. Constitution was drafted in 1787, organized political parties developed in the U.S. in the mid-1790s nonetheless. They evolved from political factions, which began to appear almost immediately after the Federal government came into existence. Those who supported the Washington administration were referred to as \\\"pro-administration\\\" and would eventually form the Federalist Party, while those in opposition largely joined the emerging Democratic-Republican Party.\\nGreatly concerned about the very real capacity of political parties to destroy the fragile unity holding the nation together, Washington remained unaffiliated with any political faction or party throughout his eight-year presidency. He was, and remains, the only U.S. president never to be affiliated with a political party. Since Washington, every U.S. president has been affiliated with a political party at the time of assuming office.\\nThe number of presidents per political party by their affiliation at the time they were first sworn into office (alphabetical, by last name) are:\\n\\nTimeline of presidents\\nThe following timeline depicts the progression of the presidents and their political affiliation at the time of assuming office.\\n\\nSee also\\nOutline of American politics\\n\\nNotes\\nReferences\\nFurther reading\\nExternal links\\n\\nWhite House homepage\\nUnited States Presidents Collection. General Collection, Beinecke Rare Book and Manuscript Library, Yale University\\n\\nJames Buchanan Jr. ( bew-KAN-ən; April 23, 1791 – June 1, 1868) was the 15th president of the United States, serving from 1857 to 1861. Buchanan also served as the secretary of state from 1845 to 1849 and represented Pennsylvania in both houses of the U.S. Congress. He was an advocate for states' rights, particularly regarding slavery, and minimized the role of the federal government preceding the Civil War.\\nBuchanan was a lawyer in Pennsylvania and won his first election to the state's House of Representatives as a Federalist. He was elected to the U.S. House of Representatives in 1820 and retained that post for five terms, aligning with Andrew Jackson's Democratic Party. Buchanan served as Jackson's minister to Russia in 1832. He won the election in 1834 as a U.S. senator from Pennsylvania and continued in that position for 11 years. He was appointed to serve as President James K. Polk's secretary of state in 1845, and eight years later was named as President Franklin Pierce's minister to the United Kingdom.\\nBeginning in 1844, Buchanan became a regular contender for the Democratic Party's presidential nomination. He was nominated and won the 1856 presidential election. As President, Buchanan intervened to assure the Supreme Court's majority ruling in the pro-slavery decision in the Dred Scott case. He acceded to Southern attempts to engineer Kansas' entry into the Union as a slave state under the Lecompton Constitution, and angered not only Republicans but also Northern Democrats. Buchanan honored his pledge to serve only one term and supported Breckinridge's unsuccessful candidacy in the 1860 presidential election. He failed to reconcile the fractured Democratic Party amid the grudge against Stephen Douglas, leading to the election of Republican and former Congressman Abraham Lincoln.\\nBuchanan's leadership during his lame duck period, before the American Civil War, has been widely criticized. He simultaneously angered the North by not stopping secession and the South by not yielding to their demands. He supported the Corwin Amendment in an effort to reconcile the country. He made an unsuccessful attempt to reinforce Fort Sumter, but otherwise refrained from preparing the military. In his personal life, Buchanan never married and was the only U.S. president to remain a lifelong bachelor, leading some historians and authors to question his sexual orientation. His failure to forestall the Civil War has been described as incompetence, and he spent his last years defending his reputation. Historians and scholars rank Buchanan as among the worst presidents in American history.\\n\\nEarly life\\nChildhood and education\\nJames Buchanan Jr. was born into a Scottish-Irish family on April 23, 1791, in a log cabin on a farm called Stony Batter, near Cove Gap, Peters Township, in the Allegheny Mountains of southern Pennsylvania. He was the last president born in the 18th century and, until the election of Joe Biden in 2020, the only one born in Pennsylvania. Buchanan was the second of eleven children with six sisters and four brothers, and the eldest son of James Buchanan Sr. (1761–1821) and his wife Elizabeth Speer (1767–1833). James Buchanan Sr., was an Ulster-Scot from just outside Ramelton, a small town in the north-east of County Donegal in the north-west of Ulster, the northern province in Ireland, who emigrated to the newly formed United States in 1783, having sailed from Derry. He belonged to the Clan Buchanan, whose members had emigrated in large numbers from the Scottish Highlands to Ulster in the north of Ireland during the Plantation of Ulster in the seventeenth century and, later, largely because of poverty and persecution by the Crown due to their Presbyterian faith, had further emigrated in large numbers from Ulster to America from the early eighteenth century onwards. Shortly after Buchanan's birth, the family relocated to a farm near Mercersburg, Pennsylvania, and later settled in the town in 1794. His father became the area's wealthiest resident, working as a merchant, farmer, and real estate investor. Buchanan attributed his early education primarily to his mother, whereas his father had a greater influence on his character. His mother had discussed politics with him as a child and had an interest in poetry, quoting John Milton and William Shakespeare to Buchanan.\\nBuchanan attended the Old Stone Academy in Mercersburg and then Dickinson College in Carlisle, Pennsylvania. In 1808, he was nearly expelled for disorderly conduct; he and his fellow students had attracted negative attention for drinking in local taverns, disturbing the peace at night and committing acts of vandalism, but he pleaded for a second chance and ultimately graduated with honors in 1809. Later that year, he moved to the state capital at Lancaster, to train as a lawyer for two and a half years with the well-known James Hopkins. Following the fashion of the time, Buchanan studied the United States Code and the Constitution of the United States as well as legal authorities such as William Blackstone during his education.\\n\\nEarly law practice and Pennsylvania House of Representatives\\nIn 1812, Buchanan passed the bar exam and after being admitted to the bar, he remained in Lancaster, even when Harrisburg became the new capital of Pennsylvania. Buchanan quickly established himself as a prominent legal representative in the city. His income rapidly rose after he established his practice, and by 1821 he was earning over $11,000 per year (equivalent to $250,000 in 2023). At this time, Buchanan became a Freemason, and served as the Worshipful Master of Masonic Lodge No. 43 in Lancaster and as a District Deputy Grand Master of the Grand Lodge of Pennsylvania.\\nBuchanan also served as chairman of the Lancaster chapter of the Federalist Party. Like his father, he supported their political program, which provided federal funds for building projects and import duties as well as the re-establishment of a central bank after the First Bank of the United States' license expired in 1811. He became a strong critic of Democratic-Republican President James Madison during the War of 1812. Although he did not himself serve in a militia during the War of 1812, during the British occupation he joined a group of young men who stole horses for the United States Army in the Baltimore area. He was the last president involved in the War of 1812.\\nIn 1814, he was elected for the Federalists to the Pennsylvania House of Representatives, where he was the youngest member, and held this seat until 1816. Since the sessions in the Pennsylvania General Assembly lasted only three months, Buchanan continued practicing law at a profit by charging higher fees, and his service helped him acquire more clients. In 1815, Buchanan defended District Judge Walter Franklin in an impeachment trial before the Pennsylvania Senate, over alleged judicial misconduct. Impeachments were more common at the time because the line between abuse of office and a wrong legal decision was determined by the ruling parties' preferences and the popularity of the judge's decision. Buchanan persuaded the senators that only judicial crimes and clear violations of the law justified impeachment.\\n\\nCongressional career\\nU.S. House of Representatives\\nIn the congressional elections of 1820, Buchanan ran for a seat in the House of Representatives. Shortly after his election victory, his father died in a carriage accident. As a young Representative, Buchanan was one of the most prominent leaders of the \\\"Amalgamator party\\\" faction of Pennsylvanian politics, named that because it was made up of both Democratic-Republicans and former Federalists, which transitioned from the First Party System to the Era of Good Feelings. During this era, the Democratic-Republicans became the most influential party. Buchanan's Federalist convictions were weak, and he switched parties after opposing a nativist Federalist bill. During the 1824 presidential election, Buchanan initially supported Henry Clay, but switched to Andrew Jackson (with Clay as a second choice) when it became clear that the Pennsylvanian public overwhelmingly preferred Jackson. After Jackson lost the 1824 election, he joined his faction, but Jackson had contempt for Buchanan due to his misinterpretation of his efforts to mediate between the Clay and Jackson camps.\\nIn Washington, Buchanan became an avid defender of states' rights, and was close with many southern Congressmen, viewing some New England Congressmen as dangerous radicals. Buchanan's close proximity to his constituency allowed him to establish a Democratic coalition in Pennsylvania, consisting of former Federalist farmers, Philadelphia artisans, and Ulster-Scots-Americans. In the 1828 presidential election, he secured Pennsylvania, while the \\\"Jacksonian Democrats\\\", an independent party after splitting from the National Republican Party, won an easy victory in the parallel congressional election.\\n\\nBuchanan gained most attention during an impeachment trial where he acted as prosecutor for federal district judge James H. Peck; however, the Senate rejected Buchanan's plea and acquitted Peck by a majority vote. He was appointed to the Agriculture Committee in his first year, and he eventually became chairman of the Judiciary Committee. In 1831, Buchanan declined a nomination for the 22nd United States Congress from his constituency consisting of Dauphin, Lebanon, and Lancaster counties. He still had political ambitions and some Pennsylvania Democrats put him forward as a candidate for the vice presidency in the 1832 election.\\n\\nMinister to Russia\\nAfter Jackson was re-elected in 1832, he offered Buchanan the position of United States Ambassador to Russia. Buchanan was reluctant to leave the country, as the distant St. Petersburg was a kind of political exile, which was the intention of Jackson, who considered Buchanan to be an \\\"incompetent busybody\\\" and untrustworthy, but he ultimately agreed. His work focused on concluding a trade and shipping treaty with Russia. While Buchanan was successful with the former, negotiating an agreement on free merchant shipping with Foreign Minister Karl Nesselrode proved difficult. He had denounced Tsar Nicholas I as a despot merely a year prior during his tenure in Congress; many Americans had reacted negatively to Russia's reaction to the 1830 Polish uprising.\\n\\nU.S. Senator\\nBuchanan returned home and lost the election in the State Legislature for a full six-year term in the 23rd Congress, but was appointed by the Pennsylvania state legislature to succeed William Wilkins in the U.S. Senate. Wilkins, in turn, replaced Buchanan as the ambassador to Russia. The Jacksonian Buchanan, who was re-elected in 1836 and 1842, opposed the re-chartering of the Second Bank of the United States and sought to expunge a congressional censure of Jackson stemming from the Bank War. Buchanan served in the Senate until March 1845 and was twice confirmed in office. To unite Pennsylvania Democrats at the State Convention, he was chosen as their candidate for the National Convention. Buchanan maintained a strict adherence to the Pennsylvania State Legislature's guidelines and sometimes voted against positions in Congress which he promoted in his own speeches, despite open ambitions for the White House.\\nBuchanan was known for his commitment to states' rights and the Manifest Destiny ideology. He rejected President Martin Van Buren's offer to become United States Attorney General and chaired prestigious Senate committees such as the Committee on the Judiciary and the Committee on Foreign Relations. Buchanan was one of only a few senators to vote against the Webster–Ashburton Treaty for its \\\"surrender\\\" of lands to the United Kingdom, as he demanded the entire Aroostook River Valley for the United States. In the Oregon Boundary Dispute, Buchanan adopted the maximum demand of 54°40′ as the northern border and spoke out in favor of annexing the Republic of Texas. During the contentious 1838 Pennsylvania gubernatorial election, Buchanan chose to support the Democratic challenger, David Rittenhouse Porter, who was elected by fewer than 5,500 votes as Pennsylvania's first governor under the state's revised Constitution of 1838.\\nBuchanan also opposed a gag rule sponsored by John C. Calhoun that would have suppressed anti-slavery petitions. He joined the majority in blocking the rule, with most senators of the belief that it would have the reverse effect of strengthening the abolitionists. He said, \\\"We have just as little right to interfere with slavery in the South, as we have to touch the right of petition.\\\" Buchanan thought that the issue of slavery was the domain of the states, and he faulted abolitionists for exciting passions over the issue. In the lead-up to the 1844 Democratic National Convention, Buchanan positioned himself as a potential alternative to former President Martin Van Buren, but the nomination went to James K. Polk, who won the election.\\n\\nDiplomatic career\\nSecretary of State\\nBuchanan was offered the position of Secretary of State in the Polk administration or, as the alternative, a seat on the Supreme Court, to compensate him for his support in the election campaign but also in order to eliminate him as an internal party rival. He accepted the State Department post and served for the duration of Polk's single term in office. During his tenure, the United States recorded its largest territorial gain in history through the Oregon Treaty and the Treaty of Guadalupe Hidalgo, which included territory that is now Texas, California, Nevada, New Mexico, Arizona, Utah, and Colorado. In negotiations with Britain over Oregon, Buchanan initially favored the 49th parallel as the boundary of Oregon Territory, while Polk called for a more northerly boundary line. When Northern Democrats rallied around the popular slogan Fifty-Four Forty or Fight (\\\"54°40′ or war\\\") in the 1844 election campaign, Buchanan adopted this position, but later followed Polk's direction, leading to the Oregon Compromise of 1846, which established the 49th parallel as the boundary in the Pacific Northwest.\\nIn regards to Mexico, Buchanan maintained a dubious view that its attack on American troops on the other side of the Rio Grande in April 1846 constituted a border violation and a legitimate reason for war. During the Mexican-American War, Buchanan initially advised against claiming territory south of the Rio Grande, fearing war with Britain and France. However, as the war came to an end, Buchanan changed his mind and argued for the annexation of further territory, arguing that Mexico was to blame for the war and that the compensation negotiated for the American losses was too low. Buchanan sought the nomination at the 1848 Democratic National Convention, as Polk had promised to serve only one term, but he only won the support of the Pennsylvania and Virginia delegations, so Senator Lewis Cass of Michigan was nominated.\\n\\nCivilian life and 1852 presidential election\\nWith the 1848 election of Whig Zachary Taylor, Buchanan returned to private life. Buchanan was getting on in years and still dressed in the old-fashioned style of his adolescence, earning him the nickname \\\"Old Public Functionary\\\" from the press. Slavery opponents in the North mocked him as a relic of prehistoric man because of his moral values. He bought the house of Wheatland on the outskirts of Lancaster and entertained various visitors while monitoring political events. During this period, Buchanan became the center of a family network consisting of 22 nieces, nephews and their descendants, seven of whom were orphans. He found public service jobs for some through patronage, and for those in his favor, he took on the role of surrogate father. He formed the strongest emotional bond with his niece Harriet Lane, who later became First Lady for Buchanan in the White House.\\nIn 1852, he was named president of the Board of Trustees of Franklin and Marshall College in Lancaster, and he served in this capacity until 1866. Buchanan did not completely leave politics. He intended to publish a collection of speeches and an autobiography, but his political comeback was thwarted by the 1852 presidential election. Buchanan traveled to Washington to discuss Pennsylvania Democratic Party politics, which were divided into two camps led by Simon Cameron and George Dallas. He quietly campaigned for the 1852 Democratic presidential nomination. In light of the Compromise of 1850, which had led to the admission of California into the Union as a free state and a stricter Fugitive Slave Act, Buchanan now rejected the Missouri Compromise and welcomed Congress's rejection of the Wilmot Proviso, which prohibited slavery in all territories gained in the Mexican-American War. Buchanan criticized abolitionism as a fanatical attitude and believed that slavery should be decided by state legislatures, not Congress. He disliked abolitionist Northerners due to his party affiliation, and became known as a \\\"doughface\\\" due to his sympathy toward the South. Buchanan emerged as a promising candidate for the Democratic presidential nomination, alongside Lewis Cass, Stephen Douglas, and William L. Marcy; however, the Pennsylvania convention did not vote unanimously in his favor, with over 30 delegates protesting against him. At the 1852 Democratic National Convention, he won the support of many southern delegates but failed to win the two-thirds support needed for the presidential nomination, which went to Franklin Pierce. Buchanan declined to serve as the vice presidential nominee, and the convention instead nominated his close friend, William R. King.\\n\\nMinister to the United Kingdom\\nPierce won the election in 1852, and six months later, Buchanan accepted the position of United States Minister to the United Kingdom, a position that represented a step backward in his career and that he had twice previously rejected. Buchanan sailed for England in the summer of 1853, and he remained abroad for the next three years. In 1850, the United States and Great Britain signed the Clayton–Bulwer Treaty, which committed both countries to joint control of any future canal that would connect the Atlantic and Pacific Oceans through Central America. Buchanan met repeatedly with Lord Clarendon, the British foreign minister, in hopes of pressuring the British to withdraw from Central America. He was able to reduce British influence in Honduras and Nicaragua while also raising the kingdom's awareness of American interests in the region. He also focused on the potential annexation of Cuba, which had long interested him.\\nAt Pierce's prompting, Buchanan met in Ostend, Belgium, with U.S. Ambassador to Spain Pierre Soulé and U.S. Ambassador to France John Mason, to work out a plan for the acquisition of Cuba. A memorandum draft resulted, called the Ostend Manifesto, which proposed the purchase of Cuba from Spain, then in the midst of revolution and near bankruptcy. The document declared the island \\\"as necessary to the North American republic as any of its present ... family of states\\\". Against Buchanan's recommendation, the final draft of the manifesto suggested that \\\"wresting it from Spain\\\", if Spain refused to sell, would be justified \\\"by every law, human and Divine\\\". The manifesto was met with a divided response and was never acted upon. It weakened the Pierce administration and reduced support for Manifest Destiny. In 1855, as Buchanan's desire to return home grew, Pierce asked him to hold the fort in London in light of the relocation of a British fleet to the Caribbean.\\n\\nElection of 1856\\nBuchanan's service abroad allowed him to conveniently avoid the debate over the Kansas–Nebraska Act then roiling the country in the slavery dispute. While he did not overtly seek the presidency, he assented to the movement on his behalf. While still in England, he campaigned by praising John Joseph Hughes, who was Archbishop of New York, to a Catholic archbishop. The latter campaigned for Buchanan among high-ranking Catholics as soon as he heard about it. When Buchanan arrived home at the end of April 1856, he led on the first ballot, supported by powerful Senators John Slidell, Jesse Bright, and Thomas F. Bayard, who presented Buchanan as an experienced leader appealing to the North and South. The 1856 Democratic National Convention met in June 1856, producing a platform that reflected Buchanan's views, including support for the Fugitive Slave Law, which required the return of escaped slaves. The platform also called for an end to anti-slavery agitation and U.S. \\\"ascendancy in the Gulf of Mexico\\\". President Pierce hoped for re-nomination, while Senator Stephen A. Douglas also loomed as a strong candidate. He won the nomination after seventeen ballots after Douglas' resignation. He was joined on the ticket by John C. Breckinridge of Kentucky in order to maintain regional proportional representation, placating supporters of Pierce and Douglas, also allies of Breckinridge.\\nBuchanan faced two candidates in the general election: former Whig President Millard Fillmore ran as the candidate for the anti-Catholic, anti-immigrant American Party (or \\\"Know-Nothing\\\"), while John C. Frémont ran as the Republican nominee.  The contrast between Buchanan and Frémont was particularly stark, with opposing caricaturists drawing the Democratic candidate as a fussy old man in drag. Buchanan did not actively campaign, but he wrote letters and pledged to uphold the Democratic platform. In the election, he carried every slave state except for Maryland, as well as five slavery-free states, including his home state of Pennsylvania. He won 45 percent of the popular vote and decisively won the electoral vote, taking 174 of 296 votes. His election made him the first president from Pennsylvania. In a combative victory speech, Buchanan denounced Republicans, calling them a \\\"dangerous\\\" and \\\"geographical\\\" party that had unfairly attacked the South. He also declared, \\\"the object of my administration will be to destroy sectional party, North or South, and to restore harmony to the Union under a national and conservative government.\\\" He set about this initially by feigning a sectional balance in his cabinet appointments.\\n\\nPresidency (1857–1861)\\nInauguration\\nBuchanan was inaugurated on March 4, 1857, taking the oath of office from Chief Justice Roger B. Taney. In his lengthy inaugural address, Buchanan committed himself to serving only one term, as his predecessor had done. He abhorred the growing divisions over slavery and its status in the territories, saying that Congress should play no role in determining the status of slavery in the states or territories. He proposed a solution based on the Kansas-Nebraska Act, which stated that the principle of popular sovereignty was decisive, and Congress had no say in the matter. Buchanan recommended that a federal slave code be enacted to protect the rights of slaveowners in federal territories. He alluded to a then-pending Supreme Court case, Dred Scott v. Sandford, which he said would permanently settle the issue of slavery. Dred Scott was a slave who was temporarily taken from a slave state to a free territory by his owner, John Sanford. After Scott returned to the slave state, he filed a petition for his freedom based on his time in the free territory.\\n\\nAssociate Justice Robert C. Grier leaked the decision in the \\\"Dred Scott\\\" case early to Buchanan. In his inaugural address, Buchanan declared that the issue of slavery in the territories would be \\\"speedily and finally settled\\\" by the Supreme Court.  According to historian Paul Finkelman: Buchanan already knew what the Court was going to decide. In a major breach of Court etiquette, Justice Grier, who, like Buchanan, was from Pennsylvania, had kept the President-elect fully informed about the progress of the case and the internal debates within the Court. When Buchanan urged the nation to support the decision, he already knew what Taney would say. Republican suspicions of impropriety turned out to be fully justified.\\nHistorians agree that the court decision was a major disaster because it dramatically inflamed tensions, leading to the Civil War. In 2022, historian David W. Blight argued that the year 1857 was, \\\"the great pivot on the road to disunion...largely because of the Dred Scott case, which stoked the fear, distrust and conspiratorial hatred already common in both the North and the South to new levels of intensity.\\\"\\n\\nPersonnel\\nCabinet and administration\\nAs his inauguration approached, Buchanan sought to establish an obedient, harmonious cabinet to avoid the in-fighting that had plagued Andrew Jackson's administration. The cabinet's composition had to do justice to the proportional representation within the party and between the regions of the country. Buchanan first worked on this task in Wheatland until he traveled to the capital in January 1857. There, like many other guests at the National Hotel, he contracted severe dysentery, from which he did not fully recover until several months later. Dozens of those who fell ill died, including Buchanan's nephew and private secretary Eskridge Lane.\\nThe cabinet selection was disastrous, with four Southern ministers being large-scale slaveholders who later became loyal to the Confederate States of America. Secretary of the Treasury Howell Cobb was considered the greatest political talent in the Cabinet, while the three department heads from the northern states were all considered to be doughfaces.  His objective was to dominate the cabinet, and he chose men who would agree with his views. Buchanan had a troubled relationship with his vice president from the beginning, when he did not receive him during his inaugural visit but referred him to his niece and First Lady, which Breckinridge never forgave him for and saw as disrespectful. He left out the influential Stephen A. Douglas, who had made Buchanan's nomination possible by resigning at the National Convention the previous year, when filling the post. Concentrating on foreign policy, he appointed the aging Lewis Cass as Secretary of State. Buchanan's appointment of Southerners and their allies alienated many in the North, and his failure to appoint any followers of Douglas divided the party. Outside of the cabinet, he left in place many of Pierce's appointments but removed a disproportionate number of Northerners who had ties to Democratic opponents Pierce or Douglas.\\n\\nJudicial appointments\\nBuchanan appointed one Justice, Nathan Clifford, to the Supreme Court of the United States. He appointed seven other federal judges to United States district courts. He also appointed two judges to the United States Court of Claims.\\n\\nIntervention in the Dred Scott case\\nThe case of Dred Scott v. Sandford, to which Buchanan referred to in his inaugural address, dated back to 1846. Scott sued for his release in Missouri, claiming he lived in service to the proprietor in Illinois and Wisconsin Territory. The case reached the Supreme Court and gained national attention by 1856. Buchanan consulted with Judge John Catron in January 1857, inquiring about the outcome of the case and suggesting that a broader decision, beyond the specifics of the case, would be more prudent. Buchanan hoped that a broad decision protecting slavery in the territories could lay the issue to rest, allowing him to focus on other issues.\\nCatron replied on February 10, saying that the Supreme Court's Southern majority would decide against Scott, but would likely have to publish the decision on narrow grounds unless Buchanan could convince his fellow Pennsylvanian, Justice Robert Cooper Grier, to join the majority of the court. Buchanan then wrote to Grier and prevailed upon him, providing the majority leverage to issue a broad-ranging decision sufficient to render the Missouri Compromise of 1820 unconstitutional.\\nTwo days after Buchanan was sworn in as president, Chief Justice Taney delivered the Dred Scott decision, which denied the petitioner's request to be set free from slavery. The ruling broadly asserted that Congress had no constitutional power to exclude slavery in the territories. According to this decision, slaves were forever the property of their owners without rights and no African American could ever be a full citizen of the United States, even if they had full civil rights in a state. Buchanan's letters were not made public at the time, but he was seen conversing quietly with the Chief Justice during his inauguration. When the decision was issued, Republicans began spreading the word that Taney had informed Buchanan of the impending outcome. Rather than destroying the Republican platform as Buchanan had hoped, the decision infuriated Northerners, who condemned it.\\n\\nPanic of 1857\\nThe Panic of 1857 began in the summer of that year, when the New York branch of Ohio Life Insurance and Trust Company announced its insolvency. The crisis spread rapidly, and by the fall, 1,400 state banks and 5,000 businesses had gone bankrupt. Unemployment and hunger became common in northern cities, but the agricultural south was more resilient. Buchanan agreed with the southerners who attributed the economic collapse to over-speculation.\\nBuchanan acted in accordance with Jacksonian Democracy principles, which restricted paper money issuance, and froze federal funds for public works projects, causing resentment among some of the population due to his refusal to implement an economic stimulus program. While the government was \\\"without the power to extend relief\\\", it would continue to pay its debts in specie, and while it would not curtail public works, none would be added. In hopes of reducing paper money supplies and inflation, he urged the states to restrict the banks to a credit level of $3 to $1 of specie and discouraged the use of federal or state bonds as security for bank note issues. The economy recovered in several years, though many Americans suffered as a result of the panic. Buchanan had hoped to reduce the deficit, but by the time he left office the federal budget grew by 15%.\\n\\nUtah War\\nIn the spring of 1857, the Latter-day Saints and their leader Brigham Young had been challenging federal representatives in Utah Territory, causing harassment and violence against non-Mormons. Young harassed federal officers and discouraged outsiders from settling in the Salt Lake City area. In September 1857, the Utah Territorial Militia, associated with the Latter-day Saints, perpetrated the Mountain Meadows massacre, in which Young's militia attacked a wagon train and killed 125 settlers. Buchanan was offended by the militarism and polygamous behavior of Young. With reports of violence against non-Mormons, Buchanan authorized a military expedition into Utah Territory in late March 1857 to replace Young as governor. The force consisted of 2,500 men, including Alfred Cumming and his staff, and was commanded by General William S. Harney. Complicating matters, Young's notice of his replacement was not delivered because the Pierce administration had annulled the Utah mail contract, and Young portrayed the approaching forces as an unauthorized overthrow.\\nBuchanan's personnel decision incited resistance from the Mormons around Young, as Harney was known for his volatility and brutality. In August 1857, Albert S. Johnston replaced him for organizational reasons. Young reacted to the military action by mustering a two-week expedition, destroying wagon trains, oxen, and other Army property. Buchanan then dispatched Thomas L. Kane as a private agent to negotiate peace. The mission was successful, a peaceful agreement to replace Governor Young with Cumming was reached, and the Utah War ended. The President granted amnesty to inhabitants affirming loyalty to the government, and placed the federal troops at a peaceable distance for the balance of his administration.\\nBuchanan did not comment on the conflict again until his State of the Union Address in December 1857, leaving open the question of whether it was a rebellion in Utah. One of Buchanan's last official acts in March 1861 was to reduce the size of Utah Territory in favor of Nevada, Colorado, and Nebraska. While the Latter-day Saints had frequently defied federal authority, some historians consider Buchanan's action was an inappropriate response to uncorroborated reports.\\n\\nTransatlantic telegraph cable\\nBuchanan was the first recipient of an official telegram transmitted across the Atlantic. Following the dispatch of test and configuration telegrams, on August 16, 1858 Queen Victoria sent a 98-word message to Buchanan at his summer residence in the Bedford Springs Hotel in Pennsylvania, expressing hope that the newly laid cable would prove \\\"an additional link between the nations whose friendship is founded on their common interest and reciprocal esteem\\\". Queen Victoria's message took 16 hours to send.\\nBuchanan responded: \\\"It is a triumph more glorious, because far more useful to mankind, than was ever won by conqueror on the field of battle. May the Atlantic telegraph, under the blessing of Heaven, prove to be a bond of perpetual peace and friendship between the kindred nations, and an instrument destined by Divine Providence to diffuse religion, civilization, liberty, and law throughout the world.\\\"\\n\\nBleeding Kansas and constitutional dispute\\nThe Kansas–Nebraska Act of 1854 created the Kansas Territory and allowed the settlers there to decide whether to allow slavery. This resulted in violence between \\\"Free-Soil\\\" (antislavery) and pro-slavery settlers, which developed into the \\\"Bleeding Kansas\\\" period. The antislavery settlers, with the help of Northern abolitionists, organized their own territorial government in Topeka. The more numerous proslavery settlers, many from the neighboring slave state Missouri, established a government in Lecompton, giving the Territory two different governments for a time, with two distinct constitutions, each claiming legitimacy. The admission of Kansas as a state required a constitution be submitted to Congress with the approval of a majority of its residents. Under President Pierce, a series of violent confrontations escalated over who had the right to vote in Kansas. The situation drew national attention, and some in Georgia and Mississippi advocated secession should Kansas be admitted as a free state. Buchanan chose to endorse the pro-slavery Lecompton government.\\nBuchanan appointed Robert J. Walker to replace John W. Geary as Territorial Governor, and there ensued conflicting referendums from Topeka and Lecompton, where election fraud occurred. In October 1857, the Lecompton government framed the pro-slavery Lecompton Constitution that agreed to a referendum limited solely to the slavery question. However, the vote against slavery, as provided by the Lecompton Convention, would still permit existing slaves, and all their issue, to be enslaved, so there was no referendum that permitted the majority anti-slavery residents to prohibit slavery in Kansas. As a result, anti-slavery residents boycotted the referendum since it did not provide a meaningful choice.\\nDespite the protests of Walker and two former Kansas governors, Buchanan decided to accept the Lecompton Constitution. In a December 1857 meeting with Stephen A. Douglas, the chairman of the Senate Committee on Territories, Buchanan demanded that all Democrats support the administration's position of admitting Kansas under the Lecompton Constitution. On February 2, he transmitted the Lecompton Constitution to Congress. He also transmitted a message that attacked the \\\"revolutionary government\\\" in Topeka, conflating them with the Mormons in Utah. Buchanan made every effort to secure congressional approval, offering favors, patronage appointments, and even cash for votes. The Lecompton Constitution won the approval of the Senate in March, but a combination of Know-Nothings, Republicans, and Northern Democrats defeated the bill in the House.\\nBuchanan never forgave Douglas, as the Northern Democrats' rejection was the deciding factor in the House's decision, and he removed all Douglas supporters from his patronage in Illinois and Washington, D.C., installing pro-administration Democrats, including postmasters. Rather than accepting defeat, Buchanan backed the 1858 English Bill, which offered Kansas immediate statehood and vast public lands in exchange for accepting the Lecompton Constitution. In August 1858, Kansans by referendum strongly rejected the Lecompton Constitution. The territory received an abolitionist constitution, which was bitterly opposed in Congress by representatives and senators from the southern states until Kansas was admitted to the Union in January 1861.\\nThe dispute over Kansas became the battlefront for control of the Democratic Party. On one side were Buchanan, the majority of Southern Democrats, and the \\\"doughfaces\\\". On the other side were Douglas and the majority of northern Democrats, as well as a few Southerners. Douglas's faction continued to support the doctrine of popular sovereignty, while Buchanan insisted that Democrats respect the Dred Scott decision and its repudiation of federal interference with slavery in the territories.\\n\\n1858 mid-term elections\\nDouglas's Senate term was coming to an end in 1859, with the Illinois legislature, elected in 1858, determining whether Douglas would win re-election. The Senate seat was the primary issue of the legislative election, marked by the famous debates between Douglas and his Republican opponent for the seat, Abraham Lincoln. Buchanan, working through federal patronage appointees in Illinois, ran candidates for the legislature in competition with both the Republicans and the Douglas Democrats. This could easily have thrown the election to the Republicans, and showed the depth of Buchanan's animosity toward Douglas. In the end, Douglas Democrats won the legislative election and Douglas was re-elected to the Senate. In that year's elections, Douglas forces took control throughout the North, except in Buchanan's home state of Pennsylvania. Buchanan's support was otherwise reduced to a narrow base of southerners.\\nThe division between northern and southern Democrats allowed the Republicans to win a plurality of the House in the 1858 elections, and allowed them to block most of Buchanan's agenda. Buchanan, in turn, added to the hostility with his veto of six substantial pieces of Republican legislation. Among these measures were the Homestead Act, which would have given 160 acres of public land to settlers who remained on the land for five years, and the Morrill Act, which would have granted public lands to establish land-grant colleges. Buchanan argued that these acts were unconstitutional. In the western and northwestern United States, where the Homestead Act was very popular, even many Democrats condemned the president's policies, while many Americans who considered education an important asset resented Buchanan's veto of agricultural colleges.\\n\\nForeign policy\\nBuchanan took office with an ambitious foreign policy, designed to establish U.S. hegemony over Central America at the expense of Great Britain. Buchanan sought to revitalize Manifest Destiny and to enforce the Monroe Doctrine, which had been under attack from the Spanish, French, and especially the British in the 1850s. He hoped to re-negotiate the Clayton–Bulwer Treaty to counter European imperialism in the Western Hemisphere, which he thought limited U.S. influence in the region. He also sought to establish American protectorates over the Mexican states of Chihuahua and Sonora to secure American citizens and investments, and most importantly, he hoped to achieve his long-term goal of acquiring Cuba. However, Buchanan's ambitions in Cuba and Mexico were largely blocked by the House of Representatives. After long negotiations with the British, he convinced them to cede the Bay Islands to Honduras and the Mosquito Coast to Nicaragua.\\nIn 1858, Buchanan ordered the Paraguay expedition to punish Paraguay for firing on the USS Water Witch, ordering 2,500 marines and 19 warships there. This costly expedition took months to reach Asunción, which successfully resulted in a Paraguayan apology and payment of an indemnity. The chiefs of Raiatea and Tahaa in the South Pacific, refusing to accept the rule of King Tamatoa V, unsuccessfully petitioned the United States to accept the islands under a protectorate in June 1858. Buchanan also considered buying Alaska from the Russian Empire, as whaling in the waters there had become of great economic importance to the United States. Buchanan fueled this by spreading the rumor to the Russian ambassador Eduard de Stoeckl in December 1857 that a large amount of Mormons intended to emigrate to Russian Alaska. In the winter of 1859, an initial purchase offer of $5,000,000 (equivalent to $169,560,000 in 2023) was made. Although the project ultimately failed due to the reservations of Foreign Minister Alexander Gorchakov, the talks formed the basis for the later negotiations to purchase Alaska.\\nBuchanan sought trade agreements with the Qing Dynasty and Japan. In China, his envoy William Bradford Reed succeeded in having the United States included as a party to the Treaty of Tianjin. In May 1860, Buchanan received a Japanese delegation consisting of several princes who carried the Harris Treaty negotiated by Townsend Harris for mutual ratification. Buchanan was offered a herd of elephants by King Rama IV of Siam, though the letter arrived after Buchanan's departure from office and Buchanan's successor Abraham Lincoln declined the offer stating that the U.S. had an unsuitable climate. Other presidential pets included a pair of bald eagles and a Newfoundland dog.\\n\\nCovode Committee\\nIn March 1860, the House impaneled the Covode Committee to investigate the Buchanan administration's patronage system for alleged impeachable offenses, such as bribery and extortion of representatives. Buchanan supporters accused the committee, consisting of three Republicans and two Democrats, of being blatantly partisan, and claimed its chairman, Republican Rep. John Covode, was acting on a personal grudge stemming from a disputed land grant designed to benefit Covode's railroad company. The Democratic committee members, as well as Democratic witnesses, were enthusiastic in their condemnation of Buchanan.\\nThe committee was unable to establish grounds for impeaching Buchanan; however, the majority report issued on June 17 alleged corruption and abuse of power among members of his cabinet. The committee gathered evidence that Buchanan had tried to bribe members of Congress in his favor through intermediaries in the spring of 1858 in connection with the pro-slavery Lecompton Constitution of Kansas, and threatened their relatives with losing their posts if they did not vote in favor of the Lecompton Constitution. Witnesses also testified that the federal government used public funds to strengthen the intra-party faction of Douglas's opponents in Illinois. The Democrats pointed out that evidence was scarce, but did not refute the allegations; one of the Democratic members, Rep. James Robinson, stated that he agreed with the Republicans, though he did not sign it.\\nThe public was shocked by the extent of the bribery, which affected all levels and agencies of government. Buchanan claimed to have \\\"passed triumphantly through this ordeal\\\" with complete vindication. Republican operatives distributed thousands of copies of the Covode Committee report throughout the nation as campaign material in that year's presidential election.\\n\\nElection of 1860\\nAs he had promised in his inaugural address, Buchanan did not seek re-election. He went so far as to tell his ultimate successor, \\\"If you are as happy in entering the White House as I shall feel on returning to Wheatland, you are a happy man.\\\"\\nAt the 1860 Democratic National Convention in Charleston, the party split over the issue of slavery in the territories, damaging Buchanan's reputation as the main person responsible for this issue. Though Douglas led after every ballot, he was unable to win the two-thirds majority required. The convention adjourned after 53 ballots, and re-convened in Baltimore in June. After Douglas finally won the nomination, several Southerners refused to accept the outcome, and nominated Vice President Breckinridge as their own candidate. Douglas and Breckinridge agreed on most issues except the protection of slavery. Buchanan, nursing a grudge against Douglas, failed to reconcile the party, and tepidly supported Breckinridge. With the splintering of the Democratic Party, Republican nominee Abraham Lincoln won a four-way election that also included John Bell of the Constitutional Union Party. Lincoln's support in the North was enough to give him an Electoral College majority. Buchanan became the last Democrat to win a presidential election until Grover Cleveland in 1884.\\nAs early as October, the army's Commanding General, Winfield Scott, an opponent of Buchanan, warned him that Lincoln's election would likely cause at least seven states to secede from the union. He recommended that massive amounts of federal troops and artillery be deployed to those states to protect federal property, although he also warned that few reinforcements were available. Since 1857, Congress had failed to heed calls for a stronger militia and allowed the army to fall into deplorable condition. Buchanan distrusted Scott and ignored his recommendations. After Lincoln's election, Buchanan directed Secretary of War John B. Floyd to reinforce southern forts with such provisions, arms, and men as were available; however, Floyd persuaded him to revoke the order.\\n\\nSecession\\nWith Lincoln's victory, talk of secession and disunion reached a boiling point, putting the burden on Buchanan to address it in his final speech to Congress on December 10. In his message, which was anticipated by both factions, Buchanan denied the right of states to secede but maintained the federal government was without power to prevent them. He placed the blame for the crisis solely on \\\"intemperate interference of the Northern people with the question of slavery in the Southern States,\\\" and suggested that if they did not \\\"repeal their unconstitutional and obnoxious enactments ... the injured States, after having first used all peaceful and constitutional means to obtain redress, would be justified in revolutionary resistance to the Government of the Union.\\\" Buchanan's only suggestion to solve the crisis was \\\"an explanatory amendment\\\" affirming the constitutionality of slavery in the states, the fugitive slave laws, and popular sovereignty in the territories. His address was sharply criticized both by the North, for its refusal to stop secession, and the South, for denying its right to secede. Five days after the address was delivered, Treasury Secretary Howell Cobb resigned, as his views had become irreconcilable with the President's. Even as the formation of the Confederacy by the secessionist states became increasingly apparent in the winter of 1860, the president continued to surround himself with Southerners and ignore the Republicans.\\n\\nSouth Carolina, long the most radical Southern state, seceded from the Union on December 20, 1860. However, Unionist sentiment remained strong among many in the South, and Buchanan sought to appeal to the Southern moderates who might prevent secession in other states. He met with South Carolinian commissioners in an attempt to resolve the situation at Fort Sumter, which federal forces remained in control of despite its location in Charleston, South Carolina. Buchanan saw Congress, not himself, as responsible for finding a solution to the secession crisis. As a compromise for the southern states, Buchanan envisioned the adoption of amendments to the United States Constitution that would guarantee the right to slavery in the southern states and territories and strengthen the right of slave owners to reclaim escaped slaves as property in the northern states.\\nHe refused to dismiss Interior Secretary Jacob Thompson after the latter was chosen as Mississippi's agent to discuss secession, and he refused to fire Secretary of War John B. Floyd despite an embezzlement scandal. Floyd ended up resigning, but not before sending numerous firearms to Southern states, where they eventually fell into the hands of the Confederacy. Despite Floyd's resignation, Buchanan continued to seek the advice of counselors from the Deep South, including Jefferson Davis and William Henry Trescot. Buchanan's friend Rose O'Neal Greenhow took advantage of the proximity to the president and spied for the Confederacy, which had already established a sophisticated network for gathering information from its eventual opponent before its formation.\\nEfforts were made in vain by Sen. John J. Crittenden, Rep. Thomas Corwin, and former president John Tyler to negotiate a compromise to stop secession, with Buchanan's support. Failed attempts were also made by a group of governors meeting in New York. Buchanan secretly asked President-elect Lincoln to call for a national referendum on the issue of slavery, but Lincoln declined. In December 1860, when the second session of the 36th Congress was convened, The Committee of Thirty-Three was established by the House of Representatives to prevent further states from seceding. They proposed the Corwin Amendment, which would bar Congress from interfering with slavery in states. Despite opposition from Republicans, it passed both houses of Congress and was proposed to states for ratification, but it was never ratified by the requisite number of states.\\nDespite the efforts of Buchanan and others, six more slave states seceded by the end of January 1861. Buchanan replaced the departed Southern cabinet members with John Adams Dix, Edwin M. Stanton, and Joseph Holt, all of whom were committed to preserving the Union. When Buchanan considered surrendering Fort Sumter, the new cabinet members threatened to resign, and Buchanan relented. On January 5, Buchanan decided to reinforce Fort Sumter, sending the Star of the West with 250 men and supplies. However, he failed to ask Major Robert Anderson to provide covering fire for the ship, and it was forced to return North without delivering troops or supplies. Buchanan chose not to respond to this act of war, and instead sought to find a compromise to avoid secession. He received a March 3 message from Anderson, that supplies were running low, but the response became Lincoln's to make, as the latter succeeded to the presidency the next day.\\n\\nStates admitted to the Union\\nThree new states were admitted to the Union while Buchanan was in office:\\n\\nMinnesota – May 11, 1858\\nOregon – February 14, 1859\\nKansas – January 29, 1861\\n\\nFinal years and death (1861–1868)\\nAfter leaving office, Buchanan retired to private life in Wheatland, where he spent most of his time in his study, reading books and writing letters. The Civil War erupted within two months of Buchanan's retirement. He supported the Union and the war effort, writing to former colleagues that, \\\"the assault upon Sumter was the commencement of war by the Confederate states, and no alternative was left but to prosecute it with vigor on our part.\\\" Buchanan supported Lincoln's introduction of universal conscription in the northern states, but was an opponent of his Emancipation Proclamation. Although he recognized constitutional violations in some of the president's executive orders, he never criticized them in public. He also wrote a letter to his fellow Pennsylvania Democrats in Harrisburg, urging them and all young men to enlist in the Union army and \\\"join the many thousands of brave & patriotic volunteers who are already in the field.\\\"\\nBuchanan was dedicated to defending his actions prior to the Civil War, which was referred to by some as \\\"Buchanan's War\\\". He received hate mail and threatening letters daily, and stores in Lancaster displayed Buchanan's likeness with the eyes inked red, a noose drawn around his neck and the word \\\"TRAITOR\\\" written across his forehead. The Senate proposed a resolution of condemnation which ultimately failed, and newspapers accused him of colluding with the Confederacy. His former cabinet members, five of whom had been given jobs in the Lincoln administration, refused to defend Buchanan publicly.\\nBuchanan became distraught by the vitriolic attacks levied against him, and fell sick and depressed. In October 1862, he defended himself in an exchange of letters with Winfield Scott, published in the National Intelligencer. He soon began writing his fullest public defense, in the form of his memoir Mr. Buchanan's Administration on the Eve of Rebellion, which was published in 1866, one year after the Civil War ended. Buchanan attributed secession to the \\\"malign influence\\\" of Republicans and the abolitionist movement. He discussed his foreign policy successes and expressed satisfaction with his decisions, even during the secession crisis. He blamed Robert Anderson, Winfield Scott, and Congress for the unresolved issue. Two years after the publication of the memoir, Buchanan caught a cold in May 1868, which quickly worsened due to his advanced age. He died on June 1, 1868, of respiratory failure at the age of 77 at his home at Wheatland. He was interred in Woodward Hill Cemetery in Lancaster.\\n\\nPolitical views\\nBuchanan was often considered by anti-slavery northerners a \\\"doughface\\\", a northerner with pro-southern principles. Buchanan's sympathies for the Southern states went beyond political expediency for his path to the White House. He identified with cultural and social values that he found reflected in the honor code and lifestyle of the planter class and with which he increasingly came into contact in his retirement community beginning in 1834. Shortly after his election, he said that the \\\"great object\\\" of his administration was \\\"to arrest, if possible, the agitation of the Slavery question in the North and to destroy sectional parties\\\". Although Buchanan was personally opposed to slavery, he believed that the abolitionists were preventing the solution to the slavery problem. He stated, \\\"Before [the abolitionists] commenced this agitation, a very large and growing party existed in several of the slave states in favor of the gradual abolition of slavery; and now not a voice is heard there in support of such a measure. The abolitionists have postponed the emancipation of the slaves in three or four states for at least half a century.\\\" In deference to the intentions of the typical slaveholder, he was willing to provide the benefit of the doubt. In his third annual message to Congress, the president claimed that the slaves were \\\"treated with kindness and humanity. ... Both the philanthropy and the self-interest of the master have combined to produce this humane result.\\\"\\n\\nBuchanan thought restraint was the essence of good self-government. He believed the constitution comprised \\\"... restraints, imposed not by arbitrary authority, but by the people upon themselves and their representatives. ... In an enlarged view, the people's interests may seem identical, but to the eye of local and sectional prejudice, they always appear to be conflicting ... and the jealousies that will perpetually arise can be repressed only by the mutual forbearance which pervades the constitution.\\\" Regarding slavery and the Constitution, he stated: \\\"Although in Pennsylvania we are all opposed to slavery in the abstract, we can never violate the constitutional compact we have with our sister states. Their rights will be held sacred by us. Under the constitution it is their own question; and there let it remain.\\\"\\nOne of the prominent issues of the day was tariffs. Buchanan was conflicted by free trade as well as prohibitive tariffs, since either would benefit one section of the country to the detriment of the other. As a senator from Pennsylvania, he said: \\\"I am viewed as the strongest advocate of protection in other states, whilst I am denounced as its enemy in Pennsylvania.\\\"\\nBuchanan was also torn between his desire to expand the country for the general welfare of the nation, and to guarantee the rights of the people settling particular areas. On territorial expansion, he said, \\\"What, sir? Prevent the people from crossing the Rocky Mountains? You might just as well command the Niagara not to flow. We must fulfill our destiny.\\\" On the resulting spread of slavery, through unconditional expansion, he stated: \\\"I feel a strong repugnance by any act of mine to extend the present limits of the Union over a new slave-holding territory.\\\" For instance, he hoped the acquisition of Texas would \\\"be the means of limiting, not enlarging, the dominion of slavery.\\\"\\n\\nPersonal life\\nBuchanan suffered from esotropia. In addition, one eye was short-sighted and the other far-sighted. To cover this, he bent his head forward and leaned it to one side during social interactions. This led to ridicule, which Henry Clay, among others, used ruthlessly during a congressional debate.\\nIn 1818, Buchanan met Anne Caroline Coleman at a grand ball in Lancaster, and the two began courting. Anne was the daughter of the wealthy iron manufacturer Robert Coleman; Robert, like Buchanan's father, was from County Donegal in Ulster. Anne was also the sister-in-law of Philadelphia judge Joseph Hemphill, one of Buchanan's colleagues. By 1819, the two were engaged, but spent little time together. Buchanan was busy with his law firm and political projects during the Panic of 1819, which took him away from Coleman for weeks at a time. Rumors abounded, as some suggested that he was involved with other (unidentified) women. Letters from Coleman revealed she was aware of several rumors, and she accused him of only being interested in her money. She broke off the engagement, and soon afterward, on December 9, 1819, inexplicably died of \\\"hysterical convulsions\\\" resulting from an overdose of laudanum, at the age of 23. It was never established if the drug was taken by instruction, by accident, or by intent. Buchanan wrote to her father for permission to attend the funeral, which was refused. At the time of her funeral, he said that, \\\"I feel happiness has fled from me forever.\\\" Afterwards, Buchanan claimed that he remained unmarried out of devotion to his only love, who had died young.\\n\\nIn 1833 and the 1840s, he spoke of plans to marry, but these came to nothing and may merely have been due to his ambitions for a seat in the federal Senate or the White House. In the latter case, the aspirant was 19-year-old Anna Payne, the niece of former First Lady Dolley Madison. During his presidency, an orphaned niece, Harriet Lane, whom he had adopted, served as official White House hostess. There was an unfounded rumor that he had an affair with President Polk's widow, Sarah Childress Polk.\\nBuchanan had a close relationship with William Rufus King, which became a popular target of gossip. King was an Alabama politician who briefly served as vice president under Franklin Pierce. Buchanan and King lived together in a Washington boardinghouse and attended social functions together from 1834 until 1844. Such a living arrangement was then common, though Buchanan once referred to the relationship as a \\\"communion\\\". Andrew Jackson mockingly called them \\\"Miss Nancy\\\" and \\\"Aunt Fancy\\\", the former being a 19th-century euphemism for an effeminate man. Buchanan's Postmaster General, Aaron V. Brown, also referred to King as \\\"Aunt Fancy\\\", as well as Buchanan's \\\"better half\\\", and \\\"wife\\\". King died of tuberculosis shortly after Pierce's inauguration, four years before Buchanan became president. Buchanan described him as \\\"among the best, the purest and most consistent public men I have known\\\". Biographer Baker opines that both men's nieces may have destroyed correspondence between the two men. However, she believes that their surviving letters illustrate only \\\"the affection of a special friendship\\\".\\nBuchanan's lifelong bachelorhood after Anne Coleman's death has drawn interest and speculation. Some conjecture that Anne's death merely served to deflect questions about Buchanan's sexuality and bachelorhood. One of his biographers, Jean Baker, suggests that Buchanan was celibate, if not asexual. Several writers have surmised that he was homosexual, including James W. Loewen, Robert P. Watson, and Shelley Ross. Loewen indicated that Buchanan, late in life, wrote a letter acknowledging that he might marry a woman who could accept his \\\"lack of ardent or romantic affection\\\".\\n\\nLegacy\\nHistorical reputation\\nThough Buchanan predicted that \\\"history will vindicate my memory,\\\" historians have criticized Buchanan for his unwillingness or inability to act in the face of secession. Historical rankings of presidents of the United States without exception place Buchanan among the least successful presidents. When scholars are surveyed, he ranks at or near the bottom in terms of vision/agenda-setting, domestic leadership, foreign policy leadership, moral authority, and positive historical significance of their legacy. According to surveys taken by American scholars and political scientists between 1948 and 1982, Buchanan ranks every time among the worst presidents of the United States, alongside Harding, Fillmore and Nixon.\\nBuchanan biographer Philip S. Klein focused in 1962, during the Civil Rights movement, upon challenges Buchanan faced:\\n\\nBuchanan assumed leadership ... when an unprecedented wave of angry passion was sweeping over the nation. That he held the hostile sections in check during these revolutionary times was in itself a remarkable achievement. His weaknesses in the stormy years of his presidency were magnified by enraged partisans of the North and South. His many talents, which in a quieter era might have gained for him a place among the great presidents, were quickly overshadowed by the cataclysmic events of civil war and by the towering Abraham Lincoln.\\nBiographer Jean Baker is less charitable to Buchanan, saying in 2004:\\n\\nAmericans have conveniently misled themselves about the presidency of James Buchanan, preferring to classify him as indecisive and inactive ... In fact Buchanan's failing during the crisis over the Union was not inactivity, but rather his partiality for the South, a favoritism that bordered on disloyalty in an officer pledged to defend all the United States. He was that most dangerous of chief executives, a stubborn, mistaken ideologue whose principles held no room for compromise. His experience in government had only rendered him too self-confident to consider other views. In his betrayal of the national trust, Buchanan came closer to committing treason than any other president in American history.Other historians, such as Robert May, argued that his politics were \\\"anything but pro-slavery\\\", nevertheless, a very negative view is to be found in Michael Birkner's works about Buchanan. For Lori Cox Han, he ranks among scholars \\\"as either the worst president in [American] history or as part of a lowest ranking failure category\\\".\\n\\nMemorials\\nA bronze and granite memorial near the southeast corner of Washington, D.C.'s Meridian Hill Park was designed by architect William Gorden Beecher and sculpted by Maryland artist Hans Schuler. It was commissioned in 1916 but not approved by the U.S. Congress until 1918, and not completed and unveiled until June 26, 1930. The memorial features a statue of Buchanan, bookended by male and female classical figures representing law and diplomacy, with engraved text reading: \\\"The incorruptible statesman whose walk was upon the mountain ranges of the law,\\\" a quote from a member of Buchanan's cabinet, Jeremiah S. Black.\\n\\nAn earlier monument was constructed in 1907–1908 and dedicated in 1911, on the site of Buchanan's birthplace in Stony Batter, Pennsylvania. Part of the original 18.5-acre (75,000 m2) memorial site is a 250-ton pyramid structure that stands on the site of the original cabin where Buchanan was born. The monument was designed to show the original weathered surface of the native rubble and mortar.\\nThree counties are named in his honor, in Iowa, Missouri, and Virginia. Another in Texas was christened in 1858 but renamed Stephens County, after the newly elected vice president of the Confederate States of America, Alexander Stephens, in 1861. The city of Buchanan, Michigan, was also named after him. Several other communities are named after him: the unincorporated community of Buchanan, Indiana, the city of Buchanan, Georgia, the town of Buchanan, Wisconsin, and the townships of Buchanan Township, Michigan, and Buchanan, Missouri.\\nJames Buchanan High School is a small, rural high school located on the outskirts of his childhood hometown, Mercersburg, Pennsylvania.\\n\\nPopular culture depictions\\nBuchanan and his legacy are central to the film Raising Buchanan (2019). He is portrayed by René Auberjonois.\\n\\nSee also\\nHistorical rankings of presidents of the United States\\nList of presidents of the United States\\nList of presidents of the United States by previous experience\\nPresidents of the United States on U.S. postage stamps\\nList of federal political sex scandals in the United States\\n\\nReferences\\nWorks cited\\nFurther reading\\nExternal links\\n\\nUnited States Congress. \\\"James Buchanan (id: B001005)\\\". Biographical Directory of the United States Congress.\\nJames Buchanan: A Resource Guide from the Library of Congress\\nThe James Buchanan papers, spanning the entirety of his legal, political and diplomatic career, are available for research use at the Historical Society of Pennsylvania.\\nUniversity of Virginia article: Buchanan biography\\nWheatland\\nJames Buchanan at Tulane University\\nEssay on James Buchanan and his presidency from the Miller Center of Public Affairs\\nBuchanan's Birthplace State Park, Franklin County, Pennsylvania\\n\\\"Life Portrait of James Buchanan\\\", from C-SPAN's American Presidents: Life Portraits, June 21, 1999\\nPrimary sources\\n\\nWorks by James Buchanan at Project Gutenberg\\nWorks by James Buchanan at LibriVox (public domain audiobooks) \\nWorks by or about James Buchanan at the Internet Archive\\nJames Buchanan Ill with Dysentery Before Inauguration: Original Letters Shapell Manuscript Foundation\\nMr. Buchanans Administration on the Eve of the Rebellion. President Buchanans memoirs.\\nInaugural Address Archived August 9, 2020, at the Wayback Machine\\nFourth Annual Message to Congress, December 3, 1860\\n\\nHarriet Rebecca Lane Johnston (May 9, 1830 – July 3, 1903) acted as first lady of the United States during the administration of her uncle, lifelong bachelor president James Buchanan, from 1857 to 1861. She has been described as the first of the modern first ladies, being a notably charming and diplomatic hostess, whose dress-styles were copied, and who promoted deserving causes. In her will, she left funds for a new school on the grounds of Washington National Cathedral. Several ships have been named in her honor, including the cutter USCGC Harriet Lane, still in service.\\n\\nStatus\\nLane is the only person to have served as First Lady to a bachelor president, Buchanan being the only U.S. president never to have married. She is among 11 women who have served as First Lady, but were not married to the president, with most of the other women being relatives of widowed presidents.\\n\\nEarly life\\nHarriet Lane's family was from Franklin County, Pennsylvania.  She was the youngest child of Elliott Tole Lane, a merchant, and Jane Ann Buchanan Lane. She lost her mother when she was nine; when her father's death two years later made her an orphan, she requested that her favorite uncle, James Buchanan, be appointed as her legal guardian. Buchanan, an unmarried Democratic senator from Pennsylvania, indulged his niece and her sister, enrolling them in boarding schools in Charles Town, Virginia (later for two years at the Georgetown Visitation Monastery in the Georgetown section of Washington, D.C.) By this time, Buchanan was Secretary of State, and, as he had promised, he introduced her to fashionable and political circles.\\nIn 1854, she joined him in London, where he was minister to the Court of St. James's. Queen Victoria gave \\\"dear Miss Lane\\\" the rank of ambassador's wife; admiring suitors gave her the fame of a beauty. In appearance \\\"Hal\\\" Lane was of medium height, with masses of light, almost golden-colored hair. She had eyes that were described as \\\"violet colored\\\".\\n\\nActing First Lady of the United States\\nThe capital welcomed its new \\\"Democratic Queen\\\" to the White House in 1857. Harriet was a popular hostess during the four years of the Buchanan presidency. Women copied her hair and clothing styles (especially when she lowered the neckline on her inaugural gown by 2.5 inches), parents named their daughters for her, and a popular song (\\\"Listen to the Mockingbird\\\") was dedicated to her. While in the White House, she used her position to promote social causes, such as improving the living conditions of Native Americans in reservations. She also made a point of inviting artists and musicians to White House functions. For both her popularity and her advocacy work, she has been described as the first of the modern first ladies, and her popularity at the time is compared to that of Jacqueline Kennedy in the 1960s. The presidential yacht was named for her—the first of several ships to be named after her, one of which remains in service.\\n\\nAs sectional tensions increased, she worked out seating arrangements for her weekly formal dinner parties with special care, to give dignitaries their proper precedence and still keep political foes apart. Her tact did not falter, but her task became impossible—as did her uncle's. Seven states had seceded by the time Buchanan retired from office and returned with his niece to his spacious country home, Wheatland, near Lancaster, Pennsylvania.\\nIn the 1982 Siena College Research Institute survey asking historians to assess American first ladies, Lane and several other \\\"acting\\\" first ladies were included. The first ladies survey, which has been conducted periodically since, ranks first ladies according to a cumulative score on the independent criteria of their background, value to the country, intelligence, courage, accomplishments, integrity, leadership, being their own women, public image, and value to the president. In the 1982 survey, out of 42 first ladies and acting first ladies, Lane was assessed as the 29th most highly regarded among historians. Acting first ladies such as Lane have been excluded from subsequent iterations of this survey.\\n\\nRomance and marriage\\nDuring her time in England, Sir Fitzroy Kelly, then Prime Minister Palmerston's attorney general, proposed marriage to her; Queen Victoria was strongly in favor of this match, as it would keep Lane in England.\\nLane considered the advantages of a number of bachelors. Her uncle cautioned Lane against \\\"rushing precipitately into matrimonial connections\\\" as his ward found her potential suitors \\\"pleasant but dreadfully troublesome\\\". Lane eventually married Baltimore banker Henry Elliott Johnston at the age of 36. They had two sons: James Buchanan Johnston (1866–1881) and Henry Elliot Johnston (1869–1882), but within the 18 years from 1867 to 1885, her uncle, her husband, and her children all died.\\n\\nLater life and death\\nHarriet wrote her will in 1895 and lived another eight years, during which the country's general prosperity greatly increased the value of her estate. She added a codicil in 1899 directing that a school building be constructed on the grounds of the Washington National Cathedral property and asked that it be called the Lane-Johnston Building \\\"to the end that the family names of my husband and myself may be associated with the bequest made in loving memory of our sons.\\\" A codicil of 1903 increased her gift by one third but said that only half the total was to be spent on the building. The remainder was \\\"specially to provide for the free maintenance, education and training of choirboys, primarily those in service of the Cathedral.\\\" This bequest founded the prestigious boys' school that today is called St. Albans School, which opened in October 1909. \\nAt Harriet Lane Johnston's funeral, services were conducted by Bishop Satterlee and Canon DeVries of the Washington National Cathedral. She was buried in Green Mount Cemetery, Baltimore, Maryland, her grave marked with a Celtic cross like the Peace Cross on the cathedral close. In 1905, guests were invited to see the cornerstone of the first St. Albans School building, laid for what the invitation referred to as \\\"The Lane Johnston Choir School for Boys of the Washington Cathedral\\\".\\n\\nLegacy\\nLane left bequests in her will that established a children's hospital and a boys' school, and she donated her collection of artwork to the Smithsonian. Several Navy and Coast Guard ships have been named in her honor.\\nHer birthplace, the Lane House, was listed on the National Register of Historic Places in 1972.\\n\\nHospital and school\\nShe dedicated $400,000 (equivalent to $13,600,000 in 2023) to establish the Harriet Lane Home for Invalid Children at the Johns Hopkins Hospital in Baltimore, Maryland as a memorial to two sons who had died in childhood. In October 1912 the Harriet Lane Home officially opened. It was the first children's clinic in the United States that was associated with a medical school. Eventually treating over 60,000 children a year, the Harriet Lane Home became a pioneer treatment, teaching, and research clinic.\\nFrom 1930 to 1963 Helen Taussig, who helped to develop the blue baby operation, headed the pediatric cardiac clinic. Child psychiatrist Leo Kanner did studies of autistic children. Lawson Wilkins established an endocrine clinic that developed procedures used universally to treat children with certain glandular disorders, including dwarfism. John E. Bordley and William G. Hardy broke ground in detecting hearing impairments in very young children. It became a renowned pediatric facility; the Harriet Lane Outpatient Clinics serve thousands of children today, and the widely used manual for pediatric house officers, The Harriet Lane Handbook, bears her name.\\nThe Harriet Lane Outpatient Clinics continue to operate in countries throughout the world.\\nThe pediatric medicine Harriet Lane Handbook series continues in print and online, with multiple titles. The original title (subtitled A Manual for Pediatric House Officers) is in its 22nd edition, published by Mosby.\\n\\nArt collection\\nShe had an art collection based on European works which she left to the U.S. government. The Smithsonian Institution called her the \\\"First Lady of the National Collection of Fine Arts\\\" after her collection was accepted into public ownership.\\n\\nNamesake ships\\nThe United States Coast Guard has had three cutters named in her honor. The first was the USRC Harriet Lane, commissioned into the United States Revenue Cutter Service (predecessor of the USCG) in 1857. This cutter was transferred to the United States Navy in 1861 because of the American Civil War.\\nThe second cutter named for Harriet Lane was the 125 foot USCGC Harriet Lane (WSC-141), commissioned in 1926 and decommissioned in 1946.\\nThe third cutter named for Harriet Lane is the USCGC Harriet Lane (WMEC-903). The cutter was commissioned in May 1984, and as of 2021, is still in active service.\\n\\nFootnotes\\nReferences\\nFurther reading\\nBalcerski, Thomas J. \\\"Harriet Rebecca Lane Johnston.\\\" in A Companion to First Ladies (2016): 197-213.\\nRosenberger, Homer Tope. \\\"To what Extent Did Harriet Lance Influence the Public Policies of James Buchanan?\\\" Lancaster County Historical Society, 1970. online\\nUpdike, John (1974). Buchanan Dying (play). (Ms. Johnston is a character in Updike's fictional play about President Buchanan.)\\n\\nExternal links\\nWorks by or about Harriet Lane at the Internet Archive\\n\\\"Harriet Lane\\\". First Ladies: Influence & Image. firstladies.org. CNN.\\n\\nSince the office was established in 1789, 45 persons have served as president of the United States. Of these, eight have died in office: four were assassinated, and four died of natural causes. In each of these instances, the vice president has succeeded to the presidency. This practice is now governed by Section One of the Twenty-fifth Amendment to the United States Constitution, ratified in 1967, which declares that, \\\"the Vice President shall become President\\\" if the president is removed from office, dies, or resigns. The initial authorization for this practice was provided by Article II, Section 1, Clause 6, of the U.S. Constitution.\\nThe first incumbent U.S. president to die was William Henry Harrison, on April 4, 1841, only one month after Inauguration Day. He died from complications of what at the time was believed to be pneumonia. The second American president to die in office, Zachary Taylor, died on July 9, 1850, from acute gastroenteritis. Abraham Lincoln was the first U.S. president to be killed while in office. He was shot by John Wilkes Booth on the night of April 14, 1865, and died the following morning. Sixteen years later, on July 2, 1881, James A. Garfield was shot by Charles J. Guiteau, surviving for over two months before dying on September 19, 1881.\\nOn September 14, 1901, William McKinley died, eight days after being shot by Leon Czolgosz. Next, Warren G. Harding suffered a heart attack, and died on August 2, 1923. On April 12, 1945, Franklin D. Roosevelt (who had just begun his fourth term in office) collapsed and died as a result of a cerebral hemorrhage. The most recent U.S. president to die in office was John F. Kennedy, who was shot by Lee Harvey Oswald on November 22, 1963, in Dallas, Texas.\\n\\n1841: William Henry Harrison\\nOn March 26, 1841, William Henry Harrison became ill with a cold after being caught in a torrential downpour without cover. His symptoms grew progressively worse over the ensuing two days, at which time a team of doctors was called in to treat him. After making a diagnosis of right lower lobe pneumonia, they proceeded to place heated suction cups on his bare torso and to administer a series of bloodlettings, to supposedly draw out the disease. When those procedures failed to bring about improvement, the doctors treated him with ipecac, Castor oil, calomel, and finally with a boiled mixture of crude petroleum and Virginia snakeroot. All this only weakened Harrison further.\\nInitially, no official announcement was made concerning Harrison's illness, which, the longer he remained out of public view, fueled public speculation and concern. By the end of the month large crowds were gathering outside the White House, holding vigil while awaiting any news about the president's condition. On the evening of April 4, 1841, nine days after becoming ill, and exactly one month after taking the oath of office, Harrison died at age 68. His last words were to his attending doctor, though assumed to be directed at Vice President John Tyler:\\n\\nSir, I wish you to understand the true principles of the government. I wish them carried out. I ask nothing more.\\nA 30-day period of mourning commenced following the president's death. Various public ceremonies, modeled after European royal funeral practices, were held. An invitation-only funeral service was also held, on April 7 in the East Room of the White House, after which Harrison's coffin was brought to Congressional Cemetery in Washington, D.C., where it was placed in a temporary receiving vault.\\nThat June, Harrison's body was transported by train and river barge to North Bend, Ohio. Then, on July 7, 1841, the nation's 9th president was buried in a family tomb at the summit of Mt. Nebo, overlooking the Ohio River – the William Henry Harrison Tomb State Memorial.\\nHarrison's death sparked a brief constitutional crisis regarding succession to the presidency, as the U.S. Constitution was unclear as to whether Vice President John Tyler should assume the office of president or merely execute the duties of the vacant office. Tyler claimed a constitutional mandate to carry out the full powers and duties of the presidency and took the presidential oath of office, setting an important precedent for an orderly transfer of presidential power when a president leaves office intra-term.\\nCoincidentally, all but one of the presidents who later died in office had, like Harrison, won a presidential election in a year ending in a zero (1840 through 1960). This pattern of tragedies came to be known as the Curse of Tippecanoe, or the Curse of Tecumseh, the name of the Shawnee leader against whom Harrison fought in the 1811 Battle of Tippecanoe. Also sometimes referred to as the Zero Factor legend, the pattern was disrupted by Ronald Reagan, who survived an assassination attempt in 1981 (69 days after taking office) and lived to complete two full terms.\\n\\n1850: Zachary Taylor\\nZachary Taylor was known to have consumed copious amounts of ice water, cold milk, green apples, and cherries on July 4, 1850, after attending holiday celebrations and the laying of the cornerstone of the Washington Monument. That same evening, he became severely ill with an unknown digestive ailment. Doctors used popular treatments of the time. On the morning of July 9, the president asked his wife Margaret not to grieve saying:\\n\\nI have always done my duty, I am ready to die. My only regret is for the friends I leave behind me.\\nTaylor died late that evening, five days after becoming ill, at age 65. Contemporary reports listed the cause of death as \\\"bilious diarrhea or a bilious cholera.\\\" He was succeeded by Vice President Millard Fillmore.\\nTaylor's funeral took place on July 13, and like Harrison's nine years earlier, was held in the East Room of the White House. Afterward, an estimated 100,000 people gathered along the funeral route to Congressional Cemetery where his coffin was placed temporarily in the Public Vault; that October it was transported to Louisville, Kentucky. On November 1, 1850, Taylor was buried in his family's burial ground on the Taylor estate, Springfield, which became the Zachary Taylor National Cemetery.\\nAlmost immediately after his death, rumors began to circulate that Taylor had been poisoned by pro-slavery Southerners, and various conspiracy theories persisted into the late-20th century. The cause of Taylor's death was definitively established in 1991, when his remains were exhumed and an autopsy conducted by Kentucky's chief medical examiner. Subsequent neutron activation analysis conducted at Oak Ridge National Laboratory revealed no evidence of poisoning, as arsenic levels were too low. The analysis concluded Taylor had contracted cholera morbus (acute gastroenteritis), as Washington had open sewers, and his food or drink may have been contaminated.\\n\\n1865: Abraham Lincoln\\nThe assassination of Abraham Lincoln took place on Good Friday, April 14, 1865, as the Civil War was drawing to a close. He died the following morning at the age of 56. The assassination occurred five days after General Robert E. Lee and the Army of Northern Virginia surrendered to General Ulysses S. Grant and the Army of the Potomac following the Battle of Appomattox Court House. Lincoln was the first American president to be killed by an assassin. (The first U.S. president to be confronted by a would-be assassin was Andrew Jackson 30 years earlier, in January 1835.)\\nThe assassination of President Lincoln was planned and carried out by the well-known stage actor John Wilkes Booth, a Confederate sympathizer, vehement in his denunciation of Lincoln, and a strong opponent of the abolition of slavery in the United States. Booth and a group of co-conspirators originally plotted to kidnap Lincoln, but later planned to kill him, Vice President Andrew Johnson, and Secretary of State William H. Seward in a bid to help the Confederacy's cause. Johnson's would-be-assassin, George Atzerodt did not carry out his part of the plan, and Johnson succeeded Lincoln as president while Lewis Powell only managed to wound Seward.\\nLincoln was shot once in the back of his head while watching the play Our American Cousin with his wife Mary Todd Lincoln at Ford's Theatre in Washington, D.C., on the night of April 14, 1865. An army surgeon who happened to be at Ford's, Doctor Charles Leale, assessed Lincoln's wound as mortal. The unconscious president was then carried across the street from the theater to the Petersen House, where he remained in a coma for eight hours before dying the following morning.\\nWithin two weeks of the manhunt for Lincoln's killers, on April 26, 1865, Booth and David Herold were caught in a tobacco barn in Port Conway, Virginia. While Herold surrendered, Booth was shot to death by Boston Corbett, a Union Corporal.\\nA three-week series of official functions were held following the president's death. He lay in state in the East Room of the White House which was open to the public on April 18. A funeral service was held the next day, and then the coffin was transported in a procession down Pennsylvania Avenue to the United States Capitol, where a ceremonial burial service was held in the rotunda. After lying in state at the Capitol, Lincoln's remains were transported by train to Springfield, Illinois, for burial. He was interred on May 4, 1865, at Oak Ridge Cemetery in Springfield – the Lincoln Tomb State Historic Site since 1895.\\n\\n1881: James A. Garfield\\nThe assassination of James A. Garfield happened in Washington, D.C., on July 2, 1881. Garfield was shot by Charles J. Guiteau at 9:30 a.m., less than four months into his term as the nation's 20th president. He died 11 weeks later on September 19, 1881, at the age of 49. Vice President Chester A. Arthur succeeded him as president. Garfield was scheduled to leave Washington on July 2, 1881, for his summer vacation. On that day, Guiteau lay in wait for the president at the Baltimore and Potomac Railroad station, on the southwest corner of present-day Sixth Street and Constitution Avenue NW, Washington, D.C.\\nPresident Garfield came to the Sixth Street Station on his way to his alma mater, Williams College, where he was scheduled to deliver a speech. Garfield was accompanied by two of his sons, James and Harry, and Secretary of State James G. Blaine. Secretary of War Robert Todd Lincoln waited at the station to see the president off. Garfield had no bodyguard or security detail; with the exception of Abraham Lincoln during the Civil War, early U.S. presidents never used any guards.\\nAs President Garfield entered the waiting room of the station, Guiteau stepped forward and pulled the trigger from behind at point-blank range. \\\"My God, what is that?!\\\" Garfield cried out, flinging up his arms. Guiteau fired again and Garfield collapsed. One bullet grazed Garfield's shoulder; the other hit him in the back, passing the first lumbar vertebra but missing the spinal cord before coming to rest behind his pancreas.\\nGarfield, conscious but in shock, was carried to an upstairs floor of the train station. Lincoln sent for D.C. Bliss, a prominent Washington physician, who soon arrived and examined Garfield's wounds several times, probing for the bullet that remained lodged in the president's body with his fingers and metal probes. Two additional doctors were summoned, and they also probed the entry wound. Eventually there were about twenty people in the room, including at least ten physicians. As Garfield was being cared for, Lincoln, thinking back to the death of his father, said \\\"How many hours of sorrow I have passed in this town.\\\"\\nGarfield was carried back to the White House. Although doctors told him that he would not survive the night, the president remained conscious and alert. The next morning his vital signs were good and doctors began to hope for recovery. A long vigil began, with Garfield's doctors issuing regular bulletins that the American public followed closely throughout the summer of 1881. His condition fluctuated. Fevers came and went. Garfield struggled to keep down solid food and spent most of the summer eating little, and that only liquids.\\nGarfield had been a regular visitor to the shore town of Long Branch, New Jersey, one of the nation's premier summer vacation spots until World War I. In early September, it was decided to bring him to Elberon, a quiet beach town just to the south of Long Branch, in hopes that the beach air would help him recover. When they heard that the president was being brought to their town, local citizens built more than half a mile of tracks in less than 24 hours, enabling Garfield to be brought directly to the door of the oceanfront Franklyn cottage, rather than being moved by carriage from the local Elberon train station. However, Garfield died 12 days later. A granite marker on Garfield Road identifies the former site of the cottage, which was demolished in 1950. Throughout the five-month drama, anxious Americans across the country were kept informed of developments by the news media. The publisher of Frank Leslie's Illustrated Newspaper, Miriam Leslie, was especially quick to publish fully illustrated accounts of key moments, from Garfield's shooting to the embalming of his body.\\nChester Arthur was at his home in New York City on the night of September 19, when word came that Garfield had died. After first getting the news, Arthur said \\\"I hope—my God, I do hope it is a mistake.\\\" But confirmation by telegram came soon after. Arthur took the presidential oath of office, administered by a New York Supreme Court judge, then left for Long Branch to pay his respects before traveling on to Washington. Garfield's body was taken to Washington, where it lay in state for two days in the Capitol Rotunda before being taken to Cleveland, where the funeral was held on September 26.\\nWhen the tracks that had been hastily built to the Franklyn cottage were later torn up, actor Oliver Byron bought the wooden ties, and had local carpenter William Presley build them into a small tea house, in commemoration of the president. The red & white (originally red, white & blue) \\\"Garfield Tea House\\\" still survives, resting a couple of blocks away from the site of the cottage on the grounds of the Long Branch Historical Museum, a former Episcopal Church. The church is nicknamed \\\"The Church of the Presidents,\\\" as it had been attended by, in addition to Garfield, presidents Chester A. Arthur, Ulysses S. Grant, Benjamin Harrison, Rutherford Hayes, William McKinley, and Woodrow Wilson, during their own visits to Long Branch.\\n\\n1901: William McKinley\\nWilliam McKinley was assassinated on September 6, 1901, inside the Temple of Music on the grounds of the Pan-American Exposition in Buffalo, New York. McKinley was shaking hands with the public when Leon Czolgosz, a Polish-American anarchist, shot him. The 58-year-old president died eight days later on September 14 from gangrene caused by the bullet wounds.\\nMcKinley had been elected for a second term in 1900. He enjoyed meeting the public, and was reluctant to accept the security available to his office. The secretary to the president, George B. Cortelyou, feared an assassination attempt would take place during a visit to the Temple of Music, and twice took it off the schedule. McKinley restored it each time.\\nCzolgosz had lost his job during the economic Panic of 1893 and turned to anarchism, a political philosophy whose adherents had previously killed foreign leaders. Regarding McKinley as a symbol of oppression, Czolgosz felt it was his duty as an anarchist to kill him. Unable to get near McKinley during the earlier part of the presidential visit, Czolgosz shot McKinley twice as the President reached to shake his hand in the reception line at the temple. One bullet grazed McKinley; the other entered his abdomen and was never found.\\nMcKinley initially appeared to be recovering, but took a turn for the worse on September 13 as his wounds became gangrenous, and died early the next morning; Vice President Theodore Roosevelt succeeded him. Roosevelt was hiking near the top of Mt. Marcy, in New York's Adirondack region, when a runner located him to convey the news. After McKinley's murder, for which Czolgosz was put to death in the electric chair, the United States Congress passed legislation to officially charge the Secret Service with the responsibility for protecting the president.\\n\\n1923: Warren G. Harding\\nWarren G. Harding died from a sudden heart attack in his hotel suite while visiting San Francisco on the evening of August 2, 1923, at the age of 57. His death quickly led to theories that he had been poisoned or committed suicide. Rumors of poisoning were fueled, in part, by a book called The Strange Death of President Harding by private detective and former Ohio Gang member Gaston Means, who suggested  First Lady Florence Harding had poisoned her husband after learning of his infidelity. Mrs. Harding's refusal to allow an autopsy on President Harding only added to the speculation. According to the physicians attending Harding, however, the symptoms in the days prior to his death all pointed to congestive heart failure. Harding's biographer, Samuel H. Adams, concluded that \\\"Warren G. Harding died a natural death which, in any case, could not have been long postponed.\\\"\\nImmediately after President Harding's death, Mrs. Harding returned to Washington, D.C., and briefly stayed in the White House with the new president Calvin Coolidge and first lady. For a month, former first lady Harding gathered and destroyed by fire President Harding's correspondence and documents, both official and unofficial. Upon her return to Marion, Ohio, Mrs. Harding hired a number of secretaries to collect and burn President Harding's personal papers. According to Mrs. Harding, she took these actions to protect her husband's legacy. The remaining papers were held and kept from public view by the Harding Memorial Association in Marion.\\n\\n1945: Franklin D. Roosevelt\\nOn March 29, 1945, Franklin D. Roosevelt went to the Little White House in Warm Springs, Georgia, to rest before his anticipated appearance at the founding conference of the United Nations in late April in San Francisco. At around 1:00 pm on April 12, Roosevelt said, \\\"I have a terrific pain in the back of my head,\\\" which were his last words. He then slumped forward in his chair, unconscious, and was carried into his bedroom. The president's attending cardiologist, Howard Bruenn, diagnosed a massive cerebral hemorrhage (stroke). The 63-year-old Roosevelt died a few hours later, without regaining consciousness. As Allen Drury later said, \\\"so ended an era, and so began another.\\\" After Roosevelt's death, an editorial in The New York Times declared, \\\"Men will thank God on their knees a hundred years from now that Franklin D. Roosevelt was in the White House.\\\"\\nIn his later years at the White House, when Roosevelt was increasingly overworked, his daughter Anna Roosevelt Boettiger had moved in to provide her father companionship and support. Anna had also arranged for her father to meet with his former mistress, the then widowed Lucy Mercer Rutherfurd. A close friend of both Roosevelt and Mercer who was present, Elizabeth Shoumatoff, rushed Mercer away to avoid negative publicity and implications of infidelity. When Eleanor heard about her husband's death, she was also faced with the news that Anna had been arranging these meetings with Mercer and that Mercer had been with Franklin when he died.\\nOn the morning of April 13, Roosevelt's body was placed in a flag-draped coffin and loaded onto the presidential train. After a White House funeral on April 14, Roosevelt was transported back to Hyde Park by train, guarded by four servicemen, one each from the Army, Navy, Marines, and Coast Guard. As was his wish, Roosevelt was buried in the Rose Garden of the Springwood estate, the Roosevelt family home in Hyde Park on April 15. Eleanor died in November 1962 and was buried next to him.\\nRoosevelt's death was met with shock and grief across the U.S. and around the world. His declining health had not been known to the general public. Roosevelt had been president for more than 12 years, longer than any other person, and had led the country through some of its greatest crises to the impending defeat of Nazi Germany and within sight of the defeat of Japan as well.\\nLess than a month after his death, on May 8, the war in Europe ended. President Harry S. Truman dedicated Victory in Europe Day and its celebrations to Roosevelt's memory, and kept the flags across the U.S. at half-staff for the remainder of the 30-day mourning period. In doing so, Truman said that his only wish was \\\"that Franklin D. Roosevelt had lived to witness this day.\\\"\\n\\n1963: John F. Kennedy\\nThe most recent U.S. president to die in office is John F. Kennedy, who was assassinated on November 22, 1963, in Dallas, Texas. He was fatally shot by Lee Harvey Oswald, who fired three shots from a sixth floor window of the Texas School Book Depository at 12:30 p.m. as the presidential motorcade passed through Dealey Plaza. Riding in the vehicle with the president were First Lady Jackie Kennedy, Texas governor John Connally, and Connally's wife Nellie; Governor Connally was also seriously wounded in the attack. The motorcade rushed to Parkland Memorial Hospital, where Kennedy was pronounced dead about 30 minutes later, at the age of 46. Connally recovered from his injuries.\\nVice President Lyndon B. Johnson, who was a few cars behind the president in the motorcade, became U.S. president upon Kennedy's death. He took the presidential oath of office onboard Air Force One as it sat on the runway at Dallas Love Field. Oswald was arrested by the Dallas Police Department that afternoon, and was charged under Texas state law with the murder of Kennedy, as well as that of Dallas policeman J. D. Tippit, who had been fatally shot a short time after the assassination. Two days later, on November 24, 1963, as live television cameras were covering his transfer from the city jail to the county jail, Oswald was fatally shot in the basement of Dallas Police Headquarters by Dallas nightclub operator Jack Ruby. Ruby was convicted of Oswald's murder, though it was later overturned on appeal, and Ruby died in prison in 1967 while awaiting a new trial.\\nIn 1964, after a 10-month investigation into the assassination, the Warren Commission concluded that President Kennedy was assassinated by Lee Harvey Oswald and that Oswald had acted entirely alone. It also concluded that Jack Ruby acted alone when he killed Oswald in police custody. Nonetheless, speculation over \\\"what really happened\\\" on November 22, 1963, in Dallas captured the public imagination during the decades that followed. Polls conducted from 1966 to 2004 found that as many as 80 percent of Americans have suspected that there was a criminal conspiracy or cover-up. Numerous books, films, television specials and websites have examined the assassination in minute detail, and numerous conspiracy theories have been advanced. Parties as varied as the FBI, the CIA, the Mafia, the Cuban and the Soviet governments, along with Kennedy's successor, Lyndon Johnson, have been identified as Suspect. In an article published prior to the 50th anniversary of Kennedy's assassination, author Vincent Bugliosi estimates that a total of 42 groups, 82 assassins, and 214 people have been accused in conspiracy theories challenging the \\\"lone gunman\\\" theory.\\n\\nSee also\\nList of United States presidential assassination attempts\\nCurse of Tippecanoe\\n\\nNotes\\nReferences\\nBibliography\\nBauer, K. Jack (1985). Zachary Taylor: Soldier, Planter, Statesman of the Old Southwest. Louisiana State University Press. ISBN 0-8071-1237-2.\\nCleaves, Freeman (1939). Old Tippecanoe: William Henry Harrison and His Time. New York, NY: C. Scribner's Sons.\\nLeech, Margaret (1959). In the Days of McKinley. New York: Harper and Brothers. pp. 594–600. OCLC 456809.\\nMcCullough, David (1992). Truman. Simon & Schuster. ISBN 0-671-86920-5.\\nMillard, Candice (2011). Destiny of the Republic. Doubleday. ISBN 978-0-385-53500-7.\\nMiller, Scott (2011). The President and the Assassin. New York: Random House. pp. 56–60. ISBN 978-1-4000-6752-7.\\nPeskin, Allan (1978). Garfield. Kent State University Press. ISBN 0-87338-210-2.\\nVowell, Sarah (2005). Assassination Vacation. Simon and Schuster. ISBN 0-7432-6003-1.\\n\\nExternal links\\nThe Mortal Presidency Archived June 3, 2015, at the Wayback Machine (Shapell Manuscript Foundation)\\n\\nJames Abram Garfield (November 19, 1831 – September 19, 1881) was the 20th president of the United States, serving from March 1881 until his assassination in September that year. A preacher, lawyer, and Civil War general, Garfield served nine terms in the United States House of Representatives and is the only sitting member of the House to be elected president. Before his candidacy for the presidency, he had been elected to the U.S. Senate by the Ohio General Assembly—a position he declined when he became president-elect.\\nGarfield was born into poverty in a log cabin and grew up in northeastern Ohio. After graduating from Williams College, he studied law and became an attorney. He was a preacher in the Stone–Campbell Movement and president of the Western Reserve Eclectic Institute, affiliated with the Disciples. Garfield was elected as a Republican member of the Ohio State Senate in 1859, serving until 1861. He opposed Confederate secession, was a major general in the Union Army during the American Civil War, and fought in the battles of Middle Creek, Shiloh, and Chickamauga. He was elected to Congress in 1862 to represent Ohio's 19th district. Throughout his congressional service, he firmly supported the gold standard and gained a reputation as a skilled orator. He initially agreed with Radical Republican views on Reconstruction but later favored a Moderate Republican–aligned approach to civil rights enforcement for freedmen. Garfield's aptitude for mathematics extended to his own proof of the Pythagorean theorem, which he published in 1876.\\nAt the 1880 Republican National Convention, delegates chose Garfield, who had not sought the White House, as a compromise presidential nominee on the 36th ballot. In the 1880 presidential election, he conducted a low-key front porch campaign and narrowly defeated the Democratic nominee, Winfield Scott Hancock. Garfield's accomplishments as president included his assertion of presidential authority against senatorial courtesy in executive appointments, a purge of corruption in the Post Office, and his appointment of a Supreme Court justice. He advocated for agricultural technology, an educated electorate, and civil rights for African Americans. He also proposed substantial civil service reforms, which were passed by Congress in 1883 as the Pendleton Civil Service Reform Act and signed into law by his successor, Chester A. Arthur.\\nGarfield was a member of the intraparty \\\"Half-Breed\\\" faction who used the powers of the presidency to defy the powerful \\\"Stalwart\\\" Senator Roscoe Conkling from New York. He did this by appointing Blaine faction leader William H. Robertson to the lucrative post of Collector of the Port of New York. The ensuing political battle resulted in Robertson's confirmation and the resignations of Conkling and Thomas C. Platt from the Senate.\\nOn July 2, 1881, Charles J. Guiteau, a disappointed and delusional office seeker, shot Garfield at the Baltimore and Potomac Railroad Station in Washington. The wound was not immediately fatal, but an infection caused by his doctors' unsanitary methods in treating the wound killed Garfield on September 19. Due to his brief tenure in office, historians tend to rank Garfield as a below-average president, though he has earned praise for anti-corruption and pro-civil rights stances.\\n\\nChildhood and early life\\nJames Abram Garfield was born the youngest of five children on November 19, 1831, in a log cabin in Orange Township, now Moreland Hills, Ohio. Garfield's ancestor Edward Garfield migrated from Hillmorton, Warwickshire, England, to Massachusetts around 1630. James's father Abram was born in Worcester, New York, and came to Ohio to woo his childhood sweetheart, Mehitabel Ballou, only to find her married. He instead wed her sister Eliza, who was born in New Hampshire. James was named after an earlier son of Eliza and Abram who had died in infancy.\\nIn early 1833, Abram and Eliza Garfield joined a Stone-Campbell church, a decision that influenced their youngest son's life. Abram died later that year, and James was raised in poverty in a household led by his strong-willed mother. He was her favorite child and the two remained close for the rest of his life. Eliza remarried in 1842, but soon left her second husband, Warren (or Alfred) Belden, and a scandalous divorce was awarded in 1850. James took his mother's side in the matter and noted Belden's 1880 death with satisfaction in his diary. Garfield also enjoyed his mother's stories about his ancestry, especially those about his Welsh great-great-grandfathers and an ancestor who served as a knight of Caerphilly Castle.\\nPoor and fatherless, Garfield was mocked by his peers and became sensitive to slights throughout his life; he sought escape through voracious reading. He left home at age 16 in 1847 and was rejected for work on the only ship in port in Cleveland. Garfield instead found work on a canal boat, managing the mules that pulled it. Horatio Alger later used this labor to good effect when he wrote Garfield's campaign biography in 1880.\\nAfter six weeks, illness forced Garfield to return home, and during his recuperation, his mother and a local school official secured his promise to forgo canal work for a year of school. In 1848, he began at Geauga Seminary, in nearby Chester Township, Geauga County, Ohio. Garfield later said of his childhood, \\\"I lament that I was born to poverty, and in this chaos of childhood, seventeen years passed before I caught any inspiration ... a precious 17 years when a boy with a father and some wealth might have become fixed in manly ways.\\\"\\n\\nEducation, marriage and early career\\nGarfield attended Geauga Seminary from 1848 to 1850 and learned academic subjects for which he had not previously had time. He excelled as a student and was especially interested in languages and elocution. He began to appreciate the power a speaker had over an audience, writing that the speaker's platform \\\"creates some excitement. I love agitation and investigation and glory in defending unpopular truth against popular error.\\\" Geauga was coeducational, and Garfield was attracted to one of his classmates, Lucretia Rudolph, whom he later married. To support himself at Geauga, he worked as a carpenter's assistant and teacher. The need to go from town to town to find work as a teacher aggravated Garfield, and he developed a dislike of what he called \\\"place-seeking\\\", which became, he said, \\\"the law of my life.\\\" In later years, he astounded his friends by disregarding positions that could have been his with little politicking. Garfield had attended church more to please his mother than to worship God, but in his late teens he underwent a religious awakening. He attended many camp meetings, which led to his being born again on March 4, 1850, when he was baptized into Christ by being submerged in the icy waters of the Chagrin River.\\nAfter he left Geauga, Garfield worked for a year at various jobs, including teaching jobs. Finding that some New Englanders worked their way through college, Garfield determined to do the same and sought a school that could prepare him for the entrance examinations. From 1851 to 1854, he attended the Western Reserve Eclectic Institute (later named Hiram College) in Hiram, Ohio, a school founded by and still affiliated with the Christian Church (Disciples of Christ). While there, he was most interested in the study of Greek and Latin but was inclined to learn about and discuss any new thing he encountered. Securing a position on entry as janitor, he obtained a teaching position while he was still a student there. Lucretia Rudolph also enrolled at the Institute and Garfield wooed her while teaching her Greek. He developed a regular preaching circuit at neighboring churches and, in some cases, earned one gold dollar per service. By 1854, Garfield had learned all the Institute could teach him and was a full-time teacher. Garfield then enrolled at Williams College in Williamstown, Massachusetts, as a third-year student; he received credit for two years' study at the Institute after passing a cursory examination. Garfield was also impressed with the college president, Mark Hopkins, who had responded warmly to Garfield's letter inquiring about admission. He said of Hopkins, \\\"The ideal college is Mark Hopkins on one end of a log with a student on the other.\\\" Hopkins later said of Garfield in his student days, \\\"There was a large general capacity applicable to any subject. There was no pretense of genius, or alternation of spasmodic effort, but a satisfactory accomplishment in all directions.\\\" After his first term, Garfield was hired to teach penmanship to the students of nearby Pownal, Vermont, a post Chester A. Arthur previously held.\\n\\nGarfield graduated Phi Beta Kappa from Williams in August 1856, was named salutatorian, and spoke at the commencement. His biographer Ira Rutkow writes that Garfield's years at Williams gave him the opportunity to know and respect those of different social backgrounds, and that, despite his origin as an unsophisticated Westerner, socially conscious New Englanders liked and respected him. \\\"In short,\\\" Rutkow writes, \\\"Garfield had an extensive and positive first experience with the world outside the Western Reserve of Ohio.\\\"\\nUpon his return to Ohio, the degree from a prestigious Eastern college made Garfield a man of distinction. He returned to Hiram to teach at the Institute and in 1857 was made its principal, though he did not see education as a field that would realize his full potential. The abolitionist atmosphere at Williams had enlightened him politically, after which he began to consider politics as a career. He campaigned for Republican presidential candidate John C. Frémont in 1856. In 1858, he married Lucretia, and they had seven children, five of whom survived infancy. Soon after the wedding, he registered to read law at the office of attorney Albert Gallatin Riddle in Cleveland, though he did his studying in Hiram. He was admitted to the bar in 1861.\\nLocal Republican leaders invited Garfield to enter politics upon the death of Cyrus Prentiss, the presumptive nominee for the local state senate seat. He was nominated at the party convention on the sixth ballot and was elected, serving from 1860 to 1861. Garfield's major effort in the state senate was an unsuccessful bill providing for Ohio's first geological survey to measure its mineral resources.\\n\\nCivil War\\nAfter Abraham Lincoln's election as president, several Southern states announced their secession from the Union to form a new government, the Confederate States of America. Garfield read military texts while anxiously awaiting the war effort, which he regarded as a holy crusade against the Slave Power. In April 1861, the rebels bombarded Fort Sumter, one of the South's last federal outposts, beginning the Civil War. Although he had no military training, Garfield knew his place was in the Union Army.\\nAt Governor William Dennison's request, Garfield deferred his military ambitions to remain in the legislature, where he helped appropriate the funds to raise and equip Ohio's volunteer regiments. When the legislature adjourned Garfield spent the spring and early summer on a speaking tour of northeastern Ohio, encouraging enlistment in the new regiments. Following a trip to Illinois to purchase muskets, Garfield returned to Ohio and, in August 1861, received a commission as a colonel in the 42nd Ohio Infantry regiment. The 42nd Ohio existed only on paper, so Garfield's first task was to fill its ranks. He did so quickly, recruiting many of his neighbors and former students. The regiment traveled to Camp Chase, outside Columbus, Ohio, to complete training. In December, Garfield was ordered to bring the 42nd to Kentucky, where they joined the Army of the Ohio under Brigadier General Don Carlos Buell.\\n\\nBuell's command\\nBuell quickly assigned Garfield the task of driving Confederate forces out of eastern Kentucky, giving him the 18th Brigade for the campaign, which, besides his own 42nd, included the 40th Ohio Infantry, two Kentucky infantry regiments and two cavalry units. They departed Catlettsburg, Kentucky, in mid-December, advancing through the valley of the Big Sandy River. The march was uneventful until Union forces reached Paintsville, Kentucky, on January 6, 1862, where Garfield's cavalry engaged the rebels at Jenny's Creek. Confederate troops under Brigadier General Humphrey Marshall held the town in numbers roughly equal to Garfield's own, but Garfield positioned his troops so as to deceive Marshall into believing the rebels were outnumbered. Marshall ordered his troops to withdraw to the forks of Middle Creek, on the road to Virginia, and Garfield ordered his troops to take up the pursuit. They attacked the rebel positions on January 9, 1862, in the Battle of Middle Creek, the only pitched battle Garfield commanded personally. At the fighting's end, the Confederates withdrew from the field and Garfield sent his troops to Prestonsburg to reprovision.\\n\\nIn recognition of his success, Garfield was promoted to brigadier general. After Marshall's retreat, Garfield's command was the sole remaining Union force in eastern Kentucky and he announced that any men who had fought for the Confederacy would be granted amnesty if they returned to their homes, lived peaceably, and remained loyal to the Union. The proclamation was surprisingly lenient, as Garfield now believed the war was a crusade for eradication of slavery. Following a brief skirmish at Pound Gap, the last rebel units in the area were outflanked and retreated to Virginia.\\nGarfield's promotion gave him command of the 20th Brigade of the Army of the Ohio, which received orders to join Major General Ulysses S. Grant's forces as they advanced on Corinth, Mississippi, in early 1862. Before the 20th Brigade arrived, however, Confederate forces under General Albert Sidney Johnston surprised Grant's men in their camps, driving them back. Garfield's troops received word of the battle and advanced quickly, joining the rest of the army on the second day to drive the Confederates back across the field and into retreat. The action, later known as the Battle of Shiloh, was the bloodiest of the war to date; Garfield was exposed to fire for much of the day, but emerged uninjured. Major General Henry W. Halleck, Grant's superior, took charge of the combined armies and advanced ponderously toward Corinth; when they arrived, the Confederates had fled.\\nThat summer, Garfield suffered from jaundice and significant weight loss. He was forced to return home, where his wife nursed him back to health. While he was home, Garfield's friends worked to gain him the Republican nomination for Congress, but he refused to campaign with the delegates. He returned to military duty that autumn and went to Washington to await his next assignment. During this period of idleness, a rumor of an extramarital affair caused friction in the Garfields' marriage until Lucretia eventually chose to overlook it. Garfield repeatedly received tentative assignments that were quickly withdrawn, to his frustration. In the meantime, he served on the court-martial of Fitz John Porter for his tardiness at the Second Battle of Bull Run. He was convinced of Porter's guilt and voted with his fellow generals to convict Porter. The trial lasted almost two months, from November 1862 to January 1863, and, by its end, Garfield had procured an assignment as chief of staff to Major General William S. Rosecrans.\\n\\nChief of staff for Rosecrans\\nGenerals' chiefs of staff were usually more junior officers, but Garfield's influence with Rosecrans was greater than usual, with duties extending beyond communication of orders to actual management of his Army of the Cumberland. Rosecrans had a voracious appetite for conversation, especially when unable to sleep; in Garfield, he found \\\"the first well read person in the Army\\\" and the ideal candidate for discussions that ran deep into the night. They discussed everything, especially religion, and the two became close despite Garfield's being 12 years his junior. Rosecrans, who had converted from Methodism to Roman Catholicism, softened Garfield's view of his faith.\\nGarfield recommended that Rosecrans replace wing commanders Alexander McCook and Thomas Crittenden, as he believed they were ineffective, but Rosecrans ignored the suggestion. With Rosecrans, Garfield devised the Tullahoma Campaign to pursue and trap Confederate General Braxton Bragg in Tullahoma. After initial Union success, Bragg retreated toward Chattanooga, where Rosecrans stalled and requested more troops and supplies. Garfield argued for an immediate advance, in line with demands from Halleck and Lincoln. After a council of war and lengthy deliberations, Rosecrans agreed to attack.\\nAt the ensuing Battle of Chickamauga on September 19 and 20, 1863, confusion among the wing commanders over Rosecrans's orders created a gap in the lines, resulting in a rout of the right flank. Rosecrans concluded that the battle was lost and fell back on Chattanooga to establish a defensive line. Garfield, however, thought part of the army had held and, with Rosecrans's approval, headed across Missionary Ridge to survey the scene. Garfield's hunch was correct. Consequently, his ride became legendary and Rosecrans's error reignited criticism about the latter's leadership. While Rosecrans's army had avoided disaster, they were stranded in Chattanooga, surrounded by Bragg's army. Garfield sent a telegram to Secretary of War Edwin M. Stanton alerting Washington to the need for reinforcements to avoid annihilation. Lincoln and Halleck responded to the request for reinforcements by sending 20,000 troops to Garfield by rail within nine days. In the meantime, Grant was promoted to command of the western armies and quickly replaced Rosecrans with George H. Thomas. Garfield was ordered to report to Washington, where he was promoted to major general. According to historian Jean Edward Smith, Grant and Garfield had a \\\"guarded relationship\\\" since Grant promoted Thomas, rather than Garfield, to command of the Army of the Cumberland after Rosecrans's dismissal.\\n\\nCongressional career\\nElection in 1862; Civil War years\\nWhile he served in the Army in early 1862, friends of Garfield approached him about running for Congress from Ohio's newly redrawn and heavily Republican 19th district. He worried that he and other state-appointed generals would receive obscure assignments, and running for Congress would allow him to resume his political career. That the new Congress would not hold its first regular session until December 1863 allowed him to continue his war service for a time. Home on medical leave, he refused to campaign for the nomination, leaving that to political managers who secured it at the local convention in September 1862 on the eighth ballot. In the October general election, he defeated D.B. Woods by a two-to-one margin for a seat in the 38th Congress.\\nDays before his Congressional term began, Garfield lost his eldest daughter, three-year-old Eliza, and became anxious and conflicted, saying his \\\"desolation of heart\\\" might require his return to \\\"the wild life of the army.\\\" He also assumed that the war would end before his joining the House, but it had not, and he felt strongly that he belonged in the field, rather than in Congress. He also thought he could expect a favorable command, so he decided to see President Lincoln. During their meeting, Lincoln recommended he take his House seat, as there was an excess of generals and a shortage of administration congressmen, especially those with knowledge of military affairs. Garfield accepted this recommendation and resigned his military commission to do so.\\nGarfield met and befriended Treasury Secretary Salmon P. Chase, who saw Garfield as a younger version of himself. The two agreed politically and both were part of the Radical wing of the Republican Party. Once he took his seat in December 1863, Garfield was frustrated at Lincoln's reluctance to press the South hard. Many radicals, led in the House by Pennsylvania's Thaddeus Stevens, wanted rebel-owned lands confiscated, but Lincoln threatened to veto any bill that proposed to do so on a widespread basis. In debate on the House floor, Garfield supported such legislation and, discussing England's Glorious Revolution, hinted that Lincoln might be thrown out of office for resisting it. Garfield had supported Lincoln's Emancipation Proclamation and marveled at the \\\"strange phenomenon in the world's history, when a second-rate Illinois lawyer is the instrument to utter words which shall form an epoch memorable in all future ages.\\\"\\nGarfield not only favored the abolition of slavery, but also believed the leaders of the rebellion had forfeited their constitutional rights. He supported the confiscation of Southern plantations and even exile or execution of rebellion leaders as a means to ensure a permanent end to slavery. Garfield felt Congress had an obligation \\\"to determine what legislation is necessary to secure equal justice to all loyal persons, without regard to color.\\\" He was more supportive of Lincoln when he took action against slavery.\\nGarfield showed leadership early in his congressional career; he was initially the only Republican vote to terminate the use of bounties in military recruiting. Some financially able recruits had used the bounty system to buy their way out of service (called commutation), which Garfield considered reprehensible. He gave a speech pointing out the flaws in the existing conscription law: 300,000 recruits had been called upon to enlist, but barely 10,000 had done so, with the remainder claiming exemption, providing money, or recruiting a substitute. Lincoln appeared before the Military Affairs committee on which Garfield served, demanding a more effective bill; even if it cost him reelection, Lincoln was confident he could win the war before his term expired. After many false starts, Garfield, with Lincoln's support, procured the passage of a conscription bill that excluded commutation.\\nUnder Chase's influence, Garfield became a staunch proponent of a dollar backed by a gold standard, and strongly opposed the \\\"greenback\\\". He also accepted the necessity of suspension of payment in gold or silver during the Civil War with strong reluctance. He voted with the Radical Republicans in passing the Wade–Davis Bill, designed to give Congress more authority over Reconstruction, but Lincoln defeated it with a pocket veto.\\nGarfield did not consider Lincoln very worthy of reelection, but there seemed to be no viable alternative. \\\"He will probably be the man, though I think we could do better\\\", he said. Garfield attended the party convention and promoted Rosecrans as Lincoln's running mate, but delegates chose Military Governor of Tennessee Andrew Johnson. Lincoln was reelected, as was Garfield. By then, Chase had left the Cabinet and been appointed Chief Justice, and his relations with Garfield became more distant.\\nGarfield took up the practice of law in 1865 to improve his personal finances. His efforts took him to Wall Street where, the day after Lincoln's assassination, a riotous crowd drew him into an impromptu speech to calm their passions: \\\"Fellow citizens! Clouds and darkness are round about Him! His pavilion is dark waters and thick clouds of the skies! Justice and judgment are the establishment of His throne! Mercy and truth shall go before His face! Fellow citizens! God reigns, and the Government at Washington still lives!\\\" The speech, with no mention or praise of Lincoln, was, according to Garfield biographer Robert G. Caldwell, \\\"quite as significant for what it did not contain as for what it did.\\\" In the following years, Garfield had more praise for Lincoln; a year after Lincoln's death, Garfield said, \\\"Greatest among all these developments were the character and fame of Abraham Lincoln,\\\" and in 1878 he called Lincoln \\\"one of the few great rulers whose wisdom increased with his power\\\".\\nWhen in Washington, Garfield attended Vermont Avenue Christian Church, which later became National City Christian Church, a building constructed and funded by the Disciples.\\n\\nReconstruction\\nIn 1864, the U.S. Senate passed the 13th Amendment, which abolished slavery throughout the Union. The bill failed to pass the House by a two-thirds majority until January 31, 1865, when it was then sent to the states for ratification. The Amendment opened other issues concerning African American civil rights. Garfield asked, \\\"[What] is freedom? Is it the bare privilege of not being chained?...If this is all, then freedom is a bitter mockery, a cruel delusion.\\\"\\nGarfield supported black suffrage as firmly as he supported abolition. President Johnson sought the rapid restoration of the Southern states during the months between his accession and the meeting of Congress in December 1865; Garfield hesitantly supported this policy as an experiment. Johnson, an old friend, sought Garfield's backing and their conversations led Garfield to assume Johnson's differences with Congress were not large. When Congress assembled in December (to Johnson's chagrin, without the elected representatives of the Southern states, who were excluded), Garfield urged conciliation on his colleagues, although he feared that Johnson, a former Democrat, might join other Democrats to gain political control. Garfield foresaw conflict even before February 1866, when Johnson vetoed a bill to extend the life of the Freedmen's Bureau, charged with aiding the former slaves. By April, Garfield had concluded that Johnson was either \\\"crazy or drunk with opium.\\\"\\n\\nThe conflict between Congress and President Johnson was the major issue of the 1866 campaign, with Johnson taking to the campaign trail in a Swing Around the Circle and Garfield facing opposition within the Republican party in his home district. With the South still disenfranchised and Northern public opinion behind the Republicans, they gained a two-thirds majority in both houses of Congress. Garfield, having overcome his challengers at the district nominating convention, won reelection easily.\\nGarfield opposed the proposed impeachment of Johnson initially when Congress convened in December 1866, but supported legislation to limit Johnson's powers, such as the Tenure of Office Act, which restricted Johnson's ability to remove presidential appointees. Distracted by committee duties, Garfield spoke about these bills rarely, but was a loyal Republican vote against Johnson.\\nOn January 7, 1867, Garfield voted in support of the resolution that launched the first impeachment inquiry against Johnson (run by the House Committee on the Judiciary). On December 7, 1867, he voted against the unsuccessful resolution to impeach Johnson that the House Committee on the Judiciary had sent the full House. On January 27, 1868, he voted to pass the resolution that authorized the second impeachment inquiry against Johnson (run by the House Select Committee on Reconstruction). Due to a court case, he was absent on February 24, 1868, when the House impeached Johnson, but gave a speech aligning himself with Thaddeus Stevens and others who sought Johnson's removal shortly thereafter. Garfield was present on March 2 and 3, 1868, when the House voted on specific articles of impeachment, and voted in support of all 11 articles. During the March 2 debate on the articles, Garfield argued that what he characterized as Johnson's attempts to render Ulysses S. Grant, William Tecumseh Sherman, and William H. Emory personal tools of his demonstrated Johnson's intent to disregard the law and override the Constitution, suggesting that Johnson's trial perhaps could be expedited to last only a day in order to hasten his removal. When Johnson was acquitted in his trial before the Senate, Garfield was shocked and blamed the outcome on the trial's presiding officer, Chief Justice Chase, his onetime mentor.\\nBy the time Grant succeeded Johnson in 1869, Garfield had moved away from the remaining radicals (Stevens, their leader, had died in 1868). By this time, many in the Republican Party wanted to remove the \\\"Negro question\\\" from national affairs. Garfield hailed the ratification of the 15th Amendment in 1870 as a triumph and favored Georgia's readmission to the Union as a matter of right, not politics. An influential Republican, Garfield said, \\\"[The] Fifteen Amendment confers on the African race the care of its own destiny. It places their fortunes in their own hands.\\\" In 1871, Congress took up the Ku Klux Klan Act, which was designed to combat attacks on African Americans' suffrage rights. Garfield opposed the act, saying, \\\"I have never been more perplexed by a piece of legislation.\\\" He was torn between his indignation at the Klan, whom he called \\\"terrorists\\\", and his concern for the power given the president to enforce the act through suspension of habeas corpus.\\n\\nTariffs and finance\\nThroughout his political career, Garfield favored the gold standard and decried attempts to increase the money supply through the issuance of paper money not backed by gold, and later, through the free and unlimited coinage of silver. In 1865, he was put on the House Ways and Means Committee, a long-awaited opportunity to focus on financial and economic issues. He reprised his opposition to the greenback, saying, \\\"Any party which commits itself to paper money will go down amid the general disaster, covered with the curses of a ruined people.\\\" In 1868 Garfield gave a two-hour speech on currency in the House, which was widely applauded as his best oratory to that point; in it, he advocated a gradual resumption of specie payments, that is, the government paying out silver and gold, rather than paper money that could not be redeemed.\\nTariffs had been raised to high levels during the Civil War. Afterward, Garfield, who made a close study of financial affairs, advocated moving toward free trade, though the standard Republican position was a protective tariff that would allow American industries to grow. This break with his party likely cost him his place on the Ways and Means Committee in 1867, and though Republicans held the majority in the House until 1875, Garfield remained off that committee. Garfield came to chair the powerful House Appropriations Committee, but it was Ways and Means, with its influence over fiscal policy, that he really wanted to lead. One reason he was denied a place on Ways and Means was the opposition of the influential Republican editor Horace Greeley.\\n\\nStarting in January 1870, Garfield, then chairman of the House Banking Committee, led an investigation into the Black Friday Gold Panic scandal. In 1869, during Grant's first term in office, two New York conspirators, Jay Gould and James Fisk, launched a scheme to corner the gold market. The conspiracy was broken on Friday, September 24, 1869, when Grant and Treasury Secretary George Boutwell released gold into the market, causing widespread financial panic. During the investigation, rumors spread that Grant's family might have been involved. In order not to force Grant's wife to testify, Garfield had a private meeting with Grant at the White House. When Garfield showed Grant testimony about him and his family, Grant thanked Garfield but refused to read it or give a response. Grant personally resented Garfield for investigating Black Friday and his wife Julia concerning possible involvement in the scandal.\\nGarfield's investigation and final majority report, released on September 12, 1870, were thorough but found no indictable offenses and exonerated Grant and Julia of wrongdoing. Garfield thought the scandal was enabled by the greenbacks that financed the speculation. Garfield was not at all enthused about President Grant's reelection in 1872—until Greeley, who emerged as the candidate of the Democrats and Liberal Republicans, became the only serious alternative. Garfield said, \\\"I would say Grant was not fit to be nominated and Greeley is not fit to be elected.\\\" Both Grant and Garfield were overwhelmingly reelected.\\n\\nCrédit Mobilier scandal; salary grab\\nThe Crédit Mobilier of America scandal involved corruption in the financing of the Union Pacific Railroad, part of the transcontinental railroad which was completed in 1869. Union Pacific officers and directors secretly purchased control of the Crédit Mobilier of America company, then contracted with it to undertake construction of the railroad. The railroad paid the company's grossly inflated invoices with federal funds appropriated to subsidize the project, and the company was allowed to purchase Union Pacific securities at par value, well below the market rate. Crédit Mobilier showed large profits and stock gains, and distributed substantial dividends. The high expenses meant Congress was called upon to appropriate more funds. One of the railroad officials who controlled Crédit Mobilier was also a congressman, Oakes Ames of Massachusetts. He offered some of his colleagues the opportunity to buy Crédit Mobilier stock at par value, well below what it sold for on the market, and the railroad got its additional appropriations.\\n\\nThe story broke in July 1872, in the middle of the presidential campaign. Among those named were Vice President Schuyler Colfax, Massachusetts Senator Henry Wilson (the Republican candidate for vice president), Speaker James G. Blaine of Maine, and Garfield. Greeley had little luck taking advantage of the scandal. When Congress reconvened after the election, Blaine, seeking to clear his name, demanded a House investigation. Evidence before the special committee exonerated Blaine. Garfield had said in September 1872 that Ames had offered him stock but he had repeatedly refused it. Testifying before the committee in January, Ames said he had offered Garfield ten shares of stock at par value, but that Garfield had never taken them or paid for them, though a year passed, from 1867 to 1868, before Garfield had finally refused. Appearing before the committee on January 14, 1873, Garfield confirmed much of this. Ames testified several weeks later that Garfield agreed to take the stock on credit, and that it was paid for by the company's huge dividends. The two men differed over $300 that Garfield received and later paid back, with Garfield deeming it a loan and Ames a dividend.\\nGarfield's biographers have been unwilling to exonerate him in the scandal. Allan Peskin writes, \\\"Did Garfield lie? Not exactly. Did he tell the truth? Not completely. Was he corrupted? Not really. Even Garfield's enemies never claimed that his involvement in the affair influenced his behavior.\\\" Rutkow writes, \\\"Garfield's real offense was that he knowingly denied to the House investigating committee that he had agreed to accept the stock and that he had also received a dividend of $329.\\\" Caldwell suggests Garfield \\\"told the truth [before the committee, but] certainly failed to tell the whole truth, clearly evading an answer to certain vital questions and thus giving the impression of worse faults than those of which he was guilty.\\\" That Crédit Mobilier was a corrupt organization had been a badly kept secret, even mentioned on the floor of Congress, and editor Sam Bowles wrote at the time that Garfield, in his positions on committees dealing with finance, \\\"had no more right to be ignorant in a matter of such grave importance as this, than the sentinel has to snore on his post.\\\"\\nAnother issue that caused Garfield trouble in his 1874 reelection bid was the so-called \\\"Salary Grab\\\" of 1873, which increased the compensation for members of Congress by 50%, retroactive to 1871. As chairman of the Appropriations Committee, Garfield was responsible for shepherding the appropriations bill through the House; during the debate in February 1873, Massachusetts Representative Benjamin Butler offered the increase as an amendment, and despite Garfield's opposition, it passed the House and eventually became law. The law was very popular in the House, as almost half the members were lame ducks, but the public was outraged, and many of Garfield's constituents blamed him, though he personally refused to accept the increase. In a bad year for Republicans, who lost control of the House for the first time since the Civil War, Garfield had his closest congressional election, winning with only 57% of the vote.\\n\\nFloor leader; Hayes administration\\nThe Democratic takeover of the House of Representatives in 1875 meant the loss of Garfield's chairmanship of the Appropriations Committee, though the Democrats did put him on the Ways and Means Committee. With many of his leadership rivals defeated in the 1874 Democratic landslide, and Blaine elected to the Senate, Garfield was seen as the Republican floor leader, and the likely Speaker, should the party regain control of the chamber.\\nGarfield thought the land grants given to expanding railroads was an unjust practice. He also opposed monopolistic practices by corporations, as well as the power sought by workers' unions. He supported the proposed establishment of the United States civil service as a means of ridding officials of the annoyance of aggressive office seekers. He especially wished to eliminate the practice of forcing government workers, in exchange for their positions, to kick back a percentage of their wages as political contributions.\\nAs the 1876 presidential election approached, Garfield was loyal to the candidacy of Senator Blaine, and fought for the former Speaker's nomination at the 1876 Republican National Convention in Cincinnati. When it became clear, after six ballots, that Blaine could not prevail, the convention nominated Ohio Governor Rutherford B. Hayes. Although Garfield had supported Blaine, he had kept good relations with Hayes, and wholeheartedly supported the governor. Garfield had hoped to retire from politics after his term expired to devote himself full-time to the practice of law, but to help his party, he sought re-election, and won it easily that October. Any celebration was short-lived, as Garfield's youngest son, Neddie, fell ill with whooping cough shortly after the congressional election, and soon died.\\n\\nWhen Hayes appeared to have lost the presidential election the following month to Democrat Samuel Tilden, the Republicans launched efforts to reverse the results in South Carolina, Louisiana, and Florida, where they held the governorship. If Hayes won all three states, he would take the election by a single electoral vote. Grant asked Garfield to serve as a \\\"neutral observer\\\" of the recount in Louisiana. The observers soon recommended to the state electoral commissions that Hayes be declared the winner—Garfield recommended the entire vote of West Feliciana Parish, which had given Tilden a sizable majority, be thrown out. The Republican governors of the three states certified that Hayes had won their states, to the outrage of Democrats, who had the state legislatures submit rival returns, and threatened to prevent the counting of the electoral vote—under the Constitution, Congress is the final arbiter of the election. Congress then established an Electoral Commission, consisting of eight Republicans and seven Democrats, to determine the winner. Despite his objection to the Commission, Garfield was appointed to it. He felt Congress should count the vote and proclaim Hayes victorious. Hayes emerged the victor by a party line vote of 8–7. In exchange for recognizing Hayes as president, Southern Democrats secured the removal of federal troops from the South, ending Reconstruction.\\nAlthough an Ohio Senate seat would be vacated by the resignation of John Sherman to become Treasury Secretary, Hayes needed Garfield's expertise to protect him from the agenda of a hostile Congress, and asked him not to seek it. Garfield agreed. As Hayes's key legislator in the House, he gained considerable prestige and respect for his role there. When Congress debated the Bland–Allison Act, to have the government purchase large quantities of silver and strike it into legal tender dollar coins, Garfield opposed it as a deviation from the gold standard; it was enacted over Hayes's veto in February 1878.\\nIn 1876, Garfield purchased the property in Mentor that reporters later dubbed Lawnfield, where he conducted the first successful front porch campaign for the presidency. Hayes suggested that Garfield run for governor in 1879, seeing that as a road likely to take Garfield to the White House. Garfield preferred to seek election as a U.S. senator. Rivals were spoken of for the seat, such as Secretary Sherman, but he had presidential ambitions (for which he sought Garfield's support), and other candidates fell by the wayside. The General Assembly elected Garfield to the Senate in January 1880, though his term was not scheduled to commence until March 4, 1881.\\n\\nLegal career and other activities\\nIn 1865, Garfield became a partner in the law firm of a fellow Disciple of Christ, Jeremiah Black. They had much in common, except politics: Black was an avid Democrat, having served in the cabinet of President James Buchanan. The next year, Black was retained by some pro-Confederate northern civilians who had been found guilty of treason in a military court and sentenced to death. Black saw an opportunity to strike a blow against military courts and the Republicans. He had heard Garfield's military speeches, and learned of not only his oratory skills but also his resistance to expansive powers of military commissions. Black assigned the case to Garfield one week before arguments were to be made before the U. S. Supreme Court. When Black warned him of the political peril, Garfield responded, \\\"It don't make any difference. I believe in English liberty and English law.\\\" In this landmark case, Ex parte Milligan, Garfield successfully argued that civilians could not be tried before military tribunals, despite a declaration of martial law, as long as civil courts were still operating. In his first court appearance, Garfield's oral argument lasted over two hours, and though his wealthy clients refused to pay him, he had established himself as a preeminent lawyer.\\nDuring Grant's first term, Garfield was discontented with public service and in 1872 again pursued opportunities in the law. But he declined a partnership offer from a Cleveland law firm when told his prospective partner was of \\\"intemperate and licentious\\\" reputation. In 1873, after Chase's death, Garfield appealed to Grant to appoint Justice Noah H. Swayne Chief Justice, but Grant appointed Morrison R. Waite.\\n\\nIn 1871, Garfield traveled to Montana Territory to negotiate the removal of the Bitterroot Salish tribe to the Flathead Indian Reservation. Having been told that the people would happily move, Garfield expected an easy task. Instead, he found the Salish determined to stay in their Bitterroot Valley homeland. His attempts to coerce Chief Charlo to sign the agreement nearly brought about a military clash. In the end, he convinced two subchiefs to sign and move to the reservation with a few of the Salish people. Garfield never convinced Charlo to sign, although the official treaty document voted on by Congress bore his forged mark.\\nIn 1876, Garfield developed a trapezoid proof of the Pythagorean theorem, which was published in the New England Journal of Education. Mathematics historian William Dunham wrote that Garfield's trapezoid work was \\\"really a very clever proof.\\\" According to the Journal, Garfield arrived at the proof \\\"in mathematical amusements and discussions with other members of congress.\\\"\\nAfter his conversion experience in 1850, religious inquiry was a high priority for Garfield. He read widely and moved beyond the confines of his early experience as a member of the Disciples of Christ. His new, broader perspective was rooted in his devotion to freedom of inquiry and his study of history. The intensity of Garfield's religious thought was also influenced by his experience in combat and his interaction with voters.\\n\\nPresidential election of 1880\\nRepublican nomination\\nHaving just been elected to the Senate with John Sherman's support, Garfield was committed to Sherman for the 1880 Republican presidential nomination. Before the convention began, however, a few Republicans, including Wharton Barker of Philadelphia, thought Garfield the best choice for the nomination. Garfield denied any interest in the position, but the attention was enough to make Sherman suspicious of his lieutenant's ambitions. Besides Sherman, the early favorites for the nomination were Blaine, former President Grant; several other candidates attracted delegates as well.\\nThe Republican Party at the time was split into two factions: the \\\"Stalwarts\\\", who supported the existing federal government patronage system, and the \\\"Half-Breeds\\\", who wanted civil service reform. As the convention began, New York Senator Roscoe Conkling, floor leader for the Stalwarts, who supported former President Ulysses S. Grant, proposed that the delegates pledge to back the eventual nominee in the general election. When three West Virginia delegates declined to be so bound, Conkling sought to expel them from the convention. Garfield rose to defend the men, giving a passionate speech in defense of their right to reserve judgment. The crowd turned against Conkling, and he withdrew the motion. The performance delighted Garfield's boosters, who were then convinced he was the only one who could attract a majority of the delegates' votes.\\nAfter speeches in favor of the other front-runners, Garfield rose to place Sherman's name in nomination; his speech was well-received, but the delegates mustered little excitement for Sherman as the next president. The first ballot showed Grant leading with 304 votes to Blaine's 284, and Sherman's 93 votes placed him in a distant third. Subsequent ballots demonstrated a deadlock between Grant and Blaine, with neither having the 379 votes needed for nomination. Jeremiah McLain Rusk, a member of the Wisconsin delegation, and Benjamin Harrison, an Indiana delegate, sought to break the deadlock by shifting a few of the anti-Grant votes to a dark horse candidate—Garfield. Garfield gained 50 votes on the 35th ballot, and a stampede began. Garfield protested to the Ohio delegation that he did not seek the nomination and would not betray Sherman, but they overruled his objections and cast their ballots for him. In the next round of voting, nearly all the Sherman and Blaine delegates shifted their support to Garfield, giving him 399 votes, and the Republican nomination. Most of the Grant forces backed the former president to the end, creating a disgruntled Stalwart minority in the party. To obtain that faction's support for the ticket, Chester A. Arthur, a former New York customs collector and member of Conkling's political machine, was chosen as the vice presidential nominee.\\n\\nCampaign against Hancock\\nEven with a Stalwart on the ticket, animosity between the Republican factions carried over from the convention, so Garfield traveled to New York to meet with party leaders. After convincing the Stalwart crowd to put aside their differences and unite for the coming campaign, Garfield returned to Ohio, leaving the active campaigning to others, as was traditional at the time. Meanwhile, the Democrats settled on their nominee, Major General Winfield Scott Hancock of Pennsylvania, a career military officer. Hancock and the Democrats expected to carry the Solid South, while much of the North was considered safe territory for Garfield and the Republicans; most of the campaign focused on a few close states, including New York and Indiana.\\nPractical differences between the candidates were few, but Republicans began the campaign with the familiar theme of waving the bloody shirt. They reminded Northern voters the Democratic Party was responsible for secession and four years of civil war, and Democrats would reverse the gains of that war, dishonor Union veterans, and pay Confederate veterans pensions out of the federal treasury. Fifteen years had passed since the end of the war, and with Union generals at the head of both tickets, the bloody shirt was of diminishing value in exciting the voters. With a few months to go before the election, the Republicans switched tactics to emphasize the tariff. Seizing on the Democratic platform's call for a \\\"tariff for revenue only\\\", Republicans told Northern workers a Hancock presidency would weaken the tariff protection that kept them in good jobs. Hancock made the situation worse when, attempting to strike a moderate stance, he said, \\\"The tariff question is a local question.\\\" The Republican ploy proved effective in uniting the North behind Garfield. Ultimately, of the more than 9.2 million popular votes cast, fewer than 2,000 separated the two candidates. But in the Electoral College, Garfield had an easy victory over Hancock, 214 to 155. The election made Garfield the only sitting member of the House ever to be elected to the presidency.\\n\\nPresidency (1881)\\nCabinet and inauguration\\nBefore his inauguration, Garfield was occupied with assembling a cabinet that might engender peace between the party's Conkling and Blaine factions. Blaine's delegates had provided much of the support for Garfield's nomination, so the Maine senator received the place of honor as Secretary of State. Blaine was not only the president's closest advisor, but he was also obsessed with knowing all that took place in the White House, and allegedly posted spies there in his absence. Garfield nominated William Windom of Minnesota as Secretary of the Treasury, William H. Hunt of Louisiana as Secretary of the Navy, Robert Todd Lincoln as Secretary of War, and Samuel J. Kirkwood of Iowa as Secretary of the Interior. New York was represented by Thomas Lemuel James as Postmaster General. Garfield appointed Pennsylvania's Wayne MacVeagh, an adversary of Blaine's, as Attorney General. Blaine tried to sabotage the appointment by convincing Garfield to name an opponent of MacVeagh, William E. Chandler, as Solicitor General under MacVeagh. Only Chandler's rejection by the Senate forestalled MacVeagh's resignation over the matter.\\nBecause Garfield was distracted by cabinet maneuvering, his inaugural address was a \\\"compendium of platitudes\\\" and fell below expectations. At one high point, however, Garfield emphasized the civil rights of African-Americans, saying \\\"Freedom can never yield its fullness of blessings so long as the law or its administration places the smallest obstacle in the pathway of any virtuous citizen.\\\" After discussing the gold standard, the need for education, and an unexpected denunciation of Mormon polygamy, the speech ended. The crowd applauded, but the speech, according to Peskin, \\\"however sincerely intended, betrayed its hasty composition by the flatness of its tone and the conventionality of its subject matter.\\\"\\nGarfield's appointment of James infuriated Conkling, a factional opponent of the Postmaster General, who demanded a compensatory appointment for his faction, such as the position of Secretary of the Treasury. The resulting squabble occupied much of Garfield's brief presidency. The feud with Conkling reached a climax when the president, at Blaine's instigation, nominated Conkling's enemy, Judge William H. Robertson, to be Collector of the Port of New York. This was one of the prize patronage positions below cabinet level and was then held by Edwin A. Merritt. Conkling raised the time-honored principle of senatorial courtesy in an attempt to defeat the nomination, to no avail. Garfield, who believed the practice was corrupt, would not back down and threatened to withdraw all nominations unless Robertson was confirmed, intending to \\\"settle the question whether the president is registering clerk of the Senate or the Executive of the United States.\\\" Ultimately, Conkling and his New York colleague, Senator Thomas C. Platt, resigned their Senate seats to seek vindication but found only further humiliation when the New York legislature elected others in their places. Robertson was confirmed as Collector and Garfield's victory was clear. To Blaine's chagrin, the victorious Garfield returned to his goal of balancing the interests of party factions and nominated a number of Conkling's Stalwart friends to offices.\\nWith his cabinet complete, Garfield had to contend with myriad office seekers. He exclaimed, \\\"My God! What is there in this place that a man should ever get into it.\\\" Garfield's family happily settled into the White House, but he found presidential duties exasperating.\\n\\nRefinance of national debt\\nGarfield ordered the Secretary of the Treasury William Windom to refund (refinance) the national debt by calling in outstanding U.S. bonds paying 6% interest. Holders would have the option of accepting cash or new bonds at 3%, closer to the interest rates of the time. Taxpayers were saved an estimated $10 million. By comparison, federal expenditures in 1881 were below $261 million (~$7.09 billion in 2023).\\n\\nSupreme Court nomination\\nIn 1880, President Hayes had nominated Stanley Matthews to the Supreme Court but the Senate declined to act on the nomination. In March 1881, Garfield re-nominated Matthews to the Court and the Senate confirmed Matthews by a vote of 24–23. According to The New York Times, \\\"opposition to Matthews's Supreme Court appointment ... stemmed from his prosecution in 1859 of a newspaper editor who had assisted two runaway slaves.\\\" Because Matthews was \\\"a professed abolitionist at the time, the matter was later framed as political expediency triumphing over moral principle.\\\" Matthews served on the Court until his death in 1889.\\n\\nReforms\\nGrant and Hayes had both advocated civil service reform, and by 1881 such reform associations had organized with renewed energy across the nation. Garfield sympathized with them, believing the spoils system damaged the presidency and often eclipsed more important concerns. Some reformers became disappointed when Garfield promoted limited tenure only to minor office seekers and gave appointments to his old friends.\\nCorruption in the post office also cried out for reform. In April 1880, there had been a congressional investigation of corruption in the Post Office Department, where profiteering rings allegedly stole millions of dollars, securing bogus mail contracts on star routes. After obtaining contracts with the lowest bid, costs to run the mail routes would be escalated and profits would be divided among ring members. Shortly after taking office, Garfield received word of postal corruption by an alleged star route ringleader, Assistant Postmaster General Thomas J. Brady. Garfield demanded Brady's resignation and ordered prosecutions that ended in trials for conspiracy. When told that his party, including his campaign manager, Stephen W. Dorsey, was involved, Garfield directed that the corruption in the Post Office be rooted out \\\"to the bone\\\", regardless of where it might lead. Brady resigned and was indicted for conspiracy, though jury trials in 1882 and 1883 found Brady not guilty.\\n\\nCivil rights and education\\nGarfield believed the key to improving the state of African American civil rights was government supported education. During Reconstruction, freedmen had gained citizenship and suffrage, which enabled them to participate in government, but Garfield believed their rights were being eroded by Southern white resistance and illiteracy, and he was concerned that blacks would become America's permanent \\\"peasantry\\\". He proposed a \\\"universal\\\" education system funded by the federal government. In February 1866, as a congressman from Ohio, Garfield and Ohio School Commissioner Emerson Edward White had drafted a bill for the National Department of Education. They believed that through the use of statistics they could push the US Congress to establish a federal agency for school reform. But by the time of Garfield's presidency, Congress and the northern white public had lost interest in African-American rights, and Congress did not pass federal funding for universal education during his term. Garfield also worked to appoint several African Americans to prominent positions: Frederick Douglass, recorder of deeds in Washington; Robert Elliot, special agent to the Treasury; John M. Langston, Haitian minister; and Blanche K. Bruce, register to the Treasury. Garfield believed Southern support for the Republican Party could be gained by \\\"commercial and industrial\\\" interests rather than race issues and began to reverse Hayes's policy of conciliating Southern Democrats. He appointed William H. Hunt, a Republican from Louisiana, as Secretary of the Navy. To break the hold of the resurgent Democratic Party in the Solid South, Garfield took patronage advice from Virginia Senator William Mahone of the biracial independent Readjuster Party, hoping to add the independents' strength to the Republicans' there.\\n\\nForeign policy and naval reform\\nGarfield had little foreign policy experience, so he leaned heavily on Blaine. They agreed on the need to promote freer trade, especially within the Western Hemisphere. Garfield and Blaine believed increasing trade with Latin America would be the best way to keep the United Kingdom of Great Britain and Ireland from dominating the region. And by encouraging exports, they believed they could increase American prosperity. Garfield authorized Blaine to call for a Pan-American conference in 1882 to mediate disputes among the Latin American nations and to serve as a forum for talks on increasing trade.\\nAt the same time, they hoped to negotiate a peace in the War of the Pacific then being fought by Bolivia, Chile, and Peru. Blaine favored a resolution that would result in Peru yielding no territory, but Chile by 1881 had occupied the Peruvian capital of Lima, and rejected any settlement that restored the previous status quo.\\nGarfield sought to expand American influence in other areas, calling for renegotiation of the Clayton–Bulwer Treaty to allow the United States to construct a canal through Panama without British involvement and attempting to reduce British influence in the strategically located Kingdom of Hawaii. Garfield's and Blaine's plans for the United States' involvement in the world stretched even beyond the Western Hemisphere, as he sought commercial treaties with Korea and Madagascar. Garfield also considered enhancing U.S. military strength abroad, asking Navy Secretary Hunt to investigate the navy's condition with an eye toward expansion and modernization. In the end, these ambitious plans came to nothing after Garfield was assassinated. Nine countries had accepted invitations to the Pan-American conference, but the invitations were withdrawn in April 1882 after Blaine resigned from the cabinet and Arthur, Garfield's successor, cancelled the conference. Naval reform continued under Arthur, on a more modest scale than Garfield and Hunt had envisioned, ultimately ending in the construction of the Squadron of Evolution.\\n\\nAssassination\\nGuiteau and shooting\\nCharles J. Guiteau had followed various professions in his life, but in 1880 had determined to gain federal office by supporting what he expected would be the winning Republican ticket. He composed a speech, \\\"Garfield vs. Hancock\\\", and got it printed by the Republican National Committee. One means of persuading the voters in that era was through orators expounding on the candidate's merits, but with the Republicans seeking more famous men, Guiteau received few opportunities to speak. On one occasion, according to Kenneth D. Ackerman, Guiteau was unable to finish his speech due to nerves. Guiteau, who considered himself a Stalwart, deemed his contribution to Garfield's victory sufficient to justify his appointment to the position of consul in Paris, despite the fact that he spoke no French, nor any foreign language. One medical expert has since described Guiteau as possibly a narcissistic schizophrenic; neuroscientist Kent Kiehl assessed him as a clinical psychopath.\\n\\nOne of Garfield's more wearying duties was seeing office-seekers, and he saw Guiteau at least once. White House officials suggested to Guiteau that he approach Blaine, as the consulship was within the Department of State. Blaine also saw the public regularly, and Guiteau became a regular at these sessions. Blaine, who had no intention of giving Guiteau a position he was unqualified for and had not earned, simply said the deadlock in the Senate over Robertson's nomination made it impossible to consider the Paris consulship, which required Senate confirmation. Once the New York senators had resigned, and Robertson had been confirmed as Collector, Guiteau pressed his claim, and Blaine told him he would not receive the position.\\nGuiteau came to believe he had lost the position because he was a Stalwart. He decided the only way to end the Republican Party's internecine warfare was for Garfield to die—though he had nothing personal against the president. Arthur's succession would restore peace, he felt, and lead to rewards for fellow Stalwarts, including Guiteau.\\nThe assassination of Abraham Lincoln was deemed a fluke due to the Civil War, and Garfield, like most people, saw no reason the president should be guarded; his movements and plans were often printed in the newspapers. Guiteau knew Garfield would leave Washington for a cooler climate on July 2, 1881, and made plans to kill him before then. He purchased a gun he thought would look good in a museum, and followed Garfield several times, but each time his plans were frustrated, or he lost his nerve. His opportunities dwindled to one—Garfield's departure by train for New Jersey on the morning of July 2.\\nGuiteau concealed himself by the ladies' waiting room at the Sixth Street Station of the Baltimore and Potomac Railroad, from where Garfield was scheduled to depart. Most of Garfield's cabinet planned to accompany him at least part of the way. Blaine, who was to remain in Washington, came to the station to see him off. The two men were deep in conversation and did not notice Guiteau before he took out his revolver and shot Garfield twice, once in the back and once in the arm. Guiteau attempted to leave the station but was quickly captured. As Blaine recognized him, Guiteau was led away, and said, \\\"I did it. I will go to jail for it. I am a Stalwart and Arthur will be President.\\\" News of his motivation to benefit the Stalwarts reached many with the news of the shooting, causing rage against that faction.\\n\\nTreatment and death\\nGarfield was struck by two shots: one glanced off his arm while the other pierced his back, shattering a rib and embedding itself in his abdomen. \\\"My God, what is this?\\\" he exclaimed. Among those at the station was Robert Todd Lincoln, who was deeply upset, thinking back to when his father Abraham Lincoln was assassinated 16 years earlier. Garfield was taken on a mattress upstairs to a private office, where several doctors examined him. At his request, Garfield was taken back to the White House, and his wife, then in New Jersey, was sent for. Blaine sent word to Vice President Arthur in New York City, who received threats against his life because of his animosity toward Garfield and Guiteau's statements.\\nAlthough Joseph Lister's pioneering work in antisepsis was known to American doctors, few of them had confidence in it, and none of his advocates were among Garfield's treating physicians. The physician who took charge at the depot and then at the White House was Doctor Willard Bliss. A noted physician and surgeon, Bliss was an old friend of Garfield, and about a dozen doctors, led by Bliss, were soon probing the wound with unsterilized fingers and instruments. Garfield was given morphine for the pain, and asked Bliss to frankly tell him his chances, which Bliss put at one in a hundred. \\\"Well, Doctor, we'll take that chance.\\\"\\nOver the next few days, Garfield made some improvement, as the nation viewed the news from the capital and prayed. Although he never stood again, he was able to sit up and write several times, and his recovery was viewed so positively that a steamer was fitted out as a seagoing hospital to aid with his convalescence. He was nourished on oatmeal porridge (which he detested) and milk from a cow on the White House lawn. When told that Indian chief Sitting Bull, a prisoner of the army, was starving, Garfield said, \\\"Let him starve...\\\" initially, but a few moments later said, \\\"No, send him my oatmeal.\\\"\\n\\nX-ray imaging, which could have assisted physicians in precisely locating the bullet in Garfield's body, would not be invented for another 14 years. Alexander Graham Bell tried to locate the bullet with a primitive metal detector, but was unsuccessful, though the device had been effective when tested on others. But Bliss limited its use on Garfield, ensuring he remained in charge. Because Bliss insisted the bullet rested someplace it did not, the detector could not locate it. Bell shortly returned after adjusting his device, which emitted an unusual tone in the area where Bliss believed the bullet was lodged. Bliss took this as confirmation that the bullet was where he declared it to be. Bliss recorded the test as a success, saying it was: now unanimously agreed that the location of the ball has been ascertained with reasonable certainty, and that it lies, as heretofore stated, in the front wall of the abdomen, immediately over the groin, about five inches [130 mm] below and to the right of the navel.\\nOne means of keeping Garfield comfortable in Washington's summer heat was one of the first successful air conditioning units: air propelled by fans over ice and then dried reduced the temperature in the sickroom by 20 °F (11 °C). Engineers from the navy, and other scientists, worked together to develop it, though there were problems to solve, such as excessive noise and increased humidity.\\nOn July 23, Garfield took a turn for the worse when his temperature increased to 104 °F (40 °C); doctors, concerned by an abscess at the wound, inserted a drainage tube. This initially helped, and the bedridden Garfield held a brief cabinet meeting on July 29; members were under orders from Bliss to discuss nothing that might excite Garfield. Doctors probed the abscess, hoping to find the bullet; they likely made the infections worse. Garfield performed only one official act in August, signing an extradition paper. By the end of the month, he was much feebler than he had been, and his weight had decreased from 210 pounds (95 kg) to 130 pounds (59 kg).\\nGarfield had long been anxious to escape hot, unhealthy Washington, and in early September the doctors agreed to move him to Elberon, part of Long Branch, New Jersey, where his wife had recovered earlier in the summer. He left the White House for the last time on September 5, traveling in a specially cushioned railway car; a spur line to the Francklyn Cottage, a seaside mansion given over to his use, was built in a night by volunteers. After arriving in Elberon the next day, Garfield was moved from the train car to a bedroom where he could see the ocean as officials and reporters maintained what became (after an initial rally) a death watch. Garfield's personal secretary, Joe Stanley Brown, wrote forty years later, \\\"to this day I cannot hear the sound of the low slow roll of the Atlantic on the shore, the sound which filled my ears as I walked from my cottage to his bedside, without recalling again that ghastly tragedy.\\\"\\n\\nOn September 18, Garfield asked Colonel A.F. Rockwell, a friend, if he would have a place in history. Rockwell assured him he would and told Garfield he had much work still before him. But his response was, \\\"No, my work is done.\\\" The following day, Garfield, then suffering also from pneumonia and hypertension, marveled that he could not pick up a glass despite feeling well and went to sleep without discomfort. He awoke that evening around 10:15 p.m. complaining of great pain in his chest to his chief of staff General David Swaim, who was watching him, as he placed his hand over his heart. The president then requested a drink of water from Swaim. After finishing his glass, Garfield said, \\\"Oh Swaim, this terrible pain—press your hand on it.\\\" As Swaim put his hand on Garfield's chest, Garfield's hands went up reflexively. Clutching his heart, he exclaimed, \\\"Oh, Swaim, can't you stop this? Oh, oh, Swaim!\\\" Those were Garfield's last words. Swaim ordered another attendant to send for Bliss, who found Garfield unconscious. Despite efforts to revive him, Garfield never awoke, and he was pronounced dead at about 10:30 p.m. Learning from a reporter of Garfield's death the following day, Chester A. Arthur took the presidential oath of office administered by New York Supreme Court Justice John R. Brady.\\nAccording to some historians and medical experts, Garfield might have survived his wounds had the doctors attending him had at their disposal today's medical research, knowledge, techniques, and equipment. Standard medical practice at the time dictated that priority be given to locating the path of the bullet. Several of his doctors inserted their unsterilized fingers into the wound to probe for the bullet, a common practice in the 1880s. Historians agree that massive infection was a significant factor in Garfield's demise. Biographer Peskin said medical malpractice did not contribute to Garfield's death; the inevitable infection and blood poisoning that would ensue from a deep bullet wound resulted in damage to multiple organs and spinal fragmentation. Rutkow, a professor of surgery at the University of Medicine and Dentistry of New Jersey, has argued that starvation also played a role. Rutkow suggests \\\"Garfield had such a nonlethal wound. In today's world, he would have gone home in a matter of two or three days.\\\" The conventional narrative regarding Garfield's post-shooting medical condition was challenged by Theodore Pappas and Shahrzad Joharifard in a 2013 article in The American Journal of Surgery. They argued that Garfield died from a late rupture of a splenic artery pseudoaneurysm, which developed secondary to the path of the bullet adjacent to the splenic artery. They also argued that his sepsis was actually caused by post-traumatic acute acalculous cholecystitis. Based on the autopsy report, the authors speculate that his gallbladder subsequently ruptured, leading to the development of a large bile-containing abscess adjacent to the gallbladder. Pappas and Joharifard say this caused the septic decline in Garfield's condition that was visible starting from July 23, 1881. Pappas and Joharifard also state that they don't believe that Garfield's doctors could have saved him even if they had been aware of his cholecystitis, since the first successful cholecystectomy (surgical removal of the gallbladder) was performed a year after Garfield's death.\\nGuiteau was indicted on October 14, 1881, for the murder of the president. During his trial, Guiteau declared that he was not responsible for Garfield's death, admitting to the shooting but not the killing. In his defense, Guiteau wrote: \\\"General Garfield died from malpractice. According to his own physicians, he was not fatally shot. The doctors who mistreated him ought to bear the odium of his death, and not his assailant. They ought to be indicted for murdering James A. Garfield, and not me.\\\" After a chaotic trial in which Guiteau often interrupted and argued, and in which his counsel used the insanity defense, the jury found him guilty on January 25, 1882, and he was sentenced to death by hanging. Guiteau may have had neurosyphilis, a disease that causes physiological mental impairment. He was executed on June 30, 1882.\\n\\nFuneral, memorials and commemorations\\nGarfield's funeral train left Long Branch on the same special track that had brought him there, traveling over tracks blanketed with flowers and past houses adorned with flags. His body was transported to the Capitol and then continued on to Cleveland for burial. Shocked by his death, Marine Band leader John Philip Sousa composed the march \\\"In Memoriam\\\", which was played when Garfield's body was received in Washington, D.C. More than 70,000 citizens, some waiting over three hours, passed by Garfield's coffin as his body lay in state from September 21 to 23, 1881, at the United States Capitol rotunda; on September 25, in Cleveland, Garfield's casket was paraded down Euclid Avenue from Wilson Avenue to Public Square, with those in attendance including former presidents Grant and Hayes, and Generals William Sherman, Sheridan and Hancock. More than 150,000—a number equal to the city's population—likewise paid their respects, and Sousa's march was again played. Garfield's body was temporarily interred in the Schofield family vault in Cleveland's Lake View Cemetery until his permanent memorial was built.\\nMemorials to Garfield were erected across the country. On April 10, 1882, seven months after Garfield's death, the U.S. Post Office Department issued a postage stamp in his honor. In 1884, sculptor Frank Happersberger completed a monument on the grounds of the San Francisco Conservatory of Flowers. In 1887, the James A. Garfield Monument was dedicated in Washington. Another monument, in Philadelphia's Fairmount Park, was erected in 1896. In Victoria, Australia, Cannibal Creek was renamed Garfield in his honor.\\n\\nOn May 19, 1890, Garfield's body was permanently interred, with great solemnity and fanfare, in a mausoleum in Lake View Cemetery. Attending the dedication ceremonies were former President Hayes, President Benjamin Harrison, and future president William McKinley. Garfield's Treasury Secretary, William Windom, also attended. Harrison said Garfield was always a \\\"student and instructor\\\" and that his life works and death would \\\"continue to be instructive and inspiring incidents in American history\\\". Three panels on the monument display Garfield as a teacher, Union major general, and orator; another shows him taking the presidential oath, and a fifth shows his body lying in state at the Capitol rotunda in Washington, D.C.\\nGarfield's murder by a deranged office-seeker awakened public awareness of the need for civil service reform legislation. Senator George H. Pendleton, a Democrat from Ohio, launched a reform effort that resulted in the Pendleton Act in January 1883. This act reversed the \\\"spoils system\\\" where office seekers paid up or gave political service to obtain or keep federally appointed positions. Under the act, appointments were awarded on merit and competitive examination. To ensure the reform was implemented, Congress and Arthur established and funded the Civil Service Commission. The Pendleton Act, however, covered only 10% of federal government workers. For Arthur, previously known for having been a \\\"veteran spoilsman\\\", civil service reform became his most noteworthy achievement.\\nA marble statue of Garfield by Charles Niehaus was added to the National Statuary Hall Collection in the Capitol in Washington D.C., a gift from the State of Ohio in 1886.\\nGarfield is honored with a life-size bronze sculpture inside the Cuyahoga County Soldiers' and Sailors' Monument in Cleveland, Ohio.\\nOn March 2, 2019, the National Park Service erected exhibit panels in Washington to mark the site of his assassination.\\n\\nLegacy and historical view\\nFor a few years after his assassination, Garfield's life story was seen as an exemplar of the American success story—that even the poorest boy might someday become President of the United States. Peskin wrote: \\\"In mourning Garfield, Americans were not only honoring a president; they were paying tribute to a man whose life story embodied their own most cherished aspirations.\\\" As the rivalry between Stalwarts and Half-Breeds faded from the scene in the late 1880s and after, so too did memories of Garfield. In the 1890s, Americans became disillusioned with politicians, and looked elsewhere for inspiration, focusing on industrialists, labor leaders, scientists, and others as their heroes. Increasingly, Garfield's short time as president was forgotten.\\n\\nThe 20th century saw no revival for Garfield. Thomas Wolfe deemed the presidents of the Gilded Age, including Garfield, \\\"lost Americans\\\" whose \\\"gravely vacant and bewhiskered faces mixed, melted, swam together\\\". The politicians of the Gilded Age faded from the public eye, their luster eclipsed by those who had influenced America outside of political office during that time; the robber barons, the inventors, those who had sought social reform, and others who had lived as America rapidly changed. Current events and more recent figures occupied America's attention. According to Ackerman, \\\"the busy Twentieth Century has made Garfield's era seem remote and irrelevant, its leaders ridiculed for their very obscurity.\\\"\\nGarfield's biographers, and those who have studied his presidency, tend to think well of him, and that his presidency saw a promising start before its untimely end. Historian Justus D. Doenecke, while deeming Garfield a bit of an enigma, chronicles his achievements: \\\"by winning a victory over the Stalwarts, he enhanced both the power and prestige of his office. As a man, he was intelligent, sensitive, and alert, and his knowledge of how government worked was unmatched.\\\" Doenecke criticizes Garfield's dismissal of Merritt in Robertson's favor, and wonders if the president was truly in command of the situation even after the latter's confirmation. In 1931, Caldwell wrote: \\\"If Garfield lives in history, it will be partly on account of the charm of his personality—but also because in life and in death, he struck the first shrewd blows against a dangerous system of boss rule which seemed for a time about to engulf the politics of the nation. Perhaps if he had lived he could have done no more.\\\" Rutkow writes that \\\"James Abram Garfield's presidency is reduced to a tantalizing 'what if.'\\\"\\nIn 2002, historian Bernard A. Weisberger said, \\\"[Garfield] was, to some extent, a perfect moderate. He read widely (and unobtrusively) without its visibly affecting his Christianity, his Republicanism, or his general laissez-faire orthodoxy. He was not so much a scholar in politics as a politic scholar.\\\" Peskin believes Garfield deserves more credit for his political career than he has received: \\\"True, his accomplishments were neither bold nor heroic, but his was not an age that called for heroism. His stormy presidency was brief, and in some respects, unfortunate, but he did leave the office stronger than he found it. As a public man he had a hand in almost every issue of national importance for almost two decades, while as a party leader he, along with Blaine, forged the Republican Party into the instrument that would lead the United States into the twentieth century.\\\"\\n\\nNotes\\nReferences\\nWorks cited\\nFurther reading\\nFuller, Corydon E. (2022) [1887]. Reminiscences of James A. Garfield. Hansebooks. ISBN 978-3-34807-944-0.\\nGoodyear, C. W. (2023). President Garfield: From Radical to Unifier. New York, New York: Simon & Schuster.\\nGraff Henry F., ed. The Presidents: A Reference History (3rd ed. 2002) online\\nHammond, William A.; Ashhurst, Jr., John; Sims, J. Marion; Hodgen, John T. (December 1881). \\\"The Surgical Treatment of President Garfield\\\". The North American Review. 133 (301): 578–610. JSTOR 25101018.\\nHoudek, John Thomas. \\\"James A. Garfield and Rutherford B. Hayes: A Study in State and National Politics\\\" (PhD dissertation, Michigan State University; Proquest Dissertations Publishing, 1970. 7111871).\\nMenke, Richard. \\\"Media in America, 1881: Garfield, Guiteau, Bell, Whitman.\\\" Critical Inquiry 31.3 (2005): 638–664.\\nMillard, Candice (2012). Destiny of the Republic: A Tale of Madness, Medicine and the Murder of a President. New York, New York: Anchor Books. ISBN 978-0-7679-2971-4.\\nNorth, Ira Lutts. \\\"A rhetorical criticism of the speaking of James Abram Garfield, 1876-1880\\\" (PhD dissertation, Louisiana State University; ProQuest Dissertations Publishing, 1953. DP69446).\\nRushford, Jerry Bryant. \\\"Political Disciple: The Relationship Between James A. Garfield And The Disciples Of Christ\\\" (PhD dissertation, University of California, Santa Barbara; ProQuest Dissertations Publishing, 1977. 7807029).\\nSkidmore, Max J. \\\"James A. Garfield and Chester A. Arthur.\\\" in Maligned Presidents: The Late 19th Century (Palgrave Macmillan, New York, 2014) pp. 63–79.\\nSutton, Thomas C. \\\"James A. Garfield.\\\" in The Presidents and the Constitution (Volume One. New York University Press, 2020) pp. 266–275.\\nUhler, Kevin A. \\\"The demise of patronage: Garfield, the midterm election, and the passage of the Pendleton Civil Service Act\\\" (PhD. Diss. The Florida State University, 2011) online.\\nVermilya, Daniel J. James Garfield and the Civil War: For Ohio and the Union (Arcadia Publishing, 2015).\\n\\nExternal links\\n\\nGarfield, James Abram, (1831–1881) Congressional Biography\\nJames Garfield: A Resource Guide from the Library of Congress\\nJames A. Garfield at the Database of Classical Scholars\\n[http://millercenter.org/president/garfield Brief essays on James A. Garfield and his administration from the Miller Center of Public Affairs\\n\\\"Life Portrait of James Garfield\\\", from C-SPAN's American Presidents: Life Portraits, July 26, 1999\\nWorks by or about James A. Garfield at the Internet Archive\\nWorks by James A. Garfield at LibriVox (public domain audiobooks) \\nNotable alumni of Delta Upsilon fraternity, including Garfield\\nJames A. Garfield Personal Manuscripts\\nJames A. Garfield Collection at Williams College Chapin Library\\nJames A. Garfield Collection at Williams College Archives and Special Collections\\nOfficial medical bulletins relating to the health of U.S. President James Garfield from the U.S. National Library of Medicine. Contains medical bulletins issued by attending physicians D. Hayes Agnes, J.K. Barnes, D. W. Bliss, Frank H. Hamilton, Robert Reyburn, and J.J. Woodward between July 6 – September 19, 1881.\\n\\nBased on all the information, answer the query. \\n\\nQuery: If my future wife has the same first name as the 15th first lady of the United States' mother and her surname is the same as the second assassinated president's mother's maiden name, what is my future wife's name? \\n\\n\"}\n"
  },
  {
    "path": "benchmark/kernels/quantization/README.md",
    "content": "# W8A8 Block-wise Quantization Kernel Tuning\n\nAuto-tune Triton FP8/INT8 block-wise quantization kernels for optimal performance.\n\n## When to Use Triton FP8 Block-wise Quantization Kernel vs DeepGEMM\n\n**Use Triton FP8 Block-wise Quantization Kernel when:**\n- Output dtype is NOT `bfloat16` (e.g., `float16`, `float32`)\n- DeepGEMM is disabled (environment variable `SGLANG_ENABLE_JIT_DEEPGEMM=0`)\n- Running on GPUs with compute capability < SM90 (DeepGEMM requires SM90+)\n- You need cross-platform compatibility (Triton works on both NVIDIA and AMD GPUs)\n\n**Use DeepGEMM when:**\n- Output dtype is `bfloat16` AND DeepGEMM is enabled\n- Running on NVIDIA GPUs with compute capability >= SM90 (e.g., H100, H200)\n- Need maximum performance for production workloads (DeepGEMM is highly optimized for Hopper architecture)\n\n**Note:** DeepGEMM requires CUDA compute capability >= 9.0 (SM90+). It is specifically optimized for NVIDIA Hopper GPUs (H100/H200).\n\nThe kernel selection logic in SGLang automatically chooses DeepGEMM when conditions are met (see `w8a8_block_fp8_matmul` function in `fp8_kernel.py`), otherwise falls back to Triton implementation.\n\n## Quick Start\n\n**Default (DeepSeek-V3):**\n```bash\npython benchmark/kernels/quantization/tuning_block_wise_kernel.py --tp-size 8\n```\n\n**Custom Model (specify N and K):**\n```bash\npython benchmark/kernels/quantization/tuning_block_wise_kernel.py --N 5120 --K 25600\n```\n\n## Parameters\n\n- `--N`, `--K`: Weight matrix dimensions (N=output_dim, K=input_dim). If not specified, uses `--tp-size` for DeepSeek-V3\n- `--tp-size`: Tensor parallelism size for DeepSeek-V3 (default: 8)\n- `--input-type`: `fp8` or `int8` (default: fp8)\n- `--block-n`, `--block-k`: Block quantization granularity (default: 128)\n- `--batch-size`: Test single batch size (optional)\n\n## How to Calculate N and K\n\nFor a linear layer `y = xW^T` where `x` is (M, K) and `W` is (N, K):\n- **N**: Output features (weight matrix output dimension)\n- **K**: Input features (weight matrix input dimension)\n\n**Example: Qwen3-VL-32B** (hidden_size=5120, intermediate_size=25600, num_heads=64, num_kv_heads=8, head_dim=128) and TP=1\n```bash\n# QKV projection: Q(8192) + K(1024) + V(1024) = 10240\npython benchmark/kernels/quantization/tuning_block_wise_kernel.py --N 10240 --K 5120\n\n# MLP gate+up (SwiGLU): 2 * intermediate_size = 51200\npython benchmark/kernels/quantization/tuning_block_wise_kernel.py --N 51200 --K 5120\n\n# MLP down projection\npython benchmark/kernels/quantization/tuning_block_wise_kernel.py --N 5120 --K 25600\n\n# O projection (if separate from QKV)\npython benchmark/kernels/quantization/tuning_block_wise_kernel.py --N 5120 --K 8192\n```\n\nIf TP=8:\n\n```bash\n# QKV projection: Q(8192) + K(1024) + V(1024) = 10240 / TP=8\npython benchmark/kernels/quantization/tuning_block_wise_kernel.py --N 1280 --K 5120\n\n# MLP gate+up (SwiGLU): 2 * intermediate_size = 51200 / TP=8\npython benchmark/kernels/quantization/tuning_block_wise_kernel.py --N 6400 --K 5120\n\n# MLP down projection\npython benchmark/kernels/quantization/tuning_block_wise_kernel.py --N 5120 --K 3200\n\n# O projection (if separate from QKV)\npython benchmark/kernels/quantization/tuning_block_wise_kernel.py --N 5120 --K 1024\n```\n\n## Output\n\nGenerates JSON config files saved to `python/sglang/srt/layers/quantization/configs/`:\n```\nN={N},K={K},device_name={DEVICE},dtype=fp8_w8a8,block_shape=[128,128].json\n```\n\nConfig maps batch size to optimal kernel parameters:\n```json\n{\n    \"1\": {\"BLOCK_SIZE_M\": 16, \"BLOCK_SIZE_N\": 64, \"BLOCK_SIZE_K\": 128, ...},\n    \"2048\": {\"BLOCK_SIZE_M\": 128, \"BLOCK_SIZE_N\": 128, \"BLOCK_SIZE_K\": 128, ...}\n}\n```\n"
  },
  {
    "path": "benchmark/kernels/quantization/bench_fp4_quant.py",
    "content": "import argparse\nimport itertools\n\nimport torch\nimport triton\nfrom flashinfer import (\n    scaled_fp4_grouped_quantize,\n    silu_and_mul_scaled_nvfp4_experts_quantize,\n)\nfrom sgl_kernel.elementwise import silu_and_mul\n\nfrom sglang.benchmark.bench_utils import run_bench\nfrom sglang.srt.layers import deep_gemm_wrapper\nfrom sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_masked_post_quant_fwd\n\n\ndef _test_accuracy_once(E, M, K, input_dtype, device):\n    x = torch.randn(E, M, K, device=device, dtype=input_dtype)\n    glb_scales = torch.ones((E,), dtype=torch.float32, device=device)\n    masks = torch.full((E,), M, dtype=torch.int32, device=device)\n    out, blk_scales = silu_and_mul_scaled_nvfp4_experts_quantize(x, masks, glb_scales)\n    out1, blk_scales1 = scaled_fp4_grouped_quantize(\n        silu_and_mul(x),\n        masks,\n        glb_scales,\n    )\n\n    torch.testing.assert_close(out, out1)\n    torch.testing.assert_close(blk_scales, blk_scales1)\n    print(f\"E: {E}, M: {M}, K: {K}, type: {input_dtype} OK\")\n\n\nNUM_RANKS = 48\nM_PER_RANKs = [128, 256, 512, 1024]\nMs = [M_PER_RANK * NUM_RANKS for M_PER_RANK in M_PER_RANKs]\nKs = [2048, 4096, 7168]\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=[\"M\", \"K\"],\n        x_vals=list(itertools.product(Ms, Ks)),\n        x_log=False,\n        line_arg=\"provider\",\n        line_vals=[\"triton_fp8\", \"cuda_unfused_fp4\", \"cuda_fused_fp4\"],\n        line_names=[\"triton_fp8\", \"cuda_unfused_fp4\", \"cuda_fused_fp4\"],\n        styles=[(\"blue\", \"-\"), (\"orange\", \"-\"), (\"green\", \"-\")],\n        ylabel=\"ms\",\n        plot_name=\"fp4 quant\",\n        args={},\n    )\n)\ndef benchmark(M, K, provider):\n    E = 6\n    device = \"cuda\"\n    x = torch.randn(E, M, K, device=device, dtype=torch.bfloat16)\n    glb_scales = torch.ones((E,), dtype=torch.float32, device=device)\n    masks = torch.randint(1, 4096, (E,), dtype=torch.int32, device=device)\n    fp8_out = torch.empty(\n        (\n            x.shape[0],\n            x.shape[1],\n            x.shape[2] // 2,\n        ),\n        device=x.device,\n        dtype=torch.float8_e4m3fn,\n    )\n    scale_block_size = 128\n    fp8_scales = torch.empty(\n        (\n            x.shape[0],\n            x.shape[1],\n            x.shape[2] // 2 // scale_block_size,\n        ),\n        device=x.device,\n        dtype=torch.float32,\n    )\n\n    quantiles = (0.5, 0.2, 0.8)\n    if provider == \"triton_fp8\":\n        ms, min_ms, max_ms = run_bench(\n            lambda: silu_and_mul_masked_post_quant_fwd(\n                x,\n                fp8_out,\n                fp8_scales,\n                scale_block_size,\n                masks,\n                scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,\n            ),\n            quantiles=quantiles,\n        )\n    if provider == \"cuda_unfused_fp4\":\n        ms, min_ms, max_ms = run_bench(\n            lambda: scaled_fp4_grouped_quantize(\n                silu_and_mul(x),\n                masks,\n                glb_scales,\n            ),\n            quantiles=quantiles,\n        )\n    if provider == \"cuda_fused_fp4\":\n        ms, min_ms, max_ms = run_bench(\n            lambda: silu_and_mul_scaled_nvfp4_experts_quantize(\n                x,\n                masks,\n                glb_scales,\n            ),\n            quantiles=quantiles,\n        )\n\n    return ms, min_ms, max_ms\n\n\ndef test_accuracy():\n    E = 6\n    N_RANKS = 48\n    Ms = [128, 256, 512, 1024]\n    Ks = [2048, 4096, 7168]\n    input_dtype = torch.bfloat16\n    for M in Ms:\n        for K in Ks:\n            _test_accuracy_once(E, N_RANKS * M, K, input_dtype, \"cuda\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--save_path\",\n        type=str,\n        default=\"./bench_fp4_quant_res\",\n        help=\"Path to save fp4 quant benchmark results\",\n    )\n    args = parser.parse_args()\n\n    test_accuracy()\n\n    benchmark.run(print_data=True, show_plots=True, save_path=args.save_path)\n"
  },
  {
    "path": "benchmark/kernels/quantization/bench_int8_quant.py",
    "content": "import argparse\n\nimport torch\nimport triton\nfrom vllm._custom_ops import scaled_int8_quant as vllm_scaled_int8_quant\n\nfrom sglang.benchmark.bench_utils import run_bench\nfrom sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8\n\n\n@torch.compile(backend=\"inductor\")\ndef torch_int8_quant(x):\n    int8_max = torch.iinfo(torch.int8).max\n\n    abs_max = x.abs().max(dim=-1, keepdim=True).values\n    scales = abs_max.to(torch.float32) / float(int8_max)\n\n    q_x = (x / scales).round().to(torch.int8)\n\n    return q_x, scales\n\n\ndef _test_accuracy_once(M, K, input_dtype, device):\n    x = torch.randn(M, K, dtype=input_dtype, device=device) * 5000\n    out, scales, _ = vllm_scaled_int8_quant(x, symmetric=True)\n    out1, scales1 = per_token_quant_int8(x)\n    out2, scales2 = torch_int8_quant(x)\n    torch.testing.assert_close(out, out2, atol=1, rtol=0)\n    torch.testing.assert_close(out, out1, atol=1, rtol=0)\n    torch.testing.assert_close(scales, scales2)\n    torch.testing.assert_close(scales1, scales2)\n    print(f\"M: {M}, K: {K}, type: {input_dtype} OK\")\n\n\ndef test_accuracy():\n    Ms = [1, 13, 128, 1024, 2048, 4096]\n    Ks = [512, 1024, 2048, 8192]\n    input_dtypes = [torch.float16, torch.bfloat16]\n    for M in Ms:\n        for K in Ks:\n            for input_dtype in input_dtypes:\n                _test_accuracy_once(M, K, input_dtype, \"cuda\")\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=[\"batch_size\"],\n        x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048],\n        x_log=False,\n        line_arg=\"provider\",\n        line_vals=[\"vllm op\", \"triton\", \"torch.compile\"],\n        line_names=[\"vllm op\", \"triton\", \"torch.compile\"],\n        styles=[(\"blue\", \"-\"), (\"orange\", \"-\"), (\"red\", \"-\")],\n        ylabel=\"ms\",\n        plot_name=\"int8 per token quant\",\n        args={},\n    )\n)\ndef benchmark(batch_size, provider):\n    M, K = batch_size, 16384\n    x = torch.randn(M, K, dtype=torch.float16, device=\"cuda\") * 1000\n\n    quantiles = (0.5, 0.2, 0.8)\n    if provider == \"vllm op\":\n        ms, min_ms, max_ms = run_bench(\n            lambda: vllm_scaled_int8_quant(x, symmetric=True),\n            quantiles=quantiles,\n        )\n    if provider == \"triton\":\n        ms, min_ms, max_ms = run_bench(\n            lambda: per_token_quant_int8(x),\n            quantiles=quantiles,\n        )\n    if provider == \"torch.compile\":\n        ms, min_ms, max_ms = run_bench(\n            lambda: torch_int8_quant(x),\n            quantiles=quantiles,\n        )\n\n    return ms, min_ms, max_ms\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--save_path\",\n        type=str,\n        default=\"./bench_int8_quant_res\",\n        help=\"Path to save int8 quant benchmark results\",\n    )\n    args = parser.parse_args()\n\n    test_accuracy()\n\n    benchmark.run(print_data=True, show_plots=True, save_path=args.save_path)\n"
  },
  {
    "path": "benchmark/kernels/quantization/tuning_block_wise_kernel.py",
    "content": "# Copyright 2025 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\nimport argparse\nimport json\nimport multiprocessing as mp\nimport os\nimport time\nfrom datetime import datetime\nfrom typing import Any, Dict, List\n\nimport torch\nimport triton\nfrom tqdm import tqdm\n\nmp.set_start_method(\"spawn\", force=True)\n\nfrom sglang.srt.layers.quantization.fp8_kernel import (\n    _w8a8_block_fp8_matmul,\n    _w8a8_block_fp8_matmul_unrolledx4,\n)\nfrom sglang.srt.layers.quantization.int8_kernel import _w8a8_block_int8_matmul\nfrom sglang.srt.utils import get_device_core_count, get_device_name, is_hip\n\n_is_hip = is_hip()\n\nDTYPE_MAP = {\n    \"float32\": torch.float32,\n    \"float16\": torch.float16,\n    \"half\": torch.half,\n    \"bfloat16\": torch.bfloat16,\n}\n\n\ndef w8a8_block_matmul(\n    A: torch.Tensor,\n    B: torch.Tensor,\n    As: torch.Tensor,\n    Bs: torch.Tensor,\n    block_size: List[int],\n    config: Dict[str, Any],\n    output_dtype: torch.dtype = torch.float16,\n) -> torch.Tensor:\n    \"\"\"This function performs matrix multiplication with block-wise quantization.\n\n    It takes two input tensors `A` and `B` with scales `As` and `Bs`.\n    The output is returned in the specified `output_dtype`.\n\n    Args:\n        A: The input tensor, e.g., activation.\n        B: The input tensor, e.g., weight.\n        As: The per-token-group quantization scale for `A`.\n        Bs: The per-block quantization scale for `B`.\n        block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128].\n        output_dytpe: The dtype of the returned tensor.\n\n    Returns:\n        torch.Tensor: The result of matmul.\n    \"\"\"\n    assert len(block_size) == 2\n    block_n, block_k = block_size[0], block_size[1]\n\n    assert A.shape[-1] == B.shape[-1]\n    assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()\n    assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]\n    M = A.numel() // A.shape[-1]\n\n    assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2\n    N, K = B.shape\n    assert triton.cdiv(N, block_n) == Bs.shape[0]\n    assert triton.cdiv(K, block_k) == Bs.shape[1]\n\n    C_shape = A.shape[:-1] + (N,)\n    C = A.new_empty(C_shape, dtype=output_dtype)\n\n    needs_masking = bool(K % config[\"BLOCK_SIZE_K\"] != 0)\n\n    def grid(META):\n        return (\n            triton.cdiv(M, META[\"BLOCK_SIZE_M\"]) * triton.cdiv(N, META[\"BLOCK_SIZE_N\"]),\n        )\n\n    # Use manually unrolledx4 kernel on AMD GPU when the grid size is small.\n    # Empirical testing shows the sweet spot lies when it's less than the # of\n    # compute units available on the device.\n    num_workgroups = triton.cdiv(M, config[\"BLOCK_SIZE_M\"]) * triton.cdiv(\n        N, config[\"BLOCK_SIZE_N\"]\n    )\n\n    if A.dtype == torch.float8_e4m3fnuz or A.dtype == torch.float8_e4m3fn:\n        kernel = (\n            _w8a8_block_fp8_matmul_unrolledx4\n            if (_is_hip == True and num_workgroups <= get_device_core_count())\n            else _w8a8_block_fp8_matmul\n        )\n    else:\n        kernel = _w8a8_block_int8_matmul\n\n    kernel[grid](\n        A,\n        B,\n        C,\n        As,\n        Bs,\n        M,\n        N,\n        K,\n        block_n,\n        block_k,\n        A.stride(-2),\n        A.stride(-1),\n        B.stride(1),\n        B.stride(0),\n        C.stride(-2),\n        C.stride(-1),\n        As.stride(-2),\n        As.stride(-1),\n        Bs.stride(1),\n        Bs.stride(0),\n        **config,\n        needs_masking=needs_masking,\n    )\n\n    return C\n\n\ndef get_rocm_configs_compute_bound():\n    configs = []\n    waves_per_eu_range = 0\n    for num_stages in [2]:\n        for block_m in [32, 64, 128, 256]:\n            for block_k in [32, 64, 128, 256]:\n                for block_n in [16, 32, 64, 128, 256]:\n                    for num_warps in [4, 8]:\n                        for group_size in [1, 4, 8, 16, 32]:\n                            configs.append(\n                                {\n                                    \"BLOCK_SIZE_M\": block_m,\n                                    \"BLOCK_SIZE_N\": block_n,\n                                    \"BLOCK_SIZE_K\": block_k,\n                                    \"GROUP_SIZE_M\": group_size,\n                                    \"num_warps\": num_warps,\n                                    \"num_stages\": num_stages,\n                                    \"waves_per_eu\": waves_per_eu_range,\n                                }\n                            )\n    return configs\n\n\ndef get_configs_compute_bound():\n    configs = []\n    if _is_hip:\n        configs = get_rocm_configs_compute_bound()\n    else:\n        for num_stages in [2, 3, 4, 5]:\n            for block_m in [16, 32, 64, 128, 256]:\n                for block_k in [64, 128]:\n                    for block_n in [32, 64, 128, 256]:\n                        for num_warps in [4, 8]:\n                            for group_size in [1, 16, 32, 64]:\n                                configs.append(\n                                    {\n                                        \"BLOCK_SIZE_M\": block_m,\n                                        \"BLOCK_SIZE_N\": block_n,\n                                        \"BLOCK_SIZE_K\": block_k,\n                                        \"GROUP_SIZE_M\": group_size,\n                                        \"num_warps\": num_warps,\n                                        \"num_stages\": num_stages,\n                                    }\n                                )\n    return configs\n\n\ndef get_weight_shapes(tp_size):\n    # NOTE(HandH1998): The weight shapes only works for DeepSeek-V3. Modify them, if you tune for another different model.\n    # cannot TP\n    total = [\n        (512 + 64, 7168),\n        ((128 + 64) * 128, 7168),\n        (128 * (128 + 128), 512),\n        (7168, 16384),\n        (7168, 18432),\n    ]\n    # N can TP\n    n_tp = [\n        (18432 * 2, 7168),\n        ((128 + 64) * 128, 7168),\n        (128 * (128 + 128), 512),\n        (24576, 1536),\n        (4096, 7168),\n    ]\n    # K can TP\n    k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)]\n\n    weight_shapes = []\n    for t in total:\n        weight_shapes.append(t)\n    for n_t in n_tp:\n        new_t = (n_t[0] // tp_size, n_t[1])\n        weight_shapes.append(new_t)\n    for k_t in k_tp:\n        new_t = (k_t[0], k_t[1] // tp_size)\n        weight_shapes.append(new_t)\n    return weight_shapes\n\n\ndef benchmark_config(\n    A, B, As, Bs, block_size, config, out_dtype=torch.float16, num_iters=10\n):\n    def run():\n        w8a8_block_matmul(A, B, As, Bs, block_size, config, out_dtype)\n\n    torch.cuda.synchronize()\n    # JIT complication & warmup\n    for _ in range(5):\n        run()\n    torch.cuda.synchronize()\n\n    start_event = torch.cuda.Event(enable_timing=True)\n    end_event = torch.cuda.Event(enable_timing=True)\n\n    latencies: List[float] = []\n    for i in range(num_iters):\n        torch.cuda.synchronize()\n        start_event.record()\n        run()\n        end_event.record()\n        end_event.synchronize()\n        latencies.append(start_event.elapsed_time(end_event))\n    avg = sum(latencies) / (num_iters * 10) * 1000  # us\n    return avg\n\n\ndef tune(M, N, K, block_size, out_dtype, search_space, input_type):\n    factor_for_scale = 1e-2\n\n    if input_type == \"fp8\":\n        fp8_info = torch.finfo(\n            torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn\n        )\n        fp8_max, fp8_min = fp8_info.max, fp8_info.min\n\n        A_fp32 = (\n            (torch.rand(M, K, dtype=torch.float32, device=\"cuda\") - 0.5) * 2 * fp8_max\n        )\n        A = A_fp32.clamp(min=fp8_min, max=fp8_max).to(\n            torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn\n        )\n\n        B_fp32 = (\n            (torch.rand(N, K, dtype=torch.float32, device=\"cuda\") - 0.5) * 2 * fp8_max\n        )\n        B = B_fp32.clamp(min=fp8_min, max=fp8_max).to(\n            torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn\n        )\n    else:\n        int8_info = torch.iinfo(torch.int8)\n        int8_max, int8_min = int8_info.max, int8_info.min\n\n        A_fp32 = (\n            (torch.rand(M, K, dtype=torch.float32, device=\"cuda\") - 0.5) * 2 * int8_max\n        )\n        A = A_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)\n\n        B_fp32 = (\n            (torch.rand(N, K, dtype=torch.float32, device=\"cuda\") - 0.5) * 2 * int8_max\n        )\n        B = B_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)\n\n    block_n, block_k = block_size[0], block_size[1]\n    n_tiles = (N + block_n - 1) // block_n\n    k_tiles = (K + block_k - 1) // block_k\n\n    As = torch.rand(M, k_tiles, dtype=torch.float32, device=\"cuda\") * factor_for_scale\n    Bs = (\n        torch.rand(n_tiles, k_tiles, dtype=torch.float32, device=\"cuda\")\n        * factor_for_scale\n    )\n\n    best_config = None\n    best_time = float(\"inf\")\n    for config in tqdm(search_space):\n        try:\n            kernel_time = benchmark_config(\n                A,\n                B,\n                As,\n                Bs,\n                block_size,\n                config,\n                out_dtype,\n                num_iters=10,\n            )\n        except triton.runtime.autotuner.OutOfResources:\n            # Some configurations may be invalid and fail to compile.\n            continue\n\n        if kernel_time < best_time:\n            best_time = kernel_time\n            best_config = config\n    now = datetime.now()\n    print(f\"{now.ctime()}] Completed tuning for batch_size={M}\")\n    assert best_config is not None\n    return best_config\n\n\ndef save_configs(\n    N,\n    K,\n    block_n,\n    block_k,\n    configs,\n    save_path,\n    input_type=\"fp8\",\n    lock=None,\n) -> None:\n    os.makedirs(save_path, exist_ok=True)\n    device_name = get_device_name().replace(\" \", \"_\")\n    json_file_name = f\"N={N},K={K},device_name={device_name},dtype={input_type}_w8a8,block_shape=[{block_n}, {block_k}].json\"\n\n    config_file_path = os.path.join(save_path, json_file_name)\n    print(f\"Writing best config to {config_file_path}...\")\n\n    if lock is not None:\n        lock.acquire()\n    try:\n        existing_configs = {}\n        if os.path.exists(config_file_path):\n            with open(config_file_path, \"r\") as f:\n                existing_configs = json.load(f)\n            existing_configs = {int(k): v for k, v in existing_configs.items()}\n\n        existing_configs.update(configs)\n\n        with open(config_file_path, \"w\") as f:\n            json.dump(existing_configs, f, indent=4)\n            f.write(\"\\n\")\n    finally:\n        if lock is not None:\n            lock.release()\n\n\ndef get_available_gpu_count():\n    \"\"\"Get the number of available GPUs.\"\"\"\n    return torch.cuda.device_count()\n\n\ndef tune_on_gpu(args_dict):\n    \"\"\"Run tuning on a specific GPU.\"\"\"\n    gpu_id = args_dict[\"gpu_id\"]\n    batch_sizes = args_dict[\"batch_sizes\"]\n    weight_shapes = args_dict[\"weight_shapes\"]\n    args = args_dict[\"args\"]\n    lock = args_dict[\"lock\"]\n\n    torch.cuda.set_device(gpu_id)\n    print(f\"Starting tuning on GPU {gpu_id} with batch sizes {batch_sizes}\")\n\n    block_n = args.block_n\n    block_k = args.block_k\n    out_dtype = DTYPE_MAP[args.out_dtype]\n    save_path = args.save_path\n    input_type = args.input_type\n\n    search_space = get_configs_compute_bound()\n    search_space = [\n        config for config in search_space if block_k % config[\"BLOCK_SIZE_K\"] == 0\n    ]\n\n    start = time.perf_counter()\n    results = {}\n    for shape in tqdm(weight_shapes, desc=f\"GPU {gpu_id} - Shapes\"):\n        N, K = shape[0], shape[1]\n        print(f\"[GPU {gpu_id}] Tune for weight shape of `N: {N}, K: {K}`\")\n        benchmark_results = [\n            tune(\n                batch_size,\n                N,\n                K,\n                [block_n, block_k],\n                out_dtype,\n                search_space,\n                input_type,\n            )\n            for batch_size in tqdm(batch_sizes, desc=f\"GPU {gpu_id} - Batch sizes\")\n        ]\n        best_configs = {M: config for M, config in zip(batch_sizes, benchmark_results)}\n        save_configs(N, K, block_n, block_k, best_configs, save_path, input_type, lock)\n\n    end = time.perf_counter()\n    print(f\"Tuning on GPU {gpu_id} took {end - start:.2f} seconds\")\n\n\ndef distribute_batch_sizes(batch_sizes, num_gpus):\n    \"\"\"Distribute batch sizes across available GPUs.\"\"\"\n    batches_per_gpu = []\n    for i in range(num_gpus):\n        start_idx = i * len(batch_sizes) // num_gpus\n        end_idx = (i + 1) * len(batch_sizes) // num_gpus\n        batches_per_gpu.append(batch_sizes[start_idx:end_idx])\n    return batches_per_gpu\n\n\ndef main(args):\n    print(args)\n\n    num_gpus = get_available_gpu_count()\n    if num_gpus == 0:\n        raise RuntimeError(\"No GPU available for tuning\")\n    print(f\"Found {num_gpus} GPUs for parallel tuning\")\n\n    torch.cuda.init()\n\n    if args.batch_size is None:\n        batch_sizes = [\n            1,\n            2,\n            4,\n            8,\n            16,\n            24,\n            32,\n            48,\n            64,\n            96,\n            128,\n            256,\n            512,\n            1024,\n            1536,\n            2048,\n            3072,\n            4096,\n        ]\n    else:\n        batch_sizes = [args.batch_size]\n        num_gpus = 1  # If only one batch size, use only one GPU\n\n    # Support manual N and K specification\n    if args.N is not None and args.K is not None:\n        weight_shapes = [(args.N, args.K)]\n        print(f\"Using manually specified weight shape: N={args.N}, K={args.K}\")\n    else:\n        weight_shapes = get_weight_shapes(args.tp_size)\n        print(f\"Using predefined weight shapes for TP size {args.tp_size}\")\n\n    batches_per_gpu = distribute_batch_sizes(batch_sizes, num_gpus)\n\n    ctx = mp.get_context(\"spawn\")\n    manager = ctx.Manager()\n    lock = manager.Lock()\n\n    process_args = []\n    for gpu_id in range(num_gpus):\n        process_args.append(\n            {\n                \"gpu_id\": gpu_id,\n                \"batch_sizes\": batches_per_gpu[gpu_id],\n                \"weight_shapes\": weight_shapes,  # Each GPU processes all weight shapes\n                \"args\": args,\n                \"lock\": lock,\n            }\n        )\n\n    with ctx.Pool(num_gpus) as pool:\n        pool.map(tune_on_gpu, process_args)\n\n    print(\"Multi-GPU tuning completed\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n\n    parser.add_argument(\n        \"--tp-size\",\n        \"-tp\",\n        type=int,\n        default=8,\n        help=\"Tensor parallelism size (ignored if --N and --K are specified)\",\n    )\n    parser.add_argument(\n        \"--N\",\n        type=int,\n        default=None,\n        help=\"Output dimension of weight matrix (number of columns)\",\n    )\n    parser.add_argument(\n        \"--K\",\n        type=int,\n        default=None,\n        help=\"Input dimension of weight matrix (number of rows)\",\n    )\n    parser.add_argument(\n        \"--input-type\", type=str, choices=[\"fp8\", \"int8\"], default=\"fp8\"\n    )\n    parser.add_argument(\n        \"--out-dtype\",\n        type=str,\n        choices=[\"float32\", \"float16\", \"bfloat16\", \"half\"],\n        default=\"float16\",\n    )\n    parser.add_argument(\"--block-n\", type=int, default=128)\n    parser.add_argument(\"--block-k\", type=int, default=128)\n    parser.add_argument(\"--batch-size\", type=int, required=False)\n    parser.add_argument(\n        \"--save-path\", type=str, default=\"python/sglang/srt/layers/quantization/configs\"\n    )\n    args = parser.parse_args()\n\n    # Validate arguments\n    if (args.N is None) != (args.K is None):\n        parser.error(\"--N and --K must be specified together or not at all\")\n\n    main(args)\n"
  },
  {
    "path": "benchmark/kernels/scheduler_batch/benchmark_get_last_loc_triton.py",
    "content": "import os\n\nimport torch\nimport triton\nimport triton.language as tl\n\nfrom sglang.benchmark.bench_utils import run_bench\n\n\n@torch.compile(dynamic=True)\ndef get_last_loc_torch(\n    req_to_token: torch.Tensor,\n    req_pool_indices_tensor: torch.Tensor,\n    prefix_lens_tensor: torch.Tensor,\n) -> torch.Tensor:\n    return torch.where(\n        prefix_lens_tensor > 0,\n        req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1],\n        torch.full_like(prefix_lens_tensor, -1),\n    )\n\n\n@triton.jit\ndef get_last_loc_kernel(\n    req_to_token,\n    req_pool_indices_tensor,\n    prefix_lens_tensor,\n    result,\n    num_tokens,\n    req_to_token_stride,\n    BLOCK_SIZE: tl.constexpr,\n):\n    pid = tl.program_id(0)\n    offset = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE\n    mask = offset < num_tokens\n\n    prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0)\n    req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0)\n\n    token_mask = prefix_lens > 0\n    token_index = req_pool_indices * req_to_token_stride + (prefix_lens - 1)\n    tokens = tl.load(req_to_token + token_index, mask=token_mask, other=-1)\n\n    tl.store(result + offset, tokens, mask=mask)\n\n\ndef get_last_loc_triton(\n    req_to_token: torch.Tensor,\n    req_pool_indices_tensor: torch.Tensor,\n    prefix_lens_tensor: torch.Tensor,\n) -> torch.Tensor:\n    BLOCK_SIZE = 256\n    num_tokens = prefix_lens_tensor.shape[0]\n    result = torch.empty_like(prefix_lens_tensor)\n    grid = (triton.cdiv(num_tokens, BLOCK_SIZE),)\n\n    get_last_loc_kernel[grid](\n        req_to_token,\n        req_pool_indices_tensor,\n        prefix_lens_tensor,\n        result,\n        num_tokens,\n        req_to_token.stride(0),\n        BLOCK_SIZE,\n    )\n    return result\n\n\ndef test_get_last_loc():\n    max_batch = 4097\n    max_context_len = 6148\n    batch_size = 20\n\n    # Initialize input tensors\n    req_to_token = torch.zeros(\n        (max_batch, max_context_len), dtype=torch.int32, device=\"cuda\"\n    )\n    req_pool_indices = torch.arange(batch_size, dtype=torch.int64, device=\"cuda\")\n    pre_lens = torch.randint(\n        -max_context_len // 2,\n        max_context_len,\n        (batch_size,),\n        dtype=torch.int64,\n        device=\"cuda\",\n    )\n\n    last_loc_res = get_last_loc_triton(req_to_token, req_pool_indices, pre_lens)\n    last_loc_ref = get_last_loc_torch(req_to_token, req_pool_indices, pre_lens)\n\n    # Compare results\n    torch.testing.assert_close(last_loc_res, last_loc_ref)\n\n\ndef get_benchmark():\n    batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]\n\n    @triton.testing.perf_report(\n        triton.testing.Benchmark(\n            x_names=[\"batch_size\"],\n            x_vals=batch_sizes,\n            line_arg=\"provider\",\n            line_vals=[\"reference\", \"triton\"],\n            line_names=[\"PyTorch\", \"Triton\"],\n            styles=[(\"blue\", \"-\"), (\"green\", \"-\")],\n            ylabel=\"us\",\n            plot_name=\"get-last-loc-performance\",\n            args={},\n        )\n    )\n    def benchmark(batch_size, provider):\n        max_batch = 2048\n        max_context_len = 16384\n\n        req_to_token = torch.zeros(\n            (max_batch, max_context_len), dtype=torch.int32, device=\"cuda\"\n        )\n        req_pool_indices = torch.arange(batch_size, dtype=torch.int64, device=\"cuda\")\n        pre_lens = torch.randint(\n            -max_context_len // 2,\n            max_context_len,\n            (batch_size,),\n            dtype=torch.int64,\n            device=\"cuda\",\n        )\n\n        quantiles = [0.5, 0.2, 0.8]\n\n        if provider == \"reference\":\n            ms, min_ms, max_ms = run_bench(\n                lambda: get_last_loc_torch(req_to_token, req_pool_indices, pre_lens),\n                quantiles=tuple(quantiles),\n            )\n        elif provider == \"triton\":\n            ms, min_ms, max_ms = run_bench(\n                lambda: get_last_loc_triton(req_to_token, req_pool_indices, pre_lens),\n                quantiles=tuple(quantiles),\n            )\n\n        return 1000 * ms, 1000 * max_ms, 1000 * min_ms\n\n    return benchmark\n\n\ndef run_benchmark(save_path: str = \"./configs/benchmark_ops/get_last_loc/\"):\n    \"\"\"Run benchmark and save results\"\"\"\n\n    # Ensure save path exists\n    os.makedirs(save_path, exist_ok=True)\n\n    # Run correctness test\n    test_get_last_loc()\n    print(\"Correctness test passed!\")\n\n    # Run performance test\n    benchmark = get_benchmark()\n    benchmark.run(print_data=True, save_path=save_path)\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--save_path\",\n        type=str,\n        default=\"./configs/benchmark_ops/get_last_loc/\",\n        help=\"Path to save benchmark results\",\n    )\n    args = parser.parse_args()\n\n    run_benchmark(args.save_path)\n"
  },
  {
    "path": "benchmark/kernels/scheduler_batch/benchmark_write_req_to_token_pool_triton.py",
    "content": "import itertools\nimport os\n\nimport torch\nimport triton\nimport triton.language as tl\n\nfrom sglang.benchmark.bench_utils import run_bench\n\n\n@triton.jit\ndef write_req_to_token_pool_triton(\n    req_to_token_ptr,  # [max_batch, max_context_len]\n    req_pool_indices,\n    pre_lens,\n    seq_lens,\n    extend_lens,\n    out_cache_loc,\n    req_to_token_ptr_stride: tl.constexpr,\n):\n    BLOCK_SIZE: tl.constexpr = 512\n    pid = tl.program_id(0)\n\n    req_pool_index = tl.load(req_pool_indices + pid)\n    pre_len = tl.load(pre_lens + pid)\n    seq_len = tl.load(seq_lens + pid)\n\n    # TODO: optimize this?\n    cumsum_start = 0\n    for i in range(pid):\n        cumsum_start += tl.load(extend_lens + i)\n\n    num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)\n    for i in range(num_loop):\n        offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE\n        mask = offset < (seq_len - pre_len)\n        value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)\n        tl.store(\n            req_to_token_ptr\n            + req_pool_index * req_to_token_ptr_stride\n            + offset\n            + pre_len,\n            value,\n            mask=mask,\n        )\n\n\n@triton.jit\ndef write_req_to_token_pool_triton_optimize(\n    req_to_token_ptr,  # [max_batch, max_context_len]\n    req_pool_indices,\n    pre_lens,\n    seq_lens,\n    extend_lens,\n    out_cache_loc,\n    req_to_token_ptr_stride: tl.constexpr,\n    BLOCK_SIZE: tl.constexpr,\n):\n    pid_batch = tl.program_id(0)\n    pid_token = tl.program_id(1)\n\n    req_pool_index = tl.load(req_pool_indices + pid_batch)\n    pre_len = tl.load(pre_lens + pid_batch)\n    seq_len = tl.load(seq_lens + pid_batch)\n    extend_len = seq_len - pre_len\n\n    cumsum_start = 0\n    for i in range(pid_batch):\n        cumsum_start += tl.load(extend_lens + i)\n\n    token_start = pid_token * BLOCK_SIZE\n\n    offset = tl.arange(0, BLOCK_SIZE)\n    actual_offset = token_start + offset\n    mask = actual_offset < extend_len\n\n    src_ptr = out_cache_loc + cumsum_start + actual_offset\n    src_ptr = tl.max_contiguous(tl.multiple_of(src_ptr, BLOCK_SIZE), BLOCK_SIZE)\n    value = tl.load(src_ptr, mask=mask)\n    dst_ptr = (\n        req_to_token_ptr\n        + req_pool_index * req_to_token_ptr_stride\n        + actual_offset\n        + pre_len\n    )\n    dst_ptr = tl.max_contiguous(tl.multiple_of(dst_ptr, BLOCK_SIZE), BLOCK_SIZE)\n\n    tl.store(dst_ptr, value, mask=mask)\n\n\ndef write_req_to_token_pool_reference(\n    req_to_token: torch.Tensor,\n    req_pool_indices: torch.Tensor,\n    pre_lens: torch.Tensor,\n    seq_lens: torch.Tensor,\n    extend_lens: torch.Tensor,\n    out_cache_loc: torch.Tensor,\n) -> None:\n    \"\"\"Reference implementation using PyTorch\"\"\"\n    for i in range(len(req_pool_indices)):\n        req_pool_idx = req_pool_indices[i].item()\n        pre_len = pre_lens[i].item()\n        seq_len = seq_lens[i].item()\n        extend_len = extend_lens[i].item()\n\n        cumsum_start = sum(extend_lens[:i].tolist())\n\n        # Copy values from out_cache_loc to req_to_token\n        req_to_token[req_pool_idx, pre_len:seq_len] = out_cache_loc[\n            cumsum_start : cumsum_start + extend_len\n        ]\n\n\ndef test_write_req_to_token_pool():\n    max_batch = 4097\n    max_context_len = 6148\n    batch_size = 1\n    extend_len = 14\n\n    # Initialize input tensors\n    req_to_token = torch.zeros(\n        (max_batch, max_context_len), dtype=torch.int32, device=\"cuda\"\n    )\n    req_pool_indices = torch.tensor([42], dtype=torch.int32, device=\"cuda\")\n    pre_lens = torch.tensor([8], dtype=torch.int32, device=\"cuda\")\n    seq_lens = torch.tensor([22], dtype=torch.int32, device=\"cuda\")\n    extend_lens = torch.tensor([extend_len], dtype=torch.int32, device=\"cuda\")\n    out_cache_loc = torch.arange(extend_len, dtype=torch.int32, device=\"cuda\")\n\n    # Create copies for reference implementation\n    req_to_token_ref = req_to_token.clone()\n    req_to_token_opt = req_to_token.clone()\n\n    # Run original triton kernel\n    write_req_to_token_pool_triton[(batch_size,)](\n        req_to_token,\n        req_pool_indices,\n        pre_lens,\n        seq_lens,\n        extend_lens,\n        out_cache_loc,\n        max_context_len,\n    )\n\n    # Run optimized triton kernel\n    def grid(batch_size, extend_len):\n        num_token_blocks = triton.cdiv(extend_len, 512)\n        return (batch_size, num_token_blocks)\n\n    write_req_to_token_pool_triton_optimize[grid(batch_size, extend_len)](\n        req_to_token_opt,\n        req_pool_indices,\n        pre_lens,\n        seq_lens,\n        extend_lens,\n        out_cache_loc,\n        max_context_len,\n        BLOCK_SIZE=512,\n    )\n\n    # Run reference implementation\n    write_req_to_token_pool_reference(\n        req_to_token_ref,\n        req_pool_indices,\n        pre_lens,\n        seq_lens,\n        extend_lens,\n        out_cache_loc,\n    )\n\n    # Compare results\n    torch.testing.assert_close(req_to_token, req_to_token_ref)\n    torch.testing.assert_close(req_to_token_opt, req_to_token_ref)\n\n    # Test case 2: batch size > 1\n    batch_size = 3\n    extend_lens_list = [14, 20, 30]\n    total_extend_len = sum(extend_lens_list)\n\n    req_to_token = torch.zeros(\n        (max_batch, max_context_len), dtype=torch.int32, device=\"cuda\"\n    )\n    req_pool_indices = torch.tensor([42, 100, 200], dtype=torch.int32, device=\"cuda\")\n    pre_lens = torch.tensor([8, 10, 15], dtype=torch.int32, device=\"cuda\")\n    seq_lens = torch.tensor([22, 30, 45], dtype=torch.int32, device=\"cuda\")\n    extend_lens = torch.tensor(extend_lens_list, dtype=torch.int32, device=\"cuda\")\n    out_cache_loc = torch.arange(total_extend_len, dtype=torch.int32, device=\"cuda\")\n\n    req_to_token_ref = req_to_token.clone()\n    req_to_token_opt = req_to_token.clone()\n\n    # Run original triton kernel\n    write_req_to_token_pool_triton[(batch_size,)](\n        req_to_token,\n        req_pool_indices,\n        pre_lens,\n        seq_lens,\n        extend_lens,\n        out_cache_loc,\n        max_context_len,\n    )\n\n    # Run optimized triton kernel\n    max_extend_len = max(extend_lens_list)\n    write_req_to_token_pool_triton_optimize[grid(batch_size, max_extend_len)](\n        req_to_token_opt,\n        req_pool_indices,\n        pre_lens,\n        seq_lens,\n        extend_lens,\n        out_cache_loc,\n        max_context_len,\n        BLOCK_SIZE=512,\n    )\n\n    # Run reference implementation\n    write_req_to_token_pool_reference(\n        req_to_token_ref,\n        req_pool_indices,\n        pre_lens,\n        seq_lens,\n        extend_lens,\n        out_cache_loc,\n    )\n\n    # Compare results\n    torch.testing.assert_close(req_to_token, req_to_token_ref)\n    torch.testing.assert_close(req_to_token_opt, req_to_token_ref)\n\n\ndef get_benchmark():\n    batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128]\n    extend_lens = [32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]\n    configs = list(itertools.product(batch_sizes, extend_lens))\n\n    @triton.testing.perf_report(\n        triton.testing.Benchmark(\n            x_names=[\"batch_size\", \"extend_len\"],\n            x_vals=configs,\n            line_arg=\"provider\",\n            line_vals=[\"reference\", \"triton\", \"triton_optimize\"],\n            line_names=[\"PyTorch\", \"Triton\", \"Triton Optimized\"],\n            styles=[(\"blue\", \"-\"), (\"green\", \"-\"), (\"red\", \"-\")],\n            ylabel=\"us\",\n            plot_name=\"write-req-to-token-pool-performance\",\n            args={},\n        )\n    )\n    def benchmark(batch_size, extend_len, provider):\n        max_batch = 256\n        max_context_len = 16384\n\n        extend_lens_list = [extend_len] * batch_size\n        total_extend_len = sum(extend_lens_list)\n\n        req_to_token = torch.zeros(\n            (max_batch, max_context_len), dtype=torch.int32, device=\"cuda\"\n        )\n        req_pool_indices = torch.arange(batch_size, dtype=torch.int32, device=\"cuda\")\n        pre_lens = torch.ones(batch_size, dtype=torch.int32, device=\"cuda\") * 8\n        seq_lens = pre_lens + extend_len\n        extend_lens = torch.tensor(extend_lens_list, dtype=torch.int32, device=\"cuda\")\n        out_cache_loc = torch.arange(total_extend_len, dtype=torch.int32, device=\"cuda\")\n\n        quantiles = [0.5, 0.2, 0.8]\n\n        if provider == \"reference\":\n            ms, min_ms, max_ms = run_bench(\n                lambda: write_req_to_token_pool_reference(\n                    req_to_token.clone(),\n                    req_pool_indices,\n                    pre_lens,\n                    seq_lens,\n                    extend_lens,\n                    out_cache_loc,\n                ),\n                quantiles=tuple(quantiles),\n            )\n        elif provider == \"triton\":\n            ms, min_ms, max_ms = run_bench(\n                lambda: write_req_to_token_pool_triton[(batch_size,)](\n                    req_to_token.clone(),\n                    req_pool_indices,\n                    pre_lens,\n                    seq_lens,\n                    extend_lens,\n                    out_cache_loc,\n                    max_context_len,\n                ),\n                quantiles=tuple(quantiles),\n            )\n        else:\n\n            def run_optimized():\n                block_size = 128 if extend_len <= 1024 else 512\n                grid_config = (batch_size, triton.cdiv(extend_len, block_size))\n                write_req_to_token_pool_triton_optimize[grid_config](\n                    req_to_token.clone(),\n                    req_pool_indices,\n                    pre_lens,\n                    seq_lens,\n                    extend_lens,\n                    out_cache_loc,\n                    max_context_len,\n                    BLOCK_SIZE=block_size,\n                )\n\n            ms, min_ms, max_ms = run_bench(run_optimized, quantiles=tuple(quantiles))\n\n        return 1000 * ms, 1000 * max_ms, 1000 * min_ms\n\n    return benchmark\n\n\ndef run_benchmark(save_path: str = \"./configs/benchmark_ops/write_req_to_token_pool/\"):\n    \"\"\"Run benchmark and save results\"\"\"\n\n    # Ensure save path exists\n    os.makedirs(save_path, exist_ok=True)\n\n    # Run correctness test\n    test_write_req_to_token_pool()\n    print(\"Correctness test passed!\")\n\n    # Run performance test\n    benchmark = get_benchmark()\n    benchmark.run(print_data=True, save_path=save_path)\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--save_path\",\n        type=str,\n        default=\"./configs/benchmark_ops/write_req_to_token_pool/\",\n        help=\"Path to save benchmark results\",\n    )\n    args = parser.parse_args()\n\n    run_benchmark(args.save_path)\n"
  },
  {
    "path": "benchmark/kernels/sliding_window_attention_triton/bench_triton_swa_kernel.py",
    "content": "import itertools\n\nimport torch\nimport torch.nn.functional as F\nimport triton.testing as tt\n\nfrom sglang.benchmark.bench_utils import run_bench\nfrom sglang.srt.layers.attention.triton_ops.extend_attention import extend_attention_fwd\n\n\ndef extend_attention_fwd_torch(\n    q: torch.Tensor,  # [extend_tokens, H_Q, D]\n    k: torch.Tensor,  # [extend_tokens, H_KV, D]\n    v: torch.Tensor,  # [extend_tokens, H_KV, D]\n    o: torch.Tensor,  # [extend_tokens, H_Q, D]\n    k_cache: torch.Tensor,  # [total_tokens, H_KV, D]\n    v_cache: torch.Tensor,  # [total_tokens, H_KV, D]\n    qo_indptr: torch.Tensor,  # [B+1]\n    kv_indptr: torch.Tensor,  # [B+1]\n    kv_indices: torch.Tensor,  # [prefix_tokens]\n    sliding_window_size: int,\n):\n    B = qo_indptr.size(0) - 1\n    _, H_Q, D = q.shape\n    _, H_KV, _ = k.shape\n\n    group_size = H_Q // H_KV\n    scale = 1.0 / D**0.5\n\n    for i in range(B):\n        q_start = int(qo_indptr[i].item())\n        q_end = int(qo_indptr[i + 1].item())\n        kv_start = int(kv_indptr[i].item())\n        kv_end = int(kv_indptr[i + 1].item())\n\n        prefix_indices = kv_indices[kv_start:kv_end]\n        k_prefix = k_cache[prefix_indices]  # [prefix_len, H_KV, D]\n        v_prefix = v_cache[prefix_indices]  # [prefix_len, H_KV, D]\n\n        k_extend = k[q_start:q_end]  # [extend_len, H_KV, D]\n        v_extend = v[q_start:q_end]  # [extend_len, H_KV, D]\n        q_extend = q[q_start:q_end]  # [extend_len, H_Q,  D]\n\n        k_full = torch.cat([k_prefix, k_extend], dim=0)  # [total_len, H_KV, D]\n        v_full = torch.cat([v_prefix, v_extend], dim=0)  # [total_len, H_KV, D]\n\n        if group_size != 1:\n            k_full_hq = k_full.repeat_interleave(\n                group_size, dim=1\n            )  # [total_len, H_Q, D]\n            v_full_hq = v_full.repeat_interleave(\n                group_size, dim=1\n            )  # [total_len, H_Q, D]\n        else:\n            k_full_hq = k_full\n            v_full_hq = v_full\n\n        prefix_len = k_prefix.size(0)\n        extend_len = k_extend.size(0)\n        total_len = prefix_len + extend_len\n\n        # causal\n        pos_keys = torch.arange(total_len, device=q.device)\n        t = prefix_len + torch.arange(extend_len, device=q.device)  # [extend_len]\n        causal_mask = pos_keys.unsqueeze(0) <= t.unsqueeze(1)\n\n        # sliding window\n        if sliding_window_size is not None and sliding_window_size > 0:\n            start = (t - (sliding_window_size)).clamp_min(0)  # [extend_len]\n        else:\n            start = torch.zeros_like(t)\n        window_mask = pos_keys.unsqueeze(0) >= start.unsqueeze(1)\n\n        final_mask = causal_mask & window_mask\n\n        attn_scores = (\n            torch.einsum(\"qhd,khd->qhk\", q_extend, k_full_hq) * scale\n        )  # [extend_len, H_Q, total_len]\n        attn_scores = attn_scores.masked_fill(~final_mask.unsqueeze(1), float(\"-inf\"))\n\n        attn_weights = F.softmax(attn_scores, dim=-1)\n        o[q_start:q_end] = torch.einsum(\"qhk,khd->qhd\", attn_weights, v_full_hq)\n\n\ndef _build_batch(\n    B, N_CTX, H_Q, H_KV, D, WINDOW_SIZE, dtype=torch.bfloat16, device=\"cuda\"\n):\n    b_seq_len_prefix = torch.randint(\n        1, max(2, N_CTX // 2), (B,), dtype=torch.int32, device=device\n    )\n    b_seq_len_extend = torch.randint(\n        1, max(2, N_CTX // 2), (B,), dtype=torch.int32, device=device\n    )\n    b_seq_len = b_seq_len_prefix + b_seq_len_extend\n\n    b_start_loc = torch.zeros((B,), dtype=torch.int32, device=device)\n    b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)\n    b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device=device)\n    b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)\n\n    kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device)\n    kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0)\n\n    kv_indices = torch.zeros(\n        (int(b_seq_len_prefix.sum().item()),), dtype=torch.int32, device=device\n    )\n    for i in range(B):\n        s = kv_indptr[i].item()\n        e = kv_indptr[i + 1].item()\n        kv_indices[s:e] = torch.arange(\n            b_start_loc[i],\n            b_start_loc[i] + b_seq_len_prefix[i],\n            dtype=torch.int32,\n            device=device,\n        )\n\n    total_token_num = int(torch.sum(b_seq_len).item())\n    extend_token_num = int(torch.sum(b_seq_len_extend).item())\n\n    k_buffer = torch.empty(\n        (total_token_num, H_KV, D), dtype=dtype, device=device\n    ).normal_(mean=0.1, std=0.2)\n    v_buffer = torch.empty(\n        (total_token_num, H_KV, D), dtype=dtype, device=device\n    ).normal_(mean=0.1, std=0.2)\n\n    k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device=device)\n    v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device=device)\n    q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device)\n\n    for i in range(B):\n        extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i]\n        extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]\n        extend_start = b_start_loc_extend[i]\n        extend_end = b_start_loc_extend[i] + b_seq_len_extend[i]\n\n        k_extend[extend_start:extend_end] = k_buffer[\n            extend_start_in_buffer:extend_end_in_buffer\n        ]\n        v_extend[extend_start:extend_end] = v_buffer[\n            extend_start_in_buffer:extend_end_in_buffer\n        ]\n        q_extend[extend_start:extend_end] = torch.empty(\n            (int(b_seq_len_extend[i].item()), H_Q, D), dtype=dtype, device=device\n        ).normal_(mean=0.1, std=0.2)\n\n    o_extend_triton = torch.empty(\n        (extend_token_num, H_Q, D), dtype=dtype, device=device\n    )\n    o_extend_torch = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device)\n\n    b_seq_len_extend = b_seq_len - b_seq_len_prefix\n    max_len_extend = int(torch.max(b_seq_len_extend, 0)[0].item())\n    qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device)\n    qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0)\n\n    inputs = dict(\n        q_extend=q_extend,\n        k_extend=k_extend,\n        v_extend=v_extend,\n        k_buffer=k_buffer,\n        v_buffer=v_buffer,\n        o_extend_triton=o_extend_triton,\n        o_extend_torch=o_extend_torch,\n        qo_indptr=qo_indptr,\n        kv_indptr=kv_indptr,\n        kv_indices=kv_indices,\n        max_len_extend=max_len_extend,\n        WINDOW_SIZE=WINDOW_SIZE,\n    )\n    meta = dict(\n        B=B, N_CTX=N_CTX, H_Q=H_Q, H_KV=H_KV, D=D, extend_token_num=extend_token_num\n    )\n    return inputs, meta\n\n\ndef _run_triton(inputs):\n    extend_attention_fwd(\n        inputs[\"q_extend\"],\n        inputs[\"k_extend\"],\n        inputs[\"v_extend\"],\n        inputs[\"o_extend_triton\"],\n        inputs[\"k_buffer\"],\n        inputs[\"v_buffer\"],\n        inputs[\"qo_indptr\"],\n        inputs[\"kv_indptr\"],\n        inputs[\"kv_indices\"],\n        custom_mask=None,\n        is_causal=True,\n        mask_indptr=None,\n        max_len_extend=inputs[\"max_len_extend\"],\n        sliding_window_size=inputs[\"WINDOW_SIZE\"],\n    )\n\n\ndef _run_torch_ref(inputs):\n    extend_attention_fwd_torch(\n        inputs[\"q_extend\"],\n        inputs[\"k_extend\"],\n        inputs[\"v_extend\"],\n        inputs[\"o_extend_torch\"],\n        inputs[\"k_buffer\"],\n        inputs[\"v_buffer\"],\n        inputs[\"qo_indptr\"],\n        inputs[\"kv_indptr\"],\n        inputs[\"kv_indices\"],\n        inputs[\"WINDOW_SIZE\"],\n    )\n\n\nN_CTXS = [1024, 2048, 4096, 8192]\nWINDOW_SIZES = [-1, 127, 256, 512]\n\nCONFIGS = list(itertools.product(N_CTXS, WINDOW_SIZES))\n\nPROVIDERS = [\"torch\", \"triton\"]\n\n\n@tt.perf_report(\n    tt.Benchmark(\n        x_names=[\"N_CTX\", \"WINDOW_SIZE\"],\n        x_vals=CONFIGS,\n        line_arg=\"provider\",\n        line_vals=PROVIDERS,\n        line_names=PROVIDERS,\n        ylabel=\"Runtime (ms)\",\n        plot_name=\"extend_attention_triton_vs_torch\",\n        args={\n            \"B\": 32,\n            \"H_Q\": 64,\n            \"H_KV\": 8,\n            \"D\": 128,\n            \"dtype\": \"bf16\",\n            \"device\": \"cuda\",\n            \"check_correctness\": False,\n            \"warmup\": 25,\n            \"rep\": 100,\n        },\n    )\n)\ndef bench(\n    N_CTX,\n    provider,\n    B,\n    H_Q,\n    H_KV,\n    D,\n    dtype,\n    device,\n    WINDOW_SIZE,\n    check_correctness,\n    warmup,\n    rep,\n):\n    torch.manual_seed(0)\n    torch.cuda.manual_seed(0)\n    dtype_map = {\"bf16\": torch.bfloat16, \"fp16\": torch.float16, \"fp32\": torch.float32}\n    dt = dtype_map[dtype]\n\n    inputs, _ = _build_batch(\n        B, N_CTX, H_Q, H_KV, D, WINDOW_SIZE, dtype=dt, device=device\n    )\n\n    if check_correctness and provider == \"triton\":\n        _run_triton(inputs)\n        _run_torch_ref(inputs)\n        torch.cuda.synchronize()\n        if not torch.allclose(\n            inputs[\"o_extend_triton\"], inputs[\"o_extend_torch\"], rtol=1e-3, atol=1e-3\n        ):\n            raise AssertionError(\"Mismatch between triton and torch reference.\")\n\n    if provider == \"triton\":\n        ms = run_bench(\n            lambda: _run_triton(inputs),\n            quantiles=None,\n            warmup_ms=warmup,\n            rep_ms=rep,\n        )[0]\n    elif provider == \"torch\":\n        ms = run_bench(\n            lambda: _run_torch_ref(inputs),\n            quantiles=None,\n            warmup_ms=warmup,\n            rep_ms=rep,\n        )[0]\n    else:\n        raise ValueError(provider)\n\n    return ms\n\n\nif __name__ == \"__main__\":\n    bench.run(print_data=True, show_plots=False)\n"
  },
  {
    "path": "benchmark/line_retrieval/README.md",
    "content": "## Download data\n\n```\nwget https://raw.githubusercontent.com/merrymercy/merrymercy.github.io/master/files/random_words.json\npython3 gen_data.py --number 1000\n```\n\n## Run benchmark\n\n### Benchmark sglang\n```\npython3 -m sglang.launch_server --model-path codellama/CodeLlama-7b-hf --port 30000\n```\n\n```\npython3 bench_sglang.py --src-index 600 --num-q 50 --parallel 1\n```\n\n\n###\n\n```\n# original\nAccuracy: 0.940, latency: 332.83 s\n\n# parallel encoding (no_adjust, offset = 1000)\nAccuracy: 0.760, latency: 238.46 s\n\n# parallel encoding (no_adjust, offset = 3000)\nAccuracy: 0.760, latency: 238.46 s\n\n# parallel encoding (no_adjust, offset = 0)\nAccuracy: 0.520, latency: 238.46 s\n\n# parallel encoding (adjust_cache)\nAccuracy: 0.460, latency: 257.66 s\n```\n"
  },
  {
    "path": "benchmark/line_retrieval/bench_sglang.py",
    "content": "import argparse\nimport json\nimport re\nimport time\n\nimport numpy as np\n\nimport sglang as sgl\nfrom sglang.test.test_utils import (\n    add_common_sglang_args_and_parse,\n    select_sglang_backend,\n)\nfrom sglang.utils import dump_state_text\n\n\n@sgl.function\ndef line_retrieval(s, prefix, suffix, body_0, body_1, body_2, body_3):\n    s += prefix + \"\\n\"\n\n    contexts = [body_0, body_1, body_2, body_3]\n    position_ids_offset = [i * 1000 for i in range(len(contexts))]\n    forks = s.fork(len(contexts), position_ids_offset)\n    forks += lambda i: contexts[i] + \"\\n\"\n    forks.join(mode=\"concate_and_append\")\n\n    s += \"\\n\" + suffix\n    s += sgl.gen(\"answer\", max_tokens=16)\n\n\ndef eval_model(args, line_obj, num_hoops, src_indices, dst_percents):\n    arguments = []\n    labels = []\n    sum_src_indices = []\n    sum_dst_indices = []\n\n    for i in range(len(src_indices)):\n        for j in range(len(dst_percents)):\n            src_index = src_indices[i]\n            dst_percent = dst_percents[j]\n\n            query_indices = line_obj[\"group_by_num_hoops\"][str(num_hoops)]\n            query_indices = [\n                q\n                for q in query_indices\n                if all(l <= src_index for l in line_obj[\"links\"][q]) and q < src_index\n            ]\n            dst_index = query_indices[\n                min(int(len(query_indices) * dst_percent), len(query_indices) - 1)\n            ]\n            label = line_obj[\"values\"][dst_index]\n\n            body = line_obj[\"lines\"][: src_index + 1]\n            suffix = line_obj[\"suffix\"].replace(\"???\", line_obj[\"indices\"][dst_index])\n            body_part_len = len(body) // 4\n\n            arguments.append(\n                {\n                    \"prefix\": line_obj[\"prefix\"],\n                    \"body_0\": \"\\n\".join(body[:body_part_len]),\n                    \"body_1\": \"\\n\".join(body[body_part_len : 2 * body_part_len]),\n                    \"body_2\": \"\\n\".join(body[2 * body_part_len : 3 * body_part_len]),\n                    \"body_3\": \"\\n\".join(body[3 * body_part_len :]),\n                    \"suffix\": suffix,\n                }\n            )\n            labels.append(label)\n            sum_src_indices.append(src_index)\n            sum_dst_indices.append(dst_index)\n\n    # Select backend\n    backend = select_sglang_backend(args)\n\n    tic = time.perf_counter()\n    states = line_retrieval.run_batch(\n        arguments,\n        temperature=0,\n        backend=backend,\n        num_threads=args.parallel,\n        progress_bar=True,\n    )\n    latency = time.perf_counter() - tic\n\n    corrects = []\n    for i in range(len(arguments)):\n        output = states[i][\"answer\"]\n        prompt_len = states[i].get_meta_info(\"answer\").get(\"prompt_length\", -1)\n        label = labels[i]\n\n        # Try all numbers\n        findall = re.findall(\"\\d+\", output)\n        if not findall:\n            response_number = output\n        else:\n            for response_number in findall:\n                if response_number == label:\n                    break\n\n        correct = response_number == label\n        corrects.append(correct)\n\n        # Log results\n        summary = (\n            f\"Line index: {sum_src_indices[i]} -> {sum_dst_indices[i]}, \"\n            f\"Prompt len: {prompt_len}, \"\n            f\"Correct: {correct}, \"\n            f\"Label: {label}, Predicted: {response_number}, \"\n        )\n        print(summary)\n\n    accuracy = np.mean(corrects)\n    print(f\"Accuracy: {accuracy:.3f}, latency: {latency:.2f} s\")\n\n    # Write results\n    dump_state_text(f\"tmp_output_{args.backend}.txt\", states)\n\n    with open(args.result_file, \"a\") as fout:\n        value = {\n            \"task\": \"line_retrieval\",\n            \"backend\": args.backend,\n            \"num_gpus\": 1,\n            \"latency\": round(latency, 3),\n            \"num_requests\": len(arguments),\n            \"other\": {\n                \"num_questions\": len(arguments),\n                \"parallel\": args.parallel,\n            },\n        }\n        fout.write(json.dumps(value) + \"\\n\")\n\n\ndef main(args):\n    line_obj = json.load(open(args.data_path, \"r\"))\n\n    num_hoops = args.num_hoops\n    for src_index in args.src_index:\n        src_indices = [src_index]\n        num_queries = args.num_queries_per_src\n        dst_percents = [i * (1 / (num_queries)) for i in range(num_queries)]\n        eval_model(args, line_obj, num_hoops, src_indices, dst_percents)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--data-path\", type=str, default=\"lines_1000_0.0.json\")\n    parser.add_argument(\"--src-index\", type=int, nargs=\"+\", default=[100])\n    parser.add_argument(\"--num-queries-per-src\", type=int, default=10)\n    parser.add_argument(\"--num-hoops\", type=int, default=1)\n    args = add_common_sglang_args_and_parse(parser)\n    main(args)\n"
  },
  {
    "path": "benchmark/line_retrieval/gen_data.py",
    "content": "\"\"\"\nGenerate line data for line retrieval task.\n\nUsage:\npython3 gen_data.py --number 1000\n\"\"\"\n\nimport argparse\nimport json\nfrom collections import defaultdict\n\nimport numpy as np\nfrom tqdm import tqdm\n\n\ndef generate_lines(random_words, num_lines, redirect_ratio):\n    prefix = \"Here is a list of lines, each with its corresponding REGISTER_CONTENT value. Please memorize them. Be prepared to provide the REGISTER_CONTENT value for a specific line index when I ask.\"\n    suffix = \"The list has ended. Please give the final REGISTER_CONTENT value for a specific line after resolving the redirections and references. For example, the REGISTER_CONTENT of Line __idx0__ is __val0__. The REGISTER_CONTENT of Line __idx1__ is __val1__. The REGISTER_CONTENT of Line __idx2__ is __val2__. The REGISTER_CONTENT of Line ??? is\"\n\n    # Raw lines\n    visited_indices = set([None])\n    visited_values = set([None])\n\n    lines = []\n    redirects = []\n    indices = []\n    values = []\n    for i in tqdm(range(num_lines)):\n        line_index = None\n        while line_index in visited_indices:\n            line_index = \"-\".join(np.random.choice(random_words, size=(2,)))\n        visited_indices.add(line_index)\n\n        line_value = np.random.randint(low=0, high=999999)\n        line_value = f\"{line_value:06}\"\n\n        line = f\"Line {line_index}: The REGISTER_CONTENT is {line_value}.\"\n        lines.append(line)\n        redirects.append(None)\n        indices.append(line_index)\n        values.append(line_value)\n\n    # Add redirect\n    if redirect_ratio > 0:\n        num_redirect_lines = int(len(lines) * redirect_ratio)\n        redirect_indices = np.random.choice(\n            np.arange(len(lines)), size=(num_redirect_lines,), replace=False\n        )\n        for i in redirect_indices:\n            target_idx = np.random.choice(min(i * 2 + 100, num_lines))\n            lines[i] = (\n                f\"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}.\"\n            )\n            redirects[i] = target_idx\n\n    # Build links and find sources\n    links = [[] for _ in range(num_lines)]\n    contains_ring = set()\n    for i in range(num_lines):\n        if redirects[i] is None:\n            continue\n\n        tmp_link = []\n        cur = i\n        visited = set()\n        while redirects[cur] is not None:\n            visited.add(cur)\n            tmp_link.append(redirects[cur])\n            cur = redirects[cur]\n\n            if cur in visited:\n                contains_ring.add(i)\n                tmp_link = None\n                break\n        values[i] = values[cur]\n        links[i] = tmp_link\n\n    # Group by num_links\n    group_by_num_hoops = defaultdict(list)\n    for i in range(num_lines):\n        if i in contains_ring:\n            continue\n        group_by_num_hoops[len(links[i]) + 1].append(i)\n\n    keys = sorted(list(group_by_num_hoops.keys()))\n    for num_links in keys:\n        print(f\"#links: {num_links}, #lines: {len(group_by_num_hoops[num_links])}\")\n\n    # Append few-shot examples\n    hoop1_candidates = list(group_by_num_hoops[1])\n    hoop1_candidate_keys = {c: max([c] + links[c]) for c in hoop1_candidates}\n    hoop1_candidates.sort(key=lambda c: hoop1_candidate_keys[c])\n    hoop2_candidates = list(group_by_num_hoops[2])\n    hoop2_candidate_keys = {c: max([c] + links[c]) for c in hoop2_candidates}\n    hoop2_candidates.sort(key=lambda c: hoop2_candidate_keys[c])\n\n    i = hoop1_candidates[5]\n    suffix = suffix.replace(\"__idx0__\", indices[i]).replace(\"__val0__\", values[i])\n    if len(hoop2_candidates):\n        i = hoop2_candidates[0]\n        suffix = suffix.replace(\"__idx1__\", indices[i]).replace(\"__val1__\", values[i])\n        i = hoop2_candidates[1]\n        suffix = suffix.replace(\"__idx2__\", indices[i]).replace(\"__val2__\", values[i])\n    else:\n        i = hoop1_candidates[1]\n        suffix = suffix.replace(\"__idx1__\", indices[i]).replace(\"__val1__\", values[i])\n        i = hoop1_candidates[10]\n        suffix = suffix.replace(\"__idx2__\", indices[i]).replace(\"__val2__\", values[i])\n\n    obj = {\n        \"prefix\": prefix,\n        \"suffix\": suffix,\n        \"lines\": lines,\n        \"indices\": indices,\n        \"values\": values,\n        \"links\": links,\n        \"group_by_num_hoops\": group_by_num_hoops,\n        \"contains_ring\": sorted(list(contains_ring)),\n    }\n    return obj\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--number\", type=int)\n    parser.add_argument(\"--redirect-ratio\", type=float, default=0.0)\n    args = parser.parse_args()\n\n    num_lines = args.number\n\n    random_words_filename = \"random_words.json\"\n    random_words = json.load(open(random_words_filename, \"r\"))\n\n    np.random.seed(42)\n    obj = generate_lines(random_words, num_lines, args.redirect_ratio)\n\n    fout = f\"lines_{num_lines}_{args.redirect_ratio:.1f}.json\"\n    with open(fout, \"w\") as fout:\n        json.dump(obj, fout, indent=2)\n"
  },
  {
    "path": "benchmark/llava_bench/README.md",
    "content": "## Download benchmark images\n\n```\npython3 download_images.py\n```\n\nimage benchmark source: https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild\n\n### Other Dependency\n```\npip3 install \"sglang[all]\"\npip3 install \"torch>=2.1.2\" \"transformers>=4.36\" pillow\n```\n\n## Run benchmark\n\n### Benchmark sglang\nLaunch a server\n```\npython3 -m sglang.launch_server --model-path liuhaotian/llava-v1.6-vicuna-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000\n```\n\nRun benchmark\n```\n# Run with local models\npython3 bench_sglang.py --num-questions 60\n\n# Run with OpenAI models\npython3 bench_sglang.py --num-questions 60 --backend gpt-4-vision-preview\n```\n\n### Bench LLaVA original code\n```\ngit clone git@github.com:haotian-liu/LLaVA.git\ncd LLaVA\ngit reset --hard 9a26bd1435b4ac42c282757f2c16d34226575e96\npip3 install -e .\n\ncd ~/sglang/benchmark/llava_bench\nCUDA_VISIBLE_DEVICES=0 bash bench_hf_llava_bench.sh\n```\n\n\n### Benchmark llama.cpp\n\n```\n# Install\nCMAKE_ARGS=\"-DLLAMA_CUBLAS=on\" pip install llama-cpp-python\npip install sse_starlette starlette_context pydantic_settings\n\n# Download weights\nmkdir -p ~/model_weights/llava-v1.5-7b/\nwget https://huggingface.co/mys/ggml_llava-v1.5-7b/resolve/main/ggml-model-f16.gguf -O ~/model_weights/llava-v1.5-7b/ggml-model-f16.gguf\nwget https://huggingface.co/mys/ggml_llava-v1.5-7b/resolve/main/mmproj-model-f16.gguf -O ~/model_weights/llava-v1.5-7b/mmproj-model-f16.gguf\n```\n\n```\npython3 -m llama_cpp.server --model ~/model_weights/llava-v1.5-7b/ggml-model-f16.gguf --clip_model_path ~/model_weights/llava-v1.5-7b/mmproj-model-f16.gguf --chat_format llava-1-5 --port 23000\n\nOPENAI_BASE_URL=http://localhost:23000/v1 python3 bench_sglang.py --backend gpt-4-vision-preview --num-q 1\n```\n"
  },
  {
    "path": "benchmark/llava_bench/bench_hf_llava_bench.sh",
    "content": "#!/bin/bash\n\npython -m llava.eval.model_vqa \\\n    --model-path liuhaotian/llava-v1.5-7b \\\n    --question-file ./questions.jsonl \\\n    --image-folder ./images \\\n    --answers-file ./answers_hf.jsonl \\\n    --temperature 0 \\\n    --conv-mode vicuna_v1\n"
  },
  {
    "path": "benchmark/llava_bench/bench_hf_mme.sh",
    "content": "#!/bin/bash\n\npython -m llava.eval.model_vqa_loader \\\n    --model-path liuhaotian/llava-v1.5-7b \\\n    --question-file ./mme_pack/llava_mme_bench_replace.jsonl \\\n    --image-folder ./mme_pack/MME_Benchmark_release_version \\\n    --answers-file ./answers_hf_mme.jsonl \\\n    --temperature 0 \\\n    --conv-mode vicuna_v1\n"
  },
  {
    "path": "benchmark/llava_bench/bench_sglang.py",
    "content": "import argparse\nimport json\nimport os\nimport time\n\nimport tqdm\n\nimport sglang as sgl\nfrom sglang.test.test_utils import (\n    add_common_sglang_args_and_parse,\n    select_sglang_backend,\n)\nfrom sglang.utils import dump_state_text, read_jsonl\n\n\n@sgl.function\ndef image_qa(s, image_file, question):\n    s += sgl.user(sgl.image(image_file) + question)\n    s += sgl.assistant(sgl.gen(\"answer\", max_tokens=args.max_tokens))\n\n\ndef main(args):\n    lines = list(read_jsonl(args.question_file))[: args.num_questions]\n    arguments = [\n        {\n            \"image_file\": os.path.abspath(args.image_folder + \"/\" + l[\"image\"]),\n            \"question\": l[\"text\"],\n        }\n        for l in lines\n    ]\n    # arguments = [\n    #    {\"image_file\":\n    #        Image.open(os.path.abspath(args.image_folder + \"/\" + l[\"image\"])),\n    #      \"question\": l[\"text\"]} for l in lines\n    # ]\n\n    states = [None] * len(lines)\n\n    # Select backend\n    backend = select_sglang_backend(args)\n    sgl.set_default_backend(backend)\n\n    # Run requests\n    tic = time.perf_counter()\n    if args.parallel == 1:\n        for i in tqdm.tqdm(range(len(lines))):\n            image_file = arguments[i][\"image_file\"]\n            question = arguments[i][\"question\"]\n            ret = image_qa.run(image_file=image_file, question=question, temperature=0)\n            states[i] = ret\n    else:\n        states = image_qa.run_batch(\n            arguments, temperature=0, num_threads=args.parallel, progress_bar=True\n        )\n    latency = time.perf_counter() - tic\n\n    print(f\"Latency: {latency:.3f}\")\n\n    # Write results\n    dump_state_text(f\"tmp_output_{args.backend}.txt\", states)\n\n    print(f\"Write output to {args.answer_file}\")\n    with open(args.answer_file, \"w\") as fout:\n        for i in range(len(lines)):\n            value = {\n                \"question_id\": lines[i][\"question_id\"],\n                \"prompt\": lines[i][\"text\"],\n                \"text\": states[i][\"answer\"].strip(),\n                \"model_id\": backend.model_info[\"model_path\"],\n                \"answer_id\": i,\n                \"metadata\": {},\n            }\n            fout.write(json.dumps(value) + \"\\n\")\n\n    with open(args.result_file, \"a\") as fout:\n        value = {\n            \"task\": \"llava_bench\",\n            \"backend\": args.backend,\n            \"num_gpus\": 1,\n            \"latency\": round(latency, 3),\n            \"num_requests\": len(lines),\n            \"parallel\": args.parallel,\n        }\n        fout.write(json.dumps(value) + \"\\n\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--question-file\", type=str, default=\"questions.jsonl\")\n    parser.add_argument(\"--answer-file\", type=str, default=\"answers.jsonl\")\n    parser.add_argument(\"--image-folder\", type=str, default=\"./images\")\n    parser.add_argument(\"--temperature\", type=float, default=0.0)\n    parser.add_argument(\"--num-questions\", type=int, default=None)\n    parser.add_argument(\"--max-tokens\", type=int, default=768)\n    args = add_common_sglang_args_and_parse(parser)\n    main(args)\n"
  },
  {
    "path": "benchmark/llava_bench/bench_sglang_mme.sh",
    "content": "MME_FOLDER=./mme_pack\npython3 bench_sglang.py --num-questions 5000 --question-file $MME_FOLDER/llava_mme_bench_replace.jsonl --answer-file answer_mme.jsonl --image-folder $MME_FOLDER/MME_Benchmark_release_version --max-tokens 4\n"
  },
  {
    "path": "benchmark/llava_bench/download_images.py",
    "content": "import os\n\n# Create the 'images' directory if it doesn't exist\nif not os.path.exists(\"images\"):\n    os.makedirs(\"images\")\n\n# Base URL\nbase_url = \"https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild/resolve/main/images/\"\n\n# Loop through image numbers\nfor i in range(1, 25):\n    # Format the image number with leading zeros\n    image_number = str(i).zfill(3)\n    image_url = base_url + image_number + \".jpg\"\n    image_path = \"images/\" + image_number + \".jpg\"\n\n    # Download the image using wget\n    os.system(f\"wget -O {image_path} {image_url}\")\n\nprint(\"Download complete.\")\n"
  },
  {
    "path": "benchmark/llm_judge/README.md",
    "content": "## Run benchmark\n\n### Benchmark sglang\n```\npython -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000\n```\n\n```\npython3 bench_sglang.py --num-questions 25 --parallel 8\npython3 bench_sglang.py --num-questions 16 --parallel 1\n```\n\n\n### Benchmark vllm\n```\npython3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000\n```\n\n```\npython3 bench_other.py --backend vllm --num-questions 25\n```\n\n\n### Benchmark guidance\n```\npython3 bench_other.py --backend guidance --num-questions 25 --parallel 1 --n-ctx 4096 --model-path path/to/gguf\n```\n\n### Benchmark lmql\n\n```\npython3 bench_other.py --backend lmql --num-questions 25 --parallel 1\n```\n"
  },
  {
    "path": "benchmark/llm_judge/bench_other.py",
    "content": "import argparse\nimport json\nimport time\nfrom concurrent.futures import ThreadPoolExecutor\nfrom functools import partial\n\nfrom tqdm import tqdm\n\nfrom sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate\nfrom sglang.utils import dump_state_text, read_jsonl\n\nsystem_prompt = \"Please serve as an impartial judge and rigorously evaluate the quality of the following article. Apply the most stringent standards possible, showing no leniency.\"\n\ndimension_prompts = [\n    \"Content: This refers to the essences of the essay. The substance should be well researched, accurate, relevant to the topic and should show a thorough understanding of the subject. The essay should also reflect a clear goal or purpose.\",\n    \"Organization and Structure: An essay needs to be properly structured with a clear introduction, body, and conclusion. The essay should flow naturally, with one paragraph leading seamlessly into the next.\",\n    \"Argument and Analysis: The argument made in the essay should be logical, coherent and clearly articulated. Each point made should be backed up by solid evidence and thorough analysis.\",\n    \"Clarity and Precision: The essay should be written in a clear and concise manner. The points made should be easily understood by the reader. The language used should also be precise and unambiguous.\",\n    \"Grammar and Punctuation: Proper use of grammar and punctuation is vital in an academic essay. Errors in grammar and punctuation not only distract the reader but can also negatively impact the meaning and interpretation of the content.\",\n    \"Referencing and Citation: An essay should contain proper citations and references for all sources used. This not only prevents accusations of plagiarism but also gives credit to the authors of the works that have contributed to the essay. The citation should adhere to a specific format as required by the academic institution or specified by the professor.\",\n]\n\n\ndef multi_dimension_judge(article, generate):\n    s = system_prompt\n    s += \"\\n```\\n\" + article + \"\\n```\\n\\n\"\n\n    judges = []\n    for i in range(len(dimension_prompts)):\n        comp = generate(\n            s\n            + \"USER: Please judge the quality based on the following metric. \"\n            + dimension_prompts[i]\n            + \" Please provide a single-paragraph judgement. \"\n            + \"Focus on the provided metric and do not say other things. \"\n            'End your judgement paragraph with the word \"END\"\\nJUDGE:',\n            max_tokens=256,\n            stop=\"END\",\n        )\n        judges.append(comp)\n\n    s += \"I will judge the quality based on the following metrics.\\n\"\n    for i in range(len(dimension_prompts)):\n        s += dimension_prompts[i].split(\":\")[0] + \": \" + judges[i].strip() + \"\\n\"\n\n    s += \"In summary, on a scale of 1 to 10, I would give the article a score of\"\n    s += generate(s, max_tokens=2, stop=None)\n\n    return s\n\n\nasync def multi_dimension_judge_async(article, generate):\n    s = system_prompt\n    s += \"\\n```\\n\" + article + \"\\n```\\n\\n\"\n\n    judges = []\n    for i in range(len(dimension_prompts)):\n        comp = await generate(\n            s\n            + \"USER: Please judge the quality based on the following metric. \"\n            + dimension_prompts[i]\n            + \" Please provide a single-paragraph judgement. \"\n            + \"Focus on the provided metric and do not say other things. \"\n            'End your judgement paragraph with the word \"END\"\\nJUDGE:',\n            max_tokens=256,\n            stop=\"END\",\n        )\n        judges.append(comp)\n\n    s += \"I will judge the quality based on the following metrics.\\n\"\n    for i in range(len(dimension_prompts)):\n        s += dimension_prompts[i].split(\":\")[0] + \": \" + judges[i].strip() + \"\\n\"\n\n    s += \"In summary, on a scale of 1 to 10, I would give the article a score of\"\n    s += await generate(s, max_tokens=2, stop=None)\n\n    return s\n\n\ndef main(args):\n    lines = read_jsonl(args.data_path)[: args.num_questions]\n    states = [None] * len(lines)\n\n    # Select backend\n    call_generate = partial(get_call_generate(args), temperature=0)\n\n    # Run requests\n    tic = time.perf_counter()\n\n    if args.backend != \"lmql\":\n\n        def get_one_answer(i):\n            states[i] = multi_dimension_judge(lines[i], call_generate)\n\n        if args.parallel == 1:\n            for i in tqdm(range(len(lines))):\n                get_one_answer(i)\n        else:\n            with ThreadPoolExecutor(args.parallel) as executor:\n                list(\n                    tqdm(\n                        executor.map(get_one_answer, list(range(len(lines)))),\n                        total=len(lines),\n                    )\n                )\n\n    else:\n        import asyncio\n\n        async def get_one_answer_async(i):\n            states[i] = await multi_dimension_judge_async(lines[i], call_generate)\n\n        batches = []\n        for i in range(0, len(lines), args.parallel):\n            batches.append(list(range(i, min(i + args.parallel, len(lines)))))\n\n        loop = asyncio.get_event_loop()\n        for bt in tqdm(batches):\n            loop.run_until_complete(\n                asyncio.gather(*[get_one_answer_async(i) for i in bt])\n            )\n\n    latency = time.perf_counter() - tic\n\n    # Compute accuracy\n    print(f\"Latency: {latency:.3f}\")\n\n    # Write results\n    dump_state_text(f\"tmp_output_{args.backend}.txt\", states)\n\n    with open(args.result_file, \"a\") as fout:\n        value = {\n            \"task\": \"llm_judge\",\n            \"backend\": args.backend,\n            \"num_gpus\": 1,\n            \"latency\": round(latency, 3),\n            \"num_requests\": args.num_questions,\n            \"other\": {\n                \"num_questions\": args.num_questions,\n                \"parallel\": args.parallel,\n            },\n        }\n        fout.write(json.dumps(value) + \"\\n\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--data-path\", type=str, default=\"articles.jsonl\")\n    parser.add_argument(\"--num-questions\", type=int, default=20)\n    args = add_common_other_args_and_parse(parser)\n    main(args)\n"
  },
  {
    "path": "benchmark/llm_judge/bench_sglang.py",
    "content": "import argparse\nimport json\nimport time\n\nimport sglang as sgl\nfrom sglang.test.test_utils import (\n    add_common_sglang_args_and_parse,\n    select_sglang_backend,\n)\nfrom sglang.utils import dump_state_text, read_jsonl\n\nsystem_prompt = \"Please serve as an impartial judge and rigorously evaluate the quality of the following article. Apply the most stringent standards possible, showing no leniency.\"\n\ndimension_prompts = [\n    \"Content: This refers to the essences of the essay. The substance should be well researched, accurate, relevant to the topic and should show a thorough understanding of the subject. The essay should also reflect a clear goal or purpose.\",\n    \"Organization and Structure: An essay needs to be properly structured with a clear introduction, body, and conclusion. The essay should flow naturally, with one paragraph leading seamlessly into the next.\",\n    \"Argument and Analysis: The argument made in the essay should be logical, coherent and clearly articulated. Each point made should be backed up by solid evidence and thorough analysis.\",\n    \"Clarity and Precision: The essay should be written in a clear and concise manner. The points made should be easily understood by the reader. The language used should also be precise and unambiguous.\",\n    \"Grammar and Punctuation: Proper use of grammar and punctuation is vital in an academic essay. Errors in grammar and punctuation not only distract the reader but can also negatively impact the meaning and interpretation of the content.\",\n    \"Referencing and Citation: An essay should contain proper citations and references for all sources used. This not only prevents accusations of plagiarism but also gives credit to the authors of the works that have contributed to the essay. The citation should adhere to a specific format as required by the academic institution or specified by the professor.\",\n]\n\n\n@sgl.function\ndef multi_dimension_judge(s, article):\n    s += system_prompt\n    s += \"\\n```\\n\" + article + \"\\n```\\n\\n\"\n\n    forks = s.fork(len(dimension_prompts))\n    for i in range(len(dimension_prompts)):\n        forks[i] += (\n            \"USER: Please judge the quality based on the following metric. \"\n            + dimension_prompts[i]\n            + \" Please provide a single-paragraph judgement. \"\n            + \"Focus on the provided metric and do not say other things. \"\n            'End your judgement paragraph with the word \"END\"\\nJUDGE:'\n        )\n        forks[i] += sgl.gen(\"judgement\", max_tokens=256, stop=\"END\")\n    forks.join()\n\n    s += \"I will judge the quality based on the following metrics.\\n\"\n    for i in range(len(dimension_prompts)):\n        s += (\n            dimension_prompts[i].split(\":\")[0]\n            + \": \"\n            + forks[i][\"judgement\"].strip()\n            + \"\\n\"\n        )\n\n    s += \"In summary, on a scale of 1 to 10, I would give the article a score of\"\n    s += sgl.gen(\"score\", max_tokens=2)\n\n\ndef main(args):\n    lines = read_jsonl(args.data_path)[: args.num_questions]\n    arguments = [{\"article\": l} for l in lines]\n\n    # Select backend\n    backend = select_sglang_backend(args)\n\n    # Run requests\n    tic = time.perf_counter()\n    states = multi_dimension_judge.run_batch(\n        arguments,\n        temperature=0,\n        backend=backend,\n        num_threads=args.parallel,\n        progress_bar=True,\n    )\n    latency = time.perf_counter() - tic\n\n    print(f\"Latency: {latency:.3f}\")\n\n    # Write results\n    dump_state_text(f\"tmp_output_{args.backend}.txt\", states)\n\n    with open(args.result_file, \"a\") as fout:\n        value = {\n            \"task\": \"llm_judge\",\n            \"backend\": args.backend,\n            \"num_gpus\": 1,\n            \"latency\": round(latency, 3),\n            \"num_requests\": args.num_questions,\n            \"other\": {\n                \"num_questions\": args.num_questions,\n                \"parallel\": args.parallel,\n            },\n        }\n        fout.write(json.dumps(value) + \"\\n\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--data-path\", type=str, default=\"articles.jsonl\")\n    parser.add_argument(\"--num-questions\", type=int, default=20)\n    args = add_common_sglang_args_and_parse(parser)\n    main(args)\n"
  },
  {
    "path": "benchmark/long_json_decode/README.md",
    "content": "## Run benchmark\n\n### Benchmark sglang\n```\npython3 -m sglang.launch_server --model-path codellama/CodeLlama-7b-instruct-hf --port 30000\n```\n\n```\npython3 bench_sglang.py --num-questions 5 --parallel 1\n```\n\n\n### Benchmark vllm\n```\npython3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model codellama/CodeLlama-7b-instruct-hf  --disable-log-requests --port 21000 --gpu 0.97\n```\n\n```\npython3 bench_other.py --backend vllm --num-questions 5\n```\n\n\n### Benchmark guidance\n```\npython3 bench_other.py --backend guidance --num-questions 5 --parallel 1 --n-ctx 11000 --model-path path/to/code-llama/gguf\n```\n\n\n### Build dataset\n```\npip install wikipedia\npython3 build_dataset.py\n```\n"
  },
  {
    "path": "benchmark/long_json_decode/bench_other.py",
    "content": "import argparse\nimport json\nimport time\nfrom concurrent.futures import ThreadPoolExecutor\nfrom functools import partial\n\nfrom tqdm import tqdm\n\nfrom sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate\nfrom sglang.utils import dump_state_text, read_jsonl\n\n\ndef json_decode(document, generate):\n    s = \"Please extract the information of a city from the following wikipedia page.\\n\"\n    s += \"Page begin.\\n\" + document + \"Page end.\\n\"\n    s += \"Here is the name, country, and symbol of the city in JSON format.\\n\"\n    s += \"{\\n\"\n    s += '  \"name\": \"'\n    s += generate(s, max_tokens=8, stop='\"') + '\",\\n'\n    s += '  \"country\": \"'\n    s += generate(s, max_tokens=8, stop='\"') + '\",\\n'\n    s += '  \"air port code\": \"'\n    s += generate(s, max_tokens=8, stop='\"') + '\",\\n'\n    s += '  \"top 3 landmarks\": \"'\n    s += generate(s, max_tokens=24, stop='\"') + '\",\\n'\n    s += \"}\\n\"\n    return s\n\n\ndef main(args):\n    lines = read_jsonl(args.data_path)\n    arguments = []\n    for i in range(len(lines[: args.num_questions])):\n        arguments.append(\n            {\n                \"document\": lines[i][\"document\"],\n            }\n        )\n    states = [None] * len(arguments)\n\n    # Select backend\n    call_generate = partial(get_call_generate(args), temperature=0)\n\n    # Run requests\n    def get_one_answer(i):\n        states[i] = json_decode(generate=call_generate, **arguments[i])\n\n    tic = time.perf_counter()\n    if args.parallel == 1:\n        for i in tqdm(range(len(arguments))):\n            get_one_answer(i)\n    else:\n        with ThreadPoolExecutor(args.parallel) as executor:\n            list(\n                tqdm(\n                    executor.map(get_one_answer, list(range(len(arguments)))),\n                    total=len(arguments),\n                )\n            )\n\n    latency = time.perf_counter() - tic\n\n    # Compute accuracy\n    print(f\"Latency: {latency:.3f}\")\n\n    # Write results\n    dump_state_text(f\"tmp_output_{args.backend}.txt\", states)\n\n    with open(args.result_file, \"a\") as fout:\n        value = {\n            \"task\": \"long_json_decode\",\n            \"backend\": args.backend,\n            \"num_gpus\": 1,\n            \"latency\": round(latency, 3),\n            \"num_requests\": args.num_questions,\n            \"other\": {\n                \"num_questions\": args.num_questions,\n                \"parallel\": args.parallel,\n            },\n        }\n        fout.write(json.dumps(value) + \"\\n\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--data-path\", type=str, default=\"questions.jsonl\")\n    parser.add_argument(\"--num-questions\", type=int, default=100)\n    args = add_common_other_args_and_parse(parser)\n    main(args)\n"
  },
  {
    "path": "benchmark/long_json_decode/bench_sglang.py",
    "content": "import argparse\nimport json\nimport time\n\nimport sglang as sgl\nfrom sglang.test.test_utils import (\n    add_common_sglang_args_and_parse,\n    select_sglang_backend,\n)\nfrom sglang.utils import dump_state_text, read_jsonl\n\n\n@sgl.function\ndef json_decode(s, document):\n    s += \"Please extract the information of a city from the following wikipedia page.\\n\"\n    s += \"Page begin.\\n\" + document + \"Page end.\\n\"\n    s += \"Here is the name, country, and symbol of the city in JSON format.\\n\"\n    s += \"{\\n\"\n    s += '  \"name\": \"' + sgl.gen(\"name\", max_tokens=8, stop='\"') + '\",\\n'\n    s += '  \"country\": \"' + sgl.gen(\"country\", max_tokens=8, stop='\"') + '\",\\n'\n    s += (\n        '  \"air port code\": \"'\n        + sgl.gen(\"air port code\", max_tokens=8, stop='\"')\n        + '\",\\n'\n    )\n    s += (\n        '  \"top 3 landmarks\": \"'\n        + sgl.gen(\"landmarks\", max_tokens=24, stop='\"')\n        + '\",\\n'\n    )\n    s += \"}\\n\"\n\n\ndef main(args):\n    lines = read_jsonl(args.data_path)\n    arguments = []\n    for i in range(len(lines[: args.num_questions])):\n        arguments.append(\n            {\n                \"document\": lines[i][\"document\"],\n            }\n        )\n\n    # Select backend\n    backend = select_sglang_backend(args)\n    sgl.set_default_backend(backend)\n\n    # Run requests\n    tic = time.perf_counter()\n    states = json_decode.run_batch(\n        arguments, temperature=0, num_threads=args.parallel, progress_bar=True\n    )\n    latency = time.perf_counter() - tic\n\n    # Compute accuracy\n    print(f\"Latency: {latency:.3f}\")\n\n    # Write results\n    dump_state_text(f\"tmp_output_{args.backend}.txt\", states)\n\n    with open(args.result_file, \"a\") as fout:\n        value = {\n            \"task\": \"long_json_decode\",\n            \"backend\": args.backend,\n            \"num_gpus\": 1,\n            \"latency\": round(latency, 3),\n            \"num_requests\": args.num_questions,\n            \"other\": {\n                \"num_questions\": args.num_questions,\n                \"parallel\": args.parallel,\n            },\n        }\n        fout.write(json.dumps(value) + \"\\n\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--data-path\", type=str, default=\"questions.jsonl\")\n    parser.add_argument(\"--num-questions\", type=int, default=10)\n    args = add_common_sglang_args_and_parse(parser)\n    main(args)\n"
  },
  {
    "path": "benchmark/long_json_decode/build_dataset.py",
    "content": "import json\n\nimport transformers\nimport wikipedia\n\nname = \"meta-llama/Llama-2-7b-chat-hf\"\nt = transformers.AutoTokenizer.from_pretrained(name)\ncity_names = [\"los angles\", \"london\", \"tokyo\", \"beijing\", \"singapore\"]\n\n\nfor city_name in city_names:\n    content = str(wikipedia.page(city_name).content)\n    content = content.replace(\"\\n\\n\", \"\\n\")\n\n    tokens = t.encode(content)\n\n    truncate_len = int((10000 / len(tokens)) * len(content))\n    truncate_content = content[:truncate_len]\n    truncate_tokens = t.encode(truncate_content)\n\n    # Count token\n    print(\n        f\"city_name: {city_name}, #tokens: {len(tokens)}, #truncate tokens: {len(truncate_tokens)}\"\n    )\n\n    with open(\"questions.jsonl\", \"a\") as fout:\n        fout.write(json.dumps({\"document\": truncate_content}) + \"\\n\")\n"
  },
  {
    "path": "benchmark/lora/launch_server.py",
    "content": "import argparse\nimport os\n\nNUM_LORAS = 4\nLORA_PATH = {\n    \"base\": \"meta-llama/Llama-2-7b-hf\",\n    \"lora\": \"winddude/wizardLM-LlaMA-LoRA-7B\",\n}\n\n\ndef launch_server(args):\n    base_path = LORA_PATH[\"base\"]\n    lora_path = LORA_PATH[\"lora\"]\n\n    if args.base_only:\n        cmd = f\"python3 -m sglang.launch_server --model {base_path} \"\n    else:\n        cmd = f\"python3 -m sglang.launch_server --model {base_path} --lora-paths \"\n        for i in range(NUM_LORAS):\n            lora_name = f\"lora{i}\"\n            cmd += f\"{lora_name}={lora_path} \"\n    cmd += f\"--disable-radix \"\n    cmd += f\"--max-loras-per-batch {args.max_loras_per_batch} \"\n    cmd += f\"--max-running-requests {args.max_running_requests} \"\n    cmd += f\"--lora-backend {args.lora_backend} \"\n    cmd += f\"--tp-size {args.tp_size} \"\n    if args.disable_custom_all_reduce:\n        cmd += \"--disable-custom-all-reduce\"\n    if args.enable_mscclpp:\n        cmd += \"--enable-mscclpp\"\n    if args.enable_torch_symm_mem:\n        cmd += \"--enable-torch-symm-mem\"\n    print(cmd)\n    os.system(cmd)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--base-only\",\n        action=\"store_true\",\n    )\n    parser.add_argument(\n        \"--max-loras-per-batch\",\n        type=int,\n        default=8,\n    )\n    parser.add_argument(\n        \"--max-running-requests\",\n        type=int,\n        default=8,\n    )\n    parser.add_argument(\n        \"--lora-backend\",\n        type=str,\n        default=\"csgmv\",\n    )\n    parser.add_argument(\n        \"--tp-size\",\n        type=int,\n        default=1,\n        help=\"Tensor parallel size for distributed inference\",\n    )\n    # disable_custom_all_reduce\n    parser.add_argument(\n        \"--disable-custom-all-reduce\",\n        action=\"store_true\",\n        help=\"Disable custom all reduce when device does not support p2p communication\",\n    )\n    parser.add_argument(\n        \"--enable-mscclpp\",\n        action=\"store_true\",\n        help=\"Enable using mscclpp for small messages for all-reduce kernel and fall back to NCCL.\",\n    )\n    parser.add_argument(\n        \"--enable-torch-symm-mem\",\n        action=\"store_true\",\n        help=\"Enable using torch symm mem for all-reduce kernel and fall back to NCCL.\",\n    )\n    args = parser.parse_args()\n\n    launch_server(args)\n"
  },
  {
    "path": "benchmark/lora/lora_bench.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\nimport argparse\nimport asyncio\nimport json\nimport random\nimport resource\nimport sys\nimport time\nimport traceback\nfrom argparse import ArgumentParser\nfrom datetime import datetime\nfrom typing import Any, Dict, List, Optional, Tuple\n\nimport numpy as np\nfrom launch_server import LORA_PATH, NUM_LORAS\nfrom tqdm.asyncio import tqdm\nfrom transformers import PreTrainedTokenizerBase\n\nfrom sglang.bench_serving import (\n    RequestFuncInput,\n    RequestFuncOutput,\n    _create_bench_client_session,\n    calculate_metrics,\n    get_request,\n)\nfrom sglang.benchmark.datasets.random import sample_random_requests\nfrom sglang.benchmark.utils import get_tokenizer, remove_prefix\n\nglobal args\n\n\n# set ignore_eos True by default\nasync def async_request_openai_completions(\n    request_func_input: RequestFuncInput,\n    pbar: Optional[tqdm] = None,\n) -> RequestFuncOutput:\n    api_url = request_func_input.api_url\n    # assert api_url.endswith(\n    #     \"completions\"\n    # ), \"OpenAI Completions API URL must end with 'completions'.\"\n\n    prompt = request_func_input.prompt\n\n    async with _create_bench_client_session() as session:\n        # payload = {\n        #     \"model\": request_func_input.model,\n        #     \"prompt\": prompt,\n        #     \"temperature\": 0.0,\n        #     \"best_of\": 1,\n        #     \"max_tokens\": request_func_input.output_len,\n        #     \"stream\": not args.disable_stream,\n        #     \"ignore_eos\": not args.disable_ignore_eos,\n        #     **request_func_input.extra_request_body,\n        # }\n        # headers = {\"Authorization\": f\"Bearer {os.environ.get('OPENAI_API_KEY')}\"}\n        if args.base_only:\n            payload = {\n                \"text\": prompt,\n                \"sampling_params\": {\"max_new_tokens\": request_func_input.output_len},\n            }\n        else:\n            payload = {\n                \"text\": prompt,\n                \"sampling_params\": {\"max_new_tokens\": request_func_input.output_len},\n                \"lora_path\": f\"lora{random.randint(0, NUM_LORAS - 1)}\",\n            }\n        headers = {\"Authorization\": \"\"}\n\n        output = RequestFuncOutput()\n        output.prompt_len = request_func_input.prompt_len\n\n        generated_text = \"\"\n        ttft = 0.0\n        st = time.perf_counter()\n        most_recent_timestamp = st\n        try:\n            async with session.post(\n                url=api_url, json=payload, headers=headers\n            ) as response:\n                if response.status == 200:\n                    async for chunk_bytes in response.content:\n                        chunk_bytes = chunk_bytes.strip()\n                        if not chunk_bytes:\n                            continue\n\n                        chunk = remove_prefix(chunk_bytes.decode(\"utf-8\"), \"data: \")\n                        latency = time.perf_counter() - st\n                        if chunk == \"[DONE]\":\n                            pass\n                        else:\n                            data = json.loads(chunk)\n\n                            # NOTE: Some completion API might have a last\n                            # usage summary response without a token so we\n                            # want to check a token was generated\n                            if data[\"text\"]:\n                                # if data[\"choices\"][0][\"text\"]:\n                                timestamp = time.perf_counter()\n                                # First token\n                                if ttft == 0.0:\n                                    ttft = time.perf_counter() - st\n                                    output.ttft = ttft\n\n                                # Decoding phase\n                                else:\n                                    output.itl.append(timestamp - most_recent_timestamp)\n\n                                most_recent_timestamp = timestamp\n                                # generated_text += data[\"choices\"][0][\"text\"]\n                                generated_text += data[\"text\"]\n\n                    output.generated_text = generated_text\n                    output.success = True\n                    output.latency = latency\n                    output.output_len = request_func_input.output_len\n                else:\n                    output.error = response.reason or \"\"\n                    output.success = False\n        except Exception:\n            output.success = False\n            exc_info = sys.exc_info()\n            output.error = \"\".join(traceback.format_exception(*exc_info))\n\n    if pbar:\n        pbar.update(1)\n    return output\n\n\nASYNC_REQUEST_FUNCS = {\n    \"sglang\": async_request_openai_completions,\n}\n\n\nasync def benchmark(\n    backend: str,\n    api_url: str,\n    model_id: str,\n    tokenizer: PreTrainedTokenizerBase,\n    input_requests: List[Tuple[str, int, int]],\n    request_rate: float,\n    disable_tqdm: bool,\n    extra_request_body: Dict[str, Any],\n):\n    if backend in ASYNC_REQUEST_FUNCS:\n        request_func = ASYNC_REQUEST_FUNCS[backend]\n    else:\n        raise ValueError(f\"Unknown backend: {backend}\")\n\n    print(\"Starting initial single prompt test run...\")\n    test_request = input_requests[0]\n    test_input = RequestFuncInput(\n        model=model_id,\n        prompt=test_request.prompt,\n        api_url=api_url,\n        prompt_len=test_request.prompt_len,\n        output_len=test_request.output_len,\n        lora_name=\"dummy\",  # the lora_name argument will not be used\n        image_data=None,\n        extra_request_body=extra_request_body,\n    )\n    test_output = await request_func(request_func_input=test_input)\n    if not test_output.success:\n        raise ValueError(\n            \"Initial test run failed - Please make sure benchmark arguments \"\n            f\"are correctly specified. Error: {test_output.error}\"\n        )\n    else:\n        print(\"Initial test run completed. Starting main benchmark run...\")\n\n    pbar = None if disable_tqdm else tqdm(total=len(input_requests))\n\n    benchmark_start_time = time.perf_counter()\n    tasks: List[asyncio.Task] = []\n    async for request in get_request(input_requests, request_rate):\n        request_func_input = RequestFuncInput(\n            model=model_id,\n            prompt=request.prompt,\n            api_url=api_url,\n            prompt_len=request.prompt_len,\n            output_len=request.output_len,\n            lora_name=\"dummy\",\n            image_data=None,\n            extra_request_body=extra_request_body,\n        )\n        tasks.append(\n            asyncio.create_task(\n                request_func(request_func_input=request_func_input, pbar=pbar)\n            )\n        )\n    outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)\n\n    if pbar is not None:\n        pbar.close()\n\n    benchmark_duration = time.perf_counter() - benchmark_start_time\n\n    metrics, output_lens = calculate_metrics(\n        input_requests=input_requests,\n        outputs=outputs,\n        dur_s=benchmark_duration,\n        tokenizer=tokenizer,\n        backend=backend,\n    )\n\n    print(\"\\n{s:{c}^{n}}\".format(s=\" Serving Benchmark Result \", n=50, c=\"=\"))\n    print(\"{:<40} {:<10}\".format(\"Backend:\", backend))\n    print(\"{:<40} {:<10}\".format(\"Traffic request rate:\", request_rate))\n    print(\"{:<40} {:<10}\".format(\"Successful requests:\", metrics.completed))\n    print(\"{:<40} {:<10.2f}\".format(\"Benchmark duration (s):\", benchmark_duration))\n    print(\"{:<40} {:<10}\".format(\"Total input tokens:\", metrics.total_input))\n    print(\"{:<40} {:<10}\".format(\"Total generated tokens:\", metrics.total_output))\n    print(\n        \"{:<40} {:<10}\".format(\n            \"Total generated tokens (retokenized):\", metrics.total_output_retokenized\n        )\n    )\n    print(\n        \"{:<40} {:<10.2f}\".format(\n            \"Request throughput (req/s):\", metrics.request_throughput\n        )\n    )\n    print(\n        \"{:<40} {:<10.2f}\".format(\n            \"Input token throughput (tok/s):\", metrics.input_throughput\n        )\n    )\n    print(\n        \"{:<40} {:<10.2f}\".format(\n            \"Output token throughput (tok/s):\", metrics.output_throughput\n        )\n    )\n    print(\n        \"{:<40} {:<10.2f}\".format(\"Total throughput (tok/s):\", metrics.total_throughput)\n    )\n    print(\"{s:{c}^{n}}\".format(s=\"End-to-End Latency\", n=50, c=\"-\"))\n    print(\n        \"{:<40} {:<10.2f}\".format(\"Mean E2E Latency (ms):\", metrics.mean_e2e_latency_ms)\n    )\n    print(\n        \"{:<40} {:<10.2f}\".format(\n            \"Median E2E Latency (ms):\", metrics.median_e2e_latency_ms\n        )\n    )\n    print(\"{s:{c}^{n}}\".format(s=\"Time to First Token\", n=50, c=\"-\"))\n    print(\"{:<40} {:<10.2f}\".format(\"Mean TTFT (ms):\", metrics.mean_ttft_ms))\n    print(\"{:<40} {:<10.2f}\".format(\"Median TTFT (ms):\", metrics.median_ttft_ms))\n    print(\"{:<40} {:<10.2f}\".format(\"P99 TTFT (ms):\", metrics.p99_ttft_ms))\n    print(\n        \"{s:{c}^{n}}\".format(s=\"Time per Output Token (excl. 1st token)\", n=50, c=\"-\")\n    )\n    print(\"{:<40} {:<10.2f}\".format(\"Mean TPOT (ms):\", metrics.mean_tpot_ms))\n    print(\"{:<40} {:<10.2f}\".format(\"Median TPOT (ms):\", metrics.median_tpot_ms))\n    print(\"{:<40} {:<10.2f}\".format(\"P99 TPOT (ms):\", metrics.p99_tpot_ms))\n    print(\"{s:{c}^{n}}\".format(s=\"Inter-token Latency\", n=50, c=\"-\"))\n    print(\"{:<40} {:<10.2f}\".format(\"Mean ITL (ms):\", metrics.mean_itl_ms))\n    print(\"{:<40} {:<10.2f}\".format(\"Median ITL (ms):\", metrics.median_itl_ms))\n    print(\"{:<40} {:<10.2f}\".format(\"P99 ITL (ms):\", metrics.p99_itl_ms))\n    print(\"=\" * 50)\n\n    if (\n        metrics.median_ttft_ms is not None\n        and metrics.mean_itl_ms is not None\n        and metrics.output_throughput is not None\n    ):\n        result = {\n            \"backend\": args.backend,\n            \"request_rate\": request_rate,\n            \"total_input_tokens\": metrics.total_input,\n            \"total_output_tokens\": metrics.total_output,\n            \"total_output_tokens_retokenized\": metrics.total_output_retokenized,\n            \"mean_e2e_latency_ms\": metrics.mean_e2e_latency_ms,\n            \"median_e2e_latency_ms\": metrics.median_e2e_latency_ms,\n            \"median_ttft_ms\": metrics.median_ttft_ms,\n            \"median_itl_ms\": metrics.median_itl_ms,\n            \"output_throughput\": metrics.output_throughput,\n            \"random_input_len\": args.random_input_len,\n            \"random_output_len\": args.random_output_len,\n            \"random_range_ratio\": args.random_range_ratio,\n            \"duration\": benchmark_duration,\n            \"completed\": metrics.completed,\n        }\n    else:\n        print(f\"Error running benchmark for request rate: {request_rate}\")\n        print(\"-\" * 30)\n\n    # Determine output file name\n    if args.output_file:\n        output_file_name = args.output_file\n    else:\n        now = datetime.now().strftime(\"%m%d\")\n        output_file_name = f\"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl\"\n\n    # Append results to a JSONL file\n    with open(output_file_name, \"a\") as file:\n        file.write(json.dumps(result) + \"\\n\")\n\n    result = {\n        \"duration\": benchmark_duration,\n        \"completed\": metrics.completed,\n        \"total_input_tokens\": metrics.total_input,\n        \"total_output_tokens\": metrics.total_output,\n        \"total_output_tokens_retokenized\": metrics.total_output_retokenized,\n        \"request_throughput\": metrics.request_throughput,\n        \"input_throughput\": metrics.input_throughput,\n        \"output_throughput\": metrics.output_throughput,\n        \"mean_ttft_ms\": metrics.mean_ttft_ms,\n        \"median_ttft_ms\": metrics.median_ttft_ms,\n        \"std_ttft_ms\": metrics.std_ttft_ms,\n        \"p99_ttft_ms\": metrics.p99_ttft_ms,\n        \"mean_tpot_ms\": metrics.mean_tpot_ms,\n        \"median_tpot_ms\": metrics.median_tpot_ms,\n        \"std_tpot_ms\": metrics.std_tpot_ms,\n        \"p99_tpot_ms\": metrics.p99_tpot_ms,\n        \"mean_itl_ms\": metrics.mean_itl_ms,\n        \"median_itl_ms\": metrics.median_itl_ms,\n        \"std_itl_ms\": metrics.std_itl_ms,\n        \"p99_itl_ms\": metrics.p99_itl_ms,\n        \"input_lens\": [output.prompt_len for output in outputs],\n        \"output_lens\": output_lens,\n        \"ttfts\": [output.ttft for output in outputs],\n        \"itls\": [output.itl for output in outputs],\n        \"generated_texts\": [output.generated_text for output in outputs],\n        \"errors\": [output.error for output in outputs],\n        \"mean_e2e_latency_ms\": metrics.mean_e2e_latency_ms,\n        \"median_e2e_latency_ms\": metrics.median_e2e_latency_ms,\n    }\n    return result\n\n\ndef run_benchmark(args_: argparse.Namespace):\n    global args\n    args = args_\n\n    # Set global environments\n    set_ulimit()\n    random.seed(args.seed)\n    np.random.seed(args.seed)\n\n    # Set url\n    if args.port is None:\n        args.port = {\n            \"sglang\": 30000,\n        }.get(args.backend, 30000)\n\n    # api_url = (\n    #     f\"{args.base_url}/v1/completions\"\n    #     if args.base_url\n    #     else f\"http://{args.host}:{args.port}/v1/completions\"\n    # )\n    api_url = (\n        f\"{args.base_url}/generate\"\n        if args.base_url\n        else f\"http://{args.host}:{args.port}/generate\"\n    )\n\n    print(f\"{args}\\n\")\n\n    # Read dataset\n    backend = args.backend\n    model_id = args.model = LORA_PATH[\"base\"]\n    tokenizer_id = args.model\n\n    tokenizer = get_tokenizer(tokenizer_id)\n\n    input_requests = sample_random_requests(\n        input_len=args.random_input_len,\n        output_len=args.random_output_len,\n        num_prompts=args.num_prompts,\n        range_ratio=args.random_range_ratio,\n        tokenizer=tokenizer,\n        dataset_path=\"\",\n    )\n\n    return asyncio.run(\n        benchmark(\n            backend=backend,\n            api_url=api_url,\n            model_id=model_id,\n            tokenizer=tokenizer,\n            input_requests=input_requests,\n            request_rate=args.request_rate,\n            disable_tqdm=False,\n            extra_request_body={},\n        )\n    )\n\n\ndef set_ulimit(target_soft_limit=65535):\n    resource_type = resource.RLIMIT_NOFILE\n    current_soft, current_hard = resource.getrlimit(resource_type)\n\n    if current_soft < target_soft_limit:\n        try:\n            resource.setrlimit(resource_type, (target_soft_limit, current_hard))\n        except ValueError as e:\n            print(f\"Fail to set RLIMIT_NOFILE: {e}\")\n\n\nif __name__ == \"__main__\":\n    parser = ArgumentParser(description=\"Benchmark the online lora serving throughput.\")\n    parser.add_argument(\n        \"--backend\",\n        type=str,\n        choices=list(ASYNC_REQUEST_FUNCS.keys()),\n        default=\"sglang\",\n        help=\"Must specify a backend, depending on the LLM Inference Engine.\",\n    )\n    parser.add_argument(\n        \"--base-url\",\n        type=str,\n        default=None,\n        help=\"Server or API base url if not using http host and port.\",\n    )\n    parser.add_argument(\n        \"--host\", type=str, default=\"0.0.0.0\", help=\"Default host is 0.0.0.0.\"\n    )\n    parser.add_argument(\n        \"--port\",\n        type=int,\n        help=\"If not set, the default port is configured according to its default value for different LLM Inference Engines.\",\n    )\n    parser.add_argument(\n        \"--num-prompts\",\n        type=int,\n        default=50,\n        help=\"Number of prompts to process. Default is 1000.\",\n    )\n    parser.add_argument(\n        \"--random-input-len\",\n        type=int,\n        default=1024,\n        help=\"Number of input tokens per request, used only for random dataset.\",\n    )\n    parser.add_argument(\n        \"--random-output-len\",\n        type=int,\n        default=128,\n        help=\"Number of output tokens per request, used only for random dataset.\",\n    )\n    parser.add_argument(\n        \"--random-range-ratio\",\n        type=float,\n        default=0.0,\n        help=\"Range of sampled ratio of input/output length, \"\n        \"used only for random dataset.\",\n    )\n    parser.add_argument(\n        \"--request-rate\",\n        type=float,\n        default=float(\"inf\"),\n        help=\"Number of requests per second. If this is inf, then all the requests are sent at time 0. \"\n        \"Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.\",\n    )\n    parser.add_argument(\n        \"--base-only\",\n        action=\"store_true\",\n    )\n    parser.add_argument(\"--output-file\", type=str, help=\"Output JSONL file name.\")\n    parser.add_argument(\"--seed\", type=int, default=1, help=\"The random seed.\")\n    args = parser.parse_args()\n    run_benchmark(args)\n"
  },
  {
    "path": "benchmark/mmlu/README.md",
    "content": "## Download data\n```\nbash download_data.sh\n```\n\n## Run benchmark\n\n### Benchmark sglang\n```\npython -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000\n```\n\n```\npython3 bench_sglang.py --nsub 10\n```\n\n```\n# OpenAI models\npython3 bench_sglang.py --backend gpt-3.5-turbo --parallel 8\n```\n\n### Benchmark vllm\n```\npython3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000\n```\n\n```\npython3 bench_other.py --nsub 10 --backend vllm\n```\n\n\n### Benchmark lightllm\n```\n# A10G\npython -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000\n\n# V100\npython -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 4500 --port 22000\n```\n\n```\npython3 bench_other.py --nsub 10 --backend lightllm\n```\n\n\n### Benchmark guidance\n```\npython3 bench_other.py --nsub 10 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf\n```\n\n\n### Benchmark lmql\n```\nCUDA_VISIBLE_DEVICES=0,1 lmql serve-model meta-llama/Llama-2-7b-chat-hf --cuda --port 23000\n```\n\n```\npython3 bench_other.py --nsub 10 --backend lmql --parallel 2\n```\n"
  },
  {
    "path": "benchmark/mmlu/bench_other.py",
    "content": "import argparse\nimport asyncio\nimport json\nimport os\nimport time\nfrom concurrent.futures import ThreadPoolExecutor\n\nimport numpy as np\nimport pandas as pd\nimport tiktoken\nfrom tqdm import tqdm\n\nfrom sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate\n\nchoices = [\"A\", \"B\", \"C\", \"D\"]\n\ntokenizer = tiktoken.encoding_for_model(\"gpt-3.5-turbo\")\n\n\ndef format_subject(subject):\n    l = subject.split(\"_\")\n    s = \"\"\n    for entry in l:\n        s += \" \" + entry\n    return s\n\n\ndef format_example(df, idx, include_answer=True):\n    prompt = df.iloc[idx, 0]\n    k = df.shape[1] - 2\n    for j in range(k):\n        prompt += \"\\n{}. {}\".format(choices[j], df.iloc[idx, j + 1])\n    prompt += \"\\nAnswer:\"\n    if include_answer:\n        prompt += \" {}\\n\\n\".format(df.iloc[idx, k + 1])\n    return prompt\n\n\ndef gen_prompt(train_df, subject, k=-1):\n    prompt = \"The following are multiple choice questions (with answers) about{}.\\n\\n\".format(\n        format_subject(subject)\n    )\n    if k == -1:\n        k = train_df.shape[0]\n    for i in range(k):\n        prompt += format_example(train_df, i)\n    return prompt\n\n\ndef evaluate(args, subject, dev_df, test_df, call_generate):\n    prompts = []\n    labels = []\n\n    # Construct prompts\n    k = args.ntrain\n    train_prompt = gen_prompt(dev_df, subject, k)\n    while len(tokenizer.encode(train_prompt)) > 1536:\n        k -= 1\n        train_prompt = gen_prompt(dev_df, subject, k)\n\n    for i in range(test_df.shape[0]):\n        prompt_end = format_example(test_df, i, include_answer=False)\n        prompt = train_prompt + prompt_end\n        prompts.append(prompt)\n\n        label = test_df.iloc[i, test_df.shape[1] - 1]\n        labels.append(label)\n\n    preds = [None] * len(prompts)\n    max_tokens = 1\n\n    # Run requests\n    if args.backend != \"lmql\":\n        # Use thread pool\n        def get_one_answer(i):\n            pred = call_generate(prompts[i], temperature=0, max_tokens=max_tokens)\n            preds[i] = pred.strip()[0]\n\n        tic = time.perf_counter()\n        if args.parallel == 1:\n            for i in range(len(prompts)):\n                get_one_answer(i)\n        else:\n            with ThreadPoolExecutor(args.parallel) as executor:\n                executor.map(get_one_answer, list(range(len(prompts))))\n    else:\n        # Use asyncio\n        async def batched_call(batch_size):\n            for i in range(0, len(prompts), batch_size):\n                tasks = []\n                for p in prompts[i : i + batch_size]:\n                    tasks.append(call_generate(p, temperature=0, max_tokens=max_tokens))\n                rets = await asyncio.gather(*tasks)\n                for j in range(len(rets)):\n                    preds[i + j] = rets[j].strip()[0]\n\n        tic = time.perf_counter()\n        asyncio.run(batched_call(batch_size=args.parallel))\n    latency = time.perf_counter() - tic\n\n    # Compute accuracy\n    cors = [pred == label for pred, label in zip(preds, labels)]\n    acc = np.mean(cors)\n    cors = np.array(cors)\n\n    print(\n        \"Average accuracy {:.3f}, latency {:.2f}, #q: {} - {}\".format(\n            acc, latency, len(prompts), subject\n        )\n    )\n\n    return cors, acc, latency\n\n\ndef main(args):\n    subjects = sorted(\n        [\n            f.split(\"_test.csv\")[0]\n            for f in os.listdir(os.path.join(args.data_dir, \"test\"))\n            if \"_test.csv\" in f\n        ]\n    )\n\n    all_cors = []\n    all_latencies = []\n    num_requests = 0\n\n    # Select backend\n    call_generate = get_call_generate(args)\n\n    for subject in tqdm(subjects[: args.nsub]):\n        dev_df = pd.read_csv(\n            os.path.join(args.data_dir, \"dev\", subject + \"_dev.csv\"), header=None\n        )[: args.ntrain]\n        test_df = pd.read_csv(\n            os.path.join(args.data_dir, \"test\", subject + \"_test.csv\"), header=None\n        )\n\n        cors, acc, latency = evaluate(args, subject, dev_df, test_df, call_generate)\n        all_cors.append(cors)\n        all_latencies.append(latency)\n        num_requests += len(test_df)\n\n    total_latency = np.sum(all_latencies)\n    print(\"Total latency: {:.3f}\".format(total_latency))\n\n    weighted_acc = np.mean(np.concatenate(all_cors))\n    print(\"Average accuracy: {:.3f}\".format(weighted_acc))\n\n    # Write results\n    with open(args.result_file, \"a\") as fout:\n        value = {\n            \"task\": \"mmlu\",\n            \"backend\": args.backend,\n            \"num_gpus\": 1,\n            \"latency\": round(total_latency, 3),\n            \"accuracy\": round(weighted_acc, 3),\n            \"num_requests\": num_requests,\n            \"other\": {\n                \"nsub\": args.nsub,\n                \"parallel\": args.parallel,\n            },\n        }\n        fout.write(json.dumps(value) + \"\\n\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--ntrain\", type=int, default=5)\n    parser.add_argument(\"--data_dir\", type=str, default=\"data\")\n    parser.add_argument(\"--nsub\", type=int, default=60)\n    args = add_common_other_args_and_parse(parser)\n    main(args)\n"
  },
  {
    "path": "benchmark/mmlu/bench_sglang.py",
    "content": "import argparse\nimport json\nimport os\nimport subprocess\nimport tarfile\nimport time\n\nimport numpy as np\nimport pandas as pd\nimport tiktoken\n\nfrom sglang.test.test_utils import (\n    add_common_sglang_args_and_parse,\n    dump_bench_raw_result,\n    select_sglang_backend,\n)\n\nSCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))\n\nchoices = [\"A\", \"B\", \"C\", \"D\"]\n\ntokenizer = tiktoken.encoding_for_model(\"gpt-3.5-turbo\")\n\n\ndef format_subject(subject):\n    l = subject.split(\"_\")\n    s = \"\"\n    for entry in l:\n        s += \" \" + entry\n    return s\n\n\ndef format_example(df, idx, include_answer=True):\n    prompt = df.iloc[idx, 0]\n    k = df.shape[1] - 2\n    for j in range(k):\n        prompt += \"\\n{}. {}\".format(choices[j], df.iloc[idx, j + 1])\n    prompt += \"\\nAnswer:\"\n    if include_answer:\n        prompt += \" {}\\n\\n\".format(df.iloc[idx, k + 1])\n    return prompt\n\n\ndef gen_prompt(train_df, subject, k=-1):\n    prompt = \"The following are multiple choice questions (with answers) about{}.\\n\\n\".format(\n        format_subject(subject)\n    )\n    if k == -1:\n        k = train_df.shape[0]\n    for i in range(k):\n        prompt += format_example(train_df, i)\n    return prompt\n\n\ndef download_data(data_dir):\n    \"\"\"Download and extract MMLU data if it doesn't exist.\"\"\"\n    if os.path.isdir(os.path.join(data_dir, \"test\")):\n        return\n    print(f\"Data not found at {data_dir}. Downloading...\")\n    os.makedirs(data_dir, exist_ok=True)\n    tar_path = os.path.join(data_dir, \"data.tar\")\n    subprocess.check_call(\n        [\"wget\", \"-O\", tar_path, \"https://people.eecs.berkeley.edu/~hendrycks/data.tar\"]\n    )\n    with tarfile.open(tar_path) as tar:\n        tar.extractall(path=data_dir, filter=\"data\")\n    # The tarball extracts into a \"data/\" subdirectory; move contents up if needed\n    nested = os.path.join(data_dir, \"data\")\n    if os.path.isdir(nested):\n        for item in os.listdir(nested):\n            os.rename(os.path.join(nested, item), os.path.join(data_dir, item))\n        os.rmdir(nested)\n    os.remove(tar_path)\n    print(\"Download complete.\")\n\n\ndef main(args):\n    subjects = sorted(\n        [\n            f.split(\"_test.csv\")[0]\n            for f in os.listdir(os.path.join(args.data_dir, \"test\"))\n            if \"_test.csv\" in f\n        ]\n    )\n\n    # Build prompts\n    arguments = []\n    labels = []\n    num_questions = []\n\n    for subject in subjects[: args.nsub]:\n        dev_df = pd.read_csv(\n            os.path.join(args.data_dir, \"dev\", subject + \"_dev.csv\"), header=None\n        )[: args.ntrain]\n        test_df = pd.read_csv(\n            os.path.join(args.data_dir, \"test\", subject + \"_test.csv\"), header=None\n        )\n        num_questions.append(test_df.shape[0])\n\n        k = args.ntrain\n        few_shot_examples = gen_prompt(dev_df, subject, k)\n        while len(tokenizer.encode(few_shot_examples)) > 1536:\n            k -= 1\n            few_shot_examples = gen_prompt(dev_df, subject, k)\n\n        for i in range(test_df.shape[0]):\n            prompt_end = format_example(test_df, i, include_answer=False)\n\n            arguments.append(\n                {\n                    \"examples\": few_shot_examples,\n                    \"question\": prompt_end,\n                }\n            )\n\n            label = test_df.iloc[i, test_df.shape[1] - 1]\n            labels.append(label)\n\n    #####################################\n    ######### SGL Program Begin #########\n    #####################################\n\n    import sglang as sgl\n\n    if args.backend.startswith(\"gpt-\"):\n\n        @sgl.function\n        def few_shot_mmlu(s, examples, question):\n            s += sgl.user(examples + question)\n            s += sgl.assistant(sgl.gen(\"answer\"))\n\n    else:\n\n        @sgl.function\n        def few_shot_mmlu(s, examples, question):\n            s += examples + question + sgl.gen(\"answer\")\n\n    #####################################\n    ########## SGL Program End ##########\n    #####################################\n\n    # Select backend\n    backend = select_sglang_backend(args)\n\n    # Run\n    tic = time.perf_counter()\n    states = few_shot_mmlu.run_batch(\n        arguments,\n        temperature=0,\n        max_new_tokens=1,\n        backend=backend,\n        num_threads=args.parallel,\n        progress_bar=True,\n    )\n    preds = [\n        s[\"answer\"].strip()[0] if len(s[\"answer\"].strip()) > 0 else \"\" for s in states\n    ]\n    latency = time.perf_counter() - tic\n\n    # Compute accuracy\n    cors = [pred == label for pred, label in zip(preds, labels)]\n\n    pt = 0\n    for subject, num_qs in zip(subjects[: args.nsub], num_questions):\n        print(\n            f\"subject: {subject}, #q:{num_qs}, acc: {np.mean(cors[pt: pt + num_qs]):.3f}\"\n        )\n        pt += num_qs\n    assert pt == len(cors)\n    weighted_acc = np.mean(cors)\n\n    dump_bench_raw_result(\n        path=args.raw_result_file,\n        states=states,\n        preds=preds,\n        labels=labels,\n    )\n\n    # Print results\n    print(\"Total latency: {:.3f}\".format(latency))\n    print(\"Average accuracy: {:.3f}\".format(weighted_acc))\n\n    # Write results\n    with open(args.result_file, \"a\") as fout:\n        value = {\n            \"task\": \"mmlu\",\n            \"backend\": args.backend,\n            \"num_gpus\": 1,\n            \"latency\": round(latency, 3),\n            \"accuracy\": round(weighted_acc, 3),\n            \"num_requests\": len(arguments),\n            \"other\": {\n                \"nsub\": args.nsub,\n                \"parallel\": args.parallel,\n            },\n        }\n        fout.write(json.dumps(value) + \"\\n\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--ntrain\", \"-k\", type=int, default=5)\n    parser.add_argument(\n        \"--data_dir\", \"-d\", type=str, default=os.path.join(SCRIPT_DIR, \"data\")\n    )\n    parser.add_argument(\"--save_dir\", \"-s\", type=str, default=\"results\")\n    parser.add_argument(\"--nsub\", type=int, default=60)\n    args = add_common_sglang_args_and_parse(parser)\n    download_data(args.data_dir)\n    main(args)\n"
  },
  {
    "path": "benchmark/mmlu/download_data.sh",
    "content": "wget https://people.eecs.berkeley.edu/~hendrycks/data.tar\ntar xf data.tar\n"
  },
  {
    "path": "benchmark/mmmu/README.md",
    "content": "## Run evaluation\n\n### Evaluate sglang\n\nHost the VLM:\n\n```\npython -m sglang.launch_server --model-path Qwen/Qwen2-VL-7B-Instruct --port 30000\n```\n\nIt's recommended to reduce the memory usage by appending something like `--mem-fraction-static 0.6` to the command above.\n\nBenchmark:\n\n```\npython benchmark/mmmu/bench_sglang.py --port 30000 --concurrency 16\n```\n\nYou can adjust the `--concurrency` to control the number of concurrent OpenAI calls.\n\nYou can use `--lora-path` to specify the LoRA adapter to apply during benchmarking. E.g.,\n```\n# Launch server with LoRA enabled\npython -m sglang.launch_server --model-path microsoft/Phi-4-multimodal-instruct --port 30000 --trust-remote-code --disable-radix-cache --lora-paths vision=<LoRA path>\n\n# Apply LoRA adapter during inferencing\npython -m benchmark/mmmu/bench_sglang.py --concurrency 8 --lora-path vision\n```\n\nYou can use `--response-answer-regex` to specify how to extract the answer from the response string. E.g.,\n```\npython3 -m sglang.launch_server --model-path zai-org/GLM-4.1V-9B-Thinking --reasoning-parser glm45\n\npython3 bench_sglang.py --response-answer-regex \"<\\|begin_of_box\\|>(.*)<\\|end_of_box\\|>\" --concurrency 64\n```\n\nYou can use `--extra-request-body` to specify additional OpenAI request parameters. E.g.,\n```\npython3 bench_sglang.py --extra-request-body '{\"max_new_tokens\": 128, \"temperature\": 0.01}'\n```\n\n### Evaluate HF\n\n```\npython benchmark/mmmu/bench_hf.py --model-path Qwen/Qwen2-VL-7B-Instruct\n```\n\n# Profiling MMMU\nYou should use the standard instructions found in the [dedicated profiling doc](../../docs/developer_guide/benchmark_and_profiling.md) if running this benchmark with the profile option. We recommend using `--concurrency 1` for consistency, which makes profiling and debugging easier.\n"
  },
  {
    "path": "benchmark/mmmu/bench_hf.py",
    "content": "import argparse\n\nimport PIL\nimport torch\nfrom data_utils import save_json\nfrom eval_utils import (\n    EvalArgs,\n    eval_result,\n    get_sampling_params,\n    prepare_samples,\n    process_result,\n)\nfrom tqdm import tqdm\nfrom transformers import AutoModel, AutoProcessor, GenerationConfig\n\n\n@torch.no_grad()\ndef eval_mmmu(args):\n    eval_args = EvalArgs.from_cli_args(args)\n\n    sampling_params = get_sampling_params(eval_args)\n    generation_config = GenerationConfig(\n        max_new_tokens=sampling_params[\"max_new_tokens\"],\n        do_sample=False,\n    )\n\n    try:\n        from transformers import AutoModelForImageTextToText\n\n        model = AutoModelForImageTextToText.from_pretrained(\n            args.model_path,\n            torch_dtype=\"auto\",\n            trust_remote_code=True,\n        )\n    except Exception as first_exception:\n        try:\n            # check if the model is belongs to internvl\n            if \"InternVL\" in args.model_path:\n                from transformers import AutoTokenizer\n\n                from sglang.srt.multimodal.internvl_utils import image_to_pixel_values\n\n                tokenizer = AutoTokenizer.from_pretrained(args.model_path)\n                model = AutoModel.from_pretrained(\n                    args.model_path,\n                    torch_dtype=\"auto\",\n                    trust_remote_code=True,\n                )\n                generation_config_internvl = dict(\n                    max_new_tokens=sampling_params[\"max_new_tokens\"], do_sample=False\n                )\n\n            else:\n                model = AutoModel.from_pretrained(\n                    args.model_path,\n                    torch_dtype=\"auto\",\n                    trust_remote_code=True,\n                    init_tts=False,\n                )\n        except Exception as second_exception:\n            raise RuntimeError(\n                f\"Failed to load model: First attempt failed with {first_exception}, \"\n                f\"second attempt failed with {second_exception}\"\n            ) from second_exception\n\n    model = model.eval().cuda()\n\n    processor = AutoProcessor.from_pretrained(\n        args.model_path, torch_dtype=\"auto\", device_map=\"auto\", trust_remote_code=True\n    )\n\n    samples = prepare_samples(eval_args)\n    out_samples = dict()\n\n    answer_dict = {}\n    for sample in tqdm(samples):\n        prompt = sample[\"final_input_prompt\"]\n        image = sample[\"image\"]\n        prefix = prompt.split(\"<\")[0]\n        suffix = prompt.split(\">\")[1]\n        assert image is not None\n\n        if \"InternVL\" in args.model_path:\n            image = PIL.Image.open(sample[\"image_path\"]).convert(\"RGB\")\n            pixel_values = image_to_pixel_values(\n                image, input_size=448, max_num=12, use_thumbnail=True\n            )\n            pixel_values = pixel_values.to(device=\"cuda\", dtype=torch.bfloat16)\n            contents = \"\"\n            if prefix:\n                contents += prefix\n            contents += \"<image>\\n\"\n            if suffix:\n                contents += suffix\n            response = model.chat(\n                tokenizer, pixel_values, contents, generation_config_internvl\n            )\n            print(f\"response: {response}\")\n            process_result(response, sample, answer_dict, out_samples)\n            continue\n\n        contents = []\n        if prefix:\n            contents += [{\"type\": \"text\", \"text\": prefix}]\n        contents += [\n            {\n                \"type\": \"image\",\n                \"image\": sample[\"image_path\"],\n            }\n        ]\n        if suffix:\n            contents += [{\"type\": \"text\", \"text\": suffix}]\n        messages = [{\"role\": \"user\", \"content\": contents}]\n        try:\n            model_inputs = processor.tokenizer.apply_chat_template(\n                messages,\n                tokenize=True,\n                return_dict=True,\n                add_generation_prompt=True,\n                return_tensors=\"pt\",\n            ).to(model.device)\n            input_len = model_inputs[\"input_ids\"].shape[-1]\n            generation = model.generate(\n                **model_inputs, generation_config=generation_config\n            )\n            generation = generation[0][input_len:]\n            response = processor.decode(generation, skip_special_tokens=True)\n        except:\n            contents = []\n            if prefix:\n                contents += [prefix]\n            image = PIL.Image.open(sample[\"image_path\"])\n            contents += [image]\n            if suffix:\n                contents += [suffix]\n            messages = [{\"role\": \"user\", \"content\": contents}]\n            response = model.chat(\n                msgs=messages,\n                tokenizer=processor.tokenizer,\n                sampling=False,\n                max_new_tokens=sampling_params[\"max_new_tokens\"],\n                use_tts_template=False,\n                generate_audio=False,\n                temperature=0.0,\n            )\n        print(f\"response: {response}\")\n        process_result(response, sample, answer_dict, out_samples)\n\n    args.output_path = f\"{args.model_path}_answer_hf.json\"\n    save_json(args.output_path, out_samples)\n    eval_result(\n        model_answer_path=args.output_path,\n        answer_dict=answer_dict,\n        eval_output_path=f\"{args.model_path}_val_hf.json\",\n    )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--model-path\",\n        type=str,\n        help=\"The path of the model weights. This can be a local folder or a Hugging Face repo ID.\",\n        required=True,\n    )\n    EvalArgs.add_cli_args(parser)\n    args = parser.parse_args()\n\n    eval_mmmu(args)\n"
  },
  {
    "path": "benchmark/mmmu/bench_sglang.py",
    "content": "\"\"\"\nBench the sglang-hosted vLM with benchmark MMMU\n\nUsage:\n    Host the VLM: python -m sglang.launch_server --model-path Qwen/Qwen2-VL-7B-Instruct --port 30000\n\n    Benchmark: python benchmark/mmmu/bench_sglang.py --port 30000 --concurrency 16\n\nThe eval output will be logged\n\"\"\"\n\nimport argparse\nimport asyncio\nimport base64\nimport mimetypes\nimport re\nimport sys\nimport time\nimport traceback\nfrom dataclasses import dataclass, field\nfrom pathlib import Path\nfrom typing import Any, List, Optional, Tuple\n\nimport aiohttp\nimport openai\nfrom data_utils import save_json\nfrom eval_utils import (\n    EvalArgs,\n    eval_result,\n    get_sampling_params,\n    prepare_samples,\n    process_result,\n)\nfrom tqdm import tqdm\n\nfrom sglang.test.test_utils import add_common_sglang_args_and_parse\n\nAIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=20 * 60 * 60)\n\n\n@dataclass\nclass RequestFuncOutput:\n    generated_text: List[str] = field(default_factory=list)\n    prompt_len: List[int] = field(default_factory=list)\n    output_len: List[int] = field(default_factory=list)\n    latency: List[float] = field(default_factory=list)\n    ttft: List[float] = field(default_factory=list)\n    itl: List[float] = field(default_factory=list)  # List of inter-token latencies\n\n    success: bool = False\n    error: str = \"\"\n\n\nasync def async_request_profile(api_url: str) -> RequestFuncOutput:\n    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:\n        output = RequestFuncOutput()\n        try:\n            async with session.post(url=api_url) as response:\n                if response.status == 200:\n                    output.success = True\n                else:\n                    output.error = response.reason or \"\"\n                    output.success = False\n        except Exception:\n            output.success = False\n            exc_info = sys.exc_info()\n            output.error = \"\".join(traceback.format_exception(*exc_info))\n\n    return output\n\n\ndef _get_prefix_suffix(prompt: str) -> Tuple[str, str]:\n    \"\"\"Split the prompt into prefix and suffix.\"\"\"\n    prefix = prompt.split(\"<\")[0]\n    suffix = prompt.split(\">\", 1)[1]\n    return prefix, suffix\n\n\nasync def process_sample(\n    client: Any,\n    sample: dict,\n    sampling_params: dict,\n    model: str,\n    reasoning_effort: Optional[str] = None,\n    lora_path: Optional[str] = None,\n) -> Tuple[dict, str]:\n    \"\"\"Send a single sample to the LLM and return (sample, response).\"\"\"\n    prompt = sample[\"final_input_prompt\"]\n    prefix, suffix = _get_prefix_suffix(prompt)\n    image = sample[\"image\"]\n    assert image is not None\n    image_path = sample[\"image_path\"]\n    if image_path and not image_path.startswith((\"http://\", \"https://\", \"data:\")):\n        p = Path(image_path)\n        mime = mimetypes.guess_type(str(p))[0] or \"image/png\"\n        with open(p, \"rb\") as f:\n            b64 = base64.b64encode(f.read()).decode()\n        image_url = f\"data:{mime};base64,{b64}\"\n    else:\n        image_url = image_path\n    extra_body = {\"lora_path\": lora_path} if lora_path else None\n    payload = {\n        \"model\": model,\n        \"messages\": [\n            {\n                \"role\": \"user\",\n                \"content\": [\n                    {\"type\": \"text\", \"text\": prefix},\n                    {\"type\": \"image_url\", \"image_url\": {\"url\": image_url}},\n                    {\"type\": \"text\", \"text\": suffix},\n                ],\n            }\n        ],\n        \"extra_body\": extra_body,\n        **sampling_params,\n    }\n    if reasoning_effort:\n        payload[\"reasoning_effort\"] = reasoning_effort\n    response = await client.chat.completions.create(**payload)\n    msg = response.choices[0].message\n    content = msg.content\n    if content is None:\n        content = getattr(msg, \"reasoning_content\", None)\n    return sample, content\n\n\nasync def process_sample_with_semaphore(\n    semaphore: asyncio.Semaphore,\n    client: Any,\n    sample: dict,\n    sampling_params: dict,\n    model: str,\n    reasoning_effort: Optional[str] = None,\n    lora_path: Optional[str] = None,\n) -> Tuple[dict, str]:\n    \"\"\"Wrap process_sample with a semaphore for concurrency control.\"\"\"\n    async with semaphore:\n        return await process_sample(\n            client, sample, sampling_params, model, reasoning_effort, lora_path\n        )\n\n\nasync def eval_mmmu(args) -> None:\n    \"\"\"Main evaluation loop with concurrency control.\"\"\"\n    eval_args = EvalArgs.from_cli_args(args)\n    sampling_params = get_sampling_params(eval_args)\n    samples = prepare_samples(eval_args)\n    model = args.model\n    reasoning_effort = eval_args.reasoning_effort\n    lora_path = eval_args.lora_path\n    answer_dict = {}\n    out_samples = {}\n    client = openai.AsyncOpenAI(\n        api_key=\"sk\",\n        base_url=f\"http://127.0.0.1:{args.port}/v1\",\n        timeout=20 * 60 * 60,\n    )\n    start = time.perf_counter()\n    base_url = f\"http://127.0.0.1:{args.port}\"\n\n    if args.profile:\n        print(\"Starting profiler...\")\n        profile_output = await async_request_profile(\n            api_url=f\"{base_url}/start_profile\"\n        )\n        if profile_output.success:\n            print(\"Profiler started\")\n\n        samples = samples[: args.profile_number]\n\n    if args.concurrency == 1:\n        # For concurrency == 1, run in sequential mode to ensure consistent order\n        # this is mainly for profiling\n        for sample in tqdm(samples):\n            _, response = await process_sample(\n                client, sample, sampling_params, model, reasoning_effort, lora_path\n            )\n            sample[\"original_response\"] = response\n            answer = (\n                re.search(args.response_answer_regex, response)\n                if response is not None\n                else None\n            )\n            process_result(\n                answer.group(1).strip() if answer else response,\n                sample,\n                answer_dict,\n                out_samples,\n            )\n    else:\n        semaphore = asyncio.Semaphore(args.concurrency)\n        tasks = [\n            process_sample_with_semaphore(\n                semaphore,\n                client,\n                sample,\n                sampling_params,\n                model,\n                reasoning_effort,\n                lora_path,\n            )\n            for sample in samples\n        ]\n\n        for coro in tqdm(asyncio.as_completed(tasks), total=len(tasks)):\n            sample, response = await coro\n            sample[\"original_response\"] = response\n            answer = (\n                re.search(args.response_answer_regex, response)\n                if response is not None\n                else None\n            )\n            process_result(\n                answer.group(1).strip() if answer else response,\n                sample,\n                answer_dict,\n                out_samples,\n            )\n\n    if args.profile:\n        print(\"Stopping profiler...\")\n        profile_output = await async_request_profile(api_url=f\"{base_url}/stop_profile\")\n        if profile_output.success:\n            print(\"Profiler stopped\")\n\n    print(f\"Benchmark time: {time.perf_counter() - start}\")\n    args.output_path = \"./answer_sglang.json\"\n    save_json(args.output_path, out_samples)\n    eval_result(\n        model_answer_path=args.output_path,\n        answer_dict=answer_dict,\n        eval_output_path=\"./val_sglang.json\",\n    )\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--model\",\n        type=str,\n        default=\"default\",\n        help=\"Model name to use in API requests.\",\n    )\n    EvalArgs.add_cli_args(parser)\n    args = add_common_sglang_args_and_parse(parser)\n    return args\n\n\ndef main():\n    args = parse_args()\n    asyncio.run(eval_mmmu(args))\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "benchmark/mmmu/data_utils.py",
    "content": "\"\"\"Utils for data load, save, and process (e.g., prompt construction)\"\"\"\n\nimport json\nimport os\nimport re\n\nimport yaml\n\nDOMAIN_CAT2SUB_CAT = {\n    \"Art and Design\": [\"Art\", \"Art_Theory\", \"Design\", \"Music\"],\n    \"Business\": [\"Accounting\", \"Economics\", \"Finance\", \"Manage\", \"Marketing\"],\n    \"Science\": [\n        \"Biology\",\n        \"Chemistry\",\n        \"Geography\",\n        \"Math\",\n        \"Physics\",\n    ],\n    \"Health and Medicine\": [\n        \"Basic_Medical_Science\",\n        \"Clinical_Medicine\",\n        \"Diagnostics_and_Laboratory_Medicine\",\n        \"Pharmacy\",\n        \"Public_Health\",\n    ],\n    \"Humanities and Social Science\": [\n        \"History\",\n        \"Literature\",\n        \"Sociology\",\n        \"Psychology\",\n    ],\n    \"Tech and Engineering\": [\n        \"Agriculture\",\n        \"Architecture_and_Engineering\",\n        \"Computer_Science\",\n        \"Electronics\",\n        \"Energy_and_Power\",\n        \"Materials\",\n        \"Mechanical_Engineering\",\n    ],\n}\n\n\nCAT_SHORT2LONG = {\n    \"acc\": \"Accounting\",\n    \"agri\": \"Agriculture\",\n    \"arch\": \"Architecture_and_Engineering\",\n    \"art\": \"Art\",\n    \"art_theory\": \"Art_Theory\",\n    \"bas_med\": \"Basic_Medical_Science\",\n    \"bio\": \"Biology\",\n    \"chem\": \"Chemistry\",\n    \"cli_med\": \"Clinical_Medicine\",\n    \"cs\": \"Computer_Science\",\n    \"design\": \"Design\",\n    \"diag_med\": \"Diagnostics_and_Laboratory_Medicine\",\n    \"econ\": \"Economics\",\n    \"elec\": \"Electronics\",\n    \"ep\": \"Energy_and_Power\",\n    \"fin\": \"Finance\",\n    \"geo\": \"Geography\",\n    \"his\": \"History\",\n    \"liter\": \"Literature\",\n    \"manage\": \"Manage\",\n    \"mark\": \"Marketing\",\n    \"mate\": \"Materials\",\n    \"math\": \"Math\",\n    \"mech\": \"Mechanical_Engineering\",\n    \"music\": \"Music\",\n    \"phar\": \"Pharmacy\",\n    \"phys\": \"Physics\",\n    \"psy\": \"Psychology\",\n    \"pub_health\": \"Public_Health\",\n    \"socio\": \"Sociology\",\n}\n\n\ndef get_multi_choice_info(options):\n    \"\"\"\n    Given the list of options for multiple choice question\n    Return the index2ans and all_choices\n    \"\"\"\n\n    start_chr = \"A\"\n    all_choices = []\n    index2ans = {}\n    for i, option in enumerate(options):\n        index2ans[chr(ord(start_chr) + i)] = option\n        all_choices.append(chr(ord(start_chr) + i))\n\n    return index2ans, all_choices\n\n\ndef load_yaml(file_path):\n    with open(file_path, \"r\") as stream:\n        try:\n            yaml_dict = yaml.safe_load(stream)\n        except yaml.YAMLError as exc:\n            print(exc)\n\n    return yaml_dict\n\n\ndef parse_img_path(text):\n    matches = re.findall(\"<img='(.*?)'>\", text)\n    return matches\n\n\ndef process_single_sample(data):\n    question = data[\"question\"]\n    o_imgs_paths = []\n    for option in data[\"options\"]:\n        current_o_imgs_paths = parse_img_path(option)\n        for img_path in current_o_imgs_paths:\n            o_imgs_paths.append(img_path)\n\n    if len(o_imgs_paths) > 1:  # multiple images in options, used for random selection\n        return {\n            \"id\": data[\"id\"],\n            \"question\": question,\n            \"options\": data[\"options\"],\n            \"answer\": data[\"answer\"],\n            \"image\": None,\n            \"question_type\": data[\"question_type\"],\n        }\n    else:\n        return {\n            \"id\": data[\"id\"],\n            \"question\": question,\n            \"options\": data[\"options\"],\n            \"answer\": data[\"answer\"],\n            \"image\": data[\"image_1\"],\n            \"question_type\": data[\"question_type\"],\n        }\n\n\n# DATA SAVING\ndef save_json(filename, ds):\n    print(f\"answers saved to: {filename}\")\n    os.makedirs(os.path.dirname(filename), exist_ok=True)\n    with open(filename, \"w\") as f:\n        json.dump(ds, f, indent=4)\n\n\ndef save_jsonl(filename, data):\n    \"\"\"\n    Save a dictionary of data to a JSON Lines file with the filename as key and caption as value.\n\n    Args:\n        filename (str): The path to the file where the data should be saved.\n        data (dict): The dictionary containing the data to save where key is the image path and value is the caption.\n    \"\"\"\n    with open(filename, \"w\", encoding=\"utf-8\") as f:\n        for img_path, caption in data.items():\n            # Extract the base filename without the extension\n            base_filename = os.path.basename(img_path)\n            # Create a JSON object with the filename as the key and caption as the value\n            json_record = json.dumps({base_filename: caption}, ensure_ascii=False)\n            # Write the JSON object to the file, one per line\n            f.write(json_record + \"\\n\")\n\n\ndef save_args(args, path_dir):\n    argsDict = args.__dict__\n    with open(path_dir + \"setting.txt\", \"w\") as f:\n        f.writelines(\"------------------ start ------------------\" + \"\\n\")\n        for eachArg, value in argsDict.items():\n            f.writelines(eachArg + \" : \" + str(value) + \"\\n\")\n        f.writelines(\"------------------- end -------------------\")\n\n\n# DATA PROCESSING\ndef construct_prompt(sample, config):\n    question = sample[\"question\"]\n    options = eval(sample[\"options\"])\n    example = \"\"\n    if sample[\"question_type\"] == \"multiple-choice\":\n        start_chr = \"A\"\n        prediction_range = []\n        index2ans = {}\n        for option in options:\n            prediction_range.append(start_chr)\n            example += f\"({start_chr}) {option}\\n\"\n            index2ans[start_chr] = option\n            start_chr = chr(ord(start_chr) + 1)\n        empty_prompt_sample_structure = config[\"multi_choice_example_format\"]\n        empty_prompt = empty_prompt_sample_structure.format(question, example)\n        res_dict = {}\n        res_dict[\"index2ans\"] = index2ans\n        res_dict[\"correct_choice\"] = sample[\"answer\"]\n        res_dict[\"all_choices\"] = prediction_range\n        res_dict[\"empty_prompt\"] = empty_prompt\n        if config[\"task_instructions\"]:\n            res_dict[\"final_input_prompt\"] = (\n                config[\"task_instructions\"].strip() + \"\\n\\n\" + empty_prompt\n            )\n        else:\n            res_dict[\"final_input_prompt\"] = empty_prompt\n\n        res_dict[\"gt_content\"] = options[ord(sample[\"answer\"].upper()) - ord(\"A\")]\n    else:\n        empty_prompt_sample_structure = config[\"short_ans_example_format\"]\n        empty_prompt = empty_prompt_sample_structure.format(question)\n        res_dict = {}\n        res_dict[\"empty_prompt\"] = empty_prompt\n        if config[\"task_instructions\"]:\n            res_dict[\"final_input_prompt\"] = (\n                config[\"task_instructions\"].strip() + \"\\n\\n\" + empty_prompt\n            )\n        else:\n            res_dict[\"final_input_prompt\"] = empty_prompt\n        res_dict[\"gt_content\"] = sample[\"answer\"]\n\n    res_dict.update(sample)\n    return res_dict\n"
  },
  {
    "path": "benchmark/mmmu/eval_utils.py",
    "content": "\"\"\"Response Parsing and Evaluation for various models\"\"\"\n\nimport argparse\nimport dataclasses\nimport json\nimport os\nimport pprint\nimport random\nimport re\nfrom concurrent.futures import ThreadPoolExecutor, as_completed\nfrom typing import Dict, Optional\n\nimport numpy as np\nimport torch\nfrom data_utils import (\n    CAT_SHORT2LONG,\n    DOMAIN_CAT2SUB_CAT,\n    construct_prompt,\n    load_yaml,\n    process_single_sample,\n    save_json,\n)\nfrom datasets import concatenate_datasets, load_dataset\nfrom tqdm import tqdm\n\n\n@dataclasses.dataclass\nclass EvalArgs:\n    seed: int = 42\n    split: str = \"validation\"\n    image_pixels_limit: int = -1\n    result_filename: str = f\"./val_sglang.json\"\n    prompt_format_file: str = \"prompt_format.yaml\"\n    dataset_path: str = \"MMMU/MMMU\"\n    extra_request_body: Optional[str] = None\n    profile: bool = False\n    profile_number: int = 5\n    concurrency: int = 1\n    max_new_tokens: Optional[int] = None\n    temperature: Optional[float] = None\n    response_answer_regex: str = \"(.*)\"\n    lora_path: Optional[str] = None\n    reasoning_effort: Optional[str] = None\n\n    @staticmethod\n    def add_cli_args(parser: argparse.ArgumentParser):\n        parser.add_argument(\n            \"--result-filename\",\n            type=str,\n            default=EvalArgs.result_filename,\n            help=\"The filename to save the evaluation results.\",\n        )\n        parser.add_argument(\n            \"--image-pixels-limit\",\n            type=int,\n            default=EvalArgs.image_pixels_limit,\n            help=\"The maximum number of pixels allowed for an image. If an image exceeds this limit, it will be skipped during evaluation.\",\n        )\n        parser.add_argument(\n            \"--dataset-path\",\n            type=str,\n            default=EvalArgs.dataset_path,\n            help=\"path to the dataset\",\n        )\n        parser.add_argument(\"--seed\", type=int, default=1, help=\"The random seed.\")\n        parser.add_argument(\n            \"--prompt-format-file\",\n            type=str,\n            help=\"The path to the prompt format of mmmu. If not, a default format llava_config.yaml will be used\",\n        )\n        parser.add_argument(\n            \"--split\",\n            type=str,\n            default=EvalArgs.split,\n            help='Split of the dataset to use for evaluation. Default is \"validation\".',\n        )\n        parser.add_argument(\n            \"--extra-request-body\",\n            metavar='{\"key1\": \"value1\", \"key2\": \"value2\"}',\n            type=str,\n            default=EvalArgs.extra_request_body,\n            help=\"Append given JSON object to the request payload. You can use this to specify\"\n            \"additional generate params like sampling params.\",\n        )\n        parser.add_argument(\n            \"--profile\", action=\"store_true\", help=\"enable mmmu profile\"\n        )\n        parser.add_argument(\n            \"--profile-number\",\n            type=int,\n            default=EvalArgs.profile_number,\n            help=\"Number of samples to profile. If not set, will profile all samples.\",\n        )\n        parser.add_argument(\n            \"--concurrency\",\n            type=int,\n            default=EvalArgs.concurrency,\n            help=\"Number of concurrent requests to make during evaluation. Default is 1, which means no concurrency.\",\n        )\n        parser.add_argument(\n            \"--max-new-tokens\",\n            type=int,\n            default=EvalArgs.max_new_tokens,\n            help=\"Maximum number of new tokens to generate per sample.\",\n        )\n        parser.add_argument(\n            \"--temperature\",\n            type=float,\n            default=EvalArgs.temperature,\n            help=\"Sampling temperature for generation.\",\n        )\n        parser.add_argument(\n            \"--response-answer-regex\",\n            type=str,\n            default=EvalArgs.response_answer_regex,\n            help=\"Specific regex to capture the answer from the response, string\",\n        )\n        parser.add_argument(\n            \"--lora-path\",\n            type=str,\n            default=EvalArgs.lora_path,\n            help=\"Specify the LoRA path to use for evaluation. If specified, the value will be specified in the body of every request as `lora-path`.\",\n        )\n        parser.add_argument(\n            \"--reasoning-effort\",\n            type=str,\n            default=EvalArgs.reasoning_effort,\n            choices=[\"none\", \"high\"],\n            help=\"Reasoning effort for the model (none or high).\",\n        )\n\n    @classmethod\n    def from_cli_args(cls, args: argparse.Namespace):\n        attrs = [attr.name for attr in dataclasses.fields(cls)]\n        return cls(**{attr: getattr(args, attr) for attr in attrs})\n\n\ndef set_seed(seed_value):\n    \"\"\"\n    Set the seed for PyTorch (both CPU and CUDA), Python, and NumPy for reproducible results.\n\n    :param seed_value: An integer value to be used as the seed.\n    \"\"\"\n    torch.manual_seed(seed_value)\n    if torch.cuda.is_available():\n        torch.cuda.manual_seed(seed_value)\n        torch.cuda.manual_seed_all(seed_value)  # For multi-GPU setups\n    random.seed(seed_value)\n    np.random.seed(seed_value)\n    torch.backends.cudnn.deterministic = True\n    torch.backends.cudnn.benchmark = False\n\n\ndef prepare_samples(eval_args: EvalArgs):\n    print(\"Preparing samples...\")\n    # Build prompts\n    set_seed(eval_args.seed)\n\n    prompt_format_file = (\n        eval_args.prompt_format_file\n        if eval_args.prompt_format_file is not None\n        else os.path.join(os.path.dirname(__file__), \"prompt_format.yaml\")\n    )\n    # load config and process to one value\n    eval_args.config = load_yaml(prompt_format_file)\n    for key, value in eval_args.config.items():\n        if key != \"eval_params\" and type(value) == list:\n            assert len(value) == 1, \"key {} has more than one value\".format(key)\n            eval_args.config[key] = value[0]\n\n    # run for each subject in parallel\n    sub_dataset_list = []\n    subjects = list(CAT_SHORT2LONG.values())  # Get a fixed list of subjects\n\n    print(f\"Loading datasets for {len(subjects)} subjects...\")\n    with ThreadPoolExecutor() as executor:\n        # Submit all load_dataset tasks\n        future_to_subject = {\n            executor.submit(\n                load_dataset, eval_args.dataset_path, subject, split=eval_args.split\n            ): subject\n            for subject in subjects\n        }\n\n        # Collect results as they complete\n        results = {}\n        for future in tqdm(\n            as_completed(future_to_subject),\n            total=len(subjects),\n            desc=\"Loading datasets\",\n        ):\n            subject = future_to_subject[future]\n            try:\n                results[subject] = future.result()\n            except Exception as exc:\n                print(f\"{subject} generated an exception: {exc}\")\n\n    # Ensure datasets are added in the original order for consistency\n    for subject in subjects:\n        if subject in results:\n            sub_dataset_list.append(results[subject])\n        else:\n            # Handle cases where a dataset failed to load (optional, depends on desired behavior)\n            print(f\"Warning: Dataset for subject '{subject}' could not be loaded.\")\n\n    # merge all dataset\n    dataset = concatenate_datasets(sub_dataset_list)\n\n    # Prepare images in parallel\n    images_path = os.path.expanduser(\"~/.cache/mmmu/images\")\n    os.makedirs(images_path, exist_ok=True)\n    print(f\"Saving images to: {images_path}\")\n\n    samples = []\n    skip_count = 0\n\n    def process_sample(i, sample):\n        sample = process_single_sample(sample)\n        sample = construct_prompt(sample, eval_args.config)\n        image = sample[\"image\"]\n        width, height = image.size\n        if 0 < eval_args.image_pixels_limit <= width * height:\n            return None, True\n        # Use a unique identifier for the image path to avoid potential collisions if indices reset\n        image_path = f\"{images_path}/image_{sample['id']}.png\"\n        if not os.path.exists(image_path):\n            image.save(image_path)\n        sample[\"image_path\"] = image_path\n        return sample, False\n\n    print(\"Processing samples...\")\n    with ThreadPoolExecutor() as executor:\n        # Pass the sample itself to process_sample, index is less reliable now\n        futures = [\n            executor.submit(\n                process_sample, i, sample\n            )  # Keep index i for tqdm maybe? Or remove it. Let's keep it for now.\n            for i, sample in enumerate(dataset)\n        ]\n        for future in tqdm(\n            as_completed(futures), total=len(dataset), desc=\"Processing samples\"\n        ):\n            sample, skipped = future.result()\n            if skipped:\n                skip_count += 1\n            elif sample:\n                samples.append(sample)\n\n    samples.sort(key=lambda x: x[\"final_input_prompt\"])\n\n    print(\n        f\"Skipping {skip_count} samples with large images, {round((float(skip_count) / len(dataset)) * 100, 2)}% of dataset\"\n    )\n    print(\"Samples have been prepared\")\n    return samples\n\n\ndef get_sampling_params(eval_args):\n    extra_request_body = {}\n    if eval_args.extra_request_body:\n        extra_request_body = json.loads(eval_args.extra_request_body)\n    sampling_params = {\n        **extra_request_body,\n    }\n\n    if eval_args.max_new_tokens is not None and eval_args.max_new_tokens > 0:\n        sampling_params.update({\"max_completion_tokens\": eval_args.max_new_tokens})\n\n    if eval_args.temperature is not None:\n        sampling_params.update({\"temperature\": eval_args.temperature})\n\n    return sampling_params\n\n\n# ----------- Process Multi-choice -------------\ndef parse_multi_choice_response(response, all_choices, index2ans):\n    \"\"\"\n    Parse the prediction from the generated response.\n    Return the predicted index e.g., A, B, C, D.\n    \"\"\"\n    for char in [\",\", \".\", \"!\", \"?\", \";\", \":\", \"'\"]:\n        response = response.strip(char)\n    response = \" \" + response + \" \"  # add space to avoid partial match\n\n    index_ans = True\n    ans_with_brack = False\n    candidates = []\n    for choice in all_choices:  # e.g., (A) (B) (C) (D)\n        if f\"({choice})\" in response:\n            candidates.append(choice)\n            ans_with_brack = True\n\n    if len(candidates) == 0:\n        for choice in all_choices:  # e.g., A B C D\n            if f\" {choice} \" in response:\n                candidates.append(choice)\n\n    # if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example\n    if len(candidates) == 0 and len(response.split()) > 5:\n        for index, ans in index2ans.items():\n            if ans.lower() in response.lower():\n                candidates.append(index)\n                index_ans = False  # it's content ans.\n\n    if len(candidates) == 0:  # still not get answer, randomly choose one.\n        pred_index = random.choice(all_choices)\n    elif len(candidates) > 1:\n        start_indexes = []\n        if index_ans:\n            if ans_with_brack:\n                for can in candidates:\n                    index = response.rfind(f\"({can})\")\n                    start_indexes.append(index)  # -1 will be ignored anyway\n                # start_indexes = [generated_response.index(f'({can})') for can in candidates]\n            else:\n                for can in candidates:\n                    index = response.rfind(f\" {can} \")\n                    start_indexes.append(index)\n        else:\n            for can in candidates:\n                index = response.lower().rfind(index2ans[can].lower())\n                start_indexes.append(index)\n        # get the last one\n        pred_index = candidates[np.argmax(start_indexes)]\n    else:  # if only one candidate, use it.\n        pred_index = candidates[0]\n\n    return pred_index\n\n\n# ----------- Process Open -------------\ndef check_is_number(string):\n    \"\"\"\n    Check if the given string a number.\n    \"\"\"\n    try:\n        float(string.replace(\",\", \"\"))\n        return True\n    except ValueError:\n        # check if there's comma inside\n        return False\n\n\ndef normalize_str(string):\n    \"\"\"\n    Normalize the str to lower case and make them float numbers if possible.\n    \"\"\"\n    # check if characters in the string\n\n    # if number, numerize it.\n    string = string.strip()\n\n    is_number = check_is_number(string)\n\n    if is_number:\n        string = string.replace(\",\", \"\")\n        string = float(string)\n        # leave 2 decimal\n        string = round(string, 2)\n        return [string]\n    else:  # it's likely to be a string\n        # lower it\n        string = string.lower()\n        if len(string) == 1:\n            return [\" \" + string, string + \" \"]  # avoid trivial matches\n        return [string]\n\n\ndef extract_numbers(string):\n    \"\"\"\n    Exact all forms of numbers from a string with regex.\n    \"\"\"\n    # Pattern for numbers with commas\n    pattern_commas = r\"-?\\b\\d{1,3}(?:,\\d{3})+\\b\"\n    # Pattern for scientific notation\n    pattern_scientific = r\"-?\\d+(?:\\.\\d+)?[eE][+-]?\\d+\"\n    # Pattern for simple numbers without commas\n    pattern_simple = r\"-?(?:\\d+\\.\\d+|\\.\\d+|\\d+\\b)(?![eE][+-]?\\d+)(?![,\\d])\"\n\n    # Extract numbers with commas\n    numbers_with_commas = re.findall(pattern_commas, string)\n    # Extract numbers in scientific notation\n    numbers_scientific = re.findall(pattern_scientific, string)\n    # Extract simple numbers without commas\n    numbers_simple = re.findall(pattern_simple, string)\n\n    # Combine all extracted numbers\n    all_numbers = numbers_with_commas + numbers_scientific + numbers_simple\n    return all_numbers\n\n\ndef parse_open_response(response):\n    \"\"\"\n    Parse the prediction from the generated response.\n    Return a list of predicted strings or numbers.\n    \"\"\"\n\n    # content = content.strip(\"\\n\").strip(\".\").strip(\" \")\n    def get_key_subresponses(response):\n        key_responses = []\n        response = response.strip().strip(\".\").lower()\n        sub_responses = re.split(r\"\\.\\s(?=[A-Z])|\\n\", response)\n        indicators_of_keys = [\n            \"could be \",\n            \"so \",\n            \"is \",\n            \"thus \",\n            \"therefore \",\n            \"final \",\n            \"answer \",\n            \"result \",\n        ]\n        key_responses = []\n        for index, resp in enumerate(sub_responses):\n            # if last one, accept it's an equation (the entire response can be just one sentence with equation)\n            if index == len(sub_responses) - 1:\n                indicators_of_keys.extend([\"=\"])\n            shortest_key_response = None  # the shortest response that may contain the answer (tail part of the response)\n            for indicator in indicators_of_keys:\n                if indicator in resp:\n                    if not shortest_key_response:\n                        shortest_key_response = resp.split(indicator)[-1].strip()\n                    else:\n                        if len(resp.split(indicator)[-1].strip()) < len(\n                            shortest_key_response\n                        ):\n                            shortest_key_response = resp.split(indicator)[-1].strip()\n                    # key_responses.append(resp.split(indicator)[1].strip())\n\n            if shortest_key_response:\n                # and it's not trivial\n                if shortest_key_response.strip() not in [\n                    \":\",\n                    \",\",\n                    \".\",\n                    \"!\",\n                    \"?\",\n                    \";\",\n                    \":\",\n                    \"'\",\n                ]:\n                    key_responses.append(shortest_key_response)\n        if len(key_responses) == 0:  # did not found any\n            return [response]\n        return key_responses\n\n    # pdb.set_trace()\n    key_responses = get_key_subresponses(response)\n\n    pred_list = key_responses.copy()  # keep the original string response\n    for resp in key_responses:\n        pred_list.extend(extract_numbers(resp))\n\n    tmp_pred_list = []\n    for i in range(len(pred_list)):\n        tmp_pred_list.extend(normalize_str(pred_list[i]))\n    pred_list = tmp_pred_list\n\n    # remove duplicates\n    pred_list = list(set(pred_list))\n\n    return pred_list\n\n\n# ----------- Evaluation -------------\n\n\ndef eval_multi_choice(gold_i, pred_i):\n    \"\"\"\n    Evaluate a multiple choice instance.\n    \"\"\"\n    correct = False\n    # for case like Answer: A, Answer is A, answer is A, answer: A\n    for _exp in [\"Answer:\", \"Answer is \", \"answer is \", \"answer: \"]:\n        if _exp in pred_i:\n            pred_i = pred_i.split(_exp)[1].strip()\n            break\n    # for case like (A), (B), (C), (D) ......\n    if \"(\" in pred_i and \")\" in pred_i:\n        try:\n            pred_i = re.search(r\"\\(([A-Z])\\)\", pred_i).group(1)\n        except:\n            print(f\"Error to extract answer from: {pred_i}\")\n            pass\n    # only they are exactly the same, we consider it as correct\n    if isinstance(gold_i, list):\n        for answer in gold_i:\n            if answer == pred_i:\n                correct = True\n                break\n    else:  # gold_i is a string\n        if gold_i == pred_i:\n            correct = True\n    return correct\n\n\ndef eval_open(gold_i, pred_i):\n    \"\"\"\n    Evaluate an open question instance\n    \"\"\"\n    correct = False\n    if isinstance(gold_i, list):\n        # use float to avoid trivial matches\n        norm_answers = []\n        for answer in gold_i:\n            norm_answers.extend(normalize_str(answer))\n    else:\n        norm_answers = normalize_str(gold_i)\n    for pred in pred_i:  # pred is already normalized in parse response phase\n        if isinstance(pred, str):  # if it's a string, then find if ans in the pred_i\n            for norm_ans in norm_answers:\n                # only see if the string answer in the string pred\n                if isinstance(norm_ans, str) and norm_ans in pred:\n                    if not correct:\n                        correct = True\n                    break\n        else:  # it's a float number\n            if pred in norm_answers:\n                if not correct:\n                    correct = True\n                break\n    return correct\n\n\n# ----------- Batch Evaluation -------------\ndef evaluate(samples):\n    \"\"\"\n    Batch evaluation for multiple choice and open questions.\n    \"\"\"\n    pred_correct = 0\n    judge_dict = dict()\n    for sample in samples:\n        gold_i = sample[\"answer\"]\n        pred_i = sample[\"parsed_pred\"]\n        if sample[\"question_type\"] == \"multiple-choice\":\n            correct = eval_multi_choice(gold_i, pred_i)\n        else:  # open question\n            correct = eval_open(gold_i, pred_i)\n\n        if correct:\n            judge_dict[sample[\"id\"]] = \"Correct\"\n            pred_correct += 1\n        else:\n            # print(f\"Wrong! expected {pred_i}, answered with {gold_i}\")\n            judge_dict[sample[\"id\"]] = \"Wrong\"\n\n    if len(samples) == 0:\n        return {\"acc\": 0}\n    return judge_dict, {\"acc\": pred_correct / len(samples)}\n\n\n# ----------- Calculate Accuracy -------------\ndef calculate_ins_level_acc(results: Dict):\n    \"\"\"Calculate the instruction level accuracy for given Subject results\"\"\"\n    acc = 0\n    ins_num = 0\n    for cat_results in results.values():\n        acc += cat_results[\"acc\"] * cat_results[\"num_example\"]\n        ins_num += cat_results[\"num_example\"]\n    if ins_num == 0:\n        return 0\n    return acc / ins_num\n\n\ndef process_result(response, sample, answer_dict, out_samples):\n    if response is None:\n        return\n    if sample[\"question_type\"] == \"multiple-choice\":\n        pred_ans = parse_multi_choice_response(\n            response, sample[\"all_choices\"], sample[\"index2ans\"]\n        )\n    else:  # open question\n        pred_ans = response\n\n    out_samples[sample[\"id\"]] = {\n        \"pred_ans\": pred_ans,\n        \"original_response\": sample[\"original_response\"],\n        \"ground_truth\": sample[\"answer\"],\n        \"question_type\": sample[\"question_type\"],\n    }\n\n    # set ground truth answer\n    answer_dict[sample[\"id\"]] = {\n        \"question_type\": sample[\"question_type\"],\n        \"ground_truth\": sample[\"answer\"],\n    }\n\n\ndef eval_result(model_answer_path, answer_dict, eval_output_path=None):\n    if eval_output_path is None:\n        eval_output_path = model_answer_path\n    print(\"Evaluating...\")\n    output_dict = json.load(open(model_answer_path))\n    # answer_dict = json.load(open(answer_path))\n\n    # group by category\n    output_dict_w_cat = {}\n    for data_id, parsed_pred in output_dict.items():\n        if isinstance(parsed_pred, str):\n            parsed_pred = parsed_pred\n        elif isinstance(parsed_pred, dict):\n            parsed_pred = parsed_pred[\"pred_ans\"]\n        else:\n            raise ValueError(f\"Unknown type of parsed_pred: {type(parsed_pred)}\")\n        category = \"_\".join(data_id.split(\"_\")[1:-1])\n        if category not in output_dict_w_cat:\n            output_dict_w_cat.update({category: {}})\n        output_dict_w_cat[category].update({data_id: parsed_pred})\n\n    # group by category\n    answer_dict_w_cat = {}\n    for data_id, parsed_pred in answer_dict.items():\n        category = \"_\".join(data_id.split(\"_\")[1:-1])\n        if category not in answer_dict_w_cat:\n            answer_dict_w_cat.update({category: {}})\n        answer_dict_w_cat[category].update({data_id: parsed_pred})\n\n    evaluation_result = {}\n\n    for category in CAT_SHORT2LONG.values():\n        # print(\"Evaluating: {}\".format(category))\n        # get cat_outputs and cat_answers\n        try:\n            cat_outputs = output_dict_w_cat[category]\n            cat_answers = answer_dict_w_cat[category]\n        except KeyError:\n            # print(\"Skipping {} for not found\".format(category))\n            continue\n\n        exampels_to_eval = []\n        for data_id, parsed_pred in cat_outputs.items():\n            question_type = cat_answers[data_id][\"question_type\"]\n            if question_type != \"multiple-choice\":\n                parsed_pred = parse_open_response(\n                    parsed_pred\n                )  # mainly for type consistency (make it number, etc.)\n            else:\n                parsed_pred = parsed_pred\n\n            exampels_to_eval.append(\n                {\n                    \"id\": data_id,\n                    \"question_type\": question_type,\n                    \"answer\": cat_answers[data_id][\"ground_truth\"],\n                    \"parsed_pred\": parsed_pred,\n                }\n            )\n\n        judge_dict, metric_dict = evaluate(exampels_to_eval)\n        metric_dict.update({\"num_example\": len(exampels_to_eval)})\n        for key, value in judge_dict.items():\n            output_dict[key][\"judge\"] = value\n\n        evaluation_result[category] = metric_dict\n\n    save_json(model_answer_path, output_dict)\n    printable_results = {}\n    # pdb.set_trace()\n    # add domain Subject\n    for domain, in_domain_cats in DOMAIN_CAT2SUB_CAT.items():\n        in_domain_cat_results = {}\n        for cat_name in in_domain_cats:  # use the order in DOMAIN_CAT2SUB_CAT\n            if cat_name in evaluation_result.keys():\n                in_domain_cat_results[cat_name] = evaluation_result[cat_name]\n            else:\n                pass\n        in_domain_ins_acc = calculate_ins_level_acc(in_domain_cat_results)\n        in_domain_data_num = sum(\n            [\n                cat_results[\"num_example\"]\n                for cat_results in in_domain_cat_results.values()\n            ]\n        )\n        printable_results[\"Overall-\" + domain] = {\n            \"num\": int(in_domain_data_num),\n            \"acc\": round(in_domain_ins_acc, 3),\n        }\n        # add sub category\n        for cat_name, cat_results in in_domain_cat_results.items():\n            printable_results[cat_name] = {\n                \"num\": int(cat_results[\"num_example\"]),\n                \"acc\": round(cat_results[\"acc\"], 3),\n            }\n\n    # table.append([\"-----------------------------\", \"-----\", \"----\"])\n    all_ins_acc = calculate_ins_level_acc(evaluation_result)\n    overall_acc = round(all_ins_acc, 3)\n    printable_results[\"Overall\"] = {\n        \"num\": sum(\n            [cat_results[\"num_example\"] for cat_results in evaluation_result.values()]\n        ),\n        \"acc\": overall_acc,\n    }\n    pprint.pprint(printable_results)\n    out = eval_output_path\n    with open(out, \"w\", encoding=\"utf-8\") as outfile:\n        json.dump(printable_results, outfile)\n        print(f\"eval out saved to {out}\")\n\n    print(f\"Overall accuracy: {overall_acc}\")\n"
  },
  {
    "path": "benchmark/mmmu/prompt_format.yaml",
    "content": "task_instructions:\n- \"\"\nmulti_choice_example_format:\n- \"{}\n\n{}\n\nAnswer with the option's letter from the given choices directly.\"\n\nshort_ans_example_format:\n- \"{}\n\nAnswer the question using a single word or phrase.\"\ntemperature:\n- 0\n"
  },
  {
    "path": "benchmark/mtbench/README.md",
    "content": "## Download Dataset\n\n```sh\nwget -O question.jsonl https://raw.githubusercontent.com/lm-sys/FastChat/main/fastchat/llm_judge/data/mt_bench/question.jsonl\n```\n\n## Run benchmark\n\n### Benchmark sglang\n```\npython -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000\n```\n\n```\npython3 bench_sglang.py --num-questions 80\n```\n\n### Benchmark sglang EAGLE\n```\npython3 -m sglang.launch_server --model meta-llama/Meta-Llama-3-8B-Instruct --speculative-algo EAGLE \\\n    --speculative-draft-model-path lmsys/sglang-EAGLE-LLaMA3-Instruct-8B --speculative-num-steps 5 \\\n    --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --dtype float16 --port 30000\n```\n\n```\npython3 bench_sglang_eagle.py --num-questions 80 --parallel 1\n```\n\n\n### Benchmark vllm\n```\npython3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000\n```\n\n```\npython3 bench_other.py --num-questions 80 --backend vllm\n```\n\n\n### Benchmark lightllm\n```\n# A10G\npython -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000\n```\n\n```\npython3 bench_other.py --num-questions 80 --backend lightllm\n```\n"
  },
  {
    "path": "benchmark/mtbench/bench_other.py",
    "content": "import argparse\nimport json\nimport os\nimport time\nimport uuid\nfrom concurrent.futures import ThreadPoolExecutor\n\nfrom fastchat.model import get_conversation_template\nfrom tqdm import tqdm\n\nfrom sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate\nfrom sglang.utils import download_and_cache_file\n\n\ndef load_questions(filename):\n    questions = []\n    with open(filename, \"r\") as fin:\n        for line in fin:\n            obj = json.loads(line)\n            questions.append(obj)\n    return questions\n\n\ndef write_answers(filename, model_id, questions, answers):\n    with open(os.path.expanduser(filename), \"w\") as fout:\n        for i in range(len(answers)):\n            ans_json = {\n                \"question_id\": questions[i][\"question_id\"],\n                \"answer_id\": uuid.uuid4().hex,\n                \"model_id\": model_id,\n                \"choices\": {\n                    \"index\": 0,\n                    \"turns\": [answers[i][0], answers[i][1]],\n                },\n                \"tstamp\": time.time(),\n            }\n            fout.write(json.dumps(ans_json) + \"\\n\")\n\n\ndef main(args):\n    # Download question file if not exist\n    question_file = args.question_file\n    url = \"https://raw.githubusercontent.com/lm-sys/FastChat/main/fastchat/llm_judge/data/mt_bench/question.jsonl\"\n    if not os.path.isfile(question_file):\n        question_file = download_and_cache_file(url)\n\n    questions = load_questions(question_file)\n    questions = (questions * 10)[: args.num_questions]\n    max_tokens = 256\n    model_id = \"llama-2-chat\"\n\n    conv_main = get_conversation_template(model_id)\n\n    # Select backend\n    call_generate = get_call_generate(args)\n\n    answers = [None] * len(questions)\n\n    def get_answer(i):\n        conv = conv_main.copy()\n        cur_answers = []\n        for j in range(2):\n            q = questions[i][\"turns\"][j]\n            conv.append_message(conv.roles[0], q)\n            conv.append_message(conv.roles[1], None)\n\n            prompt = conv.get_prompt()\n            output = call_generate(prompt, temperature=0, max_tokens=max_tokens).strip()\n\n            cur_answers.append(output)\n            conv.update_last_message(output)\n\n        answers[i] = cur_answers\n\n    # Run requests\n    tic = time.perf_counter()\n    if args.parallel == 1:\n        for i in tqdm(range(len(questions))):\n            get_answer(i)\n    else:\n        with ThreadPoolExecutor(args.parallel) as executor:\n            list(\n                tqdm(\n                    executor.map(get_answer, list(range(len(questions)))),\n                    total=len(questions),\n                )\n            )\n\n    latency = time.perf_counter() - tic\n\n    print(f\"#questions: {len(questions)}, Latency: {latency:.2f}\")\n\n    # Write results\n    answer_file = args.answer_file or f\"tmp_output_{args.backend}.txt\"\n    write_answers(answer_file, model_id, questions, answers)\n\n    with open(args.result_file, \"a\") as fout:\n        value = {\n            \"task\": \"mtbench\",\n            \"backend\": args.backend,\n            \"num_gpus\": 1,\n            \"latency\": round(latency, 3),\n            \"num_requests\": args.num_questions,\n            \"other\": {\n                \"num_questions\": args.num_questions,\n                \"parallel\": args.parallel,\n            },\n        }\n        fout.write(json.dumps(value) + \"\\n\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--question-file\", type=str, default=\"question.jsonl\")\n    parser.add_argument(\"--answer-file\", type=str, default=None)\n    parser.add_argument(\"--num-questions\", type=int, default=80)\n    args = add_common_other_args_and_parse(parser)\n    main(args)\n"
  },
  {
    "path": "benchmark/mtbench/bench_sglang.py",
    "content": "import argparse\nimport json\nimport os\nimport time\nimport uuid\n\nimport sglang as sgl\nfrom sglang.test.test_utils import (\n    add_common_sglang_args_and_parse,\n    select_sglang_backend,\n)\nfrom sglang.utils import download_and_cache_file\n\n\ndef load_questions(filename):\n    questions = []\n    with open(filename, \"r\") as fin:\n        for line in fin:\n            obj = json.loads(line)\n            questions.append(obj)\n    return questions\n\n\ndef write_answers(filename, model_id, questions, answers):\n    with open(os.path.expanduser(filename), \"w\") as fout:\n        for i in range(len(answers)):\n            ans_json = {\n                \"question_id\": questions[i][\"question_id\"],\n                \"answer_id\": uuid.uuid4().hex,\n                \"model_id\": model_id,\n                \"choices\": {\n                    \"index\": 0,\n                    \"turns\": [answers[i][0], answers[i][1]],\n                },\n                \"tstamp\": time.time(),\n            }\n            fout.write(json.dumps(ans_json) + \"\\n\")\n\n\n@sgl.function\ndef answer_mt_bench(s, question_1, question_2):\n    s += sgl.system()\n    s += sgl.user(question_1)\n    s += sgl.assistant(sgl.gen(\"answer_1\"))\n    s += sgl.user(question_2)\n    s += sgl.assistant(sgl.gen(\"answer_2\"))\n\n\ndef main(args):\n    # Download question file if not exist\n    question_file = args.question_file\n    url = \"https://raw.githubusercontent.com/lm-sys/FastChat/main/fastchat/llm_judge/data/mt_bench/question.jsonl\"\n    if not os.path.isfile(question_file):\n        question_file = download_and_cache_file(url)\n\n    # Construct prompts\n    questions = load_questions(question_file)[: args.num_questions]\n    arguments = [\n        {\"question_1\": q[\"turns\"][0], \"question_2\": q[\"turns\"][1]} for q in questions\n    ]\n\n    # Select backend\n    backend = select_sglang_backend(args)\n    sgl.set_default_backend(backend)\n\n    # Run requests\n    tic = time.perf_counter()\n    rets = answer_mt_bench.run_batch(\n        arguments,\n        temperature=0,\n        max_new_tokens=256,\n        num_threads=args.parallel,\n        progress_bar=True,\n    )\n    answers = [[s[\"answer_1\"], s[\"answer_2\"]] for s in rets]\n    latency = time.perf_counter() - tic\n\n    print(f\"#questions: {len(questions)}, Latency: {latency:.2f}\")\n\n    # Write results\n    model_id = backend.model_info[\"model_path\"]\n    answer_file = args.answer_file or f\"tmp_output_{args.backend}.txt\"\n    write_answers(answer_file, model_id, questions, answers)\n\n    with open(args.result_file, \"a\") as fout:\n        value = {\n            \"task\": \"mtbench\",\n            \"backend\": args.backend,\n            \"num_gpus\": 1,\n            \"latency\": round(latency, 3),\n            \"num_requests\": args.num_questions,\n            \"other\": {\n                \"num_questions\": args.num_questions,\n                \"parallel\": args.parallel,\n            },\n        }\n        fout.write(json.dumps(value) + \"\\n\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--question-file\", type=str, default=\"question.jsonl\")\n    parser.add_argument(\"--answer-file\", type=str, default=None)\n    parser.add_argument(\"--num-questions\", type=int, default=80)\n    args = add_common_sglang_args_and_parse(parser)\n    main(args)\n"
  },
  {
    "path": "benchmark/mtbench/bench_sglang_eagle.py",
    "content": "\"\"\"\nAdapted from https://github.com/chromecast56/sglang/blob/6f145d2eadb93a116134f703358ce76f15381045/benchmark/mtbench/bench_sglang.py\n\nBenchmark SGLang EAGLE/EAGLE3 Speculative Decoding\n\nUsage:\npython3 benchmark/mtbench/bench_sglang_eagle.py --num-questions 80 --parallel 1\n\"\"\"\n\nimport argparse\nimport json\nimport os\nimport time\nimport uuid\n\nimport sglang as sgl\nfrom sglang.test.test_utils import (\n    add_common_sglang_args_and_parse,\n    select_sglang_backend,\n)\nfrom sglang.utils import download_and_cache_file\n\n\ndef load_questions(filename):\n    questions = []\n    with open(filename, \"r\") as fin:\n        for line in fin:\n            obj = json.loads(line)\n            questions.append(obj)\n    return questions\n\n\ndef write_answers(filename, model_id, questions, answers):\n    with open(os.path.expanduser(filename), \"w\") as fout:\n        for i in range(len(answers)):\n            ans_json = {\n                \"question_id\": questions[i][\"question_id\"],\n                \"answer_id\": uuid.uuid4().hex,\n                \"model_id\": model_id,\n                \"choices\": {\n                    \"index\": 0,\n                    \"turns\": [answers[i][0], answers[i][1]],\n                },\n                \"tstamp\": time.time(),\n            }\n            fout.write(json.dumps(ans_json) + \"\\n\")\n\n\n@sgl.function\ndef answer_mt_bench(s, question_1, question_2):\n    s += sgl.system(\n        \"You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\\n\\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\"\n    )\n    s += sgl.user(question_1)\n    s += sgl.assistant(sgl.gen(\"answer_1\"))\n    s += sgl.user(question_2)\n    s += sgl.assistant(sgl.gen(\"answer_2\"))\n\n\ndef main(args):\n    # Download question file if not exist\n    question_file = args.question_file\n    url = \"https://raw.githubusercontent.com/lm-sys/FastChat/main/fastchat/llm_judge/data/mt_bench/question.jsonl\"\n    if not os.path.isfile(question_file):\n        question_file = download_and_cache_file(url)\n\n    # Construct prompts\n    questions = load_questions(question_file)[: args.num_questions]\n    arguments = [\n        {\"question_1\": q[\"turns\"][0], \"question_2\": q[\"turns\"][1]} for q in questions\n    ]\n\n    # Select backend\n    backend = select_sglang_backend(args)\n    sgl.set_default_backend(backend)\n\n    # Run requests\n    tic = time.perf_counter()\n    rets = answer_mt_bench.run_batch(\n        arguments,\n        temperature=0,\n        max_new_tokens=2048,\n        num_threads=args.parallel,\n        progress_bar=True,\n    )\n    answers = [[s[\"answer_1\"], s[\"answer_2\"]] for s in rets]\n\n    latency = time.perf_counter() - tic\n    num_output_tokens = sum(\n        s.get_meta_info(\"answer_1\")[\"completion_tokens\"]\n        + s.get_meta_info(\"answer_2\")[\"completion_tokens\"]\n        for s in rets\n    )\n\n    # NOTE: acceptance length is just completion_tokens / spec_verify_ct\n    # {'id': '3bb9c5ead109488d8ed5ee9cbecaec29', 'finish_reason': {'type': 'length', 'length': 256}, 'prompt_tokens': 37, 'spec_verify_ct': 101, 'completion_tokens': 256, 'cached_tokens': 0}\n\n    output_throughput = num_output_tokens / latency\n\n    has_verify = \"spec_verify_ct\" in rets[0].get_meta_info(\"answer_1\")\n    if has_verify:\n        num_verify_tokens = sum(\n            s.get_meta_info(\"answer_1\")[\"spec_verify_ct\"]\n            + s.get_meta_info(\"answer_2\")[\"spec_verify_ct\"]\n            for s in rets\n        )\n\n        accept_length = num_output_tokens / num_verify_tokens\n    else:\n        accept_length = 1.0\n\n    print(\n        f\"#questions: {len(questions)}, Throughput: {output_throughput:.2f} token/s, Acceptance length: {accept_length:.2f}\"\n    )\n\n    # Write results\n    model_id = backend.model_info[\"model_path\"]\n    answer_file = args.answer_file or f\"tmp_output_{args.backend}.txt\"\n    write_answers(answer_file, model_id, questions, answers)\n\n    with open(args.result_file, \"a\") as fout:\n        value = {\n            \"task\": \"mtbench\",\n            \"backend\": args.backend,\n            \"num_gpus\": 1,\n            \"latency\": round(latency, 3),\n            \"throughput\": round(output_throughput, 3),\n            \"accept_length\": round(accept_length, 3),\n            \"num_requests\": args.num_questions,\n            \"other\": {\n                \"num_questions\": args.num_questions,\n                \"parallel\": args.parallel,\n            },\n        }\n        fout.write(json.dumps(value) + \"\\n\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--question-file\", type=str, default=\"question.jsonl\")\n    parser.add_argument(\"--answer-file\", type=str, default=None)\n    parser.add_argument(\"--num-questions\", type=int, default=80)\n    args = add_common_sglang_args_and_parse(parser)\n    main(args)\n"
  },
  {
    "path": "benchmark/multi_chain_reasoning/README.md",
    "content": "## Download data\n```\nwget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl\n```\n\n## Run benchmark\n\n### Benchmark sglang\n```\npython -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000  --schedule-conservativeness 1.3\n```\n\n```\npython3 bench_sglang.py --num-questions 64\npython3 bench_sglang.py --num-questions 32 --parallel 1\n```\n\n\n### Benchmark vllm\n```\npython3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000\n```\n\n```\npython3 bench_other.py --num-questions 64 --backend vllm\n```\n\n\n### Benchmark lightllm\n```\n# A10G\npython -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000\n```\n\n```\npython3 bench_other.py --num-questions 64 --backend lightllm\n```\n\n\n### Benchmark guidance\n```\npython3 bench_other.py --num-questions 8 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf\n```\n\n### Benchmark lmql\n\n```\npython3 bench_other.py --num-questions 64 --backend lmql --parallel 1\n```\n"
  },
  {
    "path": "benchmark/multi_chain_reasoning/bench_other.py",
    "content": "import argparse\nimport ast\nimport asyncio\nimport json\nimport re\nimport time\nfrom concurrent.futures import ThreadPoolExecutor\n\nimport numpy as np\nfrom tqdm import tqdm\n\nfrom sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate\nfrom sglang.utils import dump_state_text, read_jsonl\n\nINVALID = -9999999\n\n\ndef get_answer_value(answer_str):\n    answer_str = answer_str.replace(\",\", \"\")\n    numbers = re.findall(r\"\\d+\", answer_str)\n    if len(numbers) < 1:\n        return INVALID\n    try:\n        return ast.literal_eval(numbers[-1])\n    except SyntaxError:\n        return INVALID\n\n\nprompt_lib = [\n    \"Let us think step by step.\",\n    \"Approach this methodically. Let's dissect the problem into smaller, more manageable parts.\",\n    \"It's important to proceed step by step, ensuring accuracy at each stage.\",\n    \"Take a deep breath and break this down.\",\n    \"A little bit of arithmetic and a logical approach will help us quickly arrive at the solution to this problem.\",\n    \"I am extremely good at math.\",\n]\n\n\ndef multi_chain_gsm8k(question, num_chains, call_generate):\n    s = \"Question: \" + question + \"\\n\"\n    # s += call_generate(s + \"Answer: \" + prompt_lib[0], max_tokens=256,\n    #     stop=\"Question\", temperature=0)\n    # return s\n\n    comps = []\n    for i in range(num_chains):\n        comps.append(\n            call_generate(\n                s + \"Answer: \" + prompt_lib[i % num_chains],\n                max_tokens=256,\n                temperature=0.3,\n                stop=\"Question\",\n            )\n        )\n\n    s += \"Answer: To answer this question, here are some possible solutions. \"\n    s += \"After considering all of them, I will do a majority vote.\\n\\n\"\n    for i in range(num_chains):\n        s += f\"Solution {i+1}: \" + comps[i].strip() + \"\\n\\n\"\n    s += \"\\nBy considering the above solutions and doing a majority vote, I think the final answer (a single integer number) is \"\n    s += call_generate(s, max_tokens=16, temperature=0, stop=None)\n    return s\n\n\nasync def multi_chain_gsm8k_async(question, num_chains, call_generate):\n    s = \"Question: \" + question + \"\\n\"\n    # s += call_generate(s + \"Answer: \" + prompt_lib[0], max_tokens=256,\n    #     stop=\"Question\", temperature=0)\n    # return s\n\n    comps = []\n    for i in range(num_chains):\n        comps.append(\n            await call_generate(\n                s + \"Answer: \" + prompt_lib[i % num_chains],\n                max_tokens=256,\n                temperature=0.3,\n                stop=\"Question\",\n            )\n        )\n\n    s += \"Answer: To answer this question, here are some possible solutions. \"\n    s += \"After considering all of them, I will do a majority vote.\\n\\n\"\n    for i in range(num_chains):\n        s += f\"Solution {i+1}: \" + comps[i].strip() + \"\\n\\n\"\n    s += \"\\nBy considering the above solutions and doing a majority vote, I think the final answer (a single integer number) is \"\n    s += await call_generate(s, max_tokens=16, temperature=0, stop=None)\n    return s\n\n\ndef main(args):\n    lines = list(read_jsonl(args.data_path))\n\n    # Construct prompts\n    k = args.num_shot\n\n    questions = []\n    labels = []\n    for i in range(len(lines[: args.num_questions])):\n        questions.append(lines[i][\"question\"])\n        labels.append(get_answer_value(lines[i][\"answer\"]))\n    assert all(l != INVALID for l in labels)\n\n    states = [None] * len(labels)\n\n    # Select backend\n    call_generate = get_call_generate(args)\n\n    # Run requests\n    if args.backend != \"lmql\":\n        # Use thread pool\n        def get_one_answer(i):\n            answer = multi_chain_gsm8k(questions[i], args.num_chains, call_generate)\n            states[i] = answer\n\n        tic = time.perf_counter()\n        if args.parallel == 1:\n            for i in tqdm(range(len(questions))):\n                get_one_answer(i)\n        else:\n            with ThreadPoolExecutor(args.parallel) as executor:\n                list(\n                    tqdm(\n                        executor.map(get_one_answer, list(range(len(questions)))),\n                        total=len(questions),\n                    )\n                )\n\n    else:\n        # Use asyncio\n        async def get_one_answer_asyncio(i):\n            answer = await multi_chain_gsm8k_async(\n                questions[i], args.num_chains, call_generate\n            )\n            states[i] = answer\n\n        tic = time.perf_counter()\n        loop = asyncio.get_event_loop()\n        batches = [\n            list(range(i, min(i + args.parallel, len(questions))))\n            for i in range(0, len(questions), args.parallel)\n        ]\n        for bt in tqdm(batches):\n            tasks = [get_one_answer_asyncio(k) for k in bt]\n            loop.run_until_complete(asyncio.gather(*tasks))\n\n    latency = time.perf_counter() - tic\n\n    preds = []\n    for i in range(len(states)):\n        preds.append(get_answer_value(states[i]))\n\n    # Compute accuracy\n    acc = np.mean(np.array(preds) == np.array(labels))\n    invalid = np.mean(np.array(preds) == INVALID)\n    print(f\"Latency: {latency:.3f}\")\n    print(f\"Invalid: {invalid:.3f}\")\n    print(f\"Accuracy: {acc:.3f}\")\n\n    # Write results\n    dump_state_text(f\"tmp_output_{args.backend}.txt\", states)\n\n    with open(args.result_file, \"a\") as fout:\n        value = {\n            \"task\": \"multi_chain_gsm8k\",\n            \"backend\": args.backend,\n            \"num_gpus\": 1,\n            \"latency\": round(latency, 3),\n            \"accuracy\": round(acc, 3),\n            \"num_requests\": args.num_questions,\n            \"other\": {\n                \"num_questions\": args.num_questions,\n                \"parallel\": args.parallel,\n            },\n        }\n        fout.write(json.dumps(value) + \"\\n\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--num-shot\", type=int, default=0)\n    parser.add_argument(\"--num-chains\", type=int, default=5)\n    parser.add_argument(\"--data-path\", type=str, default=\"test.jsonl\")\n    parser.add_argument(\"--num-questions\", type=int, default=50)\n    args = add_common_other_args_and_parse(parser)\n    main(args)\n"
  },
  {
    "path": "benchmark/multi_chain_reasoning/bench_sglang.py",
    "content": "import argparse\nimport ast\nimport json\nimport re\nimport time\n\nimport numpy as np\n\nfrom sglang.test.test_utils import (\n    add_common_sglang_args_and_parse,\n    select_sglang_backend,\n)\nfrom sglang.utils import dump_state_text, read_jsonl\n\nINVALID = -9999999\n\n\ndef get_answer_value(answer_str):\n    answer_str = answer_str.replace(\",\", \"\")\n    numbers = re.findall(r\"\\d+\", answer_str)\n    if len(numbers) < 1:\n        return INVALID\n    try:\n        return ast.literal_eval(numbers[-1])\n    except SyntaxError:\n        return INVALID\n\n\nprompt_lib = [\n    \"Let us think step by step.\",\n    \"Approach this methodically. Let's dissect the problem into smaller, more manageable parts.\",\n    \"It's important to proceed step by step, ensuring accuracy at each stage.\",\n    \"Take a deep breath and break this down.\",\n    \"A little bit of arithmetic and a logical approach will help us quickly arrive at the solution to this problem.\",\n    \"I am extremely good at math.\",\n]\n\n\ndef main(args):\n    lines = list(read_jsonl(args.data_path))\n\n    # Construct prompts\n    # k = args.num_shot\n    # few_shot_examples = get_few_shot_examples(lines, k)\n\n    questions = []\n    labels = []\n    for i in range(len(lines[: args.num_questions])):\n        questions.append(lines[i][\"question\"])\n        labels.append(get_answer_value(lines[i][\"answer\"]))\n    assert all(l != INVALID for l in labels)\n    arguments = [{\"question\": q} for q in questions]\n\n    num_chains = args.num_chains\n\n    #####################################\n    ######### SGL Program Begin #########\n    #####################################\n\n    import sglang as sgl\n\n    @sgl.function\n    def multi_chain_gsm8k(s, question):\n        s += \"Question: \" + question + \"\\n\"\n        # s += \"Answer: \" + prompt_lib[0] + sgl.gen(\"answer\", max_tokens=256, stop=\"Question\",\n        #    temperature=0)\n        # return\n\n        forks = s.fork(num_chains)\n        for i in range(num_chains):\n            forks[i] += (\n                \"Answer: \"\n                + prompt_lib[i % num_chains]\n                + sgl.gen(\"chain\", max_tokens=256, temperature=0.3, stop=\"Question\")\n            )\n        forks.join()\n\n        s += \"Answer: To answer this question, here are some possible solutions. \"\n        s += \"After considering all of them, I will do a majority vote.\\n\\n\"\n        for i in range(num_chains):\n            s += f\"Solution {i+1}: \" + forks[i][\"chain\"].strip() + \"\\n\\n\"\n        s += \"\\nBy considering the above solutions and doing a majority vote, I think the final answer (a single integer number) is \"\n        s += sgl.gen(\"answer\", max_tokens=16)\n\n    #####################################\n    ########## SGL Program End ##########\n    #####################################\n\n    # Select backend\n    backend = select_sglang_backend(args)\n\n    # Run requests\n    tic = time.perf_counter()\n    states = multi_chain_gsm8k.run_batch(\n        arguments,\n        temperature=0,\n        backend=backend,\n        num_threads=args.parallel,\n        progress_bar=True,\n    )\n    latency = time.perf_counter() - tic\n\n    preds = []\n    for i in range(len(states)):\n        preds.append(get_answer_value(states[i][\"answer\"]))\n\n    # Compute accuracy\n    acc = np.mean(np.array(preds) == np.array(labels))\n    invalid = np.mean(np.array(preds) == INVALID)\n    print(f\"Latency: {latency:.3f}\")\n    print(f\"Invalid: {invalid:.3f}\")\n    print(f\"Accuracy: {acc:.3f}\")\n\n    # Write results\n    dump_state_text(f\"tmp_output_{args.backend}.txt\", states)\n\n    with open(args.result_file, \"a\") as fout:\n        value = {\n            \"task\": \"multi_chain_gsm8k\",\n            \"backend\": args.backend,\n            \"num_gpus\": 1,\n            \"latency\": round(latency, 3),\n            \"accuracy\": round(acc, 3),\n            \"num_requests\": args.num_questions,\n            \"other\": {\n                \"num_questions\": args.num_questions,\n                \"parallel\": args.parallel,\n            },\n        }\n        fout.write(json.dumps(value) + \"\\n\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--num-shot\", type=int, default=0)\n    parser.add_argument(\"--num-chains\", type=int, default=5)\n    parser.add_argument(\"--data-path\", type=str, default=\"test.jsonl\")\n    parser.add_argument(\"--num-questions\", type=int, default=50)\n    args = add_common_sglang_args_and_parse(parser)\n    main(args)\n"
  },
  {
    "path": "benchmark/multi_document_qa/README.md",
    "content": "## Run benchmark\n\n### Benchmark sglang\n```\npython3 -m sglang.launch_server --model-path codellama/CodeLlama-7b-instruct-hf --port 30000\n```\n\n```\npython3 bench_sglang.py --num-questions 10 --parallel 1\n```\n\n\n### Benchmark vllm\n```\npython3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model codellama/CodeLlama-7b-instruct-hf  --disable-log-requests --port 21000 --gpu 0.97\n```\n\n```\npython3 bench_other.py --backend vllm --num-questions 64\n```\n\n\n### Benchmark guidance\n```\npython3 bench_other.py --backend guidance --num-questions 32 --parallel 1 --n-ctx 11000 --model-path path/to/code-llama/gguf\n```\n\n\n\n### Build dataset\n\n```\npip install PyPDF2\npython3 build_dataset.py\n```\n\n```python\nimport PyPDF2\n\nwith open('llama2.pdf', 'rb') as file:\n    reader = PyPDF2.PdfReader(file)\n    text = ''\n    for page_num in range(len(reader.pages)):\n        text += reader.pages[page_num].extract_text()\n    with open('output.txt', 'w') as text_file:\n        text_file.write(text)\n```\n"
  },
  {
    "path": "benchmark/multi_document_qa/bench_other.py",
    "content": "import argparse\nimport json\nimport time\nfrom concurrent.futures import ThreadPoolExecutor\nfrom functools import partial\n\nfrom tqdm import tqdm\n\nfrom sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate\nfrom sglang.utils import dump_state_text, read_jsonl\n\nUSER_PREFIX = \"[INST] \"\nUSER_SUFFIX = \" [/INST]\"\nASSISTANT_PREFIX = \"\"\nASSISTANT_SUFFIX = \" </s><s>\"\n\n\ndef multi_document_qa(docs, question, generate):\n    s = USER_PREFIX\n    s += \"Please answer a question according to given documents.\\n\"\n    s += \"Question:\" + question + \"Documents begin.\\n\"\n\n    s += \"\".join(docs)\n\n    s += \"\\nDocuments end.\"\n    s += (\n        \"\\n\\nBased on the above documents, please answer this question:\\n\"\n        + question\n        + \"\\nAnswer in three words or fewer.\"\n    )\n    s += USER_SUFFIX\n    s += ASSISTANT_PREFIX\n    answer = generate(s, max_tokens=16, stop=None)\n    return answer\n\n\ndef main(args):\n    lines = read_jsonl(args.data_path)\n    l = lines[0]\n    arguments = []\n    labels = []\n\n    num_docs = 10\n    if args.backend == \"guidance\":\n        num_docs = 7  # due to OOM\n\n    for i in range(len(l[\"questions\"][: args.num_questions])):\n        arguments.append(\n            {\n                \"docs\": l[\"documents\"][:num_docs],\n                \"question\": l[\"questions\"][i],\n            }\n        )\n        labels.append(l[\"answers\"][i])\n    states = [None] * len(arguments)\n\n    # Select backend\n    call_generate = partial(get_call_generate(args), temperature=0)\n\n    # Run requests\n    def get_one_answer(i):\n        states[i] = multi_document_qa(generate=call_generate, **arguments[i])\n\n    tic = time.perf_counter()\n    if args.parallel == 1:\n        for i in tqdm(range(len(labels))):\n            get_one_answer(i)\n    else:\n        with ThreadPoolExecutor(args.parallel) as executor:\n            list(\n                tqdm(\n                    executor.map(get_one_answer, list(range(len(labels)))),\n                    total=len(labels),\n                )\n            )\n\n    latency = time.perf_counter() - tic\n\n    # Compute accuracy\n    print(states)\n    correct = 0\n    for s, label in zip(states, labels):\n        answer = s.lower()\n        if all(x in answer for x in label.lower().split(\" \")):\n            correct += 1\n    accuracy = correct / len(labels)\n    print(f\"Accuracy: {accuracy:.3f}\")\n    print(f\"Latency: {latency:.3f}\")\n\n    # Write results\n    dump_state_text(f\"tmp_output_{args.backend}.txt\", states)\n\n    with open(args.result_file, \"a\") as fout:\n        value = {\n            \"task\": \"multi_document_qa\",\n            \"backend\": args.backend,\n            \"num_gpus\": 1,\n            \"latency\": round(latency, 3),\n            \"num_requests\": args.num_questions,\n            \"accuracy\": accuracy,\n            \"other\": {\n                \"num_questions\": args.num_questions,\n                \"parallel\": args.parallel,\n            },\n        }\n        fout.write(json.dumps(value) + \"\\n\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--data-path\", type=str, default=\"questions.jsonl\")\n    parser.add_argument(\"--num-questions\", type=int, default=100)\n    args = add_common_other_args_and_parse(parser)\n    main(args)\n"
  },
  {
    "path": "benchmark/multi_document_qa/bench_sglang.py",
    "content": "import argparse\nimport json\nimport time\n\nimport sglang as sgl\nfrom sglang.test.test_utils import (\n    add_common_sglang_args_and_parse,\n    select_sglang_backend,\n)\nfrom sglang.utils import dump_state_text, read_jsonl\n\n\n@sgl.function\ndef multi_document_qa(s, docs, question):\n    s += sgl.user_begin()\n    s += \"Please answer a question according to given documents.\\n\"\n    s += \"Question:\" + question + \"Documents begin.\\n\"\n\n    forks = s.fork(len(docs))\n    forks += lambda i: docs[i]\n    forks.join(\"concate_and_append\")\n\n    s += \"\\nDocuments end.\"\n    s += (\n        \"\\n\\nBased on the above documents, please answer this question:\\n\"\n        + question\n        + \"\\nAnswer in three words or fewer.\"\n    )\n    s += sgl.user_end()\n    s += sgl.assistant(sgl.gen(\"answer\", max_tokens=16))\n\n\ndef main(args):\n    lines = read_jsonl(args.data_path)\n    l = lines[0]\n    arguments = []\n    labels = []\n    for i in range(len(l[\"questions\"][: args.num_questions])):\n        arguments.append(\n            {\n                \"docs\": l[\"documents\"][:10],\n                \"question\": l[\"questions\"][i],\n            }\n        )\n        labels.append(l[\"answers\"][i])\n\n    # Select backend\n    backend = select_sglang_backend(args)\n    sgl.set_default_backend(backend)\n\n    # Run requests\n    tic = time.perf_counter()\n    states = multi_document_qa.run_batch(\n        arguments, temperature=0, num_threads=args.parallel, progress_bar=True\n    )\n    latency = time.perf_counter() - tic\n\n    # Compute accuracy\n    print([s[\"answer\"] for s in states])\n    correct = 0\n    for s, label in zip(states, labels):\n        answer = s[\"answer\"].lower()\n        if all(x in answer for x in label.lower().split(\" \")):\n            correct += 1\n    accuracy = correct / len(labels)\n    print(f\"Accuracy: {accuracy:.3f}\")\n    print(f\"Latency: {latency:.3f}\")\n\n    # Write results\n    dump_state_text(f\"tmp_output_{args.backend}.txt\", states)\n\n    with open(args.result_file, \"a\") as fout:\n        value = {\n            \"task\": \"multi_document_qa\",\n            \"backend\": args.backend,\n            \"num_gpus\": 1,\n            \"latency\": round(latency, 3),\n            \"num_requests\": args.num_questions,\n            \"accuracy\": accuracy,\n            \"other\": {\n                \"num_questions\": args.num_questions,\n                \"parallel\": args.parallel,\n            },\n        }\n        fout.write(json.dumps(value) + \"\\n\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--data-path\", type=str, default=\"questions.jsonl\")\n    parser.add_argument(\"--num-questions\", type=int, default=100)\n    args = add_common_sglang_args_and_parse(parser)\n    main(args)\n"
  },
  {
    "path": "benchmark/multi_document_qa/build_dataset.py",
    "content": "import json\n\nimport transformers\n\ncontent = \"\\n\".join(\n    open(\"llama2.txt\", \"r\", encoding=\"utf-8\", errors=\"ignore\").readlines()\n)\ncontent = content.replace(\"\\n\\n\", \"\\n\")\n\n# Count token\nname = \"meta-llama/Llama-2-7b-chat-hf\"\nt = transformers.AutoTokenizer.from_pretrained(name)\nprint(f\"num tokens: {len(t.encode(content))}\")\n\n# Segment\nSEP = \"\\n\\n\"\nparts = content.split(SEP)\nprint(f\"num segments: {len(parts)}\")\n\nsegment_len = 1100\n\nsegments = []\ntmp = []\ntmp_len = 0\nfor i in range(len(parts)):\n    tmp.append(parts[i])\n    tmp_len += len(t.encode(parts[i]))\n\n    if tmp_len > segment_len:\n        segments.append(SEP.join(tmp))\n        tmp = []\n        tmp_len = 0\n\nfor i, s in enumerate(segments):\n    print(i, len(t.encode(segments[i])))\n\n# Dump\nwith open(\"questions.jsonl\", \"w\") as fout:\n    fout.write(\n        json.dumps(\n            {\n                \"documents\": segments[:30],\n                \"questions\": [\n                    \"What is the name of the fine-tuned LLMs?\",\n                    \"Which figure shows the helpfulness human evaluation results for Llama 2-Chat?\",\n                    \"What is the number of parameters in the largest Llama 2 model?\",\n                    \"What is the batch size of fine-tuning?\",\n                    \"Where can we find the details of potential data contamination?\",\n                    \"What is the full name of MPT?\",\n                    \"What is the power consumption of RSC in Watt?\",\n                    \"How many tokens of data do they train on?\",\n                    \"Which model's release is delayed due to a lack of time to sufficiently red team?\",\n                    \"Which activation function is used in Llama?\",\n                ],\n                \"answers\": [\n                    \"Llama 2 Chat\",\n                    \"1\",\n                    \"70 B\",\n                    \"64\",\n                    \"A 6\",\n                    \"MosaicML\",\n                    \"400\",\n                    \"2 trillion\",\n                    \"34 B\",\n                    \"SwiGLU\",\n                ],\n            }\n        )\n        + \"\\n\"\n    )\n"
  },
  {
    "path": "benchmark/multi_turn_chat/README.md",
    "content": "### Benchmark sglang\n\nRun Llama-7B\n\n```\npython3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000\n```\n\nRun Mixtral-8x7B\n(When there is a CUDA out-of-memory error, try to reduce the `--mem-fraction-static`)\n\n```\npython3 -m sglang.launch_server --model-path mistralai/Mixtral-8x7B-Instruct-v0.1 --port 30000 --tp-size 8\n```\n\nBenchmark(short output)\n\n```\npython3 bench_sglang.py --tokenizer meta-llama/Llama-2-7b-chat-hf\n```\n\nBenchmark(long output)\n\n```\npython3 bench_sglang.py --tokenizer meta-llama/Llama-2-7b-chat-hf --long\n```\n\n### Benchmark vLLM\n\nRun Llama-7B\n\n```\npython3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf  --disable-log-requests --port 21000\n```\n\nRun Mixtral-8x7B\n\n```\npython3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model mistralai/Mixtral-8x7B-Instruct-v0.1 --disable-log-requests --port 21000 --tensor-parallel-size 8\n```\n\nBenchmark(short output)\n\n```\npython3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend vllm\n```\n\nBenchmark(long output)\n\n```\npython3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend vllm --long\n```\n\n### Benchmark guidance\n\nBenchmark Llama-7B (short output)\n\n```\npython3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf\n```\n\nBenchmark Llama-7B (long output)\n\n```\npython3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf --long\n```\n"
  },
  {
    "path": "benchmark/multi_turn_chat/bench_other.py",
    "content": "import json\nimport time\nfrom argparse import ArgumentParser\nfrom concurrent.futures import ThreadPoolExecutor\nfrom functools import partial\n\nfrom data_gen import gen_arguments\nfrom tqdm import tqdm\nfrom vllm.transformers_utils.tokenizer import get_tokenizer\n\nfrom sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate\nfrom sglang.utils import dump_state_text\n\n\ndef multi_turns(generate, qas):\n    s = \"\"\n    for qa in qas:\n        s += qa[\"prompt\"]\n        s += generate(s, max_tokens=qa[\"new_tokens\"])\n\n    return s\n\n\ndef main(args):\n    print(args)\n\n    tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)\n\n    multi_qas = gen_arguments(args, tokenizer)\n\n    states = [None] * args.num_qa\n\n    call_generate = partial(get_call_generate(args), temperature=0)\n\n    def get_one_answer(i):\n        states[i] = multi_turns(generate=call_generate, **multi_qas[i])\n\n    tic = time.perf_counter()\n    if args.parallel == 1:\n        for i in tqdm(range(len(multi_qas))):\n            get_one_answer(i)\n    else:\n        with ThreadPoolExecutor(args.parallel) as executor:\n            rets = list(\n                tqdm(\n                    executor.map(get_one_answer, list(range(len(multi_qas)))),\n                    total=len(multi_qas),\n                )\n            )\n            for _ in rets:\n                pass\n\n    latency = time.perf_counter() - tic\n\n    # Compute accuracy\n    print(f\"Latency: {latency:.3f}\")\n\n    dump_state_text(f\"tmp_output_{args.backend}.txt\", states)\n\n    with open(args.result_file, \"a\") as fout:\n        value = {\n            \"task\": \"multi_turn_chat\",\n            \"backend\": args.backend,\n            \"num_gpus\": 1,\n            \"latency\": round(latency, 3),\n            \"num_requests\": args.num_qa,\n            \"num_turns\": args.turns,\n            \"other\": {\n                \"parallel\": args.parallel,\n                \"output_mode\": \"long\" if args.long else \"short\",\n            },\n        }\n        fout.write(json.dumps(value) + \"\\n\")\n\n\nif __name__ == \"__main__\":\n    parser = ArgumentParser()\n    parser.add_argument(\"--turns\", type=int, default=4)\n    parser.add_argument(\"--num-qa\", type=int, default=20)\n    parser.add_argument(\"--min-len-q\", type=int, default=256)\n    parser.add_argument(\"--max-len-q\", type=int, default=512)\n    parser.add_argument(\"--min-len-a\", type=int, default=4)\n    parser.add_argument(\"--max-len-a\", type=int, default=8)\n    parser.add_argument(\"--tokenizer\", type=str, required=True)\n    parser.add_argument(\"--trust-remote-code\", action=\"store_true\")\n    parser.add_argument(\"--long\", action=\"store_true\")\n    args = add_common_other_args_and_parse(parser)\n\n    if args.long:\n        args.min_len_a = 256\n        args.max_len_a = 512\n        args.num_qa = 20\n    main(args)\n"
  },
  {
    "path": "benchmark/multi_turn_chat/bench_sglang.py",
    "content": "import json\nimport time\nfrom argparse import ArgumentParser\n\nfrom data_gen import gen_arguments\nfrom vllm.transformers_utils.tokenizer import get_tokenizer\n\nimport sglang as sgl\nfrom sglang.test.test_utils import (\n    add_common_sglang_args_and_parse,\n    select_sglang_backend,\n)\nfrom sglang.utils import dump_state_text\n\n\n@sgl.function\ndef multi_turns(s, qas):\n    for qa in qas:\n        s += qa[\"prompt\"]\n        s += sgl.gen(max_tokens=qa[\"new_tokens\"], ignore_eos=True)\n\n\ndef main(args):\n    tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)\n\n    multi_qas = gen_arguments(args, tokenizer)\n\n    backend = select_sglang_backend(args)\n\n    tic = time.perf_counter()\n    states = multi_turns.run_batch(\n        multi_qas,\n        temperature=0,\n        backend=backend,\n        num_threads=args.parallel,\n        progress_bar=True,\n    )\n    latency = time.perf_counter() - tic\n\n    print(f\"Latency: {latency:.3f}\")\n\n    dump_state_text(f\"tmp_output_{args.backend}.txt\", states)\n\n    with open(args.result_file, \"a\") as fout:\n        value = {\n            \"task\": \"multi_turn_chat\",\n            \"backend\": args.backend,\n            \"num_gpus\": 1,\n            \"latency\": round(latency, 3),\n            \"num_requests\": args.num_qa,\n            \"num_turns\": args.turns,\n            \"other\": {\n                \"parallel\": args.parallel,\n                \"output_mode\": \"long\" if args.long else \"short\",\n            },\n        }\n        fout.write(json.dumps(value) + \"\\n\")\n\n\nif __name__ == \"__main__\":\n    parser = ArgumentParser()\n    parser.add_argument(\"--turns\", type=int, default=4)\n    parser.add_argument(\"--num-qa\", type=int, default=20)\n    parser.add_argument(\"--min-len-q\", type=int, default=256)\n    parser.add_argument(\"--max-len-q\", type=int, default=512)\n    parser.add_argument(\"--min-len-a\", type=int, default=4)\n    parser.add_argument(\"--max-len-a\", type=int, default=8)\n    parser.add_argument(\"--tokenizer\", type=str, required=True)\n    parser.add_argument(\"--trust-remote-code\", action=\"store_true\")\n    parser.add_argument(\"--long\", action=\"store_true\")\n    args = add_common_sglang_args_and_parse(parser)\n\n    if args.long:\n        args.min_len_a = 256\n        args.max_len_a = 512\n        args.num_qa = 20\n\n    print(args)\n    main(args)\n"
  },
  {
    "path": "benchmark/multi_turn_chat/data_gen.py",
    "content": "import random\nimport string\n\nrandom.seed(42)\n\n\ndef gen_prompt(tokenizer, token_num):\n    cha_set = string.ascii_letters + string.digits\n    ret = \"\".join(random.choices(cha_set, k=token_num))\n    while len(tokenizer(ret).input_ids) < token_num:\n        ret += random.choice(cha_set)\n    return ret\n\n\ndef gen_arguments(args, tokenizer):\n    multi_qas = [{\"qas\": []} for _ in range(args.num_qa)]\n    for i in range(args.num_qa):\n        qas = multi_qas[i][\"qas\"]\n        for _ in range(args.turns):\n            prompt_len = random.randint(args.min_len_q, args.max_len_q)\n            new_tokens = random.randint(args.min_len_a, args.max_len_a)\n            qas.append(\n                {\n                    \"prompt\": gen_prompt(tokenizer, prompt_len),\n                    \"new_tokens\": new_tokens,\n                }\n            )\n\n    return multi_qas\n"
  },
  {
    "path": "benchmark/multi_turn_chat/long_prompt_multi_turn.py",
    "content": "import json\nimport random\nimport time\nfrom argparse import ArgumentParser\nfrom pathlib import Path\n\nfrom tqdm import tqdm\n\nimport sglang as sgl\nfrom sglang.srt.utils.hf_transformers_utils import get_tokenizer\nfrom sglang.test.test_utils import (\n    add_common_sglang_args_and_parse,\n    select_sglang_backend,\n)\nfrom sglang.utils import dump_state_text\n\n\ndef gen_prompt(tokenizer, token_num):\n    all_available_tokens = list(tokenizer.get_vocab().values())\n    selected_tokens = random.choices(all_available_tokens, k=token_num)\n    ret = tokenizer.decode(selected_tokens)\n    return ret\n\n\ndef get_cache_path(args):\n    # Create cache directory under ~/.cache/sglang\n    cache_dir = Path.home() / \".cache\" / \"sglang\"\n\n    # Create a unique cache filename based on the arguments that affect generation\n    cache_key = f\"qa_{args.num_qa}_{args.turns}_{args.system_prompt_len}_{args.len_q}_{args.len_a}_{args.tokenizer.replace('/', '_')}.json\"\n    return cache_dir / cache_key\n\n\ndef gen_arguments(args, tokenizer):\n    cache_path = get_cache_path(args)\n\n    # Try to load from cache first\n    if cache_path.exists():\n        print(f\"Loading cached arguments from {cache_path}\")\n        with open(cache_path, \"r\") as f:\n            return json.load(f)\n\n    print(\"Generating new arguments...\")\n    # First progress bar for system prompts\n    multi_qas = []\n    for _ in tqdm(range(args.num_qa), desc=\"Generating system prompts\"):\n        multi_qas.append(\n            {\"system_prompt\": gen_prompt(tokenizer, args.system_prompt_len), \"qas\": []}\n        )\n\n    # Nested progress bars for QA pairs\n    for i in tqdm(range(args.num_qa), desc=\"Generating QA pairs\"):\n        qas = multi_qas[i][\"qas\"]\n        for j in range(args.turns):\n            qas.append(\n                {\n                    \"prompt\": gen_prompt(tokenizer, args.len_q),\n                    \"new_tokens\": args.len_a,\n                }\n            )\n\n    # Save to cache\n    cache_path.parent.mkdir(parents=True, exist_ok=True)\n    with open(cache_path, \"w\") as f:\n        json.dump(multi_qas, f)\n    print(f\"Cached arguments saved to {cache_path}\")\n\n    return multi_qas\n\n\n@sgl.function\ndef multi_turns(s, system_prompt, qas):\n    s += system_prompt\n\n    for i, qa in enumerate(qas):\n        s += qa[\"prompt\"]\n        s += sgl.gen(max_tokens=qa[\"new_tokens\"], ignore_eos=True)\n\n\ndef main(args):\n    tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)\n\n    multi_qas = gen_arguments(args, tokenizer)\n\n    backend = select_sglang_backend(args)\n\n    tic = time.perf_counter()\n    states = multi_turns.run_batch(\n        multi_qas,\n        temperature=0,\n        backend=backend,\n        num_threads=\"auto\",\n        progress_bar=True,\n    )\n    latency = time.perf_counter() - tic\n\n    print(f\"Latency: {latency:.3f}\")\n\n    dump_state_text(f\"tmp_output_{args.backend}.txt\", states)\n\n    with open(args.result_file, \"a\") as fout:\n        value = {\n            \"task\": \"multi_turn_system_prompt_chat\",\n            \"backend\": args.backend,\n            \"latency\": round(latency, 3),\n            \"num_requests\": args.num_qa,\n            \"num_turns\": args.turns,\n            \"other\": {\n                \"parallel\": args.parallel,\n            },\n        }\n        fout.write(json.dumps(value) + \"\\n\")\n\n\nif __name__ == \"__main__\":\n    parser = ArgumentParser()\n    parser.add_argument(\"--turns\", type=int, default=8)\n    parser.add_argument(\"--num-qa\", type=int, default=128)\n    parser.add_argument(\"--system-prompt-len\", type=int, default=2048)\n    parser.add_argument(\"--len-q\", type=int, default=32)\n    parser.add_argument(\"--len-a\", type=int, default=128)\n    parser.add_argument(\n        \"--tokenizer\", type=str, default=\"meta-llama/Meta-Llama-3-8B-Instruct\"\n    )\n    parser.add_argument(\"--trust-remote-code\", action=\"store_true\")\n    args = add_common_sglang_args_and_parse(parser)\n\n    print(args)\n    main(args)\n"
  },
  {
    "path": "benchmark/prefill_only/bench_embeddings.py",
    "content": "\"\"\"\nSGLang Embeddings Benchmark Script\n\nThis script benchmarks SGLang's /v1/embeddings API performance using HTTP requests.\n\nFeatures:\n- HTTP-only implementation\n- Uses /v1/embeddings API endpoint directly\n- Configurable RPS, duration, and batch sizes\n- Progress tracking and detailed metrics\n- Poisson and constant request distributions\n\nUsage:\n- Update configuration variables at the top of the file\n- Ensure SGLang server is running on the configured HTTP_URL\n- Run: python bench_embeddings.py\n\"\"\"\n\nimport asyncio\nimport logging\nfrom typing import Optional\n\nfrom transformers import AutoTokenizer\nfrom util import (\n    BenchmarkConfig,\n    generate_text_with_token_count,\n    run_benchmark_main,\n    run_generic_benchmark,\n)\n\n# Configure logging\nlogging.basicConfig(\n    level=logging.INFO, format=\"%(asctime)s - %(name)s - %(levelname)s - %(message)s\"\n)\nlogger = logging.getLogger(__name__)\n\n###############################################################################\n# CONFIG\n###############################################################################\n# Create benchmark configuration\nconfig = BenchmarkConfig()\nconfig.rps_values = [500]\nconfig.duration_secs_values = [60]\nconfig.num_unique_requests = 100\nconfig.distribution = \"POISSON\"\nconfig.profile = False\nconfig.freeze_gc = True  # Enable GC freeze functionality\n# Profiler output directory - by default uses present working directory (pwd)\n# Uncomment and customize the line below to override the default location:\n# config.profiler_dir = \"/sglang-oss-trace\"\n\n# HTTP Configuration\nHTTP_URL = \"http://localhost:30000/v1/embeddings\"\n\n# Embeddings API Config\nEMBEDDINGS_MODEL_PATH = \"Qwen/Qwen3-Embedding-0.6B\"\nBATCH_SIZE = [1]  # Number of items per request (batch size)\n\n# Configurable input token length\nEMBEDDINGS_INPUT_TOKENS = 500  # Default token length\nMATRYOSHKA_DIMENSIONS: Optional[int] = (\n    None  # Set to None to disable matryoshka embeddings\n)\n\n# Load tokenizer once for embeddings text generation\nprint(\"Loading tokenizer for embeddings input generation...\")\nembeddings_tokenizer = AutoTokenizer.from_pretrained(EMBEDDINGS_MODEL_PATH)\n\n# Generate input text with the specified token length using pre-loaded tokenizer\nEMBEDDINGS_INPUT_TEXT = generate_text_with_token_count(\n    EMBEDDINGS_MODEL_PATH,\n    EMBEDDINGS_INPUT_TOKENS,\n    config.special_replicated_token,\n    tokenizer=embeddings_tokenizer,\n)\n\n\n###############################################################################\n# REQUEST GENERATION (in parallel)\n###############################################################################\ndef build_embeddings_request(index: int, item_count: int) -> tuple:\n    \"\"\"Build a single embeddings request.\"\"\"\n    try:\n        # For embeddings, input can be a string or list of strings\n        if item_count == 1:\n            input_data = EMBEDDINGS_INPUT_TEXT\n        else:\n            input_data = [EMBEDDINGS_INPUT_TEXT for _ in range(item_count)]\n        req = {\n            \"input\": input_data,\n            \"model\": EMBEDDINGS_MODEL_PATH,\n            \"dimensions\": MATRYOSHKA_DIMENSIONS,\n        }\n        return (index, req)\n    except Exception as e:\n        logger.error(f\"Error building request {index}: {e}\")\n        return (index, None)\n\n\ndef validate_embeddings_response(response_data: dict) -> bool:\n    \"\"\"Validate embeddings API response.\"\"\"\n    return (\n        \"data\" in response_data\n        and len(response_data[\"data\"][0][\"embedding\"]) == MATRYOSHKA_DIMENSIONS\n        if MATRYOSHKA_DIMENSIONS\n        else True\n    )\n\n\ndef build_warmup_embeddings_request() -> dict:\n    \"\"\"Build a warmup request for the embeddings API.\"\"\"\n    return {\n        \"input\": EMBEDDINGS_INPUT_TEXT,\n        \"model\": EMBEDDINGS_MODEL_PATH,\n        \"dimensions\": MATRYOSHKA_DIMENSIONS,\n    }\n\n\n###############################################################################\n# MAIN\n###############################################################################\nasync def run_benchmark(rps, duration_secs, item_count):\n    \"\"\"Run a single embeddings benchmark with the given RPS value.\"\"\"\n    return await run_generic_benchmark(\n        rps=rps,\n        duration_secs=duration_secs,\n        item_count=item_count,\n        config=config,\n        http_url=HTTP_URL,\n        build_request_func=build_embeddings_request,\n        response_validator=validate_embeddings_response,\n        api_name=\"EMBEDDINGS\",\n        request_description=\"embeddings requests\",\n    )\n\n\nasync def main():\n    additional_info = {\n        \"Input text length\": f\"{EMBEDDINGS_INPUT_TOKENS} tokens\",\n        \"Input text preview\": (\n            EMBEDDINGS_INPUT_TEXT[:100] + \"...\"\n            if len(EMBEDDINGS_INPUT_TEXT) > 100\n            else EMBEDDINGS_INPUT_TEXT\n        ),\n    }\n\n    await run_benchmark_main(\n        config,\n        run_benchmark,\n        \"EMBEDDINGS\",\n        HTTP_URL,\n        BATCH_SIZE,\n        additional_info,\n        build_warmup_embeddings_request,\n    )\n\n\nif __name__ == \"__main__\":\n    asyncio.run(main())\n"
  },
  {
    "path": "benchmark/prefill_only/bench_score.py",
    "content": "\"\"\"\nSGLang Scoring Benchmark Script\n\nThis script benchmarks SGLang's scoring API performance using HTTP requests.\n\nCurrent Features:\n- HTTP-only implementation (open source compatible)\n- Uses /v1/score API endpoint directly\n- Single item scoring with batching support\n- Configurable RPS, duration, and batch sizes\n- Progress tracking and detailed metrics\n- Poisson and constant request distributions\n\nUsage:\n- Update configuration variables at the top of the file\n- Ensure SGLang server is running on the configured HTTP_URL\n- Run: python bench_score.py\n- Each request will contain ITEM_COUNT_VALUES items for batch scoring\n\n\"\"\"\n\nimport asyncio\n\nfrom transformers import AutoTokenizer\nfrom util import (\n    BenchmarkConfig,\n    generate_text_with_token_count,\n    run_benchmark_main,\n    run_generic_benchmark,\n)\n\n###############################################################################\n# CONFIG\n###############################################################################\n# Create benchmark configuration\nconfig = BenchmarkConfig()\nconfig.rps_values = [160]\nconfig.duration_secs_values = [60]\nconfig.num_unique_requests = 100\nconfig.distribution = \"POISSON\"\nconfig.profile = False\nconfig.freeze_gc = True  # Enable GC freeze functionality\n# Profiler output directory - by default uses present working directory (pwd)\n# Uncomment and customize the line below to override the default location:\n# config.profiler_dir = \"/sglang-oss-trace\"\n\n# HTTP Configuration\nHTTP_URL = \"http://localhost:30000/v1/score\"  # Use score API directly\n\n# Score API Config\n# ITEM_COUNT_VALUES determines number of items per score request (batch size)\nSCORE_QUERY_TOKENS = 120\nSCORE_ITEM_TOKENS = 180\nSCORE_MODEL_PATH = \"Qwen/Qwen3-0.6B\"\nSCORE_LABEL_TOKEN_IDS = [9454, 2753]  # Yes/No token IDs\nITEM_COUNT_VALUES = [10]  # Number of items per request\n\n# Special token to replicate for precise token counting\nSPECIAL_REPLICATED_TOKEN = \"<|im_start|>\"\n\n\n###############################################################################\n# REQUEST GENERATION (in parallel)\n###############################################################################\ndef create_score_request_builder():\n    \"\"\"Create a score request builder function with shared tokenizer.\"\"\"\n    # Load tokenizer once here to verify special token and get precise counts\n    print(\"Loading tokenizer...\")\n    tokenizer = AutoTokenizer.from_pretrained(SCORE_MODEL_PATH)\n\n    # Verify that our special token produces exactly 1 token\n    special_token_count = len(\n        tokenizer.encode(config.special_replicated_token, add_special_tokens=False)\n    )\n    print(\n        f\"Special token '{config.special_replicated_token}' produces \"\n        f\"{special_token_count} token(s)\"\n    )\n\n    def generate_text_with_token_count_local(num_toks):\n        \"\"\"Generate text with precise token count using replicated token.\"\"\"\n        return generate_text_with_token_count(\n            SCORE_MODEL_PATH,\n            num_toks,\n            config.special_replicated_token,\n            tokenizer=tokenizer,\n        )\n\n    def build_score_request(index: int, item_count: int) -> tuple:\n        \"\"\"Build a single score request.\"\"\"\n        try:\n            # Generate query and items for score API\n            query = generate_text_with_token_count_local(SCORE_QUERY_TOKENS)\n            items = [\n                generate_text_with_token_count_local(SCORE_ITEM_TOKENS)\n                for _ in range(item_count)\n            ]\n\n            # Return as dict for score API format\n            score_data = {\n                \"query\": query,\n                \"items\": items,\n                \"label_token_ids\": SCORE_LABEL_TOKEN_IDS,\n                \"model\": SCORE_MODEL_PATH,\n            }\n            return (index, score_data)\n\n        except Exception as e:\n            print(f\"Error building request {index}: {e}\")\n            return (index, None)\n\n    return build_score_request\n\n\ndef validate_score_response(response_data: dict) -> bool:\n    \"\"\"Validate score API response.\"\"\"\n    return \"scores\" in response_data or \"logprobs\" in response_data\n\n\ndef build_warmup_score_request() -> dict:\n    \"\"\"Build a warmup request for the score API.\"\"\"\n    # Load tokenizer once for warmup generation\n    tokenizer = AutoTokenizer.from_pretrained(SCORE_MODEL_PATH)\n\n    warmup_query = generate_text_with_token_count(\n        SCORE_MODEL_PATH,\n        SCORE_QUERY_TOKENS,\n        config.special_replicated_token,\n        tokenizer=tokenizer,\n    )\n    warmup_items = [\n        generate_text_with_token_count(\n            SCORE_MODEL_PATH,\n            SCORE_ITEM_TOKENS,\n            config.special_replicated_token,\n            tokenizer=tokenizer,\n        )\n        for _ in range(3)\n    ]\n\n    return {\n        \"query\": warmup_query,\n        \"items\": warmup_items,\n        \"label_token_ids\": SCORE_LABEL_TOKEN_IDS,\n        \"model\": SCORE_MODEL_PATH,\n        # Add missing parameters for consistency with the original warmup\n        \"apply_softmax\": True,\n        \"item_first\": False,\n    }\n\n\n###############################################################################\n# MAIN\n###############################################################################\nasync def run_benchmark(rps, duration_secs, item_count):\n    \"\"\"Run a single benchmark with the given RPS value.\"\"\"\n    # Create the request builder function with shared tokenizer\n    build_request_func = create_score_request_builder()\n\n    return await run_generic_benchmark(\n        rps=rps,\n        duration_secs=duration_secs,\n        item_count=item_count,\n        config=config,\n        http_url=HTTP_URL,\n        build_request_func=build_request_func,\n        response_validator=validate_score_response,\n        api_name=\"SINGLE_ITEM_SCORING\",\n        request_description=\"score requests\",\n    )\n\n\nasync def main():\n    \"\"\"Main function that runs benchmarks for all RPS values.\"\"\"\n    additional_info = {\n        \"Query tokens per request\": SCORE_QUERY_TOKENS,\n        \"Item tokens per item\": SCORE_ITEM_TOKENS,\n    }\n\n    await run_benchmark_main(\n        config,\n        run_benchmark,\n        \"SINGLE_ITEM_SCORING\",\n        HTTP_URL,\n        ITEM_COUNT_VALUES,\n        additional_info,\n        build_warmup_score_request,\n    )\n\n\nif __name__ == \"__main__\":\n    asyncio.run(main())\n"
  },
  {
    "path": "benchmark/prefill_only/util.py",
    "content": "\"\"\"\nCommon utilities for SGLang benchmark scripts.\n\nThis module contains shared code for benchmarking different SGLang APIs\nincluding scoring, embeddings, and other endpoints.\n\"\"\"\n\nimport asyncio\nimport concurrent.futures\nimport json\nimport os\nimport random\nfrom statistics import mean\nfrom typing import Any, Callable, Dict, List, Optional, Tuple\n\nimport aiohttp\nimport numpy as np\nfrom tqdm import tqdm\nfrom transformers import AutoTokenizer\n\n\nclass BenchmarkConfig:\n    \"\"\"Configuration for benchmark parameters.\"\"\"\n\n    def __init__(self):\n        # Common benchmark settings\n        self.server_type = \"HTTP\"\n        self.rps_values = [70]\n        self.duration_secs_values = [60]\n        self.num_unique_requests = 100\n        self.distribution = \"POISSON\"  # Options: \"CONSTANT\", \"POISSON\"\n        self.profile = False\n\n        # Garbage Collection Control\n        self.freeze_gc = True  # Enable/disable garbage collection freezing\n\n        # Profiler configuration\n        self.profiler_dir = (\n            os.getcwd()\n        )  # Default profiler output directory (current working directory)\n\n        # Special token for text generation\n        self.special_replicated_token = \"<|im_start|>\"\n\n\ndef generate_text_with_token_count(\n    model_path: str,\n    num_tokens: int,\n    special_token: str = \"<|im_start|>\",\n    tokenizer: Optional[Any] = None,\n) -> str:\n    \"\"\"\n    Generate text with precise token count using a replicated token.\n\n    Args:\n        model_path: Path to the model for tokenizer\n        num_tokens: Target number of tokens\n        special_token: Token to replicate\n        tokenizer: Optional pre-loaded tokenizer to avoid repeated loading\n\n    Returns:\n        Generated text with approximately the target token count\n    \"\"\"\n    if tokenizer is None:\n        tokenizer = AutoTokenizer.from_pretrained(model_path)\n\n    # Verify token count\n    special_token_count = len(tokenizer.encode(special_token, add_special_tokens=False))\n\n    if special_token_count == 1:\n        # Simple case: token maps to exactly 1 token\n        return special_token * num_tokens\n    else:\n        print(f\"Special token '{special_token}' produces {special_token_count} tokens\")\n        # Handle case where special token produces multiple tokens\n        repetitions = (num_tokens + special_token_count - 1) // special_token_count\n        text = special_token * repetitions\n\n        # Verify we got the expected token count\n        actual_tokens = len(tokenizer.encode(text, add_special_tokens=False))\n        if actual_tokens < num_tokens:\n            print(f\"Warning: Generated {actual_tokens} tokens, expected {num_tokens}\")\n\n        return text\n\n\ndef setup_profiler(config: BenchmarkConfig, benchmark_name: str) -> None:\n    \"\"\"\n    Set up profiler environment if profiling is enabled.\n\n    Args:\n        config: Benchmark configuration\n        benchmark_name: Name of the benchmark (used in directory path)\n    \"\"\"\n    if config.profile:\n        # Create benchmark-specific subdirectory\n        profiler_path = os.path.join(\n            config.profiler_dir, benchmark_name.lower().replace(\"_\", \"-\")\n        )\n        os.environ[\"SGLANG_TORCH_PROFILER_DIR\"] = profiler_path\n        print(f\"Profiler enabled. Output directory: {profiler_path}\")\n    else:\n        print(\"Profiler disabled\")\n\n\ndef prepare_all_requests_parallel(\n    num_requests: int,\n    item_count: int,\n    build_request_func: Callable[[int, int], Tuple[int, Any]],\n    config: BenchmarkConfig,\n    description: str = \"requests\",\n) -> List[Any]:\n    \"\"\"\n    Generic function to generate unique requests in parallel, then reuse them.\n\n    Args:\n        num_requests: Total number of requests needed\n        item_count: Number of items per request (batch size)\n        build_request_func: Function that takes (index, item_count) and returns (index, request_data)\n        config: Benchmark configuration\n        description: Description for progress bars\n\n    Returns:\n        List of request data objects\n    \"\"\"\n\n    def build_request_wrapper(index):\n        \"\"\"Wrapper to call the provided build_request_func.\"\"\"\n        try:\n            return build_request_func(index, item_count)\n        except Exception as e:\n            print(f\"Error building request {index}: {e}\")\n            return (index, None)\n\n    # Generate only the unique requests\n    unique_requests = [None] * config.num_unique_requests\n    max_workers = min(8, os.cpu_count() or 1)  # Limit to 8 threads max\n\n    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:\n        futures = []\n        for i in tqdm(\n            range(config.num_unique_requests),\n            desc=f\"Submitting {description} generation tasks\",\n        ):\n            future = executor.submit(build_request_wrapper, i)\n            futures.append(future)\n\n        # Collect results as they complete\n        for f in tqdm(\n            concurrent.futures.as_completed(futures),\n            desc=f\"Building unique {description}\",\n            total=config.num_unique_requests,\n        ):\n            try:\n                index, req_data = f.result()\n                if req_data is not None:\n                    unique_requests[index] = req_data\n                else:\n                    print(f\"Failed to build request {index}\")\n            except Exception as e:\n                print(f\"Error processing request result: {e}\")\n\n    # Check if we have any valid requests\n    valid_requests = [req for req in unique_requests if req is not None]\n    if not valid_requests:\n        raise RuntimeError(\"Failed to generate any valid requests\")\n\n    print(\n        f\"Successfully generated {len(valid_requests)} out of \"\n        f\"{config.num_unique_requests} unique {description}\"\n    )\n\n    # Create the full request list by cycling through unique requests\n    print(\n        f\"Reusing {len(valid_requests)} unique {description} to create \"\n        f\"{num_requests} total requests...\"\n    )\n    all_requests = []\n    for i in tqdm(range(num_requests), desc=f\"Reusing {description}\"):\n        unique_index = i % len(valid_requests)\n        all_requests.append(valid_requests[unique_index])\n\n    print(f\"All {description} prepared.\\n\")\n    return all_requests\n\n\nasync def sleep_with_distribution(distribution: str, rps: float) -> None:\n    \"\"\"\n    Sleep according to the specified distribution pattern.\n\n    Args:\n        distribution: \"CONSTANT\" or \"POISSON\"\n        rps: Requests per second rate\n    \"\"\"\n    if distribution == \"CONSTANT\":\n        interval = 1 / rps\n        await asyncio.sleep(interval)\n    elif distribution == \"POISSON\":\n        # For Poisson process, inter-arrival times follow exponential distribution\n        interval = random.expovariate(rps)\n        await asyncio.sleep(interval)\n    else:\n        raise ValueError(\n            f\"Unknown distribution: {distribution}. Use 'CONSTANT' or 'POISSON'.\"\n        )\n\n\ndef build_http_request_json(request_data: Any) -> str:\n    \"\"\"\n    Generic function to build HTTP request JSON.\n\n    Args:\n        request_data: The data to serialize to JSON\n\n    Returns:\n        JSON string representation of the request data\n    \"\"\"\n    return json.dumps(request_data)\n\n\nasync def make_http_call(\n    session: aiohttp.ClientSession,\n    request_data: Any,\n    request_id: int,\n    results_queue: asyncio.Queue,\n    http_url: str,\n    response_validator: Callable[[Dict[str, Any]], bool],\n    api_name: str = \"API\",\n) -> None:\n    \"\"\"\n    Generic HTTP call function for API requests.\n\n    Args:\n        session: aiohttp client session\n        request_data: Data to send in the request\n        request_id: Unique identifier for this request\n        results_queue: Queue to put results\n        http_url: URL to send the request to\n        response_validator: Function to validate the response JSON\n        api_name: Name of the API for error messages\n    \"\"\"\n    try:\n        start_time = asyncio.get_running_loop().time()\n\n        request_json = build_http_request_json(request_data)\n        headers = {\"Content-Type\": \"application/json\"}\n\n        async with session.post(http_url, data=request_json, headers=headers) as resp:\n            resp_text = await resp.text()\n\n            if resp.status != 200:\n                print(\n                    f\"[HTTP] {api_name} Request {request_id} failed with status \"\n                    f\"{resp.status}: {resp_text}\"\n                )\n                completion_time = asyncio.get_running_loop().time()\n                await results_queue.put((request_id, 0, False, completion_time))\n                return\n\n            # Parse and validate response\n            try:\n                response_data = json.loads(resp_text)\n                success = response_validator(response_data)\n                if not success:\n                    print(\n                        f\"[HTTP] {api_name} Request {request_id} failed response validation\"\n                    )\n            except json.JSONDecodeError:\n                print(\n                    f\"[HTTP] {api_name} Request {request_id} failed to parse JSON response\"\n                )\n                success = False\n\n        completion_time = asyncio.get_running_loop().time()\n        elapsed_time = (completion_time - start_time) * 1000\n        await results_queue.put((request_id, elapsed_time, success, completion_time))\n\n    except Exception as e:\n        print(f\"[HTTP] {api_name} Error for request {request_id}: {e}\")\n        completion_time = asyncio.get_running_loop().time()\n        await results_queue.put((request_id, 0, False, completion_time))\n\n\nasync def send_profile_request(\n    profile_text: str, http_url: str, session: Optional[aiohttp.ClientSession] = None\n) -> None:\n    \"\"\"\n    Send a profile request (START_PROFILE or STOP_PROFILE) and wait for completion.\n\n    Args:\n        profile_text: \"START_PROFILE\" or \"STOP_PROFILE\"\n        http_url: Base HTTP URL (will derive profile endpoints from this)\n        session: Optional aiohttp session to use\n    \"\"\"\n    try:\n        if session:\n            print(f\"Sending {profile_text} request via HTTP...\")\n\n            # Determine the correct endpoint\n            if \"/v1/\" in http_url:\n                base_url = http_url.rsplit(\"/v1/\", 1)[0]  # Remove /v1/xxx\n            else:\n                base_url = http_url.rsplit(\"/\", 1)[0]  # Remove last path component\n\n            if profile_text == \"START_PROFILE\":\n                endpoint_url = f\"{base_url}/start_profile\"\n            elif profile_text == \"STOP_PROFILE\":\n                endpoint_url = f\"{base_url}/stop_profile\"\n            else:\n                print(f\"Unknown profile request: {profile_text}\")\n                return\n\n            headers = {\"Content-Type\": \"application/json\"}\n\n            async with session.post(endpoint_url, headers=headers) as resp:\n                resp_text = await resp.text()\n                if resp.status == 200:\n                    print(f\"{profile_text} request completed\")\n                else:\n                    print(\n                        f\"{profile_text} request failed with status \"\n                        f\"{resp.status}: {resp_text}\"\n                    )\n        else:\n            print(f\"Cannot send {profile_text} request - missing session\")\n\n    except Exception as e:\n        print(f\"Error sending {profile_text} request: {e}\")\n\n\nasync def call_freeze_gc_http(session: aiohttp.ClientSession, http_url: str) -> None:\n    \"\"\"\n    Call the /freeze_gc HTTP endpoint.\n\n    Args:\n        session: aiohttp client session\n        http_url: Base HTTP URL to derive the freeze_gc endpoint from\n    \"\"\"\n    try:\n        # Derive freeze_gc endpoint from the API URL\n        if \"/v1/\" in http_url:\n            freeze_gc_url = http_url.rsplit(\"/v1/\", 1)[0] + \"/freeze_gc\"\n        else:\n            freeze_gc_url = http_url.rsplit(\"/\", 1)[0] + \"/freeze_gc\"\n\n        print(f\"Calling freeze_gc endpoint: {freeze_gc_url}\")\n\n        async with session.post(freeze_gc_url) as resp:\n            if resp.status == 200:\n                print(\"freeze_gc called successfully\")\n            else:\n                resp_text = await resp.text()\n                print(f\"freeze_gc failed with status {resp.status}: {resp_text}\")\n\n    except Exception as e:\n        print(f\"Failed to call freeze_gc: {e}\")\n\n\nasync def send_warmup_requests(\n    session: aiohttp.ClientSession,\n    http_url: str,\n    build_warmup_request_func: Callable[[], Any],\n    num_warmup: int = 3,\n) -> None:\n    \"\"\"\n    Send warmup requests to HTTP server.\n\n    Args:\n        session: aiohttp client session\n        http_url: URL to send warmup requests to\n        build_warmup_request_func: Function that returns a warmup request object\n        num_warmup: Number of warmup requests to send\n    \"\"\"\n    print(f\"Sending {num_warmup} HTTP warmup requests...\")\n\n    for i in range(num_warmup):\n        try:\n            warmup_data = build_warmup_request_func()\n            request_json = build_http_request_json(warmup_data)\n            headers = {\"Content-Type\": \"application/json\"}\n\n            async with session.post(\n                http_url, data=request_json, headers=headers\n            ) as resp:\n                if resp.status == 200:\n                    print(f\"Warmup request {i+1}/{num_warmup} completed successfully\")\n                else:\n                    print(\n                        f\"Warmup request {i+1}/{num_warmup} failed with status {resp.status}\"\n                    )\n\n        except Exception as e:\n            print(f\"Warmup request {i+1}/{num_warmup} failed with error: {e}\")\n\n    print(\"HTTP warmup requests completed\")\n\n\nasync def perform_global_warmup_and_freeze(\n    config: BenchmarkConfig,\n    http_url: str,\n    build_warmup_request_func: Callable[[], Any],\n) -> None:\n    \"\"\"\n    Perform warmup and optionally GC freeze operations once before all benchmark runs.\n\n    Args:\n        config: Benchmark configuration\n        http_url: URL for API requests\n        build_warmup_request_func: Function that returns a warmup request object\n    \"\"\"\n    print(\"=\" * 80)\n    print(f\"PERFORMING GLOBAL WARMUP{' AND GC FREEZE' if config.freeze_gc else ''}\")\n    print(\"=\" * 80)\n\n    print(f\"Performing HTTP warmup{' and GC freeze' if config.freeze_gc else ''}...\")\n    async with aiohttp.ClientSession() as session:\n        await send_warmup_requests(session, http_url, build_warmup_request_func)\n        if config.freeze_gc:\n            await call_freeze_gc_http(session, http_url)\n        print(\n            f\"HTTP warmup{' and GC freeze' if config.freeze_gc else ''} completed successfully.\"\n        )\n\n    print(\n        f\"Global warmup{' and GC freeze' if config.freeze_gc else ''} operations completed.\"\n    )\n    print(\"=\" * 80)\n\n\nasync def process_results(\n    results_queue: asyncio.Queue,\n    num_requests: int,\n    send_duration: float,\n    total_duration: float,\n    rps: int,\n    duration_secs: int,\n    item_count: int,\n    test_start_time: float,\n    config: BenchmarkConfig,\n    http_mode: str = \"UNKNOWN\",\n) -> List[Dict[str, Any]]:\n    \"\"\"\n    Process benchmark results and group them by minute intervals.\n\n    Args:\n        results_queue: Queue containing result tuples\n        num_requests: Total number of requests sent\n        send_duration: Time taken to send all requests\n        total_duration: Total time for all requests to complete\n        rps: Target requests per second\n        duration_secs: Test duration in seconds\n        item_count: Number of items per request\n        test_start_time: Start time of the test\n        config: Benchmark configuration\n        http_mode: Description of the HTTP mode/API being tested\n\n    Returns:\n        List of dictionaries containing minute-by-minute results\n    \"\"\"\n    all_results = []\n\n    # Collect all results\n    for _ in range(num_requests):\n        result = await results_queue.get()\n        request_id, elapsed_time, success, completion_time = result\n        all_results.append(\n            {\n                \"request_id\": request_id,\n                \"elapsed_time\": elapsed_time,\n                \"success\": success,\n                \"completion_time\": completion_time,\n            }\n        )\n\n    # Group results by minute intervals\n    minute_results = []\n    num_minutes = int(duration_secs // 60) + (1 if duration_secs % 60 > 0 else 0)\n\n    for minute in range(num_minutes):\n        minute_start = test_start_time + (minute * 60)\n        minute_end = test_start_time + ((minute + 1) * 60)\n\n        # Filter results that completed in this minute\n        minute_data = [\n            r for r in all_results if minute_start <= r[\"completion_time\"] < minute_end\n        ]\n\n        response_times = [r[\"elapsed_time\"] for r in minute_data if r[\"success\"]]\n        successful_requests = len([r for r in minute_data if r[\"success\"]])\n        failed_requests = len([r for r in minute_data if not r[\"success\"]])\n\n        avg_response_time = mean(response_times) if response_times else 0\n\n        # Calculate percentiles using numpy\n        if response_times:\n            p50 = np.percentile(response_times, 50)\n            p90 = np.percentile(response_times, 90)\n            p99 = np.percentile(response_times, 99)\n        else:\n            p50 = p90 = p99 = 0\n\n        minute_result = {\n            \"test_duration_secs\": duration_secs,\n            \"minute_interval\": minute + 1,\n            \"target_rps\": rps,\n            \"item_count\": item_count,\n            \"server_type\": config.server_type,\n            \"distribution\": config.distribution,\n            \"unique_requests\": config.num_unique_requests,\n            \"total_requests\": len(minute_data),\n            \"successful_requests\": successful_requests,\n            \"failed_requests\": failed_requests,\n            \"send_duration_secs\": send_duration,\n            \"total_duration_secs\": total_duration,\n            \"avg_response_time_ms\": avg_response_time,\n            \"p50_response_time_ms\": p50,\n            \"p90_response_time_ms\": p90,\n            \"p99_response_time_ms\": p99,\n        }\n\n        minute_results.append(minute_result)\n\n        print(\n            f\"\\nMinute {minute + 1} Summary for RPS {rps}, \"\n            f\"Duration {duration_secs}s, Item Count {item_count}:\"\n        )\n        print(f\"  Requests completed in minute: {len(minute_data)}\")\n        print(f\"  Successful requests:   {successful_requests}\")\n        print(f\"  Failed requests:       {failed_requests}\")\n        print(f\"  Average response time: {avg_response_time:.2f} ms\")\n        print(f\"  P50 response time:     {p50:.2f} ms\")\n        print(f\"  P90 response time:     {p90:.2f} ms\")\n        print(f\"  P99 response time:     {p99:.2f} ms\")\n\n    # Print overall summary\n    all_response_times = [r[\"elapsed_time\"] for r in all_results if r[\"success\"]]\n    total_successful = len([r for r in all_results if r[\"success\"]])\n    total_failed = len([r for r in all_results if not r[\"success\"]])\n\n    overall_avg = mean(all_response_times) if all_response_times else 0\n    if all_response_times:\n        overall_p50 = np.percentile(all_response_times, 50)\n        overall_p90 = np.percentile(all_response_times, 90)\n        overall_p99 = np.percentile(all_response_times, 99)\n    else:\n        overall_p50 = overall_p90 = overall_p99 = 0\n\n    print(\n        f\"\\nOverall Summary for RPS {rps}, Duration {duration_secs}s, \"\n        f\"Item Count {item_count}:\"\n    )\n    print(f\"  Test duration:         {duration_secs} seconds\")\n    print(f\"  Server type:           {config.server_type}\")\n    print(f\"  HTTP mode:             {http_mode}\")\n    print(f\"  Target RPS:            {rps}\")\n    print(f\"  Achieved RPS:          {len(all_results) / total_duration:.2f}\")\n    print(f\"  Item count:            {item_count}\")\n    print(f\"  Distribution:          {config.distribution}\")\n    print(f\"  Unique requests generated: {config.num_unique_requests}\")\n    print(f\"  Total requests sent:   {num_requests}\")\n    print(f\"  Successful requests:   {total_successful}\")\n    print(f\"  Failed requests:       {total_failed}\")\n    print(f\"  Time to send all requests: {send_duration:.2f} seconds\")\n    print(f\"  Time for all requests to complete: {total_duration:.2f} seconds\")\n    print(f\"  Average response time: {overall_avg:.2f} ms\")\n    print(f\"  P50 response time:     {overall_p50:.2f} ms\")\n    print(f\"  P90 response time:     {overall_p90:.2f} ms\")\n    print(f\"  P99 response time:     {overall_p99:.2f} ms\\n\")\n\n    return minute_results\n\n\ndef print_csv_results(all_results: List[Dict[str, Any]]) -> None:\n    \"\"\"\n    Print benchmark results in CSV format.\n\n    Args:\n        all_results: List of result dictionaries from process_results\n    \"\"\"\n    print(\"\\n\" + \"=\" * 80)\n    print(\"FINAL CSV RESULTS:\")\n    print(\"=\" * 80)\n\n    # CSV Header\n    headers = [\n        \"test_duration_secs\",\n        \"minute_interval\",\n        \"target_rps\",\n        \"item_count\",\n        \"server_type\",\n        \"distribution\",\n        \"unique_requests\",\n        \"total_requests\",\n        \"successful_requests\",\n        \"failed_requests\",\n        \"send_duration_secs\",\n        \"total_duration_secs\",\n        \"avg_response_time_ms\",\n        \"p50_response_time_ms\",\n        \"p90_response_time_ms\",\n        \"p99_response_time_ms\",\n    ]\n    print(\",\".join(headers))\n\n    # CSV Data\n    for result in all_results:\n        row = [\n            result[\"test_duration_secs\"],\n            result[\"minute_interval\"],\n            result[\"target_rps\"],\n            result[\"item_count\"],\n            result[\"server_type\"],\n            result[\"distribution\"],\n            result[\"unique_requests\"],\n            result[\"total_requests\"],\n            result[\"successful_requests\"],\n            result[\"failed_requests\"],\n            f\"{result['send_duration_secs']:.2f}\",\n            f\"{result['total_duration_secs']:.2f}\",\n            f\"{result['avg_response_time_ms']:.2f}\",\n            f\"{result['p50_response_time_ms']:.2f}\",\n            f\"{result['p90_response_time_ms']:.2f}\",\n            f\"{result['p99_response_time_ms']:.2f}\",\n        ]\n        print(\",\".join(map(str, row)))\n\n\nasync def run_benchmark_main(\n    config: BenchmarkConfig,\n    run_single_benchmark_func,\n    benchmark_name: str,\n    http_url: str,\n    item_count_values: List[int],\n    additional_info: Optional[Dict[str, Any]] = None,\n    build_warmup_request_func: Optional[Callable[[], Any]] = None,\n) -> None:\n    \"\"\"\n    Main benchmark orchestration function.\n\n    Args:\n        config: Benchmark configuration\n        run_single_benchmark_func: Async function to run a single benchmark\n        benchmark_name: Name of the benchmark (e.g., \"SCORING\", \"EMBEDDINGS\")\n        http_url: URL of the API endpoint\n        item_count_values: List of item counts to test\n        additional_info: Additional information to print in the header\n        build_warmup_request_func: Optional function to build warmup requests\n    \"\"\"\n    total_combinations = (\n        len(config.duration_secs_values)\n        * len(config.rps_values)\n        * len(item_count_values)\n    )\n\n    print(\n        f\"Running benchmarks for {len(config.duration_secs_values)} duration \"\n        f\"values, {len(config.rps_values)} RPS values, and \"\n        f\"{len(item_count_values)} item count values = \"\n        f\"{total_combinations} total combinations\"\n    )\n    print(f\"Server Type: {config.server_type}\")\n    print(f\"HTTP Mode: {benchmark_name}\")\n    print(f\"API URL: {http_url}\")\n\n    if additional_info:\n        for key, value in additional_info.items():\n            print(f\"{key}: {value}\")\n\n    print(f\"Items per request (batch size): {item_count_values}\")\n    print(f\"Profiling Enabled: {config.profile}\")\n    print(f\"Duration values: {config.duration_secs_values}\")\n    print(f\"RPS values: {config.rps_values}\")\n    print(f\"Item count values: {item_count_values}\")\n    print(\"=\" * 80)\n\n    # Set up profiler environment\n    setup_profiler(config, benchmark_name)\n\n    # Perform global warmup and GC freeze operations if warmup function is provided\n    if build_warmup_request_func is not None:\n        await perform_global_warmup_and_freeze(\n            config, http_url, build_warmup_request_func\n        )\n\n    all_results = []\n\n    for duration_secs in config.duration_secs_values:\n        for rps in config.rps_values:\n            for item_count in item_count_values:\n                result = await run_single_benchmark_func(rps, duration_secs, item_count)\n                all_results.extend(result)  # Extend with minute results\n\n    print_csv_results(all_results)\n\n\nasync def run_generic_benchmark(\n    rps: int,\n    duration_secs: int,\n    item_count: int,\n    config: BenchmarkConfig,\n    http_url: str,\n    build_request_func: Callable[[int, int], Tuple[int, Any]],\n    response_validator: Callable[[Dict[str, Any]], bool],\n    api_name: str,\n    request_description: str = \"requests\",\n) -> List[Dict[str, Any]]:\n    \"\"\"\n    Generic benchmark runner that can be used for different APIs.\n\n    Args:\n        rps: Requests per second\n        duration_secs: Duration of the test in seconds\n        item_count: Number of items per request (batch size)\n        config: Benchmark configuration\n        http_url: URL of the API endpoint\n        build_request_func: Function to build individual requests\n        response_validator: Function to validate API responses\n        api_name: Name of the API for logging\n        request_description: Description for progress bars\n\n    Returns:\n        List of dictionaries containing minute-by-minute results\n    \"\"\"\n    num_requests = int(rps * duration_secs)\n    print(\n        f\"Starting benchmark with RPS={rps}, Duration={duration_secs}s, \"\n        f\"Item Count={item_count}, num_requests={num_requests}\"\n    )\n    print(f\"Server Type: {config.server_type}\")\n    print(f\"HTTP Mode: {api_name}\")\n    print(f\"Profiling Enabled: {config.profile}\")\n\n    # Build requests in parallel (unmeasured)\n    all_requests = prepare_all_requests_parallel(\n        num_requests, item_count, build_request_func, config, request_description\n    )\n\n    results_queue = asyncio.Queue()\n    tasks = []\n\n    # Track timing for sending requests\n    send_start_time = asyncio.get_running_loop().time()\n\n    # HTTP implementation\n    async with aiohttp.ClientSession(\n        timeout=aiohttp.ClientTimeout(total=300)\n    ) as session:\n\n        # Send START_PROFILE if profiling is enabled\n        if config.profile:\n            await send_profile_request(\"START_PROFILE\", http_url, session=session)\n\n        # Add progress bar for sending requests\n        with tqdm(\n            total=len(all_requests),\n            desc=f\"Sending HTTP {request_description} at {rps} RPS\",\n            unit=\"req\",\n        ) as pbar:\n            for i, request_data in enumerate(all_requests):\n                request_id = i + 1\n                tasks.append(\n                    asyncio.create_task(\n                        make_http_call(\n                            session,\n                            request_data,\n                            request_id,\n                            results_queue,\n                            http_url,\n                            response_validator,\n                            api_name,\n                        )\n                    )\n                )\n\n                # Update progress bar\n                pbar.update(1)\n\n                # Throttle based on distribution\n                if i < len(all_requests) - 1:\n                    await sleep_with_distribution(config.distribution, rps)\n\n        send_end_time = asyncio.get_running_loop().time()\n        send_duration = send_end_time - send_start_time\n\n        # Wait for all requests to complete with progress tracking\n        print(f\"Waiting for {len(tasks)} HTTP {request_description} to complete...\")\n        with tqdm(\n            total=len(tasks), desc=f\"Completing HTTP {request_description}\", unit=\"req\"\n        ) as completion_pbar:\n            completed_tasks = []\n            for task in asyncio.as_completed(tasks):\n                await task\n                completed_tasks.append(task)\n                completion_pbar.update(1)\n\n        # Send STOP_PROFILE if profiling is enabled\n        if config.profile:\n            await send_profile_request(\"STOP_PROFILE\", http_url, session=session)\n\n    completion_end_time = asyncio.get_running_loop().time()\n    total_duration = completion_end_time - send_start_time\n\n    return await process_results(\n        results_queue,\n        num_requests,\n        send_duration,\n        total_duration,\n        rps,\n        duration_secs,\n        item_count,\n        send_start_time,\n        config,\n        api_name,\n    )\n"
  },
  {
    "path": "benchmark/react/README.md",
    "content": "## Run benchmark\n\nNOTE: This is an implementation for replaying a given trace for throughput/latency benchmark purposes. It is not an actual ReAct agent implementation.\n\n### Benchmark sglang\n```\npython -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000\n```\n\n```\npython3 bench_sglang.py --num-questions 100\n```\n\n\n### Benchmark vllm\n```\npython3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000\n```\n\n```\npython3 bench_other.py --num-questions 100 --backend vllm\n```\n\n\n### Benchmark guidance\n```\npython3 bench_other.py --num-questions 100 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf\n```\n\n### Benchmark lmql\n\n```\npython3 bench_other.py --num-questions 100 --backend lmql --parallel 1\n```\n"
  },
  {
    "path": "benchmark/react/bench_other.py",
    "content": "import argparse\nimport json\nimport time\nfrom concurrent.futures import ThreadPoolExecutor\n\nfrom tqdm import tqdm\n\nfrom sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate\nfrom sglang.utils import dump_state_text, read_jsonl\n\n\ndef get_prompt(question):\n    prompt = (\n        \"\"\"Solve a question answering task with interleaving Thought, Action, Observation steps. Thought can reason about the current situation, and Action can be three types:\n(1) Search[entity], which searches the exact entity on Wikipedia and returns the first paragraph if it exists. If not, it will return some similar entities to search.\n(2) Lookup[keyword], which returns the next sentence containing keyword in the current passage.\n(3) Finish[answer], which returns the answer and finishes the task.\nHere are some examples.\nQuestion: What is the elevation range for the area that the eastern sector of the Colorado orogeny extends into?\nThought 1: I need to search Colorado orogeny, find the area that the eastern sector of the Colorado orogeny extends into, then find the elevation range of the area.\nAction 1: Search[Colorado orogeny]\nObservation 1: The Colorado orogeny was an episode of mountain building (an orogeny) in Colorado and surrounding areas.\nThought 2: It does not mention the eastern sector. So I need to look up eastern sector.\nAction 2: Lookup[eastern sector]\nObservation 2: (Result 1 / 1) The eastern sector extends into the High Plains and is called the Central Plains orogeny.\nThought 3: The eastern sector of Colorado orogeny extends into the High Plains. So I need to search High Plains and find its elevation range.\nAction 3: Search[High Plains]\nObservation 3: High Plains refers to one of two distinct land regions:\nThought 4: I need to instead search High Plains (United States).\nAction 4: Search[High Plains (United States)]\nObservation 4: The High Plains are a subregion of the Great Plains. From east to west, the High Plains rise in elevation from around 1,800 to 7,000 ft (550 to 2,130 m).[3]\nThought 5: High Plains rise in elevation from around 1,800 to 7,000 ft, so the answer is 1,800 to 7,000 ft.\nAction 5: Finish[1,800 to 7,000 ft]\nQuestion: Musician and satirist Allie Goertz wrote a song about the \"The Simpsons\" character Milhouse, who Matt Groening named after who?\nThought 1: The question simplifies to \"The Simpsons\" character Milhouse is named after who. I only need to search Milhouse and find who it is named after.\nAction 1: Search[Milhouse]\nObservation 1: Milhouse Mussolini Van Houten is a recurring character in the Fox animated television series The Simpsons voiced by Pamela Hayden and created by Matt Groening.\nThought 2: The paragraph does not tell who Milhouse is named after, maybe I can look up \"named after\".\nAction 2: Lookup[named after]\nObservation 2: (Result 1 / 1) Milhouse was named after U.S. president Richard Nixon, whose middle name was Milhous.\nThought 3: Milhouse was named after U.S. president Richard Nixon, so the answer is Richard Nixon.\nAction 3: Finish[Richard Nixon]\nQuestion: Which documentary is about Finnish rock groups, Adam Clayton Powell or The Saimaa Gesture?\nThought 1: I need to search Adam Clayton Powell and The Saimaa Gesture, and find which documentary is about Finnish rock groups.\nAction 1: Search[Adam Clayton Powell]\nObservation 1: Could not find [Adam Clayton Powell]. Similar: ['Adam Clayton Powell III', 'Seventh Avenue (Manhattan)', 'Adam Clayton Powell Jr. State Office Building', 'Isabel Washington Powell', 'Adam Powell', 'Adam Clayton Powell (film)', 'Giancarlo Esposito'].\nThought 2: To find the documentary, I can search Adam Clayton Powell (film).\nAction 2: Search[Adam Clayton Powell (film)]\nObservation 2: Adam Clayton Powell is a 1989 American documentary film directed by Richard Kilberg.\nThe film is about the rise and fall of influential African-American politician Adam Clayton Powell Jr.[3][4] It was later aired as part of the PBS series The American Experience.\nThought 3: Adam Clayton Powell (film) is a documentary about an African-American politician, not Finnish rock groups. So the documentary about Finnish rock groups must instead be The Saimaa Gesture.\nAction 3: Finish[The Saimaa Gesture]\nQuestion: What profession does Nicholas Ray and Elia Kazan have in common?\nThought 1: I need to search Nicholas Ray and Elia Kazan, find their professions, then find the profession they have in common.\nAction 1: Search[Nicholas Ray]\nObservation 1: Nicholas Ray (born Raymond Nicholas Kienzle Jr., August 7, 1911 – June 16, 1979) was an American film director, screenwriter, and actor best known for the 1955 film Rebel Without a Cause.\nThought 2: Professions of Nicholas Ray are director, screenwriter, and actor. I need to search Elia Kazan next and find his professions.\nAction 2: Search[Elia Kazan]\nObservation 2: Elia Kazan was an American film and theatre director, producer, screenwriter and actor.\nThought 3: Professions of Elia Kazan are director, producer, screenwriter, and actor. So profession Nicholas Ray and Elia Kazan have in common is director, screenwriter, and actor.\nAction 3: Finish[director, screenwriter, actor]\nQuestion: Which magazine was started first Arthur's Magazine or First for Women?\nThought 1: I need to search Arthur's Magazine and First for Women, and find which was started first.\nAction 1: Search[Arthur's Magazine]\nObservation 1: Arthur's Magazine (1844-1846) was an American literary periodical published in Philadelphia in the 19th century.\nThought 2: Arthur's Magazine was started in 1844. I need to search First for Women next.\nAction 2: Search[First for Women]\nObservation 2: First for Women is a woman's magazine published by Bauer Media Group in the USA.[1] The magazine was started in 1989.\nThought 3: First for Women was started in 1989. 1844 (Arthur's Magazine) < 1989 (First for Women), so Arthur's Magazine was started first.\nAction 3: Finish[Arthur's Magazine]\nQuestion: Were Pavel Urysohn and Leonid Levin known for the same type of work?\nThought 1: I need to search Pavel Urysohn and Leonid Levin, find their types of work, then find if they are the same.\nAction 1: Search[Pavel Urysohn]\nObservation 1: Pavel Samuilovich Urysohn (February 3, 1898 â August 17, 1924) was a Soviet mathematician who is best known for his contributions in dimension theory.\nThought 2: Pavel Urysohn is a mathematician. I need to search Leonid Levin next and find its type of work.\nAction 2: Search[Leonid Levin]\nObservation 2: Leonid Anatolievich Levin is a Soviet-American mathematician and computer scientist.\nThought 3: Leonid Levin is a mathematician and computer scientist. So Pavel Urysohn and Leonid Levin have the same type of work.\nAction 3: Finish[yes]\n\"\"\"\n        + question\n    )\n    return prompt\n\n\ndef main(args):\n    lines = read_jsonl(args.data_path)[: args.num_questions]\n    arguments = [{\"question\": k, \"triplets\": v} for l in lines for k, v in l.items()]\n\n    states = []\n\n    # Select backend\n    call_generate = get_call_generate(args)\n\n    def run_single_agent(argument):\n        question = argument[\"question\"]\n        triplets = argument[\"triplets\"]\n        prompt = get_prompt(question)\n        for i in range(1, len(triplets) + 2):\n            prompt += \"Thought \" + str(i) + \":\"\n            states.append(prompt)\n            answer = call_generate(\n                prompt, max_tokens=200, temperature=0, stop=\"Observation\"\n            )\n            if i > len(triplets):\n                break\n            prompt += (\n                triplets[i - 1][\"thought\"]\n                + \"\\nAction \"\n                + str(i)\n                + \":\"\n                + triplets[i - 1][\"action\"]\n                + \"\\nObservation \"\n                + str(i)\n                + \":\"\n                + triplets[i - 1][\"observation\"]\n                + \"\\n\"\n            )\n\n            states.append(answer)\n\n    async def run_single_agent_async(argument):\n        question = argument[\"question\"]\n        triplets = argument[\"triplets\"]\n        prompt = get_prompt(question)\n        for i in range(1, len(triplets) + 2):\n            prompt += \"Thought \" + str(i) + \":\"\n            states.append(prompt)\n            answer = await call_generate(\n                prompt, max_tokens=200, temperature=0, stop=\"Observation\", max_len=4096\n            )\n            if i > len(triplets):\n                break\n            prompt += (\n                triplets[i - 1][\"thought\"]\n                + \"\\nAction \"\n                + str(i)\n                + \":\"\n                + triplets[i - 1][\"action\"]\n                + \"\\nObservation \"\n                + str(i)\n                + \":\"\n                + triplets[i - 1][\"observation\"]\n                + \"\\n\"\n            )\n\n            states.append(answer)\n\n    tic = time.perf_counter()\n\n    if args.backend != \"lmql\":\n        if args.parallel == 1:\n            for arg in tqdm(arguments):\n                run_single_agent(arg)\n        else:\n            with ThreadPoolExecutor(args.parallel) as executor:\n                list(\n                    tqdm(\n                        executor.map(run_single_agent, arguments), total=len(arguments)\n                    )\n                )\n\n    else:\n        import asyncio\n\n        loop = asyncio.get_event_loop()\n        batches = [\n            [] for _ in range((len(arguments) + args.parallel - 1) // args.parallel)\n        ]\n        for i, arg in enumerate(arguments):\n            batches[i // args.parallel].append(arg)\n        for bt in tqdm(batches):\n            tasks = [run_single_agent_async(arg) for arg in bt]\n            loop.run_until_complete(asyncio.gather(*tasks))\n\n    latency = time.perf_counter() - tic\n\n    print(f\"Latency: {latency:.3f}\")\n\n    # Write results\n    dump_state_text(f\"tmp_output_{args.backend}.txt\", states)\n\n    with open(args.result_file, \"a\") as fout:\n        value = {\n            \"task\": \"ReAct Agents\",\n            \"backend\": args.backend,\n            \"num_gpus\": 1,\n            \"latency\": round(latency, 3),\n            \"num_requests\": len(arguments),\n            \"other\": {\n                \"parallel\": args.parallel,\n            },\n        }\n        fout.write(json.dumps(value) + \"\\n\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--data-path\", type=str, default=\"hotpotqa_100.jsonl\")\n    parser.add_argument(\"--num-questions\", type=int, default=10)\n    args = add_common_other_args_and_parse(parser)\n    main(args)\n"
  },
  {
    "path": "benchmark/react/bench_sglang.py",
    "content": "import argparse\nimport json\nimport time\n\nimport sglang as sgl\nfrom sglang.test.test_utils import (\n    add_common_sglang_args_and_parse,\n    select_sglang_backend,\n)\nfrom sglang.utils import dump_state_text, read_jsonl\n\n\n@sgl.function\ndef webthink(s, question, triplets):\n    s += (\n        \"\"\"Solve a question answering task with interleaving Thought, Action, Observation steps. Thought can reason about the current situation, and Action can be three types:\n(1) Search[entity], which searches the exact entity on Wikipedia and returns the first paragraph if it exists. If not, it will return some similar entities to search.\n(2) Lookup[keyword], which returns the next sentence containing keyword in the current passage.\n(3) Finish[answer], which returns the answer and finishes the task.\nHere are some examples.\nQuestion: What is the elevation range for the area that the eastern sector of the Colorado orogeny extends into?\nThought 1: I need to search Colorado orogeny, find the area that the eastern sector of the Colorado orogeny extends into, then find the elevation range of the area.\nAction 1: Search[Colorado orogeny]\nObservation 1: The Colorado orogeny was an episode of mountain building (an orogeny) in Colorado and surrounding areas.\nThought 2: It does not mention the eastern sector. So I need to look up eastern sector.\nAction 2: Lookup[eastern sector]\nObservation 2: (Result 1 / 1) The eastern sector extends into the High Plains and is called the Central Plains orogeny.\nThought 3: The eastern sector of Colorado orogeny extends into the High Plains. So I need to search High Plains and find its elevation range.\nAction 3: Search[High Plains]\nObservation 3: High Plains refers to one of two distinct land regions:\nThought 4: I need to instead search High Plains (United States).\nAction 4: Search[High Plains (United States)]\nObservation 4: The High Plains are a subregion of the Great Plains. From east to west, the High Plains rise in elevation from around 1,800 to 7,000 ft (550 to 2,130 m).[3]\nThought 5: High Plains rise in elevation from around 1,800 to 7,000 ft, so the answer is 1,800 to 7,000 ft.\nAction 5: Finish[1,800 to 7,000 ft]\nQuestion: Musician and satirist Allie Goertz wrote a song about the \"The Simpsons\" character Milhouse, who Matt Groening named after who?\nThought 1: The question simplifies to \"The Simpsons\" character Milhouse is named after who. I only need to search Milhouse and find who it is named after.\nAction 1: Search[Milhouse]\nObservation 1: Milhouse Mussolini Van Houten is a recurring character in the Fox animated television series The Simpsons voiced by Pamela Hayden and created by Matt Groening.\nThought 2: The paragraph does not tell who Milhouse is named after, maybe I can look up \"named after\".\nAction 2: Lookup[named after]\nObservation 2: (Result 1 / 1) Milhouse was named after U.S. president Richard Nixon, whose middle name was Milhous.\nThought 3: Milhouse was named after U.S. president Richard Nixon, so the answer is Richard Nixon.\nAction 3: Finish[Richard Nixon]\nQuestion: Which documentary is about Finnish rock groups, Adam Clayton Powell or The Saimaa Gesture?\nThought 1: I need to search Adam Clayton Powell and The Saimaa Gesture, and find which documentary is about Finnish rock groups.\nAction 1: Search[Adam Clayton Powell]\nObservation 1: Could not find [Adam Clayton Powell]. Similar: ['Adam Clayton Powell III', 'Seventh Avenue (Manhattan)', 'Adam Clayton Powell Jr. State Office Building', 'Isabel Washington Powell', 'Adam Powell', 'Adam Clayton Powell (film)', 'Giancarlo Esposito'].\nThought 2: To find the documentary, I can search Adam Clayton Powell (film).\nAction 2: Search[Adam Clayton Powell (film)]\nObservation 2: Adam Clayton Powell is a 1989 American documentary film directed by Richard Kilberg.\nThe film is about the rise and fall of influential African-American politician Adam Clayton Powell Jr.[3][4] It was later aired as part of the PBS series The American Experience.\nThought 3: Adam Clayton Powell (film) is a documentary about an African-American politician, not Finnish rock groups. So the documentary about Finnish rock groups must instead be The Saimaa Gesture.\nAction 3: Finish[The Saimaa Gesture]\nQuestion: What profession does Nicholas Ray and Elia Kazan have in common?\nThought 1: I need to search Nicholas Ray and Elia Kazan, find their professions, then find the profession they have in common.\nAction 1: Search[Nicholas Ray]\nObservation 1: Nicholas Ray (born Raymond Nicholas Kienzle Jr., August 7, 1911 – June 16, 1979) was an American film director, screenwriter, and actor best known for the 1955 film Rebel Without a Cause.\nThought 2: Professions of Nicholas Ray are director, screenwriter, and actor. I need to search Elia Kazan next and find his professions.\nAction 2: Search[Elia Kazan]\nObservation 2: Elia Kazan was an American film and theatre director, producer, screenwriter and actor.\nThought 3: Professions of Elia Kazan are director, producer, screenwriter, and actor. So profession Nicholas Ray and Elia Kazan have in common is director, screenwriter, and actor.\nAction 3: Finish[director, screenwriter, actor]\nQuestion: Which magazine was started first Arthur's Magazine or First for Women?\nThought 1: I need to search Arthur's Magazine and First for Women, and find which was started first.\nAction 1: Search[Arthur's Magazine]\nObservation 1: Arthur's Magazine (1844-1846) was an American literary periodical published in Philadelphia in the 19th century.\nThought 2: Arthur's Magazine was started in 1844. I need to search First for Women next.\nAction 2: Search[First for Women]\nObservation 2: First for Women is a woman's magazine published by Bauer Media Group in the USA.[1] The magazine was started in 1989.\nThought 3: First for Women was started in 1989. 1844 (Arthur's Magazine) < 1989 (First for Women), so Arthur's Magazine was started first.\nAction 3: Finish[Arthur's Magazine]\nQuestion: Were Pavel Urysohn and Leonid Levin known for the same type of work?\nThought 1: I need to search Pavel Urysohn and Leonid Levin, find their types of work, then find if they are the same.\nAction 1: Search[Pavel Urysohn]\nObservation 1: Pavel Samuilovich Urysohn (February 3, 1898 â August 17, 1924) was a Soviet mathematician who is best known for his contributions in dimension theory.\nThought 2: Pavel Urysohn is a mathematician. I need to search Leonid Levin next and find its type of work.\nAction 2: Search[Leonid Levin]\nObservation 2: Leonid Anatolievich Levin is a Soviet-American mathematician and computer scientist.\nThought 3: Leonid Levin is a mathematician and computer scientist. So Pavel Urysohn and Leonid Levin have the same type of work.\nAction 3: Finish[yes]\n\"\"\"\n        + question\n    )\n    for i in range(1, len(triplets) + 2):\n        s += \"Thought \" + str(i) + \":\"\n        # NOTE: This is an implementation for replaying a given trace for benchmark purposes. It is not an actual ReAct agent implementation.\n        ss = s.fork(1)\n        ss[0] += sgl.gen(name=\"thought_action\", max_tokens=200, stop=\"Observation\")\n        ss.join()\n        # to verify the correctness of output, this should be collected\n        # print(ss[0][\"thought_action\"])\n        if i > len(triplets):\n            break\n        s += (\n            triplets[i - 1][\"thought\"]\n            + \"\\nAction \"\n            + str(i)\n            + \":\"\n            + triplets[i - 1][\"action\"]\n            + \"\\nObservation \"\n            + str(i)\n            + \":\"\n            + triplets[i - 1][\"observation\"]\n            + \"\\n\"\n        )\n\n\ndef main(args):\n    lines = read_jsonl(args.data_path)[: args.num_questions]\n    arguments = [{\"question\": k, \"triplets\": v} for l in lines for k, v in l.items()]\n\n    # Select backend\n    backend = select_sglang_backend(args)\n    sgl.set_default_backend(backend)\n\n    states = []\n    tic = time.perf_counter()\n    states = webthink.run_batch(\n        arguments,\n        temperature=0,\n        num_threads=args.parallel,\n        progress_bar=True,\n    )\n    latency = time.perf_counter() - tic\n\n    # Compute accuracy\n    print(f\"Latency: {latency:.3f}\")\n\n    # Write results\n    dump_state_text(f\"tmp_output_{args.backend}.txt\", states)\n\n    with open(args.result_file, \"a\") as fout:\n        value = {\n            \"task\": \"ReAct Agents\",\n            \"backend\": args.backend,\n            \"num_gpus\": 1,\n            \"latency\": round(latency, 3),\n            \"num_requests\": len(arguments),\n            \"other\": {\n                \"num_questions\": args.num_questions,\n                \"parallel\": args.parallel,\n            },\n        }\n        fout.write(json.dumps(value) + \"\\n\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--data-path\", type=str, default=\"hotpotqa_100.jsonl\")\n    parser.add_argument(\"--num-questions\", type=int, default=10)\n    args = add_common_sglang_args_and_parse(parser)\n    main(args)\n"
  },
  {
    "path": "benchmark/reasoning_benchmark/README.md",
    "content": "# Run benchmark\n\nThis benchmark is primarily intended to be used with reasoning models like `DeepSeek-R1` and its distilled models like `DeepSeek-R1-Distill-Qwen-1.5B`. Please use\n\n```bash\npip install antlr4-python3-runtime\n```\n\nfor `parse_latex` which we use for symbolic equality check.\n\n## Benchmark sglang\n\n1. Launch the Server\n```bash\npython3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --port 30000\n```\n\nNote that depending on the GPU this benchmark will take quiet some time. To employ data parallelism please use:\n\n```bash\npython3 -m sglang_router.launch_server --model-path deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --port 30000 --dp-size 4\n```\n\n\n2. Benchmarking\n\nWe use [suggested](https://github.com/deepseek-ai/DeepSeek-R1) parameters of `temperature=0.6`, `top_p=.95`, `max_new_tokens=32768`. The command line argument `num-tries` can be used to evaluate the model multiple times on the same question. We use the suggested `64` from the repo for AIME 2024. For LIMO, we use `8` as the number of tries due to the size of the dataset.\n\nBy default evaluate on LIMO dataset.\n\n```bash\npython3 bench_sglang.py --parallel 256 --num-tries 64 --port 30000\n```\n\nEvaluate on AIME 2024 dataset.\n\n```bash\npython3 bench_sglang.py --parallel 256 --port 30000 --data-path Maxwell-Jia/AIME_2024 --question-key Problem --answer-key Answer --num-tries 64\n```\n\nEvaluate on [AIME 2025 I dataset](https://huggingface.co/datasets/opencompass/AIME2025). For benchmark result see [here](https://matharena.ai/).\n\n```bash\npython3 bench_sglang.py --parallel 256 --port 30000 --data-path opencompass/AIME2025 --question-key question --answer-key answer --num-tries 64\n```\n## Results\n\n### Evaluation Results\n| Dataset    | Num Tries | Accuracy | Reference | Standard Error |\n|------------|-----------|----------|-----------|-----------|\n| LIMO       | 8         | 47.7%    | ?         | ?         |\n| AIME 2024  | 64        | 33.2%    | 28.9%     | 3.4%       |\n| AIME 2025 I| 64        | 29.9%    | 25.0%     |  ?        |\n\n### Statistic Analysis Results\nSet up SGLang engine for statistic analysis, for high efficiency we use `--dp-size 8` for data parallelism:\n```bash\npython3 -m sglang_router.launch_server --model-path deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --port 30000 --dp-size 8\n```\n**Experiment 1**:\nWe fixed the number of attempts (num_tries) and conducted multiple runs to assess the consistency of the model's performance. The results show that all recorded accuracies lie within ± one standard error deviation from the mean. This suggests that **our metric serves as an effective upper bound for the deviation of reported accuracy**.\n\nTo collect the accuracy, run the following command 30 times:\n```bash\npython3 bench_sglang.py --parallel 64 --port 30000 --data-path Maxwell-Jia/AIME_2024 --question-key Problem --answer-key Answer --num-tries 64\n```\n\n![acc_hist](figure/Acc_histplot.png)\n\n\n**Experiment 2**: We explored the relationship between the number of attempts (num_tries) and the standard error (SE) by varying num_tries across a range (e.g., 8, 16, 32, ..., 256) and performing a single run for each value. The results demonstrate that as the number of attempts increases, the standard error decreases, leading to **greater stability in answer accuracy**.\n\nTo reveal the relationship, run the command 6 times and adjust the parameter `--num-tries` for each run:\n```bash\npython3 bench_sglang.py --parallel 64 --port 30000 --data-path Maxwell-Jia/AIME_2024 --question-key Problem --answer-key Answer --num-tries <num_tries>\n```\n![SE_num_tries](figure/SE_numtries.png)\n"
  },
  {
    "path": "benchmark/reasoning_benchmark/answer_extraction.py",
    "content": "# Adapted from https://github.com/deepseek-ai/DeepSeek-Math/blob/main/evaluation/data_processing/answer_extraction.py\n\nimport re\n\nimport regex\n\n\ndef _fix_fracs(string):\n    substrs = string.split(\"\\\\frac\")\n    new_str = substrs[0]\n    if len(substrs) > 1:\n        substrs = substrs[1:]\n        for substr in substrs:\n            new_str += \"\\\\frac\"\n            if len(substr) > 0 and substr[0] == \"{\":\n                new_str += substr\n            else:\n                try:\n                    assert len(substr) >= 2\n                except:\n                    return string\n                a = substr[0]\n                b = substr[1]\n                if b != \"{\":\n                    if len(substr) > 2:\n                        post_substr = substr[2:]\n                        new_str += \"{\" + a + \"}{\" + b + \"}\" + post_substr\n                    else:\n                        new_str += \"{\" + a + \"}{\" + b + \"}\"\n                else:\n                    if len(substr) > 2:\n                        post_substr = substr[2:]\n                        new_str += \"{\" + a + \"}\" + b + post_substr\n                    else:\n                        new_str += \"{\" + a + \"}\" + b\n    string = new_str\n    return string\n\n\ndef _fix_a_slash_b(string):\n    if len(string.split(\"/\")) != 2:\n        return string\n    a = string.split(\"/\")[0]\n    b = string.split(\"/\")[1]\n    try:\n        if \"sqrt\" not in a:\n            a = int(a)\n        if \"sqrt\" not in b:\n            b = int(b)\n        assert string == \"{}/{}\".format(a, b)\n        new_string = \"\\\\frac{\" + str(a) + \"}{\" + str(b) + \"}\"\n        return new_string\n    except:\n        return string\n\n\ndef _fix_sqrt(string):\n    _string = re.sub(r\"\\\\sqrt(-?[0-9.a-zA-Z]+)\", r\"\\\\sqrt{\\1}\", string)\n    _string = re.sub(r\"\\\\sqrt\\s+(\\w+)$\", r\"\\\\sqrt{\\1}\", _string)\n    return _string\n\n\ndef _fix_tan(string):\n    _string = re.sub(r\"\\\\tan(-?[0-9.a-zA-Z]+)\", r\"\\\\tan{\\1}\", string)\n    _string = re.sub(r\"\\\\tan\\s+(\\w+)$\", r\"\\\\tan{\\1}\", _string)\n    return _string\n\n\ndef strip_string(string):\n    string = str(string).strip()\n    # linebreaks\n    string = string.replace(\"\\n\", \"\")\n\n    # right \".\"\n    string = string.rstrip(\".\")\n\n    # remove inverse spaces\n    string = string.replace(\"\\\\!\", \"\")\n    # string = string.replace(\"\\\\ \", \"\")\n\n    # replace \\\\ with \\\n    # string = string.replace(\"\\\\\\\\\", \"\\\\\")\n    # string = string.replace(\"\\\\\\\\\", \"\\\\\")\n\n    if string.startswith(\"\\\\text{\") and string.endswith(\"}\"):\n        string = string.split(\"{\", 1)[1][:-1]\n\n    # replace tfrac and dfrac with frac\n    string = string.replace(\"tfrac\", \"frac\")\n    string = string.replace(\"dfrac\", \"frac\")\n    string = string.replace(\"cfrac\", \"frac\")\n\n    # remove \\left and \\right\n    string = string.replace(\"\\\\left\", \"\")\n    string = string.replace(\"\\\\right\", \"\")\n\n    # Remove unit: miles, dollars if after is not none\n    _string = re.sub(r\"\\\\text{.*?}$\", \"\", string).strip()\n    if _string != \"\" and _string != string:\n        # print(\"Warning: unit not removed: '{}' -> '{}'\".format(string, _string))\n        string = _string\n\n    # Remove circ (degrees)\n    string = string.replace(\"^{\\\\circ}\", \"\").strip()\n    string = string.replace(\"^\\\\circ\", \"\").strip()\n\n    string = regex.sub(r\"\\{(c|m)?m\\}(\\^(2|3))?\", \"\", string).strip()\n    string = regex.sub(r\"p\\.m\\.$\", \"\", string).strip()\n    string = regex.sub(r\"(\\d)\\s*t$\", r\"\\1\", string).strip()\n\n    # remove dollar signs\n    string = string.replace(\"\\\\$\", \"\")\n    string = string.replace(\"$\", \"\")\n\n    # string = string.replace(\"\\\\text\", \"\")\n    string = string.replace(\"x\\\\in\", \"\")\n\n    # remove percentage\n    string = string.replace(\"\\\\%\", \"%\")\n    string = string.replace(\"\\%\", \"%\")\n    # string = string.replace(\"%\", \"\")\n\n    # \" 0.\" equivalent to \" .\" and \"{0.\" equivalent to \"{.\" Alternatively, add \"0\" if \".\" is the start of the string\n    string = string.replace(\" .\", \" 0.\")\n    string = string.replace(\"{.\", \"{0.\")\n\n    # cdot\n    string = string.replace(\"\\\\cdot\", \"\")\n\n    # inf\n    string = string.replace(\"infinity\", \"\\\\infty\")\n    if \"\\\\infty\" not in string:\n        string = string.replace(\"inf\", \"\\\\infty\")\n    string = string.replace(\"+\\\\inity\", \"\\\\infty\")\n\n    # and\n    # string = string.replace(\"and\", \"\")\n    string = string.replace(\"\\\\mathbf\", \"\")\n    string = string.replace(\"\\\\mathrm\", \"\")\n\n    # use regex to remove \\mbox{...}\n    string = re.sub(r\"\\\\mbox{.*?}\", \"\", string)\n\n    # quote\n    string.replace(\"'\", \"\")\n    string.replace('\"', \"\")\n\n    # i, j\n    if \"j\" in string and \"i\" not in string:\n        string = string.replace(\"j\", \"i\")\n\n    # replace a.000b where b is not number or b is end, with ab, use regex\n    string = re.sub(r\"(\\d+)\\.0+([^\\d])\", r\"\\1\\2\", string)\n    string = re.sub(r\"(\\d+)\\.0+$\", r\"\\1\", string)\n\n    # if empty, return empty string\n    if len(string) == 0:\n        return string\n    if string[0] == \".\":\n        string = \"0\" + string\n\n    # to consider: get rid of e.g. \"k = \" or \"q = \" at beginning\n    # if len(string.split(\"=\")) == 2:\n    #     if len(string.split(\"=\")[0]) <= 2:\n    #         string = string.split(\"=\")[1]\n\n    string = _fix_sqrt(string)\n    string = _fix_tan(string)\n    string = string.replace(\" \", \"\")\n\n    # \\frac1b or \\frac12 --> \\frac{1}{b} and \\frac{1}{2}, etc. Even works with \\frac1{72} (but not \\frac{72}1). Also does a/b --> \\\\frac{a}{b}\n    string = _fix_fracs(string)\n\n    # NOTE: X/Y changed to \\frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y\n    string = _fix_a_slash_b(string)\n\n    string = regex.sub(r\"(\\\\|,|\\.)+$\", \"\", string)\n\n    return string\n\n\ndef extract_boxed_answers(text):\n    answers = []\n    for piece in text.split(\"boxed{\")[1:]:\n        n = 0\n        for i in range(len(piece)):\n            if piece[i] == \"{\":\n                n += 1\n            elif piece[i] == \"}\":\n                n -= 1\n                if n < 0:\n                    if i + 1 < len(piece) and piece[i + 1] == \"%\":\n                        answers.append(piece[: i + 1])\n                    else:\n                        answers.append(piece[:i])\n                    break\n    return answers\n\n\ndef extract_program_output(pred_str):\n    \"\"\"\n    extract output between the last ```output\\n...\\n```\n    \"\"\"\n    if \"```output\" not in pred_str:\n        return \"\"\n    if \"```output\" in pred_str:\n        pred_str = pred_str.split(\"```output\")[-1]\n    if \"```\" in pred_str:\n        pred_str = pred_str.split(\"```\")[0]\n    output = pred_str.strip()\n    return output\n\n\ndef extract_answer(pred_str, exhaust=False):\n    pred = []\n    if \"final answer is $\" in pred_str and \"$. I hope\" in pred_str:\n        tmp = pred_str.split(\"final answer is $\", 1)[1]\n        pred = [tmp.split(\"$. I hope\", 1)[0].strip()]\n    elif \"boxed\" in pred_str:\n        pred = extract_boxed_answers(pred_str)\n    elif \"he answer is\" in pred_str:\n        pred = [pred_str.split(\"he answer is\")[-1].strip()]\n    else:\n        program_output = extract_program_output(pred_str)\n        if program_output != \"\":\n            # fall back to program\n            pred.append(program_output)\n        else:  # use the last number\n            pattern = \"-?\\d*\\.?\\d+\"\n            answers = re.findall(pattern, pred_str.replace(\",\", \"\"))\n            if len(answers) >= 1:\n                last_ans = answers[-1]\n            else:\n                last_ans = \"\"\n            if last_ans:\n                pred.append(last_ans)\n\n    # multiple line\n    _pred = []\n    for each_ans in pred:\n        each_ans = each_ans.strip().split(\"\\n\")[0]\n        each_ans = each_ans.lstrip(\":\")\n        each_ans = each_ans.rstrip(\".\")\n        each_ans = each_ans.rstrip(\"/\")\n        each_ans = strip_string(each_ans)\n        _pred.append(each_ans)\n    if exhaust:\n        return _pred\n    else:\n        return _pred[-1] if _pred else \"\"\n\n\ndef extract_math_answer(question, reasoning, task):\n    answer = []\n    for ans in extract_answer(reasoning, exhaust=True):\n        if \"separated by commas\" in question and all(ch not in ans for ch in \"()[]\"):\n            answer.extend([a.strip() for a in ans.split(\",\")])\n        elif regex.search(r\"\\\\text\\{\\s*and\\s*\\}\", ans):\n            answer.extend(\n                [\n                    a.strip()\n                    for a in regex.sub(r\"\\\\text\\{\\s*and\\s*\\}\", \"[SEP]\", ans).split(\n                        \"[SEP]\"\n                    )\n                ]\n            )\n        else:\n            answer.append(ans.strip())\n    return answer\n"
  },
  {
    "path": "benchmark/reasoning_benchmark/bench_sglang.py",
    "content": "import argparse\nimport json\nimport time\n\nimport answer_extraction\nimport eval_utils\nimport numpy as np\nfrom datasets import load_dataset\n\nimport sglang as sgl\nfrom sglang.test.test_utils import (\n    add_common_sglang_args_and_parse,\n    select_sglang_backend,\n)\nfrom sglang.utils import dump_state_text\n\n\n@sgl.function\ndef reasoning_gen(s, question: str):\n    s += sgl.user(\n        question\n        + \"\\nPlease reason step by step, and put your final answer within \\boxed{}.\"\n    )\n    s += sgl.assistant(\n        sgl.gen(\n            \"answer\",\n        )\n    )\n\n\ndef convert_dataset(path: str, question_key: str, answer_key: str, num_tries: int):\n    raw_dataset = load_dataset(path)\n    questions = []\n    answers = []\n    for data in raw_dataset[\"train\"]:\n        question = data[question_key]\n        answer = data[answer_key]\n        for _ in range(num_tries):\n            questions.append({\"question\": question})\n            answers.append({\"answer\": answer})\n    return questions, answers\n\n\ndef main(args):\n    # Select backend\n    sgl.set_default_backend(select_sglang_backend(args))\n\n    # Get dataset\n    questions, answers = convert_dataset(\n        args.data_path, args.question_key, args.answer_key, args.num_tries\n    )\n\n    # Run requests\n    tic = time.perf_counter()\n    states = reasoning_gen.run_batch(\n        questions,\n        num_threads=args.parallel,\n        progress_bar=True,\n        temperature=0.6,\n        max_new_tokens=32768,\n        top_p=0.95,\n    )\n    latency = time.perf_counter() - tic\n\n    # Extract results and record outcomes in a list.\n    outcomes = []\n    for i, state in enumerate(states):\n        try:\n            pred_answer = answer_extraction.extract_math_answer(\n                questions[i][\"question\"], state[\"answer\"], \"limo\"\n            )\n            gt_answer = str(answers[i][\"answer\"])\n            pred_answer = (\n                pred_answer[-1] if isinstance(pred_answer, list) else pred_answer\n            )\n            is_correct = 1 if eval_utils.math_equal(pred_answer, gt_answer) else 0\n        except Exception as e:\n            print(f\"Error extracting answer: {e}\")\n            is_correct = 0\n\n        outcomes.append(is_correct)\n\n    # Calculate overall accuracy using numpy\n    overall_accuracy = np.mean(outcomes)\n    print(f\"Overall Accuracy: {overall_accuracy}\")\n\n    # Calculate mean standard error over questions if num_tries >= 2\n    if args.num_tries > 1:\n        outcomes_np = np.array(outcomes).reshape(-1, args.num_tries)\n        # Using sample standard deviation with ddof=1\n        std_per_question = np.std(outcomes_np, axis=1, ddof=1)\n        # Compute the standard error for each question: std / sqrt(num_tries)\n        se_per_question = std_per_question / np.sqrt(args.num_tries)\n        mean_se = se_per_question.mean()\n        print(f\"Mean Standard Error of Accuracy across questions: {mean_se}\")\n    else:\n        mean_se = None\n        print(\"Not enough samples per question to compute standard error.\")\n\n    # Calculate output throughput\n    num_output_tokens = sum(\n        s.get_meta_info(\"answer\")[\"completion_tokens\"] for s in states\n    )\n    output_throughput = num_output_tokens / latency\n    print(f\"Output throughput: {output_throughput} token/s\")\n\n    # Dump results\n    dump_state_text(f\"tmp_output_{args.backend}.txt\", states)\n\n    # Write results\n    with open(args.result_file, \"a\") as fout:\n        value = {\n            \"task\": \"limo\",\n            \"backend\": args.backend,\n            \"latency\": round(latency, 3),\n            \"overall_accuracy\": round(overall_accuracy, 3),\n            \"mean_se_accuracy\": round(mean_se, 3) if mean_se is not None else None,\n            \"num_requests\": len(questions),\n            \"other\": {\n                \"num_questions\": len(questions),\n                \"parallel\": args.parallel,\n            },\n        }\n        fout.write(json.dumps(value) + \"\\n\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--data-path\", type=str, default=\"GAIR/LIMO\")\n    parser.add_argument(\"--question-key\", type=str, default=\"question\")\n    parser.add_argument(\"--answer-key\", type=str, default=\"answer\")\n    parser.add_argument(\"--num-tries\", type=int, default=1)\n    add_common_sglang_args_and_parse(parser)\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "benchmark/reasoning_benchmark/eval_utils.py",
    "content": "# Adapted from https://github.com/deepseek-ai/DeepSeek-Math/blob/main/evaluation/eval/eval_utils.py\n\nfrom math import isclose\n\nimport regex\nfrom sympy import N, simplify\nfrom sympy.parsing.latex import parse_latex\nfrom sympy.parsing.sympy_parser import parse_expr\n\n\ndef parse_digits(num):\n    # format: 234.23 || 23%\n    num = regex.sub(\",\", \"\", str(num))\n    try:\n        return float(num)\n    except:\n        if num.endswith(\"%\"):\n            num = num[:-1]\n            if num.endswith(\"\\\\\"):\n                num = num[:-1]\n            try:\n                return float(num) / 100\n            except:\n                pass\n    return None\n\n\ndef is_digit(num):\n    # paired with parse_digits\n    return parse_digits(num) is not None\n\n\ndef symbolic_equal(a, b):\n    def _parse(s):\n        for f in [parse_latex, parse_expr]:\n            try:\n                return f(s)\n            except:\n                pass\n        return s\n\n    a = _parse(a)\n    b = _parse(b)\n\n    try:\n        if simplify(a - b) == 0:\n            return True\n    except:\n        pass\n\n    try:\n        if isclose(N(a), N(b), abs_tol=1e-3):\n            return True\n    except:\n        pass\n    return False\n\n\ndef math_equal(prediction, reference, include_percentage=True, is_close=True):\n    \"\"\"\n    Exact match of math if and only if:\n    1. numerical equal: both can convert to float and are equal\n    2. symbolic equal: both can convert to sympy expression and are equal\n    \"\"\"\n    if str(prediction) == str(reference):\n        return True\n\n    try:  # 1. numerical equal\n        if is_digit(prediction) and is_digit(reference):\n            prediction = parse_digits(prediction)\n            reference = parse_digits(reference)\n            # number questions\n            if include_percentage:\n                gt_result = [reference / 100, reference, reference * 100]\n            else:\n                gt_result = [reference]\n            for item in gt_result:\n                try:\n                    if is_close:\n                        if isclose(item, prediction, abs_tol=1e-3):\n                            return True\n                    else:\n                        if item == prediction:\n                            return True\n                except Exception:\n                    continue\n            return False\n    except:\n        pass\n\n    if not prediction and prediction not in [0, False]:\n        return False\n\n    # 2. symbolic equal\n    reference = str(reference).strip()\n    prediction = str(prediction).strip()\n\n    if (\n        regex.match(r\"(\\(|\\[).+(\\)|\\])\", prediction) is not None\n        and regex.match(r\"(\\(|\\[).+(\\)|\\])\", reference) is not None\n    ):\n        pred_parts = prediction[1:-1].split(\",\")\n        ref_parts = reference[1:-1].split(\",\")\n        if len(pred_parts) == len(ref_parts):\n            if all(\n                [\n                    math_equal(\n                        pred_parts[i], ref_parts[i], include_percentage, is_close\n                    )\n                    for i in range(len(pred_parts))\n                ]\n            ):\n                return True\n\n    # Add back matrix comparison\n    if (\n        (\n            prediction.startswith(\"\\\\begin{pmatrix}\")\n            or prediction.startswith(\"\\\\begin{bmatrix}\")\n        )\n        and (\n            prediction.endswith(\"\\\\end{pmatrix}\")\n            or prediction.endswith(\"\\\\end{bmatrix}\")\n        )\n        and (\n            reference.startswith(\"\\\\begin{pmatrix}\")\n            or reference.startswith(\"\\\\begin{bmatrix}\")\n        )\n        and (\n            reference.endswith(\"\\\\end{pmatrix}\") or reference.endswith(\"\\\\end{bmatrix}\")\n        )\n    ):\n        pred_lines = [\n            line.strip()\n            for line in prediction[\n                len(\"\\\\begin{pmatrix}\") : -len(\"\\\\end{pmatrix}\")\n            ].split(\"\\\\\\\\\")\n            if line.strip()\n        ]\n        ref_lines = [\n            line.strip()\n            for line in reference[\n                len(\"\\\\begin{pmatrix}\") : -len(\"\\\\end{pmatrix}\")\n            ].split(\"\\\\\\\\\")\n            if line.strip()\n        ]\n        matched = True\n        if len(pred_lines) == len(ref_lines):\n            for pred_line, ref_line in zip(pred_lines, ref_lines):\n                pred_parts = pred_line.split(\"&\")\n                ref_parts = ref_line.split(\"&\")\n                if len(pred_parts) == len(ref_parts):\n                    if not all(\n                        [\n                            math_equal(\n                                pred_parts[i],\n                                ref_parts[i],\n                                include_percentage,\n                                is_close,\n                            )\n                            for i in range(len(pred_parts))\n                        ]\n                    ):\n                        matched = False\n                        break\n                else:\n                    matched = False\n                if not matched:\n                    break\n        else:\n            matched = False\n        if matched:\n            return True\n\n    # Add back equation comparison\n    if prediction.count(\"=\") == 1 and reference.count(\"=\") == 1:\n        pred = prediction.split(\"=\")\n        pred = f\"{pred[0].strip()} - ({pred[1].strip()})\"\n        ref = reference.split(\"=\")\n        ref = f\"{ref[0].strip()} - ({ref[1].strip()})\"\n        if symbolic_equal(pred, ref) or symbolic_equal(f\"-({pred})\", ref):\n            return True\n    elif (\n        prediction.count(\"=\") == 1\n        and len(prediction.split(\"=\")[0].strip()) <= 2\n        and \"=\" not in reference\n    ):\n        if math_equal(\n            prediction.split(\"=\")[1], reference, include_percentage, is_close\n        ):\n            return True\n    elif (\n        reference.count(\"=\") == 1\n        and len(reference.split(\"=\")[0].strip()) <= 2\n        and \"=\" not in prediction\n    ):\n        if math_equal(\n            prediction, reference.split(\"=\")[1], include_percentage, is_close\n        ):\n            return True\n\n    # symbolic equal with sympy\n    if symbolic_equal(prediction, reference):\n        return True\n\n    return False\n"
  },
  {
    "path": "benchmark/tip_suggestion/.gitignore",
    "content": "!topic.jsonl\n"
  },
  {
    "path": "benchmark/tip_suggestion/README.md",
    "content": "## Run benchmark\n\n### Benchmark sglang\n```\npython -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000\n```\n\n```\npython3 bench_sglang.py --num-questions 64\npython3 bench_sglang.py --num-questions 32 --parallel 1\n```\n\n\n### Benchmark vllm\n```\npython3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000\n```\n\n```\npython3 bench_other.py --backend vllm --num-questions 64\n```\n\n\n### Benchmark guidance\n```\npython3 bench_other.py --backend guidance --num-questions 32 --parallel 1 --n-ctx 4096 --model-path path/to/gguf\n```\n\n### Benchmark lmql\n\n```\npython3 bench_other.py --backend lmql --num-questions 32 --parallel 1\n```\n"
  },
  {
    "path": "benchmark/tip_suggestion/bench_other.py",
    "content": "import argparse\nimport json\nimport time\nfrom concurrent.futures import ThreadPoolExecutor\nfrom functools import partial\n\nfrom tqdm import tqdm\n\nfrom sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate\nfrom sglang.utils import dump_state_text, read_jsonl\n\nnumber = 5\n\n\ndef expand_tip(topic, tip, generate):\n    s = \"\"\"Please expand a tip for a topic into a detailed paragraph.\n\nTopic: staying healthy\nTip: Regular Exercise\nParagraph: Incorporate physical activity into your daily routine. This doesn't necessarily mean intense gym workouts; it can be as simple as walking, cycling, or yoga. Regular exercise helps in maintaining a healthy weight, improves cardiovascular health, boosts mental health, and can enhance cognitive function, which is crucial for fields that require intense intellectual engagement.\n\nTopic: building a campfire\nTip: Choose the Right Location\nParagraph: Always build your campfire in a safe spot. This means selecting a location that's away from trees, bushes, and other flammable materials. Ideally, use a fire ring if available. If you're building a fire pit, it should be on bare soil or on a bed of stones, not on grass or near roots which can catch fire underground. Make sure the area above is clear of low-hanging branches.\n\nTopic: writing a blog post\nTip: structure your content effectively\nParagraph: A well-structured post is easier to read and more enjoyable. Start with an engaging introduction that hooks the reader and clearly states the purpose of your post. Use headings and subheadings to break up the text and guide readers through your content. Bullet points and numbered lists can make information more digestible. Ensure each paragraph flows logically into the next, and conclude with a summary or call-to-action that encourages reader engagement.\n\nTopic: \"\"\" + topic + \"\\nTip: \" + tip + \"\\nParagraph:\"\n    return generate(s, max_tokens=128, stop=[\"\\n\\n\"])\n\n\ndef suggest_tips(topic, generate):\n    s = \"Please act as a helpful assistant. Your job is to provide users with useful tips on a specific topic.\\n\"\n    s += \"USER: Give some tips for \" + topic + \".\\n\"\n    s += (\n        \"ASSISTANT: Okay. Here are \"\n        + str(number)\n        + \" concise tips, each under 8 words:\\n\"\n    )\n\n    tips = []\n    for i in range(1, 1 + number):\n        s += f\"{i}.\"\n        tip = generate(s, max_tokens=24, stop=[\".\", \"\\n\"])\n        s += tip + \".\\n\"\n        tips.append(tip)\n\n    paragraphs = [expand_tip(topic, tip, generate=generate) for tip in tips]\n\n    for i in range(1, 1 + number):\n        s += f\"Tip {i}:\" + paragraphs[i - 1] + \"\\n\"\n    return s\n\n\ndef main(args):\n    lines = read_jsonl(args.data_path)[: args.num_questions]\n    states = [None] * len(lines)\n\n    # Select backend\n    call_generate = partial(get_call_generate(args), temperature=0)\n\n    # Run requests\n    tic = time.perf_counter()\n    if args.backend != \"lmql\":\n\n        def get_one_answer(i):\n            states[i] = suggest_tips(lines[i][\"topic\"], call_generate)\n\n        if args.parallel == 1:\n            for i in tqdm(range(len(lines))):\n                get_one_answer(i)\n        else:\n            with ThreadPoolExecutor(args.parallel) as executor:\n                list(\n                    tqdm(\n                        executor.map(get_one_answer, list(range(len(lines)))),\n                        total=len(lines),\n                    )\n                )\n\n    else:\n        import asyncio\n\n        from lmql_funcs import suggest_tips_async\n\n        async def get_one_answer_async(i):\n            states[i] = await suggest_tips_async(lines[i][\"topic\"], call_generate)\n\n        batches = []\n        for i in range(0, len(lines), args.parallel):\n            batches.append(list(range(i, min(i + args.parallel, len(lines)))))\n        loop = asyncio.get_event_loop()\n        for batch in tqdm(batches):\n            loop.run_until_complete(\n                asyncio.gather(*[get_one_answer_async(i) for i in batch])\n            )\n    latency = time.perf_counter() - tic\n\n    # Compute accuracy\n    print(f\"Latency: {latency:.3f}\")\n\n    # Write results\n    dump_state_text(f\"tmp_output_{args.backend}.txt\", states)\n\n    with open(args.result_file, \"a\") as fout:\n        value = {\n            \"task\": \"tip_suggestion\",\n            \"backend\": args.backend,\n            \"num_gpus\": 1,\n            \"latency\": round(latency, 3),\n            \"num_requests\": args.num_questions,\n            \"other\": {\n                \"num_questions\": args.num_questions,\n                \"parallel\": args.parallel,\n            },\n        }\n        fout.write(json.dumps(value) + \"\\n\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--data-path\", type=str, default=\"topic.jsonl\")\n    parser.add_argument(\"--num-questions\", type=int, default=100)\n    args = add_common_other_args_and_parse(parser)\n    main(args)\n"
  },
  {
    "path": "benchmark/tip_suggestion/bench_sglang.py",
    "content": "import argparse\nimport json\nimport time\n\nimport sglang as sgl\nfrom sglang.test.test_utils import (\n    add_common_sglang_args_and_parse,\n    select_sglang_backend,\n)\nfrom sglang.utils import dump_state_text, read_jsonl\n\nnumber = 5\n\n\n@sgl.function\ndef expand_tip(s, topic, tip):\n    s += \"\"\"Please expand a tip for a topic into a detailed paragraph.\n\nTopic: staying healthy\nTip: Regular Exercise\nParagraph: Incorporate physical activity into your daily routine. This doesn't necessarily mean intense gym workouts; it can be as simple as walking, cycling, or yoga. Regular exercise helps in maintaining a healthy weight, improves cardiovascular health, boosts mental health, and can enhance cognitive function, which is crucial for fields that require intense intellectual engagement.\n\nTopic: building a campfire\nTip: Choose the Right Location\nParagraph: Always build your campfire in a safe spot. This means selecting a location that's away from trees, bushes, and other flammable materials. Ideally, use a fire ring if available. If you're building a fire pit, it should be on bare soil or on a bed of stones, not on grass or near roots which can catch fire underground. Make sure the area above is clear of low-hanging branches.\n\nTopic: writing a blog post\nTip: structure your content effectively\nParagraph: A well-structured post is easier to read and more enjoyable. Start with an engaging introduction that hooks the reader and clearly states the purpose of your post. Use headings and subheadings to break up the text and guide readers through your content. Bullet points and numbered lists can make information more digestible. Ensure each paragraph flows logically into the next, and conclude with a summary or call-to-action that encourages reader engagement.\n\nTopic: \"\"\" + topic + \"\\nTip: \" + tip + \"\\nParagraph:\"\n    s += sgl.gen(\"paragraph\", max_tokens=128, stop=[\"\\n\\n\"], temperature=0)\n\n\n@sgl.function\ndef suggest_tips(s, topic):\n    s += \"Please act as a helpful assistant. Your job is to provide users with useful tips on a specific topic.\\n\"\n    s += \"USER: Give some tips for \" + topic + \".\\n\"\n    s += (\n        \"ASSISTANT: Okay. Here are \"\n        + str(number)\n        + \" concise tips, each under 8 words:\\n\"\n    )\n\n    paragraphs = []\n    for i in range(1, 1 + number):\n        s += f\"{i}.\" + sgl.gen(f\"tip_{i}\", max_tokens=24, stop=[\".\", \"\\n\"]) + \".\\n\"\n        paragraphs.append(expand_tip(topic=topic, tip=s[f\"tip_{i}\"]))\n\n    for i in range(1, 1 + number):\n        s += f\"Tip {i}:\" + paragraphs[i - 1][\"paragraph\"] + \"\\n\"\n\n\ndef main(args):\n    lines = read_jsonl(args.data_path)[: args.num_questions]\n    arguments = [{\"topic\": l[\"topic\"]} for l in lines]\n\n    # Select backend\n    sgl.set_default_backend(select_sglang_backend(args))\n\n    # Run requests\n    tic = time.perf_counter()\n    states = suggest_tips.run_batch(\n        arguments, temperature=0, num_threads=args.parallel, progress_bar=True\n    )\n    latency = time.perf_counter() - tic\n\n    # Compute accuracy\n    print(f\"Latency: {latency:.3f}\")\n\n    # Write results\n    dump_state_text(f\"tmp_output_{args.backend}.txt\", states)\n\n    with open(args.result_file, \"a\") as fout:\n        value = {\n            \"task\": \"tip_suggestion\",\n            \"backend\": args.backend,\n            \"num_gpus\": 1,\n            \"latency\": round(latency, 3),\n            \"num_requests\": args.num_questions,\n            \"other\": {\n                \"num_questions\": args.num_questions,\n                \"parallel\": args.parallel,\n            },\n        }\n        fout.write(json.dumps(value) + \"\\n\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--data-path\", type=str, default=\"topic.jsonl\")\n    parser.add_argument(\"--num-questions\", type=int, default=100)\n    args = add_common_sglang_args_and_parse(parser)\n    main(args)\n"
  },
  {
    "path": "benchmark/tip_suggestion/lmql_funcs.py",
    "content": "number = 5\n\n\nasync def expand_tip_async(topic, tip, generate):\n    s = \"\"\"Please expand a tip for a topic into a detailed paragraph.\n\nTopic: staying healthy\nTip: Regular Exercise\nParagraph: Incorporate physical activity into your daily routine. This doesn't necessarily mean intense gym workouts; it can be as simple as walking, cycling, or yoga. Regular exercise helps in maintaining a healthy weight, improves cardiovascular health, boosts mental health, and can enhance cognitive function, which is crucial for fields that require intense intellectual engagement.\n\nTopic: building a campfire\nTip: Choose the Right Location\nParagraph: Always build your campfire in a safe spot. This means selecting a location that's away from trees, bushes, and other flammable materials. Ideally, use a fire ring if available. If you're building a fire pit, it should be on bare soil or on a bed of stones, not on grass or near roots which can catch fire underground. Make sure the area above is clear of low-hanging branches.\n\nTopic: writing a blog post\nTip: structure your content effectively\nParagraph: A well-structured post is easier to read and more enjoyable. Start with an engaging introduction that hooks the reader and clearly states the purpose of your post. Use headings and subheadings to break up the text and guide readers through your content. Bullet points and numbered lists can make information more digestible. Ensure each paragraph flows logically into the next, and conclude with a summary or call-to-action that encourages reader engagement.\n\nTopic: \"\"\" + topic + \"\\nTip: \" + tip + \"\\nParagraph:\"\n    return await generate(s, max_tokens=128, stop=\"\\n\\n\")\n\n\nasync def suggest_tips_async(topic, generate):\n    s = \"Please act as a helpful assistant. Your job is to provide users with useful tips on a specific topic.\\n\"\n    s += \"USER: Give some tips for \" + topic + \".\\n\"\n    s += (\n        \"ASSISTANT: Okay. Here are \"\n        + str(number)\n        + \" concise tips, each under 8 words:\\n\"\n    )\n\n    tips = []\n    for i in range(1, 1 + number):\n        s += f\"{i}.\"\n        # NOTE: stop is different due to lmql does not support a list of stop tokens\n        tip = await generate(s, max_tokens=24, stop=\".\\n\")\n        s += tip + \".\\n\"\n        tips.append(tip)\n\n    paragraphs = [await expand_tip_async(topic, tip, generate=generate) for tip in tips]\n\n    for i in range(1, 1 + number):\n        s += f\"Tip {i}:\" + paragraphs[i - 1] + \"\\n\"\n    return s\n"
  },
  {
    "path": "benchmark/tip_suggestion/topic.jsonl",
    "content": "{\"topic\": \"organizing a successful charity event\", \"number\": 6}\n{\"topic\": \"improving personal credit scores\", \"number\": 7}\n{\"topic\": \"staying motivated during job searches\", \"number\": 5}\n{\"topic\": \"maintaining a work-life balance\", \"number\": 9}\n{\"topic\": \"reducing carbon footprint at home\", \"number\": 8}\n{\"topic\": \"starting a book club\", \"number\": 5}\n{\"topic\": \"learning to play a musical instrument\", \"number\": 7}\n{\"topic\": \"getting into freelance writing\", \"number\": 6}\n{\"topic\": \"beginner yoga poses\", \"number\": 8}\n{\"topic\": \"preparing for graduate school exams\", \"number\": 5}\n{\"topic\": \"exploring minimalist living\", \"number\": 9}\n{\"topic\": \"effective grocery shopping\", \"number\": 7}\n{\"topic\": \"winter camping\", \"number\": 5}\n{\"topic\": \"starting a podcast on a budget\", \"number\": 8}\n{\"topic\": \"creating a capsule wardrobe\", \"number\": 6}\n{\"topic\": \"improving your writing skills\", \"number\": 7}\n{\"topic\": \"learning a new software quickly\", \"number\": 9}\n{\"topic\": \"reducing anxiety before public speaking\", \"number\": 5}\n{\"topic\": \"planning a solo travel adventure\", \"number\": 8}\n{\"topic\": \"beginner skateboarders\", \"number\": 6}\n{\"topic\": \"studying abroad\", \"number\": 7}\n{\"topic\": \"planting a vegetable garden\", \"number\": 5}\n{\"topic\": \"adopting a shelter pet\", \"number\": 9}\n{\"topic\": \"learning to cook ethnic cuisines\", \"number\": 8}\n{\"topic\": \"effective conflict resolution\", \"number\": 5}\n{\"topic\": \"starting a vlog\", \"number\": 7}\n{\"topic\": \"keeping a daily journal\", \"number\": 6}\n{\"topic\": \"improving sleep hygiene\", \"number\": 8}\n{\"topic\": \"beginner mountain climbers\", \"number\": 5}\n{\"topic\": \"creating a mobile app\", \"number\": 9}\n{\"topic\": \"maintaining a saltwater aquarium\", \"number\": 7}\n{\"topic\": \"preparing for a baby's arrival\", \"number\": 6}\n{\"topic\": \"writing a fantasy novel\", \"number\": 5}\n{\"topic\": \"effective team leadership\", \"number\": 8}\n{\"topic\": \"making a documentary film\", \"number\": 9}\n{\"topic\": \"learning about historical events\", \"number\": 7}\n{\"topic\": \"baking gluten-free treats\", \"number\": 6}\n{\"topic\": \"improving mental arithmetic skills\", \"number\": 5}\n{\"topic\": \"building a treehouse\", \"number\": 8}\n{\"topic\": \"getting started with watercolor painting\", \"number\": 9}\n{\"topic\": \"creating a YouTube tutorial series\", \"number\": 7}\n{\"topic\": \"landscape photography\", \"number\": 5}\n{\"topic\": \"navigating cultural differences\", \"number\": 6}\n{\"topic\": \"preparing for a marathon\", \"number\": 8}\n{\"topic\": \"building an online business\", \"number\": 9}\n{\"topic\": \"learning to dance at home\", \"number\": 5}\n{\"topic\": \"self-publishing a book\", \"number\": 7}\n{\"topic\": \"starting an urban farm\", \"number\": 6}\n{\"topic\": \"improving your memory\", \"number\": 8}\n{\"topic\": \"creating a personal brand online\", \"number\": 9}\n"
  },
  {
    "path": "benchmark/tree_of_thought_deep/README.md",
    "content": "## Download data\n```\nwget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl\n```\n\n## Run benchmark\n\nNOTE: This is an implementation for throughput/latency benchmark purposes. The prompts are not tuned to achieve good accuracy on the GSM-8K tasks.\n\n### Benchmark sglang\n```\npython -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000\n```\n\n```\npython3 bench_sglang.py --num-questions 32\npython3 bench_sglang.py --num-questions 16 --parallel 1\n```\n\n\n### Benchmark vllm\n```\npython3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000\n```\n\n```\npython3 bench_other.py --num-questions 32 --backend vllm\n```\n\n\n### Benchmark lightllm\n```\n# A10G\npython -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000\n```\n\n```\npython3 bench_other.py --num-questions 32 --backend lightllm\n```\n\n\n### Benchmark guidance\n```\npython3 bench_other.py --num-questions 8 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf\n```\n\n### Benchmark lmql\n\n```\npython3 bench_other.py --num-questions 8 --backend lmql --parallel 1\n```\n"
  },
  {
    "path": "benchmark/tree_of_thought_deep/bench_other.py",
    "content": "import argparse\nimport ast\nimport json\nimport re\nimport time\nfrom collections import Counter\nfrom concurrent.futures import ThreadPoolExecutor\n\nimport numpy as np\nfrom tqdm import tqdm\n\nfrom sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate\nfrom sglang.utils import dump_state_text, read_jsonl\n\nINVALID = -9999999\n\n\ndef get_answer_value(answer_str):\n    answer_str = answer_str.replace(\",\", \"\")\n    numbers = re.findall(r\"\\d+\", answer_str)\n    if len(numbers) < 1:\n        return INVALID\n    try:\n        return ast.literal_eval(numbers[-1])\n    except SyntaxError:\n        return INVALID\n\n\ndef most_frequent_number(numbers):\n    if not numbers:\n        return None\n\n    frequency = Counter(numbers)\n    most_frequent = max(frequency, key=frequency.get)\n    return most_frequent\n\n\nUSER_PREFIX = \"[INST] \"\nUSER_SUFFIX = \" [/INST]\"\nASSISTANT_PREFIX = \"\"\nASSISTANT_SUFFIX = \" </s><s>\"\n\n# Use a low temp to make the results more deterministic and the comparison more fair.\ntemp = 0.001\n\n\ndef propose_plan(s, question, num_branches, call_generate):\n    s += (\n        USER_PREFIX\n        + \"\"\"Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: \"\"\"\n        + question\n        + USER_SUFFIX\n    )\n\n    s += ASSISTANT_PREFIX\n    comps = call_generate(\n        s, max_tokens=256, temperature=temp, stop=None, n=num_branches\n    )\n    return [s + comp + ASSISTANT_SUFFIX for comp in comps]\n\n\ndef execute_plan(s, num_branches, call_generate):\n    s += (\n        USER_PREFIX\n        + \"\"\"The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short.\"\"\"\n        + USER_SUFFIX\n    )\n    s += ASSISTANT_PREFIX\n    comps = call_generate(\n        s, max_tokens=256, temperature=temp, stop=None, n=num_branches\n    )\n    return [s + comp + ASSISTANT_SUFFIX for comp in comps]\n\n\ndef reflect_solution(s, num_branches, call_generate):\n    s += (\n        USER_PREFIX\n        + \"\"\"Okay. Now, evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness.\"\"\"\n        + USER_SUFFIX\n    )\n    s += ASSISTANT_PREFIX\n    comps = call_generate(\n        s, max_tokens=256, temperature=temp, stop=None, n=num_branches\n    )\n    return [s + comp + ASSISTANT_SUFFIX for comp in comps]\n\n\ndef get_final_answer(s, num_branches, call_generate):\n    s += (\n        USER_PREFIX\n        + \"\"\"Based on your reflection, do you change your mind? Now, give me the final answer after careful consideration.\"\"\"\n        + USER_SUFFIX\n    )\n    s += ASSISTANT_PREFIX\n    comps = call_generate(\n        s, max_tokens=256, temperature=temp, stop=None, n=num_branches\n    )\n    return [s + comp + ASSISTANT_SUFFIX for comp in comps]\n\n\ndef tree_search(question, num_branches, call_generate):\n    plan_forks = propose_plan(\"\", question, num_branches, call_generate)\n\n    sol_states = []\n    for plan in plan_forks:\n        forks = execute_plan(plan, num_branches, call_generate)\n        sol_states.extend(forks)\n\n    ref_states = []\n    for sol in sol_states:\n        forks = reflect_solution(sol, num_branches, call_generate)\n        ref_states.extend(forks)\n\n    solutions = []\n    for sol in ref_states:\n        ans = get_final_answer(sol, num_branches, call_generate)\n        solutions.append(ans)\n\n    return solutions\n\n\ndef main(args):\n    lines = read_jsonl(args.data_path)\n\n    # Construct prompts\n    num_branches = 2\n    questions = []\n    labels = []\n    for i in range(len(lines[: args.num_questions])):\n        questions.append(lines[i][\"question\"])\n        labels.append(get_answer_value(lines[i][\"answer\"]))\n    assert all(l != INVALID for l in labels)\n    arguments = [{\"question\": q, \"num_branches\": num_branches} for q in questions]\n\n    # Select backend\n    call_generate = get_call_generate(args)\n\n    # Run requests\n    states = [None] * len(questions)\n\n    tic = time.perf_counter()\n    if args.backend != \"lmql\":\n\n        def get_one_answer(i):\n            states[i] = tree_search(**arguments[i], call_generate=call_generate)\n\n        if args.parallel == 1:\n            for i in tqdm(range(len(questions))):\n                get_one_answer(i)\n        else:\n            with ThreadPoolExecutor(args.parallel) as executor:\n                list(\n                    tqdm(\n                        executor.map(get_one_answer, list(range(len(questions)))),\n                        total=len(questions),\n                    )\n                )\n\n    else:\n        import asyncio\n\n        from lmql_funcs import tree_search_async\n\n        async def get_one_answer_async(i):\n            states[i] = await tree_search_async(\n                **arguments[i], call_generate=call_generate\n            )\n\n        batches = [\n            [] for _ in range((len(questions) + args.parallel - 1) // args.parallel)\n        ]\n        for i in range(len(questions)):\n            batches[i // args.parallel].append(i)\n\n        loop = asyncio.get_event_loop()\n        for bt in tqdm(batches):\n            tasks = [get_one_answer_async(k) for k in bt]\n            loop.run_until_complete(asyncio.gather(*tasks))\n\n    latency = time.perf_counter() - tic\n\n    answers_text = []\n    for s in states:\n        answers_text.append([x for xs in s for x in xs])\n\n    preds = []\n    for i in range(len(states)):\n        answers = [get_answer_value(v) for v in answers_text[i]]\n        preds.append(most_frequent_number(answers))\n\n    # Compute accuracy\n    acc = np.mean(np.array(preds) == np.array(labels))\n    invalid = np.mean(np.array(preds) == INVALID)\n    print(f\"Latency: {latency:.3f}\")\n    print(f\"Invalid: {invalid:.3f}\")\n    print(f\"Accuracy: {acc:.3f}\")\n\n    # Write results\n    dump_state_text(f\"tmp_output_{args.backend}.txt\", answers_text)\n\n    with open(args.result_file, \"a\") as fout:\n        value = {\n            \"task\": \"tree_of_thought_gsm8k\",\n            \"backend\": args.backend,\n            \"num_gpus\": 1,\n            \"latency\": round(latency, 3),\n            \"accuracy\": round(acc, 3),\n            \"num_requests\": args.num_questions,\n            \"other\": {\n                \"num_questions\": args.num_questions,\n                \"parallel\": args.parallel,\n            },\n        }\n        fout.write(json.dumps(value) + \"\\n\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--data-path\", type=str, default=\"test.jsonl\")\n    parser.add_argument(\"--num-questions\", type=int, default=200)\n    args = add_common_other_args_and_parse(parser)\n    main(args)\n"
  },
  {
    "path": "benchmark/tree_of_thought_deep/bench_sglang.py",
    "content": "import argparse\nimport ast\nimport json\nimport re\nimport time\nfrom collections import Counter\n\nimport numpy as np\n\nimport sglang as sgl\nfrom sglang.test.test_utils import (\n    add_common_sglang_args_and_parse,\n    select_sglang_backend,\n)\nfrom sglang.utils import dump_state_text, read_jsonl\n\nINVALID = -9999999\n\n\ndef get_answer_value(answer_str):\n    answer_str = answer_str.replace(\",\", \"\")\n    numbers = re.findall(r\"\\d+\", answer_str)\n    if len(numbers) < 1:\n        return INVALID\n    try:\n        return ast.literal_eval(numbers[-1])\n    except SyntaxError:\n        return INVALID\n\n\ndef most_frequent_number(numbers):\n    if not numbers:\n        return None\n\n    frequency = Counter(numbers)\n    most_frequent = max(frequency, key=frequency.get)\n    return most_frequent\n\n\n# Use a low temp to make the results more deterministic and the comparison more fair.\ntemp = 0.001\n\n\ndef propose_plan(s, question, num_branches):\n    s += sgl.user(\n        \"\"\"Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: \"\"\"\n        + question\n    )\n    forks = s.fork(num_branches)\n    forks += sgl.assistant(sgl.gen(\"plan\", max_tokens=256, temperature=temp))\n    return forks\n\n\ndef execute_plan(s, num_branches):\n    s += sgl.user(\n        \"\"\"The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short.\"\"\"\n    )\n    forks = s.fork(num_branches)\n    forks += sgl.assistant(sgl.gen(\"answer\", max_tokens=256, temperature=temp))\n    return forks\n\n\ndef reflect_solution(s, num_branches):\n    s += sgl.user(\n        \"\"\"Okay. Now, evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness.\"\"\"\n    )\n    forks = s.fork(num_branches)\n    forks += sgl.assistant(sgl.gen(\"score\", max_tokens=256, temperature=temp))\n    return forks\n\n\ndef get_final_answer(s, num_branches):\n    s += sgl.user(\n        \"\"\"Based on your reflection, do you change your mind? Now, give me the final answer after careful consideration.\"\"\"\n    )\n    forks = s.fork(num_branches)\n    forks += sgl.assistant(sgl.gen(\"final_answer\", max_tokens=256, temperature=temp))\n    return forks\n\n\n@sgl.function\ndef tree_search(s, question, num_branches):\n    plan_forks = propose_plan(s, question, num_branches)\n\n    sol_states = []\n    for plan in plan_forks:\n        forks = execute_plan(plan, num_branches)\n        sol_states.extend(forks)\n\n    ref_states = []\n    for sol in sol_states:\n        forks = reflect_solution(sol, num_branches)\n        ref_states.extend(forks)\n\n    solutions = []\n    for sol in ref_states:\n        forks = get_final_answer(sol, num_branches)\n        solutions.append(forks)\n    solutions = [[s.text() for s in forks] for forks in solutions]\n\n    return solutions\n\n\ndef main(args):\n    lines = read_jsonl(args.data_path)\n    lines = list(lines)\n\n    # Construct prompts\n    num_branches = 2\n    questions = []\n    labels = []\n    for i in range(len(lines[: args.num_questions])):\n        questions.append(lines[i][\"question\"])\n        labels.append(get_answer_value(lines[i][\"answer\"]))\n    assert all(l != INVALID for l in labels)\n    arguments = [{\"question\": q, \"num_branches\": num_branches} for q in questions]\n\n    # Select backend\n    backend = select_sglang_backend(args)\n\n    # Run requests\n    tic = time.perf_counter()\n    states = tree_search.run_batch(\n        arguments,\n        temperature=0,\n        backend=backend,\n        num_threads=args.parallel,\n        progress_bar=True,\n    )\n    latency = time.perf_counter() - tic\n    answers_text = []\n    for s in states:\n        answers_text.append([x for xs in s.ret_value for x in xs])\n\n    preds = []\n    for i in range(len(states)):\n        answers = [get_answer_value(v) for v in answers_text[i]]\n        preds.append(most_frequent_number(answers))\n\n    # Compute accuracy\n    acc = np.mean(np.array(preds) == np.array(labels))\n    invalid = np.mean(np.array(preds) == INVALID)\n    print(f\"Latency: {latency:.3f}\")\n    print(f\"Invalid: {invalid:.3f}\")\n    print(f\"Accuracy: {acc:.3f}\")\n\n    # Write results\n    dump_state_text(f\"tmp_output_{args.backend}.txt\", answers_text)\n\n    with open(args.result_file, \"a\") as fout:\n        value = {\n            \"task\": \"tree_of_thought_gsm8k\",\n            \"backend\": args.backend,\n            \"num_gpus\": 1,\n            \"latency\": round(latency, 3),\n            \"accuracy\": round(acc, 3),\n            \"num_requests\": args.num_questions,\n            \"other\": {\n                \"num_questions\": args.num_questions,\n                \"parallel\": args.parallel,\n            },\n        }\n        fout.write(json.dumps(value) + \"\\n\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--data-path\", type=str, default=\"test.jsonl\")\n    parser.add_argument(\"--num-questions\", type=int, default=200)\n    args = add_common_sglang_args_and_parse(parser)\n    main(args)\n"
  },
  {
    "path": "benchmark/tree_of_thought_deep/lmql_funcs.py",
    "content": "from bench_other import (\n    ASSISTANT_PREFIX,\n    ASSISTANT_SUFFIX,\n    USER_PREFIX,\n    USER_SUFFIX,\n    temp,\n)\n\n\nasync def propose_plan_async(s, question, num_branches, call_generate):\n    s += (\n        USER_PREFIX\n        + \"\"\"Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: \"\"\"\n        + question\n        + USER_SUFFIX\n    )\n\n    s += ASSISTANT_PREFIX\n    comps = await call_generate(\n        s, max_tokens=256, temperature=temp, stop=None, n=num_branches\n    )\n    return [s + comp + ASSISTANT_SUFFIX for comp in comps]\n\n\nasync def execute_plan_async(s, num_branches, call_generate):\n    s += (\n        USER_PREFIX\n        + \"\"\"The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short.\"\"\"\n        + USER_SUFFIX\n    )\n    s += ASSISTANT_PREFIX\n    comps = await call_generate(\n        s, max_tokens=256, temperature=temp, stop=None, n=num_branches\n    )\n    return [s + comp + ASSISTANT_SUFFIX for comp in comps]\n\n\nasync def reflect_solution_async(s, num_branches, call_generate):\n    s += (\n        USER_PREFIX\n        + \"\"\"Okay. Now, evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness.\"\"\"\n        + USER_SUFFIX\n    )\n    s += ASSISTANT_PREFIX\n    comps = await call_generate(\n        s, max_tokens=256, temperature=temp, stop=None, n=num_branches\n    )\n    return [s + comp + ASSISTANT_SUFFIX for comp in comps]\n\n\nasync def get_final_answer_async(s, num_branches, call_generate):\n    s += (\n        USER_PREFIX\n        + \"\"\"Based on your reflection, do you change your mind? Now, give me the final answer after careful consideration.\"\"\"\n        + USER_SUFFIX\n    )\n    s += ASSISTANT_PREFIX\n    comps = await call_generate(\n        s, max_tokens=256, temperature=temp, stop=None, n=num_branches\n    )\n    return [s + comp + ASSISTANT_SUFFIX for comp in comps]\n\n\nasync def tree_search_async(question, num_branches, call_generate):\n    plan_forks = await propose_plan_async(\"\", question, num_branches, call_generate)\n\n    sol_states = []\n    for plan in plan_forks:\n        forks = await execute_plan_async(plan, num_branches, call_generate)\n        sol_states.extend(forks)\n\n    ref_states = []\n    for sol in sol_states:\n        forks = await reflect_solution_async(sol, num_branches, call_generate)\n        ref_states.extend(forks)\n\n    solutions = []\n    for sol in ref_states:\n        ans = await get_final_answer_async(sol, num_branches, call_generate)\n        solutions.append(ans)\n\n    return solutions\n"
  },
  {
    "path": "benchmark/tree_of_thought_v0/README.md",
    "content": "## Download data\n```\nwget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl\n```\n\n## Run benchmark\n\n### Benchmark sglang\n```\npython -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000\n```\n\n```\npython3 bench_sglang.py --num-questions 32 --parallel 16\npython3 bench_sglang.py --num-questions 10 --parallel 1\n```\n\n\n### Benchmark vllm\n```\npython3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000\n```\n\n```\npython3 bench_other.py --num-questions 32 --backend vllm\n```\n\n\n### Benchmark lightllm\n```\n# A10G\npython -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000\n```\n\n```\npython3 bench_other.py --num-questions 32 --backend lightllm\n```\n\n\n### Benchmark guidance\n```\npython3 bench_other.py --num-questions 32 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf\n```\n"
  },
  {
    "path": "benchmark/tree_of_thought_v0/bench_other.py",
    "content": "import argparse\nimport ast\nimport json\nimport re\nimport time\nfrom collections import Counter\nfrom concurrent.futures import ThreadPoolExecutor\n\nimport numpy as np\nfrom tqdm import tqdm\n\nfrom sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate\nfrom sglang.utils import dump_state_text, read_jsonl\n\nINVALID = -9999999\n\n\ndef get_answer_value(answer_str):\n    answer_str = answer_str.replace(\",\", \"\")\n    numbers = re.findall(r\"\\d+\", answer_str)\n    if len(numbers) < 1:\n        return INVALID\n    try:\n        return ast.literal_eval(numbers[-1])\n    except SyntaxError:\n        return INVALID\n\n\ndef most_frequent_number(numbers):\n    if not numbers:\n        return None\n\n    frequency = Counter(numbers)\n    most_frequent = max(frequency, key=frequency.get)\n    return most_frequent\n\n\nUSER_PREFIX = \"[INST] \"\nUSER_SUFFIX = \" [/INST]\"\nASSISTANT_PREFIX = \"\"\nASSISTANT_SUFFIX = \" </s><s>\"\n\n# Use a low temp to make the results more deterministic and the comparison more fair.\ntemp = 0.3\n\n\ndef propose_plan(s, question, num_branches, call_generate):\n    s += (\n        USER_PREFIX\n        + \"\"\"Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: \"\"\"\n        + question\n        + USER_SUFFIX\n    )\n\n    s += ASSISTANT_PREFIX\n    comps = call_generate(\n        s, max_tokens=256, temperature=temp, stop=None, n=num_branches\n    )\n    return [s + comp + ASSISTANT_SUFFIX for comp in comps]\n\n\ndef execute_plan(s, num_branches, call_generate):\n    s += (\n        USER_PREFIX\n        + \"\"\"The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short.\"\"\"\n        + USER_SUFFIX\n    )\n    s += ASSISTANT_PREFIX\n    comps = call_generate(\n        s, max_tokens=256, temperature=temp, stop=None, n=num_branches\n    )\n    return [s + comp + ASSISTANT_SUFFIX for comp in comps]\n\n\ndef reflect_solution(s, num_branches, call_generate):\n    s += (\n        USER_PREFIX\n        + \"\"\"Okay. Now you evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness.\"\"\"\n        + USER_SUFFIX\n    )\n    s += ASSISTANT_PREFIX\n    comps = call_generate(\n        s, max_tokens=256, temperature=temp, stop=None, n=num_branches\n    )\n    return [s + comp + ASSISTANT_SUFFIX for comp in comps]\n\n\ndef tree_search(question, num_branches, call_generate):\n    s = \"\"\n    solutions = []\n\n    plan_forks = propose_plan(s, question, num_branches, call_generate)\n    for plan in plan_forks:\n        sol_forks = execute_plan(plan, num_branches, call_generate)\n        for sol in sol_forks:\n            score_forks = reflect_solution(sol, num_branches, call_generate)\n        solutions.append(sol_forks)\n\n    return solutions\n\n\ndef main(args):\n    lines = read_jsonl(args.data_path)\n\n    # Construct prompts\n    num_branches = 3\n    questions = []\n    labels = []\n    for i in range(len(lines[: args.num_questions])):\n        questions.append(lines[i][\"question\"])\n        labels.append(get_answer_value(lines[i][\"answer\"]))\n    assert all(l != INVALID for l in labels)\n    arguments = [{\"question\": q, \"num_branches\": num_branches} for q in questions]\n\n    # Select backend\n    call_generate = get_call_generate(args)\n\n    # Run requests\n    states = [None] * len(questions)\n\n    def get_one_answer(i):\n        states[i] = tree_search(**arguments[i], call_generate=call_generate)\n\n    tic = time.perf_counter()\n    if args.parallel == 1:\n        for i in tqdm(range(len(questions))):\n            get_one_answer(i)\n    else:\n        with ThreadPoolExecutor(args.parallel) as executor:\n            list(\n                tqdm(\n                    executor.map(get_one_answer, list(range(len(questions)))),\n                    total=len(questions),\n                )\n            )\n\n    latency = time.perf_counter() - tic\n\n    answers_text = []\n    for s in states:\n        answers_text.append([x for xs in s for x in xs])\n\n    preds = []\n    for i in range(len(states)):\n        answers = [get_answer_value(v) for v in answers_text[i]]\n        preds.append(most_frequent_number(answers))\n\n    # Compute accuracy\n    acc = np.mean(np.array(preds) == np.array(labels))\n    invalid = np.mean(np.array(preds) == INVALID)\n    print(f\"Latency: {latency:.3f}\")\n    print(f\"Invalid: {invalid:.3f}\")\n    print(f\"Accuracy: {acc:.3f}\")\n\n    # Write results\n    dump_state_text(f\"tmp_output_{args.backend}.txt\", answers_text)\n\n    with open(args.result_file, \"a\") as fout:\n        value = {\n            \"task\": \"tree_of_thought_gsm8k\",\n            \"backend\": args.backend,\n            \"num_gpus\": 1,\n            \"latency\": round(latency, 3),\n            \"accuracy\": round(acc, 3),\n            \"num_requests\": args.num_questions,\n            \"other\": {\n                \"num_questions\": args.num_questions,\n                \"parallel\": args.parallel,\n            },\n        }\n        fout.write(json.dumps(value) + \"\\n\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--data-path\", type=str, default=\"test.jsonl\")\n    parser.add_argument(\"--num-questions\", type=int, default=200)\n    args = add_common_other_args_and_parse(parser)\n    main(args)\n"
  },
  {
    "path": "benchmark/tree_of_thought_v0/bench_sglang.py",
    "content": "import argparse\nimport ast\nimport json\nimport re\nimport time\nfrom collections import Counter\n\nimport numpy as np\n\nimport sglang as sgl\nfrom sglang.test.test_utils import (\n    add_common_sglang_args_and_parse,\n    select_sglang_backend,\n)\nfrom sglang.utils import dump_state_text, read_jsonl\n\nINVALID = -9999999\n\n\ndef get_answer_value(answer_str):\n    answer_str = answer_str.replace(\",\", \"\")\n    numbers = re.findall(r\"\\d+\", answer_str)\n    if len(numbers) < 1:\n        return INVALID\n    try:\n        return ast.literal_eval(numbers[-1])\n    except SyntaxError:\n        return INVALID\n\n\ndef most_frequent_number(numbers):\n    if not numbers:\n        return None\n\n    frequency = Counter(numbers)\n    most_frequent = max(frequency, key=frequency.get)\n    return most_frequent\n\n\n# Use a low temp to make the results more deterministic and the comparison more fair.\ntemp = 0.3\n\n\ndef propose_plan(s, question, num_branches):\n    s += sgl.user(\n        \"\"\"Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: \"\"\"\n        + question\n    )\n    forks = s.fork(num_branches)\n    forks += sgl.assistant(sgl.gen(\"plan\", max_tokens=256, temperature=temp))\n    return forks\n\n\ndef execute_plan(s, num_branches):\n    s += sgl.user(\n        \"\"\"The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short.\"\"\"\n    )\n    forks = s.fork(num_branches)\n    forks += sgl.assistant(sgl.gen(\"answer\", max_tokens=256, temperature=temp))\n    return forks\n\n\ndef reflect_solution(s, num_branches):\n    s += sgl.user(\n        \"\"\"Okay. Now you evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness.\"\"\"\n    )\n    forks = s.fork(num_branches)\n    forks += sgl.assistant(sgl.gen(\"score\", max_tokens=256, temperature=temp))\n    return forks\n\n\n@sgl.function\ndef tree_search(s, question, num_branches):\n    forks_to_join = []\n\n    plan_forks = propose_plan(s, question, num_branches)\n    forks_to_join.append(plan_forks)\n\n    sol_states = []\n    for plan in plan_forks:\n        forks = execute_plan(plan, num_branches)\n        forks_to_join.append(forks)\n        sol_states.extend(forks)\n\n    for sol in sol_states:\n        forks = reflect_solution(sol, num_branches)\n        forks_to_join.append(forks)\n\n    for f in reversed(forks_to_join):\n        f.join()\n\n\ndef main(args):\n    lines = read_jsonl(args.data_path)\n\n    # Construct prompts\n    num_branches = 3\n    questions = []\n    labels = []\n    for i in range(len(lines[: args.num_questions])):\n        questions.append(lines[i][\"question\"])\n        labels.append(get_answer_value(lines[i][\"answer\"]))\n    assert all(l != INVALID for l in labels)\n    arguments = [{\"question\": q, \"num_branches\": num_branches} for q in questions]\n\n    # Select backend\n    backend = select_sglang_backend(args)\n\n    # Run requests\n    tic = time.perf_counter()\n    states = tree_search.run_batch(\n        arguments,\n        temperature=0,\n        backend=backend,\n        num_threads=args.parallel,\n        progress_bar=True,\n    )\n    latency = time.perf_counter() - tic\n    answers_text = []\n    for s in states:\n        answers_text.append([x for xs in s[\"answer\"] for x in xs])\n\n    preds = []\n    for i in range(len(states)):\n        answers = [get_answer_value(v) for v in answers_text[i]]\n        preds.append(most_frequent_number(answers))\n\n    # Compute accuracy\n    acc = np.mean(np.array(preds) == np.array(labels))\n    invalid = np.mean(np.array(preds) == INVALID)\n    print(f\"Latency: {latency:.3f}\")\n    print(f\"Invalid: {invalid:.3f}\")\n    print(f\"Accuracy: {acc:.3f}\")\n\n    # Write results\n    dump_state_text(f\"tmp_output_{args.backend}.txt\", answers_text)\n\n    with open(args.result_file, \"a\") as fout:\n        value = {\n            \"task\": \"tree_of_thought_gsm8k\",\n            \"backend\": args.backend,\n            \"num_gpus\": 1,\n            \"latency\": round(latency, 3),\n            \"accuracy\": round(acc, 3),\n            \"num_requests\": args.num_questions,\n            \"other\": {\n                \"num_questions\": args.num_questions,\n                \"parallel\": args.parallel,\n            },\n        }\n        fout.write(json.dumps(value) + \"\\n\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--data-path\", type=str, default=\"test.jsonl\")\n    parser.add_argument(\"--num-questions\", type=int, default=200)\n    args = add_common_sglang_args_and_parse(parser)\n    main(args)\n"
  },
  {
    "path": "docker/Dockerfile",
    "content": "ARG CUDA_VERSION=12.9.1\nFROM nvidia/cuda:${CUDA_VERSION}-cudnn-devel-ubuntu24.04 AS base\n\nARG TARGETARCH\nARG BUILD_TYPE=all\nARG BRANCH_TYPE=remote\nARG GRACE_BLACKWELL=0\nARG HOPPER_SBO=0\n\nARG GRACE_BLACKWELL_DEEPEP_BRANCH=gb200_blog_part_2\nARG HOPPER_SBO_DEEPEP_COMMIT=9f2fc4b3182a51044ae7ecb6610f7c9c3258c4d6\nARG DEEPEP_COMMIT=9af0e0d0e74f3577af1979c9b9e1ac2cad0104ee\nARG BUILD_AND_DOWNLOAD_PARALLEL=8\nARG SGL_KERNEL_VERSION=0.4.0\nARG SGL_VERSION\nARG USE_LATEST_SGLANG=0\nARG GDRCOPY_VERSION=2.5.1\nARG PIP_DEFAULT_INDEX\nARG UBUNTU_MIRROR\nARG GITHUB_ARTIFACTORY=github.com\nARG INSTALL_FLASHINFER_JIT_CACHE=0\nARG FLASHINFER_VERSION=0.6.6\nARG MOONCAKE_VERSION=0.3.9\n#if need other arg please add in MOONCAKE_COMPILE_ARG\nARG MOONCAKE_COMPILE_ARG=\"-DUSE_HTTP=ON -DUSE_MNNVL=ON -DUSE_CUDA=ON -DWITH_EP=ON\"\n\nENV DEBIAN_FRONTEND=noninteractive \\\n    CUDA_HOME=/usr/local/cuda \\\n    GDRCOPY_HOME=/usr/src/gdrdrv-${GDRCOPY_VERSION}/ \\\n    FLASHINFER_VERSION=${FLASHINFER_VERSION}\n\n# Add GKE default lib and bin locations\nENV PATH=\"${PATH}:/usr/local/nvidia/bin\" \\\n    LD_LIBRARY_PATH=\"${LD_LIBRARY_PATH}:/usr/local/nvidia/lib:/usr/local/nvidia/lib64\"\n\n# Replace Ubuntu sources if specified\nRUN if [ -n \"$UBUNTU_MIRROR\" ]; then \\\n    sed -i \"s|http://.*archive.ubuntu.com|$UBUNTU_MIRROR|g\" /etc/apt/sources.list && \\\n    sed -i \"s|http://.*security.ubuntu.com|$UBUNTU_MIRROR|g\" /etc/apt/sources.list; \\\nfi\n\n# Python setup (combined with apt update to reduce layers)\nRUN --mount=type=cache,target=/var/cache/apt,id=base-apt \\\n    apt update && apt install -y --no-install-recommends wget software-properties-common \\\n    && add-apt-repository ppa:deadsnakes/ppa -y \\\n    && apt install -y --no-install-recommends python3.12-full python3.12-dev python3.10-venv \\\n    && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1 \\\n    && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.12 2 \\\n    && update-alternatives --set python3 /usr/bin/python3.12 \\\n    && wget -q https://bootstrap.pypa.io/get-pip.py \\\n    && python3 get-pip.py --break-system-packages \\\n    && rm get-pip.py \\\n    # Allow pip to install packages globally (PEP 668 workaround for Ubuntu 24.04)\n    && python3 -m pip config set global.break-system-packages true \\\n    # Fix for apt-add-repository\n    && cd /usr/lib/python3/dist-packages/ \\\n    && ln -s apt_pkg.cpython-310-*-linux-gnu.so apt_pkg.so\n\n# Install system dependencies (organized by category for better caching)\nRUN --mount=type=cache,target=/var/cache/apt,id=base-apt \\\n    apt-get update && apt-get install -y --no-install-recommends \\\n    # Core system utilities\n    ca-certificates \\\n    software-properties-common \\\n    netcat-openbsd \\\n    kmod \\\n    unzip \\\n    openssh-server \\\n    curl \\\n    wget \\\n    lsof \\\n    locales \\\n    # Build essentials (needed for framework stage)\n    build-essential \\\n    cmake \\\n    perl \\\n    patchelf \\\n    ccache \\\n    git-lfs \\\n    # MPI and NUMA\n    libopenmpi-dev \\\n    libnuma1 \\\n    libnuma-dev \\\n    numactl \\\n    # transformers multimodal VLM\n    ffmpeg \\\n    # InfiniBand/RDMA\n    libibverbs-dev \\\n    libibverbs1 \\\n    libibumad3 \\\n    librdmacm1 \\\n    libnl-3-200 \\\n    libnl-route-3-200 \\\n    libnl-route-3-dev \\\n    libnl-3-dev \\\n    ibverbs-providers \\\n    infiniband-diags \\\n    perftest \\\n    # Development libraries\n    libgoogle-glog-dev \\\n    libgtest-dev \\\n    libjsoncpp-dev \\\n    libunwind-dev \\\n    libboost-all-dev \\\n    libssl-dev \\\n    libgrpc-dev \\\n    libgrpc++-dev \\\n    libprotobuf-dev \\\n    protobuf-compiler \\\n    protobuf-compiler-grpc \\\n    pybind11-dev \\\n    libhiredis-dev \\\n    libcurl4-openssl-dev \\\n    libczmq4 \\\n    libczmq-dev \\\n    libfabric-dev \\\n    # Package building tools\n    devscripts \\\n    debhelper \\\n    fakeroot \\\n    dkms \\\n    check \\\n    libsubunit0 \\\n    libsubunit-dev \\\n    && ln -sf /usr/bin/python3.12 /usr/bin/python \\\n    && rm -rf /var/lib/apt/lists/* \\\n    && apt-get clean\n\n# Replace pip global cache if specified\nRUN if [ -n \"${PIP_DEFAULT_INDEX}\" ]; then \\\n    python3 -m pip config set global.index-url ${PIP_DEFAULT_INDEX}; \\\nfi\n\n# GDRCopy installation\nRUN mkdir -p /tmp/gdrcopy && cd /tmp \\\n    && curl --retry 3 --retry-delay 2 -fsSL -o v${GDRCOPY_VERSION}.tar.gz \\\n        https://${GITHUB_ARTIFACTORY}/NVIDIA/gdrcopy/archive/refs/tags/v${GDRCOPY_VERSION}.tar.gz \\\n    && tar -xzf v${GDRCOPY_VERSION}.tar.gz && rm v${GDRCOPY_VERSION}.tar.gz \\\n    && cd gdrcopy-${GDRCOPY_VERSION}/packages \\\n    && CUDA=/usr/local/cuda ./build-deb-packages.sh \\\n    && dpkg -i gdrdrv-dkms_*.deb libgdrapi_*.deb gdrcopy-tests_*.deb gdrcopy_*.deb \\\n    && cd / && rm -rf /tmp/gdrcopy\n\n# Fix DeepEP IBGDA symlink\nRUN ln -sf /usr/lib/$(uname -m)-linux-gnu/libmlx5.so.1 /usr/lib/$(uname -m)-linux-gnu/libmlx5.so\n\n# Set up locale\nRUN locale-gen en_US.UTF-8\nENV LANG=en_US.UTF-8 \\\n    LANGUAGE=en_US:en \\\n    LC_ALL=en_US.UTF-8\n\n########################################################\n########## Framework Development Image ################\n########################################################\n\n# Copy local source if building from local\nFROM scratch AS local_src\nCOPY . /src\n\nFROM base AS framework\n\nARG BRANCH_TYPE\nARG BUILD_TYPE\nARG CUDA_VERSION\nARG BUILD_AND_DOWNLOAD_PARALLEL\nARG SGL_KERNEL_VERSION\nARG SGL_VERSION\nARG USE_LATEST_SGLANG\nARG INSTALL_FLASHINFER_JIT_CACHE\nARG FLASHINFER_VERSION\nARG GRACE_BLACKWELL\nARG GRACE_BLACKWELL_DEEPEP_BRANCH\nARG DEEPEP_COMMIT\nARG TRITON_LANG_COMMIT\nARG GITHUB_ARTIFACTORY\n\nWORKDIR /sgl-workspace\n\n# Install SGLang\nCOPY --from=local_src /src /tmp/local_src\nRUN if [ \"$BRANCH_TYPE\" = \"local\" ]; then \\\n        cp -r /tmp/local_src /sgl-workspace/sglang; \\\n    elif [ \"$USE_LATEST_SGLANG\" = \"1\" ]; then \\\n        git clone --depth=1 https://github.com/sgl-project/sglang.git /sgl-workspace/sglang; \\\n    elif [ -z \"$SGL_VERSION\" ]; then \\\n        echo \"ERROR: SGL_VERSION must be set when USE_LATEST_SGLANG=0 and BRANCH_TYPE!=local\" && exit 1; \\\n    else \\\n        git clone --depth=1 --branch v${SGL_VERSION} https://github.com/sgl-project/sglang.git /sgl-workspace/sglang; \\\n    fi \\\n    && rm -rf /tmp/local_src\n\nRUN --mount=type=cache,target=/root/.cache/pip \\\n    python3 -m pip install --upgrade pip setuptools wheel html5lib six \\\n    && cd sglang \\\n    && case \"$CUDA_VERSION\" in \\\n        12.6.1) CUINDEX=126 ;; \\\n        12.8.1) CUINDEX=128 ;; \\\n        12.9.1) CUINDEX=129 ;; \\\n        13.0.1) CUINDEX=130 ;; \\\n        *) echo \"Unsupported CUDA version: $CUDA_VERSION\" && exit 1 ;; \\\n    esac \\\n    && if [ \"$CUDA_VERSION\" = \"12.6.1\" ]; then \\\n        python3 -m pip install https://${GITHUB_ARTIFACTORY}/sgl-project/whl/releases/download/v${SGL_KERNEL_VERSION}/sglang_kernel-${SGL_KERNEL_VERSION}+cu124-cp310-abi3-manylinux2014_$(uname -m).whl --force-reinstall --no-deps \\\n    ; \\\n    elif [ \"$CUDA_VERSION\" = \"12.8.1\" ] || [ \"$CUDA_VERSION\" = \"12.9.1\" ]; then \\\n        python3 -m pip install sglang-kernel==${SGL_KERNEL_VERSION} \\\n    ; \\\n    elif [ \"$CUDA_VERSION\" = \"13.0.1\" ]; then \\\n        python3 -m pip install https://github.com/sgl-project/whl/releases/download/v${SGL_KERNEL_VERSION}/sglang_kernel-${SGL_KERNEL_VERSION}+cu130-cp310-abi3-manylinux2014_$(uname -m).whl --force-reinstall --no-deps \\\n    ; \\\n    else \\\n        echo \"Unsupported CUDA version: $CUDA_VERSION\" && exit 1 \\\n    ; \\\n    fi \\\n    && python3 -m pip install -e \"python[${BUILD_TYPE}]\" --extra-index-url https://download.pytorch.org/whl/cu${CUINDEX} \\\n    && if [ \"$INSTALL_FLASHINFER_JIT_CACHE\" = \"1\" ]; then \\\n        python3 -m pip install flashinfer-jit-cache==${FLASHINFER_VERSION} --index-url https://flashinfer.ai/whl/cu${CUINDEX} ; \\\n    fi \\\n    && FLASHINFER_CUBIN_DOWNLOAD_THREADS=${BUILD_AND_DOWNLOAD_PARALLEL} FLASHINFER_LOGGING_LEVEL=warning python3 -m flashinfer --download-cubin\n\n# DeepEP\n# We use Tom's DeepEP fork for GB200 for now; the 1fd57b0276311d035d16176bb0076426166e52f3 commit is https://github.com/fzyzcjy/DeepEP/tree/gb200_blog_part_2\n# TODO: move from Tom's branch to DeepEP hybrid-ep branch\n# We use the nvshmem version that ships with torch 2.9.1\n# CU12 uses 3.3.20 and CU13 uses 3.3.24\nRUN set -eux; \\\n    if [ \"$GRACE_BLACKWELL\" = \"1\" ]; then \\\n      git clone https://github.com/fzyzcjy/DeepEP.git && \\\n      cd DeepEP && \\\n      git checkout ${GRACE_BLACKWELL_DEEPEP_BRANCH} && \\\n      sed -i 's/#define NUM_CPU_TIMEOUT_SECS 100/#define NUM_CPU_TIMEOUT_SECS 1000/' csrc/kernels/configs.cuh && \\\n      sed -i 's/#define NUM_TIMEOUT_CYCLES 200000000000ull/#define NUM_TIMEOUT_CYCLES 2000000000000ull/' csrc/kernels/configs.cuh && \\\n      cd .. ; \\\n    elif [ \"$HOPPER_SBO\" = \"1\" ]; then \\\n      git clone https://github.com/deepseek-ai/DeepEP.git -b antgroup-opt && \\\n      cd DeepEP && \\\n      git checkout ${HOPPER_SBO_DEEPEP_COMMIT} && \\\n      sed -i 's/#define NUM_CPU_TIMEOUT_SECS 100/#define NUM_CPU_TIMEOUT_SECS 1000/' csrc/kernels/configs.cuh && \\\n      sed -i 's/#define NUM_TIMEOUT_CYCLES 200000000000ull/#define NUM_TIMEOUT_CYCLES 2000000000000ull/' csrc/kernels/configs.cuh && \\\n      cd .. ; \\\n    else \\\n        curl --retry 3 --retry-delay 2 -fsSL -o ${DEEPEP_COMMIT}.zip \\\n            https://${GITHUB_ARTIFACTORY}/deepseek-ai/DeepEP/archive/${DEEPEP_COMMIT}.zip && \\\n        unzip -q ${DEEPEP_COMMIT}.zip && rm ${DEEPEP_COMMIT}.zip && mv DeepEP-${DEEPEP_COMMIT} DeepEP && cd DeepEP && \\\n        sed -i 's/#define NUM_CPU_TIMEOUT_SECS 100/#define NUM_CPU_TIMEOUT_SECS 1000/' csrc/kernels/configs.cuh && \\\n        sed -i 's/#define NUM_TIMEOUT_CYCLES 200000000000ull/#define NUM_TIMEOUT_CYCLES 2000000000000ull/' csrc/kernels/configs.cuh && \\\n        cd .. ; \\\n    fi\n\n# Install DeepEP\nRUN --mount=type=cache,target=/root/.cache/pip \\\n    cd /sgl-workspace/DeepEP && \\\n    case \"$CUDA_VERSION\" in \\\n        12.6.1) \\\n            CHOSEN_TORCH_CUDA_ARCH_LIST='9.0' \\\n            ;; \\\n        12.8.1) \\\n            # FIXED: 12.8.1 does NOT support Blackwell 10.3 \\\n            CHOSEN_TORCH_CUDA_ARCH_LIST='9.0;10.0' \\\n            ;; \\\n        12.9.1|13.0.1) \\\n            # 12.9.1+ properly supports Blackwell 10.3 \\\n            CHOSEN_TORCH_CUDA_ARCH_LIST='9.0;10.0;10.3' \\\n            ;; \\\n        *) \\\n            echo \"Unsupported CUDA version: $CUDA_VERSION\" && exit 1 \\\n            ;; \\\n    esac && \\\n    if [ \"${CUDA_VERSION%%.*}\" = \"13\" ]; then \\\n        sed -i \"/^    include_dirs = \\['csrc\\/'\\]/a\\    include_dirs.append('${CUDA_HOME}/include/cccl')\" setup.py; \\\n    fi && \\\n    TORCH_CUDA_ARCH_LIST=\"${CHOSEN_TORCH_CUDA_ARCH_LIST}\" MAX_JOBS=${BUILD_AND_DOWNLOAD_PARALLEL} pip install --no-build-isolation .\n\n# Install Mooncake\nRUN --mount=type=cache,target=/root/.cache/pip \\\n    CUDA_MAJOR=\"${CUDA_VERSION%%.*}\" && \\\n    if [ \"$CUDA_MAJOR\" -ge 13 ]; then \\\n        echo \"CUDA >= 13, installing mooncake-transfer-engine from source code\"; \\\n        git clone --branch v${MOONCAKE_VERSION} --depth 1 https://github.com/kvcache-ai/Mooncake.git && \\\n        cd Mooncake && \\\n        bash dependencies.sh && \\\n        mkdir -p build && \\\n        cd build && \\\n        cmake .. ${MOONCAKE_COMPILE_ARG} && \\\n        make -j$(nproc) && \\\n        make install; \\\n    else \\\n        echo \"CUDA < 13, installing mooncake-transfer-engine from pip\"; \\\n        python3 -m pip install mooncake-transfer-engine==${MOONCAKE_VERSION}; \\\n    fi\n# Install essential Python packages\nRUN --mount=type=cache,target=/root/.cache/pip \\\n    python3 -m pip install \\\n    datamodel_code_generator \\\n    pre-commit \\\n    pytest \\\n    black \\\n    isort \\\n    icdiff \\\n    uv \\\n    wheel \\\n    scikit-build-core \\\n    nixl \\\n    py-spy \\\n    cubloaty \\\n    google-cloud-storage\n\n# Build and install sgl-model-gateway (install Rust, build, then remove to save space)\nRUN --mount=type=cache,target=/root/.cache/pip \\\n    curl --proto '=https' --tlsv1.2 --retry 3 --retry-delay 2 -sSf https://sh.rustup.rs | sh -s -- -y \\\n    && export PATH=\"/root/.cargo/bin:${PATH}\" \\\n    && rustc --version && cargo --version \\\n    && python3 -m pip install maturin \\\n    && cd /sgl-workspace/sglang/sgl-model-gateway/bindings/python \\\n    && ulimit -n 65536 && maturin build --release --features vendored-openssl --out dist \\\n    && python3 -m pip install --force-reinstall dist/*.whl \\\n    && cd /sgl-workspace/sglang/sgl-model-gateway \\\n    && cargo build --release --bin sglang-router --features vendored-openssl \\\n    && cp target/release/sglang-router /usr/local/bin/sglang-router \\\n    && rm -rf /root/.cargo /root/.rustup target dist ~/.cargo \\\n    && sed -i '/\\.cargo\\/env/d' /root/.profile /root/.bashrc 2>/dev/null || true\n\nRUN --mount=type=cache,target=/root/.cache/pip \\\n   python3 -m pip install \"nvidia-cutlass-dsl>=4.4.1\" \"nvidia-cutlass-dsl-libs-base>=4.4.1\" --force-reinstall --no-deps;\n\n# Patching packages for CUDA 12/13 compatibility\n# TODO: Remove when torch version covers these packages\nRUN --mount=type=cache,target=/root/.cache/pip if [ \"${CUDA_VERSION%%.*}\" = \"12\" ]; then \\\n    python3 -m pip install nvidia-nccl-cu12==2.28.3 --force-reinstall --no-deps ; \\\n    python3 -m pip install nvidia-cudnn-cu12==9.16.0.29 --force-reinstall --no-deps ; \\\n    python3 -m pip install cuda-python==12.9 ; \\\nelif [ \"${CUDA_VERSION%%.*}\" = \"13\" ]; then \\\n    python3 -m pip install nvidia-nccl-cu13==2.28.3 --force-reinstall --no-deps ; \\\n    python3 -m pip install nvidia-cudnn-cu13==9.16.0.29 --force-reinstall --no-deps ; \\\n    python3 -m pip install nvidia-cublas==13.1.0.3 --force-reinstall --no-deps ; \\\n    python3 -m pip install nixl-cu13 --no-deps ; \\\n    python3 -m pip install cuda-python==13.2.0 ; \\\nfi\n\n# Install development tools\nRUN --mount=type=cache,target=/var/cache/apt,id=framework-apt \\\n    apt-get update && apt-get install -y --no-install-recommends \\\n    gdb \\\n    ninja-build \\\n    vim \\\n    tmux \\\n    htop \\\n    zsh \\\n    tree \\\n    silversearcher-ag \\\n    cloc \\\n    pkg-config \\\n    bear \\\n    less \\\n    rdma-core \\\n    openssh-server \\\n    gnuplot \\\n    infiniband-diags \\\n    perftest \\\n    ibverbs-providers \\\n    libibumad3 \\\n    libibverbs1 \\\n    libnl-3-200 \\\n    libnl-route-3-200 \\\n    librdmacm1 \\\n    && rm -rf /var/lib/apt/lists/* \\\n    && apt-get clean\n\n# Install NVIDIA development tools\nRUN --mount=type=cache,target=/var/cache/apt,id=framework-apt \\\n    apt update -y \\\n    && apt install -y --no-install-recommends gnupg \\\n    && echo \"deb http://developer.download.nvidia.com/devtools/repos/ubuntu2004/$(if [ \"$(uname -m)\" = \"aarch64\" ]; then echo \"arm64\"; else echo \"amd64\"; fi) /\" | tee /etc/apt/sources.list.d/nvidia-devtools.list \\\n    && apt-key adv --fetch-keys http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/$(if [ \"$(uname -m)\" = \"aarch64\" ]; then echo \"arm64\"; else echo \"x86_64\"; fi)/7fa2af80.pub \\\n    && apt update -y \\\n    && apt install -y --no-install-recommends nsight-systems-cli \\\n    && rm -rf /var/lib/apt/lists/*\n\n# Install minimal Python dev packages\nRUN --mount=type=cache,target=/root/.cache/pip \\\n    python3 -m pip install --break-system-packages \\\n    pytest \\\n    black \\\n    isort \\\n    icdiff \\\n    scikit-build-core \\\n    uv \\\n    pre-commit \\\n    pandas \\\n    matplotlib \\\n    tabulate \\\n    termplotlib\n\n# diff-so-fancy\nRUN curl --retry 3 --retry-delay 2 -LSso /usr/local/bin/diff-so-fancy \\\n        https://${GITHUB_ARTIFACTORY}/so-fancy/diff-so-fancy/releases/download/v1.4.4/diff-so-fancy \\\n    && chmod +x /usr/local/bin/diff-so-fancy\n\n# clang-format\nRUN curl --retry 3 --retry-delay 2 -LSso /usr/local/bin/clang-format \\\n        https://${GITHUB_ARTIFACTORY}/muttleyxd/clang-tools-static-binaries/releases/download/master-32d3ac78/clang-format-16_linux-amd64 \\\n    && chmod +x /usr/local/bin/clang-format\n\n# clangd\nRUN curl --retry 3 --retry-delay 2 -fsSL -o clangd.zip \\\n        https://${GITHUB_ARTIFACTORY}/clangd/clangd/releases/download/18.1.3/clangd-linux-18.1.3.zip \\\n    && unzip -q clangd.zip \\\n    && cp -r clangd_18.1.3/bin/* /usr/local/bin/ \\\n    && cp -r clangd_18.1.3/lib/* /usr/local/lib/ \\\n    && rm -rf clangd_18.1.3 clangd.zip\n\n# CMake\nRUN CMAKE_VERSION=3.31.1 \\\n    && ARCH=$(uname -m) \\\n    && CMAKE_INSTALLER=\"cmake-${CMAKE_VERSION}-linux-${ARCH}\" \\\n    && curl --retry 3 --retry-delay 2 -fsSL -o \"${CMAKE_INSTALLER}.tar.gz\" \\\n        \"https://${GITHUB_ARTIFACTORY}/Kitware/CMake/releases/download/v${CMAKE_VERSION}/${CMAKE_INSTALLER}.tar.gz\" \\\n    && tar -xzf \"${CMAKE_INSTALLER}.tar.gz\" \\\n    && cp -r \"${CMAKE_INSTALLER}/bin/\"* /usr/local/bin/ \\\n    && cp -r \"${CMAKE_INSTALLER}/share/\"* /usr/local/share/ \\\n    && rm -rf \"${CMAKE_INSTALLER}\" \"${CMAKE_INSTALLER}.tar.gz\"\n\n# Install just\nRUN curl --proto '=https' --tlsv1.2 --retry 3 --retry-delay 2 -sSf https://just.systems/install.sh | \\\n    sed \"s|https://github.com|https://${GITHUB_ARTIFACTORY}|g\" | \\\n    bash -s -- --tag 1.42.4 --to /usr/local/bin\n\n# Add yank script\nCOPY --chown=root:root --chmod=755 docker/configs/yank /usr/local/bin/yank\n\n# Install oh-my-zsh and plugins\nRUN sh -c \"$(curl --retry 3 --retry-delay 2 -fsSL https://raw.githubusercontent.com/ohmyzsh/ohmyzsh/master/tools/install.sh)\" \"\" --unattended \\\n    && git clone --depth 1 https://github.com/zsh-users/zsh-autosuggestions ${ZSH_CUSTOM:-~/.oh-my-zsh/custom}/plugins/zsh-autosuggestions \\\n    && git clone --depth 1 https://github.com/zsh-users/zsh-syntax-highlighting.git ${ZSH_CUSTOM:-~/.oh-my-zsh/custom}/plugins/zsh-syntax-highlighting\n\n# These configs are optional; users can override them by mounting their own files\nCOPY docker/configs/opt/.vimrc /opt/sglang/.vimrc\nCOPY docker/configs/opt/.tmux.conf /opt/sglang/.tmux.conf\nCOPY docker/configs/opt/.gitconfig /opt/sglang/.gitconfig\n\n# Configure development environment\nCOPY docker/configs/.zshrc /root/.zshrc\n\n# Fix Triton to use system ptxas for Blackwell (sm_103a) support (CUDA 13+ only)\nRUN if [ \"${CUDA_VERSION%%.*}\" = \"13\" ] && [ -d /usr/local/lib/python3.12/dist-packages/triton/backends/nvidia/bin ]; then \\\n        rm -f /usr/local/lib/python3.12/dist-packages/triton/backends/nvidia/bin/ptxas && \\\n        ln -s /usr/local/cuda/bin/ptxas /usr/local/lib/python3.12/dist-packages/triton/backends/nvidia/bin/ptxas; \\\n    fi\n\nRUN python3 -m pip install --upgrade \"urllib3>=2.6.3\"\n\n# Set workspace directory\nWORKDIR /sgl-workspace/sglang\n\n########################################################\n########## Runtime Image ##############################\n########################################################\n#\n# PURPOSE: Production runtime environment with JIT support\n#\n# This stage creates a production-ready image containing:\n# - Pre-compiled SGLang and DeepEP components\n# - Full CUDA toolchain for JIT compilation (DeepGEMM, Triton, FlashInfer)\n# - Optimized for inference workloads and deployment\n# - Smaller than framework (no dev tools like vim, tmux, nsight, etc.)\n#\n# Use this stage when you need:\n# - Production deployment of SGLang\n# - JIT compilation support for FP8/microscaling kernels\n# - Ready-to-run inference server environment\n#\n# Note: Uses devel base for complete NVCC toolchain required by DeepGEMM JIT\nFROM nvidia/cuda:${CUDA_VERSION}-cudnn-devel-ubuntu24.04 AS runtime\n\nARG CUDA_VERSION\nARG TARGETARCH\nARG GDRCOPY_VERSION=2.5.1\n\nENV DEBIAN_FRONTEND=noninteractive \\\n    CUDA_HOME=/usr/local/cuda \\\n    GDRCOPY_HOME=/usr/src/gdrdrv-${GDRCOPY_VERSION}/\n\n# Add GKE default lib and bin locations + CUDA compiler paths for FlashInfer JIT\nENV PATH=\"${PATH}:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/cuda/nvvm/bin\" \\\n    LD_LIBRARY_PATH=\"${LD_LIBRARY_PATH}:/usr/local/nvidia/lib:/usr/local/nvidia/lib64\"\n\n# Install runtime dependencies (devel base provides gcc/g++/build tools)\nRUN --mount=type=cache,target=/var/cache/apt,id=runtime-apt \\\n    apt-get update && apt-get install -y --no-install-recommends \\\n    # Python runtime\n    software-properties-common \\\n    && add-apt-repository ppa:deadsnakes/ppa -y \\\n    && apt-get update && apt-get install -y --no-install-recommends --allow-change-held-packages \\\n    python3.12-full \\\n    python3.12-dev \\\n    wget \\\n    # Core system utilities\n    ca-certificates \\\n    netcat-openbsd \\\n    curl \\\n    git \\\n    # Runtime libraries\n    libopenmpi3 \\\n    libnuma1 \\\n    libibverbs1 \\\n    libibumad3 \\\n    librdmacm1 \\\n    libnl-3-200 \\\n    libnl-route-3-200 \\\n    ibverbs-providers \\\n    libgoogle-glog0v6t64 \\\n    libunwind8 \\\n    libboost-system1.83.0 \\\n    libboost-thread1.83.0 \\\n    libboost-filesystem1.83.0 \\\n    libgrpc++1.51t64 \\\n    libprotobuf32t64 \\\n    libhiredis1.1.0 \\\n    libcurl4 \\\n    libczmq4 \\\n    libfabric1 \\\n    libssl3 \\\n    # RDMA runtime\n    rdma-core \\\n    infiniband-diags \\\n    perftest \\\n    # Build tools for JIT compilation\n    ninja-build \\\n    # NCCL packages needed for pynccl_allocator JIT compilation (-lnccl)\n    libnccl2 \\\n    libnccl-dev \\\n    # GPG key verification\n    gnupg2 \\\n    && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.12 2 \\\n    && update-alternatives --set python3 /usr/bin/python3.12 \\\n    && ln -sf /usr/bin/python3.12 /usr/bin/python \\\n    && wget -q https://bootstrap.pypa.io/get-pip.py \\\n    && python3 get-pip.py --break-system-packages \\\n    && rm get-pip.py \\\n    # Allow pip to install packages globally (PEP 668 workaround for Ubuntu 24.04)\n    && python3 -m pip config set global.break-system-packages true \\\n    && rm -rf /var/lib/apt/lists/* \\\n    && apt-get clean\n\n# Set up locale\nRUN apt-get update && apt-get install -y --no-install-recommends locales \\\n    && locale-gen en_US.UTF-8 \\\n    && rm -rf /var/lib/apt/lists/*\n\nENV LANG=en_US.UTF-8 \\\n    LANGUAGE=en_US:en \\\n    LC_ALL=en_US.UTF-8\n\n# Copy Python site-packages from framework (contains all built packages)\nCOPY --from=framework /usr/local/lib/python3.12/dist-packages /usr/local/lib/python3.12/dist-packages\n\n# Copy SGLang workspace\nCOPY --from=framework /sgl-workspace /sgl-workspace\n\n# Fix Triton to use system ptxas for Blackwell (sm_103a) support (CUDA 13+ only)\nRUN if [ \"${CUDA_VERSION%%.*}\" = \"13\" ] && [ -d /usr/local/lib/python3.12/dist-packages/triton/backends/nvidia/bin ]; then \\\n        rm -f /usr/local/lib/python3.12/dist-packages/triton/backends/nvidia/bin/ptxas && \\\n        ln -s /usr/local/cuda/bin/ptxas /usr/local/lib/python3.12/dist-packages/triton/backends/nvidia/bin/ptxas; \\\n    fi\n\n# Copy GDRCopy runtime libraries (but not the build artifacts)\nCOPY --from=framework /usr/lib/libgdrapi.so* /usr/lib/\nCOPY --from=framework /usr/bin/gdrcopy_* /usr/bin/\nCOPY --from=framework /usr/src/gdrdrv-2.5.1 /usr/src/gdrdrv-2.5.1\n\n# Fix DeepEP IBGDA symlink in runtime\nRUN ln -sf /usr/lib/$(uname -m)-linux-gnu/libmlx5.so.1 /usr/lib/$(uname -m)-linux-gnu/libmlx5.so\n\nWORKDIR /sgl-workspace/sglang\n\n# Default command\nCMD [\"/bin/bash\"]\n"
  },
  {
    "path": "docker/compose.yaml",
    "content": "services:\n  sglang:\n    image: lmsysorg/sglang:latest\n    container_name: sglang\n    volumes:\n      - ${HOME}/.cache/huggingface:/root/.cache/huggingface\n      # If you use modelscope, you need mount this directory\n      # - ${HOME}/.cache/modelscope:/root/.cache/modelscope\n    restart: always\n    network_mode: host # required by RDMA\n    privileged: true # required by RDMA\n    # Or you can only publish port 30000\n    # ports:\n    #   - 30000:30000\n    environment:\n      - HF_TOKEN=<secret>\n      # if you use modelscope to download model, you need set this environment\n      # - SGLANG_USE_MODELSCOPE=true\n    entrypoint: python3 -m sglang.launch_server\n    command: --model-path meta-llama/Llama-3.1-8B-Instruct\n      --host 0.0.0.0\n      --port 30000\n    ulimits:\n      memlock: -1\n      stack: 67108864\n    ipc: host\n    healthcheck:\n      test: [\"CMD-SHELL\", \"curl -f http://localhost:30000/health || exit 1\"]\n    deploy:\n      resources:\n        reservations:\n          devices:\n            - driver: nvidia\n              count: 1\n              capabilities: [gpu]\n"
  },
  {
    "path": "docker/configs/.zshrc",
    "content": "export ZSH=\"/root/.oh-my-zsh\"\n\n# Theme\nZSH_THEME=\"robbyrussell\"\n\n# Plugins\nplugins=(\n    git\n    z\n    zsh-autosuggestions\n    zsh-syntax-highlighting\n)\n\nsource $ZSH/oh-my-zsh.sh\n\n# Aliases\nalias ll='ls -alF'\nalias la='ls -A'\nalias l='ls -CF'\nalias vi='vim'\n\n# Enhanced history\nHISTSIZE=10000\nSAVEHIST=10000\nsetopt HIST_IGNORE_ALL_DUPS\nsetopt HIST_FIND_NO_DUPS\nsetopt INC_APPEND_HISTORY\n"
  },
  {
    "path": "docker/configs/opt/.gitconfig",
    "content": "[core]\n\teditor = vim\n\twhitespace = fix,-indent-with-non-tab,trailing-space,cr-at-eol\n\tpager = diff-so-fancy | less --tabs=4 -RFX\n\n[color]\n\tui = true\n\n[color \"diff-highlight\"]\n\toldNormal = red bold\n\toldHighlight = red bold 52\n\tnewNormal = green bold\n\tnewHighlight = green bold 22\n\n[color \"diff\"]\n\tmeta = 11\n\tfrag = magenta bold\n\tcommit = yellow bold\n\told = red bold\n\tnew = green bold\n\twhitespace = red reverse\n\n[alias]\n\tlg = log --color --graph --pretty=format:'%Cred%h%Creset - %s %Cgreen(%cr) %C(bold blue)<%an>%Creset%C(auto)%d%Creset' --abbrev-commit --\n\n[http]\n\tsslVerify = false\n\n[pull]\n\trebase = true\n"
  },
  {
    "path": "docker/configs/opt/.tmux.conf",
    "content": "# Pane border styling\nset -g pane-border-style fg='#742727',bg=black\nset -g pane-active-border-style fg=red,bg=black\n\n# Status bar styling\nset -g status-style bg='#0C8A92',fg=black\n\n# Change prefix key to backtick\nset-option -g prefix `\nunbind C-b\nbind-key ` send-prefix\n\n# Split panes using - and = with current path\nunbind '\"'\nbind - splitw -v -c '#{pane_current_path}'\nunbind '%'\nbind = splitw -h -c '#{pane_current_path}'\n\n# Vi mode settings\nbind-key -T copy-mode-vi Y send-keys -X copy-pipe 'yank > #{pane_tty}'\nset-window-option -g mode-keys vi\n\n# Other settings\nset-option -g escape-time 0\nset-option -g base-index 1\nset-window-option -g mouse on\nset -g history-limit 100000\n"
  },
  {
    "path": "docker/configs/opt/.vimrc",
    "content": "function! Yank(text) abort\n  let escape = system('yank', a:text)\n  if v:shell_error\n    echoerr escape\n  else\n    call writefile([escape], '/dev/tty', 'b')\n  endif\nendfunction\n\nnoremap <silent> <Leader>y y:<C-U>call Yank(@0)<CR>\n\n\" automatically run yank(1) whenever yanking in Vim\nfunction! CopyYank() abort\n  call Yank(join(v:event.regcontents, \"\\n\"))\nendfunction\n\nautocmd TextYankPost * call CopyYank()\n\n\" Basic settings\nset number\nsyntax on\nset mouse=a\nfiletype indent on\n\n\" Indentation\nset autoindent nosmartindent\nset smarttab\nset expandtab\nset shiftwidth=4\nset softtabstop=4\n\n\" Visual guides\nset colorcolumn=120\nhighlight ColorColumn ctermbg=5\n\n\" Status line\nset laststatus=2\nset statusline=%<%f\\ %h%m%r%=%{\\\"[\\\".(&fenc==\\\"\\\"?&enc:&fenc).((exists(\\\"+bomb\\\")\\ &&\\ &bomb)?\\\",B\\\":\\\"\\\").\\\"]\\ \\\"}%k\\ %-14.(%l,%c%V%)\\ %P\n\n\" Backspace behavior\nset backspace=2\n\n\" Encoding\nset encoding=utf-8\nset fileencoding=utf-8\n"
  },
  {
    "path": "docker/configs/yank",
    "content": "#!/bin/bash\nput() {\n  esc=$1\n  test -n \"$TMUX\" -o -z \"${TERM##screen*}\" && esc=\"\\033Ptmux;\\033$esc\\033\\\\\"\n  printf \"$esc\"\n}\nput \"\\033]52;c;!\\a\"\nbuf=$( cat \"$@\" )\nlen=$( printf %s \"$buf\" | wc -c ) max=74994\ntest $len -gt $max && echo \"$0: input is $(( len - max )) bytes too long\" >&2\nput \"\\033]52;c;$( printf %s \"$buf\" | head -c $max | base64 | tr -d '\\r\\n' )\\a\"\ntest -n \"$TMUX\" && tmux set-buffer \"$buf\" ||:\n"
  },
  {
    "path": "docker/diffusion.Dockerfile",
    "content": "FROM nvidia/cuda:12.8.0-cudnn-devel-ubuntu22.04\n\nENV DEBIAN_FRONTEND=noninteractive\n\nSHELL [\"/bin/bash\", \"-c\"]\n\nWORKDIR /sgl-workspace/sglang\n\nRUN apt-get update && apt-get install -y --no-install-recommends \\\n    wget \\\n    git \\\n    ca-certificates \\\n    openssh-server \\\n    zsh \\\n    vim \\\n    curl \\\n    gcc-11 \\\n    g++-11 \\\n    clang-11 \\\n    libnuma1 libnuma-dev \\\n    && rm -rf /var/lib/apt/lists/*\n\n# Install oh-my-zsh and plugins\nRUN sh -c \"$(curl -fsSL https://raw.githubusercontent.com/ohmyzsh/ohmyzsh/master/tools/install.sh)\" \"\" --unattended \\\n    && git clone https://github.com/zsh-users/zsh-autosuggestions ${ZSH_CUSTOM:-~/.oh-my-zsh/custom}/plugins/zsh-autosuggestions \\\n    && git clone https://github.com/zsh-users/zsh-syntax-highlighting.git ${ZSH_CUSTOM:-~/.oh-my-zsh/custom}/plugins/zsh-syntax-highlighting\n\n\n# Set up C++20 compilers for ThunderKittens\nRUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-11 100 --slave /usr/bin/g++ g++ /usr/bin/g++-11\n\n# Set CUDA environment variables\nENV CUDA_HOME=/usr/local/cuda-12.8\nENV PATH=${CUDA_HOME}/bin:${PATH}\nENV LD_LIBRARY_PATH=${CUDA_HOME}/lib64:$LD_LIBRARY_PATH\n\n# Install uv and source its environment\nRUN curl -LsSf https://astral.sh/uv/install.sh | sh && \\\n    echo 'source $HOME/.local/bin/env' >> /root/.zshrc\n\n# Copy just the pyproject.toml first to leverage Docker cache\nCOPY python/pyproject.toml python/\n\n# Create a dummy README to satisfy the installation\nRUN mkdir -p python && echo \"# Placeholder\" > python/README.md\n\n# Create and activate virtual environment with specific Python version and seed\nRUN source $HOME/.local/bin/env && \\\n    uv venv --python 3.12 --seed /opt/venv && \\\n    source /opt/venv/bin/activate && \\\n    uv pip install nvitop && \\\n    uv pip install --no-cache-dir --upgrade pip && \\\n    uv pip install --no-cache-dir --prerelease=allow ./python[diffusion]\n\nCOPY . .\n\n# Install dependencies using uv and set up shell configuration\nRUN source $HOME/.local/bin/env && \\\n    source /opt/venv/bin/activate && \\\n    git config --unset-all http.https://github.com/.extraheader || true && \\\n    echo 'source /opt/venv/bin/activate' >> /root/.zshrc && \\\n    echo 'if [ -n \"$ZSH_VERSION\" ] && [ -f ~/.zshrc ]; then . ~/.zshrc; elif [ -f ~/.bashrc ]; then . ~/.bashrc; fi' > /root/.profile\n\n# Set PATH to include venv bin\nENV PATH=/opt/venv/bin:$PATH\n\n# Configure zsh\nCOPY --chown=root:root <<-\"EOF\" /root/.zshrc\nexport ZSH=\"/root/.oh-my-zsh\"\n\nsource $HOME/.local/bin/env\nsource /opt/venv/bin/activate\n\n## Theme\nZSH_THEME=\"robbyrussell\"\n\n## Plugins\nplugins=(\n    git\n    z\n    zsh-autosuggestions\n    zsh-syntax-highlighting\n)\n\nsource $ZSH/oh-my-zsh.sh\n\n## Aliases\nalias ll='ls -alF'\nalias la='ls -A'\nalias l='ls -CF'\nalias vi='vim'\n\n## Enhanced history\nHISTSIZE=10000\nSAVEHIST=10000\nsetopt HIST_IGNORE_ALL_DUPS\nsetopt HIST_FIND_NO_DUPS\nsetopt INC_APPEND_HISTORY\nEOF\n\n\nEXPOSE 22\n\nCMD [\"/bin/zsh\"]\n"
  },
  {
    "path": "docker/gateway.Dockerfile",
    "content": "######################## BASE IMAGE ##########################\nFROM ubuntu:24.04 AS base\n\nARG PYTHON_VERSION=3.12\n\n# set the environment variables\nENV PATH=\"/root/.local/bin:${PATH}\"\nENV DEBIAN_FRONTEND=noninteractive\n\n# uv environment variables\nENV UV_HTTP_TIMEOUT=500\nENV VIRTUAL_ENV=\"/opt/venv\"\nENV UV_PYTHON_INSTALL_DIR=/opt/uv/python\nENV UV_LINK_MODE=\"copy\"\nENV PATH=\"$VIRTUAL_ENV/bin:$PATH\"\n\n\n# install dependencies\nRUN apt update -y \\\n    && apt install -y curl \\\n    && rm -rf /var/lib/apt/lists/* \\\n    && apt clean\n\n# install uv\nRUN curl -LsSf https://astral.sh/uv/install.sh | sh\n\n# install python\nRUN uv venv --python ${PYTHON_VERSION} --seed ${VIRTUAL_ENV}\n\nFROM scratch AS local_src\nCOPY . /src\n\n######################### BUILD IMAGE #########################\nFROM base AS build-image\n\n# set the environment variables\nENV PATH=\"/root/.cargo/bin:${PATH}\"\n\n# install dependencies\nRUN apt update -y \\\n    && apt install -y git build-essential libssl-dev pkg-config protobuf-compiler \\\n    && rm -rf /var/lib/apt/lists/* \\\n    && apt clean\n\n# install rustup from rustup.rs\nRUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y \\\n    && rustc --version && cargo --version && protoc --version\n\n# copy source code\nCOPY --from=local_src /src /opt/sglang\n\n# working directory\nWORKDIR /opt/sglang/sgl-model-gateway\n\n# install maturin and build the wheel with vendored OpenSSL\nRUN uv pip install maturin \\\n    && cargo clean \\\n    && rm -rf bindings/python/dist/ \\\n    && cd bindings/python \\\n    && ulimit -n 65536 && maturin build --release --features vendored-openssl --out dist \\\n    && rm -rf /root/.cache\n\n######################### ROUTER IMAGE #########################\nFROM base AS router-image\n\n# Copy the built package from the build image\nCOPY --from=build-image /opt/sglang/sgl-model-gateway/bindings/python/dist/*.whl dist/\n\n# Build the package and install\nRUN uv pip install --force-reinstall dist/*.whl\n\n# Clean up unnecessary files to reduce the image size\nRUN rm -rf /root/.cache dist/ \\\n    && apt purge -y --auto-remove curl\n\n# Set the entrypoint to the main command\nENTRYPOINT [\"python3\", \"-m\", \"sglang_router.launch_router\"]\n"
  },
  {
    "path": "docker/k8s-sglang-distributed-sts.yaml",
    "content": "# Two Nodes Sglang example\n\napiVersion: apps/v1\nkind: StatefulSet\nmetadata:\n  name: distributed-sglang\nspec:\n  replicas: 2   # number of nodes/pods to run distributed sglang\n  selector:\n    matchLabels:\n      app: distributed-sglang\n  serviceName: \"\"\n  template:\n    metadata:\n      labels:\n        app: distributed-sglang\n    spec:\n      containers:\n      - name: sglang-container\n        image: docker.io/lmsysorg/sglang:latest\n        imagePullPolicy: Always # image may be replaced by official CI versioned image\n        command:\n        - /bin/bash\n        - -c\n        # please modify the sglang serving arguments below, as necessary.\n        # NOTE: the --expert-parallel-size is for MoE model like DeepSeek-R1\n        args:\n        - |\n          python3 -m sglang.launch_server \\\n          --model /llm-folder \\\n          --dist-init-addr sglang-master-pod:5000 \\\n          --tensor-parallel-size 16 \\\n          --nnodes 2 \\\n          --node-rank $POD_INDEX \\\n          --trust-remote-code \\\n          --host 0.0.0.0 \\\n          --port 8000 \\\n          --enable-metrics \\\n          --expert-parallel-size 16\n        env:\n        - name: POD_INDEX     # reflects the node-rank\n          valueFrom:\n            fieldRef:\n              apiVersion: v1\n              fieldPath: metadata.labels['apps.kubernetes.io/pod-index']\n        - name: NCCL_DEBUG\n          value: INFO\n        resources:\n          limits:\n            nvidia.com/gpu: \"8\"\n          requests:\n        volumeMounts:\n        - mountPath: /dev/shm\n          name: dshm\n        - mountPath: /llm-folder\n          name: llm\n        securityContext:\n          privileged: true   # to leverage RDMA/InfiniBand device, co-work with HostNetwork=true\n      hostNetwork: true\n      volumes:\n      - emptyDir:\n          medium: Memory\n          sizeLimit: 10Gi\n        name: dshm\n      - hostPath:\n          path: /llm-folder # replace with PVC or hostPath with your model weights\n          type: DirectoryOrCreate\n        name: llm\n      #- persistentVolumeClaim:\n      #  claimName: llm-pvc\n      #  name: llm\n---\napiVersion: v1\nkind: Service\nmetadata:\n  name: sglang-master-pod\nspec:\n  type: ClusterIP\n  selector:\n    app: distributed-sglang\n    apps.kubernetes.io/pod-index: \"0\"\n  ports:\n  - name: dist-port\n    port: 5000\n    targetPort: 5000\n---\n# the serving service\napiVersion: v1\nkind: Service\nmetadata:\n  name: sglang-serving-on-master\nspec:\n  type: NodePort\n  selector:\n    app: distributed-sglang\n    apps.kubernetes.io/pod-index: \"0\"\n  ports:\n  - name: serving\n    port: 8000\n    targetPort: 8000\n  - name: metrics\n    port: 8080\n    targetPort: 8080\n"
  },
  {
    "path": "docker/k8s-sglang-service.yaml",
    "content": "apiVersion: v1\nkind: PersistentVolumeClaim\nmetadata:\n  name: llama-31-8b-sglang\nspec:\n  accessModes:\n    - ReadWriteMany\n  resources:\n    requests:\n      storage: 30Gi\n  storageClassName: default # change this to your preferred storage class\n  volumeMode: Filesystem\n---\napiVersion: node.k8s.io/v1\nkind: RuntimeClass\nmetadata:\n  name: nvidia\nhandler: nvidia\n---\napiVersion: apps/v1\nkind: Deployment\nmetadata:\n  name: meta-llama-31-8b-instruct-sglang\nspec:\n  replicas: 1\n  strategy:\n    type: Recreate\n  selector:\n    matchLabels:\n      app: meta-llama-31-8b-instruct-sglang\n  template:\n    metadata:\n      labels:\n        app: meta-llama-31-8b-instruct-sglang\n        model: meta-llama-31-8b-instruct\n        engine: sglang\n    spec:\n      restartPolicy: Always\n      runtimeClassName: nvidia\n      containers:\n        - name: meta-llama-31-8b-instruct-sglang\n          image: docker.io/lmsysorg/sglang:latest\n          imagePullPolicy: Always # IfNotPresent or Never\n          ports:\n            - containerPort: 30000\n          command: [\"python3\", \"-m\", \"sglang.launch_server\"]\n          args:\n            [\n              \"--model-path\",\n              \"meta-llama/Llama-3.1-8B-Instruct\",\n              \"--host\",\n              \"0.0.0.0\",\n              \"--port\",\n              \"30000\",\n            ]\n          env:\n            - name: HF_TOKEN\n              value: <secret>\n          resources:\n            limits:\n              nvidia.com/gpu: 1\n              cpu: 8\n              memory: 40Gi\n            requests:\n              cpu: 2\n              memory: 16Gi\n              nvidia.com/gpu: 1\n          volumeMounts:\n            - name: shm\n              mountPath: /dev/shm\n            - name: hf-cache\n              mountPath: /root/.cache/huggingface\n            - name: localtime\n              mountPath: /etc/localtime\n              readOnly: true\n          livenessProbe:\n            httpGet:\n              path: /health\n              port: 30000\n            initialDelaySeconds: 120\n            periodSeconds: 15\n            timeoutSeconds: 10\n            failureThreshold: 3\n          readinessProbe:\n            httpGet:\n              path: /health_generate\n              port: 30000\n            initialDelaySeconds: 120\n            periodSeconds: 15\n            timeoutSeconds: 10\n            failureThreshold: 3\n            successThreshold: 1\n      volumes:\n        - name: shm\n          emptyDir:\n            medium: Memory\n            sizeLimit: 10Gi\n        - name: hf-cache\n          persistentVolumeClaim:\n            claimName: llama-31-8b-sglang\n        - name: localtime\n          hostPath:\n            path: /etc/localtime\n            type: File\n---\napiVersion: v1\nkind: Service\nmetadata:\n  name: meta-llama-31-8b-instruct-sglang\nspec:\n  selector:\n    app: meta-llama-31-8b-instruct-sglang\n  ports:\n    - protocol: TCP\n      port: 80 # port on host\n      targetPort: 30000 # port in container\n  type: LoadBalancer # change to ClusterIP if needed\n"
  },
  {
    "path": "docker/npu.Dockerfile",
    "content": "ARG CANN_VERSION=8.5.0\nARG DEVICE_TYPE=a3\nARG OS=ubuntu22.04\nARG PYTHON_VERSION=py3.11\n\nFROM quay.io/ascend/cann:$CANN_VERSION-$DEVICE_TYPE-$OS-$PYTHON_VERSION\n\n# Update pip & apt sources\nARG TARGETARCH\nARG CANN_VERSION\nARG DEVICE_TYPE\nARG PIP_INDEX_URL=\"https://pypi.org/simple/\"\nARG APTMIRROR=\"\"\nARG PYTORCH_VERSION=\"2.8.0\"\nARG TORCHVISION_VERSION=\"0.23.0\"\nARG PTA_URL_ARM64=\"https://gitcode.com/Ascend/pytorch/releases/download/v7.3.0-pytorch2.8.0/torch_npu-2.8.0.post2-cp311-cp311-manylinux_2_28_aarch64.whl\"\nARG PTA_URL_AMD64=\"https://gitcode.com/Ascend/pytorch/releases/download/v7.3.0-pytorch2.8.0/torch_npu-2.8.0.post2-cp311-cp311-manylinux_2_28_x86_64.whl\"\nARG SGLANG_TAG=main\nARG ASCEND_CANN_PATH=/usr/local/Ascend/ascend-toolkit\nARG SGLANG_KERNEL_NPU_TAG=main\n\nARG PIP_INSTALL=\"python3 -m pip install --no-cache-dir\"\nARG DEVICE_TYPE\n\nRUN if [ \"$TARGETARCH\" = \"amd64\" ]; then \\\n      echo \"Using x86_64 dependencies\"; \\\n      echo \"PTA_URL=$PTA_URL_AMD64\" >> /etc/environment_new; \\\n    elif [ \"$TARGETARCH\" = \"arm64\" ]; then \\\n      echo \"Using aarch64 dependencies\"; \\\n      echo \"PTA_URL=$PTA_URL_ARM64\" >> /etc/environment_new; \\\n    else \\\n      echo \"Unsupported TARGETARCH: $TARGETARCH\"; exit 1; \\\n    fi\n\nWORKDIR /workspace\n\n# Define environments\nENV DEBIAN_FRONTEND=noninteractive\n\nRUN pip config set global.index-url $PIP_INDEX_URL\nRUN if [ -n \"$APTMIRROR\" ];then sed -i \"s|.*.ubuntu.com|$APTMIRROR|g\" /etc/apt/sources.list ;fi\n\n# Install development tools and utilities\nRUN apt-get update -y && apt upgrade -y && apt-get install -y \\\n    unzip \\\n    build-essential \\\n    cmake \\\n    vim \\\n    wget \\\n    curl \\\n    net-tools \\\n    zlib1g-dev \\\n    lld \\\n    clang \\\n    locales \\\n    ccache \\\n    openssl \\\n    libssl-dev \\\n    pkg-config \\\n    ca-certificates \\\n    && rm -rf /var/cache/apt/* \\\n    && rm -rf /var/lib/apt/lists/* \\\n    && update-ca-certificates \\\n    && locale-gen en_US.UTF-8\n\nENV LANG=en_US.UTF-8\nENV LANGUAGE=en_US:en\nENV LC_ALL=en_US.UTF-8\n\n\n### Install MemFabric\nRUN ${PIP_INSTALL} memfabric-hybrid==1.0.5\n### Install SGLang Model Gateway\nRUN ${PIP_INSTALL} sglang-router\n\n\n### Install PyTorch and PTA\nRUN . /etc/environment_new && \\\n    (${PIP_INSTALL} torch==${PYTORCH_VERSION} torchvision==${TORCHVISION_VERSION} --index-url https://download.pytorch.org/whl/cpu) \\\n    && (${PIP_INSTALL} ${PTA_URL})\n\n\n## Install triton-ascend\nRUN (${PIP_INSTALL} pybind11 triton-ascend)\n\n# Install SGLang\nRUN git clone https://github.com/sgl-project/sglang --branch $SGLANG_TAG && \\\n    (cd sglang/python && rm -rf pyproject.toml && mv pyproject_npu.toml pyproject.toml && ${PIP_INSTALL} -v .[all_npu]) && \\\n    rm -rf sglang\n\n# Install Deep-ep\n# pin wheel to 0.45.1 ref: https://github.com/pypa/wheel/issues/662\nRUN ${PIP_INSTALL} wheel==0.45.1 pybind11 pyyaml decorator scipy attrs psutil \\\n    && mkdir sgl-kernel-npu \\\n    && cd sgl-kernel-npu \\\n    && wget https://github.com/sgl-project/sgl-kernel-npu/releases/download/${SGLANG_KERNEL_NPU_TAG}/sgl-kernel-npu-${SGLANG_KERNEL_NPU_TAG}-torch2.8.0-py311-cann${CANN_VERSION}-${DEVICE_TYPE}-$(arch).zip \\\n    && unzip sgl-kernel-npu-${SGLANG_KERNEL_NPU_TAG}-torch2.8.0-py311-cann${CANN_VERSION}-${DEVICE_TYPE}-$(arch).zip \\\n    && ${PIP_INSTALL} deep_ep*.whl sgl_kernel_npu*.whl \\\n    && cd .. && rm -rf sgl-kernel-npu \\\n    && cd \"$(python3 -m pip show deep-ep | awk '/^Location:/ {print $2}')\" && ln -sf deep_ep/deep_ep_cpp*.so\n\nCMD [\"/bin/bash\"]\n"
  },
  {
    "path": "docker/rocm.Dockerfile",
    "content": "# Usage (to build SGLang ROCm docker image):\n#   docker build --build-arg SGL_BRANCH=v0.5.9 --build-arg GPU_ARCH=gfx942 -t v0.5.9-rocm700-mi30x -f rocm.Dockerfile .\n#   docker build --build-arg SGL_BRANCH=v0.5.9 --build-arg GPU_ARCH=gfx942-rocm720 -t v0.5.9-rocm720-mi30x -f rocm.Dockerfile .\n#   docker build --build-arg SGL_BRANCH=v0.5.9 --build-arg GPU_ARCH=gfx950 -t v0.5.9-rocm700-mi35x -f rocm.Dockerfile .\n#   docker build --build-arg SGL_BRANCH=v0.5.9 --build-arg GPU_ARCH=gfx950-rocm720 -t v0.5.9-rocm720-mi35x -f rocm.Dockerfile .\n\n# Usage (to build SGLang ROCm + Mori docker image):\n#   docker build --build-arg SGL_BRANCH=v0.5.9 --build-arg GPU_ARCH=gfx942 --build-arg ENABLE_MORI=1 --build-arg NIC_BACKEND=ainic -t v0.5.9-rocm700-mi30x -f rocm.Dockerfile .\n#   docker build --build-arg SGL_BRANCH=v0.5.9 --build-arg GPU_ARCH=gfx942-rocm720 --build-arg ENABLE_MORI=1 --build-arg NIC_BACKEND=ainic -t v0.5.9-rocm720-mi30x -f rocm.Dockerfile .\n#   docker build --build-arg SGL_BRANCH=v0.5.9 --build-arg GPU_ARCH=gfx950 --build-arg ENABLE_MORI=1 --build-arg NIC_BACKEND=ainic -t v0.5.9-rocm700-mi35x -f rocm.Dockerfile .\n#   docker build --build-arg SGL_BRANCH=v0.5.9 --build-arg GPU_ARCH=gfx950-rocm720 --build-arg ENABLE_MORI=1 --build-arg NIC_BACKEND=ainic -t v0.5.9-rocm720-mi35x -f rocm.Dockerfile .\n\n# Default base images\nARG BASE_IMAGE_942=\"rocm/sgl-dev:rocm7-vllm-20250904\"\nARG BASE_IMAGE_942_ROCM720=\"rocm/pytorch:rocm7.2_ubuntu22.04_py3.10_pytorch_release_2.9.1\"\nARG BASE_IMAGE_950=\"rocm/sgl-dev:rocm7-vllm-20250904\"\nARG BASE_IMAGE_950_ROCM720=\"rocm/pytorch:rocm7.2_ubuntu22.04_py3.10_pytorch_release_2.9.1\"\n\n# This is necessary for scope purpose\nARG GPU_ARCH=gfx950\n\n# ===============================\n# Base image 942 with rocm700 and args\nFROM $BASE_IMAGE_942 AS gfx942\nENV BUILD_VLLM=\"0\"\nENV BUILD_TRITON=\"0\"\nENV BUILD_LLVM=\"0\"\nENV BUILD_AITER_ALL=\"1\"\nENV BUILD_MOONCAKE=\"1\"\nENV AITER_COMMIT=\"v0.1.11.post1\"\n\n# ===============================\n# Base image 942 with rocm720 and args\nFROM $BASE_IMAGE_942_ROCM720 AS gfx942-rocm720\nENV BUILD_VLLM=\"0\"\nENV BUILD_TRITON=\"1\"\nENV BUILD_LLVM=\"0\"\nENV BUILD_AITER_ALL=\"1\"\nENV BUILD_MOONCAKE=\"1\"\nENV AITER_COMMIT=\"v0.1.11.post1\"\n\n# ===============================\n# Base image 950 and args\nFROM $BASE_IMAGE_950 AS gfx950\nENV BUILD_VLLM=\"0\"\nENV BUILD_TRITON=\"0\"\nENV BUILD_LLVM=\"0\"\nENV BUILD_AITER_ALL=\"1\"\nENV BUILD_MOONCAKE=\"1\"\nENV AITER_COMMIT=\"v0.1.11.post1\"\n\n# ===============================\n# Base image 950 with rocm720 and args\nFROM $BASE_IMAGE_950_ROCM720 AS gfx950-rocm720\nENV BUILD_VLLM=\"0\"\nENV BUILD_TRITON=\"1\"\nENV BUILD_LLVM=\"0\"\nENV BUILD_AITER_ALL=\"1\"\nENV BUILD_MOONCAKE=\"1\"\nENV AITER_COMMIT=\"v0.1.11.post1\"\n\n# ===============================\n# Chosen arch and args\nFROM ${GPU_ARCH}\n\n# This is necessary for scope purpose, again\nARG GPU_ARCH=gfx950\nENV GPU_ARCH_LIST=${GPU_ARCH%-*}\nENV PYTORCH_ROCM_ARCH=gfx942;gfx950\n\nARG SGL_REPO=\"https://github.com/sgl-project/sglang.git\"\nARG SGL_DEFAULT=\"main\"\nARG SGL_BRANCH=${SGL_DEFAULT}\n\n# Version override for setuptools_scm (used in nightly builds)\nARG SETUPTOOLS_SCM_PRETEND_VERSION=\"\"\n\nARG TRITON_REPO=\"https://github.com/triton-lang/triton.git\"\nARG TRITON_COMMIT=\"42270451990532c67e69d753fbd026f28fcc4840\"\n\nARG AITER_REPO=\"https://github.com/ROCm/aiter.git\"\n\nARG LLVM_REPO=\"https://github.com/jrbyrnes/llvm-project.git\"\nARG LLVM_BRANCH=\"MainOpSelV2\"\nARG LLVM_COMMIT=\"6520ace8227ffe2728148d5f3b9872a870b0a560\"\n\nARG MOONCAKE_REPO=\"https://github.com/kvcache-ai/Mooncake.git\"\nARG MOONCAKE_COMMIT=\"b6a841dc78c707ec655a563453277d969fb8f38d\"\n\nARG TILELANG_REPO=\"https://github.com/tile-ai/tilelang.git\"\nARG TILELANG_COMMIT=\"ebf4a7cb8881432165ae8760e99d209d905c704a\"\n\nARG FHT_REPO=\"https://github.com/jeffdaily/fast-hadamard-transform.git\"\nARG FHT_BRANCH=\"rocm\"\nARG FHT_COMMIT=\"46efb7d776d38638fc39f3c803eaee3dd7016bd1\"\n\nARG ENABLE_MORI=0\nARG NIC_BACKEND=none\n\nARG MORI_REPO=\"https://github.com/ROCm/mori.git\"\nARG MORI_COMMIT=\"2f88d06aba75400262ca5c1ca5986cf1fdf4cd82\"\n\n# AMD AINIC apt repo settings\nARG AINIC_VERSION=1.117.5\nARG UBUNTU_CODENAME=jammy\nUSER root\n\n# Fix hipDeviceGetName returning empty string in ROCm 7.0 docker images.\n# The ROCm 7.0 base image is missing libdrm-amdgpu-common which provides the\n# amdgpu.ids device-ID-to-marketing-name mapping file.\n# ROCm 7.2 base images already ship these packages, so this step is skipped.\n# See https://github.com/ROCm/ROCm/issues/5992\nRUN set -eux; \\\n    case \"${GPU_ARCH}\" in \\\n      *rocm720*) \\\n        echo \"ROCm 7.2 (GPU_ARCH=${GPU_ARCH}): libdrm-amdgpu packages already present, skipping\"; \\\n        ;; \\\n      *) \\\n        echo \"ROCm 7.0 (GPU_ARCH=${GPU_ARCH}): installing libdrm-amdgpu packages\"; \\\n        curl -fsSL https://repo.radeon.com/rocm/rocm.gpg.key \\\n          | gpg --dearmor -o /etc/apt/keyrings/amdgpu-graphics.gpg \\\n        && echo 'deb [arch=amd64,i386 signed-by=/etc/apt/keyrings/amdgpu-graphics.gpg] https://repo.radeon.com/graphics/7.0/ubuntu jammy main' \\\n          > /etc/apt/sources.list.d/amdgpu-graphics.list \\\n        && apt-get update \\\n        && apt-get install -y --no-install-recommends \\\n             libdrm-amdgpu-common \\\n             libdrm-amdgpu-amdgpu1 \\\n             libdrm2-amdgpu \\\n        && rm -rf /var/lib/apt/lists/* \\\n        && cp /opt/amdgpu/share/libdrm/amdgpu.ids /usr/share/libdrm/amdgpu.ids; \\\n        ;; \\\n    esac\n\n\n# Install some basic utilities\nRUN python -m pip install --upgrade pip && pip install setuptools_scm\nRUN apt-get purge -y sccache; python -m pip uninstall -y sccache; rm -f \"$(which sccache)\"\n\n# Install AMD SMI Python package from ROCm distribution.\n# The ROCm 7.2 base image (rocm/pytorch) does not pre-install this package.\nRUN set -eux; \\\n    case \"${GPU_ARCH}\" in \\\n      *rocm720*) \\\n        echo \"ROCm 7.2 flavor detected from GPU_ARCH=${GPU_ARCH}\"; \\\n        cd /opt/rocm/share/amd_smi \\\n        && python3 -m pip install --no-cache-dir . \\\n        ;; \\\n      *) \\\n        echo \"Not rocm720 (GPU_ARCH=${GPU_ARCH}), skip amdsmi installation\"; \\\n        ;; \\\n    esac\n\nWORKDIR /sgl-workspace\n\n# -----------------------\n# llvm\nRUN if [ \"$BUILD_LLVM\" = \"1\" ]; then \\\n     ENV HIP_CLANG_PATH=\"/sgl-workspace/llvm-project/build/bin/\" \\\n     git clone --single-branch ${LLVM_REPO} -b ${LLVM_BRANCH} \\\n     && cd llvm-project \\\n     && git checkout ${LLVM_COMMIT} \\\n     && mkdir build \\\n     && cd build \\\n     && cmake -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD=\"AMDGPU;X86\" -DLLVM_ENABLE_PROJECTS=\"clang;lld;\" -DLLVM_ENABLE_RUNTIMES=\"compiler-rt\" ../llvm \\\n     && make -j$(nproc); \\\n    fi\n\n# -----------------------\n# AITER\n# Unset setuptools_scm override so AITER gets its own version (AITER_COMMIT), not SGLang's\n# (SETUPTOOLS_SCM_PRETEND_VERSION is set later for SGLang nightly builds and would otherwise\n# leak into AITER's version when AITER uses setuptools_scm)\nENV SETUPTOOLS_SCM_PRETEND_VERSION=\nRUN pip uninstall -y aiter \\\n && pip install flydsl==0.0.1.dev95158637 \\\n && pip install psutil pybind11 # Required by AITER setup.py\nRUN git clone ${AITER_REPO} \\\n && cd aiter \\\n && git checkout ${AITER_COMMIT} \\\n && git submodule update --init --recursive\n\n# Hot patches for AITER in v0.1.10.post3\n# This is for ROCm 7.2 only, because of the image rebase from vllm\n# to rocm/pytorch.\nRUN set -eux; \\\n    case \"${GPU_ARCH}\" in \\\n      *rocm720*) \\\n        echo \"ROCm 7.2 flavor detected from GPU_ARCH=${GPU_ARCH}\"; \\\n        cd aiter \\\n        && sed -i '459 s/if.*:/if False:/' aiter/ops/triton/attention/pa_mqa_logits.py; \\\n        ;; \\\n      *) \\\n        echo \"Not rocm720 (GPU_ARCH=${GPU_ARCH}), skip patch\"; \\\n        ;; \\\n    esac\n# [WA] from kk-huang\n# add sed -i '/c1 = torch.empty((M, D, S1 + S3) for aiter triton gemm config issue\n# the corresponding pr is https://github.com/ROCm/aiter/pull/2173\n# it will be removed when server launched issue is fixed by aiter\nRUN cd aiter \\\n     && echo \"[AITER] GPU_ARCH=${GPU_ARCH}\" \\\n     && sed -i '/c1 = torch.empty((M, D, S1 + S3), dtype=dtype, device=x.device)/i\\    config = dict(config)' aiter/ops/triton/gemm/fused/fused_gemm_afp4wfp4_split_cat.py \\\n     && if [ \"$BUILD_AITER_ALL\" = \"1\" ] && [ \"$BUILD_LLVM\" = \"1\" ]; then \\\n          sh -c \"HIP_CLANG_PATH=/sgl-workspace/llvm-project/build/bin/ PREBUILD_KERNELS=1 GPU_ARCHS=$GPU_ARCH_LIST python setup.py build_ext --inplace\" \\\n          && sh -c \"HIP_CLANG_PATH=/sgl-workspace/llvm-project/build/bin/ GPU_ARCHS=$GPU_ARCH_LIST pip install -e .\"; \\\n        elif [ \"$BUILD_AITER_ALL\" = \"1\" ]; then \\\n          sh -c \"PREBUILD_KERNELS=1 GPU_ARCHS=$GPU_ARCH_LIST python setup.py build_ext --inplace\" \\\n          && sh -c \"GPU_ARCHS=$GPU_ARCH_LIST pip install -e .\"; \\\n        else \\\n          sh -c \"GPU_ARCHS=$GPU_ARCH_LIST pip install -e .\"; \\\n        fi \\\n      && echo \"export PYTHONPATH=/sgl-workspace/aiter:\\${PYTHONPATH}\" >> /etc/bash.bashrc\n\n# -----------------------\n# Build Mooncake\nENV PATH=$PATH:/usr/local/go/bin\n\nRUN if [ \"$BUILD_MOONCAKE\" = \"1\" ]; then \\\n     apt update && apt install -y zip unzip wget && \\\n     apt install -y gcc make libtool autoconf  librdmacm-dev rdmacm-utils infiniband-diags ibverbs-utils perftest ethtool  libibverbs-dev rdma-core && \\\n     apt install -y openssh-server openmpi-bin openmpi-common libopenmpi-dev && \\\n     git clone ${MOONCAKE_REPO} && \\\n     cd Mooncake && \\\n     git checkout ${MOONCAKE_COMMIT} && \\\n     git submodule update --init --recursive && \\\n     bash dependencies.sh -y && \\\n     rm -rf /usr/local/go && \\\n     wget https://go.dev/dl/go1.22.2.linux-amd64.tar.gz && \\\n     tar -C /usr/local -xzf go1.22.2.linux-amd64.tar.gz && \\\n     rm go1.22.2.linux-amd64.tar.gz && \\\n     mkdir -p build && \\\n     cd build && \\\n     cmake .. -DUSE_HIP=ON -DUSE_ETCD=ON && \\\n     make -j \"$(nproc)\" && make install; \\\n    fi\n\n# -----------------------\n# Build SGLang\nARG BUILD_TYPE=all\n\n# Set version for setuptools_scm if provided (for nightly builds). Only pass in the SGLang\n# pip install RUN so it does not affect AITER, sgl-model-gateway, TileLang, FHT, MORI, etc.\nARG SETUPTOOLS_SCM_PRETEND_VERSION\n\nRUN pip install IPython \\\n    && pip install orjson \\\n    && pip install python-multipart \\\n    && pip install torchao==0.9.0 \\\n    && pip install pybind11\n\nRUN pip uninstall -y sgl_kernel sglang\nRUN git clone ${SGL_REPO} \\\n    && cd sglang \\\n    && if [ \"${SGL_BRANCH}\" = ${SGL_DEFAULT} ]; then \\\n         echo \"Using ${SGL_DEFAULT}, default branch.\"; \\\n         git checkout ${SGL_DEFAULT}; \\\n       else \\\n         echo \"Using ${SGL_BRANCH} branch.\"; \\\n         git checkout ${SGL_BRANCH}; \\\n       fi \\\n    && cd sgl-kernel \\\n    && rm -f pyproject.toml \\\n    && mv pyproject_rocm.toml pyproject.toml \\\n    && AMDGPU_TARGET=$GPU_ARCH_LIST python setup_rocm.py install \\\n    && cd .. \\\n    && rm -rf python/pyproject.toml && mv python/pyproject_other.toml python/pyproject.toml \\\n    && if [ \"$BUILD_TYPE\" = \"srt\" ]; then \\\n         export SETUPTOOLS_SCM_PRETEND_VERSION=\"${SETUPTOOLS_SCM_PRETEND_VERSION}\" && python -m pip --no-cache-dir install -e \"python[srt_hip,diffusion_hip]\"; \\\n       else \\\n         export SETUPTOOLS_SCM_PRETEND_VERSION=\"${SETUPTOOLS_SCM_PRETEND_VERSION}\" && python -m pip --no-cache-dir install -e \"python[all_hip]\"; \\\n       fi\n\nRUN python -m pip cache purge\n\n# Copy config files to support MI300X in virtualized environments (MI300X_VF).  Symlinks will not be created in image build.\nRUN find /sgl-workspace/sglang/python/sglang/srt/layers/quantization/configs/ \\\n         /sgl-workspace/sglang/python/sglang/srt/layers/moe/fused_moe_triton/configs/ \\\n         -type f -name '*MI300X*' | xargs -I {} sh -c 'vf_config=$(echo \"$1\" | sed \"s/MI300X/MI300X_VF/\"); cp \"$1\" \"$vf_config\"' -- {}\n\n# Install Rust toolchain for sgl-model-gateway\nENV PATH=\"/root/.cargo/bin:${PATH}\"\nRUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y \\\n    && rustc --version && cargo --version\nENV CARGO_BUILD_JOBS=4\n\n# Build and install sgl-model-gateway\nRUN python3 -m pip install --no-cache-dir maturin \\\n    && cd /sgl-workspace/sglang/sgl-model-gateway/bindings/python \\\n    && ulimit -n 65536 && maturin build --release --features vendored-openssl --out dist \\\n    && python3 -m pip install --force-reinstall dist/*.whl \\\n    && rm -rf /root/.cache\n\n# -----------------------\n# TileLang\nENV DEBIAN_FRONTEND=noninteractive\nENV LIBGL_ALWAYS_INDIRECT=1\nRUN echo \"LC_ALL=en_US.UTF-8\" >> /etc/environment\n\nRUN /bin/bash -lc 'set -euo pipefail; \\\n  echo \"[TileLang] Building TileLang for ${GPU_ARCH}\"; \\\n  # System dependencies (NO llvm-dev to avoid llvm-config-16 shadowing)\n  apt-get update && apt-get install -y --no-install-recommends \\\n      build-essential git wget curl ca-certificates gnupg \\\n      libgtest-dev libgmock-dev \\\n      libprotobuf-dev protobuf-compiler libgflags-dev libsqlite3-dev \\\n      python3 python3-dev python3-setuptools python3-pip python3-apt \\\n      gcc libtinfo-dev zlib1g-dev libedit-dev libxml2-dev vim \\\n      cmake ninja-build pkg-config libstdc++6 software-properties-common \\\n  && rm -rf /var/lib/apt/lists/*; \\\n  \\\n  # Prefer the container venv\n  VENV_PY=\"/opt/venv/bin/python\"; \\\n  VENV_PIP=\"/opt/venv/bin/pip\"; \\\n  if [ ! -x \"$VENV_PY\" ]; then VENV_PY=\"python3\"; fi; \\\n  if [ ! -x \"$VENV_PIP\" ]; then VENV_PIP=\"pip3\"; fi; \\\n  \\\n  # Build GoogleTest static libs (Ubuntu package ships sources only)\n  cmake -S /usr/src/googletest -B /tmp/build-gtest -DBUILD_GTEST=ON -DBUILD_GMOCK=ON -DCMAKE_BUILD_TYPE=Release && \\\n  cmake --build /tmp/build-gtest -j\"$(nproc)\" && \\\n  cp -v /tmp/build-gtest/lib/*.a /usr/lib/x86_64-linux-gnu/ && \\\n  rm -rf /tmp/build-gtest; \\\n  \\\n  # Keep setuptools < 80 (compat with base image)\n  \"$VENV_PIP\" install --upgrade \"setuptools>=77.0.3,<80\" wheel cmake ninja scikit-build-core && \\\n  \"$VENV_PIP\" cache purge || true; \\\n  \\\n  # Locate ROCm llvm-config; fallback to installing LLVM 18 if missing\n  LLVM_CONFIG_PATH=\"\"; \\\n  for p in /opt/rocm/llvm/bin/llvm-config /opt/rocm/llvm-*/bin/llvm-config /opt/rocm-*/llvm*/bin/llvm-config; do \\\n    if [ -x \"$p\" ]; then LLVM_CONFIG_PATH=\"$p\"; break; fi; \\\n  done; \\\n  if [ -z \"$LLVM_CONFIG_PATH\" ]; then \\\n    echo \"[TileLang] ROCm llvm-config not found; installing LLVM 18...\"; \\\n    curl -fsSL https://apt.llvm.org/llvm-snapshot.gpg.key | gpg --dearmor -o /etc/apt/keyrings/llvm.gpg; \\\n    echo \"deb [signed-by=/etc/apt/keyrings/llvm.gpg] http://apt.llvm.org/jammy/ llvm-toolchain-jammy-18 main\" > /etc/apt/sources.list.d/llvm.list; \\\n    apt-get update; \\\n    apt-get install -y --no-install-recommends llvm-18; \\\n    rm -rf /var/lib/apt/lists/*; \\\n    LLVM_CONFIG_PATH=\"$(command -v llvm-config-18)\"; \\\n    if [ -z \"$LLVM_CONFIG_PATH\" ]; then echo \"ERROR: llvm-config-18 not found after install\"; exit 1; fi; \\\n  fi; \\\n  echo \"[TileLang] Using LLVM_CONFIG at: $LLVM_CONFIG_PATH\"; \\\n  export PATH=\"$(dirname \"$LLVM_CONFIG_PATH\"):/usr/local/bin:${PATH}\"; \\\n  export LLVM_CONFIG=\"$LLVM_CONFIG_PATH\"; \\\n  \\\n  # Optional shim for tools that expect llvm-config-16\n  mkdir -p /usr/local/bin && \\\n  printf \"#!/usr/bin/env bash\\nexec \\\"%s\\\" \\\"\\$@\\\"\\n\" \"$LLVM_CONFIG_PATH\" > /usr/local/bin/llvm-config-16 && \\\n  chmod +x /usr/local/bin/llvm-config-16; \\\n  \\\n  # TVM Python bits need Cython + z3 before configure.\n  # Pin z3-solver==4.15.4.0: 4.15.4.0 has a manylinux wheel; 4.15.5.0 has no wheel and builds from source (fails: C++20 <format> needs GCC 14+, image has GCC 11).\n  \"$VENV_PIP\" install --no-cache-dir \"cython>=0.29.36,<3.0\" \"apache-tvm-ffi @ git+https://github.com/apache/tvm-ffi.git@37d0485b2058885bf4e7a486f7d7b2174a8ac1ce\" \"z3-solver==4.15.4.0\"; \\\n  \\\n  # Clone + pin TileLang (bundled TVM), then build\n  git clone --recursive \"${TILELANG_REPO}\" /opt/tilelang && \\\n  cd /opt/tilelang && \\\n  git fetch --depth=1 origin \"${TILELANG_COMMIT}\" || true && \\\n  git checkout -f \"${TILELANG_COMMIT}\" && \\\n  git submodule update --init --recursive && \\\n  export CMAKE_ARGS=\"-DUSE_CUDA=OFF -DUSE_ROCM=ON -DROCM_PATH=/opt/rocm -DLLVM_CONFIG=${LLVM_CONFIG} -DSKBUILD_SABI_VERSION= ${CMAKE_ARGS:-}\" && \\\n  \"$VENV_PIP\" install -e . -v --no-build-isolation --no-deps; \\\n  if [ -f pyproject.toml ]; then sed -i \"/^[[:space:]]*\\\"torch/d\" pyproject.toml || true; fi; \\\n  \"$VENV_PIP\" cache purge || true; \\\n  \"$VENV_PY\" -c \"import tilelang; print(tilelang.__version__)\"'\n\n# -----------------------\n# Hadamard-transform (HIP build)\nRUN /bin/bash -lc 'set -euo pipefail; \\\n    git clone --branch \"${FHT_BRANCH}\" \"${FHT_REPO}\" fast-hadamard-transform; \\\n    cd fast-hadamard-transform; \\\n    git checkout -f \"${FHT_COMMIT}\"; \\\n    python setup.py install'\n\n# -----------------------\n# Python tools\nRUN python3 -m pip install --no-cache-dir \\\n    py-spy \\\n    pre-commit \\\n    tabulate\n\n# -----------------------\n# MORI (optional)\nRUN /bin/bash -lc 'set -euo pipefail; \\\n  if [ \"${ENABLE_MORI}\" != \"1\" ]; then \\\n    echo \"[MORI] Skipping (ENABLE_MORI=${ENABLE_MORI})\"; \\\n    exit 0; \\\n  fi; \\\n  echo \"[MORI] Enabling MORI (NIC_BACKEND=${NIC_BACKEND})\"; \\\n  \\\n  # Base deps for MORI build\n  apt-get update && apt-get install -y --no-install-recommends \\\n      build-essential \\\n      g++ \\\n      jq \\\n      libopenmpi-dev \\\n      libpci-dev \\\n      initramfs-tools \\\n  && rm -rf /var/lib/apt/lists/*; \\\n  \\\n  # NIC backend deps\n  case \"${NIC_BACKEND}\" in \\\n    # default: mlx5\n    none) \\\n      export USE_IONIC=\"OFF\"; \\\n      export USE_BNXT=\"OFF\"; \\\n      ;; \\\n    # AMD NIC\n    ainic) \\\n      export USE_IONIC=\"ON\"; \\\n      export USE_BNXT=\"OFF\"; \\\n      apt-get update && apt-get install -y --no-install-recommends ca-certificates curl gnupg apt-transport-https && \\\n      rm -rf /var/lib/apt/lists/* && mkdir -p /etc/apt/keyrings; \\\n      curl -fsSL https://repo.radeon.com/rocm/rocm.gpg.key | gpg --dearmor > /etc/apt/keyrings/amdainic.gpg; \\\n      echo \"deb [arch=amd64 signed-by=/etc/apt/keyrings/amdainic.gpg] https://repo.radeon.com/amdainic/pensando/ubuntu/${AINIC_VERSION} ${UBUNTU_CODENAME} main\" \\\n        > /etc/apt/sources.list.d/amdainic.list; \\\n      apt-get update && apt-get install -y --no-install-recommends \\\n          libionic-dev \\\n          ionic-common \\\n      ; \\\n      rm -rf /var/lib/apt/lists/*; \\\n      ;; \\\n    # TODO: Add Broadcom bnxt packages/repos here later.\n    # bnxt) \\\n    #   export USE_IONIC=\"OFF\"; \\\n    #   export USE_BNXT=\"ON\"; \\\n    #   echo \"[MORI] NIC_BACKEND=bnxt: USE_BNXT=ON. Add Broadcom bnxt packages/repos here later.\"; \\\n    #   ;; \\\n    *) \\\n      echo \"ERROR: unknown NIC_BACKEND=${NIC_BACKEND}. Use one of: none, ainic\"; \\\n      exit 2; \\\n      ;; \\\n  esac; \\\n  \\\n  # Build/install MORI\n  export MORI_GPU_ARCHS=\"${GPU_ARCH_LIST}\"; \\\n  echo \"[MORI] MORI_GPU_ARCHS=${MORI_GPU_ARCHS} USE_IONIC=${USE_IONIC} USE_BNXT=${USE_BNXT}\"; \\\n  rm -rf /sgl-workspace/mori; \\\n  git clone \"${MORI_REPO}\" /sgl-workspace/mori; \\\n  cd /sgl-workspace/mori; \\\n  git checkout \"${MORI_COMMIT}\"; \\\n  git submodule update --init --recursive; \\\n  python3 setup.py develop; \\\n  python3 -c \"import os, torch; print(os.path.join(os.path.dirname(torch.__file__), \\\"lib\\\"))\" > /etc/ld.so.conf.d/torch.conf; \\\n  ldconfig; \\\n  echo \"export PYTHONPATH=/sgl-workspace/mori:\\${PYTHONPATH}\" >> /etc/bash.bashrc; \\\n  echo \"[MORI] Done.\"'\n\n# -----------------------\n# Hot patch: torch-ROCm\n# The artifact hardcoded the supported triton version to be 3.5.1.\n# Rewrite the restriction directly.\nARG TORCH_ROCM_FILE=\"torch-2.9.1+rocm7.2.0.lw.git7e1940d4-cp310-cp310-linux_x86_64.whl\"\nRUN mkdir /tmp/whl && cd /tmp/whl \\\n     && export TORCH_ROCM_FILE=\"${TORCH_ROCM_FILE}\" \\\n     && cat > hack.py <<\"PY\"\nimport zipfile, csv, os, re\nfrom pathlib import Path\n\nfname = os.environ[\"TORCH_ROCM_FILE\"]\nin_whl  = Path(\"/\")   / fname\nout_whl = Path(\"/tmp\")/ fname\nwork = Path(\"/tmp/whl\")\n\n# 1) Extract\nwith zipfile.ZipFile(in_whl, \"r\") as z:\n    z.extractall(work)\n\n# 2) Locate dist-info and patch METADATA (edit this logic to match your exact line)\ndist_info = next(work.glob(\"*.dist-info\"))\nmeta = dist_info / \"METADATA\"\ntxt = meta.read_text(encoding=\"utf-8\")\n\n# Example: replace one exact requirement form.\n# Adjust the string to match what you actually see.\npat = r\"^Requires-Dist:\\s*triton==3.5.1[^\\s]*;\"\ntxt2, n = re.subn(pat, r\"triton>=3.5.1;\", txt, flags=re.MULTILINE)\nif txt2 == txt:\n    raise SystemExit(\"Did not find expected Requires-Dist line to replace in METADATA\")\nmeta.write_text(txt2, encoding=\"utf-8\")\n\n# 3) Hacky step: blank hash/size columns in RECORD\nrecord = dist_info / \"RECORD\"\nrows = []\nwith record.open(newline=\"\", encoding=\"utf-8\") as f:\n    for r in csv.reader(f):\n        if not r:\n            continue\n        # keep filename, blank out hash and size\n        rows.append([r[0], \"\", \"\"])\nwith record.open(\"w\", newline=\"\", encoding=\"utf-8\") as f:\n    csv.writer(f).writerows(rows)\n\n# 4) Re-zip as a wheel\nwith zipfile.ZipFile(out_whl, \"w\", compression=zipfile.ZIP_DEFLATED) as z:\n    for p in work.rglob(\"*\"):\n        if p.is_file():\n            z.write(p, p.relative_to(work).as_posix())\n\nprint(\"Wrote\", out_whl)\nPY\n\nRUN cd /tmp/whl \\\n    && case \"${GPU_ARCH}\" in \\\n      *rocm720*) \\\n        echo \"ROCm 7.2 flavor detected from GPU_ARCH=${GPU_ARCH}\"; \\\n        python hack.py \\\n        && python3 -m pip install --force --no-deps /tmp/${TORCH_ROCM_FILE} \\\n        && rm -fr /tmp/whl /tmp/${TORCH_ROCM_FILE} \\\n        ;; \\\n      *) \\\n        echo \"Not rocm720 (GPU_ARCH=${GPU_ARCH}), skip patch\"; \\\n        ;; \\\n    esac\n\n\n# -----------------------\n# Hot patch: Triton\n# For ROCm 7.2, this custom build breaks pip dependency management,\n# so future `pip install` will break the ROCm stack.\n# A workaround for this is to reinstall the default triton\n# wheel with the `rocm/pytorch` image in the root directory.\nRUN if [ \"$BUILD_TRITON\" = \"1\" ]; then \\\n        pip uninstall -y triton \\\n     && apt install -y cmake \\\n     && git clone ${TRITON_REPO} triton-custom \\\n     && cd triton-custom \\\n     && git checkout ${TRITON_COMMIT} \\\n     && pip install -r python/requirements.txt \\\n     && pip install -e .; \\\n    fi\n\n# -----------------------\n# Performance environment variable.\n\n# Skip CuDNN compatibility check - not applicable for ROCm (uses MIOpen instead)\nENV SGLANG_DISABLE_CUDNN_CHECK=1\nENV HIP_FORCE_DEV_KERNARG=1\nENV HSA_NO_SCRATCH_RECLAIM=1\nENV SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1\nENV SGLANG_INT4_WEIGHT=0\nENV SGLANG_MOE_PADDING=1\nENV SGLANG_ROCM_DISABLE_LINEARQUANT=0\nENV SGLANG_ROCM_FUSED_DECODE_MLA=1\nENV SGLANG_SET_CPU_AFFINITY=1\nENV SGLANG_USE_AITER=1\nENV SGLANG_USE_ROCM700A=1\n\nENV NCCL_MIN_NCHANNELS=112\nENV ROCM_QUICK_REDUCE_QUANTIZATION=INT8\nENV TORCHINDUCTOR_MAX_AUTOTUNE=1\nENV TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE=1\n\nCMD [\"/bin/bash\"]\n"
  },
  {
    "path": "docker/sagemaker.Dockerfile",
    "content": "FROM lmsysorg/sglang:latest\n\nCOPY serve /usr/bin/serve\nRUN chmod 777 /usr/bin/serve\n\nENTRYPOINT [ \"/usr/bin/serve\" ]\n"
  },
  {
    "path": "docker/serve",
    "content": "#!/bin/bash\necho \"Starting server\"\n\nPREFIX=\"SM_SGLANG_\"\nARG_PREFIX=\"--\"\n\nARGS=()\n\nwhile IFS='=' read -r key value; do\n    arg_name=$(echo \"${key#\"${PREFIX}\"}\" | tr '[:upper:]' '[:lower:]' | tr '_' '-')\n\n    ARGS+=(\"${ARG_PREFIX}${arg_name}\")\n    if [ -n \"$value\" ]; then\n        ARGS+=(\"$value\")\n    fi\ndone < <(env | grep \"^${PREFIX}\")\n\n# Add default port only if not already set\nif ! [[ \" ${ARGS[@]} \" =~ \" --port \" ]]; then\n    ARGS+=(--port \"${SM_SGLANG_PORT:-8080}\")\nfi\n\n# Add default host only if not already set\nif ! [[ \" ${ARGS[@]} \" =~ \" --host \" ]]; then\n    ARGS+=(--host \"${SM_SGLANG_HOST:-0.0.0.0}\")\nfi\n\n# Add default model-path only if not already set\nif ! [[ \" ${ARGS[@]} \" =~ \" --model-path \" ]]; then\n    ARGS+=(--model-path \"${SM_SGLANG_MODEL_PATH:-/opt/ml/model}\")\nfi\n\necho \"Running command: exec python3 -m sglang.launch_server ${ARGS[@]}\"\nexec python3 -m sglang.launch_server \"${ARGS[@]}\"\n"
  },
  {
    "path": "docker/xeon.Dockerfile",
    "content": "FROM ubuntu:24.04\nSHELL [\"/bin/bash\", \"-c\"]\n\nARG SGLANG_REPO=https://github.com/sgl-project/sglang.git\nARG VER_SGLANG=main\n\nRUN apt-get update && \\\n    apt-get full-upgrade -y && \\\n    DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \\\n    ca-certificates \\\n    git \\\n    curl \\\n    wget \\\n    vim \\\n    gcc \\\n    g++ \\\n    make \\\n    libsqlite3-dev \\\n    google-perftools \\\n    libtbb-dev \\\n    libnuma-dev \\\n    numactl\n\nWORKDIR /opt\n\nRUN curl -LsSf https://astral.sh/uv/install.sh | sh && \\\n    source $HOME/.local/bin/env && \\\n    uv venv --python 3.12\n\nRUN echo -e '[[index]]\\nname = \"torch\"\\nurl = \"https://download.pytorch.org/whl/cpu\"\\n\\n[[index]]\\nname = \"torchvision\"\\nurl = \"https://download.pytorch.org/whl/cpu\"\\n\\n[[index]]\\nname = \"torchaudio\"\\nurl = \"https://download.pytorch.org/whl/cpu\"\\n\\n[[index]]\\nname = \"triton\"\\nurl = \"https://download.pytorch.org/whl/cpu\"' > .venv/uv.toml\n\nENV UV_CONFIG_FILE=/opt/.venv/uv.toml\n\nWORKDIR /sgl-workspace\nRUN source $HOME/.local/bin/env && \\\n    source /opt/.venv/bin/activate && \\\n    git clone ${SGLANG_REPO} sglang && \\\n    cd sglang && \\\n    git checkout ${VER_SGLANG} && \\\n    cd python && \\\n    cp pyproject_cpu.toml pyproject.toml && \\\n    uv pip install . && \\\n    cd ../sgl-kernel && \\\n    cp pyproject_cpu.toml pyproject.toml && \\\n    uv pip install .\n\nENV SGLANG_USE_CPU_ENGINE=1\nENV LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4:/usr/lib/x86_64-linux-gnu/libtbbmalloc.so:/opt/.venv/lib/libiomp5.so\nRUN echo 'source /opt/.venv/bin/activate' >> /root/.bashrc\n\nWORKDIR /sgl-workspace/sglang\n"
  },
  {
    "path": "docker/xpu.Dockerfile",
    "content": "# If the device is Battlemage, we need to set UBUNTU_VERSION to 24.10\n\n# Usage: docker build --build-arg UBUNTU_VERSION=24.04 --build-arg PYTHON_VERSION=3.10 -t sglang:xpu_kernel -f  xpu.Dockerfile --no-cache .\n\n# Use Intel deep learning essentials base image with Ubuntu 24.04\nFROM intel/deep-learning-essentials:2025.3.2-0-devel-ubuntu24.04\n\n# Avoid interactive prompts during package install\nENV DEBIAN_FRONTEND=noninteractive\n\n# Define build arguments\nARG PYTHON_VERSION=3.10\n\nARG SG_LANG_REPO=https://github.com/sgl-project/sglang.git\nARG SG_LANG_BRANCH=main\n\nARG SG_LANG_KERNEL_REPO=https://github.com/sgl-project/sgl-kernel-xpu.git\nARG SG_LANG_KERNEL_BRANCH=main\n\nRUN useradd -m -d /home/sdp -s /bin/bash sdp && \\\n    chown -R sdp:sdp /home/sdp\n\n# Switch to non-root user 'sdp'\nUSER sdp\n\n# Set HOME and WORKDIR to user's home directory\nENV HOME=/home/sdp\nWORKDIR /home/sdp\n\nRUN curl -fsSL -v -o miniforge.sh -O https://github.com/conda-forge/miniforge/releases/download/25.1.1-0/Miniforge3-Linux-x86_64.sh && \\\n    bash miniforge.sh -b -p ./miniforge3 && \\\n    rm miniforge.sh && \\\n    # Initialize conda environment and install pip\n    . ./miniforge3/bin/activate && \\\n    conda create -y -n py${PYTHON_VERSION} python=${PYTHON_VERSION} && \\\n    conda activate py${PYTHON_VERSION} && \\\n    conda install pip && \\\n    # Append environment activation to .bashrc for interactive shells\n    echo \". /home/sdp/miniforge3/bin/activate; conda activate py${PYTHON_VERSION}; . /opt/intel/oneapi/setvars.sh; cd /home/sdp\" >> /home/sdp/.bashrc\n\nUSER root\nRUN apt-get update && apt install -y intel-ocloc\n\n# Switch back to user sdp\nUSER sdp\n\nRUN --mount=type=secret,id=github_token \\\n    cd /home/sdp && \\\n    . /home/sdp/miniforge3/bin/activate && \\\n    conda activate py${PYTHON_VERSION} && \\\n    pip3 install torch==2.10.0+xpu torchao torchvision torchaudio triton-xpu==3.6.0 --index-url https://download.pytorch.org/whl/xpu\n\nRUN --mount=type=secret,id=github_token \\\n    cd /home/sdp && \\\n    . /home/sdp/miniforge3/bin/activate && \\\n    conda activate py${PYTHON_VERSION} && \\\n    echo \"Cloning ${SG_LANG_BRANCH} from ${SG_LANG_REPO}\" && \\\n    git clone --branch ${SG_LANG_BRANCH} --single-branch ${SG_LANG_REPO} && \\\n    cd sglang && cd python && \\\n    cp pyproject_xpu.toml pyproject.toml && \\\n    pip install . && \\\n    pip install xgrammar --no-deps && \\\n    pip install msgspec blake3 py-cpuinfo compressed_tensors gguf partial_json_parser einops tabulate --root-user-action=ignore && \\\n    conda install libsqlite=3.48.0 -y && \\\n    # Add environment setup commands to .bashrc again (in case it was overwritten)\n    echo \". /home/sdp/miniforge3/bin/activate; conda activate py${PYTHON_VERSION}; cd /home/sdp\" >> /home/sdp/.bashrc\n\n# Use bash as default shell with initialization from .bashrc\nSHELL [\"bash\", \"-c\"]\n\n# Start an interactive bash shell with all environment set up\nUSER sdp\nCMD [\"bash\", \"-c\", \"source /home/sdp/.bashrc && exec bash\"]\n"
  },
  {
    "path": "docs/Makefile",
    "content": "# Minimal Makefile for Sphinx documentation\nSPHINXOPTS    ?=\nSPHINXBUILD   ?= sphinx-build\nSPHINXAUTOBUILD ?= sphinx-autobuild\nSOURCEDIR     = .\nBUILDDIR      = _build\nPORT          ?= 8003\n\nhelp:\n\t@$(SPHINXBUILD) -M help \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n\t@echo \"\"\n\t@echo \"Additional targets:\"\n\t@echo \"  serve       to build and serve documentation with auto-build and live reload\"\n\n# Compile Notebook files and record execution time\ncompile:\n\t@set -e; \\\n\techo \"Starting Notebook compilation...\"; \\\n\tmkdir -p logs; \\\n\techo \"Notebook execution timings:\" > logs/timing.log; \\\n\tSTART_TOTAL=$$(date +%s); \\\n\tfind $(SOURCEDIR) -path \"*/_build/*\" -prune -o -name \"*.ipynb\" -print0 | \\\n\t\tparallel -0 -j3 --halt soon,fail=1 ' \\\n\t\tNB_NAME=$$(basename {}); \\\n\t\tSTART_TIME=$$(date +%s); \\\n\t\tretry --delay=0 --times=2 -- \\\n\t\t\tjupyter nbconvert --to notebook --execute --inplace \"{}\" \\\n\t\t\t--ExecutePreprocessor.timeout=600 \\\n\t\t\t--ExecutePreprocessor.kernel_name=python3; \\\n\t\tRET_CODE=$$?; \\\n\t\tEND_TIME=$$(date +%s); \\\n\t\tELAPSED_TIME=$$((END_TIME - START_TIME)); \\\n\t\techo \"$${NB_NAME}: $${ELAPSED_TIME}s\" >> logs/timing.log; \\\n\t\texit $$RET_CODE' || exit 1; \\\n\tEND_TOTAL=$$(date +%s); \\\n\tTOTAL_ELAPSED=$$((END_TOTAL - START_TOTAL)); \\\n\techo \"---------------------------------\" >> logs/timing.log; \\\n\techo \"Total execution time: $${TOTAL_ELAPSED}s\" >> logs/timing.log; \\\n\techo \"All Notebook execution timings:\" && cat logs/timing.log\n\n# Convert Notebook files to Markdown artifacts (no execution)\nmarkdown:\n\t@set -e; \\\n\techo \"Exporting docs to Markdown...\"; \\\n\tmkdir -p \"$(BUILDDIR)/html/markdown\"; \\\n\t\\\n\t# 1) Copy .md and .rst files as-is; additionally convert .rst -> .md \\\n\tfind $(SOURCEDIR) -path \"*/_build/*\" -prune -o \\( -name \"*.md\" -o -name \"*.rst\" \\) -print0 | \\\n\t\tparallel -0 -j3 --halt soon,fail=1 ' \\\n\t\tSRC=\"{}\"; \\\n\t\tREL_DIR=$$(dirname \"$$SRC\"); \\\n\t\tOUT_DIR=\"$(BUILDDIR)/html/markdown/$$REL_DIR\"; \\\n\t\tmkdir -p \"$$OUT_DIR\"; \\\n\t\tcp -f \"$$SRC\" \"$$OUT_DIR/\"; \\\n\t\tcase \"$$SRC\" in \\\n\t\t  *.rst) \\\n\t\t\tBASE=$$(basename \"$$SRC\" .rst); \\\n\t\t\tpandoc -f rst -t gfm \"$$SRC\" -o \"$$OUT_DIR/$$BASE.md\" ;; \\\n\t\tesac \\\n\t\t' || exit 1; \\\n\t\\\n\t# 2) Convert .ipynb -> .md \\\n\tfind $(SOURCEDIR) -path \"*/_build/*\" -prune -o -name \"*.ipynb\" -print0 | \\\n\t\tparallel -0 -j3 --halt soon,fail=1 ' \\\n\t\tNB_SRC=\"{}\"; \\\n\t\tREL_DIR=$$(dirname \"$$NB_SRC\"); \\\n\t\tNB_NAME=$$(basename \"$$NB_SRC\"); \\\n\t\tNB_BASE=$${NB_NAME%.ipynb}; \\\n\t\tOUT_DIR=\"$(BUILDDIR)/html/markdown/$$REL_DIR\"; \\\n\t\tmkdir -p \"$$OUT_DIR\"; \\\n\t\tjupyter nbconvert --to markdown \"$$NB_SRC\" \\\n\t\t\t--output \"$$NB_BASE.md\" \\\n\t\t\t--output-dir \"$$OUT_DIR\" \\\n\t\t\t>/dev/null; \\\n\t\t' || exit 1; \\\n\t\\\n\techo \"Markdown artifacts written to: $(BUILDDIR)/html/markdown\"\n\n\n\n# Serve documentation with auto-build and live reload\nserve:\n\t@echo \"Starting auto-build server at http://0.0.0.0:$(PORT)\"\n\t@$(SPHINXAUTOBUILD) \"$(SOURCEDIR)\" \"$(BUILDDIR)/html\" \\\n\t\t--host 0.0.0.0 \\\n\t\t--port $(PORT) \\\n\t\t--watch $(SOURCEDIR) \\\n\t\t--re-ignore \".*\\.(ipynb_checkpoints|pyc|pyo|pyd|git)\"\n\n.PHONY: help Makefile compile clean serve\n\n%: Makefile\n\t@$(SPHINXBUILD) -M $@ \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n\nclean:\n\tfind . -name \"*.ipynb\" -exec nbstripout {} \\;\n\trm -rf $(BUILDDIR)\n\trm -rf logs\n"
  },
  {
    "path": "docs/README.md",
    "content": "# SGLang Documentation\n\nThis is the documentation website for the SGLang project (https://github.com/sgl-project/sglang).\n\nWe recommend new contributors start from writing documentation, which helps you quickly understand SGLang codebase.\nMost documentation files are located under the `docs/` folder.\n\n## Docs Workflow\n\n### Install Dependency\n\n**Linux:**\n```bash\napt-get update && apt-get install -y pandoc parallel retry\npip install -r requirements.txt\n```\n\n**macOS:**\n```bash\nbrew install pandoc parallel retry\npip install -r requirements.txt\n```\n\n### Update Documentation\n\nUpdate your Jupyter notebooks in the appropriate subdirectories under `docs/`. If you add new files, remember to update `index.rst` (or relevant `.rst` files) accordingly.\n\n- **`pre-commit run --all-files`** manually runs all configured checks, applying fixes if possible. If it fails the first time, re-run it to ensure lint errors are fully resolved. Make sure your code passes all checks **before** creating a Pull Request.\n\n```bash\n# 1) Compile all Jupyter notebooks\nmake compile  # This step can take a long time (10+ mins). You can consider skipping this step if you can make sure your added files are correct.\nmake html\n\n# 2) Compile and Preview documentation locally with auto-build\n# This will automatically rebuild docs when files change\n# Open your browser at the displayed port to view the docs\nbash serve.sh\n\n# 2a) Alternative ways to serve documentation\n# Directly use make serve\nmake serve\n# With custom port\nPORT=8080 make serve\n\n# 3) Clean notebook outputs\n# nbstripout removes notebook outputs so your PR stays clean\npip install nbstripout\nfind . -name '*.ipynb' -exec nbstripout {} \\;\n\n# 4) Pre-commit checks and create a PR\n# After these checks pass, push your changes and open a PR on your branch\npre-commit run --all-files\n```\n\n## Documentation Style Guidelines\n\n- For common functionalities, we prefer **Jupyter Notebooks** over Markdown so that all examples can be executed and validated by our docs CI pipeline. For complex features (e.g., distributed serving), Markdown is preferred.\n- Keep in mind the documentation execution time when writing interactive Jupyter notebooks. Each interactive notebook will be run and compiled against every commit to ensure they are runnable, so it is important to apply some tips to reduce the documentation compilation time:\n  - Use small models (e.g., `qwen/qwen2.5-0.5b-instruct`) for most cases to reduce server launch time.\n  - Reuse the launched server as much as possible to reduce server launch time.\n- Do not use absolute links (e.g., `https://docs.sglang.io/get_started/install.html`). Always prefer relative links (e.g., `../get_started/install.md`).\n- Follow the existing examples to learn how to launch a server, send a query and other common styles.\n\n## Documentation Build, Deployment, and CI\n\nThe SGLang documentation pipeline is based on **Sphinx** and supports rendering Jupyter notebooks (`.ipynb`) into HTML/Markdown for web display. Detailed logits can be found in the [Makefile](./Makefile).\n\n### Notebook Execution (`make compile`)\n\nThe `make compile` target is responsible for executing notebooks before rendering:\n\n* Finds all `.ipynb` files under `docs/` (excluding `_build/`)\n* Executes notebooks in parallel using GNU Parallel, with a relatively small `--mem-fraction-static`\n* Wraps execution with `retry` to reduce flaky failures\n* Executes notebooks via `jupyter nbconvert --execute --inplace`\n* Records execution timing in `logs/timing.log`\n\nThis step ensures notebooks contain up-to-date outputs with each commit in the main branch before rendering.\n\n### Web Rendering (`make html`)\n\nAfter compilation, Sphinx builds the website:\n\n* Reads Markdown, reStructuredText, and Jupyter notebooks\n* Renders them into HTML pages\n* Outputs the website into:\n\n```\ndocs/_build/html/\n```\n\nThis directory is the source for online documentation hosting.\n\n### Markdown Export (`make markdown`)\n\nTo support downstream consumers, we add a **new Makefile target**:\n\n```bash\nmake markdown\n```\n\nThis target:\n\n* Does **not modify** `make compile`\n* Scans all `.ipynb` files (excluding `_build/`)\n* Converts notebooks directly to Markdown using `jupyter nbconvert --to markdown`\n* Writes Markdown artifacts into the existing build directory:\n\n```\ndocs/_build/html/markdown/<relative-path>.md\n```\n\nExample:\n\n```\ndocs/advanced_features/lora.ipynb\n→ docs/_build/html/markdown/advanced_features/lora.md\n```\n\n### CI Execution\n\nIn our [CI](https://github.com/sgl-project/sglang/blob/main/.github/workflows/release-docs.yml), the documentation pipeline first gets all the executed results and renders HTML and Markdown by:\n\n```bash\nmake compile    # execute notebooks (ensure outputs are up to date)\nmake html       # build website as usual\nmake markdown   # export markdown artifacts into _build/html/markdown\n```\n\nThen, the compiled results are forced pushed to [sgl-project.io](https://github.com/sgl-project/sgl-project.github.io) for rendering. In other words, sgl-project.io is push-only. All the changes of SGLang docs should be made directly in SGLang main repo, then push to the sgl-project.io.\n"
  },
  {
    "path": "docs/_static/css/custom_log.css",
    "content": ".output_area {\n    color: #615656;\n}\n\ntable.autosummary td {\n    width: 50%\n  }\n\n  img.align-center {\n    display: block;\n    margin-left: auto;\n    margin-right: auto;\n}\n\n.output_area.stderr {\n    color: #d3d3d3 !important;\n}\n\n.output_area.stdout {\n    color: #d3d3d3 !important;\n}\n\ndiv.output_area.stderr {\n    color: #d3d3d3 !important;\n}\n\ndiv.output_area.stdout {\n    color: #d3d3d3 !important;\n}\n"
  },
  {
    "path": "docs/_static/css/readthedocs.css",
    "content": "table.autosummary td {\n  width: 50%\n}\n\nimg.align-center {\n  display: block;\n  margin-left: auto;\n  margin-right: auto;\n}\n"
  },
  {
    "path": "docs/advanced_features/attention_backend.md",
    "content": "# Attention Backend\n\nSGLang supports a large variety of attention backends. Each of them has different pros and cons.\nYou can test them according to your needs.\n\n```{important}\nSelecting an optimal attention backend is crucial for maximizing your performance. Different backends excel in various scenarios, so choose based on your model, hardware, and use case. Not all backends are supported on all platforms and model architectures.\n\nIf you don't specify `--attention-backend`, SGLang makes a best effort to automatically select the most performant backend based on your hardware and model architecture.\n```\n\n## Support Matrix\n\nThe support matrix is split into two parts: MHA (standard attention) and MLA (multi-head latent attention). For an explanation of the key differences between MHA and MLA, please see the [SGLang documentation on DeepSeek MLA](../basic_usage/deepseek_v3.md#multi-head-latent-attention-mla-throughput-optimizations) and the original [DeepSeek MLA paper](https://arxiv.org/pdf/2405.04434).\n\n### MHA Backends\n\n| **Backend**                     | **Page Size > 1 (native)** | **FP8 KV Cache** | **FP4 KV Cache** | **Spec topk=1** | **Spec topk>1** | **Sliding Window** | **MultiModal** |\n|---------------------------------|-----------------------------|------------------|-----------------|-----------------|-----------------|--------------------|----------------|\n| **FlashInfer**                  | ✅                          | ✅               | ❌              | ✅              | ✅              | ✅                 | ❌             |\n| **FA3 (FlashAttention 3)**      | ✅                          | ✅               | ❌              | ✅              | ✅              | ✅                 | ✅             |\n| **FA4 (FlashAttention 4)**      | 128                         | ❌               | ✅              | ❌              | ❌              | ❌                 | ✅             |\n| **Triton**                      | ❌                          | ✅               | ✅              | ✅              | ✅              | ✅                 | ✅             |\n| **Torch Native (SDPA)**         | ❌                          | ✅               | ✅              | ❌              | ❌              | ❌                 | ✅             |\n| **FlexAttention (PyTorch)**     | ❌                          | ❌               | ✅              | ❌              | ❌              | ❌                 | ❌             |\n| **TRTLLM MHA**                  | 16, 32 or 64                | ✅               | ✅              | ✅              | ❌              | ✅                 | ❌             |\n| **Dual Chunk FlashAttention**   | ✅                          | ❌               | ❌              | ❌              | ❌              | ❌                 | ❌             |\n| **AITER (ROCm)**                | ✅                          | ✅               | ❌              | ✅              | ✅              | ✅                 | ✅             |\n| **Wave (ROCm)**                 | ✅                          | ❌               | ❌              | ❌              | ❌              | ❌                 | ❌             |\n| **Ascend (NPU)**                | ✅                          | ❌               | ❌              | ✅              | ❌              | ✅                 | ✅             |\n| **Intel XPU**                   | ✅                          | ❌               | ❌              | ❌              | ❌              | ✅                 | ❌             |\n| **Intel AMX (CPU)**             | ❌                          | ❌               | ❌              | ❌              | ❌              | ❌                 | ❌             |\n\n### MLA Backends\n\n| **Backend**                | **Native Page Sizes**     | **FP8 KV Cache** | **FP4 KV Cache** | **Chunked Prefix Cache** | **Spec topk=1** | **Spec topk>1** |\n|----------------------------|---------------------------|------------------|------------------|--------------------------|-----------------|-----------------|\n| **FlashInfer MLA**         | 1                         | ❌               | ✅               | ✅                       | ✅              | ❌              |\n| **FlashMLA**               | 64                        | ✅               | ✅               | ✅                       | ✅              | ❌              |\n| **Cutlass MLA**            | 128                       | ✅               | ✅               | ✅                       | ✅              | ❌              |\n| **TRTLLM MLA (Blackwell)** | 32 or 64                  | ✅               | ✅               | ✅                       | ✅              | ❌              |\n| **FA3 (FlashAttention 3)** | n/a                       | ❌               | ❌               | ✅                       | ✅              | ⚠️ (page_size=1 only) |\n| **Triton**                 | n/a                       | ❌               | ❌               | ❌                       | ✅              | ⚠️ (page_size=1 only) |\n| **FA4**                    | 1                         | ❌               | ✅               | ✅                       | ❌              | ❌              |\n| **Ascend MLA (NPU)**       | 128                       | ❌               | ❌               | ❌                       | ❌              | ❌              |\n\n```{note}\nMultimodal attention is selected by `--mm-attention-backend`. The \"MultiModal\" column indicates whether a corresponding multimodal implementation exists for that backend family.\n```\n\n```{note}\n- FlashAttention 4 supports both prefill and decode on SM90 (Hopper) and SM100 (Blackwell). FA4 MLA supports `page_size = 1`; FA4 MHA requires `page_size = 128`. On SM100, this is auto-enforced by the server; on SM90, users must set `--page-size 128` manually.\n- NSA is specifically designed for [DeepSeek V3.2 DSA](https://lmsys.org/blog/2025-09-29-deepseek-V32/). See the [DSA Attention Backend (NSA)](#dsa-attention-backend-nsa) section and [DeepSeek V3.2 deployment guide](../basic_usage/deepseek_v32.md) for details.\n```\n\n```{warning}\n**FA4 on Hopper (SM90):** FA4 decode speed decreases as sequence length grows due to lack of SplitKV support. At batch=1 compared to FA3 on H100: ~-10% at 2K tokens, ~-18% at 4K, ~-31% at 8K, ~-49% at 16K. Larger batch sizes reduce the gap (e.g., batch=8: ~-2% at 2K, ~-8% at 4K). Blackwell (SM100) is not affected.\n```\n\n```{note}\nFor the KV4 FA4 scenario, FA4 requires using a different --decode-attention-backend to run. Except for trtllm_mha being incompatible with FA4, all other decode backends behave as shown in the table.\n```\n\n```{tip}\nSpeculative decoding topk: `topk` is the number of draft tokens sampled per step from the draft model. `topk = 1` follows classic EAGLE; `topk > 1` explores multiple branches and requires backend support in both draft and verification paths.\n```\n\n```{note}\n**Speculative Decoding V2 (Spec V2):** Spec V2 uses overlap scheduling (`SGLANG_ENABLE_SPEC_V2=True`) that benefits various attention backends. Requires `--speculative-eagle-topk 1` and currently applies to EAGLE and EAGLE3.\n\n**Verified backends:** TRTLLM MLA, TRTLLM MHA, FA3, Ascend (NPU), Triton.\n\n**Limited support:** FlashInfer can run under Spec V2, but its plan stream (used for split-KV optimization) introduces a synchronization point that limits overlap benefits.\n```\n\n```{tip}\nPage size controls how many tokens are grouped into a KV cache block. For the prefix cache to take effect, the number of tokens must fill at least one complete page. For example, if your prompt is only 32 tokens and `page_size = 64`, it won't fill a complete page and cannot be matched in the prefix cache (pages cannot be padded). With 65 tokens and `page_size = 64`, only the first page of 64 tokens will be cached and matched; the remaining 1 token is discarded. Use `page_size = 1` for maximum prefix reuse (token-level matching). Note that higher page sizes generally improve attention kernel performance, so prefer `page_size > 1` when prefix cache reuse is not critical.\n```\n\nMany backends that do not natively operate on pages can emulate `page_size > 1` at the wrapper layer by expanding page tables to per-token indices. The \"Page Size > 1 (native)\" column indicates true in-kernel paging. Some backends require fixed native page sizes and cannot be reduced/emulated differently: TRTLLM MHA (16/32/64), TRTLLM MLA (32/64), FlashMLA (64), Cutlass MLA (128), Ascend (128).\n\nMLA page-size constraints:\n- FlashInfer MLA: page_size = 1.\n- FlashMLA: page_size = 64.\n- Cutlass MLA: page_size = 128.\n- TRTLLM MLA: page_size ∈ {32, 64}.\n\n### GDN Attention Backends\n\nGDN (Gated Delta Network) is a linear attention mechanism with O(n) complexity, used in hybrid models that alternate GDN linear attention layers with standard full attention layers. GDN is **not** selected via `--attention-backend`; it is automatically activated when the model architecture requires it (e.g., Qwen 3.5, Qwen 3 Next, Jet Nemotron, Jet VLM).\n\nThe GDN linear attention layers have their own kernel backends, selected via `--linear-attn-backend` (default: `triton`). You can override the kernel per phase with `--linear-attn-decode-backend` and `--linear-attn-prefill-backend`.\n\n| **Backend**              | **Decode** | **Prefill / Extend** | **Spec Decoding (Target Verify)** |\n|--------------------------|------------|----------------------|-----------------------------------|\n| **Triton (CUDA)**        | ✅         | ✅                   | ✅                                |\n| **Triton (AMD/ROCm)**    | ✅         | ✅                   | ✅                                |\n| **Triton (NPU)**         | ✅         | ✅                   | ❌                                |\n| **Triton (CPU)**         | ✅         | ✅                   | ❌                                |\n| **CuTe DSL (CUDA only)**| ✅         | ❌                   | ❌                                |\n\n```{important}\nGDN models are hybrid: the full-attention layers still require a standard `--attention-backend`. Platform constraints for the full-attention backend on hybrid GDN models:\n- **Blackwell (e.g., B200)**: `triton`, `trtllm_mha`, or `fa4` only.\n- **NPU (Ascend)**: `ascend` only.\n- **AMD (ROCm)**: `triton` recommended.\n- **Other CUDA (Hopper, Ampere, etc.)**: auto-selection works; no special constraints.\n```\n\n### DSA Attention Backend (NSA)\n\nDSA (Deepseek Sparse Attention) is a native sparse attention mechanism used by [DeepSeek V3.2](https://lmsys.org/blog/2025-09-29-deepseek-V32/). It is activated automatically when the model architecture requires it and is selected via `--attention-backend nsa`.\n\nInternally, the NSA backend dispatches to different sub-backends for prefill and decode phases. You can override these with `--nsa-prefill-backend` and `--nsa-decode-backend`:\n\n| **Sub-backend**       | **Prefill** | **Decode** | **Notes**                                     |\n|-----------------------|-------------|------------|-----------------------------------------------|\n| **flashmla_sparse**   | ✅          | ✅         | Default prefill on Hopper and Blackwell (bf16) |\n| **flashmla_kv**       | ✅          | ✅         | Default decode for FP8 on Blackwell with DP   |\n| **flashmla_auto**     | ✅          | ❌         | Auto-selects flashmla_sparse or flashmla_kv based on kv_cache_dtype |\n| **fa3**               | ✅          | ✅         | Default decode on Hopper (bf16)               |\n| **trtllm**            | ✅          | ✅         | Default decode on Blackwell (bf16); default for both on Blackwell without DP |\n| **tilelang**          | ✅          | ✅         | Default on AMD (ROCm)                         |\n| **aiter**             | ✅          | ✅         | AMD-specific kernel library (requires aiter package) |\n\nFor deployment examples, see the [DeepSeek V3.2 deployment guide](../basic_usage/deepseek_v32.md).\n\n### Hybrid attention (different backends for prefill vs decode) (Experimental)\n\n```{warning}\nHybrid attention is an experimental feature.\n```\n\nYou can mix-and-match attention backends for prefill and decode. This is useful when one backend excels at prefill and another excels at decode. For the implementation details, please see `python/sglang/srt/layers/attention/hybrid_attn_backend.py`.\n\n```bash\n# Example: Prefill with FA4, Decode with TRTLLM MLA (Blackwell)\npython3 -m sglang.launch_server \\\n  --model-path nvidia/DeepSeek-R1-FP4 \\\n  --tp 8 \\\n  --attention-backend trtllm_mla \\\n  --moe-runner-backend flashinfer_trtllm \\\n  --quantization modelopt_fp4 \\\n  --prefill-attention-backend fa4\n```\n\n#### Speculative decoding with hybrid attention\n\nHybrid attention also works with speculative decoding. The backend used for draft decoding and target verification depends on `--speculative-attention-mode`:\n\n- `--speculative-attention-mode decode` (recommended): draft/verify use the decode backend.\n- `--speculative-attention-mode prefill` (default): draft/verify use the prefill backend.\n\nConstraints when combining hybrid attention with speculative decoding:\n\n- If any attention backend is `trtllm_mha`, speculative decoding supports only `--speculative-eagle-topk 1`.\n- For paged MHA backends with `--page-size > 1` and `--speculative-eagle-topk > 1`, only `flashinfer` is supported.\n- CUDA Graph: the decode backend is always captured; the prefill backend is captured only when `--speculative-attention-mode prefill`.\n\n\n```{tip}\nIf you set only one of `--prefill-attention-backend` or `--decode-attention-backend`, the unspecified phase inherits `--attention-backend`.\nIf both are specified and differ, SGLang automatically enables a hybrid wrapper to dispatch to the chosen backend per phase.\n```\n\n## Attention Backend Selection Guide (CUDA)\n\nIf the `--attention-backend` argument is not specified, SGLang automatically selects the best backend based on the hardware (CUDA) and model architecture.\n\n### Automatic Selection Logic\n\n**1. MHA Models (e.g., Llama, Qwen)**\n- **Hopper (e.g., H100, H200)**: Defaults to `fa3` if using CUDA 12.3+ and the model configuration is supported.\n- **Blackwell (e.g., B200)**: Defaults to `trtllm_mha`, unless using speculative decoding with `topk > 1`.\n- **Other Architectures (Ampere, Ada, etc.)**: Defaults to `flashinfer` if available; otherwise falls back to `triton`.\n\n**2. MLA Models (e.g., DeepSeek V3)**\n- **Hopper**: Defaults to `fa3` (requires CUDA 12.3+).\n- **Blackwell**: Defaults to `flashinfer`; `trtllm_mla` is auto-selected for DeepSeek V3 models specifically.\n- **Other Architectures**: Defaults to `triton`.\n\n\n## User Guide\n\n### Launch Command for Different Attention Backends\n\n- FlashInfer (Default for Non-Hopper Machines, e.g., A100, A40)\n```bash\npython3 -m sglang.launch_server \\\n  --model meta-llama/Meta-Llama-3.1-8B-Instruct \\\n  --attention-backend flashinfer\npython3 -m sglang.launch_server \\\n  --tp 8 \\\n  --model deepseek-ai/DeepSeek-V3 \\\n  --attention-backend flashinfer \\\n  --trust-remote-code\n```\n\n- FlashAttention 3 (Default for Hopper Machines, e.g., H100, H200, H20)\n```bash\npython3 -m sglang.launch_server \\\n  --model meta-llama/Meta-Llama-3.1-8B-Instruct \\\n  --attention-backend fa3\npython3 -m sglang.launch_server \\\n  --tp 8 \\\n  --model deepseek-ai/DeepSeek-V3 \\\n  --trust-remote-code \\\n  --attention-backend fa3\n```\n\n- Triton\n```bash\npython3 -m sglang.launch_server \\\n  --model meta-llama/Meta-Llama-3.1-8B-Instruct \\\n  --attention-backend triton\npython3 -m sglang.launch_server \\\n  --tp 8 \\\n  --model deepseek-ai/DeepSeek-V3 \\\n  --attention-backend triton \\\n  --trust-remote-code\n```\n\n- FlashMLA\n```bash\npython3 -m sglang.launch_server \\\n  --tp 8 \\\n  --model deepseek-ai/DeepSeek-R1 \\\n  --attention-backend flashmla \\\n  --trust-remote-code\npython3 -m sglang.launch_server \\\n  --tp 8 \\\n  --model deepseek-ai/DeepSeek-R1 \\\n  --attention-backend flashmla \\\n  --kv-cache-dtype fp8_e4m3 \\\n  --trust-remote-code\n```\n\n- TRTLLM MLA (Optimized for Blackwell Architecture, e.g., B200)\n```bash\npython3 -m sglang.launch_server \\\n  --tp 8 \\\n  --model deepseek-ai/DeepSeek-R1 \\\n  --attention-backend trtllm_mla \\\n  --trust-remote-code\n```\n\n- TRTLLM MLA with FP8 KV Cache (Higher concurrency, lower memory footprint)\n```bash\npython3 -m sglang.launch_server \\\n  --tp 8 \\\n  --model deepseek-ai/DeepSeek-R1 \\\n  --attention-backend trtllm_mla \\\n  --kv-cache-dtype fp8_e4m3 \\\n  --trust-remote-code\n```\n\n- TRTLLM MHA (Optimized for Blackwell Architecture, e.g., B200)\n```bash\npython3 -m sglang.launch_server \\\n  --tp 4 \\\n  --model Qwen/Qwen3.5-35B-A3B-FP8 \\\n  --attention-backend trtllm_mha \\\n  --trust-remote-code\n```\n\n- TRTLLM MHA (XQA backend) (Optimized for SM90 and SM120, e.g., H20, H200, 5090)\n  Note that TRTLLM XQA backend only works well for pagesize 64.\n```bash\npython3 -m sglang.launch_server \\\n  --tp 4 \\\n  --model Qwen/Qwen3.5-35B-A3B-FP8 \\\n  --decode-attention-backend trtllm_mha \\\n  --trust-remote-code\n```\n\n- FlashAttention 4 (MHA & MLA)\n```bash\n# FA4 for both prefill and decode on SM90/SM100\npython3 -m sglang.launch_server \\\n  --model-path Qwen/Qwen3-30B-A3B-Instruct-2507-FP8 \\\n  --attention-backend fa4 \\\n  --page-size 128 \\\n  --trust-remote-code\n\npython3 -m sglang.launch_server \\\n  --tp 8 \\\n  --model deepseek-ai/DeepSeek-R1 \\\n  --prefill-attention-backend fa4 \\\n  --trust-remote-code\n```\n\n- Cutlass MLA\n```bash\npython3 -m sglang.launch_server \\\n  --tp 8 \\\n  --model deepseek-ai/DeepSeek-R1 \\\n  --attention-backend cutlass_mla \\\n  --trust-remote-code\n```\n\n- Ascend\n```bash\npython3 -m sglang.launch_server \\\n  --model meta-llama/Meta-Llama-3.1-8B-Instruct \\\n  --attention-backend ascend\n```\n\n- Intel XPU\n```bash\npython3 -m sglang.launch_server \\\n  --model meta-llama/Meta-Llama-3.1-8B-Instruct \\\n  --attention-backend intel_xpu\n```\n\n- Wave\n```bash\npython3 -m sglang.launch_server \\\n  --model meta-llama/Meta-Llama-3.1-8B-Instruct \\\n  --attention-backend wave\n```\n\n- FlexAttention\n```bash\npython3 -m sglang.launch_server \\\n  --model meta-llama/Meta-Llama-3.1-8B-Instruct \\\n  --attention-backend flex_attention\n```\n\n- Dual Chunk FlashAttention\n```bash\npython3 -m sglang.launch_server \\\n  --model Qwen/Qwen2.5-14B-Instruct-1M \\\n  --attention-backend dual_chunk_flash_attn\n```\n\n- Torch Native\n```bash\npython3 -m sglang.launch_server \\\n  --model meta-llama/Meta-Llama-3.1-8B-Instruct \\\n  --attention-backend torch_native\n```\n\n## Steps to add a new attention backend\nTo add a new attention backend, you can learn from the existing backends\n(`python/sglang/srt/layers/attention/triton_backend.py`, `python/sglang/srt/layers/attention/flashattention_backend.py`)\nand follow the steps below.\n\n```{note}\nLinear attention kernel backends (GDN, KDA) follow a different pattern. They implement `LinearAttnKernelBase` in `python/sglang/srt/layers/attention/linear/kernels/` and are dispatched by `GDNKernelDispatcher` / `KDAKernelDispatcher` rather than registered via `@register_attention_backend`.\n```\n\n1. Run without cuda graph. Support the two forward functions\n- forward_extend\n  - Will be used for prefill, prefill with KV cache, and target verification\n  - It will be called once per layer\n- forward_decode\n  - Will be used for normal decode, and draft decode\n  - It will be called once per layer\n- init_forward_metadata\n  - Initialize the class and common metadata shared by all layers\n  - Call the plan function for optimizations like split_kv\n  - It will be called once per forward\n2. Run with cuda graph. It has two phases (capture and replay) and you need to implement three functions\n- init_cuda_graph_state\n  - It will be called once during life time\n  - Create all common shared buffers\n- init_forward_metadata_capture_cuda_graph\n  - It will be called before capturing a cuda graph\n  - It is similar to init_forward_metadata but write the medatada to some pre-defined buffers\n- init_forward_metadata_replay_cuda_graph\n  - It will be called before replaying a cuda graph\n  - This function is in the critical path and needs to be fast\n"
  },
  {
    "path": "docs/advanced_features/checkpoint_engine.md",
    "content": "# Checkpoint Engine Integration\n\nThe SGLang checkpoint engine integration provides an efficient way to load model weights using a distributed checkpoint loading system. This feature significantly reduces model loading time, especially for large models and multi-node setups, by parallelizing the weight loading process across multiple processes and nodes.\n\n## Overview\n\nThe checkpoint engine integration allows SGLang to:\n- Load model weights in parallel using multiple processes\n- Distribute weight loading across multiple nodes to increase effective disk bandwidth\n- Overlap weight loading with other initialization tasks like CUDA graph capture\n- Support both single-node and multi-node deployments\n\n## Installation\n\nFirst, install the checkpoint engine package:\n\n```bash\npip install 'checkpoint-engine[p2p]'\n```\n\n## Architecture\n\nThe system consists of two main components:\n\n1. **SGLang Server**: Runs with `--wait-for-initial-weights` flag to wait for weights before becoming ready\n2. **Checkpoint Engine Workers**: Separate processes (managed by torchrun) that load and distribute model weights\n\nThe checkpoint engine uses a parameter server architecture with support for:\n- **Broadcast mode**: Weights are broadcast from loading processes to inference processes\n- **P2P mode**: Direct peer-to-peer weight transfer between processes\n- **All mode**: Combination of both broadcast and P2P methods\n\n## Usage Examples\n\n### Single Node Setup\n\n**Terminal 1 - Launch SGLang Server:**\n```bash\npython -m sglang.launch_server \\\n    --model-path Qwen/Qwen3-8B \\\n    --tp 8 \\\n    --load-format dummy \\\n    --wait-for-initial-weights\n```\n\n**Terminal 2 - Run Checkpoint Engine:**\n\nUsing sglang entrypoint:\n```bash\npython -m sglang.srt.checkpoint_engine.update \\\n    --update-method broadcast \\\n    --checkpoint-path /path/to/Qwen/Qwen3-8B/ \\\n    --inference-parallel-size 8\n```\n\nUsing torchrun directly:\n```bash\ntorchrun --nproc-per-node 8 \\\n    examples/checkpoint_engine/update.py \\\n    --update-method broadcast \\\n    --checkpoint-path /path/to/Qwen/Qwen3-8B/ \\\n    --inference-parallel-size 8\n```\n\n### Multi-Node Setup (2 Nodes)\n\n**Node 0:**\n\nLaunch SGLang server:\n```bash\npython -m sglang.launch_server \\\n    --model-path Qwen/Qwen3-8B \\\n    --tp 8 \\\n    --load-format dummy \\\n    --wait-for-initial-weights \\\n    --host [IP]\n```\n\nRun checkpoint engine:\n\nUsing sglang entrypoint (recommended):\n```bash\npython -m sglang.srt.checkpoint_engine.update \\\n    --update-method broadcast \\\n    --checkpoint-path /path/to/Qwen/Qwen3-8B/ \\\n    --inference-parallel-size 8\n```\n\nUsing torchrun directly:\n```bash\ntorchrun --nproc-per-node 8 \\\n    --nnodes 2 \\\n    --node-rank 0 \\\n    --master-addr [IP] \\\n    --master-port 29500 \\\n    examples/checkpoint_engine/update.py \\\n    --update-method broadcast \\\n    --checkpoint-path /path/to/Qwen/Qwen3-8B/ \\\n    --inference-parallel-size 8\n```\n\n**Node 1:**\n\nLaunch SGLang server:\n```bash\npython -m sglang.launch_server \\\n    --model-path Qwen/Qwen3-8B \\\n    --tp 8 \\\n    --load-format dummy \\\n    --wait-for-initial-weights \\\n    --host [IP]\n```\n\nRun checkpoint engine:\n\nUsing sglang entrypoint (recommended):\n```bash\npython -m sglang.srt.checkpoint_engine.update \\\n    --update-method broadcast \\\n    --checkpoint-path /path/to/Qwen/Qwen3-8B/ \\\n    --inference-parallel-size 8\n```\n\nUsing torchrun directly:\n```bash\ntorchrun --nproc-per-node 8 \\\n    --nnodes 2 \\\n    --node-rank 1 \\\n    --master-addr [IP] \\\n    --master-port 29500 \\\n    examples/checkpoint_engine/update.py \\\n    --update-method broadcast \\\n    --checkpoint-path /path/to/Qwen/Qwen3-8B/ \\\n    --inference-parallel-size 8\n```\n\n### Multi-Node Setup with Tensor Parallelism (TP=16)\n\n**Node 0:**\n\nLaunch SGLang server:\n```bash\npython -m sglang.launch_server \\\n    --model-path Qwen/Qwen3-8B \\\n    --tp 8 \\\n    --load-format dummy \\\n    --wait-for-initial-weights \\\n    --host [IP] \\\n    --dist-init-addr [IP]:9120 \\\n    --nnodes 2 \\\n    --node-rank 0\n```\n\nRun checkpoint engine:\n\nUsing sglang entrypoint (recommended):\n```bash\npython -m sglang.srt.checkpoint_engine.update \\\n    --update-method broadcast \\\n    --checkpoint-path /path/to/Qwen/Qwen3-8B/ \\\n    --inference-parallel-size 16\n```\n\nUsing torchrun directly:\n```bash\ntorchrun --nproc-per-node 8 \\\n    --nnodes 2 \\\n    --node-rank 0 \\\n    --master-addr [IP] \\\n    --master-port 29500 \\\n    examples/checkpoint_engine/update.py \\\n    --update-method broadcast \\\n    --checkpoint-path /path/to/Qwen/Qwen3-8B/ \\\n    --inference-parallel-size 16\n```\n\n**Node 1:**\n\nLaunch SGLang server:\n```bash\npython -m sglang.launch_server \\\n    --model-path Qwen/Qwen3-8B \\\n    --tp 8 \\\n    --load-format dummy \\\n    --wait-for-initial-weights \\\n    --host [IP] \\\n    --dist-init-addr [IP]:9120 \\\n    --nnodes 2 \\\n    --node-rank 1\n```\n\nRun checkpoint engine:\n\nUsing sglang entrypoint (recommended):\n```bash\npython -m sglang.srt.checkpoint_engine.update \\\n    --update-method broadcast \\\n    --checkpoint-path /path/to/Qwen/Qwen3-8B/ \\\n    --inference-parallel-size 16\n```\n\nUsing torchrun directly:\n```bash\ntorchrun --nproc-per-node 8 \\\n    --nnodes 2 \\\n    --node-rank 1 \\\n    --master-addr [IP] \\\n    --master-port 29500 \\\n    examples/checkpoint_engine/update.py \\\n    --update-method broadcast \\\n    --checkpoint-path /path/to/Qwen/Qwen3-8B/ \\\n    --inference-parallel-size 16\n```\n\n## Configuration Options\n\n### SGLang Server Options\n\n- `--load-format dummy`: Use dummy format for initial loading (allows overlapping with other tasks)\n- `--wait-for-initial-weights`: Wait for checkpoint engine to provide weights before becoming ready\n- `--host`: Host address for multi-node setups\n- `--dist-init-addr`: Distributed initialization address for tensor parallelism\n\n### Checkpoint Engine Options\n\n- `--update-method`: Weight update method (`broadcast`, `p2p`, or `all`)\n- `--checkpoint-path`: Path to model checkpoint directory\n- `--inference-parallel-size`: Number of inference parallel processes\n- `--endpoint`: SGLang server endpoint (default: `http://localhost:19730`)\n- `--checkpoint-name`: Name for the checkpoint (default: `my-checkpoint-iter-0`)\n- `--save-metas-file`: File to save checkpoint metadata\n- `--load-metas-file`: File to load checkpoint metadata from\n- `--uds`: Unix domain socket path for communication\n- `--weight-version`: Version identifier for weights\n\n## Performance Benefits\n\nThe checkpoint engine provides significant time savings in two main aspects:\n\n1. **Multi-node Loading**: Each node only loads a portion of weights from disk, effectively increasing disk bandwidth. More participating nodes provide greater acceleration. Preliminary tests show 20-second acceleration when loading DeepSeek-R1 on H20-3e with two nodes.\n\n2. **Single Process Optimization**: Using dummy format allows overlapping disk-to-CPU transfer with CUDA graph capture and other initialization tasks, providing additional time savings.\n\n## Troubleshooting\n\n- Ensure checkpoint engine package is installed: `pip install 'checkpoint-engine[p2p]'`\n- Verify network connectivity between nodes in multi-node setups\n- Check that the checkpoint path contains valid model files\n- Monitor logs for connection errors between SGLang server and checkpoint engine\n- Use `--sleep-time` parameter to add delays if needed for debugging\n\n## References\n\n- [Checkpoint Engine Repository](https://github.com/MoonshotAI/checkpoint-engine)\n"
  },
  {
    "path": "docs/advanced_features/cuda_graph_for_multi_modal_encoder.md",
    "content": "# Cuda Graph for Multi-Modal Encoder in SGLang\n\n## Motivation\n\nIn multimodal reasoning services, the visual encoder (ViT / Vision Transformer) typically has a few characteristic traits:\n\nMany layers, fragmented operators: Each layer includes LN, QKV projections, attention, MLP, residual connections, etc., resulting in extremely frequent kernel launches.\n\nServer-side “small batch / low latency” is common: The batch size is very small (sometimes it looks like 1 after “flattening” the batch), so kernel launch overhead accounts for a large portion of end-to-end latency.\n\nInput token count (number of patches) varies frequently: Different image/video resolutions and different batch composition lead to different sequence lengths\nS — and this is precisely the biggest obstacle for CUDA Graph (unstable shapes).\n\nThe value of CUDA Graph: It captures a long sequence of GPU kernels with fixed shapes and fixed memory addresses into a graph; later, for the same shapes, it can replay the graph directly, dramatically reducing launch overhead and making GPU scheduling more compact.\n\nThis led us to seek a CUDA Graph enabled feature for ViT in order to improve ViT performance.\n\n## Design and Restrictions\n\nThe new CUDA Graph enabled ViT logic is built on ViTCudaGraphRunner. This runner captures the \"blocks + merger + deepstack merger (optional)\" part of a vision transformer into a CUDA graph and replays it for identical shapes. See the following design consideration and restrictions for more details.\n\n### Dynamic inputs to fit static constraints of CUDA Graph\n\nVariable sequence length S is very common in ViT. While CUDA Graph requires fixed shapes. The solution is to build a graph cache by S(e.g., graph_key = S). The first time create a new S, and then capture a graph; afterwards, replay it.\n\nIf there are many distinct S values, we need to increase VRAM usage which is graph-private memory pools for many graphs.\n\n### Stable addresses\n\nEverything \"parameter-like\" becomes a static buffer:\n\n- block_input / block_ws / block_output\n- cu_full_len / cu_window_len and their kk variants\n- sin_cos_ws\n\nIn this way to solve the underlying requirement: during replay, not allowed to swap tensors, can only modify tensor contents.\n\n### Attention backend arguments\nAttention backend arguments are fixed inside the graph:\n\nTritonAttn expects [cu_seqlens, cu_seqlens_kk, max_len]\nFA3 expects [cu_seqlens, max_len]\n\nmax_len is frozen as an int constant.\ncu_seqlens is cached into a dict during create_graph(), and its contents are not updated during subsequent replays.\n\nFor the same graph_key = S, you not only require the input shape to match, but also require the segmentation pattern in cu_seqlens (and window seqlens) to be identical. Otherwise, attention will segment the sequence incorrectly.\n\n### Rotary buffer management\nThe feature reallocates a larger sin_cos_ws when seq_len increases.\nThe max_content_len is used to make sure the maximum size of the allocated rotary buffer.\n\n\n## Command Example\nYou can enable CUDA Graph for ViT by setting env variable `SGLANG_VIT_ENABLE_CUDA_GRAPH=1`, for example:\n```\nSGLANG_VIT_ENABLE_CUDA_GRAPH=1 \\\npython3 -m sglang.launch_server \\\n  --model Qwen/Qwen3-VL-8B-Instruct\n```\nOr you can run CUDA Graph for ViT together with Piecewise CUDA Graph feature by both setting env variable `SGLANG_VIT_ENABLE_CUDA_GRAPH=1` and setting `--enable-piecewise-cuda-graph`, for example:\n```\nSGLANG_VIT_ENABLE_CUDA_GRAPH=1 \\\npython3 -m sglang.launch_server \\\n  --model Qwen/Qwen3-VL-8B-Instruct \\\n  --piecewise-cuda-graph-max-tokens 4096 \\\n  --enable-piecewise-cuda-graph \\\n  --piecewise-cuda-graph-compiler eager\n```\n\n## Known supported models\n- Qwen2.5-VL (https://github.com/sgl-project/sglang/pull/14422)\n- Qwen3-VL (https://github.com/sgl-project/sglang/pull/15320)\n"
  },
  {
    "path": "docs/advanced_features/deterministic_inference.md",
    "content": "# Deterministic Inference\n\n## Why Deterministic Inference Matters\n\nDeterministic inference ensures consistent LLM outputs across runs, which is critical for:\n- **Reinforcement Learning**: Ensures consistent logprobs across runs, reducing stochastic noise and making RL training more stable, reproducible, and debuggable.\n- **Testing & Debugging**: Enables reproducible validation\n- **Production**: Improves reliability and user experience\n\nEven with `temperature=0`, standard LLM inference can produce different outputs due to dynamic batching and varying reduction orders in GPU kernels.\n\n## The Root Cause of Non-Determinism\n\nThe main source is **varying batch sizes**. Different batch sizes cause GPU kernels to split reduction operations differently, leading to different addition orders. Due to floating-point non-associativity (`(a + b) + c ≠ a + (b + c)`), this produces different results even for identical inputs.\n\n\n## SGLang's Solution\n\nBuilding on [Thinking Machines Lab's batch-invariant operators](https://github.com/thinking-machines-lab/batch_invariant_ops), SGLang achieves fully deterministic inference while maintaining compatibility with chunked prefill, CUDA graphs, radix cache, and non-greedy sampling. The development roadmap for deterministic inference features can be found in this [issue](https://github.com/sgl-project/sglang/issues/10278).\n\n### Supported Backends\n\nDeterministic inference is only supported with the following three attention backends: **FlashInfer**, **FlashAttention 3 (FA3)**, and **Triton**.\n\nThe following table shows feature compatibility for deterministic inference across different attention backends:\n\n| Attention Backend | CUDA Graph | Chunked Prefill | Radix Cache | Non-greedy Sampling (Temp > 0) |\n|-------------------|------------|-----------------|-------------|---------------------|\n| **FlashInfer** | ✅ Yes | ✅ Yes | ❌ No | ✅ Yes |\n| **FlashAttention 3 (FA3)** | ✅ Yes | ✅ Yes | ✅ Yes | ✅ Yes |\n| **Triton** | ✅ Yes | ✅ Yes | ✅ Yes | ✅ Yes |\n\n## Usage\n\n### Basic Usage\n\nEnable deterministic inference by adding the `--enable-deterministic-inference` flag:\n\n```bash\npython3 -m sglang.launch_server \\\n    --model-path Qwen/Qwen3-8B \\\n    --attention-backend fa3 \\\n    --enable-deterministic-inference\n```\n\n### Server Arguments\n\n| Argument | Type/Default | Description |\n|----------|--------------|-------------|\n| `--enable-deterministic-inference` | flag; default: disabled | Enable deterministic inference with batch-invariant operations |\n| `--attention-backend` | string; default: fa3 | Choose attention backend (flashinfer, fa3, or triton) |\n\n### Example Configurations\n\n#### Qwen3-8B\n```bash\npython3 -m sglang.launch_server \\\n    --model-path Qwen/Qwen3-8B \\\n    --attention-backend flashinfer \\\n    --enable-deterministic-inference\n```\n\n#### Llama Models\n```bash\npython3 -m sglang.launch_server \\\n    --model-path meta-llama/Llama-3.1-8B-Instruct \\\n    --attention-backend fa3 \\\n    --enable-deterministic-inference\n```\n\n#### Qwen3-30B-A3B (MoE Model)\n```bash\npython3 -m sglang.launch_server \\\n    --model-path Qwen/Qwen3-30B-A3B \\\n    --attention-backend fa3 \\\n    --enable-deterministic-inference\n```\n\n### Deterministic Inference with Non-Greedy Sampling (Temperature > 0)\n\nSGLang supports deterministic inference even with non-greedy sampling by using sampling seeds. This is particularly useful for reinforcement learning scenarios like GRPO (Group Relative Policy Optimization) where you need multiple diverse but reproducible responses.\n\n#### Default Behavior\n\nBy default, SGLang uses a sampling seed of `42` for reproducible sampling:\n\n```python\nimport requests\n\nresponse = requests.post(\n    \"http://localhost:30000/generate\",\n    json={\n        \"text\": \"Tell me a joke\",\n        \"sampling_params\": {\n            \"temperature\": 0.8,  # Non-greedy sampling\n            \"max_new_tokens\": 128,\n        },\n    },\n)\nprint(response.json())\n# This will always produce the same response across runs\n```\n\n#### Generating Multiple Reproducible Responses\n\nTo sample different responses from the same prompt while maintaining reproducibility (e.g., for GRPO training), provide different sampling seeds in your requests:\n\n```python\nimport requests\n\n# Prepare a list of sampling seeds for different responses\nsampling_seeds = [42, 43, 44, 45, 46]\n\nresponses = []\nfor seed in sampling_seeds:\n    response = requests.post(\n        \"http://localhost:30000/generate\",\n        json={\n            \"text\": \"Tell me a joke\",\n            \"sampling_params\": {\n                \"temperature\": 0.8,\n                \"max_new_tokens\": 128,\n                \"sampling_seed\": seed,  # Specify sampling seed\n            },\n        },\n    )\n    responses.append(response.json())\n\n# Each seed will produce a different but reproducible response\n# Using the same seed will always produce the same response\n```\n\nThis approach ensures that:\n- Different seeds produce diverse responses\n- The same seed always produces the same response across different runs\n- Results are reproducible for debugging and evaluation\n\n\n## Verification\n\nRun deterministic tests to verify consistent outputs:\n\n```bash\n# Single test: same prompt, varying batch sizes\npython3 -m sglang.test.test_deterministic --test-mode single --n-trials 50\n\n# Prefix test: prompts with different prefix lengths\npython3 -m sglang.test.test_deterministic --test-mode prefix --n-trials 50\n\n# Radix Cache Consistency mode: test radix cache determinism (cached vs uncached prefill)\npython3 -m sglang.test.test_deterministic --test-mode radix_cache\n```\n\nExpected result: All tests should show `Unique samples: 1` (perfectly deterministic).\n"
  },
  {
    "path": "docs/advanced_features/dp_dpa_smg_guide.md",
    "content": "# DP, DPA and SGLang DP Router\n\nThis guide explains the difference between Data Parallelism (DP) and Data Parallelism Attention (DPA), how to enable each mode correctly, and how to use the SGLang Model Gateway (SMG) for production-grade DP deployments.\n\n## Data Parallelism (DP)\n\n**Data Parallelism (DP)** is the most common parallelism strategy that replicates the entire model across multiple GPU sets and processes different batches of requests in parallel. Each GPU set handles independent requests. With dedicated routing strategies, as we will introduce later, with those proper routing algorithms in SGLang Model Gateway, the throughput of your serving system could be multiplied nearly linearly.\n\n### Key characteristics\n\n- Each replica has a full copy of the model\n- Requests are distributed/scattered across replicas\n- No inter-replica communication during one request's inference (for simple DP)\n\n## Data Parallelism Attention (DPA)\n\n**Data Parallelism Attention (DPA)**, also known as DP Attention, is an advanced parallelism strategy. While DPA provides the most significant benefits for **Multi-Head Latent Attention (MLA)** models (such as DeepSeek, MiniMax, Kimi-K2), it also supports **standard attention models** like Qwen.\n\n### The Problem with Tensor Parallelism for MLA Models\n\nThe most common parallelism strategy for inference is **Tensor Parallelism (TP)**. However, TP might not be the most efficient strategy for certain models. For example, DeepSeek models use MLA and only have **one KV head**. If we use tensor parallelism on 8 GPUs, it will lead to:\n\n- **Duplicated KV cache** across all GPUs\n- **Unwanted memory usage** that limits batch size\n- **Reduced throughput** due to memory constraints\n\n### How DPA Works\n\nDPA addresses these limitations by applying **data parallelism specifically to the attention component**.\n\n<table>\n<tr>\n<td width=\"50%\">\n<img src=\"../_static/image/dpa.png\" alt=\"DPA + EP Architecture\" width=\"100%\">\n</td>\n<td width=\"50%\" valign=\"top\">\n\n**Each DP replica:**\n\n- Processes different batches independently (can be in different forward modes: prefill, decode, or idle)\n- Maintains its own KV cache (no duplication)\n- Enables significantly larger batch sizes due to memory savings\n\n**Communication patterns in DPA + EP:**\n-\n-  **All2All (Dispatch)**: Routes tokens to expert sub-groups based on gating decisions\n- **All2All (Combine)**: Gathers computed results from experts back to original token positions\n\n</td>\n</tr>\n</table>\n\n### Key benefits of DPA\n\n1. **Significantly reduced KV cache memory**: Each DP replica only stores KV cache for its own batches\n2. **Larger batch sizes**: Memory savings enable larger batch sizes\n3. **Improved decoding throughput**: Significant throughput gains for MLA-based models\n4. **Independent forward modes**: Each DP replica can be in different forward modes (prefill, decode, or idle) and handles its assigned batches independently during attention computation\n\n### DPA with Expert Parallelism for MoE\n\nFor MoE models like DeepSeek, DPA is **often** paired with Expert Parallelism (EP) for best throughput at scale. However, **DPA does not require EP**: you can enable DPA without EP if your deployment does not need expert sharding.\n\n- Distribute 256+ expert weights across GPUs (cannot fit on a single GPU)\n- Enable efficient all-to-all token routing via DeepEP\n- Scale to large clusters (up to 5x throughput improvement over vanilla TP)\n\n### Recommended setup for DeepSeek\n\n```bash\npython -m sglang.launch_server \\\n    --model-path deepseek-ai/DeepSeek-V3 \\\n    --tp 8 \\\n    --dp-size 8 \\\n    --ep 8 \\\n    --enable-dp-attention \\\n    --moe-a2a-backend deepep \\\n    --moe-runner-backend deep_gemm\n```\n\n> **Note**: `--dp-size` must be explicitly set when using `--enable-dp-attention`. If `dp_size` is 1 (default), DPA will be disabled.\n\nFor detailed EP configuration (DeepEP, Two-Batch Overlap, EPLB), see [Expert Parallelism](expert_parallelism.md).\n\n### Target Models\n\nDPA supports the following model architectures:\n\n- **MLA (Multi-Head Latent Attention) models** - where DPA provides the most significant benefits:\n  - DeepSeek family (DeepSeek-V2, DeepSeek-V3, DeepSeek-R1)\n  - MiniMax models\n  - Kimi-K2\n  - Other models using MLA architecture\n\n- **Standard attention models** - also supported:\n  - Qwen models (see [PR #6121](https://github.com/sgl-project/sglang/pull/6121))\n\nFor models like Llama, with standard GQA, standard DP, or TP is typically recommended.\n\nTo enable DPA, add `--enable-dp-attention` to your server launch command.\n\n### Activation Logic\n\nDPA is enabled explicitly via server arguments (CLI or config). You must set both `--dp-size` and `--enable-dp-attention`:\n\n```bash\npython -m sglang.launch_server \\\n    --model-path deepseek-ai/DeepSeek-V3 \\\n    --tp 8 \\\n    --dp-size 8 \\\n    --enable-dp-attention\n```\n\n**Important**: `--dp-size` must be greater than 1 for DPA to work. When `dp_size == 1` (default), `--enable-dp-attention` is automatically disabled. The constraint `tp_size % dp_size == 0` must also be satisfied.\n\n### Standard DP for MLA models\n\nNote that MLA models, of course, also support DP. Suppose you want to enable standard DP for MLA models. First, launch each MLA model's replica independently. You may launch these replicas one by one with DPA enabled. After launching each MLA model's replica, launch an SMG and connect all the replicas to the SMG. A detailed explanation of SMG is as follows.\n\n## Modern Data Parallelism SGLang Model Gateway (SMG)\n\n### Native DP Mode\n\nNative DP (built-in Data Parallelism) in SGLang creates multiple worker processes within a single SGLang instance, under the control of `DataParallelController` with the launching parameter of `dp-size`.\n\n\n```bash\n# Native DP mode\npython -m sglang.launch_server \\\n    --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n    --dp-size 4\n```\n\n**Limitations:**\n\n- Built-in in-process load balancing only (e.g., `round_robin`, `total_requests`, `total_tokens`)\n- No cache-aware routing\n- Limited observability and metrics\n- No fault tolerance or circuit breakers\n- Not suitable for production workloads\n\n⚠️ Native DP is **highly not recommended for use right now**. It is only used in some ancient/outdated RL frameworks. You can use SGLang Model Gateway (SMG) to power up your data parallelism in any use case.\n\n### SMG-Based DP (Recommended)\n\nStarting from September 2024, SGLang Model Gateway, i.e., SMG, formerly named as SGLang DP Router, was built especially as a production-ready DP routing system with Rust. It starts from DP routing, but later we further expanded its scope to coordinate RL, PD Disaggregation, and other scenarios. This doc only discusses SMG's usage in DP routing. For other usage, please refer to [SGLang Model Gateway Documentation](sgl_model_gateway.md).\n\n> To achieve the best production-level routing performance and reduce the overhead to an extreme extent, we use Rust to build SMG, but not Python, since Python is never FAST enough.\n\n**We strongly recommend using the SGLang Model Gateway (SMG) for production-grade Data Parallelism.** SMG provides significant advantages over native DP mode.\n\n```bash\n# SMG-based DP mode (Recommended)\npython -m sglang_router.launch_server \\\n    --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n    --dp-size 4\n```\n\n⚠️ Note that **SMG and Naive DP share the same launching parameter, `--dp-size`**. But the entrypoint of Naive DP is `python -m sglang.launch_server`, and SMG's entrypoint is `python -m sglang_router.launch_server`.\n\n**Advantages of SMG-Based DP:**\n\n| Feature | Native DP | SMG-Based DP |\n|---------|-----------|--------------|\n| **Load Balancing** | Built-in in-process methods | Advanced policies (cache-aware, power-of-two, etc.) |\n| **Cache Awareness** | ❌ No | ✅ Yes - significantly higher cache hit rate |\n| **Throughput** | Baseline | Significant improvement |\n| **Multi-Node Support** | Limited | ✅ Full support |\n| **Worker Health Monitoring** | Basic | ✅ Circuit breakers, health checks |\n| **Reliability** | Basic | ✅ Retries, rate limiting, queuing |\n| **Observability** | Basic metrics | ✅ 40+ Prometheus metrics, OpenTelemetry |\n| **Hot Worker Add/Remove** | ❌ No | ✅ Yes |\n\n###  SMG's Performance\n\nThe cache-aware routing policy in SMG significantly improves performance for workloads with shared prefixes:\n\n| Metric | Without Cache-Aware | With Cache-Aware SMG |\n|--------|---------------------|----------------------|\n| Throughput (token/s) | 82,665 | 158,596 (+92%) |\n| Cache Hit Rate | 20% | 75% (+275%) |\n\n*Benchmark from [SGLang v0.4 blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/), workload with multiple long prefix groups, 8x A100 80GB GPUs, dp-size=8*\n\n### When to Use Each\n\n**Use Native DP when:**\n\n- ~Never use Native/Naive DP~\n- Learning material of DP routing\n\n**Use SMG-Based DP when:**\n\n- In any case, when you think DP is needed\n- Production deployments\n- Multi-node distributed setups\n- Workloads with shared prefixes (high cache reuse potential)\n- You need high availability and reliability features\n- You require detailed observability and metrics\n- You want to have highly efficient RL rollout systems\n\nNote that for RL rollout systems, **there are four crucial reasons that SMG-Based DP is far better than naive DP routing**. Details can be found at [Load Balancing Router in RL](./sglang_for_rl.md#load-balancing-router).\n\n### Quick Start For SMG\n\n**Installation**\n\n```bash\npip install sglang-router\n# or\npip install \"sglang[all]\"\n```\n\n**Option A: Co-launch Workers and SMG (Simplest)**\n\nThis is the easiest way to get started - SMG and workers are launched together:\n\n```bash\npython -m sglang_router.launch_server \\\n    --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n    --dp-size 4 \\\n    --host 0.0.0.0 \\\n    --port 30000\n```\n\n**Option B: Separate Launch (Multi-Node)**\n\nFor distributed deployments across multiple machines:\n\n1. Launch workers on each node\n\n```bash\n# Node 1\npython -m sglang.launch_server \\\n    --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n    --port 8000\n\n# Node 2\npython -m sglang.launch_server \\\n    --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n    --port 8000\n```\n\n2. Launch SMG pointing to workers\n\n```bash\npython -m sglang_router.launch_router \\\n    --worker-urls http://node1:8000 http://node2:8000 \\\n    --policy cache_aware \\\n    --host 0.0.0.0 \\\n    --port 30000\n```\n\n**Option C: Dynamic Worker Registration**\n\nFor elastic deployments where workers can be added/removed dynamically:\n\n```bash\n# Launch SMG first\npython -m sglang_router.launch_router \\\n    --policy cache_aware \\\n    --host 0.0.0.0 \\\n    --port 30000\n\n# Register workers dynamically\ncurl -X POST http://localhost:30000/workers \\\n    -H \"Content-Type: application/json\" \\\n    -d '{\"url\": \"http://worker1:8000\"}'\n\ncurl -X POST http://localhost:30000/workers \\\n    -H \"Content-Type: application/json\" \\\n    -d '{\"url\": \"http://worker2:8000\"}'\n```\n\n### Load Balancing Policies\n\nSMG supports multiple load balancing policies:\n\n| Policy | Description | Best For |\n|--------|-------------|----------|\n| `cache_aware` | Combines cache locality with load balancing | **Recommended for most workloads** |\n| `round_robin` | Cycles through workers in order | Simple, predictable distribution |\n| `random` | Random worker selection | Baseline, testing |\n| `power_of_two` | Samples two workers, picks lighter one | Low latency requirements |\n\n**Cache-Aware Policy (Default, Recommended)**\n\nThe cache-aware policy provides the best performance for most workloads:\n\n```bash\npython -m sglang_router.launch_router \\\n    --worker-urls http://worker1:8000 http://worker2:8000 \\\n    --policy cache_aware \\\n    --cache-threshold 0.5 \\\n    --balance-abs-threshold 32 \\\n    --balance-rel-threshold 1.5 \\\n    --eviction-interval-secs 120 \\\n    --max-tree-size 67108864\n```\n\n**How it works:**\n\n1. Maintains an approximate radix tree for each worker based on request history\n2. Routes requests to workers with the highest prefix match (cache hit)\n3. Falls back to shortest-queue routing when load is imbalanced\n4. Automatically evicts old entries to prevent memory overflow\n\n### Best Practices\n\n1. **Start with `cache_aware` policy** - It provides the best balance between cache locality and load distribution for most workloads\n2. **Use SMG for production** - Prefer `sglang_router.launch_server` over `sglang.launch_server` for better reliability and observability\n3. **Enable health checks** - Configure `--router-health-check-interval-secs` to detect and remove unhealthy workers automatically\n\n**Recommended command with best practices applied:**\n\n```bash\npython -m sglang_router.launch_server \\\n    --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n    --dp-size 4 \\\n    --router-policy cache_aware \\\n    --router-health-check-interval-secs 30 \\\n    --router-prometheus-port 10001 \\\n    --host 0.0.0.0 \\\n    --port 30000\n```\n\nFor advanced configuration (circuit breakers, retries, Prometheus metrics, K8s integration), see [SGLang Model Gateway Documentation](sgl_model_gateway.md).\n\n### Verifying Traffic Distribution\n\nAfter launching SMG, verify that traffic is being distributed correctly:\n\n**1. Check worker status:**\n\n```bash\ncurl http://localhost:30000/workers\n```\n\n**2. Check load distribution:**\n\n```bash\ncurl http://localhost:30000/get_loads\n```\n\n**3. Monitor metrics (if Prometheus enabled):**\n\n```bash\n# Key metrics to check\nsmg_router_requests_total{model=\"...\"}\nsmg_worker_requests_active{worker=\"...\"}\nsglang_cache_hit_rate{source=\"...\"}\n```\n\nFor detailed metrics and monitoring setup, see [SGLang Model Gateway Documentation](sgl_model_gateway.md).\n\n## Reference\n\n| Strategy | Use Case | Key Benefit |\n|----------|----------|-------------|\n| **Native DP** (`--dp-size`) | Never | Easy to understand, not rust based |\n| **SMG-Based DP** | **Production (recommended)** | Cache-aware routing, high availability |\n| **DPA** (`--dp-size N --enable-dp-attention`) | DeepSeek/MLA models | Eliminates KV cache duplication, improved throughput |\n| **DPA + EP** | DeepSeek MoE models | Significant throughput improvement vs vanilla TP |\n\n**Recommended production setup for DeepSeek:**\n1. Enable **DPA** for attention layers (`--dp-size 8 --enable-dp-attention`)\n2. Enable **EP** for MoE layers (`--ep 8 --moe-a2a-backend deepep`)\n3. Use **SMG** with **cache_aware** policy\n\n**Related documentation:**\n- [Expert Parallelism](expert_parallelism.md) - DeepEP, Two-Batch Overlap, EPLB\n- [SGLang Model Gateway Documentation](sgl_model_gateway.md) - SMG configuration & troubleshooting\n- [Large-Scale EP Blog](https://lmsys.org/blog/2025-05-05-large-scale-ep/) - 96 GPU deployment guide\n"
  },
  {
    "path": "docs/advanced_features/dp_for_multi_modal_encoder.md",
    "content": "# DP for Multi-Modal Encoder in SGLang\n\nA typical VLM architecture involves two main components: an multi-modal encoder and a text decoder.\n\nMost VLMs utilize a Vision Transformer (ViT) as their multi-modal encoder, it is responsible for processing visual data, extracting features (objects, colors, textures, etc.), and transforming them into a format that can be understood by the model.\n\nThe text decoder is based on LLM. It processes textual data and generates output based on the encoded visual features.\n\nHowever, since the size of ViT is very small compared to language decoders,\nthere is relatively little gain from TP. On the other hand, TP incurs significant communication\noverhead because of all-reduce being performed after every layer.\n\nPlacing the ViT in data parallel while keeping the LLM in tensor parallel consistently lowers TTFT and boosts end-to-end throughput. In this hybrid layout, the vision front-end becomes parallel and lightweight, while scarce interconnect bandwidth and collective ops are reserved for the LLM.\n\nData parallelism replicates the entire model across multiple GPU sets and processes different batches of requests in parallel.\n\n## Command Example\nYou can enable batch-level DP by setting `mm-enable-dp-encoder`, for example:\n```\npython3 -m sglang.launch_server \\\n    --model-path Qwen/Qwen2.5-VL-7B-Instruct \\\n    --tp 2 \\\n    --mm-enable-dp-encoder\n```\n\n## Known supported models\n- Qwen2.5-VL (<https://github.com/sgl-project/sglang/pull/13126>)\n- Qwen3-VL (<https://github.com/sgl-project/sglang/pull/13724>)\n- InternVL (<https://github.com/sgl-project/sglang/pull/13925>)\n- GLM-4.5V & GLM-4.6V (<https://github.com/sgl-project/sglang/pull/14097>)\n"
  },
  {
    "path": "docs/advanced_features/epd_disaggregation.md",
    "content": "# EPD Disaggregation\n\n## Why and What is EPD Disaggregation?\n\nIn modern Vision-Language Model (VLM) inference, request execution naturally decomposes into three distinct stages: Encoder, Prefill, and Decode.\nThe Encoder stage performs vision preprocessing and ViT-based image encoding, which is highly compute-intensive but only required during request initialization. The Prefill stage processes the full multimodal input sequence to initialize the language model’s Key-Value (KV) cache, while the Decode stage is dominated by memory bandwidth and KV cache access for autoregressive token generation.\n\nExisting deployments typically colocate these stages within a unified execution engine, or at best apply Prefill–Decode (PD) disaggregation. However, such designs still tightly couple vision encoding with language prefill, leading to inefficient resource utilization, limited scalability for image-heavy workloads, and suboptimal scheduling under load.\n\nTo address these challenges, we introduce Encoder–Prefill–Decode (EPD) Disaggregation in SGLang. EPD further separates vision encoding from language processing, enabling independent horizontal scaling of encoder servers, improved load balancing for multimodal requests, and seamless integration with existing PD disaggregation to form a fully decoupled three-tier inference architecture.\n\n### Usage\n\nYou can launch a language-only model using `--language-only`, or an encoder-only model using `--encoder-only`.\nWhen launching a language-only model, you must additionally specify the encoder service endpoints via `--encoder-urls`.\n\nWe support multiple encoder transfer backends, including zmq_to_scheduler, zmq_to_tokenizer, and mooncake (the default is zmq_to_scheduler). The backend can be selected using `--encoder-transfer-backend`.\n\n### Encoder transfer with Mooncake\n\n`--encoder-transfer-backend mooncake` controls **how encoder outputs are transferred** between encoder and language/prefill services. It is an encoder transfer option and can be used independently of the global multimodal embedding cache.\n\nExample:\n\n```bash\n# encoder\npython -m sglang.launch_server \\\n  --model-path Qwen/Qwen3-VL-8B-Instruct \\\n  --encoder-only \\\n  --encoder-transfer-backend mooncake \\\n  --port 30000\n\n# language-only server\npython -m sglang.launch_server \\\n  --model-path Qwen/Qwen3-VL-8B-Instruct \\\n  --language-only \\\n  --encoder-urls http://127.0.0.1:30000 \\\n  --encoder-transfer-backend mooncake \\\n  --port 30002\n```\n\n### Global multimodal embedding cache with Mooncake\n\nSGLang also supports a Mooncake-backed **global multimodal embedding cache** for EPD workloads. When enabled on encoder servers, repeated image inputs can reuse previously computed ViT embeddings across instances instead of running the vision encoder again.\n\nThis feature is useful when:\n\n- the deployment serves repeated or overlapping image inputs,\n- encoder compute is the bottleneck, and\n- Mooncake is already available in the cluster.\n\nAt a high level, the encoder checks whether the image embedding already exists in Mooncake. Cache hits are prefetched from the global store, while misses are encoded normally and inserted into the cache in the background.\n\nTo enable it:\n\n- install and configure Mooncake in the same way as other SGLang Mooncake integrations,\n- add `--enable-mm-global-cache` on the encoder server.\n\n`--enable-mm-global-cache` controls **whether multimodal embeddings are looked up and stored in the global Mooncake cache**. It is separate from `--encoder-transfer-backend`, which only controls encoder output transport.\n\nFor Mooncake deployment and configuration details, see [HiCache best practices](hicache_best_practices.md#deployment-with-mooncake) and the [Mooncake backend README](../../python/sglang/srt/mem_cache/storage/mooncake_store/README.md).\n\nExample:\n\n```bash\n# Shared Mooncake configuration\nexport MOONCAKE_TE_META_DATA_SERVER=\"http://127.0.0.1:8080/metadata\"\nexport MOONCAKE_MASTER=\"127.0.0.1:50051\"\nexport MOONCAKE_PROTOCOL=\"rdma\"\nexport MOONCAKE_GLOBAL_SEGMENT_SIZE=\"4gb\"\n\n# encoder with global multimodal cache enabled\npython -m sglang.launch_server \\\n  --model-path Qwen/Qwen3-VL-8B-Instruct \\\n  --encoder-only \\\n  --enable-mm-global-cache \\\n  --port 30000\n\n# language-only server\npython -m sglang.launch_server \\\n  --model-path Qwen/Qwen3-VL-8B-Instruct \\\n  --language-only \\\n  --encoder-urls http://127.0.0.1:30000 \\\n  --port 30002\n```\n\nNotes:\n\n- This cache is for **multimodal encoder embeddings**, not the language model KV cache.\n- The feature currently uses Mooncake as the shared backing store.\n- It can be enabled regardless of which `--encoder-transfer-backend` you use.\n- It is most relevant for EPD or encoder-disaggregated VLM deployments where the same images are likely to appear across requests or instances.\n\n#### Qwen VL\n\n- EP Disaggregation\n\n```bash\n# encoder 0\npython -m sglang.launch_server \\\n  --model-path Qwen/Qwen3-VL-8B-Instruct \\\n  --encoder-only \\\n  --encoder-transfer-backend zmq_to_scheduler \\\n  --port 30000\n# encoder 1\npython -m sglang.launch_server \\\n  --model-path Qwen/Qwen3-VL-8B-Instruct \\\n  --encoder-only \\\n  --encoder-transfer-backend zmq_to_scheduler \\\n  --port 30001\n# language-only server\npython -m sglang.launch_server \\\n  --model-path Qwen/Qwen3-VL-8B-Instruct \\\n  --language-only \\\n  --encoder-urls http://127.0.0.1:30000 http://127.0.0.1:30001 \\\n  --encoder-transfer-backend zmq_to_scheduler \\\n  --port 30002\n```\n\n- EPD Disaggregation\n\n```bash\n# encoder 0\npython -m sglang.launch_server \\\n  --model-path Qwen/Qwen3-VL-8B-Instruct \\\n  --encoder-only \\\n  --encoder-transfer-backend zmq_to_scheduler \\\n  --port 30000\n# encoder 1\npython -m sglang.launch_server \\\n  --model-path Qwen/Qwen3-VL-8B-Instruct \\\n  --encoder-only \\\n  --encoder-transfer-backend zmq_to_scheduler \\\n  --port 30001\n# prefill 0\npython -m sglang.launch_server \\\n  --model-path Qwen/Qwen3-VL-8B-Instruct \\\n  --disaggregation-mode prefill \\\n  --language-only \\\n  --encoder-urls http://127.0.0.1:30000 http://127.0.0.1:30001 \\\n  --encoder-transfer-backend zmq_to_scheduler \\\n  --port 30002\n# decode 0\npython -m sglang.launch_server \\\n  --model-path Qwen/Qwen3-VL-8B-Instruct \\\n  --disaggregation-mode decode \\\n  --port 30003\n# router\npython -m sglang_router.launch_router \\\n  --pd-disaggregation \\\n  --prefill http://$PREFILL_HOST:30002 \\\n  --decode http://$DECODE_HOST:30003 \\\n  --port 8000\n\n```\n\n#### gRPC Encoder (EPD)\n\nYou can run the encoder as a gRPC server while keeping prefill/decode as HTTP.\nWhen using gRPC encoders, set `SGLANG_ENCODER_MM_RECEIVER_MODE=grpc` for the\nprefill process so it uses the gRPC receiver.\n\n```bash\n# gRPC encoder\npython -m sglang.launch_server \\\n  --model-path Qwen/Qwen3-VL-8B-Instruct \\\n  --encoder-only \\\n  --grpc-mode \\\n  --encoder-transfer-backend zmq_to_scheduler \\\n  --port 30000\n\n# prefill (HTTP) - tell it to use gRPC receiver\nSGLANG_ENCODER_MM_RECEIVER_MODE=grpc \\\npython -m sglang.launch_server \\\n  --model-path Qwen/Qwen3-VL-8B-Instruct \\\n  --disaggregation-mode prefill \\\n  --language-only \\\n  --encoder-urls grpc://127.0.0.1:30000 \\\n  --encoder-transfer-backend zmq_to_scheduler \\\n  --port 30002\n\n# decode (HTTP)\npython -m sglang.launch_server \\\n  --model-path Qwen/Qwen3-VL-8B-Instruct \\\n  --disaggregation-mode decode \\\n  --port 30003\n\n# router\npython -m sglang_router.launch_router \\\n  --pd-disaggregation \\\n  --prefill http://$PREFILL_HOST:30002 \\\n  --decode http://$DECODE_HOST:30003 \\\n  --port 8000\n```\n"
  },
  {
    "path": "docs/advanced_features/expert_parallelism.md",
    "content": "# Expert Parallelism\n\nExpert Parallelism (EP) in SGLang distributes expert weights across multiple devices in Mixture-of-Experts (MoE) models, addressing memory bottlenecks and enabling efficient scaling for high-performance inference. It is particularly vital for serving large-scale MoE models where tokens are dynamically routed to specialized experts across GPUs. By leveraging optimized all-to-all communication and grouped matrix multiplications (GEMMs), EP reduces latency, boosts throughput, and minimizes idle GPU time. SGLang's EP offers strong extensibility through its modular framework, allowing seamless integration of custom kernels, backends, and optimizations without refactoring core logic, supporting diverse hardware and quantization schemes.\n\n## Supported Backends and Selection Guidance\n\nSGLang's EP integrates diverse, highly efficient backends for different use cases, allowing fine-grained control over performance trade-offs. Users specify backends via command-line flags:\n- `--moe-a2a-backend`: Selects the backend for all-to-all communication.\n- `--moe-runner-backend`: Selects the backend for MoE computation.\n\n### Backends for All-to-All Communication\n\n| Backend      | Description                                                                 | Use Cases                          |\n|--------------|-----------------------------------------------------------------------------|------------------------------------|\n| **`none` (default)** | Disables all-to-all for EP. Uses All-Reduce or All-Gather for token dispatch. | Hybrid EP and TP setups.           |\n| `deepep`     | DeepEP, a communication library for efficient token shuffling in MoE models. | Large-scale EP deployments.        |\n| `mooncake`   | An extension of DeepEP for elastic inference, leveraging RDMA for high-performance data transfers. | Elastic EP serving. |\n| `nixl`       | [NIXL-EP](https://github.com/ai-dynamo/nixl/tree/main/examples/device/ep), an elastic EP communication library built on NVIDIA's [NIXL](https://github.com/ai-dynamo/nixl) framework with native RDMA and NVLink support. | Elastic EP serving with fault tolerance and dynamic scaling. |\n| `mori` | MORI-EP, AMD's native all-to-all communication implementation optimized for ROCm. | AMD GPU deployments. |\n| `flashinfer` | Flashinfer implementation of all-to-all. | Large-scale EP deployments. |\n| `ascend_fuseep` | Ascend NPU native fused all-to-all communication. | Ascend NPU deployments. |\n\nDeepEP and Mooncake backends support two modes for token dispatch: `normal` mode (optimized for prefill workloads with high throughput) and `low_latency` mode (optimized for decode workloads with low latency and CUDA Graph compatibility). MORI backend only supports `normal` mode now. NIXL-EP currently operates in low-latency mode with CUDA Graph support. Users are recommended to set `--deepep-mode auto` to enable automatic dispatch mode switching during runtime. Setting `--deepep-mode normal` or `--deepep-mode low_latency` is useful for debugging or development purposes.\n\nCurrently, DeepEP, Mooncake, NIXL-EP, `ascend_fuseep` and MORI only support cases where `ep_size = tp_size`. For hybrid EP and TP (i.e., `ep_size < tp_size`), only the `none` backend (All-Reduce or All-Gather-based dispatching) is supported.\n\n### Backends for MoE Computation\n\n| Backend                  | Description                                                                 | Use Cases                          |\n|--------------------------|-----------------------------------------------------------------------------|------------------------------------|\n| **`auto` (default)**     | Automatically selects the optimal backend based on model architecture, hardware (e.g., NVIDIA architecture like Ampere, Hopper, Blackwell), quantization scheme (e.g., FP8, FP4), and runtime conditions. | General-purpose deployments; ensures compatibility and performance without user intervention. |\n| `triton`                 | Triton-based implementation for grouped GEMMs. To achieve higher performance, it's highly recommended to create [tuned configurations](https://github.com/sgl-project/sglang/blob/main/benchmark/kernels/fused_moe_triton/README.md). | Custom kernel development or scenarios requiring high extensibility with Torch compilation support. |\n| `deep_gemm`              | DeepGEMM backend optimized for MoE matrix multiplications, supporting contiguous layouts for prefill and masked layouts for decode; often JIT-compiled for performance. | Large-scale EP deployments with FP8 block-wise quantization. |\n| `cutlass`                | CUTLASS-based backend for efficient GEMMs. | NVIDIA architectures with CUTLASS support. |\n| `flashinfer_trtllm`      | FlashInfer integrated with TensorRT-LLM for accelerated MoE computations, supporting FP4 communication operators and high-performance GEMMs. | Blackwell with TRT-LLM. |\n| `flashinfer_trtllm_routed` | FlashInfer integrated with TensorRT-LLM for accelerated routed MoE computations, consuming SGLang-computed top-k expert assignments and weights. | Blackwell with TRT-LLM. |\n| `flashinfer_cutlass`     | FlashInfer combined with CUTLASS for high-performance grouped GEMMs in MoE layers, handling FP4/FP8 quantization efficiently. | Blackwell with FP4/FP8 models. |\n| `flashinfer_mxfp4`       | FlashInfer variant optimized for MXFP4 (mixed FP4) quantization in MoE runners, focusing on memory-efficient low-precision inference. | Low-precision models with MXFP4. |\n| `flashinfer_cutedsl`     | FlashInfer with a custom DSL for flexible and efficient MoE kernel generation, integrated with ModelOpt FP4 quantization. | Low-precision models with NVFP4. |\n\n### Examples\n\nLaunch with DeepEP and DeepGEMM for DeepSeek-V3:\n\n```bash\npython -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --moe-a2a-backend deepep --moe-runner-backend deep_gemm --tp 8 --ep 8\n```\n\n## Extensible EP Framework\n\nSGLang's EP framework provides modular abstractions for easy integration of custom kernels, backends, and optimizations. It decouples the MoE forward pass into stages (dispatch → pre-permute → core runner → post-permute → combine), enabling seamless extensions without refactoring core logic.\n\n### Framework Overview\n\nThe framework centers on `FusedMoE` as the unified entry point for a single, extensible structure. Key components include:\n- **Dispatcher**: Manages dispatch/combine for backends like DeepEP (implements `BaseDispatcher` subclasses).\n- **MoeRunner**: Orchestrates grouped-GEMM execution via `MoeRunnerCore` implementations (e.g., `TritonRunnerCore`).\n- **PermuteMethodPool**: Auto-registers layout conversions (e.g., pre/post-permute via `register_pre_permute` and `register_post_permute` for dynamic modes, or `register_fused_func` for static, torch.compile-compatible fused operations).\n- **TopK Router**: Backend-agnostic expert selection.\n\nThis design supports multiple backends via `--moe-a2a-backend` and `--moe-runner-backend`, with quantization integrated through a standardized `apply()` method. The computation flow ensures modularity:\n\n```\n[input_hidden_states]\n          |\n          v\n     TopK.forward -> select_experts / triton_kernels.routing / bypass\n          |\n          v\n     [TopKOutput]\n          |\n          v\n   FusedMoE.forward -> Dispatcher.dispatch -> DeepEP / bypass\n          |                     |\n          |                     v\n          |              [DispatchOutput]\n          |                     |\n          |                     v\n          |             quant_method.apply -> MoeRunner.forward\n          |                     |              |\n          |                     |              v\n          |                     | pre-permute + grouped_gemm + post-permute\n          |                     |              |\n          |                     |--------------\n          |                     v\n          |               [CombineInput]\n          |                     |\n          |                     v\n          |            Dispatcher.combine -> DeepEP / bypass\n          |                     |\n          |---------------------\n          v\n[final_hidden_states]\n```\n\nFor details, see the [MoE Refactor Roadmap](https://github.com/sgl-project/sglang/issues/8715).\n\n### Implementing New Backends\n\nTo add a new backend:\n1. For a new all-to-all dispatcher, implement a `BaseDispatcher` subclass with `dispatch` and `combine` methods.\n2. For a new MoE runner backend, define a `MoeRunnerCore` subclass for core operations (e.g., grouped GEMMs).\n3. Define new input/output formats for the dispatcher or model runner (e.g., `RunnerInput`, `RunnerOutput`).\n4. Register permute/unpermute methods to ensure compatibility:\n   - **Fused Mode** (static, torch.compile-compatible): Use `register_fused_func` for end-to-end operations.\n   - **Permute Mode** (dynamic): Register `register_pre_permute` and `register_post_permute` for flexible layouts.\n\nSee the [MoE Refactor Implementation PR](https://github.com/sgl-project/sglang/pull/9269) for full changes, including type hints and config expansions.\n\n### Examples\n\nFor an example implementation, see [moe_runner/triton.py](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/moe/moe_runner/triton.py), which demonstrates Triton-based grouped GEMMs with registered fused and permutation functions.\n\n## Computation and Communication Overlap\n\nSGLang's EP employs advanced overlap techniques to hide communication latency behind computation, maximizing GPU utilization in MoE layers.\n\n### Two-Batch Overlap (TBO)\n\nTBO splits requests into micro-batches, interleaving attention computation with dispatch/combine operations. Yield points in the execution graph allow pausing for overlaps, increasing overall throughput without peak memory spikes:\n\n```python\noperations = [\n    self._forward_attn,\n    YieldOperation(),  # Overlap with dispatch of prior micro-batch\n    self._forward_dispatch,\n    self._forward_mlp,\n    YieldOperation(),  # Overlap with combine\n    self._forward_combine,\n]\n```\n\nUsers need to specify `--enable-two-batch-overlap` to unlock up to 2x throughput. For details, see the [Large-Scale EP Blog](https://lmsys.org/blog/2025-05-05-large-scale-ep/#two-batch-overlap).\n\n### Single-Batch Overlap (SBO)\n\nSGLang introduces a dispatcher-hook system for Single-Batch Overlap (SBO), enabling the overlap of operations within a single batch—such as shared experts computation with communication—while decentralizing logic to enhance modularity. These hooks execute before and after the `dispatch` and `combine` operations without modifying core MoE modules. This design simplifies interfaces, reduces coupling, and improves extensibility. For implementation details and an example of overlapping shared experts with DeepEP's combine operation, refer to [PR #13327](https://github.com/sgl-project/sglang/pull/13327). Users can set `--enable-single-batch-overlap` to enable this feature.\n\n\n## Workload Balancer\n\nSGLang integrates the [Expert Parallelism Load Balancer (EPLB)](https://github.com/deepseek-ai/EPLB) from DeepSeek to address routing imbalances in MoE models. By analyzing expert activation statistics, EPLB computes an optimal expert arrangement, strategically placing or replicating experts to minimize GPU utilization variance, reduce idle cycles, and enhance scalability.\n\nTo enable EPLB, use the flags `--enable-eplb`. For optimal performance, increase batch sizes to stabilize activation statistics and configure periodic rebalancing (e.g., every 1000 requests) to adapt to evolving workloads. Simulations demonstrate significant improvements in load balancedness (ratio of mean to max computation time), correlating strongly with throughput gains.\n\nFor more details, refer to the [EPLB Section in the Large-Scale EP Blog](https://lmsys.org/blog/2025-05-05-large-scale-ep/#expert-parallelism-load-balancer) and the [EPLB Repository](https://github.com/deepseek-ai/eplb).\n\n\n## EP with Spectulative Decoding\n\n\nWhen utilizing speculative decoding with MTP on MoE architectures, use the `--speculative-moe-runner-backend` and `--speculative-moe-a2a-backend` arguments to customize the MoE layer behavior for the draft model. While they default to the target model’s settings, users can differentiate them for varying precisions between target and draft models.\n\nFor model like `nvidia/DeepSeek-R1-0528-NVFP4-v2`, the target model uses NVFP4 precision while the draft model uses BF16. To apply `flashinfer_trtllm` kernel for target MoE layer while falling back to triton fused MoE kernel for draft MoE layer, users can set the arguments as follows:\n```\n...\n--moe-runner-backend flashinfer_trtllm \\\n--speculative-moe-runner-backend triton \\\n...\n```\n\n\n## Ascend NPU Guidance\n\n\n### Guidance on SGLang configuration in Ascend NPU\n- `--moe-a2a-backend` only supports `deepep` and `ascend_fuseep` backends,\n  - `deepep`: The mechanism is consistent with the above description.\n  - `ascend_fuseep`: Offer a large fused operator which integrates all operations between dispatch and combine to boost MoE computation. Only used for decode stage in PD Disaggregation Mode.\n- `--moe-runner-backend` parameter does not need to be configured.\n- `--deepep-mode`:\n  - In PD mixed mode, please set `--deepep-mode auto`.\n  - In PD Disaggregation Mode, prefill instance sets `--deepep-mode normal`, and decode instance sets `--deepep-mode low_latency`.\n\n\n### DeepEP Ascend Introduction\n\nDeepEP Ascend is the adapted version of the DeepEP communication library for Huawei Ascend NPUs, specifically designed for Mixture-of-Experts (MoE) model Expert Parallelism (EP).\nIt supports the Ant-moving Function (Split the sequence length into rounds for streaming batch transmission) to optimize the buffer size occupied during collective communication in prefill stage, especially for long sequences.\n\nAnt-moving Function can be enabled for both the dispatch and combine phases via the following environment variables:\n- `DEEPEP_NORMAL_LONG_SEQ_PER_ROUND_TOKENS`: Enable ant-moving function in dispatch stage. Indicates the number of tokens transmitted per round on each rank, default 8192.\n- `DEEPEP_NORMAL_LONG_SEQ_ROUND`: Enable ant-moving function in dispatch stage. Indicates the number of rounds transmitted on each rank, default 1.\n- `DEEPEP_NORMAL_COMBINE_ENABLE_LONG_SEQ`: Enable ant-moving function in combine stage, default 0 (means disabled).\n\n`DEEPEP_NORMAL_LONG_SEQ_PER_ROUND_TOKENS * DEEPEP_NORMAL_LONG_SEQ_ROUND` means input sequence length. When the input sequence length exceeds 8192, it is recommended to enable the ant-moving function in both dispatch and combine phase.\n\nThe environment variable `HCCL_BUFFSIZE` is used to configure the buffer size (MB) actually allocated. Its calculation formula is as follows:\n```angular2html\n# Enable Ant-moving Function\nHCCL_BUFFSIZE >= 2 * (102MB + 4MB + DEEPEP_NORMAL_LONG_SEQ_PER_ROUND_TOKENS * (hidden_size + hidden_size + hidden_size) * topk) + PADDING_BUFFSIZE\n\n# Disable Ant-moving Function\nHCCL_BUFFSIZE >= 2 * (102MB + 4MB + TOTAL_SEQ_LEN * (hidden_size + hidden_size) * topk) + PADDING_BUFFSIZE\n```\nWherein the parameters are described as follows:\n- `hidden_size`: hidden size in model config.\n- `topk`: The number of selected routing experts.\n- `TOTAL_SEQ_LEN`: input sequence length.\n- `PADDING_BUFFSIZE`: A value of 20 or greater is recommended.\n"
  },
  {
    "path": "docs/advanced_features/forward_hooks.md",
    "content": "## Model Hooks\n\nSGLang supports attaching PyTorch forward hooks to specific submodules in the loaded model, configured entirely via `server_args` JSON.\n\nThis is useful for:\n\n* Logging intermediate activations\n* Debugging model internals\n* Exporting hidden states to external tooling\n\nHooks are attached once during `ModelRunner.initialize` and run on every forward pass.\n\n---\n\n### Configuration overview\n\nHooks are configured via a `ServerArgs` field:\n\n```python\nclass ServerArgs:\n    ...\n    # For forward hooks\n    forward_hooks: Optional[List[dict[str, Any]]] = None\n````\n\nIn JSON form, a minimal configuration looks like:\n\n```jsonc\n{\n  \"forward_hooks\": [\n    {\n      \"name\": \"outer_linear_hooks\",\n      \"target_modules\": [\"outer.0\", \"outer.1\"],\n      \"hook_factory\": \"my_project.hooks:dummy_hook_factory\",\n      \"config\": {\n        \"tag\": \"outer-layer\"\n      }\n    }\n  ]\n}\n```\n\n#### Top-level fields\n\n* `forward_hooks` (optional list of objects)\n  Each element is a hook spec describing:\n\n  * Which modules to target\n  * Which Python factory to call\n  * What configuration to pass into that factory\n\n---\n\n### Hook spec schema\n\nEach entry in `forward_hooks` is a JSON object with the following shape:\n\n```jsonc\n{\n  \"name\": \"optional-descriptive-name\",\n  \"target_modules\": [\"pattern1\", \"pattern2\", \"...\"],\n  \"hook_factory\": \"module.submodule:factory_name\",\n  \"config\": {\n    \"...\": \"arbitrary JSON\"\n  }\n}\n```\n\n#### `name` (optional)\n\n* Human-readable name for logging.\n* Used only in log messages such as:\n\n  ```text\n  Registered forward hook 'outer_linear_hooks' on outer.0\n  ```\n\n#### `target_modules` (required)\n\n* List of **module name patterns** used to match entries in `model.named_modules()`.\n* Patterns are matched using `fnmatch.fnmatch`, so:\n\n  * `\"outer.0\"` matches exactly `\"outer.0\"`.\n  * `\"outer.*\"` matches `\"outer.0\"`, `\"outer.1\"`, `\"outer.inner\"`, etc.\n  * `\"outer.inner.*\"` matches children under `outer.inner`.\n\n> If no modules match the given patterns, hook registration does **not** fail.\n> Instead, SGLang logs a warning and continues:\n>\n> ```text\n> No modules matched hook spec 'name' patterns=['...']\n> ```\n\n#### `hook_factory` (required)\n\n* String path to the Python factory function that creates the hook.\n* Supported formats:\n\n  * `\"package.module:factory_name\"`\n  * `\"package.module.submodule.factory_name\"`\n\nThe path is resolved via:\n\n```python\ndef resolve_callable(path: Optional[str]) -> Optional[Callable]:\n    if path is None:\n        return None\n\n    if \":\" in path:\n        module_name, fn_name = path.split(\":\", 1)\n    else:\n        parts = path.split(\".\")\n        if len(parts) < 2:\n            raise ValueError(\n                f\"Invalid hook callable path '{path}'. \"\n                \"Expected 'module.submodule:factory' or 'module.submodule.factory'.\"\n            )\n        *mod_parts, fn_name = parts\n        module_name = \".\".join(mod_parts)\n\n    module = importlib.import_module(module_name)\n    try:\n        return getattr(module, fn_name)\n    except AttributeError as e:\n        raise AttributeError(\n            f\"Module '{module_name}' has no attribute '{fn_name}' \"\n            f\"(from hook path '{path}')\"\n        ) from e\n```\n\n**Failure modes**:\n\n* If the path is malformed (not enough dots and no `:`), a `ValueError` is raised at startup.\n* If the module imports but the attribute is missing, an `AttributeError` is raised with a clear error message.\n* If the hook factory returns `None`, a warning is logged and no hook is registered for that spec (initialization continues).\n\nThe first two cause initialization to fail fast with a descriptive error; the last one is non-fatal.\n\n#### `config` (optional)\n\n* Arbitrary JSON object.\n* Passed directly to the hook factory as a Python `dict`.\n* This lets you parameterize hook behavior from config (e.g. tags, log levels, sampling rates, etc.).\n\n---\n\n### Hook lifecycle and behavior\n\nHooks are registered in `ModelRunner.initialize()`:\n\n```python\nif server_args.forward_hooks:\n    register_forward_hooks(self.model, server_args.forward_hooks)\n```\n\nThe actual registration logic is implemented by `register_forward_hooks`:\n\n```python\ndef register_forward_hooks(model: nn.Module, hook_specs: List[dict[str, Any]]) -> None:\n    \"\"\"\n    hook_specs is a list of dicts from server_args.forward_hooks.\n    Attaches forward hooks to the matching modules.\n    \"\"\"\n    name_to_module = dict(model.named_modules())\n\n    for spec in hook_specs:\n        spec_name = spec.get(\"name\", \"\")\n        target_patterns = spec.get(\"target_modules\", [])\n        if not target_patterns:\n            logger.warning(\n                f\"Hook spec '{spec_name}' has no 'target_modules', skipping\"\n            )\n            continue\n\n        hook_factory_path = spec.get(\"hook_factory\")\n        if not hook_factory_path:\n            logger.warning(\n                f\"Hook spec '{spec_name}' has no 'hook_factory', skipping\"\n            )\n            continue\n\n        config = spec.get(\"config\") or {}\n        hook_factory = resolve_callable(hook_factory_path)\n\n        hook = hook_factory(config) if hook_factory else None\n        if hook is None:\n            logger.warning(\n                f\"Hook factory '{hook_factory_path}' for spec '{spec_name}' \"\n                \"returned None, not registering any hook\"\n            )\n            continue\n\n        # Resolve patterns like \"model.layers.*.mlp\"\n        matched = []\n        for name, module in name_to_module.items():\n            if any(fnmatch.fnmatch(name, pattern) for pattern in target_patterns):\n                matched.append((name, module))\n\n        if not matched:\n            logger.warning(\n                f\"No modules matched hook spec '{spec_name}' \"\n                f\"patterns={target_patterns}\"\n            )\n            continue\n\n        for module_name, module in matched:\n            if hook:\n                _ = module.register_forward_hook(hook)\n                logger.info(\n                    f\"Registered forward hook '{spec_name}' \"\n                    f\"on {module_name}\"\n                )\n```\n\nKey points:\n\n* Hooks are **forward hooks only** (via `module.register_forward_hook`).\n* They are attached once at initialization.\n* Hook handles are currently not stored on `ModelRunner` (they cannot be removed later via this API).\n* Failure to match any modules is non-fatal; a warning is logged instead.\n* If a hook factory returns `None`, a warning is logged and that spec is skipped.\n\n---\n\n### Writing a hook factory\n\nA hook factory is a regular Python function:\n\n* Takes a `config: dict` (from JSON)\n* Returns a forward hook function with signature `(module, inputs, output)`\n\nExample:\n\n```python\nHOOK_CALLS = []\n\ndef dummy_hook_factory(config):\n    \"\"\"Factory that returns a forward hook capturing a tag from config.\"\"\"\n    tag = config.get(\"tag\", \"default\")\n\n    def hook(module, inputs, output):\n        HOOK_CALLS.append(\n            {\n                \"module_type\": type(module).__name__,\n                \"tag\": tag,\n                \"shape\": tuple(output.shape),\n            }\n        )\n        return output  # must return output if you don’t want to modify the tensor\n\n    return hook\n```\n\nIn JSON:\n\n```jsonc\n{\n  \"forward_hooks\": [\n    {\n      \"name\": \"capture_outer\",\n      \"target_modules\": [\"outer.0\", \"outer.1\"],\n      \"hook_factory\": \"my_project.hooks:dummy_hook_factory\",\n      \"config\": {\n        \"tag\": \"outer\"\n      }\n    }\n  ]\n}\n```\n\nThis will:\n\n* Resolve `my_project.hooks:dummy_hook_factory` to a Python callable.\n* Call it with `config = {\"tag\": \"outer\"}`.\n* Use the returned hook for all modules matching `outer.0` and `outer.1`.\n* Append metadata about each call to `HOOK_CALLS`.\n\n---\n\n### Summary\n\n* Define `forward_hooks` as a list of specs in `ServerArgs` to turn on the feature.\n\n* Each spec:\n\n  * selects modules via `target_modules` (glob patterns over `model.named_modules()`),\n  * points to a hook factory via `hook_factory`,\n  * passes arbitrary `config` into that factory.\n\n* Hook factories are resolved via `resolve_callable`, which supports `module:factory` and `module.submodule.factory`.\n\n* Hooks are standard PyTorch forward hooks, attached once at startup and invoked on every forward pass.\n\n* Misconfiguration is either:\n\n  * **fatal and explicit** (bad path / missing attribute), or\n  * **non-fatal with clear warnings** (no targets matched, or factory returned `None`).\n"
  },
  {
    "path": "docs/advanced_features/hicache.rst",
    "content": "Hierarchical KV Caching (HiCache)\n=================================\n\n.. toctree::\n   :maxdepth: 1\n\n   hicache_best_practices.md\n   hicache_design.md\n   hicache_storage_runtime_attach_detach.md\n"
  },
  {
    "path": "docs/advanced_features/hicache_best_practices.md",
    "content": "# SGLang HiCache Best Practices\n\n## Why HiCache Matters\n\nSGLang HiCache extends the traditional RadixAttention with a three-tier hierarchical KV caching system that dramatically improves performance for long-context and multi-turn conversation scenarios. By intelligently managing KV caches across GPU memory, host memory, and external storage backends, HiCache addresses the fundamental capacity bottleneck that limits cache hit rates in conventional systems.\n\n## Configuration Guidelines\n\n## Core HiCache Parameters\n\n```bash\n# Essential HiCache flags\n--page-size 64                        # Page size for cache management\n--enable-hierarchical-cache           # Enable HiCache\n--hicache-ratio 2                     # Host memory ratio (2x GPU memory)\n--hicache-size 100                    # Host memory size in GBs, will override the above ratio\n--hicache-io-backend kernel           # The I/O backend of moving data between CPU and GPU\n--hicache-write-policy write_through  # Cache write policy from GPU to CPU\n--hicache-storage-backend             # Optional storage backend (e.g., hf3fs, mooncake, etc.)\n```\n\nNotes:\n\n- Besides configuring `--hicache-storage-backend` at startup, SGLang also supports **runtime attach/detach** of the HiCache storage backend (no restart required) via HTTP admin endpoints. See [Runtime Attach/Detach HiCache Storage Backend](hicache_storage_runtime_attach_detach.md).\n\n## Key Configurations with Storage Backends Enabled\n\n### Memory Layout Optimization\n\n```bash\n# Page-first: Optimized for I/O efficiency with zero-copy (recommended with kernel backend)\n--hicache-mem-layout page_first\n# Page-first-direct: Optimized for direct I/O operations (Compatible with fa3 and same zero-copy performance as page_first)\n--hicache-mem-layout page_first_direct\n# Layer-first\n--hicache-mem-layout layer_first\n```\n**Layout Compatibility:**\n- `page_first`: Only compatible with `kernel` I/O backend, automatically switches to `layer_first` with `direct` backend\n- `page_first_direct`: Specifically designed for `direct` I/O backend with optimized memory organization\n\n### Heterogeneous TP Support (GQA/MHA models)\n\nHiCache storage supports cross-cluster KV reuse when different deployments use different TP sizes (for example, `tp=4` and `tp=8`) and share the same storage backend namespace.\n\nUse `tp_lcm_size` in `--hicache-storage-backend-extra-config`:\n\n```bash\n# Example: heterogeneous TP = {4, 8}, so lcm = 8\n--hicache-storage-backend-extra-config '{\"tp_lcm_size\": 8}'\n```\n\nGuidelines:\n\n- Set `tp_lcm_size` to the least common multiple (LCM) of all TP sizes that will share the same HiCache storage.\n- For MHA models with Mooncake and `page_head` layout, HiCache will split head shards based on `tp_lcm_size` to make keys reusable across heterogeneous TP deployments.\n- If all clusters use the same TP size, this option is not needed.\n\n### Prefetch Policies\n\n```bash\n# Best-effort: Terminate prefetch when needed\n--hicache-storage-prefetch-policy best_effort\n# Wait-complete: Ensure complete prefetch, higher cache reuse\n--hicache-storage-prefetch-policy wait_complete\n# Timeout: Balance between completion and best-effort\n--hicache-storage-prefetch-policy timeout\n```\n\n### Integration with PD Disaggregation\n\nHiCache works seamlessly with PD Disaggregation. You can choose between two configurations:\n\n1. **Prefill-only HiCache**: Enable HiCache only on Prefill nodes, allowing KV cache sharing among Prefill instances\n2. **Full HiCache with async offloading**: Enable HiCache on Prefill nodes and async KV cache offloading on Decode nodes, allowing Prefill nodes to reuse KV caches from Decode nodes in multi-turn dialogue scenarios\n\n```bash\n# Prefill node with HiCache enabled for cross-prefill sharing (ideal for SystemPrompt scenarios)\npython3 -m sglang.launch_server \\\n  --model-path /xxx/DeepSeek-R1/ \\\n  --tp 8 \\\n  --host 0.0.0.0 \\\n  --port 10000 \\\n  --enable-metrics \\\n  --enable-cache-report \\\n  --mem-fraction-static 0.85 \\\n  --page-size 64 \\\n  --enable-hierarchical-cache \\\n  --hicache-ratio 2 \\\n  --hicache-size 0 \\\n  --hicache-mem-layout page_first_direct \\\n  --hicache-io-backend direct \\\n  --hicache-write-policy write_through \\\n  --hicache-storage-backend hf3fs \\\n  --hicache-storage-prefetch-policy wait_complete \\\n  --disaggregation-ib-device mlx5_0 \\\n  --disaggregation-mode prefill \\\n  --disaggregation-transfer-backend mooncake\n\n# Decode node with async offloading enabled for KV cache reuse by Prefill (ideal for multi-turn conversations)\npython3 -m sglang.launch_server \\\n  --model-path /xxx/DeepSeek-R1/ \\\n  --tp 8 \\\n  --host 0.0.0.0 \\\n  --port 10000 \\\n  --enable-metrics \\\n  --enable-cache-report \\\n  --page-size 64 \\\n  --hicache-ratio 2 \\\n  --hicache-size 0 \\\n  --hicache-mem-layout page_first_direct \\\n  --hicache-io-backend direct \\\n  --hicache-write-policy write_through \\\n  --hicache-storage-backend hf3fs \\\n  --hicache-storage-prefetch-policy wait_complete \\\n  --disaggregation-decode-enable-offload-kvcache \\  # Enable async KV cache offloading in decode node\n  --disaggregation-ib-device mlx5_0 \\\n  --disaggregation-mode decode \\\n  --disaggregation-transfer-backend mooncake\n```\n\n\n### Deployment with HF3FS\n\nHere is an example of deploying DeepSeek-R1 with HiCache-HF3FS. For more details, see the [HF3FS Documentation](../../python/sglang/srt/mem_cache/storage/hf3fs/docs/README.md).\n\n```bash\npython3 -m sglang.launch_server \\\n  --model-path /xxx/DeepSeek-R1/ \\\n  --log-level info \\\n  --tp 8 \\\n  --host 0.0.0.0 \\\n  --port 10000 \\\n  --enable-metrics \\\n  --enable-cache-report \\\n  --page-size 64 \\\n  --mem-fraction-static 0.85 \\\n  --enable-hierarchical-cache \\\n  --hicache-ratio 2 \\\n  --hicache-size 0 \\\n  --hicache-mem-layout page_first_direct \\\n  --hicache-io-backend direct \\\n  --hicache-write-policy write_through \\\n  --hicache-storage-backend hf3fs \\\n  --hicache-storage-prefetch-policy wait_complete \\\n```\n\n### Deployment with Mooncake\n\nHere is an example of deploying Qwen3-235B-A22B-Instruct-2507 with Mooncake. For more details, see the [Mooncake Documentation](../../python/sglang/srt/mem_cache/storage/mooncake_store/README.md).\n\n```bash\n# Set Mooncake environment variables\nexport MOONCAKE_TE_META_DATA_SERVER=\"http://127.0.0.1:8080/metadata\"\nexport MOONCAKE_GLOBAL_SEGMENT_SIZE=816043786240\nexport MOONCAKE_PROTOCOL=\"rdma\"\nexport MOONCAKE_DEVICE=\"$DEVICE_LIST\"\nexport MOONCAKE_MASTER=127.0.0.1:50051\n\n# Launch SGLang server with Mooncake backend\npython3 -m sglang.launch_server \\\n  --model-path $MODEL_PATH \\\n  --tp 8 \\\n  --page-size 64 \\\n  --enable-hierarchical-cache \\\n  --hicache-ratio 2 \\\n  --hicache-mem-layout page_first_direct \\\n  --hicache-io-backend direct \\\n  --hicache-storage-backend mooncake \\\n  --hicache-write-policy write_through \\\n  --hicache-storage-prefetch-policy timeout\n```\n\n\n## Custom Storage Backend Integration\n\nTo integrate a new storage backend:\n\n1. **Implement three core methods:**\n   - `get(key)`: Retrieve value by key\n   - `exists(key)`: Check key existence\n   - `set(key, value)`: Store key-value pair\n\n2. **Register your backend:** Add your storage backend to the HiCache [BackendFactory](../../python/sglang/srt/mem_cache/storage/backend_factory.py#L188)\n\nThe HiCache controller handles all scheduling and synchronization automatically.\n\n### Dynamic Backend Loading\n\nAlternatively, you can use dynamic loading to avoid hard-coding your backend in the repository:\n\n```bash\npython3 -m sglang.launch_server \\\n  --model-path your-model \\\n  --enable-hierarchical-cache \\\n  --hicache-storage-backend dynamic \\\n  --hicache-storage-backend-extra-config '{\"backend_name\":\"custom_backend_name\", \"module_path\": \"your_module_path\", \"class_name\": \"YourHiCacheClassName\"}'\n```\n\n**Configuration Parameters:**\n- `--hicache-storage-backend`: Set to `dynamic`\n- `--hicache-storage-backend-extra-config`: JSON configuration with:\n  - `backend_name`: Custom backend identifier\n  - `module_path`: Python module path to your implementation\n  - `class_name`: Your HiCache implementation class name\n  - `interface_v1`: 0 (disable) or 1 (enable) to control usage of batch_get_v1 and batch_set_v1 methods\n\n\n## Community and Support\n\n- **GitHub Issues**: Report bugs and feature requests\n- **Slack Channel**: Join community discussions in #sgl-kv-cache-store\n- **Documentation**: Refer to storage backend-specific guides\n\n---\n\n*This document will be continuously updated based on community feedback and new features. Contributions and suggestions are welcome!*\n"
  },
  {
    "path": "docs/advanced_features/hicache_design.md",
    "content": "# HiCache System Design and Optimization\n\nThis document provides a comprehensive overview of SGLang HiCache, covering its system architecture, workflow and key components. It also details configuration parameters, optimization techniques, and integration with various L3 storage backends, serving as a complete reference for users and developers to understand and tune HiCache for efficient LLM inference.\n\n## Why and What is HiCache?\n\nIn large language model inference, the prefill phase is often time-consuming: input sequences need to be first converted into Key-Value cache (KV cache) for subsequent decoding. When multiple requests share the same prefix, the KV cache for that prefix is identical. By caching and reusing these shared KV caches, redundant computation can be avoided. To address this, SGLang introduced RadixAttention, which leverages idle GPU memory to cache and reuse prefix KV caches, and **HiCache**, which extends this idea to host memory and distributed storage.\n\nInspired by the classic three-level cache design of modern CPUs, HiCache organizes GPU memory as L1, host memory as L2, and distributed storage as L3. This hierarchy enables HiCache to fully exploit the \"idle\" storage space of GPUs and CPUs, while integrating distributed cache systems such as Mooncake, 3FS, NIXL, and AIBrix KVCache for global KV cache storage and scheduling. As a result, HiCache significantly expands KV cache capacity while maintaining strong read performance—especially in workloads such as multi-QA and long-context inference, where KV cache reuse is frequent. For detailed benchmark results, see [this blog](https://lmsys.org/blog/2025-09-10-sglang-hicache/).\n\n\n## System Design\n\n### Overall Architecture\n\nIn many modern CPU architectures, the small but fast L1 and L2 caches are private to each core, enabling rapid access to the hottest data, while the larger L3 cache is shared across all cores to significantly reduce redundancy within the cache. Similarly, in HiCache, the L1 and L2 KV caches are private to each inference instance, whereas the L3 KV cache is shared among all inference instances within the cluster.\n\n### HiRadixTree: Metadata Organization in HiCache\n\nFor KV cache data organization, HiCache builds upon the RadixTree structure introduced in RadixAttention and proposes HiRadixTree. In RadixAttention, each node of the RadixTree corresponds to the KV cache of a consecutive span of tokens in GPU memory. A path from the root to a leaf node represents the prefix of a request, and shared prefixes across multiple requests can reuse the same nodes, thereby avoiding redundant storage.\n\nHiRadixTree extends this idea: each node corresponds to the KV cache of a span of consecutive tokens and records where that KV cache is stored—whether in local GPU memory, CPU memory, L3 storage, or multiple of these tiers. If stored locally, HiRadixTree maintains precise metadata, including the exact storage address. However, to reduce overhead, HiRadixTree does not store or continuously synchronize metadata for L3 KV cache. Instead, when accessing L3 data, it queries the backend in real time to retrieve the necessary metadata, such as whether the data exists and on which server and location it resides.\n\n### Overall Workflow\n\nThe workflow of HiCache mainly involves three key operations: **local match**, **prefetch** and **write-back**. When the system receives a new request, it first searches the local L1 and L2 caches for matching KV caches. For parts not found locally, it attempts to prefetch from L3. After prefetching, all required KV caches are loaded into the GPU for computation. Once the prefill computation is complete, the system considers storing the newly generated data into L2 or L3.\n\n![HiCache Workflow](https://lmsys.org/images/blog/hicache/hicache_overview.png)\n\n### Local Match\n\nLocal matching is the first step in HiCache's workflow, where incoming request tokens are matched against the HiRadixTree to locate cached KV data in local memory tiers (L1 GPU memory and L2 host memory).\n\nThe matching algorithm traverses the HiRadixTree from the root node, following child nodes that match the token sequence prefix. At each node, the incoming token sequence is compared with the node’s stored token sequence. When `page_size > 1`, matching is performed at the page granularity to optimize memory access patterns. If a match terminates within a node’s stored sequence, the node is automatically split to create an exact boundary, improving the efficiency of future matches.\n\nThe algorithm returns a continuous prefix of the request, with the first part residing in L1 and the latter part in L2.\n\nSince the process only requires traversing the local HiRadixTree and does not involve any actual data copying, local matching is extremely fast.\n\n### Prefetch from L3\n\nData prefetching is one of HiCache’s core optimization techniques, designed to proactively load KV caches from L3 storage into local L2 memory, thereby reducing access latency during subsequent operations.\n\n**Prefetch Trigger Conditions**:\nAfter local matching, for the parts not found in L1 or L2, the system queries L3 to retrieve metadata for the next continuous matching KV caches. If the length of hit cache in L3 exceeds a threshold (default: 256 tokens, configurable), a prefetch operation is triggered.\n\n**Prefetch Strategies**: HiCache provides three different prefetch termination strategies to address different scenario needs:\n- **best_effort**: Terminates immediately when GPU can execute prefill computation, with no waiting time, suitable for scenarios extremely sensitive to latency.\n- **wait_complete**: Must wait for all prefetch operations to complete, suitable for scenarios requiring high cache hit rates.\n- **timeout**: Terminates after specified time or when complete, balancing latency and cache hit rate needs.\n\nAfter prefetching stops, the data already fetched is used together with the local data for the prefill computation.\n\nFor **timeout** strategy, HiCache introduces two configuration parameters to support fine-grained control over prefetch timeout conditions:\n\n* `prefetch_timeout_base`: the base timeout, representing overhead unrelated to the number of tokens (e.g., scheduling and synchronization).\n* `prefetch_timeout_per_ki_token`: the incremental timeout per thousand tokens.\n\nThe timeout is computed as:\n\n```\ntimeout = prefetch_timeout_base + prefetch_timeout_per_ki_token * num_token_to_fetch / 1024\n```\n\n### Data Write-back\n\nThe write-back mechanism is responsible for moving frequently accessed KV caches from L1 to L2 and L3, enabling larger and longer-term storage as well as cache sharing across instances.\n\n**Configurable Write-back Policies**: HiCache supports three write-back strategies:\n\n* **write_through**: Every access is immediately written back to the next level. When bandwidth is sufficient, this strategy provides the strongest caching benefit.\n* **write_through_selective**: Data is written back only after the access frequency exceeds a threshold. This strategy backs up only hot data, reducing I/O overhead.\n* **write_back**: Data is written back to the next level only when it is evicted from the upper level. This strategy alleviates storage pressure and is suitable for scenarios where storage capacity is limited but memory utilization must be maximized.\n\n**Cross-instance Sharing**: When data is written back from L2 to L3, only data not already present in L3 is transferred. KV caches stored in L3 can then be shared across all SGLang instances in the cluster (depending on the L3 backend implementation), significantly improving cache hit rates within the same memory budget.\n\n### Multi-Rank Synchronization\n\nDuring multi-GPU parallel computation, such as tensor parallelism (TP), HiCache must ensure consistent states across different ranks. Therefore, critical computation steps require the use of `all_reduce` for state synchronization.\n\nFor example, during prefetching, `all_reduce(op=min)` is used to ensure that all ranks obtain the same number of L3 hits, preventing inconsistent judgments about whether the prefetch threshold has been reached. Similarly, after prefetching completes or terminates, `all_reduce(op=min)` is again required to guarantee consensus among ranks on the prefix length of the successfully retrieved KV cache.\n\n### Data Transfer Optimization\n\n**Zero-Copy Data Transfers**: Both prefetching and write-back involve substantial data movement. Minimizing the number of data copies can significantly improve system performance. HiCache supports passing memory addresses and sizes directly when transferring data from L2 memory to an L3 backend.\n\n**“Batch-Oriented” Data Organization**: The granularity of data reads and writes has a major impact on performance. To address this, HiCache L3 stores and transfers KV cache data at the granularity of **pages** and supports different data layouts beyond the existing `layer first` scheme, including `page first` and `page first direct`. Under the `page first` and `page first direct` layouts, all KV cache data belonging to the same page is placed in contiguous memory, allowing it to be passed as a single object to L3 using zero-copy transfers.\n\n![HiCache L2 MEM layout](https://lmsys.org/images/blog/hicache/hicache_layout.png)\n\nHowever, because GPU KV computation is naturally performed layer by layer, the GPU inherently operates in a `layer first` layout. When transferring `page first` data from L2 to the GPU, data must be transferred at the granularity of one token per layer. The `page first direct` layout mitigates this issue by grouping together all tokens of a given layer within a page, allowing transfers from L2 to GPU to be aggregated at the page-layer level.\n\n**CPU-to-GPU Transfer Optimizations**: In HiCache, moving data from CPU memory to GPU is as performance-critical as prefetching data from L3 to L2. HiCache employs several optimizations for this process:\n\n* **Compute-Transfer Overlap**: During the prefill phase, when transferring data from CPU to GPU, HiCache overlaps layers by concurrently loading the KV cache of layer N+1 while computing layer N. This effectively hides data transfer latency.\n* **GPU-assisted I/O Kernels**: On top of `cudaMemcpyAsync`, HiCache implements a set of GPU-assisted I/O kernels specifically optimized for KV cache transfers between CPU and GPU. Compared to the baseline approach, these kernels achieve up to 3x higher transfer speed.\n\n**Write-back Optimization for MLA**: For MHA (Multi-Head Attention) models under multi-TP, each rank holds `1/tp_size` of a token’s KV data. In contrast, for MLA (Multi-Layer Attention) models, all ranks hold the complete and identical KV data for each token. HiCache includes a dedicated optimization for MLA: only one rank initiates the write-back operation, ensuring that data is not redundantly stored across ranks.\n\n### Integration with PD-Disaggregation Deployment Mode\n\nSGLang supports a PD (Prefill-Decode) disaggregation deployment mode through the Mooncake TransferEngine (for details, see [this doc](https://docs.sglang.io/advanced_features/pd_disaggregation.html)). In the PD-disaggregation deployment mode, HiCache can be enabled on both the prefill nodes and decode nodes to optimize prefill performance. If enabled on decode nodes, the decode output will also be written back to L3.\n\n### Unified Interfaces and Rich L3 Storage Backends\n\nHiCache encapsulates all read, write, and query operations on L3 backends within the `class HiCacheStorage(ABC)`, exposing a set of simple and consistent interfaces. This design supports a wide range of L3 storage backends and allows users to select the one that best fits their specific use cases.\n\n- **Mooncake**: Mooncake is a high-performance caching system for LLM inference that leverages RDMA and multi-NIC resources to enable zero-copy, ultra-fast data transfers. Try Mooncake [here](https://github.com/sgl-project/sglang/tree/main/python/sglang/srt/mem_cache/storage/mooncake_store).\n\n- **DeepSeek 3FS (HF3FS)**: HF3FS is a Kubernetes-native distributed storage solution with operator-based deployment. Try HF3FS [here](https://github.com/sgl-project/sglang/tree/main/python/sglang/srt/mem_cache/storage/hf3fs).\n\n- **NIXL**: NIXL provides a unified API for accessing various storage plugins, including but not limited to DeepSeek's 3FS, GPU Direct Storage (GDS) and Amazon S3-compatible object storage. Try NIXL [here](https://github.com/sgl-project/sglang/tree/main/python/sglang/srt/mem_cache/storage/nixl).\n\n- **AIBrix KVCache**: AIBrix KVCache is a production-ready KVCache Offloading Framework, which enables efficient memory tiering and low-overhead cross-engine reuse. Try AIBrix KVCache [here](https://github.com/sgl-project/sglang/tree/main/python/sglang/srt/mem_cache/storage/aibrix_kvcache).\n\n- **HiCacheFile**: A simple file-based storage backend for demonstration purposes.\n\nSpecifically, **LMCache**, an efficient KV cache layer for enterprise-scale LLM inference, provides an alternative solution to HiCache. Try LMCache [here](https://github.com/sgl-project/sglang/tree/main/python/sglang/srt/mem_cache/storage/lmcache).\n\n## Related Parameters\n\n- **`--enable-hierarchical-cache`**: Enable hierarchical cache functionality. This is required to use HiCache.\n\n- **`--hicache-ratio HICACHE_RATIO`**: The ratio of the size of host KV cache memory pool to the size of device pool. For example, a value of 2 means the host memory pool is twice as large as the device memory pool. The value of this parameter must be greater than 1, as the current implementation requires the host memory allocated for the KV cache to be larger than the device memory allocated for the KV cache.\n\n- **`--hicache-size HICACHE_SIZE`**: The size of host KV cache memory pool in gigabytes. This parameter overrides `hicache-ratio` if set. For example, `--hicache-size 30` allocates 30GB (1GB = 1e9 bytes) for the host memory pool **for each rank**. If there are 8 ranks, then the total memory size is 240GB. Just like `hicache-ratio`, the value of this parameter must be larger than the size of device memory allocated for KV cache.\n\n**Note**: `--hicache-ratio` and `--hicache-size` are two critical parameters. In general, a larger HiCache size leads to a higher cache hit rate, which improves prefill performance. However, the relationship between cache size and hit rate is not linear. Once most reusable KV data—especially hot tokens—are already cached, further increasing the size may yield only marginal performance gains. Users can set these parameters based on their workload characteristics and performance requirements.\n\n- **`--page-size PAGE_SIZE`**: The number of tokens per page. This parameter determines the granularity of KV cache storage and retrieval. Larger page sizes reduce metadata overhead and improve I/O efficiency for storage backends, but may lower the cache hit rate when only part of a page matches the stored KV cache. For workloads with long common prefixes, larger pages can improve performance, while workloads with more diverse prefixes may benefit from smaller pages. See [Data Transfer Optimization](#data-transfer-optimization) for how page granularity affects I/O performance.\n\n- **`--hicache-storage-prefetch-policy {best_effort,wait_complete,timeout}`**: Controls when prefetching from storage should stop. See [Prefetch from L3](#prefetch-from-l3) for details.\n  - `best_effort`: Prefetch as much as possible without blocking\n  - `wait_complete`: Wait for prefetch to complete before proceeding\n  - `timeout`: Terminates after specified time or when complete (Recommended for production environments, as setting an appropriate timeout helps the system meet required SLOs)\n\n- **`--hicache-write-policy {write_back,write_through,write_through_selective}`**: Controls how data is written from faster to slower memory tiers. See [Data Write-back](#data-write-back) for details.\n  - `write_through`: Immediately writes data to all tiers (strongest caching benefits)\n  - `write_through_selective`: Uses hit-count tracking to back up only frequently accessed data\n  - `write_back`: Writes data back to slower tiers only when eviction is needed (reduces I/O load)\n\n- **`--hicache-io-backend {direct,kernel}`**: Choose the I/O backend for KV cache transfer between CPU and GPU. See [Data Transfer Optimization](#data-transfer-optimization) for details.\n  - `direct`: Standard CUDA memory copy operations\n  - `kernel`: GPU-assisted I/O kernels (recommended for better performance)\n\n- **`--hicache-mem-layout {layer_first,page_first,page_first_direct}`**: Memory layout for the host memory pool. See [Data Transfer Optimization](#data-transfer-optimization) for details.\n  - `layer_first`: Compatible with GPU computation kernels (default for GPU memory)\n  - `page_first`: Optimized for I/O efficiency\n  - `page_first_direct`: Groups all tokens of a given layer within a page, allowing transfers from L2 to GPU to be aggregated at the page-layer level\n\n- **`--hicache-storage-backend {file,mooncake,hf3fs,nixl,aibrix,dynamic}`**: Choose the storage backend for the L3 tier. Built-in backends: file, mooncake, hf3fs, nixl, aibrix. For dynamic backend, use --hicache-storage-backend-extra-config to specify: `backend_name` (custom name), `module_path` (Python module path), `class_name` (backend class name). See [Unified Interfaces and Rich L3 Storage Backends](#unified-interfaces-and-rich-l3-storage-backends) for available backends.\n\n- **`--enable-lmcache`**: Using LMCache as an alternative hierarchical cache solution.\n\n- **`--hicache-storage-backend-extra-config HICACHE_STORAGE_BACKEND_EXTRA_CONFIG`**: the extra config can be either\n  - a JSON string containing extra configuration for the storage backend, e.g., `--hicache-storage-backend-extra-config '{\"prefetch_threshold\":512, \"prefetch_timeout_base\": 0.5, \"prefetch_timeout_per_ki_token\": 0.25}' `, or\n  - a TOML or JSON or YAML file specifying the extra configuration for the storage backend (to differentiate from the JSON string input, prepend a `@` in front of the file name), e.g., `--hicache-storage-backend-extra-config \"@config.toml\"` where `config.toml` is the config file containing the complex configurations. This can be useful when the configuration consists of many or complex key-value pairs (for instance, it is preferred to use a config file for NIXL backend as its configurations can be complex).\n"
  },
  {
    "path": "docs/advanced_features/hicache_storage_runtime_attach_detach.md",
    "content": "# Runtime Attach/Detach HiCache Storage Backend (No Restart)\n\nThis document explains how to **dynamically attach/detach the HiCache L3 storage backend at runtime** (e.g., `mooncake` / `hf3fs` / `nixl` / `file` / `aibrix` / `eic`) while **SGLang is already running and serving traffic**, without restarting the process.\n\nFor safety and consistency, the current implementation **strictly requires** these operations to happen only when the service is **idle**:\n\n- **No running requests**\n- **No waiting/queued requests**\n\nIf the idle condition is not met, the API will fail fast (HTTP 400) and **will not modify** the current service state.\n\n---\n\n## 1. Background and implementation overview\n\n### 1.1 Architecture / control path\n\nThe control path is:\n\n1. **HTTP Server** (`python/sglang/srt/entrypoints/http_server.py`)\n   - Exposes `PUT /hicache/storage-backend`, `DELETE /hicache/storage-backend`, `GET /hicache/storage-backend`\n2. **TokenizerManager** (`python/sglang/srt/managers/tokenizer_communicator_mixin.py`)\n   - Sends the request to the Scheduler via `_Communicator`\n3. **Scheduler** (`python/sglang/srt/managers/scheduler.py`)\n   - Performs a **strict idle check**\n   - Calls `tree_cache.attach_storage_backend(...)` / `detach_storage_backend(...)`\n4. **HiRadixCache** (`python/sglang/srt/mem_cache/hiradix_cache.py`)\n   - Parses `hicache_storage_backend_extra_config_json` (supports both backend config and prefetch knobs)\n   - Calls `cache_controller.attach_storage_backend(...)` / `detach_storage_backend(...)`\n5. **HiCacheController** (`python/sglang/srt/managers/cache_controller.py`)\n   - Creates/destroys the storage backend instance (via `StorageBackendFactory`)\n   - Starts/stops backend background threads at runtime (prefetch/backup)\n\n---\n\n## 2. Idle-state requirement (strict)\n\nThe Scheduler uses `is_fully_idle()` which checks:\n\n- No running batches (including chunked prefill, overlap, pipeline-parallel, and disaggregation paths)\n- No waiting requests in any queue (waiting, grammar, disagg bootstrap/prealloc/transfer/inflight)\n- No DLLM staging requests\n\nIf the condition is not met, attach/detach returns an error like:\n\n- `Reject attach: scheduler is not idle. #queue-req=... #running-req=...`\n\n> Tip: before switching, drain upstream traffic and wait for the server to become idle, then call attach/detach.\n\n### 2.1 DP (data parallel) semantics\n\nWhen `dp_size > 1`, the tokenizer dispatches the request to **all DP scheduler instances** and aggregates their responses:\n\n- The final `success` is **true only if all DP ranks return success**\n- The final `message` concatenates messages from all DP ranks\n\nThis is intended to prevent “silent partial success”, but it also means you may see:\n\n- Overall **failure** even though **some ranks already succeeded**\n\nCurrently there is **no automatic partial rollback** across DP ranks (see TODO in code). Operationally:\n\n- Prefer to keep backend config identical across ranks\n- If attach fails, immediately call detach (best-effort/idempotent), fix config, then retry attach\n\n---\n\n## 3. How to use (HTTP Admin API)\n\nThe examples below assume your SGLang HTTP server is at `http://127.0.0.1:30000`.\n\n### 3.1 Query current storage backend status\n\n```bash\ncurl -s http://127.0.0.1:30000/hicache/storage-backend\n```\n\nExample response:\n\n```json\n{\n  \"hicache_storage_backend\": \"mooncake\",\n  \"hicache_storage_backend_extra_config\": \"{\\\"master_server_address\\\":\\\"127.0.0.1:50051\\\", ...}\"\n}\n```\n\n### 3.2 Attach (enable) a storage backend\n```bash\ncurl -s -X PUT http://127.0.0.1:30000/hicache/storage-backend \\\n  -H 'Content-Type: application/json' \\\n  -d '{\n    \"hicache_storage_backend\": \"mooncake\"\n  }'\n```\n\n```bash\ncurl -s -X PUT http://127.0.0.1:30000/hicache/storage-backend \\\n  -H 'Content-Type: application/json' \\\n  -d '{\n    \"hicache_storage_backend\": \"mooncake\",\n    \"hicache_storage_backend_extra_config_json\": \"{\\\"master_server_address\\\":\\\"127.0.0.1:50051\\\",\\\"protocol\\\":\\\"tcp\\\",\\\"global_segment_size\\\":\\\"4gb\\\",\\\"prefetch_threshold\\\":256}\",\n    \"hicache_storage_prefetch_policy\": \"timeout\"\n  }'\n```\n\nNotes:\n\n- `hicache_storage_backend_extra_config_json` can include both:\n  - **Backend configuration** (e.g., Mooncake master/metadata/protocol, etc.)\n  - **Prefetch configuration** (`prefetch_threshold`, `prefetch_timeout_base`, `prefetch_timeout_per_ki_token`, `hicache_storage_pass_prefix_keys`)\n\n### 3.3 Detach (disable) the storage backend\n\n```bash\ncurl -s -X DELETE http://127.0.0.1:30000/hicache/storage-backend\n```\n\nNotes:\n\n- Detach only makes SGLang **stop using** the L3 storage backend and stops prefetch/backup threads\n- It **does not automatically delete** data stored in Mooncake/HF3FS (or other remote backends)\n\n---\n\n## 4. Behavior and caveats\n\n- **No restart required**: attach/detach switches in-process at runtime\n- **Must be idle**: otherwise the request is rejected to avoid consistency issues\n- **Host KV layout constraints still apply**: for example, Mooncake still requires layouts like `page_first/page_first_direct/page_head`; if the server's HiCache host-memory layout does not satisfy the backend requirements, attach will fail with an error\n- **Observability**:\n  - After attach, `server_args.hicache_storage_backend*` is updated on both the tokenizer and scheduler sides\n  - If metrics are enabled, attach will create a storage metrics collector in `HiRadixCache` on demand\n"
  },
  {
    "path": "docs/advanced_features/hyperparameter_tuning.md",
    "content": "# Hyperparameter Tuning\n\n## Achieving high throughput for offline batch inference\n\nAchieving a large batch size is the most important thing for attaining high throughput in offline batch inference.\nWhen the server is running at full load in a steady state, look for the following in the log:\n\n```Decode batch. #running-req: 233, #token: 370959, token usage: 0.82, cuda graph: True, gen throughput (token/s): 4594.01, #queue-req: 317```\n\n### Adjust the request submission speed to control `#queue-req`\n\n`#queue-req` indicates the number of requests in the queue.\nIf you frequently see `#queue-req: 0`, it suggests that your client code is submitting requests too slowly.\nA healthy range for `#queue-req` is `100 - 2000`.\nHowever, avoid making `#queue-req` too large, as this will increase the scheduling overhead on the server.\n\n### Achieve a high `token usage`\n\n`token usage` indicates the KV cache memory utilization of the server. `token usage > 0.9` means good utilization.\n\nIf you frequently see `token usage < 0.9` and `#queue-req > 0`, it means the server is too conservative about taking in new requests. You can decrease `--schedule-conservativeness` to a value like 0.3.\nThe case of a server being too conservative can happen when users send many requests with a large `max_new_tokens` but the requests stop very early due to EOS or stop strings.\n\nOn the other hand, if you see `token usage` very high and you frequently see warnings like\n`KV cache pool is full. Retract requests. #retracted_reqs: 1, #new_token_ratio: 0.9998 -> 1.0000`, you can increase `--schedule-conservativeness` to a value like 1.3.\nIf you see `KV cache pool is full. Retract requests.` occasionally but not frequently (~1 time per minute), it is okay.\n\n### Tune `--mem-fraction-static` to increase KV cache pool capacity\nSGLang allocates memory as follows:\n\nTotal memory usage = model weights + KV cache pool + CUDA graph buffers + activations\n\nThe `--mem-fraction-static` parameter determines how much memory is allocated to the first two components:\n\nmem_fraction_static = (model weights + KV cache pool) / GPU memory capacity\n\nTo support higher concurrency, you should maximize the KV cache pool capacity by setting `--mem-fraction-static` as high as possible while still reserving enough memory for activations and CUDA graph buffers.\n\nSGLang uses simple heuristics to set the default value of `--mem-fraction-static`, but you can optimize it for your use cases.\nAs a rule of thumb, reserving 5–8 GB of memory for activations is typically sufficient. You can check this by inspecting the logs just before the server is ready.\nLook for log entries like this:\n\n```\n[2025-08-11 17:17:03] max_total_num_tokens=665690, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=4096, context_len=65536, available_gpu_mem=13.50 GB\n```\n\nCheck the `available_gpu_mem` value.\n- If it is between 5–8 GB, the setting is good.\n- If it is too high (e.g., 10 - 20 GB), increase `--mem-fraction-static` to allocate more memory to the KV cache.\n- If it is too low, you risk out-of-memory (OOM) errors later, so decrease `--mem-fraction-static`.\n\nAnother straightforward approach is to increase `--mem-fraction-static` in increments of 0.01 until you encounter OOM errors for your workloads.\n\n### Avoid out-of-memory errors by tuning `--chunked-prefill-size`, `--mem-fraction-static`, and `--max-running-requests`\n\nIf you encounter out-of-memory (OOM) errors, you can adjust the following parameters:\n\n- If OOM occurs during prefill, try reducing `--chunked-prefill-size` to `4096` or `2048`. This saves memory but slows down the prefill speed for long prompts.\n- If OOM occurs during decoding, try lowering `--max-running-requests`.\n- You can also reduce `--mem-fraction-static` to a smaller value, such as 0.8 or 0.7. This decreases the memory usage of the KV cache memory pool and helps prevent OOM errors during both prefill and decoding. However, it limits maximum concurrency and reduces peak throughput.\n\n### Tune `--cuda-graph-max-bs`\nBy default, CUDA graph is enabled only for small batch sizes (e.g., less than 160 or 256).\nHowever, for some models, especially at large tensor parallelism sizes, CUDA graph can be useful for batch sizes up to 512 or 768.\nTherefore, it may be beneficial to increase `--cuda-graph-max-bs` to a larger value.\nNote that CUDA graph consumes more memory, so you may need to reduce `--mem-fraction-static` at the same time.\n\n### Tune `--dp-size` and `--tp-size`\n\nData parallelism is better for throughput. When there is enough GPU memory, always favor data parallelism for throughput. Refer to [SGLang Model Gateway (former Router)](../advanced_features/sgl_model_gateway.md) for a better data parallelism rather than using `dp_size` parameter.\n\n### Try other options\n\n- `torch.compile` accelerates small models on small batch sizes. You can enable it with `--enable-torch-compile`.\n- Try other quantization (e.g. FP8 quantization with `--quantization fp8`)\n- Try other parallelism strategies (e.g. [expert parallelism](https://lmsys.org/blog/2025-05-05-large-scale-ep/)) or DP attention for deepseek models (with `--enable-dp-attention --dp-size 8`).\n- If the workload has many shared prefixes, try `--schedule-policy lpm`. Here, `lpm` stands for longest prefix match. It reorders requests to encourage more cache hits but introduces more scheduling overhead.\n"
  },
  {
    "path": "docs/advanced_features/lora.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# LoRA Serving\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"SGLang enables the use of [LoRA adapters](https://arxiv.org/abs/2106.09685) with a base model. By incorporating techniques from [S-LoRA](https://arxiv.org/pdf/2311.03285) and [Punica](https://arxiv.org/pdf/2310.18547), SGLang can efficiently support multiple LoRA adapters for different sequences within a single batch of inputs.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Arguments for LoRA Serving\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"The following server arguments are relevant for multi-LoRA serving:\\n\",\n    \"\\n\",\n    \"* `enable_lora`: Enable LoRA support for the model. This argument is automatically set to True if `--lora-paths` is provided for backward compatibility.\\n\",\n    \"\\n\",\n    \"* `enable_lora_overlap_loading`: Enable asynchronous LoRA weight loading in order to overlap H2D transfers with GPU compute. This should be enabled if you find that your LoRA workloads are bottlenecked by adapter weight loading, for example when frequently loading large LoRA adapters.\\n\",\n    \"\\n\",\n    \"* `lora_paths`: The list of LoRA adapters to load. Each adapter must be specified in one of the following formats: <PATH> | <NAME>=<PATH> | JSON with schema {\\\"lora_name\\\":str,\\\"lora_path\\\":str,\\\"pinned\\\":bool}.\\n\",\n    \"\\n\",\n    \"* `max_loras_per_batch`: Maximum number of adaptors used by each batch. This argument can affect the amount of GPU memory reserved for multi-LoRA serving, so it should be set to a smaller value when memory is scarce. Defaults to be 8.\\n\",\n    \"\\n\",\n    \"* `max_loaded_loras`: If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `max-loras-per-batch`.\\n\",\n    \"\\n\",\n    \"* `lora_eviction_policy`: LoRA adapter eviction policy when GPU memory pool is full. `lru`: Least Recently Used (default, better cache efficiency). `fifo`: First-In-First-Out.\\n\",\n    \"\\n\",\n    \"* `lora_backend`: The backend of running GEMM kernels for Lora modules. Currently we support Triton LoRA backend (`triton`) and Chunked SGMV backend (`csgmv`). In the future, faster backend built upon Cutlass or Cuda kernels will be added.\\n\",\n    \"\\n\",\n    \"* `max_lora_rank`: The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup.\\n\",\n    \"\\n\",\n    \"* `lora_target_modules`: The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup. You can also set it to `all` to enable LoRA for all supported modules. However, enabling LoRA on additional modules introduces a minor performance overhead. If your application is performance-sensitive, we recommend only specifying the modules for which you plan to load adapters.\\n\",\n    \"\\n\",\n    \"* `--max-lora-chunk-size`: Maximum chunk size for the ChunkedSGMV LoRA backend. Only used when --lora-backend is 'csgmv'. Choosing a larger value might improve performance. Please tune this value based on your hardware and workload as needed. Defaults to 16.\\n\",\n    \"\\n\",\n    \"* `tp_size`: LoRA serving along with Tensor Parallelism is supported by SGLang. `tp_size` controls the number of GPUs for tensor parallelism. More details on the tensor sharding strategy can be found in [S-Lora](https://arxiv.org/pdf/2311.03285) paper.\\n\",\n    \"\\n\",\n    \"From client side, the user needs to provide a list of strings as input batch, and a list of adaptor names that each input sequence corresponds to.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Usage\\n\",\n    \"\\n\",\n    \"### Serving Single Adaptor\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"**Note:** SGLang supports LoRA adapters through two APIs:\\n\",\n    \"\\n\",\n    \"1. **OpenAI-Compatible API** (`/v1/chat/completions`, `/v1/completions`): Use the `model:adapter-name` syntax. See [OpenAI API with LoRA](../basic_usage/openai_api_completions.ipynb#Using-LoRA-Adapters) for examples.\\n\",\n    \"\\n\",\n    \"2. **Native API** (`/generate`): Pass `lora_path` in the request body (shown below).\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import json\\n\",\n    \"import requests\\n\",\n    \"\\n\",\n    \"from sglang.test.doc_patch import launch_server_cmd\\n\",\n    \"from sglang.utils import wait_for_server, terminate_process\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"server_process, port = launch_server_cmd(\\n\",\n    \"    # Here we set max-loras-per-batch to 2: one slot for adaptor and another one for base model\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\\\\n\",\n    \"    --enable-lora \\\\\\n\",\n    \"    --lora-paths lora0=algoprog/fact-generation-llama-3.1-8b-instruct-lora \\\\\\n\",\n    \"    --max-loras-per-batch 2 \\\\\\n\",\n    \"    --log-level warning \\\\\\n\",\n    \"\\\"\\\"\\\"\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"wait_for_server(f\\\"http://localhost:{port}\\\", process=server_process)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"url = f\\\"http://127.0.0.1:{port}\\\"\\n\",\n    \"json_data = {\\n\",\n    \"    \\\"text\\\": [\\n\",\n    \"        \\\"List 3 countries and their capitals.\\\",\\n\",\n    \"        \\\"List 3 countries and their capitals.\\\",\\n\",\n    \"    ],\\n\",\n    \"    \\\"sampling_params\\\": {\\\"max_new_tokens\\\": 32, \\\"temperature\\\": 0},\\n\",\n    \"    # The first input uses lora0, and the second input uses the base model\\n\",\n    \"    \\\"lora_path\\\": [\\\"lora0\\\", None],\\n\",\n    \"}\\n\",\n    \"response = requests.post(\\n\",\n    \"    url + \\\"/generate\\\",\\n\",\n    \"    json=json_data,\\n\",\n    \")\\n\",\n    \"print(f\\\"Output 0: {response.json()[0]['text']}\\\")\\n\",\n    \"print(f\\\"Output 1: {response.json()[1]['text']}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"terminate_process(server_process)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Serving Multiple Adaptors\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"server_process, port = launch_server_cmd(\\\"\\\"\\\"\\n\",\n    \"python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\\\\n\",\n    \"    --enable-lora \\\\\\n\",\n    \"    --lora-paths lora0=algoprog/fact-generation-llama-3.1-8b-instruct-lora \\\\\\n\",\n    \"    lora1=Nutanix/Meta-Llama-3.1-8B-Instruct_SFT_lora_4_alpha_16_humaneval_raw_json \\\\\\n\",\n    \"    --max-loras-per-batch 2 \\\\\\n\",\n    \"    --log-level warning \\\\\\n\",\n    \"\\\"\\\"\\\")\\n\",\n    \"\\n\",\n    \"wait_for_server(f\\\"http://localhost:{port}\\\", process=server_process)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"url = f\\\"http://127.0.0.1:{port}\\\"\\n\",\n    \"json_data = {\\n\",\n    \"    \\\"text\\\": [\\n\",\n    \"        \\\"List 3 countries and their capitals.\\\",\\n\",\n    \"        \\\"List 3 countries and their capitals.\\\",\\n\",\n    \"    ],\\n\",\n    \"    \\\"sampling_params\\\": {\\\"max_new_tokens\\\": 32, \\\"temperature\\\": 0},\\n\",\n    \"    # The first input uses lora0, and the second input uses lora1\\n\",\n    \"    \\\"lora_path\\\": [\\\"lora0\\\", \\\"lora1\\\"],\\n\",\n    \"}\\n\",\n    \"response = requests.post(\\n\",\n    \"    url + \\\"/generate\\\",\\n\",\n    \"    json=json_data,\\n\",\n    \")\\n\",\n    \"print(f\\\"Output 0: {response.json()[0]['text']}\\\")\\n\",\n    \"print(f\\\"Output 1: {response.json()[1]['text']}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"terminate_process(server_process)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Dynamic LoRA loading\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Instead of specifying all adapters during server startup via `--lora-paths`. You can also load & unload LoRA adapters dynamically via the `/load_lora_adapter` and `/unload_lora_adapter` API.\\n\",\n    \"\\n\",\n    \"When using dynamic LoRA loading, it's recommended to explicitly specify both `--max-lora-rank` and `--lora-target-modules` at startup. For backward compatibility, SGLang will infer these values from `--lora-paths` if they are not explicitly provided. However, in that case, you would have to ensure that all dynamically loaded adapters share the same shape (rank and target modules) as those in the initial `--lora-paths` or are strictly \\\"smaller\\\".\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"lora0 = \\\"Nutanix/Meta-Llama-3.1-8B-Instruct_SFT_lora_4_alpha_16_humaneval_raw_json\\\"  # rank - 4, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj\\n\",\n    \"lora1 = \\\"algoprog/fact-generation-llama-3.1-8b-instruct-lora\\\"  # rank - 64, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj\\n\",\n    \"lora0_new = \\\"philschmid/code-llama-3-1-8b-text-to-sql-lora\\\"  # rank - 256, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"# The `--target-lora-modules` param below is technically not needed, as the server will infer it from lora0 which already has all the target modules specified.\\n\",\n    \"# We are adding it here just to demonstrate usage.\\n\",\n    \"server_process, port = launch_server_cmd(\\\"\\\"\\\"\\n\",\n    \"    python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\\\\n\",\n    \"    --enable-lora \\\\\\n\",\n    \"    --cuda-graph-max-bs 2 \\\\\\n\",\n    \"    --max-loras-per-batch 2 \\\\\\n\",\n    \"    --max-lora-rank 256\\n\",\n    \"    --lora-target-modules all\\n\",\n    \"    --log-level warning\\n\",\n    \"    \\\"\\\"\\\")\\n\",\n    \"\\n\",\n    \"url = f\\\"http://127.0.0.1:{port}\\\"\\n\",\n    \"wait_for_server(url, process=server_process)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Load adapter lora0\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"response = requests.post(\\n\",\n    \"    url + \\\"/load_lora_adapter\\\",\\n\",\n    \"    json={\\n\",\n    \"        \\\"lora_name\\\": \\\"lora0\\\",\\n\",\n    \"        \\\"lora_path\\\": lora0,\\n\",\n    \"    },\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"if response.status_code == 200:\\n\",\n    \"    print(\\\"LoRA adapter loaded successfully.\\\", response.json())\\n\",\n    \"else:\\n\",\n    \"    print(\\\"Failed to load LoRA adapter.\\\", response.json())\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Load adapter lora1:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"response = requests.post(\\n\",\n    \"    url + \\\"/load_lora_adapter\\\",\\n\",\n    \"    json={\\n\",\n    \"        \\\"lora_name\\\": \\\"lora1\\\",\\n\",\n    \"        \\\"lora_path\\\": lora1,\\n\",\n    \"    },\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"if response.status_code == 200:\\n\",\n    \"    print(\\\"LoRA adapter loaded successfully.\\\", response.json())\\n\",\n    \"else:\\n\",\n    \"    print(\\\"Failed to load LoRA adapter.\\\", response.json())\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Check inference output:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"url = f\\\"http://127.0.0.1:{port}\\\"\\n\",\n    \"json_data = {\\n\",\n    \"    \\\"text\\\": [\\n\",\n    \"        \\\"List 3 countries and their capitals.\\\",\\n\",\n    \"        \\\"List 3 countries and their capitals.\\\",\\n\",\n    \"    ],\\n\",\n    \"    \\\"sampling_params\\\": {\\\"max_new_tokens\\\": 32, \\\"temperature\\\": 0},\\n\",\n    \"    # The first input uses lora0, and the second input uses lora1\\n\",\n    \"    \\\"lora_path\\\": [\\\"lora0\\\", \\\"lora1\\\"],\\n\",\n    \"}\\n\",\n    \"response = requests.post(\\n\",\n    \"    url + \\\"/generate\\\",\\n\",\n    \"    json=json_data,\\n\",\n    \")\\n\",\n    \"print(f\\\"Output from lora0: \\\\n{response.json()[0]['text']}\\\\n\\\")\\n\",\n    \"print(f\\\"Output from lora1 (updated): \\\\n{response.json()[1]['text']}\\\\n\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Unload lora0 and replace it with a different adapter:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"response = requests.post(\\n\",\n    \"    url + \\\"/unload_lora_adapter\\\",\\n\",\n    \"    json={\\n\",\n    \"        \\\"lora_name\\\": \\\"lora0\\\",\\n\",\n    \"    },\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"response = requests.post(\\n\",\n    \"    url + \\\"/load_lora_adapter\\\",\\n\",\n    \"    json={\\n\",\n    \"        \\\"lora_name\\\": \\\"lora0\\\",\\n\",\n    \"        \\\"lora_path\\\": lora0_new,\\n\",\n    \"    },\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"if response.status_code == 200:\\n\",\n    \"    print(\\\"LoRA adapter loaded successfully.\\\", response.json())\\n\",\n    \"else:\\n\",\n    \"    print(\\\"Failed to load LoRA adapter.\\\", response.json())\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Check output again:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"url = f\\\"http://127.0.0.1:{port}\\\"\\n\",\n    \"json_data = {\\n\",\n    \"    \\\"text\\\": [\\n\",\n    \"        \\\"List 3 countries and their capitals.\\\",\\n\",\n    \"        \\\"List 3 countries and their capitals.\\\",\\n\",\n    \"    ],\\n\",\n    \"    \\\"sampling_params\\\": {\\\"max_new_tokens\\\": 32, \\\"temperature\\\": 0},\\n\",\n    \"    # The first input uses lora0, and the second input uses lora1\\n\",\n    \"    \\\"lora_path\\\": [\\\"lora0\\\", \\\"lora1\\\"],\\n\",\n    \"}\\n\",\n    \"response = requests.post(\\n\",\n    \"    url + \\\"/generate\\\",\\n\",\n    \"    json=json_data,\\n\",\n    \")\\n\",\n    \"print(f\\\"Output from lora0: \\\\n{response.json()[0]['text']}\\\\n\\\")\\n\",\n    \"print(f\\\"Output from lora1 (updated): \\\\n{response.json()[1]['text']}\\\\n\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"terminate_process(server_process)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### OpenAI-compatible API usage\\n\",\n    \"\\n\",\n    \"You can use LoRA adapters via the OpenAI-compatible APIs by specifying the adapter in the `model` field using the `base-model:adapter-name` syntax (for example, `qwen/qwen2.5-0.5b-instruct:adapter_a`). For more details and examples, see the “Using LoRA Adapters” section in the OpenAI API documentation: [openai_api_completions.ipynb](../basic_usage/openai_api_completions.ipynb).\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### LoRA GPU Pinning\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Another advanced option is to specify adapters as `pinned` during loading. When an adapter is pinned, it is permanently assigned to one of the available GPU pool slots (as configured by `--max-loras-per-batch`) and will not be evicted from GPU memory during runtime. Instead, it remains resident until it is explicitly unloaded.\\n\",\n    \"\\n\",\n    \"This can improve performance in scenarios where the same adapter is frequently used across requests, by avoiding repeated memory transfers and reinitialization overhead. However, since GPU pool slots are limited, pinning adapters reduces the flexibility of the system to dynamically load other adapters on demand. If too many adapters are pinned, it may lead to degraded performance, or in the most extreme case (`Number of pinned adapters == max-loras-per-batch`), halt all unpinned requests. Therefore, currently SGLang limits maximal number of pinned adapters to `max-loras-per-batch - 1` to prevent unexpected starvations. \\n\",\n    \"\\n\",\n    \"In the example below, we start a server with `lora1` loaded as pinned, `lora2` and `lora3` loaded as regular (unpinned) adapters. Please note that, we intentionally specify `lora2` and `lora3` in two different formats to demonstrate that both are supported.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"server_process, port = launch_server_cmd(\\\"\\\"\\\"\\n\",\n    \"    python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\\\\n\",\n    \"    --enable-lora \\\\\\n\",\n    \"    --cuda-graph-max-bs 8 \\\\\\n\",\n    \"    --max-loras-per-batch 3 \\\\\\n\",\n    \"    --max-lora-rank 256 \\\\\\n\",\n    \"    --lora-target-modules all \\\\\\n\",\n    \"    --lora-paths \\\\\\n\",\n    \"        {\\\"lora_name\\\":\\\"lora0\\\",\\\"lora_path\\\":\\\"Nutanix/Meta-Llama-3.1-8B-Instruct_SFT_lora_4_alpha_16_humaneval_raw_json\\\",\\\"pinned\\\":true} \\\\\\n\",\n    \"        {\\\"lora_name\\\":\\\"lora1\\\",\\\"lora_path\\\":\\\"algoprog/fact-generation-llama-3.1-8b-instruct-lora\\\"} \\\\\\n\",\n    \"        lora2=philschmid/code-llama-3-1-8b-text-to-sql-lora\\n\",\n    \"    --log-level warning\\n\",\n    \"    \\\"\\\"\\\")\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"url = f\\\"http://127.0.0.1:{port}\\\"\\n\",\n    \"wait_for_server(url, process=server_process)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"You can also specify adapter as pinned during dynamic adapter loading. In the example below, we reload `lora2` as pinned adapter:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"response = requests.post(\\n\",\n    \"    url + \\\"/unload_lora_adapter\\\",\\n\",\n    \"    json={\\n\",\n    \"        \\\"lora_name\\\": \\\"lora1\\\",\\n\",\n    \"    },\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"response = requests.post(\\n\",\n    \"    url + \\\"/load_lora_adapter\\\",\\n\",\n    \"    json={\\n\",\n    \"        \\\"lora_name\\\": \\\"lora1\\\",\\n\",\n    \"        \\\"lora_path\\\": \\\"algoprog/fact-generation-llama-3.1-8b-instruct-lora\\\",\\n\",\n    \"        \\\"pinned\\\": True,  # Pin the adapter to GPU\\n\",\n    \"    },\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Verify that the results are expected:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"url = f\\\"http://127.0.0.1:{port}\\\"\\n\",\n    \"json_data = {\\n\",\n    \"    \\\"text\\\": [\\n\",\n    \"        \\\"List 3 countries and their capitals.\\\",\\n\",\n    \"        \\\"List 3 countries and their capitals.\\\",\\n\",\n    \"        \\\"List 3 countries and their capitals.\\\",\\n\",\n    \"    ],\\n\",\n    \"    \\\"sampling_params\\\": {\\\"max_new_tokens\\\": 32, \\\"temperature\\\": 0},\\n\",\n    \"    # The first input uses lora0, and the second input uses lora1\\n\",\n    \"    \\\"lora_path\\\": [\\\"lora0\\\", \\\"lora1\\\", \\\"lora2\\\"],\\n\",\n    \"}\\n\",\n    \"response = requests.post(\\n\",\n    \"    url + \\\"/generate\\\",\\n\",\n    \"    json=json_data,\\n\",\n    \")\\n\",\n    \"print(f\\\"Output from lora0 (pinned): \\\\n{response.json()[0]['text']}\\\\n\\\")\\n\",\n    \"print(f\\\"Output from lora1 (pinned): \\\\n{response.json()[1]['text']}\\\\n\\\")\\n\",\n    \"print(f\\\"Output from lora2 (not pinned): \\\\n{response.json()[2]['text']}\\\\n\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"terminate_process(server_process)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Choosing LoRA Backend\\n\",\n    \"\\n\",\n    \"SGLang supports two LoRA backends that you can choose from using the `--lora-backend` argument:\\n\",\n    \"\\n\",\n    \"- `triton`: Basic Triton-based backend.\\n\",\n    \"- `csgmv`: Default chunked SGMV backend optimized for high concurrency scenarios.\\n\",\n    \"\\n\",\n    \"The `csgmv` backend was recently introduced to improve performance especially at high-concurrency scenarios. Our benchmark shows that it achieves 20% to 80% latency improvements over the basic triton backend.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"server_process, port = launch_server_cmd(\\\"\\\"\\\"\\n\",\n    \"    python3 -m sglang.launch_server \\\\\\n\",\n    \"    --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\\\\n\",\n    \"    --enable-lora \\\\\\n\",\n    \"    --lora-backend csgmv \\\\\\n\",\n    \"    --max-loras-per-batch 16 \\\\\\n\",\n    \"    --lora-paths lora1=path/to/lora1 lora2=path/to/lora2\\n\",\n    \"    \\\"\\\"\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"terminate_process(server_process)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## LoRA Overlap Loading\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"By using the `--enable-lora-overlap-loading` server argument, the SGLang engine is able to overlap the loading of LoRA weights with prefill and decode compute, essentially hiding the data movement for LoRA weights behind GPU computation. Our benchmarks show that under adversarial conditions, enabling this feature can result in a ~35% reduction in median TTFT - (see the [LoRA overlap loading PR](https://github.com/sgl-project/sglang/pull/15512) for detailed benchmarks).\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"lora0 = \\\"Nutanix/Meta-Llama-3.1-8B-Instruct_SFT_lora_4_alpha_16_humaneval_raw_json\\\"\\n\",\n    \"lora1 = \\\"algoprog/fact-generation-llama-3.1-8b-instruct-lora\\\"\\n\",\n    \"lora2 = \\\"philschmid/code-llama-3-1-8b-text-to-sql-lora\\\"\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"server_process, port = launch_server_cmd(\\\"\\\"\\\"\\n\",\n    \"    python3 -m sglang.launch_server \\\\\\n\",\n    \"    --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\\\\n\",\n    \"    --enable-lora \\\\\\n\",\n    \"    --enable-lora-overlap-loading \\\\\\n\",\n    \"    --lora-paths lora0=Nutanix/Meta-Llama-3.1-8B-Instruct_SFT_lora_4_alpha_16_humaneval_raw_json \\\\\\n\",\n    \"    lora1=algoprog/fact-generation-llama-3.1-8b-instruct-lora \\\\\\n\",\n    \"    lora2=philschmid/code-llama-3-1-8b-text-to-sql-lora \\\\\\n\",\n    \"    --max-lora-rank 256 \\\\\\n\",\n    \"    --max-loras-per-batch 2 \\\\\\n\",\n    \"    --max-loaded-loras 4\\n\",\n    \"    \\\"\\\"\\\")\\n\",\n    \"\\n\",\n    \"url = f\\\"http://127.0.0.1:{port}\\\"\\n\",\n    \"wait_for_server(url, process=server_process)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"json_data = {\\n\",\n    \"    \\\"text\\\": [\\n\",\n    \"        \\\"Write a very long fairy-tale.\\\",\\n\",\n    \"        \\\"List 3 countries and their capitals.\\\",\\n\",\n    \"        \\\"List 3 countries and their capitals.\\\",\\n\",\n    \"    ],\\n\",\n    \"    \\\"sampling_params\\\": [\\n\",\n    \"        {\\\"max_new_tokens\\\": 1024, \\\"temperature\\\": 0},\\n\",\n    \"        {\\\"max_new_tokens\\\": 64, \\\"temperature\\\": 0},\\n\",\n    \"        {\\\"max_new_tokens\\\": 64, \\\"temperature\\\": 0},\\n\",\n    \"    ],\\n\",\n    \"    \\\"lora_path\\\": [\\\"lora0\\\", \\\"lora1\\\", \\\"lora2\\\"],\\n\",\n    \"}\\n\",\n    \"\\n\",\n    \"# lora0 and lora1 will be loaded into the memory pool first, and because max_loras_per_batch = 2, lora2's request will remain in the queue.\\n\",\n    \"# lora1's request will likely finish first, and once it does, lora2 will be loaded. With --enable-lora-overlap-loading, this loading will\\n\",\n    \"# occur asynchronously and thus decoding for lora0's request won't be blocked.\\n\",\n    \"response = requests.post(\\n\",\n    \"    url + \\\"/generate\\\",\\n\",\n    \"    json=json_data,\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"for i in range(3):\\n\",\n    \"    print(f\\\"Output from lora{i}: \\\\n{response.json()[i]['text']}\\\\n\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"terminate_process(server_process)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Limitations of LoRA Overlap Loading\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"However, LoRA overlap loading is not free and comes with two important caveats:\\n\",\n    \"\\n\",\n    \"1. **Pinned CPU memory requirement**:\\n\",\n    \"   Asynchronous H2D memory copies require LoRA weights to be pinned in CPU memory, which is a finite system resource. To mitigate excessive pinned-memory usage, SGLang currently restricts `max_loaded_loras` to be at most 2× `max_loras_per_batch` when LoRA overlap loading is enabled.\\n\",\n    \"\\n\",\n    \"2. **Reduced multi-adapter prefill batching**:\\n\",\n    \"   With overlap loading, adapters become available on the GPU at different times because each adapter is loaded asynchronously. This can reduce the scheduler’s ability to form multi-adapter prefill batches, since only requests whose adapters are currently loaded can be grouped together. As a result, requests for different adapters will be scheduled in separate (or smaller) prefill batches, which can increase TTFT when adapter load time is small compared to prefill compute time. This is why LoRA overlap loading is disabled by default: it should only be enabled when users have determined that LoRA weight loading is a bottleneck (EG high adapter churn, heavy adapter weights, or PCIe-bottlenecked workloads).\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Example When Overlap Loading Results in Higher Latency\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"For instance, suppose we have four LoRA adapters: `lora0`, `lora1`, `lora2`, and `lora3`. Loading any adapter takes 2ms, while the prefill step for requests for that adapter takes 20ms.\\n\",\n    \"\\n\",\n    \"1. **Baseline**:\\n\",\n    \"  The engine loads all four adapters synchronously, then runs one combined prefill batch, giving us a total time of ≈ `2 * 4 + 20 = 28ms`\\n\",\n    \"\\n\",\n    \"2. **With LoRA overlap loading enabled**:\\n\",\n    \"  The engine begins loading `lora0` and, once it is ready, schedules a prefill batch containing only `lora0` while `lora1` loads in the background. Then it schedules `lora1`’s prefill while `lora2` loads, and so on. In the worst case where prefill cannot be batched across adapters, total time is ≈ `2 + 4 * 20 = 82ms`\\n\",\n    \"\\n\",\n    \"In this scenario, overlap loading reduces adapter-load overhead, but the loss of multi-adapter prefill batching dominates and leads to higher TTFT.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Future Works\\n\",\n    \"\\n\",\n    \"The development roadmap for LoRA-related features can be found in this [issue](https://github.com/sgl-project/sglang/issues/2929). Other features, including Embedding Layer, Unified Paging, Cutlass backend are still under development.\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "docs/advanced_features/observability.md",
    "content": "# Observability\n\n## Production Metrics\nSGLang exposes the following metrics via Prometheus. You can enable them by adding `--enable-metrics` when launching the server.\nYou can query them by:\n```\ncurl http://localhost:30000/metrics\n```\n\nSee [Production Metrics](../references/production_metrics.md) and [Production Request Tracing](../references/production_request_trace.md) for more details.\n\n## Logging\n\nBy default, SGLang does not log any request contents. You can log them by using `--log-requests`.\nYou can control the verbosity by using `--log-request-level`.\nSee [Logging](server_arguments.md#logging) for more details.\n\n## Request Dump and Replay\n\nYou can dump all requests and replay them later for benchmarking or other purposes.\n\nTo start dumping, use the following command to send a request to a server:\n```\npython3 -m sglang.srt.managers.configure_logging --url http://localhost:30000 --dump-requests-folder /tmp/sglang_request_dump --dump-requests-threshold 100\n```\nThe server will dump the requests into a pickle file for every 100 requests.\n\nTo replay the request dump, use `scripts/playground/replay_request_dump.py`.\n\n## Crash Dump and Replay\nSometimes the server might crash, and you may want to debug the cause of the crash.\nSGLang supports crash dumping, which will dump all requests from the 5 minutes before the crash, allowing you to replay the requests and debug the reason later.\n\nTo enable crash dumping, use `--crash-dump-folder /tmp/crash_dump`.\nTo replay the crash dump, use `scripts/playground/replay_request_dump.py`.\n"
  },
  {
    "path": "docs/advanced_features/pd_disaggregation.md",
    "content": "# PD Disaggregation\n\n## Why and What is PD Disaggregation?\n\nLarge Language Model (LLM) inference comprises two distinct phases: **Prefill** and **Decode**. The Prefill phase is computation-intensive, processing the entire input sequence, while the Decode phase is memory-intensive, managing the Key-Value (KV) cache for token generation. Traditionally, these phases are handled within a unified engine, where combined scheduling of prefill and decode batches introduces inefficiencies. To address these challenges, we introduce **Prefill and Decoding (PD) Disaggregation** in SGLang.\n\n### Issues with Unified Scheduling\n\nThe conventional unified engine, which processes prefill and decode batches together, results in two significant problems:\n\n1. **Prefill Interruption**: Incoming prefill batches frequently interrupt ongoing decode batches, causing substantial delays in token generation.\n2. **DP Attention Imbalance**: In data-parallel (DP) attention, one DP worker may process a prefill batch while another handles a decode batch simultaneously, leading to increased decode latency.\n\nPD Disaggregation resolves these by separating the two stages, enabling tailored optimizations for each.\n\nFor the design details, please refer to [link](https://docs.google.com/document/d/1rQXJwKd5b9b1aOzLh98mnyMhBMhlxXA5ATZTHoQrwvc/edit?tab=t.0).\n\nCurrently, we support Mooncake and NIXL as the transfer engine.\n\n## Profiling in PD Disaggregation Mode\n\nWhen you need to profile prefill or decode workers in PD disaggregation mode, please refer to the [Profile In PD Disaggregation Mode](https://docs.sglang.io/developer_guide/benchmark_and_profiling.html#profile-in-pd-disaggregation-mode) section in the Benchmark and Profiling guide. Due to torch profiler limitations, prefill and decode workers must be profiled separately using dedicated command-line options.\n\n## Router Integration\n\nFor deploying PD disaggregation at scale with load balancing and fault tolerance, SGLang provides a router. The router can distribute requests between prefill and decode instances using various routing policies. For detailed information on setting up routing with PD disaggregation, including configuration options and deployment patterns, see the [SGLang Model Gateway (former Router)](../advanced_features/sgl_model_gateway.md#prefill-decode-disaggregation).\n\n\n## Mooncake\n### Requirements\n\n```bash\nuv pip install mooncake-transfer-engine\n```\n\n### Usage\n\n### Llama Single Node\n\n```bash\npython -m sglang.launch_server \\\n  --model-path meta-llama/Llama-3.1-8B-Instruct \\\n  --disaggregation-mode prefill \\\n  --port 30000 \\\n  --disaggregation-ib-device mlx5_roce0\npython -m sglang.launch_server \\\n  --model-path meta-llama/Llama-3.1-8B-Instruct \\\n  --disaggregation-mode decode \\\n  --port 30001 \\\n  --base-gpu-id 1 \\\n  --disaggregation-ib-device mlx5_roce0\npython -m sglang_router.launch_router --pd-disaggregation --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000\n```\n\n### DeepSeek Multi-Node\n\n```bash\n# prefill 0\npython -m sglang.launch_server \\\n  --model-path deepseek-ai/DeepSeek-V3-0324 \\\n  --disaggregation-ib-device ${device_name} \\\n  --disaggregation-mode prefill \\\n  --host ${local_ip} \\\n  --port 30000 \\\n  --trust-remote-code \\\n  --dist-init-addr ${prefill_master_ip}:5000 \\\n  --nnodes 2 \\\n  --node-rank 0 \\\n  --tp-size 16 \\\n  --dp-size 8 \\\n  --enable-dp-attention \\\n  --moe-a2a-backend deepep \\\n  --mem-fraction-static 0.8\n# prefill 1\npython -m sglang.launch_server \\\n  --model-path deepseek-ai/DeepSeek-V3-0324 \\\n  --disaggregation-ib-device ${device_name} \\\n  --disaggregation-mode prefill \\\n  --host ${local_ip} \\\n  --port 30000 \\\n  --trust-remote-code \\\n  --dist-init-addr ${prefill_master_ip}:5000 \\\n  --nnodes 2 \\\n  --node-rank 1 \\\n  --tp-size 16 \\\n  --dp-size 8 \\\n  --enable-dp-attention \\\n  --moe-a2a-backend deepep \\\n  --mem-fraction-static 0.8\n# decode 0\npython -m sglang.launch_server \\\n  --model-path deepseek-ai/DeepSeek-V3-0324 \\\n  --disaggregation-ib-device ${device_name} \\\n  --disaggregation-mode decode \\\n  --host ${local_ip} \\\n  --port 30001 \\\n  --trust-remote-code \\\n  --dist-init-addr ${decode_master_ip}:5000 \\\n  --nnodes 2 \\\n  --node-rank 0 \\\n  --tp-size 16 \\\n  --dp-size 8 \\\n  --enable-dp-attention \\\n  --moe-a2a-backend deepep \\\n  --mem-fraction-static 0.8 \\\n  --max-running-requests 128\n# decode 1\npython -m sglang.launch_server \\\n  --model-path deepseek-ai/DeepSeek-V3-0324 \\\n  --disaggregation-ib-device ${device_name} \\\n  --disaggregation-mode decode \\\n  --host ${local_ip} \\\n  --port 30001 \\\n  --trust-remote-code \\\n  --dist-init-addr ${decode_master_ip}:5000 \\\n  --nnodes 2 \\\n  --node-rank 1 \\\n  --tp-size 16 \\\n  --dp-size 8 \\\n  --enable-dp-attention \\\n  --moe-a2a-backend deepep \\\n  --mem-fraction-static 0.8 \\\n  --max-running-requests 128\n```\n### Advanced Configuration\n\nPD Disaggregation with Mooncake supports the following environment variables for fine-grained control over system behavior.\n\n#### NVLink Transport Configuration\nTo enable NVLink transport for KV cache transfers with the mooncake backend (recommended for NVL72 deployments), set the following environment variables. Note that auxiliary data transfer will still use TCP as a temporary workaround.\n\n```bash\nexport SGLANG_MOONCAKE_CUSTOM_MEM_POOL=NVLINK\nexport MC_FORCE_MNNVL=True\n```\n\nThe `SGLANG_MOONCAKE_CUSTOM_MEM_POOL` environment variable enables the custom memory pool. Supported values are `NVLINK` (or `True`), `BAREX`, and `INTRA_NODE_NVLINK`.\n\n#### Prefill Server Configuration\n| Variable | Description | Default |\n|:--------:|:-----------:|:--------:\n| **`SGLANG_DISAGGREGATION_THREAD_POOL_SIZE`** | Controls the total number of worker threads for KVCache transfer operations per TP rank | A dynamic value calculated by `int(0.75 * os.cpu_count()) // 8)`, which is limited to be larger than 4 and less than 12 to ensure efficiency and prevent thread race conditions |\n| **`SGLANG_DISAGGREGATION_QUEUE_SIZE`** | Sets the number of parallel transfer queues. KVCache transfer requests from multiple decode instances will be sharded into these queues so that they can share the threads and the transfer bandwidth at the same time. If it is set to `1`, then we transfer requests one by one according to fcfs strategy | `4` |\n| **`SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT`** | Timeout (seconds) for receiving destination KV indices during request initialization | `300` |\n| **`SGLANG_DISAGGREGATION_BOOTSTRAP_ENTRY_CLEANUP_INTERVAL`** | Interval (seconds) between cleanups of bootstrap entries | `120` |\n\nIf a greater mean TTFT is acceptable, you can `export SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=600` (10 minutes) to relax the timeout condition.\nPlease be aware that this setting will cause prefill instances to take a longer time to clean up the affected memory resources when a running decode node loses connection.\n\n#### Decode Server Configuration\n| Variable | Description | Default |\n|:--------:|:-----------:|:--------:\n| **`SGLANG_DISAGGREGATION_HEARTBEAT_INTERVAL`** | Interval (seconds) between health checks to prefill bootstrap servers | `5.0` |\n| **`SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE`** | Consecutive heartbeat failures before marking prefill server offline | `2` |\n| **`SGLANG_DISAGGREGATION_WAITING_TIMEOUT`** | Timeout (seconds) for receiving KV Cache after request initialization | `300` |\n\nIf a greater mean TTFT is acceptable, you can `export SGLANG_DISAGGREGATION_WAITING_TIMEOUT=600` (10 minutes) to relax the timeout condition.\n\n\n## NIXL\n### Requirements\n\nInstall via pip.\n\n```bash\npip install nixl\n```\n\nOr build from source - may be required if you already have UCX installed.\n\n```bash\ngit clone https://github.com/ai-dynamo/nixl.git\ncd nixl\npip install . --config-settings=setup-args=\"-Ducx_path=/path/to/ucx\"\n```\n\n\n### Usage\n\n### Llama Single Node\n\n```bash\npython -m sglang.launch_server \\\n  --model-path meta-llama/Llama-3.1-8B-Instruct \\\n  --disaggregation-mode prefill \\\n  --port 30000 \\\n  --disaggregation-transfer-backend nixl\npython -m sglang.launch_server \\\n  --model-path meta-llama/Llama-3.1-8B-Instruct \\\n  --disaggregation-mode decode \\\n  --port 30001 \\\n  --base-gpu-id 1 \\\n  --disaggregation-transfer-backend nixl\npython -m sglang_router.launch_router --pd-disaggregation --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000\n```\n\n### DeepSeek Multi-Node\n\n```bash\n# prefill 0\npython -m sglang.launch_server \\\n  --model-path deepseek-ai/DeepSeek-V3-0324 \\\n  --disaggregation-transfer-backend nixl \\\n  --disaggregation-mode prefill \\\n  --host ${local_ip} \\\n  --port 30000 \\\n  --trust-remote-code \\\n  --dist-init-addr ${prefill_master_ip}:5000 \\\n  --nnodes 2 \\\n  --node-rank 0 \\\n  --tp-size 16 \\\n  --dp-size 8 \\\n  --enable-dp-attention \\\n  --moe-a2a-backend deepep \\\n  --mem-fraction-static 0.8\n# prefill 1\npython -m sglang.launch_server \\\n  --model-path deepseek-ai/DeepSeek-V3-0324 \\\n  --disaggregation-transfer-backend nixl \\\n  --disaggregation-mode prefill \\\n  --host ${local_ip} \\\n  --port 30000 \\\n  --trust-remote-code \\\n  --dist-init-addr ${prefill_master_ip}:5000 \\\n  --nnodes 2 \\\n  --node-rank 1 \\\n  --tp-size 16 \\\n  --dp-size 8 \\\n  --enable-dp-attention \\\n  --moe-a2a-backend deepep \\\n  --mem-fraction-static 0.8\n# decode 0\npython -m sglang.launch_server \\\n  --model-path deepseek-ai/DeepSeek-V3-0324 \\\n  --disaggregation-transfer-backend nixl \\\n  --disaggregation-mode decode \\\n  --host ${local_ip} \\\n  --port 30001 \\\n  --trust-remote-code \\\n  --dist-init-addr ${decode_master_ip}:5000 \\\n  --nnodes 2 \\\n  --node-rank 0 \\\n  --tp-size 16 \\\n  --dp-size 8 \\\n  --enable-dp-attention \\\n  --moe-a2a-backend deepep \\\n  --mem-fraction-static 0.8 \\\n  --max-running-requests 128\n# decode 1\npython -m sglang.launch_server \\\n  --model-path deepseek-ai/DeepSeek-V3-0324 \\\n  --disaggregation-transfer-backend nixl \\\n  --disaggregation-mode decode \\\n  --host ${local_ip} \\\n  --port 30001 \\\n  --trust-remote-code \\\n  --dist-init-addr ${decode_master_ip}:5000 \\\n  --nnodes 2 \\\n  --node-rank 1 \\\n  --tp-size 16 \\\n  --dp-size 8 \\\n  --enable-dp-attention \\\n  --moe-a2a-backend deepep \\\n  --mem-fraction-static 0.8 \\\n  --max-running-requests 128\n```\n\n### Advanced Configuration\n\n#### NIXL Backend Selection\n\nBy default, NIXL uses the **UCX** backend for KV cache transfers. You can select a different NIXL plugin backend depending on your infrastructure using the environment variable `SGLANG_DISAGGREGATION_NIXL_BACKEND`.\n\nExample: `export SGLANG_DISAGGREGATION_NIXL_BACKEND=LIBFABRIC`\n\n**Available backends:** UCX (default), LIBFABRIC, or any installed NIXL plugin.\n\nExample usage:\n```bash\nexport SGLANG_DISAGGREGATION_NIXL_BACKEND=LIBFABRIC\npython -m sglang.launch_server \\\n  --model-path meta-llama/Llama-3.1-8B-Instruct \\\n  --disaggregation-mode prefill \\\n  --disaggregation-transfer-backend nixl \\\n  --port 30000\n```\n\n## ASCEND\n\n### Usage\n\nUse ascend backend with [memfabric_hybrid](https://gitcode.com/Ascend/memfabric_hybrid) and ASCEND_MF_STORE_URL being set\n\n```bash\npip install memfabric-hybrid==1.0.0\nexport ASCEND_MF_STORE_URL=\"tcp://xxx.xx.xxx.xxx:xxxx\"\n```\nUse mooncake backend, more details can be found in mooncake section.\n```bash\nexport ENABLE_ASCEND_TRANSFER_WITH_MOONCAKE=true\n```\nASCEND_NPU_PHY_ID need to be set in container env\n```bash\nexport ASCEND_NPU_PHY_ID=xxx\n```\n\n\n### Llama Single Node\n\n```bash\npython -m sglang.launch_server \\\n  --model-path meta-llama/Llama-3.1-8B-Instruct \\\n  --disaggregation-mode prefill \\\n  --port 30000 \\\n  --disaggregation-transfer-backend ascend\npython -m sglang.launch_server \\\n  --model-path meta-llama/Llama-3.1-8B-Instruct \\\n  --disaggregation-mode decode \\\n  --port 30001 \\\n  --base-gpu-id 1 \\\n  --disaggregation-transfer-backend ascend\npython -m sglang_router.launch_router --pd-disaggregation --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000\n```\n\n### DeepSeek Multi-Node\n\n```bash\n# prefill 0\npython -m sglang.launch_server \\\n  --model-path deepseek-ai/DeepSeek-V3-0324 \\\n  --disaggregation-transfer-backend ascend \\\n  --disaggregation-mode prefill \\\n  --host ${local_ip} \\\n  --port 30000 \\\n  --trust-remote-code \\\n  --dist-init-addr ${prefill_master_ip}:5000 \\\n  --nnodes 1 \\\n  --node-rank 0 \\\n  --tp-size 16\n# decode 0\npython -m sglang.launch_server \\\n  --model-path deepseek-ai/DeepSeek-V3-0324 \\\n  --disaggregation-transfer-backend ascend \\\n  --disaggregation-mode decode \\\n  --host ${local_ip} \\\n  --port 30001 \\\n  --trust-remote-code \\\n  --dist-init-addr ${decode_master_ip}:5000 \\\n  --nnodes 1 \\\n  --node-rank 0 \\\n  --tp-size 16\n```\n"
  },
  {
    "path": "docs/advanced_features/piecewise_cuda_graph.md",
    "content": "# Piecewise CUDA Graph\n\n## Motivation\n\nStandard CUDA graphs capture the entire model forward pass as a single graph. This works well for decode (fixed batch size), but not for extend/prefill where the number of tokens varies across iterations.\n\nPiecewise CUDA Graph (PCG) solves this by splitting the model's computation graph into pieces (roughly one per layer) at \"split points\" (e.g., MoE dispatch ops). Each piece is captured as a separate CUDA graph for a set of pre-defined token lengths. At runtime, the input is padded to the nearest captured size, and each piece is replayed. This eliminates kernel launch overhead for prefill/extend while still supporting dynamic shapes.\n\nRecently we **enabled PCG by default**, which means that the old `--enable-piecewise-cuda-graph` flag is deprecated. Use `--disable-piecewise-cuda-graph` to turn it off.\n\n## Usage\n\nPCG is enabled by default for supported configurations. No extra flags needed:\n\n```bash\npython3 -m sglang.launch_server \\\n    --model-path meta-llama/Llama-3.1-8B-Instruct\n```\n\n### Disable PCG\n\n```bash\npython3 -m sglang.launch_server \\\n    --model-path meta-llama/Llama-3.1-8B-Instruct \\\n    --disable-piecewise-cuda-graph\n```\n\n### Custom capture sizes\n\n```bash\npython3 -m sglang.launch_server \\\n    --model-path meta-llama/Llama-3.1-8B-Instruct \\\n    --piecewise-cuda-graph-max-tokens 2048\n```\n\n### Server Args\n\n| Argument | Default | Description |\n|---|---|---|\n| `--disable-piecewise-cuda-graph` | `False` | Disable PCG for extend/prefill. |\n| `--enforce-piecewise-cuda-graph` | `False` | Force-enable PCG, skipping all auto-disable conditions. For testing only. |\n| `--piecewise-cuda-graph-max-tokens` | `None` (auto) | Maximum token count to capture. Defaults to `chunked_prefill_size` (non-MLA) or `2048` (MLA). |\n| `--piecewise-cuda-graph-tokens` | `None` (auto) | Explicit list of token lengths to capture. Auto-generated if not set. |\n| `--piecewise-cuda-graph-compiler` | `\"eager\"` | Compiler backend for the captured subgraphs. Choices: `eager`, `inductor`. |\n| ~~`--enable-piecewise-cuda-graph`~~ | — | **Deprecated.** PCG is now enabled by default. Use `--enforce-piecewise-cuda-graph` to skip auto-disable conditions. |\n\n## Bug Report\n\nPCG is enabled by default but is still in an experimental stage. Since PCG relies on `torch.compile` to trace the model's forward pass, most bugs are introduced by torch compile tracing failures (e.g., untraceable ops, dynamic control flow, or graph breaks). If you encounter any issues related to PCG, please disable it by adding `--disable-piecewise-cuda-graph` to your launch command and report the bug at [GitHub Issues](https://github.com/sgl-project/sglang/issues/new/choose). We greatly appreciate your help in improving this feature.\n\n### For Users\n\nIf you see an error message like the following during server startup, it is a PCG bug:\n\n```\nPiecewise CUDA Graph is enabled by default as an experimental feature.\nTo work around this error, add --disable-piecewise-cuda-graph to your launch command.\nPlease report this issue at https://github.com/sgl-project/sglang/issues/new/choose\n```\n\nTo work around it, add `--disable-piecewise-cuda-graph` to your launch command. When filing a bug report, please include:\n1. The full error traceback\n2. Model name and quantization method\n3. Launch command with all arguments\n4. GPU type and driver version\n\n### For Developers\n\nSince PCG relies on `torch.compile` to trace the model's forward pass, newly developed CUDA kernels (both JIT kernels and sgl-kernels) are typically not compatible with `torch.compile` out of the box. The tracing will fail on untraceable operations such as JIT compilation, file I/O, or dynamic module loading inside the kernel.\n\nTo make a kernel compatible with PCG, you need to register it as a custom op using `register_custom_op` from `sglang.srt.utils.custom_op`. This wraps the kernel as an opaque node in the compiled graph so that `torch.compile` will not trace inside it.\n\n**Example usage (JIT kernel):**\n\n```python\nfrom sglang.srt.utils.custom_op import register_custom_op\n\n# Inplace operator (no return value)\n@register_custom_op(mutates_args=[\"output_q\", \"output_s\"])\ndef per_token_group_quant_8bit(\n    input: torch.Tensor,\n    output_q: torch.Tensor,\n    output_s: torch.Tensor,\n) -> None:\n    # kernel implementation ...\n```\n\n**Example usage (operator with output):**\n\n```python\n# out_shape indicates which argument has the same shape as the output\n@register_custom_op(mutates_args=[\"x\"], out_shape=0)\ndef add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n    return x.add_(y)\n```\n\nFor wrapping external library functions (e.g., FlashInfer kernels), use `register_custom_op_from_extern` instead. See `python/sglang/srt/utils/custom_op.py` for full API documentation.\n\n## How it works\n### Torch compile backend\n\nPCG uses `torch.compile` with a custom backend (`SGLangBackend`) to split and compile the model's forward pass. The flow is:\n\n```\nmodel.forward wrapper\n→ torch.compile(..., backend=SGLangBackend)\n→ FX graph\n→ split_graph() at registered split ops\n→ split_gm (top-level graph that chains the pieces)\n→ replace capturable submodules with CUDAPiecewiseBackend\n→ runtime dispatch: eager split ops + per-piece capture/replay\n```\n\n- **Install**: `install_torch_compiled()` replaces `model.forward` with a wrapper function. When `is_in_piecewise_cuda_graph()` returns True, the wrapper dispatches to the compiled callable; otherwise it falls back to the original forward. The first invocation through this path triggers Dynamo tracing and graph compilation — CUDA graph replay only happens after the capture phase completes.\n\n- **Split**: When `torch.compile` traces the model, `SGLangBackend` receives the FX graph and calls `split_graph()`. Ops listed in `CompilationConfig.split_ops` are treated as split points, so the graph is cut at each one. These split-op submodules are left to run eagerly at runtime, while the surrounding submodules are compiled and wrapped by `CUDAPiecewiseBackend`. The result is a top-level \"stitching graph\" (`split_gm`) with children such as `submod_0`, `submod_1`, … interleaving capturable subgraphs and eager split-op submodules.\n\n- **Replace**: `PiecewiseCompileInterpreter` iterates over each capturable submodule in `split_gm`, compiles it for general (dynamic) shapes, and replaces it in-place with a `CUDAPiecewiseBackend` instance. Split-op submodules (e.g., attention, all-reduce) are left as-is and run eagerly at runtime.\n\n- **Dispatch**: At runtime, calling `split_gm` executes the stitching graph, which calls each submodule in order. Split-op submodules run eagerly. Each `CUDAPiecewiseBackend` submodule goes through three phases:\n  - **Compile warmup** — runs the general-shape compiled path.\n  - **Capture** — for each capture size, runs one warmup pass then records a CUDA graph.\n  - **Steady-state replay** — replays the captured CUDA graph for each forward pass.\n\n### Piecewise cuda graph runner\n\n`PiecewiseCudaGraphRunner` orchestrates the full lifecycle through three phases:\n\n- **Compile** — Warms up JIT kernels with a dummy forward pass, then wraps the model with `torch.compile`, triggering Dynamo tracing to split the FX graph and create `CUDAPiecewiseBackend` instances for each subgraph piece.\n\n- **Capture** — Iterates over capture sizes in reverse order (largest first). For each size, runs the forward pass twice (one warmup, one CUDA graph capture).\n\n- **Replay** — At runtime, finds the smallest captured size >= actual token count via binary search, copies inputs into static buffers with zero-padding, replays the captured CUDA graphs, and slices outputs back to the actual token count.\n\n### Memory optimization\n\nThe memory cost of PCG comes from two parts: **torch memory allocator** and **non-torch memory**.\n\nThe torch memory allocator overhead is trivial thanks to several optimizations: a global shared memory pool is reused across all CUDA graph runners and capture sizes, capture is done in reverse order (large to small) so smaller graphs reuse memory allocated by larger ones, and output tensors of the last subgraph are stored as weak references to maximize memory reuse.\n\nThe main memory overhead comes from non-torch memory — the CUDA graph objects themselves require GPU memory to store the recorded kernel launch parameters and internal state. This overhead scales with the number of captured sizes, which is why `piecewise_cuda_graph_max_tokens` is capped conservatively by default.\n\n### Shape configuration\nPiecewise CUDA graph pre-captures graphs for a set of token counts. At runtime, the actual token count is rounded up to the nearest captured size (via binary search), and the corresponding graph is replayed. If the token count exceeds the largest captured size, the runtime falls back to the normal (non-graph) forward path.\n\nThe default capture schedule is auto-generated with increasing granularity:\n\n| Token range | Step size |\n|-------------|-----------|\n| 4 – 32      | 4         |\n| 48 – 256    | 16        |\n| 288 – 512   | 32        |\n| 576 – 1024  | 64        |\n| 1280 – 4096 | 256       |\n| 4096+       | 512       |\n\nFor the auto-generated schedule, sizes are capped at `--piecewise-cuda-graph-max-tokens`. The default cap is `chunked_prefill_size` for non-MLA models and `2048` for MLA backend models. If `--max-total-tokens` is set, the cap is further limited to not exceed it. Additionally, Llama-2 models are auto-capped at 4096 tokens as a temporary workaround.\n\n## Compatibility\n\nPCG is auto-disabled in the following scenarios. We are actively working on expanding compatibility — support for many of these will be coming soon.\n\n- Disabled model architectures (e.g., `DeepseekV32ForCausalLM`)\n- Speculative decoding\n- DP attention\n- Pipeline parallelism (`pp_size > 1`)\n- Non-CUDA hardware (AMD ROCm, Ascend NPU)\n- MoE A2A backend\n- LoRA\n- Multimodal / VLM models\n- DLLM (diffusion LLM)\n- Deterministic inference\n- PD disaggregation\n- Expert distribution recorder / EPLB\n\nUse `--enforce-piecewise-cuda-graph` to skip all auto-disable checks (for testing/debugging only).\n\n## Code Reference\n\n| File | Description |\n|---|---|\n| `python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py` | Main runner: init, capture, replay |\n| `python/sglang/srt/compilation/compile.py` | `install_torch_compiled` trampoline |\n| `python/sglang/srt/compilation/backend.py` | `SGLangBackend`, graph splitting, piecewise compilation |\n| `python/sglang/srt/compilation/cuda_piecewise_backend.py` | Per-subgraph CUDA graph capture/replay |\n| `python/sglang/srt/compilation/piecewise_context_manager.py` | Global context flags and `ForwardContext` |\n| `python/sglang/srt/compilation/compilation_config.py` | Capture sizes, split ops, compiler config |\n| `python/sglang/srt/utils/custom_op.py` | `register_custom_op` for torch.compile compatibility |\n| `python/sglang/srt/server_args.py` | Server arguments and auto-disable logic |\n"
  },
  {
    "path": "docs/advanced_features/pipeline_parallelism.md",
    "content": "# Pipeline Parallelism for Long Context\n\n## Why Pipeline Parallelism?\n\nAs Large Language Models (LLMs) scale toward trillion-parameter architectures and \"infinite\" context windows, the underlying serving infrastructure must evolve toward more granular, cross-node parallelization strategies. While KV cache techniques effectively mitigate redundant computation, they cannot circumvent the prohibitive Time to First Token (TTFT) inherent in ultra-long sequences with extremely large initial Input Token Length (ITL). Although Tensor Parallelism (TP) remains the conventional approach for intra-node scaling, it frequently encounters communication bottlenecks during multi-node deployments. On the other hand, pipeline parallelism only requires cross-node communication at the boundaries of each pipeline stage, which can achieve better computation-communication overlap compared to a large TP. Therefore, it is also a promising parallelization strategy for improving throughput.\n\nDetailed analysis can be found in this [blog](https://lmsys.org/blog/2026-01-15-chunked-pipeline/).\n\n## Implementation Refactoring based on Async Communication\nWith Dynamic Chunked Prefill, pipeline parallelism has the potential to reduce the TTFT of long-context inputs. For each request, its input tokens can be partitioned into multiple chunks, each no longer than the chunked prefill size. Different chunks of the same request can be processed simultaneously by different nodes, thus parallelizing the processing and reducing TTFT. SGLang has supported Pipeline Parallelism (#5724) for some time and made it compatible with the PD Disaggregation feature (#8846), but the implementation was not perfect and had significant room for performance improvements.\n\nTo eliminate this performance hazard, SGLang implements a Micro-batching Event Loop with non-blocking asynchronous peer-to-peer (P2P) communication to overlap GPU computation with CPU metadata processing and PP communication. This ensures that while one micro-batch is being computed on the GPU, the next one is already being prepared and moved into position effectively, ensuring the pipeline remains as saturated as possible. This approach was first proposed in #7979 and has been redesigned and included in #11852.\n\nThe key mechanisms of the implementation include:\n\n* **Decoupled Sync/Async Logic in the Event Loop:** The scheduler uses `async_send` in `_pp_send_pyobj_to_next_stage`. Instead of waiting for a transfer to complete, it returns a `P2PWork` handle. The actual synchronization (`P2PWork.work.wait()`) is deferred until `_pp_commit_comm_work` is called, allowing the CPU to perform other work—like scheduling the next batch or processing metadata—while data is in flight.\n* **Multi-Stream Execution:** In addition to the main `default_stream`, which serves as the synchronization stream, SGLang utilizes dedicated `forward_stream` and `copy_stream` to execute forward pass GPU computation and Data-to-Host (D2H) memory transfers separately for better overlapping. While `_pp_launch_batch` is executing the current micro-batch on the GPU for the current stage, the CPU processes the previous micro-batch's results using `_pp_process_batch_result`.\n\n## Guidance about Dynamic Chunking\n\n### Why Dynamic Chunking\nChunked prefill with a fixed size can cause bubbles in the pipeline, especially when the pp size is large. The main reason behind this phenomenon is that the model has a non-uniform running time, even though each chunk size is identical (brought by the Transformer structure). The larger the prefix sequence length, the longer the running time of the chunk. And these bubbles will be propagated to the next stage, and will significantly degrade the scale efficiency of larger pp ranks.\n\nTo address this issue, SGLang introduces a dynamic chunking mechanism to predict the optimal size for the next chunk such that it satisfies this condition:\n\nRuntime(L + Next Chunk Size) - Runtime(L) = Runtime(Initial Chunk Size)\n\nwhere ***L*** denotes the Prefix Sequence Length. By profiling a series of requests with different ITLs, we model the cumulative runtime as a quadratic function of sequence length. Using this model, we solve the optimal next chunk size for any given prefix length ***L***. Since the computation complexity of the Attention mechanism scales with ***L***, the next chunk size will be progressively reduced as ***L*** grows to maintain an aligned chunk execution time across pipeline stages.\n\nBased on this method, the scheduler can predict and dynamically reduce the chunk size during runtime to minimize the bubbles caused by the stage misalignment. To be noticed, the scheduler does not use the raw predicted value. To facilitate efficient KVCache memory management and ensure affinity with hardware execution efficiency, the value is aligned downward to the nearest multiple of max(`--page-size`, 64).\n\n\n### Chunked Prefill Size and Smoothing Factor\n\nWhen `--enable-dynamic-chunking` is enabled, each chunk size of a sequence is determined dynamically based on the quadratic model that predicts the next chunk size based on the estimated runtime of the initial chunk length. In this case, we use `--chunked-prefill-size` to set up the initial chunk size. When switching to the dynamic chunking mode, the initial chunk size (`--chunked-prefill-size`) should be set to a larger value comparable to the original chunked prefill size, so that there won't be too many chunks.\n\n**`SGLANG_DYNAMIC_CHUNKING_SMOOTH_FACTOR`** is an environmental variable that controls the smoothing factor for the dynamic chunking algorithm, defaulting to 0.75. It determines how much the chunk size can change during the prefill phase. A larger value means a more aggressive chunk size change, which may lead to better performance but also to greater chunk size changes (the chunk size at the end may become very small, which could lead to performance degradation) and more total chunks. When it is set to 1, the chunk size will be adjusted strictly based on the aforementioned quadratic model that predicts the next chunk size. A smaller value means a more conservative chunk size change, which may lead to smaller chunk size changes and fewer total chunks. When it is set to 0, the chunk size will not be adjusted dynamically, so it is identical to the traditional way with a fixed chunked prefill size.\n\nDue to the variation in hardware, models, and target workloads, a static configuration is seldom optimal across all scenarios. Consequently, achieving peak performance necessitates a degree of hyperparameter tuning when switching to the dynamic chunking mode.\n\n**Tuning Guidance for Dynamic Chunked Prefill**\n\n* **Step 1 \\- Iterate to find the optimal fixed chunked prefill size for the targeted PP size**: Different PP sizes for targeted ITL may have different optimal chunked prefill sizes. Therefore, users should iterate to obtain the baseline according to the available resources for scaling.\n* **Step 2 \\- Initial Chunk Size Selection for Dynamic Chunking**: Set the initial size to 2× or 3× the optimal fixed chunked prefill size. This reduces the total number of chunks and prevents \"tail chunks\" from underutilizing hardware. To maintain efficiency for extremely large Input Token Lengths (ITL), the dynamic predictor automatically ensures subsequent chunks are at least 1/4 of this initial size. In addition, it is recommended to use a larger initial chunk size (e.g., 4× the optimal fixed chunked prefill size) for such cases as well.\n* **Step 3 \\- Smooth Factor Adjustment**: This factor controls how strictly the chunk size adjusts the prediction given by the quadratic performance fitting model.\n  * 1.0: Follows the model strictly.\n  * **0.6 – 0.85 (Recommended)**: Typical range for the best balance between dynamic scaling and hardware stability. Through experiments, we find that a range between 0.6 and 0.85 typically yields the best performance for dynamic chunking.\n  * 0: Disables dynamic adjustment, reverting to traditional fixed-size chunking.\n* **Another small optimization tip:** Put the larger partition in the higher PP rank when the layers are not evenly divisible across ranks. It can increase the GPU utilization when a larger PP rank is waiting for the previous stage’s result, hence reducing the bubbles on higher PP ranks. If we take DeepSeek-V3.1 as an example, `SGLANG_PP_LAYER_PARTITION=15,15,15,16` usually performs better than `16,15,15,15`.\n\n## Best Practice for Long Context\n\n### Tuning the Chunked Prefill Size\nOptimizing the chunked prefill size is crucial for balancing pipeline efficiency and resource utilization. The ideal size depends on factors including model architecture, hardware configuration, and typical input lengths. We recommend starting with a small chunk size, such as 4K, and gradually increasing it until you find the optimal size for your specific use case (Different targeted ITL and PP Sizes may have different optimal chunked prefill sizes. Therefore, users should iterate to obtain the baseline according to the available resources for scaling). Alternatively, you can analyze the hardware capacity and determine the optimal chunk size based on the roofline model.\n\n### Enable Dynamic Chunking and Adjust Smoothing Factor for Ultra-long ITL\nSGLang also offers a dynamic chunking solution that could further improve performance. This feature is currently an experimental feature that requires a certain amount of tuning experimentation and may not be suitable for all workloads. In addition, fine-tuning the smoothing factor can help optimize performance for specific workloads and model characteristics.\n\n### Case Study on NVIDIA H20\n\nWhen evaluating pipeline parallelism with fixed chunked prefill sizes from 2K to 16K, experiment results show that a 4K chunk size delivered optimal prefill TTFT performance for the DeepSeek-V3.1, and a 6K chunk size delivered optimal prefill TTFT performance for the Qwen3-235B-A22B-FP8.\n\nWhen enabling dynamic chunking, we first scale the optimal fixed chunked prefill size by a factor of 3 as the initial chunk size. Through experimentation, we found that a multiplier of 2-3 provides an appropriate balance—avoiding excessive initial pipeline bubbles while ensuring that subsequent chunks don't become too small as context length increases. With the default dynamic chunking smoothing factor of 0.75, we performed parameter tuning and determined that a value of 0.65 works optimally with the 12K initial chunk size for the DeepSeek-V3.1, while a value of 0.8 works optimally with the 18K initial chunk size for the Qwen3-235B-A22B-FP8.\n\n#### DeepSeek-V3.1 with 128K Input Token Length\n```bash\n# prefill node 0 (fixed chunked prefill size)\npython3 -m sglang.launch_server \\\n  --model-path deepseek-ai/DeepSeek-V3.1 --trust-remote-code \\\n  --nnodes 4 --node-rank 0 --tp 8 --pp-size 4 \\\n  --port 30000 --dist-init-addr <MASTER_NODE_IP> \\\n  --disable-radix-cache --mem-fraction-static 0.8  \\\n  --attention-backend fa3 --host 0.0.0.0 --watchdog-timeout 3600 \\\n  --max-running-requests 128 --chunked-prefill-size 4096\n```\n\n```bash\n# prefill node 0 (with dynamic chunking)\nexport SGLANG_DYNAMIC_CHUNKING_SMOOTH_FACTOR=0.65\npython3 -m sglang.launch_server \\\n  --model-path deepseek-ai/DeepSeek-V3.1 --trust-remote-code \\\n  --nnodes 4 --node-rank 0 --tp 8 --pp-size 4 \\\n  --port 30000 --dist-init-addr <MASTER_NODE_IP> \\\n  --disable-radix-cache --mem-fraction-static 0.8  \\\n  --attention-backend fa3 --host 0.0.0.0 --watchdog-timeout 3600 \\\n  --max-running-requests 128 --chunked-prefill-size 12288 --enable-dynamic-chunking\n```\n\n#### Qwen3-235B-A22B-FP8 with 128K Input Token Length\n```bash\n# prefill node 0 (fixed chunked prefill size)\npython3 -m sglang.launch_server \\\n  --model-path Qwen/Qwen3-235B-A22B-FP8 --trust-remote-code \\\n  --nnodes 4 --node-rank 0 --tp 4 --pp-size 8 \\\n  --port 30000 --dist-init-addr <MASTER_NODE_IP> \\\n  --disable-radix-cache --mem-fraction-static 0.8  \\\n  --attention-backend fa3 --host 0.0.0.0 --watchdog-timeout 3600 \\\n  --max-running-requests 128 --chunked-prefill-size 6144\n```\n\n```bash\n# prefill node 0 (with dynamic chunking)\nexport SGLANG_DYNAMIC_CHUNKING_SMOOTH_FACTOR=0.8\npython3 -m sglang.launch_server \\\n  --model-path Qwen/Qwen3-235B-A22B-FP8 --trust-remote-code \\\n  --nnodes 4 --node-rank 0 --tp 4 --pp-size 8 \\\n  --port 30000 --dist-init-addr <MASTER_NODE_IP> \\\n  --disable-radix-cache --mem-fraction-static 0.8  \\\n  --attention-backend fa3 --host 0.0.0.0 --watchdog-timeout 3600 \\\n  --max-running-requests 128 --chunked-prefill-size 18432 --enable-dynamic-chunking\n```\n\nNote: `--disable-radix-cache` is enabled only for reproducible benchmarking purposes. It is not recommended to use it in production.\n\n## Best Practice for Pipeline Parallelism with PD Disaggregation\nTo be added. Stay tuned for the latest updates on Pipeline Parallelism with PD Disaggregation.\n"
  },
  {
    "path": "docs/advanced_features/quantization.md",
    "content": "# Quantization\n\nSGLang supports various quantization methods, including offline quantization and online dynamic quantization.\n\nOffline quantization loads pre-quantized model weights directly during inference. This is required for quantization methods\nsuch as GPTQ and AWQ, which collect and pre-compute various statistics from the original weights using the calibration dataset.\n\nOnline quantization dynamically computes scaling parameters—such as the maximum/minimum values of model weights—during runtime.\nLike NVIDIA FP8 training's [delayed scaling](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#Mixed-precision-training-with-FP8) mechanism, online quantization calculates the appropriate scaling factors\non-the-fly to convert high-precision weights into a lower-precision format.\n\n**Note: For better performance, usability and convenience, offline quantization is recommended over online quantization.**\n\nIf you use a pre-quantized model, do not add `--quantization` to enable online quantization at the same time.\nFor popular pre-quantized models, please visit [Unsloth](https://huggingface.co/unsloth), [NVIDIA ModelOpt](https://huggingface.co/collections/nvidia/inference-optimized-checkpoints-with-model-optimizer)\nor [NeuralMagic](https://huggingface.co/collections/neuralmagic) collections on HF for some\npopular quality validated quantized models. Quantized models must be validated via benchmarks post-quantization\nto guard against abnormal quantization loss regressions.\n\n## Platform Compatibility\n\nThe following table summarizes quantization method support across NVIDIA and AMD GPUs.\n\n| Method | NVIDIA GPUs | AMD GPUs (MI300X/MI325X/MI350X) | Notes |\n|--------|:-----------:|:-------------------------------:|-------|\n| `fp8` | Yes | Yes | Aiter or Triton backend on AMD |\n| `mxfp4` | Yes | Yes | Requires CDNA3/CDNA4 with MXFP support; uses Aiter |\n| `blockwise_int8` | Yes | Yes | Triton-based, works on both platforms |\n| `w8a8_int8` | Yes | Yes | |\n| `w8a8_fp8` | Yes | Yes | Aiter or Triton FP8 on AMD |\n| `awq` | Yes | Yes | Uses Triton dequantize on AMD (vs. optimized CUDA kernels on NVIDIA) |\n| `gptq` | Yes | Yes | Uses Triton or vLLM kernels on AMD |\n| `compressed-tensors` | Yes | Yes | Aiter paths for FP8/MoE on AMD |\n| `quark` | Yes | Yes | AMD Quark quantization; Aiter GEMM paths on AMD |\n| `auto-round` | Yes | Yes | Platform-agnostic (Intel auto-round) |\n| `quark_int4fp8_moe` | No | Yes | AMD-only; online INT4-to-FP8 MoE quantization (CDNA3/CDNA4) |\n| `awq_marlin` | Yes | No | Marlin kernels are CUDA-only |\n| `gptq_marlin` | Yes | No | Marlin kernels are CUDA-only |\n| `gguf` | Yes | No | CUDA-only kernels in sgl-kernel |\n| `modelopt` / `modelopt_fp8` | Yes (Hopper/SM90+) | No | [NVIDIA ModelOpt](https://github.com/NVIDIA/Model-Optimizer); requires NVIDIA hardware |\n| `modelopt_fp4` | Yes (Blackwell/SM100+) | No | [NVIDIA ModelOpt](https://github.com/NVIDIA/Model-Optimizer); native FP4 on Blackwell (B200, GB200) |\n| `petit_nvfp4` | No | Yes (MI250/MI300X/MI325X) | Enables NVFP4 on ROCm via [Petit](https://github.com/causalflow-ai/petit-kernel); use `modelopt_fp4` on NVIDIA Blackwell. Auto-selected when loading NVFP4 models on AMD. See [LMSYS blog](https://lmsys.org/blog/2025-09-21-petit-amdgpu/) and [AMD ROCm blog](https://rocm.blogs.amd.com/artificial-intelligence/fp4-mixed-precision/README.html). |\n| `bitsandbytes` | Yes | Experimental | Depends on bitsandbytes ROCm support |\n| `torchao` (`int4wo`, etc.) | Yes | Partial | `int4wo` not supported on AMD; other methods may work |\n\nOn AMD, several of these methods use [Aiter](https://github.com/ROCm/aiter) for acceleration -- set `SGLANG_USE_AITER=1` where noted. See [AMD GPU setup](../platforms/amd_gpu.md) for installation and configuration details.\n\n## GEMM Backends for FP4/FP8 Quantization\n\n:::{note}\nBackend selection is supported only for **blockwise FP8** and **NVFP4** GEMM. When running FP8 or FP4 quantized models, you can select the GEMM backend via `--fp8-gemm-backend` and `--fp4-gemm-backend`.\n:::\n\n### `--fp8-gemm-backend` (Blockwise FP8 GEMM)\n\n| Backend | Hardware | Description |\n|---------|----------|-------------|\n| `auto` | All | Auto-selects based on hardware |\n| `deep_gemm` | SM90, SM100 | JIT-compiled; enabled when DeepGEMM is installed |\n| `flashinfer_trtllm` | SM100 | FlashInfer TensorRT-LLM backend; optimal for low-latency |\n| `flashinfer_cutlass` | SM100/120 | FlashInfer CUTLASS groupwise FP8 GEMM |\n| `flashinfer_deepgemm` | SM90 | Uses swapAB optimization for small M dimensions in decoding |\n| `cutlass` | SM90, SM100/120 | sgl-kernel CUTLASS |\n| `triton` | All | Fallback; widely compatible |\n| `aiter` | ROCm | AMD AITER backend |\n\n**`auto` selection order:** 1) DeepGEMM (SM90/SM100, installed); 2) FlashInfer TRTLLM (SM100, FlashInfer available); 3) CUTLASS (SM90/SM100/120); 4) AITER (AMD); 5) Triton. **Exception:** SM120 always resolves to Triton.\n\n### `--fp4-gemm-backend` (NVFP4 GEMM)\n\n| Backend | Hardware | Description |\n|---------|----------|-------------|\n| `auto` | SM100/120 | Auto-selects: `flashinfer_cudnn` on SM120; `flashinfer_cutlass` on SM100 |\n| `flashinfer_cutlass` | SM100/120 | FlashInfer CUTLASS backend |\n| `flashinfer_cudnn` | SM100/120 (CUDA 13+, cuDNN 9.15+) | FlashInfer cuDNN backend; used on SM120 for performance |\n| `flashinfer_trtllm` | SM100 | FlashInfer TensorRT-LLM backend |\n\nWhen FlashInfer is unavailable for NVFP4, sgl-kernel CUTLASS is used as an automatic fallback.\n\n## Offline Quantization\n\nTo load already quantized models, simply load the model weights and config. **Again, if the model has been quantized offline,\nthere's no need to add `--quantization` argument when starting the engine. The quantization method will be parsed from the\ndownloaded Hugging Face config. For example, DeepSeek V3/R1 models are already in FP8, so do not add redundant parameters.**\n\n```bash\npython3 -m sglang.launch_server \\\n    --model-path hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4 \\\n    --port 30000 --host 0.0.0.0\n```\n\nTake note, if your model is **per-channel quantized (INT8 or FP8) with per-token dynamic quantization activation**, you can opt to include `--quantization w8a8_int8` or `--quantization w8a8_fp8` to invoke the corresponding CUTLASS int8_kernel or fp8_kernel in sgl-kernel. This action will ignore the Hugging Face config's quantization settings. For instance, with `neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8-dynamic`, if you execute with `--quantization w8a8_fp8`, the system will use the `W8A8Fp8Config` from SGLang to invoke the sgl-kernel, rather than the `CompressedTensorsConfig` for vLLM kernels.\n\n```bash\npython3 -m sglang.launch_server \\\n    --model-path neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8-dynamic \\\n    --quantization w8a8_fp8 \\\n    --port 30000 --host 0.0.0.0\n```\n\n### Examples of Offline Model Quantization\n\n#### Using [Unsloth](https://docs.unsloth.ai/basics/inference-and-deployment/sglang-guide)\n\nWe strongly suggest the use of Unsloth to quantize and load the model. Please refer to [SGLang Deployment & Inference Guide with Unsloth](https://docs.unsloth.ai/basics/inference-and-deployment/sglang-guide).\n\n#### Using [auto-round](https://github.com/intel/auto-round)\n\n```bash\n# Install\npip install auto-round\n```\n\n- LLM quantization\n\n```py\n# for LLM\nfrom auto_round import AutoRound\nmodel_id = \"meta-llama/Llama-3.2-1B-Instruct\"\nquant_path = \"Llama-3.2-1B-Instruct-autoround-4bit\"\n# Scheme examples: \"W2A16\", \"W3A16\", \"W4A16\", \"W8A16\", \"NVFP4\", \"MXFP4\" (no real kernels), \"GGUF:Q4_K_M\", etc.\nscheme = \"W4A16\"\nformat = \"auto_round\"\nautoround = AutoRound(model_id, scheme=scheme)\nautoround.quantize_and_save(quant_path, format=format) # quantize and save\n\n```\n\n- VLM quantization\n```py\n# for VLMs\nfrom auto_round import AutoRoundMLLM\nmodel_name = \"Qwen/Qwen2-VL-2B-Instruct\"\nquant_path = \"Qwen2-VL-2B-Instruct-autoround-4bit\"\nscheme = \"W4A16\"\nformat = \"auto_round\"\nautoround = AutoRoundMLLM(model_name, scheme)\nautoround.quantize_and_save(quant_path, format=format) # quantize and save\n\n```\n\n- Command Line Usage (Gaudi/CPU/Intel GPU/CUDA)\n\n```bash\nauto-round \\\n    --model meta-llama/Llama-3.2-1B-Instruct \\\n    --bits 4 \\\n    --group_size 128 \\\n    --format \"auto_round\" \\\n    --output_dir ./tmp_autoround\n```\n\n- known issues\n\nSeveral limitations currently affect offline quantized model loading in sglang, These issues might be resolved in future updates of sglang. If you experience any problems, consider using Hugging Face Transformers as an alternative.\n\n1. Mixed-bit Quantization Limitations\n\n    Mixed-bit quantization is not fully supported. Due to vLLM's layer fusion (e.g., QKV fusion), applying different bit-widths to components within the same fused layer can lead to compatibility issues.\n\n\n2. Limited Support for Quantized MoE Models\n\n    Quantized MoE models may encounter inference issues due to kernel limitations (e.g., lack of support for mlp.gate layer quantization). please try to skip quantizing these layers to avoid such errors.\n\n\n3. Limited Support for Quantized VLMs\n    <details>\n        <summary>VLM failure cases</summary>\n\n    Qwen2.5-VL-7B\n\n    auto_round:auto_gptq format:  Accuracy is close to zero.\n\n    GPTQ format:  Fails with:\n    ```\n    The output size is not aligned with the quantized weight shape\n    ```\n    auto_round:auto_awq and AWQ format:  These work as expected.\n    </details>\n\n#### Using [GPTQModel](https://github.com/ModelCloud/GPTQModel)\n\n```bash\n# install\npip install gptqmodel --no-build-isolation -v\n```\n\n```py\nfrom datasets import load_dataset\nfrom gptqmodel import GPTQModel, QuantizeConfig\n\nmodel_id = \"meta-llama/Llama-3.2-1B-Instruct\"\nquant_path = \"Llama-3.2-1B-Instruct-gptqmodel-4bit\"\n\ncalibration_dataset = load_dataset(\n    \"allenai/c4\", data_files=\"en/c4-train.00001-of-01024.json.gz\",\n    split=\"train\"\n  ).select(range(1024))[\"text\"]\n\nquant_config = QuantizeConfig(bits=4, group_size=128) # quantization config\nmodel = GPTQModel.load(model_id, quant_config) # load model\n\nmodel.quantize(calibration_dataset, batch_size=2) # quantize\nmodel.save(quant_path) # save model\n```\n\n#### Using [LLM Compressor](https://github.com/vllm-project/llm-compressor/)\n\n```bash\n# install\npip install llmcompressor\n```\n\nHere, we take quantize `meta-llama/Meta-Llama-3-8B-Instruct` to `FP8` as an example to elaborate on how to do offline quantization.\n\n```python\nfrom transformers import AutoTokenizer\nfrom llmcompressor.transformers import SparseAutoModelForCausalLM\nfrom llmcompressor.transformers import oneshot\nfrom llmcompressor.modifiers.quantization import QuantizationModifier\n\n# Step 1: Load the original model.\nMODEL_ID = \"meta-llama/Meta-Llama-3-8B-Instruct\"\n\nmodel = SparseAutoModelForCausalLM.from_pretrained(\n  MODEL_ID, device_map=\"auto\", torch_dtype=\"auto\")\ntokenizer = AutoTokenizer.from_pretrained(MODEL_ID)\n\n# Step 2: Perform offline quantization.\n# Step 2.1: Configure the simple PTQ quantization.\nrecipe = QuantizationModifier(\n  targets=\"Linear\", scheme=\"FP8_DYNAMIC\", ignore=[\"lm_head\"])\n\n# Step 2.2: Apply the quantization algorithm.\noneshot(model=model, recipe=recipe)\n\n# Step 3: Save the model.\nSAVE_DIR = MODEL_ID.split(\"/\")[1] + \"-FP8-Dynamic\"\nmodel.save_pretrained(SAVE_DIR)\ntokenizer.save_pretrained(SAVE_DIR)\n```\n\nThen, you can directly use the quantized model with `SGLang`, by using the following command:\n\n```bash\npython3 -m sglang.launch_server \\\n    --model-path $PWD/Meta-Llama-3-8B-Instruct-FP8-Dynamic \\\n    --port 30000 --host 0.0.0.0\n```\n\n#### Using [NVIDIA ModelOpt](https://github.com/NVIDIA/Model-Optimizer)\n\nNVIDIA Model Optimizer (ModelOpt) provides advanced quantization techniques optimized for NVIDIA hardware.\n\n**Offline vs. Online Quantization:**\n\nSGLang supports two modes for ModelOpt.\n\n* **Offline Quantization (pre-quantized):**\n    * **Usage:** Download a pre-quantized model from Hugging Face or run `hf_ptq.py` once to create a new quantized checkpoint. Then load this quantized checkpoint.\n    * **Pros:** Fast server startup, quantization can be validated before deployment, efficient resource usage.\n    * **Cons:** Requires an extra preparation step.\n\n* **Online Quantization (quant and serve):**\n    * **Usage:** Load a standard BF16/FP16 model and add a flag. The engine applies quantization *on startup*.\n    * **Pros:** Convenient (no new checkpoint needed).\n    * **Cons:** **High startup time**, increases VRAM usage during initialization (risk of OOM).\n\nThe following sections guide you through using the Offline path: loading pre-quantized models or creating your own checkpoints.\n\n##### Using Pre-Quantized Checkpoints\n\nIf a model is already quantized (e.g., from Hugging Face), you can load it directly.\n\n* **FP8 Models:**\n    Use `--quantization modelopt_fp8`.\n    ```bash\n    python3 -m sglang.launch_server \\\n        --model-path nvidia/Llama-3.1-8B-Instruct-FP8 \\\n        --quantization modelopt_fp8 \\\n        --port 30000\n    ```\n\n* **FP4 Models:**\n    Use `--quantization modelopt_fp4`.\n    ```bash\n    python3 -m sglang.launch_server \\\n        --model-path nvidia/Llama-3.3-70B-Instruct-NVFP4 \\\n        --quantization modelopt_fp4 \\\n        --port 30000\n    ```\n\n##### Creating Your Own Quantized Checkpoints\n\nIf a pre-quantized checkpoint is not available for your model, you can create one using NVIDIA Model Optimizer's `hf_ptq.py` script.\n\n**Why quantize?**\n- Reduce VRAM usage\n- Higher throughput and lower latency\n- More flexible deployment (on smaller GPUs)\n\n**What can be quantized?**\n- The entire model\n- MLP layers only\n- KV cache\n\n**Key options in `hf_ptq.py`:**\n\n`--qformat`: Quantization formats `fp8`, `nvfp4`, `nvfp4_mlp_only`\n\n`--kv_cache_qformat`: KV cache quantization format (default: `fp8`)\n\n**Note:** The default `kv_cache_qformat` may not be optimal for all use cases. Consider setting this explicitly.\n\n**Hardware requirements:** Hopper and higher are recommended. Insufficient GPU memory may cause weight offloading, resulting in extremely long quantization time.\n\nFor detailed usage and supported model architectures, see [NVIDIA Model Optimizer LLM PTQ](https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/llm_ptq).\n\nSGLang includes a streamlined workflow for quantizing models with ModelOpt and automatically exporting them for deployment.\n\n\n##### Installation\n\nFirst, install ModelOpt:\n\n```bash\npip install nvidia-modelopt\n```\n\n##### Quantization and Export Workflow\n\nSGLang provides an example script that demonstrates the complete ModelOpt quantization and export workflow. Run from the SGLang repository root (see [modelopt_quantize_and_export.py](https://github.com/sgl-project/sglang/blob/main/examples/usage/modelopt_quantize_and_export.py)):\n\n```bash\n# Quantize and export a model using ModelOpt FP8 quantization\npython examples/usage/modelopt_quantize_and_export.py quantize \\\n    --model-path TinyLlama/TinyLlama-1.1B-Chat-v1.0 \\\n    --export-dir ./quantized_tinyllama_fp8 \\\n    --quantization-method modelopt_fp8\n\n# For FP4 quantization (requires Blackwell GPU)\npython examples/usage/modelopt_quantize_and_export.py quantize \\\n    --model-path TinyLlama/TinyLlama-1.1B-Chat-v1.0 \\\n    --export-dir ./quantized_tinyllama_fp4 \\\n    --quantization-method modelopt_fp4\n```\n\n##### Available Quantization Methods\n\n- `modelopt_fp8`: FP8 quantization with optimal performance on NVIDIA Hopper and Blackwell GPUs\n- `modelopt_fp4`: FP4 quantization with optimal performance on Nvidia Blackwell GPUs\n\n##### Python API Usage\n\nYou can also use ModelOpt quantization programmatically:\n\n```python\nimport sglang as sgl\nfrom sglang.srt.configs.device_config import DeviceConfig\nfrom sglang.srt.configs.load_config import LoadConfig\nfrom sglang.srt.configs.model_config import ModelConfig\nfrom sglang.srt.model_loader.loader import get_model_loader\n\n# Configure model with ModelOpt quantization and export\nmodel_config = ModelConfig(\n    model_path=\"TinyLlama/TinyLlama-1.1B-Chat-v1.0\",\n    quantization=\"modelopt_fp8\",  # or \"modelopt_fp4\"\n    trust_remote_code=True,\n)\n\nload_config = LoadConfig(\n    modelopt_export_path=\"./exported_model\",\n    modelopt_checkpoint_save_path=\"./checkpoint.pth\",  # optional, fake quantized checkpoint\n)\ndevice_config = DeviceConfig(device=\"cuda\")\n\n# Load and quantize the model (export happens automatically)\nmodel_loader = get_model_loader(load_config, model_config)\nquantized_model = model_loader.load_model(\n    model_config=model_config,\n    device_config=device_config,\n)\n```\n\n##### Deploying Quantized Models\n\nAfter quantization and export, you can deploy the model with SGLang:\n\n```bash\n# Deploy the exported quantized model\npython -m sglang.launch_server \\\n    --model-path ./quantized_tinyllama_fp8 \\\n    --quantization modelopt \\\n    --port 30000 --host 0.0.0.0\n```\n\nOr using the Python API (use the same path as `modelopt_export_path` from the quantize step):\n\n```python\nimport sglang as sgl\n\ndef main():\n    # Deploy exported ModelOpt quantized model\n    # Path must match modelopt_export_path from quantize step (e.g., ./exported_model)\n    llm = sgl.Engine(\n        model_path=\"./exported_model\",\n        quantization=\"modelopt\",\n    )\n\n    # Run inference\n    prompts = [\n        \"Hello, how are you?\",\n        \"What is the capital of France?\",\n    ]\n    sampling_params = {\n        \"temperature\": 0.8,\n        \"top_p\": 0.95,\n        \"max_new_tokens\": 100,\n    }\n\n    outputs = llm.generate(prompts, sampling_params)\n\n    for i, output in enumerate(outputs):\n        print(f\"Prompt: {prompts[i]}\")\n        print(f\"Output: {output['text']}\")\n\nif __name__ == \"__main__\":\n    main()\n\n```\n\n##### Advanced Features\n\n**Checkpoint Management**: Save and restore fake quantized checkpoints for reuse:\n\n```bash\n# Save the fake quantized checkpoint during quantization\npython examples/usage/modelopt_quantize_and_export.py quantize \\\n    --model-path meta-llama/Llama-3.2-1B-Instruct \\\n    --export-dir ./quantized_model \\\n    --quantization-method modelopt_fp8 \\\n    --checkpoint-save-path ./my_checkpoint.pth\n\n# The checkpoint can be reused for future quantization runs and skip calibration\n```\n\n**Export-only Workflow**: If you have a pre-existing fake quantized ModelOpt checkpoint, you can export it directly. See [LoadConfig](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/configs/load_config.py) for the full API:\n\n```python\nfrom sglang.srt.configs.device_config import DeviceConfig\nfrom sglang.srt.configs.load_config import LoadConfig\nfrom sglang.srt.configs.model_config import ModelConfig\nfrom sglang.srt.model_loader.loader import get_model_loader\n\nmodel_config = ModelConfig(\n    model_path=\"meta-llama/Llama-3.2-1B-Instruct\",\n    quantization=\"modelopt_fp8\",\n    trust_remote_code=True,\n)\n\nload_config = LoadConfig(\n    modelopt_checkpoint_restore_path=\"./my_checkpoint.pth\",\n    modelopt_export_path=\"./exported_model\",\n)\n\n# Load and export the model (DeviceConfig defaults to device=\"cuda\")\nmodel_loader = get_model_loader(load_config, model_config)\nmodel_loader.load_model(model_config=model_config, device_config=DeviceConfig())\n```\n\n##### Benefits of ModelOpt\n\n- **Hardware Optimization**: Specifically optimized for NVIDIA GPU architectures\n- **Advanced Quantization**: Supports cutting-edge FP8 and FP4 quantization techniques\n- **Seamless Integration**: Automatic export to HuggingFace format for easy deployment\n- **Calibration-based**: Uses calibration datasets for optimal quantization quality\n- **Production Ready**: Enterprise-grade quantization with NVIDIA support\n\n## Online Quantization\n\nTo enable online quantization, you can simply specify `--quantization` in the command line. For example, you can launch the server with the following command to enable `FP8` quantization for model `meta-llama/Meta-Llama-3.1-8B-Instruct`:\n\n```bash\npython3 -m sglang.launch_server \\\n    --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n    --quantization fp8 \\\n    --port 30000 --host 0.0.0.0\n```\n\nOur team is working on supporting more online quantization methods. SGLang will soon support methods including but not limited to `[\"awq\", \"gptq\", \"marlin\", \"gptq_marlin\", \"awq_marlin\", \"bitsandbytes\", \"gguf\"]`.\n\n### torchao online quantization method\n\nSGLang also supports quantization methods based on [torchao](https://github.com/pytorch/ao). You can simply specify `--torchao-config` in the command line to support this feature. For example, if you want to enable `int4wo-128` for model `meta-llama/Meta-Llama-3.1-8B-Instruct`, you can launch the server with the following command:\n\n```bash\npython3 -m sglang.launch_server \\\n    --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n    --torchao-config int4wo-128 \\\n    --port 30000 --host 0.0.0.0\n```\n\nSGLang supports the following quantization methods based on torchao `[\"int8dq\", \"int8wo\", \"fp8wo\", \"fp8dq-per_tensor\", \"fp8dq-per_row\", \"int4wo-32\", \"int4wo-64\", \"int4wo-128\", \"int4wo-256\"]`.\n\nNote: According to [this issue](https://github.com/sgl-project/sglang/issues/2219#issuecomment-2561890230), `\"int8dq\"` method currently has some bugs when using together with cuda graph capture. So we suggest to disable cuda graph capture when using `\"int8dq\"` method. Namely, please use the following command:\n\n```bash\npython3 -m sglang.launch_server \\\n    --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n    --torchao-config int8dq \\\n    --disable-cuda-graph \\\n    --port 30000 --host 0.0.0.0\n```\n\n### `quark_int4fp8_moe` online quantization method\n\nSGLang running on AMD GPUs (CDNA3 or CDNA4 architecture) supports the quantization method `--quantization quark_int4fp8_moe`, that will replace [MoE layers](https://github.com/sgl-project/sglang/blob/v0.4.8/python/sglang/srt/layers/moe/fused_moe_triton/layer.py#L271) originally in high precision (bfloat16, float16 or float32) to use weights dynamically quantized to int4, that are upcasted to float8 during inference to run compute in float8 precision with activations dynamically quantized on the fly to float8.\n\nOther layers (e.g. projections in the attention layers) have their weights quantized online to float8 directly.\n\n## Reference\n\n- [GPTQModel](https://github.com/ModelCloud/GPTQModel)\n- [LLM Compressor](https://github.com/vllm-project/llm-compressor/)\n- [NVIDIA Model Optimizer (ModelOpt)](https://github.com/NVIDIA/Model-Optimizer)\n- [NVIDIA Model Optimizer LLM PTQ](https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/llm_ptq)\n- [Petit: NVFP4 on ROCm](https://github.com/causalflow-ai/petit-kernel) — [LMSYS blog](https://lmsys.org/blog/2025-09-21-petit-amdgpu/), [AMD ROCm blog](https://rocm.blogs.amd.com/artificial-intelligence/fp4-mixed-precision/README.html)\n- [Torchao: PyTorch Architecture Optimization](https://github.com/pytorch/ao)\n- [vLLM Quantization](https://docs.vllm.ai/en/latest/quantization/)\n- [auto-round](https://github.com/intel/auto-round)\n"
  },
  {
    "path": "docs/advanced_features/quantized_kv_cache.md",
    "content": "# Quantized KV Cache\n\nQuantized KV cache reduces the memory footprint of key-value cache storage by using lower-precision data types (FP8 or FP4) instead of the default model precision in BF16. During autoregressive generation, LLMs cache previously computed key-value pairs to avoid redundant calculations. The KV cache typically consumes a significant portion of GPU memory, especially for long sequences.\n\nQuantized KV cache is a memory optimization technique that primarily benefits throughput by allowing more tokens to be cached, but may introduce minimal accuracy degradation depending on the quantization format used.\n\n```{warning}\n**Performance Warning**: When quantized KV cache must be dequantized before use in attention operations, performance can be extremely slow if dequantization is not fused with the attention kernel. Always verify that your chosen attention backend supports quantized KV cache. Backends without fused support may experience significant throughput degradation, potentially negating the memory benefits.\n\n**Backend Support**: Not all attention backends support quantized KV cache. Refer to [Attention Backend](attention_backend.md) for which backends support it.\n```\n\n## Supported Formats\n\nSGLang supports the following quantized KV cache formats:\n\n### FP8 Format\n\n[OCP (Open Compute Project)](https://www.opencompute.org) specifies two common 8-bit floating point formats:\n\n- **E5M2** (5 exponent bits, 2 mantissa bits): Larger dynamic range (±57344.0), lower precision\n- **E4M3** (4 exponent bits, 3 mantissa bits): Higher precision, smaller dynamic range (±240.0)\n\n### FP4 Format\n\n```{warning}\nFP4 quantization is currently experimental.\n```\n\n[OCP (Open Compute Project)](https://www.opencompute.org) specifies MXFP4 (Microscaling FP4), a 4-bit floating-point format:\n\n- **E2M1** (1 sign bit, 2 exponent bits, 1 mantissa bit): Uses block-based microscaling where tensors are divided into blocks of consecutive elements, with each block sharing a single 8-bit exponential scaling factor. While OCP specifies blocks of 32 elements, SGLang's current implementation uses blocks of 16 elements for KV cache quantization.\n\n## Usage\n\n### Enabling Quantized KV Cache\n\nTo enable quantized KV cache, use the `--kv-cache-dtype` argument when launching the server:\n\n```bash\n# Enable FP8 E5M2 KV cache\npython3 -m sglang.launch_server \\\n    --model-path deepseek-ai/DeepSeek-R1-0528 \\\n    --kv-cache-dtype fp8_e5m2 \\\n\n# Enable FP8 E4M3 KV cache\npython3 -m sglang.launch_server \\\n    --model-path deepseek-ai/DeepSeek-R1-0528 \\\n    --kv-cache-dtype fp8_e4m3 \\\n\n# Enable FP4 E2M1 KV cache\npython3 -m sglang.launch_server \\\n    --model-path nvidia/DeepSeek-R1-0528-NVFP4 \\\n    --kv-cache-dtype fp4_e2m1 \\\n```\n\n### Scaling Factors\n\nFP8 quantization requires scaling factors to properly quantize and dequantize the KV cache.\n\n```{note}\nCurrently, only per-tensor (scalar) scaling factors are supported.\n```\n\nScaling factors can be:\n\n- **Loaded from checkpoints**: Pre-quantized models (e.g., ModelOpt) may include `k_scale` and `v_scale` parameters that are automatically loaded\n- **Provided via JSON**: Supply scaling factors via `--quantization-param-path`.\n\nThe JSON file should follow this format:\n\n```json\n{\n  \"kv_cache\": {\n    \"dtype\": \"float8_e4m3fn\",\n    \"scaling_factor\": {\n      \"0\": {\n        \"0\": 1.0,\n        \"1\": 1.0\n      }\n    }\n  }\n}\n```\n\nWhere the outer keys in `scaling_factor` are tensor parallel ranks and inner keys are layer indices.\n\n```{warning}\nIf scaling factors are not provided and not found in the checkpoint, it will default to 1.0, which may cause accuracy issues.\n```\n\n```{tip}\n**FP4 (MXFP4)**: Unlike FP8, FP4 quantization handles scaling factors automatically on-the-fly during quantization and dequantization. No pre-quantized models or external scaling factor files are required—the block-based scaling factors are computed dynamically as needed.\n```\n\n## Performance Considerations\n\n### Memory Savings\n\nQuantized KV cache provides significant memory savings:\n- **BF16 → FP4**: Supports approximately 3.56× more tokens than BF16 (accounting for scaling factor overhead)\n\n```{note}\nFP4 and FP8 quantization require additional memory for block-based scaling factors, which reduces the effective memory savings compared to the raw bit-width reduction. FP4 with block size 16 supports approximately 1.78× more tokens than FP8, and approximately 3.56× more tokens than BF16. The relative token capacity between FP8 and BF16 can be derived from these ratios.\n```\n\nThis enables longer context lengths or more concurrent requests within the same memory budget.\n\n### Accuracy Impact\n\n#### FP8 Accuracy\n\nFP8 E4M3 quantization typically introduces minimal accuracy degradation. The impact depends on model architecture, sequence length, and quantization format (generally, E4M3 has better accuracy than E5M2).\n\n#### FP4 Accuracy\n\nFP4 (MXFP4) quantization provides significant memory savings with varying accuracy impact depending on model size and dataset complexity. Preliminary accuracy test results from [PR #10078](https://github.com/sgl-project/sglang/pull/10078) (MLA) and [PR #12612](https://github.com/sgl-project/sglang/pull/12612) (MHA) show:\n\n**Large Models (e.g., Qwen3-235B-A22B, DeepSeek-R1-0528)**\n\nOn large-scale models, FP4 maintains accuracy close to FP8/BF16, especially on simpler datasets:\n\n| Model | Dataset | KV16 | KV8 (FP8 E4M3) | KV4 (FP4 E2M1) |\n|-------|---------|------|----------------|----------------|\n| Qwen3-235B-A22B | gsm8k | 0.9168 | 0.9181 | 0.9186 |\n| Qwen3-235B-A22B | aime25 | 0.7733 | 0.7333 | 0.6000 |\n| Qwen3-235B-A22B | gpqa_diamond | 0.7010 | 0.6899 | 0.6778 |\n| DeepSeek-R1-0528 | gsm8k | 0.9157 | 0.9154 | 0.9124 |\n| DeepSeek-R1-0528 | aime25 | 0.5067 | 0.4934 | 0.4000 |\n| DeepSeek-R1-0528 | gpqa_diamond | 0.7707 | 0.7697 | 0.7273 |\n\n**Smaller Models (e.g., GPT-OSS-120B)**\n\nOn smaller models, FP4 shows more pronounced accuracy drops, particularly on challenging datasets:\n\n| Model | Dataset | KV16 | KV8 (FP8 E4M3) | KV4 (FP4 E2M1) |\n|-------|---------|------|----------------|----------------|\n| GPT-OSS-120B | gsm8k | 0.9161 | 0.9163 | 0.9152 |\n| GPT-OSS-120B | aime25 | 0.7533 | 0.7667 | 0.3533 |\n| GPT-OSS-120B | gpqa_diamond | 0.5081 | 0.5434 | 0.3202 |\n\n**Key Observations:**\n\n- **Simple datasets (e.g., gsm8k)**: FP4 maintains accuracy close to FP8/BF16 across model sizes\n- **Model size matters**: Large models (200B+ parameters) generally tolerate FP4 quantization better than smaller models\n- **Context length**: Accuracy degradation may be more pronounced in long-context scenarios, as the accumulation of the quantization error may become significant.\n\n```{tip}\nEvaluate FP4 accuracy on your specific model and workload. Large models on simpler tasks typically show minimal degradation, while smaller models or complex reasoning tasks may require FP8 or BF16 for acceptable accuracy.\n```\n\n## Best Practices\n\n- **Use pre-quantized models**: Prefer models quantized offline with scaling factors included in the checkpoint.\n- **Choose the right format**: Use `fp8_e4m3` for better accuracy (recommended), `fp8_e5m2` for larger dynamic range, or `fp4_e2m1` for maximum memory savings (experimental)\n- **Check backend compatibility**: Verify that your chosen attention backend supports quantized KV cache\n\n```{seealso}\n- [Quantization](quantization.md)\n- [Attention Backend](attention_backend.md)\n- [Server Arguments](server_arguments.md)\n```\n"
  },
  {
    "path": "docs/advanced_features/rfork.md",
    "content": "# R-Fork\n\nR-Fork (Tensor Remote Fork) is a novel weight loading methodology that leverages efficient inter-node GPU-to-GPU data transfer path to load tensors from a running SGLang instance to a new instance with zero-copy. It can significantly optimize the SGLang instance boot-up time by reducing model weights loading from several minutes to mere seconds.\n\nTo learn more details about R-Fork, please check **<a href=https://lmsys.org/blog/2025-12-10-rfork/> R-Fork blog </a>**\n\n## Usage\n\n| Argument     | Usage                                      |\n|--------------|--------------------------------------------|\n| load-format  | set to `remote_instance` to enable R-Fork. |\n| remote-instance-weight-loader-backend | `nccl`, `transfer_engine`, or `modelexpress`. Default is `nccl`. |\n| remote-instance-weight-loader-seed-instance-ip | IP address of the seed instance who will provide the model weight. Used by `nccl` and `transfer_engine` backends. |\n| remote-instance-weight-loader-seed-instance-service-port | the port that the seed instance's HTTP server is listening on. Used by `nccl` and `transfer_engine` backends. |\n| remote-instance-weight-loader-send-weights-group-ports | the list of available ports on the seed instance that will be used to build NCCL communication groups between seed and client instance. Only needed by `nccl` backend. |\n| remote-instance-weight-loader-start-seed-via-transfer-engine | set to start seed service that supports TransferEngine as backend. Needed for seed instances when using `transfer_engine` as backend. |\n| modelexpress-config | JSON config for `modelexpress` backend. Keys: `\"url\"` (required, gRPC host:port of ModelExpress server), `\"model_name\"` (optional, defaults to `--model-path`), `\"source\"` (optional bool, `true` for seed mode). |\n\n### NCCL as backend\n\nseed instance:\n```shell\npython -m sglang.launch_server [args]\n```\n\nclient instance:\n```shell\npython -m sglang.launch_server [args] \\\n  --load-format remote_instance \\\n  --remote-instance-weight-loader-seed-instance-ip [seed_instance_ip] \\\n  --remote-instance-weight-loader-seed-instance-service-port [seed_instance_service_port] \\\n  --remote-instance-weight-loader-send-weights-group-ports [send_weights_nccl_group_ports_list]  \\\n  --remote-instance-weight-loader-backend nccl\n```\n\n### TransferEngine as backend\n\nseed instance:\n```shell\npython -m sglang.launch_server [args] \\\n  --remote-instance-weight-loader-start-seed-via-transfer-engine\n```\n\n```shell\npython -m sglang.launch_server [args] \\\n  --load-format remote_instance \\\n  --remote-instance-weight-loader-seed-instance-ip [seed_instance_ip] \\\n  --remote-instance-weight-loader-seed-instance-service-port [seed_instance_service_port] \\\n  --remote-instance-weight-loader-backend transfer_engine\n```\n\n### ModelExpress as backend\n\n[ModelExpress](https://github.com/ai-dynamo/modelexpress) is a coordination service that manages P2P weight transfer metadata. It removes the need for direct seed IP/port configuration by providing a centralized registry that seeds publish to and clients discover from. Under the hood it uses TransferEngine (Mooncake) for the actual RDMA data transfer.\n\nA running ModelExpress server is required. See the [ModelExpress documentation](https://github.com/ai-dynamo/modelexpress) for setup instructions.\n\nseed instance:\n```shell\npython -m sglang.launch_server [args] \\\n  --modelexpress-config '{\"url\": \"[modelexpress_grpc_host:port]\", \"model_name\": \"[model_name]\", \"source\": true}'\n```\n\nclient instance:\n```shell\npython -m sglang.launch_server [args] \\\n  --load-format remote_instance \\\n  --remote-instance-weight-loader-backend modelexpress \\\n  --modelexpress-config '{\"url\": \"[modelexpress_grpc_host:port]\", \"model_name\": \"[model_name]\"}'\n```\n\nThe seed publishes its TransferEngine session ID and tensor layout to ModelExpress. The client queries ModelExpress to discover the seed, then pulls weights directly via RDMA. This enables dynamic seed discovery without hardcoding IPs, and supports multiple models through a single ModelExpress instance.\n"
  },
  {
    "path": "docs/advanced_features/separate_reasoning.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Reasoning Parser\\n\",\n    \"\\n\",\n    \"SGLang supports parsing reasoning content out from \\\"normal\\\" content for reasoning models such as [DeepSeek R1](https://huggingface.co/deepseek-ai/DeepSeek-R1).\\n\",\n    \"\\n\",\n    \"## Supported Models & Parsers\\n\",\n    \"\\n\",\n    \"| Model  |  Reasoning tags      | Parser | Notes |\\n\",\n    \"|---------|-----------------------------|------------------|-------|\\n\",\n    \"| [DeepSeek‑R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) | `<think>` … `</think>` | `deepseek-r1` | Supports all variants (R1, R1-0528, R1-Distill) |\\n\",\n    \"| [DeepSeek‑V3 series](https://huggingface.co/deepseek-ai/DeepSeek-V3.1) | `<think>` … `</think>` | `deepseek-v3` | Including [DeepSeek‑V3.2](https://huggingface.co/deepseek-ai/DeepSeek-V3.2-Exp). Supports `thinking` parameter |\\n\",\n    \"| [Standard Qwen3 models](https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f) | `<think>` … `</think>` | `qwen3` | Supports `enable_thinking` parameter |\\n\",\n    \"| [Qwen3-Thinking models](https://huggingface.co/Qwen/Qwen3-235B-A22B-Thinking-2507) | `<think>` … `</think>` | `qwen3` or `qwen3-thinking` | Always generates thinking content |\\n\",\n    \"| [Kimi K2 Thinking](https://huggingface.co/moonshotai/Kimi-K2-Thinking) | `◁think▷` … `◁/think▷` | `kimi_k2` | Uses special thinking delimiters. Also requires `--tool-call-parser kimi_k2` for tool use. |\\n\",\n    \"| [GPT OSS](https://huggingface.co/openai/gpt-oss-120b) | `<\\\\|channel\\\\|>analysis<\\\\|message\\\\|>` … `<\\\\|end\\\\|>` | `gpt-oss` | N/A |\\n\",\n    \"### Model-Specific Behaviors\\n\",\n    \"\\n\",\n    \"**DeepSeek-R1 Family:**\\n\",\n    \"- DeepSeek-R1: No `<think>` start tag, jumps directly to thinking content\\n\",\n    \"- DeepSeek-R1-0528: Generates both `<think>` start and `</think>` end tags\\n\",\n    \"- Both are handled by the same `deepseek-r1` parser\\n\",\n    \"\\n\",\n    \"**DeepSeek-V3 Family:**\\n\",\n    \"- DeepSeek-V3.1/V3.2: Hybrid model supporting both thinking and non-thinking modes, use the `deepseek-v3` parser and `thinking` parameter (NOTE: not `enable_thinking`)\\n\",\n    \"\\n\",\n    \"**Qwen3 Family:**\\n\",\n    \"- Standard Qwen3 (e.g., Qwen3-2507): Use `qwen3` parser, supports `enable_thinking` in chat templates\\n\",\n    \"- Qwen3-Thinking (e.g., Qwen3-235B-A22B-Thinking-2507): Use `qwen3` or `qwen3-thinking` parser, always thinks\\n\",\n    \"\\n\",\n    \"**Kimi K2:**\\n\",\n    \"- Kimi K2 Thinking: Uses special `◁think▷` and `◁/think▷` tags. For agentic tool use, also specify `--tool-call-parser kimi_k2`.\\n\",\n    \"\\n\",\n    \"**GPT OSS:**\\n\",\n    \"- GPT OSS: Uses special `<|channel|>analysis<|message|>` and `<|end|>` tags\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Usage\\n\",\n    \"\\n\",\n    \"### Launching the Server\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Specify the `--reasoning-parser` option.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import requests\\n\",\n    \"from openai import OpenAI\\n\",\n    \"from sglang.test.doc_patch import launch_server_cmd\\n\",\n    \"from sglang.utils import wait_for_server, print_highlight, terminate_process\\n\",\n    \"\\n\",\n    \"server_process, port = launch_server_cmd(\\n\",\n    \"    \\\"python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-R1-Distill-Qwen-7B --host 0.0.0.0 --reasoning-parser deepseek-r1 --log-level warning\\\"\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"wait_for_server(f\\\"http://localhost:{port}\\\", process=server_process)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Note that `--reasoning-parser` defines the parser used to interpret responses.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### OpenAI Compatible API\\n\",\n    \"\\n\",\n    \"Using the OpenAI compatible API, the contract follows the [DeepSeek API design](https://api-docs.deepseek.com/guides/reasoning_model) established with the release of DeepSeek-R1:\\n\",\n    \"\\n\",\n    \"- `reasoning_content`: The content of the CoT.\\n\",\n    \"- `content`: The content of the final answer.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Initialize OpenAI-like client\\n\",\n    \"client = OpenAI(api_key=\\\"None\\\", base_url=f\\\"http://0.0.0.0:{port}/v1\\\")\\n\",\n    \"model_name = client.models.list().data[0].id\\n\",\n    \"\\n\",\n    \"messages = [\\n\",\n    \"    {\\n\",\n    \"        \\\"role\\\": \\\"user\\\",\\n\",\n    \"        \\\"content\\\": \\\"What is 1+3?\\\",\\n\",\n    \"    }\\n\",\n    \"]\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Non-Streaming Request\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"response_non_stream = client.chat.completions.create(\\n\",\n    \"    model=model_name,\\n\",\n    \"    messages=messages,\\n\",\n    \"    temperature=0.6,\\n\",\n    \"    top_p=0.95,\\n\",\n    \"    stream=False,  # Non-streaming\\n\",\n    \"    extra_body={\\\"separate_reasoning\\\": True},\\n\",\n    \")\\n\",\n    \"print_highlight(\\\"==== Reasoning ====\\\")\\n\",\n    \"print_highlight(response_non_stream.choices[0].message.reasoning_content)\\n\",\n    \"\\n\",\n    \"print_highlight(\\\"==== Text ====\\\")\\n\",\n    \"print_highlight(response_non_stream.choices[0].message.content)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Streaming Request\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"response_stream = client.chat.completions.create(\\n\",\n    \"    model=model_name,\\n\",\n    \"    messages=messages,\\n\",\n    \"    temperature=0.6,\\n\",\n    \"    top_p=0.95,\\n\",\n    \"    stream=True,  # Non-streaming\\n\",\n    \"    extra_body={\\\"separate_reasoning\\\": True},\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"reasoning_content = \\\"\\\"\\n\",\n    \"content = \\\"\\\"\\n\",\n    \"for chunk in response_stream:\\n\",\n    \"    if chunk.choices[0].delta.content:\\n\",\n    \"        content += chunk.choices[0].delta.content\\n\",\n    \"    if chunk.choices[0].delta.reasoning_content:\\n\",\n    \"        reasoning_content += chunk.choices[0].delta.reasoning_content\\n\",\n    \"\\n\",\n    \"print_highlight(\\\"==== Reasoning ====\\\")\\n\",\n    \"print_highlight(reasoning_content)\\n\",\n    \"\\n\",\n    \"print_highlight(\\\"==== Text ====\\\")\\n\",\n    \"print_highlight(content)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Optionally, you can buffer the reasoning content to the last reasoning chunk (or the first chunk after the reasoning content).\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"response_stream = client.chat.completions.create(\\n\",\n    \"    model=model_name,\\n\",\n    \"    messages=messages,\\n\",\n    \"    temperature=0.6,\\n\",\n    \"    top_p=0.95,\\n\",\n    \"    stream=True,  # Non-streaming\\n\",\n    \"    extra_body={\\\"separate_reasoning\\\": True, \\\"stream_reasoning\\\": False},\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"reasoning_content = \\\"\\\"\\n\",\n    \"content = \\\"\\\"\\n\",\n    \"for chunk in response_stream:\\n\",\n    \"    if chunk.choices[0].delta.content:\\n\",\n    \"        content += chunk.choices[0].delta.content\\n\",\n    \"    if chunk.choices[0].delta.reasoning_content:\\n\",\n    \"        reasoning_content += chunk.choices[0].delta.reasoning_content\\n\",\n    \"\\n\",\n    \"print_highlight(\\\"==== Reasoning ====\\\")\\n\",\n    \"print_highlight(reasoning_content)\\n\",\n    \"\\n\",\n    \"print_highlight(\\\"==== Text ====\\\")\\n\",\n    \"print_highlight(content)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"The reasoning separation is enable by default when specify . \\n\",\n    \"**To disable it, set the `separate_reasoning` option to `False` in request.**\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"response_non_stream = client.chat.completions.create(\\n\",\n    \"    model=model_name,\\n\",\n    \"    messages=messages,\\n\",\n    \"    temperature=0.6,\\n\",\n    \"    top_p=0.95,\\n\",\n    \"    stream=False,  # Non-streaming\\n\",\n    \"    extra_body={\\\"separate_reasoning\\\": False},\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"print_highlight(\\\"==== Original Output ====\\\")\\n\",\n    \"print_highlight(response_non_stream.choices[0].message.content)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### SGLang Native API \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from transformers import AutoTokenizer\\n\",\n    \"\\n\",\n    \"tokenizer = AutoTokenizer.from_pretrained(\\\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\\\")\\n\",\n    \"input = tokenizer.apply_chat_template(\\n\",\n    \"    messages, tokenize=False, add_generation_prompt=True, return_dict=False\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"gen_url = f\\\"http://localhost:{port}/generate\\\"\\n\",\n    \"gen_data = {\\n\",\n    \"    \\\"text\\\": input,\\n\",\n    \"    \\\"sampling_params\\\": {\\n\",\n    \"        \\\"skip_special_tokens\\\": False,\\n\",\n    \"        \\\"max_new_tokens\\\": 1024,\\n\",\n    \"        \\\"temperature\\\": 0.6,\\n\",\n    \"        \\\"top_p\\\": 0.95,\\n\",\n    \"    },\\n\",\n    \"}\\n\",\n    \"gen_response = requests.post(gen_url, json=gen_data).json()[\\\"text\\\"]\\n\",\n    \"\\n\",\n    \"print_highlight(\\\"==== Original Output ====\\\")\\n\",\n    \"print_highlight(gen_response)\\n\",\n    \"\\n\",\n    \"parse_url = f\\\"http://localhost:{port}/separate_reasoning\\\"\\n\",\n    \"separate_reasoning_data = {\\n\",\n    \"    \\\"text\\\": gen_response,\\n\",\n    \"    \\\"reasoning_parser\\\": \\\"deepseek-r1\\\",\\n\",\n    \"}\\n\",\n    \"separate_reasoning_response_json = requests.post(\\n\",\n    \"    parse_url, json=separate_reasoning_data\\n\",\n    \").json()\\n\",\n    \"print_highlight(\\\"==== Reasoning ====\\\")\\n\",\n    \"print_highlight(separate_reasoning_response_json[\\\"reasoning_text\\\"])\\n\",\n    \"print_highlight(\\\"==== Text ====\\\")\\n\",\n    \"print_highlight(separate_reasoning_response_json[\\\"text\\\"])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"terminate_process(server_process)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Offline Engine API\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import sglang as sgl\\n\",\n    \"from sglang.srt.parser.reasoning_parser import ReasoningParser\\n\",\n    \"from sglang.utils import print_highlight\\n\",\n    \"\\n\",\n    \"llm = sgl.Engine(model_path=\\\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\\\")\\n\",\n    \"tokenizer = AutoTokenizer.from_pretrained(\\\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\\\")\\n\",\n    \"input = tokenizer.apply_chat_template(\\n\",\n    \"    messages, tokenize=False, add_generation_prompt=True, return_dict=False\\n\",\n    \")\\n\",\n    \"sampling_params = {\\n\",\n    \"    \\\"max_new_tokens\\\": 1024,\\n\",\n    \"    \\\"skip_special_tokens\\\": False,\\n\",\n    \"    \\\"temperature\\\": 0.6,\\n\",\n    \"    \\\"top_p\\\": 0.95,\\n\",\n    \"}\\n\",\n    \"result = llm.generate(prompt=input, sampling_params=sampling_params)\\n\",\n    \"\\n\",\n    \"generated_text = result[\\\"text\\\"]  # Assume there is only one prompt\\n\",\n    \"\\n\",\n    \"print_highlight(\\\"==== Original Output ====\\\")\\n\",\n    \"print_highlight(generated_text)\\n\",\n    \"\\n\",\n    \"parser = ReasoningParser(\\\"deepseek-r1\\\")\\n\",\n    \"reasoning_text, text = parser.parse_non_stream(generated_text)\\n\",\n    \"print_highlight(\\\"==== Reasoning ====\\\")\\n\",\n    \"print_highlight(reasoning_text)\\n\",\n    \"print_highlight(\\\"==== Text ====\\\")\\n\",\n    \"print_highlight(text)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"llm.shutdown()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Supporting New Reasoning Model Schemas\\n\",\n    \"\\n\",\n    \"For future reasoning models, you can implement the reasoning parser as a subclass of `BaseReasoningFormatDetector` in `python/sglang/srt/reasoning_parser.py` and specify the reasoning parser for new reasoning model schemas accordingly.\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "docs/advanced_features/server_arguments.md",
    "content": "# Server Arguments\n\nThis page provides a list of server arguments used in the command line to configure the behavior\nand performance of the language model server during deployment. These arguments enable users to\ncustomize key aspects of the server, including model selection, parallelism policies,\nmemory management, and optimization techniques.\nYou can find all arguments by `python3 -m sglang.launch_server --help`\n\n## Common launch commands\n\n- To use a configuration file, create a YAML file with your server arguments and specify it with `--config`. CLI arguments will override config file values.\n\n  ```bash\n  # Create config.yaml\n  cat > config.yaml << EOF\n  model-path: meta-llama/Meta-Llama-3-8B-Instruct\n  host: 0.0.0.0\n  port: 30000\n  tensor-parallel-size: 2\n  enable-metrics: true\n  log-requests: true\n  EOF\n\n  # Launch server with config file\n  python -m sglang.launch_server --config config.yaml\n  ```\n\n- To enable multi-GPU tensor parallelism, add `--tp 2`. If it reports the error \"peer access is not supported between these two devices\", add `--enable-p2p-check` to the server launch command.\n\n  ```bash\n  python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 2\n  ```\n\n- To enable multi-GPU data parallelism, add `--dp 2`. Data parallelism is better for throughput if there is enough memory. It can also be used together with tensor parallelism. The following command uses 4 GPUs in total. We recommend [SGLang Model Gateway (former Router)](../advanced_features/sgl_model_gateway.md) for data parallelism.\n\n  ```bash\n  python -m sglang_router.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --dp 2 --tp 2\n  ```\n\n- If you see out-of-memory errors during serving, try to reduce the memory usage of the KV cache pool by setting a smaller value of `--mem-fraction-static`. The default value is `0.9`.\n\n  ```bash\n  python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --mem-fraction-static 0.7\n  ```\n\n- See [hyperparameter tuning](hyperparameter_tuning.md) on tuning hyperparameters for better performance.\n- For docker and Kubernetes runs, you need to set up shared memory which is used for communication between processes. See `--shm-size` for docker and `/dev/shm` size update for Kubernetes manifests.\n- If you see out-of-memory errors during prefill for long prompts, try to set a smaller chunked prefill size.\n\n  ```bash\n  python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --chunked-prefill-size 4096\n  ```\n- To enable fp8 weight quantization, add `--quantization fp8` on a fp16 checkpoint or directly load a fp8 checkpoint without specifying any arguments.\n- To enable fp8 kv cache quantization, add `--kv-cache-dtype fp8_e4m3` or `--kv-cache-dtype fp8_e5m2`.\n- To enable deterministic inference and batch invariant operations, add `--enable-deterministic-inference`. More details can be found in [deterministic inference document](../advanced_features/deterministic_inference.md).\n- If the model does not have a chat template in the Hugging Face tokenizer, you can specify a [custom chat template](../references/custom_chat_template.md). If the tokenizer has multiple named templates (e.g., 'default', 'tool_use'), you can select one using `--hf-chat-template-name tool_use`.\n- To run tensor parallelism on multiple nodes, add `--nnodes 2`. If you have two nodes with two GPUs on each node and want to run TP=4, let `sgl-dev-0` be the hostname of the first node and `50000` be an available port, you can use the following commands. If you meet deadlock, please try to add `--disable-cuda-graph`\n- (Note: This feature is out of maintenance and might cause error) To enable `torch.compile` acceleration, add `--enable-torch-compile`. It accelerates small models on small batch sizes. By default, the cache path is located at `/tmp/torchinductor_root`, you can customize it using environment variable `TORCHINDUCTOR_CACHE_DIR`. For more details, please refer to [PyTorch official documentation](https://pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html) and [Enabling cache for torch.compile](https://docs.sglang.io/references/torch_compile_cache.html).\n  ```bash\n  # Node 0\n  python -m sglang.launch_server \\\n    --model-path meta-llama/Meta-Llama-3-8B-Instruct \\\n    --tp 4 \\\n    --dist-init-addr sgl-dev-0:50000 \\\n    --nnodes 2 \\\n    --node-rank 0\n\n  # Node 1\n  python -m sglang.launch_server \\\n    --model-path meta-llama/Meta-Llama-3-8B-Instruct \\\n    --tp 4 \\\n    --dist-init-addr sgl-dev-0:50000 \\\n    --nnodes 2 \\\n    --node-rank 1\n  ```\n\nPlease consult the documentation below and [server_args.py](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/server_args.py) to learn more about the arguments you may provide when launching a server.\n\n## Model and tokenizer\n| Argument | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `--model-path`<br>`--model` | The path of the model weights. This can be a local folder or a Hugging Face repo ID. | `None` | Type: str |\n| `--tokenizer-path` | The path of the tokenizer. | `None` | Type: str |\n| `--tokenizer-mode` | Tokenizer mode. 'auto' will use the fast tokenizer if available, and 'slow' will always use the slow tokenizer. | `auto` | `auto`, `slow` |\n| `--tokenizer-worker-num` | The worker num of the tokenizer manager. | `1` | Type: int |\n| `--skip-tokenizer-init` | If set, skip init tokenizer and pass input_ids in generate request. | `False` | bool flag (set to enable) |\n| `--load-format` | The format of the model weights to load. \"auto\" will try to load the weights in the safetensors format and fall back to the pytorch bin format if safetensors format is not available. \"pt\" will load the weights in the pytorch bin format. \"safetensors\" will load the weights in the safetensors format. \"npcache\" will load the weights in pytorch format and store a numpy cache to speed up the loading. \"dummy\" will initialize the weights with random values, which is mainly for profiling.\"gguf\" will load the weights in the gguf format. \"bitsandbytes\" will load the weights using bitsandbytes quantization.\"layered\" loads weights layer by layer so that one can quantize a layer before loading another to make the peak memory envelope smaller. \"flash_rl\" will load the weights in flash_rl format. \"fastsafetensors\" and \"private\" are also supported. | `auto` | `auto`, `pt`, `safetensors`, `npcache`, `dummy`, `sharded_state`, `gguf`, `bitsandbytes`, `layered`, `flash_rl`, `remote`, `remote_instance`, `fastsafetensors`, `private` |\n| `--model-loader-extra-config` | Extra config for model loader. This will be passed to the model loader corresponding to the chosen load_format. | `{}` | Type: str |\n| `--trust-remote-code` | Whether or not to allow for custom models defined on the Hub in their own modeling files. | `False` | bool flag (set to enable) |\n| `--context-length` | The model's maximum context length. Defaults to None (will use the value from the model's config.json instead). | `None` | Type: int |\n| `--is-embedding` | Whether to use a CausalLM as an embedding model. | `False` | bool flag (set to enable) |\n| `--enable-multimodal` | Enable the multimodal functionality for the served model. If the model being served is not multimodal, nothing will happen | `None` | bool flag (set to enable) |\n| `--revision` | The specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. | `None` | Type: str |\n| `--model-impl` | Which implementation of the model to use. * \"auto\" will try to use the SGLang implementation if it exists and fall back to the Transformers implementation if no SGLang implementation is available. * \"sglang\" will use the SGLang model implementation. * \"transformers\" will use the Transformers model implementation. | `auto` | Type: str |\n\n## HTTP server\n| Argument | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `--host` | The host of the HTTP server. | `127.0.0.1` | Type: str |\n| `--port` | The port of the HTTP server. | `30000` | Type: int |\n| `--fastapi-root-path` | App is behind a path based routing proxy. | `\"\"` | Type: str |\n| `--grpc-mode` | If set, use gRPC server instead of HTTP server. | `False` | bool flag (set to enable) |\n| `--skip-server-warmup` | If set, skip warmup. | `False` | bool flag (set to enable) |\n| `--warmups` | Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests | `None` | Type: str |\n| `--nccl-port` | The port for NCCL distributed environment setup. Defaults to a random port. | `None` | Type: int |\n| `--checkpoint-engine-wait-weights-before-ready` | If set, the server will wait for initial weights to be loaded via checkpoint-engine or other update methods before serving inference requests. | `False` | bool flag (set to enable) |\n\n## Quantization and data type\n| Argument | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `--dtype` | Data type for model weights and activations. * \"auto\" will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models. * \"half\" for FP16. Recommended for AWQ quantization. * \"float16\" is the same as \"half\". * \"bfloat16\" for a balance between precision and range. * \"float\" is shorthand for FP32 precision. * \"float32\" for FP32 precision. | `auto` | `auto`, `half`, `float16`, `bfloat16`, `float`, `float32` |\n| `--quantization` | The quantization method. | `None` | `awq`, `fp8`, `gptq`, `marlin`, `gptq_marlin`, `awq_marlin`, `bitsandbytes`, `gguf`, `modelopt`, `modelopt_fp8`, `modelopt_fp4`, `petit_nvfp4`, `w8a8_int8`, `w8a8_fp8`, `moe_wna16`, `qoq`, `w4afp8`, `mxfp4`, `mxfp8`, `auto-round`, `compressed-tensors`, `modelslim`, `quark_int4fp8_moe` |\n| `--quantization-param-path` | Path to the JSON file containing the KV cache scaling factors. This should generally be supplied, when KV cache dtype is FP8. Otherwise, KV cache scaling factors default to 1.0, which may cause accuracy issues. | `None` | Type: Optional[str] |\n| `--kv-cache-dtype` | Data type for kv cache storage. \"auto\" will use model data type. \"bf16\" or \"bfloat16\" for BF16 KV cache. \"fp8_e5m2\" and \"fp8_e4m3\" are supported for CUDA 11.8+. \"fp4_e2m1\" (only mxfp4) is supported for CUDA 12.8+ and PyTorch 2.8.0+ | `auto` | `auto`, `fp8_e5m2`, `fp8_e4m3`, `bf16`, `bfloat16`, `fp4_e2m1` |\n| `--enable-fp32-lm-head` | If set, the LM head outputs (logits) are in FP32. | `False` | bool flag (set to enable) |\n| `--modelopt-quant` | The ModelOpt quantization configuration. Supported values: 'fp8', 'int4_awq', 'w4a8_awq', 'nvfp4', 'nvfp4_awq'. This requires the NVIDIA Model Optimizer library to be installed: pip install nvidia-modelopt | `None` | Type: str |\n| `--modelopt-checkpoint-restore-path` | Path to restore a previously saved ModelOpt quantized checkpoint. If provided, the quantization process will be skipped and the model will be loaded from this checkpoint. | `None` | Type: str |\n| `--modelopt-checkpoint-save-path` | Path to save the ModelOpt quantized checkpoint after quantization. This allows reusing the quantized model in future runs. | `None` | Type: str |\n| `--modelopt-export-path` | Path to export the quantized model in HuggingFace format after ModelOpt quantization. The exported model can then be used directly with SGLang for inference. If not provided, the model will not be exported. | `None` | Type: str |\n| `--quantize-and-serve` | Quantize the model with ModelOpt and immediately serve it without exporting. This is useful for development and prototyping. For production, it's recommended to use separate quantization and deployment steps. | `False` | bool flag (set to enable) |\n| `--rl-quant-profile` | Path to the FlashRL quantization profile. Required when using --load-format flash_rl. | `None` | Type: str |\n\n## Memory and scheduling\n| Argument | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `--mem-fraction-static` | The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors. | `None` | Type: float |\n| `--max-running-requests` | The maximum number of running requests. | `None` | Type: int |\n| `--max-queued-requests` | The maximum number of queued requests. This option is ignored when using disaggregation-mode. | `None` | Type: int |\n| `--max-total-tokens` | The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. This option is typically used for development and debugging purposes. | `None` | Type: int |\n| `--chunked-prefill-size` | The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill. | `None` | Type: int |\n| `--prefill-max-requests` | The maximum number of requests in a prefill batch. If not specified, there is no limit. | `None` | Type: int |\n| `--enable-dynamic-chunking` | Enable dynamic chunk size adjustment for pipeline parallelism. When enabled, chunk sizes are dynamically calculated based on fitted function to maintain consistent execution time across chunks. | `False` | bool flag (set to enable) |\n| `--max-prefill-tokens` | The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length. | `16384` | Type: int |\n| `--schedule-policy` | The scheduling policy of the requests. | `fcfs` | `lpm`, `random`, `fcfs`, `dfs-weight`, `lof`, `priority`, `routing-key` |\n| `--enable-priority-scheduling` | Enable priority scheduling. Requests with higher priority integer values will be scheduled first by default. | `False` | bool flag (set to enable) |\n| `--abort-on-priority-when-disabled` | If set, abort requests that specify a priority when priority scheduling is disabled. | `False` | bool flag (set to enable) |\n| `--schedule-low-priority-values-first` | If specified with --enable-priority-scheduling, the scheduler will schedule requests with lower priority integer values first. | `False` | bool flag (set to enable) |\n| `--priority-scheduling-preemption-threshold` | Minimum difference in priorities for an incoming request to have to preempt running request(s). | `10` | Type: int |\n| `--schedule-conservativeness` | How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently. | `1.0` | Type: float |\n| `--page-size` | The number of tokens in a page. | `1` | Type: int |\n| `--swa-full-tokens-ratio` | The ratio of SWA layer KV tokens / full layer KV tokens, regardless of the number of swa:full layers. It should be between 0 and 1. E.g. 0.5 means if each swa layer has 50 tokens, then each full layer has 100 tokens. | `0.8` | Type: float |\n| `--disable-hybrid-swa-memory` | Disable the hybrid SWA memory. | `False` | bool flag (set to enable) |\n| `--radix-eviction-policy` | The eviction policy of radix trees. 'lru' stands for Least Recently Used, 'lfu' stands for Least Frequently Used. | `lru` | `lru`, `lfu` |\n| `--enable-prefill-delayer` | Enable prefill delayer for DP attention to reduce idle time. | `False` | bool flag (set to enable) |\n| `--prefill-delayer-max-delay-passes` | Maximum forward passes to delay prefill. | `30` | Type: int |\n| `--prefill-delayer-token-usage-low-watermark` | Token usage low watermark for prefill delayer. | `None` | Type: float |\n| `--prefill-delayer-forward-passes-buckets` | Custom buckets for prefill delayer forward passes histogram. 0 and max_delay_passes-1 will be auto-added. | `None` | List[float] |\n| `--prefill-delayer-wait-seconds-buckets` | Custom buckets for prefill delayer wait seconds histogram. 0 will be auto-added. | `None` | List[float] |\n\n## Runtime options\n| Argument | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `--device` | The device to use ('cuda', 'xpu', 'hpu', 'npu', 'cpu'). Defaults to auto-detection if not specified. | `None` | Type: str |\n| `--tensor-parallel-size`<br>`--tp-size` | The tensor parallelism size. | `1` | Type: int |\n| `--pipeline-parallel-size`<br>`--pp-size` | The pipeline parallelism size. | `1` | Type: int |\n| `--attention-context-parallel-size`<br>`--attn-cp-size`| The attention context parallelism size. | `1` | Type: int|\n| `--moe-data-parallel-size`<br>`--moe-dp-size`| The moe data parallelism size. | `1` | Type: int|\n| `--pp-max-micro-batch-size` | The maximum micro batch size in pipeline parallelism. | `None` | Type: int |\n| `--pp-async-batch-depth` | The async batch depth of pipeline parallelism. | `0` | Type: int |\n| `--stream-interval` | The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher | `1` | Type: int |\n| `--incremental-streaming-output` | Whether to output as a sequence of disjoint segments. | `False` | bool flag (set to enable) |\n| `--random-seed` | The random seed. | `None` | Type: int |\n| `--constrained-json-whitespace-pattern` | (outlines and llguidance backends only) Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model to generate consecutive whitespaces, set the pattern to [\\n\\t ]* | `None` | Type: str |\n| `--constrained-json-disable-any-whitespace` | (xgrammar and llguidance backends only) Enforce compact representation in JSON constrained output. | `False` | bool flag (set to enable) |\n| `--watchdog-timeout` | Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging. | `300` | Type: float |\n| `--soft-watchdog-timeout` | Set soft watchdog timeout in seconds. If a forward batch takes longer than this, the server will dump information for debugging. | `None` | Type: float |\n| `--dist-timeout` | Set timeout for torch.distributed initialization. | `None` | Type: int |\n| `--download-dir` | Model download directory for huggingface. | `None` | Type: str |\n| `--model-checksum` | Model file integrity verification. If provided without value, uses model-path as HF repo ID. Otherwise, provide checksums JSON file path or HuggingFace repo ID. | `None` | Type: str |\n| `--base-gpu-id` | The base GPU ID to start allocating GPUs from. Useful when running multiple instances on the same machine. | `0` | Type: int |\n| `--gpu-id-step` | The delta between consecutive GPU IDs that are used. For example, setting it to 2 will use GPU 0,2,4,... | `1` | Type: int |\n| `--sleep-on-idle` | Reduce CPU usage when sglang is idle. | `False` | bool flag (set to enable) |\n| `--custom-sigquit-handler` | Register a custom sigquit handler so you can do additional cleanup after the server is shutdown. This is only available for Engine, not for CLI. | `None` | Type: str |\n\n## Logging\n| Argument | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `--log-level` | The logging level of all loggers. | `info` | Type: str |\n| `--log-level-http` | The logging level of HTTP server. If not set, reuse --log-level by default. | `None` | Type: str |\n| `--log-requests` | Log metadata, inputs, outputs of all requests. The verbosity is decided by --log-requests-level | `False` | bool flag (set to enable) |\n| `--log-requests-level` | 0: Log metadata (no sampling parameters). 1: Log metadata and sampling parameters. 2: Log metadata, sampling parameters and partial input/output. 3: Log every input/output. | `2` | `0`, `1`, `2`, `3` |\n| `--log-requests-format` | Format for request logging: 'text' (human-readable) or 'json' (structured) | `text` | `text`, `json` |\n| `--log-requests-target` | Target(s) for request logging: 'stdout' and/or directory path(s) for file output. Can specify multiple targets, e.g., '--log-requests-target stdout /my/path'. | `None` | List[str] |\n| `--uvicorn-access-log-exclude-prefixes` | Exclude uvicorn access logs whose request path starts with any of these prefixes. Defaults to empty (disabled). | `[]` | List[str] |\n| `--crash-dump-folder` | Folder path to dump requests from the last 5 min before a crash (if any). If not specified, crash dumping is disabled. | `None` | Type: str |\n| `--show-time-cost` | Show time cost of custom marks. | `False` | bool flag (set to enable) |\n| `--enable-metrics` | Enable log prometheus metrics. | `False` | bool flag (set to enable) |\n| `--enable-metrics-for-all-schedulers` | Enable --enable-metrics-for-all-schedulers when you want schedulers on all TP ranks (not just TP 0) to record request metrics separately. This is especially useful when dp_attention is enabled, as otherwise all metrics appear to come from TP 0. | `False` | bool flag (set to enable) |\n| `--tokenizer-metrics-custom-labels-header` | Specify the HTTP header for passing custom labels for tokenizer metrics. | `x-custom-labels` | Type: str |\n| `--tokenizer-metrics-allowed-custom-labels` | The custom labels allowed for tokenizer metrics. The labels are specified via a dict in '--tokenizer-metrics-custom-labels-header' field in HTTP requests, e.g., {'label1': 'value1', 'label2': 'value2'} is allowed if '--tokenizer-metrics-allowed-custom-labels label1 label2' is set. | `None` | List[str] |\n| `--bucket-time-to-first-token` | The buckets of time to first token, specified as a list of floats. | `None` | List[float] |\n| `--bucket-inter-token-latency` | The buckets of inter-token latency, specified as a list of floats. | `None` | List[float] |\n| `--bucket-e2e-request-latency` | The buckets of end-to-end request latency, specified as a list of floats. | `None` | List[float] |\n| `--collect-tokens-histogram` | Collect prompt/generation tokens histogram. | `False` | bool flag (set to enable) |\n| `--prompt-tokens-buckets` | The buckets rule of prompt tokens. Supports 3 rule types: 'default' uses predefined buckets; 'tse <middle> <base> <count>' generates two sides exponential distributed buckets (e.g., 'tse 1000 2 8' generates buckets [984.0, 992.0, 996.0, 998.0, 1000.0, 1002.0, 1004.0, 1008.0, 1016.0]).); 'custom <value1> <value2> ...' uses custom bucket values (e.g., 'custom 10 50 100 500'). | `None` | List[str] |\n| `--generation-tokens-buckets` | The buckets rule for generation tokens histogram. Supports 3 rule types: 'default' uses predefined buckets; 'tse <middle> <base> <count>' generates two sides exponential distributed buckets (e.g., 'tse 1000 2 8' generates buckets [984.0, 992.0, 996.0, 998.0, 1000.0, 1002.0, 1004.0, 1008.0, 1016.0]).); 'custom <value1> <value2> ...' uses custom bucket values (e.g., 'custom 10 50 100 500'). | `None` | List[str] |\n| `--gc-warning-threshold-secs` | The threshold for long GC warning. If a GC takes longer than this, a warning will be logged. Set to 0 to disable. | `0.0` | Type: float |\n| `--decode-log-interval` | The log interval of decode batch. | `40` | Type: int |\n| `--enable-request-time-stats-logging` | Enable per request time stats logging | `False` | bool flag (set to enable) |\n| `--kv-events-config` | Config in json format for NVIDIA dynamo KV event publishing. Publishing will be enabled if this flag is used. | `None` | Type: str |\n| `--enable-trace` | Enable opentelemetry trace | `False` | bool flag (set to enable) |\n| `--otlp-traces-endpoint` | Config opentelemetry collector endpoint if --enable-trace is set. format: <ip>:<port> | `localhost:4317` | Type: str |\n\n## RequestMetricsExporter configuration\n| Argument | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `--export-metrics-to-file` | Export performance metrics for each request to local file (e.g. for forwarding to external systems). | `False` | bool flag (set to enable) |\n| `--export-metrics-to-file-dir` | Directory path for writing performance metrics files (required when --export-metrics-to-file is enabled). | `None` | Type: str |\n\n## API related\n| Argument | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `--api-key` | Set API key of the server. It is also used in the OpenAI API compatible server. | `None` | Type: str |\n| `--admin-api-key` | Set **admin API key** for administrative/control endpoints (e.g., weights update, cache flush, `/get_server_info`). Endpoints marked as admin-only require `Authorization: Bearer <admin_api_key>` when this is set. | `None` | Type: str |\n| `--served-model-name` | Override the model name returned by the v1/models endpoint in OpenAI API server. | `None` | Type: str |\n| `--weight-version` | Version identifier for the model weights. Defaults to 'default' if not specified. | `default` | Type: str |\n| `--chat-template` | The builtin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server. | `None` | Type: str |\n| `--hf-chat-template-name` | When the HuggingFace tokenizer has multiple chat templates (e.g., 'default', 'tool_use', 'rag'), specify which named template to use. If not set, the first available template is used. | `None` | Type: str |\n| `--completion-template` | The builtin completion template name or the path of the completion template file. This is only used for OpenAI-compatible API server. only for code completion currently. | `None` | Type: str |\n| `--file-storage-path` | The path of the file storage in backend. | `sglang_storage` | Type: str |\n| `--enable-cache-report` | Return number of cached tokens in usage.prompt_tokens_details for each openai request. | `False` | bool flag (set to enable) |\n| `--reasoning-parser` | Specify the parser for reasoning models. Supported parsers: [deepseek-r1, deepseek-v3, glm45, gpt-oss, kimi, qwen3, qwen3-thinking, step3]. | `None` | `deepseek-r1`, `deepseek-v3`, `glm45`, `gpt-oss`, `kimi`, `qwen3`, `qwen3-thinking`, `step3` |\n| `--tool-call-parser` | Specify the parser for handling tool-call interactions. Supported parsers: [deepseekv3, deepseekv31, glm, glm45, glm47, gpt-oss, kimi_k2, llama3, mistral, pythonic, qwen, qwen25, qwen3_coder, step3]. | `None` | `deepseekv3`, `deepseekv31`, `glm`, `glm45`, `glm47`, `gpt-oss`, `kimi_k2`, `llama3`, `mistral`, `pythonic`, `qwen`, `qwen25`, `qwen3_coder`, `step3`, `gigachat3` |\n| `--tool-server` | Either 'demo' or a comma-separated list of tool server urls to use for the model. If not specified, no tool server will be used. | `None` | Type: str |\n| `--sampling-defaults` | Where to get default sampling parameters. 'openai' uses SGLang/OpenAI defaults (temperature=1.0, top_p=1.0, etc.). 'model' uses the model's generation_config.json to get the recommended sampling parameters if available. Default is 'model'. | `model` | `openai`, `model` |\n\n## Data parallelism\n| Argument | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `--data-parallel-size`<br>`--dp-size` | The data parallelism size. | `1` | Type: int |\n| `--load-balance-method` | The load balancing strategy for data parallelism. The `total_tokens` algorithm can only be used when DP attention is applied. This algorithm performs load balancing based on the real-time token load of the DP workers. | `auto` | `auto`, `round_robin`, `follow_bootstrap_room`, `total_requests`, `total_tokens` |\n\n## Multi-node distributed serving\n| Argument | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `--dist-init-addr`<br>`--nccl-init-addr` | The host address for initializing distributed backend (e.g., `192.168.0.2:25000`). | `None` | Type: str |\n| `--nnodes` | The number of nodes. | `1` | Type: int |\n| `--node-rank` | The node rank. | `0` | Type: int |\n\n## Model override args\n| Argument | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `--json-model-override-args` | A dictionary in JSON string format used to override default model configurations. | `{}` | Type: str |\n| `--preferred-sampling-params` | json-formatted sampling settings that will be returned in /get_model_info | `None` | Type: str |\n\n## LoRA\n| Argument | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `--enable-lora` | Enable LoRA support for the model. This argument is automatically set to `True` if `--lora-paths` is provided for backward compatibility. | `False` | Bool flag (set to enable) |\n| `--enable-lora-overlap-loading` | Enable asynchronous LoRA weight loading in order to overlap H2D transfers with GPU compute. This should be enabled if you find that your LoRA workloads are bottlenecked by adapter weight loading, for example when frequently loading large LoRA adapters. | `False` | Bool flag (set to enable)\n| `--max-lora-rank` | The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup. | `None` | Type: int |\n| `--lora-target-modules` | The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. You can also set it to `all` to enable LoRA for all supported modules; note this may introduce minor performance overhead. | `None` | `q_proj`, `k_proj`, `v_proj`, `o_proj`, `gate_proj`, `up_proj`, `down_proj`, `qkv_proj`, `gate_up_proj`, `all` |\n| `--lora-paths` | The list of LoRA adapters to load. Each adapter must be specified in one of the following formats: `<PATH>` \\| `<NAME>=<PATH>` \\| JSON with schema `{\"lora_name\": str, \"lora_path\": str, \"pinned\": bool}`. | `None` | Type: List[str] / JSON objects |\n| `--max-loras-per-batch` | Maximum number of adapters for a running batch, including base-only requests. | `8` | Type: int |\n| `--max-loaded-loras` | If specified, limits the maximum number of LoRA adapters loaded in CPU memory at a time. Must be ≥ `--max-loras-per-batch`. | `None` | Type: int |\n| `--lora-eviction-policy` | LoRA adapter eviction policy when the GPU memory pool is full. | `lru` | `lru`, `fifo` |\n| `--lora-backend` | Choose the kernel backend for multi-LoRA serving. | `csgmv` | `triton`, `csgmv`, `ascend`, `torch_native` |\n| `--max-lora-chunk-size` | Maximum chunk size for the ChunkedSGMV LoRA backend. Only used when `--lora-backend` is `csgmv`. Larger values may improve performance. | `16` | `16`, `32`, `64`, `128` |\n\n## Kernel Backends (Attention, Sampling, Grammar, GEMM)\n| Argument | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `--attention-backend` | Choose the kernels for attention layers. | `None` | `triton`, `torch_native`, `flex_attention`, `nsa`, `cutlass_mla`, `fa3`, `fa4`, `flashinfer`, `flashmla`, `trtllm_mla`, `trtllm_mha`, `dual_chunk_flash_attn`, `aiter`, `wave`, `intel_amx`, `ascend` |\n| `--prefill-attention-backend` | Choose the kernels for prefill attention layers (have priority over --attention-backend). | `None` | `triton`, `torch_native`, `flex_attention`, `nsa`, `cutlass_mla`, `fa3`, `fa4`, `flashinfer`, `flashmla`, `trtllm_mla`, `trtllm_mha`, `dual_chunk_flash_attn`, `aiter`, `wave`, `intel_amx`, `ascend` |\n| `--decode-attention-backend` | Choose the kernels for decode attention layers (have priority over --attention-backend). | `None` | `triton`, `torch_native`, `flex_attention`, `nsa`, `cutlass_mla`, `fa3`, `fa4`, `flashinfer`, `flashmla`, `trtllm_mla`, `trtllm_mha`, `dual_chunk_flash_attn`, `aiter`, `wave`, `intel_amx`, `ascend` |\n| `--sampling-backend` | Choose the kernels for sampling layers. | `None` | `flashinfer`, `pytorch`, `ascend` |\n| `--grammar-backend` | Choose the backend for grammar-guided decoding. | `None` | `xgrammar`, `outlines`, `llguidance`, `none` |\n| `--mm-attention-backend` | Set multimodal attention backend. | `None` | `sdpa`, `fa3`, `fa4`, `triton_attn`, `ascend_attn`, `aiter_attn` |\n| `--nsa-prefill-backend` | Choose the NSA backend for the prefill stage (overrides `--attention-backend` when running DeepSeek NSA-style attention). | `flashmla_sparse` | `flashmla_sparse`, `flashmla_kv`, `flashmla_auto`, `fa3`, `tilelang`, `aiter`, `trtllm` |\n| `--nsa-decode-backend` | Choose the NSA backend for the decode stage when running DeepSeek NSA-style attention. Overrides `--attention-backend` for decoding. | `fa3` | `flashmla_sparse`, `flashmla_kv`, `fa3`, `tilelang`, `aiter`, `trtllm` |\n| `--fp8-gemm-backend` | Choose the runner backend for Blockwise FP8 GEMM operations. Options: 'auto' (default, auto-selects based on hardware), 'deep_gemm' (JIT-compiled; enabled by default on NVIDIA Hopper (SM90) and Blackwell (SM100) when DeepGEMM is installed), 'flashinfer_trtllm' (FlashInfer TRTLLM backend; SM100/SM103 only), 'flashinfer_cutlass' (FlashInfer CUTLASS backend, SM120 only), 'flashinfer_deepgemm' (Hopper SM90 only, uses swapAB optimization for small M dimensions in decoding), 'cutlass' (optimal for Hopper/Blackwell GPUs and high-throughput), 'triton' (fallback, widely compatible), 'aiter' (ROCm only). **NOTE**: This replaces the deprecated environment variables SGLANG_ENABLE_FLASHINFER_FP8_GEMM and SGLANG_SUPPORT_CUTLASS_BLOCK_FP8. | `auto` | `auto`, `deep_gemm`, `flashinfer_trtllm`, `flashinfer_cutlass`, `flashinfer_deepgemm`, `cutlass`, `triton`, `aiter` |\n| `--fp4-gemm-backend` | Choose the runner backend for NVFP4 GEMM operations. Options: 'flashinfer_cutlass' (default), 'auto' (auto-selects between flashinfer_cudnn/flashinfer_cutlass based on CUDA/cuDNN version), 'flashinfer_cudnn' (FlashInfer cuDNN backend, optimal on CUDA 13+ with cuDNN 9.15+), 'flashinfer_trtllm' (FlashInfer TensorRT-LLM backend, requires different weight preparation with shuffling). All backends are from FlashInfer; when FlashInfer is unavailable, sgl-kernel CUTLASS is used as an automatic fallback. **NOTE**: This replaces the deprecated environment variable SGLANG_FLASHINFER_FP4_GEMM_BACKEND. | `flashinfer_cutlass` | `auto`, `flashinfer_cudnn`, `flashinfer_cutlass`, `flashinfer_trtllm` |\n| `--disable-flashinfer-autotune` | Flashinfer autotune is enabled by default. Set this flag to disable the autotune. | `False` | bool flag (set to enable) |\n\n## Speculative decoding\n| Argument | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `--speculative-algorithm` | Speculative algorithm. | `None` | `EAGLE`, `EAGLE3`, `NEXTN`, `STANDALONE`, `NGRAM` |\n| `--speculative-draft-model-path`<br>`--speculative-draft-model` | The path of the draft model weights. This can be a local folder or a Hugging Face repo ID. | `None` | Type: str |\n| `--speculative-draft-model-revision` | The specific draft model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. | `None` | Type: str |\n| `--speculative-draft-load-format` | The format of the draft model weights to load. If not specified, will use the same format as --load-format. Use 'dummy' to initialize draft model weights with random values for profiling. | `None` | Same as --load-format options |\n| `--speculative-num-steps` | The number of steps sampled from draft model in Speculative Decoding. | `None` | Type: int |\n| `--speculative-eagle-topk` | The number of tokens sampled from the draft model in eagle2 each step. | `None` | Type: int |\n| `--speculative-num-draft-tokens` | The number of tokens sampled from the draft model in Speculative Decoding. | `None` | Type: int |\n| `--speculative-accept-threshold-single` | Accept a draft token if its probability in the target model is greater than this threshold. | `1.0` | Type: float |\n| `--speculative-accept-threshold-acc` | The accept probability of a draft token is raised from its target probability p to min(1, p / threshold_acc). | `1.0` | Type: float |\n| `--speculative-token-map` | The path of the draft model's small vocab table. | `None` | Type: str |\n| `--speculative-attention-mode` | Attention backend for speculative decoding operations (both target verify and draft extend). Can be one of 'prefill' (default) or 'decode'. | `prefill` | `prefill`, `decode` |\n| `--speculative-draft-attention-backend` | Attention backend for speculative decoding drafting. | `None` | Same as attention backend options |\n| `--speculative-moe-runner-backend` | MOE backend for EAGLE speculative decoding, see --moe-runner-backend for options. Same as moe runner backend if unset. | `None` | Same as --moe-runner-backend options |\n| `--speculative-moe-a2a-backend` | MOE A2A backend for EAGLE speculative decoding, see --moe-a2a-backend for options. Same as moe a2a backend if unset. | `None` | Same as --moe-a2a-backend options |\n| `--speculative-draft-model-quantization` | The quantization method for speculative model. | `None` | Same as --quantization options |\n\n## Ngram speculative decoding\n| Argument | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `--speculative-ngram-min-match-window-size` | The minimum window size for pattern matching in ngram speculative decoding. | `1` | Type: int |\n| `--speculative-ngram-max-match-window-size` | The maximum window size for pattern matching in ngram speculative decoding. | `12` | Type: int |\n| `--speculative-ngram-min-bfs-breadth` | The minimum breadth for BFS (Breadth-First Search) in ngram speculative decoding. | `1` | Type: int |\n| `--speculative-ngram-max-bfs-breadth` | The maximum breadth for BFS (Breadth-First Search) in ngram speculative decoding. | `10` | Type: int |\n| `--speculative-ngram-match-type` | The match type for cache tree. | `BFS` | `BFS`, `PROB` |\n| `--speculative-ngram-branch-length` | The branch length for ngram speculative decoding. | `18` | Type: int |\n| `--speculative-ngram-capacity` | The cache capacity for ngram speculative decoding. | `10000000` | Type: int |\n\n## Multi-layer Eagle speculative decoding\n| Argument | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `--enable-multi-layer-eagle` | Enable multi-layer Eagle speculative decoding. | `False` | bool flag (set to enable) |\n\n## MoE\n| Argument | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `--expert-parallel-size`<br>`--ep-size`<br>`--ep` | The expert parallelism size. | `1` | Type: int |\n| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | `none` | `none`, `deepep`, `mooncake`, `mori`, `nixl`, `ascend_fuseep`|\n| `--moe-runner-backend` | Choose the runner backend for MoE. | `auto` | `auto`, `deep_gemm`, `triton`, `triton_kernel`, `flashinfer_trtllm`, `flashinfer_trtllm_routed`, `flashinfer_cutlass`, `flashinfer_mxfp4`, `flashinfer_cutedsl`, `cutlass` |\n| `--flashinfer-mxfp4-moe-precision` | Choose the computation precision of flashinfer mxfp4 moe | `default` | `default`, `bf16` |\n| `--enable-flashinfer-allreduce-fusion` | Enable FlashInfer allreduce fusion with Residual RMSNorm. | `False` | bool flag (set to enable) |\n| `--enable-aiter-allreduce-fusion` | Enable aiter allreduce fusion with Residual RMSNorm. | `False` | bool flag (set to enable) |\n| `--deepep-mode` | Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch. | `auto` | `normal`, `low_latency`, `auto` |\n| `--ep-num-redundant-experts` | Allocate this number of redundant experts in expert parallel. | `0` | Type: int |\n| `--ep-dispatch-algorithm` | The algorithm to choose ranks for redundant experts in expert parallel. | `None` | Type: str |\n| `--init-expert-location` | Initial location of EP experts. | `trivial` | Type: str |\n| `--enable-eplb` | Enable EPLB algorithm | `False` | bool flag (set to enable) |\n| `--eplb-algorithm` | Chosen EPLB algorithm | `auto` | Type: str |\n| `--eplb-rebalance-num-iterations` | Number of iterations to automatically trigger a EPLB re-balance. | `1000` | Type: int |\n| `--eplb-rebalance-layers-per-chunk` | Number of layers to rebalance per forward pass. | `None` | Type: int |\n| `--eplb-min-rebalancing-utilization-threshold` | Minimum threshold for GPU average utilization to trigger EPLB rebalancing. Must be in the range [0.0, 1.0]. | `1.0` | Type: float |\n| `--expert-distribution-recorder-mode` | Mode of expert distribution recorder. | `None` | Type: str |\n| `--expert-distribution-recorder-buffer-size` | Circular buffer size of expert distribution recorder. Set to -1 to denote infinite buffer. | `None` | Type: int |\n| `--enable-expert-distribution-metrics` | Enable logging metrics for expert balancedness | `False` | bool flag (set to enable) |\n| `--deepep-config` | Tuned DeepEP config suitable for your own cluster. It can be either a string with JSON content or a file path. | `None` | Type: str |\n| `--moe-dense-tp-size` | TP size for MoE dense MLP layers. This flag is useful when, with large TP size, there are errors caused by weights in MLP layers having dimension smaller than the min dimension GEMM supports. | `None` | Type: int |\n| `--elastic-ep-backend` | Specify the collective communication backend for elastic EP. Currently supports 'mooncake'. | `none` | `none`, `mooncake` |\n| `--enable-elastic-expert-backup` | Enable elastic EP backend to backup expert weights in DRAM feature. Currently supports 'mooncake'.| `False` | bool flag (set to enable) |\n| `--mooncake-ib-device` | The InfiniBand devices for Mooncake Backend transfer, accepts multiple comma-separated devices (e.g., --mooncake-ib-device mlx5_0,mlx5_1). Default is None, which triggers automatic device detection when Mooncake Backend is enabled. | `None` | Type: str |\n\n## Mamba Cache\n| Argument | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `--max-mamba-cache-size` | The maximum size of the mamba cache. | `None` | Type: int |\n| `--mamba-ssm-dtype` | The data type of the SSM states in mamba cache. | `float32` | `float32`, `bfloat16`, `float16` |\n| `--mamba-full-memory-ratio` | The ratio of mamba state memory to full kv cache memory. | `0.9` | Type: float |\n| `--mamba-scheduler-strategy` | The strategy to use for mamba scheduler. `auto` currently defaults to `no_buffer`. 1. `no_buffer` does not support overlap scheduler due to not allocating extra mamba state buffers. Branching point caching support is feasible but not implemented. 2. `extra_buffer` supports overlap schedule by allocating extra mamba state buffers to track mamba state for caching (mamba state usage per running req becomes `2x` for non-spec; `1+(1/(2+speculative_num_draft_tokens))x` for spec dec (e.g. 1.16x if speculative_num_draft_tokens==4)). 2a. `extra_buffer` is strictly better for non-KV-cache-bound cases; for KV-cache-bound cases, the tradeoff depends on whether enabling overlap outweighs reduced max running requests. 2b. mamba caching at radix cache branching point is strictly better than non-branch but requires kernel support (currently only FLA backend), currently only extra_buffer supports branching. | `auto` | `auto`, `no_buffer`, `extra_buffer` |\n| `--mamba-track-interval` | The interval (in tokens) to track the mamba state during decode. Only used when `--mamba-scheduler-strategy` is `extra_buffer`. Must be divisible by page_size if set, and must be >= speculative_num_draft_tokens when using speculative decoding. | `256` | Type: int |\n\n## Hierarchical cache\n| Argument | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `--enable-hierarchical-cache` | Enable hierarchical cache | `False` | bool flag (set to enable) |\n| `--hicache-ratio` | The ratio of the size of host KV cache memory pool to the size of device pool. | `2.0` | Type: float |\n| `--hicache-size` | The size of host KV cache memory pool in gigabytes, which will override the hicache_ratio if set. | `0` | Type: int |\n| `--hicache-write-policy` | The write policy of hierarchical cache. | `write_through` | `write_back`, `write_through`, `write_through_selective` |\n| `--hicache-io-backend` | The IO backend for KV cache transfer between CPU and GPU | `kernel` | `direct`, `kernel`, `kernel_ascend` |\n| `--hicache-mem-layout` | The layout of host memory pool for hierarchical cache. | `layer_first` | `layer_first`, `page_first`, `page_first_direct`, `page_first_kv_split`, `page_head` |\n| `--hicache-storage-backend` | The storage backend for hierarchical KV cache. Built-in backends: file, mooncake, hf3fs, nixl, aibrix. For dynamic backend, use --hicache-storage-backend-extra-config to specify: backend_name (custom name), module_path (Python module path), class_name (backend class name). | `None` | `file`, `mooncake`, `hf3fs`, `nixl`, `aibrix`, `dynamic`, `eic` |\n| `--hicache-storage-prefetch-policy` | Control when prefetching from the storage backend should stop. | `best_effort` | `best_effort`, `wait_complete`, `timeout` |\n| `--hicache-storage-backend-extra-config` | A dictionary in JSON string format, or a string starting with a `@` followed by a config file in JSON/YAML/TOML format, containing extra configuration for the storage backend. | `None` | Type: str |\n\n## Hierarchical sparse attention\n| Argument | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `--hierarchical-sparse-attention-extra-config` | A dictionary in JSON string format for hierarchical sparse attention configuration. Required fields: `algorithm` (str), `backend` (str). All other fields are algorithm-specific and passed to the algorithm constructor. | `None` | Type: str |\n\n## LMCache\n| Argument | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `--enable-lmcache` | Using LMCache as an alternative hierarchical cache solution | `False` | bool flag (set to enable) |\n\n## Ktransformers\n| Argument | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `--kt-weight-path` | [ktransformers parameter] The path of the quantized expert weights for amx kernel. A local folder. | `None` | Type: str |\n| `--kt-method` | [ktransformers parameter] Quantization formats for CPU execution. | `AMXINT4` | Type: str |\n| `--kt-cpuinfer` | [ktransformers parameter] The number of CPUInfer threads. | `None` | Type: int |\n| `--kt-threadpool-count` | [ktransformers parameter] One-to-one with the number of NUMA nodes (one thread pool per NUMA). | `2` | Type: int |\n| `--kt-num-gpu-experts` | [ktransformers parameter] The number of GPU experts. | `None` | Type: int |\n| `--kt-max-deferred-experts-per-token` | [ktransformers parameter] Maximum number of experts deferred to CPU per token. All MoE layers except the final one use this value; the final layer always uses 0. | `None` | Type: int |\n\n## Diffusion LLM\n\n| Argument | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `--dllm-algorithm` | The diffusion LLM algorithm, such as LowConfidence. | `None` | Type: str |\n| `--dllm-algorithm-config` | The diffusion LLM algorithm configurations. Must be a YAML file. | `None` | Type: str |\n\n## Double Sparsity\n| Argument | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `--enable-double-sparsity` | Enable double sparsity attention | `False` | bool flag (set to enable) |\n| `--ds-channel-config-path` | The path of the double sparsity channel config | `None` | Type: str |\n| `--ds-heavy-channel-num` | The number of heavy channels in double sparsity attention | `32` | Type: int |\n| `--ds-heavy-token-num` | The number of heavy tokens in double sparsity attention | `256` | Type: int |\n| `--ds-heavy-channel-type` | The type of heavy channels in double sparsity attention | `qk` | Type: str |\n| `--ds-sparse-decode-threshold` | The minimum decode sequence length required before the double-sparsity backend switches from the dense fallback to the sparse decode kernel. | `4096` | Type: int |\n\n## Offloading\n| Argument | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `--cpu-offload-gb` | How many GBs of RAM to reserve for CPU offloading. | `0` | Type: int |\n| `--offload-group-size` | Number of layers per group in offloading. | `-1` | Type: int |\n| `--offload-num-in-group` | Number of layers to be offloaded within a group. | `1` | Type: int |\n| `--offload-prefetch-step` | Steps to prefetch in offloading. | `1` | Type: int |\n| `--offload-mode` | Mode of offloading. | `cpu` | Type: str |\n\n## Args for multi-item scoring\n| Argument | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `--multi-item-scoring-delimiter` | Delimiter token ID for multi-item scoring. Used to combine Query and Items into a single sequence: Query<delimiter>Item1<delimiter>Item2<delimiter>... This enables efficient batch processing of multiple items against a single query. | `None` | Type: int |\n\n## Optimization/debug options\n| Argument | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `--disable-radix-cache` | Disable RadixAttention for prefix caching. | `False` | bool flag (set to enable) |\n| `--cuda-graph-max-bs` | Set the maximum batch size for cuda graph. It will extend the cuda graph capture batch size to this value. | `None` | Type: int |\n| `--cuda-graph-bs` | Set the list of batch sizes for cuda graph. | `None` | List[int] |\n| `--disable-cuda-graph` | Disable cuda graph. | `False` | bool flag (set to enable) |\n| `--disable-cuda-graph-padding` | Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed. | `False` | bool flag (set to enable) |\n| `--enable-profile-cuda-graph` | Enable profiling of cuda graph capture. | `False` | bool flag (set to enable) |\n| `--enable-cudagraph-gc` | Enable garbage collection during CUDA graph capture. If disabled (default), GC is frozen during capture to speed up the process. | `False` | bool flag (set to enable) |\n| `--enable-layerwise-nvtx-marker` | Enable layerwise NVTX profiling annotations for the model. This adds NVTX markers to every layer for detailed per-layer performance analysis with Nsight Systems. | `False` | bool flag (set to enable) |\n| `--enable-nccl-nvls` | Enable NCCL NVLS for prefill heavy requests when available. | `False` | bool flag (set to enable) |\n| `--enable-symm-mem` | Enable NCCL symmetric memory for fast collectives. | `False` | bool flag (set to enable) |\n| `--disable-flashinfer-cutlass-moe-fp4-allgather` | Disables quantize before all-gather for flashinfer cutlass moe. | `False` | bool flag (set to enable) |\n| `--enable-tokenizer-batch-encode` | Enable batch tokenization for improved performance when processing multiple text inputs. Do not use with image inputs, pre-tokenized input_ids, or input_embeds. | `False` | bool flag (set to enable) |\n| `--disable-tokenizer-batch-decode` | Disable batch decoding when decoding multiple completions. | `False` | bool flag (set to enable) |\n| `--disable-outlines-disk-cache` | Disable disk cache of outlines to avoid possible crashes related to file system or high concurrency. | `False` | bool flag (set to enable) |\n| `--disable-custom-all-reduce` | Disable the custom all-reduce kernel and fall back to NCCL. | `False` | bool flag (set to enable) |\n| `--enable-mscclpp` | Enable using mscclpp for small messages for all-reduce kernel and fall back to NCCL. | `False` | bool flag (set to enable) |\n| `--enable-torch-symm-mem` | Enable using torch symm mem for all-reduce kernel and fall back to NCCL. Only supports CUDA device SM90 and above. SM90 supports world size 4, 6, 8. SM10 supports world size 6, 8. | `False` | bool flag (set to enable) |\n| `--disable-overlap-schedule` | Disable the overlap scheduler, which overlaps the CPU scheduler with GPU model worker. | `False` | bool flag (set to enable) |\n| `--enable-mixed-chunk` | Enabling mixing prefill and decode in a batch when using chunked prefill. | `False` | bool flag (set to enable) |\n| `--enable-dp-attention` | Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently DeepSeek-V2 and Qwen 2/3 MoE models are supported. | `False` | bool flag (set to enable) |\n| `--enable-dp-lm-head` | Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention. | `False` | bool flag (set to enable) |\n| `--enable-two-batch-overlap` | Enabling two micro batches to overlap. | `False` | bool flag (set to enable) |\n| `--enable-single-batch-overlap` | Let computation and communication overlap within one micro batch. | `False` | bool flag (set to enable) |\n| `--tbo-token-distribution-threshold` | The threshold of token distribution between two batches in micro-batch-overlap, determines whether to two-batch-overlap or two-chunk-overlap. Set to 0 denote disable two-chunk-overlap. | `0.48` | Type: float |\n| `--enable-torch-compile` | Optimize the model with torch.compile. Experimental feature. | `False` | bool flag (set to enable) |\n| `--enable-torch-compile-debug-mode` | Enable debug mode for torch compile. | `False` | bool flag (set to enable) |\n| `--disable-piecewise-cuda-graph` | Disable piecewise cuda graph for extend/prefill. PCG is enabled by default. | `False` | bool flag (set to disable) |\n| `--enforce-piecewise-cuda-graph` | Enforce piecewise cuda graph, skipping all auto-disable conditions. For testing only. | `False` | bool flag (set to enable) |\n| `--piecewise-cuda-graph-tokens` | Set the list of tokens when using piecewise cuda graph. | `None` | Type: JSON list |\n| `--piecewise-cuda-graph-compiler` | Set the compiler for piecewise cuda graph. Choices are: eager, inductor. | `eager` | `eager`, `inductor` |\n| `--torch-compile-max-bs` | Set the maximum batch size when using torch compile. | `32` | Type: int |\n| `--piecewise-cuda-graph-max-tokens` | Set the maximum tokens when using piecewise cuda graph. | `4096` | Type: int |\n| `--torchao-config` | Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row | `` | Type: str |\n| `--enable-nan-detection` | Enable the NaN detection for debugging purposes. | `False` | bool flag (set to enable) |\n| `--enable-p2p-check` | Enable P2P check for GPU access, otherwise the p2p access is allowed by default. | `False` | bool flag (set to enable) |\n| `--triton-attention-reduce-in-fp32` | Cast the intermediate attention results to fp32 to avoid possible crashes related to fp16. This only affects Triton attention kernels. | `False` | bool flag (set to enable) |\n| `--triton-attention-num-kv-splits` | The number of KV splits in flash decoding Triton kernel. Larger value is better in longer context scenarios. The default value is 8. | `8` | Type: int |\n| `--triton-attention-split-tile-size` | The size of split KV tile in flash decoding Triton kernel. Used for deterministic inference. | `None` | Type: int |\n| `--num-continuous-decode-steps` | Run multiple continuous decoding steps to reduce scheduling overhead. This can potentially increase throughput but may also increase time-to-first-token latency. The default value is 1, meaning only run one decoding step at a time. | `1` | Type: int |\n| `--delete-ckpt-after-loading` | Delete the model checkpoint after loading the model. | `False` | bool flag (set to enable) |\n| `--enable-memory-saver` | Allow saving memory using release_memory_occupation and resume_memory_occupation | `False` | bool flag (set to enable) |\n| `--enable-weights-cpu-backup` | Save model weights to CPU memory during release_weights_occupation and resume_weights_occupation | `False` | bool flag (set to enable) |\n| `--enable-draft-weights-cpu-backup` | Save draft model weights to CPU memory during release_weights_occupation and resume_weights_occupation | `False` | bool flag (set to enable) |\n| `--allow-auto-truncate` | Allow automatically truncating requests that exceed the maximum input length instead of returning an error. | `False` | bool flag (set to enable) |\n| `--enable-custom-logit-processor` | Enable users to pass custom logit processors to the server (disabled by default for security) | `False` | bool flag (set to enable) |\n| `--flashinfer-mla-disable-ragged` | Not using ragged prefill wrapper when running flashinfer mla | `False` | bool flag (set to enable) |\n| `--disable-shared-experts-fusion` | Disable shared experts fusion optimization for deepseek v3/r1. | `False` | bool flag (set to enable) |\n| `--disable-chunked-prefix-cache` | Disable chunked prefix cache feature for deepseek, which should save overhead for short sequences. | `False` | bool flag (set to enable) |\n| `--disable-fast-image-processor` | Adopt base image processor instead of fast image processor. | `False` | bool flag (set to enable) |\n| `--keep-mm-feature-on-device` | Keep multimodal feature tensors on device after processing to save D2H copy. | `False` | bool flag (set to enable) |\n| `--enable-return-hidden-states` | Enable returning hidden states with responses. | `False` | bool flag (set to enable) |\n| `--enable-return-routed-experts` | Enable returning routed experts of each layer with responses. | `False` | bool flag (set to enable) |\n| `--scheduler-recv-interval` | The interval to poll requests in scheduler. Can be set to >1 to reduce the overhead of this. | `1` | Type: int |\n| `--numa-node` | Sets the numa node for the subprocesses. i-th element corresponds to i-th subprocess. | `None` | List[int] |\n| `--enable-deterministic-inference` | Enable deterministic inference mode with batch invariant ops. | `False` | bool flag (set to enable) |\n| `--rl-on-policy-target` | The training system that SGLang needs to match for true on-policy. | `None` | `fsdp` |\n| `--enable-attn-tp-input-scattered` | Allow input of attention to be scattered when only using tensor parallelism, to reduce the computational load of operations such as qkv latent. | `False` | bool flag (set to enable) |\n| `--enable-nsa-prefill-context-parallel` | Enable context parallelism used in the long sequence prefill phase of DeepSeek v3.2. | `False` | bool flag (set to enable) |\n| `--nsa-prefill-cp-mode` | Token splitting mode for the prefill phase of DeepSeek v3.2 under context parallelism. Optional values: `round-robin-split`(default),`in-seq-split`. `round-robin-split` distributes tokens across ranks based on `token_idx % cp_size`. It supports multi-batch prefill, fused MoE, and FP8 KV cache. | `in-seq-split` | `in-seq-split`, `round-robin-split` |\n| `--enable-fused-qk-norm-rope` | Enable fused qk normalization and rope rotary embedding. | `False` | bool flag (set to enable) |\n| `--enable-precise-embedding-interpolation` | Enable corner alignment for resize of embeddings grid to ensure more accurate(but slower) evaluation of interpolated embedding values. | `False` | bool flag (set to enable) |\n\n## Dynamic batch tokenizer\n| Argument | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `--enable-dynamic-batch-tokenizer` | Enable async dynamic batch tokenizer for improved performance when multiple requests arrive concurrently. | `False` | bool flag (set to enable) |\n| `--dynamic-batch-tokenizer-batch-size` | [Only used if --enable-dynamic-batch-tokenizer is set] Maximum batch size for dynamic batch tokenizer. | `32` | Type: int |\n| `--dynamic-batch-tokenizer-batch-timeout` | [Only used if --enable-dynamic-batch-tokenizer is set] Timeout in seconds for batching tokenization requests. | `0.002` | Type: float |\n\n## Debug tensor dumps\n| Argument | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `--debug-tensor-dump-output-folder` | The output folder for dumping tensors. | `None` | Type: str |\n| `--debug-tensor-dump-layers` | The layer ids to dump. Dump all layers if not specified. | `None` | Type: JSON list |\n| `--debug-tensor-dump-input-file` | The input filename for dumping tensors | `None` | Type: str |\n| `--debug-tensor-dump-inject` | Inject the outputs from jax as the input of every layer. | `False` | Type: str |\n\n## PD disaggregation\n| Argument | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `--disaggregation-mode` | Only used for PD disaggregation. \"prefill\" for prefill-only server, and \"decode\" for decode-only server. If not specified, it is not PD disaggregated | `null` | `null`, `prefill`, `decode` |\n| `--disaggregation-transfer-backend` | The backend for disaggregation transfer. Default is mooncake. | `mooncake` | `mooncake`, `nixl`, `ascend`, `fake` |\n| `--disaggregation-bootstrap-port` | Bootstrap server port on the prefill server. Default is 8998. | `8998` | Type: int |\n| `--disaggregation-ib-device` | The InfiniBand devices for disaggregation transfer, accepts single device (e.g., --disaggregation-ib-device mlx5_0) or multiple comma-separated devices (e.g., --disaggregation-ib-device mlx5_0,mlx5_1). Default is None, which triggers automatic device detection when mooncake backend is enabled. | `None` | Type: str |\n| `--disaggregation-decode-enable-offload-kvcache` | Enable async KV cache offloading on decode server (PD mode). | `False` | bool flag (set to enable) |\n| `--num-reserved-decode-tokens` | Number of decode tokens that will have memory reserved when adding new request to the running batch. | `512` | Type: int |\n| `--disaggregation-decode-polling-interval` | The interval to poll requests in decode server. Can be set to >1 to reduce the overhead of this. | `1` | Type: int |\n\n## Encode prefill disaggregation\n| Argument | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `--encoder-only` | For MLLM with an encoder, launch an encoder-only server | `False` | bool flag (set to enable) |\n| `--language-only` | For VLM, load weights for the language model only. | `False` | bool flag (set to enable) |\n| `--encoder-transfer-backend` | The backend for encoder disaggregation transfer. Default is zmq_to_scheduler. | `zmq_to_scheduler` | `zmq_to_scheduler`, `zmq_to_tokenizer`, `mooncake` |\n| `--encoder-urls` | List of encoder server urls. | `[]` | Type: JSON list |\n\n## Custom weight loader\n| Argument | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `--custom-weight-loader` | The custom dataloader which used to update the model. Should be set with a valid import path, such as my_package.weight_load_func | `None` | List[str] |\n| `--weight-loader-disable-mmap` | Disable mmap while loading weight using safetensors. | `False` | bool flag (set to enable) |\n| `--remote-instance-weight-loader-seed-instance-ip` | The ip of the seed instance for loading weights from remote instance. | `None` | Type: str |\n| `--remote-instance-weight-loader-seed-instance-service-port` | The service port of the seed instance for loading weights from remote instance. | `None` | Type: int |\n| `--remote-instance-weight-loader-send-weights-group-ports` | The communication group ports for loading weights from remote instance. | `None` | Type: JSON list |\n| `--remote-instance-weight-loader-backend` | The backend for loading weights from remote instance. Can be 'transfer_engine' or 'nccl'. Default is 'nccl'. | `nccl` | `transfer_engine`, `nccl` |\n| `--remote-instance-weight-loader-start-seed-via-transfer-engine` | Start seed server via transfer engine backend for remote instance weight loader. | `False` | bool flag (set to enable) |\n\n## For PD-Multiplexing\n| Argument | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `--enable-pdmux` | Enable PD-Multiplexing, PD running on greenctx stream. | `False` | bool flag (set to enable) |\n| `--pdmux-config-path` | The path of the PD-Multiplexing config file. | `None` | Type: str |\n| `--sm-group-num` | Number of sm partition groups. | `8` | Type: int |\n\n## Configuration file support\n| Argument | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `--config` | Read CLI options from a config file. Must be a YAML file with configuration options. | `None` | Type: str |\n\n## For Multi-Modal\n| Argument | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `--mm-max-concurrent-calls` | The max concurrent calls for async mm data processing. | `32` | Type: int |\n| `--mm-per-request-timeout` | The timeout for each multi-modal request in seconds. | `10.0` | Type: int |\n| `--enable-broadcast-mm-inputs-process` | Enable broadcast mm-inputs process in scheduler. | `False` | bool flag (set to enable) |\n| `--mm-process-config` | Multimodal preprocessing config, a json config contains keys: `image`, `video`, `audio`. | `{}` | Type: JSON / Dict |\n| `--mm-enable-dp-encoder` | Enabling data parallelism for mm encoder. The dp size will be set to the tp size automatically. | `False` | bool flag (set to enable) |\n| `--limit-mm-data-per-request` | Limit the number of multimodal inputs per request. e.g. '{\"image\": 1, \"video\": 1, \"audio\": 1}' | `None` | Type: JSON / Dict |\n| `--enable-mm-global-cache` | Enable Mooncake-backed global multimodal embedding cache on encoder servers so repeated images can reuse cached ViT embeddings instead of recomputing them. | `False` | bool flag (set to enable) |\n\n## For checkpoint decryption\n| Argument | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `--decrypted-config-file` | The path of the decrypted config file. | `None` | Type: str |\n| `--decrypted-draft-config-file` | The path of the decrypted draft config file. | `None` | Type: str |\n| `--enable-prefix-mm-cache` | Enable prefix multimodal cache. Currently only supports mm-only. | `False` | bool flag (set to enable) |\n\n## Forward hooks\n| Argument | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `--forward-hooks` | JSON-formatted list of forward hook specifications. Each element must include `target_modules` (list of glob patterns matched against `model.named_modules()` names) and `hook_factory` (Python import path to a factory, e.g. `my_package.hooks:make_hook`). An optional `name` field is used for logging, and an optional `config` object is passed as a `dict` to the factory. | `None` | Type: JSON list |\n\n## Deprecated arguments\n| Argument | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `--enable-ep-moe` | NOTE: --enable-ep-moe is deprecated. Please set `--ep-size` to the same value as `--tp-size` instead. | `None` | N/A |\n| `--enable-deepep-moe` | NOTE: --enable-deepep-moe is deprecated. Please set `--moe-a2a-backend` to 'deepep' instead. | `None` | N/A |\n| `--prefill-round-robin-balance` | Note: Note: --prefill-round-robin-balance is deprecated now. | `None` | N/A |\n| `--enable-flashinfer-cutlass-moe` | NOTE: --enable-flashinfer-cutlass-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_cutlass' instead. | `None` | N/A |\n| `--enable-flashinfer-cutedsl-moe` | NOTE: --enable-flashinfer-cutedsl-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_cutedsl' instead. | `None` | N/A |\n| `--enable-flashinfer-trtllm-moe` | NOTE: --enable-flashinfer-trtllm-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_trtllm' instead. | `None` | N/A |\n| `--enable-triton-kernel-moe` | NOTE: --enable-triton-kernel-moe is deprecated. Please set `--moe-runner-backend` to 'triton_kernel' instead. | `None` | N/A |\n| `--enable-flashinfer-mxfp4-moe` | NOTE: --enable-flashinfer-mxfp4-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_mxfp4' instead. | `None` | N/A |\n| `--crash-on-nan` | Crash the server on nan logprobs. | `False` | Type: str |\n| `--hybrid-kvcache-ratio` | Mix ratio in [0,1] between uniform and hybrid kv buffers (0.0 = pure uniform: swa_size / full_size = 1)(1.0 = pure hybrid: swa_size / full_size = local_attention_size / context_length) | `None` | Optional[float] |\n| `--load-watch-interval` | The interval of load watching in seconds. | `0.1` | Type: float |\n| `--nsa-prefill` | Choose the NSA backend for the prefill stage (overrides `--attention-backend` when running DeepSeek NSA-style attention). | `flashmla_sparse` | `flashmla_sparse`, `flashmla_decode`, `fa3`, `tilelang`, `aiter` |\n| `--nsa-decode` | Choose the NSA backend for the decode stage when running DeepSeek NSA-style attention. Overrides `--attention-backend` for decoding. | `flashmla_kv` | `flashmla_prefill`, `flashmla_kv`, `fa3`, `tilelang`, `aiter` |\n"
  },
  {
    "path": "docs/advanced_features/sgl_model_gateway.md",
    "content": "# SGLang Model Gateway\n\nSGLang Model Gateway is a high-performance model-routing gateway for large-scale LLM deployments. It centralizes worker lifecycle management, balances traffic across heterogeneous protocols (HTTP, gRPC, OpenAI-compatible), and provides enterprise-ready control over history storage, MCP tooling, and privacy-sensitive workflows. The gateway is deeply optimized for the SGLang serving runtime, but can route to any OpenAI-compatible backend.\n\n---\n\n## Table of Contents\n\n1. [Overview](#overview)\n2. [Architecture](#architecture)\n   - [Control Plane](#control-plane)\n   - [Data Plane](#data-plane)\n   - [Storage and Privacy](#storage-and-privacy)\n3. [Installation](#installation)\n4. [Quick Start](#quick-start)\n5. [Deployment Modes](#deployment-modes)\n   - [Co-launch Router and Workers](#co-launch-router-and-workers)\n   - [Separate Launch (HTTP)](#separate-launch-http)\n   - [gRPC Launch](#grpc-launch)\n   - [Prefill-Decode Disaggregation](#prefill-decode-disaggregation)\n   - [OpenAI Backend Proxy](#openai-backend-proxy)\n   - [Multi-Model Inference Gateway](#multi-model-inference-gateway)\n6. [API Reference](#api-reference)\n   - [Inference Endpoints](#inference-endpoints)\n   - [Tokenization Endpoints](#tokenization-endpoints)\n   - [Parser Endpoints](#parser-endpoints)\n   - [Classification API](#classification-api)\n   - [Conversation and Response APIs](#conversation-and-response-apis)\n   - [Worker Management APIs](#worker-management-apis)\n   - [Admin and Health Endpoints](#admin-and-health-endpoints)\n7. [Load Balancing Policies](#load-balancing-policies)\n8. [Reliability and Flow Control](#reliability-and-flow-control)\n   - [Retries](#retries)\n   - [Circuit Breaker](#circuit-breaker)\n   - [Rate Limiting and Queuing](#rate-limiting-and-queuing)\n   - [Health Checks](#health-checks)\n9. [Reasoning Parser Integration](#reasoning-parser-integration)\n10. [Tool Call Parsing](#tool-call-parsing)\n11. [Tokenizer Management](#tokenizer-management)\n12. [MCP Integration](#mcp-integration)\n13. [Service Discovery (Kubernetes)](#service-discovery-kubernetes)\n14. [History and Data Connectors](#history-and-data-connectors)\n15. [WASM Middleware](#wasm-middleware)\n16. [Language Bindings](#language-bindings)\n17. [Security and Authentication](#security-and-authentication)\n    - [TLS (HTTPS) for Gateway Server](#tls-https-for-gateway-server)\n    - [mTLS for Worker Communication](#mtls-for-worker-communication)\n18. [Observability](#observability)\n    - [Prometheus Metrics](#prometheus-metrics)\n    - [OpenTelemetry Tracing](#opentelemetry-tracing)\n    - [Logging](#logging)\n19. [Production Recommendations](#production-recommendations)\n    - [Security Best Practices](#security-best-practices)\n    - [High Availability](#high-availability)\n    - [Performance](#performance)\n    - [Kubernetes Deployment](#kubernetes-deployment)\n    - [Monitoring with PromQL](#monitoring-with-promql)\n20. [Configuration Reference](#configuration-reference)\n21. [Troubleshooting](#troubleshooting)\n\n---\n\n## Overview\n\n- **Unified control plane** for registering, monitoring, and orchestrating regular, prefill, and decode workers across heterogeneous model fleets.\n- **Multi-protocol data plane** that routes traffic across HTTP, PD (prefill/decode), gRPC, and OpenAI-compatible backends with shared reliability primitives.\n- **Industry-first gRPC pipeline** with native Rust tokenization, reasoning parsers, and tool-call execution for high-throughput, OpenAI-compatible serving; supports both single-stage and PD topologies.\n- **Inference Gateway Mode (`--enable-igw`)** dynamically instantiates multiple router stacks (HTTP regular/PD, gRPC) and applies per-model policies for multi-tenant deployments.\n- **Conversation & responses connectors** centralize chat history inside the router so the same context can be reused across models and MCP loops without leaking data to upstream vendors (memory, none, Oracle ATP, PostgreSQL).\n- **Enterprise privacy**: agentic multi-turn `/v1/responses`, native MCP client (STDIO/HTTP/SSE/Streamable), and history storage all operate within the router boundary.\n- **Reliability core**: retries with jitter, worker-scoped circuit breakers, token-bucket rate limiting with queuing, background health checks, and cache-aware load monitoring.\n- **Comprehensive observability**: 40+ Prometheus metrics, OpenTelemetry distributed tracing, structured logging, and request ID propagation.\n\n---\n\n## Architecture\n\n### Control Plane\n\n- **Worker Manager** discovers capabilities (`/get_server_info`, `/get_model_info`), tracks load, and registers/removes workers in the shared registry.\n- **Job Queue** serializes add/remove requests and exposes status (`/workers/{worker_id}`) so clients can track onboarding progress.\n- **Load Monitor** feeds cache-aware and power-of-two policies with live worker load statistics.\n- **Health Checker** continuously probes workers and updates readiness, circuit breaker state, and router metrics.\n- **Tokenizer Registry** manages dynamically registered tokenizers with async loading from HuggingFace or local paths.\n\n### Data Plane\n\n- **HTTP routers** (regular & PD) implement `/generate`, `/v1/chat/completions`, `/v1/completions`, `/v1/responses`, `/v1/embeddings`, `/v1/rerank`, `/v1/classify`, `/v1/tokenize`, `/v1/detokenize`, and associated admin endpoints.\n- **gRPC router** streams tokenized requests directly to SRT gRPC workers, running fully in Rust—tokenizer, reasoning parser, and tool parser all reside in-process. Supports both single-stage and PD routing, including embeddings and classification.\n- **OpenAI router** proxies OpenAI-compatible endpoints to external vendors (OpenAI, xAI, etc.) while keeping chat history and multi-turn orchestration local.\n\n### Storage and Privacy\n\n- Conversation and response history is stored at the router tier (memory, none, Oracle ATP, or PostgreSQL). The same history can power multiple models or MCP loops without sending data to upstream vendors.\n- `/v1/responses` agentic flows, MCP sessions, and conversation APIs share the same storage layer, enabling compliance for regulated workloads.\n\n---\n\n## Installation\n\n### Docker\n\nPre-built Docker images are available on Docker Hub with multi-architecture support (x86_64 and ARM64):\n\n```bash\ndocker pull lmsysorg/sgl-model-gateway:latest\n```\n\n### Prerequisites\n\n- **Rust and Cargo**\n  ```bash\n  curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh\n  source \"$HOME/.cargo/env\"\n  rustc --version\n  cargo --version\n  ```\n- **Python** with `pip` and virtualenv tooling available.\n\n### Rust Binary\n\n```bash\ncd sgl-model-gateway\ncargo build --release\n```\n\n### Python Package\n\n```bash\npip install maturin\n\n# Fast development mode\ncd sgl-model-gateway/bindings/python\nmaturin develop\n\n# Production build\nmaturin build --release --out dist --features vendored-openssl\npip install --force-reinstall dist/*.whl\n```\n\n---\n\n## Quick Start\n\n### Regular HTTP Routing\n\n```bash\n# Rust binary\n./target/release/sgl-model-gateway \\\n  --worker-urls http://worker1:8000 http://worker2:8000 \\\n  --policy cache_aware\n\n# Python launcher\npython -m sglang_router.launch_router \\\n  --worker-urls http://worker1:8000 http://worker2:8000 \\\n  --policy cache_aware\n```\n\n### gRPC Routing\n\n```bash\npython -m sglang_router.launch_router \\\n  --worker-urls grpc://127.0.0.1:20000 \\\n  --model-path meta-llama/Llama-3.1-8B-Instruct \\\n  --reasoning-parser deepseek-r1 \\\n  --tool-call-parser json \\\n  --host 0.0.0.0 --port 8080\n```\n\n---\n\n## Deployment Modes\n\n### Co-launch Router and Workers\n\nLaunch the router and a fleet of SGLang workers in one process:\n\n```bash\npython -m sglang_router.launch_server \\\n  --model meta-llama/Meta-Llama-3.1-8B-Instruct \\\n  --dp-size 4 \\\n  --host 0.0.0.0 \\\n  --port 30000\n```\n\nComprehensive example with router arguments (prefixed with `--router-`):\n\n```bash\npython -m sglang_router.launch_server \\\n  --host 0.0.0.0 \\\n  --port 8080 \\\n  --model meta-llama/Llama-3.1-8B-Instruct \\\n  --tp-size 1 \\\n  --dp-size 8 \\\n  --grpc-mode \\\n  --log-level debug \\\n  --router-prometheus-port 10001 \\\n  --router-tool-call-parser llama \\\n  --router-model-path meta-llama/Llama-3.1-8B-Instruct \\\n  --router-policy round_robin \\\n  --router-log-level debug\n```\n\n### Separate Launch (HTTP)\n\nRun workers independently and point the router at their HTTP endpoints:\n\n```bash\n# Worker nodes\npython -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --port 8000\npython -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --port 8001\n\n# Router node\npython -m sglang_router.launch_router \\\n  --worker-urls http://worker1:8000 http://worker2:8001 \\\n  --policy cache_aware \\\n  --host 0.0.0.0 --port 30000\n```\n\n### gRPC Launch\n\nUse SRT gRPC workers to unlock the highest throughput and access native reasoning/tool pipelines:\n\n```bash\n# Workers expose gRPC endpoints\npython -m sglang.launch_server \\\n  --model meta-llama/Llama-3.1-8B-Instruct \\\n  --grpc-mode \\\n  --port 20000\n\n# Router\npython -m sglang_router.launch_router \\\n  --worker-urls grpc://127.0.0.1:20000 \\\n  --model-path meta-llama/Llama-3.1-8B-Instruct \\\n  --reasoning-parser deepseek-r1 \\\n  --tool-call-parser json \\\n  --host 0.0.0.0 --port 8080\n```\n\nThe gRPC router supports both regular HTTP-equivalent serving and PD (prefill/decode) serving. Provide `--tokenizer-path` or `--model-path` (HuggingFace ID or local directory) whenever connection mode resolves to gRPC.\n\n### Prefill-Decode Disaggregation\n\nSplit prefill and decode workers for PD-aware caching and balancing:\n\n```bash\npython -m sglang_router.launch_router \\\n  --pd-disaggregation \\\n  --prefill http://prefill1:30001 9001 \\\n  --decode http://decode1:30011 \\\n  --prefill-policy cache_aware \\\n  --decode-policy power_of_two\n```\n\nPrefill entries accept an optional bootstrap port. PD mode merges prefill metadata with decode outputs and streams results back to the client.\n\n### OpenAI Backend Proxy\n\nProxy OpenAI-compatible endpoints while keeping history and MCP sessions local:\n\n```bash\npython -m sglang_router.launch_router \\\n  --backend openai \\\n  --worker-urls https://api.openai.com \\\n  --history-backend memory\n```\n\nOpenAI backend mode expects exactly one `--worker-urls` entry per router instance.\n\n### Multi-Model Inference Gateway\n\nEnable IGW mode to route multiple models through a single router:\n\n```bash\n./target/release/sgl-model-gateway \\\n  --enable-igw \\\n  --policy cache_aware \\\n  --max-concurrent-requests 512\n\n# Register workers dynamically\ncurl -X POST http://localhost:30000/workers \\\n  -H \"Content-Type: application/json\" \\\n  -d '{\n        \"url\": \"http://worker-a:8000\",\n        \"model_id\": \"mistral\",\n        \"priority\": 10,\n        \"labels\": {\"tier\": \"gold\"}\n      }'\n```\n\n---\n\n## API Reference\n\n### Inference Endpoints\n\n| Method | Path | Description |\n|--------|------|-------------|\n| `POST` | `/generate` | SGLang generate API |\n| `POST` | `/v1/chat/completions` | OpenAI-compatible chat completions (streaming/tool calls) |\n| `POST` | `/v1/completions` | OpenAI-compatible text completions |\n| `POST` | `/v1/embeddings` | Embedding generation (HTTP and gRPC) |\n| `POST` | `/v1/rerank`, `/rerank` | Reranking requests |\n| `POST` | `/v1/classify` | Text classification |\n\n### Tokenization Endpoints\n\nThe gateway provides HTTP endpoints for text tokenization with batch support, designed to mirror the SGLang Python tokenization API.\n\n| Method | Path | Description |\n|--------|------|-------------|\n| `POST` | `/v1/tokenize` | Tokenize text to token IDs (single or batch) |\n| `POST` | `/v1/detokenize` | Convert token IDs back to text (single or batch) |\n| `POST` | `/v1/tokenizers` | Register a new tokenizer (async, returns job status) |\n| `GET` | `/v1/tokenizers` | List all registered tokenizers |\n| `GET` | `/v1/tokenizers/{id}` | Get tokenizer info by UUID |\n| `GET` | `/v1/tokenizers/{id}/status` | Check async tokenizer loading status |\n| `DELETE` | `/v1/tokenizers/{id}` | Remove a tokenizer from the registry |\n\n#### Tokenize Request\n\n```json\n{\n  \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n  \"prompt\": \"Hello, world!\"\n}\n```\n\n#### Batch Tokenize Request\n\n```json\n{\n  \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n  \"prompt\": [\"Hello\", \"World\", \"How are you?\"]\n}\n```\n\n#### Tokenize Response\n\n```json\n{\n  \"tokens\": [15339, 11, 1917, 0],\n  \"count\": 4,\n  \"char_count\": 13\n}\n```\n\n#### Detokenize Request\n\n```json\n{\n  \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n  \"tokens\": [15339, 11, 1917, 0],\n  \"skip_special_tokens\": true\n}\n```\n\n#### Detokenize Response\n\n```json\n{\n  \"text\": \"Hello, world!\"\n}\n```\n\n#### Add Tokenizer (Async)\n\n```bash\ncurl -X POST http://localhost:30000/v1/tokenizers \\\n  -H \"Content-Type: application/json\" \\\n  -d '{\"name\": \"llama3\", \"source\": \"meta-llama/Llama-3.1-8B-Instruct\"}'\n```\n\nResponse:\n```json\n{\n  \"id\": \"550e8400-e29b-41d4-a716-446655440000\",\n  \"status\": \"pending\",\n  \"message\": \"Tokenizer registration queued\"\n}\n```\n\nCheck status:\n```bash\ncurl http://localhost:30000/v1/tokenizers/550e8400-e29b-41d4-a716-446655440000/status\n```\n\n### Parser Endpoints\n\nThe gateway provides admin endpoints for parsing reasoning content and function calls from LLM outputs.\n\n| Method | Path | Description |\n|--------|------|-------------|\n| `POST` | `/parse/reasoning` | Separate reasoning (`<think>`) from normal text |\n| `POST` | `/parse/function_call` | Parse function/tool calls from text |\n\n#### Separate Reasoning Request\n\n```json\n{\n  \"text\": \"<think>Let me analyze this step by step...</think>The answer is 42.\",\n  \"parser\": \"deepseek-r1\"\n}\n```\n\n#### Response\n\n```json\n{\n  \"normal_text\": \"The answer is 42.\",\n  \"reasoning_text\": \"Let me analyze this step by step...\"\n}\n```\n\n#### Function Call Parsing\n\n```json\n{\n  \"text\": \"{\\\"name\\\": \\\"get_weather\\\", \\\"arguments\\\": {\\\"city\\\": \\\"NYC\\\"}}\",\n  \"parser\": \"json\"\n}\n```\n\n### Classification API\n\nThe `/v1/classify` endpoint provides text classification using sequence classification models (e.g., `Qwen2ForSequenceClassification`, `BertForSequenceClassification`).\n\n#### Request\n\n```bash\ncurl http://localhost:30000/v1/classify \\\n  -H \"Content-Type: application/json\" \\\n  -d '{\n    \"model\": \"jason9693/Qwen2.5-1.5B-apeach\",\n    \"input\": \"I love this product!\"\n  }'\n```\n\n#### Response\n\n```json\n{\n  \"id\": \"classify-a1b2c3d4-5678-90ab-cdef-1234567890ab\",\n  \"object\": \"list\",\n  \"created\": 1767034308,\n  \"model\": \"jason9693/Qwen2.5-1.5B-apeach\",\n  \"data\": [\n    {\n      \"index\": 0,\n      \"label\": \"positive\",\n      \"probs\": [0.12, 0.88],\n      \"num_classes\": 2\n    }\n  ],\n  \"usage\": {\n    \"prompt_tokens\": 6,\n    \"completion_tokens\": 0,\n    \"total_tokens\": 6\n  }\n}\n```\n\n#### Response Fields\n\n| Field | Description |\n|-------|-------------|\n| `label` | Predicted class label (from model's `id2label` config, or `LABEL_N` fallback) |\n| `probs` | Probability distribution over all classes (softmax of logits) |\n| `num_classes` | Number of classification classes |\n\n#### Notes\n\n- Classification reuses the embedding backend—the scheduler returns logits which are converted to probabilities via softmax\n- Labels come from the model's HuggingFace config (`id2label` field); models without this mapping use generic labels (`LABEL_0`, `LABEL_1`, etc.)\n- Both HTTP and gRPC routers support classification\n\n### Conversation and Response APIs\n\n| Method | Path | Description |\n|--------|------|-------------|\n| `POST` | `/v1/responses` | Create background responses (agentic loops) |\n| `GET` | `/v1/responses/{id}` | Retrieve stored response |\n| `POST` | `/v1/responses/{id}/cancel` | Cancel background response |\n| `DELETE` | `/v1/responses/{id}` | Delete response |\n| `GET` | `/v1/responses/{id}/input_items` | List response input items |\n| `POST` | `/v1/conversations` | Create conversation |\n| `GET` | `/v1/conversations/{id}` | Get conversation |\n| `POST` | `/v1/conversations/{id}` | Update conversation |\n| `DELETE` | `/v1/conversations/{id}` | Delete conversation |\n| `GET` | `/v1/conversations/{id}/items` | List conversation items |\n| `POST` | `/v1/conversations/{id}/items` | Add items to conversation |\n| `GET` | `/v1/conversations/{id}/items/{item_id}` | Get conversation item |\n| `DELETE` | `/v1/conversations/{id}/items/{item_id}` | Delete conversation item |\n\n### Worker Management APIs\n\n| Method | Path | Description |\n|--------|------|-------------|\n| `POST` | `/workers` | Queue worker registration (returns 202 Accepted) |\n| `GET` | `/workers` | List workers with health, load, and policy metadata |\n| `GET` | `/workers/{worker_id}` | Inspect specific worker or job queue entry |\n| `PUT` | `/workers/{worker_id}` | Queue worker update |\n| `DELETE` | `/workers/{worker_id}` | Queue worker removal |\n\n#### Add Worker\n\n```bash\ncurl -X POST http://localhost:30000/workers \\\n  -H \"Content-Type: application/json\" \\\n  -d '{\"url\":\"grpc://0.0.0.0:31000\",\"worker_type\":\"regular\"}'\n```\n\n#### List Workers\n\n```bash\ncurl http://localhost:30000/workers\n```\n\nResponse:\n```json\n{\n  \"workers\": [\n    {\n      \"id\": \"2f3a0c3e-3a7b-4c3f-8c70-1b7d4c3a6e1f\",\n      \"url\": \"http://0.0.0.0:31378\",\n      \"model_id\": \"mistral\",\n      \"priority\": 50,\n      \"cost\": 1.0,\n      \"worker_type\": \"regular\",\n      \"is_healthy\": true,\n      \"load\": 0,\n      \"connection_mode\": \"Http\"\n    }\n  ],\n  \"total\": 1,\n  \"stats\": {\n    \"prefill_count\": 0,\n    \"decode_count\": 0,\n    \"regular_count\": 1\n  }\n}\n```\n\n### Admin and Health Endpoints\n\n| Method | Path | Description |\n|--------|------|-------------|\n| `GET` | `/liveness` | Health check (always returns OK) |\n| `GET` | `/readiness` | Readiness check (checks healthy worker availability) |\n| `GET` | `/health` | Alias for liveness |\n| `GET` | `/health_generate` | Health generate test |\n| `GET` | `/engine_metrics` | Engine-level metrics from workers |\n| `GET` | `/v1/models` | List available models |\n| `GET` | `/get_model_info` | Get model information |\n| `GET` | `/get_server_info` | Get server information |\n| `POST` | `/flush_cache` | Clear all caches |\n| `GET` | `/get_loads` | Get all worker loads |\n| `POST` | `/wasm` | Upload WASM module |\n| `GET` | `/wasm` | List WASM modules |\n| `DELETE` | `/wasm/{module_uuid}` | Remove WASM module |\n\n---\n\n## Load Balancing Policies\n\n| Policy | Description | Usage |\n|--------|-------------|-------|\n| `random` | Uniform random selection | `--policy random` |\n| `round_robin` | Cycles through workers in order | `--policy round_robin` |\n| `power_of_two` | Samples two workers and picks the lighter one | `--policy power_of_two` |\n| `cache_aware` | Combines cache locality with load balancing (default) | `--policy cache_aware` |\n| `bucket` | Divides workers into load buckets with dynamic boundaries | `--policy bucket` |\n\n### Cache-Aware Policy Tuning\n\n```bash\n--cache-threshold 0.5 \\\n--balance-abs-threshold 32 \\\n--balance-rel-threshold 1.5 \\\n--eviction-interval-secs 120 \\\n--max-tree-size 67108864\n```\n\n| Parameter | Default | Description |\n|-----------|---------|-------------|\n| `--cache-threshold` | 0.3 | Minimum prefix match ratio for cache hit |\n| `--balance-abs-threshold` | 64 | Absolute load difference before rebalancing |\n| `--balance-rel-threshold` | 1.5 | Relative load ratio before rebalancing |\n| `--eviction-interval-secs` | 120 | Cache eviction cadence in seconds |\n| `--max-tree-size` | 67108864 | Maximum nodes in cache tree |\n\n---\n\n## Reliability and Flow Control\n\n### Retries\n\nConfigure exponential backoff retries:\n\n```bash\npython -m sglang_router.launch_router \\\n  --worker-urls http://worker1:8000 http://worker2:8001 \\\n  --retry-max-retries 5 \\\n  --retry-initial-backoff-ms 50 \\\n  --retry-max-backoff-ms 30000 \\\n  --retry-backoff-multiplier 1.5 \\\n  --retry-jitter-factor 0.2\n```\n\n| Parameter | Default | Description |\n|-----------|---------|-------------|\n| `--retry-max-retries` | 5 | Maximum retry attempts |\n| `--retry-initial-backoff-ms` | 50 | Initial backoff duration (ms) |\n| `--retry-max-backoff-ms` | 5000 | Maximum backoff duration (ms) |\n| `--retry-backoff-multiplier` | 2.0 | Exponential backoff multiplier |\n| `--retry-jitter-factor` | 0.1 | Random jitter factor (0.0-1.0) |\n| `--disable-retries` | false | Disable retries entirely |\n\n**Retryable Status Codes:** 408, 429, 500, 502, 503, 504\n\n### Circuit Breaker\n\nPer-worker circuit breakers prevent cascading failures:\n\n```bash\npython -m sglang_router.launch_router \\\n  --worker-urls http://worker1:8000 http://worker2:8001 \\\n  --cb-failure-threshold 5 \\\n  --cb-success-threshold 2 \\\n  --cb-timeout-duration-secs 30 \\\n  --cb-window-duration-secs 60\n```\n\n| Parameter | Default | Description |\n|-----------|---------|-------------|\n| `--cb-failure-threshold` | 5 | Consecutive failures to open circuit |\n| `--cb-success-threshold` | 2 | Successes to close from half-open |\n| `--cb-timeout-duration-secs` | 30 | Time before half-open attempt |\n| `--cb-window-duration-secs` | 60 | Failure counting window |\n| `--disable-circuit-breaker` | false | Disable circuit breaker |\n\n**Circuit Breaker States:**\n- **Closed**: Normal operation, requests allowed\n- **Open**: Failing, requests rejected immediately\n- **Half-Open**: Testing recovery, limited requests allowed\n\n### Rate Limiting and Queuing\n\n```bash\npython -m sglang_router.launch_router \\\n  --worker-urls http://worker1:8000 http://worker2:8001 \\\n  --max-concurrent-requests 256 \\\n  --rate-limit-tokens-per-second 512 \\\n  --queue-size 128 \\\n  --queue-timeout-secs 30\n```\n\nRequests beyond the concurrency limit wait in a FIFO queue. Returns:\n- `429 Too Many Requests` when queue is full\n- `408 Request Timeout` when queue timeout expires\n\n### Health Checks\n\n```bash\n--health-check-interval-secs 30 \\\n--health-check-timeout-secs 10 \\\n--health-success-threshold 2 \\\n--health-failure-threshold 3 \\\n--health-check-endpoint /health\n```\n\n---\n\n## Reasoning Parser Integration\n\nThe gateway includes built-in reasoning parsers for models that use Chain-of-Thought (CoT) reasoning with explicit thinking blocks.\n\n### Supported Parsers\n\n| Parser ID | Model Family | Think Tokens |\n|-----------|--------------|--------------|\n| `deepseek-r1` | DeepSeek-R1 | `<think>...</think>` (initial reasoning) |\n| `qwen3` | Qwen-3 | `<think>...</think>` |\n| `qwen3-thinking` | Qwen-3 Thinking | `<think>...</think>` (initial reasoning) |\n| `kimi` | Kimi K2 | Unicode think tokens |\n| `glm45` | GLM-4.5/4.6/4.7 | `<think>...</think>` |\n| `step3` | Step-3 | `<think>...</think>` |\n| `minimax` | MiniMax | `<think>...</think>` |\n\n### Usage\n\n```bash\npython -m sglang_router.launch_router \\\n  --worker-urls grpc://127.0.0.1:20000 \\\n  --model-path deepseek-ai/DeepSeek-R1 \\\n  --reasoning-parser deepseek-r1\n```\n\nThe gRPC router automatically:\n1. Detects reasoning blocks in streaming output\n2. Separates reasoning content from normal text\n3. Applies incremental streaming parsing with buffer management\n4. Handles partial token detection for correct streaming behavior\n\n---\n\n## Tool Call Parsing\n\nThe gateway supports parsing function/tool calls from LLM outputs in multiple formats.\n\n### Supported Formats\n\n| Parser | Format | Description |\n|--------|--------|-------------|\n| `json` | JSON | Standard JSON tool calls |\n| `python` | Pythonic | Python function call syntax |\n| `xml` | XML | XML-formatted tool calls |\n\n### Usage\n\n```bash\npython -m sglang_router.launch_router \\\n  --worker-urls grpc://127.0.0.1:20000 \\\n  --model-path meta-llama/Llama-3.1-8B-Instruct \\\n  --tool-call-parser json\n```\n\n---\n\n## Tokenizer Management\n\n### Tokenizer Sources\n\nThe gateway supports multiple tokenizer backends:\n- **HuggingFace**: Load from HuggingFace Hub by model ID\n- **Local**: Load from local `tokenizer.json` or directory\n- **Tiktoken**: Auto-detect OpenAI GPT models (gpt-4, davinci, etc.)\n\n### Configuration\n\n```bash\n# HuggingFace model\n--model-path meta-llama/Llama-3.1-8B-Instruct\n\n# Local tokenizer\n--tokenizer-path /path/to/tokenizer.json\n\n# With chat template override\n--chat-template /path/to/template.jinja\n```\n\n### Tokenizer Caching\n\nTwo-level caching for optimal performance:\n\n| Cache | Type | Description |\n|-------|------|-------------|\n| L0 | Exact match | Whole-string caching for repeated prompts |\n| L1 | Prefix match | Prefix boundary matching for incremental prompts |\n\n```bash\n--enable-l0-cache \\\n--l0-max-entries 10000 \\\n--enable-l1-cache \\\n--l1-max-memory 52428800  # 50MB\n```\n\n---\n\n## MCP Integration\n\nThe gateway provides native Model Context Protocol (MCP) client integration for tool execution.\n\n### Supported Transports\n\n| Transport | Description |\n|-----------|-------------|\n| STDIO | Local process execution |\n| SSE | Server-Sent Events (HTTP) |\n| Streamable | Bidirectional streaming |\n\n### Configuration\n\n```bash\npython -m sglang_router.launch_router \\\n  --mcp-config-path /path/to/mcp-config.yaml \\\n  --worker-urls http://worker1:8000\n```\n\n### MCP Configuration File\n\n```yaml\nservers:\n  - name: \"filesystem\"\n    command: \"npx\"\n    args: [\"-y\", \"@modelcontextprotocol/server-filesystem\", \"/tmp\"]\n    protocol: \"stdio\"\n    required: false\n\n  - name: \"github\"\n    url: \"https://api.github.com/mcp\"\n    token: \"ghp_xxxxx\"\n    protocol: \"sse\"\n    required: false\n\n  - name: \"custom-tools\"\n    url: \"https://tools.example.com/mcp\"\n    protocol: \"streamable\"\n    required: true\n\npool:\n  max_connections: 100\n  idle_timeout: 300\n\nproxy:\n  http: \"http://proxy.internal:8080\"\n  https: \"https://proxy.internal:8443\"\n  no_proxy: \"localhost,127.0.0.1,*.internal\"\n\ninventory:\n  enable_refresh: true\n  tool_ttl: 300\n  refresh_interval: 300\n```\n\n---\n\n## Service Discovery (Kubernetes)\n\nEnable automatic worker discovery via Kubernetes pod selectors:\n\n```bash\npython -m sglang_router.launch_router \\\n  --service-discovery \\\n  --selector app=sglang-worker role=inference \\\n  --service-discovery-namespace production \\\n  --service-discovery-port 8000\n```\n\n### PD Mode Discovery\n\n```bash\n--pd-disaggregation \\\n--prefill-selector app=sglang component=prefill \\\n--decode-selector app=sglang component=decode \\\n--service-discovery\n```\n\nPrefill pods can expose bootstrap ports via the `sglang.ai/bootstrap-port` annotation. RBAC must allow `get`, `list`, and `watch` on pods.\n\n---\n\n## History and Data Connectors\n\n| Backend | Description | Usage |\n|---------|-------------|-------|\n| `memory` | In-memory storage (default) | `--history-backend memory` |\n| `none` | No persistence | `--history-backend none` |\n| `oracle` | Oracle Autonomous Database | `--history-backend oracle` |\n| `postgres` | PostgreSQL Database | `--history-backend postgres` |\n| `redis` | Redis | `--history-backend redis` |\n\n### Oracle Configuration\n\n```bash\n# Connection descriptor\nexport ATP_DSN=\"(description=(address=(protocol=tcps)(port=1522)(host=adb.region.oraclecloud.com))(connect_data=(service_name=service_name)))\"\n\n# Or TNS alias (requires wallet)\nexport ATP_TNS_ALIAS=\"sglroutertestatp_high\"\nexport ATP_WALLET_PATH=\"/path/to/wallet\"\n\n# Credentials\nexport ATP_USER=\"admin\"\nexport ATP_PASSWORD=\"secret\"\nexport ATP_POOL_MIN=4\nexport ATP_POOL_MAX=32\n\npython -m sglang_router.launch_router \\\n  --backend openai \\\n  --worker-urls https://api.openai.com \\\n  --history-backend oracle\n```\n\n### PostgreSQL Configuration\n\n```bash\nexport POSTGRES_DB_URL=\"postgres://user:password@host:5432/dbname\"\n\npython -m sglang_router.launch_router \\\n  --backend openai \\\n  --worker-urls https://api.openai.com \\\n  --history-backend postgres\n```\n\n### Redis Configuration\n\n```bash\nexport REDIS_URL=\"redis://localhost:6379\"\nexport REDIS_POOL_MAX=16\nexport REDIS_RETENTION_DAYS=30\n\npython -m sglang_router.launch_router \\\n  --backend openai \\\n  --worker-urls https://api.openai.com \\\n  --history-backend redis \\\n  --redis-retention-days 30\n```\n\nUse `--redis-retention-days -1` for persistent storage (default is 30 days).\n\n---\n\n## WASM Middleware\n\nThe gateway supports WebAssembly (WASM) middleware modules for custom request/response processing. This enables organization-specific logic for authentication, rate limiting, billing, logging, and more—without modifying or recompiling the gateway.\n\n### Overview\n\nWASM middleware runs in a sandboxed environment with memory isolation, no network/filesystem access, and configurable resource limits.\n\n| Attach Point | When Executed | Use Cases |\n|--------------|---------------|-----------|\n| `OnRequest` | Before forwarding to workers | Auth, rate limiting, request modification |\n| `OnResponse` | After receiving worker response | Logging, response modification, error handling |\n\n| Action | Description |\n|--------|-------------|\n| `Continue` | Proceed without modification |\n| `Reject(status)` | Reject request with HTTP status code |\n| `Modify(...)` | Modify headers, body, or status |\n\n### Examples\n\nComplete working examples are available in `examples/wasm/`:\n\n| Example | Description |\n|---------|-------------|\n| `auth/` | API key authentication for protected routes |\n| `rate_limit/` | Per-client rate limiting (requests/minute) |\n| `logging/` | Request tracking headers and response modification |\n\nThe interface definition is located at `src/wasm/interface`.\n\n### Building Modules\n\n```bash\n# Prerequisites\nrustup target add wasm32-wasip2\ncargo install wasm-tools\n\n# Build\ncargo build --target wasm32-wasip2 --release\n\n# Convert to component format\nwasm-tools component new \\\n  target/wasm32-wasip2/release/my_middleware.wasm \\\n  -o my_middleware.component.wasm\n```\n\n### Deploying Modules\n\n```bash\n# Enable WASM support\npython -m sglang_router.launch_router \\\n  --worker-urls http://worker1:8000 \\\n  --enable-wasm\n\n# Upload module\ncurl -X POST http://localhost:30000/wasm \\\n  -H \"Content-Type: application/json\" \\\n  -d '{\n    \"modules\": [{\n      \"name\": \"auth-middleware\",\n      \"file_path\": \"/absolute/path/to/auth.component.wasm\",\n      \"module_type\": \"Middleware\",\n      \"attach_points\": [{\"Middleware\": \"OnRequest\"}]\n    }]\n  }'\n\n# List modules\ncurl http://localhost:30000/wasm\n\n# Remove module\ncurl -X DELETE http://localhost:30000/wasm/{module_uuid}\n```\n\n### Runtime Configuration\n\n| Parameter | Default | Description |\n|-----------|---------|-------------|\n| `max_memory_pages` | 1024 (64MB) | Maximum WASM memory |\n| `max_execution_time_ms` | 1000 | Execution timeout |\n| `max_stack_size` | 1MB | Stack size limit |\n| `module_cache_size` | 10 | Cached modules per worker |\n\n**Note:** Rate limiting state is per-worker thread and not shared across gateway replicas. For production, consider implementing rate limiting at a shared layer (e.g., Redis)\n\n---\n\n## Language Bindings\n\nSGLang Model Gateway provides official language bindings for Python and Go, enabling integration with different technology stacks and organizational requirements.\n\n### Python Bindings\n\nThe Python bindings provide a PyO3-based wrapper around the Rust gateway library. This is a straightforward binding that calls the gateway server startup from Python.\n\n#### Installation\n\n```bash\n# From PyPI\npip install sglang-router\n\n# Development build\ncd sgl-model-gateway/bindings/python\npip install maturin && maturin develop --features vendored-openssl\n```\n\n#### Usage\n\nThe Python bindings are used throughout this documentation. See the [Quick Start](#quick-start) and [Deployment Modes](#deployment-modes) sections for detailed examples.\n\nKey components:\n- `RouterArgs` dataclass with 50+ configuration options\n- `Router.from_args()` for programmatic startup\n- CLI commands: `smg launch`, `smg server`, `python -m sglang_router.launch_router`\n\n### Go Bindings\n\nThe Go bindings provide a high-performance gRPC client library for organizations with Go-based infrastructure. This is ideal for:\n\n- Integration with internal Go services and tooling\n- High-performance client applications\n- Building custom OpenAI-compatible proxy servers\n\n#### Architecture\n\n```\n┌─────────────────────────────────────────┐\n│         High-Level Go API               │\n│   (client.go - OpenAI-style interface)  │\n├─────────────────────────────────────────┤\n│         gRPC Layer                      │\n├─────────────────────────────────────────┤\n│         Rust FFI Layer                  │\n│   (Tokenization, Parsing, Conversion)   │\n└─────────────────────────────────────────┘\n```\n\n**Key Features:**\n- Native Rust tokenization via FFI (thread-safe, lock-free)\n- Full streaming support with context cancellation\n- Configurable channel buffer sizes for high concurrency\n- Built-in tool call parsing and chat template application\n\n#### Installation\n\n```bash\n# Build the FFI library first\ncd sgl-model-gateway/bindings/golang\nmake build && make lib\n\n# Then use in your Go project\ngo get github.com/sgl-project/sgl-go-sdk\n```\n\n**Requirements:** Go 1.24+, Rust toolchain\n\n#### Examples\n\nComplete working examples are available in `bindings/golang/examples/`:\n\n| Example | Description |\n|---------|-------------|\n| `simple/` | Non-streaming chat completion |\n| `streaming/` | Streaming chat completion with SSE |\n| `oai_server/` | Full OpenAI-compatible HTTP server |\n\n```bash\n# Run examples\ncd sgl-model-gateway/bindings/golang/examples/simple && ./run.sh\ncd sgl-model-gateway/bindings/golang/examples/streaming && ./run.sh\ncd sgl-model-gateway/bindings/golang/examples/oai_server && ./run.sh\n```\n\n#### Testing\n\n```bash\ncd sgl-model-gateway/bindings/golang\n\n# Unit tests\ngo test -v ./...\n\n# Integration tests (requires running SGLang server)\nexport SGL_GRPC_ENDPOINT=grpc://localhost:20000\nexport SGL_TOKENIZER_PATH=/path/to/tokenizer\ngo test -tags=integration -v ./...\n```\n\n### Comparison\n\n| Feature | Python | Go |\n|---------|--------|-----|\n| **Primary Use** | Gateway server launcher | gRPC client library |\n| **CLI Support** | Full CLI (smg, sglang-router) | Library only |\n| **K8s Discovery** | Native support | N/A (client library) |\n| **PD Mode** | Built-in | N/A (client library) |\n\n**When to Use Python:** Launching and managing the gateway server, service discovery, PD disaggregation.\n\n**When to Use Go:** Building custom client applications, integration with Go microservices, OpenAI-compatible proxy servers\n\n---\n\n## Security and Authentication\n\n### Router API Key\n\n```bash\npython -m sglang_router.launch_router \\\n  --api-key \"your-router-api-key\" \\\n  --worker-urls http://worker1:8000\n```\n\nClients must supply `Authorization: Bearer <key>` for protected endpoints.\n\n### Worker API Keys\n\n```bash\n# Add worker with explicit key\ncurl -H \"Authorization: Bearer router-key\" \\\n  -X POST http://localhost:8080/workers \\\n  -H \"Content-Type: application/json\" \\\n  -d '{\"url\":\"http://worker:8000\",\"api_key\":\"worker-key\"}'\n```\n\n### Security Configurations\n\n1. **No Authentication** (default): Use only in trusted environments\n2. **Router-only Authentication**: Clients authenticate to router\n3. **Worker-only Authentication**: Router open, workers require keys\n4. **Full Authentication**: Both router and workers protected\n\n### TLS (HTTPS) for Gateway Server\n\nEnable TLS to serve the gateway over HTTPS:\n\n```bash\npython -m sglang_router.launch_router \\\n  --worker-urls http://worker1:8000 \\\n  --tls-cert-path /path/to/server.crt \\\n  --tls-key-path /path/to/server.key\n```\n\n| Parameter | Description |\n|-----------|-------------|\n| `--tls-cert-path` | Path to server certificate (PEM format) |\n| `--tls-key-path` | Path to server private key (PEM format) |\n\nBoth parameters must be provided together. The gateway uses rustls with the ring crypto provider for TLS termination. If TLS is not configured, the gateway falls back to plain HTTP.\n\n### mTLS for Worker Communication\n\nEnable mutual TLS (mTLS) for secure communication with workers in HTTP mode:\n\n```bash\npython -m sglang_router.launch_router \\\n  --worker-urls https://worker1:8443 https://worker2:8443 \\\n  --client-cert-path /path/to/client.crt \\\n  --client-key-path /path/to/client.key \\\n  --ca-cert-path /path/to/ca.crt\n```\n\n| Parameter | Description |\n|-----------|-------------|\n| `--client-cert-path` | Path to client certificate for mTLS (PEM format) |\n| `--client-key-path` | Path to client private key for mTLS (PEM format) |\n| `--ca-cert-path` | Path to CA certificate for verifying worker TLS (PEM format, repeatable) |\n\n**Key Points:**\n- Client certificate and key must be provided together\n- Multiple CA certificates can be added with multiple `--ca-cert-path` flags\n- Uses rustls backend when TLS is configured\n- Single HTTP client is created for all workers (assumes single security domain)\n- TCP keepalive (30 seconds) is enabled for long-lived connections\n\n### Full TLS Configuration Example\n\nGateway HTTPS + Worker mTLS + API Key authentication:\n\n```bash\npython -m sglang_router.launch_router \\\n  --worker-urls https://worker1:8443 https://worker2:8443 \\\n  --tls-cert-path /etc/certs/server.crt \\\n  --tls-key-path /etc/certs/server.key \\\n  --client-cert-path /etc/certs/client.crt \\\n  --client-key-path /etc/certs/client.key \\\n  --ca-cert-path /etc/certs/ca.crt \\\n  --api-key \"secure-api-key\" \\\n  --policy cache_aware\n```\n\n---\n\n## Observability\n\n### Prometheus Metrics\n\nEnable with `--prometheus-host`/`--prometheus-port` (defaults to `0.0.0.0:29000`).\n\n#### Metric Categories (40+ metrics)\n\n| Layer | Prefix | Metrics |\n|-------|--------|---------|\n| HTTP | `smg_http_*` | `requests_total`, `request_duration_seconds`, `responses_total`, `connections_active`, `rate_limit_total` |\n| Router | `smg_router_*` | `requests_total`, `request_duration_seconds`, `request_errors_total`, `stage_duration_seconds`, `upstream_responses_total` |\n| Inference | `smg_router_*` | `ttft_seconds`, `tpot_seconds`, `tokens_total`, `generation_duration_seconds` |\n| Worker | `smg_worker_*` | `pool_size`, `connections_active`, `requests_active`, `health_checks_total`, `selection_total`, `errors_total` |\n| Circuit Breaker | `smg_worker_cb_*` | `state`, `transitions_total`, `outcomes_total`, `consecutive_failures`, `consecutive_successes` |\n| Retry | `smg_worker_*` | `retries_total`, `retries_exhausted_total`, `retry_backoff_seconds` |\n| Discovery | `smg_discovery_*` | `registrations_total`, `deregistrations_total`, `sync_duration_seconds`, `workers_discovered` |\n| MCP | `smg_mcp_*` | `tool_calls_total`, `tool_duration_seconds`, `servers_active`, `tool_iterations_total` |\n| Database | `smg_db_*` | `operations_total`, `operation_duration_seconds`, `connections_active`, `items_stored` |\n\n#### Key Inference Metrics (gRPC mode)\n\n| Metric | Type | Description |\n|--------|------|-------------|\n| `smg_router_ttft_seconds` | Histogram | Time to first token |\n| `smg_router_tpot_seconds` | Histogram | Time per output token |\n| `smg_router_tokens_total` | Counter | Total tokens (input/output) |\n| `smg_router_generation_duration_seconds` | Histogram | End-to-end generation time |\n\n#### Duration Buckets\n\n1ms, 5ms, 10ms, 25ms, 50ms, 100ms, 250ms, 500ms, 1s, 2.5s, 5s, 10s, 15s, 30s, 45s, 60s, 90s, 120s, 180s, 240s\n\n### OpenTelemetry Tracing\n\nEnable distributed tracing with OTLP export:\n\n```bash\npython -m sglang_router.launch_router \\\n  --worker-urls http://worker1:8000 \\\n  --enable-trace \\\n  --otlp-traces-endpoint localhost:4317\n```\n\n#### Features\n\n- OTLP/gRPC exporter (default port 4317)\n- W3C Trace Context propagation for HTTP and gRPC\n- Batch span processing (500ms delay, 64 span batch size)\n- Custom filtering to reduce noise\n- Trace context injection into upstream worker requests\n- Service name: `sgl-router`\n\n### Logging\n\n```bash\npython -m sglang_router.launch_router \\\n  --worker-urls http://worker1:8000 \\\n  --log-level debug \\\n  --log-dir ./router_logs\n```\n\nStructured tracing with optional file sink. Log levels: `debug`, `info`, `warn`, `error`.\n\n### Request ID Propagation\n\n```bash\n--request-id-headers x-request-id x-trace-id x-correlation-id\n```\n\nResponses include `x-request-id` header for correlation.\n\n---\n\n## Production Recommendations\n\nThis section provides guidance for deploying SGLang Model Gateway in production environments.\n\n### Security Best Practices\n\n**Always enable TLS in production:**\n\n```bash\npython -m sglang_router.launch_router \\\n  --worker-urls https://worker1:8443 https://worker2:8443 \\\n  --tls-cert-path /etc/certs/server.crt \\\n  --tls-key-path /etc/certs/server.key \\\n  --client-cert-path /etc/certs/client.crt \\\n  --client-key-path /etc/certs/client.key \\\n  --ca-cert-path /etc/certs/ca.crt \\\n  --api-key \"${ROUTER_API_KEY}\"\n```\n\n**Security Checklist:**\n- Enable TLS for gateway HTTPS termination\n- Enable mTLS for worker communication when workers are on untrusted networks\n- Set `--api-key` to protect router endpoints\n- Use Kubernetes Secrets or a secrets manager for credentials\n- Rotate certificates and API keys periodically\n- Restrict network access with firewalls or network policies\n\n### High Availability\n\n**Scaling Strategy:**\n\nThe gateway supports running multiple replicas behind a load balancer for high availability. However, there are important considerations:\n\n| Component | Shared Across Replicas | Impact |\n|-----------|----------------------|--------|\n| Worker Registry | No (independent) | Each replica discovers workers independently |\n| Radix Cache Tree | No (independent) | Cache hits may decrease by 10-20% |\n| Circuit Breaker State | No (independent) | Each replica tracks failures independently |\n| Rate Limiting | No (independent) | Limits apply per-replica, not globally |\n\n**Recommendations:**\n\n1. **Prefer horizontal scaling over vertical scaling**: Deploy multiple smaller gateway replicas rather than one large instance with excessive CPU and memory. This provides:\n   - Better fault tolerance (single replica failure doesn't take down the gateway)\n   - More predictable resource usage\n   - Easier capacity planning\n\n2. **Use Kubernetes Service Discovery**: Let the gateway automatically discover and manage workers:\n   ```bash\n   python -m sglang_router.launch_router \\\n     --service-discovery \\\n     --selector app=sglang-worker \\\n     --service-discovery-namespace production\n   ```\n\n3. **Accept cache efficiency trade-off**: With multiple replicas, the cache-aware routing policy's radix tree is not synchronized across replicas. This means:\n   - Each replica builds its own cache tree\n   - Requests from the same user may hit different replicas\n   - Expected cache hit rate reduction: **10-20%**\n   - This is often acceptable given the HA benefits\n\n4. **Configure session affinity (optional)**: If cache efficiency is critical, configure your load balancer for session affinity based on a consistent hash of the request (e.g., user ID or API key).\n\n**Example HA Architecture:**\n```\n                    ┌─────────────────┐\n                    │  Load Balancer  │\n                    │   (L4/L7)       │\n                    └────────┬────────┘\n              ┌──────────────┼──────────────┐\n              │              │              │\n        ┌─────▼─────┐  ┌─────▼─────┐  ┌─────▼─────┐\n        │  Gateway  │  │  Gateway  │  │  Gateway  │\n        │ Replica 1 │  │ Replica 2 │  │ Replica 3 │\n        └─────┬─────┘  └─────┬─────┘  └─────┬─────┘\n              │              │              │\n              └──────────────┼──────────────┘\n                             │\n              ┌──────────────┼──────────────┐\n              │              │              │\n        ┌─────▼─────┐  ┌─────▼─────┐  ┌─────▼─────┐\n        │  Worker   │  │  Worker   │  │  Worker   │\n        │  Pod 1    │  │  Pod 2    │  │  Pod N    │\n        └───────────┘  └───────────┘  └───────────┘\n```\n\n### Performance\n\n**Use gRPC mode for high throughput:**\n\ngRPC mode provides the highest performance for SGLang workers:\n\n```bash\n# Start workers in gRPC mode\npython -m sglang.launch_server \\\n  --model meta-llama/Llama-3.1-8B-Instruct \\\n  --grpc-mode \\\n  --port 20000\n\n# Configure gateway for gRPC\npython -m sglang_router.launch_router \\\n  --worker-urls grpc://worker1:20000 grpc://worker2:20000 \\\n  --model-path meta-llama/Llama-3.1-8B-Instruct \\\n  --policy cache_aware\n```\n\n**Performance Benefits of gRPC:**\n- Native Rust tokenization (no Python overhead)\n- Streaming with lower latency\n- Built-in reasoning parser execution\n- Tool call parsing in the gateway\n- Reduced serialization overhead\n\n**Tuning Recommendations:**\n\n| Parameter | Recommendation | Reason |\n|-----------|---------------|--------|\n| `--policy` | `cache_aware` | Best for repeated prompts, ~30% latency reduction |\n| `--max-concurrent-requests` | 2-4x worker count | Prevent overload while maximizing throughput |\n| `--queue-size` | 2x max-concurrent | Buffer for burst traffic |\n| `--request-timeout-secs` | Based on max generation length | Prevent stuck requests |\n\n### Kubernetes Deployment\n\n**Pod Labeling for Service Discovery:**\n\nFor the gateway to discover workers automatically, label your worker pods consistently:\n\n```yaml\n# Worker Deployment (Regular Mode)\napiVersion: apps/v1\nkind: Deployment\nmetadata:\n  name: sglang-worker\n  namespace: production\nspec:\n  replicas: 4\n  selector:\n    matchLabels:\n      app: sglang-worker\n      component: inference\n  template:\n    metadata:\n      labels:\n        app: sglang-worker\n        component: inference\n        model: llama-3-8b\n    spec:\n      containers:\n      - name: worker\n        image: lmsysorg/sglang:latest\n        ports:\n        - containerPort: 8000\n          name: http\n        - containerPort: 20000\n          name: grpc\n```\n\n**Gateway configuration for discovery:**\n```bash\npython -m sglang_router.launch_router \\\n  --service-discovery \\\n  --selector app=sglang-worker component=inference \\\n  --service-discovery-namespace production \\\n  --service-discovery-port 8000\n```\n\n**PD (Prefill/Decode) Mode Labeling:**\n\n```yaml\n# Prefill Worker\nmetadata:\n  labels:\n    app: sglang-worker\n    component: prefill\n  annotations:\n    sglang.ai/bootstrap-port: \"9001\"\n\n# Decode Worker\nmetadata:\n  labels:\n    app: sglang-worker\n    component: decode\n```\n\n**Gateway configuration for PD discovery:**\n```bash\npython -m sglang_router.launch_router \\\n  --service-discovery \\\n  --pd-disaggregation \\\n  --prefill-selector app=sglang-worker component=prefill \\\n  --decode-selector app=sglang-worker component=decode \\\n  --service-discovery-namespace production\n```\n\n**RBAC Requirements:**\n\nThe gateway needs permissions to watch pods:\n\n```yaml\napiVersion: rbac.authorization.k8s.io/v1\nkind: Role\nmetadata:\n  name: sglang-gateway\n  namespace: production\nrules:\n- apiGroups: [\"\"]\n  resources: [\"pods\"]\n  verbs: [\"get\", \"list\", \"watch\"]\n---\napiVersion: rbac.authorization.k8s.io/v1\nkind: RoleBinding\nmetadata:\n  name: sglang-gateway\n  namespace: production\nsubjects:\n- kind: ServiceAccount\n  name: sglang-gateway\n  namespace: production\nroleRef:\n  kind: Role\n  name: sglang-gateway\n  apiGroup: rbac.authorization.k8s.io\n```\n\n### Monitoring with PromQL\n\nConfigure Prometheus to scrape the gateway metrics endpoint (default: `:29000/metrics`).\n\n**Essential Dashboards:**\n\n**1. Request Rate and Latency:**\n```promql\n# Request rate by endpoint\nsum(rate(smg_http_requests_total[5m])) by (path, method)\n\n# P50 latency\nhistogram_quantile(0.50, sum(rate(smg_http_request_duration_seconds_bucket[5m])) by (le))\n\n# P99 latency\nhistogram_quantile(0.99, sum(rate(smg_http_request_duration_seconds_bucket[5m])) by (le))\n\n# Error rate\nsum(rate(smg_http_responses_total{status=~\"5..\"}[5m])) / sum(rate(smg_http_responses_total[5m]))\n```\n\n**2. Worker Health:**\n```promql\n# Healthy workers\nsum(smg_worker_pool_size)\n\n# Active connections per worker\nsmg_worker_connections_active\n\n# Worker health check failures\nsum(rate(smg_worker_health_checks_total{result=\"failure\"}[5m])) by (worker_id)\n```\n\n**3. Circuit Breaker Status:**\n```promql\n# Circuit breaker states (0=closed, 1=open, 2=half-open)\nsmg_worker_cb_state\n\n# Circuit breaker transitions\nsum(rate(smg_worker_cb_transitions_total[5m])) by (worker_id, from_state, to_state)\n\n# Workers with open circuits\ncount(smg_worker_cb_state == 1)\n```\n\n**4. Inference Performance (gRPC mode):**\n```promql\n# Time to first token (P50)\nhistogram_quantile(0.50, sum(rate(smg_router_ttft_seconds_bucket[5m])) by (le, model))\n\n# Time per output token (P99)\nhistogram_quantile(0.99, sum(rate(smg_router_tpot_seconds_bucket[5m])) by (le, model))\n\n# Token throughput\nsum(rate(smg_router_tokens_total[5m])) by (model, direction)\n\n# Generation duration P95\nhistogram_quantile(0.95, sum(rate(smg_router_generation_duration_seconds_bucket[5m])) by (le))\n```\n\n**5. Rate Limiting and Queuing:**\n```promql\n# Rate limit rejections\nsum(rate(smg_http_rate_limit_total{decision=\"rejected\"}[5m]))\n\n# Queue depth (if using concurrency limiting)\nsmg_worker_requests_active\n\n# Retry attempts\nsum(rate(smg_worker_retries_total[5m])) by (worker_id)\n\n# Exhausted retries (failures after all retries)\nsum(rate(smg_worker_retries_exhausted_total[5m]))\n```\n\n**6. MCP Tool Execution:**\n```promql\n# Tool call rate\nsum(rate(smg_mcp_tool_calls_total[5m])) by (server, tool)\n\n# Tool latency P95\nhistogram_quantile(0.95, sum(rate(smg_mcp_tool_duration_seconds_bucket[5m])) by (le, tool))\n\n# Active MCP server connections\nsmg_mcp_servers_active\n```\n\n**Alerting Rules Example:**\n\n```yaml\ngroups:\n- name: sglang-gateway\n  rules:\n  - alert: HighErrorRate\n    expr: |\n      sum(rate(smg_http_responses_total{status=~\"5..\"}[5m]))\n      / sum(rate(smg_http_responses_total[5m])) > 0.05\n    for: 5m\n    labels:\n      severity: critical\n    annotations:\n      summary: \"High error rate on SGLang Gateway\"\n\n  - alert: CircuitBreakerOpen\n    expr: count(smg_worker_cb_state == 1) > 0\n    for: 2m\n    labels:\n      severity: warning\n    annotations:\n      summary: \"Worker circuit breaker is open\"\n\n  - alert: HighLatency\n    expr: |\n      histogram_quantile(0.99, sum(rate(smg_http_request_duration_seconds_bucket[5m])) by (le)) > 30\n    for: 5m\n    labels:\n      severity: warning\n    annotations:\n      summary: \"P99 latency exceeds 30 seconds\"\n\n  - alert: NoHealthyWorkers\n    expr: sum(smg_worker_pool_size) == 0\n    for: 1m\n    labels:\n      severity: critical\n    annotations:\n      summary: \"No healthy workers available\"\n```\n\n---\n\n## Configuration Reference\n\n### Core Settings\n\n| Parameter | Type | Default | Description |\n|-----------|------|---------|-------------|\n| `--host` | str | 127.0.0.1 | Router host |\n| `--port` | int | 30000 | Router port |\n| `--worker-urls` | list | [] | Worker URLs (HTTP or gRPC) |\n| `--policy` | str | cache_aware | Routing policy |\n| `--max-concurrent-requests` | int | -1 | Concurrency limit (-1 disables) |\n| `--request-timeout-secs` | int | 600 | Request timeout |\n| `--max-payload-size` | int | 256MB | Maximum request payload |\n\n### Prefill/Decode\n\n| Parameter | Type | Default | Description |\n|-----------|------|---------|-------------|\n| `--pd-disaggregation` | flag | false | Enable PD mode |\n| `--prefill` | list | [] | Prefill URLs + optional bootstrap ports |\n| `--decode` | list | [] | Decode URLs |\n| `--prefill-policy` | str | None | Override policy for prefill nodes |\n| `--decode-policy` | str | None | Override policy for decode nodes |\n| `--worker-startup-timeout-secs` | int | 600 | Worker init timeout |\n\n### Kubernetes Discovery\n\n| Parameter | Type | Description |\n|-----------|------|-------------|\n| `--service-discovery` | flag | Enable discovery |\n| `--selector` | list | Label selectors (key=value) |\n| `--prefill-selector` / `--decode-selector` | list | PD mode selectors |\n| `--service-discovery-namespace` | str | Namespace to watch |\n| `--service-discovery-port` | int | Worker port (default 80) |\n| `--bootstrap-port-annotation` | str | Annotation for bootstrap ports |\n\n### TLS Configuration\n\n| Parameter | Type | Description |\n|-----------|------|-------------|\n| `--tls-cert-path` | str | Server certificate for gateway HTTPS (PEM) |\n| `--tls-key-path` | str | Server private key for gateway HTTPS (PEM) |\n| `--client-cert-path` | str | Client certificate for worker mTLS (PEM) |\n| `--client-key-path` | str | Client private key for worker mTLS (PEM) |\n| `--ca-cert-path` | str | CA certificate for verifying workers (PEM, repeatable) |\n\n---\n\n## Troubleshooting\n\n### Workers Never Ready\n\nIncrease `--worker-startup-timeout-secs` or ensure health probes respond before router startup.\n\n### Load Imbalance / Hot Workers\n\nInspect `smg_router_requests_total` by worker and tune cache-aware thresholds (`--balance-*`, `--cache-threshold`).\n\n### Circuit Breaker Flapping\n\nIncrease `--cb-failure-threshold` or extend the timeout/window durations. Consider temporarily disabling retries.\n\n### Queue Overflow (429)\n\nIncrease `--queue-size` or reduce client concurrency. Ensure `--max-concurrent-requests` matches downstream capacity.\n\n### Memory Growth\n\nReduce `--max-tree-size` or lower `--eviction-interval-secs` for more aggressive cache pruning.\n\n### Debugging\n\n```bash\npython -m sglang_router.launch_router \\\n  --worker-urls http://worker1:8000 \\\n  --log-level debug \\\n  --log-dir ./router_logs\n```\n\n### gRPC Connection Issues\n\nEnsure workers are started with `--grpc-mode` and verify `--model-path` or `--tokenizer-path` is provided to the router.\n\n### Tokenizer Loading Failures\n\nCheck HuggingFace Hub credentials (`HF_TOKEN` environment variable) for private models. Verify local paths are accessible.\n\n---\n\nSGLang Model Gateway continues to evolve alongside the SGLang runtime. Keep CLI flags, integrations, and documentation aligned when adopting new features or contributing improvements.\n"
  },
  {
    "path": "docs/advanced_features/sglang_for_rl.md",
    "content": "# SGLang for RL Systems\n\nThis document is a practical guide for infrastructure teams integrating SGLang into RL and post-training systems. It focuses on the operational pain points in the loop (rollout, evaluation, training, weight sync) and maps them to concrete SGLang APIs, flags, and integration patterns. The focus is on maximizing rollout efficiency, accuracy and stability while keeping rollout-serving behavior aligned in production environments.\n\n## Why SGLang for RL Lifecycle?\n\nLet's embrace a guiding principle from early DeepMind's RL engineering:\n\n**Be a library, not a framework.**\n\nThis philosophy empowers innovation by providing SGLang as flexible tools, not rigid structures. Here are five reasons to use SGLang for your RL lifecycle:\n\n* **Fine-Grained Engine Sleep and Wake Up**: facilitate maximum-powered rollout and training\n* **Open-To-Use Refit Functionality**: diverse methods for co-location or disaggregation\n* **Easy To Postpone Generation**: enable partial rollout and dedicated rollout control\n* **Deterministic Inference**: achieve deterministic inference to enable zero training-inference mismatch\n* **Load Balancing Router**: cache-aware load-balancing for high-throughput rollout\n\nThe following sections cover these aspects in detail.\n\n## Fine-Grained Engine Sleep and Wake Up\n\nRollout and training are both memory-intensive, and co-locating them on the same GPUs often leads to memory pressure and slow handoffs. SGLang provides a memory-aware sleep/wake mechanism that releases KV cache and weights while keeping the server process alive, then resumes them for rollout without a full restart. This avoids repeated disk I/O and CUDA graph recapture during each RL step.\n\nUnder the hood, the RL team uses CUDA-graph-aware weight offload via [torch_memory_saver](https://github.com/fzyzcjy/torch_memory_saver) to preserve virtual memory addresses for graph replay. For details, see: [Efficient RL Training - Optimizing Memory Usage in verl](https://hebiao064.github.io/rl-memory-management).\n\n### Server flag\n\nEnable memory saver support when launching the server:\n\n```\n--enable-memory-saver\n```\n\n### Release Memory\n\n**Endpoint:** `POST /release_memory_occupation`\n\n**Request body:**\n\n| Field | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `tags` | Which memory regions to release. If omitted, all are released. | `None` | Type: list[str], values: `kv_cache`, `weights` |\n<!-- python/sglang/srt/managers/io_struct.py#L1381 currently only supports `kv_cache`, `weights` -->\n**Behavior notes:**\n\n- This call asserts there are no ongoing requests. Ensure the engine is idle before calling it.\n- If `kv_cache` is released, SGLang flushes cache; subsequent requests will rebuild KV cache as needed.\n\n### Resume Memory\n\n**Endpoint:** `POST /resume_memory_occupation`\n\n**Request body:**\n\n| Field | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `tags` | Which memory regions to resume. If omitted, all are resumed. | `None` | Type: list[str], values: `kv_cache`, `weights` |\n<!-- python/sglang/srt/managers/io_struct.py#L1393 currently only supports `kv_cache`, `weights` -->\n\n## Open-To-Use Refit Functionality\n\nAfter training completes each step, rollout engines must be refit with new weights. SGLang supports three refit strategies so you can match your infrastructure style (co-located vs disaggregated) and scaling needs. Each strategy maps to a concrete API with clear request schemas. For a deeper dive into SGLang's weight update utilities, see [RL System Deep Thinking: Weight Update Mechanisms](https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/sys-design/readme-1-EN.md).\n\n**How to choose:**\n\n- **From disk** is simplest and best for elastic rollout scaling and checkpointing.\n- **From tensor** is best for co-located training/rollout when you can pass in-memory tensors.\n- **From distributed** is best for disaggregated training/rollout with dedicated communication groups (NCCL/IB).\n\n### Update Weights from Disk\n\n**When to use:**\n\n- Save checkpoint to disk and update weights from disk\n- Dynamic scaling (new rollout instances can load from the same checkpoint)\n\n**Why it works well:**\n\nThis path trades some I/O overhead for simplicity and flexibility. It integrates naturally with checkpointing and makes it trivial to add new rollout engines: point them at the same checkpoint and call the API. It is also the safest option for high availability because the checkpoint itself is the source of truth.\n\n**Endpoint:** `POST /update_weights_from_disk`\n\n**Request body:**\n\n| Field | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `model_path` | The model path with the new weights. | Required | Type: str |\n| `load_format` | The format to load the weights. | `None` | Type: str |\n| `abort_all_requests` | Abort all running requests before update. | `False` | Type: bool |\n| `weight_version` | Optional weight version label tracked by the server. | `None` | Type: str |\n| `is_async` | Perform weight load asynchronously. | `False` | Type: bool |\n| `torch_empty_cache` | Empty torch cache. | `False` | Type: bool |\n| `keep_pause` | Keep scheduler paused after update. | `False` | Type: bool |\n| `recapture_cuda_graph` | Recapture CUDA graphs after update. | `False` | Type: bool |\n| `token_step` | Trainer step id for rollout bookkeeping. | `0` | Type: int |\n| `flush_cache` | Flush KV cache after update. | `True` | Type: bool |\n\n**Response body:**\n\n| Field | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `success` | Whether the update succeeded. | - | Type: bool |\n| `message` | Status / error message. | - | Type: str |\n| `num_paused_requests` | Number of paused requests during update. | `0` | Type: int |\n\n**Python Engine API:** `engine.update_weights_from_disk(model_path, load_format=None)`\n\n**Diffusion engine (SGLang-Diffusion):** The diffusion engine exposes the same `POST /update_weights_from_disk` endpoint with the following behavior:\n\n- **All-or-nothing with rollback:** if any module fails to load, all previously updated modules are rolled back to the original weights by reloading from the original model path. No partial updates are left behind. If rollback itself fails, the exception propagates so the caller knows the model is in an inconsistent state.\n- **Offload-aware:** when layerwise offload (`--dit-layerwise-offload`) is enabled, the diffusion offload manager replaces GPU parameters with small `torch.empty((1,))` placeholders while real weights live in consolidated pinned CPU buffers. A naive `param.data.copy_()` would fail with a shape mismatch. Instead, the updater dynamically detects active offload managers and writes new weights directly into their CPU buffers, bypassing the placeholders entirely. For any layer that happens to be prefetched on GPU at update time, the live GPU tensor is also updated so the change takes effect immediately. This requires no extra GPU memory and does not disturb the offload state.\n- **DTensor-aware:** parameters distributed via `torch.distributed.tensor` (tensor parallelism) are updated through `distribute_tensor` so that each shard is correctly placed on the right device mesh.\n\n**Request body:**\n\n| Field | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `model_path` | The model path with the new weights. | Required | Type: str |\n| `flush_cache` | Flush TeaCache state after update. | `True` | Type: bool |\n| `target_modules` | List of module names to update (e.g. `[\"transformer\"]`). If omitted, all `nn.Module` components are updated. | `None` | Type: list[str] |\n\n**Response body:**\n\n| Field | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `success` | Whether the update succeeded. | - | Type: bool |\n| `message` | Status / error message. | - | Type: str |\n\n> **Note:** The diffusion engine (SGLang-Diffusion) does not currently support hot refit (updating weights while inference is in progress). The diffusion scheduler processes one request at a time and completes the entire inference before handling the next request, so weight updates and inference never run concurrently.\n\n### Update Weights from Tensor\n\n**When to use:**\n\n- Co-located training and rollout, where training can provide tensors directly\n- Fast in-memory updates\n\n**Important constraints:**\n\nThis strategy requires the training process and rollout engine to share access to the tensors. Co-located setups must keep the model on GPU; moving tensors to CPU will break the update path. For high-performance MoE or specialized attention kernels, co-location may limit some optimizations compared to disaggregated rollouts.\n\n**Endpoint:** `POST /update_weights_from_tensor`\n\n**Request body:**\n\n| Field | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `serialized_named_tensors` | Per-TP serialized tensor payloads. | Required | Type: list[str|bytes] |\n| `load_format` | Optional load format selector. | `None` | `None`, `direct`, `flattened_bucket`, or a custom loader path string |\n| `flush_cache` | Flush KV cache after update. | `True` | Type: bool |\n| `abort_all_requests` | Abort all running requests before update. | `False` | Type: bool |\n| `weight_version` | Optional version label tracked by the server. | `None` | Type: str |\n\n**Note:** The serialized tensor payloads must be created with `MultiprocessingSerializer.serialize(...)` and should be base64-safe strings.\n\n**Python Engine API:** `engine.update_weights_from_tensor(named_tensors, load_format=None, flush_cache=True)`\n\n### Update Weights from Distributed Group\n\n**When to use:**\n\n- Disaggregated training and rollout\n- NCCL or IB-backed weight broadcast from training workers to rollout workers\n\n**How it works:**\n\nTraining workers gather weights (typically on TP rank 0), broadcast them to the rollout group, and each rollout TP shard loads the parameters it needs. This avoids disk I/O and keeps training and rollout decoupled, at the cost of managing a dedicated communication group.\n\n**Initialize weight update group**\n\n**Endpoint:** `POST /init_weights_update_group`\n\n**Request body:**\n\n| Field | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `master_address` | Group master address. | Required | Type: str |\n| `master_port` | Group master port. | Required | Type: int |\n| `rank_offset` | Offset for local rank mapping. | Required | Type: int |\n| `world_size` | Total world size. | Required | Type: int |\n| `group_name` | Group name. | `weight_update_group` | Type: str |\n| `backend` | Communication backend. | `nccl` | Type: str |\n\n**Update weight**\n\n**Endpoint:** `POST /update_weights_from_distributed`\n\n**Request body:**\n\n| Field | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `names` | Parameter names to update. | Required | Type: list[str] |\n| `dtypes` | Dtype strings for each parameter. | Required | Type: list[str] |\n| `shapes` | Tensor shapes. | Required | Type: list[list[int]] |\n| `group_name` | Group name. | `weight_update_group` | Type: str |\n| `flush_cache` | Flush KV cache after update. | `True` | Type: bool |\n| `abort_all_requests` | Abort all running requests before update. | `False` | Type: bool |\n| `weight_version` | Optional version label. | `None` | Type: str |\n| `load_format` | Optional format selector. | `None` | `None` or `flattened_bucket` |\n\n**Destroy weights update group**\n\n**Endpoint:** `POST /destroy_weights_update_group`\n\n**Request body:**\n\n| Field | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `group_name` | Group name. | `weight_update_group` | Type: str |\n\n**Python Engine APIs:**\n\n- `engine.init_weights_update_group(...)`\n- `engine.update_weights_from_distributed(names, dtypes, shapes, ...)`\n- `engine.destroy_weights_update_group(group_name)`\n\n## Easy To Postpone Generation\n\nMulti-turn RL rollouts often suffer from long-tail requests that block the entire batch. A small number of slow interactions can stall all GPUs, and the long-tail behavior makes profiling and monitoring difficult.\n\nSGLang exposes explicit pause/resume APIs so you can pause slow requests and continue them later. This pattern matches systems like [APRIL](https://arxiv.org/abs/2509.18521), terminate once enough responses are collected, and recycle incomplete responses in the next step. The result is higher GPU utilization without discarding partial work.\n\n`pause_generation` ---  update weights --- `continue_generation` is the correct execution flow when updating weights from training. An update can only happen when SGLang is not actively processing inference tasks.\n\n### Pause Generation\n\n**Endpoint:** `POST /pause_generation`\n\n**Request body:**\n\n| Field | Description | Defaults | Options |\n| --- | --- | --- | --- |\n| `mode` | Pause mode. | `abort` | `abort`, `retract`, `in_place` |\n\n**Modes:**\n\n- `abort`: Default behavior, identical to `abort` endpoint with `abort_all` set. Pending requests from `waiting_queue` and `running_queue` will be returned immediately to the caller.\n- `retract`: Put engine in \"paused\" state.  Move running requests back to waiting queue. KV cache can be flushed and recomputed later.\n- `in_place`: Put engine in \"paused\" state without changing states of the requests. Running requests rely on availability of KV caches to continue, so any subsequent `flush_cache` call will be unsuccessful.\n\n### Continue Generation\n\n**Endpoint:** `POST /continue_generation`\n\n## Deterministic Inference\n\nIn many RL stacks, rollout and training are implemented with different kernels or batching behavior. Even when weights are identical, token probabilities can drift, silently breaking the on-policy assumption. This is the training–inference mismatch problem.\n\nSGLang supports a deterministic inference mode that reduces non-determinism across batch shapes. This mitigates variance introduced by runtime batching and kernel selection. To further achieve true on-policy training, you need to modify the training engine to use the same deterministic kernels. For implementation details, see these miles examples: [True On-Policy](https://github.com/radixark/miles/tree/main/examples/true_on_policy) and [True On-Policy for VLM](https://github.com/radixark/miles/tree/main/examples/true_on_policy_vlm). For additional context, see the blog post [Let Speed Be With Stability: All-In-One Solution to Training-Inference Mismatch with Miles](https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/slime/mismatch/blog-en.md).\n\n**Server flag:**\n\n```\n--enable-deterministic-inference\n```\n\nFor more details, see [Deterministic Inference](deterministic_inference.md)\n\n## Load Balancing Router\n\nSGLang Model Gateway is the recommended control plane for large‑scale RL rollouts. It provides async, non‑blocking request handling, cache‑aware load balancing, and fault‑tolerant routing across rollout and reward servers. This lets you keep GPUs saturated while avoiding long‑tail stalls and brittle, engine‑local concurrency logic. It has been deployed in the training of GLM 4.5+ models and proven to be highly efficient in production-level large-scale RL workloads.\n\nKey benefits for RL infrastructure:\n\n- **Async non-blocking efficiency**: SGLang’s native async server/router architecture (HTTPS/gRPC) manages concurrency automatically. This guarantees maximum GPU saturation and effective continuous batching without requiring complex, manual implementation by engineers.\n- **Elasticity and fault tolerance**: By encapsulating the reward model and rollout as independent servers, SGLang decouples them logically and physically. This architecture provides robust disaster recovery for large-scale distributed training; if a server fails, the router automatically redirects traffic to healthy nodes, ensuring the training process continues without interruption.\n- **Training–Inference alignment**: Using the SGLang Model Gateway for both training and inference ensures \"What You See Is What You Get.\" This eliminates score discrepancies and the painful backend alignment issues often caused by using different engines for training versus deployment.\n- **Dynamic load balancing and long-tail mitigation**: Unlike static partitioning, the SGLang Model Gateway enables request-level dynamic dispatching for multi-turn RL. It can distribute different turns of a conversation across different servers to balance workloads and eliminate long-tail latency caused by varying sequence lengths.\n\nFor deployment and configuration, see: [SGLang Model Gateway](sgl_model_gateway.md)\n"
  },
  {
    "path": "docs/advanced_features/speculative_decoding.md",
    "content": "# Speculative Decoding\n\nSGLang provides several speculative decoding options, including EAGLE-2/EAGLE-3, MTP, classic draft-model decoding, and an NGRAM-based variant. Our implementation aims to maximize speed and efficiency and is considered to be among the fastest in open-source LLM engines.\n\n## Summary\n\n### Jump to sections\n\n- [EAGLE Decoding](#eagle-decoding)\n  - [EAGLE-2 Decoding](#eagle-2-decoding)\n  - [EAGLE-2 Decoding with torch.compile](#eagle-2-decoding-with-torchcompile)\n  - [EAGLE-2 Decoding via Frequency-Ranked Speculative Sampling](#eagle-2-decoding-via-frequency-ranked-speculative-sampling)\n  - [EAGLE-3 Decoding](#eagle-3-decoding)\n- [Multi Token Prediction](#multi-token-prediction)\n- [Standalone Speculative Decoding (Small Draft Model)](#standalone-speculative-decoding-small-draft-model)\n- [Speculative Decoding V2 (Overlap Scheduler)](#speculative-decoding-v2-overlap-scheduler)\n- [Ngram Speculative Decoding](#ngram-speculative-decoding)\n- [Full Parameter Reference](#full-parameter-reference)\n- [OOM Troubleshooting](#oom-troubleshooting)\n- [References](#references)\n\n### Quick guidance\n\n- **Best speed/quality (recommended)**: Use **EAGLE-3** with `--speculative-algorithm EAGLE3`.\n- **Strong default / broad compatibility**: Use **EAGLE-2** with `--speculative-algorithm EAGLE`.\n- **Lower `lm_head` overhead for EAGLE-2**: Enable **FR-Spec** with `--speculative-token-map`.\n- **Model is MTP-enabled**: Use **MTP via speculative decoding** (often with small `speculative_num_steps/topk/num_draft_tokens`, see the example section).\n- **You have a smaller draft LLM**: Use **STANDALONE** (`--speculative-algorithm STANDALONE`).\n- **No extra model available**: Use **NGRAM** (`--speculative-algorithm NGRAM`, CUDA-only).\n- **Want overlap scheduler (experimental)**: Enable **SpecV2** with `SGLANG_ENABLE_SPEC_V2=True` (requires `--speculative-eagle-topk 1`).\n\n### Method comparison (mini table)\n\n| Method | Draft source | Separate draft model? | How to enable | Notes / constraints |\n|---|---|---:|---|---|\n| EAGLE-2 | EAGLE draft model (feature drafting + tree) | Typically yes | `--speculative-algorithm EAGLE` + `--speculative-draft-model-path ...` | Tune `--speculative-num-steps`, `--speculative-eagle-topk`, `--speculative-num-draft-tokens` |\n| EAGLE-2 + `torch.compile` | Same as EAGLE-2 | Typically yes | Add `--enable-torch-compile` (optionally `--torch-compile-max-bs`) | Benefit varies by hardware/model; benchmark to verify |\n| EAGLE-2 + FR-Spec | Same as EAGLE-2 + token subset | Typically yes | Add `--speculative-token-map ...` | Reduces `lm_head` overhead with high-frequency token vocab |\n| EAGLE-3 | EAGLE3 draft model | Yes | `--speculative-algorithm EAGLE3` + `--speculative-draft-model-path ...` | Best throughput in the benchmark below |\n| MTP | Built-in multi-token heads (model-specific) | Often no | See **Multi Token Prediction** section | Uses speculative workflow; draft path may be auto-handled for some models |\n| STANDALONE | Smaller draft LLM (token-level) | Yes | `--speculative-algorithm STANDALONE` + `--speculative-draft-model-path ...` | Does **not** support `--enable-dp-attention` |\n| SpecV2 (experimental) | V2 workers + overlap scheduler | N/A | `SGLANG_ENABLE_SPEC_V2=True` | Only supports `--speculative-eagle-topk 1`; applies to `EAGLE`, `EAGLE3`, `STANDALONE` |\n| NGRAM | Ngram cache from previous tokens | No | `--speculative-algorithm NGRAM` | CUDA-only; no `--enable-dp-attention`; disables overlap scheduler & mixed chunked prefill |\n\n### Performance Highlights\n\nPlease see below for the huge improvements on throughput for LLaMA-Instruct 3.1 8B tested on MT bench that can be achieved via EAGLE3 decoding.\nFor further details please see the [EAGLE3 paper](https://arxiv.org/pdf/2503.01840).\n\n| Method | Throughput (tokens/s) |\n|--------|----------------|\n| SGLang (w/o speculative, 1x H100) | 158.34 tokens/s |\n| SGLang + EAGLE-2 (1x H100) | 244.10 tokens/s |\n| SGLang + EAGLE-3 (1x H100) | 373.25 tokens/s |\n\n---\n\n## EAGLE Decoding\n\nTo enable EAGLE speculative decoding the following parameters are relevant:\n\n| Parameter | Description | Default |\n|---|---|---|\n| `--speculative-draft-model-path` | Draft model path/weights. **Typically required** for EAGLE/EAGLE3 and STANDALONE. For some MTP-enabled models, this can be omitted. | `None` |\n| `--speculative-num-steps` | Depth of autoregressive drafting. Increases speculation range but risks rejection cascades. | Auto (`5` for Llama/Grok; `3` for many other models) |\n| `--speculative-eagle-topk` | Branching factor per step. Improves candidate diversity and acceptance rate, but increases memory/compute consumption. | Auto (`4` for Llama/Grok; `1` for many other models) |\n| `--speculative-num-draft-tokens` | Maximum parallel verification capacity. Allows deeper tree evaluation but increases GPU memory usage. | Auto (`8` for Llama/Grok; `4` for many other models). If `topk=1`, it is adjusted to `num_steps + 1`. |\n| `--speculative-accept-threshold-single` | Acceptance threshold for single-token verification. Lower values accept more aggressively. | `1.0` |\n| `--speculative-accept-threshold-acc` | Accumulated acceptance threshold across steps. | `1.0` |\n| `--speculative-attention-mode` | Attention mode for speculative operations (`prefill` or `decode`), affecting both target verification and draft extension. | `\"prefill\"` |\n| `--speculative-draft-attention-backend` | Override attention backend for the draft model. | `None` (same as target) |\n| `--speculative-draft-model-quantization` | Quantization method for the draft model. Use `\"unquant\"` to force no quantization even when the target model is quantized. | Same as target model |\n| `--speculative-draft-model-revision` | Specific revision/commit of the draft model to load. | `None` (auto-set to `\"main\"` when `--speculative-draft-model-path` is set and revision is omitted) |\n| `--speculative-draft-load-format` | Load format for the draft model weights. | `None` |\n\nThese parameters are mostly the same for EAGLE-2 and EAGLE-3. `--speculative-token-map` is ignored for EAGLE-3 models.\nFor `--speculative-num-steps`, `--speculative-eagle-topk`, and `--speculative-num-draft-tokens`: leave all three unset to use auto-tuning, or set all three explicitly when tuning.\n\nYou can find the best combinations of these parameters with [bench_speculative.py](https://github.com/sgl-project/sglang/blob/main/scripts/playground/bench_speculative.py).\n\n\n### EAGLE-2 Decoding\n\nYou can enable EAGLE-2 Decoding by setting `--speculative-algorithm EAGLE` and choosing an appropriate model.\n\n**Launch the server:**\n\n```bash\npython3 -m sglang.launch_server \\\n    --model meta-llama/Llama-2-7b-chat-hf \\\n    --speculative-algorithm EAGLE \\\n    --speculative-draft-model-path lmsys/sglang-EAGLE-llama2-chat-7B \\\n    --speculative-num-steps 3 \\\n    --speculative-eagle-topk 4 \\\n    --speculative-num-draft-tokens 16 \\\n    --mem-fraction-static 0.7 \\\n    --cuda-graph-max-bs 8 \\\n    --log-level warning\n```\n\n**Send a request:**\n\n```python\nimport openai\n\nclient = openai.Client(base_url=\"http://127.0.0.1:30000/v1\", api_key=\"None\")\n\nresponse = client.chat.completions.create(\n    model=\"meta-llama/Llama-2-7b-chat-hf\",\n    messages=[\n        {\"role\": \"user\", \"content\": \"List 3 countries and their capitals.\"},\n    ],\n    temperature=0,\n    max_tokens=64,\n)\n\nprint(response.choices[0].message.content)\n```\n\n---\n\n### EAGLE-2 Decoding with `torch.compile`\n\nYou can optionally enable `torch.compile` to apply kernel-level optimizations (operator fusion, autotune) to the draft model. The actual speedup depends on your hardware, model architecture, and batch size. In some configurations (e.g., small draft models on H100 where cuBLAS is already optimal and CUDA graphs are enabled), the benefit may be negligible. We recommend benchmarking with and without this flag on your specific setup to verify whether it helps.\n\nTo enable it, add `--enable-torch-compile` and optionally set `--torch-compile-max-bs`:\n\n```bash\npython3 -m sglang.launch_server \\\n    --model meta-llama/Llama-2-7b-chat-hf \\\n    --speculative-algorithm EAGLE \\\n    --speculative-draft-model-path lmsys/sglang-EAGLE-llama2-chat-7B \\\n    --speculative-num-steps 3 \\\n    --speculative-eagle-topk 4 \\\n    --speculative-num-draft-tokens 16 \\\n    --mem-fraction-static 0.7 \\\n    --enable-torch-compile \\\n    --torch-compile-max-bs 8 \\\n    --log-level warning\n```\n\n**Send a request:**\n\n```python\nimport openai\n\nclient = openai.Client(base_url=\"http://127.0.0.1:30000/v1\", api_key=\"None\")\n\nresponse = client.chat.completions.create(\n    model=\"meta-llama/Llama-2-7b-chat-hf\",\n    messages=[\n        {\"role\": \"user\", \"content\": \"List 3 countries and their capitals.\"},\n    ],\n    temperature=0,\n    max_tokens=64,\n)\n\nprint(response.choices[0].message.content)\n```\n\n---\n\n### EAGLE-2 Decoding via Frequency-Ranked Speculative Sampling\n\nBy employing a truncated high-frequency token vocabulary in the draft model, EAGLE speculative decoding reduces `lm_head` computational overhead while accelerating the pipeline without quality degradation. For more details, check out [the paper](https://arxiv.org/pdf/2502.14856).\n\nIn our implementation, set `--speculative-token-map` to enable the optimization. You can get the high-frequency tokens in FR-Spec from [this model](https://huggingface.co/thunlp/LLaMA3-Instruct-8B-FR-Spec). Or you can obtain high-frequency tokens by directly downloading these tokens from [this repo](https://github.com/thunlp/FR-Spec/tree/main?tab=readme-ov-file#prepare-fr-spec-vocabulary-subset).\n\nThanks for the contribution from [Weilin Zhao](https://github.com/Achazwl) and [Zhousx](https://github.com/Zhou-sx).\n\n```bash\npython3 -m sglang.launch_server \\\n    --model meta-llama/Meta-Llama-3-8B-Instruct \\\n    --speculative-algorithm EAGLE \\\n    --speculative-draft-model-path lmsys/sglang-EAGLE-LLaMA3-Instruct-8B \\\n    --speculative-num-steps 3 \\\n    --speculative-eagle-topk 4 \\\n    --speculative-num-draft-tokens 16 \\\n    --speculative-token-map thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt \\\n    --mem-fraction-static 0.7 \\\n    --cuda-graph-max-bs 8 \\\n    --dtype float16 \\\n    --log-level warning\n```\n\n**Send a request:**\n\n```python\nimport openai\n\nclient = openai.Client(base_url=\"http://127.0.0.1:30000/v1\", api_key=\"None\")\n\nresponse = client.chat.completions.create(\n    model=\"meta-llama/Meta-Llama-3-8B-Instruct\",\n    messages=[\n        {\"role\": \"user\", \"content\": \"List 3 countries and their capitals.\"},\n    ],\n    temperature=0,\n    max_tokens=64,\n)\n\nprint(response.choices[0].message.content)\n```\n\n---\n\n### EAGLE-3 Decoding\n\nYou can enable EAGLE-3 decoding by setting `--speculative-algorithm EAGLE3` and choosing an appropriate model.\n\n```bash\npython3 -m sglang.launch_server \\\n    --model meta-llama/Meta-Llama-3.1-8B-Instruct \\\n    --speculative-algorithm EAGLE3 \\\n    --speculative-draft-model-path jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B \\\n    --speculative-num-steps 3 \\\n    --speculative-eagle-topk 4 \\\n    --speculative-num-draft-tokens 16 \\\n    --mem-fraction-static 0.7 \\\n    --cuda-graph-max-bs 8 \\\n    --dtype float16 \\\n    --log-level warning\n```\n\n**Send a request:**\n\n```python\nimport openai\n\nclient = openai.Client(base_url=\"http://127.0.0.1:30000/v1\", api_key=\"None\")\n\nresponse = client.chat.completions.create(\n    model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    messages=[\n        {\"role\": \"user\", \"content\": \"List 3 countries and their capitals.\"},\n    ],\n    temperature=0,\n    max_tokens=64,\n)\n\nprint(response.choices[0].message.content)\n```\n\n---\n\n## Multi Token Prediction\n\nWe support [MTP (Multi-Token Prediction)](https://arxiv.org/pdf/2404.19737) in SGLang by using speculative decoding. We use `XiaomiMiMo/MiMo-7B-RL` as an example here (for DeepSeek MTP usage, refer to [deepseek_v32 doc](../basic_usage/deepseek_v32.md#multi-token-prediction)).\n\n```bash\npython3 -m sglang.launch_server \\\n    --model XiaomiMiMo/MiMo-7B-RL \\\n    --host 0.0.0.0 \\\n    --trust-remote-code \\\n    --speculative-algorithm EAGLE \\\n    --speculative-num-steps 1 \\\n    --speculative-eagle-topk 1 \\\n    --speculative-num-draft-tokens 2 \\\n    --mem-fraction-static 0.7 \\\n    --cuda-graph-max-bs 8 \\\n    --log-level warning\n```\n\n**Send a request:**\n\n```python\nimport requests\n\nurl = \"http://localhost:30000/v1/chat/completions\"\n\ndata = {\n    \"model\": \"XiaomiMiMo/MiMo-7B-RL\",\n    \"messages\": [{\"role\": \"user\", \"content\": \"What is the capital of France?\"}],\n}\n\nresponse = requests.post(url, json=data)\nprint(response.json())\n```\n\n---\n\n## Standalone Speculative Decoding (Small Draft Model)\n\nBesides EAGLE/MTP, SGLang also supports **token-level speculative decoding** using a smaller **draft model**. Enable it with `--speculative-algorithm STANDALONE` and provide a draft model via `--speculative-draft-model-path`.\n\nRelevant parameters:\n\n| Parameter | Description | Default |\n|---|---|---|\n| `--speculative-draft-model-path` | Draft model weights (smaller than the target model). | `None` |\n| `--speculative-num-steps` | Draft depth (how many steps the draft model runs autoregressively). | `3` (auto default for STANDALONE) |\n| `--speculative-eagle-topk` | Branching factor (token candidates per step). | `1` (auto default for STANDALONE) |\n| `--speculative-num-draft-tokens` | Verification capacity. | `4` (auto default for STANDALONE) |\n| `--speculative-draft-model-quantization` | Quantization for the draft model. Use `\"unquant\"` to disable quantization on the draft even when the target is quantized. | Same as target |\n\n> **Note:** Standalone speculative decoding currently **does not support** `--enable-dp-attention`.\n\n```bash\npython3 -m sglang.launch_server \\\n    --model Qwen/Qwen2.5-7B-Instruct \\\n    --speculative-algorithm STANDALONE \\\n    --speculative-draft-model-path Qwen/Qwen2.5-1.5B-Instruct \\\n    --speculative-num-steps 4 \\\n    --speculative-eagle-topk 2 \\\n    --speculative-num-draft-tokens 7 \\\n    --mem-fraction-static 0.7 \\\n    --cuda-graph-max-bs 8 \\\n    --log-level warning\n```\n\n**Send a request:**\n\n```python\nimport openai\n\nclient = openai.Client(base_url=\"http://127.0.0.1:30000/v1\", api_key=\"None\")\n\nresponse = client.chat.completions.create(\n    model=\"Qwen/Qwen2.5-7B-Instruct\",\n    messages=[\n        {\"role\": \"user\", \"content\": \"List 3 countries and their capitals.\"},\n    ],\n    temperature=0,\n    max_tokens=64,\n)\n\nprint(response.choices[0].message.content)\n```\n\n---\n\n## Speculative Decoding V2 (Overlap Scheduler)\n\nSGLang provides an **experimental Speculative Decoding V2** implementation that enables an overlap scheduler and uses V2 speculative workers (e.g. `StandaloneWorkerV2`, `EAGLEWorkerV2`).\n\nTo enable it, set the environment variable:\n- `SGLANG_ENABLE_SPEC_V2=True`\n\nNotes:\n- SpecV2 currently only supports `--speculative-eagle-topk 1`. When SpecV2 is enabled, **set `--speculative-eagle-topk 1` explicitly**.\n- If you explicitly set `--speculative-eagle-topk > 1`, the server will error.\n- If you omit `--speculative-eagle-topk`, auto-tuning may pick `topk > 1` for some models (e.g. Llama). This is incompatible with SpecV2 and may not always trigger an immediate config error, so set `--speculative-eagle-topk 1` explicitly.\n- This applies to `EAGLE`, `EAGLE3`, and `STANDALONE`.\n\n```bash\nSGLANG_ENABLE_SPEC_V2=True python3 -m sglang.launch_server \\\n    --model Qwen/Qwen2.5-7B-Instruct \\\n    --speculative-algorithm STANDALONE \\\n    --speculative-draft-model-path Qwen/Qwen2.5-1.5B-Instruct \\\n    --speculative-num-steps 4 \\\n    --speculative-eagle-topk 1 \\\n    --speculative-num-draft-tokens 5 \\\n    --mem-fraction-static 0.7 \\\n    --cuda-graph-max-bs 8 \\\n    --log-level warning\n```\n\n**Send a request:**\n\n```python\nimport openai\n\nclient = openai.Client(base_url=\"http://127.0.0.1:30000/v1\", api_key=\"None\")\n\nresponse = client.chat.completions.create(\n    model=\"Qwen/Qwen2.5-7B-Instruct\",\n    messages=[\n        {\"role\": \"user\", \"content\": \"List 3 countries and their capitals.\"},\n    ],\n    temperature=0,\n    max_tokens=64,\n)\n\nprint(response.choices[0].message.content)\n```\n\n---\n\n## Ngram Speculative Decoding\n\nSGLang also supports **ngram-based speculative decoding** (no separate draft model). It retrieves draft tokens from an ngram cache built from previously generated tokens, and then verifies them with the target model.\n\nEnable it with:\n- `--speculative-algorithm NGRAM`\n\n### Ngram-specific parameters\n\n| Parameter | Description | Default |\n|---|---|---|\n| `--speculative-num-draft-tokens` | Number of draft tokens verified per step. If omitted, defaults to `--speculative-ngram-max-match-window-size`. | `12` (with default ngram settings) |\n| `--speculative-ngram-min-match-window-size` | Minimum matching window size. | `1` |\n| `--speculative-ngram-max-match-window-size` | Maximum matching window size. | `12` |\n| `--speculative-ngram-min-bfs-breadth` | Minimum BFS breadth. | `1` |\n| `--speculative-ngram-max-bfs-breadth` | Maximum BFS breadth. | `10` |\n| `--speculative-ngram-match-type` | Match type: `\"BFS\"` or `\"PROB\"`. | `\"BFS\"` |\n| `--speculative-ngram-branch-length` | How many recent tokens to insert into the cache. | `18` |\n| `--speculative-ngram-capacity` | Cache capacity (number of entries). | `10,000,000` |\n\nNotes:\n- Ngram speculative decoding **only supports CUDA**.\n- It currently **does not support** `--enable-dp-attention`.\n- It disables the overlap scheduler and mixed chunked prefill.\n- If `--speculative-ngram-max-bfs-breadth > 1` (thus `speculative_eagle_topk > 1`) and `page_size > 1`, use `--attention-backend flashinfer`; otherwise the server will error.\n- Optional: set `SGLANG_NGRAM_FORCE_GREEDY_VERIFY=True` to force greedy verification.\n\n```bash\npython3 -m sglang.launch_server \\\n    --model Qwen/Qwen2.5-7B-Instruct \\\n    --speculative-algorithm NGRAM \\\n    --speculative-num-draft-tokens 16 \\\n    --speculative-ngram-max-match-window-size 12 \\\n    --speculative-ngram-max-bfs-breadth 10 \\\n    --mem-fraction-static 0.7 \\\n    --cuda-graph-max-bs 8 \\\n    --log-level warning\n```\n\n**Send a request:**\n\n```python\nimport openai\n\nclient = openai.Client(base_url=\"http://127.0.0.1:30000/v1\", api_key=\"None\")\n\nresponse = client.chat.completions.create(\n    model=\"Qwen/Qwen2.5-7B-Instruct\",\n    messages=[\n        {\"role\": \"user\", \"content\": \"List 3 countries and their capitals.\"},\n    ],\n    temperature=0,\n    max_tokens=64,\n)\n\nprint(response.choices[0].message.content)\n```\n\n---\n\n## Full Parameter Reference\n\nBelow is a comprehensive list of all speculative decoding parameters available in SGLang:\n\n### Core parameters\n\n| Parameter | Type | Default | Description |\n|---|---|---|---|\n| `--speculative-algorithm` | `str` | `None` | Algorithm to use: `EAGLE`, `EAGLE3`, `STANDALONE`, `NGRAM`, `NEXTN` (alias of `EAGLE`) |\n| `--speculative-draft-model-path` | `str` | `None` | Path to the draft model weights |\n| `--speculative-draft-model-revision` | `str` | `None` | Specific revision/commit of the draft model (`\"main\"` is auto-used when draft path is set and revision is omitted) |\n| `--speculative-draft-load-format` | `str` | `None` | Load format for draft model weights |\n| `--speculative-num-steps` | `int` | `None` (auto-chosen when omitted) | Autoregressive drafting depth |\n| `--speculative-eagle-topk` | `int` | `None` (auto-chosen when omitted) | Branching factor per drafting step |\n| `--speculative-num-draft-tokens` | `int` | `None` (auto-chosen when omitted) | Maximum number of draft tokens for verification |\n| `--speculative-accept-threshold-single` | `float` | `1.0` | Single-token acceptance threshold |\n| `--speculative-accept-threshold-acc` | `float` | `1.0` | Accumulated acceptance threshold |\n| `--speculative-token-map` | `str` | `None` | Path to FR-Spec high-frequency token map |\n| `--speculative-attention-mode` | `str` | `\"prefill\"` | Attention mode for speculative operations (`\"prefill\"` or `\"decode\"`) |\n| `--speculative-draft-attention-backend` | `str` | `None` | Override attention backend for the draft model |\n| `--speculative-moe-runner-backend` | `str` | `None` | MoE runner backend for the draft model |\n| `--speculative-moe-a2a-backend` | `str` | `None` | MoE all-to-all backend for the draft model |\n| `--speculative-draft-model-quantization` | `str` | Same as target | Quantization for the draft model (`\"unquant\"` to disable) |\n\n### Ngram-specific parameters\n\n| Parameter | Type | Default | Description |\n|---|---|---|---|\n| `--speculative-ngram-min-match-window-size` | `int` | `1` | Minimum ngram matching window |\n| `--speculative-ngram-max-match-window-size` | `int` | `12` | Maximum ngram matching window |\n| `--speculative-ngram-min-bfs-breadth` | `int` | `1` | Minimum BFS breadth |\n| `--speculative-ngram-max-bfs-breadth` | `int` | `10` | Maximum BFS breadth |\n| `--speculative-ngram-match-type` | `str` | `\"BFS\"` | Match type: `\"BFS\"` or `\"PROB\"` |\n| `--speculative-ngram-branch-length` | `int` | `18` | Recent tokens to insert into cache |\n| `--speculative-ngram-capacity` | `int` | `10,000,000` | Cache capacity |\n\n### Environment variables\n\n| Variable | Default | Description |\n|---|---|---|\n| `SGLANG_ENABLE_SPEC_V2` | `False` | Enable Speculative Decoding V2 (overlap scheduler) |\n| `SGLANG_NGRAM_FORCE_GREEDY_VERIFY` | `False` | Force greedy verification for ngram decoding |\n\n### Other related flags\n\n| Parameter | Description |\n|---|---|\n| `--enable-multi-layer-eagle` | Enable multi-layer EAGLE (auto-enabled for MiMoV2 and Step3p5 models) |\n| `--enable-torch-compile` | Enable `torch.compile` for kernel-level optimizations |\n| `--torch-compile-max-bs` | Maximum batch size for `torch.compile` |\n\n---\n\n## OOM Troubleshooting\n\n> [!WARNING]\n> **Out of Memory (OOM)?** Speculative decoding may increase GPU memory usage because the draft tree, CUDA graphs, and verification-related buffers consume additional VRAM. If you encounter OOM errors, try the following adjustments.\n\n### Step 1: Lower static memory fraction (most effective)\n\n```bash\n--mem-fraction-static 0.5   # when omitted, this value is auto-computed\n```\n\n- `--mem-fraction-static` controls the memory budget for model weights + KV cache pool.\n- Lowering it directly increases dynamic headroom for activations and CUDA graph buffers.\n- If omitted, SGLang auto-estimates this value from other settings, and those auto settings can still be too aggressive for some workloads.\n\n### Step 2: Reduce CUDA graph batch size\n\n```bash\n# Fewer CUDA graph captures = less memory reserved\n--cuda-graph-max-bs 4   # or even 2 for tight memory situations\n```\n\n- If omitted, `--cuda-graph-max-bs` is auto-selected based on GPU memory and TP size, and can be much larger on high-memory GPUs.\n\n### Step 3: Reduce draft tree size\n\nThese three parameters directly control how much memory the draft tree consumes:\n\n```bash\n# Before (aggressive, high memory)\n--speculative-num-steps 5 --speculative-eagle-topk 8 --speculative-num-draft-tokens 64\n\n# After (conservative, lower memory)\n--speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4\n```\n\n### Step 4: Limit concurrent requests\n\n```bash\n# Fewer concurrent requests lowers in-flight load and can reduce OOM risk\n--max-running-requests 4\n```\n\n### Quick OOM recovery recipe\n\nIf you're hitting OOM and just want something that works, start with this minimal configuration and scale up:\n\n```bash\npython3 -m sglang.launch_server \\\n    --model <your-model> \\\n    --speculative-algorithm EAGLE \\\n    --speculative-draft-model-path <your-draft-model> \\\n    --speculative-num-steps 3 \\\n    --speculative-eagle-topk 1 \\\n    --speculative-num-draft-tokens 4 \\\n    --cuda-graph-max-bs 2 \\\n    --mem-fraction-static 0.5 \\\n    --max-running-requests 4 \\\n    --log-level warning\n```\n\nThen gradually increase `--speculative-num-draft-tokens`, `--speculative-eagle-topk`, and `--cuda-graph-max-bs`. Increase `--mem-fraction-static` last, only after the run is stable.\n\n---\n\n## References\n\nEAGLE process is as follows:\n\n- Within EAGLE the draft model predicts the next feature vector, i.e. the last hidden state of the original LLM, using the feature sequence $(f_1, ..., f_k)$ and the token sequence $(t_2, ..., t_{k+1})$.\n- The next token is then sampled from $p_{k+2}=\\text{LMHead}(f_{k+1})$. Afterwards, the two sequences are extended in a tree style—branching out multiple potential continuations, with the branching factor per step controlled by the `speculative_eagle_topk` parameter—to ensure a more coherent connection of context, and are given as input again.\n- In SGLang's EAGLE-2 implementation, the draft tree is expanded for the configured steps and then reranked to select the top `speculative_num_draft_tokens` final nodes as draft tokens.\n- EAGLE-3 removes the feature prediction objective, incorporates low and mid-layer features, and is trained in an on-policy manner.\n\nThis enhances drafting accuracy by operating on features instead of tokens for more regular inputs and by additionally passing tokens from the next timestep to reduce sampling randomness. For more details, see the [EAGLE-2](https://arxiv.org/abs/2406.16858) and [EAGLE-3](https://arxiv.org/abs/2503.01840) papers.\n\nFor guidance on how to train your own EAGLE model please see the [EAGLE repo](https://github.com/SafeAILab/EAGLE/tree/main?tab=readme-ov-file#train). For EAGLE-3 training specifically, check out [SpecForge](https://github.com/sgl-project/SpecForge), the SGLang team's training framework designed for EAGLE-3 speculative decoding models with seamless porting to SGLang serving. See the [SpecForge documentation](https://docs.sglang.ai/SpecForge/) and [blog post](https://lmsys.org/blog/2025-07-25-spec-forge) for details.\n"
  },
  {
    "path": "docs/advanced_features/structured_outputs.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Structured Outputs\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"You can specify a JSON schema, [regular expression](https://en.wikipedia.org/wiki/Regular_expression) or [EBNF](https://en.wikipedia.org/wiki/Extended_Backus%E2%80%93Naur_form) to constrain the model output. The model output will be guaranteed to follow the given constraints. Only one constraint parameter (`json_schema`, `regex`, or `ebnf`) can be specified for a request.\\n\",\n    \"\\n\",\n    \"SGLang supports three grammar backends:\\n\",\n    \"\\n\",\n    \"- [XGrammar](https://github.com/mlc-ai/xgrammar)(default): Supports JSON schema, regular expression, and EBNF constraints.\\n\",\n    \"- [Outlines](https://github.com/dottxt-ai/outlines): Supports JSON schema and regular expression constraints.\\n\",\n    \"- [Llguidance](https://github.com/guidance-ai/llguidance): Supports JSON schema, regular expression, and EBNF constraints.\\n\",\n    \"\\n\",\n    \"We suggest using XGrammar for its better performance and utility. XGrammar currently uses the [GGML BNF format](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md). For more details, see [XGrammar technical overview](https://blog.mlc.ai/2024/11/22/achieving-efficient-flexible-portable-structured-generation-with-xgrammar).\\n\",\n    \"\\n\",\n    \"To use Outlines, simply add `--grammar-backend outlines` when launching the server.\\n\",\n    \"To use llguidance, add `--grammar-backend llguidance`  when launching the server.\\n\",\n    \"If no backend is specified, XGrammar will be used as the default.\\n\",\n    \"\\n\",\n    \"For better output quality, **It's advisable to explicitly include instructions in the prompt to guide the model to generate the desired format.** For example, you can specify, 'Please generate the output in the following JSON format: ...'.\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## OpenAI Compatible API\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import openai\\n\",\n    \"import os\\n\",\n    \"\\n\",\n    \"from sglang.test.doc_patch import launch_server_cmd\\n\",\n    \"from sglang.utils import wait_for_server, print_highlight, terminate_process\\n\",\n    \"\\n\",\n    \"os.environ[\\\"TOKENIZERS_PARALLELISM\\\"] = \\\"false\\\"\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"server_process, port = launch_server_cmd(\\n\",\n    \"    \\\"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --host 0.0.0.0 --log-level warning\\\"\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"wait_for_server(f\\\"http://localhost:{port}\\\", process=server_process)\\n\",\n    \"client = openai.Client(base_url=f\\\"http://127.0.0.1:{port}/v1\\\", api_key=\\\"None\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### JSON\\n\",\n    \"\\n\",\n    \"you can directly define a JSON schema or use [Pydantic](https://docs.pydantic.dev/latest/) to define and validate the response.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"**Using Pydantic**\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from pydantic import BaseModel, Field\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"# Define the schema using Pydantic\\n\",\n    \"class CapitalInfo(BaseModel):\\n\",\n    \"    name: str = Field(..., pattern=r\\\"^\\\\w+$\\\", description=\\\"Name of the capital city\\\")\\n\",\n    \"    population: int = Field(..., description=\\\"Population of the capital city\\\")\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"response = client.chat.completions.create(\\n\",\n    \"    model=\\\"meta-llama/Meta-Llama-3.1-8B-Instruct\\\",\\n\",\n    \"    messages=[\\n\",\n    \"        {\\n\",\n    \"            \\\"role\\\": \\\"user\\\",\\n\",\n    \"            \\\"content\\\": \\\"Please generate the information of the capital of France in the JSON format.\\\",\\n\",\n    \"        },\\n\",\n    \"    ],\\n\",\n    \"    temperature=0,\\n\",\n    \"    max_tokens=128,\\n\",\n    \"    response_format={\\n\",\n    \"        \\\"type\\\": \\\"json_schema\\\",\\n\",\n    \"        \\\"json_schema\\\": {\\n\",\n    \"            \\\"name\\\": \\\"foo\\\",\\n\",\n    \"            # convert the pydantic model to json schema\\n\",\n    \"            \\\"schema\\\": CapitalInfo.model_json_schema(),\\n\",\n    \"        },\\n\",\n    \"    },\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"response_content = response.choices[0].message.content\\n\",\n    \"# validate the JSON response by the pydantic model\\n\",\n    \"capital_info = CapitalInfo.model_validate_json(response_content)\\n\",\n    \"print_highlight(f\\\"Validated response: {capital_info.model_dump_json()}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"**JSON Schema Directly**\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import json\\n\",\n    \"\\n\",\n    \"json_schema = json.dumps(\\n\",\n    \"    {\\n\",\n    \"        \\\"type\\\": \\\"object\\\",\\n\",\n    \"        \\\"properties\\\": {\\n\",\n    \"            \\\"name\\\": {\\\"type\\\": \\\"string\\\", \\\"pattern\\\": \\\"^[\\\\\\\\w]+$\\\"},\\n\",\n    \"            \\\"population\\\": {\\\"type\\\": \\\"integer\\\"},\\n\",\n    \"        },\\n\",\n    \"        \\\"required\\\": [\\\"name\\\", \\\"population\\\"],\\n\",\n    \"    }\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"response = client.chat.completions.create(\\n\",\n    \"    model=\\\"meta-llama/Meta-Llama-3.1-8B-Instruct\\\",\\n\",\n    \"    messages=[\\n\",\n    \"        {\\n\",\n    \"            \\\"role\\\": \\\"user\\\",\\n\",\n    \"            \\\"content\\\": \\\"Give me the information of the capital of France in the JSON format.\\\",\\n\",\n    \"        },\\n\",\n    \"    ],\\n\",\n    \"    temperature=0,\\n\",\n    \"    max_tokens=128,\\n\",\n    \"    response_format={\\n\",\n    \"        \\\"type\\\": \\\"json_schema\\\",\\n\",\n    \"        \\\"json_schema\\\": {\\\"name\\\": \\\"foo\\\", \\\"schema\\\": json.loads(json_schema)},\\n\",\n    \"    },\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"print_highlight(response.choices[0].message.content)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### EBNF\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"ebnf_grammar = \\\"\\\"\\\"\\n\",\n    \"root ::= city | description\\n\",\n    \"city ::= \\\"London\\\" | \\\"Paris\\\" | \\\"Berlin\\\" | \\\"Rome\\\"\\n\",\n    \"description ::= city \\\" is \\\" status\\n\",\n    \"status ::= \\\"the capital of \\\" country\\n\",\n    \"country ::= \\\"England\\\" | \\\"France\\\" | \\\"Germany\\\" | \\\"Italy\\\"\\n\",\n    \"\\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"response = client.chat.completions.create(\\n\",\n    \"    model=\\\"meta-llama/Meta-Llama-3.1-8B-Instruct\\\",\\n\",\n    \"    messages=[\\n\",\n    \"        {\\\"role\\\": \\\"system\\\", \\\"content\\\": \\\"You are a helpful geography bot.\\\"},\\n\",\n    \"        {\\n\",\n    \"            \\\"role\\\": \\\"user\\\",\\n\",\n    \"            \\\"content\\\": \\\"Give me the information of the capital of France.\\\",\\n\",\n    \"        },\\n\",\n    \"    ],\\n\",\n    \"    temperature=0,\\n\",\n    \"    max_tokens=32,\\n\",\n    \"    extra_body={\\\"ebnf\\\": ebnf_grammar},\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"print_highlight(response.choices[0].message.content)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Regular expression\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"response = client.chat.completions.create(\\n\",\n    \"    model=\\\"meta-llama/Meta-Llama-3.1-8B-Instruct\\\",\\n\",\n    \"    messages=[\\n\",\n    \"        {\\\"role\\\": \\\"user\\\", \\\"content\\\": \\\"What is the capital of France?\\\"},\\n\",\n    \"    ],\\n\",\n    \"    temperature=0,\\n\",\n    \"    max_tokens=128,\\n\",\n    \"    extra_body={\\\"regex\\\": \\\"(Paris|London)\\\"},\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"print_highlight(response.choices[0].message.content)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Structural Tag\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"tool_get_current_weather = {\\n\",\n    \"    \\\"type\\\": \\\"function\\\",\\n\",\n    \"    \\\"function\\\": {\\n\",\n    \"        \\\"name\\\": \\\"get_current_weather\\\",\\n\",\n    \"        \\\"description\\\": \\\"Get the current weather in a given location\\\",\\n\",\n    \"        \\\"parameters\\\": {\\n\",\n    \"            \\\"type\\\": \\\"object\\\",\\n\",\n    \"            \\\"properties\\\": {\\n\",\n    \"                \\\"city\\\": {\\n\",\n    \"                    \\\"type\\\": \\\"string\\\",\\n\",\n    \"                    \\\"description\\\": \\\"The city to find the weather for, e.g. 'San Francisco'\\\",\\n\",\n    \"                },\\n\",\n    \"                \\\"state\\\": {\\n\",\n    \"                    \\\"type\\\": \\\"string\\\",\\n\",\n    \"                    \\\"description\\\": \\\"the two-letter abbreviation for the state that the city is\\\"\\n\",\n    \"                    \\\" in, e.g. 'CA' which would mean 'California'\\\",\\n\",\n    \"                },\\n\",\n    \"                \\\"unit\\\": {\\n\",\n    \"                    \\\"type\\\": \\\"string\\\",\\n\",\n    \"                    \\\"description\\\": \\\"The unit to fetch the temperature in\\\",\\n\",\n    \"                    \\\"enum\\\": [\\\"celsius\\\", \\\"fahrenheit\\\"],\\n\",\n    \"                },\\n\",\n    \"            },\\n\",\n    \"            \\\"required\\\": [\\\"city\\\", \\\"state\\\", \\\"unit\\\"],\\n\",\n    \"        },\\n\",\n    \"    },\\n\",\n    \"}\\n\",\n    \"\\n\",\n    \"tool_get_current_date = {\\n\",\n    \"    \\\"type\\\": \\\"function\\\",\\n\",\n    \"    \\\"function\\\": {\\n\",\n    \"        \\\"name\\\": \\\"get_current_date\\\",\\n\",\n    \"        \\\"description\\\": \\\"Get the current date and time for a given timezone\\\",\\n\",\n    \"        \\\"parameters\\\": {\\n\",\n    \"            \\\"type\\\": \\\"object\\\",\\n\",\n    \"            \\\"properties\\\": {\\n\",\n    \"                \\\"timezone\\\": {\\n\",\n    \"                    \\\"type\\\": \\\"string\\\",\\n\",\n    \"                    \\\"description\\\": \\\"The timezone to fetch the current date and time for, e.g. 'America/New_York'\\\",\\n\",\n    \"                }\\n\",\n    \"            },\\n\",\n    \"            \\\"required\\\": [\\\"timezone\\\"],\\n\",\n    \"        },\\n\",\n    \"    },\\n\",\n    \"}\\n\",\n    \"\\n\",\n    \"schema_get_current_weather = tool_get_current_weather[\\\"function\\\"][\\\"parameters\\\"]\\n\",\n    \"schema_get_current_date = tool_get_current_date[\\\"function\\\"][\\\"parameters\\\"]\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def get_messages():\\n\",\n    \"    return [\\n\",\n    \"        {\\n\",\n    \"            \\\"role\\\": \\\"system\\\",\\n\",\n    \"            \\\"content\\\": f\\\"\\\"\\\"\\n\",\n    \"# Tool Instructions\\n\",\n    \"- Always execute python code in messages that you share.\\n\",\n    \"- When looking for real time information use relevant functions if available else fallback to brave_search\\n\",\n    \"You have access to the following functions:\\n\",\n    \"Use the function 'get_current_weather' to: Get the current weather in a given location\\n\",\n    \"{tool_get_current_weather[\\\"function\\\"]}\\n\",\n    \"Use the function 'get_current_date' to: Get the current date and time for a given timezone\\n\",\n    \"{tool_get_current_date[\\\"function\\\"]}\\n\",\n    \"If a you choose to call a function ONLY reply in the following format:\\n\",\n    \"<{{start_tag}}={{function_name}}>{{parameters}}{{end_tag}}\\n\",\n    \"where\\n\",\n    \"start_tag => `<function`\\n\",\n    \"parameters => a JSON dict with the function argument name as key and function argument value as value.\\n\",\n    \"end_tag => `</function>`\\n\",\n    \"Here is an example,\\n\",\n    \"<function=example_function_name>{{\\\"example_name\\\": \\\"example_value\\\"}}</function>\\n\",\n    \"Reminder:\\n\",\n    \"- Function calls MUST follow the specified format\\n\",\n    \"- Required parameters MUST be specified\\n\",\n    \"- Only call one function at a time\\n\",\n    \"- Put the entire function call reply on one line\\n\",\n    \"- Always add your sources when using search results to answer the user query\\n\",\n    \"You are a helpful assistant.\\\"\\\"\\\",\\n\",\n    \"        },\\n\",\n    \"        {\\n\",\n    \"            \\\"role\\\": \\\"user\\\",\\n\",\n    \"            \\\"content\\\": \\\"You are in New York. Please get the current date and time, and the weather.\\\",\\n\",\n    \"        },\\n\",\n    \"    ]\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"messages = get_messages()\\n\",\n    \"\\n\",\n    \"response = client.chat.completions.create(\\n\",\n    \"    model=\\\"meta-llama/Meta-Llama-3.1-8B-Instruct\\\",\\n\",\n    \"    messages=messages,\\n\",\n    \"    response_format={\\n\",\n    \"        \\\"type\\\": \\\"structural_tag\\\",\\n\",\n    \"        \\\"structures\\\": [\\n\",\n    \"            {\\n\",\n    \"                \\\"begin\\\": \\\"<function=get_current_weather>\\\",\\n\",\n    \"                \\\"schema\\\": schema_get_current_weather,\\n\",\n    \"                \\\"end\\\": \\\"</function>\\\",\\n\",\n    \"            },\\n\",\n    \"            {\\n\",\n    \"                \\\"begin\\\": \\\"<function=get_current_date>\\\",\\n\",\n    \"                \\\"schema\\\": schema_get_current_date,\\n\",\n    \"                \\\"end\\\": \\\"</function>\\\",\\n\",\n    \"            },\\n\",\n    \"        ],\\n\",\n    \"        \\\"triggers\\\": [\\\"<function=\\\"],\\n\",\n    \"    },\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"print_highlight(response.choices[0].message.content)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Support for XGrammar latest structural tag format\\n\",\n    \"# https://xgrammar.mlc.ai/docs/tutorials/structural_tag.html\\n\",\n    \"\\n\",\n    \"response = client.chat.completions.create(\\n\",\n    \"    model=\\\"meta-llama/Meta-Llama-3.1-8B-Instruct\\\",\\n\",\n    \"    messages=messages,\\n\",\n    \"    response_format={\\n\",\n    \"        \\\"type\\\": \\\"structural_tag\\\",\\n\",\n    \"        \\\"format\\\": {\\n\",\n    \"            \\\"type\\\": \\\"triggered_tags\\\",\\n\",\n    \"            \\\"triggers\\\": [\\\"<function=\\\"],\\n\",\n    \"            \\\"tags\\\": [\\n\",\n    \"                {\\n\",\n    \"                    \\\"begin\\\": \\\"<function=get_current_weather>\\\",\\n\",\n    \"                    \\\"content\\\": {\\n\",\n    \"                        \\\"type\\\": \\\"json_schema\\\",\\n\",\n    \"                        \\\"json_schema\\\": schema_get_current_weather,\\n\",\n    \"                    },\\n\",\n    \"                    \\\"end\\\": \\\"</function>\\\",\\n\",\n    \"                },\\n\",\n    \"                {\\n\",\n    \"                    \\\"begin\\\": \\\"<function=get_current_date>\\\",\\n\",\n    \"                    \\\"content\\\": {\\n\",\n    \"                        \\\"type\\\": \\\"json_schema\\\",\\n\",\n    \"                        \\\"json_schema\\\": schema_get_current_date,\\n\",\n    \"                    },\\n\",\n    \"                    \\\"end\\\": \\\"</function>\\\",\\n\",\n    \"                },\\n\",\n    \"            ],\\n\",\n    \"            \\\"at_least_one\\\": False,\\n\",\n    \"            \\\"stop_after_first\\\": False,\\n\",\n    \"        },\\n\",\n    \"    },\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"print_highlight(response.choices[0].message.content)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Native API and SGLang Runtime (SRT)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### JSON\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"**Using Pydantic**\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import requests\\n\",\n    \"import json\\n\",\n    \"from pydantic import BaseModel, Field\\n\",\n    \"\\n\",\n    \"from transformers import AutoTokenizer\\n\",\n    \"\\n\",\n    \"tokenizer = AutoTokenizer.from_pretrained(\\\"meta-llama/Meta-Llama-3.1-8B-Instruct\\\")\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"# Define the schema using Pydantic\\n\",\n    \"class CapitalInfo(BaseModel):\\n\",\n    \"    name: str = Field(..., pattern=r\\\"^\\\\w+$\\\", description=\\\"Name of the capital city\\\")\\n\",\n    \"    population: int = Field(..., description=\\\"Population of the capital city\\\")\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"# Make API request\\n\",\n    \"messages = [\\n\",\n    \"    {\\n\",\n    \"        \\\"role\\\": \\\"user\\\",\\n\",\n    \"        \\\"content\\\": \\\"Here is the information of the capital of France in the JSON format.\\\\n\\\",\\n\",\n    \"    }\\n\",\n    \"]\\n\",\n    \"text = tokenizer.apply_chat_template(\\n\",\n    \"    messages, tokenize=False, add_generation_prompt=True, return_dict=False\\n\",\n    \")\\n\",\n    \"response = requests.post(\\n\",\n    \"    f\\\"http://localhost:{port}/generate\\\",\\n\",\n    \"    json={\\n\",\n    \"        \\\"text\\\": text,\\n\",\n    \"        \\\"sampling_params\\\": {\\n\",\n    \"            \\\"temperature\\\": 0,\\n\",\n    \"            \\\"max_new_tokens\\\": 64,\\n\",\n    \"            \\\"json_schema\\\": json.dumps(CapitalInfo.model_json_schema()),\\n\",\n    \"        },\\n\",\n    \"    },\\n\",\n    \")\\n\",\n    \"print_highlight(response.json())\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"response_data = json.loads(response.json()[\\\"text\\\"])\\n\",\n    \"# validate the response by the pydantic model\\n\",\n    \"capital_info = CapitalInfo.model_validate(response_data)\\n\",\n    \"print_highlight(f\\\"Validated response: {capital_info.model_dump_json()}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"**JSON Schema Directly**\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"json_schema = json.dumps(\\n\",\n    \"    {\\n\",\n    \"        \\\"type\\\": \\\"object\\\",\\n\",\n    \"        \\\"properties\\\": {\\n\",\n    \"            \\\"name\\\": {\\\"type\\\": \\\"string\\\", \\\"pattern\\\": \\\"^[\\\\\\\\w]+$\\\"},\\n\",\n    \"            \\\"population\\\": {\\\"type\\\": \\\"integer\\\"},\\n\",\n    \"        },\\n\",\n    \"        \\\"required\\\": [\\\"name\\\", \\\"population\\\"],\\n\",\n    \"    }\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"# JSON\\n\",\n    \"response = requests.post(\\n\",\n    \"    f\\\"http://localhost:{port}/generate\\\",\\n\",\n    \"    json={\\n\",\n    \"        \\\"text\\\": text,\\n\",\n    \"        \\\"sampling_params\\\": {\\n\",\n    \"            \\\"temperature\\\": 0,\\n\",\n    \"            \\\"max_new_tokens\\\": 64,\\n\",\n    \"            \\\"json_schema\\\": json_schema,\\n\",\n    \"        },\\n\",\n    \"    },\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"print_highlight(response.json())\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### EBNF\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"messages = [\\n\",\n    \"    {\\n\",\n    \"        \\\"role\\\": \\\"user\\\",\\n\",\n    \"        \\\"content\\\": \\\"Give me the information of the capital of France.\\\",\\n\",\n    \"    }\\n\",\n    \"]\\n\",\n    \"text = tokenizer.apply_chat_template(\\n\",\n    \"    messages, tokenize=False, add_generation_prompt=True, return_dict=False\\n\",\n    \")\\n\",\n    \"response = requests.post(\\n\",\n    \"    f\\\"http://localhost:{port}/generate\\\",\\n\",\n    \"    json={\\n\",\n    \"        \\\"text\\\": text,\\n\",\n    \"        \\\"sampling_params\\\": {\\n\",\n    \"            \\\"max_new_tokens\\\": 128,\\n\",\n    \"            \\\"temperature\\\": 0,\\n\",\n    \"            \\\"n\\\": 3,\\n\",\n    \"            \\\"ebnf\\\": (\\n\",\n    \"                \\\"root ::= city | description\\\\n\\\"\\n\",\n    \"                'city ::= \\\"London\\\" | \\\"Paris\\\" | \\\"Berlin\\\" | \\\"Rome\\\"\\\\n'\\n\",\n    \"                'description ::= city \\\" is \\\" status\\\\n'\\n\",\n    \"                'status ::= \\\"the capital of \\\" country\\\\n'\\n\",\n    \"                'country ::= \\\"England\\\" | \\\"France\\\" | \\\"Germany\\\" | \\\"Italy\\\"'\\n\",\n    \"            ),\\n\",\n    \"        },\\n\",\n    \"        \\\"stream\\\": False,\\n\",\n    \"        \\\"return_logprob\\\": False,\\n\",\n    \"    },\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"print_highlight(response.json())\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Regular expression\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"messages = [\\n\",\n    \"    {\\n\",\n    \"        \\\"role\\\": \\\"user\\\",\\n\",\n    \"        \\\"content\\\": \\\"Paris is the capital of\\\",\\n\",\n    \"    }\\n\",\n    \"]\\n\",\n    \"text = tokenizer.apply_chat_template(\\n\",\n    \"    messages, tokenize=False, add_generation_prompt=True, return_dict=False\\n\",\n    \")\\n\",\n    \"response = requests.post(\\n\",\n    \"    f\\\"http://localhost:{port}/generate\\\",\\n\",\n    \"    json={\\n\",\n    \"        \\\"text\\\": text,\\n\",\n    \"        \\\"sampling_params\\\": {\\n\",\n    \"            \\\"temperature\\\": 0,\\n\",\n    \"            \\\"max_new_tokens\\\": 64,\\n\",\n    \"            \\\"regex\\\": \\\"(France|England)\\\",\\n\",\n    \"        },\\n\",\n    \"    },\\n\",\n    \")\\n\",\n    \"print_highlight(response.json())\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Structural Tag\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from transformers import AutoTokenizer\\n\",\n    \"\\n\",\n    \"# generate an answer\\n\",\n    \"tokenizer = AutoTokenizer.from_pretrained(\\\"meta-llama/Meta-Llama-3.1-8B-Instruct\\\")\\n\",\n    \"\\n\",\n    \"text = tokenizer.apply_chat_template(\\n\",\n    \"    messages, tokenize=False, add_generation_prompt=True, return_dict=False\\n\",\n    \")\\n\",\n    \"payload = {\\n\",\n    \"    \\\"text\\\": text,\\n\",\n    \"    \\\"sampling_params\\\": {\\n\",\n    \"        \\\"structural_tag\\\": json.dumps(\\n\",\n    \"            {\\n\",\n    \"                \\\"type\\\": \\\"structural_tag\\\",\\n\",\n    \"                \\\"structures\\\": [\\n\",\n    \"                    {\\n\",\n    \"                        \\\"begin\\\": \\\"<function=get_current_weather>\\\",\\n\",\n    \"                        \\\"schema\\\": schema_get_current_weather,\\n\",\n    \"                        \\\"end\\\": \\\"</function>\\\",\\n\",\n    \"                    },\\n\",\n    \"                    {\\n\",\n    \"                        \\\"begin\\\": \\\"<function=get_current_date>\\\",\\n\",\n    \"                        \\\"schema\\\": schema_get_current_date,\\n\",\n    \"                        \\\"end\\\": \\\"</function>\\\",\\n\",\n    \"                    },\\n\",\n    \"                ],\\n\",\n    \"                \\\"triggers\\\": [\\\"<function=\\\"],\\n\",\n    \"            }\\n\",\n    \"        )\\n\",\n    \"    },\\n\",\n    \"}\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"# Send POST request to the API endpoint\\n\",\n    \"response = requests.post(f\\\"http://localhost:{port}/generate\\\", json=payload)\\n\",\n    \"print_highlight(response.json())\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Support for XGrammar latest structural tag format\\n\",\n    \"# https://xgrammar.mlc.ai/docs/tutorials/structural_tag.html\\n\",\n    \"\\n\",\n    \"payload = {\\n\",\n    \"    \\\"text\\\": text,\\n\",\n    \"    \\\"sampling_params\\\": {\\n\",\n    \"        \\\"structural_tag\\\": json.dumps(\\n\",\n    \"            {\\n\",\n    \"                \\\"type\\\": \\\"structural_tag\\\",\\n\",\n    \"                \\\"format\\\": {\\n\",\n    \"                    \\\"type\\\": \\\"triggered_tags\\\",\\n\",\n    \"                    \\\"triggers\\\": [\\\"<function=\\\"],\\n\",\n    \"                    \\\"tags\\\": [\\n\",\n    \"                        {\\n\",\n    \"                            \\\"begin\\\": \\\"<function=get_current_weather>\\\",\\n\",\n    \"                            \\\"content\\\": {\\n\",\n    \"                                \\\"type\\\": \\\"json_schema\\\",\\n\",\n    \"                                \\\"json_schema\\\": schema_get_current_weather,\\n\",\n    \"                            },\\n\",\n    \"                            \\\"end\\\": \\\"</function>\\\",\\n\",\n    \"                        },\\n\",\n    \"                        {\\n\",\n    \"                            \\\"begin\\\": \\\"<function=get_current_date>\\\",\\n\",\n    \"                            \\\"content\\\": {\\n\",\n    \"                                \\\"type\\\": \\\"json_schema\\\",\\n\",\n    \"                                \\\"json_schema\\\": schema_get_current_date,\\n\",\n    \"                            },\\n\",\n    \"                            \\\"end\\\": \\\"</function>\\\",\\n\",\n    \"                        },\\n\",\n    \"                    ],\\n\",\n    \"                    \\\"at_least_one\\\": False,\\n\",\n    \"                    \\\"stop_after_first\\\": False,\\n\",\n    \"                },\\n\",\n    \"            }\\n\",\n    \"        )\\n\",\n    \"    },\\n\",\n    \"}\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"# Send POST request to the API endpoint\\n\",\n    \"response = requests.post(f\\\"http://localhost:{port}/generate\\\", json=payload)\\n\",\n    \"print_highlight(response.json())\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"terminate_process(server_process)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Offline Engine API\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import sglang as sgl\\n\",\n    \"\\n\",\n    \"llm = sgl.Engine(\\n\",\n    \"    model_path=\\\"meta-llama/Meta-Llama-3.1-8B-Instruct\\\", grammar_backend=\\\"xgrammar\\\"\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### JSON\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"**Using Pydantic**\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import json\\n\",\n    \"from pydantic import BaseModel, Field\\n\",\n    \"\\n\",\n    \"prompts = [\\n\",\n    \"    \\\"Give me the information of the capital of China in the JSON format.\\\",\\n\",\n    \"    \\\"Give me the information of the capital of France in the JSON format.\\\",\\n\",\n    \"    \\\"Give me the information of the capital of Ireland in the JSON format.\\\",\\n\",\n    \"]\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"# Define the schema using Pydantic\\n\",\n    \"class CapitalInfo(BaseModel):\\n\",\n    \"    name: str = Field(..., pattern=r\\\"^\\\\w+$\\\", description=\\\"Name of the capital city\\\")\\n\",\n    \"    population: int = Field(..., description=\\\"Population of the capital city\\\")\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"sampling_params = {\\n\",\n    \"    \\\"temperature\\\": 0.1,\\n\",\n    \"    \\\"top_p\\\": 0.95,\\n\",\n    \"    \\\"json_schema\\\": json.dumps(CapitalInfo.model_json_schema()),\\n\",\n    \"}\\n\",\n    \"\\n\",\n    \"outputs = llm.generate(prompts, sampling_params)\\n\",\n    \"for prompt, output in zip(prompts, outputs):\\n\",\n    \"    print_highlight(\\\"===============================\\\")\\n\",\n    \"    print_highlight(f\\\"Prompt: {prompt}\\\")  # validate the output by the pydantic model\\n\",\n    \"    capital_info = CapitalInfo.model_validate_json(output[\\\"text\\\"])\\n\",\n    \"    print_highlight(f\\\"Validated output: {capital_info.model_dump_json()}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"**JSON Schema Directly**\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"prompts = [\\n\",\n    \"    \\\"Give me the information of the capital of China in the JSON format.\\\",\\n\",\n    \"    \\\"Give me the information of the capital of France in the JSON format.\\\",\\n\",\n    \"    \\\"Give me the information of the capital of Ireland in the JSON format.\\\",\\n\",\n    \"]\\n\",\n    \"\\n\",\n    \"json_schema = json.dumps(\\n\",\n    \"    {\\n\",\n    \"        \\\"type\\\": \\\"object\\\",\\n\",\n    \"        \\\"properties\\\": {\\n\",\n    \"            \\\"name\\\": {\\\"type\\\": \\\"string\\\", \\\"pattern\\\": \\\"^[\\\\\\\\w]+$\\\"},\\n\",\n    \"            \\\"population\\\": {\\\"type\\\": \\\"integer\\\"},\\n\",\n    \"        },\\n\",\n    \"        \\\"required\\\": [\\\"name\\\", \\\"population\\\"],\\n\",\n    \"    }\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"sampling_params = {\\\"temperature\\\": 0.1, \\\"top_p\\\": 0.95, \\\"json_schema\\\": json_schema}\\n\",\n    \"\\n\",\n    \"outputs = llm.generate(prompts, sampling_params)\\n\",\n    \"for prompt, output in zip(prompts, outputs):\\n\",\n    \"    print_highlight(\\\"===============================\\\")\\n\",\n    \"    print_highlight(f\\\"Prompt: {prompt}\\\\nGenerated text: {output['text']}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### EBNF\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"prompts = [\\n\",\n    \"    \\\"Give me the information of the capital of France.\\\",\\n\",\n    \"    \\\"Give me the information of the capital of Germany.\\\",\\n\",\n    \"    \\\"Give me the information of the capital of Italy.\\\",\\n\",\n    \"]\\n\",\n    \"\\n\",\n    \"sampling_params = {\\n\",\n    \"    \\\"temperature\\\": 0.8,\\n\",\n    \"    \\\"top_p\\\": 0.95,\\n\",\n    \"    \\\"ebnf\\\": (\\n\",\n    \"        \\\"root ::= city | description\\\\n\\\"\\n\",\n    \"        'city ::= \\\"London\\\" | \\\"Paris\\\" | \\\"Berlin\\\" | \\\"Rome\\\"\\\\n'\\n\",\n    \"        'description ::= city \\\" is \\\" status\\\\n'\\n\",\n    \"        'status ::= \\\"the capital of \\\" country\\\\n'\\n\",\n    \"        'country ::= \\\"England\\\" | \\\"France\\\" | \\\"Germany\\\" | \\\"Italy\\\"'\\n\",\n    \"    ),\\n\",\n    \"}\\n\",\n    \"\\n\",\n    \"outputs = llm.generate(prompts, sampling_params)\\n\",\n    \"for prompt, output in zip(prompts, outputs):\\n\",\n    \"    print_highlight(\\\"===============================\\\")\\n\",\n    \"    print_highlight(f\\\"Prompt: {prompt}\\\\nGenerated text: {output['text']}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Regular expression\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"prompts = [\\n\",\n    \"    \\\"Please provide information about London as a major global city:\\\",\\n\",\n    \"    \\\"Please provide information about Paris as a major global city:\\\",\\n\",\n    \"]\\n\",\n    \"\\n\",\n    \"sampling_params = {\\\"temperature\\\": 0.8, \\\"top_p\\\": 0.95, \\\"regex\\\": \\\"(France|England)\\\"}\\n\",\n    \"\\n\",\n    \"outputs = llm.generate(prompts, sampling_params)\\n\",\n    \"for prompt, output in zip(prompts, outputs):\\n\",\n    \"    print_highlight(\\\"===============================\\\")\\n\",\n    \"    print_highlight(f\\\"Prompt: {prompt}\\\\nGenerated text: {output['text']}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Structural Tag\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"text = tokenizer.apply_chat_template(\\n\",\n    \"    messages, tokenize=False, add_generation_prompt=True, return_dict=False\\n\",\n    \")\\n\",\n    \"prompts = [text]\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"sampling_params = {\\n\",\n    \"    \\\"temperature\\\": 0.8,\\n\",\n    \"    \\\"top_p\\\": 0.95,\\n\",\n    \"    \\\"structural_tag\\\": json.dumps(\\n\",\n    \"        {\\n\",\n    \"            \\\"type\\\": \\\"structural_tag\\\",\\n\",\n    \"            \\\"structures\\\": [\\n\",\n    \"                {\\n\",\n    \"                    \\\"begin\\\": \\\"<function=get_current_weather>\\\",\\n\",\n    \"                    \\\"schema\\\": schema_get_current_weather,\\n\",\n    \"                    \\\"end\\\": \\\"</function>\\\",\\n\",\n    \"                },\\n\",\n    \"                {\\n\",\n    \"                    \\\"begin\\\": \\\"<function=get_current_date>\\\",\\n\",\n    \"                    \\\"schema\\\": schema_get_current_date,\\n\",\n    \"                    \\\"end\\\": \\\"</function>\\\",\\n\",\n    \"                },\\n\",\n    \"            ],\\n\",\n    \"            \\\"triggers\\\": [\\\"<function=\\\"],\\n\",\n    \"        }\\n\",\n    \"    ),\\n\",\n    \"}\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"# Send POST request to the API endpoint\\n\",\n    \"outputs = llm.generate(prompts, sampling_params)\\n\",\n    \"for prompt, output in zip(prompts, outputs):\\n\",\n    \"    print_highlight(\\\"===============================\\\")\\n\",\n    \"    print_highlight(f\\\"Prompt: {prompt}\\\\nGenerated text: {output['text']}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Support for XGrammar latest structural tag format\\n\",\n    \"# https://xgrammar.mlc.ai/docs/tutorials/structural_tag.html\\n\",\n    \"\\n\",\n    \"sampling_params = {\\n\",\n    \"    \\\"temperature\\\": 0.8,\\n\",\n    \"    \\\"top_p\\\": 0.95,\\n\",\n    \"    \\\"structural_tag\\\": json.dumps(\\n\",\n    \"        {\\n\",\n    \"            \\\"type\\\": \\\"structural_tag\\\",\\n\",\n    \"            \\\"format\\\": {\\n\",\n    \"                \\\"type\\\": \\\"triggered_tags\\\",\\n\",\n    \"                \\\"triggers\\\": [\\\"<function=\\\"],\\n\",\n    \"                \\\"tags\\\": [\\n\",\n    \"                    {\\n\",\n    \"                        \\\"begin\\\": \\\"<function=get_current_weather>\\\",\\n\",\n    \"                        \\\"content\\\": {\\n\",\n    \"                            \\\"type\\\": \\\"json_schema\\\",\\n\",\n    \"                            \\\"json_schema\\\": schema_get_current_weather,\\n\",\n    \"                        },\\n\",\n    \"                        \\\"end\\\": \\\"</function>\\\",\\n\",\n    \"                    },\\n\",\n    \"                    {\\n\",\n    \"                        \\\"begin\\\": \\\"<function=get_current_date>\\\",\\n\",\n    \"                        \\\"content\\\": {\\n\",\n    \"                            \\\"type\\\": \\\"json_schema\\\",\\n\",\n    \"                            \\\"json_schema\\\": schema_get_current_date,\\n\",\n    \"                        },\\n\",\n    \"                        \\\"end\\\": \\\"</function>\\\",\\n\",\n    \"                    },\\n\",\n    \"                ],\\n\",\n    \"                \\\"at_least_one\\\": False,\\n\",\n    \"                \\\"stop_after_first\\\": False,\\n\",\n    \"            },\\n\",\n    \"        }\\n\",\n    \"    ),\\n\",\n    \"}\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"# Send POST request to the API endpoint\\n\",\n    \"outputs = llm.generate(prompts, sampling_params)\\n\",\n    \"for prompt, output in zip(prompts, outputs):\\n\",\n    \"    print_highlight(\\\"===============================\\\")\\n\",\n    \"    print_highlight(f\\\"Prompt: {prompt}\\\\nGenerated text: {output['text']}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"llm.shutdown()\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "docs/advanced_features/structured_outputs_for_reasoning_models.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Structured Outputs For Reasoning Models\\n\",\n    \"\\n\",\n    \"When working with reasoning models that use special tokens like `<think>...</think>` to denote reasoning sections, you might want to allow free-form text within these sections while still enforcing grammar constraints on the rest of the output.\\n\",\n    \"\\n\",\n    \"SGLang provides a feature to disable grammar restrictions within reasoning sections. This is particularly useful for models that need to perform complex reasoning steps before providing a structured output.\\n\",\n    \"\\n\",\n    \"To enable this feature, use the `--reasoning-parser` flag which decide the think_end_token, such as `</think>`, when launching the server. You can also specify the reasoning parser using the `--reasoning-parser` flag.\\n\",\n    \"\\n\",\n    \"## Supported Models\\n\",\n    \"\\n\",\n    \"Currently, SGLang supports the following reasoning models:\\n\",\n    \"- [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d): The reasoning content is wrapped with `<think>` and `</think>` tags.\\n\",\n    \"- [QwQ](https://huggingface.co/Qwen/QwQ-32B): The reasoning content is wrapped with `<think>` and `</think>` tags.\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"## Usage\\n\",\n    \"\\n\",\n    \"## OpenAI Compatible API\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Specify the `--grammar-backend`, `--reasoning-parser` option.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import openai\\n\",\n    \"import os\\n\",\n    \"\\n\",\n    \"from sglang.test.doc_patch import launch_server_cmd\\n\",\n    \"from sglang.utils import wait_for_server, print_highlight, terminate_process\\n\",\n    \"\\n\",\n    \"os.environ[\\\"TOKENIZERS_PARALLELISM\\\"] = \\\"false\\\"\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"server_process, port = launch_server_cmd(\\n\",\n    \"    \\\"python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-R1-Distill-Qwen-7B --host 0.0.0.0 --reasoning-parser deepseek-r1 --log-level warning\\\"\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"wait_for_server(f\\\"http://localhost:{port}\\\", process=server_process)\\n\",\n    \"client = openai.Client(base_url=f\\\"http://127.0.0.1:{port}/v1\\\", api_key=\\\"None\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### JSON\\n\",\n    \"\\n\",\n    \"you can directly define a JSON schema or use [Pydantic](https://docs.pydantic.dev/latest/) to define and validate the response.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"**Using Pydantic**\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from pydantic import BaseModel, Field\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"# Define the schema using Pydantic\\n\",\n    \"class CapitalInfo(BaseModel):\\n\",\n    \"    name: str = Field(..., pattern=r\\\"^\\\\w+$\\\", description=\\\"Name of the capital city\\\")\\n\",\n    \"    population: int = Field(..., description=\\\"Population of the capital city\\\")\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"response = client.chat.completions.create(\\n\",\n    \"    model=\\\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\\\",\\n\",\n    \"    messages=[\\n\",\n    \"        {\\n\",\n    \"            \\\"role\\\": \\\"assistant\\\",\\n\",\n    \"            \\\"content\\\": \\\"Give me the information and population of the capital of France in the JSON format.\\\",\\n\",\n    \"        },\\n\",\n    \"    ],\\n\",\n    \"    temperature=0,\\n\",\n    \"    max_tokens=2048,\\n\",\n    \"    response_format={\\n\",\n    \"        \\\"type\\\": \\\"json_schema\\\",\\n\",\n    \"        \\\"json_schema\\\": {\\n\",\n    \"            \\\"name\\\": \\\"foo\\\",\\n\",\n    \"            # convert the pydantic model to json schema\\n\",\n    \"            \\\"schema\\\": CapitalInfo.model_json_schema(),\\n\",\n    \"        },\\n\",\n    \"    },\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"print_highlight(\\n\",\n    \"    f\\\"reasoing_content: {response.choices[0].message.reasoning_content}\\\\n\\\\ncontent: {response.choices[0].message.content}\\\"\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"**JSON Schema Directly**\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import json\\n\",\n    \"\\n\",\n    \"json_schema = json.dumps(\\n\",\n    \"    {\\n\",\n    \"        \\\"type\\\": \\\"object\\\",\\n\",\n    \"        \\\"properties\\\": {\\n\",\n    \"            \\\"name\\\": {\\\"type\\\": \\\"string\\\", \\\"pattern\\\": \\\"^[\\\\\\\\w]+$\\\"},\\n\",\n    \"            \\\"population\\\": {\\\"type\\\": \\\"integer\\\"},\\n\",\n    \"        },\\n\",\n    \"        \\\"required\\\": [\\\"name\\\", \\\"population\\\"],\\n\",\n    \"    }\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"response = client.chat.completions.create(\\n\",\n    \"    model=\\\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\\\",\\n\",\n    \"    messages=[\\n\",\n    \"        {\\n\",\n    \"            \\\"role\\\": \\\"assistant\\\",\\n\",\n    \"            \\\"content\\\": \\\"Give me the information and population of the capital of France in the JSON format.\\\",\\n\",\n    \"        },\\n\",\n    \"    ],\\n\",\n    \"    temperature=0,\\n\",\n    \"    max_tokens=2048,\\n\",\n    \"    response_format={\\n\",\n    \"        \\\"type\\\": \\\"json_schema\\\",\\n\",\n    \"        \\\"json_schema\\\": {\\\"name\\\": \\\"foo\\\", \\\"schema\\\": json.loads(json_schema)},\\n\",\n    \"    },\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"print_highlight(\\n\",\n    \"    f\\\"reasoing_content: {response.choices[0].message.reasoning_content}\\\\n\\\\ncontent: {response.choices[0].message.content}\\\"\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### EBNF\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"ebnf_grammar = \\\"\\\"\\\"\\n\",\n    \"root ::= city | description\\n\",\n    \"city ::= \\\"London\\\" | \\\"Paris\\\" | \\\"Berlin\\\" | \\\"Rome\\\"\\n\",\n    \"description ::= city \\\" is \\\" status\\n\",\n    \"status ::= \\\"the capital of \\\" country\\n\",\n    \"country ::= \\\"England\\\" | \\\"France\\\" | \\\"Germany\\\" | \\\"Italy\\\"\\n\",\n    \"\\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"response = client.chat.completions.create(\\n\",\n    \"    model=\\\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\\\",\\n\",\n    \"    messages=[\\n\",\n    \"        {\\\"role\\\": \\\"system\\\", \\\"content\\\": \\\"You are a helpful geography bot.\\\"},\\n\",\n    \"        {\\n\",\n    \"            \\\"role\\\": \\\"assistant\\\",\\n\",\n    \"            \\\"content\\\": \\\"Give me the information and population of the capital of France in the JSON format.\\\",\\n\",\n    \"        },\\n\",\n    \"    ],\\n\",\n    \"    temperature=0,\\n\",\n    \"    max_tokens=2048,\\n\",\n    \"    extra_body={\\\"ebnf\\\": ebnf_grammar},\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"print_highlight(\\n\",\n    \"    f\\\"reasoing_content: {response.choices[0].message.reasoning_content}\\\\n\\\\ncontent: {response.choices[0].message.content}\\\"\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Regular expression\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"response = client.chat.completions.create(\\n\",\n    \"    model=\\\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\\\",\\n\",\n    \"    messages=[\\n\",\n    \"        {\\\"role\\\": \\\"assistant\\\", \\\"content\\\": \\\"What is the capital of France?\\\"},\\n\",\n    \"    ],\\n\",\n    \"    temperature=0,\\n\",\n    \"    max_tokens=2048,\\n\",\n    \"    extra_body={\\\"regex\\\": \\\"(Paris|London)\\\"},\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"print_highlight(\\n\",\n    \"    f\\\"reasoing_content: {response.choices[0].message.reasoning_content}\\\\n\\\\ncontent: {response.choices[0].message.content}\\\"\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Structural Tag\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"tool_get_current_weather = {\\n\",\n    \"    \\\"type\\\": \\\"function\\\",\\n\",\n    \"    \\\"function\\\": {\\n\",\n    \"        \\\"name\\\": \\\"get_current_weather\\\",\\n\",\n    \"        \\\"description\\\": \\\"Get the current weather in a given location\\\",\\n\",\n    \"        \\\"parameters\\\": {\\n\",\n    \"            \\\"type\\\": \\\"object\\\",\\n\",\n    \"            \\\"properties\\\": {\\n\",\n    \"                \\\"city\\\": {\\n\",\n    \"                    \\\"type\\\": \\\"string\\\",\\n\",\n    \"                    \\\"description\\\": \\\"The city to find the weather for, e.g. 'San Francisco'\\\",\\n\",\n    \"                },\\n\",\n    \"                \\\"state\\\": {\\n\",\n    \"                    \\\"type\\\": \\\"string\\\",\\n\",\n    \"                    \\\"description\\\": \\\"the two-letter abbreviation for the state that the city is\\\"\\n\",\n    \"                    \\\" in, e.g. 'CA' which would mean 'California'\\\",\\n\",\n    \"                },\\n\",\n    \"                \\\"unit\\\": {\\n\",\n    \"                    \\\"type\\\": \\\"string\\\",\\n\",\n    \"                    \\\"description\\\": \\\"The unit to fetch the temperature in\\\",\\n\",\n    \"                    \\\"enum\\\": [\\\"celsius\\\", \\\"fahrenheit\\\"],\\n\",\n    \"                },\\n\",\n    \"            },\\n\",\n    \"            \\\"required\\\": [\\\"city\\\", \\\"state\\\", \\\"unit\\\"],\\n\",\n    \"        },\\n\",\n    \"    },\\n\",\n    \"}\\n\",\n    \"\\n\",\n    \"tool_get_current_date = {\\n\",\n    \"    \\\"type\\\": \\\"function\\\",\\n\",\n    \"    \\\"function\\\": {\\n\",\n    \"        \\\"name\\\": \\\"get_current_date\\\",\\n\",\n    \"        \\\"description\\\": \\\"Get the current date and time for a given timezone\\\",\\n\",\n    \"        \\\"parameters\\\": {\\n\",\n    \"            \\\"type\\\": \\\"object\\\",\\n\",\n    \"            \\\"properties\\\": {\\n\",\n    \"                \\\"timezone\\\": {\\n\",\n    \"                    \\\"type\\\": \\\"string\\\",\\n\",\n    \"                    \\\"description\\\": \\\"The timezone to fetch the current date and time for, e.g. 'America/New_York'\\\",\\n\",\n    \"                }\\n\",\n    \"            },\\n\",\n    \"            \\\"required\\\": [\\\"timezone\\\"],\\n\",\n    \"        },\\n\",\n    \"    },\\n\",\n    \"}\\n\",\n    \"\\n\",\n    \"schema_get_current_weather = tool_get_current_weather[\\\"function\\\"][\\\"parameters\\\"]\\n\",\n    \"schema_get_current_date = tool_get_current_date[\\\"function\\\"][\\\"parameters\\\"]\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def get_messages():\\n\",\n    \"    return [\\n\",\n    \"        {\\n\",\n    \"            \\\"role\\\": \\\"system\\\",\\n\",\n    \"            \\\"content\\\": f\\\"\\\"\\\"\\n\",\n    \"# Tool Instructions\\n\",\n    \"- Always execute python code in messages that you share.\\n\",\n    \"- When looking for real time information use relevant functions if available else fallback to brave_search\\n\",\n    \"You have access to the following functions:\\n\",\n    \"Use the function 'get_current_weather' to: Get the current weather in a given location\\n\",\n    \"{tool_get_current_weather[\\\"function\\\"]}\\n\",\n    \"Use the function 'get_current_date' to: Get the current date and time for a given timezone\\n\",\n    \"{tool_get_current_date[\\\"function\\\"]}\\n\",\n    \"If a you choose to call a function ONLY reply in the following format:\\n\",\n    \"<{{start_tag}}={{function_name}}>{{parameters}}{{end_tag}}\\n\",\n    \"where\\n\",\n    \"start_tag => `<function`\\n\",\n    \"parameters => a JSON dict with the function argument name as key and function argument value as value.\\n\",\n    \"end_tag => `</function>`\\n\",\n    \"Here is an example,\\n\",\n    \"<function=example_function_name>{{\\\"example_name\\\": \\\"example_value\\\"}}</function>\\n\",\n    \"Reminder:\\n\",\n    \"- Function calls MUST follow the specified format\\n\",\n    \"- Required parameters MUST be specified\\n\",\n    \"- Only call one function at a time\\n\",\n    \"- Put the entire function call reply on one line\\n\",\n    \"- Always add your sources when using search results to answer the user query\\n\",\n    \"You are a helpful assistant.\\\"\\\"\\\",\\n\",\n    \"        },\\n\",\n    \"        {\\n\",\n    \"            \\\"role\\\": \\\"assistant\\\",\\n\",\n    \"            \\\"content\\\": \\\"You are in New York. Please get the current date and time, and the weather.\\\",\\n\",\n    \"        },\\n\",\n    \"    ]\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"messages = get_messages()\\n\",\n    \"\\n\",\n    \"response = client.chat.completions.create(\\n\",\n    \"    model=\\\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\\\",\\n\",\n    \"    messages=messages,\\n\",\n    \"    response_format={\\n\",\n    \"        \\\"type\\\": \\\"structural_tag\\\",\\n\",\n    \"        \\\"max_new_tokens\\\": 2048,\\n\",\n    \"        \\\"structures\\\": [\\n\",\n    \"            {\\n\",\n    \"                \\\"begin\\\": \\\"<function=get_current_weather>\\\",\\n\",\n    \"                \\\"schema\\\": schema_get_current_weather,\\n\",\n    \"                \\\"end\\\": \\\"</function>\\\",\\n\",\n    \"            },\\n\",\n    \"            {\\n\",\n    \"                \\\"begin\\\": \\\"<function=get_current_date>\\\",\\n\",\n    \"                \\\"schema\\\": schema_get_current_date,\\n\",\n    \"                \\\"end\\\": \\\"</function>\\\",\\n\",\n    \"            },\\n\",\n    \"        ],\\n\",\n    \"        \\\"triggers\\\": [\\\"<function=\\\"],\\n\",\n    \"    },\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"print_highlight(\\n\",\n    \"    f\\\"reasoing_content: {response.choices[0].message.reasoning_content}\\\\n\\\\ncontent: {response.choices[0].message.content}\\\"\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Native API and SGLang Runtime (SRT)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"> Note: For native API, as a work-around, you need to set `require_reasoning` argument to `True` to ensure the model will think before generating the structured output. It's not required for chat-completion API.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### JSON\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"**Using Pydantic**\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import requests\\n\",\n    \"from pydantic import BaseModel, Field\\n\",\n    \"from transformers import AutoTokenizer\\n\",\n    \"\\n\",\n    \"tokenizer = AutoTokenizer.from_pretrained(\\\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\\\")\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"# Define the schema using Pydantic\\n\",\n    \"class CapitalInfo(BaseModel):\\n\",\n    \"    name: str = Field(..., pattern=r\\\"^\\\\w+$\\\", description=\\\"Name of the capital city\\\")\\n\",\n    \"    population: int = Field(..., description=\\\"Population of the capital city\\\")\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"messages = [\\n\",\n    \"    {\\n\",\n    \"        \\\"role\\\": \\\"assistant\\\",\\n\",\n    \"        \\\"content\\\": \\\"Give me the information and population of the capital of France in the JSON format.\\\",\\n\",\n    \"    },\\n\",\n    \"]\\n\",\n    \"text = tokenizer.apply_chat_template(\\n\",\n    \"    messages, tokenize=False, add_generation_prompt=True, return_dict=False\\n\",\n    \")\\n\",\n    \"# Make API request\\n\",\n    \"response = requests.post(\\n\",\n    \"    f\\\"http://localhost:{port}/generate\\\",\\n\",\n    \"    json={\\n\",\n    \"        \\\"text\\\": text,\\n\",\n    \"        \\\"require_reasoning\\\": True,\\n\",\n    \"        \\\"sampling_params\\\": {\\n\",\n    \"            \\\"temperature\\\": 0,\\n\",\n    \"            \\\"max_new_tokens\\\": 2048,\\n\",\n    \"            \\\"json_schema\\\": json.dumps(CapitalInfo.model_json_schema()),\\n\",\n    \"        },\\n\",\n    \"    },\\n\",\n    \")\\n\",\n    \"print(response.json())\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"reasoing_content = response.json()[\\\"text\\\"].split(\\\"</think>\\\")[0]\\n\",\n    \"content = response.json()[\\\"text\\\"].split(\\\"</think>\\\")[1]\\n\",\n    \"print_highlight(f\\\"reasoing_content: {reasoing_content}\\\\n\\\\ncontent: {content}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"**JSON Schema Directly**\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"json_schema = json.dumps(\\n\",\n    \"    {\\n\",\n    \"        \\\"type\\\": \\\"object\\\",\\n\",\n    \"        \\\"properties\\\": {\\n\",\n    \"            \\\"name\\\": {\\\"type\\\": \\\"string\\\", \\\"pattern\\\": \\\"^[\\\\\\\\w]+$\\\"},\\n\",\n    \"            \\\"population\\\": {\\\"type\\\": \\\"integer\\\"},\\n\",\n    \"        },\\n\",\n    \"        \\\"required\\\": [\\\"name\\\", \\\"population\\\"],\\n\",\n    \"    }\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"# JSON\\n\",\n    \"text = tokenizer.apply_chat_template(\\n\",\n    \"    messages, tokenize=False, add_generation_prompt=True, return_dict=False\\n\",\n    \")\\n\",\n    \"response = requests.post(\\n\",\n    \"    f\\\"http://localhost:{port}/generate\\\",\\n\",\n    \"    json={\\n\",\n    \"        \\\"text\\\": text,\\n\",\n    \"        \\\"require_reasoning\\\": True,\\n\",\n    \"        \\\"sampling_params\\\": {\\n\",\n    \"            \\\"temperature\\\": 0,\\n\",\n    \"            \\\"max_new_tokens\\\": 2048,\\n\",\n    \"            \\\"json_schema\\\": json_schema,\\n\",\n    \"        },\\n\",\n    \"    },\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"print_highlight(response.json())\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### EBNF\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"response = requests.post(\\n\",\n    \"    f\\\"http://localhost:{port}/generate\\\",\\n\",\n    \"    json={\\n\",\n    \"        \\\"text\\\": \\\"Give me the information of the capital of France.\\\",\\n\",\n    \"        \\\"require_reasoning\\\": True,\\n\",\n    \"        \\\"sampling_params\\\": {\\n\",\n    \"            \\\"max_new_tokens\\\": 2048,\\n\",\n    \"            \\\"temperature\\\": 0,\\n\",\n    \"            \\\"n\\\": 3,\\n\",\n    \"            \\\"ebnf\\\": (\\n\",\n    \"                \\\"root ::= city | description\\\\n\\\"\\n\",\n    \"                'city ::= \\\"London\\\" | \\\"Paris\\\" | \\\"Berlin\\\" | \\\"Rome\\\"\\\\n'\\n\",\n    \"                'description ::= city \\\" is \\\" status\\\\n'\\n\",\n    \"                'status ::= \\\"the capital of \\\" country\\\\n'\\n\",\n    \"                'country ::= \\\"England\\\" | \\\"France\\\" | \\\"Germany\\\" | \\\"Italy\\\"'\\n\",\n    \"            ),\\n\",\n    \"        },\\n\",\n    \"        \\\"stream\\\": False,\\n\",\n    \"        \\\"return_logprob\\\": False,\\n\",\n    \"    },\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"print(response.json())\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Regular expression\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"response = requests.post(\\n\",\n    \"    f\\\"http://localhost:{port}/generate\\\",\\n\",\n    \"    json={\\n\",\n    \"        \\\"text\\\": \\\"Paris is the capital of\\\",\\n\",\n    \"        \\\"require_reasoning\\\": True,\\n\",\n    \"        \\\"sampling_params\\\": {\\n\",\n    \"            \\\"temperature\\\": 0,\\n\",\n    \"            \\\"max_new_tokens\\\": 2048,\\n\",\n    \"            \\\"regex\\\": \\\"(France|England)\\\",\\n\",\n    \"        },\\n\",\n    \"    },\\n\",\n    \")\\n\",\n    \"print(response.json())\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Structural Tag\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"text = tokenizer.apply_chat_template(\\n\",\n    \"    messages, tokenize=False, add_generation_prompt=True, return_dict=False\\n\",\n    \")\\n\",\n    \"payload = {\\n\",\n    \"    \\\"text\\\": text,\\n\",\n    \"    \\\"require_reasoning\\\": True,\\n\",\n    \"    \\\"sampling_params\\\": {\\n\",\n    \"        \\\"max_new_tokens\\\": 2048,\\n\",\n    \"        \\\"structural_tag\\\": json.dumps(\\n\",\n    \"            {\\n\",\n    \"                \\\"type\\\": \\\"structural_tag\\\",\\n\",\n    \"                \\\"structures\\\": [\\n\",\n    \"                    {\\n\",\n    \"                        \\\"begin\\\": \\\"<function=get_current_weather>\\\",\\n\",\n    \"                        \\\"schema\\\": schema_get_current_weather,\\n\",\n    \"                        \\\"end\\\": \\\"</function>\\\",\\n\",\n    \"                    },\\n\",\n    \"                    {\\n\",\n    \"                        \\\"begin\\\": \\\"<function=get_current_date>\\\",\\n\",\n    \"                        \\\"schema\\\": schema_get_current_date,\\n\",\n    \"                        \\\"end\\\": \\\"</function>\\\",\\n\",\n    \"                    },\\n\",\n    \"                ],\\n\",\n    \"                \\\"triggers\\\": [\\\"<function=\\\"],\\n\",\n    \"            }\\n\",\n    \"        ),\\n\",\n    \"    },\\n\",\n    \"}\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"# Send POST request to the API endpoint\\n\",\n    \"response = requests.post(f\\\"http://localhost:{port}/generate\\\", json=payload)\\n\",\n    \"print_highlight(response.json())\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"terminate_process(server_process)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Offline Engine API\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import sglang as sgl\\n\",\n    \"\\n\",\n    \"llm = sgl.Engine(\\n\",\n    \"    model_path=\\\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\\\",\\n\",\n    \"    reasoning_parser=\\\"deepseek-r1\\\",\\n\",\n    \"    grammar_backend=\\\"xgrammar\\\",\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### JSON\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"**Using Pydantic**\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import json\\n\",\n    \"from pydantic import BaseModel, Field\\n\",\n    \"\\n\",\n    \"prompts = [\\n\",\n    \"    \\\"Give me the information of the capital of China in the JSON format.\\\",\\n\",\n    \"    \\\"Give me the information of the capital of France in the JSON format.\\\",\\n\",\n    \"    \\\"Give me the information of the capital of Ireland in the JSON format.\\\",\\n\",\n    \"]\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"# Define the schema using Pydantic\\n\",\n    \"class CapitalInfo(BaseModel):\\n\",\n    \"    name: str = Field(..., pattern=r\\\"^\\\\w+$\\\", description=\\\"Name of the capital city\\\")\\n\",\n    \"    population: int = Field(..., description=\\\"Population of the capital city\\\")\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"sampling_params = {\\n\",\n    \"    \\\"temperature\\\": 0,\\n\",\n    \"    \\\"top_p\\\": 0.95,\\n\",\n    \"    \\\"max_new_tokens\\\": 2048,\\n\",\n    \"    \\\"json_schema\\\": json.dumps(CapitalInfo.model_json_schema()),\\n\",\n    \"}\\n\",\n    \"\\n\",\n    \"outputs = llm.generate(prompts, sampling_params)\\n\",\n    \"for prompt, output in zip(prompts, outputs):\\n\",\n    \"    print(\\\"===============================\\\")\\n\",\n    \"    print(f\\\"Prompt: {prompt}\\\\nGenerated text: {output['text']}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"**JSON Schema Directly**\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"prompts = [\\n\",\n    \"    \\\"Give me the information of the capital of China in the JSON format.\\\",\\n\",\n    \"    \\\"Give me the information of the capital of France in the JSON format.\\\",\\n\",\n    \"    \\\"Give me the information of the capital of Ireland in the JSON format.\\\",\\n\",\n    \"]\\n\",\n    \"\\n\",\n    \"json_schema = json.dumps(\\n\",\n    \"    {\\n\",\n    \"        \\\"type\\\": \\\"object\\\",\\n\",\n    \"        \\\"properties\\\": {\\n\",\n    \"            \\\"name\\\": {\\\"type\\\": \\\"string\\\", \\\"pattern\\\": \\\"^[\\\\\\\\w]+$\\\"},\\n\",\n    \"            \\\"population\\\": {\\\"type\\\": \\\"integer\\\"},\\n\",\n    \"        },\\n\",\n    \"        \\\"required\\\": [\\\"name\\\", \\\"population\\\"],\\n\",\n    \"    }\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"sampling_params = {\\\"temperature\\\": 0, \\\"max_new_tokens\\\": 2048, \\\"json_schema\\\": json_schema}\\n\",\n    \"\\n\",\n    \"outputs = llm.generate(prompts, sampling_params)\\n\",\n    \"for prompt, output in zip(prompts, outputs):\\n\",\n    \"    print(\\\"===============================\\\")\\n\",\n    \"    print(f\\\"Prompt: {prompt}\\\\nGenerated text: {output['text']}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### EBNF\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"prompts = [\\n\",\n    \"    \\\"Give me the information of the capital of France.\\\",\\n\",\n    \"    \\\"Give me the information of the capital of Germany.\\\",\\n\",\n    \"    \\\"Give me the information of the capital of Italy.\\\",\\n\",\n    \"]\\n\",\n    \"\\n\",\n    \"sampling_params = {\\n\",\n    \"    \\\"temperature\\\": 0.8,\\n\",\n    \"    \\\"top_p\\\": 0.95,\\n\",\n    \"    \\\"ebnf\\\": (\\n\",\n    \"        \\\"root ::= city | description\\\\n\\\"\\n\",\n    \"        'city ::= \\\"London\\\" | \\\"Paris\\\" | \\\"Berlin\\\" | \\\"Rome\\\"\\\\n'\\n\",\n    \"        'description ::= city \\\" is \\\" status\\\\n'\\n\",\n    \"        'status ::= \\\"the capital of \\\" country\\\\n'\\n\",\n    \"        'country ::= \\\"England\\\" | \\\"France\\\" | \\\"Germany\\\" | \\\"Italy\\\"'\\n\",\n    \"    ),\\n\",\n    \"}\\n\",\n    \"\\n\",\n    \"outputs = llm.generate(prompts, sampling_params)\\n\",\n    \"for prompt, output in zip(prompts, outputs):\\n\",\n    \"    print(\\\"===============================\\\")\\n\",\n    \"    print(f\\\"Prompt: {prompt}\\\\nGenerated text: {output['text']}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Regular expression\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"prompts = [\\n\",\n    \"    \\\"Please provide information about London as a major global city:\\\",\\n\",\n    \"    \\\"Please provide information about Paris as a major global city:\\\",\\n\",\n    \"]\\n\",\n    \"\\n\",\n    \"sampling_params = {\\\"temperature\\\": 0.8, \\\"top_p\\\": 0.95, \\\"regex\\\": \\\"(France|England)\\\"}\\n\",\n    \"\\n\",\n    \"outputs = llm.generate(prompts, sampling_params)\\n\",\n    \"for prompt, output in zip(prompts, outputs):\\n\",\n    \"    print(\\\"===============================\\\")\\n\",\n    \"    print(f\\\"Prompt: {prompt}\\\\nGenerated text: {output['text']}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"text = tokenizer.apply_chat_template(\\n\",\n    \"    messages, tokenize=False, add_generation_prompt=True, return_dict=False\\n\",\n    \")\\n\",\n    \"prompts = [text]\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"sampling_params = {\\n\",\n    \"    \\\"temperature\\\": 0.8,\\n\",\n    \"    \\\"top_p\\\": 0.95,\\n\",\n    \"    \\\"max_new_tokens\\\": 2048,\\n\",\n    \"    \\\"structural_tag\\\": json.dumps(\\n\",\n    \"        {\\n\",\n    \"            \\\"type\\\": \\\"structural_tag\\\",\\n\",\n    \"            \\\"structures\\\": [\\n\",\n    \"                {\\n\",\n    \"                    \\\"begin\\\": \\\"<function=get_current_weather>\\\",\\n\",\n    \"                    \\\"schema\\\": schema_get_current_weather,\\n\",\n    \"                    \\\"end\\\": \\\"</function>\\\",\\n\",\n    \"                },\\n\",\n    \"                {\\n\",\n    \"                    \\\"begin\\\": \\\"<function=get_current_date>\\\",\\n\",\n    \"                    \\\"schema\\\": schema_get_current_date,\\n\",\n    \"                    \\\"end\\\": \\\"</function>\\\",\\n\",\n    \"                },\\n\",\n    \"            ],\\n\",\n    \"            \\\"triggers\\\": [\\\"<function=\\\"],\\n\",\n    \"        }\\n\",\n    \"    ),\\n\",\n    \"}\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"# Send POST request to the API endpoint\\n\",\n    \"outputs = llm.generate(prompts, sampling_params)\\n\",\n    \"for prompt, output in zip(prompts, outputs):\\n\",\n    \"    print(\\\"===============================\\\")\\n\",\n    \"    print(f\\\"Prompt: {prompt}\\\\nGenerated text: {output['text']}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"llm.shutdown()\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "docs/advanced_features/tool_parser.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Tool Parser\\n\",\n    \"\\n\",\n    \"This guide demonstrates how to use SGLang’s [Function calling](https://platform.openai.com/docs/guides/function-calling) functionality.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Currently supported parsers:\\n\",\n    \"\\n\",\n    \"| Parser | Supported Models | Notes |\\n\",\n    \"|---|---|---|\\n\",\n    \"| `deepseekv3` | DeepSeek-v3 (e.g., `deepseek-ai/DeepSeek-V3-0324`) | Recommend adding `--chat-template ./examples/chat_template/tool_chat_template_deepseekv3.jinja` to launch command. |\\n\",\n    \"| `deepseekv31` | DeepSeek-V3.1 and DeepSeek-V3.2-Exp (e.g. `deepseek-ai/DeepSeek-V3.1`, `deepseek-ai/DeepSeek-V3.2-Exp`) | Recommend adding `--chat-template ./examples/chat_template/tool_chat_template_deepseekv31.jinja` (Or ..deepseekv32.jinja for DeepSeek-V3.2) to launch command. |\\n\",\n    \"| `deepseekv32` | DeepSeek-V3.2 (`deepseek-ai/DeepSeek-V3.2`) | |\\n\",\n    \"| `glm` | GLM series (e.g. `zai-org/GLM-4.6`) | |\\n\",\n    \"| `gpt-oss` | GPT-OSS (e.g., `openai/gpt-oss-120b`, `openai/gpt-oss-20b`, `lmsys/gpt-oss-120b-bf16`, `lmsys/gpt-oss-20b-bf16`) | The gpt-oss tool parser filters out analysis channel events and only preserves normal text. This can cause the content to be empty when explanations are in the analysis channel. To work around this, complete the tool round by returning tool results as `role=\\\"tool\\\"` messages, which enables the model to generate the final content. |\\n\",\n    \"| `kimi_k2` | `moonshotai/Kimi-K2-Instruct` | |\\n\",\n    \"| `llama3` | Llama 3.1 / 3.2 / 3.3 (e.g. `meta-llama/Llama-3.1-8B-Instruct`, `meta-llama/Llama-3.2-1B-Instruct`, `meta-llama/Llama-3.3-70B-Instruct`) | |\\n\",\n    \"| `llama4` | Llama 4 (e.g. `meta-llama/Llama-4-Scout-17B-16E-Instruct`) | |\\n\",\n    \"| `mistral` | Mistral (e.g. `mistralai/Mistral-7B-Instruct-v0.3`, `mistralai/Mistral-Nemo-Instruct-2407`, `mistralai/Mistral-7B-v0.3`) | |\\n\",\n    \"| `pythonic` | Llama-3.2 / Llama-3.3 / Llama-4 | Model outputs function calls as Python code. Requires `--tool-call-parser pythonic` and is recommended to use with a specific chat template. |\\n\",\n    \"| `qwen` | Qwen series (e.g. `Qwen/Qwen3-Next-80B-A3B-Instruct`, `Qwen/Qwen3-VL-30B-A3B-Thinking`) except Qwen3-Coder| |\\n\",\n    \"| `qwen3_coder` | Qwen3-Coder (e.g. `Qwen/Qwen3-Coder-30B-A3B-Instruct`) | |\\n\",\n    \"| `step3` | Step-3 | |\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## OpenAI Compatible API\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Launching the Server\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import json\\n\",\n    \"from sglang.test.doc_patch import launch_server_cmd\\n\",\n    \"from sglang.utils import wait_for_server, print_highlight, terminate_process\\n\",\n    \"from openai import OpenAI\\n\",\n    \"\\n\",\n    \"server_process, port = launch_server_cmd(\\n\",\n    \"    \\\"python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-7B-Instruct --tool-call-parser qwen25 --host 0.0.0.0 --log-level warning\\\"  # qwen25\\n\",\n    \")\\n\",\n    \"wait_for_server(f\\\"http://localhost:{port}\\\", process=server_process)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Note that `--tool-call-parser` defines the parser used to interpret responses.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Define Tools for Function Call\\n\",\n    \"Below is a Python snippet that shows how to define a tool as a dictionary. The dictionary includes a tool name, a description, and property defined Parameters.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Define tools\\n\",\n    \"tools = [\\n\",\n    \"    {\\n\",\n    \"        \\\"type\\\": \\\"function\\\",\\n\",\n    \"        \\\"function\\\": {\\n\",\n    \"            \\\"name\\\": \\\"get_current_weather\\\",\\n\",\n    \"            \\\"description\\\": \\\"Get the current weather in a given location\\\",\\n\",\n    \"            \\\"parameters\\\": {\\n\",\n    \"                \\\"type\\\": \\\"object\\\",\\n\",\n    \"                \\\"properties\\\": {\\n\",\n    \"                    \\\"city\\\": {\\n\",\n    \"                        \\\"type\\\": \\\"string\\\",\\n\",\n    \"                        \\\"description\\\": \\\"The city to find the weather for, e.g. 'San Francisco'\\\",\\n\",\n    \"                    },\\n\",\n    \"                    \\\"state\\\": {\\n\",\n    \"                        \\\"type\\\": \\\"string\\\",\\n\",\n    \"                        \\\"description\\\": \\\"the two-letter abbreviation for the state that the city is\\\"\\n\",\n    \"                        \\\" in, e.g. 'CA' which would mean 'California'\\\",\\n\",\n    \"                    },\\n\",\n    \"                    \\\"unit\\\": {\\n\",\n    \"                        \\\"type\\\": \\\"string\\\",\\n\",\n    \"                        \\\"description\\\": \\\"The unit to fetch the temperature in\\\",\\n\",\n    \"                        \\\"enum\\\": [\\\"celsius\\\", \\\"fahrenheit\\\"],\\n\",\n    \"                    },\\n\",\n    \"                },\\n\",\n    \"                \\\"required\\\": [\\\"city\\\", \\\"state\\\", \\\"unit\\\"],\\n\",\n    \"            },\\n\",\n    \"        },\\n\",\n    \"    }\\n\",\n    \"]\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Define Messages\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def get_messages():\\n\",\n    \"    return [\\n\",\n    \"        {\\n\",\n    \"            \\\"role\\\": \\\"user\\\",\\n\",\n    \"            \\\"content\\\": \\\"What's the weather like in Boston today? Output a reasoning before act, then use the tools to help you.\\\",\\n\",\n    \"        }\\n\",\n    \"    ]\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"messages = get_messages()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Initialize the Client\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Initialize OpenAI-like client\\n\",\n    \"client = OpenAI(api_key=\\\"None\\\", base_url=f\\\"http://0.0.0.0:{port}/v1\\\")\\n\",\n    \"model_name = client.models.list().data[0].id\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"###  Non-Streaming Request\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Non-streaming mode test\\n\",\n    \"response_non_stream = client.chat.completions.create(\\n\",\n    \"    model=model_name,\\n\",\n    \"    messages=messages,\\n\",\n    \"    temperature=0,\\n\",\n    \"    top_p=0.95,\\n\",\n    \"    max_tokens=1024,\\n\",\n    \"    stream=False,  # Non-streaming\\n\",\n    \"    tools=tools,\\n\",\n    \")\\n\",\n    \"print_highlight(\\\"Non-stream response:\\\")\\n\",\n    \"print_highlight(response_non_stream)\\n\",\n    \"print_highlight(\\\"==== content ====\\\")\\n\",\n    \"print_highlight(response_non_stream.choices[0].message.content)\\n\",\n    \"print_highlight(\\\"==== tool_calls ====\\\")\\n\",\n    \"print_highlight(response_non_stream.choices[0].message.tool_calls)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Handle Tools\\n\",\n    \"When the engine determines it should call a particular tool, it will return arguments or partial arguments through the response. You can parse these arguments and later invoke the tool accordingly.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"name_non_stream = response_non_stream.choices[0].message.tool_calls[0].function.name\\n\",\n    \"arguments_non_stream = (\\n\",\n    \"    response_non_stream.choices[0].message.tool_calls[0].function.arguments\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"print_highlight(f\\\"Final streamed function call name: {name_non_stream}\\\")\\n\",\n    \"print_highlight(f\\\"Final streamed function call arguments: {arguments_non_stream}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Streaming Request\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Streaming mode test\\n\",\n    \"print_highlight(\\\"Streaming response:\\\")\\n\",\n    \"response_stream = client.chat.completions.create(\\n\",\n    \"    model=model_name,\\n\",\n    \"    messages=messages,\\n\",\n    \"    temperature=0,\\n\",\n    \"    top_p=0.95,\\n\",\n    \"    max_tokens=1024,\\n\",\n    \"    stream=True,  # Enable streaming\\n\",\n    \"    tools=tools,\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"texts = \\\"\\\"\\n\",\n    \"tool_calls = []\\n\",\n    \"name = \\\"\\\"\\n\",\n    \"arguments = \\\"\\\"\\n\",\n    \"for chunk in response_stream:\\n\",\n    \"    if chunk.choices[0].delta.content:\\n\",\n    \"        texts += chunk.choices[0].delta.content\\n\",\n    \"    if chunk.choices[0].delta.tool_calls:\\n\",\n    \"        tool_calls.append(chunk.choices[0].delta.tool_calls[0])\\n\",\n    \"print_highlight(\\\"==== Text ====\\\")\\n\",\n    \"print_highlight(texts)\\n\",\n    \"\\n\",\n    \"print_highlight(\\\"==== Tool Call ====\\\")\\n\",\n    \"for tool_call in tool_calls:\\n\",\n    \"    print_highlight(tool_call)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Handle Tools\\n\",\n    \"When the engine determines it should call a particular tool, it will return arguments or partial arguments through the response. You can parse these arguments and later invoke the tool accordingly.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Parse and combine function call arguments\\n\",\n    \"arguments = []\\n\",\n    \"for tool_call in tool_calls:\\n\",\n    \"    if tool_call.function.name:\\n\",\n    \"        print_highlight(f\\\"Streamed function call name: {tool_call.function.name}\\\")\\n\",\n    \"\\n\",\n    \"    if tool_call.function.arguments:\\n\",\n    \"        arguments.append(tool_call.function.arguments)\\n\",\n    \"\\n\",\n    \"# Combine all fragments into a single JSON string\\n\",\n    \"full_arguments = \\\"\\\".join(arguments)\\n\",\n    \"print_highlight(f\\\"streamed function call arguments: {full_arguments}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Define a Tool Function\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# This is a demonstration, define real function according to your usage.\\n\",\n    \"def get_current_weather(city: str, state: str, unit: \\\"str\\\"):\\n\",\n    \"    return (\\n\",\n    \"        f\\\"The weather in {city}, {state} is 85 degrees {unit}. It is \\\"\\n\",\n    \"        \\\"partly cloudly, with highs in the 90's.\\\"\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"available_tools = {\\\"get_current_weather\\\": get_current_weather}\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"\\n\",\n    \"### Execute the Tool\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"messages.append(response_non_stream.choices[0].message)\\n\",\n    \"\\n\",\n    \"# Call the corresponding tool function\\n\",\n    \"tool_call = messages[-1].tool_calls[0]\\n\",\n    \"tool_name = tool_call.function.name\\n\",\n    \"tool_to_call = available_tools[tool_name]\\n\",\n    \"result = tool_to_call(**(json.loads(tool_call.function.arguments)))\\n\",\n    \"print_highlight(f\\\"Function call result: {result}\\\")\\n\",\n    \"# messages.append({\\\"role\\\": \\\"tool\\\", \\\"content\\\": result, \\\"name\\\": tool_name})\\n\",\n    \"messages.append(\\n\",\n    \"    {\\n\",\n    \"        \\\"role\\\": \\\"tool\\\",\\n\",\n    \"        \\\"tool_call_id\\\": tool_call.id,\\n\",\n    \"        \\\"content\\\": str(result),\\n\",\n    \"        \\\"name\\\": tool_name,\\n\",\n    \"    }\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"print_highlight(f\\\"Updated message history: {messages}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Send Results Back to Model\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"final_response = client.chat.completions.create(\\n\",\n    \"    model=model_name,\\n\",\n    \"    messages=messages,\\n\",\n    \"    temperature=0,\\n\",\n    \"    top_p=0.95,\\n\",\n    \"    stream=False,\\n\",\n    \"    tools=tools,\\n\",\n    \")\\n\",\n    \"print_highlight(\\\"Non-stream response:\\\")\\n\",\n    \"print_highlight(final_response)\\n\",\n    \"\\n\",\n    \"print_highlight(\\\"==== Text ====\\\")\\n\",\n    \"print_highlight(final_response.choices[0].message.content)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Native API and SGLang Runtime (SRT)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from transformers import AutoTokenizer\\n\",\n    \"import requests\\n\",\n    \"\\n\",\n    \"# generate an answer\\n\",\n    \"tokenizer = AutoTokenizer.from_pretrained(\\\"Qwen/Qwen2.5-7B-Instruct\\\")\\n\",\n    \"\\n\",\n    \"messages = get_messages()\\n\",\n    \"\\n\",\n    \"input = tokenizer.apply_chat_template(\\n\",\n    \"    messages, tokenize=False, add_generation_prompt=True, tools=tools, return_dict=False\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"gen_url = f\\\"http://localhost:{port}/generate\\\"\\n\",\n    \"gen_data = {\\n\",\n    \"    \\\"text\\\": input,\\n\",\n    \"    \\\"sampling_params\\\": {\\n\",\n    \"        \\\"skip_special_tokens\\\": False,\\n\",\n    \"        \\\"max_new_tokens\\\": 1024,\\n\",\n    \"        \\\"temperature\\\": 0,\\n\",\n    \"        \\\"top_p\\\": 0.95,\\n\",\n    \"    },\\n\",\n    \"}\\n\",\n    \"gen_response = requests.post(gen_url, json=gen_data).json()[\\\"text\\\"]\\n\",\n    \"print_highlight(\\\"==== Response ====\\\")\\n\",\n    \"print_highlight(gen_response)\\n\",\n    \"\\n\",\n    \"# parse the response\\n\",\n    \"parse_url = f\\\"http://localhost:{port}/parse_function_call\\\"\\n\",\n    \"\\n\",\n    \"function_call_input = {\\n\",\n    \"    \\\"text\\\": gen_response,\\n\",\n    \"    \\\"tool_call_parser\\\": \\\"qwen25\\\",\\n\",\n    \"    \\\"tools\\\": tools,\\n\",\n    \"}\\n\",\n    \"\\n\",\n    \"function_call_response = requests.post(parse_url, json=function_call_input)\\n\",\n    \"function_call_response_json = function_call_response.json()\\n\",\n    \"\\n\",\n    \"print_highlight(\\\"==== Text ====\\\")\\n\",\n    \"print(function_call_response_json[\\\"normal_text\\\"])\\n\",\n    \"print_highlight(\\\"==== Calls ====\\\")\\n\",\n    \"print(\\\"function name: \\\", function_call_response_json[\\\"calls\\\"][0][\\\"name\\\"])\\n\",\n    \"print(\\\"function arguments: \\\", function_call_response_json[\\\"calls\\\"][0][\\\"parameters\\\"])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"terminate_process(server_process)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Offline Engine API\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import sglang as sgl\\n\",\n    \"from sglang.srt.function_call.function_call_parser import FunctionCallParser\\n\",\n    \"from sglang.srt.managers.io_struct import Tool, Function\\n\",\n    \"\\n\",\n    \"llm = sgl.Engine(model_path=\\\"Qwen/Qwen2.5-7B-Instruct\\\")\\n\",\n    \"tokenizer = llm.tokenizer_manager.tokenizer\\n\",\n    \"input_ids = tokenizer.apply_chat_template(\\n\",\n    \"    messages, tokenize=True, add_generation_prompt=True, tools=tools, return_dict=False\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"# Note that for gpt-oss tool parser, adding \\\"no_stop_trim\\\": True\\n\",\n    \"# to make sure the tool call token <call> is not trimmed.\\n\",\n    \"\\n\",\n    \"sampling_params = {\\n\",\n    \"    \\\"max_new_tokens\\\": 1024,\\n\",\n    \"    \\\"temperature\\\": 0,\\n\",\n    \"    \\\"top_p\\\": 0.95,\\n\",\n    \"    \\\"skip_special_tokens\\\": False,\\n\",\n    \"}\\n\",\n    \"\\n\",\n    \"# 1) Offline generation\\n\",\n    \"result = llm.generate(input_ids=input_ids, sampling_params=sampling_params)\\n\",\n    \"generated_text = result[\\\"text\\\"]  # Assume there is only one prompt\\n\",\n    \"\\n\",\n    \"print_highlight(\\\"=== Offline Engine Output Text ===\\\")\\n\",\n    \"print_highlight(generated_text)\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"# 2) Parse using FunctionCallParser\\n\",\n    \"def convert_dict_to_tool(tool_dict: dict) -> Tool:\\n\",\n    \"    function_dict = tool_dict.get(\\\"function\\\", {})\\n\",\n    \"    return Tool(\\n\",\n    \"        type=tool_dict.get(\\\"type\\\", \\\"function\\\"),\\n\",\n    \"        function=Function(\\n\",\n    \"            name=function_dict.get(\\\"name\\\"),\\n\",\n    \"            description=function_dict.get(\\\"description\\\"),\\n\",\n    \"            parameters=function_dict.get(\\\"parameters\\\"),\\n\",\n    \"        ),\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"tools = [convert_dict_to_tool(raw_tool) for raw_tool in tools]\\n\",\n    \"\\n\",\n    \"parser = FunctionCallParser(tools=tools, tool_call_parser=\\\"qwen25\\\")\\n\",\n    \"normal_text, calls = parser.parse_non_stream(generated_text)\\n\",\n    \"\\n\",\n    \"print_highlight(\\\"=== Parsing Result ===\\\")\\n\",\n    \"print(\\\"Normal text portion:\\\", normal_text)\\n\",\n    \"print_highlight(\\\"Function call portion:\\\")\\n\",\n    \"for call in calls:\\n\",\n    \"    # call: ToolCallItem\\n\",\n    \"    print_highlight(f\\\"  - tool name: {call.name}\\\")\\n\",\n    \"    print_highlight(f\\\"    parameters: {call.parameters}\\\")\\n\",\n    \"\\n\",\n    \"# 3) If needed, perform additional logic on the parsed functions, such as automatically calling the corresponding function to obtain a return value, etc.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"llm.shutdown()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Tool Choice Mode\\n\",\n    \"\\n\",\n    \"SGLang supports OpenAI's `tool_choice` parameter to control when and which tools the model should call. This feature is implemented using EBNF (Extended Backus-Naur Form) grammar to ensure reliable tool calling behavior.\\n\",\n    \"\\n\",\n    \"### Supported Tool Choice Options\\n\",\n    \"\\n\",\n    \"- **`tool_choice=\\\"required\\\"`**: Forces the model to call at least one tool\\n\",\n    \"- **`tool_choice={\\\"type\\\": \\\"function\\\", \\\"function\\\": {\\\"name\\\": \\\"specific_function\\\"}}`**: Forces the model to call a specific function\\n\",\n    \"\\n\",\n    \"### Backend Compatibility\\n\",\n    \"\\n\",\n    \"Tool choice is fully supported with the **Xgrammar backend**, which is the default grammar backend (`--grammar-backend xgrammar`). However, it may not be fully supported with other backends such as `outlines`.\\n\",\n    \"\\n\",\n    \"### Example: Required Tool Choice\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from openai import OpenAI\\n\",\n    \"from sglang.utils import wait_for_server, print_highlight, terminate_process\\n\",\n    \"from sglang.test.doc_patch import launch_server_cmd\\n\",\n    \"\\n\",\n    \"# Start a new server session for tool choice examples\\n\",\n    \"server_process_tool_choice, port_tool_choice = launch_server_cmd(\\n\",\n    \"    \\\"python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-7B-Instruct --tool-call-parser qwen25 --host 0.0.0.0  --log-level warning\\\"\\n\",\n    \")\\n\",\n    \"wait_for_server(\\n\",\n    \"    f\\\"http://localhost:{port_tool_choice}\\\", process=server_process_tool_choice\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"# Initialize client for tool choice examples\\n\",\n    \"client_tool_choice = OpenAI(\\n\",\n    \"    api_key=\\\"None\\\", base_url=f\\\"http://0.0.0.0:{port_tool_choice}/v1\\\"\\n\",\n    \")\\n\",\n    \"model_name_tool_choice = client_tool_choice.models.list().data[0].id\\n\",\n    \"\\n\",\n    \"# Example with tool_choice=\\\"required\\\" - forces the model to call a tool\\n\",\n    \"messages_required = [\\n\",\n    \"    {\\\"role\\\": \\\"user\\\", \\\"content\\\": \\\"Hello, what is the capital of France?\\\"}\\n\",\n    \"]\\n\",\n    \"\\n\",\n    \"# Define tools\\n\",\n    \"tools = [\\n\",\n    \"    {\\n\",\n    \"        \\\"type\\\": \\\"function\\\",\\n\",\n    \"        \\\"function\\\": {\\n\",\n    \"            \\\"name\\\": \\\"get_current_weather\\\",\\n\",\n    \"            \\\"description\\\": \\\"Get the current weather in a given location\\\",\\n\",\n    \"            \\\"parameters\\\": {\\n\",\n    \"                \\\"type\\\": \\\"object\\\",\\n\",\n    \"                \\\"properties\\\": {\\n\",\n    \"                    \\\"city\\\": {\\n\",\n    \"                        \\\"type\\\": \\\"string\\\",\\n\",\n    \"                        \\\"description\\\": \\\"The city to find the weather for, e.g. 'San Francisco'\\\",\\n\",\n    \"                    },\\n\",\n    \"                    \\\"unit\\\": {\\n\",\n    \"                        \\\"type\\\": \\\"string\\\",\\n\",\n    \"                        \\\"description\\\": \\\"The unit to fetch the temperature in\\\",\\n\",\n    \"                        \\\"enum\\\": [\\\"celsius\\\", \\\"fahrenheit\\\"],\\n\",\n    \"                    },\\n\",\n    \"                },\\n\",\n    \"                \\\"required\\\": [\\\"city\\\", \\\"unit\\\"],\\n\",\n    \"            },\\n\",\n    \"        },\\n\",\n    \"    }\\n\",\n    \"]\\n\",\n    \"\\n\",\n    \"response_required = client_tool_choice.chat.completions.create(\\n\",\n    \"    model=model_name_tool_choice,\\n\",\n    \"    messages=messages_required,\\n\",\n    \"    temperature=0,\\n\",\n    \"    max_tokens=1024,\\n\",\n    \"    tools=tools,\\n\",\n    \"    tool_choice=\\\"required\\\",  # Force the model to call a tool\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"print_highlight(\\\"Response with tool_choice='required':\\\")\\n\",\n    \"print(\\\"Content:\\\", response_required.choices[0].message.content)\\n\",\n    \"print(\\\"Tool calls:\\\", response_required.choices[0].message.tool_calls)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Example: Specific Function Choice\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Example with specific function choice - forces the model to call a specific function\\n\",\n    \"messages_specific = [\\n\",\n    \"    {\\\"role\\\": \\\"user\\\", \\\"content\\\": \\\"What are the most attactive places in France?\\\"}\\n\",\n    \"]\\n\",\n    \"\\n\",\n    \"response_specific = client_tool_choice.chat.completions.create(\\n\",\n    \"    model=model_name_tool_choice,\\n\",\n    \"    messages=messages_specific,\\n\",\n    \"    temperature=0,\\n\",\n    \"    max_tokens=1024,\\n\",\n    \"    tools=tools,\\n\",\n    \"    tool_choice={\\n\",\n    \"        \\\"type\\\": \\\"function\\\",\\n\",\n    \"        \\\"function\\\": {\\\"name\\\": \\\"get_current_weather\\\"},\\n\",\n    \"    },  # Force the model to call the specific get_current_weather function\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"print_highlight(\\\"Response with specific function choice:\\\")\\n\",\n    \"print(\\\"Content:\\\", response_specific.choices[0].message.content)\\n\",\n    \"print(\\\"Tool calls:\\\", response_specific.choices[0].message.tool_calls)\\n\",\n    \"\\n\",\n    \"if response_specific.choices[0].message.tool_calls:\\n\",\n    \"    tool_call = response_specific.choices[0].message.tool_calls[0]\\n\",\n    \"    print_highlight(f\\\"Called function: {tool_call.function.name}\\\")\\n\",\n    \"    print_highlight(f\\\"Arguments: {tool_call.function.arguments}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"terminate_process(server_process_tool_choice)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Pythonic Tool Call Format (Llama-3.2 / Llama-3.3 / Llama-4)\\n\",\n    \"\\n\",\n    \"Some Llama models (such as Llama-3.2-1B, Llama-3.2-3B, Llama-3.3-70B, and Llama-4) support a \\\"pythonic\\\" tool call format, where the model outputs function calls as Python code, e.g.:\\n\",\n    \"\\n\",\n    \"```python\\n\",\n    \"[get_current_weather(city=\\\"San Francisco\\\", state=\\\"CA\\\", unit=\\\"celsius\\\")]\\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"- The output is a Python list of function calls, with arguments as Python literals (not JSON).\\n\",\n    \"- Multiple tool calls can be returned in the same list:\\n\",\n    \"```python\\n\",\n    \"[get_current_weather(city=\\\"San Francisco\\\", state=\\\"CA\\\", unit=\\\"celsius\\\"),\\n\",\n    \" get_current_weather(city=\\\"New York\\\", state=\\\"NY\\\", unit=\\\"fahrenheit\\\")]\\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"For more information, refer to Meta’s documentation on  [Zero shot function calling](https://github.com/meta-llama/llama-models/blob/main/models/llama4/prompt_format.md#zero-shot-function-calling---system-message).\\n\",\n    \"\\n\",\n    \"Note that this feature is still under development on Blackwell.\\n\",\n    \"\\n\",\n    \"### How to enable\\n\",\n    \"- Launch the server with `--tool-call-parser pythonic`\\n\",\n    \"- You may also specify --chat-template with the improved template for the model (e.g., `--chat-template=examples/chat_template/tool_chat_template_llama4_pythonic.jinja`).\\n\",\n    \"This is recommended because the model expects a special prompt format to reliably produce valid pythonic tool call outputs. The template ensures that the prompt structure (e.g., special tokens, message boundaries like `<|eom|>`, and function call delimiters) matches what the model was trained or fine-tuned on. If you do not use the correct chat template, tool calling may fail or produce inconsistent results.\\n\",\n    \"\\n\",\n    \"#### Forcing Pythonic Tool Call Output Without a Chat Template\\n\",\n    \"If you don't want to specify a chat template, you must give the model extremely explicit instructions in your messages to enforce pythonic output. For example, for `Llama-3.2-1B-Instruct`, you need:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import openai\\n\",\n    \"\\n\",\n    \"server_process, port = launch_server_cmd(\\n\",\n    \"    \\\" python3 -m sglang.launch_server --model-path meta-llama/Llama-3.2-1B-Instruct --tool-call-parser pythonic --tp 1  --log-level warning\\\"  # llama-3.2-1b-instruct\\n\",\n    \")\\n\",\n    \"wait_for_server(f\\\"http://localhost:{port}\\\", process=server_process)\\n\",\n    \"\\n\",\n    \"tools = [\\n\",\n    \"    {\\n\",\n    \"        \\\"type\\\": \\\"function\\\",\\n\",\n    \"        \\\"function\\\": {\\n\",\n    \"            \\\"name\\\": \\\"get_weather\\\",\\n\",\n    \"            \\\"description\\\": \\\"Get the current weather for a given location.\\\",\\n\",\n    \"            \\\"parameters\\\": {\\n\",\n    \"                \\\"type\\\": \\\"object\\\",\\n\",\n    \"                \\\"properties\\\": {\\n\",\n    \"                    \\\"location\\\": {\\n\",\n    \"                        \\\"type\\\": \\\"string\\\",\\n\",\n    \"                        \\\"description\\\": \\\"The name of the city or location.\\\",\\n\",\n    \"                    }\\n\",\n    \"                },\\n\",\n    \"                \\\"required\\\": [\\\"location\\\"],\\n\",\n    \"            },\\n\",\n    \"        },\\n\",\n    \"    },\\n\",\n    \"    {\\n\",\n    \"        \\\"type\\\": \\\"function\\\",\\n\",\n    \"        \\\"function\\\": {\\n\",\n    \"            \\\"name\\\": \\\"get_tourist_attractions\\\",\\n\",\n    \"            \\\"description\\\": \\\"Get a list of top tourist attractions for a given city.\\\",\\n\",\n    \"            \\\"parameters\\\": {\\n\",\n    \"                \\\"type\\\": \\\"object\\\",\\n\",\n    \"                \\\"properties\\\": {\\n\",\n    \"                    \\\"city\\\": {\\n\",\n    \"                        \\\"type\\\": \\\"string\\\",\\n\",\n    \"                        \\\"description\\\": \\\"The name of the city to find attractions for.\\\",\\n\",\n    \"                    }\\n\",\n    \"                },\\n\",\n    \"                \\\"required\\\": [\\\"city\\\"],\\n\",\n    \"            },\\n\",\n    \"        },\\n\",\n    \"    },\\n\",\n    \"]\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def get_messages():\\n\",\n    \"    return [\\n\",\n    \"        {\\n\",\n    \"            \\\"role\\\": \\\"system\\\",\\n\",\n    \"            \\\"content\\\": (\\n\",\n    \"                \\\"You are a travel assistant. \\\"\\n\",\n    \"                \\\"When asked to call functions, ALWAYS respond ONLY with a python list of function calls, \\\"\\n\",\n    \"                \\\"using this format: [func_name1(param1=value1, param2=value2), func_name2(param=value)]. \\\"\\n\",\n    \"                \\\"Do NOT use JSON, do NOT use variables, do NOT use any other format. \\\"\\n\",\n    \"                \\\"Here is an example:\\\\n\\\"\\n\",\n    \"                '[get_weather(location=\\\"Paris\\\"), get_tourist_attractions(city=\\\"Paris\\\")]'\\n\",\n    \"            ),\\n\",\n    \"        },\\n\",\n    \"        {\\n\",\n    \"            \\\"role\\\": \\\"user\\\",\\n\",\n    \"            \\\"content\\\": (\\n\",\n    \"                \\\"I'm planning a trip to Tokyo next week. What's the weather like and what are some top tourist attractions? \\\"\\n\",\n    \"                \\\"Propose parallel tool calls at once, using the python list of function calls format as shown above.\\\"\\n\",\n    \"            ),\\n\",\n    \"        },\\n\",\n    \"    ]\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"messages = get_messages()\\n\",\n    \"\\n\",\n    \"client = openai.Client(base_url=f\\\"http://localhost:{port}/v1\\\", api_key=\\\"xxxxxx\\\")\\n\",\n    \"model_name = client.models.list().data[0].id\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"response_non_stream = client.chat.completions.create(\\n\",\n    \"    model=model_name,\\n\",\n    \"    messages=messages,\\n\",\n    \"    temperature=0,\\n\",\n    \"    top_p=0.9,\\n\",\n    \"    stream=False,  # Non-streaming\\n\",\n    \"    tools=tools,\\n\",\n    \")\\n\",\n    \"print_highlight(\\\"Non-stream response:\\\")\\n\",\n    \"print_highlight(response_non_stream)\\n\",\n    \"\\n\",\n    \"response_stream = client.chat.completions.create(\\n\",\n    \"    model=model_name,\\n\",\n    \"    messages=messages,\\n\",\n    \"    temperature=0,\\n\",\n    \"    top_p=0.9,\\n\",\n    \"    stream=True,\\n\",\n    \"    tools=tools,\\n\",\n    \")\\n\",\n    \"texts = \\\"\\\"\\n\",\n    \"tool_calls = []\\n\",\n    \"name = \\\"\\\"\\n\",\n    \"arguments = \\\"\\\"\\n\",\n    \"\\n\",\n    \"for chunk in response_stream:\\n\",\n    \"    if chunk.choices[0].delta.content:\\n\",\n    \"        texts += chunk.choices[0].delta.content\\n\",\n    \"    if chunk.choices[0].delta.tool_calls:\\n\",\n    \"        tool_calls.append(chunk.choices[0].delta.tool_calls[0])\\n\",\n    \"\\n\",\n    \"print_highlight(\\\"Streaming Response:\\\")\\n\",\n    \"print_highlight(\\\"==== Text ====\\\")\\n\",\n    \"print_highlight(texts)\\n\",\n    \"\\n\",\n    \"print_highlight(\\\"==== Tool Call ====\\\")\\n\",\n    \"for tool_call in tool_calls:\\n\",\n    \"    print_highlight(tool_call)\\n\",\n    \"\\n\",\n    \"terminate_process(server_process)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"> **Note:**  \\n\",\n    \"> The model may still default to JSON if it was heavily finetuned on that format. Prompt engineering (including examples) is the only way to increase the chance of pythonic output if you are not using a chat template.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## How to support a new model?\\n\",\n    \"1. Update the TOOLS_TAG_LIST in sglang/srt/function_call_parser.py with the model’s tool tags. Currently supported tags include:\\n\",\n    \"```\\n\",\n    \"\\tTOOLS_TAG_LIST = [\\n\",\n    \"\\t    “<|plugin|>“,\\n\",\n    \"\\t    “<function=“,\\n\",\n    \"\\t    “<tool_call>“,\\n\",\n    \"\\t    “<|python_tag|>“,\\n\",\n    \"\\t    “[TOOL_CALLS]”\\n\",\n    \"\\t]\\n\",\n    \"```\\n\",\n    \"2. Create a new detector class in sglang/srt/function_call_parser.py that inherits from BaseFormatDetector. The detector should handle the model’s specific function call format. For example:\\n\",\n    \"```\\n\",\n    \"    class NewModelDetector(BaseFormatDetector):\\n\",\n    \"```\\n\",\n    \"3. Add the new detector to the MultiFormatParser class that manages all the format detectors.\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "docs/advanced_features/vlm_query.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"0\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Query VLM with Offline Engine\\n\",\n    \"\\n\",\n    \"This tutorial demonstrates how to use SGLang's **offline Engine API** to query VLMs. We will demonstrate usage with Qwen2.5-VL and Llama 4. This section demonstrates three different calling approaches:\\n\",\n    \"\\n\",\n    \"1. **Basic Call**: Directly pass images and text.\\n\",\n    \"2. **Processor Output**: Use HuggingFace processor for data preprocessing.\\n\",\n    \"3. **Precomputed Embeddings**: Pre-calculate image features to improve inference efficiency.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"1\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Understanding the Three Input Formats\\n\",\n    \"\\n\",\n    \"SGLang supports three ways to pass visual data, each optimized for different scenarios:\\n\",\n    \"\\n\",\n    \"### 1. **Raw Images** - Simplest approach\\n\",\n    \"- Pass PIL Images, file paths, URLs, or base64 strings directly\\n\",\n    \"- SGLang handles all preprocessing automatically\\n\",\n    \"- Best for: Quick prototyping, simple applications\\n\",\n    \"\\n\",\n    \"### 2. **Processor Output** - For custom preprocessing\\n\",\n    \"- Pre-process images with HuggingFace processor\\n\",\n    \"- Pass the complete processor output dict with `format: \\\"processor_output\\\"`\\n\",\n    \"- Best for: Custom image transformations, integration with existing pipelines\\n\",\n    \"- Requirement: Must use `input_ids` instead of text prompt\\n\",\n    \"\\n\",\n    \"### 3. **Precomputed Embeddings** - For maximum performance\\n\",\n    \"- Pre-calculate visual embeddings using the vision encoder\\n\",\n    \"- Pass embeddings with `format: \\\"precomputed_embedding\\\"`\\n\",\n    \"- Best for: Repeated queries on same images, caching, high-throughput serving\\n\",\n    \"- Performance gain: Avoids redundant vision encoder computation (30-50% speedup)\\n\",\n    \"\\n\",\n    \"**Key Rule**: Within a single request, use only one format for all images. Don't mix formats.\\n\",\n    \"\\n\",\n    \"The examples below demonstrate all three approaches with both Qwen2.5-VL and Llama 4 models.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"2\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Querying Qwen2.5-VL Model\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"3\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import nest_asyncio\\n\",\n    \"\\n\",\n    \"nest_asyncio.apply()\\n\",\n    \"\\n\",\n    \"import sglang.test.doc_patch  # noqa: F401\\n\",\n    \"\\n\",\n    \"model_path = \\\"Qwen/Qwen2.5-VL-3B-Instruct\\\"\\n\",\n    \"chat_template = \\\"qwen2-vl\\\"\\n\",\n    \"example_image_url = \\\"https://raw.githubusercontent.com/sgl-project/sglang/main/examples/assets/example_image.png\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"4\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from io import BytesIO\\n\",\n    \"import requests\\n\",\n    \"from PIL import Image\\n\",\n    \"\\n\",\n    \"from sglang.srt.parser.conversation import chat_templates\\n\",\n    \"\\n\",\n    \"image = Image.open(BytesIO(requests.get(example_image_url).content))\\n\",\n    \"\\n\",\n    \"conv = chat_templates[chat_template].copy()\\n\",\n    \"conv.append_message(conv.roles[0], f\\\"What's shown here: {conv.image_token}?\\\")\\n\",\n    \"conv.append_message(conv.roles[1], \\\"\\\")\\n\",\n    \"conv.image_data = [image]\\n\",\n    \"\\n\",\n    \"print(\\\"Generated prompt text:\\\")\\n\",\n    \"print(conv.get_prompt())\\n\",\n    \"print(f\\\"\\\\nImage size: {image.size}\\\")\\n\",\n    \"image\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"5\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Basic Offline Engine API Call\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"6\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from sglang import Engine\\n\",\n    \"\\n\",\n    \"llm = Engine(model_path=model_path, chat_template=chat_template, log_level=\\\"warning\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"7\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"out = llm.generate(prompt=conv.get_prompt(), image_data=[image])\\n\",\n    \"print(\\\"Model response:\\\")\\n\",\n    \"print(out[\\\"text\\\"])\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"8\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Call with Processor Output\\n\",\n    \"\\n\",\n    \"Using a HuggingFace processor to preprocess text and images, and passing the `processor_output` directly into `Engine.generate`.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"9\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from transformers import AutoProcessor\\n\",\n    \"\\n\",\n    \"processor = AutoProcessor.from_pretrained(model_path, use_fast=True)\\n\",\n    \"processor_output = processor(\\n\",\n    \"    images=[image], text=conv.get_prompt(), return_tensors=\\\"pt\\\"\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"out = llm.generate(\\n\",\n    \"    input_ids=processor_output[\\\"input_ids\\\"][0].detach().cpu().tolist(),\\n\",\n    \"    image_data=[dict(processor_output, format=\\\"processor_output\\\")],\\n\",\n    \")\\n\",\n    \"print(\\\"Response using processor output:\\\")\\n\",\n    \"print(out[\\\"text\\\"])\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"10\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Call with Precomputed Embeddings\\n\",\n    \"\\n\",\n    \"You can pre-calculate image features to avoid repeated visual encoding processes.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"11\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from transformers import AutoProcessor\\n\",\n    \"from transformers import Qwen2_5_VLForConditionalGeneration\\n\",\n    \"\\n\",\n    \"processor = AutoProcessor.from_pretrained(model_path, use_fast=True)\\n\",\n    \"model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path).eval()\\n\",\n    \"vision = model.model.visual.cuda()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"12\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"processor_output = processor(\\n\",\n    \"    images=[image], text=conv.get_prompt(), return_tensors=\\\"pt\\\"\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"input_ids = processor_output[\\\"input_ids\\\"][0].detach().cpu().tolist()\\n\",\n    \"\\n\",\n    \"precomputed_embeddings = vision(\\n\",\n    \"    processor_output[\\\"pixel_values\\\"].cuda(), processor_output[\\\"image_grid_thw\\\"].cuda()\\n\",\n    \")\\n\",\n    \"precomputed_embeddings = precomputed_embeddings.pooler_output\\n\",\n    \"\\n\",\n    \"multi_modal_item = dict(\\n\",\n    \"    processor_output,\\n\",\n    \"    format=\\\"precomputed_embedding\\\",\\n\",\n    \"    feature=precomputed_embeddings,\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"out = llm.generate(input_ids=input_ids, image_data=[multi_modal_item])\\n\",\n    \"print(\\\"Response using precomputed embeddings:\\\")\\n\",\n    \"print(out[\\\"text\\\"])\\n\",\n    \"\\n\",\n    \"llm.shutdown()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"13\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Querying Llama 4 Vision Model\\n\",\n    \"\\n\",\n    \"```python\\n\",\n    \"model_path = \\\"meta-llama/Llama-4-Scout-17B-16E-Instruct\\\"\\n\",\n    \"chat_template = \\\"llama-4\\\"\\n\",\n    \"\\n\",\n    \"from io import BytesIO\\n\",\n    \"import requests\\n\",\n    \"from PIL import Image\\n\",\n    \"\\n\",\n    \"from sglang.srt.parser.conversation import chat_templates\\n\",\n    \"\\n\",\n    \"# Download the same example image\\n\",\n    \"image = Image.open(BytesIO(requests.get(example_image_url).content))\\n\",\n    \"\\n\",\n    \"conv = chat_templates[chat_template].copy()\\n\",\n    \"conv.append_message(conv.roles[0], f\\\"What's shown here: {conv.image_token}?\\\")\\n\",\n    \"conv.append_message(conv.roles[1], \\\"\\\")\\n\",\n    \"conv.image_data = [image]\\n\",\n    \"\\n\",\n    \"print(\\\"Llama 4 generated prompt text:\\\")\\n\",\n    \"print(conv.get_prompt())\\n\",\n    \"print(f\\\"Image size: {image.size}\\\")\\n\",\n    \"\\n\",\n    \"image\\n\",\n    \"```\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"14\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Llama 4 Basic Call\\n\",\n    \"\\n\",\n    \"Llama 4 requires more computational resources, so it's configured with multi-GPU parallelism (tp_size=4) and larger context length.\\n\",\n    \"\\n\",\n    \"```python\\n\",\n    \"llm = Engine(\\n\",\n    \"    model_path=model_path,\\n\",\n    \"    enable_multimodal=True,\\n\",\n    \"    attention_backend=\\\"fa3\\\",\\n\",\n    \"    tp_size=4,\\n\",\n    \"    context_length=65536,\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"out = llm.generate(prompt=conv.get_prompt(), image_data=[image])\\n\",\n    \"print(\\\"Llama 4 response:\\\")\\n\",\n    \"print(out[\\\"text\\\"])\\n\",\n    \"```\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"15\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Call with Processor Output\\n\",\n    \"\\n\",\n    \"Using HuggingFace processor to preprocess data can reduce computational overhead during inference.\\n\",\n    \"\\n\",\n    \"```python\\n\",\n    \"from transformers import AutoProcessor\\n\",\n    \"\\n\",\n    \"processor = AutoProcessor.from_pretrained(model_path, use_fast=True)\\n\",\n    \"processor_output = processor(\\n\",\n    \"    images=[image], text=conv.get_prompt(), return_tensors=\\\"pt\\\"\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"out = llm.generate(\\n\",\n    \"    input_ids=processor_output[\\\"input_ids\\\"][0].detach().cpu().tolist(),\\n\",\n    \"    image_data=[dict(processor_output, format=\\\"processor_output\\\")],\\n\",\n    \")\\n\",\n    \"print(\\\"Response using processor output:\\\")\\n\",\n    \"print(out)\\n\",\n    \"```\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"16\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Call with Precomputed Embeddings\\n\",\n    \"\\n\",\n    \"```python\\n\",\n    \"from transformers import AutoProcessor\\n\",\n    \"from transformers import Llama4ForConditionalGeneration\\n\",\n    \"\\n\",\n    \"processor = AutoProcessor.from_pretrained(model_path, use_fast=True)\\n\",\n    \"model = Llama4ForConditionalGeneration.from_pretrained(\\n\",\n    \"    model_path, torch_dtype=\\\"auto\\\"\\n\",\n    \").eval()\\n\",\n    \"\\n\",\n    \"vision = model.vision_model.cuda()\\n\",\n    \"multi_modal_projector = model.multi_modal_projector.cuda()\\n\",\n    \"\\n\",\n    \"print(f'Image pixel values shape: {processor_output[\\\"pixel_values\\\"].shape}')\\n\",\n    \"input_ids = processor_output[\\\"input_ids\\\"][0].detach().cpu().tolist()\\n\",\n    \"\\n\",\n    \"# Process image through vision encoder\\n\",\n    \"image_outputs = vision(\\n\",\n    \"    processor_output[\\\"pixel_values\\\"].to(\\\"cuda\\\"), \\n\",\n    \"    aspect_ratio_ids=processor_output[\\\"aspect_ratio_ids\\\"].to(\\\"cuda\\\"),\\n\",\n    \"    aspect_ratio_mask=processor_output[\\\"aspect_ratio_mask\\\"].to(\\\"cuda\\\"),\\n\",\n    \"    output_hidden_states=False\\n\",\n    \")\\n\",\n    \"image_features = image_outputs.last_hidden_state\\n\",\n    \"\\n\",\n    \"# Flatten image features and pass through multimodal projector\\n\",\n    \"vision_flat = image_features.view(-1, image_features.size(-1))\\n\",\n    \"precomputed_embeddings = multi_modal_projector(vision_flat)\\n\",\n    \"\\n\",\n    \"# Build precomputed embedding data item\\n\",\n    \"mm_item = dict(\\n\",\n    \"    processor_output, \\n\",\n    \"    format=\\\"precomputed_embedding\\\", \\n\",\n    \"    feature=precomputed_embeddings\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"# Use precomputed embeddings for efficient inference\\n\",\n    \"out = llm.generate(input_ids=input_ids, image_data=[mm_item])\\n\",\n    \"print(\\\"Llama 4 precomputed embedding response:\\\")\\n\",\n    \"print(out[\\\"text\\\"])\\n\",\n    \"```\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"jupytext\": {\n   \"cell_metadata_filter\": \"-all\",\n   \"custom_cell_magics\": \"kql\",\n   \"encoding\": \"# -*- coding: utf-8 -*-\",\n   \"text_representation\": {\n    \"extension\": \".py\",\n    \"format_name\": \"light\",\n    \"format_version\": \"1.5\",\n    \"jupytext_version\": \"1.16.1\"\n   }\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "docs/basic_usage/deepseek_ocr.md",
    "content": "# DeepSeek OCR (OCR-1 / OCR-2)\n\nDeepSeek OCR models are multimodal (image + text) models for OCR and document understanding.\n\n## Launch server\n\n```shell\npython -m sglang.launch_server \\\n  --model-path deepseek-ai/DeepSeek-OCR-2 \\\n  --trust-remote-code \\\n  --host 0.0.0.0 \\\n  --port 30000\n```\n\n> You can replace `deepseek-ai/DeepSeek-OCR-2` with `deepseek-ai/DeepSeek-OCR`.\n\n## Prompt examples\n\nRecommended prompts from the model card:\n\n```\n<image>\n<|grounding|>Convert the document to markdown.\n```\n\n```\n<image>\nFree OCR.\n```\n\n## OpenAI-compatible request example\n\n```python\nimport requests\n\nurl = \"http://localhost:30000/v1/chat/completions\"\n\ndata = {\n    \"model\": \"deepseek-ai/DeepSeek-OCR-2\",\n    \"messages\": [\n        {\n            \"role\": \"user\",\n            \"content\": [\n                {\"type\": \"text\", \"text\": \"<image>\\n<|grounding|>Convert the document to markdown.\"},\n                {\"type\": \"image_url\", \"image_url\": {\"url\": \"https://example.com/your_image.jpg\"}},\n            ],\n        }\n    ],\n    \"max_tokens\": 512,\n}\n\nresponse = requests.post(url, json=data)\nprint(response.text)\n```\n"
  },
  {
    "path": "docs/basic_usage/deepseek_v3.md",
    "content": "# DeepSeek V3/V3.1/R1 Usage\n\nSGLang provides many optimizations specifically designed for the DeepSeek models, making it the inference engine recommended by the official [DeepSeek team](https://github.com/deepseek-ai/DeepSeek-V3/tree/main?tab=readme-ov-file#62-inference-with-sglang-recommended) from Day 0.\n\nThis document outlines current optimizations for DeepSeek.\nFor an overview of the implemented features see the completed [Roadmap](https://github.com/sgl-project/sglang/issues/2591).\n\n## Launch DeepSeek V3.1/V3/R1 with SGLang\n\nTo run DeepSeek V3.1/V3/R1 models, the recommended settings are as follows:\n\n| Weight Type | Configuration |\n|------------|-------------------|\n| **Full precision [FP8](https://huggingface.co/deepseek-ai/DeepSeek-R1-0528)**<br>*(recommended)* | 8 x H200 |\n| | 8 x B200 |\n| | 8 x MI300X |\n| | 2 x 8 x H100/800/20 |\n| | Xeon 6980P CPU |\n| **Full precision ([BF16](https://huggingface.co/unsloth/DeepSeek-R1-0528-BF16))** (upcast from original FP8) | 2 x 8 x H200 |\n| | 2 x 8 x MI300X |\n| | 4 x 8 x H100/800/20 |\n| | 4 x 8 x A100/A800 |\n| **Quantized weights ([INT8](https://huggingface.co/meituan/DeepSeek-R1-Channel-INT8))** | 16 x A100/800 |\n| | 32 x L40S |\n| | Xeon 6980P CPU |\n| | 4 x Atlas 800I A3 |\n| **Quantized weights ([W4A8](https://huggingface.co/novita/Deepseek-R1-0528-W4AFP8))** | 8 x H20/100, 4 x H200 |\n| **Quantized weights ([AWQ](https://huggingface.co/QuixiAI/DeepSeek-R1-0528-AWQ))** | 8 x H100/800/20 |\n| | 8 x A100/A800 |\n| **Quantized weights ([MXFP4](https://huggingface.co/amd/DeepSeek-R1-MXFP4-Preview))** | 8, 4 x MI355X/350X |\n| **Quantized weights ([NVFP4](https://huggingface.co/nvidia/DeepSeek-R1-0528-NVFP4-v2))** | 8, 4 x B200 |\n\n<style>\n.md-typeset__table {\n  width: 100%;\n}\n\n.md-typeset__table table {\n  border-collapse: collapse;\n  margin: 1em 0;\n  border: 2px solid var(--md-typeset-table-color);\n  table-layout: fixed;\n}\n\n.md-typeset__table th {\n  border: 1px solid var(--md-typeset-table-color);\n  border-bottom: 2px solid var(--md-typeset-table-color);\n  background-color: var(--md-default-bg-color--lighter);\n  padding: 12px;\n}\n\n.md-typeset__table td {\n  border: 1px solid var(--md-typeset-table-color);\n  padding: 12px;\n}\n\n.md-typeset__table tr:nth-child(2n) {\n  background-color: var(--md-default-bg-color--lightest);\n}\n</style>\n\n```{important}\nThe official DeepSeek V3 is already in FP8 format, so you should not run it with any quantization arguments like `--quantization fp8`.\n```\n\nDetailed commands for reference:\n\n- [8 x H200](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#using-docker-recommended)\n- [4 x B200, 8 x B200](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-one-b200-node)\n- [8 x MI300X](../platforms/amd_gpu.md#running-deepseek-v3)\n- [2 x 8 x H200](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-two-h2008-nodes-and-docker)\n- [4 x 8 x A100](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-four-a1008-nodes)\n- [8 x A100 (AWQ)](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-8-a100a800-with-awq-quantization)\n- [16 x A100 (INT8)](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-16-a100a800-with-int8-quantization)\n- [32 x L40S (INT8)](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-32-l40s-with-int8-quantization)\n- [Xeon 6980P CPU](../platforms/cpu_server.md#example-running-deepseek-r1)\n- [4 x Atlas 800I A3 (int8)](../platforms/ascend_npu_deepseek_example.md#running-deepseek-with-pd-disaggregation-on-4-x-atlas-800i-a3)\n\n### Download Weights\nIf you encounter errors when starting the server, ensure the weights have finished downloading. It's recommended to download them beforehand or restart multiple times until all weights are downloaded. Please refer to [DeepSeek V3](https://huggingface.co/deepseek-ai/DeepSeek-V3-Base#61-inference-with-deepseek-infer-demo-example-only) official guide to download the weights.\n\n### Launch with one node of 8 x H200\nPlease refer to [the example](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#installation--launch).\n\n### Running examples on Multi-Node\n\n- [Deploying DeepSeek on GB200 NVL72 with PD and Large Scale EP](https://lmsys.org/blog/2025-06-16-gb200-part-1/) ([Part I](https://lmsys.org/blog/2025-06-16-gb200-part-1/), [Part II](https://lmsys.org/blog/2025-09-25-gb200-part-2/)) - Comprehensive guide on GB200 optimizations.\n\n- [Deploying DeepSeek with PD Disaggregation and Large-Scale Expert Parallelism on 96 H100 GPUs](https://lmsys.org/blog/2025-05-05-deepseek-pd-ep/) - Guide on PD disaggregation and large-scale EP.\n\n- [Serving with two H20*8 nodes](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-two-h208-nodes).\n\n- [Best Practices for Serving DeepSeek-R1 on H20](https://lmsys.org/blog/2025-09-26-sglang-ant-group/) - Comprehensive guide on H20 optimizations, deployment and performance.\n\n- [Serving with two H200*8 nodes and docker](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-two-h2008-nodes-and-docker).\n\n- [Serving with four A100*8 nodes](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-four-a1008-nodes).\n\n## Optimizations\n\n### Multi-head Latent Attention (MLA) Throughput Optimizations\n\n**Description**: [MLA](https://arxiv.org/pdf/2405.04434) is an innovative attention mechanism introduced by the DeepSeek team, aimed at improving inference efficiency. SGLang has implemented specific optimizations for this, including:\n\n- **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase.\n\n- **MLA Attention Backends**: Currently SGLang supports different optimized MLA attention backends, including [FlashAttention3](https://github.com/Dao-AILab/flash-attention), [Flashinfer](https://docs.flashinfer.ai/api/attention.html#flashinfer-mla), [FlashMLA](https://github.com/deepseek-ai/FlashMLA), [CutlassMLA](https://github.com/sgl-project/sglang/pull/5390), **TRTLLM MLA** (optimized for Blackwell architecture), and [Triton](https://github.com/triton-lang/triton) backends. The default FA3 provides good performance across wide workloads.\n\n- **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption.\n\n- **CUDA Graph & Torch.compile**: Both MLA and Mixture of Experts (MoE) are compatible with CUDA Graph and Torch.compile, which reduces latency and accelerates decoding speed for small batch sizes.\n\n- **Chunked Prefix Cache**: Chunked prefix cache optimization can increase throughput by cutting prefix cache into chunks, processing them with multi-head attention and merging their states. Its improvement can be significant when doing chunked prefill on long sequences. Currently this optimization is only available for FlashAttention3 backend.\n\nOverall, with these optimizations, we have achieved up to **7x** acceleration in output throughput compared to the previous version.\n\n<p align=\"center\">\n  <img src=\"https://lmsys.org/images/blog/sglang_v0_3/deepseek_mla.svg\" alt=\"Multi-head Latent Attention for DeepSeek Series Models\">\n</p>\n\n**Usage**: MLA optimization is enabled by default.\n\n**Reference**: Check [Blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/#deepseek-multi-head-latent-attention-mla-throughput-optimizations) and [Slides](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/lmsys_1st_meetup_deepseek_mla.pdf) for more details.\n\n### Data Parallelism Attention\n\n**Description**: This optimization involves data parallelism (DP) for the MLA attention mechanism of DeepSeek Series Models, which allows for a significant reduction in the KV cache size, enabling larger batch sizes. Each DP worker independently handles different types of batches (prefill, decode, idle), which are then synchronized before and after processing through the Mixture-of-Experts (MoE) layer. If you do not use DP attention, KV cache will be duplicated among all TP ranks.\n\n<p align=\"center\">\n  <img src=\"https://lmsys.org/images/blog/sglang_v0_4/dp_attention.svg\" alt=\"Data Parallelism Attention for DeepSeek Series Models\">\n</p>\n\nWith data parallelism attention enabled, we have achieved up to **1.9x** decoding throughput improvement compared to the previous version.\n\n<p align=\"center\">\n  <img src=\"https://lmsys.org/images/blog/sglang_v0_4/deepseek_coder_v2.svg\" alt=\"Data Parallelism Attention Performance Comparison\">\n</p>\n\n**Usage**:\n- Append `--enable-dp-attention --tp 8 --dp 8` to the server arguments when using 8 H200 GPUs. This optimization improves peak throughput in high batch size scenarios where the server is limited by KV cache capacity.\n- DP and TP attention can be flexibly combined. For example, to deploy DeepSeek-V3/R1 on 2 nodes with 8 H100 GPUs each, you can specify `--enable-dp-attention --tp 16 --dp 2`. This configuration runs attention with 2 DP groups, each containing 8 TP GPUs.\n\n```{caution}\nData parallelism attention is not recommended for low-latency, small-batch use cases. It is optimized for high-throughput scenarios with large batch sizes.\n```\n\n**Reference**: Check [Blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models).\n\n### Multi-Node Tensor Parallelism\n\n**Description**: For users with limited memory on a single node, SGLang supports serving DeepSeek Series Models, including DeepSeek V3, across multiple nodes using tensor parallelism. This approach partitions the model parameters across multiple GPUs or nodes to handle models that are too large for one node's memory.\n\n**Usage**: Check [here](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-two-h2008-nodes-and-docker) for usage examples.\n\n### Block-wise FP8\n\n**Description**: SGLang implements block-wise FP8 quantization with two key optimizations:\n\n- **Activation**: E4M3 format using per-token-per-128-channel sub-vector scales with online casting.\n\n- **Weight**: Per-128x128-block quantization for better numerical stability.\n\n- **DeepGEMM**: The [DeepGEMM](https://github.com/deepseek-ai/DeepGEMM) kernel library optimized for FP8 matrix multiplications.\n\n**Usage**: The activation and weight optimization above are turned on by default for DeepSeek V3 models. DeepGEMM is enabled by default on NVIDIA Hopper/Blackwell GPUs and disabled by default on other devices. DeepGEMM can also be manually turned off by setting the environment variable `SGLANG_ENABLE_JIT_DEEPGEMM=0`.\n\n```{tip}\nBefore serving the DeepSeek model, precompile the DeepGEMM kernels to improve first-run performance. The precompilation process typically takes around 10 minutes to complete.\n```\n\n```bash\npython3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code\n```\n\n### Multi-token Prediction\n**Description**: SGLang implements DeepSeek V3 Multi-Token Prediction (MTP) based on [EAGLE speculative decoding](https://docs.sglang.io/advanced_features/speculative_decoding.html#EAGLE-Decoding). With this optimization, the decoding speed can be improved by **1.8x** for batch size 1 and **1.5x** for batch size 32 respectively on H200 TP8 setting.\n\n**Usage**:\nAdd `--speculative-algorithm EAGLE`. Other flags, like `--speculative-num-steps`, `--speculative-eagle-topk` and `--speculative-num-draft-tokens` are optional. For example:\n```\npython3 -m sglang.launch_server \\\n  --model-path deepseek-ai/DeepSeek-V3-0324 \\\n  --speculative-algorithm EAGLE \\\n  --trust-remote-code \\\n  --tp 8\n```\n- The default configuration for DeepSeek models is `--speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4`. The best configuration for `--speculative-num-steps`, `--speculative-eagle-topk` and `--speculative-num-draft-tokens` can be searched with [bench_speculative.py](https://github.com/sgl-project/sglang/blob/main/scripts/playground/bench_speculative.py) script for given batch size. The minimum configuration is `--speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2`, which can achieve speedup for larger batch sizes.\n- Most MLA attention backends fully support MTP usage. See [MLA Backends](../advanced_features/attention_backend.md#mla-backends) for details.\n\n```{note}\nTo enable DeepSeek MTP for large batch sizes (>48), you need to adjust some parameters (Reference [this discussion](https://github.com/sgl-project/sglang/issues/4543#issuecomment-2737413756)):\n- Adjust `--max-running-requests` to a larger number. The default value is `48` for MTP. For larger batch sizes, you should increase this value beyond the default value.\n- Set `--cuda-graph-bs`. It's a list of batch sizes for cuda graph capture. The [default captured batch sizes for speculative decoding](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/server_args.py#L888-L895) is 48. You can customize this by including more batch sizes.\n```\n\n```{tip}\nTo enable the experimental overlap scheduler for EAGLE speculative decoding, set the environment variable `SGLANG_ENABLE_SPEC_V2=1`. This can improve performance by enabling overlap scheduling between draft and verification stages.\n```\n\n\n### Reasoning Content for DeepSeek R1 & V3.1\n\nSee [Reasoning Parser](https://docs.sglang.io/advanced_features/separate_reasoning.html) and [Thinking Parameter for DeepSeek V3.1](https://docs.sglang.io/basic_usage/openai_api_completions.html#Example:-DeepSeek-V3-Models).\n\n\n### Function calling for DeepSeek Models\n\nAdd arguments `--tool-call-parser deepseekv3` and `--chat-template ./examples/chat_template/tool_chat_template_deepseekv3.jinja`(recommended) to enable this feature. For example (running on 1 * H20 node):\n\n```\npython3 -m sglang.launch_server \\\n  --model deepseek-ai/DeepSeek-V3-0324 \\\n  --tp 8 \\\n  --port 30000 \\\n  --host 0.0.0.0 \\\n  --mem-fraction-static 0.9 \\\n  --tool-call-parser deepseekv3 \\\n  --chat-template ./examples/chat_template/tool_chat_template_deepseekv3.jinja\n```\n\nSample Request:\n\n```\ncurl \"http://127.0.0.1:30000/v1/chat/completions\" \\\n-H \"Content-Type: application/json\" \\\n-d '{\"temperature\": 0, \"max_tokens\": 100, \"model\": \"deepseek-ai/DeepSeek-V3-0324\", \"tools\": [{\"type\": \"function\", \"function\": {\"name\": \"query_weather\", \"description\": \"Get weather of a city, the user should supply a city first\", \"parameters\": {\"type\": \"object\", \"properties\": {\"city\": {\"type\": \"string\", \"description\": \"The city, e.g. Beijing\"}}, \"required\": [\"city\"]}}}], \"messages\": [{\"role\": \"user\", \"content\": \"How'\\''s the weather like in Qingdao today\"}]}'\n```\n\nExpected Response\n\n```\n{\"id\":\"6501ef8e2d874006bf555bc80cddc7c5\",\"object\":\"chat.completion\",\"created\":1745993638,\"model\":\"deepseek-ai/DeepSeek-V3-0324\",\"choices\":[{\"index\":0,\"message\":{\"role\":\"assistant\",\"content\":null,\"reasoning_content\":null,\"tool_calls\":[{\"id\":\"0\",\"index\":null,\"type\":\"function\",\"function\":{\"name\":\"query_weather\",\"arguments\":\"{\\\"city\\\": \\\"Qingdao\\\"}\"}}]},\"logprobs\":null,\"finish_reason\":\"tool_calls\",\"matched_stop\":null}],\"usage\":{\"prompt_tokens\":116,\"total_tokens\":138,\"completion_tokens\":22,\"prompt_tokens_details\":null}}\n\n```\nSample Streaming Request:\n```\ncurl \"http://127.0.0.1:30000/v1/chat/completions\" \\\n-H \"Content-Type: application/json\" \\\n-d '{\"temperature\": 0, \"max_tokens\": 100, \"model\": \"deepseek-ai/DeepSeek-V3-0324\",\"stream\":true,\"tools\": [{\"type\": \"function\", \"function\": {\"name\": \"query_weather\", \"description\": \"Get weather of a city, the user should supply a city first\", \"parameters\": {\"type\": \"object\", \"properties\": {\"city\": {\"type\": \"string\", \"description\": \"The city, e.g. Beijing\"}}, \"required\": [\"city\"]}}}], \"messages\": [{\"role\": \"user\", \"content\": \"How'\\''s the weather like in Qingdao today\"}]}'\n```\nExpected Streamed Chunks (simplified for clarity):\n```\ndata: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"function\":{\"arguments\":\"{\\\"\"}}]}}]}\ndata: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"function\":{\"arguments\":\"city\"}}]}}]}\ndata: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"function\":{\"arguments\":\"\\\":\\\"\"}}]}}]}\ndata: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"function\":{\"arguments\":\"Q\"}}]}}]}\ndata: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"function\":{\"arguments\":\"ing\"}}]}}]}\ndata: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"function\":{\"arguments\":\"dao\"}}]}}]}\ndata: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"function\":{\"arguments\":\"\\\"}\"}}]}}]}\ndata: {\"choices\":[{\"delta\":{\"tool_calls\":null}}], \"finish_reason\": \"tool_calls\"}\ndata: [DONE]\n```\nThe client needs to concatenate all arguments fragments to reconstruct the complete tool call:\n```\n{\"city\": \"Qingdao\"}\n```\n\n```{important}\n1. Use a lower `\"temperature\"` value for better results.\n2. To receive more consistent tool call results, it is recommended to use `--chat-template examples/chat_template/tool_chat_template_deepseekv3.jinja`. It provides an improved unified prompt.\n```\n\n\n### Thinking Budget for DeepSeek R1\n\nIn SGLang, we can implement thinking budget with `CustomLogitProcessor`.\n\nLaunch a server with `--enable-custom-logit-processor` flag on.\n\n```\npython3 -m sglang.launch_server --model deepseek-ai/DeepSeek-R1 --tp 8 --port 30000 --host 0.0.0.0 --mem-fraction-static 0.9 --disable-cuda-graph --reasoning-parser deepseek-r1 --enable-custom-logit-processor\n```\n\nSample Request:\n\n```python\nimport openai\nfrom rich.pretty import pprint\nfrom sglang.srt.sampling.custom_logit_processor import DeepSeekR1ThinkingBudgetLogitProcessor\n\n\nclient = openai.Client(base_url=\"http://127.0.0.1:30000/v1\", api_key=\"*\")\nresponse = client.chat.completions.create(\n    model=\"deepseek-ai/DeepSeek-R1\",\n    messages=[\n        {\n            \"role\": \"user\",\n            \"content\": \"Question: Is Paris the Capital of France?\",\n        }\n    ],\n    max_tokens=1024,\n    extra_body={\n        \"custom_logit_processor\": DeepSeekR1ThinkingBudgetLogitProcessor().to_str(),\n        \"custom_params\": {\n            \"thinking_budget\": 512,\n        },\n    },\n)\npprint(response)\n```\n\n## FAQ\n\n**Q: Model loading is taking too long, and I'm encountering an NCCL timeout. What should I do?**\n\nA: If you're experiencing extended model loading times and an NCCL timeout, you can try increasing the timeout duration. Add the argument `--dist-timeout 3600` when launching your model. This will set the timeout to one hour, which often resolves the issue.\n"
  },
  {
    "path": "docs/basic_usage/deepseek_v32.md",
    "content": "# DeepSeek V3.2 Usage\n\nDeepSeek-V3.2 model family equips DeepSeek-V3.1-Terminus with DeepSeek Sparse Attention (DSA) through continued training. With DSA, a fine-grained sparse attention mechanism powered by a lightning indexer, DeepSeek-V3.2 achieves efficiency improvements in long-context scenarios.\n\nFor reporting issues or tracking upcoming features, please refer to this [Roadmap](https://github.com/sgl-project/sglang/issues/11060).\n\nNote: This document is originally written for the usage of [DeepSeek-V3.2-Exp](https://huggingface.co/deepseek-ai/DeepSeek-V3.2-Exp) model. The usage of [DeepSeek-V3.2](https://huggingface.co/deepseek-ai/DeepSeek-V3.2) or [DeepSeek-V3.2-Speciale](https://huggingface.co/deepseek-ai/DeepSeek-V3.2-Speciale) is the same as DeepSeek-V3.2-Exp except for the tool call parser.\n\n\n## Installation\n\n### Docker\n\n```bash\n# H200/B200\ndocker pull lmsysorg/sglang:latest\n\n# MI350/MI355\ndocker pull lmsysorg/sglang:v0.5.8-rocm700-mi35x\n\n# MI300\n# v0.5.8-rocm700-mi30x does not include PR #17504. Prefer the newest MI30x ROCm\n# image tag from Docker Hub when available, or build from source (below).\ndocker pull lmsysorg/sglang:v0.5.8-rocm700-mi30x\n\n\n# NPUs\ndocker pull lmsysorg/sglang:dsv32-a2\ndocker pull lmsysorg/sglang:dsv32-a3\n```\n\n### Build From Source\n\n```bash\n# Install SGLang\ngit clone https://github.com/sgl-project/sglang\ncd sglang\npip3 install pip --upgrade\npip3 install -e \"python\"\n```\n## Launch DeepSeek V3.2 with SGLang\n\nTo serve [DeepSeek-V3.2-Exp](https://huggingface.co/deepseek-ai/DeepSeek-V3.2-Exp) on 8xH200/B200 GPUs:\n\n```bash\n# Launch with TP + DP (Recommended)\npython -m sglang.launch_server --model deepseek-ai/DeepSeek-V3.2-Exp --tp 8 --dp 8 --enable-dp-attention\n\n# Launch with EP + DP\npython -m sglang.launch_server --model deepseek-ai/DeepSeek-V3.2-Exp --tp 8 --ep 8 --dp 8 --enable-dp-attention\n\n# Launch with Pure TP\npython -m sglang.launch_server --model deepseek-ai/DeepSeek-V3.2-Exp --tp 8\n\n# Launch with TP on MI30x/MI35x\npython3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3.2-Exp --tp 8 --nsa-prefill-backend tilelang --nsa-decode-backend tilelang\n```\n\n### Configuration Tips\n- **DP Attention (Recommended)**: For DeepSeek V3.2 model, the kernels are customized for the use case of `dp_size=8`, so DP attention (`--dp 8 --enable-dp-attention`) is the recommended configuration for better stability and performance. All test cases use this configuration by default.\n- **Pure TP Mode**: Launching with pure TP (without `--dp` and `--enable-dp-attention`) is also supported. Note that this mode has not been fully validated in PD disaggregation scenarios.\n- **Short-sequence MHA prefill (adaptive)**: For short prefill sequences (default threshold: **2048 tokens**), the NSA backend uses standard MHA automatically (no extra flags). On H200 (SM90) this path uses the FlashAttention variable-length kernel; on B200 (SM100) it uses TRT-LLM ragged MHA. MHA uses `MHA_ONE_SHOT` for best performance. `MHA_ONE_SHOT` computes multi-head attention over all tokens (both cached prefix and newly extended tokens) in a single kernel invocation, avoiding the overhead of chunked KV cache processing. This achieves optimal throughput for short sequences where total sequence length fits within the chunk capacity limit.\n- **Choices of Attention Kernels**: The attention backend is automatically set to `nsa` attention backend for DeepSeek V3.2 model. In this backend, different kernels for sparse prefilling/decoding are implemented, which can be specified by `--nsa-prefill-backend` and `--nsa-decode-backend` server arguments. The choices of nsa prefill/decode attention kernels include:\n  - `flashmla_sparse`: `flash_mla_sparse_fwd` kernel from `flash_mla` library. Can run on both Hopper and Blackwell GPUs. It requires bf16 q, kv inputs.\n  - `flashmla_kv`: `flash_mla_with_kvcache` kernel from `flash_mla` library. Can run on both Hopper and Blackwell GPUs. It requires bf16 q, fp8 k_cache inputs.\n  - `fa3`: `flash_attn_with_kvcache` kernel from `flash_attn` library. Can only run on Hopper GPUs. It requires bf16 q, kv inputs.\n  - `tilelang`: `tilelang` implementation that can run on GPU, HPU and NPU.\n  - `aiter`: Aiter kernel on AMD HPUs. Can only be used as decode kernel.\n  - `trtllm`: `trtllm-mla` sparse kernel from flashinfer library. Only run on blackwell GPUs. It requires QKV bf16 or QKV fp8.\n- On the basis of performance benchmarks, the default configuration on H200 and B200 are set as follows :\n  - H200: `flashmla_sparse` prefill attention (short-seq prefill uses MHA via FlashAttention varlen), `fa3` decode attention, `bf16` kv cache dtype.\n  - B200: `flashmla_auto` prefill attention (short-seq prefill uses MHA via TRT-LLM ragged), `flashmla_kv` decode attention, `fp8_e4m3` kv cache dtype. `flashmla_auto` enables automatic selection of either `flashmla_sparse` or `flashmla_kv` kernel for prefill based on KV cache dtype, hardware, and heuristics. When FP8 KV cache is enabled and `total_kv_tokens < total_q_tokens * 512`, it uses the `flashmla_sparse` kernel; otherwise, it falls back to the `flashmla_kv` kernel. The heuristics may need to be tuned if the performance of either the `flashmla_sparse` or `flashmla_kv` kernel changes significantly.\n- On Blackwell platform, with slightly accuracy drop, the performance can boost up to 3x-5x\n  - B200: by choosing `trtllm` for both `--nsa-prefill-backend` and `--nsa-decode-backend`, the prefill attention use MHA via TRT-LLM ragged for both short and long sequence (**accuracy impact**). Combine the `trtllm` with `fp8_e4m3` kv cache, the kv cache dim is `576` (kv_lora_rank + qk_rope_head_dim) (**accuracy impact**), compare to the combination of `flashmla_auto` and `fp8_e4m` kv cache dim is `656` (kv_lora_rank + scale storage (kv_lora_rank // quant_block_size * 4 bytes) + rope dimension storage).\n\n\n## Multi-token Prediction\nSGLang implements Multi-Token Prediction (MTP) for DeepSeek V3.2 based on [EAGLE speculative decoding](https://docs.sglang.io/advanced_features/speculative_decoding.html#EAGLE-Decoding). With this optimization, the decoding speed can be improved significantly on small batch sizes. Please look at [this PR](https://github.com/sgl-project/sglang/pull/11652) for more information.\n\nExample usage with DP Attention:\n```bash\npython -m sglang.launch_server --model deepseek-ai/DeepSeek-V3.2-Exp --tp 8 --dp 8 --enable-dp-attention --speculative-algorithm EAGLE --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4\n```\n\nExample usage with Pure TP:\n```bash\npython -m sglang.launch_server --model deepseek-ai/DeepSeek-V3.2-Exp --tp 8 --speculative-algorithm EAGLE --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4\n```\n\n- The best configuration for `--speculative-num-steps`, `--speculative-eagle-topk` and `--speculative-num-draft-tokens` can be searched with [bench_speculative.py](https://github.com/sgl-project/sglang/blob/main/scripts/playground/bench_speculative.py) script for given batch size. The minimum configuration is `--speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2`, which can achieve speedup for larger batch sizes.\n- The default value of  `--max-running-requests` is set to `48` for MTP. For larger batch sizes, this value should be increased beyond the default value.\n\n```{tip}\nTo enable the experimental overlap scheduler for EAGLE speculative decoding, set the environment variable `SGLANG_ENABLE_SPEC_V2=1`. This can improve performance by enabling overlap scheduling between draft and verification stages.\n```\n\n\n## Function Calling and Reasoning Parser\nThe usage of function calling and reasoning parser is the same as DeepSeek V3.1. Please refer to [Reasoning Parser](https://docs.sglang.io/advanced_features/separate_reasoning.html) and [Tool Parser](https://docs.sglang.io/advanced_features/tool_parser.html) documents.\n\nTo launch `DeepSeek-V3.2-Exp` with function calling and reasoning parser:\n> Note: It is recommended to specify the chat-template, ensuring that you are within the sglang's root directory.\n```bash\npython3 -m sglang.launch_server \\\n  --model-path deepseek-ai/DeepSeek-V3.2-Exp \\\n  --trust-remote-code \\\n  --tp-size 8 --dp-size 8 --enable-dp-attention \\\n  --tool-call-parser deepseekv31 \\\n  --reasoning-parser deepseek-v3 \\\n  --chat-template ./examples/chat_template/tool_chat_template_deepseekv32.jinja\n```\n\nTo launch `DeepSeek-V3.2` with function calling and reasoning parser:\n```bash\npython3 -m sglang.launch_server \\\n  --model-path deepseek-ai/DeepSeek-V3.2 \\\n  --trust-remote-code \\\n  --tp-size 8 --dp-size 8 --enable-dp-attention \\\n  --tool-call-parser deepseekv32 \\\n  --reasoning-parser deepseek-v3\n```\n\n`DeepSeek-V3.2-Speciale` doesn't support tool calling, so can only be launched with reasoning parser:\n```bash\npython3 -m sglang.launch_server \\\n  --model-path deepseek-ai/DeepSeek-V3.2-Speciale \\\n  --trust-remote-code \\\n  --tp-size 8 --dp-size 8 --enable-dp-attention \\\n  --reasoning-parser deepseek-v3\n```\n\n## NVFP4 Checkpoint\n\nTo launch deepseek v3.2 [NVFP4 checkpoint](https://huggingface.co/nvidia/DeepSeek-V3.2-NVFP4) on Blackwell devices, the user needs to specify the quantization method as `modelopt_fp4`, and moe runner backend as one of `flashinfer_trtllm`(recommended), `flashinfer_cutlass` and `flashinfer_cutedsl`. Any other usage (parallelism, reasoning parser, ...) is the same as FP8 checkpoint.\n\nAn example launching command can be:\n```bash\npython -m sglang.launch_server --model nvidia/DeepSeek-V3.2-NVFP4 --tp 4 --quantization modelopt_fp4 --moe-runner-backend flashinfer_trtllm --tool-call-parser deepseekv32  --reasoning-parser deepseek-v3\n```\n\n## PD Disaggregation\n\nPrefill Command:\n```bash\npython -m sglang.launch_server \\\n        --model-path deepseek-ai/DeepSeek-V3.2-Exp \\\n        --disaggregation-mode prefill \\\n        --host $LOCAL_IP \\\n        --port $PORT \\\n        --tp 8 \\\n        --dp 8 \\\n        --enable-dp-attention \\\n        --dist-init-addr ${HOST}:${DIST_PORT} \\\n        --trust-remote-code \\\n        --disaggregation-bootstrap-port 8998 \\\n        --mem-fraction-static 0.9 \\\n```\n\nDecode command:\n```bash\npython -m sglang.launch_server \\\n        --model-path deepseek-ai/DeepSeek-V3.2-Exp \\\n        --disaggregation-mode decode \\\n        --host $LOCAL_IP \\\n        --port $PORT \\\n        --tp 8 \\\n        --dp 8 \\\n        --enable-dp-attention \\\n        --dist-init-addr ${HOST}:${DIST_PORT} \\\n        --trust-remote-code \\\n        --mem-fraction-static 0.9 \\\n```\n\nRouter command:\n```bash\npython -m sglang_router.launch_router --pd-disaggregation \\\n  --prefill $PREFILL_ADDR 8998 \\\n  --decode $DECODE_ADDR \\\n  --host 127.0.0.1 \\\n  --port 8000 \\\n```\n\nIf you need more advanced deployment methods or production-ready deployment methods, such as RBG or LWS-based deployment, please refer to [references/multi_node_deployment/rbg_pd/deepseekv32_pd.md](../references/multi_node_deployment/rbg_pd/deepseekv32_pd.md). Additionally, you can also find startup commands for DeepEP-based EP parallelism in the aforementioned documentation.\n\n\n## Benchmarking Results\n\n### Accuracy Test with `gsm8k`\nA simple accuracy benchmark can be tested with `gsm8k` dataset:\n```bash\npython3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1319 --parallel 1319\n```\n\nThe result is 0.956, which matches our expectation:\n```bash\nAccuracy: 0.956\nInvalid: 0.000\nLatency: 25.109 s\nOutput throughput: 5226.235 token/s\n```\n\nTo test long-context accuracy, run gsm8k with `--num-shots 20`. The results are very close to the 8 shots results:\n```\nAccuracy: 0.956\nInvalid: 0.000\nLatency: 29.545 s\nOutput throughput: 4418.617 token/s\n```\n\n\n### Accuracy Test with `gpqa-diamond`\n\nAccuracy benchmark on long context can be tested on GPQA-diamond dataset with long output tokens and thinking enabled:\n```bash\npython3 -m sglang.test.run_eval --port 30000 --eval-name gpqa --num-examples 198 --max-tokens 128000 --repeat 8 --thinking-mode deepseek-v3\n```\n\nThe mean accuracy over 8 runs shows 0.797, which matches the number 0.799 in official tech report.\n```bash\nRepeat: 8, mean: 0.797\nScores: ['0.808', '0.798', '0.808', '0.798', '0.783', '0.788', '0.803', '0.793']\n```\n\nFor Deepseek V3.2, Deepseek recommends setting the sampling parameters to temperature = 1.0, top_p = 0.95:\n\n```bash\npython3 -m sglang.test.run_eval --port 30000 --eval-name gpqa --num-examples 198 --max-tokens 128000 --repeat 8 --top-p 0.95 --temperature 1.0 --thinking-mode deepseek-v3\n\nRepeat: 8, mean: 0.840\nScores: ['0.848', '0.808', '0.848', '0.838', '0.879', '0.813', '0.838', '0.848']\n```\nwhich matches the official score, 0.824, as reported in the [Deepseek-V3.2 technical report](https://huggingface.co/deepseek-ai/DeepSeek-V3.2/blob/main/assets/paper.pdf).\n\n### Accuracy Test with `aime 2025`\n\nPrepare the environment by installing NeMo-Skills in the docker or your own virtual environment:\n\n  ```\n  pip install git+https://github.com/NVIDIA/NeMo-Skills.git --ignore-installed blinker\n  ```\n\nThen launch the SGLang server:\n```\npython -m sglang.launch_server --model deepseek-ai/DeepSeek-V3.2-Exp --tp 8 --dp 8 --enable-dp-attention\n```\n\n**For `DeepSeek-V3.2` and `DeepSeek-V3.2-Speciale`**:\n\n```\npython3 -m sglang.launch_server   --model-path deepseek-ai/DeepSeek-V3.2   --trust-remote-code   --tp-size 8 --dp-size 8 --enable-dp-attention   --tool-call-parser deepseekv32   --reasoning-parser deepseek-v3\n```\n\nRun the following script to evaluate AIME 2025:\n```\n#! /bin/bash\nexport NEMO_SKILLS_DISABLE_UNCOMMITTED_CHANGES_CHECK=1\n\nns prepare_data aime25\n\nPORT=30000\nBACKEND=sglang\nMODEL=\"deepseek-ai/DeepSeek-V3.2-Exp\" # Should be changed to the model name\nMODEL_NAME=\"dsv32-fp8\"\n\necho \"Starting AIME25 evaluation with model $MODEL on port $PORT using backend $BACKEND...\"\nns eval \\\n  --benchmarks=aime25:4 \\\n  --server_type=$BACKEND \\\n  --model=$MODEL \\\n  --server_address=http://localhost:${PORT}/v1 \\\n  --output_dir=nemo_skills_aime25_${MODEL_NAME}_output_${BACKEND}_$(date +%Y%m%d_%H%M%S) \\\n  ++chat_template_kwargs.thinking=true \\\n  ++inference.temperature=1.0 \\\n  ++inference.top_p=0.95 \\\n  ++inference.tokens_to_generate=64000\n  # ++inference.tokens_to_generate=120000 for Speciale model\n```\n\nTest results (8*B200):\n\nDeepSeek-V3.2-Exp：\n\n| evaluation_mode    | num_entries | avg_tokens | gen_seconds | symbolic_correct      | no_answer |\n|--------------------|-------------|------------|-------------|-----------------------|-----------|\n| pass@1[avg-of-4]   | 30          | 15040      | 1673        | 87.50% ± 1.67%        | 0.00%     |\n| majority@4         | 30          | 15040      | 1673        | 90.00%                | 0.00%     |\n| pass@4             | 30          | 15040      | 1673        | 90.00%                | 0.00%     |\n\n\nDeepSeek-V3.2:\n| evaluation_mode    | num_entries | avg_tokens | gen_seconds | symbolic_correct      | no_answer |\n|--------------------|-------------|------------|-------------|-----------------------|-----------|\n| pass@1[avg-of-4]   | 30          | 13550      | 1632        | 92.50% ± 1.67%        | 0.00%     |\n| majority@4         | 30          | 13550      | 1632        | 94.71%                | 0.00%     |\n| pass@4             | 30          | 13550      | 1632        | 96.67%                | 0.00%     |\n\n\nDeepSeek-V3.2-Speciale:\n| evaluation_mode    | num_entries | avg_tokens | gen_seconds | symbolic_correct      | no_answer |\n|--------------------|-------------|------------|-------------|-----------------------|-----------|\n| pass@1[avg-of-4]   | 30          | 24155      | 3583        | 95.00% ± 1.92%        | 0.00%     |\n| majority@4         | 30          | 24155      | 3583        | 95.83%                | 0.00%     |\n| pass@4             | 30          | 24155      | 3583        | 100.00%               | 0.00%     |\n\n\n\n## DSA long sequence context parallel optimization(experimental)\n\n**Note: This feature is only verified on Hopper machines**\n\nFor context parallel in DeepSeek V3.2 model, we provide two different modes of splitting tokens, which can be controlled with argument `--nsa-prefill-cp-mode`.\n\n### In sequence splitting\n\nThe first mode can be enabled by `--nsa-prefill-cp-mode in-seq-split`. This mode implements context parallel for DSA by splitting the sequence uniformly between context parallel ranks. At attention stage, each cp rank computes the indexer results of sharded sequence, and collects the whole kv cache through all gather operator. Add `attn_cp_size` for communication group for context parallel.\n\nNote that in sequence splitting mode has the following restrictions:\n- The batch size is restricted to 1 for prefill batches\n- `moe_dense_tp_size=1`, `moe_a2a_backend = \"deepep\"`\n- To ensure `cp_size > 1`, the passed in `tp_size` must be larger than `dp_size`\n\nFor more details, please refer to PR https://github.com/sgl-project/sglang/pull/12065.\n\nExample:\n```bash\n# In-seq splitting mode launched with EP + DP\npython -m sglang.launch_server --model deepseek-ai/DeepSeek-V3.2-Exp  --tp 8 --ep 8 --dp 2 --enable-dp-attention --enable-nsa-prefill-context-parallel --attn-cp-size 4 --nsa-prefill-cp-mode in-seq-split --max-running-requests 32\n```\n\n### Round robin splitting (default setting)\n\nThis mode can be enabled by specifying the parameter `--nsa-prefill-cp-mode round-robin-split`, which distributes tokens across ranks based on `token_idx % cp_size`.\n\nIn this scenario, compared with the aforementioned method, it additionally supports the fused MoE backend (the fused MoE backend may deliver better performance than DeepEP in single-machine scenarios), FP8 KV-cache, and multi-batch prefill inference. But it cannot be enabled with dp attention together.\n\nFor more details, please refer to PR https://github.com/sgl-project/sglang/pull/13959.\n\nExample usage:\n```bash\n# Launch with FusedMoe + CP8\npython -m sglang.launch_server --model deepseek-ai/DeepSeek-V3.2-Exp  --tp 8 --enable-nsa-prefill-context-parallel  --attn-cp-size 8 --nsa-prefill-cp-mode round-robin-split --max-running-requests 32\n```\n### Pipeline Parallel + Context Parallel (PP + CP)\n\nThis mode combines Pipeline Parallelism (PP) and Context Parallelism (CP) to scale across multiple nodes, which can achieve better throughput and Time To First Token (TTFT). Note that this method has only been tested on H20 96G.\n\n#### Standard Usage\n\nTo launch with PP=2 and CP (via `round-robin-split` mode) on 2 nodes. This configuration uses the fused MoE kernel by default, which generally provides better performance.\n\nFor related development details, please refer to:\n- Fused MoE + CP support: [PR #13959](https://github.com/sgl-project/sglang/pull/13959)\n- PP + CP support: [Issue #15358](https://github.com/sgl-project/sglang/issues/15358) and [PR #16380](https://github.com/sgl-project/sglang/pull/16380)\n\nNode 0:\n```bash\nexport SGLANG_PP_LAYER_PARTITION=30,31\npython3 -m sglang.launch_server \\\n  --model-path deepseek-ai/DeepSeek-V3.2-Exp \\\n  --nnodes 2 --node-rank 0 \\\n  --dist-init-addr <HEAD_NODE_IP>:62001 \\\n  --tp 8 --pp-size 2 \\\n  --dp-size 1 --moe-dense-tp-size 1 \\\n  --enable-nsa-prefill-context-parallel \\\n  --attn-cp-size 8 \\\n  --nsa-prefill-cp-mode round-robin-split \\\n  --trust-remote-code \\\n  --disable-radix-cache \\\n  --mem-fraction-static 0.8 \\\n  --max-running-requests 128 \\\n  --chunked-prefill-size 16384 \\\n  --cuda-graph-max-bs 8 \\\n  --page-size 64 \\\n  --watchdog-timeout 3600 \\\n  --host 0.0.0.0 --port 8000 \\\n  --tool-call-parser deepseekv32\n```\n\nNode 1:\n```bash\nexport SGLANG_PP_LAYER_PARTITION=30,31\npython3 -m sglang.launch_server \\\n  --model-path deepseek-ai/DeepSeek-V3.2-Exp \\\n  --nnodes 2 --node-rank 1 \\\n  --dist-init-addr <HEAD_NODE_IP>:62001 \\\n  --tp 8 --pp-size 2 \\\n  --dp-size 1 --moe-dense-tp-size 1 \\\n  --enable-nsa-prefill-context-parallel \\\n  --attn-cp-size 8 \\\n  --nsa-prefill-cp-mode round-robin-split \\\n  --trust-remote-code \\\n  --disable-radix-cache \\\n  --mem-fraction-static 0.8 \\\n  --max-running-requests 128 \\\n  --chunked-prefill-size 16384 \\\n  --cuda-graph-max-bs 8 \\\n  --page-size 64 \\\n  --watchdog-timeout 3600 \\\n  --host 0.0.0.0 --port 8000 \\\n  --tool-call-parser deepseekv32\n```\n\n#### PD Disaggregation with PP + CP\n\nIf using PD (Prefill-Decode) Disaggregation, the Prefill nodes can be configured with PP + CP as follows.\n\nPrefill Node 0:\n```bash\npython -m sglang.launch_server \\\n  --model-path deepseek-ai/DeepSeek-V3.2-Exp \\\n  --served-model-name deepseek-v32 \\\n  --nnodes 2 --node-rank 0 \\\n  --dist-init-addr <PREFILL_HEAD_IP>:20102 \\\n  --tp 8 --pp-size 2 \\\n  --dp-size 1 --moe-dense-tp-size 1 \\\n  --enable-nsa-prefill-context-parallel \\\n  --attn-cp-size 8 \\\n  --nsa-prefill-cp-mode round-robin-split  \\\n  --disaggregation-ib-device mlx5_bond_0,mlx5_bond_1,mlx5_bond_2,mlx5_bond_3 \\\n  --trust-remote-code \\\n  --disable-radix-cache \\\n  --max-running-requests 512 \\\n  --chunked-prefill-size 4096 \\\n  --context-length 131072 \\\n  --mem-fraction-static 0.9 \\\n  --page-size 64 \\\n  --enable-metrics \\\n  --collect-tokens-histogram \\\n  --tokenizer-worker-num 8 \\\n  --host 0.0.0.0 --port 30000\n```\n\nPrefill Node 1:\n```bash\npython -m sglang.launch_server \\\n  --model-path deepseek-ai/DeepSeek-V3.2-Exp \\\n  --served-model-name deepseek-v32-prefill \\\n  --nnodes 2 --node-rank 1 \\\n  --dist-init-addr <PREFILL_HEAD_IP>:20102 \\\n  --tp 8 --pp-size 2 \\\n  --dp-size 1 --moe-dense-tp-size 1 \\\n  --enable-nsa-prefill-context-parallel \\\n  --attn-cp-size 8 \\\n  --nsa-prefill-cp-mode round-robin-split  \\\n  --disaggregation-ib-device mlx5_bond_0,mlx5_bond_1,mlx5_bond_2,mlx5_bond_3 \\\n  --trust-remote-code \\\n  --disable-radix-cache \\\n  --max-running-requests 512 \\\n  --chunked-prefill-size 4096 \\\n  --context-length 131072 \\\n  --mem-fraction-static 0.9 \\\n  --page-size 64 \\\n  --enable-metrics \\\n  --collect-tokens-histogram \\\n  --tokenizer-worker-num 8 \\\n  --host 0.0.0.0 --port 30000\n```\n\nFor the Decode nodes, it is recommended to use the **EP mode**.\n"
  },
  {
    "path": "docs/basic_usage/glm45.md",
    "content": "## Launch GLM-4.5 / GLM-4.6 / GLM-4.7 with SGLang\n\nTo serve GLM-4.5 / GLM-4.6 FP8 models on 8xH100/H200 GPUs:\n\n```bash\npython3 -m sglang.launch_server --model zai-org/GLM-4.6-FP8 --tp 8\n```\n\n### EAGLE Speculative Decoding\n\n**Description**: SGLang has supported GLM-4.5 / GLM-4.6 models\nwith [EAGLE speculative decoding](https://docs.sglang.io/advanced_features/speculative_decoding.html#EAGLE-Decoding).\n\n**Usage**:\nAdd arguments `--speculative-algorithm`, `--speculative-num-steps`, `--speculative-eagle-topk` and\n`--speculative-num-draft-tokens` to enable this feature. For example:\n\n``` bash\npython3 -m sglang.launch_server \\\n  --model-path zai-org/GLM-4.6-FP8 \\\n  --tp-size 8 \\\n  --tool-call-parser glm45  \\\n  --reasoning-parser glm45  \\\n  --speculative-algorithm EAGLE \\\n  --speculative-num-steps 3  \\\n  --speculative-eagle-topk 1  \\\n  --speculative-num-draft-tokens 4 \\\n  --mem-fraction-static 0.9 \\\n  --served-model-name glm-4.6-fp8 \\\n  --enable-custom-logit-processor\n```\n\n```{tip}\nTo enable the experimental overlap scheduler for EAGLE speculative decoding, set the environment variable `SGLANG_ENABLE_SPEC_V2=1`. This can improve performance by enabling overlap scheduling between draft and verification stages.\n```\n\n### Thinking Budget for GLM-4.5 / GLM-4.6\n**Note**: For GLM-4.7, `--tool-call-parser` should be set to `glm47`, for GLM-4.5 and GLM-4.6, it should be set to `glm45`.\n\nIn SGLang, we can implement thinking budget with `CustomLogitProcessor`.\n\nLaunch a server with `--enable-custom-logit-processor` flag on.\n\nSample Request:\n\n```python\nimport openai\nfrom rich.pretty import pprint\nfrom sglang.srt.sampling.custom_logit_processor import Glm4MoeThinkingBudgetLogitProcessor\n\n\nclient = openai.Client(base_url=\"http://127.0.0.1:30000/v1\", api_key=\"*\")\nresponse = client.chat.completions.create(\n    model=\"zai-org/GLM-4.6\",\n    messages=[\n        {\n            \"role\": \"user\",\n            \"content\": \"Question: Is Paris the Capital of France?\",\n        }\n    ],\n    max_tokens=1024,\n    extra_body={\n        \"custom_logit_processor\": Glm4MoeThinkingBudgetLogitProcessor().to_str(),\n        \"custom_params\": {\n            \"thinking_budget\": 512,\n        },\n    },\n)\npprint(response)\n```\n"
  },
  {
    "path": "docs/basic_usage/glmv.md",
    "content": "# GLM-4.6V / GLM-4.5V Usage\n\n## Launch commands for SGLang\n\nBelow are suggested launch commands tailored for different hardware / precision modes\n\n### FP8 (quantised) mode\n\nFor high memory-efficiency and latency optimized deployments (e.g., on H100, H200) where FP8 checkpoint is supported:\n\n```bash\npython3 -m sglang.launch_server \\\n  --model-path zai-org/GLM-4.6V-FP8 \\\n  --tp 2 \\\n  --ep 2 \\\n  --host 0.0.0.0 \\\n  --port 30000 \\\n  --keep-mm-feature-on-device\n```\n\n### Non-FP8 (BF16 / full precision) mode\nFor deployments on A100/H100 where BF16 is used (or FP8 snapshot not used):\n```bash\npython3 -m sglang.launch_server \\\n  --model-path zai-org/GLM-4.6V \\\n  --tp 4 \\\n  --ep 4 \\\n  --host 0.0.0.0 \\\n  --port 30000\n```\n\n## Hardware-specific notes / recommendations\n\n- On H100 with FP8: Use the FP8 checkpoint for best memory efficiency.\n- On A100 / H100 with BF16 (non-FP8): It’s recommended to use `--mm-max-concurrent-calls` to control parallel throughput and GPU memory usage during image/video inference.\n- On H200 & B200: The model can be run “out of the box”, supporting full context length plus concurrent image + video processing.\n\n## Sending Image/Video Requests\n\n### Image input:\n\n```python\nimport requests\n\nurl = f\"http://localhost:30000/v1/chat/completions\"\n\ndata = {\n    \"model\": \"zai-org/GLM-4.6V\",\n    \"messages\": [\n        {\n            \"role\": \"user\",\n            \"content\": [\n                {\"type\": \"text\", \"text\": \"What’s in this image?\"},\n                {\n                    \"type\": \"image_url\",\n                    \"image_url\": {\n                        \"url\": \"https://github.com/sgl-project/sglang/blob/main/examples/assets/example_image.png?raw=true\"\n                    },\n                },\n            ],\n        }\n    ],\n    \"max_tokens\": 300,\n}\n\nresponse = requests.post(url, json=data)\nprint(response.text)\n```\n\n### Video Input:\n\n```python\nimport requests\n\nurl = f\"http://localhost:30000/v1/chat/completions\"\n\ndata = {\n    \"model\": \"zai-org/GLM-4.6V\",\n    \"messages\": [\n        {\n            \"role\": \"user\",\n            \"content\": [\n                {\"type\": \"text\", \"text\": \"What’s happening in this video?\"},\n                {\n                    \"type\": \"video_url\",\n                    \"video_url\": {\n                        \"url\": \"https://github.com/sgl-project/sgl-test-files/raw/refs/heads/main/videos/jobs_presenting_ipod.mp4\"\n                    },\n                },\n            ],\n        }\n    ],\n    \"max_tokens\": 300,\n}\n\nresponse = requests.post(url, json=data)\nprint(response.text)\n```\n\n## Important Server Parameters and Flags\n\nWhen launching the model server for **multimodal support**, you can use the following command-line arguments to fine-tune performance and behavior:\n\n- `--mm-attention-backend`: Specify multimodal attention backend. Eg. `fa3`(Flash Attention 3)\n- `--mm-max-concurrent-calls <value>`: Specifies the **maximum number of concurrent asynchronous multimodal data processing calls** allowed on the server. Use this to control parallel throughput and GPU memory usage during image/video inference.\n- `--mm-per-request-timeout <seconds>`: Defines the **timeout duration (in seconds)** for each multimodal request. If a request exceeds this time limit (e.g., for very large video inputs), it will be automatically terminated.\n- `--keep-mm-feature-on-device`: Instructs the server to **retain multimodal feature tensors on the GPU** after processing. This avoids device-to-host (D2H) memory copies and improves performance for repeated or high-frequency inference workloads.\n- `--mm-enable-dp-encoder`: Placing the ViT in data parallel while keeping the LLM in tensor parallel consistently lowers TTFT and boosts end-to-end throughput.\n- `SGLANG_USE_CUDA_IPC_TRANSPORT=1`: Shared memory pool based CUDA IPC for multi-modal data transport. For significantly improving e2e latency.\n\n### Example usage with the above optimizations:\n```bash\nSGLANG_USE_CUDA_IPC_TRANSPORT=1 \\\nSGLANG_VLM_CACHE_SIZE_MB=0 \\\npython -m sglang.launch_server \\\n  --model-path zai-org/GLM-4.6V \\\n  --host 0.0.0.0 \\\n  --port 30000 \\\n  --trust-remote-code \\\n  --tp-size 8 \\\n  --enable-cache-report \\\n  --log-level info \\\n  --max-running-requests 64 \\\n  --mem-fraction-static 0.65 \\\n  --chunked-prefill-size 8192 \\\n  --attention-backend fa3 \\\n  --mm-attention-backend fa3 \\\n  --mm-enable-dp-encoder \\\n  --enable-metrics\n```\n\n### Thinking Budget for GLM-4.5V / GLM-4.6V\n\nIn SGLang, we can implement thinking budget with `CustomLogitProcessor`.\n\nLaunch a server with the `--enable-custom-logit-processor` flag. Then, use `Glm4MoeThinkingBudgetLogitProcessor` in the request, similar to the `GLM-4.6` example in [glm45.md](./glm45.md).\n"
  },
  {
    "path": "docs/basic_usage/gpt_oss.md",
    "content": "# GPT OSS Usage\n\nPlease refer to [https://github.com/sgl-project/sglang/issues/8833](https://github.com/sgl-project/sglang/issues/8833).\n\n## Responses API & Built-in Tools\n\n### Responses API\n\nGPT‑OSS is compatible with the OpenAI Responses API. Use `client.responses.create(...)` with `model`, `instructions`, `input`, and optional `tools` to enable built‑in tool use. You can set reasoning level via `instructions`, e.g., \"Reasoning: high\" (also supports \"medium\" and \"low\") — levels: low (fast), medium (balanced), high (deep).\n\n### Built-in Tools\n\nGPT‑OSS can call built‑in tools for web search and Python execution. You can use the demo tool server or connect to external MCP tool servers.\n\n#### Python Tool\n\n- Executes short Python snippets for calculations, parsing, and quick scripts.\n- By default runs in a Docker-based sandbox. To run on the host, set `PYTHON_EXECUTION_BACKEND=UV` (this executes model-generated code locally; use with care).\n- Ensure Docker is available if you are not using the UV backend. It is recommended to run `docker pull python:3.11` in advance.\n\n#### Web Search Tool\n\n- Uses the Exa backend for web search.\n- Requires an Exa API key; set `EXA_API_KEY` in your environment. Create a key at `https://exa.ai`.\n\n### Tool & Reasoning Parser\n\n- We support OpenAI Reasoning and Tool Call parser, as well as our SGLang native api for tool call and reasoning. Refer to [reasoning parser](../advanced_features/separate_reasoning.ipynb) and [tool call parser](../advanced_features/function_calling.ipynb) for more details.\n\n\n## Notes\n\n- Use **Python 3.12** for the demo tools. And install the required `gpt-oss` packages.\n- The default demo integrates the web search tool (Exa backend) and a demo Python interpreter via Docker.\n- For search, set `EXA_API_KEY`. For Python execution, either have Docker available or set `PYTHON_EXECUTION_BACKEND=UV`.\n\nExamples:\n```bash\nexport EXA_API_KEY=YOUR_EXA_KEY\n# Optional: run Python tool locally instead of Docker (use with care)\nexport PYTHON_EXECUTION_BACKEND=UV\n```\n\nLaunch the server with the demo tool server:\n\n```bash\npython3 -m sglang.launch_server \\\n  --model-path openai/gpt-oss-120b \\\n  --tool-server demo \\\n  --tp 2\n```\n\nFor production usage, sglang can act as an MCP client for multiple services. An [example tool server](https://github.com/openai/gpt-oss/tree/main/gpt-oss-mcp-server) is provided. Start the servers and point sglang to them:\n```bash\nmcp run -t sse browser_server.py:mcp\nmcp run -t sse python_server.py:mcp\n\npython -m sglang.launch_server ... --tool-server ip-1:port-1,ip-2:port-2\n```\nThe URLs should be MCP SSE servers that expose server information and well-documented tools. These tools are added to the system prompt so the model can use them.\n\n## Speculative Decoding\n\nSGLang supports speculative decoding for GPT-OSS models using EAGLE3 algorithm. This can significantly improve decoding speed, especially for small batch sizes.\n\n**Usage**:\nAdd `--speculative-algorithm EAGLE3` along with the draft model path.\n```bash\npython3 -m sglang.launch_server \\\n  --model-path openai/gpt-oss-120b \\\n  --speculative-algorithm EAGLE3 \\\n  --speculative-draft-model-path lmsys/EAGLE3-gpt-oss-120b-bf16 \\\n  --tp 2\n```\n\n```{tip}\nTo enable the experimental overlap scheduler for EAGLE3 speculative decoding, set the environment variable `SGLANG_ENABLE_SPEC_V2=1`. This can improve performance by enabling overlap scheduling between draft and verification stages.\n```\n\n### Quick Demo\n\n```python\nfrom openai import OpenAI\n\nclient = OpenAI(\n    base_url=\"http://localhost:30000/v1\",\n    api_key=\"sk-123456\"\n)\n\ntools = [\n    {\"type\": \"code_interpreter\"},\n    {\"type\": \"web_search_preview\"},\n]\n\n# Reasoning level example\nresponse = client.responses.create(\n    model=\"openai/gpt-oss-120b\",\n    instructions=\"You are a helpful assistant.\"\n    reasoning_effort=\"high\" # Supports high, medium, or low\n    input=\"In one sentence, explain the transformer architecture.\",\n)\nprint(\"====== reasoning: high ======\")\nprint(response.output_text)\n\n# Test python tool\nresponse = client.responses.create(\n    model=\"openai/gpt-oss-120b\",\n    instructions=\"You are a helpful assistant, you could use python tool to execute code.\",\n    input=\"Use python tool to calculate the sum of 29138749187 and 29138749187\", # 58,277,498,374\n    tools=tools\n)\nprint(\"====== test python tool ======\")\nprint(response.output_text)\n\n# Test browser tool\nresponse = client.responses.create(\n    model=\"openai/gpt-oss-120b\",\n    instructions=\"You are a helpful assistant, you could use browser to search the web\",\n    input=\"Search the web for the latest news about Nvidia stock price\",\n    tools=tools\n)\nprint(\"====== test browser tool ======\")\nprint(response.output_text)\n```\n\nExample output:\n```\n====== test python tool ======\nThe sum of 29,138,749,187 and 29,138,749,187 is **58,277,498,374**.\n====== test browser tool ======\n**Recent headlines on Nvidia (NVDA) stock**\n\n| Date (2025) | Source | Key news points | Stock‑price detail |\n|-------------|--------|----------------|--------------------|\n| **May 13** | Reuters | The market data page shows Nvidia trading “higher” at **$116.61** with no change from the previous close. | **$116.61** – latest trade (delayed ≈ 15 min)【14†L34-L38】 |\n| **Aug 18** | CNBC | Morgan Stanley kept an **overweight** rating and lifted its price target to **$206** (up from $200), implying a 14 % upside from the Friday close. The firm notes Nvidia shares have already **jumped 34 % this year**. | No exact price quoted, but the article signals strong upside expectations【9†L27-L31】 |\n| **Aug 20** | The Motley Fool | Nvidia is set to release its Q2 earnings on Aug 27. The article lists the **current price of $175.36**, down 0.16 % on the day (as of 3:58 p.m. ET). | **$175.36** – current price on Aug 20【10†L12-L15】【10†L53-L57】 |\n\n**What the news tells us**\n\n* Nvidia’s share price has risen sharply this year – up roughly a third according to Morgan Stanley – and analysts are still raising targets (now $206).\n* The most recent market quote (Reuters, May 13) was **$116.61**, but the stock has surged since then, reaching **$175.36** by mid‑August.\n* Upcoming earnings on **Aug 27** are a focal point; both the Motley Fool and Morgan Stanley expect the results could keep the rally going.\n\n**Bottom line:** Nvidia’s stock is on a strong upward trajectory in 2025, with price targets climbing toward $200‑$210 and the market price already near $175 as of late August.\n\n```\n"
  },
  {
    "path": "docs/basic_usage/llama4.md",
    "content": "# Llama4 Usage\n\n[Llama 4](https://github.com/meta-llama/llama-models/blob/main/models/llama4/MODEL_CARD.md) is Meta's latest generation of open-source LLM model with industry-leading performance.\n\nSGLang has supported Llama 4 Scout (109B) and Llama 4 Maverick (400B) since [v0.4.5](https://github.com/sgl-project/sglang/releases/tag/v0.4.5).\n\nOngoing optimizations are tracked in the [Roadmap](https://github.com/sgl-project/sglang/issues/5118).\n\n## Launch Llama 4 with SGLang\n\nTo serve Llama 4 models on 8xH100/H200 GPUs:\n\n```bash\npython3 -m sglang.launch_server \\\n  --model-path meta-llama/Llama-4-Scout-17B-16E-Instruct \\\n  --tp 8 \\\n  --context-length 1000000\n```\n\n### Configuration Tips\n\n- **OOM Mitigation**: Adjust `--context-length` to avoid a GPU out-of-memory issue. For the Scout model, we recommend setting this value up to 1M on 8\\*H100 and up to 2.5M on 8\\*H200. For the Maverick model, we don't need to set context length on 8\\*H200. When hybrid kv cache is enabled, `--context-length` can be set up to 5M on 8\\*H100 and up to 10M on 8\\*H200 for the Scout model.\n\n- **Attention Backend Auto-Selection**: SGLang automatically selects the optimal attention backend for Llama 4 based on your hardware. You typically don't need to specify `--attention-backend` manually:\n  - **Blackwell GPUs (B200/GB200)**: `trtllm_mha`\n  - **Hopper GPUs (H100/H200)**: `fa3`\n  - **AMD GPUs**: `aiter`\n  - **Intel XPU**: `intel_xpu`\n  - **Other platforms**: `triton` (fallback)\n\n  To override the auto-selection, explicitly specify `--attention-backend` with one of the supported backends: `fa3`, `aiter`, `triton`, `trtllm_mha`, or `intel_xpu`.\n\n- **Chat Template**: Add `--chat-template llama-4` for chat completion tasks.\n- **Enable Multi-Modal**: Add `--enable-multimodal` for multi-modal capabilities.\n- **Enable Hybrid-KVCache**: Set `--swa-full-tokens-ratio` to adjust the ratio of SWA layer (for Llama4, it's local attention layer) KV tokens / full layer KV tokens. (default: 0.8, range: 0-1)\n\n\n### EAGLE Speculative Decoding\n**Description**: SGLang has supported Llama 4 Maverick (400B) with [EAGLE speculative decoding](https://docs.sglang.io/advanced_features/speculative_decoding.html#EAGLE-Decoding).\n\n**Usage**:\nAdd arguments `--speculative-draft-model-path`, `--speculative-algorithm`, `--speculative-num-steps`, `--speculative-eagle-topk` and `--speculative-num-draft-tokens` to enable this feature. For example:\n```\npython3 -m sglang.launch_server \\\n  --model-path meta-llama/Llama-4-Maverick-17B-128E-Instruct \\\n  --speculative-algorithm EAGLE3 \\\n  --speculative-draft-model-path nvidia/Llama-4-Maverick-17B-128E-Eagle3 \\\n  --speculative-num-steps 3 \\\n  --speculative-eagle-topk 1 \\\n  --speculative-num-draft-tokens 4 \\\n  --trust-remote-code \\\n  --tp 8 \\\n  --context-length 1000000\n```\n\n- **Note** The Llama 4 draft model *nvidia/Llama-4-Maverick-17B-128E-Eagle3* can only recognize conversations in chat mode.\n\n## Benchmarking Results\n\n### Accuracy Test with `lm_eval`\n\nThe accuracy on SGLang for both Llama4 Scout and Llama4 Maverick can match the [official benchmark numbers](https://ai.meta.com/blog/llama-4-multimodal-intelligence/).\n\nBenchmark results on MMLU Pro dataset with 8*H100:\n|                    | Llama-4-Scout-17B-16E-Instruct | Llama-4-Maverick-17B-128E-Instruct  |\n|--------------------|--------------------------------|-------------------------------------|\n| Official Benchmark | 74.3                           | 80.5                                |\n| SGLang             | 75.2                           | 80.7                                |\n\nCommands:\n\n```bash\n# Llama-4-Scout-17B-16E-Instruct model\npython -m sglang.launch_server \\\n  --model-path meta-llama/Llama-4-Scout-17B-16E-Instruct \\\n  --port 30000 \\\n  --tp 8 \\\n  --mem-fraction-static 0.8 \\\n  --context-length 65536\nlm_eval --model local-chat-completions --model_args model=meta-llama/Llama-4-Scout-17B-16E-Instruct,base_url=http://localhost:30000/v1/chat/completions,num_concurrent=128,timeout=999999,max_gen_toks=2048 --tasks mmlu_pro --batch_size 128 --apply_chat_template --num_fewshot 0\n\n# Llama-4-Maverick-17B-128E-Instruct\npython -m sglang.launch_server \\\n  --model-path meta-llama/Llama-4-Maverick-17B-128E-Instruct \\\n  --port 30000 \\\n  --tp 8 \\\n  --mem-fraction-static 0.8 \\\n  --context-length 65536\nlm_eval --model local-chat-completions --model_args model=meta-llama/Llama-4-Maverick-17B-128E-Instruct,base_url=http://localhost:30000/v1/chat/completions,num_concurrent=128,timeout=999999,max_gen_toks=2048 --tasks mmlu_pro --batch_size 128 --apply_chat_template --num_fewshot 0\n```\n\nDetails can be seen in [this PR](https://github.com/sgl-project/sglang/pull/5092).\n"
  },
  {
    "path": "docs/basic_usage/minimax_m2.md",
    "content": "# MiniMax M2.5/M2.1/M2 Usage\n\n[MiniMax-M2.5](https://huggingface.co/MiniMaxAI/MiniMax-M2.5), [MiniMax-M2.1](https://huggingface.co/MiniMaxAI/MiniMax-M2.1), and [MiniMax-M2](https://huggingface.co/MiniMaxAI/MiniMax-M2) are advanced large language models created by [MiniMax](https://www.minimax.io/).\n\nThe MiniMax-M2 series redefines efficiency for agents. These compact, fast, and cost-effective MoE models (230 billion total parameters with 10 billion active parameters) are built for elite performance in coding and agentic tasks, all while maintaining powerful general intelligence. With just 10 billion activated parameters, the MiniMax-M2 series provides sophisticated, end-to-end tool use performance expected from today's leading models, but in a streamlined form factor that makes deployment and scaling easier than ever.\n\n## Supported Models\n\nThis guide applies to the following models. You only need to update the model name during deployment. The following examples use **MiniMax-M2**:\n\n- [MiniMaxAI/MiniMax-M2.5](https://huggingface.co/MiniMaxAI/MiniMax-M2.5)\n- [MiniMaxAI/MiniMax-M2.1](https://huggingface.co/MiniMaxAI/MiniMax-M2.1)\n- [MiniMaxAI/MiniMax-M2](https://huggingface.co/MiniMaxAI/MiniMax-M2)\n\n## System Requirements\n\nThe following are recommended configurations; actual requirements should be adjusted based on your use case:\n\n- 4x 96GB GPUs: Supported context length of up to 400K tokens.\n- 8x 144GB GPUs: Supported context length of up to 3M tokens.\n\n## Deployment with Python\n\n4-GPU deployment command:\n\n```bash\npython -m sglang.launch_server \\\n    --model-path MiniMaxAI/MiniMax-M2 \\\n    --tp-size 4 \\\n    --tool-call-parser minimax-m2 \\\n    --reasoning-parser minimax-append-think \\\n    --host 0.0.0.0 \\\n    --trust-remote-code \\\n    --port 8000 \\\n    --mem-fraction-static 0.85\n```\n\n8-GPU deployment command:\n\n```bash\npython -m sglang.launch_server \\\n    --model-path MiniMaxAI/MiniMax-M2 \\\n    --tp-size 8 \\\n    --ep-size 8 \\\n    --tool-call-parser minimax-m2 \\\n    --reasoning-parser minimax-append-think \\\n    --host 0.0.0.0 \\\n    --trust-remote-code \\\n    --port 8000 \\\n    --mem-fraction-static 0.85\n```\n\n### AMD GPUs (MI300X/MI325X/MI355X)\n\n8-GPU deployment command:\n\n```bash\nSGLANG_USE_AITER=1 python -m sglang.launch_server \\\n    --model-path MiniMaxAI/MiniMax-M2.5 \\\n    --tp-size 8 \\\n    --ep-size 8 \\\n    --attention-backend aiter \\\n    --tool-call-parser minimax-m2 \\\n    --reasoning-parser minimax-append-think \\\n    --host 0.0.0.0 \\\n    --trust-remote-code \\\n    --port 8000 \\\n    --mem-fraction-static 0.85\n```\n\n## Testing Deployment\n\nAfter startup, you can test the SGLang OpenAI-compatible API with the following command:\n\n```bash\ncurl http://localhost:8000/v1/chat/completions \\\n    -H \"Content-Type: application/json\" \\\n    -d '{\n        \"model\": \"MiniMaxAI/MiniMax-M2\",\n        \"messages\": [\n            {\"role\": \"system\", \"content\": [{\"type\": \"text\", \"text\": \"You are a helpful assistant.\"}]},\n            {\"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": \"Who won the world series in 2020?\"}]}\n        ]\n    }'\n```\n"
  },
  {
    "path": "docs/basic_usage/native_api.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# SGLang Native APIs\\n\",\n    \"\\n\",\n    \"Apart from the OpenAI compatible APIs, the SGLang Runtime also provides its native server APIs. We introduce the following APIs:\\n\",\n    \"\\n\",\n    \"- `/generate` (text generation model)\\n\",\n    \"- `/get_model_info`\\n\",\n    \"- `/get_server_info`\\n\",\n    \"- `/health`\\n\",\n    \"- `/health_generate`\\n\",\n    \"- `/flush_cache`\\n\",\n    \"- `/update_weights`\\n\",\n    \"- `/encode`(embedding model)\\n\",\n    \"- `/v1/rerank`(cross encoder rerank model)\\n\",\n    \"- `/v1/score`(decoder-only scoring)\\n\",\n    \"- `/classify`(reward model)\\n\",\n    \"- `/start_expert_distribution_record`\\n\",\n    \"- `/stop_expert_distribution_record`\\n\",\n    \"- `/dump_expert_distribution_record`\\n\",\n    \"- `/tokenize`\\n\",\n    \"- `/detokenize`\\n\",\n    \"- A full list of these APIs can be found at [http_server.py](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/entrypoints/http_server.py)\\n\",\n    \"\\n\",\n    \"We mainly use `requests` to test these APIs in the following examples. You can also use `curl`.\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Launch A Server\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from sglang.test.doc_patch import launch_server_cmd\\n\",\n    \"from sglang.utils import wait_for_server, print_highlight, terminate_process\\n\",\n    \"\\n\",\n    \"server_process, port = launch_server_cmd(\\n\",\n    \"    \\\"python3 -m sglang.launch_server --model-path qwen/qwen2.5-0.5b-instruct --host 0.0.0.0 --log-level warning\\\"\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"wait_for_server(f\\\"http://localhost:{port}\\\", process=server_process)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Generate (text generation model)\\n\",\n    \"Generate completions. This is similar to the `/v1/completions` in OpenAI API. Detailed parameters can be found in the [sampling parameters](sampling_params.md).\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import requests\\n\",\n    \"\\n\",\n    \"url = f\\\"http://localhost:{port}/generate\\\"\\n\",\n    \"data = {\\\"text\\\": \\\"What is the capital of France?\\\"}\\n\",\n    \"\\n\",\n    \"response = requests.post(url, json=data)\\n\",\n    \"print_highlight(response.json())\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Get Model Info\\n\",\n    \"\\n\",\n    \"Get the information of the model.\\n\",\n    \"\\n\",\n    \"- `model_path`: The path/name of the model.\\n\",\n    \"- `is_generation`: Whether the model is used as generation model or embedding model.\\n\",\n    \"- `tokenizer_path`: The path/name of the tokenizer.\\n\",\n    \"- `preferred_sampling_params`: The default sampling params specified via `--preferred-sampling-params`. `None` is returned in this example as we did not explicitly configure it in server args.\\n\",\n    \"- `weight_version`: This field contains the version of the model weights. This is often used to track changes or updates to the model’s trained parameters.\\n\",\n    \"- `has_image_understanding`: Whether the model has image-understanding capability.\\n\",\n    \"- `has_audio_understanding`: Whether the model has audio-understanding capability.\\n\",\n    \"- `model_type`: The model type from the HuggingFace config (e.g., \\\"qwen2\\\", \\\"llama\\\").\\n\",\n    \"- `architectures`: The model architectures from the HuggingFace config (e.g., [\\\"Qwen2ForCausalLM\\\"]).\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"url = f\\\"http://localhost:{port}/get_model_info\\\"\\n\",\n    \"\\n\",\n    \"response = requests.get(url)\\n\",\n    \"response_json = response.json()\\n\",\n    \"print_highlight(response_json)\\n\",\n    \"assert response_json[\\\"model_path\\\"] == \\\"qwen/qwen2.5-0.5b-instruct\\\"\\n\",\n    \"assert response_json[\\\"is_generation\\\"] is True\\n\",\n    \"assert response_json[\\\"tokenizer_path\\\"] == \\\"qwen/qwen2.5-0.5b-instruct\\\"\\n\",\n    \"assert response_json[\\\"preferred_sampling_params\\\"] is None\\n\",\n    \"assert response_json.keys() == {\\n\",\n    \"    \\\"model_path\\\",\\n\",\n    \"    \\\"is_generation\\\",\\n\",\n    \"    \\\"tokenizer_path\\\",\\n\",\n    \"    \\\"preferred_sampling_params\\\",\\n\",\n    \"    \\\"weight_version\\\",\\n\",\n    \"    \\\"has_image_understanding\\\",\\n\",\n    \"    \\\"has_audio_understanding\\\",\\n\",\n    \"    \\\"model_type\\\",\\n\",\n    \"    \\\"architectures\\\",\\n\",\n    \"}\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Get Server Info\\n\",\n    \"Gets the server information including CLI arguments, token limits, and memory pool sizes.\\n\",\n    \"- Note: `get_server_info` merges the following deprecated endpoints:\\n\",\n    \"  - `get_server_args`\\n\",\n    \"  - `get_memory_pool_size`\\n\",\n    \"  - `get_max_total_num_tokens`\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"url = f\\\"http://localhost:{port}/get_server_info\\\"\\n\",\n    \"\\n\",\n    \"response = requests.get(url)\\n\",\n    \"print_highlight(response.text)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Health Check\\n\",\n    \"- `/health`: Check the health of the server.\\n\",\n    \"- `/health_generate`: Check the health of the server by generating one token.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"url = f\\\"http://localhost:{port}/health_generate\\\"\\n\",\n    \"\\n\",\n    \"response = requests.get(url)\\n\",\n    \"print_highlight(response.text)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"url = f\\\"http://localhost:{port}/health\\\"\\n\",\n    \"\\n\",\n    \"response = requests.get(url)\\n\",\n    \"print_highlight(response.text)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Flush Cache\\n\",\n    \"\\n\",\n    \"Flush the radix cache. It will be automatically triggered when the model weights are updated by the `/update_weights` API.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"url = f\\\"http://localhost:{port}/flush_cache\\\"\\n\",\n    \"\\n\",\n    \"response = requests.post(url)\\n\",\n    \"print_highlight(response.text)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Update Weights From Disk\\n\",\n    \"\\n\",\n    \"Update model weights from disk without restarting the server. Only applicable for models with the same architecture and parameter size.\\n\",\n    \"\\n\",\n    \"SGLang support `update_weights_from_disk` API for continuous evaluation during training (save checkpoint to disk and update weights from disk).\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# successful update with same architecture and size\\n\",\n    \"\\n\",\n    \"url = f\\\"http://localhost:{port}/update_weights_from_disk\\\"\\n\",\n    \"data = {\\\"model_path\\\": \\\"qwen/qwen2.5-0.5b-instruct\\\"}\\n\",\n    \"\\n\",\n    \"response = requests.post(url, json=data)\\n\",\n    \"print_highlight(response.text)\\n\",\n    \"assert response.json()[\\\"success\\\"] is True\\n\",\n    \"assert response.json()[\\\"message\\\"] == \\\"Succeeded to update model weights.\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# failed update with different parameter size or wrong name\\n\",\n    \"\\n\",\n    \"url = f\\\"http://localhost:{port}/update_weights_from_disk\\\"\\n\",\n    \"data = {\\\"model_path\\\": \\\"qwen/qwen2.5-0.5b-instruct-wrong\\\"}\\n\",\n    \"\\n\",\n    \"response = requests.post(url, json=data)\\n\",\n    \"response_json = response.json()\\n\",\n    \"print_highlight(response_json)\\n\",\n    \"assert response_json[\\\"success\\\"] is False\\n\",\n    \"assert response_json[\\\"message\\\"] == (\\n\",\n    \"    \\\"Failed to get weights iterator: \\\"\\n\",\n    \"    \\\"qwen/qwen2.5-0.5b-instruct-wrong\\\"\\n\",\n    \"    \\\" (repository not found).\\\"\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"terminate_process(server_process)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Encode (embedding model)\\n\",\n    \"\\n\",\n    \"Encode text into embeddings. Note that this API is only available for [embedding models](openai_api_embeddings.ipynb) and will raise an error for generation models.\\n\",\n    \"Therefore, we launch a new server to server an embedding model.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"embedding_process, port = launch_server_cmd(\\\"\\\"\\\"\\n\",\n    \"python3 -m sglang.launch_server --model-path Alibaba-NLP/gte-Qwen2-1.5B-instruct \\\\\\n\",\n    \"    --host 0.0.0.0 --is-embedding --log-level warning\\n\",\n    \"\\\"\\\"\\\")\\n\",\n    \"\\n\",\n    \"wait_for_server(f\\\"http://localhost:{port}\\\", process=embedding_process)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# successful encode for embedding model\\n\",\n    \"\\n\",\n    \"url = f\\\"http://localhost:{port}/encode\\\"\\n\",\n    \"data = {\\\"model\\\": \\\"Alibaba-NLP/gte-Qwen2-1.5B-instruct\\\", \\\"text\\\": \\\"Once upon a time\\\"}\\n\",\n    \"\\n\",\n    \"response = requests.post(url, json=data)\\n\",\n    \"response_json = response.json()\\n\",\n    \"print_highlight(f\\\"Text embedding (first 10): {response_json['embedding'][:10]}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"terminate_process(embedding_process)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## v1/rerank (cross encoder rerank model)\\n\",\n    \"Rerank a list of documents given a query using a cross-encoder model. Note that this API is only available for cross encoder model like [BAAI/bge-reranker-v2-m3](https://huggingface.co/BAAI/bge-reranker-v2-m3) with `attention-backend` `triton` and `torch_native`.\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"reranker_process, port = launch_server_cmd(\\\"\\\"\\\"\\n\",\n    \"python3 -m sglang.launch_server --model-path BAAI/bge-reranker-v2-m3 \\\\\\n\",\n    \"    --host 0.0.0.0 --disable-radix-cache --chunked-prefill-size -1 --attention-backend triton --is-embedding --log-level warning\\n\",\n    \"\\\"\\\"\\\")\\n\",\n    \"\\n\",\n    \"wait_for_server(f\\\"http://localhost:{port}\\\", process=reranker_process)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# compute rerank scores for query and documents\\n\",\n    \"\\n\",\n    \"url = f\\\"http://localhost:{port}/v1/rerank\\\"\\n\",\n    \"data = {\\n\",\n    \"    \\\"model\\\": \\\"BAAI/bge-reranker-v2-m3\\\",\\n\",\n    \"    \\\"query\\\": \\\"what is panda?\\\",\\n\",\n    \"    \\\"documents\\\": [\\n\",\n    \"        \\\"hi\\\",\\n\",\n    \"        \\\"The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.\\\",\\n\",\n    \"    ],\\n\",\n    \"}\\n\",\n    \"\\n\",\n    \"response = requests.post(url, json=data)\\n\",\n    \"response_json = response.json()\\n\",\n    \"for item in response_json:\\n\",\n    \"    print_highlight(f\\\"Score: {item['score']:.2f} - Document: '{item['document']}'\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"terminate_process(reranker_process)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## v1/score (decoder-only scoring)\\n\",\n    \"\\n\",\n    \"Compute token probabilities for specified tokens given a query and items. This is useful for classification tasks, scoring responses, or computing log-probabilities.\\n\",\n    \"\\n\",\n    \"Parameters:\\n\",\n    \"- `query`: Query text\\n\",\n    \"- `items`: Item text(s) to score\\n\",\n    \"- `label_token_ids`: Token IDs to compute probabilities for\\n\",\n    \"- `apply_softmax`: Whether to apply softmax to get normalized probabilities (default: False)\\n\",\n    \"- `item_first`: Whether items come first in concatenation order (default: False)\\n\",\n    \"- `model`: Model name\\n\",\n    \"\\n\",\n    \"The response contains `scores` - a list of probability lists, one per item, each in the order of `label_token_ids`.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"score_process, port = launch_server_cmd(\\\"\\\"\\\"\\n\",\n    \"python3 -m sglang.launch_server --model-path qwen/qwen2.5-0.5b-instruct \\\\\\n\",\n    \"    --host 0.0.0.0 --log-level warning\\n\",\n    \"\\\"\\\"\\\")\\n\",\n    \"\\n\",\n    \"wait_for_server(f\\\"http://localhost:{port}\\\", process=score_process)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Score the probability of different completions given a query\\n\",\n    \"query = \\\"The capital of France is\\\"\\n\",\n    \"items = [\\\"Paris\\\", \\\"London\\\", \\\"Berlin\\\"]\\n\",\n    \"\\n\",\n    \"url = f\\\"http://localhost:{port}/v1/score\\\"\\n\",\n    \"data = {\\n\",\n    \"    \\\"model\\\": \\\"qwen/qwen2.5-0.5b-instruct\\\",\\n\",\n    \"    \\\"query\\\": query,\\n\",\n    \"    \\\"items\\\": items,\\n\",\n    \"    \\\"label_token_ids\\\": [9454, 2753],  # e.g. \\\"Yes\\\" and \\\"No\\\" token ids\\n\",\n    \"    \\\"apply_softmax\\\": True,  # Normalize probabilities to sum to 1\\n\",\n    \"}\\n\",\n    \"\\n\",\n    \"response = requests.post(url, json=data)\\n\",\n    \"response_json = response.json()\\n\",\n    \"\\n\",\n    \"# Display scores for each item\\n\",\n    \"for item, scores in zip(items, response_json[\\\"scores\\\"]):\\n\",\n    \"    print_highlight(f\\\"Item '{item}': probabilities = {[f'{s:.4f}' for s in scores]}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"terminate_process(score_process)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Classify (reward model)\\n\",\n    \"\\n\",\n    \"SGLang Runtime also supports reward models. Here we use a reward model to classify the quality of pairwise generations.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Note that SGLang now treats embedding models and reward models as the same type of models.\\n\",\n    \"# This will be updated in the future.\\n\",\n    \"\\n\",\n    \"reward_process, port = launch_server_cmd(\\\"\\\"\\\"\\n\",\n    \"python3 -m sglang.launch_server --model-path Skywork/Skywork-Reward-Llama-3.1-8B-v0.2 --host 0.0.0.0 --is-embedding --log-level warning\\n\",\n    \"\\\"\\\"\\\")\\n\",\n    \"\\n\",\n    \"wait_for_server(f\\\"http://localhost:{port}\\\", process=reward_process)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from transformers import AutoTokenizer\\n\",\n    \"\\n\",\n    \"PROMPT = (\\n\",\n    \"    \\\"What is the range of the numeric output of a sigmoid node in a neural network?\\\"\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"RESPONSE1 = \\\"The output of a sigmoid node is bounded between -1 and 1.\\\"\\n\",\n    \"RESPONSE2 = \\\"The output of a sigmoid node is bounded between 0 and 1.\\\"\\n\",\n    \"\\n\",\n    \"CONVS = [\\n\",\n    \"    [{\\\"role\\\": \\\"user\\\", \\\"content\\\": PROMPT}, {\\\"role\\\": \\\"assistant\\\", \\\"content\\\": RESPONSE1}],\\n\",\n    \"    [{\\\"role\\\": \\\"user\\\", \\\"content\\\": PROMPT}, {\\\"role\\\": \\\"assistant\\\", \\\"content\\\": RESPONSE2}],\\n\",\n    \"]\\n\",\n    \"\\n\",\n    \"tokenizer = AutoTokenizer.from_pretrained(\\\"Skywork/Skywork-Reward-Llama-3.1-8B-v0.2\\\")\\n\",\n    \"prompts = tokenizer.apply_chat_template(CONVS, tokenize=False, return_dict=False)\\n\",\n    \"\\n\",\n    \"url = f\\\"http://localhost:{port}/classify\\\"\\n\",\n    \"data = {\\\"model\\\": \\\"Skywork/Skywork-Reward-Llama-3.1-8B-v0.2\\\", \\\"text\\\": prompts}\\n\",\n    \"\\n\",\n    \"responses = requests.post(url, json=data).json()\\n\",\n    \"for response in responses:\\n\",\n    \"    print_highlight(f\\\"reward: {response['embedding'][0]}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"terminate_process(reward_process)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Capture expert selection distribution in MoE models\\n\",\n    \"\\n\",\n    \"SGLang Runtime supports recording the number of times an expert is selected in a MoE model run for each expert in the model. This is useful when analyzing the throughput of the model and plan for optimization.\\n\",\n    \"\\n\",\n    \"*Note: We only print out the first 10 lines of the csv below for better readability. Please adjust accordingly if you want to analyze the results more deeply.*\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"expert_record_server_process, port = launch_server_cmd(\\n\",\n    \"    \\\"python3 -m sglang.launch_server --model-path Qwen/Qwen1.5-MoE-A2.7B --host 0.0.0.0 --expert-distribution-recorder-mode stat --log-level warning\\\"\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"wait_for_server(f\\\"http://localhost:{port}\\\", process=expert_record_server_process)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"response = requests.post(f\\\"http://localhost:{port}/start_expert_distribution_record\\\")\\n\",\n    \"print_highlight(response)\\n\",\n    \"\\n\",\n    \"url = f\\\"http://localhost:{port}/generate\\\"\\n\",\n    \"data = {\\\"text\\\": \\\"What is the capital of France?\\\"}\\n\",\n    \"\\n\",\n    \"response = requests.post(url, json=data)\\n\",\n    \"print_highlight(response.json())\\n\",\n    \"\\n\",\n    \"response = requests.post(f\\\"http://localhost:{port}/stop_expert_distribution_record\\\")\\n\",\n    \"print_highlight(response)\\n\",\n    \"\\n\",\n    \"response = requests.post(f\\\"http://localhost:{port}/dump_expert_distribution_record\\\")\\n\",\n    \"print_highlight(response)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"terminate_process(expert_record_server_process)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Tokenize/Detokenize Example (Round Trip)\\n\",\n    \"\\n\",\n    \"This example demonstrates how to use the /tokenize and /detokenize endpoints together. We first tokenize a string, then detokenize the resulting IDs to reconstruct the original text. This workflow is useful when you need to handle tokenization externally but still leverage the server for detokenization.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"tokenizer_free_server_process, port = launch_server_cmd(\\\"\\\"\\\"\\n\",\n    \"python3 -m sglang.launch_server --model-path qwen/qwen2.5-0.5b-instruct\\n\",\n    \"\\\"\\\"\\\")\\n\",\n    \"\\n\",\n    \"wait_for_server(f\\\"http://localhost:{port}\\\", process=tokenizer_free_server_process)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import requests\\n\",\n    \"from sglang.utils import print_highlight\\n\",\n    \"\\n\",\n    \"base_url = f\\\"http://localhost:{port}\\\"\\n\",\n    \"tokenize_url = f\\\"{base_url}/tokenize\\\"\\n\",\n    \"detokenize_url = f\\\"{base_url}/detokenize\\\"\\n\",\n    \"\\n\",\n    \"model_name = \\\"qwen/qwen2.5-0.5b-instruct\\\"\\n\",\n    \"input_text = \\\"SGLang provides efficient tokenization endpoints.\\\"\\n\",\n    \"print_highlight(f\\\"Original Input Text:\\\\n'{input_text}'\\\")\\n\",\n    \"\\n\",\n    \"# --- tokenize the input text ---\\n\",\n    \"tokenize_payload = {\\n\",\n    \"    \\\"model\\\": model_name,\\n\",\n    \"    \\\"prompt\\\": input_text,\\n\",\n    \"    \\\"add_special_tokens\\\": False,\\n\",\n    \"}\\n\",\n    \"try:\\n\",\n    \"    tokenize_response = requests.post(tokenize_url, json=tokenize_payload)\\n\",\n    \"    tokenize_response.raise_for_status()\\n\",\n    \"    tokenization_result = tokenize_response.json()\\n\",\n    \"    token_ids = tokenization_result.get(\\\"tokens\\\")\\n\",\n    \"\\n\",\n    \"    if not token_ids:\\n\",\n    \"        raise ValueError(\\\"Tokenization returned empty tokens.\\\")\\n\",\n    \"\\n\",\n    \"    print_highlight(f\\\"\\\\nTokenized Output (IDs):\\\\n{token_ids}\\\")\\n\",\n    \"    print_highlight(f\\\"Token Count: {tokenization_result.get('count')}\\\")\\n\",\n    \"    print_highlight(f\\\"Max Model Length: {tokenization_result.get('max_model_len')}\\\")\\n\",\n    \"\\n\",\n    \"    # --- detokenize the obtained token IDs ---\\n\",\n    \"    detokenize_payload = {\\n\",\n    \"        \\\"model\\\": model_name,\\n\",\n    \"        \\\"tokens\\\": token_ids,\\n\",\n    \"        \\\"skip_special_tokens\\\": True,\\n\",\n    \"    }\\n\",\n    \"\\n\",\n    \"    detokenize_response = requests.post(detokenize_url, json=detokenize_payload)\\n\",\n    \"    detokenize_response.raise_for_status()\\n\",\n    \"    detokenization_result = detokenize_response.json()\\n\",\n    \"    reconstructed_text = detokenization_result.get(\\\"text\\\")\\n\",\n    \"\\n\",\n    \"    print_highlight(f\\\"\\\\nDetokenized Output (Text):\\\\n'{reconstructed_text}'\\\")\\n\",\n    \"\\n\",\n    \"    if input_text == reconstructed_text:\\n\",\n    \"        print_highlight(\\n\",\n    \"            \\\"\\\\nRound Trip Successful: Original and reconstructed text match.\\\"\\n\",\n    \"        )\\n\",\n    \"    else:\\n\",\n    \"        print_highlight(\\n\",\n    \"            \\\"\\\\nRound Trip Mismatch: Original and reconstructed text differ.\\\"\\n\",\n    \"        )\\n\",\n    \"\\n\",\n    \"except requests.exceptions.RequestException as e:\\n\",\n    \"    print_highlight(f\\\"\\\\nHTTP Request Error: {e}\\\")\\n\",\n    \"except Exception as e:\\n\",\n    \"    print_highlight(f\\\"\\\\nAn error occurred: {e}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"terminate_process(tokenizer_free_server_process)\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "docs/basic_usage/offline_engine_api.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Offline Engine API\\n\",\n    \"\\n\",\n    \"SGLang provides a direct inference engine without the need for an HTTP server, especially for use cases where additional HTTP server adds unnecessary complexity or overhead. Here are two general use cases:\\n\",\n    \"\\n\",\n    \"- Offline Batch Inference\\n\",\n    \"- Custom Server on Top of the Engine\\n\",\n    \"\\n\",\n    \"This document focuses on the offline batch inference, demonstrating four different inference modes:\\n\",\n    \"\\n\",\n    \"- Non-streaming synchronous generation\\n\",\n    \"- Streaming synchronous generation\\n\",\n    \"- Non-streaming asynchronous generation\\n\",\n    \"- Streaming asynchronous generation\\n\",\n    \"\\n\",\n    \"Additionally, you can easily build a custom server on top of the SGLang offline engine. A detailed example working in a python script can be found in [custom_server](https://github.com/sgl-project/sglang/blob/main/examples/runtime/engine/custom_server.py).\\n\",\n    \"\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Nest Asyncio\\n\",\n    \"Note that if you want to use **Offline Engine** in ipython or some other nested loop code, you need to add the following code:\\n\",\n    \"```python\\n\",\n    \"import nest_asyncio\\n\",\n    \"\\n\",\n    \"nest_asyncio.apply()\\n\",\n    \"\\n\",\n    \"```\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Advanced Usage\\n\",\n    \"\\n\",\n    \"The engine supports [vlm inference](https://github.com/sgl-project/sglang/blob/main/examples/runtime/engine/offline_batch_inference_vlm.py) as well as [extracting hidden states](https://github.com/sgl-project/sglang/blob/main/examples/runtime/hidden_states). \\n\",\n    \"\\n\",\n    \"Please see [the examples](https://github.com/sgl-project/sglang/tree/main/examples/runtime/engine) for further use cases.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Offline Batch Inference\\n\",\n    \"\\n\",\n    \"SGLang offline engine supports batch inference with efficient scheduling.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# launch the offline engine\\n\",\n    \"import asyncio\\n\",\n    \"\\n\",\n    \"import sglang as sgl\\n\",\n    \"import sglang.test.doc_patch  # noqa: F401\\n\",\n    \"from sglang.utils import async_stream_and_merge, stream_and_merge\\n\",\n    \"\\n\",\n    \"llm = sgl.Engine(model_path=\\\"qwen/qwen2.5-0.5b-instruct\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Non-streaming Synchronous Generation\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"prompts = [\\n\",\n    \"    \\\"Hello, my name is\\\",\\n\",\n    \"    \\\"The president of the United States is\\\",\\n\",\n    \"    \\\"The capital of France is\\\",\\n\",\n    \"    \\\"The future of AI is\\\",\\n\",\n    \"]\\n\",\n    \"\\n\",\n    \"sampling_params = {\\\"temperature\\\": 0.8, \\\"top_p\\\": 0.95}\\n\",\n    \"\\n\",\n    \"outputs = llm.generate(prompts, sampling_params)\\n\",\n    \"for prompt, output in zip(prompts, outputs):\\n\",\n    \"    print(\\\"===============================\\\")\\n\",\n    \"    print(f\\\"Prompt: {prompt}\\\\nGenerated text: {output['text']}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Streaming Synchronous Generation\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"prompts = [\\n\",\n    \"    \\\"Write a short, neutral self-introduction for a fictional character. Hello, my name is\\\",\\n\",\n    \"    \\\"Provide a concise factual statement about France’s capital city. The capital of France is\\\",\\n\",\n    \"    \\\"Explain possible future trends in artificial intelligence. The future of AI is\\\",\\n\",\n    \"]\\n\",\n    \"\\n\",\n    \"sampling_params = {\\n\",\n    \"    \\\"temperature\\\": 0.2,\\n\",\n    \"    \\\"top_p\\\": 0.9,\\n\",\n    \"}\\n\",\n    \"\\n\",\n    \"print(\\\"\\\\n=== Testing synchronous streaming generation with overlap removal ===\\\\n\\\")\\n\",\n    \"\\n\",\n    \"for prompt in prompts:\\n\",\n    \"    print(f\\\"Prompt: {prompt}\\\")\\n\",\n    \"    merged_output = stream_and_merge(llm, prompt, sampling_params)\\n\",\n    \"    print(\\\"Generated text:\\\", merged_output)\\n\",\n    \"    print()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Non-streaming Asynchronous Generation\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"prompts = [\\n\",\n    \"    \\\"Write a short, neutral self-introduction for a fictional character. Hello, my name is\\\",\\n\",\n    \"    \\\"Provide a concise factual statement about France’s capital city. The capital of France is\\\",\\n\",\n    \"    \\\"Explain possible future trends in artificial intelligence. The future of AI is\\\",\\n\",\n    \"]\\n\",\n    \"\\n\",\n    \"sampling_params = {\\\"temperature\\\": 0.8, \\\"top_p\\\": 0.95}\\n\",\n    \"\\n\",\n    \"print(\\\"\\\\n=== Testing asynchronous batch generation ===\\\")\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"async def main():\\n\",\n    \"    outputs = await llm.async_generate(prompts, sampling_params)\\n\",\n    \"\\n\",\n    \"    for prompt, output in zip(prompts, outputs):\\n\",\n    \"        print(f\\\"\\\\nPrompt: {prompt}\\\")\\n\",\n    \"        print(f\\\"Generated text: {output['text']}\\\")\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"asyncio.run(main())\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Streaming Asynchronous Generation\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"prompts = [\\n\",\n    \"    \\\"Write a short, neutral self-introduction for a fictional character. Hello, my name is\\\",\\n\",\n    \"    \\\"Provide a concise factual statement about France’s capital city. The capital of France is\\\",\\n\",\n    \"    \\\"Explain possible future trends in artificial intelligence. The future of AI is\\\",\\n\",\n    \"]\\n\",\n    \"\\n\",\n    \"sampling_params = {\\\"temperature\\\": 0.8, \\\"top_p\\\": 0.95}\\n\",\n    \"\\n\",\n    \"print(\\\"\\\\n=== Testing asynchronous streaming generation (no repeats) ===\\\")\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"async def main():\\n\",\n    \"    for prompt in prompts:\\n\",\n    \"        print(f\\\"\\\\nPrompt: {prompt}\\\")\\n\",\n    \"        print(\\\"Generated text: \\\", end=\\\"\\\", flush=True)\\n\",\n    \"\\n\",\n    \"        # Replace direct calls to async_generate with our custom overlap-aware version\\n\",\n    \"        async for cleaned_chunk in async_stream_and_merge(llm, prompt, sampling_params):\\n\",\n    \"            print(cleaned_chunk, end=\\\"\\\", flush=True)\\n\",\n    \"\\n\",\n    \"        print()  # New line after each prompt\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"asyncio.run(main())\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"llm.shutdown()\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "docs/basic_usage/ollama_api.md",
    "content": "# Ollama-Compatible API\n\nSGLang provides Ollama API compatibility, allowing you to use the Ollama CLI and Python library with SGLang as the inference backend.\n\n## Prerequisites\n\n```bash\n# Install the Ollama Python library (for Python client usage)\npip install ollama\n```\n\n> **Note**: You don't need the Ollama server installed - SGLang acts as the backend. You only need the `ollama` CLI or Python library as the client.\n\n## Endpoints\n\n| Endpoint | Method | Description |\n|----------|--------|-------------|\n| `/` | GET, HEAD | Health check for Ollama CLI |\n| `/api/tags` | GET | List available models |\n| `/api/chat` | POST | Chat completions (streaming & non-streaming) |\n| `/api/generate` | POST | Text generation (streaming & non-streaming) |\n| `/api/show` | POST | Model information |\n\n## Quick Start\n\n### 1. Launch SGLang Server\n\n```bash\npython -m sglang.launch_server \\\n    --model Qwen/Qwen2.5-1.5B-Instruct \\\n    --port 30001 \\\n    --host 0.0.0.0\n```\n\n> **Note**: The model name used with `ollama run` must match exactly what you passed to `--model`.\n\n### 2. Use Ollama CLI\n\n```bash\n# List available models\nOLLAMA_HOST=http://localhost:30001 ollama list\n\n# Interactive chat\nOLLAMA_HOST=http://localhost:30001 ollama run \"Qwen/Qwen2.5-1.5B-Instruct\"\n```\n\nIf connecting to a remote server behind a firewall:\n\n```bash\n# SSH tunnel\nssh -L 30001:localhost:30001 user@gpu-server -N &\n\n# Then use Ollama CLI as above\nOLLAMA_HOST=http://localhost:30001 ollama list\n```\n\n### 3. Use Ollama Python Library\n\n```python\nimport ollama\n\nclient = ollama.Client(host='http://localhost:30001')\n\n# Non-streaming\nresponse = client.chat(\n    model='Qwen/Qwen2.5-1.5B-Instruct',\n    messages=[{'role': 'user', 'content': 'Hello!'}]\n)\nprint(response['message']['content'])\n\n# Streaming\nstream = client.chat(\n    model='Qwen/Qwen2.5-1.5B-Instruct',\n    messages=[{'role': 'user', 'content': 'Tell me a story'}],\n    stream=True\n)\nfor chunk in stream:\n    print(chunk['message']['content'], end='', flush=True)\n```\n\n## Smart Router\n\nFor intelligent routing between local Ollama (fast) and remote SGLang (powerful) using an LLM judge, see the [Smart Router documentation](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/entrypoints/ollama/README.md).\n\n## Summary\n\n| Component | Purpose |\n|-----------|---------|\n| **Ollama API** | Familiar CLI/API that developers already know |\n| **SGLang Backend** | High-performance inference engine |\n| **Smart Router** | Intelligent routing - fast local for simple tasks, powerful remote for complex tasks |\n"
  },
  {
    "path": "docs/basic_usage/openai_api.rst",
    "content": "OpenAI-Compatible APIs\n======================\n\n.. toctree::\n   :maxdepth: 1\n\n   openai_api_completions.ipynb\n   openai_api_vision.ipynb\n   openai_api_embeddings.ipynb\n"
  },
  {
    "path": "docs/basic_usage/openai_api_completions.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# OpenAI APIs - Completions\\n\",\n    \"\\n\",\n    \"SGLang provides OpenAI-compatible APIs to enable a smooth transition from OpenAI services to self-hosted local models.\\n\",\n    \"A complete reference for the API is available in the [OpenAI API Reference](https://platform.openai.com/docs/api-reference).\\n\",\n    \"\\n\",\n    \"This tutorial covers the following popular APIs:\\n\",\n    \"\\n\",\n    \"- `chat/completions`\\n\",\n    \"- `completions`\\n\",\n    \"\\n\",\n    \"Check out other tutorials to learn about [vision APIs](openai_api_vision.ipynb) for vision-language models and [embedding APIs](openai_api_embeddings.ipynb) for embedding models.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Launch A Server\\n\",\n    \"\\n\",\n    \"Launch the server in your terminal and wait for it to initialize.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from sglang.test.doc_patch import launch_server_cmd\\n\",\n    \"from sglang.utils import wait_for_server, print_highlight, terminate_process\\n\",\n    \"\\n\",\n    \"server_process, port = launch_server_cmd(\\n\",\n    \"    \\\"python3 -m sglang.launch_server --model-path qwen/qwen2.5-0.5b-instruct --host 0.0.0.0 --log-level warning\\\"\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"wait_for_server(f\\\"http://localhost:{port}\\\", process=server_process)\\n\",\n    \"print(f\\\"Server started on http://localhost:{port}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Chat Completions\\n\",\n    \"\\n\",\n    \"### Usage\\n\",\n    \"\\n\",\n    \"The server fully implements the OpenAI API.\\n\",\n    \"It will automatically apply the chat template specified in the Hugging Face tokenizer, if one is available.\\n\",\n    \"You can also specify a custom chat template with `--chat-template` when launching the server.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import openai\\n\",\n    \"\\n\",\n    \"client = openai.Client(base_url=f\\\"http://127.0.0.1:{port}/v1\\\", api_key=\\\"None\\\")\\n\",\n    \"\\n\",\n    \"response = client.chat.completions.create(\\n\",\n    \"    model=\\\"qwen/qwen2.5-0.5b-instruct\\\",\\n\",\n    \"    messages=[\\n\",\n    \"        {\\\"role\\\": \\\"user\\\", \\\"content\\\": \\\"List 3 countries and their capitals.\\\"},\\n\",\n    \"    ],\\n\",\n    \"    temperature=0,\\n\",\n    \"    max_tokens=64,\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"print_highlight(f\\\"Response: {response}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Model Thinking/Reasoning Support\\n\",\n    \"\\n\",\n    \"Some models support internal reasoning or thinking processes that can be exposed in the API response. SGLang provides unified support for various reasoning models through the `chat_template_kwargs` parameter and compatible reasoning parsers.\\n\",\n    \"\\n\",\n    \"#### Supported Models and Configuration\\n\",\n    \"\\n\",\n    \"| Model Family | Chat Template Parameter | Reasoning Parser | Notes |\\n\",\n    \"|--------------|------------------------|------------------|--------|\\n\",\n    \"| DeepSeek-R1 (R1, R1-0528, R1-Distill) | `enable_thinking` | `--reasoning-parser deepseek-r1` | Standard reasoning models |\\n\",\n    \"| DeepSeek-V3.1 | `thinking` | `--reasoning-parser deepseek-v3` | Hybrid model (thinking/non-thinking modes) |\\n\",\n    \"| Qwen3 (standard) | `enable_thinking` | `--reasoning-parser qwen3` | Hybrid model (thinking/non-thinking modes) |\\n\",\n    \"| Qwen3-Thinking | N/A (always enabled) | `--reasoning-parser qwen3-thinking` | Always generates reasoning |\\n\",\n    \"| Kimi | N/A (always enabled) | `--reasoning-parser kimi` | Kimi thinking models |\\n\",\n    \"| Gpt-Oss | N/A (always enabled) | `--reasoning-parser gpt-oss` | Gpt-Oss thinking models |\\n\",\n    \"\\n\",\n    \"#### Basic Usage\\n\",\n    \"\\n\",\n    \"To enable reasoning output, you need to:\\n\",\n    \"1. Launch the server with the appropriate reasoning parser\\n\",\n    \"2. Set the model-specific parameter in `chat_template_kwargs`\\n\",\n    \"3. Optionally use `separate_reasoning: False` to not get reasoning content separately (default to `True`)\\n\",\n    \"\\n\",\n    \"**Note for Qwen3-Thinking models:** These models always generate thinking content and do not support the `enable_thinking` parameter. Use `--reasoning-parser qwen3-thinking` or `--reasoning-parser qwen3` to parse the thinking content.\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Example: Qwen3 Models\\n\",\n    \"\\n\",\n    \"```python\\n\",\n    \"# Launch server:\\n\",\n    \"# python3 -m sglang.launch_server --model Qwen/Qwen3-4B --reasoning-parser qwen3\\n\",\n    \"\\n\",\n    \"from openai import OpenAI\\n\",\n    \"\\n\",\n    \"client = OpenAI(\\n\",\n    \"    api_key=\\\"EMPTY\\\",\\n\",\n    \"    base_url=f\\\"http://127.0.0.1:30000/v1\\\",\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"model = \\\"Qwen/Qwen3-4B\\\"\\n\",\n    \"messages = [{\\\"role\\\": \\\"user\\\", \\\"content\\\": \\\"How many r's are in 'strawberry'?\\\"}]\\n\",\n    \"\\n\",\n    \"response = client.chat.completions.create(\\n\",\n    \"    model=model,\\n\",\n    \"    messages=messages,\\n\",\n    \"    extra_body={\\n\",\n    \"        \\\"chat_template_kwargs\\\": {\\\"enable_thinking\\\": True},\\n\",\n    \"        \\\"separate_reasoning\\\": True\\n\",\n    \"    }\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"print(\\\"Reasoning:\\\", response.choices[0].message.reasoning_content)\\n\",\n    \"print(\\\"-\\\"*100)\\n\",\n    \"print(\\\"Answer:\\\", response.choices[0].message.content)\\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"**ExampleOutput:**\\n\",\n    \"```\\n\",\n    \"Reasoning: Okay, so the user is asking how many 'r's are in the word 'strawberry'. Let me think. First, I need to make sure I have the word spelled correctly. Strawberry... S-T-R-A-W-B-E-R-R-Y. Wait, is that right? Let me break it down.\\n\",\n    \"\\n\",\n    \"Starting with 'strawberry', let's write out the letters one by one. S, T, R, A, W, B, E, R, R, Y. Hmm, wait, that's 10 letters. Let me check again. S (1), T (2), R (3), A (4), W (5), B (6), E (7), R (8), R (9), Y (10). So the letters are S-T-R-A-W-B-E-R-R-Y. \\n\",\n    \"...\\n\",\n    \"Therefore, the answer should be three R's in 'strawberry'. But I need to make sure I'm not counting any other letters as R. Let me check again. S, T, R, A, W, B, E, R, R, Y. No other R's. So three in total. Yeah, that seems right.\\n\",\n    \"\\n\",\n    \"----------------------------------------------------------------------------------------------------\\n\",\n    \"Answer: The word \\\"strawberry\\\" contains **three** letters 'r'. Here's the breakdown:\\n\",\n    \"\\n\",\n    \"1. **S-T-R-A-W-B-E-R-R-Y**  \\n\",\n    \"   - The **third letter** is 'R'.  \\n\",\n    \"   - The **eighth and ninth letters** are also 'R's.  \\n\",\n    \"\\n\",\n    \"Thus, the total count is **3**.  \\n\",\n    \"\\n\",\n    \"**Answer:** 3.\\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"**Note:** Setting `\\\"enable_thinking\\\": False` (or omitting it) will result in `reasoning_content` being `None`. Qwen3-Thinking models always generate reasoning content and don't support the `enable_thinking` parameter.\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Logit Bias Support\\n\",\n    \"\\n\",\n    \"SGLang supports the `logit_bias` parameter for both chat completions and completions APIs. This parameter allows you to modify the likelihood of specific tokens being generated by adding bias values to their logits. The bias values can range from -100 to 100, where:\\n\",\n    \"\\n\",\n    \"- **Positive values** (0 to 100) increase the likelihood of the token being selected\\n\",\n    \"- **Negative values** (-100 to 0) decrease the likelihood of the token being selected\\n\",\n    \"- **-100** effectively prevents the token from being generated\\n\",\n    \"\\n\",\n    \"The `logit_bias` parameter accepts a dictionary where keys are token IDs (as strings) and values are the bias amounts (as floats).\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Getting Token IDs\\n\",\n    \"\\n\",\n    \"To use `logit_bias` effectively, you need to know the token IDs for the words you want to bias. Here's how to get token IDs:\\n\",\n    \"\\n\",\n    \"```python\\n\",\n    \"# Get tokenizer to find token IDs\\n\",\n    \"import tiktoken\\n\",\n    \"\\n\",\n    \"# For OpenAI models, use the appropriate encoding\\n\",\n    \"tokenizer = tiktoken.encoding_for_model(\\\"gpt-3.5-turbo\\\")  # or your model\\n\",\n    \"\\n\",\n    \"# Get token IDs for specific words\\n\",\n    \"word = \\\"sunny\\\"\\n\",\n    \"token_ids = tokenizer.encode(word)\\n\",\n    \"print(f\\\"Token IDs for '{word}': {token_ids}\\\")\\n\",\n    \"\\n\",\n    \"# For SGLang models, you can access the tokenizer through the client\\n\",\n    \"# and get token IDs for bias\\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"**Important:** The `logit_bias` parameter uses token IDs as string keys, not the actual words.\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Example: DeepSeek-V3 Models\\n\",\n    \"\\n\",\n    \"DeepSeek-V3 models support thinking mode through the `thinking` parameter:\\n\",\n    \"\\n\",\n    \"```python\\n\",\n    \"# Launch server:\\n\",\n    \"# python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3.1 --tp 8  --reasoning-parser deepseek-v3\\n\",\n    \"\\n\",\n    \"from openai import OpenAI\\n\",\n    \"\\n\",\n    \"client = OpenAI(\\n\",\n    \"    api_key=\\\"EMPTY\\\",\\n\",\n    \"    base_url=f\\\"http://127.0.0.1:30000/v1\\\",\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"model = \\\"deepseek-ai/DeepSeek-V3.1\\\"\\n\",\n    \"messages = [{\\\"role\\\": \\\"user\\\", \\\"content\\\": \\\"How many r's are in 'strawberry'?\\\"}]\\n\",\n    \"\\n\",\n    \"response = client.chat.completions.create(\\n\",\n    \"    model=model,\\n\",\n    \"    messages=messages,\\n\",\n    \"    extra_body={\\n\",\n    \"        \\\"chat_template_kwargs\\\": {\\\"thinking\\\": True},\\n\",\n    \"        \\\"separate_reasoning\\\": True\\n\",\n    \"    }\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"print(\\\"Reasoning:\\\", response.choices[0].message.reasoning_content)\\n\",\n    \"print(\\\"-\\\"*100)\\n\",\n    \"print(\\\"Answer:\\\", response.choices[0].message.content)\\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"**Example Output:**\\n\",\n    \"```\\n\",\n    \"Reasoning: First, the question is: \\\"How many r's are in 'strawberry'?\\\"\\n\",\n    \"\\n\",\n    \"I need to count the number of times the letter 'r' appears in the word \\\"strawberry\\\".\\n\",\n    \"\\n\",\n    \"Let me write out the word: S-T-R-A-W-B-E-R-R-Y.\\n\",\n    \"\\n\",\n    \"Now, I'll go through each letter and count the 'r's.\\n\",\n    \"...\\n\",\n    \"So, I have three 'r's in \\\"strawberry\\\".\\n\",\n    \"\\n\",\n    \"I should double-check. The word is spelled S-T-R-A-W-B-E-R-R-Y. The letters are at positions: 3, 8, and 9 are 'r's. Yes, that's correct.\\n\",\n    \"\\n\",\n    \"Therefore, the answer should be 3.\\n\",\n    \"----------------------------------------------------------------------------------------------------\\n\",\n    \"Answer: The word \\\"strawberry\\\" contains **3** instances of the letter \\\"r\\\". Here's a breakdown for clarity:\\n\",\n    \"\\n\",\n    \"- The word is spelled: S-T-R-A-W-B-E-R-R-Y\\n\",\n    \"- The \\\"r\\\" appears at the 3rd, 8th, and 9th positions.\\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"**Note:** DeepSeek-V3 models use the `thinking` parameter (not `enable_thinking`) to control reasoning output.\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Example with logit_bias parameter\\n\",\n    \"# Note: You need to get the actual token IDs from your tokenizer\\n\",\n    \"# For demonstration, we'll use some example token IDs\\n\",\n    \"response = client.chat.completions.create(\\n\",\n    \"    model=\\\"qwen/qwen2.5-0.5b-instruct\\\",\\n\",\n    \"    messages=[\\n\",\n    \"        {\\\"role\\\": \\\"user\\\", \\\"content\\\": \\\"Complete this sentence: The weather today is\\\"}\\n\",\n    \"    ],\\n\",\n    \"    temperature=0.7,\\n\",\n    \"    max_tokens=20,\\n\",\n    \"    logit_bias={\\n\",\n    \"        \\\"12345\\\": 50,  # Increase likelihood of token ID 12345\\n\",\n    \"        \\\"67890\\\": -50,  # Decrease likelihood of token ID 67890\\n\",\n    \"        \\\"11111\\\": 25,  # Slightly increase likelihood of token ID 11111\\n\",\n    \"    },\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"print_highlight(f\\\"Response with logit bias: {response.choices[0].message.content}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Parameters\\n\",\n    \"\\n\",\n    \"The chat completions API accepts OpenAI Chat Completions API's parameters. Refer to [OpenAI Chat Completions API](https://platform.openai.com/docs/api-reference/chat/create) for more details.\\n\",\n    \"\\n\",\n    \"SGLang extends the standard API with the `extra_body` parameter, allowing for additional customization. One key option within `extra_body` is `chat_template_kwargs`, which can be used to pass arguments to the chat template processor.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"response = client.chat.completions.create(\\n\",\n    \"    model=\\\"qwen/qwen2.5-0.5b-instruct\\\",\\n\",\n    \"    messages=[\\n\",\n    \"        {\\n\",\n    \"            \\\"role\\\": \\\"system\\\",\\n\",\n    \"            \\\"content\\\": \\\"You are a knowledgeable historian who provides concise responses.\\\",\\n\",\n    \"        },\\n\",\n    \"        {\\\"role\\\": \\\"user\\\", \\\"content\\\": \\\"Tell me about ancient Rome\\\"},\\n\",\n    \"        {\\n\",\n    \"            \\\"role\\\": \\\"assistant\\\",\\n\",\n    \"            \\\"content\\\": \\\"Ancient Rome was a civilization centered in Italy.\\\",\\n\",\n    \"        },\\n\",\n    \"        {\\\"role\\\": \\\"user\\\", \\\"content\\\": \\\"What were their major achievements?\\\"},\\n\",\n    \"    ],\\n\",\n    \"    temperature=0.3,  # Lower temperature for more focused responses\\n\",\n    \"    max_tokens=128,  # Reasonable length for a concise response\\n\",\n    \"    top_p=0.95,  # Slightly higher for better fluency\\n\",\n    \"    presence_penalty=0.2,  # Mild penalty to avoid repetition\\n\",\n    \"    frequency_penalty=0.2,  # Mild penalty for more natural language\\n\",\n    \"    n=1,  # Single response is usually more stable\\n\",\n    \"    seed=42,  # Keep for reproducibility\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"print_highlight(response.choices[0].message.content)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Streaming mode is also supported.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Logit Bias Support\\n\",\n    \"\\n\",\n    \"The completions API also supports the `logit_bias` parameter with the same functionality as described in the chat completions section above.\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"stream = client.chat.completions.create(\\n\",\n    \"    model=\\\"qwen/qwen2.5-0.5b-instruct\\\",\\n\",\n    \"    messages=[{\\\"role\\\": \\\"user\\\", \\\"content\\\": \\\"Say this is a test\\\"}],\\n\",\n    \"    stream=True,\\n\",\n    \")\\n\",\n    \"for chunk in stream:\\n\",\n    \"    if chunk.choices[0].delta.content is not None:\\n\",\n    \"        print(chunk.choices[0].delta.content, end=\\\"\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Returning Routed Experts (MoE Models)\\n\",\n    \"\\n\",\n    \"For MoE models, set `return_routed_experts: true` in `extra_body` to return expert routing data. Requires `--enable-return-routed-experts` server flag. The `routed_experts` field will be returned in the `sgl_ext` object on each choice, containing base64-encoded int32 expert IDs as a flattened array with logical shape `[num_tokens, num_layers, top_k]`.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Example with logit_bias parameter for completions API\\n\",\n    \"# Note: You need to get the actual token IDs from your tokenizer\\n\",\n    \"# For demonstration, we'll use some example token IDs\\n\",\n    \"response = client.completions.create(\\n\",\n    \"    model=\\\"qwen/qwen2.5-0.5b-instruct\\\",\\n\",\n    \"    prompt=\\\"The best programming language for AI is\\\",\\n\",\n    \"    temperature=0.7,\\n\",\n    \"    max_tokens=20,\\n\",\n    \"    logit_bias={\\n\",\n    \"        \\\"12345\\\": 75,  # Strongly favor token ID 12345\\n\",\n    \"        \\\"67890\\\": -100,  # Completely avoid token ID 67890\\n\",\n    \"        \\\"11111\\\": -25,  # Slightly discourage token ID 11111\\n\",\n    \"    },\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"print_highlight(f\\\"Response with logit bias: {response.choices[0].text}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Completions\\n\",\n    \"\\n\",\n    \"### Usage\\n\",\n    \"Completions API is similar to Chat Completions API, but without the `messages` parameter or chat templates.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"response = client.completions.create(\\n\",\n    \"    model=\\\"qwen/qwen2.5-0.5b-instruct\\\",\\n\",\n    \"    prompt=\\\"List 3 countries and their capitals.\\\",\\n\",\n    \"    temperature=0,\\n\",\n    \"    max_tokens=64,\\n\",\n    \"    n=1,\\n\",\n    \"    stop=None,\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"print_highlight(f\\\"Response: {response}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Parameters\\n\",\n    \"\\n\",\n    \"The completions API accepts OpenAI Completions API's parameters.  Refer to [OpenAI Completions API](https://platform.openai.com/docs/api-reference/completions/create) for more details.\\n\",\n    \"\\n\",\n    \"Here is an example of a detailed completions request:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"response = client.completions.create(\\n\",\n    \"    model=\\\"qwen/qwen2.5-0.5b-instruct\\\",\\n\",\n    \"    prompt=\\\"Write a short story about a space explorer.\\\",\\n\",\n    \"    temperature=0.7,  # Moderate temperature for creative writing\\n\",\n    \"    max_tokens=150,  # Longer response for a story\\n\",\n    \"    top_p=0.9,  # Balanced diversity in word choice\\n\",\n    \"    stop=[\\\"\\\\n\\\\n\\\", \\\"THE END\\\"],  # Multiple stop sequences\\n\",\n    \"    presence_penalty=0.3,  # Encourage novel elements\\n\",\n    \"    frequency_penalty=0.3,  # Reduce repetitive phrases\\n\",\n    \"    n=1,  # Generate one completion\\n\",\n    \"    seed=123,  # For reproducible results\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"print_highlight(f\\\"Response: {response}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Returning Routed Experts (MoE Models)\\n\",\n    \"\\n\",\n    \"For MoE models, set `return_routed_experts: true` in `extra_body` to return expert routing data. Requires `--enable-return-routed-experts` server flag. The `routed_experts` field will be returned in the `sgl_ext` object on each choice, containing base64-encoded int32 expert IDs as a flattened array with logical shape `[num_tokens, num_layers, top_k]`.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Structured Outputs (JSON, Regex, EBNF)\\n\",\n    \"\\n\",\n    \"For OpenAI compatible structured outputs API, refer to [Structured Outputs](../advanced_features/structured_outputs.ipynb) for more details.\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Using LoRA Adapters\\n\",\n    \"\\n\",\n    \"SGLang supports LoRA (Low-Rank Adaptation) adapters with OpenAI-compatible APIs. You can specify which adapter to use directly in the `model` parameter using the `base-model:adapter-name` syntax.\\n\",\n    \"\\n\",\n    \"**Server Setup:**\\n\",\n    \"```bash\\n\",\n    \"python -m sglang.launch_server \\\\\\n\",\n    \"    --model-path qwen/qwen2.5-0.5b-instruct \\\\\\n\",\n    \"    --enable-lora \\\\\\n\",\n    \"    --lora-paths adapter_a=/path/to/adapter_a adapter_b=/path/to/adapter_b\\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"For more details on LoRA serving configuration, see the [LoRA documentation](../advanced_features/lora.ipynb).\\n\",\n    \"\\n\",\n    \"**API Call:**\\n\",\n    \"\\n\",\n    \"(Recommended) Use the `model:adapter` syntax to specify which adapter to use:\\n\",\n    \"```python\\n\",\n    \"response = client.chat.completions.create(\\n\",\n    \"    model=\\\"qwen/qwen2.5-0.5b-instruct:adapter_a\\\",  # ← base-model:adapter-name\\n\",\n    \"    messages=[{\\\"role\\\": \\\"user\\\", \\\"content\\\": \\\"Convert to SQL: show all users\\\"}],\\n\",\n    \"    max_tokens=50,\\n\",\n    \")\\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"**Backward Compatible: Using `extra_body`**\\n\",\n    \"\\n\",\n    \"The old `extra_body` method is still supported for backward compatibility:\\n\",\n    \"```python\\n\",\n    \"# Backward compatible method\\n\",\n    \"response = client.chat.completions.create(\\n\",\n    \"    model=\\\"qwen/qwen2.5-0.5b-instruct\\\",\\n\",\n    \"    messages=[{\\\"role\\\": \\\"user\\\", \\\"content\\\": \\\"Convert to SQL: show all users\\\"}],\\n\",\n    \"    extra_body={\\\"lora_path\\\": \\\"adapter_a\\\"},  # ← old method\\n\",\n    \"    max_tokens=50,\\n\",\n    \")\\n\",\n    \"```\\n\",\n    \"**Note:** When both `model:adapter` and `extra_body[\\\"lora_path\\\"]` are specified, the `model:adapter` syntax takes precedence.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"terminate_process(server_process)\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "docs/basic_usage/openai_api_embeddings.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# OpenAI APIs - Embedding\\n\",\n    \"\\n\",\n    \"SGLang provides OpenAI-compatible APIs to enable a smooth transition from OpenAI services to self-hosted local models.\\n\",\n    \"A complete reference for the API is available in the [OpenAI API Reference](https://platform.openai.com/docs/guides/embeddings).\\n\",\n    \"\\n\",\n    \"This tutorial covers the embedding APIs for embedding models. For a list of the supported models see the [corresponding overview page](../supported_models/retrieval_ranking/embedding_models.md)\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Launch A Server\\n\",\n    \"\\n\",\n    \"Launch the server in your terminal and wait for it to initialize. Remember to add `--is-embedding` to the command.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from sglang.test.doc_patch import launch_server_cmd\\n\",\n    \"from sglang.utils import wait_for_server, print_highlight, terminate_process\\n\",\n    \"\\n\",\n    \"embedding_process, port = launch_server_cmd(\\\"\\\"\\\"\\n\",\n    \"python3 -m sglang.launch_server --model-path Alibaba-NLP/gte-Qwen2-1.5B-instruct \\\\\\n\",\n    \"    --host 0.0.0.0 --is-embedding --log-level warning\\n\",\n    \"\\\"\\\"\\\")\\n\",\n    \"\\n\",\n    \"wait_for_server(f\\\"http://localhost:{port}\\\", process=embedding_process)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Using cURL\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import subprocess, json\\n\",\n    \"\\n\",\n    \"text = \\\"Once upon a time\\\"\\n\",\n    \"\\n\",\n    \"curl_text = f\\\"\\\"\\\"curl -s http://localhost:{port}/v1/embeddings \\\\\\n\",\n    \"  -H \\\"Content-Type: application/json\\\" \\\\\\n\",\n    \"  -d '{{\\\"model\\\": \\\"Alibaba-NLP/gte-Qwen2-1.5B-instruct\\\", \\\"input\\\": \\\"{text}\\\"}}'\\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"result = subprocess.check_output(curl_text, shell=True)\\n\",\n    \"\\n\",\n    \"print(result)\\n\",\n    \"\\n\",\n    \"text_embedding = json.loads(result)[\\\"data\\\"][0][\\\"embedding\\\"]\\n\",\n    \"\\n\",\n    \"print_highlight(f\\\"Text embedding (first 10): {text_embedding[:10]}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Using Python Requests\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import requests\\n\",\n    \"\\n\",\n    \"text = \\\"Once upon a time\\\"\\n\",\n    \"\\n\",\n    \"response = requests.post(\\n\",\n    \"    f\\\"http://localhost:{port}/v1/embeddings\\\",\\n\",\n    \"    json={\\\"model\\\": \\\"Alibaba-NLP/gte-Qwen2-1.5B-instruct\\\", \\\"input\\\": text},\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"text_embedding = response.json()[\\\"data\\\"][0][\\\"embedding\\\"]\\n\",\n    \"\\n\",\n    \"print_highlight(f\\\"Text embedding (first 10): {text_embedding[:10]}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Using OpenAI Python Client\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import openai\\n\",\n    \"\\n\",\n    \"client = openai.Client(base_url=f\\\"http://127.0.0.1:{port}/v1\\\", api_key=\\\"None\\\")\\n\",\n    \"\\n\",\n    \"# Text embedding example\\n\",\n    \"response = client.embeddings.create(\\n\",\n    \"    model=\\\"Alibaba-NLP/gte-Qwen2-1.5B-instruct\\\",\\n\",\n    \"    input=text,\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"embedding = response.data[0].embedding[:10]\\n\",\n    \"print_highlight(f\\\"Text embedding (first 10): {embedding}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Using Input IDs\\n\",\n    \"\\n\",\n    \"SGLang also supports `input_ids` as input to get the embedding.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import json\\n\",\n    \"import os\\n\",\n    \"from transformers import AutoTokenizer\\n\",\n    \"\\n\",\n    \"os.environ[\\\"TOKENIZERS_PARALLELISM\\\"] = \\\"false\\\"\\n\",\n    \"\\n\",\n    \"tokenizer = AutoTokenizer.from_pretrained(\\\"Alibaba-NLP/gte-Qwen2-1.5B-instruct\\\")\\n\",\n    \"input_ids = tokenizer.encode(text)\\n\",\n    \"\\n\",\n    \"curl_ids = f\\\"\\\"\\\"curl -s http://localhost:{port}/v1/embeddings \\\\\\n\",\n    \"  -H \\\"Content-Type: application/json\\\" \\\\\\n\",\n    \"  -d '{{\\\"model\\\": \\\"Alibaba-NLP/gte-Qwen2-1.5B-instruct\\\", \\\"input\\\": {json.dumps(input_ids)}}}'\\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"input_ids_embedding = json.loads(subprocess.check_output(curl_ids, shell=True))[\\\"data\\\"][\\n\",\n    \"    0\\n\",\n    \"][\\\"embedding\\\"]\\n\",\n    \"\\n\",\n    \"print_highlight(f\\\"Input IDs embedding (first 10): {input_ids_embedding[:10]}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"terminate_process(embedding_process)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Multi-Modal Embedding Model\\n\",\n    \"Please refer to [Multi-Modal Embedding Model](../supported_models/retrieval_ranking/embedding_models.md)\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "docs/basic_usage/openai_api_vision.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# OpenAI APIs - Vision\\n\",\n    \"\\n\",\n    \"SGLang provides OpenAI-compatible APIs to enable a smooth transition from OpenAI services to self-hosted local models.\\n\",\n    \"A complete reference for the API is available in the [OpenAI API Reference](https://platform.openai.com/docs/guides/vision).\\n\",\n    \"This tutorial covers the vision APIs for vision language models.\\n\",\n    \"\\n\",\n    \"SGLang supports various vision language models such as Llama 3.2, LLaVA-OneVision, Qwen2.5-VL, Gemma3 and [more](../supported_models/text_generation/multimodal_language_models.md).\\n\",\n    \"\\n\",\n    \"As an alternative to the OpenAI API, you can also use the [SGLang offline engine](https://github.com/sgl-project/sglang/blob/main/examples/runtime/engine/offline_batch_inference_vlm.py).\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Launch A Server\\n\",\n    \"\\n\",\n    \"Launch the server in your terminal and wait for it to initialize.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from sglang.test.doc_patch import launch_server_cmd\\n\",\n    \"from sglang.utils import wait_for_server, print_highlight, terminate_process\\n\",\n    \"\\n\",\n    \"example_image_url = \\\"https://raw.githubusercontent.com/sgl-project/sglang/main/examples/assets/example_image.png\\\"\\n\",\n    \"logo_image_url = (\\n\",\n    \"    \\\"https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png\\\"\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"vision_process, port = launch_server_cmd(\\\"\\\"\\\"\\n\",\n    \"python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-VL-7B-Instruct --log-level warning\\n\",\n    \"\\\"\\\"\\\")\\n\",\n    \"\\n\",\n    \"wait_for_server(f\\\"http://localhost:{port}\\\", process=vision_process)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Using cURL\\n\",\n    \"\\n\",\n    \"Once the server is up, you can send test requests using curl or requests.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import subprocess\\n\",\n    \"\\n\",\n    \"curl_command = f\\\"\\\"\\\"\\n\",\n    \"curl -s http://localhost:{port}/v1/chat/completions \\\\\\\\\\n\",\n    \"  -H \\\"Content-Type: application/json\\\" \\\\\\\\\\n\",\n    \"  -d '{{\\n\",\n    \"    \\\"model\\\": \\\"Qwen/Qwen2.5-VL-7B-Instruct\\\",\\n\",\n    \"    \\\"messages\\\": [\\n\",\n    \"      {{\\n\",\n    \"        \\\"role\\\": \\\"user\\\",\\n\",\n    \"        \\\"content\\\": [\\n\",\n    \"          {{\\n\",\n    \"            \\\"type\\\": \\\"text\\\",\\n\",\n    \"            \\\"text\\\": \\\"What’s in this image?\\\"\\n\",\n    \"          }},\\n\",\n    \"          {{\\n\",\n    \"            \\\"type\\\": \\\"image_url\\\",\\n\",\n    \"            \\\"image_url\\\": {{\\n\",\n    \"              \\\"url\\\": \\\"{example_image_url}\\\"\\n\",\n    \"            }}\\n\",\n    \"          }}\\n\",\n    \"        ]\\n\",\n    \"      }}\\n\",\n    \"    ],\\n\",\n    \"    \\\"max_tokens\\\": 300\\n\",\n    \"  }}'\\n\",\n    \"\\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"response = subprocess.check_output(curl_command, shell=True).decode()\\n\",\n    \"print_highlight(response)\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"response = subprocess.check_output(curl_command, shell=True).decode()\\n\",\n    \"print_highlight(response)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Using Python Requests\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import requests\\n\",\n    \"\\n\",\n    \"url = f\\\"http://localhost:{port}/v1/chat/completions\\\"\\n\",\n    \"\\n\",\n    \"data = {\\n\",\n    \"    \\\"model\\\": \\\"Qwen/Qwen2.5-VL-7B-Instruct\\\",\\n\",\n    \"    \\\"messages\\\": [\\n\",\n    \"        {\\n\",\n    \"            \\\"role\\\": \\\"user\\\",\\n\",\n    \"            \\\"content\\\": [\\n\",\n    \"                {\\\"type\\\": \\\"text\\\", \\\"text\\\": \\\"What’s in this image?\\\"},\\n\",\n    \"                {\\n\",\n    \"                    \\\"type\\\": \\\"image_url\\\",\\n\",\n    \"                    \\\"image_url\\\": {\\\"url\\\": example_image_url},\\n\",\n    \"                },\\n\",\n    \"            ],\\n\",\n    \"        }\\n\",\n    \"    ],\\n\",\n    \"    \\\"max_tokens\\\": 300,\\n\",\n    \"}\\n\",\n    \"\\n\",\n    \"response = requests.post(url, json=data)\\n\",\n    \"print_highlight(response.text)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Using OpenAI Python Client\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from openai import OpenAI\\n\",\n    \"\\n\",\n    \"client = OpenAI(base_url=f\\\"http://localhost:{port}/v1\\\", api_key=\\\"None\\\")\\n\",\n    \"\\n\",\n    \"response = client.chat.completions.create(\\n\",\n    \"    model=\\\"Qwen/Qwen2.5-VL-7B-Instruct\\\",\\n\",\n    \"    messages=[\\n\",\n    \"        {\\n\",\n    \"            \\\"role\\\": \\\"user\\\",\\n\",\n    \"            \\\"content\\\": [\\n\",\n    \"                {\\n\",\n    \"                    \\\"type\\\": \\\"text\\\",\\n\",\n    \"                    \\\"text\\\": \\\"What is in this image?\\\",\\n\",\n    \"                },\\n\",\n    \"                {\\n\",\n    \"                    \\\"type\\\": \\\"image_url\\\",\\n\",\n    \"                    \\\"image_url\\\": {\\\"url\\\": example_image_url},\\n\",\n    \"                },\\n\",\n    \"            ],\\n\",\n    \"        }\\n\",\n    \"    ],\\n\",\n    \"    max_tokens=300,\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"print_highlight(response.choices[0].message.content)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Multiple-Image Inputs\\n\",\n    \"\\n\",\n    \"The server also supports multiple images and interleaved text and images if the model supports it.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from openai import OpenAI\\n\",\n    \"\\n\",\n    \"client = OpenAI(base_url=f\\\"http://localhost:{port}/v1\\\", api_key=\\\"None\\\")\\n\",\n    \"\\n\",\n    \"response = client.chat.completions.create(\\n\",\n    \"    model=\\\"Qwen/Qwen2.5-VL-7B-Instruct\\\",\\n\",\n    \"    messages=[\\n\",\n    \"        {\\n\",\n    \"            \\\"role\\\": \\\"user\\\",\\n\",\n    \"            \\\"content\\\": [\\n\",\n    \"                {\\n\",\n    \"                    \\\"type\\\": \\\"image_url\\\",\\n\",\n    \"                    \\\"image_url\\\": {\\n\",\n    \"                        \\\"url\\\": example_image_url,\\n\",\n    \"                    },\\n\",\n    \"                },\\n\",\n    \"                {\\n\",\n    \"                    \\\"type\\\": \\\"image_url\\\",\\n\",\n    \"                    \\\"image_url\\\": {\\n\",\n    \"                        \\\"url\\\": logo_image_url,\\n\",\n    \"                    },\\n\",\n    \"                },\\n\",\n    \"                {\\n\",\n    \"                    \\\"type\\\": \\\"text\\\",\\n\",\n    \"                    \\\"text\\\": \\\"I have two very different images. They are not related at all. \\\"\\n\",\n    \"                    \\\"Please describe the first image in one sentence, and then describe the second image in another sentence.\\\",\\n\",\n    \"                },\\n\",\n    \"            ],\\n\",\n    \"        }\\n\",\n    \"    ],\\n\",\n    \"    temperature=0,\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"print_highlight(response.choices[0].message.content)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"terminate_process(vision_process)\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "docs/basic_usage/popular_model_usage.rst",
    "content": "Popular Model Usage (DeepSeek, GPT-OSS, GLM, Llama, MiniMax, Qwen, and more)\n===============================================================\n\nFor more usage examples and recipes, visit the `SGLang Cookbook <https://cookbook.sglang.io/>`_.\n\n.. toctree::\n   :maxdepth: 1\n\n   deepseek_v3.md\n   deepseek_v32.md\n   glm45.md\n   glmv.md\n   gpt_oss.md\n   minimax_m2.md\n   qwen3.md\n   qwen3_5.md\n   qwen3_vl.md\n   deepseek_ocr.md\n   llama4.md\n"
  },
  {
    "path": "docs/basic_usage/qwen3.md",
    "content": "# Qwen3-Next Usage\n\nSGLang has supported Qwen3-Next-80B-A3B-Instruct and Qwen3-Next-80B-A3B-Thinking since [this PR](https://github.com/sgl-project/sglang/pull/10233).\n\n## Launch Qwen3-Next with SGLang\n\nTo serve Qwen3-Next models on 4xH100/H200 GPUs:\n\n```bash\npython3 -m sglang.launch_server --model Qwen/Qwen3-Next-80B-A3B-Instruct --tp 4\n```\n\n### Configuration Tips\n- `--max-mamba-cache-size`: Adjust `--max-mamba-cache-size` to increase mamba cache space and max running requests capability. It will decrease KV cache space as a trade-off. You can adjust it according to workload.\n- `--mamba-ssm-dtype`: `bfloat16` or `float32`, use `bfloat16` to save mamba cache size and `float32` to get more accurate results. The default setting is `float32`.\n- `--mamba-full-memory-ratio`: The ratio of mamba state memory to full kv cache memory. The default is 0.9.\n\n### Mamba Radix Cache\nSGLang supports prefix caching for Qwen3-Next models named `MambaRadixCache`, which improves inference speed by reusing computation results. There are two versions of `MambaRadixCache`:\n- `no_buffer`: The default version, which is also other hybrid linear models' choice. When it is enabled, SGLang will automatically close overlap schedule for compatibility reasons.\n- `extra_buffer`: An optimized version that is compatible with features like page size > 1, overlap schedule, and speculative decoding. It also supports storing mamba state in branching positions. However, it requires two extra mamba spaces for a ping-pong buffer for each request. To enable it, add the argument `--mamba-scheduler-strategy extra_buffer` when launching the server.\n\n### EAGLE Speculative Decoding\n**Description**: SGLang has supported Qwen3-Next models with [EAGLE speculative decoding](https://docs.sglang.io/advanced_features/speculative_decoding.html#EAGLE-Decoding).\n\n**Usage**:\nAdd arguments `--speculative-algorithm`, `--speculative-num-steps`, `--speculative-eagle-topk` and `--speculative-num-draft-tokens` to enable this feature. For example:\n\n``` bash\npython3 -m sglang.launch_server \\\n  --model Qwen/Qwen3-Next-80B-A3B-Instruct \\\n  --tp 4 \\\n  --speculative-num-steps 3 \\\n  --speculative-eagle-topk 1 \\\n  --speculative-num-draft-tokens 4 \\\n  --speculative-algo NEXTN\n```\n\nDetails can be seen in [this PR](https://github.com/sgl-project/sglang/pull/10233).\n"
  },
  {
    "path": "docs/basic_usage/qwen3_5.md",
    "content": "# Qwen 3.5 Usage\n\nQwen 3.5 is Alibaba's latest generation LLM featuring a hybrid attention architecture, advanced MoE with shared experts, and native multimodal capabilities.\n\nKey architecture features:\n- **Hybrid Attention**: Gated Delta Networks (linear, O(n) complexity) combined with full attention every 4th layer for high associative recall\n- **MoE with Shared Experts**: Top-8 active out of 64 routed experts plus a dedicated shared expert for universal features\n- **Multimodal**: DeepStack Vision Transformer with Conv3d for native image and video understanding\n\n## Launch Qwen 3.5 with SGLang\n\n### Dense Model\n\nTo serve `Qwen/Qwen3.5-397B-A17B` on 8 GPUs:\n\n```bash\npython3 -m sglang.launch_server \\\n    --model-path Qwen/Qwen3.5-397B-A17B \\\n    --tp 8 \\\n    --trust-remote-code\n```\n\n### AMD GPU (MI300X / MI325X / MI35X)\n\nOn AMD Instinct GPUs, use the `triton` attention backend. Both the full attention layers and the Gated Delta Net (linear attention) layers use Triton-based kernels on ROCm:\n\n```bash\nSGLANG_USE_AITER=1 python3 -m sglang.launch_server \\\n    --model-path Qwen/Qwen3.5-397B-A17B \\\n    --tp 8 \\\n    --attention-backend triton \\\n    --trust-remote-code\n```\n\n```{tip}\nSet `SGLANG_USE_AITER=1` to enable AMD's optimized aiter kernels for MoE and GEMM operations.\n```\n\n### Configuration Tips\n\n- `--attention-backend`: Use `triton` on AMD GPUs for Qwen 3.5. The hybrid attention architecture (Gated Delta Networks + full attention) works best with the Triton backend on ROCm. The linear attention (GDN) layers always use Triton kernels internally via the `GDNAttnBackend`.\n- `--watchdog-timeout`: Increase to `1200` or higher for this large model, as weight loading takes significant time.\n- `--model-loader-extra-config '{\"enable_multithread_load\": true}'`: Enables parallel weight loading for faster startup.\n\n### Reasoning and Tool Calling\n\nQwen 3.5 supports reasoning and tool calling via the Qwen3 parsers:\n\n```bash\npython3 -m sglang.launch_server \\\n    --model-path Qwen/Qwen3.5-397B-A17B \\\n    --tp 8 \\\n    --trust-remote-code \\\n    --reasoning-parser qwen3 \\\n    --tool-call-parser qwen3_coder\n```\n\n## Accuracy Evaluation\n\nYou can evaluate the model accuracy using `lm-eval`:\n\n```bash\npip install lm-eval[api]\n\nlm_eval --model local-completions \\\n    --model_args '{\"base_url\": \"http://localhost:8000/v1/completions\", \"model\": \"Qwen/Qwen3.5-397B-A17B\", \"num_concurrent\": 256, \"max_retries\": 10, \"max_gen_toks\": 2048}' \\\n    --tasks gsm8k \\\n    --batch_size auto \\\n    --num_fewshot 5 \\\n    --trust_remote_code\n```\n\n## Additional Resources\n\n- [AMD Day 0 Support for Qwen 3.5 on AMD Instinct GPUs](https://www.amd.com/en/developer/resources/technical-articles/2026/day-0-support-for-qwen-3-5-on-amd-instinct-gpus.html)\n- [HuggingFace Model Card](https://huggingface.co/Qwen/Qwen3.5-397B-A17B)\n"
  },
  {
    "path": "docs/basic_usage/qwen3_vl.md",
    "content": "# Qwen3-VL Usage\n\n[Qwen3-VL](https://huggingface.co/collections/Qwen/qwen3-vl)\nis Alibaba’s latest multimodal large language model with strong text, vision, and reasoning capabilities.\nSGLang supports Qwen3-VL Family of models with Image and Video input support.\n\n## Launch commands for SGLang\n\nBelow are suggested launch commands tailored for different hardware / precision modes\n\n### FP8 (quantised) mode\nFor high memory-efficiency and latency optimized deployments (e.g., on H100, H200) where FP8 checkpoint is supported:\n```bash\npython3 -m sglang.launch_server \\\n  --model-path Qwen/Qwen3-VL-235B-A22B-Instruct-FP8 \\\n  --tp 8 \\\n  --ep 8 \\\n  --host 0.0.0.0 \\\n  --port 30000 \\\n  --keep-mm-feature-on-device\n```\n\n### Non-FP8 (BF16 / full precision) mode\nFor deployments on A100/H100 where BF16 is used (or FP8 snapshot not used):\n```bash\npython3 -m sglang.launch_server \\\n  --model-path Qwen/Qwen3-VL-235B-A22B-Instruct \\\n  --tp 8 \\\n  --ep 8 \\\n  --host 0.0.0.0 \\\n  --port 30000 \\\n```\n\n## Hardware-specific notes / recommendations\n\n- On H100 with FP8: Use the FP8 checkpoint for best memory efficiency.\n- On A100 / H100 with BF16 (non-FP8): It’s recommended to use `--mm-max-concurrent-calls` to control parallel throughput and GPU memory usage during image/video inference.\n- On H200 & B200: The model can be run “out of the box”, supporting full context length plus concurrent image + video processing.\n\n## Sending Image/Video Requests\n\n### Image input:\n\n```python\nimport requests\n\nurl = f\"http://localhost:30000/v1/chat/completions\"\n\ndata = {\n    \"model\": \"Qwen/Qwen3-VL-30B-A3B-Instruct\",\n    \"messages\": [\n        {\n            \"role\": \"user\",\n            \"content\": [\n                {\"type\": \"text\", \"text\": \"What’s in this image?\"},\n                {\n                    \"type\": \"image_url\",\n                    \"image_url\": {\n                        \"url\": \"https://github.com/sgl-project/sglang/blob/main/examples/assets/example_image.png?raw=true\"\n                    },\n                },\n            ],\n        }\n    ],\n    \"max_tokens\": 300,\n}\n\nresponse = requests.post(url, json=data)\nprint(response.text)\n```\n\n### Video Input:\n\n```python\nimport requests\n\nurl = f\"http://localhost:30000/v1/chat/completions\"\n\ndata = {\n    \"model\": \"Qwen/Qwen3-VL-30B-A3B-Instruct\",\n    \"messages\": [\n        {\n            \"role\": \"user\",\n            \"content\": [\n                {\"type\": \"text\", \"text\": \"What’s happening in this video?\"},\n                {\n                    \"type\": \"video_url\",\n                    \"video_url\": {\n                        \"url\": \"https://github.com/sgl-project/sgl-test-files/raw/refs/heads/main/videos/jobs_presenting_ipod.mp4\"\n                    },\n                },\n            ],\n        }\n    ],\n    \"max_tokens\": 300,\n}\n\nresponse = requests.post(url, json=data)\nprint(response.text)\n```\n\n## Important Server Parameters and Flags\n\nWhen launching the model server for **multimodal support**, you can use the following command-line arguments to fine-tune performance and behavior:\n\n- `--mm-attention-backend`: Specify multimodal attention backend. Eg. `fa3`(Flash Attention 3)\n- `--mm-max-concurrent-calls <value>`: Specifies the **maximum number of concurrent asynchronous multimodal data processing calls** allowed on the server. Use this to control parallel throughput and GPU memory usage during image/video inference.\n- `--mm-per-request-timeout <seconds>`: Defines the **timeout duration (in seconds)** for each multimodal request. If a request exceeds this time limit (e.g., for very large video inputs), it will be automatically terminated.\n- `--keep-mm-feature-on-device`: Instructs the server to **retain multimodal feature tensors on the GPU** after processing. This avoids device-to-host (D2H) memory copies and improves performance for repeated or high-frequency inference workloads.\n- `SGLANG_USE_CUDA_IPC_TRANSPORT=1`: Shared memory pool based CUDA IPC for multi-modal data transport. For significantly improving e2e latency.\n\n### Example usage with the above optimizations:\n```bash\nSGLANG_USE_CUDA_IPC_TRANSPORT=1 \\\nSGLANG_VLM_CACHE_SIZE_MB=0 \\\npython -m sglang.launch_server \\\n  --model-path Qwen/Qwen3-VL-235B-A22B-Instruct \\\n  --host 0.0.0.0 \\\n  --port 30000 \\\n  --trust-remote-code \\\n  --tp-size 8 \\\n  --enable-cache-report \\\n  --log-level info \\\n  --max-running-requests 64 \\\n  --mem-fraction-static 0.65 \\\n  --chunked-prefill-size 8192 \\\n  --attention-backend fa3 \\\n  --mm-attention-backend fa3 \\\n  --enable-metrics\n```\n"
  },
  {
    "path": "docs/basic_usage/sampling_params.md",
    "content": "# Sampling Parameters\n\nThis doc describes the sampling parameters of the SGLang Runtime. It is the low-level endpoint of the runtime.\nIf you want a high-level endpoint that can automatically handle chat templates, consider using the [OpenAI Compatible API](openai_api_completions.ipynb).\n\n## `/generate` Endpoint\n\nThe `/generate` endpoint accepts the following parameters in JSON format. For detailed usage, see the [native API doc](native_api.ipynb). The object is defined at `io_struct.py::GenerateReqInput`. You can also read the source code to find more arguments and docs.\n\n| Argument                   | Type/Default                                                                 | Description                                                                                                                                                     |\n|----------------------------|------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| text                       | `Optional[Union[List[str], str]] = None`                                     | The input prompt. Can be a single prompt or a batch of prompts.                                                                                                 |\n| input_ids                  | `Optional[Union[List[List[int]], List[int]]] = None`                         | The token IDs for text; one can specify either text or input_ids.                                                                                               |\n| input_embeds               | `Optional[Union[List[List[List[float]]], List[List[float]]]] = None`         | The embeddings for input_ids; one can specify either text, input_ids, or input_embeds.                                                                          |\n| image_data                 | `Optional[Union[List[List[ImageDataItem]], List[ImageDataItem], ImageDataItem]] = None` | The image input. Supports three formats: (1) **Raw images**: PIL Image, file path, URL, or base64 string; (2) **Processor output**: Dict with `format: \"processor_output\"` containing HuggingFace processor outputs; (3) **Precomputed embeddings**: Dict with `format: \"precomputed_embedding\"` and `feature` containing pre-calculated visual embeddings. Can be a single image, list of images, or list of lists of images. See [Multimodal Input Formats](#multimodal-input-formats) for details. |\n| audio_data                 | `Optional[Union[List[AudioDataItem], AudioDataItem]] = None`                 | The audio input. Can be a file name, URL, or base64 encoded string.                                                                                             |\n| sampling_params            | `Optional[Union[List[Dict], Dict]] = None`                                   | The sampling parameters as described in the sections below.                                                                                                     |\n| rid                        | `Optional[Union[List[str], str]] = None`                                     | The request ID.                                                                                                                                                 |\n| return_logprob             | `Optional[Union[List[bool], bool]] = None`                                   | Whether to return log probabilities for tokens.                                                                                                                 |\n| logprob_start_len          | `Optional[Union[List[int], int]] = None`                                     | If return_logprob, the start location in the prompt for returning logprobs. Default is \"-1\", which returns logprobs for output tokens only.                     |\n| top_logprobs_num           | `Optional[Union[List[int], int]] = None`                                     | If return_logprob, the number of top logprobs to return at each position.                                                                                       |\n| token_ids_logprob          | `Optional[Union[List[List[int]], List[int]]] = None`                         | If return_logprob, the token IDs to return logprob for.                                                                                                         |\n| return_text_in_logprobs    | `bool = False`                                                               | Whether to detokenize tokens in text in the returned logprobs.                                                                                                  |\n| stream                     | `bool = False`                                                               | Whether to stream output.                                                                                                                                       |\n| lora_path                  | `Optional[Union[List[Optional[str]], Optional[str]]] = None`                 | The path to the LoRA.                                                                                                                                           |\n| custom_logit_processor     | `Optional[Union[List[Optional[str]], str]] = None`                           | Custom logit processor for advanced sampling control. Must be a serialized instance of `CustomLogitProcessor` using its `to_str()` method. For usage see below. |\n| return_hidden_states       | `Union[List[bool], bool] = False`                                            | Whether to return hidden states.                                                                                                                                |\n| return_routed_experts      | `bool = False`                                                               | Whether to return routed experts for MoE models. Requires `--enable-return-routed-experts` server flag. Returns base64-encoded int32 expert IDs as a flattened array with logical shape `[num_tokens, num_layers, top_k]`. |\n\n## Sampling parameters\n\nThe object is defined at `sampling_params.py::SamplingParams`. You can also read the source code to find more arguments and docs.\n\n### Note on defaults\n\nBy default, SGLang initializes several sampling parameters from the model's `generation_config.json` (when the server is launched with `--sampling-defaults model`, which is the default). To use SGLang/OpenAI constant defaults instead, start the server with `--sampling-defaults openai`. You can always override any parameter per request via `sampling_params`.\n\n```bash\n# Use model-provided defaults from generation_config.json (default behavior)\npython -m sglang.launch_server --model-path <MODEL> --sampling-defaults model\n\n# Use SGLang/OpenAI constant defaults instead\npython -m sglang.launch_server --model-path <MODEL> --sampling-defaults openai\n```\n\n### Core parameters\n\n| Argument        | Type/Default                                 | Description                                                                                                                                    |\n|-----------------|----------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------|\n| max_new_tokens  | `int = 128`                                  | The maximum output length measured in tokens.                                                                                                  |\n| stop            | `Optional[Union[str, List[str]]] = None`     | One or multiple [stop words](https://platform.openai.com/docs/api-reference/chat/create#chat-create-stop). Generation will stop if one of these words is sampled. |\n| stop_token_ids  | `Optional[List[int]] = None`                 | Provide stop words in the form of token IDs. Generation will stop if one of these token IDs is sampled.                                        |\n| stop_regex      | `Optional[Union[str, List[str]]] = None`     | Stop when hitting any of the regex patterns in this list |\n| temperature     | `float (model default; fallback 1.0)`        | [Temperature](https://platform.openai.com/docs/api-reference/chat/create#chat-create-temperature) when sampling the next token. `temperature = 0` corresponds to greedy sampling, a higher temperature leads to more diversity. |\n| top_p           | `float (model default; fallback 1.0)`        | [Top-p](https://platform.openai.com/docs/api-reference/chat/create#chat-create-top_p) selects tokens from the smallest sorted set whose cumulative probability exceeds `top_p`. When `top_p = 1`, this reduces to unrestricted sampling from all tokens. |\n| top_k           | `int (model default; fallback -1)`           | [Top-k](https://developer.nvidia.com/blog/how-to-get-better-outputs-from-your-large-language-model/#predictability_vs_creativity) randomly selects from the `k` highest-probability tokens. |\n| min_p           | `float (model default; fallback 0.0)`        | [Min-p](https://github.com/huggingface/transformers/issues/27670) samples from tokens with probability larger than `min_p * highest_token_probability`. |\n\n### Penalizers\n\n| Argument           | Type/Default           | Description                                                                                                                                    |\n|--------------------|------------------------|------------------------------------------------------------------------------------------------------------------------------------------------|\n| frequency_penalty  | `float = 0.0`          | Penalizes tokens based on their frequency in generation so far. Must be between `-2` and `2` where negative numbers encourage repeatment of tokens and positive number encourages sampling of new tokens. The scaling of penalization grows linearly with each appearance of a token. |\n| presence_penalty   | `float = 0.0`          | Penalizes tokens if they appeared in the generation so far. Must be between `-2` and `2` where negative numbers encourage repeatment of tokens and positive number encourages sampling of new tokens. The scaling of the penalization is constant if a token occurred. |\n| repetition_penalty | `float = 1.0`          | Scales the logits of previously generated tokens to discourage (values > 1) or encourage (values < 1) repetition. Valid range is `[0, 2]`; `1.0` leaves probabilities unchanged. |\n| min_new_tokens     | `int = 0`              | Forces the model to generate at least `min_new_tokens` until a stop word or EOS token is sampled. Note that this might lead to unintended behavior, for example, if the distribution is highly skewed towards these tokens. |\n\n### Constrained decoding\n\nPlease refer to our dedicated guide on [constrained decoding](../advanced_features/structured_outputs.ipynb) for the following parameters.\n\n| Argument        | Type/Default                    | Description                                                                                                                                    |\n|-----------------|---------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------|\n| json_schema     | `Optional[str] = None`          | JSON schema for structured outputs.                                                                                                            |\n| regex           | `Optional[str] = None`          | Regex for structured outputs.                                                                                                                  |\n| ebnf            | `Optional[str] = None`          | EBNF for structured outputs.                                                                                                                   |\n| structural_tag  | `Optional[str] = None`          | The structural tag for structured outputs.                                                                                                       |\n\n### Other options\n\n| Argument                      | Type/Default                    | Description                                                                                                                                    |\n|-------------------------------|---------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------|\n| n                             | `int = 1`                       | Specifies the number of output sequences to generate per request. (Generating multiple outputs in one request (n > 1) is discouraged; repeating the same prompts several times offers better control and efficiency.) |\n| ignore_eos                    | `bool = False`                  | Don't stop generation when EOS token is sampled.                                                                                               |\n| skip_special_tokens           | `bool = True`                   | Remove special tokens during decoding.                                                                                                         |\n| spaces_between_special_tokens | `bool = True`                   | Whether or not to add spaces between special tokens during detokenization.                                                                     |\n| no_stop_trim                  | `bool = False`                  | Don't trim stop words or EOS token from the generated text.                                                                                    |\n| custom_params                 | `Optional[List[Optional[Dict[str, Any]]]] = None` | Used when employing `CustomLogitProcessor`. For usage, see below.                                                                              |\n\n## Examples\n\n### Normal\n\nLaunch a server:\n\n```bash\npython -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000\n```\n\nSend a request:\n\n```python\nimport requests\n\nresponse = requests.post(\n    \"http://localhost:30000/generate\",\n    json={\n        \"text\": \"The capital of France is\",\n        \"sampling_params\": {\n            \"temperature\": 0,\n            \"max_new_tokens\": 32,\n        },\n    },\n)\nprint(response.json())\n```\n\nDetailed example in [send request](./send_request.ipynb).\n\n### Streaming\n\nSend a request and stream the output:\n\n```python\nimport requests, json\n\nresponse = requests.post(\n    \"http://localhost:30000/generate\",\n    json={\n        \"text\": \"The capital of France is\",\n        \"sampling_params\": {\n            \"temperature\": 0,\n            \"max_new_tokens\": 32,\n        },\n        \"stream\": True,\n    },\n    stream=True,\n)\n\nprev = 0\nfor chunk in response.iter_lines(decode_unicode=False):\n    chunk = chunk.decode(\"utf-8\")\n    if chunk and chunk.startswith(\"data:\"):\n        if chunk == \"data: [DONE]\":\n            break\n        data = json.loads(chunk[5:].strip(\"\\n\"))\n        output = data[\"text\"].strip()\n        print(output[prev:], end=\"\", flush=True)\n        prev = len(output)\nprint(\"\")\n```\n\nDetailed example in [openai compatible api](openai_api_completions.ipynb).\n\n### Multimodal\n\nLaunch a server:\n\n```bash\npython3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-7b-ov\n```\n\nDownload an image:\n\n```bash\ncurl -o example_image.png -L https://github.com/sgl-project/sglang/blob/main/examples/assets/example_image.png?raw=true\n```\n\nSend a request:\n\n```python\nimport requests\n\nresponse = requests.post(\n    \"http://localhost:30000/generate\",\n    json={\n        \"text\": \"<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n\"\n                \"<|im_start|>user\\n<image>\\nDescribe this image in a very short sentence.<|im_end|>\\n\"\n                \"<|im_start|>assistant\\n\",\n        \"image_data\": \"example_image.png\",\n        \"sampling_params\": {\n            \"temperature\": 0,\n            \"max_new_tokens\": 32,\n        },\n    },\n)\nprint(response.json())\n```\n\nThe `image_data` can be a file name, a URL, or a base64 encoded string. See also `python/sglang/srt/utils.py:load_image`.\n\nStreaming is supported in a similar manner as [above](#streaming).\n\nDetailed example in [OpenAI API Vision](openai_api_vision.ipynb).\n\n### Structured Outputs (JSON, Regex, EBNF)\n\nYou can specify a JSON schema, regular expression or [EBNF](https://en.wikipedia.org/wiki/Extended_Backus%E2%80%93Naur_form) to constrain the model output. The model output will be guaranteed to follow the given constraints. Only one constraint parameter (`json_schema`, `regex`, or `ebnf`) can be specified for a request.\n\nSGLang supports two grammar backends:\n\n- [XGrammar](https://github.com/mlc-ai/xgrammar) (default): Supports JSON schema, regular expression, and EBNF constraints.\n  - XGrammar currently uses the [GGML BNF format](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md).\n- [Outlines](https://github.com/dottxt-ai/outlines): Supports JSON schema and regular expression constraints.\n\nIf instead you want to initialize the Outlines backend, you can use `--grammar-backend outlines` flag:\n\n```bash\npython -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n--port 30000 --host 0.0.0.0 --grammar-backend [xgrammar|outlines] # xgrammar or outlines (default: xgrammar)\n```\n\n```python\nimport json\nimport requests\n\njson_schema = json.dumps({\n    \"type\": \"object\",\n    \"properties\": {\n        \"name\": {\"type\": \"string\", \"pattern\": \"^[\\\\w]+$\"},\n        \"population\": {\"type\": \"integer\"},\n    },\n    \"required\": [\"name\", \"population\"],\n})\n\n# JSON (works with both Outlines and XGrammar)\nresponse = requests.post(\n    \"http://localhost:30000/generate\",\n    json={\n        \"text\": \"Here is the information of the capital of France in the JSON format.\\n\",\n        \"sampling_params\": {\n            \"temperature\": 0,\n            \"max_new_tokens\": 64,\n            \"json_schema\": json_schema,\n        },\n    },\n)\nprint(response.json())\n\n# Regular expression (Outlines backend only)\nresponse = requests.post(\n    \"http://localhost:30000/generate\",\n    json={\n        \"text\": \"Paris is the capital of\",\n        \"sampling_params\": {\n            \"temperature\": 0,\n            \"max_new_tokens\": 64,\n            \"regex\": \"(France|England)\",\n        },\n    },\n)\nprint(response.json())\n\n# EBNF (XGrammar backend only)\nresponse = requests.post(\n    \"http://localhost:30000/generate\",\n    json={\n        \"text\": \"Write a greeting.\",\n        \"sampling_params\": {\n            \"temperature\": 0,\n            \"max_new_tokens\": 64,\n            \"ebnf\": 'root ::= \"Hello\" | \"Hi\" | \"Hey\"',\n        },\n    },\n)\nprint(response.json())\n```\n\nDetailed example in [structured outputs](../advanced_features/structured_outputs.ipynb).\n\n### Custom logit processor\n\nLaunch a server with `--enable-custom-logit-processor` flag on.\n\n```bash\npython -m sglang.launch_server \\\n  --model-path meta-llama/Meta-Llama-3-8B-Instruct \\\n  --port 30000 \\\n  --enable-custom-logit-processor\n```\n\nDefine a custom logit processor that will always sample a specific token id.\n\n```python\nfrom sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor\n\nclass DeterministicLogitProcessor(CustomLogitProcessor):\n    \"\"\"A dummy logit processor that changes the logits to always\n    sample the given token id.\n    \"\"\"\n\n    def __call__(self, logits, custom_param_list):\n        # Check that the number of logits matches the number of custom parameters\n        assert logits.shape[0] == len(custom_param_list)\n        key = \"token_id\"\n\n        for i, param_dict in enumerate(custom_param_list):\n            # Mask all other tokens\n            logits[i, :] = -float(\"inf\")\n            # Assign highest probability to the specified token\n            logits[i, param_dict[key]] = 0.0\n        return logits\n```\n\nSend a request:\n\n```python\nimport requests\n\nresponse = requests.post(\n    \"http://localhost:30000/generate\",\n    json={\n        \"text\": \"The capital of France is\",\n        \"custom_logit_processor\": DeterministicLogitProcessor().to_str(),\n        \"sampling_params\": {\n            \"temperature\": 0.0,\n            \"max_new_tokens\": 32,\n            \"custom_params\": {\"token_id\": 5},\n        },\n    },\n)\nprint(response.json())\n```\n\nSend an OpenAI chat completion request:\n\n```python\nimport openai\nfrom sglang.utils import print_highlight\n\nclient = openai.Client(base_url=\"http://127.0.0.1:30000/v1\", api_key=\"None\")\n\nresponse = client.chat.completions.create(\n    model=\"meta-llama/Meta-Llama-3-8B-Instruct\",\n    messages=[\n        {\"role\": \"user\", \"content\": \"List 3 countries and their capitals.\"},\n    ],\n    temperature=0.0,\n    max_tokens=32,\n    extra_body={\n        \"custom_logit_processor\": DeterministicLogitProcessor().to_str(),\n        \"custom_params\": {\"token_id\": 5},\n    },\n)\n\nprint_highlight(f\"Response: {response}\")\n```\n"
  },
  {
    "path": "docs/basic_usage/send_request.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Sending Requests\\n\",\n    \"This notebook provides a quick-start guide to use SGLang in chat completions after installation. Once your server is running, API documentation is available at `http://localhost:30000/docs` (Swagger UI), `http://localhost:30000/redoc` (ReDoc), or `http://localhost:30000/openapi.json` (OpenAPI spec, useful for AI agents). Replace `30000` with your port if using a different one.\\n\",\n    \"\\n\",\n    \"- For Vision Language Models, see [OpenAI APIs - Vision](openai_api_vision.ipynb).\\n\",\n    \"- For Embedding Models, see [OpenAI APIs - Embedding](openai_api_embeddings.ipynb) and [Encode (embedding model)](native_api.html#Encode-(embedding-model)).\\n\",\n    \"- For Reward Models, see [Classify (reward model)](native_api.html#Classify-(reward-model)).\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Launch A Server\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from sglang.test.doc_patch import launch_server_cmd\\n\",\n    \"from sglang.utils import wait_for_server, print_highlight, terminate_process\\n\",\n    \"\\n\",\n    \"# This is equivalent to running the following command in your terminal\\n\",\n    \"# python3 -m sglang.launch_server --model-path qwen/qwen2.5-0.5b-instruct --host 0.0.0.0\\n\",\n    \"\\n\",\n    \"server_process, port = launch_server_cmd(\\\"\\\"\\\"\\n\",\n    \"python3 -m sglang.launch_server --model-path qwen/qwen2.5-0.5b-instruct \\\\\\n\",\n    \" --host 0.0.0.0 --log-level warning\\n\",\n    \"\\\"\\\"\\\")\\n\",\n    \"\\n\",\n    \"wait_for_server(f\\\"http://localhost:{port}\\\", process=server_process)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Using cURL\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import subprocess, json\\n\",\n    \"\\n\",\n    \"curl_command = f\\\"\\\"\\\"\\n\",\n    \"curl -s http://localhost:{port}/v1/chat/completions \\\\\\n\",\n    \"  -H \\\"Content-Type: application/json\\\" \\\\\\n\",\n    \"  -d '{{\\\"model\\\": \\\"qwen/qwen2.5-0.5b-instruct\\\", \\\"messages\\\": [{{\\\"role\\\": \\\"user\\\", \\\"content\\\": \\\"What is the capital of France?\\\"}}]}}'\\n\",\n    \"\\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"response = json.loads(subprocess.check_output(curl_command, shell=True))\\n\",\n    \"print_highlight(response)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Using Python Requests\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import requests\\n\",\n    \"\\n\",\n    \"url = f\\\"http://localhost:{port}/v1/chat/completions\\\"\\n\",\n    \"\\n\",\n    \"data = {\\n\",\n    \"    \\\"model\\\": \\\"qwen/qwen2.5-0.5b-instruct\\\",\\n\",\n    \"    \\\"messages\\\": [{\\\"role\\\": \\\"user\\\", \\\"content\\\": \\\"What is the capital of France?\\\"}],\\n\",\n    \"}\\n\",\n    \"\\n\",\n    \"response = requests.post(url, json=data)\\n\",\n    \"print_highlight(response.json())\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Using OpenAI Python Client\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import openai\\n\",\n    \"\\n\",\n    \"client = openai.Client(base_url=f\\\"http://127.0.0.1:{port}/v1\\\", api_key=\\\"None\\\")\\n\",\n    \"\\n\",\n    \"response = client.chat.completions.create(\\n\",\n    \"    model=\\\"qwen/qwen2.5-0.5b-instruct\\\",\\n\",\n    \"    messages=[\\n\",\n    \"        {\\\"role\\\": \\\"user\\\", \\\"content\\\": \\\"List 3 countries and their capitals.\\\"},\\n\",\n    \"    ],\\n\",\n    \"    temperature=0,\\n\",\n    \"    max_tokens=64,\\n\",\n    \")\\n\",\n    \"print_highlight(response)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Streaming\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import openai\\n\",\n    \"\\n\",\n    \"client = openai.Client(base_url=f\\\"http://127.0.0.1:{port}/v1\\\", api_key=\\\"None\\\")\\n\",\n    \"\\n\",\n    \"# Use stream=True for streaming responses\\n\",\n    \"response = client.chat.completions.create(\\n\",\n    \"    model=\\\"qwen/qwen2.5-0.5b-instruct\\\",\\n\",\n    \"    messages=[\\n\",\n    \"        {\\\"role\\\": \\\"user\\\", \\\"content\\\": \\\"List 3 countries and their capitals.\\\"},\\n\",\n    \"    ],\\n\",\n    \"    temperature=0,\\n\",\n    \"    max_tokens=64,\\n\",\n    \"    stream=True,\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"# Handle the streaming output\\n\",\n    \"for chunk in response:\\n\",\n    \"    if chunk.choices[0].delta.content:\\n\",\n    \"        print(chunk.choices[0].delta.content, end=\\\"\\\", flush=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Using Native Generation APIs\\n\",\n    \"\\n\",\n    \"You can also use the native `/generate` endpoint with requests, which provides more flexibility. An API reference is available at [Sampling Parameters](sampling_params.md).\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import requests\\n\",\n    \"\\n\",\n    \"response = requests.post(\\n\",\n    \"    f\\\"http://localhost:{port}/generate\\\",\\n\",\n    \"    json={\\n\",\n    \"        \\\"text\\\": \\\"The capital of France is\\\",\\n\",\n    \"        \\\"sampling_params\\\": {\\n\",\n    \"            \\\"temperature\\\": 0,\\n\",\n    \"            \\\"max_new_tokens\\\": 32,\\n\",\n    \"        },\\n\",\n    \"    },\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"print_highlight(response.json())\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Streaming\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import requests, json\\n\",\n    \"\\n\",\n    \"response = requests.post(\\n\",\n    \"    f\\\"http://localhost:{port}/generate\\\",\\n\",\n    \"    json={\\n\",\n    \"        \\\"text\\\": \\\"The capital of France is\\\",\\n\",\n    \"        \\\"sampling_params\\\": {\\n\",\n    \"            \\\"temperature\\\": 0,\\n\",\n    \"            \\\"max_new_tokens\\\": 32,\\n\",\n    \"        },\\n\",\n    \"        \\\"stream\\\": True,\\n\",\n    \"    },\\n\",\n    \"    stream=True,\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"prev = 0\\n\",\n    \"for chunk in response.iter_lines(decode_unicode=False):\\n\",\n    \"    chunk = chunk.decode(\\\"utf-8\\\")\\n\",\n    \"    if chunk and chunk.startswith(\\\"data:\\\"):\\n\",\n    \"        if chunk == \\\"data: [DONE]\\\":\\n\",\n    \"            break\\n\",\n    \"        data = json.loads(chunk[5:].strip(\\\"\\\\n\\\"))\\n\",\n    \"        output = data[\\\"text\\\"]\\n\",\n    \"        print(output[prev:], end=\\\"\\\", flush=True)\\n\",\n    \"        prev = len(output)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"terminate_process(server_process)\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "docs/conf.py",
    "content": "import os\nimport sys\nfrom datetime import datetime\n\nsys.path.insert(0, os.path.abspath(\"../..\"))\n\nversion_file = \"../python/sglang/version.py\"\nwith open(version_file, \"r\") as f:\n    exec(compile(f.read(), version_file, \"exec\"))\n__version__ = locals()[\"__version__\"]\n\nproject = \"SGLang\"\ncopyright = f\"2023-{datetime.now().year}, SGLang\"\nauthor = \"SGLang Team\"\n\nversion = __version__\nrelease = __version__\n\nextensions = [\n    \"sphinx.ext.autodoc\",\n    \"sphinx.ext.autosummary\",\n    \"sphinx.ext.napoleon\",\n    \"sphinx.ext.viewcode\",\n    \"sphinx.ext.autosectionlabel\",\n    \"sphinx.ext.intersphinx\",\n    \"sphinx_tabs.tabs\",\n    \"myst_parser\",\n    \"sphinx_copybutton\",\n    \"sphinxcontrib.mermaid\",\n    \"nbsphinx\",\n    \"sphinx.ext.mathjax\",\n]\n\nnbsphinx_allow_errors = True\nnbsphinx_execute = \"never\"\n\nautosectionlabel_prefix_document = True\nnbsphinx_allow_directives = True\n\n\nmyst_enable_extensions = [\n    \"dollarmath\",\n    \"amsmath\",\n    \"deflist\",\n    \"colon_fence\",\n    \"html_image\",\n    \"linkify\",\n    \"substitution\",\n]\n\nmyst_heading_anchors = 3\n\nnbsphinx_kernel_name = \"python3\"\nnbsphinx_execute_arguments = [\n    \"--InlineBackend.figure_formats={'svg', 'pdf'}\",\n    \"--InlineBackend.rc={'figure.dpi': 96}\",\n]\n\n\nnb_render_priority = {\n    \"html\": (\n        \"application/vnd.jupyter.widget-view+json\",\n        \"application/javascript\",\n        \"text/html\",\n        \"image/svg+xml\",\n        \"image/png\",\n        \"image/jpeg\",\n        \"text/markdown\",\n        \"text/latex\",\n        \"text/plain\",\n    )\n}\n\nmyst_enable_extensions = [\n    \"dollarmath\",\n    \"amsmath\",\n    \"deflist\",\n    \"colon_fence\",\n    \"html_image\",\n    \"linkify\",\n    \"substitution\",\n]\n\nmyst_heading_anchors = 3\nmyst_ref_domains = [\"std\", \"py\"]\n\ntemplates_path = [\"_templates\"]\n\nsource_suffix = {\n    \".rst\": \"restructuredtext\",\n    \".md\": \"markdown\",\n}\n\nmaster_doc = \"index\"\n\nlanguage = \"en\"\n\nexclude_patterns = [\"_build\", \"Thumbs.db\", \".DS_Store\"]\n\npygments_style = \"sphinx\"\n\nhtml_theme = \"sphinx_book_theme\"\nhtml_logo = \"_static/image/logo.png\"\nhtml_favicon = \"_static/image/logo.ico\"\nhtml_title = project\nhtml_copy_source = True\nhtml_last_updated_fmt = \"\"\n\nhtml_theme_options = {\n    \"repository_url\": \"https://github.com/sgl-project/sgl-project.github.io\",\n    \"repository_branch\": \"main\",\n    \"show_navbar_depth\": 3,\n    \"max_navbar_depth\": 4,\n    \"collapse_navbar\": True,\n    \"use_edit_page_button\": True,\n    \"use_source_button\": True,\n    \"use_issues_button\": True,\n    \"use_repository_button\": True,\n    \"use_download_button\": True,\n    \"use_sidenotes\": True,\n    \"show_toc_level\": 2,\n}\n\nhtml_context = {\n    \"display_github\": True,\n    \"github_user\": \"sgl-project\",\n    \"github_repo\": \"sgl-project.github.io\",\n    \"github_version\": \"main\",\n    \"conf_py_path\": \"/docs/\",\n}\n\nhtml_static_path = [\"_static\"]\nhtml_css_files = [\"css/custom_log.css\"]\n\n\ndef setup(app):\n    app.add_css_file(\"css/custom_log.css\")\n\n\nmyst_enable_extensions = [\n    \"dollarmath\",\n    \"amsmath\",\n    \"deflist\",\n    \"colon_fence\",\n]\nmyst_heading_anchors = 5\n\nhtmlhelp_basename = \"sglangdoc\"\n\nlatex_elements = {}\n\nlatex_documents = [\n    (master_doc, \"sglang.tex\", \"sglang Documentation\", \"SGLang Team\", \"manual\"),\n]\n\nman_pages = [(master_doc, \"sglang\", \"sglang Documentation\", [author], 1)]\n\ntexinfo_documents = [\n    (\n        master_doc,\n        \"sglang\",\n        \"sglang Documentation\",\n        author,\n        \"sglang\",\n        \"One line description of project.\",\n        \"Miscellaneous\",\n    ),\n]\n\nepub_title = project\n\nepub_exclude_files = [\"search.html\"]\n\ncopybutton_prompt_text = r\">>> |\\.\\.\\. \"\ncopybutton_prompt_is_regexp = True\n\nautodoc_preserve_defaults = True\nnavigation_with_keys = False\n\nautodoc_mock_imports = [\n    \"torch\",\n    \"transformers\",\n    \"triton\",\n]\n\nintersphinx_mapping = {\n    \"python\": (\"https://docs.python.org/3.12\", None),\n    \"typing_extensions\": (\"https://typing-extensions.readthedocs.io/en/latest\", None),\n    \"pillow\": (\"https://pillow.readthedocs.io/en/stable\", None),\n    \"numpy\": (\"https://numpy.org/doc/stable\", None),\n    \"torch\": (\"https://pytorch.org/docs/stable\", None),\n}\n\nhtml_theme = \"sphinx_book_theme\"\n\n\nnbsphinx_prolog = \"\"\"\n.. raw:: html\n\n    <style>\n        .output_area.stderr, .output_area.stdout {\n            color: #d3d3d3 !important; /* light gray */\n        }\n    </style>\n\"\"\"\n"
  },
  {
    "path": "docs/deploy.py",
    "content": "# Deploy the documents\n\nimport os\nfrom datetime import datetime\n\n\ndef run_cmd(cmd):\n    print(cmd)\n    os.system(cmd)\n\n\nrun_cmd(\"cd $DOC_SITE_PATH; git pull\")\n\n# (Optional) Remove old files\n# run_cmd(\"rm -rf $ALPA_SITE_PATH/*\")\n\nrun_cmd(\"cp -r _build/html/* $DOC_SITE_PATH\")\n\ncmd_message = f\"Update {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\"\nrun_cmd(\n    f\"cd $DOC_SITE_PATH; git add .; git commit -m '{cmd_message}'; git push origin main\"\n)\n"
  },
  {
    "path": "docs/developer_guide/bench_serving.md",
    "content": "# Bench Serving Guide\n\nThis guide explains how to benchmark online serving throughput and latency using `python -m sglang.bench_serving`. It supports multiple inference backends via OpenAI-compatible and native endpoints, and produces both console metrics and optional JSONL outputs.\n\n### What it does\n\n- Generates synthetic or dataset-driven prompts and submits them to a target serving endpoint\n- Measures throughput, time-to-first-token (TTFT), inter-token latency (ITL), per-request end-to-end latency, and more\n- Supports streaming or non-streaming modes, rate control, and concurrency limits\n\n### Supported backends and endpoints\n\n- `sglang` / `sglang-native`: `POST /generate`\n- `sglang-oai`, `vllm`, `lmdeploy`: `POST /v1/completions`\n- `sglang-oai-chat`, `vllm-chat`, `lmdeploy-chat`: `POST /v1/chat/completions`\n- `trt` (TensorRT-LLM): `POST /v2/models/ensemble/generate_stream`\n- `gserver`: Custom server (Not Implemented yet in this script)\n- `truss`: `POST /v1/models/model:predict`\n\nIf `--base-url` is provided, requests are sent to it. Otherwise, `--host` and `--port` are used. When `--model` is not provided, the script will attempt to query `GET /v1/models` for an available model ID (OpenAI-compatible endpoints).\n\n### Prerequisites\n\n- Python 3.8+\n- Dependencies typically used by this script: `aiohttp`, `numpy`, `requests`, `tqdm`, `transformers`, and for some datasets `datasets`, `pillow`, `pybase64`. Install as needed.\n- An inference server running and reachable via the endpoints above\n- If your server requires authentication, set environment variable `OPENAI_API_KEY` (used as `Authorization: Bearer <key>`)\n\n### Quick start\n\nRun a basic benchmark against an sglang server exposing `/generate`:\n\n```bash\npython3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct\n```\n\n```bash\npython3 -m sglang.bench_serving \\\n  --backend sglang \\\n  --host 127.0.0.1 --port 30000 \\\n  --num-prompts 1000 \\\n  --model meta-llama/Llama-3.1-8B-Instruct\n```\n\nOr, using an OpenAI-compatible endpoint (completions):\n\n```bash\npython3 -m sglang.bench_serving \\\n  --backend vllm \\\n  --base-url http://127.0.0.1:8000 \\\n  --num-prompts 1000 \\\n  --model meta-llama/Llama-3.1-8B-Instruct\n```\n\n### Datasets\n\nSelect with `--dataset-name`:\n\n- `sharegpt` (default): loads ShareGPT-style pairs; optionally restrict with `--sharegpt-context-len` and override outputs with `--sharegpt-output-len`\n- `random`: random text lengths; sampled from ShareGPT token space\n- `random-ids`: random token ids (can lead to gibberish)\n- `image`: generates images and wraps them in chat messages; supports custom resolutions, multiple formats, and different content types\n- `generated-shared-prefix`: synthetic dataset with shared long system prompts and short questions\n- `mmmu`: samples from MMMU (Math split) and includes images\n\nCommon dataset flags:\n\n- `--num-prompts N`: number of requests\n- `--random-input-len`, `--random-output-len`, `--random-range-ratio`: for random/random-ids/image\n- `--image-count`: Number of images per request (for `image` dataset).\n\n- `--apply-chat-template`: apply tokenizer chat template when constructing prompts\n- `--dataset-path PATH`: file path for ShareGPT json; if blank and missing, it will be downloaded and cached\n\nGenerated Shared Prefix flags (for `generated-shared-prefix`):\n\n- `--gsp-num-groups`\n- `--gsp-prompts-per-group`\n- `--gsp-system-prompt-len`\n- `--gsp-question-len`\n- `--gsp-output-len`\n\nImage dataset flags (for `image`):\n\n- `--image-count`: Number of images per request\n- `--image-resolution`: Image resolution; supports presets (4k, 1080p, 720p, 360p) or custom 'heightxwidth' format (e.g., 1080x1920, 512x768)\n- `--image-format`: Image format (jpeg or png)\n- `--image-content`: Image content type (random or blank)\n\n### Examples\n\n1. To benchmark image dataset with 3 images per request, 500 prompts, 512 input length, and 512 output length, you can run:\n\n```bash\npython -m sglang.launch_server --model-path Qwen/Qwen2.5-VL-3B-Instruct --disable-radix-cache\n```\n\n```bash\npython -m sglang.bench_serving \\\n    --backend sglang-oai-chat \\\n    --dataset-name image \\\n    --num-prompts 500 \\\n    --image-count 3 \\\n    --image-resolution 720p \\\n    --random-input-len 512 \\\n    --random-output-len 512\n```\n\n2. To benchmark random dataset with 3000 prompts, 1024 input length, and 1024 output length, you can run:\n\n```bash\npython -m sglang.launch_server --model-path Qwen/Qwen2.5-3B-Instruct\n```\n\n```bash\npython3 -m sglang.bench_serving \\\n    --backend sglang \\\n    --dataset-name random \\\n    --num-prompts 3000 \\\n    --random-input 1024 \\\n    --random-output 1024 \\\n    --random-range-ratio 0.5\n```\n\n### Choosing model and tokenizer\n\n- `--model` is required unless the backend exposes `GET /v1/models`, in which case the first model ID is auto-selected.\n- `--tokenizer` defaults to `--model`. Both can be HF model IDs or local paths.\n- For ModelScope workflows, setting `SGLANG_USE_MODELSCOPE=true` enables fetching via ModelScope (weights are skipped for speed).\n- If your tokenizer lacks a chat template, the script warns because token counting can be less robust for gibberish outputs.\n\n### Rate, concurrency, and streaming\n\n- `--request-rate`: requests per second. `inf` sends all immediately (burst). Non-infinite rate uses a Poisson process for arrival times.\n- `--max-concurrency`: caps concurrent in-flight requests regardless of arrival rate.\n- `--disable-stream`: switch to non-streaming mode when supported; TTFT then equals total latency for chat completions.\n\n### Other key options\n\n- `--output-file FILE.jsonl`: append JSONL results to file; auto-named if unspecified\n- `--output-details`: include per-request arrays (generated texts, errors, ttfts, itls, input/output lens)\n- `--extra-request-body '{\"top_p\":0.9,\"temperature\":0.6}'`: merged into payload (sampling params, etc.)\n- `--disable-ignore-eos`: pass through EOS behavior (varies by backend)\n- `--warmup-requests N`: run warmup requests with short output first (default 1)\n- `--flush-cache`: call `/flush_cache` (sglang) before main run\n- `--profile`: call `/start_profile` and `/stop_profile` (requires server to enable profiling, e.g., `SGLANG_TORCH_PROFILER_DIR`)\n- `--lora-name name1 name2 ...`: randomly pick one per request and pass to backend (e.g., `lora_path` for sglang)\n- `--tokenize-prompt`: send integer IDs instead of text (currently supports `--backend sglang` only)\n\n### Authentication\n\nIf your target endpoint requires OpenAI-style auth, set:\n\n```bash\nexport OPENAI_API_KEY=sk-...yourkey...\n```\n\nThe script will add `Authorization: Bearer $OPENAI_API_KEY` automatically for OpenAI-compatible routes.\n\n### Metrics explained\n\nPrinted after each run:\n\n- Request throughput (req/s)\n- Input token throughput (tok/s) - includes both text and vision tokens\n- Output token throughput (tok/s)\n- Total token throughput (tok/s) - includes both text and vision tokens\n- Total input text tokens and Total input vision tokens - per-modality breakdown\n- Concurrency: aggregate time of all requests divided by wall time\n- End-to-End Latency (ms): mean/median/std/p99 per-request total latency\n- Time to First Token (TTFT, ms): mean/median/std/p99 for streaming mode\n- Inter-Token Latency (ITL, ms): mean/median/std/p95/p99/max between tokens\n- TPOT (ms): Token processing time after first token, i.e., `(latency - ttft)/(tokens-1)`\n- Accept length (sglang-only, if available): speculative decoding accept length\n\nThe script also retokenizes generated text with the configured tokenizer and reports \"retokenized\" counts.\n\n### JSONL output format\n\nWhen `--output-file` is set, one JSON object is appended per run. Base fields:\n\n- Arguments summary: backend, dataset, request_rate, max_concurrency, etc.\n- Duration and totals: completed, total_input_tokens, total_output_tokens, retokenized totals\n- Throughputs and latency statistics as printed in the console\n- `accept_length` when available (sglang)\n\nWith `--output-details`, an extended object also includes arrays:\n\n- `input_lens`, `output_lens`\n- `ttfts`, `itls` (per request: ITL arrays)\n- `generated_texts`, `errors`\n\n### End-to-end examples\n\n1) sglang native `/generate` (streaming):\n\n```bash\npython3 -m sglang.bench_serving \\\n  --backend sglang \\\n  --host 127.0.0.1 --port 30000 \\\n  --model meta-llama/Llama-3.1-8B-Instruct \\\n  --dataset-name random \\\n  --random-input-len 1024 --random-output-len 1024 --random-range-ratio 0.5 \\\n  --num-prompts 2000 \\\n  --request-rate 100 \\\n  --max-concurrency 512 \\\n  --output-file sglang_random.jsonl --output-details\n```\n\n2) OpenAI-compatible Completions (e.g., vLLM):\n\n```bash\npython3 -m sglang.bench_serving \\\n  --backend vllm \\\n  --base-url http://127.0.0.1:8000 \\\n  --model meta-llama/Llama-3.1-8B-Instruct \\\n  --dataset-name sharegpt \\\n  --num-prompts 1000 \\\n  --sharegpt-output-len 256\n```\n\n3) OpenAI-compatible Chat Completions (streaming):\n\n```bash\npython3 -m sglang.bench_serving \\\n  --backend vllm-chat \\\n  --base-url http://127.0.0.1:8000 \\\n  --model meta-llama/Llama-3.1-8B-Instruct \\\n  --dataset-name random \\\n  --num-prompts 500 \\\n  --apply-chat-template\n```\n\n4) Images (VLM) with chat template:\n\n```bash\npython3 -m sglang.bench_serving \\\n  --backend sglang \\\n  --host 127.0.0.1 --port 30000 \\\n  --model your-vlm-model \\\n  --dataset-name image \\\n  --image-count 2 \\\n  --image-resolution 720p \\\n  --random-input-len 128 --random-output-len 256 \\\n  --num-prompts 200 \\\n  --apply-chat-template\n```\n\n4a) Images with custom resolution:\n\n```bash\npython3 -m sglang.bench_serving \\\n  --backend sglang \\\n  --host 127.0.0.1 --port 30000 \\\n  --model your-vlm-model \\\n  --dataset-name image \\\n  --image-count 1 \\\n  --image-resolution 512x768 \\\n  --random-input-len 64 --random-output-len 128 \\\n  --num-prompts 100 \\\n  --apply-chat-template\n```\n\n4b) 1080p images with PNG format and blank content:\n\n```bash\npython3 -m sglang.bench_serving \\\n  --backend sglang \\\n  --host 127.0.0.1 --port 30000 \\\n  --model your-vlm-model \\\n  --dataset-name image \\\n  --image-count 1 \\\n  --image-resolution 1080p \\\n  --image-format png \\\n  --image-content blank \\\n  --random-input-len 64 --random-output-len 128 \\\n  --num-prompts 100 \\\n  --apply-chat-template\n```\n\n5) Generated shared prefix (long system prompts + short questions):\n\n```bash\npython3 -m sglang.bench_serving \\\n  --backend sglang \\\n  --host 127.0.0.1 --port 30000 \\\n  --model meta-llama/Llama-3.1-8B-Instruct \\\n  --dataset-name generated-shared-prefix \\\n  --gsp-num-groups 64 --gsp-prompts-per-group 16 \\\n  --gsp-system-prompt-len 2048 --gsp-question-len 128 --gsp-output-len 256 \\\n  --num-prompts 1024\n```\n\n6) Tokenized prompts (ids) for strict length control (sglang only):\n\n```bash\npython3 -m sglang.bench_serving \\\n  --backend sglang \\\n  --host 127.0.0.1 --port 30000 \\\n  --model meta-llama/Llama-3.1-8B-Instruct \\\n  --dataset-name random \\\n  --tokenize-prompt \\\n  --random-input-len 2048 --random-output-len 256 --random-range-ratio 0.2\n```\n\n7) Profiling and cache flush (sglang):\n\n```bash\npython3 -m sglang.bench_serving \\\n  --backend sglang \\\n  --host 127.0.0.1 --port 30000 \\\n  --model meta-llama/Llama-3.1-8B-Instruct \\\n  --profile \\\n  --flush-cache\n```\n\n8) TensorRT-LLM streaming endpoint:\n\n```bash\npython3 -m sglang.bench_serving \\\n  --backend trt \\\n  --base-url http://127.0.0.1:8000 \\\n  --model your-trt-llm-model \\\n  --dataset-name random \\\n  --num-prompts 100 \\\n  --disable-ignore-eos\n```\n\n9) Evaluating large-scale KVCache sharing with mooncake trace (sglang only):\n\n```bash\npython3 -m sglang.bench_serving \\\n  --backend sglang \\\n  --host 127.0.0.1 --port 30000 \\\n  --model model-name \\\n  --dataset-name mooncake \\\n  --mooncake-slowdown-factor 1.0 \\\n  --mooncake-num-rounds 1000 \\\n  --mooncake-workload conversation|mooncake|agent|synthetic\n  --use-trace-timestamps true \\\n  --random-output-len 256\n```\n\n### Troubleshooting\n\n- All requests failed: verify `--backend`, server URL/port, `--model`, and authentication. Check warmup errors printed by the script.\n- Throughput seems too low: adjust `--request-rate` and `--max-concurrency`; verify server batch size/scheduling; ensure streaming is enabled if appropriate.\n- Token counts look odd: prefer chat/instruct models with proper chat templates; otherwise tokenization of gibberish may be inconsistent.\n- Image/MMMU datasets: ensure you installed extra deps (`pillow`, `datasets`, `pybase64`).\n- Authentication errors (401/403): set `OPENAI_API_KEY` or disable auth on your server.\n\n### Notes\n\n- The script raises the file descriptor soft limit (`RLIMIT_NOFILE`) to help with many concurrent connections.\n- For sglang, `/get_server_info` is queried post-run to report speculative decoding accept length when available.\n"
  },
  {
    "path": "docs/developer_guide/benchmark_and_profiling.md",
    "content": "# Benchmark and Profiling\n\n## Benchmark\n\nSGLang provides four benchmark tools that operate at different levels of the stack. The table below summarizes their key differences:\n\n| Tool                       | HTTP Server                                   | Scheduler                               | Use Case                                                                   |\n| -------------------------- | --------------------------------------------- | --------------------------------------- | -------------------------------------------------------------------------- |\n| `bench_serving`            | Yes (async HTTP client to a running server)   | Yes (indirectly, via server)            | Realistic online serving benchmarks with latency metrics (TTFT, TPOT, ITL) |\n| `bench_one_batch_server`   | Yes (sends HTTP requests to a running server) | Yes (indirectly, via server)            | End-to-end single-batch latency including HTTP and scheduler overhead      |\n| `bench_offline_throughput` | No                                            | Yes (directly uses `Engine` in-process) | Maximum throughput measurement without HTTP overhead                       |\n| `bench_one_batch`          | No                                            | No (directly calls `ModelRunner`)       | Kernel-level latency profiling of a single static batch                    |\n\nUse `bench_serving` by default unless there are specific needs.\n\n**`bench_serving`** is an async HTTP load-testing client that sends requests at controlled rates with configurable concurrency to a running server. It measures realistic online serving metrics including time-to-first-token (TTFT), time-per-output-token (TPOT), inter-token latency (ITL), and throughput. Use `num-prompts >= 5 * max-concurrency` to measure steady-state performance. Launch a server with `sglang.launch_server` first.\n\n  ```bash\n  python3 -m sglang.bench_serving --backend sglang --max-concurrency 16 --num-prompts 80 --random-input-len 256 --random-output-len 32 --dataset-name random\n  ```\n\n**`bench_one_batch_server`** sends a single batch as one HTTP request to a running server. Due to only having a single batch, the server is never in a steady-state and metrics will be biased. Launch a server with `sglang.launch_server` first.\n\n  ```bash\n  python3 -m sglang.bench_one_batch_server --base-url http://127.0.0.1:30000 --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --batch-size 32 --input-len 256 --output-len 32\n  ```\n\n**`bench_offline_throughput`** directly instantiates the `Engine` object in-process (no HTTP server) and submits all requests at once via `engine.generate()`. The engine's scheduler handles batching and execution. This measures maximum achievable throughput without any network overhead.\n\n  ```bash\n  python3 -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --num-prompts 10\n  ```\n\n**`bench_one_batch`** is the lowest-level tool. It directly instantiates a `ModelRunner` and calls `extend()` / `decode()` on a fixed static batch, bypassing the scheduler entirely. The prefill and decode phases are run separately, making profiling easier but rendering the metrics unrealistic. Because there is no dynamic batching, it may run out of memory for batch sizes that a real server can handle (a real server chunks prefill into smaller batches). This is best suited for profiling individual kernel performance.\n\n  ```bash\n  python3 -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --batch-size 32 --input-len 256 --output-len 32\n  ```\n\n## Profile with PyTorch Profiler\n\n[Pytorch Profiler](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html) is a convenient basic tool to inspect kernel execution time, call stack, and kernel overlap and occupancy.\n\n### Profile a server with `sglang.bench_serving`\n\n```bash\n# set trace path\nexport SGLANG_TORCH_PROFILER_DIR=/root/sglang/profile_log\n\n# start server\npython -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct\n\n# send profiling request from client\npython -m sglang.bench_serving --backend sglang --model meta-llama/Llama-3.1-8B-Instruct --num-prompts 10 --sharegpt-output-len 100 --profile\n```\n\nThe `SGLANG_TORCH_PROFILER_DIR` environment variable must be set on both the server and client side; otherwise, the trace file will not be generated correctly. A secure way to do this is by setting it in your shell's resource file (e.g., `~/.bashrc` for bash).\n\nFor more details, please refer to [Bench Serving Guide](./bench_serving.md).\n\n### Profile In PD Disaggregation Mode\n\nWhen profiling in PD disaggregation mode, prefill and decode workers **must be profiled separately** due to torch profiler limitations. The `bench_serving` command provides dedicated options for this:\n\n#### Profile Prefill Workers\n\n```bash\n# set trace path\nexport SGLANG_TORCH_PROFILER_DIR=/root/sglang/profile_log\n\n# start prefill and decode servers (see PD disaggregation docs for setup)\npython -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode prefill\npython -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode decode --port 30001 --base-gpu-id 1\n\n# start router\npython -m sglang_router.launch_router --pd-disaggregation --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000\n\n# send profiling request targeting prefill workers\npython -m sglang.bench_serving --backend sglang --model meta-llama/Llama-3.1-8B-Instruct --num-prompts 10 --sharegpt-output-len 100 --profile --pd-separated --profile-prefill-url http://127.0.0.1:30000\n```\n\n#### Profile Decode Workers\n\n```bash\n# send profiling request targeting decode workers\npython -m sglang.bench_serving --backend sglang --model meta-llama/Llama-3.1-8B-Instruct --num-prompts 10 --sharegpt-output-len 100 --profile --pd-separated --profile-decode-url http://127.0.0.1:30001\n```\n\n#### Important Notes\n\n- `--profile-prefill-url` and `--profile-decode-url` are **mutually exclusive** - you cannot profile both at the same time\n- Both options support multiple worker URLs for multi-instance setups:\n  ```bash\n  # Profile multiple prefill workers\n  python -m sglang.bench_serving --backend sglang --model meta-llama/Llama-3.1-8B-Instruct --num-prompts 10 --profile --pd-separated --profile-prefill-url http://127.0.0.1:30000 http://127.0.0.1:30002\n\n  # Profile multiple decode workers\n  python -m sglang.bench_serving --backend sglang --model meta-llama/Llama-3.1-8B-Instruct --num-prompts 10 --profile --pd-separated --profile-decode-url http://127.0.0.1:30001 http://127.0.0.1:30003\n  ```\n- Make sure `SGLANG_TORCH_PROFILER_DIR` is set on all worker nodes before starting the servers\n- For more details on setting up PD disaggregation, see [PD Disaggregation Guide](../advanced_features/pd_disaggregation.md)\n\n### Profile a server with `sglang.bench_offline_throughput`\n```bash\nexport SGLANG_TORCH_PROFILER_DIR=/root/sglang/profile_log\n\n# profile one batch with bench_one_batch.py\n# batch size can be controlled with --batch argument\npython3 -m sglang.bench_one_batch --model-path meta-llama/Llama-3.1-8B-Instruct --batch 32 --input-len 1024 --output-len 10 --profile\n\n# profile multiple batches with bench_offline_throughput.py\npython -m sglang.bench_offline_throughput --model-path meta-llama/Llama-3.1-8B-Instruct --dataset-name random --num-prompts 10 --profile --mem-frac=0.8\n```\n\n### Profile a server with `sglang.profiler`\n\nWhen the server is running (e.g., processing a decoding request), you can start live profiling immediately by sending a profile request to the server.\n\nYou can do this by running `python3 -m sglang.profiler`. For example:\n\n```\n# Terminal 1: Send a generation request\npython3 -m sglang.test.send_one\n\n# Terminal 2: Before the above request finishes, quickly launch the following command in a separate terminal.\n# It will generate a profile of the above request for several decoding batches.\npython3 -m sglang.profiler\n```\n\nYou can also combine the above operations into a single command\n\n```\npython3 -m sglang.test.send_one --profile\n```\n\n### Profile a server with HTTP API endpoints\n\nSGLang provides HTTP API endpoints to control profiling on a running server. This allows you to start and stop profiling programmatically, which is useful for capturing specific workload patterns.\n\n#### Using `/start_profile` endpoint\n\nThe `/start_profile` endpoint starts profiling on the server. You can control when profiling begins and how long it runs using the following parameters:\n\n**Basic usage:**\n\n```bash\n# Start profiling immediately for 10 steps\ncurl -X POST http://127.0.0.1:30000/start_profile \\\n  -H \"Content-Type: application/json\" \\\n  -d '{\n    \"num_steps\": 10\n  }'\n```\n\n**Parameters:**\n\n- `output_dir` (optional): Directory where profile traces will be saved. If not specified, uses `SGLANG_TORCH_PROFILER_DIR` environment variable, or `/tmp` as the default\n- `num_steps` (optional): Number of steps to profile. If not specified, profiling continues until manually stopped with `/end_profile`\n- `start_step` (optional): Step number at which to start profiling (inclusive). Useful for skipping warmup iterations\n- `activities` (optional): List of activities to profile, e.g., `[\"CPU\", \"GPU\"]`. Default is `[\"CPU\", \"GPU\"]`\n- `merge_profiles` (optional): Whether to merge distributed traces. Default is `false`\n\n**Note on step ranges:** Profiling starts at `start_step` (inclusive) and continues for `num_steps` iterations. For example, with `start_step=3` and `num_steps=10`, profiling captures steps 3, 4, 5, 6, 7, 8, 9, 10, 11, and 12 (10 steps total, starting from step 3).\n\n**Advanced usage with `start_step`:**\n\n```bash\n# Wait 5 steps (warmup), then profile for 10 steps\ncurl -X POST http://127.0.0.1:30000/start_profile \\\n  -H \"Content-Type: application/json\" \\\n  -d '{\n    \"output_dir\": \"/tmp/profiles\",\n    \"start_step\": 5,\n    \"num_steps\": 10,\n    \"activities\": [\"CPU\", \"GPU\"]\n  }'\n```\n\n**Continuous profiling (manual stop):**\n\n```bash\n# Start profiling without num_steps - must manually stop with /end_profile\ncurl -X POST http://127.0.0.1:30000/start_profile\n```\n\n#### Using `/end_profile` endpoint\n\nThe `/end_profile` endpoint stops an ongoing profiling session and saves the trace file.\n\n```bash\n# Stop profiling and save traces\ncurl -X POST http://127.0.0.1:30000/end_profile\n```\n\nThis is only needed when you start profiling without specifying `num_steps`. If `num_steps` is specified, profiling will automatically stop after that many steps.\n\n#### Example workflow\n\n```bash\n# Terminal 1: Start the server\nexport SGLANG_TORCH_PROFILER_DIR=/tmp/profiles\npython -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct\n\n# Terminal 2: Start continuous profiling\ncurl -X POST http://127.0.0.1:30000/start_profile \\\n  -H \"Content-Type: application/json\" \\\n  -d '{\n    \"start_step\": 3\n  }'\n\n# Terminal 3: Send requests to generate load\npython -m sglang.bench_serving --backend sglang --num-prompts 100\n\n# Terminal 2: Stop profiling when done\ncurl -X POST http://127.0.0.1:30000/end_profile\n```\n\n### Profiler Trace Merger for Distributed Traces\n\nSGLang now supports automatic merging of profiling traces from distributed setups with multiple parallelism types (TP, DP, PP, EP). This feature is particularly useful for analyzing performance across distributed runs.\n\n#### Multi-Node Profiling and Shared Storage Considerations\n\nSingle-node profiler output merging is completely supported. When profiling in distributed environments spanning multiple nodes, shared storage (e.g., NFS, Lustre) should be accessible by all nodes for the output directory to enable merging of trace files.\n\nIf there is no shared storage accessible across nodes, automatic merging of trace files during profiling is not supported directly as of now.\n\n#### HTTP API Usage\n\n```bash\n# Start profiling with automatic trace merging enabled\ncurl -X POST <BASE_URL>/start_profile \\\n  -H \"Content-Type: application/json\" \\\n  -d '{\n    \"output_dir\": \"/tmp/profiles\", # where to store profile traces\n    \"num_steps\": 10,\n    \"activities\": [\"CPU\", \"GPU\"],\n    \"merge_profiles\": true # optional argument to merge profile traces (default=False)\n  }'\n```\n\n#### Command Line Usage\n\n```bash\n# Start profiling with merge enabled\npython -m sglang.profiler \\\n  --num-steps 10 \\\n  --cpu \\\n  --gpu \\\n  --output-dir /tmp/profiles \\\n  --merge-profiles # optional argument to merge profile traces (default=False)\n```\n\n#### Output Files\n\nThe profile merger generates:\n- Individual rank trace files: `{profile_id}-TP-{tp}-DP-{dp}-PP-{pp}-EP-{ep}.trace.json.gz`\n- Merged trace file: `merged-{profile_id}.trace.json.gz`\n\n### Possible PyTorch bugs\nIf in any cases you encounter the following error (for example, using qwen 2.5 VL):\n```bash\nRuntimeError: !stack.empty() INTERNAL ASSERT FAILED at \"/pytorch/torch/csrc/autograd/profiler_python.cpp\":983, please report a bug to PyTorch. Python replay stack is empty.\n```\nThis is likely a PyTorch Bug reported in [Bug: vLLM Profiler](https://github.com/vllm-project/vllm/issues/18240) and [Bug: torch.profiler.profile](https://github.com/pytorch/pytorch/issues/101632). As a workaround, you may disable `with_stack` with an environment variable such as follows:\n```bash\nexport SGLANG_PROFILE_WITH_STACK=False\npython -m sglang.bench_offline_throughput --model-path meta-llama/Llama-3.1-8B-Instruct --dataset-name random --num-prompts 10 --profile --mem-frac=0.8\n```\n\n### View traces\n\nTrace files can be loaded and visualized from:\n\n1. https://ui.perfetto.dev/ (any browser)\n2. chrome://tracing (Chrome browser only)\n\nIf browser cannot open trace file due to its large size,\nclient can generate a small trace file (<100MB) by controlling number of prompts and lengths of prompt outputs.\nFor example, when profiling a server,\n\n```bash\npython -m sglang.bench_serving --backend sglang --model meta-llama/Llama-3.1-8B-Instruct --num-prompts 2 --sharegpt-output-len 100 --profile\n```\n\nThis command sets the number of prompts to 2 with `--num-prompts` argument and limits the length of output sequences to 100 with `--sharegpt-output-len` argument, which can generate a small trace file for browser to open smoothly.\n\nAdditionally, if you want to locate the SGLang Python source code through the cuda kernel in Trace, you need to disable CUDA Graph when starting the service. This can be done by using the `--disable-cuda-graph` parameter in the command to start the service.\n\n## Profile with Nsight\n\n[Nsight systems](https://docs.nvidia.com/nsight-systems/) is an advanced tool that exposes more profiling details, such as register and shared memory usage, annotated code regions and low-level CUDA APIs and events.\n\n1. Prerequisite:\n\n   Install using apt, or run inside a [NVIDIA Docker container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch/tags) or [SGLang Docker container](https://github.com/sgl-project/sglang/tree/main/docker).\n\n   ```bash\n   # install nsys\n   # https://docs.nvidia.com/nsight-systems/InstallationGuide/index.html\n   apt update\n   apt install -y --no-install-recommends gnupg\n   echo \"deb http://developer.download.nvidia.com/devtools/repos/ubuntu$(source /etc/lsb-release; echo \"$DISTRIB_RELEASE\" | tr -d .)/$(dpkg --print-architecture) /\" | tee /etc/apt/sources.list.d/nvidia-devtools.list\n   apt-key adv --fetch-keys http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub\n   apt update\n   apt install nsight-systems-cli\n   ```\n\n2. To profile a single batch, use\n\n   ```bash\n   nsys profile --trace-fork-before-exec=true --cuda-graph-trace=node python3 -m sglang.bench_one_batch --model meta-llama/Meta-Llama-3-8B --batch-size 64 --input-len 512\n   ```\n\n3. To profile a server, e.g.\n\n   ```bash\n   # launch the server, set the delay and duration times according to needs\n   # after the duration time has been used up, server will be killed by nsys\n\n   nsys profile --trace-fork-before-exec=true --cuda-graph-trace=node -o sglang.out --delay 60 --duration 70 python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache\n\n   # client\n   python3 -m sglang.bench_serving --backend sglang --num-prompts 1000 --dataset-name random --random-input 1024 --random-output 512\n   ```\n\n   In practice, we recommend users to set `--duration` argument to a large value. Whenever user wants the server to stop profiling. Firstly run:\n\n   ```bash\n   nsys sessions list\n   ```\n\n   to get the session id in the form of `profile-XXXXX`, then run:\n\n   ```bash\n   nsys stop --session=profile-XXXXX\n   ```\n\n   to manually kill the profiler and generate `nsys-rep` files instantly.\n\n4. Use NVTX to annotate code regions, e.g. to see their execution time.\n\n   ```bash\n   # install nvtx\n   pip install nvtx\n   ```\n\n   ```python\n   # code snippets\n   import nvtx\n   with nvtx.annotate(\"description\", color=\"color\"):\n       # some critical code\n   ```\n\n### Layer-wise NVTX Profiling with Nsight Systems\n\nSGLang provides built-in layerwise NVTX annotations that can be combined with the CUDA Profiler for detailed per-layer profiling in Nsight Systems. This is particularly useful for identifying performance bottlenecks at the layer level.\n\n#### Using `--enable-layerwise-nvtx-marker` with Nsight Systems and `/start_profile`\n\nThe `--enable-layerwise-nvtx-marker` flag automatically adds NVTX markers to every layer in your model. This is particularly powerful when combined with Nsight Systems profiling to see detailed per-layer performance.\n\n**Method 1: Using `/start_profile` with CUDA_PROFILER (for programmatic control)**\n\nThis method allows you to control exactly when profiling starts/stops via HTTP API while Nsight Systems is running.\n\n1. Launch the server with layerwise NVTX enabled under Nsight Systems:\n\n   ```bash\n   # Terminal 1: Start server with nsys and capture-range option\n   nsys profile --trace-fork-before-exec=true \\\n     --cuda-graph-trace=node \\\n     --capture-range=cudaProfilerApi \\\n     --capture-range-end=stop \\\n     -o layerwise_profile \\\n     python -m sglang.launch_server \\\n       --model-path meta-llama/Llama-3.1-8B-Instruct \\\n       --enable-layerwise-nvtx-marker \\\n       --disable-cuda-graph\n   ```\n\n   Note: NVTX markers are not emitted for kernel launches captured by CUDA graphs. Use `--disable-cuda-graph` to ensure all layerwise NVTX markers are emitted in the trace.\n\n2. In another terminal, control profiling via `/start_profile` with `CUDA_PROFILER` activity:\n\n   ```bash\n   # Terminal 2: Wait for server to be ready, then start CUDA profiling\n   # Wait 3 steps for warmup, then profile for 10 steps\n   curl -X POST http://127.0.0.1:30000/start_profile \\\n     -H \"Content-Type: application/json\" \\\n     -d '{\n       \"start_step\": 3,\n       \"num_steps\": 10,\n       \"activities\": [\"CUDA_PROFILER\"]\n     }'\n   ```\n\n3. Send requests to generate load:\n\n   ```bash\n   # Terminal 3: Generate workload\n   python -m sglang.bench_serving --backend sglang --num-prompts 100\n   ```\n\n4. Profiling will automatically stop after 10 steps (due to `num_steps: 10`). If you hadn't specified `num_steps`, you would need to manually stop it:\n\n   ```bash\n   # Terminal 2: Only needed if num_steps was not specified\n   curl -X POST http://127.0.0.1:30000/end_profile\n   ```\n\nThe `--capture-range=cudaProfilerApi` option tells Nsight Systems to only capture data between `cudaProfilerStart()` and `cudaProfilerStop()` calls (triggered by `/start_profile` and `/end_profile`), reducing overhead and file size. The `start_step` parameter skips the first 3 steps to avoid capturing warmup overhead.\n\n**Method 2: Simpler approach without `/start_profile` API**\n\nFor simpler use cases where you don't need fine-grained control over profiling start/stop, you can profile with Nsight Systems capturing the entire workload:\n\n```bash\n# Terminal 1: Start server with layerwise NVTX\n# Note: --disable-cuda-graph ensures all NVTX markers are emitted\npython -m sglang.launch_server \\\n  --model-path meta-llama/Llama-3.1-8B-Instruct \\\n  --enable-layerwise-nvtx-marker \\\n  --disable-cuda-graph\n\n# Terminal 2: Profile the benchmarking client\nnsys profile --trace-fork-before-exec=true \\\n  --cuda-graph-trace=node \\\n  -o layerwise_profile \\\n  python -m sglang.bench_serving --backend sglang --num-prompts 10\n```\n\nThis approach profiles the entire client execution, including all server interactions. The layerwise NVTX markers will be visible in the Nsight Systems timeline.\n\n**Viewing the profiling results:**\n\nOpen the generated `.qdrep` file with Nsight Systems:\n\n```bash\nnsys-ui layerwise_profile.qdrep\n```\n\nIn the Nsight Systems GUI, you'll see:\n- **NVTX ranges**: Each layer appears as a labeled range in the timeline with detailed information in the marker metadata\n- **CUDA kernels**: All GPU kernels are shown alongside the layer annotations\n- **Layer hierarchy**: The full module path (e.g., `meta-llama/Meta-Llama-3.1-8B-Instruct.model.layers.0.self_attn.qkv_proj`) helps identify specific layers. The prefix uses the full model path from `--model-path`.\n- **Tensor shapes**: Input/output dimensions and parameter shapes are included in the NVTX marker data\n\n**Benefits of layerwise NVTX profiling:**\n\n- **Granular visibility**: See exactly which layers are taking the most time\n- **Memory tracking**: Identify layers with large memory allocations\n- **Bottleneck identification**: Quickly locate inefficient operations\n- **Communication overhead**: In multi-GPU setups, see per-layer communication costs\n- **Development debugging**: Validate that model architecture changes have the expected performance impact\n\n## Other tips\n\n1. You can benchmark a model using dummy weights by only providing the config.json file. This allows for quick testing of model variants without training. To do so, add `--load-format dummy` to the above commands and then you only need a correct `config.json` under the checkpoint folder.\n2. You can benchmark a model with modified configs (e.g., less layers) by using `--json-model-override-args`. For example, you can benchmark a model with only 2 layers and 2 kv heads using:\n\n   ```bash\n   python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --batch 32 --input-len 256 --output-len 32 --load-format dummy --json-model-override-args '{\"num_hidden_layers\": 1, \"num_key_value_heads\": 1}'\n   ```\n\n3. You can use `--python-backtrace=cuda` to see python call stack for all CUDA kernels, as in PyTorch Profiler. (Caveat: this can cause inaccurately long kernel runtimes for CUDA event based timing)\n4. For more arguments see [Nsight Systems User Guide](https://docs.nvidia.com/nsight-systems/UserGuide/index.html).\n"
  },
  {
    "path": "docs/developer_guide/contribution_guide.md",
    "content": "# Contribution Guide\n\nWelcome to **SGLang**! We appreciate your interest in contributing. This guide provides a concise overview of how to set up your environment, run tests, build documentation, and open a Pull Request (PR). Whether you’re fixing a small bug or developing a major feature, we encourage following these steps for a smooth contribution process.\n\n## Install SGLang from Source\n\n### Fork and clone the repository\n\n**Note**: New contributors do **not** have the write permission to push to the official SGLang repo. Please fork the repository under your GitHub account, then clone your fork locally.\n\n```bash\ngit clone https://github.com/<your_user_name>/sglang.git\n```\n\n### Build from source\n\nRefer to [Install SGLang from Source](../get_started/install.md#method-2-from-source).\n\n## Format code with pre-commit\n\nWe use [pre-commit](https://pre-commit.com/) to maintain consistent code style checks. Before pushing your changes, please run:\n\n```bash\npip3 install pre-commit\npre-commit install\npre-commit run --all-files\n```\n\n- **`pre-commit run --all-files`** manually runs all configured checks, applying fixes if possible. If it fails the first time, re-run it to ensure lint errors are fully resolved. Make sure your code passes all checks **before** creating a Pull Request.\n- **Do not commit** directly to the `main` branch. Always create a new branch (e.g., `feature/my-new-feature`), push your changes, and open a PR from that branch.\n\n## Run and add unit tests\n\nIf you add a new feature or fix a bug, please add corresponding unit tests to ensure coverage and prevent regression.\nSGLang uses Python's built-in [unittest](https://docs.python.org/3/library/unittest.html) framework with [pytest](https://docs.pytest.org/) as the test runner.\n\n### Unit tests (no server required)\n\nUnit tests live under [`test/registered/unit/`](https://github.com/sgl-project/sglang/tree/main/test/registered/unit), organized to mirror the `python/sglang/srt/` source tree. These tests validate component logic **without** launching a server or loading real model weights.\n\n**When to add a unit test:** If you modify a file under `python/sglang/srt/`, check whether a corresponding test exists in `test/registered/unit/` and add coverage for your changes. For example:\n\n```\nsrt/mem_cache/radix_cache.py   →  unit/mem_cache/test_radix_cache.py\nsrt/sampling/sampling_params.py →  unit/sampling/test_sampling_params.py\n```\n\n**Run unit tests locally:**\n\n```bash\npytest test/registered/unit/ -v                # all unit tests\npytest test/registered/unit/mem_cache/ -v      # one module\n```\n\n**Run with coverage:**\n\n```bash\npytest test/registered/unit/ --cov --cov-config=.coveragerc -v\n```\n\nFor conventions on CI registration, test structure, and examples, see [`test/registered/unit/README.md`](https://github.com/sgl-project/sglang/tree/main/test/registered/unit/README.md).\n\n### E2E tests (server required)\n\nFor tests that require launching a server, refer to [`test/registered/README.md`](https://github.com/sgl-project/sglang/tree/main/test/registered/README.md) for guidance on where to place your test.\n\nFor detailed instructions on running tests and integrating them into CI, refer to [test/README.md](https://github.com/sgl-project/sglang/tree/main/test/README.md).\n\n## Write documentations\n\nWe recommend new contributors start from writing documentation, which helps you quickly understand SGLang codebase.\nFor more details, please refer to [docs/README.md](https://github.com/sgl-project/sglang/tree/main/docs/README.md).\n\n## Test the accuracy\nIf your code changes the model output, please run the accuracy tests. A quick sanity check is the few-shot GSM8K.\n\n```\n# Launch a server\npython3 -m sglang.launch_server --model Qwen/Qwen2-7B-Instruct\n\n# Evaluate\npython3 -m sglang.test.few_shot_gsm8k --num-questions 200\n```\n\nPlease note that the above script is primarily a sanity check, not a rigorous accuracy or speed test.\nThis test can have significant variance (1%–5%) in accuracy due to batching and the non-deterministic nature of the inference engine.\nAlso, do not rely on the \"Latency/Output throughput\" from this script, as it is not a proper speed test.\n\nGSM8K is too easy for state-of-the-art models nowadays. Please try your own more challenging accuracy tests.\nYou can find additional accuracy eval examples in:\n- [test_eval_accuracy_large.py](https://github.com/sgl-project/sglang/blob/main/test/srt/test_eval_accuracy_large.py)\n- [test_gpt_oss_1gpu.py](https://github.com/sgl-project/sglang/blob/main/test/srt/test_gpt_oss_1gpu.py)\n\n## Benchmark the speed\nRefer to [Benchmark and Profiling](../developer_guide/benchmark_and_profiling.md).\n\n## Requesting a review for merge\nYou can follow the pull request merge process described in [MAINTAINER.md](https://github.com/sgl-project/sglang/blob/main/.github/MAINTAINER.md).\nYou will need to work with the Merge Oncall, Codeowner, and other reviewers to get their approvals.\nThen your PR can be merged.\n\n## How to Trigger CI Tests\n\nWe have a lot of open PRs but limited CI machines, so only top and trusted contributors have permission to trigger CI tests.\nUsers with permission are listed in the [CI_PERMISSIONS.json](https://github.com/sgl-project/sglang/blob/main/.github/CI_PERMISSIONS.json)\n\n**PR authors** can always use `/rerun-failed-ci` on their own PRs, even if they are not listed in `CI_PERMISSIONS.json`.\n\nFor CI to run on a pull request, it must have the \"run-ci\" label. Authorized users can add the label or rerun failed tests by commenting on the PR with one of these commands:\n\n- `/tag-run-ci-label`: Adds the \"run-ci\" label. Every future commit will trigger CI.\n- `/rerun-failed-ci`: Reruns the failed or flaky tests from the most recent commit.\n- `/tag-and-rerun-ci`: A single command that performs both `/tag-run-ci-label` and `/rerun-failed-ci`.\n- `/rerun-stage <stage-name>`: Reruns a specific test stage without waiting for its dependencies. This is useful when you want to quickly validate a fix for a specific test failure instead of waiting ~30 minutes for preceding stages to complete.\n\nIf you have permission, the [Slash Command Handler](https://github.com/sgl-project/sglang/actions/workflows/slash-command-handler.yml) will run your command and react with a 👍 to your comment. It may take up to a few minutes for the reaction to appear. Here’s a usage [example](https://github.com/sgl-project/sglang/pull/14253#issuecomment-3599509302).\n\nTo avoid spamming a PR with too many `/rerun-failed-ci` comments, you can also trigger the command by editing an existing comment and adding any suffix (e.g., `/rerun-failed-ci try again`).\n\nExample of rerunning a single test stage: `/rerun-stage unit-test-backend-4-gpu`.\n\nIf you don’t have permission and you’re not the PR author, please ask maintainers to trigger CI for you.\n\n### CI rate limits\n\nDue to CI scheduling and limited resources, higher-priority PRs may preempt running jobs. In such cases, you may need to rerun the tests.\n\nWe apply CI rate limits to prevent abuse and ensure fair usage of our CI resources.\n\nEach CI workflow has a default limit defined in its workflow configuration file. For example, in [pr-gate.yml](https://github.com/sgl-project/sglang/blob/main/.github/workflows/pr-gate.yml), the default cooldown period is 120 minutes, and each workflow can override it via the `cool-down-minutes` input parameter:\n\n```yaml\ncool-down-minutes:\n  description: \"Default cooldown period in minutes; 0 disables rate limiting\"\n  type: number\n  default: 120\n```\n\nUsers listed in [CI_PERMISSIONS.json](https://github.com/sgl-project/sglang/blob/main/.github/CI_PERMISSIONS.json) may have a per-user cooldown interval. In practice, we use the minimum of the workflow’s default window and the user-specific interval.\n\n\n## Code style guidance\n- Avoid code duplication. If the same code snippet (more than five lines) appears multiple times, extract it into a shared function.\n- Minimize device synchronization. Reduce expensive CPU-GPU synchronization operations, such as `tensor.item()` or `tensor.cpu()`, whenever possible. Use vectorized code.\n- Prioritize extreme efficiency. SGLang is a runtime, and most of your code runs on the critical path for every request. Optimize all minor overheads as much as possible, especially in the model forward code.\n  - A common pattern is some runtime checks in the model forward pass (e.g., [this](https://github.com/sgl-project/sglang/blob/f1b0eda55c2c4838e8ab90a0fac7fb1e3d7064ab/python/sglang/srt/models/deepseek_v2.py#L486-L491)). These are very likely the same for every layer. Please cache the result as a single boolean value whenever possible.\n- Make functions as pure as possible. Avoid in-place modification of arguments.\n- Keep files concise. If a file exceeds 2,000 lines of code, split it into multiple smaller files. (e.g., `scheduler.py`, `scheduler_output_processor_mixin.py`)\n- Keep tests run fast.\n  - If a single test file run longer than 500 seconds, split it into multiple smaller files (e.g., `test_eagle_infer_a.py`, `test_eagle_infer_b.py`).\n  - If a single job in a github workflow runs longer than 30 mins, split it into smaller jobs/steps.\n  - Reuse server launches in your unit tests to make tests run faster.\n- When supporting new hardware or features, follow these guidelines:\n  - Do not drastically change existing code.\n  - Always prefer new files to introduce specific components for your new hardware (e.g., `allocator_ascend.py`).\n  - If you write multiple if/else blocks for new features, ensure the common path (e.g., NVIDIA hardware or the existing code path) is the first branch.\n\n## How to update sgl-kernel\nSince sglang and the `sglang-kernel` (prior `sgl-kernel`) distribution are separate Python packages, our current GitHub CI infrastructure does not support updating a kernel and using it immediately within the same pull request (PR).\nTo add a new kernel or modify an existing one in the `sgl-kernel/` source tree, you must use multiple PRs.\n\nFollow these steps:\n\n1. Submit a PR to update the sgl-kernel source code without using it in sglang python package (e.g., [#8884](https://github.com/sgl-project/sglang/pull/8884/files)).\n2. Bump the version of the kernel package (e.g., [#9220](https://github.com/sgl-project/sglang/pull/9220/files)).\n   - Once merged, this will trigger an automatic release of the `sglang-kernel` wheel to PyPI.\n   - If not urgent, you can wait for other people to release the wheel. A new version will typically be released within one week.\n3. Apply the changes:\n   - Update the `sglang-kernel` version in `sglang/python/pyproject.toml` to use the modified kernels.\n   - Update the related caller code in the sglang to use the new kernel.\n\n## Tips for newcomers\n\nIf you want to contribute but don’t have a specific idea in mind, pick issues labeled [“good first issue” or “help wanted”](https://github.com/sgl-project/sglang/issues?q=is%3Aissue+label%3A%22good+first+issue%22%2C%22help+wanted%22). These tasks typically have lower complexity and provide an excellent introduction to the codebase. Also check out this [code walk-through](https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/tree/main/sglang/code-walk-through) for a deeper look into SGLang’s workflow.\n\nIf you have any questions or want to start a discussion, please feel free to ask in our [Slack channel](https://slack.sglang.io).\n\nThank you for your interest in SGLang. Happy coding!\n"
  },
  {
    "path": "docs/developer_guide/development_guide_using_docker.md",
    "content": "# Development Guide Using Docker\n\n## Setup VSCode on a Remote Host\n(Optional - you can skip this step if you plan to run sglang dev container locally)\n\n1. In the remote host, download `code` from [Https://code.visualstudio.com/docs/?dv=linux64cli](https://code.visualstudio.com/download) and run `code tunnel` in a shell.\n\nExample\n```bash\nwget https://vscode.download.prss.microsoft.com/dbazure/download/stable/fabdb6a30b49f79a7aba0f2ad9df9b399473380f/vscode_cli_alpine_x64_cli.tar.gz\ntar xf vscode_cli_alpine_x64_cli.tar.gz\n\n# https://code.visualstudio.com/docs/remote/tunnels\n./code tunnel\n```\n\n2. In your local machine, press F1 in VSCode and choose \"Remote Tunnels: Connect to Tunnel\".\n\n## Setup Docker Container\n\n### Option 1. Use the default dev container automatically from VSCode\nThere is a `.devcontainer` folder in the sglang repository root folder to allow VSCode to automatically start up within dev container. You can read more about this VSCode extension in VSCode official document [Developing inside a Container](https://code.visualstudio.com/docs/devcontainers/containers).\n![image](https://github.com/user-attachments/assets/6a245da8-2d4d-4ea8-8db1-5a05b3a66f6d)\n(*Figure 1: Diagram from VSCode official documentation [Developing inside a Container](https://code.visualstudio.com/docs/devcontainers/containers).*)\n\nTo enable this, you only need to:\n1. Start Visual Studio Code and install [VSCode dev container extension](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers).\n2. Press F1, type and choose \"Dev Container: Open Folder in Container.\n3. Input the `sglang` local repo path in your machine and press enter.\n\nThe first time you open it in dev container might take longer due to docker pull and build. Once it's successful, you should set on your status bar at the bottom left displaying that you are in a dev container:\n\n![image](https://github.com/user-attachments/assets/650bba0b-c023-455f-91f9-ab357340106b)\n\nNow when you run `sglang.launch_server` in the VSCode terminal or start debugging using F5, sglang server will be started in the dev container with all your local changes applied automatically:\n\n![image](https://github.com/user-attachments/assets/748c85ba-7f8c-465e-8599-2bf7a8dde895)\n\n\n### Option 2. Start up containers manually (advanced)\n\nThe following startup command is an example for internal development by the SGLang team. You can **modify or add directory mappings as needed**, especially for model weight downloads, to prevent repeated downloads by different Docker containers.\n\n❗️ **Note on RDMA**\n\n    1. `--network host` and `--privileged` are required by RDMA. If you don't need RDMA, you can remove them but keeping them there does not harm. Thus, we enable these two flags by default in the commands below.\n    2. You may need to set `NCCL_IB_GID_INDEX` if you are using RoCE, for example: `export NCCL_IB_GID_INDEX=3`.\n\n```bash\n# Change the name to yours\ndocker run -itd --shm-size 32g --gpus all -v <volumes-to-mount> --ipc=host --network=host --privileged --name sglang_dev lmsysorg/sglang:dev /bin/zsh\ndocker exec -it sglang_dev /bin/zsh\n```\nSome useful volumes to mount are:\n1. **Huggingface model cache**: mounting model cache can avoid re-download every time docker restarts. Default location on Linux is `~/.cache/huggingface/`.\n2. **SGLang repository**: code changes in the SGLang local repository will be automatically synced to the .devcontainer.\n\nExample 1: Mounting local cache folder `/opt/dlami/nvme/.cache` but not the SGLang repo. Use this when you prefer to manually transfer local code changes to the devcontainer.\n```bash\ndocker run -itd --shm-size 32g --gpus all -v /opt/dlami/nvme/.cache:/root/.cache --ipc=host --network=host --privileged --name sglang_zhyncs lmsysorg/sglang:dev /bin/zsh\ndocker exec -it sglang_zhyncs /bin/zsh\n```\nExample 2: Mounting both HuggingFace cache and local SGLang repo. Local code changes are automatically synced to the devcontainer as the SGLang is installed in editable mode in the dev image.\n```bash\ndocker run -itd --shm-size 32g --gpus all -v $HOME/.cache/huggingface/:/root/.cache/huggingface -v $HOME/src/sglang:/sgl-workspace/sglang --ipc=host --network=host --privileged --name sglang_zhyncs lmsysorg/sglang:dev /bin/zsh\ndocker exec -it sglang_zhyncs /bin/zsh\n```\n## Debug SGLang with VSCode Debugger\n1. (Create if not exist) open `launch.json` in VSCode.\n2. Add the following config and save. Please note that you can edit the script as needed to apply different parameters or debug a different program (e.g. benchmark script).\n     ```JSON\n       {\n          \"version\": \"0.2.0\",\n          \"configurations\": [\n              {\n                  \"name\": \"Python Debugger: launch_server\",\n                  \"type\": \"debugpy\",\n                  \"request\": \"launch\",\n                  \"module\": \"sglang.launch_server\",\n                  \"console\": \"integratedTerminal\",\n                  \"args\": [\n                      \"--model-path\", \"meta-llama/Llama-3.2-1B\",\n                      \"--host\", \"0.0.0.0\",\n                      \"--port\", \"30000\",\n                      \"--trust-remote-code\",\n                  ],\n                  \"justMyCode\": false\n              }\n          ]\n      }\n    ```\n\n3. Press \"F5\" to start. VSCode debugger will ensure that the program will pause at the breakpoints even if the program is running at remote SSH/Tunnel host + dev container.\n\n## Profile\n\n```bash\n# Change batch size, input, output and add `disable-cuda-graph` (for easier analysis)\n# e.g. DeepSeek V3\nnsys profile -o deepseek_v3 python3 -m sglang.bench_one_batch --batch-size 1 --input 128 --output 256 --model deepseek-ai/DeepSeek-V3 --trust-remote-code --tp 8 --disable-cuda-graph\n```\n\n## Evaluation\n\n```bash\n# e.g. gsm8k 8 shot\npython3 benchmark/gsm8k/bench_sglang.py --num-questions 2000 --parallel 2000 --num-shots 8\n```\n"
  },
  {
    "path": "docs/developer_guide/development_jit_kernel_guide.md",
    "content": "# Development Guide for JIT Kernels\n\n## Environment Setup\n\nWe strongly recommend using `clangd` as the language server for JIT kernel development.\nFor Ubuntu/Debian, you can download clangd from [apt.llvm.org](https://apt.llvm.org/).\nIf you are using VS Code, we recommend installing the `clangd` extension for better IDE integration.\n\nAll JIT-related files are located in `python/sglang/jit_kernel`.\nUnlike `sgl-kernel`, which compiles CUDA/C++ binaries ahead of time (AOT), just-in-time (JIT) kernels are compiled at runtime.\nConsequently, a static `compile_commands.json` cannot be generated.\nTo enable code completion with `clangd`, run `python -m sglang.jit_kernel` to generate a `.clangd` configuration file in your current directory.\nAfter generating the file, restart the clangd language server. It should now recognize all JIT kernel files.\n\n## Code Structure\n\n### C++ Implementation\n\nC++ source code is located in `python/sglang/jit_kernel/csrc`.\nReusable functions should be placed in `python/sglang/jit_kernel/include`.\n\nWe use [tvm-ffi](https://github.com/apache/tvm-ffi) for efficient foreign language bindings.\nRefer to the [documentation](https://tvm.apache.org/ffi/) for advanced usage, such as exporting C++ objects.\nTypically, `tvm::ffi::TensorView` is sufficient for passing PyTorch Tensors from Python.\n\n### Python Interface\n\nPython interfaces are defined in `python/sglang/jit_kernel`.\nThe `load_jit` utility function in `python/sglang/jit_kernel/utils.py` loads and returns the compiled module.\nTo export a C++ function (e.g., `cpp_func`), pass `cuda_wrappers=[(\"func\", \"cpp_func\")]` to `load_jit`.\nThe function can then be called in Python as `module.func`.\n\nFor caching compiled modules, prefer `sglang.jit_kernel.utils.cache_once` over `functools.lru_cache`.\n`functools.lru_cache` is not compatible with `torch.compile`.\n\n### C++ Utilities\n\nThe following C++ utilities are available:\n\n#### Integer Range\n\nSimilar to PyTorch, we provide an `irange` function to represent an integer range.\n\n```C++\n#include <sgl_kernel/utils.h>\n\nvoid test() {\n  for (auto i : host::irange(100)) { // [0, 100)\n    // do something\n  }\n  for (auto i : host::irange(0, 100)) { // [0, 100)\n    // do something\n  }\n}\n\n```\n\n#### Runtime Checking\n\n`RuntimeCheck` validates conditions at runtime. It accepts optional arguments for error reporting.\nIf the check fails, these arguments are output to aid debugging.\n`RuntimeDeviceCheck` verifies the status of the last kernel launch.\n\n```C++\n#include <sgl_kernel/utils.h>\n#include <sgl_kernel/utils.cuh>\n\nvoid test() {\n  host::RuntimeCheck(1 + 1 == 2, 1 + 1, \" != \", 2);\n  host::RuntimeDeviceCheck();\n  // check the provided `cudaError_t`\n  host::RuntimeDeviceCheck(cudaGetLastError());\n}\n\n```\n\n#### Tensor Checking\n\n`TensorMatcher` provides a readable way to validate and extract tensor shape information.\n\n```cpp\n#include <sgl_kernel/tensor.h>\n\nvoid test(const tvm::ffi::TensorView k_cache, const tvm::ffi::TensorView v_cache) {\n  using namespace host;\n\n  auto D = SymbolicSize{\"D\"};  // cache dimension\n  auto N = SymbolicSize{\"N\"};  // kvcache stride\n  auto dtype = SymbolicDType{};\n  auto device = SymbolicDevice{};\n\n  TensorMatcher({-1, D})  //\n      .with_strides({N, 1})\n      .with_dtype<int32_t, int64_t>(dtype)\n      .with_device<kDLCUDA, kDLCPU>(device)\n      .verify(k_cache)\n      .verify(v_cache);\n}\n```\n\nConfigure the `TensorMatcher` with expected stride, dtype, and device properties before verification.\n- If `with_strides` is omitted, the tensor is expected to be contiguous.\n- Template arguments in `with_dtype` restrict the allowed data types.\n- Template arguments in `with_device` restrict the allowed devices.\n- Values passed to `with_xxx` methods enforce equality checks.\n- Passing `-1` for size or stride allows matching any value.\n\nA `Symbolic` variable must resolve to the same value across all verifications.\nUse `.unwrap()` to retrieve the matched value after verification.\n\n> Note: `TensorMatcher` is a temporary expression and should not be stored in a variable.\n\n> Tip: Add `//` at the end of the `TensorMatcher` chain to enforce proper indentation.\n\n#### Kernel Launching\n\n`LaunchKernel::resolve_device` retrieves the current `cudaStream` from PyTorch.\nKernels can also be launched directly using `LaunchKernel`.\n\n```cpp\n#include <sgl_kernel/utils.cuh>\n\n#include <dlpack/dlpack.h>\n\n__global__ void kernel() {}\n\nvoid test() {\n  const auto num_blocks = 1;\n  const auto num_threads = 32;\n  const auto dynamic_smem = 0;\n\n  DLDevice dev;  // suppose this is initialized properly\n  host::LaunchKernel(num_blocks, num_threads, dev)(kernel);\n\n  cudaStream_t stream = host::LaunchKernel::resolve_device(dev);\n  host::LaunchKernel(num_blocks, num_threads, stream, dynamic_smem)(kernel);\n}\n\n```\n\n## Add new kernels\n\nThis section walks through a complete, end-to-end example of adding a new JIT kernel to the system.\nWe use a simple add_constant kernel as a running example, which adds a constant integer value to every element of an input tensor.\n\nConceptually, the Python interface looks like this:\n\n```python\ndef add_constant(src: torch.Tensor, c: int):\n    return src + c\n```\n\n### STEP 1: Write the C++ kernel\n\nWrite your CUDA kernel in [jit_kernel/csrc/add_constant.cuh](../../python/sglang/jit_kernel/csrc/add_constant.cuh). For demonstration purposes, we pass the constant value as a template parameter.\n\n```cpp\n#include <sgl_kernel/tensor.h>   // For TensorMatcher, SymbolicSize, SymbolicDevice\n#include <sgl_kernel/utils.cuh>  // For LaunchKernel\n#include <sgl_kernel/utils.h>    // For div_ceil, RuntimeCheck\n\n#include <dlpack/dlpack.h>\n#include <tvm/ffi/container/tensor.h>\n\n#include <cstddef>\n#include <cstdint>\n\nnamespace {\n\ntemplate <int32_t kConstant>\n__global__ void add_constant_kernel(int32_t* dst, const int32_t* src, size_t length) {\n  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;\n  if (idx < length) {\n    dst[idx] = src[idx] + kConstant;\n  }\n}\n\nconstexpr size_t kBlockSize = 256;\n\n// You can also use struct with static method as an alternative\ntemplate <int32_t kConstant>\nvoid add_constant(tvm::ffi::TensorView dst, tvm::ffi::TensorView src) {\n  using namespace host;\n\n  // 1. Validate input tensors\n  SymbolicSize N = {\"num_elements\"};\n  SymbolicDevice device_;\n  TensorMatcher({N})                  // 1D tensor, must be contiguous\n      .with_dtype<int32_t>()          // must be int32\n      .with_device<kDLCUDA>(device_)  // must be on CUDA device\n      .verify(dst)                    // check tensor dst\n      .verify(src);                   // check tensor src\n\n  // 2. Extract required parameters, prepare for kernel launch\n  const size_t num_elements = N.unwrap();\n  const size_t grid_size = div_ceil(num_elements, kBlockSize);\n  const DLDevice device = device_.unwrap();\n  // some extra runtime checks using host::RuntimeCheck\n  RuntimeCheck(num_elements > 0, \"We only support non-empty tensors, got num_elements = \", num_elements);\n\n  // 3. Launch the kernel. Error code will be automatically checked.\n  LaunchKernel(grid_size, kBlockSize, device /*, dynamic_smem*/)(\n      // kernel function\n      add_constant_kernel<kConstant>,\n      // kernel arguments\n      static_cast<int32_t*>(dst.data_ptr()),\n      static_cast<int32_t*>(src.data_ptr()),\n      num_elements);\n}\n\n}  // namespace\n\n```\n\n### STEP 2: Create Python Interfaces\n\nNext, expose the kernel through a Python wrapper.\nCreate a new file at [jit_kernel/add_constant.py](../../python/sglang/jit_kernel/add_constant.py) and expose the needed interfaces.\n\n```python\nfrom __future__ import annotations\nfrom typing import TYPE_CHECKING\n\nimport torch\n\nfrom sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args\n\nif TYPE_CHECKING:\n    from tvm_ffi.module import Module\n\n\n@cache_once\ndef _jit_add_constant_module(constant: int) -> Module:\n    args = make_cpp_args(constant)  # pass all the template argument\n    return load_jit(\n        \"add_constant\",\n        *args,\n        cuda_files=[\"add_constant.cuh\"],\n        cuda_wrappers=[(\"add_constant\", f\"add_constant<{args}>\")],\n    )\n\n\ndef add_constant(src: torch.Tensor, constant: int) -> torch.Tensor:\n    if not src.is_cuda:\n        raise RuntimeError(\"src must be a CUDA tensor\")\n    if src.dtype != torch.int32:\n        raise RuntimeError(f\"Unsupported dtype {src.dtype}. Supported: int32\")\n    dst = torch.empty_like(src)\n    module = _jit_add_constant_module(constant)\n    module.add_constant(dst, src)\n    return dst\n\n```\n\nKeep the Python wrapper thin, but still validate the basic invariants such as device and dtype before dispatch. In the current JIT/FFI path, invalid tensors are not always rejected safely before launch.\n\n### STEP 3: Use your kernel\n\nFinally, import and use the kernel like a regular Python function:\n\n```python\nfrom sglang.jit_kernel.add_constant import add_constant\n```\n\nFor a complete, runnable example, refer to [test_add_constant.py](../../python/sglang/jit_kernel/tests/test_add_constant.py).\n\n## C++ Include Library Reference\n\nThe JIT kernel framework provides a set of reusable C++ headers in\n`python/sglang/jit_kernel/include/sgl_kernel/`. Each header is designed\nto be lightweight and self-contained. Below is a summary of each header\nand its key APIs.\n\n### Core Utilities\n\n| Header | Namespace | Purpose |\n|--------|-----------|---------|\n| `utils.h` | `host` | Host-side essentials: `RuntimeCheck`, `Panic`, `div_ceil`, `irange` |\n| `utils.cuh` | `device` / `host` | Type aliases (`fp16_t`, `bf16_t`, ...), `SGL_DEVICE` macro, PDL helpers, `LaunchKernel`, `RuntimeDeviceCheck` |\n| `source_location.h` | (global) | Portable `std::source_location` wrapper for error reporting |\n| `runtime.cuh` | `host::runtime` | CUDA runtime queries: `get_blocks_per_sm`, `get_sm_count`, `get_cc_major`, `get_runtime_version`, `get_available_dynamic_smem_per_block` |\n\n### Tensor Validation\n\n| Header | Namespace | Purpose |\n|--------|-----------|---------|\n| `tensor.h` | `host` | `TensorMatcher`, `SymbolicSize`, `SymbolicDType`, `SymbolicDevice` |\n\n### Math & Type System\n\n| Header | Namespace | Purpose |\n|--------|-----------|---------|\n| `math.cuh` | `device::math` | `max`, `min`, `abs`, `sqrt`, `rsqrt`, `exp`, `sin`, `cos`, constants |\n| `type.cuh` | (global) / `device` | `dtype_trait<T>`, `packed_t<T>`, `device::cast<To>(from)` |\n\n### Memory Access\n\n| Header | Namespace | Purpose |\n|--------|-----------|---------|\n| `vec.cuh` | `device` | `AlignedVector<T, N>` - vectorized load/store (up to 128-bit; 256-bit requires Blackwell GPUs) |\n| `tile.cuh` | `device::tile` | `Memory<T>` - cooperative tiled memory I/O (thread/warp/CTA) |\n\n### Parallel Primitives\n\n| Header | Namespace | Purpose |\n|--------|-----------|---------|\n| `warp.cuh` | `device::warp` | `reduce_sum`, `reduce_max` via `__shfl_xor_sync` |\n| `cta.cuh` | `device::cta` | `reduce_max` across warps via shared memory |\n| `atomic.cuh` | `device::atomic` | `max` - atomic float max (CUDA + ROCm fallback) |\n\n### Reusable Kernel Templates\n\n| Header | Namespace | Purpose |\n|--------|-----------|---------|\n| `impl/norm.cuh` | `host::norm` / `device::norm` | RMSNorm building blocks (warp & CTA paths, `StorageType`) |\n"
  },
  {
    "path": "docs/developer_guide/evaluating_new_models.md",
    "content": "# Evaluating New Models with SGLang\n\nThis document provides commands for evaluating models' accuracy and performance. Before open-sourcing new models, we strongly suggest running these commands to verify whether the score matches your internal benchmark results.\n\n**For cross verification, please submit commands for installation, server launching, and benchmark running with all the scores and hardware requirements when open-sourcing your models.**\n\n[Reference: MiniMax M2](https://github.com/sgl-project/sglang/pull/12129)\n\n## Accuracy\n\n### LLMs\n\nSGLang provides built-in scripts to evaluate common benchmarks.\n\n**MMLU**\n\n```bash\npython -m sglang.test.run_eval \\\n  --eval-name mmlu \\\n  --port 30000 \\\n  --num-examples 1000 \\\n  --max-tokens 8192\n```\n\n**GSM8K**\n\n```bash\npython -m sglang.test.few_shot_gsm8k \\\n  --host 127.0.0.1 \\\n  --port 30000 \\\n  --num-questions 200 \\\n  --num-shots 5\n```\n\n**HellaSwag**\n\n```bash\npython benchmark/hellaswag/bench_sglang.py \\\n  --host 127.0.0.1 \\\n  --port 30000 \\\n  --num-questions 200 \\\n  --num-shots 20\n```\n\n**GPQA**\n\n```bash\npython -m sglang.test.run_eval \\\n  --eval-name gpqa \\\n  --port 30000 \\\n  --num-examples 198 \\\n  --max-tokens 120000 \\\n  --repeat 8\n```\n\n```{tip}\nFor reasoning models, add `--thinking-mode <mode>` (e.g., `qwen3`, `deepseek-v3`). You may skip it if the model has forced thinking enabled.\n```\n\n**HumanEval**\n\n```bash\npip install human_eval\n\npython -m sglang.test.run_eval \\\n  --eval-name humaneval \\\n  --num-examples 10 \\\n  --port 30000\n```\n\n### VLMs\n\n**MMMU**\n\n```bash\npython benchmark/mmmu/bench_sglang.py \\\n  --port 30000 \\\n  --concurrency 64\n```\n\n```{tip}\nYou can set max tokens by passing `--extra-request-body '{\"max_tokens\": 4096}'`.\n```\n\nFor models capable of processing video, we recommend extending the evaluation to include `VideoMME`, `MVBench`, and other relevant benchmarks.\n\n## Performance\n\nPerformance benchmarks measure **Latency** (Time To First Token - TTFT) and **Throughput** (tokens/second).\n\n### LLMs\n\n**Latency-Sensitive Benchmark**\n\nThis simulates a scenario with low concurrency (e.g., single user) to measure latency.\n\n```bash\npython -m sglang.bench_serving \\\n  --backend sglang \\\n  --host 0.0.0.0 \\\n  --port 30000 \\\n  --dataset-name random \\\n  --num-prompts 10 \\\n  --max-concurrency 1\n```\n\n**Throughput-Sensitive Benchmark**\n\nThis simulates a high-traffic scenario to measure maximum system throughput.\n\n```bash\npython -m sglang.bench_serving \\\n  --backend sglang \\\n  --host 0.0.0.0 \\\n  --port 30000 \\\n  --dataset-name random \\\n  --num-prompts 1000 \\\n  --max-concurrency 100\n```\n\n**Single Batch Performance**\n\nYou can also benchmark the performance of processing a single batch offline.\n\n```bash\npython -m sglang.bench_one_batch_server \\\n  --model <model-path> \\\n  --batch-size 8 \\\n  --input-len 1024 \\\n  --output-len 1024\n```\n\nYou can run more granular benchmarks:\n\n- **Low Concurrency**: `--num-prompts 10 --max-concurrency 1`\n- **Medium Concurrency**: `--num-prompts 80 --max-concurrency 16`\n- **High Concurrency**: `--num-prompts 500 --max-concurrency 100`\n\n## Reporting Results\n\nFor each evaluation, please report:\n\n1.  **Metric Score**: Accuracy % (LLMs and VLMs); Latency (ms) and Throughput (tok/s) (LLMs only).\n2.  **Environment settings**: GPU type/count, SGLang commit hash.\n3.  **Launch configuration**: Model path, TP size, and any special flags.\n4.  **Evaluation parameters**: Number of shots, examples, max tokens.\n"
  },
  {
    "path": "docs/developer_guide/release_process.md",
    "content": "# PyPI Package Release Process\n\n## Update the version in code\nUpdate the package version in `python/pyproject.toml` and `python/sglang/__init__.py`.\n\n## Upload the PyPI package\n\n```\npip install build twine\n```\n\n```\ncd python\nbash upload_pypi.sh\n```\n\n## Make a release in GitHub\nMake a new release https://github.com/sgl-project/sglang/releases/new.\n"
  },
  {
    "path": "docs/developer_guide/setup_github_runner.md",
    "content": "# Set Up Self-Hosted Runners for GitHub Actions\n\n## Add a Runner\n\n### Step 1: Start a docker container.\n\n**You can mount a folder for the shared huggingface model weights cache. **\nThe command below uses `/tmp/huggingface` as an example.\n\n```\ndocker pull nvidia/cuda:12.9.1-devel-ubuntu22.04\n# Nvidia\ndocker run --shm-size 128g -it -v /tmp/huggingface:/hf_home --gpus all nvidia/cuda:12.9.1-devel-ubuntu22.04 /bin/bash\n# AMD\ndocker run --rm --device=/dev/kfd --device=/dev/dri --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.5.8-rocm700-mi30x /bin/bash\n# AMD just the last 2 GPUs\ndocker run --rm --device=/dev/kfd --device=/dev/dri/renderD176 --device=/dev/dri/renderD184 --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.5.8-rocm700-mi30x /bin/bash\n```\n\n### Step 2: Configure the runner by `config.sh`\n\nRun these commands inside the container.\n\n```\napt update && apt install -y curl python3-pip git\npip install --upgrade pip\nexport RUNNER_ALLOW_RUNASROOT=1\n```\n\nThen follow https://github.com/sgl-project/sglang/settings/actions/runners/new?arch=x64&os=linux to run `config.sh`\n\n**Notes**\n- Do not need to specify the runner group\n- Give it a name (e.g., `test-sgl-gpu-0`) and some labels (e.g., `1-gpu-runner`). The labels can be edited later in Github Settings.\n- Do not need to change the work folder.\n\n### Step 3: Run the runner by `run.sh`\n\n- Set up environment variables\n```\nexport HF_HOME=/hf_home\nexport SGLANG_IS_IN_CI=true\nexport HF_TOKEN=hf_xxx\nexport OPENAI_API_KEY=sk-xxx\nexport CUDA_VISIBLE_DEVICES=0\n```\n\n- Run it forever\n```\nwhile true; do ./run.sh; echo \"Restarting...\"; sleep 2; done\n```\n"
  },
  {
    "path": "docs/diffusion/api/cli.md",
    "content": "# SGLang diffusion CLI Inference\n\nThe SGLang-diffusion CLI provides a quick way to access the inference pipeline for image and video generation.\n\n## Prerequisites\n\n- A working SGLang diffusion installation and the `sglang` CLI available in `$PATH`.\n\n\n## Supported Arguments\n\n### Server Arguments\n\n- `--model-path {MODEL_PATH}`: Path to the model or model ID\n- `--lora-path {LORA_PATH}`: Path to a LoRA adapter (local path or HuggingFace model ID). If not specified, LoRA will not be applied.\n- `--lora-nickname {NAME}`: Nickname for the LoRA adapter. (default: `default`).\n- `--num-gpus {NUM_GPUS}`: Number of GPUs to use\n- `--tp-size {TP_SIZE}`: Tensor parallelism size (only for the encoder; should not be larger than 1 if text encoder offload is enabled, as layer-wise offload plus prefetch is faster)\n- `--sp-degree {SP_SIZE}`: Sequence parallelism size (typically should match the number of GPUs)\n- `--ulysses-degree {ULYSSES_DEGREE}`: The degree of DeepSpeed-Ulysses-style SP in USP\n- `--ring-degree {RING_DEGREE}`: The degree of ring attention-style SP in USP\n- `--attention-backend {BACKEND}`: Attention backend to use. For SGLang-native pipelines use `fa`, `torch_sdpa`, `sage_attn`, etc. For diffusers pipelines use diffusers backend names like `flash`, `_flash_3_hub`, `sage`, `xformers`.\n- `--attention-backend-config {CONFIG}`: Configuration for the attention backend. Can be a JSON string (e.g., '{\"k\": \"v\"}'), a path to a JSON/YAML file, or key=value pairs (e.g., \"k=v,k2=v2\").\n- `--cache-dit-config {PATH}`: Path to a Cache-DiT YAML/JSON config (diffusers backend only)\n- `--dit-precision {DTYPE}`: Precision for the DiT model (currently supports fp32, fp16, and bf16).\n\n\n### Sampling Parameters\n\n- `--prompt {PROMPT}`: Text description for the video you want to generate\n- `--num-inference-steps {STEPS}`: Number of denoising steps\n- `--negative-prompt {PROMPT}`: Negative prompt to guide generation away from certain concepts\n- `--seed {SEED}`: Random seed for reproducible generation\n\n\n**Image/Video Configuration**\n\n- `--height {HEIGHT}`: Height of the generated output\n- `--width {WIDTH}`: Width of the generated output\n- `--num-frames {NUM_FRAMES}`: Number of frames to generate\n- `--fps {FPS}`: Frames per second for the saved output, if this is a video-generation task\n\n\n**Post-Processing** (frame interpolation & upscaling)\n\nSGLang diffusion supports optional post-processing steps — frame interpolation\n(RIFE) for smoother video and upscaling (Real-ESRGAN) for higher resolution.\nSee the dedicated **[Post-Processing](post_processing.md)** page for full\ndetails, supported models, and examples.\n\n**Output Options**\n\n- `--output-path {PATH}`: Directory to save the generated video\n- `--save-output`: Whether to save the image/video to disk\n- `--return-frames`: Whether to return the raw frames\n\n### Using Configuration Files\n\nInstead of specifying all parameters on the command line, you can use a configuration file:\n\n```bash\nsglang generate --config {CONFIG_FILE_PATH}\n```\n\nThe configuration file should be in JSON or YAML format with the same parameter names as the CLI options. Command-line arguments take precedence over settings in the configuration file, allowing you to override specific values while keeping the rest from the configuration file.\n\nExample configuration file (config.json):\n\n```json\n{\n    \"model_path\": \"FastVideo/FastHunyuan-diffusers\",\n    \"prompt\": \"A beautiful woman in a red dress walking down a street\",\n    \"output_path\": \"outputs/\",\n    \"num_gpus\": 2,\n    \"sp_size\": 2,\n    \"tp_size\": 1,\n    \"num_frames\": 45,\n    \"height\": 720,\n    \"width\": 1280,\n    \"num_inference_steps\": 6,\n    \"seed\": 1024,\n    \"fps\": 24,\n    \"precision\": \"bf16\",\n    \"vae_precision\": \"fp16\",\n    \"vae_tiling\": true,\n    \"vae_sp\": true,\n    \"vae_config\": {\n        \"load_encoder\": false,\n        \"load_decoder\": true,\n        \"tile_sample_min_height\": 256,\n        \"tile_sample_min_width\": 256\n    },\n    \"text_encoder_precisions\": [\n        \"fp16\",\n        \"fp16\"\n    ],\n    \"mask_strategy_file_path\": null,\n    \"enable_torch_compile\": false\n}\n```\n\nOr using YAML format (config.yaml):\n\n```yaml\nmodel_path: \"FastVideo/FastHunyuan-diffusers\"\nprompt: \"A beautiful woman in a red dress walking down a street\"\noutput_path: \"outputs/\"\nnum_gpus: 2\nsp_size: 2\ntp_size: 1\nnum_frames: 45\nheight: 720\nwidth: 1280\nnum_inference_steps: 6\nseed: 1024\nfps: 24\nprecision: \"bf16\"\nvae_precision: \"fp16\"\nvae_tiling: true\nvae_sp: true\nvae_config:\n  load_encoder: false\n  load_decoder: true\n  tile_sample_min_height: 256\n  tile_sample_min_width: 256\ntext_encoder_precisions:\n  - \"fp16\"\n  - \"fp16\"\nmask_strategy_file_path: null\nenable_torch_compile: false\n```\n\n\nTo see all the options, you can use the `--help` flag:\n\n```bash\nsglang generate --help\n```\n\n## Serve\n\nLaunch the SGLang diffusion HTTP server and interact with it using the OpenAI SDK and curl.\n\n### Start the server\n\nUse the following command to launch the server:\n\n```bash\nSERVER_ARGS=(\n  --model-path Wan-AI/Wan2.1-T2V-1.3B-Diffusers\n  --text-encoder-cpu-offload\n  --pin-cpu-memory\n  --num-gpus 4\n  --ulysses-degree=2\n  --ring-degree=2\n)\n\nsglang serve \"${SERVER_ARGS[@]}\"\n```\n\n- **--model-path**: Which model to load. The example uses `Wan-AI/Wan2.1-T2V-1.3B-Diffusers`.\n- **--port**: HTTP port to listen on (the default here is `30010`).\n\nFor detailed API usage, including Image, Video Generation and LoRA management, please refer to the [OpenAI API Documentation](openai_api.md).\n\n### Cloud Storage Support\n\nSGLang diffusion supports automatically uploading generated images and videos to S3-compatible cloud storage (e.g., AWS S3, MinIO, Alibaba Cloud OSS, Tencent Cloud COS).\n\nWhen enabled, the server follows a **Generate -> Upload -> Delete** workflow:\n1. The artifact is generated to a temporary local file.\n2. The file is immediately uploaded to the configured S3 bucket in a background thread.\n3. Upon successful upload, the local file is deleted.\n4. The API response returns the public URL of the uploaded object.\n\n**Configuration**\n\nCloud storage is enabled via environment variables. Note that `boto3` must be installed separately (`pip install boto3`) to use this feature.\n\n```bash\n# Enable S3 storage\nexport SGLANG_CLOUD_STORAGE_TYPE=s3\nexport SGLANG_S3_BUCKET_NAME=my-bucket\nexport SGLANG_S3_ACCESS_KEY_ID=your-access-key\nexport SGLANG_S3_SECRET_ACCESS_KEY=your-secret-key\n\n# Optional: Custom endpoint for MinIO/OSS/COS\nexport SGLANG_S3_ENDPOINT_URL=https://minio.example.com\n```\n\nSee [Environment Variables Documentation](../environment_variables.md) for more details.\n\n## Generate\n\nRun a one-off generation task without launching a persistent server.\n\nTo use it, pass both server arguments and sampling parameters in one command, after the `generate` subcommand, for example:\n\n```bash\nSERVER_ARGS=(\n  --model-path Wan-AI/Wan2.2-T2V-A14B-Diffusers\n  --text-encoder-cpu-offload\n  --pin-cpu-memory\n  --num-gpus 4\n  --ulysses-degree=2\n  --ring-degree=2\n)\n\nSAMPLING_ARGS=(\n  --prompt \"A curious raccoon\"\n  --save-output\n  --output-path outputs\n  --output-file-name \"A curious raccoon.mp4\"\n)\n\nsglang generate \"${SERVER_ARGS[@]}\" \"${SAMPLING_ARGS[@]}\"\n\n# Or, users can set `SGLANG_CACHE_DIT_ENABLED` env as `true` to enable cache acceleration\nSGLANG_CACHE_DIT_ENABLED=true sglang generate \"${SERVER_ARGS[@]}\" \"${SAMPLING_ARGS[@]}\"\n```\n\nOnce the generation task has finished, the server will shut down automatically.\n\n> [!NOTE]\n> The HTTP server-related arguments are ignored in this subcommand.\n\n## Component Path Overrides\n\nSGLang diffusion allows you to override any pipeline component (e.g., `vae`, `transformer`, `text_encoder`) by specifying a custom checkpoint path. This is useful for:\n\n### Example: FLUX.2-dev with Tiny AutoEncoder\n\nYou can override **any** component by using `--<component>-path`, where `<component>` matches the key in the model's `model_index.json`:\n\nFor example, replace the default VAE with a distilled tiny autoencoder for ~3x faster decoding:\n\n```bash\nsglang serve \\\n  --model-path=black-forest-labs/FLUX.2-dev \\\n  # with a Huggingface Repo ID\n  --vae-path=fal/FLUX.2-Tiny-AutoEncoder\n  # or use a local path\n  --vae-path=~/.cache/huggingface/hub/models--fal--FLUX.2-Tiny-AutoEncoder/snapshots/.../vae\n```\n\n**Important:**\n- The component key must match the one in your model's `model_index.json` (e.g., `vae`).\n- The path must:\n    - either be a Huggingface Repo ID (e.g., fal/FLUX.2-Tiny-AutoEncoder)\n    - or point to a **complete component folder**, containing `config.json` and safetensors files\n\n\n## Diffusers Backend\n\nSGLang diffusion supports a **diffusers backend** that allows you to run any diffusers-compatible model through SGLang's infrastructure using vanilla diffusers pipelines. This is useful for running models without native SGLang implementations or models with custom pipeline classes.\n\n### Arguments\n\n| Argument | Values | Description |\n|----------|--------|-------------|\n| `--backend` | `auto` (default), `sglang`, `diffusers` | `auto`: prefer native SGLang, fallback to diffusers. `sglang`: force native (fails if unavailable). `diffusers`: force vanilla diffusers pipeline. |\n| `--diffusers-attention-backend` | `flash`, `_flash_3_hub`, `sage`, `xformers`, `native` | Attention backend for diffusers pipelines. See [diffusers attention backends](https://huggingface.co/docs/diffusers/main/en/optimization/attention_backends). |\n| `--trust-remote-code` | flag | Required for models with custom pipeline classes (e.g., Ovis). |\n| `--vae-tiling` | flag | Enable VAE tiling for large image support (decodes tile-by-tile). |\n| `--vae-slicing` | flag | Enable VAE slicing for lower memory usage (decodes slice-by-slice). |\n| `--dit-precision` | `fp16`, `bf16`, `fp32` | Precision for the diffusion transformer. |\n| `--vae-precision` | `fp16`, `bf16`, `fp32` | Precision for the VAE. |\n| `--enable-torch-compile` | flag | Enable `torch.compile` for diffusers pipelines. |\n| `--cache-dit-config` | `{PATH}` | Path to a Cache-DiT YAML/JSON config file for accelerating diffusers pipelines with Cache-DiT. |\n\n### Example: Running Ovis-Image-7B\n\n[Ovis-Image-7B](https://huggingface.co/AIDC-AI/Ovis-Image-7B) is a 7B text-to-image model optimized for high-quality text rendering.\n\n```bash\nsglang generate \\\n  --model-path AIDC-AI/Ovis-Image-7B \\\n  --backend diffusers \\\n  --trust-remote-code \\\n  --diffusers-attention-backend flash \\\n  --prompt \"A serene Japanese garden with cherry blossoms\" \\\n  --height 1024 \\\n  --width 1024 \\\n  --num-inference-steps 30 \\\n  --save-output \\\n  --output-path outputs \\\n  --output-file-name ovis_garden.png\n```\n\n### Extra Diffusers Arguments\n\nFor pipeline-specific parameters not exposed via CLI, use `diffusers_kwargs` in a config file:\n\n```json\n{\n    \"model_path\": \"AIDC-AI/Ovis-Image-7B\",\n    \"backend\": \"diffusers\",\n    \"prompt\": \"A beautiful landscape\",\n    \"diffusers_kwargs\": {\n        \"cross_attention_kwargs\": {\"scale\": 0.5}\n    }\n}\n```\n\n```bash\nsglang generate --config config.json\n```\n\n### Cache-DiT Acceleration\n\nUsers who use the diffusers backend can also leverage Cache-DiT acceleration and load custom cache configs from a YAML file to boost performance of diffusers pipelines. See the [Cache-DiT Acceleration](https://docs.sglang.io/diffusion/performance/cache/cache_dit.html) documentation for details.\n"
  },
  {
    "path": "docs/diffusion/api/openai_api.md",
    "content": "# SGLang Diffusion OpenAI API\n\nThe SGLang diffusion HTTP server implements an OpenAI-compatible API for image and video generation, as well as LoRA adapter management.\n\n## Prerequisites\n\n- Python 3.11+ if you plan to use the OpenAI Python SDK.\n\n## Serve\n\nLaunch the server using the `sglang serve` command.\n\n### Start the server\n\n```bash\nSERVER_ARGS=(\n  --model-path Wan-AI/Wan2.1-T2V-1.3B-Diffusers\n  --text-encoder-cpu-offload\n  --pin-cpu-memory\n  --num-gpus 4\n  --ulysses-degree=2\n  --ring-degree=2\n  --port 30010\n)\n\nsglang serve \"${SERVER_ARGS[@]}\"\n```\n\n- **--model-path**: Path to the model or model ID.\n- **--port**: HTTP port to listen on (default: `30000`).\n\n**Get Model Information**\n\n**Endpoint:** `GET /models`\n\nReturns information about the model served by this server, including model path, task type, pipeline configuration, and precision settings.\n\n**Curl Example:**\n\n```bash\ncurl -sS -X GET \"http://localhost:30010/models\"\n```\n\n**Response Example:**\n\n```json\n{\n  \"model_path\": \"Wan-AI/Wan2.1-T2V-1.3B-Diffusers\",\n  \"task_type\": \"T2V\",\n  \"pipeline_name\": \"wan_pipeline\",\n  \"pipeline_class\": \"WanPipeline\",\n  \"num_gpus\": 4,\n  \"dit_precision\": \"bf16\",\n  \"vae_precision\": \"fp16\"\n}\n```\n\n---\n\n## Endpoints\n\n### Image Generation\n\nThe server implements an OpenAI-compatible Images API under the `/v1/images` namespace.\n\n**Create an image**\n\n**Endpoint:** `POST /v1/images/generations`\n\n**Python Example (b64_json response):**\n\n```python\nimport base64\nfrom openai import OpenAI\n\nclient = OpenAI(api_key=\"sk-proj-1234567890\", base_url=\"http://localhost:30010/v1\")\n\nimg = client.images.generate(\n    prompt=\"A calico cat playing a piano on stage\",\n    size=\"1024x1024\",\n    n=1,\n    response_format=\"b64_json\",\n)\n\nimage_bytes = base64.b64decode(img.data[0].b64_json)\nwith open(\"output.png\", \"wb\") as f:\n    f.write(image_bytes)\n```\n\n**Curl Example:**\n\n```bash\ncurl -sS -X POST \"http://localhost:30010/v1/images/generations\" \\\n  -H \"Content-Type: application/json\" \\\n  -H \"Authorization: Bearer sk-proj-1234567890\" \\\n  -d '{\n        \"prompt\": \"A calico cat playing a piano on stage\",\n        \"size\": \"1024x1024\",\n        \"n\": 1,\n        \"response_format\": \"b64_json\"\n      }'\n```\n\n> **Note**\n> If `response_format=url` is used and cloud storage is not configured, the API returns\n> a relative URL like `/v1/images/<IMAGE_ID>/content`.\n\n**Edit an image**\n\n**Endpoint:** `POST /v1/images/edits`\n\nThis endpoint accepts a multipart form upload with input images and a text prompt. The server can return either a base64-encoded image or a URL to download the image.\n\n**Curl Example (b64_json response):**\n\n```bash\ncurl -sS -X POST \"http://localhost:30010/v1/images/edits\" \\\n  -H \"Authorization: Bearer sk-proj-1234567890\" \\\n  -F \"image=@local_input_image.png\" \\\n  -F \"url=image_url.jpg\" \\\n  -F \"prompt=A calico cat playing a piano on stage\" \\\n  -F \"size=1024x1024\" \\\n  -F \"response_format=b64_json\"\n```\n\n**Curl Example (URL response):**\n\n```bash\ncurl -sS -X POST \"http://localhost:30010/v1/images/edits\" \\\n  -H \"Authorization: Bearer sk-proj-1234567890\" \\\n  -F \"image=@local_input_image.png\" \\\n  -F \"url=image_url.jpg\" \\\n  -F \"prompt=A calico cat playing a piano on stage\" \\\n  -F \"size=1024x1024\" \\\n  -F \"response_format=url\"\n```\n\n**Download image content**\n\nWhen `response_format=url` is used with `POST /v1/images/generations` or `POST /v1/images/edits`,\nthe API returns a relative URL like `/v1/images/<IMAGE_ID>/content`.\n\n**Endpoint:** `GET /v1/images/{image_id}/content`\n\n**Curl Example:**\n\n```bash\ncurl -sS -L \"http://localhost:30010/v1/images/<IMAGE_ID>/content\" \\\n  -H \"Authorization: Bearer sk-proj-1234567890\" \\\n  -o output.png\n```\n\n### Video Generation\n\nThe server implements a subset of the OpenAI Videos API under the `/v1/videos` namespace.\n\n**Create a video**\n\n**Endpoint:** `POST /v1/videos`\n\n**Python Example:**\n\n```python\nfrom openai import OpenAI\n\nclient = OpenAI(api_key=\"sk-proj-1234567890\", base_url=\"http://localhost:30010/v1\")\n\nvideo = client.videos.create(\n    prompt=\"A calico cat playing a piano on stage\",\n    size=\"1280x720\"\n)\nprint(f\"Video ID: {video.id}, Status: {video.status}\")\n```\n\n**Curl Example:**\n\n```bash\ncurl -sS -X POST \"http://localhost:30010/v1/videos\" \\\n  -H \"Content-Type: application/json\" \\\n  -H \"Authorization: Bearer sk-proj-1234567890\" \\\n  -d '{\n        \"prompt\": \"A calico cat playing a piano on stage\",\n        \"size\": \"1280x720\"\n      }'\n```\n\n**List videos**\n\n**Endpoint:** `GET /v1/videos`\n\n**Python Example:**\n\n```python\nvideos = client.videos.list()\nfor item in videos.data:\n    print(item.id, item.status)\n```\n\n**Curl Example:**\n\n```bash\ncurl -sS -X GET \"http://localhost:30010/v1/videos\" \\\n  -H \"Authorization: Bearer sk-proj-1234567890\"\n```\n\n**Download video content**\n\n**Endpoint:** `GET /v1/videos/{video_id}/content`\n\n**Python Example:**\n\n```python\nimport time\n\n# Poll for completion\nwhile True:\n    page = client.videos.list()\n    item = next((v for v in page.data if v.id == video_id), None)\n    if item and item.status == \"completed\":\n        break\n    time.sleep(5)\n\n# Download content\nresp = client.videos.download_content(video_id=video_id)\nwith open(\"output.mp4\", \"wb\") as f:\n    f.write(resp.read())\n```\n\n**Curl Example:**\n\n```bash\ncurl -sS -L \"http://localhost:30010/v1/videos/<VIDEO_ID>/content\" \\\n  -H \"Authorization: Bearer sk-proj-1234567890\" \\\n  -o output.mp4\n```\n\n---\n\n### LoRA Management\n\nThe server supports dynamic loading, merging, and unmerging of LoRA adapters.\n\n**Important Notes:**\n- Mutual Exclusion: Only one LoRA can be *merged* (active) at a time\n- Switching: To switch LoRAs, you must first `unmerge` the current one, then `set` the new one\n- Caching: The server caches loaded LoRA weights in memory. Switching back to a previously loaded LoRA (same path) has little cost\n\n**Set LoRA Adapter**\n\nLoads one or more LoRA adapters and merges their weights into the model. Supports both single LoRA (backward compatible) and multiple LoRA adapters.\n\n**Endpoint:** `POST /v1/set_lora`\n\n**Parameters:**\n- `lora_nickname` (string or list of strings, required): A unique identifier for the LoRA adapter(s). Can be a single string or a list of strings for multiple LoRAs\n- `lora_path` (string or list of strings/None, optional): Path to the `.safetensors` file(s) or Hugging Face repo ID(s). Required for the first load; optional if re-activating a cached nickname. If a list, must match the length of `lora_nickname`\n- `target` (string or list of strings, optional): Which transformer(s) to apply the LoRA to. If a list, must match the length of `lora_nickname`. Valid values:\n  - `\"all\"` (default): Apply to all transformers\n  - `\"transformer\"`: Apply only to the primary transformer (high noise for Wan2.2)\n  - `\"transformer_2\"`: Apply only to transformer_2 (low noise for Wan2.2)\n  - `\"critic\"`: Apply only to the critic model\n- `strength` (float or list of floats, optional): LoRA strength for merge, default 1.0. If a list, must match the length of `lora_nickname`. Values < 1.0 reduce the effect, values > 1.0 amplify the effect\n\n**Single LoRA Example:**\n\n```bash\ncurl -X POST http://localhost:30010/v1/set_lora \\\n  -H \"Content-Type: application/json\" \\\n  -d '{\n        \"lora_nickname\": \"lora_name\",\n        \"lora_path\": \"/path/to/lora.safetensors\",\n        \"target\": \"all\",\n        \"strength\": 0.8\n      }'\n```\n\n**Multiple LoRA Example:**\n\n```bash\ncurl -X POST http://localhost:30010/v1/set_lora \\\n  -H \"Content-Type: application/json\" \\\n  -d '{\n        \"lora_nickname\": [\"lora_1\", \"lora_2\"],\n        \"lora_path\": [\"/path/to/lora1.safetensors\", \"/path/to/lora2.safetensors\"],\n        \"target\": [\"transformer\", \"transformer_2\"],\n        \"strength\": [0.8, 1.0]\n      }'\n```\n\n**Multiple LoRA with Same Target:**\n\n```bash\ncurl -X POST http://localhost:30010/v1/set_lora \\\n  -H \"Content-Type: application/json\" \\\n  -d '{\n        \"lora_nickname\": [\"style_lora\", \"character_lora\"],\n        \"lora_path\": [\"/path/to/style.safetensors\", \"/path/to/character.safetensors\"],\n        \"target\": \"all\",\n        \"strength\": [0.7, 0.9]\n      }'\n```\n\n> [!NOTE]\n> When using multiple LoRAs:\n> - All list parameters (`lora_nickname`, `lora_path`, `target`, `strength`) must have the same length\n> - If `target` or `strength` is a single value, it will be applied to all LoRAs\n> - Multiple LoRAs applied to the same target will be merged in order\n\n\n**Merge LoRA Weights**\n\nManually merges the currently set LoRA weights into the base model.\n\n> [!NOTE]\n> `set_lora` automatically performs a merge, so this is typically only needed if you have manually unmerged but want to re-apply the same LoRA without calling `set_lora` again.*\n\n**Endpoint:** `POST /v1/merge_lora_weights`\n\n**Parameters:**\n- `target` (string, optional): Which transformer(s) to merge. One of \"all\" (default), \"transformer\", \"transformer_2\", \"critic\"\n- `strength` (float, optional): LoRA strength for merge, default 1.0. Values < 1.0 reduce the effect, values > 1.0 amplify the effect\n\n**Curl Example:**\n\n```bash\ncurl -X POST http://localhost:30010/v1/merge_lora_weights \\\n  -H \"Content-Type: application/json\" \\\n  -d '{\"strength\": 0.8}'\n```\n\n\n**Unmerge LoRA Weights**\n\nUnmerges the currently active LoRA weights from the base model, restoring it to its original state. This **must** be called before setting a different LoRA.\n\n**Endpoint:** `POST /v1/unmerge_lora_weights`\n\n**Curl Example:**\n\n```bash\ncurl -X POST http://localhost:30010/v1/unmerge_lora_weights \\\n  -H \"Content-Type: application/json\"\n```\n\n**List LoRA Adapters**\n\nReturns loaded LoRA adapters and current application status per module.\n\n**Endpoint:** `GET /v1/list_loras`\n\n**Curl Example:**\n\n```bash\ncurl -sS -X GET \"http://localhost:30010/v1/list_loras\"\n```\n\n**Response Example:**\n\n```json\n{\n  \"loaded_adapters\": [\n    { \"nickname\": \"lora_a\", \"path\": \"/weights/lora_a.safetensors\" },\n    { \"nickname\": \"lora_b\", \"path\": \"/weights/lora_b.safetensors\" }\n  ],\n  \"active\": {\n    \"transformer\": [\n      {\n        \"nickname\": \"lora2\",\n        \"path\": \"tarn59/pixel_art_style_lora_z_image_turbo\",\n        \"merged\": true,\n        \"strength\": 1.0\n      }\n    ]\n  }\n}\n```\n\nNotes:\n- If LoRA is not enabled for the current pipeline, the server will return an error.\n- `num_lora_layers_with_weights` counts only layers that have LoRA weights applied for the active adapter.\n\n### Example: Switching LoRAs\n\n1.  Set LoRA A:\n    ```bash\n    curl -X POST http://localhost:30010/v1/set_lora -d '{\"lora_nickname\": \"lora_a\", \"lora_path\": \"path/to/A\"}'\n    ```\n2.  Generate with LoRA A...\n3.  Unmerge LoRA A:\n    ```bash\n    curl -X POST http://localhost:30010/v1/unmerge_lora_weights\n    ```\n4.  Set LoRA B:\n    ```bash\n    curl -X POST http://localhost:30010/v1/set_lora -d '{\"lora_nickname\": \"lora_b\", \"lora_path\": \"path/to/B\"}'\n    ```\n5.  Generate with LoRA B...\n\n### Adjust Output Quality\n\nThe server supports adjusting output quality and compression levels for both image and video generation through the `output-quality` and `output-compression` parameters.\n\n#### Parameters\n\n- **`output-quality`** (string, optional): Preset quality level that automatically sets compression. **Default is `\"default\"`**. Valid values:\n  - `\"maximum\"`: Highest quality (100)\n  - `\"high\"`: High quality (90)\n  - `\"medium\"`: Medium quality (55)\n  - `\"low\"`: Lower quality (35)\n  - `\"default\"`: Auto-adjust based on media type (50 for video, 75 for image)\n\n- **`output-compression`** (integer, optional): Direct compression level override (0-100). **Default is `None`**. When provided (not `None`), takes precedence over `output-quality`.\n  - `0`: Lowest quality, smallest file size\n  - `100`: Highest quality, largest file size\n\n#### Notes\n\n- **Precedence**: When both `output-quality` and `output-compression` are provided, `output-compression` takes precedence\n- **Format Support**: Quality settings apply to JPEG, and video formats. PNG uses lossless compression and ignores these settings\n- **File Size vs Quality**: Lower compression values (or \"low\" quality preset) produce smaller files but may show visible artifacts\n"
  },
  {
    "path": "docs/diffusion/api/post_processing.md",
    "content": "# Post-Processing\n\nSGLang diffusion supports optional post-processing steps that run after\ngeneration to improve temporal smoothness (frame interpolation) or spatial\nresolution (upscaling). These steps are independent of the diffusion model and\ncan be combined in a single run.\n\nWhen both are enabled, **frame interpolation runs first** (increasing the frame\ncount), then **upscaling runs on every frame** (increasing the spatial\nresolution).\n\n---\n\n## Frame Interpolation (video only)\n\nFrame interpolation synthesizes new frames between each pair of consecutive\ngenerated frames, producing smoother motion without re-running the diffusion\nmodel.\n\nThe `--frame-interpolation-exp` flag controls how many rounds of interpolation\nto apply: each round inserts one new frame into every gap between adjacent\nframes, so the output frame count follows the formula:\n\n> **(N − 1) × 2^exp + 1**\n>\n> e.g. 5 original frames with `exp=1` → 4 gaps × 1 new frame + 5 originals = **9** frames;\n> with `exp=2` → **17** frames.\n\n### CLI Arguments\n\n| Argument | Description |\n|----------|-------------|\n| `--enable-frame-interpolation` | Enable frame interpolation. Model weights are downloaded automatically on first use. |\n| `--frame-interpolation-exp {EXP}` | Interpolation exponent — `1` = 2× temporal resolution, `2` = 4×, etc. (default: `1`) |\n| `--frame-interpolation-scale {SCALE}` | RIFE inference scale; use `0.5` for high-resolution inputs to save memory (default: `1.0`) |\n| `--frame-interpolation-model-path {PATH}` | Local directory or HuggingFace repo ID containing RIFE `flownet.pkl` weights (default: `elfgum/RIFE-4.22.lite`, downloaded automatically) |\n\n### Supported Models\n\nFrame interpolation uses the [RIFE](https://github.com/hzwer/Practical-RIFE)\n(Real-Time Intermediate Flow Estimation) architecture. Only **RIFE 4.22.lite**\n(`IFNet` with 4-scale `IFBlock` backbone) is supported. The network topology is\nhard-coded, so custom weights provided via `--frame-interpolation-model-path`\nmust be a `flownet.pkl` checkpoint that is compatible with this architecture.\n\nOther RIFE versions (e.g., older `v4.x` variants with different block counts)\nor entirely different frame interpolation methods (FILM, AMT, etc.) are **not\nsupported**.\n\n| Weight | HuggingFace Repo | Description |\n|--------|------------------|-------------|\n| RIFE 4.22.lite *(default)* | [`elfgum/RIFE-4.22.lite`](https://huggingface.co/elfgum/RIFE-4.22.lite) | Lightweight model, downloaded automatically on first use |\n\n### Example\n\nGenerate a 5-frame video and interpolate to 9 frames ((5 − 1) × 2¹ + 1 = 9):\n\n```bash\nsglang generate \\\n  --model-path Wan-AI/Wan2.2-T2V-A14B-Diffusers \\\n  --prompt \"A dog running through a park\" \\\n  --num-frames 5 \\\n  --enable-frame-interpolation \\\n  --frame-interpolation-exp 1 \\\n  --save-output\n```\n\n---\n\n## Upscaling (image and video)\n\nUpscaling increases the spatial resolution of generated images or video frames\nusing [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN). The model weights\nare downloaded automatically on first use and cached for subsequent runs.\n\n### CLI Arguments\n\n| Argument | Description |\n|----------|-------------|\n| `--enable-upscaling` | Enable post-generation upscaling using Real-ESRGAN. |\n| `--upscaling-scale {SCALE}` | Desired upscaling factor (default: `4`). The 4× model is used internally; if a different scale is requested, a bicubic resize is applied after the network output. |\n| `--upscaling-model-path {PATH}` | Local `.pth` file, HuggingFace repo ID, or `repo_id:filename` for Real-ESRGAN weights (default: `ai-forever/Real-ESRGAN` with `RealESRGAN_x4.pth`, downloaded automatically). Use the `repo_id:filename` format to specify a custom weight file from a HuggingFace repo (e.g. `my-org/my-esrgan:weights.pth`). |\n\n### Supported Models\n\nUpscaling supports two Real-ESRGAN network architectures. The correct\narchitecture is **auto-detected** from the checkpoint keys, so you only need to\npoint `--upscaling-model-path` at a valid `.pth` file:\n\n| Architecture | Example Weights | Description |\n|--------------|-----------------|-------------|\n| **RRDBNet** | `RealESRGAN_x4plus.pth` | Heavier model with higher quality; best for photos |\n| **SRVGGNetCompact** | `RealESRGAN_x4.pth` *(default)*, `realesr-animevideov3.pth`, `realesr-general-x4v3.pth` | Lightweight model; faster inference, good for video |\n\nThe default weight file is\n[`ai-forever/Real-ESRGAN`](https://huggingface.co/ai-forever/Real-ESRGAN) with\n`RealESRGAN_x4.pth` (SRVGGNetCompact, 4× native scale).\n\nOther super-resolution models (e.g., SwinIR, HAT, BSRGAN) are **not supported**\n— only Real-ESRGAN checkpoints using the two architectures above are\ncompatible.\n\n### Examples\n\nGenerate a 1024×1024 image and upscale to 4096×4096:\n\n```bash\nsglang generate \\\n  --model-path black-forest-labs/FLUX.2-dev \\\n  --prompt \"A cat sitting on a windowsill\" \\\n  --output-size 1024x1024 \\\n  --enable-upscaling \\\n  --save-output\n```\n\nGenerate a video and upscale each frame by 4×:\n\n```bash\nsglang generate \\\n  --model-path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \\\n  --prompt \"A curious raccoon\" \\\n  --enable-upscaling \\\n  --upscaling-scale 4 \\\n  --save-output\n```\n\n---\n\n## Combining Frame Interpolation and Upscaling\n\nFrame interpolation and upscaling can be combined in a single run.\nInterpolation is applied first (increasing the frame count), then upscaling is\napplied to every frame (increasing the spatial resolution).\n\nExample — generate 5 frames, interpolate to 9 frames, and upscale each frame\nby 4×:\n\n```bash\nsglang generate \\\n  --model-path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \\\n  --prompt \"A curious raccoon\" \\\n  --num-frames 5 \\\n  --enable-frame-interpolation \\\n  --frame-interpolation-exp 1 \\\n  --enable-upscaling \\\n  --upscaling-scale 4 \\\n  --save-output\n```\n"
  },
  {
    "path": "docs/diffusion/ci_perf.md",
    "content": "## Perf Baseline Generation Script\n\n`python/sglang/multimodal_gen/test/scripts/gen_perf_baselines.py` starts a local diffusion server, issues requests for selected test cases, aggregates stage/denoise-step/E2E timings from the perf log, and writes the results back to the `scenarios` section of `perf_baselines.json`.\n\n### Usage\n\nUpdate a single case:\n\n```bash\npython python/sglang/multimodal_gen/test/scripts/gen_perf_baselines.py --case qwen_image_t2i\n```\n\nSelect by regex:\n\n```bash\npython python/sglang/multimodal_gen/test/scripts/gen_perf_baselines.py --match 'qwen_image_.*'\n```\n\nRun all keys from the baseline file `scenarios`:\n\n```bash\npython python/sglang/multimodal_gen/test/scripts/gen_perf_baselines.py --all-from-baseline\n```\n\nSpecify input/output paths and timeout:\n\n```bash\npython python/sglang/multimodal_gen/test/scripts/gen_perf_baselines.py --baseline python/sglang/multimodal_gen/test/server/perf_baselines.json --out /tmp/perf_baselines.json --timeout 600\n```\n"
  },
  {
    "path": "docs/diffusion/compatibility_matrix.md",
    "content": "# Compatibility Matrix\n\nThe table below shows every supported model and the optimizations supported for them.\n\nThe symbols used have the following meanings:\n\n- ✅ = Full compatibility\n- ❌ = No compatibility\n- ⭕ = Does not apply to this model\n\n## Models x Optimization\n\nThe `HuggingFace Model ID` can be passed directly to `from_pretrained()` methods, and sglang-diffusion will use the\noptimal\ndefault parameters when initializing and generating videos.\n\n### Video Generation Models\n\n| Model Name                   | Hugging Face Model ID                             | Resolutions         | TeaCache | Sliding Tile Attn | Sage Attn | Video Sparse Attention (VSA) | Sparse Linear Attention (SLA) | Sage Sparse Linear Attention (SageSLA) | Sparse Video Gen 2 (SVG2) |\n|:-----------------------------|:--------------------------------------------------|:--------------------|:--------:|:-----------------:|:---------:|:----------------------------:|:----------------------------:|:-----------------------------------------------:|:----------------------------------:|\n| FastWan2.1 T2V 1.3B          | `FastVideo/FastWan2.1-T2V-1.3B-Diffusers`         | 480p                |    ⭕     |         ⭕         |      ⭕     |              ✅               |              ❌               |              ❌               |    ❌     |\n| FastWan2.2 TI2V 5B Full Attn | `FastVideo/FastWan2.2-TI2V-5B-FullAttn-Diffusers` | 720p                |    ⭕     |         ⭕         |     ⭕     |              ✅               |              ❌               |              ❌               |    ❌     |\n| Wan2.2 TI2V 5B               | `Wan-AI/Wan2.2-TI2V-5B-Diffusers`                 | 720p                |    ⭕     |         ⭕         |     ✅     |              ⭕               |              ❌               |              ❌               |    ❌     |\n| Wan2.2 T2V A14B              | `Wan-AI/Wan2.2-T2V-A14B-Diffusers`                | 480p<br>720p        |    ❌     |         ❌         |     ✅     |              ⭕               |              ❌               |              ❌               |    ❌     |\n| Wan2.2 I2V A14B              | `Wan-AI/Wan2.2-I2V-A14B-Diffusers`                | 480p<br>720p        |    ❌     |         ❌         |     ✅     |              ⭕               |              ❌               |              ❌               |    ❌     |\n| HunyuanVideo                 | `hunyuanvideo-community/HunyuanVideo`             | 720×1280<br>544×960 |    ❌     |         ✅         |     ✅     |              ⭕               |              ❌               |              ❌               |    ✅     |\n| FastHunyuan                  | `FastVideo/FastHunyuan-diffusers`                 | 720×1280<br>544×960 |    ❌     |         ✅         |     ✅     |              ⭕               |              ❌               |              ❌               |    ✅     |\n| Wan2.1 T2V 1.3B              | `Wan-AI/Wan2.1-T2V-1.3B-Diffusers`                | 480p                |    ✅     |         ✅         |     ✅     |              ⭕               |              ❌               |              ❌               |    ✅     |\n| Wan2.1 T2V 14B               | `Wan-AI/Wan2.1-T2V-14B-Diffusers`                 | 480p, 720p          |    ✅     |         ✅         |     ✅     |              ⭕               |              ❌               |              ❌               |    ✅     |\n| Wan2.1 I2V 480P              | `Wan-AI/Wan2.1-I2V-14B-480P-Diffusers`            | 480p                |    ✅     |         ✅         |     ✅     |              ⭕               |              ❌               |              ❌               |    ✅     |\n| Wan2.1 I2V 720P              | `Wan-AI/Wan2.1-I2V-14B-720P-Diffusers`            | 720p                |    ✅     |         ✅         |     ✅     |              ⭕               |              ❌               |              ❌               |    ✅     |\n| TurboWan2.1 T2V 1.3B         | `IPostYellow/TurboWan2.1-T2V-1.3B-Diffusers`      | 480p                |    ✅     |         ❌         |     ❌     |              ❌               |              ✅               |              ✅               |    ⭕     |\n| TurboWan2.1 T2V 14B          | `IPostYellow/TurboWan2.1-T2V-14B-Diffusers`       | 480p                |    ✅     |         ❌         |     ❌     |              ❌               |              ✅               |              ✅               |    ⭕     |\n| TurboWan2.1 T2V 14B 720P     | `IPostYellow/TurboWan2.1-T2V-14B-720P-Diffusers`  | 720p                |    ✅     |         ❌         |     ❌     |              ❌               |              ✅               |              ✅               |    ⭕     |\n| TurboWan2.2 I2V A14B         | `IPostYellow/TurboWan2.2-I2V-A14B-Diffusers`      | 720p                |    ✅     |         ❌         |     ❌     |              ❌               |              ✅               |              ✅               |    ⭕     |\n\n**Note**:\n1.Wan2.2 TI2V 5B has some quality issues when performing I2V generation. We are working on fixing this issue.\n2.SageSLA Based on SpargeAttn. Install it first with `pip install git+https://github.com/thu-ml/SpargeAttn.git --no-build-isolation`\n\n### Image Generation Models\n\n| Model Name       | HuggingFace Model ID                    | Resolutions    |\n|:-----------------|:----------------------------------------|:---------------|\n| FLUX.1-dev       | `black-forest-labs/FLUX.1-dev`          | Any resolution |\n| FLUX.2-dev       | `black-forest-labs/FLUX.2-dev`          | Any resolution |\n| FLUX.2-Klein     | `black-forest-labs/FLUX.2-klein-4B`     | Any resolution |\n| Z-Image-Turbo    | `Tongyi-MAI/Z-Image-Turbo`              | Any resolution |\n| GLM-Image        | `zai-org/GLM-Image`                     | Any resolution |\n| Qwen Image       | `Qwen/Qwen-Image`                       | Any resolution |\n| Qwen Image 2512  | `Qwen/Qwen-Image-2512`                  | Any resolution |\n| Qwen Image Edit  | `Qwen/Qwen-Image-Edit`                  | Any resolution |\n\n## Verified LoRA Examples\n\nThis section lists example LoRAs that have been explicitly tested and verified with each base model in the **SGLang Diffusion** pipeline.\n\n> Important:\n> LoRAs that are not listed here are not necessarily incompatible.\n> In practice, most standard LoRAs are expected to work, especially those following common Diffusers or SD-style conventions.\n> The entries below simply reflect configurations that have been manually validated by the SGLang team.\n\n### Verified LoRAs by Base Model\n\n| Base Model       | Supported LoRAs |\n|:-----------------|:----------------|\n| Wan2.2           | `lightx2v/Wan2.2-Distill-Loras`<br>`Cseti/wan2.2-14B-Arcane_Jinx-lora-v1` |\n| Wan2.1           | `lightx2v/Wan2.1-Distill-Loras` |\n| Z-Image-Turbo    | `tarn59/pixel_art_style_lora_z_image_turbo`<br>`wcde/Z-Image-Turbo-DeJPEG-Lora` |\n| Qwen-Image       | `lightx2v/Qwen-Image-Lightning`<br>`flymy-ai/qwen-image-realism-lora`<br>`prithivMLmods/Qwen-Image-HeadshotX`<br>`starsfriday/Qwen-Image-EVA-LoRA` |\n| Qwen-Image-Edit  | `ostris/qwen_image_edit_inpainting`<br>`lightx2v/Qwen-Image-Edit-2511-Lightning` |\n| Flux             | `dvyio/flux-lora-simple-illustration`<br>`XLabs-AI/flux-furry-lora`<br>`XLabs-AI/flux-RealismLora` |\n\n## Special requirements\n\n### Sliding Tile Attention\n\n- Currently, only Hopper GPUs (H100s) are supported.\n"
  },
  {
    "path": "docs/diffusion/contributing.md",
    "content": "# Contributing to SGLang Diffusion\n\nThis guide outlines the requirements for contributing to the SGLang Diffusion module (`sglang.multimodal_gen`).\n\n## On AI-Assisted (\"Vibe Coding\") PRs\n\nVibe-coded PRs are welcome — we judge code quality, not how it was produced. The bar is the same for all PRs:\n\n- **No over-commenting.** If the name says it all, skip the docstring.\n- **No over-catching.** Don't guard against errors that virtually never happen in practice.\n- **Test before submitting.** AI-generated code can be subtly wrong — verify correctness end-to-end.\n\n## Commit Message Convention\n\nWe follow a structured commit message format to maintain a clean history.\n\n**Format:**\n```text\n[diffusion] <scope>: <subject>\n```\n\n**Examples:**\n- `[diffusion] cli: add --perf-dump-path argument`\n- `[diffusion] scheduler: fix deadlock in batch processing`\n- `[diffusion] model: support Stable Diffusion 3.5`\n\n**Rules:**\n- **Prefix**: Always start with `[diffusion]`.\n- **Scope** (Optional): `cli`, `scheduler`, `model`, `pipeline`, `docs`, etc.\n- **Subject**: Imperative mood, short and clear (e.g., \"add feature\" not \"added feature\").\n\n## Performance Reporting\n\nFor PRs that impact **latency**, **throughput**, or **memory usage**, you **should** provide a performance comparison report.\n\n### How to Generate a Report\n\n1.  **Baseline**: run the benchmark (for a single generation task)\n    ```bash\n    $ sglang generate --model-path <model> --prompt \"A benchmark prompt\" --perf-dump-path baseline.json\n    ```\n\n2.  **New**: run the same benchmark, without modifying any server_args or sampling_params\n    ```bash\n    $ sglang generate --model-path <model> --prompt \"A benchmark prompt\" --perf-dump-path new.json\n    ```\n\n3.  **Compare**: run the compare script, which will print a Markdown table to the console\n    ```bash\n    $ python python/sglang/multimodal_gen/benchmarks/compare_perf.py baseline.json new.json [new2.json ...]\n    ### Performance Comparison Report\n    ...\n    ```\n4. **Paste**: paste the table into the PR description\n\n## CI-Based Change Protection\n\nConsider adding tests to the `pr-test` or `nightly-test` suites to safeguard your changes, especially for PRs that:\n\n- support a new model\n    - add a testcase for this new model to `testcase_configs.py`\n- support or fix important features\n- significantly improve performance\n\nPlease run the according testcase, then update/add the baseline to `perf_baselines.json` by following the instruction in console if applicable.\n\nSee [test](https://github.com/sgl-project/sglang/tree/main/python/sglang/multimodal_gen/test) for examples\n"
  },
  {
    "path": "docs/diffusion/environment_variables.md",
    "content": "## Apple MPS\n\n| Environment Variable | Default | Description                                                  |\n|----------------------|---------|--------------------------------------------------------------|\n| `SGLANG_USE_MLX`     | not set | Set to `1` to enable MLX fused Metal kernels for norm ops on MPS |\n\n## Caching Acceleration\n\nThese variables configure caching acceleration for Diffusion Transformer (DiT) models.\nSGLang supports multiple caching strategies - see [caching documentation](performance/cache/index.md) for an overview.\n\n### Cache-DiT Configuration\n\nSee [cache-dit documentation](performance/cache/cache_dit.md) for detailed configuration.\n\n| Environment Variable                | Default | Description                              |\n|-------------------------------------|---------|------------------------------------------|\n| `SGLANG_CACHE_DIT_ENABLED`          | false   | Enable Cache-DiT acceleration            |\n| `SGLANG_CACHE_DIT_FN`               | 1       | First N blocks to always compute         |\n| `SGLANG_CACHE_DIT_BN`               | 0       | Last N blocks to always compute          |\n| `SGLANG_CACHE_DIT_WARMUP`           | 4       | Warmup steps before caching              |\n| `SGLANG_CACHE_DIT_RDT`              | 0.24    | Residual difference threshold            |\n| `SGLANG_CACHE_DIT_MC`               | 3       | Max continuous cached steps              |\n| `SGLANG_CACHE_DIT_TAYLORSEER`       | false   | Enable TaylorSeer calibrator             |\n| `SGLANG_CACHE_DIT_TS_ORDER`         | 1       | TaylorSeer order (1 or 2)                |\n| `SGLANG_CACHE_DIT_SCM_PRESET`       | none    | SCM preset (none/slow/medium/fast/ultra) |\n| `SGLANG_CACHE_DIT_SCM_POLICY`       | dynamic | SCM caching policy                       |\n| `SGLANG_CACHE_DIT_SCM_COMPUTE_BINS` | not set | Custom SCM compute bins                  |\n| `SGLANG_CACHE_DIT_SCM_CACHE_BINS`   | not set | Custom SCM cache bins                    |\n\n## Cloud Storage\n\nThese variables configure S3-compatible cloud storage for automatically uploading generated images and videos.\n\n| Environment Variable            | Default | Description                                            |\n|---------------------------------|---------|--------------------------------------------------------|\n| `SGLANG_CLOUD_STORAGE_TYPE`     | not set | Set to `s3` to enable cloud storage                    |\n| `SGLANG_S3_BUCKET_NAME`         | not set | The name of the S3 bucket                              |\n| `SGLANG_S3_ENDPOINT_URL`        | not set | Custom endpoint URL (for MinIO, OSS, etc.)             |\n| `SGLANG_S3_REGION_NAME`         | us-east-1 | AWS region name                                      |\n| `SGLANG_S3_ACCESS_KEY_ID`       | not set | AWS Access Key ID                                      |\n| `SGLANG_S3_SECRET_ACCESS_KEY`   | not set | AWS Secret Access Key                                  |\n"
  },
  {
    "path": "docs/diffusion/index.md",
    "content": "# SGLang Diffusion\n\nSGLang Diffusion is an inference framework for accelerated image and video generation using diffusion models. It provides an end-to-end unified pipeline with optimized kernels and an efficient scheduler loop.\n\n## Key Features\n\n- **Broad Model Support**: Wan series, FastWan series, Hunyuan, Qwen-Image, Qwen-Image-Edit, Flux, Z-Image, GLM-Image, and more\n- **Fast Inference**: Optimized kernels, efficient scheduler loop, and Cache-DiT acceleration\n- **Ease of Use**: OpenAI-compatible API, CLI, and Python SDK\n- **Multi-Platform**:\n  - NVIDIA GPUs (H100, H200, A100, B200, 4090)\n  - AMD GPUs (MI300X, MI325X)\n  - Ascend NPU (A2, A3)\n  - Apple Silicon (M-series via MPS)\n  - Moore Threads GPUs (MTT S5000)\n\n---\n\n## Quick Start\n\n### Installation\n\n```bash\nuv pip install \"sglang[diffusion]\" --prerelease=allow\n```\n\nSee [Installation Guide](installation.md) for more installation methods and ROCm-specific instructions.\n\n### Basic Usage\n\nGenerate an image with the CLI:\n\n```bash\nsglang generate --model-path Qwen/Qwen-Image \\\n    --prompt \"A beautiful sunset over the mountains\" \\\n    --save-output\n```\n\nOr start a server with the OpenAI-compatible API:\n\n```bash\nsglang serve --model-path Qwen/Qwen-Image --port 30010\n```\n\n---\n\n## Documentation\n\n### Getting Started\n\n- **[Installation](installation.md)** - Install SGLang Diffusion via pip, uv, Docker, or from source\n- **[Compatibility Matrix](compatibility_matrix.md)** - Supported models and optimization compatibility\n\n### Usage\n\n- **[CLI Documentation](api/cli.md)** - Command-line interface for `sglang generate` and `sglang serve`\n- **[OpenAI API](api/openai_api.md)** - OpenAI-compatible API for image/video generation and LoRA management\n- **[Post-Processing](api/post_processing.md)** - Frame interpolation (RIFE) and upscaling (Real-ESRGAN)\n\n### Performance Optimization\n\n- **[Performance Overview](performance/index.md)** - Overview of all performance optimization strategies\n- **[Attention Backends](performance/attention_backends.md)** - Available attention backends (FlashAttention, SageAttention, etc.)\n- **[Caching Strategies](performance/cache/)** - Cache-DiT and TeaCache acceleration\n- **[Profiling](performance/profiling.md)** - Profiling techniques with PyTorch Profiler and Nsight Systems\n\n### Reference\n\n- **[Environment Variables](environment_variables.md)** - Configuration via environment variables\n- **[Support New Models](support_new_models.md)** - Guide for adding new diffusion models\n- **[Contributing](contributing.md)** - Contribution guidelines and commit message conventions\n- **[CI Performance](ci_perf.md)** - Performance baseline generation script\n\n---\n\n## CLI Quick Reference\n\n### Generate (one-off generation)\n\n```bash\nsglang generate --model-path <MODEL> --prompt \"<PROMPT>\" --save-output\n```\n\n### Serve (HTTP server)\n\n```bash\nsglang serve --model-path <MODEL> --port 30010\n```\n\n### Enable Cache-DiT acceleration\n\n```bash\nSGLANG_CACHE_DIT_ENABLED=true sglang generate --model-path <MODEL> --prompt \"<PROMPT>\"\n```\n\n---\n\n## References\n\n- [SGLang GitHub](https://github.com/sgl-project/sglang)\n- [Cache-DiT](https://github.com/vipshop/cache-dit)\n- [FastVideo](https://github.com/hao-ai-lab/FastVideo)\n- [xDiT](https://github.com/xdit-project/xDiT)\n- [Diffusers](https://github.com/huggingface/diffusers)\n"
  },
  {
    "path": "docs/diffusion/installation.md",
    "content": "# Install SGLang-Diffusion\n\nYou can install SGLang-Diffusion using one of the methods below.\n\n## Standard Installation (NVIDIA GPUs)\n\n### Method 1: With pip or uv\n\nIt is recommended to use uv for a faster installation:\n\n```bash\npip install --upgrade pip\npip install uv\nuv pip install \"sglang[diffusion]\" --prerelease=allow\n```\n\n### Method 2: From source\n\n```bash\n# Use the latest release branch\ngit clone https://github.com/sgl-project/sglang.git\ncd sglang\n\n# Install the Python packages\npip install --upgrade pip\npip install -e \"python[diffusion]\"\n\n# With uv\nuv pip install -e \"python[diffusion]\" --prerelease=allow\n```\n\n### Method 3: Using Docker\n\nThe Docker images are available on Docker Hub at [lmsysorg/sglang](https://hub.docker.com/r/lmsysorg/sglang), built from the [Dockerfile](https://github.com/sgl-project/sglang/blob/main/docker/Dockerfile).\nReplace `<secret>` below with your HuggingFace Hub [token](https://huggingface.co/docs/hub/en/security-tokens).\n\n```bash\ndocker run --gpus all \\\n    --shm-size 32g \\\n    -p 30000:30000 \\\n    -v ~/.cache/huggingface:/root/.cache/huggingface \\\n    --env \"HF_TOKEN=<secret>\" \\\n    --ipc=host \\\n    lmsysorg/sglang:dev \\\n    zsh -c '\\\n        echo \"Installing diffusion dependencies...\" && \\\n        pip install -e \"python[diffusion]\" && \\\n        echo \"Starting SGLang-Diffusion...\" && \\\n        sglang generate \\\n            --model-path black-forest-labs/FLUX.1-dev \\\n            --prompt \"A logo With Bold Large text: SGL Diffusion\" \\\n            --save-output \\\n    '\n```\n\n## Platform-Specific: ROCm (AMD GPUs)\n\nFor AMD Instinct GPUs (e.g., MI300X), you can use the ROCm-enabled Docker image:\n\n```bash\ndocker run --device=/dev/kfd --device=/dev/dri --ipc=host \\\n  -v ~/.cache/huggingface:/root/.cache/huggingface \\\n  --env HF_TOKEN=<secret> \\\n  lmsysorg/sglang:v0.5.5.post2-rocm700-mi30x \\\n  sglang generate --model-path black-forest-labs/FLUX.1-dev --prompt \"A logo With Bold Large text: SGL Diffusion\" --save-output\n```\n\nFor detailed ROCm system configuration and installation from source, see [AMD GPUs](../../platforms/amd_gpu.md).\n\n## Platform-Specific: MUSA (Moore Threads GPUs)\n\nFor Moore Threads GPUs (MTGPU) with the MUSA software stack, please follow the instructions below to install from source:\n\n```bash\n# Clone the repository\ngit clone https://github.com/sgl-project/sglang.git\ncd sglang\n\n# Install the Python packages\npip install --upgrade pip\nrm -f python/pyproject.toml && mv python/pyproject_other.toml python/pyproject.toml\npip install -e \"python[all_musa]\"\n```\n\n## Platform-Specific: Ascend NPU\n\nFor Ascend NPU, please follow the [NPU installation guide](../platforms/ascend_npu.md).\n\nQuick test:\n\n```bash\nsglang generate --model-path black-forest-labs/FLUX.1-dev \\\n    --prompt \"A logo With Bold Large text: SGL Diffusion\" \\\n    --save-output\n```\n\n## Platform-Specific: Apple MPS\n\nFor Apple MPS, please follow the instructions below to install from source:\n\n```bash\n# Install ffmpeg\nbrew install ffmpeg\n\n# Install uv\nbrew install uv\n\n# Clone the repository\ngit clone https://github.com/sgl-project/sglang.git\ncd sglang\n\n# Create and activate a virtual environment\nuv venv -p 3.11 sglang-diffusion\nsource sglang-diffusion/bin/activate\n\n# Install the Python packages\nuv pip install --upgrade pip\nrm -f python/pyproject.toml && mv python/pyproject_other.toml python/pyproject.toml\nuv pip install -e \"python[all_mps]\"\n```\n"
  },
  {
    "path": "docs/diffusion/performance/attention_backends.md",
    "content": "# Attention Backends\n\nThis document describes the attention backends available in sglang diffusion (`sglang.multimodal_gen`) and how to select them.\n\n## Overview\n\nAttention backends are defined by `AttentionBackendEnum` (`sglang.multimodal_gen.runtime.platforms.interface.AttentionBackendEnum`) and selected via the CLI flag `--attention-backend`.\n\nBackend selection is performed by the shared attention layers (e.g. `LocalAttention` / `USPAttention` / `UlyssesAttention` in `sglang.multimodal_gen.runtime.layers.attention.layer`) and therefore applies to any model component using these layers (e.g. diffusion transformer / DiT and encoders).\n\nWhen using the diffusers backend, `--attention-backend` is passed through to diffusers'\n`set_attention_backend` (e.g., `flash`, `_flash_3_hub`, `sage`, `xformers`, `native`).\n\n- **CUDA**: prefers FlashAttention (FA3/FA4) when supported; otherwise falls back to PyTorch SDPA.\n- **ROCm**: uses FlashAttention when available; otherwise falls back to PyTorch SDPA.\n- **MPS**: always uses PyTorch SDPA.\n- **NPU**: always uses PyTorch SDPA.\n\n## Backend options\n\nFor SGLang-native pipelines, the CLI accepts the lowercase names of `AttentionBackendEnum`. The table below lists the backends implemented by the built-in platforms. `fa3`/`fa4` are accepted as aliases for `fa`.\n\n| CLI value | Enum value | Notes |\n|---|---|---|\n| `fa` / `fa3` / `fa4` | `FA` | FlashAttention. `fa3/fa4` are normalized to `fa` during argument parsing (`ServerArgs.__post_init__`). |\n| `torch_sdpa` | `TORCH_SDPA` | PyTorch `scaled_dot_product_attention`. |\n| `sliding_tile_attn` | `SLIDING_TILE_ATTN` | Sliding Tile Attention (STA). Requires `st_attn`. Configure via `--attention-backend-config`. |\n| `sage_attn` | `SAGE_ATTN` | Requires `sageattention`. Upstream SageAttention CUDA extensions target SM80/SM86/SM89/SM90/SM120 (compute capability 8.0/8.6/8.9/9.0/12.0); see upstream `setup.py`: https://github.com/thu-ml/SageAttention/blob/main/setup.py. |\n| `sage_attn_3` | `SAGE_ATTN_3` | Requires SageAttention3 installed per upstream instructions. |\n| `video_sparse_attn` | `VIDEO_SPARSE_ATTN` | Requires `vsa`. Configure `sparsity` via `--attention-backend-config`. |\n| `vmoba_attn` | `VMOBA_ATTN` | Requires `kernel.attn.vmoba_attn.vmoba`. Configure via `--attention-backend-config`. |\n| `aiter` | `AITER` | Requires `aiter`. |\n| `aiter_sage` | `AITER_SAGE` | Requires `aiter`. |\n| `sparse_video_gen_2_attn` | `SPARSE_VIDEO_GEN_2_ATTN` | Requires `svg`. See installation instructions at https://github.com/svg-project/Sparse-VideoGen. |\n\n## Selection priority\n\nThe selection order in `runtime/layers/attention/selector.py` is:\n\n1. `global_force_attn_backend(...)` / `global_force_attn_backend_context_manager(...)`\n2. CLI `--attention-backend` (`ServerArgs.attention_backend`)\n3. Auto selection (platform capability, dtype, and installed packages)\n\n## Configuration\n\nSome backends require additional configuration. You can pass these parameters via `--attention-backend-config`. This argument accepts:\n- A path to a JSON or YAML configuration file.\n- A JSON string (e.g., `'{\"sparsity\": 0.5}'`).\n- Key-value pairs (e.g., `\"sparsity=0.5,enable_x=true\"`).\n\n### Supported Configuration Parameters\n\n**Sliding Tile Attention (`sliding_tile_attn`)**\n\n| Parameter | Type | Description | Default |\n| :--- | :--- | :--- | :--- |\n| `mask_strategy_file_path` | `str` | **Required.** Path to the mask strategy JSON file. | - |\n| `sta_mode` | `str` | Mode of STA. | `STA_inference` |\n| `skip_time_steps` | `int` | Number of steps to use full attention before switching to sparse attention. | `15` |\n\n**Video Sparse Attention (`video_sparse_attn`)**\n\n| Parameter | Type | Description | Default |\n| :--- | :--- | :--- | :--- |\n| `sparsity` | `float` | Validation sparsity (0.0 - 1.0). | `0.0` |\n\n**V-MoBA (`vmoba_attn`)**\n\n| Parameter | Type | Description | Default |\n| :--- | :--- | :--- | :--- |\n| `temporal_chunk_size` | `int` | Chunk size for temporal dimension. | - |\n| `temporal_topk` | `int` | Top-K tokens to select in temporal dimension. | - |\n| `spatial_chunk_size` | `list[int]` | Chunk size for spatial dimension (H, W). | - |\n| `spatial_topk` | `int` | Top-K tokens to select in spatial dimension. | - |\n| `st_chunk_size` | `list[int]` | Chunk size for spatiotemporal dimension (T, H, W). | - |\n| `st_topk` | `int` | Top-K tokens to select in spatiotemporal dimension. | - |\n| `moba_select_mode` | `str` | Selection mode (e.g., `threshold`). | `threshold` |\n| `moba_threshold` | `float` | Threshold value for selection. | `0.25` |\n| `moba_threshold_type` | `str` | Type of thresholding (e.g., `query_head`). | `query_head` |\n| `first_full_step` | `int` | Number of initial steps to use full attention. | `12` |\n| `first_full_layer` | `int` | Number of initial layers to use full attention. | `0` |\n| `temporal_layer` | `int` | Number of temporal layers. | `1` |\n| `spatial_layer` | `int` | Number of spatial layers. | `1` |\n| `st_layer` | `int` | Number of spatiotemporal layers. | `1` |\n\n## Platform support matrix\n\n| Backend | CUDA | ROCm | MPS | NPU | Notes |\n|---|---:|---:|---:|---:|---|\n| `fa` | ✅ | ✅ | ❌ | ❌ | CUDA requires SM80+ and fp16/bf16. FlashAttention is only used when the required runtime is installed; otherwise it falls back to `torch_sdpa`. |\n| `torch_sdpa` | ✅ | ✅ | ✅ | ✅ | Most compatible option across platforms. |\n| `sliding_tile_attn` | ✅ | ❌ | ❌ | ❌ | CUDA-only. Requires `st_attn`. Configure via `--attention-backend-config`. |\n| `sage_attn` | ✅ | ❌ | ❌ | ❌ | CUDA-only (optional dependency). |\n| `sage_attn_3` | ✅ | ❌ | ❌ | ❌ | CUDA-only (optional dependency). |\n| `video_sparse_attn` | ✅ | ❌ | ❌ | ❌ | CUDA-only. Requires `vsa`. Configure `sparsity` via `--attention-backend-config`. |\n| `vmoba_attn` | ✅ | ❌ | ❌ | ❌ | CUDA-only. Requires `kernel.attn.vmoba_attn.vmoba`. Configure via `--attention-backend-config`. |\n| `aiter` | ❌ | ✅ | ❌ | ❌ | Requires `aiter`. |\n| `aiter_sage` | ❌ | ✅ | ❌ | ❌ | Requires `aiter`. |\n| `sparse_video_gen_2_attn` | ✅ | ❌ | ❌ | ❌ | CUDA-only. Requires `svg`. |\n\n## Usage\n\n### Select a backend via CLI\n\n```bash\nsglang generate \\\n  --model-path <MODEL_PATH_OR_ID> \\\n  --prompt \"...\" \\\n  --attention-backend fa\n```\n\n```bash\nsglang generate \\\n  --model-path <MODEL_PATH_OR_ID> \\\n  --prompt \"...\" \\\n  --attention-backend torch_sdpa\n```\n\n### Using Sliding Tile Attention (STA)\n\n```bash\n# Pass the mask strategy file path via config\nsglang generate \\\n  --model-path <MODEL_PATH_OR_ID> \\\n  --prompt \"...\" \\\n  --attention-backend sliding_tile_attn \\\n  --attention-backend-config \"mask_strategy_file_path=/abs/path/to/mask_strategy.json\"\n```\n\n### Notes for ROCm / MPS\n\n- ROCm: use `--attention-backend torch_sdpa` or `fa` depending on what is available in your environment.\n- MPS: the platform implementation always uses `torch_sdpa`.\n"
  },
  {
    "path": "docs/diffusion/performance/cache/cache_dit.md",
    "content": "# Cache-DiT Acceleration\n\nSGLang integrates [Cache-DiT](https://github.com/vipshop/cache-dit), a caching acceleration engine for Diffusion Transformers (DiT), to achieve up to **1.69x inference speedup** with minimal quality loss.\n\n## Overview\n\n**Cache-DiT** uses intelligent caching strategies to skip redundant computation in the denoising loop:\n\n- **DBCache (Dual Block Cache)**: Dynamically decides when to cache transformer blocks based on residual differences\n- **TaylorSeer**: Uses Taylor expansion for calibration to optimize caching decisions\n- **SCM (Step Computation Masking)**: Step-level caching control for additional speedup\n\n## Basic Usage\n\nEnable Cache-DiT by exporting the environment variable and using `sglang generate` or `sglang serve` :\n\n```bash\nSGLANG_CACHE_DIT_ENABLED=true \\\nsglang generate --model-path Qwen/Qwen-Image \\\n    --prompt \"A beautiful sunset over the mountains\"\n```\n\n## Diffusers Backend\n\nCache-DiT supports loading acceleration configs from a custom YAML file. For\ndiffusers pipelines (`diffusers` backend), pass the YAML/JSON path via `--cache-dit-config`. This\nflow requires cache-dit >= 1.2.0 (`cache_dit.load_configs`).\n\n### Single GPU inference\n\nDefine a `cache.yaml` file that contains:\n\n- DBCache + TaylorSeer\n\n```yaml\ncache_config:\n  max_warmup_steps: 8\n  warmup_interval: 2\n  max_cached_steps: -1\n  max_continuous_cached_steps: 2\n  Fn_compute_blocks: 1\n  Bn_compute_blocks: 0\n  residual_diff_threshold: 0.12\n  enable_taylorseer: true\n  taylorseer_order: 1\n```\n\nThen apply the config with:\n\n```bash\nsglang generate \\\n  --backend diffusers \\\n  --model-path Qwen/Qwen-Image \\\n  --cache-dit-config cache.yaml \\\n  --prompt \"A beautiful sunset over the mountains\"\n```\n\n- DBCache + TaylorSeer + SCM (Step Computation Mask)\n\n```yaml\ncache_config:\n  max_warmup_steps: 8\n  warmup_interval: 2\n  max_cached_steps: -1\n  max_continuous_cached_steps: 2\n  Fn_compute_blocks: 1\n  Bn_compute_blocks: 0\n  residual_diff_threshold: 0.12\n  enable_taylorseer: true\n  taylorseer_order: 1\n  # Must set the num_inference_steps for SCM. The SCM will automatically\n  # generate the steps computation mask based on the num_inference_steps.\n  # Reference: https://cache-dit.readthedocs.io/en/latest/user_guide/CACHE_API/#scm-steps-computation-masking\n  num_inference_steps: 28\n  steps_computation_mask: fast\n```\n\n- DBCache + TaylorSeer + SCM (Step Computation Mask) + Cache CFG\n\n```yaml\ncache_config:\n  max_warmup_steps: 8\n  warmup_interval: 2\n  max_cached_steps: -1\n  max_continuous_cached_steps: 2\n  Fn_compute_blocks: 1\n  Bn_compute_blocks: 0\n  residual_diff_threshold: 0.12\n  enable_taylorseer: true\n  taylorseer_order: 1\n  num_inference_steps: 28\n  steps_computation_mask: fast\n  enable_sperate_cfg: true # e.g, Qwen-Image, Wan, Chroma, Ovis-Image, etc.\n```\n\n### Distributed inference\n\n- 1D Parallelism\n\nDefine a parallelism only config yaml `parallel.yaml` file that contains:\n\n```yaml\nparallelism_config:\n  ulysses_size: auto\n  attention_backend: native\n```\n\nThen, apply the distributed inference acceleration config from yaml. `ulysses_size: auto` means that cache-dit will auto detect the `world_size` as the ulysses_size. Otherwise, you should manually set it as specific int number, e.g, 4.\n\nThen apply the distributed config with: (Note: please add `--num-gpus N` to specify the number of gpus for distributed inference)\n\n```bash\nsglang generate \\\n  --backend diffusers \\\n  --num-gpus 4 \\\n  --model-path Qwen/Qwen-Image \\\n  --cache-dit-config parallel.yaml \\\n  --prompt \"A futuristic cityscape at sunset\"\n```\n\n- 2D Parallelism\n\nYou can also define a 2D parallelism config yaml `parallel_2d.yaml` file that contains:\n\n```yaml\nparallelism_config:\n  ulysses_size: auto\n  tp_size: 2\n  attention_backend: native\n```\nThen, apply the 2D parallelism config from yaml. Here `tp_size: 2` means using tensor parallelism with size 2. The `ulysses_size: auto` means that cache-dit will auto detect the `world_size // tp_size` as the ulysses_size.\n\n- 3D Parallelism\n\nYou can also define a 3D parallelism config yaml `parallel_3d.yaml` file that contains:\n\n```yaml\nparallelism_config:\n  ulysses_size: 2\n  ring_size: 2\n  tp_size: 2\n  attention_backend: native\n```\nThen, apply the 3D parallelism config from yaml. Here `ulysses_size: 2`, `ring_size: 2`, `tp_size: 2` means using ulysses parallelism with size 2, ring parallelism with size 2 and tensor parallelism with size 2.\n\n- Ulysses Anything Attention\n\nTo enable Ulysses Anything Attention, you can define a parallelism config yaml `parallel_uaa.yaml` file that contains:\n\n```yaml\nparallelism_config:\n  ulysses_size: auto\n  attention_backend: native\n  ulysses_anything: true\n```\n\n- Ulysses FP8 Communication\n\nFor device that don't have NVLink support, you can enable Ulysses FP8 Communication to further reduce the communication overhead. You can define a parallelism config yaml `parallel_fp8.yaml` file that contains:\n\n```yaml\nparallelism_config:\n  ulysses_size: auto\n  attention_backend: native\n  ulysses_float8: true\n```\n\n- Async Ulysses CP\n\nYou can also enable async ulysses CP to overlap the communication and computation. Define a parallelism config yaml `parallel_async.yaml` file that contains:\n\n```yaml\nparallelism_config:\n  ulysses_size: auto\n  attention_backend: native\n  ulysses_async: true # Now, only support for FLUX.1, Qwen-Image, Ovis-Image and Z-Image.\n```\nThen, apply the config from yaml. Here `ulysses_async: true` means enabling async ulysses CP.\n\n- TE-P and VAE-P\n\nYou can also specify the extra parallel modules in the yaml config. For example, define a parallelism config yaml `parallel_extra.yaml` file that contains:\n\n```yaml\nparallelism_config:\n  ulysses_size: auto\n  attention_backend: native\n  extra_parallel_modules: [\"text_encoder\", \"vae\"]\n```\n\n\n### Hybrid Cache and Parallelism\n\nDefine a hybrid cache and parallel acceleration config yaml `hybrid.yaml` file that contains:\n\n```yaml\ncache_config:\n  max_warmup_steps: 8\n  warmup_interval: 2\n  max_cached_steps: -1\n  max_continuous_cached_steps: 2\n  Fn_compute_blocks: 1\n  Bn_compute_blocks: 0\n  residual_diff_threshold: 0.12\n  enable_taylorseer: true\n  taylorseer_order: 1\nparallelism_config:\n  ulysses_size: auto\n  attention_backend: native\n  extra_parallel_modules: [\"text_encoder\", \"vae\"]\n```\n\nThen, apply the hybrid cache and parallel acceleration config from yaml.\n\n```bash\nsglang generate \\\n  --backend diffusers \\\n  --num-gpus 4 \\\n  --model-path Qwen/Qwen-Image \\\n  --cache-dit-config hybrid.yaml \\\n  --prompt \"A beautiful sunset over the mountains\"\n```\n\n### Attention Backend\n\nIn some cases, users may want to only specify the attention backend without any other optimization configs. In this case, you can define a yaml file `attention.yaml` that only contains:\n\n```yaml\nattention_backend: \"flash\" # '_flash_3' for Hopper\n```\n\n### Quantization\n\nYou can also specify the quantization config in the yaml file, required `torchao>=0.16.0`. For example, define a yaml file `quantize.yaml` that contains:\n\n```yaml\nquantize_config: # quantization configuration for transformer modules\n  # float8 (DQ), float8_weight_only, float8_blockwise, int8 (DQ), int8_weight_only, etc.\n  quant_type: \"float8\"\n  # layers to exclude from quantization (transformer). layers that contains any of the\n  # keywords in the exclude_layers list will be excluded from quantization. This is useful\n  # for some sensitive layers that are not robust to quantization, e.g., embedding layers.\n  exclude_layers:\n    - \"embedder\"\n    - \"embed\"\n  verbose: false # whether to print verbose logs during quantization\n```\nThen, apply the quantization config from yaml. Please also enable torch.compile for better performance if you are using quantization. For example:\n\n```bash\nsglang generate \\\n  --backend diffusers \\\n  --model-path Qwen/Qwen-Image \\\n  --warmup \\\n  --cache-dit-config quantize.yaml \\\n  --enable-torch-compile \\\n  --dit-cpu-offload false \\\n  --text-encoder-cpu-offload false \\\n  --prompt \"A beautiful sunset over the mountains\"\n```\n\n### Combined Configs: Cache + Parallelism + Quantization\n\nYou can also combine all the above configs together in a single yaml file `combined.yaml` that contains:\n\n```yaml\ncache_config:\n  max_warmup_steps: 8\n  warmup_interval: 2\n  max_cached_steps: -1\n  max_continuous_cached_steps: 2\n  Fn_compute_blocks: 1\n  Bn_compute_blocks: 0\n  residual_diff_threshold: 0.12\n  enable_taylorseer: true\n  taylorseer_order: 1\nparallelism_config:\n  ulysses_size: auto\n  attention_backend: native\n  extra_parallel_modules: [\"text_encoder\", \"vae\"]\nquantize_config:\n  quant_type: \"float8\"\n  exclude_layers:\n    - \"embedder\"\n    - \"embed\"\n  verbose: false\n```\nThen, apply the combined cache, parallelism and quantization config from yaml. Please also enable torch.compile for better performance if you are using quantization.\n\n## Advanced Configuration\n\n### DBCache Parameters\n\nDBCache controls block-level caching behavior:\n\n| Parameter | Env Variable              | Default | Description                              |\n|-----------|---------------------------|---------|------------------------------------------|\n| Fn        | `SGLANG_CACHE_DIT_FN`     | 1       | Number of first blocks to always compute |\n| Bn        | `SGLANG_CACHE_DIT_BN`     | 0       | Number of last blocks to always compute  |\n| W         | `SGLANG_CACHE_DIT_WARMUP` | 4       | Warmup steps before caching starts       |\n| R         | `SGLANG_CACHE_DIT_RDT`    | 0.24    | Residual difference threshold            |\n| MC        | `SGLANG_CACHE_DIT_MC`     | 3       | Maximum continuous cached steps          |\n\n### TaylorSeer Configuration\n\nTaylorSeer improves caching accuracy using Taylor expansion:\n\n| Parameter | Env Variable                  | Default | Description                     |\n|-----------|-------------------------------|---------|---------------------------------|\n| Enable    | `SGLANG_CACHE_DIT_TAYLORSEER` | false   | Enable TaylorSeer calibrator    |\n| Order     | `SGLANG_CACHE_DIT_TS_ORDER`   | 1       | Taylor expansion order (1 or 2) |\n\n### Combined Configuration Example\n\nDBCache and TaylorSeer are complementary strategies that work together, you can configure both sets of parameters\nsimultaneously:\n\n```bash\nSGLANG_CACHE_DIT_ENABLED=true \\\nSGLANG_CACHE_DIT_FN=2 \\\nSGLANG_CACHE_DIT_BN=1 \\\nSGLANG_CACHE_DIT_WARMUP=4 \\\nSGLANG_CACHE_DIT_RDT=0.4 \\\nSGLANG_CACHE_DIT_MC=4 \\\nSGLANG_CACHE_DIT_TAYLORSEER=true \\\nSGLANG_CACHE_DIT_TS_ORDER=2 \\\nsglang generate --model-path black-forest-labs/FLUX.1-dev \\\n    --prompt \"A curious raccoon in a forest\"\n```\n\n### SCM (Step Computation Masking)\n\nSCM provides step-level caching control for additional speedup. It decides which denoising steps to compute fully and\nwhich to use cached results.\n\n**SCM Presets**\n\nSCM is configured with presets:\n\n| Preset   | Compute Ratio | Speed    | Quality    |\n|----------|---------------|----------|------------|\n| `none`   | 100%          | Baseline | Best       |\n| `slow`   | ~75%          | ~1.3x    | High       |\n| `medium` | ~50%          | ~2x      | Good       |\n| `fast`   | ~35%          | ~3x      | Acceptable |\n| `ultra`  | ~25%          | ~4x      | Lower      |\n\n**Usage**\n\n```bash\nSGLANG_CACHE_DIT_ENABLED=true \\\nSGLANG_CACHE_DIT_SCM_PRESET=medium \\\nsglang generate --model-path Qwen/Qwen-Image \\\n    --prompt \"A futuristic cityscape at sunset\"\n```\n\n**Custom SCM Bins**\n\nFor fine-grained control over which steps to compute vs cache:\n\n```bash\nSGLANG_CACHE_DIT_ENABLED=true \\\nSGLANG_CACHE_DIT_SCM_COMPUTE_BINS=\"8,3,3,2,2\" \\\nSGLANG_CACHE_DIT_SCM_CACHE_BINS=\"1,2,2,2,3\" \\\nsglang generate --model-path Qwen/Qwen-Image \\\n    --prompt \"A futuristic cityscape at sunset\"\n```\n\n**SCM Policy**\n\n| Policy    | Env Variable                          | Description                                 |\n|-----------|---------------------------------------|---------------------------------------------|\n| `dynamic` | `SGLANG_CACHE_DIT_SCM_POLICY=dynamic` | Adaptive caching based on content (default) |\n| `static`  | `SGLANG_CACHE_DIT_SCM_POLICY=static`  | Fixed caching pattern                       |\n\n## Environment Variables\n\nAll Cache-DiT parameters can be configured via environment variables.\nSee [Environment Variables](../../environment_variables.md) for the complete list.\n\n## Supported Models\n\nSGLang Diffusion x Cache-DiT supports almost all models originally supported in SGLang Diffusion:\n\n| Model Family | Example Models              |\n|--------------|-----------------------------|\n| Wan          | Wan2.1, Wan2.2              |\n| Flux         | FLUX.1-dev, FLUX.2-dev      |\n| Z-Image      | Z-Image-Turbo               |\n| Qwen         | Qwen-Image, Qwen-Image-Edit |\n| Hunyuan      | HunyuanVideo                |\n\n## Performance Tips\n\n1. **Start with defaults**: The default parameters work well for most models\n2. **Use TaylorSeer**: It typically improves both speed and quality\n3. **Tune R threshold**: Lower values = better quality, higher values = faster\n4. **SCM for extra speed**: Use `medium` preset for good speed/quality balance\n5. **Warmup matters**: Higher warmup = more stable caching decisions\n\n## Limitations\n\n- **SGLang-native pipelines**: Distributed support (TP/SP) is not yet validated; Cache-DiT will be automatically\n  disabled when `world_size > 1`.\n- **SCM minimum steps**: SCM requires >= 8 inference steps to be effective\n- **Model support**: Only models registered in Cache-DiT's BlockAdapterRegister are supported\n\n## Troubleshooting\n\n### SCM disabled for low step count\n\nFor models with < 8 inference steps (e.g., DMD distilled models), SCM will be automatically disabled. DBCache\nacceleration still works.\n\n## References\n\n- [Cache-DiT](https://github.com/vipshop/cache-dit)\n- [SGLang Diffusion](../index.md)\n"
  },
  {
    "path": "docs/diffusion/performance/cache/index.md",
    "content": "# Caching Acceleration for Diffusion Models\n\nSGLang provides multiple caching acceleration strategies for Diffusion Transformer (DiT) models. These strategies can significantly reduce inference time by skipping redundant computation.\n\n## Overview\n\nSGLang supports two complementary caching approaches:\n\n| Strategy | Scope | Mechanism | Best For |\n|----------|-------|-----------|----------|\n| **Cache-DiT** | Block-level | Skip individual transformer blocks dynamically | Advanced, higher speedup |\n| **TeaCache** | Timestep-level | Skip entire denoising steps based on L1 similarity | Simple, built-in |\n\n\n\n## Cache-DiT\n\n[Cache-DiT](https://github.com/vipshop/cache-dit) provides block-level caching with\nadvanced strategies like DBCache and TaylorSeer. It can achieve up to **1.69x speedup**.\n\nSee [cache_dit.md](cache_dit.md) for detailed configuration.\n\n### Quick Start\n\n```bash\nSGLANG_CACHE_DIT_ENABLED=true \\\nsglang generate --model-path Qwen/Qwen-Image \\\n    --prompt \"A beautiful sunset over the mountains\"\n```\n\n### Key Features\n\n- **DBCache**: Dynamic block-level caching based on residual differences\n- **TaylorSeer**: Taylor expansion-based calibration for optimized caching\n- **SCM**: Step-level computation masking for additional speedup\n\n## TeaCache\n\nTeaCache (Temporal similarity-based caching) accelerates diffusion inference by detecting when consecutive denoising steps are similar enough to skip computation entirely.\n\nSee [teacache.md](teacache.md) for detailed documentation.\n\n### Quick Overview\n\n- Tracks L1 distance between modulated inputs across timesteps\n- When accumulated distance is below threshold, reuses cached residual\n- Supports CFG with separate positive/negative caches\n\n### Supported Models\n\n- Wan (wan2.1, wan2.2)\n- Hunyuan (HunyuanVideo)\n- Z-Image\n\nFor Flux and Qwen models, TeaCache is automatically disabled when CFG is enabled.\n\n## References\n\n- [Cache-DiT Repository](https://github.com/vipshop/cache-dit)\n- [TeaCache Paper](https://arxiv.org/abs/2411.14324)\n"
  },
  {
    "path": "docs/diffusion/performance/cache/teacache.md",
    "content": "# TeaCache Acceleration\n\n> **Note**: This is one of two caching strategies available in SGLang.\n> For an overview of all caching options, see [caching](../index.md).\n\nTeaCache (Temporal similarity-based caching) accelerates diffusion inference by detecting when consecutive denoising steps are similar enough to skip computation entirely.\n\n## Overview\n\nTeaCache works by:\n1. Tracking the L1 distance between modulated inputs across consecutive timesteps\n2. Accumulating the rescaled L1 distance over steps\n3. When accumulated distance is below a threshold, reusing the cached residual\n4. Supporting CFG (Classifier-Free Guidance) with separate positive/negative caches\n\n## How It Works\n\n### L1 Distance Tracking\n\nAt each denoising step, TeaCache computes the relative L1 distance between the current and previous modulated inputs:\n\n```\nrel_l1 = |current - previous|.mean() / |previous|.mean()\n```\n\nThis distance is then rescaled using polynomial coefficients and accumulated:\n\n```\naccumulated += poly(coefficients)(rel_l1)\n```\n\n### Cache Decision\n\n- If `accumulated >= threshold`: Force computation, reset accumulator\n- If `accumulated < threshold`: Skip computation, use cached residual\n\n### CFG Support\n\nFor models that support CFG cache separation (Wan, Hunyuan, Z-Image), TeaCache maintains separate caches for positive and negative branches:\n- `previous_modulated_input` / `previous_residual` for positive branch\n- `previous_modulated_input_negative` / `previous_residual_negative` for negative branch\n\nFor models that don't support CFG separation (Flux, Qwen), TeaCache is automatically disabled when CFG is enabled.\n\n## Configuration\n\nTeaCache is configured via `TeaCacheParams` in the sampling parameters:\n\n```python\nfrom sglang.multimodal_gen.configs.sample.teacache import TeaCacheParams\n\nparams = TeaCacheParams(\n    teacache_thresh=0.1,           # Threshold for accumulated L1 distance\n    coefficients=[1.0, 0.0, 0.0],  # Polynomial coefficients for L1 rescaling\n)\n```\n\n### Parameters\n\n| Parameter | Type | Description |\n|-----------|------|-------------|\n| `teacache_thresh` | float | Threshold for accumulated L1 distance. Lower = more caching, faster but potentially lower quality |\n| `coefficients` | list[float] | Polynomial coefficients for L1 rescaling. Model-specific tuning |\n\n### Model-Specific Configurations\n\nDifferent models may have different optimal configurations. The coefficients are typically tuned per-model to balance speed and quality.\n\n## Supported Models\n\nTeaCache is built into the following model families:\n\n| Model Family | CFG Cache Separation | Notes |\n|--------------|---------------------|-------|\n| Wan (wan2.1, wan2.2) | Yes | Full support |\n| Hunyuan (HunyuanVideo) | Yes | To be supported |\n| Z-Image | Yes | To be supported |\n| Flux | No | To be supported |\n| Qwen | No | To be supported |\n\n\n## References\n\n- [TeaCache: Accelerating Diffusion Models with Temporal Similarity](https://arxiv.org/abs/2411.14324)\n"
  },
  {
    "path": "docs/diffusion/performance/index.md",
    "content": "# Performance Optimization\n\nSGLang-Diffusion provides multiple performance optimization strategies to accelerate inference. This section covers all available performance tuning options.\n\n## Overview\n\n| Optimization | Type | Description |\n|--------------|------|-------------|\n| **Cache-DiT** | Caching | Block-level caching with DBCache, TaylorSeer, and SCM |\n| **TeaCache** | Caching | Timestep-level caching using L1 similarity |\n| **Attention Backends** | Kernel | Optimized attention implementations (FlashAttention, SageAttention, etc.) |\n| **Profiling** | Diagnostics | PyTorch Profiler and Nsight Systems guidance |\n\n## Caching Strategies\n\nSGLang supports two complementary caching approaches:\n\n### Cache-DiT\n\n[Cache-DiT](https://github.com/vipshop/cache-dit) provides block-level caching with advanced strategies. It can achieve up to **1.69x speedup**.\n\n**Quick Start:**\n```bash\nSGLANG_CACHE_DIT_ENABLED=true \\\nsglang generate --model-path Qwen/Qwen-Image \\\n    --prompt \"A beautiful sunset over the mountains\"\n```\n\n**Key Features:**\n- **DBCache**: Dynamic block-level caching based on residual differences\n- **TaylorSeer**: Taylor expansion-based calibration for optimized caching\n- **SCM**: Step-level computation masking for additional speedup\n\nSee [Cache-DiT Documentation](cache/cache_dit.md) for detailed configuration.\n\n### TeaCache\n\nTeaCache (Temporal similarity-based caching) accelerates diffusion inference by detecting when consecutive denoising steps are similar enough to skip computation entirely.\n\n**Quick Overview:**\n- Tracks L1 distance between modulated inputs across timesteps\n- When accumulated distance is below threshold, reuses cached residual\n- Supports CFG with separate positive/negative caches\n\n**Supported Models:** Wan (wan2.1, wan2.2), Hunyuan (HunyuanVideo), Z-Image\n\nSee [TeaCache Documentation](cache/teacache.md) for detailed configuration.\n\n## Attention Backends\n\nDifferent attention backends offer varying performance characteristics depending on your hardware and model:\n\n- **FlashAttention**: Fastest on NVIDIA GPUs with fp16/bf16\n- **SageAttention**: Alternative optimized implementation\n- **xformers**: Memory-efficient attention\n- **SDPA**: PyTorch native scaled dot-product attention\n\nSee [Attention Backends](attention_backends.md) for platform support and configuration options.\n\n## Profiling\n\nTo diagnose performance bottlenecks, SGLang-Diffusion supports profiling tools:\n\n- **PyTorch Profiler**: Built-in Python profiling\n- **Nsight Systems**: GPU kernel-level analysis\n\nSee [Profiling Guide](profiling.md) for detailed instructions.\n\n## References\n\n- [Cache-DiT Repository](https://github.com/vipshop/cache-dit)\n- [TeaCache Paper](https://arxiv.org/abs/2411.14324)\n"
  },
  {
    "path": "docs/diffusion/performance/profiling.md",
    "content": "# Profiling Multimodal Generation\n\nThis guide covers profiling techniques for multimodal generation pipelines in SGLang.\n\n## PyTorch Profiler\n\nPyTorch Profiler provides detailed kernel execution time, call stack, and GPU utilization metrics.\n\n### Denoising Stage Profiling\n\nProfile the denoising stage with sampled timesteps (default: 5 steps after 1 warmup step):\n\n```bash\nsglang generate \\\n  --model-path Qwen/Qwen-Image \\\n  --prompt \"A Logo With Bold Large Text: SGL Diffusion\" \\\n  --seed 0 \\\n  --profile\n```\n\n**Parameters:**\n- `--profile`: Enable profiling for the denoising stage\n- `--num-profiled-timesteps N`: Number of timesteps to profile after warmup (default: 5)\n  - Smaller values reduce trace file size\n  - Example: `--num-profiled-timesteps 10` profiles 10 steps after 1 warmup step\n\n### Full Pipeline Profiling\n\nProfile all pipeline stages (text encoding, denoising, VAE decoding, etc.):\n\n```bash\nsglang generate \\\n  --model-path Qwen/Qwen-Image \\\n  --prompt \"A Logo With Bold Large Text: SGL Diffusion\" \\\n  --seed 0 \\\n  --profile \\\n  --profile-all-stages\n```\n\n**Parameters:**\n- `--profile-all-stages`: Used with `--profile`, profile all pipeline stages instead of just denoising\n\n### Output Location\n\nBy default, trace files are saved in the ./logs/ directory.\n\nThe exact output file path will be shown in the console output, for example:\n\n```bash\n[mm-dd hh:mm:ss] Saved profiler traces to: /sgl-workspace/sglang/logs/mocked_fake_id_for_offline_generate-5_steps-global-rank0.trace.json.gz\n```\n\n### View Traces\n\nLoad and visualize trace files at:\n- https://ui.perfetto.dev/ (recommended)\n- chrome://tracing (Chrome only)\n\nFor large trace files, reduce `--num-profiled-timesteps` or avoid using `--profile-all-stages`.\n\n\n### `--perf-dump-path` (Stage/Step Timing Dump)\n\nBesides profiler traces, you can also dump a lightweight JSON report that contains:\n- stage-level timing breakdown for the full pipeline\n- step-level timing breakdown for the denoising stage (per diffusion step)\n\nThis is useful to quickly identify which stage dominates end-to-end latency, and whether denoising steps have uniform runtimes (and if not, which step has an abnormal spike).\n\nThe dumped JSON contains a `denoise_steps_ms` field formatted as an array of objects, each with a `step` key (the step index) and a `duration_ms` key.\n\nExample:\n\n```bash\nsglang generate \\\n  --model-path <MODEL_PATH_OR_ID> \\\n  --prompt \"<PROMPT>\" \\\n  --perf-dump-path perf.json\n```\n\n## Nsight Systems\n\nNsight Systems provides low-level CUDA profiling with kernel details, register usage, and memory access patterns.\n\n### Installation\n\nSee the [SGLang profiling guide](https://github.com/sgl-project/sglang/blob/main/docs/developer_guide/benchmark_and_profiling.md#profile-with-nsight) for installation instructions.\n\n### Basic Profiling\n\nProfile the entire pipeline execution:\n\n```bash\nnsys profile \\\n  --trace-fork-before-exec=true \\\n  --cuda-graph-trace=node \\\n  --force-overwrite=true \\\n  -o QwenImage \\\n  sglang generate \\\n    --model-path Qwen/Qwen-Image \\\n    --prompt \"A Logo With Bold Large Text: SGL Diffusion\" \\\n    --seed 0\n```\n\n### Targeted Stage Profiling\n\nUse `--delay` and `--duration` to capture specific stages and reduce file size:\n\n```bash\nnsys profile \\\n  --trace-fork-before-exec=true \\\n  --cuda-graph-trace=node \\\n  --force-overwrite=true \\\n  --delay 10 \\\n  --duration 30 \\\n  -o QwenImage_denoising \\\n  sglang generate \\\n    --model-path Qwen/Qwen-Image \\\n    --prompt \"A Logo With Bold Large Text: SGL Diffusion\" \\\n    --seed 0\n```\n\n**Parameters:**\n- `--delay N`: Wait N seconds before starting capture (skip initialization overhead)\n- `--duration N`: Capture for N seconds (focus on specific stages)\n- `--force-overwrite`: Overwrite existing output files\n\n## Notes\n\n- **Reduce trace size**: Use `--num-profiled-timesteps` with smaller values or `--delay`/`--duration` with Nsight Systems\n- **Stage-specific analysis**: Use `--profile` alone for denoising stage, add `--profile-all-stages` for full pipeline\n- **Multiple runs**: Profile with different prompts and resolutions to identify bottlenecks across workloads\n\n## FAQ\n\n- If you are profiling `sglang generate` with Nsight Systems and find that the generated profiler file did not capture any CUDA kernels, you can resolve this issue by increasing the model's inference steps to extend the execution time.\n"
  },
  {
    "path": "docs/diffusion/support_new_models.md",
    "content": "# How to Support New Diffusion Models\n\nThis document explains how to add support for new diffusion models in SGLang Diffusion.\n\n## Architecture Overview\n\nSGLang Diffusion is engineered for both performance and flexibility, built upon a pipeline architecture. This\ndesign allows developers to construct pipelines for various diffusion models while keeping the core generation\nloop standardized for optimization.\n\nAt its core, the architecture revolves around two key concepts, as highlighted in our [blog post](https://lmsys.org/blog/2025-11-07-sglang-diffusion/#architecture):\n\n-   **`ComposedPipeline`**: This class orchestrates a series of `PipelineStage`s to define the complete generation process for a specific model. It acts as the main entry point for a model and manages the data flow between the different stages of the diffusion process.\n-   **`PipelineStage`**: Each stage is a modular component that encapsulates a function within the diffusion process. Examples include prompt encoding, the denoising loop, or VAE decoding.\n\n### Two Pipeline Styles\n\nSGLang Diffusion supports two pipeline composition styles. Both are valid; choose the one that best fits your model.\n\n#### Style A: Hybrid Monolithic Pipeline (Recommended Default)\n\nThe recommended default for most new models. Uses a three-stage structure:\n\n```\nBeforeDenoisingStage (model-specific)  →  DenoisingStage (standard)  →  DecodingStage (standard)\n```\n\n| Stage | Ownership | Responsibility |\n|-------|-----------|----------------|\n| `{Model}BeforeDenoisingStage` | Model-specific | All pre-processing: input validation, text/image encoding, latent preparation, timestep computation |\n| `DenoisingStage` | Framework-standard | The denoising loop (DiT/UNet forward passes), shared across all models |\n| `DecodingStage` | Framework-standard | VAE decoding from latent space to pixel space, shared across all models |\n\n**Why recommended?** Modern diffusion models often have highly heterogeneous pre-processing requirements — different text encoders, different latent formats, different conditioning mechanisms. The Hybrid approach keeps pre-processing isolated per model, avoids fragile shared stages with excessive conditional logic, and lets developers port Diffusers reference code quickly.\n\n#### Style B: Modular Composition Style\n\nUses the framework's fine-grained standard stages (`TextEncodingStage`, `LatentPreparationStage`, `TimestepPreparationStage`, etc.) to build the pipeline by composition. Convenience methods like `add_standard_t2i_stages()` and `add_standard_ti2i_stages()` make this very concise.\n\nThis style is appropriate when:\n- **The new model's pre-processing can largely reuse existing stages** — e.g., a model that uses standard CLIP/T5 text encoding + standard latent preparation with minimal customization.\n- **A model-specific optimization needs to be extracted as a standalone stage** — e.g., a specialized encoding or conditioning step that benefits from being a separate stage for profiling, parallelism control, or reuse across multiple pipeline variants.\n\n#### How to Choose\n\n| Situation | Recommended Style |\n|-----------|-------------------|\n| Model has unique/complex pre-processing (VLM captioning, AR token generation, custom latent packing, etc.) | **Hybrid** — consolidate into a BeforeDenoisingStage |\n| Model fits neatly into standard text-to-image or text+image-to-image pattern | **Modular** — use `add_standard_t2i_stages()` / `add_standard_ti2i_stages()` |\n| Porting a Diffusers pipeline with many custom steps | **Hybrid** — copy the `__call__` logic into a single stage |\n| Adding a variant of an existing model that shares most logic | **Modular** — reuse existing stages, customize via PipelineConfig callbacks |\n| A specific pre-processing step needs special parallelism or profiling isolation | **Modular** — extract that step as a dedicated stage |\n\n## Key Components for Implementation\n\nTo add support for a new diffusion model, you will need to define or configure the following components:\n\n1.  **`PipelineConfig`**: A dataclass holding static configurations for your model pipeline — precision settings, model architecture parameters, and callback methods used by the standard `DenoisingStage` and `DecodingStage`. Each model has its own subclass.\n\n2.  **`SamplingParams`**: A dataclass defining runtime generation parameters — `prompt`, `negative_prompt`, `guidance_scale`, `num_inference_steps`, `seed`, `height`, `width`, etc.\n\n3.  **Pre-processing stage(s)**: Either a single model-specific `{Model}BeforeDenoisingStage` (Hybrid style) or a combination of standard stages (Modular style). See [Two Pipeline Styles](#two-pipeline-styles) above.\n\n4.  **`ComposedPipeline`**: A class that wires together your pre-processing stage(s) with the standard `DenoisingStage` and `DecodingStage`. See base definitions:\n    - [`ComposedPipelineBase`](https://github.com/sgl-project/sglang/blob/main/python/sglang/multimodal_gen/runtime/pipelines_core/composed_pipeline_base.py)\n    - [`PipelineStage`](https://github.com/sgl-project/sglang/blob/main/python/sglang/multimodal_gen/runtime/pipelines_core/stages/base.py)\n    - [Central registry](https://github.com/sgl-project/sglang/blob/main/python/sglang/multimodal_gen/registry.py)\n\n5.  **Modules (model components)**: Each pipeline references modules loaded from the model repository (e.g., Diffusers `model_index.json`):\n    - `text_encoder`: Encodes text prompts into embeddings.\n    - `tokenizer`: Tokenizes raw text input for the text encoder(s).\n    - `processor`: Preprocesses images and extracts features; often used in image-to-image tasks.\n    - `image_encoder`: Specialized image feature extractor.\n    - `dit/transformer`: The core denoising network (DiT/UNet architecture) operating in latent space.\n    - `scheduler`: Controls the timestep schedule and denoising dynamics.\n    - `vae`: Variational Autoencoder for encoding/decoding between pixel space and latent space.\n\n## Pipeline Stages Reference\n\n### Core Stages (used by all pipelines)\n\n| Stage Class                      | Description                                                                                             |\n| -------------------------------- | ------------------------------------------------------------------------------------------------------- |\n| `DenoisingStage`                 | Executes the main denoising loop, iteratively applying the model (DiT/UNet) to refine the latents.      |\n| `DecodingStage`                  | Decodes the final latent tensor back into pixel space using the VAE.                                    |\n| `DmdDenoisingStage`              | A specialized denoising stage for DMD model architectures.                                              |\n| `CausalDMDDenoisingStage`        | A specialized causal denoising stage for specific video models.                                         |\n\n### Pre-processing Stages (for Modular Composition Style)\n\nThe following fine-grained stages can be composed to build the pre-processing portion of a pipeline. They are best suited for models whose pre-processing largely fits the standard patterns. If your model requires significant customization, consider the Hybrid style with a single `BeforeDenoisingStage` instead.\n\n| Stage Class                      | Description                                                                                             |\n| -------------------------------- | ------------------------------------------------------------------------------------------------------- |\n| `InputValidationStage`           | Validates user-provided `SamplingParams`.                                                               |\n| `TextEncodingStage`              | Encodes text prompts into embeddings using one or more text encoders.                                   |\n| `ImageEncodingStage`             | Encodes input images into embeddings, often used in image-to-image tasks.                               |\n| `ImageVAEEncodingStage`          | Encodes an input image into latent space using the VAE.                                                 |\n| `TimestepPreparationStage`       | Prepares the scheduler's timesteps for the diffusion process.                                           |\n| `LatentPreparationStage`         | Creates the initial noisy latent tensor that will be denoised.                                          |\n\n## Implementation Guide\n\n### Step 1: Obtain and Study the Reference Implementation\n\nBefore writing any code, obtain the model's original implementation or Diffusers pipeline code:\n- The model's Diffusers pipeline source (e.g., the `pipeline_*.py` file from the `diffusers` library or HuggingFace repo)\n- Or the model's official reference implementation (e.g., from the model author's GitHub repo)\n- Or the HuggingFace model ID to look up `model_index.json` and the associated pipeline class\n\nOnce you have the reference code, study it thoroughly:\n\n1. Find the model's `model_index.json` to identify required modules.\n2. Read the Diffusers pipeline's `__call__` method to understand:\n   - How text prompts are encoded\n   - How latents are prepared (shape, dtype, scaling)\n   - How timesteps/sigmas are computed\n   - What conditioning kwargs the DiT expects\n   - How the denoising loop works\n   - How VAE decoding is done\n\n### Step 2: Evaluate Reuse of Existing Pipelines and Stages\n\nBefore creating any new files, check whether an existing pipeline or stage can be reused or extended. Only create new pipelines/stages when the existing ones would need substantial structural changes or when no architecturally similar implementation exists.\n\n- **Compare against existing pipelines** (Flux, Wan, Qwen-Image, GLM-Image, HunyuanVideo, LTX, etc.). If the new model shares most of its structure with an existing one, prefer adding a new config variant or reusing existing stages.\n- **Check existing stages** in `runtime/pipelines_core/stages/` and `stages/model_specific_stages/`.\n- **Check existing model components** — many models share VAEs (e.g., `AutoencoderKL`), text encoders (CLIP, T5), and schedulers. Reuse these directly.\n\n### Step 3: Implement Model Components\n\nAdapt the model's core components:\n\n- **DiT/Transformer**: Implement in [`runtime/models/dits/`](https://github.com/sgl-project/sglang/blob/main/python/sglang/multimodal_gen/runtime/models/dits/)\n- **Encoders**: Implement in [`runtime/models/encoders/`](https://github.com/sgl-project/sglang/blob/main/python/sglang/multimodal_gen/runtime/models/encoders/)\n- **VAEs**: Implement in [`runtime/models/vaes/`](https://github.com/sgl-project/sglang/blob/main/python/sglang/multimodal_gen/runtime/models/vaes/)\n- **Schedulers**: Implement in [`runtime/models/schedulers/`](https://github.com/sgl-project/sglang/blob/main/python/sglang/multimodal_gen/runtime/models/schedulers/) if needed\n\nUse SGLang's fused kernels where possible (see `LayerNormScaleShift`, `RMSNormScaleShift`, `apply_qk_norm`, etc.).\n\n**Tensor Parallel (TP) and Sequence Parallel (SP)**: For multi-GPU deployment, it is recommended to add TP/SP support to the DiT model. This can be done incrementally after the single-GPU implementation is verified. Reference implementations:\n- **Wan model** (`runtime/models/dits/wanvideo.py`) — Full TP + SP: `ColumnParallelLinear`/`RowParallelLinear` for attention, sequence dimension sharding via `get_sp_world_size()`\n- **Qwen-Image model** (`runtime/models/dits/qwen_image.py`) — SP via `USPAttention` (Ulysses + Ring Attention)\n\n### Step 4: Create Configs\n\n- **DiT Config**: `configs/models/dits/{model_name}.py`\n- **VAE Config**: `configs/models/vaes/{model_name}.py`\n- **SamplingParams**: `configs/sample/{model_name}.py`\n\n### Step 5: Create PipelineConfig\n\nThe `PipelineConfig` provides callbacks that the standard `DenoisingStage` and `DecodingStage` use:\n\n```python\n# python/sglang/multimodal_gen/configs/pipeline_configs/my_model.py\n\n@dataclass\nclass MyModelPipelineConfig(ImagePipelineConfig):\n    task_type: ModelTaskType = ModelTaskType.T2I\n    vae_precision: str = \"bf16\"\n    should_use_guidance: bool = True\n    dit_config: DiTConfig = field(default_factory=MyModelDitConfig)\n    vae_config: VAEConfig = field(default_factory=MyModelVAEConfig)\n\n    def get_freqs_cis(self, batch, device, rotary_emb, dtype):\n        \"\"\"Prepare rotary position embeddings for the DiT.\"\"\"\n        ...\n\n    def prepare_pos_cond_kwargs(self, batch, latent_model_input, t, **kwargs):\n        \"\"\"Build positive conditioning kwargs for each denoising step.\"\"\"\n        return {\n            \"hidden_states\": latent_model_input,\n            \"encoder_hidden_states\": batch.prompt_embeds[0],\n            \"timestep\": t,\n        }\n\n    def prepare_neg_cond_kwargs(self, batch, latent_model_input, t, **kwargs):\n        \"\"\"Build negative conditioning kwargs for CFG.\"\"\"\n        return {\n            \"hidden_states\": latent_model_input,\n            \"encoder_hidden_states\": batch.negative_prompt_embeds[0],\n            \"timestep\": t,\n        }\n\n    def get_decode_scale_and_shift(self):\n        \"\"\"Return (scale, shift) for latent denormalization before VAE decode.\"\"\"\n        ...\n```\n\n### Step 6: Implement Pre-processing\n\nChoose based on your model's needs (see [How to Choose](#how-to-choose)):\n\n#### Option A: BeforeDenoisingStage (Hybrid Style)\n\nCreate a single stage that handles all pre-processing. Best when the model has custom/complex pre-processing logic.\n\n```python\n# python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/my_model.py\n\nclass MyModelBeforeDenoisingStage(PipelineStage):\n    \"\"\"Monolithic pre-processing stage for MyModel.\n\n    Consolidates: input validation, text/image encoding, latent\n    preparation, and timestep computation.\n    \"\"\"\n\n    def __init__(self, vae, text_encoder, tokenizer, transformer, scheduler):\n        super().__init__()\n        self.vae = vae\n        self.text_encoder = text_encoder\n        self.tokenizer = tokenizer\n        self.transformer = transformer\n        self.scheduler = scheduler\n\n    @torch.no_grad()\n    def forward(self, batch: Req, server_args: ServerArgs) -> Req:\n        device = get_local_torch_device()\n\n        # 1. Encode prompt (model-specific logic)\n        prompt_embeds, negative_prompt_embeds = self._encode_prompt(...)\n\n        # 2. Prepare latents\n        latents = self._prepare_latents(...)\n\n        # 3. Prepare timesteps\n        timesteps, sigmas = self._prepare_timesteps(...)\n\n        # 4. Populate batch for DenoisingStage\n        batch.prompt_embeds = [prompt_embeds]\n        batch.negative_prompt_embeds = [negative_prompt_embeds]\n        batch.latents = latents\n        batch.timesteps = timesteps\n        batch.num_inference_steps = len(timesteps)\n        batch.sigmas = sigmas.tolist()\n        batch.generator = generator\n        batch.raw_latent_shape = latents.shape\n        return batch\n```\n\n#### Option B: Standard Stages (Modular Style)\n\nSkip creating a custom stage entirely — configure via `PipelineConfig` callbacks and use framework helpers. Best when the model fits standard patterns.\n\n(This option has no separate stage file; the pipeline class in Step 7 calls `add_standard_t2i_stages()` directly.)\n\n**Key batch fields that `DenoisingStage` expects** (regardless of which option you choose):\n\n| Field | Type | Description |\n|-------|------|-------------|\n| `batch.latents` | `torch.Tensor` | Initial noisy latent tensor |\n| `batch.timesteps` | `torch.Tensor` | Timestep schedule |\n| `batch.num_inference_steps` | `int` | Number of denoising steps |\n| `batch.sigmas` | `list[float]` | Sigma schedule (must be a Python list, not numpy) |\n| `batch.prompt_embeds` | `list[torch.Tensor]` | Positive prompt embeddings (wrapped in a list) |\n| `batch.negative_prompt_embeds` | `list[torch.Tensor]` | Negative prompt embeddings (wrapped in a list) |\n| `batch.generator` | `torch.Generator` | RNG generator for reproducibility |\n| `batch.raw_latent_shape` | `tuple` | Original latent shape before any packing |\n\n### Step 7: Define the Pipeline Class\n\n#### Hybrid Style\n\n```python\n# python/sglang/multimodal_gen/runtime/pipelines/my_model.py\n\nclass MyModelPipeline(LoRAPipeline, ComposedPipelineBase):\n    pipeline_name = \"MyModelPipeline\"  # Must match model_index.json _class_name\n\n    _required_config_modules = [\n        \"text_encoder\", \"tokenizer\", \"vae\", \"transformer\", \"scheduler\",\n    ]\n\n    def create_pipeline_stages(self, server_args: ServerArgs):\n        # 1. Monolithic pre-processing (model-specific)\n        self.add_stage(\n            MyModelBeforeDenoisingStage(\n                vae=self.get_module(\"vae\"),\n                text_encoder=self.get_module(\"text_encoder\"),\n                tokenizer=self.get_module(\"tokenizer\"),\n                transformer=self.get_module(\"transformer\"),\n                scheduler=self.get_module(\"scheduler\"),\n            ),\n        )\n\n        # 2. Standard denoising loop (framework-provided)\n        self.add_stage(\n            DenoisingStage(\n                transformer=self.get_module(\"transformer\"),\n                scheduler=self.get_module(\"scheduler\"),\n            ),\n        )\n\n        # 3. Standard VAE decoding (framework-provided)\n        self.add_standard_decoding_stage()\n\n\nEntryClass = [MyModelPipeline]\n```\n\n#### Modular Style\n\n```python\n# python/sglang/multimodal_gen/runtime/pipelines/my_model.py\n\nclass MyModelPipeline(LoRAPipeline, ComposedPipelineBase):\n    pipeline_name = \"MyModelPipeline\"\n\n    _required_config_modules = [\n        \"text_encoder\", \"tokenizer\", \"vae\", \"transformer\", \"scheduler\",\n    ]\n\n    def create_pipeline_stages(self, server_args: ServerArgs):\n        # All pre-processing + denoising + decoding in one call\n        self.add_standard_t2i_stages(\n            prepare_extra_timestep_kwargs=[prepare_mu],  # model-specific hooks\n        )\n\n\nEntryClass = [MyModelPipeline]\n```\n\n### Step 8: Register the Model\n\nRegister your configs in [`registry.py`](https://github.com/sgl-project/sglang/blob/main/python/sglang/multimodal_gen/registry.py):\n\n```python\nregister_configs(\n    model_family=\"my_model\",\n    sampling_param_cls=MyModelSamplingParams,\n    pipeline_config_cls=MyModelPipelineConfig,\n    hf_model_paths=[\"org/my-model-name\"],\n)\n```\n\nThe `EntryClass` in your pipeline file is automatically discovered by the registry — no additional registration needed for the pipeline class itself.\n\n### Step 9: Verify Output Quality\n\nAfter implementation, verify that the generated output is not noise. A noisy or garbled output is the most common sign of an incorrect implementation. Common causes include:\n\n- Incorrect latent scale/shift factors\n- Wrong timestep/sigma schedule (order, dtype, or value range)\n- Mismatched conditioning kwargs\n- Rotary embedding style mismatch (`is_neox_style`)\n\nDebug by comparing intermediate tensor values against the Diffusers reference pipeline with the same seed.\n\n## Reference Implementations\n\n### Hybrid Style\n\n| Model | Pipeline | BeforeDenoisingStage | PipelineConfig |\n|-------|----------|---------------------|----------------|\n| GLM-Image | `runtime/pipelines/glm_image.py` | `stages/model_specific_stages/glm_image.py` | `configs/pipeline_configs/glm_image.py` |\n| Qwen-Image-Layered | `runtime/pipelines/qwen_image.py` | `stages/model_specific_stages/qwen_image_layered.py` | `configs/pipeline_configs/qwen_image.py` |\n\n### Modular Style\n\n| Model | Pipeline | Notes |\n|-------|----------|-------|\n| Qwen-Image (T2I) | `runtime/pipelines/qwen_image.py` | Uses `add_standard_t2i_stages()` |\n| Qwen-Image-Edit | `runtime/pipelines/qwen_image.py` | Uses `add_standard_ti2i_stages()` |\n| Flux | `runtime/pipelines/flux.py` | Uses `add_standard_t2i_stages()` with custom `prepare_mu` |\n| Wan | `runtime/pipelines/wan_pipeline.py` | Uses `add_standard_ti2v_stages()` |\n\n## Checklist\n\nBefore submitting your implementation, verify:\n\n**Common (both styles):**\n- [ ] **Pipeline file** at `runtime/pipelines/{model_name}.py` with `EntryClass`\n- [ ] **PipelineConfig** at `configs/pipeline_configs/{model_name}.py`\n- [ ] **SamplingParams** at `configs/sample/{model_name}.py`\n- [ ] **DiT model** at `runtime/models/dits/{model_name}.py`\n- [ ] **Model configs** (DiT, VAE) at `configs/models/dits/` and `configs/models/vaes/`\n- [ ] **Registry entry** in `registry.py` via `register_configs()`\n- [ ] `pipeline_name` matches Diffusers `model_index.json` `_class_name`\n- [ ] `_required_config_modules` lists all modules from `model_index.json`\n- [ ] `PipelineConfig` callbacks (`prepare_pos_cond_kwargs`, etc.) match the DiT's `forward()` signature\n- [ ] Uses framework-standard `DenoisingStage` and `DecodingStage` (not custom denoising loops)\n- [ ] **TP/SP support** considered for DiT model (recommended; reference `wanvideo.py` for TP+SP, `qwen_image.py` for USPAttention)\n- [ ] **Output quality verified** — generated images/videos are not noise; compared against Diffusers reference output\n\n**Hybrid style only:**\n- [ ] **BeforeDenoisingStage** at `stages/model_specific_stages/{model_name}.py`\n- [ ] `BeforeDenoisingStage.forward()` populates all batch fields required by `DenoisingStage`\n"
  },
  {
    "path": "docs/get_started/install.md",
    "content": "# Install SGLang\n\nYou can install SGLang using one of the methods below.\nThis page primarily applies to common NVIDIA GPU platforms.\nFor other or newer platforms, please refer to the dedicated pages for [AMD GPUs](../platforms/amd_gpu.md), [Intel Xeon CPUs](../platforms/cpu_server.md), [TPU](../platforms/tpu.md), [NVIDIA DGX Spark](https://lmsys.org/blog/2025-11-03-gpt-oss-on-nvidia-dgx-spark/), [NVIDIA Jetson](../platforms/nvidia_jetson.md), [Ascend NPUs](../platforms/ascend_npu.md), and [Intel XPU](../platforms/xpu.md).\n\n## Method 1: With pip or uv\n\nIt is recommended to use uv for faster installation:\n\n```bash\npip install --upgrade pip\npip install uv\nuv pip install sglang\n```\n\n### For CUDA 13\n\nDocker is recommended (see Method 3 note on B300/GB300/CUDA 13). If you do not have Docker access, follow these steps:\n\n1. Install PyTorch with CUDA 13 support first:\n```bash\n# Replace X.Y.Z with the version by your SGLang install\nuv pip install torch==X.Y.Z torchvision torchaudio --index-url https://download.pytorch.org/whl/cu130\n```\n\n2. Install sglang:\n```bash\nuv pip install sglang\n```\n\n3. Install the `sglang-kernel` wheel for CUDA 13 from [the sgl-project whl releases](https://github.com/sgl-project/whl/blob/gh-pages/cu130/sglang-kernel/index.html). Replace `X.Y.Z` with the `sglang-kernel` version required by your SGLang install (you can find this by running `uv pip show sglang-kernel`). Examples:\n```bash\n# x86_64\nuv pip install \"https://github.com/sgl-project/whl/releases/download/vX.Y.Z/sglang_kernel-X.Y.Z+cu130-cp310-abi3-manylinux2014_x86_64.whl\"\n\n# aarch64\nuv pip install \"https://github.com/sgl-project/whl/releases/download/vX.Y.Z/sglang_kernel-X.Y.Z+cu130-cp310-abi3-manylinux2014_aarch64.whl\"\n```\n\n### **Quick fixes to common problems**\n- If you encounter `OSError: CUDA_HOME environment variable is not set`. Please set it to your CUDA install root with either of the following solutions:\n  1. Use `export CUDA_HOME=/usr/local/cuda-<your-cuda-version>` to set the `CUDA_HOME` environment variable.\n  2. Install FlashInfer first following [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html), then install SGLang as described above.\n\n## Method 2: From source\n\n```bash\n# Use the last release branch\ngit clone -b v0.5.9 https://github.com/sgl-project/sglang.git\ncd sglang\n\n# Install the python packages\npip install --upgrade pip\npip install -e \"python\"\n```\n\n**Quick fixes to common problems**\n\n- If you want to develop SGLang, you can try the dev docker image. Please refer to [setup docker container](../developer_guide/development_guide_using_docker.md#setup-docker-container). The docker image is `lmsysorg/sglang:dev`.\n\n## Method 3: Using docker\n\nThe docker images are available on Docker Hub at [lmsysorg/sglang](https://hub.docker.com/r/lmsysorg/sglang/tags), built from [Dockerfile](https://github.com/sgl-project/sglang/tree/main/docker).\nReplace `<secret>` below with your huggingface hub [token](https://huggingface.co/docs/hub/en/security-tokens).\n\n```bash\ndocker run --gpus all \\\n    --shm-size 32g \\\n    -p 30000:30000 \\\n    -v ~/.cache/huggingface:/root/.cache/huggingface \\\n    --env \"HF_TOKEN=<secret>\" \\\n    --ipc=host \\\n    lmsysorg/sglang:latest \\\n    python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --host 0.0.0.0 --port 30000\n```\n\nFor production deployments, use the `runtime` variant which is significantly smaller (~40% reduction) by excluding build tools and development dependencies:\n\n```bash\ndocker run --gpus all \\\n    --shm-size 32g \\\n    -p 30000:30000 \\\n    -v ~/.cache/huggingface:/root/.cache/huggingface \\\n    --env \"HF_TOKEN=<secret>\" \\\n    --ipc=host \\\n    lmsysorg/sglang:latest-runtime \\\n    python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --host 0.0.0.0 --port 30000\n```\n\nYou can also find the nightly docker images [here](https://hub.docker.com/r/lmsysorg/sglang/tags?name=nightly).\n\nNotes:\n- On B300/GB300 (SM103) or CUDA 13 environment, we recommend using the nightly image at `lmsysorg/sglang:dev-cu13` or stable image at `lmsysorg/sglang:latest-cu130-runtime`. Please, do not re-install the project as editable inside the docker image, since it will override the version of libraries specified by the cu13 docker image.\n\n## Method 4: Using Kubernetes\n\nPlease check out [OME](https://github.com/sgl-project/ome), a Kubernetes operator for enterprise-grade management and serving of large language models (LLMs).\n\n<details>\n<summary>More</summary>\n\n1. Option 1: For single node serving (typically when the model size fits into GPUs on one node)\n\n   Execute command `kubectl apply -f docker/k8s-sglang-service.yaml`, to create k8s deployment and service, with llama-31-8b as example.\n\n2. Option 2: For multi-node serving (usually when a large model requires more than one GPU node, such as `DeepSeek-R1`)\n\n   Modify the LLM model path and arguments as necessary, then execute command `kubectl apply -f docker/k8s-sglang-distributed-sts.yaml`, to create two nodes k8s statefulset and serving service.\n\n</details>\n\n## Method 5: Using docker compose\n\n<details>\n<summary>More</summary>\n\n> This method is recommended if you plan to serve it as a service.\n> A better approach is to use the [k8s-sglang-service.yaml](https://github.com/sgl-project/sglang/blob/main/docker/k8s-sglang-service.yaml).\n\n1. Copy the [compose.yml](https://github.com/sgl-project/sglang/blob/main/docker/compose.yaml) to your local machine\n2. Execute the command `docker compose up -d` in your terminal.\n</details>\n\n## Method 6: Run on Kubernetes or Clouds with SkyPilot\n\n<details>\n<summary>More</summary>\n\nTo deploy on Kubernetes or 12+ clouds, you can use [SkyPilot](https://github.com/skypilot-org/skypilot).\n\n1. Install SkyPilot and set up Kubernetes cluster or cloud access: see [SkyPilot's documentation](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html).\n2. Deploy on your own infra with a single command and get the HTTP API endpoint:\n<details>\n<summary>SkyPilot YAML: <code>sglang.yaml</code></summary>\n\n```yaml\n# sglang.yaml\nenvs:\n  HF_TOKEN: null\n\nresources:\n  image_id: docker:lmsysorg/sglang:latest\n  accelerators: A100\n  ports: 30000\n\nrun: |\n  conda deactivate\n  python3 -m sglang.launch_server \\\n    --model-path meta-llama/Llama-3.1-8B-Instruct \\\n    --host 0.0.0.0 \\\n    --port 30000\n```\n\n</details>\n\n```bash\n# Deploy on any cloud or Kubernetes cluster. Use --cloud <cloud> to select a specific cloud provider.\nHF_TOKEN=<secret> sky launch -c sglang --env HF_TOKEN sglang.yaml\n\n# Get the HTTP API endpoint\nsky status --endpoint 30000 sglang\n```\n\n3. To further scale up your deployment with autoscaling and failure recovery, check out the [SkyServe + SGLang guide](https://github.com/skypilot-org/skypilot/tree/master/llm/sglang#serving-llama-2-with-sglang-for-more-traffic-using-skyserve).\n\n</details>\n\n## Method 7: Run on AWS SageMaker\n\n<details>\n<summary>More</summary>\n\nTo deploy on SGLang on AWS SageMaker, check out [AWS SageMaker Inference](https://aws.amazon.com/sagemaker/ai/deploy)\n\nAmazon Web Services provide supports for SGLang containers along with routine security patching. For available SGLang containers, check out [AWS SGLang DLCs](https://github.com/aws/deep-learning-containers/blob/master/available_images.md#sglang-containers)\n\nTo host a model with your own container, follow the following steps:\n\n1. Build a docker container with [sagemaker.Dockerfile](https://github.com/sgl-project/sglang/blob/main/docker/sagemaker.Dockerfile) alongside the [serve](https://github.com/sgl-project/sglang/blob/main/docker/serve) script.\n2. Push your container onto AWS ECR.\n\n<details>\n<summary>Dockerfile Build Script: <code>build-and-push.sh</code></summary>\n\n```bash\n#!/bin/bash\nAWS_ACCOUNT=\"<YOUR_AWS_ACCOUNT>\"\nAWS_REGION=\"<YOUR_AWS_REGION>\"\nREPOSITORY_NAME=\"<YOUR_REPOSITORY_NAME>\"\nIMAGE_TAG=\"<YOUR_IMAGE_TAG>\"\n\nECR_REGISTRY=\"${AWS_ACCOUNT}.dkr.ecr.${AWS_REGION}.amazonaws.com\"\nIMAGE_URI=\"${ECR_REGISTRY}/${REPOSITORY_NAME}:${IMAGE_TAG}\"\n\necho \"Starting build and push process...\"\n\n# Login to ECR\necho \"Logging into ECR...\"\naws ecr get-login-password --region ${AWS_REGION} | docker login --username AWS --password-stdin ${ECR_REGISTRY}\n\n# Build the image\necho \"Building Docker image...\"\ndocker build -t ${IMAGE_URI} -f sagemaker.Dockerfile .\n\necho \"Pushing ${IMAGE_URI}\"\ndocker push ${IMAGE_URI}\n\necho \"Build and push completed successfully!\"\n```\n\n</details>\n\n3. Deploy a model for serving on AWS Sagemaker, refer to [deploy_and_serve_endpoint.py](https://github.com/sgl-project/sglang/blob/main/examples/sagemaker/deploy_and_serve_endpoint.py). For more information, check out [sagemaker-python-sdk](https://github.com/aws/sagemaker-python-sdk).\n   1. By default, the model server on SageMaker will run with the following command: `python3 -m sglang.launch_server --model-path opt/ml/model --host 0.0.0.0 --port 8080`. This is optimal for hosting your own model with SageMaker.\n   2. To modify your model serving parameters, the [serve](https://github.com/sgl-project/sglang/blob/main/docker/serve) script allows for all available options within `python3 -m sglang.launch_server --help` cli by specifying environment variables with prefix `SM_SGLANG_`.\n   3. The serve script will automatically convert all environment variables with prefix `SM_SGLANG_` from `SM_SGLANG_INPUT_ARGUMENT` into `--input-argument` to be parsed into `python3 -m sglang.launch_server` cli.\n   4. For example, to run [Qwen/Qwen3-0.6B](https://huggingface.co/Qwen/Qwen3-0.6B) with reasoning parser, simply add additional environment variables `SM_SGLANG_MODEL_PATH=Qwen/Qwen3-0.6B` and `SM_SGLANG_REASONING_PARSER=qwen3`.\n\n</details>\n\n## Common Notes\n\n- [FlashInfer](https://github.com/flashinfer-ai/flashinfer) is the default attention kernel backend. It only supports sm75 and above. If you encounter any FlashInfer-related issues on sm75+ devices (e.g., T4, A10, A100, L4, L40S, H100), please switch to other kernels by adding `--attention-backend triton --sampling-backend pytorch` and open an issue on GitHub.\n- To reinstall flashinfer locally, use the following command: `pip3 install --upgrade flashinfer-python --force-reinstall --no-deps` and then delete the cache with `rm -rf ~/.cache/flashinfer`.\n- When encountering `ptxas fatal   : Value 'sm_103a' is not defined for option 'gpu-name'` on B300/GB300, fix it with `export TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas`.\n"
  },
  {
    "path": "docs/index.rst",
    "content": "SGLang Documentation\n====================\n\n.. raw:: html\n\n  <a class=\"github-button\" href=\"https://github.com/sgl-project/sglang\" data-size=\"large\" data-show-count=\"true\" aria-label=\"Star sgl-project/sglang on GitHub\">Star</a>\n  <a class=\"github-button\" href=\"https://github.com/sgl-project/sglang/fork\" data-icon=\"octicon-repo-forked\" data-size=\"large\" data-show-count=\"true\" aria-label=\"Fork sgl-project/sglang on GitHub\">Fork</a>\n  <script async defer src=\"https://buttons.github.io/buttons.js\"></script>\n  <br></br>\n\nSGLang is a high-performance serving framework for large language models and multimodal models.\nIt is designed to deliver low-latency and high-throughput inference across a wide range of setups, from a single GPU to large distributed clusters.\nIts core features include:\n\n- **Fast Runtime**: Provides efficient serving with RadixAttention for prefix caching, a zero-overhead CPU scheduler, prefill-decode disaggregation, speculative decoding, continuous batching, paged attention, tensor/pipeline/expert/data parallelism, structured outputs, chunked prefill, quantization (FP4/FP8/INT4/AWQ/GPTQ), and multi-LoRA batching.\n- **Broad Model Support**: Supports a wide range of language models (Llama, Qwen, DeepSeek, Kimi, GLM, GPT, Gemma, Mistral, etc.), embedding models (e5-mistral, gte, mcdse), reward models (Skywork), and diffusion models (WAN, Qwen-Image), with easy extensibility for adding new models. Compatible with most Hugging Face models and OpenAI APIs.\n- **Extensive Hardware Support**: Runs on NVIDIA GPUs (GB200/B300/H100/A100/Spark), AMD GPUs (MI355/MI300), Intel Xeon CPUs, Google TPUs, Ascend NPUs, and more.\n- **Active Community**: SGLang is open-source and supported by a vibrant community with widespread industry adoption, powering over 400,000 GPUs worldwide.\n- **RL & Post-Training Backbone**: SGLang is a proven rollout backend used for training many frontier models, with native RL integrations and adoption by well-known post-training frameworks such as AReaL, Miles, slime, Tunix, verl and more.\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Get Started\n\n   get_started/install.md\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Basic Usage\n\n   basic_usage/send_request.ipynb\n   basic_usage/openai_api.rst\n   basic_usage/ollama_api.md\n   basic_usage/offline_engine_api.ipynb\n   basic_usage/native_api.ipynb\n   basic_usage/sampling_params.md\n   basic_usage/popular_model_usage.rst\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Advanced Features\n\n   advanced_features/server_arguments.md\n   advanced_features/hyperparameter_tuning.md\n   advanced_features/attention_backend.md\n   advanced_features/speculative_decoding.ipynb\n   advanced_features/structured_outputs.ipynb\n   advanced_features/structured_outputs_for_reasoning_models.ipynb\n   advanced_features/tool_parser.ipynb\n   advanced_features/separate_reasoning.ipynb\n   advanced_features/quantization.md\n   advanced_features/quantized_kv_cache.md\n   advanced_features/expert_parallelism.md\n   advanced_features/dp_dpa_smg_guide.md\n   advanced_features/lora.ipynb\n   advanced_features/pd_disaggregation.md\n   advanced_features/epd_disaggregation.md\n   advanced_features/pipeline_parallelism.md\n   advanced_features/hicache.rst\n   advanced_features/pd_multiplexing.md\n   advanced_features/vlm_query.ipynb\n   advanced_features/dp_for_multi_modal_encoder.md\n   advanced_features/cuda_graph_for_multi_modal_encoder.md\n   advanced_features/piecewise_cuda_graph.md\n   advanced_features/sgl_model_gateway.md\n   advanced_features/deterministic_inference.md\n   advanced_features/observability.md\n   advanced_features/checkpoint_engine.md\n   advanced_features/sglang_for_rl.md\n\n.. toctree::\n   :maxdepth: 2\n   :caption: Supported Models\n\n   supported_models/text_generation/index\n   supported_models/retrieval_ranking/index\n   supported_models/specialized/index\n   supported_models/extending/index\n\n.. toctree::\n   :maxdepth: 2\n   :caption: SGLang Diffusion\n\n   diffusion/index\n   diffusion/installation\n   diffusion/compatibility_matrix\n   diffusion/api/cli\n   diffusion/api/openai_api\n   diffusion/performance/index\n   diffusion/performance/attention_backends\n   diffusion/performance/profiling\n   diffusion/performance/cache/index\n   diffusion/performance/cache/cache_dit\n   diffusion/performance/cache/teacache\n   diffusion/support_new_models\n   diffusion/contributing\n   diffusion/ci_perf\n   diffusion/environment_variables\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Hardware Platforms\n\n   platforms/amd_gpu.md\n   platforms/cpu_server.md\n   platforms/tpu.md\n   platforms/nvidia_jetson.md\n   platforms/ascend_npu_support.rst\n   platforms/xpu.md\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Developer Guide\n\n   developer_guide/contribution_guide.md\n   developer_guide/development_guide_using_docker.md\n   developer_guide/development_jit_kernel_guide.md\n   developer_guide/benchmark_and_profiling.md\n   developer_guide/bench_serving.md\n   developer_guide/evaluating_new_models.md\n\n.. toctree::\n   :maxdepth: 1\n   :caption: References\n\n   references/faq.md\n   references/environment_variables.md\n   references/production_metrics.md\n   references/production_request_trace.md\n   references/multi_node_deployment/multi_node_index.rst\n   references/custom_chat_template.md\n   references/frontend/frontend_index.rst\n   references/post_training_integration.md\n   references/release_lookup\n   references/learn_more.md\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Security Acknowledgement\n\n   security/acknowledgements.md\n"
  },
  {
    "path": "docs/performance_dashboard/README.md",
    "content": "# SGLang Performance Dashboard\n\nA web-based dashboard for visualizing SGLang nightly test performance metrics.\n\n## Features\n\n- **Performance Trends**: View throughput, latency, and TTFT trends over time\n- **Model Comparison**: Compare performance across different models and configurations\n- **Filtering**: Filter by GPU configuration, model, variant, and batch size\n- **Interactive Charts**: Zoom, pan, and hover for detailed metrics\n- **Run History**: View recent benchmark runs with links to GitHub Actions\n\n## Quick Start\n\n### Option 1: Run with Local Server (Recommended)\n\nFor live data from GitHub Actions artifacts:\n\n```bash\n# Install requirements\npip install requests\n\n# Run the server\npython server.py --fetch-on-start\n\n# Visit http://localhost:8000\n```\n\nThe server provides:\n- Automatic fetching of metrics from GitHub\n- Caching to reduce API calls\n- `/api/metrics` endpoint for the frontend\n\n### Option 2: Fetch Data Manually\n\nUse the fetch script to download metrics data:\n\n```bash\n# Fetch last 30 days of metrics\npython fetch_metrics.py --output metrics_data.json\n\n# Fetch a specific run\npython fetch_metrics.py --run-id 21338741812 --output single_run.json\n\n# Fetch only scheduled (nightly) runs\npython fetch_metrics.py --scheduled-only --days 7\n```\n\n## GitHub Token\n\nTo download artifacts from GitHub, you need authentication:\n\n1. **Using `gh` CLI** (recommended):\n   ```bash\n   gh auth login\n   ```\n\n2. **Using environment variable**:\n   ```bash\n   export GITHUB_TOKEN=your_token_here\n   ```\n\nWithout a token, the dashboard will show run metadata but not detailed benchmark results.\n\n## Data Structure\n\nThe metrics JSON has this structure:\n\n```json\n{\n  \"run_id\": \"21338741812\",\n  \"run_date\": \"2026-01-25T22:24:02.090218+00:00\",\n  \"commit_sha\": \"5cdb391...\",\n  \"branch\": \"main\",\n  \"results\": [\n    {\n      \"gpu_config\": \"8-gpu-h200\",\n      \"partition\": 0,\n      \"model\": \"deepseek-ai/DeepSeek-V3.1\",\n      \"variant\": \"TP8+MTP\",\n      \"benchmarks\": [\n        {\n          \"batch_size\": 1,\n          \"input_len\": 4096,\n          \"output_len\": 512,\n          \"latency_ms\": 2400.72,\n          \"input_throughput\": 21408.64,\n          \"output_throughput\": 231.74,\n          \"overall_throughput\": 1919.43,\n          \"ttft_ms\": 191.32,\n          \"acc_length\": 3.19\n        }\n      ]\n    }\n  ]\n}\n```\n\n## Deployment\n\n### GitHub Pages\n\nThe dashboard can be deployed to GitHub Pages for public access:\n\n1. Copy the dashboard files to `docs/performance_dashboard/`\n2. Enable GitHub Pages in repository settings\n3. Set up a GitHub Action to periodically update metrics data\n\n### Self-Hosted\n\nFor a self-hosted deployment with live data:\n\n1. Set up a server running `server.py`\n2. Configure a cron job or systemd timer to refresh data\n3. Optionally put behind nginx/caddy for SSL\n\n## Metrics Explained\n\n- **Overall Throughput**: Total tokens (input + output) processed per second\n- **Input Throughput**: Input tokens processed per second (prefill speed)\n- **Output Throughput**: Output tokens generated per second (decode speed)\n- **Latency**: End-to-end time to complete the request\n- **TTFT**: Time to First Token - time until the first output token\n- **Acc Length**: Acceptance length for speculative decoding (MTP variants)\n\n## Contributing\n\nTo add support for new metrics or visualizations:\n\n1. Update `fetch_metrics.py` if data collection needs changes\n2. Modify `app.js` to add new chart types or filters\n3. Update `index.html` for UI changes\n\n## Troubleshooting\n\n**No data displayed**\n- Check browser console for errors\n- Verify GitHub API is accessible\n- Try running with `server.py --fetch-on-start`\n\n**API rate limits**\n- Use a GitHub token for higher limits\n- The server caches data for 5 minutes\n\n**Charts not rendering**\n- Ensure Chart.js is loading from CDN\n- Check for JavaScript errors in console\n"
  },
  {
    "path": "docs/performance_dashboard/app.js",
    "content": "// SGLang Performance Dashboard Application\n\nconst GITHUB_REPO = 'sgl-project/sglang';\nconst WORKFLOW_NAME = 'nightly-test-nvidia.yml';\nconst ARTIFACT_PREFIX = 'consolidated-metrics-';\n\n// Chart instances (array for batch-separated charts)\nlet activeCharts = [];\n\n// Data storage\nlet allMetricsData = [];\nlet currentModel = null;\nlet currentMetricType = 'throughput'; // throughput, latency, ttft, inputThroughput\n\n// Metric type definitions\nconst metricTypes = {\n    // Text/VLM metrics\n    throughput: { label: 'Overall Throughput', unit: 'tokens/sec', field: 'throughput', type: 'text' },\n    outputThroughput: { label: 'Output Throughput', unit: 'tokens/sec', field: 'outputThroughput', type: 'text' },\n    inputThroughput: { label: 'Input Throughput', unit: 'tokens/sec', field: 'inputThroughput', type: 'text' },\n    latency: { label: 'Latency', unit: 'ms', field: 'latency', type: 'text' },\n    ttft: { label: 'Time to First Token', unit: 'ms', field: 'ttft', type: 'text' },\n    accLength: { label: 'Accept Length', unit: 'tokens', field: 'accLength', filterInvalid: true, type: 'text' },\n    // Diffusion metrics\n    e2eMs: { label: 'End-to-End Time', unit: 'ms', field: 'e2e_ms', type: 'diffusion' },\n    avgDenoiseMs: { label: 'Avg Denoise Time', unit: 'ms', field: 'avg_denoise_ms', type: 'diffusion' },\n    medianDenoiseMs: { label: 'Median Denoise Time', unit: 'ms', field: 'median_denoise_ms', type: 'diffusion' }\n};\n\n// Chart.js default configuration for dark theme\nChart.defaults.color = '#94a3b8';\nChart.defaults.borderColor = '#1e293b';\n\nconst chartColors = [\n    '#22d3ee', '#34d399', '#fbbf24', '#f87171', '#a78bfa',\n    '#67e8f9', '#6ee7b7', '#fcd34d', '#fca5a5', '#c4b5fd'\n];\n\n// Initialize the dashboard\nasync function init() {\n    try {\n        await loadData();\n        document.getElementById('loading').style.display = 'none';\n        document.getElementById('content').style.display = 'block';\n        populateFilters();\n        updateStats();\n        updateCharts();\n        updateRunsTable();\n    } catch (error) {\n        console.error('Failed to initialize dashboard:', error);\n        document.getElementById('loading').style.display = 'none';\n        document.getElementById('error').style.display = 'block';\n        document.getElementById('error-message').textContent = error.message;\n    }\n}\n\n// Load data from local server API or GitHub\nasync function loadData() {\n    // Try local server API first (if running server.py)\n    try {\n        const response = await fetch('/api/metrics', { headers: getAuthHeaders() });\n        if (response.ok) {\n            const data = await response.json();\n            if (data.length > 0 && data[0].results && data[0].results.length > 0) {\n                allMetricsData = data;\n                console.log(`Loaded ${data.length} records from local API`);\n                allMetricsData.sort((a, b) => new Date(b.run_date) - new Date(a.run_date));\n                return;\n            }\n        }\n    } catch (error) {\n        console.log('Local API not available, trying GitHub API');\n    }\n\n    // Try to load from GitHub API\n    const runs = await fetchWorkflowRuns();\n    const metricsPromises = runs.map(run => fetchMetricsForRun(run));\n    const results = await Promise.allSettled(metricsPromises);\n\n    allMetricsData = results\n        .filter(r => r.status === 'fulfilled' && r.value !== null)\n        .map(r => r.value);\n\n    if (allMetricsData.length === 0) {\n        throw new Error('No metrics data available. Please run the server.py with --fetch-on-start to fetch data from GitHub.');\n    }\n\n    // Sort by date descending\n    allMetricsData.sort((a, b) => new Date(b.run_date) - new Date(a.run_date));\n}\n\n// Fetch workflow runs from GitHub API\nasync function fetchWorkflowRuns() {\n    const response = await fetch(\n        `https://api.github.com/repos/${GITHUB_REPO}/actions/workflows/${WORKFLOW_NAME}/runs?status=completed&per_page=30`,\n        {\n            headers: {\n                'Accept': 'application/vnd.github.v3+json'\n            }\n        }\n    );\n\n    if (!response.ok) {\n        throw new Error(`GitHub API error: ${response.status}`);\n    }\n\n    const data = await response.json();\n    return data.workflow_runs || [];\n}\n\n// Fetch metrics artifact for a specific run\nasync function fetchMetricsForRun(run) {\n    try {\n        // Get artifacts for this run\n        const artifactsResponse = await fetch(\n            `https://api.github.com/repos/${GITHUB_REPO}/actions/runs/${run.id}/artifacts`,\n            {\n                headers: {\n                    'Accept': 'application/vnd.github.v3+json'\n                }\n            }\n        );\n\n        if (!artifactsResponse.ok) return null;\n\n        const artifactsData = await artifactsResponse.json();\n        const metricsArtifact = artifactsData.artifacts.find(\n            a => a.name.startsWith(ARTIFACT_PREFIX)\n        );\n\n        if (!metricsArtifact) return null;\n\n        // Note: GitHub API doesn't allow direct artifact download without authentication\n        // For public access, we would need to use a proxy or pre-process the data\n        // For now, return run metadata - in production, use a backend to fetch artifacts\n        return {\n            run_id: run.id.toString(),\n            run_date: run.created_at,\n            commit_sha: run.head_sha,\n            branch: run.head_branch,\n            artifact_id: metricsArtifact.id,\n            results: [] // Would be populated from artifact content\n        };\n    } catch (error) {\n        console.warn(`Failed to fetch metrics for run ${run.id}:`, error);\n        return null;\n    }\n}\n\n// Helper function to detect if result is diffusion type\nfunction isDiffusionResult(result) {\n    return result.test_type === 'diffusion' || (result.tests && !result.benchmarks);\n}\n\n// Populate filter dropdowns\nfunction populateFilters() {\n    const gpuConfigs = new Set();\n    const models = new Set();\n    const testNames = new Set(); // For diffusion tests\n    const batchSizes = new Set();\n    const ioLengths = new Set();\n\n    allMetricsData.forEach(run => {\n        run.results.forEach(result => {\n            gpuConfigs.add(result.gpu_config);\n\n            // Handle diffusion results\n            if (isDiffusionResult(result)) {\n                models.add(result.test_suite || 'diffusion');\n                if (result.tests) {\n                    result.tests.forEach(test => {\n                        testNames.add(test.test_name);\n                    });\n                }\n            }\n            // Handle text/VLM results\n            else {\n                models.add(result.model);\n                // Try new structure first (benchmarks_by_io_len), fall back to flat benchmarks\n                if (result.benchmarks_by_io_len) {\n                    Object.entries(result.benchmarks_by_io_len).forEach(([ioKey, ioData]) => {\n                        ioLengths.add(ioKey);\n                        ioData.benchmarks.forEach(bench => {\n                            batchSizes.add(bench.batch_size);\n                        });\n                    });\n                } else if (result.benchmarks) {\n                    result.benchmarks.forEach(bench => {\n                        batchSizes.add(bench.batch_size);\n                        if (bench.input_len && bench.output_len) {\n                            ioLengths.add(`${bench.input_len}_${bench.output_len}`);\n                        }\n                    });\n                }\n            }\n        });\n    });\n\n    // No \"all\" option for GPU and Model - populate with first value selected\n    const gpuArray = Array.from(gpuConfigs).sort();\n    const modelArray = Array.from(models).sort();\n\n    populateSelectNoAll('gpu-filter', gpuArray);\n    populateSelectNoAll('model-filter', modelArray);\n    populateSelect('batch-filter', Array.from(batchSizes).sort((a, b) => a - b));\n    populateSelectWithLabels('io-len-filter', sortIoLengths(Array.from(ioLengths)), formatIoLenLabel);\n\n    // Set initial values (first option)\n    if (gpuArray.length > 0) {\n        document.getElementById('gpu-filter').value = gpuArray[0];\n    }\n    if (modelArray.length > 0) {\n        document.getElementById('model-filter').value = modelArray[0];\n        currentModel = modelArray[0];\n    }\n\n    // Update variants based on selected model\n    updateVariantFilter();\n    // Update IO length filter based on selected GPU/model\n    updateIoLenFilter();\n\n    // Create metric type tabs\n    createMetricTabs();\n}\n\n// Format input/output length key for display\nfunction formatIoLenLabel(ioKey) {\n    if (!ioKey) return 'Unknown';\n    const parts = ioKey.split('_');\n    if (parts.length === 2) {\n        return `In: ${parts[0]}, Out: ${parts[1]}`;\n    }\n    return ioKey;\n}\n\n// Sort IO length keys numerically (by input length, then output length)\nfunction sortIoLengths(ioLengths) {\n    return ioLengths.filter(key => key && key.includes('_')).sort((a, b) => {\n        const [aIn, aOut] = a.split('_').map(Number);\n        const [bIn, bOut] = b.split('_').map(Number);\n        if (isNaN(aIn) || isNaN(bIn)) return 0;\n        return (aIn - bIn) || (aOut - bOut);\n    });\n}\n\n// Populate select with custom label formatting\nfunction populateSelectWithLabels(selectId, options, labelFormatter) {\n    const select = document.getElementById(selectId);\n    options.forEach(option => {\n        const opt = document.createElement('option');\n        opt.value = option;\n        opt.textContent = labelFormatter ? labelFormatter(option) : option;\n        select.appendChild(opt);\n    });\n}\n\n// Update IO length filter based on selected GPU and model\nfunction updateIoLenFilter() {\n    const gpuFilterEl = document.getElementById('gpu-filter');\n    const modelFilterEl = document.getElementById('model-filter');\n    const ioLenSelect = document.getElementById('io-len-filter');\n    if (!gpuFilterEl || !modelFilterEl || !ioLenSelect) return;\n\n    const gpuFilter = gpuFilterEl.value;\n    const modelFilter = modelFilterEl.value;\n\n    const ioLengths = new Set();\n\n    allMetricsData.forEach(run => {\n        run.results.forEach(result => {\n            if (result.gpu_config === gpuFilter && result.model === modelFilter) {\n                if (result.benchmarks_by_io_len) {\n                    Object.keys(result.benchmarks_by_io_len).forEach(ioKey => {\n                        ioLengths.add(ioKey);\n                    });\n                } else if (result.benchmarks) {\n                    result.benchmarks.forEach(bench => {\n                        if (bench.input_len && bench.output_len) {\n                            ioLengths.add(`${bench.input_len}_${bench.output_len}`);\n                        }\n                    });\n                }\n            }\n        });\n    });\n\n    const ioLenArray = sortIoLengths(Array.from(ioLengths));\n    const currentIoLen = ioLenSelect.value;\n\n    // Clear and repopulate\n    ioLenSelect.innerHTML = '<option value=\"all\">All Lengths</option>';\n    ioLenArray.forEach(ioLen => {\n        const opt = document.createElement('option');\n        opt.value = ioLen;\n        opt.textContent = formatIoLenLabel(ioLen);\n        ioLenSelect.appendChild(opt);\n    });\n\n    // Try to restore previous selection if still valid\n    if (ioLenArray.includes(currentIoLen)) {\n        ioLenSelect.value = currentIoLen;\n    } else {\n        ioLenSelect.value = 'all';\n    }\n}\n\n// Update variant filter based on selected GPU and model\nfunction updateVariantFilter() {\n    const gpuFilter = document.getElementById('gpu-filter').value;\n    const modelFilter = document.getElementById('model-filter').value;\n\n    const variants = new Set();\n\n    allMetricsData.forEach(run => {\n        run.results.forEach(result => {\n            if (result.gpu_config === gpuFilter && result.model === modelFilter) {\n                // Use 'default' for null/undefined variants\n                variants.add(result.variant || 'default');\n            }\n        });\n    });\n\n    const variantArray = Array.from(variants).sort();\n    const variantSelect = document.getElementById('variant-filter');\n    const currentVariant = variantSelect.value;\n\n    // Clear and repopulate\n    variantSelect.innerHTML = '<option value=\"all\">All Variants</option>';\n    variantArray.forEach(variant => {\n        const opt = document.createElement('option');\n        opt.value = variant;\n        opt.textContent = variant;\n        variantSelect.appendChild(opt);\n    });\n\n    // Try to restore previous selection if still valid\n    if (variantArray.includes(currentVariant)) {\n        variantSelect.value = currentVariant;\n    } else {\n        variantSelect.value = 'all';\n    }\n}\n\nfunction populateSelect(selectId, options) {\n    const select = document.getElementById(selectId);\n    options.forEach(option => {\n        const opt = document.createElement('option');\n        opt.value = option;\n        opt.textContent = option;\n        select.appendChild(opt);\n    });\n}\n\nfunction populateSelectNoAll(selectId, options) {\n    const select = document.getElementById(selectId);\n    // Remove the \"all\" option if present\n    while (select.options.length > 0) {\n        select.remove(0);\n    }\n    options.forEach(option => {\n        const opt = document.createElement('option');\n        opt.value = option;\n        opt.textContent = option;\n        select.appendChild(opt);\n    });\n}\n\nfunction createMetricTabs() {\n    const tabsContainer = document.getElementById('metric-tabs');\n    tabsContainer.innerHTML = '';\n\n    // Detect if current data is diffusion or text\n    const isDiffusion = detectCurrentDataType() === 'diffusion';\n    const dataType = isDiffusion ? 'diffusion' : 'text';\n\n    // Filter metrics based on data type\n    const relevantMetrics = Object.entries(metricTypes).filter(([key, metric]) =>\n        metric.type === dataType\n    );\n\n    relevantMetrics.forEach(([key, metric], index) => {\n        const tab = document.createElement('div');\n        tab.className = index === 0 ? 'tab active' : 'tab';\n        tab.textContent = metric.label;\n        tab.dataset.metric = key;\n        tab.onclick = () => selectMetricTab(key, tab);\n        tabsContainer.appendChild(tab);\n    });\n\n    // Set initial metric type\n    if (relevantMetrics.length > 0) {\n        currentMetricType = relevantMetrics[0][0];\n    }\n}\n\nfunction detectCurrentDataType() {\n    // Check if currently selected model/GPU config has diffusion data\n    const gpuFilter = document.getElementById('gpu-filter')?.value;\n    const modelFilter = currentModel;\n\n    if (!gpuFilter || !modelFilter) return 'text';\n\n    for (const run of allMetricsData) {\n        for (const result of run.results) {\n            if (result.gpu_config === gpuFilter) {\n                const resultModel = result.test_suite || result.model;\n                if (resultModel === modelFilter && isDiffusionResult(result)) {\n                    return 'diffusion';\n                }\n            }\n        }\n    }\n    return 'text';\n}\n\nfunction selectMetricTab(metricKey, tabElement) {\n    document.querySelectorAll('.tab').forEach(t => t.classList.remove('active'));\n    tabElement.classList.add('active');\n    currentMetricType = metricKey;\n\n    // Update chart title\n    const metric = metricTypes[metricKey];\n    document.getElementById('metric-title').textContent = `${metric.label} (${metric.unit})`;\n\n    updateCharts();\n}\n\n// Handle model filter dropdown change\nfunction handleModelFilterChange(model) {\n    currentModel = model;\n    // Update variant filter based on new model selection\n    updateVariantFilter();\n    // Update IO length filter based on new model selection\n    updateIoLenFilter();\n    // Recreate metric tabs in case data type changed (text vs diffusion)\n    createMetricTabs();\n    updateCharts();\n}\n\n// Handle GPU filter change\nfunction handleGpuFilterChange() {\n    // Update variant filter based on new GPU selection\n    updateVariantFilter();\n    // Update IO length filter based on new GPU selection\n    updateIoLenFilter();\n    // Recreate metric tabs in case data type changed (text vs diffusion)\n    createMetricTabs();\n    updateCharts();\n}\n\n// Update summary stats\nfunction updateStats() {\n    const statsRow = document.getElementById('stats-row');\n    const latestRun = allMetricsData[0];\n\n    if (!latestRun) {\n        statsRow.innerHTML = '';\n        const noDataDiv = document.createElement('div');\n        noDataDiv.className = 'no-data';\n        noDataDiv.textContent = 'No data available';\n        statsRow.appendChild(noDataDiv);\n        return;\n    }\n\n    const totalModels = new Set(latestRun.results.map(r => r.model)).size;\n    const totalBenchmarks = latestRun.results.reduce((sum, r) => {\n        // Count benchmarks from either structure\n        if (r.benchmarks_by_io_len) {\n            return sum + Object.values(r.benchmarks_by_io_len).reduce(\n                (ioSum, ioData) => ioSum + ioData.benchmarks.length, 0\n            );\n        }\n        return sum + (r.benchmarks ? r.benchmarks.length : 0);\n    }, 0);\n\n    statsRow.innerHTML = ''; // Clear previous stats\n\n    const addStat = (label, value) => {\n        const card = document.createElement('div');\n        card.className = 'stat-card';\n        const labelEl = document.createElement('div');\n        labelEl.className = 'label';\n        labelEl.textContent = label;\n        const valueEl = document.createElement('div');\n        valueEl.className = 'value';\n        valueEl.textContent = value;\n        card.appendChild(labelEl);\n        card.appendChild(valueEl);\n        statsRow.appendChild(card);\n    };\n\n    addStat('Total Runs', allMetricsData.length);\n    addStat('Models Tested', totalModels);\n    addStat('Benchmarks', totalBenchmarks);\n}\n\n// Update charts based on current filters and selected metric type\nfunction updateCharts() {\n    const gpuFilter = document.getElementById('gpu-filter').value;\n    const modelFilter = currentModel;\n    const variantFilter = document.getElementById('variant-filter').value;\n    const ioLenFilter = document.getElementById('io-len-filter').value;\n    const batchFilter = document.getElementById('batch-filter').value;\n\n    // Prepare data for charts - grouped by batch size\n    const chartDataByBatch = prepareChartDataByBatch(gpuFilter, modelFilter, variantFilter, ioLenFilter, batchFilter);\n\n    // Update chart for the selected metric type\n    updateMetricChart(chartDataByBatch, currentMetricType);\n}\n\nfunction prepareChartData(gpuFilter, modelFilter, variantFilter, ioLenFilter, batchFilter) {\n    const seriesMap = new Map();\n\n    allMetricsData.forEach(run => {\n        const runDate = new Date(run.run_date);\n\n        run.results.forEach(result => {\n            // Apply filters\n            if (result.gpu_config !== gpuFilter) return;\n            if (result.model !== modelFilter) return;\n            if (variantFilter !== 'all' && result.variant !== variantFilter) return;\n\n            // Helper function to process a benchmark entry\n            const processBenchmark = (bench, ioKey) => {\n                if (batchFilter !== 'all' && bench.batch_size !== parseInt(batchFilter)) return;\n\n                const ioLabel = ioKey ? `, ${formatIoLenLabel(ioKey)}` : '';\n                const seriesKey = `${result.model.split('/').pop()} (${result.variant}, BS=${bench.batch_size}${ioLabel})`;\n\n                if (!seriesMap.has(seriesKey)) {\n                    seriesMap.set(seriesKey, {\n                        label: seriesKey,\n                        data: [],\n                        model: result.model,\n                        variant: result.variant,\n                        batchSize: bench.batch_size,\n                        ioKey: ioKey\n                    });\n                }\n\n                seriesMap.get(seriesKey).data.push({\n                    x: runDate,\n                    throughput: bench.overall_throughput,\n                    outputThroughput: bench.output_throughput,\n                    latency: bench.latency_ms,\n                    ttft: bench.ttft_ms,\n                    inputThroughput: bench.input_throughput,\n                    accLength: bench.acc_length,\n                    runId: run.run_id\n                });\n            };\n\n            // Use benchmarks_by_io_len if available\n            if (result.benchmarks_by_io_len) {\n                Object.entries(result.benchmarks_by_io_len).forEach(([ioKey, ioData]) => {\n                    if (ioLenFilter !== 'all' && ioKey !== ioLenFilter) return;\n                    ioData.benchmarks.forEach(bench => processBenchmark(bench, ioKey));\n                });\n            } else if (result.benchmarks) {\n                result.benchmarks.forEach(bench => {\n                    const benchIoKey = bench.input_len && bench.output_len\n                        ? `${bench.input_len}_${bench.output_len}`\n                        : null;\n                    if (ioLenFilter !== 'all' && benchIoKey !== ioLenFilter) return;\n                    processBenchmark(bench, benchIoKey);\n                });\n            }\n        });\n    });\n\n    // Sort data points by date\n    seriesMap.forEach(series => {\n        series.data.sort((a, b) => a.x - b.x);\n    });\n\n    return Array.from(seriesMap.values());\n}\n\n// Prepare chart data grouped by batch size - each batch size is a separate series\nfunction prepareChartDataByBatch(gpuFilter, modelFilter, variantFilter, ioLenFilter, batchFilter) {\n    const batchDataMap = new Map(); // batch_size -> Map of variant -> data\n    const testDataMap = new Map(); // For diffusion: test_name -> data\n\n    allMetricsData.forEach(run => {\n        const runDate = new Date(run.run_date);\n\n        run.results.forEach(result => {\n            // Apply filters - GPU and Model are required (no \"all\" option)\n            if (result.gpu_config !== gpuFilter) return;\n\n            // Handle diffusion results\n            if (isDiffusionResult(result)) {\n                const resultModel = result.test_suite || 'diffusion';\n                if (resultModel !== modelFilter) return;\n\n                if (result.tests) {\n                    result.tests.forEach(test => {\n                        const testName = test.test_name;\n                        if (!testDataMap.has(testName)) {\n                            testDataMap.set(testName, {\n                                label: testName,\n                                data: [],\n                                model: resultModel,\n                                testName: testName\n                            });\n                        }\n\n                        testDataMap.get(testName).data.push({\n                            x: runDate,\n                            e2e_ms: test.e2e_ms,\n                            avg_denoise_ms: test.avg_denoise_ms,\n                            median_denoise_ms: test.median_denoise_ms,\n                            runId: run.run_id\n                        });\n                    });\n                }\n                return;\n            }\n\n            // Handle text/VLM results\n            if (result.model !== modelFilter) return;\n            if (variantFilter !== 'all' && result.variant !== variantFilter) return;\n\n            // Use benchmarks_by_io_len if available, otherwise fall back to flat benchmarks\n            if (result.benchmarks_by_io_len) {\n                Object.entries(result.benchmarks_by_io_len).forEach(([ioKey, ioData]) => {\n                    // Apply IO length filter\n                    if (ioLenFilter !== 'all' && ioKey !== ioLenFilter) return;\n\n                    ioData.benchmarks.forEach(bench => {\n                        if (batchFilter !== 'all' && bench.batch_size !== parseInt(batchFilter)) return;\n\n                        const batchSize = bench.batch_size;\n                        const variantLabel = result.variant || 'default';\n                        // Include IO length in series key when showing all lengths\n                        const seriesKey = ioLenFilter === 'all'\n                            ? `${variantLabel} (${formatIoLenLabel(ioKey)})`\n                            : variantLabel;\n\n                        if (!batchDataMap.has(batchSize)) {\n                            batchDataMap.set(batchSize, new Map());\n                        }\n\n                        const variantMap = batchDataMap.get(batchSize);\n                        if (!variantMap.has(seriesKey)) {\n                            variantMap.set(seriesKey, {\n                                label: seriesKey,\n                                data: [],\n                                model: result.model,\n                                variant: result.variant,\n                                batchSize: batchSize,\n                                ioKey: ioKey\n                            });\n                        }\n\n                        variantMap.get(seriesKey).data.push({\n                            x: runDate,\n                            throughput: bench.overall_throughput,\n                            outputThroughput: bench.output_throughput,\n                            latency: bench.latency_ms,\n                            ttft: bench.ttft_ms,\n                            inputThroughput: bench.input_throughput,\n                            accLength: bench.acc_length,\n                            runId: run.run_id\n                        });\n                    });\n                });\n            } else if (result.benchmarks) {\n                // Fall back to flat benchmarks for backward compatibility\n                result.benchmarks.forEach(bench => {\n                    // Apply IO length filter using flat structure\n                    const benchIoKey = bench.input_len && bench.output_len\n                        ? `${bench.input_len}_${bench.output_len}`\n                        : null;\n                    if (ioLenFilter !== 'all' && benchIoKey !== ioLenFilter) return;\n                    if (batchFilter !== 'all' && bench.batch_size !== parseInt(batchFilter)) return;\n\n                    const batchSize = bench.batch_size;\n                    const variantLabel = result.variant || 'default';\n                    // Include IO length in series key when showing all lengths\n                    const seriesKey = ioLenFilter === 'all' && benchIoKey\n                        ? `${variantLabel} (${formatIoLenLabel(benchIoKey)})`\n                        : variantLabel;\n\n                    if (!batchDataMap.has(batchSize)) {\n                        batchDataMap.set(batchSize, new Map());\n                    }\n\n                    const variantMap = batchDataMap.get(batchSize);\n                    if (!variantMap.has(seriesKey)) {\n                        variantMap.set(seriesKey, {\n                            label: seriesKey,\n                            data: [],\n                            model: result.model,\n                            variant: result.variant,\n                            batchSize: batchSize,\n                            ioKey: benchIoKey\n                        });\n                    }\n\n                    variantMap.get(seriesKey).data.push({\n                        x: runDate,\n                        throughput: bench.overall_throughput,\n                        outputThroughput: bench.output_throughput,\n                        latency: bench.latency_ms,\n                        ttft: bench.ttft_ms,\n                        inputThroughput: bench.input_throughput,\n                        accLength: bench.acc_length,\n                        runId: run.run_id\n                    });\n                });\n            }\n        });\n    });\n\n    // Sort data points by date and convert to array format\n    const result = {};\n\n    // For diffusion data, use test names as \"batch sizes\"\n    if (testDataMap.size > 0) {\n        testDataMap.forEach((series, testName) => {\n            series.data.sort((a, b) => a.x - b.x);\n            result[testName] = [series]; // Each test is its own series\n        });\n        return result;\n    }\n\n    // For text/VLM data, use batch sizes\n    batchDataMap.forEach((variantMap, batchSize) => {\n        variantMap.forEach(series => {\n            series.data.sort((a, b) => a.x - b.x);\n        });\n        result[batchSize] = Array.from(variantMap.values());\n    });\n\n    return result;\n}\n\n// Unified chart update function for any metric type\nfunction updateMetricChart(chartDataByBatch, metricType) {\n    const container = document.getElementById('charts-container');\n    container.innerHTML = '';\n\n    // Destroy existing charts\n    activeCharts.forEach(chart => chart.destroy());\n    activeCharts = [];\n\n    const metric = metricTypes[metricType];\n    const isDiffusion = metric.type === 'diffusion';\n\n    // For diffusion, keys are test names; for text, keys are batch sizes\n    const keys = Object.keys(chartDataByBatch);\n    if (!isDiffusion) {\n        keys.sort((a, b) => parseInt(a) - parseInt(b));\n    } else {\n        keys.sort(); // Alphabetical sort for test names\n    }\n    const batchSizes = keys; // Keep variable name for compatibility\n\n    if (batchSizes.length === 0) {\n        container.innerHTML = '<div class=\"no-data\">No data available for the selected filters</div>';\n        return;\n    }\n\n    let hasAnyData = false;\n\n    batchSizes.forEach(batchSize => {\n        const chartData = chartDataByBatch[batchSize];\n\n        const ctx_datasets = chartData.map((series, index) => {\n            // Filter data points - for metrics like accLength, exclude invalid values (-1 or null)\n            let dataPoints = series.data.map(d => ({ x: d.x, y: d[metric.field] }));\n            if (metric.filterInvalid) {\n                dataPoints = dataPoints.filter(d => d.y != null && d.y !== -1 && d.y > 0);\n            }\n            return {\n                label: series.label,\n                data: dataPoints,\n                borderColor: chartColors[index % chartColors.length],\n                backgroundColor: chartColors[index % chartColors.length] + '20',\n                tension: 0.1,\n                fill: false\n            };\n        }).filter(dataset => dataset.data.length > 0); // Remove empty datasets\n\n        // Skip this batch size if no valid data\n        if (ctx_datasets.length === 0) {\n            return;\n        }\n\n        hasAnyData = true;\n\n        const chartWrapper = document.createElement('div');\n        chartWrapper.className = 'batch-chart-wrapper';\n\n        const title = document.createElement('div');\n        title.className = 'batch-chart-title';\n        // For diffusion, show test name; for text, show batch size\n        title.textContent = isDiffusion ? `Test: ${batchSize}` : `Batch Size: ${batchSize}`;\n        chartWrapper.appendChild(title);\n\n        const chartContainer = document.createElement('div');\n        chartContainer.className = 'chart-container';\n        const canvas = document.createElement('canvas');\n        chartContainer.appendChild(canvas);\n        chartWrapper.appendChild(chartContainer);\n        container.appendChild(chartWrapper);\n\n        const ctx = canvas.getContext('2d');\n\n        const chart = new Chart(ctx, {\n            type: 'line',\n            data: { datasets: ctx_datasets },\n            options: getChartOptions(metric.unit)\n        });\n        activeCharts.push(chart);\n    });\n\n    // Show message if no valid data for this metric\n    if (!hasAnyData) {\n        container.innerHTML = `<div class=\"no-data\">No valid ${metric.label.toLowerCase()} data available for the selected filters</div>`;\n    }\n}\n\nfunction getChartOptions(yAxisLabel) {\n    return {\n        responsive: true,\n        maintainAspectRatio: false,\n        interaction: {\n            mode: 'index',\n            intersect: false\n        },\n        plugins: {\n            legend: {\n                position: 'bottom',\n                labels: {\n                    boxWidth: 12,\n                    padding: 10,\n                    font: { size: 11 }\n                }\n            },\n            tooltip: {\n                backgroundColor: '#1a2332',\n                borderColor: 'rgba(148, 163, 184, 0.1)',\n                borderWidth: 1,\n                titleFont: { size: 13, family: \"'DM Sans', sans-serif\" },\n                bodyFont: { size: 12, family: \"'JetBrains Mono', monospace\" },\n                padding: 14,\n                cornerRadius: 8\n            }\n        },\n        scales: {\n            x: {\n                type: 'time',\n                time: {\n                    unit: 'day',\n                    displayFormats: {\n                        day: 'MMM d'\n                    }\n                },\n                grid: {\n                    color: 'rgba(148, 163, 184, 0.06)'\n                }\n            },\n            y: {\n                title: {\n                    display: true,\n                    text: yAxisLabel\n                },\n                grid: {\n                    color: 'rgba(148, 163, 184, 0.06)'\n                }\n            }\n        }\n    };\n}\n\n// Escape HTML to prevent XSS\nfunction escapeHtml(text) {\n    const div = document.createElement('div');\n    div.textContent = text;\n    return div.innerHTML;\n}\n\n// Update runs table\nfunction updateRunsTable() {\n    const tbody = document.getElementById('runs-table-body');\n    tbody.innerHTML = '';\n\n    allMetricsData.slice(0, 10).forEach(run => {\n        const models = new Set(run.results.map(r => r.model.split('/').pop()));\n        const date = new Date(run.run_date);\n\n        const row = document.createElement('tr');\n\n        // Create cells safely to prevent XSS\n        const dateCell = document.createElement('td');\n        dateCell.textContent = `${date.toLocaleDateString()} ${date.toLocaleTimeString()}`;\n\n        const runIdCell = document.createElement('td');\n        const runLink = document.createElement('a');\n        runLink.href = `https://github.com/${GITHUB_REPO}/actions/runs/${encodeURIComponent(run.run_id)}`;\n        runLink.target = '_blank';\n        runLink.className = 'run-link';\n        runLink.textContent = run.run_id;\n        runIdCell.appendChild(runLink);\n\n        const commitCell = document.createElement('td');\n        const commitCode = document.createElement('code');\n        commitCode.textContent = run.commit_sha.substring(0, 7);\n        commitCell.appendChild(commitCode);\n\n        const branchCell = document.createElement('td');\n        branchCell.textContent = run.branch;\n\n        const modelsCell = document.createElement('td');\n        Array.from(models).forEach((model, index) => {\n            if (index > 0) modelsCell.appendChild(document.createTextNode(' '));\n            const badge = document.createElement('span');\n            badge.className = 'model-badge';\n            badge.textContent = model;\n            modelsCell.appendChild(badge);\n        });\n\n        row.appendChild(dateCell);\n        row.appendChild(runIdCell);\n        row.appendChild(commitCell);\n        row.appendChild(branchCell);\n        row.appendChild(modelsCell);\n\n        tbody.appendChild(row);\n    });\n}\n\n// Refresh data\nasync function refreshData() {\n    document.getElementById('content').style.display = 'none';\n    document.getElementById('loading').style.display = 'flex';\n    await init();\n}\n\n// Format numbers for display\nfunction formatNumber(num) {\n    if (num >= 1000) {\n        return (num / 1000).toFixed(1) + 'k';\n    }\n    return num.toFixed(1);\n}\n\n// Authentication state\nlet authToken = sessionStorage.getItem('dashboard_auth_token') || null;\n\n// Get auth headers for API requests\nfunction getAuthHeaders() {\n    const headers = {};\n    if (authToken) {\n        headers['Authorization'] = `Bearer ${authToken}`;\n    }\n    return headers;\n}\n\n// Check if server requires authentication and show/hide login accordingly\nasync function checkAuthAndInit() {\n    const loginOverlay = document.getElementById('login-overlay');\n    const dashboardContainer = document.getElementById('dashboard-container');\n\n    try {\n        const response = await fetch('/api/auth-check');\n        if (response.ok) {\n            const data = await response.json();\n            if (!data.auth_required) {\n                // No auth required - skip login, show dashboard directly\n                loginOverlay.style.display = 'none';\n                dashboardContainer.style.display = 'block';\n                init();\n                return;\n            }\n        }\n    } catch (e) {\n        // Server not available (e.g. static hosting) - skip login\n        loginOverlay.style.display = 'none';\n        dashboardContainer.style.display = 'block';\n        init();\n        return;\n    }\n\n    // Auth is required - check if we have a valid token from a previous session\n    if (authToken) {\n        try {\n            const testResponse = await fetch('/api/metrics', {\n                headers: getAuthHeaders()\n            });\n            if (testResponse.ok) {\n                loginOverlay.style.display = 'none';\n                dashboardContainer.style.display = 'block';\n                init();\n                return;\n            }\n        } catch (e) {\n            // Token invalid or expired\n        }\n        // Clear invalid token\n        authToken = null;\n        sessionStorage.removeItem('dashboard_auth_token');\n    }\n\n    // Show login form\n    loginOverlay.style.display = 'flex';\n    dashboardContainer.style.display = 'none';\n}\n\n// Handle login form submission\nasync function handleLogin(event) {\n    event.preventDefault();\n\n    const username = document.getElementById('login-username').value;\n    const password = document.getElementById('login-password').value;\n    const errorEl = document.getElementById('login-error');\n    const loginBtn = document.getElementById('login-btn');\n\n    errorEl.textContent = '';\n    loginBtn.disabled = true;\n    loginBtn.textContent = 'Signing in...';\n\n    try {\n        const response = await fetch('/api/login', {\n            method: 'POST',\n            headers: { 'Content-Type': 'application/json' },\n            body: JSON.stringify({ username, password })\n        });\n\n        const data = await response.json();\n\n        if (response.ok && data.token) {\n            authToken = data.token;\n            sessionStorage.setItem('dashboard_auth_token', authToken);\n\n            document.getElementById('login-overlay').style.display = 'none';\n            document.getElementById('dashboard-container').style.display = 'block';\n            init();\n        } else {\n            errorEl.textContent = data.error || 'Invalid username or password';\n        }\n    } catch (e) {\n        errorEl.textContent = 'Unable to connect to server';\n    } finally {\n        loginBtn.disabled = false;\n        loginBtn.textContent = 'Sign In';\n    }\n\n    return false;\n}\n\n// Initialize on page load\ndocument.addEventListener('DOMContentLoaded', checkAuthAndInit);\n"
  },
  {
    "path": "docs/performance_dashboard/fetch_metrics.py",
    "content": "#!/usr/bin/env python3\n\"\"\"\nFetch and process SGLang nightly test metrics from GitHub Actions artifacts.\n\nThis script fetches consolidated metrics from GitHub Actions workflow runs\nand outputs them as JSON for the performance dashboard.\n\nUsage:\n    python fetch_metrics.py --output metrics_data.json\n    python fetch_metrics.py --output metrics_data.json --days 30\n    python fetch_metrics.py --output metrics_data.json --run-id 21338741812\n\"\"\"\n\nimport argparse\nimport io\nimport json\nimport os\nimport sys\nimport zipfile\nfrom datetime import datetime, timedelta, timezone\nfrom pathlib import Path\nfrom typing import Optional\n\nimport requests\n\nGITHUB_REPO = \"sgl-project/sglang\"\nWORKFLOW_NAME = \"nightly-test-nvidia.yml\"\nARTIFACT_PREFIX = \"consolidated-metrics-\"\n\n\ndef get_github_token() -> Optional[str]:\n    \"\"\"Get GitHub token from environment or gh CLI.\"\"\"\n    # Check environment variable first\n    token = os.environ.get(\"GITHUB_TOKEN\")\n    if token:\n        return token\n\n    # Try gh CLI\n    try:\n        import subprocess\n\n        result = subprocess.run(\n            [\"gh\", \"auth\", \"token\"],\n            capture_output=True,\n            text=True,\n            check=True,\n        )\n        return result.stdout.strip()\n    except (subprocess.CalledProcessError, FileNotFoundError):\n        pass\n\n    return None\n\n\ndef get_headers(token: Optional[str]) -> dict:\n    \"\"\"Get request headers with optional authentication.\"\"\"\n    headers = {\n        \"Accept\": \"application/vnd.github.v3+json\",\n    }\n    if token:\n        headers[\"Authorization\"] = f\"Bearer {token}\"\n    return headers\n\n\ndef fetch_workflow_runs(\n    token: Optional[str],\n    days: int = 30,\n    event: Optional[str] = None,\n) -> list:\n    \"\"\"Fetch completed workflow runs from GitHub Actions.\"\"\"\n    url = f\"https://api.github.com/repos/{GITHUB_REPO}/actions/workflows/{WORKFLOW_NAME}/runs\"\n\n    params = {\n        \"status\": \"completed\",\n        \"per_page\": 100,\n    }\n\n    if event:\n        params[\"event\"] = event\n\n    response = requests.get(url, headers=get_headers(token), params=params, timeout=30)\n    response.raise_for_status()\n\n    runs = response.json().get(\"workflow_runs\", [])\n\n    # Filter by date\n    cutoff = datetime.now(timezone.utc) - timedelta(days=days)\n    runs = [\n        run\n        for run in runs\n        if datetime.fromisoformat(run[\"created_at\"].replace(\"Z\", \"+00:00\")) > cutoff\n    ]\n\n    return runs\n\n\ndef fetch_run_artifacts(token: Optional[str], run_id: int) -> list:\n    \"\"\"Fetch artifacts for a specific workflow run.\"\"\"\n    url = f\"https://api.github.com/repos/{GITHUB_REPO}/actions/runs/{run_id}/artifacts\"\n\n    response = requests.get(url, headers=get_headers(token), timeout=30)\n    response.raise_for_status()\n\n    return response.json().get(\"artifacts\", [])\n\n\ndef download_artifact(token: Optional[str], artifact_id: int) -> Optional[bytes]:\n    \"\"\"Download an artifact by ID.\"\"\"\n    if not token:\n        print(f\"Warning: GitHub token required to download artifacts\", file=sys.stderr)\n        return None\n\n    url = f\"https://api.github.com/repos/{GITHUB_REPO}/actions/artifacts/{artifact_id}/zip\"\n\n    headers = get_headers(token)\n    response = requests.get(url, headers=headers, allow_redirects=True, timeout=60)\n\n    if response.status_code == 200:\n        return response.content\n\n    print(\n        f\"Failed to download artifact {artifact_id}: {response.status_code}\",\n        file=sys.stderr,\n    )\n    return None\n\n\ndef extract_metrics_from_zip(zip_content: bytes) -> Optional[dict]:\n    \"\"\"Extract metrics JSON from a zip file.\"\"\"\n    try:\n        with zipfile.ZipFile(io.BytesIO(zip_content)) as zf:\n            # Find the JSON file in the archive\n            json_files = [f for f in zf.namelist() if f.endswith(\".json\")]\n            if not json_files:\n                return None\n\n            with zf.open(json_files[0]) as f:\n                return json.load(f)\n    except (zipfile.BadZipFile, json.JSONDecodeError) as e:\n        print(f\"Failed to extract metrics: {e}\", file=sys.stderr)\n        return None\n\n\ndef fetch_metrics_for_run(token: Optional[str], run: dict) -> Optional[dict]:\n    \"\"\"Fetch metrics for a single workflow run.\"\"\"\n    run_id = run[\"id\"]\n    print(f\"Fetching metrics for run {run_id}...\", file=sys.stderr)\n\n    artifacts = fetch_run_artifacts(token, run_id)\n\n    # Find consolidated metrics artifact\n    metrics_artifact = None\n    for artifact in artifacts:\n        if artifact[\"name\"].startswith(ARTIFACT_PREFIX):\n            metrics_artifact = artifact\n            break\n\n    if not metrics_artifact:\n        print(f\"No consolidated metrics found for run {run_id}\", file=sys.stderr)\n        return None\n\n    # Download and extract\n    zip_content = download_artifact(token, metrics_artifact[\"id\"])\n    if not zip_content:\n        return None\n\n    metrics = extract_metrics_from_zip(zip_content)\n    if not metrics:\n        return None\n\n    # Ensure required fields are present\n    if \"run_id\" not in metrics:\n        metrics[\"run_id\"] = str(run_id)\n    if \"run_date\" not in metrics:\n        metrics[\"run_date\"] = run[\"created_at\"]\n    if \"commit_sha\" not in metrics:\n        metrics[\"commit_sha\"] = run[\"head_sha\"]\n    if \"branch\" not in metrics:\n        metrics[\"branch\"] = run[\"head_branch\"]\n\n    return metrics\n\n\ndef fetch_single_run(token: Optional[str], run_id: int) -> Optional[dict]:\n    \"\"\"Fetch metrics for a single run by ID.\"\"\"\n    url = f\"https://api.github.com/repos/{GITHUB_REPO}/actions/runs/{run_id}\"\n\n    response = requests.get(url, headers=get_headers(token), timeout=30)\n    response.raise_for_status()\n\n    run = response.json()\n    return fetch_metrics_for_run(token, run)\n\n\ndef main():\n    parser = argparse.ArgumentParser(\n        description=\"Fetch SGLang nightly test metrics from GitHub Actions\"\n    )\n    parser.add_argument(\n        \"--output\",\n        \"-o\",\n        type=str,\n        default=\"metrics_data.json\",\n        help=\"Output JSON file path\",\n    )\n    parser.add_argument(\n        \"--days\",\n        type=int,\n        default=30,\n        help=\"Number of days to fetch (default: 30)\",\n    )\n    parser.add_argument(\n        \"--run-id\",\n        type=int,\n        help=\"Fetch a specific run by ID\",\n    )\n    parser.add_argument(\n        \"--event\",\n        type=str,\n        choices=[\"schedule\", \"workflow_dispatch\", \"push\"],\n        help=\"Filter by trigger event type\",\n    )\n    parser.add_argument(\n        \"--scheduled-only\",\n        action=\"store_true\",\n        help=\"Only fetch scheduled (nightly) runs\",\n    )\n\n    args = parser.parse_args()\n\n    token = get_github_token()\n    if not token:\n        print(\n            \"Warning: No GitHub token found. Some features may be limited.\",\n            file=sys.stderr,\n        )\n        print(\n            \"Set GITHUB_TOKEN env var or login with 'gh auth login'\",\n            file=sys.stderr,\n        )\n\n    all_metrics = []\n\n    if args.run_id:\n        # Fetch single run\n        metrics = fetch_single_run(token, args.run_id)\n        if metrics:\n            all_metrics.append(metrics)\n    else:\n        # Fetch multiple runs\n        event = \"schedule\" if args.scheduled_only else args.event\n        runs = fetch_workflow_runs(token, days=args.days, event=event)\n        print(f\"Found {len(runs)} workflow runs\", file=sys.stderr)\n\n        for run in runs:\n            metrics = fetch_metrics_for_run(token, run)\n            if metrics:\n                all_metrics.append(metrics)\n\n    # Sort by date descending\n    all_metrics.sort(key=lambda x: x.get(\"run_date\", \"\"), reverse=True)\n\n    # Write output\n    output_path = Path(args.output)\n    with open(output_path, \"w\") as f:\n        json.dump(all_metrics, f, indent=2)\n\n    print(f\"Wrote {len(all_metrics)} metrics records to {output_path}\", file=sys.stderr)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "docs/performance_dashboard/index.html",
    "content": "<!DOCTYPE html>\n<html lang=\"en\">\n<head>\n    <meta charset=\"UTF-8\">\n    <meta name=\"viewport\" content=\"width=device-width, initial-scale=1.0\">\n    <title>SGLang Performance Dashboard</title>\n    <script src=\"https://cdn.jsdelivr.net/npm/chart.js\"></script>\n    <script src=\"https://cdn.jsdelivr.net/npm/chartjs-adapter-date-fns\"></script>\n    <link rel=\"preconnect\" href=\"https://fonts.googleapis.com\">\n    <link rel=\"preconnect\" href=\"https://fonts.gstatic.com\" crossorigin>\n    <link href=\"https://fonts.googleapis.com/css2?family=DM+Sans:ital,opsz,wght@0,9..40,300;0,9..40,400;0,9..40,500;0,9..40,600;0,9..40,700;1,9..40,400&family=JetBrains+Mono:wght@400;500;600;700&display=swap\" rel=\"stylesheet\">\n    <style>\n        :root {\n            --bg-primary: #0a0e17;\n            --bg-secondary: #111827;\n            --bg-tertiary: #1a2332;\n            --bg-elevated: #1e293b;\n            --text-primary: #e2e8f0;\n            --text-secondary: #94a3b8;\n            --text-muted: #64748b;\n            --border-color: #1e293b;\n            --border-subtle: rgba(148, 163, 184, 0.08);\n            --accent-cyan: #22d3ee;\n            --accent-cyan-dim: rgba(34, 211, 238, 0.15);\n            --accent-green: #34d399;\n            --accent-green-dim: rgba(52, 211, 153, 0.15);\n            --accent-amber: #fbbf24;\n            --accent-amber-dim: rgba(251, 191, 36, 0.15);\n            --accent-red: #f87171;\n            --accent-red-dim: rgba(248, 113, 113, 0.15);\n            --accent-violet: #a78bfa;\n            --accent-violet-dim: rgba(167, 139, 250, 0.15);\n            --glass-bg: rgba(17, 24, 39, 0.7);\n            --glass-border: rgba(148, 163, 184, 0.1);\n            --shadow-sm: 0 1px 2px rgba(0, 0, 0, 0.3);\n            --shadow-md: 0 4px 16px rgba(0, 0, 0, 0.3);\n            --shadow-lg: 0 12px 40px rgba(0, 0, 0, 0.4);\n            --radius-sm: 6px;\n            --radius-md: 10px;\n            --radius-lg: 14px;\n            --radius-xl: 20px;\n        }\n\n        * {\n            margin: 0;\n            padding: 0;\n            box-sizing: border-box;\n        }\n\n        body {\n            font-family: 'DM Sans', -apple-system, BlinkMacSystemFont, sans-serif;\n            background-color: var(--bg-primary);\n            color: var(--text-primary);\n            line-height: 1.6;\n            min-height: 100vh;\n            overflow-x: hidden;\n        }\n\n        /* Subtle grid background */\n        body::before {\n            content: '';\n            position: fixed;\n            inset: 0;\n            background-image:\n                linear-gradient(rgba(148, 163, 184, 0.03) 1px, transparent 1px),\n                linear-gradient(90deg, rgba(148, 163, 184, 0.03) 1px, transparent 1px);\n            background-size: 60px 60px;\n            pointer-events: none;\n            z-index: 0;\n        }\n\n        /* Ambient glow */\n        body::after {\n            content: '';\n            position: fixed;\n            top: -40%;\n            left: -20%;\n            width: 80%;\n            height: 80%;\n            background: radial-gradient(ellipse, rgba(34, 211, 238, 0.04) 0%, transparent 70%);\n            pointer-events: none;\n            z-index: 0;\n        }\n\n        .container {\n            max-width: 1480px;\n            margin: 0 auto;\n            padding: 24px 32px;\n            position: relative;\n            z-index: 1;\n        }\n\n        /* ---- Header ---- */\n        header {\n            display: flex;\n            justify-content: space-between;\n            align-items: center;\n            padding: 20px 0 24px;\n            margin-bottom: 28px;\n            border-bottom: 1px solid var(--border-subtle);\n        }\n\n        h1 {\n            font-size: 22px;\n            font-weight: 600;\n            display: flex;\n            align-items: center;\n            gap: 14px;\n            letter-spacing: -0.02em;\n            color: var(--text-primary);\n        }\n\n        .logo-mark {\n            width: 36px;\n            height: 36px;\n            border-radius: var(--radius-md);\n            background: linear-gradient(135deg, var(--accent-cyan-dim), rgba(167, 139, 250, 0.12));\n            border: 1px solid rgba(34, 211, 238, 0.2);\n            display: flex;\n            align-items: center;\n            justify-content: center;\n            flex-shrink: 0;\n        }\n\n        .logo-mark svg {\n            width: 20px;\n            height: 20px;\n            color: var(--accent-cyan);\n        }\n\n        h1 span.title-accent {\n            color: var(--accent-cyan);\n        }\n\n        .header-actions {\n            display: flex;\n            gap: 10px;\n            align-items: center;\n        }\n\n        .btn {\n            padding: 8px 18px;\n            border-radius: var(--radius-sm);\n            border: 1px solid var(--border-color);\n            background: var(--bg-secondary);\n            color: var(--text-secondary);\n            cursor: pointer;\n            font-size: 13px;\n            font-family: 'DM Sans', sans-serif;\n            font-weight: 500;\n            transition: all 0.2s ease;\n            text-decoration: none;\n            display: inline-flex;\n            align-items: center;\n            gap: 6px;\n        }\n\n        .btn:hover {\n            background: var(--bg-tertiary);\n            color: var(--text-primary);\n            border-color: var(--glass-border);\n        }\n\n        .btn svg {\n            width: 14px;\n            height: 14px;\n        }\n\n        .btn-primary {\n            background: linear-gradient(135deg, rgba(34, 211, 238, 0.15), rgba(34, 211, 238, 0.08));\n            border-color: rgba(34, 211, 238, 0.25);\n            color: var(--accent-cyan);\n        }\n\n        .btn-primary:hover {\n            background: linear-gradient(135deg, rgba(34, 211, 238, 0.25), rgba(34, 211, 238, 0.12));\n            border-color: rgba(34, 211, 238, 0.4);\n        }\n\n        /* ---- Stats Row ---- */\n        .stats-row {\n            display: grid;\n            grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));\n            gap: 16px;\n            margin-bottom: 28px;\n        }\n\n        .stat-card {\n            background: var(--glass-bg);\n            backdrop-filter: blur(12px);\n            -webkit-backdrop-filter: blur(12px);\n            border-radius: var(--radius-lg);\n            border: 1px solid var(--glass-border);\n            padding: 20px 22px;\n            position: relative;\n            overflow: hidden;\n            transition: transform 0.2s ease, box-shadow 0.2s ease;\n        }\n\n        .stat-card:hover {\n            transform: translateY(-2px);\n            box-shadow: var(--shadow-md);\n        }\n\n        .stat-card::before {\n            content: '';\n            position: absolute;\n            top: 0;\n            left: 0;\n            right: 0;\n            height: 2px;\n            border-radius: var(--radius-lg) var(--radius-lg) 0 0;\n        }\n\n        .stat-card:nth-child(1)::before { background: linear-gradient(90deg, var(--accent-cyan), transparent); }\n        .stat-card:nth-child(2)::before { background: linear-gradient(90deg, var(--accent-violet), transparent); }\n        .stat-card:nth-child(3)::before { background: linear-gradient(90deg, var(--accent-green), transparent); }\n\n        .stat-card .label {\n            font-size: 11px;\n            color: var(--text-muted);\n            text-transform: uppercase;\n            letter-spacing: 0.08em;\n            font-weight: 500;\n            margin-bottom: 8px;\n        }\n\n        .stat-card .value {\n            font-family: 'JetBrains Mono', monospace;\n            font-size: 28px;\n            font-weight: 700;\n            color: var(--text-primary);\n            letter-spacing: -0.02em;\n        }\n\n        .stat-card .change {\n            font-size: 12px;\n            margin-top: 6px;\n            font-weight: 500;\n        }\n\n        .stat-card .change.positive { color: var(--accent-green); }\n        .stat-card .change.negative { color: var(--accent-red); }\n\n        /* ---- Filters ---- */\n        .filters {\n            display: flex;\n            gap: 14px;\n            flex-wrap: wrap;\n            margin-bottom: 28px;\n            padding: 18px 22px;\n            background: var(--glass-bg);\n            backdrop-filter: blur(12px);\n            -webkit-backdrop-filter: blur(12px);\n            border-radius: var(--radius-lg);\n            border: 1px solid var(--glass-border);\n            align-items: flex-end;\n        }\n\n        .filter-group {\n            display: flex;\n            flex-direction: column;\n            gap: 6px;\n            flex: 1;\n            min-width: 160px;\n        }\n\n        .filter-group label {\n            font-size: 10px;\n            color: var(--text-muted);\n            font-weight: 600;\n            text-transform: uppercase;\n            letter-spacing: 0.1em;\n        }\n\n        select {\n            padding: 9px 32px 9px 14px;\n            border-radius: var(--radius-sm);\n            border: 1px solid var(--border-color);\n            background: var(--bg-tertiary);\n            color: var(--text-primary);\n            font-size: 13px;\n            font-family: 'DM Sans', sans-serif;\n            font-weight: 500;\n            cursor: pointer;\n            transition: all 0.15s ease;\n            appearance: none;\n            -webkit-appearance: none;\n            background-image: url(\"data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='12' height='12' viewBox='0 0 24 24' fill='none' stroke='%2394a3b8' stroke-width='2' stroke-linecap='round' stroke-linejoin='round'%3E%3Cpath d='M6 9l6 6 6-6'/%3E%3C/svg%3E\");\n            background-repeat: no-repeat;\n            background-position: right 10px center;\n            width: 100%;\n        }\n\n        select:hover {\n            border-color: rgba(148, 163, 184, 0.2);\n        }\n\n        select:focus {\n            outline: none;\n            border-color: rgba(34, 211, 238, 0.4);\n            box-shadow: 0 0 0 3px rgba(34, 211, 238, 0.08);\n        }\n\n        /* ---- Metric Tabs ---- */\n        .tabs {\n            display: flex;\n            gap: 2px;\n            margin-bottom: 24px;\n            padding: 4px;\n            background: var(--bg-secondary);\n            border-radius: var(--radius-md);\n            border: 1px solid var(--border-subtle);\n            width: fit-content;\n        }\n\n        .tab {\n            padding: 9px 18px;\n            cursor: pointer;\n            border-radius: var(--radius-sm);\n            background: transparent;\n            color: var(--text-muted);\n            border: none;\n            transition: all 0.2s ease;\n            font-weight: 500;\n            font-size: 13px;\n            font-family: 'DM Sans', sans-serif;\n            white-space: nowrap;\n        }\n\n        .tab:hover {\n            color: var(--text-secondary);\n            background: rgba(148, 163, 184, 0.05);\n        }\n\n        .tab.active {\n            background: var(--bg-tertiary);\n            color: var(--accent-cyan);\n            box-shadow: var(--shadow-sm);\n        }\n\n        /* ---- Chart Cards ---- */\n        .chart-card {\n            background: var(--glass-bg);\n            backdrop-filter: blur(12px);\n            -webkit-backdrop-filter: blur(12px);\n            border-radius: var(--radius-lg);\n            border: 1px solid var(--glass-border);\n            padding: 24px;\n        }\n\n        .chart-card h3 {\n            font-size: 15px;\n            font-weight: 600;\n            margin-bottom: 20px;\n            color: var(--text-primary);\n            display: flex;\n            align-items: center;\n            gap: 10px;\n            letter-spacing: -0.01em;\n        }\n\n        .chart-card h3::before {\n            content: '';\n            width: 3px;\n            height: 18px;\n            background: var(--accent-cyan);\n            border-radius: 2px;\n            flex-shrink: 0;\n        }\n\n        .chart-container {\n            position: relative;\n            height: 320px;\n        }\n\n        .metric-section {\n            margin-bottom: 24px;\n        }\n\n        .batch-charts-container {\n            display: grid;\n            grid-template-columns: repeat(auto-fit, minmax(420px, 1fr));\n            gap: 18px;\n        }\n\n        .batch-chart-wrapper {\n            background: var(--bg-tertiary);\n            border-radius: var(--radius-md);\n            padding: 16px;\n            border: 1px solid var(--border-subtle);\n        }\n\n        .batch-chart-title {\n            font-family: 'JetBrains Mono', monospace;\n            font-size: 12px;\n            font-weight: 500;\n            color: var(--text-muted);\n            margin-bottom: 10px;\n            text-align: center;\n            text-transform: uppercase;\n            letter-spacing: 0.06em;\n        }\n\n        .charts-grid {\n            display: grid;\n            grid-template-columns: repeat(auto-fit, minmax(600px, 1fr));\n            gap: 24px;\n        }\n\n        /* ---- Data Table ---- */\n        .data-table {\n            width: 100%;\n            border-collapse: separate;\n            border-spacing: 0;\n            margin-top: 20px;\n        }\n\n        .data-table th {\n            padding: 10px 16px;\n            text-align: left;\n            font-size: 10px;\n            font-weight: 600;\n            color: var(--text-muted);\n            text-transform: uppercase;\n            letter-spacing: 0.08em;\n            border-bottom: 1px solid var(--border-color);\n            background: transparent;\n        }\n\n        .data-table td {\n            padding: 12px 16px;\n            text-align: left;\n            border-bottom: 1px solid var(--border-subtle);\n            font-size: 13px;\n            color: var(--text-secondary);\n        }\n\n        .data-table tbody tr {\n            transition: background 0.15s ease;\n        }\n\n        .data-table tbody tr:hover {\n            background: rgba(148, 163, 184, 0.04);\n        }\n\n        .data-table td code {\n            font-family: 'JetBrains Mono', monospace;\n            font-size: 12px;\n            color: var(--accent-cyan);\n            background: var(--accent-cyan-dim);\n            padding: 2px 8px;\n            border-radius: 4px;\n        }\n\n        .run-link {\n            font-family: 'JetBrains Mono', monospace;\n            font-size: 12px;\n            color: var(--accent-cyan);\n            text-decoration: none;\n            transition: color 0.15s;\n        }\n\n        .run-link:hover {\n            color: #67e8f9;\n            text-decoration: underline;\n        }\n\n        .model-badge {\n            display: inline-block;\n            padding: 3px 10px;\n            border-radius: 20px;\n            font-size: 11px;\n            font-weight: 500;\n            background: var(--accent-violet-dim);\n            color: var(--accent-violet);\n            border: 1px solid rgba(167, 139, 250, 0.15);\n        }\n\n        /* ---- Loading ---- */\n        .loading {\n            display: flex;\n            flex-direction: column;\n            justify-content: center;\n            align-items: center;\n            min-height: 400px;\n            gap: 20px;\n            color: var(--text-muted);\n        }\n\n        .spinner {\n            width: 36px;\n            height: 36px;\n            border: 2px solid var(--border-color);\n            border-top-color: var(--accent-cyan);\n            border-radius: 50%;\n            animation: spin 0.8s linear infinite;\n        }\n\n        @keyframes spin {\n            to { transform: rotate(360deg); }\n        }\n\n        .loading-text {\n            font-size: 13px;\n            font-weight: 500;\n            color: var(--text-muted);\n        }\n\n        /* ---- Error ---- */\n        .error {\n            background: var(--accent-red-dim);\n            border: 1px solid rgba(248, 113, 113, 0.2);\n            border-radius: var(--radius-lg);\n            padding: 28px;\n            text-align: center;\n            color: var(--accent-red);\n        }\n\n        .error h3 {\n            margin-bottom: 8px;\n            font-size: 16px;\n        }\n\n        .error p {\n            font-size: 13px;\n            color: rgba(248, 113, 113, 0.8);\n        }\n\n        .no-data {\n            text-align: center;\n            padding: 60px 20px;\n            color: var(--text-muted);\n            font-size: 14px;\n        }\n\n        .no-data h3 {\n            margin-bottom: 8px;\n        }\n\n        /* ---- Footer ---- */\n        footer {\n            margin-top: 48px;\n            padding: 28px 0;\n            border-top: 1px solid var(--border-subtle);\n            text-align: center;\n            color: var(--text-muted);\n            font-size: 13px;\n        }\n\n        footer a {\n            color: var(--text-secondary);\n            text-decoration: none;\n            transition: color 0.15s;\n        }\n\n        footer a:hover {\n            color: var(--accent-cyan);\n        }\n\n        /* ---- Login Overlay ---- */\n        .login-overlay {\n            position: fixed;\n            inset: 0;\n            background-color: var(--bg-primary);\n            display: flex;\n            justify-content: center;\n            align-items: center;\n            z-index: 1000;\n            overflow: hidden;\n        }\n\n        .login-overlay::before {\n            content: '';\n            position: absolute;\n            inset: 0;\n            background-image:\n                linear-gradient(rgba(148, 163, 184, 0.03) 1px, transparent 1px),\n                linear-gradient(90deg, rgba(148, 163, 184, 0.03) 1px, transparent 1px);\n            background-size: 60px 60px;\n            pointer-events: none;\n        }\n\n        .login-overlay::after {\n            content: '';\n            position: absolute;\n            top: 50%;\n            left: 50%;\n            transform: translate(-50%, -50%);\n            width: 600px;\n            height: 600px;\n            background: radial-gradient(ellipse, rgba(34, 211, 238, 0.06) 0%, transparent 70%);\n            pointer-events: none;\n        }\n\n        .login-card {\n            background: var(--glass-bg);\n            backdrop-filter: blur(20px);\n            -webkit-backdrop-filter: blur(20px);\n            border: 1px solid var(--glass-border);\n            border-radius: var(--radius-xl);\n            padding: 44px 40px;\n            width: 100%;\n            max-width: 400px;\n            box-shadow: var(--shadow-lg);\n            position: relative;\n            z-index: 1;\n            animation: loginSlideUp 0.5s ease-out;\n        }\n\n        @keyframes loginSlideUp {\n            from {\n                opacity: 0;\n                transform: translateY(20px);\n            }\n            to {\n                opacity: 1;\n                transform: translateY(0);\n            }\n        }\n\n        .login-icon {\n            text-align: center;\n            margin-bottom: 20px;\n        }\n\n        .login-icon-wrapper {\n            width: 56px;\n            height: 56px;\n            margin: 0 auto;\n            border-radius: var(--radius-lg);\n            background: linear-gradient(135deg, var(--accent-cyan-dim), rgba(167, 139, 250, 0.12));\n            border: 1px solid rgba(34, 211, 238, 0.2);\n            display: flex;\n            align-items: center;\n            justify-content: center;\n        }\n\n        .login-icon-wrapper svg {\n            width: 24px;\n            height: 24px;\n            color: var(--accent-cyan);\n        }\n\n        .login-card h2 {\n            font-size: 20px;\n            font-weight: 600;\n            margin-bottom: 6px;\n            text-align: center;\n            letter-spacing: -0.02em;\n        }\n\n        .login-card .login-subtitle {\n            font-size: 13px;\n            color: var(--text-muted);\n            text-align: center;\n            margin-bottom: 28px;\n        }\n\n        .login-card .form-group {\n            margin-bottom: 18px;\n        }\n\n        .login-card .form-group label {\n            display: block;\n            font-size: 12px;\n            color: var(--text-muted);\n            margin-bottom: 7px;\n            font-weight: 500;\n        }\n\n        .login-card .form-group input {\n            width: 100%;\n            padding: 11px 14px;\n            border-radius: var(--radius-sm);\n            border: 1px solid var(--border-color);\n            background: var(--bg-tertiary);\n            color: var(--text-primary);\n            font-size: 14px;\n            font-family: 'DM Sans', sans-serif;\n            outline: none;\n            transition: all 0.2s ease;\n        }\n\n        .login-card .form-group input:focus {\n            border-color: rgba(34, 211, 238, 0.4);\n            box-shadow: 0 0 0 3px rgba(34, 211, 238, 0.08);\n        }\n\n        .login-card .form-group input::placeholder {\n            color: var(--text-muted);\n        }\n\n        .login-card .login-btn {\n            width: 100%;\n            padding: 11px 16px;\n            border-radius: var(--radius-sm);\n            border: 1px solid rgba(34, 211, 238, 0.3);\n            background: linear-gradient(135deg, rgba(34, 211, 238, 0.15), rgba(34, 211, 238, 0.08));\n            color: var(--accent-cyan);\n            font-size: 14px;\n            font-family: 'DM Sans', sans-serif;\n            font-weight: 600;\n            cursor: pointer;\n            transition: all 0.2s ease;\n            margin-top: 6px;\n        }\n\n        .login-card .login-btn:hover {\n            background: linear-gradient(135deg, rgba(34, 211, 238, 0.25), rgba(34, 211, 238, 0.12));\n            border-color: rgba(34, 211, 238, 0.5);\n        }\n\n        .login-card .login-btn:disabled {\n            opacity: 0.5;\n            cursor: not-allowed;\n        }\n\n        .login-error {\n            color: var(--accent-red);\n            font-size: 13px;\n            text-align: center;\n            margin-top: 14px;\n            min-height: 20px;\n        }\n\n        /* ---- Entrance Animations ---- */\n        @keyframes fadeInUp {\n            from {\n                opacity: 0;\n                transform: translateY(12px);\n            }\n            to {\n                opacity: 1;\n                transform: translateY(0);\n            }\n        }\n\n        .animate-in {\n            animation: fadeInUp 0.4s ease-out both;\n        }\n\n        .animate-delay-1 { animation-delay: 0.05s; }\n        .animate-delay-2 { animation-delay: 0.1s; }\n        .animate-delay-3 { animation-delay: 0.15s; }\n        .animate-delay-4 { animation-delay: 0.2s; }\n        .animate-delay-5 { animation-delay: 0.25s; }\n        .animate-delay-6 { animation-delay: 0.3s; }\n\n        /* ---- Responsive ---- */\n        @media (max-width: 768px) {\n            .container {\n                padding: 16px;\n            }\n\n            header {\n                flex-direction: column;\n                gap: 16px;\n                align-items: flex-start;\n            }\n\n            .filters {\n                padding: 14px;\n            }\n\n            .filter-group {\n                min-width: 140px;\n            }\n\n            .tabs {\n                overflow-x: auto;\n                -webkit-overflow-scrolling: touch;\n            }\n\n            .batch-charts-container {\n                grid-template-columns: 1fr;\n            }\n\n            .login-card {\n                margin: 16px;\n                padding: 32px 24px;\n            }\n\n            .stat-card .value {\n                font-size: 22px;\n            }\n        }\n\n        /* ---- Scrollbar ---- */\n        ::-webkit-scrollbar {\n            width: 6px;\n            height: 6px;\n        }\n\n        ::-webkit-scrollbar-track {\n            background: transparent;\n        }\n\n        ::-webkit-scrollbar-thumb {\n            background: var(--border-color);\n            border-radius: 3px;\n        }\n\n        ::-webkit-scrollbar-thumb:hover {\n            background: var(--text-muted);\n        }\n    </style>\n</head>\n<body>\n    <!-- Login overlay -->\n    <div id=\"login-overlay\" class=\"login-overlay\">\n        <div class=\"login-card\">\n            <div class=\"login-icon\">\n                <div class=\"login-icon-wrapper\">\n                    <svg viewBox=\"0 0 24 24\" fill=\"none\" xmlns=\"http://www.w3.org/2000/svg\">\n                        <rect x=\"3\" y=\"11\" width=\"18\" height=\"11\" rx=\"2\" ry=\"2\" stroke=\"currentColor\" stroke-width=\"2\"/>\n                        <path d=\"M7 11V7a5 5 0 0 1 10 0v4\" stroke=\"currentColor\" stroke-width=\"2\" stroke-linecap=\"round\"/>\n                        <circle cx=\"12\" cy=\"16\" r=\"1.5\" fill=\"currentColor\"/>\n                    </svg>\n                </div>\n            </div>\n            <h2>SGLang Performance Dashboard</h2>\n            <p class=\"login-subtitle\">Enter your credentials to access the dashboard</p>\n            <form id=\"login-form\" onsubmit=\"return handleLogin(event)\">\n                <div class=\"form-group\">\n                    <label for=\"login-username\">Username</label>\n                    <input type=\"text\" id=\"login-username\" name=\"username\" autocomplete=\"username\" placeholder=\"Enter username\" required autofocus>\n                </div>\n                <div class=\"form-group\">\n                    <label for=\"login-password\">Password</label>\n                    <input type=\"password\" id=\"login-password\" name=\"password\" autocomplete=\"current-password\" placeholder=\"Enter password\" required>\n                </div>\n                <button type=\"submit\" class=\"login-btn\" id=\"login-btn\">Sign In</button>\n            </form>\n            <div id=\"login-error\" class=\"login-error\"></div>\n        </div>\n    </div>\n\n    <div class=\"container\" id=\"dashboard-container\" style=\"display: none;\">\n        <header class=\"animate-in\">\n            <h1>\n                <div class=\"logo-mark\">\n                    <svg viewBox=\"0 0 24 24\" fill=\"none\" xmlns=\"http://www.w3.org/2000/svg\">\n                        <path d=\"M12 2L2 7L12 12L22 7L12 2Z\" stroke=\"currentColor\" stroke-width=\"2\" stroke-linecap=\"round\" stroke-linejoin=\"round\"/>\n                        <path d=\"M2 17L12 22L22 17\" stroke=\"currentColor\" stroke-width=\"2\" stroke-linecap=\"round\" stroke-linejoin=\"round\"/>\n                        <path d=\"M2 12L12 17L22 12\" stroke=\"currentColor\" stroke-width=\"2\" stroke-linecap=\"round\" stroke-linejoin=\"round\"/>\n                    </svg>\n                </div>\n                <span><span class=\"title-accent\">SGLang</span> Performance Dashboard</span>\n            </h1>\n            <div class=\"header-actions\">\n                <button class=\"btn\" onclick=\"refreshData()\">\n                    <svg viewBox=\"0 0 24 24\" fill=\"none\" stroke=\"currentColor\" stroke-width=\"2\" stroke-linecap=\"round\" stroke-linejoin=\"round\"><polyline points=\"23 4 23 10 17 10\"/><path d=\"M20.49 15a9 9 0 1 1-2.12-9.36L23 10\"/></svg>\n                    Refresh\n                </button>\n                <a href=\"https://github.com/sgl-project/sglang/actions/workflows/nightly-test-nvidia.yml?query=event%3Aschedule\" target=\"_blank\" class=\"btn\">\n                    <svg viewBox=\"0 0 24 24\" fill=\"none\" stroke=\"currentColor\" stroke-width=\"2\" stroke-linecap=\"round\" stroke-linejoin=\"round\"><path d=\"M18 13v6a2 2 0 0 1-2 2H5a2 2 0 0 1-2-2V8a2 2 0 0 1 2-2h6\"/><polyline points=\"15 3 21 3 21 9\"/><line x1=\"10\" y1=\"14\" x2=\"21\" y2=\"3\"/></svg>\n                    Workflow\n                </a>\n            </div>\n        </header>\n\n        <div id=\"loading\" class=\"loading\">\n            <div class=\"spinner\"></div>\n            <div class=\"loading-text\">Loading performance data...</div>\n        </div>\n\n        <div id=\"content\" style=\"display: none;\">\n            <div class=\"stats-row animate-in animate-delay-1\" id=\"stats-row\"></div>\n\n            <div class=\"filters animate-in animate-delay-2\">\n                <div class=\"filter-group\">\n                    <label>GPU Configuration</label>\n                    <select id=\"gpu-filter\" onchange=\"handleGpuFilterChange()\">\n                    </select>\n                </div>\n                <div class=\"filter-group\">\n                    <label>Model</label>\n                    <select id=\"model-filter\" onchange=\"handleModelFilterChange(this.value)\">\n                    </select>\n                </div>\n                <div class=\"filter-group\">\n                    <label>Variant</label>\n                    <select id=\"variant-filter\" onchange=\"updateCharts()\">\n                        <option value=\"all\">All Variants</option>\n                    </select>\n                </div>\n                <div class=\"filter-group\">\n                    <label>Input / Output Length</label>\n                    <select id=\"io-len-filter\" onchange=\"updateCharts()\">\n                        <option value=\"all\">All Lengths</option>\n                    </select>\n                </div>\n                <div class=\"filter-group\">\n                    <label>Batch Size</label>\n                    <select id=\"batch-filter\" onchange=\"updateCharts()\">\n                        <option value=\"all\">All Batch Sizes</option>\n                    </select>\n                </div>\n            </div>\n\n            <div class=\"tabs animate-in animate-delay-3\" id=\"metric-tabs\"></div>\n\n            <div class=\"metric-section animate-in animate-delay-4\">\n                <div class=\"chart-card\">\n                    <h3 id=\"metric-title\">Overall Throughput (tokens/sec)</h3>\n                    <div class=\"batch-charts-container\" id=\"charts-container\">\n                    </div>\n                </div>\n            </div>\n\n            <div class=\"chart-card animate-in animate-delay-5\" style=\"margin-top: 24px;\">\n                <h3>Recent Benchmark Runs</h3>\n                <table class=\"data-table\" id=\"runs-table\">\n                    <thead>\n                        <tr>\n                            <th>Date</th>\n                            <th>Run ID</th>\n                            <th>Commit</th>\n                            <th>Branch</th>\n                            <th>Models Tested</th>\n                        </tr>\n                    </thead>\n                    <tbody id=\"runs-table-body\">\n                    </tbody>\n                </table>\n            </div>\n        </div>\n\n        <div id=\"error\" class=\"error\" style=\"display: none;\">\n            <h3>Failed to load performance data</h3>\n            <p id=\"error-message\"></p>\n        </div>\n\n        <footer class=\"animate-in animate-delay-6\">\n            <p>\n                SGLang Performance Dashboard &mdash;\n                <a href=\"https://github.com/sgl-project/sglang\" target=\"_blank\">GitHub</a> &middot;\n                <a href=\"https://sgl-project.github.io/sglang\" target=\"_blank\">Documentation</a>\n            </p>\n        </footer>\n    </div>\n\n    <script src=\"app.js\"></script>\n</body>\n</html>\n"
  },
  {
    "path": "docs/performance_dashboard/server.py",
    "content": "#!/usr/bin/env python3\n\"\"\"\nSimple development server for the SGLang Performance Dashboard.\n\nThis server:\n1. Serves the static HTML/JS files\n2. Provides an API endpoint to fetch metrics from GitHub\n3. Caches metrics data to reduce API calls\n\nUsage:\n    python server.py\n    python server.py --port 8080\n    python server.py --host 0.0.0.0  # Allow external access\n    python server.py --fetch-on-start\n    python server.py --username admin --password secret  # Enable authentication\n    DASHBOARD_USERNAME=admin DASHBOARD_PASSWORD=secret python server.py  # Via env vars\n    python server.py --refresh-interval 12  # Auto-refresh data every 12 hours\n\"\"\"\n\nimport argparse\nimport hashlib\nimport hmac\nimport http.server\nimport io\nimport json\nimport os\nimport secrets\nimport socketserver\nimport threading\nimport time\nimport zipfile\nfrom datetime import datetime, timedelta, timezone\nfrom pathlib import Path\nfrom urllib.parse import urlparse\n\nimport requests\n\nGITHUB_REPO = \"sgl-project/sglang\"\nWORKFLOW_NAME = \"nightly-test-nvidia.yml\"\nARTIFACT_PREFIX = \"consolidated-metrics-\"\n\n# Cache for metrics data with thread-safe lock\ncache_lock = threading.Lock()\nmetrics_cache = {\n    \"data\": [],\n    \"last_updated\": None,\n    \"updating\": False,\n}\n\nCACHE_TTL = 300  # 5 minutes\nREQUEST_TIMEOUT = 30  # seconds\n\n# Authentication configuration (set via CLI flags)\nauth_config = {\n    \"enabled\": False,\n    \"username\": None,\n    \"password_hash\": None,  # SHA-256 hash of the password\n    \"active_tokens\": {},  # token -> expiry timestamp\n}\nauth_lock = threading.Lock()\nAUTH_TOKEN_TTL = 3600  # 1 hour\n\n\ndef hash_password(password):\n    \"\"\"Hash a password using SHA-256 for constant-time comparison.\"\"\"\n    return hashlib.sha256(password.encode(\"utf-8\")).hexdigest()\n\n\ndef create_auth_token():\n    \"\"\"Create a new session token.\"\"\"\n    token = secrets.token_hex(32)\n    with auth_lock:\n        # Clean up expired tokens\n        now = time.time()\n        auth_config[\"active_tokens\"] = {\n            t: exp for t, exp in auth_config[\"active_tokens\"].items() if exp > now\n        }\n        auth_config[\"active_tokens\"][token] = now + AUTH_TOKEN_TTL\n    return token\n\n\ndef verify_auth_token(token):\n    \"\"\"Verify a session token is valid and not expired.\"\"\"\n    if not token:\n        return False\n    with auth_lock:\n        expiry = auth_config[\"active_tokens\"].get(token)\n        if expiry and expiry > time.time():\n            return True\n        # Remove expired token\n        auth_config[\"active_tokens\"].pop(token, None)\n        return False\n\n\ndef get_github_token():\n    \"\"\"Get GitHub token from environment or gh CLI.\"\"\"\n    token = os.environ.get(\"GITHUB_TOKEN\")\n    if token:\n        return token\n\n    try:\n        import subprocess\n\n        result = subprocess.run(\n            [\"gh\", \"auth\", \"token\"],\n            capture_output=True,\n            text=True,\n            check=True,\n        )\n        return result.stdout.strip()\n    except (subprocess.CalledProcessError, FileNotFoundError):\n        pass\n\n    return None\n\n\ndef fetch_metrics_from_github(days=30):\n    \"\"\"Fetch metrics from GitHub Actions artifacts.\"\"\"\n    token = get_github_token()\n    headers = {\"Accept\": \"application/vnd.github.v3+json\"}\n    if token:\n        headers[\"Authorization\"] = f\"Bearer {token}\"\n\n    # Get workflow runs - only scheduled (nightly) runs, not workflow_dispatch\n    url = f\"https://api.github.com/repos/{GITHUB_REPO}/actions/workflows/{WORKFLOW_NAME}/runs\"\n    params = {\"status\": \"completed\", \"per_page\": 50, \"event\": \"schedule\"}\n\n    try:\n        response = requests.get(\n            url, headers=headers, params=params, timeout=REQUEST_TIMEOUT\n        )\n        if not response.ok:\n            print(f\"Failed to fetch workflow runs: {response.status_code}\")\n            return []\n    except requests.exceptions.RequestException as e:\n        print(f\"Network error fetching workflow runs: {e}\")\n        return []\n\n    runs = response.json().get(\"workflow_runs\", [])\n\n    # Filter by date\n    cutoff = datetime.now(timezone.utc) - timedelta(days=days)\n    runs = [\n        run\n        for run in runs\n        if datetime.fromisoformat(run[\"created_at\"].replace(\"Z\", \"+00:00\")) > cutoff\n    ]\n\n    all_metrics = []\n\n    for run in runs[:20]:  # Limit to 20 most recent\n        run_id = run[\"id\"]\n\n        # Get artifacts\n        artifacts_url = f\"https://api.github.com/repos/{GITHUB_REPO}/actions/runs/{run_id}/artifacts\"\n        try:\n            artifacts_resp = requests.get(\n                artifacts_url, headers=headers, timeout=REQUEST_TIMEOUT\n            )\n            if not artifacts_resp.ok:\n                continue\n        except requests.exceptions.RequestException as e:\n            print(f\"Network error fetching artifacts for run {run_id}: {e}\")\n            continue\n\n        artifacts = artifacts_resp.json().get(\"artifacts\", [])\n\n        # Find consolidated metrics\n        for artifact in artifacts:\n            if artifact[\"name\"].startswith(ARTIFACT_PREFIX):\n                if not token:\n                    # Without token, we can't download - return metadata only\n                    all_metrics.append(\n                        {\n                            \"run_id\": str(run_id),\n                            \"run_date\": run[\"created_at\"],\n                            \"commit_sha\": run[\"head_sha\"],\n                            \"branch\": run[\"head_branch\"],\n                            \"results\": [],\n                        }\n                    )\n                    break\n\n                # Download artifact\n                download_url = f\"https://api.github.com/repos/{GITHUB_REPO}/actions/artifacts/{artifact['id']}/zip\"\n                try:\n                    download_resp = requests.get(\n                        download_url,\n                        headers=headers,\n                        allow_redirects=True,\n                        timeout=REQUEST_TIMEOUT,\n                    )\n                except requests.exceptions.RequestException as e:\n                    print(f\"Network error downloading artifact: {e}\")\n                    break\n\n                if download_resp.ok:\n                    try:\n                        with zipfile.ZipFile(io.BytesIO(download_resp.content)) as zf:\n                            json_files = [\n                                f for f in zf.namelist() if f.endswith(\".json\")\n                            ]\n                            if json_files:\n                                with zf.open(json_files[0]) as f:\n                                    metrics = json.load(f)\n                                    # Ensure required fields\n                                    metrics.setdefault(\"run_id\", str(run_id))\n                                    metrics.setdefault(\"run_date\", run[\"created_at\"])\n                                    metrics.setdefault(\"commit_sha\", run[\"head_sha\"])\n                                    metrics.setdefault(\"branch\", run[\"head_branch\"])\n                                    all_metrics.append(metrics)\n                    except (zipfile.BadZipFile, json.JSONDecodeError) as e:\n                        print(f\"Failed to process artifact: {e}\")\n                break\n\n    return all_metrics\n\n\ndef update_cache_async():\n    \"\"\"Update the metrics cache in background with thread safety.\"\"\"\n    with cache_lock:\n        if metrics_cache[\"updating\"]:\n            return\n        metrics_cache[\"updating\"] = True\n\n    try:\n        data = fetch_metrics_from_github()\n        with cache_lock:\n            metrics_cache[\"data\"] = data\n            metrics_cache[\"last_updated\"] = time.time()\n        print(f\"Cache updated with {len(data)} metrics records\")\n    finally:\n        with cache_lock:\n            metrics_cache[\"updating\"] = False\n\n\ndef start_periodic_refresh(interval_hours):\n    \"\"\"Start a background thread that refreshes the cache periodically.\"\"\"\n    interval_seconds = interval_hours * 3600\n\n    def refresh_loop():\n        while True:\n            time.sleep(interval_seconds)\n            print(f\"Periodic refresh triggered (every {interval_hours}h)\")\n            update_cache_async()\n\n    thread = threading.Thread(target=refresh_loop, daemon=True)\n    thread.start()\n    print(f\"Periodic refresh enabled: every {interval_hours} hours\")\n\n\nclass DashboardHandler(http.server.SimpleHTTPRequestHandler):\n    \"\"\"HTTP request handler for the dashboard.\"\"\"\n\n    def __init__(self, *args, directory=None, **kwargs):\n        super().__init__(*args, directory=directory, **kwargs)\n\n    def _send_json(self, data, status=200):\n        \"\"\"Send a JSON response.\"\"\"\n        self.send_response(status)\n        self.send_header(\"Content-Type\", \"application/json\")\n        self.send_header(\"Access-Control-Allow-Origin\", \"*\")\n        self.end_headers()\n        self.wfile.write(json.dumps(data).encode())\n\n    def _check_auth(self):\n        \"\"\"Check if request is authenticated. Returns True if OK, sends 401 and returns False otherwise.\"\"\"\n        if not auth_config[\"enabled\"]:\n            return True\n        auth_header = self.headers.get(\"Authorization\", \"\")\n        if auth_header.startswith(\"Bearer \"):\n            token = auth_header[7:]\n            if verify_auth_token(token):\n                return True\n        self._send_json({\"error\": \"Unauthorized\"}, status=401)\n        return False\n\n    def do_GET(self):\n        parsed = urlparse(self.path)\n\n        # Prevent directory traversal attacks\n        if \"..\" in parsed.path or parsed.path.startswith(\"//\"):\n            self.send_error(400, \"Invalid path\")\n            return\n\n        if parsed.path == \"/api/auth-check\":\n            self.handle_auth_check()\n        elif parsed.path == \"/api/metrics\":\n            if self._check_auth():\n                self.handle_metrics_api(parsed)\n        elif parsed.path == \"/api/refresh\":\n            if self._check_auth():\n                self.handle_refresh_api()\n        else:\n            super().do_GET()\n\n    def do_POST(self):\n        parsed = urlparse(self.path)\n\n        if parsed.path == \"/api/login\":\n            self.handle_login()\n        else:\n            self.send_error(404, \"Not Found\")\n\n    def handle_auth_check(self):\n        \"\"\"Tell the frontend whether authentication is required.\"\"\"\n        self._send_json({\"auth_required\": auth_config[\"enabled\"]})\n\n    def handle_login(self):\n        \"\"\"Validate username/password and return a session token.\"\"\"\n        content_length = int(self.headers.get(\"Content-Length\", 0))\n        if content_length == 0 or content_length > 4096:\n            self._send_json({\"error\": \"Invalid request\"}, status=400)\n            return\n\n        try:\n            body = json.loads(self.rfile.read(content_length))\n        except (json.JSONDecodeError, ValueError):\n            self._send_json({\"error\": \"Invalid JSON\"}, status=400)\n            return\n\n        username = body.get(\"username\", \"\")\n        password = body.get(\"password\", \"\")\n\n        if hmac.compare_digest(\n            username, auth_config[\"username\"]\n        ) and hmac.compare_digest(\n            hash_password(password), auth_config[\"password_hash\"]\n        ):\n            token = create_auth_token()\n            self._send_json({\"token\": token})\n        else:\n            self._send_json({\"error\": \"Invalid username or password\"}, status=401)\n\n    def handle_metrics_api(self, parsed):\n        \"\"\"Handle /api/metrics endpoint.\"\"\"\n        # Check cache with thread safety\n        with cache_lock:\n            cache_valid = (\n                metrics_cache[\"last_updated\"]\n                and time.time() - metrics_cache[\"last_updated\"] < CACHE_TTL\n            )\n            data = metrics_cache[\"data\"].copy()\n\n        if not cache_valid:\n            # Trigger background update\n            threading.Thread(target=update_cache_async, daemon=True).start()\n\n        self._send_json(data)\n\n    def handle_refresh_api(self):\n        \"\"\"Handle /api/refresh endpoint.\"\"\"\n        threading.Thread(target=update_cache_async, daemon=True).start()\n        self._send_json({\"status\": \"refreshing\"})\n\n    def log_message(self, format, *args):\n        \"\"\"Custom log format.\"\"\"\n        print(f\"[{self.log_date_time_string()}] {args[0]}\")\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"SGLang Performance Dashboard Server\")\n    parser.add_argument(\"--port\", type=int, default=8000, help=\"Port to serve on\")\n    parser.add_argument(\n        \"--host\",\n        default=\"127.0.0.1\",\n        help=\"Host to bind to (use 0.0.0.0 for external access)\",\n    )\n    parser.add_argument(\n        \"--fetch-on-start\", action=\"store_true\", help=\"Fetch metrics on startup\"\n    )\n    parser.add_argument(\n        \"--refresh-interval\",\n        type=float,\n        default=12,\n        help=\"Auto-refresh interval in hours (default: 12, set to 0 to disable)\",\n    )\n    parser.add_argument(\n        \"--username\",\n        default=os.environ.get(\"DASHBOARD_USERNAME\"),\n        help=\"Username for dashboard authentication (or set DASHBOARD_USERNAME env var)\",\n    )\n    parser.add_argument(\n        \"--password\",\n        default=os.environ.get(\"DASHBOARD_PASSWORD\"),\n        help=\"Password for dashboard authentication (or set DASHBOARD_PASSWORD env var)\",\n    )\n    args = parser.parse_args()\n\n    # Configure authentication if both username and password are provided\n    if args.username and args.password:\n        auth_config[\"enabled\"] = True\n        auth_config[\"username\"] = args.username\n        auth_config[\"password_hash\"] = hash_password(args.password)\n        print(f\"Authentication enabled for user: {args.username}\")\n    elif args.username or args.password:\n        parser.error(\"Both --username and --password must be provided together\")\n\n    # Change to dashboard directory\n    dashboard_dir = Path(__file__).parent\n    os.chdir(dashboard_dir)\n\n    if args.fetch_on_start:\n        print(\"Fetching initial metrics data...\")\n        update_cache_async()\n\n    if args.refresh_interval > 0:\n        start_periodic_refresh(args.refresh_interval)\n\n    handler = lambda *a, **kw: DashboardHandler(*a, directory=str(dashboard_dir), **kw)\n\n    with socketserver.TCPServer((args.host, args.port), handler) as httpd:\n        print(f\"Serving dashboard at http://{args.host}:{args.port}\")\n        print(\"Press Ctrl+C to stop\")\n        try:\n            httpd.serve_forever()\n        except KeyboardInterrupt:\n            print(\"\\nShutting down...\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "docs/platforms/amd_gpu.md",
    "content": "# AMD GPUs\n\nThis document describes how to run SGLang on AMD GPUs. If you encounter issues or have questions, please [open an issue](https://github.com/sgl-project/sglang/issues).\n\n## System Configuration\n\nWhen using AMD GPUs (such as MI300X), certain system-level optimizations help ensure stable performance. Here we take MI300X as an example. AMD provides official documentation for MI300X optimization and system tuning:\n\n- [AMD MI300X Tuning Guides](https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/index.html)\n- [LLM inference performance validation on AMD Instinct MI300X](https://rocm.docs.amd.com/en/latest/how-to/rocm-for-ai/inference/vllm-benchmark.html)\n- [AMD Instinct MI300X System Optimization](https://rocm.docs.amd.com/en/latest/how-to/system-optimization/mi300x.html)\n- [AMD Instinct MI300X Workload Optimization](https://rocm.docs.amd.com/en/latest/how-to/rocm-for-ai/inference-optimization/workload.html)\n- [Supercharge DeepSeek-R1 Inference on AMD Instinct MI300X](https://rocm.blogs.amd.com/artificial-intelligence/DeepSeekR1-Part2/README.html)\n\n**NOTE:** We strongly recommend reading these docs and guides entirely to fully utilize your system.\n\nBelow are a few key settings to confirm or enable for SGLang:\n\n### Update GRUB Settings\n\nIn `/etc/default/grub`, append the following to `GRUB_CMDLINE_LINUX`:\n\n```text\npci=realloc=off iommu=pt\n```\n\nAfterward, run `sudo update-grub` (or your distro’s equivalent) and reboot.\n\n### Disable NUMA Auto-Balancing\n\n```bash\nsudo sh -c 'echo 0 > /proc/sys/kernel/numa_balancing'\n```\n\nYou can automate or verify this change using [this helpful script](https://github.com/ROCm/triton/blob/rocm_env/scripts/amd/env_check.sh).\n\nAgain, please go through the entire documentation to confirm your system is using the recommended configuration.\n\n## Install SGLang\n\nYou can install SGLang using one of the methods below.\n\n### Install from Source\n\n```bash\n# Use the last release branch\ngit clone -b v0.5.9 https://github.com/sgl-project/sglang.git\ncd sglang\n\n# Compile sgl-kernel\npip install --upgrade pip\ncd sgl-kernel\npython setup_rocm.py install\n\n# Install sglang python package along with diffusion support\ncd ..\nrm -rf python/pyproject.toml && mv python/pyproject_other.toml python/pyproject.toml\npip install -e \"python[all_hip]\"\n```\n\n### Install Using Docker (Recommended)\n\nThe docker images are available on Docker Hub at [lmsysorg/sglang](https://hub.docker.com/r/lmsysorg/sglang/tags), built from [rocm.Dockerfile](https://github.com/sgl-project/sglang/tree/main/docker).\n\nThe steps below show how to build and use an image.\n\n1. Build the docker image.\n   If you use pre-built images, you can skip this step and replace `sglang_image` with the pre-built image names in the steps below.\n\n   ```bash\n   docker build -t sglang_image -f rocm.Dockerfile .\n   ```\n\n2. Create a convenient alias.\n\n   ```bash\n   alias drun='docker run -it --rm --network=host --privileged --device=/dev/kfd --device=/dev/dri \\\n       --ipc=host --shm-size 16G --group-add video --cap-add=SYS_PTRACE \\\n       --security-opt seccomp=unconfined \\\n       -v $HOME/dockerx:/dockerx \\\n       -v /data:/data'\n   ```\n\n   If you are using RDMA, please note that:\n     - `--network host` and `--privileged` are required by RDMA. If you don't need RDMA, you can remove them.\n     - You may need to set `NCCL_IB_GID_INDEX` if you are using RoCE, for example: `export NCCL_IB_GID_INDEX=3`.\n\n3. Launch the server.\n\n   **NOTE:** Replace `<secret>` below with your [huggingface hub token](https://huggingface.co/docs/hub/en/security-tokens).\n\n   ```bash\n   drun -p 30000:30000 \\\n       -v ~/.cache/huggingface:/root/.cache/huggingface \\\n       --env \"HF_TOKEN=<secret>\" \\\n       sglang_image \\\n       python3 -m sglang.launch_server \\\n       --model-path NousResearch/Meta-Llama-3.1-8B \\\n       --host 0.0.0.0 \\\n       --port 30000\n   ```\n\n4. To verify the utility, you can run a benchmark in another terminal or refer to [other docs](https://docs.sglang.io/basic_usage/openai_api_completions.html) to send requests to the engine.\n\n   ```bash\n   drun sglang_image \\\n       python3 -m sglang.bench_serving \\\n       --backend sglang \\\n       --dataset-name random \\\n       --num-prompts 4000 \\\n       --random-input 128 \\\n       --random-output 128\n   ```\n\nWith your AMD system properly configured and SGLang installed, you can now fully leverage AMD hardware to power SGLang’s machine learning capabilities.\n\n## Quantization on AMD GPUs\n\nThe [Quantization documentation](../advanced_features/quantization.md#platform-compatibility) has a full compatibility matrix. The short version: FP8, AWQ, MXFP4, W8A8, GPTQ, compressed-tensors, Quark, and **petit_nvfp4** (NVFP4 on ROCm via [Petit](https://github.com/causalflow-ai/petit-kernel)) all work on AMD. Methods that depend on Marlin or NVIDIA-specific kernels (`awq_marlin`, `gptq_marlin`, `gguf`, `modelopt_fp8`, `modelopt_fp4`) do not.\n\nA few things to keep in mind:\n\n- FP8 works via Aiter or Triton. Pre-quantized FP8 models like DeepSeek-V3/R1 work out of the box.\n- AWQ uses Triton dequantization kernels on AMD. The faster Marlin path is not available.\n- MXFP4 requires CDNA3/CDNA4 and `SGLANG_USE_AITER=1`.\n- `petit_nvfp4` enables NVFP4 models (e.g., [Llama 3.3 70B FP4](https://huggingface.co/nvidia/Llama-3.3-70B-Instruct-FP4)) on MI250/MI300X via [Petit](https://github.com/causalflow-ai/petit-kernel). Install with `pip install petit-kernel`; no `--quantization` flag needed when loading pre-quantized NVFP4 models.\n- `quark_int4fp8_moe` is an AMD-only online quantization method for MoE models on CDNA3/CDNA4.\n\nSeveral of these backends are accelerated by [Aiter](https://github.com/ROCm/aiter). Enable it with:\n\n```bash\nexport SGLANG_USE_AITER=1\n```\n\nExample -- serving an AWQ model:\n\n```bash\npython3 -m sglang.launch_server \\\n    --model-path hugging-quants/Mixtral-8x7B-Instruct-v0.1-AWQ-INT4 \\\n    --trust-remote-code \\\n    --port 30000 --host 0.0.0.0\n```\n\nExample -- FP8 online quantization:\n\n```bash\npython3 -m sglang.launch_server \\\n    --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n    --quantization fp8 \\\n    --port 30000 --host 0.0.0.0\n```\n\n## Examples\n\n### Running DeepSeek-V3\n\nThe only difference when running DeepSeek-V3 is in how you start the server. Here's an example command:\n\n```bash\ndrun -p 30000:30000 \\\n    -v ~/.cache/huggingface:/root/.cache/huggingface \\\n    --ipc=host \\\n    --env \"HF_TOKEN=<secret>\" \\\n    sglang_image \\\n    python3 -m sglang.launch_server \\\n    --model-path deepseek-ai/DeepSeek-V3 \\ # <- here\n    --tp 8 \\\n    --trust-remote-code \\\n    --host 0.0.0.0 \\\n    --port 30000\n```\n\n[Running DeepSeek-R1 on a single NDv5 MI300X VM](https://techcommunity.microsoft.com/blog/azurehighperformancecomputingblog/running-deepseek-r1-on-a-single-ndv5-mi300x-vm/4372726) could also be a good reference.\n\n### Running Llama3.1\n\nRunning Llama3.1 is nearly identical to running DeepSeek-V3. The only difference is in the model specified when starting the server, shown by the following example command:\n\n```bash\ndrun -p 30000:30000 \\\n    -v ~/.cache/huggingface:/root/.cache/huggingface \\\n    --ipc=host \\\n    --env \"HF_TOKEN=<secret>\" \\\n    sglang_image \\\n    python3 -m sglang.launch_server \\\n    --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\ # <- here\n    --tp 8 \\\n    --trust-remote-code \\\n    --host 0.0.0.0 \\\n    --port 30000\n```\n\n### Warmup Step\n\nWhen the server displays `The server is fired up and ready to roll!`, it means the startup is successful.\n"
  },
  {
    "path": "docs/platforms/apple_metal.md",
    "content": "# Apple Silicon with Metal\n\nThis document describes how run SGLang on Apple Silicon using [Metal](https://developer.apple.com/metal/). If you encounter issues or have questions, please [open an issue](https://github.com/sgl-project/sglang/issues).\n\n## Install SGLang\n\nYou can install SGLang using one of the methods below.\n\n### Install from Source\n\n```bash\n# Use the default branch\ngit clone https://github.com/sgl-project/sglang.git\ncd sglang\n\n# Install sglang python package\npip install --upgrade pip\nrm -f python/pyproject.toml && mv python/pyproject_other.toml python/pyproject.toml\nuv pip install -e \"python[all_mps]\"\n```\n"
  },
  {
    "path": "docs/platforms/ascend_contribution_guide.md",
    "content": "# Contribution Guide\n\nWelcome to **SGLang**! We appreciate your interest in contributing. This guide provides a concise overview of how to set up your environment, run tests, build documentation, and open a Pull Request (PR). Whether you’re fixing a small bug or developing a major feature, we encourage following these steps for a smooth contribution process.\n\n## Install SGLang from Source\n\n### Prepare Environment\n\nBefore contributing, please ensure that your environment is set up correctly. Follow the steps in the [Installation Guide](../platforms/ascend_npu.md) to install the necessary dependencies. We recommend [using docker](../platforms/ascend_npu.md#method-2-using-docker-image) to build the environment.\n\n### Fork and clone the repository\n\n**Note**: New contributors do **not** have the write permission to push to the official SGLang repo. Please fork the repository under your GitHub account, then clone your fork locally.\n\n```bash\ngit clone https://github.com/<your_user_name>/sglang.git\n# if you are using docker, the environment is already set up.\ncd sglang\nexport PYTHONPATH=$PWD/python:$PYTHONPATH\n```\n\n## Format code with pre-commit\n\nWe use [pre-commit](https://pre-commit.com/) to maintain consistent code style checks. Before pushing your changes, please run:\n\n```bash\npip3 install pre-commit\npre-commit install\npre-commit run --all-files\n```\n\n- **`pre-commit run --all-files`** manually runs all configured checks, applying fixes if possible. If it fails the first time, re-run it to ensure lint errors are fully resolved. Make sure your code passes all checks **before** creating a Pull Request.\n- **Do not commit** directly to the `main` branch. Always create a new branch (e.g., `feature/my-new-feature`), push your changes, and open a PR from that branch.\n\n## Run and add unit tests\n\nIf you add a new feature or fix a bug, please add corresponding unit tests to ensure coverage and prevent regression.\nSGLang uses Python's built-in [unittest](https://docs.python.org/3/library/unittest.html) framework.\nFor detailed instructions on running tests and integrating them into CI, refer to [test/README.md](https://github.com/sgl-project/sglang/tree/main/test/README.md).\n\n## Write documentations\n\nWe recommend new contributors start from writing documentation, which helps you quickly understand SGLang codebase.\nFor more details, please refer to [docs/README.md](https://github.com/sgl-project/sglang/tree/main/docs/README.md).\n\n## Test the accuracy\nIf your code changes the model output, please run the accuracy tests. A quick sanity check is the few-shot GSM8K.\n\n```\n# Launch a server\npython3 -m sglang.launch_server --model Qwen/Qwen2-7B-Instruct\n\n# Evaluate\npython3 -m sglang.test.few_shot_gsm8k --num-questions 200\n```\n\nPlease note that the above script is primarily a sanity check, not a rigorous accuracy or speed test.\nThis test can have significant variance (1%–5%) in accuracy due to batching and the non-deterministic nature of the inference engine.\nAlso, do not rely on the \"Latency/Output throughput\" from this script, as it is not a proper speed test.\n\nGSM8K is too easy for state-of-the-art models nowadays. Please try your own more challenging accuracy tests.\nYou can find additional accuracy eval examples in:\n- [test_eval_accuracy_large.py](https://github.com/sgl-project/sglang/blob/main/test/srt/test_eval_accuracy_large.py)\n- [test_moe_eval_accuracy_large.py](https://github.com/sgl-project/sglang/blob/main/test/srt/test_moe_eval_accuracy_large.py)\n\n## Benchmark the speed\nRefer to [Benchmark and Profiling](../developer_guide/benchmark_and_profiling.md).\n\n## Requesting a review for merge\nYou can follow the pull request merge process described in [MAINTAINER.md](https://github.com/sgl-project/sglang/blob/main/.github/MAINTAINER.md).\nYou will need to work with the Merge Oncall, Codeowner, and other reviewers to get their approvals.\nThen your PR can be merged.\n\n## How to Trigger CI Tests\n\nWe have a lot of open PRs but limited CI machines, so only top and trusted contributors have permission to trigger CI tests.\nUsers with permission are listed in the [CI_PERMISSIONS.json](https://github.com/sgl-project/sglang/blob/main/.github/CI_PERMISSIONS.json)\n\nFor CI to run on a pull request, it must have the \"run-ci\" label. Authorized users can add the label or rerun failed tests by commenting on the PR with one of these commands:\n\n- `/tag-run-ci-label`: Adds the \"run-ci\" label. Every future commit will trigger CI.\n- `/rerun-failed-ci`: Reruns the failed or flaky tests from the most recent commit.\n- `/tag-and-rerun-ci`: A single command that performs both `/tag-run-ci-label` and `/rerun-failed-ci`.\n- `/rerun-stage <stage-name>`: Reruns a specific test stage without waiting for its dependencies. This is useful when you want to quickly validate a fix for a specific test failure instead of waiting ~30 minutes for preceding stages to complete.\n\nIf you have permission, the [Slash Command Handler](https://github.com/sgl-project/sglang/actions/workflows/slash-command-handler.yml) will run your command and react with a 👍 to your comment. It may take up to a few minutes for the reaction to appear. Here’s a usage [example](https://github.com/sgl-project/sglang/pull/14253#issuecomment-3599509302).\n\nTo avoid spamming a PR with too many `/rerun-failed-ci` comments, you can also trigger the command by editing an existing comment and adding any suffix (e.g., `/rerun-failed-ci try again`).\n\nExample of rerunning a single test stage: `/rerun-stage unit-test-backend-4-gpu`.\n\nIf you don’t have permission, please ask maintainers to trigger CI for you.\n\n### CI rate limits\n\nDue to CI scheduling and limited resources, higher-priority PRs may preempt running jobs. In such cases, you may need to rerun the tests.\n\nWe apply CI rate limits to prevent abuse and ensure fair usage of our CI resources.\n\nEach CI workflow has a default limit defined in its workflow configuration file. For example, in [pr-gate.yml](https://github.com/sgl-project/sglang/blob/main/.github/workflows/pr-gate.yml), the default cooldown period is 120 minutes, and each workflow can override it via the `cool-down-minutes` input parameter:\n\n```yaml\ncool-down-minutes:\n  description: \"Default cooldown period in minutes; 0 disables rate limiting\"\n  type: number\n  default: 120\n```\n\nUsers listed in [CI_PERMISSIONS.json](https://github.com/sgl-project/sglang/blob/main/.github/CI_PERMISSIONS.json) may have a per-user cooldown interval. In practice, we use the minimum of the workflow’s default window and the user-specific interval.\n\n\n## Code style guidance\n- Avoid code duplication. If the same code snippet (more than five lines) appears multiple times, extract it into a shared function.\n- Minimize device synchronization. Reduce expensive CPU-GPU synchronization operations, such as `tensor.item()` or `tensor.cpu()`, whenever possible. Use vectorized code.\n- Prioritize extreme efficiency. SGLang is a runtime, and most of your code runs on the critical path for every request. Optimize all minor overheads as much as possible, especially in the model forward code.\n  - A common pattern is some runtime checks in the model forward pass (e.g., [this](https://github.com/sgl-project/sglang/blob/f1b0eda55c2c4838e8ab90a0fac7fb1e3d7064ab/python/sglang/srt/models/deepseek_v2.py#L486-L491)). These are very likely the same for every layer. Please cache the result as a single boolean value whenever possible.\n- Make functions as pure as possible. Avoid in-place modification of arguments.\n- Keep files concise. If a file exceeds 2,000 lines of code, split it into multiple smaller files. (e.g., `scheduler.py`, `scheduler_output_processor_mixin.py`)\n- Keep tests run fast.\n  - If a single test file run longer than 500 seconds, split it into multiple smaller files (e.g., `test_eagle_infer_a.py`, `test_eagle_infer_b.py`).\n  - If a single job in a github workflow runs longer than 30 mins, split it into smaller jobs/steps.\n  - Reuse server launches in your unit tests to make tests run faster.\n- When supporting new hardware or features, follow these guidelines:\n  - Do not drastically change existing code.\n  - Always prefer new files to introduce specific components for your new hardware (e.g., `allocator_ascend.py`).\n  - If you write multiple if/else blocks for new features, ensure the common path (e.g., NVIDIA hardware or the existing code path) is the first branch.\n\n## How to update sgl-kernel\nSince sglang and sgl-kernel are separate Python packages, our current GitHub CI infrastructure does not support updating a kernel and using it immediately within the same pull request (PR).\nTo add a new kernel or modify an existing one in the `sgl-kernel/` source tree, you must use multiple PRs.\n\nFollow these steps:\n\n1. Submit a PR to update the sgl-kernel source code without using it in sglang python package (e.g., [#8884](https://github.com/sgl-project/sglang/pull/8884/files)).\n2. Bump the version of the kernel package (e.g., [#9220](https://github.com/sgl-project/sglang/pull/9220/files)).\n   - Once merged, this will trigger an automatic release of the `sglang-kernel` wheel to PyPI.\n   - If not urgent, you can wait for other people to release the wheel. A new version will typically be released within one week.\n3. Apply the changes:\n   - Update the `sglang-kernel` version in `sglang/python/pyproject.toml` to use the modified kernels.\n   - Update the related caller code in the sglang to use the new kernel.\n\n## How to update sgl-kernel-npu\n\nSgl-kernel-npu is the kernel package for Ascend NPU and is maintained in the [sgl-kernel-npu](https://github.com/sgl-project/sgl-kernel-npu) repository. if you want to add a new kernel and want to use it in sglang, please follow the steps in [Contribution Guide](https://github.com/sgl-project/sgl-kernel-npu/blob/main/docs/developer_guide/contribution_guide.md).\n\n## Tips for newcomers\n\nIf you want to contribute but don’t have a specific idea in mind, pick issues labeled [“good first issue” or “help wanted”](https://github.com/sgl-project/sglang/issues?q=is%3Aissue+label%3A%22good+first+issue%22%2C%22help+wanted%22). These tasks typically have lower complexity and provide an excellent introduction to the codebase. Also check out this [code walk-through](https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/tree/main/sglang/code-walk-through) for a deeper look into SGLang’s workflow.\n\nIf you have any questions or want to start a discussion, please feel free to ask in our [Slack channel](https://slack.sglang.io).\n\nThank you for your interest in SGLang. Happy coding!\n"
  },
  {
    "path": "docs/platforms/ascend_npu.md",
    "content": "\n# SGLang installation with NPUs support\n\nYou can install SGLang using any of the methods below. Please go through `System Settings` section to ensure the clusters are roaring at max performance. Feel free to leave an issue [here at sglang](https://github.com/sgl-project/sglang/issues) if you encounter any issues or have any problems.\n\n## Component Version Mapping For SGLang\n| Component         | Version                 | Obtain Way                                                                                                                                                                                                                   |\n|-------------------|-------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| HDK               | 25.3.RC1                  | [link](https://hiascend.com/hardware/firmware-drivers/commercial?product=7&model=33) |\n| CANN              | 8.5.0                     | [Obtain Images](#obtain-cann-image)                                                                                                                                                                                          |\n| Pytorch Adapter   | 7.3.0                   | [link](https://gitcode.com/Ascend/pytorch/releases)                                                                                                                                                                          |\n| MemFabric         | 1.0.5                   | `pip install memfabric-hybrid==1.0.5`                                                                                                                                                                 |\n| Triton            | 3.2.0                   | `pip install triton-ascend`|\n| SGLang NPU Kernel | NA                      | [link](https://github.com/sgl-project/sgl-kernel-npu/releases)                                                                                                                                                               |\n\n<a id=\"obtain-cann-image\"></a>\n### Obtain CANN Image\nYou can obtain the dependency of a specified version of CANN through an image.\n```shell\n# for Atlas 800I A3 and Ubuntu OS\ndocker pull quay.io/ascend/cann:8.5.0-a3-ubuntu22.04-py3.11\n# for Atlas 800I A2 and Ubuntu OS\ndocker pull quay.io/ascend/cann:8.5.0-910b-ubuntu22.04-py3.11\n```\n\n## Preparing the Running Environment\n\n### Method 1: Installing from source with prerequisites\n\n#### Python Version\n\nOnly `python==3.11` is supported currently. If you don't want to break system pre-installed python, try installing with [conda](https://github.com/conda/conda).\n\n```shell\nconda create --name sglang_npu python=3.11\nconda activate sglang_npu\n```\n\n#### CANN\n\nPrior to start work with SGLang on Ascend you need to install CANN Toolkit, Kernels operator package and NNAL version 8.3.RC2 or higher, check the [installation guide](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/83RC1/softwareinst/instg/instg_0008.html?Mode=PmIns&InstallType=local&OS=openEuler&Software=cannToolKit)\n\n#### MemFabric-Hybrid\n\nIf you want to use PD disaggregation mode, you need to install MemFabric-Hybrid. MemFabric-Hybrid is a drop-in replacement of Mooncake Transfer Engine that enables KV cache transfer on Ascend NPU clusters.\n\n```shell\npip install memfabric-hybrid==1.0.5\n```\n\n#### Pytorch and Pytorch Framework Adaptor on Ascend\n\n```shell\nPYTORCH_VERSION=2.8.0\nTORCHVISION_VERSION=0.23.0\nTORCH_NPU_VERSION=2.8.0\npip install torch==$PYTORCH_VERSION torchvision==$TORCHVISION_VERSION --index-url https://download.pytorch.org/whl/cpu\npip install torch_npu==$TORCH_NPU_VERSION\n```\n\nIf you are using other versions of `torch` and install `torch_npu`, check [installation guide](https://github.com/Ascend/pytorch/blob/master/README.md)\n\n#### Triton on Ascend\n\nWe provide our own implementation of Triton for Ascend.\n\n```shell\npip install triton-ascend\n```\nFor installation of Triton on Ascend nightly builds or from sources, follow [installation guide](https://gitcode.com/Ascend/triton-ascend/blob/master/docs/sources/getting-started/installation.md)\n\n#### SGLang Kernels NPU\nWe provide SGL kernels for Ascend NPU, check [installation guide](https://github.com/sgl-project/sgl-kernel-npu/blob/main/python/sgl_kernel_npu/README.md).\n\n#### DeepEP-compatible Library\nWe provide a DeepEP-compatible Library as a drop-in replacement of deepseek-ai's DeepEP library, check the [installation guide](https://github.com/sgl-project/sgl-kernel-npu/blob/main/python/deep_ep/README.md).\n\n#### Installing SGLang from source\n\n```shell\n# Use the last release branch\ngit clone https://github.com/sgl-project/sglang.git\ncd sglang\nmv python/pyproject_npu.toml python/pyproject.toml\npip install -e python[all_npu]\n```\n\n### Method 2: Using Docker Image\n#### Obtain Image\nYou can download the SGLang image or build an image based on Dockerfile to obtain the Ascend NPU image.\n1. Download SGLang image\n```angular2html\ndockerhub: docker.io/lmsysorg/sglang:$tag\n# Main-based tag, change main to specific version like v0.5.6,\n# you can get image for specific version\nAtlas 800I A3 : {main}-cann8.5.0-a3\nAtlas 800I A2: {main}-cann8.5.0-910b\n```\n2. Build an image based on Dockerfile\n```shell\n# Clone the SGLang repository\ngit clone https://github.com/sgl-project/sglang.git\ncd sglang/docker\n\n# Build the docker image\n# If there are network errors, please modify the Dockerfile to use offline dependencies or use a proxy\ndocker build -t <image_name> -f npu.Dockerfile .\n```\n\n#### Create Docker\n__Notice:__ `--privileged` and `--network=host` are required by RDMA, which is typically needed by Ascend NPU clusters.\n\n__Notice:__ The following docker command is based on Atlas 800I A3 machines. If you are using Atlas 800I A2, make sure only `davinci[0-7]` are mapped into container.\n\n```shell\n\nalias drun='docker run -it --rm --privileged --network=host --ipc=host --shm-size=16g \\\n    --device=/dev/davinci0 --device=/dev/davinci1 --device=/dev/davinci2 --device=/dev/davinci3 \\\n    --device=/dev/davinci4 --device=/dev/davinci5 --device=/dev/davinci6 --device=/dev/davinci7 \\\n    --device=/dev/davinci8 --device=/dev/davinci9 --device=/dev/davinci10 --device=/dev/davinci11 \\\n    --device=/dev/davinci12 --device=/dev/davinci13 --device=/dev/davinci14 --device=/dev/davinci15 \\\n    --device=/dev/davinci_manager --device=/dev/hisi_hdc \\\n    --volume /usr/local/sbin:/usr/local/sbin --volume /usr/local/Ascend/driver:/usr/local/Ascend/driver \\\n    --volume /usr/local/Ascend/firmware:/usr/local/Ascend/firmware \\\n    --volume /etc/ascend_install.info:/etc/ascend_install.info \\\n    --volume /var/queue_schedule:/var/queue_schedule --volume ~/.cache/:/root/.cache/'\n\n# Add HF_TOKEN env for download model by SGLang.\ndrun --env \"HF_TOKEN=<secret>\" \\\n    <image_name> \\\n    python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --attention-backend ascend\n```\n\n## System Settings\n\n### CPU performance power scheme\n\nThe default power scheme on Ascend hardware is `ondemand` which could affect performance, changing it to `performance` is recommended.\n\n```shell\necho performance | sudo tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor\n\n# Make sure changes are applied successfully\ncat /sys/devices/system/cpu/cpu0/cpufreq/scaling_governor # shows performance\n```\n\n### Disable NUMA balancing\n\n```shell\nsudo sysctl -w kernel.numa_balancing=0\n# Check\ncat /proc/sys/kernel/numa_balancing # shows 0\n```\n\n### Prevent swapping out system memory\n\n```shell\nsudo sysctl -w vm.swappiness=10\n\n# Check\ncat /proc/sys/vm/swappiness # shows 10\n```\n\n## Running SGLang Service\n### Running Service For Large Language Models\n#### PD Mixed Scene\n```shell\n# Enabling CPU Affinity\nexport SGLANG_SET_CPU_AFFINITY=1\npython3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --attention-backend ascend\n```\n\n#### PD Separation Scene\n1. Launch Prefill Server\n```shell\n# Enabling CPU Affinity\nexport SGLANG_SET_CPU_AFFINITY=1\n\n# PIP: recommended to config first Prefill Server IP\n# PORT: one free port\n# all sglang servers need to be config the same PIP and PORT,\nexport ASCEND_MF_STORE_URL=\"tcp://PIP:PORT\"\n# if you are Atlas 800I A2 hardware and use rdma for kv cache transfer, add this parameter\nexport ASCEND_MF_TRANSFER_PROTOCOL=\"device_rdma\"\npython3 -m sglang.launch_server \\\n    --model-path meta-llama/Llama-3.1-8B-Instruct \\\n    --disaggregation-mode prefill \\\n    --disaggregation-transfer-backend ascend \\\n    --disaggregation-bootstrap-port 8995 \\\n    --attention-backend ascend \\\n    --device npu \\\n    --base-gpu-id 0 \\\n    --tp-size 1 \\\n    --host 127.0.0.1 \\\n    --port 8000\n```\n\n2. Launch Decode Server\n```shell\n# PIP: recommended to config first Prefill Server IP\n# PORT: one free port\n# all sglang servers need to be config the same PIP and PORT,\nexport ASCEND_MF_STORE_URL=\"tcp://PIP:PORT\"\n# if you are Atlas 800I A2 hardware and use rdma for kv cache transfer, add this parameter\nexport ASCEND_MF_TRANSFER_PROTOCOL=\"device_rdma\"\npython3 -m sglang.launch_server \\\n    --model-path meta-llama/Llama-3.1-8B-Instruct \\\n    --disaggregation-mode decode \\\n    --disaggregation-transfer-backend ascend \\\n    --attention-backend ascend \\\n    --device npu \\\n    --base-gpu-id 1 \\\n    --tp-size 1 \\\n    --host 127.0.0.1 \\\n    --port 8001\n```\n\n3. Launch Router\n```shell\npython3 -m sglang_router.launch_router \\\n    --pd-disaggregation \\\n    --policy cache_aware \\\n    --prefill http://127.0.0.1:8000 8995 \\\n    --decode http://127.0.0.1:8001 \\\n    --host 127.0.0.1 \\\n    --port 6688\n```\n\n### Running Service For Multimodal Language Models\n#### PD Mixed Scene\n```shell\npython3 -m sglang.launch_server \\\n    --model-path Qwen3-VL-30B-A3B-Instruct \\\n    --host 127.0.0.1 \\\n    --port 8000 \\\n    --tp 4 \\\n    --device npu \\\n    --attention-backend ascend \\\n    --mm-attention-backend ascend_attn \\\n    --disable-radix-cache \\\n    --trust-remote-code \\\n    --enable-multimodal \\\n    --sampling-backend ascend\n```\n"
  },
  {
    "path": "docs/platforms/ascend_npu_best_practice.md",
    "content": "# Best Practice on Ascend NPU\n\nThis section describes the best practice data of mainstream LLM models such as DeepSeek and Qwen on the Ascend NPU. If\nyou encounter issues or have any questions, please [open an issue](https://github.com/sgl-project/sglang/issues).\n\n## DeepSeek Series Models\n\n### Low Latency\n\n| Model             | Hardware      | Cards | Deploy Mode   | Dataset   | TPOT | Quantization | Configuration                                                                         |\n|-------------------|---------------|-------|---------------|-----------|------|--------------|---------------------------------------------------------------------------------------|\n| Deepseek-R1       | Atlas 800I A3 | 32    | PD Separation | 6K+1.6K   | 20ms | W8A8 INT8    | [Optimal Configuration](#deepseek-r1-6k-1_6k-20ms-on-a3-32-cards-separation-mode)     |\n| Deepseek-R1       | Atlas 800I A3 | 32    | PD Separation | 3.9K+1K   | 20ms | W8A8 INT8    | [Optimal Configuration](#deepseek-r1-3_9k-1k-20ms-on-a3-32-cards-separation-mode)     |\n| Deepseek-R1       | Atlas 800I A3 | 32    | PD Separation | 3.5K+1.5K | 20ms | W8A8 INT8    | [Optimal Configuration](#deepseek-r1-3_5k-1_5k-20ms-on-a3-32-cards-separation-mode)   |\n| Deepseek-R1       | Atlas 800I A3 | 32    | PD Separation | 3.5K+1K   | 20ms | W8A8 INT8    | [Optimal Configuration](#deepseek-r1-3_5k-1k-20ms-on-a3-32-cards-separation-mode)     |\n| DeepSeek-V3.2-Exp | Atlas 800I A3 | 32    | PD Separation | 64K+3K    | 30ms | W8A8 INT8    | [Optimal Configuration](#deepseek-v32-exp-64k-3k-30ms-on-a3-32-cards-separation-mode) |\n\n### High Throughput\n\n| Model       | Hardware      | Cards | Deploy Mode   | Dataset   | TPOT | Quantization | Configuration                                                                       |\n|-------------|---------------|-------|---------------|-----------|------|--------------|-------------------------------------------------------------------------------------|\n| Deepseek-R1 | Atlas 800I A3 | 32    | PD Separation | 3.5K+1.5K | 50ms | W8A8 INT8    | [Optimal Configuration](#deepseek-r1-3_5k-1_5k-50ms-on-a3-32-cards-separation-mode) |\n| Deepseek-R1 | Atlas 800I A3 | 8     | PD Mixed      | 2K+2K     | 50ms | W4A8 INT8    | [Optimal Configuration](#deepseek-r1-2k-2k-50ms-on-a3-8-cards-mixed-mode)           |\n| Deepseek-R1 | Atlas 800I A3 | 16    | PD Separation | 2K+2K     | 50ms | W4A8 INT8    | [Optimal Configuration](#deepseek-r1-2k-2k-50ms-on-a3-16-cards-separation-mode)     |\n| Deepseek-R1 | Atlas 800I A3 | 8     | PD Mixed      | 3.5K+1.5K | 50ms | W4A8 INT8    | [Optimal Configuration](#deepseek-r1-3_5k-1_5k-50ms-on-a3-8-cards-mixed-mode)       |\n| Deepseek-R1 | Atlas 800I A3 | 16    | PD Separation | 3.5K+1.5K | 50ms | W4A8 INT8    | [Optimal Configuration](#deepseek-r1-3_5k-1_5k-50ms-on-a3-16-cards-separation-mode) |\n\n## Qwen Series Models\n\n### Low Latency\n\n| Model           | Hardware      | Cards | Deploy Mode | Dataset | TPOT | Quantization | Configuration                                                                  |\n|-----------------|---------------|-------|-------------|---------|------|--------------|--------------------------------------------------------------------------------|\n| Qwen3-235B-A22B | Atlas 800I A3 | 8     | PD Mixed    | 11K+1K  | 10ms | BF16         | [Optimal Configuration](#qwen3-235b-a22b-11k-1k-10ms-on-a3-8-cards-mixed-mode) |\n| Qwen3-32B       | Atlas 800I A3 | 4     | PD Mixed    | 6K+1.5K | 18ms | BF16         | [Optimal Configuration](#qwen3-32b-6k-1_5k-18ms-on-a3-4-cards-mixed-mode)      |\n| Qwen3-32B       | Atlas 800I A3 | 4     | PD Mixed    | 4K+1.5K | 11ms | BF16         | [Optimal Configuration](#qwen3-32b-4k-1_5k-11ms-on-a3-4-cards-mixed-mode)      |\n| Qwen3-32B       | Atlas 800I A3 | 8     | PD Mixed    | 18K+4K  | 12ms | BF16         | [Optimal Configuration](#qwen3-32b-18k-4k-12ms-on-a3-8-cards-mixed-mode)       |\n| Qwen3-32B       | Atlas 800I A2 | 8     | PD Mixed    | 6K+1.5K | 18ms | W8A8 INT8    | [Optimal Configuration](#qwen3-32b-6k-1_5k-18ms-on-a2-8-cards-mixed-mode)      |\n| Qwen3-32B       | Atlas 800I A2 | 8     | PD Mixed    | 4K+1.5K | 11ms | BF16         | [Optimal Configuration](#qwen3-32b-4k-1_5k-11ms-on-a2-8-cards-mixed-mode)      |\n\n### High Throughput\n\n| Model                          | Hardware      | Cards | Deploy Mode   | Dataset   | TPOT  | Quantization | Configuration                                                                                          |\n|--------------------------------|---------------|-------|---------------|-----------|-------|--------------|--------------------------------------------------------------------------------------------------------|\n| Qwen3-235B-A22B                | Atlas 800I A3 | 24    | PD Separation | 3.5K+1.5K | 50ms  | W8A8 INT8    | [Optimal Configuration](#qwen3-235b-a22b-3_5k-1_5k-50ms-on-a3-24-cards-separation-mode)                |\n| Qwen3-235B-A22B                | Atlas 800I A3 | 8     | PD Mixed      | 3.5K+1.5K | 50ms  | W8A8 INT8    | [Optimal Configuration](#qwen3-235b-a22b-3_5k-1_5k-50ms-on-a3-8-cards-mixed-mode)                      |\n| Qwen3-235B-A22B                | Atlas 800I A3 | 8     | PD Mixed      | 2K+2K     | 100ms | W8A8 INT8    | [Optimal Configuration](#qwen3-235b-a22b-2k-2k-100ms-on-a3-8-cards-mixed-mode)                         |\n| Qwen3-235B-A22B                | Atlas 800I A3 | 8     | PD Mixed      | 2K+2K     | 50ms  | W8A8 INT8    | [Optimal Configuration](#qwen3-235b-a22b-2k-2k-50ms-on-a3-8-cards-mixed-mode)                          |\n| Qwen3-235B-A22B                | Atlas 800I A3 | 16    | PD Mixed      | 2K+2K     | 50ms  | W8A8 INT8    | [Optimal Configuration](#qwen3-235b-a22b-2k-2k-50ms-on-a3-16-cards-mixed-mode)                         |\n| Qwen3-32B                      | Atlas 800I A3 | 2     | PD Mixed      | 3.5K+1.5K | 50ms  | W8A8 INT8    | [Optimal Configuration](#qwen3-32b-3_5k-1_5k-50ms-on-a3-2-cards-mixed-mode)                            |\n| Qwen3-32B                      | Atlas 800I A3 | 2     | PD Mixed      | 2K+2K     | 50ms  | W8A8 INT8    | [Optimal Configuration](#qwen3-32b-2k-2k-50ms-on-a3-2-cards-mixed-mode)                                |\n| Qwen3-30B-A3B                  | Atlas 800I A3 | 1     | PD Mixed      | 3.5K+1.5K | 50ms  | W8A8 INT8    | [Optimal Configuration](#qwen3-30b-a3b-3_5k-1_5k-50ms-on-a3-1-card-mixed-mode)                         |\n| Qwen3-Coder-480B-A35B-Instruct | Atlas 800I A3 | 24    | PD Separation | 3.5K+1.5K | 50ms  | W8A8 INT8    | [Optimal Configuration](#qwen3-coder-480b-a35b-instruct-3_5k-1_5k-50ms-on-a3-24-cards-separation-mode) |\n| Qwen3-Coder-480B-A35B-Instruct | Atlas 800I A3 | 16    | PD Mixed      | 3.5K+1.5K | 50ms  | W8A8 INT8    | [Optimal Configuration](#qwen3-coder-480b-a35b-instruct-3_5k-1_5k-50ms-on-a3-16-cards-mixed-mode)      |\n| Qwen3-Coder-480B-A35B-Instruct | Atlas 800I A3 | 8     | PD Mixed      | 3.5K+1.5K | 50ms  | W8A8 INT8    | [Optimal Configuration](#qwen3-coder-480b-a35b-instruct-3_5k-1_5k-50ms-on-a3-8-cards-mixed-mode)       |\n| Qwen3-Next-80B-A3B-Instruct    | Atlas 800I A3 | 2     | PD Mixed      | 3.5K+1.5K | 50ms  | W8A8 INT8    | [Optimal Configuration](#qwen3-next-80B-a3b-instruct-3_5k-1_5k-50ms-on-a3-2-cards-mixed-mode)          |\n| Qwen3-32B                      | Atlas 800I A2 | 8     | PD Mixed      | 3.5K+1.5K | 50ms  | W8A8 INT8    | [Optimal Configuration](#qwen3-32b-3_5k-1_5k-50ms-on-a2-8-cards-mixed-mode)                            |\n| Qwen3-32B                      | Atlas 800I A2 | 8     | PD Mixed      | 2K+2K     | 50ms  | W8A8 INT8    | [Optimal Configuration](#qwen3-32b-2k-2k-50ms-on-a2-8-cards-mixed-mode)                                |\n\n## Optimal Configuration\n\n### DeepSeek-R1 3_5K-1_5K 50ms on A3 32 Cards Separation Mode\n\nModel: Deepseek R1\n\nHardware: Atlas 800I A3 32Card\n\nDeployMode: PD Separation\n\nDataset: random\n\nInput Output Length: 3.5K+1.5K\n\nTPOT: 50ms\n\n#### Model Deployment\n\n```shell\necho performance | tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor\nsysctl -w vm.swappiness=0\nsysctl -w kernel.numa_balancing=0\nsysctl -w kernel.sched_migration_cost_ns=50000\nexport SGLANG_SET_CPU_AFFINITY=1\nunset ASCEND_LAUNCH_BLOCKING\nsource /usr/local/Ascend/ascend-toolkit/set_env.sh\nsource /usr/local/Ascend/nnal/atb/set_env.sh\nexport PATH=/usr/local/Ascend/8.5.0/compiler/bishengir/bin:$PATH\n\nexport PYTORCH_NPU_ALLOC_CONF=expandable_segments:True\nexport STREAMS_PER_DEVICE=32\n\nexport ASCEND_MF_STORE_URL=\"tcp://your prefill ip1:24669\"\n\nP_IP=('your prefill ip1' 'your prefill ip2')\n\nD_IP=('your decode ip1' 'your decode ip2')\n\nMODEL_PATH=xxx\n\nexport SGLANG_NPU_USE_MLAPO=1\nexport SGLANG_USE_FIA_NZ=1\n\nLOCAL_HOST1=`hostname -I|awk -F \" \" '{print$1}'`\nLOCAL_HOST2=`hostname -I|awk -F \" \" '{print$2}'`\necho \"${LOCAL_HOST1}\"\necho \"${LOCAL_HOST2}\"\n# prefill\nfor i in \"${!P_IP[@]}\";\ndo\n    if [[ \"$LOCAL_HOST1\" == \"${P_IP[$i]}\" || \"$LOCAL_HOST2\" == \"${P_IP[$i]}\" ]];\n    then\n        echo \"${P_IP[$i]}\"\n        export HCCL_BUFFSIZE=1536\n        export DEEP_NORMAL_MODE_USE_INT8_QUANT=1\n        export TASK_QUEUE_ENABLE=2\n\n        export HCCL_SOCKET_IFNAME=lo\n        export GLOO_SOCKET_IFNAME=lo\n        python -m sglang.launch_server --model-path ${MODEL_PATH}  --disaggregation-mode prefill --host ${P_IP[$i]} \\\n        --port 8000 --disaggregation-bootstrap-port $((8998+$i)) --trust-remote-code --nnodes 1 --node-rank 0 \\\n        --tp-size 16 --mem-fraction-static 0.81 --attention-backend ascend --device npu --quantization modelslim \\\n        --disaggregation-transfer-backend ascend --max-running-requests 8 --context-length 8192  --disable-radix-cache \\\n        --chunked-prefill-size -1 --max-prefill-tokens 28680 --moe-a2a-backend deepep --deepep-mode normal \\\n        --speculative-algorithm NEXTN --speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2  \\\n        --dp-size 2 --enable-dp-attention --disable-shared-experts-fusion --dtype bfloat16 --enable-attn-tp-input-scattered\n        NODE_RANK=$i\n        break\n    fi\ndone\n\n# decode\nfor i in \"${!D_IP[@]}\";\ndo\n    if [[ \"$LOCAL_HOST1\" == \"${D_IP[$i]}\" || \"$LOCAL_HOST2\" == \"${D_IP[$i]}\" ]];\n    then\n        echo \"${D_IP[$i]}\"\n        export SGLANG_ENABLE_OVERLAP_PLAN_STREAM=1\n        export SGLANG_ENABLE_SPEC_V2=1\n        export HCCL_BUFFSIZE=650\n        export SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=78\n        export TASK_QUEUE_ENABLE=1\n        export SGLANG_SCHEDULER_SKIP_ALL_GATHER=1\n        export HCCL_SOCKET_IFNAME=xxx\n        export GLOO_SOCKET_IFNAME=xxx\n        python -m sglang.launch_server --model-path ${MODEL_PATH} --disaggregation-mode decode --host ${D_IP[$i]} \\\n        --port 8001 --trust-remote-code --dist-init-addr ${D_IP[0]}:5000 --nnodes 2 --node-rank $i --tp-size 32 --dp-size 32 \\\n        --mem-fraction-static 0.815 --max-running-requests 832 --attention-backend ascend --device npu --quantization modelslim \\\n        --moe-a2a-backend deepep --enable-dp-attention --deepep-mode low_latency --enable-dp-lm-head --moe-dense-tp 1 \\\n        --cuda-graph-bs 12 14 16 18 20 22 24 26 --disaggregation-transfer-backend ascend --watchdog-timeout 9000 --context-length 8192 \\\n        --speculative-algorithm NEXTN --speculative-num-steps 2 --speculative-eagle-topk 1 --speculative-num-draft-tokens 3  \\\n        --tokenizer-worker-num 4 --prefill-round-robin-balance --disable-shared-experts-fusion --dtype bfloat16 \\\n        --load-balance-method decode_round_robin\n        NODE_RANK=$i\n        break\n    fi\ndone\n\n```\n\n```shell\nexport SGLANG_DP_ROUND_ROBIN=1\npython -m sglang_router.launch_router \\\n    --pd-disaggregation \\\n    --policy cache_aware \\\n    --prefill http://P_IP:8000 8998 \\\n    --prefill http://P_IP:8000 8999 \\\n    --decode http://D_IP:8001 \\\n    --host 127.0.0.1 \\\n    --port 6688 \\\n    --mini-lb\n```\n\n#### Benchmark\n\nWe tested it based on the `RANDOM` dataset.\n\n```shell\npython -m sglang.bench_serving --dataset-name random --backend sglang --host 127.0.0.1 --port 6688 --max-concurrency 768  --random-input-len 3500 --random-output-len 1500 --num-prompts 3072 --random-range-ratio 1 --request-rate 16\n```\n\n### DeepSeek-R1 6K-1_6K 20ms on A3 32 Cards Separation Mode\n\nModel: Deepseek R1\n\nHardware: Atlas 800I A3 32Card\n\nDeployMode: PD Separation\n\nDataset: random\n\nInput Output Length: 6K+1.6K\n\nTPOT: 20ms\n\n#### Model Deployment\n\n```shell\necho performance | tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor\nsysctl -w vm.swappiness=0\nsysctl -w kernel.numa_balancing=0\nsysctl -w kernel.sched_migration_cost_ns=50000\nexport SGLANG_SET_CPU_AFFINITY=1\nunset ASCEND_LAUNCH_BLOCKING\nsource /usr/local/Ascend/ascend-toolkit/set_env.sh\nsource /usr/local/Ascend/nnal/atb/set_env.sh\nexport PATH=/usr/local/Ascend/8.5.0/compiler/bishengir/bin:$PATH\nexport PYTORCH_NPU_ALLOC_CONF=expandable_segments:True\nexport STREAMS_PER_DEVICE=32\nexport ASCEND_MF_STORE_URL=\"tcp://your prefill ip1:24669\"\n\nP_IP=('your prefill ip1' 'your prefill ip2')\n\nD_IP=('your decode ip1' 'your decode ip2')\n\nMODEL_PATH=xxx\n\nexport SGLANG_NPU_USE_MLAPO=1\nexport SGLANG_USE_FIA_NZ=1\n\nLOCAL_HOST1=`hostname -I|awk -F \" \" '{print$1}'`\nLOCAL_HOST2=`hostname -I|awk -F \" \" '{print$2}'`\necho \"${LOCAL_HOST1}\"\necho \"${LOCAL_HOST2}\"\n# prefill\nfor i in \"${!P_IP[@]}\";\ndo\n    if [[ \"$LOCAL_HOST1\" == \"${P_IP[$i]}\" || \"$LOCAL_HOST2\" == \"${P_IP[$i]}\" ]];\n    then\n        echo \"${P_IP[$i]}\"\n        export HCCL_BUFFSIZE=1536\n        export DEEP_NORMAL_MODE_USE_INT8_QUANT=1\n        export TASK_QUEUE_ENABLE=2\n\n        export HCCL_SOCKET_IFNAME=lo\n        export GLOO_SOCKET_IFNAME=lo\n        python -m sglang.launch_server --model-path ${MODEL_PATH}  --disaggregation-mode prefill --host ${P_IP[$i]} \\\n        --port 8000 --disaggregation-bootstrap-port $((8998+$i)) --trust-remote-code --nnodes 1 --node-rank 0 \\\n        --tp-size 16 --mem-fraction-static 0.81 --attention-backend ascend --device npu --quantization modelslim \\\n        --disaggregation-transfer-backend ascend --max-running-requests 4 --context-length 8192  --disable-radix-cache \\\n        --chunked-prefill-size -1 --max-prefill-tokens 28680 --moe-a2a-backend deepep --deepep-mode normal \\\n        --speculative-algorithm NEXTN --speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2  \\\n        --dp-size 2 --enable-dp-attention --disable-shared-experts-fusion --dtype bfloat16 --enable-attn-tp-input-scattered\n        NODE_RANK=$i\n        break\n    fi\ndone\n\n# decode\nfor i in \"${!D_IP[@]}\";\ndo\n    if [[ \"$LOCAL_HOST1\" == \"${D_IP[$i]}\" || \"$LOCAL_HOST2\" == \"${D_IP[$i]}\" ]];\n    then\n        echo \"${D_IP[$i]}\"\n        export SGLANG_ENABLE_OVERLAP_PLAN_STREAM=1\n        export SGLANG_ENABLE_SPEC_V2=1\n        export HCCL_BUFFSIZE=650\n        export SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=12\n        export TASK_QUEUE_ENABLE=1\n        export SGLANG_SCHEDULER_SKIP_ALL_GATHER=1\n        export HCCL_SOCKET_IFNAME=xxx\n        export GLOO_SOCKET_IFNAME=xxx\n        python -m sglang.launch_server --model-path ${MODEL_PATH} --disaggregation-mode decode --host ${D_IP[$i]} \\\n        --port 8001 --trust-remote-code --dist-init-addr DIP1:5000 --nnodes 2 --node-rank $i --tp-size 32 --dp-size 16 \\\n        --mem-fraction-static 0.75 --max-running-requests 32 --attention-backend ascend --device npu --quantization modelslim \\\n        --moe-a2a-backend deepep --enable-dp-attention --deepep-mode low_latency --enable-dp-lm-head --moe-dense-tp 1 \\\n        --cuda-graph-bs 2 4 6 --disaggregation-transfer-backend ascend --watchdog-timeout 9000 --context-length 8192 \\\n        --speculative-algorithm NEXTN --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4  \\\n        --tokenizer-worker-num 4 --prefill-round-robin-balance --disable-shared-experts-fusion --dtype bfloat16 \\\n        --load-balance-method decode_round_robin\n        NODE_RANK=$i\n        break\n    fi\ndone\n\n```\n\n```shell\nexport SGLANG_DP_ROUND_ROBIN=1\npython -m sglang_router.launch_router \\\n    --pd-disaggregation \\\n    --policy cache_aware \\\n    --prefill http://P_IP:8000 8998 \\\n    --prefill http://P_IP:8000 8999 \\\n    --decode http://D_IP:8001 \\\n    --host 127.0.0.1 \\\n    --port 6688 \\\n    --mini-lb\n```\n\n#### Benchmark\n\nWe tested it based on the `RANDOM` dataset.\n\n```shell\npython -m sglang.bench_serving --dataset-name random --backend sglang --host 127.0.0.1 --port 6688 --max-concurrency 32  --random-input-len 6000 --random-output-len 1600 --num-prompts 32 --random-range-ratio 1\n```\n\n### DeepSeek-R1 3_9K-1K 20ms on A3 32 Cards Separation Mode\n\nModel: Deepseek R1\n\nHardware: Atlas 800I A3 32Card\n\nDeployMode: PD Separation\n\nDataset: random\n\nInput Output Length: 3.9K+1K\n\nTPOT: 20ms\n\n#### Model Deployment\n\nPlease Turn to [DeepSeek-R1 6K-1_6K 20ms on A3 32 Cards Separation Mode](#deepseek-r1-6k-1_6k-20ms-on-a3-32-cards-separation-mode)\n\n#### Benchmark\n\nWe tested it based on the `RANDOM` dataset.\n\n```shell\npython -m sglang.bench_serving --dataset-name random --backend sglang --host 127.0.0.1 --port 6688 --max-concurrency 768  --random-input-len 3900 --random-output-len 1000 --num-prompts 768 --random-range-ratio 1 --request-rate 16\n```\n\n### DeepSeek-R1 3_5K-1_5K 20ms on A3 32 Cards Separation Mode\n\nModel: Deepseek R1\n\nHardware: Atlas 800I A3 32Card\n\nDeployMode: PD Separation\n\nDataset: random\n\nInput Output Length: 3.5K+1.5K\n\nTPOT: 20ms\n\n#### Model Deployment\n\nPlease Turn to [DeepSeek-R1 6K-1_6K 20ms on A3 32 Cards Separation Mode](#deepseek-r1-6k-1_6k-20ms-on-a3-32-cards-separation-mode)\n\n#### Benchmark\n\nWe tested it based on the `RANDOM` dataset.\n\n```bash\npython -m sglang.bench_serving --dataset-name random --backend sglang --host 127.0.0.1 --port 6688 --max-concurrency 768  --random-input-len 3500 --random-output-len 1500 --num-prompts 768 --random-range-ratio 1 --request-rate 16\n```\n\n### DeepSeek-R1 3_5K-1K 20ms on A3 32 Cards Separation Mode\n\nModel: Deepseek R1\n\nHardware: Atlas 800I A3 32Card\n\nDeployMode: PD Separation\n\nDataset: random\n\nInput Output Length: 3.5K+1K\n\nTPOT: 20ms\n\n#### Model Deployment\n\nPlease Turn to [DeepSeek-R1 6K-1_6K 20ms on A3 32 Cards Separation Mode](#deepseek-r1-6k-1_6k-20ms-on-a3-32-cards-separation-mode)\n\n#### Benchmark\n\nWe tested it based on the `RANDOM` dataset.\n\n```shell\npython -m sglang.bench_serving --dataset-name random --backend sglang --host 127.0.0.1 --port 6688 --max-concurrency 768  --random-input-len 3500 --random-output-len 1000 --num-prompts 768 --random-range-ratio 1 --request-rate 16\n```\n\n### DeepSeek-R1 2K-2K 50ms on A3 8 Cards Mixed Mode\n\nModel: Deepseek R1\n\nHardware: Atlas 800I A3 8Card\n\nDeployMode: PD Mixed\n\nDataset: random\n\nInput Output Length: 2K+2K\n\nTPOT: 50ms\n\n#### Model Deployment\n\n```shell\necho performance | tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor\nsysctl -w vm.swappiness=0\nsysctl -w kernel.numa_balancing=0\nsysctl -w kernel.sched_migration_cost_ns=50000\n\nexport SGLANG_SET_CPU_AFFINITY=1\nunset https_proxy\nunset http_proxy\nunset HTTPS_PROXY\nunset HTTP_PROXY\nunset ASCEND_LAUNCH_BLOCKING\nsource /usr/local/Ascend/ascend-toolkit/set_env.sh\nsource /usr/local/Ascend/nnal/atb/set_env.sh\nexport PATH=/usr/local/Ascend/8.5.0/compiler/bishengir/bin:$PATH\n\nexport PYTORCH_NPU_ALLOC_CONF=expandable_segments:True\nexport STREAMS_PER_DEVICE=32\nexport SGLANG_SCHEDULER_DECREASE_PREFILL_IDLE=1\n\nexport HCCL_SOCKET_IFNAME=lo\nexport GLOO_SOCKET_IFNAME=lo\n\nexport SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=64\nexport HCCL_BUFFSIZE=1600\nexport DEEPEP_NORMAL_LONG_SEQ_ROUND=10\nexport DEEPEP_NORMAL_LONG_SEQ_PER_ROUND_TOKENS=512\n\nMODEL_PATH=xxx\n\nexport DEEP_NORMAL_MODE_USE_INT8_QUANT=1\nexport SGLANG_NPU_USE_MLAPO=1\nexport SGLANG_ENABLE_SPEC_V2=1\nexport SGLANG_ENABLE_OVERLAP_PLAN_STREAM=1\nexport SGLANG_USE_FIA_NZ=1\nexport ENABLE_MOE_NZ=1\n\npython3 -m sglang.launch_server --model-path ${MODEL_PATH} \\\n--tp 16 \\\n--trust-remote-code \\\n--attention-backend ascend \\\n--device npu \\\n--quantization modelslim \\\n--watchdog-timeout 9000 \\\n--host 127.0.0.1 --port 6699 \\\n--cuda-graph-bs 4 8 16 \\\n--mem-fraction-static 0.74 \\\n--max-running-requests 256 \\\n--disable-radix-cache --chunked-prefill-size -1 --max-prefill-tokens 1500 \\\n--moe-a2a-backend deepep --deepep-mode auto \\\n--enable-dp-attention --dp-size 16 --enable-dp-lm-head \\\n--speculative-algorithm NEXTN --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 \\\n--dtype bfloat16\n\n```\n\n#### Benchmark\n\nWe tested it based on the `RANDOM` dataset.\n\n```shell\npython -m sglang.bench_serving --dataset-name random --backend sglang --host 127.0.0.1 --port 6699 --max-concurrency 256  --random-input-len 2048 --random-output-len 2048 --num-prompts 1024 --random-range-ratio 1\n```\n\n### DeepSeek-R1 2K-2K 50ms on A3 16 Cards Separation Mode\n\nModel: Deepseek R1\n\nHardware: Atlas 800I A3 16Card\n\nDeployMode: PD Separation\n\nDataset: random\n\nInput Output Length: 2K+2K\n\nTPOT: 50ms\n\n#### Model Deployment\n\n```shell\necho performance | tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor\nsysctl -w vm.swappiness=0\nsysctl -w kernel.numa_balancing=0\nsysctl -w kernel.sched_migration_cost_ns=50000\n\nexport SGLANG_SET_CPU_AFFINITY=1\n\nunset https_proxy\nunset http_proxy\nunset HTTPS_PROXY\nunset HTTP_PROXY\nunset ASCEND_LAUNCH_BLOCKING\nsource /usr/local/Ascend/ascend-toolkit/set_env.sh\nsource /usr/local/Ascend/nnal/atb/set_env.sh\nexport PATH=/usr/local/Ascend/8.5.0/compiler/bishengir/bin:$PATH\n\nexport PYTORCH_NPU_ALLOC_CONF=expandable_segments:True\nexport STREAMS_PER_DEVICE=32\n\nexport ASCEND_MF_STORE_URL=\"tcp://your prefill ip1:24667\"\n\nP_IP=('your prefill ip1')\n\nD_IP=('your decode ip1')\n\nMODEL_PATH=xxx\n\nexport SGLANG_NPU_USE_MLAPO=1\nexport SGLANG_USE_FIA_NZ=1\nexport ENABLE_MOE_NZ=1\n\nLOCAL_HOST1=`hostname -I|awk -F \" \" '{print$1}'`\nLOCAL_HOST2=`hostname -I|awk -F \" \" '{print$2}'`\necho \"${LOCAL_HOST1}\"\necho \"${LOCAL_HOST2}\"\n\n# prefill\nfor i in \"${!P_IP[@]}\";\ndo\n    if [[ \"$LOCAL_HOST1\" == \"${P_IP[$i]}\" || \"$LOCAL_HOST2\" == \"${P_IP[$i]}\" ]];\n    then\n        echo \"${P_IP[$i]}\"\n        export HCCL_BUFFSIZE=1536\n        export DEEP_NORMAL_MODE_USE_INT8_QUANT=1\n        export TASK_QUEUE_ENABLE=2\n\n        export HCCL_SOCKET_IFNAME=lo\n        export GLOO_SOCKET_IFNAME=lo\n        python -m sglang.launch_server --model-path ${MODEL_PATH}  --disaggregation-mode prefill --host ${P_IP[$i]} \\\n        --port 8000 --disaggregation-bootstrap-port $((8998+$i)) --trust-remote-code --nnodes 1 --node-rank 0 \\\n        --tp-size 16 --mem-fraction-static 0.6 --attention-backend ascend --device npu --quantization modelslim \\\n        --disaggregation-transfer-backend ascend --max-running-requests 8 --context-length 8192  --disable-radix-cache \\\n        --chunked-prefill-size 32768 --max-prefill-tokens 28680 --moe-a2a-backend deepep --deepep-mode normal \\\n        --speculative-algorithm NEXTN --speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2  \\\n        --dp-size 2 --enable-dp-attention --disable-shared-experts-fusion --dtype bfloat16\n        NODE_RANK=$i\n        break\n    fi\ndone\n\n# decode\nfor i in \"${!D_IP[@]}\";\ndo\n    if [[ \"$LOCAL_HOST1\" == \"${D_IP[$i]}\" || \"$LOCAL_HOST2\" == \"${D_IP[$i]}\" ]];\n    then\n        echo \"${D_IP[$i]}\"\n        export SGLANG_ENABLE_OVERLAP_PLAN_STREAM=1\n        export SGLANG_ENABLE_SPEC_V2=1\n        export HCCL_BUFFSIZE=720\n        export SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=96\n        export TASK_QUEUE_ENABLE=1\n        export HCCL_SOCKET_IFNAME=xxx\n        export GLOO_SOCKET_IFNAME=xxx\n        python -m sglang.launch_server --model-path ${MODEL_PATH} --disaggregation-mode decode --host ${D_IP[$i]} \\\n        --port 8001 --trust-remote-code --nnodes 1 --node-rank 0 --tp-size 16 --dp-size 16 \\\n        --mem-fraction-static 0.8 --max-running-requests 384 --attention-backend ascend --device npu --quantization modelslim \\\n        --moe-a2a-backend deepep --enable-dp-attention --deepep-mode low_latency --enable-dp-lm-head \\\n        --cuda-graph-bs 8 10 12 14 16 18 20 22 24 --disaggregation-transfer-backend ascend --watchdog-timeout 9000 --context-length 8192 \\\n        --speculative-algorithm NEXTN --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4  \\\n        --prefill-round-robin-balance --disable-shared-experts-fusion --dtype bfloat16 --tokenizer-worker-num 4 \\\n\t\t    --load-balance-method decode_round_robin\n        NODE_RANK=$i\n        break\n    fi\ndone\n\n```\n\n```shell\nexport SGLANG_DP_ROUND_ROBIN=1\npython -m sglang_router.launch_router \\\n    --pd-disaggregation \\\n    --policy cache_aware \\\n    --prefill http://P_IP:8000 8998 \\\n    --decode http://D_IP:8001 \\\n    --host 127.0.0.1 \\\n    --port 6688 \\\n    --mini-lb\n```\n\n#### Benchmark\n\nWe tested it based on the `RANDOM` dataset.\n\n```shell\npython -m sglang.bench_serving --dataset-name random --backend sglang --host 127.0.0.1 --port 6688 --max-concurrency 400  --random-input-len 2048 --random-output-len 2048 --num-prompts 3200 --random-range-ratio 1 --request-rate 8\n```\n\n### DeepSeek-R1 3_5K-1_5K 50ms on A3 8 Cards Mixed Mode\n\nModel: Deepseek R1\n\nHardware: Atlas 800I A3 8Card\n\nDeployMode: PD Mixed\n\nDataset: random\n\nInput Output Length: 3.5K+1.5K\n\nTPOT: 50ms\n\n#### Model Deployment\n\n```shell\necho performance | tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor\nsysctl -w vm.swappiness=0\nsysctl -w kernel.numa_balancing=0\nsysctl -w kernel.sched_migration_cost_ns=50000\n\nexport SGLANG_SET_CPU_AFFINITY=1\n\nunset https_proxy\nunset http_proxy\nunset HTTPS_PROXY\nunset HTTP_PROXY\nunset ASCEND_LAUNCH_BLOCKING\nsource /usr/local/Ascend/ascend-toolkit/set_env.sh\nsource /usr/local/Ascend/nnal/atb/set_env.sh\nexport PATH=/usr/local/Ascend/8.5.0/compiler/bishengir/bin:$PATH\n\nexport PYTORCH_NPU_ALLOC_CONF=expandable_segments:True\n\nexport STREAMS_PER_DEVICE=32\nexport HCCL_SOCKET_IFNAME=lo\nexport GLOO_SOCKET_IFNAME=lo\nexport SGLANG_SCHEDULER_DECREASE_PREFILL_IDLE=1\nexport SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=36\nexport HCCL_BUFFSIZE=1600\nexport DEEP_NORMAL_MODE_USE_INT8_QUANT=1\nexport SGLANG_NPU_USE_MLAPO=1\nexport SGLANG_ENABLE_SPEC_V2=1\nexport SGLANG_ENABLE_OVERLAP_PLAN_STREAM=1\nexport SGLANG_USE_FIA_NZ=1\nexport ENABLE_MOE_NZ=1\n\nMODEL_PATH=xxx\n\npython3 -m sglang.launch_server --model-path ${MODEL_PATH} \\\n--tp 16 \\\n--trust-remote-code \\\n--attention-backend ascend \\\n--device npu \\\n--quantization modelslim \\\n--watchdog-timeout 9000 \\\n--host 127.0.0.1 --port 6699 \\\n--cuda-graph-bs 8 16 24 28 32 36 \\\n--mem-fraction-static 0.71 \\\n--max-running-requests 144 \\\n--context-length 8188  --disable-radix-cache --chunked-prefill-size -1 --max-prefill-tokens 9000 \\\n--moe-a2a-backend deepep --deepep-mode auto \\\n--enable-dp-attention --dp-size 4 --enable-dp-lm-head \\\n--speculative-algorithm NEXTN --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 \\\n--dtype bfloat16\n\n```\n\n#### Benchmark\n\nWe tested it based on the `RANDOM` dataset.\n\n```shell\npython -m sglang.bench_serving --dataset-name random --backend sglang --host 127.0.0.1 --port 6699 --max-concurrency 144  --random-input-len 3500 --random-output-len 1500 --num-prompts 576 --random-range-ratio 1\n```\n\n### DeepSeek-R1 3_5K-1_5K 50ms on A3 16 Cards Separation Mode\n\nModel: Deepseek R1\n\nHardware: Atlas 800I A3 16Card\n\nDeployMode: PD Separation\n\nDataset: random\n\nInput Output Length: 3.5K+1.5K\n\nTPOT: 50ms\n\n#### Model Deployment\n\n```shell\necho performance | tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor\nsysctl -w vm.swappiness=0\nsysctl -w kernel.numa_balancing=0\nsysctl -w kernel.sched_migration_cost_ns=50000\n\nexport SGLANG_SET_CPU_AFFINITY=1\nunset https_proxy\nunset http_proxy\nunset HTTPS_PROXY\nunset HTTP_PROXY\nunset ASCEND_LAUNCH_BLOCKING\nsource /usr/local/Ascend/ascend-toolkit/set_env.sh\nsource /usr/local/Ascend/nnal/atb/set_env.sh\nexport PATH=/usr/local/Ascend/8.5.0/compiler/bishengir/bin:$PATH\n\nexport PYTORCH_NPU_ALLOC_CONF=expandable_segments:True\nexport STREAMS_PER_DEVICE=32\n\nexport ASCEND_MF_STORE_URL=\"tcp://your prefill ip1:24667\"\n\nP_IP=('your prefill ip1')\n\nD_IP=('your decode ip1')\n\nMODEL_PATH=xxx\n\nexport SGLANG_NPU_USE_MLAPO=1\nexport SGLANG_USE_FIA_NZ=1\nexport ENABLE_MOE_NZ=1\n\nLOCAL_HOST1=`hostname -I|awk -F \" \" '{print$1}'`\nLOCAL_HOST2=`hostname -I|awk -F \" \" '{print$2}'`\necho \"${LOCAL_HOST1}\"\necho \"${LOCAL_HOST2}\"\n\n# prefill\nfor i in \"${!P_IP[@]}\";\ndo\n    if [[ \"$LOCAL_HOST1\" == \"${P_IP[$i]}\" || \"$LOCAL_HOST2\" == \"${P_IP[$i]}\" ]];\n    then\n        echo \"${P_IP[$i]}\"\n        export HCCL_BUFFSIZE=1536\n        export DEEP_NORMAL_MODE_USE_INT8_QUANT=1\n        export TASK_QUEUE_ENABLE=2\n\n        export HCCL_SOCKET_IFNAME=lo\n        export GLOO_SOCKET_IFNAME=lo\n        python -m sglang.launch_server --model-path ${MODEL_PATH}  --disaggregation-mode prefill --host ${P_IP[$i]} \\\n        --port 8000 --disaggregation-bootstrap-port $((8998+$i)) --trust-remote-code --nnodes 1 --node-rank 0 \\\n        --tp-size 16 --mem-fraction-static 0.6 --attention-backend ascend --device npu --quantization modelslim \\\n        --disaggregation-transfer-backend ascend --max-running-requests 8 --context-length 8192  --disable-radix-cache \\\n        --chunked-prefill-size -1 --max-prefill-tokens 28680 --moe-a2a-backend deepep --deepep-mode normal \\\n        --speculative-algorithm NEXTN --speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2  \\\n        --dp-size 2 --enable-dp-attention --disable-shared-experts-fusion --dtype bfloat16\n        NODE_RANK=$i\n        break\n    fi\ndone\n\n# decode\nfor i in \"${!D_IP[@]}\";\ndo\n    if [[ \"$LOCAL_HOST1\" == \"${D_IP[$i]}\" || \"$LOCAL_HOST2\" == \"${D_IP[$i]}\" ]];\n    then\n        echo \"${D_IP[$i]}\"\n        export SGLANG_ENABLE_OVERLAP_PLAN_STREAM=1\n        export SGLANG_ENABLE_SPEC_V2=1\n        export HCCL_BUFFSIZE=720\n        export SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=96\n        export TASK_QUEUE_ENABLE=1\n        export HCCL_SOCKET_IFNAME=xxx\n        export GLOO_SOCKET_IFNAME=xxx\n        python -m sglang.launch_server --model-path ${MODEL_PATH} --disaggregation-mode decode --host ${D_IP[$i]} \\\n        --port 8001 --trust-remote-code --nnodes 1 --node-rank 0 --tp-size 16 --dp-size 16 \\\n        --mem-fraction-static 0.8 --max-running-requests 384 --attention-backend ascend --device npu --quantization modelslim \\\n        --moe-a2a-backend deepep --enable-dp-attention --deepep-mode low_latency --enable-dp-lm-head \\\n        --cuda-graph-bs 8 10 12 14 16 18 20 22 24 --disaggregation-transfer-backend ascend --watchdog-timeout 9000 --context-length 8192 \\\n        --speculative-algorithm NEXTN --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4  \\\n        --prefill-round-robin-balance --disable-shared-experts-fusion --dtype bfloat16 --tokenizer-worker-num 4 \\\n\t\t    --load-balance-method decode_round_robin\n        NODE_RANK=$i\n        break\n    fi\ndone\n\n```\n\n```shell\nexport SGLANG_DP_ROUND_ROBIN=1\npython -m sglang_router.launch_router \\\n    --pd-disaggregation \\\n    --policy cache_aware \\\n    --prefill http://P_IP:8000 8998 \\\n    --decode http://D_IP:8001 \\\n    --host 127.0.0.1 \\\n    --port 6688 \\\n    --mini-lb\n```\n\n#### Benchmark\n\nWe tested it based on the `RANDOM` dataset.\n\n```shell\npython -m sglang.bench_serving --dataset-name random --backend sglang --host 127.0.0.1 --port 6688 --max-concurrency 384  --random-input-len 3500 --random-output-len 1500 --num-prompts 1536 --random-range-ratio 1\n```\n\n### DeepSeek-V3.2-Exp 64K-3K 30ms on A3 32 Cards Separation Mode\n\nModel: DeepSeek-V3.2-Exp-W8A8\n\nHardware: Atlas 800I A3 32Card\n\nDeployMode: PD Separation\n\nDataset: random\n\nInput Output Length: 64K+3K\n\nTPOT: 30ms\n\n#### Model Deployment\n\nDeploy Prefill Instance\n\n```shell\necho performance | tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor\nsysctl -w vm.swappiness=0\nsysctl -w kernel.numa_balancing=0\nsysctl -w kernel.sched_migration_cost_ns=50000\n\nexport SGLANG_SET_CPU_AFFINITY=1\nunset https_proxy\nunset http_proxy\nunset HTTPS_PROXY\nunset HTTP_PROXY\nunset ASCEND_LAUNCH_BLOCKING\n\nsource /usr/local/Ascend/ascend-toolkit/set_env.sh\nsource /usr/local/Ascend/nnal/atb/set_env.sh\nexport LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/opp/vendors/customize/op_api/lib/:${LD_LIBRARY_PATH}\nexport PATH=/usr/local/Ascend/8.5.0/compiler/bishengir/bin:$PATH\n\nexport ASCEND_HOME_PATH=/usr/local/Ascend/ascend-toolkit/latest\n\nexport PYTORCH_NPU_ALLOC_CONF=expandable_segments:True\nexport STREAMS_PER_DEVICE=32\n\nexport HCCL_BUFFSIZE=1024\nexport DEEPEP_NORMAL_LONG_SEQ_ROUND=5\nexport DEEPEP_NORMAL_LONG_SEQ_PER_ROUND_TOKENS=512\n\nMODEL_PATH=xxx\n\nexport SGLANG_NPU_USE_MLAPO=1\nexport DEEP_NORMAL_MODE_USE_INT8_QUANT=1\nexport SGLANG_NPU_USE_MULTI_STREAM=1\nexport HCCL_OP_EXPANSION_MODE=AIV\n\nIPs=('your prefill ip1' 'your prefill ip2')\n\n# get IP in current node\nLOCAL_HOST=`hostname -I|awk -F \" \" '{print$1}'`\necho \"LOCAL_HOST = \" ${LOCAL_HOST}\n# get node index\nfor i in \"${!IPs[@]}\";\ndo\n  echo \"LOCAL_HOST=${LOCAL_HOST}, IPs[${i}]=${IPs[$i]}\"\n  if [ \"$LOCAL_HOST\" == \"${IPs[$i]}\" ]; then\n      echo \"Node Rank : ${i}\"\n      VC_TASK_INDEX=$i\n      break\n  fi\ndone\n\nIFNAMES=('xxx' 'xxx')\n\nexport HCCL_SOCKET_IFNAME=${IFNAMES[$VC_TASK_INDEX]}\nexport GLOO_SOCKET_IFNAME=${HCCL_SOCKET_IFNAME}\necho \"HCCL_SOCKET_IFNAME : ${HCCL_SOCKET_IFNAME}\"\nnnodes=${#IPs[@]}\ntp_size=`expr 16 \\* ${nnodes}`\nexport ASCEND_MF_STORE_URL=tcp://${IPs[0]}:24667\n\npython3 -m sglang.launch_server --model-path ${MODEL_PATH} \\\n--tp $tp_size \\\n--trust-remote-code \\\n--attention-backend ascend \\\n--device npu \\\n--watchdog-timeout 9000 \\\n--host ${IPs[$VC_TASK_INDEX]} --port 8000 \\\n--mem-fraction-static 0.73 \\\n--disable-radix-cache --chunked-prefill-size -1 --max-prefill-tokens 68000 \\\n--max-running-requests 1 \\\n--moe-a2a-backend deepep --deepep-mode normal \\\n--quantization modelslim \\\n--disaggregation-transfer-backend ascend \\\n--disaggregation-mode prefill \\\n--disable-cuda-graph \\\n--nnodes $nnodes --node-rank $VC_TASK_INDEX \\\n--disaggregation-bootstrap-port 8995 \\\n--enable-nsa-prefill-context-parallel  --moe-dense-tp-size 1 \\\n--speculative-algorithm NEXTN --speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2 \\\n--dist-init-addr ${IPs[0]}:10000\n```\n\nDeploy Decode Instance\n\n```shell\necho performance | tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor\nsysctl -w vm.swappiness=0\nsysctl -w kernel.numa_balancing=0\nsysctl -w kernel.sched_migration_cost_ns=50000\n\nexport SGLANG_SET_CPU_AFFINITY=1\nunset https_proxy\nunset http_proxy\nunset HTTPS_PROXY\nunset HTTP_PROXY\nunset ASCEND_LAUNCH_BLOCKING\nsource /usr/local/Ascend/ascend-toolkit/set_env.sh\nsource /usr/local/Ascend/nnal/atb/set_env.sh\nexport LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/opp/vendors/customize/op_api/lib/:${LD_LIBRARY_PATH}\nexport PATH=/usr/local/Ascend/8.5.0/compiler/bishengir/bin:$PATH\nexport ASCEND_HOME_PATH=/usr/local/Ascend/ascend-toolkit/latest\n\nexport PYTORCH_NPU_ALLOC_CONF=expandable_segments:True\nexport STREAMS_PER_DEVICE=32\n\nMODEL_PATH=xxx\n\nexport SGLANG_NPU_USE_MULTI_STREAM=1\nexport SGLANG_NPU_USE_MLAPO=1\nexport HCCL_OP_EXPANSION_MODE=AIV\nexport SGLANG_SCHEDULER_SKIP_ALL_GATHER=1\nexport TASK_QUEUE_ENABLE=0\nexport SGLANG_ENABLE_OVERLAP_PLAN_STREAM=1\nexport SGLANG_ENABLE_SPEC_V2=1\n\nIPs=('your decode ip1' 'your decode ip2')\n\nexport prefill_ip=your prefill ip1\n# get IP in current node\nLOCAL_HOST=`hostname -I|awk -F \" \" '{print$1}'`\necho \"LOCAL_HOST = \" ${LOCAL_HOST}\n# get node index\nfor i in \"${!IPs[@]}\";\ndo\n  echo \"LOCAL_HOST=${LOCAL_HOST}, IPs[${i}]=${IPs[$i]}\"\n  if [ \"$LOCAL_HOST\" == \"${IPs[$i]}\" ]; then\n      echo \"Node Rank : ${i}\"\n      VC_TASK_INDEX=$i\n      break\n  fi\ndone\n\nIFNAMES=('xxx' 'xxx')\n\nexport HCCL_SOCKET_IFNAME=${IFNAMES[$VC_TASK_INDEX]}\nexport GLOO_SOCKET_IFNAME=${HCCL_SOCKET_IFNAME}\nnnodes=${#IPs[@]}\ntp_size=`expr 16 \\* ${nnodes}`\nexport ASCEND_MF_STORE_URL=tcp://${prefill_ip}:24667\n\nCHUNKED_SIZE=65536\nDP=8\nexport HCCL_BUFFSIZE=400\nexport SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=8\n\npython3 -m sglang.launch_server --model-path ${MODEL_PATH} \\\n--tp $tp_size \\\n--dp ${DP} \\\n--ep $tp_size \\\n--moe-dense-tp-size 1 \\\n--enable-dp-attention \\\n--enable-dp-lm-head \\\n--trust-remote-code \\\n--attention-backend ascend \\\n--device npu \\\n--watchdog-timeout 9000 \\\n--host ${IPs[$VC_TASK_INDEX]} --port 8001 \\\n--mem-fraction-static 0.79 \\\n--disable-radix-cache \\\n--chunked-prefill-size -1 --max-prefill-tokens 68000 \\\n--max-running-requests 32 \\\n--cuda-graph-max-bs 4 \\\n--moe-a2a-backend deepep \\\n--deepep-mode low_latency \\\n--quantization modelslim \\\n--speculative-algorithm NEXTN --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 \\\n--disaggregation-transfer-backend ascend \\\n--disaggregation-mode decode \\\n--prefill-round-robin-balance \\\n--load-balance-method round_robin \\\n--nnodes $nnodes --node-rank $VC_TASK_INDEX \\\n--dist-init-addr ${IPs[0]}:10000 --load-balance-method decode_round_robin\n```\n\n```shell\nexport SGLANG_DP_ROUND_ROBIN=1\npython -m sglang_router.launch_router \\\n    --pd-disaggregation \\\n    --policy cache_aware \\\n    --prefill http://PIP1:8000 8995 \\\n    --decode http://DIP1:8001 \\\n    --host 127.0.0.1 \\\n    --port 6688 \\\n    --mini-lb\n```\n\n#### Benchmark\n\nWe tested it based on the `RANDOM` dataset.\n\n```shell\npython -m sglang.bench_serving --dataset-name random --backend sglang --host 127.0.0.1 --port 6688 --max-concurrency 32  --random-input-len 64000 --random-output-len 3000 --num-prompts 64 --random-range-ratio 1\n```\n\n### Qwen3-235B-A22B 3_5K-1_5K 50ms on A3 24 Cards Separation Mode\n\nModel: Qwen3-235B-A22B-W8A8\n\nHardware: Atlas 800I A3 24Card\n\nDeployMode: PD Separation\n\nDataset: random\n\nInput Output Length: 3.5K+1.5K\n\nTPOT: 50ms\n\n#### Model Deployment\n\n```shell\necho performance | tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor\nsysctl -w vm.swappiness=0\nsysctl -w kernel.numa_balancing=0\nsysctl -w kernel.sched_migration_cost_ns=50000\n\nexport SGLANG_SET_CPU_AFFINITY=1\nunset https_proxy\nunset http_proxy\nunset HTTPS_PROXY\nunset HTTP_PROXY\nunset ASCEND_LAUNCH_BLOCKING\n\nsource /usr/local/Ascend/ascend-toolkit/latest/opp/vendors/customize/bin/set_env.bash\nexport PYTORCH_NPU_ALLOC_CONF=expandable_segments:True\n\nexport SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=16\n\nMODEL_PATH=xxx\nexport ASCEND_MF_STORE_URL=\"tcp://your prefill ip1:24667\"\nP_IP=('your prefill ip1')\nD_IP=('your decode ip1' 'your decode ip2')\nexport SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=600\nexport SGLANG_ENABLE_OVERLAP_PLAN_STREAM=1\nexport SGLANG_ENABLE_SPEC_V2=1\nexport SGLANG_DP_ROUND_ROBIN=1\n\nLOCAL_HOST1=`hostname -I|awk -F \" \" '{print$1}'`\nLOCAL_HOST2=`hostname -I|awk -F \" \" '{print$2}'`\n\necho \"${LOCAL_HOST1}\"\necho \"${LOCAL_HOST2}\"\n\n\nfor i in \"${!P_IP[@]}\";\ndo\n    if [[ \"$LOCAL_HOST1\" == \"${P_IP[$i]}\" || \"$LOCAL_HOST2\" == \"${P_IP[$i]}\" ]];\n    then\n        echo \"${P_IP[$i]}\"\n        source /usr/local/Ascend/ascend-toolkit/set_env.sh\n        source /usr/local/Ascend/nnal/atb/set_env.sh\n        export DEEPEP_NORMAL_LONG_SEQ_PER_ROUND_TOKENS=1024\n        export DEEPEP_NORMAL_LONG_SEQ_ROUND=16\n        export HCCL_BUFFSIZE=4300\n        export TASK_QUEUE_ENABLE=2\n        export HCCL_SOCKET_IFNAME=lo\n        export GLOO_SOCKET_IFNAME=lo\n        export STREAMS_PER_DEVICE=32\n        export DEEP_NORMAL_MODE_USE_INT8_QUANT=1\n\n        # P节点\n        python -m sglang.launch_server --model-path ${MODEL_PATH} --disaggregation-mode prefill \\\n        --host ${P_IP[$i]} --port 8000 --disaggregation-bootstrap-port 8995 --trust-remote-code \\\n        --nnodes 1 --node-rank $i --tp-size 16 --dp-size 16 --mem-fraction-static 0.6 \\\n        --disable-radix-cache \\\n        --attention-backend ascend --device npu --quantization modelslim --disaggregation-transfer-backend ascend \\\n        --speculative-algorithm EAGLE3 --speculative-draft-model-path xxx \\\n        --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 \\\n        --speculative-draft-model-quantization unquant \\\n        --max-running-requests 128 --chunked-prefill-size 262144 --max-prefill-tokens 262144 \\\n        --enable-dp-attention  \\\n        --moe-a2a-backend deepep --deepep-mode normal --dtype bfloat16\n        NODE_RANK=$i\n        break\n    fi\ndone\n\n\nfor i in \"${!D_IP[@]}\";\ndo\n    if [[ \"$LOCAL_HOST1\" == \"${D_IP[$i]}\" || \"$LOCAL_HOST2\" == \"${D_IP[$i]}\" ]];\n    then\n        echo \"${D_IP[$i]}\"\n        source /usr/local/Ascend/ascend-toolkit/set_env.sh\n        source /usr/local/Ascend/nnal/atb/set_env.sh\n        export SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=24\n        export HCCL_BUFFSIZE=512\n        export HCCL_SOCKET_IFNAME=data0.3001\n        export GLOO_SOCKET_IFNAME=data0.3001\n        export STREAMS_PER_DEVICE=32\n\n        python -m sglang.launch_server --model-path ${MODEL_PATH} --disaggregation-mode decode \\\n        --host ${D_IP[$i]} --port 8001 --trust-remote-code \\\n        --nnodes 2 --node-rank $i --tp-size 32 --dp-size 32 --mem-fraction-static 0.83 --max-running-requests 768 \\\n        --attention-backend ascend --device npu --quantization modelslim --enable-dp-attention \\\n        --moe-a2a-backend ascend_fuseep --cuda-graph-bs 6 8 12 15 18 20 22 24 \\\n        --speculative-algorithm EAGLE3 --speculative-draft-model-path xxx \\\n        --speculative-draft-model-quantization unquant \\\n        --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 \\\n        --dist-init-addr xxx:5000 \\\n        --disaggregation-transfer-backend ascend --watchdog-timeout 9000 --context-length 8192 \\\n        --prefill-round-robin-balance --enable-dp-lm-head --dtype bfloat16 --tokenizer-worker-num 4 \\\n        --load-balance-method decode_round_robin\n        NODE_RANK=$i\n        break\n    fi\ndone\n\n```\n\n```shell\nexport SGLANG_DP_ROUND_ROBIN=1\npython -m sglang_router.launch_router \\\n    --pd-disaggregation \\\n    --policy cache_aware \\\n    --prefill http://PIP:8000 8995 \\\n    --decode http://DIP:8001 \\\n    --host 127.0.0.1 \\\n    --port 6688 \\\n    --mini-lb\n```\n\n#### Benchmark\n\nWe tested it based on the `RANDOM` dataset.\n\n```shell\npython -m sglang.bench_serving --dataset-name random --backend sglang-oai --host 127.0.0.1 --port 7239 --max-concurrency 860 --random-input-len 3500 --random-output-len 1500 --num-prompts 3440 --random-range-ratio 1\n```\n\n### Qwen3-235B-A22B 3_5K-1_5K 50ms on A3 8 Cards Mixed Mode\n\nModel: Qwen3-235B-A22B-W8A8\n\nHardware: Atlas 800I A3 8Card\n\nDeployMode: PD Mixed\n\nDataset: random\n\nInput Output Length: 3.5K+1.5K\n\nTPOT: 50ms\n\n#### Model Deployment\n\n```shell\necho performance | tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor\nsysctl -w vm.swappiness=0\nsysctl -w kernel.numa_balancing=0\nsysctl -w kernel.sched_migration_cost_ns=50000\n\nexport SGLANG_SET_CPU_AFFINITY=1\nunset https_proxy\nunset http_proxy\nunset HTTPS_PROXY\nunset HTTP_PROXY\nunset ASCEND_LAUNCH_BLOCKING\nsource /usr/local/Ascend/ascend-toolkit/set_env.sh\nsource /usr/local/Ascend/nnal/atb/set_env.sh\nsource /usr/local/Ascend/ascend-toolkit/latest/opp/vendors/customize/bin/set_env.bash\nexport PATH=/usr/local/Ascend/8.5.0/compiler/bishengir/bin:$PATH\n\nexport PYTORCH_NPU_ALLOC_CONF=expandable_segments:True\n\nMODEL_PATH=xxx\n\nexport SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=600\n\nLOCAL_HOST1=`hostname -I|awk -F \" \" '{print$1}'`\nLOCAL_HOST2=`hostname -I|awk -F \" \" '{print$2}'`\n\necho \"${LOCAL_HOST1}\"\necho \"${LOCAL_HOST2}\"\n\nexport HCCL_BUFFSIZE=1600\nexport HCCL_SOCKET_IFNAME=lo\nexport GLOO_SOCKET_IFNAME=lo\nexport HCCL_OP_EXPANSION_MODE=\"AIV\"\nexport SGLANG_ENABLE_OVERLAP_PLAN_STREAM=1\nexport SGLANG_ENABLE_SPEC_V2=1\nexport SGLANG_SCHEDULER_DECREASE_PREFILL_IDLE=2\n\npython -m sglang.launch_server --model-path $MODEL_PATH \\\n    --host 127.0.0.1 --port 7439 --trust-remote-code --nnodes 1 --node-rank 0  \\\n    --attention-backend ascend --device npu --quantization modelslim  \\\n    --max-running-requests 272 --context-length 8192 --dtype bfloat16 \\\n    --chunked-prefill-size 32768 --max-prefill-tokens 32768 \\\n    --speculative-algorithm EAGLE3 --speculative-draft-model-path xxx \\\n    --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 \\\n    --disable-radix-cache --moe-a2a-backend deepep  --deepep-mode auto --speculative-draft-model-quantization unquant \\\n    --tp 16 --dp-size 16 --enable-dp-attention --enable-dp-lm-head --mem-fraction-static 0.8 --cuda-graph-bs 3 4 6 8 10 12 13 14 15 16 17\n\n```\n\n#### Benchmark\n\nWe tested it based on the `RANDOM` dataset.\n\n```shell\npython -m sglang.bench_serving --dataset-name random --backend sglang --host 127.0.0.1 --port 7439 --max-concurrency 272 --random-input-len 3500 --random-output-len 1500 --num-prompts 1088 --random-range-ratio 1\n```\n\n### Qwen3-235B-A22B 2K-2K 100ms on A3 8 Cards Mixed Mode\n\nModel: Qwen3-235B-A22B-W8A8\n\nHardware: Atlas 800I A3 8Card\n\nDeployMode: PD Mixed\n\nDataset: random\n\nInput Output Length: 2K+2K\n\nTPOT: 100ms\n\n#### Model Deployment\n\n```shell\necho performance | tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor\nsysctl -w vm.swappiness=0\nsysctl -w kernel.numa_balancing=0\nsysctl -w kernel.sched_migration_cost_ns=50000\n\nexport SGLANG_SET_CPU_AFFINITY=1\nunset https_proxy\nunset http_proxy\nunset HTTPS_PROXY\nunset HTTP_PROXY\nunset ASCEND_LAUNCH_BLOCKING\nsource /usr/local/Ascend/ascend-toolkit/set_env.sh\nsource /usr/local/Ascend/nnal/atb/set_env.sh\nsource /usr/local/Ascend/ascend-toolkit/latest/opp/vendors/customize/bin/set_env.bash\nexport PATH=/usr/local/Ascend/8.5.0/compiler/bishengir/bin:$PATH\n\nexport PYTORCH_NPU_ALLOC_CONF=expandable_segments:True\n\nMODEL_PATH=xxx\n\nexport SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=600\n\nLOCAL_HOST1=`hostname -I|awk -F \" \" '{print$1}'`\nLOCAL_HOST2=`hostname -I|awk -F \" \" '{print$2}'`\n\necho \"${LOCAL_HOST1}\"\necho \"${LOCAL_HOST2}\"\n\nexport HCCL_BUFFSIZE=1200\nexport HCCL_SOCKET_IFNAME=lo\nexport GLOO_SOCKET_IFNAME=lo\nexport HCCL_OP_EXPANSION_MODE=\"AIV\"\nexport SGLANG_ENABLE_OVERLAP_PLAN_STREAM=1\nexport SGLANG_ENABLE_SPEC_V2=1\n\npython -m sglang.launch_server --model-path $MODEL_PATH \\\n    --host 127.0.0.1 --port 7439 --trust-remote-code --nnodes 1 --node-rank 0  \\\n    --attention-backend ascend --device npu --quantization modelslim  \\\n    --max-running-requests 576 --context-length 8192 --dtype bfloat16 \\\n    --chunked-prefill-size 32768 --max-prefill-tokens 458880  \\\n    --speculative-algorithm EAGLE3 --speculative-draft-model-path xxx \\\n    --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 \\\n    --disable-radix-cache --moe-a2a-backend deepep  --deepep-mode auto --speculative-draft-model-quantization unquant  \\\n    --tp 16 --dp-size 16 --enable-dp-attention --enable-dp-lm-head --mem-fraction-static 0.81 --cuda-graph-bs 8 16 20 24 32 36\n\n```\n\n#### Benchmark\n\nWe tested it based on the `RANDOM` dataset.\n\n```shell\npython -m sglang.bench_serving --dataset-name random --backend sglang --host 127.0.0.1 --port 7439 --max-concurrency 576 --random-input-len 2000 --random-output-len 2000 --num-prompts 576 --random-range-ratio 1\n```\n\n### Qwen3-235B-A22B 2K-2K 50ms on A3 8 Cards Mixed Mode\n\nModel: Qwen3-235B-A22B-W8A8\n\nHardware: Atlas 800I A3 8Card\n\nDeployMode: PD Mixed\n\nDataset: random\n\nInput Output Length: 2K+2K\n\nTPOT: 50ms\n\n#### Model Deployment\n\n```shell\necho performance | tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor\nsysctl -w vm.swappiness=0\nsysctl -w kernel.numa_balancing=0\nsysctl -w kernel.sched_migration_cost_ns=50000\n\nexport SGLANG_SET_CPU_AFFINITY=1\n\nunset https_proxy\nunset http_proxy\nunset HTTPS_PROXY\nunset HTTP_PROXY\nunset ASCEND_LAUNCH_BLOCKING\nsource /usr/local/Ascend/ascend-toolkit/set_env.sh\nsource /usr/local/Ascend/nnal/atb/set_env.sh\nsource /usr/local/Ascend/ascend-toolkit/latest/opp/vendors/customize/bin/set_env.bash\nexport PATH=/usr/local/Ascend/8.5.0/compiler/bishengir/bin:$PATH\n\nexport PYTORCH_NPU_ALLOC_CONF=expandable_segments:True\n\nMODEL_PATH=xxx\n\nexport SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=600\n\nLOCAL_HOST1=`hostname -I|awk -F \" \" '{print$1}'`\nLOCAL_HOST2=`hostname -I|awk -F \" \" '{print$2}'`\n\necho \"${LOCAL_HOST1}\"\necho \"${LOCAL_HOST2}\"\n\nexport HCCL_BUFFSIZE=2100\nexport HCCL_SOCKET_IFNAME=xxx\nexport GLOO_SOCKET_IFNAME=xxx\nexport HCCL_OP_EXPANSION_MODE=\"AIV\"\nexport SGLANG_ENABLE_OVERLAP_PLAN_STREAM=1\nexport SGLANG_ENABLE_SPEC_V2=1\n\npython -m sglang.launch_server --model-path $MODEL_PATH \\\n    --host 127.0.0.1 --port 7439 --trust-remote-code --nnodes 1 --node-rank 0  \\\n    --attention-backend ascend --device npu --quantization modelslim  \\\n    --max-running-requests 480 --context-length 8192 --dtype bfloat16 \\\n    --chunked-prefill-size -1 --max-prefill-tokens 4096 --speculative-draft-model-quantization unquant  \\\n    --speculative-algorithm EAGLE3 --speculative-draft-model-path xxx \\\n    --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 \\\n    --disable-radix-cache --moe-a2a-backend deepep  --deepep-mode auto  \\\n    --tp 16 --dp-size 16 --enable-dp-attention --enable-dp-lm-head --mem-fraction-static 0.75 --cuda-graph-bs 6 8 10 12 15 18 28 30\n```\n\n#### Benchmark\n\nWe tested it based on the `RANDOM` dataset.\n\n```shell\npython -m sglang.bench_serving --dataset-name random --backend sglang --host 127.0.0.1 --port 7439 --max-concurrency 480 --random-input-len 2048 --random-output-len 2048 --num-prompts 480 --random-range-ratio 1\n```\n\n### Qwen3-235B-A22B 2K-2K 50ms on A3 16 Cards Mixed Mode\n\nModel: Qwen3-235B-A22B-W8A8\n\nHardware: Atlas 800I A3 16Card\n\nDeployMode: PD Mixed\n\nDataset: random\n\nInput Output Length: 2K+2K\n\nTPOT: 50ms\n\n#### Model Deployment\n\n```shell\necho performance | tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor\nsysctl -w vm.swappiness=0\nsysctl -w kernel.numa_balancing=0\nsysctl -w kernel.sched_migration_cost_ns=50000\n\nexport SGLANG_SET_CPU_AFFINITY=1\n\nunset https_proxy\nunset http_proxy\nunset HTTPS_PROXY\nunset HTTP_PROXY\nunset ASCEND_LAUNCH_BLOCKING\nsource /usr/local/Ascend/ascend-toolkit/set_env.sh\nsource /usr/local/Ascend/nnal/atb/set_env.sh\nsource /usr/local/Ascend/ascend-toolkit/latest/opp/vendors/customize/bin/set_env.bash\n\nexport PYTORCH_NPU_ALLOC_CONF=expandable_segments:True\n\nMODEL_PATH=xxx\n\nexport SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=600\n\nLOCAL_HOST1=`hostname -I|awk -F \" \" '{print$1}'`\nLOCAL_HOST2=`hostname -I|awk -F \" \" '{print$2}'`\n\necho \"${LOCAL_HOST1}\"\necho \"${LOCAL_HOST2}\"\n\nexport HCCL_BUFFSIZE=1600\nexport HCCL_SOCKET_IFNAME=xxx\nexport GLOO_SOCKET_IFNAME=xxx\nexport HCCL_OP_EXPANSION_MODE=\"AIV\"\n\nMIX_IP=('IP1' 'IP2')\n\nfor i in \"${!MIX_IP[@]}\";\ndo\n    if [[ \"$LOCAL_HOST1\" == \"${MIX_IP[$i]}\" || \"$LOCAL_HOST2\" == \"${MIX_IP[$i]}\" ]];\n    then\n        echo \"${MIX_IP[$i]}\"\n        export SGLANG_ENABLE_OVERLAP_PLAN_STREAM=1\n        export SGLANG_ENABLE_SPEC_V2=1\n\n        python -m sglang.launch_server --model-path ${MODEL_PATH} \\\n        --host 127.0.0.1 --port 7439 --trust-remote-code \\\n        --nnodes 2 --node-rank $i --tp-size 32 --dp-size 32 --mem-fraction-static 0.8 --max-running-requests 768 \\\n        --attention-backend ascend --device npu --quantization modelslim --enable-dp-attention \\\n        --moe-a2a-backend deepep --deepep-mode auto --cuda-graph-bs 6 8 10 12 18 24 \\\n        --dist-init-addr ${MIX_IP[0]}:5000 --chunked-prefill-size 131072 --max-prefill-tokens 458880 \\\n        --speculative-algorithm EAGLE3 --speculative-draft-model-path xxx --speculative-draft-model-quantization= unquant \\\n        --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 \\\n        --context-length 8192 --disable-radix-cache \\\n        --enable-dp-lm-head --dtype bfloat16\n        NODE_RANK=$i\n        break\n    fi\ndone\n\n```\n\n#### Benchmark\n\nWe tested it based on the `RANDOM` dataset.\n\n```shell\npython -m sglang.bench_serving --dataset-name random --backend sglang --host 127.0.0.1 --port 7439 --max-concurrency 768 --random-input-len 2000 --random-output-len 2000 --num-prompts 768 --random-range-ratio 1\n```\n\n### Qwen3-235B-A22B 11K-1K 10ms on A3 8 Cards Mixed Mode\n\nModel: Qwen3-235B-A22B-W8A8\n\nHardware: Atlas 800I A3 8Card\n\nDeployMode: PD Mixed\n\nDataset: random\n\nInput Output Length: 11K+1K\n\nTPOT: 10ms\n\n#### Model Deployment\n\n```shell\necho performance | tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor\nsysctl -w vm.swappiness=0\nsysctl -w kernel.numa_balancing=0\nsysctl -w kernel.sched_migration_cost_ns=50000\n\nexport SGLANG_SET_CPU_AFFINITY=1\n\nunset https_proxy\nunset http_proxy\nunset HTTPS_PROXY\nunset HTTP_PROXY\nunset ASCEND_LAUNCH_BLOCKING\nsource /usr/local/Ascend/ascend-toolkit/set_env.sh\nsource /usr/local/Ascend/nnal/atb/set_env.sh\nsource /usr/local/Ascend/ascend-toolkit/latest/opp/vendors/customize/bin/set_env.bash\nexport PATH=/usr/local/Ascend/8.5.0/compiler/bishengir/bin:$PATH\n\nexport PYTORCH_NPU_ALLOC_CONF=expandable_segments:True\n\nMODEL_PATH=xxx\n\nexport SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=600\n\nLOCAL_HOST1=`hostname -I|awk -F \" \" '{print$1}'`\nLOCAL_HOST2=`hostname -I|awk -F \" \" '{print$2}'`\n\necho \"${LOCAL_HOST1}\"\necho \"${LOCAL_HOST2}\"\n\nexport HCCL_BUFFSIZE=1600\nexport HCCL_SOCKET_IFNAME=xxx\nexport GLOO_SOCKET_IFNAME=xxx\nexport HCCL_OP_EXPANSION_MODE=\"AIV\"\nexport SGLANG_ENABLE_OVERLAP_PLAN_STREAM=1\nexport SGLANG_ENABLE_SPEC_V2=1\n\npython -m sglang.launch_server --model-path $MODEL_PATH \\\n    --host 127.0.0.1 --port 7439 --trust-remote-code --nnodes 1 --node-rank 0  \\\n    --attention-backend ascend --device npu --quantization modelslim  \\\n    --max-running-requests 1  --dtype bfloat16 \\\n    --chunked-prefill-size -1 --max-prefill-tokens 16384 --speculative-draft-model-quantization unquant  \\\n    --speculative-algorithm EAGLE3 --speculative-draft-model-path xxx \\\n    --speculative-num-steps 4 --speculative-eagle-topk 1 --speculative-num-draft-tokens 5 \\\n    --disable-radix-cache --enable-dp-lm-head \\\n    --tp 16 --mem-fraction-static 0.78 --cuda-graph-bs 1\n\n```\n\n#### Benchmark\n\nWe tested it based on the `RANDOM` dataset.\n\n```shell\npython -m sglang.bench_serving --dataset-name random --backend sglang --host 127.0.0.1 --port 7439 --max-concurrency 1 --random-input-len 11000 --random-output-len 1000 --num-prompts 1 --random-range-ratio 1\n```\n\n### Qwen3-32B 6K-1_5K 18ms on A3 4 Cards Mixed Mode\n\nModel: Qwen3-32B\n\nHardware: Atlas 800I A3 4Card\n\nDeployMode: PD Mixed\n\nDataset: random\n\nInput Output Length: 6K+1.5K\n\nTPOT: 18ms\n\n#### Model Deployment\n\n```shell\necho performance | tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor\nsysctl -w vm.swappiness=0\nsysctl -w kernel.numa_balancing=0\nsysctl -w kernel.sched_migration_cost_ns=50000\n\nexport SGLANG_SET_CPU_AFFINITY=1\n\nunset https_proxy\nunset http_proxy\nunset HTTPS_PROXY\nunset HTTP_PROXY\nunset ASCEND_LAUNCH_BLOCKING\nsource /usr/local/Ascend/ascend-toolkit/set_env.sh\nsource /usr/local/Ascend/nnal/atb/set_env.sh\nsource /usr/local/Ascend/ascend-toolkit/latest/opp/vendors/customize/bin/set_env.bash\nexport PATH=/usr/local/Ascend/8.5.0/compiler/bishengir/bin:$PATH\n\nMODEL_PATH=xxx\n\nexport SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=600\n\nLOCAL_HOST1=`hostname -I|awk -F \" \" '{print$1}'`\nLOCAL_HOST2=`hostname -I|awk -F \" \" '{print$2}'`\n\necho \"${LOCAL_HOST1}\"\necho \"${LOCAL_HOST2}\"\n\nexport HCCL_BUFFSIZE=400\nexport HCCL_SOCKET_IFNAME=xxx\nexport GLOO_SOCKET_IFNAME=xxx\nexport HCCL_OP_EXPANSION_MODE=\"AIV\"\nexport SGLANG_ENABLE_OVERLAP_PLAN_STREAM=1\nexport SGLANG_ENABLE_SPEC_V2=1\n\npython -m sglang.launch_server --model-path $MODEL_PATH \\\n    --host 127.0.0.1 --port 7439 --trust-remote-code --nnodes 1 --node-rank 0  \\\n    --attention-backend ascend --device npu \\\n    --max-running-requests 32 \\\n    --disable-radix-cache \\\n    --chunked-prefill-size 24576 --max-prefill-tokens 65536 \\\n    --speculative-algorithm EAGLE3 --speculative-draft-model-path xxx \\\n    --speculative-num-steps 4 --speculative-eagle-topk 1 --speculative-num-draft-tokens 5 \\\n    --tp-size 8 --mem-fraction-static 0.72 --cuda-graph-bs 8 16 24 32  --dtype bfloat16\n\n```\n\n#### Benchmark\n\nWe tested it based on the `RANDOM` dataset.\n\n```shell\npython3 -m sglang.bench_serving --dataset-name random --backend sglang --host 127.0.0.1 --port 7439 --max-concurrency 32 --random-output-len 1500 --random-input-len 6000 --num-prompts 32 --random-range-ratio 1\n```\n\n### Qwen3-32B 4K-1_5K 11ms on A3 4 Cards Mixed Mode\n\nModel: Qwen3-32B\n\nHardware: Atlas 800I A3 4Card\n\nDeployMode: PD Mixed\n\nDataset: random\n\nInput Output Length: 4K+1.5K\n\nTPOT: 11ms\n\n#### Model Deployment\n\n```shell\necho performance | tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor\nsysctl -w vm.swappiness=0\nsysctl -w kernel.numa_balancing=0\nsysctl -w kernel.sched_migration_cost_ns=50000\n\nexport SGLANG_SET_CPU_AFFINITY=1\nunset https_proxy\nunset http_proxy\nunset HTTPS_PROXY\nunset HTTP_PROXY\nunset ASCEND_LAUNCH_BLOCKING\nsource /usr/local/Ascend/ascend-toolkit/set_env.sh\nsource /usr/local/Ascend/nnal/atb/set_env.sh\nsource /usr/local/Ascend/ascend-toolkit/latest/opp/vendors/customize/bin/set_env.bash\nexport PATH=/usr/local/Ascend/8.5.0/compiler/bishengir/bin:$PATH\n\nMODEL_PATH=xxx\n\nexport SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=600\n\nLOCAL_HOST1=`hostname -I|awk -F \" \" '{print$1}'`\nLOCAL_HOST2=`hostname -I|awk -F \" \" '{print$2}'`\n\necho \"${LOCAL_HOST1}\"\necho \"${LOCAL_HOST2}\"\n\nexport HCCL_BUFFSIZE=400\nexport HCCL_SOCKET_IFNAME=lo\nexport GLOO_SOCKET_IFNAME=lo\nexport HCCL_OP_EXPANSION_MODE=\"AIV\"\nexport SGLANG_ENABLE_OVERLAP_PLAN_STREAM=1\nexport SGLANG_ENABLE_SPEC_V2=1\n\npython -m sglang.launch_server --model-path $MODEL_PATH \\\n    --host 127.0.0.1 --port 7339 --trust-remote-code --nnodes 1 --node-rank 0  \\\n    --attention-backend ascend --device npu   \\\n    --max-running-requests 1 \\\n    --disable-radix-cache \\\n    --speculative-algorithm EAGLE3 --speculative-draft-model-path xxx \\\n    --speculative-num-steps 4 --speculative-eagle-topk 1 --speculative-num-draft-tokens 5 \\\n    --chunked-prefill-size 24576 --max-prefill-tokens 65536  \\\n    --tp-size 8 --mem-fraction-static 0.72 --cuda-graph-bs 1 --dtype bfloat16\n\n```\n\n#### Benchmark\n\nWe tested it based on the `RANDOM` dataset.\n\n```shell\npython3 -m sglang.bench_serving --dataset-name random --backend sglang --host 127.0.0.1 --port 7239 --random-range-ratio 1 --max-concurrency 1 --random-output-len 1500 --random-input-len 4096 --num-prompts 4\n```\n\n### Qwen3-32B 18K-4K 12ms on A3 8 Cards Mixed Mode\n\nModel: Qwen3-32B\n\nHardware: Atlas 800I A3 8Card\n\nDeployMode: PD Mixed\n\nDataset: random\n\nInput Output Length: 18K+4K\n\nTPOT: 12ms\n\n#### Model Deployment\n\n```shell\necho performance | tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor\nsysctl -w vm.swappiness=0\nsysctl -w kernel.numa_balancing=0\nsysctl -w kernel.sched_migration_cost_ns=50000\n\nexport SGLANG_SET_CPU_AFFINITY=1\nunset https_proxy\nunset http_proxy\nunset HTTPS_PROXY\nunset HTTP_PROXY\nunset ASCEND_LAUNCH_BLOCKING\nsource /usr/local/Ascend/ascend-toolkit/set_env.sh\nsource /usr/local/Ascend/nnal/atb/set_env.sh\nsource /usr/local/Ascend/ascend-toolkit/latest/opp/vendors/customize/bin/set_env.bash\nexport PATH=/usr/local/Ascend/8.5.0/compiler/bishengir/bin:$PATH\n\nMODEL_PATH=xxx\n\nexport SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=600\n\nLOCAL_HOST1=`hostname -I|awk -F \" \" '{print$1}'`\nLOCAL_HOST2=`hostname -I|awk -F \" \" '{print$2}'`\n\necho \"${LOCAL_HOST1}\"\necho \"${LOCAL_HOST2}\"\n\nexport HCCL_BUFFSIZE=400\nexport HCCL_SOCKET_IFNAME=lo\nexport GLOO_SOCKET_IFNAME=lo\nexport HCCL_OP_EXPANSION_MODE=\"AIV\"\nexport SGLANG_ENABLE_OVERLAP_PLAN_STREAM=1\nexport SGLANG_ENABLE_SPEC_V2=1\n\npython -m sglang.launch_server --model-path $MODEL_PATH \\\n    --host 127.0.0.1 --port 7339 --trust-remote-code --nnodes 1 --node-rank 0  \\\n    --attention-backend ascend --device npu   \\\n    --max-running-requests 1 \\\n    --disable-radix-cache --speculative-draft-model-quantization unquant \\\n    --speculative-algorithm EAGLE3 --speculative-draft-model-path xxx \\\n    --speculative-num-steps 4 --speculative-eagle-topk 1 --speculative-num-draft-tokens 5 \\\n    --chunked-prefill-size -1 --max-prefill-tokens 65536  \\\n    --tp-size 16 --mem-fraction-static 0.72 --cuda-graph-bs 1 --dtype bfloat16\n```\n\n#### Benchmark\n\nWe tested it based on the `RANDOM` dataset.\n\n```shell\npython3 -m sglang.bench_serving --dataset-name random --backend sglang --host 127.0.0.1 --port 7339 --random-range-ratio 1 --max-concurrency 1 --random-output-len 18000 --random-input-len 4000 --num-prompts 1\n```\n\n### Qwen3-32B 3_5K-1_5K 50ms on A3 2 Cards Mixed Mode\n\nModel: Qwen3-32B\n\nHardware: Atlas 800I A3 2Card\n\nDeployMode: PD Mixed\n\nDataset: random\n\nInput Output Length: 3.5K+1.5K\n\nTPOT: 50ms\n\n#### Model Deployment\n\n```shell\necho performance | tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor\nsysctl -w vm.swappiness=0\nsysctl -w kernel.numa_balancing=0\nsysctl -w kernel.sched_migration_cost_ns=50000\n\nexport SGLANG_SET_CPU_AFFINITY=1\nunset https_proxy\nunset http_proxy\nunset HTTPS_PROXY\nunset HTTP_PROXY\nunset ASCEND_LAUNCH_BLOCKING\nsource /usr/local/Ascend/ascend-toolkit/set_env.sh\nsource /usr/local/Ascend/nnal/atb/set_env.sh\nsource /usr/local/Ascend/ascend-toolkit/latest/opp/vendors/customize/bin/set_env.bash\nexport PATH=/usr/local/Ascend/8.5.0/compiler/bishengir/bin:$PATH\n\n\nMODEL_PATH=xxx\n\nexport SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=600\n\nLOCAL_HOST1=`hostname -I|awk -F \" \" '{print$1}'`\nLOCAL_HOST2=`hostname -I|awk -F \" \" '{print$2}'`\n\necho \"${LOCAL_HOST1}\"\necho \"${LOCAL_HOST2}\"\n\nexport HCCL_BUFFSIZE=400\nexport HCCL_SOCKET_IFNAME=lo\nexport GLOO_SOCKET_IFNAME=lo\nexport HCCL_OP_EXPANSION_MODE=\"AIV\"\nexport SGLANG_ENABLE_OVERLAP_PLAN_STREAM=1\nexport SGLANG_ENABLE_SPEC_V2=1\n\npython -m sglang.launch_server --model-path $MODEL_PATH \\\n    --host 127.0.0.1 --port 7239 --trust-remote-code --nnodes 1 --node-rank 0  \\\n    --attention-backend ascend --device npu  --quantization modelslim  \\\n    --max-running-requests 78 \\\n    --disable-radix-cache --speculative-draft-model-quantization unquant \\\n    --chunked-prefill-size -1 --max-prefill-tokens 49152  \\\n    --speculative-algorithm EAGLE3 --speculative-draft-model-path xxx \\\n    --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 \\\n    --tp-size 4  --mem-fraction-static 0.72 --cuda-graph-bs 16 32 64 68 72 78 --dtype bfloat16\n```\n\n#### Benchmark\n\nWe tested it based on the `RANDOM` dataset.\n\n```shell\npython3 -m sglang.bench_serving --dataset-name random --backend sglang --host 127.0.0.1 --port 7239 --max-concurrency 78 --random-output-len 1500 --random-input-len 3500 --num-prompts 312 --random-range-ratio 1\n```\n\n### Qwen3-32B 2K-2K 50ms on A3 2 Cards Mixed Mode\n\nModel: Qwen3-32B\n\nHardware: Atlas 800I A3 2Card\n\nDeployMode: PD Mixed\n\nDataset: random\n\nInput Output Length: 2K+2K\n\nTPOT: 50ms\n\n#### Model Deployment\n\n```shell\necho performance | tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor\nsysctl -w vm.swappiness=0\nsysctl -w kernel.numa_balancing=0\nsysctl -w kernel.sched_migration_cost_ns=50000\n\nexport SGLANG_SET_CPU_AFFINITY=1\nunset https_proxy\nunset http_proxy\nunset HTTPS_PROXY\nunset HTTP_PROXY\nunset ASCEND_LAUNCH_BLOCKING\nsource /usr/local/Ascend/ascend-toolkit/set_env.sh\nsource /usr/local/Ascend/nnal/atb/set_env.sh\nsource /usr/local/Ascend/ascend-toolkit/latest/opp/vendors/customize/bin/set_env.bash\nexport PATH=/usr/local/Ascend/8.5.0/compiler/bishengir/bin:$PATH\n\nMODEL_PATH=xxx\n\nexport SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=600\n\nLOCAL_HOST1=`hostname -I|awk -F \" \" '{print$1}'`\nLOCAL_HOST2=`hostname -I|awk -F \" \" '{print$2}'`\n\necho \"${LOCAL_HOST1}\"\necho \"${LOCAL_HOST2}\"\n\nexport HCCL_BUFFSIZE=400\nexport HCCL_SOCKET_IFNAME=lo\nexport GLOO_SOCKET_IFNAME=lo\nexport HCCL_OP_EXPANSION_MODE=\"AIV\"\nexport SGLANG_ENABLE_OVERLAP_PLAN_STREAM=1\nexport SGLANG_ENABLE_SPEC_V2=1\n\npython -m sglang.launch_server --model-path $MODEL_PATH \\\n    --host 127.0.0.1 --port 7239 --trust-remote-code --nnodes 1 --node-rank 0  \\\n    --attention-backend ascend --device npu  --quantization modelslim  \\\n    --max-running-requests 120 \\\n    --disable-radix-cache --speculative-draft-model-quantization unquant \\\n    --speculative-algorithm EAGLE3 --speculative-draft-model-path xxx \\\n    --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 \\\n    --chunked-prefill-size -1 --max-prefill-tokens 49152 \\\n    --tp-size 4 --mem-fraction-static 0.7 --cuda-graph-bs 54 60 66 72 78 84 90 108 114 120 --dtype bfloat16\n\n```\n\n#### Benchmark\n\nWe tested it based on the `RANDOM` dataset.\n\n```shell\npython3 -m sglang.bench_serving --dataset-name random --backend sglang --host 127.0.0.1 --port 7239 --max-concurrency 120 --random-output-len 2000 --random-input-len 2000 --num-prompts 480 --random-range-ratio 1\n```\n\n### Qwen3-30B-A3B 3_5K-1_5K 50ms on A3 1 Card Mixed Mode\n\nModel: Qwen3-30B-A3B-Instruct-2507\n\nHardware: Atlas 800I A3 1Card\n\nDeployMode: PD Mixed\n\nDataset: random\n\nInput Output Length: 3.5K+1.5K\n\nTPOT: 50ms\n\n#### Model Deployment\n\n```shell\necho performance | tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor\nsysctl -w vm.swappiness=0\nsysctl -w kernel.numa_balancing=0\nsysctl -w kernel.sched_migration_cost_ns=50000\n\nexport SGLANG_SET_CPU_AFFINITY=1\nunset https_proxy\nunset http_proxy\nunset HTTPS_PROXY\nunset HTTP_PROXY\nunset ASCEND_LAUNCH_BLOCKING\nsource /usr/local/Ascend/ascend-toolkit/set_env.sh\nsource /usr/local/Ascend/nnal/atb/set_env.sh\nsource /usr/local/Ascend/ascend-toolkit/latest/opp/vendors/customize/bin/set_env.bash\nexport PATH=/usr/local/Ascend/8.5.0/compiler/bishengir/bin:$PATH\n\nMODEL_PATH=xxx\n\nexport SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=600\n\nLOCAL_HOST1=`hostname -I|awk -F \" \" '{print$1}'`\nLOCAL_HOST2=`hostname -I|awk -F \" \" '{print$2}'`\n\necho \"${LOCAL_HOST1}\"\necho \"${LOCAL_HOST2}\"\n\nexport HCCL_BUFFSIZE=400\nexport HCCL_SOCKET_IFNAME=lo\nexport GLOO_SOCKET_IFNAME=lo\nexport HCCL_OP_EXPANSION_MODE=\"AIV\"\nexport SGLANG_ENABLE_OVERLAP_PLAN_STREAM=1\nexport SGLANG_ENABLE_SPEC_V2=1\n\npython -m sglang.launch_server --model-path $MODEL_PATH \\\n    --host 127.0.0.1 --port 7239 --trust-remote-code --nnodes 1 --node-rank 0  \\\n    --attention-backend ascend --device npu  --quantization modelslim  \\\n    --max-running-requests 192 \\\n    --disable-radix-cache \\\n    --speculative-draft-model-quantization unquant \\\n    --speculative-algorithm EAGLE3 --speculative-draft-model-path xxx \\\n    --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 \\\n    --chunked-prefill-size -1 --max-prefill-tokens 32768 \\\n    --tp-size 2 --mem-fraction-static 0.86 --cuda-graph-bs 42 88 96 132 144 156 172 178 192 --dtype bfloat16\n```\n\n#### Benchmark\n\nWe tested it based on the `RANDOM` dataset.\n\n```shell\npython -m sglang.bench_serving --dataset-name random --backend sglang --host 127.0.0.1 --port 7239 --max-concurrency 156 --random-input-len 3500 --random-output-len 1500 --num-prompts 624 --random-range-ratio 1\n```\n\n### Qwen3-Coder-480B-A35B-Instruct 3_5K-1_5K 50ms on A3 24 Cards Separation Mode\n\nModel: Qwen3-Coder-480B-A35B-Instruct\n\nHardware: Atlas 800I A3 24Card\n\nDeployMode: PD Separation\n\nDataset: random\n\nInput Output Length: 3.5K+1.5K\n\nTPOT: 50ms\n\n#### Model Deployment\n\n```shell\necho performance | tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor\nsysctl -w vm.swappiness=0\nsysctl -w kernel.numa_balancing=0\nsysctl -w kernel.sched_migration_cost_ns=50000\n\nexport SGLANG_SET_CPU_AFFINITY=1\nunset https_proxy\nunset http_proxy\nunset HTTPS_PROXY\nunset HTTP_PROXY\nunset ASCEND_LAUNCH_BLOCKING\n\nsource /usr/local/Ascend/ascend-toolkit/latest/opp/vendors/customize/bin/set_env.bash\nexport PYTORCH_NPU_ALLOC_CONF=expandable_segments:True\n\nexport SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=16\n\nMODEL_PATH=xxx\nexport ASCEND_MF_STORE_URL=\"tcp://PIP:24667\"\nP_IP=('PIP')\nD_IP=('DIP1' 'DIP2')\nexport SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=600\n\nLOCAL_HOST1=`hostname -I|awk -F \" \" '{print$1}'`\nLOCAL_HOST2=`hostname -I|awk -F \" \" '{print$2}'`\n\necho \"${LOCAL_HOST1}\"\necho \"${LOCAL_HOST2}\"\n\n\nfor i in \"${!P_IP[@]}\";\ndo\n    if [[ \"$LOCAL_HOST1\" == \"${P_IP[$i]}\" || \"$LOCAL_HOST2\" == \"${P_IP[$i]}\" ]];\n    then\n        echo \"${P_IP[$i]}\"\n        source /usr/local/Ascend/ascend-toolkit/set_env.sh\n        source /usr/local/Ascend/nnal/atb/set_env.sh\n        export DEEPEP_NORMAL_LONG_SEQ_PER_ROUND_TOKENS=1024\n        export DEEPEP_NORMAL_LONG_SEQ_ROUND=16\n        export HCCL_BUFFSIZE=4300\n        export TASK_QUEUE_ENABLE=2\n        export HCCL_SOCKET_IFNAME=lo\n        export GLOO_SOCKET_IFNAME=lo\n        export STREAMS_PER_DEVICE=32\n        export DEEP_NORMAL_MODE_USE_INT8_QUANT=1\n\n        python -m sglang.launch_server --model-path ${MODEL_PATH} --disaggregation-mode prefill \\\n        --host ${P_IP[$i]} --port 8000 --disaggregation-bootstrap-port 8995 --trust-remote-code \\\n        --nnodes 1 --node-rank $i --tp-size 16 --dp-size 2 --mem-fraction-static 0.6 \\\n        --disable-radix-cache \\\n\t      --attention-backend ascend --device npu --quantization modelslim --disaggregation-transfer-backend ascend \\\n\t      --max-running-requests 128 --chunked-prefill-size 65536 --max-prefill-tokens 262144 \\\n        --enable-dp-attention  \\\n        --moe-a2a-backend deepep --deepep-mode normal --dtype bfloat16\n        NODE_RANK=$i\n        break\n    fi\ndone\n\nfor i in \"${!D_IP[@]}\";\ndo\n    if [[ \"$LOCAL_HOST1\" == \"${D_IP[$i]}\" || \"$LOCAL_HOST2\" == \"${D_IP[$i]}\" ]];\n    then\n        echo \"${D_IP[$i]}\"\n        source /usr/local/Ascend/ascend-toolkit/set_env.sh\n        source /usr/local/Ascend/nnal/atb/set_env.sh\n        export SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=72\n        export HCCL_BUFFSIZE=512\n        export HCCL_SOCKET_IFNAME=xxx\n        export GLOO_SOCKET_IFNAME=xxx\n        export STREAMS_PER_DEVICE=32\n\n        python -m sglang.launch_server --model-path ${MODEL_PATH} --disaggregation-mode decode \\\n        --host ${D_IP[$i]} --port 8001 --trust-remote-code \\\n        --nnodes 2 --node-rank $i --tp-size 32 --dp-size 4 --mem-fraction-static 0.73 --max-running-requests 384 \\\n        --attention-backend ascend --device npu --quantization modelslim --enable-dp-attention \\\n        --moe-a2a-backend ascend_fuseep --cuda-graph-bs 16 32 48 56 64 72 80 88 96 \\\n        --dist-init-addr DIP1:5000 \\\n\t      --disaggregation-transfer-backend ascend --watchdog-timeout 9000 --context-length 8192 \\\n        --prefill-round-robin-balance --enable-dp-lm-head --dtype bfloat16 --tokenizer-worker-num 4 --load-balance-method decode_round_robin\n        NODE_RANK=$i\n        break\n    fi\ndone\n\n```\n\n```shell\nexport SGLANG_DP_ROUND_ROBIN=1\npython -m sglang_router.launch_router \\\n    --pd-disaggregation \\\n    --policy cache_aware \\\n    --prefill http://PIP:8000 8995 \\\n    --decode http://DIP:8001 \\\n    --host 127.0.0.1 \\\n    --port 6688 \\\n    --mini-lb\n```\n\n#### Benchmark\n\nWe tested it based on the `RANDOM` dataset.\n\n```shell\npython -m sglang.bench_serving --dataset-name random --backend sglang --host 127.0.0.1 --port 7239 --max-concurrency 410 --random-input-len 3500 --random-output-len 1500 --num-prompts 1640 --random-range-ratio 1 --request-rate 8\n```\n\n### Qwen3-Coder-480B-A35B-Instruct 3_5K-1_5K 50ms on A3 16 Cards Mixed Mode\n\nModel: Qwen3-Coder-480B-A35B-Instruct\n\nHardware: Atlas 800I A3 16Card\n\nDeployMode: PD Mixed\n\nDataset: random\n\nInput Output Length: 3.5K+1.5K\n\nTPOT: 50ms\n\n#### Model Deployment\n\n```shell\necho performance | tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor\nsysctl -w vm.swappiness=0\nsysctl -w kernel.numa_balancing=0\nsysctl -w kernel.sched_migration_cost_ns=50000\n\nexport SGLANG_SET_CPU_AFFINITY=1\nunset https_proxy\nunset http_proxy\nunset HTTPS_PROXY\nunset HTTP_PROXY\nunset ASCEND_LAUNCH_BLOCKING\nsource /usr/local/Ascend/ascend-toolkit/set_env.sh\nsource /usr/local/Ascend/nnal/atb/set_env.sh\nsource /usr/local/Ascend/ascend-toolkit/latest/opp/vendors/customize/bin/set_env.bash\n\nexport PYTORCH_NPU_ALLOC_CONF=expandable_segments:True\n\nexport SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=16\n\nexport DEEP_NORMAL_MODE_USE_INT8_QUANT=1\n\nMODEL_PATH=xxx\n\nexport SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=600\n\nLOCAL_HOST1=`hostname -I|awk -F \" \" '{print$1}'`\nLOCAL_HOST2=`hostname -I|awk -F \" \" '{print$2}'`\n\necho \"${LOCAL_HOST1}\"\necho \"${LOCAL_HOST2}\"\n\nexport HCCL_BUFFSIZE=1800\nexport HCCL_SOCKET_IFNAME=xxx\nexport GLOO_SOCKET_IFNAME=xxx\nexport HCCL_OP_EXPANSION_MODE=\"AIV\"\n\nMIX_IP=('IP1' 'IP2')\n\nfor i in \"${!MIX_IP[@]}\";\ndo\n    if [[ \"$LOCAL_HOST1\" == \"${MIX_IP[$i]}\" || \"$LOCAL_HOST2\" == \"${MIX_IP[$i]}\" ]];\n    then\n        echo \"${MIX_IP[$i]}\"\n\n        python -m sglang.launch_server --model-path $MODEL_PATH \\\n        --host 127.0.0.1 --port 7439 --trust-remote-code --nnodes 2 --node-rank $i  \\\n        --dist-init-addr 141.61.133.128:5000 \\\n        --attention-backend ascend --device npu --quantization modelslim  \\\n        --max-running-requests 288 --context-length 8192 --dtype bfloat16  \\\n        --chunked-prefill-size 114688 --max-prefill-tokens 458880  \\\n        --disable-radix-cache --moe-a2a-backend deepep  --deepep-mode auto  \\\n        --tp 32 --dp-size 4 --enable-dp-attention --enable-dp-lm-head --mem-fraction-static 0.7 --cuda-graph-bs 56 64 72\n        NODE_RANK=$i\n        break\n    fi\ndone\n```\n\n#### Benchmark\n\nWe tested it based on the `RANDOM` dataset.\n\n```shell\npython -m sglang.bench_serving --dataset-name random --backend sglang --host 127.0.0.1 --port 7439 --max-concurrency 288 --random-input-len 3500 --random-output-len 1500 --num-prompts 1152 --random-range-ratio 1 --request-rate 20\n```\n\n### Qwen3-Coder-480B-A35B-Instruct 3_5K-1_5K 50ms on A3 8 Cards Mixed Mode\n\nModel: Qwen3-Coder-480B-A35B-Instruct\n\nHardware: Atlas 800I A3 8Card\n\nDeployMode: PD Mixed\n\nDataset: random\n\nInput Output Length: 3.5K+1.5K\n\nTPOT: 50ms\n\n#### Model Deployment\n\n```shell\necho performance | tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor\nsysctl -w vm.swappiness=0\nsysctl -w kernel.numa_balancing=0\nsysctl -w kernel.sched_migration_cost_ns=50000\n\nexport SGLANG_SET_CPU_AFFINITY=1\n\nunset https_proxy\nunset http_proxy\nunset HTTPS_PROXY\nunset HTTP_PROXY\nunset ASCEND_LAUNCH_BLOCKING\nsource /usr/local/Ascend/ascend-toolkit/set_env.sh\nsource /usr/local/Ascend/nnal/atb/set_env.sh\nsource /usr/local/Ascend/ascend-toolkit/latest/opp/vendors/customize/bin/set_env.bash\nexport PATH=/usr/local/Ascend/8.5.0/compiler/bishengir/bin:$PATH\n\nexport PYTORCH_NPU_ALLOC_CONF=expandable_segments:True\n\nMODEL_PATH=xxx\n\nexport SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=600\n\nLOCAL_HOST1=`hostname -I|awk -F \" \" '{print$1}'`\nLOCAL_HOST2=`hostname -I|awk -F \" \" '{print$2}'`\n\necho \"${LOCAL_HOST1}\"\necho \"${LOCAL_HOST2}\"\n\nexport HCCL_BUFFSIZE=2100\nexport HCCL_SOCKET_IFNAME=lo\nexport GLOO_SOCKET_IFNAME=lo\nexport HCCL_OP_EXPANSION_MODE=\"AIV\"\n\npython -m sglang.launch_server --model-path $MODEL_PATH \\\n--host 127.0.0.1 --port 7439 --trust-remote-code --nnodes 1 --node-rank 0  \\\n--attention-backend ascend --device npu --quantization modelslim  \\\n--max-running-requests 80 --context-length 8192 --dtype bfloat16 \\\n--chunked-prefill-size 28672 --max-prefill-tokens 458880  \\\n--disable-radix-cache --moe-a2a-backend deepep  --deepep-mode auto --enable-dp-attention --enable-dp-lm-head \\\n--tp 16 --dp-size 4 --mem-fraction-static 0.7 --cuda-graph-bs  16 20 24\n```\n\n#### Benchmark\n\nWe tested it based on the `RANDOM` dataset.\n\n```shell\npython -m sglang.bench_serving --dataset-name random --backend sglang --host 127.0.0.1 --port 7439 --max-concurrency 80 --random-input-len 3500 --random-output-len 1500 --num-prompts 320 --random-range-ratio 1\n```\n\n### Qwen3-Next-80B-A3B-Instruct 3_5K-1_5K 50ms on A3 2 Cards Mixed Mode\n\nModel: Qwen3-Next-80B-A3B-Instruct\n\nHardware: Atlas 800I A3 2Card\n\nDeployMode: PD Mixed\n\nDataset: random\n\nInput Output Length: 3.5K+1.5K\n\nTPOT: 50ms\n\n#### Model Deployment\n\n```shell\nexport cann_path=/usr/local/Ascend/ascend-toolkit/latest\nsource /usr/local/Ascend/driver/bin/setenv.bash\nsource ${cann_path}/../set_env.sh\nsource ${cann_path}/../../nnal/atb/set_env.sh\nsource ${cann_path}/opp/vendors/customize/bin/set_env.bash\nexport ASCEND_HOME_PATH=${cann_path}\nsource /usr/local/Ascend/8.5.0/bisheng_toolkit/set_env.sh\n\necho performance | tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor\nsysctl -w vm.swappiness=0\nsysctl -w kernel.numa_balancing=0\nsysctl -w kernel.sched_migration_cost_ns=50000\n\nexport SGLANG_SET_CPU_AFFINITY=1\n\nexport PYTORCH_NPU_ALLOC_CONF=expandable_segments:True\nexport STREAMS_PER_DEVICE=32\nexport HCCL_SOCKET_IFNAME=lo\nexport GLOO_SOCKET_IFNAME=lo\n\nexport HCCL_OP_EXPANSION_MODE=AIV\nexport HCCL_ALGO=\"level0:NA;level1:ring\"\n\nexport SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=20\nexport HCCL_BUFFSIZE=2000\n\npython -m sglang.launch_server \\\n        --model-path /path/to/Qwen3-Next-80B-A3B-Instruct-W8A8-3 \\\n        --host 127.0.0.1 \\\n        --port 6699 \\\n        --tp-size 4 \\\n        --device npu \\\n        --attention-backend ascend \\\n        --mem-fraction-static 0.685 \\\n        --max-running-requests 80 \\\n        --watchdog-timeout 3600 \\\n        --disable-radix-cache \\\n        --cuda-graph-bs 80 \\\n        --max-prefill-tokens 28672  --max-total-tokens 450560 \\\n        --moe-a2a-backend deepep --deepep-mode auto \\\n        --quantization modelslim \\\n        --chunked-prefill-size -1\n```\n\n#### Benchmark\n\nWe tested it based on the `RANDOM` dataset.\n\n```shell\npython3 -m sglang.bench_serving --dataset-name random --backend sglang --host 127.0.0.1 --port 6699 --max-concurrency 80 --random-output-len 1536 --random-input-len 3584 --num-prompts 160 --random-range-ratio 1\n```\n\n### Qwen3-32B 6K-1_5K 18ms on A2 8 Cards Mixed Mode\n\nModel: Qwen3-32B\n\nHardware: Atlas 800I A2 8Card\n\nDeployMode: PD Mixed\n\nDataset: random\n\nInput Output Length: 6K+1.5K\n\nTPOT: 18ms\n\n#### Model Deployment\n\n```shell\necho performance | tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor\nsysctl -w vm.swappiness=0\nsysctl -w kernel.numa_balancing=0\nsysctl -w kernel.sched_migration_cost_ns=50000\n\nexport SGLANG_SET_CPU_AFFINITY=1\nunset https_proxy\nunset http_proxy\nunset HTTPS_PROXY\nunset HTTP_PROXY\nunset ASCEND_LAUNCH_BLOCKING\nsource /usr/local/Ascend/ascend-toolkit/set_env.sh\nsource /usr/local/Ascend/nnal/atb/set_env.sh\nsource /usr/local/Ascend/ascend-toolkit/latest/opp/vendors/customize/bin/set_env.bash\nexport PATH=/usr/local/Ascend/8.5.0/compiler/bishengir/bin:$PATH\n\nMODEL_PATH=xxx\n\nexport SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=600\n\nLOCAL_HOST1=`hostname -I|awk -F \" \" '{print$1}'`\nLOCAL_HOST2=`hostname -I|awk -F \" \" '{print$2}'`\n\necho \"${LOCAL_HOST1}\"\necho \"${LOCAL_HOST2}\"\n\nexport HCCL_BUFFSIZE=400\nexport HCCL_SOCKET_IFNAME=lo\nexport GLOO_SOCKET_IFNAME=lo\nexport HCCL_OP_EXPANSION_MODE=\"AIV\"\nexport SGLANG_ENABLE_OVERLAP_PLAN_STREAM=1\nexport SGLANG_ENABLE_SPEC_V2=1\n\npython -m sglang.launch_server --model-path $MODEL_PATH \\\n    --host 127.0.0.1 --port 7439 --trust-remote-code --nnodes 1 --node-rank 0  \\\n    --attention-backend ascend --device npu  --quantization modelslim  \\\n    --max-running-requests 32 \\\n    --disable-radix-cache \\\n    --chunked-prefill-size 24576 --max-prefill-tokens 65536 \\\n    --speculative-algorithm EAGLE3 --speculative-draft-model-path xxx \\\n    --speculative-num-steps 4 --speculative-eagle-topk 1 --speculative-num-draft-tokens 5 \\\n    --tp-size 8 --mem-fraction-static 0.72 --cuda-graph-bs 8 16 24 32 --dtype bfloat16\n```\n\n#### Benchmark\n\nWe tested it based on the `RANDOM` dataset.\n\n```shell\npython3 -m sglang.bench_serving --dataset-name random --backend sglang --host 127.0.0.1 --port 7439 --max-concurrency 32 --random-output-len 1500 --random-input-len 6000 --num-prompts 32 --random-range-ratio 1\n```\n\n### Qwen3-32B 4K-1_5K 11ms on A2 8 Cards Mixed Mode\n\nModel: Qwen3-32B\n\nHardware: Atlas 800I A2 8Card\n\nDeployMode: PD Mixed\n\nDataset: random\n\nInput Output Length: 4K+1.5K\n\nTPOT: 11ms\n\n#### Model Deployment\n\n```shell\necho performance | tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor\nsysctl -w vm.swappiness=0\nsysctl -w kernel.numa_balancing=0\nsysctl -w kernel.sched_migration_cost_ns=50000\n\nexport SGLANG_SET_CPU_AFFINITY=1\n\nunset https_proxy\nunset http_proxy\nunset HTTPS_PROXY\nunset HTTP_PROXY\nunset ASCEND_LAUNCH_BLOCKING\nsource /usr/local/Ascend/ascend-toolkit/set_env.sh\nsource /usr/local/Ascend/nnal/atb/set_env.sh\nsource /usr/local/Ascend/ascend-toolkit/latest/opp/vendors/customize/bin/set_env.bash\nexport PATH=/usr/local/Ascend/8.5.0/compiler/bishengir/bin:$PATH\n\nexport PYTORCH_NPU_ALLOC_CONF=expandable_segments:True\n\nMODEL_PATH=xxx\n\nexport SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=600\n\nLOCAL_HOST1=`hostname -I|awk -F \" \" '{print$1}'`\nLOCAL_HOST2=`hostname -I|awk -F \" \" '{print$2}'`\n\necho \"${LOCAL_HOST1}\"\necho \"${LOCAL_HOST2}\"\n\nexport HCCL_BUFFSIZE=400\nexport HCCL_SOCKET_IFNAME=lo\nexport GLOO_SOCKET_IFNAME=lo\nexport HCCL_OP_EXPANSION_MODE=\"AIV\"\nexport SGLANG_ENABLE_OVERLAP_PLAN_STREAM=1\nexport SGLANG_ENABLE_SPEC_V2=1\n\npython -m sglang.launch_server --model-path $MODEL_PATH \\\n    --host 127.0.0.1 --port 7339 --trust-remote-code --nnodes 1 --node-rank 0  \\\n    --attention-backend ascend --device npu   \\\n    --max-running-requests 32 \\\n    --disable-radix-cache \\\n    --speculative-draft-model-quantization unquant \\\n    --speculative-algorithm EAGLE3 --speculative-draft-model-path xxx  \\\n    --speculative-num-steps 4 --speculative-eagle-topk 1 --speculative-num-draft-tokens 5 \\\n    --chunked-prefill-size -1 --max-prefill-tokens 65536  \\\n    --tp-size 8 --mem-fraction-static 0.72 --cuda-graph-bs 1 4 6 12 18 24 30 32 --dtype bfloat16\n```\n\n#### Benchmark\n\nWe tested it based on the `RANDOM` dataset.\n\n```shell\npython3 -m sglang.bench_serving  --dataset-name random --backend sglang --host 127.0.0.1 --port 7339 --random-range-ratio 1 --max-concurrency 1 --random-output-len 1500 --random-input-len 4096 --num-prompts 4\n```\n\n### Qwen3-32B 3_5K-1_5K 50ms on A2 8 Cards Mixed Mode\n\nModel: Qwen3-32B\n\nHardware: Atlas 800I A2 8Card\n\nDeployMode: PD Mixed\n\nDataset: random\n\nInput Output Length: 3.5K+1.5K\n\nTPOT: 50ms\n\n#### Model Deployment\n\n```shell\necho performance | tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor\nsysctl -w vm.swappiness=0\nsysctl -w kernel.numa_balancing=0\nsysctl -w kernel.sched_migration_cost_ns=50000\n\nexport SGLANG_SET_CPU_AFFINITY=1\n\nunset https_proxy\nunset http_proxy\nunset HTTPS_PROXY\nunset HTTP_PROXY\nunset ASCEND_LAUNCH_BLOCKING\nsource /usr/local/Ascend/ascend-toolkit/set_env.sh\nsource /usr/local/Ascend/nnal/atb/set_env.sh\nsource /usr/local/Ascend/ascend-toolkit/latest/opp/vendors/customize/bin/set_env.bash\nexport PATH=/usr/local/Ascend/8.5.0/compiler/bishengir/bin:$PATH\n\nMODEL_PATH=xxx\n\nexport SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=600\n\nLOCAL_HOST1=`hostname -I|awk -F \" \" '{print$1}'`\nLOCAL_HOST2=`hostname -I|awk -F \" \" '{print$2}'`\n\necho \"${LOCAL_HOST1}\"\necho \"${LOCAL_HOST2}\"\n\nexport HCCL_BUFFSIZE=400\nexport HCCL_SOCKET_IFNAME=lo\nexport GLOO_SOCKET_IFNAME=lo\nexport HCCL_OP_EXPANSION_MODE=\"AIV\"\nexport SGLANG_ENABLE_OVERLAP_PLAN_STREAM=1\nexport SGLANG_ENABLE_SPEC_V2=1\n\npython -m sglang.launch_server --model-path $MODEL_PATH \\\n    --host 127.0.0.1 --port 7239 --trust-remote-code --nnodes 1 --node-rank 0  \\\n    --attention-backend ascend --device npu  --quantization modelslim  \\\n    --max-running-requests 78 \\\n    --disable-radix-cache --speculative-draft-model-quantization unquant \\\n    --chunked-prefill-size -1 --max-prefill-tokens 65536  \\\n    --speculative-algorithm EAGLE3 --speculative-draft-model-path xxx \\\n    --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 \\\n    --tp-size 4  --mem-fraction-static 0.72 --cuda-graph-bs 1 4 8 16 32 64 68 72 78 --dtype bfloat16 --base-gpu-id 4\n```\n\n#### Benchmark\n\nWe tested it based on the `RANDOM` dataset.\n\n```shell\npython3 -m sglang.bench_serving --dataset-name random --backend sglang --host 127.0.0.1 --port 7239 --max-concurrency 78 --random-output-len 1500 --random-input-len 3500 --num-prompts 312 --random-range-ratio 1\n```\n\n### Qwen3-32B 2K-2K 50ms on A2 8 Cards Mixed Mode\n\nModel: Qwen3-32B\n\nHardware: Atlas 800I A2 8Card\n\nDeployMode: PD Mixed\n\nDataset: random\n\nInput Output Length: 2K+2K\n\nTPOT: 50ms\n\n#### Model Deployment\n\n```shell\necho performance | tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor\nsysctl -w vm.swappiness=0\nsysctl -w kernel.numa_balancing=0\nsysctl -w kernel.sched_migration_cost_ns=50000\n\nexport SGLANG_SET_CPU_AFFINITY=1\nunset https_proxy\nunset http_proxy\nunset HTTPS_PROXY\nunset HTTP_PROXY\nunset ASCEND_LAUNCH_BLOCKING\nsource /usr/local/Ascend/ascend-toolkit/set_env.sh\nsource /usr/local/Ascend/nnal/atb/set_env.sh\nsource /usr/local/Ascend/ascend-toolkit/latest/opp/vendors/customize/bin/set_env.bash\nexport PATH=/usr/local/Ascend/8.5.0/compiler/bishengir/bin:$PATH\n\nMODEL_PATH=xxx\n\nexport SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=600\n\nLOCAL_HOST1=`hostname -I|awk -F \" \" '{print$1}'`\nLOCAL_HOST2=`hostname -I|awk -F \" \" '{print$2}'`\n\necho \"${LOCAL_HOST1}\"\necho \"${LOCAL_HOST2}\"\n\nexport HCCL_BUFFSIZE=400\nexport HCCL_SOCKET_IFNAME=lo\nexport GLOO_SOCKET_IFNAME=lo\nexport HCCL_OP_EXPANSION_MODE=\"AIV\"\nexport SGLANG_ENABLE_OVERLAP_PLAN_STREAM=1\nexport SGLANG_ENABLE_SPEC_V2=1\n\npython -m sglang.launch_server --model-path $MODEL_PATH \\\n    --host 127.0.0.1 --port 7239 --trust-remote-code --nnodes 1 --node-rank 0  \\\n    --attention-backend ascend --device npu  --quantization modelslim  \\\n    --max-running-requests 120 \\\n    --disable-radix-cache \\\n    --speculative-algorithm EAGLE3 --speculative-draft-model-path xxx \\\n    --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 --speculative-draft-model-quantization unquant \\\n    --chunked-prefill-size -1 --max-prefill-tokens 49152 --base-gpu-id 4 \\\n    --tp-size 4 --mem-fraction-static 0.7 --cuda-graph-bs 54 60 66 72 78 84 90 108 114 120 --dtype bfloat16\n```\n\n#### Benchmark\n\nWe tested it based on the `RANDOM` dataset.\n\n```shell\npython3 -m sglang.bench_serving --dataset-name random --backend sglang --host 127.0.0.1 --port 7239 --max-concurrency 120 --random-output-len 2000 --random-input-len 2000 --num-prompts 120 --random-range-ratio 1\n```\n"
  },
  {
    "path": "docs/platforms/ascend_npu_deepseek_example.md",
    "content": "## DeepSeek examples\n\n### Running DeepSeek-V3\n\n#### Running DeepSeek in PD mixed mode on 1 x Atlas 800I A3.\n\nW4A8 Model weights could be found [here](https://modelers.cn/models/Modelers_Park/DeepSeek-R1-0528-w4a8).\n\n```shell\nexport PYTORCH_NPU_ALLOC_CONF=expandable_segments:True\nexport STREAMS_PER_DEVICE=32\n\n#Deepep communication settings\nexport DEEP_NORMAL_MODE_USE_INT8_QUANT=1\nexport SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=32\nexport HCCL_BUFFSIZE=1600\n\n#spec overlap\nexport SGLANG_ENABLE_SPEC_V2=1\nexport SGLANG_ENABLE_OVERLAP_PLAN_STREAM=1\n\n#npu acceleration operator\nexport SGLANG_NPU_USE_MLAPO=1\nexport SGLANG_USE_FIA_NZ=1\n\npython3 -m sglang.launch_server \\\n    --model-path ${MODEL_PATH} \\\n    --tp 16 \\\n    --trust-remote-code \\\n    --attention-backend ascend \\\n    --device npu \\\n    --quantization modelslim \\\n    --watchdog-timeout 9000 \\\n    --cuda-graph-bs 8 16 24 28 32 \\\n    --mem-fraction-static 0.68 \\\n    --max-running-requests 128 \\\n    --context-length 8188 \\\n    --disable-radix-cache \\\n    --chunked-prefill-size -1 \\\n    --max-prefill-tokens 16384 \\\n    --moe-a2a-backend deepep \\\n    --deepep-mode auto \\\n    --enable-dp-attention \\\n    --dp-size 4 \\\n    --enable-dp-lm-head \\\n    --speculative-algorithm NEXTN \\\n    --speculative-num-steps 3 \\\n    --speculative-eagle-topk 1 \\\n    --speculative-num-draft-tokens 4 \\\n    --dtype bfloat16\n```\n\n#### Running DeepSeek with PD disaggregation mode on 2 x Atlas 800I A3.\n\nW4A8 Model weights could be found [here](https://modelers.cn/models/Modelers_Park/DeepSeek-R1-0528-w4a8).\n\n1. Prefill:\n\n```shell\nexport PYTORCH_NPU_ALLOC_CONF=expandable_segments:True\nexport STREAMS_PER_DEVICE=32\n\n#memfabric config store\nexport ASCEND_MF_STORE_URL=\"tcp://<PREFILL_HOST_IP>:<PORT>\"\n\n#Deepep communication settings\nexport DEEP_NORMAL_MODE_USE_INT8_QUANT=1\nexport HCCL_BUFFSIZE=1536\n\n#npu acceleration operator\nexport SGLANG_NPU_USE_MLAPO=1\nexport SGLANG_USE_FIA_NZ=1\nexport TASK_QUEUE_ENABLE=2\n\npython -m sglang.launch_server \\\n    --model-path ${MODEL_PATH} \\\n    --host $PREFILL_HOST_IP \\\n    --port 8000 \\\n    --disaggregation-mode prefill \\\n    --disaggregation-bootstrap-port 8996 \\\n    --disaggregation-transfer-backend ascend \\\n    --trust-remote-code \\\n    --nnodes 1 \\\n    --node-rank 0 \\\n    --tp-size 16 \\\n    --mem-fraction-static 0.6 \\\n    --attention-backend ascend \\\n    --device npu \\\n    --quantization modelslim \\\n    --load-balance-method round_robin \\\n    --max-running-requests 8 \\\n    --context-length 8192 \\\n    --disable-radix-cache \\\n    --chunked-prefill-size -1 \\\n    --max-prefill-tokens 28680 \\\n    --moe-a2a-backend deepep \\\n    --deepep-mode normal \\\n    --speculative-algorithm NEXTN \\\n    --speculative-num-steps 3 \\\n    --speculative-eagle-topk 1 \\\n    --speculative-num-draft-tokens 4 \\\n    --dp-size 2 \\\n    --enable-dp-attention \\\n    --disable-shared-experts-fusion \\\n    --dtype bfloat16\n```\n\n2. Decode:\n\n```shell\nexport PYTORCH_NPU_ALLOC_CONF=expandable_segments:True\nexport STREAMS_PER_DEVICE=32\n\n#memfabric config store\nexport ASCEND_MF_STORE_URL=\"tcp://<PREFILL_HOST_IP>:<PORT>\"\n\n#Deepep communication settings\nexport HCCL_BUFFSIZE=720\nexport SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=88\n\n#spec overlap\nexport SGLANG_ENABLE_SPEC_V2=1\nexport SGLANG_ENABLE_OVERLAP_PLAN_STREAM=1\n\n#npu acceleration operator\nunset TASK_QUEUE_ENABLE\nexport SGLANG_NPU_USE_MLAPO=1\nexport SGLANG_USE_FIA_NZ=1\n\n# suggest max-running-requests <= max-cuda-graph-bs * dp_size, Because when this value is exceeded, performance will significantly degrade.\npython -m sglang.launch_server \\\n    --model-path ${MODEL_PATH} \\\n    --disaggregation-mode decode \\\n    --host $DECODE_HOST_IP \\\n    --port 8001 \\\n    --trust-remote-code \\\n    --nnodes 1 \\\n    --node-rank 0 \\\n    --tp-size 16 \\\n    --dp-size 16 \\\n    --mem-fraction-static 0.8 \\\n    --max-running-requests 352 \\\n    --attention-backend ascend \\\n    --device npu \\\n    --quantization modelslim \\\n    --prefill-round-robin-balance \\\n    --moe-a2a-backend deepep \\\n    --enable-dp-attention \\\n    --deepep-mode low_latency \\\n    --enable-dp-lm-head \\\n    --cuda-graph-bs 8 10 12 14 16 18 20 22 \\\n    --disaggregation-transfer-backend ascend \\\n    --watchdog-timeout 9000 \\\n    --context-length 8192 \\\n    --speculative-algorithm NEXTN \\\n    --speculative-num-steps 3 \\\n    --speculative-eagle-topk 1 \\\n    --speculative-num-draft-tokens 4 \\\n    --disable-shared-experts-fusion \\\n    --dtype bfloat16 \\\n    --tokenizer-worker-num 4\n```\n\n3. SGLang Model Gateway (former Router)\n\n```shell\npython -m sglang_router.launch_router \\\n    --pd-disaggregation \\\n    --policy cache_aware \\\n    --prefill http://<PREFILL_HOST_IP>:8000 8996 \\\n    --decode http://<DECODE_HOST_IP>:8001 \\\n    --host 127.0.0.1 \\\n    --port 6688\n```\n\n#### Running DeepSeek with PD disaggregation on 4 x Atlas 800I A3.\n\nW8A8 Model weights could be found [here](https://modelers.cn/models/State_Cloud/Deepseek-R1-bf16-hfd-w8a8).\n\n1. Prefill & Decode:\n\n```shell\necho performance | tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor\nsysctl -w vm.swappiness=0\nsysctl -w kernel.numa_balancing=0\nsysctl -w kernel.sched_migration_cost_ns=50000\nexport SGLANG_SET_CPU_AFFINITY=1\nunset ASCEND_LAUNCH_BLOCKING\nsource /usr/local/Ascend/ascend-toolkit/set_env.sh\nsource /usr/local/Ascend/nnal/atb/set_env.sh\nexport PATH=/usr/local/Ascend/8.5.0/compiler/bishengir/bin:$PATH\n\nexport PYTORCH_NPU_ALLOC_CONF=expandable_segments:True\nexport STREAMS_PER_DEVICE=32\n\nexport ASCEND_MF_STORE_URL=\"tcp://your prefill ip1:24669\"\n\nP_IP=('your prefill ip1' 'your prefill ip2')\n\nD_IP=('your decode ip1' 'your decode ip2')\n\nMODEL_PATH=xxx\n\nexport SGLANG_NPU_USE_MLAPO=1\nexport SGLANG_USE_FIA_NZ=1\n\nLOCAL_HOST1=`hostname -I|awk -F \" \" '{print$1}'`\nLOCAL_HOST2=`hostname -I|awk -F \" \" '{print$2}'`\necho \"${LOCAL_HOST1}\"\necho \"${LOCAL_HOST2}\"\n# prefill\nfor i in \"${!P_IP[@]}\";\ndo\n    if [[ \"$LOCAL_HOST1\" == \"${P_IP[$i]}\" || \"$LOCAL_HOST2\" == \"${P_IP[$i]}\" ]];\n    then\n        echo \"${P_IP[$i]}\"\n        export HCCL_BUFFSIZE=1536\n        export DEEP_NORMAL_MODE_USE_INT8_QUANT=1\n        export TASK_QUEUE_ENABLE=2\n\n        export HCCL_SOCKET_IFNAME=lo\n        export GLOO_SOCKET_IFNAME=lo\n        python -m sglang.launch_server --model-path ${MODEL_PATH}  --disaggregation-mode prefill --host ${P_IP[$i]} \\\n        --port 8000 --disaggregation-bootstrap-port $((8998+$i)) --trust-remote-code --nnodes 1 --node-rank 0 \\\n        --tp-size 16 --mem-fraction-static 0.81 --attention-backend ascend --device npu --quantization modelslim \\\n        --disaggregation-transfer-backend ascend --max-running-requests 8 --context-length 8192  --disable-radix-cache \\\n        --chunked-prefill-size -1 --max-prefill-tokens 28680 --moe-a2a-backend deepep --deepep-mode normal \\\n        --speculative-algorithm NEXTN --speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2  \\\n        --dp-size 2 --enable-dp-attention --disable-shared-experts-fusion --dtype bfloat16 --enable-attn-tp-input-scattered\n        NODE_RANK=$i\n        break\n    fi\ndone\n\n# decode\nfor i in \"${!D_IP[@]}\";\ndo\n    if [[ \"$LOCAL_HOST1\" == \"${D_IP[$i]}\" || \"$LOCAL_HOST2\" == \"${D_IP[$i]}\" ]];\n    then\n        echo \"${D_IP[$i]}\"\n        export SGLANG_ENABLE_OVERLAP_PLAN_STREAM=1\n        export SGLANG_ENABLE_SPEC_V2=1\n        export HCCL_BUFFSIZE=650\n        export SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=78\n        export TASK_QUEUE_ENABLE=1\n        export SGLANG_SCHEDULER_SKIP_ALL_GATHER=1\n        export HCCL_SOCKET_IFNAME=xxx\n        export GLOO_SOCKET_IFNAME=xxx\n        python -m sglang.launch_server --model-path ${MODEL_PATH} --disaggregation-mode decode --host ${D_IP[$i]} \\\n        --port 8001 --trust-remote-code --dist-init-addr ${D_IP[0]}:5000 --nnodes 2 --node-rank $i --tp-size 32 --dp-size 32 \\\n        --mem-fraction-static 0.815 --max-running-requests 832 --attention-backend ascend --device npu --quantization modelslim \\\n        --moe-a2a-backend deepep --enable-dp-attention --deepep-mode low_latency --enable-dp-lm-head --moe-dense-tp 1 \\\n        --cuda-graph-bs 12 14 16 18 20 22 24 26 --disaggregation-transfer-backend ascend --watchdog-timeout 9000 --context-length 8192 \\\n        --speculative-algorithm NEXTN --speculative-num-steps 2 --speculative-eagle-topk 1 --speculative-num-draft-tokens 3  \\\n        --tokenizer-worker-num 4 --prefill-round-robin-balance --disable-shared-experts-fusion --dtype bfloat16 \\\n        --load-balance-method decode_round_robin\n        NODE_RANK=$i\n        break\n    fi\ndone\n```\n\n2. SGLang Model Gateway (former Router):\n\n```shell\nexport SGLANG_DP_ROUND_ROBIN=1\npython -m sglang_router.launch_router \\\n    --pd-disaggregation \\\n    --policy cache_aware \\\n    --prefill http://P_IP:8000 8998 \\\n    --prefill http://P_IP:8000 8999 \\\n    --decode http://D_IP:8001 \\\n    --host 127.0.0.1 \\\n    --port 6688 \\\n    --mini-lb\n```\n\n#### test gsm8k\n\n```python\nfrom types import SimpleNamespace\nfrom sglang.test.few_shot_gsm8k import run_eval\n\ndef gsm8k():\n    args = SimpleNamespace(\n        num_shots=5,\n        data_path=None,\n        num_questions=200,\n        max_new_tokens=512,\n        parallel=32,\n        host=f\"http://127.0.0.1\",\n        port=6688,\n    )\n    metrics = run_eval(args)\n    print(f\"{metrics=}\")\n    print(f\"{metrics['accuracy']=}\")\nif __name__ == \"__main__\":\n    gsm8k()\n```\n"
  },
  {
    "path": "docs/platforms/ascend_npu_environment_variables.md",
    "content": "# Environment Variables\n\nSGLang supports various environment variables related to Ascend NPU that can be used to configure its runtime behavior.\nThis document provides a list of commonly used environment variables and aims to stay updated over time.\n\n## Directly Used in SGLang\n\n| Environment Variable                             | Description                                                                                                                                                 | Default Value |\n|--------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------|\n| `SGLANG_NPU_USE_MLAPO`                           | Adopts the `MLAPO` fusion operator in attention <br/> preprocessing stage of the MLA model.                                                                 | `false`       |\n| `SGLANG_USE_FIA_NZ`                              | Reshapes KV Cache for FIA NZ format.<br/> `SGLANG_USE_FIA_NZ` must be enabled with `SGLANG_NPU_USE_MLAPO`                                                   | `false`       |\n| `SGLANG_NPU_USE_MULTI_STREAM`                    | Enable dual-stream computation of shared experts <br/> and routing experts in DeepSeek models.<br/> Enable dual-stream computation in DeepSeek NSA Indexer. | `false`       |\n| `SGLANG_NPU_DISABLE_ACL_FORMAT_WEIGHT`           | Disable cast model weight tensor to a specific NPU <br/> ACL format.                                                                                        | `false`       |\n| `SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK` | The maximum number of dispatched tokens on each rank.                                                                                                       | `128`         |\n\n## Used in DeepEP Ascend\n\n| Environment Variable                      | Description                                                                                                            | Default Value |\n|-------------------------------------------|------------------------------------------------------------------------------------------------------------------------|---------------|\n| `DEEPEP_NORMAL_LONG_SEQ_PER_ROUND_TOKENS` | Enable ant-moving function in dispatch stage. Indicates <br/> the number of tokens transmitted per round on each rank. | `8192`        |\n| `DEEPEP_NORMAL_LONG_SEQ_ROUND`            | Enable ant-moving function in dispatch stage. Indicates <br/> the number of rounds transmitted on each rank.           | `1`           |\n| `DEEPEP_NORMAL_COMBINE_ENABLE_LONG_SEQ`   | Enable ant-moving function in combine stage. <br/> The value `0` means disabled.                                       | `0`           |\n| `MOE_ENABLE_TOPK_NEG_ONE`                 | Needs to be enabled when the expert ID to be processed by <br/> DEEPEP contains -1.                                    | `0`           |\n| `DEEP_NORMAL_MODE_USE_INT8_QUANT`         | Quantizes x to int8 and returns (tensor, scales) in dispatch operator.                                                 | `0`           |\n\n## Others\n\n| Environment Variable     | Description                                                                                                                                                                                                                                                                | Default Value |\n|--------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------|\n| `TASK_QUEUE_ENABLE`      | Used to control the optimization level of the dispatch queue<br/> about the task_queue operator. [Detail](https://www.hiascend.com/document/detail/zh/Pytorch/730/comref/Envvariables/docs/zh/environment_variable_reference/TASK_QUEUE_ENABLE.md)                         | `1`           |\n| `INF_NAN_MODE_ENABLE`    | Controls whether the chip uses saturation mode or INF_NAN mode. [Detail](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/800alpha001/apiref/envref/envref_07_0056.html)                                                                                   | `1`           |\n| `STREAMS_PER_DEVICE`     | Configures the maximum number of streams for the stream pool. [Detail](https://www.hiascend.com/document/detail/zh/Pytorch/720/comref/Envvariables/Envir_041.html)                                                                                                         | `32`          |\n| `PYTORCH_NPU_ALLOC_CONF` | Controls the behavior of the cache allocator. <br/>This variable changes memory usage and may cause performance fluctuations. [Detail](https://www.hiascend.com/document/detail/zh/Pytorch/700/comref/Envvariables/Envir_012.html)                                         |               |\n| `ASCEND_MF_STORE_URL`    | The address of config store in MemFabric during PD separation, <br/>which is generally set to the IP address of the P primary node<br/> with an arbitrary port number.                                                                                                     |               |\n| `ASCEND_LAUNCH_BLOCKING` | Controls whether synchronous mode is enabled during operator execution. [Detail](https://www.hiascend.com/document/detail/zh/Pytorch/710/comref/Envvariables/Envir_006.html)                                                                                               | `0`           |\n| `HCCL_OP_EXPANSION_MODE` | Configures the expansion position for communication algorithm scheduling. [Detail](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/800alpha001/apiref/envref/envref_07_0094.html)                                                                         |               |\n| `HCCL_BUFFSIZE`          | Controls the size of the buffer area for shared data between two NPUs. <br/>The unit is MB, and the value must be greater than or equal to 1. [Detail](https://www.hiascend.com/document/detail/zh/Pytorch/60RC3/ptmoddevg/trainingmigrguide/performance_tuning_0047.html) | `200`         |\n| `HCCL_SOCKET_IFNAME`     | Configures the name of the network card used by the Host <br/>during HCCL initialization. [Detail](https://www.hiascend.com/document/detail/zh/canncommercial/81RC1/apiref/envvar/envref_07_0075.html)                                                                     |               |\n| `GLOO_SOCKET_IFNAME`     | Configures the network interface name for GLOO communication.                                                                                                                                                                                                              |               |\n"
  },
  {
    "path": "docs/platforms/ascend_npu_glm5_examples.md",
    "content": "# GLM-5 examples\n\n## Introduction\n\nThe GLM (General Language Model) series is an open-source bilingual large language model family jointly developed by the KEG Laboratory of Tsinghua University and Zhipu AI. This series of models has performed outstandingly in the field of Chinese NLP with its unique unified pre-training framework and bilingual capabilities. [GLM-5](https://huggingface.co/zai-org/GLM-5) adopts the DeepSeek-V3/V3.2 architecture, including the sparse attention (DSA) and multi-token prediction (MTP). Ascend supports GLM-5 with 0Day based on the SGLang inference framework, achieving low-code seamless enablement and compatibility with the mainstream distributed parallel capabilities within the current SGLang framework. We welcome developers to download and experience it.\n\n## Environment Preparation\n\n### Model Weight\n\n- `GLM-5.0`(BF16 version): [Download model weight](https://www.modelscope.cn/models/ZhipuAI/GLM-5).\n- `GLM-5.0-w4a8`(Quantized version without mtp): [Download model weight](https://modelers.cn/models/Eco-Tech/GLM-5-w4a8).\n- You can use [msmodelslim](https://gitcode.com/Ascend/msmodelslim) to quantify the model naively.\n\n\n### Installation\n\nThe dependencies required for the NPU runtime environment have been integrated into a Docker image and uploaded to the Ascend platform. You can directly pull it.\n\n```{code-block} bash\n#Atlas 800 A3\ndocker pull swr.cn-southwest-2.myhuaweicloud.com/base_image/dockerhub/lmsysorg/sglang:cann8.5.0-a3-glm5\n#Atlas 800 A2\ndocker pull swr.cn-southwest-2.myhuaweicloud.com/base_image/dockerhub/lmsysorg/sglang:cann8.5.0-910b-glm5\n\n#start container\ndocker run -itd --shm-size=16g --privileged=true --name ${NAME} \\\n--privileged=true --net=host \\\n-v /var/queue_schedule:/var/queue_schedule \\\n-v /etc/ascend_install.info:/etc/ascend_install.info \\\n-v /usr/local/sbin:/usr/local/sbin \\\n-v /usr/local/Ascend/driver:/usr/local/Ascend/driver \\\n-v /usr/local/Ascend/firmware:/usr/local/Ascend/firmware \\\n--device=/dev/davinci0:/dev/davinci0  \\\n--device=/dev/davinci1:/dev/davinci1  \\\n--device=/dev/davinci2:/dev/davinci2  \\\n--device=/dev/davinci3:/dev/davinci3  \\\n--device=/dev/davinci4:/dev/davinci4  \\\n--device=/dev/davinci5:/dev/davinci5  \\\n--device=/dev/davinci6:/dev/davinci6  \\\n--device=/dev/davinci7:/dev/davinci7  \\\n--device=/dev/davinci8:/dev/davinci8  \\\n--device=/dev/davinci9:/dev/davinci9  \\\n--device=/dev/davinci10:/dev/davinci10  \\\n--device=/dev/davinci11:/dev/davinci11  \\\n--device=/dev/davinci12:/dev/davinci12  \\\n--device=/dev/davinci13:/dev/davinci13  \\\n--device=/dev/davinci14:/dev/davinci14  \\\n--device=/dev/davinci15:/dev/davinci15  \\\n--device=/dev/davinci_manager:/dev/davinci_manager \\\n--device=/dev/hisi_hdc:/dev/hisi_hdc \\\n--entrypoint=bash \\\nswr.cn-southwest-2.myhuaweicloud.com/base_image/dockerhub/lmsysorg/sglang:${TAG}\n```\n\nNote: Using this image, you need to update transformers to main branch\n``` shell\n# reinstall transformers\npip install git+https://github.com/huggingface/transformers.git\n```\n\n## Deployment\n\n### Single-node Deployment\n\n- Quantized model `glm5_w4a8` can be deployed on 1 Atlas 800 A3 (64G × 16) .\n\nRun the following script to execute online inference.\n\n```shell\n# high performance cpu\necho performance | tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor\nsysctl -w vm.swappiness=0\nsysctl -w kernel.numa_balancing=0\nsysctl -w kernel.sched_migration_cost_ns=50000\n# bind cpu\nexport SGLANG_SET_CPU_AFFINITY=1\n\nunset https_proxy\nunset http_proxy\nunset HTTPS_PROXY\nunset HTTP_PROXY\nunset ASCEND_LAUNCH_BLOCKING\n# cann\nsource /usr/local/Ascend/ascend-toolkit/set_env.sh\nsource /usr/local/Ascend/nnal/atb/set_env.sh\n\nexport STREAMS_PER_DEVICE=32\nexport SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=600\nexport SGLANG_ENABLE_SPEC_V2=1\nexport SGLANG_ENABLE_OVERLAP_PLAN_STREAM=1\nexport SGLANG_NPU_USE_MULTI_STREAM=1\nexport HCCL_BUFFSIZE=1000\nexport HCCL_OP_EXPANSION_MODE=AIV\nexport HCCL_SOCKET_IFNAME=lo\nexport GLOO_SOCKET_IFNAME=lo\n\npython3 -m sglang.launch_server \\\n        --model-path $MODEL_PATH \\\n        --attention-backend ascend \\\n        --device npu \\\n        --tp-size 16 --nnodes 1 --node-rank 0 \\\n        --chunked-prefill-size 16384 --max-prefill-tokens 280000 \\\n        --trust-remote-code \\\n        --host 127.0.0.1 \\\n        --mem-fraction-static 0.7 \\\n        --port 8000 \\\n        --served-model-name glm-5 \\\n        --cuda-graph-bs 16 \\\n        --quantization modelslim \\\n        --moe-a2a-backend deepep --deepep-mode auto\n```\n\n### Multi-node Deployment\n\n- `GLM-5-bf16`: require at least 2 Atlas 800 A3 (64G × 16).\n\n**A3 series**\n\nModify the IP of 2 nodes, then run the same scripts on two nodes.\n\n**node 0/1**\n\n```shell\necho performance | tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor\nsysctl -w vm.swappiness=0\nsysctl -w kernel.numa_balancing=0\nsysctl -w kernel.sched_migration_cost_ns=50000\n# bind cpu\nexport SGLANG_SET_CPU_AFFINITY=1\n\nunset https_proxy\nunset http_proxy\nunset HTTPS_PROXY\nunset HTTP_PROXY\nunset ASCEND_LAUNCH_BLOCKING\n# cann\nsource /usr/local/Ascend/ascend-toolkit/set_env.sh\nsource /usr/local/Ascend/nnal/atb/set_env.sh\n\nexport STREAMS_PER_DEVICE=32\nexport SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=600\nexport SGLANG_ENABLE_SPEC_V2=1\nexport SGLANG_ENABLE_OVERLAP_PLAN_STREAM=1\nexport SGLANG_NPU_USE_MULTI_STREAM=1\nexport HCCL_BUFFSIZE=1000\nexport HCCL_OP_EXPANSION_MODE=AIV\n\n# Run command ifconfig on two nodes, find out which inet addr has same IP with your node IP. That is your public interface, which should be added here\nexport HCCL_SOCKET_IFNAME=lo\nexport GLOO_SOCKET_IFNAME=lo\n\n\nP_IP=('your ip1' 'your ip2')\nP_MASTER=\"${P_IP[0]}:your port\"\nexport SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=600\n\nexport SGLANG_ENABLE_SPEC_V2=1\nexport SGLANG_ENABLE_OVERLAP_PLAN_STREAM=1\n\nLOCAL_HOST1=`hostname -I|awk -F \" \" '{print$1}'`\nLOCAL_HOST2=`hostname -I|awk -F \" \" '{print$2}'`\nfor i in \"${!P_IP[@]}\";\ndo\n    if [[ \"$LOCAL_HOST1\" == \"${P_IP[$i]}\" || \"$LOCAL_HOST2\" == \"${P_IP[$i]}\" ]];\n    then\n        echo \"${P_IP[$i]}\"\n        python3 -m sglang.launch_server \\\n        --model-path $MODEL_PATH \\\n        --attention-backend ascend \\\n        --device npu \\\n        --tp-size 32 --nnodes 2 --node-rank $i --dist-init-addr $P_MASTER \\\n        --chunked-prefill-size 16384 --max-prefill-tokens 131072 \\\n        --trust-remote-code \\\n        --host 127.0.0.1 \\\n        --mem-fraction-static 0.8\\\n        --port 8000 \\\n        --served-model-name glm-5 \\\n        --cuda-graph-max-bs 16 \\\n        --disable-radix-cache\n        NODE_RANK=$i\n        break\n    fi\ndone\n\n```\n\n### Prefill-Decode Disaggregation\n\nNot test yet.\n\n### Using Benchmark\n\nRefer to [Benchmark and Profiling](../developer_guide/benchmark_and_profiling.md) for details.\n"
  },
  {
    "path": "docs/platforms/ascend_npu_quantization.md",
    "content": "Quantization on Ascend.\n\nTo load already quantized models, simply load the model weights and config. Again, if the model has been quantized offline, there's no need to add `--quantization` argument when starting the engine. The quantization method will be automatically parsed from the downloaded `quant_model_description.json` or `config.json` config.\n\n[ModelSlim on Ascend support](https://github.com/sgl-project/sglang/pull/14504):\n- [x] W4A4 dynamic linear\n- [x] W8A8 static linear\n- [x] W8A8 dynamic linear\n- [x] W4A4 dynamic MOE\n- [x] W4A8 dynamic MOE\n- [x] W8A8 dynamic MOE\n\n[AWQ on Ascend support](https://github.com/sgl-project/sglang/pull/10158):\n- [x] W4A16 linear\n- [x] W8A16 linear # Need to test\n- [x] W4A16 MOE # Need to test\n\nCompressed-tensors (LLM Compressor) on Ascend support:\n- [x] [W4A8 dynamic MOE with/without activation clip](https://github.com/sgl-project/sglang/pull/14736) # Need to test\n- [x] [W4A16 MOE](https://github.com/sgl-project/sglang/pull/12759)\n- [x] [W8A8 dynamic linear](https://github.com/sgl-project/sglang/pull/14504)\n- [x] [W8A8 dynamic MOE](https://github.com/sgl-project/sglang/pull/14504)\n\nDiffusion model [modelslim](https://github.com/sgl-project/sglang/pull/17996) quantization on Ascend support:\n- [x] W4A4 dynamic linear\n- [x] W8A8 static linear\n- [x] W8A8 dynamic linear\n"
  },
  {
    "path": "docs/platforms/ascend_npu_qwen3_5_examples.md",
    "content": "# Qwen3.5 examples\n\n## Environment Preparation\n\n### Installation\n\nThe dependencies required for the NPU runtime environment have been integrated into a Docker image and uploaded to the quay.io platform. You can directly pull it.\n\n```{code-block} bash\n#Atlas 800 A3\ndocker pull quay.io/ascend/sglang:main-cann8.5.0-a3\n#Atlas 800 A2\ndocker pull quay.io/ascend/sglang:main-cann8.5.0-910b\n\n#start container\ndocker run -itd --shm-size=16g --privileged=true --name ${NAME} \\\n--privileged=true --net=host \\\n-v /var/queue_schedule:/var/queue_schedule \\\n-v /etc/ascend_install.info:/etc/ascend_install.info \\\n-v /usr/local/sbin:/usr/local/sbin \\\n-v /usr/local/Ascend/driver:/usr/local/Ascend/driver \\\n-v /usr/local/Ascend/firmware:/usr/local/Ascend/firmware \\\n--device=/dev/davinci0:/dev/davinci0  \\\n--device=/dev/davinci1:/dev/davinci1  \\\n--device=/dev/davinci2:/dev/davinci2  \\\n--device=/dev/davinci3:/dev/davinci3  \\\n--device=/dev/davinci4:/dev/davinci4  \\\n--device=/dev/davinci5:/dev/davinci5  \\\n--device=/dev/davinci6:/dev/davinci6  \\\n--device=/dev/davinci7:/dev/davinci7  \\\n--device=/dev/davinci8:/dev/davinci8  \\\n--device=/dev/davinci9:/dev/davinci9  \\\n--device=/dev/davinci10:/dev/davinci10  \\\n--device=/dev/davinci11:/dev/davinci11  \\\n--device=/dev/davinci12:/dev/davinci12  \\\n--device=/dev/davinci13:/dev/davinci13  \\\n--device=/dev/davinci14:/dev/davinci14  \\\n--device=/dev/davinci15:/dev/davinci15  \\\n--device=/dev/davinci_manager:/dev/davinci_manager \\\n--device=/dev/hisi_hdc:/dev/hisi_hdc \\\n--entrypoint=bash \\\nquay.io/ascend/sglang:${tag}\n```\n\n## Deployment\n\n### Single-node Deployment\n\nRun the following script to execute online inference.\n\n#### Qwen3.5 397B\n\n```shell\n# high performance cpu\necho performance | tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor\nsysctl -w vm.swappiness=0\nsysctl -w kernel.numa_balancing=0\nsysctl -w kernel.sched_migration_cost_ns=50000\n# bind cpu\nexport SGLANG_SET_CPU_AFFINITY=1\n\nunset https_proxy\nunset http_proxy\nunset HTTPS_PROXY\nunset HTTP_PROXY\nunset ASCEND_LAUNCH_BLOCKING\n# cann\nsource /usr/local/Ascend/ascend-toolkit/set_env.sh\nsource /usr/local/Ascend/nnal/atb/set_env.sh\n\nexport STREAMS_PER_DEVICE=32\nexport HCCL_BUFFSIZE=1000\nexport HCCL_OP_EXPANSION_MODE=AIV\nexport HCCL_SOCKET_IFNAME=lo\nexport GLOO_SOCKET_IFNAME=lo\n\npython3 -m sglang.launch_server \\\n        --model-path $MODEL_PATH \\\n        --attention-backend ascend \\\n        --device npu \\\n        --tp-size 16 --nnodes 1 --node-rank 0 \\\n        --chunked-prefill-size 4096 --max-prefill-tokens 280000 \\\n        --disable-radix-cache \\\n        --trust-remote-code \\\n        --host 127.0.0.1 \\\n        --mem-fraction-static 0.7 \\\n        --port 8000 \\\n        --cuda-graph-bs 16 \\\n        --quantization modelslim \\\n        --enable-multimodal \\\n        --mm-attention-backend ascend_attn \\\n        --dtype bfloat16\n```\n\n#### Qwen3.5 122B\n\n```shell\n# high performance cpu\necho performance | tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor\nsysctl -w vm.swappiness=0\nsysctl -w kernel.numa_balancing=0\nsysctl -w kernel.sched_migration_cost_ns=50000\n# bind cpu\nexport SGLANG_SET_CPU_AFFINITY=1\n\nunset https_proxy\nunset http_proxy\nunset HTTPS_PROXY\nunset HTTP_PROXY\nunset ASCEND_LAUNCH_BLOCKING\n# cann\nsource /usr/local/Ascend/ascend-toolkit/set_env.sh\nsource /usr/local/Ascend/nnal/atb/set_env.sh\n\nexport STREAMS_PER_DEVICE=32\nexport HCCL_BUFFSIZE=1000\nexport HCCL_OP_EXPANSION_MODE=AIV\nexport HCCL_SOCKET_IFNAME=lo\nexport GLOO_SOCKET_IFNAME=lo\n\npython3 -m sglang.launch_server \\\n        --model-path $MODEL_PATH \\\n        --attention-backend ascend \\\n        --device npu \\\n        --tp-size 8 --nnodes 1 --node-rank 0 \\\n        --chunked-prefill-size 4096 --max-prefill-tokens 280000 \\\n        --disable-radix-cache \\\n        --trust-remote-code \\\n        --host 127.0.0.1 \\\n        --mem-fraction-static 0.7 \\\n        --port 8000 \\\n        --cuda-graph-bs 16 \\\n        --quantization modelslim \\\n        --enable-multimodal \\\n        --mm-attention-backend ascend_attn \\\n        --dtype bfloat16\n```\n\n#### Qwen3.5 35B\n\n```shell\n# high performance cpu\necho performance | tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor\nsysctl -w vm.swappiness=0\nsysctl -w kernel.numa_balancing=0\nsysctl -w kernel.sched_migration_cost_ns=50000\n# bind cpu\nexport SGLANG_SET_CPU_AFFINITY=1\n\nunset https_proxy\nunset http_proxy\nunset HTTPS_PROXY\nunset HTTP_PROXY\nunset ASCEND_LAUNCH_BLOCKING\n# cann\nsource /usr/local/Ascend/ascend-toolkit/set_env.sh\nsource /usr/local/Ascend/nnal/atb/set_env.sh\n\nexport STREAMS_PER_DEVICE=32\nexport HCCL_BUFFSIZE=1000\nexport HCCL_OP_EXPANSION_MODE=AIV\nexport HCCL_SOCKET_IFNAME=lo\nexport GLOO_SOCKET_IFNAME=lo\n\npython3 -m sglang.launch_server \\\n        --model-path $MODEL_PATH \\\n        --attention-backend ascend \\\n        --device npu \\\n        --tp-size 2 --nnodes 1 --node-rank 0 \\\n        --chunked-prefill-size 4096 --max-prefill-tokens 280000 \\\n        --disable-radix-cache \\\n        --trust-remote-code \\\n        --host 127.0.0.1 \\\n        --mem-fraction-static 0.7 \\\n        --port 8000 \\\n        --cuda-graph-bs 16 \\\n        --quantization modelslim \\\n        --enable-multimodal \\\n        --mm-attention-backend ascend_attn \\\n        --dtype bfloat16\n```\n\n#### Qwen3.5 27B\n\n```shell\n# high performance cpu\necho performance | tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor\nsysctl -w vm.swappiness=0\nsysctl -w kernel.numa_balancing=0\nsysctl -w kernel.sched_migration_cost_ns=50000\n# bind cpu\nexport SGLANG_SET_CPU_AFFINITY=1\n\nunset https_proxy\nunset http_proxy\nunset HTTPS_PROXY\nunset HTTP_PROXY\nunset ASCEND_LAUNCH_BLOCKING\n# cann\nsource /usr/local/Ascend/ascend-toolkit/set_env.sh\nsource /usr/local/Ascend/nnal/atb/set_env.sh\n\nexport STREAMS_PER_DEVICE=32\nexport HCCL_BUFFSIZE=1000\nexport HCCL_OP_EXPANSION_MODE=AIV\nexport HCCL_SOCKET_IFNAME=lo\nexport GLOO_SOCKET_IFNAME=lo\n\npython3 -m sglang.launch_server \\\n        --model-path $MODEL_PATH \\\n        --attention-backend ascend \\\n        --device npu \\\n        --tp-size 2 \\\n        --chunked-prefill-size -1 --max-prefill-tokens 120000 \\\n        --disable-radix-cache \\\n        --trust-remote-code \\\n        --host 127.0.0.1 \\\n        --mem-fraction-static 0.8 \\\n        --port 8000 \\\n        --cuda-graph-bs 32 \\\n        --enable-multimodal \\\n        --mm-attention-backend ascend_attn\n```\n\n### Prefill-Decode Disaggregation\n\nNot test yet.\n\n### Using Benchmark\n\nRefer to [Benchmark and Profiling](../developer_guide/benchmark_and_profiling.md) for details.\n"
  },
  {
    "path": "docs/platforms/ascend_npu_qwen3_examples.md",
    "content": "## Qwen3 examples\n\n### Running Qwen3\n\n#### Running Qwen3-32B on 1 x Atlas 800I A3.\n\nModel weights could be found [here](https://huggingface.co/Qwen/Qwen3-32B)\n\n```shell\nexport SGLANG_SET_CPU_AFFINITY=1\nexport PYTORCH_NPU_ALLOC_CONF=expandable_segments:True\nexport STREAMS_PER_DEVICE=32\nexport HCCL_BUFFSIZE=1536\nexport HCCL_OP_EXPANSION_MODE=AIV\n\npython -m sglang.launch_server \\\n   --device npu \\\n   --attention-backend ascend \\\n   --trust-remote-code \\\n   --tp-size 4 \\\n   --model-path Qwen/Qwen3-32B \\\n   --mem-fraction-static 0.8\n```\n\n#### Running Qwen3-32B on 1 x Atlas 800I A3 with Qwen3-32B-Eagle3.\n\nModel weights could be found [here](https://huggingface.co/Qwen/Qwen3-32B)\n\nSpeculative model weights could be found [here](https://huggingface.co/Zhihu-ai/Zhi-Create-Qwen3-32B-Eagle3)\n\n```shell\nexport SGLANG_SET_CPU_AFFINITY=1\nexport PYTORCH_NPU_ALLOC_CONF=expandable_segments:True\nexport STREAMS_PER_DEVICE=32\nexport HCCL_OP_EXPANSION_MODE=AIV\nexport SGLANG_ENABLE_OVERLAP_PLAN_STREAM=1\nexport SGLANG_ENABLE_SPEC_V2=1\n\npython -m sglang.launch_server \\\n   --device npu \\\n   --attention-backend ascend \\\n   --trust-remote-code \\\n   --tp-size 4 \\\n   --model-path Qwen/Qwen3-32B \\\n   --mem-fraction-static 0.8 \\\n   --speculative-algorithm EAGLE3 \\\n   --speculative-draft-model-path Qwen/Qwen3-32B-Eagle3 \\\n   --speculative-num-steps 1 \\\n   --speculative-eagle-topk 1 \\\n   --speculative-num-draft-tokens 2\n```\n\n#### Running Qwen3-30B-A3B MOE on 1 x Atlas 800I A3.\n\nModel weights could be found [here](https://huggingface.co/Qwen/Qwen3-30B-A3B)\n\n```shell\nexport SGLANG_SET_CPU_AFFINITY=1\nexport PYTORCH_NPU_ALLOC_CONF=expandable_segments:True\nexport STREAMS_PER_DEVICE=32\nexport HCCL_BUFFSIZE=1536\nexport HCCL_OP_EXPANSION_MODE=AIV\nexport SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=32\nexport SGLANG_DEEPEP_BF16_DISPATCH=1\n\npython -m sglang.launch_server \\\n   --device npu \\\n   --attention-backend ascend \\\n   --trust-remote-code \\\n   --tp-size 4 \\\n   --model-path Qwen/Qwen3-30B-A3B \\\n   --mem-fraction-static 0.8\n```\n\n#### Running Qwen3-235B-A22B-Instruct-2507 MOE on 1 x Atlas 800I A3.\n\nModel weights could be found [here](https://huggingface.co/Qwen/Qwen3-235B-A22B-Instruct-2507)\n\n```shell\nexport SGLANG_SET_CPU_AFFINITY=1\nexport PYTORCH_NPU_ALLOC_CONF=expandable_segments:True\nexport STREAMS_PER_DEVICE=32\nexport HCCL_BUFFSIZE=1536\nexport SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=32\nexport SGLANG_DEEPEP_BF16_DISPATCH=1\n\npython -m sglang.launch_server \\\n   --model-path Qwen/Qwen3-235B-A22B-Instruct-2507 \\\n   --tp-size 16 \\\n   --trust-remote-code \\\n   --attention-backend ascend \\\n   --device npu \\\n   --watchdog-timeout 9000 \\\n   --mem-fraction-static 0.8\n```\n\n#### Running Qwen3-VL-8B-Instruct on 1 x Atlas 800I A3.\n\nModel weights could be found [here](https://huggingface.co/Qwen/Qwen3-VL-8B-Instruct)\n\n```shell\nexport SGLANG_SET_CPU_AFFINITY=1\nexport PYTORCH_NPU_ALLOC_CONF=expandable_segments:True\nexport STREAMS_PER_DEVICE=32\nexport HCCL_BUFFSIZE=1536\nexport HCCL_OP_EXPANSION_MODE=AIV\n\npython -m sglang.launch_server \\\n   --enable-multimodal \\\n   --attention-backend ascend \\\n   --mm-attention-backend ascend_attn \\\n   --trust-remote-code \\\n   --tp-size 4 \\\n   --model-path Qwen/Qwen3-VL-8B-Instruct \\\n   --mem-fraction-static 0.8\n```\n"
  },
  {
    "path": "docs/platforms/ascend_npu_support.rst",
    "content": "Ascend NPUs\n===============================================================\n\n.. toctree::\n   :maxdepth: 1\n\n   ascend_npu.md\n   ascend_npu_support_features.md\n   ascend_npu_support_models.md\n   ascend_npu_deepseek_example.md\n   ascend_npu_qwen3_examples.md\n   mindspore_backend.md\n   ascend_contribution_guide.md\n   ascend_npu_best_practice.md\n   ascend_npu_qwen3_5_examples.md\n   ascend_npu_glm5_examples.md\n   ascend_npu_environment_variables.md\n"
  },
  {
    "path": "docs/platforms/ascend_npu_support_features.md",
    "content": "# Support Features on Ascend NPU\n\nThis section describes the basic functions and features supported by the Ascend NPU.If you encounter issues or have any\nquestions, please [open an issue](https://github.com/sgl-project/sglang/issues).\n\nIf you want to know the meaning and usage of each parameter,\nclick [Server Arguments](https://docs.sglang.io/advanced_features/server_arguments.html).\n\n## Model and tokenizer\n\n| Argument                               | Defaults | Options                               | Server supported |\n|----------------------------------------|----------|---------------------------------------|:----------------:|\n| `--model-path`<br/>`--model`           | `None`   | Type: str                             |      A2, A3      |\n| `--tokenizer-path`                     | `None`   | Type: str                             |      A2, A3      |\n| `--tokenizer-mode`                     | `auto`   | `auto`, `slow`                        |      A2, A3      |\n| `--tokenizer-worker-num`               | `1`      | Type: int                             |      A2, A3      |\n| `--skip-tokenizer-init`                | `False`  | bool flag (set to enable)             |      A2, A3      |\n| `--load-format`                        | `auto`   | `auto`, `safetensors`                 |      A2, A3      |\n| `--model-loader-` <br/> `extra-config` | `{}`     | Type: str                             |      A2, A3      |\n| `--trust-remote-code`                  | `False`  | bool flag (set to enable)             |      A2, A3      |\n| `--context-length`                     | `None`   | Type: int                             |      A2, A3      |\n| `--is-embedding`                       | `False`  | bool flag (set to enable)             |      A2, A3      |\n| `--enable-multimodal`                  | `None`   | bool flag (set to enable)             |      A2, A3      |\n| `--revision`                           | `None`   | Type: str                             |      A2, A3      |\n| `--model-impl`                         | `auto`   | `auto`, `sglang`,<br/> `transformers` |      A2, A3      |\n\n## HTTP server\n\n| Argument               | Defaults    | Options                   | Server supported |\n|------------------------|-------------|---------------------------|:----------------:|\n| `--host`               | `127.0.0.1` | Type: str                 |      A2, A3      |\n| `--port`               | `30000`     | Type: int                 |      A2, A3      |\n| `--skip-server-warmup` | `False`     | bool flag (set to enable) |      A2, A3      |\n| `--warmups`            | `None`      | Type: str                 |      A2, A3      |\n| `--nccl-port`          | `None`      | Type: int                 |      A2, A3      |\n| `--fastapi-root-path`  | `None`      | Type: str                 |      A2, A3      |\n| `--grpc-mode`          | `False`     | bool flag (set to enable) |      A2, A3      |\n\n## Quantization and data type\n\n| Argument                                    | Defaults | Options                                 | Server supported |\n|---------------------------------------------|----------|-----------------------------------------|:----------------:|\n| `--dtype`                                   | `auto`   | `auto`,<br/> `float16`,<br/> `bfloat16` |      A2, A3      |\n| `--quantization`                            | `None`   | `modelslim`                             |      A2, A3      |\n| `--quantization-param-path`                 | `None`   | Type: str                               | Special For GPU  |\n| `--kv-cache-dtype`                          | `auto`   | `auto`                                  |      A2, A3      |\n| `--enable-fp32-lm-head`                     | `False`  | bool flag <br/> (set to enable)         |      A2, A3      |\n| `--modelopt-quant`                          | `None`   | Type: str                               | Special For GPU  |\n| `--modelopt-checkpoint-`<br/>`restore-path` | `None`   | Type: str                               | Special For GPU  |\n| `--modelopt-checkpoint-`<br/>`save-path`    | `None`   | Type: str                               | Special For GPU  |\n| `--modelopt-export-path`                    | `None`   | Type: str                               | Special For GPU  |\n| `--quantize-and-serve`                      | `False`  | bool flag <br/> (set to enable)         | Special For GPU  |\n| `--rl-quant-profile`                        | `None`   | Type: str                               | Special For GPU  |\n\n## Memory and scheduling\n\n| Argument                                            | Defaults | Options                        | Server supported |\n|-----------------------------------------------------|----------|--------------------------------|:----------------:|\n| `--mem-fraction-static`                             | `None`   | Type: float                    |      A2, A3      |\n| `--max-running-requests`                            | `None`   | Type: int                      |      A2, A3      |\n| `--prefill-max-requests`                            | `None`   | Type: int                      |      A2, A3      |\n| `--max-queued-requests`                             | `None`   | Type: int                      |      A2, A3      |\n| `--max-total-tokens`                                | `None`   | Type: int                      |      A2, A3      |\n| `--chunked-prefill-size`                            | `None`   | Type: int                      |      A2, A3      |\n| `--max-prefill-tokens`                              | `16384`  | Type: int                      |      A2, A3      |\n| `--schedule-policy`                                 | `fcfs`   | `lpm`, `fcfs`                  |      A2, A3      |\n| `--enable-priority-`<br/>`scheduling`               | `False`  | bool flag<br/> (set to enable) |      A2, A3      |\n| `--schedule-low-priority-`<br/>`values-first`       | `False`  | bool flag<br/> (set to enable) |      A2, A3      |\n| `--priority-scheduling-`<br/>`preemption-threshold` | `10`     | Type: int                      |      A2, A3      |\n| `--schedule-conservativeness`                       | `1.0`    | Type: float                    |      A2, A3      |\n| `--page-size`                                       | `128`    | Type: int                      |      A2, A3      |\n| `--swa-full-tokens-ratio`                           | `0.8`    | Type: float                    |      A2, A3      |\n| `--disable-hybrid-swa-memory`                       | `False`  | bool flag<br/> (set to enable) |      A2, A3      |\n| `--radix-eviction-policy`                           | `lru`    | `lru`,<br/>`lfu`               |      A2, A3      |\n| `--abort-on-priority-`<br/>`when-disabled`          | `False`  | bool flag<br/> (set to enable) |      A2, A3      |\n| `--enable-dynamic-chunking`                         | `False`  | bool flag<br/> (set to enable) |      A2, A3      |\n\n## Runtime options\n\n| Argument                                           | Defaults | Options                   | Server supported |\n|----------------------------------------------------|----------|---------------------------|:----------------:|\n| `--device`                                         | `None`   | Type: str                 |      A2, A3      |\n| `--tensor-parallel-size`<br/>`--tp-size`           | `1`      | Type: int                 |      A2, A3      |\n| `--pipeline-parallel-size`<br/>`--pp-size`         | `1`      | Type: int                 |      A2, A3      |\n| `--pp-max-micro-batch-size`                        | `None`   | Type: int                 |      A2, A3      |\n| `--pp-async-batch-depth`                           | `None`   | Type: int                 |      A2, A3      |\n| `--stream-interval`                                | `1`      | Type: int                 |      A2, A3      |\n| `--incremental-streaming-output`                   | `False`  | bool flag (set to enable) |      A2, A3      |\n| `--random-seed`                                    | `None`   | Type: int                 |      A2, A3      |\n| `--constrained-json-`<br/>`whitespace-pattern`     | `None`   | Type: str                 |      A2, A3      |\n| `--constrained-json-`<br/>`disable-any-whitespace` | `False`  | bool flag (set to enable) |      A2, A3      |\n| `--watchdog-timeout`                               | `300`    | Type: float               |      A2, A3      |\n| `--soft-watchdog-timeout`                          | `300`    | Type: float               |      A2, A3      |\n| `--dist-timeout`                                   | `None`   | Type: int                 |      A2, A3      |\n| `--base-gpu-id`                                    | `0`      | Type: int                 |      A2, A3      |\n| `--gpu-id-step`                                    | `1`      | Type: int                 |      A2, A3      |\n| `--sleep-on-idle`                                  | `False`  | bool flag (set to enable) |      A2, A3      |\n| `--custom-sigquit-handler`                         | `None`   | Optional[Callable]        |      A2, A3      |\n\n## Logging\n\n| Argument                                           | Defaults          | Options                        | Server supported |\n|----------------------------------------------------|-------------------|--------------------------------|:----------------:|\n| `--log-level`                                      | `info`            | Type: str                      |      A2, A3      |\n| `--log-level-http`                                 | `None`            | Type: str                      |      A2, A3      |\n| `--log-requests`                                   | `False`           | bool flag<br/> (set to enable) |      A2, A3      |\n| `--log-requests-level`                             | `2`               | `0`, `1`, `2`, `3`             |      A2, A3      |\n| `--log-requests-format`                            | `text`            | `text`, `json`                 |      A2, A3      |\n| `--crash-dump-folder`                              | `None`            | Type: str                      |      A2, A3      |\n| `--enable-metrics`                                 | `False`           | bool flag<br/> (set to enable) |      A2, A3      |\n| `--enable-metrics-for-`<br/>`all-schedulers`       | `False`           | bool flag<br/> (set to enable) |      A2, A3      |\n| `--tokenizer-metrics-`<br/>`custom-labels-header`  | `x-custom-labels` | Type: str                      |      A2, A3      |\n| `--tokenizer-metrics-`<br/>`allowed-custom-labels` | `None`            | List[str]                      |      A2, A3      |\n| `--bucket-time-to-`<br/>`first-token`              | `None`            | List[float]                    |      A2, A3      |\n| `--bucket-inter-token-`<br/>`latency`              | `None`            | List[float]                    |      A2, A3      |\n| `--bucket-e2e-request-`<br/>`latency`              | `None`            | List[float]                    |      A2, A3      |\n| `--collect-tokens-`<br/>`histogram`                | `False`           | bool flag<br/> (set to enable) |      A2, A3      |\n| `--prompt-tokens-buckets`                          | `None`            | List[str]                      |      A2, A3      |\n| `--generation-tokens-buckets`                      | `None`            | List[str]                      |      A2, A3      |\n| `--gc-warning-threshold-secs`                      | `0.0`             | Type: float                    |      A2, A3      |\n| `--decode-log-interval`                            | `40`              | Type: int                      |      A2, A3      |\n| `--enable-request-time-`<br/>`stats-logging`       | `False`           | bool flag<br/> (set to enable) |      A2, A3      |\n| `--kv-events-config`                               | `None`            | Type: str                      | Special for GPU  |\n| `--enable-trace`                                   | `False`           | bool flag<br/> (set to enable) |      A2, A3      |\n| `--oltp-traces-endpoint`                           | `localhost:4317`  | Type: str                      |      A2, A3      |\n| `--log-requests-target`                            | `None`            | Type: str                      |      A2, A3      |\n| `--uvicorn-access-log-exclude-prefixes`            | `[]`              | List[str]                      |      A2, A3      |\n\n## RequestMetricsExporter configuration\n\n| Argument                              | Defaults | Options                        | Server supported |\n|---------------------------------------|----------|--------------------------------|:----------------:|\n| `--export-metrics-to-`<br/>`file`     | `False`  | bool flag<br/> (set to enable) |      A2, A3      |\n| `--export-metrics-to-`<br/>`file-dir` | `None`   | Type: str                      |      A2, A3      |\n\n## API related\n\n| Argument                | Defaults  | Options                        | Server supported |\n|-------------------------|-----------|--------------------------------|:----------------:|\n| `--api-key`             | `None`    | Type: str                      |      A2, A3      |\n| `--admin-api-key`       | `None`    | Type: str                      |      A2, A3      |\n| `--served-model-name`   | `None`    | Type: str                      |      A2, A3      |\n| `--weight-version`      | `default` | Type: str                      |      A2, A3      |\n| `--chat-template`       | `None`    | Type: str                      |      A2, A3      |\n| `--completion-template` | `None`    | Type: str                      |      A2, A3      |\n| `--enable-cache-report` | `False`   | bool flag<br/> (set to enable) |      A2, A3      |\n| `--reasoning-parser`    | `None`    | `deepseek-r1`<br/>`deepseek-v3`<br/>`glm45`<br/>`gpt-oss`<br/>`kimi`<br/>`qwen3`<br/>`qwen3-thinking`<br/>`step3`                  |      A2, A3      |\n| `--tool-call-parser`    | `None`    | `deepseekv3`<br/>`deepseekv31`<br/>`glm`<br/>`glm45`<br/>`glm47`<br/>`gpt-oss`<br/>`kimi_k2`<br/>`llama3`<br/>`mistral`<br/>`pythonic`<br/>`qwen`<br/>`qwen25`<br/>`qwen3_coder`<br/>`step3`<br/>`gigachat3`            |      A2, A3      |\n| `--sampling-defaults`   | `model`   | `openai`, `model`              |      A2, A3      |\n\n## Data parallelism\n\n| Argument                               | Defaults      | Options                                                   | Server supported |\n|----------------------------------------|---------------|-----------------------------------------------------------|:----------------:|\n| `--data-parallel-size`<br/>`--dp-size` | `1`           | Type: int                                                 |      A2, A3      |\n| `--load-balance-method`                | `auto` | `auto`,<br/> `round_robin`,<br/> `follow_bootstrap_room`,<br/> `total_requests`,<br/> `total_tokens` |      A2, A3      |\n| `--prefill-round-robin-balance`        | `False`       | bool flag<br/> (set to enable)                            |      A2, A3      |\n\n## Multi-node distributed serving\n\n| Argument                                  | Defaults | Options   | Server supported |\n|-------------------------------------------|----------|-----------|:----------------:|\n| `--dist-init-addr`<br/>`--nccl-init-addr` | `None`   | Type: str |      A2, A3      |\n| `--nnodes`                                | `1`      | Type: int |      A2, A3      |\n| `--node-rank`                             | `0`      | Type: int |      A2, A3      |\n\n## Model override args\n\n| Argument                             | Defaults | Options   | Server supported |\n|--------------------------------------|----------|-----------|:----------------:|\n| `--json-model-override-`<br/>`args`  | `{}`     | Type: str |      A2, A3      |\n| `--preferred-sampling-`<br/>`params` | `None`   | Type: str |      A2, A3      |\n\n## LoRA\n\n| Argument                 | Defaults | Options                             | Server supported |\n|--------------------------|----------|-------------------------------------|:----------------:|\n| `--enable-lora`          | `False`  | Bool flag <br/>(set to enable)      |      A2, A3      |\n| `--max-lora-rank`        | `None`   | Type: int                           |      A2, A3      |\n| `--lora-target-modules`  | `None`   | `all`                               |      A2, A3      |\n| `--lora-paths`           | `None`   | Type: List[str] /<br/> JSON objects |      A2, A3      |\n| `--max-loras-per-batch`  | `8`      | Type: int                           |      A2, A3      |\n| `--max-loaded-loras`     | `None`   | Type: int                           |      A2, A3      |\n| `--lora-eviction-policy` | `lru`    | `lru`,<br/> `fifo`                  |      A2, A3      |\n| `--lora-backend`         | `csgmv`  | `triton`,<br/>`csgmv`,<br/>`ascend`,<br/>`torch_native`  |      A2, A3      |\n| `--max-lora-chunk-size`  | `16`     | `16`, `32`,<br/> `64`, `128`        | Special for GPU  |\n\n## Kernel Backends (Attention, Sampling, Grammar, GEMM)\n\n| Argument                               | Defaults          | Options                                                                                        | Server supported |\n|----------------------------------------|-------------------|------------------------------------------------------------------------------------------------|:----------------:|\n| `--attention-backend`                  | `None`            | `ascend`                                                                                       |      A2, A3      |\n| `--prefill-attention-backend`          | `None`            | `ascend`                                                                                       |      A2, A3      |\n| `--decode-attention-backend`           | `None`            | `ascend`                                                                                       |      A2, A3      |\n| `--sampling-backend`                   | `None`            | `pytorch`,<br/>`ascend`                                                                        |      A2, A3      |\n| `--grammar-backend`                    | `None`            | `xgrammar`                                                                                     |      A2, A3      |\n| `--mm-attention-backend`               | `None`            | `ascend_attn`                                                                                  |      A2, A3      |\n| `--nsa-prefill-backend`                | `flashmla_sparse` | `flashmla_sparse`,<br/> `flashmla_decode`,<br/>`fa3`,<br/> `tilelang`,<br/> `aiter`            | Special for GPU  |\n| `--nsa-decode-backend`                 | `fa3`             | `flashmla_prefill`,<br/> `flashmla_kv`,<br/> `fa3`,<br/>`tilelang`,<br/> `aiter`               | Special for GPU  |\n| `--fp8-gemm-backend`                   | `auto`            | `auto`,<br/> `deep_gemm`,<br/> `flashinfer_trtllm`,<br/>`flashinfer_cutlass`,<br/>`flashinfer_deepgemm`,<br/>`cutlass`,<br/> `triton`,<br/> `aiter` | Special for GPU  |\n| `--disable-flashinfer-`<br/>`autotune` | `False`           | bool flag<br/> (set to enable)                                                                 | Special for GPU  |\n\n## Speculative decoding\n\n| Argument                                                         | Defaults  | Options                  | Server supported |\n|------------------------------------------------------------------|-----------|--------------------------|:----------------:|\n| `--speculative-algorithm`                                        | `None`    | `EAGLE3`,<br/> `NEXTN`   |      A2, A3      |\n| `--speculative-draft-model-path`<br/>`--speculative-draft-model` | `None`    | Type: str                |      A2, A3      |\n| `--speculative-draft-model-`<br/>`revision`                      | `None`    | Type: str                |      A2, A3      |\n| `--speculative-draft-load-format`                                | `None`    | `auto`                   |      A2, A3      |\n| `--speculative-num-steps`                                        | `None`    | Type: int                |      A2, A3      |\n| `--speculative-eagle-topk`                                       | `None`    | Type: int                |      A2, A3      |\n| `--speculative-num-draft-tokens`                                 | `None`    | Type: int                |      A2, A3      |\n| `--speculative-accept-`<br/>`threshold-single`                   | `1.0`     | Type: float              | Special for GPU  |\n| `--speculative-accept-`<br/>`threshold-acc`                      | `1.0`     | Type: float              | Special for GPU  |\n| `--speculative-token-map`                                        | `None`    | Type: str                |      A2, A3      |\n| `--speculative-attention-`<br/>`mode`                            | `prefill` | `prefill`,<br/> `decode` |      A2, A3      |\n| `--speculative-moe-runner-`<br/>`backend`                        | `None`    | `auto`                   |      A2, A3      |\n| `--speculative-moe-a2a-`<br/>`backend`                           | `None`    | `ascend_fuseep`          |      A2, A3      |\n| `--speculative-draft-attention-backend`                          | `None`    | `ascend`                 |      A2, A3      |\n| `--speculative-draft-model-quantization`                         | `None`    | `unquant`                |      A2, A3      |\n\n## Ngram speculative decoding\n\n| Argument                                           | Defaults   | Options            | Server supported |\n|----------------------------------------------------|------------|--------------------|:----------------:|\n| `--speculative-ngram-`<br/>`min-match-window-size` | `1`        | Type: int          |   Experimental   |\n| `--speculative-ngram-`<br/>`max-match-window-size` | `12`       | Type: int          |   Experimental   |\n| `--speculative-ngram-`<br/>`min-bfs-breadth`       | `1`        | Type: int          |   Experimental   |\n| `--speculative-ngram-`<br/>`max-bfs-breadth`       | `10`       | Type: int          |   Experimental   |\n| `--speculative-ngram-`<br/>`match-type`            | `BFS`      | `BFS`,<br/> `PROB` |   Experimental   |\n| `--speculative-ngram-`<br/>`branch-length`         | `18`       | Type: int          |   Experimental   |\n| `--speculative-ngram-`<br/>`capacity`              | `10000000` | Type: int          |   Experimental   |\n\n## Expert parallelism\n\n| Argument                                              | Defaults  | Options                                     | Server supported |\n|-------------------------------------------------------|-----------|---------------------------------------------|:----------------:|\n| `--expert-parallel-size`<br/>`--ep-size`<br/>`--ep`   | `1`       | Type: int                                   |      A2, A3      |\n| `--moe-a2a-backend`                                   | `none`    | `none`,<br/> `deepep`,<br/> `ascend_fuseep` |      A2, A3      |\n| `--moe-runner-backend`                                | `auto`    | `auto`, `triton`                            |      A2, A3      |\n| `--flashinfer-mxfp4-`<br/>`moe-precision`             | `default` | `default`,<br/> `bf16`                      | Special for GPU  |\n| `--enable-flashinfer-`<br/>`allreduce-fusion`         | `False`   | bool flag<br/> (set to enable)              | Special for GPU  |\n| `--deepep-mode`                                       | `auto`    | `normal`, <br/>`low_latency`,<br/> `auto`   |      A2, A3      |\n| `--deepep-config`                                     | `None`    | Type: str                                   | Special for GPU  |\n| `--ep-num-redundant-experts`                          | `0`       | Type: int                                   |      A2, A3      |\n| `--ep-dispatch-algorithm`                             | `None`    | Type: str                                   |      A2, A3      |\n| `--init-expert-location`                              | `trivial` | Type: str                                   |      A2, A3      |\n| `--enable-eplb`                                       | `False`   | bool flag<br/> (set to enable)              |      A2, A3      |\n| `--eplb-algorithm`                                    | `auto`    | Type: str                                   |      A2, A3      |\n| `--eplb-rebalance-layers-`<br/>`per-chunk`            | `None`    | Type: int                                   |      A2, A3      |\n| `--eplb-min-rebalancing-`<br/>`utilization-threshold` | `1.0`     | Type: float                                 |      A2, A3      |\n| `--expert-distribution-`<br/>`recorder-mode`          | `None`    | Type: str                                   |      A2, A3      |\n| `--expert-distribution-`<br/>`recorder-buffer-size`   | `None`    | Type: int                                   |      A2, A3      |\n| `--enable-expert-distribution-`<br/>`metrics`         | `False`   | bool flag (set to enable)                   |      A2, A3      |\n| `--moe-dense-tp-size`                                 | `None`    | Type: int                                   |      A2, A3      |\n| `--elastic-ep-backend`                                | `None`    | `none`, `mooncake`                          | Special for GPU  |\n| `--mooncake-ib-device`                                | `None`    | Type: str                                   | Special for GPU  |\n\n## Mamba Cache\n\n| Argument                     | Defaults  | Options                                       | Server supported |\n|------------------------------|-----------|-----------------------------------------------|:----------------:|\n| `--max-mamba-cache-size`     | `None`    | Type: int                                     |      A2, A3      |\n| `--mamba-ssm-dtype`          | `float32` | `float32`,<br/>`bfloat16`,<br/>`float16`      |      A2, A3      |\n| `--mamba-full-memory-ratio`  | `0.9`     | Type: float                                   |      A2, A3      |\n| `--mamba-scheduler-strategy` | `auto`    | Only `auto`, `no_buffer` supported            |      A2, A3      |\n| `--mamba-track-interval`     | `256`     | Type: int                                     |      A2, A3      |\n\n## Hierarchical cache\n\n| Argument                                        | Defaults        | Options                                                             | Server supported |\n|-------------------------------------------------|-----------------|---------------------------------------------------------------------|:----------------:|\n| `--enable-hierarchical-`<br/>`cache`            | `False`         | bool flag<br/> (set to enable)                                      |      A2, A3      |\n| `--hicache-ratio`                               | `2.0`           | Type: float                                                         |      A2, A3      |\n| `--hicache-size`                                | `0`             | Type: int                                                           |      A2, A3      |\n| `--hicache-write-policy`                        | `write_through` | Currently only `write_back` supported                               |      A2, A3      |\n| `--hicache-io-backend`                          | `kernel`        | `kernel_ascend`,<br/>                     `direct`                  |      A2, A3      |\n| `--hicache-mem-layout`                          | `layer_first`   | `page_first_direct`,<br/>                  `page_first_kv_split`    |      A2, A3      |\n| `--hicache-storage-`<br/>`backend`              | `None`          | `file`                                                              |      A2, A3      |\n| `--hicache-storage-`<br/>`prefetch-policy`      | `best_effort`   | `best_effort`,<br/> `wait_complete`,<br/>  `timeout`                | Special for GPU  |\n| `--hicache-storage-`<br/>`backend-extra-config` | `None`          | Type: str                                                           | Special for GPU  |\n\n## LMCache\n\n| Argument           | Defaults | Options                        | Server supported |\n|--------------------|----------|--------------------------------|:----------------:|\n| `--enable-lmcache` | `False`  | bool flag<br/> (set to enable) | Special for GPU  |\n\n## Offloading\n\n| Argument                  | Defaults | Options   | Server supported |\n|---------------------------|----------|-----------|:----------------:|\n| `--cpu-offload-gb`        | `0`      | Type: int |      A2, A3      |\n| `--offload-group-size`    | `-1`     | Type: int |      Planned     |\n| `--offload-num-in-group`  | `1`      | Type: int |      Planned     |\n| `--offload-prefetch-step` | `1`      | Type: int |      Planned     |\n| `--offload-mode`          | `cpu`    | Type: str |      Planned     |\n\n## Args for multi-item scoring\n\n| Argument                         | Defaults | Options   | Server supported |\n|----------------------------------|----------|-----------|:----------------:|\n| `--multi-item-scoring-delimiter` | `None`   | Type: int |      A2, A3      |\n\n## Optimization/debug options\n\n| Argument                                                | Defaults | Options                        | Server supported |\n|---------------------------------------------------------|----------|--------------------------------|:----------------:|\n| `--disable-radix-cache`                                 | `False`  | bool flag<br/> (set to enable) |      A2, A3      |\n| `--cuda-graph-max-bs`                                   | `None`   | Type: int                      |      A2, A3      |\n| `--cuda-graph-bs`                                       | `None`   | List[int]                      |      A2, A3      |\n| `--disable-cuda-graph`                                  | `False`  | bool flag<br/> (set to enable) |      A2, A3      |\n| `--disable-cuda-graph-`<br/>`padding`                   | `False`  | bool flag<br/> (set to enable) |      A2, A3      |\n| `--enable-profile-`<br/>`cuda-graph`                    | `False`  | bool flag<br/> (set to enable) |      A2, A3      |\n| `--enable-cudagraph-gc`                                 | `False`  | bool flag<br/> (set to enable) |      A2, A3      |\n| `--enable-nccl-nvls`                                    | `False`  | bool flag<br/> (set to enable) | Special for GPU  |\n| `--enable-symm-mem`                                     | `False`  | bool flag<br/> (set to enable) | Special for GPU  |\n| `--disable-flashinfer-`<br/>`cutlass-moe-fp4-allgather` | `False`  | bool flag<br/> (set to enable) | Special for GPU  |\n| `--enable-tokenizer-`<br/>`batch-encode`                | `False`  | bool flag<br/> (set to enable) |      A2, A3      |\n| `--disable-tokenizer-`<br/>`batch-decode`               | `False`  | bool flag<br/> (set to enable) |      A2, A3      |\n| `--disable-custom-`<br/>`all-reduce`                    | `False`  | bool flag<br/> (set to enable) | Special for GPU  |\n| `--enable-mscclpp`                                      | `False`  | bool flag<br/> (set to enable) | Special for GPU  |\n| `--enable-torch-`<br/>`symm-mem`                        | `False`  | bool flag<br/> (set to enable) | Special for GPU  |\n| `--disable-overlap`<br/>`-schedule`                     | `False`  | bool flag<br/> (set to enable) |      A2, A3      |\n| `--enable-mixed-`<br/>`chunk`                           | `False`  | bool flag<br/> (set to enable) |      A2, A3      |\n| `--enable-dp-attention`                                 | `False`  | bool flag<br/> (set to enable) |      A2, A3      |\n| `--enable-dp-lm-head`                                   | `False`  | bool flag<br/> (set to enable) |      A2, A3      |\n| `--enable-two-`<br/>`batch-overlap`                     | `False`  | bool flag<br/> (set to enable) |     Planned      |\n| `--enable-single-`<br/>`batch-overlap`                  | `False`  | bool flag<br/> (set to enable) |      A2, A3      |\n| `--tbo-token-`<br/>`distribution-threshold`             | `0.48`   | Type: float                    |     Planned      |\n| `--enable-torch-`<br/>`compile`                         | `False`  | bool flag<br/> (set to enable) |      A2, A3      |\n| `--enable-torch-`<br/>`compile-debug-mode`              | `False`  | bool flag<br/> (set to enable) |      A2, A3      |\n| `--enable-piecewise-`<br/>`cuda-graph`                  | `False`  | bool flag<br/> (set to enable) |      A2, A3      |\n| `--piecewise-cuda-`<br/>`graph-tokens`                  | `None`   | Type: JSON<br/> list           |      A2, A3      |\n| `--piecewise-cuda-`<br/>`graph-compiler`                | `eager`  | [\"eager\", \"inductor\"]          |      A2, A3      |\n| `--torch-compile-max-bs`                                | `32`     | Type: int                      |      A2, A3      |\n| `--piecewise-cuda-`<br/>`graph-max-tokens`              | `None`   | Type: int                      |      A2, A3      |\n| `--torchao-config`                                      | ``       | Type: str                      | Special for GPU  |\n| `--enable-nan-detection`                                | `False`  | bool flag<br/> (set to enable) |      A2, A3      |\n| `--enable-p2p-check`                                    | `False`  | bool flag<br/> (set to enable) | Special for GPU  |\n| `--triton-attention-`<br/>`reduce-in-fp32`              | `False`  | bool flag<br/> (set to enable) | Special for GPU  |\n| `--triton-attention-`<br/>`num-kv-splits`               | `8`      | Type: int                      | Special for GPU  |\n| `--triton-attention-`<br/>`split-tile-size`             | `None`   | Type: int                      | Special for GPU  |\n| `--delete-ckpt-`<br/>`after-loading`                    | `False`  | bool flag<br/> (set to enable) |      A2, A3      |\n| `--enable-memory-saver`                                 | `False`  | bool flag<br/> (set to enable) |      A2, A3      |\n| `--enable-weights-`<br/>`cpu-backup`                    | `False`  | bool flag<br/> (set to enable) |      A2, A3      |\n| `--enable-draft-weights-`<br/>`cpu-backup`              | `False`  | bool flag<br/> (set to enable) |      A2, A3      |\n| `--allow-auto-truncate`                                 | `False`  | bool flag<br/> (set to enable) |      A2, A3      |\n| `--enable-custom-`<br/>`logit-processor`                | `False`  | bool flag<br/> (set to enable) |      A2, A3      |\n| `--flashinfer-mla-`<br/>`disable-ragged`                | `False`  | bool flag<br/> (set to enable) | Special for GPU  |\n| `--disable-shared-`<br/>`experts-fusion`                | `False`  | bool flag<br/> (set to enable) |      A2, A3      |\n| `--disable-chunked-`<br/>`prefix-cache`                 | `False`  | bool flag<br/> (set to enable) |      A2, A3      |\n| `--disable-fast-`<br/>`image-processor`                 | `False`  | bool flag<br/> (set to enable) |      A2, A3      |\n| `--keep-mm-feature-`<br/>`on-device`                    | `False`  | bool flag<br/> (set to enable) |      A2, A3      |\n| `--enable-return-`<br/>`hidden-states`                  | `False`  | bool flag<br/> (set to enable) |      A2, A3      |\n| `--enable-return-`<br/>`routed-experts`                 | `False`  | bool flag<br/> (set to enable) |      A2, A3      |\n| `--scheduler-recv-`<br/>`interval`                      | `1`      | Type: int                      |      A2, A3      |\n| `--numa-node`                                           | `None`   | List[int]                      |      A2, A3      |\n| `--enable-deterministic-`<br/>`inference`               | `False`  | bool flag<br/> (set to enable) |     Planned      |\n| `--rl-on-policy-target`                                 | `None`   | `fsdp`                         |     Planned      |\n| `--enable-layerwise-`<br/>`nvtx-marker`                 | `False`  | bool flag<br/> (set to enable) | Special for GPU  |\n| `--enable-attn-tp-`<br/>`input-scattered`               | `False`  | bool flag<br/> (set to enable) |   Experimental   |\n| `--enable-nsa-prefill-`<br/>`context-parallel`          | `False`  | bool flag<br/> (set to enable) |      A2, A3      |\n| `--enable-fused-qk-`<br/>`norm-rope`                    | `False`  | bool flag<br/> (set to enable) | Special for GPU  |\n\n## Dynamic batch tokenizer\n\n| Argument                                         | Defaults | Options                        | Server supported |\n|--------------------------------------------------|----------|--------------------------------|:----------------:|\n| `--enable-dynamic-`<br/>`batch-tokenizer`        | `False`  | bool flag<br/> (set to enable) |      A2, A3      |\n| `--dynamic-batch-`<br/>`tokenizer-batch-size`    | `32`     | Type: int                      |      A2, A3      |\n| `--dynamic-batch-`<br/>`tokenizer-batch-timeout` | `0.002`  | Type: float                    |      A2, A3      |\n\n## Debug tensor dumps\n\n| Argument                                   | Defaults | Options   | Server supported |\n|--------------------------------------------|----------|-----------|:----------------:|\n| `--debug-tensor-dump-`<br/>`output-folder` | `None`   | Type: str |      A2, A3      |\n| `--debug-tensor-dump-`<br/>`layers`        | `None`   | List[int] |      A2, A3      |\n| `--debug-tensor-dump-`<br/>`input-file`    | `None`   | Type: str |      A2, A3      |\n\n## PD disaggregation\n\n| Argument                                                | Defaults   | Options                               | Server supported |\n|---------------------------------------------------------|------------|---------------------------------------|:----------------:|\n| `--disaggregation-mode`                                 | `null`     | `null`,<br/> `prefill`,<br/> `decode` |      A2, A3      |\n| `--disaggregation-transfer-backend`                     | `mooncake` | `ascend`                              |      A2, A3      |\n| `--disaggregation-bootstrap-port`                       | `8998`     | Type: int                             |      A2, A3      |\n| `--disaggregation-ib-device`                            | `None`     | Type: str                             | Special for GPU  |\n| `--disaggregation-decode-`<br/>`enable-offload-kvcache` | `False`    | bool flag<br/> (set to enable)        |      A2, A3      |\n| `--disaggregation-decode-`<br/>`enable-fake-auto`       | `False`    | bool flag<br/> (set to enable)        |      A2, A3      |\n| `--num-reserved-decode-tokens`                          | `512`      | Type: int                             |      A2, A3      |\n| `--disaggregation-decode-`<br/>`polling-interval`       | `1`        | Type: int                             |      A2, A3      |\n\n## Encode prefill disaggregation\n\n| Argument                     | Defaults           | Options                                                        | Server supported |\n|------------------------------|--------------------|----------------------------------------------------------------|:----------------:|\n| `--encoder-only`             | `False`            | bool flag<br/> (set to enable)                                 |      A2, A3      |\n| `--language-only`            | `False`            | bool flag<br/> (set to enable)                                 |      A2, A3      |\n| `--encoder-transfer-backend` | `zmq_to_scheduler` | `zmq_to_scheduler`, <br/> `zmq_to_tokenizer`,<br/>  `mooncake` |      A2, A3      |\n| `--encoder-urls`             | `[]`               | List[str]                                                      |      A2, A3      |\n\n## Custom weight loader\n\n| Argument                                                                | Defaults | Options                         | Server supported |\n|-------------------------------------------------------------------------|----------|---------------------------------|:----------------:|\n| `--custom-weight-loader`                                                | `None`   | List[str]                       |      A2, A3      |\n| `--weight-loader-disable-`<br/>`mmap`                                   | `False`  | bool flag<br/> (set to enable)  |      A2, A3      |\n| `--remote-instance-weight-`<br/>`loader-seed-instance-ip`               | `None`   | Type: str                       |      A2, A3      |\n| `--remote-instance-weight-`<br/>`loader-seed-instance-service-port`     | `None`   | Type: int                       |      A2, A3      |\n| `--remote-instance-weight-`<br/>`loader-send-weights-group-ports`       | `None`   | Type: JSON<br/> list            |      A2, A3      |\n| `--remote-instance-weight-`<br/>`loader-backend`                        | `nccl`   | `transfer_engine`, <br/> `nccl` |      A2, A3      |\n| `--remote-instance-weight-`<br/>`loader-start-seed-via-transfer-engine` | `False`  | bool flag<br/> (set to enable)  | Special for GPU  |\n\n## For PD-Multiplexing\n\n| Argument              | Defaults | Options                        | Server supported |\n|-----------------------|----------|--------------------------------|:----------------:|\n| `--enable-pdmux`      | `False`  | bool flag<br/> (set to enable) | Special for GPU  |\n| `--pdmux-config-path` | `None`   | Type: str                      | Special for GPU  |\n| `--sm-group-num`      | `8`      | Type: int                      | Special for GPU  |\n\n## For Multi-Modal\n\n| Argument                                      | Defaults | Options                        | Server supported |\n|-----------------------------------------------|----------|--------------------------------|:----------------:|\n| `--mm-max-concurrent-calls`                   | `32`     | Type: int                      |      A2, A3      |\n| `--mm-per-request-timeout`                    | `10.0`   | Type: float                    |      A2, A3      |\n| `--enable-broadcast-mm-`<br/>`inputs-process` | `False`  | bool flag<br/> (set to enable) |      A2, A3      |\n| `--mm-process-config`                         | `None`   | Type: JSON / Dict              |      A2, A3      |\n| `--mm-enable-dp-encoder`                      | `False`  | bool flag<br/> (set to enable) |      A2, A3      |\n| `--limit-mm-data-per-request`                 | `None`   | Type: JSON / Dict              |      A2, A3      |\n\n## For checkpoint decryption\n\n| Argument                        | Defaults | Options                        | Server supported |\n|---------------------------------|----------|--------------------------------|:----------------:|\n| `--decrypted-config-file`       | `None`   | Type: str                      |      A2, A3      |\n| `--decrypted-draft-config-file` | `None`   | Type: str                      |      A2, A3      |\n| `--enable-prefix-mm-cache`      | `False`  | bool flag<br/> (set to enable) |      A2, A3      |\n\n## Forward hooks\n\n| Argument          | Defaults | Options         | Server supported |\n|-------------------|----------|-----------------|:----------------:|\n| `--forward-hooks` | `None`   | Type: JSON list |      A2, A3      |\n\n## Configuration file support\n\n| Argument   | Defaults | Options   | Server supported |\n|------------|----------|-----------|:----------------:|\n| `--config` | `None`   | Type: str |      A2, A3      |\n\n## Other Params\n\nThe following parameters are not supported because the third-party components that depend on are not compatible with the\nNPU, like Ktransformer, checkpoint-engine etc.\n\n| Argument                                                          | Defaults  | Options                   |\n|-------------------------------------------------------------------|-----------|---------------------------|\n| `--checkpoint-engine-` <br/> `wait-weights-` <br/> `before-ready` | `False`   | bool flag (set to enable) |\n| `--kt-weight-path`                                                | `None`    | Type: str                 |\n| `--kt-method`                                                     | `AMXINT4` | Type: str                 |\n| `--kt-cpuinfer`                                                   | `None`    | Type: int                 |\n| `--kt-threadpool-count`                                           | 2         | Type: int                 |\n| `--kt-num-gpu-experts`                                            | `None`    | Type: int                 |\n| `--kt-max-deferred-`<br/>`experts-per-token`                      | `None`    | Type: int                 |\n\nThe following parameters have some functional deficiencies on community\n\n| Argument                              | Defaults | Options                        |\n|---------------------------------------|----------|--------------------------------|\n| `--enable-double-sparsity`            | `False`  | bool flag<br/> (set to enable) |\n| `--ds-channel-config-path`            | `None`   | Type: str                      |\n| `--ds-heavy-channel-num`              | `32`     | Type: int                      |\n| `--ds-heavy-token-num`                | `256`    | Type: int                      |\n| `--ds-heavy-channel-type`             | `qk`     | Type: str                      |\n| `--ds-sparse-decode-`<br/>`threshold` | `4096`   | Type: int                      |\n| `--tool-server`                       | `None`   | Type: str                      |\n"
  },
  {
    "path": "docs/platforms/ascend_npu_support_models.md",
    "content": "# Support Models on Ascend NPU\n\nThis section describes the models supported on the Ascend NPU, including Large Language Models, Multimodal Language\nModels, Embedding Models, Reward Models and Rerank Models. Mainstream DeepSeek/Qwen/GLM series are included.\nYou are welcome to enable various models based on your business requirements.\n\n## Large Language Models\n\n| Models                                     | Model Family                   |               A2 Supported               |               A3 Supported               |\n|--------------------------------------------|--------------------------------|:----------------------------------------:|:----------------------------------------:|\n| DeepSeek V3/V3.1                           | DeepSeek                       | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| DeepSeek-V3.2-W8A8                         | DeepSeek                       | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| DeepSeek-R1-0528-W8A8                      | DeepSeek                       | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| DeepSeek-V2-Lite-W8A8                      | DeepSeek                       | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| Qwen/Qwen3-30B-A3B-Instruct-2507           | Qwen                           | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| Qwen/Qwen3-32B                             | Qwen                           | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| Qwen/Qwen3-0.6B                            | Qwen                           | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| Qwen3-235B-A22B-W8A8                       | Qwen                           | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| Qwen/Qwen3-Next-80B-A3B-Instruct           | Qwen                           | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| Qwen3-Coder-480B-A35B-Instruct-w8a8-QuaRot | Qwen                           | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| Qwen/Qwen2.5-7B-Instruct                   | Qwen                           | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| QWQ-32B-W8A8                               | Qwen                           | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| meta-llama/Llama-4-Scout-17B-16E-Instruct  | Llama                          | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| AI-ModelScope/Llama-3.1-8B-Instruct        | Llama                          | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| LLM-Research/llama-2-7b                    | Llama                          | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| LLM-Research/Llama-3.2-1B-Instruct         | Llama                          | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| mistralai/Mistral-7B-Instruct-v0.2         | Mistral                        | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| google/gemma-3-4b-it                       | Gemma                          | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| microsoft/Phi-4-multimodal-instruct        | Phi                            | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| allenai/OLMoE-1B-7B-0924                   | OLMoE                          | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| stabilityai/stablelm-2-1_6b                | StableLM                       | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| CohereForAI/c4ai-command-r-v01             | Command-R                      | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| huihui-ai/grok-2                           | Grok                           | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| ZhipuAI/chatglm2-6b                        | ChatGLM                        | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| Shanghai_AI_Laboratory/internlm2-7b        | InternLM 2                     | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| LGAI-EXAONE/EXAONE-3.5-7.8B-Instruct       | ExaONE 3                       | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| xverse/XVERSE-MoE-A36B                     | XVERSE                         | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| HuggingFaceTB/SmolLM-1.7B                  | SmolLM                         | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| ZhipuAI/glm-4-9b-chat                      | GLM-4                          | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| XiaomiMiMo/MiMo-7B-RL                      | MiMo                           | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| arcee-ai/AFM-4.5B-Base                     | Arcee AFM-4.5B                 | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| Howeee/persimmon-8b-chat                   | Persimmon                      | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| inclusionAI/Ling-lite                      | Ling                           | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| ibm-granite/granite-3.1-8b-instruct        | Granite                        | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| ibm-granite/granite-3.0-3b-a800m-instruct  | Granite MoE                    | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| AI-ModelScope/dbrx-instruct                | DBRX (Databricks)              | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| baichuan-inc/Baichuan2-13B-Chat            | Baichuan 2 (7B, 13B)           | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| baidu/ERNIE-4.5-21B-A3B-PT                 | ERNIE-4.5 (4.5, 4.5MoE series) | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| OpenBMB/MiniCPM3-4B                        | MiniCPM (v3, 4B)               | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| Kimi/Kimi-K2-Thinking                      | Kimi                           | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| openai/gpt-oss-120b                        | GPTOSS                         | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| allenai/OLMo-2-1124-7B-Instruct            | OLMo                           | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| minimax/MiniMax-M2                         | MiniMax-M2                     | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| upstage/SOLAR-10.7B-Instruct-v1.0          | Solar                          | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| bigcode/starcoder2-7b                      | StarCoder2                     | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| arcee-ai/Trinity-Mini                      | Trinity (Nano, Mini)           | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n\n## Multimodal Language Models\n\n| Models                                        | Model Family (Variants)   |               A2 Supported               |               A3 Supported               |\n|-----------------------------------------------|---------------------------|:----------------------------------------:|:----------------------------------------:|\n| Qwen/Qwen2.5-VL-3B-Instruct                   | Qwen-VL                   | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| Qwen/Qwen2.5-VL-72B-Instruct                  | Qwen-VL                   | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| Qwen/Qwen3-VL-30B-A3B-Instruct                | Qwen-VL                   | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| Qwen/Qwen3-VL-8B-Instruct                     | Qwen-VL                   | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| Qwen/Qwen3-VL-4B-Instruct                     | Qwen-VL                   | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| Qwen/Qwen3-VL-235B-A22B-Instruct              | Qwen-VL                   | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| deepseek-ai/deepseek-vl2                      | DeepSeek-VL2              | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| deepseek-ai/Janus-Pro-1B                      | Janus-Pro (1B, 7B)        | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| deepseek-ai/Janus-Pro-7B                      | Janus-Pro (1B, 7B)        | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| openbmb/MiniCPM-V-2_6                         | MiniCPM-V / MiniCPM-o     | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| openbmb/MiniCPM-o-2_6                         | MiniCPM-V / MiniCPM-o     | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| google/gemma-3-4b-it                          | Gemma 3 (Multimodal)      | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| mistralai/Mistral-Small-3.1-24B-Instruct-2503 | Mistral-Small-3.1-24B     | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| microsoft/Phi-4-multimodal-instruct           | Phi-4-multimodal-instruct | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| XiaomiMiMo/MiMo-VL-7B-RL                      | MiMo-VL (7B)              | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| AI-ModelScope/llava-v1.6-34b                  | LLaVA (v1.5 & v1.6)       | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| lmms-lab/llava-next-72b                       | LLaVA-NeXT (8B, 72B)      | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| lmms-lab/llava-onevision-qwen2-7b-ov          | LLaVA-OneVision           | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| Kimi/Kimi-VL-A3B-Instruct                     | Kimi-VL (A3B)             | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| ZhipuAI/GLM-4.5V                              | GLM-4.5V (106B)           | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| LLM-Research/Llama-3.2-11B-Vision-Instruct    | Llama 3.2 Vision (11B)    | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| rednote-hilab/dots.ocr                        | DotsVLM-OCR               | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n\n## Embedding Models\n\n| Models                                    | Model Family             |               A2 Supported               |               A3 Supported               |\n|-------------------------------------------|--------------------------|:----------------------------------------:|:----------------------------------------:|\n| \tintfloat/e5-mistral-7b-instruct          | E5 (Llama/Mistral based) | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| \tiic/gte_Qwen2-1.5B-instruct              | GTE-Qwen2                | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| \tQwen/Qwen3-Embedding-8B                  | Qwen3-Embedding          | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| \tAlibaba-NLP/gme-Qwen2-VL-2B-Instruct     | GME (Multimodal)         | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| \tAI-ModelScope/clip-vit-large-patch14-336 | CLIP                     | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| \tBAAI/bge-large-en-v1.5                   | BGE                      | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n\n## Reward Models\n\n| Models                                         | Model Family              | A2 Supported                             |               A3 Supported               |\n|------------------------------------------------|---------------------------|------------------------------------------|:----------------------------------------:|\n| \tSkywork/Skywork-Reward-Llama-3.1-8B-v0.2      | Llama3.1 Reward           | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| \tShanghai_AI_Laboratory/internlm2-7b-reward    | InternLM 2 Reward         | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| \tQwen/Qwen2.5-Math-RM-72B                      | Qwen2.5 Reward - Math     | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| \tHoweee/Qwen2.5-1.5B-apeach                    | Qwen2.5 Reward - Sequence | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n| \tAI-ModelScope/Skywork-Reward-Gemma-2-27B-v0.2 | Gemma 2-27B Reward        | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n\n## Rerank Models\n\n| Models                  | Model Family |               A2 Supported               |               A3 Supported               |\n|-------------------------|--------------|:----------------------------------------:|:----------------------------------------:|\n| BAAI/bge-reranker-v2-m3 | BGE-Reranker | **<span style=\"color: green;\">√</span>** | **<span style=\"color: green;\">√</span>** |\n"
  },
  {
    "path": "docs/platforms/cpu_server.md",
    "content": "# CPU Servers\n\nThe document addresses how to set up the [SGLang](https://github.com/sgl-project/sglang) environment and run LLM inference on CPU servers.\nSGLang is enabled and optimized on the CPUs equipped with Intel® AMX® Instructions,\nwhich are 4th generation or newer Intel® Xeon® Scalable Processors.\n\n## Optimized Model List\n\nA list of popular LLMs are optimized and run efficiently on CPU,\nincluding the most notable open-source models like Llama series, Qwen series,\nand DeepSeek series like DeepSeek-R1 and DeepSeek-V3.1-Terminus.\n\n| Model Name | BF16 | W8A8_INT8 | FP8 |\n|:---:|:---:|:---:|:---:|\n| DeepSeek-R1 |   | [meituan/DeepSeek-R1-Channel-INT8](https://huggingface.co/meituan/DeepSeek-R1-Channel-INT8) | [deepseek-ai/DeepSeek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1) |\n| DeepSeek-V3.1-Terminus |   | [IntervitensInc/DeepSeek-V3.1-Terminus-Channel-int8](https://huggingface.co/IntervitensInc/DeepSeek-V3.1-Terminus-Channel-int8) | [deepseek-ai/DeepSeek-V3.1-Terminus](https://huggingface.co/deepseek-ai/DeepSeek-V3.1-Terminus) |\n| Llama-3.2-3B | [meta-llama/Llama-3.2-3B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct) | [RedHatAI/Llama-3.2-3B-quantized.w8a8](https://huggingface.co/RedHatAI/Llama-3.2-3B-Instruct-quantized.w8a8) |   |\n| Llama-3.1-8B | [meta-llama/Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) | [RedHatAI/Meta-Llama-3.1-8B-quantized.w8a8](https://huggingface.co/RedHatAI/Meta-Llama-3.1-8B-quantized.w8a8) |   |\n| QwQ-32B |   | [RedHatAI/QwQ-32B-quantized.w8a8](https://huggingface.co/RedHatAI/QwQ-32B-quantized.w8a8) |   |\n| DeepSeek-Distilled-Llama |   | [RedHatAI/DeepSeek-R1-Distill-Llama-70B-quantized.w8a8](https://huggingface.co/RedHatAI/DeepSeek-R1-Distill-Llama-70B-quantized.w8a8) |   |\n| Qwen3-235B |   |   | [Qwen/Qwen3-235B-A22B-FP8](https://huggingface.co/Qwen/Qwen3-235B-A22B-FP8) |\n\n**Note:** The model identifiers listed in the table above\nhave been verified on 6th Gen Intel® Xeon® P-core platforms.\n\n## Installation\n\n### Install Using Docker\n\nIt is recommended to use Docker for setting up the SGLang environment.\nA [Dockerfile](https://github.com/sgl-project/sglang/blob/main/docker/xeon.Dockerfile) is provided to facilitate the installation.\nReplace `<secret>` below with your [HuggingFace access token](https://huggingface.co/docs/hub/en/security-tokens).\n\n```bash\n# Clone the SGLang repository\ngit clone https://github.com/sgl-project/sglang.git\ncd sglang/docker\n\n# Build the docker image\ndocker build -t sglang-cpu:latest -f xeon.Dockerfile .\n\n# Initiate a docker container\ndocker run \\\n    -it \\\n    --privileged \\\n    --ipc=host \\\n    --network=host \\\n    -v /dev/shm:/dev/shm \\\n    -v ~/.cache/huggingface:/root/.cache/huggingface \\\n    -p 30000:30000 \\\n    -e \"HF_TOKEN=<secret>\" \\\n    sglang-cpu:latest /bin/bash\n```\n\n### Install From Source\n\nIf you prefer to install SGLang in a bare metal environment,\nthe setup process is as follows:\n\nPlease install the required packages and libraries beforehand if\nthey are not already present on your system.\nYou can refer to the Ubuntu-based installation commands in\n[the Dockerfile](https://github.com/sgl-project/sglang/blob/main/docker/xeon.Dockerfile#L11)\nfor guidance.\n\n1. Install `uv` package manager, then create and activate a virtual environment:\n\n```bash\n# Taking '/opt' as the example uv env folder, feel free to change it as needed\ncd /opt\ncurl -LsSf https://astral.sh/uv/install.sh | sh\nsource $HOME/.local/bin/env\nuv venv --python 3.12\nsource .venv/bin/activate\n```\n\n2. Create a config file to direct the installation channel\n    (a.k.a. index-url) of `torch` related packages:\n\n```bash\nvim .venv/uv.toml\n```\n\nPress 'a' to enter insert mode of `vim`, paste the following content into the created file\n\n```file\n[[index]]\nname = \"torch\"\nurl = \"https://download.pytorch.org/whl/cpu\"\n\n[[index]]\nname = \"torchvision\"\nurl = \"https://download.pytorch.org/whl/cpu\"\n\n[[index]]\nname = \"torchaudio\"\nurl = \"https://download.pytorch.org/whl/cpu\"\n\n[[index]]\nname = \"triton\"\nurl = \"https://download.pytorch.org/whl/cpu\"\n\n```\n\nSave the file (in `vim`, press 'esc' to exit insert mode, then ':x+Enter'),\nand set it as the default `uv` config.\n\n```bash\nexport UV_CONFIG_FILE=/opt/.venv/uv.toml\n```\n\n3. Clone the `sglang` source code and build the packages\n\n```bash\n# Clone the SGLang code\ngit clone https://github.com/sgl-project/sglang.git\ncd sglang\ngit checkout <YOUR-DESIRED-VERSION>\n\n# Use dedicated toml file\ncd python\ncp pyproject_cpu.toml pyproject.toml\n# Install SGLang dependent libs, and build SGLang main package\nuv pip install --upgrade pip setuptools\nuv pip install .\n\n# Build the CPU backend kernels\ncd ../sgl-kernel\ncp pyproject_cpu.toml pyproject.toml\nuv pip install .\n```\n\n4. Set the required environment variables\n\n```bash\nexport SGLANG_USE_CPU_ENGINE=1\n\n# Set 'LD_LIBRARY_PATH' and 'LD_PRELOAD' to ensure the libs can be loaded by sglang processes\nexport LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu\nexport LD_PRELOAD=${LD_PRELOAD}:/opt/.venv/lib/libiomp5.so:${LD_LIBRARY_PATH}/libtcmalloc.so.4:${LD_LIBRARY_PATH}/libtbbmalloc.so.2\n```\n\nNotes:\n\n- Note that the environment variable `SGLANG_USE_CPU_ENGINE=1`\n    is required to enable the SGLang service with the CPU engine.\n\n- If you encounter code compilation issues during the `sgl-kernel` building process,\n    please check your `gcc` and `g++` versions and upgrade them if they are outdated.\n    It is recommended to use `gcc-13` and `g++-13` as they have been verified\n    in the official Docker container.\n\n- The system library path is typically located in one of the following directories:\n    `~/.local/lib/`, `/usr/local/lib/`, `/usr/local/lib64/`, `/usr/lib/`, `/usr/lib64/`\n    and `/usr/lib/x86_64-linux-gnu/`. In the above example commands, `/usr/lib/x86_64-linux-gnu`\n    is used. Please adjust the path according to your server configuration.\n\n- It is recommended to add the following to your `~/.bashrc` file to\n    avoid setting these variables every time you open a new terminal:\n\n    ```bash\n    source .venv/bin/activate\n    export SGLANG_USE_CPU_ENGINE=1\n    export LD_LIBRARY_PATH=<YOUR-SYSTEM-LIBRARY-FOLDER>\n    export LD_PRELOAD=<YOUR-LIBS-PATHS>\n    ```\n\n## Launch of the Serving Engine\n\nExample command to launch SGLang serving:\n\n```bash\npython -m sglang.launch_server   \\\n    --model <MODEL_ID_OR_PATH>   \\\n    --trust-remote-code          \\\n    --disable-overlap-schedule   \\\n    --device cpu                 \\\n    --host 0.0.0.0               \\\n    --tp 6\n```\n\nNotes:\n\n1. For running W8A8 quantized models, please add the flag `--quantization w8a8_int8`.\n\n2. The flag `--tp 6` specifies that tensor parallelism will be applied using 6 ranks (TP6).\n    The number of TP specified is how many TP ranks will be used during the execution.\n    On a CPU platform, a TP rank means a sub-NUMA cluster (SNC).\n    Usually we can get the SNC information (How many available) from the Operating System with e.g. `lscpu` command.\n\n    If the specified TP rank number differs from the total SNC count,\n    the system will automatically utilize the first `n` SNCs.\n    Note that `n` cannot exceed the total SNC number, doing so will result in an error.\n\n    `SGLANG_CPU_OMP_THREADS_BIND` allows explicit control of CPU cores for each tensor parallel (TP) rank.\n\n    **example 1**: Run SGLang service with TP=6, using the first 40 cores of each SNC on a Xeon® 6980P server,\n    which has 43-43-42 cores on the 3 SNCs of a socket, we should set:\n\n    ```bash\n    export SGLANG_CPU_OMP_THREADS_BIND=\"0-39|43-82|86-125|128-167|171-210|214-253\"\n    ```\n    This configuration is equivalent to:\n    - rank 0: `numactl -C 0-39 -m 0`\n    - rank 1: `numactl -C 43-82 -m 1`\n    - rank 2: `numactl -C 86-125 -m 2`\n    - rank 3: `numactl -C 128-167 -m 3`\n    - rank 4: `numactl -C 171-210 -m 4`\n    - rank 5: `numactl -C 214-253 -m 5`\n\n\n    **example 2**: Run SGLang service with TP=2, using 96 cores cross 3 SNCs on a Xeon® 6972P server,\n    which has 32-32-32 cores on the 3 SNCs in a socket, we should set:\n    ```bash\n    export SGLANG_CPU_OMP_THREADS_BIND=\"0-95|96-191\"\n    ```\n    This configuration is equivalent to:\n    - rank 0: `numactl -C 0-95 -m 0-2`\n    - rank 1: `numactl -C 96-191 -m 3-5`\n\n    Please beware that with SGLANG_CPU_OMP_THREADS_BIND set,\n    the available memory amounts of the ranks may not be determined in prior.\n    You may need to set proper `--max-total-tokens` to avoid the out-of-memory error.\n\n3. For optimizing decoding with torch.compile, please add the flag `--enable-torch-compile`.\n    To specify the maximum batch size when using `torch.compile`, set the flag `--torch-compile-max-bs`.\n    For example, `--enable-torch-compile --torch-compile-max-bs 4` means using `torch.compile`\n    and setting the maximum batch size to 4.\n\n4. A warmup step is automatically triggered when the service is started.\n    The server is ready when you see the log `The server is fired up and ready to roll!`.\n\n## Benchmarking with Requests\n\nYou can benchmark the performance via the `bench_serving` script.\nRun the command in another terminal. An example command would be:\n\n```bash\npython -m sglang.bench_serving   \\\n    --dataset-name random        \\\n    --random-input-len 1024      \\\n    --random-output-len 1024     \\\n    --num-prompts 1              \\\n    --request-rate inf           \\\n    --random-range-ratio 1.0\n```\n\nDetailed parameter descriptions are available via the command:\n\n```bash\npython -m sglang.bench_serving -h\n```\n\nAdditionally, requests can be formatted using\n[the OpenAI Completions API](https://docs.sglang.io/basic_usage/openai_api_completions.html)\nand sent via the command line (e.g., using `curl`) or through your own scripts.\n\n## Example Usage Commands\n\nLarge Language Models can range from fewer than 1 billion to several hundred billion parameters.\nDense models larger than 20B are expected to run on flagship 6th Gen Intel® Xeon® processors\nwith dual sockets and a total of 6 sub-NUMA clusters. Dense models of approximately 10B parameters or fewer,\nor MoE (Mixture of Experts) models with fewer than 10B activated parameters, can run on more common\n4th generation or newer Intel® Xeon® processors, or utilize a single socket of the flagship 6th Gen Intel® Xeon® processors.\n\n### Example: Running DeepSeek-V3.1-Terminus\n\nAn example command to launch service of W8A8_INT8 DeepSeek-V3.1-Terminus on a Xeon® 6980P server:\n\n```bash\npython -m sglang.launch_server                                 \\\n    --model IntervitensInc/DeepSeek-V3.1-Terminus-Channel-int8 \\\n    --trust-remote-code                                        \\\n    --disable-overlap-schedule                                 \\\n    --device cpu                                               \\\n    --quantization w8a8_int8                                   \\\n    --host 0.0.0.0                                             \\\n    --enable-torch-compile                                     \\\n    --torch-compile-max-bs 4                                   \\\n    --tp 6\n```\n\nSimilarly, an example command to launch service of FP8 DeepSeek-V3.1-Terminus would be:\n\n```bash\npython -m sglang.launch_server                     \\\n    --model deepseek-ai/DeepSeek-V3.1-Terminus     \\\n    --trust-remote-code                            \\\n    --disable-overlap-schedule                     \\\n    --device cpu                                   \\\n    --host 0.0.0.0                                 \\\n    --enable-torch-compile                         \\\n    --torch-compile-max-bs 4                       \\\n    --tp 6\n```\n\nNote: Please set `--torch-compile-max-bs` to the maximum desired batch size for your deployment,\nwhich can be up to 16. The value `4` in the examples is illustrative.\n\n### Example: Running Llama-3.2-3B\n\nAn example command to launch service of Llama-3.2-3B with BF16 precision:\n\n```bash\npython -m sglang.launch_server                     \\\n    --model meta-llama/Llama-3.2-3B-Instruct       \\\n    --trust-remote-code                            \\\n    --disable-overlap-schedule                     \\\n    --device cpu                                   \\\n    --host 0.0.0.0                                 \\\n    --enable-torch-compile                         \\\n    --torch-compile-max-bs 16                      \\\n    --tp 2\n```\n\nThe example command to launch service of W8A8_INT8 version of Llama-3.2-3B:\n\n```bash\npython -m sglang.launch_server                     \\\n    --model RedHatAI/Llama-3.2-3B-quantized.w8a8   \\\n    --trust-remote-code                            \\\n    --disable-overlap-schedule                     \\\n    --device cpu                                   \\\n    --quantization w8a8_int8                       \\\n    --host 0.0.0.0                                 \\\n    --enable-torch-compile                         \\\n    --torch-compile-max-bs 16                      \\\n    --tp 2\n```\n\nNote: The `--torch-compile-max-bs` and `--tp` settings are examples that should be adjusted for your setup.\nFor instance, use `--tp 3` to utilize 1 socket with 3 sub-NUMA clusters on an Intel® Xeon® 6980P server.\n\nOnce the server have been launched, you can test it using the `bench_serving` command or create\nyour own commands or scripts following [the benchmarking example](#benchmarking-with-requests).\n"
  },
  {
    "path": "docs/platforms/mindspore_backend.md",
    "content": "# MindSpore Models\n\n## Introduction\n\nMindSpore is a high-performance AI framework optimized for Ascend NPUs. This doc guides users to run MindSpore models in SGLang.\n\n## Requirements\n\nMindSpore currently only supports Ascend NPU devices. Users need to first install Ascend CANN software packages.\nThe CANN software packages can be downloaded from the [Ascend Official Website](https://www.hiascend.com). The recommended version is 8.3.RC2.\n\n## Supported Models\n\nCurrently, the following models are supported:\n\n- **Qwen3**: Dense and MoE models\n- **DeepSeek V3/R1**\n- *More models coming soon...*\n\n## Installation\n\n> **Note**: Currently, MindSpore models are provided by an independent package `sgl-mindspore`. Support for MindSpore is built upon current SGLang support for Ascend NPU platform. Please first [install SGLang for Ascend NPU](ascend_npu.md) and then install `sgl-mindspore`:\n\n```shell\ngit clone https://github.com/mindspore-lab/sgl-mindspore.git\ncd sgl-mindspore\npip install -e .\n```\n\n\n## Run Model\n\nCurrent SGLang-MindSpore supports Qwen3 and DeepSeek V3/R1 models. This doc uses Qwen3-8B as an example.\n\n### Offline infer\n\nUse the following script for offline infer:\n\n```python\nimport sglang as sgl\n\n# Initialize the engine with MindSpore backend\nllm = sgl.Engine(\n    model_path=\"/path/to/your/model\",  # Local model path\n    device=\"npu\",                      # Use NPU device\n    model_impl=\"mindspore\",            # MindSpore implementation\n    attention_backend=\"ascend\",        # Attention backend\n    tp_size=1,                         # Tensor parallelism size\n    dp_size=1                          # Data parallelism size\n)\n\n# Generate text\nprompts = [\n    \"Hello, my name is\",\n    \"The capital of France is\",\n    \"The future of AI is\"\n]\n\nsampling_params = {\"temperature\": 0, \"top_p\": 0.9}\noutputs = llm.generate(prompts, sampling_params)\n\nfor prompt, output in zip(prompts, outputs):\n    print(f\"Prompt: {prompt}\")\n    print(f\"Generated: {output['text']}\")\n    print(\"---\")\n```\n\n### Start server\n\nLaunch a server with MindSpore backend:\n\n```bash\n# Basic server startup\npython3 -m sglang.launch_server \\\n    --model-path /path/to/your/model \\\n    --host 0.0.0.0 \\\n    --device npu \\\n    --model-impl mindspore \\\n    --attention-backend ascend \\\n    --tp-size 1 \\\n    --dp-size 1\n```\n\nFor distributed server with multiple nodes:\n\n```bash\n# Multi-node distributed server\npython3 -m sglang.launch_server \\\n    --model-path /path/to/your/model \\\n    --host 0.0.0.0 \\\n    --device npu \\\n    --model-impl mindspore \\\n    --attention-backend ascend \\\n    --dist-init-addr 127.0.0.1:29500 \\\n    --nnodes 2 \\\n    --node-rank 0 \\\n    --tp-size 4 \\\n    --dp-size 2\n```\n\n## Troubleshooting\n\n#### Debug Mode\n\nEnable sglang debug logging by log-level argument.\n\n```bash\npython3 -m sglang.launch_server \\\n    --model-path /path/to/your/model \\\n    --host 0.0.0.0 \\\n    --device npu \\\n    --model-impl mindspore \\\n    --attention-backend ascend \\\n    --log-level DEBUG\n```\n\nEnable mindspore info and debug logging by setting environments.\n\n```bash\nexport GLOG_v=1  # INFO\nexport GLOG_v=0  # DEBUG\n```\n\n#### Explicitly select devices\n\nUse the following environment variable to explicitly select the devices to use.\n\n```shell\nexport ASCEND_RT_VISIBLE_DEVICES=4,5,6,7  # to set device\n```\n\n#### Some communication environment issues\n\nIn case of some environment with special communication environment, users need set some environment variables.\n\n```shell\nexport MS_ENABLE_LCCL=off # current not support LCCL communication mode in SGLang-MindSpore\n```\n\n#### Some dependencies of protobuf\n\nIn case of some environment with special protobuf version, users need set some environment variables to avoid binary version mismatch.\n\n```shell\nexport PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python  # to avoid protobuf binary version mismatch\n```\n\n## Support\nFor MindSpore-specific issues:\n\n- Refer to the [MindSpore documentation](https://www.mindspore.cn/)\n"
  },
  {
    "path": "docs/platforms/mthreads_gpu.md",
    "content": "# Moore Threads GPUs\n\nThis document describes how run SGLang on Moore Threads GPUs. If you encounter issues or have questions, please [open an issue](https://github.com/sgl-project/sglang/issues).\n\n## Install SGLang\n\nYou can install SGLang using one of the methods below.\n\n### Install from Source\n\n```bash\n# Use the default branch\ngit clone https://github.com/sgl-project/sglang.git\ncd sglang\n\n# Compile sgl-kernel\npip install --upgrade pip\ncd sgl-kernel\npython setup_musa.py install\n\n# Install sglang python package\ncd ..\nrm -f python/pyproject.toml && mv python/pyproject_other.toml python/pyproject.toml\npip install -e \"python[all_musa]\"\n```\n"
  },
  {
    "path": "docs/platforms/nvidia_jetson.md",
    "content": "# NVIDIA Jetson Orin\n\n## Prerequisites\n\nBefore starting, ensure the following:\n\n- [**NVIDIA Jetson AGX Orin Devkit**](https://www.nvidia.com/en-us/autonomous-machines/embedded-systems/jetson-orin/) is set up with **JetPack 6.1** or later.\n- **CUDA Toolkit** and **cuDNN** are installed.\n- Verify that the Jetson AGX Orin is in **high-performance mode**:\n```bash\nsudo nvpmodel -m 0\n```\n* * * * *\n## Installing and running SGLang with Jetson Containers\nClone the jetson-containers github repository:\n```\ngit clone https://github.com/dusty-nv/jetson-containers.git\n```\nRun the installation script:\n```\nbash jetson-containers/install.sh\n```\nBuild the container image:\n```\njetson-containers build sglang\n```\nRun the container:\n```\njetson-containers run $(autotag sglang)\n```\nOr you can also manually run a container with this command:\n```\ndocker run --runtime nvidia -it --rm --network=host IMAGE_NAME\n```\n* * * * *\n\nRunning Inference\n-----------------------------------------\n\nLaunch the server:\n```bash\npython -m sglang.launch_server \\\n  --model-path deepseek-ai/DeepSeek-R1-Distill-Llama-8B \\\n  --device cuda \\\n  --dtype half \\\n  --attention-backend flashinfer \\\n  --mem-fraction-static 0.8 \\\n  --context-length 8192\n```\nThe quantization and limited context length (`--dtype half --context-length 8192`) are due to the limited computational resources in [Nvidia jetson kit](https://www.nvidia.com/en-us/autonomous-machines/embedded-systems/jetson-orin/). A detailed explanation can be found in [Server Arguments](../advanced_features/server_arguments.md).\n\nAfter launching the engine, refer to [Chat completions](https://docs.sglang.io/basic_usage/openai_api_completions.html#Usage) to test the usability.\n* * * * *\nRunning quantization with TorchAO\n-------------------------------------\nTorchAO is suggested to NVIDIA Jetson Orin.\n```bash\npython -m sglang.launch_server \\\n    --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n    --device cuda \\\n    --dtype bfloat16 \\\n    --attention-backend flashinfer \\\n    --mem-fraction-static 0.8 \\\n    --context-length 8192 \\\n    --torchao-config int4wo-128\n```\nThis enables TorchAO's int4 weight-only quantization with a 128-group size. The usage of `--torchao-config int4wo-128` is also for memory efficiency.\n\n\n* * * * *\nStructured output with XGrammar\n-------------------------------\nPlease refer to [SGLang doc structured output](../advanced_features/structured_outputs.ipynb).\n* * * * *\n\nThanks to the support from [Nurgaliyev Shakhizat](https://github.com/shahizat), [Dustin Franklin](https://github.com/dusty-nv) and [Johnny Núñez Cano](https://github.com/johnnynunez).\n\nReferences\n----------\n-   [NVIDIA Jetson AGX Orin Documentation](https://developer.nvidia.com/embedded/jetson-agx-orin)\n"
  },
  {
    "path": "docs/platforms/tpu.md",
    "content": "# TPU\n\nSGLang supports high-performance TPU inference through the SGLang-JAX backend, which is specifically optimized for Google Cloud TPUs. The JAX-based implementation delivers exceptional throughput and low latency for Large Language Model (LLM) serving workloads on TPU hardware.\n\nFor TPU-specific issues or feature requests, please visit the [sglang-jax GitHub issues page](https://github.com/sgl-project/sglang-jax/issues).\n\n**NOTE:** SGLang TPU support is implemented via the SGLang-JAX backend, a dedicated JAX-based inference engine maintained as a separate repository at [https://github.com/sgl-project/sglang-jax](https://github.com/sgl-project/sglang-jax).\n\n## System Requirements\n\n### Supported TPU Hardware\n\n| TPU Type | HBM Memory | Availability |\n|----------|-----------|--------------|\n| TPU v6e | 32 GB | Google Cloud |\n| TPU v7 | 96 GB per core | Google Cloud |\n\n### Software Requirements\n\n- **Python:** 3.12 or higher\n- **JAX:** Latest version with TPU support\n- **Environment:** Google Cloud TPU VM or compatible TPU runtime\n- **Optional:** SkyPilot for simplified cloud deployment\n\n## Feature Support Matrix\n\nSGLang-JAX provides comprehensive TPU-optimized features for production LLM serving:\n\n| Feature | Support Status | Description |\n|---------|---------------|-------------|\n| High-Throughput Continuous Batching | ✅ | Dynamic request batching for maximum TPU utilization |\n| Radix Tree KV Cache | ✅ | Memory-efficient prefix sharing between requests |\n| FlashAttention Backend | ✅ | TPU-optimized attention kernel for long sequences |\n| Tensor Parallelism | ✅ | Distribute models across multiple TPU cores |\n| Paged Attention | ✅ | Flexible KV cache management with paging |\n| Speculative Decoding (EAGLE/EAGLE3) | ✅ | 20-40% throughput improvement for compatible models |\n| Chunked Prefill | ✅ | Mixed prefill-decode batching |\n| OpenAI-Compatible API | ✅ | Drop-in replacement for OpenAI API |\n| Data Parallel Attention | 🚧 | In development - Attention computation with data parallelism |\n| Quantization | 🚧 | In development - Model quantization for reduced memory usage |\n| Multi-LoRA | 🚧 | In development - Serve multiple LoRA adapters simultaneously |\n\n### Attention Backend Comparison\n\n| Backend | Paged Attention | Spec Decoding | MLA | Sliding Window |\n|---------|----------------|---------------|-----|----------------|\n| FlashAttention (fa) | ✅ | ✅ | ❌ | ✅ |\n| Native | ❌ | ❌ | ❌ | ❌ |\n\n**NOTE:** FlashAttention backend is recommended for production workloads due to superior memory efficiency and performance.\n\n## Optimized Model List\n\nThe following models have been tested and optimized for TPU deployment:\n\n| Model Family | Performance Status |\n|--------------|-------------------|\n| [Qwen 3](https://huggingface.co/Qwen) | ⭐ Recommended for production |\n| [Qwen 3 MoE](https://huggingface.co/Qwen) | ⭐ Best performance |\n| [Qwen 2](https://huggingface.co/Qwen) | Needs improvement |\n| [Qwen 2 MoE](https://huggingface.co/Qwen) | Needs improvement |\n| [Qwen 1.5](https://huggingface.co/Qwen) | Needs improvement |\n| [Llama/LLaMA](https://huggingface.co/meta-llama) | Needs improvement |\n| [Grok-2](https://huggingface.co/xai-org) | Needs improvement |\n| [Gemma 2](https://huggingface.co/google) | Verified on TPU |\n| Bailing MoE | Needs improvement |\n\n## Installation\n\n### Method 1: Using PyPI (Recommended)\n\n```bash\npip install sglang-jax\n```\n\n### Method 2: From Source\n\n```bash\ngit clone https://github.com/sgl-project/sglang-jax\ncd sglang-jax\nuv venv --python 3.12 && source .venv/bin/activate\nuv pip install -e \"python[all]\"\n```\n\n### Method 3: Using Docker\n\n**NOTE:** Docker support for TPU is currently under development. Please use PyPI or source installation methods.\n\n### Method 4: Cloud TPU with SkyPilot\n\n[SkyPilot](https://github.com/skypilot-org/skypilot) provides simplified deployment on Google Cloud TPU:\n\n1. Install SkyPilot and configure GCP access (see [SkyPilot documentation](https://skypilot.readthedocs.io/))\n\n2. Create a SkyPilot configuration file:\n\n<details>\n<summary>SkyPilot YAML: <code>sglang-jax.sky.yaml</code></summary>\n\n```yaml\n# sglang-jax.sky.yaml\nresources:\n   accelerators: tpu-v6e-4\n   accelerator_args:\n      tpu_vm: True\n      runtime_version: v2-alpha-tpuv6e\n\nrun: |\n  git clone https://github.com/sgl-project/sglang-jax.git\n  cd sglang-jax\n  uv venv --python 3.12\n  source .venv/bin/activate\n  uv pip install -e \"python[all]\"\n```\n\n</details>\n\n3. Launch your TPU cluster:\n\n```bash\n# Standard deployment\nsky launch -c sglang-jax sglang-jax.sky.yaml --infra=gcp\n\n# With spot instances for cost savings\nsky launch -c sglang-jax sglang-jax.sky.yaml --infra=gcp --use-spot\n```\n\n## Launch of the Serving Engine\n\n### Basic Example: Qwen-7B\n\n```bash\nJAX_COMPILATION_CACHE_DIR=/tmp/jit_cache python3 -u -m sgl_jax.launch_server \\\n    --model-path Qwen/Qwen-7B-Chat \\\n    --trust-remote-code \\\n    --dist-init-addr=0.0.0.0:10011 \\\n    --nnodes=1 \\\n    --tp-size=4 \\\n    --device=tpu \\\n    --random-seed=3 \\\n    --node-rank=0 \\\n    --mem-fraction-static=0.8 \\\n    --max-prefill-tokens=8192 \\\n    --download-dir=/tmp \\\n    --dtype=bfloat16 \\\n    --skip-server-warmup \\\n    --host 0.0.0.0 \\\n    --port 30000\n```\n\n**Key Parameters Explained:**\n\n1. `JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache` - Enables JIT compilation caching to accelerate server startup on subsequent runs\n2. `--tp-size=4` - Tensor parallelism size; match this to your TPU core count (typically 1, 4, or 8)\n3. `--device=tpu` - Specifies TPU device (this is the default for sglang-jax)\n4. `--dtype=bfloat16` - Uses bfloat16 precision, which TPUs are optimized for\n5. `--mem-fraction-static=0.8` - Allocates 80% of TPU HBM for static memory (adjustable from 0.2 to 0.9)\n6. `--max-prefill-tokens=8192` - Maximum number of tokens processed in the prefill phase\n\n### High-Performance Configuration: Qwen3-8B\n\nFor production workloads with optimal throughput:\n\n```bash\npython3 -u -m sgl_jax.launch_server \\\n    --model-path Qwen/Qwen3-8B \\\n    --trust-remote-code \\\n    --tp-size=4 \\\n    --device=tpu \\\n    --mem-fraction-static=0.8 \\\n    --chunked-prefill-size=2048 \\\n    --dtype=bfloat16 \\\n    --max-running-requests=256 \\\n    --page-size=128 \\\n    --attention-backend=fa\n```\n\n### Advanced: Speculative Decoding (EAGLE3)\n\nSpeculative decoding can improve throughput by 20-40% for compatible models:\n\n```bash\npython3 -u -m sgl_jax.launch_server \\\n    --model-path Qwen/Qwen3-32B \\\n    --trust-remote-code \\\n    --device=tpu \\\n    --tp-size=4 \\\n    --mem-fraction-static=0.8 \\\n    --max-prefill-tokens=4096 \\\n    --attention-backend=fa \\\n    --dtype=bfloat16 \\\n    --port=30000 \\\n    --host=0.0.0.0 \\\n    --disable-overlap-schedule \\\n    --speculative-algorithm=EAGLE3 \\\n    --speculative-draft-model-path=AngelSlim/Qwen3-32B_eagle3 \\\n    --page-size=64 \\\n    --speculative-eagle-topk=1 \\\n    --speculative-num-steps=3 \\\n    --speculative-num-draft-tokens=4\n```\n\n**NOTE:** Speculative decoding is currently supported for Qwen3 and LLaMA model families. See the [Speculative Decoding documentation](https://github.com/sgl-project/sglang-jax/blob/main/docs/features/speculative_decoding.md) for detailed configuration guidance.\n\n\n### Multi-Node Distributed Serving\n\nFor large models requiring multiple TPU VMs:\n\n```bash\n# Node 0 (coordinator)\npython3 -m sgl_jax.launch_server \\\n    --model-path MODEL_PATH \\\n    --dist-init-addr=NODE0_IP:10011 \\\n    --nnodes=2 \\\n    --node-rank=0 \\\n    --tp-size=8 \\\n    [other parameters...]\n\n# Node 1 (worker)\npython3 -m sgl_jax.launch_server \\\n    --model-path MODEL_PATH \\\n    --dist-init-addr=NODE0_IP:10011 \\\n    --nnodes=2 \\\n    --node-rank=1 \\\n    --tp-size=8 \\\n    [other parameters...]\n```\n\n## Benchmarking with Requests\n\n### Throughput Testing\n\nBasic throughput benchmark:\n\n```bash\npython3 -m sgl_jax.bench_serving \\\n    --backend sgl-jax \\\n    --dataset-name random \\\n    --num-prompts=100 \\\n    --random-input=512 \\\n    --random-output=128 \\\n    --max-concurrency=8 \\\n    --random-range-ratio=1 \\\n    --warmup-requests=0\n```\n\n### Latency Testing\n\nMeasure single-batch latency:\n\n```bash\npython3 -m sgl_jax.bench_one_batch_server \\\n    --base-url http://127.0.0.1:30000 \\\n    --model-path Qwen/Qwen-7B-Chat \\\n    --batch-size=32 \\\n    --input-len=256 \\\n    --output-len=32\n```\n\n### Comprehensive Benchmark Script\n\nFor systematic performance evaluation across different configurations:\n\n```bash\n#!/bin/bash\nset -e\n\nbackend=${1:-sgl-jax}\nnum_prompts_per_concurrency=3\ninput_seq_lens=(1024 4096 8192)\noutput_seq_lens=(1 1024)\nmax_concurrencies=(8 16 32 64 128 256)\n\nfor input_seq_len in \"${input_seq_lens[@]}\"; do\n    for output_seq_len in \"${output_seq_lens[@]}\"; do\n        echo \"=======================================\"\n        echo \"Testing ISL/OSL: $input_seq_len/$output_seq_len\"\n        echo \"=======================================\"\n        for max_concurrency in \"${max_concurrencies[@]}\"; do\n            num_prompts=$((num_prompts_per_concurrency * max_concurrency))\n            python3 -m sgl_jax.bench_serving \\\n                --backend ${backend} \\\n                --dataset-name random \\\n                --num-prompts ${num_prompts} \\\n                --random-input ${input_seq_len} \\\n                --random-output ${output_seq_len} \\\n                --max-concurrency ${max_concurrency} \\\n                --random-range-ratio 1 \\\n                --disable-ignore-eos \\\n                --warmup-requests 0\n        done\n    done\ndone\n```\n\nFor detailed help on all benchmark parameters:\n\n```bash\npython3 -m sgl_jax.bench_serving --help\n```\n\nSee the [Benchmark and Profiling Guide](https://github.com/sgl-project/sglang-jax/blob/main/docs/developer_guide/benchmark_and_profiling.md) for advanced benchmarking techniques and profiling with JAX Profiler.\n\n## Performance Optimization\n\n### Memory Optimization\n\n**Reduce memory usage:**\n- Lower `--mem-fraction-static` (from 0.8 → 0.5 → 0.3)\n- Decrease `--max-prefill-tokens` (from 16384 → 8192 → 4096)\n- Reduce `--max-running-requests`\n\n**Handle OOM errors:**\n- Start with conservative memory settings (`--mem-fraction-static=0.5`)\n- Gradually increase until you find the optimal balance\n- Increase `--page-size` for better memory locality (1 → 16 → 64 → 128)\n\n### Throughput Optimization\n\nTo maximize tokens per second:\n\n- Use FlashAttention backend: `--attention-backend=fa`\n- Enable speculative decoding (EAGLE3) for Qwen3 models (20-40% improvement)\n- Increase `--max-running-requests` to 256+\n- Set `--mem-fraction-static` to 0.8+ (if memory allows)\n- Use larger page sizes (64-128)\n- Enable chunked prefill: `--chunked-prefill-size=2048`\n\n### Latency Optimization\n\nTo minimize time-to-first-token (TTFT) and inter-token latency:\n\n- Reduce `--page-size` to 1-4\n- Lower `--max-running-requests` (16-32) for smaller batches\n- Reduce `--chunked-prefill-size`\n- Use conservative memory settings to avoid GC pauses\n\n### TPU-Specific Optimizations\n\n1. **JIT Compilation Cache:**\n   ```bash\n   export JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache\n   ```\n   Always set this environment variable to cache compiled kernels and accelerate server startup.\n\n2. **Data Type Optimization:**\n   Use `--dtype=bfloat16` for TPU native optimization. TPUs are specifically designed for bfloat16 computations.\n\n3. **Tensor Parallelism:**\n   Match `--tp-size` to your TPU core configuration (1, 4, or 8) for optimal model distribution.\n\n4. **Attention Backend:**\n   Always use `--attention-backend=fa` (FlashAttention) for production workloads.\n\n## Troubleshooting\n\n### OOM (Out of Memory) Errors\n\nIf you encounter out-of-memory errors:\n\n1. Reduce `--mem-fraction-static` from 0.8 to 0.5 or lower\n2. Decrease `--max-prefill-tokens` from 8192 to 4096 or 2048\n3. Lower `--max-running-requests` to reduce concurrent batch size\n4. Increase `--page-size` for better memory layout efficiency\n\n### Compilation Long-Time\n\nIf the server takes too long to start:\n\n1. Ensure `JAX_COMPILATION_CACHE_DIR` is properly set\n2. Understand that the first run requires JIT compilation (this is normal)\n3. Subsequent runs will be significantly faster with cached compilations\n4. Consider using `--skip-server-warmup` to defer compilation until first request\n\n### Low Throughput\n\nIf you're not achieving expected throughput:\n\n1. Verify `--tp-size` matches your TPU core configuration\n2. Check that `--attention-backend=fa` is enabled\n3. Increase `--max-running-requests` to enable larger batch formation\n4. Consider enabling speculative decoding for compatible models\n5. Ensure memory settings allow for sufficient batch sizes\n\n### Connection Issues\n\nIf clients cannot connect to the server:\n\n1. Ensure `--host=0.0.0.0` for external access (not just `127.0.0.1`)\n2. Verify firewall rules allow traffic on the specified port (default: 30000)\n3. Check that the server process is running: `curl http://localhost:30000/health`\n\n## Advanced Features\n\n### Speculative Decoding\n\nSGLang-JAX supports EAGLE and EAGLE3 speculative decoding algorithms for Qwen3 and LLaMA model families. Speculative decoding can improve throughput by 20-40% without affecting output quality.\n\nSee the [Speculative Decoding documentation](https://github.com/sgl-project/sglang-jax/blob/main/docs/features/speculative_decoding.md) for detailed configuration and supported model combinations.\n\n### Chunked Prefill\n\nEnable mixed prefill-decode batching for better TPU utilization:\n\n```bash\n--chunked-prefill-size=2048 --enable-mixed-chunk\n```\n\nThis allows the scheduler to mix prefill operations with decode operations in the same batch, improving overall throughput.\n\n### Custom Attention Backends\n\nSGLang-JAX supports a plugin-based attention backend system. You can implement custom attention kernels optimized for specific use cases.\n\nSee the [Attention Backend documentation](https://github.com/sgl-project/sglang-jax/blob/main/docs/features/attention_backend.md) for implementation details.\n\n### Environment Verification\n\nVerify your TPU setup before deploying:\n\n```bash\npython -c \"from sgl_jax import check_env; check_env.check_env()\"\n```\n\nThis command checks:\n- Installed package versions\n- TPU device availability and specifications\n- System resources and configuration\n- Compatibility of settings\n\n## Contributing\n\nWe welcome contributions to improve TPU support in SGLang-JAX!\n\n### Areas for Contribution\n\n**Check the [Development Roadmap](https://github.com/sgl-project/sglang-jax/issues/190)** to see planned features and find opportunities to contribute new functionality.\n\nCurrent contribution areas include:\n\n- Performance optimizations for specific TPU generations\n- Support for additional model architectures\n- Documentation improvements and examples\n- Bug reports and fixes\n- Benchmark results and performance analysis\n\n### How to Contribute\n\n1. Visit the [sglang-jax repository](https://github.com/sgl-project/sglang-jax)\n2. Read the [Contribution Guide](https://github.com/sgl-project/sglang-jax/blob/main/docs/developer_guide/contribution_guide.md)\n3. Join the [SGL-JAX Slack community](https://sgl-fru7574.slack.com/archives/C09EBE5HT5X) for discussions\n4. Report issues at [sglang-jax/issues](https://github.com/sgl-project/sglang-jax/issues)\n\n### Testing on TPU\n\nFor contributors who need TPU access for testing:\n\n- Refer to the [TPU Resources Guide](https://github.com/sgl-project/sglang-jax/blob/main/docs/developer_guide/tpu_resources_guide.md) for information on accessing TPU hardware\n- Use SkyPilot with spot instances for cost-effective testing\n- Follow the [Benchmark and Profiling Guide](https://github.com/sgl-project/sglang-jax/blob/main/docs/developer_guide/benchmark_and_profiling.md) for performance validation\n\n## References\n\n### Documentation\n\n- [SGLang-JAX Repository](https://github.com/sgl-project/sglang-jax)\n- [SGLang-JAX Installation Guide](https://github.com/sgl-project/sglang-jax/blob/main/docs/get_started/install.md)\n- [Qwen Models Quick Start](https://github.com/sgl-project/sglang-jax/blob/main/docs/basic_usage/qwen.md)\n- [Benchmark and Profiling Guide](https://github.com/sgl-project/sglang-jax/blob/main/docs/developer_guide/benchmark_and_profiling.md)\n- [Speculative Decoding](https://github.com/sgl-project/sglang-jax/blob/main/docs/features/speculative_decoding.md)\n\n### External Resources\n\n- [JAX Documentation](https://jax.readthedocs.io/)\n- [Google Cloud TPU Documentation](https://cloud.google.com/tpu/docs)\n- [SkyPilot Documentation](https://skypilot.readthedocs.io/)\n"
  },
  {
    "path": "docs/platforms/xpu.md",
    "content": "# XPU\n\nThe document addresses how to set up the [SGLang](https://github.com/sgl-project/sglang) environment and run LLM inference on Intel GPU, [see more context about Intel GPU support within PyTorch ecosystem](https://docs.pytorch.org/docs/stable/notes/get_start_xpu.html).\n\nSpecifically, SGLang is optimized for [Intel® Arc™ Pro B-Series Graphics](https://www.intel.com/content/www/us/en/ark/products/series/242616/intel-arc-pro-b-series-graphics.html) and [\nIntel® Arc™ B-Series Graphics](https://www.intel.com/content/www/us/en/ark/products/series/240391/intel-arc-b-series-graphics.html).\n\n## Optimized Model List\n\nA list of LLMs have been optimized on Intel GPU, and more are on the way:\n\n| Model Name | BF16 |\n|:---:|:---:|\n| Llama-3.2-3B | [meta-llama/Llama-3.2-3B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct) |\n| Llama-3.1-8B | [meta-llama/Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) |\n| Qwen2.5-1.5B |   [Qwen/Qwen2.5-1.5B](https://huggingface.co/Qwen/Qwen2.5-1.5B) |\n\n**Note:** The model identifiers listed in the table above\nhave been verified on [Intel® Arc™ B580 Graphics](https://www.intel.com/content/www/us/en/products/sku/241598/intel-arc-b580-graphics/specifications.html).\n\n## Installation\n\n### Install From Source\n\nCurrently SGLang XPU only supports installation from source. Please refer to [\"Getting Started on Intel GPU\"](https://docs.pytorch.org/docs/stable/notes/get_start_xpu.html) to install XPU dependency.\n\n```bash\n# Create and activate a conda environment\nconda create -n sgl-xpu python=3.12 -y\nconda activate sgl-xpu\n\n# Set PyTorch XPU as primary pip install channel to avoid installing the larger CUDA-enabled version and prevent potential runtime issues.\npip3 install torch==2.10.0+xpu torchao torchvision torchaudio triton-xpu==3.6.0 --index-url https://download.pytorch.org/whl/xpu\npip3 install xgrammar --no-deps # xgrammar will introduce CUDA-enabled triton which might conflict with XPU\n\n# Clone the SGLang code\ngit clone https://github.com/sgl-project/sglang.git\ncd sglang\ngit checkout <YOUR-DESIRED-VERSION>\n\n# Use dedicated toml file\ncd python\ncp pyproject_xpu.toml pyproject.toml\n# Install SGLang dependent libs, and build SGLang main package\npip install --upgrade pip setuptools\npip install -v . --extra-index-url https://download.pytorch.org/whl/xpu\n```\n\n### Install Using Docker\n\nThe docker for XPU is under active development. Please stay tuned.\n\n## Launch of the Serving Engine\n\nExample command to launch SGLang serving:\n\n```bash\npython -m sglang.launch_server       \\\n    --model <MODEL_ID_OR_PATH>       \\\n    --trust-remote-code              \\\n    --disable-overlap-schedule       \\\n    --device xpu                     \\\n    --host 0.0.0.0                   \\\n    --tp 2                           \\   # using multi GPUs\n    --attention-backend intel_xpu    \\   # using intel optimized XPU attention backend\n    --page-size                      \\   # intel_xpu attention backend supports [32, 64, 128]\n```\n\n## Benchmarking with Requests\n\nYou can benchmark the performance via the `bench_serving` script.\nRun the command in another terminal.\n\n```bash\npython -m sglang.bench_serving   \\\n    --dataset-name random        \\\n    --random-input-len 1024      \\\n    --random-output-len 1024     \\\n    --num-prompts 1              \\\n    --request-rate inf           \\\n    --random-range-ratio 1.0\n```\n\nThe detail explanations of the parameters can be looked up by the command:\n\n```bash\npython -m sglang.bench_serving -h\n```\n\nAdditionally, the requests can be formed with\n[OpenAI Completions API](https://docs.sglang.io/basic_usage/openai_api_completions.html)\nand sent via the command line (e.g. using `curl`) or via your own script.\n"
  },
  {
    "path": "docs/references/custom_chat_template.md",
    "content": "# Custom Chat Template\n\n**NOTE**: There are two chat template systems in SGLang project. This document is about setting a custom chat template for the OpenAI-compatible API server (defined at [conversation.py](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/conversation.py)). It is NOT related to the chat template used in the SGLang language frontend (defined at [chat_template.py](https://github.com/sgl-project/sglang/blob/main/python/sglang/lang/chat_template.py)).\n\nBy default, the server uses the chat template specified in the model tokenizer from Hugging Face.\nIt should just work for most official models such as Llama-2/Llama-3.\n\nIf needed, you can also override the chat template when launching the server:\n\n```bash\npython -m sglang.launch_server \\\n  --model-path meta-llama/Llama-2-7b-chat-hf \\\n  --port 30000 \\\n  --chat-template llama-2\n```\n\nIf the chat template you are looking for is missing, you are welcome to contribute it or load it from a file.\n\n## JSON Format\n\nYou can load the JSON format, which is defined by `conversation.py`.\n\n```json\n{\n  \"name\": \"my_model\",\n  \"system\": \"<|im_start|>system\",\n  \"user\": \"<|im_start|>user\",\n  \"assistant\": \"<|im_start|>assistant\",\n  \"sep_style\": \"CHATML\",\n  \"sep\": \"<|im_end|>\",\n  \"stop_str\": [\"<|im_end|>\", \"<|im_start|>\"]\n}\n```\n\n```bash\npython -m sglang.launch_server \\\n  --model-path meta-llama/Llama-2-7b-chat-hf \\\n  --port 30000 \\\n  --chat-template ./my_model_template.json\n```\n\n## Jinja Format\n\nYou can also use the [Jinja template format](https://huggingface.co/docs/transformers/main/en/chat_templating) as defined by Hugging Face Transformers.\n\n```bash\npython -m sglang.launch_server \\\n  --model-path meta-llama/Llama-2-7b-chat-hf \\\n  --port 30000 \\\n  --chat-template ./my_model_template.jinja\n```\n"
  },
  {
    "path": "docs/references/environment_variables.md",
    "content": "# Environment Variables\n\nSGLang supports various environment variables that can be used to configure its runtime behavior. This document provides a comprehensive list and aims to stay updated over time.\n\n*Note: SGLang uses two prefixes for environment variables: `SGL_` and `SGLANG_`. This is likely due to historical reasons. While both are currently supported for different settings, future versions might consolidate them.*\n\n## General Configuration\n\n| Environment Variable                      | Description                                                                                                                      | Default Value                |\n|-------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------|------------------------------|\n| `SGLANG_USE_MODELSCOPE`                   | Enable using models from ModelScope                                                                                              | `false`                      |\n| `SGLANG_HOST_IP`                          | Host IP address for the server                                                                                                   | `0.0.0.0`                    |\n| `SGLANG_PORT`                             | Port for the server                                                                                                              | auto-detected                |\n| `SGLANG_LOGGING_CONFIG_PATH`              | Custom logging configuration path                                                                                                | Not set                      |\n| `SGLANG_DISABLE_REQUEST_LOGGING`          | Disable request logging                                                                                                          | `false`                      |\n| `SGLANG_LOG_REQUEST_HEADERS`              | Comma-separated list of additional HTTP headers to log when `--log-requests` is enabled. Appends to the default `x-smg-routing-key`. | Not set                      |\n| `SGLANG_HEALTH_CHECK_TIMEOUT`             | Timeout for health check in seconds                                                                                              | `20`                         |\n| `SGLANG_EPLB_HEATMAP_COLLECTION_INTERVAL` | The interval of passes to collect the metric of selected count of physical experts on each layer and GPU rank. 0 means disabled. | `0`                          |\n| `SGLANG_FORWARD_UNKNOWN_TOOLS`            | Forward unknown tool calls to clients instead of dropping them                                                                   | `false` (drop unknown tools) |\n| `SGLANG_REQ_WAITING_TIMEOUT`              | Timeout (in seconds) for requests waiting in the queue before being scheduled                                                    | `-1`                         |\n| `SGLANG_REQ_RUNNING_TIMEOUT`              | Timeout (in seconds) for requests running in the decode batch                                                                    | `-1`                         |\n\n## Performance Tuning\n\n| Environment Variable | Description | Default Value |\n| --- | --- | --- |\n| `SGLANG_ENABLE_TORCH_INFERENCE_MODE` | Control whether to use torch.inference_mode | `false` |\n| `SGLANG_ENABLE_TORCH_COMPILE` | Enable torch.compile | `false` |\n| `SGLANG_SET_CPU_AFFINITY` | Enable CPU affinity setting (often set to `1` in Docker builds) | `false` |\n| `SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN` | Allows the scheduler to overwrite longer context length requests (often set to `1` in Docker builds) | `false` |\n| `SGLANG_IS_FLASHINFER_AVAILABLE` | Control FlashInfer availability check | `true` |\n| `SGLANG_SKIP_P2P_CHECK` | Skip P2P (peer-to-peer) access check | `false` |\n| `SGLANG_CHUNKED_PREFIX_CACHE_THRESHOLD` | Sets the threshold for enabling chunked prefix caching | `8192` |\n| `SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION` | Enable RoPE fusion in Fused Multi-Layer Attention | `1` |\n| `SGLANG_DISABLE_CONSECUTIVE_PREFILL_OVERLAP` | Disable overlap schedule for consecutive prefill batches | `false` |\n| `SGLANG_SCHEDULER_MAX_RECV_PER_POLL` | Set the maximum number of requests per poll, with a negative value indicating no limit | `-1` |\n| `SGLANG_DISABLE_FA4_WARMUP` | Disable Flash Attention 4 warmup passes (set to `1`, `true`, `yes`, or `on` to disable) | `false` |\n| `SGLANG_DATA_PARALLEL_BUDGET_INTERVAL` | Interval for DPBudget updates | `1` |\n| `SGLANG_SCHEDULER_RECV_SKIPPER_WEIGHT_DEFAULT` | Default weight value for scheduler recv skipper counter (used when forward mode doesn't match specific modes). Only active when `--scheduler-recv-interval > 1`. The counter accumulates weights and triggers request polling when reaching the interval threshold. | `1000` |\n| `SGLANG_SCHEDULER_RECV_SKIPPER_WEIGHT_DECODE` | Weight increment for decode forward mode in scheduler recv skipper. Works with `--scheduler-recv-interval` to control polling frequency during decode phase. | `1` |\n| `SGLANG_SCHEDULER_RECV_SKIPPER_WEIGHT_TARGET_VERIFY` | Weight increment for target verify forward mode in scheduler recv skipper. Works with `--scheduler-recv-interval` to control polling frequency during verification phase. | `1` |\n| `SGLANG_SCHEDULER_RECV_SKIPPER_WEIGHT_NONE` | Weight increment when forward mode is None in scheduler recv skipper. Works with `--scheduler-recv-interval` to control polling frequency when no specific forward mode is active. | `1` |\n| `SGLANG_MM_BUFFER_SIZE_MB` | Size of preallocated GPU buffer (in MB) for multi-modal feature hashing optimization. When set to a positive value, temporarily moves features to GPU for faster hash computation, then moves them back to CPU to save GPU memory. Larger features benefit more from GPU hashing. Set to `0` to disable. | `0` |\n| `SGLANG_MM_PRECOMPUTE_HASH` | Enable precomputing of hash values for MultimodalDataItem | `false` |\n| `SGLANG_NCCL_ALL_GATHER_IN_OVERLAP_SCHEDULER_SYNC_BATCH` | Enable NCCL for gathering when preparing mlp sync batch under overlap scheduler (without this flag gloo is used for gathering) | `false` |\n| `SGLANG_SYMM_MEM_PREALLOC_GB_SIZE` | Size of preallocated GPU buffer (in GB) for NCCL symmetric memory pool to limit memory fragmentation. Only have an effect when server arg `--enable-symm-mem` is set. | `-1` |\n| `SGLANG_CUSTOM_ALLREDUCE_ALGO` | The algorithm of custom all-reduce. Set to `oneshot` or `1stage` to force use one-shot. Set to `twoshot` or `2stage` to force use two-shot. | `` |\n\n\n## DeepGEMM Configuration (Advanced Optimization)\n\n| Environment Variable | Description | Default Value |\n| --- | --- | --- |\n| `SGLANG_ENABLE_JIT_DEEPGEMM` | Enable Just-In-Time compilation of DeepGEMM kernels (enabled by default on NVIDIA Hopper (SM90) and Blackwell (SM100) GPUs when the DeepGEMM package is installed; set to `\"0\"` to disable) | `\"true\"` |\n| `SGLANG_JIT_DEEPGEMM_PRECOMPILE` | Enable precompilation of DeepGEMM kernels | `\"true\"` |\n| `SGLANG_JIT_DEEPGEMM_COMPILE_WORKERS` | Number of workers for parallel DeepGEMM kernel compilation | `4` |\n| `SGLANG_IN_DEEPGEMM_PRECOMPILE_STAGE` | Indicator flag used during the DeepGEMM precompile script | `\"false\"` |\n| `SGLANG_DG_CACHE_DIR` | Directory for caching compiled DeepGEMM kernels | `~/.cache/deep_gemm` |\n| `SGLANG_DG_USE_NVRTC` | Use NVRTC (instead of Triton) for JIT compilation (Experimental) | `\"false\"` |\n| `SGLANG_USE_DEEPGEMM_BMM` | Use DeepGEMM for Batched Matrix Multiplication (BMM) operations | `\"false\"` |\n| `SGLANG_JIT_DEEPGEMM_FAST_WARMUP` | Precompile less kernels during warmup, which reduces the warmup time from 30min to less than 3min. Might cause performance degradation during runtime. | `\"false\"` |\n\n## DeepEP Configuration\n\n| Environment Variable | Description | Default Value |\n| --- | --- | --- |\n| `SGLANG_DEEPEP_BF16_DISPATCH` | Use Bfloat16 for dispatch | `\"false\"` |\n| `SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK` | The maximum number of dispatched tokens on each GPU | `\"128\"` |\n| `SGLANG_FLASHINFER_NUM_MAX_DISPATCH_TOKENS_PER_RANK` | The maximum number of dispatched tokens on each GPU for --moe-a2a-backend=flashinfer | `\"1024\"` |\n| `SGLANG_DEEPEP_LL_COMBINE_SEND_NUM_SMS` | Number of SMs used for DeepEP combine when single batch overlap is enabled | `\"32\"` |\n| `SGLANG_BLACKWELL_OVERLAP_SHARED_EXPERTS_OUTSIDE_SBO` | Run shared experts on an alternate stream when single batch overlap is enabled on GB200. When not setting this flag, shared experts and down gemm will be overlapped with DeepEP combine together. | `\"false\"` |\n\n## MORI Configuration\n\n| Environment Variable | Description | Default Value |\n| --- | --- | --- |\n| `SGLANG_MORI_FP8_DISP` | Use FP8 for dispatch | `\"false\"` |\n| `SGLANG_MORI_FP4_DISP` | Use MXFP4 for dispatch | `\"false\"` |\n| `SGLANG_MORI_FP8_COMB` | Use FP8 for combine | `\"false\"` |\n| `SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK` | Maximum number of dispatch tokens per rank for MORI-EP buffer allocation | `4096` |\n| `SGLANG_MORI_DISPATCH_INTER_KERNEL_SWITCH_THRESHOLD` | Threshold for switching between `InterNodeV1` and `InterNodeV1LL` kernel types. `InterNodeV1LL` is used if `SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK` is less than or equal to this threshold; otherwise, `InterNodeV1` is used. | `256` |\n| `SGLANG_MORI_QP_PER_TRANSFER` | Number of RDMA Queue Pairs (QPs) used per transfer operation | `1` |\n| `SGLANG_MORI_POST_BATCH_SIZE` | Number of RDMA work requests posted in a single batch to each QP | `-1` |\n| `SGLANG_MORI_NUM_WORKERS` | Number of worker threads in the RDMA executor thread pool | `1` |\n\n## NSA Backend Configuration (For DeepSeek V3.2)\n\n<!-- # Environment variable to control mtp precomputing of metadata for multi-step speculative decoding -->\n\n| Environment Variable | Description | Default Value |\n| --- | --- | --- |\n| `SGLANG_NSA_FUSE_TOPK` | Fuse the operation of picking topk logits and picking topk indices from page table  | `true` |\n| `SGLANG_NSA_ENABLE_MTP_PRECOMPUTE_METADATA` | Precompute metadata that can be shared among different draft steps when MTP is enabled | `true` |\n| `SGLANG_USE_FUSED_METADATA_COPY` | Control whether to use fused metadata copy kernel for cuda graph replay  | `true` |\n| `SGLANG_NSA_PREFILL_DENSE_ATTN_KV_LEN_THRESHOLD` | When the maximum kv len in current prefill batch exceeds this value, the sparse mla kernel will be applied, else it falls back to dense MHA implementation. Default to the index topk of model (2048 for DeepSeek V3.2) | `2048` |\n\n\n## Memory Management\n\n| Environment Variable | Description | Default Value |\n| --- | --- | --- |\n| `SGLANG_DEBUG_MEMORY_POOL` | Enable memory pool debugging | `false` |\n| `SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION` | Clip max new tokens estimation for memory planning | `4096` |\n| `SGLANG_DETOKENIZER_MAX_STATES` | Maximum states for detokenizer | Default value based on system |\n| `SGLANG_ENABLE_TP_MEMORY_INBALANCE_CHECK` | Enable checks for memory imbalance across Tensor Parallel ranks | `true` |\n| `SGLANG_MOONCAKE_CUSTOM_MEM_POOL` | Configure the custom memory pool type for Mooncake. Supports `NVLINK`, `BAREX`, `INTRA_NODE_NVLINK`. If set to `true`, it defaults to `NVLINK`. | `None` |\n\n## Model-Specific Options\n\n| Environment Variable | Description | Default Value |\n| --- | --- | --- |\n| `SGLANG_USE_AITER` | Use AITER optimize implementation | `false` |\n| `SGLANG_MOE_PADDING` | Enable MoE padding (sets padding size to 128 if value is `1`, often set to `1` in Docker builds) | `false` |\n| `SGLANG_CUTLASS_MOE` (deprecated) | Use Cutlass FP8 MoE kernel on Blackwell GPUs (deprecated, use --moe-runner-backend=cutlass) | `false` |\n\n## Quantization\n\n| Environment Variable | Description | Default Value |\n| --- | --- | --- |\n| `SGLANG_INT4_WEIGHT` | Enable INT4 weight quantization | `false` |\n| `SGLANG_PER_TOKEN_GROUP_QUANT_8BIT_V2` | Apply per token group quantization kernel with fused silu and mul and masked m | `false` |\n| `SGLANG_FORCE_FP8_MARLIN` | Force using FP8 MARLIN kernels even if other FP8 kernels are available | `false` |\n| `SGLANG_FLASHINFER_FP4_GEMM_BACKEND` (deprecated) | Select backend for `mm_fp4` on Blackwell GPUs. **DEPRECATED**: Please use `--fp4-gemm-backend` instead. | `` |\n| `SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN` | Quantize q_b_proj from BF16 to FP8 when launching DeepSeek NVFP4 checkpoint | `false` |\n| `SGLANG_MOE_NVFP4_DISPATCH` | Use nvfp4 for moe dispatch (on flashinfer_cutlass or flashinfer_cutedsl moe runner backend) | `\"false\"` |\n| `SGLANG_NVFP4_CKPT_FP8_NEXTN_MOE` | Quantize moe of nextn layer from BF16 to FP8 when launching DeepSeek NVFP4 checkpoint | `false` |\n| `SGLANG_ENABLE_FLASHINFER_FP8_GEMM` (deprecated) | Use flashinfer kernels when running blockwise fp8 GEMM on Blackwell GPUs. **DEPRECATED**: Please use `--fp8-gemm-backend=flashinfer_trtllm` (SM100/SM103) or `--fp8-gemm-backend=flashinfer_cutlass` (SM120/SM121 and newer) instead. | `false` |\n| `SGLANG_SUPPORT_CUTLASS_BLOCK_FP8` (deprecated) | Use Cutlass kernels when running blockwise fp8 GEMM on Hopper or Blackwell GPUs. **DEPRECATED**: Please use `--fp8-gemm-backend=cutlass` instead. | `false` |\n| `SGLANG_QUANT_ALLOW_DOWNCASTING` | Allow weight dtype downcasting during loading (e.g., fp32 → fp16). By default, SGLang rejects this kind of downcasting when using quantization. | `false` |\n| `SGLANG_FP8_IGNORED_LAYERS` | A comma-separated list of layer names to ignore during FP8 quantization. For example: `model.layers.0,model.layers.1.,qkv_proj`. | `\"\"` |\n\n\n## Distributed Computing\n\n| Environment Variable | Description | Default Value |\n| --- | --- | --- |\n| `SGLANG_BLOCK_NONZERO_RANK_CHILDREN` | Control blocking of non-zero rank children processes | `1` |\n| `SGLANG_IS_FIRST_RANK_ON_NODE` | Indicates if the current process is the first rank on its node | `\"true\"` |\n| `SGLANG_PP_LAYER_PARTITION` | Pipeline parallel layer partition specification | Not set |\n| `SGLANG_ONE_VISIBLE_DEVICE_PER_PROCESS` | Set one visible device per process for distributed computing | `false` |\n\n## Testing & Debugging (Internal/CI)\n\n*These variables are primarily used for internal testing, continuous integration, or debugging.*\n\n| Environment Variable | Description | Default Value |\n| --- | --- | --- |\n| `SGLANG_IS_IN_CI` | Indicates if running in CI environment | `false` |\n| `SGLANG_IS_IN_CI_AMD` | Indicates running in AMD CI environment | `false` |\n| `SGLANG_TEST_RETRACT` | Enable retract decode testing | `false` |\n| `SGLANG_TEST_RETRACT_NO_PREFILL_BS` | When SGLANG_TEST_RETRACT is enabled, no prefill is performed if the batch size exceeds SGLANG_TEST_RETRACT_NO_PREFILL_BS. | `2 ** 31`     |\n| `SGLANG_RECORD_STEP_TIME` | Record step time for profiling | `false` |\n| `SGLANG_TEST_REQUEST_TIME_STATS` | Test request time statistics | `false` |\n\n## Profiling & Benchmarking\n\n| Environment Variable | Description | Default Value |\n| --- | --- | --- |\n| `SGLANG_TORCH_PROFILER_DIR` | Directory for PyTorch profiler output | `/tmp` |\n| `SGLANG_PROFILE_WITH_STACK` | Set `with_stack` option (bool) for PyTorch profiler (capture stack trace) | `true` |\n| `SGLANG_PROFILE_RECORD_SHAPES` | Set `record_shapes` option (bool) for PyTorch profiler (record shapes) | `true` |\n| `SGLANG_OTLP_EXPORTER_SCHEDULE_DELAY_MILLIS` | Config BatchSpanProcessor.schedule_delay_millis if tracing is enabled | `500` |\n| `SGLANG_OTLP_EXPORTER_MAX_EXPORT_BATCH_SIZE` | Config BatchSpanProcessor.max_export_batch_size if tracing is enabled | `64` |\n\n## Storage & Caching\n\n| Environment Variable | Description | Default Value |\n| --- | --- | --- |\n| `SGLANG_WAIT_WEIGHTS_READY_TIMEOUT` | Timeout period for waiting on weights | `120` |\n| `SGLANG_DISABLE_OUTLINES_DISK_CACHE` | Disable Outlines disk cache | `false` |\n| `SGLANG_USE_CUSTOM_TRITON_KERNEL_CACHE` | Use SGLang's custom Triton kernel cache implementation for lower overheads (automatically enabled on CUDA) | `false` |\n| `SGLANG_HICACHE_DECODE_OFFLOAD_STRIDE` | Decode-side incremental KV cache offload stride. Rounded down to a multiple of `--page-size` (min is `--page-size`). If unset/invalid/<=0, it falls back to `--page-size`. | Not set (uses `--page-size`) |\n\n\n## Function Calling / Tool Use\n\n| Environment Variable | Description | Default Value |\n| --- | --- | --- |\n| `SGLANG_TOOL_STRICT_LEVEL` | Controls the strictness level of tool call parsing and validation. <br>**Level 0**: Off - No strict validation <br>**Level 1**: Function strict - Enables structural tag constraints for all tools (even if none have `strict=True` set) <br>**Level 2**: Parameter strict - Enforces strict parameter validation for all tools, treating them as if they all have `strict=True` set | `0` |\n"
  },
  {
    "path": "docs/references/faq.md",
    "content": "# Troubleshooting and Frequently Asked Questions\n\n## Troubleshooting\n\nThis page lists common errors and tips for resolving them.\n\n### CUDA Out of Memory\nIf you encounter out-of-memory (OOM) errors, you can adjust the following parameters:\n\n- If OOM occurs during prefill, try reducing `--chunked-prefill-size` to `4096` or `2048`. This saves memory but slows down the prefill speed for long prompts.\n- If OOM occurs during decoding, try lowering `--max-running-requests`.\n- You can also decrease `--mem-fraction-static` to a smaller value, such as 0.8 or 0.7. This decreases the memory usage of the KV cache memory pool and helps prevent OOM errors during both prefill and decoding. However, it limits maximum concurrency and reduces peak throughput.\n- Another common case for OOM is requesting input logprobs for a long prompt as it requires significant memory. To address this, set `logprob_start_len` in your sampling parameters to include only the necessary parts. If you do need input logprobs for a long prompt, try reducing `--mem-fraction-static`.\n\n### CUDA Error: Illegal Memory Access Encountered\nThis error may result from kernel errors or out-of-memory issues:\n- If it is a kernel error, resolving it may be challenging. Please file an issue on GitHub.\n- If it is an out-of-memory issue, it may sometimes be reported as this error instead of \"Out of Memory.\" Refer to the section above for guidance on avoiding OOM issues.\n\n### The server hangs\n- If the server hangs during initialization or running, it can be memory issues (out of memory), network issues (nccl errors), or other bugs in sglang.\n    - If it is out of memory, you might see that `avail mem` is very low during the initialization or right after initialization. In this case,\n      you can try to decrease `--mem-fraction-static`, decrease `--cuda-graph-max-bs`, or decrease `--chunked-prefill-size`.\n- Other bugs, please file an issue on GitHub.\n\n\n## Frequently Asked Questions\n\n### The results are not deterministic, even with a temperature of 0\n\nYou may notice that when you send the same request twice, the results from the engine will be slightly different, even when the temperature is set to 0.\n\nFrom our initial investigation, this indeterminism arises from two factors: dynamic batching and prefix caching. Roughly speaking, dynamic batching accounts for about 95% of the indeterminism, while prefix caching accounts for the remaining portion. The server runs dynamic batching under the hood. Different batch sizes can cause PyTorch/CuBLAS to dispatch to different CUDA kernels, which can lead to slight numerical differences. This difference accumulates across many layers, resulting in nondeterministic output when the batch size changes. Similarly, when prefix caching is enabled, it can also dispatch to different kernels. Even when the computations are mathematically equivalent, small numerical differences from different kernel implementations lead to the final nondeterministic outputs.\n\nTo achieve more deterministic outputs in the current code, you can add `--disable-radix-cache` and send only one request at a time. The results will be mostly deterministic under this setting.\n\n**Update**:\nRecently, we also introduced a deterministic mode, you can enable it with `--enable-deterministic-inference`.\nPlease find more details in this blog post: https://lmsys.org/blog/2025-09-22-sglang-deterministic/\n"
  },
  {
    "path": "docs/references/frontend/choices_methods.md",
    "content": "# Choices Methods in SGLang\nThis doc describes the choices methods supported by SGLang.\n\nThe optional `choices_method` arg determines how options supplied to SGLang's `choices` primitive are selected. Only the `RuntimeEndpoint` backend supports the `choices_method` arg. Other backends, such as `OpenAI`, have bespoke selection implementations due to API limitations.\n\n## Methods\n\n### Token Length Normalized\n\nToken length normalized is the default SGLang choices method. It selects the option with the highest average logprob across all of its tokens.\n\nUsage example (alternatively, simply omit the `choices_method` arg):\n```python\n@sgl.function\ndef example(s):\n    s += sgl.user(\"What is the capital of France?\")\n    s += sgl.assistant(\n        sgl.gen(\n            \"answer\",\n            choices=[\"London\", \"Paris\", \"Berlin\"],\n            choices_method=sgl.token_length_normalized,\n        )\n    )\n```\n\n\nThis can perform poorly if an option contains many tokens, where its later tokens are predicted with high confidence based on its earlier tokens. For instance, even strong models will fail the above example if the specified options are `[\"Paris\", \"Antidisestablishmentarianism\"]`.\n\n### Greedy Token Selection\n\nGreedy token selection simply selects the option with the highest logprob for its initial token. For overlapping options where one option is a subset of a longer option, the logprobs of the shorter option are extended using its average logprob for comparison against the longer option.\n\nUsage example:\n```python\n@sgl.function\ndef example(s):\n    s += sgl.user(\"What is the capital of France?\")\n    s += sgl.assistant(\n        sgl.gen(\n            \"answer\",\n            choices=[\"London\", \"Paris\", \"Berlin\"],\n            choices_method=sgl.greedy_token_selection,\n        )\n    )\n```\n\nThis can perform poorly if an option misleads the model down a bad path based on an attractive initial token. For instance, greedy selection will result in an incorrect response for this example:\n```python\n@sgl.function\ndef us_president_example(s):\n    s += sgl.user(\"Name a US president.\")\n    s += sgl.assistant(\n        sgl.gen(\n            \"answer\",\n            choices=[\"Donald Duck\", \"Millard Fillmore\"],\n            choices_method=sgl.greedy_token_selection,\n        )\n    )\n```\n\n### Unconditional Likelihood Normalized\n\nUnconditional likelihood normalized selects the option with the highest average token logprob once normalized by the unconditional token logprobs, as described in [this EleutherAI blogpost](https://blog.eleuther.ai/multiple-choice-normalization/). This method incurs an additional LLM call to obtain the unconditional likelihoods.\n\nUsage example:\n```python\n@sgl.function\ndef example(s):\n    s += sgl.user(\"What is the capital of France?\")\n    s += sgl.assistant(\n        sgl.gen(\n            \"answer\",\n            choices=[\"London\", \"Paris\", \"Berlin\"],\n            choices_method=sgl.unconditional_likelihood_normalized,\n        )\n    )\n```\n"
  },
  {
    "path": "docs/references/frontend/frontend_index.rst",
    "content": "Frontend Language\n=================\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Frontend Language\n\n   frontend_tutorial.ipynb\n   choices_methods.md\n"
  },
  {
    "path": "docs/references/frontend/frontend_tutorial.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# SGLang Frontend Language\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"SGLang frontend language can be used to define simple and easy prompts in a convenient, structured way.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Launch A Server\\n\",\n    \"\\n\",\n    \"Launch the server in your terminal and wait for it to initialize.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from sglang import assistant_begin, assistant_end\\n\",\n    \"from sglang import assistant, function, gen, system, user\\n\",\n    \"from sglang import image\\n\",\n    \"from sglang import RuntimeEndpoint\\n\",\n    \"from sglang.lang.api import set_default_backend\\n\",\n    \"from sglang.srt.utils import load_image\\n\",\n    \"from sglang.test.doc_patch import launch_server_cmd\\n\",\n    \"from sglang.utils import print_highlight, terminate_process, wait_for_server\\n\",\n    \"\\n\",\n    \"server_process, port = launch_server_cmd(\\n\",\n    \"    \\\"python -m sglang.launch_server --model-path Qwen/Qwen2.5-7B-Instruct --host 0.0.0.0 --log-level warning\\\"\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"wait_for_server(f\\\"http://localhost:{port}\\\", process=server_process)\\n\",\n    \"print(f\\\"Server started on http://localhost:{port}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Set the default backend. Note: Besides the local server, you may use also `OpenAI` or other API endpoints.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"set_default_backend(RuntimeEndpoint(f\\\"http://localhost:{port}\\\"))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Basic Usage\\n\",\n    \"\\n\",\n    \"The most simple way of using SGLang frontend language is a simple question answer dialog between a user and an assistant.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"@function\\n\",\n    \"def basic_qa(s, question):\\n\",\n    \"    s += system(f\\\"You are a helpful assistant than can answer questions.\\\")\\n\",\n    \"    s += user(question)\\n\",\n    \"    s += assistant(gen(\\\"answer\\\", max_tokens=512))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"state = basic_qa(\\\"List 3 countries and their capitals.\\\")\\n\",\n    \"print_highlight(state[\\\"answer\\\"])\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Multi-turn Dialog\\n\",\n    \"\\n\",\n    \"SGLang frontend language can also be used to define multi-turn dialogs.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"@function\\n\",\n    \"def multi_turn_qa(s):\\n\",\n    \"    s += system(f\\\"You are a helpful assistant than can answer questions.\\\")\\n\",\n    \"    s += user(\\\"Please give me a list of 3 countries and their capitals.\\\")\\n\",\n    \"    s += assistant(gen(\\\"first_answer\\\", max_tokens=512))\\n\",\n    \"    s += user(\\\"Please give me another list of 3 countries and their capitals.\\\")\\n\",\n    \"    s += assistant(gen(\\\"second_answer\\\", max_tokens=512))\\n\",\n    \"    return s\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"state = multi_turn_qa()\\n\",\n    \"print_highlight(state[\\\"first_answer\\\"])\\n\",\n    \"print_highlight(state[\\\"second_answer\\\"])\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Control flow\\n\",\n    \"\\n\",\n    \"You may use any Python code within the function to define more complex control flows.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"@function\\n\",\n    \"def tool_use(s, question):\\n\",\n    \"    s += assistant(\\n\",\n    \"        \\\"To answer this question: \\\"\\n\",\n    \"        + question\\n\",\n    \"        + \\\". I need to use a \\\"\\n\",\n    \"        + gen(\\\"tool\\\", choices=[\\\"calculator\\\", \\\"search engine\\\"])\\n\",\n    \"        + \\\". \\\"\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    if s[\\\"tool\\\"] == \\\"calculator\\\":\\n\",\n    \"        s += assistant(\\\"The math expression is: \\\" + gen(\\\"expression\\\"))\\n\",\n    \"    elif s[\\\"tool\\\"] == \\\"search engine\\\":\\n\",\n    \"        s += assistant(\\\"The key word to search is: \\\" + gen(\\\"word\\\"))\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"state = tool_use(\\\"What is 2 * 2?\\\")\\n\",\n    \"print_highlight(state[\\\"tool\\\"])\\n\",\n    \"print_highlight(state[\\\"expression\\\"])\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Parallelism\\n\",\n    \"\\n\",\n    \"Use `fork` to launch parallel prompts. Because `sgl.gen` is non-blocking, the for loop below issues two generation calls in parallel.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"@function\\n\",\n    \"def tip_suggestion(s):\\n\",\n    \"    s += assistant(\\n\",\n    \"        \\\"Here are two tips for staying healthy: \\\"\\n\",\n    \"        \\\"1. Balanced Diet. 2. Regular Exercise.\\\\n\\\\n\\\"\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    forks = s.fork(2)\\n\",\n    \"    for i, f in enumerate(forks):\\n\",\n    \"        f += assistant(\\n\",\n    \"            f\\\"Now, expand tip {i+1} into a paragraph:\\\\n\\\"\\n\",\n    \"            + gen(\\\"detailed_tip\\\", max_tokens=256, stop=\\\"\\\\n\\\\n\\\")\\n\",\n    \"        )\\n\",\n    \"\\n\",\n    \"    s += assistant(\\\"Tip 1:\\\" + forks[0][\\\"detailed_tip\\\"] + \\\"\\\\n\\\")\\n\",\n    \"    s += assistant(\\\"Tip 2:\\\" + forks[1][\\\"detailed_tip\\\"] + \\\"\\\\n\\\")\\n\",\n    \"    s += assistant(\\n\",\n    \"        \\\"To summarize the above two tips, I can say:\\\\n\\\" + gen(\\\"summary\\\", max_tokens=512)\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"state = tip_suggestion()\\n\",\n    \"print_highlight(state[\\\"summary\\\"])\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Constrained Decoding\\n\",\n    \"\\n\",\n    \"Use `regex` to specify a regular expression as a decoding constraint. This is only supported for local models.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"@function\\n\",\n    \"def regular_expression_gen(s):\\n\",\n    \"    s += user(\\\"What is the IP address of the Google DNS servers?\\\")\\n\",\n    \"    s += assistant(\\n\",\n    \"        gen(\\n\",\n    \"            \\\"answer\\\",\\n\",\n    \"            temperature=0,\\n\",\n    \"            regex=r\\\"((25[0-5]|2[0-4]\\\\d|[01]?\\\\d\\\\d?).){3}(25[0-5]|2[0-4]\\\\d|[01]?\\\\d\\\\d?)\\\",\\n\",\n    \"        )\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"state = regular_expression_gen()\\n\",\n    \"print_highlight(state[\\\"answer\\\"])\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Use `regex` to define a `JSON` decoding schema.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"character_regex = (\\n\",\n    \"    r\\\"\\\"\\\"\\\\{\\\\n\\\"\\\"\\\"\\n\",\n    \"    + r\\\"\\\"\\\"    \\\"name\\\": \\\"[\\\\w\\\\d\\\\s]{1,16}\\\",\\\\n\\\"\\\"\\\"\\n\",\n    \"    + r\\\"\\\"\\\"    \\\"house\\\": \\\"(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)\\\",\\\\n\\\"\\\"\\\"\\n\",\n    \"    + r\\\"\\\"\\\"    \\\"blood status\\\": \\\"(Pure-blood|Half-blood|Muggle-born)\\\",\\\\n\\\"\\\"\\\"\\n\",\n    \"    + r\\\"\\\"\\\"    \\\"occupation\\\": \\\"(student|teacher|auror|ministry of magic|death eater|order of the phoenix)\\\",\\\\n\\\"\\\"\\\"\\n\",\n    \"    + r\\\"\\\"\\\"    \\\"wand\\\": \\\\{\\\\n\\\"\\\"\\\"\\n\",\n    \"    + r\\\"\\\"\\\"        \\\"wood\\\": \\\"[\\\\w\\\\d\\\\s]{1,16}\\\",\\\\n\\\"\\\"\\\"\\n\",\n    \"    + r\\\"\\\"\\\"        \\\"core\\\": \\\"[\\\\w\\\\d\\\\s]{1,16}\\\",\\\\n\\\"\\\"\\\"\\n\",\n    \"    + r\\\"\\\"\\\"        \\\"length\\\": [0-9]{1,2}\\\\.[0-9]{0,2}\\\\n\\\"\\\"\\\"\\n\",\n    \"    + r\\\"\\\"\\\"    \\\\},\\\\n\\\"\\\"\\\"\\n\",\n    \"    + r\\\"\\\"\\\"    \\\"alive\\\": \\\"(Alive|Deceased)\\\",\\\\n\\\"\\\"\\\"\\n\",\n    \"    + r\\\"\\\"\\\"    \\\"patronus\\\": \\\"[\\\\w\\\\d\\\\s]{1,16}\\\",\\\\n\\\"\\\"\\\"\\n\",\n    \"    + r\\\"\\\"\\\"    \\\"bogart\\\": \\\"[\\\\w\\\\d\\\\s]{1,16}\\\"\\\\n\\\"\\\"\\\"\\n\",\n    \"    + r\\\"\\\"\\\"\\\\}\\\"\\\"\\\"\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"@function\\n\",\n    \"def character_gen(s, name):\\n\",\n    \"    s += user(\\n\",\n    \"        f\\\"{name} is a character in Harry Potter. Please fill in the following information about this character.\\\"\\n\",\n    \"    )\\n\",\n    \"    s += assistant(gen(\\\"json_output\\\", max_tokens=256, regex=character_regex))\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"state = character_gen(\\\"Harry Potter\\\")\\n\",\n    \"print_highlight(state[\\\"json_output\\\"])\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Batching \\n\",\n    \"\\n\",\n    \"Use `run_batch` to run a batch of prompts.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"@function\\n\",\n    \"def text_qa(s, question):\\n\",\n    \"    s += user(question)\\n\",\n    \"    s += assistant(gen(\\\"answer\\\", stop=\\\"\\\\n\\\"))\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"states = text_qa.run_batch(\\n\",\n    \"    [\\n\",\n    \"        {\\\"question\\\": \\\"What is the capital of the United Kingdom?\\\"},\\n\",\n    \"        {\\\"question\\\": \\\"What is the capital of France?\\\"},\\n\",\n    \"        {\\\"question\\\": \\\"What is the capital of Japan?\\\"},\\n\",\n    \"    ],\\n\",\n    \"    progress_bar=True,\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"for i, state in enumerate(states):\\n\",\n    \"    print_highlight(f\\\"Answer {i+1}: {states[i]['answer']}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Streaming \\n\",\n    \"\\n\",\n    \"Use `stream` to stream the output to the user.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"@function\\n\",\n    \"def text_qa(s, question):\\n\",\n    \"    s += user(question)\\n\",\n    \"    s += assistant(gen(\\\"answer\\\", stop=\\\"\\\\n\\\"))\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"state = text_qa.run(\\n\",\n    \"    question=\\\"What is the capital of France?\\\", temperature=0.1, stream=True\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"for out in state.text_iter():\\n\",\n    \"    print(out, end=\\\"\\\", flush=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Complex Prompts\\n\",\n    \"\\n\",\n    \"You may use `{system|user|assistant}_{begin|end}` to define complex prompts.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"@function\\n\",\n    \"def chat_example(s):\\n\",\n    \"    s += system(\\\"You are a helpful assistant.\\\")\\n\",\n    \"    # Same as: s += s.system(\\\"You are a helpful assistant.\\\")\\n\",\n    \"\\n\",\n    \"    with s.user():\\n\",\n    \"        s += \\\"Question: What is the capital of France?\\\"\\n\",\n    \"\\n\",\n    \"    s += assistant_begin()\\n\",\n    \"    s += \\\"Answer: \\\" + gen(\\\"answer\\\", max_tokens=100, stop=\\\"\\\\n\\\")\\n\",\n    \"    s += assistant_end()\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"state = chat_example()\\n\",\n    \"print_highlight(state[\\\"answer\\\"])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"terminate_process(server_process)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Multi-modal Generation\\n\",\n    \"\\n\",\n    \"You may use SGLang frontend language to define multi-modal prompts.\\n\",\n    \"See [here](https://docs.sglang.io/supported_models/text_generation/multimodal_language_models.html) for supported models.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"server_process, port = launch_server_cmd(\\n\",\n    \"    \\\"python -m sglang.launch_server --model-path Qwen/Qwen2.5-VL-7B-Instruct --host 0.0.0.0 --log-level warning\\\"\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"wait_for_server(f\\\"http://localhost:{port}\\\", process=server_process)\\n\",\n    \"print(f\\\"Server started on http://localhost:{port}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"set_default_backend(RuntimeEndpoint(f\\\"http://localhost:{port}\\\"))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Ask a question about an image.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"@function\\n\",\n    \"def image_qa(s, image_file, question):\\n\",\n    \"    s += user(image(image_file) + question)\\n\",\n    \"    s += assistant(gen(\\\"answer\\\", max_tokens=256))\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"image_url = \\\"https://raw.githubusercontent.com/sgl-project/sglang/main/examples/assets/example_image.png\\\"\\n\",\n    \"image_bytes, _ = load_image(image_url)\\n\",\n    \"state = image_qa(image_bytes, \\\"What is in the image?\\\")\\n\",\n    \"print_highlight(state[\\\"answer\\\"])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"terminate_process(server_process)\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"language_info\": {\n   \"name\": \"python\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "docs/references/learn_more.md",
    "content": "# Learn More and Join the Community\n\n- The development roadmap: [https://roadmap.sglang.io](https://roadmap.sglang.io)\n- Join weekly public development meeting: [https://meet.sglang.io](https://meet.sglang.io)\n- Join Slack: [https://slack.sglang.io/](https://slack.sglang.io/)\n- Follow on X (formerly Twitter): [https://x.com/lmsysorg](https://x.com/lmsysorg)\n- Follow on LinkedIn: [https://www.linkedin.com/company/sgl-project/](https://www.linkedin.com/company/sgl-project/)\n- The latest SGLang features and updates are shared through the [LMSYS blog](https://lmsys.org/blog/)\n- More blogs, slides, and videos about SGLang at [https://github.com/sgl-project/sgl-learning-materials](https://github.com/sgl-project/sgl-learning-materials)\n"
  },
  {
    "path": "docs/references/multi_node_deployment/deploy_on_k8s.md",
    "content": "# Deploy On Kubernetes\n\nThis document is for deploying a RoCE network-based SGLang two-node inference service on a Kubernetes (K8S) cluster.\n\n[LeaderWorkerSet (LWS)](https://github.com/kubernetes-sigs/lws) is a Kubernetes API that aims to address common deployment patterns of AI/ML inference workloads. A major use case is for multi-host/multi-node distributed inference.\n\nSGLang can also be deployed with LWS on Kubernetes for distributed model serving.\n\nPlease see this guide for more details on deploying SGLang on Kubernetes using LWS.\n\nHere we take the deployment of DeepSeek-R1 as an example.\n\n## Prerequisites\n\n1. At least two Kubernetes nodes, each with two H20 systems and eight GPUs, are required.\n\n2. Make sure your K8S cluster has LWS correctly installed. If it hasn't been set up yet, please follow the [installation instructions](https://github.com/kubernetes-sigs/lws/blob/main/site/content/en/docs/installation/_index.md). **Note:** For LWS versions ≤0.5.x, you must use the Downward API to obtain `LWS_WORKER_INDEX`, as native support for this feature was introduced in v0.6.0.\n\n## Basic example\n\nFor the basic example documentation, refer to [Deploy Distributed Inference Service with SGLang and LWS on GPUs](https://github.com/kubernetes-sigs/lws/tree/main/docs/examples/sglang).\n\nHowever, that document only covers the basic NCCL socket mode.\n\nIn this section, we’ll make some simple modifications to adapt the setup to the RDMA scenario.\n\n## RDMA RoCE case\n\n* Check your env:\n\n```bash\n[root@node1 ~]# ibstatus\nInfiniband device 'mlx5_bond_0' port 1 status:\n        default gid:     fe80:0000:0000:0000:0225:9dff:fe64:c79a\n        base lid:        0x0\n        sm lid:          0x0\n        state:           4: ACTIVE\n        phys state:      5: LinkUp\n        rate:            200 Gb/sec (2X NDR)\n        link_layer:      Ethernet\n\nInfiniband device 'mlx5_bond_1' port 1 status:\n        default gid:     fe80:0000:0000:0000:0225:9dff:fe6e:c3ec\n        base lid:        0x0\n        sm lid:          0x0\n        state:           4: ACTIVE\n        phys state:      5: LinkUp\n        rate:            200 Gb/sec (2X NDR)\n        link_layer:      Ethernet\n\nInfiniband device 'mlx5_bond_2' port 1 status:\n        default gid:     fe80:0000:0000:0000:0225:9dff:fe73:0dd7\n        base lid:        0x0\n        sm lid:          0x0\n        state:           4: ACTIVE\n        phys state:      5: LinkUp\n        rate:            200 Gb/sec (2X NDR)\n        link_layer:      Ethernet\n\nInfiniband device 'mlx5_bond_3' port 1 status:\n        default gid:     fe80:0000:0000:0000:0225:9dff:fe36:f7ff\n        base lid:        0x0\n        sm lid:          0x0\n        state:           4: ACTIVE\n        phys state:      5: LinkUp\n        rate:            200 Gb/sec (2X NDR)\n        link_layer:      Ethernet\n```\n\n* Prepare the `lws.yaml` file for deploying on k8s.\n\n```yaml\napiVersion: leaderworkerset.x-k8s.io/v1\nkind: LeaderWorkerSet\nmetadata:\n  name: sglang\nspec:\n  replicas: 1\n  leaderWorkerTemplate:\n    size: 2\n    restartPolicy: RecreateGroupOnPodRestart\n    leaderTemplate:\n      metadata:\n        labels:\n          role: leader\n      spec:\n        dnsPolicy: ClusterFirstWithHostNet\n        hostNetwork: true\n        hostIPC: true\n        containers:\n          - name: sglang-leader\n            image: sglang:latest\n            securityContext:\n              privileged: true\n            env:\n              - name: NCCL_IB_GID_INDEX\n                value: \"3\"\n            command:\n              - python3\n              - -m\n              - sglang.launch_server\n              - --model-path\n              - /work/models\n              - --mem-fraction-static\n              -  \"0.93\"\n              - --torch-compile-max-bs\n              - \"8\"\n              - --max-running-requests\n              - \"20\"\n              - --tp\n              - \"16\" # Size of Tensor Parallelism\n              - --dist-init-addr\n              - $(LWS_LEADER_ADDRESS):20000\n              - --nnodes\n              - $(LWS_GROUP_SIZE)\n              - --node-rank\n              - $(LWS_WORKER_INDEX)\n              - --trust-remote-code\n              - --host\n              - \"0.0.0.0\"\n              - --port\n              - \"40000\"\n            resources:\n              limits:\n                nvidia.com/gpu: \"8\"\n            ports:\n              - containerPort: 40000\n            readinessProbe:\n              tcpSocket:\n                port: 40000\n              initialDelaySeconds: 15\n              periodSeconds: 10\n            volumeMounts:\n              - mountPath: /dev/shm\n                name: dshm\n              - name: model\n                mountPath: /work/models\n              - name: ib\n                mountPath: /dev/infiniband\n        volumes:\n          - name: dshm\n            emptyDir:\n              medium: Memory\n          - name: model\n            hostPath:\n              path: '< your models dir >' # modify it according your models dir\n          - name: ib\n            hostPath:\n              path: /dev/infiniband\n    workerTemplate:\n      spec:\n        dnsPolicy: ClusterFirstWithHostNet\n        hostNetwork: true\n        hostIPC: true\n        containers:\n          - name: sglang-worker\n            image: sglang:latest\n            securityContext:\n              privileged: true\n            env:\n            - name: NCCL_IB_GID_INDEX\n              value: \"3\"\n            command:\n              - python3\n              - -m\n              - sglang.launch_server\n              - --model-path\n              - /work/models\n              - --mem-fraction-static\n              - \"0.93\"\n              - --torch-compile-max-bs\n              - \"8\"\n              - --max-running-requests\n              - \"20\"\n              - --tp\n              - \"16\" # Size of Tensor Parallelism\n              - --dist-init-addr\n              - $(LWS_LEADER_ADDRESS):20000\n              - --nnodes\n              - $(LWS_GROUP_SIZE)\n              - --node-rank\n              - $(LWS_WORKER_INDEX)\n              - --trust-remote-code\n            resources:\n              limits:\n                nvidia.com/gpu: \"8\"\n            volumeMounts:\n              - mountPath: /dev/shm\n                name: dshm\n              - name: model\n                mountPath: /work/models\n              - name: ib\n                mountPath: /dev/infiniband\n        volumes:\n          - name: dshm\n            emptyDir:\n              medium: Memory\n          - name: ib\n            hostPath:\n              path: /dev/infiniband\n          - name: model\n            hostPath:\n              path: /data1/models/deepseek_v3_moe\n---\napiVersion: v1\nkind: Service\nmetadata:\n  name: sglang-leader\nspec:\n  selector:\n    leaderworkerset.sigs.k8s.io/name: sglang\n    role: leader\n  ports:\n    - protocol: TCP\n      port: 40000\n      targetPort: 40000\n\n```\n\n* Then use  `kubectl apply -f lws.yaml` you will get this output.\n\n```text\nNAME           READY   STATUS    RESTARTS       AGE\nsglang-0       0/1     Running   0              9s\nsglang-0-1     1/1     Running   0              9s\n```\n\nWait for the sglang leader (`sglang-0`) status to change to 1/1, which indicates it is `Ready`.\n\nYou can use the command `kubectl logs -f sglang-0` to view the logs of the leader node.\n\nOnce successful, you should see output like this:\n\n```text\n[2025-02-17 05:27:24 TP1] Capture cuda graph end. Time elapsed: 84.89 s\n[2025-02-17 05:27:24 TP6] max_total_num_tokens=712400, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=50, context_len=163840\n[2025-02-17 05:27:24 TP0] max_total_num_tokens=712400, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=50, context_len=163840\n[2025-02-17 05:27:24 TP7] max_total_num_tokens=712400, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=50, context_len=163840\n[2025-02-17 05:27:24 TP3] max_total_num_tokens=712400, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=50, context_len=163840\n[2025-02-17 05:27:24 TP2] max_total_num_tokens=712400, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=50, context_len=163840\n[2025-02-17 05:27:24 TP4] max_total_num_tokens=712400, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=50, context_len=163840\n[2025-02-17 05:27:24 TP1] max_total_num_tokens=712400, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=50, context_len=163840\n[2025-02-17 05:27:24 TP5] max_total_num_tokens=712400, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=50, context_len=163840\n[2025-02-17 05:27:24] INFO:     Started server process [1]\n[2025-02-17 05:27:24] INFO:     Waiting for application startup.\n[2025-02-17 05:27:24] INFO:     Application startup complete.\n[2025-02-17 05:27:24] INFO:     Uvicorn running on http://0.0.0.0:40000 (Press CTRL+C to quit)\n[2025-02-17 05:27:25] INFO:     127.0.0.1:48908 - \"GET /get_model_info HTTP/1.1\" 200 OK\n[2025-02-17 05:27:25 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, cache hit rate: 0.00%, token usage: 0.00, #running-req: 0, #queue-req: 0\n[2025-02-17 05:27:32] INFO:     127.0.0.1:48924 - \"POST /generate HTTP/1.1\" 200 OK\n[2025-02-17 05:27:32] The server is fired up and ready to roll!\n```\n\nIf it doesn’t start up successfully, please follow these steps to check for any remaining issues. Thanks!\n\n### Debug\n\n* Set `NCCL_DEBUG=TRACE` to check if it is a NCCL communication problem.\n\nThis should resolve most NCCL-related issues.\n\n***Notice: If you find that NCCL_DEBUG=TRACE is not effective in the container environment, but the process is stuck or you encounter hard-to-diagnose issues, try switching to a different container image. Some images may not handle standard error output properly.***\n\n#### RoCE scenario\n\n* Please make sure that RDMA devices are available in the cluster environment.\n* Please make sure that the nodes in the cluster have Mellanox NICs with RoCE. In this example, we use Mellanox ConnectX 5 model NICs, and the proper OFED driver has been installed. If not, please refer to the document [Install OFED Driver](https://docs.nvidia.com/networking/display/mlnxofedv461000/installing+mellanox+ofed) to install the driver.\n* Check your env:\n\n  ```shell\n  $ lspci -nn | grep Eth | grep Mellanox\n  0000:7f:00.0 Ethernet controller [0200]: Mellanox Technologies MT43244 BlueField-3 integrated ConnectX-7 network controller [15b3:a2dc] (rev 01)\n  0000:7f:00.1 Ethernet controller [0200]: Mellanox Technologies MT43244 BlueField-3 integrated ConnectX-7 network controller [15b3:a2dc] (rev 01)\n  0000:c7:00.0 Ethernet controller [0200]: Mellanox Technologies MT43244 BlueField-3 integrated ConnectX-7 network controller [15b3:a2dc] (rev 01)\n  0000:c7:00.1 Ethernet controller [0200]: Mellanox Technologies MT43244 BlueField-3 integrated ConnectX-7 network controller [15b3:a2dc] (rev 01)\n  0001:08:00.0 Ethernet controller [0200]: Mellanox Technologies MT43244 BlueField-3 integrated ConnectX-7 network controller [15b3:a2dc] (rev 01)\n  0001:08:00.1 Ethernet controller [0200]: Mellanox Technologies MT43244 BlueField-3 integrated ConnectX-7 network controller [15b3:a2dc] (rev 01)\n  0001:a2:00.0 Ethernet controller [0200]: Mellanox Technologies MT43244 BlueField-3 integrated ConnectX-7 network controller [15b3:a2dc] (rev 01)\n  0001:a2:00.1 Ethernet controller [0200]: Mellanox Technologies MT43244 BlueField-3 integrated ConnectX-7 network controller [15b3:a2dc] (rev 01)\n  ```\n\n* Check the OFED driver:\n\n  ```shell\n  ofed_info -s\n  OFED-internal-23.07-0.5.0:\n  ```\n\n* Show RDMA link status and check IB devices:\n\n  ```shell\n  $ rdma link show\n  8/1: mlx5_bond_0/1: state ACTIVE physical_state LINK_UP netdev reth0\n  9/1: mlx5_bond_1/1: state ACTIVE physical_state LINK_UP netdev reth2\n  10/1: mlx5_bond_2/1: state ACTIVE physical_state LINK_UP netdev reth4\n  11/1: mlx5_bond_3/1: state ACTIVE physical_state LINK_UP netdev reth6\n\n  $ ibdev2netdev\n  8/1: mlx5_bond_0/1: state ACTIVE physical_state LINK_UP netdev reth0\n  9/1: mlx5_bond_1/1: state ACTIVE physical_state LINK_UP netdev reth2\n  10/1: mlx5_bond_2/1: state ACTIVE physical_state LINK_UP netdev reth4\n  11/1: mlx5_bond_3/1: state ACTIVE physical_state LINK_UP netdev reth6\n  ```\n\n* Test RoCE network speed on the host:\n\n  ```shell\n  yum install qperf\n  # for server：\n  execute qperf\n  # for client\n  qperf -t 60 -cm1 <server_ip>   rc_rdma_write_bw\n  ```\n\n* Check RDMA accessible in your container:\n\n  ```shell\n  # ibv_devices\n  # ibv_devinfo\n  ```\n\n## Keys to success\n\n* In the YAML configuration above, pay attention to the NCCL environment variable. For older versions of NCCL, you should check the NCCL_IB_GID_INDEX environment setting.\n* NCCL_SOCKET_IFNAME is also crucial, but in a containerized environment, this typically isn’t an issue.\n* In some cases, it’s necessary to configure GLOO_SOCKET_IFNAME correctly.\n* NCCL_DEBUG is essential for troubleshooting, but I've found that sometimes it doesn't show error logs within containers. This could be related to the Docker image you're using. You may want to try switching images if needed.\n* Avoid using Docker images based on Ubuntu 18.04, as they tend to have compatibility issues.\n\n## Remaining issues\n\n* In Kubernetes, Docker, or Containerd environments, we use hostNetwork to prevent performance degradation.\n* We utilize privileged mode, which  isn’t secure. Additionally, in containerized environments, full GPU isolation cannot be achieved.\n\n## TODO\n\n* Integrated with [k8s-rdma-shared-dev-plugin](https://github.com/Mellanox/k8s-rdma-shared-dev-plugin).\n"
  },
  {
    "path": "docs/references/multi_node_deployment/lws_pd/lws-examples/d-svc.yaml",
    "content": "apiVersion: v1\nkind: Service\nmetadata:\n  name: deepseekr10528-decode-main\nspec:\n  selector:\n    leaderworkerset.sigs.k8s.io/name: deepseekr10528-decode-main\n    role: leader\n  ports:\n    - protocol: TCP\n      port: 30000\n      targetPort: 30000\n"
  },
  {
    "path": "docs/references/multi_node_deployment/lws_pd/lws-examples/d.yaml",
    "content": "apiVersion: leaderworkerset.x-k8s.io/v1\nkind: LeaderWorkerSet\nmetadata:\n  name: deepseekr10528-decode-main\nspec:\n  leaderWorkerTemplate:\n    leaderTemplate:\n      metadata:\n        labels:\n          role: leader\n      spec:\n        containers:\n        - command:\n          - python3\n          - -m\n          - sglang.launch_server\n          - --port\n          - \"30000\"\n          - --host\n          - \"0.0.0.0\"\n          - --model-path\n          - /work/models\n          - --chunked-prefill-size\n          - \"262144\"\n          - --page-size\n          - \"64\"\n          - --enable-dp-attention\n          - --enable-dp-lm-head\n          - --dp-size\n          - \"16\"\n          - --moe-a2a-backend\n          - deepep\n          - --disaggregation-mode\n          - decode\n          - --mem-fraction-static\n          - \"0.849\"\n          - --context-length\n          - \"32768\"\n          - --disaggregation-ib-device\n          - \"mlx5_bond_0,mlx5_bond_1,mlx5_bond_2,mlx5_bond_3\"\n          - --cuda-graph-max-bs\n          - \"64\"\n          - --max-running-requests\n          - \"2048\"\n          - --tp-size\n          - \"16\" # Size of Tensor Parallelism\n          - --dist-init-addr\n          - $(LWS_LEADER_ADDRESS):20102\n          - --nnodes\n          - $(LWS_GROUP_SIZE)\n          - --node-rank\n          - $(LWS_WORKER_INDEX)\n          - --trust-remote-code\n          - --ep-num-redundant-experts\n          - \"32\"\n          - --moe-dense-tp-size\n          - \"1\"\n          env:\n          - name: CUDA_LAUNCH_BLOCKING\n            value: \"0\"\n          - name: NVSHMEM_IB_GID_INDEX\n            value: \"3\"\n          - name: NVSHMEM_ENABLE_NIC_PE_MAPPING\n            value: \"1\"\n          - name: NVSHMEM_HCA_PE_MAPPING\n            value: \"mlx5_bond_0:1:2,mlx5_bond_1:1:2,mlx5_bond_2:1:2,mlx5_bond_3:1:2\"\n          - name:  NCCL_IB_QPS_PER_CONNECTION\n            value: \"8\"\n          - name: NCCL_IB_SPLIT_DATA_ON_QPS\n            value: \"1\"\n          - name: NCCL_NET_PLUGIN\n            value: \"none\"\n          - name: NCCL_IB_TC\n            value: \"136\"\n          - name: NCCL_MIN_NCHANNELS\n            value: \"4\"\n          - name: NCCL_IB_SL\n            value: \"5\"\n          - name: MC_TE_METRIC\n            value: \"true\"\n          - name: SGLANG_MOONCAKE_TRANS_THREAD\n            value: \"16\"\n          - name: SGLANG_ENABLE_JIT_DEEPGEMM\n            value: \"1\"\n          - name: NCCL_IB_HCA\n            value: ^=mlx5_0,mlx5_5,mlx5_6\n          - name: LWS_WORKER_INDEX\n            valueFrom:\n              fieldRef:\n                fieldPath: metadata.labels['leaderworkerset.sigs.k8s.io/worker-index']\n          image: lmsysorg/sglang:latest\n          name: sglang-leader\n          ports:\n          - containerPort: 30000\n            protocol: TCP\n          readinessProbe:\n            periodSeconds: 30\n            tcpSocket:\n              port: 30000\n          resources:\n            limits:\n              nvidia.com/gpu: \"8\"\n          securityContext:\n            capabilities:\n              add:\n              - IPC_LOCK\n            privileged: true\n          volumeMounts:\n          - mountPath: /root/.cache\n            name: sgl-cache\n          - mountPath: /dev/shm\n            name: dshm\n          - mountPath: /work/models\n            name: model\n          - mountPath: /dev/infiniband\n            name: ib\n          - mountPath: /sgl-workspace/sglang/python/sglang/srt/layers/moe/fused_moe_triton/configs\n            name: cf\n        dnsPolicy: ClusterFirstWithHostNet\n        hostIPC: true\n        hostNetwork: true\n        nodeSelector:\n        # should modify according your deployment env\n          pd: \"yes\"\n        tolerations:\n        # should modify according your deployment env\n        - key: bopd\n          operator: Exists\n        - key: node-role\n          operator: Exists\n        volumes:\n        - hostPath:\n            path: /data1/sgl_cache1\n            type: DirectoryOrCreate\n          name: sgl-cache\n        - emptyDir:\n            medium: Memory\n          name: dshm\n        - hostPath:\n            path: /data1/maas_hosted_models/models/DeepSeek-R1-0528/deepseek_r1_0528\n          name: model\n        - hostPath:\n            path: /dev/infiniband\n          name: ib\n        - hostPath:\n            path: /data1/maas_hosted_models/models/fused_moe_triton/configs\n          name: cf\n    restartPolicy: RecreateGroupOnPodRestart\n    size:  2\n    workerTemplate:\n      metadata: {}\n      spec:\n        containers:\n        - command:\n          - python3\n          - -m\n          - sglang.launch_server\n          - --model-path\n          - /work/models\n          - --chunked-prefill-size\n          - \"262144\"\n          - --page-size\n          - \"64\"\n          - --enable-dp-attention\n          - --enable-dp-lm-head\n          - --dp-size\n          - \"16\"\n          - --moe-a2a-backend\n          - deepep\n          - --disaggregation-mode\n          - decode\n          - --mem-fraction-static\n          - \"0.849\"\n          - --context-length\n          - \"32768\"\n          - --disaggregation-ib-device\n          - \"mlx5_bond_0,mlx5_bond_1,mlx5_bond_2,mlx5_bond_3\"\n          - --cuda-graph-max-bs\n          - \"64\"\n          - --max-running-requests\n          - \"2048\"\n          - --tp-size\n          - \"16\" # Size of Tensor Parallelism\n          - --dist-init-addr\n          - $(LWS_LEADER_ADDRESS):20102\n          - --nnodes\n          - $(LWS_GROUP_SIZE)\n          - --node-rank\n          - $(LWS_WORKER_INDEX)\n          - --trust-remote-code\n          - --ep-num-redundant-experts\n          - \"32\"\n          - --moe-dense-tp-size\n          - \"1\"\n          env:\n          - name: NVSHMEM_IB_TRAFFIC_CLASS\n            value: \"16\"\n          - name: NVSHMEM_IB_GID_INDEX\n            value: \"3\"\n          - name: NVSHMEM_ENABLE_NIC_PE_MAPPING\n            value: \"1\"\n          - name: NVSHMEM_HCA_PE_MAPPING\n            value: \"mlx5_bond_0:1:2,mlx5_bond_1:1:2,mlx5_bond_2:1:2,mlx5_bond_3:1:2\"\n          - name:  NCCL_IB_QPS_PER_CONNECTION\n            value: \"8\"\n          - name: NCCL_IB_SPLIT_DATA_ON_QPS\n            value: \"1\"\n          - name: NCCL_NET_PLUGIN\n            value: \"none\"\n          - name: NCCL_IB_TC\n            value: \"136\"\n          - name: NCCL_MIN_NCHANNELS\n            value: \"4\"\n          - name: MC_TE_METRIC\n            value: \"true\"\n          - name: NCCL_IB_SL\n            value: \"5\"\n          - name: SGLANG_MOONCAKE_TRANS_THREAD\n            value: \"16\"\n          - name: SGLANG_ENABLE_JIT_DEEPGEMM\n            value: \"1\"\n          - name: NCCL_IB_HCA\n            value: ^=mlx5_0,mlx5_5,mlx5_6\n          - name: LWS_WORKER_INDEX\n            valueFrom:\n              fieldRef:\n                fieldPath: metadata.labels['leaderworkerset.sigs.k8s.io/worker-index']\n          image: lmsysorg/sglang:latest\n          name: sglang-worker\n          ports:\n          - containerPort: 30001\n          resources:\n            limits:\n              nvidia.com/gpu: \"8\"\n          securityContext:\n            capabilities:\n              add:\n              - IPC_LOCK\n            privileged: true\n          volumeMounts:\n          - mountPath: /root/.cache\n            name: sgl-cache\n          - mountPath: /dev/shm\n            name: dshm\n          - mountPath: /work/models\n            name: model\n          - mountPath: /dev/infiniband\n            name: ib\n          - mountPath: /sgl-workspace/sglang/python/sglang/srt/layers/moe/fused_moe_triton/configs\n            name: cf\n        dnsPolicy: ClusterFirstWithHostNet\n        hostIPC: true\n        hostNetwork: true\n        nodeSelector:\n        # should modify according your deployment env\n          pd: \"yes\"\n        tolerations:\n        # should modify according your deployment env\n        - key: bopd\n          operator: Exists\n        - key: node-role\n          operator: Exists\n        volumes:\n        - hostPath:\n            path: /data1/sgl_cache1\n            type: DirectoryOrCreate\n          name: sgl-cache\n        - emptyDir:\n            medium: Memory\n          name: dshm\n        - hostPath:\n            path: /dev/infiniband\n          name: ib\n        - hostPath:\n            # modify according to you deployment env\n            path: /data1/maas_hosted_models/models/DeepSeek-R1-0528/deepseek_r1_0528\n          name: model\n        - hostPath:\n            # modify according to you deployment env\n            path: /data1/maas_hosted_models/models/fused_moe_triton/configs\n          name: cf\n  networkConfig:\n    subdomainPolicy: Shared\n  replicas: 1\n  rolloutStrategy:\n    rollingUpdateConfiguration:\n      maxSurge: 0\n      maxUnavailable: 1\n    type: RollingUpdate\n  startupPolicy: LeaderCreated\n"
  },
  {
    "path": "docs/references/multi_node_deployment/lws_pd/lws-examples/lb.yaml",
    "content": "apiVersion: apps/v1\nkind: Deployment\nmetadata:\n  name: deepseekr10528-lb-main\n  labels:\n    app: deepseekr10528-lb\nspec:\n  replicas: 1\n  selector:\n    matchLabels:\n      app: deepseekr10528-lb\n  template:\n    metadata:\n      labels:\n        app: deepseekr10528-lb\n    spec:\n      nodeSelector:\n          bo: \"yes\"\n      tolerations:\n        - key: bopd\n          operator: Exists\n        - key: node-role\n          operator: Exists\n      containers:\n        - name: sgl-minilb\n          image: lmsysorg/sglang:latest\n          command:\n          - python\n          - -m\n          - sglang_router.launch_router\n          - --pd-disaggregation\n          - --prefill\n          - http://deepseekr10528-prefill-main:30000\n          - --decode\n          - http://deepseekr10528-decode-main:30000\n          - --host\n          - 0.0.0.0\n          - --port\n          -  \"8000\"\n          ports:\n            - containerPort: 8000\n\n---\napiVersion: v1\nkind: Service\nmetadata:\n  name: deepseekr10528-lb-service\nspec:\n  type: NodePort # NodePort is easy to test, you can also specify `ClusterIP`\n  selector:\n    app: deepseekr10528-lb\n  ports:\n    - protocol: TCP\n      port: 8000         # Service Port（In-Cluster）\n      targetPort: 8000   # Exposed Container\n      nodePort: 30800\n"
  },
  {
    "path": "docs/references/multi_node_deployment/lws_pd/lws-examples/p-svc.yaml",
    "content": "apiVersion: v1\nkind: Service\nmetadata:\n  name: deepseekr10528-prefill-main\nspec:\n  selector:\n    leaderworkerset.sigs.k8s.io/name: deepseekr10528-prefill-main\n    role: leader\n  ports:\n    - protocol: TCP\n      port: 30000\n      targetPort: 30000\n"
  },
  {
    "path": "docs/references/multi_node_deployment/lws_pd/lws-examples/p.yaml",
    "content": "apiVersion: leaderworkerset.x-k8s.io/v1\nkind: LeaderWorkerSet\nmetadata:\n  name: deepseekr10528-prefill-main\nspec:\n  leaderWorkerTemplate:\n    leaderTemplate:\n      metadata:\n        labels:\n          role: leader\n      spec:\n        containers:\n        - command:\n          - python3\n          - -m\n          - sglang.launch_server\n          - --port\n          - \"30000\"\n          - --host\n          - \"0.0.0.0\"\n          - --model-path\n          - /work/models\n          - --disaggregation-ib-device\n          # should modify according your rdma env\n          - mlx5_bond_0,mlx5_bond_1,mlx5_bond_2,mlx5_bond_3\n          - --chunked-prefill-size\n          - \"524288\"\n          - --max-prefill-tokens\n          - \"32768\"\n          - --page-size\n          - \"64\"\n          - --ep-dispatch-algorithm\n          - dynamic\n          - --eplb-algorithm\n          - deepseek\n          - --enable-dp-lm-head\n          - --enable-dp-attention\n          - --dp-size\n          - \"16\"\n          - --disable-radix-cache\n          - --moe-a2a-backend\n          - deepep\n          - --disaggregation-mode\n          - prefill\n          - --mem-fraction-static\n          - \"0.7\"\n          - --context-length\n          - \"32768\"\n          - --tp\n          - \"16\"\n          - --dist-init-addr\n          - $(LWS_LEADER_ADDRESS):20102\n          - --nnodes\n          - $(LWS_GROUP_SIZE)\n          - --node-rank\n          - $(LWS_WORKER_INDEX)\n          - --trust-remote-code\n          - --ep-num-redundant-experts\n          - \"32\"\n          - --moe-dense-tp-size\n          - \"1\"\n          - --max-running-requests\n          - \"1024\"\n          env:\n          - name: NVSHMEM_HCA_PE_MAPPING\n            # should modify according your rdma env\n            value: \"mlx5_bond_0:1:2,mlx5_bond_1:1:2,mlx5_bond_2:1:2,mlx5_bond_3:1:2\"\n          - name: NVSHMEM_IB_GID_INDEX\n            value: \"3\"\n          - name: NVSHMEM_ENABLE_NIC_PE_MAPPING\n            value: \"1\"\n          - name: SGLANG_SET_CPU_AFFINITY\n            value: \"true\"\n          - name: SGLANG_ENABLE_JIT_DEEPGEMM\n            value: \"1\"\n          - name: NCCL_IB_QPS_PER_CONNECTION\n            value: \"8\"\n          - name: NCCL_IB_SPLIT_DATA_ON_QPS\n            value: \"1\"\n          - name: NCCL_NET_PLUGIN\n            value: none\n          - name: NCCL_IB_TC\n            value: \"136\"\n          - name: NCCL_MIN_NCHANNELS\n            value: \"4\"\n          - name: MC_TE_METRIC\n            value: \"false\"\n          - name: NCCL_IB_SL\n            value: \"5\"\n          - name: NCCL_IB_HCA\n            value: ^=mlx5_0,mlx5_5,mlx5_6\n          - name: LWS_WORKER_INDEX\n            valueFrom:\n              fieldRef:\n                fieldPath: metadata.labels['leaderworkerset.sigs.k8s.io/worker-index']\n          image: lmsysorg/sglang:latest\n          name: sglang-leader\n          ports:\n          - containerPort: 30000\n            protocol: TCP\n          readinessProbe:\n            periodSeconds: 30\n            tcpSocket:\n              port: 30000\n          resources:\n            limits:\n              nvidia.com/gpu: \"8\"\n          securityContext:\n            capabilities:\n              add:\n              - IPC_LOCK\n            privileged: true\n          volumeMounts:\n          - mountPath: /dev/shm\n            name: dshm\n          - mountPath: /work/models\n            name: model\n          - mountPath: /dev/infiniband\n            name: ib\n          - mountPath: /sgl-workspace/sglang/python/sglang/srt/layers/moe/fused_moe_triton/configs\n            name: cf\n          - mountPath: /root/.cache\n            name: sgl-cache\n        dnsPolicy: ClusterFirstWithHostNet\n        hostIPC: true\n        hostNetwork: true\n        nodeSelector:\n        # should modify according your deployment env\n          pd: \"yes\"\n        tolerations:\n        # should modify according your deployment env\n        - key: bopd\n          operator: Exists\n        - key: node-role\n          operator: Exists\n        volumes:\n        - emptyDir:\n            medium: Memory\n          name: dshm\n        - hostPath:\n            path: /data1/maas_hosted_models/models/DeepSeek-R1-0528/deepseek_r1_0528\n          name: model\n        - hostPath:\n            path: /dev/infiniband\n          name: ib\n        - hostPath:\n            path: /data1/maas_hosted_models/models/fused_moe_triton/configs\n          name: cf\n        - hostPath:\n            path: /data1/sgl_cache\n            type: DirectoryOrCreate\n          name: sgl-cache\n    restartPolicy: RecreateGroupOnPodRestart\n    size: 2\n    workerTemplate:\n      metadata: {}\n      spec:\n        containers:\n        - command:\n          - python3\n          - -m\n          - sglang.launch_server\n          - --model-path\n          - /work/models\n          - --disaggregation-ib-device\n          # should modify according your rdma env\n          - mlx5_bond_0,mlx5_bond_1,mlx5_bond_2,mlx5_bond_3\n          - --chunked-prefill-size\n          - \"524288\"\n          - --max-prefill-tokens\n          - \"32768\"\n          - --page-size\n          - \"64\"\n          - --ep-dispatch-algorithm\n          - dynamic\n          - --eplb-algorithm\n          - deepseek\n          #          - --deepep-config\n          #          -  /home/aiges/tuned/tuned_8sms.json\n          # can be tuned using deepep test scripts\n          - --enable-dp-lm-head\n          - --enable-dp-attention\n          - --dp-size\n          - \"16\"\n          - --disable-radix-cache\n          - --moe-a2a-backend\n          - deepep\n          - --disaggregation-mode\n          - prefill\n          - --mem-fraction-static\n          - \"0.7\"\n          - --context-length\n          - \"32768\"\n          - --tp\n          - \"16\"\n          - --dist-init-addr\n          - $(LWS_LEADER_ADDRESS):20102\n          - --nnodes\n          - $(LWS_GROUP_SIZE)\n          - --node-rank\n          - $(LWS_WORKER_INDEX)\n          - --trust-remote-code\n          - --ep-num-redundant-experts\n          - \"32\"\n          - --moe-dense-tp-size\n          - \"1\"\n          - --max-running-requests\n          - \"1024\"\n          env:\n          - name: SGLANG_SET_CPU_AFFINITY\n            value: \"true\"\n          - name: NVSHMEM_HCA_PE_MAPPING\n            # should modify according your rdma env\n            value: \"mlx5_bond_0:1:2,mlx5_bond_1:1:2,mlx5_bond_2:1:2,mlx5_bond_3:1:2\"\n          - name: NCCL_IB_HCA\n            value: ^=mlx5_0,mlx5_5,mlx5_6\n          - name: NVSHMEM_IB_TRAFFIC_CLASS\n            value: \"16\"\n          - name: NVSHMEM_IB_GID_INDEX\n            value: \"3\"\n          - name: NVSHMEM_ENABLE_NIC_PE_MAPPING\n            value: \"1\"\n          - name: CUDA_LAUNCH_BLOCKING\n            value: \"0\"\n          - name: SGLANG_MOONCAKE_TRANS_THREAD\n            value: \"8\"\n          - name: SGLANG_ENABLE_JIT_DEEPGEMM\n            value: \"1\"\n          - name: SGLANG_CHUNKED_PREFIX_CACHE_THRESHOLD\n            value: \"0\"\n          - name: NCCL_IB_QPS_PER_CONNECTION\n            value: \"8\"\n          - name: NCCL_IB_SPLIT_DATA_ON_QPS\n            value: \"1\"\n          - name: NCCL_NET_PLUGIN\n            value: none\n          - name: NCCL_IB_TC\n            value: \"136\"\n          - name: NCCL_MIN_NCHANNELS\n            value: \"4\"\n          - name: MC_TE_METRIC\n            value: \"true\"\n          - name: NCCL_IB_SL\n            value: \"5\"\n          - name: LWS_WORKER_INDEX\n            valueFrom:\n              fieldRef:\n                fieldPath: metadata.labels['leaderworkerset.sigs.k8s.io/worker-index']\n          image: lmsysorg/sglang:latest\n          name: sglang-worker\n          ports:\n          - containerPort: 30001\n            protocol: TCP\n          resources:\n            limits:\n              nvidia.com/gpu: \"8\"\n          securityContext:\n            capabilities:\n              add:\n              - IPC_LOCK\n            privileged: true\n          volumeMounts:\n          - mountPath: /root/.cache\n            name: sgl-cache\n          - mountPath: /dev/shm\n            name: dshm\n          - mountPath: /work/models\n            name: model\n          - mountPath: /dev/infiniband\n            name: ib\n          - mountPath: /sgl-workspace/sglang/python/sglang/srt/layers/moe/fused_moe_triton/configs\n            name: cf\n        dnsPolicy: ClusterFirstWithHostNet\n        hostIPC: true\n        hostNetwork: true\n        nodeSelector:\n        # should modify according your deployment env\n          pd: \"yes\"\n        tolerations:\n        # should modify according your deployment env\n        - key: bopd\n          operator: Exists\n        - key: node-role\n          operator: Exists\n        volumes:\n        - emptyDir:\n            medium: Memory\n          name: dshm\n        - hostPath:\n            path: /dev/infiniband\n          name: ib\n        - hostPath:\n            # modify according to you deployment env\n            path: /data1/maas_hosted_models/models/DeepSeek-R1-0528/deepseek_r1_0528\n          name: model\n        - hostPath:\n            # modify according to you deployment env\n            path: /data1/maas_hosted_models/models/fused_moe_triton/configs\n          name: cf\n        - hostPath:\n            # modify according to you deployment env\n            path: /data1/sgl_cache\n            type: DirectoryOrCreate\n          name: sgl-cache\n"
  },
  {
    "path": "docs/references/multi_node_deployment/lws_pd/lws_pd_deploy.md",
    "content": "# LWS Based PD Deploy\n\n## 0. Prerequisites\n\n1. k8s >=1.26\n2. lws installed on k8s.\n\n## 1. Image Preparation\n\n`lmsysorg/sglang:deepep`\n\n## 2. Deployment Manifest Files\n\n***Notice: We will package all deployment files into Helm Chart format in the near future. Interested community members can contact us to contribute***\n\n### Prefill\n\nPrefill manifest file [prefill.yaml](lws-examples/p.yaml)\n\n*Note: The NodeSelector section, model location section, and taint toleration section can be adjusted according to your actual deployment environment*\n\n```yaml\napiVersion: leaderworkerset.x-k8s.io/v1\nkind: LeaderWorkerSet\nmetadata:\n  name: deepseekr10528-prefill-main\nspec:\n  leaderWorkerTemplate:\n    leaderTemplate:\n      metadata:\n        labels:\n          role: leader\n      spec:\n        containers:\n        - command:\n          - python3\n          - -m\n          - sglang.launch_server\n          - --port\n          - \"30000\"\n          - --host\n          - \"0.0.0.0\"\n          - --model-path\n          - /work/models\n          - --disaggregation-ib-device\n          # should modify according your rdma env\n          - mlx5_bond_0,mlx5_bond_1,mlx5_bond_2,mlx5_bond_3\n          - --chunked-prefill-size\n          - \"524288\"\n          - --max-prefill-tokens\n          - \"32768\"\n          - --page-size\n          - \"64\"\n          #          - --init-expert-location\n          #          - /home/aiges/tuned/attachment_ep_statistics/prefill_in1024.json\n          - --ep-dispatch-algorithm\n          - dynamic\n          - --eplb-algorithm\n          - deepseek\n          #          - --deepep-config\n          #          -  /home/aiges/tuned/tuned_8sms.json\n          - --enable-dp-lm-head\n          - --enable-dp-attention\n          - --dp-size\n          - \"16\"\n          - --disable-radix-cache\n          - --moe-a2a-backend\n          - deepep\n          - --disaggregation-mode\n          - prefill\n          - --mem-fraction-static\n          - \"0.7\"\n          - --context-length\n          - \"32768\"\n          - --tp\n          - \"16\"\n          - --dist-init-addr\n          - $(LWS_LEADER_ADDRESS):20102\n          - --nnodes\n          - $(LWS_GROUP_SIZE)\n          - --node-rank\n          - $(LWS_WORKER_INDEX)\n          - --trust-remote-code\n          - --ep-num-redundant-experts\n          - \"32\"\n          - --moe-dense-tp-size\n          - \"1\"\n          - --max-running-requests\n          - \"1024\"\n          env:\n#          - name: NVSHMEM_HCA_PE_MAPPING\n#            value: \"mlx5_bond_0:1:2,mlx5_bond_1:1:2,mlx5_bond_2:1:2,mlx5_bond_3:1:2\"\n#          - name: NVSHMEM_HCA_LIST\n#            value: \"mlx5_bond_0:1,mlx5_bond_1:1,mlx5_bond_2:1,mlx5_bond_3:1\"\n          - name: NVSHMEM_IB_GID_INDEX\n            value: \"3\"\n          - name: NVSHMEM_ENABLE_NIC_PE_MAPPING\n            value: \"1\"\n          - name: SGLANG_SET_CPU_AFFINITY\n            value: \"true\"\n          - name: SGLANG_ENABLE_JIT_DEEPGEMM\n            value: \"1\"\n          - name: NCCL_IB_QPS_PER_CONNECTION\n            value: \"8\"\n          - name: NCCL_IB_SPLIT_DATA_ON_QPS\n            value: \"1\"\n          - name: NCCL_NET_PLUGIN\n            value: none\n          - name: NCCL_IB_TC\n            value: \"136\"\n          - name: NCCL_MIN_NCHANNELS\n            value: \"4\"\n          - name: MC_TE_METRIC\n            value: \"false\"\n          - name: NCCL_IB_SL\n            value: \"5\"\n          - name: NCCL_IB_HCA\n            value: ^=mlx5_0,mlx5_5,mlx5_6\n          - name: LWS_WORKER_INDEX\n            valueFrom:\n              fieldRef:\n                fieldPath: metadata.labels['leaderworkerset.sigs.k8s.io/worker-index']\n          image: lmsysorg/sglang:deepep\n          name: sglang-leader\n          ports:\n          - containerPort: 30000\n            protocol: TCP\n          readinessProbe:\n            periodSeconds: 30\n            tcpSocket:\n              port: 30000\n          resources:\n            limits:\n              nvidia.com/gpu: \"8\"\n          securityContext:\n            capabilities:\n              add:\n              - IPC_LOCK\n            privileged: true\n          volumeMounts:\n          - mountPath: /dev/shm\n            name: dshm\n          - mountPath: /work/models\n            name: model\n          - mountPath: /dev/infiniband\n            name: ib\n          - mountPath: /sgl-workspace/sglang/python/sglang/srt/layers/moe/fused_moe_triton/configs\n            name: cf\n          - mountPath: /root/.cache\n            name: sgl-cache\n        dnsPolicy: ClusterFirstWithHostNet\n        hostIPC: true\n        hostNetwork: true\n        nodeSelector:\n          pd: \"yes\"\n        tolerations:\n        - key: pd\n          operator: Exists\n        - key: node-role\n          operator: Exists\n        volumes:\n        - emptyDir:\n            medium: Memory\n          name: dshm\n        - hostPath:\n            # modify according to you deployment env\n            path: /data1/maas_hosted_models/models/DeepSeek-R1-0528/deepseek_r1_0528\n          name: model\n        - hostPath:\n            path: /dev/infiniband\n          name: ib\n        - hostPath:\n            # modify according to you deployment env\n            path: /data1/maas_hosted_models/models/fused_moe_triton/configs\n          name: cf\n        - hostPath:\n            # modify according to you deployment env\n            path: /data1/sgl_cache\n            type: DirectoryOrCreate\n          name: sgl-cache\n    restartPolicy: RecreateGroupOnPodRestart\n    size: 2\n    workerTemplate:\n      metadata: {}\n      spec:\n        containers:\n        - command:\n          - python3\n          - -m\n          - sglang.launch_server\n          - --model-path\n          - /work/models\n          - --disaggregation-ib-device\n          - mlx5_bond_0,mlx5_bond_1,mlx5_bond_2,mlx5_bond_3\n          - --chunked-prefill-size\n          - \"524288\"\n          - --max-prefill-tokens\n          - \"32768\"\n          - --page-size\n          - \"64\"\n          #- --init-expert-location\n          #- /home/aiges/tuned/attachment_ep_statistics/prefill_in1024.json\n          - --ep-dispatch-algorithm\n          - dynamic\n          - --eplb-algorithm\n          - deepseek\n#          - --deepep-config\n#          -  /home/aiges/tuned/tuned_8sms.json\n          - --enable-dp-lm-head\n          - --enable-dp-attention\n          - --dp-size\n          - \"16\"\n          - --disable-radix-cache\n          - --moe-a2a-backend\n          - deepep\n          - --disaggregation-mode\n          - prefill\n          - --mem-fraction-static\n          - \"0.7\"\n          - --context-length\n          - \"32768\"\n          - --tp\n          - \"16\"\n          - --dist-init-addr\n          - $(LWS_LEADER_ADDRESS):20102\n          - --nnodes\n          - $(LWS_GROUP_SIZE)\n          - --node-rank\n          - $(LWS_WORKER_INDEX)\n          - --trust-remote-code\n          - --ep-num-redundant-experts\n          - \"32\"\n          - --moe-dense-tp-size\n          - \"1\"\n          - --max-running-requests\n          - \"1024\"\n          env:\n          - name: SGLANG_SET_CPU_AFFINITY\n            value: \"true\"\n          - name: SGLANG_HACK_DEEPEP_NUM_SMS\n            value: \"8\"\n          - name: SGLANG_HACK_DEEPEP_NEW_MODE\n            value: \"0\"\n#          - name: NVSHMEM_HCA_PE_MAPPING\n#            value: \"mlx5_bond_0:1:2,mlx5_bond_1:1:2,mlx5_bond_2:1:2,mlx5_bond_3:1:2\"\n#          - name: NVSHMEM_HCA_LIST\n#            value: \"mlx5_bond_0:1,mlx5_bond_1:1,mlx5_bond_2:1,mlx5_bond_3:1\"\n          - name: NCCL_IB_HCA\n            value: ^=mlx5_0,mlx5_5,mlx5_6\n          - name: NVSHMEM_IB_TRAFFIC_CLASS\n            value: \"16\"\n          - name: NVSHMEM_IB_GID_INDEX\n            value: \"3\"\n          - name: NVSHMEM_ENABLE_NIC_PE_MAPPING\n            value: \"1\"\n          - name: CUDA_LAUNCH_BLOCKING\n            value: \"0\"\n          - name: SGLANG_MOONCAKE_TRANS_THREAD\n            value: \"8\"\n          - name: SGLANG_ENABLE_JIT_DEEPGEMM\n            value: \"1\"\n          - name: SGLANG_CHUNKED_PREFIX_CACHE_THRESHOLD\n            value: \"0\"\n          - name: NCCL_IB_QPS_PER_CONNECTION\n            value: \"8\"\n          - name: NCCL_IB_SPLIT_DATA_ON_QPS\n            value: \"1\"\n          - name: NCCL_NET_PLUGIN\n            value: none\n          - name: NCCL_IB_TC\n            value: \"136\"\n          - name: NCCL_MIN_NCHANNELS\n            value: \"4\"\n          - name: MC_TE_METRIC\n            value: \"true\"\n          - name: NCCL_IB_SL\n            value: \"5\"\n          - name: LWS_WORKER_INDEX\n            valueFrom:\n              fieldRef:\n                fieldPath: metadata.labels['leaderworkerset.sigs.k8s.io/worker-index']\n          image: lmsysorg/sglang:deepep\n          name: sglang-worker\n          ports:\n          - containerPort: 30001\n            protocol: TCP\n          resources:\n            limits:\n              nvidia.com/gpu: \"8\"\n          securityContext:\n            capabilities:\n              add:\n              - IPC_LOCK\n            privileged: true\n          volumeMounts:\n\n          - mountPath: /root/.cache\n            name: sgl-cache\n          - mountPath: /dev/shm\n            name: dshm\n          - mountPath: /work/models\n            name: model\n          - mountPath: /dev/infiniband\n            name: ib\n          - mountPath: /sgl-workspace/sglang/python/sglang/srt/layers/moe/fused_moe_triton/configs\n            name: cf\n        dnsPolicy: ClusterFirstWithHostNet\n        hostIPC: true\n        hostNetwork: true\n        nodeSelector:\n          pd: \"yes\"\n        tolerations:\n        - key: pd\n          operator: Exists\n        - key: node-role\n          operator: Exists\n        volumes:\n        - emptyDir:\n            medium: Memory\n          name: dshm\n        - hostPath:\n            path: /dev/infiniband\n          name: ib\n        - hostPath:\n            path: /data1/maas_hosted_models/models/DeepSeek-R1-0528/deepseek_r1_0528\n          name: model\n        - hostPath:\n            path: /data1/maas_hosted_models/models/fused_moe_triton/configs\n          name: cf\n        - hostPath:\n            path: /data1/sgl_cache\n            type: DirectoryOrCreate\n          name: sgl-cache\n\n```\n\n### Decode\n\nDecode node deployment manifest file [decode.yaml](lws-examples/d.yaml)\n\n*Note: The NodeSelector section, model location section, and taint toleration section can be adjusted according to your actual deployment environment*\n\n```yaml\napiVersion: leaderworkerset.x-k8s.io/v1\nkind: LeaderWorkerSet\nmetadata:\n  name: deepseekr10528-decode-main\nspec:\n  leaderWorkerTemplate:\n    leaderTemplate:\n      metadata:\n        labels:\n          role: leader\n      spec:\n        containers:\n        - command:\n          - python3\n          - -m\n          - sglang.launch_server\n          - --port\n          - \"30000\"\n          - --host\n          - \"0.0.0.0\"\n          - --model-path\n          - /work/models\n          - --chunked-prefill-size\n          - \"262144\"\n          - --page-size\n          - \"64\"\n          - --enable-dp-attention\n          - --enable-dp-lm-head\n          - --dp-size\n          - \"16\"\n          - --moe-a2a-backend\n          - deepep\n          - --disaggregation-mode\n          - decode\n          - --mem-fraction-static\n          -  \"0.849\"\n          - --context-length\n          - \"32768\"\n          - --disaggregation-ib-device\n          - \"mlx5_bond_0,mlx5_bond_1,mlx5_bond_2,mlx5_bond_3\"\n          - --cuda-graph-max-bs\n          - \"64\"\n          - --max-running-requests\n          - \"2048\"\n          - --tp-size\n          - \"16\" # Size of Tensor Parallelism\n          - --dist-init-addr\n          - $(LWS_LEADER_ADDRESS):20102\n          - --nnodes\n          - $(LWS_GROUP_SIZE)\n          - --node-rank\n          - $(LWS_WORKER_INDEX)\n          - --trust-remote-code\n          - --ep-num-redundant-experts\n          - \"32\"\n          - --moe-dense-tp-size\n          - \"1\"\n          env:\n          - name: CUDA_LAUNCH_BLOCKING\n            value: \"0\"\n          - name: NVSHMEM_IB_GID_INDEX\n            value: \"3\"\n          - name: NVSHMEM_ENABLE_NIC_PE_MAPPING\n            value: \"1\"\n          - name:  NCCL_IB_QPS_PER_CONNECTION\n            value: \"8\"\n          - name: NCCL_IB_SPLIT_DATA_ON_QPS\n            value: \"1\"\n          - name: NCCL_NET_PLUGIN\n            value: \"none\"\n          - name: NCCL_IB_TC\n            value: \"136\"\n          - name: NCCL_MIN_NCHANNELS\n            value: \"4\"\n          - name: NCCL_IB_SL\n            value: \"5\"\n          - name: MC_TE_METRIC\n            value: \"true\"\n          - name: SGLANG_MOONCAKE_TRANS_THREAD\n            value: \"16\"\n          - name: SGLANG_ENABLE_JIT_DEEPGEMM\n            value: \"1\"\n          - name: NCCL_IB_HCA\n            value: ^=mlx5_0,mlx5_5,mlx5_6\n          - name: LWS_WORKER_INDEX\n            valueFrom:\n              fieldRef:\n                fieldPath: metadata.labels['leaderworkerset.sigs.k8s.io/worker-index']\n          image: lmsysorg/sglang:deepep\n          name: sglang-leader\n          ports:\n          - containerPort: 30000\n            protocol: TCP\n          readinessProbe:\n            periodSeconds: 30\n            tcpSocket:\n              port: 30000\n          resources:\n            limits:\n              nvidia.com/gpu: \"8\"\n          securityContext:\n            capabilities:\n              add:\n              - IPC_LOCK\n            privileged: true\n          volumeMounts:\n          - mountPath: /root/.cache\n            name: sgl-cache\n          - mountPath: /dev/shm\n            name: dshm\n          - mountPath: /work/models\n            name: model\n          - mountPath: /dev/infiniband\n            name: ib\n          - mountPath: /sgl-workspace/sglang/python/sglang/srt/layers/moe/fused_moe_triton/configs\n            name: cf\n        dnsPolicy: ClusterFirstWithHostNet\n        hostIPC: true\n        hostNetwork: true\n        nodeSelector:\n          pd: \"yes\"\n        tolerations:\n        - key: pd\n          operator: Exists\n        - key: node-role\n          operator: Exists\n        volumes:\n        - hostPath:\n            path: /data1/sgl_cache1\n            type: DirectoryOrCreate\n          name: sgl-cache\n        - emptyDir:\n            medium: Memory\n          name: dshm\n        - hostPath:\n            path: /data1/maas_hosted_models/models/DeepSeek-R1-0528/deepseek_r1_0528\n          name: model\n        - hostPath:\n            path: /dev/infiniband\n          name: ib\n        - hostPath:\n            path: /data1/maas_hosted_models/models/fused_moe_triton/configs\n          name: cf\n    restartPolicy: RecreateGroupOnPodRestart\n    size:  2\n    workerTemplate:\n      metadata: {}\n      spec:\n        containers:\n        - command:\n          - python3\n          - -m\n          - sglang.launch_server\n          - --model-path\n          - /work/models\n          - --chunked-prefill-size\n          - \"262144\"\n          - --page-size\n          - \"64\"\n          - --enable-dp-attention\n          - --enable-dp-lm-head\n            #- --enable-two-batch-overlap\n          - --dp-size\n          - \"16\"\n          - --moe-a2a-backend\n          - deepep\n          - --disaggregation-mode\n          - decode\n          - --mem-fraction-static\n          -  \"0.849\"\n          - --context-length\n          - \"32768\"\n          - --disaggregation-ib-device\n          # should modify according your rdma env\n          - \"mlx5_bond_0,mlx5_bond_1,mlx5_bond_2,mlx5_bond_3\"\n          - --cuda-graph-max-bs\n          - \"64\"\n          - --max-running-requests\n          - \"2048\"\n          - --tp-size\n          - \"16\" # Size of Tensor Parallelism\n          - --dist-init-addr\n          - $(LWS_LEADER_ADDRESS):20102\n          - --nnodes\n          - $(LWS_GROUP_SIZE)\n          - --node-rank\n          - $(LWS_WORKER_INDEX)\n          - --trust-remote-code\n          - --ep-num-redundant-experts\n          - \"32\"\n          - --moe-dense-tp-size\n          - \"1\"\n          env:\n          - name: SGLANG_HACK_DEEPEP_NUM_SMS\n            value: \"24\"\n          - name: SGLANG_HACK_DEEPEP_NEW_MODE\n            value: \"0\"\n          - name: NVSHMEM_IB_TRAFFIC_CLASS\n            value: \"16\"\n          - name: NVSHMEM_IB_GID_INDEX\n            value: \"3\"\n          - name: NVSHMEM_ENABLE_NIC_PE_MAPPING\n            value: \"1\"\n          - name:  NCCL_IB_QPS_PER_CONNECTION\n            value: \"8\"\n          - name: NCCL_IB_SPLIT_DATA_ON_QPS\n            value: \"1\"\n          - name: NCCL_NET_PLUGIN\n            value: \"none\"\n          - name: NCCL_IB_TC\n            value: \"136\"\n          - name: NCCL_MIN_NCHANNELS\n            value: \"4\"\n          - name: MC_TE_METRIC\n            value: \"true\"\n          - name: NCCL_IB_SL\n            value: \"5\"\n          - name: SGLANG_MOONCAKE_TRANS_THREAD\n            value: \"16\"\n          - name: SGLANG_ENABLE_JIT_DEEPGEMM\n            value: \"1\"\n          - name: NCCL_IB_HCA\n            value: ^=mlx5_0,mlx5_5,mlx5_6\n          - name: LWS_WORKER_INDEX\n            valueFrom:\n              fieldRef:\n                fieldPath: metadata.labels['leaderworkerset.sigs.k8s.io/worker-index']\n          image: lmsysorg/sglang:deepep\n          name: sglang-worker\n          ports:\n          - containerPort: 30001\n          resources:\n            limits:\n              nvidia.com/gpu: \"8\"\n          securityContext:\n            capabilities:\n              add:\n              - IPC_LOCK\n            privileged: true\n          volumeMounts:\n          - mountPath: /root/.cache\n            name: sgl-cache\n          - mountPath: /dev/shm\n            name: dshm\n          - mountPath: /work/models\n            name: model\n          - mountPath: /dev/infiniband\n            name: ib\n          - mountPath: /sgl-workspace/sglang/python/sglang/srt/layers/moe/fused_moe_triton/configs\n            name: cf\n        dnsPolicy: ClusterFirstWithHostNet\n        hostIPC: true\n        hostNetwork: true\n        nodeSelector:\n          pd: \"yes\"\n        tolerations:\n        - key: pd\n          operator: Exists\n        - key: node-role\n          operator: Exists\n        volumes:\n        - hostPath:\n            path: /data1/sgl_cache1\n            type: DirectoryOrCreate\n          name: sgl-cache\n        - emptyDir:\n            medium: Memory\n          name: dshm\n        - hostPath:\n            path: /dev/infiniband\n          name: ib\n        - hostPath:\n            # modify according to you deployment env\n            path: /data1/maas_hosted_models/models/DeepSeek-R1-0528/deepseek_r1_0528\n          name: model\n        - hostPath:\n            # modify according to you deployment env\n            path: /data1/maas_hosted_models/models/fused_moe_triton/configs\n          name: cf\n  networkConfig:\n    subdomainPolicy: Shared\n  replicas: 1\n  rolloutStrategy:\n    rollingUpdateConfiguration:\n      maxSurge: 0\n      maxUnavailable: 1\n    type: RollingUpdate\n  startupPolicy: LeaderCreated\n```\n\nExecute separately:\n\n```bash\nkubectl apply -f p.yaml\nkubectl apply -f d.yaml\n```\n\nAt this point, we have completed the deployment of the 1P1D SGLang engine part.\n\nTo allow our users to directly experience the model API, we still need a load balancer to handle sequential calls between prefill and decode. Different companies implement LBs differently, and the community will also officially release a new LB component written in Rust in the near future.\n\nCurrently, we use a static K8S service + minilb approach to implement model API calls.\n\n### Creating Service for Prefill and Decode\n\n#### Create prefill k8s service\n[p-svc.yaml](lws-examples/p-svc.yaml)\n```yaml\napiVersion: v1\nkind: Service\nmetadata:\n  name: deepseekr10528-prefill-main\nspec:\n  selector:\n    leaderworkerset.sigs.k8s.io/name: deepseekr10528-prefill-main\n    role: leader\n  ports:\n    - protocol: TCP\n      port: 30000\n      targetPort: 30000\n```\nExecute `kubectl apply -f p-svc.yaml`\n\n#### Create decode k8s service\n[d-svc.yaml](lws-examples/d-svc.yaml)\n```yaml\napiVersion: v1\nkind: Service\nmetadata:\n  name: deepseekr10528-decode-main\nspec:\n  selector:\n    leaderworkerset.sigs.k8s.io/name: deepseekr10528-decode-main\n    role: leader\n  ports:\n    - protocol: TCP\n      port: 30000\n      targetPort: 30000\n```\nExecute `kubectl apply -f d-svc.yaml`\n\n#### Deploy minilb and lb service\n[lb.yaml](lws-examples/lb.yaml)\n```yaml\napiVersion: apps/v1\nkind: Deployment\nmetadata:\n  name: deepseekr10528-lb-main\n  labels:\n    app: deepseekr10528-lb\nspec:\n  replicas: 1\n  selector:\n    matchLabels:\n      app: deepseekr10528-lb\n  template:\n    metadata:\n      labels:\n        app: deepseekr10528-lb\n    spec:\n      nodeSelector:\n          pd: \"yes\"\n      tolerations:\n        - key: pd\n          operator: Exists\n        - key: node-role\n          operator: Exists\n      containers:\n        - name: sgl-minilb\n          image: lmsysorg/sglang:deepep\n          command:\n          - python\n          - -m\n          - sglang_router.launch_router\n          - --pd-disaggregation\n          - --prefill\n          - http://deepseekr10528-prefill-main:30000\n          - --decode\n          - http://deepseekr10528-decode-main:30000\n          - --host\n          - 0.0.0.0\n          - --port\n          -  \"8000\"\n          ports:\n            - containerPort: 8000\n---\napiVersion: v1\nkind: Service\nmetadata:\n  name: deepseekr10528-lb-service\nspec:\n  type: NodePort\n  selector:\n    app: deepseekr10528-lb\n  ports:\n    - protocol: TCP\n      port: 8000         # Service Port（In-Cluster）\n      targetPort: 8000   # Exposed Container\n      nodePort: 30800\n```\nExecute `kubectl apply -f lb.yaml`\n\nAfter waiting for all model deployments to succeed, you will get the following output:\n\n```bash\n[root@ecs-001]# kubectl get po\ndeepseekr10528-decode-main-0             1/1     Running   0          74m\ndeepseekr10528-decode-main-0-1           1/1     Running   0          74m\ndeepseekr10528-lb-main-9c5dbfc57-6lcbd   1/1     Running   0          22m\ndeepseekr10528-prefill-main-0            1/1     Running   0          74m\ndeepseekr10528-prefill-main-0-1          1/1     Running   0          74m\n[root@ecs-cbm-x1-pd-cpu-001 main_doc]# kubectl  get svc |grep dee\ndeepseekr10528-decode-main    ClusterIP   None             <none>        <none>           97m\ndeepseekr10528-lb-service     NodePort    172.16.242.169   <none>        8000:30800/TCP   22m\ndeepseekr10528-prefill-main   ClusterIP   None             <none>        <none>           97m\n```\n\nAt this point, select a nodePort:30800 to access:\n\n```bash\n[root@ecs-001]# curl -X POST \"http://{nodePort}:30800/v1/chat/completions\" \\\n>     -H \"Content-Type: application/json\" \\\n>     -H \"Authorization: Bearer None\" \\\n>     -d '{\n>        \"rid\":\"ccccdd\",\n>         \"model\": \"r1\",\n>         \"messages\": [\n>             {\"role\": \"system\", \"content\": \"0: You are a helpful AI assistant\"},\n>             {\"role\": \"user\", \"content\": \"你是谁？.\"}\n>         ],\n>         \"max_tokens\":221\n>     }'\n{\"id\":\"ccccdd\",\"object\":\"chat.completion\",\"created\":1750252498,\"model\":\"qwen2\",\"choices\":[{\"index\":0,\"message\":{\"role\":\"assistant\",\"content\":\"<think>\\n嗯，用户问了一个很基础的自我介绍问题\"你是谁？\"。这可能是第一次互动时的常规开场白，也可能是想确认我的身份和功能范围。\\n\\n用户没有提供任何背景信息，语气简洁中性。这种场景下新用户的可能性较高，需要给出清晰友好的自我介绍，同时突出实用价值来降低陌生感。\\n\\n考虑到中文用户，应该用简体中文回复。重点要说明三点：身份归属（深度求索）、功能定位（AI助手）、服务范围（学习/工作/生活）。结尾用开放性问题引导对话很关键——既能了解需求，又能避免让用户面对空白输入框时不知所措。\\n\\n用波浪线结尾可以软化语气，那个笑脸表情😊刚好能中和AI的机械感。不过要控制表情符号数量，避免显得轻浮。\\n</think>\\n你好呀！我是你的AI助手，由深度求索公司（DeepSeek）开发的语言模型，名字叫 **DeepSeek-R1**。你可以把我当成一个知识丰富、随叫随到的小帮手～😊\\n\\n我的任务就是陪你聊天、解答问题、\",\"reasoning_content\":null,\"tool_calls\":null},\"logprobs\":null,\"finish_reason\":\"length\",\"matched_stop\":null}],\"usage\":{\"prompt_tokens\":14,\"total_tokens\":235,\"completion_tokens\":221,\"prompt_tokens_details\":null}}\n\n```\n## FAQ\n\n1. The current deployment startup parameters may not be fully compatible with all RDMA scenarios. Different RDMA NCCL-related environment configurations may be needed in different network environments.\n\n2. Some preset, optimized configurations for EPLB are not used here. You can adjust them according to [6017](https://github.com/sgl-project/sglang/issues/6017) as needed.\n"
  },
  {
    "path": "docs/references/multi_node_deployment/multi_node.md",
    "content": "# Multi-Node Deployment\n\n## Llama 3.1 405B\n\n**Run 405B (fp16) on Two Nodes**\n\n```bash\n# replace 172.16.4.52:20000 with your own node ip address and port of the first node\n\npython3 -m sglang.launch_server \\\n  --model-path meta-llama/Meta-Llama-3.1-405B-Instruct \\\n  --tp 16 \\\n  --dist-init-addr 172.16.4.52:20000 \\\n  --nnodes 2 \\\n  --node-rank 0\n\npython3 -m sglang.launch_server \\\n  --model-path meta-llama/Meta-Llama-3.1-405B-Instruct \\\n  --tp 16 \\\n  --dist-init-addr 172.16.4.52:20000 \\\n  --nnodes 2 \\\n  --node-rank 1\n```\n\nNote that LLama 405B (fp8) can also be launched on a single node.\n\n```bash\npython -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct-FP8 --tp 8\n```\n\n## DeepSeek V3/R1\n\nPlease refer to [DeepSeek documents for reference](https://docs.sglang.io/basic_usage/deepseek.html#running-examples-on-multi-node).\n\n## Multi-Node Inference on SLURM\n\nThis example showcases how to serve SGLang server across multiple nodes by SLURM. Submit the following job to the SLURM cluster.\n\n```\n#!/bin/bash -l\n\n#SBATCH -o SLURM_Logs/%x_%j_master.out\n#SBATCH -e SLURM_Logs/%x_%j_master.err\n#SBATCH -D ./\n#SBATCH -J Llama-405B-Online-Inference-TP16-SGL\n\n#SBATCH --nodes=2\n#SBATCH --ntasks=2\n#SBATCH --ntasks-per-node=1  # Ensure 1 task per node\n#SBATCH --cpus-per-task=18\n#SBATCH --mem=224GB\n#SBATCH --partition=\"lmsys.org\"\n#SBATCH --gres=gpu:8\n#SBATCH --time=12:00:00\n\necho \"[INFO] Activating environment on node $SLURM_PROCID\"\nif ! source ENV_FOLDER/bin/activate; then\n    echo \"[ERROR] Failed to activate environment\" >&2\n    exit 1\nfi\n\n# Define parameters\nmodel=MODEL_PATH\ntp_size=16\n\necho \"[INFO] Running inference\"\necho \"[INFO] Model: $model\"\necho \"[INFO] TP Size: $tp_size\"\n\n# Set NCCL initialization address using the hostname of the head node\nHEAD_NODE=$(scontrol show hostname \"$SLURM_NODELIST\" | head -n 1)\nNCCL_INIT_ADDR=\"${HEAD_NODE}:8000\"\necho \"[INFO] NCCL_INIT_ADDR: $NCCL_INIT_ADDR\"\n\n# Launch the model server on each node using SLURM\nsrun --ntasks=2 --nodes=2 --output=\"SLURM_Logs/%x_%j_node$SLURM_NODEID.out\" \\\n    --error=\"SLURM_Logs/%x_%j_node$SLURM_NODEID.err\" \\\n    python3 -m sglang.launch_server \\\n    --model-path \"$model\" \\\n    --grammar-backend \"xgrammar\" \\\n    --tp \"$tp_size\" \\\n    --dist-init-addr \"$NCCL_INIT_ADDR\" \\\n    --nnodes 2 \\\n    --node-rank \"$SLURM_NODEID\" &\n\n# Wait for the NCCL server to be ready on port 30000\nwhile ! nc -z \"$HEAD_NODE\" 30000; do\n    sleep 1\n    echo \"[INFO] Waiting for $HEAD_NODE:30000 to accept connections\"\ndone\n\necho \"[INFO] $HEAD_NODE:30000 is ready to accept connections\"\n\n# Keep the script running until the SLURM job times out\nwait\n```\n\nThen, you can test the server by sending requests following other [documents](https://docs.sglang.io/basic_usage/openai_api_completions.html).\n\nThanks for [aflah02](https://github.com/aflah02) for providing the example, based on his [blog post](https://aflah02.substack.com/p/multi-node-llm-inference-with-sglang).\n"
  },
  {
    "path": "docs/references/multi_node_deployment/multi_node_index.rst",
    "content": "Multi-Node Deployment\n=====================\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Multi-Node Deployment\n\n   multi_node.md\n   deploy_on_k8s.md\n   lws_pd/lws_pd_deploy.md\n   rbg_pd/deepseekv32_pd.md\n\n- `Deploying DeepSeek with PD Disaggregation and Large-Scale Expert Parallelism on 96 H100 GPUs <https://lmsys.org/blog/2025-05-05-large-scale-ep/>`_\n- `Deploying Kimi K2 with PD Disaggregation and Large-Scale Expert Parallelism on 128 H200 GPUs <https://lmsys.org/blog/2025-07-20-k2-large-scale-ep/>`_\n"
  },
  {
    "path": "docs/references/multi_node_deployment/rbg_pd/deepseekv32_pd.md",
    "content": "# DeepSeekV32-Exp RBG Based PD Deploy\n\n## 0. Prerequisites\n\n1. k8s >=1.26\n2. lws installed on k8s.\n3. rbg installed on k8s.\n\nFor RBG installation, please refer to: https://github.com/sgl-project/rbg\n\n## 1. Image Preparation\n\n`lmsysorg/sglang:latest`\n\n\n### 2. All In One manifest file\n\n*Note: The NodeSelector section, model location section, and taint toleration section can be adjusted according to your actual deployment environment*\n\nrbg-dsv32.yml\n\n```yaml\napiVersion: workloads.x-k8s.io/v1alpha1\nkind: RoleBasedGroup\nmetadata:\n  name: deepseek-rbg-32exp\n  namespace: default\nspec:\n  roles:\n    - name: prefill\n      replicas: 1\n      workload:\n        apiVersion: leaderworkerset.x-k8s.io/v1\n        kind: LeaderWorkerSet\n      restartPolicy: None\n      leaderWorkerSet:\n        size: 1\n        patchLeaderTemplate:\n          metadata:\n            labels:\n              role: leader\n              pd_role: prefill\n          spec:\n            containers:\n            - command:\n              - python3\n              - -m\n              - sglang.launch_server\n              - --model-path\n              - /work/models\n              - --port\n              - \"30000\"\n              - --trust-remote\n              - --host\n              -  0.0.0.0\n              - --disaggregation-ib-device\n              -  mlx5_0,mlx5_1,mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7\n              - --disable-radix-cache\n              - --chunked-prefill-size\n              - \"131072\"\n              - --page-size\n              - \"64\"\n    #          - --enable-eplb\n              - --ep-dispatch-algorithm\n              - dynamic\n              - --eplb-algorithm\n              - deepseek\n              - --enable-dp-lm-head\n              - --enable-dp-attention\n              - --dp-size\n              - \"8\"\n              - --moe-a2a-backend\n              - deepep\n              - --deepep-mode\n              - normal\n              - --disaggregation-mode\n              - prefill\n              - --mem-fraction-static\n              - \"0.8\"\n              - --max-prefill-tokens\n              - \"32768\"\n              - --context-length\n              - \"32768\"\n              - --tp\n              - \"8\"\n              - --dist-init-addr\n              - $(LWS_LEADER_ADDRESS):20102\n              - --nnodes\n              - $(LWS_GROUP_SIZE)\n              - --node-rank\n              - $(LWS_WORKER_INDEX)\n              - --trust-remote-code\n              - --ep-num-redundant-experts\n              - \"32\"\n              - --moe-dense-tp-size\n              - \"1\"\n              - --max-running-requests\n              - \"1024\"\n              env:\n              - name: LWS_WORKER_INDEX\n                valueFrom:\n                  fieldRef:\n                    fieldPath: metadata.labels['leaderworkerset.sigs.k8s.io/worker-index']\n              livenessProbe:\n                failureThreshold: 3000\n                httpGet:\n                  path: /health\n                  port: 30000\n                initialDelaySeconds: 300\n                periodSeconds: 60\n                successThreshold: 1\n                timeoutSeconds: 10\n              readinessProbe:\n                failureThreshold: 20\n                httpGet:\n                  path: /health\n                  port: 30000\n                periodSeconds: 30\n                successThreshold: 1\n                timeoutSeconds: 10\n              name: sglang\n              ports:\n              - containerPort: 30000\n                name: sglang-http\n                protocol: TCP\n\n        patchWorkerTemplate: {}\n      template:\n        metadata:\n          labels:\n            inference-framework: sglang\n            inference-stack.io/monitoring: \"enabled\"\n        spec:\n            containers:\n            - name: sglang\n              image: lmsysorg/sglang:latest\n              env:\n                - name: SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK\n                  value: \"1\"\n                - name: CUDA_LAUNCH_BLOCKING\n                  value: \"0\"\n                - name:  SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT\n                  value: \"1000000000\"\n                - name: NVSHMEM_IB_TRAFFIC_CLASS\n                  value: \"16\"\n                - name: NVSHMEM_DISABLE_P2P\n                  value: \"0\"\n                - name: ENABLE_METRICS\n                  value: \"true\"\n                - name: NVSHMEM_IB_GID_INDEX\n                  value: \"3\"\n                - name: NVSHMEM_IB_SL\n                  value: \"5\"\n                - name: SGLANG_SET_CPU_AFFINITY\n                  value: \"true\"\n                - name: SGL_ENABLE_JIT_DEEPGEMM\n                  value: \"1\"\n                - name:  NCCL_IB_QPS_PER_CONNECTION\n                  value: \"8\"\n                - name: NCCL_IB_SPLIT_DATA_ON_QPS\n                  value: \"1\"\n                - name: NCCL_NET_PLUGIN\n                  value: \"none\"\n                - name: NCCL_IB_TC\n                  value: \"136\"\n                - name: NCCL_IB_SL\n                  value: \"5\"\n                - name: NCCL_IB_TIMEOUT\n                  value: \"22\"\n                - name: NCCL_IB_GID_INDEX\n                  value: \"3\"\n                - name: NCCL_MIN_NCHANNELS\n                  value: \"4\"\n                - name: NCCL_SOCKET_IFNAME\n                  value: bond0\n                - name: GLOO_SOCKET_IFNAME\n                  value: bond0\n                - name: NCCL_IB_HCA\n                  value: ^=mlx5_0,mlx5_5,mlx5_6\n                - name: NVSHMEM_BOOTSTRAP_UID_SOCK_IFNAME\n                  value: \"bond0\"\n                - name: MC_TE_METRIC\n                  value: \"false\"\n              resources:\n                limits:\n                  nvidia.com/gpu: \"8\"\n              securityContext:\n                capabilities:\n                  add:\n                  - IPC_LOCK\n                privileged: true\n              volumeMounts:\n                - mountPath: /root/.cache\n                  name: sgl-cache\n                - mountPath: /dev/shm\n                  name: dshm\n                - mountPath: /work/models\n                  name: model\n                - mountPath: /dev/infiniband\n                  name: ib\n                - mountPath: /sgl-workspace/sglang\n                  name: src\n\n            dnsPolicy: ClusterFirstWithHostNet\n            hostIPC: true\n            hostNetwork: true\n            nodeSelector:\n              pd: \"yes\"\n            tolerations:\n              - key: pd\n                operator: Exists\n            volumes:\n            - hostPath:\n                path: /var/run/sys-topology\n              name: topo\n            - hostPath:\n                path: /data1/sgl_cache4\n                type: DirectoryOrCreate\n              name: sgl-cache\n            - emptyDir:\n                medium: Memory\n              name: dshm\n            - hostPath:\n                path: /data/DeepSeek-V3.2-Exp\n              name: model\n            - hostPath:\n                path: /dev/infiniband\n              name: ib\n            - hostPath:\n                path: /data/src/sglang\n                type: DirectoryOrCreate\n              name: src\n\n    - name: decode\n      replicas: 1\n      workload:\n        apiVersion: leaderworkerset.x-k8s.io/v1\n        kind: LeaderWorkerSet\n      leaderWorkerSet:\n        size: 1\n        patchLeaderTemplate:\n          metadata:\n            labels:\n              role: leader\n              pd_role: decode\n          spec:\n            containers:\n            - command:\n                  - python3\n                  - -m\n                  - sglang.launch_server\n                  - --model-path\n                  - /work/models\n                  - --port\n                  - \"30000\"\n                  - --trust-remote\n                  - --host\n                  -  0.0.0.0\n                  - --disaggregation-ib-device\n                  -  mlx5_0,mlx5_1,mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7\n                  - --chunked-prefill-size\n                  - \"131072\"\n                  - --eplb-rebalance-layers-per-chunk\n                  - \"29\"\n                  - --page-size\n                  - \"64\"\n                  - --enable-dp-attention\n                  - --enable-dp-lm-head\n                  - --dp-size\n                  - \"8\"\n                  - --moe-a2a-backend\n                  - deepep\n                  - --deepep-mode\n                  - low_latency\n                  - --disaggregation-mode\n                  - decode\n                  - --mem-fraction-static\n                  -  \"0.8\"\n                  - --context-length\n                  - \"32768\"\n                  - --max-running-requests\n                  - \"2048\"\n                  - --tp-size\n                  - \"8\" # Size of Tensor Parallelism\n                  - --cuda-graph-max-bs\n                  - \"16\"\n                  - --dist-init-addr\n                  - $(LWS_LEADER_ADDRESS):20102\n                  - --nnodes\n                  - $(LWS_GROUP_SIZE)\n                  - --node-rank\n                  - $(LWS_WORKER_INDEX)\n                  - --trust-remote-code\n                  - --ep-num-redundant-experts\n                  - \"32\"\n                  - --moe-dense-tp-size\n                  - \"1\"\n              env:\n              - name: LWS_WORKER_INDEX\n                valueFrom:\n                  fieldRef:\n                    fieldPath: metadata.labels['leaderworkerset.sigs.k8s.io/worker-index']\n              livenessProbe:\n                failureThreshold: 30000\n                httpGet:\n                  path: /health\n                  port: 30000\n                initialDelaySeconds: 300\n                periodSeconds: 60\n                successThreshold: 1\n                timeoutSeconds: 10\n              name: sglang\n              readinessProbe:\n                failureThreshold: 20\n                httpGet:\n                  path: /health\n                  port: 30000\n                periodSeconds: 30\n                successThreshold: 1\n                timeoutSeconds: 10\n        patchWorkerTemplate:\n          spec:\n            containers:\n            - command:\n                - python3\n                - -m\n                - sglang.launch_server\n                - --model-path\n                - /work/models\n                - --crash-dump-folder\n                -  /log\n                - --chunked-prefill-size\n                - \"262144\"\n                - --eplb-rebalance-layers-per-chunk\n                - \"29\"\n                - --page-size\n                - \"64\"\n                - --enable-dp-attention\n                - --enable-dp-lm-head\n                - --dp-size\n                - \"32\"\n                - --moe-a2a-backend\n                - \"deepep\"\n                - --deepep-mode\n                - low_latency\n                - --disaggregation-mode\n                - decode\n                - --mem-fraction-static\n                -  \"0.849\"\n                - --context-length\n                - \"32768\"\n                - --disaggregation-ib-device\n                -  mlx5_0,mlx5_1,mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7\n                - --max-running-requests\n                - \"4096\"\n                - --cuda-graph-max-bs\n                - \"16\"\n                - --tp-size\n                - \"8\" # Size of Tensor Parallelism\n                - --dist-init-addr\n                - $(LWS_LEADER_ADDRESS):20102\n                - --nnodes\n                - $(LWS_GROUP_SIZE)\n                - --node-rank\n                - $(LWS_WORKER_INDEX)\n                - --trust-remote-code\n                - --ep-num-redundant-experts\n                - \"32\"\n                - --moe-dense-tp-size\n                - \"1\"\n              env:\n              - name: LWS_WORKER_INDEX\n                valueFrom:\n                  fieldRef:\n                    fieldPath: metadata.labels['leaderworkerset.sigs.k8s.io/worker-index']\n              name: sglang\n      template:\n        metadata:\n          labels:\n            inference-framework: sglang-unuse\n            inference-stack.io/monitoring: \"enabled\"\n        spec:\n            containers:\n            - image: lmsysorg/sglang:latest\n              name: sglang\n              resources:\n                limits:\n                  nvidia.com/gpu: \"8\"\n              securityContext:\n                capabilities:\n                  add:\n                  - IPC_LOCK\n                privileged: true\n              volumeMounts:\n                - mountPath: /root/.cache\n                  name: sgl-cache\n                - mountPath: /dev/shm\n                  name: dshm\n                - mountPath: /work/models\n                  name: model\n                - mountPath: /dev/infiniband\n                  name: ib\n                - mountPath: /sgl-workspace/sglang\n                  name: src\n              env:\n                - name: SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK\n                  value: \"1\"\n                - name: SGLANG_DISAGGREGATION_WAITING_TIMEOUT\n                  value: \"100000000\"\n                - name: NVSHMEM_DISABLE_P2P\n                  value: \"0\"\n                - name: NVSHMEM_IB_TRAFFIC_CLASS\n                  value: \"16\"\n                - name: NVSHMEM_IB_SL\n                  value: \"5\"\n                - name: ENABLE_METRICS\n                  value: \"true\"\n                - name: CUDA_LAUNCH_BLOCKING\n                  value: \"0\"\n                - name: NVSHMEM_IB_GID_INDEX\n                  value: \"3\"\n                - name:  NCCL_IB_QPS_PER_CONNECTION\n                  value: \"8\"\n                - name: NCCL_IB_SPLIT_DATA_ON_QPS\n                  value: \"1\"\n                - name: NCCL_NET_PLUGIN\n                  value: \"none\"\n                - name: NCCL_IB_TC\n                  value: \"136\"\n                - name: NCCL_IB_SL\n                  value: \"5\"\n                - name: NCCL_IB_TIMEOUT\n                  value: \"22\"\n                - name: NCCL_IB_GID_INDEX\n                  value: \"3\"\n                - name: NCCL_MIN_NCHANNELS\n                  value: \"4\"\n                - name: NCCL_SOCKET_IFNAME\n                  value: bond0\n                - name: GLOO_SOCKET_IFNAME\n                  value: bond0\n                - name: NVSHMEM_BOOTSTRAP_UID_SOCK_IFNAME\n                  value: \"bond0\"\n                - name: NCCL_IB_HCA\n                  value: ^=mlx5_0,mlx5_5,mlx5_6\n                - name: MC_TE_METRIC\n                  value: \"false\"\n                - name: SGL_ENABLE_JIT_DEEPGEMM\n                  value: \"1\"\n            dnsPolicy: ClusterFirstWithHostNet\n            hostIPC: true\n            hostNetwork: true\n            nodeSelector:\n              pd: \"yes\"\n            tolerations:\n            - key: pd\n              operator: Exists\n            volumes:\n            - hostPath:\n                path: /var/run/sys-topology\n              name: topo\n            - hostPath:\n                path: /data1/sgl_cache4\n                type: DirectoryOrCreate\n              name: sgl-cache\n            - hostPath:\n                path: /data/src/sglang\n                type: DirectoryOrCreate\n              name: src\n            - emptyDir:\n                medium: Memory\n              name: dshm\n            - hostPath:\n                path: /data/DeepSeek-V3.2-Exp\n              name: model\n            - hostPath:\n                path: /dev/infiniband\n              name: ib\n    - name: router\n      replicas: 1\n      dependencies: [ \"decode\", \"prefill\" ]\n      template:\n        spec:\n          containers:\n            - name: scheduler\n              image: lmsysorg/sglang:latest\n              command:\n              - sh\n              - -c\n              - >\n                python3 -m sglang_router.launch_router\n                --host 0.0.0.0\n                --port 8080\n                --pd-disaggregation\n                --policy random\n                --service-discovery\n                --service-discovery-namespace ${NAMESPACE}\n                --service-discovery-port 30000\n                --prefill-selector pd_role=prefill\n                --decode-selector pd_role=decode\n                --max-payload-size 2147483648\n                --worker-startup-timeout-secs 1200\n              env:\n              - name: NAMESPACE\n                valueFrom:\n                  fieldRef:\n                    apiVersion: v1\n                    fieldPath: metadata.namespace\n---\napiVersion: v1\nkind: Service\nmetadata:\n  labels:\n    app: deepseek-rbg-32exp\n  name: deepseek-rbg-32exp\n  namespace: default\nspec:\n  ports:\n    - name: http\n      port: 8080\n      protocol: TCP\n      targetPort: 8080\n      nodePort: 30080\n\n  selector:\n    rolebasedgroup.workloads.x-k8s.io/name: deepseek-rbg-32exp\n    rolebasedgroup.workloads.x-k8s.io/role: router\n  type: NodePort\n\n```\n\n```bash\n[root@ecs-001]# kubectl get po -n default\ndeepseek-rbg-32exp-decode-main-0             1/1     Running   0          74m\ndeepseek-rbg-32exp-decode-0-1                1/1     Running   0          74m\ndeepseek-rbg-32exp-router-9c5dbfc57          1/1     Running   0          22m\ndeepseek-rbg-32exp-prefill-0                 1/1     Running   0          74m\n\n[root@ecs-cbm-x1-pd-cpu-001 main_doc]# kubectl  get svc |grep dee\ndeepseek-rbg-32exp-decode             ClusterIP   None             <none>        <none>           97m\ndeepseek-rbg-32exp-router-service     NodePort    172.16.242.169   <none>        8000:30800/TCP   22m\ndeepseek-rbg-32exp-prefill            ClusterIP   None             <none>        <none>           97m\n```\n\nAt this point, select a nodePort:30800 to access:\n\n```bash\n[root@ecs-001]# curl -X POST \"http://{nodePort}:30800/v1/chat/completions\" \\\n>     -H \"Content-Type: application/json\" \\\n>     -H \"Authorization: Bearer None\" \\\n>     -d '{\n>        \"rid\":\"ccccdd\",\n>         \"model\": \"dsv32\",\n>         \"messages\": [\n>             {\"role\": \"system\", \"content\": \"0: You are a helpful AI assistant\"},\n>             {\"role\": \"user\", \"content\": \"你是谁？.\"}\n>         ],\n>         \"max_tokens\":221\n>     }'\n{\"id\":\"ccccdd\",\"object\":\"chat.completion\",\"created\":1750252498,\"model\":\"qwen2\",\"choices\":[{\"index\":0,\"message\":{\"role\":\"assistant\",\"content\":\"<think>\\n嗯，用户问了一个很基础的自我介绍问题\"你是谁？\"。这可能是第一次互动时的常规开场白，也可能是想确认我的身份和功能范围。\\n\\n用户没有提供任何背景信息，语气简洁中性。这种场景下新用户的可能性较高，需要给出清晰友好的自我介绍，同时突出实用价值来降低陌生感。\\n\\n考虑到中文用户，应该用简体中文回复。重点要说明三点：身份归属（深度求索）、功能定位（AI助手）、服务范围（学习/工作/生活）。结尾用开放性问题引导对话很关键——既能了解需求，又能避免让用户面对空白输入框时不知所措。\\n\\n用波浪线结尾可以软化语气，那个笑脸表情😊刚好能中和AI的机械感。不过要控制表情符号数量，避免显得轻浮。\\n</think>\\n你好呀！我是你的AI助手，由深度求索公司（DeepSeek）开发的语言模型，名字叫 **DeepSeek-V32**。你可以把我当成一个知识丰富、随叫随到的小帮手～😊\\n\\n我的任务就是陪你聊天、解答问题、\",\"reasoning_content\":null,\"tool_calls\":null},\"logprobs\":null,\"finish_reason\":\"length\",\"matched_stop\":null}],\"usage\":{\"prompt_tokens\":14,\"total_tokens\":235,\"completion_tokens\":221,\"prompt_tokens_details\":null}}\n\n```\n## FAQ\n\n1. The current deployment startup parameters may not be fully compatible with all RDMA scenarios. Different RDMA NCCL-related environment configurations may be needed in different network environments.\n\n2. Please ensure that the sglang code in the image has incorporated the changes from [PR #10912](https://github.com/sgl-project/sglang/pull/10912).\n"
  },
  {
    "path": "docs/references/post_training_integration.md",
    "content": "# Post-Training Integration\n\nSGLang has become the de facto inference backend for modern LLM training frameworks, powering state-of-the-art models across the industry. From GLM-4.6 to Qwen3, leading models leverage SGLang's high-performance inference during reinforcement learning and post-training workflows.\n\nWhat makes SGLang essential for post-training?\n\n- Open-To-Use Refit Functionality: diverse method for colocate or disaggregate\n- Easy To Postpone Generation: enable partial rollout and dedicated rollout control\n- Fine-Grained Engine Sleep And Wake Up: facilitate maximum-powered rollout and training\n- Training Serving Alignment: ensure the performance consistency in training and serving\n- Load Balancing Router: cache-aware load-balancing for high-throughput rollout\n- Deterministic Inference: ensure zero kl divergence between rollout and training\n\nThese capabilities, combined with native integration support across major frameworks, have established SGLang as the infrastructure backbone for modern LLM/VLMs post-training. We also share our latest work in this slide, [Optimizing Large-Scale RL with SGLang](https://gamma.app/docs/Optimizing-RL-with-SGLang-y0kqgj877k34779).\n\n## Adoption\n\n- [**Miles**](https://github.com/radixark/miles): Enterprise-scale RL framework for large MoE models with SGLang-native rollout, speculative training, and production-grade stability\n- [**slime**](https://github.com/THUDM/slime): Post-training framework combining Megatron and SGLang, used to train GLM-4.6\n- [**AReaL**](https://github.com/inclusionAI/AReaL): Fully asynchronous RL system achieving 2.77x speedup with SGLang backend for continuous rollout generation\n- [**ROLL**](https://github.com/alibaba/ROLL): ROLL is an efficient and user-friendly RL library designed for Large Language Models utilizing Large Scale GPU resources\n- [**verl**](https://github.com/volcengine/verl): Full-stack RLHF framework supporting PPO, GRPO, and ReMax with modular SGLang integration\n- [**Unsloth**](https://docs.unsloth.ai/basics/inference-and-deployment/sglang-guide): 2x faster fine-tuning with optimized kernels, deploys seamlessly with SGLang inference\n- [**LLaMA Factory**](https://github.com/hiyouga/LLaMA-Factory): Unified framework for training 100+ LLMs with LoRA, QLoRA, and full fine-tuning methods\n- [**Tunix**](https://github.com/google/tunix): Google's JAX-native library for LLM post-training with SFT, DPO, PPO, and GRPO support\n- [**RL2**](https://github.com/ChenmienTan/RL2): Ray Less Reinforcement Learning, a concise library of post-training for large language models\n\n\n## Collaboration\n\nDue to the privacy of the design partners, we cannot list the companies that adopt SGLang for post-training. However, we are happy to share the details with you if you are interested and trust the choice among 10+ top companies and frontier labs across US and China. If you are interested in integrating SGLang with your training framework or need technical support, we're here to help! Reach out to us at **rl_team@lmsys.org** for partnerships, integration guidance, and custom feature development.\n"
  },
  {
    "path": "docs/references/production_metrics.md",
    "content": "# Production Metrics\n\nSGLang exposes the following metrics via Prometheus. You can enable it by adding `--enable-metrics` when you launch the server.\n\nAn example of the monitoring dashboard is available in [examples/monitoring/grafana.json](https://github.com/sgl-project/sglang/blob/main/examples/monitoring/grafana/dashboards/json/sglang-dashboard.json).\n\nHere is an example of the metrics:\n\n```\n$ curl http://localhost:30000/metrics\n# HELP sglang:prompt_tokens_total Number of prefill tokens processed.\n# TYPE sglang:prompt_tokens_total counter\nsglang:prompt_tokens_total{model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 8.128902e+06\n# HELP sglang:generation_tokens_total Number of generation tokens processed.\n# TYPE sglang:generation_tokens_total counter\nsglang:generation_tokens_total{model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 7.557572e+06\n# HELP sglang:token_usage The token usage\n# TYPE sglang:token_usage gauge\nsglang:token_usage{model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 0.28\n# HELP sglang:cache_hit_rate The cache hit rate\n# TYPE sglang:cache_hit_rate gauge\nsglang:cache_hit_rate{model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 0.007507552643049313\n# HELP sglang:time_to_first_token_seconds Histogram of time to first token in seconds.\n# TYPE sglang:time_to_first_token_seconds histogram\nsglang:time_to_first_token_seconds_sum{model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 2.3518979474117756e+06\nsglang:time_to_first_token_seconds_bucket{le=\"0.001\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 0.0\nsglang:time_to_first_token_seconds_bucket{le=\"0.005\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 0.0\nsglang:time_to_first_token_seconds_bucket{le=\"0.01\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 0.0\nsglang:time_to_first_token_seconds_bucket{le=\"0.02\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 0.0\nsglang:time_to_first_token_seconds_bucket{le=\"0.04\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 1.0\nsglang:time_to_first_token_seconds_bucket{le=\"0.06\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 3.0\nsglang:time_to_first_token_seconds_bucket{le=\"0.08\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 6.0\nsglang:time_to_first_token_seconds_bucket{le=\"0.1\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 6.0\nsglang:time_to_first_token_seconds_bucket{le=\"0.25\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 6.0\nsglang:time_to_first_token_seconds_bucket{le=\"0.5\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 6.0\nsglang:time_to_first_token_seconds_bucket{le=\"0.75\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 6.0\nsglang:time_to_first_token_seconds_bucket{le=\"1.0\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 27.0\nsglang:time_to_first_token_seconds_bucket{le=\"2.5\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 140.0\nsglang:time_to_first_token_seconds_bucket{le=\"5.0\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 314.0\nsglang:time_to_first_token_seconds_bucket{le=\"7.5\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 941.0\nsglang:time_to_first_token_seconds_bucket{le=\"10.0\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 1330.0\nsglang:time_to_first_token_seconds_bucket{le=\"15.0\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 1970.0\nsglang:time_to_first_token_seconds_bucket{le=\"20.0\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 2326.0\nsglang:time_to_first_token_seconds_bucket{le=\"25.0\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 2417.0\nsglang:time_to_first_token_seconds_bucket{le=\"30.0\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 2513.0\nsglang:time_to_first_token_seconds_bucket{le=\"+Inf\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 11008.0\nsglang:time_to_first_token_seconds_count{model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 11008.0\n# HELP sglang:e2e_request_latency_seconds Histogram of End-to-end request latency in seconds\n# TYPE sglang:e2e_request_latency_seconds histogram\nsglang:e2e_request_latency_seconds_sum{model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 3.116093850019932e+06\nsglang:e2e_request_latency_seconds_bucket{le=\"0.3\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 0.0\nsglang:e2e_request_latency_seconds_bucket{le=\"0.5\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 6.0\nsglang:e2e_request_latency_seconds_bucket{le=\"0.8\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 6.0\nsglang:e2e_request_latency_seconds_bucket{le=\"1.0\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 6.0\nsglang:e2e_request_latency_seconds_bucket{le=\"1.5\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 6.0\nsglang:e2e_request_latency_seconds_bucket{le=\"2.0\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 6.0\nsglang:e2e_request_latency_seconds_bucket{le=\"2.5\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 6.0\nsglang:e2e_request_latency_seconds_bucket{le=\"5.0\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 7.0\nsglang:e2e_request_latency_seconds_bucket{le=\"10.0\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 10.0\nsglang:e2e_request_latency_seconds_bucket{le=\"15.0\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 11.0\nsglang:e2e_request_latency_seconds_bucket{le=\"20.0\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 14.0\nsglang:e2e_request_latency_seconds_bucket{le=\"30.0\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 247.0\nsglang:e2e_request_latency_seconds_bucket{le=\"40.0\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 486.0\nsglang:e2e_request_latency_seconds_bucket{le=\"50.0\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 845.0\nsglang:e2e_request_latency_seconds_bucket{le=\"60.0\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 1513.0\nsglang:e2e_request_latency_seconds_bucket{le=\"+Inf\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 11228.0\nsglang:e2e_request_latency_seconds_count{model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 11228.0\n# HELP sglang:time_per_output_token_seconds Histogram of time per output token in seconds.\n# TYPE sglang:time_per_output_token_seconds histogram\nsglang:time_per_output_token_seconds_sum{model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 866964.5791549598\nsglang:time_per_output_token_seconds_bucket{le=\"0.005\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 1.0\nsglang:time_per_output_token_seconds_bucket{le=\"0.01\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 73.0\nsglang:time_per_output_token_seconds_bucket{le=\"0.015\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 382.0\nsglang:time_per_output_token_seconds_bucket{le=\"0.02\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 593.0\nsglang:time_per_output_token_seconds_bucket{le=\"0.025\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 855.0\nsglang:time_per_output_token_seconds_bucket{le=\"0.03\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 1035.0\nsglang:time_per_output_token_seconds_bucket{le=\"0.04\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 1815.0\nsglang:time_per_output_token_seconds_bucket{le=\"0.05\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 11685.0\nsglang:time_per_output_token_seconds_bucket{le=\"0.075\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 433413.0\nsglang:time_per_output_token_seconds_bucket{le=\"0.1\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 4.950195e+06\nsglang:time_per_output_token_seconds_bucket{le=\"0.15\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 7.039435e+06\nsglang:time_per_output_token_seconds_bucket{le=\"0.2\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 7.171662e+06\nsglang:time_per_output_token_seconds_bucket{le=\"0.3\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 7.266055e+06\nsglang:time_per_output_token_seconds_bucket{le=\"0.4\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 7.296752e+06\nsglang:time_per_output_token_seconds_bucket{le=\"0.5\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 7.312226e+06\nsglang:time_per_output_token_seconds_bucket{le=\"0.75\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 7.339675e+06\nsglang:time_per_output_token_seconds_bucket{le=\"1.0\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 7.357747e+06\nsglang:time_per_output_token_seconds_bucket{le=\"2.5\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 7.389414e+06\nsglang:time_per_output_token_seconds_bucket{le=\"+Inf\",model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 7.400757e+06\nsglang:time_per_output_token_seconds_count{model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 7.400757e+06\n# HELP sglang:func_latency_seconds Function latency in seconds\n# TYPE sglang:func_latency_seconds histogram\nsglang:func_latency_seconds_sum{name=\"generate_request\"} 4.514771912145079\nsglang:func_latency_seconds_bucket{le=\"0.05\",name=\"generate_request\"} 14006.0\nsglang:func_latency_seconds_bucket{le=\"0.07500000000000001\",name=\"generate_request\"} 14006.0\nsglang:func_latency_seconds_bucket{le=\"0.1125\",name=\"generate_request\"} 14006.0\nsglang:func_latency_seconds_bucket{le=\"0.16875\",name=\"generate_request\"} 14006.0\nsglang:func_latency_seconds_bucket{le=\"0.253125\",name=\"generate_request\"} 14006.0\nsglang:func_latency_seconds_bucket{le=\"0.3796875\",name=\"generate_request\"} 14006.0\nsglang:func_latency_seconds_bucket{le=\"0.56953125\",name=\"generate_request\"} 14006.0\nsglang:func_latency_seconds_bucket{le=\"0.8542968750000001\",name=\"generate_request\"} 14006.0\nsglang:func_latency_seconds_bucket{le=\"1.2814453125\",name=\"generate_request\"} 14006.0\nsglang:func_latency_seconds_bucket{le=\"1.9221679687500002\",name=\"generate_request\"} 14006.0\nsglang:func_latency_seconds_bucket{le=\"2.8832519531250003\",name=\"generate_request\"} 14006.0\nsglang:func_latency_seconds_bucket{le=\"4.3248779296875\",name=\"generate_request\"} 14007.0\nsglang:func_latency_seconds_bucket{le=\"6.487316894531251\",name=\"generate_request\"} 14007.0\nsglang:func_latency_seconds_bucket{le=\"9.730975341796876\",name=\"generate_request\"} 14007.0\nsglang:func_latency_seconds_bucket{le=\"14.596463012695313\",name=\"generate_request\"} 14007.0\nsglang:func_latency_seconds_bucket{le=\"21.89469451904297\",name=\"generate_request\"} 14007.0\nsglang:func_latency_seconds_bucket{le=\"32.84204177856446\",name=\"generate_request\"} 14007.0\nsglang:func_latency_seconds_bucket{le=\"49.26306266784668\",name=\"generate_request\"} 14007.0\nsglang:func_latency_seconds_bucket{le=\"+Inf\",name=\"generate_request\"} 14007.0\nsglang:func_latency_seconds_count{name=\"generate_request\"} 14007.0\n# HELP sglang:num_running_reqs The number of running requests\n# TYPE sglang:num_running_reqs gauge\nsglang:num_running_reqs{model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 162.0\n# HELP sglang:num_used_tokens The number of used tokens\n# TYPE sglang:num_used_tokens gauge\nsglang:num_used_tokens{model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 123859.0\n# HELP sglang:gen_throughput The generate throughput (token/s)\n# TYPE sglang:gen_throughput gauge\nsglang:gen_throughput{model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 86.50814177726902\n# HELP sglang:num_queue_reqs The number of requests in the waiting queue\n# TYPE sglang:num_queue_reqs gauge\nsglang:num_queue_reqs{model_name=\"meta-llama/Llama-3.1-8B-Instruct\"} 2826.0\n```\n\n## Setup Guide\n\nThis section describes how to set up the monitoring stack (Prometheus + Grafana) provided in the `examples/monitoring` directory.\n\n### Prerequisites\n\n- Docker and Docker Compose installed\n- SGLang server running with metrics enabled\n\n### Usage\n\n1.  **Start your SGLang server with metrics enabled:**\n\n    ```bash\n    python -m sglang.launch_server \\\n      --model-path <your_model_path> \\\n      --port 30000 \\\n      --enable-metrics\n    ```\n    Replace `<your_model_path>` with the actual path to your model (e.g., `meta-llama/Meta-Llama-3.1-8B-Instruct`). Ensure the server is accessible from the monitoring stack (you might need `--host 0.0.0.0` if running in Docker). By default, the metrics endpoint will be available at `http://<sglang_server_host>:30000/metrics`.\n\n2.  **Navigate to the monitoring example directory:**\n    ```bash\n    cd examples/monitoring\n    ```\n\n3.  **Start the monitoring stack:**\n    ```bash\n    docker compose up -d\n    ```\n    This command will start Prometheus and Grafana in the background.\n\n4.  **Access the monitoring interfaces:**\n    *   **Grafana:** Open your web browser and go to [http://localhost:3000](http://localhost:3000).\n    *   **Prometheus:** Open your web browser and go to [http://localhost:9090](http://localhost:9090).\n\n5.  **Log in to Grafana:**\n    *   Default Username: `admin`\n    *   Default Password: `admin`\n    You will be prompted to change the password upon your first login.\n\n6.  **View the Dashboard:**\n    The SGLang dashboard is pre-configured and should be available automatically. Navigate to `Dashboards` -> `Browse` -> `SGLang Monitoring` folder -> `SGLang Dashboard`.\n\n### Troubleshooting\n\n*   **Port Conflicts:** If you encounter errors like \"port is already allocated,\" check if other services (including previous instances of Prometheus/Grafana) are using ports `9090` or `3000`. Use `docker ps` to find running containers and `docker stop <container_id>` to stop them, or use `lsof -i :<port>` to find other processes using the ports. You might need to adjust the ports in the `docker-compose.yaml` file if they permanently conflict with other essential services on your system.\n\nTo modify Grafana's port to the other one(like 3090) in your Docker Compose file, you need to explicitly specify the port mapping under the grafana service.\n\n    Option 1: Add GF_SERVER_HTTP_PORT to the environment section:\n    ```\n      environment:\n    - GF_AUTH_ANONYMOUS_ENABLED=true\n    - GF_SERVER_HTTP_PORT=3090  # <-- Add this line\n    ```\n    Option 2: Use port mapping:\n    ```\n    grafana:\n      image: grafana/grafana:latest\n      container_name: grafana\n      ports:\n      - \"3090:3000\"  # <-- Host:Container port mapping\n    ```\n*   **Connection Issues:**\n    *   Ensure both Prometheus and Grafana containers are running (`docker ps`).\n    *   Verify the Prometheus data source configuration in Grafana (usually auto-configured via `grafana/datasources/datasource.yaml`). Go to `Connections` -> `Data sources` -> `Prometheus`. The URL should point to the Prometheus service (e.g., `http://prometheus:9090`).\n    *   Confirm that your SGLang server is running and the metrics endpoint (`http://<sglang_server_host>:30000/metrics`) is accessible *from the Prometheus container*. If SGLang is running on your host machine and Prometheus is in Docker, use `host.docker.internal` (on Docker Desktop) or your machine's network IP instead of `localhost` in the `prometheus.yaml` scrape configuration.\n*   **No Data on Dashboard:**\n    *   Generate some traffic to your SGLang server to produce metrics. For example, run a benchmark:\n        ```bash\n        python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 100 --random-input 128 --random-output 128\n        ```\n    *   Check the Prometheus UI (`http://localhost:9090`) under `Status` -> `Targets` to see if the SGLang endpoint is being scraped successfully.\n    *   Verify the `model_name` and `instance` labels in your Prometheus metrics match the variables used in the Grafana dashboard. You might need to adjust the Grafana dashboard variables or the labels in your Prometheus configuration.\n\n### Configuration Files\n\nThe monitoring setup is defined by the following files within the `examples/monitoring` directory:\n\n*   `docker-compose.yaml`: Defines the Prometheus and Grafana services.\n*   `prometheus.yaml`: Prometheus configuration, including scrape targets.\n*   `grafana/datasources/datasource.yaml`: Configures the Prometheus data source for Grafana.\n*   `grafana/dashboards/config/dashboard.yaml`: Tells Grafana to load dashboards from the specified path.\n*   `grafana/dashboards/json/sglang-dashboard.json`: The actual Grafana dashboard definition in JSON format.\n\nYou can customize the setup by modifying these files. For instance, you might need to update the `static_configs` target in `prometheus.yaml` if your SGLang server runs on a different host or port.\n\n#### Check if the metrics are being collected\n\nRun:\n```\npython3 -m sglang.bench_serving \\\n  --backend sglang \\\n  --dataset-name random \\\n  --num-prompts 3000 \\\n  --random-input 1024 \\\n  --random-output 1024 \\\n  --random-range-ratio 0.5\n```\n\nto generate some requests.\n\nThen you should be able to see the metrics in the Grafana dashboard.\n"
  },
  {
    "path": "docs/references/production_request_trace.md",
    "content": "# Production Request Tracing\n\nSGLang exports request trace data based on the OpenTelemetry Collector. You can enable tracing by adding the `--enable-trace` and configure the OpenTelemetry Collector endpoint using `--otlp-traces-endpoint` when launching the server.\n\nYou can find example screenshots of the visualization in https://github.com/sgl-project/sglang/issues/8965.\n\n## Setup Guide\nThis section explains how to configure the request tracing and export the trace data.\n1. Install the required packages and tools\n    * install Docker and Docker Compose\n    * install the dependencies\n    ```bash\n    # enter the SGLang root directory\n    pip install -e \"python[tracing]\"\n\n    # or manually install the dependencies using pip\n    pip install opentelemetry-sdk opentelemetry-api opentelemetry-exporter-otlp opentelemetry-exporter-otlp-proto-grpc\n    ```\n\n2. Launch OpenTelemetry collector and Jaeger\n    ```bash\n    docker compose -f examples/monitoring/tracing_compose.yaml up -d\n    ```\n\n3. Start your SGLang server with tracing enabled\n    ```bash\n    # set env variables\n    export SGLANG_OTLP_EXPORTER_SCHEDULE_DELAY_MILLIS=500\n    export SGLANG_OTLP_EXPORTER_MAX_EXPORT_BATCH_SIZE=64\n    # start the prefill and decode server\n    python -m sglang.launch_server --enable-trace --otlp-traces-endpoint 0.0.0.0:4317 <other option>\n    # start the model-gate-way\n    python -m sglang_router.launch_router --enable-trace --otlp-traces-endpoint 0.0.0.0:4317 <other option>\n    ```\n\n    Replace `0.0.0.0:4317` with the actual endpoint of the OpenTelemetry collector. If you launched the openTelemetry collector with tracing_compose.yaml, the default receiving port is 4317.\n\n    To use the HTTP/protobuf span exporter, set the following environment variable and point to an HTTP endpoint, for example, `http://0.0.0.0:4318/v1/traces`.\n    ```bash\n    export OTEL_EXPORTER_OTLP_TRACES_PROTOCOL=http/protobuf\n    ```\n\n\n4. Raise some requests\n5. Observe whether trace data is being exported\n    * Access port 16686 of Jaeger using a web browser to visualize the request traces.\n    * The OpenTelemetry Collector also exports trace data in JSON format to /tmp/otel_trace.json. In a follow-up patch, we will provide a tool to convert this data into a Perfetto-compatible format, enabling visualization of requests in the Perfetto UI.\n\n6. Dynamically adjust trace level\n    The trace level accepts configurable values from `0` to `3`. The meanings of different trace level values are as follows:\n    ```\n    0: disable tracing\n    1: Trace important slices\n    2: Trace all slices except nested ones\n    3: Trace all slices\n    ```\n    The trace level can be dynamically set via HTTP API, for example:\n    ```bash\n    curl http://0.0.0.0:30000/set_trace_level?level=2\n    ```\n    Replace `0.0.0.0:30000` with your actual server address, and replace `level=2` with the level you want to set.\n\n    **Note**: You must set the parameter `--enable-trace`; otherwise, the trace capability will not be enabled regardless of any dynamic adjustments to the trace level.\n\n## How to add Tracing for slices you're interested in?(API introduction)\nWe have already inserted instrumentation points in the tokenizer and scheduler main threads. If you wish to trace additional request execution segments or perform finer-grained tracing, please use the APIs from the tracing package as described below.\n\n**All of the following implementations are done in python/sglang/srt/observability/req_time_stats.py. If you want to add another slice, please do it here.**\n\n1. Initialization\n\n    Every process involved in tracing during the initialization phase should execute:\n    ```python\n    process_tracing_init(otlp_traces_endpoint, server_name)\n    ```\n    The otlp_traces_endpoint is obtained from the arguments, and you can set server_name freely, but it should remain consistent across all processes.\n\n    Every thread involved in tracing during the initialization phase should execute:\n    ```python\n    trace_set_thread_info(\"thread label\", tp_rank, dp_rank)\n    ```\n    The \"thread label\" can be regarded as the name of the thread, used to distinguish different threads in the visualization view.\n\n2. Create a trace context for a request\n    Each request needs to call `TraceReqContext()` to initialize a request context, which is used to generate slice spans and record request stage info. You can either store it within the request object or maintain it as a global variable.\n\n3. Mark the beginning and end of a request\n    ```\n    trace_ctx.trace_req_start().\n    trace_ctx.trace_req_finish()\n    ```\n    trace_req_start() and trace_req_finish() must be called within the same process, for example, in the tokenizer.\n\n4. Add tracing for a slice\n\n    * Add slice tracing normally:\n        ```python\n        trace_ctx.trace_slice_start(RequestStage.TOKENIZER.stage_name)\n        trace_ctx.trace_slice_end(RequestStage.TOKENIZER.stage_name)\n\n        or\n        trace_ctx.trace_slice(slice: TraceSliceContext)\n        ```\n\n    - The end of the last slice in a thread must be marked with thread_finish_flag=True, or explicitly call trace_ctx.abort(); otherwise, the thread's span will not be properly generated.\n        ```python\n        trace_ctx.slice_end(RequestStage.D.stage_name, thread_finish_flag = True)\n        trace_ctx.abort()\n        ```\n\n5. When the request execution flow transfers to another thread, the thread context needs to be explicitly rebuilt.\n    - receiver: Execute the following code after receiving the request via ZMQ\n        ```python\n        trace_ctx.rebuild_thread_context()\n        ```\n\n## How to Extend the Tracing Framework to Support Complex Tracing Scenarios\n\nThe currently provided tracing package still has potential for further development. If you wish to build more advanced features upon it, you must first understand its existing design principles.\n\nThe core of the tracing framework's implementation lies in the design of the span structure and the trace context. To aggregate scattered slices and enable concurrent tracking of multiple requests, we have designed a three-level trace context structure or span structure: `TraceReqContext`, `TraceThreadContext` and `TraceSliceContext`. Their relationship is as follows:\n```\nTraceReqContext (req_id=\"req-123\")\n├── TraceThreadContext(thread_label=\"scheduler\", tp_rank=0)\n|     └── TraceSliceContext(slice_name=\"prefill\")\n|\n└── TraceThreadContext(thread_label=\"scheduler\", tp_rank=1)\n      └── TraceSliceContext(slice_name=\"prefill\")\n```\n\nEach traced request maintains a global `TraceReqContext` and creates a corresponding request span. For every thread that processes the request, a `TraceThreadContext` is recorded and a thread span is created. The `TraceThreadContext` is nested within the `TraceReqContext`, and each currently traced code slice—potentially nested—is stored in its associated `TraceThreadContext`.\n\nIn addition to the above hierarchy, each slice also records its previous slice via Span.add_link(), which can be used to trace the execution flow.\n"
  },
  {
    "path": "docs/references/release_lookup.rst",
    "content": "Release Lookup\n==============\n\nFind which SGLang release first included a specific PR or commit.\n\n.. raw:: html\n\n   <style>\n       .release-lookup-container {\n           background-color: #ffffff;\n           padding: 2rem;\n           border-radius: 12px;\n           box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);\n           max-width: 600px;\n           margin: 1.5rem 0;\n       }\n\n       .release-lookup-container .input-group {\n           display: flex;\n           gap: 10px;\n           margin-bottom: 1.2rem;\n       }\n\n       .release-lookup-container input[type=\"text\"] {\n           flex: 1;\n           padding: 10px 14px;\n           border: 2px solid #e2e8f0;\n           border-radius: 8px;\n           font-size: 0.95rem;\n           outline: none;\n           transition: border-color 0.2s;\n           color: #1e293b;\n       }\n\n       .release-lookup-container input[type=\"text\"]:focus {\n           border-color: #3b82f6;\n           box-shadow: 0 0 0 3px rgba(59, 130, 246, 0.1);\n       }\n\n       .release-lookup-container input[type=\"text\"]::placeholder {\n           color: #94a3b8;\n       }\n\n       .release-lookup-container .rl-btn {\n           padding: 10px 20px;\n           background-color: #3b82f6;\n           color: white;\n           border: none;\n           border-radius: 8px;\n           font-size: 0.95rem;\n           font-weight: 600;\n           cursor: pointer;\n           transition: background-color 0.2s;\n       }\n\n       .release-lookup-container .rl-btn:hover {\n           background-color: #2563eb;\n       }\n\n       .release-lookup-container .rl-btn:disabled {\n           background-color: #cbd5e1;\n           cursor: not-allowed;\n       }\n\n       .release-lookup-container .rl-result {\n           margin-top: 1rem;\n           text-align: left;\n           display: none;\n       }\n\n       .release-lookup-container .rl-result.visible {\n           display: block;\n       }\n\n       .release-lookup-container .rl-result-content {\n           padding: 1rem;\n           border-radius: 8px;\n           margin-bottom: 0.75rem;\n       }\n\n       .release-lookup-container .rl-success {\n           background-color: #f0fdf4;\n           border: 1px solid #bbf7d0;\n           color: #166534;\n       }\n\n       .release-lookup-container .rl-error {\n           background-color: #fef2f2;\n           border: 1px solid #fecaca;\n           color: #991b1b;\n       }\n\n       .release-lookup-container .rl-row {\n           display: flex;\n           justify-content: space-between;\n           margin-bottom: 0.4rem;\n           align-items: baseline;\n       }\n\n       .release-lookup-container .rl-row:last-child {\n           margin-bottom: 0;\n       }\n\n       .release-lookup-container .rl-label {\n           font-weight: 600;\n           margin-right: 1rem;\n           min-width: 70px;\n       }\n\n       .release-lookup-container .rl-tag-link {\n           color: #3b82f6;\n           text-decoration: none;\n           font-weight: bold;\n           font-size: 1.05rem;\n       }\n\n       .release-lookup-container .rl-tag-link:hover {\n           text-decoration: underline;\n       }\n\n       .release-lookup-container .rl-badge {\n           display: inline-block;\n           padding: 2px 8px;\n           border-radius: 12px;\n           font-size: 0.75rem;\n           font-weight: 600;\n           text-transform: uppercase;\n       }\n\n       .release-lookup-container .rl-badge-main {\n           background-color: #dbeafe;\n           color: #1e40af;\n       }\n\n       .release-lookup-container .rl-badge-gateway {\n           background-color: #f3e8ff;\n           color: #6b21a8;\n       }\n\n       .release-lookup-container .rl-status {\n           margin-top: 0.8rem;\n           font-size: 0.85rem;\n           color: #64748b;\n           min-height: 18px;\n       }\n\n       .release-lookup-container .rl-loader {\n           display: inline-block;\n           width: 16px;\n           height: 16px;\n           border: 3px solid rgba(59, 130, 246, 0.2);\n           border-radius: 50%;\n           border-top-color: #3b82f6;\n           animation: rl-spin 1s linear infinite;\n           margin-right: 6px;\n           vertical-align: text-bottom;\n       }\n\n       @keyframes rl-spin {\n           to { transform: rotate(360deg); }\n       }\n   </style>\n\n   <div class=\"release-lookup-container\">\n       <div class=\"input-group\">\n           <input type=\"text\" id=\"rlQueryInput\" placeholder=\"PR # (e.g. 1425), PR URL, or commit hash\" autocomplete=\"off\" />\n           <button class=\"rl-btn\" id=\"rlSearchBtn\" disabled>Search</button>\n       </div>\n       <div id=\"rlLoading\" style=\"display:none; color:#64748b; margin-bottom:0.8rem;\">\n           <span class=\"rl-loader\"></span> Loading index…\n       </div>\n       <div class=\"rl-result\" id=\"rlResult\"></div>\n       <div class=\"rl-status\" id=\"rlStatus\">Initializing…</div>\n   </div>\n\n   <script>\n   (function() {\n       var INDEX_URL = '/release_lookup/release_index.json';\n       var SHORT_HASH_LEN = 8;\n       var tagIndex = null, tagsArray = null, sortedCommitKeys = null;\n\n       var input = document.getElementById('rlQueryInput');\n       var btn = document.getElementById('rlSearchBtn');\n       var resultDiv = document.getElementById('rlResult');\n       var loadingDiv = document.getElementById('rlLoading');\n       var statusDiv = document.getElementById('rlStatus');\n\n       function formatDate(iso) {\n           if (!iso) return 'Unknown';\n           try { return new Date(iso).toLocaleDateString('en-US', {year:'numeric',month:'long',day:'numeric'}); }\n           catch(e) { return iso; }\n       }\n\n       function getTagInfo(ref) {\n           var tag = tagsArray[ref];\n           return { name: tag[0], date: tag[1], type: tag[2] === 1 ? 'gateway' : 'main' };\n       }\n\n       function parseTagRef(ref) {\n           if (typeof ref === 'string' && /^[mg]\\d+$/.test(ref))\n               return { type: ref[0], idx: parseInt(ref.slice(1)) };\n           return null;\n       }\n\n       function prefixSearch(prefix) {\n           if (!sortedCommitKeys) return null;\n           var lo = 0, hi = sortedCommitKeys.length;\n           while (lo < hi) {\n               var mid = (lo + hi) >>> 1;\n               if (sortedCommitKeys[mid] < prefix) lo = mid + 1; else hi = mid;\n           }\n           if (lo < sortedCommitKeys.length && sortedCommitKeys[lo].indexOf(prefix) === 0)\n               return sortedCommitKeys[lo];\n           return null;\n       }\n\n       function loadIndex() {\n           loadingDiv.style.display = 'block';\n           statusDiv.textContent = 'Downloading index…';\n           fetch(INDEX_URL)\n               .then(function(r) {\n                   if (!r.ok) throw new Error('Index not found. It is generated on each release.');\n                   return r.json();\n               })\n               .then(function(data) {\n                   tagsArray = data.t;\n                   tagIndex = { prs: data.p, commits: data.c };\n                   sortedCommitKeys = Object.keys(tagIndex.commits).sort();\n                   var tagCount = tagsArray.length;\n                   var prCount = Object.keys(tagIndex.prs).length;\n                   statusDiv.textContent = 'Ready. Indexed ' + tagCount + ' releases and ' + prCount + ' PRs.';\n                   btn.disabled = false;\n               })\n               .catch(function(e) {\n                   statusDiv.innerHTML = '<span style=\"color:#991b1b;\">Error: ' + e.message + '</span>';\n                   btn.disabled = true;\n               })\n               .finally(function() { loadingDiv.style.display = 'none'; });\n       }\n\n       function search() {\n           if (!tagIndex) return;\n           var raw = input.value.trim();\n           if (!raw) return;\n           resultDiv.style.display = 'none';\n           resultDiv.classList.remove('visible');\n           resultDiv.innerHTML = '';\n\n           var queryType = 'unknown', key = raw;\n           var urlMatch = raw.match(/\\/pull\\/(\\d+)/);\n           if (urlMatch) { key = urlMatch[1]; queryType = 'pr'; }\n           else if (/^#?\\d+$/.test(raw)) { key = raw.replace('#',''); queryType = 'pr'; }\n           else if (/^[0-9a-fA-F]{7,40}$/.test(raw)) { key = raw.toLowerCase(); queryType = 'commit'; }\n\n           var tagData = null;\n           if (queryType === 'pr') {\n               tagData = tagIndex.prs[key];\n           } else if (queryType === 'commit') {\n               var sk = key.slice(0, SHORT_HASH_LEN);\n               tagData = tagIndex.commits[sk];\n               if (!tagData) { var mk = prefixSearch(sk); if (mk) tagData = tagIndex.commits[mk]; }\n           }\n\n           renderResult(tagData, queryType, key);\n       }\n\n       function renderResult(tagData, queryType, key) {\n           resultDiv.innerHTML = '';\n           resultDiv.style.display = 'block';\n           void resultDiv.offsetWidth;\n           resultDiv.classList.add('visible');\n\n           var tagRefs = [];\n           if (tagData) {\n               if (typeof tagData === 'string') {\n                   var p = parseTagRef(tagData);\n                   if (p) tagRefs.push(p.idx);\n               } else if (typeof tagData === 'object') {\n                   if ('m' in tagData) tagRefs.push(tagData.m);\n                   if ('g' in tagData) tagRefs.push(tagData.g);\n               }\n           }\n\n           if (tagRefs.length === 0) {\n               var label = queryType === 'pr' ? 'PR #' + key : 'Commit ' + key.substring(0,7);\n               var c = document.createElement('div');\n               c.className = 'rl-result-content rl-error';\n               c.innerHTML = '<div class=\"rl-row\"><span class=\"rl-label\">Status</span><span>Not Found</span></div>';\n               var msg = document.createElement('div');\n               msg.style.marginTop = '6px';\n               var s = document.createElement('strong');\n               s.textContent = label;\n               msg.appendChild(document.createTextNode('The ' + queryType + ' '));\n               msg.appendChild(s);\n               msg.appendChild(document.createTextNode(' has not been included in any release yet, or is not in the index.'));\n               c.appendChild(msg);\n               resultDiv.appendChild(c);\n               return;\n           }\n\n           var repoUrl = 'https://github.com/sgl-project/sglang';\n           for (var i = 0; i < tagRefs.length; i++) {\n               var info = getTagInfo(tagRefs[i]);\n               var tagUrl = repoUrl + '/releases/tag/' + encodeURIComponent(info.name);\n               var badgeClass = info.type === 'gateway' ? 'rl-badge-gateway' : 'rl-badge-main';\n               var box = document.createElement('div');\n               box.className = 'rl-result-content rl-success';\n               box.innerHTML =\n                   '<div class=\"rl-row\"><span class=\"rl-label\">Release</span><a target=\"_blank\" class=\"rl-tag-link\"></a></div>' +\n                   '<div class=\"rl-row\"><span class=\"rl-label\">Date</span><span class=\"rl-date\"></span></div>' +\n                   '<div class=\"rl-row\"><span class=\"rl-label\">Module</span><span class=\"rl-badge ' + badgeClass + ' rl-module\"></span></div>';\n               var link = box.querySelector('.rl-tag-link');\n               link.href = tagUrl;\n               link.textContent = info.name;\n               box.querySelector('.rl-date').textContent = formatDate(info.date);\n               box.querySelector('.rl-module').textContent = info.type;\n               resultDiv.appendChild(box);\n           }\n       }\n\n       btn.addEventListener('click', search);\n       input.addEventListener('keypress', function(e) { if (e.key === 'Enter') search(); });\n       loadIndex();\n   })();\n   </script>\n"
  },
  {
    "path": "docs/references/torch_compile_cache.md",
    "content": "# Enabling cache for torch.compile\n\nSGLang uses `max-autotune-no-cudagraphs` mode of torch.compile. The auto-tuning can be slow.\nIf you want to deploy a model on many different machines, you can ship the torch.compile cache to these machines and skip the compilation steps.\n\nThis is based on https://pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html\n\n\n1. Generate the cache by setting TORCHINDUCTOR_CACHE_DIR and running the model once.\n```\nTORCHINDUCTOR_CACHE_DIR=/root/inductor_root_cache python3 -m sglang.launch_server --model meta-llama/Llama-3.1-8B-Instruct --enable-torch-compile\n```\n2. Copy the cache folder to other machines and launch the server with `TORCHINDUCTOR_CACHE_DIR`.\n"
  },
  {
    "path": "docs/release_lookup/README.md",
    "content": "# SGLang Release Lookup Tool\n\nThis tool allows users to find the earliest release that contains a specific PR or commit.\nIt runs entirely in the browser using a static JSON index generated from the git history.\n\n## Usage\n\n1. **Generate the Index**:\n   Run the Python script to generate the `release_index.json` file from your local git repository.\n\n   ```bash\n   python3 generate_index.py --output release_index.json\n   ```\n\n   This script:\n   - Finds all tags matching `v*` and `gateway-v*`.\n   - Sorts them by creation date.\n   - Traverses the history to find which release first introduced each commit and PR.\n   - Extracts PR numbers from commit messages.\n\n2. **Open the Tool**:\n   Open `index.html` in your browser.\n\n   ```bash\n   # You can open it directly if your browser supports local file fetch (Firefox usually does),\n   # or serve it locally:\n   python3 -m http.server\n   # Then go to http://localhost:8000/index.html\n   ```\n\n## Files\n\n- `index.html`: The UI for the lookup tool.\n- `generate_index.py`: Script to build the index.\n- `release_index.json`: The index file used by the UI.\n\n## Logic\n\nThe tool determines the \"earliest release\" based on the tag creation date. It traverses tags from oldest to newest. Any commit reachable from a tag (that wasn't reachable from a previous tag) is assigned to that release.\n"
  },
  {
    "path": "docs/release_lookup/generate_index.py",
    "content": "import argparse\nimport json\nimport os\nimport re\nimport subprocess\nimport sys\nfrom datetime import datetime\n\n# Short hash length for commits (7 is git's default short hash)\nSHORT_HASH_LEN = 8\nCOMMIT_CHUNK_SIZE = 1000\n\n\ndef run_git(cmd):\n    try:\n        output = subprocess.check_output(cmd, stderr=subprocess.STDOUT)\n        return output.decode(\"utf-8\", errors=\"replace\").strip()\n    except subprocess.CalledProcessError as e:\n        print(f\"Error running cmd: {cmd}\\n{e.output.decode('utf-8', errors='replace')}\")\n        sys.exit(1)\n\n\ndef is_stable_release(tag_name):\n    \"\"\"Check if tag is a stable release (not rc/alpha/beta).\"\"\"\n    # Skip release candidates, alpha, beta versions\n    if re.search(r\"(rc|alpha|beta)\\d*$\", tag_name, re.IGNORECASE):\n        return False\n    return True\n\n\ndef get_tags():\n    # Get tags sorted by creator date\n    cmd = [\n        \"git\",\n        \"tag\",\n        \"--list\",\n        \"v*\",\n        \"gateway-v*\",\n        \"--sort=creatordate\",\n        \"--format=%(refname:short)|%(creatordate:iso8601)|%(objectname)\",\n    ]\n    raw = run_git(cmd)\n    tags = []\n    if not raw:\n        return []\n    for line in raw.split(\"\\n\"):\n        parts = line.split(\"|\")\n        if len(parts) >= 3:\n            name, date, commit = parts[0], parts[1], parts[2]\n            # Skip non-stable releases (rc, alpha, beta)\n            if not is_stable_release(name):\n                continue\n            tag_type = \"gateway\" if name.startswith(\"gateway-\") else \"main\"\n            tags.append(\n                {\"name\": name, \"date\": date, \"commit\": commit, \"type\": tag_type}\n            )\n    return tags\n\n\ndef extract_pr_num(message):\n    lines = message.strip().split(\"\\n\")\n    first_line = lines[0]\n\n    m = re.search(r\"\\(#(\\d+)\\)$\", first_line)\n    if m:\n        return m.group(1)\n\n    m = re.search(r\"Merge pull request #(\\d+)\", message)\n    if m:\n        return m.group(1)\n\n    return None\n\n\ndef process_tag_line(tags, commit_map, pr_map, tag_type, tag_to_idx):\n    \"\"\"Process a single release line (main or gateway) independently.\"\"\"\n    seen_commits = set()\n\n    for tag in tags:\n        tag_name = tag[\"name\"]\n        print(f\"Processing {tag_name}...\")\n\n        commits = run_git([\"git\", \"rev-list\", tag_name]).split(\"\\n\")\n\n        new_commits = []\n        for c in commits:\n            c = c.strip()\n            if not c:\n                continue\n            if c in seen_commits:\n                continue\n            new_commits.append(c)\n            seen_commits.add(c)\n\n        if not new_commits:\n            continue\n\n        for i in range(0, len(new_commits), COMMIT_CHUNK_SIZE):\n            chunk = new_commits[i : i + COMMIT_CHUNK_SIZE]\n\n            cmd = [\"git\", \"show\", \"-s\", \"--format=%H|%B%n--END-COMMIT--\"] + chunk\n            raw_logs = run_git(cmd)\n\n            entries = raw_logs.split(\"--END-COMMIT--\\n\")\n            for log_entry in entries:\n                if not log_entry.strip():\n                    continue\n                parts = log_entry.split(\"|\", 1)\n                if len(parts) < 2:\n                    continue\n                sha = parts[0].strip()\n                msg = parts[1].strip()\n\n                tag_idx = tag_to_idx[tag_name]\n\n                # Store release index using full SHA as key\n                if sha not in commit_map:\n                    commit_map[sha] = {}\n                commit_map[sha][tag_type] = tag_idx\n\n                pr = extract_pr_num(msg)\n                if pr:\n                    if pr not in pr_map:\n                        pr_map[pr] = {}\n                    if tag_type not in pr_map[pr]:\n                        pr_map[pr][tag_type] = tag_idx\n\n\ndef main():\n    parser = argparse.ArgumentParser(\n        description=\"Generate lookup index for sglang releases\"\n    )\n    parser.add_argument(\n        \"--output\", default=\"release_index.json\", help=\"Output JSON file\"\n    )\n    args = parser.parse_args()\n\n    tags = get_tags()\n    print(f\"Found {len(tags)} tags.\")\n\n    main_tags = [t for t in tags if t[\"type\"] == \"main\"]\n    gateway_tags = [t for t in tags if t[\"type\"] == \"gateway\"]\n\n    print(f\"  - {len(main_tags)} main tags\")\n    print(f\"  - {len(gateway_tags)} gateway tags\")\n\n    # Build tag list and index mapping\n    # Tags array: [name, date, type] for each tag\n    tag_list = []\n    tag_to_idx = {}\n\n    for tag in tags:\n        tag_to_idx[tag[\"name\"]] = len(tag_list)\n        # Compact format: [name, date, type (0=main, 1=gateway)]\n        tag_list.append(\n            [tag[\"name\"], tag[\"date\"], 1 if tag[\"type\"] == \"gateway\" else 0]\n        )\n\n    pr_map = {}\n    commit_map_full = {}\n\n    process_tag_line(main_tags, commit_map_full, pr_map, \"m\", tag_to_idx)\n    process_tag_line(gateway_tags, commit_map_full, pr_map, \"g\", tag_to_idx)\n\n    # Convert full SHAs to short SHAs, checking for collisions\n    commit_map = {}\n    short_to_full_map = {}\n    for full_sha, data in commit_map_full.items():\n        short_sha = full_sha[:SHORT_HASH_LEN]\n        if short_sha in short_to_full_map and short_to_full_map[short_sha] != full_sha:\n            print(\n                f\"CRITICAL: Short SHA collision detected for '{short_sha}'\\n\"\n                f\"  Commit 1: {short_to_full_map[short_sha]}\\n\"\n                f\"  Commit 2: {full_sha}\\n\"\n                \"Please increase SHORT_HASH_LEN and re-run.\",\n                file=sys.stderr,\n            )\n            sys.exit(1)\n        commit_map[short_sha] = data\n        short_to_full_map[short_sha] = full_sha\n\n    # Compact output format:\n    # - tags: array of [name, date, type]\n    # - prs: {pr_num: tag_idx} or {pr_num: {m: idx, g: idx}}\n    # - commits: {short_hash: tag_idx} or {short_hash: {m: idx, g: idx}}\n\n    # Simplify single-entry dicts to just the value\n    def simplify_map(m):\n        result = {}\n        for k, v in m.items():\n            if len(v) == 1:\n                # Single entry: just store the index directly with type prefix\n                key_type, idx = list(v.items())[0]\n                result[k] = f\"{key_type}{idx}\"\n            else:\n                # Multiple entries: keep as dict\n                result[k] = v\n        return result\n\n    output_data = {\n        \"t\": tag_list,  # tags\n        \"p\": simplify_map(pr_map),  # prs\n        \"c\": simplify_map(commit_map),  # commits\n        \"g\": datetime.now().isoformat(),  # generated_at\n    }\n\n    # Write minified JSON with a trailing newline for formatter compatibility.\n    json_str = json.dumps(output_data, separators=(\",\", \":\"))\n\n    with open(args.output, \"w\", encoding=\"utf-8\") as f:\n        f.write(json_str)\n        f.write(\"\\n\")\n\n    json_size = os.path.getsize(args.output)\n\n    print(f\"Index generated at {args.output}\")\n    print(f\"Stats: {len(tag_list)} tags, {len(pr_map)} PRs, {len(commit_map)} commits.\")\n    print(f\"Size: {json_size/1024:.1f} KB\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "docs/release_lookup/index.html",
    "content": "<!DOCTYPE html>\n<html lang=\"en\">\n<head>\n    <meta charset=\"UTF-8\">\n    <meta name=\"viewport\" content=\"width=device-width, initial-scale=1.0\">\n    <title>SGLang Release Lookup</title>\n    <style>\n        :root {\n            --primary: #3b82f6;\n            --primary-hover: #2563eb;\n            --bg: #f8fafc;\n            --card-bg: #ffffff;\n            --text-main: #1e293b;\n            --text-secondary: #64748b;\n            --border: #e2e8f0;\n            --success-bg: #f0fdf4;\n            --success-border: #bbf7d0;\n            --success-text: #166534;\n            --error-bg: #fef2f2;\n            --error-border: #fecaca;\n            --error-text: #991b1b;\n        }\n\n        body {\n            font-family: -apple-system, BlinkMacSystemFont, \"Segoe UI\", Roboto, Helvetica, Arial, sans-serif;\n            background-color: var(--bg);\n            color: var(--text-main);\n            display: flex;\n            justify-content: center;\n            align-items: center;\n            min-height: 100vh;\n            margin: 0;\n            padding: 20px;\n        }\n\n        .container {\n            background-color: var(--card-bg);\n            padding: 2.5rem;\n            border-radius: 16px;\n            box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.1), 0 4px 6px -2px rgba(0, 0, 0, 0.05);\n            width: 100%;\n            max-width: 550px;\n            text-align: center;\n            transition: transform 0.2s;\n        }\n\n        h1 {\n            margin-top: 0;\n            margin-bottom: 0.5rem;\n            color: var(--text-main);\n            font-size: 1.8rem;\n            font-weight: 700;\n        }\n\n        p.subtitle {\n            margin-bottom: 2rem;\n            color: var(--text-secondary);\n            font-size: 0.95rem;\n        }\n\n        .input-group {\n            display: flex;\n            gap: 12px;\n            margin-bottom: 1.5rem;\n            position: relative;\n        }\n\n        input {\n            flex: 1;\n            padding: 12px 16px;\n            border: 2px solid var(--border);\n            border-radius: 8px;\n            font-size: 1rem;\n            outline: none;\n            transition: all 0.2s ease;\n            color: var(--text-main);\n        }\n\n        input:focus {\n            border-color: var(--primary);\n            box-shadow: 0 0 0 3px rgba(59, 130, 246, 0.1);\n        }\n\n        input::placeholder {\n            color: #94a3b8;\n        }\n\n        button {\n            padding: 12px 24px;\n            background-color: var(--primary);\n            color: white;\n            border: none;\n            border-radius: 8px;\n            font-size: 1rem;\n            font-weight: 600;\n            cursor: pointer;\n            transition: background-color 0.2s, transform 0.1s;\n        }\n\n        button:hover {\n            background-color: var(--primary-hover);\n        }\n\n        button:active {\n            transform: translateY(1px);\n        }\n\n        button:disabled {\n            background-color: #cbd5e1;\n            cursor: not-allowed;\n            transform: none;\n        }\n\n        #result {\n            margin-top: 1.5rem;\n            text-align: left;\n            border-radius: 8px;\n            display: none;\n            opacity: 0;\n            transition: opacity 0.3s ease;\n        }\n\n        #result.visible {\n            opacity: 1;\n        }\n\n        .result-content {\n            padding: 1.25rem;\n            border-radius: 8px;\n        }\n\n        .result-success {\n            background-color: var(--success-bg);\n            border: 1px solid var(--success-border);\n            color: var(--success-text);\n        }\n\n        .result-error {\n            background-color: var(--error-bg);\n            border: 1px solid var(--error-border);\n            color: var(--error-text);\n        }\n\n        .result-row {\n            display: flex;\n            justify-content: space-between;\n            margin-bottom: 0.5rem;\n            align-items: baseline;\n        }\n\n        .result-row:last-child {\n            margin-bottom: 0;\n        }\n\n        .result-label {\n            font-weight: 600;\n            margin-right: 1rem;\n            min-width: 80px;\n        }\n\n        .tag-link {\n            color: var(--primary);\n            text-decoration: none;\n            font-weight: bold;\n            font-size: 1.1rem;\n        }\n\n        .tag-link:hover {\n            text-decoration: underline;\n        }\n\n        .loader {\n            display: inline-block;\n            width: 18px;\n            height: 18px;\n            border: 3px solid rgba(59, 130, 246, 0.2);\n            border-radius: 50%;\n            border-top-color: var(--primary);\n            animation: spin 1s linear infinite;\n            margin-right: 8px;\n            vertical-align: text-bottom;\n        }\n\n        @keyframes spin {\n            to { transform: rotate(360deg); }\n        }\n\n        .status-msg {\n            margin-top: 1rem;\n            font-size: 0.85rem;\n            color: var(--text-secondary);\n            min-height: 20px;\n        }\n\n        .badge {\n            display: inline-block;\n            padding: 2px 8px;\n            border-radius: 12px;\n            font-size: 0.75rem;\n            font-weight: 600;\n            text-transform: uppercase;\n        }\n\n        .badge-main {\n            background-color: #dbeafe;\n            color: #1e40af;\n        }\n\n        .badge-gateway {\n            background-color: #f3e8ff;\n            color: #6b21a8;\n        }\n    </style>\n</head>\n<body>\n\n<div class=\"container\">\n    <h1>Release Lookup</h1>\n    <p class=\"subtitle\">Find which SGLang release first included your PR or commit.</p>\n\n    <div class=\"input-group\">\n        <input type=\"text\" id=\"queryInput\" placeholder=\"PR # (e.g. 1425), URL, or Commit Hash\" autocomplete=\"off\" />\n        <button id=\"searchBtn\" disabled>Search</button>\n    </div>\n\n    <div id=\"loading\" style=\"display: none; margin-bottom: 1rem; color: var(--text-secondary);\">\n        <span class=\"loader\"></span> Loading index...\n    </div>\n\n    <div id=\"result\"></div>\n\n    <div id=\"indexStatus\" class=\"status-msg\">Initializing...</div>\n</div>\n\n<script>\n    let tagIndex = null;\n    let tagsArray = null;  // Compact format: array of [name, date, type]\n    let sortedCommitKeys = null;  // Sorted keys for binary prefix search\n    const INDEX_FILE = 'release_index.json';\n    const SHORT_HASH_LEN = 8;\n\n    const input = document.getElementById('queryInput');\n    const btn = document.getElementById('searchBtn');\n    const resultDiv = document.getElementById('result');\n    const loadingDiv = document.getElementById('loading');\n    const statusDiv = document.getElementById('indexStatus');\n\n    // Format date nicely (always in English)\n    function formatDate(isoString) {\n        if (!isoString) return 'Unknown';\n        try {\n            return new Date(isoString).toLocaleDateString('en-US', {\n                year: 'numeric', month: 'long', day: 'numeric'\n            });\n        } catch(e) { return isoString; }\n    }\n\n    // Check if index is in compact format\n    function isCompactFormat(data) {\n        return Array.isArray(data.t);\n    }\n\n    // Get tag info by index (compact) or name (legacy)\n    function getTagInfo(tagRef) {\n        if (tagsArray) {\n            // Compact format: tagRef is index\n            const tag = tagsArray[tagRef];\n            return {\n                name: tag[0],\n                date: tag[1],\n                type: tag[2] === 1 ? 'gateway' : 'main'\n            };\n        } else {\n            // Legacy format: tagRef is name\n            const info = tagIndex.tags[tagRef];\n            return { name: tagRef, ...info };\n        }\n    }\n\n    // Parse compact tag reference: \"m5\" -> {type: 'm', idx: 5}\n    function parseTagRef(ref) {\n        if (typeof ref === 'string' && /^[mg]\\d+$/.test(ref)) {\n            return {\n                type: ref[0],\n                idx: parseInt(ref.slice(1))\n            };\n        }\n        return null;\n    }\n\n    async function loadIndex() {\n        loadingDiv.style.display = 'block';\n        statusDiv.innerText = 'Downloading index...';\n\n        try {\n            const response = await fetch(INDEX_FILE);\n            if (!response.ok) {\n                throw new Error(\"No index file found. Please run generate_index.py.\");\n            }\n            const data = await response.json();\n\n            // Handle both compact and legacy formats\n            if (isCompactFormat(data)) {\n                tagsArray = data.t;\n                tagIndex = {\n                    prs: data.p,\n                    commits: data.c\n                };\n            } else {\n                tagIndex = data;\n                tagsArray = null;\n            }\n            // Pre-sort commit keys for binary prefix search\n            sortedCommitKeys = Object.keys(tagIndex.commits).sort();\n\n            const tagCount = tagsArray ? tagsArray.length : Object.keys(tagIndex.tags).length;\n            const prCount = Object.keys(tagIndex.prs).length;\n\n            statusDiv.innerText = `Ready. Indexed ${tagCount} releases and ${prCount} PRs.`;\n            btn.disabled = false;\n        } catch (e) {\n            statusDiv.textContent = '';\n            const errorSpan = document.createElement('span');\n            errorSpan.style.color = 'var(--error-text)';\n            errorSpan.textContent = `Error: ${e.message}`;\n            statusDiv.appendChild(errorSpan);\n            btn.disabled = true;\n        } finally {\n            loadingDiv.style.display = 'none';\n        }\n    }\n\n    // Binary search for first commit key matching the given prefix (O(log n))\n    function prefixSearchCommit(prefix) {\n        if (!sortedCommitKeys) return null;\n        let lo = 0, hi = sortedCommitKeys.length;\n        while (lo < hi) {\n            const mid = (lo + hi) >>> 1;\n            if (sortedCommitKeys[mid] < prefix) lo = mid + 1;\n            else hi = mid;\n        }\n        if (lo < sortedCommitKeys.length && sortedCommitKeys[lo].startsWith(prefix)) {\n            return sortedCommitKeys[lo];\n        }\n        return null;\n    }\n\n    // Start loading\n    loadIndex();\n\n    // Event listeners\n    btn.addEventListener('click', performSearch);\n    input.addEventListener('keypress', (e) => {\n        if (e.key === 'Enter') performSearch();\n    });\n\n    // Auto-focus input\n    input.focus();\n\n    function performSearch() {\n        if (!tagIndex) return;\n\n        const rawQuery = input.value.trim();\n        if (!rawQuery) return;\n\n        // Hide previous result\n        resultDiv.style.display = 'none';\n        resultDiv.classList.remove('visible');\n\n        let queryType = 'unknown';\n        let key = rawQuery;\n\n        // Parse query\n        // 1. PR URL: https://github.com/.../pull/1234\n        const urlMatch = rawQuery.match(/\\/pull\\/(\\d+)/);\n        if (urlMatch) {\n            key = urlMatch[1];\n            queryType = 'pr';\n        }\n        // 2. PR Number: #1234 or 1234\n        else if (rawQuery.match(/^#?\\d+$/)) {\n            key = rawQuery.replace('#', '');\n            queryType = 'pr';\n        }\n        // 3. Commit Hash: usually hex string (min 7 chars)\n        else if (rawQuery.match(/^[0-9a-fA-F]{7,40}$/)) {\n            key = rawQuery.toLowerCase();\n            queryType = 'commit';\n        }\n\n        let tagData = null;\n\n        if (queryType === 'pr') {\n            tagData = tagIndex.prs[key];\n        } else if (queryType === 'commit') {\n            // Use short hash for lookup\n            const shortKey = key.slice(0, SHORT_HASH_LEN);\n            tagData = tagIndex.commits[shortKey];\n\n            // If not found with short hash, try prefix match (binary search)\n            if (!tagData) {\n                const matchKey = prefixSearchCommit(shortKey);\n                if (matchKey) {\n                    tagData = tagIndex.commits[matchKey];\n                }\n            }\n        }\n\n        renderResult(tagData, queryType, key);\n    }\n\n    function renderResult(tagData, queryType, key) {\n        resultDiv.innerHTML = '';\n        resultDiv.style.display = 'block';\n\n        // Trigger reflow for animation\n        void resultDiv.offsetWidth;\n        resultDiv.classList.add('visible');\n\n        // Collect tag references\n        let tagRefs = [];\n\n        if (!tagData) {\n            // Not found\n        } else if (typeof tagData === 'string') {\n            // Compact format: \"m5\" or \"g3\"\n            const parsed = parseTagRef(tagData);\n            if (parsed) {\n                tagRefs.push(parsed.idx);\n            } else {\n                // Legacy format: tag name directly\n                tagRefs.push(tagData);\n            }\n        } else if (typeof tagData === 'object') {\n            // Object format: {m: 5, g: 3} or {main: \"v0.5.8\", gateway: \"...\"}\n            if ('m' in tagData) tagRefs.push(tagData.m);\n            if ('g' in tagData) tagRefs.push(tagData.g);\n            if ('main' in tagData) tagRefs.push(tagData.main);\n            if ('gateway' in tagData) tagRefs.push(tagData.gateway);\n        }\n\n        if (tagRefs.length === 0) {\n            const label = queryType === 'pr' ? `PR #${key}` : `Commit ${key.substring(0, 7)}`;\n\n            const container = document.createElement('div');\n            container.className = 'result-content result-error';\n\n            const statusRow = document.createElement('div');\n            statusRow.className = 'result-row';\n            const statusLabel = document.createElement('span');\n            statusLabel.className = 'result-label';\n            statusLabel.textContent = 'Status';\n            const statusValue = document.createElement('span');\n            statusValue.textContent = 'Not Found';\n            statusRow.appendChild(statusLabel);\n            statusRow.appendChild(statusValue);\n\n            const msgDiv = document.createElement('div');\n            msgDiv.style.marginTop = '8px';\n            const strongEl = document.createElement('strong');\n            strongEl.textContent = label;\n            msgDiv.append(\n                `The ${queryType} `,\n                strongEl,\n                ' has not been included in any release yet, or is not in the index.'\n            );\n\n            container.appendChild(statusRow);\n            container.appendChild(msgDiv);\n            resultDiv.appendChild(container);\n            return;\n        }\n\n        const repoUrl = \"https://github.com/sgl-project/sglang\";\n        resultDiv.innerHTML = ''; // Clear previous results\n\n        for (const tagRef of tagRefs) {\n            const tagInfo = getTagInfo(tagRef);\n            const dateStr = formatDate(tagInfo.date);\n            const tagUrl = `${repoUrl}/releases/tag/${encodeURIComponent(tagInfo.name)}`;\n            const badgeClass = tagInfo.type === 'gateway' ? 'badge-gateway' : 'badge-main';\n\n            const container = document.createElement('div');\n            container.className = 'result-content result-success';\n            container.style.marginBottom = '0.75rem';\n\n            container.innerHTML = `\n                <div class=\"result-row\">\n                    <span class=\"result-label\">Release</span>\n                    <a target=\"_blank\" class=\"tag-link\"></a>\n                </div>\n                <div class=\"result-row\">\n                    <span class=\"result-label\">Date</span>\n                    <span class=\"date-value\"></span>\n                </div>\n                <div class=\"result-row\">\n                    <span class=\"result-label\">Module</span>\n                    <span class=\"badge ${badgeClass} module-value\"></span>\n                </div>\n            `;\n\n            // Set dynamic content safely via textContent\n            const link = container.querySelector('.tag-link');\n            link.href = tagUrl;\n            link.textContent = tagInfo.name;\n            container.querySelector('.date-value').textContent = dateStr;\n            container.querySelector('.module-value').textContent = tagInfo.type;\n\n            resultDiv.appendChild(container);\n        }\n    }\n</script>\n\n</body>\n</html>\n"
  },
  {
    "path": "docs/release_lookup/release_index.json",
    "content": "{\"t\":[[\"v0.1.3\",\"2024-01-16 05:55:25 +0000\",0],[\"v0.1.5\",\"2024-01-17 18:37:02 -0800\",0],[\"v0.1.6\",\"2024-01-21 01:45:02 -0800\",0],[\"v0.1.7\",\"2024-01-21 10:31:02 +0000\",0],[\"v0.1.8\",\"2024-01-24 03:33:34 -0800\",0],[\"v0.1.9\",\"2024-01-24 11:37:25 +0000\",0],[\"v0.1.10\",\"2024-01-30 15:37:52 +0000\",0],[\"v0.1.11\",\"2024-02-03 02:50:13 -0800\",0],[\"v0.1.12\",\"2024-02-11 06:43:45 -0800\",0],[\"v0.1.13\",\"2024-03-11 05:49:27 -0700\",0],[\"v0.1.14\",\"2024-03-22 13:42:22 -0700\",0],[\"v0.1.15\",\"2024-05-12 14:22:33 -0700\",0],[\"v0.1.16\",\"2024-05-13 17:29:17 -0700\",0],[\"v0.1.17\",\"2024-06-07 19:49:18 -0700\",0],[\"v0.1.18\",\"2024-07-04 06:27:29 +0000\",0],[\"v0.1.19\",\"2024-07-09 02:23:14 -0700\",0],[\"v0.1.20\",\"2024-07-13 17:27:55 -0700\",0],[\"v0.1.21\",\"2024-07-15 13:10:53 -0700\",0],[\"v0.1.22\",\"2024-07-20 03:39:50 -0700\",0],[\"v0.1.23\",\"2024-07-23 13:49:34 -0700\",0],[\"v0.1.24\",\"2024-07-24 15:55:01 -0700\",0],[\"v0.2.0\",\"2024-07-25 08:03:36 -0700\",0],[\"v0.2.5\",\"2024-07-27 05:56:30 +1000\",0],[\"v0.2.6\",\"2024-07-27 20:29:33 -0700\",0],[\"v0.2.7\",\"2024-07-30 20:41:10 +1000\",0],[\"v0.2.8\",\"2024-08-01 14:18:26 -0700\",0],[\"v0.2.9\",\"2024-08-02 01:45:48 -0700\",0],[\"v0.2.9.post1\",\"2024-08-02 12:08:00 -0700\",0],[\"v0.2.10\",\"2024-08-04 16:52:51 -0700\",0],[\"v0.2.11\",\"2024-08-07 20:47:53 +0800\",0],[\"v0.2.12\",\"2024-08-12 20:59:38 +1000\",0],[\"v0.2.13\",\"2024-08-16 03:50:43 +1000\",0],[\"v0.2.14\",\"2024-08-27 00:28:24 +1000\",0],[\"v0.2.14.post1\",\"2024-08-28 21:16:47 +1000\",0],[\"v0.2.14.post2\",\"2024-08-28 18:46:33 +0000\",0],[\"v0.2.15\",\"2024-09-01 22:22:38 -0700\",0],[\"v0.3.0\",\"2024-09-04 04:21:21 -0700\",0],[\"v0.3.1.post1\",\"2024-09-17 01:47:31 -0700\",0],[\"v0.3.1.post2\",\"2024-09-19 02:03:38 -0700\",0],[\"v0.3.1.post3\",\"2024-09-21 11:17:45 +0800\",0],[\"v0.3.2\",\"2024-09-25 14:17:09 +0800\",0],[\"v0.3.3\",\"2024-10-08 12:58:41 -0700\",0],[\"v0.3.3.post1\",\"2024-10-11 07:56:16 -0700\",0],[\"v0.3.4\",\"2024-10-19 08:17:41 -0700\",0],[\"v0.3.4.post1\",\"2024-10-21 21:16:43 -0700\",0],[\"v0.3.4.post2\",\"2024-10-25 11:07:19 -0700\",0],[\"v0.3.5\",\"2024-11-03 13:48:11 -0800\",0],[\"v0.3.5.post1\",\"2024-11-13 10:27:12 -0800\",0],[\"v0.3.5.post2\",\"2024-11-15 06:54:00 -0800\",0],[\"v0.3.6\",\"2024-11-22 19:27:30 +0800\",0],[\"v0.3.6.post1\",\"2024-11-25 17:31:37 -0800\",0],[\"v0.3.6.post2\",\"2024-11-27 03:35:30 -0800\",0],[\"v0.3.6.post3\",\"2024-11-30 01:41:16 +0800\",0],[\"v0.4.0\",\"2024-12-03 11:55:41 -0800\",0],[\"v0.4.0.post1\",\"2024-12-06 06:08:19 -0800\",0],[\"v0.4.0.post2\",\"2024-12-21 21:16:34 +0800\",0],[\"v0.4.1\",\"2024-12-26 07:14:51 +0800\",0],[\"v0.4.1.post1\",\"2024-12-28 00:11:06 +0800\",0],[\"v0.4.1.post2\",\"2024-12-30 00:11:46 +0800\",0],[\"v0.4.1.post3\",\"2024-12-29 14:25:53 -0800\",0],[\"v0.4.1.post4\",\"2025-01-06 01:29:54 +0800\",0],[\"v0.4.1.post5\",\"2025-01-11 23:10:02 +0800\",0],[\"v0.4.1.post6\",\"2025-01-15 16:23:42 +0800\",0],[\"v0.4.1.post7\",\"2025-01-20 21:50:55 +0800\",0],[\"v0.4.2\",\"2025-01-27 21:42:05 +0800\",0],[\"v0.4.2.post1\",\"2025-01-31 20:35:55 +0800\",0],[\"v0.4.2.post2\",\"2025-02-05 17:35:02 +0800\",0],[\"v0.4.2.post3\",\"2025-02-07 08:20:03 -0800\",0],[\"v0.4.2.post4\",\"2025-02-10 14:12:16 +0800\",0],[\"v0.4.3\",\"2025-02-14 09:43:14 +0800\",0],[\"v0.4.3.post1\",\"2025-02-17 21:58:19 +0800\",0],[\"v0.4.3.post2\",\"2025-02-18 02:48:30 +0800\",0],[\"v0.4.3.post3\",\"2025-03-05 17:26:10 -0800\",0],[\"v0.4.3.post4\",\"2025-03-06 12:50:28 -0800\",0],[\"v0.4.4\",\"2025-03-13 02:49:58 -0700\",0],[\"v0.4.4.post1\",\"2025-03-13 17:53:46 -0700\",0],[\"v0.4.4.post2\",\"2025-03-26 19:58:00 -0700\",0],[\"v0.4.4.post3\",\"2025-03-28 23:21:24 -0700\",0],[\"v0.4.4.post4\",\"2025-04-05 15:36:17 -0700\",0],[\"v0.4.5\",\"2025-04-07 00:35:00 -0700\",0],[\"v0.4.5.post1\",\"2025-04-15 23:00:07 -0700\",0],[\"v0.4.5.post2\",\"2025-04-20 14:12:37 -0700\",0],[\"v0.4.5.post3\",\"2025-04-21 18:16:20 -0700\",0],[\"v0.4.6\",\"2025-04-27 14:07:05 -0700\",0],[\"v0.4.6.post1\",\"2025-04-28 12:57:08 -0700\",0],[\"v0.4.6.post2\",\"2025-04-30 22:04:40 -0700\",0],[\"v0.4.6.post3\",\"2025-05-09 15:38:47 -0700\",0],[\"v0.4.6.post4\",\"2025-05-13 01:57:51 -0700\",0],[\"v0.4.6.post5\",\"2025-05-24 00:48:05 -0700\",0],[\"v0.4.7\",\"2025-06-10 01:56:20 -0700\",0],[\"v0.4.7.post1\",\"2025-06-16 15:20:29 -0700\",0],[\"v0.4.8\",\"2025-06-23 23:14:22 -0700\",0],[\"v0.4.8.post1\",\"2025-06-26 02:21:12 -0700\",0],[\"v0.4.9\",\"2025-07-05 17:40:29 -0700\",0],[\"gateway-v0.1.5\",\"2025-07-06 22:54:17 -0700\",1],[\"v0.4.9.post1\",\"2025-07-09 00:28:17 -0700\",0],[\"v0.4.9.post2\",\"2025-07-11 21:11:20 -0700\",0],[\"gateway-v0.1.6\",\"2025-07-20 23:13:20 -0700\",1],[\"v0.4.9.post3\",\"2025-07-22 15:55:48 -0700\",0],[\"v0.4.9.post4\",\"2025-07-25 17:12:47 -0700\",0],[\"v0.4.9.post5\",\"2025-07-28 02:11:06 -0700\",0],[\"v0.4.9.post6\",\"2025-07-29 02:30:07 -0700\",0],[\"v0.4.10\",\"2025-07-31 20:50:17 +0800\",0],[\"gateway-v0.1.7\",\"2025-07-31 11:24:12 -0700\",1],[\"gateway-v0.1.8\",\"2025-07-31 19:00:23 -0700\",1],[\"v0.4.10.post1\",\"2025-08-01 12:07:30 +0800\",0],[\"v0.4.10.post2\",\"2025-08-03 03:43:29 -0700\",0],[\"gateway-v0.1.9\",\"2025-08-07 09:29:12 -0700\",1],[\"v0.5.1\",\"2025-08-23 07:09:26 -0700\",0],[\"v0.5.1.post1\",\"2025-08-24 01:14:17 -0700\",0],[\"v0.5.1.post2\",\"2025-08-25 03:45:09 -0700\",0],[\"v0.5.1.post3\",\"2025-08-27 15:42:42 -0700\",0],[\"v0.5.2\",\"2025-09-11 16:09:20 -0700\",0],[\"v0.5.3\",\"2025-10-06 20:07:02 +0800\",0],[\"v0.5.3.post1\",\"2025-10-09 15:19:59 -0700\",0],[\"gateway-v0.2.0\",\"2025-10-14 22:10:30 -0400\",1],[\"v0.5.3.post2\",\"2025-10-15 16:49:14 -0700\",0],[\"v0.5.3.post3\",\"2025-10-16 13:14:55 -0700\",0],[\"gateway-v0.2.1\",\"2025-10-20 21:08:45 -0700\",1],[\"v0.5.4\",\"2025-10-23 18:01:40 -0700\",0],[\"v0.5.4.post1\",\"2025-10-27 09:35:20 +0800\",0],[\"gateway-v0.2.2\",\"2025-10-30 14:40:13 -0700\",1],[\"v0.5.4.post2\",\"2025-10-31 17:38:50 -0700\",0],[\"v0.5.4.post3\",\"2025-11-04 18:32:11 -0800\",0],[\"v0.5.5\",\"2025-11-07 00:46:19 +0800\",0],[\"v0.5.5.post1\",\"2025-11-10 11:53:43 -0800\",0],[\"v0.5.5.post2\",\"2025-11-12 20:35:20 +0800\",0],[\"gateway-v0.2.3\",\"2025-11-14 19:04:20 -0800\",1],[\"v0.5.5.post3\",\"2025-11-16 17:55:38 -0800\",0],[\"v0.5.6\",\"2025-12-02 17:17:13 -0800\",0],[\"v0.5.6.post1\",\"2025-12-08 13:41:01 -0800\",0],[\"gateway-v0.2.4\",\"2025-12-09 16:36:17 -0800\",1],[\"v0.5.6.post2\",\"2025-12-11 12:29:52 -0800\",0],[\"gateway-v0.3.0\",\"2025-12-24 16:25:05 -0500\",1],[\"v0.5.7\",\"2026-01-01 10:59:48 +0800\",0],[\"gateway-v0.3.1\",\"2026-01-08 21:50:34 -0800\",1],[\"v0.5.8\",\"2026-01-23 09:58:11 -0800\",0],[\"v0.5.8.post1\",\"2026-02-05 20:56:52 +0800\",0]],\"p\":{\"10\":{\"m\":0,\"g\":94},\"8\":{\"m\":0,\"g\":94},\"7\":{\"m\":0,\"g\":94},\"6\":{\"m\":0,\"g\":94},\"4\":{\"m\":0,\"g\":94},\"3\":{\"m\":0,\"g\":94},\"2\":{\"m\":0,\"g\":94},\"1\":{\"m\":0,\"g\":94},\"32\":{\"m\":1,\"g\":94},\"18\":{\"m\":1,\"g\":94},\"30\":{\"m\":1,\"g\":94},\"20\":{\"m\":1,\"g\":94},\"19\":{\"m\":1,\"g\":94},\"9\":{\"m\":1,\"g\":94},\"17\":{\"m\":1,\"g\":94},\"16\":{\"m\":1,\"g\":94},\"15\":{\"m\":1,\"g\":94},\"12\":{\"m\":1,\"g\":94},\"11\":{\"m\":1,\"g\":94},\"68\":{\"m\":2,\"g\":94},\"67\":{\"m\":2,\"g\":94},\"64\":{\"m\":2,\"g\":94},\"63\":{\"m\":2,\"g\":94},\"58\":{\"m\":2,\"g\":94},\"57\":{\"m\":2,\"g\":94},\"36\":{\"m\":2,\"g\":94},\"52\":{\"m\":2,\"g\":94},\"50\":{\"m\":2,\"g\":94},\"49\":{\"m\":2,\"g\":94},\"47\":{\"m\":2,\"g\":94},\"46\":{\"m\":2,\"g\":94},\"45\":{\"m\":2,\"g\":94},\"42\":{\"m\":2,\"g\":94},\"34\":{\"m\":2,\"g\":94},\"33\":{\"m\":2,\"g\":94},\"93\":{\"m\":4,\"g\":94},\"92\":{\"m\":4,\"g\":94},\"90\":{\"m\":4,\"g\":94},\"87\":{\"m\":4,\"g\":94},\"84\":{\"m\":4,\"g\":94},\"83\":{\"m\":4,\"g\":94},\"82\":{\"m\":4,\"g\":94},\"75\":{\"m\":4,\"g\":94},\"80\":{\"m\":4,\"g\":94},\"72\":{\"m\":4,\"g\":94},\"37\":{\"m\":4,\"g\":94},\"71\":{\"m\":4,\"g\":94},\"113\":{\"m\":6,\"g\":94},\"121\":{\"m\":6,\"g\":94},\"120\":{\"m\":6,\"g\":94},\"118\":{\"m\":6,\"g\":94},\"114\":{\"m\":6,\"g\":94},\"117\":{\"m\":6,\"g\":94},\"108\":{\"m\":6,\"g\":94},\"103\":{\"m\":6,\"g\":94},\"101\":{\"m\":6,\"g\":94},\"98\":{\"m\":6,\"g\":94},\"48\":{\"m\":6,\"g\":94},\"97\":{\"m\":6,\"g\":94},\"95\":{\"m\":6,\"g\":94},\"134\":{\"m\":7,\"g\":94},\"133\":{\"m\":7,\"g\":94},\"132\":{\"m\":7,\"g\":94},\"112\":{\"m\":7,\"g\":94},\"129\":{\"m\":7,\"g\":94},\"125\":{\"m\":7,\"g\":94},\"119\":{\"m\":7,\"g\":94},\"116\":{\"m\":7,\"g\":94},\"178\":{\"m\":8,\"g\":94},\"177\":{\"m\":8,\"g\":94},\"172\":{\"m\":8,\"g\":94},\"174\":{\"m\":8,\"g\":94},\"168\":{\"m\":8,\"g\":94},\"170\":{\"m\":8,\"g\":94},\"162\":{\"m\":8,\"g\":94},\"160\":{\"m\":8,\"g\":94},\"156\":{\"m\":8,\"g\":94},\"155\":{\"m\":8,\"g\":94},\"130\":{\"m\":8,\"g\":94},\"141\":{\"m\":8,\"g\":94},\"153\":{\"m\":8,\"g\":94},\"148\":{\"m\":8,\"g\":94},\"146\":{\"m\":8,\"g\":94},\"144\":{\"m\":8,\"g\":94},\"142\":{\"m\":8,\"g\":94},\"137\":{\"m\":8,\"g\":94},\"136\":{\"m\":8,\"g\":94},\"280\":{\"m\":9,\"g\":94},\"279\":{\"m\":9,\"g\":94},\"230\":{\"m\":9,\"g\":94},\"277\":{\"m\":9,\"g\":94},\"278\":{\"m\":9,\"g\":94},\"256\":{\"m\":9,\"g\":94},\"222\":{\"m\":9,\"g\":94},\"261\":{\"m\":9,\"g\":94},\"275\":{\"m\":9,\"g\":94},\"263\":{\"m\":9,\"g\":94},\"201\":{\"m\":9,\"g\":94},\"224\":{\"m\":9,\"g\":94},\"253\":{\"m\":9,\"g\":94},\"226\":{\"m\":9,\"g\":94},\"195\":{\"m\":9,\"g\":94},\"198\":{\"m\":9,\"g\":94},\"225\":{\"m\":9,\"g\":94},\"219\":{\"m\":9,\"g\":94},\"193\":{\"m\":9,\"g\":94},\"210\":{\"m\":9,\"g\":94},\"207\":{\"m\":9,\"g\":94},\"200\":{\"m\":9,\"g\":94},\"196\":{\"m\":9,\"g\":94},\"189\":{\"m\":9,\"g\":94},\"186\":{\"m\":9,\"g\":94},\"184\":{\"m\":9,\"g\":94},\"181\":{\"m\":9,\"g\":94},\"182\":{\"m\":9,\"g\":94},\"324\":{\"m\":10,\"g\":94},\"323\":{\"m\":10,\"g\":94},\"301\":{\"m\":10,\"g\":94},\"304\":{\"m\":10,\"g\":94},\"311\":{\"m\":10,\"g\":94},\"291\":{\"m\":10,\"g\":94},\"290\":{\"m\":10,\"g\":94},\"288\":{\"m\":10,\"g\":94},\"286\":{\"m\":10,\"g\":94},\"287\":{\"m\":10,\"g\":94},\"282\":{\"m\":10,\"g\":94},\"242\":{\"m\":10,\"g\":94},\"281\":{\"m\":10,\"g\":94},\"431\":{\"m\":11,\"g\":94},\"430\":{\"m\":11,\"g\":94},\"429\":{\"m\":11,\"g\":94},\"428\":{\"m\":11,\"g\":94},\"427\":{\"m\":11,\"g\":94},\"422\":{\"m\":11,\"g\":94},\"420\":{\"m\":11,\"g\":94},\"380\":{\"m\":11,\"g\":94},\"416\":{\"m\":11,\"g\":94},\"415\":{\"m\":11,\"g\":94},\"412\":{\"m\":11,\"g\":94},\"411\":{\"m\":11,\"g\":94},\"381\":{\"m\":11,\"g\":94},\"392\":{\"m\":11,\"g\":94},\"390\":{\"m\":11,\"g\":94},\"406\":{\"m\":11,\"g\":94},\"399\":{\"m\":11,\"g\":94},\"395\":{\"m\":11,\"g\":94},\"394\":{\"m\":11,\"g\":94},\"382\":{\"m\":11,\"g\":94},\"385\":{\"m\":11,\"g\":94},\"378\":{\"m\":11,\"g\":94},\"372\":{\"m\":11,\"g\":94},\"375\":{\"m\":11,\"g\":94},\"364\":{\"m\":11,\"g\":94},\"370\":{\"m\":11,\"g\":94},\"368\":{\"m\":11,\"g\":94},\"369\":{\"m\":11,\"g\":94},\"358\":{\"m\":11,\"g\":94},\"355\":{\"m\":11,\"g\":94},\"354\":{\"m\":11,\"g\":94},\"346\":{\"m\":11,\"g\":94},\"338\":{\"m\":11,\"g\":94},\"345\":{\"m\":11,\"g\":94},\"343\":{\"m\":11,\"g\":94},\"315\":{\"m\":11,\"g\":94},\"332\":{\"m\":11,\"g\":94},\"337\":{\"m\":11,\"g\":94},\"331\":{\"m\":11,\"g\":94},\"329\":{\"m\":11,\"g\":94},\"293\":{\"m\":11,\"g\":94},\"327\":{\"m\":11,\"g\":94},\"326\":{\"m\":11,\"g\":94},\"298\":{\"m\":11,\"g\":94},\"438\":{\"m\":12,\"g\":94},\"437\":{\"m\":12,\"g\":94},\"426\":{\"m\":12,\"g\":94},\"436\":{\"m\":12,\"g\":94},\"434\":{\"m\":12,\"g\":94},\"418\":{\"m\":12,\"g\":94},\"433\":{\"m\":12,\"g\":94},\"363\":{\"m\":12,\"g\":94},\"432\":{\"m\":12,\"g\":94},\"515\":{\"m\":13,\"g\":94},\"514\":{\"m\":13,\"g\":94},\"505\":{\"m\":13,\"g\":94},\"512\":{\"m\":13,\"g\":94},\"502\":{\"m\":13,\"g\":94},\"500\":{\"m\":13,\"g\":94},\"511\":{\"m\":13,\"g\":94},\"493\":{\"m\":13,\"g\":94},\"491\":{\"m\":13,\"g\":94},\"492\":{\"m\":13,\"g\":94},\"488\":{\"m\":13,\"g\":94},\"486\":{\"m\":13,\"g\":94},\"480\":{\"m\":13,\"g\":94},\"484\":{\"m\":13,\"g\":94},\"477\":{\"m\":13,\"g\":94},\"475\":{\"m\":13,\"g\":94},\"476\":{\"m\":13,\"g\":94},\"440\":{\"m\":13,\"g\":94},\"471\":{\"m\":13,\"g\":94},\"463\":{\"m\":13,\"g\":94},\"470\":{\"m\":13,\"g\":94},\"460\":{\"m\":13,\"g\":94},\"459\":{\"m\":13,\"g\":94},\"458\":{\"m\":13,\"g\":94},\"457\":{\"m\":13,\"g\":94},\"456\":{\"m\":13,\"g\":94},\"250\":{\"m\":13,\"g\":94},\"451\":{\"m\":13,\"g\":94},\"449\":{\"m\":13,\"g\":94},\"448\":{\"m\":13,\"g\":94},\"447\":{\"m\":13,\"g\":94},\"446\":{\"m\":13,\"g\":94},\"441\":{\"m\":13,\"g\":94},\"419\":{\"m\":13,\"g\":94},\"579\":{\"m\":14,\"g\":94},\"585\":{\"m\":14,\"g\":94},\"583\":{\"m\":14,\"g\":94},\"578\":{\"m\":14,\"g\":94},\"577\":{\"m\":14,\"g\":94},\"576\":{\"m\":14,\"g\":94},\"574\":{\"m\":14,\"g\":94},\"545\":{\"m\":14,\"g\":94},\"571\":{\"m\":14,\"g\":94},\"569\":{\"m\":14,\"g\":94},\"568\":{\"m\":14,\"g\":94},\"567\":{\"m\":14,\"g\":94},\"566\":{\"m\":14,\"g\":94},\"564\":{\"m\":14,\"g\":94},\"563\":{\"m\":14,\"g\":94},\"561\":{\"m\":14,\"g\":94},\"560\":{\"m\":14,\"g\":94},\"559\":{\"m\":14,\"g\":94},\"558\":{\"m\":14,\"g\":94},\"557\":{\"m\":14,\"g\":94},\"556\":{\"m\":14,\"g\":94},\"554\":{\"m\":14,\"g\":94},\"550\":{\"m\":14,\"g\":94},\"553\":{\"m\":14,\"g\":94},\"551\":{\"m\":14,\"g\":94},\"546\":{\"m\":14,\"g\":94},\"542\":{\"m\":14,\"g\":94},\"540\":{\"m\":14,\"g\":94},\"539\":{\"m\":14,\"g\":94},\"538\":{\"m\":14,\"g\":94},\"517\":{\"m\":14,\"g\":94},\"531\":{\"m\":14,\"g\":94},\"516\":{\"m\":14,\"g\":94},\"526\":{\"m\":14,\"g\":94},\"525\":{\"m\":14,\"g\":94},\"524\":{\"m\":14,\"g\":94},\"518\":{\"m\":14,\"g\":94},\"605\":{\"m\":15,\"g\":94},\"503\":{\"m\":15,\"g\":94},\"530\":{\"m\":15,\"g\":94},\"603\":{\"m\":15,\"g\":94},\"598\":{\"m\":15,\"g\":94},\"602\":{\"m\":15,\"g\":94},\"604\":{\"m\":15,\"g\":94},\"601\":{\"m\":15,\"g\":94},\"600\":{\"m\":15,\"g\":94},\"599\":{\"m\":15,\"g\":94},\"586\":{\"m\":15,\"g\":94},\"594\":{\"m\":15,\"g\":94},\"593\":{\"m\":15,\"g\":94},\"592\":{\"m\":15,\"g\":94},\"588\":{\"m\":15,\"g\":94},\"618\":{\"m\":16,\"g\":94},\"616\":{\"m\":16,\"g\":94},\"615\":{\"m\":16,\"g\":94},\"614\":{\"m\":16,\"g\":94},\"613\":{\"m\":16,\"g\":94},\"612\":{\"m\":16,\"g\":94},\"611\":{\"m\":16,\"g\":94},\"610\":{\"m\":16,\"g\":94},\"609\":{\"m\":16,\"g\":94},\"607\":{\"m\":16,\"g\":94},\"626\":{\"m\":17,\"g\":94},\"625\":{\"m\":17,\"g\":94},\"623\":{\"m\":17,\"g\":94},\"621\":{\"m\":17,\"g\":94},\"620\":{\"m\":17,\"g\":94},\"619\":{\"m\":17,\"g\":94},\"677\":{\"m\":18,\"g\":94},\"676\":{\"m\":18,\"g\":94},\"675\":{\"m\":18,\"g\":94},\"664\":{\"m\":18,\"g\":94},\"673\":{\"m\":18,\"g\":94},\"671\":{\"m\":18,\"g\":94},\"669\":{\"m\":18,\"g\":94},\"668\":{\"m\":18,\"g\":94},\"667\":{\"m\":18,\"g\":94},\"640\":{\"m\":18,\"g\":94},\"666\":{\"m\":18,\"g\":94},\"665\":{\"m\":18,\"g\":94},\"663\":{\"m\":18,\"g\":94},\"662\":{\"m\":18,\"g\":94},\"661\":{\"m\":18,\"g\":94},\"660\":{\"m\":18,\"g\":94},\"659\":{\"m\":18,\"g\":94},\"655\":{\"m\":18,\"g\":94},\"657\":{\"m\":18,\"g\":94},\"658\":{\"m\":18,\"g\":94},\"656\":{\"m\":18,\"g\":94},\"654\":{\"m\":18,\"g\":94},\"653\":{\"m\":18,\"g\":94},\"651\":{\"m\":18,\"g\":94},\"650\":{\"m\":18,\"g\":94},\"648\":{\"m\":18,\"g\":94},\"647\":{\"m\":18,\"g\":94},\"649\":{\"m\":18,\"g\":94},\"646\":{\"m\":18,\"g\":94},\"645\":{\"m\":18,\"g\":94},\"643\":{\"m\":18,\"g\":94},\"642\":{\"m\":18,\"g\":94},\"617\":{\"m\":18,\"g\":94},\"638\":{\"m\":18,\"g\":94},\"637\":{\"m\":18,\"g\":94},\"636\":{\"m\":18,\"g\":94},\"635\":{\"m\":18,\"g\":94},\"632\":{\"m\":18,\"g\":94},\"633\":{\"m\":18,\"g\":94},\"624\":{\"m\":18,\"g\":94},\"630\":{\"m\":18,\"g\":94},\"631\":{\"m\":18,\"g\":94},\"629\":{\"m\":18,\"g\":94},\"628\":{\"m\":18,\"g\":94},\"627\":{\"m\":18,\"g\":94},\"705\":{\"m\":19,\"g\":94},\"704\":{\"m\":19,\"g\":94},\"701\":{\"m\":19,\"g\":94},\"702\":{\"m\":19,\"g\":94},\"700\":{\"m\":19,\"g\":94},\"698\":{\"m\":19,\"g\":94},\"697\":{\"m\":19,\"g\":94},\"696\":{\"m\":19,\"g\":94},\"695\":{\"m\":19,\"g\":94},\"694\":{\"m\":19,\"g\":94},\"692\":{\"m\":19,\"g\":94},\"691\":{\"m\":19,\"g\":94},\"690\":{\"m\":19,\"g\":94},\"689\":{\"m\":19,\"g\":94},\"688\":{\"m\":19,\"g\":94},\"687\":{\"m\":19,\"g\":94},\"686\":{\"m\":19,\"g\":94},\"685\":{\"m\":19,\"g\":94},\"684\":{\"m\":19,\"g\":94},\"682\":{\"m\":19,\"g\":94},\"681\":{\"m\":19,\"g\":94},\"679\":{\"m\":19,\"g\":94},\"670\":{\"m\":19,\"g\":94},\"678\":{\"m\":19,\"g\":94},\"718\":{\"m\":20,\"g\":94},\"717\":{\"m\":20,\"g\":94},\"716\":{\"m\":20,\"g\":94},\"715\":{\"m\":20,\"g\":94},\"714\":{\"m\":20,\"g\":94},\"713\":{\"m\":20,\"g\":94},\"712\":{\"m\":20,\"g\":94},\"711\":{\"m\":20,\"g\":94},\"708\":{\"m\":20,\"g\":94},\"709\":{\"m\":20,\"g\":94},\"707\":{\"m\":20,\"g\":94},\"706\":{\"m\":20,\"g\":94},\"730\":{\"m\":21,\"g\":94},\"729\":{\"m\":21,\"g\":94},\"728\":{\"m\":21,\"g\":94},\"727\":{\"m\":21,\"g\":94},\"726\":{\"m\":21,\"g\":94},\"725\":{\"m\":21,\"g\":94},\"720\":{\"m\":21,\"g\":94},\"723\":{\"m\":21,\"g\":94},\"724\":{\"m\":21,\"g\":94},\"722\":{\"m\":21,\"g\":94},\"721\":{\"m\":21,\"g\":94},\"719\":{\"m\":21,\"g\":94},\"755\":{\"m\":22,\"g\":94},\"754\":{\"m\":22,\"g\":94},\"753\":{\"m\":22,\"g\":94},\"752\":{\"m\":22,\"g\":94},\"751\":{\"m\":22,\"g\":94},\"740\":{\"m\":22,\"g\":94},\"743\":{\"m\":22,\"g\":94},\"742\":{\"m\":22,\"g\":94},\"739\":{\"m\":22,\"g\":94},\"741\":{\"m\":22,\"g\":94},\"736\":{\"m\":22,\"g\":94},\"734\":{\"m\":22,\"g\":94},\"733\":{\"m\":22,\"g\":94},\"731\":{\"m\":22,\"g\":94},\"779\":{\"m\":23,\"g\":94},\"778\":{\"m\":23,\"g\":94},\"776\":{\"m\":23,\"g\":94},\"775\":{\"m\":23,\"g\":94},\"774\":{\"m\":23,\"g\":94},\"773\":{\"m\":23,\"g\":94},\"772\":{\"m\":23,\"g\":94},\"766\":{\"m\":23,\"g\":94},\"770\":{\"m\":23,\"g\":94},\"769\":{\"m\":23,\"g\":94},\"767\":{\"m\":23,\"g\":94},\"761\":{\"m\":23,\"g\":94},\"763\":{\"m\":23,\"g\":94},\"762\":{\"m\":23,\"g\":94},\"760\":{\"m\":23,\"g\":94},\"757\":{\"m\":23,\"g\":94},\"693\":{\"m\":23,\"g\":94},\"830\":{\"m\":24,\"g\":94},\"829\":{\"m\":24,\"g\":94},\"828\":{\"m\":24,\"g\":94},\"825\":{\"m\":24,\"g\":94},\"826\":{\"m\":24,\"g\":94},\"823\":{\"m\":24,\"g\":94},\"824\":{\"m\":24,\"g\":94},\"822\":{\"m\":24,\"g\":94},\"821\":{\"m\":24,\"g\":94},\"820\":{\"m\":24,\"g\":94},\"819\":{\"m\":24,\"g\":94},\"807\":{\"m\":24,\"g\":94},\"817\":{\"m\":24,\"g\":94},\"814\":{\"m\":24,\"g\":94},\"815\":{\"m\":24,\"g\":94},\"812\":{\"m\":24,\"g\":94},\"809\":{\"m\":24,\"g\":94},\"699\":{\"m\":24,\"g\":94},\"806\":{\"m\":24,\"g\":94},\"802\":{\"m\":24,\"g\":94},\"805\":{\"m\":24,\"g\":94},\"803\":{\"m\":24,\"g\":94},\"800\":{\"m\":24,\"g\":94},\"799\":{\"m\":24,\"g\":94},\"797\":{\"m\":24,\"g\":94},\"793\":{\"m\":24,\"g\":94},\"796\":{\"m\":24,\"g\":94},\"795\":{\"m\":24,\"g\":94},\"794\":{\"m\":24,\"g\":94},\"792\":{\"m\":24,\"g\":94},\"791\":{\"m\":24,\"g\":94},\"790\":{\"m\":24,\"g\":94},\"789\":{\"m\":24,\"g\":94},\"788\":{\"m\":24,\"g\":94},\"787\":{\"m\":24,\"g\":94},\"786\":{\"m\":24,\"g\":94},\"785\":{\"m\":24,\"g\":94},\"784\":{\"m\":24,\"g\":94},\"783\":{\"m\":24,\"g\":94},\"781\":{\"m\":24,\"g\":94},\"877\":{\"m\":25,\"g\":94},\"872\":{\"m\":25,\"g\":94},\"873\":{\"m\":25,\"g\":94},\"871\":{\"m\":25,\"g\":94},\"870\":{\"m\":25,\"g\":94},\"869\":{\"m\":25,\"g\":94},\"864\":{\"m\":25,\"g\":94},\"862\":{\"m\":25,\"g\":94},\"861\":{\"m\":25,\"g\":94},\"860\":{\"m\":25,\"g\":94},\"811\":{\"m\":25,\"g\":94},\"852\":{\"m\":25,\"g\":94},\"858\":{\"m\":25,\"g\":94},\"856\":{\"m\":25,\"g\":94},\"855\":{\"m\":25,\"g\":94},\"850\":{\"m\":25,\"g\":94},\"848\":{\"m\":25,\"g\":94},\"843\":{\"m\":25,\"g\":94},\"842\":{\"m\":25,\"g\":94},\"838\":{\"m\":25,\"g\":94},\"840\":{\"m\":25,\"g\":94},\"890\":{\"m\":26,\"g\":94},\"886\":{\"m\":26,\"g\":94},\"889\":{\"m\":26,\"g\":94},\"888\":{\"m\":26,\"g\":94},\"883\":{\"m\":26,\"g\":94},\"882\":{\"m\":26,\"g\":94},\"880\":{\"m\":26,\"g\":94},\"879\":{\"m\":26,\"g\":94},\"749\":{\"m\":26,\"g\":94},\"878\":{\"m\":26,\"g\":94},\"876\":{\"m\":26,\"g\":94},\"875\":{\"m\":26,\"g\":94},\"899\":{\"m\":27,\"g\":94},\"896\":{\"m\":27,\"g\":94},\"895\":{\"m\":27,\"g\":94},\"894\":{\"m\":27,\"g\":94},\"884\":{\"m\":27,\"g\":94},\"891\":{\"m\":27,\"g\":94},\"923\":{\"m\":28,\"g\":94},\"916\":{\"m\":28,\"g\":94},\"920\":{\"m\":28,\"g\":94},\"918\":{\"m\":28,\"g\":94},\"917\":{\"m\":28,\"g\":94},\"915\":{\"m\":28,\"g\":94},\"905\":{\"m\":28,\"g\":94},\"914\":{\"m\":28,\"g\":94},\"912\":{\"m\":28,\"g\":94},\"911\":{\"m\":28,\"g\":94},\"909\":{\"m\":28,\"g\":94},\"866\":{\"m\":28,\"g\":94},\"904\":{\"m\":28,\"g\":94},\"908\":{\"m\":28,\"g\":94},\"900\":{\"m\":28,\"g\":94},\"970\":{\"m\":29,\"g\":94},\"966\":{\"m\":29,\"g\":94},\"967\":{\"m\":29,\"g\":94},\"960\":{\"m\":29,\"g\":94},\"965\":{\"m\":29,\"g\":94},\"964\":{\"m\":29,\"g\":94},\"963\":{\"m\":29,\"g\":94},\"932\":{\"m\":29,\"g\":94},\"936\":{\"m\":29,\"g\":94},\"957\":{\"m\":29,\"g\":94},\"953\":{\"m\":29,\"g\":94},\"948\":{\"m\":29,\"g\":94},\"941\":{\"m\":29,\"g\":94},\"940\":{\"m\":29,\"g\":94},\"835\":{\"m\":29,\"g\":94},\"934\":{\"m\":29,\"g\":94},\"935\":{\"m\":29,\"g\":94},\"928\":{\"m\":29,\"g\":94},\"927\":{\"m\":29,\"g\":94},\"926\":{\"m\":29,\"g\":94},\"925\":{\"m\":29,\"g\":94},\"921\":{\"m\":29,\"g\":94},\"1048\":{\"m\":30,\"g\":94},\"1052\":{\"m\":30,\"g\":94},\"1051\":{\"m\":30,\"g\":94},\"1049\":{\"m\":30,\"g\":94},\"1050\":{\"m\":30,\"g\":94},\"1033\":{\"m\":30,\"g\":94},\"1046\":{\"m\":30,\"g\":94},\"1047\":{\"m\":30,\"g\":94},\"1044\":{\"m\":30,\"g\":94},\"1045\":{\"m\":30,\"g\":94},\"1039\":{\"m\":30,\"g\":94},\"1037\":{\"m\":30,\"g\":94},\"1038\":{\"m\":30,\"g\":94},\"1025\":{\"m\":30,\"g\":94},\"1034\":{\"m\":30,\"g\":94},\"1031\":{\"m\":30,\"g\":94},\"1027\":{\"m\":30,\"g\":94},\"1029\":{\"m\":30,\"g\":94},\"1028\":{\"m\":30,\"g\":94},\"1022\":{\"m\":30,\"g\":94},\"1024\":{\"m\":30,\"g\":94},\"907\":{\"m\":30,\"g\":94},\"1021\":{\"m\":30,\"g\":94},\"1020\":{\"m\":30,\"g\":94},\"1019\":{\"m\":30,\"g\":94},\"990\":{\"m\":30,\"g\":94},\"1014\":{\"m\":30,\"g\":94},\"1010\":{\"m\":30,\"g\":94},\"1009\":{\"m\":30,\"g\":94},\"1007\":{\"m\":30,\"g\":94},\"959\":{\"m\":30,\"g\":94},\"997\":{\"m\":30,\"g\":94},\"1005\":{\"m\":30,\"g\":94},\"1002\":{\"m\":30,\"g\":94},\"1001\":{\"m\":30,\"g\":94},\"994\":{\"m\":30,\"g\":94},\"995\":{\"m\":30,\"g\":94},\"988\":{\"m\":30,\"g\":94},\"993\":{\"m\":30,\"g\":94},\"973\":{\"m\":30,\"g\":94},\"992\":{\"m\":30,\"g\":94},\"985\":{\"m\":30,\"g\":94},\"981\":{\"m\":30,\"g\":94},\"987\":{\"m\":30,\"g\":94},\"983\":{\"m\":30,\"g\":94},\"984\":{\"m\":30,\"g\":94},\"982\":{\"m\":30,\"g\":94},\"971\":{\"m\":30,\"g\":94},\"980\":{\"m\":30,\"g\":94},\"977\":{\"m\":30,\"g\":94},\"968\":{\"m\":30,\"g\":94},\"969\":{\"m\":30,\"g\":94},\"976\":{\"m\":30,\"g\":94},\"975\":{\"m\":30,\"g\":94},\"1111\":{\"m\":31,\"g\":94},\"1113\":{\"m\":31,\"g\":94},\"1112\":{\"m\":31,\"g\":94},\"1110\":{\"m\":31,\"g\":94},\"1107\":{\"m\":31,\"g\":94},\"1040\":{\"m\":31,\"g\":94},\"1106\":{\"m\":31,\"g\":94},\"1077\":{\"m\":31,\"g\":94},\"1104\":{\"m\":31,\"g\":94},\"1092\":{\"m\":31,\"g\":94},\"1103\":{\"m\":31,\"g\":94},\"1090\":{\"m\":31,\"g\":94},\"1082\":{\"m\":31,\"g\":94},\"1099\":{\"m\":31,\"g\":94},\"1095\":{\"m\":31,\"g\":94},\"1098\":{\"m\":31,\"g\":94},\"1096\":{\"m\":31,\"g\":94},\"1094\":{\"m\":31,\"g\":94},\"1088\":{\"m\":31,\"g\":94},\"1086\":{\"m\":31,\"g\":94},\"1084\":{\"m\":31,\"g\":94},\"1056\":{\"m\":31,\"g\":94},\"1081\":{\"m\":31,\"g\":94},\"1006\":{\"m\":31,\"g\":94},\"1079\":{\"m\":31,\"g\":94},\"1078\":{\"m\":31,\"g\":94},\"1074\":{\"m\":31,\"g\":94},\"1053\":{\"m\":31,\"g\":94},\"1070\":{\"m\":31,\"g\":94},\"1060\":{\"m\":31,\"g\":94},\"1066\":{\"m\":31,\"g\":94},\"1068\":{\"m\":31,\"g\":94},\"1057\":{\"m\":31,\"g\":94},\"1155\":{\"m\":32,\"g\":94},\"1201\":{\"m\":32,\"g\":94},\"1219\":{\"m\":32,\"g\":94},\"1212\":{\"m\":32,\"g\":94},\"1218\":{\"m\":32,\"g\":94},\"1217\":{\"m\":32,\"g\":94},\"1215\":{\"m\":32,\"g\":94},\"1214\":{\"m\":32,\"g\":94},\"1204\":{\"m\":32,\"g\":94},\"1213\":{\"m\":32,\"g\":94},\"1210\":{\"m\":32,\"g\":94},\"1211\":{\"m\":32,\"g\":94},\"1209\":{\"m\":32,\"g\":94},\"1208\":{\"m\":32,\"g\":94},\"1186\":{\"m\":32,\"g\":94},\"1205\":{\"m\":32,\"g\":94},\"1207\":{\"m\":32,\"g\":94},\"1199\":{\"m\":32,\"g\":94},\"1202\":{\"m\":32,\"g\":94},\"1198\":{\"m\":32,\"g\":94},\"1194\":{\"m\":32,\"g\":94},\"1193\":{\"m\":32,\"g\":94},\"1123\":{\"m\":32,\"g\":94},\"1185\":{\"m\":32,\"g\":94},\"1184\":{\"m\":32,\"g\":94},\"1180\":{\"m\":32,\"g\":94},\"1168\":{\"m\":32,\"g\":94},\"1179\":{\"m\":32,\"g\":94},\"1167\":{\"m\":32,\"g\":94},\"1170\":{\"m\":32,\"g\":94},\"1177\":{\"m\":32,\"g\":94},\"1171\":{\"m\":32,\"g\":94},\"1157\":{\"m\":32,\"g\":94},\"1166\":{\"m\":32,\"g\":94},\"1154\":{\"m\":32,\"g\":94},\"1165\":{\"m\":32,\"g\":94},\"1148\":{\"m\":32,\"g\":94},\"1164\":{\"m\":32,\"g\":94},\"1134\":{\"m\":32,\"g\":94},\"1138\":{\"m\":32,\"g\":94},\"1035\":{\"m\":32,\"g\":94},\"1144\":{\"m\":32,\"g\":94},\"1143\":{\"m\":32,\"g\":94},\"1141\":{\"m\":32,\"g\":94},\"1140\":{\"m\":32,\"g\":94},\"1139\":{\"m\":32,\"g\":94},\"1136\":{\"m\":32,\"g\":94},\"1133\":{\"m\":32,\"g\":94},\"1131\":{\"m\":32,\"g\":94},\"1013\":{\"m\":32,\"g\":94},\"1122\":{\"m\":32,\"g\":94},\"1119\":{\"m\":32,\"g\":94},\"1115\":{\"m\":32,\"g\":94},\"1114\":{\"m\":32,\"g\":94},\"1242\":{\"m\":33,\"g\":94},\"1239\":{\"m\":33,\"g\":94},\"1233\":{\"m\":33,\"g\":94},\"1237\":{\"m\":33,\"g\":94},\"1225\":{\"m\":33,\"g\":94},\"1236\":{\"m\":33,\"g\":94},\"1231\":{\"m\":33,\"g\":94},\"1230\":{\"m\":33,\"g\":94},\"1227\":{\"m\":33,\"g\":94},\"1222\":{\"m\":33,\"g\":94},\"1223\":{\"m\":33,\"g\":94},\"1125\":{\"m\":33,\"g\":94},\"1250\":{\"m\":34,\"g\":94},\"1252\":{\"m\":34,\"g\":94},\"1249\":{\"m\":34,\"g\":94},\"1234\":{\"m\":34,\"g\":94},\"1247\":{\"m\":34,\"g\":94},\"1232\":{\"m\":34,\"g\":94},\"1244\":{\"m\":34,\"g\":94},\"1243\":{\"m\":34,\"g\":94},\"1295\":{\"m\":35,\"g\":94},\"1297\":{\"m\":35,\"g\":94},\"1296\":{\"m\":35,\"g\":94},\"1294\":{\"m\":35,\"g\":94},\"1293\":{\"m\":35,\"g\":94},\"1291\":{\"m\":35,\"g\":94},\"1290\":{\"m\":35,\"g\":94},\"1277\":{\"m\":35,\"g\":94},\"1284\":{\"m\":35,\"g\":94},\"1288\":{\"m\":35,\"g\":94},\"1286\":{\"m\":35,\"g\":94},\"1289\":{\"m\":35,\"g\":94},\"1285\":{\"m\":35,\"g\":94},\"1280\":{\"m\":35,\"g\":94},\"1262\":{\"m\":35,\"g\":94},\"1282\":{\"m\":35,\"g\":94},\"1276\":{\"m\":35,\"g\":94},\"1256\":{\"m\":35,\"g\":94},\"1269\":{\"m\":35,\"g\":94},\"1267\":{\"m\":35,\"g\":94},\"1258\":{\"m\":35,\"g\":94},\"1261\":{\"m\":35,\"g\":94},\"1260\":{\"m\":35,\"g\":94},\"1253\":{\"m\":35,\"g\":94},\"1255\":{\"m\":35,\"g\":94},\"1254\":{\"m\":35,\"g\":94},\"1327\":{\"m\":36,\"g\":94},\"1326\":{\"m\":36,\"g\":94},\"1320\":{\"m\":36,\"g\":94},\"1319\":{\"m\":36,\"g\":94},\"1318\":{\"m\":36,\"g\":94},\"1317\":{\"m\":36,\"g\":94},\"1313\":{\"m\":36,\"g\":94},\"1299\":{\"m\":36,\"g\":94},\"1308\":{\"m\":36,\"g\":94},\"1306\":{\"m\":36,\"g\":94},\"1304\":{\"m\":36,\"g\":94},\"1445\":{\"m\":37,\"g\":94},\"1444\":{\"m\":37,\"g\":94},\"1442\":{\"m\":37,\"g\":94},\"1420\":{\"m\":37,\"g\":94},\"1441\":{\"m\":37,\"g\":94},\"1440\":{\"m\":37,\"g\":94},\"1438\":{\"m\":37,\"g\":94},\"1432\":{\"m\":37,\"g\":94},\"1433\":{\"m\":37,\"g\":94},\"1431\":{\"m\":37,\"g\":94},\"1430\":{\"m\":37,\"g\":94},\"1428\":{\"m\":37,\"g\":94},\"1429\":{\"m\":37,\"g\":94},\"1427\":{\"m\":37,\"g\":94},\"1422\":{\"m\":37,\"g\":94},\"1426\":{\"m\":37,\"g\":94},\"1425\":{\"m\":37,\"g\":94},\"1418\":{\"m\":37,\"g\":94},\"1392\":{\"m\":37,\"g\":94},\"1414\":{\"m\":37,\"g\":94},\"1412\":{\"m\":37,\"g\":94},\"1411\":{\"m\":37,\"g\":94},\"1409\":{\"m\":37,\"g\":94},\"1408\":{\"m\":37,\"g\":94},\"1407\":{\"m\":37,\"g\":94},\"1406\":{\"m\":37,\"g\":94},\"1307\":{\"m\":37,\"g\":94},\"1397\":{\"m\":37,\"g\":94},\"1402\":{\"m\":37,\"g\":94},\"1403\":{\"m\":37,\"g\":94},\"1401\":{\"m\":37,\"g\":94},\"1399\":{\"m\":37,\"g\":94},\"1393\":{\"m\":37,\"g\":94},\"1381\":{\"m\":37,\"g\":94},\"1390\":{\"m\":37,\"g\":94},\"1389\":{\"m\":37,\"g\":94},\"1367\":{\"m\":37,\"g\":94},\"1385\":{\"m\":37,\"g\":94},\"1378\":{\"m\":37,\"g\":94},\"1380\":{\"m\":37,\"g\":94},\"1379\":{\"m\":37,\"g\":94},\"1376\":{\"m\":37,\"g\":94},\"1375\":{\"m\":37,\"g\":94},\"1373\":{\"m\":37,\"g\":94},\"1371\":{\"m\":37,\"g\":94},\"1370\":{\"m\":37,\"g\":94},\"1368\":{\"m\":37,\"g\":94},\"1300\":{\"m\":37,\"g\":94},\"1363\":{\"m\":37,\"g\":94},\"1360\":{\"m\":37,\"g\":94},\"1361\":{\"m\":37,\"g\":94},\"1341\":{\"m\":37,\"g\":94},\"1357\":{\"m\":37,\"g\":94},\"1298\":{\"m\":37,\"g\":94},\"1346\":{\"m\":37,\"g\":94},\"1281\":{\"m\":37,\"g\":94},\"1345\":{\"m\":37,\"g\":94},\"1339\":{\"m\":37,\"g\":94},\"1340\":{\"m\":37,\"g\":94},\"1337\":{\"m\":37,\"g\":94},\"1336\":{\"m\":37,\"g\":94},\"1328\":{\"m\":37,\"g\":94},\"1470\":{\"m\":38,\"g\":94},\"1469\":{\"m\":38,\"g\":94},\"1464\":{\"m\":38,\"g\":94},\"1458\":{\"m\":38,\"g\":94},\"1457\":{\"m\":38,\"g\":94},\"1454\":{\"m\":38,\"g\":94},\"1453\":{\"m\":38,\"g\":94},\"1452\":{\"m\":38,\"g\":94},\"1449\":{\"m\":38,\"g\":94},\"1451\":{\"m\":38,\"g\":94},\"1450\":{\"m\":38,\"g\":94},\"1448\":{\"m\":38,\"g\":94},\"1447\":{\"m\":38,\"g\":94},\"1483\":{\"m\":39,\"g\":94},\"1484\":{\"m\":39,\"g\":94},\"1482\":{\"m\":39,\"g\":94},\"1476\":{\"m\":39,\"g\":94},\"1475\":{\"m\":39,\"g\":94},\"1305\":{\"m\":39,\"g\":94},\"1472\":{\"m\":39,\"g\":94},\"1512\":{\"m\":40,\"g\":94},\"1511\":{\"m\":40,\"g\":94},\"1508\":{\"m\":40,\"g\":94},\"1510\":{\"m\":40,\"g\":94},\"1503\":{\"m\":40,\"g\":94},\"1499\":{\"m\":40,\"g\":94},\"1502\":{\"m\":40,\"g\":94},\"1500\":{\"m\":40,\"g\":94},\"1497\":{\"m\":40,\"g\":94},\"1496\":{\"m\":40,\"g\":94},\"1494\":{\"m\":40,\"g\":94},\"1490\":{\"m\":40,\"g\":94},\"1492\":{\"m\":40,\"g\":94},\"1491\":{\"m\":40,\"g\":94},\"1489\":{\"m\":40,\"g\":94},\"1456\":{\"m\":40,\"g\":94},\"1488\":{\"m\":40,\"g\":94},\"1486\":{\"m\":40,\"g\":94},\"1481\":{\"m\":40,\"g\":94},\"1605\":{\"m\":41,\"g\":94},\"1606\":{\"m\":41,\"g\":94},\"1604\":{\"m\":41,\"g\":94},\"1598\":{\"m\":41,\"g\":94},\"1603\":{\"m\":41,\"g\":94},\"1597\":{\"m\":41,\"g\":94},\"1594\":{\"m\":41,\"g\":94},\"1596\":{\"m\":41,\"g\":94},\"1595\":{\"m\":41,\"g\":94},\"1593\":{\"m\":41,\"g\":94},\"1567\":{\"m\":41,\"g\":94},\"1592\":{\"m\":41,\"g\":94},\"1591\":{\"m\":41,\"g\":94},\"1590\":{\"m\":41,\"g\":94},\"1589\":{\"m\":41,\"g\":94},\"1587\":{\"m\":41,\"g\":94},\"1586\":{\"m\":41,\"g\":94},\"1585\":{\"m\":41,\"g\":94},\"1584\":{\"m\":41,\"g\":94},\"1573\":{\"m\":41,\"g\":94},\"1582\":{\"m\":41,\"g\":94},\"1583\":{\"m\":41,\"g\":94},\"1581\":{\"m\":41,\"g\":94},\"1576\":{\"m\":41,\"g\":94},\"1561\":{\"m\":41,\"g\":94},\"1572\":{\"m\":41,\"g\":94},\"1577\":{\"m\":41,\"g\":94},\"1580\":{\"m\":41,\"g\":94},\"1574\":{\"m\":41,\"g\":94},\"1563\":{\"m\":41,\"g\":94},\"1569\":{\"m\":41,\"g\":94},\"1568\":{\"m\":41,\"g\":94},\"1566\":{\"m\":41,\"g\":94},\"1562\":{\"m\":41,\"g\":94},\"1559\":{\"m\":41,\"g\":94},\"1536\":{\"m\":41,\"g\":94},\"1557\":{\"m\":41,\"g\":94},\"1556\":{\"m\":41,\"g\":94},\"1555\":{\"m\":41,\"g\":94},\"1553\":{\"m\":41,\"g\":94},\"1554\":{\"m\":41,\"g\":94},\"1549\":{\"m\":41,\"g\":94},\"1552\":{\"m\":41,\"g\":94},\"1550\":{\"m\":41,\"g\":94},\"1548\":{\"m\":41,\"g\":94},\"1547\":{\"m\":41,\"g\":94},\"1545\":{\"m\":41,\"g\":94},\"1544\":{\"m\":41,\"g\":94},\"1543\":{\"m\":41,\"g\":94},\"1541\":{\"m\":41,\"g\":94},\"1539\":{\"m\":41,\"g\":94},\"1538\":{\"m\":41,\"g\":94},\"1537\":{\"m\":41,\"g\":94},\"1534\":{\"m\":41,\"g\":94},\"1531\":{\"m\":41,\"g\":94},\"1532\":{\"m\":41,\"g\":94},\"1530\":{\"m\":41,\"g\":94},\"1495\":{\"m\":41,\"g\":94},\"1520\":{\"m\":41,\"g\":94},\"1521\":{\"m\":41,\"g\":94},\"1528\":{\"m\":41,\"g\":94},\"1525\":{\"m\":41,\"g\":94},\"1529\":{\"m\":41,\"g\":94},\"1524\":{\"m\":41,\"g\":94},\"1513\":{\"m\":41,\"g\":94},\"1636\":{\"m\":42,\"g\":94},\"1635\":{\"m\":42,\"g\":94},\"1634\":{\"m\":42,\"g\":94},\"1633\":{\"m\":42,\"g\":94},\"1632\":{\"m\":42,\"g\":94},\"1631\":{\"m\":42,\"g\":94},\"1626\":{\"m\":42,\"g\":94},\"1579\":{\"m\":42,\"g\":94},\"1611\":{\"m\":42,\"g\":94},\"1607\":{\"m\":42,\"g\":94},\"1629\":{\"m\":42,\"g\":94},\"1625\":{\"m\":42,\"g\":94},\"1619\":{\"m\":42,\"g\":94},\"1620\":{\"m\":42,\"g\":94},\"1615\":{\"m\":42,\"g\":94},\"1714\":{\"m\":43,\"g\":94},\"1713\":{\"m\":43,\"g\":94},\"1712\":{\"m\":43,\"g\":94},\"1710\":{\"m\":43,\"g\":94},\"1709\":{\"m\":43,\"g\":94},\"1707\":{\"m\":43,\"g\":94},\"1706\":{\"m\":43,\"g\":94},\"1705\":{\"m\":43,\"g\":94},\"1704\":{\"m\":43,\"g\":94},\"1703\":{\"m\":43,\"g\":94},\"1684\":{\"m\":43,\"g\":94},\"1702\":{\"m\":43,\"g\":94},\"1701\":{\"m\":43,\"g\":94},\"1700\":{\"m\":43,\"g\":94},\"1699\":{\"m\":43,\"g\":94},\"1694\":{\"m\":43,\"g\":94},\"1697\":{\"m\":43,\"g\":94},\"1696\":{\"m\":43,\"g\":94},\"1690\":{\"m\":43,\"g\":94},\"1679\":{\"m\":43,\"g\":94},\"1689\":{\"m\":43,\"g\":94},\"1688\":{\"m\":43,\"g\":94},\"1599\":{\"m\":43,\"g\":94},\"1687\":{\"m\":43,\"g\":94},\"1686\":{\"m\":43,\"g\":94},\"1685\":{\"m\":43,\"g\":94},\"1677\":{\"m\":43,\"g\":94},\"1676\":{\"m\":43,\"g\":94},\"1681\":{\"m\":43,\"g\":94},\"1674\":{\"m\":43,\"g\":94},\"1672\":{\"m\":43,\"g\":94},\"1671\":{\"m\":43,\"g\":94},\"1670\":{\"m\":43,\"g\":94},\"1658\":{\"m\":43,\"g\":94},\"1667\":{\"m\":43,\"g\":94},\"1666\":{\"m\":43,\"g\":94},\"1665\":{\"m\":43,\"g\":94},\"1459\":{\"m\":43,\"g\":94},\"1663\":{\"m\":43,\"g\":94},\"1662\":{\"m\":43,\"g\":94},\"1661\":{\"m\":43,\"g\":94},\"1659\":{\"m\":43,\"g\":94},\"1650\":{\"m\":43,\"g\":94},\"1656\":{\"m\":43,\"g\":94},\"1652\":{\"m\":43,\"g\":94},\"1654\":{\"m\":43,\"g\":94},\"1653\":{\"m\":43,\"g\":94},\"1651\":{\"m\":43,\"g\":94},\"1648\":{\"m\":43,\"g\":94},\"1480\":{\"m\":43,\"g\":94},\"1645\":{\"m\":43,\"g\":94},\"1642\":{\"m\":43,\"g\":94},\"1638\":{\"m\":43,\"g\":94},\"1614\":{\"m\":43,\"g\":94},\"1749\":{\"m\":44,\"g\":94},\"1748\":{\"m\":44,\"g\":94},\"1551\":{\"m\":44,\"g\":94},\"1746\":{\"m\":44,\"g\":94},\"1737\":{\"m\":44,\"g\":94},\"1738\":{\"m\":44,\"g\":94},\"1743\":{\"m\":44,\"g\":94},\"1741\":{\"m\":44,\"g\":94},\"1740\":{\"m\":44,\"g\":94},\"1736\":{\"m\":44,\"g\":94},\"1735\":{\"m\":44,\"g\":94},\"1734\":{\"m\":44,\"g\":94},\"1727\":{\"m\":44,\"g\":94},\"1726\":{\"m\":44,\"g\":94},\"1725\":{\"m\":44,\"g\":94},\"1724\":{\"m\":44,\"g\":94},\"1722\":{\"m\":44,\"g\":94},\"1721\":{\"m\":44,\"g\":94},\"1720\":{\"m\":44,\"g\":94},\"1718\":{\"m\":44,\"g\":94},\"1716\":{\"m\":44,\"g\":94},\"1796\":{\"m\":45,\"g\":94},\"1795\":{\"m\":45,\"g\":94},\"1797\":{\"m\":45,\"g\":94},\"1794\":{\"m\":45,\"g\":94},\"1787\":{\"m\":45,\"g\":94},\"1793\":{\"m\":45,\"g\":94},\"1780\":{\"m\":45,\"g\":94},\"1789\":{\"m\":45,\"g\":94},\"1785\":{\"m\":45,\"g\":94},\"1783\":{\"m\":45,\"g\":94},\"1778\":{\"m\":45,\"g\":94},\"1782\":{\"m\":45,\"g\":94},\"1779\":{\"m\":45,\"g\":94},\"1776\":{\"m\":45,\"g\":94},\"1774\":{\"m\":45,\"g\":94},\"1773\":{\"m\":45,\"g\":94},\"1772\":{\"m\":45,\"g\":94},\"1771\":{\"m\":45,\"g\":94},\"1769\":{\"m\":45,\"g\":94},\"1768\":{\"m\":45,\"g\":94},\"1766\":{\"m\":45,\"g\":94},\"1767\":{\"m\":45,\"g\":94},\"1765\":{\"m\":45,\"g\":94},\"1760\":{\"m\":45,\"g\":94},\"1758\":{\"m\":45,\"g\":94},\"1747\":{\"m\":45,\"g\":94},\"1908\":{\"m\":46,\"g\":94},\"1907\":{\"m\":46,\"g\":94},\"1906\":{\"m\":46,\"g\":94},\"1902\":{\"m\":46,\"g\":94},\"1905\":{\"m\":46,\"g\":94},\"1904\":{\"m\":46,\"g\":94},\"1903\":{\"m\":46,\"g\":94},\"1899\":{\"m\":46,\"g\":94},\"1896\":{\"m\":46,\"g\":94},\"1895\":{\"m\":46,\"g\":94},\"1894\":{\"m\":46,\"g\":94},\"1892\":{\"m\":46,\"g\":94},\"1890\":{\"m\":46,\"g\":94},\"1888\":{\"m\":46,\"g\":94},\"1889\":{\"m\":46,\"g\":94},\"1886\":{\"m\":46,\"g\":94},\"1885\":{\"m\":46,\"g\":94},\"1883\":{\"m\":46,\"g\":94},\"1873\":{\"m\":46,\"g\":94},\"1882\":{\"m\":46,\"g\":94},\"1881\":{\"m\":46,\"g\":94},\"1879\":{\"m\":46,\"g\":94},\"1878\":{\"m\":46,\"g\":94},\"1877\":{\"m\":46,\"g\":94},\"1875\":{\"m\":46,\"g\":94},\"1871\":{\"m\":46,\"g\":94},\"1867\":{\"m\":46,\"g\":94},\"1866\":{\"m\":46,\"g\":94},\"1754\":{\"m\":46,\"g\":94},\"1856\":{\"m\":46,\"g\":94},\"1859\":{\"m\":46,\"g\":94},\"1860\":{\"m\":46,\"g\":94},\"1861\":{\"m\":46,\"g\":94},\"1858\":{\"m\":46,\"g\":94},\"1855\":{\"m\":46,\"g\":94},\"1852\":{\"m\":46,\"g\":94},\"1851\":{\"m\":46,\"g\":94},\"1846\":{\"m\":46,\"g\":94},\"1850\":{\"m\":46,\"g\":94},\"1847\":{\"m\":46,\"g\":94},\"1845\":{\"m\":46,\"g\":94},\"1842\":{\"m\":46,\"g\":94},\"1836\":{\"m\":46,\"g\":94},\"1838\":{\"m\":46,\"g\":94},\"1840\":{\"m\":46,\"g\":94},\"1839\":{\"m\":46,\"g\":94},\"1827\":{\"m\":46,\"g\":94},\"1833\":{\"m\":46,\"g\":94},\"1835\":{\"m\":46,\"g\":94},\"1834\":{\"m\":46,\"g\":94},\"1822\":{\"m\":46,\"g\":94},\"1823\":{\"m\":46,\"g\":94},\"1830\":{\"m\":46,\"g\":94},\"1825\":{\"m\":46,\"g\":94},\"1790\":{\"m\":46,\"g\":94},\"1821\":{\"m\":46,\"g\":94},\"1820\":{\"m\":46,\"g\":94},\"1819\":{\"m\":46,\"g\":94},\"1810\":{\"m\":46,\"g\":94},\"1817\":{\"m\":46,\"g\":94},\"1816\":{\"m\":46,\"g\":94},\"1813\":{\"m\":46,\"g\":94},\"1811\":{\"m\":46,\"g\":94},\"1809\":{\"m\":46,\"g\":94},\"1808\":{\"m\":46,\"g\":94},\"1807\":{\"m\":46,\"g\":94},\"1805\":{\"m\":46,\"g\":94},\"1804\":{\"m\":46,\"g\":94},\"1803\":{\"m\":46,\"g\":94},\"1802\":{\"m\":46,\"g\":94},\"1801\":{\"m\":46,\"g\":94},\"1800\":{\"m\":46,\"g\":94},\"1786\":{\"m\":46,\"g\":94},\"1799\":{\"m\":46,\"g\":94},\"1798\":{\"m\":46,\"g\":94},\"1752\":{\"m\":46,\"g\":94},\"2022\":{\"m\":47,\"g\":94},\"2020\":{\"m\":47,\"g\":94},\"2018\":{\"m\":47,\"g\":94},\"1996\":{\"m\":47,\"g\":94},\"2015\":{\"m\":47,\"g\":94},\"2014\":{\"m\":47,\"g\":94},\"2013\":{\"m\":47,\"g\":94},\"1998\":{\"m\":47,\"g\":94},\"2011\":{\"m\":47,\"g\":94},\"2010\":{\"m\":47,\"g\":94},\"2009\":{\"m\":47,\"g\":94},\"2008\":{\"m\":47,\"g\":94},\"2006\":{\"m\":47,\"g\":94},\"2005\":{\"m\":47,\"g\":94},\"1994\":{\"m\":47,\"g\":94},\"2003\":{\"m\":47,\"g\":94},\"2004\":{\"m\":47,\"g\":94},\"2002\":{\"m\":47,\"g\":94},\"2001\":{\"m\":47,\"g\":94},\"2000\":{\"m\":47,\"g\":94},\"1999\":{\"m\":47,\"g\":94},\"1995\":{\"m\":47,\"g\":94},\"1997\":{\"m\":47,\"g\":94},\"1934\":{\"m\":47,\"g\":94},\"1980\":{\"m\":47,\"g\":94},\"1990\":{\"m\":47,\"g\":94},\"1988\":{\"m\":47,\"g\":94},\"1984\":{\"m\":47,\"g\":94},\"1986\":{\"m\":47,\"g\":94},\"1983\":{\"m\":47,\"g\":94},\"1981\":{\"m\":47,\"g\":94},\"1982\":{\"m\":47,\"g\":94},\"1977\":{\"m\":47,\"g\":94},\"1972\":{\"m\":47,\"g\":94},\"1976\":{\"m\":47,\"g\":94},\"1975\":{\"m\":47,\"g\":94},\"1974\":{\"m\":47,\"g\":94},\"1745\":{\"m\":47,\"g\":94},\"1973\":{\"m\":47,\"g\":94},\"1963\":{\"m\":47,\"g\":94},\"1966\":{\"m\":47,\"g\":94},\"1962\":{\"m\":47,\"g\":94},\"1961\":{\"m\":47,\"g\":94},\"1958\":{\"m\":47,\"g\":94},\"1957\":{\"m\":47,\"g\":94},\"1956\":{\"m\":47,\"g\":94},\"1955\":{\"m\":47,\"g\":94},\"1954\":{\"m\":47,\"g\":94},\"1933\":{\"m\":47,\"g\":94},\"1952\":{\"m\":47,\"g\":94},\"1951\":{\"m\":47,\"g\":94},\"1939\":{\"m\":47,\"g\":94},\"1941\":{\"m\":47,\"g\":94},\"1949\":{\"m\":47,\"g\":94},\"1940\":{\"m\":47,\"g\":94},\"1942\":{\"m\":47,\"g\":94},\"1891\":{\"m\":47,\"g\":94},\"1926\":{\"m\":47,\"g\":94},\"1922\":{\"m\":47,\"g\":94},\"1853\":{\"m\":47,\"g\":94},\"1924\":{\"m\":47,\"g\":94},\"1920\":{\"m\":47,\"g\":94},\"1916\":{\"m\":47,\"g\":94},\"1893\":{\"m\":47,\"g\":94},\"1915\":{\"m\":47,\"g\":94},\"1910\":{\"m\":47,\"g\":94},\"1909\":{\"m\":47,\"g\":94},\"2046\":{\"m\":48,\"g\":94},\"2044\":{\"m\":48,\"g\":94},\"2043\":{\"m\":48,\"g\":94},\"2030\":{\"m\":48,\"g\":94},\"2042\":{\"m\":48,\"g\":94},\"2038\":{\"m\":48,\"g\":94},\"1968\":{\"m\":48,\"g\":94},\"2039\":{\"m\":48,\"g\":94},\"2036\":{\"m\":48,\"g\":94},\"2034\":{\"m\":48,\"g\":94},\"2033\":{\"m\":48,\"g\":94},\"2031\":{\"m\":48,\"g\":94},\"2027\":{\"m\":48,\"g\":94},\"2028\":{\"m\":48,\"g\":94},\"2026\":{\"m\":48,\"g\":94},\"2024\":{\"m\":48,\"g\":94},\"2023\":{\"m\":48,\"g\":94},\"2120\":{\"m\":49,\"g\":94},\"2125\":{\"m\":49,\"g\":94},\"2122\":{\"m\":49,\"g\":94},\"2110\":{\"m\":49,\"g\":94},\"2118\":{\"m\":49,\"g\":94},\"2106\":{\"m\":49,\"g\":94},\"2115\":{\"m\":49,\"g\":94},\"2055\":{\"m\":49,\"g\":94},\"2111\":{\"m\":49,\"g\":94},\"2116\":{\"m\":49,\"g\":94},\"2107\":{\"m\":49,\"g\":94},\"2105\":{\"m\":49,\"g\":94},\"2104\":{\"m\":49,\"g\":94},\"2103\":{\"m\":49,\"g\":94},\"2073\":{\"m\":49,\"g\":94},\"2100\":{\"m\":49,\"g\":94},\"2067\":{\"m\":49,\"g\":94},\"2096\":{\"m\":49,\"g\":94},\"2095\":{\"m\":49,\"g\":94},\"2093\":{\"m\":49,\"g\":94},\"2094\":{\"m\":49,\"g\":94},\"2091\":{\"m\":49,\"g\":94},\"2088\":{\"m\":49,\"g\":94},\"2089\":{\"m\":49,\"g\":94},\"2086\":{\"m\":49,\"g\":94},\"2085\":{\"m\":49,\"g\":94},\"2083\":{\"m\":49,\"g\":94},\"2078\":{\"m\":49,\"g\":94},\"2069\":{\"m\":49,\"g\":94},\"2075\":{\"m\":49,\"g\":94},\"2074\":{\"m\":49,\"g\":94},\"2072\":{\"m\":49,\"g\":94},\"2071\":{\"m\":49,\"g\":94},\"2070\":{\"m\":49,\"g\":94},\"2068\":{\"m\":49,\"g\":94},\"2062\":{\"m\":49,\"g\":94},\"2056\":{\"m\":49,\"g\":94},\"2066\":{\"m\":49,\"g\":94},\"2061\":{\"m\":49,\"g\":94},\"2065\":{\"m\":49,\"g\":94},\"2064\":{\"m\":49,\"g\":94},\"2063\":{\"m\":49,\"g\":94},\"1849\":{\"m\":49,\"g\":94},\"2053\":{\"m\":49,\"g\":94},\"2051\":{\"m\":49,\"g\":94},\"1970\":{\"m\":49,\"g\":94},\"2050\":{\"m\":49,\"g\":94},\"2049\":{\"m\":49,\"g\":94},\"1876\":{\"m\":49,\"g\":94},\"2048\":{\"m\":49,\"g\":94},\"2047\":{\"m\":49,\"g\":94},\"2189\":{\"m\":50,\"g\":94},\"2188\":{\"m\":50,\"g\":94},\"2187\":{\"m\":50,\"g\":94},\"2052\":{\"m\":50,\"g\":94},\"2184\":{\"m\":50,\"g\":94},\"2176\":{\"m\":50,\"g\":94},\"2186\":{\"m\":50,\"g\":94},\"2171\":{\"m\":50,\"g\":94},\"2185\":{\"m\":50,\"g\":94},\"2183\":{\"m\":50,\"g\":94},\"2173\":{\"m\":50,\"g\":94},\"2182\":{\"m\":50,\"g\":94},\"2180\":{\"m\":50,\"g\":94},\"2175\":{\"m\":50,\"g\":94},\"2174\":{\"m\":50,\"g\":94},\"2170\":{\"m\":50,\"g\":94},\"2169\":{\"m\":50,\"g\":94},\"2167\":{\"m\":50,\"g\":94},\"2164\":{\"m\":50,\"g\":94},\"2163\":{\"m\":50,\"g\":94},\"2162\":{\"m\":50,\"g\":94},\"2158\":{\"m\":50,\"g\":94},\"2161\":{\"m\":50,\"g\":94},\"2159\":{\"m\":50,\"g\":94},\"2156\":{\"m\":50,\"g\":94},\"2154\":{\"m\":50,\"g\":94},\"2157\":{\"m\":50,\"g\":94},\"2155\":{\"m\":50,\"g\":94},\"2153\":{\"m\":50,\"g\":94},\"2152\":{\"m\":50,\"g\":94},\"2148\":{\"m\":50,\"g\":94},\"2147\":{\"m\":50,\"g\":94},\"2146\":{\"m\":50,\"g\":94},\"2144\":{\"m\":50,\"g\":94},\"2143\":{\"m\":50,\"g\":94},\"2142\":{\"m\":50,\"g\":94},\"2114\":{\"m\":50,\"g\":94},\"2139\":{\"m\":50,\"g\":94},\"2138\":{\"m\":50,\"g\":94},\"2136\":{\"m\":50,\"g\":94},\"2137\":{\"m\":50,\"g\":94},\"2134\":{\"m\":50,\"g\":94},\"2081\":{\"m\":50,\"g\":94},\"2121\":{\"m\":50,\"g\":94},\"2130\":{\"m\":50,\"g\":94},\"2124\":{\"m\":50,\"g\":94},\"2092\":{\"m\":50,\"g\":94},\"2127\":{\"m\":50,\"g\":94},\"2077\":{\"m\":50,\"g\":94},\"2126\":{\"m\":50,\"g\":94},\"2214\":{\"m\":51,\"g\":94},\"2222\":{\"m\":51,\"g\":94},\"2221\":{\"m\":51,\"g\":94},\"2217\":{\"m\":51,\"g\":94},\"2210\":{\"m\":51,\"g\":94},\"2212\":{\"m\":51,\"g\":94},\"2208\":{\"m\":51,\"g\":94},\"2207\":{\"m\":51,\"g\":94},\"2204\":{\"m\":51,\"g\":94},\"2206\":{\"m\":51,\"g\":94},\"2201\":{\"m\":51,\"g\":94},\"2199\":{\"m\":51,\"g\":94},\"2196\":{\"m\":51,\"g\":94},\"2198\":{\"m\":51,\"g\":94},\"2195\":{\"m\":51,\"g\":94},\"2197\":{\"m\":51,\"g\":94},\"2259\":{\"m\":52,\"g\":94},\"2257\":{\"m\":52,\"g\":94},\"2256\":{\"m\":52,\"g\":94},\"2254\":{\"m\":52,\"g\":94},\"2253\":{\"m\":52,\"g\":94},\"2252\":{\"m\":52,\"g\":94},\"2251\":{\"m\":52,\"g\":94},\"2250\":{\"m\":52,\"g\":94},\"2239\":{\"m\":52,\"g\":94},\"2242\":{\"m\":52,\"g\":94},\"2238\":{\"m\":52,\"g\":94},\"2231\":{\"m\":52,\"g\":94},\"2233\":{\"m\":52,\"g\":94},\"2235\":{\"m\":52,\"g\":94},\"2234\":{\"m\":52,\"g\":94},\"2232\":{\"m\":52,\"g\":94},\"2228\":{\"m\":52,\"g\":94},\"2123\":{\"m\":52,\"g\":94},\"2223\":{\"m\":52,\"g\":94},\"2224\":{\"m\":52,\"g\":94},\"2226\":{\"m\":52,\"g\":94},\"2191\":{\"m\":52,\"g\":94},\"2218\":{\"m\":52,\"g\":94},\"2338\":{\"m\":53,\"g\":94},\"2339\":{\"m\":53,\"g\":94},\"2300\":{\"m\":53,\"g\":94},\"2335\":{\"m\":53,\"g\":94},\"2327\":{\"m\":53,\"g\":94},\"2324\":{\"m\":53,\"g\":94},\"2328\":{\"m\":53,\"g\":94},\"2329\":{\"m\":53,\"g\":94},\"2325\":{\"m\":53,\"g\":94},\"2281\":{\"m\":53,\"g\":94},\"2319\":{\"m\":53,\"g\":94},\"2318\":{\"m\":53,\"g\":94},\"2314\":{\"m\":53,\"g\":94},\"2311\":{\"m\":53,\"g\":94},\"2310\":{\"m\":53,\"g\":94},\"2309\":{\"m\":53,\"g\":94},\"2279\":{\"m\":53,\"g\":94},\"2306\":{\"m\":53,\"g\":94},\"2305\":{\"m\":53,\"g\":94},\"2304\":{\"m\":53,\"g\":94},\"2301\":{\"m\":53,\"g\":94},\"2302\":{\"m\":53,\"g\":94},\"2299\":{\"m\":53,\"g\":94},\"2298\":{\"m\":53,\"g\":94},\"2241\":{\"m\":53,\"g\":94},\"2295\":{\"m\":53,\"g\":94},\"2292\":{\"m\":53,\"g\":94},\"2179\":{\"m\":53,\"g\":94},\"2293\":{\"m\":53,\"g\":94},\"2290\":{\"m\":53,\"g\":94},\"2244\":{\"m\":53,\"g\":94},\"2288\":{\"m\":53,\"g\":94},\"2289\":{\"m\":53,\"g\":94},\"2287\":{\"m\":53,\"g\":94},\"2284\":{\"m\":53,\"g\":94},\"2286\":{\"m\":53,\"g\":94},\"2285\":{\"m\":53,\"g\":94},\"2282\":{\"m\":53,\"g\":94},\"2215\":{\"m\":53,\"g\":94},\"2280\":{\"m\":53,\"g\":94},\"2274\":{\"m\":53,\"g\":94},\"2266\":{\"m\":53,\"g\":94},\"2265\":{\"m\":53,\"g\":94},\"2269\":{\"m\":53,\"g\":94},\"2243\":{\"m\":53,\"g\":94},\"2268\":{\"m\":53,\"g\":94},\"2225\":{\"m\":53,\"g\":94},\"2261\":{\"m\":53,\"g\":94},\"2375\":{\"m\":54,\"g\":94},\"2377\":{\"m\":54,\"g\":94},\"2363\":{\"m\":54,\"g\":94},\"2374\":{\"m\":54,\"g\":94},\"2373\":{\"m\":54,\"g\":94},\"2369\":{\"m\":54,\"g\":94},\"2357\":{\"m\":54,\"g\":94},\"2360\":{\"m\":54,\"g\":94},\"2370\":{\"m\":54,\"g\":94},\"2371\":{\"m\":54,\"g\":94},\"2368\":{\"m\":54,\"g\":94},\"2359\":{\"m\":54,\"g\":94},\"2364\":{\"m\":54,\"g\":94},\"2355\":{\"m\":54,\"g\":94},\"2342\":{\"m\":54,\"g\":94},\"2352\":{\"m\":54,\"g\":94},\"2308\":{\"m\":54,\"g\":94},\"2323\":{\"m\":54,\"g\":94},\"2350\":{\"m\":54,\"g\":94},\"2340\":{\"m\":54,\"g\":94},\"2349\":{\"m\":54,\"g\":94},\"2341\":{\"m\":54,\"g\":94},\"2348\":{\"m\":54,\"g\":94},\"2525\":{\"m\":55,\"g\":94},\"2528\":{\"m\":55,\"g\":94},\"2524\":{\"m\":55,\"g\":94},\"2517\":{\"m\":55,\"g\":94},\"2516\":{\"m\":55,\"g\":94},\"2515\":{\"m\":55,\"g\":94},\"2502\":{\"m\":55,\"g\":94},\"2500\":{\"m\":55,\"g\":94},\"2499\":{\"m\":55,\"g\":94},\"2438\":{\"m\":55,\"g\":94},\"2426\":{\"m\":55,\"g\":94},\"2457\":{\"m\":55,\"g\":94},\"2467\":{\"m\":55,\"g\":94},\"2495\":{\"m\":55,\"g\":94},\"2494\":{\"m\":55,\"g\":94},\"2493\":{\"m\":55,\"g\":94},\"2492\":{\"m\":55,\"g\":94},\"2491\":{\"m\":55,\"g\":94},\"2476\":{\"m\":55,\"g\":94},\"2490\":{\"m\":55,\"g\":94},\"2489\":{\"m\":55,\"g\":94},\"2486\":{\"m\":55,\"g\":94},\"2487\":{\"m\":55,\"g\":94},\"2481\":{\"m\":55,\"g\":94},\"2485\":{\"m\":55,\"g\":94},\"2484\":{\"m\":55,\"g\":94},\"2483\":{\"m\":55,\"g\":94},\"2479\":{\"m\":55,\"g\":94},\"2473\":{\"m\":55,\"g\":94},\"2469\":{\"m\":55,\"g\":94},\"2464\":{\"m\":55,\"g\":94},\"2466\":{\"m\":55,\"g\":94},\"2463\":{\"m\":55,\"g\":94},\"2462\":{\"m\":55,\"g\":94},\"2459\":{\"m\":55,\"g\":94},\"2456\":{\"m\":55,\"g\":94},\"2444\":{\"m\":55,\"g\":94},\"2455\":{\"m\":55,\"g\":94},\"2454\":{\"m\":55,\"g\":94},\"2442\":{\"m\":55,\"g\":94},\"2453\":{\"m\":55,\"g\":94},\"2452\":{\"m\":55,\"g\":94},\"2437\":{\"m\":55,\"g\":94},\"2449\":{\"m\":55,\"g\":94},\"2448\":{\"m\":55,\"g\":94},\"2425\":{\"m\":55,\"g\":94},\"2447\":{\"m\":55,\"g\":94},\"2436\":{\"m\":55,\"g\":94},\"2441\":{\"m\":55,\"g\":94},\"2440\":{\"m\":55,\"g\":94},\"2435\":{\"m\":55,\"g\":94},\"2434\":{\"m\":55,\"g\":94},\"2433\":{\"m\":55,\"g\":94},\"2424\":{\"m\":55,\"g\":94},\"2412\":{\"m\":55,\"g\":94},\"2422\":{\"m\":55,\"g\":94},\"2419\":{\"m\":55,\"g\":94},\"2417\":{\"m\":55,\"g\":94},\"2416\":{\"m\":55,\"g\":94},\"2410\":{\"m\":55,\"g\":94},\"2393\":{\"m\":55,\"g\":94},\"2413\":{\"m\":55,\"g\":94},\"2411\":{\"m\":55,\"g\":94},\"2398\":{\"m\":55,\"g\":94},\"2409\":{\"m\":55,\"g\":94},\"2408\":{\"m\":55,\"g\":94},\"2407\":{\"m\":55,\"g\":94},\"2406\":{\"m\":55,\"g\":94},\"2405\":{\"m\":55,\"g\":94},\"2404\":{\"m\":55,\"g\":94},\"2403\":{\"m\":55,\"g\":94},\"2401\":{\"m\":55,\"g\":94},\"2397\":{\"m\":55,\"g\":94},\"2394\":{\"m\":55,\"g\":94},\"2382\":{\"m\":55,\"g\":94},\"2330\":{\"m\":55,\"g\":94},\"2392\":{\"m\":55,\"g\":94},\"2391\":{\"m\":55,\"g\":94},\"2388\":{\"m\":55,\"g\":94},\"2390\":{\"m\":55,\"g\":94},\"2387\":{\"m\":55,\"g\":94},\"2380\":{\"m\":55,\"g\":94},\"2379\":{\"m\":55,\"g\":94},\"2378\":{\"m\":55,\"g\":94},\"2582\":{\"m\":56,\"g\":94},\"2581\":{\"m\":56,\"g\":94},\"2580\":{\"m\":56,\"g\":94},\"2579\":{\"m\":56,\"g\":94},\"2575\":{\"m\":56,\"g\":94},\"2566\":{\"m\":56,\"g\":94},\"2563\":{\"m\":56,\"g\":94},\"2553\":{\"m\":56,\"g\":94},\"2545\":{\"m\":56,\"g\":94},\"2547\":{\"m\":56,\"g\":94},\"2543\":{\"m\":56,\"g\":94},\"2509\":{\"m\":56,\"g\":94},\"2523\":{\"m\":56,\"g\":94},\"2529\":{\"m\":56,\"g\":94},\"2541\":{\"m\":56,\"g\":94},\"2616\":{\"m\":57,\"g\":94},\"2617\":{\"m\":57,\"g\":94},\"2615\":{\"m\":57,\"g\":94},\"2610\":{\"m\":57,\"g\":94},\"2611\":{\"m\":57,\"g\":94},\"2612\":{\"m\":57,\"g\":94},\"2608\":{\"m\":57,\"g\":94},\"2606\":{\"m\":57,\"g\":94},\"2586\":{\"m\":57,\"g\":94},\"2605\":{\"m\":57,\"g\":94},\"2603\":{\"m\":57,\"g\":94},\"2564\":{\"m\":57,\"g\":94},\"2570\":{\"m\":57,\"g\":94},\"2521\":{\"m\":57,\"g\":94},\"2598\":{\"m\":57,\"g\":94},\"2574\":{\"m\":57,\"g\":94},\"2565\":{\"m\":57,\"g\":94},\"2597\":{\"m\":57,\"g\":94},\"2596\":{\"m\":57,\"g\":94},\"2526\":{\"m\":57,\"g\":94},\"2557\":{\"m\":57,\"g\":94},\"2594\":{\"m\":57,\"g\":94},\"2555\":{\"m\":57,\"g\":94},\"2560\":{\"m\":57,\"g\":94},\"2592\":{\"m\":57,\"g\":94},\"2590\":{\"m\":57,\"g\":94},\"2589\":{\"m\":57,\"g\":94},\"2643\":{\"m\":58,\"g\":94},\"2641\":{\"m\":58,\"g\":94},\"2637\":{\"m\":58,\"g\":94},\"2635\":{\"m\":58,\"g\":94},\"2640\":{\"m\":58,\"g\":94},\"2639\":{\"m\":58,\"g\":94},\"2638\":{\"m\":58,\"g\":94},\"2609\":{\"m\":58,\"g\":94},\"2544\":{\"m\":58,\"g\":94},\"2631\":{\"m\":58,\"g\":94},\"2628\":{\"m\":58,\"g\":94},\"2633\":{\"m\":58,\"g\":94},\"2614\":{\"m\":58,\"g\":94},\"2626\":{\"m\":58,\"g\":94},\"2624\":{\"m\":58,\"g\":94},\"2625\":{\"m\":58,\"g\":94},\"2623\":{\"m\":58,\"g\":94},\"2622\":{\"m\":58,\"g\":94},\"2475\":{\"m\":58,\"g\":94},\"2618\":{\"m\":58,\"g\":94},\"2647\":{\"m\":59,\"g\":94},\"2648\":{\"m\":59,\"g\":94},\"2646\":{\"m\":59,\"g\":94},\"2636\":{\"m\":59,\"g\":94},\"2645\":{\"m\":59,\"g\":94},\"2644\":{\"m\":59,\"g\":94},\"2713\":{\"m\":60,\"g\":94},\"2688\":{\"m\":60,\"g\":94},\"2735\":{\"m\":60,\"g\":94},\"2733\":{\"m\":60,\"g\":94},\"2731\":{\"m\":60,\"g\":94},\"2571\":{\"m\":60,\"g\":94},\"2726\":{\"m\":60,\"g\":94},\"2727\":{\"m\":60,\"g\":94},\"2722\":{\"m\":60,\"g\":94},\"2717\":{\"m\":60,\"g\":94},\"2601\":{\"m\":60,\"g\":94},\"2716\":{\"m\":60,\"g\":94},\"2711\":{\"m\":60,\"g\":94},\"2714\":{\"m\":60,\"g\":94},\"2704\":{\"m\":60,\"g\":94},\"2707\":{\"m\":60,\"g\":94},\"2712\":{\"m\":60,\"g\":94},\"2150\":{\"m\":60,\"g\":94},\"2709\":{\"m\":60,\"g\":94},\"2695\":{\"m\":60,\"g\":94},\"2705\":{\"m\":60,\"g\":94},\"2663\":{\"m\":60,\"g\":94},\"2697\":{\"m\":60,\"g\":94},\"2692\":{\"m\":60,\"g\":94},\"2691\":{\"m\":60,\"g\":94},\"2690\":{\"m\":60,\"g\":94},\"2689\":{\"m\":60,\"g\":94},\"2685\":{\"m\":60,\"g\":94},\"2684\":{\"m\":60,\"g\":94},\"2683\":{\"m\":60,\"g\":94},\"2682\":{\"m\":60,\"g\":94},\"2680\":{\"m\":60,\"g\":94},\"2678\":{\"m\":60,\"g\":94},\"2679\":{\"m\":60,\"g\":94},\"2676\":{\"m\":60,\"g\":94},\"2674\":{\"m\":60,\"g\":94},\"2672\":{\"m\":60,\"g\":94},\"2670\":{\"m\":60,\"g\":94},\"2667\":{\"m\":60,\"g\":94},\"2669\":{\"m\":60,\"g\":94},\"2664\":{\"m\":60,\"g\":94},\"2642\":{\"m\":60,\"g\":94},\"2666\":{\"m\":60,\"g\":94},\"2654\":{\"m\":60,\"g\":94},\"2655\":{\"m\":60,\"g\":94},\"2656\":{\"m\":60,\"g\":94},\"2652\":{\"m\":60,\"g\":94},\"2651\":{\"m\":60,\"g\":94},\"2650\":{\"m\":60,\"g\":94},\"2649\":{\"m\":60,\"g\":94},\"2840\":{\"m\":61,\"g\":94},\"2837\":{\"m\":61,\"g\":94},\"2836\":{\"m\":61,\"g\":94},\"2804\":{\"m\":61,\"g\":94},\"2730\":{\"m\":61,\"g\":94},\"2826\":{\"m\":61,\"g\":94},\"2835\":{\"m\":61,\"g\":94},\"2822\":{\"m\":61,\"g\":94},\"2819\":{\"m\":61,\"g\":94},\"2833\":{\"m\":61,\"g\":94},\"2830\":{\"m\":61,\"g\":94},\"2787\":{\"m\":61,\"g\":94},\"2816\":{\"m\":61,\"g\":94},\"2813\":{\"m\":61,\"g\":94},\"2809\":{\"m\":61,\"g\":94},\"2792\":{\"m\":61,\"g\":94},\"2789\":{\"m\":61,\"g\":94},\"2773\":{\"m\":61,\"g\":94},\"2784\":{\"m\":61,\"g\":94},\"2723\":{\"m\":61,\"g\":94},\"2780\":{\"m\":61,\"g\":94},\"2779\":{\"m\":61,\"g\":94},\"2771\":{\"m\":61,\"g\":94},\"2774\":{\"m\":61,\"g\":94},\"2761\":{\"m\":61,\"g\":94},\"2770\":{\"m\":61,\"g\":94},\"2767\":{\"m\":61,\"g\":94},\"2758\":{\"m\":61,\"g\":94},\"2757\":{\"m\":61,\"g\":94},\"2513\":{\"m\":61,\"g\":94},\"2535\":{\"m\":61,\"g\":94},\"2756\":{\"m\":61,\"g\":94},\"2745\":{\"m\":61,\"g\":94},\"2748\":{\"m\":61,\"g\":94},\"2752\":{\"m\":61,\"g\":94},\"2751\":{\"m\":61,\"g\":94},\"2750\":{\"m\":61,\"g\":94},\"2899\":{\"m\":62,\"g\":94},\"2887\":{\"m\":62,\"g\":94},\"2888\":{\"m\":62,\"g\":94},\"2881\":{\"m\":62,\"g\":94},\"2885\":{\"m\":62,\"g\":94},\"2879\":{\"m\":62,\"g\":94},\"2878\":{\"m\":62,\"g\":94},\"2875\":{\"m\":62,\"g\":94},\"2630\":{\"m\":62,\"g\":94},\"2870\":{\"m\":62,\"g\":94},\"2869\":{\"m\":62,\"g\":94},\"2863\":{\"m\":62,\"g\":94},\"2868\":{\"m\":62,\"g\":94},\"2867\":{\"m\":62,\"g\":94},\"2862\":{\"m\":62,\"g\":94},\"2866\":{\"m\":62,\"g\":94},\"2865\":{\"m\":62,\"g\":94},\"2828\":{\"m\":62,\"g\":94},\"2859\":{\"m\":62,\"g\":94},\"2858\":{\"m\":62,\"g\":94},\"2861\":{\"m\":62,\"g\":94},\"2857\":{\"m\":62,\"g\":94},\"2860\":{\"m\":62,\"g\":94},\"2856\":{\"m\":62,\"g\":94},\"2851\":{\"m\":62,\"g\":94},\"2853\":{\"m\":62,\"g\":94},\"2854\":{\"m\":62,\"g\":94},\"2852\":{\"m\":62,\"g\":94},\"2786\":{\"m\":62,\"g\":94},\"2848\":{\"m\":62,\"g\":94},\"2850\":{\"m\":62,\"g\":94},\"2846\":{\"m\":62,\"g\":94},\"2843\":{\"m\":62,\"g\":94},\"2841\":{\"m\":62,\"g\":94},\"3009\":{\"m\":63,\"g\":94},\"2993\":{\"m\":63,\"g\":94},\"3010\":{\"m\":63,\"g\":94},\"3006\":{\"m\":63,\"g\":94},\"3008\":{\"m\":63,\"g\":94},\"3003\":{\"m\":63,\"g\":94},\"2998\":{\"m\":63,\"g\":94},\"3005\":{\"m\":63,\"g\":94},\"3004\":{\"m\":63,\"g\":94},\"3001\":{\"m\":63,\"g\":94},\"2996\":{\"m\":63,\"g\":94},\"2997\":{\"m\":63,\"g\":94},\"2995\":{\"m\":63,\"g\":94},\"2991\":{\"m\":63,\"g\":94},\"2992\":{\"m\":63,\"g\":94},\"2990\":{\"m\":63,\"g\":94},\"2396\":{\"m\":63,\"g\":94},\"2983\":{\"m\":63,\"g\":94},\"2988\":{\"m\":63,\"g\":94},\"2839\":{\"m\":63,\"g\":94},\"2982\":{\"m\":63,\"g\":94},\"2986\":{\"m\":63,\"g\":94},\"2987\":{\"m\":63,\"g\":94},\"2985\":{\"m\":63,\"g\":94},\"2984\":{\"m\":63,\"g\":94},\"2981\":{\"m\":63,\"g\":94},\"2975\":{\"m\":63,\"g\":94},\"2980\":{\"m\":63,\"g\":94},\"2979\":{\"m\":63,\"g\":94},\"2978\":{\"m\":63,\"g\":94},\"2976\":{\"m\":63,\"g\":94},\"2974\":{\"m\":63,\"g\":94},\"2973\":{\"m\":63,\"g\":94},\"2972\":{\"m\":63,\"g\":94},\"2958\":{\"m\":63,\"g\":94},\"2956\":{\"m\":63,\"g\":94},\"2901\":{\"m\":63,\"g\":94},\"2971\":{\"m\":63,\"g\":94},\"2785\":{\"m\":63,\"g\":94},\"2966\":{\"m\":63,\"g\":94},\"2967\":{\"m\":63,\"g\":94},\"2964\":{\"m\":63,\"g\":94},\"2963\":{\"m\":63,\"g\":94},\"2960\":{\"m\":63,\"g\":94},\"2941\":{\"m\":63,\"g\":94},\"2894\":{\"m\":63,\"g\":94},\"2942\":{\"m\":63,\"g\":94},\"2944\":{\"m\":63,\"g\":94},\"2954\":{\"m\":63,\"g\":94},\"2952\":{\"m\":63,\"g\":94},\"2951\":{\"m\":63,\"g\":94},\"2947\":{\"m\":63,\"g\":94},\"2948\":{\"m\":63,\"g\":94},\"2949\":{\"m\":63,\"g\":94},\"2950\":{\"m\":63,\"g\":94},\"2945\":{\"m\":63,\"g\":94},\"2907\":{\"m\":63,\"g\":94},\"2938\":{\"m\":63,\"g\":94},\"2937\":{\"m\":63,\"g\":94},\"2806\":{\"m\":63,\"g\":94},\"2876\":{\"m\":63,\"g\":94},\"2930\":{\"m\":63,\"g\":94},\"2928\":{\"m\":63,\"g\":94},\"2920\":{\"m\":63,\"g\":94},\"2926\":{\"m\":63,\"g\":94},\"2927\":{\"m\":63,\"g\":94},\"2925\":{\"m\":63,\"g\":94},\"2924\":{\"m\":63,\"g\":94},\"2923\":{\"m\":63,\"g\":94},\"2821\":{\"m\":63,\"g\":94},\"2910\":{\"m\":63,\"g\":94},\"2922\":{\"m\":63,\"g\":94},\"2919\":{\"m\":63,\"g\":94},\"2917\":{\"m\":63,\"g\":94},\"2915\":{\"m\":63,\"g\":94},\"2911\":{\"m\":63,\"g\":94},\"2909\":{\"m\":63,\"g\":94},\"2908\":{\"m\":63,\"g\":94},\"2511\":{\"m\":63,\"g\":94},\"2906\":{\"m\":63,\"g\":94},\"2904\":{\"m\":63,\"g\":94},\"2902\":{\"m\":63,\"g\":94},\"2872\":{\"m\":63,\"g\":94},\"2897\":{\"m\":63,\"g\":94},\"3180\":{\"m\":64,\"g\":94},\"3179\":{\"m\":64,\"g\":94},\"3178\":{\"m\":64,\"g\":94},\"3175\":{\"m\":64,\"g\":94},\"3176\":{\"m\":64,\"g\":94},\"3174\":{\"m\":64,\"g\":94},\"3146\":{\"m\":64,\"g\":94},\"3173\":{\"m\":64,\"g\":94},\"3170\":{\"m\":64,\"g\":94},\"3167\":{\"m\":64,\"g\":94},\"3156\":{\"m\":64,\"g\":94},\"3134\":{\"m\":64,\"g\":94},\"3162\":{\"m\":64,\"g\":94},\"3144\":{\"m\":64,\"g\":94},\"3155\":{\"m\":64,\"g\":94},\"2700\":{\"m\":64,\"g\":94},\"3154\":{\"m\":64,\"g\":94},\"3153\":{\"m\":64,\"g\":94},\"3152\":{\"m\":64,\"g\":94},\"3150\":{\"m\":64,\"g\":94},\"3151\":{\"m\":64,\"g\":94},\"3149\":{\"m\":64,\"g\":94},\"3147\":{\"m\":64,\"g\":94},\"3145\":{\"m\":64,\"g\":94},\"3085\":{\"m\":64,\"g\":94},\"3047\":{\"m\":64,\"g\":94},\"3143\":{\"m\":64,\"g\":94},\"3139\":{\"m\":64,\"g\":94},\"3135\":{\"m\":64,\"g\":94},\"3138\":{\"m\":64,\"g\":94},\"3113\":{\"m\":64,\"g\":94},\"3133\":{\"m\":64,\"g\":94},\"3132\":{\"m\":64,\"g\":94},\"3130\":{\"m\":64,\"g\":94},\"3129\":{\"m\":64,\"g\":94},\"3128\":{\"m\":64,\"g\":94},\"3127\":{\"m\":64,\"g\":94},\"3126\":{\"m\":64,\"g\":94},\"3125\":{\"m\":64,\"g\":94},\"3124\":{\"m\":64,\"g\":94},\"3121\":{\"m\":64,\"g\":94},\"3110\":{\"m\":64,\"g\":94},\"3109\":{\"m\":64,\"g\":94},\"3107\":{\"m\":64,\"g\":94},\"3096\":{\"m\":64,\"g\":94},\"3037\":{\"m\":64,\"g\":94},\"3105\":{\"m\":64,\"g\":94},\"3097\":{\"m\":64,\"g\":94},\"3095\":{\"m\":64,\"g\":94},\"3094\":{\"m\":64,\"g\":94},\"3070\":{\"m\":64,\"g\":94},\"3093\":{\"m\":64,\"g\":94},\"2742\":{\"m\":64,\"g\":94},\"3087\":{\"m\":64,\"g\":94},\"3086\":{\"m\":64,\"g\":94},\"3084\":{\"m\":64,\"g\":94},\"3083\":{\"m\":64,\"g\":94},\"3081\":{\"m\":64,\"g\":94},\"3080\":{\"m\":64,\"g\":94},\"3079\":{\"m\":64,\"g\":94},\"3078\":{\"m\":64,\"g\":94},\"3074\":{\"m\":64,\"g\":94},\"3030\":{\"m\":64,\"g\":94},\"3071\":{\"m\":64,\"g\":94},\"3069\":{\"m\":64,\"g\":94},\"3068\":{\"m\":64,\"g\":94},\"3067\":{\"m\":64,\"g\":94},\"3061\":{\"m\":64,\"g\":94},\"3062\":{\"m\":64,\"g\":94},\"3063\":{\"m\":64,\"g\":94},\"2989\":{\"m\":64,\"g\":94},\"3060\":{\"m\":64,\"g\":94},\"3045\":{\"m\":64,\"g\":94},\"3038\":{\"m\":64,\"g\":94},\"3058\":{\"m\":64,\"g\":94},\"3057\":{\"m\":64,\"g\":94},\"3055\":{\"m\":64,\"g\":94},\"3056\":{\"m\":64,\"g\":94},\"3054\":{\"m\":64,\"g\":94},\"3053\":{\"m\":64,\"g\":94},\"3052\":{\"m\":64,\"g\":94},\"3051\":{\"m\":64,\"g\":94},\"3048\":{\"m\":64,\"g\":94},\"3046\":{\"m\":64,\"g\":94},\"3039\":{\"m\":64,\"g\":94},\"3036\":{\"m\":64,\"g\":94},\"3035\":{\"m\":64,\"g\":94},\"3033\":{\"m\":64,\"g\":94},\"3027\":{\"m\":64,\"g\":94},\"3026\":{\"m\":64,\"g\":94},\"3025\":{\"m\":64,\"g\":94},\"3014\":{\"m\":64,\"g\":94},\"3018\":{\"m\":64,\"g\":94},\"2939\":{\"m\":64,\"g\":94},\"3022\":{\"m\":64,\"g\":94},\"3021\":{\"m\":64,\"g\":94},\"3017\":{\"m\":64,\"g\":94},\"3015\":{\"m\":64,\"g\":94},\"3020\":{\"m\":64,\"g\":94},\"3019\":{\"m\":64,\"g\":94},\"3016\":{\"m\":64,\"g\":94},\"3013\":{\"m\":64,\"g\":94},\"3012\":{\"m\":64,\"g\":94},\"3233\":{\"m\":65,\"g\":94},\"3231\":{\"m\":65,\"g\":94},\"3232\":{\"m\":65,\"g\":94},\"3224\":{\"m\":65,\"g\":94},\"3230\":{\"m\":65,\"g\":94},\"3229\":{\"m\":65,\"g\":94},\"3227\":{\"m\":65,\"g\":94},\"3218\":{\"m\":65,\"g\":94},\"3217\":{\"m\":65,\"g\":94},\"3216\":{\"m\":65,\"g\":94},\"3214\":{\"m\":65,\"g\":94},\"3213\":{\"m\":65,\"g\":94},\"3212\":{\"m\":65,\"g\":94},\"2977\":{\"m\":65,\"g\":94},\"3190\":{\"m\":65,\"g\":94},\"3169\":{\"m\":65,\"g\":94},\"3192\":{\"m\":65,\"g\":94},\"3183\":{\"m\":65,\"g\":94},\"3166\":{\"m\":65,\"g\":94},\"3171\":{\"m\":65,\"g\":94},\"3181\":{\"m\":65,\"g\":94},\"3313\":{\"m\":66,\"g\":94},\"3312\":{\"m\":66,\"g\":94},\"3306\":{\"m\":66,\"g\":94},\"3305\":{\"m\":66,\"g\":94},\"3299\":{\"m\":66,\"g\":94},\"3294\":{\"m\":66,\"g\":94},\"3293\":{\"m\":66,\"g\":94},\"3292\":{\"m\":66,\"g\":94},\"3287\":{\"m\":66,\"g\":94},\"3288\":{\"m\":66,\"g\":94},\"3161\":{\"m\":66,\"g\":94},\"3273\":{\"m\":66,\"g\":94},\"3276\":{\"m\":66,\"g\":94},\"3274\":{\"m\":66,\"g\":94},\"3272\":{\"m\":66,\"g\":94},\"3270\":{\"m\":66,\"g\":94},\"3269\":{\"m\":66,\"g\":94},\"3268\":{\"m\":66,\"g\":94},\"3205\":{\"m\":66,\"g\":94},\"3207\":{\"m\":66,\"g\":94},\"3259\":{\"m\":66,\"g\":94},\"3261\":{\"m\":66,\"g\":94},\"3114\":{\"m\":66,\"g\":94},\"3255\":{\"m\":66,\"g\":94},\"3252\":{\"m\":66,\"g\":94},\"3251\":{\"m\":66,\"g\":94},\"3250\":{\"m\":66,\"g\":94},\"3249\":{\"m\":66,\"g\":94},\"3248\":{\"m\":66,\"g\":94},\"3246\":{\"m\":66,\"g\":94},\"3221\":{\"m\":66,\"g\":94},\"3242\":{\"m\":66,\"g\":94},\"3240\":{\"m\":66,\"g\":94},\"3238\":{\"m\":66,\"g\":94},\"3236\":{\"m\":66,\"g\":94},\"3235\":{\"m\":66,\"g\":94},\"3228\":{\"m\":66,\"g\":94},\"3369\":{\"m\":67,\"g\":94},\"3378\":{\"m\":67,\"g\":94},\"3376\":{\"m\":67,\"g\":94},\"3374\":{\"m\":67,\"g\":94},\"3373\":{\"m\":67,\"g\":94},\"3372\":{\"m\":67,\"g\":94},\"3366\":{\"m\":67,\"g\":94},\"3356\":{\"m\":67,\"g\":94},\"3300\":{\"m\":67,\"g\":94},\"3355\":{\"m\":67,\"g\":94},\"3314\":{\"m\":67,\"g\":94},\"3350\":{\"m\":67,\"g\":94},\"3347\":{\"m\":67,\"g\":94},\"3352\":{\"m\":67,\"g\":94},\"3349\":{\"m\":67,\"g\":94},\"3338\":{\"m\":67,\"g\":94},\"3337\":{\"m\":67,\"g\":94},\"3335\":{\"m\":67,\"g\":94},\"3332\":{\"m\":67,\"g\":94},\"3325\":{\"m\":67,\"g\":94},\"3327\":{\"m\":67,\"g\":94},\"3324\":{\"m\":67,\"g\":94},\"3168\":{\"m\":67,\"g\":94},\"3317\":{\"m\":67,\"g\":94},\"3309\":{\"m\":67,\"g\":94},\"3459\":{\"m\":68,\"g\":94},\"3457\":{\"m\":68,\"g\":94},\"3413\":{\"m\":68,\"g\":94},\"3453\":{\"m\":68,\"g\":94},\"3452\":{\"m\":68,\"g\":94},\"3442\":{\"m\":68,\"g\":94},\"3441\":{\"m\":68,\"g\":94},\"3440\":{\"m\":68,\"g\":94},\"3439\":{\"m\":68,\"g\":94},\"3437\":{\"m\":68,\"g\":94},\"3410\":{\"m\":68,\"g\":94},\"3435\":{\"m\":68,\"g\":94},\"3433\":{\"m\":68,\"g\":94},\"3431\":{\"m\":68,\"g\":94},\"3430\":{\"m\":68,\"g\":94},\"3425\":{\"m\":68,\"g\":94},\"3422\":{\"m\":68,\"g\":94},\"3421\":{\"m\":68,\"g\":94},\"3408\":{\"m\":68,\"g\":94},\"3415\":{\"m\":68,\"g\":94},\"3411\":{\"m\":68,\"g\":94},\"3412\":{\"m\":68,\"g\":94},\"3407\":{\"m\":68,\"g\":94},\"3404\":{\"m\":68,\"g\":94},\"3346\":{\"m\":68,\"g\":94},\"3382\":{\"m\":68,\"g\":94},\"3386\":{\"m\":68,\"g\":94},\"3275\":{\"m\":68,\"g\":94},\"3556\":{\"m\":69,\"g\":94},\"3555\":{\"m\":69,\"g\":94},\"3550\":{\"m\":69,\"g\":94},\"3553\":{\"m\":69,\"g\":94},\"3543\":{\"m\":69,\"g\":94},\"3534\":{\"m\":69,\"g\":94},\"3541\":{\"m\":69,\"g\":94},\"3536\":{\"m\":69,\"g\":94},\"3529\":{\"m\":69,\"g\":94},\"3530\":{\"m\":69,\"g\":94},\"3267\":{\"m\":69,\"g\":94},\"3450\":{\"m\":69,\"g\":94},\"3503\":{\"m\":69,\"g\":94},\"3493\":{\"m\":69,\"g\":94},\"3523\":{\"m\":69,\"g\":94},\"3522\":{\"m\":69,\"g\":94},\"3405\":{\"m\":69,\"g\":94},\"3502\":{\"m\":69,\"g\":94},\"3420\":{\"m\":69,\"g\":94},\"3495\":{\"m\":69,\"g\":94},\"3496\":{\"m\":69,\"g\":94},\"3499\":{\"m\":69,\"g\":94},\"3498\":{\"m\":69,\"g\":94},\"3497\":{\"m\":69,\"g\":94},\"3500\":{\"m\":69,\"g\":94},\"3492\":{\"m\":69,\"g\":94},\"3490\":{\"m\":69,\"g\":94},\"3418\":{\"m\":69,\"g\":94},\"3364\":{\"m\":69,\"g\":94},\"3473\":{\"m\":69,\"g\":94},\"3469\":{\"m\":69,\"g\":94},\"3468\":{\"m\":69,\"g\":94},\"3466\":{\"m\":69,\"g\":94},\"3638\":{\"m\":70,\"g\":94},\"3636\":{\"m\":70,\"g\":94},\"3634\":{\"m\":70,\"g\":94},\"3632\":{\"m\":70,\"g\":94},\"3619\":{\"m\":70,\"g\":94},\"3617\":{\"m\":70,\"g\":94},\"3260\":{\"m\":70,\"g\":94},\"3532\":{\"m\":70,\"g\":94},\"3597\":{\"m\":70,\"g\":94},\"3535\":{\"m\":70,\"g\":94},\"3258\":{\"m\":70,\"g\":94},\"3598\":{\"m\":70,\"g\":94},\"3564\":{\"m\":70,\"g\":94},\"3594\":{\"m\":70,\"g\":94},\"3592\":{\"m\":70,\"g\":94},\"3591\":{\"m\":70,\"g\":94},\"3589\":{\"m\":70,\"g\":94},\"3587\":{\"m\":70,\"g\":94},\"3582\":{\"m\":70,\"g\":94},\"3548\":{\"m\":70,\"g\":94},\"3584\":{\"m\":70,\"g\":94},\"3505\":{\"m\":70,\"g\":94},\"3581\":{\"m\":70,\"g\":94},\"3563\":{\"m\":70,\"g\":94},\"3363\":{\"m\":70,\"g\":94},\"3558\":{\"m\":70,\"g\":94},\"3557\":{\"m\":70,\"g\":94},\"3645\":{\"m\":71,\"g\":94},\"3644\":{\"m\":71,\"g\":94},\"3643\":{\"m\":71,\"g\":94},\"3639\":{\"m\":71,\"g\":94},\"4114\":{\"m\":72,\"g\":94},\"4099\":{\"m\":72,\"g\":94},\"4101\":{\"m\":72,\"g\":94},\"3211\":{\"m\":72,\"g\":94},\"3029\":{\"m\":72,\"g\":94},\"4110\":{\"m\":72,\"g\":94},\"4105\":{\"m\":72,\"g\":94},\"4109\":{\"m\":72,\"g\":94},\"4103\":{\"m\":72,\"g\":94},\"4108\":{\"m\":72,\"g\":94},\"4107\":{\"m\":72,\"g\":94},\"4102\":{\"m\":72,\"g\":94},\"4100\":{\"m\":72,\"g\":94},\"4012\":{\"m\":72,\"g\":94},\"4081\":{\"m\":72,\"g\":94},\"3986\":{\"m\":72,\"g\":94},\"4075\":{\"m\":72,\"g\":94},\"3990\":{\"m\":72,\"g\":94},\"3790\":{\"m\":72,\"g\":94},\"3941\":{\"m\":72,\"g\":94},\"4077\":{\"m\":72,\"g\":94},\"4074\":{\"m\":72,\"g\":94},\"4066\":{\"m\":72,\"g\":94},\"4071\":{\"m\":72,\"g\":94},\"4065\":{\"m\":72,\"g\":94},\"4016\":{\"m\":72,\"g\":94},\"3607\":{\"m\":72,\"g\":94},\"3712\":{\"m\":72,\"g\":94},\"3948\":{\"m\":72,\"g\":94},\"3954\":{\"m\":72,\"g\":94},\"4030\":{\"m\":72,\"g\":94},\"4051\":{\"m\":72,\"g\":94},\"4046\":{\"m\":72,\"g\":94},\"4053\":{\"m\":72,\"g\":94},\"4052\":{\"m\":72,\"g\":94},\"4023\":{\"m\":72,\"g\":94},\"4000\":{\"m\":72,\"g\":94},\"4049\":{\"m\":72,\"g\":94},\"4044\":{\"m\":72,\"g\":94},\"4043\":{\"m\":72,\"g\":94},\"4033\":{\"m\":72,\"g\":94},\"4039\":{\"m\":72,\"g\":94},\"3999\":{\"m\":72,\"g\":94},\"4034\":{\"m\":72,\"g\":94},\"4032\":{\"m\":72,\"g\":94},\"4027\":{\"m\":72,\"g\":94},\"4025\":{\"m\":72,\"g\":94},\"3264\":{\"m\":72,\"g\":94},\"4031\":{\"m\":72,\"g\":94},\"4029\":{\"m\":72,\"g\":94},\"4021\":{\"m\":72,\"g\":94},\"3988\":{\"m\":72,\"g\":94},\"4014\":{\"m\":72,\"g\":94},\"3826\":{\"m\":72,\"g\":94},\"4010\":{\"m\":72,\"g\":94},\"4008\":{\"m\":72,\"g\":94},\"3987\":{\"m\":72,\"g\":94},\"3822\":{\"m\":72,\"g\":94},\"3993\":{\"m\":72,\"g\":94},\"3406\":{\"m\":72,\"g\":94},\"3994\":{\"m\":72,\"g\":94},\"3992\":{\"m\":72,\"g\":94},\"3991\":{\"m\":72,\"g\":94},\"3893\":{\"m\":72,\"g\":94},\"3989\":{\"m\":72,\"g\":94},\"3985\":{\"m\":72,\"g\":94},\"3982\":{\"m\":72,\"g\":94},\"3979\":{\"m\":72,\"g\":94},\"3976\":{\"m\":72,\"g\":94},\"3977\":{\"m\":72,\"g\":94},\"3975\":{\"m\":72,\"g\":94},\"3967\":{\"m\":72,\"g\":94},\"3966\":{\"m\":72,\"g\":94},\"3963\":{\"m\":72,\"g\":94},\"3852\":{\"m\":72,\"g\":94},\"3566\":{\"m\":72,\"g\":94},\"3678\":{\"m\":72,\"g\":94},\"3950\":{\"m\":72,\"g\":94},\"3870\":{\"m\":72,\"g\":94},\"3613\":{\"m\":72,\"g\":94},\"3866\":{\"m\":72,\"g\":94},\"3861\":{\"m\":72,\"g\":94},\"3934\":{\"m\":72,\"g\":94},\"3933\":{\"m\":72,\"g\":94},\"3593\":{\"m\":72,\"g\":94},\"3922\":{\"m\":72,\"g\":94},\"3925\":{\"m\":72,\"g\":94},\"3914\":{\"m\":72,\"g\":94},\"3905\":{\"m\":72,\"g\":94},\"3897\":{\"m\":72,\"g\":94},\"3907\":{\"m\":72,\"g\":94},\"3898\":{\"m\":72,\"g\":94},\"3903\":{\"m\":72,\"g\":94},\"3791\":{\"m\":72,\"g\":94},\"3900\":{\"m\":72,\"g\":94},\"3894\":{\"m\":72,\"g\":94},\"3298\":{\"m\":72,\"g\":94},\"3860\":{\"m\":72,\"g\":94},\"3845\":{\"m\":72,\"g\":94},\"3602\":{\"m\":72,\"g\":94},\"3865\":{\"m\":72,\"g\":94},\"3843\":{\"m\":72,\"g\":94},\"3519\":{\"m\":72,\"g\":94},\"3841\":{\"m\":72,\"g\":94},\"3857\":{\"m\":72,\"g\":94},\"3709\":{\"m\":72,\"g\":94},\"3803\":{\"m\":72,\"g\":94},\"3787\":{\"m\":72,\"g\":94},\"3237\":{\"m\":72,\"g\":94},\"3641\":{\"m\":72,\"g\":94},\"3741\":{\"m\":72,\"g\":94},\"3799\":{\"m\":72,\"g\":94},\"3801\":{\"m\":72,\"g\":94},\"3821\":{\"m\":72,\"g\":94},\"3829\":{\"m\":72,\"g\":94},\"3828\":{\"m\":72,\"g\":94},\"3818\":{\"m\":72,\"g\":94},\"3730\":{\"m\":72,\"g\":94},\"3785\":{\"m\":72,\"g\":94},\"3813\":{\"m\":72,\"g\":94},\"3809\":{\"m\":72,\"g\":94},\"2693\":{\"m\":72,\"g\":94},\"3795\":{\"m\":72,\"g\":94},\"3116\":{\"m\":72,\"g\":94},\"3115\":{\"m\":72,\"g\":94},\"3562\":{\"m\":72,\"g\":94},\"3117\":{\"m\":72,\"g\":94},\"3766\":{\"m\":72,\"g\":94},\"3777\":{\"m\":72,\"g\":94},\"3348\":{\"m\":72,\"g\":94},\"3223\":{\"m\":72,\"g\":94},\"3772\":{\"m\":72,\"g\":94},\"3773\":{\"m\":72,\"g\":94},\"3771\":{\"m\":72,\"g\":94},\"3733\":{\"m\":72,\"g\":94},\"3754\":{\"m\":72,\"g\":94},\"3432\":{\"m\":72,\"g\":94},\"3761\":{\"m\":72,\"g\":94},\"3740\":{\"m\":72,\"g\":94},\"3680\":{\"m\":72,\"g\":94},\"3747\":{\"m\":72,\"g\":94},\"3737\":{\"m\":72,\"g\":94},\"3588\":{\"m\":72,\"g\":94},\"3652\":{\"m\":72,\"g\":94},\"3732\":{\"m\":72,\"g\":94},\"3705\":{\"m\":72,\"g\":94},\"3731\":{\"m\":72,\"g\":94},\"3722\":{\"m\":72,\"g\":94},\"3727\":{\"m\":72,\"g\":94},\"3710\":{\"m\":72,\"g\":94},\"3677\":{\"m\":72,\"g\":94},\"3601\":{\"m\":72,\"g\":94},\"3692\":{\"m\":72,\"g\":94},\"3700\":{\"m\":72,\"g\":94},\"3706\":{\"m\":72,\"g\":94},\"3628\":{\"m\":72,\"g\":94},\"3698\":{\"m\":72,\"g\":94},\"3657\":{\"m\":72,\"g\":94},\"3665\":{\"m\":72,\"g\":94},\"3635\":{\"m\":72,\"g\":94},\"3676\":{\"m\":72,\"g\":94},\"3663\":{\"m\":72,\"g\":94},\"3629\":{\"m\":72,\"g\":94},\"3654\":{\"m\":72,\"g\":94},\"3624\":{\"m\":72,\"g\":94},\"3567\":{\"m\":72,\"g\":94},\"3616\":{\"m\":72,\"g\":94},\"3650\":{\"m\":72,\"g\":94},\"4140\":{\"m\":73,\"g\":94},\"4142\":{\"m\":73,\"g\":94},\"4147\":{\"m\":73,\"g\":94},\"4134\":{\"m\":73,\"g\":94},\"4135\":{\"m\":73,\"g\":94},\"4138\":{\"m\":73,\"g\":94},\"4137\":{\"m\":73,\"g\":94},\"4132\":{\"m\":73,\"g\":94},\"4128\":{\"m\":73,\"g\":94},\"4126\":{\"m\":73,\"g\":94},\"4129\":{\"m\":73,\"g\":94},\"4131\":{\"m\":73,\"g\":94},\"4038\":{\"m\":73,\"g\":94},\"4111\":{\"m\":73,\"g\":94},\"4113\":{\"m\":73,\"g\":94},\"4117\":{\"m\":73,\"g\":94},\"4121\":{\"m\":73,\"g\":94},\"4041\":{\"m\":74,\"g\":94},\"4381\":{\"m\":74,\"g\":94},\"4377\":{\"m\":74,\"g\":94},\"4376\":{\"m\":74,\"g\":94},\"4374\":{\"m\":74,\"g\":94},\"4367\":{\"m\":74,\"g\":94},\"4086\":{\"m\":74,\"g\":94},\"4356\":{\"m\":74,\"g\":94},\"3679\":{\"m\":74,\"g\":94},\"3814\":{\"m\":74,\"g\":94},\"3835\":{\"m\":74,\"g\":94},\"3844\":{\"m\":74,\"g\":94},\"3896\":{\"m\":74,\"g\":94},\"3959\":{\"m\":74,\"g\":94},\"3980\":{\"m\":74,\"g\":94},\"3962\":{\"m\":74,\"g\":94},\"3961\":{\"m\":74,\"g\":94},\"4026\":{\"m\":74,\"g\":94},\"4079\":{\"m\":74,\"g\":94},\"4359\":{\"m\":74,\"g\":94},\"4212\":{\"m\":74,\"g\":94},\"4295\":{\"m\":74,\"g\":94},\"4342\":{\"m\":74,\"g\":94},\"4335\":{\"m\":74,\"g\":94},\"4329\":{\"m\":74,\"g\":94},\"4362\":{\"m\":74,\"g\":94},\"4348\":{\"m\":74,\"g\":94},\"4326\":{\"m\":74,\"g\":94},\"4354\":{\"m\":74,\"g\":94},\"4355\":{\"m\":74,\"g\":94},\"4352\":{\"m\":74,\"g\":94},\"4350\":{\"m\":74,\"g\":94},\"4278\":{\"m\":74,\"g\":94},\"4082\":{\"m\":74,\"g\":94},\"3203\":{\"m\":74,\"g\":94},\"4340\":{\"m\":74,\"g\":94},\"4337\":{\"m\":74,\"g\":94},\"3911\":{\"m\":74,\"g\":94},\"4334\":{\"m\":74,\"g\":94},\"4331\":{\"m\":74,\"g\":94},\"4333\":{\"m\":74,\"g\":94},\"4104\":{\"m\":74,\"g\":94},\"4215\":{\"m\":74,\"g\":94},\"4327\":{\"m\":74,\"g\":94},\"4297\":{\"m\":74,\"g\":94},\"4323\":{\"m\":74,\"g\":94},\"4321\":{\"m\":74,\"g\":94},\"4317\":{\"m\":74,\"g\":94},\"4229\":{\"m\":74,\"g\":94},\"4220\":{\"m\":74,\"g\":94},\"4311\":{\"m\":74,\"g\":94},\"4299\":{\"m\":74,\"g\":94},\"4287\":{\"m\":74,\"g\":94},\"4290\":{\"m\":74,\"g\":94},\"4199\":{\"m\":74,\"g\":94},\"4291\":{\"m\":74,\"g\":94},\"4136\":{\"m\":74,\"g\":94},\"4288\":{\"m\":74,\"g\":94},\"4284\":{\"m\":74,\"g\":94},\"4277\":{\"m\":74,\"g\":94},\"4279\":{\"m\":74,\"g\":94},\"4275\":{\"m\":74,\"g\":94},\"4272\":{\"m\":74,\"g\":94},\"4261\":{\"m\":74,\"g\":94},\"4256\":{\"m\":74,\"g\":94},\"4267\":{\"m\":74,\"g\":94},\"4262\":{\"m\":74,\"g\":94},\"4258\":{\"m\":74,\"g\":94},\"4255\":{\"m\":74,\"g\":94},\"4231\":{\"m\":74,\"g\":94},\"4144\":{\"m\":74,\"g\":94},\"4252\":{\"m\":74,\"g\":94},\"4206\":{\"m\":74,\"g\":94},\"3958\":{\"m\":74,\"g\":94},\"4165\":{\"m\":74,\"g\":94},\"4250\":{\"m\":74,\"g\":94},\"4230\":{\"m\":74,\"g\":94},\"4238\":{\"m\":74,\"g\":94},\"4243\":{\"m\":74,\"g\":94},\"4242\":{\"m\":74,\"g\":94},\"4241\":{\"m\":74,\"g\":94},\"4237\":{\"m\":74,\"g\":94},\"4235\":{\"m\":74,\"g\":94},\"4228\":{\"m\":74,\"g\":94},\"4217\":{\"m\":74,\"g\":94},\"3148\":{\"m\":74,\"g\":94},\"4218\":{\"m\":74,\"g\":94},\"4224\":{\"m\":74,\"g\":94},\"4193\":{\"m\":74,\"g\":94},\"3631\":{\"m\":74,\"g\":94},\"4225\":{\"m\":74,\"g\":94},\"4223\":{\"m\":74,\"g\":94},\"4213\":{\"m\":74,\"g\":94},\"4222\":{\"m\":74,\"g\":94},\"4219\":{\"m\":74,\"g\":94},\"4124\":{\"m\":74,\"g\":94},\"4216\":{\"m\":74,\"g\":94},\"4211\":{\"m\":74,\"g\":94},\"4210\":{\"m\":74,\"g\":94},\"3749\":{\"m\":74,\"g\":94},\"4203\":{\"m\":74,\"g\":94},\"4181\":{\"m\":74,\"g\":94},\"4200\":{\"m\":74,\"g\":94},\"4198\":{\"m\":74,\"g\":94},\"4197\":{\"m\":74,\"g\":94},\"4164\":{\"m\":74,\"g\":94},\"4195\":{\"m\":74,\"g\":94},\"4194\":{\"m\":74,\"g\":94},\"4189\":{\"m\":74,\"g\":94},\"4185\":{\"m\":74,\"g\":94},\"4187\":{\"m\":74,\"g\":94},\"4186\":{\"m\":74,\"g\":94},\"4178\":{\"m\":74,\"g\":94},\"4174\":{\"m\":74,\"g\":94},\"4179\":{\"m\":74,\"g\":94},\"4177\":{\"m\":74,\"g\":94},\"4176\":{\"m\":74,\"g\":94},\"4170\":{\"m\":74,\"g\":94},\"4168\":{\"m\":74,\"g\":94},\"4166\":{\"m\":74,\"g\":94},\"4162\":{\"m\":74,\"g\":94},\"4163\":{\"m\":74,\"g\":94},\"3888\":{\"m\":74,\"g\":94},\"4089\":{\"m\":74,\"g\":94},\"3786\":{\"m\":74,\"g\":94},\"3694\":{\"m\":74,\"g\":94},\"4152\":{\"m\":74,\"g\":94},\"4154\":{\"m\":74,\"g\":94},\"4151\":{\"m\":74,\"g\":94},\"4148\":{\"m\":74,\"g\":94},\"4402\":{\"m\":75,\"g\":94},\"4397\":{\"m\":75,\"g\":94},\"4399\":{\"m\":75,\"g\":94},\"4269\":{\"m\":75,\"g\":94},\"4320\":{\"m\":75,\"g\":94},\"4398\":{\"m\":75,\"g\":94},\"4393\":{\"m\":75,\"g\":94},\"4390\":{\"m\":75,\"g\":94},\"4392\":{\"m\":75,\"g\":94},\"4375\":{\"m\":75,\"g\":94},\"4669\":{\"m\":76,\"g\":94},\"4743\":{\"m\":76,\"g\":94},\"4797\":{\"m\":76,\"g\":94},\"4782\":{\"m\":76,\"g\":94},\"4777\":{\"m\":76,\"g\":94},\"4775\":{\"m\":76,\"g\":94},\"4784\":{\"m\":76,\"g\":94},\"4566\":{\"m\":76,\"g\":94},\"4728\":{\"m\":76,\"g\":94},\"4755\":{\"m\":76,\"g\":94},\"4753\":{\"m\":76,\"g\":94},\"4752\":{\"m\":76,\"g\":94},\"4751\":{\"m\":76,\"g\":94},\"4310\":{\"m\":76,\"g\":94},\"4705\":{\"m\":76,\"g\":94},\"4435\":{\"m\":76,\"g\":94},\"4695\":{\"m\":76,\"g\":94},\"4691\":{\"m\":76,\"g\":94},\"4738\":{\"m\":76,\"g\":94},\"4735\":{\"m\":76,\"g\":94},\"4744\":{\"m\":76,\"g\":94},\"3023\":{\"m\":76,\"g\":94},\"4737\":{\"m\":76,\"g\":94},\"3899\":{\"m\":76,\"g\":94},\"4609\":{\"m\":76,\"g\":94},\"4736\":{\"m\":76,\"g\":94},\"4716\":{\"m\":76,\"g\":94},\"4721\":{\"m\":76,\"g\":94},\"4731\":{\"m\":76,\"g\":94},\"4720\":{\"m\":76,\"g\":94},\"4396\":{\"m\":76,\"g\":94},\"4605\":{\"m\":76,\"g\":94},\"4680\":{\"m\":76,\"g\":94},\"4631\":{\"m\":76,\"g\":94},\"4698\":{\"m\":76,\"g\":94},\"4064\":{\"m\":76,\"g\":94},\"4525\":{\"m\":76,\"g\":94},\"4661\":{\"m\":76,\"g\":94},\"4610\":{\"m\":76,\"g\":94},\"4608\":{\"m\":76,\"g\":94},\"4685\":{\"m\":76,\"g\":94},\"4679\":{\"m\":76,\"g\":94},\"4643\":{\"m\":76,\"g\":94},\"3984\":{\"m\":76,\"g\":94},\"4660\":{\"m\":76,\"g\":94},\"4676\":{\"m\":76,\"g\":94},\"4670\":{\"m\":76,\"g\":94},\"4677\":{\"m\":76,\"g\":94},\"4674\":{\"m\":76,\"g\":94},\"4556\":{\"m\":76,\"g\":94},\"4665\":{\"m\":76,\"g\":94},\"4596\":{\"m\":76,\"g\":94},\"4639\":{\"m\":76,\"g\":94},\"4664\":{\"m\":76,\"g\":94},\"4582\":{\"m\":76,\"g\":94},\"4654\":{\"m\":76,\"g\":94},\"4637\":{\"m\":76,\"g\":94},\"4641\":{\"m\":76,\"g\":94},\"4613\":{\"m\":76,\"g\":94},\"4640\":{\"m\":76,\"g\":94},\"4558\":{\"m\":76,\"g\":94},\"4622\":{\"m\":76,\"g\":94},\"4592\":{\"m\":76,\"g\":94},\"4571\":{\"m\":76,\"g\":94},\"3446\":{\"m\":76,\"g\":94},\"4577\":{\"m\":76,\"g\":94},\"4549\":{\"m\":76,\"g\":94},\"4583\":{\"m\":76,\"g\":94},\"4514\":{\"m\":76,\"g\":94},\"4232\":{\"m\":76,\"g\":94},\"4515\":{\"m\":76,\"g\":94},\"4557\":{\"m\":76,\"g\":94},\"4553\":{\"m\":76,\"g\":94},\"4274\":{\"m\":76,\"g\":94},\"4521\":{\"m\":76,\"g\":94},\"4247\":{\"m\":76,\"g\":94},\"4532\":{\"m\":76,\"g\":94},\"4441\":{\"m\":76,\"g\":94},\"4538\":{\"m\":76,\"g\":94},\"4541\":{\"m\":76,\"g\":94},\"4542\":{\"m\":76,\"g\":94},\"4500\":{\"m\":76,\"g\":94},\"4531\":{\"m\":76,\"g\":94},\"3682\":{\"m\":76,\"g\":94},\"4458\":{\"m\":76,\"g\":94},\"4486\":{\"m\":76,\"g\":94},\"4507\":{\"m\":76,\"g\":94},\"4522\":{\"m\":76,\"g\":94},\"4520\":{\"m\":76,\"g\":94},\"4505\":{\"m\":76,\"g\":94},\"4482\":{\"m\":76,\"g\":94},\"4517\":{\"m\":76,\"g\":94},\"4513\":{\"m\":76,\"g\":94},\"4510\":{\"m\":76,\"g\":94},\"4499\":{\"m\":76,\"g\":94},\"4495\":{\"m\":76,\"g\":94},\"4446\":{\"m\":76,\"g\":94},\"4480\":{\"m\":76,\"g\":94},\"4067\":{\"m\":76,\"g\":94},\"4485\":{\"m\":76,\"g\":94},\"4418\":{\"m\":76,\"g\":94},\"4493\":{\"m\":76,\"g\":94},\"2798\":{\"m\":76,\"g\":94},\"4372\":{\"m\":76,\"g\":94},\"4483\":{\"m\":76,\"g\":94},\"4386\":{\"m\":76,\"g\":94},\"2797\":{\"m\":76,\"g\":94},\"3612\":{\"m\":76,\"g\":94},\"4448\":{\"m\":76,\"g\":94},\"4465\":{\"m\":76,\"g\":94},\"4474\":{\"m\":76,\"g\":94},\"4479\":{\"m\":76,\"g\":94},\"4424\":{\"m\":76,\"g\":94},\"4202\":{\"m\":76,\"g\":94},\"4484\":{\"m\":76,\"g\":94},\"4481\":{\"m\":76,\"g\":94},\"4477\":{\"m\":76,\"g\":94},\"4472\":{\"m\":76,\"g\":94},\"4363\":{\"m\":76,\"g\":94},\"4470\":{\"m\":76,\"g\":94},\"4469\":{\"m\":76,\"g\":94},\"4383\":{\"m\":76,\"g\":94},\"4468\":{\"m\":76,\"g\":94},\"4467\":{\"m\":76,\"g\":94},\"4466\":{\"m\":76,\"g\":94},\"4449\":{\"m\":76,\"g\":94},\"4464\":{\"m\":76,\"g\":94},\"4460\":{\"m\":76,\"g\":94},\"4368\":{\"m\":76,\"g\":94},\"4459\":{\"m\":76,\"g\":94},\"4447\":{\"m\":76,\"g\":94},\"4423\":{\"m\":76,\"g\":94},\"4391\":{\"m\":76,\"g\":94},\"4413\":{\"m\":76,\"g\":94},\"4454\":{\"m\":76,\"g\":94},\"4453\":{\"m\":76,\"g\":94},\"4455\":{\"m\":76,\"g\":94},\"4452\":{\"m\":76,\"g\":94},\"4451\":{\"m\":76,\"g\":94},\"4442\":{\"m\":76,\"g\":94},\"4439\":{\"m\":76,\"g\":94},\"4438\":{\"m\":76,\"g\":94},\"4437\":{\"m\":76,\"g\":94},\"4302\":{\"m\":76,\"g\":94},\"4427\":{\"m\":76,\"g\":94},\"4419\":{\"m\":76,\"g\":94},\"3964\":{\"m\":76,\"g\":94},\"4400\":{\"m\":76,\"g\":94},\"4009\":{\"m\":76,\"g\":94},\"4403\":{\"m\":76,\"g\":94},\"4878\":{\"m\":77,\"g\":94},\"4874\":{\"m\":77,\"g\":94},\"4873\":{\"m\":77,\"g\":94},\"4831\":{\"m\":77,\"g\":94},\"4872\":{\"m\":77,\"g\":94},\"4768\":{\"m\":77,\"g\":94},\"4871\":{\"m\":77,\"g\":94},\"4866\":{\"m\":77,\"g\":94},\"4749\":{\"m\":77,\"g\":94},\"4864\":{\"m\":77,\"g\":94},\"4772\":{\"m\":77,\"g\":94},\"4834\":{\"m\":77,\"g\":94},\"4492\":{\"m\":77,\"g\":94},\"4863\":{\"m\":77,\"g\":94},\"4855\":{\"m\":77,\"g\":94},\"4853\":{\"m\":77,\"g\":94},\"4840\":{\"m\":77,\"g\":94},\"4740\":{\"m\":77,\"g\":94},\"4750\":{\"m\":77,\"g\":94},\"4687\":{\"m\":77,\"g\":94},\"4729\":{\"m\":77,\"g\":94},\"4712\":{\"m\":77,\"g\":94},\"4704\":{\"m\":77,\"g\":94},\"4688\":{\"m\":77,\"g\":94},\"4681\":{\"m\":77,\"g\":94},\"4648\":{\"m\":77,\"g\":94},\"4832\":{\"m\":77,\"g\":94},\"4528\":{\"m\":77,\"g\":94},\"4597\":{\"m\":77,\"g\":94},\"4487\":{\"m\":77,\"g\":94},\"3949\":{\"m\":77,\"g\":94},\"4844\":{\"m\":77,\"g\":94},\"4846\":{\"m\":77,\"g\":94},\"4843\":{\"m\":77,\"g\":94},\"4799\":{\"m\":77,\"g\":94},\"4788\":{\"m\":77,\"g\":94},\"4809\":{\"m\":77,\"g\":94},\"4837\":{\"m\":77,\"g\":94},\"3969\":{\"m\":77,\"g\":94},\"4835\":{\"m\":77,\"g\":94},\"4815\":{\"m\":77,\"g\":94},\"4819\":{\"m\":77,\"g\":94},\"4770\":{\"m\":77,\"g\":94},\"4833\":{\"m\":77,\"g\":94},\"4830\":{\"m\":77,\"g\":94},\"4694\":{\"m\":77,\"g\":94},\"4826\":{\"m\":77,\"g\":94},\"4825\":{\"m\":77,\"g\":94},\"4828\":{\"m\":77,\"g\":94},\"4827\":{\"m\":77,\"g\":94},\"4823\":{\"m\":77,\"g\":94},\"4341\":{\"m\":77,\"g\":94},\"4638\":{\"m\":77,\"g\":94},\"4813\":{\"m\":77,\"g\":94},\"4764\":{\"m\":77,\"g\":94},\"4706\":{\"m\":77,\"g\":94},\"4745\":{\"m\":77,\"g\":94},\"4565\":{\"m\":77,\"g\":94},\"4804\":{\"m\":77,\"g\":94},\"4506\":{\"m\":77,\"g\":94},\"4388\":{\"m\":77,\"g\":94},\"4628\":{\"m\":77,\"g\":94},\"4719\":{\"m\":77,\"g\":94},\"5091\":{\"m\":78,\"g\":94},\"5080\":{\"m\":78,\"g\":94},\"5089\":{\"m\":78,\"g\":94},\"5079\":{\"m\":78,\"g\":94},\"5088\":{\"m\":78,\"g\":94},\"5052\":{\"m\":78,\"g\":94},\"5050\":{\"m\":78,\"g\":94},\"5074\":{\"m\":78,\"g\":94},\"4535\":{\"m\":78,\"g\":94},\"5072\":{\"m\":78,\"g\":94},\"4918\":{\"m\":78,\"g\":94},\"4996\":{\"m\":78,\"g\":94},\"4995\":{\"m\":78,\"g\":94},\"4994\":{\"m\":78,\"g\":94},\"5060\":{\"m\":78,\"g\":94},\"5057\":{\"m\":78,\"g\":94},\"5056\":{\"m\":78,\"g\":94},\"4625\":{\"m\":78,\"g\":94},\"5049\":{\"m\":78,\"g\":94},\"5051\":{\"m\":78,\"g\":94},\"5039\":{\"m\":78,\"g\":94},\"4796\":{\"m\":78,\"g\":94},\"5046\":{\"m\":78,\"g\":94},\"5048\":{\"m\":78,\"g\":94},\"5036\":{\"m\":78,\"g\":94},\"4992\":{\"m\":78,\"g\":94},\"5005\":{\"m\":78,\"g\":94},\"5024\":{\"m\":78,\"g\":94},\"5030\":{\"m\":78,\"g\":94},\"5020\":{\"m\":78,\"g\":94},\"4727\":{\"m\":78,\"g\":94},\"5009\":{\"m\":78,\"g\":94},\"5011\":{\"m\":78,\"g\":94},\"4817\":{\"m\":78,\"g\":94},\"5008\":{\"m\":78,\"g\":94},\"4951\":{\"m\":78,\"g\":94},\"4861\":{\"m\":78,\"g\":94},\"4989\":{\"m\":78,\"g\":94},\"4581\":{\"m\":78,\"g\":94},\"4915\":{\"m\":78,\"g\":94},\"4977\":{\"m\":78,\"g\":94},\"4958\":{\"m\":78,\"g\":94},\"4767\":{\"m\":78,\"g\":94},\"4959\":{\"m\":78,\"g\":94},\"4954\":{\"m\":78,\"g\":94},\"4953\":{\"m\":78,\"g\":94},\"4950\":{\"m\":78,\"g\":94},\"4754\":{\"m\":78,\"g\":94},\"4944\":{\"m\":78,\"g\":94},\"4928\":{\"m\":78,\"g\":94},\"4913\":{\"m\":78,\"g\":94},\"4936\":{\"m\":78,\"g\":94},\"4883\":{\"m\":78,\"g\":94},\"4925\":{\"m\":78,\"g\":94},\"4933\":{\"m\":78,\"g\":94},\"4932\":{\"m\":78,\"g\":94},\"4931\":{\"m\":78,\"g\":94},\"4902\":{\"m\":78,\"g\":94},\"4930\":{\"m\":78,\"g\":94},\"4927\":{\"m\":78,\"g\":94},\"4926\":{\"m\":78,\"g\":94},\"4896\":{\"m\":78,\"g\":94},\"4914\":{\"m\":78,\"g\":94},\"4908\":{\"m\":78,\"g\":94},\"4909\":{\"m\":78,\"g\":94},\"4890\":{\"m\":78,\"g\":94},\"4899\":{\"m\":78,\"g\":94},\"4898\":{\"m\":78,\"g\":94},\"4530\":{\"m\":78,\"g\":94},\"4891\":{\"m\":78,\"g\":94},\"4889\":{\"m\":78,\"g\":94},\"4886\":{\"m\":78,\"g\":94},\"4795\":{\"m\":78,\"g\":94},\"4845\":{\"m\":78,\"g\":94},\"4882\":{\"m\":78,\"g\":94},\"5117\":{\"m\":79,\"g\":94},\"5106\":{\"m\":79,\"g\":94},\"5092\":{\"m\":79,\"g\":94},\"5097\":{\"m\":79,\"g\":94},\"5445\":{\"m\":80,\"g\":94},\"5113\":{\"m\":80,\"g\":94},\"5425\":{\"m\":80,\"g\":94},\"5397\":{\"m\":80,\"g\":94},\"5398\":{\"m\":80,\"g\":94},\"5038\":{\"m\":80,\"g\":94},\"5211\":{\"m\":80,\"g\":94},\"5264\":{\"m\":80,\"g\":94},\"5436\":{\"m\":80,\"g\":94},\"5344\":{\"m\":80,\"g\":94},\"5431\":{\"m\":80,\"g\":94},\"5434\":{\"m\":80,\"g\":94},\"5430\":{\"m\":80,\"g\":94},\"5419\":{\"m\":80,\"g\":94},\"5420\":{\"m\":80,\"g\":94},\"5423\":{\"m\":80,\"g\":94},\"5422\":{\"m\":80,\"g\":94},\"5415\":{\"m\":80,\"g\":94},\"5416\":{\"m\":80,\"g\":94},\"5351\":{\"m\":80,\"g\":94},\"5412\":{\"m\":80,\"g\":94},\"5352\":{\"m\":80,\"g\":94},\"5214\":{\"m\":80,\"g\":94},\"5406\":{\"m\":80,\"g\":94},\"5401\":{\"m\":80,\"g\":94},\"5400\":{\"m\":80,\"g\":94},\"5381\":{\"m\":80,\"g\":94},\"5399\":{\"m\":80,\"g\":94},\"5395\":{\"m\":80,\"g\":94},\"5279\":{\"m\":80,\"g\":94},\"5291\":{\"m\":80,\"g\":94},\"5368\":{\"m\":80,\"g\":94},\"5393\":{\"m\":80,\"g\":94},\"5263\":{\"m\":80,\"g\":94},\"5392\":{\"m\":80,\"g\":94},\"5371\":{\"m\":80,\"g\":94},\"5385\":{\"m\":80,\"g\":94},\"5370\":{\"m\":80,\"g\":94},\"5384\":{\"m\":80,\"g\":94},\"5326\":{\"m\":80,\"g\":94},\"5003\":{\"m\":80,\"g\":94},\"5367\":{\"m\":80,\"g\":94},\"5364\":{\"m\":80,\"g\":94},\"5277\":{\"m\":80,\"g\":94},\"5360\":{\"m\":80,\"g\":94},\"5359\":{\"m\":80,\"g\":94},\"5161\":{\"m\":80,\"g\":94},\"5328\":{\"m\":80,\"g\":94},\"5357\":{\"m\":80,\"g\":94},\"5342\":{\"m\":80,\"g\":94},\"5343\":{\"m\":80,\"g\":94},\"5322\":{\"m\":80,\"g\":94},\"5341\":{\"m\":80,\"g\":94},\"5337\":{\"m\":80,\"g\":94},\"5336\":{\"m\":80,\"g\":94},\"5333\":{\"m\":80,\"g\":94},\"5332\":{\"m\":80,\"g\":94},\"5331\":{\"m\":80,\"g\":94},\"5327\":{\"m\":80,\"g\":94},\"4848\":{\"m\":80,\"g\":94},\"5294\":{\"m\":80,\"g\":94},\"5321\":{\"m\":80,\"g\":94},\"4884\":{\"m\":80,\"g\":94},\"5210\":{\"m\":80,\"g\":94},\"5317\":{\"m\":80,\"g\":94},\"5316\":{\"m\":80,\"g\":94},\"5315\":{\"m\":80,\"g\":94},\"5120\":{\"m\":80,\"g\":94},\"5299\":{\"m\":80,\"g\":94},\"5311\":{\"m\":80,\"g\":94},\"5142\":{\"m\":80,\"g\":94},\"5271\":{\"m\":80,\"g\":94},\"5065\":{\"m\":80,\"g\":94},\"5310\":{\"m\":80,\"g\":94},\"5308\":{\"m\":80,\"g\":94},\"5307\":{\"m\":80,\"g\":94},\"5306\":{\"m\":80,\"g\":94},\"5304\":{\"m\":80,\"g\":94},\"5303\":{\"m\":80,\"g\":94},\"5302\":{\"m\":80,\"g\":94},\"5298\":{\"m\":80,\"g\":94},\"5301\":{\"m\":80,\"g\":94},\"5300\":{\"m\":80,\"g\":94},\"5290\":{\"m\":80,\"g\":94},\"5292\":{\"m\":80,\"g\":94},\"5289\":{\"m\":80,\"g\":94},\"5288\":{\"m\":80,\"g\":94},\"5287\":{\"m\":80,\"g\":94},\"5286\":{\"m\":80,\"g\":94},\"5280\":{\"m\":80,\"g\":94},\"5193\":{\"m\":80,\"g\":94},\"5244\":{\"m\":80,\"g\":94},\"5265\":{\"m\":80,\"g\":94},\"5254\":{\"m\":80,\"g\":94},\"5262\":{\"m\":80,\"g\":94},\"5127\":{\"m\":80,\"g\":94},\"5228\":{\"m\":80,\"g\":94},\"5259\":{\"m\":80,\"g\":94},\"5245\":{\"m\":80,\"g\":94},\"5216\":{\"m\":80,\"g\":94},\"5167\":{\"m\":80,\"g\":94},\"5213\":{\"m\":80,\"g\":94},\"5215\":{\"m\":80,\"g\":94},\"5204\":{\"m\":80,\"g\":94},\"4880\":{\"m\":80,\"g\":94},\"5086\":{\"m\":80,\"g\":94},\"5190\":{\"m\":80,\"g\":94},\"4444\":{\"m\":80,\"g\":94},\"5209\":{\"m\":80,\"g\":94},\"5196\":{\"m\":80,\"g\":94},\"5207\":{\"m\":80,\"g\":94},\"5171\":{\"m\":80,\"g\":94},\"5144\":{\"m\":80,\"g\":94},\"5102\":{\"m\":80,\"g\":94},\"5194\":{\"m\":80,\"g\":94},\"5128\":{\"m\":80,\"g\":94},\"5185\":{\"m\":80,\"g\":94},\"5189\":{\"m\":80,\"g\":94},\"4058\":{\"m\":80,\"g\":94},\"5179\":{\"m\":80,\"g\":94},\"5180\":{\"m\":80,\"g\":94},\"5176\":{\"m\":80,\"g\":94},\"5110\":{\"m\":80,\"g\":94},\"5068\":{\"m\":80,\"g\":94},\"5173\":{\"m\":80,\"g\":94},\"5175\":{\"m\":80,\"g\":94},\"5174\":{\"m\":80,\"g\":94},\"3972\":{\"m\":80,\"g\":94},\"5159\":{\"m\":80,\"g\":94},\"4938\":{\"m\":80,\"g\":94},\"5139\":{\"m\":80,\"g\":94},\"4911\":{\"m\":80,\"g\":94},\"5150\":{\"m\":80,\"g\":94},\"5155\":{\"m\":80,\"g\":94},\"4686\":{\"m\":80,\"g\":94},\"5158\":{\"m\":80,\"g\":94},\"5151\":{\"m\":80,\"g\":94},\"5115\":{\"m\":80,\"g\":94},\"5152\":{\"m\":80,\"g\":94},\"4760\":{\"m\":80,\"g\":94},\"5103\":{\"m\":80,\"g\":94},\"5147\":{\"m\":80,\"g\":94},\"5083\":{\"m\":80,\"g\":94},\"5140\":{\"m\":80,\"g\":94},\"5145\":{\"m\":80,\"g\":94},\"4984\":{\"m\":80,\"g\":94},\"5137\":{\"m\":80,\"g\":94},\"5133\":{\"m\":80,\"g\":94},\"5090\":{\"m\":80,\"g\":94},\"5126\":{\"m\":80,\"g\":94},\"5582\":{\"m\":81,\"g\":94},\"5581\":{\"m\":81,\"g\":94},\"4947\":{\"m\":81,\"g\":94},\"5571\":{\"m\":81,\"g\":94},\"5564\":{\"m\":81,\"g\":94},\"5568\":{\"m\":81,\"g\":94},\"5432\":{\"m\":81,\"g\":94},\"5149\":{\"m\":81,\"g\":94},\"5562\":{\"m\":81,\"g\":94},\"5543\":{\"m\":81,\"g\":94},\"5561\":{\"m\":81,\"g\":94},\"5546\":{\"m\":81,\"g\":94},\"5549\":{\"m\":81,\"g\":94},\"5504\":{\"m\":81,\"g\":94},\"5475\":{\"m\":81,\"g\":94},\"5460\":{\"m\":81,\"g\":94},\"5547\":{\"m\":81,\"g\":94},\"5534\":{\"m\":81,\"g\":94},\"5545\":{\"m\":81,\"g\":94},\"5548\":{\"m\":81,\"g\":94},\"5340\":{\"m\":81,\"g\":94},\"5540\":{\"m\":81,\"g\":94},\"5544\":{\"m\":81,\"g\":94},\"5542\":{\"m\":81,\"g\":94},\"4693\":{\"m\":81,\"g\":94},\"5476\":{\"m\":81,\"g\":94},\"5497\":{\"m\":81,\"g\":94},\"5461\":{\"m\":81,\"g\":94},\"5473\":{\"m\":81,\"g\":94},\"5518\":{\"m\":81,\"g\":94},\"5440\":{\"m\":81,\"g\":94},\"4836\":{\"m\":81,\"g\":94},\"5512\":{\"m\":81,\"g\":94},\"5373\":{\"m\":81,\"g\":94},\"5426\":{\"m\":81,\"g\":94},\"5511\":{\"m\":81,\"g\":94},\"5500\":{\"m\":81,\"g\":94},\"5503\":{\"m\":81,\"g\":94},\"5496\":{\"m\":81,\"g\":94},\"5493\":{\"m\":81,\"g\":94},\"5205\":{\"m\":81,\"g\":94},\"5479\":{\"m\":81,\"g\":94},\"4887\":{\"m\":81,\"g\":94},\"5480\":{\"m\":81,\"g\":94},\"5481\":{\"m\":81,\"g\":94},\"5484\":{\"m\":81,\"g\":94},\"5489\":{\"m\":81,\"g\":94},\"4982\":{\"m\":81,\"g\":94},\"5345\":{\"m\":81,\"g\":94},\"5467\":{\"m\":81,\"g\":94},\"5463\":{\"m\":81,\"g\":94},\"5447\":{\"m\":81,\"g\":94},\"5449\":{\"m\":81,\"g\":94},\"5444\":{\"m\":81,\"g\":94},\"5611\":{\"m\":82,\"g\":94},\"5510\":{\"m\":82,\"g\":94},\"5610\":{\"m\":82,\"g\":94},\"5580\":{\"m\":82,\"g\":94},\"5604\":{\"m\":82,\"g\":94},\"5609\":{\"m\":82,\"g\":94},\"5608\":{\"m\":82,\"g\":94},\"5477\":{\"m\":82,\"g\":94},\"5589\":{\"m\":82,\"g\":94},\"5598\":{\"m\":82,\"g\":94},\"5488\":{\"m\":82,\"g\":94},\"5037\":{\"m\":82,\"g\":94},\"5021\":{\"m\":82,\"g\":94},\"4980\":{\"m\":82,\"g\":94},\"4937\":{\"m\":82,\"g\":94},\"4718\":{\"m\":82,\"g\":94},\"4590\":{\"m\":82,\"g\":94},\"3443\":{\"m\":82,\"g\":94},\"5348\":{\"m\":82,\"g\":94},\"5570\":{\"m\":82,\"g\":94},\"5318\":{\"m\":82,\"g\":94},\"5590\":{\"m\":82,\"g\":94},\"5319\":{\"m\":82,\"g\":94},\"5241\":{\"m\":82,\"g\":94},\"5188\":{\"m\":82,\"g\":94},\"5141\":{\"m\":82,\"g\":94},\"5019\":{\"m\":82,\"g\":94},\"5016\":{\"m\":82,\"g\":94},\"4859\":{\"m\":82,\"g\":94},\"4852\":{\"m\":82,\"g\":94},\"4675\":{\"m\":82,\"g\":94},\"4733\":{\"m\":82,\"g\":94},\"5378\":{\"m\":82,\"g\":94},\"5224\":{\"m\":82,\"g\":94},\"4226\":{\"m\":82,\"g\":94},\"5433\":{\"m\":82,\"g\":94},\"5588\":{\"m\":82,\"g\":94},\"5452\":{\"m\":82,\"g\":94},\"5586\":{\"m\":82,\"g\":94},\"5575\":{\"m\":82,\"g\":94},\"5417\":{\"m\":82,\"g\":94},\"5531\":{\"m\":82,\"g\":94},\"5526\":{\"m\":82,\"g\":94},\"5521\":{\"m\":82,\"g\":94},\"5559\":{\"m\":82,\"g\":94},\"5560\":{\"m\":82,\"g\":94},\"5567\":{\"m\":82,\"g\":94},\"5574\":{\"m\":82,\"g\":94},\"5795\":{\"m\":83,\"g\":94},\"5691\":{\"m\":83,\"g\":94},\"5790\":{\"m\":83,\"g\":94},\"5791\":{\"m\":83,\"g\":94},\"5769\":{\"m\":83,\"g\":94},\"5789\":{\"m\":83,\"g\":94},\"5787\":{\"m\":83,\"g\":94},\"5786\":{\"m\":83,\"g\":94},\"5785\":{\"m\":83,\"g\":94},\"5779\":{\"m\":83,\"g\":94},\"5777\":{\"m\":83,\"g\":94},\"5774\":{\"m\":83,\"g\":94},\"5776\":{\"m\":83,\"g\":94},\"5772\":{\"m\":83,\"g\":94},\"3744\":{\"m\":83,\"g\":94},\"4986\":{\"m\":83,\"g\":94},\"4971\":{\"m\":83,\"g\":94},\"4870\":{\"m\":83,\"g\":94},\"5633\":{\"m\":83,\"g\":94},\"5599\":{\"m\":83,\"g\":94},\"5565\":{\"m\":83,\"g\":94},\"5509\":{\"m\":83,\"g\":94},\"5687\":{\"m\":83,\"g\":94},\"5592\":{\"m\":83,\"g\":94},\"5607\":{\"m\":83,\"g\":94},\"5730\":{\"m\":83,\"g\":94},\"5697\":{\"m\":83,\"g\":94},\"5748\":{\"m\":83,\"g\":94},\"5682\":{\"m\":83,\"g\":94},\"5685\":{\"m\":83,\"g\":94},\"5716\":{\"m\":83,\"g\":94},\"5720\":{\"m\":83,\"g\":94},\"5722\":{\"m\":83,\"g\":94},\"5728\":{\"m\":83,\"g\":94},\"5733\":{\"m\":83,\"g\":94},\"5736\":{\"m\":83,\"g\":94},\"5756\":{\"m\":83,\"g\":94},\"5760\":{\"m\":83,\"g\":94},\"5552\":{\"m\":83,\"g\":94},\"5718\":{\"m\":83,\"g\":94},\"5719\":{\"m\":83,\"g\":94},\"5737\":{\"m\":83,\"g\":94},\"5740\":{\"m\":83,\"g\":94},\"5753\":{\"m\":83,\"g\":94},\"5754\":{\"m\":83,\"g\":94},\"5750\":{\"m\":83,\"g\":94},\"5723\":{\"m\":83,\"g\":94},\"5704\":{\"m\":83,\"g\":94},\"5738\":{\"m\":83,\"g\":94},\"5715\":{\"m\":83,\"g\":94},\"5078\":{\"m\":83,\"g\":94},\"4491\":{\"m\":83,\"g\":94},\"5706\":{\"m\":83,\"g\":94},\"5648\":{\"m\":83,\"g\":94},\"5684\":{\"m\":83,\"g\":94},\"5707\":{\"m\":83,\"g\":94},\"5349\":{\"m\":83,\"g\":94},\"5688\":{\"m\":83,\"g\":94},\"5686\":{\"m\":83,\"g\":94},\"5683\":{\"m\":83,\"g\":94},\"5667\":{\"m\":83,\"g\":94},\"5677\":{\"m\":83,\"g\":94},\"5530\":{\"m\":83,\"g\":94},\"5671\":{\"m\":83,\"g\":94},\"5670\":{\"m\":83,\"g\":94},\"5435\":{\"m\":83,\"g\":94},\"5669\":{\"m\":83,\"g\":94},\"5666\":{\"m\":83,\"g\":94},\"5281\":{\"m\":83,\"g\":94},\"5601\":{\"m\":83,\"g\":94},\"5619\":{\"m\":83,\"g\":94},\"5628\":{\"m\":83,\"g\":94},\"5649\":{\"m\":83,\"g\":94},\"5646\":{\"m\":83,\"g\":94},\"5638\":{\"m\":83,\"g\":94},\"5634\":{\"m\":83,\"g\":94},\"5272\":{\"m\":83,\"g\":94},\"5641\":{\"m\":83,\"g\":94},\"5632\":{\"m\":83,\"g\":94},\"5640\":{\"m\":83,\"g\":94},\"5624\":{\"m\":83,\"g\":94},\"5622\":{\"m\":83,\"g\":94},\"5620\":{\"m\":83,\"g\":94},\"5618\":{\"m\":83,\"g\":94},\"5615\":{\"m\":83,\"g\":94},\"5578\":{\"m\":83,\"g\":94},\"5845\":{\"m\":84,\"g\":94},\"5849\":{\"m\":84,\"g\":94},\"5854\":{\"m\":84,\"g\":94},\"5823\":{\"m\":84,\"g\":94},\"5847\":{\"m\":84,\"g\":94},\"5816\":{\"m\":84,\"g\":94},\"5798\":{\"m\":84,\"g\":94},\"5850\":{\"m\":84,\"g\":94},\"5851\":{\"m\":84,\"g\":94},\"5846\":{\"m\":84,\"g\":94},\"5726\":{\"m\":84,\"g\":94},\"5839\":{\"m\":84,\"g\":94},\"5842\":{\"m\":84,\"g\":94},\"5833\":{\"m\":84,\"g\":94},\"5838\":{\"m\":84,\"g\":94},\"5551\":{\"m\":84,\"g\":94},\"5825\":{\"m\":84,\"g\":94},\"5276\":{\"m\":84,\"g\":94},\"5482\":{\"m\":84,\"g\":94},\"5771\":{\"m\":84,\"g\":94},\"5809\":{\"m\":84,\"g\":94},\"5807\":{\"m\":84,\"g\":94},\"5690\":{\"m\":84,\"g\":94},\"5390\":{\"m\":84,\"g\":94},\"5788\":{\"m\":84,\"g\":94},\"5643\":{\"m\":84,\"g\":94},\"5796\":{\"m\":84,\"g\":94},\"5797\":{\"m\":84,\"g\":94},\"5939\":{\"m\":85,\"g\":94},\"5934\":{\"m\":85,\"g\":94},\"5881\":{\"m\":85,\"g\":94},\"5915\":{\"m\":85,\"g\":94},\"5930\":{\"m\":85,\"g\":94},\"5724\":{\"m\":85,\"g\":94},\"5783\":{\"m\":85,\"g\":94},\"5933\":{\"m\":85,\"g\":94},\"5932\":{\"m\":85,\"g\":94},\"5912\":{\"m\":85,\"g\":94},\"5909\":{\"m\":85,\"g\":94},\"5917\":{\"m\":85,\"g\":94},\"5919\":{\"m\":85,\"g\":94},\"5910\":{\"m\":85,\"g\":94},\"5905\":{\"m\":85,\"g\":94},\"5383\":{\"m\":85,\"g\":94},\"5903\":{\"m\":85,\"g\":94},\"5900\":{\"m\":85,\"g\":94},\"5899\":{\"m\":85,\"g\":94},\"5870\":{\"m\":85,\"g\":94},\"5830\":{\"m\":85,\"g\":94},\"5901\":{\"m\":85,\"g\":94},\"5898\":{\"m\":85,\"g\":94},\"5861\":{\"m\":85,\"g\":94},\"5696\":{\"m\":85,\"g\":94},\"5725\":{\"m\":85,\"g\":94},\"5893\":{\"m\":85,\"g\":94},\"5793\":{\"m\":85,\"g\":94},\"5896\":{\"m\":85,\"g\":94},\"5746\":{\"m\":85,\"g\":94},\"5895\":{\"m\":85,\"g\":94},\"5841\":{\"m\":85,\"g\":94},\"5836\":{\"m\":85,\"g\":94},\"5894\":{\"m\":85,\"g\":94},\"5880\":{\"m\":85,\"g\":94},\"5875\":{\"m\":85,\"g\":94},\"5859\":{\"m\":85,\"g\":94},\"4115\":{\"m\":85,\"g\":94},\"4949\":{\"m\":85,\"g\":94},\"5820\":{\"m\":85,\"g\":94},\"5868\":{\"m\":85,\"g\":94},\"5860\":{\"m\":85,\"g\":94},\"5801\":{\"m\":85,\"g\":94},\"5857\":{\"m\":85,\"g\":94},\"6165\":{\"m\":86,\"g\":94},\"5778\":{\"m\":86,\"g\":94},\"6162\":{\"m\":86,\"g\":94},\"6089\":{\"m\":86,\"g\":94},\"6141\":{\"m\":86,\"g\":94},\"5822\":{\"m\":86,\"g\":94},\"6101\":{\"m\":86,\"g\":94},\"5745\":{\"m\":86,\"g\":94},\"6132\":{\"m\":86,\"g\":94},\"6131\":{\"m\":86,\"g\":94},\"6112\":{\"m\":86,\"g\":94},\"6129\":{\"m\":86,\"g\":94},\"6123\":{\"m\":86,\"g\":94},\"6097\":{\"m\":86,\"g\":94},\"5662\":{\"m\":86,\"g\":94},\"5764\":{\"m\":86,\"g\":94},\"6119\":{\"m\":86,\"g\":94},\"5626\":{\"m\":86,\"g\":94},\"6091\":{\"m\":86,\"g\":94},\"5572\":{\"m\":86,\"g\":94},\"5232\":{\"m\":86,\"g\":94},\"5219\":{\"m\":86,\"g\":94},\"5121\":{\"m\":86,\"g\":94},\"6077\":{\"m\":86,\"g\":94},\"6111\":{\"m\":86,\"g\":94},\"6038\":{\"m\":86,\"g\":94},\"6079\":{\"m\":86,\"g\":94},\"6034\":{\"m\":86,\"g\":94},\"6102\":{\"m\":86,\"g\":94},\"6105\":{\"m\":86,\"g\":94},\"5993\":{\"m\":86,\"g\":94},\"6075\":{\"m\":86,\"g\":94},\"6063\":{\"m\":86,\"g\":94},\"5233\":{\"m\":86,\"g\":94},\"5014\":{\"m\":86,\"g\":94},\"6084\":{\"m\":86,\"g\":94},\"3853\":{\"m\":86,\"g\":94},\"6010\":{\"m\":86,\"g\":94},\"6039\":{\"m\":86,\"g\":94},\"5655\":{\"m\":86,\"g\":94},\"6062\":{\"m\":86,\"g\":94},\"6004\":{\"m\":86,\"g\":94},\"6045\":{\"m\":86,\"g\":94},\"5885\":{\"m\":86,\"g\":94},\"5751\":{\"m\":86,\"g\":94},\"6057\":{\"m\":86,\"g\":94},\"6048\":{\"m\":86,\"g\":94},\"6047\":{\"m\":86,\"g\":94},\"6046\":{\"m\":86,\"g\":94},\"5081\":{\"m\":86,\"g\":94},\"5752\":{\"m\":86,\"g\":94},\"5996\":{\"m\":86,\"g\":94},\"5428\":{\"m\":86,\"g\":94},\"5555\":{\"m\":86,\"g\":94},\"5587\":{\"m\":86,\"g\":94},\"5781\":{\"m\":86,\"g\":94},\"6018\":{\"m\":86,\"g\":94},\"6002\":{\"m\":86,\"g\":94},\"5997\":{\"m\":86,\"g\":94},\"5679\":{\"m\":86,\"g\":94},\"5957\":{\"m\":86,\"g\":94},\"6012\":{\"m\":86,\"g\":94},\"5992\":{\"m\":86,\"g\":94},\"5998\":{\"m\":86,\"g\":94},\"5991\":{\"m\":86,\"g\":94},\"5986\":{\"m\":86,\"g\":94},\"5977\":{\"m\":86,\"g\":94},\"5975\":{\"m\":86,\"g\":94},\"5681\":{\"m\":86,\"g\":94},\"5969\":{\"m\":86,\"g\":94},\"5968\":{\"m\":86,\"g\":94},\"5350\":{\"m\":86,\"g\":94},\"5967\":{\"m\":86,\"g\":94},\"5908\":{\"m\":86,\"g\":94},\"5960\":{\"m\":86,\"g\":94},\"5782\":{\"m\":86,\"g\":94},\"5956\":{\"m\":86,\"g\":94},\"5945\":{\"m\":86,\"g\":94},\"5944\":{\"m\":86,\"g\":94},\"5952\":{\"m\":86,\"g\":94},\"5953\":{\"m\":86,\"g\":94},\"5834\":{\"m\":86,\"g\":94},\"5921\":{\"m\":86,\"g\":94},\"6245\":{\"m\":87,\"g\":94},\"6259\":{\"m\":87,\"g\":94},\"6252\":{\"m\":87,\"g\":94},\"5084\":{\"m\":87,\"g\":94},\"5657\":{\"m\":87,\"g\":94},\"6247\":{\"m\":87,\"g\":94},\"6235\":{\"m\":87,\"g\":94},\"6251\":{\"m\":87,\"g\":94},\"6042\":{\"m\":87,\"g\":94},\"6248\":{\"m\":87,\"g\":94},\"6225\":{\"m\":87,\"g\":94},\"5922\":{\"m\":87,\"g\":94},\"6241\":{\"m\":87,\"g\":94},\"6243\":{\"m\":87,\"g\":94},\"6244\":{\"m\":87,\"g\":94},\"6223\":{\"m\":87,\"g\":94},\"6206\":{\"m\":87,\"g\":94},\"6209\":{\"m\":87,\"g\":94},\"6231\":{\"m\":87,\"g\":94},\"6212\":{\"m\":87,\"g\":94},\"6201\":{\"m\":87,\"g\":94},\"6213\":{\"m\":87,\"g\":94},\"5558\":{\"m\":87,\"g\":94},\"6154\":{\"m\":87,\"g\":94},\"6043\":{\"m\":87,\"g\":94},\"6204\":{\"m\":87,\"g\":94},\"6202\":{\"m\":87,\"g\":94},\"6178\":{\"m\":87,\"g\":94},\"6198\":{\"m\":87,\"g\":94},\"6192\":{\"m\":87,\"g\":94},\"6188\":{\"m\":87,\"g\":94},\"6032\":{\"m\":87,\"g\":94},\"6199\":{\"m\":87,\"g\":94},\"5621\":{\"m\":87,\"g\":94},\"6196\":{\"m\":87,\"g\":94},\"6195\":{\"m\":87,\"g\":94},\"6073\":{\"m\":87,\"g\":94},\"6169\":{\"m\":87,\"g\":94},\"6191\":{\"m\":87,\"g\":94},\"6190\":{\"m\":87,\"g\":94},\"6186\":{\"m\":87,\"g\":94},\"5654\":{\"m\":87,\"g\":94},\"6180\":{\"m\":87,\"g\":94},\"6179\":{\"m\":87,\"g\":94},\"6184\":{\"m\":87,\"g\":94},\"6183\":{\"m\":87,\"g\":94},\"4701\":{\"m\":87,\"g\":94},\"6181\":{\"m\":87,\"g\":94},\"6146\":{\"m\":87,\"g\":94},\"6114\":{\"m\":87,\"g\":94},\"6118\":{\"m\":87,\"g\":94},\"6016\":{\"m\":87,\"g\":94},\"6566\":{\"m\":88,\"g\":94},\"6567\":{\"m\":88,\"g\":94},\"6485\":{\"m\":88,\"g\":94},\"6560\":{\"m\":88,\"g\":94},\"6550\":{\"m\":88,\"g\":94},\"6533\":{\"m\":88,\"g\":94},\"6562\":{\"m\":88,\"g\":94},\"6524\":{\"m\":88,\"g\":94},\"6347\":{\"m\":88,\"g\":94},\"6558\":{\"m\":88,\"g\":94},\"6521\":{\"m\":88,\"g\":94},\"6507\":{\"m\":88,\"g\":94},\"6474\":{\"m\":88,\"g\":94},\"6452\":{\"m\":88,\"g\":94},\"6404\":{\"m\":88,\"g\":94},\"6493\":{\"m\":88,\"g\":94},\"6355\":{\"m\":88,\"g\":94},\"6535\":{\"m\":88,\"g\":94},\"6536\":{\"m\":88,\"g\":94},\"6059\":{\"m\":88,\"g\":94},\"6532\":{\"m\":88,\"g\":94},\"6120\":{\"m\":88,\"g\":94},\"6522\":{\"m\":88,\"g\":94},\"6520\":{\"m\":88,\"g\":94},\"6469\":{\"m\":88,\"g\":94},\"6482\":{\"m\":88,\"g\":94},\"6308\":{\"m\":88,\"g\":94},\"6388\":{\"m\":88,\"g\":94},\"6492\":{\"m\":88,\"g\":94},\"6504\":{\"m\":88,\"g\":94},\"6019\":{\"m\":88,\"g\":94},\"6457\":{\"m\":88,\"g\":94},\"6499\":{\"m\":88,\"g\":94},\"6510\":{\"m\":88,\"g\":94},\"6508\":{\"m\":88,\"g\":94},\"5759\":{\"m\":88,\"g\":94},\"6419\":{\"m\":88,\"g\":94},\"6275\":{\"m\":88,\"g\":94},\"6503\":{\"m\":88,\"g\":94},\"5573\":{\"m\":88,\"g\":94},\"6445\":{\"m\":88,\"g\":94},\"6461\":{\"m\":88,\"g\":94},\"6467\":{\"m\":88,\"g\":94},\"6468\":{\"m\":88,\"g\":94},\"6476\":{\"m\":88,\"g\":94},\"6311\":{\"m\":88,\"g\":94},\"6487\":{\"m\":88,\"g\":94},\"5339\":{\"m\":88,\"g\":94},\"6381\":{\"m\":88,\"g\":94},\"6214\":{\"m\":88,\"g\":94},\"6475\":{\"m\":88,\"g\":94},\"6472\":{\"m\":88,\"g\":94},\"6447\":{\"m\":88,\"g\":94},\"6385\":{\"m\":88,\"g\":94},\"6444\":{\"m\":88,\"g\":94},\"6405\":{\"m\":88,\"g\":94},\"6429\":{\"m\":88,\"g\":94},\"6412\":{\"m\":88,\"g\":94},\"6387\":{\"m\":88,\"g\":94},\"6386\":{\"m\":88,\"g\":94},\"6326\":{\"m\":88,\"g\":94},\"6438\":{\"m\":88,\"g\":94},\"6321\":{\"m\":88,\"g\":94},\"4957\":{\"m\":88,\"g\":94},\"6440\":{\"m\":88,\"g\":94},\"6431\":{\"m\":88,\"g\":94},\"6098\":{\"m\":88,\"g\":94},\"6414\":{\"m\":88,\"g\":94},\"6430\":{\"m\":88,\"g\":94},\"6417\":{\"m\":88,\"g\":94},\"6137\":{\"m\":88,\"g\":94},\"6401\":{\"m\":88,\"g\":94},\"6400\":{\"m\":88,\"g\":94},\"5974\":{\"m\":88,\"g\":94},\"6325\":{\"m\":88,\"g\":94},\"6323\":{\"m\":88,\"g\":94},\"6333\":{\"m\":88,\"g\":94},\"6397\":{\"m\":88,\"g\":94},\"6396\":{\"m\":88,\"g\":94},\"6395\":{\"m\":88,\"g\":94},\"6339\":{\"m\":88,\"g\":94},\"6383\":{\"m\":88,\"g\":94},\"6392\":{\"m\":88,\"g\":94},\"6331\":{\"m\":88,\"g\":94},\"6391\":{\"m\":88,\"g\":94},\"6365\":{\"m\":88,\"g\":94},\"6250\":{\"m\":88,\"g\":94},\"6187\":{\"m\":88,\"g\":94},\"6379\":{\"m\":88,\"g\":94},\"6377\":{\"m\":88,\"g\":94},\"6362\":{\"m\":88,\"g\":94},\"6364\":{\"m\":88,\"g\":94},\"6330\":{\"m\":88,\"g\":94},\"6290\":{\"m\":88,\"g\":94},\"6041\":{\"m\":88,\"g\":94},\"6284\":{\"m\":88,\"g\":94},\"6257\":{\"m\":88,\"g\":94},\"6108\":{\"m\":88,\"g\":94},\"6134\":{\"m\":88,\"g\":94},\"6107\":{\"m\":88,\"g\":94},\"4741\":{\"m\":88,\"g\":94},\"6348\":{\"m\":88,\"g\":94},\"6211\":{\"m\":88,\"g\":94},\"6175\":{\"m\":88,\"g\":94},\"6366\":{\"m\":88,\"g\":94},\"6373\":{\"m\":88,\"g\":94},\"6356\":{\"m\":88,\"g\":94},\"6368\":{\"m\":88,\"g\":94},\"6324\":{\"m\":88,\"g\":94},\"6316\":{\"m\":88,\"g\":94},\"5099\":{\"m\":88,\"g\":94},\"6361\":{\"m\":88,\"g\":94},\"6360\":{\"m\":88,\"g\":94},\"6358\":{\"m\":88,\"g\":94},\"6359\":{\"m\":88,\"g\":94},\"6121\":{\"m\":88,\"g\":94},\"5694\":{\"m\":88,\"g\":94},\"6334\":{\"m\":88,\"g\":94},\"6136\":{\"m\":88,\"g\":94},\"6336\":{\"m\":88,\"g\":94},\"6327\":{\"m\":88,\"g\":94},\"6302\":{\"m\":88,\"g\":94},\"6147\":{\"m\":88,\"g\":94},\"6317\":{\"m\":88,\"g\":94},\"6216\":{\"m\":88,\"g\":94},\"6298\":{\"m\":88,\"g\":94},\"6109\":{\"m\":88,\"g\":94},\"5914\":{\"m\":88,\"g\":94},\"6009\":{\"m\":88,\"g\":94},\"6274\":{\"m\":88,\"g\":94},\"6283\":{\"m\":88,\"g\":94},\"6300\":{\"m\":88,\"g\":94},\"6138\":{\"m\":88,\"g\":94},\"6282\":{\"m\":88,\"g\":94},\"6276\":{\"m\":88,\"g\":94},\"6273\":{\"m\":88,\"g\":94},\"6115\":{\"m\":88,\"g\":94},\"7038\":{\"m\":89,\"g\":94},\"6833\":{\"m\":89,\"g\":94},\"7029\":{\"m\":89,\"g\":94},\"7027\":{\"m\":89,\"g\":94},\"6980\":{\"m\":89,\"g\":94},\"6964\":{\"m\":89,\"g\":94},\"7023\":{\"m\":89,\"g\":94},\"7018\":{\"m\":89,\"g\":94},\"7017\":{\"m\":89,\"g\":94},\"6741\":{\"m\":89,\"g\":94},\"6987\":{\"m\":89,\"g\":94},\"7015\":{\"m\":89,\"g\":94},\"7013\":{\"m\":89,\"g\":94},\"6884\":{\"m\":89,\"g\":94},\"6992\":{\"m\":89,\"g\":94},\"6998\":{\"m\":89,\"g\":94},\"7008\":{\"m\":89,\"g\":94},\"7007\":{\"m\":89,\"g\":94},\"6958\":{\"m\":89,\"g\":94},\"6557\":{\"m\":89,\"g\":94},\"6990\":{\"m\":89,\"g\":94},\"6960\":{\"m\":89,\"g\":94},\"6973\":{\"m\":89,\"g\":94},\"6983\":{\"m\":89,\"g\":94},\"6929\":{\"m\":89,\"g\":94},\"6977\":{\"m\":89,\"g\":94},\"6967\":{\"m\":89,\"g\":94},\"6981\":{\"m\":89,\"g\":94},\"6979\":{\"m\":89,\"g\":94},\"6976\":{\"m\":89,\"g\":94},\"6970\":{\"m\":89,\"g\":94},\"6937\":{\"m\":89,\"g\":94},\"6965\":{\"m\":89,\"g\":94},\"6974\":{\"m\":89,\"g\":94},\"6966\":{\"m\":89,\"g\":94},\"6956\":{\"m\":89,\"g\":94},\"6963\":{\"m\":89,\"g\":94},\"6926\":{\"m\":89,\"g\":94},\"6968\":{\"m\":89,\"g\":94},\"6853\":{\"m\":89,\"g\":94},\"6957\":{\"m\":89,\"g\":94},\"6955\":{\"m\":89,\"g\":94},\"6916\":{\"m\":89,\"g\":94},\"6885\":{\"m\":89,\"g\":94},\"6950\":{\"m\":89,\"g\":94},\"6953\":{\"m\":89,\"g\":94},\"6220\":{\"m\":89,\"g\":94},\"6866\":{\"m\":89,\"g\":94},\"6895\":{\"m\":89,\"g\":94},\"6874\":{\"m\":89,\"g\":94},\"6915\":{\"m\":89,\"g\":94},\"6924\":{\"m\":89,\"g\":94},\"6945\":{\"m\":89,\"g\":94},\"6369\":{\"m\":89,\"g\":94},\"5955\":{\"m\":89,\"g\":94},\"6912\":{\"m\":89,\"g\":94},\"6910\":{\"m\":89,\"g\":94},\"6944\":{\"m\":89,\"g\":94},\"6943\":{\"m\":89,\"g\":94},\"6942\":{\"m\":89,\"g\":94},\"6939\":{\"m\":89,\"g\":94},\"6767\":{\"m\":89,\"g\":94},\"6879\":{\"m\":89,\"g\":94},\"6922\":{\"m\":89,\"g\":94},\"6932\":{\"m\":89,\"g\":94},\"6934\":{\"m\":89,\"g\":94},\"6931\":{\"m\":89,\"g\":94},\"6930\":{\"m\":89,\"g\":94},\"6838\":{\"m\":89,\"g\":94},\"6458\":{\"m\":89,\"g\":94},\"6877\":{\"m\":89,\"g\":94},\"6887\":{\"m\":89,\"g\":94},\"6890\":{\"m\":89,\"g\":94},\"6764\":{\"m\":89,\"g\":94},\"6170\":{\"m\":89,\"g\":94},\"6837\":{\"m\":89,\"g\":94},\"6868\":{\"m\":89,\"g\":94},\"6865\":{\"m\":89,\"g\":94},\"6851\":{\"m\":89,\"g\":94},\"6846\":{\"m\":89,\"g\":94},\"6820\":{\"m\":89,\"g\":94},\"6277\":{\"m\":89,\"g\":94},\"6861\":{\"m\":89,\"g\":94},\"6878\":{\"m\":89,\"g\":94},\"6736\":{\"m\":89,\"g\":94},\"6852\":{\"m\":89,\"g\":94},\"6460\":{\"m\":89,\"g\":94},\"6659\":{\"m\":89,\"g\":94},\"6858\":{\"m\":89,\"g\":94},\"6745\":{\"m\":89,\"g\":94},\"5929\":{\"m\":89,\"g\":94},\"6735\":{\"m\":89,\"g\":94},\"6816\":{\"m\":89,\"g\":94},\"6818\":{\"m\":89,\"g\":94},\"6671\":{\"m\":89,\"g\":94},\"6456\":{\"m\":89,\"g\":94},\"6766\":{\"m\":89,\"g\":94},\"6093\":{\"m\":89,\"g\":94},\"6812\":{\"m\":89,\"g\":94},\"6811\":{\"m\":89,\"g\":94},\"6815\":{\"m\":89,\"g\":94},\"6813\":{\"m\":89,\"g\":94},\"6780\":{\"m\":89,\"g\":94},\"5382\":{\"m\":89,\"g\":94},\"6805\":{\"m\":89,\"g\":94},\"6699\":{\"m\":89,\"g\":94},\"6803\":{\"m\":89,\"g\":94},\"6804\":{\"m\":89,\"g\":94},\"6799\":{\"m\":89,\"g\":94},\"6800\":{\"m\":89,\"g\":94},\"5981\":{\"m\":89,\"g\":94},\"6421\":{\"m\":89,\"g\":94},\"6797\":{\"m\":89,\"g\":94},\"6795\":{\"m\":89,\"g\":94},\"6787\":{\"m\":89,\"g\":94},\"6794\":{\"m\":89,\"g\":94},\"6788\":{\"m\":89,\"g\":94},\"6791\":{\"m\":89,\"g\":94},\"6792\":{\"m\":89,\"g\":94},\"6734\":{\"m\":89,\"g\":94},\"6408\":{\"m\":89,\"g\":94},\"6786\":{\"m\":89,\"g\":94},\"6785\":{\"m\":89,\"g\":94},\"6782\":{\"m\":89,\"g\":94},\"6784\":{\"m\":89,\"g\":94},\"6679\":{\"m\":89,\"g\":94},\"6772\":{\"m\":89,\"g\":94},\"6289\":{\"m\":89,\"g\":94},\"6509\":{\"m\":89,\"g\":94},\"6737\":{\"m\":89,\"g\":94},\"6761\":{\"m\":89,\"g\":94},\"6765\":{\"m\":89,\"g\":94},\"6727\":{\"m\":89,\"g\":94},\"6265\":{\"m\":89,\"g\":94},\"6748\":{\"m\":89,\"g\":94},\"6746\":{\"m\":89,\"g\":94},\"6742\":{\"m\":89,\"g\":94},\"6728\":{\"m\":89,\"g\":94},\"6680\":{\"m\":89,\"g\":94},\"6437\":{\"m\":89,\"g\":94},\"6729\":{\"m\":89,\"g\":94},\"6545\":{\"m\":89,\"g\":94},\"6705\":{\"m\":89,\"g\":94},\"6715\":{\"m\":89,\"g\":94},\"6725\":{\"m\":89,\"g\":94},\"6718\":{\"m\":89,\"g\":94},\"6720\":{\"m\":89,\"g\":94},\"6726\":{\"m\":89,\"g\":94},\"6676\":{\"m\":89,\"g\":94},\"6709\":{\"m\":89,\"g\":94},\"6719\":{\"m\":89,\"g\":94},\"6479\":{\"m\":89,\"g\":94},\"6668\":{\"m\":89,\"g\":94},\"6682\":{\"m\":89,\"g\":94},\"6711\":{\"m\":89,\"g\":94},\"6710\":{\"m\":89,\"g\":94},\"6712\":{\"m\":89,\"g\":94},\"6706\":{\"m\":89,\"g\":94},\"6703\":{\"m\":89,\"g\":94},\"6697\":{\"m\":89,\"g\":94},\"6649\":{\"m\":89,\"g\":94},\"6685\":{\"m\":89,\"g\":94},\"6689\":{\"m\":89,\"g\":94},\"6693\":{\"m\":89,\"g\":94},\"6655\":{\"m\":89,\"g\":94},\"6260\":{\"m\":89,\"g\":94},\"6627\":{\"m\":89,\"g\":94},\"6687\":{\"m\":89,\"g\":94},\"6582\":{\"m\":89,\"g\":94},\"6672\":{\"m\":89,\"g\":94},\"6678\":{\"m\":89,\"g\":94},\"6380\":{\"m\":89,\"g\":94},\"6665\":{\"m\":89,\"g\":94},\"6673\":{\"m\":89,\"g\":94},\"6007\":{\"m\":89,\"g\":94},\"6473\":{\"m\":89,\"g\":94},\"6677\":{\"m\":89,\"g\":94},\"6661\":{\"m\":89,\"g\":94},\"6660\":{\"m\":89,\"g\":94},\"6638\":{\"m\":89,\"g\":94},\"6606\":{\"m\":89,\"g\":94},\"6662\":{\"m\":89,\"g\":94},\"6652\":{\"m\":89,\"g\":94},\"6658\":{\"m\":89,\"g\":94},\"6403\":{\"m\":89,\"g\":94},\"6601\":{\"m\":89,\"g\":94},\"6640\":{\"m\":89,\"g\":94},\"6450\":{\"m\":89,\"g\":94},\"6650\":{\"m\":89,\"g\":94},\"6585\":{\"m\":89,\"g\":94},\"6648\":{\"m\":89,\"g\":94},\"6634\":{\"m\":89,\"g\":94},\"6646\":{\"m\":89,\"g\":94},\"6643\":{\"m\":89,\"g\":94},\"6547\":{\"m\":89,\"g\":94},\"6631\":{\"m\":89,\"g\":94},\"6603\":{\"m\":89,\"g\":94},\"6263\":{\"m\":89,\"g\":94},\"6639\":{\"m\":89,\"g\":94},\"6635\":{\"m\":89,\"g\":94},\"6620\":{\"m\":89,\"g\":94},\"6629\":{\"m\":89,\"g\":94},\"6628\":{\"m\":89,\"g\":94},\"6617\":{\"m\":89,\"g\":94},\"6306\":{\"m\":89,\"g\":94},\"6599\":{\"m\":89,\"g\":94},\"6611\":{\"m\":89,\"g\":94},\"6598\":{\"m\":89,\"g\":94},\"6597\":{\"m\":89,\"g\":94},\"6581\":{\"m\":89,\"g\":94},\"6610\":{\"m\":89,\"g\":94},\"6575\":{\"m\":89,\"g\":94},\"6609\":{\"m\":89,\"g\":94},\"6594\":{\"m\":89,\"g\":94},\"6587\":{\"m\":89,\"g\":94},\"6596\":{\"m\":89,\"g\":94},\"6595\":{\"m\":89,\"g\":94},\"6593\":{\"m\":89,\"g\":94},\"6586\":{\"m\":89,\"g\":94},\"6571\":{\"m\":89,\"g\":94},\"6439\":{\"m\":89,\"g\":94},\"6527\":{\"m\":89,\"g\":94},\"6588\":{\"m\":89,\"g\":94},\"6570\":{\"m\":89,\"g\":94},\"6546\":{\"m\":89,\"g\":94},\"6537\":{\"m\":89,\"g\":94},\"6494\":{\"m\":89,\"g\":94},\"5961\":{\"m\":89,\"g\":94},\"6543\":{\"m\":89,\"g\":94},\"4068\":{\"m\":89,\"g\":94},\"6577\":{\"m\":89,\"g\":94},\"6578\":{\"m\":89,\"g\":94},\"6477\":{\"m\":89,\"g\":94},\"6564\":{\"m\":89,\"g\":94},\"6576\":{\"m\":89,\"g\":94},\"7248\":{\"m\":90,\"g\":94},\"7244\":{\"m\":90,\"g\":94},\"7247\":{\"m\":90,\"g\":94},\"6058\":{\"m\":90,\"g\":94},\"7234\":{\"m\":90,\"g\":94},\"7245\":{\"m\":90,\"g\":94},\"7231\":{\"m\":90,\"g\":94},\"7239\":{\"m\":90,\"g\":94},\"7232\":{\"m\":90,\"g\":94},\"7207\":{\"m\":90,\"g\":94},\"7213\":{\"m\":90,\"g\":94},\"7228\":{\"m\":90,\"g\":94},\"7221\":{\"m\":90,\"g\":94},\"7218\":{\"m\":90,\"g\":94},\"6378\":{\"m\":90,\"g\":94},\"7217\":{\"m\":90,\"g\":94},\"7215\":{\"m\":90,\"g\":94},\"7214\":{\"m\":90,\"g\":94},\"7210\":{\"m\":90,\"g\":94},\"7163\":{\"m\":90,\"g\":94},\"7205\":{\"m\":90,\"g\":94},\"7204\":{\"m\":90,\"g\":94},\"7202\":{\"m\":90,\"g\":94},\"7200\":{\"m\":90,\"g\":94},\"7196\":{\"m\":90,\"g\":94},\"7198\":{\"m\":90,\"g\":94},\"7195\":{\"m\":90,\"g\":94},\"7186\":{\"m\":90,\"g\":94},\"7180\":{\"m\":90,\"g\":94},\"7189\":{\"m\":90,\"g\":94},\"7191\":{\"m\":90,\"g\":94},\"7190\":{\"m\":90,\"g\":94},\"7184\":{\"m\":90,\"g\":94},\"7181\":{\"m\":90,\"g\":94},\"7177\":{\"m\":90,\"g\":94},\"7178\":{\"m\":90,\"g\":94},\"7175\":{\"m\":90,\"g\":94},\"7172\":{\"m\":90,\"g\":94},\"7173\":{\"m\":90,\"g\":94},\"7161\":{\"m\":90,\"g\":94},\"7150\":{\"m\":90,\"g\":94},\"7153\":{\"m\":90,\"g\":94},\"7170\":{\"m\":90,\"g\":94},\"7165\":{\"m\":90,\"g\":94},\"7156\":{\"m\":90,\"g\":94},\"7020\":{\"m\":90,\"g\":94},\"7157\":{\"m\":90,\"g\":94},\"6814\":{\"m\":90,\"g\":94},\"7154\":{\"m\":90,\"g\":94},\"7155\":{\"m\":90,\"g\":94},\"7152\":{\"m\":90,\"g\":94},\"7146\":{\"m\":90,\"g\":94},\"7145\":{\"m\":90,\"g\":94},\"7056\":{\"m\":90,\"g\":94},\"7058\":{\"m\":90,\"g\":94},\"7140\":{\"m\":90,\"g\":94},\"7134\":{\"m\":90,\"g\":94},\"6026\":{\"m\":90,\"g\":94},\"7092\":{\"m\":90,\"g\":94},\"7126\":{\"m\":90,\"g\":94},\"7119\":{\"m\":90,\"g\":94},\"7115\":{\"m\":90,\"g\":94},\"6919\":{\"m\":90,\"g\":94},\"7093\":{\"m\":90,\"g\":94},\"6994\":{\"m\":90,\"g\":94},\"6824\":{\"m\":90,\"g\":94},\"7079\":{\"m\":90,\"g\":94},\"6106\":{\"m\":90,\"g\":94},\"7091\":{\"m\":90,\"g\":94},\"7097\":{\"m\":90,\"g\":94},\"6870\":{\"m\":90,\"g\":94},\"7067\":{\"m\":90,\"g\":94},\"7076\":{\"m\":90,\"g\":94},\"7046\":{\"m\":90,\"g\":94},\"7054\":{\"m\":90,\"g\":94},\"7073\":{\"m\":90,\"g\":94},\"6031\":{\"m\":90,\"g\":94},\"7071\":{\"m\":90,\"g\":94},\"6579\":{\"m\":90,\"g\":94},\"7066\":{\"m\":90,\"g\":94},\"6716\":{\"m\":90,\"g\":94},\"7064\":{\"m\":90,\"g\":94},\"7049\":{\"m\":90,\"g\":94},\"7063\":{\"m\":90,\"g\":94},\"6947\":{\"m\":90,\"g\":94},\"7057\":{\"m\":90,\"g\":94},\"7061\":{\"m\":90,\"g\":94},\"7060\":{\"m\":90,\"g\":94},\"7021\":{\"m\":90,\"g\":94},\"7053\":{\"m\":90,\"g\":94},\"7037\":{\"m\":90,\"g\":94},\"7051\":{\"m\":90,\"g\":94},\"6999\":{\"m\":90,\"g\":94},\"7045\":{\"m\":90,\"g\":94},\"7043\":{\"m\":90,\"g\":94},\"7040\":{\"m\":90,\"g\":94},\"7493\":{\"m\":91,\"g\":94},\"7490\":{\"m\":91,\"g\":94},\"7376\":{\"m\":91,\"g\":94},\"7487\":{\"m\":91,\"g\":94},\"7347\":{\"m\":91,\"g\":94},\"7269\":{\"m\":91,\"g\":94},\"7449\":{\"m\":91,\"g\":94},\"7469\":{\"m\":91,\"g\":94},\"7378\":{\"m\":91,\"g\":94},\"7481\":{\"m\":91,\"g\":94},\"7382\":{\"m\":91,\"g\":94},\"7480\":{\"m\":91,\"g\":94},\"7456\":{\"m\":91,\"g\":94},\"7479\":{\"m\":91,\"g\":94},\"7397\":{\"m\":91,\"g\":94},\"7472\":{\"m\":91,\"g\":94},\"7457\":{\"m\":91,\"g\":94},\"7290\":{\"m\":91,\"g\":94},\"6821\":{\"m\":91,\"g\":94},\"7454\":{\"m\":91,\"g\":94},\"7451\":{\"m\":91,\"g\":94},\"7445\":{\"m\":91,\"g\":94},\"7361\":{\"m\":91,\"g\":94},\"7391\":{\"m\":91,\"g\":94},\"7441\":{\"m\":91,\"g\":94},\"7327\":{\"m\":91,\"g\":94},\"7406\":{\"m\":91,\"g\":94},\"7414\":{\"m\":91,\"g\":94},\"7408\":{\"m\":91,\"g\":94},\"7412\":{\"m\":91,\"g\":94},\"7351\":{\"m\":91,\"g\":94},\"7420\":{\"m\":91,\"g\":94},\"7409\":{\"m\":91,\"g\":94},\"7425\":{\"m\":91,\"g\":94},\"6984\":{\"m\":91,\"g\":94},\"7394\":{\"m\":91,\"g\":94},\"7396\":{\"m\":91,\"g\":94},\"7400\":{\"m\":91,\"g\":94},\"7329\":{\"m\":91,\"g\":94},\"7401\":{\"m\":91,\"g\":94},\"7403\":{\"m\":91,\"g\":94},\"7402\":{\"m\":91,\"g\":94},\"7285\":{\"m\":91,\"g\":94},\"7219\":{\"m\":91,\"g\":94},\"7360\":{\"m\":91,\"g\":94},\"7399\":{\"m\":91,\"g\":94},\"7398\":{\"m\":91,\"g\":94},\"6389\":{\"m\":91,\"g\":94},\"7372\":{\"m\":91,\"g\":94},\"5485\":{\"m\":91,\"g\":94},\"7393\":{\"m\":91,\"g\":94},\"7356\":{\"m\":91,\"g\":94},\"7326\":{\"m\":91,\"g\":94},\"7322\":{\"m\":91,\"g\":94},\"7371\":{\"m\":91,\"g\":94},\"7364\":{\"m\":91,\"g\":94},\"7159\":{\"m\":91,\"g\":94},\"7370\":{\"m\":91,\"g\":94},\"7343\":{\"m\":91,\"g\":94},\"7366\":{\"m\":91,\"g\":94},\"7362\":{\"m\":91,\"g\":94},\"7242\":{\"m\":91,\"g\":94},\"7363\":{\"m\":91,\"g\":94},\"7303\":{\"m\":91,\"g\":94},\"7354\":{\"m\":91,\"g\":94},\"7099\":{\"m\":91,\"g\":94},\"7333\":{\"m\":91,\"g\":94},\"7331\":{\"m\":91,\"g\":94},\"7284\":{\"m\":91,\"g\":94},\"7096\":{\"m\":91,\"g\":94},\"7319\":{\"m\":91,\"g\":94},\"7003\":{\"m\":91,\"g\":94},\"7301\":{\"m\":91,\"g\":94},\"7297\":{\"m\":91,\"g\":94},\"7251\":{\"m\":91,\"g\":94},\"7300\":{\"m\":91,\"g\":94},\"6614\":{\"m\":91,\"g\":94},\"7267\":{\"m\":91,\"g\":94},\"7264\":{\"m\":91,\"g\":94},\"7237\":{\"m\":91,\"g\":94},\"7289\":{\"m\":91,\"g\":94},\"7288\":{\"m\":91,\"g\":94},\"7286\":{\"m\":91,\"g\":94},\"6842\":{\"m\":91,\"g\":94},\"7283\":{\"m\":91,\"g\":94},\"7164\":{\"m\":91,\"g\":94},\"7022\":{\"m\":91,\"g\":94},\"6081\":{\"m\":91,\"g\":94},\"7265\":{\"m\":91,\"g\":94},\"7160\":{\"m\":91,\"g\":94},\"7179\":{\"m\":91,\"g\":94},\"7167\":{\"m\":91,\"g\":94},\"7252\":{\"m\":91,\"g\":94},\"7122\":{\"m\":91,\"g\":94},\"7233\":{\"m\":91,\"g\":94},\"7125\":{\"m\":91,\"g\":94},\"7559\":{\"m\":92,\"g\":94},\"7541\":{\"m\":92,\"g\":94},\"7542\":{\"m\":92,\"g\":94},\"7544\":{\"m\":92,\"g\":94},\"7549\":{\"m\":92,\"g\":94},\"7543\":{\"m\":92,\"g\":94},\"7527\":{\"m\":92,\"g\":94},\"7531\":{\"m\":92,\"g\":94},\"7148\":{\"m\":92,\"g\":94},\"7499\":{\"m\":92,\"g\":94},\"7507\":{\"m\":92,\"g\":94},\"7513\":{\"m\":92,\"g\":94},\"7521\":{\"m\":92,\"g\":94},\"7520\":{\"m\":92,\"g\":94},\"6793\":{\"m\":92,\"g\":94},\"6721\":{\"m\":92,\"g\":94},\"7386\":{\"m\":92,\"g\":94},\"7522\":{\"m\":92,\"g\":94},\"6641\":{\"m\":92,\"g\":94},\"6626\":{\"m\":92,\"g\":94},\"7498\":{\"m\":92,\"g\":94},\"7512\":{\"m\":92,\"g\":94},\"7510\":{\"m\":92,\"g\":94},\"7516\":{\"m\":92,\"g\":94},\"7508\":{\"m\":92,\"g\":94},\"7437\":{\"m\":92,\"g\":94},\"7489\":{\"m\":92,\"g\":94},\"7505\":{\"m\":92,\"g\":94},\"6717\":{\"m\":92,\"g\":94},\"7277\":{\"m\":92,\"g\":94},\"7422\":{\"m\":92,\"g\":94},\"7439\":{\"m\":92,\"g\":94},\"7236\":{\"m\":92,\"g\":94},\"7423\":{\"m\":92,\"g\":94},\"7268\":{\"m\":92,\"g\":94},\"7802\":{\"m\":93,\"g\":94},\"7801\":{\"m\":93,\"g\":94},\"7799\":{\"m\":93,\"g\":94},\"7792\":{\"m\":93,\"g\":94},\"7800\":{\"m\":93,\"g\":94},\"7756\":{\"m\":93,\"g\":94},\"7790\":{\"m\":93,\"g\":94},\"7757\":{\"m\":93,\"g\":94},\"7786\":{\"m\":93,\"g\":94},\"7222\":{\"m\":93,\"g\":94},\"7787\":{\"m\":93,\"g\":94},\"7784\":{\"m\":93,\"g\":94},\"7444\":{\"m\":93,\"g\":94},\"7623\":{\"m\":93,\"g\":94},\"7782\":{\"m\":93,\"g\":94},\"7596\":{\"m\":93,\"g\":94},\"7772\":{\"m\":93,\"g\":94},\"7705\":{\"m\":93,\"g\":94},\"7745\":{\"m\":93,\"g\":94},\"7778\":{\"m\":93,\"g\":94},\"7419\":{\"m\":93,\"g\":94},\"7418\":{\"m\":93,\"g\":94},\"7748\":{\"m\":93,\"g\":94},\"7764\":{\"m\":93,\"g\":94},\"7741\":{\"m\":93,\"g\":94},\"7729\":{\"m\":93,\"g\":94},\"7390\":{\"m\":93,\"g\":94},\"7759\":{\"m\":93,\"g\":94},\"7751\":{\"m\":93,\"g\":94},\"7754\":{\"m\":93,\"g\":94},\"7755\":{\"m\":93,\"g\":94},\"7744\":{\"m\":93,\"g\":94},\"7752\":{\"m\":93,\"g\":94},\"7723\":{\"m\":93,\"g\":94},\"7673\":{\"m\":93,\"g\":94},\"7740\":{\"m\":93,\"g\":94},\"7681\":{\"m\":93,\"g\":94},\"6771\":{\"m\":93,\"g\":94},\"7647\":{\"m\":93,\"g\":94},\"7750\":{\"m\":93,\"g\":94},\"7722\":{\"m\":93,\"g\":94},\"7738\":{\"m\":93,\"g\":94},\"7278\":{\"m\":93,\"g\":94},\"7731\":{\"m\":93,\"g\":94},\"7735\":{\"m\":93,\"g\":94},\"6770\":{\"m\":93,\"g\":94},\"7734\":{\"m\":93,\"g\":94},\"7714\":{\"m\":93,\"g\":94},\"6698\":{\"m\":93,\"g\":94},\"7462\":{\"m\":93,\"g\":94},\"6549\":{\"m\":93,\"g\":94},\"7621\":{\"m\":93,\"g\":94},\"6512\":{\"m\":93,\"g\":94},\"7292\":{\"m\":93,\"g\":94},\"7416\":{\"m\":93,\"g\":94},\"7717\":{\"m\":93,\"g\":94},\"7677\":{\"m\":93,\"g\":94},\"7683\":{\"m\":93,\"g\":94},\"7697\":{\"m\":93,\"g\":94},\"7635\":{\"m\":93,\"g\":94},\"7642\":{\"m\":93,\"g\":94},\"7698\":{\"m\":93,\"g\":94},\"7684\":{\"m\":93,\"g\":94},\"7688\":{\"m\":93,\"g\":94},\"7676\":{\"m\":93,\"g\":94},\"7629\":{\"m\":93,\"g\":94},\"6985\":{\"m\":93,\"g\":94},\"7675\":{\"m\":93,\"g\":94},\"7486\":{\"m\":93,\"g\":94},\"7671\":{\"m\":93,\"g\":94},\"7648\":{\"m\":93,\"g\":94},\"7318\":{\"m\":93,\"g\":94},\"7663\":{\"m\":93,\"g\":94},\"7627\":{\"m\":93,\"g\":94},\"7632\":{\"m\":93,\"g\":94},\"7524\":{\"m\":93,\"g\":94},\"7643\":{\"m\":93,\"g\":94},\"7640\":{\"m\":93,\"g\":94},\"7539\":{\"m\":93,\"g\":94},\"7176\":{\"m\":93,\"g\":94},\"7580\":{\"m\":93,\"g\":94},\"7628\":{\"m\":93,\"g\":94},\"7619\":{\"m\":93,\"g\":94},\"7432\":{\"m\":93,\"g\":94},\"7636\":{\"m\":93,\"g\":94},\"7630\":{\"m\":93,\"g\":94},\"7624\":{\"m\":93,\"g\":94},\"7625\":{\"m\":93,\"g\":94},\"7036\":{\"m\":93,\"g\":94},\"7620\":{\"m\":93,\"g\":94},\"7310\":{\"m\":93,\"g\":94},\"7309\":{\"m\":93,\"g\":94},\"7618\":{\"m\":93,\"g\":94},\"7598\":{\"m\":93,\"g\":94},\"7446\":{\"m\":93,\"g\":94},\"7584\":{\"m\":93,\"g\":94},\"7552\":{\"m\":93,\"g\":94},\"7308\":{\"m\":93,\"g\":94},\"6769\":{\"m\":93,\"g\":94},\"6563\":{\"m\":93,\"g\":94},\"7612\":{\"m\":93,\"g\":94},\"7588\":{\"m\":93,\"g\":94},\"7610\":{\"m\":93,\"g\":94},\"7581\":{\"m\":93,\"g\":94},\"7225\":{\"m\":93,\"g\":94},\"7577\":{\"m\":93,\"g\":94},\"7540\":{\"m\":93,\"g\":94},\"7208\":{\"m\":93,\"g\":94},\"7569\":{\"m\":93,\"g\":94},\"7573\":{\"m\":93,\"g\":94},\"7575\":{\"m\":93,\"g\":94},\"7330\":{\"m\":93,\"g\":94},\"7882\":{\"m\":95,\"g\":97},\"7880\":{\"m\":95,\"g\":97},\"7660\":{\"m\":95,\"g\":97},\"7846\":{\"m\":95,\"g\":97},\"7818\":{\"m\":95,\"g\":97},\"7866\":{\"m\":95,\"g\":97},\"7724\":{\"m\":95,\"g\":97},\"7830\":{\"m\":95,\"g\":97},\"7579\":{\"m\":95,\"g\":97},\"7840\":{\"m\":95,\"g\":97},\"7864\":{\"m\":95,\"g\":97},\"7860\":{\"m\":95,\"g\":97},\"7832\":{\"m\":95,\"g\":97},\"7853\":{\"m\":95,\"g\":97},\"7850\":{\"m\":95,\"g\":97},\"7129\":{\"m\":95,\"g\":97},\"7762\":{\"m\":95,\"g\":97},\"7821\":{\"m\":95,\"g\":97},\"7187\":{\"m\":95,\"g\":97},\"7816\":{\"m\":95,\"g\":97},\"7798\":{\"m\":95,\"g\":94},\"7797\":{\"m\":95,\"g\":94},\"7313\":{\"m\":95,\"g\":94},\"7689\":{\"m\":95,\"g\":94},\"7813\":{\"m\":95,\"g\":94},\"6094\":{\"m\":95,\"g\":94},\"7794\":{\"m\":95,\"g\":94},\"7812\":{\"m\":95,\"g\":94},\"7733\":{\"m\":95,\"g\":94},\"5246\":{\"m\":95,\"g\":94},\"7785\":{\"m\":95,\"g\":94},\"7709\":{\"m\":95,\"g\":94},\"7793\":{\"m\":95,\"g\":94},\"7803\":{\"m\":95,\"g\":94},\"7796\":{\"m\":95,\"g\":94},\"7963\":{\"m\":96,\"g\":97},\"7971\":{\"m\":96,\"g\":97},\"7960\":{\"m\":96,\"g\":97},\"7969\":{\"m\":96,\"g\":97},\"7970\":{\"m\":96,\"g\":97},\"7962\":{\"m\":96,\"g\":97},\"7968\":{\"m\":96,\"g\":97},\"7964\":{\"m\":96,\"g\":97},\"7932\":{\"m\":96,\"g\":97},\"7961\":{\"m\":96,\"g\":97},\"7953\":{\"m\":96,\"g\":97},\"7795\":{\"m\":96,\"g\":97},\"7940\":{\"m\":96,\"g\":97},\"7775\":{\"m\":96,\"g\":97},\"7791\":{\"m\":96,\"g\":97},\"6449\":{\"m\":96,\"g\":97},\"7922\":{\"m\":96,\"g\":97},\"7907\":{\"m\":96,\"g\":97},\"5888\":{\"m\":96,\"g\":97},\"7904\":{\"m\":96,\"g\":97},\"7608\":{\"m\":96,\"g\":97},\"7899\":{\"m\":96,\"g\":97},\"7898\":{\"m\":96,\"g\":97},\"7838\":{\"m\":96,\"g\":97},\"7885\":{\"m\":96,\"g\":97},\"7872\":{\"m\":96,\"g\":97},\"7895\":{\"m\":96,\"g\":97},\"8265\":{\"m\":98,\"g\":103},\"8260\":{\"m\":98,\"g\":103},\"7822\":{\"m\":98,\"g\":103},\"8059\":{\"m\":98,\"g\":103},\"8257\":{\"m\":98,\"g\":103},\"7484\":{\"m\":98,\"g\":103},\"8221\":{\"m\":98,\"g\":103},\"8237\":{\"m\":98,\"g\":103},\"8231\":{\"m\":98,\"g\":103},\"8202\":{\"m\":98,\"g\":103},\"8204\":{\"m\":98,\"g\":103},\"8209\":{\"m\":98,\"g\":97},\"8208\":{\"m\":98,\"g\":97},\"8107\":{\"m\":98,\"g\":97},\"8193\":{\"m\":98,\"g\":97},\"8200\":{\"m\":98,\"g\":97},\"8195\":{\"m\":98,\"g\":97},\"7935\":{\"m\":98,\"g\":97},\"8184\":{\"m\":98,\"g\":97},\"8197\":{\"m\":98,\"g\":97},\"8067\":{\"m\":98,\"g\":97},\"8163\":{\"m\":98,\"g\":97},\"8183\":{\"m\":98,\"g\":97},\"7983\":{\"m\":98,\"g\":97},\"8182\":{\"m\":98,\"g\":97},\"8181\":{\"m\":98,\"g\":97},\"7825\":{\"m\":98,\"g\":97},\"8178\":{\"m\":98,\"g\":97},\"8176\":{\"m\":98,\"g\":97},\"7312\":{\"m\":98,\"g\":97},\"6230\":{\"m\":98,\"g\":97},\"7999\":{\"m\":98,\"g\":97},\"8115\":{\"m\":98,\"g\":97},\"8175\":{\"m\":98,\"g\":97},\"8172\":{\"m\":98,\"g\":97},\"8019\":{\"m\":98,\"g\":97},\"8167\":{\"m\":98,\"g\":97},\"8170\":{\"m\":98,\"g\":97},\"8169\":{\"m\":98,\"g\":97},\"8103\":{\"m\":98,\"g\":97},\"8161\":{\"m\":98,\"g\":97},\"8171\":{\"m\":98,\"g\":97},\"8168\":{\"m\":98,\"g\":97},\"8166\":{\"m\":98,\"g\":97},\"8165\":{\"m\":98,\"g\":97},\"7966\":{\"m\":98,\"g\":97},\"8157\":{\"m\":98,\"g\":97},\"8160\":{\"m\":98,\"g\":97},\"8158\":{\"m\":98,\"g\":97},\"8028\":{\"m\":98,\"g\":97},\"7931\":{\"m\":98,\"g\":97},\"8048\":{\"m\":98,\"g\":97},\"7302\":{\"m\":98,\"g\":97},\"6881\":{\"m\":98,\"g\":97},\"8155\":{\"m\":98,\"g\":97},\"7661\":{\"m\":98,\"g\":97},\"7987\":{\"m\":98,\"g\":97},\"8113\":{\"m\":98,\"g\":97},\"8147\":{\"m\":98,\"g\":97},\"8142\":{\"m\":98,\"g\":97},\"8136\":{\"m\":98,\"g\":97},\"8141\":{\"m\":98,\"g\":97},\"7704\":{\"m\":98,\"g\":97},\"7820\":{\"m\":98,\"g\":97},\"7889\":{\"m\":98,\"g\":97},\"7506\":{\"m\":98,\"g\":97},\"7959\":{\"m\":98,\"g\":97},\"8127\":{\"m\":98,\"g\":97},\"7924\":{\"m\":98,\"g\":97},\"7030\":{\"m\":98,\"g\":97},\"8117\":{\"m\":98,\"g\":97},\"8102\":{\"m\":98,\"g\":97},\"8046\":{\"m\":98,\"g\":97},\"7884\":{\"m\":98,\"g\":97},\"7989\":{\"m\":98,\"g\":97},\"8105\":{\"m\":98,\"g\":97},\"8110\":{\"m\":98,\"g\":97},\"8108\":{\"m\":98,\"g\":97},\"8100\":{\"m\":98,\"g\":97},\"7597\":{\"m\":98,\"g\":97},\"8075\":{\"m\":98,\"g\":97},\"7992\":{\"m\":98,\"g\":97},\"7634\":{\"m\":98,\"g\":97},\"8098\":{\"m\":98,\"g\":97},\"8090\":{\"m\":98,\"g\":97},\"8086\":{\"m\":98,\"g\":97},\"8077\":{\"m\":98,\"g\":97},\"8001\":{\"m\":98,\"g\":97},\"7760\":{\"m\":98,\"g\":97},\"8029\":{\"m\":98,\"g\":97},\"8058\":{\"m\":98,\"g\":97},\"5163\":{\"m\":98,\"g\":97},\"8045\":{\"m\":98,\"g\":97},\"8047\":{\"m\":98,\"g\":97},\"7943\":{\"m\":98,\"g\":97},\"8052\":{\"m\":98,\"g\":97},\"6556\":{\"m\":98,\"g\":97},\"8022\":{\"m\":98,\"g\":97},\"8002\":{\"m\":98,\"g\":97},\"8023\":{\"m\":98,\"g\":97},\"8044\":{\"m\":98,\"g\":97},\"7887\":{\"m\":98,\"g\":97},\"8035\":{\"m\":98,\"g\":97},\"7897\":{\"m\":98,\"g\":97},\"7982\":{\"m\":98,\"g\":97},\"8006\":{\"m\":98,\"g\":97},\"7649\":{\"m\":98,\"g\":97},\"7653\":{\"m\":98,\"g\":97},\"8005\":{\"m\":98,\"g\":97},\"7874\":{\"m\":98,\"g\":97},\"8021\":{\"m\":98,\"g\":97},\"8010\":{\"m\":98,\"g\":97},\"7902\":{\"m\":98,\"g\":97},\"7862\":{\"m\":98,\"g\":97},\"7844\":{\"m\":98,\"g\":97},\"7997\":{\"m\":98,\"g\":97},\"7367\":{\"m\":98,\"g\":97},\"7749\":{\"m\":98,\"g\":97},\"7952\":{\"m\":98,\"g\":97},\"7988\":{\"m\":98,\"g\":97},\"7814\":{\"m\":98,\"g\":97},\"7978\":{\"m\":98,\"g\":97},\"7985\":{\"m\":98,\"g\":97},\"7975\":{\"m\":98,\"g\":97},\"7972\":{\"m\":98,\"g\":97},\"7950\":{\"m\":98,\"g\":97},\"8305\":{\"m\":99,\"g\":103},\"8370\":{\"m\":99,\"g\":103},\"8333\":{\"m\":99,\"g\":103},\"8367\":{\"m\":99,\"g\":103},\"8363\":{\"m\":99,\"g\":103},\"8359\":{\"m\":99,\"g\":103},\"8332\":{\"m\":99,\"g\":103},\"8357\":{\"m\":99,\"g\":103},\"8344\":{\"m\":99,\"g\":103},\"7858\":{\"m\":99,\"g\":103},\"8353\":{\"m\":99,\"g\":103},\"8341\":{\"m\":99,\"g\":103},\"8000\":{\"m\":99,\"g\":103},\"7135\":{\"m\":99,\"g\":103},\"8266\":{\"m\":99,\"g\":103},\"8280\":{\"m\":99,\"g\":103},\"6619\":{\"m\":99,\"g\":103},\"8233\":{\"m\":99,\"g\":103},\"8334\":{\"m\":99,\"g\":103},\"8307\":{\"m\":99,\"g\":103},\"8300\":{\"m\":99,\"g\":103},\"8299\":{\"m\":99,\"g\":103},\"8301\":{\"m\":99,\"g\":103},\"8310\":{\"m\":99,\"g\":103},\"8298\":{\"m\":99,\"g\":103},\"8315\":{\"m\":99,\"g\":103},\"8303\":{\"m\":99,\"g\":103},\"8317\":{\"m\":99,\"g\":103},\"8235\":{\"m\":99,\"g\":103},\"8070\":{\"m\":99,\"g\":103},\"7562\":{\"m\":99,\"g\":103},\"8043\":{\"m\":99,\"g\":103},\"7685\":{\"m\":99,\"g\":103},\"8304\":{\"m\":99,\"g\":103},\"8240\":{\"m\":99,\"g\":103},\"8262\":{\"m\":99,\"g\":103},\"7708\":{\"m\":99,\"g\":103},\"8133\":{\"m\":99,\"g\":103},\"8302\":{\"m\":99,\"g\":103},\"8295\":{\"m\":99,\"g\":103},\"8130\":{\"m\":99,\"g\":103},\"8288\":{\"m\":99,\"g\":103},\"8282\":{\"m\":99,\"g\":103},\"8264\":{\"m\":99,\"g\":103},\"8261\":{\"m\":99,\"g\":103},\"8284\":{\"m\":99,\"g\":103},\"8272\":{\"m\":99,\"g\":103},\"8458\":{\"m\":100,\"g\":103},\"8457\":{\"m\":100,\"g\":103},\"8449\":{\"m\":100,\"g\":103},\"8456\":{\"m\":100,\"g\":103},\"8445\":{\"m\":100,\"g\":103},\"8441\":{\"m\":100,\"g\":103},\"8442\":{\"m\":100,\"g\":103},\"8224\":{\"m\":100,\"g\":103},\"8352\":{\"m\":100,\"g\":103},\"8416\":{\"m\":100,\"g\":103},\"6338\":{\"m\":100,\"g\":103},\"8415\":{\"m\":100,\"g\":103},\"8422\":{\"m\":100,\"g\":103},\"8425\":{\"m\":100,\"g\":103},\"8419\":{\"m\":100,\"g\":103},\"8417\":{\"m\":100,\"g\":103},\"8316\":{\"m\":100,\"g\":103},\"7603\":{\"m\":100,\"g\":103},\"8213\":{\"m\":100,\"g\":103},\"8414\":{\"m\":100,\"g\":103},\"8062\":{\"m\":100,\"g\":103},\"8258\":{\"m\":100,\"g\":103},\"8406\":{\"m\":100,\"g\":103},\"8407\":{\"m\":100,\"g\":103},\"8405\":{\"m\":100,\"g\":103},\"8156\":{\"m\":100,\"g\":103},\"8241\":{\"m\":100,\"g\":103},\"8397\":{\"m\":100,\"g\":103},\"7720\":{\"m\":100,\"g\":103},\"8351\":{\"m\":100,\"g\":103},\"8395\":{\"m\":100,\"g\":103},\"8382\":{\"m\":100,\"g\":103},\"8392\":{\"m\":100,\"g\":103},\"8036\":{\"m\":100,\"g\":103},\"7739\":{\"m\":100,\"g\":103},\"7974\":{\"m\":100,\"g\":103},\"7976\":{\"m\":100,\"g\":103},\"8403\":{\"m\":100,\"g\":103},\"8401\":{\"m\":100,\"g\":103},\"8372\":{\"m\":100,\"g\":103},\"8394\":{\"m\":100,\"g\":103},\"8396\":{\"m\":100,\"g\":103},\"8314\":{\"m\":100,\"g\":103},\"8350\":{\"m\":100,\"g\":103},\"8381\":{\"m\":100,\"g\":103},\"6003\":{\"m\":100,\"g\":103},\"7737\":{\"m\":100,\"g\":103},\"8267\":{\"m\":100,\"g\":103},\"8343\":{\"m\":100,\"g\":103},\"7000\":{\"m\":100,\"g\":103},\"8356\":{\"m\":100,\"g\":103},\"8374\":{\"m\":100,\"g\":103},\"8517\":{\"m\":101,\"g\":103},\"8489\":{\"m\":101,\"g\":103},\"8482\":{\"m\":101,\"g\":103},\"7973\":{\"m\":101,\"g\":103},\"8426\":{\"m\":101,\"g\":103},\"8413\":{\"m\":101,\"g\":103},\"8486\":{\"m\":101,\"g\":103},\"8485\":{\"m\":101,\"g\":103},\"8477\":{\"m\":101,\"g\":103},\"8480\":{\"m\":101,\"g\":103},\"8478\":{\"m\":101,\"g\":103},\"8476\":{\"m\":101,\"g\":103},\"8469\":{\"m\":101,\"g\":103},\"8473\":{\"m\":101,\"g\":103},\"8421\":{\"m\":101,\"g\":103},\"7273\":{\"m\":101,\"g\":103},\"8453\":{\"m\":101,\"g\":103},\"8467\":{\"m\":101,\"g\":103},\"7565\":{\"m\":101,\"g\":103},\"8465\":{\"m\":101,\"g\":103},\"8125\":{\"m\":101,\"g\":103},\"8608\":{\"m\":102,\"g\":103},\"8590\":{\"m\":102,\"g\":103},\"8583\":{\"m\":102,\"g\":103},\"8604\":{\"m\":102,\"g\":103},\"8515\":{\"m\":102,\"g\":103},\"8603\":{\"m\":102,\"g\":103},\"8550\":{\"m\":102,\"g\":103},\"7211\":{\"m\":102,\"g\":103},\"8533\":{\"m\":102,\"g\":103},\"8404\":{\"m\":102,\"g\":103},\"8599\":{\"m\":102,\"g\":103},\"8514\":{\"m\":102,\"g\":103},\"8365\":{\"m\":102,\"g\":103},\"8544\":{\"m\":102,\"g\":103},\"8541\":{\"m\":102,\"g\":103},\"8564\":{\"m\":102,\"g\":103},\"8479\":{\"m\":102,\"g\":103},\"7280\":{\"m\":102,\"g\":103},\"8584\":{\"m\":102,\"g\":103},\"8154\":{\"m\":102,\"g\":103},\"8461\":{\"m\":102,\"g\":103},\"6869\":{\"m\":102,\"g\":103},\"8562\":{\"m\":102,\"g\":103},\"8545\":{\"m\":102,\"g\":103},\"8560\":{\"m\":102,\"g\":103},\"8516\":{\"m\":102,\"g\":103},\"8498\":{\"m\":102,\"g\":103},\"8448\":{\"m\":102,\"g\":103},\"8431\":{\"m\":102,\"g\":103},\"8537\":{\"m\":102,\"g\":103},\"8483\":{\"m\":102,\"g\":103},\"8531\":{\"m\":102,\"g\":103},\"8535\":{\"m\":102,\"g\":103},\"8499\":{\"m\":102,\"g\":103},\"8528\":{\"m\":102,\"g\":103},\"8527\":{\"m\":102,\"g\":103},\"8652\":{\"m\":105,\"g\":107},\"8051\":{\"m\":105,\"g\":107},\"8318\":{\"m\":105,\"g\":107},\"8636\":{\"m\":105,\"g\":107},\"8450\":{\"m\":105,\"g\":107},\"8645\":{\"m\":105,\"g\":104},\"8640\":{\"m\":105,\"g\":104},\"8644\":{\"m\":105,\"g\":104},\"8308\":{\"m\":105,\"g\":104},\"8270\":{\"m\":105,\"g\":104},\"8083\":{\"m\":105,\"g\":104},\"8642\":{\"m\":105,\"g\":104},\"8532\":{\"m\":105,\"g\":104},\"8632\":{\"m\":105,\"g\":104},\"8598\":{\"m\":105,\"g\":104},\"8630\":{\"m\":105,\"g\":104},\"8634\":{\"m\":105,\"g\":104},\"8488\":{\"m\":105,\"g\":104},\"8633\":{\"m\":105,\"g\":104},\"6227\":{\"m\":105,\"g\":104},\"8577\":{\"m\":105,\"g\":104},\"8628\":{\"m\":105,\"g\":103},\"8629\":{\"m\":105,\"g\":103},\"8626\":{\"m\":105,\"g\":103},\"8623\":{\"m\":105,\"g\":103},\"8611\":{\"m\":105,\"g\":103},\"8595\":{\"m\":105,\"g\":103},\"8727\":{\"m\":106,\"g\":107},\"8723\":{\"m\":106,\"g\":107},\"8579\":{\"m\":106,\"g\":107},\"8567\":{\"m\":106,\"g\":107},\"8718\":{\"m\":106,\"g\":107},\"8444\":{\"m\":106,\"g\":107},\"8547\":{\"m\":106,\"g\":107},\"8683\":{\"m\":106,\"g\":107},\"8631\":{\"m\":106,\"g\":107},\"7379\":{\"m\":106,\"g\":107},\"8719\":{\"m\":106,\"g\":107},\"8306\":{\"m\":106,\"g\":107},\"8650\":{\"m\":106,\"g\":107},\"8721\":{\"m\":106,\"g\":107},\"8524\":{\"m\":106,\"g\":107},\"8709\":{\"m\":106,\"g\":107},\"8722\":{\"m\":106,\"g\":107},\"7369\":{\"m\":106,\"g\":107},\"8714\":{\"m\":106,\"g\":107},\"8705\":{\"m\":106,\"g\":107},\"8693\":{\"m\":106,\"g\":107},\"8717\":{\"m\":106,\"g\":107},\"8713\":{\"m\":106,\"g\":107},\"8701\":{\"m\":106,\"g\":107},\"8711\":{\"m\":106,\"g\":107},\"8706\":{\"m\":106,\"g\":107},\"8704\":{\"m\":106,\"g\":107},\"8691\":{\"m\":106,\"g\":107},\"8512\":{\"m\":106,\"g\":107},\"7434\":{\"m\":106,\"g\":107},\"8688\":{\"m\":106,\"g\":107},\"8694\":{\"m\":106,\"g\":107},\"8364\":{\"m\":106,\"g\":107},\"8238\":{\"m\":106,\"g\":107},\"8618\":{\"m\":106,\"g\":107},\"8522\":{\"m\":106,\"g\":107},\"8668\":{\"m\":106,\"g\":107},\"8648\":{\"m\":106,\"g\":107},\"8679\":{\"m\":106,\"g\":107},\"8686\":{\"m\":106,\"g\":107},\"8684\":{\"m\":106,\"g\":107},\"8685\":{\"m\":106,\"g\":107},\"8647\":{\"m\":106,\"g\":107},\"8094\":{\"m\":106,\"g\":107},\"8664\":{\"m\":106,\"g\":107},\"8543\":{\"m\":106,\"g\":107},\"8665\":{\"m\":106,\"g\":107},\"8013\":{\"m\":106,\"g\":107},\"8643\":{\"m\":106,\"g\":107},\"8658\":{\"m\":106,\"g\":107},\"8511\":{\"m\":106,\"g\":107},\"8635\":{\"m\":106,\"g\":107},\"8653\":{\"m\":106,\"g\":107},\"9533\":{\"m\":108,\"g\":115},\"9532\":{\"m\":108,\"g\":115},\"9372\":{\"m\":108,\"g\":115},\"9485\":{\"m\":108,\"g\":115},\"9478\":{\"m\":108,\"g\":115},\"8034\":{\"m\":108,\"g\":115},\"9473\":{\"m\":108,\"g\":115},\"9525\":{\"m\":108,\"g\":115},\"9530\":{\"m\":108,\"g\":115},\"8946\":{\"m\":108,\"g\":115},\"9004\":{\"m\":108,\"g\":115},\"9241\":{\"m\":108,\"g\":115},\"9211\":{\"m\":108,\"g\":115},\"9503\":{\"m\":108,\"g\":115},\"9519\":{\"m\":108,\"g\":115},\"9456\":{\"m\":108,\"g\":115},\"7699\":{\"m\":108,\"g\":115},\"9200\":{\"m\":108,\"g\":115},\"9516\":{\"m\":108,\"g\":115},\"9127\":{\"m\":108,\"g\":115},\"9513\":{\"m\":108,\"g\":115},\"8624\":{\"m\":108,\"g\":115},\"8865\":{\"m\":108,\"g\":115},\"9109\":{\"m\":108,\"g\":115},\"9452\":{\"m\":108,\"g\":115},\"9507\":{\"m\":108,\"g\":115},\"9303\":{\"m\":108,\"g\":115},\"9331\":{\"m\":108,\"g\":115},\"9497\":{\"m\":108,\"g\":115},\"9494\":{\"m\":108,\"g\":115},\"9475\":{\"m\":108,\"g\":115},\"9480\":{\"m\":108,\"g\":115},\"9487\":{\"m\":108,\"g\":115},\"9491\":{\"m\":108,\"g\":115},\"9482\":{\"m\":108,\"g\":115},\"9492\":{\"m\":108,\"g\":115},\"9483\":{\"m\":108,\"g\":115},\"9333\":{\"m\":108,\"g\":115},\"9474\":{\"m\":108,\"g\":115},\"8616\":{\"m\":108,\"g\":115},\"9356\":{\"m\":108,\"g\":115},\"8593\":{\"m\":108,\"g\":115},\"9468\":{\"m\":108,\"g\":115},\"9467\":{\"m\":108,\"g\":115},\"9470\":{\"m\":108,\"g\":115},\"9469\":{\"m\":108,\"g\":115},\"9455\":{\"m\":108,\"g\":115},\"9463\":{\"m\":108,\"g\":115},\"9464\":{\"m\":108,\"g\":115},\"9462\":{\"m\":108,\"g\":115},\"9461\":{\"m\":108,\"g\":115},\"9458\":{\"m\":108,\"g\":115},\"9454\":{\"m\":108,\"g\":115},\"9427\":{\"m\":108,\"g\":115},\"9433\":{\"m\":108,\"g\":115},\"7604\":{\"m\":108,\"g\":115},\"8521\":{\"m\":108,\"g\":115},\"9392\":{\"m\":108,\"g\":115},\"9395\":{\"m\":108,\"g\":115},\"9238\":{\"m\":108,\"g\":115},\"9384\":{\"m\":108,\"g\":115},\"9430\":{\"m\":108,\"g\":115},\"9346\":{\"m\":108,\"g\":115},\"9399\":{\"m\":108,\"g\":115},\"9251\":{\"m\":108,\"g\":115},\"9388\":{\"m\":108,\"g\":115},\"9261\":{\"m\":108,\"g\":115},\"9420\":{\"m\":108,\"g\":115},\"9416\":{\"m\":108,\"g\":115},\"9415\":{\"m\":108,\"g\":115},\"9413\":{\"m\":108,\"g\":115},\"9371\":{\"m\":108,\"g\":115},\"9339\":{\"m\":108,\"g\":115},\"9357\":{\"m\":108,\"g\":115},\"9377\":{\"m\":108,\"g\":115},\"9381\":{\"m\":108,\"g\":115},\"9404\":{\"m\":108,\"g\":115},\"9359\":{\"m\":108,\"g\":115},\"9336\":{\"m\":108,\"g\":115},\"9249\":{\"m\":108,\"g\":115},\"9409\":{\"m\":108,\"g\":115},\"8690\":{\"m\":108,\"g\":115},\"9278\":{\"m\":108,\"g\":115},\"9391\":{\"m\":108,\"g\":115},\"9106\":{\"m\":108,\"g\":115},\"7375\":{\"m\":108,\"g\":115},\"9385\":{\"m\":108,\"g\":115},\"9383\":{\"m\":108,\"g\":115},\"9378\":{\"m\":108,\"g\":115},\"9380\":{\"m\":108,\"g\":115},\"9376\":{\"m\":108,\"g\":115},\"9350\":{\"m\":108,\"g\":115},\"9368\":{\"m\":108,\"g\":115},\"9344\":{\"m\":108,\"g\":115},\"9369\":{\"m\":108,\"g\":115},\"9367\":{\"m\":108,\"g\":115},\"9370\":{\"m\":108,\"g\":115},\"9364\":{\"m\":108,\"g\":115},\"9360\":{\"m\":108,\"g\":115},\"9335\":{\"m\":108,\"g\":115},\"9361\":{\"m\":108,\"g\":115},\"9354\":{\"m\":108,\"g\":115},\"6295\":{\"m\":108,\"g\":115},\"9353\":{\"m\":108,\"g\":115},\"9348\":{\"m\":108,\"g\":115},\"9327\":{\"m\":108,\"g\":115},\"9332\":{\"m\":108,\"g\":115},\"9326\":{\"m\":108,\"g\":115},\"8990\":{\"m\":108,\"g\":115},\"7019\":{\"m\":108,\"g\":115},\"9321\":{\"m\":108,\"g\":115},\"9317\":{\"m\":108,\"g\":115},\"9299\":{\"m\":108,\"g\":115},\"9322\":{\"m\":108,\"g\":115},\"9306\":{\"m\":108,\"g\":115},\"9320\":{\"m\":108,\"g\":115},\"9059\":{\"m\":108,\"g\":115},\"8936\":{\"m\":108,\"g\":115},\"9284\":{\"m\":108,\"g\":115},\"9313\":{\"m\":108,\"g\":115},\"9316\":{\"m\":108,\"g\":115},\"9315\":{\"m\":108,\"g\":115},\"8829\":{\"m\":108,\"g\":115},\"9011\":{\"m\":108,\"g\":115},\"9289\":{\"m\":108,\"g\":115},\"9310\":{\"m\":108,\"g\":115},\"9307\":{\"m\":108,\"g\":115},\"9298\":{\"m\":108,\"g\":115},\"9276\":{\"m\":108,\"g\":115},\"9293\":{\"m\":108,\"g\":115},\"6307\":{\"m\":108,\"g\":115},\"9245\":{\"m\":108,\"g\":115},\"9287\":{\"m\":108,\"g\":115},\"9286\":{\"m\":108,\"g\":115},\"8289\":{\"m\":108,\"g\":115},\"9281\":{\"m\":108,\"g\":115},\"8520\":{\"m\":108,\"g\":115},\"9272\":{\"m\":108,\"g\":115},\"9271\":{\"m\":108,\"g\":115},\"9279\":{\"m\":108,\"g\":115},\"9131\":{\"m\":108,\"g\":115},\"9242\":{\"m\":108,\"g\":115},\"9260\":{\"m\":108,\"g\":115},\"9268\":{\"m\":108,\"g\":115},\"9067\":{\"m\":108,\"g\":115},\"9264\":{\"m\":108,\"g\":115},\"9237\":{\"m\":108,\"g\":115},\"9232\":{\"m\":108,\"g\":115},\"9006\":{\"m\":108,\"g\":115},\"8893\":{\"m\":108,\"g\":115},\"9049\":{\"m\":108,\"g\":115},\"8846\":{\"m\":108,\"g\":115},\"9252\":{\"m\":108,\"g\":115},\"8027\":{\"m\":108,\"g\":115},\"9258\":{\"m\":108,\"g\":115},\"7758\":{\"m\":108,\"g\":115},\"9165\":{\"m\":108,\"g\":115},\"7667\":{\"m\":108,\"g\":115},\"9247\":{\"m\":108,\"g\":115},\"9246\":{\"m\":108,\"g\":115},\"8663\":{\"m\":108,\"g\":115},\"8268\":{\"m\":108,\"g\":115},\"9243\":{\"m\":108,\"g\":115},\"9236\":{\"m\":108,\"g\":115},\"9201\":{\"m\":108,\"g\":115},\"9198\":{\"m\":108,\"g\":115},\"9231\":{\"m\":108,\"g\":115},\"9220\":{\"m\":108,\"g\":115},\"9223\":{\"m\":108,\"g\":115},\"9222\":{\"m\":108,\"g\":115},\"8777\":{\"m\":108,\"g\":115},\"8790\":{\"m\":108,\"g\":115},\"9215\":{\"m\":108,\"g\":115},\"9218\":{\"m\":108,\"g\":115},\"9214\":{\"m\":108,\"g\":115},\"9208\":{\"m\":108,\"g\":115},\"9213\":{\"m\":108,\"g\":115},\"9207\":{\"m\":108,\"g\":115},\"9177\":{\"m\":108,\"g\":115},\"8849\":{\"m\":108,\"g\":115},\"9206\":{\"m\":108,\"g\":115},\"9205\":{\"m\":108,\"g\":115},\"9183\":{\"m\":108,\"g\":115},\"9204\":{\"m\":108,\"g\":115},\"9203\":{\"m\":108,\"g\":115},\"9202\":{\"m\":108,\"g\":115},\"9197\":{\"m\":108,\"g\":115},\"9008\":{\"m\":108,\"g\":115},\"8795\":{\"m\":108,\"g\":115},\"9191\":{\"m\":108,\"g\":115},\"9194\":{\"m\":108,\"g\":115},\"9060\":{\"m\":108,\"g\":115},\"8913\":{\"m\":108,\"g\":115},\"9185\":{\"m\":108,\"g\":115},\"8112\":{\"m\":108,\"g\":115},\"9065\":{\"m\":108,\"g\":115},\"8018\":{\"m\":108,\"g\":115},\"7687\":{\"m\":108,\"g\":115},\"7631\":{\"m\":108,\"g\":115},\"7004\":{\"m\":108,\"g\":115},\"8852\":{\"m\":108,\"g\":115},\"8808\":{\"m\":108,\"g\":115},\"8818\":{\"m\":108,\"g\":115},\"9154\":{\"m\":108,\"g\":115},\"9101\":{\"m\":108,\"g\":115},\"9162\":{\"m\":108,\"g\":115},\"9136\":{\"m\":108,\"g\":115},\"9171\":{\"m\":108,\"g\":115},\"9169\":{\"m\":108,\"g\":115},\"9159\":{\"m\":108,\"g\":115},\"8951\":{\"m\":108,\"g\":115},\"9161\":{\"m\":108,\"g\":115},\"8840\":{\"m\":108,\"g\":115},\"9134\":{\"m\":108,\"g\":115},\"9042\":{\"m\":108,\"g\":115},\"7957\":{\"m\":108,\"g\":115},\"9069\":{\"m\":108,\"g\":115},\"9028\":{\"m\":108,\"g\":115},\"8910\":{\"m\":108,\"g\":115},\"9149\":{\"m\":108,\"g\":115},\"9133\":{\"m\":108,\"g\":115},\"9126\":{\"m\":108,\"g\":115},\"9150\":{\"m\":108,\"g\":115},\"8484\":{\"m\":108,\"g\":115},\"9111\":{\"m\":108,\"g\":115},\"9146\":{\"m\":108,\"g\":115},\"9093\":{\"m\":108,\"g\":115},\"9088\":{\"m\":108,\"g\":115},\"8588\":{\"m\":108,\"g\":115},\"9137\":{\"m\":108,\"g\":115},\"8884\":{\"m\":108,\"g\":115},\"8651\":{\"m\":108,\"g\":115},\"9130\":{\"m\":108,\"g\":115},\"9129\":{\"m\":108,\"g\":115},\"8660\":{\"m\":108,\"g\":115},\"9119\":{\"m\":108,\"g\":115},\"8619\":{\"m\":108,\"g\":115},\"8610\":{\"m\":108,\"g\":115},\"8700\":{\"m\":108,\"g\":115},\"9125\":{\"m\":108,\"g\":115},\"9121\":{\"m\":108,\"g\":115},\"9014\":{\"m\":108,\"g\":115},\"9118\":{\"m\":108,\"g\":115},\"9122\":{\"m\":108,\"g\":115},\"9107\":{\"m\":108,\"g\":115},\"9113\":{\"m\":108,\"g\":115},\"9114\":{\"m\":108,\"g\":115},\"9021\":{\"m\":108,\"g\":115},\"9103\":{\"m\":108,\"g\":115},\"9077\":{\"m\":108,\"g\":115},\"9096\":{\"m\":108,\"g\":115},\"9005\":{\"m\":108,\"g\":115},\"9075\":{\"m\":108,\"g\":115},\"9097\":{\"m\":108,\"g\":115},\"9032\":{\"m\":108,\"g\":115},\"8766\":{\"m\":108,\"g\":115},\"9087\":{\"m\":108,\"g\":115},\"8293\":{\"m\":108,\"g\":115},\"9095\":{\"m\":108,\"g\":115},\"9084\":{\"m\":108,\"g\":115},\"8992\":{\"m\":108,\"g\":115},\"9089\":{\"m\":108,\"g\":115},\"9043\":{\"m\":108,\"g\":115},\"9086\":{\"m\":108,\"g\":115},\"9030\":{\"m\":108,\"g\":115},\"9053\":{\"m\":108,\"g\":115},\"8638\":{\"m\":108,\"g\":115},\"8731\":{\"m\":108,\"g\":115},\"9083\":{\"m\":108,\"g\":115},\"8752\":{\"m\":108,\"g\":115},\"9081\":{\"m\":108,\"g\":115},\"9082\":{\"m\":108,\"g\":115},\"8866\":{\"m\":108,\"g\":115},\"9080\":{\"m\":108,\"g\":115},\"8973\":{\"m\":108,\"g\":115},\"9063\":{\"m\":108,\"g\":115},\"9079\":{\"m\":108,\"g\":115},\"9066\":{\"m\":108,\"g\":115},\"7216\":{\"m\":108,\"g\":115},\"9051\":{\"m\":108,\"g\":115},\"9047\":{\"m\":108,\"g\":115},\"9057\":{\"m\":108,\"g\":115},\"9050\":{\"m\":108,\"g\":115},\"9054\":{\"m\":108,\"g\":115},\"9048\":{\"m\":108,\"g\":115},\"8997\":{\"m\":108,\"g\":115},\"9046\":{\"m\":108,\"g\":115},\"9044\":{\"m\":108,\"g\":115},\"9031\":{\"m\":108,\"g\":115},\"9037\":{\"m\":108,\"g\":115},\"9036\":{\"m\":108,\"g\":115},\"9034\":{\"m\":108,\"g\":115},\"9035\":{\"m\":108,\"g\":115},\"9033\":{\"m\":108,\"g\":115},\"8079\":{\"m\":108,\"g\":115},\"8794\":{\"m\":108,\"g\":115},\"9024\":{\"m\":108,\"g\":115},\"9027\":{\"m\":108,\"g\":115},\"9029\":{\"m\":108,\"g\":115},\"7626\":{\"m\":108,\"g\":115},\"9022\":{\"m\":108,\"g\":115},\"8940\":{\"m\":108,\"g\":115},\"8996\":{\"m\":108,\"g\":115},\"9018\":{\"m\":108,\"g\":115},\"9017\":{\"m\":108,\"g\":115},\"9019\":{\"m\":108,\"g\":115},\"8340\":{\"m\":108,\"g\":115},\"8991\":{\"m\":108,\"g\":115},\"8915\":{\"m\":108,\"g\":115},\"8245\":{\"m\":108,\"g\":115},\"9013\":{\"m\":108,\"g\":115},\"9007\":{\"m\":108,\"g\":115},\"9012\":{\"m\":108,\"g\":115},\"9003\":{\"m\":108,\"g\":115},\"9010\":{\"m\":108,\"g\":115},\"9001\":{\"m\":108,\"g\":115},\"8329\":{\"m\":108,\"g\":115},\"8355\":{\"m\":108,\"g\":115},\"8877\":{\"m\":108,\"g\":115},\"8995\":{\"m\":108,\"g\":115},\"8998\":{\"m\":108,\"g\":115},\"8878\":{\"m\":108,\"g\":115},\"6752\":{\"m\":108,\"g\":115},\"8673\":{\"m\":108,\"g\":115},\"8798\":{\"m\":108,\"g\":115},\"8687\":{\"m\":108,\"g\":115},\"8600\":{\"m\":108,\"g\":115},\"8966\":{\"m\":108,\"g\":115},\"8851\":{\"m\":108,\"g\":115},\"8984\":{\"m\":108,\"g\":115},\"8962\":{\"m\":108,\"g\":115},\"8987\":{\"m\":108,\"g\":115},\"8994\":{\"m\":108,\"g\":115},\"8993\":{\"m\":108,\"g\":115},\"8989\":{\"m\":108,\"g\":115},\"8667\":{\"m\":108,\"g\":115},\"8983\":{\"m\":108,\"g\":115},\"8980\":{\"m\":108,\"g\":115},\"8330\":{\"m\":108,\"g\":115},\"8770\":{\"m\":108,\"g\":115},\"8724\":{\"m\":108,\"g\":115},\"8988\":{\"m\":108,\"g\":115},\"8986\":{\"m\":108,\"g\":115},\"8785\":{\"m\":108,\"g\":115},\"8371\":{\"m\":108,\"g\":115},\"8982\":{\"m\":108,\"g\":115},\"8978\":{\"m\":108,\"g\":115},\"8981\":{\"m\":108,\"g\":115},\"8772\":{\"m\":108,\"g\":115},\"8971\":{\"m\":108,\"g\":115},\"8968\":{\"m\":108,\"g\":115},\"8972\":{\"m\":108,\"g\":115},\"7279\":{\"m\":108,\"g\":115},\"8941\":{\"m\":108,\"g\":115},\"8959\":{\"m\":108,\"g\":115},\"8757\":{\"m\":108,\"g\":115},\"8960\":{\"m\":108,\"g\":115},\"8958\":{\"m\":108,\"g\":115},\"8692\":{\"m\":108,\"g\":115},\"6555\":{\"m\":108,\"g\":115},\"8894\":{\"m\":108,\"g\":115},\"8957\":{\"m\":108,\"g\":115},\"8955\":{\"m\":108,\"g\":115},\"7657\":{\"m\":108,\"g\":115},\"8944\":{\"m\":108,\"g\":115},\"8799\":{\"m\":108,\"g\":115},\"8932\":{\"m\":108,\"g\":115},\"8952\":{\"m\":108,\"g\":115},\"8953\":{\"m\":108,\"g\":115},\"8947\":{\"m\":108,\"g\":115},\"8950\":{\"m\":108,\"g\":115},\"8720\":{\"m\":108,\"g\":115},\"8703\":{\"m\":108,\"g\":115},\"8923\":{\"m\":108,\"g\":115},\"8850\":{\"m\":108,\"g\":115},\"8933\":{\"m\":108,\"g\":115},\"8937\":{\"m\":108,\"g\":115},\"8929\":{\"m\":108,\"g\":115},\"8928\":{\"m\":108,\"g\":115},\"8925\":{\"m\":108,\"g\":115},\"8927\":{\"m\":108,\"g\":115},\"8908\":{\"m\":108,\"g\":115},\"8844\":{\"m\":108,\"g\":107},\"8916\":{\"m\":108,\"g\":107},\"8912\":{\"m\":108,\"g\":107},\"8698\":{\"m\":108,\"g\":107},\"8869\":{\"m\":108,\"g\":107},\"8898\":{\"m\":108,\"g\":107},\"8895\":{\"m\":108,\"g\":107},\"8041\":{\"m\":108,\"g\":107},\"5949\":{\"m\":108,\"g\":107},\"8888\":{\"m\":108,\"g\":107},\"8292\":{\"m\":108,\"g\":107},\"8787\":{\"m\":108,\"g\":107},\"8369\":{\"m\":108,\"g\":107},\"8697\":{\"m\":108,\"g\":107},\"8834\":{\"m\":108,\"g\":107},\"8883\":{\"m\":108,\"g\":107},\"8847\":{\"m\":108,\"g\":107},\"8539\":{\"m\":108,\"g\":107},\"8837\":{\"m\":108,\"g\":107},\"8880\":{\"m\":108,\"g\":107},\"8881\":{\"m\":108,\"g\":107},\"8811\":{\"m\":108,\"g\":107},\"8872\":{\"m\":108,\"g\":107},\"8861\":{\"m\":108,\"g\":107},\"8860\":{\"m\":108,\"g\":107},\"8868\":{\"m\":108,\"g\":107},\"8815\":{\"m\":108,\"g\":107},\"8859\":{\"m\":108,\"g\":107},\"8853\":{\"m\":108,\"g\":107},\"8843\":{\"m\":108,\"g\":107},\"8753\":{\"m\":108,\"g\":107},\"8751\":{\"m\":108,\"g\":107},\"8838\":{\"m\":108,\"g\":107},\"8144\":{\"m\":108,\"g\":107},\"8680\":{\"m\":108,\"g\":107},\"8839\":{\"m\":108,\"g\":107},\"8828\":{\"m\":108,\"g\":107},\"8836\":{\"m\":108,\"g\":107},\"8824\":{\"m\":108,\"g\":107},\"8809\":{\"m\":108,\"g\":107},\"8832\":{\"m\":108,\"g\":107},\"8681\":{\"m\":108,\"g\":107},\"8827\":{\"m\":108,\"g\":107},\"8823\":{\"m\":108,\"g\":107},\"8817\":{\"m\":108,\"g\":107},\"8804\":{\"m\":108,\"g\":107},\"8782\":{\"m\":108,\"g\":107},\"8802\":{\"m\":108,\"g\":107},\"8800\":{\"m\":108,\"g\":107},\"8797\":{\"m\":108,\"g\":107},\"8596\":{\"m\":108,\"g\":107},\"8779\":{\"m\":108,\"g\":107},\"8780\":{\"m\":108,\"g\":107},\"8744\":{\"m\":108,\"g\":107},\"8571\":{\"m\":108,\"g\":107},\"8212\":{\"m\":108,\"g\":107},\"8255\":{\"m\":108,\"g\":107},\"8776\":{\"m\":108,\"g\":107},\"8773\":{\"m\":108,\"g\":107},\"8771\":{\"m\":108,\"g\":107},\"8762\":{\"m\":108,\"g\":107},\"8768\":{\"m\":108,\"g\":107},\"8639\":{\"m\":108,\"g\":107},\"8552\":{\"m\":108,\"g\":107},\"8749\":{\"m\":108,\"g\":107},\"8294\":{\"m\":108,\"g\":107},\"8738\":{\"m\":108,\"g\":107},\"8437\":{\"m\":108,\"g\":107},\"8745\":{\"m\":108,\"g\":107},\"8733\":{\"m\":108,\"g\":107},\"8735\":{\"m\":108,\"g\":107},\"8737\":{\"m\":108,\"g\":107},\"8662\":{\"m\":108,\"g\":107},\"7114\":{\"m\":108,\"g\":107},\"8678\":{\"m\":108,\"g\":107},\"8732\":{\"m\":108,\"g\":107},\"8729\":{\"m\":108,\"g\":107},\"8676\":{\"m\":108,\"g\":107},\"8699\":{\"m\":108,\"g\":107},\"9558\":{\"m\":109,\"g\":115},\"9557\":{\"m\":109,\"g\":115},\"9549\":{\"m\":109,\"g\":115},\"9544\":{\"m\":109,\"g\":115},\"9547\":{\"m\":109,\"g\":115},\"9546\":{\"m\":109,\"g\":115},\"9592\":{\"m\":110,\"g\":115},\"9591\":{\"m\":110,\"g\":115},\"9589\":{\"m\":110,\"g\":115},\"9587\":{\"m\":110,\"g\":115},\"9581\":{\"m\":110,\"g\":115},\"9578\":{\"m\":110,\"g\":115},\"9229\":{\"m\":110,\"g\":115},\"9536\":{\"m\":110,\"g\":115},\"9535\":{\"m\":110,\"g\":115},\"9559\":{\"m\":110,\"g\":115},\"9560\":{\"m\":110,\"g\":115},\"7317\":{\"m\":110,\"g\":115},\"9576\":{\"m\":110,\"g\":115},\"9429\":{\"m\":110,\"g\":115},\"9565\":{\"m\":110,\"g\":115},\"9498\":{\"m\":110,\"g\":115},\"9716\":{\"m\":111,\"g\":115},\"9708\":{\"m\":111,\"g\":115},\"9340\":{\"m\":111,\"g\":115},\"9703\":{\"m\":111,\"g\":115},\"9702\":{\"m\":111,\"g\":115},\"9695\":{\"m\":111,\"g\":115},\"9683\":{\"m\":111,\"g\":115},\"9700\":{\"m\":111,\"g\":115},\"9676\":{\"m\":111,\"g\":115},\"9694\":{\"m\":111,\"g\":115},\"9693\":{\"m\":111,\"g\":115},\"9679\":{\"m\":111,\"g\":115},\"9678\":{\"m\":111,\"g\":115},\"9397\":{\"m\":111,\"g\":115},\"9495\":{\"m\":111,\"g\":115},\"9677\":{\"m\":111,\"g\":115},\"9446\":{\"m\":111,\"g\":115},\"9071\":{\"m\":111,\"g\":115},\"9597\":{\"m\":111,\"g\":115},\"9555\":{\"m\":111,\"g\":115},\"9583\":{\"m\":111,\"g\":115},\"9564\":{\"m\":111,\"g\":115},\"9658\":{\"m\":111,\"g\":115},\"9665\":{\"m\":111,\"g\":115},\"9648\":{\"m\":111,\"g\":115},\"9637\":{\"m\":111,\"g\":115},\"9649\":{\"m\":111,\"g\":115},\"9647\":{\"m\":111,\"g\":115},\"9656\":{\"m\":111,\"g\":115},\"9523\":{\"m\":111,\"g\":115},\"9606\":{\"m\":111,\"g\":115},\"9635\":{\"m\":111,\"g\":115},\"9630\":{\"m\":111,\"g\":115},\"9640\":{\"m\":111,\"g\":115},\"9636\":{\"m\":111,\"g\":115},\"9301\":{\"m\":111,\"g\":115},\"8328\":{\"m\":111,\"g\":115},\"9632\":{\"m\":111,\"g\":115},\"9629\":{\"m\":111,\"g\":115},\"9628\":{\"m\":111,\"g\":115},\"9623\":{\"m\":111,\"g\":115},\"9622\":{\"m\":111,\"g\":115},\"9608\":{\"m\":111,\"g\":115},\"8901\":{\"m\":111,\"g\":115},\"9613\":{\"m\":111,\"g\":115},\"9190\":{\"m\":111,\"g\":115},\"9554\":{\"m\":111,\"g\":115},\"9500\":{\"m\":111,\"g\":115},\"9436\":{\"m\":111,\"g\":115},\"9568\":{\"m\":111,\"g\":115},\"10221\":{\"m\":112,\"g\":115},\"10340\":{\"m\":112,\"g\":115},\"10303\":{\"m\":112,\"g\":115},\"10331\":{\"m\":112,\"g\":115},\"10339\":{\"m\":112,\"g\":115},\"10330\":{\"m\":112,\"g\":115},\"10327\":{\"m\":112,\"g\":115},\"10338\":{\"m\":112,\"g\":115},\"10254\":{\"m\":112,\"g\":115},\"10264\":{\"m\":112,\"g\":115},\"10280\":{\"m\":112,\"g\":115},\"10335\":{\"m\":112,\"g\":115},\"10322\":{\"m\":112,\"g\":115},\"10328\":{\"m\":112,\"g\":115},\"10326\":{\"m\":112,\"g\":115},\"10233\":{\"m\":112,\"g\":115},\"10297\":{\"m\":112,\"g\":115},\"10314\":{\"m\":112,\"g\":115},\"10311\":{\"m\":112,\"g\":115},\"10310\":{\"m\":112,\"g\":115},\"10299\":{\"m\":112,\"g\":115},\"9090\":{\"m\":112,\"g\":115},\"10229\":{\"m\":112,\"g\":115},\"10239\":{\"m\":112,\"g\":115},\"9881\":{\"m\":112,\"g\":115},\"10294\":{\"m\":112,\"g\":115},\"10292\":{\"m\":112,\"g\":115},\"10282\":{\"m\":112,\"g\":115},\"10184\":{\"m\":112,\"g\":115},\"10241\":{\"m\":112,\"g\":115},\"9662\":{\"m\":112,\"g\":115},\"10252\":{\"m\":112,\"g\":115},\"9940\":{\"m\":112,\"g\":115},\"10251\":{\"m\":112,\"g\":115},\"10256\":{\"m\":112,\"g\":115},\"10250\":{\"m\":112,\"g\":115},\"10262\":{\"m\":112,\"g\":115},\"9954\":{\"m\":112,\"g\":115},\"10173\":{\"m\":112,\"g\":115},\"10060\":{\"m\":112,\"g\":115},\"10253\":{\"m\":112,\"g\":115},\"10093\":{\"m\":112,\"g\":115},\"10240\":{\"m\":112,\"g\":115},\"8803\":{\"m\":112,\"g\":115},\"10246\":{\"m\":112,\"g\":115},\"9795\":{\"m\":112,\"g\":115},\"10245\":{\"m\":112,\"g\":115},\"10242\":{\"m\":112,\"g\":115},\"10236\":{\"m\":112,\"g\":115},\"10234\":{\"m\":112,\"g\":115},\"10213\":{\"m\":112,\"g\":115},\"10238\":{\"m\":112,\"g\":115},\"10210\":{\"m\":112,\"g\":115},\"10220\":{\"m\":112,\"g\":115},\"10214\":{\"m\":112,\"g\":115},\"9960\":{\"m\":112,\"g\":115},\"10208\":{\"m\":112,\"g\":115},\"10212\":{\"m\":112,\"g\":115},\"10209\":{\"m\":112,\"g\":115},\"10207\":{\"m\":112,\"g\":115},\"10205\":{\"m\":112,\"g\":115},\"9300\":{\"m\":112,\"g\":115},\"10193\":{\"m\":112,\"g\":115},\"10127\":{\"m\":112,\"g\":115},\"10188\":{\"m\":112,\"g\":115},\"10165\":{\"m\":112,\"g\":115},\"10169\":{\"m\":112,\"g\":115},\"4422\":{\"m\":112,\"g\":115},\"9900\":{\"m\":112,\"g\":115},\"10191\":{\"m\":112,\"g\":115},\"7995\":{\"m\":112,\"g\":115},\"10185\":{\"m\":112,\"g\":115},\"10149\":{\"m\":112,\"g\":115},\"9522\":{\"m\":112,\"g\":115},\"10182\":{\"m\":112,\"g\":115},\"10181\":{\"m\":112,\"g\":115},\"9839\":{\"m\":112,\"g\":115},\"10176\":{\"m\":112,\"g\":115},\"9595\":{\"m\":112,\"g\":115},\"10166\":{\"m\":112,\"g\":115},\"9925\":{\"m\":112,\"g\":115},\"10156\":{\"m\":112,\"g\":115},\"10161\":{\"m\":112,\"g\":115},\"10159\":{\"m\":112,\"g\":115},\"10131\":{\"m\":112,\"g\":115},\"10028\":{\"m\":112,\"g\":115},\"10155\":{\"m\":112,\"g\":115},\"9871\":{\"m\":112,\"g\":115},\"9434\":{\"m\":112,\"g\":115},\"10148\":{\"m\":112,\"g\":115},\"6226\":{\"m\":112,\"g\":115},\"10013\":{\"m\":112,\"g\":115},\"10147\":{\"m\":112,\"g\":115},\"9981\":{\"m\":112,\"g\":115},\"9989\":{\"m\":112,\"g\":115},\"10108\":{\"m\":112,\"g\":115},\"10123\":{\"m\":112,\"g\":115},\"7843\":{\"m\":112,\"g\":115},\"10090\":{\"m\":112,\"g\":115},\"10104\":{\"m\":112,\"g\":115},\"8801\":{\"m\":112,\"g\":115},\"10040\":{\"m\":112,\"g\":115},\"10141\":{\"m\":112,\"g\":115},\"10095\":{\"m\":112,\"g\":115},\"10144\":{\"m\":112,\"g\":115},\"9971\":{\"m\":112,\"g\":115},\"10134\":{\"m\":112,\"g\":115},\"10135\":{\"m\":112,\"g\":115},\"10074\":{\"m\":112,\"g\":115},\"10128\":{\"m\":112,\"g\":115},\"10126\":{\"m\":112,\"g\":115},\"10113\":{\"m\":112,\"g\":115},\"10096\":{\"m\":112,\"g\":115},\"10056\":{\"m\":112,\"g\":115},\"10101\":{\"m\":112,\"g\":115},\"9969\":{\"m\":112,\"g\":115},\"9741\":{\"m\":112,\"g\":115},\"9477\":{\"m\":112,\"g\":115},\"10117\":{\"m\":112,\"g\":115},\"10102\":{\"m\":112,\"g\":115},\"10116\":{\"m\":112,\"g\":115},\"10068\":{\"m\":112,\"g\":115},\"9956\":{\"m\":112,\"g\":115},\"10041\":{\"m\":112,\"g\":115},\"10058\":{\"m\":112,\"g\":115},\"10107\":{\"m\":112,\"g\":115},\"10100\":{\"m\":112,\"g\":115},\"9834\":{\"m\":112,\"g\":115},\"9861\":{\"m\":112,\"g\":115},\"9764\":{\"m\":112,\"g\":115},\"9269\":{\"m\":112,\"g\":115},\"9620\":{\"m\":112,\"g\":115},\"6905\":{\"m\":112,\"g\":115},\"10032\":{\"m\":112,\"g\":115},\"10097\":{\"m\":112,\"g\":115},\"10029\":{\"m\":112,\"g\":115},\"10039\":{\"m\":112,\"g\":115},\"10092\":{\"m\":112,\"g\":115},\"10086\":{\"m\":112,\"g\":115},\"10057\":{\"m\":112,\"g\":115},\"10047\":{\"m\":112,\"g\":115},\"9842\":{\"m\":112,\"g\":115},\"9965\":{\"m\":112,\"g\":115},\"10069\":{\"m\":112,\"g\":115},\"9884\":{\"m\":112,\"g\":115},\"10087\":{\"m\":112,\"g\":115},\"8622\":{\"m\":112,\"g\":115},\"8555\":{\"m\":112,\"g\":115},\"10080\":{\"m\":112,\"g\":115},\"10079\":{\"m\":112,\"g\":115},\"10043\":{\"m\":112,\"g\":115},\"10007\":{\"m\":112,\"g\":115},\"7182\":{\"m\":112,\"g\":115},\"8725\":{\"m\":112,\"g\":115},\"8867\":{\"m\":112,\"g\":115},\"9567\":{\"m\":112,\"g\":115},\"5255\":{\"m\":112,\"g\":115},\"10006\":{\"m\":112,\"g\":115},\"9534\":{\"m\":112,\"g\":115},\"9934\":{\"m\":112,\"g\":115},\"9931\":{\"m\":112,\"g\":115},\"9801\":{\"m\":112,\"g\":115},\"10049\":{\"m\":112,\"g\":115},\"10055\":{\"m\":112,\"g\":115},\"9964\":{\"m\":112,\"g\":115},\"10052\":{\"m\":112,\"g\":115},\"10050\":{\"m\":112,\"g\":115},\"8677\":{\"m\":112,\"g\":115},\"10008\":{\"m\":112,\"g\":115},\"9951\":{\"m\":112,\"g\":115},\"9957\":{\"m\":112,\"g\":115},\"9938\":{\"m\":112,\"g\":115},\"9634\":{\"m\":112,\"g\":115},\"9886\":{\"m\":112,\"g\":115},\"9973\":{\"m\":112,\"g\":115},\"9846\":{\"m\":112,\"g\":115},\"10003\":{\"m\":112,\"g\":115},\"10004\":{\"m\":112,\"g\":115},\"9997\":{\"m\":112,\"g\":115},\"10016\":{\"m\":112,\"g\":115},\"9993\":{\"m\":112,\"g\":115},\"9994\":{\"m\":112,\"g\":115},\"9999\":{\"m\":112,\"g\":115},\"10000\":{\"m\":112,\"g\":115},\"9996\":{\"m\":112,\"g\":115},\"9988\":{\"m\":112,\"g\":115},\"9986\":{\"m\":112,\"g\":115},\"9733\":{\"m\":112,\"g\":115},\"9314\":{\"m\":112,\"g\":115},\"9978\":{\"m\":112,\"g\":115},\"9914\":{\"m\":112,\"g\":115},\"9460\":{\"m\":112,\"g\":115},\"9958\":{\"m\":112,\"g\":115},\"9906\":{\"m\":112,\"g\":115},\"9953\":{\"m\":112,\"g\":115},\"9937\":{\"m\":112,\"g\":115},\"9959\":{\"m\":112,\"g\":115},\"9955\":{\"m\":112,\"g\":115},\"9755\":{\"m\":112,\"g\":115},\"9952\":{\"m\":112,\"g\":115},\"9671\":{\"m\":112,\"g\":115},\"7912\":{\"m\":112,\"g\":115},\"9895\":{\"m\":112,\"g\":115},\"9905\":{\"m\":112,\"g\":115},\"9927\":{\"m\":112,\"g\":115},\"9946\":{\"m\":112,\"g\":115},\"9912\":{\"m\":112,\"g\":115},\"9869\":{\"m\":112,\"g\":115},\"8747\":{\"m\":112,\"g\":115},\"9939\":{\"m\":112,\"g\":115},\"9929\":{\"m\":112,\"g\":115},\"9932\":{\"m\":112,\"g\":115},\"9705\":{\"m\":112,\"g\":115},\"9909\":{\"m\":112,\"g\":115},\"9920\":{\"m\":112,\"g\":115},\"9921\":{\"m\":112,\"g\":115},\"9879\":{\"m\":112,\"g\":115},\"9919\":{\"m\":112,\"g\":115},\"9916\":{\"m\":112,\"g\":115},\"9907\":{\"m\":112,\"g\":115},\"9844\":{\"m\":112,\"g\":115},\"9913\":{\"m\":112,\"g\":115},\"9902\":{\"m\":112,\"g\":115},\"8118\":{\"m\":112,\"g\":115},\"9875\":{\"m\":112,\"g\":115},\"9893\":{\"m\":112,\"g\":115},\"9878\":{\"m\":112,\"g\":115},\"9803\":{\"m\":112,\"g\":115},\"9876\":{\"m\":112,\"g\":115},\"9874\":{\"m\":112,\"g\":115},\"9783\":{\"m\":112,\"g\":115},\"9882\":{\"m\":112,\"g\":115},\"9857\":{\"m\":112,\"g\":115},\"9862\":{\"m\":112,\"g\":115},\"9864\":{\"m\":112,\"g\":115},\"8964\":{\"m\":112,\"g\":115},\"9794\":{\"m\":112,\"g\":115},\"9858\":{\"m\":112,\"g\":115},\"9852\":{\"m\":112,\"g\":115},\"9847\":{\"m\":112,\"g\":115},\"9073\":{\"m\":112,\"g\":115},\"9850\":{\"m\":112,\"g\":115},\"9797\":{\"m\":112,\"g\":115},\"9661\":{\"m\":112,\"g\":115},\"9841\":{\"m\":112,\"g\":115},\"9750\":{\"m\":112,\"g\":115},\"9709\":{\"m\":112,\"g\":115},\"8909\":{\"m\":112,\"g\":115},\"9840\":{\"m\":112,\"g\":115},\"9824\":{\"m\":112,\"g\":115},\"9835\":{\"m\":112,\"g\":115},\"9837\":{\"m\":112,\"g\":115},\"9836\":{\"m\":112,\"g\":115},\"9831\":{\"m\":112,\"g\":115},\"9830\":{\"m\":112,\"g\":115},\"9761\":{\"m\":112,\"g\":115},\"9828\":{\"m\":112,\"g\":115},\"9827\":{\"m\":112,\"g\":115},\"9826\":{\"m\":112,\"g\":115},\"9822\":{\"m\":112,\"g\":115},\"9802\":{\"m\":112,\"g\":115},\"9820\":{\"m\":112,\"g\":115},\"9746\":{\"m\":112,\"g\":115},\"9817\":{\"m\":112,\"g\":115},\"9815\":{\"m\":112,\"g\":115},\"9807\":{\"m\":112,\"g\":115},\"8345\":{\"m\":112,\"g\":115},\"9809\":{\"m\":112,\"g\":115},\"9670\":{\"m\":112,\"g\":115},\"9556\":{\"m\":112,\"g\":115},\"9675\":{\"m\":112,\"g\":115},\"9712\":{\"m\":112,\"g\":115},\"9793\":{\"m\":112,\"g\":115},\"9216\":{\"m\":112,\"g\":115},\"8375\":{\"m\":112,\"g\":115},\"9663\":{\"m\":112,\"g\":115},\"9715\":{\"m\":112,\"g\":115},\"9692\":{\"m\":112,\"g\":115},\"9776\":{\"m\":112,\"g\":115},\"9792\":{\"m\":112,\"g\":115},\"9786\":{\"m\":112,\"g\":115},\"9789\":{\"m\":112,\"g\":115},\"9788\":{\"m\":112,\"g\":115},\"9757\":{\"m\":112,\"g\":115},\"9784\":{\"m\":112,\"g\":115},\"9777\":{\"m\":112,\"g\":115},\"9749\":{\"m\":112,\"g\":115},\"9772\":{\"m\":112,\"g\":115},\"8750\":{\"m\":112,\"g\":115},\"6287\":{\"m\":112,\"g\":115},\"6407\":{\"m\":112,\"g\":115},\"9355\":{\"m\":112,\"g\":115},\"9770\":{\"m\":112,\"g\":115},\"8236\":{\"m\":112,\"g\":115},\"9759\":{\"m\":112,\"g\":115},\"9745\":{\"m\":112,\"g\":115},\"9721\":{\"m\":112,\"g\":115},\"9735\":{\"m\":112,\"g\":115},\"9573\":{\"m\":112,\"g\":115},\"9673\":{\"m\":112,\"g\":115},\"9740\":{\"m\":112,\"g\":115},\"9739\":{\"m\":112,\"g\":115},\"9684\":{\"m\":112,\"g\":115},\"9732\":{\"m\":112,\"g\":115},\"9730\":{\"m\":112,\"g\":115},\"9728\":{\"m\":112,\"g\":115},\"9615\":{\"m\":112,\"g\":115},\"9505\":{\"m\":112,\"g\":115},\"9724\":{\"m\":112,\"g\":115},\"9720\":{\"m\":112,\"g\":115},\"11263\":{\"m\":113,\"g\":115},\"11259\":{\"m\":113,\"g\":115},\"11061\":{\"m\":113,\"g\":115},\"11235\":{\"m\":113,\"g\":115},\"11240\":{\"m\":113,\"g\":115},\"11209\":{\"m\":113,\"g\":115},\"11254\":{\"m\":113,\"g\":115},\"11242\":{\"m\":113,\"g\":115},\"11252\":{\"m\":113,\"g\":115},\"11251\":{\"m\":113,\"g\":115},\"10048\":{\"m\":113,\"g\":115},\"10042\":{\"m\":113,\"g\":115},\"11248\":{\"m\":113,\"g\":115},\"11247\":{\"m\":113,\"g\":115},\"10996\":{\"m\":113,\"g\":115},\"11206\":{\"m\":113,\"g\":115},\"11228\":{\"m\":113,\"g\":115},\"11237\":{\"m\":113,\"g\":115},\"11222\":{\"m\":113,\"g\":115},\"11174\":{\"m\":113,\"g\":115},\"11162\":{\"m\":113,\"g\":115},\"10571\":{\"m\":113,\"g\":115},\"11229\":{\"m\":113,\"g\":115},\"11137\":{\"m\":113,\"g\":115},\"9624\":{\"m\":113,\"g\":115},\"11194\":{\"m\":113,\"g\":115},\"11225\":{\"m\":113,\"g\":115},\"11063\":{\"m\":113,\"g\":115},\"11217\":{\"m\":113,\"g\":115},\"11215\":{\"m\":113,\"g\":115},\"11140\":{\"m\":113,\"g\":115},\"11213\":{\"m\":113,\"g\":115},\"11012\":{\"m\":113,\"g\":115},\"11096\":{\"m\":113,\"g\":115},\"11011\":{\"m\":113,\"g\":115},\"11178\":{\"m\":113,\"g\":115},\"11196\":{\"m\":113,\"g\":115},\"10741\":{\"m\":113,\"g\":115},\"11198\":{\"m\":113,\"g\":115},\"10517\":{\"m\":113,\"g\":115},\"10838\":{\"m\":113,\"g\":115},\"10859\":{\"m\":113,\"g\":115},\"10609\":{\"m\":113,\"g\":115},\"10855\":{\"m\":113,\"g\":115},\"11090\":{\"m\":113,\"g\":115},\"10780\":{\"m\":113,\"g\":115},\"10892\":{\"m\":113,\"g\":115},\"11166\":{\"m\":113,\"g\":115},\"11192\":{\"m\":113,\"g\":115},\"11189\":{\"m\":113,\"g\":115},\"10873\":{\"m\":113,\"g\":115},\"11173\":{\"m\":113,\"g\":115},\"11167\":{\"m\":113,\"g\":115},\"10637\":{\"m\":113,\"g\":115},\"11185\":{\"m\":113,\"g\":115},\"11161\":{\"m\":113,\"g\":115},\"9537\":{\"m\":113,\"g\":115},\"10830\":{\"m\":113,\"g\":115},\"10133\":{\"m\":113,\"g\":115},\"11138\":{\"m\":113,\"g\":115},\"11179\":{\"m\":113,\"g\":115},\"11176\":{\"m\":113,\"g\":115},\"11159\":{\"m\":113,\"g\":115},\"11170\":{\"m\":113,\"g\":115},\"11175\":{\"m\":113,\"g\":115},\"11124\":{\"m\":113,\"g\":115},\"11171\":{\"m\":113,\"g\":115},\"11164\":{\"m\":113,\"g\":115},\"11130\":{\"m\":113,\"g\":115},\"11163\":{\"m\":113,\"g\":115},\"10837\":{\"m\":113,\"g\":115},\"11152\":{\"m\":113,\"g\":115},\"10988\":{\"m\":113,\"g\":115},\"11160\":{\"m\":113,\"g\":115},\"10422\":{\"m\":113,\"g\":115},\"10263\":{\"m\":113,\"g\":115},\"11156\":{\"m\":113,\"g\":115},\"10508\":{\"m\":113,\"g\":115},\"10779\":{\"m\":113,\"g\":115},\"10768\":{\"m\":113,\"g\":115},\"11132\":{\"m\":113,\"g\":115},\"11148\":{\"m\":113,\"g\":115},\"11149\":{\"m\":113,\"g\":115},\"10559\":{\"m\":113,\"g\":115},\"11135\":{\"m\":113,\"g\":115},\"10720\":{\"m\":113,\"g\":115},\"11145\":{\"m\":113,\"g\":115},\"11123\":{\"m\":113,\"g\":115},\"11143\":{\"m\":113,\"g\":115},\"11120\":{\"m\":113,\"g\":115},\"10512\":{\"m\":113,\"g\":115},\"10271\":{\"m\":113,\"g\":115},\"11005\":{\"m\":113,\"g\":115},\"10760\":{\"m\":113,\"g\":115},\"11128\":{\"m\":113,\"g\":115},\"11075\":{\"m\":113,\"g\":115},\"10985\":{\"m\":113,\"g\":115},\"11111\":{\"m\":113,\"g\":115},\"11115\":{\"m\":113,\"g\":115},\"11114\":{\"m\":113,\"g\":115},\"10735\":{\"m\":113,\"g\":115},\"11112\":{\"m\":113,\"g\":115},\"10972\":{\"m\":113,\"g\":115},\"11080\":{\"m\":113,\"g\":115},\"11113\":{\"m\":113,\"g\":115},\"11071\":{\"m\":113,\"g\":115},\"11081\":{\"m\":113,\"g\":115},\"11102\":{\"m\":113,\"g\":115},\"11101\":{\"m\":113,\"g\":115},\"10846\":{\"m\":113,\"g\":115},\"11094\":{\"m\":113,\"g\":115},\"11067\":{\"m\":113,\"g\":115},\"11099\":{\"m\":113,\"g\":115},\"11087\":{\"m\":113,\"g\":115},\"10991\":{\"m\":113,\"g\":115},\"11085\":{\"m\":113,\"g\":115},\"11070\":{\"m\":113,\"g\":115},\"11092\":{\"m\":113,\"g\":115},\"10729\":{\"m\":113,\"g\":115},\"9642\":{\"m\":113,\"g\":115},\"10816\":{\"m\":113,\"g\":115},\"11083\":{\"m\":113,\"g\":115},\"10875\":{\"m\":113,\"g\":115},\"11082\":{\"m\":113,\"g\":115},\"11079\":{\"m\":113,\"g\":115},\"10611\":{\"m\":113,\"g\":115},\"11076\":{\"m\":113,\"g\":115},\"11073\":{\"m\":113,\"g\":115},\"10975\":{\"m\":113,\"g\":115},\"11069\":{\"m\":113,\"g\":115},\"11056\":{\"m\":113,\"g\":115},\"10976\":{\"m\":113,\"g\":115},\"11054\":{\"m\":113,\"g\":115},\"11050\":{\"m\":113,\"g\":115},\"11022\":{\"m\":113,\"g\":115},\"9614\":{\"m\":113,\"g\":115},\"11010\":{\"m\":113,\"g\":115},\"10591\":{\"m\":113,\"g\":115},\"11015\":{\"m\":113,\"g\":115},\"10940\":{\"m\":113,\"g\":115},\"10701\":{\"m\":113,\"g\":115},\"11036\":{\"m\":113,\"g\":115},\"11038\":{\"m\":113,\"g\":115},\"10986\":{\"m\":113,\"g\":115},\"11033\":{\"m\":113,\"g\":115},\"11017\":{\"m\":113,\"g\":115},\"11003\":{\"m\":113,\"g\":115},\"11013\":{\"m\":113,\"g\":115},\"10543\":{\"m\":113,\"g\":115},\"10555\":{\"m\":113,\"g\":115},\"10964\":{\"m\":113,\"g\":115},\"11009\":{\"m\":113,\"g\":115},\"10999\":{\"m\":113,\"g\":115},\"10978\":{\"m\":113,\"g\":115},\"10995\":{\"m\":113,\"g\":115},\"10997\":{\"m\":113,\"g\":115},\"10550\":{\"m\":113,\"g\":115},\"10565\":{\"m\":113,\"g\":115},\"10616\":{\"m\":113,\"g\":115},\"10751\":{\"m\":113,\"g\":115},\"10930\":{\"m\":113,\"g\":115},\"10981\":{\"m\":113,\"g\":115},\"10112\":{\"m\":113,\"g\":115},\"10980\":{\"m\":113,\"g\":115},\"10982\":{\"m\":113,\"g\":115},\"10551\":{\"m\":113,\"g\":115},\"10965\":{\"m\":113,\"g\":115},\"10944\":{\"m\":113,\"g\":115},\"10971\":{\"m\":113,\"g\":115},\"10941\":{\"m\":113,\"g\":115},\"10495\":{\"m\":113,\"g\":115},\"10372\":{\"m\":113,\"g\":115},\"10970\":{\"m\":113,\"g\":115},\"10947\":{\"m\":113,\"g\":115},\"10968\":{\"m\":113,\"g\":115},\"10967\":{\"m\":113,\"g\":115},\"10963\":{\"m\":113,\"g\":115},\"10960\":{\"m\":113,\"g\":115},\"10958\":{\"m\":113,\"g\":115},\"10956\":{\"m\":113,\"g\":115},\"10955\":{\"m\":113,\"g\":115},\"10749\":{\"m\":113,\"g\":115},\"10927\":{\"m\":113,\"g\":115},\"10936\":{\"m\":113,\"g\":115},\"10192\":{\"m\":113,\"g\":115},\"10939\":{\"m\":113,\"g\":115},\"10935\":{\"m\":113,\"g\":115},\"10929\":{\"m\":113,\"g\":115},\"10932\":{\"m\":113,\"g\":115},\"10898\":{\"m\":113,\"g\":115},\"10612\":{\"m\":113,\"g\":115},\"10923\":{\"m\":113,\"g\":115},\"10926\":{\"m\":113,\"g\":115},\"10899\":{\"m\":113,\"g\":115},\"10883\":{\"m\":113,\"g\":115},\"10910\":{\"m\":113,\"g\":115},\"10924\":{\"m\":113,\"g\":115},\"10132\":{\"m\":113,\"g\":115},\"10881\":{\"m\":113,\"g\":115},\"10894\":{\"m\":113,\"g\":115},\"10915\":{\"m\":113,\"g\":115},\"10376\":{\"m\":113,\"g\":115},\"10778\":{\"m\":113,\"g\":115},\"10872\":{\"m\":113,\"g\":115},\"10885\":{\"m\":113,\"g\":115},\"10895\":{\"m\":113,\"g\":115},\"10845\":{\"m\":113,\"g\":115},\"10880\":{\"m\":113,\"g\":115},\"10572\":{\"m\":113,\"g\":115},\"10861\":{\"m\":113,\"g\":115},\"10876\":{\"m\":113,\"g\":115},\"10877\":{\"m\":113,\"g\":115},\"10534\":{\"m\":113,\"g\":115},\"10827\":{\"m\":113,\"g\":115},\"10832\":{\"m\":113,\"g\":115},\"10786\":{\"m\":113,\"g\":115},\"10860\":{\"m\":113,\"g\":115},\"10718\":{\"m\":113,\"g\":115},\"10829\":{\"m\":113,\"g\":115},\"10828\":{\"m\":113,\"g\":115},\"10825\":{\"m\":113,\"g\":115},\"10826\":{\"m\":113,\"g\":115},\"10822\":{\"m\":113,\"g\":115},\"10824\":{\"m\":113,\"g\":115},\"10823\":{\"m\":113,\"g\":115},\"10787\":{\"m\":113,\"g\":115},\"10820\":{\"m\":113,\"g\":115},\"10794\":{\"m\":113,\"g\":115},\"10799\":{\"m\":113,\"g\":115},\"10818\":{\"m\":113,\"g\":115},\"10540\":{\"m\":113,\"g\":115},\"10504\":{\"m\":113,\"g\":115},\"10761\":{\"m\":113,\"g\":115},\"10814\":{\"m\":113,\"g\":115},\"10812\":{\"m\":113,\"g\":115},\"10259\":{\"m\":113,\"g\":115},\"10323\":{\"m\":113,\"g\":115},\"10581\":{\"m\":113,\"g\":115},\"10773\":{\"m\":113,\"g\":115},\"10792\":{\"m\":113,\"g\":115},\"10791\":{\"m\":113,\"g\":115},\"10770\":{\"m\":113,\"g\":115},\"10783\":{\"m\":113,\"g\":115},\"10782\":{\"m\":113,\"g\":115},\"10715\":{\"m\":113,\"g\":115},\"10705\":{\"m\":113,\"g\":115},\"10776\":{\"m\":113,\"g\":115},\"10777\":{\"m\":113,\"g\":115},\"10771\":{\"m\":113,\"g\":115},\"10774\":{\"m\":113,\"g\":115},\"10767\":{\"m\":113,\"g\":115},\"10756\":{\"m\":113,\"g\":115},\"10765\":{\"m\":113,\"g\":115},\"10574\":{\"m\":113,\"g\":115},\"10556\":{\"m\":113,\"g\":115},\"10130\":{\"m\":113,\"g\":115},\"10762\":{\"m\":113,\"g\":115},\"10541\":{\"m\":113,\"g\":115},\"10281\":{\"m\":113,\"g\":115},\"10759\":{\"m\":113,\"g\":115},\"10300\":{\"m\":113,\"g\":115},\"10755\":{\"m\":113,\"g\":115},\"10732\":{\"m\":113,\"g\":115},\"10758\":{\"m\":113,\"g\":115},\"10757\":{\"m\":113,\"g\":115},\"10727\":{\"m\":113,\"g\":115},\"10754\":{\"m\":113,\"g\":115},\"10724\":{\"m\":113,\"g\":115},\"10753\":{\"m\":113,\"g\":115},\"10737\":{\"m\":113,\"g\":115},\"10728\":{\"m\":113,\"g\":115},\"10730\":{\"m\":113,\"g\":115},\"9849\":{\"m\":113,\"g\":115},\"10731\":{\"m\":113,\"g\":115},\"10709\":{\"m\":113,\"g\":115},\"10699\":{\"m\":113,\"g\":115},\"10678\":{\"m\":113,\"g\":115},\"10695\":{\"m\":113,\"g\":115},\"10694\":{\"m\":113,\"g\":115},\"10717\":{\"m\":113,\"g\":115},\"10714\":{\"m\":113,\"g\":115},\"10716\":{\"m\":113,\"g\":115},\"10385\":{\"m\":113,\"g\":115},\"10317\":{\"m\":113,\"g\":115},\"10706\":{\"m\":113,\"g\":115},\"10592\":{\"m\":113,\"g\":115},\"10697\":{\"m\":113,\"g\":115},\"10696\":{\"m\":113,\"g\":115},\"10651\":{\"m\":113,\"g\":115},\"10688\":{\"m\":113,\"g\":115},\"10686\":{\"m\":113,\"g\":115},\"10685\":{\"m\":113,\"g\":115},\"10673\":{\"m\":113,\"g\":115},\"10684\":{\"m\":113,\"g\":115},\"10680\":{\"m\":113,\"g\":115},\"10681\":{\"m\":113,\"g\":115},\"10683\":{\"m\":113,\"g\":115},\"10645\":{\"m\":113,\"g\":115},\"10677\":{\"m\":113,\"g\":115},\"10679\":{\"m\":113,\"g\":115},\"10648\":{\"m\":113,\"g\":115},\"10671\":{\"m\":113,\"g\":115},\"10675\":{\"m\":113,\"g\":115},\"10666\":{\"m\":113,\"g\":115},\"10670\":{\"m\":113,\"g\":115},\"10668\":{\"m\":113,\"g\":115},\"10664\":{\"m\":113,\"g\":115},\"10661\":{\"m\":113,\"g\":115},\"10522\":{\"m\":113,\"g\":115},\"10653\":{\"m\":113,\"g\":115},\"10634\":{\"m\":113,\"g\":115},\"10650\":{\"m\":113,\"g\":115},\"10321\":{\"m\":113,\"g\":115},\"10647\":{\"m\":113,\"g\":115},\"10081\":{\"m\":113,\"g\":115},\"10633\":{\"m\":113,\"g\":115},\"10319\":{\"m\":113,\"g\":115},\"10630\":{\"m\":113,\"g\":115},\"10631\":{\"m\":113,\"g\":115},\"10632\":{\"m\":113,\"g\":115},\"10586\":{\"m\":113,\"g\":115},\"10553\":{\"m\":113,\"g\":115},\"9873\":{\"m\":113,\"g\":115},\"10629\":{\"m\":113,\"g\":115},\"10628\":{\"m\":113,\"g\":115},\"10621\":{\"m\":113,\"g\":115},\"9947\":{\"m\":113,\"g\":115},\"10579\":{\"m\":113,\"g\":115},\"10595\":{\"m\":113,\"g\":115},\"10610\":{\"m\":113,\"g\":115},\"10622\":{\"m\":113,\"g\":115},\"10624\":{\"m\":113,\"g\":115},\"10222\":{\"m\":113,\"g\":115},\"9979\":{\"m\":113,\"g\":115},\"8274\":{\"m\":113,\"g\":115},\"10604\":{\"m\":113,\"g\":115},\"10525\":{\"m\":113,\"g\":115},\"10596\":{\"m\":113,\"g\":115},\"10273\":{\"m\":113,\"g\":115},\"10563\":{\"m\":113,\"g\":115},\"10190\":{\"m\":113,\"g\":115},\"10558\":{\"m\":113,\"g\":115},\"10526\":{\"m\":113,\"g\":115},\"9987\":{\"m\":113,\"g\":115},\"10584\":{\"m\":113,\"g\":115},\"9976\":{\"m\":113,\"g\":115},\"10171\":{\"m\":113,\"g\":115},\"10548\":{\"m\":113,\"g\":115},\"8813\":{\"m\":113,\"g\":115},\"10545\":{\"m\":113,\"g\":115},\"10523\":{\"m\":113,\"g\":115},\"10459\":{\"m\":113,\"g\":115},\"10529\":{\"m\":113,\"g\":115},\"8746\":{\"m\":113,\"g\":115},\"10538\":{\"m\":113,\"g\":115},\"10494\":{\"m\":113,\"g\":115},\"9928\":{\"m\":113,\"g\":115},\"10474\":{\"m\":113,\"g\":115},\"10530\":{\"m\":113,\"g\":115},\"10528\":{\"m\":113,\"g\":115},\"10524\":{\"m\":113,\"g\":115},\"10511\":{\"m\":113,\"g\":115},\"10506\":{\"m\":113,\"g\":115},\"10515\":{\"m\":113,\"g\":115},\"10491\":{\"m\":113,\"g\":115},\"10500\":{\"m\":113,\"g\":115},\"10507\":{\"m\":113,\"g\":115},\"10466\":{\"m\":113,\"g\":115},\"10498\":{\"m\":113,\"g\":115},\"10499\":{\"m\":113,\"g\":115},\"10493\":{\"m\":113,\"g\":115},\"10487\":{\"m\":113,\"g\":115},\"10230\":{\"m\":113,\"g\":115},\"10336\":{\"m\":113,\"g\":115},\"10203\":{\"m\":113,\"g\":115},\"10434\":{\"m\":113,\"g\":115},\"8863\":{\"m\":113,\"g\":115},\"10486\":{\"m\":113,\"g\":115},\"10484\":{\"m\":113,\"g\":115},\"10286\":{\"m\":113,\"g\":115},\"10481\":{\"m\":113,\"g\":115},\"10478\":{\"m\":113,\"g\":115},\"10479\":{\"m\":113,\"g\":115},\"10475\":{\"m\":113,\"g\":115},\"10473\":{\"m\":113,\"g\":115},\"10476\":{\"m\":113,\"g\":115},\"8189\":{\"m\":113,\"g\":115},\"9657\":{\"m\":113,\"g\":115},\"10375\":{\"m\":113,\"g\":115},\"9887\":{\"m\":113,\"g\":115},\"8710\":{\"m\":113,\"g\":115},\"10471\":{\"m\":113,\"g\":115},\"10470\":{\"m\":113,\"g\":115},\"10468\":{\"m\":113,\"g\":115},\"10465\":{\"m\":113,\"g\":115},\"10440\":{\"m\":113,\"g\":115},\"10456\":{\"m\":113,\"g\":115},\"10439\":{\"m\":113,\"g\":115},\"10463\":{\"m\":113,\"g\":115},\"10458\":{\"m\":113,\"g\":115},\"10457\":{\"m\":113,\"g\":115},\"10401\":{\"m\":113,\"g\":115},\"10358\":{\"m\":113,\"g\":115},\"10449\":{\"m\":113,\"g\":115},\"9343\":{\"m\":113,\"g\":115},\"10452\":{\"m\":113,\"g\":115},\"10450\":{\"m\":113,\"g\":115},\"10445\":{\"m\":113,\"g\":115},\"9626\":{\"m\":113,\"g\":115},\"9768\":{\"m\":113,\"g\":115},\"10143\":{\"m\":113,\"g\":115},\"10441\":{\"m\":113,\"g\":115},\"10437\":{\"m\":113,\"g\":115},\"9338\":{\"m\":113,\"g\":115},\"10435\":{\"m\":113,\"g\":115},\"10201\":{\"m\":113,\"g\":115},\"10432\":{\"m\":113,\"g\":115},\"10129\":{\"m\":113,\"g\":115},\"10433\":{\"m\":113,\"g\":115},\"10426\":{\"m\":113,\"g\":115},\"10425\":{\"m\":113,\"g\":115},\"10429\":{\"m\":113,\"g\":115},\"10431\":{\"m\":113,\"g\":115},\"10428\":{\"m\":113,\"g\":115},\"9962\":{\"m\":113,\"g\":115},\"10076\":{\"m\":113,\"g\":115},\"8627\":{\"m\":113,\"g\":115},\"10419\":{\"m\":113,\"g\":115},\"6539\":{\"m\":113,\"g\":115},\"10270\":{\"m\":113,\"g\":115},\"9948\":{\"m\":113,\"g\":115},\"10157\":{\"m\":113,\"g\":115},\"10313\":{\"m\":113,\"g\":115},\"10369\":{\"m\":113,\"g\":115},\"10318\":{\"m\":113,\"g\":115},\"10404\":{\"m\":113,\"g\":115},\"10228\":{\"m\":113,\"g\":115},\"10414\":{\"m\":113,\"g\":115},\"10410\":{\"m\":113,\"g\":115},\"10411\":{\"m\":113,\"g\":115},\"10412\":{\"m\":113,\"g\":115},\"9748\":{\"m\":113,\"g\":115},\"9382\":{\"m\":113,\"g\":115},\"10406\":{\"m\":113,\"g\":115},\"10392\":{\"m\":113,\"g\":115},\"10403\":{\"m\":113,\"g\":115},\"10400\":{\"m\":113,\"g\":115},\"10397\":{\"m\":113,\"g\":115},\"10398\":{\"m\":113,\"g\":115},\"9984\":{\"m\":113,\"g\":115},\"10395\":{\"m\":113,\"g\":115},\"10394\":{\"m\":113,\"g\":115},\"10332\":{\"m\":113,\"g\":115},\"10377\":{\"m\":113,\"g\":115},\"10379\":{\"m\":113,\"g\":115},\"10380\":{\"m\":113,\"g\":115},\"10387\":{\"m\":113,\"g\":115},\"10244\":{\"m\":113,\"g\":115},\"10386\":{\"m\":113,\"g\":115},\"10391\":{\"m\":113,\"g\":115},\"10361\":{\"m\":113,\"g\":115},\"10390\":{\"m\":113,\"g\":115},\"10388\":{\"m\":113,\"g\":115},\"10343\":{\"m\":113,\"g\":115},\"10333\":{\"m\":113,\"g\":115},\"9023\":{\"m\":113,\"g\":115},\"10099\":{\"m\":113,\"g\":115},\"10219\":{\"m\":113,\"g\":115},\"10370\":{\"m\":113,\"g\":115},\"10368\":{\"m\":113,\"g\":115},\"10180\":{\"m\":113,\"g\":115},\"10355\":{\"m\":113,\"g\":115},\"8215\":{\"m\":113,\"g\":115},\"8778\":{\"m\":113,\"g\":115},\"10351\":{\"m\":113,\"g\":115},\"10362\":{\"m\":113,\"g\":115},\"10359\":{\"m\":113,\"g\":115},\"10360\":{\"m\":113,\"g\":115},\"10356\":{\"m\":113,\"g\":115},\"10031\":{\"m\":113,\"g\":115},\"10283\":{\"m\":113,\"g\":115},\"10346\":{\"m\":113,\"g\":115},\"10352\":{\"m\":113,\"g\":115},\"9774\":{\"m\":113,\"g\":115},\"10296\":{\"m\":113,\"g\":115},\"9199\":{\"m\":113,\"g\":115},\"10345\":{\"m\":113,\"g\":115},\"10349\":{\"m\":113,\"g\":115},\"10347\":{\"m\":113,\"g\":115},\"11324\":{\"m\":114,\"g\":115},\"11369\":{\"m\":114,\"g\":115},\"11364\":{\"m\":114,\"g\":115},\"11394\":{\"m\":114,\"g\":115},\"11387\":{\"m\":114,\"g\":115},\"11376\":{\"m\":114,\"g\":115},\"11375\":{\"m\":114,\"g\":115},\"11373\":{\"m\":114,\"g\":115},\"11309\":{\"m\":114,\"g\":115},\"11366\":{\"m\":114,\"g\":115},\"11359\":{\"m\":114,\"g\":115},\"11353\":{\"m\":114,\"g\":115},\"10979\":{\"m\":114,\"g\":115},\"11350\":{\"m\":114,\"g\":115},\"11327\":{\"m\":114,\"g\":115},\"11342\":{\"m\":114,\"g\":115},\"11339\":{\"m\":114,\"g\":115},\"11341\":{\"m\":114,\"g\":115},\"11340\":{\"m\":114,\"g\":115},\"11336\":{\"m\":114,\"g\":115},\"11323\":{\"m\":114,\"g\":115},\"10909\":{\"m\":114,\"g\":115},\"11264\":{\"m\":114,\"g\":115},\"9812\":{\"m\":114,\"g\":115},\"11318\":{\"m\":114,\"g\":115},\"11007\":{\"m\":114,\"g\":115},\"11321\":{\"m\":114,\"g\":115},\"10937\":{\"m\":114,\"g\":115},\"11312\":{\"m\":114,\"g\":115},\"11211\":{\"m\":114,\"g\":115},\"9545\":{\"m\":114,\"g\":115},\"11314\":{\"m\":114,\"g\":115},\"11316\":{\"m\":114,\"g\":115},\"11126\":{\"m\":114,\"g\":115},\"11315\":{\"m\":114,\"g\":115},\"10710\":{\"m\":114,\"g\":115},\"11230\":{\"m\":114,\"g\":115},\"11200\":{\"m\":114,\"g\":115},\"11304\":{\"m\":114,\"g\":115},\"11310\":{\"m\":114,\"g\":115},\"11311\":{\"m\":114,\"g\":115},\"11297\":{\"m\":114,\"g\":115},\"11205\":{\"m\":114,\"g\":115},\"11307\":{\"m\":114,\"g\":115},\"11223\":{\"m\":114,\"g\":115},\"11306\":{\"m\":114,\"g\":115},\"11305\":{\"m\":114,\"g\":115},\"11001\":{\"m\":114,\"g\":115},\"11027\":{\"m\":114,\"g\":115},\"11288\":{\"m\":114,\"g\":115},\"11303\":{\"m\":114,\"g\":115},\"11302\":{\"m\":114,\"g\":115},\"11068\":{\"m\":114,\"g\":115},\"11301\":{\"m\":114,\"g\":115},\"11300\":{\"m\":114,\"g\":115},\"11290\":{\"m\":114,\"g\":115},\"11210\":{\"m\":114,\"g\":115},\"11231\":{\"m\":114,\"g\":115},\"11294\":{\"m\":114,\"g\":115},\"10949\":{\"m\":114,\"g\":115},\"11095\":{\"m\":114,\"g\":115},\"11283\":{\"m\":114,\"g\":115},\"11281\":{\"m\":114,\"g\":115},\"11286\":{\"m\":114,\"g\":115},\"11282\":{\"m\":114,\"g\":115},\"11261\":{\"m\":114,\"g\":115},\"11238\":{\"m\":114,\"g\":115},\"11279\":{\"m\":114,\"g\":115},\"11280\":{\"m\":114,\"g\":115},\"11276\":{\"m\":114,\"g\":115},\"11277\":{\"m\":114,\"g\":115},\"11182\":{\"m\":114,\"g\":115},\"11268\":{\"m\":114,\"g\":115},\"11274\":{\"m\":114,\"g\":115},\"11270\":{\"m\":114,\"g\":115},\"7149\":{\"m\":114,\"g\":115},\"11262\":{\"m\":114,\"g\":115},\"11219\":{\"m\":114,\"g\":115},\"11680\":{\"m\":116,\"g\":118},\"11676\":{\"m\":116,\"g\":118},\"11684\":{\"m\":116,\"g\":118},\"11667\":{\"m\":116,\"g\":118},\"11674\":{\"m\":116,\"g\":118},\"11681\":{\"m\":116,\"g\":118},\"11621\":{\"m\":116,\"g\":118},\"11367\":{\"m\":116,\"g\":118},\"11653\":{\"m\":116,\"g\":118},\"11660\":{\"m\":116,\"g\":118},\"11659\":{\"m\":116,\"g\":118},\"11585\":{\"m\":116,\"g\":118},\"11293\":{\"m\":116,\"g\":118},\"11458\":{\"m\":116,\"g\":118},\"11590\":{\"m\":116,\"g\":118},\"11636\":{\"m\":116,\"g\":118},\"11579\":{\"m\":116,\"g\":118},\"8247\":{\"m\":116,\"g\":118},\"10423\":{\"m\":116,\"g\":118},\"11642\":{\"m\":116,\"g\":115},\"11638\":{\"m\":116,\"g\":115},\"11628\":{\"m\":116,\"g\":115},\"11639\":{\"m\":116,\"g\":115},\"11351\":{\"m\":116,\"g\":115},\"11627\":{\"m\":116,\"g\":115},\"11633\":{\"m\":116,\"g\":115},\"11631\":{\"m\":116,\"g\":115},\"11622\":{\"m\":116,\"g\":115},\"11625\":{\"m\":116,\"g\":115},\"11623\":{\"m\":116,\"g\":115},\"11624\":{\"m\":116,\"g\":115},\"11619\":{\"m\":116,\"g\":115},\"11605\":{\"m\":116,\"g\":115},\"11620\":{\"m\":116,\"g\":115},\"11617\":{\"m\":116,\"g\":115},\"11453\":{\"m\":116,\"g\":115},\"11561\":{\"m\":116,\"g\":115},\"11434\":{\"m\":116,\"g\":115},\"10721\":{\"m\":116,\"g\":115},\"11586\":{\"m\":116,\"g\":115},\"11556\":{\"m\":116,\"g\":115},\"11603\":{\"m\":116,\"g\":115},\"11593\":{\"m\":116,\"g\":115},\"11601\":{\"m\":116,\"g\":115},\"11449\":{\"m\":116,\"g\":115},\"11600\":{\"m\":116,\"g\":115},\"11597\":{\"m\":116,\"g\":115},\"11598\":{\"m\":116,\"g\":115},\"11566\":{\"m\":116,\"g\":115},\"11591\":{\"m\":116,\"g\":115},\"11588\":{\"m\":116,\"g\":115},\"11587\":{\"m\":116,\"g\":115},\"11580\":{\"m\":116,\"g\":115},\"11583\":{\"m\":116,\"g\":115},\"11582\":{\"m\":116,\"g\":115},\"11041\":{\"m\":116,\"g\":115},\"11542\":{\"m\":116,\"g\":115},\"11535\":{\"m\":116,\"g\":115},\"11565\":{\"m\":116,\"g\":115},\"11413\":{\"m\":116,\"g\":115},\"11539\":{\"m\":116,\"g\":115},\"11572\":{\"m\":116,\"g\":115},\"11573\":{\"m\":116,\"g\":115},\"11534\":{\"m\":116,\"g\":115},\"11571\":{\"m\":116,\"g\":115},\"11564\":{\"m\":116,\"g\":115},\"11537\":{\"m\":116,\"g\":115},\"11538\":{\"m\":116,\"g\":115},\"11521\":{\"m\":116,\"g\":115},\"11562\":{\"m\":116,\"g\":115},\"11308\":{\"m\":116,\"g\":115},\"11557\":{\"m\":116,\"g\":115},\"11441\":{\"m\":116,\"g\":115},\"11483\":{\"m\":116,\"g\":115},\"11531\":{\"m\":116,\"g\":115},\"11549\":{\"m\":116,\"g\":115},\"11553\":{\"m\":116,\"g\":115},\"11547\":{\"m\":116,\"g\":115},\"11507\":{\"m\":116,\"g\":115},\"11419\":{\"m\":116,\"g\":115},\"11444\":{\"m\":116,\"g\":115},\"11442\":{\"m\":116,\"g\":115},\"11548\":{\"m\":116,\"g\":115},\"11530\":{\"m\":116,\"g\":115},\"11527\":{\"m\":116,\"g\":115},\"11201\":{\"m\":116,\"g\":115},\"11528\":{\"m\":116,\"g\":115},\"11457\":{\"m\":116,\"g\":115},\"11544\":{\"m\":116,\"g\":115},\"11460\":{\"m\":116,\"g\":115},\"11505\":{\"m\":116,\"g\":115},\"11385\":{\"m\":116,\"g\":115},\"11214\":{\"m\":116,\"g\":115},\"11493\":{\"m\":116,\"g\":115},\"11512\":{\"m\":116,\"g\":115},\"11432\":{\"m\":116,\"g\":115},\"11485\":{\"m\":116,\"g\":115},\"11511\":{\"m\":116,\"g\":115},\"5889\":{\"m\":116,\"g\":115},\"11520\":{\"m\":116,\"g\":115},\"11516\":{\"m\":116,\"g\":115},\"11498\":{\"m\":116,\"g\":115},\"11474\":{\"m\":116,\"g\":115},\"11515\":{\"m\":116,\"g\":115},\"11514\":{\"m\":116,\"g\":115},\"11331\":{\"m\":116,\"g\":115},\"11509\":{\"m\":116,\"g\":115},\"11443\":{\"m\":116,\"g\":115},\"11452\":{\"m\":116,\"g\":115},\"11503\":{\"m\":116,\"g\":115},\"11502\":{\"m\":116,\"g\":115},\"11497\":{\"m\":116,\"g\":115},\"11501\":{\"m\":116,\"g\":115},\"11500\":{\"m\":116,\"g\":115},\"11332\":{\"m\":116,\"g\":115},\"11499\":{\"m\":116,\"g\":115},\"11465\":{\"m\":116,\"g\":115},\"10577\":{\"m\":116,\"g\":115},\"11479\":{\"m\":116,\"g\":115},\"11221\":{\"m\":116,\"g\":115},\"10172\":{\"m\":116,\"g\":115},\"11478\":{\"m\":116,\"g\":115},\"11481\":{\"m\":116,\"g\":115},\"11476\":{\"m\":116,\"g\":115},\"11489\":{\"m\":116,\"g\":115},\"10062\":{\"m\":116,\"g\":115},\"11398\":{\"m\":116,\"g\":115},\"10635\":{\"m\":116,\"g\":115},\"9804\":{\"m\":116,\"g\":115},\"11454\":{\"m\":116,\"g\":115},\"11019\":{\"m\":116,\"g\":115},\"8919\":{\"m\":116,\"g\":115},\"11462\":{\"m\":116,\"g\":115},\"11428\":{\"m\":116,\"g\":115},\"11470\":{\"m\":116,\"g\":115},\"11427\":{\"m\":116,\"g\":115},\"11467\":{\"m\":116,\"g\":115},\"11448\":{\"m\":116,\"g\":115},\"10312\":{\"m\":116,\"g\":115},\"9991\":{\"m\":116,\"g\":115},\"11455\":{\"m\":116,\"g\":115},\"11450\":{\"m\":116,\"g\":115},\"11360\":{\"m\":116,\"g\":115},\"11368\":{\"m\":116,\"g\":115},\"11445\":{\"m\":116,\"g\":115},\"11438\":{\"m\":116,\"g\":115},\"11439\":{\"m\":116,\"g\":115},\"11399\":{\"m\":116,\"g\":115},\"11435\":{\"m\":116,\"g\":115},\"11313\":{\"m\":116,\"g\":115},\"10745\":{\"m\":116,\"g\":115},\"11411\":{\"m\":116,\"g\":115},\"11433\":{\"m\":116,\"g\":115},\"11437\":{\"m\":116,\"g\":115},\"11436\":{\"m\":116,\"g\":115},\"11345\":{\"m\":116,\"g\":115},\"9256\":{\"m\":116,\"g\":115},\"11381\":{\"m\":116,\"g\":115},\"11361\":{\"m\":116,\"g\":115},\"11420\":{\"m\":116,\"g\":115},\"11144\":{\"m\":116,\"g\":115},\"10734\":{\"m\":116,\"g\":115},\"10969\":{\"m\":116,\"g\":115},\"9045\":{\"m\":116,\"g\":115},\"11414\":{\"m\":116,\"g\":115},\"11388\":{\"m\":116,\"g\":115},\"11363\":{\"m\":116,\"g\":115},\"11365\":{\"m\":116,\"g\":115},\"11389\":{\"m\":116,\"g\":115},\"11401\":{\"m\":116,\"g\":115},\"11285\":{\"m\":116,\"g\":115},\"11693\":{\"m\":117,\"g\":118},\"11543\":{\"m\":117,\"g\":118},\"11687\":{\"m\":117,\"g\":118},\"11706\":{\"m\":117,\"g\":118},\"11510\":{\"m\":117,\"g\":118},\"11488\":{\"m\":117,\"g\":118},\"11370\":{\"m\":117,\"g\":118},\"11692\":{\"m\":117,\"g\":118},\"11663\":{\"m\":117,\"g\":118},\"10912\":{\"m\":117,\"g\":118},\"11679\":{\"m\":117,\"g\":118},\"11689\":{\"m\":117,\"g\":118},\"11686\":{\"m\":117,\"g\":118},\"10248\":{\"m\":117,\"g\":118},\"9493\":{\"m\":117,\"g\":118},\"12027\":{\"m\":119,\"g\":121},\"12009\":{\"m\":119,\"g\":121},\"12030\":{\"m\":119,\"g\":121},\"12029\":{\"m\":119,\"g\":121},\"11616\":{\"m\":119,\"g\":121},\"12028\":{\"m\":119,\"g\":121},\"9366\":{\"m\":119,\"g\":121},\"11891\":{\"m\":119,\"g\":121},\"10158\":{\"m\":119,\"g\":121},\"12024\":{\"m\":119,\"g\":121},\"11765\":{\"m\":119,\"g\":121},\"12022\":{\"m\":119,\"g\":121},\"12018\":{\"m\":119,\"g\":121},\"12021\":{\"m\":119,\"g\":121},\"11755\":{\"m\":119,\"g\":121},\"11981\":{\"m\":119,\"g\":121},\"12015\":{\"m\":119,\"g\":121},\"12014\":{\"m\":119,\"g\":121},\"11937\":{\"m\":119,\"g\":121},\"11988\":{\"m\":119,\"g\":121},\"11866\":{\"m\":119,\"g\":121},\"11821\":{\"m\":119,\"g\":121},\"12004\":{\"m\":119,\"g\":121},\"11985\":{\"m\":119,\"g\":121},\"11944\":{\"m\":119,\"g\":121},\"10652\":{\"m\":119,\"g\":121},\"11965\":{\"m\":119,\"g\":121},\"11906\":{\"m\":119,\"g\":121},\"11990\":{\"m\":119,\"g\":121},\"11299\":{\"m\":119,\"g\":121},\"11322\":{\"m\":119,\"g\":121},\"11811\":{\"m\":119,\"g\":121},\"11955\":{\"m\":119,\"g\":121},\"11978\":{\"m\":119,\"g\":121},\"10869\":{\"m\":119,\"g\":121},\"11921\":{\"m\":119,\"g\":121},\"11563\":{\"m\":119,\"g\":121},\"11980\":{\"m\":119,\"g\":121},\"11977\":{\"m\":119,\"g\":121},\"10750\":{\"m\":119,\"g\":121},\"11956\":{\"m\":119,\"g\":121},\"11723\":{\"m\":119,\"g\":121},\"11967\":{\"m\":119,\"g\":121},\"11953\":{\"m\":119,\"g\":121},\"9651\":{\"m\":119,\"g\":121},\"10606\":{\"m\":119,\"g\":121},\"11908\":{\"m\":119,\"g\":121},\"10154\":{\"m\":119,\"g\":121},\"11929\":{\"m\":119,\"g\":121},\"11717\":{\"m\":119,\"g\":121},\"11922\":{\"m\":119,\"g\":121},\"11945\":{\"m\":119,\"g\":121},\"11790\":{\"m\":119,\"g\":121},\"11926\":{\"m\":119,\"g\":121},\"11940\":{\"m\":119,\"g\":121},\"11935\":{\"m\":119,\"g\":121},\"11934\":{\"m\":119,\"g\":121},\"11377\":{\"m\":119,\"g\":121},\"11933\":{\"m\":119,\"g\":121},\"11876\":{\"m\":119,\"g\":121},\"11844\":{\"m\":119,\"g\":121},\"11287\":{\"m\":119,\"g\":121},\"11918\":{\"m\":119,\"g\":121},\"11915\":{\"m\":119,\"g\":121},\"11702\":{\"m\":119,\"g\":121},\"11482\":{\"m\":119,\"g\":121},\"10700\":{\"m\":119,\"g\":121},\"11902\":{\"m\":119,\"g\":121},\"11295\":{\"m\":119,\"g\":121},\"11416\":{\"m\":119,\"g\":121},\"11895\":{\"m\":119,\"g\":121},\"11570\":{\"m\":119,\"g\":121},\"11487\":{\"m\":119,\"g\":121},\"11878\":{\"m\":119,\"g\":121},\"11885\":{\"m\":119,\"g\":118},\"11664\":{\"m\":119,\"g\":118},\"10656\":{\"m\":119,\"g\":118},\"11843\":{\"m\":119,\"g\":118},\"11845\":{\"m\":119,\"g\":118},\"11859\":{\"m\":119,\"g\":118},\"11887\":{\"m\":119,\"g\":118},\"11838\":{\"m\":119,\"g\":118},\"11886\":{\"m\":119,\"g\":118},\"11826\":{\"m\":119,\"g\":118},\"11882\":{\"m\":119,\"g\":118},\"11875\":{\"m\":119,\"g\":118},\"11881\":{\"m\":119,\"g\":118},\"11868\":{\"m\":119,\"g\":118},\"11807\":{\"m\":119,\"g\":118},\"11867\":{\"m\":119,\"g\":118},\"11823\":{\"m\":119,\"g\":118},\"11847\":{\"m\":119,\"g\":118},\"11862\":{\"m\":119,\"g\":118},\"10691\":{\"m\":119,\"g\":118},\"11776\":{\"m\":119,\"g\":118},\"11822\":{\"m\":119,\"g\":118},\"11849\":{\"m\":119,\"g\":118},\"11396\":{\"m\":119,\"g\":118},\"11846\":{\"m\":119,\"g\":118},\"11747\":{\"m\":119,\"g\":118},\"10801\":{\"m\":119,\"g\":118},\"11722\":{\"m\":119,\"g\":118},\"11594\":{\"m\":119,\"g\":118},\"11780\":{\"m\":119,\"g\":118},\"11733\":{\"m\":119,\"g\":118},\"10510\":{\"m\":119,\"g\":118},\"11508\":{\"m\":119,\"g\":118},\"11787\":{\"m\":119,\"g\":118},\"11832\":{\"m\":119,\"g\":118},\"11778\":{\"m\":119,\"g\":118},\"11612\":{\"m\":119,\"g\":118},\"11810\":{\"m\":119,\"g\":118},\"11831\":{\"m\":119,\"g\":118},\"11815\":{\"m\":119,\"g\":118},\"11606\":{\"m\":119,\"g\":118},\"11835\":{\"m\":119,\"g\":118},\"10994\":{\"m\":119,\"g\":118},\"11147\":{\"m\":119,\"g\":118},\"11833\":{\"m\":119,\"g\":118},\"11834\":{\"m\":119,\"g\":118},\"11652\":{\"m\":119,\"g\":118},\"11819\":{\"m\":119,\"g\":118},\"11827\":{\"m\":119,\"g\":118},\"11805\":{\"m\":119,\"g\":118},\"5162\":{\"m\":119,\"g\":118},\"11786\":{\"m\":119,\"g\":118},\"11804\":{\"m\":119,\"g\":118},\"11808\":{\"m\":119,\"g\":118},\"11091\":{\"m\":119,\"g\":118},\"11818\":{\"m\":119,\"g\":118},\"10788\":{\"m\":119,\"g\":118},\"11817\":{\"m\":119,\"g\":118},\"11328\":{\"m\":119,\"g\":118},\"11555\":{\"m\":119,\"g\":118},\"11618\":{\"m\":119,\"g\":118},\"11670\":{\"m\":119,\"g\":118},\"11688\":{\"m\":119,\"g\":118},\"11773\":{\"m\":119,\"g\":118},\"11813\":{\"m\":119,\"g\":118},\"11000\":{\"m\":119,\"g\":118},\"11710\":{\"m\":119,\"g\":118},\"11749\":{\"m\":119,\"g\":118},\"11772\":{\"m\":119,\"g\":118},\"11506\":{\"m\":119,\"g\":118},\"11797\":{\"m\":119,\"g\":118},\"11781\":{\"m\":119,\"g\":118},\"11803\":{\"m\":119,\"g\":118},\"11801\":{\"m\":119,\"g\":118},\"10152\":{\"m\":119,\"g\":118},\"11669\":{\"m\":119,\"g\":118},\"11793\":{\"m\":119,\"g\":118},\"11665\":{\"m\":119,\"g\":118},\"11799\":{\"m\":119,\"g\":118},\"11798\":{\"m\":119,\"g\":118},\"11794\":{\"m\":119,\"g\":118},\"11784\":{\"m\":119,\"g\":118},\"11783\":{\"m\":119,\"g\":118},\"11614\":{\"m\":119,\"g\":118},\"11788\":{\"m\":119,\"g\":118},\"9170\":{\"m\":119,\"g\":118},\"11666\":{\"m\":119,\"g\":118},\"11613\":{\"m\":119,\"g\":118},\"11611\":{\"m\":119,\"g\":118},\"11607\":{\"m\":119,\"g\":118},\"11685\":{\"m\":119,\"g\":118},\"11519\":{\"m\":119,\"g\":118},\"11782\":{\"m\":119,\"g\":118},\"11777\":{\"m\":119,\"g\":118},\"11682\":{\"m\":119,\"g\":118},\"11775\":{\"m\":119,\"g\":118},\"11738\":{\"m\":119,\"g\":118},\"10725\":{\"m\":119,\"g\":118},\"11540\":{\"m\":119,\"g\":118},\"11767\":{\"m\":119,\"g\":118},\"11768\":{\"m\":119,\"g\":118},\"11766\":{\"m\":119,\"g\":118},\"11735\":{\"m\":119,\"g\":118},\"11730\":{\"m\":119,\"g\":118},\"11643\":{\"m\":119,\"g\":118},\"11062\":{\"m\":119,\"g\":118},\"11724\":{\"m\":119,\"g\":118},\"11746\":{\"m\":119,\"g\":118},\"11739\":{\"m\":119,\"g\":118},\"11740\":{\"m\":119,\"g\":118},\"11732\":{\"m\":119,\"g\":118},\"11734\":{\"m\":119,\"g\":118},\"11541\":{\"m\":119,\"g\":118},\"11728\":{\"m\":119,\"g\":118},\"11731\":{\"m\":119,\"g\":118},\"11727\":{\"m\":119,\"g\":118},\"11729\":{\"m\":119,\"g\":118},\"11677\":{\"m\":119,\"g\":118},\"10911\":{\"m\":119,\"g\":118},\"12169\":{\"m\":120,\"g\":121},\"12177\":{\"m\":120,\"g\":121},\"12170\":{\"m\":120,\"g\":121},\"12167\":{\"m\":120,\"g\":121},\"12171\":{\"m\":120,\"g\":121},\"12164\":{\"m\":120,\"g\":121},\"12168\":{\"m\":120,\"g\":121},\"12129\":{\"m\":120,\"g\":121},\"12166\":{\"m\":120,\"g\":121},\"11047\":{\"m\":120,\"g\":121},\"11632\":{\"m\":120,\"g\":121},\"10399\":{\"m\":120,\"g\":121},\"12142\":{\"m\":120,\"g\":121},\"12106\":{\"m\":120,\"g\":121},\"12156\":{\"m\":120,\"g\":121},\"12155\":{\"m\":120,\"g\":121},\"12152\":{\"m\":120,\"g\":121},\"12154\":{\"m\":120,\"g\":121},\"11615\":{\"m\":120,\"g\":121},\"11494\":{\"m\":120,\"g\":121},\"12113\":{\"m\":120,\"g\":121},\"12116\":{\"m\":120,\"g\":121},\"12136\":{\"m\":120,\"g\":121},\"12147\":{\"m\":120,\"g\":121},\"12141\":{\"m\":120,\"g\":121},\"11991\":{\"m\":120,\"g\":121},\"12097\":{\"m\":120,\"g\":121},\"12139\":{\"m\":120,\"g\":121},\"12138\":{\"m\":120,\"g\":121},\"12118\":{\"m\":120,\"g\":121},\"12133\":{\"m\":120,\"g\":121},\"11936\":{\"m\":120,\"g\":121},\"12132\":{\"m\":120,\"g\":121},\"11993\":{\"m\":120,\"g\":121},\"12130\":{\"m\":120,\"g\":121},\"12125\":{\"m\":120,\"g\":121},\"12101\":{\"m\":120,\"g\":121},\"11814\":{\"m\":120,\"g\":121},\"12127\":{\"m\":120,\"g\":121},\"12115\":{\"m\":120,\"g\":121},\"12126\":{\"m\":120,\"g\":121},\"12098\":{\"m\":120,\"g\":121},\"11962\":{\"m\":120,\"g\":121},\"12119\":{\"m\":120,\"g\":121},\"12124\":{\"m\":120,\"g\":121},\"11869\":{\"m\":120,\"g\":121},\"12110\":{\"m\":120,\"g\":121},\"12096\":{\"m\":120,\"g\":121},\"12083\":{\"m\":120,\"g\":121},\"12103\":{\"m\":120,\"g\":121},\"12105\":{\"m\":120,\"g\":121},\"12087\":{\"m\":120,\"g\":121},\"9501\":{\"m\":120,\"g\":121},\"11379\":{\"m\":120,\"g\":121},\"12058\":{\"m\":120,\"g\":121},\"12054\":{\"m\":120,\"g\":121},\"11877\":{\"m\":120,\"g\":121},\"12070\":{\"m\":120,\"g\":121},\"8464\":{\"m\":120,\"g\":121},\"12093\":{\"m\":120,\"g\":121},\"12071\":{\"m\":120,\"g\":121},\"12000\":{\"m\":120,\"g\":121},\"12091\":{\"m\":120,\"g\":121},\"12089\":{\"m\":120,\"g\":121},\"12086\":{\"m\":120,\"g\":121},\"11560\":{\"m\":120,\"g\":121},\"12084\":{\"m\":120,\"g\":121},\"12034\":{\"m\":120,\"g\":121},\"11924\":{\"m\":120,\"g\":121},\"11884\":{\"m\":120,\"g\":121},\"12025\":{\"m\":120,\"g\":121},\"11958\":{\"m\":120,\"g\":121},\"12067\":{\"m\":120,\"g\":121},\"11999\":{\"m\":120,\"g\":121},\"12046\":{\"m\":120,\"g\":121},\"12049\":{\"m\":120,\"g\":121},\"12063\":{\"m\":120,\"g\":121},\"12064\":{\"m\":120,\"g\":121},\"12053\":{\"m\":120,\"g\":121},\"11800\":{\"m\":120,\"g\":121},\"12056\":{\"m\":120,\"g\":121},\"12019\":{\"m\":120,\"g\":121},\"11853\":{\"m\":120,\"g\":121},\"12031\":{\"m\":120,\"g\":121},\"12041\":{\"m\":120,\"g\":121},\"10953\":{\"m\":120,\"g\":121},\"11759\":{\"m\":120,\"g\":121},\"12042\":{\"m\":120,\"g\":121},\"12037\":{\"m\":120,\"g\":121},\"11745\":{\"m\":120,\"g\":121},\"12038\":{\"m\":120,\"g\":121},\"11909\":{\"m\":120,\"g\":121},\"11816\":{\"m\":120,\"g\":121},\"11795\":{\"m\":120,\"g\":121},\"12003\":{\"m\":120,\"g\":121},\"9936\":{\"m\":120,\"g\":121},\"11964\":{\"m\":120,\"g\":121},\"12439\":{\"m\":122,\"g\":127},\"11874\":{\"m\":122,\"g\":127},\"12475\":{\"m\":122,\"g\":127},\"12469\":{\"m\":122,\"g\":127},\"11987\":{\"m\":122,\"g\":127},\"12430\":{\"m\":122,\"g\":127},\"12473\":{\"m\":122,\"g\":127},\"12297\":{\"m\":122,\"g\":127},\"12341\":{\"m\":122,\"g\":127},\"12428\":{\"m\":122,\"g\":127},\"12429\":{\"m\":122,\"g\":127},\"12066\":{\"m\":122,\"g\":127},\"12275\":{\"m\":122,\"g\":127},\"11757\":{\"m\":122,\"g\":127},\"11931\":{\"m\":122,\"g\":127},\"12470\":{\"m\":122,\"g\":127},\"12415\":{\"m\":122,\"g\":127},\"12266\":{\"m\":122,\"g\":127},\"12463\":{\"m\":122,\"g\":127},\"12328\":{\"m\":122,\"g\":127},\"12369\":{\"m\":122,\"g\":127},\"12256\":{\"m\":122,\"g\":127},\"12449\":{\"m\":122,\"g\":127},\"12452\":{\"m\":122,\"g\":127},\"12413\":{\"m\":122,\"g\":127},\"10889\":{\"m\":122,\"g\":127},\"12422\":{\"m\":122,\"g\":127},\"12436\":{\"m\":122,\"g\":127},\"12437\":{\"m\":122,\"g\":127},\"12410\":{\"m\":122,\"g\":127},\"12384\":{\"m\":122,\"g\":127},\"12307\":{\"m\":122,\"g\":127},\"10566\":{\"m\":122,\"g\":127},\"12405\":{\"m\":122,\"g\":127},\"12425\":{\"m\":122,\"g\":127},\"12401\":{\"m\":122,\"g\":127},\"12300\":{\"m\":122,\"g\":127},\"11224\":{\"m\":122,\"g\":127},\"11116\":{\"m\":122,\"g\":127},\"12399\":{\"m\":122,\"g\":121},\"12290\":{\"m\":122,\"g\":121},\"12242\":{\"m\":122,\"g\":121},\"12368\":{\"m\":122,\"g\":121},\"12403\":{\"m\":122,\"g\":121},\"12386\":{\"m\":122,\"g\":121},\"12409\":{\"m\":122,\"g\":121},\"12375\":{\"m\":122,\"g\":121},\"12404\":{\"m\":122,\"g\":121},\"12281\":{\"m\":122,\"g\":121},\"12185\":{\"m\":122,\"g\":121},\"12012\":{\"m\":122,\"g\":121},\"12364\":{\"m\":122,\"g\":121},\"12395\":{\"m\":122,\"g\":121},\"11960\":{\"m\":122,\"g\":121},\"12377\":{\"m\":122,\"g\":121},\"11897\":{\"m\":122,\"g\":121},\"11969\":{\"m\":122,\"g\":121},\"12394\":{\"m\":122,\"g\":121},\"11806\":{\"m\":122,\"g\":121},\"12319\":{\"m\":122,\"g\":121},\"12123\":{\"m\":122,\"g\":121},\"12358\":{\"m\":122,\"g\":121},\"12362\":{\"m\":122,\"g\":121},\"12135\":{\"m\":122,\"g\":121},\"12378\":{\"m\":122,\"g\":121},\"12174\":{\"m\":122,\"g\":121},\"12340\":{\"m\":122,\"g\":121},\"12195\":{\"m\":122,\"g\":121},\"11910\":{\"m\":122,\"g\":121},\"12216\":{\"m\":122,\"g\":121},\"12050\":{\"m\":122,\"g\":121},\"12153\":{\"m\":122,\"g\":121},\"12094\":{\"m\":122,\"g\":121},\"12182\":{\"m\":122,\"g\":121},\"12354\":{\"m\":122,\"g\":121},\"12350\":{\"m\":122,\"g\":121},\"12348\":{\"m\":122,\"g\":121},\"11737\":{\"m\":122,\"g\":121},\"12346\":{\"m\":122,\"g\":121},\"12325\":{\"m\":122,\"g\":121},\"12095\":{\"m\":122,\"g\":121},\"12347\":{\"m\":122,\"g\":121},\"12345\":{\"m\":122,\"g\":121},\"11709\":{\"m\":122,\"g\":121},\"12343\":{\"m\":122,\"g\":121},\"12315\":{\"m\":122,\"g\":121},\"12338\":{\"m\":122,\"g\":121},\"11673\":{\"m\":122,\"g\":121},\"12002\":{\"m\":122,\"g\":121},\"12336\":{\"m\":122,\"g\":121},\"12312\":{\"m\":122,\"g\":121},\"12317\":{\"m\":122,\"g\":121},\"12276\":{\"m\":122,\"g\":121},\"12294\":{\"m\":122,\"g\":121},\"12314\":{\"m\":122,\"g\":121},\"12144\":{\"m\":122,\"g\":121},\"10874\":{\"m\":122,\"g\":121},\"12269\":{\"m\":122,\"g\":121},\"9825\":{\"m\":122,\"g\":121},\"12313\":{\"m\":122,\"g\":121},\"12311\":{\"m\":122,\"g\":121},\"12259\":{\"m\":122,\"g\":121},\"12241\":{\"m\":122,\"g\":121},\"12308\":{\"m\":122,\"g\":121},\"12299\":{\"m\":122,\"g\":121},\"12271\":{\"m\":122,\"g\":121},\"12296\":{\"m\":122,\"g\":121},\"12295\":{\"m\":122,\"g\":121},\"12285\":{\"m\":122,\"g\":121},\"12233\":{\"m\":122,\"g\":121},\"11928\":{\"m\":122,\"g\":121},\"12188\":{\"m\":122,\"g\":121},\"12283\":{\"m\":122,\"g\":121},\"12231\":{\"m\":122,\"g\":121},\"12284\":{\"m\":122,\"g\":121},\"12274\":{\"m\":122,\"g\":121},\"12257\":{\"m\":122,\"g\":121},\"12268\":{\"m\":122,\"g\":121},\"12267\":{\"m\":122,\"g\":121},\"12206\":{\"m\":122,\"g\":121},\"12247\":{\"m\":122,\"g\":121},\"10804\":{\"m\":122,\"g\":121},\"12230\":{\"m\":122,\"g\":121},\"12229\":{\"m\":122,\"g\":121},\"12252\":{\"m\":122,\"g\":121},\"12249\":{\"m\":122,\"g\":121},\"7873\":{\"m\":122,\"g\":121},\"10567\":{\"m\":122,\"g\":121},\"11177\":{\"m\":122,\"g\":121},\"11655\":{\"m\":122,\"g\":121},\"12245\":{\"m\":122,\"g\":121},\"11517\":{\"m\":122,\"g\":121},\"10654\":{\"m\":122,\"g\":121},\"12222\":{\"m\":122,\"g\":121},\"11994\":{\"m\":122,\"g\":121},\"12176\":{\"m\":122,\"g\":121},\"11708\":{\"m\":122,\"g\":121},\"12235\":{\"m\":122,\"g\":121},\"12161\":{\"m\":122,\"g\":121},\"12234\":{\"m\":122,\"g\":121},\"11142\":{\"m\":122,\"g\":121},\"12006\":{\"m\":122,\"g\":121},\"11592\":{\"m\":122,\"g\":121},\"11656\":{\"m\":122,\"g\":121},\"12186\":{\"m\":122,\"g\":121},\"12209\":{\"m\":122,\"g\":121},\"12205\":{\"m\":122,\"g\":121},\"12107\":{\"m\":122,\"g\":121},\"12112\":{\"m\":122,\"g\":121},\"10153\":{\"m\":122,\"g\":121},\"12117\":{\"m\":122,\"g\":121},\"12080\":{\"m\":122,\"g\":121},\"9403\":{\"m\":122,\"g\":121},\"12192\":{\"m\":122,\"g\":121},\"12173\":{\"m\":122,\"g\":121},\"12159\":{\"m\":122,\"g\":121},\"12057\":{\"m\":122,\"g\":121},\"12639\":{\"m\":123,\"g\":127},\"12572\":{\"m\":123,\"g\":127},\"12656\":{\"m\":123,\"g\":127},\"12456\":{\"m\":123,\"g\":127},\"12585\":{\"m\":123,\"g\":127},\"12648\":{\"m\":123,\"g\":127},\"12650\":{\"m\":123,\"g\":127},\"12640\":{\"m\":123,\"g\":127},\"12645\":{\"m\":123,\"g\":127},\"12647\":{\"m\":123,\"g\":127},\"12642\":{\"m\":123,\"g\":127},\"12641\":{\"m\":123,\"g\":127},\"12628\":{\"m\":123,\"g\":127},\"12634\":{\"m\":123,\"g\":127},\"12633\":{\"m\":123,\"g\":127},\"12632\":{\"m\":123,\"g\":127},\"12593\":{\"m\":123,\"g\":127},\"12616\":{\"m\":123,\"g\":127},\"12594\":{\"m\":123,\"g\":127},\"12592\":{\"m\":123,\"g\":127},\"11456\":{\"m\":123,\"g\":127},\"12615\":{\"m\":123,\"g\":127},\"12599\":{\"m\":123,\"g\":127},\"12522\":{\"m\":123,\"g\":127},\"10183\":{\"m\":123,\"g\":127},\"12598\":{\"m\":123,\"g\":127},\"6318\":{\"m\":123,\"g\":127},\"11131\":{\"m\":123,\"g\":127},\"11974\":{\"m\":123,\"g\":127},\"12580\":{\"m\":123,\"g\":127},\"12597\":{\"m\":123,\"g\":127},\"11760\":{\"m\":123,\"g\":127},\"12462\":{\"m\":123,\"g\":127},\"12165\":{\"m\":123,\"g\":127},\"12111\":{\"m\":123,\"g\":127},\"12547\":{\"m\":123,\"g\":127},\"12270\":{\"m\":123,\"g\":127},\"12044\":{\"m\":123,\"g\":127},\"12519\":{\"m\":123,\"g\":127},\"12571\":{\"m\":123,\"g\":127},\"12301\":{\"m\":123,\"g\":127},\"12569\":{\"m\":123,\"g\":127},\"12550\":{\"m\":123,\"g\":127},\"12549\":{\"m\":123,\"g\":127},\"12553\":{\"m\":123,\"g\":127},\"12524\":{\"m\":123,\"g\":127},\"12560\":{\"m\":123,\"g\":127},\"12227\":{\"m\":123,\"g\":127},\"12548\":{\"m\":123,\"g\":127},\"11330\":{\"m\":123,\"g\":127},\"12564\":{\"m\":123,\"g\":127},\"12561\":{\"m\":123,\"g\":127},\"12060\":{\"m\":123,\"g\":127},\"12536\":{\"m\":123,\"g\":127},\"12541\":{\"m\":123,\"g\":127},\"12532\":{\"m\":123,\"g\":127},\"12530\":{\"m\":123,\"g\":127},\"12523\":{\"m\":123,\"g\":127},\"12367\":{\"m\":123,\"g\":127},\"12502\":{\"m\":123,\"g\":127},\"12521\":{\"m\":123,\"g\":127},\"12515\":{\"m\":123,\"g\":127},\"12453\":{\"m\":123,\"g\":127},\"12481\":{\"m\":123,\"g\":127},\"12506\":{\"m\":123,\"g\":127},\"11917\":{\"m\":123,\"g\":127},\"12511\":{\"m\":123,\"g\":127},\"10078\":{\"m\":123,\"g\":127},\"12505\":{\"m\":123,\"g\":127},\"12507\":{\"m\":123,\"g\":127},\"11133\":{\"m\":123,\"g\":127},\"11052\":{\"m\":123,\"g\":127},\"12499\":{\"m\":123,\"g\":127},\"12391\":{\"m\":123,\"g\":127},\"12488\":{\"m\":123,\"g\":127},\"12412\":{\"m\":123,\"g\":127},\"11966\":{\"m\":123,\"g\":127},\"12238\":{\"m\":123,\"g\":127},\"12423\":{\"m\":123,\"g\":127},\"12500\":{\"m\":123,\"g\":127},\"12480\":{\"m\":123,\"g\":127},\"12485\":{\"m\":123,\"g\":127},\"12435\":{\"m\":123,\"g\":127},\"12483\":{\"m\":123,\"g\":127},\"12482\":{\"m\":123,\"g\":127},\"12226\":{\"m\":123,\"g\":127},\"12334\":{\"m\":123,\"g\":127},\"12739\":{\"m\":124,\"g\":127},\"12778\":{\"m\":124,\"g\":127},\"12440\":{\"m\":124,\"g\":127},\"12760\":{\"m\":124,\"g\":127},\"12565\":{\"m\":124,\"g\":127},\"12674\":{\"m\":124,\"g\":127},\"12646\":{\"m\":124,\"g\":127},\"12240\":{\"m\":124,\"g\":127},\"12508\":{\"m\":124,\"g\":127},\"12737\":{\"m\":124,\"g\":127},\"12693\":{\"m\":124,\"g\":127},\"12752\":{\"m\":124,\"g\":127},\"12744\":{\"m\":124,\"g\":127},\"12736\":{\"m\":124,\"g\":127},\"12748\":{\"m\":124,\"g\":127},\"12741\":{\"m\":124,\"g\":127},\"12742\":{\"m\":124,\"g\":127},\"12738\":{\"m\":124,\"g\":127},\"11892\":{\"m\":124,\"g\":127},\"12721\":{\"m\":124,\"g\":127},\"12734\":{\"m\":124,\"g\":127},\"12723\":{\"m\":124,\"g\":127},\"12716\":{\"m\":124,\"g\":127},\"12732\":{\"m\":124,\"g\":127},\"12611\":{\"m\":124,\"g\":127},\"12728\":{\"m\":124,\"g\":127},\"12729\":{\"m\":124,\"g\":127},\"12718\":{\"m\":124,\"g\":127},\"12651\":{\"m\":124,\"g\":127},\"12711\":{\"m\":124,\"g\":127},\"12713\":{\"m\":124,\"g\":127},\"12714\":{\"m\":124,\"g\":127},\"12712\":{\"m\":124,\"g\":127},\"12658\":{\"m\":124,\"g\":127},\"12673\":{\"m\":124,\"g\":127},\"12710\":{\"m\":124,\"g\":127},\"12709\":{\"m\":124,\"g\":127},\"12406\":{\"m\":124,\"g\":127},\"12586\":{\"m\":124,\"g\":127},\"12631\":{\"m\":124,\"g\":127},\"12699\":{\"m\":124,\"g\":127},\"12484\":{\"m\":124,\"g\":127},\"12677\":{\"m\":124,\"g\":127},\"12696\":{\"m\":124,\"g\":127},\"12708\":{\"m\":124,\"g\":127},\"12609\":{\"m\":124,\"g\":127},\"12702\":{\"m\":124,\"g\":127},\"12455\":{\"m\":124,\"g\":127},\"12691\":{\"m\":124,\"g\":127},\"12687\":{\"m\":124,\"g\":127},\"11641\":{\"m\":124,\"g\":127},\"12680\":{\"m\":124,\"g\":127},\"10044\":{\"m\":124,\"g\":127},\"12668\":{\"m\":124,\"g\":127},\"12175\":{\"m\":124,\"g\":127},\"12670\":{\"m\":124,\"g\":127},\"8784\":{\"m\":124,\"g\":127},\"12486\":{\"m\":124,\"g\":127},\"12638\":{\"m\":124,\"g\":127},\"12353\":{\"m\":124,\"g\":127},\"13000\":{\"m\":125,\"g\":127},\"12908\":{\"m\":125,\"g\":127},\"12952\":{\"m\":125,\"g\":127},\"13010\":{\"m\":125,\"g\":127},\"11850\":{\"m\":125,\"g\":127},\"12781\":{\"m\":125,\"g\":127},\"12224\":{\"m\":125,\"g\":127},\"13013\":{\"m\":125,\"g\":127},\"13009\":{\"m\":125,\"g\":127},\"13001\":{\"m\":125,\"g\":127},\"13005\":{\"m\":125,\"g\":127},\"12996\":{\"m\":125,\"g\":127},\"12999\":{\"m\":125,\"g\":127},\"12966\":{\"m\":125,\"g\":127},\"12984\":{\"m\":125,\"g\":127},\"12982\":{\"m\":125,\"g\":127},\"10225\":{\"m\":125,\"g\":127},\"10702\":{\"m\":125,\"g\":127},\"11719\":{\"m\":125,\"g\":127},\"12916\":{\"m\":125,\"g\":127},\"12239\":{\"m\":125,\"g\":127},\"12883\":{\"m\":125,\"g\":127},\"12912\":{\"m\":125,\"g\":127},\"12604\":{\"m\":125,\"g\":127},\"12934\":{\"m\":125,\"g\":127},\"9528\":{\"m\":125,\"g\":127},\"12931\":{\"m\":125,\"g\":127},\"12959\":{\"m\":125,\"g\":127},\"12926\":{\"m\":125,\"g\":127},\"12803\":{\"m\":125,\"g\":127},\"12957\":{\"m\":125,\"g\":127},\"12943\":{\"m\":125,\"g\":127},\"12834\":{\"m\":125,\"g\":127},\"12956\":{\"m\":125,\"g\":127},\"12554\":{\"m\":125,\"g\":127},\"12946\":{\"m\":125,\"g\":127},\"12948\":{\"m\":125,\"g\":127},\"12940\":{\"m\":125,\"g\":127},\"11812\":{\"m\":125,\"g\":127},\"12928\":{\"m\":125,\"g\":127},\"12927\":{\"m\":125,\"g\":127},\"12839\":{\"m\":125,\"g\":127},\"10775\":{\"m\":125,\"g\":127},\"12917\":{\"m\":125,\"g\":127},\"12920\":{\"m\":125,\"g\":127},\"12332\":{\"m\":125,\"g\":127},\"12919\":{\"m\":125,\"g\":127},\"12448\":{\"m\":125,\"g\":127},\"12906\":{\"m\":125,\"g\":127},\"12907\":{\"m\":125,\"g\":127},\"12911\":{\"m\":125,\"g\":127},\"12895\":{\"m\":125,\"g\":127},\"12905\":{\"m\":125,\"g\":127},\"12904\":{\"m\":125,\"g\":127},\"12865\":{\"m\":125,\"g\":127},\"12900\":{\"m\":125,\"g\":127},\"12896\":{\"m\":125,\"g\":127},\"12891\":{\"m\":125,\"g\":127},\"12897\":{\"m\":125,\"g\":127},\"12889\":{\"m\":125,\"g\":127},\"12870\":{\"m\":125,\"g\":127},\"12361\":{\"m\":125,\"g\":127},\"12811\":{\"m\":125,\"g\":127},\"12888\":{\"m\":125,\"g\":127},\"12832\":{\"m\":125,\"g\":127},\"12843\":{\"m\":125,\"g\":127},\"12886\":{\"m\":125,\"g\":127},\"12798\":{\"m\":125,\"g\":127},\"12868\":{\"m\":125,\"g\":127},\"12853\":{\"m\":125,\"g\":127},\"12846\":{\"m\":125,\"g\":127},\"12582\":{\"m\":125,\"g\":127},\"12805\":{\"m\":125,\"g\":127},\"12849\":{\"m\":125,\"g\":127},\"12859\":{\"m\":125,\"g\":127},\"12852\":{\"m\":125,\"g\":127},\"12851\":{\"m\":125,\"g\":127},\"12856\":{\"m\":125,\"g\":127},\"12801\":{\"m\":125,\"g\":127},\"12822\":{\"m\":125,\"g\":127},\"12431\":{\"m\":125,\"g\":127},\"12825\":{\"m\":125,\"g\":127},\"12836\":{\"m\":125,\"g\":127},\"12374\":{\"m\":125,\"g\":127},\"12812\":{\"m\":125,\"g\":127},\"12816\":{\"m\":125,\"g\":127},\"12090\":{\"m\":125,\"g\":127},\"12758\":{\"m\":125,\"g\":127},\"12520\":{\"m\":125,\"g\":127},\"12765\":{\"m\":125,\"g\":127},\"12761\":{\"m\":125,\"g\":127},\"12763\":{\"m\":125,\"g\":127},\"12776\":{\"m\":125,\"g\":127},\"12772\":{\"m\":125,\"g\":127},\"12788\":{\"m\":125,\"g\":127},\"12576\":{\"m\":125,\"g\":127},\"12782\":{\"m\":125,\"g\":127},\"12794\":{\"m\":125,\"g\":127},\"12715\":{\"m\":125,\"g\":127},\"12795\":{\"m\":125,\"g\":127},\"12724\":{\"m\":125,\"g\":127},\"12279\":{\"m\":125,\"g\":127},\"12717\":{\"m\":125,\"g\":127},\"8243\":{\"m\":125,\"g\":127},\"11904\":{\"m\":125,\"g\":127},\"12684\":{\"m\":125,\"g\":127},\"12764\":{\"m\":125,\"g\":127},\"12363\":{\"m\":125,\"g\":127},\"11051\":{\"m\":125,\"g\":127},\"12749\":{\"m\":125,\"g\":127},\"13129\":{\"m\":126,\"g\":127},\"10808\":{\"m\":126,\"g\":127},\"12617\":{\"m\":126,\"g\":127},\"7906\":{\"m\":126,\"g\":127},\"13149\":{\"m\":126,\"g\":127},\"7886\":{\"m\":126,\"g\":127},\"9790\":{\"m\":126,\"g\":127},\"13120\":{\"m\":126,\"g\":127},\"12458\":{\"m\":126,\"g\":127},\"11961\":{\"m\":126,\"g\":127},\"13137\":{\"m\":126,\"g\":127},\"12860\":{\"m\":126,\"g\":127},\"13136\":{\"m\":126,\"g\":127},\"13135\":{\"m\":126,\"g\":127},\"13077\":{\"m\":126,\"g\":127},\"13132\":{\"m\":126,\"g\":127},\"13131\":{\"m\":126,\"g\":127},\"13133\":{\"m\":126,\"g\":127},\"12942\":{\"m\":126,\"g\":127},\"12817\":{\"m\":126,\"g\":127},\"13095\":{\"m\":126,\"g\":127},\"12666\":{\"m\":126,\"g\":127},\"12396\":{\"m\":126,\"g\":127},\"13118\":{\"m\":126,\"g\":127},\"12872\":{\"m\":126,\"g\":127},\"12997\":{\"m\":126,\"g\":127},\"12863\":{\"m\":126,\"g\":127},\"13114\":{\"m\":126,\"g\":127},\"13039\":{\"m\":126,\"g\":127},\"11856\":{\"m\":126,\"g\":127},\"13105\":{\"m\":126,\"g\":127},\"13056\":{\"m\":126,\"g\":127},\"12583\":{\"m\":126,\"g\":127},\"13093\":{\"m\":126,\"g\":127},\"12660\":{\"m\":126,\"g\":127},\"13090\":{\"m\":126,\"g\":127},\"12866\":{\"m\":126,\"g\":127},\"12915\":{\"m\":126,\"g\":127},\"13018\":{\"m\":126,\"g\":127},\"13092\":{\"m\":126,\"g\":127},\"11645\":{\"m\":126,\"g\":127},\"13041\":{\"m\":126,\"g\":127},\"10862\":{\"m\":126,\"g\":127},\"13088\":{\"m\":126,\"g\":127},\"12814\":{\"m\":126,\"g\":127},\"12941\":{\"m\":126,\"g\":127},\"12994\":{\"m\":126,\"g\":127},\"13076\":{\"m\":126,\"g\":127},\"13015\":{\"m\":126,\"g\":127},\"13063\":{\"m\":126,\"g\":127},\"12199\":{\"m\":126,\"g\":127},\"12983\":{\"m\":126,\"g\":127},\"13037\":{\"m\":126,\"g\":127},\"12689\":{\"m\":126,\"g\":127},\"11609\":{\"m\":126,\"g\":127},\"12976\":{\"m\":126,\"g\":127},\"13050\":{\"m\":126,\"g\":127},\"12980\":{\"m\":126,\"g\":127},\"11938\":{\"m\":126,\"g\":127},\"13057\":{\"m\":126,\"g\":127},\"13053\":{\"m\":126,\"g\":127},\"13036\":{\"m\":126,\"g\":127},\"13043\":{\"m\":126,\"g\":127},\"13029\":{\"m\":126,\"g\":127},\"12885\":{\"m\":126,\"g\":127},\"12518\":{\"m\":126,\"g\":127},\"13035\":{\"m\":126,\"g\":127},\"13027\":{\"m\":126,\"g\":127},\"13028\":{\"m\":126,\"g\":127},\"12218\":{\"m\":126,\"g\":127},\"12753\":{\"m\":126,\"g\":127},\"12869\":{\"m\":126,\"g\":127},\"12993\":{\"m\":126,\"g\":127},\"13012\":{\"m\":126,\"g\":127},\"13366\":{\"m\":128,\"g\":131},\"13389\":{\"m\":128,\"g\":131},\"13387\":{\"m\":128,\"g\":131},\"12903\":{\"m\":128,\"g\":131},\"13388\":{\"m\":128,\"g\":131},\"12874\":{\"m\":128,\"g\":131},\"13386\":{\"m\":128,\"g\":131},\"13385\":{\"m\":128,\"g\":131},\"13384\":{\"m\":128,\"g\":131},\"13381\":{\"m\":128,\"g\":131},\"13335\":{\"m\":128,\"g\":131},\"13228\":{\"m\":128,\"g\":131},\"13371\":{\"m\":128,\"g\":131},\"13339\":{\"m\":128,\"g\":131},\"12978\":{\"m\":128,\"g\":131},\"13332\":{\"m\":128,\"g\":131},\"13263\":{\"m\":128,\"g\":131},\"13375\":{\"m\":128,\"g\":131},\"13344\":{\"m\":128,\"g\":131},\"13373\":{\"m\":128,\"g\":131},\"13372\":{\"m\":128,\"g\":131},\"13336\":{\"m\":128,\"g\":131},\"13369\":{\"m\":128,\"g\":131},\"13348\":{\"m\":128,\"g\":131},\"13199\":{\"m\":128,\"g\":131},\"13179\":{\"m\":128,\"g\":131},\"12310\":{\"m\":128,\"g\":131},\"11870\":{\"m\":128,\"g\":131},\"13358\":{\"m\":128,\"g\":131},\"13101\":{\"m\":128,\"g\":131},\"13321\":{\"m\":128,\"g\":131},\"13355\":{\"m\":128,\"g\":131},\"13351\":{\"m\":128,\"g\":131},\"13181\":{\"m\":128,\"g\":131},\"13341\":{\"m\":128,\"g\":131},\"13325\":{\"m\":128,\"g\":131},\"12329\":{\"m\":128,\"g\":131},\"13337\":{\"m\":128,\"g\":131},\"12001\":{\"m\":128,\"g\":131},\"12692\":{\"m\":128,\"g\":131},\"12443\":{\"m\":128,\"g\":131},\"13331\":{\"m\":128,\"g\":131},\"13330\":{\"m\":128,\"g\":131},\"13329\":{\"m\":128,\"g\":131},\"10568\":{\"m\":128,\"g\":131},\"13306\":{\"m\":128,\"g\":131},\"13297\":{\"m\":128,\"g\":131},\"13326\":{\"m\":128,\"g\":131},\"13323\":{\"m\":128,\"g\":131},\"13287\":{\"m\":128,\"g\":131},\"13322\":{\"m\":128,\"g\":131},\"7415\":{\"m\":128,\"g\":131},\"13286\":{\"m\":128,\"g\":131},\"13285\":{\"m\":128,\"g\":131},\"13259\":{\"m\":128,\"g\":131},\"13320\":{\"m\":128,\"g\":131},\"13295\":{\"m\":128,\"g\":131},\"13318\":{\"m\":128,\"g\":131},\"13314\":{\"m\":128,\"g\":131},\"13226\":{\"m\":128,\"g\":131},\"13278\":{\"m\":128,\"g\":131},\"13279\":{\"m\":128,\"g\":131},\"12612\":{\"m\":128,\"g\":131},\"12871\":{\"m\":128,\"g\":131},\"13317\":{\"m\":128,\"g\":131},\"13294\":{\"m\":128,\"g\":131},\"13316\":{\"m\":128,\"g\":131},\"13315\":{\"m\":128,\"g\":131},\"13312\":{\"m\":128,\"g\":127},\"13091\":{\"m\":128,\"g\":127},\"13100\":{\"m\":128,\"g\":127},\"13311\":{\"m\":128,\"g\":127},\"13310\":{\"m\":128,\"g\":127},\"13045\":{\"m\":128,\"g\":127},\"13293\":{\"m\":128,\"g\":127},\"13305\":{\"m\":128,\"g\":127},\"13170\":{\"m\":128,\"g\":127},\"13298\":{\"m\":128,\"g\":127},\"13274\":{\"m\":128,\"g\":127},\"13235\":{\"m\":128,\"g\":127},\"13272\":{\"m\":128,\"g\":127},\"13260\":{\"m\":128,\"g\":127},\"10573\":{\"m\":128,\"g\":127},\"10665\":{\"m\":128,\"g\":127},\"13236\":{\"m\":128,\"g\":127},\"12777\":{\"m\":128,\"g\":127},\"13277\":{\"m\":128,\"g\":127},\"13288\":{\"m\":128,\"g\":127},\"13254\":{\"m\":128,\"g\":127},\"13284\":{\"m\":128,\"g\":127},\"13283\":{\"m\":128,\"g\":127},\"13242\":{\"m\":128,\"g\":127},\"12191\":{\"m\":128,\"g\":127},\"12605\":{\"m\":128,\"g\":127},\"12623\":{\"m\":128,\"g\":127},\"12622\":{\"m\":128,\"g\":127},\"12620\":{\"m\":128,\"g\":127},\"13113\":{\"m\":128,\"g\":127},\"13247\":{\"m\":128,\"g\":127},\"13237\":{\"m\":128,\"g\":127},\"13265\":{\"m\":128,\"g\":127},\"11589\":{\"m\":128,\"g\":127},\"13261\":{\"m\":128,\"g\":127},\"13256\":{\"m\":128,\"g\":127},\"13257\":{\"m\":128,\"g\":127},\"13255\":{\"m\":128,\"g\":127},\"13096\":{\"m\":128,\"g\":127},\"13221\":{\"m\":128,\"g\":127},\"13213\":{\"m\":128,\"g\":127},\"13246\":{\"m\":128,\"g\":127},\"13243\":{\"m\":128,\"g\":127},\"13186\":{\"m\":128,\"g\":127},\"13239\":{\"m\":128,\"g\":127},\"13097\":{\"m\":128,\"g\":127},\"12392\":{\"m\":128,\"g\":127},\"13222\":{\"m\":128,\"g\":127},\"13188\":{\"m\":128,\"g\":127},\"13218\":{\"m\":128,\"g\":127},\"13087\":{\"m\":128,\"g\":127},\"13142\":{\"m\":128,\"g\":127},\"11595\":{\"m\":128,\"g\":127},\"13210\":{\"m\":128,\"g\":127},\"12774\":{\"m\":128,\"g\":127},\"13211\":{\"m\":128,\"g\":127},\"13220\":{\"m\":128,\"g\":127},\"13171\":{\"m\":128,\"g\":127},\"13215\":{\"m\":128,\"g\":127},\"13212\":{\"m\":128,\"g\":127},\"10485\":{\"m\":128,\"g\":127},\"12543\":{\"m\":128,\"g\":127},\"12201\":{\"m\":128,\"g\":127},\"12376\":{\"m\":128,\"g\":127},\"13155\":{\"m\":128,\"g\":127},\"13148\":{\"m\":128,\"g\":127},\"13190\":{\"m\":128,\"g\":127},\"13178\":{\"m\":128,\"g\":127},\"13102\":{\"m\":128,\"g\":127},\"13154\":{\"m\":128,\"g\":127},\"12975\":{\"m\":128,\"g\":127},\"13163\":{\"m\":128,\"g\":127},\"13172\":{\"m\":128,\"g\":127},\"13150\":{\"m\":128,\"g\":127},\"10973\":{\"m\":128,\"g\":127},\"12288\":{\"m\":128,\"g\":127},\"13162\":{\"m\":128,\"g\":127},\"12215\":{\"m\":128,\"g\":127},\"13104\":{\"m\":128,\"g\":127},\"13164\":{\"m\":128,\"g\":127},\"13127\":{\"m\":128,\"g\":127},\"12979\":{\"m\":128,\"g\":127},\"13128\":{\"m\":128,\"g\":127},\"12998\":{\"m\":128,\"g\":127},\"13075\":{\"m\":128,\"g\":127},\"10907\":{\"m\":128,\"g\":127},\"13153\":{\"m\":128,\"g\":127},\"12214\":{\"m\":128,\"g\":127},\"14316\":{\"m\":129,\"g\":131},\"14324\":{\"m\":129,\"g\":131},\"14323\":{\"m\":129,\"g\":131},\"14317\":{\"m\":129,\"g\":131},\"14309\":{\"m\":129,\"g\":131},\"14319\":{\"m\":129,\"g\":131},\"14262\":{\"m\":129,\"g\":131},\"14315\":{\"m\":129,\"g\":131},\"14249\":{\"m\":129,\"g\":131},\"11423\":{\"m\":129,\"g\":131},\"14278\":{\"m\":129,\"g\":131},\"14299\":{\"m\":129,\"g\":131},\"13089\":{\"m\":129,\"g\":131},\"14133\":{\"m\":129,\"g\":131},\"14269\":{\"m\":129,\"g\":131},\"14281\":{\"m\":129,\"g\":131},\"14287\":{\"m\":129,\"g\":131},\"14286\":{\"m\":129,\"g\":131},\"14283\":{\"m\":129,\"g\":131},\"14252\":{\"m\":129,\"g\":131},\"14244\":{\"m\":129,\"g\":131},\"14279\":{\"m\":129,\"g\":131},\"14276\":{\"m\":129,\"g\":131},\"13738\":{\"m\":129,\"g\":131},\"14047\":{\"m\":129,\"g\":131},\"14274\":{\"m\":129,\"g\":131},\"14257\":{\"m\":129,\"g\":131},\"14254\":{\"m\":129,\"g\":131},\"13700\":{\"m\":129,\"g\":131},\"14267\":{\"m\":129,\"g\":131},\"14261\":{\"m\":129,\"g\":131},\"14172\":{\"m\":129,\"g\":131},\"14263\":{\"m\":129,\"g\":131},\"14259\":{\"m\":129,\"g\":131},\"14260\":{\"m\":129,\"g\":131},\"14222\":{\"m\":129,\"g\":131},\"14232\":{\"m\":129,\"g\":131},\"13968\":{\"m\":129,\"g\":131},\"14256\":{\"m\":129,\"g\":131},\"14255\":{\"m\":129,\"g\":131},\"13880\":{\"m\":129,\"g\":131},\"13794\":{\"m\":129,\"g\":131},\"14247\":{\"m\":129,\"g\":131},\"13843\":{\"m\":129,\"g\":131},\"14250\":{\"m\":129,\"g\":131},\"14245\":{\"m\":129,\"g\":131},\"14243\":{\"m\":129,\"g\":131},\"14241\":{\"m\":129,\"g\":131},\"14240\":{\"m\":129,\"g\":131},\"14237\":{\"m\":129,\"g\":131},\"14152\":{\"m\":129,\"g\":131},\"14179\":{\"m\":129,\"g\":131},\"14229\":{\"m\":129,\"g\":131},\"14230\":{\"m\":129,\"g\":131},\"13693\":{\"m\":129,\"g\":131},\"14228\":{\"m\":129,\"g\":131},\"14122\":{\"m\":129,\"g\":131},\"14088\":{\"m\":129,\"g\":131},\"13887\":{\"m\":129,\"g\":131},\"14219\":{\"m\":129,\"g\":131},\"14218\":{\"m\":129,\"g\":131},\"14214\":{\"m\":129,\"g\":131},\"14212\":{\"m\":129,\"g\":131},\"14211\":{\"m\":129,\"g\":131},\"14173\":{\"m\":129,\"g\":131},\"14165\":{\"m\":129,\"g\":131},\"14123\":{\"m\":129,\"g\":131},\"14186\":{\"m\":129,\"g\":131},\"14180\":{\"m\":129,\"g\":131},\"14167\":{\"m\":129,\"g\":131},\"14044\":{\"m\":129,\"g\":131},\"14182\":{\"m\":129,\"g\":131},\"14183\":{\"m\":129,\"g\":131},\"14181\":{\"m\":129,\"g\":131},\"14003\":{\"m\":129,\"g\":131},\"14034\":{\"m\":129,\"g\":131},\"14153\":{\"m\":129,\"g\":131},\"14187\":{\"m\":129,\"g\":131},\"12181\":{\"m\":129,\"g\":131},\"14155\":{\"m\":129,\"g\":131},\"14104\":{\"m\":129,\"g\":131},\"13873\":{\"m\":129,\"g\":131},\"14059\":{\"m\":129,\"g\":131},\"14140\":{\"m\":129,\"g\":131},\"13646\":{\"m\":129,\"g\":131},\"13841\":{\"m\":129,\"g\":131},\"14171\":{\"m\":129,\"g\":131},\"13907\":{\"m\":129,\"g\":131},\"14065\":{\"m\":129,\"g\":131},\"14166\":{\"m\":129,\"g\":131},\"14005\":{\"m\":129,\"g\":131},\"12494\":{\"m\":129,\"g\":131},\"14163\":{\"m\":129,\"g\":131},\"14156\":{\"m\":129,\"g\":131},\"14148\":{\"m\":129,\"g\":131},\"14161\":{\"m\":129,\"g\":131},\"14052\":{\"m\":129,\"g\":131},\"14150\":{\"m\":129,\"g\":131},\"14157\":{\"m\":129,\"g\":131},\"14154\":{\"m\":129,\"g\":131},\"14147\":{\"m\":129,\"g\":131},\"14146\":{\"m\":129,\"g\":131},\"14145\":{\"m\":129,\"g\":131},\"14136\":{\"m\":129,\"g\":131},\"13956\":{\"m\":129,\"g\":131},\"13759\":{\"m\":129,\"g\":131},\"14151\":{\"m\":129,\"g\":131},\"14135\":{\"m\":129,\"g\":131},\"14130\":{\"m\":129,\"g\":131},\"14119\":{\"m\":129,\"g\":131},\"14131\":{\"m\":129,\"g\":131},\"14129\":{\"m\":129,\"g\":131},\"14121\":{\"m\":129,\"g\":131},\"12306\":{\"m\":129,\"g\":131},\"14124\":{\"m\":129,\"g\":131},\"14113\":{\"m\":129,\"g\":131},\"14117\":{\"m\":129,\"g\":131},\"13377\":{\"m\":129,\"g\":131},\"13488\":{\"m\":129,\"g\":131},\"14111\":{\"m\":129,\"g\":131},\"12558\":{\"m\":129,\"g\":131},\"10712\":{\"m\":129,\"g\":131},\"14106\":{\"m\":129,\"g\":131},\"14067\":{\"m\":129,\"g\":131},\"14096\":{\"m\":129,\"g\":131},\"14094\":{\"m\":129,\"g\":131},\"13724\":{\"m\":129,\"g\":131},\"13904\":{\"m\":129,\"g\":131},\"14076\":{\"m\":129,\"g\":131},\"13936\":{\"m\":129,\"g\":131},\"14006\":{\"m\":129,\"g\":131},\"14082\":{\"m\":129,\"g\":131},\"13205\":{\"m\":129,\"g\":131},\"14079\":{\"m\":129,\"g\":131},\"14036\":{\"m\":129,\"g\":131},\"13944\":{\"m\":129,\"g\":131},\"13749\":{\"m\":129,\"g\":131},\"14069\":{\"m\":129,\"g\":131},\"13946\":{\"m\":129,\"g\":131},\"14048\":{\"m\":129,\"g\":131},\"13960\":{\"m\":129,\"g\":131},\"14057\":{\"m\":129,\"g\":131},\"13425\":{\"m\":129,\"g\":131},\"13854\":{\"m\":129,\"g\":131},\"14002\":{\"m\":129,\"g\":131},\"13855\":{\"m\":129,\"g\":131},\"14040\":{\"m\":129,\"g\":131},\"13895\":{\"m\":129,\"g\":131},\"13976\":{\"m\":129,\"g\":131},\"13814\":{\"m\":129,\"g\":131},\"13965\":{\"m\":129,\"g\":131},\"14026\":{\"m\":129,\"g\":131},\"14017\":{\"m\":129,\"g\":131},\"14033\":{\"m\":129,\"g\":131},\"14030\":{\"m\":129,\"g\":131},\"14028\":{\"m\":129,\"g\":131},\"13761\":{\"m\":129,\"g\":131},\"14027\":{\"m\":129,\"g\":131},\"14025\":{\"m\":129,\"g\":131},\"13824\":{\"m\":129,\"g\":131},\"12277\":{\"m\":129,\"g\":131},\"13941\":{\"m\":129,\"g\":131},\"13983\":{\"m\":129,\"g\":131},\"14018\":{\"m\":129,\"g\":131},\"14007\":{\"m\":129,\"g\":131},\"14022\":{\"m\":129,\"g\":131},\"14019\":{\"m\":129,\"g\":131},\"14021\":{\"m\":129,\"g\":131},\"13937\":{\"m\":129,\"g\":131},\"13990\":{\"m\":129,\"g\":131},\"14020\":{\"m\":129,\"g\":131},\"13966\":{\"m\":129,\"g\":131},\"13872\":{\"m\":129,\"g\":131},\"14016\":{\"m\":129,\"g\":131},\"14013\":{\"m\":129,\"g\":131},\"14015\":{\"m\":129,\"g\":131},\"14014\":{\"m\":129,\"g\":131},\"14012\":{\"m\":129,\"g\":131},\"13151\":{\"m\":129,\"g\":131},\"13892\":{\"m\":129,\"g\":131},\"14000\":{\"m\":129,\"g\":131},\"14009\":{\"m\":129,\"g\":131},\"13766\":{\"m\":129,\"g\":131},\"13754\":{\"m\":129,\"g\":131},\"12491\":{\"m\":129,\"g\":131},\"13994\":{\"m\":129,\"g\":131},\"13991\":{\"m\":129,\"g\":131},\"13977\":{\"m\":129,\"g\":131},\"13203\":{\"m\":129,\"g\":131},\"12588\":{\"m\":129,\"g\":131},\"13922\":{\"m\":129,\"g\":131},\"13852\":{\"m\":129,\"g\":131},\"13963\":{\"m\":129,\"g\":131},\"10071\":{\"m\":129,\"g\":131},\"13961\":{\"m\":129,\"g\":131},\"13962\":{\"m\":129,\"g\":131},\"13958\":{\"m\":129,\"g\":131},\"7725\":{\"m\":129,\"g\":131},\"13925\":{\"m\":129,\"g\":131},\"12786\":{\"m\":129,\"g\":131},\"13954\":{\"m\":129,\"g\":131},\"13951\":{\"m\":129,\"g\":131},\"13950\":{\"m\":129,\"g\":131},\"13945\":{\"m\":129,\"g\":131},\"13942\":{\"m\":129,\"g\":131},\"12969\":{\"m\":129,\"g\":131},\"13910\":{\"m\":129,\"g\":131},\"13866\":{\"m\":129,\"g\":131},\"13903\":{\"m\":129,\"g\":131},\"13421\":{\"m\":129,\"g\":131},\"13851\":{\"m\":129,\"g\":131},\"13544\":{\"m\":129,\"g\":131},\"13938\":{\"m\":129,\"g\":131},\"13935\":{\"m\":129,\"g\":131},\"13933\":{\"m\":129,\"g\":131},\"13859\":{\"m\":129,\"g\":131},\"13931\":{\"m\":129,\"g\":131},\"13928\":{\"m\":129,\"g\":131},\"13927\":{\"m\":129,\"g\":131},\"13657\":{\"m\":129,\"g\":131},\"13081\":{\"m\":129,\"g\":131},\"12078\":{\"m\":129,\"g\":131},\"13921\":{\"m\":129,\"g\":131},\"13905\":{\"m\":129,\"g\":131},\"13916\":{\"m\":129,\"g\":131},\"13642\":{\"m\":129,\"g\":131},\"13908\":{\"m\":129,\"g\":131},\"13848\":{\"m\":129,\"g\":131},\"13793\":{\"m\":129,\"g\":131},\"13901\":{\"m\":129,\"g\":131},\"13827\":{\"m\":129,\"g\":131},\"13889\":{\"m\":129,\"g\":131},\"13890\":{\"m\":129,\"g\":131},\"13888\":{\"m\":129,\"g\":131},\"11893\":{\"m\":129,\"g\":131},\"13891\":{\"m\":129,\"g\":131},\"13870\":{\"m\":129,\"g\":131},\"13874\":{\"m\":129,\"g\":131},\"13860\":{\"m\":129,\"g\":131},\"13871\":{\"m\":129,\"g\":131},\"13572\":{\"m\":129,\"g\":131},\"10275\":{\"m\":129,\"g\":131},\"13487\":{\"m\":129,\"g\":131},\"13786\":{\"m\":129,\"g\":131},\"13834\":{\"m\":129,\"g\":131},\"13822\":{\"m\":129,\"g\":131},\"13783\":{\"m\":129,\"g\":131},\"13763\":{\"m\":129,\"g\":131},\"13752\":{\"m\":129,\"g\":131},\"13864\":{\"m\":129,\"g\":131},\"13865\":{\"m\":129,\"g\":131},\"13853\":{\"m\":129,\"g\":131},\"10027\":{\"m\":129,\"g\":131},\"13745\":{\"m\":129,\"g\":131},\"13858\":{\"m\":129,\"g\":131},\"13612\":{\"m\":129,\"g\":131},\"11871\":{\"m\":129,\"g\":131},\"13713\":{\"m\":129,\"g\":131},\"13508\":{\"m\":129,\"g\":131},\"13846\":{\"m\":129,\"g\":131},\"13819\":{\"m\":129,\"g\":131},\"13245\":{\"m\":129,\"g\":131},\"13751\":{\"m\":129,\"g\":131},\"13833\":{\"m\":129,\"g\":131},\"13831\":{\"m\":129,\"g\":131},\"13792\":{\"m\":129,\"g\":131},\"13829\":{\"m\":129,\"g\":131},\"13800\":{\"m\":129,\"g\":131},\"13656\":{\"m\":129,\"g\":131},\"13820\":{\"m\":129,\"g\":131},\"13201\":{\"m\":129,\"g\":131},\"13650\":{\"m\":129,\"g\":131},\"13816\":{\"m\":129,\"g\":131},\"13601\":{\"m\":129,\"g\":131},\"13810\":{\"m\":129,\"g\":131},\"13802\":{\"m\":129,\"g\":131},\"13815\":{\"m\":129,\"g\":131},\"13813\":{\"m\":129,\"g\":131},\"13687\":{\"m\":129,\"g\":131},\"13806\":{\"m\":129,\"g\":131},\"13781\":{\"m\":129,\"g\":131},\"13791\":{\"m\":129,\"g\":131},\"13180\":{\"m\":129,\"g\":131},\"13718\":{\"m\":129,\"g\":131},\"13787\":{\"m\":129,\"g\":131},\"13764\":{\"m\":129,\"g\":131},\"13709\":{\"m\":129,\"g\":131},\"13776\":{\"m\":129,\"g\":131},\"13777\":{\"m\":129,\"g\":131},\"13727\":{\"m\":129,\"g\":131},\"13676\":{\"m\":129,\"g\":131},\"13690\":{\"m\":129,\"g\":131},\"13547\":{\"m\":129,\"g\":131},\"13533\":{\"m\":129,\"g\":131},\"13769\":{\"m\":129,\"g\":131},\"13478\":{\"m\":129,\"g\":131},\"13736\":{\"m\":129,\"g\":131},\"13706\":{\"m\":129,\"g\":131},\"13768\":{\"m\":129,\"g\":131},\"13704\":{\"m\":129,\"g\":131},\"13720\":{\"m\":129,\"g\":131},\"13714\":{\"m\":129,\"g\":131},\"12759\":{\"m\":129,\"g\":131},\"9405\":{\"m\":129,\"g\":131},\"13506\":{\"m\":129,\"g\":131},\"13756\":{\"m\":129,\"g\":131},\"13669\":{\"m\":129,\"g\":131},\"13729\":{\"m\":129,\"g\":131},\"13702\":{\"m\":129,\"g\":131},\"13746\":{\"m\":129,\"g\":131},\"13484\":{\"m\":129,\"g\":131},\"12690\":{\"m\":129,\"g\":131},\"13701\":{\"m\":129,\"g\":131},\"12949\":{\"m\":129,\"g\":131},\"13707\":{\"m\":129,\"g\":131},\"13694\":{\"m\":129,\"g\":131},\"13466\":{\"m\":129,\"g\":131},\"13739\":{\"m\":129,\"g\":131},\"13737\":{\"m\":129,\"g\":131},\"13735\":{\"m\":129,\"g\":131},\"13734\":{\"m\":129,\"g\":131},\"13733\":{\"m\":129,\"g\":131},\"13407\":{\"m\":129,\"g\":131},\"13649\":{\"m\":129,\"g\":131},\"13705\":{\"m\":129,\"g\":131},\"13630\":{\"m\":129,\"g\":131},\"13719\":{\"m\":129,\"g\":131},\"13327\":{\"m\":129,\"g\":131},\"13647\":{\"m\":129,\"g\":131},\"13708\":{\"m\":129,\"g\":131},\"13498\":{\"m\":129,\"g\":131},\"13590\":{\"m\":129,\"g\":131},\"13679\":{\"m\":129,\"g\":131},\"13564\":{\"m\":129,\"g\":131},\"13686\":{\"m\":129,\"g\":131},\"13177\":{\"m\":129,\"g\":131},\"13665\":{\"m\":129,\"g\":131},\"13675\":{\"m\":129,\"g\":131},\"13640\":{\"m\":129,\"g\":131},\"13587\":{\"m\":129,\"g\":131},\"13596\":{\"m\":129,\"g\":131},\"13555\":{\"m\":129,\"g\":131},\"13619\":{\"m\":129,\"g\":131},\"13301\":{\"m\":129,\"g\":131},\"12672\":{\"m\":129,\"g\":131},\"13683\":{\"m\":129,\"g\":131},\"13685\":{\"m\":129,\"g\":131},\"13627\":{\"m\":129,\"g\":131},\"13678\":{\"m\":129,\"g\":131},\"13677\":{\"m\":129,\"g\":131},\"13659\":{\"m\":129,\"g\":131},\"13667\":{\"m\":129,\"g\":131},\"13610\":{\"m\":129,\"g\":131},\"11526\":{\"m\":129,\"g\":131},\"13038\":{\"m\":129,\"g\":131},\"13600\":{\"m\":129,\"g\":131},\"13459\":{\"m\":129,\"g\":131},\"13666\":{\"m\":129,\"g\":131},\"12964\":{\"m\":129,\"g\":131},\"13655\":{\"m\":129,\"g\":131},\"13524\":{\"m\":129,\"g\":131},\"11577\":{\"m\":129,\"g\":131},\"13663\":{\"m\":129,\"g\":131},\"13637\":{\"m\":129,\"g\":131},\"13634\":{\"m\":129,\"g\":131},\"13644\":{\"m\":129,\"g\":131},\"13197\":{\"m\":129,\"g\":131},\"13617\":{\"m\":129,\"g\":131},\"13633\":{\"m\":129,\"g\":131},\"13453\":{\"m\":129,\"g\":131},\"13614\":{\"m\":129,\"g\":131},\"13554\":{\"m\":129,\"g\":131},\"13583\":{\"m\":129,\"g\":131},\"13328\":{\"m\":129,\"g\":131},\"13253\":{\"m\":129,\"g\":131},\"13248\":{\"m\":129,\"g\":131},\"12379\":{\"m\":129,\"g\":131},\"13528\":{\"m\":129,\"g\":131},\"13613\":{\"m\":129,\"g\":131},\"13429\":{\"m\":129,\"g\":131},\"13562\":{\"m\":129,\"g\":131},\"13055\":{\"m\":129,\"g\":131},\"13603\":{\"m\":129,\"g\":131},\"13604\":{\"m\":129,\"g\":131},\"13357\":{\"m\":129,\"g\":131},\"13570\":{\"m\":129,\"g\":131},\"13577\":{\"m\":129,\"g\":131},\"13465\":{\"m\":129,\"g\":131},\"13049\":{\"m\":129,\"g\":131},\"13448\":{\"m\":129,\"g\":131},\"13589\":{\"m\":129,\"g\":131},\"13567\":{\"m\":129,\"g\":131},\"13568\":{\"m\":129,\"g\":131},\"13413\":{\"m\":129,\"g\":131},\"13558\":{\"m\":129,\"g\":131},\"13557\":{\"m\":129,\"g\":131},\"13542\":{\"m\":129,\"g\":131},\"13481\":{\"m\":129,\"g\":131},\"12740\":{\"m\":129,\"g\":131},\"13551\":{\"m\":129,\"g\":131},\"13452\":{\"m\":129,\"g\":131},\"9234\":{\"m\":129,\"g\":131},\"13047\":{\"m\":129,\"g\":131},\"13548\":{\"m\":129,\"g\":131},\"13543\":{\"m\":129,\"g\":131},\"13541\":{\"m\":129,\"g\":131},\"13489\":{\"m\":129,\"g\":131},\"13495\":{\"m\":129,\"g\":131},\"13540\":{\"m\":129,\"g\":131},\"13537\":{\"m\":129,\"g\":131},\"13534\":{\"m\":129,\"g\":131},\"13536\":{\"m\":129,\"g\":131},\"13532\":{\"m\":129,\"g\":131},\"13527\":{\"m\":129,\"g\":131},\"13474\":{\"m\":129,\"g\":131},\"13525\":{\"m\":129,\"g\":131},\"13522\":{\"m\":129,\"g\":131},\"13521\":{\"m\":129,\"g\":131},\"13519\":{\"m\":129,\"g\":131},\"13516\":{\"m\":129,\"g\":131},\"12962\":{\"m\":129,\"g\":131},\"13513\":{\"m\":129,\"g\":131},\"13512\":{\"m\":129,\"g\":131},\"13510\":{\"m\":129,\"g\":131},\"13509\":{\"m\":129,\"g\":131},\"13168\":{\"m\":129,\"g\":131},\"13157\":{\"m\":129,\"g\":131},\"13501\":{\"m\":129,\"g\":131},\"13126\":{\"m\":129,\"g\":131},\"13496\":{\"m\":129,\"g\":131},\"13374\":{\"m\":129,\"g\":131},\"13491\":{\"m\":129,\"g\":131},\"13482\":{\"m\":129,\"g\":131},\"13486\":{\"m\":129,\"g\":131},\"13393\":{\"m\":129,\"g\":131},\"13460\":{\"m\":129,\"g\":131},\"13094\":{\"m\":129,\"g\":131},\"13479\":{\"m\":129,\"g\":131},\"13476\":{\"m\":129,\"g\":131},\"12149\":{\"m\":129,\"g\":131},\"13289\":{\"m\":129,\"g\":131},\"13473\":{\"m\":129,\"g\":131},\"13258\":{\"m\":129,\"g\":131},\"13229\":{\"m\":129,\"g\":131},\"13462\":{\"m\":129,\"g\":131},\"13444\":{\"m\":129,\"g\":131},\"13458\":{\"m\":129,\"g\":131},\"13449\":{\"m\":129,\"g\":131},\"13455\":{\"m\":129,\"g\":131},\"13140\":{\"m\":129,\"g\":131},\"13463\":{\"m\":129,\"g\":131},\"13264\":{\"m\":129,\"g\":131},\"12359\":{\"m\":129,\"g\":131},\"13461\":{\"m\":129,\"g\":131},\"13457\":{\"m\":129,\"g\":131},\"13173\":{\"m\":129,\"g\":131},\"13456\":{\"m\":129,\"g\":131},\"13273\":{\"m\":129,\"g\":131},\"13022\":{\"m\":129,\"g\":131},\"13450\":{\"m\":129,\"g\":131},\"13447\":{\"m\":129,\"g\":131},\"13445\":{\"m\":129,\"g\":131},\"13443\":{\"m\":129,\"g\":131},\"13418\":{\"m\":129,\"g\":131},\"13217\":{\"m\":129,\"g\":131},\"5879\":{\"m\":129,\"g\":131},\"11900\":{\"m\":129,\"g\":131},\"13144\":{\"m\":129,\"g\":131},\"13379\":{\"m\":129,\"g\":131},\"13282\":{\"m\":129,\"g\":131},\"13420\":{\"m\":129,\"g\":131},\"13416\":{\"m\":129,\"g\":131},\"13415\":{\"m\":129,\"g\":131},\"13399\":{\"m\":129,\"g\":131},\"13004\":{\"m\":129,\"g\":131},\"13324\":{\"m\":129,\"g\":131},\"13112\":{\"m\":129,\"g\":131},\"11644\":{\"m\":129,\"g\":131},\"13398\":{\"m\":129,\"g\":131},\"13396\":{\"m\":129,\"g\":131},\"13345\":{\"m\":129,\"g\":131},\"13338\":{\"m\":129,\"g\":131},\"12065\":{\"m\":129,\"g\":131},\"13391\":{\"m\":129,\"g\":131},\"13383\":{\"m\":129,\"g\":131},\"14670\":{\"m\":130,\"g\":131},\"14650\":{\"m\":130,\"g\":131},\"14457\":{\"m\":130,\"g\":131},\"14657\":{\"m\":130,\"g\":131},\"14667\":{\"m\":130,\"g\":131},\"14634\":{\"m\":130,\"g\":131},\"14664\":{\"m\":130,\"g\":131},\"14663\":{\"m\":130,\"g\":131},\"14658\":{\"m\":130,\"g\":131},\"14497\":{\"m\":130,\"g\":131},\"14651\":{\"m\":130,\"g\":131},\"14649\":{\"m\":130,\"g\":131},\"14356\":{\"m\":130,\"g\":131},\"14629\":{\"m\":130,\"g\":131},\"14558\":{\"m\":130,\"g\":131},\"12527\":{\"m\":130,\"g\":131},\"14632\":{\"m\":130,\"g\":131},\"14606\":{\"m\":130,\"g\":131},\"12551\":{\"m\":130,\"g\":131},\"14625\":{\"m\":130,\"g\":131},\"14556\":{\"m\":130,\"g\":131},\"14585\":{\"m\":130,\"g\":131},\"14618\":{\"m\":130,\"g\":131},\"14203\":{\"m\":130,\"g\":131},\"14452\":{\"m\":130,\"g\":131},\"14612\":{\"m\":130,\"g\":131},\"14609\":{\"m\":130,\"g\":131},\"14604\":{\"m\":130,\"g\":131},\"14608\":{\"m\":130,\"g\":131},\"14605\":{\"m\":130,\"g\":131},\"14600\":{\"m\":130,\"g\":131},\"14591\":{\"m\":130,\"g\":131},\"14386\":{\"m\":130,\"g\":131},\"14573\":{\"m\":130,\"g\":131},\"14551\":{\"m\":130,\"g\":131},\"14141\":{\"m\":130,\"g\":131},\"14590\":{\"m\":130,\"g\":131},\"14588\":{\"m\":130,\"g\":131},\"14587\":{\"m\":130,\"g\":131},\"14586\":{\"m\":130,\"g\":131},\"14517\":{\"m\":130,\"g\":131},\"14455\":{\"m\":130,\"g\":131},\"14553\":{\"m\":130,\"g\":131},\"13573\":{\"m\":130,\"g\":131},\"14132\":{\"m\":130,\"g\":131},\"14185\":{\"m\":130,\"g\":131},\"14576\":{\"m\":130,\"g\":131},\"14577\":{\"m\":130,\"g\":131},\"13725\":{\"m\":130,\"g\":131},\"13998\":{\"m\":130,\"g\":131},\"14569\":{\"m\":130,\"g\":131},\"14412\":{\"m\":130,\"g\":131},\"14544\":{\"m\":130,\"g\":131},\"14561\":{\"m\":130,\"g\":131},\"14560\":{\"m\":130,\"g\":131},\"14559\":{\"m\":130,\"g\":131},\"14337\":{\"m\":130,\"g\":131},\"14555\":{\"m\":130,\"g\":131},\"14494\":{\"m\":130,\"g\":131},\"14476\":{\"m\":130,\"g\":131},\"14205\":{\"m\":130,\"g\":131},\"14557\":{\"m\":130,\"g\":131},\"14447\":{\"m\":130,\"g\":131},\"14552\":{\"m\":130,\"g\":131},\"14518\":{\"m\":130,\"g\":131},\"14538\":{\"m\":130,\"g\":131},\"14520\":{\"m\":130,\"g\":131},\"14535\":{\"m\":130,\"g\":131},\"14493\":{\"m\":130,\"g\":131},\"14464\":{\"m\":130,\"g\":131},\"14543\":{\"m\":130,\"g\":131},\"13897\":{\"m\":130,\"g\":131},\"14505\":{\"m\":130,\"g\":131},\"14539\":{\"m\":130,\"g\":131},\"13115\":{\"m\":130,\"g\":131},\"14533\":{\"m\":130,\"g\":131},\"12324\":{\"m\":130,\"g\":131},\"14290\":{\"m\":130,\"g\":131},\"14528\":{\"m\":130,\"g\":131},\"14530\":{\"m\":130,\"g\":131},\"11791\":{\"m\":130,\"g\":131},\"14522\":{\"m\":130,\"g\":131},\"14465\":{\"m\":130,\"g\":131},\"14521\":{\"m\":130,\"g\":131},\"14291\":{\"m\":130,\"g\":131},\"14427\":{\"m\":130,\"g\":131},\"14516\":{\"m\":130,\"g\":131},\"14514\":{\"m\":130,\"g\":131},\"14513\":{\"m\":130,\"g\":131},\"14512\":{\"m\":130,\"g\":131},\"14312\":{\"m\":130,\"g\":131},\"14405\":{\"m\":130,\"g\":131},\"14420\":{\"m\":130,\"g\":131},\"14460\":{\"m\":130,\"g\":131},\"14508\":{\"m\":130,\"g\":131},\"13607\":{\"m\":130,\"g\":131},\"14507\":{\"m\":130,\"g\":131},\"12471\":{\"m\":130,\"g\":131},\"14093\":{\"m\":130,\"g\":131},\"14234\":{\"m\":130,\"g\":131},\"13434\":{\"m\":130,\"g\":131},\"14506\":{\"m\":130,\"g\":131},\"14471\":{\"m\":130,\"g\":131},\"14459\":{\"m\":130,\"g\":131},\"14466\":{\"m\":130,\"g\":131},\"14456\":{\"m\":130,\"g\":131},\"14097\":{\"m\":130,\"g\":131},\"14499\":{\"m\":130,\"g\":131},\"13584\":{\"m\":130,\"g\":131},\"14364\":{\"m\":130,\"g\":131},\"13861\":{\"m\":130,\"g\":131},\"13996\":{\"m\":130,\"g\":131},\"14472\":{\"m\":130,\"g\":131},\"14484\":{\"m\":130,\"g\":131},\"14444\":{\"m\":130,\"g\":131},\"14475\":{\"m\":130,\"g\":131},\"14463\":{\"m\":130,\"g\":131},\"14473\":{\"m\":130,\"g\":131},\"14468\":{\"m\":130,\"g\":131},\"13836\":{\"m\":130,\"g\":131},\"14421\":{\"m\":130,\"g\":131},\"14432\":{\"m\":130,\"g\":131},\"14251\":{\"m\":130,\"g\":131},\"14450\":{\"m\":130,\"g\":131},\"14445\":{\"m\":130,\"g\":131},\"14446\":{\"m\":130,\"g\":131},\"8287\":{\"m\":130,\"g\":131},\"14348\":{\"m\":130,\"g\":131},\"14350\":{\"m\":130,\"g\":131},\"14441\":{\"m\":130,\"g\":131},\"14325\":{\"m\":130,\"g\":131},\"14440\":{\"m\":130,\"g\":131},\"14438\":{\"m\":130,\"g\":131},\"14143\":{\"m\":130,\"g\":131},\"14434\":{\"m\":130,\"g\":131},\"14366\":{\"m\":130,\"g\":131},\"14430\":{\"m\":130,\"g\":131},\"14429\":{\"m\":130,\"g\":131},\"14225\":{\"m\":130,\"g\":131},\"14409\":{\"m\":130,\"g\":131},\"14213\":{\"m\":130,\"g\":131},\"14224\":{\"m\":130,\"g\":131},\"14334\":{\"m\":130,\"g\":131},\"14399\":{\"m\":130,\"g\":131},\"12446\":{\"m\":130,\"g\":131},\"13359\":{\"m\":130,\"g\":131},\"14383\":{\"m\":130,\"g\":131},\"14394\":{\"m\":130,\"g\":131},\"14381\":{\"m\":130,\"g\":131},\"12309\":{\"m\":130,\"g\":131},\"14393\":{\"m\":130,\"g\":131},\"12316\":{\"m\":130,\"g\":131},\"14292\":{\"m\":130,\"g\":131},\"14392\":{\"m\":130,\"g\":131},\"14272\":{\"m\":130,\"g\":131},\"13731\":{\"m\":130,\"g\":131},\"14359\":{\"m\":130,\"g\":131},\"14377\":{\"m\":130,\"g\":131},\"14330\":{\"m\":130,\"g\":131},\"14277\":{\"m\":130,\"g\":131},\"14375\":{\"m\":130,\"g\":131},\"14374\":{\"m\":130,\"g\":131},\"14253\":{\"m\":130,\"g\":131},\"14372\":{\"m\":130,\"g\":131},\"14226\":{\"m\":130,\"g\":131},\"14371\":{\"m\":130,\"g\":131},\"14326\":{\"m\":130,\"g\":131},\"9660\":{\"m\":130,\"g\":131},\"12330\":{\"m\":130,\"g\":131},\"14355\":{\"m\":130,\"g\":131},\"13585\":{\"m\":130,\"g\":131},\"14362\":{\"m\":130,\"g\":131},\"14271\":{\"m\":130,\"g\":131},\"14295\":{\"m\":130,\"g\":131},\"13980\":{\"m\":130,\"g\":131},\"14347\":{\"m\":130,\"g\":131},\"14333\":{\"m\":130,\"g\":131},\"12441\":{\"m\":130,\"g\":131},\"14344\":{\"m\":130,\"g\":131},\"14265\":{\"m\":130,\"g\":131},\"14335\":{\"m\":130,\"g\":131},\"14336\":{\"m\":130,\"g\":131},\"13350\":{\"m\":130,\"g\":131},\"14266\":{\"m\":130,\"g\":131},\"14329\":{\"m\":130,\"g\":131},\"13812\":{\"m\":130,\"g\":131},\"14195\":{\"m\":130,\"g\":131},\"14321\":{\"m\":130,\"g\":131},\"13710\":{\"m\":130,\"g\":131},\"14858\":{\"m\":132,\"g\":133},\"14620\":{\"m\":132,\"g\":133},\"14304\":{\"m\":132,\"g\":133},\"14917\":{\"m\":132,\"g\":133},\"14307\":{\"m\":132,\"g\":133},\"14887\":{\"m\":132,\"g\":133},\"14911\":{\"m\":132,\"g\":133},\"14910\":{\"m\":132,\"g\":133},\"14852\":{\"m\":132,\"g\":133},\"14889\":{\"m\":132,\"g\":133},\"14890\":{\"m\":132,\"g\":133},\"14900\":{\"m\":132,\"g\":133},\"14899\":{\"m\":132,\"g\":133},\"12287\":{\"m\":132,\"g\":133},\"14878\":{\"m\":132,\"g\":133},\"14541\":{\"m\":132,\"g\":133},\"13641\":{\"m\":132,\"g\":133},\"14828\":{\"m\":132,\"g\":133},\"14827\":{\"m\":132,\"g\":133},\"14853\":{\"m\":132,\"g\":133},\"14876\":{\"m\":132,\"g\":133},\"14880\":{\"m\":132,\"g\":133},\"14877\":{\"m\":132,\"g\":133},\"14856\":{\"m\":132,\"g\":133},\"14875\":{\"m\":132,\"g\":133},\"14861\":{\"m\":132,\"g\":133},\"14845\":{\"m\":132,\"g\":133},\"14871\":{\"m\":132,\"g\":133},\"14313\":{\"m\":132,\"g\":133},\"14554\":{\"m\":132,\"g\":133},\"14865\":{\"m\":132,\"g\":133},\"14811\":{\"m\":132,\"g\":133},\"14836\":{\"m\":132,\"g\":133},\"14848\":{\"m\":132,\"g\":133},\"14854\":{\"m\":132,\"g\":133},\"14849\":{\"m\":132,\"g\":133},\"14844\":{\"m\":132,\"g\":133},\"14851\":{\"m\":132,\"g\":133},\"14850\":{\"m\":132,\"g\":133},\"14847\":{\"m\":132,\"g\":133},\"14442\":{\"m\":132,\"g\":133},\"14841\":{\"m\":132,\"g\":133},\"14796\":{\"m\":132,\"g\":133},\"14638\":{\"m\":132,\"g\":133},\"14823\":{\"m\":132,\"g\":133},\"14801\":{\"m\":132,\"g\":133},\"14837\":{\"m\":132,\"g\":133},\"14045\":{\"m\":132,\"g\":133},\"14833\":{\"m\":132,\"g\":133},\"14829\":{\"m\":132,\"g\":133},\"14769\":{\"m\":132,\"g\":133},\"14712\":{\"m\":132,\"g\":133},\"14716\":{\"m\":132,\"g\":133},\"14830\":{\"m\":132,\"g\":133},\"14834\":{\"m\":132,\"g\":133},\"14812\":{\"m\":132,\"g\":133},\"14831\":{\"m\":132,\"g\":133},\"14806\":{\"m\":132,\"g\":133},\"14822\":{\"m\":132,\"g\":133},\"14819\":{\"m\":132,\"g\":133},\"14710\":{\"m\":132,\"g\":133},\"14807\":{\"m\":132,\"g\":133},\"14770\":{\"m\":132,\"g\":133},\"14793\":{\"m\":132,\"g\":133},\"14808\":{\"m\":132,\"g\":133},\"14697\":{\"m\":132,\"g\":133},\"14720\":{\"m\":132,\"g\":133},\"14803\":{\"m\":132,\"g\":133},\"14788\":{\"m\":132,\"g\":133},\"14794\":{\"m\":132,\"g\":133},\"9650\":{\"m\":132,\"g\":133},\"14786\":{\"m\":132,\"g\":133},\"14784\":{\"m\":132,\"g\":133},\"14725\":{\"m\":132,\"g\":133},\"14777\":{\"m\":132,\"g\":133},\"14761\":{\"m\":132,\"g\":133},\"14759\":{\"m\":132,\"g\":133},\"14064\":{\"m\":132,\"g\":133},\"14768\":{\"m\":132,\"g\":133},\"14756\":{\"m\":132,\"g\":133},\"14744\":{\"m\":132,\"g\":133},\"14687\":{\"m\":132,\"g\":133},\"14763\":{\"m\":132,\"g\":131},\"14177\":{\"m\":132,\"g\":131},\"14758\":{\"m\":132,\"g\":131},\"14669\":{\"m\":132,\"g\":131},\"14740\":{\"m\":132,\"g\":131},\"14753\":{\"m\":132,\"g\":131},\"14698\":{\"m\":132,\"g\":131},\"14379\":{\"m\":132,\"g\":131},\"14752\":{\"m\":132,\"g\":131},\"14751\":{\"m\":132,\"g\":131},\"14745\":{\"m\":132,\"g\":131},\"12953\":{\"m\":132,\"g\":131},\"14743\":{\"m\":132,\"g\":131},\"14738\":{\"m\":132,\"g\":131},\"14733\":{\"m\":132,\"g\":131},\"12039\":{\"m\":132,\"g\":131},\"13432\":{\"m\":132,\"g\":131},\"14461\":{\"m\":132,\"g\":131},\"14686\":{\"m\":132,\"g\":131},\"14601\":{\"m\":132,\"g\":131},\"14622\":{\"m\":132,\"g\":131},\"14714\":{\"m\":132,\"g\":131},\"14707\":{\"m\":132,\"g\":131},\"14699\":{\"m\":132,\"g\":131},\"14647\":{\"m\":132,\"g\":131},\"14648\":{\"m\":132,\"g\":131},\"14683\":{\"m\":132,\"g\":131},\"14678\":{\"m\":132,\"g\":131},\"14676\":{\"m\":132,\"g\":131},\"14529\":{\"m\":132,\"g\":131},\"14689\":{\"m\":132,\"g\":131},\"14627\":{\"m\":132,\"g\":131},\"14679\":{\"m\":132,\"g\":131},\"14469\":{\"m\":132,\"g\":131},\"14614\":{\"m\":132,\"g\":131},\"14653\":{\"m\":132,\"g\":131},\"13147\":{\"m\":132,\"g\":131},\"14652\":{\"m\":132,\"g\":131},\"13334\":{\"m\":132,\"g\":131},\"14489\":{\"m\":132,\"g\":131},\"14675\":{\"m\":132,\"g\":131},\"14671\":{\"m\":132,\"g\":131},\"16253\":{\"m\":134,\"g\":135},\"16107\":{\"m\":134,\"g\":135},\"16244\":\"m134\",\"16241\":{\"m\":134,\"g\":135},\"16211\":{\"m\":134,\"g\":135},\"16153\":{\"m\":134,\"g\":135},\"15942\":{\"m\":134,\"g\":135},\"16140\":{\"m\":134,\"g\":135},\"16142\":{\"m\":134,\"g\":135},\"16141\":{\"m\":134,\"g\":135},\"16129\":{\"m\":134,\"g\":135},\"10959\":{\"m\":134,\"g\":135},\"15888\":{\"m\":134,\"g\":135},\"16114\":{\"m\":134,\"g\":135},\"16131\":{\"m\":134,\"g\":135},\"16133\":{\"m\":134,\"g\":135},\"16053\":{\"m\":134,\"g\":135},\"16130\":{\"m\":134,\"g\":135},\"16105\":{\"m\":134,\"g\":135},\"15187\":{\"m\":134,\"g\":135},\"16123\":{\"m\":134,\"g\":135},\"15813\":{\"m\":134,\"g\":135},\"16081\":{\"m\":134,\"g\":135},\"15896\":{\"m\":134,\"g\":135},\"15877\":{\"m\":134,\"g\":135},\"15800\":{\"m\":134,\"g\":135},\"15985\":{\"m\":134,\"g\":135},\"16103\":{\"m\":134,\"g\":135},\"16099\":{\"m\":134,\"g\":135},\"15921\":{\"m\":134,\"g\":135},\"16101\":{\"m\":134,\"g\":135},\"16100\":{\"m\":134,\"g\":135},\"16098\":{\"m\":134,\"g\":135},\"16097\":{\"m\":134,\"g\":135},\"16094\":{\"m\":134,\"g\":135},\"16096\":{\"m\":134,\"g\":135},\"16093\":{\"m\":134,\"g\":135},\"15057\":{\"m\":134,\"g\":135},\"14838\":{\"m\":134,\"g\":135},\"16061\":{\"m\":134,\"g\":135},\"16066\":{\"m\":134,\"g\":135},\"16087\":{\"m\":134,\"g\":135},\"16069\":{\"m\":134,\"g\":135},\"16085\":{\"m\":134,\"g\":135},\"16062\":{\"m\":134,\"g\":135},\"14414\":{\"m\":134,\"g\":135},\"16047\":{\"m\":134,\"g\":135},\"15805\":{\"m\":134,\"g\":135},\"16054\":{\"m\":134,\"g\":135},\"16003\":{\"m\":134,\"g\":135},\"16046\":{\"m\":134,\"g\":135},\"16051\":{\"m\":134,\"g\":135},\"16038\":{\"m\":134,\"g\":135},\"16041\":{\"m\":134,\"g\":135},\"16039\":{\"m\":134,\"g\":135},\"16037\":{\"m\":134,\"g\":135},\"16035\":{\"m\":134,\"g\":135},\"16036\":{\"m\":134,\"g\":135},\"16010\":{\"m\":134,\"g\":135},\"14280\":{\"m\":134,\"g\":135},\"16028\":{\"m\":134,\"g\":135},\"16017\":{\"m\":134,\"g\":135},\"14873\":{\"m\":134,\"g\":135},\"15922\":{\"m\":134,\"g\":135},\"16016\":{\"m\":134,\"g\":135},\"15939\":{\"m\":134,\"g\":135},\"15998\":{\"m\":134,\"g\":135},\"15928\":{\"m\":134,\"g\":135},\"16008\":{\"m\":134,\"g\":135},\"16002\":{\"m\":134,\"g\":135},\"16013\":{\"m\":134,\"g\":135},\"16004\":{\"m\":134,\"g\":135},\"15615\":{\"m\":134,\"g\":135},\"15992\":{\"m\":134,\"g\":135},\"15991\":{\"m\":134,\"g\":135},\"16001\":{\"m\":134,\"g\":135},\"15945\":{\"m\":134,\"g\":135},\"15990\":{\"m\":134,\"g\":135},\"15988\":{\"m\":134,\"g\":135},\"15987\":{\"m\":134,\"g\":135},\"15891\":{\"m\":134,\"g\":135},\"15216\":{\"m\":134,\"g\":135},\"15693\":{\"m\":134,\"g\":135},\"15353\":{\"m\":134,\"g\":135},\"15835\":{\"m\":134,\"g\":135},\"15806\":{\"m\":134,\"g\":135},\"15937\":{\"m\":134,\"g\":135},\"15986\":{\"m\":134,\"g\":135},\"12596\":{\"m\":134,\"g\":135},\"15919\":{\"m\":134,\"g\":135},\"15947\":{\"m\":134,\"g\":135},\"15925\":{\"m\":134,\"g\":135},\"15936\":{\"m\":134,\"g\":135},\"15886\":{\"m\":134,\"g\":135},\"15943\":{\"m\":134,\"g\":135},\"15935\":{\"m\":134,\"g\":135},\"15934\":{\"m\":134,\"g\":135},\"15933\":{\"m\":134,\"g\":135},\"14736\":{\"m\":134,\"g\":135},\"15923\":{\"m\":134,\"g\":135},\"15907\":{\"m\":134,\"g\":135},\"15887\":{\"m\":134,\"g\":135},\"15920\":{\"m\":134,\"g\":135},\"15918\":{\"m\":134,\"g\":135},\"15905\":{\"m\":134,\"g\":135},\"15850\":{\"m\":134,\"g\":135},\"15915\":{\"m\":134,\"g\":135},\"15398\":{\"m\":134,\"g\":135},\"15914\":{\"m\":134,\"g\":135},\"15916\":{\"m\":134,\"g\":135},\"15913\":{\"m\":134,\"g\":135},\"15911\":{\"m\":134,\"g\":135},\"15910\":{\"m\":134,\"g\":135},\"14750\":{\"m\":134,\"g\":135},\"15906\":{\"m\":134,\"g\":135},\"15778\":{\"m\":134,\"g\":135},\"15844\":{\"m\":134,\"g\":135},\"15812\":{\"m\":134,\"g\":135},\"15842\":{\"m\":134,\"g\":135},\"15881\":{\"m\":134,\"g\":135},\"15874\":{\"m\":134,\"g\":135},\"15889\":{\"m\":134,\"g\":135},\"15846\":{\"m\":134,\"g\":135},\"14209\":{\"m\":134,\"g\":135},\"15849\":{\"m\":134,\"g\":135},\"15817\":{\"m\":134,\"g\":135},\"15870\":{\"m\":134,\"g\":135},\"15867\":{\"m\":134,\"g\":135},\"14644\":{\"m\":134,\"g\":135},\"15518\":{\"m\":134,\"g\":135},\"15821\":{\"m\":134,\"g\":135},\"15369\":{\"m\":134,\"g\":135},\"15858\":{\"m\":134,\"g\":135},\"15857\":{\"m\":134,\"g\":135},\"15820\":{\"m\":134,\"g\":135},\"15851\":{\"m\":134,\"g\":135},\"15701\":{\"m\":134,\"g\":135},\"15791\":{\"m\":134,\"g\":135},\"15847\":{\"m\":134,\"g\":135},\"15522\":{\"m\":134,\"g\":135},\"15796\":{\"m\":134,\"g\":135},\"15826\":{\"m\":134,\"g\":135},\"15822\":{\"m\":134,\"g\":135},\"15772\":{\"m\":134,\"g\":135},\"15356\":{\"m\":134,\"g\":135},\"15759\":{\"m\":134,\"g\":135},\"15827\":{\"m\":134,\"g\":135},\"15488\":{\"m\":134,\"g\":135},\"15815\":{\"m\":134,\"g\":135},\"15818\":{\"m\":134,\"g\":135},\"15802\":{\"m\":134,\"g\":135},\"15666\":{\"m\":134,\"g\":135},\"15709\":{\"m\":134,\"g\":135},\"14741\":{\"m\":134,\"g\":135},\"15803\":{\"m\":134,\"g\":135},\"15409\":{\"m\":134,\"g\":135},\"15798\":{\"m\":134,\"g\":135},\"15811\":{\"m\":134,\"g\":135},\"15720\":{\"m\":134,\"g\":135},\"15736\":{\"m\":134,\"g\":135},\"15801\":{\"m\":134,\"g\":135},\"15555\":{\"m\":134,\"g\":135},\"15770\":{\"m\":134,\"g\":135},\"11469\":{\"m\":134,\"g\":135},\"15586\":{\"m\":134,\"g\":135},\"15596\":{\"m\":134,\"g\":135},\"14032\":{\"m\":134,\"g\":135},\"15787\":{\"m\":134,\"g\":135},\"15700\":{\"m\":134,\"g\":135},\"15781\":{\"m\":134,\"g\":133},\"15149\":{\"m\":134,\"g\":133},\"15775\":{\"m\":134,\"g\":133},\"15782\":{\"m\":134,\"g\":133},\"15758\":{\"m\":134,\"g\":133},\"15745\":{\"m\":134,\"g\":133},\"15750\":{\"m\":134,\"g\":133},\"15780\":{\"m\":134,\"g\":133},\"15741\":{\"m\":134,\"g\":133},\"14137\":{\"m\":134,\"g\":133},\"15390\":{\"m\":134,\"g\":133},\"15718\":{\"m\":134,\"g\":133},\"15769\":{\"m\":134,\"g\":133},\"15768\":{\"m\":134,\"g\":133},\"15653\":{\"m\":134,\"g\":133},\"15652\":{\"m\":134,\"g\":133},\"15752\":{\"m\":134,\"g\":133},\"15747\":{\"m\":134,\"g\":133},\"15748\":{\"m\":134,\"g\":133},\"15743\":{\"m\":134,\"g\":133},\"15740\":{\"m\":134,\"g\":133},\"15655\":{\"m\":134,\"g\":133},\"15706\":{\"m\":134,\"g\":133},\"15459\":{\"m\":134,\"g\":133},\"15689\":{\"m\":134,\"g\":133},\"15593\":{\"m\":134,\"g\":133},\"15704\":{\"m\":134,\"g\":133},\"15691\":{\"m\":134,\"g\":133},\"15656\":{\"m\":134,\"g\":133},\"15717\":{\"m\":134,\"g\":133},\"15715\":{\"m\":134,\"g\":133},\"15722\":{\"m\":134,\"g\":133},\"15716\":{\"m\":134,\"g\":133},\"15719\":{\"m\":134,\"g\":133},\"11828\":{\"m\":134,\"g\":133},\"15500\":{\"m\":134,\"g\":133},\"15622\":{\"m\":134,\"g\":133},\"15624\":{\"m\":134,\"g\":133},\"15705\":{\"m\":134,\"g\":133},\"15177\":{\"m\":134,\"g\":133},\"15702\":{\"m\":134,\"g\":133},\"15538\":{\"m\":134,\"g\":133},\"15667\":{\"m\":134,\"g\":133},\"15695\":{\"m\":134,\"g\":133},\"15696\":{\"m\":134,\"g\":133},\"15606\":{\"m\":134,\"g\":133},\"15273\":{\"m\":134,\"g\":133},\"15672\":{\"m\":134,\"g\":133},\"15694\":{\"m\":134,\"g\":133},\"15469\":{\"m\":134,\"g\":133},\"12968\":{\"m\":134,\"g\":133},\"15644\":{\"m\":134,\"g\":133},\"15692\":{\"m\":134,\"g\":133},\"14570\":{\"m\":134,\"g\":133},\"15091\":{\"m\":134,\"g\":133},\"15646\":{\"m\":134,\"g\":133},\"15633\":{\"m\":134,\"g\":133},\"15688\":{\"m\":134,\"g\":133},\"15684\":{\"m\":134,\"g\":133},\"15312\":{\"m\":134,\"g\":133},\"15600\":{\"m\":134,\"g\":133},\"15463\":{\"m\":134,\"g\":133},\"15460\":{\"m\":134,\"g\":133},\"14983\":{\"m\":134,\"g\":133},\"15563\":{\"m\":134,\"g\":133},\"15582\":{\"m\":134,\"g\":133},\"14628\":{\"m\":134,\"g\":133},\"15539\":{\"m\":134,\"g\":133},\"15570\":{\"m\":134,\"g\":133},\"15635\":{\"m\":134,\"g\":133},\"15590\":{\"m\":134,\"g\":133},\"13576\":{\"m\":134,\"g\":133},\"15632\":{\"m\":134,\"g\":133},\"15621\":{\"m\":134,\"g\":133},\"15368\":{\"m\":134,\"g\":133},\"15611\":{\"m\":134,\"g\":133},\"15610\":{\"m\":134,\"g\":133},\"15572\":{\"m\":134,\"g\":133},\"15573\":{\"m\":134,\"g\":133},\"15616\":{\"m\":134,\"g\":133},\"15613\":{\"m\":134,\"g\":133},\"15607\":{\"m\":134,\"g\":133},\"15612\":{\"m\":134,\"g\":133},\"15164\":{\"m\":134,\"g\":133},\"15566\":{\"m\":134,\"g\":133},\"15599\":{\"m\":134,\"g\":133},\"15589\":{\"m\":134,\"g\":133},\"15588\":{\"m\":134,\"g\":133},\"15587\":{\"m\":134,\"g\":133},\"15565\":{\"m\":134,\"g\":133},\"15581\":{\"m\":134,\"g\":133},\"15585\":{\"m\":134,\"g\":133},\"15230\":{\"m\":134,\"g\":133},\"15427\":{\"m\":134,\"g\":133},\"15553\":{\"m\":134,\"g\":133},\"15583\":{\"m\":134,\"g\":133},\"15580\":{\"m\":134,\"g\":133},\"15569\":{\"m\":134,\"g\":133},\"14901\":{\"m\":134,\"g\":133},\"15564\":{\"m\":134,\"g\":133},\"15578\":{\"m\":134,\"g\":133},\"15579\":{\"m\":134,\"g\":133},\"15509\":{\"m\":134,\"g\":133},\"15540\":{\"m\":134,\"g\":133},\"15568\":{\"m\":134,\"g\":133},\"15558\":{\"m\":134,\"g\":133},\"12162\":{\"m\":134,\"g\":133},\"15531\":{\"m\":134,\"g\":133},\"15111\":{\"m\":134,\"g\":133},\"15432\":{\"m\":134,\"g\":133},\"15554\":{\"m\":134,\"g\":133},\"15556\":{\"m\":134,\"g\":133},\"15537\":{\"m\":134,\"g\":133},\"15520\":{\"m\":134,\"g\":133},\"15447\":{\"m\":134,\"g\":133},\"15324\":{\"m\":134,\"g\":133},\"15436\":{\"m\":134,\"g\":133},\"14134\":{\"m\":134,\"g\":133},\"15552\":{\"m\":134,\"g\":133},\"15526\":{\"m\":134,\"g\":133},\"15547\":{\"m\":134,\"g\":133},\"15413\":{\"m\":134,\"g\":133},\"15515\":{\"m\":134,\"g\":133},\"15544\":{\"m\":134,\"g\":133},\"15542\":{\"m\":134,\"g\":133},\"15418\":{\"m\":134,\"g\":133},\"15479\":{\"m\":134,\"g\":133},\"15533\":{\"m\":134,\"g\":133},\"15534\":{\"m\":134,\"g\":133},\"15530\":{\"m\":134,\"g\":133},\"15511\":{\"m\":134,\"g\":133},\"15536\":{\"m\":134,\"g\":133},\"15320\":{\"m\":134,\"g\":133},\"14091\":{\"m\":134,\"g\":133},\"15464\":{\"m\":134,\"g\":133},\"15484\":{\"m\":134,\"g\":133},\"15172\":{\"m\":134,\"g\":133},\"15178\":{\"m\":134,\"g\":133},\"15498\":{\"m\":134,\"g\":133},\"15333\":{\"m\":134,\"g\":133},\"15267\":{\"m\":134,\"g\":133},\"15507\":{\"m\":134,\"g\":133},\"15510\":{\"m\":134,\"g\":133},\"15485\":{\"m\":134,\"g\":133},\"15473\":{\"m\":134,\"g\":133},\"15505\":{\"m\":134,\"g\":133},\"15504\":{\"m\":134,\"g\":133},\"15503\":{\"m\":134,\"g\":133},\"15497\":{\"m\":134,\"g\":133},\"15040\":{\"m\":134,\"g\":133},\"15296\":{\"m\":134,\"g\":133},\"15496\":{\"m\":134,\"g\":133},\"15495\":{\"m\":134,\"g\":133},\"15494\":{\"m\":134,\"g\":133},\"15491\":{\"m\":134,\"g\":133},\"15022\":{\"m\":134,\"g\":133},\"14164\":{\"m\":134,\"g\":133},\"15348\":{\"m\":134,\"g\":133},\"15406\":{\"m\":134,\"g\":133},\"15483\":{\"m\":134,\"g\":133},\"15166\":{\"m\":134,\"g\":133},\"14138\":{\"m\":134,\"g\":133},\"15408\":{\"m\":134,\"g\":133},\"15416\":{\"m\":134,\"g\":133},\"15219\":{\"m\":134,\"g\":133},\"15478\":{\"m\":134,\"g\":133},\"15382\":{\"m\":134,\"g\":133},\"12995\":{\"m\":134,\"g\":133},\"15394\":{\"m\":134,\"g\":133},\"15458\":{\"m\":134,\"g\":133},\"15262\":{\"m\":134,\"g\":133},\"15437\":{\"m\":134,\"g\":133},\"14395\":{\"m\":134,\"g\":133},\"15253\":{\"m\":134,\"g\":133},\"15415\":{\"m\":134,\"g\":133},\"15433\":{\"m\":134,\"g\":133},\"13402\":{\"m\":134,\"g\":133},\"13760\":{\"m\":134,\"g\":133},\"15340\":{\"m\":134,\"g\":133},\"13782\":{\"m\":134,\"g\":133},\"15423\":{\"m\":134,\"g\":133},\"15425\":{\"m\":134,\"g\":133},\"15287\":{\"m\":134,\"g\":133},\"15207\":{\"m\":134,\"g\":133},\"15410\":{\"m\":134,\"g\":133},\"15371\":{\"m\":134,\"g\":133},\"15429\":{\"m\":134,\"g\":133},\"15431\":{\"m\":134,\"g\":133},\"14723\":{\"m\":134,\"g\":133},\"14843\":{\"m\":134,\"g\":133},\"14353\":{\"m\":134,\"g\":133},\"15424\":{\"m\":134,\"g\":133},\"14354\":{\"m\":134,\"g\":133},\"15318\":{\"m\":134,\"g\":133},\"14781\":{\"m\":134,\"g\":133},\"15421\":{\"m\":134,\"g\":133},\"15352\":{\"m\":134,\"g\":133},\"15395\":{\"m\":134,\"g\":133},\"15401\":{\"m\":134,\"g\":133},\"15407\":{\"m\":134,\"g\":133},\"15400\":{\"m\":134,\"g\":133},\"15396\":{\"m\":134,\"g\":133},\"15404\":{\"m\":134,\"g\":133},\"15397\":{\"m\":134,\"g\":133},\"12921\":{\"m\":134,\"g\":133},\"15298\":{\"m\":134,\"g\":133},\"15141\":{\"m\":134,\"g\":133},\"15379\":{\"m\":134,\"g\":133},\"15306\":{\"m\":134,\"g\":133},\"14270\":{\"m\":134,\"g\":133},\"15384\":{\"m\":134,\"g\":133},\"15361\":{\"m\":134,\"g\":133},\"15372\":{\"m\":134,\"g\":133},\"12967\":{\"m\":134,\"g\":133},\"15290\":{\"m\":134,\"g\":133},\"14501\":{\"m\":134,\"g\":133},\"15049\":{\"m\":134,\"g\":133},\"15354\":{\"m\":134,\"g\":133},\"15337\":{\"m\":134,\"g\":133},\"15278\":{\"m\":134,\"g\":133},\"15131\":{\"m\":134,\"g\":133},\"15205\":{\"m\":134,\"g\":133},\"15307\":{\"m\":134,\"g\":133},\"14860\":{\"m\":134,\"g\":133},\"15176\":{\"m\":134,\"g\":133},\"15277\":{\"m\":134,\"g\":133},\"15328\":{\"m\":134,\"g\":133},\"15120\":{\"m\":134,\"g\":133},\"15241\":{\"m\":134,\"g\":133},\"15308\":{\"m\":134,\"g\":133},\"15336\":{\"m\":134,\"g\":133},\"15316\":{\"m\":134,\"g\":133},\"15335\":{\"m\":134,\"g\":133},\"15329\":{\"m\":134,\"g\":133},\"15186\":{\"m\":134,\"g\":133},\"15222\":{\"m\":134,\"g\":133},\"15198\":{\"m\":134,\"g\":133},\"15189\":{\"m\":134,\"g\":133},\"15326\":{\"m\":134,\"g\":133},\"15237\":{\"m\":134,\"g\":133},\"15138\":{\"m\":134,\"g\":133},\"14918\":{\"m\":134,\"g\":133},\"11914\":{\"m\":134,\"g\":133},\"15304\":{\"m\":134,\"g\":133},\"15223\":{\"m\":134,\"g\":133},\"15233\":{\"m\":134,\"g\":133},\"15314\":{\"m\":134,\"g\":133},\"12333\":{\"m\":134,\"g\":133},\"15071\":{\"m\":134,\"g\":133},\"15284\":{\"m\":134,\"g\":133},\"14449\":{\"m\":134,\"g\":133},\"14357\":{\"m\":134,\"g\":133},\"15088\":{\"m\":134,\"g\":133},\"14376\":{\"m\":134,\"g\":133},\"15155\":{\"m\":134,\"g\":133},\"13571\":{\"m\":134,\"g\":133},\"15283\":{\"m\":134,\"g\":133},\"15218\":{\"m\":134,\"g\":133},\"15297\":{\"m\":134,\"g\":133},\"15258\":{\"m\":134,\"g\":133},\"15280\":{\"m\":134,\"g\":133},\"15293\":{\"m\":134,\"g\":133},\"15292\":{\"m\":134,\"g\":133},\"15291\":{\"m\":134,\"g\":133},\"15282\":{\"m\":134,\"g\":133},\"15242\":{\"m\":134,\"g\":133},\"14997\":{\"m\":134,\"g\":133},\"14934\":{\"m\":134,\"g\":133},\"15281\":{\"m\":134,\"g\":133},\"14857\":{\"m\":134,\"g\":133},\"14975\":{\"m\":134,\"g\":133},\"14936\":{\"m\":134,\"g\":133},\"15234\":{\"m\":134,\"g\":133},\"15270\":{\"m\":134,\"g\":133},\"15232\":{\"m\":134,\"g\":133},\"15192\":{\"m\":134,\"g\":133},\"15100\":{\"m\":134,\"g\":133},\"9337\":{\"m\":134,\"g\":133},\"15239\":{\"m\":134,\"g\":133},\"15220\":{\"m\":134,\"g\":133},\"14866\":{\"m\":134,\"g\":133},\"14415\":{\"m\":134,\"g\":133},\"15180\":{\"m\":134,\"g\":133},\"15225\":{\"m\":134,\"g\":133},\"15162\":{\"m\":134,\"g\":133},\"14990\":{\"m\":134,\"g\":133},\"15229\":{\"m\":134,\"g\":133},\"15231\":{\"m\":134,\"g\":133},\"15204\":{\"m\":134,\"g\":133},\"15092\":{\"m\":134,\"g\":133},\"15224\":{\"m\":134,\"g\":133},\"13410\":{\"m\":134,\"g\":133},\"15212\":{\"m\":134,\"g\":133},\"15185\":{\"m\":134,\"g\":133},\"15153\":{\"m\":134,\"g\":133},\"14820\":{\"m\":134,\"g\":133},\"15201\":{\"m\":134,\"g\":133},\"15127\":{\"m\":134,\"g\":133},\"15191\":{\"m\":134,\"g\":133},\"15190\":{\"m\":134,\"g\":133},\"15196\":{\"m\":134,\"g\":133},\"15193\":{\"m\":134,\"g\":133},\"15160\":{\"m\":134,\"g\":133},\"15047\":{\"m\":134,\"g\":133},\"15017\":{\"m\":134,\"g\":133},\"15163\":{\"m\":134,\"g\":133},\"15152\":{\"m\":134,\"g\":133},\"14906\":{\"m\":134,\"g\":133},\"15116\":{\"m\":134,\"g\":133},\"14862\":{\"m\":134,\"g\":133},\"15174\":{\"m\":134,\"g\":133},\"9324\":{\"m\":134,\"g\":133},\"15170\":{\"m\":134,\"g\":133},\"15005\":{\"m\":134,\"g\":133},\"13914\":{\"m\":134,\"g\":133},\"15158\":{\"m\":134,\"g\":133},\"15156\":{\"m\":134,\"g\":133},\"15154\":{\"m\":134,\"g\":133},\"15144\":{\"m\":134,\"g\":133},\"14778\":{\"m\":134,\"g\":133},\"15147\":{\"m\":134,\"g\":133},\"15146\":{\"m\":134,\"g\":133},\"15130\":{\"m\":134,\"g\":133},\"14907\":{\"m\":134,\"g\":133},\"14764\":{\"m\":134,\"g\":133},\"14792\":{\"m\":134,\"g\":133},\"15142\":{\"m\":134,\"g\":133},\"15139\":{\"m\":134,\"g\":133},\"15101\":{\"m\":134,\"g\":133},\"14938\":{\"m\":134,\"g\":133},\"15134\":{\"m\":134,\"g\":133},\"15058\":{\"m\":134,\"g\":133},\"15086\":{\"m\":134,\"g\":133},\"15099\":{\"m\":134,\"g\":133},\"15098\":{\"m\":134,\"g\":133},\"14953\":{\"m\":134,\"g\":133},\"15052\":{\"m\":134,\"g\":133},\"15044\":{\"m\":134,\"g\":133},\"15125\":{\"m\":134,\"g\":133},\"15110\":{\"m\":134,\"g\":133},\"14791\":{\"m\":134,\"g\":133},\"15113\":{\"m\":134,\"g\":133},\"15117\":{\"m\":134,\"g\":133},\"15124\":{\"m\":134,\"g\":133},\"15121\":{\"m\":134,\"g\":133},\"14874\":{\"m\":134,\"g\":133},\"15106\":{\"m\":134,\"g\":133},\"15090\":{\"m\":134,\"g\":133},\"15114\":{\"m\":134,\"g\":133},\"12263\":{\"m\":134,\"g\":133},\"13969\":{\"m\":134,\"g\":133},\"14961\":{\"m\":134,\"g\":133},\"15062\":{\"m\":134,\"g\":133},\"15053\":{\"m\":134,\"g\":133},\"14881\":{\"m\":134,\"g\":133},\"15108\":{\"m\":134,\"g\":133},\"15048\":{\"m\":134,\"g\":133},\"14294\":{\"m\":134,\"g\":133},\"15084\":{\"m\":134,\"g\":133},\"14969\":{\"m\":134,\"g\":133},\"15096\":{\"m\":134,\"g\":133},\"15061\":{\"m\":134,\"g\":133},\"15095\":{\"m\":134,\"g\":133},\"15094\":{\"m\":134,\"g\":133},\"15093\":{\"m\":134,\"g\":133},\"14943\":{\"m\":134,\"g\":133},\"15087\":{\"m\":134,\"g\":133},\"14993\":{\"m\":134,\"g\":133},\"14956\":{\"m\":134,\"g\":133},\"15056\":{\"m\":134,\"g\":133},\"15066\":{\"m\":134,\"g\":133},\"15085\":{\"m\":134,\"g\":133},\"15064\":{\"m\":134,\"g\":133},\"14935\":{\"m\":134,\"g\":133},\"15065\":{\"m\":134,\"g\":133},\"14998\":{\"m\":134,\"g\":133},\"14201\":{\"m\":134,\"g\":133},\"14940\":{\"m\":134,\"g\":133},\"15054\":{\"m\":134,\"g\":133},\"15055\":{\"m\":134,\"g\":133},\"15080\":{\"m\":134,\"g\":133},\"15079\":{\"m\":134,\"g\":133},\"14742\":{\"m\":134,\"g\":133},\"15074\":{\"m\":134,\"g\":133},\"15072\":{\"m\":134,\"g\":133},\"13740\":{\"m\":134,\"g\":133},\"15069\":{\"m\":134,\"g\":133},\"15060\":{\"m\":134,\"g\":133},\"15059\":{\"m\":134,\"g\":133},\"14422\":{\"m\":134,\"g\":133},\"14855\":{\"m\":134,\"g\":133},\"14423\":{\"m\":134,\"g\":133},\"15027\":{\"m\":134,\"g\":133},\"13989\":{\"m\":134,\"g\":133},\"14924\":{\"m\":134,\"g\":133},\"14659\":{\"m\":134,\"g\":133},\"15002\":{\"m\":134,\"g\":133},\"15034\":{\"m\":134,\"g\":133},\"14485\":{\"m\":134,\"g\":133},\"13876\":{\"m\":134,\"g\":133},\"15037\":{\"m\":134,\"g\":133},\"15036\":{\"m\":134,\"g\":133},\"15032\":{\"m\":134,\"g\":133},\"15031\":{\"m\":134,\"g\":133},\"15030\":{\"m\":134,\"g\":133},\"15028\":{\"m\":134,\"g\":133},\"15035\":{\"m\":134,\"g\":133},\"14939\":{\"m\":134,\"g\":133},\"15024\":{\"m\":134,\"g\":133},\"15033\":{\"m\":134,\"g\":133},\"15009\":{\"m\":134,\"g\":133},\"14957\":{\"m\":134,\"g\":133},\"14955\":{\"m\":134,\"g\":133},\"15015\":{\"m\":134,\"g\":133},\"15023\":{\"m\":134,\"g\":133},\"15021\":{\"m\":134,\"g\":133},\"14992\":{\"m\":134,\"g\":133},\"14976\":{\"m\":134,\"g\":133},\"14966\":{\"m\":134,\"g\":133},\"14933\":{\"m\":134,\"g\":133},\"14795\":{\"m\":134,\"g\":133},\"15020\":{\"m\":134,\"g\":133},\"15010\":{\"m\":134,\"g\":133},\"15014\":{\"m\":134,\"g\":133},\"14869\":{\"m\":134,\"g\":133},\"14989\":{\"m\":134,\"g\":133},\"14958\":{\"m\":134,\"g\":133},\"14951\":{\"m\":134,\"g\":133},\"15004\":{\"m\":134,\"g\":133},\"15001\":{\"m\":134,\"g\":133},\"15000\":{\"m\":134,\"g\":133},\"14999\":{\"m\":134,\"g\":133},\"14996\":{\"m\":134,\"g\":133},\"14995\":{\"m\":134,\"g\":133},\"14987\":{\"m\":134,\"g\":133},\"14945\":{\"m\":134,\"g\":133},\"14985\":{\"m\":134,\"g\":133},\"14534\":{\"m\":134,\"g\":133},\"14916\":{\"m\":134,\"g\":133},\"14959\":{\"m\":134,\"g\":133},\"14937\":{\"m\":134,\"g\":133},\"13798\":{\"m\":134,\"g\":133},\"14572\":{\"m\":134,\"g\":133},\"14949\":{\"m\":134,\"g\":133},\"14842\":{\"m\":134,\"g\":133},\"13699\":{\"m\":134,\"g\":133},\"14944\":{\"m\":134,\"g\":133},\"14888\":{\"m\":134,\"g\":133},\"14893\":{\"m\":134,\"g\":133},\"14892\":{\"m\":134,\"g\":133},\"14941\":{\"m\":134,\"g\":133},\"14894\":{\"m\":134,\"g\":133},\"14839\":{\"m\":134,\"g\":133},\"11852\":{\"m\":134,\"g\":133},\"14870\":{\"m\":134,\"g\":133},\"14358\":{\"m\":134,\"g\":133},\"14923\":{\"m\":134,\"g\":133},\"14891\":{\"m\":134,\"g\":133},\"14909\":{\"m\":134,\"g\":133},\"14467\":{\"m\":134,\"g\":133},\"14927\":{\"m\":134,\"g\":133},\"14931\":{\"m\":134,\"g\":133},\"13730\":{\"m\":134,\"g\":133},\"14932\":{\"m\":134,\"g\":133},\"14748\":{\"m\":134,\"g\":133},\"14525\":{\"m\":134,\"g\":133},\"14771\":{\"m\":134,\"g\":133},\"14074\":{\"m\":134,\"g\":133},\"14922\":{\"m\":134,\"g\":133},\"13125\":{\"m\":134,\"g\":133},\"14921\":{\"m\":134,\"g\":133},\"14928\":{\"m\":134,\"g\":133},\"14925\":{\"m\":134,\"g\":133},\"17458\":\"m136\",\"17569\":\"m136\",\"17591\":\"m136\",\"17553\":\"m136\",\"15859\":\"m136\",\"17460\":\"m136\",\"17474\":\"m136\",\"17541\":\"m136\",\"17518\":\"m136\",\"17539\":\"m136\",\"16366\":\"m136\",\"17514\":\"m136\",\"17536\":\"m136\",\"17529\":\"m136\",\"17534\":\"m136\",\"17442\":\"m136\",\"17493\":\"m136\",\"17528\":\"m136\",\"17524\":\"m136\",\"16927\":\"m136\",\"17486\":\"m136\",\"17519\":\"m136\",\"17416\":\"m136\",\"17490\":\"m136\",\"17517\":\"m136\",\"17394\":\"m136\",\"17498\":\"m136\",\"16034\":\"m136\",\"17510\":\"m136\",\"17166\":\"m136\",\"16919\":\"m136\",\"17108\":\"m136\",\"17417\":\"m136\",\"16670\":\"m136\",\"17355\":\"m136\",\"17399\":\"m136\",\"17400\":\"m136\",\"17397\":\"m136\",\"17372\":\"m136\",\"16396\":\"m136\",\"17457\":\"m136\",\"17462\":\"m136\",\"17466\":\"m136\",\"17425\":\"m136\",\"17465\":\"m136\",\"17452\":\"m136\",\"17386\":\"m136\",\"17293\":\"m136\",\"17455\":\"m136\",\"17251\":\"m136\",\"17444\":\"m136\",\"17443\":\"m136\",\"17439\":\"m136\",\"17290\":\"m136\",\"17291\":\"m136\",\"17313\":\"m136\",\"17436\":\"m136\",\"17428\":\"m136\",\"17289\":\"m136\",\"17334\":\"m136\",\"17403\":\"m136\",\"17429\":\"m136\",\"17247\":\"m136\",\"17205\":\"m136\",\"17288\":\"m136\",\"17419\":\"m136\",\"17414\":\"m136\",\"17409\":\"m136\",\"17358\":\"m136\",\"11657\":\"m136\",\"17382\":\"m136\",\"17160\":\"m136\",\"17179\":\"m136\",\"17327\":\"m136\",\"17385\":\"m136\",\"17302\":\"m136\",\"17339\":\"m136\",\"15325\":\"m136\",\"17364\":\"m136\",\"16880\":\"m136\",\"17378\":\"m136\",\"17376\":\"m136\",\"17370\":\"m136\",\"17043\":\"m136\",\"17088\":\"m136\",\"17336\":\"m136\",\"17367\":\"m136\",\"17366\":\"m136\",\"17363\":\"m136\",\"17329\":\"m136\",\"17049\":\"m136\",\"17345\":\"m136\",\"17116\":\"m136\",\"17142\":\"m136\",\"17220\":\"m136\",\"16567\":\"m136\",\"17305\":\"m136\",\"17309\":\"m136\",\"17158\":\"m136\",\"17177\":\"m136\",\"17332\":\"m136\",\"17238\":\"m136\",\"17245\":\"m136\",\"17241\":\"m136\",\"16412\":\"m136\",\"17325\":\"m136\",\"16744\":\"m136\",\"17317\":\"m136\",\"17319\":\"m136\",\"16961\":\"m136\",\"15347\":\"m136\",\"17315\":\"m136\",\"17306\":\"m136\",\"15512\":\"m136\",\"17308\":\"m136\",\"17235\":\"m136\",\"15631\":\"m136\",\"14197\":\"m136\",\"16649\":\"m136\",\"17296\":\"m136\",\"17212\":\"m136\",\"16534\":\"m136\",\"17295\":\"m136\",\"17236\":\"m136\",\"17281\":\"m136\",\"16974\":\"m136\",\"17287\":\"m136\",\"17182\":\"m136\",\"17264\":\"m136\",\"16152\":\"m136\",\"17261\":\"m136\",\"17250\":\"m136\",\"17234\":\"m136\",\"17225\":\"m136\",\"17256\":\"m136\",\"17257\":\"m136\",\"17048\":\"m136\",\"14883\":\"m136\",\"17191\":\"m136\",\"17252\":\"m136\",\"16561\":\"m136\",\"17248\":\"m136\",\"17249\":\"m136\",\"16879\":\"m136\",\"17045\":\"m136\",\"17047\":\"m136\",\"17044\":\"m136\",\"16951\":\"m136\",\"17242\":\"m136\",\"16354\":\"m136\",\"16817\":\"m136\",\"16824\":\"m136\",\"17232\":\"m136\",\"17230\":\"m136\",\"17233\":\"m136\",\"17165\":\"m136\",\"16925\":\"m136\",\"17187\":\"m136\",\"17217\":\"m136\",\"16842\":\"m136\",\"13672\":\"m136\",\"15551\":\"m136\",\"17133\":\"m136\",\"17016\":\"m136\",\"17200\":\"m136\",\"17038\":\"m136\",\"13216\":\"m136\",\"17143\":\"m136\",\"17105\":\"m136\",\"17184\":\"m136\",\"17173\":\"m136\",\"17051\":\"m136\",\"16934\":\"m136\",\"17186\":\"m136\",\"16121\":\"m136\",\"17180\":\"m136\",\"15513\":\"m136\",\"17041\":\"m136\",\"14565\":\"m136\",\"17100\":\"m136\",\"14579\":\"m136\",\"17178\":\"m136\",\"16882\":\"m136\",\"16826\":\"m136\",\"16369\":\"m136\",\"17176\":\"m136\",\"17174\":\"m136\",\"11028\":\"m136\",\"15455\":\"m136\",\"17170\":\"m136\",\"10598\":\"m136\",\"17091\":\"m136\",\"17168\":\"m136\",\"17167\":\"m136\",\"16568\":\"m136\",\"15789\":\"m136\",\"17163\":\"m136\",\"17099\":\"m136\",\"17126\":\"m136\",\"16965\":\"m136\",\"17111\":\"m136\",\"17113\":\"m136\",\"17020\":\"m136\",\"17064\":\"m136\",\"17103\":\"m136\",\"17092\":\"m136\",\"16949\":\"m136\",\"17056\":\"m136\",\"17075\":\"m136\",\"17061\":\"m136\",\"17028\":\"m136\",\"17101\":\"m136\",\"17052\":\"m136\",\"12497\":\"m136\",\"14504\":\"m136\",\"17087\":\"m136\",\"16278\":\"m136\",\"16971\":\"m136\",\"16976\":\"m136\",\"16403\":\"m136\",\"17058\":\"m136\",\"17077\":\"m136\",\"16264\":\"m136\",\"7392\":\"m136\",\"16732\":\"m136\",\"16899\":\"m136\",\"16569\":\"m136\",\"17013\":\"m136\",\"16941\":\"m136\",\"16962\":\"m136\",\"16888\":\"m136\",\"15908\":\"m136\",\"16989\":\"m136\",\"17046\":\"m136\",\"17030\":\"m136\",\"17027\":\"m136\",\"17054\":\"m136\",\"16898\":\"m136\",\"16994\":\"m136\",\"16841\":\"m136\",\"17002\":\"m136\",\"14108\":\"m136\",\"17022\":\"m136\",\"17005\":\"m136\",\"16982\":\"m136\",\"16894\":\"m136\",\"16953\":\"m136\",\"17019\":\"m136\",\"16886\":\"m136\",\"16259\":\"m136\",\"16924\":\"m136\",\"16986\":\"m136\",\"16790\":\"m136\",\"16935\":\"m136\",\"16884\":\"m136\",\"16967\":\"m136\",\"16766\":\"m136\",\"16564\":\"m136\",\"16767\":\"m136\",\"16648\":\"m136\",\"16252\":\"m136\",\"16940\":\"m136\",\"15853\":\"m136\",\"16765\":\"m136\",\"16978\":\"m136\",\"16981\":\"m136\",\"16980\":\"m136\",\"16979\":\"m136\",\"16977\":\"m136\",\"16973\":\"m136\",\"16970\":\"m136\",\"16192\":\"m136\",\"16348\":\"m136\",\"16963\":\"m136\",\"15271\":\"m136\",\"16933\":\"m136\",\"16480\":\"m136\",\"16922\":\"m136\",\"16916\":\"m136\",\"16912\":\"m136\",\"16019\":\"m136\",\"16559\":\"m136\",\"16677\":\"m136\",\"16930\":\"m136\",\"16932\":\"m136\",\"16536\":\"m136\",\"16850\":\"m136\",\"16915\":\"m136\",\"16896\":\"m136\",\"16820\":\"m136\",\"15268\":\"m136\",\"12909\":\"m136\",\"16908\":\"m136\",\"16906\":\"m136\",\"16851\":\"m136\",\"14867\":\"m136\",\"16895\":\"m136\",\"16300\":\"m136\",\"16864\":\"m136\",\"16889\":\"m136\",\"15182\":\"m136\",\"16876\":\"m136\",\"16258\":\"m136\",\"16835\":\"m136\",\"16345\":\"m136\",\"16878\":\"m136\",\"16867\":\"m136\",\"16273\":\"m136\",\"16757\":\"m136\",\"16865\":\"m136\",\"16877\":\"m136\",\"16863\":\"m136\",\"16667\":\"m136\",\"16788\":\"m136\",\"16397\":\"m136\",\"16737\":\"m136\",\"16872\":\"m136\",\"16871\":\"m136\",\"16870\":\"m136\",\"15790\":\"m136\",\"15927\":\"m136\",\"15227\":\"m136\",\"16226\":\"m136\",\"16844\":\"m136\",\"16838\":\"m136\",\"16849\":\"m136\",\"16333\":\"m136\",\"16779\":\"m136\",\"16854\":\"m136\",\"16862\":\"m136\",\"16572\":\"m136\",\"16805\":\"m136\",\"16723\":\"m136\",\"13715\":\"m136\",\"16853\":\"m136\",\"16637\":\"m136\",\"16852\":\"m136\",\"16847\":\"m136\",\"16848\":\"m136\",\"16845\":\"m136\",\"16698\":\"m136\",\"16825\":\"m136\",\"16837\":\"m136\",\"16840\":\"m136\",\"16839\":\"m136\",\"16831\":\"m136\",\"16830\":\"m136\",\"16679\":\"m136\",\"16810\":\"m136\",\"16783\":\"m136\",\"16587\":\"m136\",\"16821\":\"m136\",\"16804\":\"m136\",\"15753\":\"m136\",\"13947\":\"m136\",\"16275\":\"m136\",\"16814\":\"m136\",\"16813\":\"m136\",\"16812\":\"m136\",\"16811\":\"m136\",\"14655\":\"m136\",\"16325\":\"m136\",\"16708\":\"m136\",\"16760\":\"m136\",\"16792\":\"m136\",\"11349\":\"m136\",\"16458\":\"m136\",\"16743\":\"m136\",\"16768\":\"m136\",\"16721\":\"m136\",\"16778\":\"m136\",\"16774\":\"m136\",\"16686\":\"m136\",\"16378\":\"m136\",\"16588\":\"m136\",\"16446\":\"m136\",\"16773\":\"m136\",\"16380\":\"m136\",\"16254\":{\"m\":136,\"g\":135},\"16720\":{\"m\":136,\"g\":135},\"16560\":{\"m\":136,\"g\":135},\"16764\":{\"m\":136,\"g\":135},\"16763\":{\"m\":136,\"g\":135},\"16759\":{\"m\":136,\"g\":135},\"16715\":{\"m\":136,\"g\":135},\"16756\":{\"m\":136,\"g\":135},\"16754\":{\"m\":136,\"g\":135},\"16752\":{\"m\":136,\"g\":135},\"16751\":{\"m\":136,\"g\":135},\"16749\":{\"m\":136,\"g\":135},\"16748\":{\"m\":136,\"g\":135},\"16675\":{\"m\":136,\"g\":135},\"16746\":{\"m\":136,\"g\":135},\"16409\":{\"m\":136,\"g\":135},\"13681\":{\"m\":136,\"g\":135},\"16745\":{\"m\":136,\"g\":135},\"16741\":{\"m\":136,\"g\":135},\"16014\":{\"m\":136,\"g\":135},\"16739\":{\"m\":136,\"g\":135},\"16719\":{\"m\":136,\"g\":135},\"16668\":{\"m\":136,\"g\":135},\"16738\":{\"m\":136,\"g\":135},\"16709\":{\"m\":136,\"g\":135},\"16735\":{\"m\":136,\"g\":135},\"16622\":{\"m\":136,\"g\":135},\"16635\":{\"m\":136,\"g\":135},\"16733\":{\"m\":136,\"g\":135},\"16730\":{\"m\":136,\"g\":135},\"16729\":{\"m\":136,\"g\":135},\"16519\":{\"m\":136,\"g\":135},\"16706\":{\"m\":136,\"g\":135},\"16115\":{\"m\":136,\"g\":135},\"16634\":{\"m\":136,\"g\":135},\"16535\":{\"m\":136,\"g\":135},\"16452\":{\"m\":136,\"g\":135},\"16533\":{\"m\":136,\"g\":135},\"16531\":{\"m\":136,\"g\":135},\"16529\":{\"m\":136,\"g\":135},\"15627\":{\"m\":136,\"g\":135},\"16608\":{\"m\":136,\"g\":135},\"16693\":{\"m\":136,\"g\":135},\"16155\":{\"m\":136,\"g\":135},\"16505\":{\"m\":136,\"g\":135},\"16697\":{\"m\":136,\"g\":135},\"16555\":{\"m\":136,\"g\":135},\"16669\":{\"m\":136,\"g\":135},\"16625\":{\"m\":136,\"g\":135},\"16695\":{\"m\":136,\"g\":135},\"16678\":{\"m\":136,\"g\":135},\"16692\":{\"m\":136,\"g\":135},\"16629\":{\"m\":136,\"g\":135},\"16680\":{\"m\":136,\"g\":135},\"16445\":{\"m\":136,\"g\":135},\"16306\":{\"m\":136,\"g\":135},\"16652\":{\"m\":136,\"g\":135},\"15343\":{\"m\":136,\"g\":135},\"16095\":{\"m\":136,\"g\":135},\"15151\":{\"m\":136,\"g\":135},\"16681\":{\"m\":136,\"g\":135},\"16674\":{\"m\":136,\"g\":135},\"16457\":{\"m\":136,\"g\":135},\"16426\":{\"m\":136,\"g\":135},\"16676\":{\"m\":136,\"g\":135},\"16658\":{\"m\":136,\"g\":135},\"16672\":{\"m\":136,\"g\":135},\"16633\":{\"m\":136,\"g\":135},\"16671\":{\"m\":136,\"g\":135},\"15712\":{\"m\":136,\"g\":135},\"16660\":{\"m\":136,\"g\":135},\"16664\":{\"m\":136,\"g\":135},\"16657\":{\"m\":136,\"g\":135},\"16661\":{\"m\":136,\"g\":135},\"16618\":{\"m\":136,\"g\":135},\"16179\":{\"m\":136,\"g\":135},\"16599\":{\"m\":136,\"g\":135},\"15238\":{\"m\":136,\"g\":135},\"16532\":{\"m\":136,\"g\":135},\"16654\":{\"m\":136,\"g\":135},\"16651\":{\"m\":136,\"g\":135},\"16203\":{\"m\":136,\"g\":135},\"13602\":{\"m\":136,\"g\":135},\"16620\":{\"m\":136,\"g\":135},\"16514\":{\"m\":136,\"g\":135},\"16631\":{\"m\":136,\"g\":135},\"16630\":{\"m\":136,\"g\":135},\"16619\":{\"m\":136,\"g\":135},\"16527\":{\"m\":136,\"g\":135},\"15938\":{\"m\":136,\"g\":135},\"16582\":{\"m\":136,\"g\":135},\"16418\":{\"m\":136,\"g\":135},\"16459\":{\"m\":136,\"g\":135},\"16465\":{\"m\":136,\"g\":135},\"16468\":{\"m\":136,\"g\":135},\"16566\":{\"m\":136,\"g\":135},\"16617\":{\"m\":136,\"g\":135},\"16597\":{\"m\":136,\"g\":135},\"16504\":{\"m\":136,\"g\":135},\"16593\":{\"m\":136,\"g\":135},\"16118\":{\"m\":136,\"g\":135},\"15439\":{\"m\":136,\"g\":135},\"16609\":{\"m\":136,\"g\":135},\"16453\":{\"m\":136,\"g\":135},\"16499\":{\"m\":136,\"g\":135},\"16606\":{\"m\":136,\"g\":135},\"16603\":{\"m\":136,\"g\":135},\"16576\":{\"m\":136,\"g\":135},\"16596\":{\"m\":136,\"g\":135},\"16598\":{\"m\":136,\"g\":135},\"16601\":{\"m\":136,\"g\":135},\"16594\":{\"m\":136,\"g\":135},\"16442\":{\"m\":136,\"g\":135},\"16422\":{\"m\":136,\"g\":135},\"16223\":{\"m\":136,\"g\":135},\"16425\":{\"m\":136,\"g\":135},\"16589\":{\"m\":136,\"g\":135},\"16592\":{\"m\":136,\"g\":135},\"16583\":{\"m\":136,\"g\":135},\"16585\":{\"m\":136,\"g\":135},\"16591\":{\"m\":136,\"g\":135},\"16326\":{\"m\":136,\"g\":135},\"16523\":{\"m\":136,\"g\":135},\"16420\":{\"m\":136,\"g\":135},\"16548\":{\"m\":136,\"g\":135},\"16549\":{\"m\":136,\"g\":135},\"16575\":{\"m\":136,\"g\":135},\"16507\":{\"m\":136,\"g\":135},\"16570\":{\"m\":136,\"g\":135},\"16463\":{\"m\":136,\"g\":135},\"14215\":{\"m\":136,\"g\":135},\"13518\":{\"m\":136,\"g\":135},\"16455\":{\"m\":136,\"g\":135},\"16496\":{\"m\":136,\"g\":135},\"16540\":{\"m\":136,\"g\":135},\"16201\":{\"m\":136,\"g\":135},\"15456\":{\"m\":136,\"g\":135},\"16421\":{\"m\":136,\"g\":135},\"16466\":{\"m\":136,\"g\":135},\"16539\":{\"m\":136,\"g\":135},\"16417\":{\"m\":136,\"g\":135},\"16502\":{\"m\":136,\"g\":135},\"16415\":{\"m\":136,\"g\":135},\"16467\":{\"m\":136,\"g\":135},\"16399\":{\"m\":136,\"g\":135},\"14112\":{\"m\":136,\"g\":135},\"15492\":{\"m\":136,\"g\":135},\"16538\":{\"m\":136,\"g\":135},\"6135\":{\"m\":136,\"g\":135},\"16528\":{\"m\":136,\"g\":135},\"16416\":{\"m\":136,\"g\":135},\"16419\":{\"m\":136,\"g\":135},\"16434\":{\"m\":136,\"g\":135},\"16525\":{\"m\":136,\"g\":135},\"16520\":{\"m\":136,\"g\":135},\"16086\":{\"m\":136,\"g\":135},\"16524\":{\"m\":136,\"g\":135},\"15677\":{\"m\":136,\"g\":135},\"16477\":{\"m\":136,\"g\":135},\"16356\":{\"m\":136,\"g\":135},\"16513\":{\"m\":136,\"g\":135},\"15663\":{\"m\":136,\"g\":135},\"16516\":{\"m\":136,\"g\":135},\"16482\":{\"m\":136,\"g\":135},\"16511\":{\"m\":136,\"g\":135},\"16280\":{\"m\":136,\"g\":135},\"16509\":{\"m\":136,\"g\":135},\"16508\":{\"m\":136,\"g\":135},\"16492\":{\"m\":136,\"g\":135},\"16469\":{\"m\":136,\"g\":135},\"16138\":{\"m\":136,\"g\":135},\"16481\":{\"m\":136,\"g\":135},\"14474\":{\"m\":136,\"g\":135},\"16367\":{\"m\":136,\"g\":135},\"16475\":{\"m\":136,\"g\":135},\"16478\":{\"m\":136,\"g\":135},\"16479\":{\"m\":136,\"g\":135},\"16474\":{\"m\":136,\"g\":135},\"16447\":{\"m\":136,\"g\":135},\"16471\":{\"m\":136,\"g\":135},\"16456\":{\"m\":136,\"g\":135},\"16382\":{\"m\":136,\"g\":135},\"14051\":{\"m\":136,\"g\":135},\"16042\":{\"m\":136,\"g\":135},\"16424\":{\"m\":136,\"g\":135},\"16460\":{\"m\":136,\"g\":135},\"16464\":{\"m\":136,\"g\":135},\"16454\":{\"m\":136,\"g\":135},\"16088\":{\"m\":136,\"g\":135},\"16387\":{\"m\":136,\"g\":135},\"16386\":{\"m\":136,\"g\":135},\"16451\":{\"m\":136,\"g\":135},\"16450\":{\"m\":136,\"g\":135},\"16448\":{\"m\":136,\"g\":135},\"16389\":{\"m\":136,\"g\":135},\"16180\":{\"m\":136,\"g\":135},\"16441\":{\"m\":136,\"g\":135},\"16444\":{\"m\":136,\"g\":135},\"16437\":{\"m\":136,\"g\":135},\"16310\":{\"m\":136,\"g\":135},\"16435\":{\"m\":136,\"g\":135},\"16433\":{\"m\":136,\"g\":135},\"16432\":{\"m\":136,\"g\":135},\"16375\":{\"m\":136,\"g\":135},\"16330\":{\"m\":136,\"g\":135},\"15836\":{\"m\":136,\"g\":135},\"16430\":{\"m\":136,\"g\":135},\"16414\":{\"m\":136,\"g\":135},\"16429\":{\"m\":136,\"g\":135},\"16427\":{\"m\":136,\"g\":135},\"16413\":{\"m\":136,\"g\":135},\"16408\":{\"m\":136,\"g\":135},\"16411\":{\"m\":136,\"g\":135},\"16334\":{\"m\":136,\"g\":135},\"16127\":{\"m\":136,\"g\":135},\"16323\":{\"m\":136,\"g\":135},\"14066\":{\"m\":136,\"g\":135},\"16405\":{\"m\":136,\"g\":135},\"16406\":{\"m\":136,\"g\":135},\"16401\":{\"m\":136,\"g\":135},\"16400\":{\"m\":136,\"g\":135},\"16347\":{\"m\":136,\"g\":135},\"16349\":{\"m\":136,\"g\":135},\"16390\":{\"m\":136,\"g\":135},\"16392\":{\"m\":136,\"g\":135},\"16374\":{\"m\":136,\"g\":135},\"16391\":{\"m\":136,\"g\":135},\"16381\":{\"m\":136,\"g\":135},\"16376\":{\"m\":136,\"g\":135},\"16373\":{\"m\":136,\"g\":135},\"16359\":{\"m\":136,\"g\":135},\"16268\":{\"m\":136,\"g\":135},\"16365\":{\"m\":136,\"g\":135},\"16368\":{\"m\":136,\"g\":135},\"16363\":{\"m\":136,\"g\":135},\"16358\":{\"m\":136,\"g\":135},\"16343\":{\"m\":136,\"g\":135},\"16357\":{\"m\":136,\"g\":135},\"16355\":{\"m\":136,\"g\":135},\"16352\":{\"m\":136,\"g\":135},\"16064\":{\"m\":136,\"g\":135},\"15434\":{\"m\":136,\"g\":135},\"16353\":{\"m\":136,\"g\":135},\"16341\":{\"m\":136,\"g\":135},\"16351\":{\"m\":136,\"g\":135},\"16269\":{\"m\":136,\"g\":135},\"16350\":{\"m\":136,\"g\":135},\"16340\":{\"m\":136,\"g\":135},\"16339\":{\"m\":136,\"g\":135},\"16335\":{\"m\":136,\"g\":135},\"16344\":{\"m\":136,\"g\":135},\"15941\":{\"m\":136,\"g\":135},\"16342\":{\"m\":136,\"g\":135},\"16164\":{\"m\":136,\"g\":135},\"16304\":{\"m\":136,\"g\":135},\"16303\":{\"m\":136,\"g\":135},\"16338\":{\"m\":136,\"g\":135},\"16219\":{\"m\":136,\"g\":135},\"16337\":{\"m\":136,\"g\":135},\"15640\":{\"m\":136,\"g\":135},\"16311\":{\"m\":136,\"g\":135},\"16332\":{\"m\":136,\"g\":135},\"15175\":{\"m\":136,\"g\":135},\"16328\":{\"m\":136,\"g\":135},\"16009\":{\"m\":136,\"g\":135},\"13592\":{\"m\":136,\"g\":135},\"16272\":{\"m\":136,\"g\":135},\"16283\":{\"m\":136,\"g\":135},\"16257\":{\"m\":136,\"g\":135},\"16177\":{\"m\":136,\"g\":135},\"15560\":{\"m\":136,\"g\":135},\"16324\":{\"m\":136,\"g\":135},\"16317\":{\"m\":136,\"g\":135},\"16296\":{\"m\":136,\"g\":135},\"16321\":{\"m\":136,\"g\":135},\"16316\":{\"m\":136,\"g\":135},\"16313\":{\"m\":136,\"g\":135},\"16318\":{\"m\":136,\"g\":135},\"16320\":{\"m\":136,\"g\":135},\"16319\":{\"m\":136,\"g\":135},\"16314\":{\"m\":136,\"g\":135},\"16315\":{\"m\":136,\"g\":135},\"16312\":{\"m\":136,\"g\":135},\"16287\":{\"m\":136,\"g\":135},\"16277\":{\"m\":136,\"g\":135},\"16308\":{\"m\":136,\"g\":135},\"16092\":{\"m\":136,\"g\":135},\"16128\":{\"m\":136,\"g\":135},\"16305\":{\"m\":136,\"g\":135},\"16301\":{\"m\":136,\"g\":135},\"13959\":{\"m\":136,\"g\":135},\"16248\":{\"m\":136,\"g\":135},\"16292\":{\"m\":136,\"g\":135},\"16298\":{\"m\":136,\"g\":135},\"16227\":{\"m\":136,\"g\":135},\"16213\":{\"m\":136,\"g\":135},\"16267\":{\"m\":136,\"g\":135},\"16144\":{\"m\":136,\"g\":135},\"16222\":{\"m\":136,\"g\":135},\"15345\":{\"m\":136,\"g\":135},\"16270\":{\"m\":136,\"g\":135},\"16285\":{\"m\":136,\"g\":135},\"14636\":{\"m\":136,\"g\":135},\"16263\":{\"m\":136,\"g\":135},\"16265\":{\"m\":136,\"g\":135},\"15995\":{\"m\":136,\"g\":135},\"16162\":{\"m\":136,\"g\":135},\"16262\":{\"m\":136,\"g\":135},\"16261\":{\"m\":136,\"g\":135},\"16260\":{\"m\":136,\"g\":135},\"16251\":{\"m\":136,\"g\":135},\"16247\":{\"m\":136,\"g\":135},\"16250\":{\"m\":136,\"g\":135},\"16243\":{\"m\":136,\"g\":135},\"14085\":{\"m\":136,\"g\":135},\"16171\":{\"m\":136,\"g\":135},\"16245\":{\"m\":136,\"g\":135},\"16238\":{\"m\":136,\"g\":135},\"15814\":{\"m\":136,\"g\":135},\"16239\":{\"m\":136,\"g\":135},\"16240\":{\"m\":136,\"g\":135},\"16236\":{\"m\":136,\"g\":135},\"16233\":{\"m\":136,\"g\":135},\"16202\":{\"m\":136,\"g\":135},\"16228\":{\"m\":136,\"g\":135},\"16018\":{\"m\":136,\"g\":135},\"16195\":{\"m\":136,\"g\":135},\"16111\":{\"m\":136,\"g\":135},\"13394\":{\"m\":136,\"g\":135},\"16204\":{\"m\":136,\"g\":135},\"16214\":{\"m\":136,\"g\":135},\"16221\":{\"m\":136,\"g\":135},\"16178\":{\"m\":136,\"g\":135},\"16187\":{\"m\":136,\"g\":135},\"16209\":{\"m\":136,\"g\":135},\"15597\":{\"m\":136,\"g\":135},\"16117\":{\"m\":136,\"g\":135},\"15878\":{\"m\":136,\"g\":135},\"16200\":{\"m\":136,\"g\":135},\"15946\":{\"m\":136,\"g\":135},\"16198\":{\"m\":136,\"g\":135},\"16172\":{\"m\":136,\"g\":135},\"15575\":{\"m\":136,\"g\":135},\"16174\":{\"m\":136,\"g\":135},\"16156\":{\"m\":136,\"g\":135},\"15754\":{\"m\":136,\"g\":135},\"14920\":{\"m\":136,\"g\":135},\"16050\":{\"m\":136,\"g\":135},\"16181\":{\"m\":136,\"g\":135},\"16183\":{\"m\":136,\"g\":135},\"16182\":{\"m\":136,\"g\":135},\"14416\":{\"m\":136,\"g\":135},\"16168\":{\"m\":136,\"g\":135},\"16175\":{\"m\":136,\"g\":135},\"16166\":{\"m\":136,\"g\":135},\"16161\":{\"m\":136,\"g\":135},\"16163\":{\"m\":136,\"g\":135},\"16165\":{\"m\":136,\"g\":135},\"16110\":{\"m\":136,\"g\":135},\"16173\":{\"m\":136,\"g\":135},\"16159\":{\"m\":136,\"g\":135},\"16150\":{\"m\":136,\"g\":135},\"16160\":{\"m\":136,\"g\":135},\"16158\":{\"m\":136,\"g\":135},\"16149\":{\"m\":136,\"g\":135},\"16124\":{\"m\":136,\"g\":135},\"15730\":{\"m\":136,\"g\":135},\"13978\":{\"m\":136,\"g\":135},\"12625\":{\"m\":136,\"g\":135},\"16154\":{\"m\":136,\"g\":135},\"18298\":\"m137\",\"18111\":\"m137\"},\"c\":{\"2ccd9fd8\":{\"m\":0,\"g\":94},\"46b7ea7c\":{\"m\":0,\"g\":94},\"70359bf3\":{\"m\":0,\"g\":94},\"01ca82d7\":{\"m\":0,\"g\":94},\"4bd8233f\":{\"m\":0,\"g\":94},\"08ab2a16\":{\"m\":0,\"g\":94},\"f652494d\":{\"m\":0,\"g\":94},\"30720e73\":{\"m\":0,\"g\":94},\"331848de\":{\"m\":0,\"g\":94},\"93eeb543\":{\"m\":0,\"g\":94},\"ead5b39f\":{\"m\":0,\"g\":94},\"22085081\":{\"m\":0,\"g\":94},\"f6d40df0\":{\"m\":0,\"g\":94},\"22ec7bc2\":{\"m\":1,\"g\":94},\"c0454b32\":{\"m\":1,\"g\":94},\"8024fc5e\":{\"m\":1,\"g\":94},\"70528762\":{\"m\":1,\"g\":94},\"71d30d6d\":{\"m\":1,\"g\":94},\"f9d72381\":{\"m\":1,\"g\":94},\"bf51ddc6\":{\"m\":1,\"g\":94},\"fd7c4792\":{\"m\":1,\"g\":94},\"c4707f1b\":{\"m\":1,\"g\":94},\"ffe4aaee\":{\"m\":1,\"g\":94},\"5b27a1dc\":{\"m\":1,\"g\":94},\"e71d4ab3\":{\"m\":1,\"g\":94},\"fbf42263\":{\"m\":1,\"g\":94},\"cc3ada98\":{\"m\":2,\"g\":94},\"a837166e\":{\"m\":2,\"g\":94},\"11f3cca6\":{\"m\":2,\"g\":94},\"ca13f3b8\":{\"m\":2,\"g\":94},\"0b2efc2a\":{\"m\":2,\"g\":94},\"f30abd09\":{\"m\":2,\"g\":94},\"40ab1f01\":{\"m\":2,\"g\":94},\"199e82a1\":{\"m\":2,\"g\":94},\"23471f9a\":{\"m\":2,\"g\":94},\"61d4c939\":{\"m\":2,\"g\":94},\"98a3e8ef\":{\"m\":2,\"g\":94},\"2b079f89\":{\"m\":2,\"g\":94},\"05b4c398\":{\"m\":2,\"g\":94},\"dafafe5b\":{\"m\":2,\"g\":94},\"b240f751\":{\"m\":2,\"g\":94},\"501f9444\":{\"m\":2,\"g\":94},\"723f0421\":{\"m\":3,\"g\":94},\"585eabab\":{\"m\":3,\"g\":94},\"c70b3cfa\":{\"m\":4,\"g\":94},\"489796c7\":{\"m\":4,\"g\":94},\"fa7a696d\":{\"m\":4,\"g\":94},\"bef0b359\":{\"m\":4,\"g\":94},\"c6576e82\":{\"m\":4,\"g\":94},\"99258181\":{\"m\":4,\"g\":94},\"3de54a1b\":{\"m\":4,\"g\":94},\"7358fa64\":{\"m\":4,\"g\":94},\"9a16fea0\":{\"m\":4,\"g\":94},\"9e037c82\":{\"m\":4,\"g\":94},\"9076386d\":{\"m\":4,\"g\":94},\"959c4174\":{\"m\":4,\"g\":94},\"94e05770\":{\"m\":4,\"g\":94},\"63e97e5e\":{\"m\":4,\"g\":94},\"e08bca28\":{\"m\":4,\"g\":94},\"cd3ccb2e\":{\"m\":4,\"g\":94},\"3f5c2f4c\":{\"m\":4,\"g\":94},\"007eeb4e\":{\"m\":4,\"g\":94},\"e8f2b155\":{\"m\":4,\"g\":94},\"6dceab4d\":{\"m\":5,\"g\":94},\"a49dc52b\":{\"m\":6,\"g\":94},\"873d0e85\":{\"m\":6,\"g\":94},\"1d0fbe8e\":{\"m\":6,\"g\":94},\"97aa9b32\":{\"m\":6,\"g\":94},\"06175286\":{\"m\":6,\"g\":94},\"4ea92f83\":{\"m\":6,\"g\":94},\"6b0af285\":{\"m\":6,\"g\":94},\"6f560c76\":{\"m\":6,\"g\":94},\"cd687233\":{\"m\":6,\"g\":94},\"81561f8e\":{\"m\":6,\"g\":94},\"3a581e99\":{\"m\":6,\"g\":94},\"0147f940\":{\"m\":6,\"g\":94},\"23950056\":{\"m\":6,\"g\":94},\"93414c82\":{\"m\":6,\"g\":94},\"ed7c7eca\":{\"m\":6,\"g\":94},\"0c457bae\":{\"m\":6,\"g\":94},\"d3fc86a4\":{\"m\":6,\"g\":94},\"01ee0fbc\":{\"m\":6,\"g\":94},\"711d3435\":{\"m\":6,\"g\":94},\"f6bfe3aa\":{\"m\":7,\"g\":94},\"e095b162\":{\"m\":7,\"g\":94},\"67be11c7\":{\"m\":7,\"g\":94},\"cd8c3ccd\":{\"m\":7,\"g\":94},\"9c121f2a\":{\"m\":7,\"g\":94},\"03e04b23\":{\"m\":7,\"g\":94},\"86442530\":{\"m\":7,\"g\":94},\"79cb018e\":{\"m\":7,\"g\":94},\"c7af9f73\":{\"m\":7,\"g\":94},\"876db8dc\":{\"m\":7,\"g\":94},\"ad82bac6\":{\"m\":7,\"g\":94},\"71b54eea\":{\"m\":7,\"g\":94},\"74b3bfaa\":{\"m\":7,\"g\":94},\"4a634cf6\":{\"m\":7,\"g\":94},\"624b21e7\":{\"m\":8,\"g\":94},\"c51020cf\":{\"m\":8,\"g\":94},\"50afed4e\":{\"m\":8,\"g\":94},\"4d303c4f\":{\"m\":8,\"g\":94},\"37b42297\":{\"m\":8,\"g\":94},\"cba50273\":{\"m\":8,\"g\":94},\"a6aa46dd\":{\"m\":8,\"g\":94},\"405f26b0\":{\"m\":8,\"g\":94},\"b1a3a454\":{\"m\":8,\"g\":94},\"79e6b84b\":{\"m\":8,\"g\":94},\"26c34941\":{\"m\":8,\"g\":94},\"cb8e1982\":{\"m\":8,\"g\":94},\"23f05005\":{\"m\":8,\"g\":94},\"a7334aee\":{\"m\":8,\"g\":94},\"ee1df26a\":{\"m\":8,\"g\":94},\"3ae78a09\":{\"m\":8,\"g\":94},\"ccbe1e67\":{\"m\":8,\"g\":94},\"e2bf732b\":{\"m\":8,\"g\":94},\"322421fa\":{\"m\":8,\"g\":94},\"8ff870bf\":{\"m\":8,\"g\":94},\"26f0bedc\":{\"m\":8,\"g\":94},\"82fa69b3\":{\"m\":8,\"g\":94},\"8fb7459e\":{\"m\":8,\"g\":94},\"bb3a3b66\":{\"m\":8,\"g\":94},\"45d6592d\":{\"m\":8,\"g\":94},\"4aa5dd2c\":{\"m\":9,\"g\":94},\"13662fd5\":{\"m\":9,\"g\":94},\"d5ae2eba\":{\"m\":9,\"g\":94},\"1b355479\":{\"m\":9,\"g\":94},\"faba293a\":{\"m\":9,\"g\":94},\"89885b31\":{\"m\":9,\"g\":94},\"64fe3115\":{\"m\":9,\"g\":94},\"a7ace9c8\":{\"m\":9,\"g\":94},\"a833de05\":{\"m\":9,\"g\":94},\"30d67b2b\":{\"m\":9,\"g\":94},\"b0b722ee\":{\"m\":9,\"g\":94},\"01b07ea3\":{\"m\":9,\"g\":94},\"dfb13ac4\":{\"m\":9,\"g\":94},\"ec90b9c0\":{\"m\":9,\"g\":94},\"9759d927\":{\"m\":9,\"g\":94},\"8d0a7fae\":{\"m\":9,\"g\":94},\"c4e9ebe3\":{\"m\":9,\"g\":94},\"3c2c5869\":{\"m\":9,\"g\":94},\"4cb9aaed\":{\"m\":9,\"g\":94},\"9de9a468\":{\"m\":9,\"g\":94},\"ce3b2610\":{\"m\":9,\"g\":94},\"91e03633\":{\"m\":9,\"g\":94},\"2a74748b\":{\"m\":9,\"g\":94},\"63ba630b\":{\"m\":9,\"g\":94},\"6493256b\":{\"m\":9,\"g\":94},\"06008bc2\":{\"m\":9,\"g\":94},\"bb824da4\":{\"m\":9,\"g\":94},\"93121324\":{\"m\":9,\"g\":94},\"c97fdae4\":{\"m\":9,\"g\":94},\"51104cd4\":{\"m\":10,\"g\":94},\"e2b2f0a2\":{\"m\":10,\"g\":94},\"b57abe16\":{\"m\":10,\"g\":94},\"e57f0792\":{\"m\":10,\"g\":94},\"08df63a6\":{\"m\":10,\"g\":94},\"77835756\":{\"m\":10,\"g\":94},\"ed315799\":{\"m\":10,\"g\":94},\"92e2d74f\":{\"m\":10,\"g\":94},\"d9b3b018\":{\"m\":10,\"g\":94},\"745ea007\":{\"m\":10,\"g\":94},\"ad1dd746\":{\"m\":10,\"g\":94},\"eb4308c4\":{\"m\":10,\"g\":94},\"b2eb0805\":{\"m\":10,\"g\":94},\"72bb3443\":{\"m\":11,\"g\":94},\"2d580e7a\":{\"m\":11,\"g\":94},\"3fc97f67\":{\"m\":11,\"g\":94},\"abc548c7\":{\"m\":11,\"g\":94},\"aee4f523\":{\"m\":11,\"g\":94},\"7023f413\":{\"m\":11,\"g\":94},\"09deb20d\":{\"m\":11,\"g\":94},\"33b242df\":{\"m\":11,\"g\":94},\"a511a2d0\":{\"m\":11,\"g\":94},\"6ec65f45\":{\"m\":11,\"g\":94},\"e2c31fca\":{\"m\":11,\"g\":94},\"d5de20a3\":{\"m\":11,\"g\":94},\"4a1c6ae2\":{\"m\":11,\"g\":94},\"14522e6a\":{\"m\":11,\"g\":94},\"183df472\":{\"m\":11,\"g\":94},\"5c5aba59\":{\"m\":11,\"g\":94},\"ba67101f\":{\"m\":11,\"g\":94},\"95c4e0df\":{\"m\":11,\"g\":94},\"19818b9c\":{\"m\":11,\"g\":94},\"9216b106\":{\"m\":11,\"g\":94},\"da19434c\":{\"m\":11,\"g\":94},\"150d7020\":{\"m\":11,\"g\":94},\"9acc6e35\":{\"m\":11,\"g\":94},\"cf9d8efd\":{\"m\":11,\"g\":94},\"1bf1cf19\":{\"m\":11,\"g\":94},\"e822e590\":{\"m\":11,\"g\":94},\"ca4f1ab8\":{\"m\":11,\"g\":94},\"2b6d9991\":{\"m\":11,\"g\":94},\"65501a9c\":{\"m\":11,\"g\":94},\"db611066\":{\"m\":11,\"g\":94},\"c93293c5\":{\"m\":11,\"g\":94},\"62b3812b\":{\"m\":11,\"g\":94},\"550a4f78\":{\"m\":11,\"g\":94},\"ff99c38a\":{\"m\":11,\"g\":94},\"c9de3e16\":{\"m\":11,\"g\":94},\"ed27a6b9\":{\"m\":11,\"g\":94},\"463c6632\":{\"m\":11,\"g\":94},\"b0890631\":{\"m\":11,\"g\":94},\"cb389c91\":{\"m\":11,\"g\":94},\"eddaa2b5\":{\"m\":11,\"g\":94},\"2af565b3\":{\"m\":11,\"g\":94},\"3842eba5\":{\"m\":11,\"g\":94},\"24e59f53\":{\"m\":11,\"g\":94},\"75235419\":{\"m\":11,\"g\":94},\"64ee9c03\":{\"m\":11,\"g\":94},\"30d17840\":{\"m\":11,\"g\":94},\"ce216c80\":{\"m\":11,\"g\":94},\"e0ae5d42\":{\"m\":12,\"g\":94},\"32de16ce\":{\"m\":12,\"g\":94},\"0992d85f\":{\"m\":12,\"g\":94},\"5dc55a5f\":{\"m\":12,\"g\":94},\"4231a42f\":{\"m\":12,\"g\":94},\"455c9ccc\":{\"m\":12,\"g\":94},\"39191c85\":{\"m\":12,\"g\":94},\"562b8857\":{\"m\":12,\"g\":94},\"04c0b214\":{\"m\":12,\"g\":94},\"6e09cf6a\":{\"m\":12,\"g\":94},\"e8a2327d\":{\"m\":13,\"g\":94},\"91f93f14\":{\"m\":13,\"g\":94},\"f70f7258\":{\"m\":13,\"g\":94},\"c0ae70c8\":{\"m\":13,\"g\":94},\"87260b7b\":{\"m\":13,\"g\":94},\"651a23ee\":{\"m\":13,\"g\":94},\"bf3e271f\":{\"m\":13,\"g\":94},\"3bc01ac1\":{\"m\":13,\"g\":94},\"9f009261\":{\"m\":13,\"g\":94},\"159cc741\":{\"m\":13,\"g\":94},\"7d1ebc2d\":{\"m\":13,\"g\":94},\"83525a1d\":{\"m\":13,\"g\":94},\"80a33ce8\":{\"m\":13,\"g\":94},\"1a57e416\":{\"m\":13,\"g\":94},\"adc97426\":{\"m\":13,\"g\":94},\"0463f7fb\":{\"m\":13,\"g\":94},\"565d7274\":{\"m\":13,\"g\":94},\"09de730d\":{\"m\":13,\"g\":94},\"55c16436\":{\"m\":13,\"g\":94},\"2b605ab1\":{\"m\":13,\"g\":94},\"947bda73\":{\"m\":13,\"g\":94},\"f06e90c2\":{\"m\":13,\"g\":94},\"2cea6146\":{\"m\":13,\"g\":94},\"44c998fc\":{\"m\":13,\"g\":94},\"3167d8da\":{\"m\":13,\"g\":94},\"0fafc560\":{\"m\":13,\"g\":94},\"19d2135c\":{\"m\":13,\"g\":94},\"ced77c66\":{\"m\":13,\"g\":94},\"8dbdc018\":{\"m\":13,\"g\":94},\"3e684be7\":{\"m\":13,\"g\":94},\"ec380dfd\":{\"m\":13,\"g\":94},\"5b647543\":{\"m\":13,\"g\":94},\"8210ec60\":{\"m\":13,\"g\":94},\"5be9eb8a\":{\"m\":13,\"g\":94},\"c05956e5\":{\"m\":13,\"g\":94},\"d75dc20f\":{\"m\":13,\"g\":94},\"690d162d\":{\"m\":13,\"g\":94},\"664287b2\":{\"m\":13,\"g\":94},\"2f11936f\":{\"m\":14,\"g\":94},\"63fbef98\":{\"m\":14,\"g\":94},\"2a754e57\":{\"m\":14,\"g\":94},\"96c503eb\":{\"m\":14,\"g\":94},\"441cca77\":{\"m\":14,\"g\":94},\"c7709d3a\":{\"m\":14,\"g\":94},\"9380f50f\":{\"m\":14,\"g\":94},\"95dc093b\":{\"m\":14,\"g\":94},\"d9ac6392\":{\"m\":14,\"g\":94},\"26294b2f\":{\"m\":14,\"g\":94},\"75b31a2a\":{\"m\":14,\"g\":94},\"11616fc6\":{\"m\":14,\"g\":94},\"9ce89bc1\":{\"m\":14,\"g\":94},\"badf3fa0\":{\"m\":14,\"g\":94},\"945aa9be\":{\"m\":14,\"g\":94},\"2e6e62e1\":{\"m\":14,\"g\":94},\"a385ee27\":{\"m\":14,\"g\":94},\"eb1ae6ae\":{\"m\":14,\"g\":94},\"2187f362\":{\"m\":14,\"g\":94},\"9465b668\":{\"m\":14,\"g\":94},\"05471f21\":{\"m\":14,\"g\":94},\"1fa15099\":{\"m\":14,\"g\":94},\"303ef888\":{\"m\":14,\"g\":94},\"92cb93f3\":{\"m\":14,\"g\":94},\"e94e60d6\":{\"m\":14,\"g\":94},\"d2f8bfb2\":{\"m\":14,\"g\":94},\"b7e2f800\":{\"m\":14,\"g\":94},\"09593e9b\":{\"m\":14,\"g\":94},\"53a7ebd8\":{\"m\":14,\"g\":94},\"ad5f04d6\":{\"m\":14,\"g\":94},\"bbec01c9\":{\"m\":14,\"g\":94},\"40e53d65\":{\"m\":14,\"g\":94},\"fb9296f0\":{\"m\":14,\"g\":94},\"1374334d\":{\"m\":14,\"g\":94},\"94aead9e\":{\"m\":14,\"g\":94},\"9c902b19\":{\"m\":14,\"g\":94},\"111991fe\":{\"m\":14,\"g\":94},\"a8c787d2\":{\"m\":14,\"g\":94},\"5f283991\":{\"m\":14,\"g\":94},\"b6667a53\":{\"m\":14,\"g\":94},\"542bc733\":{\"m\":14,\"g\":94},\"f6dbd240\":{\"m\":14,\"g\":94},\"ad872feb\":{\"m\":15,\"g\":94},\"da2e5d65\":{\"m\":15,\"g\":94},\"ce62dc73\":{\"m\":15,\"g\":94},\"02b72586\":{\"m\":15,\"g\":94},\"d557e9f3\":{\"m\":15,\"g\":94},\"740c46a1\":{\"m\":15,\"g\":94},\"b3868722\":{\"m\":15,\"g\":94},\"710f614e\":{\"m\":15,\"g\":94},\"f25b76c0\":{\"m\":15,\"g\":94},\"f4e885b7\":{\"m\":15,\"g\":94},\"0877f1e7\":{\"m\":15,\"g\":94},\"5304b4ef\":{\"m\":15,\"g\":94},\"26908d95\":{\"m\":15,\"g\":94},\"c0982ac5\":{\"m\":15,\"g\":94},\"dc1b8bcf\":{\"m\":15,\"g\":94},\"5a57b8ad\":{\"m\":15,\"g\":94},\"d737da5f\":{\"m\":15,\"g\":94},\"ac113887\":{\"m\":15,\"g\":94},\"dc8cef1d\":{\"m\":15,\"g\":94},\"5d264a90\":{\"m\":16,\"g\":94},\"5949b1ca\":{\"m\":16,\"g\":94},\"0feca02d\":{\"m\":16,\"g\":94},\"10143e1a\":{\"m\":16,\"g\":94},\"65c65776\":{\"m\":16,\"g\":94},\"66581596\":{\"m\":16,\"g\":94},\"396a6924\":{\"m\":16,\"g\":94},\"af4e7910\":{\"m\":16,\"g\":94},\"519e20cf\":{\"m\":16,\"g\":94},\"d9a69029\":{\"m\":16,\"g\":94},\"56f5fc4a\":{\"m\":17,\"g\":94},\"6a2941f4\":{\"m\":17,\"g\":94},\"5ac8b806\":{\"m\":17,\"g\":94},\"bae9541e\":{\"m\":17,\"g\":94},\"a56858ba\":{\"m\":17,\"g\":94},\"564a898a\":{\"m\":17,\"g\":94},\"2b4c6462\":{\"m\":18,\"g\":94},\"f424e76d\":{\"m\":18,\"g\":94},\"490a1f39\":{\"m\":18,\"g\":94},\"06487f12\":{\"m\":18,\"g\":94},\"39c57317\":{\"m\":18,\"g\":94},\"9592a1f3\":{\"m\":18,\"g\":94},\"35759efa\":{\"m\":18,\"g\":94},\"8f4b1559\":{\"m\":18,\"g\":94},\"e3046ea3\":{\"m\":18,\"g\":94},\"49c5e0ec\":{\"m\":18,\"g\":94},\"ec2150b2\":{\"m\":18,\"g\":94},\"7620cd37\":{\"m\":18,\"g\":94},\"50a53887\":{\"m\":18,\"g\":94},\"11c8efff\":{\"m\":18,\"g\":94},\"e87c7fd5\":{\"m\":18,\"g\":94},\"630479c3\":{\"m\":18,\"g\":94},\"51fda143\":{\"m\":18,\"g\":94},\"dc4e4a6a\":{\"m\":18,\"g\":94},\"2d96da81\":{\"m\":18,\"g\":94},\"c126a6cc\":{\"m\":18,\"g\":94},\"ac971ff6\":{\"m\":18,\"g\":94},\"e1792cca\":{\"m\":18,\"g\":94},\"1b7adbb5\":{\"m\":18,\"g\":94},\"a9ef49c1\":{\"m\":18,\"g\":94},\"21ba3a88\":{\"m\":18,\"g\":94},\"9c5cac24\":{\"m\":18,\"g\":94},\"5960a6e5\":{\"m\":18,\"g\":94},\"b050d928\":{\"m\":18,\"g\":94},\"6a4dc996\":{\"m\":18,\"g\":94},\"d774acad\":{\"m\":18,\"g\":94},\"d93388da\":{\"m\":18,\"g\":94},\"476584cb\":{\"m\":18,\"g\":94},\"abd5385a\":{\"m\":18,\"g\":94},\"3de2f30a\":{\"m\":18,\"g\":94},\"4efcc59d\":{\"m\":18,\"g\":94},\"2e341cd4\":{\"m\":18,\"g\":94},\"a8552cb1\":{\"m\":18,\"g\":94},\"a470e60c\":{\"m\":18,\"g\":94},\"5f90e076\":{\"m\":18,\"g\":94},\"8832ecb1\":{\"m\":18,\"g\":94},\"5ff60eda\":{\"m\":18,\"g\":94},\"c1930022\":{\"m\":18,\"g\":94},\"fe3be159\":{\"m\":18,\"g\":94},\"0aa189f1\":{\"m\":18,\"g\":94},\"f6b29f69\":{\"m\":18,\"g\":94},\"c9ee3d35\":{\"m\":18,\"g\":94},\"41d1f677\":{\"m\":18,\"g\":94},\"444a0244\":{\"m\":19,\"g\":94},\"fa7ccb33\":{\"m\":19,\"g\":94},\"26868443\":{\"m\":19,\"g\":94},\"824a77d0\":{\"m\":19,\"g\":94},\"cf99eab7\":{\"m\":19,\"g\":94},\"9fdea29d\":{\"m\":19,\"g\":94},\"df7c4c19\":{\"m\":19,\"g\":94},\"c3f1aac8\":{\"m\":19,\"g\":94},\"d198791f\":{\"m\":19,\"g\":94},\"c07526e4\":{\"m\":19,\"g\":94},\"7b597475\":{\"m\":19,\"g\":94},\"5303c1ed\":{\"m\":19,\"g\":94},\"65bd1338\":{\"m\":19,\"g\":94},\"eedc12e1\":{\"m\":19,\"g\":94},\"5a4ef2b5\":{\"m\":19,\"g\":94},\"9dab947d\":{\"m\":19,\"g\":94},\"33ee97b0\":{\"m\":19,\"g\":94},\"6a846bb1\":{\"m\":19,\"g\":94},\"0fdb3127\":{\"m\":19,\"g\":94},\"5ad033a0\":{\"m\":19,\"g\":94},\"77e592e8\":{\"m\":19,\"g\":94},\"caaad53b\":{\"m\":19,\"g\":94},\"69d19188\":{\"m\":19,\"g\":94},\"4b4a67f8\":{\"m\":19,\"g\":94},\"0ac94c36\":{\"m\":19,\"g\":94},\"459abad2\":{\"m\":20,\"g\":94},\"30d8e130\":{\"m\":20,\"g\":94},\"08a3bd19\":{\"m\":20,\"g\":94},\"321a963b\":{\"m\":20,\"g\":94},\"e17deb27\":{\"m\":20,\"g\":94},\"2d3ae4e1\":{\"m\":20,\"g\":94},\"75f4ccb7\":{\"m\":20,\"g\":94},\"83d2b30d\":{\"m\":20,\"g\":94},\"4367f4bb\":{\"m\":20,\"g\":94},\"00e4baa7\":{\"m\":20,\"g\":94},\"4cd64b8e\":{\"m\":20,\"g\":94},\"01d66ae2\":{\"m\":20,\"g\":94},\"a523a3c1\":{\"m\":20,\"g\":94},\"9f94728f\":{\"m\":20,\"g\":94},\"1a491d00\":{\"m\":21,\"g\":94},\"8fbba3de\":{\"m\":21,\"g\":94},\"ae0f6130\":{\"m\":21,\"g\":94},\"60105897\":{\"m\":21,\"g\":94},\"926ac01b\":{\"m\":21,\"g\":94},\"25c881a0\":{\"m\":21,\"g\":94},\"04ec6ba2\":{\"m\":21,\"g\":94},\"d63f13c1\":{\"m\":21,\"g\":94},\"fded6744\":{\"m\":21,\"g\":94},\"6e453940\":{\"m\":21,\"g\":94},\"97e0f7d2\":{\"m\":21,\"g\":94},\"d5146bae\":{\"m\":21,\"g\":94},\"5bd06b45\":{\"m\":22,\"g\":94},\"9a611827\":{\"m\":22,\"g\":94},\"eeb24821\":{\"m\":22,\"g\":94},\"3e455b01\":{\"m\":22,\"g\":94},\"8628ab9c\":{\"m\":22,\"g\":94},\"1b77670f\":{\"m\":22,\"g\":94},\"768e05d0\":{\"m\":22,\"g\":94},\"01fbb11b\":{\"m\":22,\"g\":94},\"05d216da\":{\"m\":22,\"g\":94},\"6b32bb1c\":{\"m\":22,\"g\":94},\"40facad5\":{\"m\":22,\"g\":94},\"da504445\":{\"m\":22,\"g\":94},\"252e0f7b\":{\"m\":22,\"g\":94},\"7f6f2f0f\":{\"m\":22,\"g\":94},\"7802df1e\":{\"m\":22,\"g\":94},\"bc1154c3\":{\"m\":23,\"g\":94},\"752e6430\":{\"m\":23,\"g\":94},\"30db99b3\":{\"m\":23,\"g\":94},\"0a409bd4\":{\"m\":23,\"g\":94},\"e4db4e5b\":{\"m\":23,\"g\":94},\"bbc07c41\":{\"m\":23,\"g\":94},\"a036d419\":{\"m\":23,\"g\":94},\"f95e6617\":{\"m\":23,\"g\":94},\"de854fb5\":{\"m\":23,\"g\":94},\"f64b2a9b\":{\"m\":23,\"g\":94},\"9f95dcc6\":{\"m\":23,\"g\":94},\"0736b270\":{\"m\":23,\"g\":94},\"3fdab919\":{\"m\":23,\"g\":94},\"ba29504b\":{\"m\":23,\"g\":94},\"a72342f1\":{\"m\":23,\"g\":94},\"c3c74bf8\":{\"m\":23,\"g\":94},\"d9fccfef\":{\"m\":23,\"g\":94},\"679ebcbb\":{\"m\":23,\"g\":94},\"1edd4e07\":{\"m\":24,\"g\":94},\"62c673c4\":{\"m\":24,\"g\":94},\"377c5dc9\":{\"m\":24,\"g\":94},\"f52eda35\":{\"m\":24,\"g\":94},\"b579ecf0\":{\"m\":24,\"g\":94},\"e7487b08\":{\"m\":24,\"g\":94},\"ae5c0fc4\":{\"m\":24,\"g\":94},\"a30d5d75\":{\"m\":24,\"g\":94},\"17af39c5\":{\"m\":24,\"g\":94},\"daf593a3\":{\"m\":24,\"g\":94},\"bece265f\":{\"m\":24,\"g\":94},\"cdcbde5f\":{\"m\":24,\"g\":94},\"21e22b9e\":{\"m\":24,\"g\":94},\"a50c8a14\":{\"m\":24,\"g\":94},\"db6089e6\":{\"m\":24,\"g\":94},\"3520f75f\":{\"m\":24,\"g\":94},\"c8e9fed8\":{\"m\":24,\"g\":94},\"084fa54d\":{\"m\":24,\"g\":94},\"eba458bd\":{\"m\":24,\"g\":94},\"3d1cb0af\":{\"m\":24,\"g\":94},\"7d352b4f\":{\"m\":24,\"g\":94},\"87064015\":{\"m\":24,\"g\":94},\"7cd4f244\":{\"m\":24,\"g\":94},\"98111fbe\":{\"m\":24,\"g\":94},\"2ec39ab7\":{\"m\":24,\"g\":94},\"8f6274c8\":{\"m\":24,\"g\":94},\"325a06c2\":{\"m\":24,\"g\":94},\"79f81629\":{\"m\":24,\"g\":94},\"b688fd85\":{\"m\":24,\"g\":94},\"5bd89924\":{\"m\":24,\"g\":94},\"8d908a93\":{\"m\":24,\"g\":94},\"dd7e8b94\":{\"m\":24,\"g\":94},\"1f013d64\":{\"m\":24,\"g\":94},\"628e1fa7\":{\"m\":24,\"g\":94},\"c71880f8\":{\"m\":24,\"g\":94},\"bcb6611a\":{\"m\":24,\"g\":94},\"fa2aa0db\":{\"m\":24,\"g\":94},\"6a387a69\":{\"m\":24,\"g\":94},\"27f5ce0a\":{\"m\":24,\"g\":94},\"94862579\":{\"m\":24,\"g\":94},\"68e52626\":{\"m\":24,\"g\":94},\"e4d3333c\":{\"m\":25,\"g\":94},\"6f221d4c\":{\"m\":25,\"g\":94},\"aba6f51f\":{\"m\":25,\"g\":94},\"7f6c690b\":{\"m\":25,\"g\":94},\"40e6f513\":{\"m\":25,\"g\":94},\"40756776\":{\"m\":25,\"g\":94},\"9e8d2c7f\":{\"m\":25,\"g\":94},\"c9bff5fc\":{\"m\":25,\"g\":94},\"b04444ac\":{\"m\":25,\"g\":94},\"3d617a21\":{\"m\":25,\"g\":94},\"c020f9ce\":{\"m\":25,\"g\":94},\"ca600e8c\":{\"m\":25,\"g\":94},\"0c0c8137\":{\"m\":25,\"g\":94},\"90286d85\":{\"m\":25,\"g\":94},\"5e7dd984\":{\"m\":25,\"g\":94},\"bc3eaac2\":{\"m\":25,\"g\":94},\"a78d98de\":{\"m\":25,\"g\":94},\"7d5ed7c6\":{\"m\":25,\"g\":94},\"a6c7ebbb\":{\"m\":25,\"g\":94},\"bb0501c0\":{\"m\":25,\"g\":94},\"6b0f2e90\":{\"m\":25,\"g\":94},\"30a9b2ef\":{\"m\":26,\"g\":94},\"3cadecf0\":{\"m\":26,\"g\":94},\"e90e3a50\":{\"m\":26,\"g\":94},\"fbd6b94d\":{\"m\":26,\"g\":94},\"4c8093c8\":{\"m\":26,\"g\":94},\"ae7ee01a\":{\"m\":26,\"g\":94},\"76e59088\":{\"m\":26,\"g\":94},\"12ce3bef\":{\"m\":26,\"g\":94},\"4013a4e1\":{\"m\":26,\"g\":94},\"60340a36\":{\"m\":26,\"g\":94},\"70c78cfb\":{\"m\":26,\"g\":94},\"72b6ea88\":{\"m\":26,\"g\":94},\"b906c015\":{\"m\":27,\"g\":94},\"9319cd13\":{\"m\":27,\"g\":94},\"046c2b33\":{\"m\":27,\"g\":94},\"6b8f66ef\":{\"m\":27,\"g\":94},\"7937a886\":{\"m\":27,\"g\":94},\"2e218b9e\":{\"m\":27,\"g\":94},\"141e8c71\":{\"m\":28,\"g\":94},\"d53dcf9c\":{\"m\":28,\"g\":94},\"bb66cc4c\":{\"m\":28,\"g\":94},\"975adb80\":{\"m\":28,\"g\":94},\"0d4f3a9f\":{\"m\":28,\"g\":94},\"afd411d0\":{\"m\":28,\"g\":94},\"e1eae1fd\":{\"m\":28,\"g\":94},\"f4d9953d\":{\"m\":28,\"g\":94},\"4f005250\":{\"m\":28,\"g\":94},\"995af5a5\":{\"m\":28,\"g\":94},\"53985645\":{\"m\":28,\"g\":94},\"70cc0749\":{\"m\":28,\"g\":94},\"7dd8a7e6\":{\"m\":28,\"g\":94},\"947402c8\":{\"m\":28,\"g\":94},\"8c5382e6\":{\"m\":28,\"g\":94},\"001b0bdd\":{\"m\":28,\"g\":94},\"dc9d06d8\":{\"m\":29,\"g\":94},\"c31f084c\":{\"m\":29,\"g\":94},\"a01ddd96\":{\"m\":29,\"g\":94},\"7fa54a1a\":{\"m\":29,\"g\":94},\"05abd126\":{\"m\":29,\"g\":94},\"5f6fa04a\":{\"m\":29,\"g\":94},\"58a09708\":{\"m\":29,\"g\":94},\"ff68ae85\":{\"m\":29,\"g\":94},\"795eab6d\":{\"m\":29,\"g\":94},\"41bb1ab1\":{\"m\":29,\"g\":94},\"87e8c090\":{\"m\":29,\"g\":94},\"ad56e684\":{\"m\":29,\"g\":94},\"ffb15744\":{\"m\":29,\"g\":94},\"a9c833d5\":{\"m\":29,\"g\":94},\"94e01151\":{\"m\":29,\"g\":94},\"b216a545\":{\"m\":29,\"g\":94},\"fde83405\":{\"m\":29,\"g\":94},\"fd7926e4\":{\"m\":29,\"g\":94},\"399cad91\":{\"m\":29,\"g\":94},\"0a4f5f9b\":{\"m\":29,\"g\":94},\"3bc99e6f\":{\"m\":29,\"g\":94},\"ebf69964\":{\"m\":29,\"g\":94},\"b0ad0c1b\":{\"m\":30,\"g\":94},\"c877292c\":{\"m\":30,\"g\":94},\"0c1c72a0\":{\"m\":30,\"g\":94},\"41598e0d\":{\"m\":30,\"g\":94},\"89f23a51\":{\"m\":30,\"g\":94},\"cb99ba4f\":{\"m\":30,\"g\":94},\"32f61443\":{\"m\":30,\"g\":94},\"fb1f28cb\":{\"m\":30,\"g\":94},\"fb7421db\":{\"m\":30,\"g\":94},\"14b64930\":{\"m\":30,\"g\":94},\"82076370\":{\"m\":30,\"g\":94},\"7de60345\":{\"m\":30,\"g\":94},\"d84c5e70\":{\"m\":30,\"g\":94},\"d7854120\":{\"m\":30,\"g\":94},\"7b6a5332\":{\"m\":30,\"g\":94},\"4080e822\":{\"m\":30,\"g\":94},\"c245b789\":{\"m\":30,\"g\":94},\"9dae4078\":{\"m\":30,\"g\":94},\"fcc0f5ed\":{\"m\":30,\"g\":94},\"a97df791\":{\"m\":30,\"g\":94},\"33d61356\":{\"m\":30,\"g\":94},\"94752ac8\":{\"m\":30,\"g\":94},\"43fbb6d9\":{\"m\":30,\"g\":94},\"54fb1c80\":{\"m\":30,\"g\":94},\"b68c4c07\":{\"m\":30,\"g\":94},\"e712837d\":{\"m\":30,\"g\":94},\"7599bade\":{\"m\":30,\"g\":94},\"62757db6\":{\"m\":30,\"g\":94},\"73fa2d49\":{\"m\":30,\"g\":94},\"61728884\":{\"m\":30,\"g\":94},\"9cf0a5ba\":{\"m\":30,\"g\":94},\"b16e856f\":{\"m\":30,\"g\":94},\"05c50a82\":{\"m\":30,\"g\":94},\"b568df5d\":{\"m\":30,\"g\":94},\"10bca45b\":{\"m\":30,\"g\":94},\"b91a4cb1\":{\"m\":30,\"g\":94},\"95a28019\":{\"m\":30,\"g\":94},\"e040a245\":{\"m\":30,\"g\":94},\"9f662501\":{\"m\":30,\"g\":94},\"ab787594\":{\"m\":30,\"g\":94},\"228cf475\":{\"m\":30,\"g\":94},\"3a79613c\":{\"m\":30,\"g\":94},\"1ac304ee\":{\"m\":30,\"g\":94},\"20a4f927\":{\"m\":30,\"g\":94},\"0de7c2d0\":{\"m\":30,\"g\":94},\"6ed4e3b8\":{\"m\":30,\"g\":94},\"00023d62\":{\"m\":30,\"g\":94},\"c62d560c\":{\"m\":30,\"g\":94},\"2b8257f3\":{\"m\":30,\"g\":94},\"7623091d\":{\"m\":30,\"g\":94},\"f724f1f1\":{\"m\":30,\"g\":94},\"6db27f7b\":{\"m\":30,\"g\":94},\"4d929107\":{\"m\":30,\"g\":94},\"fbe0c818\":{\"m\":30,\"g\":94},\"5bd95374\":{\"m\":31,\"g\":94},\"0cb099e2\":{\"m\":31,\"g\":94},\"93d4e354\":{\"m\":31,\"g\":94},\"9195d136\":{\"m\":31,\"g\":94},\"14cb544d\":{\"m\":31,\"g\":94},\"e86b1ccb\":{\"m\":31,\"g\":94},\"8d2d876f\":{\"m\":31,\"g\":94},\"326df4ba\":{\"m\":31,\"g\":94},\"6767e222\":{\"m\":31,\"g\":94},\"73cf6834\":{\"m\":31,\"g\":94},\"1c2b5f52\":{\"m\":31,\"g\":94},\"96a2093e\":{\"m\":31,\"g\":94},\"a34dd86a\":{\"m\":31,\"g\":94},\"67c0d832\":{\"m\":31,\"g\":94},\"a59636bb\":{\"m\":31,\"g\":94},\"fe502432\":{\"m\":31,\"g\":94},\"f14569f6\":{\"m\":31,\"g\":94},\"8f790ac1\":{\"m\":31,\"g\":94},\"616b59f3\":{\"m\":31,\"g\":94},\"c8423ca3\":{\"m\":31,\"g\":94},\"e205527c\":{\"m\":31,\"g\":94},\"0909bb0d\":{\"m\":31,\"g\":94},\"ad3e4f16\":{\"m\":31,\"g\":94},\"312e8492\":{\"m\":31,\"g\":94},\"95f5fbf1\":{\"m\":31,\"g\":94},\"cebd78d8\":{\"m\":31,\"g\":94},\"0076f115\":{\"m\":31,\"g\":94},\"f7fb68d2\":{\"m\":31,\"g\":94},\"396a13e6\":{\"m\":31,\"g\":94},\"65915f9f\":{\"m\":31,\"g\":94},\"162f3ccb\":{\"m\":31,\"g\":94},\"65e89bae\":{\"m\":31,\"g\":94},\"6a38efa8\":{\"m\":31,\"g\":94},\"c5fe11a8\":{\"m\":32,\"g\":94},\"75ce37f4\":{\"m\":32,\"g\":94},\"97589a60\":{\"m\":32,\"g\":94},\"632d506d\":{\"m\":32,\"g\":94},\"3579162a\":{\"m\":32,\"g\":94},\"7514b9f8\":{\"m\":32,\"g\":94},\"158e8f1e\":{\"m\":32,\"g\":94},\"d3efcb39\":{\"m\":32,\"g\":94},\"2c615d12\":{\"m\":32,\"g\":94},\"61bb223e\":{\"m\":32,\"g\":94},\"15f1a49d\":{\"m\":32,\"g\":94},\"308d0240\":{\"m\":32,\"g\":94},\"ab4990e4\":{\"m\":32,\"g\":94},\"90227800\":{\"m\":32,\"g\":94},\"30b4f771\":{\"m\":32,\"g\":94},\"66e7dcaf\":{\"m\":32,\"g\":94},\"bc4c7a35\":{\"m\":32,\"g\":94},\"1cb4da5c\":{\"m\":32,\"g\":94},\"e61d13ac\":{\"m\":32,\"g\":94},\"b20daf98\":{\"m\":32,\"g\":94},\"f6af3a65\":{\"m\":32,\"g\":94},\"c9064e6f\":{\"m\":32,\"g\":94},\"a5b14ad0\":{\"m\":32,\"g\":94},\"5fafcac0\":{\"m\":32,\"g\":94},\"364d3d72\":{\"m\":32,\"g\":94},\"5623826f\":{\"m\":32,\"g\":94},\"83e23c69\":{\"m\":32,\"g\":94},\"ac1b74fa\":{\"m\":32,\"g\":94},\"068e9eae\":{\"m\":32,\"g\":94},\"d6aeb9fa\":{\"m\":32,\"g\":94},\"1fb94599\":{\"m\":32,\"g\":94},\"bea2bb9e\":{\"m\":32,\"g\":94},\"cd10654e\":{\"m\":32,\"g\":94},\"350a8160\":{\"m\":32,\"g\":94},\"6242c399\":{\"m\":32,\"g\":94},\"04707b09\":{\"m\":32,\"g\":94},\"ff2cfdb1\":{\"m\":32,\"g\":94},\"a8ae6403\":{\"m\":32,\"g\":94},\"d8476818\":{\"m\":32,\"g\":94},\"df191254\":{\"m\":32,\"g\":94},\"b997a18d\":{\"m\":32,\"g\":94},\"d8627ed1\":{\"m\":32,\"g\":94},\"fa13b95d\":{\"m\":32,\"g\":94},\"3c1f5a92\":{\"m\":32,\"g\":94},\"57d0bd91\":{\"m\":32,\"g\":94},\"cdc8d607\":{\"m\":32,\"g\":94},\"9208591f\":{\"m\":32,\"g\":94},\"5d0d40d0\":{\"m\":32,\"g\":94},\"f624f6a6\":{\"m\":32,\"g\":94},\"3694f8f9\":{\"m\":32,\"g\":94},\"5a261bd0\":{\"m\":32,\"g\":94},\"6aa8ad14\":{\"m\":32,\"g\":94},\"26e9c12c\":{\"m\":32,\"g\":94},\"87a0db82\":{\"m\":32,\"g\":94},\"f25f4dfd\":{\"m\":33,\"g\":94},\"184ae1c6\":{\"m\":33,\"g\":94},\"198974cd\":{\"m\":33,\"g\":94},\"6cc38b2b\":{\"m\":33,\"g\":94},\"1ece2cda\":{\"m\":33,\"g\":94},\"c8a9e791\":{\"m\":33,\"g\":94},\"3602692c\":{\"m\":33,\"g\":94},\"909f3436\":{\"m\":33,\"g\":94},\"5ff25cdf\":{\"m\":33,\"g\":94},\"2f1d9283\":{\"m\":33,\"g\":94},\"c61a1b6f\":{\"m\":33,\"g\":94},\"9935f97b\":{\"m\":33,\"g\":94},\"13ac95b8\":{\"m\":34,\"g\":94},\"492143bf\":{\"m\":34,\"g\":94},\"0a97d796\":{\"m\":34,\"g\":94},\"c411f32e\":{\"m\":34,\"g\":94},\"bf53bf51\":{\"m\":34,\"g\":94},\"b1a540ec\":{\"m\":34,\"g\":94},\"66975360\":{\"m\":34,\"g\":94},\"6c498313\":{\"m\":34,\"g\":94},\"99994427\":{\"m\":35,\"g\":94},\"6def9b01\":{\"m\":35,\"g\":94},\"47f20da2\":{\"m\":35,\"g\":94},\"4a9f8ea4\":{\"m\":35,\"g\":94},\"58fa6076\":{\"m\":35,\"g\":94},\"6487ef64\":{\"m\":35,\"g\":94},\"9b080524\":{\"m\":35,\"g\":94},\"32a4141d\":{\"m\":35,\"g\":94},\"08360553\":{\"m\":35,\"g\":94},\"00b19f19\":{\"m\":35,\"g\":94},\"6cb32ef9\":{\"m\":35,\"g\":94},\"761b2ceb\":{\"m\":35,\"g\":94},\"54772f78\":{\"m\":35,\"g\":94},\"1b5d56f7\":{\"m\":35,\"g\":94},\"d134c139\":{\"m\":35,\"g\":94},\"6cc9c525\":{\"m\":35,\"g\":94},\"52cefdbf\":{\"m\":35,\"g\":94},\"51c554d8\":{\"m\":35,\"g\":94},\"79ece2c5\":{\"m\":35,\"g\":94},\"55f5976b\":{\"m\":35,\"g\":94},\"b7f83410\":{\"m\":35,\"g\":94},\"f414352a\":{\"m\":35,\"g\":94},\"a362340b\":{\"m\":35,\"g\":94},\"381dd57b\":{\"m\":35,\"g\":94},\"8153168c\":{\"m\":35,\"g\":94},\"6c34d633\":{\"m\":35,\"g\":94},\"5ab9418f\":{\"m\":36,\"g\":94},\"843e63d8\":{\"m\":36,\"g\":94},\"a63c8275\":{\"m\":36,\"g\":94},\"dc67d976\":{\"m\":36,\"g\":94},\"1e495e08\":{\"m\":36,\"g\":94},\"12cb115d\":{\"m\":36,\"g\":94},\"c500f96b\":{\"m\":36,\"g\":94},\"474317f2\":{\"m\":36,\"g\":94},\"f64eae3a\":{\"m\":36,\"g\":94},\"a5a134f3\":{\"m\":36,\"g\":94},\"2561ed01\":{\"m\":36,\"g\":94},\"90a26be3\":{\"m\":37,\"g\":94},\"1f4b5f77\":{\"m\":37,\"g\":94},\"76524b70\":{\"m\":37,\"g\":94},\"3a6e0418\":{\"m\":37,\"g\":94},\"2fa5cec7\":{\"m\":37,\"g\":94},\"27b557ae\":{\"m\":37,\"g\":94},\"93dffd69\":{\"m\":37,\"g\":94},\"2abe4f1c\":{\"m\":37,\"g\":94},\"37963394\":{\"m\":37,\"g\":94},\"899cf5c4\":{\"m\":37,\"g\":94},\"e79f6cd7\":{\"m\":37,\"g\":94},\"9ba1f097\":{\"m\":37,\"g\":94},\"282681b8\":{\"m\":37,\"g\":94},\"58cafe23\":{\"m\":37,\"g\":94},\"9463bc13\":{\"m\":37,\"g\":94},\"e3fc4658\":{\"m\":37,\"g\":94},\"33b54e7c\":{\"m\":37,\"g\":94},\"30b404ce\":{\"m\":37,\"g\":94},\"70b68029\":{\"m\":37,\"g\":94},\"f3d32f88\":{\"m\":37,\"g\":94},\"8779da95\":{\"m\":37,\"g\":94},\"ad0ff62a\":{\"m\":37,\"g\":94},\"9a903a87\":{\"m\":37,\"g\":94},\"68be2f6d\":{\"m\":37,\"g\":94},\"b912de11\":{\"m\":37,\"g\":94},\"eb02c161\":{\"m\":37,\"g\":94},\"71221692\":{\"m\":37,\"g\":94},\"c33d82a2\":{\"m\":37,\"g\":94},\"8234e663\":{\"m\":37,\"g\":94},\"debbdb51\":{\"m\":37,\"g\":94},\"3efa7981\":{\"m\":37,\"g\":94},\"2a71be5e\":{\"m\":37,\"g\":94},\"44621377\":{\"m\":37,\"g\":94},\"fec185ce\":{\"m\":37,\"g\":94},\"c03cece4\":{\"m\":37,\"g\":94},\"15c75e41\":{\"m\":37,\"g\":94},\"224200e3\":{\"m\":37,\"g\":94},\"8c0efa51\":{\"m\":37,\"g\":94},\"144bc70f\":{\"m\":37,\"g\":94},\"46094e0c\":{\"m\":37,\"g\":94},\"3a6e8b6d\":{\"m\":37,\"g\":94},\"fbb4754c\":{\"m\":37,\"g\":94},\"6c7cb903\":{\"m\":37,\"g\":94},\"dff2860a\":{\"m\":37,\"g\":94},\"e72275cf\":{\"m\":37,\"g\":94},\"fec2d122\":{\"m\":37,\"g\":94},\"8d1095db\":{\"m\":37,\"g\":94},\"743007e1\":{\"m\":37,\"g\":94},\"9144ed10\":{\"m\":37,\"g\":94},\"69b3bb9a\":{\"m\":37,\"g\":94},\"689ff588\":{\"m\":37,\"g\":94},\"a7c47e0f\":{\"m\":37,\"g\":94},\"e4d68afc\":{\"m\":37,\"g\":94},\"c9b75917\":{\"m\":37,\"g\":94},\"662ecd93\":{\"m\":37,\"g\":94},\"8e6bdf85\":{\"m\":37,\"g\":94},\"05bea688\":{\"m\":37,\"g\":94},\"ab4a83b2\":{\"m\":37,\"g\":94},\"62f15eea\":{\"m\":37,\"g\":94},\"79794af5\":{\"m\":37,\"g\":94},\"3494b32c\":{\"m\":37,\"g\":94},\"eda7c090\":{\"m\":37,\"g\":94},\"5ce55aee\":{\"m\":38,\"g\":94},\"2d346a57\":{\"m\":38,\"g\":94},\"446ea332\":{\"m\":38,\"g\":94},\"8f527e29\":{\"m\":38,\"g\":94},\"7f24ea95\":{\"m\":38,\"g\":94},\"1acccb36\":{\"m\":38,\"g\":94},\"aa2750be\":{\"m\":38,\"g\":94},\"5e62a6b7\":{\"m\":38,\"g\":94},\"5752f25e\":{\"m\":38,\"g\":94},\"7c162fa9\":{\"m\":38,\"g\":94},\"36078fb2\":{\"m\":38,\"g\":94},\"b3710d2c\":{\"m\":38,\"g\":94},\"c6b6d2e7\":{\"m\":38,\"g\":94},\"82136eb0\":{\"m\":39,\"g\":94},\"b8ccaf4d\":{\"m\":39,\"g\":94},\"a68cb201\":{\"m\":39,\"g\":94},\"014982b5\":{\"m\":39,\"g\":94},\"a6db8862\":{\"m\":39,\"g\":94},\"b4408b0d\":{\"m\":39,\"g\":94},\"2cd7e181\":{\"m\":39,\"g\":94},\"37c5899f\":{\"m\":40,\"g\":94},\"f39a0197\":{\"m\":40,\"g\":94},\"3c93187c\":{\"m\":40,\"g\":94},\"fb2d0680\":{\"m\":40,\"g\":94},\"067d8e16\":{\"m\":40,\"g\":94},\"e6692bf4\":{\"m\":40,\"g\":94},\"28b4d8e1\":{\"m\":40,\"g\":94},\"bc068e96\":{\"m\":40,\"g\":94},\"8d4ed42a\":{\"m\":40,\"g\":94},\"2854a5ea\":{\"m\":40,\"g\":94},\"42a2d82b\":{\"m\":40,\"g\":94},\"e4780cf8\":{\"m\":40,\"g\":94},\"39bb49d1\":{\"m\":40,\"g\":94},\"6f3cf129\":{\"m\":40,\"g\":94},\"13f1357e\":{\"m\":40,\"g\":94},\"2a99993c\":{\"m\":40,\"g\":94},\"167591e8\":{\"m\":40,\"g\":94},\"441c22db\":{\"m\":40,\"g\":94},\"ce636ac4\":{\"m\":40,\"g\":94},\"7b69d91b\":{\"m\":41,\"g\":94},\"e8613df0\":{\"m\":41,\"g\":94},\"c5325aba\":{\"m\":41,\"g\":94},\"ebbc42d9\":{\"m\":41,\"g\":94},\"3ff64113\":{\"m\":41,\"g\":94},\"2b302b93\":{\"m\":41,\"g\":94},\"68f8b60d\":{\"m\":41,\"g\":94},\"6a5b352a\":{\"m\":41,\"g\":94},\"565b05f0\":{\"m\":41,\"g\":94},\"b6aad70a\":{\"m\":41,\"g\":94},\"551a3a9d\":{\"m\":41,\"g\":94},\"91877a9f\":{\"m\":41,\"g\":94},\"f7cce751\":{\"m\":41,\"g\":94},\"17e998f1\":{\"m\":41,\"g\":94},\"c98e84c2\":{\"m\":41,\"g\":94},\"9c064bf7\":{\"m\":41,\"g\":94},\"58d1082e\":{\"m\":41,\"g\":94},\"4d086719\":{\"m\":41,\"g\":94},\"9244f27f\":{\"m\":41,\"g\":94},\"2422de51\":{\"m\":41,\"g\":94},\"521f862d\":{\"m\":41,\"g\":94},\"34c32d28\":{\"m\":41,\"g\":94},\"dde8bb16\":{\"m\":41,\"g\":94},\"8ac3ccc0\":{\"m\":41,\"g\":94},\"9b0926ce\":{\"m\":41,\"g\":94},\"1c1bdc76\":{\"m\":41,\"g\":94},\"6bfdb403\":{\"m\":41,\"g\":94},\"f8fb4ce9\":{\"m\":41,\"g\":94},\"5d0ba403\":{\"m\":41,\"g\":94},\"04b262cd\":{\"m\":41,\"g\":94},\"2432ad40\":{\"m\":41,\"g\":94},\"45473d4b\":{\"m\":41,\"g\":94},\"114bbc86\":{\"m\":41,\"g\":94},\"32eb6e96\":{\"m\":41,\"g\":94},\"e0b5dbce\":{\"m\":41,\"g\":94},\"e6852b0d\":{\"m\":41,\"g\":94},\"4ae0969c\":{\"m\":41,\"g\":94},\"317631ca\":{\"m\":41,\"g\":94},\"b5648353\":{\"m\":41,\"g\":94},\"2c7d0a5b\":{\"m\":41,\"g\":94},\"8cdc76f6\":{\"m\":41,\"g\":94},\"f202ed97\":{\"m\":41,\"g\":94},\"100f5b8b\":{\"m\":41,\"g\":94},\"619bb6dd\":{\"m\":41,\"g\":94},\"b88ea90d\":{\"m\":41,\"g\":94},\"99ec439d\":{\"m\":41,\"g\":94},\"0f4fb19b\":{\"m\":41,\"g\":94},\"63ba2f8d\":{\"m\":41,\"g\":94},\"36d5acfc\":{\"m\":41,\"g\":94},\"3f0fe08d\":{\"m\":41,\"g\":94},\"55b974f9\":{\"m\":41,\"g\":94},\"f86c1e61\":{\"m\":41,\"g\":94},\"acaffd23\":{\"m\":41,\"g\":94},\"04868543\":{\"m\":41,\"g\":94},\"fd9ad817\":{\"m\":41,\"g\":94},\"e165a9fc\":{\"m\":41,\"g\":94},\"4e4459b9\":{\"m\":41,\"g\":94},\"065bb947\":{\"m\":41,\"g\":94},\"f42e9bfb\":{\"m\":41,\"g\":94},\"840c5dbc\":{\"m\":41,\"g\":94},\"63e845d0\":{\"m\":41,\"g\":94},\"9aa6553d\":{\"m\":41,\"g\":94},\"b1e330bc\":{\"m\":41,\"g\":94},\"4353acb4\":{\"m\":41,\"g\":94},\"9ae1db0b\":{\"m\":41,\"g\":94},\"00c7e636\":{\"m\":42,\"g\":94},\"23cc66f7\":{\"m\":42,\"g\":94},\"5d09ca57\":{\"m\":42,\"g\":94},\"81c33274\":{\"m\":42,\"g\":94},\"f13d86f9\":{\"m\":42,\"g\":94},\"aba9eae4\":{\"m\":42,\"g\":94},\"bbd72bfc\":{\"m\":42,\"g\":94},\"b503881b\":{\"m\":42,\"g\":94},\"58093b86\":{\"m\":42,\"g\":94},\"8275049c\":{\"m\":42,\"g\":94},\"5476ccad\":{\"m\":42,\"g\":94},\"b040ed71\":{\"m\":42,\"g\":94},\"c9e66586\":{\"m\":42,\"g\":94},\"e11ab79e\":{\"m\":42,\"g\":94},\"01fdb2f3\":{\"m\":42,\"g\":94},\"c996e8cc\":{\"m\":42,\"g\":94},\"087257ea\":{\"m\":43,\"g\":94},\"736f0402\":{\"m\":43,\"g\":94},\"769bf11c\":{\"m\":43,\"g\":94},\"3db43d1b\":{\"m\":43,\"g\":94},\"f0f8a769\":{\"m\":43,\"g\":94},\"2bcfba1b\":{\"m\":43,\"g\":94},\"bc12d403\":{\"m\":43,\"g\":94},\"392f2863\":{\"m\":43,\"g\":94},\"6d0fa73e\":{\"m\":43,\"g\":94},\"9e0dac1a\":{\"m\":43,\"g\":94},\"a95d5589\":{\"m\":43,\"g\":94},\"d17d19e5\":{\"m\":43,\"g\":94},\"dd3809fa\":{\"m\":43,\"g\":94},\"7feba415\":{\"m\":43,\"g\":94},\"30ee3630\":{\"m\":43,\"g\":94},\"e5db40dc\":{\"m\":43,\"g\":94},\"b1709305\":{\"m\":43,\"g\":94},\"5ab20cce\":{\"m\":43,\"g\":94},\"02f7f3e4\":{\"m\":43,\"g\":94},\"2782132b\":{\"m\":43,\"g\":94},\"d19cc0b9\":{\"m\":43,\"g\":94},\"b0facb33\":{\"m\":43,\"g\":94},\"ecb8bad2\":{\"m\":43,\"g\":94},\"dbec2f18\":{\"m\":43,\"g\":94},\"e4b367ba\":{\"m\":43,\"g\":94},\"d10b933a\":{\"m\":43,\"g\":94},\"9116b289\":{\"m\":43,\"g\":94},\"a5114b6f\":{\"m\":43,\"g\":94},\"b6b40946\":{\"m\":43,\"g\":94},\"f1088e0f\":{\"m\":43,\"g\":94},\"175afed3\":{\"m\":43,\"g\":94},\"4a292f67\":{\"m\":43,\"g\":94},\"cd0be748\":{\"m\":43,\"g\":94},\"56503d9b\":{\"m\":43,\"g\":94},\"02bc9579\":{\"m\":43,\"g\":94},\"24f3e151\":{\"m\":43,\"g\":94},\"6790240c\":{\"m\":43,\"g\":94},\"061e5463\":{\"m\":43,\"g\":94},\"0c1e8796\":{\"m\":43,\"g\":94},\"869f1c02\":{\"m\":43,\"g\":94},\"2725f8da\":{\"m\":43,\"g\":94},\"da1ffed6\":{\"m\":43,\"g\":94},\"48761171\":{\"m\":43,\"g\":94},\"c3f2fc5a\":{\"m\":43,\"g\":94},\"7ee6c259\":{\"m\":43,\"g\":94},\"9610fcd4\":{\"m\":43,\"g\":94},\"31fad29a\":{\"m\":43,\"g\":94},\"9da5a60b\":{\"m\":43,\"g\":94},\"69aa937a\":{\"m\":43,\"g\":94},\"5d638c92\":{\"m\":43,\"g\":94},\"e37cdab0\":{\"m\":43,\"g\":94},\"1d9deeac\":{\"m\":43,\"g\":94},\"dafb6a52\":{\"m\":43,\"g\":94},\"862cd265\":{\"m\":43,\"g\":94},\"1f26e8b8\":{\"m\":44,\"g\":94},\"5e1558f1\":{\"m\":44,\"g\":94},\"94cde109\":{\"m\":44,\"g\":94},\"00611286\":{\"m\":44,\"g\":94},\"e68b9e76\":{\"m\":44,\"g\":94},\"7ce36068\":{\"m\":44,\"g\":94},\"efb099cd\":{\"m\":44,\"g\":94},\"09603c6d\":{\"m\":44,\"g\":94},\"cf470fea\":{\"m\":44,\"g\":94},\"45d5af24\":{\"m\":44,\"g\":94},\"b121bc03\":{\"m\":44,\"g\":94},\"e12358dc\":{\"m\":44,\"g\":94},\"554fbf93\":{\"m\":44,\"g\":94},\"b48edff6\":{\"m\":44,\"g\":94},\"593b19f2\":{\"m\":44,\"g\":94},\"59cbf476\":{\"m\":44,\"g\":94},\"95946271\":{\"m\":44,\"g\":94},\"5c4ce656\":{\"m\":44,\"g\":94},\"cbbc82b7\":{\"m\":44,\"g\":94},\"8bee20f8\":{\"m\":44,\"g\":94},\"12cad0fe\":{\"m\":44,\"g\":94},\"b6cd9036\":{\"m\":44,\"g\":94},\"30643fed\":{\"m\":45,\"g\":94},\"e646c590\":{\"m\":45,\"g\":94},\"c555ce2c\":{\"m\":45,\"g\":94},\"40900bae\":{\"m\":45,\"g\":94},\"a2f5e755\":{\"m\":45,\"g\":94},\"2148914e\":{\"m\":45,\"g\":94},\"def55bc8\":{\"m\":45,\"g\":94},\"86a2c473\":{\"m\":45,\"g\":94},\"1701b0db\":{\"m\":45,\"g\":94},\"384d85ba\":{\"m\":45,\"g\":94},\"60597219\":{\"m\":45,\"g\":94},\"fc82f5a7\":{\"m\":45,\"g\":94},\"0089c4bc\":{\"m\":45,\"g\":94},\"72e7b57a\":{\"m\":45,\"g\":94},\"87a7cfa0\":{\"m\":45,\"g\":94},\"8f8f96a6\":{\"m\":45,\"g\":94},\"05b3bf5e\":{\"m\":45,\"g\":94},\"3f5ac88d\":{\"m\":45,\"g\":94},\"0d800090\":{\"m\":45,\"g\":94},\"b7d05594\":{\"m\":45,\"g\":94},\"80a90547\":{\"m\":45,\"g\":94},\"9af7b88e\":{\"m\":45,\"g\":94},\"fbcbb263\":{\"m\":45,\"g\":94},\"2fce449b\":{\"m\":45,\"g\":94},\"ad4125d1\":{\"m\":45,\"g\":94},\"17536e7e\":{\"m\":45,\"g\":94},\"65859754\":{\"m\":46,\"g\":94},\"2ce32db6\":{\"m\":46,\"g\":94},\"793b79db\":{\"m\":46,\"g\":94},\"1363b519\":{\"m\":46,\"g\":94},\"0abbf289\":{\"m\":46,\"g\":94},\"c17c5781\":{\"m\":46,\"g\":94},\"916b3cdd\":{\"m\":46,\"g\":94},\"838dcda1\":{\"m\":46,\"g\":94},\"efbc116a\":{\"m\":46,\"g\":94},\"6aed0445\":{\"m\":46,\"g\":94},\"908dd7f9\":{\"m\":46,\"g\":94},\"f4cd8040\":{\"m\":46,\"g\":94},\"be7986e0\":{\"m\":46,\"g\":94},\"5a5f1843\":{\"m\":46,\"g\":94},\"7b394e5f\":{\"m\":46,\"g\":94},\"3b60558d\":{\"m\":46,\"g\":94},\"5a9a4f41\":{\"m\":46,\"g\":94},\"72e979bf\":{\"m\":46,\"g\":94},\"146f6134\":{\"m\":46,\"g\":94},\"660ecb73\":{\"m\":46,\"g\":94},\"2565cb0f\":{\"m\":46,\"g\":94},\"066e8a4e\":{\"m\":46,\"g\":94},\"2134f089\":{\"m\":46,\"g\":94},\"a54f278d\":{\"m\":46,\"g\":94},\"d1b31b06\":{\"m\":46,\"g\":94},\"d59a4782\":{\"m\":46,\"g\":94},\"104bf260\":{\"m\":46,\"g\":94},\"3bf3d011\":{\"m\":46,\"g\":94},\"d86a2d65\":{\"m\":46,\"g\":94},\"16eb33ff\":{\"m\":46,\"g\":94},\"61cf00e1\":{\"m\":46,\"g\":94},\"b9fd178f\":{\"m\":46,\"g\":94},\"d8e9d61f\":{\"m\":46,\"g\":94},\"a2e0424a\":{\"m\":46,\"g\":94},\"8ce202a4\":{\"m\":46,\"g\":94},\"d913d52c\":{\"m\":46,\"g\":94},\"0ab7bcaf\":{\"m\":46,\"g\":94},\"438526a8\":{\"m\":46,\"g\":94},\"f7102fbd\":{\"m\":46,\"g\":94},\"a7a0a688\":{\"m\":46,\"g\":94},\"2d4ce1b7\":{\"m\":46,\"g\":94},\"4ba815b8\":{\"m\":46,\"g\":94},\"5f65e2b8\":{\"m\":46,\"g\":94},\"4e2af03c\":{\"m\":46,\"g\":94},\"3184aa95\":{\"m\":46,\"g\":94},\"b548801d\":{\"m\":46,\"g\":94},\"539df95d\":{\"m\":46,\"g\":94},\"5e00ddeb\":{\"m\":46,\"g\":94},\"54dd3ea1\":{\"m\":46,\"g\":94},\"d04899d7\":{\"m\":46,\"g\":94},\"5010e0d2\":{\"m\":46,\"g\":94},\"5e6c3265\":{\"m\":46,\"g\":94},\"680cad20\":{\"m\":46,\"g\":94},\"0a24eb85\":{\"m\":46,\"g\":94},\"3839be29\":{\"m\":46,\"g\":94},\"6e13b650\":{\"m\":46,\"g\":94},\"6fcd6d7d\":{\"m\":46,\"g\":94},\"c77762d5\":{\"m\":46,\"g\":94},\"51c81e33\":{\"m\":46,\"g\":94},\"eaade87a\":{\"m\":46,\"g\":94},\"86fc0d79\":{\"m\":46,\"g\":94},\"1be853ee\":{\"m\":46,\"g\":94},\"86e0dde5\":{\"m\":46,\"g\":94},\"2b809788\":{\"m\":46,\"g\":94},\"9d6fb084\":{\"m\":46,\"g\":94},\"ced362f7\":{\"m\":46,\"g\":94},\"9084a864\":{\"m\":46,\"g\":94},\"6aa94b96\":{\"m\":46,\"g\":94},\"c2650748\":{\"m\":46,\"g\":94},\"07bf2e84\":{\"m\":46,\"g\":94},\"a628dd8e\":{\"m\":46,\"g\":94},\"1e890341\":{\"m\":46,\"g\":94},\"715b16c1\":{\"m\":46,\"g\":94},\"9ce8e1a9\":{\"m\":46,\"g\":94},\"fb99aaa5\":{\"m\":46,\"g\":94},\"b77a02cd\":{\"m\":46,\"g\":94},\"f407fcf9\":{\"m\":47,\"g\":94},\"54479d6f\":{\"m\":47,\"g\":94},\"ba069a24\":{\"m\":47,\"g\":94},\"125b1199\":{\"m\":47,\"g\":94},\"eff468dd\":{\"m\":47,\"g\":94},\"a1bd7190\":{\"m\":47,\"g\":94},\"78c1d644\":{\"m\":47,\"g\":94},\"027e6524\":{\"m\":47,\"g\":94},\"b808a383\":{\"m\":47,\"g\":94},\"602ebc66\":{\"m\":47,\"g\":94},\"530ae1bd\":{\"m\":47,\"g\":94},\"befc6beb\":{\"m\":47,\"g\":94},\"59a5ba9b\":{\"m\":47,\"g\":94},\"86c37d01\":{\"m\":47,\"g\":94},\"f18b9c72\":{\"m\":47,\"g\":94},\"3e335743\":{\"m\":47,\"g\":94},\"0d94f1dd\":{\"m\":47,\"g\":94},\"e728258d\":{\"m\":47,\"g\":94},\"239eafbd\":{\"m\":47,\"g\":94},\"9d427265\":{\"m\":47,\"g\":94},\"00ffde20\":{\"m\":47,\"g\":94},\"ddeb9d42\":{\"m\":47,\"g\":94},\"aaf0a315\":{\"m\":47,\"g\":94},\"f9633fa9\":{\"m\":47,\"g\":94},\"087ab832\":{\"m\":47,\"g\":94},\"8169c6f4\":{\"m\":47,\"g\":94},\"3d043319\":{\"m\":47,\"g\":94},\"a8aad935\":{\"m\":47,\"g\":94},\"47ffe7af\":{\"m\":47,\"g\":94},\"b3523af8\":{\"m\":47,\"g\":94},\"1929c067\":{\"m\":47,\"g\":94},\"ed53ac84\":{\"m\":47,\"g\":94},\"520f0094\":{\"m\":47,\"g\":94},\"9c939a3d\":{\"m\":47,\"g\":94},\"549e8b83\":{\"m\":47,\"g\":94},\"a1f32867\":{\"m\":47,\"g\":94},\"760552e0\":{\"m\":47,\"g\":94},\"d9aada9d\":{\"m\":47,\"g\":94},\"f11eb90f\":{\"m\":47,\"g\":94},\"95a4ed12\":{\"m\":47,\"g\":94},\"d1150e9a\":{\"m\":47,\"g\":94},\"e3126e3c\":{\"m\":47,\"g\":94},\"a5095520\":{\"m\":47,\"g\":94},\"7ef0084b\":{\"m\":47,\"g\":94},\"f9a377f6\":{\"m\":47,\"g\":94},\"4ade15dd\":{\"m\":47,\"g\":94},\"8dc84da0\":{\"m\":47,\"g\":94},\"f16eb15d\":{\"m\":47,\"g\":94},\"5bc2508b\":{\"m\":47,\"g\":94},\"a71a44f2\":{\"m\":47,\"g\":94},\"691808d5\":{\"m\":47,\"g\":94},\"d32fba2a\":{\"m\":47,\"g\":94},\"67c424cc\":{\"m\":47,\"g\":94},\"1ae270c5\":{\"m\":47,\"g\":94},\"c77c1e05\":{\"m\":47,\"g\":94},\"dca87ec3\":{\"m\":47,\"g\":94},\"4b1d7a25\":{\"m\":47,\"g\":94},\"a5e0defb\":{\"m\":47,\"g\":94},\"96766101\":{\"m\":47,\"g\":94},\"a146d999\":{\"m\":47,\"g\":94},\"f5113e50\":{\"m\":47,\"g\":94},\"02755768\":{\"m\":47,\"g\":94},\"463d56bf\":{\"m\":47,\"g\":94},\"530ff541\":{\"m\":47,\"g\":94},\"3cd28092\":{\"m\":47,\"g\":94},\"704f8e8e\":{\"m\":47,\"g\":94},\"1853c352\":{\"m\":47,\"g\":94},\"32c9a7ec\":{\"m\":48,\"g\":94},\"b01df48c\":{\"m\":48,\"g\":94},\"c29b98e0\":{\"m\":48,\"g\":94},\"954f4e6b\":{\"m\":48,\"g\":94},\"2558d6a6\":{\"m\":48,\"g\":94},\"29ebe3df\":{\"m\":48,\"g\":94},\"f6dd6486\":{\"m\":48,\"g\":94},\"ea53c63b\":{\"m\":48,\"g\":94},\"a10d5309\":{\"m\":48,\"g\":94},\"aae5434b\":{\"m\":48,\"g\":94},\"c3eac1b0\":{\"m\":48,\"g\":94},\"b275ce00\":{\"m\":48,\"g\":94},\"13ce3e4b\":{\"m\":48,\"g\":94},\"df246e69\":{\"m\":48,\"g\":94},\"fb9fb351\":{\"m\":48,\"g\":94},\"c722d9bd\":{\"m\":48,\"g\":94},\"218ab361\":{\"m\":48,\"g\":94},\"9a00e6f4\":{\"m\":49,\"g\":94},\"4f8c3aea\":{\"m\":49,\"g\":94},\"2369e882\":{\"m\":49,\"g\":94},\"ad30d5cf\":{\"m\":49,\"g\":94},\"dfec7fca\":{\"m\":49,\"g\":94},\"8048c28c\":{\"m\":49,\"g\":94},\"30af7dfb\":{\"m\":49,\"g\":94},\"f6f71379\":{\"m\":49,\"g\":94},\"f35cb46c\":{\"m\":49,\"g\":94},\"7f8fcd39\":{\"m\":49,\"g\":94},\"5c6a41fa\":{\"m\":49,\"g\":94},\"722530fa\":{\"m\":49,\"g\":94},\"56a347f7\":{\"m\":49,\"g\":94},\"3295cd8a\":{\"m\":49,\"g\":94},\"5942dfc0\":{\"m\":49,\"g\":94},\"63a395b9\":{\"m\":49,\"g\":94},\"7d671e4a\":{\"m\":49,\"g\":94},\"699384cb\":{\"m\":49,\"g\":94},\"ffd20fcd\":{\"m\":49,\"g\":94},\"55bd97f3\":{\"m\":49,\"g\":94},\"e57c3e12\":{\"m\":49,\"g\":94},\"f239268f\":{\"m\":49,\"g\":94},\"929c7621\":{\"m\":49,\"g\":94},\"b7a065ea\":{\"m\":49,\"g\":94},\"b1104538\":{\"m\":49,\"g\":94},\"3b44bbee\":{\"m\":49,\"g\":94},\"80e2c4a8\":{\"m\":49,\"g\":94},\"66318ffe\":{\"m\":49,\"g\":94},\"76619261\":{\"m\":49,\"g\":94},\"2a3992b6\":{\"m\":49,\"g\":94},\"4af3f889\":{\"m\":49,\"g\":94},\"df7fe452\":{\"m\":49,\"g\":94},\"a7164b62\":{\"m\":49,\"g\":94},\"11668533\":{\"m\":49,\"g\":94},\"a9e90b4b\":{\"m\":49,\"g\":94},\"8c280cee\":{\"m\":49,\"g\":94},\"9c745d07\":{\"m\":49,\"g\":94},\"ebaa2f31\":{\"m\":49,\"g\":94},\"62832bb2\":{\"m\":49,\"g\":94},\"11f881d1\":{\"m\":49,\"g\":94},\"38625e21\":{\"m\":49,\"g\":94},\"c1f401fc\":{\"m\":49,\"g\":94},\"3b878863\":{\"m\":49,\"g\":94},\"f719d9ae\":{\"m\":49,\"g\":94},\"edad3731\":{\"m\":49,\"g\":94},\"976bc302\":{\"m\":49,\"g\":94},\"2f2e0743\":{\"m\":49,\"g\":94},\"2ffe0a73\":{\"m\":49,\"g\":94},\"cf248976\":{\"m\":49,\"g\":94},\"e5c67150\":{\"m\":49,\"g\":94},\"023d0a73\":{\"m\":49,\"g\":94},\"ac5a0f04\":{\"m\":50,\"g\":94},\"ea34350d\":{\"m\":50,\"g\":94},\"1605ae12\":{\"m\":50,\"g\":94},\"1aea19f6\":{\"m\":50,\"g\":94},\"1f76fc6e\":{\"m\":50,\"g\":94},\"7f076c2c\":{\"m\":50,\"g\":94},\"3c5538f7\":{\"m\":50,\"g\":94},\"10189d08\":{\"m\":50,\"g\":94},\"c4336b2b\":{\"m\":50,\"g\":94},\"4d62bca5\":{\"m\":50,\"g\":94},\"e1e595d7\":{\"m\":50,\"g\":94},\"5ada33ff\":{\"m\":50,\"g\":94},\"254fd130\":{\"m\":50,\"g\":94},\"538fa0ae\":{\"m\":50,\"g\":94},\"55842eb8\":{\"m\":50,\"g\":94},\"a866b65e\":{\"m\":50,\"g\":94},\"4b0a1c93\":{\"m\":50,\"g\":94},\"8e1adb84\":{\"m\":50,\"g\":94},\"dd44173d\":{\"m\":50,\"g\":94},\"8912b763\":{\"m\":50,\"g\":94},\"be0124bd\":{\"m\":50,\"g\":94},\"fe5d3e81\":{\"m\":50,\"g\":94},\"731146f6\":{\"m\":50,\"g\":94},\"fa271613\":{\"m\":50,\"g\":94},\"5652c565\":{\"m\":50,\"g\":94},\"e3938b2f\":{\"m\":50,\"g\":94},\"c211e7b6\":{\"m\":50,\"g\":94},\"d90c3d6b\":{\"m\":50,\"g\":94},\"9e8f8fbf\":{\"m\":50,\"g\":94},\"b509db58\":{\"m\":50,\"g\":94},\"dbe17293\":{\"m\":50,\"g\":94},\"84a1698d\":{\"m\":50,\"g\":94},\"32293a29\":{\"m\":50,\"g\":94},\"79216908\":{\"m\":50,\"g\":94},\"bbb81c24\":{\"m\":50,\"g\":94},\"52f58fc4\":{\"m\":50,\"g\":94},\"145c0ddc\":{\"m\":50,\"g\":94},\"505d7f71\":{\"m\":50,\"g\":94},\"cbedd1db\":{\"m\":50,\"g\":94},\"ad47749b\":{\"m\":50,\"g\":94},\"751c3a03\":{\"m\":50,\"g\":94},\"60769be1\":{\"m\":50,\"g\":94},\"a78d8f8d\":{\"m\":50,\"g\":94},\"c5f86501\":{\"m\":50,\"g\":94},\"d98fa1e9\":{\"m\":50,\"g\":94},\"865233e2\":{\"m\":50,\"g\":94},\"66d4859a\":{\"m\":50,\"g\":94},\"e1b63624\":{\"m\":50,\"g\":94},\"c35cd1f8\":{\"m\":50,\"g\":94},\"72f87b72\":{\"m\":50,\"g\":94},\"62a4a339\":{\"m\":50,\"g\":94},\"2797bc34\":{\"m\":50,\"g\":94},\"fed4c694\":{\"m\":51,\"g\":94},\"fb6e04a0\":{\"m\":51,\"g\":94},\"6997e28f\":{\"m\":51,\"g\":94},\"a0e58740\":{\"m\":51,\"g\":94},\"37c8a576\":{\"m\":51,\"g\":94},\"c754652f\":{\"m\":51,\"g\":94},\"0b46b951\":{\"m\":51,\"g\":94},\"2763c0a7\":{\"m\":51,\"g\":94},\"de3b67b7\":{\"m\":51,\"g\":94},\"19f33b32\":{\"m\":51,\"g\":94},\"30ce5b59\":{\"m\":51,\"g\":94},\"bc1f6fda\":{\"m\":51,\"g\":94},\"867e092f\":{\"m\":51,\"g\":94},\"88c7763f\":{\"m\":51,\"g\":94},\"e4118b15\":{\"m\":51,\"g\":94},\"ba4ee37f\":{\"m\":51,\"g\":94},\"fae4e5e9\":{\"m\":52,\"g\":94},\"afe1e465\":{\"m\":52,\"g\":94},\"f50a6cf4\":{\"m\":52,\"g\":94},\"fe97a2d4\":{\"m\":52,\"g\":94},\"8b48496a\":{\"m\":52,\"g\":94},\"4057ea82\":{\"m\":52,\"g\":94},\"4f2ee48e\":{\"m\":52,\"g\":94},\"71ff2728\":{\"m\":52,\"g\":94},\"b7038fec\":{\"m\":52,\"g\":94},\"65fdb289\":{\"m\":52,\"g\":94},\"b2ccf36d\":{\"m\":52,\"g\":94},\"d4fc1a70\":{\"m\":52,\"g\":94},\"db674e3d\":{\"m\":52,\"g\":94},\"fb915bd1\":{\"m\":52,\"g\":94},\"09798b36\":{\"m\":52,\"g\":94},\"b79fffdc\":{\"m\":52,\"g\":94},\"cd51758f\":{\"m\":52,\"g\":94},\"91e5dbf5\":{\"m\":52,\"g\":94},\"dd5eba4c\":{\"m\":52,\"g\":94},\"a4fd2f9b\":{\"m\":52,\"g\":94},\"92d1253e\":{\"m\":52,\"g\":94},\"a9ca297d\":{\"m\":52,\"g\":94},\"2a02185c\":{\"m\":52,\"g\":94},\"f8b03269\":{\"m\":53,\"g\":94},\"04957965\":{\"m\":53,\"g\":94},\"1228f7ca\":{\"m\":53,\"g\":94},\"fda628d8\":{\"m\":53,\"g\":94},\"07ec07ad\":{\"m\":53,\"g\":94},\"83b340e3\":{\"m\":53,\"g\":94},\"0639bf15\":{\"m\":53,\"g\":94},\"aa47f642\":{\"m\":53,\"g\":94},\"3ddb1c46\":{\"m\":53,\"g\":94},\"480e38a7\":{\"m\":53,\"g\":94},\"69e2d4fb\":{\"m\":53,\"g\":94},\"85e1a6f3\":{\"m\":53,\"g\":94},\"33deca81\":{\"m\":53,\"g\":94},\"18108abe\":{\"m\":53,\"g\":94},\"c54bda30\":{\"m\":53,\"g\":94},\"3c79ad35\":{\"m\":53,\"g\":94},\"983bfcf3\":{\"m\":53,\"g\":94},\"28bc60dc\":{\"m\":53,\"g\":94},\"7301a39b\":{\"m\":53,\"g\":94},\"47eb139f\":{\"m\":53,\"g\":94},\"5c18a037\":{\"m\":53,\"g\":94},\"5c91a315\":{\"m\":53,\"g\":94},\"3dbd73d3\":{\"m\":53,\"g\":94},\"e9a6203d\":{\"m\":53,\"g\":94},\"62c516ac\":{\"m\":53,\"g\":94},\"fc78640e\":{\"m\":53,\"g\":94},\"906d795f\":{\"m\":53,\"g\":94},\"118b6af3\":{\"m\":53,\"g\":94},\"9449a954\":{\"m\":53,\"g\":94},\"5f12f0e7\":{\"m\":53,\"g\":94},\"d5b95cbb\":{\"m\":53,\"g\":94},\"0303ca91\":{\"m\":53,\"g\":94},\"00181098\":{\"m\":53,\"g\":94},\"4936be8a\":{\"m\":53,\"g\":94},\"1bfa511b\":{\"m\":53,\"g\":94},\"f5b5f2bf\":{\"m\":53,\"g\":94},\"7e4c6dd8\":{\"m\":53,\"g\":94},\"d622851d\":{\"m\":53,\"g\":94},\"883c9554\":{\"m\":53,\"g\":94},\"0d6a49bd\":{\"m\":53,\"g\":94},\"ccaf1f99\":{\"m\":53,\"g\":94},\"7d1485d3\":{\"m\":53,\"g\":94},\"7d5d1d3d\":{\"m\":53,\"g\":94},\"b53d6cbd\":{\"m\":53,\"g\":94},\"01017d4c\":{\"m\":53,\"g\":94},\"94e167ea\":{\"m\":53,\"g\":94},\"262e370f\":{\"m\":53,\"g\":94},\"419a57e7\":{\"m\":53,\"g\":94},\"e5f227c0\":{\"m\":54,\"g\":94},\"0e7409ad\":{\"m\":54,\"g\":94},\"3cde5eb6\":{\"m\":54,\"g\":94},\"f5b2a3aa\":{\"m\":54,\"g\":94},\"f6817596\":{\"m\":54,\"g\":94},\"67b65794\":{\"m\":54,\"g\":94},\"37ee906f\":{\"m\":54,\"g\":94},\"34b364e0\":{\"m\":54,\"g\":94},\"84d96b3a\":{\"m\":54,\"g\":94},\"3d32e4a3\":{\"m\":54,\"g\":94},\"64fceab8\":{\"m\":54,\"g\":94},\"71e2a277\":{\"m\":54,\"g\":94},\"4a63c181\":{\"m\":54,\"g\":94},\"2b0fc594\":{\"m\":54,\"g\":94},\"9cc733b3\":{\"m\":54,\"g\":94},\"d693ec04\":{\"m\":54,\"g\":94},\"18ea841f\":{\"m\":54,\"g\":94},\"786be44d\":{\"m\":54,\"g\":94},\"2db44698\":{\"m\":54,\"g\":94},\"ed45e509\":{\"m\":54,\"g\":94},\"ec52464d\":{\"m\":54,\"g\":94},\"eb0c1f53\":{\"m\":54,\"g\":94},\"b2986d7a\":{\"m\":54,\"g\":94},\"8f4d04e5\":{\"m\":55,\"g\":94},\"feb2b768\":{\"m\":55,\"g\":94},\"d95a5f5b\":{\"m\":55,\"g\":94},\"4b83db24\":{\"m\":55,\"g\":94},\"64456cf0\":{\"m\":55,\"g\":94},\"bb4a9220\":{\"m\":55,\"g\":94},\"21e9e63a\":{\"m\":55,\"g\":94},\"1fc84cf6\":{\"m\":55,\"g\":94},\"361ea8d9\":{\"m\":55,\"g\":94},\"33c5ff28\":{\"m\":55,\"g\":94},\"5ce9daea\":{\"m\":55,\"g\":94},\"ce094a5d\":{\"m\":55,\"g\":94},\"e2102669\":{\"m\":55,\"g\":94},\"bd619616\":{\"m\":55,\"g\":94},\"56198b45\":{\"m\":55,\"g\":94},\"ba36b552\":{\"m\":55,\"g\":94},\"9cd9dc83\":{\"m\":55,\"g\":94},\"7a1aecb9\":{\"m\":55,\"g\":94},\"82699474\":{\"m\":55,\"g\":94},\"7154b4b1\":{\"m\":55,\"g\":94},\"b532a5fd\":{\"m\":55,\"g\":94},\"a0592c05\":{\"m\":55,\"g\":94},\"e8dbdf75\":{\"m\":55,\"g\":94},\"e04d3f28\":{\"m\":55,\"g\":94},\"5f2595be\":{\"m\":55,\"g\":94},\"0ba2c589\":{\"m\":55,\"g\":94},\"fccbfa37\":{\"m\":55,\"g\":94},\"2f9bd0fa\":{\"m\":55,\"g\":94},\"5282a473\":{\"m\":55,\"g\":94},\"f0ed9c35\":{\"m\":55,\"g\":94},\"e3b3acfa\":{\"m\":55,\"g\":94},\"2673fa29\":{\"m\":55,\"g\":94},\"dedaf8cd\":{\"m\":55,\"g\":94},\"32ed0160\":{\"m\":55,\"g\":94},\"6efa9e4a\":{\"m\":55,\"g\":94},\"7791fd99\":{\"m\":55,\"g\":94},\"2ac36b9a\":{\"m\":55,\"g\":94},\"2d60a5ee\":{\"m\":55,\"g\":94},\"2e4a5907\":{\"m\":55,\"g\":94},\"c0ee46fe\":{\"m\":55,\"g\":94},\"9208618b\":{\"m\":55,\"g\":94},\"864bf2ba\":{\"m\":55,\"g\":94},\"a4cca7fc\":{\"m\":55,\"g\":94},\"993956c6\":{\"m\":55,\"g\":94},\"f8548295\":{\"m\":55,\"g\":94},\"959735fc\":{\"m\":55,\"g\":94},\"f6772394\":{\"m\":55,\"g\":94},\"626a99ac\":{\"m\":55,\"g\":94},\"ece72491\":{\"m\":55,\"g\":94},\"0fb88aaa\":{\"m\":55,\"g\":94},\"d4de9a62\":{\"m\":55,\"g\":94},\"7310aede\":{\"m\":55,\"g\":94},\"5de9a58e\":{\"m\":55,\"g\":94},\"56fcd8e8\":{\"m\":55,\"g\":94},\"2b340adf\":{\"m\":55,\"g\":94},\"8586b72d\":{\"m\":55,\"g\":94},\"641b7d0a\":{\"m\":55,\"g\":94},\"0ce091a8\":{\"m\":55,\"g\":94},\"835f8afc\":{\"m\":55,\"g\":94},\"3844feb9\":{\"m\":55,\"g\":94},\"27f7bed7\":{\"m\":55,\"g\":94},\"6387098f\":{\"m\":55,\"g\":94},\"2a717c50\":{\"m\":55,\"g\":94},\"a1e697b2\":{\"m\":55,\"g\":94},\"a6ca736c\":{\"m\":55,\"g\":94},\"f62055b5\":{\"m\":55,\"g\":94},\"74bc9184\":{\"m\":55,\"g\":94},\"0f8eb153\":{\"m\":55,\"g\":94},\"67470bbb\":{\"m\":55,\"g\":94},\"cc858953\":{\"m\":55,\"g\":94},\"6128f7cf\":{\"m\":55,\"g\":94},\"a2486eb5\":{\"m\":55,\"g\":94},\"61dec545\":{\"m\":55,\"g\":94},\"96db0f66\":{\"m\":55,\"g\":94},\"7dc66fcb\":{\"m\":55,\"g\":94},\"1f09e84b\":{\"m\":55,\"g\":94},\"63dfab1b\":{\"m\":55,\"g\":94},\"ef995dae\":{\"m\":55,\"g\":94},\"75ae9689\":{\"m\":55,\"g\":94},\"95f93f49\":{\"m\":55,\"g\":94},\"aaac33fd\":{\"m\":55,\"g\":94},\"d332aa3b\":{\"m\":55,\"g\":94},\"c36736c8\":{\"m\":55,\"g\":94},\"1bf9e347\":{\"m\":55,\"g\":94},\"499c85f1\":{\"m\":55,\"g\":94},\"efc52f85\":{\"m\":56,\"g\":94},\"60e2fdcf\":{\"m\":56,\"g\":94},\"d7c0e872\":{\"m\":56,\"g\":94},\"31548116\":{\"m\":56,\"g\":94},\"53aed988\":{\"m\":56,\"g\":94},\"8a56b431\":{\"m\":56,\"g\":94},\"e835a500\":{\"m\":56,\"g\":94},\"23e5e50f\":{\"m\":56,\"g\":94},\"25e5d589\":{\"m\":56,\"g\":94},\"41b1db69\":{\"m\":56,\"g\":94},\"84967019\":{\"m\":56,\"g\":94},\"7d672d27\":{\"m\":56,\"g\":94},\"d4b17481\":{\"m\":56,\"g\":94},\"19ba2b0e\":{\"m\":56,\"g\":94},\"4e1e3cff\":{\"m\":56,\"g\":94},\"ef5b0ff9\":{\"m\":57,\"g\":94},\"6e530515\":{\"m\":57,\"g\":94},\"77d1210b\":{\"m\":57,\"g\":94},\"70dc2fbe\":{\"m\":57,\"g\":94},\"b438a2e5\":{\"m\":57,\"g\":94},\"7ca751ff\":{\"m\":57,\"g\":94},\"c75adfec\":{\"m\":57,\"g\":94},\"7722c11c\":{\"m\":57,\"g\":94},\"b2ed5c8e\":{\"m\":57,\"g\":94},\"f46f394f\":{\"m\":57,\"g\":94},\"2125898a\":{\"m\":57,\"g\":94},\"44f011d2\":{\"m\":57,\"g\":94},\"ed91e003\":{\"m\":57,\"g\":94},\"531d6ea9\":{\"m\":57,\"g\":94},\"dc3bee48\":{\"m\":57,\"g\":94},\"a74d1941\":{\"m\":57,\"g\":94},\"3169e66c\":{\"m\":57,\"g\":94},\"77395154\":{\"m\":57,\"g\":94},\"637de9e8\":{\"m\":57,\"g\":94},\"acb34072\":{\"m\":57,\"g\":94},\"08effbff\":{\"m\":57,\"g\":94},\"60bd3272\":{\"m\":57,\"g\":94},\"e7ebecf8\":{\"m\":57,\"g\":94},\"9a23c484\":{\"m\":57,\"g\":94},\"635a0426\":{\"m\":57,\"g\":94},\"2dccecf4\":{\"m\":57,\"g\":94},\"75ad0a14\":{\"m\":57,\"g\":94},\"3ccf566b\":{\"m\":58,\"g\":94},\"afa0341e\":{\"m\":58,\"g\":94},\"30828e71\":{\"m\":58,\"g\":94},\"e0e09fce\":{\"m\":58,\"g\":94},\"9c05c689\":{\"m\":58,\"g\":94},\"3464e57b\":{\"m\":58,\"g\":94},\"3815b23c\":{\"m\":58,\"g\":94},\"fd34f2da\":{\"m\":58,\"g\":94},\"8ee9a850\":{\"m\":58,\"g\":94},\"fd28640d\":{\"m\":58,\"g\":94},\"7863e436\":{\"m\":58,\"g\":94},\"333e3bfd\":{\"m\":58,\"g\":94},\"239c9d4d\":{\"m\":58,\"g\":94},\"855d0ba3\":{\"m\":58,\"g\":94},\"9254a33a\":{\"m\":58,\"g\":94},\"8a2681e2\":{\"m\":58,\"g\":94},\"5276a675\":{\"m\":58,\"g\":94},\"751e5ca2\":{\"m\":58,\"g\":94},\"7a7ac6be\":{\"m\":58,\"g\":94},\"d9e6ee38\":{\"m\":58,\"g\":94},\"03d5fbfd\":{\"m\":59,\"g\":94},\"1703d766\":{\"m\":59,\"g\":94},\"09e6e2aa\":{\"m\":59,\"g\":94},\"fad29f7f\":{\"m\":59,\"g\":94},\"35bdb485\":{\"m\":59,\"g\":94},\"b085e06b\":{\"m\":59,\"g\":94},\"763dd55d\":{\"m\":59,\"g\":94},\"2f0d3864\":{\"m\":60,\"g\":94},\"3900a94a\":{\"m\":60,\"g\":94},\"ded9fcd0\":{\"m\":60,\"g\":94},\"bc6ad367\":{\"m\":60,\"g\":94},\"3a22a303\":{\"m\":60,\"g\":94},\"bdb3929d\":{\"m\":60,\"g\":94},\"f5d0865b\":{\"m\":60,\"g\":94},\"afdee7b1\":{\"m\":60,\"g\":94},\"cb34d848\":{\"m\":60,\"g\":94},\"0f9cc6d8\":{\"m\":60,\"g\":94},\"c7ae474a\":{\"m\":60,\"g\":94},\"bdf946bf\":{\"m\":60,\"g\":94},\"8c8779cd\":{\"m\":60,\"g\":94},\"1775b963\":{\"m\":60,\"g\":94},\"dd2e2d27\":{\"m\":60,\"g\":94},\"a990daff\":{\"m\":60,\"g\":94},\"ba5112ff\":{\"m\":60,\"g\":94},\"815dce05\":{\"m\":60,\"g\":94},\"ad20b795\":{\"m\":60,\"g\":94},\"9183c23e\":{\"m\":60,\"g\":94},\"148254d4\":{\"m\":60,\"g\":94},\"a4d6d6f1\":{\"m\":60,\"g\":94},\"062c48d2\":{\"m\":60,\"g\":94},\"b6e0cfb5\":{\"m\":60,\"g\":94},\"0d8d97b8\":{\"m\":60,\"g\":94},\"0a765bbc\":{\"m\":60,\"g\":94},\"286cad3e\":{\"m\":60,\"g\":94},\"dc7eb01f\":{\"m\":60,\"g\":94},\"b0524c37\":{\"m\":60,\"g\":94},\"6c42fa22\":{\"m\":60,\"g\":94},\"d49b13c6\":{\"m\":60,\"g\":94},\"bedc4c7a\":{\"m\":60,\"g\":94},\"f44d1439\":{\"m\":60,\"g\":94},\"b6b57fc2\":{\"m\":60,\"g\":94},\"b4403985\":{\"m\":60,\"g\":94},\"339c69a2\":{\"m\":60,\"g\":94},\"f7074700\":{\"m\":60,\"g\":94},\"21ec66e5\":{\"m\":60,\"g\":94},\"c5210dfa\":{\"m\":60,\"g\":94},\"a29dd950\":{\"m\":60,\"g\":94},\"9c6ba248\":{\"m\":60,\"g\":94},\"b02da24a\":{\"m\":60,\"g\":94},\"bdd2827a\":{\"m\":60,\"g\":94},\"8c3b420e\":{\"m\":60,\"g\":94},\"e6f523b5\":{\"m\":60,\"g\":94},\"32318178\":{\"m\":60,\"g\":94},\"a11f8d5f\":{\"m\":60,\"g\":94},\"098d659c\":{\"m\":60,\"g\":94},\"76d14f8c\":{\"m\":60,\"g\":94},\"b08c308e\":{\"m\":60,\"g\":94},\"f624901c\":{\"m\":61,\"g\":94},\"f0e15dc6\":{\"m\":61,\"g\":94},\"f1769586\":{\"m\":61,\"g\":94},\"5d6e9467\":{\"m\":61,\"g\":94},\"a47bf391\":{\"m\":61,\"g\":94},\"b1706469\":{\"m\":61,\"g\":94},\"5413ec2b\":{\"m\":61,\"g\":94},\"f290bd43\":{\"m\":61,\"g\":94},\"8f157893\":{\"m\":61,\"g\":94},\"2db03a04\":{\"m\":61,\"g\":94},\"5cc11705\":{\"m\":61,\"g\":94},\"11fffbc9\":{\"m\":61,\"g\":94},\"4f077c01\":{\"m\":61,\"g\":94},\"679c3bca\":{\"m\":61,\"g\":94},\"656aed58\":{\"m\":61,\"g\":94},\"b5fb4ef5\":{\"m\":61,\"g\":94},\"2e6346fc\":{\"m\":61,\"g\":94},\"977f785d\":{\"m\":61,\"g\":94},\"8a690612\":{\"m\":61,\"g\":94},\"694e4192\":{\"m\":61,\"g\":94},\"b22f3f64\":{\"m\":61,\"g\":94},\"6fb57683\":{\"m\":61,\"g\":94},\"51caee74\":{\"m\":61,\"g\":94},\"58f9060e\":{\"m\":61,\"g\":94},\"bdc1acf6\":{\"m\":61,\"g\":94},\"6d08ce2a\":{\"m\":61,\"g\":94},\"380930a9\":{\"m\":61,\"g\":94},\"9dec582d\":{\"m\":61,\"g\":94},\"b01febdc\":{\"m\":61,\"g\":94},\"1acbaf1b\":{\"m\":61,\"g\":94},\"287427e2\":{\"m\":61,\"g\":94},\"b8574f69\":{\"m\":61,\"g\":94},\"2855caa4\":{\"m\":61,\"g\":94},\"2329e1dd\":{\"m\":61,\"g\":94},\"0f3eb1d2\":{\"m\":61,\"g\":94},\"06dd2eab\":{\"m\":61,\"g\":94},\"439f6580\":{\"m\":61,\"g\":94},\"b3e99dfb\":{\"m\":62,\"g\":94},\"f005758f\":{\"m\":62,\"g\":94},\"f5c6c667\":{\"m\":62,\"g\":94},\"cc0485be\":{\"m\":62,\"g\":94},\"b8cd09f2\":{\"m\":62,\"g\":94},\"c19d8482\":{\"m\":62,\"g\":94},\"80002562\":{\"m\":62,\"g\":94},\"46d44318\":{\"m\":62,\"g\":94},\"923f5183\":{\"m\":62,\"g\":94},\"d08c77c4\":{\"m\":62,\"g\":94},\"c1e097ca\":{\"m\":62,\"g\":94},\"6ec75e62\":{\"m\":62,\"g\":94},\"d855653b\":{\"m\":62,\"g\":94},\"336ff5b9\":{\"m\":62,\"g\":94},\"3b141e15\":{\"m\":62,\"g\":94},\"6249e4a1\":{\"m\":62,\"g\":94},\"f3516c28\":{\"m\":62,\"g\":94},\"17de02f9\":{\"m\":62,\"g\":94},\"51ab3ccf\":{\"m\":62,\"g\":94},\"67008f4b\":{\"m\":62,\"g\":94},\"4536d724\":{\"m\":62,\"g\":94},\"41d7e5b7\":{\"m\":62,\"g\":94},\"20a9f5df\":{\"m\":62,\"g\":94},\"42f39099\":{\"m\":62,\"g\":94},\"72c77763\":{\"m\":62,\"g\":94},\"4093aa46\":{\"m\":62,\"g\":94},\"e808c1df\":{\"m\":62,\"g\":94},\"a18ab81d\":{\"m\":62,\"g\":94},\"0bb0f763\":{\"m\":62,\"g\":94},\"85b2e057\":{\"m\":62,\"g\":94},\"a879c2fb\":{\"m\":62,\"g\":94},\"e2b16c47\":{\"m\":62,\"g\":94},\"c4f9707e\":{\"m\":62,\"g\":94},\"197cbf9b\":{\"m\":62,\"g\":94},\"e94fb7cb\":{\"m\":63,\"g\":94},\"b5caa22d\":{\"m\":63,\"g\":94},\"73401fd0\":{\"m\":63,\"g\":94},\"89cd9235\":{\"m\":63,\"g\":94},\"dc188132\":{\"m\":63,\"g\":94},\"10bfce71\":{\"m\":63,\"g\":94},\"583697cd\":{\"m\":63,\"g\":94},\"2584f6d9\":{\"m\":63,\"g\":94},\"51e87f6f\":{\"m\":63,\"g\":94},\"09bcbe01\":{\"m\":63,\"g\":94},\"03464890\":{\"m\":63,\"g\":94},\"44a96697\":{\"m\":63,\"g\":94},\"1a820e38\":{\"m\":63,\"g\":94},\"0ffcfdf4\":{\"m\":63,\"g\":94},\"cd493b5a\":{\"m\":63,\"g\":94},\"61f42b57\":{\"m\":63,\"g\":94},\"e403d237\":{\"m\":63,\"g\":94},\"3bcf5ece\":{\"m\":63,\"g\":94},\"2c05f81f\":{\"m\":63,\"g\":94},\"d77caa2b\":{\"m\":63,\"g\":94},\"8b6a4486\":{\"m\":63,\"g\":94},\"a69cb5cf\":{\"m\":63,\"g\":94},\"def5c318\":{\"m\":63,\"g\":94},\"3fc2b625\":{\"m\":63,\"g\":94},\"6ada05d0\":{\"m\":63,\"g\":94},\"24cafe31\":{\"m\":63,\"g\":94},\"5a176c92\":{\"m\":63,\"g\":94},\"4719c1d0\":{\"m\":63,\"g\":94},\"ef18b0ed\":{\"m\":63,\"g\":94},\"53cc91e5\":{\"m\":63,\"g\":94},\"d33cbb7e\":{\"m\":63,\"g\":94},\"23196d52\":{\"m\":63,\"g\":94},\"93b77c8e\":{\"m\":63,\"g\":94},\"7906d1d2\":{\"m\":63,\"g\":94},\"81d27c8e\":{\"m\":63,\"g\":94},\"4d4cdb3f\":{\"m\":63,\"g\":94},\"2bd18e2d\":{\"m\":63,\"g\":94},\"83452dbb\":{\"m\":63,\"g\":94},\"3d93f84a\":{\"m\":63,\"g\":94},\"c2f212d6\":{\"m\":63,\"g\":94},\"e2cdc8a5\":{\"m\":63,\"g\":94},\"2add697d\":{\"m\":63,\"g\":94},\"6f98c586\":{\"m\":63,\"g\":94},\"656dcc1a\":{\"m\":63,\"g\":94},\"8af7048d\":{\"m\":63,\"g\":94},\"d3024f4f\":{\"m\":63,\"g\":94},\"13387e6b\":{\"m\":63,\"g\":94},\"120c3634\":{\"m\":63,\"g\":94},\"78e5b22f\":{\"m\":63,\"g\":94},\"7a15e9ad\":{\"m\":63,\"g\":94},\"dc2ac0cb\":{\"m\":63,\"g\":94},\"d47c5101\":{\"m\":63,\"g\":94},\"033c715b\":{\"m\":63,\"g\":94},\"d06c1ab5\":{\"m\":63,\"g\":94},\"c5644cac\":{\"m\":63,\"g\":94},\"53e6552f\":{\"m\":63,\"g\":94},\"5dc54f1a\":{\"m\":63,\"g\":94},\"f3e9b489\":{\"m\":63,\"g\":94},\"6a7973ad\":{\"m\":63,\"g\":94},\"63051738\":{\"m\":63,\"g\":94},\"a8ccacc8\":{\"m\":63,\"g\":94},\"0427416b\":{\"m\":63,\"g\":94},\"bf3edc2c\":{\"m\":63,\"g\":94},\"78e974b2\":{\"m\":63,\"g\":94},\"bc6915e3\":{\"m\":63,\"g\":94},\"a883f079\":{\"m\":63,\"g\":94},\"8b6ce52e\":{\"m\":63,\"g\":94},\"58f3f2b8\":{\"m\":63,\"g\":94},\"93d69061\":{\"m\":63,\"g\":94},\"e00e5385\":{\"m\":63,\"g\":94},\"a2f602b5\":{\"m\":63,\"g\":94},\"8f2c522a\":{\"m\":63,\"g\":94},\"75964177\":{\"m\":63,\"g\":94},\"2dc957d4\":{\"m\":63,\"g\":94},\"bf8d07a6\":{\"m\":63,\"g\":94},\"ab317936\":{\"m\":63,\"g\":94},\"b7f3fec1\":{\"m\":63,\"g\":94},\"58f42b1d\":{\"m\":63,\"g\":94},\"767c9dec\":{\"m\":63,\"g\":94},\"a53454c5\":{\"m\":63,\"g\":94},\"6cb3974e\":{\"m\":63,\"g\":94},\"f65c13b5\":{\"m\":63,\"g\":94},\"b803b395\":{\"m\":63,\"g\":94},\"bfbda62c\":{\"m\":63,\"g\":94},\"4ab43cfb\":{\"m\":64,\"g\":94},\"2f79f588\":{\"m\":64,\"g\":94},\"8a96f749\":{\"m\":64,\"g\":94},\"827aa873\":{\"m\":64,\"g\":94},\"f8ca66fb\":{\"m\":64,\"g\":94},\"53cef815\":{\"m\":64,\"g\":94},\"351a72d4\":{\"m\":64,\"g\":94},\"514f37c3\":{\"m\":64,\"g\":94},\"52c03f16\":{\"m\":64,\"g\":94},\"741fccd7\":{\"m\":64,\"g\":94},\"1e3e5215\":{\"m\":64,\"g\":94},\"fb11a439\":{\"m\":64,\"g\":94},\"af02f99b\":{\"m\":64,\"g\":94},\"9472e699\":{\"m\":64,\"g\":94},\"1acc1f56\":{\"m\":64,\"g\":94},\"b045841b\":{\"m\":64,\"g\":94},\"f265d15b\":{\"m\":64,\"g\":94},\"02431b9a\":{\"m\":64,\"g\":94},\"1dda8c5e\":{\"m\":64,\"g\":94},\"7e097613\":{\"m\":64,\"g\":94},\"f4a92f4b\":{\"m\":64,\"g\":94},\"318260c0\":{\"m\":64,\"g\":94},\"4a612531\":{\"m\":64,\"g\":94},\"d1a08632\":{\"m\":64,\"g\":94},\"f8b28e46\":{\"m\":64,\"g\":94},\"82392da8\":{\"m\":64,\"g\":94},\"95f789ad\":{\"m\":64,\"g\":94},\"4f118a39\":{\"m\":64,\"g\":94},\"66283dbc\":{\"m\":64,\"g\":94},\"822bae8c\":{\"m\":64,\"g\":94},\"8e48ca8c\":{\"m\":64,\"g\":94},\"27acf63b\":{\"m\":64,\"g\":94},\"da6f8081\":{\"m\":64,\"g\":94},\"9286740e\":{\"m\":64,\"g\":94},\"896c0744\":{\"m\":64,\"g\":94},\"c23d5706\":{\"m\":64,\"g\":94},\"67ad4338\":{\"m\":64,\"g\":94},\"3cab5f71\":{\"m\":64,\"g\":94},\"14e754a8\":{\"m\":64,\"g\":94},\"98522149\":{\"m\":64,\"g\":94},\"5d9d15e7\":{\"m\":64,\"g\":94},\"665e5e85\":{\"m\":64,\"g\":94},\"a22f60a3\":{\"m\":64,\"g\":94},\"04f0b4cb\":{\"m\":64,\"g\":94},\"4505a436\":{\"m\":64,\"g\":94},\"685a5738\":{\"m\":64,\"g\":94},\"153b414e\":{\"m\":64,\"g\":94},\"6619f48e\":{\"m\":64,\"g\":94},\"3ed0a547\":{\"m\":64,\"g\":94},\"8d8ef849\":{\"m\":64,\"g\":94},\"9a0cc2e9\":{\"m\":64,\"g\":94},\"7bad7e75\":{\"m\":64,\"g\":94},\"1c4e0d24\":{\"m\":64,\"g\":94},\"54bac8af\":{\"m\":64,\"g\":94},\"5de4051b\":{\"m\":64,\"g\":94},\"e0cd65c2\":{\"m\":64,\"g\":94},\"f1b68618\":{\"m\":64,\"g\":94},\"0da0989a\":{\"m\":64,\"g\":94},\"07a22cbb\":{\"m\":64,\"g\":94},\"3d0bfa3e\":{\"m\":64,\"g\":94},\"1f6cf0d4\":{\"m\":64,\"g\":94},\"553f5a3f\":{\"m\":64,\"g\":94},\"ac2dc35d\":{\"m\":64,\"g\":94},\"3e032c07\":{\"m\":64,\"g\":94},\"44e12ce4\":{\"m\":64,\"g\":94},\"a547aad6\":{\"m\":64,\"g\":94},\"ea535dc5\":{\"m\":64,\"g\":94},\"862bcff8\":{\"m\":64,\"g\":94},\"8b84e69f\":{\"m\":64,\"g\":94},\"5de50653\":{\"m\":64,\"g\":94},\"c0bf9bf1\":{\"m\":64,\"g\":94},\"022614d2\":{\"m\":64,\"g\":94},\"b8ab989f\":{\"m\":64,\"g\":94},\"b3393e94\":{\"m\":64,\"g\":94},\"ddc2001f\":{\"m\":64,\"g\":94},\"806a3002\":{\"m\":64,\"g\":94},\"0d2148ef\":{\"m\":64,\"g\":94},\"bf669606\":{\"m\":64,\"g\":94},\"b2bd8f44\":{\"m\":64,\"g\":94},\"9d9b482a\":{\"m\":64,\"g\":94},\"7353fb9b\":{\"m\":64,\"g\":94},\"bcda0c9e\":{\"m\":64,\"g\":94},\"9f8f2c7f\":{\"m\":64,\"g\":94},\"6fc37bd8\":{\"m\":64,\"g\":94},\"3d8f1c9b\":{\"m\":64,\"g\":94},\"a42213db\":{\"m\":64,\"g\":94},\"0ac019f1\":{\"m\":64,\"g\":94},\"5a0d680a\":{\"m\":64,\"g\":94},\"a4331cd2\":{\"m\":64,\"g\":94},\"ec1c21cd\":{\"m\":64,\"g\":94},\"6c856b4f\":{\"m\":64,\"g\":94},\"287d07a6\":{\"m\":64,\"g\":94},\"d2571dd5\":{\"m\":64,\"g\":94},\"b730aa6b\":{\"m\":64,\"g\":94},\"60b2a44a\":{\"m\":64,\"g\":94},\"949b3fbf\":{\"m\":64,\"g\":94},\"da4e8b38\":{\"m\":64,\"g\":94},\"af6c5357\":{\"m\":64,\"g\":94},\"3ad4cd49\":{\"m\":64,\"g\":94},\"3a8428ec\":{\"m\":64,\"g\":94},\"0311ce8e\":{\"m\":64,\"g\":94},\"5dfcacfc\":{\"m\":64,\"g\":94},\"41a0ccd4\":{\"m\":64,\"g\":94},\"cf0f7eaf\":{\"m\":65,\"g\":94},\"b49d6d0f\":{\"m\":65,\"g\":94},\"c02e3139\":{\"m\":65,\"g\":94},\"734daedd\":{\"m\":65,\"g\":94},\"3ee62235\":{\"m\":65,\"g\":94},\"9829e77e\":{\"m\":65,\"g\":94},\"cde4bbd5\":{\"m\":65,\"g\":94},\"9602c2aa\":{\"m\":65,\"g\":94},\"e81d7f11\":{\"m\":65,\"g\":94},\"222ce6f1\":{\"m\":65,\"g\":94},\"468d23cf\":{\"m\":65,\"g\":94},\"c38b5fb4\":{\"m\":65,\"g\":94},\"20453cef\":{\"m\":65,\"g\":94},\"9f635ea5\":{\"m\":65,\"g\":94},\"76285fde\":{\"m\":65,\"g\":94},\"988d0a4b\":{\"m\":65,\"g\":94},\"81262c7b\":{\"m\":65,\"g\":94},\"27aeb4b7\":{\"m\":65,\"g\":94},\"7b9b4f44\":{\"m\":65,\"g\":94},\"08104b56\":{\"m\":65,\"g\":94},\"cf142b6e\":{\"m\":65,\"g\":94},\"7aad8d18\":{\"m\":66,\"g\":94},\"76fa2d15\":{\"m\":66,\"g\":94},\"7ab84948\":{\"m\":66,\"g\":94},\"4885b908\":{\"m\":66,\"g\":94},\"c2723a42\":{\"m\":66,\"g\":94},\"c7256ca8\":{\"m\":66,\"g\":94},\"6186a8f8\":{\"m\":66,\"g\":94},\"a07364cc\":{\"m\":66,\"g\":94},\"2c1a695f\":{\"m\":66,\"g\":94},\"d39899e8\":{\"m\":66,\"g\":94},\"70817a7e\":{\"m\":66,\"g\":94},\"7b5a3741\":{\"m\":66,\"g\":94},\"4b6f62e2\":{\"m\":66,\"g\":94},\"897e2e25\":{\"m\":66,\"g\":94},\"d54cee14\":{\"m\":66,\"g\":94},\"00fa7d04\":{\"m\":66,\"g\":94},\"013021b6\":{\"m\":66,\"g\":94},\"3c8ac78d\":{\"m\":66,\"g\":94},\"455bfe8d\":{\"m\":66,\"g\":94},\"28b0a62b\":{\"m\":66,\"g\":94},\"566d61d9\":{\"m\":66,\"g\":94},\"55f5fc68\":{\"m\":66,\"g\":94},\"c27c378a\":{\"m\":66,\"g\":94},\"d9eb9358\":{\"m\":66,\"g\":94},\"959dca4f\":{\"m\":66,\"g\":94},\"f2b3a318\":{\"m\":66,\"g\":94},\"ad674097\":{\"m\":66,\"g\":94},\"8db776f0\":{\"m\":66,\"g\":94},\"4eb4b401\":{\"m\":66,\"g\":94},\"17dbf976\":{\"m\":66,\"g\":94},\"53179026\":{\"m\":66,\"g\":94},\"d7c0b32f\":{\"m\":66,\"g\":94},\"7b020cca\":{\"m\":66,\"g\":94},\"7876279e\":{\"m\":66,\"g\":94},\"34e405e0\":{\"m\":66,\"g\":94},\"1ebe1d6d\":{\"m\":66,\"g\":94},\"7811bfda\":{\"m\":66,\"g\":94},\"656f7fc1\":{\"m\":66,\"g\":94},\"c1f5f99f\":{\"m\":67,\"g\":94},\"fa82dfcc\":{\"m\":67,\"g\":94},\"5da3d21c\":{\"m\":67,\"g\":94},\"f2870376\":{\"m\":67,\"g\":94},\"f9905d59\":{\"m\":67,\"g\":94},\"45c87e08\":{\"m\":67,\"g\":94},\"2b1808ce\":{\"m\":67,\"g\":94},\"e868d0b6\":{\"m\":67,\"g\":94},\"591e751e\":{\"m\":67,\"g\":94},\"40022d07\":{\"m\":67,\"g\":94},\"823148e7\":{\"m\":67,\"g\":94},\"76ca91df\":{\"m\":67,\"g\":94},\"cdae77b0\":{\"m\":67,\"g\":94},\"adeee152\":{\"m\":67,\"g\":94},\"6792411e\":{\"m\":67,\"g\":94},\"7348d962\":{\"m\":67,\"g\":94},\"25ed22b6\":{\"m\":67,\"g\":94},\"200d3b16\":{\"m\":67,\"g\":94},\"ad349985\":{\"m\":67,\"g\":94},\"32de54ed\":{\"m\":67,\"g\":94},\"2d9c3195\":{\"m\":67,\"g\":94},\"07e58a2d\":{\"m\":67,\"g\":94},\"04d8cd20\":{\"m\":67,\"g\":94},\"a322051e\":{\"m\":67,\"g\":94},\"de553334\":{\"m\":67,\"g\":94},\"cddb1cdf\":{\"m\":68,\"g\":94},\"fa1b40e0\":{\"m\":68,\"g\":94},\"c45cab1c\":{\"m\":68,\"g\":94},\"27c4c9cf\":{\"m\":68,\"g\":94},\"52a492a1\":{\"m\":68,\"g\":94},\"36f6fc50\":{\"m\":68,\"g\":94},\"d8727275\":{\"m\":68,\"g\":94},\"6239d0b2\":{\"m\":68,\"g\":94},\"4cfd3add\":{\"m\":68,\"g\":94},\"20cf910d\":{\"m\":68,\"g\":94},\"0af1d239\":{\"m\":68,\"g\":94},\"85986bb9\":{\"m\":68,\"g\":94},\"64c87135\":{\"m\":68,\"g\":94},\"1646149a\":{\"m\":68,\"g\":94},\"bc72e5bd\":{\"m\":68,\"g\":94},\"014cab4d\":{\"m\":68,\"g\":94},\"4d2dbeac\":{\"m\":68,\"g\":94},\"29daf498\":{\"m\":68,\"g\":94},\"6702592d\":{\"m\":68,\"g\":94},\"60abdb3e\":{\"m\":68,\"g\":94},\"7b4e61ff\":{\"m\":68,\"g\":94},\"6222e1c2\":{\"m\":68,\"g\":94},\"fad315cb\":{\"m\":68,\"g\":94},\"f90db8bc\":{\"m\":68,\"g\":94},\"d8ad5970\":{\"m\":68,\"g\":94},\"849f58d6\":{\"m\":68,\"g\":94},\"64480df4\":{\"m\":68,\"g\":94},\"4530136e\":{\"m\":68,\"g\":94},\"0a6f18f0\":{\"m\":68,\"g\":94},\"e0b9a423\":{\"m\":69,\"g\":94},\"e0821425\":{\"m\":69,\"g\":94},\"70f894b8\":{\"m\":69,\"g\":94},\"368de366\":{\"m\":69,\"g\":94},\"20de05a7\":{\"m\":69,\"g\":94},\"f076328b\":{\"m\":69,\"g\":94},\"bf2a7087\":{\"m\":69,\"g\":94},\"871a4aa1\":{\"m\":69,\"g\":94},\"98eecbda\":{\"m\":69,\"g\":94},\"4430c0a5\":{\"m\":69,\"g\":94},\"640363ad\":{\"m\":69,\"g\":94},\"8616357a\":{\"m\":69,\"g\":94},\"8adbc78b\":{\"m\":69,\"g\":94},\"45e3a7bc\":{\"m\":69,\"g\":94},\"b96e92e6\":{\"m\":69,\"g\":94},\"693c2600\":{\"m\":69,\"g\":94},\"ced68066\":{\"m\":69,\"g\":94},\"b8318aec\":{\"m\":69,\"g\":94},\"2f482210\":{\"m\":69,\"g\":94},\"d81ac443\":{\"m\":69,\"g\":94},\"2491cc92\":{\"m\":69,\"g\":94},\"67c5de92\":{\"m\":69,\"g\":94},\"1e2cf2b5\":{\"m\":69,\"g\":94},\"9490d157\":{\"m\":69,\"g\":94},\"eefcbdd3\":{\"m\":69,\"g\":94},\"7e6d5fc6\":{\"m\":69,\"g\":94},\"cadd5dbe\":{\"m\":69,\"g\":94},\"bb418ced\":{\"m\":69,\"g\":94},\"fdf04a14\":{\"m\":69,\"g\":94},\"5f0e7de3\":{\"m\":69,\"g\":94},\"2f47d710\":{\"m\":69,\"g\":94},\"4fe92bfc\":{\"m\":69,\"g\":94},\"d23cb9a0\":{\"m\":69,\"g\":94},\"2d611323\":{\"m\":69,\"g\":94},\"e782eb7e\":{\"m\":70,\"g\":94},\"e319153b\":{\"m\":70,\"g\":94},\"32b44d2f\":{\"m\":70,\"g\":94},\"5f1a485d\":{\"m\":70,\"g\":94},\"c9565e49\":{\"m\":70,\"g\":94},\"d03c4c25\":{\"m\":70,\"g\":94},\"8f13377d\":{\"m\":70,\"g\":94},\"3d4a8f9b\":{\"m\":70,\"g\":94},\"7474bed8\":{\"m\":70,\"g\":94},\"03caefeb\":{\"m\":70,\"g\":94},\"bcc213df\":{\"m\":70,\"g\":94},\"39416e39\":{\"m\":70,\"g\":94},\"231c40d8\":{\"m\":70,\"g\":94},\"bbc47c34\":{\"m\":70,\"g\":94},\"dfce9269\":{\"m\":70,\"g\":94},\"6718b109\":{\"m\":70,\"g\":94},\"7711ac6e\":{\"m\":70,\"g\":94},\"7443197a\":{\"m\":70,\"g\":94},\"862dd76c\":{\"m\":70,\"g\":94},\"fb4c9c3a\":{\"m\":70,\"g\":94},\"d973c78e\":{\"m\":70,\"g\":94},\"6ce6eabb\":{\"m\":70,\"g\":94},\"4e23c961\":{\"m\":70,\"g\":94},\"3efbdf68\":{\"m\":70,\"g\":94},\"6cc30955\":{\"m\":70,\"g\":94},\"31eec35b\":{\"m\":70,\"g\":94},\"ac963be2\":{\"m\":70,\"g\":94},\"a5375adc\":{\"m\":71,\"g\":94},\"75d171a9\":{\"m\":71,\"g\":94},\"714f3e63\":{\"m\":71,\"g\":94},\"c38f3aed\":{\"m\":71,\"g\":94},\"2e6be53e\":{\"m\":71,\"g\":94},\"fc671f66\":{\"m\":72,\"g\":94},\"197751e9\":{\"m\":72,\"g\":94},\"d2d0d061\":{\"m\":72,\"g\":94},\"25482edb\":{\"m\":72,\"g\":94},\"62b362b1\":{\"m\":72,\"g\":94},\"44d76463\":{\"m\":72,\"g\":94},\"cd85b78f\":{\"m\":72,\"g\":94},\"0aaccbbf\":{\"m\":72,\"g\":94},\"357671e2\":{\"m\":72,\"g\":94},\"e70fa279\":{\"m\":72,\"g\":94},\"abe74b7b\":{\"m\":72,\"g\":94},\"70b3c6ee\":{\"m\":72,\"g\":94},\"ef9d3b3c\":{\"m\":72,\"g\":94},\"fc91d08a\":{\"m\":72,\"g\":94},\"71ab0dab\":{\"m\":72,\"g\":94},\"d3d4d767\":{\"m\":72,\"g\":94},\"5be8f1ed\":{\"m\":72,\"g\":94},\"e5760bc4\":{\"m\":72,\"g\":94},\"56a724eb\":{\"m\":72,\"g\":94},\"583d6af7\":{\"m\":72,\"g\":94},\"e074d84e\":{\"m\":72,\"g\":94},\"4725e3f6\":{\"m\":72,\"g\":94},\"77a3954b\":{\"m\":72,\"g\":94},\"03b0364f\":{\"m\":72,\"g\":94},\"2dd7d0c5\":{\"m\":72,\"g\":94},\"0d4e3228\":{\"m\":72,\"g\":94},\"926f8efc\":{\"m\":72,\"g\":94},\"9545bfb2\":{\"m\":72,\"g\":94},\"37373ef2\":{\"m\":72,\"g\":94},\"61261b39\":{\"m\":72,\"g\":94},\"19120f71\":{\"m\":72,\"g\":94},\"2415ec38\":{\"m\":72,\"g\":94},\"87f671ab\":{\"m\":72,\"g\":94},\"51d25405\":{\"m\":72,\"g\":94},\"e0a2c963\":{\"m\":72,\"g\":94},\"12f2e6c3\":{\"m\":72,\"g\":94},\"95575aa7\":{\"m\":72,\"g\":94},\"11eea69e\":{\"m\":72,\"g\":94},\"1baa9e6c\":{\"m\":72,\"g\":94},\"911fcd09\":{\"m\":72,\"g\":94},\"9fafa62d\":{\"m\":72,\"g\":94},\"146ac8df\":{\"m\":72,\"g\":94},\"57a404fd\":{\"m\":72,\"g\":94},\"2796fbb5\":{\"m\":72,\"g\":94},\"935cda94\":{\"m\":72,\"g\":94},\"110e0066\":{\"m\":72,\"g\":94},\"6b45a21d\":{\"m\":72,\"g\":94},\"a7000a76\":{\"m\":72,\"g\":94},\"1a8f995c\":{\"m\":72,\"g\":94},\"a3ab768a\":{\"m\":72,\"g\":94},\"66301e12\":{\"m\":72,\"g\":94},\"ac238727\":{\"m\":72,\"g\":94},\"0194948f\":{\"m\":72,\"g\":94},\"b4d34cd3\":{\"m\":72,\"g\":94},\"728e175f\":{\"m\":72,\"g\":94},\"9e1014cf\":{\"m\":72,\"g\":94},\"fa561067\":{\"m\":72,\"g\":94},\"7fbab730\":{\"m\":72,\"g\":94},\"b7e274f2\":{\"m\":72,\"g\":94},\"9cf40772\":{\"m\":72,\"g\":94},\"d3fe9bae\":{\"m\":72,\"g\":94},\"00ce7e31\":{\"m\":72,\"g\":94},\"50f28f65\":{\"m\":72,\"g\":94},\"90a55e25\":{\"m\":72,\"g\":94},\"407e2b92\":{\"m\":72,\"g\":94},\"40782f05\":{\"m\":72,\"g\":94},\"18bb216c\":{\"m\":72,\"g\":94},\"6b859e7d\":{\"m\":72,\"g\":94},\"930da877\":{\"m\":72,\"g\":94},\"3f8a4414\":{\"m\":72,\"g\":94},\"aceb4201\":{\"m\":72,\"g\":94},\"90a4b7d9\":{\"m\":72,\"g\":94},\"f3b99f73\":{\"m\":72,\"g\":94},\"9e74ee91\":{\"m\":72,\"g\":94},\"77a6c9d2\":{\"m\":72,\"g\":94},\"e3e0bc50\":{\"m\":72,\"g\":94},\"bac414ab\":{\"m\":72,\"g\":94},\"eec3f6d1\":{\"m\":72,\"g\":94},\"90bc26a8\":{\"m\":72,\"g\":94},\"ec0a72c2\":{\"m\":72,\"g\":94},\"1c96fa86\":{\"m\":72,\"g\":94},\"bc20e93f\":{\"m\":72,\"g\":94},\"d3887852\":{\"m\":72,\"g\":94},\"564bdf29\":{\"m\":72,\"g\":94},\"5d860168\":{\"m\":72,\"g\":94},\"d2815879\":{\"m\":72,\"g\":94},\"b0df5d24\":{\"m\":72,\"g\":94},\"3e02526b\":{\"m\":72,\"g\":94},\"d8a98a2c\":{\"m\":72,\"g\":94},\"0519269d\":{\"m\":72,\"g\":94},\"d6898dd2\":{\"m\":72,\"g\":94},\"71ed0183\":{\"m\":72,\"g\":94},\"8b681d77\":{\"m\":72,\"g\":94},\"194eea17\":{\"m\":72,\"g\":94},\"acd1a159\":{\"m\":72,\"g\":94},\"7c1692aa\":{\"m\":72,\"g\":94},\"8f019c7d\":{\"m\":72,\"g\":94},\"7551498a\":{\"m\":72,\"g\":94},\"44a2c4bd\":{\"m\":72,\"g\":94},\"c9fc4a9d\":{\"m\":72,\"g\":94},\"21463e32\":{\"m\":72,\"g\":94},\"3dc9ff3c\":{\"m\":72,\"g\":94},\"06427dfa\":{\"m\":72,\"g\":94},\"60524920\":{\"m\":72,\"g\":94},\"10771026\":{\"m\":72,\"g\":94},\"4606e2a3\":{\"m\":72,\"g\":94},\"127998cc\":{\"m\":72,\"g\":94},\"c0bb9eb3\":{\"m\":72,\"g\":94},\"7036d6fc\":{\"m\":72,\"g\":94},\"6ce9dbe8\":{\"m\":72,\"g\":94},\"3758d209\":{\"m\":72,\"g\":94},\"faf29e0b\":{\"m\":72,\"g\":94},\"b0743ea0\":{\"m\":72,\"g\":94},\"60b771c8\":{\"m\":72,\"g\":94},\"d7934cde\":{\"m\":72,\"g\":94},\"62bbd343\":{\"m\":72,\"g\":94},\"f2388f6b\":{\"m\":72,\"g\":94},\"c9745ee0\":{\"m\":72,\"g\":94},\"1a6e9757\":{\"m\":72,\"g\":94},\"b1100846\":{\"m\":72,\"g\":94},\"27a46317\":{\"m\":72,\"g\":94},\"c9795808\":{\"m\":72,\"g\":94},\"6c7a152c\":{\"m\":72,\"g\":94},\"4d2a88bd\":{\"m\":72,\"g\":94},\"45360b2f\":{\"m\":72,\"g\":94},\"3f41b184\":{\"m\":72,\"g\":94},\"45205d88\":{\"m\":72,\"g\":94},\"90876940\":{\"m\":72,\"g\":94},\"a3339d8c\":{\"m\":72,\"g\":94},\"14d90617\":{\"m\":72,\"g\":94},\"d37f9551\":{\"m\":72,\"g\":94},\"c66b2c9c\":{\"m\":72,\"g\":94},\"20b765a2\":{\"m\":72,\"g\":94},\"e3107222\":{\"m\":72,\"g\":94},\"e074e76b\":{\"m\":72,\"g\":94},\"4592afc2\":{\"m\":72,\"g\":94},\"9af0e21e\":{\"m\":72,\"g\":94},\"c7c79b16\":{\"m\":72,\"g\":94},\"d8d75d25\":{\"m\":72,\"g\":94},\"1df6eabd\":{\"m\":72,\"g\":94},\"0c227ee3\":{\"m\":72,\"g\":94},\"5c54ef03\":{\"m\":72,\"g\":94},\"c6a48521\":{\"m\":72,\"g\":94},\"4f678c87\":{\"m\":72,\"g\":94},\"e79f7420\":{\"m\":72,\"g\":94},\"ac053100\":{\"m\":72,\"g\":94},\"d5d80ab4\":{\"m\":72,\"g\":94},\"ddcf9fe3\":{\"m\":72,\"g\":94},\"6252ade9\":{\"m\":72,\"g\":94},\"1eb8eade\":{\"m\":72,\"g\":94},\"3c7bfd7e\":{\"m\":72,\"g\":94},\"bb121214\":{\"m\":72,\"g\":94},\"55de40f7\":{\"m\":72,\"g\":94},\"6b0aeb58\":{\"m\":72,\"g\":94},\"bb3e5268\":{\"m\":72,\"g\":94},\"f93e9158\":{\"m\":72,\"g\":94},\"55a7ec38\":{\"m\":72,\"g\":94},\"fe0673f1\":{\"m\":72,\"g\":94},\"99c1b9d2\":{\"m\":72,\"g\":94},\"634a3561\":{\"m\":72,\"g\":94},\"424848d2\":{\"m\":72,\"g\":94},\"e5ce395a\":{\"m\":72,\"g\":94},\"f983213a\":{\"m\":72,\"g\":94},\"67fc595b\":{\"m\":72,\"g\":94},\"07ab4d4a\":{\"m\":72,\"g\":94},\"522e18ea\":{\"m\":72,\"g\":94},\"c51dc2cc\":{\"m\":72,\"g\":94},\"ddf39d3f\":{\"m\":72,\"g\":94},\"2eab1132\":{\"m\":72,\"g\":94},\"058d199d\":{\"m\":72,\"g\":94},\"9c58e68b\":{\"m\":73,\"g\":94},\"d03b3467\":{\"m\":73,\"g\":94},\"ab7fba0e\":{\"m\":73,\"g\":94},\"bc1534ff\":{\"m\":73,\"g\":94},\"3a391812\":{\"m\":73,\"g\":94},\"800bf018\":{\"m\":73,\"g\":94},\"b16af90b\":{\"m\":73,\"g\":94},\"98c73d71\":{\"m\":73,\"g\":94},\"fcc2e37f\":{\"m\":73,\"g\":94},\"0804dd11\":{\"m\":73,\"g\":94},\"55dc8e4d\":{\"m\":73,\"g\":94},\"02e9e9f1\":{\"m\":73,\"g\":94},\"8f0b6313\":{\"m\":73,\"g\":94},\"b9b3b098\":{\"m\":73,\"g\":94},\"aee30630\":{\"m\":73,\"g\":94},\"286e6540\":{\"m\":73,\"g\":94},\"718c391f\":{\"m\":73,\"g\":94},\"6aaeb848\":{\"m\":74,\"g\":94},\"3623b6a7\":{\"m\":74,\"g\":94},\"4ff12642\":{\"m\":74,\"g\":94},\"2a4cbad8\":{\"m\":74,\"g\":94},\"2937387a\":{\"m\":74,\"g\":94},\"cf721fde\":{\"m\":74,\"g\":94},\"45de8971\":{\"m\":74,\"g\":94},\"71046fcd\":{\"m\":74,\"g\":94},\"c76040e3\":{\"m\":74,\"g\":94},\"2f6bacee\":{\"m\":74,\"g\":94},\"40148041\":{\"m\":74,\"g\":94},\"ad46550d\":{\"m\":74,\"g\":94},\"14344caa\":{\"m\":74,\"g\":94},\"f7f88b70\":{\"m\":74,\"g\":94},\"18c27131\":{\"m\":74,\"g\":94},\"ccdd10c8\":{\"m\":74,\"g\":94},\"76f6c0eb\":{\"m\":74,\"g\":94},\"959a3143\":{\"m\":74,\"g\":94},\"6412c5e4\":{\"m\":74,\"g\":94},\"0c020860\":{\"m\":74,\"g\":94},\"85ef7f64\":{\"m\":74,\"g\":94},\"f1cf6eef\":{\"m\":74,\"g\":94},\"0a59a465\":{\"m\":74,\"g\":94},\"aff79f10\":{\"m\":74,\"g\":94},\"01603318\":{\"m\":74,\"g\":94},\"56c39a05\":{\"m\":74,\"g\":94},\"4068e012\":{\"m\":74,\"g\":94},\"817d4370\":{\"m\":74,\"g\":94},\"c550e52f\":{\"m\":74,\"g\":94},\"e35a93fa\":{\"m\":74,\"g\":94},\"2c3656f2\":{\"m\":74,\"g\":94},\"d40ee62b\":{\"m\":74,\"g\":94},\"91b19949\":{\"m\":74,\"g\":94},\"7c866711\":{\"m\":74,\"g\":94},\"10b544ae\":{\"m\":74,\"g\":94},\"01090e8a\":{\"m\":74,\"g\":94},\"6f43a9b9\":{\"m\":74,\"g\":94},\"0540fef7\":{\"m\":74,\"g\":94},\"481f608b\":{\"m\":74,\"g\":94},\"ed91561f\":{\"m\":74,\"g\":94},\"6e7239f9\":{\"m\":74,\"g\":94},\"0a3960f2\":{\"m\":74,\"g\":94},\"07f94463\":{\"m\":74,\"g\":94},\"e0917e6b\":{\"m\":74,\"g\":94},\"7130a7ce\":{\"m\":74,\"g\":94},\"8f1f614e\":{\"m\":74,\"g\":94},\"7140ba35\":{\"m\":74,\"g\":94},\"d1da58e2\":{\"m\":74,\"g\":94},\"1cf63485\":{\"m\":74,\"g\":94},\"ff2ce0b8\":{\"m\":74,\"g\":94},\"0f2a2e3c\":{\"m\":74,\"g\":94},\"690e1f23\":{\"m\":74,\"g\":94},\"00f42707\":{\"m\":74,\"g\":94},\"6a02b32d\":{\"m\":74,\"g\":94},\"3a08f546\":{\"m\":74,\"g\":94},\"dce303e2\":{\"m\":74,\"g\":94},\"4d27eb9a\":{\"m\":74,\"g\":94},\"d3ecd632\":{\"m\":74,\"g\":94},\"cd909455\":{\"m\":74,\"g\":94},\"bde24ab3\":{\"m\":74,\"g\":94},\"bf2eefc0\":{\"m\":74,\"g\":94},\"5524e7d0\":{\"m\":74,\"g\":94},\"e187a3d5\":{\"m\":74,\"g\":94},\"3dd4feae\":{\"m\":74,\"g\":94},\"2ac189ed\":{\"m\":74,\"g\":94},\"5a6400ee\":{\"m\":74,\"g\":94},\"cf0ccd40\":{\"m\":74,\"g\":94},\"3d56585a\":{\"m\":74,\"g\":94},\"00d25a7f\":{\"m\":74,\"g\":94},\"1a5023e0\":{\"m\":74,\"g\":94},\"23308a90\":{\"m\":74,\"g\":94},\"ac698850\":{\"m\":74,\"g\":94},\"aa957102\":{\"m\":74,\"g\":94},\"007f8b3d\":{\"m\":74,\"g\":94},\"4455b26e\":{\"m\":74,\"g\":94},\"c553e160\":{\"m\":74,\"g\":94},\"7c0541b3\":{\"m\":74,\"g\":94},\"e8a69e4d\":{\"m\":74,\"g\":94},\"fbd56002\":{\"m\":74,\"g\":94},\"730d084f\":{\"m\":74,\"g\":94},\"4a05bdfa\":{\"m\":74,\"g\":94},\"eb06dbcb\":{\"m\":74,\"g\":94},\"9dfafa74\":{\"m\":74,\"g\":94},\"f1d09a65\":{\"m\":74,\"g\":94},\"df84ab2a\":{\"m\":74,\"g\":94},\"34c88987\":{\"m\":74,\"g\":94},\"0dd6cda2\":{\"m\":74,\"g\":94},\"9fb48f95\":{\"m\":74,\"g\":94},\"89ccb533\":{\"m\":74,\"g\":94},\"dceb256f\":{\"m\":74,\"g\":94},\"0e90ae62\":{\"m\":74,\"g\":94},\"1361ab9e\":{\"m\":74,\"g\":94},\"5c7dd14b\":{\"m\":74,\"g\":94},\"8abf74e3\":{\"m\":74,\"g\":94},\"ee132a45\":{\"m\":74,\"g\":94},\"79a321af\":{\"m\":74,\"g\":94},\"6eec3cdc\":{\"m\":74,\"g\":94},\"48473684\":{\"m\":74,\"g\":94},\"b3251e9f\":{\"m\":74,\"g\":94},\"2cadd51d\":{\"m\":74,\"g\":94},\"4a893d14\":{\"m\":74,\"g\":94},\"8d323e95\":{\"m\":74,\"g\":94},\"0fe7c13b\":{\"m\":74,\"g\":94},\"08c4d764\":{\"m\":74,\"g\":94},\"96d0e37f\":{\"m\":74,\"g\":94},\"90bb2be2\":{\"m\":74,\"g\":94},\"b93ef5e5\":{\"m\":74,\"g\":94},\"d4017a6b\":{\"m\":74,\"g\":94},\"d052f4c8\":{\"m\":74,\"g\":94},\"e1aaa79a\":{\"m\":74,\"g\":94},\"20c81199\":{\"m\":74,\"g\":94},\"70866b6f\":{\"m\":74,\"g\":94},\"eb61f5c9\":{\"m\":74,\"g\":94},\"0beea450\":{\"m\":74,\"g\":94},\"c827c671\":{\"m\":74,\"g\":94},\"b55a621f\":{\"m\":74,\"g\":94},\"ffa1b3e3\":{\"m\":74,\"g\":94},\"7e3bb527\":{\"m\":74,\"g\":94},\"96263f27\":{\"m\":74,\"g\":94},\"9376ac36\":{\"m\":74,\"g\":94},\"94a2b9d3\":{\"m\":74,\"g\":94},\"3c3eb374\":{\"m\":74,\"g\":94},\"d557319a\":{\"m\":74,\"g\":94},\"95085d65\":{\"m\":74,\"g\":94},\"c7f25446\":{\"m\":74,\"g\":94},\"63ee26d1\":{\"m\":74,\"g\":94},\"ad55f171\":{\"m\":74,\"g\":94},\"361971b8\":{\"m\":74,\"g\":94},\"13bc39c5\":{\"m\":74,\"g\":94},\"9854a18a\":{\"m\":74,\"g\":94},\"ebddb65a\":{\"m\":74,\"g\":94},\"19fd57bc\":{\"m\":74,\"g\":94},\"ba80c102\":{\"m\":75,\"g\":94},\"fbdb5050\":{\"m\":75,\"g\":94},\"f0afaf52\":{\"m\":75,\"g\":94},\"85d2365d\":{\"m\":75,\"g\":94},\"5fe79605\":{\"m\":75,\"g\":94},\"c6d7f8d3\":{\"m\":75,\"g\":94},\"a5a892ff\":{\"m\":75,\"g\":94},\"8e66fbec\":{\"m\":75,\"g\":94},\"f141298a\":{\"m\":75,\"g\":94},\"4fea040c\":{\"m\":75,\"g\":94},\"1099f6c9\":{\"m\":76,\"g\":94},\"04e3ff69\":{\"m\":76,\"g\":94},\"45fdf1f7\":{\"m\":76,\"g\":94},\"d89c0e4b\":{\"m\":76,\"g\":94},\"fa3c9e06\":{\"m\":76,\"g\":94},\"0d658ac3\":{\"m\":76,\"g\":94},\"ced35a06\":{\"m\":76,\"g\":94},\"26f07294\":{\"m\":76,\"g\":94},\"34e07a65\":{\"m\":76,\"g\":94},\"15ddd843\":{\"m\":76,\"g\":94},\"52029bd1\":{\"m\":76,\"g\":94},\"eb934bdf\":{\"m\":76,\"g\":94},\"e45ae444\":{\"m\":76,\"g\":94},\"ac3fae84\":{\"m\":76,\"g\":94},\"2d1b83e5\":{\"m\":76,\"g\":94},\"199bb01d\":{\"m\":76,\"g\":94},\"6b7038ba\":{\"m\":76,\"g\":94},\"57eec0bf\":{\"m\":76,\"g\":94},\"f01b0925\":{\"m\":76,\"g\":94},\"14269198\":{\"m\":76,\"g\":94},\"9b7cf9ee\":{\"m\":76,\"g\":94},\"1e86457c\":{\"m\":76,\"g\":94},\"64129fa6\":{\"m\":76,\"g\":94},\"e9f8e423\":{\"m\":76,\"g\":94},\"22c3702e\":{\"m\":76,\"g\":94},\"4c584fc6\":{\"m\":76,\"g\":94},\"77cf771e\":{\"m\":76,\"g\":94},\"8154de5a\":{\"m\":76,\"g\":94},\"c11cfda0\":{\"m\":76,\"g\":94},\"64edeb79\":{\"m\":76,\"g\":94},\"65c24c28\":{\"m\":76,\"g\":94},\"3980ff1b\":{\"m\":76,\"g\":94},\"5d7edc8e\":{\"m\":76,\"g\":94},\"af6535e7\":{\"m\":76,\"g\":94},\"93cf7fc5\":{\"m\":76,\"g\":94},\"2a206b22\":{\"m\":76,\"g\":94},\"4d253057\":{\"m\":76,\"g\":94},\"11577ced\":{\"m\":76,\"g\":94},\"ca75741e\":{\"m\":76,\"g\":94},\"c6d549e7\":{\"m\":76,\"g\":94},\"3c09548d\":{\"m\":76,\"g\":94},\"8796cebb\":{\"m\":76,\"g\":94},\"c2bd094d\":{\"m\":76,\"g\":94},\"f8f9244a\":{\"m\":76,\"g\":94},\"ecbfe58b\":{\"m\":76,\"g\":94},\"8f163b16\":{\"m\":76,\"g\":94},\"e7a8610d\":{\"m\":76,\"g\":94},\"a2cc62a6\":{\"m\":76,\"g\":94},\"fb888603\":{\"m\":76,\"g\":94},\"321ab756\":{\"m\":76,\"g\":94},\"38f25e87\":{\"m\":76,\"g\":94},\"8cd42504\":{\"m\":76,\"g\":94},\"6a384d5c\":{\"m\":76,\"g\":94},\"f69e0696\":{\"m\":76,\"g\":94},\"f6ab4ca6\":{\"m\":76,\"g\":94},\"c7c7dbeb\":{\"m\":76,\"g\":94},\"417fc72f\":{\"m\":76,\"g\":94},\"c6ec7029\":{\"m\":76,\"g\":94},\"4c56e5db\":{\"m\":76,\"g\":94},\"7b5fc719\":{\"m\":76,\"g\":94},\"ad4e58bf\":{\"m\":76,\"g\":94},\"bfb03c61\":{\"m\":76,\"g\":94},\"b36ab493\":{\"m\":76,\"g\":94},\"9e93ef3f\":{\"m\":76,\"g\":94},\"fad86a68\":{\"m\":76,\"g\":94},\"df7014a8\":{\"m\":76,\"g\":94},\"49420741\":{\"m\":76,\"g\":94},\"ba52fd18\":{\"m\":76,\"g\":94},\"b6944f97\":{\"m\":76,\"g\":94},\"f44db16c\":{\"m\":76,\"g\":94},\"f9c53cbb\":{\"m\":76,\"g\":94},\"90532b76\":{\"m\":76,\"g\":94},\"c0e9a36c\":{\"m\":76,\"g\":94},\"588865f0\":{\"m\":76,\"g\":94},\"3196999f\":{\"m\":76,\"g\":94},\"9e0186f3\":{\"m\":76,\"g\":94},\"8baf9a0c\":{\"m\":76,\"g\":94},\"c7872985\":{\"m\":76,\"g\":94},\"45212ce1\":{\"m\":76,\"g\":94},\"c16b33cc\":{\"m\":76,\"g\":94},\"2d004512\":{\"m\":76,\"g\":94},\"804d250a\":{\"m\":76,\"g\":94},\"dd865bef\":{\"m\":76,\"g\":94},\"d373a48c\":{\"m\":76,\"g\":94},\"98be3bd3\":{\"m\":76,\"g\":94},\"a98290ae\":{\"m\":76,\"g\":94},\"9b81f9bd\":{\"m\":76,\"g\":94},\"f81a27f6\":{\"m\":76,\"g\":94},\"988ab646\":{\"m\":76,\"g\":94},\"3ded4b21\":{\"m\":76,\"g\":94},\"f4d7ab7a\":{\"m\":76,\"g\":94},\"c38ca4fc\":{\"m\":76,\"g\":94},\"82dec1f7\":{\"m\":76,\"g\":94},\"5f9b2c62\":{\"m\":76,\"g\":94},\"5493c334\":{\"m\":76,\"g\":94},\"f2ab37e5\":{\"m\":76,\"g\":94},\"91ba98fe\":{\"m\":76,\"g\":94},\"c614dbdf\":{\"m\":76,\"g\":94},\"927ca935\":{\"m\":76,\"g\":94},\"ef3c2dd0\":{\"m\":76,\"g\":94},\"75b65648\":{\"m\":76,\"g\":94},\"0f52fb55\":{\"m\":76,\"g\":94},\"d6d21640\":{\"m\":76,\"g\":94},\"0212d2e2\":{\"m\":76,\"g\":94},\"8cc300f5\":{\"m\":76,\"g\":94},\"452db508\":{\"m\":76,\"g\":94},\"d1112d85\":{\"m\":76,\"g\":94},\"48efec7b\":{\"m\":76,\"g\":94},\"9b8333d9\":{\"m\":76,\"g\":94},\"f5bbf603\":{\"m\":76,\"g\":94},\"5cbd709e\":{\"m\":76,\"g\":94},\"2e4a1e2d\":{\"m\":76,\"g\":94},\"9d02bb3e\":{\"m\":76,\"g\":94},\"402db5c5\":{\"m\":76,\"g\":94},\"754a0e82\":{\"m\":76,\"g\":94},\"799fb5f4\":{\"m\":76,\"g\":94},\"25e1816e\":{\"m\":76,\"g\":94},\"a53fe428\":{\"m\":76,\"g\":94},\"1b859295\":{\"m\":76,\"g\":94},\"9971dc22\":{\"m\":76,\"g\":94},\"3db35c1a\":{\"m\":76,\"g\":94},\"52a34d74\":{\"m\":76,\"g\":94},\"06d12b39\":{\"m\":76,\"g\":94},\"c30976fb\":{\"m\":76,\"g\":94},\"1a3fa75f\":{\"m\":76,\"g\":94},\"81f431ed\":{\"m\":76,\"g\":94},\"65b7c9b7\":{\"m\":76,\"g\":94},\"2c4f5cca\":{\"m\":76,\"g\":94},\"15843047\":{\"m\":76,\"g\":94},\"8ec2ce07\":{\"m\":76,\"g\":94},\"1fd0cf8a\":{\"m\":76,\"g\":94},\"bf63ee54\":{\"m\":76,\"g\":94},\"22c96f78\":{\"m\":76,\"g\":94},\"2892b9bb\":{\"m\":76,\"g\":94},\"470b4740\":{\"m\":76,\"g\":94},\"26c372c1\":{\"m\":76,\"g\":94},\"86d9baed\":{\"m\":76,\"g\":94},\"21d485f8\":{\"m\":76,\"g\":94},\"035ac2ab\":{\"m\":76,\"g\":94},\"e1a5e7e4\":{\"m\":76,\"g\":94},\"ad1ae7f7\":{\"m\":76,\"g\":94},\"e73167ad\":{\"m\":76,\"g\":94},\"862fe522\":{\"m\":76,\"g\":94},\"61e4433c\":{\"m\":76,\"g\":94},\"660305c3\":{\"m\":76,\"g\":94},\"642ab418\":{\"m\":76,\"g\":94},\"1ce4878d\":{\"m\":76,\"g\":94},\"977d7cd2\":{\"m\":76,\"g\":94},\"0e0ec702\":{\"m\":76,\"g\":94},\"bb378556\":{\"m\":76,\"g\":94},\"19e96e59\":{\"m\":77,\"g\":94},\"aa08aeac\":{\"m\":77,\"g\":94},\"d8a136a1\":{\"m\":77,\"g\":94},\"20c90be2\":{\"m\":77,\"g\":94},\"ec3ee028\":{\"m\":77,\"g\":94},\"92941ce7\":{\"m\":77,\"g\":94},\"2bb0e7cf\":{\"m\":77,\"g\":94},\"72549263\":{\"m\":77,\"g\":94},\"044c3159\":{\"m\":77,\"g\":94},\"4db29e82\":{\"m\":77,\"g\":94},\"c483377e\":{\"m\":77,\"g\":94},\"74e0ac1d\":{\"m\":77,\"g\":94},\"ef9a378a\":{\"m\":77,\"g\":94},\"6dea5c96\":{\"m\":77,\"g\":94},\"6ffb6bd4\":{\"m\":77,\"g\":94},\"47e6628a\":{\"m\":77,\"g\":94},\"7907f9eb\":{\"m\":77,\"g\":94},\"8c04f0f2\":{\"m\":77,\"g\":94},\"265e7564\":{\"m\":77,\"g\":94},\"d3f71f5e\":{\"m\":77,\"g\":94},\"5eae67cb\":{\"m\":77,\"g\":94},\"6dbf9998\":{\"m\":77,\"g\":94},\"e0166f8a\":{\"m\":77,\"g\":94},\"53a2c3b4\":{\"m\":77,\"g\":94},\"550586ef\":{\"m\":77,\"g\":94},\"cf29fe9e\":{\"m\":77,\"g\":94},\"26c0f131\":{\"m\":77,\"g\":94},\"f9970bd1\":{\"m\":77,\"g\":94},\"2e0f94ab\":{\"m\":77,\"g\":94},\"18317ddc\":{\"m\":77,\"g\":94},\"e2e2ab70\":{\"m\":77,\"g\":94},\"0d3e3072\":{\"m\":77,\"g\":94},\"62dd9587\":{\"m\":77,\"g\":94},\"72031173\":{\"m\":77,\"g\":94},\"9fdc6d6a\":{\"m\":77,\"g\":94},\"42a45df0\":{\"m\":77,\"g\":94},\"04eb6062\":{\"m\":77,\"g\":94},\"e84f4ba0\":{\"m\":77,\"g\":94},\"b149b393\":{\"m\":77,\"g\":94},\"31dfff7d\":{\"m\":77,\"g\":94},\"10a9ab7b\":{\"m\":77,\"g\":94},\"bb0fd749\":{\"m\":77,\"g\":94},\"7f19e083\":{\"m\":77,\"g\":94},\"98a2cfa9\":{\"m\":77,\"g\":94},\"2a882e8f\":{\"m\":77,\"g\":94},\"e6e4d022\":{\"m\":77,\"g\":94},\"188105a2\":{\"m\":77,\"g\":94},\"b3953258\":{\"m\":77,\"g\":94},\"5fa3058f\":{\"m\":77,\"g\":94},\"bbab97a6\":{\"m\":77,\"g\":94},\"0bc0bf57\":{\"m\":77,\"g\":94},\"f60f2931\":{\"m\":77,\"g\":94},\"17000d2b\":{\"m\":77,\"g\":94},\"668ecc6c\":{\"m\":77,\"g\":94},\"886fcbdd\":{\"m\":77,\"g\":94},\"8bf6d7f4\":{\"m\":77,\"g\":94},\"1b9175cb\":{\"m\":77,\"g\":94},\"92bb49a7\":{\"m\":77,\"g\":94},\"6f5cc5eb\":{\"m\":77,\"g\":94},\"c913ed40\":{\"m\":77,\"g\":94},\"1afe3d07\":{\"m\":77,\"g\":94},\"44f47d3e\":{\"m\":77,\"g\":94},\"ae25d36d\":{\"m\":77,\"g\":94},\"35e0856b\":{\"m\":78,\"g\":94},\"aba5ca15\":{\"m\":78,\"g\":94},\"496dde84\":{\"m\":78,\"g\":94},\"bcbbf519\":{\"m\":78,\"g\":94},\"0d99adb7\":{\"m\":78,\"g\":94},\"efbae697\":{\"m\":78,\"g\":94},\"ca8d02ab\":{\"m\":78,\"g\":94},\"3f287b85\":{\"m\":78,\"g\":94},\"7ed77d6b\":{\"m\":78,\"g\":94},\"4c54f442\":{\"m\":78,\"g\":94},\"924ca7c9\":{\"m\":78,\"g\":94},\"6ff9c6a5\":{\"m\":78,\"g\":94},\"77e929a1\":{\"m\":78,\"g\":94},\"febe21ce\":{\"m\":78,\"g\":94},\"a995a773\":{\"m\":78,\"g\":94},\"31035dda\":{\"m\":78,\"g\":94},\"913e38df\":{\"m\":78,\"g\":94},\"d95269f9\":{\"m\":78,\"g\":94},\"e53bf190\":{\"m\":78,\"g\":94},\"3289c120\":{\"m\":78,\"g\":94},\"69df9761\":{\"m\":78,\"g\":94},\"98f768d1\":{\"m\":78,\"g\":94},\"d7954b76\":{\"m\":78,\"g\":94},\"74885a84\":{\"m\":78,\"g\":94},\"b8b6008f\":{\"m\":78,\"g\":94},\"8e10fec9\":{\"m\":78,\"g\":94},\"e8999b13\":{\"m\":78,\"g\":94},\"772d2a19\":{\"m\":78,\"g\":94},\"9d0b36c4\":{\"m\":78,\"g\":94},\"7d8c0ce7\":{\"m\":78,\"g\":94},\"e41549c3\":{\"m\":78,\"g\":94},\"cccfc10e\":{\"m\":78,\"g\":94},\"a2aea59b\":{\"m\":78,\"g\":94},\"2c8fd993\":{\"m\":78,\"g\":94},\"31da75ab\":{\"m\":78,\"g\":94},\"e983e432\":{\"m\":78,\"g\":94},\"e9c6ce46\":{\"m\":78,\"g\":94},\"3fadc647\":{\"m\":78,\"g\":94},\"e119f042\":{\"m\":78,\"g\":94},\"9eb49e87\":{\"m\":78,\"g\":94},\"12047f5e\":{\"m\":78,\"g\":94},\"fda6bb78\":{\"m\":78,\"g\":94},\"23c764b1\":{\"m\":78,\"g\":94},\"87fafa01\":{\"m\":78,\"g\":94},\"1c63e797\":{\"m\":78,\"g\":94},\"ee47a6c1\":{\"m\":78,\"g\":94},\"6384d317\":{\"m\":78,\"g\":94},\"5cb552b1\":{\"m\":78,\"g\":94},\"c7457191\":{\"m\":78,\"g\":94},\"51ac297a\":{\"m\":78,\"g\":94},\"a169b9f8\":{\"m\":78,\"g\":94},\"4a63bc32\":{\"m\":78,\"g\":94},\"a303325f\":{\"m\":78,\"g\":94},\"42873eac\":{\"m\":78,\"g\":94},\"4814ecaf\":{\"m\":78,\"g\":94},\"e62d60fe\":{\"m\":78,\"g\":94},\"032f8faa\":{\"m\":78,\"g\":94},\"37c66ec8\":{\"m\":78,\"g\":94},\"9adf178c\":{\"m\":78,\"g\":94},\"f842853a\":{\"m\":78,\"g\":94},\"195a09f5\":{\"m\":78,\"g\":94},\"9fccda31\":{\"m\":78,\"g\":94},\"4ede6770\":{\"m\":78,\"g\":94},\"b26bc86b\":{\"m\":78,\"g\":94},\"5ec5eaf7\":{\"m\":78,\"g\":94},\"0d7fe866\":{\"m\":78,\"g\":94},\"54b9a2de\":{\"m\":78,\"g\":94},\"8e7b3154\":{\"m\":78,\"g\":94},\"45dcfc2e\":{\"m\":78,\"g\":94},\"ddf8981d\":{\"m\":78,\"g\":94},\"400ad660\":{\"m\":78,\"g\":94},\"05625b97\":{\"m\":78,\"g\":94},\"736502d4\":{\"m\":78,\"g\":94},\"8690c40b\":{\"m\":78,\"g\":94},\"b1cfb4e9\":{\"m\":78,\"g\":94},\"57f99608\":{\"m\":79,\"g\":94},\"81992474\":{\"m\":79,\"g\":94},\"f04c80dc\":{\"m\":79,\"g\":94},\"d1bb1711\":{\"m\":79,\"g\":94},\"5b5c7237\":{\"m\":80,\"g\":94},\"a42736bb\":{\"m\":80,\"g\":94},\"dd83e7e9\":{\"m\":80,\"g\":94},\"0769b14b\":{\"m\":80,\"g\":94},\"b64b88e7\":{\"m\":80,\"g\":94},\"bc24205b\":{\"m\":80,\"g\":94},\"3efc8e2d\":{\"m\":80,\"g\":94},\"27a009bb\":{\"m\":80,\"g\":94},\"8ec0bb7d\":{\"m\":80,\"g\":94},\"fa909dc3\":{\"m\":80,\"g\":94},\"e8f62b20\":{\"m\":80,\"g\":94},\"88defc4d\":{\"m\":80,\"g\":94},\"6f509d55\":{\"m\":80,\"g\":94},\"12ef7e3b\":{\"m\":80,\"g\":94},\"838fa0f2\":{\"m\":80,\"g\":94},\"f1b3b75f\":{\"m\":80,\"g\":94},\"33b16ad1\":{\"m\":80,\"g\":94},\"ffde65a0\":{\"m\":80,\"g\":94},\"471650de\":{\"m\":80,\"g\":94},\"d06a83fb\":{\"m\":80,\"g\":94},\"5d134401\":{\"m\":80,\"g\":94},\"f88f7e19\":{\"m\":80,\"g\":94},\"3dfc6023\":{\"m\":80,\"g\":94},\"15e91d72\":{\"m\":80,\"g\":94},\"8aab7fdb\":{\"m\":80,\"g\":94},\"e940dc4f\":{\"m\":80,\"g\":94},\"388e15c0\":{\"m\":80,\"g\":94},\"11421a3f\":{\"m\":80,\"g\":94},\"6c41fcf0\":{\"m\":80,\"g\":94},\"ee9d6ca6\":{\"m\":80,\"g\":94},\"2dd64894\":{\"m\":80,\"g\":94},\"61e7c4dd\":{\"m\":80,\"g\":94},\"dae79444\":{\"m\":80,\"g\":94},\"f6772f14\":{\"m\":80,\"g\":94},\"ac5b78ba\":{\"m\":80,\"g\":94},\"38076dea\":{\"m\":80,\"g\":94},\"5e0a9b09\":{\"m\":80,\"g\":94},\"bdde2375\":{\"m\":80,\"g\":94},\"e9fc2ac7\":{\"m\":80,\"g\":94},\"44afde82\":{\"m\":80,\"g\":94},\"072df753\":{\"m\":80,\"g\":94},\"defede50\":{\"m\":80,\"g\":94},\"fc728719\":{\"m\":80,\"g\":94},\"14e8bd88\":{\"m\":80,\"g\":94},\"adca585b\":{\"m\":80,\"g\":94},\"39d90449\":{\"m\":80,\"g\":94},\"39e41138\":{\"m\":80,\"g\":94},\"5fbafbb8\":{\"m\":80,\"g\":94},\"a9499885\":{\"m\":80,\"g\":94},\"f7655790\":{\"m\":80,\"g\":94},\"f58b929a\":{\"m\":80,\"g\":94},\"c1270aab\":{\"m\":80,\"g\":94},\"8311b07f\":{\"m\":80,\"g\":94},\"c1380257\":{\"m\":80,\"g\":94},\"b62e7e99\":{\"m\":80,\"g\":94},\"7d3b7c87\":{\"m\":80,\"g\":94},\"75015bb6\":{\"m\":80,\"g\":94},\"b371f7cd\":{\"m\":80,\"g\":94},\"812e82f3\":{\"m\":80,\"g\":94},\"4879e50c\":{\"m\":80,\"g\":94},\"bc92107b\":{\"m\":80,\"g\":94},\"3e4794aa\":{\"m\":80,\"g\":94},\"690ec205\":{\"m\":80,\"g\":94},\"2074a2e6\":{\"m\":80,\"g\":94},\"57de7c6b\":{\"m\":80,\"g\":94},\"115ae2e7\":{\"m\":80,\"g\":94},\"aea98512\":{\"m\":80,\"g\":94},\"e4155e96\":{\"m\":80,\"g\":94},\"1b1b47a9\":{\"m\":80,\"g\":94},\"3c9740d2\":{\"m\":80,\"g\":94},\"2eb55770\":{\"m\":80,\"g\":94},\"f65b8d5c\":{\"m\":80,\"g\":94},\"5ad05719\":{\"m\":80,\"g\":94},\"34ef6c81\":{\"m\":80,\"g\":94},\"61172091\":{\"m\":80,\"g\":94},\"4f288113\":{\"m\":80,\"g\":94},\"136b8e6a\":{\"m\":80,\"g\":94},\"034c5256\":{\"m\":80,\"g\":94},\"c1dd773c\":{\"m\":80,\"g\":94},\"6f859379\":{\"m\":80,\"g\":94},\"f774a0d2\":{\"m\":80,\"g\":94},\"60bcbf2a\":{\"m\":80,\"g\":94},\"a0a9f6d6\":{\"m\":80,\"g\":94},\"80aa8ca8\":{\"m\":80,\"g\":94},\"4aa6bab0\":{\"m\":80,\"g\":94},\"c35dcfdb\":{\"m\":80,\"g\":94},\"c163bf4f\":{\"m\":80,\"g\":94},\"55986343\":{\"m\":80,\"g\":94},\"b75275b6\":{\"m\":80,\"g\":94},\"7074e9ca\":{\"m\":80,\"g\":94},\"fc14cca0\":{\"m\":80,\"g\":94},\"e7beff8a\":{\"m\":80,\"g\":94},\"4d2e3051\":{\"m\":80,\"g\":94},\"e53a0b3d\":{\"m\":80,\"g\":94},\"038bc5d5\":{\"m\":80,\"g\":94},\"aee62d74\":{\"m\":80,\"g\":94},\"cd7e32e2\":{\"m\":80,\"g\":94},\"88799448\":{\"m\":80,\"g\":94},\"a879811c\":{\"m\":80,\"g\":94},\"a222945d\":{\"m\":80,\"g\":94},\"ed01b451\":{\"m\":80,\"g\":94},\"d050df36\":{\"m\":80,\"g\":94},\"76f44c2a\":{\"m\":80,\"g\":94},\"1078396f\":{\"m\":80,\"g\":94},\"7e4f72dd\":{\"m\":80,\"g\":94},\"4c31ae9f\":{\"m\":80,\"g\":94},\"f730362e\":{\"m\":80,\"g\":94},\"e3c4bd31\":{\"m\":80,\"g\":94},\"5db37c86\":{\"m\":80,\"g\":94},\"4cb53ecd\":{\"m\":80,\"g\":94},\"456b008b\":{\"m\":80,\"g\":94},\"ebf495f0\":{\"m\":80,\"g\":94},\"7f875f12\":{\"m\":80,\"g\":94},\"fbebcb7a\":{\"m\":80,\"g\":94},\"87eddedf\":{\"m\":80,\"g\":94},\"40652482\":{\"m\":80,\"g\":94},\"86a876d8\":{\"m\":80,\"g\":94},\"92823069\":{\"m\":80,\"g\":94},\"d2e507df\":{\"m\":80,\"g\":94},\"61970b08\":{\"m\":80,\"g\":94},\"76c48a09\":{\"m\":80,\"g\":94},\"90caf06c\":{\"m\":80,\"g\":94},\"6669d127\":{\"m\":80,\"g\":94},\"f2b70afd\":{\"m\":80,\"g\":94},\"bc3f6db2\":{\"m\":80,\"g\":94},\"aac531c5\":{\"m\":80,\"g\":94},\"39efad4f\":{\"m\":80,\"g\":94},\"466899e6\":{\"m\":80,\"g\":94},\"11d760d5\":{\"m\":80,\"g\":94},\"5039d547\":{\"m\":80,\"g\":94},\"d09a51f1\":{\"m\":80,\"g\":94},\"f8194b26\":{\"m\":80,\"g\":94},\"6d3b35fa\":{\"m\":80,\"g\":94},\"a73c4df4\":{\"m\":80,\"g\":94},\"89a55418\":{\"m\":80,\"g\":94},\"2695ab05\":{\"m\":80,\"g\":94},\"88d6fd9a\":{\"m\":80,\"g\":94},\"cc88d98a\":{\"m\":80,\"g\":94},\"3033c11a\":{\"m\":80,\"g\":94},\"fd5a55cf\":{\"m\":80,\"g\":94},\"804d9f2e\":{\"m\":80,\"g\":94},\"a7c3f74b\":{\"m\":80,\"g\":94},\"5a144a8a\":{\"m\":80,\"g\":94},\"27f8e6b9\":{\"m\":80,\"g\":94},\"afb752bc\":{\"m\":80,\"g\":94},\"9731eca7\":{\"m\":80,\"g\":94},\"7c5658c1\":{\"m\":80,\"g\":94},\"9798e72b\":{\"m\":80,\"g\":94},\"ade714a6\":{\"m\":80,\"g\":94},\"93470a14\":{\"m\":80,\"g\":94},\"db452760\":{\"m\":80,\"g\":94},\"fbdc94ba\":{\"m\":81,\"g\":94},\"b54b5a96\":{\"m\":81,\"g\":94},\"bca832c7\":{\"m\":81,\"g\":94},\"d9dd5298\":{\"m\":81,\"g\":94},\"0a0dd34e\":{\"m\":81,\"g\":94},\"80ac527d\":{\"m\":81,\"g\":94},\"99456bca\":{\"m\":81,\"g\":94},\"d07e797a\":{\"m\":81,\"g\":94},\"c555d794\":{\"m\":81,\"g\":94},\"e2574ee9\":{\"m\":81,\"g\":94},\"ab4b5606\":{\"m\":81,\"g\":94},\"20f1c8e3\":{\"m\":81,\"g\":94},\"613b197e\":{\"m\":81,\"g\":94},\"d58e3544\":{\"m\":81,\"g\":94},\"bf86c5e9\":{\"m\":81,\"g\":94},\"dca90f1d\":{\"m\":81,\"g\":94},\"0961feef\":{\"m\":81,\"g\":94},\"59dd090f\":{\"m\":81,\"g\":94},\"569b032c\":{\"m\":81,\"g\":94},\"f6a71139\":{\"m\":81,\"g\":94},\"1e0806f3\":{\"m\":81,\"g\":94},\"2c11f9c2\":{\"m\":81,\"g\":94},\"a6f892e5\":{\"m\":81,\"g\":94},\"08b518d5\":{\"m\":81,\"g\":94},\"4db463b1\":{\"m\":81,\"g\":94},\"bfa39224\":{\"m\":81,\"g\":94},\"e465b08d\":{\"m\":81,\"g\":94},\"bed05878\":{\"m\":81,\"g\":94},\"b2a189dd\":{\"m\":81,\"g\":94},\"f28d8299\":{\"m\":81,\"g\":94},\"8e09b370\":{\"m\":81,\"g\":94},\"53dcf388\":{\"m\":81,\"g\":94},\"1effba4c\":{\"m\":81,\"g\":94},\"a0fc5bc1\":{\"m\":81,\"g\":94},\"27e9538a\":{\"m\":81,\"g\":94},\"211c7b31\":{\"m\":81,\"g\":94},\"c08a717c\":{\"m\":81,\"g\":94},\"f13d65a7\":{\"m\":81,\"g\":94},\"06d0a3d9\":{\"m\":81,\"g\":94},\"22c2a79d\":{\"m\":81,\"g\":94},\"8beb356f\":{\"m\":81,\"g\":94},\"c776234b\":{\"m\":81,\"g\":94},\"3bface15\":{\"m\":81,\"g\":94},\"6fb29ffd\":{\"m\":81,\"g\":94},\"4fb05583\":{\"m\":81,\"g\":94},\"81c89111\":{\"m\":81,\"g\":94},\"92d1561b\":{\"m\":81,\"g\":94},\"8f783c19\":{\"m\":81,\"g\":94},\"90faf901\":{\"m\":81,\"g\":94},\"177320a5\":{\"m\":81,\"g\":94},\"d7bc19a4\":{\"m\":81,\"g\":94},\"85ec0440\":{\"m\":81,\"g\":94},\"06a1656e\":{\"m\":81,\"g\":94},\"6aca5834\":{\"m\":81,\"g\":94},\"b9c87e78\":{\"m\":82,\"g\":94},\"968ef515\":{\"m\":82,\"g\":94},\"13432002\":{\"m\":82,\"g\":94},\"c2942907\":{\"m\":82,\"g\":94},\"e69a2190\":{\"m\":82,\"g\":94},\"bf98d2e3\":{\"m\":82,\"g\":94},\"e65b9f21\":{\"m\":82,\"g\":94},\"4dce1cc6\":{\"m\":82,\"g\":94},\"deded17f\":{\"m\":82,\"g\":94},\"f29a718f\":{\"m\":82,\"g\":94},\"3f57b00a\":{\"m\":82,\"g\":94},\"453d412c\":{\"m\":82,\"g\":94},\"dc86f25a\":{\"m\":82,\"g\":94},\"08289eaa\":{\"m\":82,\"g\":94},\"3b6d539f\":{\"m\":82,\"g\":94},\"57131dd9\":{\"m\":82,\"g\":94},\"a7591ecf\":{\"m\":82,\"g\":94},\"c44f2869\":{\"m\":82,\"g\":94},\"685d8980\":{\"m\":82,\"g\":94},\"70645f4d\":{\"m\":82,\"g\":94},\"188f0955\":{\"m\":82,\"g\":94},\"eef9433b\":{\"m\":82,\"g\":94},\"97cb762b\":{\"m\":82,\"g\":94},\"11951820\":{\"m\":82,\"g\":94},\"5239d795\":{\"m\":82,\"g\":94},\"f0815419\":{\"m\":82,\"g\":94},\"2b3bdc93\":{\"m\":82,\"g\":94},\"5fc4b600\":{\"m\":82,\"g\":94},\"b868526d\":{\"m\":82,\"g\":94},\"502524e2\":{\"m\":82,\"g\":94},\"4c764007\":{\"m\":82,\"g\":94},\"9f3bd2ad\":{\"m\":82,\"g\":94},\"8de53da9\":{\"m\":82,\"g\":94},\"fac17acf\":{\"m\":82,\"g\":94},\"8b39274e\":{\"m\":82,\"g\":94},\"5156d5a4\":{\"m\":82,\"g\":94},\"c951d312\":{\"m\":82,\"g\":94},\"dcb82325\":{\"m\":82,\"g\":94},\"66c0ff9e\":{\"m\":82,\"g\":94},\"9a7e83e8\":{\"m\":82,\"g\":94},\"417b44eb\":{\"m\":82,\"g\":94},\"475e2e37\":{\"m\":82,\"g\":94},\"fba86b6b\":{\"m\":82,\"g\":94},\"072b4d03\":{\"m\":82,\"g\":94},\"9c434777\":{\"m\":82,\"g\":94},\"fa2f677e\":{\"m\":82,\"g\":94},\"463d4b74\":{\"m\":82,\"g\":94},\"9924bbe1\":{\"m\":82,\"g\":94},\"84022c0e\":{\"m\":83,\"g\":94},\"f9fb33ef\":{\"m\":83,\"g\":94},\"a38f6932\":{\"m\":83,\"g\":94},\"beb65c74\":{\"m\":83,\"g\":94},\"621e96bf\":{\"m\":83,\"g\":94},\"35ca04d2\":{\"m\":83,\"g\":94},\"3c4e0ee6\":{\"m\":83,\"g\":94},\"9c088829\":{\"m\":83,\"g\":94},\"005aad32\":{\"m\":83,\"g\":94},\"4d23ba08\":{\"m\":83,\"g\":94},\"6e313c1b\":{\"m\":83,\"g\":94},\"a45a4b23\":{\"m\":83,\"g\":94},\"981a2619\":{\"m\":83,\"g\":94},\"8ba31330\":{\"m\":83,\"g\":94},\"02102063\":{\"m\":83,\"g\":94},\"7e944246\":{\"m\":83,\"g\":94},\"a086a113\":{\"m\":83,\"g\":94},\"bdbe5f81\":{\"m\":83,\"g\":94},\"9ad28f63\":{\"m\":83,\"g\":94},\"d7b1ce65\":{\"m\":83,\"g\":94},\"f55933e1\":{\"m\":83,\"g\":94},\"408ba022\":{\"m\":83,\"g\":94},\"094891c0\":{\"m\":83,\"g\":94},\"a21ef363\":{\"m\":83,\"g\":94},\"3c4dc38a\":{\"m\":83,\"g\":94},\"d8fbc7c0\":{\"m\":83,\"g\":94},\"c5e1026f\":{\"m\":83,\"g\":94},\"799c4bb5\":{\"m\":83,\"g\":94},\"02723e1b\":{\"m\":83,\"g\":94},\"df2cf583\":{\"m\":83,\"g\":94},\"133ded03\":{\"m\":83,\"g\":94},\"f87a6ab3\":{\"m\":83,\"g\":94},\"eebfdb94\":{\"m\":83,\"g\":94},\"dfb32264\":{\"m\":83,\"g\":94},\"63c13a2c\":{\"m\":83,\"g\":94},\"4d1e52ab\":{\"m\":83,\"g\":94},\"155890e4\":{\"m\":83,\"g\":94},\"1f963d7f\":{\"m\":83,\"g\":94},\"04d0123f\":{\"m\":83,\"g\":94},\"feda9b11\":{\"m\":83,\"g\":94},\"c3948ba6\":{\"m\":83,\"g\":94},\"269c457e\":{\"m\":83,\"g\":94},\"18ce468d\":{\"m\":83,\"g\":94},\"21514ff5\":{\"m\":83,\"g\":94},\"5641a094\":{\"m\":83,\"g\":94},\"3dd3538c\":{\"m\":83,\"g\":94},\"93c6fb12\":{\"m\":83,\"g\":94},\"11e27d09\":{\"m\":83,\"g\":94},\"50eda839\":{\"m\":83,\"g\":94},\"c55550cb\":{\"m\":83,\"g\":94},\"43fb95c2\":{\"m\":83,\"g\":94},\"7d9679b7\":{\"m\":83,\"g\":94},\"b5be5694\":{\"m\":83,\"g\":94},\"d2b8d0b8\":{\"m\":83,\"g\":94},\"a14654dd\":{\"m\":83,\"g\":94},\"5d93a950\":{\"m\":83,\"g\":94},\"c998d04b\":{\"m\":83,\"g\":94},\"7d0edf3c\":{\"m\":83,\"g\":94},\"ce4ecba4\":{\"m\":83,\"g\":94},\"b1f6d89b\":{\"m\":83,\"g\":94},\"7c99103f\":{\"m\":83,\"g\":94},\"de071366\":{\"m\":83,\"g\":94},\"e0673969\":{\"m\":83,\"g\":94},\"127ff898\":{\"m\":83,\"g\":94},\"8777a1d2\":{\"m\":83,\"g\":94},\"711efe78\":{\"m\":83,\"g\":94},\"fbb5f229\":{\"m\":83,\"g\":94},\"15fabcc0\":{\"m\":83,\"g\":94},\"e62c4955\":{\"m\":83,\"g\":94},\"71d1785f\":{\"m\":83,\"g\":94},\"3f87f831\":{\"m\":83,\"g\":94},\"ce5412b6\":{\"m\":83,\"g\":94},\"7282ab74\":{\"m\":83,\"g\":94},\"b0feda09\":{\"m\":83,\"g\":94},\"6b6e7487\":{\"m\":83,\"g\":94},\"91732486\":{\"m\":83,\"g\":94},\"2ed96c7a\":{\"m\":83,\"g\":94},\"2aa3f5e2\":{\"m\":83,\"g\":94},\"76d17c7e\":{\"m\":83,\"g\":94},\"70d040f9\":{\"m\":83,\"g\":94},\"4418f599\":{\"m\":83,\"g\":94},\"04f2abcb\":{\"m\":83,\"g\":94},\"506be6b8\":{\"m\":83,\"g\":94},\"2343d8df\":{\"m\":83,\"g\":94},\"92bb64bc\":{\"m\":83,\"g\":94},\"11b23ae9\":{\"m\":83,\"g\":94},\"dcae1fb2\":{\"m\":84,\"g\":94},\"a0251a3f\":{\"m\":84,\"g\":94},\"663037a7\":{\"m\":84,\"g\":94},\"f4a9f60c\":{\"m\":84,\"g\":94},\"ee71ed8a\":{\"m\":84,\"g\":94},\"d364b9b0\":{\"m\":84,\"g\":94},\"849c83a0\":{\"m\":84,\"g\":94},\"d73ddeb1\":{\"m\":84,\"g\":94},\"f48b007c\":{\"m\":84,\"g\":94},\"74cb12a8\":{\"m\":84,\"g\":94},\"c6c62640\":{\"m\":84,\"g\":94},\"92ab0a20\":{\"m\":84,\"g\":94},\"e132cba2\":{\"m\":84,\"g\":94},\"0045f4b2\":{\"m\":84,\"g\":94},\"8601300b\":{\"m\":84,\"g\":94},\"6fa6f38e\":{\"m\":84,\"g\":94},\"693723d1\":{\"m\":84,\"g\":94},\"966eb908\":{\"m\":84,\"g\":94},\"644ed409\":{\"m\":84,\"g\":94},\"3029889c\":{\"m\":84,\"g\":94},\"ef15dcda\":{\"m\":84,\"g\":94},\"ad4df307\":{\"m\":84,\"g\":94},\"41ac0c6d\":{\"m\":84,\"g\":94},\"84810da4\":{\"m\":84,\"g\":94},\"40d9b8ac\":{\"m\":84,\"g\":94},\"f0365820\":{\"m\":84,\"g\":94},\"86317c09\":{\"m\":84,\"g\":94},\"daed453e\":{\"m\":84,\"g\":94},\"ded04b2e\":{\"m\":84,\"g\":94},\"9858113c\":{\"m\":85,\"g\":94},\"8441baad\":{\"m\":85,\"g\":94},\"256c4c25\":{\"m\":85,\"g\":94},\"9f21e754\":{\"m\":85,\"g\":94},\"7bcd8b1c\":{\"m\":85,\"g\":94},\"11383cec\":{\"m\":85,\"g\":94},\"e97e57e6\":{\"m\":85,\"g\":94},\"9a6ad891\":{\"m\":85,\"g\":94},\"d353d08b\":{\"m\":85,\"g\":94},\"08acdb5c\":{\"m\":85,\"g\":94},\"2afba1b1\":{\"m\":85,\"g\":94},\"e330f2b8\":{\"m\":85,\"g\":94},\"3ddf5b9d\":{\"m\":85,\"g\":94},\"3cff9633\":{\"m\":85,\"g\":94},\"d50e36a7\":{\"m\":85,\"g\":94},\"8fefdd32\":{\"m\":85,\"g\":94},\"403b855a\":{\"m\":85,\"g\":94},\"1698e94e\":{\"m\":85,\"g\":94},\"58195dd5\":{\"m\":85,\"g\":94},\"799789af\":{\"m\":85,\"g\":94},\"cc4a80ca\":{\"m\":85,\"g\":94},\"3c8a5231\":{\"m\":85,\"g\":94},\"a043f7f2\":{\"m\":85,\"g\":94},\"e3a53044\":{\"m\":85,\"g\":94},\"28b26dbf\":{\"m\":85,\"g\":94},\"2b06484b\":{\"m\":85,\"g\":94},\"e4b6133b\":{\"m\":85,\"g\":94},\"dd408ee4\":{\"m\":85,\"g\":94},\"9419e75d\":{\"m\":85,\"g\":94},\"2c7dbb7c\":{\"m\":85,\"g\":94},\"9a62191b\":{\"m\":85,\"g\":94},\"ae523675\":{\"m\":85,\"g\":94},\"5c08aa49\":{\"m\":85,\"g\":94},\"f4c191a7\":{\"m\":85,\"g\":94},\"771669cb\":{\"m\":85,\"g\":94},\"1468769b\":{\"m\":85,\"g\":94},\"91dda4cd\":{\"m\":85,\"g\":94},\"8e5a6d34\":{\"m\":85,\"g\":94},\"8465f035\":{\"m\":85,\"g\":94},\"8c0cfca8\":{\"m\":85,\"g\":94},\"2c3ea294\":{\"m\":85,\"g\":94},\"5bb0accb\":{\"m\":85,\"g\":94},\"8d463fe3\":{\"m\":85,\"g\":94},\"26fc32d1\":{\"m\":85,\"g\":94},\"1cc32603\":{\"m\":85,\"g\":94},\"05ee2192\":{\"m\":85,\"g\":94},\"678d8cc9\":{\"m\":86,\"g\":94},\"d2cb3024\":{\"m\":86,\"g\":94},\"1940cdec\":{\"m\":86,\"g\":94},\"63484f9f\":{\"m\":86,\"g\":94},\"dff0ab92\":{\"m\":86,\"g\":94},\"e30c273b\":{\"m\":86,\"g\":94},\"0ab3f437\":{\"m\":86,\"g\":94},\"cec98f10\":{\"m\":86,\"g\":94},\"8dc4efd0\":{\"m\":86,\"g\":94},\"6578cf27\":{\"m\":86,\"g\":94},\"087751a8\":{\"m\":86,\"g\":94},\"911f3ba6\":{\"m\":86,\"g\":94},\"f6f96b05\":{\"m\":86,\"g\":94},\"2a936a84\":{\"m\":86,\"g\":94},\"5e023301\":{\"m\":86,\"g\":94},\"fa7d7fd9\":{\"m\":86,\"g\":94},\"f1ff736d\":{\"m\":86,\"g\":94},\"acc816d8\":{\"m\":86,\"g\":94},\"a05bd83a\":{\"m\":86,\"g\":94},\"cef91b1e\":{\"m\":86,\"g\":94},\"6450c122\":{\"m\":86,\"g\":94},\"b6cf3532\":{\"m\":86,\"g\":94},\"3b2680a4\":{\"m\":86,\"g\":94},\"79961afa\":{\"m\":86,\"g\":94},\"cfca4e0e\":{\"m\":86,\"g\":94},\"e88dd482\":{\"m\":86,\"g\":94},\"73600673\":{\"m\":86,\"g\":94},\"8f508cc7\":{\"m\":86,\"g\":94},\"9bddf1c8\":{\"m\":86,\"g\":94},\"24c13ca9\":{\"m\":86,\"g\":94},\"b70957fc\":{\"m\":86,\"g\":94},\"e444c13f\":{\"m\":86,\"g\":94},\"fee37d9e\":{\"m\":86,\"g\":94},\"c68de479\":{\"m\":86,\"g\":94},\"4c7b4242\":{\"m\":86,\"g\":94},\"38053c33\":{\"m\":86,\"g\":94},\"00c2c1f0\":{\"m\":86,\"g\":94},\"cb691945\":{\"m\":86,\"g\":94},\"d25398cb\":{\"m\":86,\"g\":94},\"8a828666\":{\"m\":86,\"g\":94},\"aff584fa\":{\"m\":86,\"g\":94},\"6f566147\":{\"m\":86,\"g\":94},\"bdd17998\":{\"m\":86,\"g\":94},\"c9abd7be\":{\"m\":86,\"g\":94},\"a3e4e9bf\":{\"m\":86,\"g\":94},\"6d4d3bc8\":{\"m\":86,\"g\":94},\"5f300141\":{\"m\":86,\"g\":94},\"1c05425b\":{\"m\":86,\"g\":94},\"b26cb1c5\":{\"m\":86,\"g\":94},\"f8e46093\":{\"m\":86,\"g\":94},\"683707c3\":{\"m\":86,\"g\":94},\"a68ed766\":{\"m\":86,\"g\":94},\"82653f66\":{\"m\":86,\"g\":94},\"22da3d97\":{\"m\":86,\"g\":94},\"b8559764\":{\"m\":86,\"g\":94},\"56f6589e\":{\"m\":86,\"g\":94},\"1232f7e8\":{\"m\":86,\"g\":94},\"3008db9c\":{\"m\":86,\"g\":94},\"357fb2db\":{\"m\":86,\"g\":94},\"95c231e5\":{\"m\":86,\"g\":94},\"3042f1da\":{\"m\":86,\"g\":94},\"2b63798c\":{\"m\":86,\"g\":94},\"bf203cb7\":{\"m\":86,\"g\":94},\"8ebde73f\":{\"m\":86,\"g\":94},\"6b0fae79\":{\"m\":86,\"g\":94},\"141a4596\":{\"m\":86,\"g\":94},\"d8ab6011\":{\"m\":86,\"g\":94},\"6579cd7d\":{\"m\":86,\"g\":94},\"97ac42b6\":{\"m\":86,\"g\":94},\"1acca3a2\":{\"m\":86,\"g\":94},\"6ea1e6ac\":{\"m\":86,\"g\":94},\"3409aaab\":{\"m\":86,\"g\":94},\"73dcf2b3\":{\"m\":86,\"g\":94},\"170d1f21\":{\"m\":86,\"g\":94},\"73bc1d00\":{\"m\":86,\"g\":94},\"c5645e92\":{\"m\":86,\"g\":94},\"d33955d2\":{\"m\":86,\"g\":94},\"6fc17596\":{\"m\":86,\"g\":94},\"ad506a4e\":{\"m\":86,\"g\":94},\"ebaba856\":{\"m\":86,\"g\":94},\"de2faef9\":{\"m\":86,\"g\":94},\"67b7d5b1\":{\"m\":86,\"g\":94},\"4322c31e\":{\"m\":86,\"g\":94},\"16267d4f\":{\"m\":87,\"g\":94},\"0f5cb8ca\":{\"m\":87,\"g\":94},\"17299f08\":{\"m\":87,\"g\":94},\"5380cd7e\":{\"m\":87,\"g\":94},\"b2e95f62\":{\"m\":87,\"g\":94},\"1ab14c4c\":{\"m\":87,\"g\":94},\"3c32895c\":{\"m\":87,\"g\":94},\"ac2324c1\":{\"m\":87,\"g\":94},\"ef8ec07b\":{\"m\":87,\"g\":94},\"f24fc5b8\":{\"m\":87,\"g\":94},\"d18c6b33\":{\"m\":87,\"g\":94},\"f1c89600\":{\"m\":87,\"g\":94},\"983c663d\":{\"m\":87,\"g\":94},\"f94543d2\":{\"m\":87,\"g\":94},\"e8e18dcd\":{\"m\":87,\"g\":94},\"bad7c26f\":{\"m\":87,\"g\":94},\"12319a67\":{\"m\":87,\"g\":94},\"d738ab52\":{\"m\":87,\"g\":94},\"3ee40ff9\":{\"m\":87,\"g\":94},\"0f334945\":{\"m\":87,\"g\":94},\"fba8eccd\":{\"m\":87,\"g\":94},\"7d3a3d45\":{\"m\":87,\"g\":94},\"25c83fff\":{\"m\":87,\"g\":94},\"9f2c9568\":{\"m\":87,\"g\":94},\"3f2702ae\":{\"m\":87,\"g\":94},\"6ea05950\":{\"m\":87,\"g\":94},\"e7dd906c\":{\"m\":87,\"g\":94},\"6e2da515\":{\"m\":87,\"g\":94},\"e9a47f4c\":{\"m\":87,\"g\":94},\"03227c5f\":{\"m\":87,\"g\":94},\"01bdbf7f\":{\"m\":87,\"g\":94},\"94d42b67\":{\"m\":87,\"g\":94},\"69276f61\":{\"m\":87,\"g\":94},\"41a645f5\":{\"m\":87,\"g\":94},\"23010630\":{\"m\":87,\"g\":94},\"45b4dcf0\":{\"m\":87,\"g\":94},\"213e8c7d\":{\"m\":87,\"g\":94},\"41273fd7\":{\"m\":87,\"g\":94},\"e9bebafb\":{\"m\":87,\"g\":94},\"4d1c9db6\":{\"m\":87,\"g\":94},\"17c36c55\":{\"m\":87,\"g\":94},\"31d1f6e7\":{\"m\":87,\"g\":94},\"a823c6e8\":{\"m\":87,\"g\":94},\"2ce87935\":{\"m\":87,\"g\":94},\"de167cf5\":{\"m\":87,\"g\":94},\"4319978c\":{\"m\":87,\"g\":94},\"03dd785c\":{\"m\":87,\"g\":94},\"66fc63d6\":{\"m\":87,\"g\":94},\"921e4a81\":{\"m\":87,\"g\":94},\"9d8ec2e6\":{\"m\":87,\"g\":94},\"c178abda\":{\"m\":87,\"g\":94},\"b29a026e\":{\"m\":87,\"g\":94},\"7e257cd6\":{\"m\":88,\"g\":94},\"c4831e2f\":{\"m\":88,\"g\":94},\"2e37fa07\":{\"m\":88,\"g\":94},\"2d831c6e\":{\"m\":88,\"g\":94},\"ed0c3035\":{\"m\":88,\"g\":94},\"e6f11356\":{\"m\":88,\"g\":94},\"7b02c326\":{\"m\":88,\"g\":94},\"fefa19fe\":{\"m\":88,\"g\":94},\"9c574585\":{\"m\":88,\"g\":94},\"8233cc10\":{\"m\":88,\"g\":94},\"1b2e8f76\":{\"m\":88,\"g\":94},\"d2e0881a\":{\"m\":88,\"g\":94},\"2f427491\":{\"m\":88,\"g\":94},\"d8189660\":{\"m\":88,\"g\":94},\"3ded6235\":{\"m\":88,\"g\":94},\"4ba1eea8\":{\"m\":88,\"g\":94},\"4685fbb8\":{\"m\":88,\"g\":94},\"0a4fc73b\":{\"m\":88,\"g\":94},\"a6970a17\":{\"m\":88,\"g\":94},\"a6ae3af1\":{\"m\":88,\"g\":94},\"0b07c4a9\":{\"m\":88,\"g\":94},\"fc0e3b91\":{\"m\":88,\"g\":94},\"d71f3f0a\":{\"m\":88,\"g\":94},\"58f10679\":{\"m\":88,\"g\":94},\"7a80f565\":{\"m\":88,\"g\":94},\"9484eba4\":{\"m\":88,\"g\":94},\"e9feb488\":{\"m\":88,\"g\":94},\"fc992a09\":{\"m\":88,\"g\":94},\"121f92c5\":{\"m\":88,\"g\":94},\"3bde1010\":{\"m\":88,\"g\":94},\"75135580\":{\"m\":88,\"g\":94},\"4d643f6c\":{\"m\":88,\"g\":94},\"6ce0ed07\":{\"m\":88,\"g\":94},\"969660c7\":{\"m\":88,\"g\":94},\"16d4f680\":{\"m\":88,\"g\":94},\"ada268fd\":{\"m\":88,\"g\":94},\"cfe48c59\":{\"m\":88,\"g\":94},\"d4c038da\":{\"m\":88,\"g\":94},\"55f6005f\":{\"m\":88,\"g\":94},\"7222e1da\":{\"m\":88,\"g\":94},\"505eec4d\":{\"m\":88,\"g\":94},\"ccfe5c00\":{\"m\":88,\"g\":94},\"a071dc40\":{\"m\":88,\"g\":94},\"a40aecc5\":{\"m\":88,\"g\":94},\"d6e1d28c\":{\"m\":88,\"g\":94},\"7c347259\":{\"m\":88,\"g\":94},\"669caa0a\":{\"m\":88,\"g\":94},\"4024e1d2\":{\"m\":88,\"g\":94},\"5c0b38f3\":{\"m\":88,\"g\":94},\"30ca18f4\":{\"m\":88,\"g\":94},\"03886917\":{\"m\":88,\"g\":94},\"66324895\":{\"m\":88,\"g\":94},\"13feffd0\":{\"m\":88,\"g\":94},\"e98afbe0\":{\"m\":88,\"g\":94},\"69af3ec3\":{\"m\":88,\"g\":94},\"32cc66ef\":{\"m\":88,\"g\":94},\"83f2d9d4\":{\"m\":88,\"g\":94},\"6317c5c6\":{\"m\":88,\"g\":94},\"cba1cdbc\":{\"m\":88,\"g\":94},\"c471d39e\":{\"m\":88,\"g\":94},\"d0443275\":{\"m\":88,\"g\":94},\"17d080b7\":{\"m\":88,\"g\":94},\"1b19df4b\":{\"m\":88,\"g\":94},\"f0653886\":{\"m\":88,\"g\":94},\"b1465557\":{\"m\":88,\"g\":94},\"b06215da\":{\"m\":88,\"g\":94},\"7adf245b\":{\"m\":88,\"g\":94},\"299fd22f\":{\"m\":88,\"g\":94},\"506e5de8\":{\"m\":88,\"g\":94},\"844e2f22\":{\"m\":88,\"g\":94},\"4f39bcf7\":{\"m\":88,\"g\":94},\"31c9569b\":{\"m\":88,\"g\":94},\"1be6956d\":{\"m\":88,\"g\":94},\"626ccb7d\":{\"m\":88,\"g\":94},\"72bfb0ba\":{\"m\":88,\"g\":94},\"15521495\":{\"m\":88,\"g\":94},\"ebe58d54\":{\"m\":88,\"g\":94},\"066cf445\":{\"m\":88,\"g\":94},\"6dc6b306\":{\"m\":88,\"g\":94},\"1f30c05d\":{\"m\":88,\"g\":94},\"5dd62c3a\":{\"m\":88,\"g\":94},\"f11481b9\":{\"m\":88,\"g\":94},\"9d24c3ff\":{\"m\":88,\"g\":94},\"24161c59\":{\"m\":88,\"g\":94},\"eabcf82a\":{\"m\":88,\"g\":94},\"c47a51db\":{\"m\":88,\"g\":94},\"11553c1a\":{\"m\":88,\"g\":94},\"01dd39ba\":{\"m\":88,\"g\":94},\"b3f3d610\":{\"m\":88,\"g\":94},\"f07c6a00\":{\"m\":88,\"g\":94},\"4bb816d4\":{\"m\":88,\"g\":94},\"c250939e\":{\"m\":88,\"g\":94},\"b6909aa2\":{\"m\":88,\"g\":94},\"f8728357\":{\"m\":88,\"g\":94},\"73187152\":{\"m\":88,\"g\":94},\"40865665\":{\"m\":88,\"g\":94},\"fd08c048\":{\"m\":88,\"g\":94},\"26ebb849\":{\"m\":88,\"g\":94},\"02973cd9\":{\"m\":88,\"g\":94},\"6d95a35a\":{\"m\":88,\"g\":94},\"01d2838c\":{\"m\":88,\"g\":94},\"e3b8a722\":{\"m\":88,\"g\":94},\"3cf1473a\":{\"m\":88,\"g\":94},\"27168308\":{\"m\":88,\"g\":94},\"e3bed74a\":{\"m\":88,\"g\":94},\"e9ef39d2\":{\"m\":88,\"g\":94},\"205d5cb4\":{\"m\":88,\"g\":94},\"3d7f7a43\":{\"m\":88,\"g\":94},\"2df9d40a\":{\"m\":88,\"g\":94},\"8dc191f2\":{\"m\":88,\"g\":94},\"64825b83\":{\"m\":88,\"g\":94},\"69748d08\":{\"m\":88,\"g\":94},\"dcc0a456\":{\"m\":88,\"g\":94},\"c2b7ddca\":{\"m\":88,\"g\":94},\"abebd939\":{\"m\":88,\"g\":94},\"4bd2952a\":{\"m\":88,\"g\":94},\"6fc93575\":{\"m\":88,\"g\":94},\"839fb31e\":{\"m\":88,\"g\":94},\"f19a9204\":{\"m\":88,\"g\":94},\"c23a7072\":{\"m\":88,\"g\":94},\"e07a6977\":{\"m\":88,\"g\":94},\"cd8d4b9d\":{\"m\":88,\"g\":94},\"f194e14f\":{\"m\":88,\"g\":94},\"cfc9f9ab\":{\"m\":88,\"g\":94},\"fb4959b2\":{\"m\":88,\"g\":94},\"9a405274\":{\"m\":88,\"g\":94},\"2e4babdb\":{\"m\":88,\"g\":94},\"44a3783d\":{\"m\":88,\"g\":94},\"f3bf6110\":{\"m\":88,\"g\":94},\"198b9056\":{\"m\":88,\"g\":94},\"73eb67c0\":{\"m\":88,\"g\":94},\"9a91fa0e\":{\"m\":88,\"g\":94},\"cd7c8a8d\":{\"m\":88,\"g\":94},\"3e350a93\":{\"m\":88,\"g\":94},\"fb71725c\":{\"m\":88,\"g\":94},\"912788c0\":{\"m\":88,\"g\":94},\"0f75b907\":{\"m\":88,\"g\":94},\"4f723edd\":{\"m\":89,\"g\":94},\"fcde67b0\":{\"m\":89,\"g\":94},\"81372f3b\":{\"m\":89,\"g\":94},\"baa6624d\":{\"m\":89,\"g\":94},\"c2b16795\":{\"m\":89,\"g\":94},\"f6ebba53\":{\"m\":89,\"g\":94},\"6716b417\":{\"m\":89,\"g\":94},\"1c8b42c8\":{\"m\":89,\"g\":94},\"f20f7000\":{\"m\":89,\"g\":94},\"f40942ad\":{\"m\":89,\"g\":94},\"dc0705a5\":{\"m\":89,\"g\":94},\"a968c888\":{\"m\":89,\"g\":94},\"a979daac\":{\"m\":89,\"g\":94},\"f1569876\":{\"m\":89,\"g\":94},\"3465d7ae\":{\"m\":89,\"g\":94},\"e58423b2\":{\"m\":89,\"g\":94},\"7059ae16\":{\"m\":89,\"g\":94},\"51d9a597\":{\"m\":89,\"g\":94},\"56ccd3c2\":{\"m\":89,\"g\":94},\"98c00a2d\":{\"m\":89,\"g\":94},\"451ffe74\":{\"m\":89,\"g\":94},\"b1e5a33a\":{\"m\":89,\"g\":94},\"9d5fa68b\":{\"m\":89,\"g\":94},\"2c186425\":{\"m\":89,\"g\":94},\"18efb5e8\":{\"m\":89,\"g\":94},\"de1350ea\":{\"m\":89,\"g\":94},\"86fe943b\":{\"m\":89,\"g\":94},\"9ecb1856\":{\"m\":89,\"g\":94},\"cc74499d\":{\"m\":89,\"g\":94},\"0c1f03a2\":{\"m\":89,\"g\":94},\"3712abfa\":{\"m\":89,\"g\":94},\"971a0dfa\":{\"m\":89,\"g\":94},\"2fc12995\":{\"m\":89,\"g\":94},\"20d3ad3b\":{\"m\":89,\"g\":94},\"fa3592cf\":{\"m\":89,\"g\":94},\"608668e1\":{\"m\":89,\"g\":94},\"6c0a4828\":{\"m\":89,\"g\":94},\"47402883\":{\"m\":89,\"g\":94},\"1fb76ebb\":{\"m\":89,\"g\":94},\"c2c4f57f\":{\"m\":89,\"g\":94},\"23881fa6\":{\"m\":89,\"g\":94},\"8db3ac55\":{\"m\":89,\"g\":94},\"3e56f557\":{\"m\":89,\"g\":94},\"62fec60d\":{\"m\":89,\"g\":94},\"e7759778\":{\"m\":89,\"g\":94},\"77e928d0\":{\"m\":89,\"g\":94},\"515ef4fa\":{\"m\":89,\"g\":94},\"f5599ef1\":{\"m\":89,\"g\":94},\"c499591a\":{\"m\":89,\"g\":94},\"e1ce44cd\":{\"m\":89,\"g\":94},\"f1114e7f\":{\"m\":89,\"g\":94},\"bae4fdc7\":{\"m\":89,\"g\":94},\"6153f2ff\":{\"m\":89,\"g\":94},\"8b5f83ed\":{\"m\":89,\"g\":94},\"2a413829\":{\"m\":89,\"g\":94},\"d5c097a2\":{\"m\":89,\"g\":94},\"9736cd3b\":{\"m\":89,\"g\":94},\"2f715f51\":{\"m\":89,\"g\":94},\"d664ca18\":{\"m\":89,\"g\":94},\"22fe7878\":{\"m\":89,\"g\":94},\"c4ffbeca\":{\"m\":89,\"g\":94},\"f8eaaab8\":{\"m\":89,\"g\":94},\"697b0f71\":{\"m\":89,\"g\":94},\"132dad87\":{\"m\":89,\"g\":94},\"60fdad7c\":{\"m\":89,\"g\":94},\"61ce91ed\":{\"m\":89,\"g\":94},\"e6b7053b\":{\"m\":89,\"g\":94},\"5f91c825\":{\"m\":89,\"g\":94},\"b819381f\":{\"m\":89,\"g\":94},\"562f279a\":{\"m\":89,\"g\":94},\"8b247489\":{\"m\":89,\"g\":94},\"0df6765c\":{\"m\":89,\"g\":94},\"35b65cf0\":{\"m\":89,\"g\":94},\"dd1012fc\":{\"m\":89,\"g\":94},\"44aab7f9\":{\"m\":89,\"g\":94},\"43baba64\":{\"m\":89,\"g\":94},\"0166403c\":{\"m\":89,\"g\":94},\"bcf66ef3\":{\"m\":89,\"g\":94},\"0de5e7d4\":{\"m\":89,\"g\":94},\"72a110f6\":{\"m\":89,\"g\":94},\"5aff1e93\":{\"m\":89,\"g\":94},\"8e3797be\":{\"m\":89,\"g\":94},\"4474eaf5\":{\"m\":89,\"g\":94},\"499f5e62\":{\"m\":89,\"g\":94},\"81964328\":{\"m\":89,\"g\":94},\"f0f84975\":{\"m\":89,\"g\":94},\"3f1e4339\":{\"m\":89,\"g\":94},\"cf9815ba\":{\"m\":89,\"g\":94},\"bd75690f\":{\"m\":89,\"g\":94},\"180ff5ee\":{\"m\":89,\"g\":94},\"37f15475\":{\"m\":89,\"g\":94},\"8a548052\":{\"m\":89,\"g\":94},\"b6d0ce9f\":{\"m\":89,\"g\":94},\"0ea330ca\":{\"m\":89,\"g\":94},\"27e327b4\":{\"m\":89,\"g\":94},\"ff00895c\":{\"m\":89,\"g\":94},\"ff914748\":{\"m\":89,\"g\":94},\"eb38c7d1\":{\"m\":89,\"g\":94},\"df7f61ee\":{\"m\":89,\"g\":94},\"ef21729c\":{\"m\":89,\"g\":94},\"f5159315\":{\"m\":89,\"g\":94},\"6d7b6696\":{\"m\":89,\"g\":94},\"6376b632\":{\"m\":89,\"g\":94},\"e05e29d1\":{\"m\":89,\"g\":94},\"a2cb5913\":{\"m\":89,\"g\":94},\"55444ed6\":{\"m\":89,\"g\":94},\"20fd53b8\":{\"m\":89,\"g\":94},\"6a47b730\":{\"m\":89,\"g\":94},\"c429919d\":{\"m\":89,\"g\":94},\"1da8d230\":{\"m\":89,\"g\":94},\"2f7420bc\":{\"m\":89,\"g\":94},\"c6a0cacc\":{\"m\":89,\"g\":94},\"0a9bfc20\":{\"m\":89,\"g\":94},\"34c63731\":{\"m\":89,\"g\":94},\"2d72fc47\":{\"m\":89,\"g\":94},\"b520d028\":{\"m\":89,\"g\":94},\"7dc0e394\":{\"m\":89,\"g\":94},\"fb507b7b\":{\"m\":89,\"g\":94},\"f90945c4\":{\"m\":89,\"g\":94},\"094fbdac\":{\"m\":89,\"g\":94},\"888cb175\":{\"m\":89,\"g\":94},\"e39bca07\":{\"m\":89,\"g\":94},\"a2bb8565\":{\"m\":89,\"g\":94},\"ced3c07a\":{\"m\":89,\"g\":94},\"f18b068f\":{\"m\":89,\"g\":94},\"4fac524b\":{\"m\":89,\"g\":94},\"b581b225\":{\"m\":89,\"g\":94},\"69dd878b\":{\"m\":89,\"g\":94},\"22630ca2\":{\"m\":89,\"g\":94},\"d279d499\":{\"m\":89,\"g\":94},\"6cb00c63\":{\"m\":89,\"g\":94},\"62cac2c4\":{\"m\":89,\"g\":94},\"2c3b71d6\":{\"m\":89,\"g\":94},\"51cdd81f\":{\"m\":89,\"g\":94},\"73def253\":{\"m\":89,\"g\":94},\"d9d35def\":{\"m\":89,\"g\":94},\"6df81e8a\":{\"m\":89,\"g\":94},\"3ab7d9b5\":{\"m\":89,\"g\":94},\"7e5071c9\":{\"m\":89,\"g\":94},\"78689d33\":{\"m\":89,\"g\":94},\"1dc6864f\":{\"m\":89,\"g\":94},\"485a023b\":{\"m\":89,\"g\":94},\"7e412900\":{\"m\":89,\"g\":94},\"c673727e\":{\"m\":89,\"g\":94},\"f4d4f939\":{\"m\":89,\"g\":94},\"f2bd3515\":{\"m\":89,\"g\":94},\"c459536b\":{\"m\":89,\"g\":94},\"535c8386\":{\"m\":89,\"g\":94},\"2163586e\":{\"m\":89,\"g\":94},\"e06b0761\":{\"m\":89,\"g\":94},\"844a8f42\":{\"m\":89,\"g\":94},\"791b3bfa\":{\"m\":89,\"g\":94},\"31589e17\":{\"m\":89,\"g\":94},\"ae6a5b29\":{\"m\":89,\"g\":94},\"4839999b\":{\"m\":89,\"g\":94},\"541a985f\":{\"m\":89,\"g\":94},\"5170b010\":{\"m\":89,\"g\":94},\"d63e76f7\":{\"m\":89,\"g\":94},\"e9fd11c0\":{\"m\":89,\"g\":94},\"c7588d59\":{\"m\":89,\"g\":94},\"6b231325\":{\"m\":89,\"g\":94},\"b1c8d4e9\":{\"m\":89,\"g\":94},\"c25231c6\":{\"m\":89,\"g\":94},\"fba03b29\":{\"m\":89,\"g\":94},\"461a7302\":{\"m\":89,\"g\":94},\"07610353\":{\"m\":89,\"g\":94},\"c087ddd6\":{\"m\":89,\"g\":94},\"f4a8987f\":{\"m\":89,\"g\":94},\"41ba767f\":{\"m\":89,\"g\":94},\"f127355a\":{\"m\":89,\"g\":94},\"bdb962d7\":{\"m\":89,\"g\":94},\"0b9557fc\":{\"m\":89,\"g\":94},\"87068b5c\":{\"m\":89,\"g\":94},\"a564e001\":{\"m\":89,\"g\":94},\"2103b806\":{\"m\":89,\"g\":94},\"e806f708\":{\"m\":89,\"g\":94},\"fa6723f0\":{\"m\":89,\"g\":94},\"673ff668\":{\"m\":89,\"g\":94},\"447be242\":{\"m\":89,\"g\":94},\"183d9f96\":{\"m\":89,\"g\":94},\"63195028\":{\"m\":89,\"g\":94},\"a3d7f4b6\":{\"m\":89,\"g\":94},\"b18416fb\":{\"m\":89,\"g\":94},\"ce9d690e\":{\"m\":89,\"g\":94},\"bdaefbbf\":{\"m\":89,\"g\":94},\"45a31a82\":{\"m\":89,\"g\":94},\"1aa0fbf4\":{\"m\":89,\"g\":94},\"7a0bbe6a\":{\"m\":89,\"g\":94},\"ae335842\":{\"m\":89,\"g\":94},\"477a101c\":{\"m\":89,\"g\":94},\"1a8f5f68\":{\"m\":89,\"g\":94},\"32cd7070\":{\"m\":89,\"g\":94},\"ebd1ed49\":{\"m\":89,\"g\":94},\"f77da699\":{\"m\":89,\"g\":94},\"d6864ce6\":{\"m\":89,\"g\":94},\"755a3661\":{\"m\":89,\"g\":94},\"79a39ac0\":{\"m\":89,\"g\":94},\"3ce94f71\":{\"m\":89,\"g\":94},\"ca95556c\":{\"m\":89,\"g\":94},\"eb8f02dd\":{\"m\":89,\"g\":94},\"0ca3e568\":{\"m\":89,\"g\":94},\"5c7aa009\":{\"m\":89,\"g\":94},\"fe386aca\":{\"m\":89,\"g\":94},\"14d1075f\":{\"m\":89,\"g\":94},\"006ead9d\":{\"m\":89,\"g\":94},\"0d503090\":{\"m\":89,\"g\":94},\"501efc3d\":{\"m\":89,\"g\":94},\"f9bab3d5\":{\"m\":89,\"g\":94},\"16f69b1f\":{\"m\":89,\"g\":94},\"65f09131\":{\"m\":89,\"g\":94},\"fc419b62\":{\"m\":89,\"g\":94},\"7eb9d8e5\":{\"m\":89,\"g\":94},\"84147254\":{\"m\":89,\"g\":94},\"6bebef60\":{\"m\":89,\"g\":94},\"25be63d0\":{\"m\":89,\"g\":94},\"d502dae0\":{\"m\":89,\"g\":94},\"93e53f6e\":{\"m\":89,\"g\":94},\"a191a0e4\":{\"m\":89,\"g\":94},\"8c7279c2\":{\"m\":89,\"g\":94},\"0ca18117\":{\"m\":89,\"g\":94},\"2c3a6fe1\":{\"m\":89,\"g\":94},\"8b33d8df\":{\"m\":89,\"g\":94},\"e235be16\":{\"m\":89,\"g\":94},\"5ccf8fe1\":{\"m\":89,\"g\":94},\"3f23d8cd\":{\"m\":89,\"g\":94},\"1a399799\":{\"m\":89,\"g\":94},\"022012aa\":{\"m\":89,\"g\":94},\"681e7af3\":{\"m\":89,\"g\":94},\"681fdc26\":{\"m\":89,\"g\":94},\"0d477880\":{\"m\":89,\"g\":94},\"f4560373\":{\"m\":89,\"g\":94},\"b2388433\":{\"m\":89,\"g\":94},\"a38376fa\":{\"m\":89,\"g\":94},\"7a5e6ce1\":{\"m\":89,\"g\":94},\"24c035f2\":{\"m\":89,\"g\":94},\"f9dc9dd2\":{\"m\":90,\"g\":94},\"62a7aa2e\":{\"m\":90,\"g\":94},\"5ca07eed\":{\"m\":90,\"g\":94},\"e30ef368\":{\"m\":90,\"g\":94},\"91a066ec\":{\"m\":90,\"g\":94},\"c4943867\":{\"m\":90,\"g\":94},\"53a525bf\":{\"m\":90,\"g\":94},\"7ddf8e83\":{\"m\":90,\"g\":94},\"8321f8e4\":{\"m\":90,\"g\":94},\"cfceb83d\":{\"m\":90,\"g\":94},\"b1286a11\":{\"m\":90,\"g\":94},\"21615cc3\":{\"m\":90,\"g\":94},\"0ae1e9a7\":{\"m\":90,\"g\":94},\"e07d0647\":{\"m\":90,\"g\":94},\"3c2274fb\":{\"m\":90,\"g\":94},\"d2679f51\":{\"m\":90,\"g\":94},\"96be97bf\":{\"m\":90,\"g\":94},\"88f9c347\":{\"m\":90,\"g\":94},\"fff10809\":{\"m\":90,\"g\":94},\"5f1ab327\":{\"m\":90,\"g\":94},\"7df7c679\":{\"m\":90,\"g\":94},\"38af4f68\":{\"m\":90,\"g\":94},\"a6305c7d\":{\"m\":90,\"g\":94},\"a023856b\":{\"m\":90,\"g\":94},\"db0cc57e\":{\"m\":90,\"g\":94},\"349bb2c9\":{\"m\":90,\"g\":94},\"0b8939bc\":{\"m\":90,\"g\":94},\"ed89837c\":{\"m\":90,\"g\":94},\"55561e25\":{\"m\":90,\"g\":94},\"44733203\":{\"m\":90,\"g\":94},\"0bd67ba2\":{\"m\":90,\"g\":94},\"7d316991\":{\"m\":90,\"g\":94},\"ab1a4fa5\":{\"m\":90,\"g\":94},\"ed54bf9d\":{\"m\":90,\"g\":94},\"b57d87c2\":{\"m\":90,\"g\":94},\"98538822\":{\"m\":90,\"g\":94},\"f47a1b1d\":{\"m\":90,\"g\":94},\"93cec433\":{\"m\":90,\"g\":94},\"ba589b88\":{\"m\":90,\"g\":94},\"50876abc\":{\"m\":90,\"g\":94},\"b4c41f72\":{\"m\":90,\"g\":94},\"8b8f2e74\":{\"m\":90,\"g\":94},\"0fc3d992\":{\"m\":90,\"g\":94},\"be2d985d\":{\"m\":90,\"g\":94},\"5b1afa78\":{\"m\":90,\"g\":94},\"c49c1d92\":{\"m\":90,\"g\":94},\"0f1dfa1e\":{\"m\":90,\"g\":94},\"e3ec6bf4\":{\"m\":90,\"g\":94},\"b04df75a\":{\"m\":90,\"g\":94},\"bec3e484\":{\"m\":90,\"g\":94},\"8ab7d93c\":{\"m\":90,\"g\":94},\"5c66c442\":{\"m\":90,\"g\":94},\"aa46ed34\":{\"m\":90,\"g\":94},\"2f4ec752\":{\"m\":90,\"g\":94},\"da47621c\":{\"m\":90,\"g\":94},\"22a6b9fc\":{\"m\":90,\"g\":94},\"b02df20a\":{\"m\":90,\"g\":94},\"bd7cfbd2\":{\"m\":90,\"g\":94},\"4b9971e4\":{\"m\":90,\"g\":94},\"dcc79d32\":{\"m\":90,\"g\":94},\"7046e0fa\":{\"m\":90,\"g\":94},\"930746d9\":{\"m\":90,\"g\":94},\"84727a51\":{\"m\":90,\"g\":94},\"ef326774\":{\"m\":90,\"g\":94},\"021f76e4\":{\"m\":90,\"g\":94},\"777688b8\":{\"m\":90,\"g\":94},\"0ca594ed\":{\"m\":90,\"g\":94},\"31d6dee5\":{\"m\":90,\"g\":94},\"02543b54\":{\"m\":90,\"g\":94},\"25a6a9aa\":{\"m\":90,\"g\":94},\"83d87685\":{\"m\":90,\"g\":94},\"2a5f0100\":{\"m\":90,\"g\":94},\"dbdf76ca\":{\"m\":90,\"g\":94},\"f2a75a66\":{\"m\":90,\"g\":94},\"6b12d6a8\":{\"m\":90,\"g\":94},\"0f218731\":{\"m\":90,\"g\":94},\"14c18d25\":{\"m\":90,\"g\":94},\"90bd3e32\":{\"m\":90,\"g\":94},\"ca929118\":{\"m\":90,\"g\":94},\"344adb00\":{\"m\":90,\"g\":94},\"b56de8f9\":{\"m\":90,\"g\":94},\"ce5ee3bd\":{\"m\":90,\"g\":94},\"a0e4d4eb\":{\"m\":90,\"g\":94},\"2f584455\":{\"m\":90,\"g\":94},\"fe55947a\":{\"m\":90,\"g\":94},\"19995dd7\":{\"m\":90,\"g\":94},\"3b014bc1\":{\"m\":90,\"g\":94},\"d7c3e8e9\":{\"m\":90,\"g\":94},\"8ea7df61\":{\"m\":90,\"g\":94},\"4a102a2b\":{\"m\":90,\"g\":94},\"6406408a\":{\"m\":90,\"g\":94},\"019851d0\":{\"m\":90,\"g\":94},\"2dae104d\":{\"m\":90,\"g\":94},\"cef6655b\":{\"m\":90,\"g\":94},\"27196d41\":{\"m\":90,\"g\":94},\"bb185b0e\":{\"m\":90,\"g\":94},\"7c3a12c0\":{\"m\":91,\"g\":94},\"e846d95e\":{\"m\":91,\"g\":94},\"15f34013\":{\"m\":91,\"g\":94},\"d04163b3\":{\"m\":91,\"g\":94},\"7732bbe4\":{\"m\":91,\"g\":94},\"ed0a0b69\":{\"m\":91,\"g\":94},\"fa42e419\":{\"m\":91,\"g\":94},\"e5afb88b\":{\"m\":91,\"g\":94},\"e5ddeb04\":{\"m\":91,\"g\":94},\"bdbb8d00\":{\"m\":91,\"g\":94},\"34c3f9b2\":{\"m\":91,\"g\":94},\"76139bfb\":{\"m\":91,\"g\":94},\"f8d48fd3\":{\"m\":91,\"g\":94},\"34b6b842\":{\"m\":91,\"g\":94},\"25549433\":{\"m\":91,\"g\":94},\"d6dddc19\":{\"m\":91,\"g\":94},\"55e03b10\":{\"m\":91,\"g\":94},\"8aa68ed5\":{\"m\":91,\"g\":94},\"506c4928\":{\"m\":91,\"g\":94},\"30ceccc7\":{\"m\":91,\"g\":94},\"ac5010e0\":{\"m\":91,\"g\":94},\"3cee035e\":{\"m\":91,\"g\":94},\"30f2a44a\":{\"m\":91,\"g\":94},\"bd4f5818\":{\"m\":91,\"g\":94},\"50f1b6d6\":{\"m\":91,\"g\":94},\"5962e70d\":{\"m\":91,\"g\":94},\"edc21cc8\":{\"m\":91,\"g\":94},\"05c9bc89\":{\"m\":91,\"g\":94},\"b7a2df0a\":{\"m\":91,\"g\":94},\"1998ce40\":{\"m\":91,\"g\":94},\"72676cd6\":{\"m\":91,\"g\":94},\"02bf31ef\":{\"m\":91,\"g\":94},\"5ea5d221\":{\"m\":91,\"g\":94},\"fdfd5224\":{\"m\":91,\"g\":94},\"7f3ee861\":{\"m\":91,\"g\":94},\"bec58910\":{\"m\":91,\"g\":94},\"9edf6608\":{\"m\":91,\"g\":94},\"ab74f8f0\":{\"m\":91,\"g\":94},\"5e7fdc79\":{\"m\":91,\"g\":94},\"4d8d9b8e\":{\"m\":91,\"g\":94},\"5041df2d\":{\"m\":91,\"g\":94},\"256801e9\":{\"m\":91,\"g\":94},\"73b13e69\":{\"m\":91,\"g\":94},\"8609e637\":{\"m\":91,\"g\":94},\"dea2b84b\":{\"m\":91,\"g\":94},\"cfb2fb5a\":{\"m\":91,\"g\":94},\"22bfed75\":{\"m\":91,\"g\":94},\"e879d8b7\":{\"m\":91,\"g\":94},\"09988080\":{\"m\":91,\"g\":94},\"794be55a\":{\"m\":91,\"g\":94},\"187b85b7\":{\"m\":91,\"g\":94},\"ceba0ce4\":{\"m\":91,\"g\":94},\"1ab6be1b\":{\"m\":91,\"g\":94},\"4df5fc21\":{\"m\":91,\"g\":94},\"a06912ad\":{\"m\":91,\"g\":94},\"97011abc\":{\"m\":91,\"g\":94},\"1d6515ef\":{\"m\":91,\"g\":94},\"dea8aa7a\":{\"m\":91,\"g\":94},\"906dbc34\":{\"m\":91,\"g\":94},\"fadf18fd\":{\"m\":91,\"g\":94},\"f88e7085\":{\"m\":91,\"g\":94},\"4f838c09\":{\"m\":91,\"g\":94},\"d20a073b\":{\"m\":91,\"g\":94},\"47367b76\":{\"m\":91,\"g\":94},\"650127a1\":{\"m\":91,\"g\":94},\"3774f078\":{\"m\":91,\"g\":94},\"9179ea15\":{\"m\":91,\"g\":94},\"20a503c7\":{\"m\":91,\"g\":94},\"ffd1a26e\":{\"m\":91,\"g\":94},\"09ae5b20\":{\"m\":91,\"g\":94},\"712bf9ec\":{\"m\":91,\"g\":94},\"9c6a0656\":{\"m\":91,\"g\":94},\"2ae809c5\":{\"m\":91,\"g\":94},\"1de4db9b\":{\"m\":91,\"g\":94},\"31fccf5a\":{\"m\":91,\"g\":94},\"b783c1cb\":{\"m\":91,\"g\":94},\"094c116f\":{\"m\":91,\"g\":94},\"e56685ac\":{\"m\":91,\"g\":94},\"c26d7349\":{\"m\":91,\"g\":94},\"ceaa85c9\":{\"m\":91,\"g\":94},\"0650e517\":{\"m\":91,\"g\":94},\"fc554105\":{\"m\":91,\"g\":94},\"4f204db5\":{\"m\":91,\"g\":94},\"3eb4a800\":{\"m\":91,\"g\":94},\"e7261315\":{\"m\":91,\"g\":94},\"8c16da33\":{\"m\":91,\"g\":94},\"a39d9287\":{\"m\":91,\"g\":94},\"10d60cd4\":{\"m\":91,\"g\":94},\"8a10c4c3\":{\"m\":91,\"g\":94},\"405780bc\":{\"m\":91,\"g\":94},\"1dffee31\":{\"m\":91,\"g\":94},\"70c471a8\":{\"m\":91,\"g\":94},\"1a9c2c92\":{\"m\":91,\"g\":94},\"873ae12c\":{\"m\":91,\"g\":94},\"c64290dc\":{\"m\":91,\"g\":94},\"8e2363dc\":{\"m\":91,\"g\":94},\"69183f88\":{\"m\":92,\"g\":94},\"9b00990b\":{\"m\":92,\"g\":94},\"4d67025a\":{\"m\":92,\"g\":94},\"0e05fe8c\":{\"m\":92,\"g\":94},\"2390a2bc\":{\"m\":92,\"g\":94},\"16d76b9f\":{\"m\":92,\"g\":94},\"5c214257\":{\"m\":92,\"g\":94},\"b8df43ab\":{\"m\":92,\"g\":94},\"a1c1ebe9\":{\"m\":92,\"g\":94},\"fe2a0f96\":{\"m\":92,\"g\":94},\"20beb370\":{\"m\":92,\"g\":94},\"00fbd8a4\":{\"m\":92,\"g\":94},\"802815e4\":{\"m\":92,\"g\":94},\"4c6675c4\":{\"m\":92,\"g\":94},\"e21aa1df\":{\"m\":92,\"g\":94},\"f3cbd245\":{\"m\":92,\"g\":94},\"506a2d59\":{\"m\":92,\"g\":94},\"a07f8ae4\":{\"m\":92,\"g\":94},\"7eb47b0f\":{\"m\":92,\"g\":94},\"bc2e5645\":{\"m\":92,\"g\":94},\"3abc3036\":{\"m\":92,\"g\":94},\"afeed465\":{\"m\":92,\"g\":94},\"587b4c6e\":{\"m\":92,\"g\":94},\"7b9a174a\":{\"m\":92,\"g\":94},\"03c039c4\":{\"m\":92,\"g\":94},\"57ab7769\":{\"m\":92,\"g\":94},\"112b496a\":{\"m\":92,\"g\":94},\"3562256b\":{\"m\":92,\"g\":94},\"5f527834\":{\"m\":92,\"g\":94},\"9f1787fa\":{\"m\":92,\"g\":94},\"8ecad0b1\":{\"m\":92,\"g\":94},\"7151194b\":{\"m\":92,\"g\":94},\"2ed68d7a\":{\"m\":92,\"g\":94},\"e984d507\":{\"m\":92,\"g\":94},\"755f3147\":{\"m\":92,\"g\":94},\"ec5f9c62\":{\"m\":93,\"g\":94},\"62f5522f\":{\"m\":93,\"g\":94},\"01f98730\":{\"m\":93,\"g\":94},\"199d6218\":{\"m\":93,\"g\":94},\"f200af0d\":{\"m\":93,\"g\":94},\"5589b750\":{\"m\":93,\"g\":94},\"c04a8a82\":{\"m\":93,\"g\":94},\"6c903611\":{\"m\":93,\"g\":94},\"77cfea68\":{\"m\":93,\"g\":94},\"8fc910db\":{\"m\":93,\"g\":94},\"75354d9a\":{\"m\":93,\"g\":94},\"4fece12b\":{\"m\":93,\"g\":94},\"c7973222\":{\"m\":93,\"g\":94},\"ef8a29c4\":{\"m\":93,\"g\":94},\"8e9fb43d\":{\"m\":93,\"g\":94},\"83646089\":{\"m\":93,\"g\":94},\"da3890e8\":{\"m\":93,\"g\":94},\"cb432f17\":{\"m\":93,\"g\":94},\"1964c325\":{\"m\":93,\"g\":94},\"af564774\":{\"m\":93,\"g\":94},\"af46f299\":{\"m\":93,\"g\":94},\"16a6b1d8\":{\"m\":93,\"g\":94},\"14229ccf\":{\"m\":93,\"g\":94},\"975a5ec6\":{\"m\":93,\"g\":94},\"1e3e3add\":{\"m\":93,\"g\":94},\"8c298031\":{\"m\":93,\"g\":94},\"4de03953\":{\"m\":93,\"g\":94},\"8b1942c6\":{\"m\":93,\"g\":94},\"489934be\":{\"m\":93,\"g\":94},\"43f93f63\":{\"m\":93,\"g\":94},\"aca1101a\":{\"m\":93,\"g\":94},\"2998c4bd\":{\"m\":93,\"g\":94},\"6840a7bb\":{\"m\":93,\"g\":94},\"c01a1df5\":{\"m\":93,\"g\":94},\"00991723\":{\"m\":93,\"g\":94},\"264dc6e7\":{\"m\":93,\"g\":94},\"646cef2e\":{\"m\":93,\"g\":94},\"1dce6c48\":{\"m\":93,\"g\":94},\"9fcc9a80\":{\"m\":93,\"g\":94},\"ac49dac0\":{\"m\":93,\"g\":94},\"1e0e5497\":{\"m\":93,\"g\":94},\"b5822651\":{\"m\":93,\"g\":94},\"2c4feaf3\":{\"m\":93,\"g\":94},\"2ff572e2\":{\"m\":93,\"g\":94},\"84f2e4a0\":{\"m\":93,\"g\":94},\"8f844db6\":{\"m\":93,\"g\":94},\"36cc3ffd\":{\"m\":93,\"g\":94},\"1bebd315\":{\"m\":93,\"g\":94},\"d3c275b1\":{\"m\":93,\"g\":94},\"b044400d\":{\"m\":93,\"g\":94},\"40e5cb7a\":{\"m\":93,\"g\":94},\"8e64140e\":{\"m\":93,\"g\":94},\"82f021e2\":{\"m\":93,\"g\":94},\"0626f678\":{\"m\":93,\"g\":94},\"09e699bb\":{\"m\":93,\"g\":94},\"b116b21a\":{\"m\":93,\"g\":94},\"88f484ce\":{\"m\":93,\"g\":94},\"8e03b641\":{\"m\":93,\"g\":94},\"b3fa5dc3\":{\"m\":93,\"g\":94},\"00aec6ad\":{\"m\":93,\"g\":94},\"1a08358a\":{\"m\":93,\"g\":94},\"f18a8fdd\":{\"m\":93,\"g\":94},\"a7efbb27\":{\"m\":93,\"g\":94},\"93b6785d\":{\"m\":93,\"g\":94},\"f9eb04dd\":{\"m\":93,\"g\":94},\"3a911b85\":{\"m\":93,\"g\":94},\"886d3449\":{\"m\":93,\"g\":94},\"637bfee4\":{\"m\":93,\"g\":94},\"6005ecee\":{\"m\":93,\"g\":94},\"ff2e9c94\":{\"m\":93,\"g\":94},\"3e34e900\":{\"m\":93,\"g\":94},\"7349717e\":{\"m\":93,\"g\":94},\"392e441a\":{\"m\":93,\"g\":94},\"7248272c\":{\"m\":93,\"g\":94},\"22352d47\":{\"m\":93,\"g\":94},\"c5131f7a\":{\"m\":93,\"g\":94},\"78700893\":{\"m\":93,\"g\":94},\"663c04f7\":{\"m\":93,\"g\":94},\"3b3f1e3a\":{\"m\":93,\"g\":94},\"b691dcc4\":{\"m\":93,\"g\":94},\"0c9c6c75\":{\"m\":93,\"g\":94},\"e3f9b548\":{\"m\":93,\"g\":94},\"b3cff365\":{\"m\":93,\"g\":94},\"8f335b5b\":{\"m\":93,\"g\":94},\"b2264076\":{\"m\":93,\"g\":94},\"04b35190\":{\"m\":93,\"g\":94},\"071a1f51\":{\"m\":93,\"g\":94},\"7c0db3a6\":{\"m\":93,\"g\":94},\"c45e49d8\":{\"m\":93,\"g\":94},\"d8053929\":{\"m\":93,\"g\":94},\"00c7b1ad\":{\"m\":93,\"g\":94},\"82eccae4\":{\"m\":93,\"g\":94},\"a8c10aee\":{\"m\":93,\"g\":94},\"eb429b88\":{\"m\":93,\"g\":94},\"49538d11\":{\"m\":93,\"g\":94},\"cfe2edac\":{\"m\":93,\"g\":94},\"2373faa3\":{\"m\":93,\"g\":94},\"9efb2993\":{\"m\":93,\"g\":94},\"a5317b2f\":{\"m\":93,\"g\":94},\"eb6c2c16\":{\"m\":93,\"g\":94},\"357921aa\":{\"m\":93,\"g\":94},\"c071198c\":{\"m\":93,\"g\":94},\"d7374d74\":{\"m\":93,\"g\":94},\"ce3a3e87\":{\"m\":93,\"g\":94},\"41650b0d\":{\"m\":93,\"g\":94},\"1b951620\":{\"m\":93,\"g\":94},\"29bd4c81\":{\"m\":93,\"g\":94},\"031f64aa\":{\"m\":93,\"g\":94},\"3d7cdb2e\":{\"m\":93,\"g\":94},\"604efe07\":{\"m\":93,\"g\":94},\"1b8cf77b\":{\"m\":93,\"g\":94},\"bb9b608c\":{\"m\":93,\"g\":94},\"066f4ec9\":{\"m\":95,\"g\":97},\"b6b6268c\":{\"m\":95,\"g\":97},\"08702321\":{\"m\":95,\"g\":97},\"64c5907e\":{\"m\":95,\"g\":97},\"128f16a8\":{\"m\":95,\"g\":97},\"49861046\":{\"m\":95,\"g\":97},\"a37e1247\":{\"m\":95,\"g\":97},\"136c6e04\":{\"m\":95,\"g\":97},\"43e20c06\":{\"m\":95,\"g\":97},\"4bab50a6\":{\"m\":95,\"g\":97},\"2e7ab862\":{\"m\":95,\"g\":97},\"51ae4030\":{\"m\":95,\"g\":97},\"653b873b\":{\"m\":95,\"g\":97},\"d379bda4\":{\"m\":95,\"g\":97},\"2b0e1d1c\":{\"m\":95,\"g\":97},\"659907e3\":{\"m\":95,\"g\":97},\"cb9d91ea\":{\"m\":95,\"g\":97},\"6a6e0bb7\":{\"m\":95,\"g\":97},\"076313bd\":{\"m\":95,\"g\":97},\"9abe1163\":{\"m\":95,\"g\":97},\"3646f6bb\":{\"m\":95,\"g\":94},\"35724aa1\":{\"m\":95,\"g\":94},\"2fc824b8\":{\"m\":95,\"g\":94},\"253454de\":{\"m\":95,\"g\":94},\"ea3e7ffe\":{\"m\":95,\"g\":94},\"8d4a01cb\":{\"m\":95,\"g\":94},\"a3398d84\":{\"m\":95,\"g\":94},\"ba69c153\":{\"m\":95,\"g\":94},\"3589aa79\":{\"m\":95,\"g\":94},\"e00715eb\":{\"m\":95,\"g\":94},\"ea4bf122\":{\"m\":95,\"g\":94},\"a291439a\":{\"m\":95,\"g\":94},\"54411f6a\":{\"m\":95,\"g\":94},\"625018d2\":{\"m\":95,\"g\":94},\"5732d904\":{\"m\":95,\"g\":94},\"eb118d88\":{\"m\":96,\"g\":97},\"732fc8e4\":{\"m\":96,\"g\":97},\"f2d5c492\":{\"m\":96,\"g\":97},\"2a2d3478\":{\"m\":96,\"g\":97},\"aa205609\":{\"m\":96,\"g\":97},\"61bb2858\":{\"m\":96,\"g\":97},\"880221bd\":{\"m\":96,\"g\":97},\"8f3173d0\":{\"m\":96,\"g\":97},\"26118a13\":{\"m\":96,\"g\":97},\"475a249b\":{\"m\":96,\"g\":97},\"191d836f\":{\"m\":96,\"g\":97},\"86044712\":{\"m\":96,\"g\":97},\"61555307\":{\"m\":96,\"g\":97},\"49a5915f\":{\"m\":96,\"g\":97},\"766392c6\":{\"m\":96,\"g\":97},\"4a0d1919\":{\"m\":96,\"g\":97},\"57482415\":{\"m\":96,\"g\":97},\"2d54d4bb\":{\"m\":96,\"g\":97},\"b5e3d603\":{\"m\":96,\"g\":97},\"4ed57807\":{\"m\":96,\"g\":97},\"dd445a41\":{\"m\":96,\"g\":97},\"7590f522\":{\"m\":96,\"g\":97},\"f9df11ae\":{\"m\":96,\"g\":97},\"d389bedf\":{\"m\":96,\"g\":97},\"ac80f4da\":{\"m\":96,\"g\":97},\"d487555f\":{\"m\":96,\"g\":97},\"e5888edd\":{\"m\":96,\"g\":97},\"01c00004\":{\"m\":98,\"g\":103},\"0dfe2491\":{\"m\":98,\"g\":103},\"ff45ab7a\":{\"m\":98,\"g\":103},\"0f8b5386\":{\"m\":98,\"g\":103},\"c33499a6\":{\"m\":98,\"g\":103},\"e50109f2\":{\"m\":98,\"g\":103},\"69adc4f8\":{\"m\":98,\"g\":103},\"11483785\":{\"m\":98,\"g\":103},\"7b68d271\":{\"m\":98,\"g\":103},\"74f59ae5\":{\"m\":98,\"g\":103},\"6936be32\":{\"m\":98,\"g\":103},\"9b5de6cb\":{\"m\":98,\"g\":97},\"5c8365a0\":{\"m\":98,\"g\":97},\"8430bfe3\":{\"m\":98,\"g\":97},\"c9e8613c\":{\"m\":98,\"g\":97},\"429bb0ef\":{\"m\":98,\"g\":97},\"7eebd440\":{\"m\":98,\"g\":97},\"93d124ef\":{\"m\":98,\"g\":97},\"1fc455e8\":{\"m\":98,\"g\":97},\"465968b2\":{\"m\":98,\"g\":97},\"750838ad\":{\"m\":98,\"g\":97},\"99aefa03\":{\"m\":98,\"g\":97},\"bbcfbc1a\":{\"m\":98,\"g\":97},\"83c104b1\":{\"m\":98,\"g\":97},\"2db6719c\":{\"m\":98,\"g\":97},\"55381a46\":{\"m\":98,\"g\":97},\"a589a071\":{\"m\":98,\"g\":97},\"f62d75b6\":{\"m\":98,\"g\":97},\"0f9b11e3\":{\"m\":98,\"g\":97},\"877e35d7\":{\"m\":98,\"g\":97},\"cbdfb771\":{\"m\":98,\"g\":97},\"282eb59f\":{\"m\":98,\"g\":97},\"4540a466\":{\"m\":98,\"g\":97},\"abda2542\":{\"m\":98,\"g\":97},\"8cddfa56\":{\"m\":98,\"g\":97},\"4e3defe5\":{\"m\":98,\"g\":97},\"60468da4\":{\"m\":98,\"g\":97},\"41d33e47\":{\"m\":98,\"g\":97},\"bfdd226f\":{\"m\":98,\"g\":97},\"3de617a7\":{\"m\":98,\"g\":97},\"bb0e8a32\":{\"m\":98,\"g\":97},\"1b427dae\":{\"m\":98,\"g\":97},\"f3d97361\":{\"m\":98,\"g\":97},\"561dd7b2\":{\"m\":98,\"g\":97},\"f98e88b9\":{\"m\":98,\"g\":97},\"15ad6c90\":{\"m\":98,\"g\":97},\"cfab0ff6\":{\"m\":98,\"g\":97},\"b763cf7e\":{\"m\":98,\"g\":97},\"8fcc55cf\":{\"m\":98,\"g\":97},\"610381b7\":{\"m\":98,\"g\":97},\"1403ea56\":{\"m\":98,\"g\":97},\"b7e951a6\":{\"m\":98,\"g\":97},\"d918ab79\":{\"m\":98,\"g\":97},\"3964b352\":{\"m\":98,\"g\":97},\"9c7a4618\":{\"m\":98,\"g\":97},\"7750b91c\":{\"m\":98,\"g\":97},\"c8f31042\":{\"m\":98,\"g\":97},\"1f76fc87\":{\"m\":98,\"g\":97},\"6737671c\":{\"m\":98,\"g\":97},\"fd63b62e\":{\"m\":98,\"g\":97},\"719b29f2\":{\"m\":98,\"g\":97},\"d0510f08\":{\"m\":98,\"g\":97},\"9d33fcfb\":{\"m\":98,\"g\":97},\"7891bac1\":{\"m\":98,\"g\":97},\"48c1fa7b\":{\"m\":98,\"g\":97},\"8aa5ae6b\":{\"m\":98,\"g\":97},\"8a323557\":{\"m\":98,\"g\":97},\"6e92da8f\":{\"m\":98,\"g\":97},\"e1020dc5\":{\"m\":98,\"g\":97},\"3586b4ce\":{\"m\":98,\"g\":97},\"42960214\":{\"m\":98,\"g\":97},\"01857fab\":{\"m\":98,\"g\":97},\"519ff5c8\":{\"m\":98,\"g\":97},\"af1cc8fe\":{\"m\":98,\"g\":97},\"49b87774\":{\"m\":98,\"g\":97},\"02404a1e\":{\"m\":98,\"g\":97},\"5c08a36c\":{\"m\":98,\"g\":97},\"9069884b\":{\"m\":98,\"g\":97},\"8a7a7770\":{\"m\":98,\"g\":97},\"795668dc\":{\"m\":98,\"g\":97},\"4395c87a\":{\"m\":98,\"g\":97},\"c28ad199\":{\"m\":98,\"g\":97},\"570d3343\":{\"m\":98,\"g\":97},\"d9eb5efc\":{\"m\":98,\"g\":97},\"6dc4af49\":{\"m\":98,\"g\":97},\"b188a89a\":{\"m\":98,\"g\":97},\"497efe74\":{\"m\":98,\"g\":97},\"69f453e5\":{\"m\":98,\"g\":97},\"3bc43c68\":{\"m\":98,\"g\":97},\"7498522f\":{\"m\":98,\"g\":97},\"194841e3\":{\"m\":98,\"g\":97},\"ebff5fcb\":{\"m\":98,\"g\":97},\"f06bd210\":{\"m\":98,\"g\":97},\"14f1f151\":{\"m\":98,\"g\":97},\"38216cf0\":{\"m\":98,\"g\":97},\"4a883795\":{\"m\":98,\"g\":97},\"f1f1d1d4\":{\"m\":98,\"g\":97},\"9120e83d\":{\"m\":98,\"g\":97},\"6e923dbd\":{\"m\":98,\"g\":97},\"c268c11c\":{\"m\":98,\"g\":97},\"e6d59884\":{\"m\":98,\"g\":97},\"9b560c3e\":{\"m\":98,\"g\":97},\"5dc5866e\":{\"m\":98,\"g\":97},\"64e78bb3\":{\"m\":98,\"g\":97},\"7c39e8a1\":{\"m\":98,\"g\":97},\"d969504d\":{\"m\":98,\"g\":97},\"1ebec1a8\":{\"m\":98,\"g\":97},\"d4d0c7c3\":{\"m\":98,\"g\":97},\"8d2cf38c\":{\"m\":98,\"g\":97},\"2117f82d\":{\"m\":98,\"g\":97},\"c07f647c\":{\"m\":98,\"g\":97},\"07452cbe\":{\"m\":98,\"g\":97},\"a562c8a3\":{\"m\":98,\"g\":97},\"cb736df8\":{\"m\":98,\"g\":97},\"e2ed9d04\":{\"m\":98,\"g\":97},\"b5dd5e87\":{\"m\":98,\"g\":97},\"9379da77\":{\"m\":98,\"g\":97},\"0c55cbcf\":{\"m\":98,\"g\":97},\"c46e069d\":{\"m\":98,\"g\":97},\"42fc4410\":{\"m\":98,\"g\":97},\"5f6756b0\":{\"m\":98,\"g\":97},\"98aa836b\":{\"m\":98,\"g\":97},\"22bd857c\":{\"m\":98,\"g\":97},\"ccfa0841\":{\"m\":98,\"g\":97},\"bcc5ba94\":{\"m\":98,\"g\":97},\"cee9f329\":{\"m\":98,\"g\":97},\"2272c2a5\":{\"m\":99,\"g\":103},\"3ec0b212\":{\"m\":99,\"g\":103},\"58c468f4\":{\"m\":99,\"g\":103},\"f8ca2368\":{\"m\":99,\"g\":103},\"d8ee1564\":{\"m\":99,\"g\":103},\"7181ec8c\":{\"m\":99,\"g\":103},\"ed2e313e\":{\"m\":99,\"g\":103},\"f8260f25\":{\"m\":99,\"g\":103},\"12cb760a\":{\"m\":99,\"g\":103},\"1b9cea5a\":{\"m\":99,\"g\":103},\"9045cc1e\":{\"m\":99,\"g\":103},\"70e37b97\":{\"m\":99,\"g\":103},\"15d27591\":{\"m\":99,\"g\":103},\"af4b9bae\":{\"m\":99,\"g\":103},\"7ad6b766\":{\"m\":99,\"g\":103},\"c0fb25e9\":{\"m\":99,\"g\":103},\"28d4d472\":{\"m\":99,\"g\":103},\"f4674df6\":{\"m\":99,\"g\":103},\"d40846d4\":{\"m\":99,\"g\":103},\"145482f4\":{\"m\":99,\"g\":103},\"39fe1e88\":{\"m\":99,\"g\":103},\"33c4b4d0\":{\"m\":99,\"g\":103},\"8d1c5b94\":{\"m\":99,\"g\":103},\"a167fd0b\":{\"m\":99,\"g\":103},\"2f86f3ad\":{\"m\":99,\"g\":103},\"bfb118c0\":{\"m\":99,\"g\":103},\"f6e07f27\":{\"m\":99,\"g\":103},\"5dd0f870\":{\"m\":99,\"g\":103},\"f7e102d5\":{\"m\":99,\"g\":103},\"0e5fa677\":{\"m\":99,\"g\":103},\"624a3b8d\":{\"m\":99,\"g\":103},\"01079e17\":{\"m\":99,\"g\":103},\"0e7a5b26\":{\"m\":99,\"g\":103},\"4953f4ca\":{\"m\":99,\"g\":103},\"38000a5f\":{\"m\":99,\"g\":103},\"70251e93\":{\"m\":99,\"g\":103},\"c87d4fec\":{\"m\":99,\"g\":103},\"a99801e0\":{\"m\":99,\"g\":103},\"4c605235\":{\"m\":99,\"g\":103},\"6f8f4aee\":{\"m\":99,\"g\":103},\"0c8dab9e\":{\"m\":99,\"g\":103},\"f39037ff\":{\"m\":99,\"g\":103},\"ce86e201\":{\"m\":99,\"g\":103},\"b4326330\":{\"m\":99,\"g\":103},\"8abd3e77\":{\"m\":99,\"g\":103},\"e885bfdc\":{\"m\":99,\"g\":103},\"e2d66f60\":{\"m\":99,\"g\":103},\"45bc170b\":{\"m\":100,\"g\":103},\"22623699\":{\"m\":100,\"g\":103},\"fb4ce17d\":{\"m\":100,\"g\":103},\"25f73c6c\":{\"m\":100,\"g\":103},\"581e7dcb\":{\"m\":100,\"g\":103},\"484d0e02\":{\"m\":100,\"g\":103},\"5922c0cb\":{\"m\":100,\"g\":103},\"6d6a8bc2\":{\"m\":100,\"g\":103},\"2fd5c704\":{\"m\":100,\"g\":103},\"4ad97370\":{\"m\":100,\"g\":103},\"28103384\":{\"m\":100,\"g\":103},\"fe6a445d\":{\"m\":100,\"g\":103},\"dd487e55\":{\"m\":100,\"g\":103},\"bb81daef\":{\"m\":100,\"g\":103},\"58dd95fb\":{\"m\":100,\"g\":103},\"b47eda33\":{\"m\":100,\"g\":103},\"e983d666\":{\"m\":100,\"g\":103},\"b58c3c28\":{\"m\":100,\"g\":103},\"df906455\":{\"m\":100,\"g\":103},\"95217a9b\":{\"m\":100,\"g\":103},\"22e00eeb\":{\"m\":100,\"g\":103},\"b3eac168\":{\"m\":100,\"g\":103},\"10ee8955\":{\"m\":100,\"g\":103},\"bf3352c5\":{\"m\":100,\"g\":103},\"4d921f2b\":{\"m\":100,\"g\":103},\"44d600cd\":{\"m\":100,\"g\":103},\"5c9c275b\":{\"m\":100,\"g\":103},\"bf0f448f\":{\"m\":100,\"g\":103},\"36d6f0ba\":{\"m\":100,\"g\":103},\"2a1936de\":{\"m\":100,\"g\":103},\"2ab97023\":{\"m\":100,\"g\":103},\"0bcc195f\":{\"m\":100,\"g\":103},\"91e3d154\":{\"m\":100,\"g\":103},\"85486b6f\":{\"m\":100,\"g\":103},\"e34cf6ad\":{\"m\":100,\"g\":103},\"62222bd2\":{\"m\":100,\"g\":103},\"ed0fdbf3\":{\"m\":100,\"g\":103},\"b602f423\":{\"m\":100,\"g\":103},\"426b7493\":{\"m\":100,\"g\":103},\"528bd1ed\":{\"m\":100,\"g\":103},\"62a6b7c7\":{\"m\":100,\"g\":103},\"76154631\":{\"m\":100,\"g\":103},\"5c705b1d\":{\"m\":100,\"g\":103},\"b7094a5e\":{\"m\":100,\"g\":103},\"da0c0260\":{\"m\":100,\"g\":103},\"3212c2ad\":{\"m\":100,\"g\":103},\"53475674\":{\"m\":100,\"g\":103},\"ce32bc2b\":{\"m\":100,\"g\":103},\"e236d8fe\":{\"m\":100,\"g\":103},\"4fa44d63\":{\"m\":100,\"g\":103},\"e6312d27\":{\"m\":100,\"g\":103},\"8af145b7\":{\"m\":100,\"g\":103},\"6478831b\":{\"m\":101,\"g\":103},\"2e1d2d7e\":{\"m\":101,\"g\":103},\"fb16fbaf\":{\"m\":101,\"g\":103},\"0ce84c82\":{\"m\":101,\"g\":103},\"59d0bf01\":{\"m\":101,\"g\":103},\"7df2c0c2\":{\"m\":101,\"g\":103},\"69712e6f\":{\"m\":101,\"g\":103},\"001bffca\":{\"m\":101,\"g\":103},\"7c969717\":{\"m\":101,\"g\":103},\"8240a6b0\":{\"m\":101,\"g\":103},\"3a04aa4b\":{\"m\":101,\"g\":103},\"bd516949\":{\"m\":101,\"g\":103},\"74e7e457\":{\"m\":101,\"g\":103},\"1466c1b8\":{\"m\":101,\"g\":103},\"9c138a04\":{\"m\":101,\"g\":103},\"c8f549d9\":{\"m\":101,\"g\":103},\"134fa43e\":{\"m\":101,\"g\":103},\"ccfe52a0\":{\"m\":101,\"g\":103},\"747dd450\":{\"m\":101,\"g\":103},\"b5821592\":{\"m\":101,\"g\":103},\"a9dd3ec3\":{\"m\":101,\"g\":103},\"02328864\":{\"m\":102,\"g\":103},\"7a1f7fc5\":{\"m\":102,\"g\":103},\"51c38163\":{\"m\":102,\"g\":103},\"09f1a247\":{\"m\":102,\"g\":103},\"32fa1e9c\":{\"m\":102,\"g\":103},\"e7dc163f\":{\"m\":102,\"g\":103},\"e179e0b7\":{\"m\":102,\"g\":103},\"d9049592\":{\"m\":102,\"g\":103},\"26c8a310\":{\"m\":102,\"g\":103},\"5963e505\":{\"m\":102,\"g\":103},\"43118f5f\":{\"m\":102,\"g\":103},\"a5f5ab40\":{\"m\":102,\"g\":103},\"59aab76f\":{\"m\":102,\"g\":103},\"659bfd10\":{\"m\":102,\"g\":103},\"67e53b16\":{\"m\":102,\"g\":103},\"9b9e8253\":{\"m\":102,\"g\":103},\"66a398f4\":{\"m\":102,\"g\":103},\"29980334\":{\"m\":102,\"g\":103},\"a79a5d70\":{\"m\":102,\"g\":103},\"ec5f9442\":{\"m\":102,\"g\":103},\"3bdcdd13\":{\"m\":102,\"g\":103},\"a730ce81\":{\"m\":102,\"g\":103},\"55ecdc0a\":{\"m\":102,\"g\":103},\"e3f08c77\":{\"m\":102,\"g\":103},\"a9fd8033\":{\"m\":102,\"g\":103},\"2fbb754e\":{\"m\":102,\"g\":103},\"a85ebf50\":{\"m\":102,\"g\":103},\"9effeb5b\":{\"m\":102,\"g\":103},\"1992ef9b\":{\"m\":102,\"g\":103},\"c0fd77e8\":{\"m\":102,\"g\":103},\"a4c3b121\":{\"m\":102,\"g\":103},\"5973675b\":{\"m\":102,\"g\":103},\"4d16c88b\":{\"m\":102,\"g\":103},\"7a4309cc\":{\"m\":102,\"g\":103},\"81367066\":{\"m\":102,\"g\":103},\"263c9236\":{\"m\":102,\"g\":103},\"33f0de33\":{\"m\":105,\"g\":107},\"e7e5a305\":{\"m\":105,\"g\":107},\"dd7ca006\":{\"m\":105,\"g\":107},\"9305ea6c\":{\"m\":105,\"g\":107},\"aa4c66b5\":{\"m\":105,\"g\":107},\"39decec1\":{\"m\":105,\"g\":104},\"f6f46f46\":{\"m\":105,\"g\":104},\"2886e23d\":{\"m\":105,\"g\":104},\"99795d61\":{\"m\":105,\"g\":104},\"fe5086fd\":{\"m\":105,\"g\":104},\"04913430\":{\"m\":105,\"g\":104},\"0ad098b4\":{\"m\":105,\"g\":104},\"4a6e7a66\":{\"m\":105,\"g\":104},\"4b04998d\":{\"m\":105,\"g\":104},\"3dde8619\":{\"m\":105,\"g\":104},\"b7170cc8\":{\"m\":105,\"g\":104},\"5c14515f\":{\"m\":105,\"g\":104},\"2cd2e27f\":{\"m\":105,\"g\":104},\"743638bc\":{\"m\":105,\"g\":104},\"061c8959\":{\"m\":105,\"g\":104},\"4acf6902\":{\"m\":105,\"g\":104},\"aee0ef52\":{\"m\":105,\"g\":103},\"ae807774\":{\"m\":105,\"g\":103},\"8fbcfd07\":{\"m\":105,\"g\":103},\"3c307dc0\":{\"m\":105,\"g\":103},\"5d15fb8c\":{\"m\":105,\"g\":103},\"016fd251\":{\"m\":105,\"g\":103},\"8cd34458\":{\"m\":106,\"g\":107},\"0e0eef00\":{\"m\":106,\"g\":107},\"cb099d20\":{\"m\":106,\"g\":107},\"7a913301\":{\"m\":106,\"g\":107},\"5ce5093b\":{\"m\":106,\"g\":107},\"6f9baf10\":{\"m\":106,\"g\":107},\"a31b7a70\":{\"m\":106,\"g\":107},\"7ed8e51b\":{\"m\":106,\"g\":107},\"32f28154\":{\"m\":106,\"g\":107},\"f7b2853f\":{\"m\":106,\"g\":107},\"b0add2da\":{\"m\":106,\"g\":107},\"0305c505\":{\"m\":106,\"g\":107},\"8675bdf2\":{\"m\":106,\"g\":107},\"a437aa99\":{\"m\":106,\"g\":107},\"0e612dbf\":{\"m\":106,\"g\":107},\"9f47d686\":{\"m\":106,\"g\":107},\"d9def43d\":{\"m\":106,\"g\":107},\"e273aa6d\":{\"m\":106,\"g\":107},\"828a4fe9\":{\"m\":106,\"g\":107},\"8ada1ab6\":{\"m\":106,\"g\":107},\"e314b084\":{\"m\":106,\"g\":107},\"403566bc\":{\"m\":106,\"g\":107},\"0a56b721\":{\"m\":106,\"g\":107},\"603f5ce0\":{\"m\":106,\"g\":107},\"6d4fd882\":{\"m\":106,\"g\":107},\"f9f0138f\":{\"m\":106,\"g\":107},\"ac6962cc\":{\"m\":106,\"g\":107},\"4ca43b06\":{\"m\":106,\"g\":107},\"ea93079b\":{\"m\":106,\"g\":107},\"4bec99ec\":{\"m\":106,\"g\":107},\"89caf7a3\":{\"m\":106,\"g\":107},\"b27b1191\":{\"m\":106,\"g\":107},\"f642524f\":{\"m\":106,\"g\":107},\"82e6c3a6\":{\"m\":106,\"g\":107},\"b89d37cb\":{\"m\":106,\"g\":107},\"5deab128\":{\"m\":106,\"g\":107},\"d1c4d51c\":{\"m\":106,\"g\":107},\"1fe691a4\":{\"m\":106,\"g\":107},\"e2521926\":{\"m\":106,\"g\":107},\"07e46eca\":{\"m\":106,\"g\":107},\"ab9b893e\":{\"m\":106,\"g\":107},\"6a7528e6\":{\"m\":106,\"g\":107},\"2ae95d17\":{\"m\":106,\"g\":107},\"2d401bd9\":{\"m\":106,\"g\":107},\"b17c5b01\":{\"m\":106,\"g\":107},\"db7343c9\":{\"m\":106,\"g\":107},\"533cb5b2\":{\"m\":106,\"g\":107},\"6bdd2786\":{\"m\":106,\"g\":107},\"46e9d1c7\":{\"m\":106,\"g\":107},\"6c88f6c8\":{\"m\":106,\"g\":107},\"c8d3a402\":{\"m\":106,\"g\":107},\"7e831efe\":{\"m\":106,\"g\":107},\"20b5563e\":{\"m\":106,\"g\":107},\"97a38ee8\":{\"m\":108,\"g\":115},\"86d10d22\":{\"m\":108,\"g\":115},\"83871aa1\":{\"m\":108,\"g\":115},\"b1b3f0b3\":{\"m\":108,\"g\":115},\"34e5e11f\":{\"m\":108,\"g\":115},\"2600fc0d\":{\"m\":108,\"g\":115},\"ccd3fb94\":{\"m\":108,\"g\":115},\"c9dd70fb\":{\"m\":108,\"g\":115},\"6b2b8bf0\":{\"m\":108,\"g\":115},\"4edbe0d5\":{\"m\":108,\"g\":115},\"0374304a\":{\"m\":108,\"g\":115},\"127d4b0d\":{\"m\":108,\"g\":115},\"7e880286\":{\"m\":108,\"g\":115},\"446c8e4c\":{\"m\":108,\"g\":115},\"5ef545e6\":{\"m\":108,\"g\":115},\"c4500233\":{\"m\":108,\"g\":115},\"f445a1d9\":{\"m\":108,\"g\":115},\"e5638573\":{\"m\":108,\"g\":115},\"f556ac8b\":{\"m\":108,\"g\":115},\"110a6598\":{\"m\":108,\"g\":115},\"49f9d025\":{\"m\":108,\"g\":115},\"0f587e80\":{\"m\":108,\"g\":115},\"6078d5fc\":{\"m\":108,\"g\":115},\"70cf4abc\":{\"m\":108,\"g\":115},\"cebf4599\":{\"m\":108,\"g\":115},\"9c0c1e30\":{\"m\":108,\"g\":115},\"a1f011d0\":{\"m\":108,\"g\":115},\"9ec314c6\":{\"m\":108,\"g\":115},\"fedfe91c\":{\"m\":108,\"g\":115},\"988accbc\":{\"m\":108,\"g\":115},\"b6b2287e\":{\"m\":108,\"g\":115},\"243e745d\":{\"m\":108,\"g\":115},\"61a0e600\":{\"m\":108,\"g\":115},\"0f8cee8c\":{\"m\":108,\"g\":115},\"816c4c85\":{\"m\":108,\"g\":115},\"13ec8d42\":{\"m\":108,\"g\":115},\"05bd7897\":{\"m\":108,\"g\":115},\"5fd311d3\":{\"m\":108,\"g\":115},\"53e2cd46\":{\"m\":108,\"g\":115},\"9708d353\":{\"m\":108,\"g\":115},\"704ced1b\":{\"m\":108,\"g\":115},\"3cc3d9b9\":{\"m\":108,\"g\":115},\"0b3a5b11\":{\"m\":108,\"g\":115},\"6c855db8\":{\"m\":108,\"g\":115},\"0f9318f7\":{\"m\":108,\"g\":115},\"849957bc\":{\"m\":108,\"g\":115},\"cded039b\":{\"m\":108,\"g\":115},\"275f9df3\":{\"m\":108,\"g\":115},\"e8449ab5\":{\"m\":108,\"g\":115},\"4746aaea\":{\"m\":108,\"g\":115},\"10d34f74\":{\"m\":108,\"g\":115},\"9ba72530\":{\"m\":108,\"g\":115},\"9c8e4f69\":{\"m\":108,\"g\":115},\"78ae1758\":{\"m\":108,\"g\":115},\"dae9a80f\":{\"m\":108,\"g\":115},\"e85cb1ce\":{\"m\":108,\"g\":115},\"55d336cb\":{\"m\":108,\"g\":115},\"de4990a5\":{\"m\":108,\"g\":115},\"029e0af3\":{\"m\":108,\"g\":115},\"64574ef8\":{\"m\":108,\"g\":115},\"18da2c96\":{\"m\":108,\"g\":115},\"9b5f0f64\":{\"m\":108,\"g\":115},\"70bb066e\":{\"m\":108,\"g\":115},\"2c4b4b78\":{\"m\":108,\"g\":115},\"7cd2ee06\":{\"m\":108,\"g\":115},\"eb19ccad\":{\"m\":108,\"g\":115},\"25ef53f0\":{\"m\":108,\"g\":115},\"c674bf9c\":{\"m\":108,\"g\":115},\"af1973b8\":{\"m\":108,\"g\":115},\"5cfbb4c1\":{\"m\":108,\"g\":115},\"e6523102\":{\"m\":108,\"g\":115},\"3828db43\":{\"m\":108,\"g\":115},\"88fbc31b\":{\"m\":108,\"g\":115},\"8f5b9910\":{\"m\":108,\"g\":115},\"ef3004d9\":{\"m\":108,\"g\":115},\"84719b52\":{\"m\":108,\"g\":115},\"e99729c9\":{\"m\":108,\"g\":115},\"c10b8e6a\":{\"m\":108,\"g\":115},\"d4bce297\":{\"m\":108,\"g\":115},\"b0980af8\":{\"m\":108,\"g\":115},\"24eaebeb\":{\"m\":108,\"g\":115},\"a91e90d9\":{\"m\":108,\"g\":115},\"f96413c4\":{\"m\":108,\"g\":115},\"08ebdf79\":{\"m\":108,\"g\":115},\"42c87045\":{\"m\":108,\"g\":115},\"c9bf3877\":{\"m\":108,\"g\":115},\"de2dd738\":{\"m\":108,\"g\":115},\"1ec97697\":{\"m\":108,\"g\":115},\"d8ed60f2\":{\"m\":108,\"g\":115},\"f1b0eda5\":{\"m\":108,\"g\":115},\"f20b6a3f\":{\"m\":108,\"g\":115},\"3680d6f8\":{\"m\":108,\"g\":115},\"f5154495\":{\"m\":108,\"g\":115},\"e0ce171d\":{\"m\":108,\"g\":115},\"fe43e889\":{\"m\":108,\"g\":115},\"5ae5ecaa\":{\"m\":108,\"g\":115},\"5fbad308\":{\"m\":108,\"g\":115},\"7638f5e4\":{\"m\":108,\"g\":115},\"b45f753c\":{\"m\":108,\"g\":115},\"c5057262\":{\"m\":108,\"g\":115},\"46fe8b8c\":{\"m\":108,\"g\":115},\"0b95a01a\":{\"m\":108,\"g\":115},\"a3b810eb\":{\"m\":108,\"g\":115},\"94959237\":{\"m\":108,\"g\":115},\"f4fafacc\":{\"m\":108,\"g\":115},\"01d47a27\":{\"m\":108,\"g\":115},\"ecc9f3e4\":{\"m\":108,\"g\":115},\"7e8187e0\":{\"m\":108,\"g\":115},\"e483ab6d\":{\"m\":108,\"g\":115},\"720cd308\":{\"m\":108,\"g\":115},\"ce67b2d5\":{\"m\":108,\"g\":115},\"3c2c9f6c\":{\"m\":108,\"g\":115},\"a31ea448\":{\"m\":108,\"g\":115},\"439df454\":{\"m\":108,\"g\":115},\"5626e20b\":{\"m\":108,\"g\":115},\"c6c379ab\":{\"m\":108,\"g\":115},\"c2fbf60f\":{\"m\":108,\"g\":115},\"98b44e9e\":{\"m\":108,\"g\":115},\"6805f6da\":{\"m\":108,\"g\":115},\"ca533580\":{\"m\":108,\"g\":115},\"886454e8\":{\"m\":108,\"g\":115},\"0cf3fbeb\":{\"m\":108,\"g\":115},\"2256d62d\":{\"m\":108,\"g\":115},\"6cdcbcc6\":{\"m\":108,\"g\":115},\"c480a3f6\":{\"m\":108,\"g\":115},\"6e316588\":{\"m\":108,\"g\":115},\"24247b41\":{\"m\":108,\"g\":115},\"4c0bb411\":{\"m\":108,\"g\":115},\"968e1818\":{\"m\":108,\"g\":115},\"d08663ee\":{\"m\":108,\"g\":115},\"716e6827\":{\"m\":108,\"g\":115},\"84b30d9e\":{\"m\":108,\"g\":115},\"ff0cf51c\":{\"m\":108,\"g\":115},\"a1c7f742\":{\"m\":108,\"g\":115},\"ebbb75e9\":{\"m\":108,\"g\":115},\"b341b7db\":{\"m\":108,\"g\":115},\"b498cd21\":{\"m\":108,\"g\":115},\"0fc54b97\":{\"m\":108,\"g\":115},\"b3c1f2e4\":{\"m\":108,\"g\":115},\"be1a3cd9\":{\"m\":108,\"g\":115},\"4b74c3fc\":{\"m\":108,\"g\":115},\"ce3ca9b0\":{\"m\":108,\"g\":115},\"4d98e486\":{\"m\":108,\"g\":115},\"3d77a318\":{\"m\":108,\"g\":115},\"845d12a9\":{\"m\":108,\"g\":115},\"e47800e1\":{\"m\":108,\"g\":115},\"bb10e3a1\":{\"m\":108,\"g\":115},\"fda762a2\":{\"m\":108,\"g\":115},\"1df84ff4\":{\"m\":108,\"g\":115},\"66d6be08\":{\"m\":108,\"g\":115},\"1c1f8a11\":{\"m\":108,\"g\":115},\"384f8ab5\":{\"m\":108,\"g\":115},\"6a9d6ca3\":{\"m\":108,\"g\":115},\"94371dbb\":{\"m\":108,\"g\":115},\"740f0630\":{\"m\":108,\"g\":115},\"81da16f6\":{\"m\":108,\"g\":115},\"bc938ea1\":{\"m\":108,\"g\":115},\"eff4eb3f\":{\"m\":108,\"g\":115},\"87dab548\":{\"m\":108,\"g\":115},\"5121af46\":{\"m\":108,\"g\":115},\"983aa496\":{\"m\":108,\"g\":115},\"9c3e95d9\":{\"m\":108,\"g\":115},\"e52c3866\":{\"m\":108,\"g\":115},\"da53e13c\":{\"m\":108,\"g\":115},\"d7e38b2f\":{\"m\":108,\"g\":115},\"21b88460\":{\"m\":108,\"g\":115},\"0c8594e6\":{\"m\":108,\"g\":115},\"c186feed\":{\"m\":108,\"g\":115},\"84b006b2\":{\"m\":108,\"g\":115},\"8ca07bd9\":{\"m\":108,\"g\":115},\"4fc09e0d\":{\"m\":108,\"g\":115},\"a3d99d6d\":{\"m\":108,\"g\":115},\"189af908\":{\"m\":108,\"g\":115},\"f8644a56\":{\"m\":108,\"g\":115},\"e3e75a78\":{\"m\":108,\"g\":115},\"d4db9b02\":{\"m\":108,\"g\":115},\"f7dd651d\":{\"m\":108,\"g\":115},\"9d54c6e6\":{\"m\":108,\"g\":115},\"1f9d65f5\":{\"m\":108,\"g\":115},\"29589512\":{\"m\":108,\"g\":115},\"584e1ab2\":{\"m\":108,\"g\":115},\"392de007\":{\"m\":108,\"g\":115},\"004f7f19\":{\"m\":108,\"g\":115},\"d2fbf2de\":{\"m\":108,\"g\":115},\"fab0f6e7\":{\"m\":108,\"g\":115},\"27985c27\":{\"m\":108,\"g\":115},\"ac474869\":{\"m\":108,\"g\":115},\"0b1e04f0\":{\"m\":108,\"g\":115},\"c1c7dc45\":{\"m\":108,\"g\":115},\"2cc9eeab\":{\"m\":108,\"g\":115},\"63d82a77\":{\"m\":108,\"g\":115},\"53dcc750\":{\"m\":108,\"g\":115},\"432f2053\":{\"m\":108,\"g\":115},\"1fea998a\":{\"m\":108,\"g\":115},\"5aa1ebd2\":{\"m\":108,\"g\":115},\"4dbf4360\":{\"m\":108,\"g\":115},\"3d6be1fb\":{\"m\":108,\"g\":115},\"4063234c\":{\"m\":108,\"g\":115},\"83feef5b\":{\"m\":108,\"g\":115},\"2871eacc\":{\"m\":108,\"g\":115},\"ac15bdc1\":{\"m\":108,\"g\":115},\"d6451c3f\":{\"m\":108,\"g\":115},\"841810f2\":{\"m\":108,\"g\":115},\"733446dd\":{\"m\":108,\"g\":115},\"4c22897a\":{\"m\":108,\"g\":115},\"1bc183c6\":{\"m\":108,\"g\":115},\"b87aacb5\":{\"m\":108,\"g\":115},\"b3363cc1\":{\"m\":108,\"g\":115},\"98457c04\":{\"m\":108,\"g\":115},\"0fc8bf2c\":{\"m\":108,\"g\":115},\"a669bc2f\":{\"m\":108,\"g\":115},\"6b7c2471\":{\"m\":108,\"g\":115},\"a027a9b4\":{\"m\":108,\"g\":115},\"9e426466\":{\"m\":108,\"g\":115},\"2f20f430\":{\"m\":108,\"g\":115},\"65736dc5\":{\"m\":108,\"g\":115},\"7b56e494\":{\"m\":108,\"g\":115},\"0ff6d1fc\":{\"m\":108,\"g\":115},\"4a16a71c\":{\"m\":108,\"g\":115},\"a16923ef\":{\"m\":108,\"g\":115},\"6337d905\":{\"m\":108,\"g\":115},\"71fb8c95\":{\"m\":108,\"g\":115},\"94f44b88\":{\"m\":108,\"g\":115},\"3b3b3baf\":{\"m\":108,\"g\":115},\"35e6bc92\":{\"m\":108,\"g\":115},\"9394ed63\":{\"m\":108,\"g\":115},\"930fe467\":{\"m\":108,\"g\":115},\"13c48dcf\":{\"m\":108,\"g\":115},\"8723b4f1\":{\"m\":108,\"g\":115},\"62f99e08\":{\"m\":108,\"g\":115},\"86a0be65\":{\"m\":108,\"g\":115},\"0edda320\":{\"m\":108,\"g\":115},\"924827c3\":{\"m\":108,\"g\":115},\"c81daf83\":{\"m\":108,\"g\":115},\"25caa7a8\":{\"m\":108,\"g\":115},\"03d11449\":{\"m\":108,\"g\":115},\"83123f48\":{\"m\":108,\"g\":115},\"48afa8f1\":{\"m\":108,\"g\":115},\"2ecbd8b8\":{\"m\":108,\"g\":115},\"305b27c1\":{\"m\":108,\"g\":115},\"1ce30dd1\":{\"m\":108,\"g\":115},\"c9ee7385\":{\"m\":108,\"g\":115},\"1f9ec653\":{\"m\":108,\"g\":115},\"ad359d1c\":{\"m\":108,\"g\":115},\"5f5b3b24\":{\"m\":108,\"g\":115},\"4caca4f6\":{\"m\":108,\"g\":115},\"f2a5de28\":{\"m\":108,\"g\":115},\"445f9dca\":{\"m\":108,\"g\":115},\"3a9afe2a\":{\"m\":108,\"g\":115},\"9aea2555\":{\"m\":108,\"g\":115},\"fcc11e5e\":{\"m\":108,\"g\":115},\"5190ba7f\":{\"m\":108,\"g\":115},\"5438886c\":{\"m\":108,\"g\":115},\"9c83d74d\":{\"m\":108,\"g\":115},\"b4ac2b9c\":{\"m\":108,\"g\":115},\"83262dcb\":{\"m\":108,\"g\":115},\"c46c75f8\":{\"m\":108,\"g\":115},\"2aaf22c4\":{\"m\":108,\"g\":115},\"29a610b4\":{\"m\":108,\"g\":115},\"5ded39ca\":{\"m\":108,\"g\":115},\"4093d460\":{\"m\":108,\"g\":115},\"9d68bdb2\":{\"m\":108,\"g\":115},\"a2184901\":{\"m\":108,\"g\":115},\"0eec4cb6\":{\"m\":108,\"g\":115},\"ff1f6825\":{\"m\":108,\"g\":115},\"9f78f391\":{\"m\":108,\"g\":115},\"f508cd3c\":{\"m\":108,\"g\":115},\"44e86480\":{\"m\":108,\"g\":115},\"8c07fabd\":{\"m\":108,\"g\":115},\"90f44b74\":{\"m\":108,\"g\":115},\"38907fe6\":{\"m\":108,\"g\":115},\"f9afa7dc\":{\"m\":108,\"g\":115},\"0d9e89ec\":{\"m\":108,\"g\":115},\"3d64fda3\":{\"m\":108,\"g\":115},\"3bffe112\":{\"m\":108,\"g\":115},\"44426e54\":{\"m\":108,\"g\":115},\"9f24dfef\":{\"m\":108,\"g\":115},\"89f1d4f5\":{\"m\":108,\"g\":115},\"75e6a7cd\":{\"m\":108,\"g\":115},\"6f81a710\":{\"m\":108,\"g\":115},\"a6452b71\":{\"m\":108,\"g\":115},\"f4ae50e9\":{\"m\":108,\"g\":115},\"84cb449e\":{\"m\":108,\"g\":115},\"f003cd35\":{\"m\":108,\"g\":115},\"9d834fdc\":{\"m\":108,\"g\":115},\"b3279251\":{\"m\":108,\"g\":115},\"067068f2\":{\"m\":108,\"g\":115},\"6beeff41\":{\"m\":108,\"g\":115},\"2e8e7e35\":{\"m\":108,\"g\":115},\"2449a0af\":{\"m\":108,\"g\":115},\"0f229c07\":{\"m\":108,\"g\":115},\"dd001a54\":{\"m\":108,\"g\":115},\"4ea9d74a\":{\"m\":108,\"g\":115},\"dd949ace\":{\"m\":108,\"g\":115},\"f2887498\":{\"m\":108,\"g\":115},\"8ecf6b9d\":{\"m\":108,\"g\":115},\"0418b9d4\":{\"m\":108,\"g\":115},\"e322a94d\":{\"m\":108,\"g\":115},\"2c7f01bc\":{\"m\":108,\"g\":115},\"b58ae7a2\":{\"m\":108,\"g\":115},\"6345069f\":{\"m\":108,\"g\":115},\"ce9cf353\":{\"m\":108,\"g\":115},\"f8a173bb\":{\"m\":108,\"g\":115},\"6b847a9a\":{\"m\":108,\"g\":115},\"473400e4\":{\"m\":108,\"g\":115},\"dd665f96\":{\"m\":108,\"g\":115},\"3817a37d\":{\"m\":108,\"g\":115},\"7ba5ad57\":{\"m\":108,\"g\":115},\"19bc77f0\":{\"m\":108,\"g\":115},\"86497d99\":{\"m\":108,\"g\":115},\"5c31b35d\":{\"m\":108,\"g\":115},\"ef48d554\":{\"m\":108,\"g\":115},\"a886564a\":{\"m\":108,\"g\":115},\"9a44b643\":{\"m\":108,\"g\":115},\"41d71ca4\":{\"m\":108,\"g\":115},\"20cfc5a2\":{\"m\":108,\"g\":115},\"48b8b4c1\":{\"m\":108,\"g\":115},\"323bc2f5\":{\"m\":108,\"g\":115},\"137e75da\":{\"m\":108,\"g\":115},\"52e1f52f\":{\"m\":108,\"g\":115},\"50188092\":{\"m\":108,\"g\":115},\"326a901d\":{\"m\":108,\"g\":115},\"6e0b6468\":{\"m\":108,\"g\":115},\"4a9f3eef\":{\"m\":108,\"g\":115},\"1b7afad0\":{\"m\":108,\"g\":115},\"f29aba8c\":{\"m\":108,\"g\":115},\"faa25df1\":{\"m\":108,\"g\":115},\"7b81f956\":{\"m\":108,\"g\":115},\"d3e67deb\":{\"m\":108,\"g\":115},\"442534aa\":{\"m\":108,\"g\":115},\"de8b8b6e\":{\"m\":108,\"g\":115},\"3f2e315f\":{\"m\":108,\"g\":115},\"6e215118\":{\"m\":108,\"g\":115},\"a47baff1\":{\"m\":108,\"g\":115},\"fd7e15b7\":{\"m\":108,\"g\":115},\"fc42ff7b\":{\"m\":108,\"g\":115},\"7c0db868\":{\"m\":108,\"g\":115},\"706bd69c\":{\"m\":108,\"g\":115},\"23f2afb2\":{\"m\":108,\"g\":115},\"a60f88b5\":{\"m\":108,\"g\":115},\"591c232f\":{\"m\":108,\"g\":115},\"f352b793\":{\"m\":108,\"g\":115},\"6642e3a2\":{\"m\":108,\"g\":115},\"67a7d1f6\":{\"m\":108,\"g\":115},\"92cbef59\":{\"m\":108,\"g\":115},\"b3359dc9\":{\"m\":108,\"g\":115},\"7b7e5615\":{\"m\":108,\"g\":115},\"1a8706c8\":{\"m\":108,\"g\":115},\"7d3af603\":{\"m\":108,\"g\":115},\"4e7f0252\":{\"m\":108,\"g\":115},\"36bfddec\":{\"m\":108,\"g\":115},\"91e2f902\":{\"m\":108,\"g\":115},\"a59cbea9\":{\"m\":108,\"g\":115},\"53f7874a\":{\"m\":108,\"g\":115},\"61a46804\":{\"m\":108,\"g\":115},\"9020f7fc\":{\"m\":108,\"g\":115},\"dd650e0e\":{\"m\":108,\"g\":115},\"a9471542\":{\"m\":108,\"g\":115},\"41357e51\":{\"m\":108,\"g\":115},\"e2fd2b9c\":{\"m\":108,\"g\":115},\"7490e3f6\":{\"m\":108,\"g\":115},\"6ee6619b\":{\"m\":108,\"g\":115},\"54ea57f2\":{\"m\":108,\"g\":115},\"b4c9f38a\":{\"m\":108,\"g\":115},\"11325474\":{\"m\":108,\"g\":115},\"1d24db83\":{\"m\":108,\"g\":115},\"44401358\":{\"m\":108,\"g\":115},\"9c7e3924\":{\"m\":108,\"g\":115},\"08fab2b0\":{\"m\":108,\"g\":115},\"0d1e27a0\":{\"m\":108,\"g\":115},\"774b47f3\":{\"m\":108,\"g\":115},\"76915d68\":{\"m\":108,\"g\":115},\"39fd1788\":{\"m\":108,\"g\":115},\"ed0a3dd5\":{\"m\":108,\"g\":115},\"2e901e89\":{\"m\":108,\"g\":115},\"d3be9710\":{\"m\":108,\"g\":115},\"3e7ff1ab\":{\"m\":108,\"g\":115},\"aaf0ad8c\":{\"m\":108,\"g\":115},\"361379b5\":{\"m\":108,\"g\":115},\"1ac16add\":{\"m\":108,\"g\":115},\"c3a5fb3b\":{\"m\":108,\"g\":115},\"4bf6e5a6\":{\"m\":108,\"g\":115},\"3ae33fcd\":{\"m\":108,\"g\":115},\"500b15c9\":{\"m\":108,\"g\":107},\"16a4c66d\":{\"m\":108,\"g\":107},\"89e6521c\":{\"m\":108,\"g\":107},\"fd05b567\":{\"m\":108,\"g\":107},\"482c3db2\":{\"m\":108,\"g\":107},\"47824c14\":{\"m\":108,\"g\":107},\"c36a6693\":{\"m\":108,\"g\":107},\"62f8eb48\":{\"m\":108,\"g\":107},\"b7cd7430\":{\"m\":108,\"g\":107},\"a69b6370\":{\"m\":108,\"g\":107},\"2d120f8b\":{\"m\":108,\"g\":107},\"4f2e1490\":{\"m\":108,\"g\":107},\"3fa3c6cd\":{\"m\":108,\"g\":107},\"6210e2c4\":{\"m\":108,\"g\":107},\"6ad6c8c9\":{\"m\":108,\"g\":107},\"5b6acc14\":{\"m\":108,\"g\":107},\"4373df55\":{\"m\":108,\"g\":107},\"c0e84297\":{\"m\":108,\"g\":107},\"92cc32d9\":{\"m\":108,\"g\":107},\"cbbd685a\":{\"m\":108,\"g\":107},\"78aad910\":{\"m\":108,\"g\":107},\"288ae41f\":{\"m\":108,\"g\":107},\"01c99a99\":{\"m\":108,\"g\":107},\"b114a810\":{\"m\":108,\"g\":107},\"0475448e\":{\"m\":108,\"g\":107},\"399e7ec8\":{\"m\":108,\"g\":107},\"1bd53168\":{\"m\":108,\"g\":107},\"aeac900c\":{\"m\":108,\"g\":107},\"4fc5f2f9\":{\"m\":108,\"g\":107},\"168033d5\":{\"m\":108,\"g\":107},\"cbbb7383\":{\"m\":108,\"g\":107},\"89588179\":{\"m\":108,\"g\":107},\"8c7bb39d\":{\"m\":108,\"g\":107},\"ca47e24f\":{\"m\":108,\"g\":107},\"d26ca84f\":{\"m\":108,\"g\":107},\"8128e08d\":{\"m\":108,\"g\":107},\"5d62b56f\":{\"m\":108,\"g\":107},\"3ae8e3ea\":{\"m\":108,\"g\":107},\"c1d2061f\":{\"m\":108,\"g\":107},\"556e4143\":{\"m\":108,\"g\":107},\"4ef47839\":{\"m\":108,\"g\":107},\"32d9e39a\":{\"m\":108,\"g\":107},\"4f4e0e41\":{\"m\":108,\"g\":107},\"901ab758\":{\"m\":108,\"g\":107},\"8e8545ca\":{\"m\":108,\"g\":107},\"a4b0d5c9\":{\"m\":108,\"g\":107},\"40e3b2be\":{\"m\":108,\"g\":107},\"75df31b6\":{\"m\":108,\"g\":107},\"194561f2\":{\"m\":108,\"g\":107},\"5e91fed1\":{\"m\":108,\"g\":107},\"873f384a\":{\"m\":108,\"g\":107},\"b01eeb80\":{\"m\":108,\"g\":107},\"1ea94d3b\":{\"m\":108,\"g\":107},\"354ac435\":{\"m\":108,\"g\":107},\"d98a4913\":{\"m\":108,\"g\":107},\"08f8f490\":{\"m\":108,\"g\":107},\"d4bf5a85\":{\"m\":108,\"g\":107},\"7cb20754\":{\"m\":108,\"g\":107},\"6d0646da\":{\"m\":108,\"g\":107},\"02bc1c7d\":{\"m\":108,\"g\":107},\"fc8c8e50\":{\"m\":108,\"g\":107},\"9bd4872a\":{\"m\":108,\"g\":107},\"2fa0462c\":{\"m\":108,\"g\":107},\"915140fd\":{\"m\":108,\"g\":107},\"36fc9260\":{\"m\":108,\"g\":107},\"fee0ab0f\":{\"m\":108,\"g\":107},\"f57d2dc1\":{\"m\":108,\"g\":107},\"f2d68ded\":{\"m\":108,\"g\":107},\"3b87a9e8\":{\"m\":108,\"g\":107},\"f024795e\":{\"m\":108,\"g\":107},\"b102353f\":{\"m\":108,\"g\":107},\"7a27e798\":{\"m\":108,\"g\":107},\"76ba5bbe\":{\"m\":108,\"g\":107},\"ed6f7597\":{\"m\":108,\"g\":107},\"e67276ec\":{\"m\":108,\"g\":107},\"0242bb9c\":{\"m\":108,\"g\":107},\"760286e3\":{\"m\":108,\"g\":107},\"3435a24e\":{\"m\":108,\"g\":107},\"00da9065\":{\"m\":108,\"g\":107},\"e0ab167d\":{\"m\":109,\"g\":115},\"c807cd7c\":{\"m\":109,\"g\":115},\"327f7b7c\":{\"m\":109,\"g\":115},\"80425e59\":{\"m\":109,\"g\":115},\"af9d4eb0\":{\"m\":109,\"g\":115},\"fb107cfd\":{\"m\":109,\"g\":115},\"e3e97a12\":{\"m\":110,\"g\":115},\"05106867\":{\"m\":110,\"g\":115},\"9dcdf5da\":{\"m\":110,\"g\":115},\"f8b757bc\":{\"m\":110,\"g\":115},\"ebd9dbe7\":{\"m\":110,\"g\":115},\"938e986e\":{\"m\":110,\"g\":115},\"17d5eda8\":{\"m\":110,\"g\":115},\"71a7f1d8\":{\"m\":110,\"g\":115},\"433266c1\":{\"m\":110,\"g\":115},\"fda47926\":{\"m\":110,\"g\":115},\"a0b22f2f\":{\"m\":110,\"g\":115},\"b5c6529e\":{\"m\":110,\"g\":115},\"ca4b86c5\":{\"m\":110,\"g\":115},\"dd6ec029\":{\"m\":110,\"g\":115},\"bf863e3b\":{\"m\":110,\"g\":115},\"9e169ea8\":{\"m\":110,\"g\":115},\"bc80dc4c\":{\"m\":111,\"g\":115},\"b962a296\":{\"m\":111,\"g\":115},\"aa3eba8e\":{\"m\":111,\"g\":115},\"07ee0ab7\":{\"m\":111,\"g\":115},\"5c06dcb7\":{\"m\":111,\"g\":115},\"6f6beca4\":{\"m\":111,\"g\":115},\"68a54e06\":{\"m\":111,\"g\":115},\"fd18995c\":{\"m\":111,\"g\":115},\"db0831e0\":{\"m\":111,\"g\":115},\"6e4e1c8c\":{\"m\":111,\"g\":115},\"9768c50d\":{\"m\":111,\"g\":115},\"fd71b11b\":{\"m\":111,\"g\":115},\"ae7428a8\":{\"m\":111,\"g\":115},\"a3aee7c3\":{\"m\":111,\"g\":115},\"79e6a8a6\":{\"m\":111,\"g\":115},\"8f7b1c31\":{\"m\":111,\"g\":115},\"b9683be6\":{\"m\":111,\"g\":115},\"a85363c1\":{\"m\":111,\"g\":115},\"b21fdd53\":{\"m\":111,\"g\":115},\"c04c17ed\":{\"m\":111,\"g\":115},\"16a6d21b\":{\"m\":111,\"g\":115},\"a530b3ff\":{\"m\":111,\"g\":115},\"603b3446\":{\"m\":111,\"g\":115},\"b6c14ec0\":{\"m\":111,\"g\":115},\"43de1d73\":{\"m\":111,\"g\":115},\"79ce3688\":{\"m\":111,\"g\":115},\"44ffe2cb\":{\"m\":111,\"g\":115},\"1a0896e9\":{\"m\":111,\"g\":115},\"90313fb0\":{\"m\":111,\"g\":115},\"3578eb1e\":{\"m\":111,\"g\":115},\"0936c766\":{\"m\":111,\"g\":115},\"0ef583b7\":{\"m\":111,\"g\":115},\"f7881a27\":{\"m\":111,\"g\":115},\"fdff3167\":{\"m\":111,\"g\":115},\"cbc0e4d7\":{\"m\":111,\"g\":115},\"4cd08dc5\":{\"m\":111,\"g\":115},\"f92b729d\":{\"m\":111,\"g\":115},\"e2e378ca\":{\"m\":111,\"g\":115},\"dc1decc6\":{\"m\":111,\"g\":115},\"03680f33\":{\"m\":111,\"g\":115},\"d4c5e534\":{\"m\":111,\"g\":115},\"817c62a0\":{\"m\":111,\"g\":115},\"0ff72419\":{\"m\":111,\"g\":115},\"80dc76e1\":{\"m\":111,\"g\":115},\"9b08d975\":{\"m\":111,\"g\":115},\"a0a77d93\":{\"m\":111,\"g\":115},\"24a8cee6\":{\"m\":111,\"g\":115},\"3affa9dc\":{\"m\":111,\"g\":115},\"ea0696b9\":{\"m\":111,\"g\":115},\"3aec3d4f\":{\"m\":111,\"g\":115},\"b0d25e72\":{\"m\":112,\"g\":115},\"a2424068\":{\"m\":112,\"g\":115},\"c5d2b01c\":{\"m\":112,\"g\":115},\"46ccbed2\":{\"m\":112,\"g\":115},\"fe68c148\":{\"m\":112,\"g\":115},\"70c0c1f9\":{\"m\":112,\"g\":115},\"760b788a\":{\"m\":112,\"g\":115},\"1ee11df8\":{\"m\":112,\"g\":115},\"dee197e1\":{\"m\":112,\"g\":115},\"ab795ae8\":{\"m\":112,\"g\":115},\"480d1b8b\":{\"m\":112,\"g\":115},\"6c18ab46\":{\"m\":112,\"g\":115},\"4a0e0be2\":{\"m\":112,\"g\":115},\"64f296f8\":{\"m\":112,\"g\":115},\"956d805d\":{\"m\":112,\"g\":115},\"30c6e1f5\":{\"m\":112,\"g\":115},\"bfe01a5e\":{\"m\":112,\"g\":115},\"3dd6420a\":{\"m\":112,\"g\":115},\"532f998b\":{\"m\":112,\"g\":115},\"de15d140\":{\"m\":112,\"g\":115},\"37367da6\":{\"m\":112,\"g\":115},\"ef959d7b\":{\"m\":112,\"g\":115},\"4aa1e69b\":{\"m\":112,\"g\":115},\"dc491b39\":{\"m\":112,\"g\":115},\"5b64f006\":{\"m\":112,\"g\":115},\"5b7448de\":{\"m\":112,\"g\":115},\"6d55f60e\":{\"m\":112,\"g\":115},\"033b75f5\":{\"m\":112,\"g\":115},\"f3b5db6e\":{\"m\":112,\"g\":115},\"2286e85e\":{\"m\":112,\"g\":115},\"91b3555d\":{\"m\":112,\"g\":115},\"9e2f7252\":{\"m\":112,\"g\":115},\"21176b00\":{\"m\":112,\"g\":115},\"94100294\":{\"m\":112,\"g\":115},\"cda7e47c\":{\"m\":112,\"g\":115},\"e903f695\":{\"m\":112,\"g\":115},\"27760fc1\":{\"m\":112,\"g\":115},\"0ac809de\":{\"m\":112,\"g\":115},\"4efe2c57\":{\"m\":112,\"g\":115},\"5be8c2f7\":{\"m\":112,\"g\":115},\"737d73ed\":{\"m\":112,\"g\":115},\"ebd0e1c1\":{\"m\":112,\"g\":115},\"a1d03892\":{\"m\":112,\"g\":115},\"dccf52f9\":{\"m\":112,\"g\":115},\"676a7b51\":{\"m\":112,\"g\":115},\"15f99347\":{\"m\":112,\"g\":115},\"bcf1955f\":{\"m\":112,\"g\":115},\"a06bf664\":{\"m\":112,\"g\":115},\"bf72b801\":{\"m\":112,\"g\":115},\"8cbe1538\":{\"m\":112,\"g\":115},\"8471e5e6\":{\"m\":112,\"g\":115},\"4582931a\":{\"m\":112,\"g\":115},\"d352c29a\":{\"m\":112,\"g\":115},\"d3ee7098\":{\"m\":112,\"g\":115},\"71fc7b7f\":{\"m\":112,\"g\":115},\"9ab72f98\":{\"m\":112,\"g\":115},\"f3817cb0\":{\"m\":112,\"g\":115},\"71133a04\":{\"m\":112,\"g\":115},\"2cd94dd0\":{\"m\":112,\"g\":115},\"f5f6b3b4\":{\"m\":112,\"g\":115},\"94fb4e9e\":{\"m\":112,\"g\":115},\"d1d4074c\":{\"m\":112,\"g\":115},\"718f25ae\":{\"m\":112,\"g\":115},\"948b01a0\":{\"m\":112,\"g\":115},\"cdc56ef6\":{\"m\":112,\"g\":115},\"16ff3d4b\":{\"m\":112,\"g\":115},\"83d55ac5\":{\"m\":112,\"g\":115},\"2fe17735\":{\"m\":112,\"g\":115},\"97fff98c\":{\"m\":112,\"g\":115},\"ba066ca0\":{\"m\":112,\"g\":115},\"96784a65\":{\"m\":112,\"g\":115},\"df5407fb\":{\"m\":112,\"g\":115},\"8ad700f7\":{\"m\":112,\"g\":115},\"148022fc\":{\"m\":112,\"g\":115},\"7a40e4f4\":{\"m\":112,\"g\":115},\"19d64f2b\":{\"m\":112,\"g\":115},\"a02071a1\":{\"m\":112,\"g\":115},\"45b3a6a2\":{\"m\":112,\"g\":115},\"9a18aa54\":{\"m\":112,\"g\":115},\"91f0fd95\":{\"m\":112,\"g\":115},\"8085aca7\":{\"m\":112,\"g\":115},\"0096798e\":{\"m\":112,\"g\":115},\"2c2b19b1\":{\"m\":112,\"g\":115},\"72f9fc5f\":{\"m\":112,\"g\":115},\"ec99668a\":{\"m\":112,\"g\":115},\"78f13981\":{\"m\":112,\"g\":115},\"bfd7a18d\":{\"m\":112,\"g\":115},\"5dd8c644\":{\"m\":112,\"g\":115},\"ee21817c\":{\"m\":112,\"g\":115},\"b7d1f17b\":{\"m\":112,\"g\":115},\"c8295d23\":{\"m\":112,\"g\":115},\"b67c277f\":{\"m\":112,\"g\":115},\"8116804e\":{\"m\":112,\"g\":115},\"8c5930f0\":{\"m\":112,\"g\":115},\"3b99f23c\":{\"m\":112,\"g\":115},\"ee0b3c5b\":{\"m\":112,\"g\":115},\"6049ca20\":{\"m\":112,\"g\":115},\"7577f0e4\":{\"m\":112,\"g\":115},\"8cda5a62\":{\"m\":112,\"g\":115},\"400d3b97\":{\"m\":112,\"g\":115},\"37d83c6e\":{\"m\":112,\"g\":115},\"7802586c\":{\"m\":112,\"g\":115},\"bc5fc332\":{\"m\":112,\"g\":115},\"f3440adc\":{\"m\":112,\"g\":115},\"5a7e10fe\":{\"m\":112,\"g\":115},\"33467c05\":{\"m\":112,\"g\":115},\"b0fcbb74\":{\"m\":112,\"g\":115},\"76a2c86b\":{\"m\":112,\"g\":115},\"e719bb0e\":{\"m\":112,\"g\":115},\"06724683\":{\"m\":112,\"g\":115},\"617aa2b2\":{\"m\":112,\"g\":115},\"111b1379\":{\"m\":112,\"g\":115},\"41628dc1\":{\"m\":112,\"g\":115},\"a12061df\":{\"m\":112,\"g\":115},\"85ed8e0a\":{\"m\":112,\"g\":115},\"dd1e2689\":{\"m\":112,\"g\":115},\"9a7ced4e\":{\"m\":112,\"g\":115},\"cb3918a0\":{\"m\":112,\"g\":115},\"f3b67602\":{\"m\":112,\"g\":115},\"9eb50ecc\":{\"m\":112,\"g\":115},\"b3e7a2ce\":{\"m\":112,\"g\":115},\"00974e4f\":{\"m\":112,\"g\":115},\"5f1eb204\":{\"m\":112,\"g\":115},\"039cef76\":{\"m\":112,\"g\":115},\"4c22ebe2\":{\"m\":112,\"g\":115},\"a5a03209\":{\"m\":112,\"g\":115},\"21af5c04\":{\"m\":112,\"g\":115},\"012584ec\":{\"m\":112,\"g\":115},\"90dfe3de\":{\"m\":112,\"g\":115},\"9a719b7a\":{\"m\":112,\"g\":115},\"3fa62da7\":{\"m\":112,\"g\":115},\"dbb1235d\":{\"m\":112,\"g\":115},\"ad26f298\":{\"m\":112,\"g\":115},\"8d114f25\":{\"m\":112,\"g\":115},\"0e78c63c\":{\"m\":112,\"g\":115},\"1a3d6f31\":{\"m\":112,\"g\":115},\"0b8c5721\":{\"m\":112,\"g\":115},\"beac202b\":{\"m\":112,\"g\":115},\"21b9a4b4\":{\"m\":112,\"g\":115},\"db37422c\":{\"m\":112,\"g\":115},\"ab62b135\":{\"m\":112,\"g\":115},\"273b2834\":{\"m\":112,\"g\":115},\"f84db115\":{\"m\":112,\"g\":115},\"efb0de2c\":{\"m\":112,\"g\":115},\"0f6ac5e2\":{\"m\":112,\"g\":115},\"29850900\":{\"m\":112,\"g\":115},\"e678cc71\":{\"m\":112,\"g\":115},\"4efe844a\":{\"m\":112,\"g\":115},\"bde73ee4\":{\"m\":112,\"g\":115},\"4f0e28d7\":{\"m\":112,\"g\":115},\"045ab92d\":{\"m\":112,\"g\":115},\"bd7f8821\":{\"m\":112,\"g\":115},\"5e5c30d9\":{\"m\":112,\"g\":115},\"9f00ec44\":{\"m\":112,\"g\":115},\"8e85ee88\":{\"m\":112,\"g\":115},\"adf73175\":{\"m\":112,\"g\":115},\"13705dae\":{\"m\":112,\"g\":115},\"df97b31f\":{\"m\":112,\"g\":115},\"339f8eef\":{\"m\":112,\"g\":115},\"afd9f2f5\":{\"m\":112,\"g\":115},\"f40038fb\":{\"m\":112,\"g\":115},\"bebd0576\":{\"m\":112,\"g\":115},\"f9836660\":{\"m\":112,\"g\":115},\"8b3b995a\":{\"m\":112,\"g\":115},\"6e95f5e5\":{\"m\":112,\"g\":115},\"0e9387a9\":{\"m\":112,\"g\":115},\"fa9c82d3\":{\"m\":112,\"g\":115},\"918e3d4c\":{\"m\":112,\"g\":115},\"e9697374\":{\"m\":112,\"g\":115},\"93088b69\":{\"m\":112,\"g\":115},\"453511ac\":{\"m\":112,\"g\":115},\"d0730487\":{\"m\":112,\"g\":115},\"b32ab070\":{\"m\":112,\"g\":115},\"75ee0011\":{\"m\":112,\"g\":115},\"ec15c836\":{\"m\":112,\"g\":115},\"106c2b31\":{\"m\":112,\"g\":115},\"c6756949\":{\"m\":112,\"g\":115},\"27e8ffed\":{\"m\":112,\"g\":115},\"4dbb34fe\":{\"m\":112,\"g\":115},\"1e18a341\":{\"m\":112,\"g\":115},\"2c562fd2\":{\"m\":112,\"g\":115},\"b648d862\":{\"m\":112,\"g\":115},\"bbf261ae\":{\"m\":112,\"g\":115},\"4f8a982d\":{\"m\":112,\"g\":115},\"d966b902\":{\"m\":112,\"g\":115},\"de921733\":{\"m\":112,\"g\":115},\"397448eb\":{\"m\":112,\"g\":115},\"66d5d042\":{\"m\":112,\"g\":115},\"73179b76\":{\"m\":112,\"g\":115},\"8cbf71dc\":{\"m\":112,\"g\":115},\"56eb5d0a\":{\"m\":112,\"g\":115},\"4ed9053e\":{\"m\":112,\"g\":115},\"5e19b159\":{\"m\":112,\"g\":115},\"788b19a5\":{\"m\":112,\"g\":115},\"f78b7fd1\":{\"m\":112,\"g\":115},\"b1fb7e45\":{\"m\":112,\"g\":115},\"1b2ff4fb\":{\"m\":112,\"g\":115},\"2c7ca33a\":{\"m\":112,\"g\":115},\"df397a72\":{\"m\":112,\"g\":115},\"5dfcd6c2\":{\"m\":112,\"g\":115},\"0dfd54d1\":{\"m\":112,\"g\":115},\"bcbeed71\":{\"m\":112,\"g\":115},\"cc9a31c6\":{\"m\":112,\"g\":115},\"d631290e\":{\"m\":112,\"g\":115},\"37565b7f\":{\"m\":112,\"g\":115},\"6243c367\":{\"m\":112,\"g\":115},\"60e37f80\":{\"m\":112,\"g\":115},\"369b1433\":{\"m\":112,\"g\":115},\"03dbf1aa\":{\"m\":112,\"g\":115},\"11dcabc5\":{\"m\":112,\"g\":115},\"4d89389c\":{\"m\":112,\"g\":115},\"9491d6e5\":{\"m\":112,\"g\":115},\"f64b8e3e\":{\"m\":112,\"g\":115},\"53976fce\":{\"m\":112,\"g\":115},\"18f91eb6\":{\"m\":112,\"g\":115},\"8766b3ac\":{\"m\":112,\"g\":115},\"1db649ac\":{\"m\":112,\"g\":115},\"a1e5d781\":{\"m\":112,\"g\":115},\"b7361cc4\":{\"m\":112,\"g\":115},\"a96c5b5c\":{\"m\":112,\"g\":115},\"b9eb0d9c\":{\"m\":112,\"g\":115},\"1fbfdebe\":{\"m\":112,\"g\":115},\"a25e8e42\":{\"m\":112,\"g\":115},\"d4a93841\":{\"m\":112,\"g\":115},\"21e1bc47\":{\"m\":112,\"g\":115},\"9a0cac1b\":{\"m\":112,\"g\":115},\"b5245064\":{\"m\":112,\"g\":115},\"9d9fa9a5\":{\"m\":112,\"g\":115},\"58d06fdc\":{\"m\":112,\"g\":115},\"cb9e0e41\":{\"m\":112,\"g\":115},\"9db80253\":{\"m\":112,\"g\":115},\"598c0bc1\":{\"m\":112,\"g\":115},\"b361750a\":{\"m\":112,\"g\":115},\"16e56ea6\":{\"m\":112,\"g\":115},\"349b491c\":{\"m\":112,\"g\":115},\"5f77e129\":{\"m\":112,\"g\":115},\"4750cddf\":{\"m\":112,\"g\":115},\"065e523d\":{\"m\":112,\"g\":115},\"7de2ce45\":{\"m\":112,\"g\":115},\"8c2ffaaf\":{\"m\":112,\"g\":115},\"20445327\":{\"m\":112,\"g\":115},\"6d3c20cf\":{\"m\":112,\"g\":115},\"8b6966d0\":{\"m\":112,\"g\":115},\"a391f73a\":{\"m\":112,\"g\":115},\"25c73959\":{\"m\":112,\"g\":115},\"f05c6873\":{\"m\":112,\"g\":115},\"9a0d0b75\":{\"m\":112,\"g\":115},\"ba861293\":{\"m\":112,\"g\":115},\"c112bcc4\":{\"m\":112,\"g\":115},\"5e194b21\":{\"m\":112,\"g\":115},\"fd5ce576\":{\"m\":112,\"g\":115},\"92d79646\":{\"m\":112,\"g\":115},\"f9076a5a\":{\"m\":112,\"g\":115},\"646076b7\":{\"m\":112,\"g\":115},\"0d040089\":{\"m\":112,\"g\":115},\"05e47872\":{\"m\":112,\"g\":115},\"1e61b496\":{\"m\":112,\"g\":115},\"300676af\":{\"m\":112,\"g\":115},\"7fe89f7c\":{\"m\":112,\"g\":115},\"9970e3bf\":{\"m\":112,\"g\":115},\"70eedb58\":{\"m\":112,\"g\":115},\"9c99949e\":{\"m\":112,\"g\":115},\"c5082f0f\":{\"m\":112,\"g\":115},\"836873b9\":{\"m\":112,\"g\":115},\"8abe8dea\":{\"m\":112,\"g\":115},\"1e85589d\":{\"m\":112,\"g\":115},\"c2a26e72\":{\"m\":112,\"g\":115},\"591e6c59\":{\"m\":112,\"g\":115},\"42f34437\":{\"m\":112,\"g\":115},\"5c34b4f1\":{\"m\":112,\"g\":115},\"ff9b5618\":{\"m\":112,\"g\":115},\"fcd72bd1\":{\"m\":112,\"g\":115},\"3d8fc434\":{\"m\":112,\"g\":115},\"87a0f7d2\":{\"m\":112,\"g\":115},\"839c93bd\":{\"m\":112,\"g\":115},\"f1e9bbaf\":{\"m\":112,\"g\":115},\"3fd1431d\":{\"m\":112,\"g\":115},\"161e9dc5\":{\"m\":112,\"g\":115},\"54e872d3\":{\"m\":112,\"g\":115},\"e5b29bf1\":{\"m\":112,\"g\":115},\"9a7c8842\":{\"m\":112,\"g\":115},\"7a16db9b\":{\"m\":112,\"g\":115},\"09a1df22\":{\"m\":112,\"g\":115},\"4b7034dd\":{\"m\":112,\"g\":115},\"a23c3020\":{\"m\":112,\"g\":115},\"a7d825fc\":{\"m\":112,\"g\":115},\"38cd5fb1\":{\"m\":112,\"g\":115},\"001f5194\":{\"m\":112,\"g\":115},\"5ad296bd\":{\"m\":112,\"g\":115},\"9f81d741\":{\"m\":112,\"g\":115},\"a38c1497\":{\"m\":112,\"g\":115},\"74dd4249\":{\"m\":112,\"g\":115},\"dc20c22f\":{\"m\":112,\"g\":115},\"711390a9\":{\"m\":112,\"g\":115},\"53430588\":{\"m\":112,\"g\":115},\"fce7ae33\":{\"m\":112,\"g\":115},\"6b39f9cf\":{\"m\":112,\"g\":115},\"07c9d8fb\":{\"m\":112,\"g\":115},\"4a4772ae\":{\"m\":112,\"g\":115},\"c3779233\":{\"m\":112,\"g\":115},\"f84b57c8\":{\"m\":112,\"g\":115},\"aee094e4\":{\"m\":112,\"g\":115},\"55349e36\":{\"m\":112,\"g\":115},\"e1f7cf57\":{\"m\":112,\"g\":115},\"2bb9d454\":{\"m\":112,\"g\":115},\"d0934a51\":{\"m\":112,\"g\":115},\"3f2d0cef\":{\"m\":112,\"g\":115},\"8b30bec2\":{\"m\":112,\"g\":115},\"4aeba40d\":{\"m\":112,\"g\":115},\"28684f90\":{\"m\":112,\"g\":115},\"a4a3d823\":{\"m\":113,\"g\":115},\"0b13cbb7\":{\"m\":113,\"g\":115},\"efbc687c\":{\"m\":113,\"g\":115},\"292a867a\":{\"m\":113,\"g\":115},\"8fd41eae\":{\"m\":113,\"g\":115},\"0cd1996e\":{\"m\":113,\"g\":115},\"b6b4b563\":{\"m\":113,\"g\":115},\"f8924ad7\":{\"m\":113,\"g\":115},\"2f80bd9f\":{\"m\":113,\"g\":115},\"366a603e\":{\"m\":113,\"g\":115},\"baee0860\":{\"m\":113,\"g\":115},\"c7a104c1\":{\"m\":113,\"g\":115},\"97d966a7\":{\"m\":113,\"g\":115},\"8e66d87f\":{\"m\":113,\"g\":115},\"a20fc7b7\":{\"m\":113,\"g\":115},\"6b30e097\":{\"m\":113,\"g\":115},\"d645ae90\":{\"m\":113,\"g\":115},\"41763ba0\":{\"m\":113,\"g\":115},\"652c24a6\":{\"m\":113,\"g\":115},\"5e142484\":{\"m\":113,\"g\":115},\"c560410d\":{\"m\":113,\"g\":115},\"590f2da0\":{\"m\":113,\"g\":115},\"148d8d48\":{\"m\":113,\"g\":115},\"1a599509\":{\"m\":113,\"g\":115},\"36a6b8db\":{\"m\":113,\"g\":115},\"e0b2d3ee\":{\"m\":113,\"g\":115},\"4cb5a523\":{\"m\":113,\"g\":115},\"85c1f793\":{\"m\":113,\"g\":115},\"48e9e719\":{\"m\":113,\"g\":115},\"31b49c0b\":{\"m\":113,\"g\":115},\"d736e0b6\":{\"m\":113,\"g\":115},\"ffd03a9b\":{\"m\":113,\"g\":115},\"666da3d5\":{\"m\":113,\"g\":115},\"d01b9214\":{\"m\":113,\"g\":115},\"c70e58e8\":{\"m\":113,\"g\":115},\"c61b9a1d\":{\"m\":113,\"g\":115},\"3c3d6255\":{\"m\":113,\"g\":115},\"546914fa\":{\"m\":113,\"g\":115},\"4726c919\":{\"m\":113,\"g\":115},\"a0010bf4\":{\"m\":113,\"g\":115},\"307fc060\":{\"m\":113,\"g\":115},\"586e81a2\":{\"m\":113,\"g\":115},\"fad7ca73\":{\"m\":113,\"g\":115},\"08af8ffb\":{\"m\":113,\"g\":115},\"2c7f4ca2\":{\"m\":113,\"g\":115},\"03def5e3\":{\"m\":113,\"g\":115},\"6ae3f05b\":{\"m\":113,\"g\":115},\"fdc4e1e5\":{\"m\":113,\"g\":115},\"04b86b3c\":{\"m\":113,\"g\":115},\"d6777a70\":{\"m\":113,\"g\":115},\"8c574902\":{\"m\":113,\"g\":115},\"34151f17\":{\"m\":113,\"g\":115},\"6794d210\":{\"m\":113,\"g\":115},\"1a31229c\":{\"m\":113,\"g\":115},\"de89ef49\":{\"m\":113,\"g\":115},\"b00a0c78\":{\"m\":113,\"g\":115},\"a2faf894\":{\"m\":113,\"g\":115},\"7e61737d\":{\"m\":113,\"g\":115},\"3c699772\":{\"m\":113,\"g\":115},\"e8100774\":{\"m\":113,\"g\":115},\"963175d5\":{\"m\":113,\"g\":115},\"0618ad6d\":{\"m\":113,\"g\":115},\"6a261aac\":{\"m\":113,\"g\":115},\"7ff740a6\":{\"m\":113,\"g\":115},\"bfcd9b24\":{\"m\":113,\"g\":115},\"458611de\":{\"m\":113,\"g\":115},\"3511b370\":{\"m\":113,\"g\":115},\"afcd3e10\":{\"m\":113,\"g\":115},\"12d68183\":{\"m\":113,\"g\":115},\"b65db028\":{\"m\":113,\"g\":115},\"948278f1\":{\"m\":113,\"g\":115},\"7d004799\":{\"m\":113,\"g\":115},\"083629c2\":{\"m\":113,\"g\":115},\"b658be6f\":{\"m\":113,\"g\":115},\"5e786cca\":{\"m\":113,\"g\":115},\"0b9dfba7\":{\"m\":113,\"g\":115},\"6a290034\":{\"m\":113,\"g\":115},\"2ac453b0\":{\"m\":113,\"g\":115},\"f35def86\":{\"m\":113,\"g\":115},\"d61615fe\":{\"m\":113,\"g\":115},\"b1ccaf01\":{\"m\":113,\"g\":115},\"097725bb\":{\"m\":113,\"g\":115},\"44b1fbe2\":{\"m\":113,\"g\":115},\"c0dbbdd1\":{\"m\":113,\"g\":115},\"25e7dbe8\":{\"m\":113,\"g\":115},\"0b2aa8a7\":{\"m\":113,\"g\":115},\"609f65ba\":{\"m\":113,\"g\":115},\"2d62af6b\":{\"m\":113,\"g\":115},\"a28b394f\":{\"m\":113,\"g\":115},\"96fe2d0f\":{\"m\":113,\"g\":115},\"bfa27438\":{\"m\":113,\"g\":115},\"86cb4db0\":{\"m\":113,\"g\":115},\"2e130b76\":{\"m\":113,\"g\":115},\"ac1f2928\":{\"m\":113,\"g\":115},\"195a59fe\":{\"m\":113,\"g\":115},\"47488cc3\":{\"m\":113,\"g\":115},\"61305291\":{\"m\":113,\"g\":115},\"a9ce2bcb\":{\"m\":113,\"g\":115},\"5dddb331\":{\"m\":113,\"g\":115},\"01a26544\":{\"m\":113,\"g\":115},\"73d4a5f8\":{\"m\":113,\"g\":115},\"7fb551a7\":{\"m\":113,\"g\":115},\"1193f131\":{\"m\":113,\"g\":115},\"84a9f5d6\":{\"m\":113,\"g\":115},\"8ce830a8\":{\"m\":113,\"g\":115},\"fb367acf\":{\"m\":113,\"g\":115},\"a6cc86df\":{\"m\":113,\"g\":115},\"229d2b95\":{\"m\":113,\"g\":115},\"9710f718\":{\"m\":113,\"g\":115},\"91847e38\":{\"m\":113,\"g\":115},\"5a290a56\":{\"m\":113,\"g\":115},\"580051c5\":{\"m\":113,\"g\":115},\"1237aa19\":{\"m\":113,\"g\":115},\"59911195\":{\"m\":113,\"g\":115},\"424591d5\":{\"m\":113,\"g\":115},\"d1676cd4\":{\"m\":113,\"g\":115},\"33b3c0f8\":{\"m\":113,\"g\":115},\"e5281f84\":{\"m\":113,\"g\":115},\"d17986f8\":{\"m\":113,\"g\":115},\"8831c55c\":{\"m\":113,\"g\":115},\"2bc61dd1\":{\"m\":113,\"g\":115},\"6535fda1\":{\"m\":113,\"g\":115},\"3713eb61\":{\"m\":113,\"g\":115},\"5937a56d\":{\"m\":113,\"g\":115},\"f065e5be\":{\"m\":113,\"g\":115},\"9de1320b\":{\"m\":113,\"g\":115},\"dda34c2f\":{\"m\":113,\"g\":115},\"4eeaff74\":{\"m\":113,\"g\":115},\"a17e70f5\":{\"m\":113,\"g\":115},\"816b3a43\":{\"m\":113,\"g\":115},\"3a641d90\":{\"m\":113,\"g\":115},\"6f16bf9d\":{\"m\":113,\"g\":115},\"5942fdb4\":{\"m\":113,\"g\":115},\"af4ab656\":{\"m\":113,\"g\":115},\"11965b0d\":{\"m\":113,\"g\":115},\"71959545\":{\"m\":113,\"g\":115},\"24f7cb1e\":{\"m\":113,\"g\":115},\"e05555fa\":{\"m\":113,\"g\":115},\"43fa9f22\":{\"m\":113,\"g\":115},\"e98d9346\":{\"m\":113,\"g\":115},\"0c917410\":{\"m\":113,\"g\":115},\"25728863\":{\"m\":113,\"g\":115},\"dba751a8\":{\"m\":113,\"g\":115},\"2e763398\":{\"m\":113,\"g\":115},\"336e9a60\":{\"m\":113,\"g\":115},\"abb67815\":{\"m\":113,\"g\":115},\"07440f5f\":{\"m\":113,\"g\":115},\"9816989b\":{\"m\":113,\"g\":115},\"42245551\":{\"m\":113,\"g\":115},\"2a9d995c\":{\"m\":113,\"g\":115},\"a9050b5c\":{\"m\":113,\"g\":115},\"66face35\":{\"m\":113,\"g\":115},\"5519766a\":{\"m\":113,\"g\":115},\"72392f29\":{\"m\":113,\"g\":115},\"c1c8dd1d\":{\"m\":113,\"g\":115},\"f6bc3f52\":{\"m\":113,\"g\":115},\"8cc27fdc\":{\"m\":113,\"g\":115},\"9c339d6b\":{\"m\":113,\"g\":115},\"e23e280e\":{\"m\":113,\"g\":115},\"51f7c6bd\":{\"m\":113,\"g\":115},\"62e2e99d\":{\"m\":113,\"g\":115},\"8ebf72fe\":{\"m\":113,\"g\":115},\"82605747\":{\"m\":113,\"g\":115},\"37f3325b\":{\"m\":113,\"g\":115},\"bd95944c\":{\"m\":113,\"g\":115},\"c8a5d12a\":{\"m\":113,\"g\":115},\"2387c22b\":{\"m\":113,\"g\":115},\"592ddf37\":{\"m\":113,\"g\":115},\"0c3db889\":{\"m\":113,\"g\":115},\"2bdaf482\":{\"m\":113,\"g\":115},\"777eb538\":{\"m\":113,\"g\":115},\"05a35266\":{\"m\":113,\"g\":115},\"e56c64bf\":{\"m\":113,\"g\":115},\"fff7fbab\":{\"m\":113,\"g\":115},\"aae7ead2\":{\"m\":113,\"g\":115},\"a7fe6e10\":{\"m\":113,\"g\":115},\"be059b83\":{\"m\":113,\"g\":115},\"5d4fe1ce\":{\"m\":113,\"g\":115},\"1b011e68\":{\"m\":113,\"g\":115},\"5c0efa56\":{\"m\":113,\"g\":115},\"1e57b947\":{\"m\":113,\"g\":115},\"a5095d62\":{\"m\":113,\"g\":115},\"6c2c467d\":{\"m\":113,\"g\":115},\"c3d2ad4e\":{\"m\":113,\"g\":115},\"7ec5b4e8\":{\"m\":113,\"g\":115},\"60885482\":{\"m\":113,\"g\":115},\"172bcf01\":{\"m\":113,\"g\":115},\"37158f20\":{\"m\":113,\"g\":115},\"3e95aa1a\":{\"m\":113,\"g\":115},\"c4197e99\":{\"m\":113,\"g\":115},\"0ac61146\":{\"m\":113,\"g\":115},\"7dcd689b\":{\"m\":113,\"g\":115},\"f7bab41a\":{\"m\":113,\"g\":115},\"f68dd998\":{\"m\":113,\"g\":115},\"35ec2a45\":{\"m\":113,\"g\":115},\"0035f1ce\":{\"m\":113,\"g\":115},\"5e21d6ae\":{\"m\":113,\"g\":115},\"cd4da1f1\":{\"m\":113,\"g\":115},\"91678474\":{\"m\":113,\"g\":115},\"d511b2d9\":{\"m\":113,\"g\":115},\"77830a26\":{\"m\":113,\"g\":115},\"fce17048\":{\"m\":113,\"g\":115},\"3d40794f\":{\"m\":113,\"g\":115},\"c1f39013\":{\"m\":113,\"g\":115},\"3e43eb13\":{\"m\":113,\"g\":115},\"458c0219\":{\"m\":113,\"g\":115},\"a73eb8cd\":{\"m\":113,\"g\":115},\"e7387035\":{\"m\":113,\"g\":115},\"fe531d6f\":{\"m\":113,\"g\":115},\"c4e314f9\":{\"m\":113,\"g\":115},\"7a06ef98\":{\"m\":113,\"g\":115},\"4a87ba21\":{\"m\":113,\"g\":115},\"d7b20dd6\":{\"m\":113,\"g\":115},\"c3faf2d6\":{\"m\":113,\"g\":115},\"9209b209\":{\"m\":113,\"g\":115},\"adba172f\":{\"m\":113,\"g\":115},\"cd641a99\":{\"m\":113,\"g\":115},\"71f24ef8\":{\"m\":113,\"g\":115},\"b1f0fc1c\":{\"m\":113,\"g\":115},\"32d89373\":{\"m\":113,\"g\":115},\"f47a2c67\":{\"m\":113,\"g\":115},\"ee704e62\":{\"m\":113,\"g\":115},\"f4e3ebeb\":{\"m\":113,\"g\":115},\"312bfc4c\":{\"m\":113,\"g\":115},\"e290303e\":{\"m\":113,\"g\":115},\"aab35bcc\":{\"m\":113,\"g\":115},\"42aedb02\":{\"m\":113,\"g\":115},\"984730b7\":{\"m\":113,\"g\":115},\"23632d35\":{\"m\":113,\"g\":115},\"08b8c0c3\":{\"m\":113,\"g\":115},\"d42975c6\":{\"m\":113,\"g\":115},\"adc24a3a\":{\"m\":113,\"g\":115},\"7ff93e61\":{\"m\":113,\"g\":115},\"b24b2e7e\":{\"m\":113,\"g\":115},\"7135db5d\":{\"m\":113,\"g\":115},\"4b5ef300\":{\"m\":113,\"g\":115},\"4f564b9e\":{\"m\":113,\"g\":115},\"98c3b04f\":{\"m\":113,\"g\":115},\"ddab4fc7\":{\"m\":113,\"g\":115},\"d21c3522\":{\"m\":113,\"g\":115},\"4a762041\":{\"m\":113,\"g\":115},\"ea338676\":{\"m\":113,\"g\":115},\"b06db198\":{\"m\":113,\"g\":115},\"8c1ef0f9\":{\"m\":113,\"g\":115},\"f5a2faf2\":{\"m\":113,\"g\":115},\"1c82d9db\":{\"m\":113,\"g\":115},\"9241f4fd\":{\"m\":113,\"g\":115},\"063c3791\":{\"m\":113,\"g\":115},\"632b7d8c\":{\"m\":113,\"g\":115},\"16adf3dc\":{\"m\":113,\"g\":115},\"c3a1d775\":{\"m\":113,\"g\":115},\"89971c4c\":{\"m\":113,\"g\":115},\"113f8f65\":{\"m\":113,\"g\":115},\"e22f3a5e\":{\"m\":113,\"g\":115},\"095093ee\":{\"m\":113,\"g\":115},\"d27a6f70\":{\"m\":113,\"g\":115},\"0753ef83\":{\"m\":113,\"g\":115},\"662393f2\":{\"m\":113,\"g\":115},\"b1bb8e74\":{\"m\":113,\"g\":115},\"38c00ed7\":{\"m\":113,\"g\":115},\"d4041a5e\":{\"m\":113,\"g\":115},\"2f555c4c\":{\"m\":113,\"g\":115},\"e53df7c0\":{\"m\":113,\"g\":115},\"9c53dad8\":{\"m\":113,\"g\":115},\"7ca1bea6\":{\"m\":113,\"g\":115},\"97c38239\":{\"m\":113,\"g\":115},\"60dbbd08\":{\"m\":113,\"g\":115},\"aa1c5cf5\":{\"m\":113,\"g\":115},\"592caab6\":{\"m\":113,\"g\":115},\"2101d93b\":{\"m\":113,\"g\":115},\"70e4b218\":{\"m\":113,\"g\":115},\"944f1ea0\":{\"m\":113,\"g\":115},\"9d7e82a0\":{\"m\":113,\"g\":115},\"f0580551\":{\"m\":113,\"g\":115},\"635ccda6\":{\"m\":113,\"g\":115},\"1c3dbad8\":{\"m\":113,\"g\":115},\"e2ac7888\":{\"m\":113,\"g\":115},\"86527a47\":{\"m\":113,\"g\":115},\"134b4f7e\":{\"m\":113,\"g\":115},\"f67d1f45\":{\"m\":113,\"g\":115},\"0f04a5f4\":{\"m\":113,\"g\":115},\"2f18602f\":{\"m\":113,\"g\":115},\"56321e9f\":{\"m\":113,\"g\":115},\"12d6cf18\":{\"m\":113,\"g\":115},\"fc3e5420\":{\"m\":113,\"g\":115},\"08ecd0aa\":{\"m\":113,\"g\":115},\"720c1c8c\":{\"m\":113,\"g\":115},\"d403c143\":{\"m\":113,\"g\":115},\"cba0d8c3\":{\"m\":113,\"g\":115},\"f1d78923\":{\"m\":113,\"g\":115},\"7c876de7\":{\"m\":113,\"g\":115},\"ba94b829\":{\"m\":113,\"g\":115},\"2b7417bf\":{\"m\":113,\"g\":115},\"f1116495\":{\"m\":113,\"g\":115},\"bd7eb020\":{\"m\":113,\"g\":115},\"74cd6e39\":{\"m\":113,\"g\":115},\"b17e67df\":{\"m\":113,\"g\":115},\"8ecef73f\":{\"m\":113,\"g\":115},\"1d1ce624\":{\"m\":113,\"g\":115},\"60e2a7ce\":{\"m\":113,\"g\":115},\"d88ef4a3\":{\"m\":113,\"g\":115},\"6f993e8b\":{\"m\":113,\"g\":115},\"03ce92e5\":{\"m\":113,\"g\":115},\"00eb5eb7\":{\"m\":113,\"g\":115},\"dab4663b\":{\"m\":113,\"g\":115},\"610a6d6e\":{\"m\":113,\"g\":115},\"36efd5be\":{\"m\":113,\"g\":115},\"68cdc189\":{\"m\":113,\"g\":115},\"7f399e4b\":{\"m\":113,\"g\":115},\"873d858b\":{\"m\":113,\"g\":115},\"3fa3c22a\":{\"m\":113,\"g\":115},\"4f2055ad\":{\"m\":113,\"g\":115},\"616a3e20\":{\"m\":113,\"g\":115},\"ac2a723b\":{\"m\":113,\"g\":115},\"56b991b1\":{\"m\":113,\"g\":115},\"780d6a22\":{\"m\":113,\"g\":115},\"8b713c72\":{\"m\":113,\"g\":115},\"5bfafdfc\":{\"m\":113,\"g\":115},\"8c52de6f\":{\"m\":113,\"g\":115},\"c1815a99\":{\"m\":113,\"g\":115},\"4e6c4923\":{\"m\":113,\"g\":115},\"b91cb67e\":{\"m\":113,\"g\":115},\"e7bc6003\":{\"m\":113,\"g\":115},\"2a2ff9a8\":{\"m\":113,\"g\":115},\"5291f32d\":{\"m\":113,\"g\":115},\"67073dde\":{\"m\":113,\"g\":115},\"9a5c42f9\":{\"m\":113,\"g\":115},\"388c05d5\":{\"m\":113,\"g\":115},\"fc809665\":{\"m\":113,\"g\":115},\"6fd4816d\":{\"m\":113,\"g\":115},\"1344ebc8\":{\"m\":113,\"g\":115},\"e07b21ce\":{\"m\":113,\"g\":115},\"52f248cd\":{\"m\":113,\"g\":115},\"93f75778\":{\"m\":113,\"g\":115},\"4039c626\":{\"m\":113,\"g\":115},\"db71c38f\":{\"m\":113,\"g\":115},\"7a68b422\":{\"m\":113,\"g\":115},\"60fc5b51\":{\"m\":113,\"g\":115},\"a13dd1e4\":{\"m\":113,\"g\":115},\"d500eb91\":{\"m\":113,\"g\":115},\"1ccd59c7\":{\"m\":113,\"g\":115},\"c32fb7a2\":{\"m\":113,\"g\":115},\"1ba137e9\":{\"m\":113,\"g\":115},\"de28f8e7\":{\"m\":113,\"g\":115},\"56405076\":{\"m\":113,\"g\":115},\"b73ac629\":{\"m\":113,\"g\":115},\"77098aea\":{\"m\":113,\"g\":115},\"5ccf0b03\":{\"m\":113,\"g\":115},\"a77564e0\":{\"m\":113,\"g\":115},\"4f9e71df\":{\"m\":113,\"g\":115},\"541551ce\":{\"m\":113,\"g\":115},\"124097fc\":{\"m\":113,\"g\":115},\"e1d45bc2\":{\"m\":113,\"g\":115},\"14fdd527\":{\"m\":113,\"g\":115},\"f949ad57\":{\"m\":113,\"g\":115},\"c49484a6\":{\"m\":113,\"g\":115},\"a2f7218a\":{\"m\":113,\"g\":115},\"311de47b\":{\"m\":113,\"g\":115},\"373080ea\":{\"m\":113,\"g\":115},\"7f028b07\":{\"m\":113,\"g\":115},\"0abb41c7\":{\"m\":113,\"g\":115},\"925dbb32\":{\"m\":113,\"g\":115},\"8df7353a\":{\"m\":113,\"g\":115},\"ae4be601\":{\"m\":113,\"g\":115},\"9b876889\":{\"m\":113,\"g\":115},\"c0c6f543\":{\"m\":113,\"g\":115},\"edd6a07b\":{\"m\":113,\"g\":115},\"b6dd4bcb\":{\"m\":113,\"g\":115},\"b2435be6\":{\"m\":113,\"g\":115},\"5fe39e85\":{\"m\":113,\"g\":115},\"fa5d0bf6\":{\"m\":113,\"g\":115},\"16e93359\":{\"m\":113,\"g\":115},\"f1c692f6\":{\"m\":113,\"g\":115},\"80572c83\":{\"m\":113,\"g\":115},\"4bb08f6e\":{\"m\":113,\"g\":115},\"ec272dda\":{\"m\":113,\"g\":115},\"a220c14f\":{\"m\":113,\"g\":115},\"35ef3f29\":{\"m\":113,\"g\":115},\"31fb19a0\":{\"m\":113,\"g\":115},\"3f41b48c\":{\"m\":113,\"g\":115},\"2689f0bf\":{\"m\":113,\"g\":115},\"52074240\":{\"m\":113,\"g\":115},\"c3c26f76\":{\"m\":113,\"g\":115},\"2cf811a9\":{\"m\":113,\"g\":115},\"3b25dc12\":{\"m\":113,\"g\":115},\"5c08d7d2\":{\"m\":113,\"g\":115},\"a45d9a4e\":{\"m\":113,\"g\":115},\"28c79dc8\":{\"m\":113,\"g\":115},\"1fcccda4\":{\"m\":113,\"g\":115},\"79acec4f\":{\"m\":113,\"g\":115},\"b1721edb\":{\"m\":113,\"g\":115},\"57234d0c\":{\"m\":113,\"g\":115},\"b93acd70\":{\"m\":113,\"g\":115},\"86a32bb5\":{\"m\":113,\"g\":115},\"5afd0365\":{\"m\":113,\"g\":115},\"059c13de\":{\"m\":113,\"g\":115},\"50dc0c1e\":{\"m\":113,\"g\":115},\"76becc1d\":{\"m\":113,\"g\":115},\"2a37b24d\":{\"m\":113,\"g\":115},\"f73aae0b\":{\"m\":113,\"g\":115},\"69b35793\":{\"m\":113,\"g\":115},\"957482c8\":{\"m\":113,\"g\":115},\"3795b6a4\":{\"m\":113,\"g\":115},\"7eccbe99\":{\"m\":113,\"g\":115},\"0549f21c\":{\"m\":113,\"g\":115},\"b354e3c9\":{\"m\":113,\"g\":115},\"65e6f48c\":{\"m\":113,\"g\":115},\"0ec580a8\":{\"m\":113,\"g\":115},\"0b14159f\":{\"m\":113,\"g\":115},\"1489cd6c\":{\"m\":113,\"g\":115},\"fc2c3a3d\":{\"m\":113,\"g\":115},\"8f6a1758\":{\"m\":113,\"g\":115},\"01018138\":{\"m\":113,\"g\":115},\"4844fac9\":{\"m\":113,\"g\":115},\"b7d385e8\":{\"m\":113,\"g\":115},\"305c9e8c\":{\"m\":113,\"g\":115},\"ca63f075\":{\"m\":113,\"g\":115},\"f9ee6ae1\":{\"m\":113,\"g\":115},\"dcee42c2\":{\"m\":113,\"g\":115},\"258d02c8\":{\"m\":113,\"g\":115},\"60d7beda\":{\"m\":113,\"g\":115},\"2f8ba6fe\":{\"m\":113,\"g\":115},\"7ce6c10e\":{\"m\":113,\"g\":115},\"55025b92\":{\"m\":113,\"g\":115},\"4c21b090\":{\"m\":113,\"g\":115},\"165abeeb\":{\"m\":113,\"g\":115},\"21ca4c3a\":{\"m\":113,\"g\":115},\"e3cf812f\":{\"m\":113,\"g\":115},\"4da55336\":{\"m\":113,\"g\":115},\"ac964d2e\":{\"m\":113,\"g\":115},\"fa46e2bd\":{\"m\":113,\"g\":115},\"b047b553\":{\"m\":113,\"g\":115},\"a0f844ed\":{\"m\":113,\"g\":115},\"2df532ef\":{\"m\":113,\"g\":115},\"abea9250\":{\"m\":113,\"g\":115},\"b3c97762\":{\"m\":113,\"g\":115},\"b8347b40\":{\"m\":113,\"g\":115},\"72dfa96a\":{\"m\":113,\"g\":115},\"05b01ef4\":{\"m\":113,\"g\":115},\"55a6e644\":{\"m\":113,\"g\":115},\"6897e06b\":{\"m\":113,\"g\":115},\"a360511d\":{\"m\":113,\"g\":115},\"94d0f656\":{\"m\":113,\"g\":115},\"eca59f96\":{\"m\":113,\"g\":115},\"97528610\":{\"m\":113,\"g\":115},\"297d3745\":{\"m\":113,\"g\":115},\"31e9d3a5\":{\"m\":113,\"g\":115},\"6f4676ef\":{\"m\":113,\"g\":115},\"c9ec4cae\":{\"m\":113,\"g\":115},\"99757cc3\":{\"m\":113,\"g\":115},\"cdddab05\":{\"m\":113,\"g\":115},\"7c5a0a1b\":{\"m\":113,\"g\":115},\"49f169d5\":{\"m\":113,\"g\":115},\"7fce2fd9\":{\"m\":113,\"g\":115},\"16cd550c\":{\"m\":113,\"g\":115},\"d5e2a374\":{\"m\":113,\"g\":115},\"366043db\":{\"m\":113,\"g\":115},\"2f173ea0\":{\"m\":113,\"g\":115},\"321fecab\":{\"m\":113,\"g\":115},\"9d775b1a\":{\"m\":113,\"g\":115},\"78b7465c\":{\"m\":113,\"g\":115},\"07bcad7f\":{\"m\":113,\"g\":115},\"98adac8e\":{\"m\":113,\"g\":115},\"cef11e9a\":{\"m\":113,\"g\":115},\"2269cf1e\":{\"m\":113,\"g\":115},\"151e287d\":{\"m\":113,\"g\":115},\"8c86595c\":{\"m\":113,\"g\":115},\"4634fd59\":{\"m\":113,\"g\":115},\"efedbe6c\":{\"m\":113,\"g\":115},\"3a77c80b\":{\"m\":113,\"g\":115},\"36acd2ff\":{\"m\":113,\"g\":115},\"fe6cdf89\":{\"m\":113,\"g\":115},\"30d20ce8\":{\"m\":113,\"g\":115},\"1b1701f1\":{\"m\":113,\"g\":115},\"6d403089\":{\"m\":113,\"g\":115},\"24dc2bee\":{\"m\":113,\"g\":115},\"fac07c9b\":{\"m\":113,\"g\":115},\"b3839a7f\":{\"m\":113,\"g\":115},\"4aa39d72\":{\"m\":113,\"g\":115},\"b4c2c421\":{\"m\":113,\"g\":115},\"53ca1552\":{\"m\":113,\"g\":115},\"a23bdeaf\":{\"m\":113,\"g\":115},\"27778010\":{\"m\":113,\"g\":115},\"46d8fb1c\":{\"m\":113,\"g\":115},\"c7e85f53\":{\"m\":113,\"g\":115},\"3df05f4d\":{\"m\":113,\"g\":115},\"7b141f81\":{\"m\":113,\"g\":115},\"7bc5fb0d\":{\"m\":113,\"g\":115},\"144ee5f3\":{\"m\":113,\"g\":115},\"758b887a\":{\"m\":114,\"g\":115},\"eb7d9261\":{\"m\":114,\"g\":115},\"44cb0607\":{\"m\":114,\"g\":115},\"88bb627d\":{\"m\":114,\"g\":115},\"b520958e\":{\"m\":114,\"g\":115},\"fa7e2c30\":{\"m\":114,\"g\":115},\"8f2cd177\":{\"m\":114,\"g\":115},\"ab926dd6\":{\"m\":114,\"g\":115},\"a4b424c6\":{\"m\":114,\"g\":115},\"a0557642\":{\"m\":114,\"g\":115},\"84768d10\":{\"m\":114,\"g\":115},\"368fd206\":{\"m\":114,\"g\":115},\"53bd00d9\":{\"m\":114,\"g\":115},\"e22b13c5\":{\"m\":114,\"g\":115},\"a3c2ea44\":{\"m\":114,\"g\":115},\"fccac7d1\":{\"m\":114,\"g\":115},\"7ac6b900\":{\"m\":114,\"g\":115},\"a1080b72\":{\"m\":114,\"g\":115},\"a65ca739\":{\"m\":114,\"g\":115},\"677aa0e2\":{\"m\":114,\"g\":115},\"01c9ee1a\":{\"m\":114,\"g\":115},\"d6837aea\":{\"m\":114,\"g\":115},\"c882b5ae\":{\"m\":114,\"g\":115},\"e3bb7f5a\":{\"m\":114,\"g\":115},\"92473e2e\":{\"m\":114,\"g\":115},\"6c0bb327\":{\"m\":114,\"g\":115},\"0a7c4bde\":{\"m\":114,\"g\":115},\"edefab0c\":{\"m\":114,\"g\":115},\"97cd38e5\":{\"m\":114,\"g\":115},\"3c06b673\":{\"m\":114,\"g\":115},\"7c3f07db\":{\"m\":114,\"g\":115},\"edd86b88\":{\"m\":114,\"g\":115},\"4b4dc132\":{\"m\":114,\"g\":115},\"5a9170d9\":{\"m\":114,\"g\":115},\"c4d77774\":{\"m\":114,\"g\":115},\"832c84fb\":{\"m\":114,\"g\":115},\"64d1505c\":{\"m\":114,\"g\":115},\"f3764c26\":{\"m\":114,\"g\":115},\"7ba3de0e\":{\"m\":114,\"g\":115},\"fde9b963\":{\"m\":114,\"g\":115},\"f094e0a4\":{\"m\":114,\"g\":115},\"4ed67c27\":{\"m\":114,\"g\":115},\"cd4b39a9\":{\"m\":114,\"g\":115},\"420c99ac\":{\"m\":114,\"g\":115},\"e3c7f091\":{\"m\":114,\"g\":115},\"6f1e03a4\":{\"m\":114,\"g\":115},\"f4affd4d\":{\"m\":114,\"g\":115},\"df08bf9b\":{\"m\":114,\"g\":115},\"69efdd27\":{\"m\":114,\"g\":115},\"64582caa\":{\"m\":114,\"g\":115},\"2fcd56ea\":{\"m\":114,\"g\":115},\"0958a397\":{\"m\":114,\"g\":115},\"4f42c8cd\":{\"m\":114,\"g\":115},\"3ddd7dc9\":{\"m\":114,\"g\":115},\"501dfa6b\":{\"m\":114,\"g\":115},\"79d34951\":{\"m\":114,\"g\":115},\"1519a89c\":{\"m\":114,\"g\":115},\"24bc3fb0\":{\"m\":114,\"g\":115},\"8a8a608a\":{\"m\":114,\"g\":115},\"533e58a1\":{\"m\":114,\"g\":115},\"9b4c4497\":{\"m\":114,\"g\":115},\"a578d300\":{\"m\":114,\"g\":115},\"8c967037\":{\"m\":114,\"g\":115},\"fb27d383\":{\"m\":114,\"g\":115},\"fd8a0b29\":{\"m\":114,\"g\":115},\"afc35ccc\":{\"m\":114,\"g\":115},\"a57f0e3d\":{\"m\":114,\"g\":115},\"708f4ff4\":{\"m\":114,\"g\":115},\"e2daeb35\":{\"m\":114,\"g\":115},\"0e7b3530\":{\"m\":114,\"g\":115},\"b07c9c76\":{\"m\":114,\"g\":115},\"748f86f3\":{\"m\":114,\"g\":115},\"73ea484a\":{\"m\":114,\"g\":115},\"4aeb193f\":{\"m\":114,\"g\":115},\"466992b2\":{\"m\":114,\"g\":115},\"155cbb51\":{\"m\":114,\"g\":115},\"eb30b888\":{\"m\":114,\"g\":115},\"5ee777c9\":{\"m\":114,\"g\":115},\"baf277a9\":{\"m\":116,\"g\":118},\"f5d30dae\":{\"m\":116,\"g\":118},\"2479b894\":{\"m\":116,\"g\":118},\"54644572\":{\"m\":116,\"g\":118},\"6c01844f\":{\"m\":116,\"g\":118},\"f226d3da\":{\"m\":116,\"g\":118},\"d2478cd4\":{\"m\":116,\"g\":118},\"30ea4c46\":{\"m\":116,\"g\":118},\"6d036468\":{\"m\":116,\"g\":118},\"8221f9ae\":{\"m\":116,\"g\":118},\"ab9187a2\":{\"m\":116,\"g\":118},\"6b143d62\":{\"m\":116,\"g\":118},\"6bc503af\":{\"m\":116,\"g\":118},\"b2c85669\":{\"m\":116,\"g\":118},\"32803fb2\":{\"m\":116,\"g\":118},\"91fc5bb5\":{\"m\":116,\"g\":118},\"780fbf2f\":{\"m\":116,\"g\":118},\"825432fc\":{\"m\":116,\"g\":118},\"a40229f6\":{\"m\":116,\"g\":118},\"74737b28\":{\"m\":116,\"g\":115},\"40e0082d\":{\"m\":116,\"g\":115},\"e9e120ac\":{\"m\":116,\"g\":115},\"e0c2af2a\":{\"m\":116,\"g\":115},\"1d7f7835\":{\"m\":116,\"g\":115},\"32595146\":{\"m\":116,\"g\":115},\"86373b9e\":{\"m\":116,\"g\":115},\"d314bf60\":{\"m\":116,\"g\":115},\"e28c9e52\":{\"m\":116,\"g\":115},\"b98cf398\":{\"m\":116,\"g\":115},\"27d71045\":{\"m\":116,\"g\":115},\"c224a4c6\":{\"m\":116,\"g\":115},\"49345a68\":{\"m\":116,\"g\":115},\"94d26d85\":{\"m\":116,\"g\":115},\"9e8a15a7\":{\"m\":116,\"g\":115},\"3962e39d\":{\"m\":116,\"g\":115},\"eb8cac6f\":{\"m\":116,\"g\":115},\"5ea96ac7\":{\"m\":116,\"g\":115},\"56222658\":{\"m\":116,\"g\":115},\"dc965db0\":{\"m\":116,\"g\":115},\"817e46f4\":{\"m\":116,\"g\":115},\"5a33c3aa\":{\"m\":116,\"g\":115},\"9767a1e4\":{\"m\":116,\"g\":115},\"1d086539\":{\"m\":116,\"g\":115},\"a04efc49\":{\"m\":116,\"g\":115},\"642fa966\":{\"m\":116,\"g\":115},\"da7fac1b\":{\"m\":116,\"g\":115},\"28ad2297\":{\"m\":116,\"g\":115},\"f7f9f8ec\":{\"m\":116,\"g\":115},\"4b62af92\":{\"m\":116,\"g\":115},\"0b9915c1\":{\"m\":116,\"g\":115},\"27ef1459\":{\"m\":116,\"g\":115},\"e4358a45\":{\"m\":116,\"g\":115},\"ba2ce28f\":{\"m\":116,\"g\":115},\"98923880\":{\"m\":116,\"g\":115},\"f792e3c5\":{\"m\":116,\"g\":115},\"28f80b12\":{\"m\":116,\"g\":115},\"88a6f9da\":{\"m\":116,\"g\":115},\"cb8ed2c0\":{\"m\":116,\"g\":115},\"38473363\":{\"m\":116,\"g\":115},\"aaf7af1b\":{\"m\":116,\"g\":115},\"932e2637\":{\"m\":116,\"g\":115},\"43f80884\":{\"m\":116,\"g\":115},\"60b05032\":{\"m\":116,\"g\":115},\"dc48c4c0\":{\"m\":116,\"g\":115},\"6dc9ca8c\":{\"m\":116,\"g\":115},\"887c2b45\":{\"m\":116,\"g\":115},\"065ce815\":{\"m\":116,\"g\":115},\"8e51049f\":{\"m\":116,\"g\":115},\"cb8f3d90\":{\"m\":116,\"g\":115},\"4b694e7d\":{\"m\":116,\"g\":115},\"9f1f699a\":{\"m\":116,\"g\":115},\"c9cff2b9\":{\"m\":116,\"g\":115},\"b6fb5d76\":{\"m\":116,\"g\":115},\"f4aa7880\":{\"m\":116,\"g\":115},\"5e3f7e7f\":{\"m\":116,\"g\":115},\"728af887\":{\"m\":116,\"g\":115},\"7b59b0b8\":{\"m\":116,\"g\":115},\"acc2327b\":{\"m\":116,\"g\":115},\"bfadb5ea\":{\"m\":116,\"g\":115},\"9cc1e065\":{\"m\":116,\"g\":115},\"b8c430f1\":{\"m\":116,\"g\":115},\"f35f120d\":{\"m\":116,\"g\":115},\"54a46a26\":{\"m\":116,\"g\":115},\"7c94eaee\":{\"m\":116,\"g\":115},\"13d596c9\":{\"m\":116,\"g\":115},\"c7867b67\":{\"m\":116,\"g\":115},\"516738b0\":{\"m\":116,\"g\":115},\"0b6f535f\":{\"m\":116,\"g\":115},\"c5fe3c0b\":{\"m\":116,\"g\":115},\"318424e2\":{\"m\":116,\"g\":115},\"6806c4e6\":{\"m\":116,\"g\":115},\"0c0779d6\":{\"m\":116,\"g\":115},\"a55cf530\":{\"m\":116,\"g\":115},\"19ba16aa\":{\"m\":116,\"g\":115},\"a2b3d9b9\":{\"m\":116,\"g\":115},\"9a30914e\":{\"m\":116,\"g\":115},\"8e776c78\":{\"m\":116,\"g\":115},\"63e84352\":{\"m\":116,\"g\":115},\"a20e7df8\":{\"m\":116,\"g\":115},\"1bdd0102\":{\"m\":116,\"g\":115},\"6cd29694\":{\"m\":116,\"g\":115},\"2ac46e94\":{\"m\":116,\"g\":115},\"0aa65f94\":{\"m\":116,\"g\":115},\"0ecb4261\":{\"m\":116,\"g\":115},\"05f015f6\":{\"m\":116,\"g\":115},\"1083e7e3\":{\"m\":116,\"g\":115},\"2157d12a\":{\"m\":116,\"g\":115},\"9f2b457c\":{\"m\":116,\"g\":115},\"f5b34a51\":{\"m\":116,\"g\":115},\"5a6ec8f9\":{\"m\":116,\"g\":115},\"6a653bb1\":{\"m\":116,\"g\":115},\"548a57b1\":{\"m\":116,\"g\":115},\"88e73ed0\":{\"m\":116,\"g\":115},\"4b15fa00\":{\"m\":116,\"g\":115},\"f4941906\":{\"m\":116,\"g\":115},\"01e59e82\":{\"m\":116,\"g\":115},\"99a0704a\":{\"m\":116,\"g\":115},\"ec1cd90a\":{\"m\":116,\"g\":115},\"1103dc62\":{\"m\":116,\"g\":115},\"a220536f\":{\"m\":116,\"g\":115},\"7b064f04\":{\"m\":116,\"g\":115},\"43190bec\":{\"m\":116,\"g\":115},\"be740acd\":{\"m\":116,\"g\":115},\"2db2cddd\":{\"m\":116,\"g\":115},\"9b5efe34\":{\"m\":116,\"g\":115},\"4ac8e09d\":{\"m\":116,\"g\":115},\"20a6c0a6\":{\"m\":116,\"g\":115},\"47c606d3\":{\"m\":116,\"g\":115},\"9fcf7306\":{\"m\":116,\"g\":115},\"0a304870\":{\"m\":116,\"g\":115},\"8fdcd98e\":{\"m\":116,\"g\":115},\"b5dcfd41\":{\"m\":116,\"g\":115},\"5061b8fd\":{\"m\":116,\"g\":115},\"c8452551\":{\"m\":116,\"g\":115},\"bf3e7149\":{\"m\":116,\"g\":115},\"f5754d12\":{\"m\":116,\"g\":115},\"739daa63\":{\"m\":116,\"g\":115},\"d957177a\":{\"m\":116,\"g\":115},\"21337b22\":{\"m\":116,\"g\":115},\"129d2992\":{\"m\":116,\"g\":115},\"8b85926a\":{\"m\":116,\"g\":115},\"451d15c4\":{\"m\":116,\"g\":115},\"c80a96da\":{\"m\":116,\"g\":115},\"eae9a9fb\":{\"m\":116,\"g\":115},\"2674c1d2\":{\"m\":116,\"g\":115},\"61055cb3\":{\"m\":116,\"g\":115},\"92777135\":{\"m\":116,\"g\":115},\"c4958331\":{\"m\":116,\"g\":115},\"2eeb2751\":{\"m\":116,\"g\":115},\"b36afed4\":{\"m\":116,\"g\":115},\"9aa4502d\":{\"m\":116,\"g\":115},\"a0835c3a\":{\"m\":116,\"g\":115},\"55b14656\":{\"m\":116,\"g\":115},\"b4408e60\":{\"m\":116,\"g\":115},\"52fcbbb8\":{\"m\":116,\"g\":115},\"af96ca11\":{\"m\":116,\"g\":115},\"9082a7d3\":{\"m\":116,\"g\":115},\"3b9d97f3\":{\"m\":116,\"g\":115},\"a1a20b4c\":{\"m\":116,\"g\":115},\"4299aebd\":{\"m\":116,\"g\":115},\"0babd487\":{\"m\":116,\"g\":115},\"f19613e6\":{\"m\":116,\"g\":115},\"8df49455\":{\"m\":116,\"g\":115},\"ee3bd8a1\":{\"m\":116,\"g\":115},\"d8467db7\":{\"m\":116,\"g\":115},\"b5044fbf\":{\"m\":116,\"g\":115},\"70fbb3ad\":{\"m\":116,\"g\":115},\"9a7e7a65\":{\"m\":116,\"g\":115},\"0fe87213\":{\"m\":116,\"g\":115},\"1f106ee3\":{\"m\":116,\"g\":115},\"9b8ebb27\":{\"m\":116,\"g\":115},\"85ebeecf\":{\"m\":117,\"g\":118},\"0dd6cf16\":{\"m\":117,\"g\":118},\"0975ba99\":{\"m\":117,\"g\":118},\"1de3924b\":{\"m\":117,\"g\":118},\"3cceaa38\":{\"m\":117,\"g\":118},\"b0d20cde\":{\"m\":117,\"g\":118},\"cbac4997\":{\"m\":117,\"g\":118},\"476c67d7\":{\"m\":117,\"g\":118},\"3289da5b\":{\"m\":117,\"g\":118},\"868403f6\":{\"m\":117,\"g\":118},\"97d857c0\":{\"m\":117,\"g\":118},\"52a54a26\":{\"m\":117,\"g\":118},\"cd7e1bd5\":{\"m\":117,\"g\":118},\"729b7edf\":{\"m\":117,\"g\":118},\"4c03dbaa\":{\"m\":117,\"g\":118},\"1053e1be\":{\"m\":119,\"g\":121},\"9a71500c\":{\"m\":119,\"g\":121},\"6d6e24bc\":{\"m\":119,\"g\":121},\"2c057fbf\":{\"m\":119,\"g\":121},\"dbd9435d\":{\"m\":119,\"g\":121},\"8ae9d4bb\":{\"m\":119,\"g\":121},\"1c304aa9\":{\"m\":119,\"g\":121},\"770529a7\":{\"m\":119,\"g\":121},\"39c237f0\":{\"m\":119,\"g\":121},\"28b8a406\":{\"m\":119,\"g\":121},\"8bd26dd4\":{\"m\":119,\"g\":121},\"ab07cd3e\":{\"m\":119,\"g\":121},\"a9849683\":{\"m\":119,\"g\":121},\"a4b637d8\":{\"m\":119,\"g\":121},\"96a5e4dd\":{\"m\":119,\"g\":121},\"b0b4f716\":{\"m\":119,\"g\":121},\"6c18addb\":{\"m\":119,\"g\":121},\"32852fe9\":{\"m\":119,\"g\":121},\"53c2934d\":{\"m\":119,\"g\":121},\"e321c971\":{\"m\":119,\"g\":121},\"d6fee73d\":{\"m\":119,\"g\":121},\"36a4cad7\":{\"m\":119,\"g\":121},\"65d376b4\":{\"m\":119,\"g\":121},\"c23eda85\":{\"m\":119,\"g\":121},\"138ff231\":{\"m\":119,\"g\":121},\"13fb8b54\":{\"m\":119,\"g\":121},\"81fd2b0e\":{\"m\":119,\"g\":121},\"007b849b\":{\"m\":119,\"g\":121},\"8612811d\":{\"m\":119,\"g\":121},\"e7aa4664\":{\"m\":119,\"g\":121},\"4d4feccb\":{\"m\":119,\"g\":121},\"99c92ff2\":{\"m\":119,\"g\":121},\"6ade6a02\":{\"m\":119,\"g\":121},\"983ef22c\":{\"m\":119,\"g\":121},\"164302c7\":{\"m\":119,\"g\":121},\"5dccf697\":{\"m\":119,\"g\":121},\"eec9e471\":{\"m\":119,\"g\":121},\"6d535b71\":{\"m\":119,\"g\":121},\"fdcb1d13\":{\"m\":119,\"g\":121},\"d7e834d6\":{\"m\":119,\"g\":121},\"200a3c0b\":{\"m\":119,\"g\":121},\"77258ce0\":{\"m\":119,\"g\":121},\"1d097aac\":{\"m\":119,\"g\":121},\"7fceeef5\":{\"m\":119,\"g\":121},\"88568c01\":{\"m\":119,\"g\":121},\"904655c5\":{\"m\":119,\"g\":121},\"e028af69\":{\"m\":119,\"g\":121},\"80b2b320\":{\"m\":119,\"g\":121},\"4b65ed42\":{\"m\":119,\"g\":121},\"23afdfd1\":{\"m\":119,\"g\":121},\"9d61205d\":{\"m\":119,\"g\":121},\"590bc4b7\":{\"m\":119,\"g\":121},\"63cfe1b0\":{\"m\":119,\"g\":121},\"70f6309c\":{\"m\":119,\"g\":121},\"70416001\":{\"m\":119,\"g\":121},\"87a92e45\":{\"m\":119,\"g\":121},\"c461e771\":{\"m\":119,\"g\":121},\"fde2decf\":{\"m\":119,\"g\":121},\"9792b9d7\":{\"m\":119,\"g\":121},\"ef4a8097\":{\"m\":119,\"g\":121},\"ebff4ee6\":{\"m\":119,\"g\":121},\"2b1da821\":{\"m\":119,\"g\":121},\"97710ccd\":{\"m\":119,\"g\":121},\"f3cd5d25\":{\"m\":119,\"g\":121},\"c61b0b29\":{\"m\":119,\"g\":121},\"e8640ee9\":{\"m\":119,\"g\":121},\"d0a64c7e\":{\"m\":119,\"g\":121},\"05d3667a\":{\"m\":119,\"g\":121},\"260fe755\":{\"m\":119,\"g\":121},\"dbb16bed\":{\"m\":119,\"g\":121},\"c1e16003\":{\"m\":119,\"g\":121},\"852c0578\":{\"m\":119,\"g\":121},\"7e6191c0\":{\"m\":119,\"g\":121},\"6f9b66bd\":{\"m\":119,\"g\":121},\"8a801ee3\":{\"m\":119,\"g\":118},\"d9a20fd2\":{\"m\":119,\"g\":118},\"b113c72e\":{\"m\":119,\"g\":118},\"fb6cc7b0\":{\"m\":119,\"g\":118},\"8374a96e\":{\"m\":119,\"g\":118},\"74de76c6\":{\"m\":119,\"g\":118},\"9c0b1eb5\":{\"m\":119,\"g\":118},\"01f14a7a\":{\"m\":119,\"g\":118},\"11110303\":{\"m\":119,\"g\":118},\"28ddfb37\":{\"m\":119,\"g\":118},\"e69094df\":{\"m\":119,\"g\":118},\"43ad0590\":{\"m\":119,\"g\":118},\"b4948512\":{\"m\":119,\"g\":118},\"ddcba74b\":{\"m\":119,\"g\":118},\"0917c5da\":{\"m\":119,\"g\":118},\"184a4df6\":{\"m\":119,\"g\":118},\"f7b1d8c5\":{\"m\":119,\"g\":118},\"bfc3b3f7\":{\"m\":119,\"g\":118},\"da5bde4d\":{\"m\":119,\"g\":118},\"276e7b3e\":{\"m\":119,\"g\":118},\"296f6892\":{\"m\":119,\"g\":118},\"9edb7b51\":{\"m\":119,\"g\":118},\"e53bf442\":{\"m\":119,\"g\":118},\"d383e661\":{\"m\":119,\"g\":118},\"984fbeb1\":{\"m\":119,\"g\":118},\"a2ba0bc3\":{\"m\":119,\"g\":118},\"6d2d0ce2\":{\"m\":119,\"g\":118},\"271d3d0d\":{\"m\":119,\"g\":118},\"c4e81e64\":{\"m\":119,\"g\":118},\"c726d44c\":{\"m\":119,\"g\":118},\"283c8ba0\":{\"m\":119,\"g\":118},\"cae39565\":{\"m\":119,\"g\":118},\"27a223ab\":{\"m\":119,\"g\":118},\"53529f46\":{\"m\":119,\"g\":118},\"24ed3f32\":{\"m\":119,\"g\":118},\"44f0ece9\":{\"m\":119,\"g\":118},\"be0058bc\":{\"m\":119,\"g\":118},\"9e3be1fa\":{\"m\":119,\"g\":118},\"a8ba3279\":{\"m\":119,\"g\":118},\"3b80232d\":{\"m\":119,\"g\":118},\"252dc4e1\":{\"m\":119,\"g\":118},\"cbb5fc2e\":{\"m\":119,\"g\":118},\"53fb229f\":{\"m\":119,\"g\":118},\"4fff1ec1\":{\"m\":119,\"g\":118},\"7a020e0f\":{\"m\":119,\"g\":118},\"48738af7\":{\"m\":119,\"g\":118},\"efa47334\":{\"m\":119,\"g\":118},\"d658f049\":{\"m\":119,\"g\":118},\"57e25de7\":{\"m\":119,\"g\":118},\"12eb02e9\":{\"m\":119,\"g\":118},\"002d0373\":{\"m\":119,\"g\":118},\"a27825ae\":{\"m\":119,\"g\":118},\"ce399e15\":{\"m\":119,\"g\":118},\"ea6275df\":{\"m\":119,\"g\":118},\"eb7318f1\":{\"m\":119,\"g\":118},\"6058fb52\":{\"m\":119,\"g\":118},\"80407b04\":{\"m\":119,\"g\":118},\"b288f4f4\":{\"m\":119,\"g\":118},\"6d6ea5af\":{\"m\":119,\"g\":118},\"1dacedd2\":{\"m\":119,\"g\":118},\"b5e14b2b\":{\"m\":119,\"g\":118},\"d513ee93\":{\"m\":119,\"g\":118},\"a7ae61ed\":{\"m\":119,\"g\":118},\"fda0cb2a\":{\"m\":119,\"g\":118},\"ebda73dc\":{\"m\":119,\"g\":118},\"f4f8a1b4\":{\"m\":119,\"g\":118},\"c44e985d\":{\"m\":119,\"g\":118},\"f9a7d9b3\":{\"m\":119,\"g\":118},\"a93f10a7\":{\"m\":119,\"g\":118},\"585e1223\":{\"m\":119,\"g\":118},\"a7043c6f\":{\"m\":119,\"g\":118},\"67e34c56\":{\"m\":119,\"g\":118},\"1d726528\":{\"m\":119,\"g\":118},\"f4488e9d\":{\"m\":119,\"g\":118},\"e68a2b5b\":{\"m\":119,\"g\":118},\"31b9f19e\":{\"m\":119,\"g\":118},\"547003bd\":{\"m\":119,\"g\":118},\"f7ab9554\":{\"m\":119,\"g\":118},\"dbbd4e18\":{\"m\":119,\"g\":118},\"ca240eef\":{\"m\":119,\"g\":118},\"6c7c92eb\":{\"m\":119,\"g\":118},\"5b214b50\":{\"m\":119,\"g\":118},\"13219e1e\":{\"m\":119,\"g\":118},\"33e9bbec\":{\"m\":119,\"g\":118},\"dcb8f090\":{\"m\":119,\"g\":118},\"9eefe2c0\":{\"m\":119,\"g\":118},\"69fe3c97\":{\"m\":119,\"g\":118},\"8af84912\":{\"m\":119,\"g\":118},\"505329ca\":{\"m\":119,\"g\":118},\"8a382fd3\":{\"m\":119,\"g\":118},\"62797440\":{\"m\":119,\"g\":118},\"2614adf9\":{\"m\":119,\"g\":118},\"fdd7c69d\":{\"m\":119,\"g\":118},\"b9a54e09\":{\"m\":119,\"g\":118},\"20b8d230\":{\"m\":119,\"g\":118},\"d1984e21\":{\"m\":119,\"g\":118},\"b79f75fd\":{\"m\":119,\"g\":118},\"8fcc69e7\":{\"m\":119,\"g\":118},\"f440baa1\":{\"m\":119,\"g\":118},\"2bc3fcd4\":{\"m\":119,\"g\":118},\"a5978a20\":{\"m\":119,\"g\":118},\"e483c1ea\":{\"m\":119,\"g\":118},\"da681f35\":{\"m\":119,\"g\":118},\"9b0f725b\":{\"m\":119,\"g\":118},\"cde5a6e3\":{\"m\":119,\"g\":118},\"3e4c7da2\":{\"m\":119,\"g\":118},\"d88ac9bc\":{\"m\":119,\"g\":118},\"ce11dd82\":{\"m\":119,\"g\":118},\"9e87b60f\":{\"m\":119,\"g\":118},\"7780230a\":{\"m\":119,\"g\":118},\"dc01313d\":{\"m\":119,\"g\":118},\"7a7f99be\":{\"m\":119,\"g\":118},\"fd389df9\":{\"m\":119,\"g\":118},\"b0d1d717\":{\"m\":119,\"g\":118},\"c7962868\":{\"m\":119,\"g\":118},\"4f24ab17\":{\"m\":119,\"g\":118},\"64affab4\":{\"m\":119,\"g\":118},\"4c9bcb9d\":{\"m\":119,\"g\":118},\"86b04d25\":{\"m\":119,\"g\":118},\"55d75e11\":{\"m\":120,\"g\":121},\"3f4cc0af\":{\"m\":120,\"g\":121},\"cadfae66\":{\"m\":120,\"g\":121},\"da1766e4\":{\"m\":120,\"g\":121},\"a124b517\":{\"m\":120,\"g\":121},\"d05a968b\":{\"m\":120,\"g\":121},\"94aad0de\":{\"m\":120,\"g\":121},\"7ebc28f5\":{\"m\":120,\"g\":121},\"b89111d6\":{\"m\":120,\"g\":121},\"0b3b3e9a\":{\"m\":120,\"g\":121},\"a1d5bc4c\":{\"m\":120,\"g\":121},\"a8023891\":{\"m\":120,\"g\":121},\"0103f374\":{\"m\":120,\"g\":121},\"96a5a949\":{\"m\":120,\"g\":121},\"ea385ae8\":{\"m\":120,\"g\":121},\"9e949e58\":{\"m\":120,\"g\":121},\"6dbb569b\":{\"m\":120,\"g\":121},\"5994e6c3\":{\"m\":120,\"g\":121},\"3e6281d0\":{\"m\":120,\"g\":121},\"6371f7af\":{\"m\":120,\"g\":121},\"8491c794\":{\"m\":120,\"g\":121},\"bda3758f\":{\"m\":120,\"g\":121},\"7b36c47b\":{\"m\":120,\"g\":121},\"773d89da\":{\"m\":120,\"g\":121},\"03e7d949\":{\"m\":120,\"g\":121},\"ff604064\":{\"m\":120,\"g\":121},\"212f5e48\":{\"m\":120,\"g\":121},\"fe527812\":{\"m\":120,\"g\":121},\"97828878\":{\"m\":120,\"g\":121},\"c001deba\":{\"m\":120,\"g\":121},\"b4d2da10\":{\"m\":120,\"g\":121},\"b72f9f08\":{\"m\":120,\"g\":121},\"8e70064c\":{\"m\":120,\"g\":121},\"d98b81e2\":{\"m\":120,\"g\":121},\"bcecf27e\":{\"m\":120,\"g\":121},\"4b0ac1d5\":{\"m\":120,\"g\":121},\"8e987fa2\":{\"m\":120,\"g\":121},\"9e656dd3\":{\"m\":120,\"g\":121},\"3862661c\":{\"m\":120,\"g\":121},\"c8492978\":{\"m\":120,\"g\":121},\"428710c2\":{\"m\":120,\"g\":121},\"d9b31011\":{\"m\":120,\"g\":121},\"d0cff78f\":{\"m\":120,\"g\":121},\"e8b71445\":{\"m\":120,\"g\":121},\"4caca1ba\":{\"m\":120,\"g\":121},\"ceb105a7\":{\"m\":120,\"g\":121},\"22cbc9c0\":{\"m\":120,\"g\":121},\"3865afc5\":{\"m\":120,\"g\":121},\"4ea42f7c\":{\"m\":120,\"g\":121},\"ea13cb14\":{\"m\":120,\"g\":121},\"a04212f1\":{\"m\":120,\"g\":121},\"ce869793\":{\"m\":120,\"g\":121},\"22f55e1b\":{\"m\":120,\"g\":121},\"89824189\":{\"m\":120,\"g\":121},\"433c622e\":{\"m\":120,\"g\":121},\"20bd2271\":{\"m\":120,\"g\":121},\"64994980\":{\"m\":120,\"g\":121},\"729b2429\":{\"m\":120,\"g\":121},\"d7056c52\":{\"m\":120,\"g\":121},\"13bf565d\":{\"m\":120,\"g\":121},\"e51046be\":{\"m\":120,\"g\":121},\"4eeeae1e\":{\"m\":120,\"g\":121},\"f4b78d13\":{\"m\":120,\"g\":121},\"4463e90d\":{\"m\":120,\"g\":121},\"229f236d\":{\"m\":120,\"g\":121},\"5983e5bd\":{\"m\":120,\"g\":121},\"4b046a72\":{\"m\":120,\"g\":121},\"770d6312\":{\"m\":120,\"g\":121},\"14203432\":{\"m\":120,\"g\":121},\"d7f0d88f\":{\"m\":120,\"g\":121},\"fc86b18b\":{\"m\":120,\"g\":121},\"0bfa394a\":{\"m\":120,\"g\":121},\"e04340bf\":{\"m\":120,\"g\":121},\"84701338\":{\"m\":120,\"g\":121},\"93ef9a09\":{\"m\":120,\"g\":121},\"b04cd3d4\":{\"m\":120,\"g\":121},\"7ef5d8af\":{\"m\":120,\"g\":121},\"71d41212\":{\"m\":120,\"g\":121},\"b9fb74f3\":{\"m\":120,\"g\":121},\"e15b63a1\":{\"m\":120,\"g\":121},\"4060ed37\":{\"m\":120,\"g\":121},\"2342605e\":{\"m\":120,\"g\":121},\"dbf17a83\":{\"m\":120,\"g\":121},\"0f0c430e\":{\"m\":120,\"g\":121},\"8e797a47\":{\"m\":120,\"g\":121},\"aa3003f1\":{\"m\":120,\"g\":121},\"4793ec7d\":{\"m\":120,\"g\":121},\"92009bd2\":{\"m\":120,\"g\":121},\"4ef981e2\":{\"m\":120,\"g\":121},\"69ed8b67\":{\"m\":120,\"g\":121},\"1801cd19\":{\"m\":120,\"g\":121},\"ffc722a6\":{\"m\":120,\"g\":121},\"49afb3d9\":{\"m\":120,\"g\":121},\"f80371ff\":{\"m\":120,\"g\":121},\"62eff37b\":{\"m\":120,\"g\":121},\"47e12e08\":{\"m\":120,\"g\":121},\"823b4429\":{\"m\":120,\"g\":121},\"14a4d80e\":{\"m\":120,\"g\":121},\"41c10e67\":{\"m\":122,\"g\":127},\"0bfe1d14\":{\"m\":122,\"g\":127},\"5f98b7fe\":{\"m\":122,\"g\":127},\"a4bf5c6a\":{\"m\":122,\"g\":127},\"30ad1070\":{\"m\":122,\"g\":127},\"a80bcb5a\":{\"m\":122,\"g\":127},\"f7f9e41b\":{\"m\":122,\"g\":127},\"263eab9f\":{\"m\":122,\"g\":127},\"25257d8e\":{\"m\":122,\"g\":127},\"cf0c2415\":{\"m\":122,\"g\":127},\"5538e05c\":{\"m\":122,\"g\":127},\"c30ebb93\":{\"m\":122,\"g\":127},\"41efcaeb\":{\"m\":122,\"g\":127},\"70562969\":{\"m\":122,\"g\":127},\"b57dc169\":{\"m\":122,\"g\":127},\"0095e018\":{\"m\":122,\"g\":127},\"68486481\":{\"m\":122,\"g\":127},\"410225b7\":{\"m\":122,\"g\":127},\"2c9aebea\":{\"m\":122,\"g\":127},\"bc741073\":{\"m\":122,\"g\":127},\"2f6af1a3\":{\"m\":122,\"g\":127},\"50b6842b\":{\"m\":122,\"g\":127},\"2d5605e8\":{\"m\":122,\"g\":127},\"300b4c21\":{\"m\":122,\"g\":127},\"c0652d90\":{\"m\":122,\"g\":127},\"2e48584b\":{\"m\":122,\"g\":127},\"57cc5385\":{\"m\":122,\"g\":127},\"5cc0d25a\":{\"m\":122,\"g\":127},\"a076ec1a\":{\"m\":122,\"g\":127},\"72b5f3d0\":{\"m\":122,\"g\":127},\"2f766f38\":{\"m\":122,\"g\":127},\"069e490b\":{\"m\":122,\"g\":127},\"ab95d35f\":{\"m\":122,\"g\":127},\"34c286b8\":{\"m\":122,\"g\":127},\"9416ee60\":{\"m\":122,\"g\":127},\"d4a09ec9\":{\"m\":122,\"g\":127},\"662725b9\":{\"m\":122,\"g\":127},\"82cfcd3b\":{\"m\":122,\"g\":127},\"6c1a3f0c\":{\"m\":122,\"g\":127},\"62377548\":{\"m\":122,\"g\":121},\"96ac24c0\":{\"m\":122,\"g\":121},\"c0d02cf4\":{\"m\":122,\"g\":121},\"7d121448\":{\"m\":122,\"g\":121},\"6a63a985\":{\"m\":122,\"g\":121},\"4d2f17bd\":{\"m\":122,\"g\":121},\"7cd716f7\":{\"m\":122,\"g\":121},\"b7fdde4b\":{\"m\":122,\"g\":121},\"69bf8011\":{\"m\":122,\"g\":121},\"d8fcbaa3\":{\"m\":122,\"g\":121},\"2cf3d0f8\":{\"m\":122,\"g\":121},\"1ed1abfd\":{\"m\":122,\"g\":121},\"ecb9fa14\":{\"m\":122,\"g\":121},\"700daa34\":{\"m\":122,\"g\":121},\"39cee0fe\":{\"m\":122,\"g\":121},\"04e5b6fa\":{\"m\":122,\"g\":121},\"ce6b17c0\":{\"m\":122,\"g\":121},\"cafebef1\":{\"m\":122,\"g\":121},\"73dfd2df\":{\"m\":122,\"g\":121},\"df5192cf\":{\"m\":122,\"g\":121},\"78c43d88\":{\"m\":122,\"g\":121},\"7e28c67d\":{\"m\":122,\"g\":121},\"3edba9bc\":{\"m\":122,\"g\":121},\"e5ec9764\":{\"m\":122,\"g\":121},\"621dfb88\":{\"m\":122,\"g\":121},\"fb52d35f\":{\"m\":122,\"g\":121},\"2b71531a\":{\"m\":122,\"g\":121},\"25c50498\":{\"m\":122,\"g\":121},\"8e2ac2e6\":{\"m\":122,\"g\":121},\"17a57fd8\":{\"m\":122,\"g\":121},\"32438eba\":{\"m\":122,\"g\":121},\"fed02a49\":{\"m\":122,\"g\":121},\"03b3e89a\":{\"m\":122,\"g\":121},\"9ff9fa7f\":{\"m\":122,\"g\":121},\"7ed8ba05\":{\"m\":122,\"g\":121},\"df08f346\":{\"m\":122,\"g\":121},\"db15148c\":{\"m\":122,\"g\":121},\"5259becd\":{\"m\":122,\"g\":121},\"ed1044ac\":{\"m\":122,\"g\":121},\"d717e73e\":{\"m\":122,\"g\":121},\"a1816187\":{\"m\":122,\"g\":121},\"e39628fd\":{\"m\":122,\"g\":121},\"bacb3825\":{\"m\":122,\"g\":121},\"b53d9e11\":{\"m\":122,\"g\":121},\"8a683821\":{\"m\":122,\"g\":121},\"52694b60\":{\"m\":122,\"g\":121},\"400bddf2\":{\"m\":122,\"g\":121},\"1e90fe2e\":{\"m\":122,\"g\":121},\"caa5d296\":{\"m\":122,\"g\":121},\"750940ae\":{\"m\":122,\"g\":121},\"42f8ea40\":{\"m\":122,\"g\":121},\"14cbe42f\":{\"m\":122,\"g\":121},\"685c0645\":{\"m\":122,\"g\":121},\"1357397a\":{\"m\":122,\"g\":121},\"42e1a72e\":{\"m\":122,\"g\":121},\"83a7c89c\":{\"m\":122,\"g\":121},\"0380ca82\":{\"m\":122,\"g\":121},\"ec92b0ce\":{\"m\":122,\"g\":121},\"e03b6bee\":{\"m\":122,\"g\":121},\"5e36a0b4\":{\"m\":122,\"g\":121},\"0297773a\":{\"m\":122,\"g\":121},\"587deb15\":{\"m\":122,\"g\":121},\"83087247\":{\"m\":122,\"g\":121},\"334543ff\":{\"m\":122,\"g\":121},\"c143f416\":{\"m\":122,\"g\":121},\"b48354c5\":{\"m\":122,\"g\":121},\"29195aaa\":{\"m\":122,\"g\":121},\"0ee831de\":{\"m\":122,\"g\":121},\"8d6ab1cb\":{\"m\":122,\"g\":121},\"84a9d0ea\":{\"m\":122,\"g\":121},\"737b58d6\":{\"m\":122,\"g\":121},\"77225d60\":{\"m\":122,\"g\":121},\"9c6e25d2\":{\"m\":122,\"g\":121},\"2a3763c3\":{\"m\":122,\"g\":121},\"fdd00295\":{\"m\":122,\"g\":121},\"25e73640\":{\"m\":122,\"g\":121},\"0da9845e\":{\"m\":122,\"g\":121},\"92885441\":{\"m\":122,\"g\":121},\"64cf868e\":{\"m\":122,\"g\":121},\"ea399527\":{\"m\":122,\"g\":121},\"41a11335\":{\"m\":122,\"g\":121},\"a1f2dc90\":{\"m\":122,\"g\":121},\"ea961060\":{\"m\":122,\"g\":121},\"b1e13e7c\":{\"m\":122,\"g\":121},\"cc7b04a2\":{\"m\":122,\"g\":121},\"d85d6dba\":{\"m\":122,\"g\":121},\"c5642a7a\":{\"m\":122,\"g\":121},\"691c8534\":{\"m\":122,\"g\":121},\"d2b8c412\":{\"m\":122,\"g\":121},\"bf8f7a94\":{\"m\":122,\"g\":121},\"81a632ac\":{\"m\":122,\"g\":121},\"83b22400\":{\"m\":122,\"g\":121},\"285a8e69\":{\"m\":122,\"g\":121},\"813bd6f8\":{\"m\":122,\"g\":121},\"729f612d\":{\"m\":122,\"g\":121},\"899453ac\":{\"m\":122,\"g\":121},\"ce832d70\":{\"m\":122,\"g\":121},\"88596739\":{\"m\":122,\"g\":121},\"a6ea3add\":{\"m\":122,\"g\":121},\"326c84c4\":{\"m\":122,\"g\":121},\"8da608cc\":{\"m\":122,\"g\":121},\"9fc3e8aa\":{\"m\":122,\"g\":121},\"c11b34d5\":{\"m\":122,\"g\":121},\"05ad28f2\":{\"m\":122,\"g\":121},\"0cae873f\":{\"m\":122,\"g\":121},\"a8b91f6b\":{\"m\":122,\"g\":121},\"959d1ab8\":{\"m\":122,\"g\":121},\"6c1c1933\":{\"m\":122,\"g\":121},\"3029d301\":{\"m\":122,\"g\":121},\"f389f017\":{\"m\":122,\"g\":121},\"caa4819b\":{\"m\":122,\"g\":121},\"a88b006e\":{\"m\":122,\"g\":121},\"ce112c07\":{\"m\":122,\"g\":121},\"f7dc2f33\":{\"m\":122,\"g\":121},\"cd784faf\":{\"m\":122,\"g\":121},\"75c09e1f\":{\"m\":122,\"g\":121},\"09af0a7b\":{\"m\":122,\"g\":121},\"c8d385ce\":{\"m\":122,\"g\":121},\"09938e1f\":{\"m\":123,\"g\":127},\"23407983\":{\"m\":123,\"g\":127},\"1357ab02\":{\"m\":123,\"g\":127},\"44da7377\":{\"m\":123,\"g\":127},\"fb9582c4\":{\"m\":123,\"g\":127},\"d22d0447\":{\"m\":123,\"g\":127},\"887742a1\":{\"m\":123,\"g\":127},\"34f7564d\":{\"m\":123,\"g\":127},\"1cfbbc42\":{\"m\":123,\"g\":127},\"55dfb539\":{\"m\":123,\"g\":127},\"42889acb\":{\"m\":123,\"g\":127},\"211f4070\":{\"m\":123,\"g\":127},\"befa41a1\":{\"m\":123,\"g\":127},\"30b26ee9\":{\"m\":123,\"g\":127},\"aa797d01\":{\"m\":123,\"g\":127},\"7cee07a0\":{\"m\":123,\"g\":127},\"bb517fe3\":{\"m\":123,\"g\":127},\"0e82fd3d\":{\"m\":123,\"g\":127},\"b7d70411\":{\"m\":123,\"g\":127},\"ff0b64e1\":{\"m\":123,\"g\":127},\"d84790db\":{\"m\":123,\"g\":127},\"0678beaa\":{\"m\":123,\"g\":127},\"c2d4716d\":{\"m\":123,\"g\":127},\"c14cc47e\":{\"m\":123,\"g\":127},\"dbcf85b7\":{\"m\":123,\"g\":127},\"83804bc6\":{\"m\":123,\"g\":127},\"d5fa019c\":{\"m\":123,\"g\":127},\"fef3a6b6\":{\"m\":123,\"g\":127},\"173e0f70\":{\"m\":123,\"g\":127},\"f600866a\":{\"m\":123,\"g\":127},\"93be7e86\":{\"m\":123,\"g\":127},\"60b0754c\":{\"m\":123,\"g\":127},\"0b24af4d\":{\"m\":123,\"g\":127},\"a209fb05\":{\"m\":123,\"g\":127},\"48d6bea1\":{\"m\":123,\"g\":127},\"1689c0e3\":{\"m\":123,\"g\":127},\"193fbb0b\":{\"m\":123,\"g\":127},\"e607850f\":{\"m\":123,\"g\":127},\"15efbcb4\":{\"m\":123,\"g\":127},\"243c064d\":{\"m\":123,\"g\":127},\"0b41a293\":{\"m\":123,\"g\":127},\"d31d48b3\":{\"m\":123,\"g\":127},\"88342607\":{\"m\":123,\"g\":127},\"fd7a72d6\":{\"m\":123,\"g\":127},\"21a8fa16\":{\"m\":123,\"g\":127},\"7a21d8b2\":{\"m\":123,\"g\":127},\"d36639ee\":{\"m\":123,\"g\":127},\"6ef23b98\":{\"m\":123,\"g\":127},\"385599cb\":{\"m\":123,\"g\":127},\"952fbe47\":{\"m\":123,\"g\":127},\"edb25693\":{\"m\":123,\"g\":127},\"3529c061\":{\"m\":123,\"g\":127},\"ffb32a85\":{\"m\":123,\"g\":127},\"14d80648\":{\"m\":123,\"g\":127},\"ab8b83f7\":{\"m\":123,\"g\":127},\"de0b10cf\":{\"m\":123,\"g\":127},\"6e29446e\":{\"m\":123,\"g\":127},\"0c3543d7\":{\"m\":123,\"g\":127},\"6a3b9fd0\":{\"m\":123,\"g\":127},\"65f1d065\":{\"m\":123,\"g\":127},\"9434a0e5\":{\"m\":123,\"g\":127},\"20315697\":{\"m\":123,\"g\":127},\"c9db7911\":{\"m\":123,\"g\":127},\"15ed27d7\":{\"m\":123,\"g\":127},\"66fb9b13\":{\"m\":123,\"g\":127},\"819fc591\":{\"m\":123,\"g\":127},\"7efd8b3d\":{\"m\":123,\"g\":127},\"a920b9da\":{\"m\":123,\"g\":127},\"76196b3c\":{\"m\":123,\"g\":127},\"95191ebd\":{\"m\":123,\"g\":127},\"9a512cf9\":{\"m\":123,\"g\":127},\"3451fc32\":{\"m\":123,\"g\":127},\"c550ab91\":{\"m\":123,\"g\":127},\"086f0b79\":{\"m\":123,\"g\":127},\"0afd6832\":{\"m\":123,\"g\":127},\"6f858930\":{\"m\":123,\"g\":127},\"229256c5\":{\"m\":123,\"g\":127},\"6b634493\":{\"m\":123,\"g\":127},\"756ad9ce\":{\"m\":123,\"g\":127},\"d2a8f71c\":{\"m\":123,\"g\":127},\"2b7bf11b\":{\"m\":123,\"g\":127},\"566ade03\":{\"m\":123,\"g\":127},\"69193f71\":{\"m\":123,\"g\":127},\"d5b6e50f\":{\"m\":123,\"g\":127},\"9632e48f\":{\"m\":123,\"g\":127},\"59cce594\":{\"m\":123,\"g\":127},\"795e98f8\":{\"m\":123,\"g\":127},\"358ae356\":{\"m\":123,\"g\":127},\"0c006b88\":{\"m\":124,\"g\":127},\"cd135bfe\":{\"m\":124,\"g\":127},\"fc84b073\":{\"m\":124,\"g\":127},\"8be0e1bc\":{\"m\":124,\"g\":127},\"8e1d6756\":{\"m\":124,\"g\":127},\"0da30dbc\":{\"m\":124,\"g\":127},\"58095cb0\":{\"m\":124,\"g\":127},\"fd3034da\":{\"m\":124,\"g\":127},\"fbbe16fa\":{\"m\":124,\"g\":127},\"6a1a64fa\":{\"m\":124,\"g\":127},\"e2715cf8\":{\"m\":124,\"g\":127},\"74630ba3\":{\"m\":124,\"g\":127},\"73e9a2ef\":{\"m\":124,\"g\":127},\"837b08eb\":{\"m\":124,\"g\":127},\"bb6a21cd\":{\"m\":124,\"g\":127},\"32ec68fa\":{\"m\":124,\"g\":127},\"3c0a6df8\":{\"m\":124,\"g\":127},\"2104d20e\":{\"m\":124,\"g\":127},\"f235498e\":{\"m\":124,\"g\":127},\"149dc9aa\":{\"m\":124,\"g\":127},\"9a954982\":{\"m\":124,\"g\":127},\"97be66c3\":{\"m\":124,\"g\":127},\"74243dff\":{\"m\":124,\"g\":127},\"4ea4c48b\":{\"m\":124,\"g\":127},\"cf5d27e3\":{\"m\":124,\"g\":127},\"9ec6031d\":{\"m\":124,\"g\":127},\"a5affb0c\":{\"m\":124,\"g\":127},\"5925d3d7\":{\"m\":124,\"g\":127},\"7ef1964a\":{\"m\":124,\"g\":127},\"b0476a06\":{\"m\":124,\"g\":127},\"ffba61a1\":{\"m\":124,\"g\":127},\"3c219eb0\":{\"m\":124,\"g\":127},\"1ffdcdc4\":{\"m\":124,\"g\":127},\"c7d57d5b\":{\"m\":124,\"g\":127},\"80802c4c\":{\"m\":124,\"g\":127},\"83b104ee\":{\"m\":124,\"g\":127},\"14127804\":{\"m\":124,\"g\":127},\"82f39dc1\":{\"m\":124,\"g\":127},\"627bac64\":{\"m\":124,\"g\":127},\"3651cfbf\":{\"m\":124,\"g\":127},\"c8547ecd\":{\"m\":124,\"g\":127},\"7bc1dae0\":{\"m\":124,\"g\":127},\"4fe53e58\":{\"m\":124,\"g\":127},\"fb2e816e\":{\"m\":124,\"g\":127},\"7c45b8b4\":{\"m\":124,\"g\":127},\"ba5b6823\":{\"m\":124,\"g\":127},\"508d2f7a\":{\"m\":124,\"g\":127},\"a889c854\":{\"m\":124,\"g\":127},\"4d84f886\":{\"m\":124,\"g\":127},\"dc4f5418\":{\"m\":124,\"g\":127},\"0648eb48\":{\"m\":124,\"g\":127},\"b88fab31\":{\"m\":124,\"g\":127},\"36942660\":{\"m\":124,\"g\":127},\"9f5e7018\":{\"m\":124,\"g\":127},\"cbf23dbb\":{\"m\":124,\"g\":127},\"6dade6c3\":{\"m\":124,\"g\":127},\"b419e20c\":{\"m\":124,\"g\":127},\"48641435\":{\"m\":124,\"g\":127},\"44b1b394\":{\"m\":124,\"g\":127},\"0711d150\":{\"m\":124,\"g\":127},\"303cc957\":{\"m\":125,\"g\":127},\"661c1c97\":{\"m\":125,\"g\":127},\"c022107f\":{\"m\":125,\"g\":127},\"56c83e0f\":{\"m\":125,\"g\":127},\"b51d46d0\":{\"m\":125,\"g\":127},\"10864731\":{\"m\":125,\"g\":127},\"665416f6\":{\"m\":125,\"g\":127},\"838bcb0d\":{\"m\":125,\"g\":127},\"f1f4c451\":{\"m\":125,\"g\":127},\"b0ee99dd\":{\"m\":125,\"g\":127},\"ddfcb7c8\":{\"m\":125,\"g\":127},\"58b12ccb\":{\"m\":125,\"g\":127},\"547de8c7\":{\"m\":125,\"g\":127},\"37c40a87\":{\"m\":125,\"g\":127},\"1240ac13\":{\"m\":125,\"g\":127},\"5639145f\":{\"m\":125,\"g\":127},\"afee2843\":{\"m\":125,\"g\":127},\"6f084880\":{\"m\":125,\"g\":127},\"611a4fd0\":{\"m\":125,\"g\":127},\"9ea2c686\":{\"m\":125,\"g\":127},\"e2a784ec\":{\"m\":125,\"g\":127},\"05559a4a\":{\"m\":125,\"g\":127},\"7bffc5dc\":{\"m\":125,\"g\":127},\"a5e5088d\":{\"m\":125,\"g\":127},\"ac19ce7e\":{\"m\":125,\"g\":127},\"95876d75\":{\"m\":125,\"g\":127},\"a30f1907\":{\"m\":125,\"g\":127},\"dc8a5a1c\":{\"m\":125,\"g\":127},\"e123648b\":{\"m\":125,\"g\":127},\"90401cf7\":{\"m\":125,\"g\":127},\"9cfe78dd\":{\"m\":125,\"g\":127},\"307e7a61\":{\"m\":125,\"g\":127},\"ddd1440d\":{\"m\":125,\"g\":127},\"d1be60c3\":{\"m\":125,\"g\":127},\"583bb180\":{\"m\":125,\"g\":127},\"1f2a6c69\":{\"m\":125,\"g\":127},\"61c7fe7a\":{\"m\":125,\"g\":127},\"83f89cc6\":{\"m\":125,\"g\":127},\"db24d346\":{\"m\":125,\"g\":127},\"885cfca2\":{\"m\":125,\"g\":127},\"4e916f98\":{\"m\":125,\"g\":127},\"4f65a646\":{\"m\":125,\"g\":127},\"f5b3ccd9\":{\"m\":125,\"g\":127},\"bb00e24f\":{\"m\":125,\"g\":127},\"210a9cab\":{\"m\":125,\"g\":127},\"b8ac4fcb\":{\"m\":125,\"g\":127},\"877cb528\":{\"m\":125,\"g\":127},\"3633f8b0\":{\"m\":125,\"g\":127},\"93cf60fc\":{\"m\":125,\"g\":127},\"b5e04173\":{\"m\":125,\"g\":127},\"8a821af7\":{\"m\":125,\"g\":127},\"4b1d163b\":{\"m\":125,\"g\":127},\"c21a3ec2\":{\"m\":125,\"g\":127},\"b142831a\":{\"m\":125,\"g\":127},\"d1340963\":{\"m\":125,\"g\":127},\"f290e801\":{\"m\":125,\"g\":127},\"9299a62f\":{\"m\":125,\"g\":127},\"49543be9\":{\"m\":125,\"g\":127},\"52362903\":{\"m\":125,\"g\":127},\"b2b26d43\":{\"m\":125,\"g\":127},\"d3a03aee\":{\"m\":125,\"g\":127},\"5f02b918\":{\"m\":125,\"g\":127},\"49653c88\":{\"m\":125,\"g\":127},\"44f594d8\":{\"m\":125,\"g\":127},\"2b6c4257\":{\"m\":125,\"g\":127},\"243ea585\":{\"m\":125,\"g\":127},\"6fee2c53\":{\"m\":125,\"g\":127},\"f1a9c72d\":{\"m\":125,\"g\":127},\"190002c6\":{\"m\":125,\"g\":127},\"0296f1cd\":{\"m\":125,\"g\":127},\"e039ff38\":{\"m\":125,\"g\":127},\"b8ddc296\":{\"m\":125,\"g\":127},\"0b88d520\":{\"m\":125,\"g\":127},\"d3d7f960\":{\"m\":125,\"g\":127},\"e4341872\":{\"m\":125,\"g\":127},\"fe19a580\":{\"m\":125,\"g\":127},\"55e8e399\":{\"m\":125,\"g\":127},\"32f79828\":{\"m\":125,\"g\":127},\"0f76976c\":{\"m\":125,\"g\":127},\"ae622790\":{\"m\":125,\"g\":127},\"5c9273c0\":{\"m\":125,\"g\":127},\"e316bcac\":{\"m\":125,\"g\":127},\"0fe9c1f7\":{\"m\":125,\"g\":127},\"61bfd9fa\":{\"m\":125,\"g\":127},\"c67fce16\":{\"m\":125,\"g\":127},\"bef37d6d\":{\"m\":125,\"g\":127},\"bc25ea67\":{\"m\":125,\"g\":127},\"1fa788ec\":{\"m\":125,\"g\":127},\"125f76ea\":{\"m\":125,\"g\":127},\"d8736c75\":{\"m\":125,\"g\":127},\"3b1cc466\":{\"m\":125,\"g\":127},\"0ee5ab5a\":{\"m\":125,\"g\":127},\"ed5e905c\":{\"m\":125,\"g\":127},\"d9c812d8\":{\"m\":125,\"g\":127},\"7257525c\":{\"m\":125,\"g\":127},\"34ba10ef\":{\"m\":125,\"g\":127},\"a119363f\":{\"m\":125,\"g\":127},\"fb314d7b\":{\"m\":125,\"g\":127},\"3a64844a\":{\"m\":125,\"g\":127},\"585c417f\":{\"m\":125,\"g\":127},\"b0d1c21d\":{\"m\":125,\"g\":127},\"b07c5e40\":{\"m\":125,\"g\":127},\"1772671b\":{\"m\":125,\"g\":127},\"6e6009fb\":{\"m\":125,\"g\":127},\"88a2a340\":{\"m\":125,\"g\":127},\"78c58621\":{\"m\":125,\"g\":127},\"c3bb348d\":{\"m\":125,\"g\":127},\"4e234b4c\":{\"m\":125,\"g\":127},\"4cc725ac\":{\"m\":125,\"g\":127},\"2cb42dc1\":{\"m\":125,\"g\":127},\"ebaf86d4\":{\"m\":126,\"g\":127},\"8359f185\":{\"m\":126,\"g\":127},\"5324f37a\":{\"m\":126,\"g\":127},\"2864c49f\":{\"m\":126,\"g\":127},\"d26ec39f\":{\"m\":126,\"g\":127},\"9c546bfd\":{\"m\":126,\"g\":127},\"c9b58164\":{\"m\":126,\"g\":127},\"e5e65e3d\":{\"m\":126,\"g\":127},\"ffeb28ba\":{\"m\":126,\"g\":127},\"b40f605f\":{\"m\":126,\"g\":127},\"3cdec20c\":{\"m\":126,\"g\":127},\"d28caaf6\":{\"m\":126,\"g\":127},\"ad8d24c3\":{\"m\":126,\"g\":127},\"018123b5\":{\"m\":126,\"g\":127},\"4983b7e7\":{\"m\":126,\"g\":127},\"f825137f\":{\"m\":126,\"g\":127},\"ae68158f\":{\"m\":126,\"g\":127},\"a7cc02e3\":{\"m\":126,\"g\":127},\"8ece99a9\":{\"m\":126,\"g\":127},\"44e391b6\":{\"m\":126,\"g\":127},\"1a5c313f\":{\"m\":126,\"g\":127},\"7ea5b42d\":{\"m\":126,\"g\":127},\"5ded5e27\":{\"m\":126,\"g\":127},\"33d1aeb0\":{\"m\":126,\"g\":127},\"dd909a51\":{\"m\":126,\"g\":127},\"60cb7167\":{\"m\":126,\"g\":127},\"151e1368\":{\"m\":126,\"g\":127},\"2f9952cd\":{\"m\":126,\"g\":127},\"3e7cc273\":{\"m\":126,\"g\":127},\"8f01a12d\":{\"m\":126,\"g\":127},\"28b8c579\":{\"m\":126,\"g\":127},\"7b877ab8\":{\"m\":126,\"g\":127},\"0d4a4184\":{\"m\":126,\"g\":127},\"2ca25a8a\":{\"m\":126,\"g\":127},\"cc2e36c3\":{\"m\":126,\"g\":127},\"d8f7816a\":{\"m\":126,\"g\":127},\"99e25805\":{\"m\":126,\"g\":127},\"4a2768a8\":{\"m\":126,\"g\":127},\"9b247f73\":{\"m\":126,\"g\":127},\"36d14712\":{\"m\":126,\"g\":127},\"e0e6a6ef\":{\"m\":126,\"g\":127},\"e8114102\":{\"m\":126,\"g\":127},\"14a339fc\":{\"m\":126,\"g\":127},\"5f662e78\":{\"m\":126,\"g\":127},\"7f5055ed\":{\"m\":126,\"g\":127},\"63728b11\":{\"m\":126,\"g\":127},\"5de25f78\":{\"m\":126,\"g\":127},\"d52800db\":{\"m\":126,\"g\":127},\"38a704bc\":{\"m\":126,\"g\":127},\"a06c44f9\":{\"m\":126,\"g\":127},\"527b7d3f\":{\"m\":126,\"g\":127},\"f09eee03\":{\"m\":126,\"g\":127},\"e38994dd\":{\"m\":126,\"g\":127},\"4a78031a\":{\"m\":126,\"g\":127},\"ea10a9d1\":{\"m\":126,\"g\":127},\"71aea45c\":{\"m\":126,\"g\":127},\"e3b38d71\":{\"m\":126,\"g\":127},\"5c0cadd0\":{\"m\":126,\"g\":127},\"39b1d048\":{\"m\":126,\"g\":127},\"8db7fc41\":{\"m\":126,\"g\":127},\"fe92d4d8\":{\"m\":126,\"g\":127},\"fc8cda14\":{\"m\":126,\"g\":127},\"6a7322ff\":{\"m\":126,\"g\":127},\"c751cb38\":{\"m\":126,\"g\":127},\"3594815a\":{\"m\":126,\"g\":127},\"9caca6a4\":{\"m\":126,\"g\":127},\"08c805a8\":{\"m\":126,\"g\":127},\"f18ec927\":{\"m\":126,\"g\":127},\"aea88fa7\":{\"m\":126,\"g\":127},\"2fe4e69f\":{\"m\":126,\"g\":127},\"012bfc4f\":{\"m\":126,\"g\":127},\"0493775b\":{\"m\":126,\"g\":127},\"40b26b45\":{\"m\":126,\"g\":127},\"9840bf4f\":{\"m\":126,\"g\":127},\"7b2fb3d4\":{\"m\":128,\"g\":131},\"d64dd3e1\":{\"m\":128,\"g\":131},\"3ccd7fa6\":{\"m\":128,\"g\":131},\"254f62d8\":{\"m\":128,\"g\":131},\"2b8b9d84\":{\"m\":128,\"g\":131},\"e970892f\":{\"m\":128,\"g\":131},\"d5fa58c4\":{\"m\":128,\"g\":131},\"9f011f61\":{\"m\":128,\"g\":131},\"3b18fd4c\":{\"m\":128,\"g\":131},\"1869f25c\":{\"m\":128,\"g\":131},\"e019f233\":{\"m\":128,\"g\":131},\"b1c688fb\":{\"m\":128,\"g\":131},\"8e3663d4\":{\"m\":128,\"g\":131},\"9edb0e0d\":{\"m\":128,\"g\":131},\"95f43669\":{\"m\":128,\"g\":131},\"50691d7b\":{\"m\":128,\"g\":131},\"6afe3963\":{\"m\":128,\"g\":131},\"191f5c77\":{\"m\":128,\"g\":131},\"d7246708\":{\"m\":128,\"g\":131},\"efc5d8f5\":{\"m\":128,\"g\":131},\"ef32a252\":{\"m\":128,\"g\":131},\"9509c4cc\":{\"m\":128,\"g\":131},\"4e19c1d5\":{\"m\":128,\"g\":131},\"78a4b446\":{\"m\":128,\"g\":131},\"f9696641\":{\"m\":128,\"g\":131},\"f35f7f12\":{\"m\":128,\"g\":131},\"12c789eb\":{\"m\":128,\"g\":131},\"597d4160\":{\"m\":128,\"g\":131},\"1ca205f6\":{\"m\":128,\"g\":131},\"24a25ffa\":{\"m\":128,\"g\":131},\"db7299aa\":{\"m\":128,\"g\":131},\"13366843\":{\"m\":128,\"g\":131},\"be353ffd\":{\"m\":128,\"g\":131},\"20e59f95\":{\"m\":128,\"g\":131},\"0d116b9a\":{\"m\":128,\"g\":131},\"4a56fa5c\":{\"m\":128,\"g\":131},\"51f9b962\":{\"m\":128,\"g\":131},\"daf494b6\":{\"m\":128,\"g\":131},\"37e8724e\":{\"m\":128,\"g\":131},\"10592e9c\":{\"m\":128,\"g\":131},\"2aec8b6e\":{\"m\":128,\"g\":131},\"0d41ddfb\":{\"m\":128,\"g\":131},\"37c87615\":{\"m\":128,\"g\":131},\"3f400f25\":{\"m\":128,\"g\":131},\"d91b16eb\":{\"m\":128,\"g\":131},\"4a10e37b\":{\"m\":128,\"g\":131},\"b051d76d\":{\"m\":128,\"g\":131},\"d52d992a\":{\"m\":128,\"g\":131},\"d971f228\":{\"m\":128,\"g\":131},\"1d3d42bd\":{\"m\":128,\"g\":131},\"8e9f05ec\":{\"m\":128,\"g\":131},\"bc083521\":{\"m\":128,\"g\":131},\"33f08a98\":{\"m\":128,\"g\":131},\"8e6083bf\":{\"m\":128,\"g\":131},\"2fbc78a0\":{\"m\":128,\"g\":131},\"f0b5ccf5\":{\"m\":128,\"g\":131},\"b732ffa4\":{\"m\":128,\"g\":131},\"56fc4830\":{\"m\":128,\"g\":131},\"2a96e302\":{\"m\":128,\"g\":131},\"8a437340\":{\"m\":128,\"g\":131},\"f0021c0d\":{\"m\":128,\"g\":131},\"7ee3e364\":{\"m\":128,\"g\":131},\"6d5e16fb\":{\"m\":128,\"g\":131},\"67e6f143\":{\"m\":128,\"g\":131},\"34851471\":{\"m\":128,\"g\":131},\"c2083116\":{\"m\":128,\"g\":131},\"10285ec2\":{\"m\":128,\"g\":131},\"9b3fc186\":{\"m\":128,\"g\":131},\"172c71a2\":{\"m\":128,\"g\":127},\"af373636\":{\"m\":128,\"g\":127},\"a5be6ef9\":{\"m\":128,\"g\":127},\"8f4e18a2\":{\"m\":128,\"g\":127},\"e9681444\":{\"m\":128,\"g\":127},\"eae59b33\":{\"m\":128,\"g\":127},\"14dc0523\":{\"m\":128,\"g\":127},\"b2236691\":{\"m\":128,\"g\":127},\"dcc47a56\":{\"m\":128,\"g\":127},\"6448b4cd\":{\"m\":128,\"g\":127},\"5ae0ac42\":{\"m\":128,\"g\":127},\"22f641ab\":{\"m\":128,\"g\":127},\"0997c78d\":{\"m\":128,\"g\":127},\"fd3be107\":{\"m\":128,\"g\":127},\"665f43bd\":{\"m\":128,\"g\":127},\"a53f2d6c\":{\"m\":128,\"g\":127},\"a7002e61\":{\"m\":128,\"g\":127},\"84e151ac\":{\"m\":128,\"g\":127},\"875a25dd\":{\"m\":128,\"g\":127},\"fc55b45e\":{\"m\":128,\"g\":127},\"0050ff25\":{\"m\":128,\"g\":127},\"af9f71f9\":{\"m\":128,\"g\":127},\"15264232\":{\"m\":128,\"g\":127},\"f8d3d80f\":{\"m\":128,\"g\":127},\"5027739f\":{\"m\":128,\"g\":127},\"3701f34d\":{\"m\":128,\"g\":127},\"821fb060\":{\"m\":128,\"g\":127},\"ace27c0c\":{\"m\":128,\"g\":127},\"ed1d18d4\":{\"m\":128,\"g\":127},\"e7b57b0d\":{\"m\":128,\"g\":127},\"49141df9\":{\"m\":128,\"g\":127},\"7b79cc4f\":{\"m\":128,\"g\":127},\"385ff0e5\":{\"m\":128,\"g\":127},\"fc5da1e8\":{\"m\":128,\"g\":127},\"04848ba7\":{\"m\":128,\"g\":127},\"9bc6a9ad\":{\"m\":128,\"g\":127},\"7cdaedb8\":{\"m\":128,\"g\":127},\"1f134f85\":{\"m\":128,\"g\":127},\"5f72d36d\":{\"m\":128,\"g\":127},\"7a2254b2\":{\"m\":128,\"g\":127},\"b5904999\":{\"m\":128,\"g\":127},\"922525ee\":{\"m\":128,\"g\":127},\"ee3e337c\":{\"m\":128,\"g\":127},\"e523e216\":{\"m\":128,\"g\":127},\"2ce23777\":{\"m\":128,\"g\":127},\"19f6a33c\":{\"m\":128,\"g\":127},\"e9c0c558\":{\"m\":128,\"g\":127},\"9b41f31a\":{\"m\":128,\"g\":127},\"9db3add3\":{\"m\":128,\"g\":127},\"c9e5799b\":{\"m\":128,\"g\":127},\"2966367a\":{\"m\":128,\"g\":127},\"87791007\":{\"m\":128,\"g\":127},\"dd192a55\":{\"m\":128,\"g\":127},\"bfe638f7\":{\"m\":128,\"g\":127},\"4ac65e3c\":{\"m\":128,\"g\":127},\"0779c3d1\":{\"m\":128,\"g\":127},\"5c2d72ba\":{\"m\":128,\"g\":127},\"9bd511a5\":{\"m\":128,\"g\":127},\"85b8c5c4\":{\"m\":128,\"g\":127},\"c8b7516f\":{\"m\":128,\"g\":127},\"67e9d287\":{\"m\":128,\"g\":127},\"e7e89349\":{\"m\":128,\"g\":127},\"aead0ef5\":{\"m\":128,\"g\":127},\"66640835\":{\"m\":128,\"g\":127},\"c2d69e8b\":{\"m\":128,\"g\":127},\"4c1e909a\":{\"m\":128,\"g\":127},\"2bb0317e\":{\"m\":128,\"g\":127},\"86255f27\":{\"m\":128,\"g\":127},\"e4b29370\":{\"m\":128,\"g\":127},\"7a8524b4\":{\"m\":128,\"g\":127},\"909d0d38\":{\"m\":128,\"g\":127},\"e42df37d\":{\"m\":128,\"g\":127},\"6d21392b\":{\"m\":128,\"g\":127},\"a1cb717d\":{\"m\":128,\"g\":127},\"c9456491\":{\"m\":128,\"g\":127},\"c4b74c1d\":{\"m\":128,\"g\":127},\"7aa44390\":{\"m\":128,\"g\":127},\"4eda9969\":{\"m\":128,\"g\":127},\"03a7e6f4\":{\"m\":128,\"g\":127},\"2cdde3d4\":{\"m\":128,\"g\":127},\"401ed0c5\":{\"m\":128,\"g\":127},\"4ef43905\":{\"m\":128,\"g\":127},\"d646cf63\":{\"m\":128,\"g\":127},\"4edb2401\":{\"m\":128,\"g\":127},\"706502ff\":{\"m\":128,\"g\":127},\"c2e56dad\":{\"m\":128,\"g\":127},\"9c1c5c6d\":{\"m\":128,\"g\":127},\"2d531946\":{\"m\":128,\"g\":127},\"7ae368ef\":{\"m\":129,\"g\":131},\"c4e20cad\":{\"m\":129,\"g\":131},\"5dad1ff1\":{\"m\":129,\"g\":131},\"ca52ed42\":{\"m\":129,\"g\":131},\"5c03aa3e\":{\"m\":129,\"g\":131},\"253be18e\":{\"m\":129,\"g\":131},\"084b06e7\":{\"m\":129,\"g\":131},\"fc6fb550\":{\"m\":129,\"g\":131},\"7c38eca1\":{\"m\":129,\"g\":131},\"427b08e2\":{\"m\":129,\"g\":131},\"0141ca37\":{\"m\":129,\"g\":131},\"25a6be49\":{\"m\":129,\"g\":131},\"df1f3124\":{\"m\":129,\"g\":131},\"c5947ecd\":{\"m\":129,\"g\":131},\"9530b766\":{\"m\":129,\"g\":131},\"3067b3f0\":{\"m\":129,\"g\":131},\"e0ec42c7\":{\"m\":129,\"g\":131},\"9c9d7091\":{\"m\":129,\"g\":131},\"51a86ce6\":{\"m\":129,\"g\":131},\"64092c8b\":{\"m\":129,\"g\":131},\"63b9300f\":{\"m\":129,\"g\":131},\"21ec99be\":{\"m\":129,\"g\":131},\"383689e3\":{\"m\":129,\"g\":131},\"236a7c23\":{\"m\":129,\"g\":131},\"3dabd609\":{\"m\":129,\"g\":131},\"73df5253\":{\"m\":129,\"g\":131},\"e6420100\":{\"m\":129,\"g\":131},\"c9e20901\":{\"m\":129,\"g\":131},\"106df4ea\":{\"m\":129,\"g\":131},\"427a19b6\":{\"m\":129,\"g\":131},\"1f930cd2\":{\"m\":129,\"g\":131},\"11ce0516\":{\"m\":129,\"g\":131},\"cd4151ab\":{\"m\":129,\"g\":131},\"8fe8b635\":{\"m\":129,\"g\":131},\"8a7b1b83\":{\"m\":129,\"g\":131},\"26aebf83\":{\"m\":129,\"g\":131},\"3ab8ae68\":{\"m\":129,\"g\":131},\"03888b9d\":{\"m\":129,\"g\":131},\"796d82b1\":{\"m\":129,\"g\":131},\"1da59e83\":{\"m\":129,\"g\":131},\"1d66a14c\":{\"m\":129,\"g\":131},\"02af51e4\":{\"m\":129,\"g\":131},\"eb500884\":{\"m\":129,\"g\":131},\"079b1738\":{\"m\":129,\"g\":131},\"57f933fd\":{\"m\":129,\"g\":131},\"1f2b84d2\":{\"m\":129,\"g\":131},\"e7d6027e\":{\"m\":129,\"g\":131},\"07821352\":{\"m\":129,\"g\":131},\"edbeaf3b\":{\"m\":129,\"g\":131},\"d9dca282\":{\"m\":129,\"g\":131},\"9325f945\":{\"m\":129,\"g\":131},\"41b7aab8\":{\"m\":129,\"g\":131},\"491f4fe8\":{\"m\":129,\"g\":131},\"34035d8c\":{\"m\":129,\"g\":131},\"92ca6295\":{\"m\":129,\"g\":131},\"ec92d7f1\":{\"m\":129,\"g\":131},\"79b389da\":{\"m\":129,\"g\":131},\"3de09aad\":{\"m\":129,\"g\":131},\"2e8f54e6\":{\"m\":129,\"g\":131},\"e55731b6\":{\"m\":129,\"g\":131},\"45264554\":{\"m\":129,\"g\":131},\"c4293f59\":{\"m\":129,\"g\":131},\"a2423052\":{\"m\":129,\"g\":131},\"bc3d2a85\":{\"m\":129,\"g\":131},\"d815d002\":{\"m\":129,\"g\":131},\"fa9021b2\":{\"m\":129,\"g\":131},\"9c800728\":{\"m\":129,\"g\":131},\"630a6930\":{\"m\":129,\"g\":131},\"7ce8faae\":{\"m\":129,\"g\":131},\"de153cf7\":{\"m\":129,\"g\":131},\"f4a0c5c7\":{\"m\":129,\"g\":131},\"0f8e5394\":{\"m\":129,\"g\":131},\"e8ba5a66\":{\"m\":129,\"g\":131},\"a2960bdd\":{\"m\":129,\"g\":131},\"487c8d4d\":{\"m\":129,\"g\":131},\"f87b8eab\":{\"m\":129,\"g\":131},\"e8542db5\":{\"m\":129,\"g\":131},\"6df1e8d6\":{\"m\":129,\"g\":131},\"4addb602\":{\"m\":129,\"g\":131},\"bd0e6908\":{\"m\":129,\"g\":131},\"0825d7f4\":{\"m\":129,\"g\":131},\"0b9dbea5\":{\"m\":129,\"g\":131},\"982db4eb\":{\"m\":129,\"g\":131},\"f5f3a5d9\":{\"m\":129,\"g\":131},\"f138ae57\":{\"m\":129,\"g\":131},\"decb4896\":{\"m\":129,\"g\":131},\"c72f0756\":{\"m\":129,\"g\":131},\"f1115cf5\":{\"m\":129,\"g\":131},\"412160f4\":{\"m\":129,\"g\":131},\"7b03cc64\":{\"m\":129,\"g\":131},\"9872a677\":{\"m\":129,\"g\":131},\"c15c864b\":{\"m\":129,\"g\":131},\"dc7bdc73\":{\"m\":129,\"g\":131},\"0a9d6453\":{\"m\":129,\"g\":131},\"340c613a\":{\"m\":129,\"g\":131},\"36b729c2\":{\"m\":129,\"g\":131},\"67e6ef4b\":{\"m\":129,\"g\":131},\"65ba5ab8\":{\"m\":129,\"g\":131},\"990023e5\":{\"m\":129,\"g\":131},\"5ddd2f6b\":{\"m\":129,\"g\":131},\"0ae4b1ad\":{\"m\":129,\"g\":131},\"94cd64a7\":{\"m\":129,\"g\":131},\"b870271a\":{\"m\":129,\"g\":131},\"22ee9b01\":{\"m\":129,\"g\":131},\"9d0e5f1f\":{\"m\":129,\"g\":131},\"1d3d8b34\":{\"m\":129,\"g\":131},\"155a9e72\":{\"m\":129,\"g\":131},\"d7cb08c5\":{\"m\":129,\"g\":131},\"3339c810\":{\"m\":129,\"g\":131},\"d6c88d51\":{\"m\":129,\"g\":131},\"f03ea34a\":{\"m\":129,\"g\":131},\"4cafc835\":{\"m\":129,\"g\":131},\"c6a52f44\":{\"m\":129,\"g\":131},\"c6d34a06\":{\"m\":129,\"g\":131},\"848ee570\":{\"m\":129,\"g\":131},\"ce6b7dfc\":{\"m\":129,\"g\":131},\"0fe74af5\":{\"m\":129,\"g\":131},\"0a362d65\":{\"m\":129,\"g\":131},\"143b57b8\":{\"m\":129,\"g\":131},\"f446b51c\":{\"m\":129,\"g\":131},\"a102a050\":{\"m\":129,\"g\":131},\"6bad6a36\":{\"m\":129,\"g\":131},\"11b6217a\":{\"m\":129,\"g\":131},\"0b0b2607\":{\"m\":129,\"g\":131},\"841eb29d\":{\"m\":129,\"g\":131},\"45cf5758\":{\"m\":129,\"g\":131},\"0e8ce1e8\":{\"m\":129,\"g\":131},\"ea1e9f6b\":{\"m\":129,\"g\":131},\"f6e37d3e\":{\"m\":129,\"g\":131},\"ab9a46d4\":{\"m\":129,\"g\":131},\"621061f0\":{\"m\":129,\"g\":131},\"7daddcdb\":{\"m\":129,\"g\":131},\"91d249cd\":{\"m\":129,\"g\":131},\"95102896\":{\"m\":129,\"g\":131},\"3543a04a\":{\"m\":129,\"g\":131},\"e12c78aa\":{\"m\":129,\"g\":131},\"bce40fa2\":{\"m\":129,\"g\":131},\"051ad833\":{\"m\":129,\"g\":131},\"4c9f7c97\":{\"m\":129,\"g\":131},\"63b05621\":{\"m\":129,\"g\":131},\"21af8e73\":{\"m\":129,\"g\":131},\"7ab548ef\":{\"m\":129,\"g\":131},\"bab033b9\":{\"m\":129,\"g\":131},\"63500426\":{\"m\":129,\"g\":131},\"25758647\":{\"m\":129,\"g\":131},\"2bc8ee8b\":{\"m\":129,\"g\":131},\"ab843ced\":{\"m\":129,\"g\":131},\"6edffc63\":{\"m\":129,\"g\":131},\"077ca70e\":{\"m\":129,\"g\":131},\"7cb04dc0\":{\"m\":129,\"g\":131},\"5443db87\":{\"m\":129,\"g\":131},\"9f340ab1\":{\"m\":129,\"g\":131},\"70c6f951\":{\"m\":129,\"g\":131},\"d941a3be\":{\"m\":129,\"g\":131},\"b12c9e5c\":{\"m\":129,\"g\":131},\"e9e90460\":{\"m\":129,\"g\":131},\"6330d664\":{\"m\":129,\"g\":131},\"91e8dc37\":{\"m\":129,\"g\":131},\"231df4b0\":{\"m\":129,\"g\":131},\"b087ef8b\":{\"m\":129,\"g\":131},\"5155016b\":{\"m\":129,\"g\":131},\"082b54c6\":{\"m\":129,\"g\":131},\"a8ef4d18\":{\"m\":129,\"g\":131},\"15ff6982\":{\"m\":129,\"g\":131},\"9adef42c\":{\"m\":129,\"g\":131},\"685b9d82\":{\"m\":129,\"g\":131},\"a223402f\":{\"m\":129,\"g\":131},\"44d0a848\":{\"m\":129,\"g\":131},\"5b7da0f5\":{\"m\":129,\"g\":131},\"697a77bf\":{\"m\":129,\"g\":131},\"0a186924\":{\"m\":129,\"g\":131},\"b6312e62\":{\"m\":129,\"g\":131},\"779cbc6e\":{\"m\":129,\"g\":131},\"67c8c867\":{\"m\":129,\"g\":131},\"69a03bc3\":{\"m\":129,\"g\":131},\"66f242b9\":{\"m\":129,\"g\":131},\"8a9b8b84\":{\"m\":129,\"g\":131},\"7e964b51\":{\"m\":129,\"g\":131},\"a0d9f6cd\":{\"m\":129,\"g\":131},\"8308cd36\":{\"m\":129,\"g\":131},\"e0e8a996\":{\"m\":129,\"g\":131},\"5102d009\":{\"m\":129,\"g\":131},\"0dd759e0\":{\"m\":129,\"g\":131},\"5e70880e\":{\"m\":129,\"g\":131},\"9dab534b\":{\"m\":129,\"g\":131},\"262c3c1f\":{\"m\":129,\"g\":131},\"6c190cbd\":{\"m\":129,\"g\":131},\"eff6a07c\":{\"m\":129,\"g\":131},\"b704b0a9\":{\"m\":129,\"g\":131},\"15729dbc\":{\"m\":129,\"g\":131},\"21b0582d\":{\"m\":129,\"g\":131},\"5795da5e\":{\"m\":129,\"g\":131},\"540d6fee\":{\"m\":129,\"g\":131},\"5a8adca9\":{\"m\":129,\"g\":131},\"007c3e23\":{\"m\":129,\"g\":131},\"7130ad3a\":{\"m\":129,\"g\":131},\"35a4c21a\":{\"m\":129,\"g\":131},\"846ba3c6\":{\"m\":129,\"g\":131},\"18fb5158\":{\"m\":129,\"g\":131},\"ca5c8b16\":{\"m\":129,\"g\":131},\"f33e5d1e\":{\"m\":129,\"g\":131},\"13e5beea\":{\"m\":129,\"g\":131},\"c53e729d\":{\"m\":129,\"g\":131},\"03a26557\":{\"m\":129,\"g\":131},\"873382a9\":{\"m\":129,\"g\":131},\"391a863b\":{\"m\":129,\"g\":131},\"36b1bcd2\":{\"m\":129,\"g\":131},\"64a11303\":{\"m\":129,\"g\":131},\"1ab6ce0e\":{\"m\":129,\"g\":131},\"e99ca6ac\":{\"m\":129,\"g\":131},\"fcccaf90\":{\"m\":129,\"g\":131},\"215a97fa\":{\"m\":129,\"g\":131},\"5eed5fc0\":{\"m\":129,\"g\":131},\"808b6dfd\":{\"m\":129,\"g\":131},\"4852aa05\":{\"m\":129,\"g\":131},\"f922bfd5\":{\"m\":129,\"g\":131},\"dfd7ab96\":{\"m\":129,\"g\":131},\"46673b42\":{\"m\":129,\"g\":131},\"3421d049\":{\"m\":129,\"g\":131},\"d3d404d3\":{\"m\":129,\"g\":131},\"64225a8a\":{\"m\":129,\"g\":131},\"d64bf6c6\":{\"m\":129,\"g\":131},\"432ecf84\":{\"m\":129,\"g\":131},\"0b3f002d\":{\"m\":129,\"g\":131},\"6f094def\":{\"m\":129,\"g\":131},\"59464dbf\":{\"m\":129,\"g\":131},\"1f7fcc10\":{\"m\":129,\"g\":131},\"dbab5d50\":{\"m\":129,\"g\":131},\"760c20b3\":{\"m\":129,\"g\":131},\"407cb3ce\":{\"m\":129,\"g\":131},\"7cc43bd4\":{\"m\":129,\"g\":131},\"cce2d748\":{\"m\":129,\"g\":131},\"c1dd9a95\":{\"m\":129,\"g\":131},\"ed8786b0\":{\"m\":129,\"g\":131},\"f9fe0630\":{\"m\":129,\"g\":131},\"8ff3ef1f\":{\"m\":129,\"g\":131},\"da182e4b\":{\"m\":129,\"g\":131},\"83e72077\":{\"m\":129,\"g\":131},\"a2c388ba\":{\"m\":129,\"g\":131},\"9384fa27\":{\"m\":129,\"g\":131},\"173e73fa\":{\"m\":129,\"g\":131},\"a164259e\":{\"m\":129,\"g\":131},\"b0a26ba6\":{\"m\":129,\"g\":131},\"de430b67\":{\"m\":129,\"g\":131},\"4b45d556\":{\"m\":129,\"g\":131},\"db0ffc09\":{\"m\":129,\"g\":131},\"eb1d8854\":{\"m\":129,\"g\":131},\"bf108692\":{\"m\":129,\"g\":131},\"e83bd1fa\":{\"m\":129,\"g\":131},\"9dc15d85\":{\"m\":129,\"g\":131},\"fafaa2cc\":{\"m\":129,\"g\":131},\"9b4b3441\":{\"m\":129,\"g\":131},\"94216a9c\":{\"m\":129,\"g\":131},\"9535015d\":{\"m\":129,\"g\":131},\"a3b578fc\":{\"m\":129,\"g\":131},\"b60e769d\":{\"m\":129,\"g\":131},\"a95a3807\":{\"m\":129,\"g\":131},\"98b38de3\":{\"m\":129,\"g\":131},\"a146f833\":{\"m\":129,\"g\":131},\"1dd9a6ae\":{\"m\":129,\"g\":131},\"8ef11569\":{\"m\":129,\"g\":131},\"ecefc790\":{\"m\":129,\"g\":131},\"aeac6220\":{\"m\":129,\"g\":131},\"04b52fa8\":{\"m\":129,\"g\":131},\"e5c0f591\":{\"m\":129,\"g\":131},\"981ca831\":{\"m\":129,\"g\":131},\"414248e0\":{\"m\":129,\"g\":131},\"f56b9b42\":{\"m\":129,\"g\":131},\"b2f7b08c\":{\"m\":129,\"g\":131},\"75222bfe\":{\"m\":129,\"g\":131},\"9ea19533\":{\"m\":129,\"g\":131},\"dbf22152\":{\"m\":129,\"g\":131},\"d5e03468\":{\"m\":129,\"g\":131},\"a22104a6\":{\"m\":129,\"g\":131},\"4683e244\":{\"m\":129,\"g\":131},\"9054e844\":{\"m\":129,\"g\":131},\"18403f6b\":{\"m\":129,\"g\":131},\"2892265d\":{\"m\":129,\"g\":131},\"c9bd1aca\":{\"m\":129,\"g\":131},\"618ca238\":{\"m\":129,\"g\":131},\"5c291549\":{\"m\":129,\"g\":131},\"aaa40a9b\":{\"m\":129,\"g\":131},\"dd70cf99\":{\"m\":129,\"g\":131},\"d4593964\":{\"m\":129,\"g\":131},\"53fffefd\":{\"m\":129,\"g\":131},\"b964ce61\":{\"m\":129,\"g\":131},\"ac5505b0\":{\"m\":129,\"g\":131},\"a90435c0\":{\"m\":129,\"g\":131},\"5354d7b7\":{\"m\":129,\"g\":131},\"e0148677\":{\"m\":129,\"g\":131},\"04793508\":{\"m\":129,\"g\":131},\"cad78789\":{\"m\":129,\"g\":131},\"b29769f3\":{\"m\":129,\"g\":131},\"3990b84b\":{\"m\":129,\"g\":131},\"5a4394a3\":{\"m\":129,\"g\":131},\"dd303614\":{\"m\":129,\"g\":131},\"86312468\":{\"m\":129,\"g\":131},\"5625e32c\":{\"m\":129,\"g\":131},\"ca548d83\":{\"m\":129,\"g\":131},\"3397bcee\":{\"m\":129,\"g\":131},\"3e804bb0\":{\"m\":129,\"g\":131},\"a22de641\":{\"m\":129,\"g\":131},\"ac438226\":{\"m\":129,\"g\":131},\"a92afb00\":{\"m\":129,\"g\":131},\"0eea17e3\":{\"m\":129,\"g\":131},\"38052432\":{\"m\":129,\"g\":131},\"8bfce9b0\":{\"m\":129,\"g\":131},\"b41afa37\":{\"m\":129,\"g\":131},\"94ae816f\":{\"m\":129,\"g\":131},\"53620a1b\":{\"m\":129,\"g\":131},\"59b4d7f8\":{\"m\":129,\"g\":131},\"a56f7702\":{\"m\":129,\"g\":131},\"1b48e1b9\":{\"m\":129,\"g\":131},\"a24aefe5\":{\"m\":129,\"g\":131},\"45c572c5\":{\"m\":129,\"g\":131},\"681b9e64\":{\"m\":129,\"g\":131},\"dab06b50\":{\"m\":129,\"g\":131},\"85ffce30\":{\"m\":129,\"g\":131},\"e94ef9fc\":{\"m\":129,\"g\":131},\"964cdedc\":{\"m\":129,\"g\":131},\"aa6e2c8a\":{\"m\":129,\"g\":131},\"1776dce5\":{\"m\":129,\"g\":131},\"99e13d18\":{\"m\":129,\"g\":131},\"dc836909\":{\"m\":129,\"g\":131},\"323fed5c\":{\"m\":129,\"g\":131},\"eff7df6d\":{\"m\":129,\"g\":131},\"5e7f91d4\":{\"m\":129,\"g\":131},\"a34d3abb\":{\"m\":129,\"g\":131},\"6d0e0b9b\":{\"m\":129,\"g\":131},\"589d9ad5\":{\"m\":129,\"g\":131},\"a244c030\":{\"m\":129,\"g\":131},\"1bb063aa\":{\"m\":129,\"g\":131},\"b30f63c4\":{\"m\":129,\"g\":131},\"90a01335\":{\"m\":129,\"g\":131},\"43602790\":{\"m\":129,\"g\":131},\"b537ac0d\":{\"m\":129,\"g\":131},\"eda2f700\":{\"m\":129,\"g\":131},\"8c212a20\":{\"m\":129,\"g\":131},\"d754ce97\":{\"m\":129,\"g\":131},\"475962a1\":{\"m\":129,\"g\":131},\"bfcf15a1\":{\"m\":129,\"g\":131},\"fb04d434\":{\"m\":129,\"g\":131},\"6be65ae4\":{\"m\":129,\"g\":131},\"750084ae\":{\"m\":129,\"g\":131},\"64480ec7\":{\"m\":129,\"g\":131},\"db2d362d\":{\"m\":129,\"g\":131},\"81e86992\":{\"m\":129,\"g\":131},\"c4db77f8\":{\"m\":129,\"g\":131},\"c0a2513b\":{\"m\":129,\"g\":131},\"3ae664d7\":{\"m\":129,\"g\":131},\"c56fc424\":{\"m\":129,\"g\":131},\"3f1cfd87\":{\"m\":129,\"g\":131},\"b5344b31\":{\"m\":129,\"g\":131},\"6b262ac8\":{\"m\":129,\"g\":131},\"ada8ce1f\":{\"m\":129,\"g\":131},\"5a2c7039\":{\"m\":129,\"g\":131},\"42028af6\":{\"m\":129,\"g\":131},\"7291c72e\":{\"m\":129,\"g\":131},\"6bc30628\":{\"m\":129,\"g\":131},\"fa924410\":{\"m\":129,\"g\":131},\"fc9efdcb\":{\"m\":129,\"g\":131},\"2dec555d\":{\"m\":129,\"g\":131},\"acde21d8\":{\"m\":129,\"g\":131},\"4528cb7d\":{\"m\":129,\"g\":131},\"a352e833\":{\"m\":129,\"g\":131},\"852eb6ce\":{\"m\":129,\"g\":131},\"7af9b88c\":{\"m\":129,\"g\":131},\"2847e5c4\":{\"m\":129,\"g\":131},\"c8ede0e9\":{\"m\":129,\"g\":131},\"19729f72\":{\"m\":129,\"g\":131},\"b51f9bbe\":{\"m\":129,\"g\":131},\"4a8442af\":{\"m\":129,\"g\":131},\"7dcf910d\":{\"m\":129,\"g\":131},\"c7b37b70\":{\"m\":129,\"g\":131},\"10e0b83a\":{\"m\":129,\"g\":131},\"bc42c8c4\":{\"m\":129,\"g\":131},\"2e3a69ae\":{\"m\":129,\"g\":131},\"21370ef7\":{\"m\":129,\"g\":131},\"c3c4da71\":{\"m\":129,\"g\":131},\"dc694624\":{\"m\":129,\"g\":131},\"48ca9f75\":{\"m\":129,\"g\":131},\"127d59cd\":{\"m\":129,\"g\":131},\"af6bcadc\":{\"m\":129,\"g\":131},\"67fca6b2\":{\"m\":129,\"g\":131},\"f88b2aa6\":{\"m\":129,\"g\":131},\"17b24aca\":{\"m\":129,\"g\":131},\"bfaf0b86\":{\"m\":129,\"g\":131},\"83756a4b\":{\"m\":129,\"g\":131},\"a3557949\":{\"m\":129,\"g\":131},\"e72cf136\":{\"m\":129,\"g\":131},\"196b940a\":{\"m\":129,\"g\":131},\"a1e1e533\":{\"m\":129,\"g\":131},\"d4a4dcdf\":{\"m\":129,\"g\":131},\"97ba2c2d\":{\"m\":129,\"g\":131},\"e197bef5\":{\"m\":129,\"g\":131},\"8900f996\":{\"m\":129,\"g\":131},\"ba9102f9\":{\"m\":129,\"g\":131},\"b638abba\":{\"m\":129,\"g\":131},\"37980559\":{\"m\":129,\"g\":131},\"075ba74d\":{\"m\":129,\"g\":131},\"f7be98e1\":{\"m\":129,\"g\":131},\"9a1a9a42\":{\"m\":129,\"g\":131},\"6c2e5fcd\":{\"m\":129,\"g\":131},\"0d2d6878\":{\"m\":129,\"g\":131},\"9f59194f\":{\"m\":129,\"g\":131},\"cf1f0166\":{\"m\":129,\"g\":131},\"10969ae4\":{\"m\":129,\"g\":131},\"92ad2ff9\":{\"m\":129,\"g\":131},\"9b64f6f3\":{\"m\":129,\"g\":131},\"c0d1a338\":{\"m\":129,\"g\":131},\"a9d22b75\":{\"m\":129,\"g\":131},\"6b9459e8\":{\"m\":129,\"g\":131},\"3a6ec47b\":{\"m\":129,\"g\":131},\"b8e32e79\":{\"m\":129,\"g\":131},\"f5566acc\":{\"m\":129,\"g\":131},\"d79e1294\":{\"m\":129,\"g\":131},\"6e9b1549\":{\"m\":129,\"g\":131},\"109f27ba\":{\"m\":129,\"g\":131},\"c1a30aa7\":{\"m\":129,\"g\":131},\"2e1dbdb2\":{\"m\":129,\"g\":131},\"6d025fd3\":{\"m\":129,\"g\":131},\"518467be\":{\"m\":129,\"g\":131},\"63807079\":{\"m\":129,\"g\":131},\"f6cfe9f1\":{\"m\":129,\"g\":131},\"7bc99d41\":{\"m\":129,\"g\":131},\"e2d67468\":{\"m\":129,\"g\":131},\"6beb6e99\":{\"m\":129,\"g\":131},\"7e88b9c1\":{\"m\":129,\"g\":131},\"cfcf2758\":{\"m\":129,\"g\":131},\"ac81db66\":{\"m\":129,\"g\":131},\"33905005\":{\"m\":129,\"g\":131},\"820e13c9\":{\"m\":129,\"g\":131},\"595adf6d\":{\"m\":129,\"g\":131},\"a5ad0069\":{\"m\":129,\"g\":131},\"4e41edcb\":{\"m\":129,\"g\":131},\"67071f55\":{\"m\":129,\"g\":131},\"0c966779\":{\"m\":129,\"g\":131},\"f3386077\":{\"m\":129,\"g\":131},\"4ce8fb3c\":{\"m\":129,\"g\":131},\"4c3573e4\":{\"m\":129,\"g\":131},\"f1be8aa0\":{\"m\":129,\"g\":131},\"aa8ecbda\":{\"m\":129,\"g\":131},\"9188fecc\":{\"m\":129,\"g\":131},\"26ca0746\":{\"m\":129,\"g\":131},\"90c18a16\":{\"m\":129,\"g\":131},\"d7984f31\":{\"m\":129,\"g\":131},\"7119d188\":{\"m\":129,\"g\":131},\"fe3bbfb4\":{\"m\":129,\"g\":131},\"9846f8ed\":{\"m\":129,\"g\":131},\"85ae508e\":{\"m\":129,\"g\":131},\"d879e37f\":{\"m\":129,\"g\":131},\"e2c9a590\":{\"m\":129,\"g\":131},\"a1e37b02\":{\"m\":129,\"g\":131},\"e389f91d\":{\"m\":129,\"g\":131},\"aac07bf7\":{\"m\":129,\"g\":131},\"ea89a3a0\":{\"m\":129,\"g\":131},\"2bc7c5eb\":{\"m\":129,\"g\":131},\"a63f433b\":{\"m\":129,\"g\":131},\"58f8f4e4\":{\"m\":129,\"g\":131},\"25acbbc6\":{\"m\":129,\"g\":131},\"a8fcbf6f\":{\"m\":129,\"g\":131},\"e486308c\":{\"m\":129,\"g\":131},\"60420109\":{\"m\":129,\"g\":131},\"ff00b6ad\":{\"m\":129,\"g\":131},\"7b445260\":{\"m\":129,\"g\":131},\"c236d05f\":{\"m\":129,\"g\":131},\"df561392\":{\"m\":129,\"g\":131},\"15db5497\":{\"m\":129,\"g\":131},\"9ba3597d\":{\"m\":129,\"g\":131},\"80797c2a\":{\"m\":129,\"g\":131},\"7afff8fd\":{\"m\":129,\"g\":131},\"ac406d43\":{\"m\":129,\"g\":131},\"b436113f\":{\"m\":129,\"g\":131},\"8b5e2c53\":{\"m\":129,\"g\":131},\"b24235b8\":{\"m\":129,\"g\":131},\"ae7698fb\":{\"m\":129,\"g\":131},\"a3e4fe4b\":{\"m\":129,\"g\":131},\"f3e9336d\":{\"m\":129,\"g\":131},\"290fcd89\":{\"m\":129,\"g\":131},\"1dcde539\":{\"m\":129,\"g\":131},\"15bc1f5c\":{\"m\":129,\"g\":131},\"4d597616\":{\"m\":129,\"g\":131},\"ab63f3c5\":{\"m\":129,\"g\":131},\"147b7823\":{\"m\":129,\"g\":131},\"d368c745\":{\"m\":129,\"g\":131},\"7e626d12\":{\"m\":129,\"g\":131},\"2a577344\":{\"m\":129,\"g\":131},\"6abb8051\":{\"m\":130,\"g\":131},\"2e3946d8\":{\"m\":130,\"g\":131},\"32f8b606\":{\"m\":130,\"g\":131},\"b9bef31a\":{\"m\":130,\"g\":131},\"8550822d\":{\"m\":130,\"g\":131},\"8810152e\":{\"m\":130,\"g\":131},\"7bf16c63\":{\"m\":130,\"g\":131},\"39f9a9c2\":{\"m\":130,\"g\":131},\"d69ecc19\":{\"m\":130,\"g\":131},\"763888b5\":{\"m\":130,\"g\":131},\"9a327bdf\":{\"m\":130,\"g\":131},\"2de98010\":{\"m\":130,\"g\":131},\"8200fb56\":{\"m\":130,\"g\":131},\"cb4cdb43\":{\"m\":130,\"g\":131},\"80cfca50\":{\"m\":130,\"g\":131},\"7871593c\":{\"m\":130,\"g\":131},\"4a62a0e3\":{\"m\":130,\"g\":131},\"12a08efc\":{\"m\":130,\"g\":131},\"06836ad0\":{\"m\":130,\"g\":131},\"f72a7703\":{\"m\":130,\"g\":131},\"aeff0d38\":{\"m\":130,\"g\":131},\"cf0478d6\":{\"m\":130,\"g\":131},\"a2ca9bd4\":{\"m\":130,\"g\":131},\"36361adc\":{\"m\":130,\"g\":131},\"661e9775\":{\"m\":130,\"g\":131},\"2970f229\":{\"m\":130,\"g\":131},\"c08b780f\":{\"m\":130,\"g\":131},\"8fbf7dd5\":{\"m\":130,\"g\":131},\"1915a1f8\":{\"m\":130,\"g\":131},\"85d0ccfa\":{\"m\":130,\"g\":131},\"a4ffd665\":{\"m\":130,\"g\":131},\"559202b5\":{\"m\":130,\"g\":131},\"f57d4fe7\":{\"m\":130,\"g\":131},\"b7b7524e\":{\"m\":130,\"g\":131},\"6799847e\":{\"m\":130,\"g\":131},\"03b835e7\":{\"m\":130,\"g\":131},\"aff1238e\":{\"m\":130,\"g\":131},\"5e2cda61\":{\"m\":130,\"g\":131},\"b0bbc7f5\":{\"m\":130,\"g\":131},\"673c11ba\":{\"m\":130,\"g\":131},\"3b47973a\":{\"m\":130,\"g\":131},\"f6423b62\":{\"m\":130,\"g\":131},\"84efe54b\":{\"m\":130,\"g\":131},\"948b6ace\":{\"m\":130,\"g\":131},\"f124539a\":{\"m\":130,\"g\":131},\"c8683ae3\":{\"m\":130,\"g\":131},\"125e17ef\":{\"m\":130,\"g\":131},\"88c459c6\":{\"m\":130,\"g\":131},\"ae6a6630\":{\"m\":130,\"g\":131},\"26d95008\":{\"m\":130,\"g\":131},\"f2b5dcc9\":{\"m\":130,\"g\":131},\"9abcab3f\":{\"m\":130,\"g\":131},\"e5135b73\":{\"m\":130,\"g\":131},\"41d61faa\":{\"m\":130,\"g\":131},\"3c7886ec\":{\"m\":130,\"g\":131},\"0e4d8790\":{\"m\":130,\"g\":131},\"6d5d76ad\":{\"m\":130,\"g\":131},\"91c9c14c\":{\"m\":130,\"g\":131},\"32a32cf7\":{\"m\":130,\"g\":131},\"be4a3ec3\":{\"m\":130,\"g\":131},\"ff6e3ea9\":{\"m\":130,\"g\":131},\"dd91d38e\":{\"m\":130,\"g\":131},\"5f6f550a\":{\"m\":130,\"g\":131},\"5edbe351\":{\"m\":130,\"g\":131},\"d2b42477\":{\"m\":130,\"g\":131},\"9dfa01a4\":{\"m\":130,\"g\":131},\"e592ee65\":{\"m\":130,\"g\":131},\"cee93a6f\":{\"m\":130,\"g\":131},\"bc388471\":{\"m\":130,\"g\":131},\"3e40c636\":{\"m\":130,\"g\":131},\"80122e4f\":{\"m\":130,\"g\":131},\"e12c6b32\":{\"m\":130,\"g\":131},\"6d417918\":{\"m\":130,\"g\":131},\"35a9a073\":{\"m\":130,\"g\":131},\"ea177372\":{\"m\":130,\"g\":131},\"42fcf543\":{\"m\":130,\"g\":131},\"d257bf87\":{\"m\":130,\"g\":131},\"7b0c7ad1\":{\"m\":130,\"g\":131},\"d30d6b36\":{\"m\":130,\"g\":131},\"d881f314\":{\"m\":130,\"g\":131},\"2ac5b983\":{\"m\":130,\"g\":131},\"a0dde90a\":{\"m\":130,\"g\":131},\"b988c18e\":{\"m\":130,\"g\":131},\"e41664ba\":{\"m\":130,\"g\":131},\"3d1b591a\":{\"m\":130,\"g\":131},\"e11f795f\":{\"m\":130,\"g\":131},\"959a1746\":{\"m\":130,\"g\":131},\"b72f0268\":{\"m\":130,\"g\":131},\"09376fd7\":{\"m\":130,\"g\":131},\"aed835e3\":{\"m\":130,\"g\":131},\"49dfa1d8\":{\"m\":130,\"g\":131},\"1ea6b740\":{\"m\":130,\"g\":131},\"e73173b0\":{\"m\":130,\"g\":131},\"16e8463a\":{\"m\":130,\"g\":131},\"cf9a774c\":{\"m\":130,\"g\":131},\"ec7b2c16\":{\"m\":130,\"g\":131},\"1569fc7f\":{\"m\":130,\"g\":131},\"5a46fb15\":{\"m\":130,\"g\":131},\"38daa294\":{\"m\":130,\"g\":131},\"66984a8b\":{\"m\":130,\"g\":131},\"889b46ea\":{\"m\":130,\"g\":131},\"05284378\":{\"m\":130,\"g\":131},\"a8904560\":{\"m\":130,\"g\":131},\"66280987\":{\"m\":130,\"g\":131},\"205f041e\":{\"m\":130,\"g\":131},\"7235a7fb\":{\"m\":130,\"g\":131},\"8fce9e7b\":{\"m\":130,\"g\":131},\"53477322\":{\"m\":130,\"g\":131},\"2ce121a1\":{\"m\":130,\"g\":131},\"35ba6fe1\":{\"m\":130,\"g\":131},\"498ea41c\":{\"m\":130,\"g\":131},\"7c744d13\":{\"m\":130,\"g\":131},\"46b05ef5\":{\"m\":130,\"g\":131},\"beec8eed\":{\"m\":130,\"g\":131},\"b76e303e\":{\"m\":130,\"g\":131},\"80a575e4\":{\"m\":130,\"g\":131},\"4c5074eb\":{\"m\":130,\"g\":131},\"41429a8c\":{\"m\":130,\"g\":131},\"532037df\":{\"m\":130,\"g\":131},\"fa0ca976\":{\"m\":130,\"g\":131},\"b5d39985\":{\"m\":130,\"g\":131},\"2ecee757\":{\"m\":130,\"g\":131},\"6d37e708\":{\"m\":130,\"g\":131},\"c1006fd8\":{\"m\":130,\"g\":131},\"29c6c2ea\":{\"m\":130,\"g\":131},\"eb85fa6d\":{\"m\":130,\"g\":131},\"0e6441b4\":{\"m\":130,\"g\":131},\"88d1bab5\":{\"m\":130,\"g\":131},\"922756aa\":{\"m\":130,\"g\":131},\"d8faf2f3\":{\"m\":130,\"g\":131},\"7dfcc781\":{\"m\":130,\"g\":131},\"7f3308bc\":{\"m\":130,\"g\":131},\"fdc2ef58\":{\"m\":130,\"g\":131},\"1808df48\":{\"m\":130,\"g\":131},\"b01fc161\":{\"m\":130,\"g\":131},\"788628b5\":{\"m\":130,\"g\":131},\"11d33c0e\":{\"m\":130,\"g\":131},\"441420e1\":{\"m\":130,\"g\":131},\"29a2d4b5\":{\"m\":130,\"g\":131},\"079ac237\":{\"m\":130,\"g\":131},\"84280784\":{\"m\":130,\"g\":131},\"af35023e\":{\"m\":130,\"g\":131},\"cb8df87f\":{\"m\":130,\"g\":131},\"e3ab23c1\":{\"m\":130,\"g\":131},\"70d25873\":{\"m\":130,\"g\":131},\"894c0dc5\":{\"m\":130,\"g\":131},\"d6c49019\":{\"m\":130,\"g\":131},\"fa78c44a\":{\"m\":130,\"g\":131},\"654a78f9\":{\"m\":130,\"g\":131},\"f90b4004\":{\"m\":130,\"g\":131},\"78647e08\":{\"m\":130,\"g\":131},\"46f21a59\":{\"m\":130,\"g\":131},\"b2b09f5f\":{\"m\":130,\"g\":131},\"04df80a9\":{\"m\":130,\"g\":131},\"4f73e53d\":{\"m\":130,\"g\":131},\"16ff892c\":{\"m\":130,\"g\":131},\"df026bb1\":{\"m\":130,\"g\":131},\"d42c167b\":{\"m\":130,\"g\":131},\"38815105\":{\"m\":130,\"g\":131},\"9d823402\":{\"m\":130,\"g\":131},\"8ab5d8b4\":{\"m\":130,\"g\":131},\"03575ce3\":{\"m\":130,\"g\":131},\"80518bea\":{\"m\":130,\"g\":131},\"7e78825d\":{\"m\":130,\"g\":131},\"5bbd83a2\":{\"m\":130,\"g\":131},\"abf6272b\":{\"m\":130,\"g\":131},\"46d7b35e\":{\"m\":130,\"g\":131},\"20aad5b5\":{\"m\":130,\"g\":131},\"974c562a\":{\"m\":130,\"g\":131},\"aca0d01d\":{\"m\":130,\"g\":131},\"dc163502\":{\"m\":130,\"g\":131},\"24903b88\":{\"m\":130,\"g\":131},\"443d7bcd\":{\"m\":130,\"g\":131},\"16d8de22\":{\"m\":130,\"g\":131},\"d122e324\":{\"m\":130,\"g\":131},\"96cc1083\":{\"m\":130,\"g\":131},\"77512ae0\":{\"m\":130,\"g\":131},\"c233e9d7\":{\"m\":130,\"g\":131},\"58ac3f31\":{\"m\":130,\"g\":131},\"93452a82\":{\"m\":130,\"g\":131},\"65c8568c\":{\"m\":130,\"g\":131},\"4bcc5879\":{\"m\":130,\"g\":131},\"d5ea8c71\":{\"m\":130,\"g\":131},\"42271376\":{\"m\":130,\"g\":131},\"84e0abb7\":{\"m\":130,\"g\":131},\"043f1317\":{\"m\":130,\"g\":131},\"f764c691\":{\"m\":130,\"g\":131},\"92205407\":{\"m\":130,\"g\":131},\"7d1a130c\":{\"m\":130,\"g\":131},\"5c8bd8b5\":{\"m\":132,\"g\":133},\"b05b346a\":{\"m\":132,\"g\":133},\"5c961756\":{\"m\":132,\"g\":133},\"c5f1e861\":{\"m\":132,\"g\":133},\"2c4d376d\":{\"m\":132,\"g\":133},\"d0f756ae\":{\"m\":132,\"g\":133},\"6f99dc97\":{\"m\":132,\"g\":133},\"cd1c1fa5\":{\"m\":132,\"g\":133},\"5b0872d2\":{\"m\":132,\"g\":133},\"ba88f1ca\":{\"m\":132,\"g\":133},\"60560c07\":{\"m\":132,\"g\":133},\"ca114421\":{\"m\":132,\"g\":133},\"543d62d1\":{\"m\":132,\"g\":133},\"27032cec\":{\"m\":132,\"g\":133},\"5d804a37\":{\"m\":132,\"g\":133},\"388018a5\":{\"m\":132,\"g\":133},\"fca8e88f\":{\"m\":132,\"g\":133},\"45eeeb9a\":{\"m\":132,\"g\":133},\"a368df28\":{\"m\":132,\"g\":133},\"f85460fb\":{\"m\":132,\"g\":133},\"e52cf30e\":{\"m\":132,\"g\":133},\"a076d75e\":{\"m\":132,\"g\":133},\"8348725d\":{\"m\":132,\"g\":133},\"28566241\":{\"m\":132,\"g\":133},\"b62fe850\":{\"m\":132,\"g\":133},\"624725cb\":{\"m\":132,\"g\":133},\"1a96e664\":{\"m\":132,\"g\":133},\"32829b16\":{\"m\":132,\"g\":133},\"e54307f2\":{\"m\":132,\"g\":133},\"8642dbe4\":{\"m\":132,\"g\":133},\"7dcad45c\":{\"m\":132,\"g\":133},\"7c985331\":{\"m\":132,\"g\":133},\"bd7824b2\":{\"m\":132,\"g\":133},\"312df1d6\":{\"m\":132,\"g\":133},\"25e97380\":{\"m\":132,\"g\":133},\"b6523a4f\":{\"m\":132,\"g\":133},\"c51efb8b\":{\"m\":132,\"g\":133},\"bcc5483e\":{\"m\":132,\"g\":133},\"ccf26027\":{\"m\":132,\"g\":133},\"a4992873\":{\"m\":132,\"g\":133},\"c97ce391\":{\"m\":132,\"g\":133},\"c032b559\":{\"m\":132,\"g\":133},\"da9b801e\":{\"m\":132,\"g\":133},\"0e54a695\":{\"m\":132,\"g\":133},\"e99ee0c6\":{\"m\":132,\"g\":133},\"c1bd5ee8\":{\"m\":132,\"g\":133},\"ef1ab230\":{\"m\":132,\"g\":133},\"d6598737\":{\"m\":132,\"g\":133},\"6c5ebc0e\":{\"m\":132,\"g\":133},\"5b5571a8\":{\"m\":132,\"g\":133},\"1698c234\":{\"m\":132,\"g\":133},\"3d82c0f1\":{\"m\":132,\"g\":133},\"617e9b3b\":{\"m\":132,\"g\":133},\"83e35a7c\":{\"m\":132,\"g\":133},\"2543666c\":{\"m\":132,\"g\":133},\"f732f8ea\":{\"m\":132,\"g\":133},\"503880db\":{\"m\":132,\"g\":133},\"b8cfa02c\":{\"m\":132,\"g\":133},\"d85fecb5\":{\"m\":132,\"g\":133},\"6634f67b\":{\"m\":132,\"g\":133},\"5eccaf77\":{\"m\":132,\"g\":133},\"d7f6320b\":{\"m\":132,\"g\":133},\"766476f5\":{\"m\":132,\"g\":133},\"12b7a4fa\":{\"m\":132,\"g\":133},\"02f1e81e\":{\"m\":132,\"g\":133},\"908c7186\":{\"m\":132,\"g\":133},\"03836d85\":{\"m\":132,\"g\":133},\"87dbdddc\":{\"m\":132,\"g\":133},\"56e5c074\":{\"m\":132,\"g\":133},\"6c9c8da6\":{\"m\":132,\"g\":133},\"21028b55\":{\"m\":132,\"g\":133},\"b0a25d09\":{\"m\":132,\"g\":133},\"793c98af\":{\"m\":132,\"g\":133},\"b1cbfce6\":{\"m\":132,\"g\":133},\"b0f531ad\":{\"m\":132,\"g\":133},\"01835998\":{\"m\":132,\"g\":133},\"4285e99d\":{\"m\":132,\"g\":133},\"f0774368\":{\"m\":132,\"g\":133},\"5e8f544d\":{\"m\":132,\"g\":133},\"c8d74feb\":{\"m\":132,\"g\":133},\"cbc7dcda\":{\"m\":132,\"g\":133},\"a6dc7d29\":{\"m\":132,\"g\":133},\"390406c4\":{\"m\":132,\"g\":131},\"0c63fb94\":{\"m\":132,\"g\":131},\"9ad02b79\":{\"m\":132,\"g\":131},\"18bd8e8d\":{\"m\":132,\"g\":131},\"036e64da\":{\"m\":132,\"g\":131},\"7c6fb3aa\":{\"m\":132,\"g\":131},\"8b0b6a45\":{\"m\":132,\"g\":131},\"55504df2\":{\"m\":132,\"g\":131},\"73df7a4e\":{\"m\":132,\"g\":131},\"8b98bb76\":{\"m\":132,\"g\":131},\"15bc8cbd\":{\"m\":132,\"g\":131},\"9496f12d\":{\"m\":132,\"g\":131},\"ab004879\":{\"m\":132,\"g\":131},\"6ec77680\":{\"m\":132,\"g\":131},\"13680e55\":{\"m\":132,\"g\":131},\"98c430e1\":{\"m\":132,\"g\":131},\"fe7f91ef\":{\"m\":132,\"g\":131},\"cef5ba65\":{\"m\":132,\"g\":131},\"53d17088\":{\"m\":132,\"g\":131},\"f0e948a0\":{\"m\":132,\"g\":131},\"9a426fc5\":{\"m\":132,\"g\":131},\"66772aa2\":{\"m\":132,\"g\":131},\"b6263344\":{\"m\":132,\"g\":131},\"0f8bd55f\":{\"m\":132,\"g\":131},\"da3dc497\":{\"m\":132,\"g\":131},\"817daba0\":{\"m\":132,\"g\":131},\"af60cad0\":{\"m\":132,\"g\":131},\"ce4e836b\":{\"m\":132,\"g\":131},\"0e0b0c05\":{\"m\":132,\"g\":131},\"e6f0ddda\":{\"m\":132,\"g\":131},\"af20657c\":{\"m\":132,\"g\":131},\"08da4c26\":{\"m\":132,\"g\":131},\"ef3f8c97\":{\"m\":132,\"g\":131},\"e5201bda\":{\"m\":132,\"g\":131},\"60d36e7b\":{\"m\":132,\"g\":131},\"6f657070\":{\"m\":132,\"g\":131},\"c106b54b\":{\"m\":132,\"g\":131},\"119fd956\":{\"m\":132,\"g\":131},\"eac5b664\":{\"m\":132,\"g\":131},\"07404d76\":{\"m\":132,\"g\":131},\"93043f7b\":{\"m\":132,\"g\":131},\"edde5e5d\":{\"m\":132,\"g\":131},\"232982a0\":\"m134\",\"9e88c0a2\":\"m134\",\"b7e0d54e\":\"m134\",\"5dde0a57\":\"m134\",\"9e5ab903\":\"m134\",\"98225be6\":{\"m\":134,\"g\":135},\"94bcc19b\":{\"m\":134,\"g\":135},\"db3821a9\":{\"m\":134,\"g\":135},\"8c5d91b8\":{\"m\":134,\"g\":135},\"6a3e7092\":{\"m\":134,\"g\":135},\"c2601f0d\":{\"m\":134,\"g\":135},\"1048803c\":{\"m\":134,\"g\":135},\"8a84b1e7\":{\"m\":134,\"g\":135},\"5e20e7a6\":{\"m\":134,\"g\":135},\"3946dad6\":{\"m\":134,\"g\":135},\"ac03ec08\":{\"m\":134,\"g\":135},\"94164646\":{\"m\":134,\"g\":135},\"5fb734f1\":{\"m\":134,\"g\":135},\"60f1ca69\":{\"m\":134,\"g\":135},\"9e263c21\":{\"m\":134,\"g\":135},\"0d003e34\":{\"m\":134,\"g\":135},\"f253f43c\":{\"m\":134,\"g\":135},\"1e453201\":{\"m\":134,\"g\":135},\"26e17f90\":{\"m\":134,\"g\":135},\"3de23274\":{\"m\":134,\"g\":135},\"f4ec6f8e\":{\"m\":134,\"g\":135},\"9c4eb460\":{\"m\":134,\"g\":135},\"c2e0913e\":{\"m\":134,\"g\":135},\"8c6f865a\":{\"m\":134,\"g\":135},\"269aa27b\":{\"m\":134,\"g\":135},\"f39382c6\":{\"m\":134,\"g\":135},\"2ff289e2\":{\"m\":134,\"g\":135},\"b2a3f055\":{\"m\":134,\"g\":135},\"684e148e\":{\"m\":134,\"g\":135},\"f44c4b37\":{\"m\":134,\"g\":135},\"7380ec9d\":{\"m\":134,\"g\":135},\"ac78f96e\":{\"m\":134,\"g\":135},\"278012ca\":{\"m\":134,\"g\":135},\"7f587998\":{\"m\":134,\"g\":135},\"162d1cf9\":{\"m\":134,\"g\":135},\"8e08207c\":{\"m\":134,\"g\":135},\"c31f6272\":{\"m\":134,\"g\":135},\"b5d9fc87\":{\"m\":134,\"g\":135},\"88f3de25\":{\"m\":134,\"g\":135},\"24616c52\":{\"m\":134,\"g\":135},\"d48723b7\":{\"m\":134,\"g\":135},\"f3d73b01\":{\"m\":134,\"g\":135},\"4ab66d95\":{\"m\":134,\"g\":135},\"de2799f3\":{\"m\":134,\"g\":135},\"f784cbfa\":{\"m\":134,\"g\":135},\"09733090\":{\"m\":134,\"g\":135},\"a44d0079\":{\"m\":134,\"g\":135},\"8305dc17\":{\"m\":134,\"g\":135},\"ec8c831d\":{\"m\":134,\"g\":135},\"f13949e5\":{\"m\":134,\"g\":135},\"c236a3fd\":{\"m\":134,\"g\":135},\"41a1d16b\":{\"m\":134,\"g\":135},\"9884c9fd\":{\"m\":134,\"g\":135},\"a435f55d\":{\"m\":134,\"g\":135},\"2ec6fa3c\":{\"m\":134,\"g\":135},\"c58a573a\":{\"m\":134,\"g\":135},\"b840d6aa\":{\"m\":134,\"g\":135},\"ef4b3c0e\":{\"m\":134,\"g\":135},\"6f9d0a89\":{\"m\":134,\"g\":135},\"d7a3336e\":{\"m\":134,\"g\":135},\"7d02c8e5\":{\"m\":134,\"g\":135},\"e6d5a213\":{\"m\":134,\"g\":135},\"9f8e2307\":{\"m\":134,\"g\":135},\"d90f9bfc\":{\"m\":134,\"g\":135},\"3881bc8d\":{\"m\":134,\"g\":135},\"208e6a9d\":{\"m\":134,\"g\":135},\"be3828a1\":{\"m\":134,\"g\":135},\"5969be2f\":{\"m\":134,\"g\":135},\"8fab4895\":{\"m\":134,\"g\":135},\"7ccaec64\":{\"m\":134,\"g\":135},\"c457aad5\":{\"m\":134,\"g\":135},\"c7e7bfa3\":{\"m\":134,\"g\":135},\"bf90ea9c\":{\"m\":134,\"g\":135},\"26c50912\":{\"m\":134,\"g\":135},\"325a4c19\":{\"m\":134,\"g\":135},\"0294844f\":{\"m\":134,\"g\":135},\"0e536600\":{\"m\":134,\"g\":135},\"d70c2655\":{\"m\":134,\"g\":135},\"656f4d69\":{\"m\":134,\"g\":135},\"8e43980e\":{\"m\":134,\"g\":135},\"474a4699\":{\"m\":134,\"g\":135},\"b4a00ed2\":{\"m\":134,\"g\":135},\"f55d608c\":{\"m\":134,\"g\":135},\"349ce2dd\":{\"m\":134,\"g\":135},\"183b6519\":{\"m\":134,\"g\":135},\"2af955e1\":{\"m\":134,\"g\":135},\"0cd2b719\":{\"m\":134,\"g\":135},\"39d56196\":{\"m\":134,\"g\":135},\"41addd2e\":{\"m\":134,\"g\":135},\"5c393e81\":{\"m\":134,\"g\":135},\"3645ed0f\":{\"m\":134,\"g\":135},\"ca740a41\":{\"m\":134,\"g\":135},\"0e25aa43\":{\"m\":134,\"g\":135},\"60a230b1\":{\"m\":134,\"g\":135},\"aa89c6a7\":{\"m\":134,\"g\":135},\"171912a9\":{\"m\":134,\"g\":135},\"faecd37e\":{\"m\":134,\"g\":135},\"9ad546d7\":{\"m\":134,\"g\":135},\"29ce7b36\":{\"m\":134,\"g\":135},\"a8380ded\":{\"m\":134,\"g\":135},\"4edee695\":{\"m\":134,\"g\":135},\"67caea6f\":{\"m\":134,\"g\":135},\"cd3289c7\":{\"m\":134,\"g\":135},\"acddb8e0\":{\"m\":134,\"g\":135},\"2ec57cef\":{\"m\":134,\"g\":135},\"988b14ca\":{\"m\":134,\"g\":135},\"93495dca\":{\"m\":134,\"g\":135},\"886e0383\":{\"m\":134,\"g\":135},\"01bd0d3e\":{\"m\":134,\"g\":135},\"43e1bbc0\":{\"m\":134,\"g\":135},\"59b12996\":{\"m\":134,\"g\":135},\"b7091496\":{\"m\":134,\"g\":135},\"51dbdb22\":{\"m\":134,\"g\":135},\"8dc6f0fc\":{\"m\":134,\"g\":135},\"cf34d0ab\":{\"m\":134,\"g\":135},\"3778c2fc\":{\"m\":134,\"g\":135},\"5d421db8\":{\"m\":134,\"g\":135},\"fe3d47fc\":{\"m\":134,\"g\":135},\"ef92b4eb\":{\"m\":134,\"g\":135},\"73c0c66f\":{\"m\":134,\"g\":135},\"a1e9b4ed\":{\"m\":134,\"g\":135},\"e75657c8\":{\"m\":134,\"g\":135},\"c28c536c\":{\"m\":134,\"g\":135},\"086813ae\":{\"m\":134,\"g\":135},\"3fd232ad\":{\"m\":134,\"g\":135},\"cb181295\":{\"m\":134,\"g\":135},\"a91e072f\":{\"m\":134,\"g\":135},\"0271fc34\":{\"m\":134,\"g\":135},\"68bece8c\":{\"m\":134,\"g\":135},\"7b7e357f\":{\"m\":134,\"g\":135},\"2f66b067\":{\"m\":134,\"g\":135},\"48051181\":{\"m\":134,\"g\":135},\"f2ccc442\":{\"m\":134,\"g\":135},\"caa95c7e\":{\"m\":134,\"g\":135},\"9d878c1f\":{\"m\":134,\"g\":135},\"8087ef12\":{\"m\":134,\"g\":135},\"6ef543f9\":{\"m\":134,\"g\":135},\"a3559119\":{\"m\":134,\"g\":135},\"f3ba7116\":{\"m\":134,\"g\":135},\"5c243ba5\":{\"m\":134,\"g\":135},\"b6702d72\":{\"m\":134,\"g\":135},\"c1256727\":{\"m\":134,\"g\":135},\"bb9e6cdf\":{\"m\":134,\"g\":135},\"de03b0cd\":{\"m\":134,\"g\":135},\"2a8a7856\":{\"m\":134,\"g\":135},\"f4e835af\":{\"m\":134,\"g\":135},\"e6ce16a4\":{\"m\":134,\"g\":135},\"de2f2880\":{\"m\":134,\"g\":135},\"a89e85e7\":{\"m\":134,\"g\":135},\"cbf9f134\":{\"m\":134,\"g\":135},\"eb3da9c1\":{\"m\":134,\"g\":135},\"72a980c6\":{\"m\":134,\"g\":135},\"ccf2330b\":{\"m\":134,\"g\":135},\"b9af8d2e\":{\"m\":134,\"g\":135},\"10a9573e\":{\"m\":134,\"g\":135},\"49ab72f8\":{\"m\":134,\"g\":135},\"8865424f\":{\"m\":134,\"g\":135},\"b311c43d\":{\"m\":134,\"g\":135},\"0c39730b\":{\"m\":134,\"g\":135},\"45adad37\":{\"m\":134,\"g\":135},\"1ba897f3\":{\"m\":134,\"g\":135},\"38dd4fbb\":{\"m\":134,\"g\":135},\"ecd2d09a\":{\"m\":134,\"g\":135},\"92ddc468\":{\"m\":134,\"g\":135},\"5454d2a7\":{\"m\":134,\"g\":133},\"17b38f88\":{\"m\":134,\"g\":133},\"ae434f78\":{\"m\":134,\"g\":133},\"643aeefe\":{\"m\":134,\"g\":133},\"186a56f6\":{\"m\":134,\"g\":133},\"17e65466\":{\"m\":134,\"g\":133},\"370bd27f\":{\"m\":134,\"g\":133},\"b27b5a83\":{\"m\":134,\"g\":133},\"2f7c6292\":{\"m\":134,\"g\":133},\"2fb31605\":{\"m\":134,\"g\":133},\"8bf7f240\":{\"m\":134,\"g\":133},\"2c5679f3\":{\"m\":134,\"g\":133},\"159b1283\":{\"m\":134,\"g\":133},\"9338f63f\":{\"m\":134,\"g\":133},\"8196998a\":{\"m\":134,\"g\":133},\"aa21c6e3\":{\"m\":134,\"g\":133},\"fd4a558e\":{\"m\":134,\"g\":133},\"b3b818fd\":{\"m\":134,\"g\":133},\"d5fbbfd9\":{\"m\":134,\"g\":133},\"e245cac0\":{\"m\":134,\"g\":133},\"c6a6ba43\":{\"m\":134,\"g\":133},\"d6108166\":{\"m\":134,\"g\":133},\"ddb3970e\":{\"m\":134,\"g\":133},\"f65fa047\":{\"m\":134,\"g\":133},\"7e027691\":{\"m\":134,\"g\":133},\"cb719c74\":{\"m\":134,\"g\":133},\"ff903a7e\":{\"m\":134,\"g\":133},\"eee3700d\":{\"m\":134,\"g\":133},\"e254cdf3\":{\"m\":134,\"g\":133},\"dfb53574\":{\"m\":134,\"g\":133},\"96655749\":{\"m\":134,\"g\":133},\"ac320a6f\":{\"m\":134,\"g\":133},\"6292c244\":{\"m\":134,\"g\":133},\"5f5a5677\":{\"m\":134,\"g\":133},\"3bf07c68\":{\"m\":134,\"g\":133},\"e7b09efc\":{\"m\":134,\"g\":133},\"99d3bcdf\":{\"m\":134,\"g\":133},\"4d64f150\":{\"m\":134,\"g\":133},\"aef7ca7c\":{\"m\":134,\"g\":133},\"fe712aa3\":{\"m\":134,\"g\":133},\"846953d9\":{\"m\":134,\"g\":133},\"5c64a20d\":{\"m\":134,\"g\":133},\"cf817376\":{\"m\":134,\"g\":133},\"aa6ac966\":{\"m\":134,\"g\":133},\"0d0367e9\":{\"m\":134,\"g\":133},\"bd572360\":{\"m\":134,\"g\":133},\"5f3a47d8\":{\"m\":134,\"g\":133},\"80ae2229\":{\"m\":134,\"g\":133},\"705287b2\":{\"m\":134,\"g\":133},\"76284653\":{\"m\":134,\"g\":133},\"dd620987\":{\"m\":134,\"g\":133},\"53f974b9\":{\"m\":134,\"g\":133},\"6a5764a7\":{\"m\":134,\"g\":133},\"291f11ae\":{\"m\":134,\"g\":133},\"d7301c89\":{\"m\":134,\"g\":133},\"758b9067\":{\"m\":134,\"g\":133},\"ffc23ef8\":{\"m\":134,\"g\":133},\"b3f83cc1\":{\"m\":134,\"g\":133},\"c15fa1c5\":{\"m\":134,\"g\":133},\"fa296698\":{\"m\":134,\"g\":133},\"f9dd90ac\":{\"m\":134,\"g\":133},\"66902e0f\":{\"m\":134,\"g\":133},\"e50f356f\":{\"m\":134,\"g\":133},\"ac42797c\":{\"m\":134,\"g\":133},\"883747ce\":{\"m\":134,\"g\":133},\"989d4b30\":{\"m\":134,\"g\":133},\"bc3ca300\":{\"m\":134,\"g\":133},\"061f41af\":{\"m\":134,\"g\":133},\"82f1d615\":{\"m\":134,\"g\":133},\"5e1a495c\":{\"m\":134,\"g\":133},\"34013d9d\":{\"m\":134,\"g\":133},\"77597167\":{\"m\":134,\"g\":133},\"3c882db3\":{\"m\":134,\"g\":133},\"b736a152\":{\"m\":134,\"g\":133},\"6984837d\":{\"m\":134,\"g\":133},\"2142881b\":{\"m\":134,\"g\":133},\"d77f3fcc\":{\"m\":134,\"g\":133},\"828dec1c\":{\"m\":134,\"g\":133},\"575a49dc\":{\"m\":134,\"g\":133},\"e62e1744\":{\"m\":134,\"g\":133},\"d5431ff8\":{\"m\":134,\"g\":133},\"454a2544\":{\"m\":134,\"g\":133},\"89619a99\":{\"m\":134,\"g\":133},\"cb30d056\":{\"m\":134,\"g\":133},\"677930c2\":{\"m\":134,\"g\":133},\"beae3f96\":{\"m\":134,\"g\":133},\"1167867e\":{\"m\":134,\"g\":133},\"ad7f35fb\":{\"m\":134,\"g\":133},\"f4100732\":{\"m\":134,\"g\":133},\"468931b5\":{\"m\":134,\"g\":133},\"796969ca\":{\"m\":134,\"g\":133},\"122c2503\":{\"m\":134,\"g\":133},\"a3a55223\":{\"m\":134,\"g\":133},\"0bf95e6d\":{\"m\":134,\"g\":133},\"a92de891\":{\"m\":134,\"g\":133},\"b9d78605\":{\"m\":134,\"g\":133},\"1354063a\":{\"m\":134,\"g\":133},\"254de6d2\":{\"m\":134,\"g\":133},\"350fbbf4\":{\"m\":134,\"g\":133},\"a3912667\":{\"m\":134,\"g\":133},\"e1dcd0df\":{\"m\":134,\"g\":133},\"393e2f9b\":{\"m\":134,\"g\":133},\"8766a1dd\":{\"m\":134,\"g\":133},\"1d9ba2ce\":{\"m\":134,\"g\":133},\"ef001fb8\":{\"m\":134,\"g\":133},\"c69c1c4f\":{\"m\":134,\"g\":133},\"bed301a5\":{\"m\":134,\"g\":133},\"8fe3e374\":{\"m\":134,\"g\":133},\"60143655\":{\"m\":134,\"g\":133},\"fc05acc2\":{\"m\":134,\"g\":133},\"1ed94668\":{\"m\":134,\"g\":133},\"d7fbe73b\":{\"m\":134,\"g\":133},\"42bff706\":{\"m\":134,\"g\":133},\"26704c23\":{\"m\":134,\"g\":133},\"43b7c174\":{\"m\":134,\"g\":133},\"96740d69\":{\"m\":134,\"g\":133},\"9a3bdf2c\":{\"m\":134,\"g\":133},\"4b351f6b\":{\"m\":134,\"g\":133},\"7fa4906f\":{\"m\":134,\"g\":133},\"050f108c\":{\"m\":134,\"g\":133},\"47cdb65a\":{\"m\":134,\"g\":133},\"1d90b194\":{\"m\":134,\"g\":133},\"537ef18d\":{\"m\":134,\"g\":133},\"69412ccb\":{\"m\":134,\"g\":133},\"d3885d4b\":{\"m\":134,\"g\":133},\"0a346d3b\":{\"m\":134,\"g\":133},\"bc18cb86\":{\"m\":134,\"g\":133},\"bee8ac5b\":{\"m\":134,\"g\":133},\"41bd76e1\":{\"m\":134,\"g\":133},\"c6ca1b3a\":{\"m\":134,\"g\":133},\"8999ce75\":{\"m\":134,\"g\":133},\"dce2ed44\":{\"m\":134,\"g\":133},\"019517a3\":{\"m\":134,\"g\":133},\"1f1f05a8\":{\"m\":134,\"g\":133},\"165f5c04\":{\"m\":134,\"g\":133},\"3e01f3a5\":{\"m\":134,\"g\":133},\"51e2eaa4\":{\"m\":134,\"g\":133},\"6468cb58\":{\"m\":134,\"g\":133},\"e220da17\":{\"m\":134,\"g\":133},\"b82c7a0a\":{\"m\":134,\"g\":133},\"c0f9b519\":{\"m\":134,\"g\":133},\"5529ab58\":{\"m\":134,\"g\":133},\"74a3349b\":{\"m\":134,\"g\":133},\"b9ebf0ed\":{\"m\":134,\"g\":133},\"3c116d5e\":{\"m\":134,\"g\":133},\"71a60288\":{\"m\":134,\"g\":133},\"61405b3d\":{\"m\":134,\"g\":133},\"0adfc42b\":{\"m\":134,\"g\":133},\"ba72e759\":{\"m\":134,\"g\":133},\"6afc5d49\":{\"m\":134,\"g\":133},\"2ee6c810\":{\"m\":134,\"g\":133},\"bd16244d\":{\"m\":134,\"g\":133},\"b5eb0214\":{\"m\":134,\"g\":133},\"d72e908b\":{\"m\":134,\"g\":133},\"50cad014\":{\"m\":134,\"g\":133},\"ef908aeb\":{\"m\":134,\"g\":133},\"9d0347b3\":{\"m\":134,\"g\":133},\"05eb0bcc\":{\"m\":134,\"g\":133},\"5dccd9bd\":{\"m\":134,\"g\":133},\"933cef16\":{\"m\":134,\"g\":133},\"241ae17b\":{\"m\":134,\"g\":133},\"2c5a4460\":{\"m\":134,\"g\":133},\"f3705b01\":{\"m\":134,\"g\":133},\"5a0ad731\":{\"m\":134,\"g\":133},\"5045aa34\":{\"m\":134,\"g\":133},\"ff1e2ce2\":{\"m\":134,\"g\":133},\"1e582488\":{\"m\":134,\"g\":133},\"46be74b4\":{\"m\":134,\"g\":133},\"1c658026\":{\"m\":134,\"g\":133},\"ba410808\":{\"m\":134,\"g\":133},\"89512029\":{\"m\":134,\"g\":133},\"92e6b3c3\":{\"m\":134,\"g\":133},\"6559e43f\":{\"m\":134,\"g\":133},\"af780c59\":{\"m\":134,\"g\":133},\"a21aa87e\":{\"m\":134,\"g\":133},\"fb178457\":{\"m\":134,\"g\":133},\"65c09859\":{\"m\":134,\"g\":133},\"4bf06635\":{\"m\":134,\"g\":133},\"f2d64e67\":{\"m\":134,\"g\":133},\"0e869f08\":{\"m\":134,\"g\":133},\"17394092\":{\"m\":134,\"g\":133},\"f228b662\":{\"m\":134,\"g\":133},\"a36142aa\":{\"m\":134,\"g\":133},\"160a06ca\":{\"m\":134,\"g\":133},\"a0985dd5\":{\"m\":134,\"g\":133},\"f6c9db4b\":{\"m\":134,\"g\":133},\"e88e75a9\":{\"m\":134,\"g\":133},\"4b4050e2\":{\"m\":134,\"g\":133},\"b2803ff2\":{\"m\":134,\"g\":133},\"9e0ef04e\":{\"m\":134,\"g\":133},\"216067c0\":{\"m\":134,\"g\":133},\"e72b02db\":{\"m\":134,\"g\":133},\"29e8f7f9\":{\"m\":134,\"g\":133},\"e0963a6c\":{\"m\":134,\"g\":133},\"e0026f7c\":{\"m\":134,\"g\":133},\"9749d3e3\":{\"m\":134,\"g\":133},\"2b0ddf89\":{\"m\":134,\"g\":133},\"17e81c75\":{\"m\":134,\"g\":133},\"88a405cc\":{\"m\":134,\"g\":133},\"602fe3b2\":{\"m\":134,\"g\":133},\"ad9616f1\":{\"m\":134,\"g\":133},\"c5f4e20f\":{\"m\":134,\"g\":133},\"2c196f95\":{\"m\":134,\"g\":133},\"9a7641d7\":{\"m\":134,\"g\":133},\"793c96c3\":{\"m\":134,\"g\":133},\"d1f00632\":{\"m\":134,\"g\":133},\"4792d1f4\":{\"m\":134,\"g\":133},\"fea2d521\":{\"m\":134,\"g\":133},\"56d12b4a\":{\"m\":134,\"g\":133},\"374ad4cc\":{\"m\":134,\"g\":133},\"8b0a68f1\":{\"m\":134,\"g\":133},\"70607e55\":{\"m\":134,\"g\":133},\"ee1ca51d\":{\"m\":134,\"g\":133},\"ef7c29ac\":{\"m\":134,\"g\":133},\"58c840db\":{\"m\":134,\"g\":133},\"3d42b7e7\":{\"m\":134,\"g\":133},\"9970ee34\":{\"m\":134,\"g\":133},\"9e7656be\":{\"m\":134,\"g\":133},\"9d4f066f\":{\"m\":134,\"g\":133},\"41683536\":{\"m\":134,\"g\":133},\"891ee822\":{\"m\":134,\"g\":133},\"8fa3dc36\":{\"m\":134,\"g\":133},\"d20699a3\":{\"m\":134,\"g\":133},\"169a75df\":{\"m\":134,\"g\":133},\"4128d4f5\":{\"m\":134,\"g\":133},\"011d8d89\":{\"m\":134,\"g\":133},\"5290cef9\":{\"m\":134,\"g\":133},\"726fe3e7\":{\"m\":134,\"g\":133},\"5d087891\":{\"m\":134,\"g\":133},\"8451e227\":{\"m\":134,\"g\":133},\"53e15194\":{\"m\":134,\"g\":133},\"d747147a\":{\"m\":134,\"g\":133},\"0c002207\":{\"m\":134,\"g\":133},\"eeb2b9b2\":{\"m\":134,\"g\":133},\"c4aed389\":{\"m\":134,\"g\":133},\"9d04b570\":{\"m\":134,\"g\":133},\"3e690cce\":{\"m\":134,\"g\":133},\"b12b40de\":{\"m\":134,\"g\":133},\"6c4bf8a0\":{\"m\":134,\"g\":133},\"533851fb\":{\"m\":134,\"g\":133},\"0071fe9c\":{\"m\":134,\"g\":133},\"ffa7e035\":{\"m\":134,\"g\":133},\"712f44ee\":{\"m\":134,\"g\":133},\"8c34e181\":{\"m\":134,\"g\":133},\"e9abb525\":{\"m\":134,\"g\":133},\"88859433\":{\"m\":134,\"g\":133},\"feb8e30b\":{\"m\":134,\"g\":133},\"45a959d3\":{\"m\":134,\"g\":133},\"cdce5163\":{\"m\":134,\"g\":133},\"79ab57bd\":{\"m\":134,\"g\":133},\"4b8901ac\":{\"m\":134,\"g\":133},\"435d1c83\":{\"m\":134,\"g\":133},\"2bdbaef1\":{\"m\":134,\"g\":133},\"31d48d7f\":{\"m\":134,\"g\":133},\"0129c911\":{\"m\":134,\"g\":133},\"03f9eb25\":{\"m\":134,\"g\":133},\"7ec678eb\":{\"m\":134,\"g\":133},\"9d64a7b2\":{\"m\":134,\"g\":133},\"46ad4b98\":{\"m\":134,\"g\":133},\"71cb9037\":{\"m\":134,\"g\":133},\"c8c64876\":{\"m\":134,\"g\":133},\"0861dca8\":{\"m\":134,\"g\":133},\"da58df6b\":{\"m\":134,\"g\":133},\"49237e26\":{\"m\":134,\"g\":133},\"d92c1f8c\":{\"m\":134,\"g\":133},\"93070586\":{\"m\":134,\"g\":133},\"ccc8f3b2\":{\"m\":134,\"g\":133},\"a4c76281\":{\"m\":134,\"g\":133},\"28a19e49\":{\"m\":134,\"g\":133},\"0261c4af\":{\"m\":134,\"g\":133},\"8ac350f3\":{\"m\":134,\"g\":133},\"99401e7b\":{\"m\":134,\"g\":133},\"66824751\":{\"m\":134,\"g\":133},\"f95729b0\":{\"m\":134,\"g\":133},\"9f4ed93d\":{\"m\":134,\"g\":133},\"ecb401ed\":{\"m\":134,\"g\":133},\"b399e3ac\":{\"m\":134,\"g\":133},\"e27635a0\":{\"m\":134,\"g\":133},\"272c5fe4\":{\"m\":134,\"g\":133},\"3c8dc448\":{\"m\":134,\"g\":133},\"5e96beb3\":{\"m\":134,\"g\":133},\"9327482b\":{\"m\":134,\"g\":133},\"36fcf71f\":{\"m\":134,\"g\":133},\"6292d971\":{\"m\":134,\"g\":133},\"c8434195\":{\"m\":134,\"g\":133},\"3e4d431a\":{\"m\":134,\"g\":133},\"538e733e\":{\"m\":134,\"g\":133},\"22587bc0\":{\"m\":134,\"g\":133},\"02d24244\":{\"m\":134,\"g\":133},\"1da5cd63\":{\"m\":134,\"g\":133},\"a9a2cdd8\":{\"m\":134,\"g\":133},\"4733fcff\":{\"m\":134,\"g\":133},\"3ffa2604\":{\"m\":134,\"g\":133},\"61f362c6\":{\"m\":134,\"g\":133},\"e7157c9b\":{\"m\":134,\"g\":133},\"30da2f05\":{\"m\":134,\"g\":133},\"49016931\":{\"m\":134,\"g\":133},\"3d484be5\":{\"m\":134,\"g\":133},\"c0d94440\":{\"m\":134,\"g\":133},\"1dedb638\":{\"m\":134,\"g\":133},\"7bc8b153\":{\"m\":134,\"g\":133},\"9003a436\":{\"m\":134,\"g\":133},\"3518b331\":{\"m\":134,\"g\":133},\"b098b1ae\":{\"m\":134,\"g\":133},\"abd3e048\":{\"m\":134,\"g\":133},\"92c29d43\":{\"m\":134,\"g\":133},\"bf643814\":{\"m\":134,\"g\":133},\"f03bfa4c\":{\"m\":134,\"g\":133},\"89ad3908\":{\"m\":134,\"g\":133},\"2ea844ec\":{\"m\":134,\"g\":133},\"1e2d7538\":{\"m\":134,\"g\":133},\"d16ff357\":{\"m\":134,\"g\":133},\"af49e302\":{\"m\":134,\"g\":133},\"01b955ac\":{\"m\":134,\"g\":133},\"16e6bc20\":{\"m\":134,\"g\":133},\"37250764\":{\"m\":134,\"g\":133},\"7b9156c7\":{\"m\":134,\"g\":133},\"21cfebac\":{\"m\":134,\"g\":133},\"702426b0\":{\"m\":134,\"g\":133},\"9e9a6169\":{\"m\":134,\"g\":133},\"bd9c3a47\":{\"m\":134,\"g\":133},\"1e641ee4\":{\"m\":134,\"g\":133},\"fb96669f\":{\"m\":134,\"g\":133},\"3912ee49\":{\"m\":134,\"g\":133},\"1ab9b8e0\":{\"m\":134,\"g\":133},\"e61dabf5\":{\"m\":134,\"g\":133},\"36e7c8c5\":{\"m\":134,\"g\":133},\"037c3982\":{\"m\":134,\"g\":133},\"62b3fdae\":{\"m\":134,\"g\":133},\"1cd0c3bf\":{\"m\":134,\"g\":133},\"2c899431\":{\"m\":134,\"g\":133},\"4513f549\":{\"m\":134,\"g\":133},\"c9690307\":{\"m\":134,\"g\":133},\"4449c170\":{\"m\":134,\"g\":133},\"5ca962ce\":{\"m\":134,\"g\":133},\"bab20a84\":{\"m\":134,\"g\":133},\"0e4108ba\":{\"m\":134,\"g\":133},\"99cb2ed9\":{\"m\":134,\"g\":133},\"0612175c\":{\"m\":134,\"g\":133},\"8c96fcda\":{\"m\":134,\"g\":133},\"ea7c69ce\":{\"m\":134,\"g\":133},\"8102e36b\":{\"m\":134,\"g\":133},\"3f048217\":{\"m\":134,\"g\":133},\"b11af135\":{\"m\":134,\"g\":133},\"f9bceea0\":{\"m\":134,\"g\":133},\"997ea57e\":{\"m\":134,\"g\":133},\"47633c19\":{\"m\":134,\"g\":133},\"64b5c3ab\":{\"m\":134,\"g\":133},\"d277a86d\":{\"m\":134,\"g\":133},\"9acb21ae\":{\"m\":134,\"g\":133},\"a9ce1623\":{\"m\":134,\"g\":133},\"6f0c77d7\":{\"m\":134,\"g\":133},\"e3f51e82\":{\"m\":134,\"g\":133},\"fdfabb7a\":{\"m\":134,\"g\":133},\"19c16748\":{\"m\":134,\"g\":133},\"5c75907e\":{\"m\":134,\"g\":133},\"4ea36422\":{\"m\":134,\"g\":133},\"f50af32d\":{\"m\":134,\"g\":133},\"72952919\":{\"m\":134,\"g\":133},\"2ae5bed1\":{\"m\":134,\"g\":133},\"54df514b\":{\"m\":134,\"g\":133},\"681c68cf\":{\"m\":134,\"g\":133},\"74ea45cc\":{\"m\":134,\"g\":133},\"0fa044ad\":{\"m\":134,\"g\":133},\"7d8e42c9\":{\"m\":134,\"g\":133},\"6abdf73f\":{\"m\":134,\"g\":133},\"20ce9938\":{\"m\":134,\"g\":133},\"168a31eb\":{\"m\":134,\"g\":133},\"69cfb17b\":{\"m\":134,\"g\":133},\"a7a4b175\":{\"m\":134,\"g\":133},\"fdc93b01\":{\"m\":134,\"g\":133},\"fd37cc5d\":{\"m\":134,\"g\":133},\"96705514\":{\"m\":134,\"g\":133},\"ab3ffd1c\":{\"m\":134,\"g\":133},\"2285afff\":{\"m\":134,\"g\":133},\"a81cc1b8\":{\"m\":134,\"g\":133},\"3134d2b2\":{\"m\":134,\"g\":133},\"06b58c5d\":{\"m\":134,\"g\":133},\"ea07a283\":{\"m\":134,\"g\":133},\"ea91a720\":{\"m\":134,\"g\":133},\"5d9c6bac\":{\"m\":134,\"g\":133},\"e048ee90\":{\"m\":134,\"g\":133},\"ed52d01b\":{\"m\":134,\"g\":133},\"90e7d4f7\":{\"m\":134,\"g\":133},\"c20d43d2\":{\"m\":134,\"g\":133},\"d977dd2e\":{\"m\":134,\"g\":133},\"0c23331e\":{\"m\":134,\"g\":133},\"993278b4\":{\"m\":134,\"g\":133},\"9e9d9107\":{\"m\":134,\"g\":133},\"3b8a824b\":{\"m\":134,\"g\":133},\"f1bbd26f\":{\"m\":134,\"g\":133},\"d36299ad\":{\"m\":134,\"g\":133},\"0e7d7969\":{\"m\":134,\"g\":133},\"80554598\":{\"m\":134,\"g\":133},\"b2e240bc\":{\"m\":134,\"g\":133},\"dcc5f5c0\":{\"m\":134,\"g\":133},\"4eda4194\":{\"m\":134,\"g\":133},\"875f84db\":{\"m\":134,\"g\":133},\"f6031adf\":{\"m\":134,\"g\":133},\"2a39cfe0\":{\"m\":134,\"g\":133},\"bf17e769\":{\"m\":134,\"g\":133},\"9a5d6a84\":{\"m\":134,\"g\":133},\"31c23e5f\":{\"m\":134,\"g\":133},\"06617a9e\":{\"m\":134,\"g\":133},\"e79ca959\":{\"m\":134,\"g\":133},\"9d3b411c\":{\"m\":134,\"g\":133},\"05325db3\":{\"m\":134,\"g\":133},\"01e3b3f3\":{\"m\":134,\"g\":133},\"86988674\":{\"m\":134,\"g\":133},\"29139654\":{\"m\":134,\"g\":133},\"665cb020\":{\"m\":134,\"g\":133},\"71602838\":{\"m\":134,\"g\":133},\"77873343\":{\"m\":134,\"g\":133},\"267170bf\":{\"m\":134,\"g\":133},\"313f59ad\":{\"m\":134,\"g\":133},\"487cf81a\":{\"m\":134,\"g\":133},\"df111bc0\":{\"m\":134,\"g\":133},\"6d2b3324\":{\"m\":134,\"g\":133},\"8cc77261\":{\"m\":134,\"g\":133},\"d143b020\":{\"m\":134,\"g\":133},\"9b9d2131\":{\"m\":134,\"g\":133},\"44fd7017\":{\"m\":134,\"g\":133},\"9a56273a\":{\"m\":134,\"g\":133},\"b737a125\":{\"m\":134,\"g\":133},\"1b5e9034\":{\"m\":134,\"g\":133},\"171b442a\":{\"m\":134,\"g\":133},\"4b7b5af3\":{\"m\":134,\"g\":133},\"ec242f51\":{\"m\":134,\"g\":133},\"b2431546\":{\"m\":134,\"g\":133},\"526fd008\":{\"m\":134,\"g\":133},\"56d0ad47\":{\"m\":134,\"g\":133},\"306e5b8d\":{\"m\":134,\"g\":133},\"10c68f62\":{\"m\":134,\"g\":133},\"c7c837cd\":{\"m\":134,\"g\":133},\"3e1e7157\":{\"m\":134,\"g\":133},\"8fa8d9d7\":{\"m\":134,\"g\":133},\"c8cf1caf\":{\"m\":134,\"g\":133},\"d71baa72\":{\"m\":134,\"g\":133},\"82e33170\":{\"m\":134,\"g\":133},\"94e12511\":{\"m\":134,\"g\":133},\"4dabfbc8\":{\"m\":134,\"g\":133},\"d7ed8a8c\":{\"m\":134,\"g\":133},\"c05d3afb\":{\"m\":134,\"g\":133},\"edb172e9\":{\"m\":134,\"g\":133},\"fe6d38d2\":{\"m\":134,\"g\":133},\"dab31e4c\":{\"m\":134,\"g\":133},\"a7fa31ff\":{\"m\":134,\"g\":133},\"bd91f882\":{\"m\":134,\"g\":133},\"b47adb80\":{\"m\":134,\"g\":133},\"b2b5bdb0\":{\"m\":134,\"g\":133},\"22fe5da1\":{\"m\":134,\"g\":133},\"1834401e\":{\"m\":134,\"g\":133},\"198c8ecf\":{\"m\":134,\"g\":133},\"c01b2ee0\":{\"m\":134,\"g\":133},\"8f5adac8\":{\"m\":134,\"g\":133},\"76743a98\":{\"m\":134,\"g\":133},\"8bf10e71\":{\"m\":134,\"g\":133},\"e4873d04\":{\"m\":134,\"g\":133},\"c660d8df\":{\"m\":134,\"g\":133},\"10146af0\":{\"m\":134,\"g\":133},\"b62e7e3b\":{\"m\":134,\"g\":133},\"6ce36b12\":{\"m\":134,\"g\":133},\"e9e7f15e\":{\"m\":134,\"g\":133},\"0aa3dec5\":{\"m\":134,\"g\":133},\"4885f8b9\":{\"m\":134,\"g\":133},\"f832994c\":{\"m\":134,\"g\":133},\"e59435c3\":{\"m\":134,\"g\":133},\"d6bd2d11\":{\"m\":134,\"g\":133},\"9975acf5\":{\"m\":134,\"g\":133},\"70758d45\":{\"m\":134,\"g\":133},\"fd1ebbb0\":{\"m\":134,\"g\":133},\"3d98bd5e\":{\"m\":134,\"g\":133},\"aa3716b2\":{\"m\":134,\"g\":133},\"6107268f\":{\"m\":134,\"g\":133},\"0189f41c\":\"m136\",\"b6e4893a\":\"m136\",\"3eb7da53\":\"m136\",\"cb53ddc9\":\"m136\",\"0c265321\":\"m136\",\"6469c964\":\"m136\",\"71705394\":\"m136\",\"67589c16\":\"m136\",\"a702c8f1\":\"m136\",\"f2ae066a\":\"m136\",\"0c2993ee\":\"m136\",\"590969ee\":\"m136\",\"e6ccb294\":\"m136\",\"d6ea2c52\":\"m136\",\"9be2a3a9\":\"m136\",\"b74a57a8\":\"m136\",\"95f59c13\":\"m136\",\"85d9af51\":\"m136\",\"858f317f\":\"m136\",\"cf893516\":\"m136\",\"1fdf5cac\":\"m136\",\"cda43ffa\":\"m136\",\"39089854\":\"m136\",\"b827e9d3\":\"m136\",\"d725487d\":\"m136\",\"a95c9f5b\":\"m136\",\"19089aa4\":\"m136\",\"4f6f5d25\":\"m136\",\"458fe5a3\":\"m136\",\"2ff0880a\":\"m136\",\"2c1b164a\":\"m136\",\"bcc6d84f\":\"m136\",\"a618202f\":\"m136\",\"7520b929\":\"m136\",\"e7224e96\":\"m136\",\"e776239a\":\"m136\",\"1b97fa76\":\"m136\",\"236772c0\":\"m136\",\"0d49b13f\":\"m136\",\"0a7a2017\":\"m136\",\"0a9099e1\":\"m136\",\"0050c476\":\"m136\",\"8251a74d\":\"m136\",\"a54d75bf\":\"m136\",\"3321eb4e\":\"m136\",\"be5121b4\":\"m136\",\"c3f9c30f\":\"m136\",\"54a82179\":\"m136\",\"aca354bc\":\"m136\",\"aea57b33\":\"m136\",\"823a046e\":\"m136\",\"648aab0c\":\"m136\",\"1e309030\":\"m136\",\"60927215\":\"m136\",\"38c233fd\":\"m136\",\"20ed3822\":\"m136\",\"4ecd9afd\":\"m136\",\"eb38d644\":\"m136\",\"6ea491e4\":\"m136\",\"16802fb6\":\"m136\",\"d97066d2\":\"m136\",\"ce2d686e\":\"m136\",\"76b06bee\":\"m136\",\"612026ad\":\"m136\",\"55c61642\":\"m136\",\"91a4cd86\":\"m136\",\"23d765d1\":\"m136\",\"db2425a0\":\"m136\",\"603f386c\":\"m136\",\"f7a5e425\":\"m136\",\"d50dcd9b\":\"m136\",\"e6b7c049\":\"m136\",\"a3addd62\":\"m136\",\"6988a0f5\":\"m136\",\"1b192cf1\":\"m136\",\"c560e142\":\"m136\",\"71cb9d03\":\"m136\",\"8fb45523\":\"m136\",\"84aef378\":\"m136\",\"17c04b10\":\"m136\",\"55c4288b\":\"m136\",\"9f8b79f1\":\"m136\",\"7dc3cbe7\":\"m136\",\"79ddc34c\":\"m136\",\"09a9d214\":\"m136\",\"7e40d526\":\"m136\",\"057b07fc\":\"m136\",\"91d8c52d\":\"m136\",\"e9a44ea6\":\"m136\",\"c1282da2\":\"m136\",\"71279e31\":\"m136\",\"1a053a81\":\"m136\",\"2ea02f06\":\"m136\",\"20b0523e\":\"m136\",\"ce8a6ac6\":\"m136\",\"ebca5879\":\"m136\",\"cc410a10\":\"m136\",\"f374623f\":\"m136\",\"5c022177\":\"m136\",\"8916b9d0\":\"m136\",\"a3d9a218\":\"m136\",\"fb88fb67\":\"m136\",\"2d72e168\":\"m136\",\"64946679\":\"m136\",\"5836324c\":\"m136\",\"9fe56cd0\":\"m136\",\"858a4d65\":\"m136\",\"e619f531\":\"m136\",\"fc4b932f\":\"m136\",\"d2105d4a\":\"m136\",\"84c83905\":\"m136\",\"ea879c77\":\"m136\",\"0227db89\":\"m136\",\"ad1b4e47\":\"m136\",\"51f147ad\":\"m136\",\"d3eafc73\":\"m136\",\"e00b4344\":\"m136\",\"733de6be\":\"m136\",\"93433726\":\"m136\",\"330605cc\":\"m136\",\"bb6055b4\":\"m136\",\"4df74eb5\":\"m136\",\"f3a7c7dc\":\"m136\",\"2069050d\":\"m136\",\"1fe0c82f\":\"m136\",\"a45e0e5d\":\"m136\",\"f78201f3\":\"m136\",\"8fd33998\":\"m136\",\"088758c1\":\"m136\",\"6d29d8ab\":\"m136\",\"e499258e\":\"m136\",\"09491a9b\":\"m136\",\"7edb0615\":\"m136\",\"e486a4da\":\"m136\",\"90399cbc\":\"m136\",\"53609e5e\":\"m136\",\"9c253064\":\"m136\",\"d2c86387\":\"m136\",\"737a1183\":\"m136\",\"dc743fe4\":\"m136\",\"dd99f818\":\"m136\",\"8ce64aa1\":\"m136\",\"eb768189\":\"m136\",\"2cdd4370\":\"m136\",\"a7b5f75d\":\"m136\",\"305c1a57\":\"m136\",\"c824ddd5\":\"m136\",\"e18e0057\":\"m136\",\"166396ca\":\"m136\",\"43779f27\":\"m136\",\"8b9e9357\":\"m136\",\"d36f6f04\":\"m136\",\"b0701f02\":\"m136\",\"4229de3b\":\"m136\",\"2e144079\":\"m136\",\"d2ec128b\":\"m136\",\"7f8353af\":\"m136\",\"a7f5677a\":\"m136\",\"3e968ab3\":\"m136\",\"b4fce995\":\"m136\",\"ec9b48ea\":\"m136\",\"a0467589\":\"m136\",\"82a1b645\":\"m136\",\"6f10e17b\":\"m136\",\"c771933d\":\"m136\",\"9d8bbd42\":\"m136\",\"daea5138\":\"m136\",\"3355b6e2\":\"m136\",\"a1dd3d48\":\"m136\",\"d9ed80b9\":\"m136\",\"7c39ea68\":\"m136\",\"daa4841e\":\"m136\",\"968c4f55\":\"m136\",\"669d309a\":\"m136\",\"21ee597e\":\"m136\",\"6ee970a3\":\"m136\",\"8ec160ed\":\"m136\",\"2740ed1a\":\"m136\",\"d44f09ad\":\"m136\",\"e7dc85c5\":\"m136\",\"c81bad1b\":\"m136\",\"0e86de7c\":\"m136\",\"e3a95077\":\"m136\",\"146b5fcc\":\"m136\",\"8b22deef\":\"m136\",\"69822c72\":\"m136\",\"8b99af9a\":\"m136\",\"72e2f70e\":\"m136\",\"3d72944f\":\"m136\",\"7dde3438\":\"m136\",\"77fc4c4a\":\"m136\",\"655d2c7c\":\"m136\",\"3f44268f\":\"m136\",\"f7ec8174\":\"m136\",\"cd23c2f0\":\"m136\",\"d1110e1c\":\"m136\",\"9227d9f6\":\"m136\",\"4c59782e\":\"m136\",\"dda35ccb\":\"m136\",\"d11e2dc6\":\"m136\",\"16831ab6\":\"m136\",\"c9a45b7e\":\"m136\",\"e7df8bdc\":\"m136\",\"e9979950\":\"m136\",\"6586f44a\":\"m136\",\"43fe3a4d\":\"m136\",\"98096b5e\":\"m136\",\"68e8d0f6\":\"m136\",\"000ad422\":\"m136\",\"9d5f16d4\":\"m136\",\"7f8a58ff\":\"m136\",\"6b065298\":\"m136\",\"c020d300\":\"m136\",\"4346db5f\":\"m136\",\"424a3800\":\"m136\",\"f0918583\":\"m136\",\"5b1215d9\":\"m136\",\"aa2b4f76\":\"m136\",\"b3a3f513\":\"m136\",\"de94d793\":\"m136\",\"0d904ef4\":\"m136\",\"969faaa4\":\"m136\",\"48c2aca9\":\"m136\",\"5af84c8a\":\"m136\",\"feae615b\":\"m136\",\"e75299a1\":\"m136\",\"72bacc88\":\"m136\",\"ba625c2d\":\"m136\",\"b025cff4\":\"m136\",\"030496eb\":\"m136\",\"c86ca128\":\"m136\",\"cd336945\":\"m136\",\"c5e363e8\":\"m136\",\"9479eca7\":\"m136\",\"a5348eac\":\"m136\",\"95240402\":\"m136\",\"2122fea3\":\"m136\",\"e2c8a50b\":\"m136\",\"a4825ed5\":\"m136\",\"afe285f7\":\"m136\",\"cf25852a\":\"m136\",\"5938c3b0\":\"m136\",\"b8806071\":\"m136\",\"339915ce\":\"m136\",\"075c5a57\":\"m136\",\"a0b4ba90\":\"m136\",\"2a7b67ad\":\"m136\",\"1d811094\":\"m136\",\"2ab3ed3e\":\"m136\",\"250477d2\":\"m136\",\"af1232b2\":\"m136\",\"7a869045\":\"m136\",\"888d7e54\":\"m136\",\"3cb1fbae\":\"m136\",\"ba9f6d8f\":\"m136\",\"740d3c0b\":\"m136\",\"87165898\":\"m136\",\"ff3ddb9d\":\"m136\",\"9d3018f4\":\"m136\",\"a8348427\":\"m136\",\"47d485f3\":\"m136\",\"2b423099\":\"m136\",\"ae0baefb\":\"m136\",\"1f0e3d7f\":\"m136\",\"d3c08fb0\":\"m136\",\"c6a64e9f\":\"m136\",\"e0ac559a\":\"m136\",\"6620548f\":\"m136\",\"6e158e55\":\"m136\",\"ed729d22\":\"m136\",\"fa51b854\":\"m136\",\"559ff9ec\":\"m136\",\"7b682de8\":\"m136\",\"d0092dec\":\"m136\",\"76f69b77\":\"m136\",\"9a628744\":\"m136\",\"53dca74f\":\"m136\",\"2dadf635\":\"m136\",\"aab640c9\":\"m136\",\"2b3791ed\":\"m136\",\"9f5cd80a\":\"m136\",\"b1ee75ae\":\"m136\",\"f44c63ee\":\"m136\",\"c54c70ab\":\"m136\",\"aab906a3\":\"m136\",\"feb39f77\":\"m136\",\"38b30c7b\":\"m136\",\"c581b5ed\":\"m136\",\"a1c48943\":\"m136\",\"5b7bed7c\":\"m136\",\"38a88479\":\"m136\",\"503c3d95\":\"m136\",\"934ae89a\":\"m136\",\"cf1426a7\":\"m136\",\"17cb3c8e\":\"m136\",\"2f4a6add\":\"m136\",\"f9fc50ac\":\"m136\",\"3c16c586\":\"m136\",\"7b089ae4\":\"m136\",\"8b5d4263\":\"m136\",\"b5493f65\":\"m136\",\"09e2571e\":\"m136\",\"c0248d6f\":\"m136\",\"cc25f9df\":\"m136\",\"cf14feba\":\"m136\",\"7c25687c\":\"m136\",\"ff978142\":\"m136\",\"d112f6a2\":\"m136\",\"a2c2c09d\":\"m136\",\"2a9344d3\":\"m136\",\"78c41758\":\"m136\",\"206db66f\":\"m136\",\"5c72be1e\":\"m136\",\"76d48817\":\"m136\",\"d1ec93e3\":\"m136\",\"bdb76b34\":\"m136\",\"dae6a409\":\"m136\",\"a0899bdb\":\"m136\",\"641830c1\":\"m136\",\"3fd88ea9\":\"m136\",\"145bd54f\":\"m136\",\"2d088b85\":\"m136\",\"3a8b44fe\":\"m136\",\"aeb480c1\":\"m136\",\"9fd2358c\":\"m136\",\"3c358736\":\"m136\",\"6327dff2\":\"m136\",\"675acece\":\"m136\",\"ad201273\":\"m136\",\"7f393d95\":\"m136\",\"4b14f622\":\"m136\",\"94fc26aa\":\"m136\",\"67b61a4e\":\"m136\",\"d27f16f3\":\"m136\",\"c89949bb\":\"m136\",\"20abaee2\":\"m136\",\"32a569fb\":\"m136\",\"3ed3b7ef\":\"m136\",\"1f9d4795\":\"m136\",\"e6d40bff\":\"m136\",\"fbc128a3\":\"m136\",\"9c64a15a\":\"m136\",\"e91a7176\":\"m136\",\"1f0ea4f9\":\"m136\",\"15da3061\":\"m136\",\"6406a596\":\"m136\",\"cec19b56\":\"m136\",\"ef35d8fe\":\"m136\",\"08636f72\":\"m136\",\"a6c29d4c\":\"m136\",\"84ab32a2\":\"m136\",\"70667115\":\"m136\",\"bd1afeb5\":\"m136\",\"8ef5b905\":\"m136\",\"76b3c698\":\"m136\",\"5dcff947\":\"m136\",\"8eeffbe9\":\"m136\",\"71e9c31c\":\"m136\",\"7656d267\":\"m136\",\"dd24ba90\":\"m136\",\"068abe7e\":\"m136\",\"64a31d4b\":\"m136\",\"2babf88f\":\"m136\",\"d56d14e5\":\"m136\",\"0c4e155a\":\"m136\",\"d6d5c3fd\":\"m136\",\"e46f7943\":\"m136\",\"9d4d57db\":\"m136\",\"41609b52\":\"m136\",\"ccd0fb32\":\"m136\",\"87ee6b5e\":\"m136\",\"f7c1d24b\":\"m136\",\"cceb5e6a\":\"m136\",\"c1c13c84\":\"m136\",\"fcec35dc\":\"m136\",\"75da784d\":\"m136\",\"77d35665\":\"m136\",\"05dfef92\":\"m136\",\"74602407\":{\"m\":136,\"g\":135},\"f7f5c389\":{\"m\":136,\"g\":135},\"9e3a032a\":{\"m\":136,\"g\":135},\"1bc7aa58\":{\"m\":136,\"g\":135},\"8726d30c\":{\"m\":136,\"g\":135},\"a1b243d7\":{\"m\":136,\"g\":135},\"a9799277\":{\"m\":136,\"g\":135},\"ee71e773\":{\"m\":136,\"g\":135},\"55a8dd00\":{\"m\":136,\"g\":135},\"cf242321\":{\"m\":136,\"g\":135},\"9d03af91\":{\"m\":136,\"g\":135},\"064ae341\":{\"m\":136,\"g\":135},\"49305fa1\":{\"m\":136,\"g\":135},\"4e999404\":{\"m\":136,\"g\":135},\"fbc24886\":{\"m\":136,\"g\":135},\"16880235\":{\"m\":136,\"g\":135},\"cda35611\":{\"m\":136,\"g\":135},\"8a45a9c6\":{\"m\":136,\"g\":135},\"aecd5f5f\":{\"m\":136,\"g\":135},\"05ab110e\":{\"m\":136,\"g\":135},\"c8dc4d2d\":{\"m\":136,\"g\":135},\"82a8d77b\":{\"m\":136,\"g\":135},\"294ff71d\":{\"m\":136,\"g\":135},\"f52ae586\":{\"m\":136,\"g\":135},\"2f8a3634\":{\"m\":136,\"g\":135},\"b6e8a0d8\":{\"m\":136,\"g\":135},\"fb7609f1\":{\"m\":136,\"g\":135},\"d2ea44f7\":{\"m\":136,\"g\":135},\"20ca2c6e\":{\"m\":136,\"g\":135},\"83abecd0\":{\"m\":136,\"g\":135},\"d54f0a10\":{\"m\":136,\"g\":135},\"fb04e7e3\":{\"m\":136,\"g\":135},\"1e5de05e\":{\"m\":136,\"g\":135},\"7dd679cb\":{\"m\":136,\"g\":135},\"3d51ae18\":{\"m\":136,\"g\":135},\"f9c04266\":{\"m\":136,\"g\":135},\"48b8dcd4\":{\"m\":136,\"g\":135},\"41b434a7\":{\"m\":136,\"g\":135},\"4935344f\":{\"m\":136,\"g\":135},\"63cc97f4\":{\"m\":136,\"g\":135},\"ab7d5829\":{\"m\":136,\"g\":135},\"261860e1\":{\"m\":136,\"g\":135},\"154740bd\":{\"m\":136,\"g\":135},\"1c09cbe3\":{\"m\":136,\"g\":135},\"e14f5ec8\":{\"m\":136,\"g\":135},\"8867d248\":{\"m\":136,\"g\":135},\"6b3f93c4\":{\"m\":136,\"g\":135},\"5a5cece5\":{\"m\":136,\"g\":135},\"d566739b\":{\"m\":136,\"g\":135},\"4c46ecde\":{\"m\":136,\"g\":135},\"12a0292b\":{\"m\":136,\"g\":135},\"a08dc5aa\":{\"m\":136,\"g\":135},\"eec7dbd3\":{\"m\":136,\"g\":135},\"109fe03a\":{\"m\":136,\"g\":135},\"bb798a1c\":{\"m\":136,\"g\":135},\"38dc5839\":{\"m\":136,\"g\":135},\"5e867f60\":{\"m\":136,\"g\":135},\"65bed838\":{\"m\":136,\"g\":135},\"156d97b2\":{\"m\":136,\"g\":135},\"24b30f77\":{\"m\":136,\"g\":135},\"3a4767da\":{\"m\":136,\"g\":135},\"6037267f\":{\"m\":136,\"g\":135},\"0241e046\":{\"m\":136,\"g\":135},\"0c474273\":{\"m\":136,\"g\":135},\"3e73e124\":{\"m\":136,\"g\":135},\"f4742558\":{\"m\":136,\"g\":135},\"7385834c\":{\"m\":136,\"g\":135},\"b5a94f8a\":{\"m\":136,\"g\":135},\"c356ed03\":{\"m\":136,\"g\":135},\"ee4d2287\":{\"m\":136,\"g\":135},\"153c69f6\":{\"m\":136,\"g\":135},\"55b79365\":{\"m\":136,\"g\":135},\"7fc12e0b\":{\"m\":136,\"g\":135},\"8729ad5e\":{\"m\":136,\"g\":135},\"e4320573\":{\"m\":136,\"g\":135},\"4d902c82\":{\"m\":136,\"g\":135},\"2ff87231\":{\"m\":136,\"g\":135},\"fd16c91c\":{\"m\":136,\"g\":135},\"32a6540a\":{\"m\":136,\"g\":135},\"62d0280f\":{\"m\":136,\"g\":135},\"d4b717c0\":{\"m\":136,\"g\":135},\"98a107d4\":{\"m\":136,\"g\":135},\"b86bbf84\":{\"m\":136,\"g\":135},\"48381c3b\":{\"m\":136,\"g\":135},\"8bce0853\":{\"m\":136,\"g\":135},\"6b8a9d70\":{\"m\":136,\"g\":135},\"973116e6\":{\"m\":136,\"g\":135},\"820e97d6\":{\"m\":136,\"g\":135},\"4c85f9d0\":{\"m\":136,\"g\":135},\"7d757d6f\":{\"m\":136,\"g\":135},\"5c04088b\":{\"m\":136,\"g\":135},\"f066036c\":{\"m\":136,\"g\":135},\"52de807d\":{\"m\":136,\"g\":135},\"70933f34\":{\"m\":136,\"g\":135},\"ce453fa4\":{\"m\":136,\"g\":135},\"3be1e734\":{\"m\":136,\"g\":135},\"38895a00\":{\"m\":136,\"g\":135},\"d8b81981\":{\"m\":136,\"g\":135},\"913b688f\":{\"m\":136,\"g\":135},\"951d16c8\":{\"m\":136,\"g\":135},\"53846746\":{\"m\":136,\"g\":135},\"534ac384\":{\"m\":136,\"g\":135},\"2a8d5493\":{\"m\":136,\"g\":135},\"90eac38a\":{\"m\":136,\"g\":135},\"badcd028\":{\"m\":136,\"g\":135},\"d874c8bb\":{\"m\":136,\"g\":135},\"9a21d89c\":{\"m\":136,\"g\":135},\"4c9ac856\":{\"m\":136,\"g\":135},\"fb5b71d0\":{\"m\":136,\"g\":135},\"05b54b6d\":{\"m\":136,\"g\":135},\"4f443f44\":{\"m\":136,\"g\":135},\"dce8b060\":{\"m\":136,\"g\":135},\"399d5283\":{\"m\":136,\"g\":135},\"2e0527dd\":{\"m\":136,\"g\":135},\"6beb50d6\":{\"m\":136,\"g\":135},\"18e2ef09\":{\"m\":136,\"g\":135},\"3271e0e7\":{\"m\":136,\"g\":135},\"d415d22d\":{\"m\":136,\"g\":135},\"a49b9a64\":{\"m\":136,\"g\":135},\"53497642\":{\"m\":136,\"g\":135},\"d57d8e7e\":{\"m\":136,\"g\":135},\"0cbd8f32\":{\"m\":136,\"g\":135},\"6e3fff13\":{\"m\":136,\"g\":135},\"2210155a\":{\"m\":136,\"g\":135},\"a3656cbb\":{\"m\":136,\"g\":135},\"21da2dc1\":{\"m\":136,\"g\":135},\"ed307a40\":{\"m\":136,\"g\":135},\"02722b91\":{\"m\":136,\"g\":135},\"f959250f\":{\"m\":136,\"g\":135},\"95934379\":{\"m\":136,\"g\":135},\"bc2f40be\":{\"m\":136,\"g\":135},\"2724b110\":{\"m\":136,\"g\":135},\"176266f3\":{\"m\":136,\"g\":135},\"5e5b1183\":{\"m\":136,\"g\":135},\"9bf76c11\":{\"m\":136,\"g\":135},\"1d7ad4af\":{\"m\":136,\"g\":135},\"fba785c4\":{\"m\":136,\"g\":135},\"5cfa901b\":{\"m\":136,\"g\":135},\"f27c6cdc\":{\"m\":136,\"g\":135},\"5097e1e8\":{\"m\":136,\"g\":135},\"861a35fb\":{\"m\":136,\"g\":135},\"6ffe1fc0\":{\"m\":136,\"g\":135},\"73398e22\":{\"m\":136,\"g\":135},\"17958c5f\":{\"m\":136,\"g\":135},\"3aa11ca7\":{\"m\":136,\"g\":135},\"7be1a8c7\":{\"m\":136,\"g\":135},\"84d13c54\":{\"m\":136,\"g\":135},\"c80c0e0f\":{\"m\":136,\"g\":135},\"c105a312\":{\"m\":136,\"g\":135},\"4cf2bbd0\":{\"m\":136,\"g\":135},\"d7b706be\":{\"m\":136,\"g\":135},\"4a9537a4\":{\"m\":136,\"g\":135},\"ca922d4b\":{\"m\":136,\"g\":135},\"402a0bd6\":{\"m\":136,\"g\":135},\"76c71d1d\":{\"m\":136,\"g\":135},\"2d02c150\":{\"m\":136,\"g\":135},\"b98bd9a5\":{\"m\":136,\"g\":135},\"ce694b2b\":{\"m\":136,\"g\":135},\"6c0fb189\":{\"m\":136,\"g\":135},\"9a9f996f\":{\"m\":136,\"g\":135},\"4221b7c5\":{\"m\":136,\"g\":135},\"c371df2f\":{\"m\":136,\"g\":135},\"1751c75b\":{\"m\":136,\"g\":135},\"23849eba\":{\"m\":136,\"g\":135},\"51541404\":{\"m\":136,\"g\":135},\"45ef8344\":{\"m\":136,\"g\":135},\"5a2b1ed4\":{\"m\":136,\"g\":135},\"454dc9e2\":{\"m\":136,\"g\":135},\"1e41069a\":{\"m\":136,\"g\":135},\"2b4d6d81\":{\"m\":136,\"g\":135},\"a3914e3b\":{\"m\":136,\"g\":135},\"130f60ee\":{\"m\":136,\"g\":135},\"4397cda7\":{\"m\":136,\"g\":135},\"d56fd10c\":{\"m\":136,\"g\":135},\"abb06be9\":{\"m\":136,\"g\":135},\"c35eb0fd\":{\"m\":136,\"g\":135},\"7f35c46e\":{\"m\":136,\"g\":135},\"da2f8cc3\":{\"m\":136,\"g\":135},\"4308c25b\":{\"m\":136,\"g\":135},\"4d737db8\":{\"m\":136,\"g\":135},\"9d6029fb\":{\"m\":136,\"g\":135},\"dcfb92dd\":{\"m\":136,\"g\":135},\"bebd625b\":{\"m\":136,\"g\":135},\"b12258bf\":{\"m\":136,\"g\":135},\"012dc586\":{\"m\":136,\"g\":135},\"7f6a678f\":{\"m\":136,\"g\":135},\"399ca037\":{\"m\":136,\"g\":135},\"d93f37a6\":{\"m\":136,\"g\":135},\"e6fe092d\":{\"m\":136,\"g\":135},\"f02d8221\":{\"m\":136,\"g\":135},\"10174e11\":{\"m\":136,\"g\":135},\"2138ff48\":{\"m\":136,\"g\":135},\"e53160bb\":{\"m\":136,\"g\":135},\"4ea6a11c\":{\"m\":136,\"g\":135},\"2181bc9e\":{\"m\":136,\"g\":135},\"520c048d\":{\"m\":136,\"g\":135},\"561a3e04\":{\"m\":136,\"g\":135},\"f84487af\":{\"m\":136,\"g\":135},\"07827047\":{\"m\":136,\"g\":135},\"c63e9cb2\":{\"m\":136,\"g\":135},\"12cde0df\":{\"m\":136,\"g\":135},\"a7fd8108\":{\"m\":136,\"g\":135},\"ca80c19b\":{\"m\":136,\"g\":135},\"1e7b3264\":{\"m\":136,\"g\":135},\"e267ca0b\":{\"m\":136,\"g\":135},\"9a8ba3c1\":{\"m\":136,\"g\":135},\"0fee6bc6\":{\"m\":136,\"g\":135},\"0ff3747c\":{\"m\":136,\"g\":135},\"f8411ded\":{\"m\":136,\"g\":135},\"249c3563\":{\"m\":136,\"g\":135},\"12df1660\":{\"m\":136,\"g\":135},\"55d112dc\":{\"m\":136,\"g\":135},\"a1ed247f\":{\"m\":136,\"g\":135},\"87699d48\":{\"m\":136,\"g\":135},\"52c60434\":{\"m\":136,\"g\":135},\"ff0f370f\":{\"m\":136,\"g\":135},\"26f9e207\":{\"m\":136,\"g\":135},\"828cd893\":{\"m\":136,\"g\":135},\"4436dc0f\":{\"m\":136,\"g\":135},\"cf6800f6\":{\"m\":136,\"g\":135},\"f16606d6\":{\"m\":136,\"g\":135},\"387fad2f\":{\"m\":136,\"g\":135},\"76bc07a3\":{\"m\":136,\"g\":135},\"b328cd20\":{\"m\":136,\"g\":135},\"bf32cd83\":{\"m\":136,\"g\":135},\"27a08305\":{\"m\":136,\"g\":135},\"53479e22\":{\"m\":136,\"g\":135},\"ff978e7d\":{\"m\":136,\"g\":135},\"16e00651\":{\"m\":136,\"g\":135},\"1b2b95d8\":{\"m\":136,\"g\":135},\"9cac3c86\":{\"m\":136,\"g\":135},\"8d58b3dc\":{\"m\":136,\"g\":135},\"5f3eb377\":{\"m\":136,\"g\":135},\"22993880\":{\"m\":136,\"g\":135},\"d7aa0ce7\":{\"m\":136,\"g\":135},\"e797f0c5\":{\"m\":136,\"g\":135},\"5d4b7c78\":{\"m\":136,\"g\":135},\"9bd64d73\":{\"m\":136,\"g\":135},\"24e116ef\":{\"m\":136,\"g\":135},\"8ca95970\":{\"m\":136,\"g\":135},\"216ea910\":{\"m\":136,\"g\":135},\"f4ab2ec5\":{\"m\":136,\"g\":135},\"25fa2ac2\":{\"m\":136,\"g\":135},\"ef5ac6f0\":{\"m\":136,\"g\":135},\"c88aaf22\":{\"m\":136,\"g\":135},\"e139d2aa\":{\"m\":136,\"g\":135},\"2337b1bb\":{\"m\":136,\"g\":135},\"7bc13c90\":{\"m\":136,\"g\":135},\"877c8e3a\":{\"m\":136,\"g\":135},\"66dfb8c1\":{\"m\":136,\"g\":135},\"dcacc492\":{\"m\":136,\"g\":135},\"87ef05e2\":{\"m\":136,\"g\":135},\"b65c9889\":{\"m\":136,\"g\":135},\"c7c0d97f\":{\"m\":136,\"g\":135},\"6bc5a52f\":{\"m\":136,\"g\":135},\"2c09de34\":{\"m\":136,\"g\":135},\"38d48de9\":{\"m\":136,\"g\":135},\"7f2fa216\":{\"m\":136,\"g\":135},\"d0fb24ee\":{\"m\":136,\"g\":135},\"65b0b5b2\":{\"m\":136,\"g\":135},\"9a414b16\":{\"m\":136,\"g\":135},\"5b4f7902\":{\"m\":136,\"g\":135},\"d8ac5eec\":{\"m\":136,\"g\":135},\"bdde9496\":{\"m\":136,\"g\":135},\"b23e7ed1\":{\"m\":136,\"g\":135},\"9821fae5\":{\"m\":136,\"g\":135},\"078d9621\":{\"m\":136,\"g\":135},\"8b111b20\":{\"m\":136,\"g\":135},\"74a166cb\":{\"m\":136,\"g\":135},\"888e126a\":{\"m\":136,\"g\":135},\"bb23a8fe\":{\"m\":136,\"g\":135},\"8b869e32\":{\"m\":136,\"g\":135},\"6256936d\":{\"m\":136,\"g\":135},\"62f73a8c\":{\"m\":136,\"g\":135},\"7c1b4b1c\":{\"m\":136,\"g\":135},\"31ed68e7\":{\"m\":136,\"g\":135},\"24c91001\":{\"m\":136,\"g\":135},\"c4edcac6\":{\"m\":136,\"g\":135},\"e9343389\":{\"m\":136,\"g\":135},\"a2d4f58a\":{\"m\":136,\"g\":135},\"f66b0916\":{\"m\":136,\"g\":135},\"2f623368\":{\"m\":136,\"g\":135},\"bd9a2ced\":{\"m\":136,\"g\":135},\"0ca417d9\":{\"m\":136,\"g\":135},\"30cfb687\":{\"m\":136,\"g\":135},\"b7c7e03d\":{\"m\":136,\"g\":135},\"f0195627\":{\"m\":136,\"g\":135},\"d401d238\":{\"m\":136,\"g\":135},\"17041f46\":{\"m\":136,\"g\":135},\"f07e76b2\":{\"m\":136,\"g\":135},\"1cfd2b2d\":{\"m\":136,\"g\":135},\"0d244116\":{\"m\":136,\"g\":135},\"0eae8317\":{\"m\":136,\"g\":135},\"c483a5f4\":{\"m\":136,\"g\":135},\"f26f6c2c\":{\"m\":136,\"g\":135},\"b0213323\":{\"m\":136,\"g\":135},\"5062537b\":{\"m\":136,\"g\":135},\"698629d1\":{\"m\":136,\"g\":135},\"dd93e445\":{\"m\":136,\"g\":135},\"6c8587b5\":{\"m\":136,\"g\":135},\"02704260\":{\"m\":136,\"g\":135},\"bd48ad5e\":{\"m\":136,\"g\":135},\"749736ba\":{\"m\":136,\"g\":135},\"72549863\":{\"m\":136,\"g\":135},\"00562ee1\":{\"m\":136,\"g\":135},\"d7a8257b\":{\"m\":136,\"g\":135},\"f6f7af40\":{\"m\":136,\"g\":135},\"a3b1e8ef\":{\"m\":136,\"g\":135},\"db499e18\":{\"m\":136,\"g\":135},\"6cf3a6dd\":{\"m\":136,\"g\":135},\"90e24f5c\":{\"m\":136,\"g\":135},\"21de3e14\":{\"m\":136,\"g\":135},\"e4c1e441\":{\"m\":136,\"g\":135},\"b5af283b\":{\"m\":136,\"g\":135},\"130c6911\":{\"m\":136,\"g\":135},\"e0e50848\":{\"m\":136,\"g\":135},\"70a769bc\":{\"m\":136,\"g\":135},\"3a42c5e3\":{\"m\":136,\"g\":135},\"12b89e51\":{\"m\":136,\"g\":135},\"85184557\":{\"m\":136,\"g\":135},\"57d2ba92\":{\"m\":136,\"g\":135},\"417e75a6\":{\"m\":136,\"g\":135},\"0500fea9\":{\"m\":136,\"g\":135},\"2b461c15\":{\"m\":136,\"g\":135},\"5595ae14\":{\"m\":136,\"g\":135},\"ace6f300\":{\"m\":136,\"g\":135},\"2da49eec\":{\"m\":136,\"g\":135},\"d65ae0ec\":{\"m\":136,\"g\":135},\"60d7279c\":{\"m\":136,\"g\":135},\"abdf65d4\":{\"m\":136,\"g\":135},\"c1dfbc77\":{\"m\":136,\"g\":135},\"3b3c5a05\":{\"m\":136,\"g\":135},\"2667c857\":{\"m\":136,\"g\":135},\"fc643ffb\":{\"m\":136,\"g\":135},\"4280a18a\":{\"m\":136,\"g\":135},\"386e5415\":{\"m\":136,\"g\":135},\"b6267de5\":{\"m\":136,\"g\":135},\"e47afa02\":{\"m\":136,\"g\":135},\"5ed384d0\":{\"m\":136,\"g\":135},\"b4ce7a6d\":{\"m\":136,\"g\":135},\"25b48564\":{\"m\":136,\"g\":135},\"1c360bf7\":{\"m\":136,\"g\":135},\"3619ec61\":{\"m\":136,\"g\":135},\"db6b51a8\":{\"m\":136,\"g\":135},\"8a9ca41f\":{\"m\":136,\"g\":135},\"ac11e6a7\":{\"m\":136,\"g\":135},\"47a660d5\":{\"m\":136,\"g\":135},\"c0fc7a89\":{\"m\":136,\"g\":135},\"5bf0d862\":{\"m\":136,\"g\":135},\"75b72eb8\":{\"m\":136,\"g\":135},\"bc8b526e\":{\"m\":136,\"g\":135},\"3dfff6ae\":{\"m\":136,\"g\":135},\"ad2c1ee3\":{\"m\":136,\"g\":135},\"9940c6f5\":{\"m\":136,\"g\":135},\"00e60711\":{\"m\":136,\"g\":135},\"4bc2f2e0\":{\"m\":136,\"g\":135},\"d17b9e63\":{\"m\":136,\"g\":135},\"ba67e006\":{\"m\":136,\"g\":135},\"45f3ad2f\":{\"m\":136,\"g\":135},\"f35b5da5\":{\"m\":136,\"g\":135},\"733a0c1a\":{\"m\":136,\"g\":135},\"b369aaa2\":{\"m\":136,\"g\":135},\"b9732025\":{\"m\":136,\"g\":135},\"cbff7ad9\":{\"m\":136,\"g\":135},\"4de59d83\":{\"m\":136,\"g\":135},\"39ca57cd\":{\"m\":136,\"g\":135},\"34498067\":{\"m\":136,\"g\":135},\"b3817fa9\":{\"m\":136,\"g\":135},\"059428bd\":{\"m\":136,\"g\":135},\"5d200dd8\":{\"m\":136,\"g\":135},\"7f9a3d06\":{\"m\":136,\"g\":135},\"664f611e\":{\"m\":136,\"g\":135},\"49adb37e\":{\"m\":136,\"g\":135},\"7518dc35\":{\"m\":136,\"g\":135},\"b6871ba7\":{\"m\":136,\"g\":135},\"c1f2241a\":\"m137\",\"8da70e2a\":\"m137\"},\"g\":\"2026-02-12T19:54:15.483852\"}\n"
  },
  {
    "path": "docs/requirements.txt",
    "content": "ipykernel\nipywidgets\njupyter_client\nmarkdown>=3.4.0\nmatplotlib\nmyst-parser\nnbconvert\nnbsphinx\npandoc\npillow\npydantic\nsphinx\nsphinx-book-theme\nsphinx-copybutton\nsphinx-tabs\nnbstripout\nsphinxcontrib-mermaid\nurllib3<2.0.0\ngguf>=0.17.1\nsphinx-autobuild\n"
  },
  {
    "path": "docs/serve.sh",
    "content": "# Clean and serve documentation with auto-build\nmake clean\nmake serve\n"
  },
  {
    "path": "docs/supported_models/extending/index.rst",
    "content": "Extending SGLang\n================\n\nAdding new models and alternative backends.\n\n.. toctree::\n   :maxdepth: 1\n\n   support_new_models.md\n   transformers_fallback.md\n   modelscope.md\n   mindspore_models.md\n"
  },
  {
    "path": "docs/supported_models/extending/mindspore_models.md",
    "content": "# MindSpore Models\n\n## Introduction\n\nMindSpore is a high-performance AI framework optimized for Ascend NPUs. This doc guides users to run MindSpore models in SGLang.\n\n## Requirements\n\nMindSpore currently only supports Ascend NPU devices. Users need to first install Ascend CANN 8.5.\nThe CANN software packages can be downloaded from the [Ascend Official Website](https://www.hiascend.com).\n\n## Supported Models\n\nCurrently, the following models are supported:\n\n- **Qwen3**: Dense and MoE models\n- **DeepSeek V3/R1**\n- *More models coming soon...*\n\n## Installation\n\n> **Note**: Currently, MindSpore models are provided by an independent package `sgl-mindspore`. Support for MindSpore is built upon current SGLang support for Ascend NPU platform. Please first [install SGLang for Ascend NPU](../../platforms/ascend_npu.md) and then install `sgl-mindspore`:\n\n```shell\ngit clone https://github.com/mindspore-lab/sgl-mindspore.git\ncd sgl-mindspore\npip install -e .\n```\n\n\n## Run Model\n\nCurrent SGLang-MindSpore supports Qwen3 and DeepSeek V3/R1 models. This doc uses Qwen3-8B as an example.\n\n### Offline inference\n\nUse the following script for offline inference:\n\n```python\nimport sglang as sgl\n\n# Initialize the engine with MindSpore backend\nllm = sgl.Engine(\n    model_path=\"/path/to/your/model\",  # Local model path\n    device=\"npu\",                      # Use NPU device\n    model_impl=\"mindspore\",            # MindSpore implementation\n    attention_backend=\"ascend\",        # Attention backend\n    tp_size=1,                         # Tensor parallelism size\n    dp_size=1                          # Data parallelism size\n)\n\n# Generate text\nprompts = [\n    \"Hello, my name is\",\n    \"The capital of France is\",\n    \"The future of AI is\"\n]\n\nsampling_params = {\"temperature\": 0, \"top_p\": 0.9}\noutputs = llm.generate(prompts, sampling_params)\n\nfor prompt, output in zip(prompts, outputs):\n    print(f\"Prompt: {prompt}\")\n    print(f\"Generated: {output['text']}\")\n    print(\"---\")\n```\n\n### Start server\n\nLaunch a server with MindSpore backend:\n\n```bash\n# Basic server startup\npython3 -m sglang.launch_server \\\n    --model-path /path/to/your/model \\\n    --host 0.0.0.0 \\\n    --device npu \\\n    --model-impl mindspore \\\n    --attention-backend ascend \\\n    --tp-size 1 \\\n    --dp-size 1\n```\n\nFor distributed server with multiple nodes:\n\n```bash\n# Multi-node distributed server\npython3 -m sglang.launch_server \\\n    --model-path /path/to/your/model \\\n    --host 0.0.0.0 \\\n    --device npu \\\n    --model-impl mindspore \\\n    --attention-backend ascend \\\n    --dist-init-addr 127.0.0.1:29500 \\\n    --nnodes 2 \\\n    --node-rank 0 \\\n    --tp-size 4 \\\n    --dp-size 2\n```\n\n## Troubleshooting\n\n#### Debug Mode\n\nEnable sglang debug logging by log-level argument.\n\n```bash\npython3 -m sglang.launch_server \\\n    --model-path /path/to/your/model \\\n    --host 0.0.0.0 \\\n    --device npu \\\n    --model-impl mindspore \\\n    --attention-backend ascend \\\n    --log-level DEBUG\n```\n\nEnable mindspore info and debug logging by setting environments.\n\n```bash\nexport GLOG_v=1  # INFO\nexport GLOG_v=0  # DEBUG\n```\n\n#### Explicitly select devices\n\nUse the following environment variable to explicitly select the devices to use.\n\n```shell\nexport ASCEND_RT_VISIBLE_DEVICES=4,5,6,7  # to set device\n```\n\n#### Some communication environment issues\n\nIn case of some environment with special communication environment, users need set some environment variables.\n\n```shell\nexport MS_ENABLE_LCCL=off # current not support LCCL communication mode in SGLang-MindSpore\n```\n\n#### Some dependencies of protobuf\n\nIn case of some environment with special protobuf version, users need set some environment variables to avoid binary version mismatch.\n\n```shell\nexport PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python  # to avoid protobuf binary version mismatch\n```\n\n## Support\nFor MindSpore-specific issues:\n\n- Refer to the [MindSpore documentation](https://www.mindspore.cn/)\n"
  },
  {
    "path": "docs/supported_models/extending/modelscope.md",
    "content": "# Use Models From ModelScope\n\nTo use a model from [ModelScope](https://www.modelscope.cn), set the environment variable `SGLANG_USE_MODELSCOPE`.\n\n```bash\nexport SGLANG_USE_MODELSCOPE=true\n```\n\nWe take [Qwen2-7B-Instruct](https://www.modelscope.cn/models/qwen/qwen2-7b-instruct) as an example.\n\nLaunch the Server:\n```bash\npython -m sglang.launch_server --model-path qwen/Qwen2-7B-Instruct --port 30000\n```\n\nOr start it by docker:\n\n```bash\ndocker run --gpus all \\\n    -p 30000:30000 \\\n    -v ~/.cache/modelscope:/root/.cache/modelscope \\\n    --env \"SGLANG_USE_MODELSCOPE=true\" \\\n    --ipc=host \\\n    lmsysorg/sglang:latest \\\n    python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-7B-Instruct --host 0.0.0.0 --port 30000\n```\n\nNote that modelscope uses a different cache directory than huggingface. You may need to set it manually to avoid running out of disk space.\n"
  },
  {
    "path": "docs/supported_models/extending/support_new_models.md",
    "content": "# How to Support New Models\n\nThis document explains how to add support for new language models and multimodal large language models (MLLMs) in\nSGLang. It also covers how to test new models and register external implementations.\n\n## How to Support a New Language Model\n\nTo support a new model in SGLang, you only need to add a single file under\nthe [SGLang Models Directory](https://github.com/sgl-project/sglang/tree/main/python/sglang/srt/models). You can learn\nfrom existing model implementations and create a new file for your model. For most models, you should be able to find a\nsimilar model to start with (e.g., starting from Llama). Also refer how\nto [port a Model from vLLM to SGLang](#port-a-model-from-vllm-to-sglang)\n\n## How to Support a New Multimodal Large Language Model\n\nTo support a new multimodal large language model (MLLM) in SGLang, there are several key components in addition to the\nstandard LLM support:\n\n1. **Register your new model as multimodal**:\n   Extend `is_multimodal_model`\n   in [model_config.py](https://github.com/sgl-project/sglang/blob/0ab3f437aba729b348a683ab32b35b214456efc7/python/sglang/srt/configs/model_config.py#L561)\n   to return `True` for your model.\n\n2. **Register a new chat-template**:\n   Only when your default chat-template is unable to accept images as input: Register a new chat template in [conversation.py](https://github.com/sgl-project/sglang/tree/main/python/sglang/srt/conversation.py) and the corresponding matching function.\n\n3. **Multimodal Data Processor**:\n   Define a new `Processor` class that inherits from `BaseMultimodalProcessor` and register this processor as your\n   model’s dedicated processor.\n   See [multimodal_processor.py](https://github.com/sgl-project/sglang/tree/main/python/sglang/srt/multimodal/processors)\n   for more details.\n\n4. **Handle Multimodal Tokens**:\n   Implement a `pad_input_ids` function for your new model. In this function, multimodal tokens in the prompt should be\n   expanded (if necessary) and padded with multimodal-data-hashes so that SGLang can recognize different multimodal data\n   with `RadixAttention`.\n\n5. **Handle Image Feature Extraction**:\n   Implement a `get_image_feature` function for your new model, which extracts image features from raw image data and converts them into the embeddings used by the language model.\n\n6. **Adapt to Vision Attention**:\n   Adapt the multi-headed `Attention` of ViT with SGLang’s `VisionAttention`.\n\nYou can refer to [Qwen2VL](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/qwen2_vl.py) or\nother mllm implementations. These models demonstrate how to correctly handle both multimodal and textual inputs.\n\n## Testing and Debugging\n\nPlease note all your testing and benchmarking results in PR description.\n\n### Interactive Debugging\n\nFor interactive debugging, compare the outputs of Hugging Face/Transformers and SGLang. The following two commands\nshould give the same text output and very similar prefill logits:\n\n- Get the reference output:\n  ```bash\n  python3 scripts/playground/reference_hf.py --model-path [new model] --model-type {text,vlm}\n  ```\n- Get the SGLang output:\n  ```bash\n  python3 -m sglang.bench_one_batch --correct --model [new model]\n  ```\n\n### Add the Model to the Test Suite\n\nTo ensure the new model is well maintained, add it to the test suite by including it in the `ALL_OTHER_MODELS` list in\nthe [test_generation_models.py](https://github.com/sgl-project/sglang/blob/main/test/registered/models/test_generation_models.py)\nfile, test the new model on your local machine and report the results on demonstrative benchmarks (GSM8K, MMLU, MMMU,\nMMMU-Pro, etc.) in your PR. \\\\\nFor VLMs, also include a test in `test_vision_openai_server_{x}.py` (e.g. [test_vision_openai_server_a.py](https://github.com/sgl-project/sglang/blob/main/test/srt/test_vision_openai_server_a.py), [test_vision_openai_server_b.py](https://github.com/sgl-project/sglang/blob/main/test/srt/test_vision_openai_server_b.py)).\n\nThis is an example command to run to test a new model on your local machine:\n\n```bash\nONLY_RUN=Qwen/Qwen2-1.5B python3 -m unittest test_generation_models.TestGenerationModels.test_others\n```\n\n### Benchmark\n\n- **(Required) MMMU**: follow MMMU benchmark [README.md](https://github.com/sgl-project/sglang/blob/main/benchmark/mmmu/README.md) to get SGLang vs. HF Transformer accuracy comparison. The accuracy score from SGLang run should not be much lower than that from HF Transformer run. Similarly, follow https://docs.sglang.io/developer_guide/benchmark_and_profiling.html to get performance comparison: TTFT and throughput must meet or exceed baselines (e.g., HF Transformer).\n- **(Optional) Other evals**: If you ran other evals, please note the results in PR description.\n\n## Port a Model from vLLM to SGLang\n\nThe [vLLM Models Directory](https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models) is a valuable\nresource, as vLLM covers many models. SGLang reuses vLLM’s interface and some layers, making it easier to port models\nfrom vLLM to SGLang.\n\nTo port a model from vLLM to SGLang:\n\n- Compare these two files for guidance:\n  - [SGLang Llama Implementation](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/llama.py)\n  - [vLLM Llama Implementation](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py)\n- The major differences include:\n  - **Replace vLLM’s `Attention` with `RadixAttention`** (ensure you pass `layer_id` to `RadixAttention`).\n  - **Replace vLLM’s `LogitsProcessor` with SGLang’s `LogitsProcessor`.**\n  - **Replace the multi-headed `Attention` of ViT with SGLang’s `VisionAttention`.**\n  - **Replace other vLLM layers** (such as `RMSNorm`, `SiluAndMul`) with SGLang layers.\n  - **Remove `Sample`.**\n  - **Change the `forward()` functions** and add a `forward_batch()` method.\n  - **Add `EntryClass`** at the end.\n  - **Ensure that the new implementation uses only SGLang components** and does not rely on any vLLM components.\n\nNote: make sure you add your new model to the supported models list in the supported models documentation.\n\n## Registering an External Model Implementation\n\nIn addition to the methods above, you can register your new model with the `ModelRegistry` before launching the server.\nThis allows you to integrate your model without modifying the source code.\n\nFor example:\n\n```python\nfrom sglang.srt.models.registry import ModelRegistry\nfrom sglang.srt.entrypoints.http_server import launch_server\n\n# For a single model, add it to the registry:\nModelRegistry.models[model_name] = model_class\n\n# For multiple models, you can imitate the import_model_classes() function:\nfrom functools import lru_cache\n\n@lru_cache()\ndef import_new_model_classes():\n    model_arch_name_to_cls = {}\n    # Populate model_arch_name_to_cls with your new model classes.\n    ...\n    return model_arch_name_to_cls\n\nModelRegistry.models.update(import_new_model_classes())\n\n# Launch the server with your server arguments:\nlaunch_server(server_args)\n```\n\n## Example: Implementing and Serving a Llama Wrapper Model\n\nBelow is an introductory, step-by-step walkthrough on how to implement a new model end-to-end in SGLang and then run it via the [Offline Engine](https://github.com/sgl-project/sglang/blob/main/docs/basic_usage/offline_engine_api.ipynb).\n\n### Implementing Our Model\n\nTo keep things simple, this new model will be a simple wrapper around [Llama 3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct), and our goal will be just to bias the output logits for each `forward` call by taking the square root of each individual logit.\n\nLet's start by defining our model in a file called `llama_wrapper.py`.\nThe first step is to import the necessary libraries from SRT, which is SGLang's internal backend.\n\n```python\n# In the file `llama_wrapper.py`\n\nimport torch\nfrom transformers import LlamaConfig\nfrom typing import Optional\nfrom sglang.srt.layers.logits_processor import LogitsProcessorOutput\nfrom sglang.srt.layers.quantization.base_config import QuantizationConfig\nfrom sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors\n\nfrom sglang.srt.models.llama import LlamaForCausalLM\n```\n\nNext, we declare a new `class` for our model and have it inherit from `LlamaForCausalLM`, which allows our model to access `LlamaForCausalLM`'s predefined modules and layers, such as `LlamaAttention` and `LlamaMLP`.\nNote that almost all model implementations take in `config` and `quant_config` as arguments for their `__init__` method; `config` and `quant_config` are passed in via [`model_loader/loader.py`](https://github.com/sgl-project/sglang/blob/bf72b80122fd888bf619d17b96fa3e323ab809fc/python/sglang/srt/model_loader/loader.py#L219).\nBecause we have inherited from `LlamaForCausalLM`, we can pass our parameters directly to its constructor, which will set the member variables for us.\n\n```python\nclass LlamaWrapper(LlamaForCausalLM):\n    def __init__(\n        self,\n        config: LlamaConfig,\n        quant_config: Optional[QuantizationConfig] = None,\n        prefix: str = \"\",\n    ) -> None:\n        super().__init__(config=config, quant_config=quant_config, prefix=prefix)\n```\n\nNow, we want to define the `forward` method, which is what will be called at inference time.\nNote that the signature for `forward` is essentially the same for any model; you can take a look at the other models defined in the [`models` directory](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/) for references.\nTo see where exactly `forward` is called in the SGLang runtime's internals, take a look at [`forward_decode`](https://github.com/sgl-project/sglang/blob/bf72b80122fd888bf619d17b96fa3e323ab809fc/python/sglang/srt/model_executor/model_runner.py#L1705) and [`forward_extend`](https://github.com/sgl-project/sglang/blob/bf72b80122fd888bf619d17b96fa3e323ab809fc/python/sglang/srt/model_executor/model_runner.py#L1724) in the [`ModelRunner` class](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/model_executor/model_runner.py).\n\n```python\n    @torch.no_grad()\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        positions: torch.Tensor,\n        forward_batch: ForwardBatch,\n        pp_proxy_tensors: Optional[PPProxyTensors] = None,\n        input_embeds: Optional[torch.Tensor] = None,\n        get_embedding: bool = False,\n    ) -> LogitsProcessorOutput:\n```\n\nWe now call the `__call__` method for `self.model` (which is a member variable that `LlamaForCausalLM` defines in its `__init__` method), which eventually calls `LlamaForCausalLM`'s `forward` method.\nAfter that, we feed the `hidden_states` into our model's `LogitsProcessor` (again defined in `LlamaForCausalLM`).\n\n```python\n        hidden_states = self.model(\n            input_ids,\n            positions,\n            forward_batch,\n            input_embeds,\n            pp_proxy_tensors=pp_proxy_tensors,\n        )\n\n        res: LogitsProcessorOutput = self.logits_processor(\n            input_ids,\n            hidden_states,\n            self.lm_head,\n            forward_batch,\n        )\n```\n\nAfter receiving the logits for the next token, we can finally perform our biasing step.\n\n```python\n        orig_logits = res.next_token_logits\n        res.next_token_logits = torch.where(\n            orig_logits > 0,\n            orig_logits.sqrt(),\n            orig_logits\n        )\n\n        return res\n```\n\nNow, our `LlamaWrapper` model is created and ready to be served!\n\n### Serving Our Model Via SGLang's Offline Engine\n\nThe next step of this walkthrough involves hosting our new model offline, so that it can be served locally and without an HTTP server.\n\nFirst, create a new file called `run.py`.\nNow, we must ensure that SGLang's `ModelRegistry` can find our model.\nTo do this, we first download the model's configuration and weights from Huggingface.\n\n```python\n# In the file `run.py`\n\nimport asyncio\nfrom functools import lru_cache\nfrom huggingface_hub import snapshot_download\nfrom llama_wrapper import LlamaWrapper # Make sure to import our new model!\nimport sglang as sgl\nfrom sglang.srt.models.registry import ModelRegistry\n\n# Make sure to request access to this model on Huggingface, then export your\n# `HF_TOKEN` to download the model snapshot\nllama_dir = snapshot_download(\n    repo_id=\"meta-llama/Llama-3.1-8B-Instruct\",\n    local_dir=\"./llama_ckpt\",\n)\n```\n\nNow that we have our model on disk, we want to point it to `LlamaWrapper` by changing the `architectures` field in `./llama_ckpt/config.json` to be `LlamaWrapper`.\nThat way, when we pass in the path of our model checkpoint to SGLang, it will know that we want to use \"LlamaWrapper\" instead of \"LlamaForCausalLM\" as our model.\n\n```python\n{\n  \"architectures\": [\n   #  \"LlamaForCausalLM\"\n    \"LlamaWrapper\"\n  ],\n  ...\n}\n```\n\nHowever, if we don't link our `LlamaWrapper` class to the \"LlamaWrapper\" registry keyword, then SGLang won't be able to find our model.\nThus, to register our `LlamaWrapper`, we want to follow the steps in the above section titled \"Registering an External Model Implementation\".\n\n```python\n@lru_cache()\ndef import_new_model_classes():\n    model_arch_name_to_cls = {\"LlamaWrapper\": LlamaWrapper}\n    return model_arch_name_to_cls\n\nModelRegistry.models.update(import_new_model_classes())\n```\n\nLastly, when we create our `Engine`, we just pass in the path to the local model directory.\nThen, our `LlamaWrapper` is ready to be served; for this walkthrough, we will use SGLang `Engine`'s non-streaming asynchronous generation endpoint.\n\n```python\ndef main():\n    llm = sgl.Engine(model_path=\"./llama_ckpt\")\n    sampling_params = {\"temperature\": 0.2, \"top_k\": 5}\n    prompts = [\n        \"Write a short, neutral self-introduction for a fictional character. Hello, my name is\",\n        \"Provide a concise factual statement about France’s capital city. The capital of France is\",\n        \"Explain possible future trends in artificial intelligence. The future of AI is\",\n    ]\n\n    asyncio.run(run_llm(llm, sampling_params, prompts))\n\n    llm.shutdown()\n\nasync def run_llm(\n    llm,\n    sampling_params,\n    prompts,\n) -> None:\n    outputs = await llm.async_generate(prompts, sampling_params)\n\n    for prompt, output in zip(prompts, outputs):\n        print(f\"\\nPrompt: {prompt}\")\n        print(f\"Generated text: {output['text']}\")\n\nif __name__ == \"__main__\":\n    main()\n```\n\nNow, when we call `python run.py`, we will get the outputs of our newly created model!\n\n## Documentation\n\nAdd to table of supported models in [generative_models.md](../text_generation/generative_models.md) or [multimodal_language_models.md](../text_generation/multimodal_language_models.md)\n\n---\n\nBy following these guidelines, you can add support for new language models and multimodal large language models in\nSGLang and ensure they are thoroughly tested and easily integrated into the system.\n"
  },
  {
    "path": "docs/supported_models/extending/transformers_fallback.md",
    "content": "# Transformers fallback in SGLang\n\n`sglang` can fall back to using models that are available in `transformers`. This works for most decoder-style language models and support for vision-language models is coming soon!\n\n## Example launch Command\n\nBy default, we will use sglang implementation if it is available. Otherwise, we will fall back to transformers one. However, you can switch the implementation by setting `--model-impl` to `transformers`.\n\n```shell\npython3 -m sglang.launch_server \\\n  --model-path meta-llama/Llama-3.2-1B-Instruct \\\n  --host 0.0.0.0 \\\n  --port 30000 \\\n  --model-impl transformers\n```\n\n## Supported features\n\n### Quantization\n\nTransformers fall back has supported most of available quantization in SGLang (except GGUF). See [Quantization page](../advanced_features/quantization.md) for more information about supported quantization in SGLang.\n\n### Remote code\n\nThis fallback also means that any model on the hub that can be used in `transformers` with `trust_remote_code=True` that correctly implements attention can be used in production!\n\nA model just needs the following two things:\n\n```python\nfrom transformers import PreTrainedModel\nfrom torch import nn\n\nclass MyAttention(nn.Module):\n\n  def forward(self, hidden_states, **kwargs): # <- kwargs are required\n\n    ...\n    attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]\n    attn_output, attn_weights = attention_interface(\n      self,\n      query_states,\n      key_states,\n      value_states,\n      **kwargs,\n    )\n    ...\n\nclass MyModel(PreTrainedModel):\n  _supports_attention_backend = True\n```\n\nHere is what happens in the background:\n\n1. The config is loaded\n2. `MyModel` python class is loaded from the `auto_map`, and we check that the model `_supports_attention_backend`.\n3. The `TransformersModel` backend is used. See `/srt/models/transformers`, which leverages `self.config._attn_implementation = \"sglang\"`, thus the need to use `ALL_ATTENTION_FUNCTIONS`.\n\nThat's it!\n"
  },
  {
    "path": "docs/supported_models/index.rst",
    "content": "Supported Models\n================\n\nSGLang supports a wide variety of model architectures for different use cases.\nBrowse by category below to find models suited for your needs.\n\n.. toctree::\n   :maxdepth: 2\n\n   text_generation/index\n   retrieval_ranking/index\n   specialized/index\n   extending/index\n"
  },
  {
    "path": "docs/supported_models/retrieval_ranking/classify_models.md",
    "content": "# Classification API\n\nThis document describes the `/v1/classify` API endpoint implementation in SGLang, which is compatible with vLLM's classification API format.\n\n## Overview\n\nThe classification API allows you to classify text inputs using classification models. This implementation follows the same format as vLLM's 0.7.0 classification API.\n\n## API Endpoint\n\n```\nPOST /v1/classify\n```\n\n## Request Format\n\n```json\n{\n  \"model\": \"model_name\",\n  \"input\": \"text to classify\"\n}\n```\n\n### Parameters\n\n- `model` (string, required): The name of the classification model to use\n- `input` (string, required): The text to classify\n- `user` (string, optional): User identifier for tracking\n- `rid` (string, optional): Request ID for tracking\n- `priority` (integer, optional): Request priority\n\n## Response Format\n\n```json\n{\n  \"id\": \"classify-9bf17f2847b046c7b2d5495f4b4f9682\",\n  \"object\": \"list\",\n  \"created\": 1745383213,\n  \"model\": \"jason9693/Qwen2.5-1.5B-apeach\",\n  \"data\": [\n    {\n      \"index\": 0,\n      \"label\": \"Default\",\n      \"probs\": [0.565970778465271, 0.4340292513370514],\n      \"num_classes\": 2\n    }\n  ],\n  \"usage\": {\n    \"prompt_tokens\": 10,\n    \"total_tokens\": 10,\n    \"completion_tokens\": 0,\n    \"prompt_tokens_details\": null\n  }\n}\n```\n\n### Response Fields\n\n- `id`: Unique identifier for the classification request\n- `object`: Always \"list\"\n- `created`: Unix timestamp when the request was created\n- `model`: The model used for classification\n- `data`: Array of classification results\n  - `index`: Index of the result\n  - `label`: Predicted class label\n  - `probs`: Array of probabilities for each class\n  - `num_classes`: Total number of classes\n- `usage`: Token usage information\n  - `prompt_tokens`: Number of input tokens\n  - `total_tokens`: Total number of tokens\n  - `completion_tokens`: Number of completion tokens (always 0 for classification)\n  - `prompt_tokens_details`: Additional token details (optional)\n\n## Example Usage\n\n### Using curl\n\n```bash\ncurl -v \"http://127.0.0.1:8000/v1/classify\" \\\n  -H \"Content-Type: application/json\" \\\n  -d '{\n    \"model\": \"jason9693/Qwen2.5-1.5B-apeach\",\n    \"input\": \"Loved the new café—coffee was great.\"\n  }'\n```\n\n### Using Python\n\n```python\nimport requests\nimport json\n\n# Make classification request\nresponse = requests.post(\n    \"http://127.0.0.1:8000/v1/classify\",\n    headers={\"Content-Type\": \"application/json\"},\n    json={\n        \"model\": \"jason9693/Qwen2.5-1.5B-apeach\",\n        \"input\": \"Loved the new café—coffee was great.\"\n    }\n)\n\n# Parse response\nresult = response.json()\nprint(json.dumps(result, indent=2))\n```\n\n## Supported Models\n\nThe classification API works with any classification model supported by SGLang, including:\n\n### Classification Models (Multi-class)\n- `LlamaForSequenceClassification` - Multi-class classification\n- `Qwen2ForSequenceClassification` - Multi-class classification\n- `Qwen3ForSequenceClassification` - Multi-class classification\n- `BertForSequenceClassification` - Multi-class classification\n- `Gemma2ForSequenceClassification` - Multi-class classification\n\n**Label Mapping**: The API automatically uses the `id2label` mapping from the model's `config.json` file to provide meaningful label names instead of generic class names. If `id2label` is not available, it falls back to `LABEL_0`, `LABEL_1`, etc., or `Class_0`, `Class_1` as a last resort.\n\n### Reward Models (Single score)\n- `InternLM2ForRewardModel` - Single reward score\n- `Qwen2ForRewardModel` - Single reward score\n- `LlamaForSequenceClassificationWithNormal_Weights` - Special reward model\n\n**Note**: The `/classify` endpoint in SGLang was originally designed for reward models but now supports all non-generative models. Our `/v1/classify` endpoint provides a standardized vLLM-compatible interface for classification tasks.\n\n## Error Handling\n\nThe API returns appropriate HTTP status codes and error messages:\n\n- `400 Bad Request`: Invalid request format or missing required fields\n- `500 Internal Server Error`: Server-side processing error\n\nError response format:\n```json\n{\n  \"error\": \"Error message\",\n  \"type\": \"error_type\",\n  \"code\": 400\n}\n```\n\n## Implementation Details\n\nThe classification API is implemented using:\n\n1. **Rust Model Gateway**: Handles routing and request/response models in `sgl-model-gateway/src/protocols/spec.rs`\n2. **Python HTTP Server**: Implements the actual endpoint in `python/sglang/srt/entrypoints/http_server.py`\n3. **Classification Service**: Handles the classification logic in `python/sglang/srt/entrypoints/openai/serving_classify.py`\n\n## Testing\n\nUse the provided test script to verify the implementation:\n\n```bash\npython test_classify_api.py\n```\n\n## Compatibility\n\nThis implementation is compatible with vLLM's classification API format, allowing seamless migration from vLLM to SGLang for classification tasks.\n"
  },
  {
    "path": "docs/supported_models/retrieval_ranking/embedding_models.md",
    "content": "# Embedding Models\n\nSGLang provides robust support for embedding models by integrating efficient serving mechanisms with its flexible programming interface. This integration allows for streamlined handling of embedding tasks, facilitating faster and more accurate retrieval and semantic search operations. SGLang's architecture enables better resource utilization and reduced latency in embedding model deployment.\n\n```{important}\nEmbedding models are executed with `--is-embedding` flag and some may require `--trust-remote-code`\n```\n\n## Quick Start\n\n### Launch Server\n\n```shell\npython3 -m sglang.launch_server \\\n  --model-path Qwen/Qwen3-Embedding-4B \\\n  --is-embedding \\\n  --host 0.0.0.0 \\\n  --port 30000\n```\n\n### Client Request\n\n```python\nimport requests\n\nurl = \"http://127.0.0.1:30000\"\n\npayload = {\n    \"model\": \"Qwen/Qwen3-Embedding-4B\",\n    \"input\": \"What is the capital of France?\",\n    \"encoding_format\": \"float\"\n}\n\nresponse = requests.post(url + \"/v1/embeddings\", json=payload).json()\nprint(\"Embedding:\", response[\"data\"][0][\"embedding\"])\n```\n\n\n\n## Multimodal Embedding Example\n\nFor multimodal models like GME that support both text and images:\n\n```shell\npython3 -m sglang.launch_server \\\n  --model-path Alibaba-NLP/gme-Qwen2-VL-2B-Instruct \\\n  --is-embedding \\\n  --chat-template gme-qwen2-vl \\\n  --host 0.0.0.0 \\\n  --port 30000\n```\n\n```python\nimport requests\n\nurl = \"http://127.0.0.1:30000\"\n\ntext_input = \"Represent this image in embedding space.\"\nimage_path = \"https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild/resolve/main/images/023.jpg\"\n\npayload = {\n    \"model\": \"gme-qwen2-vl\",\n    \"input\": [\n        {\n            \"text\": text_input\n        },\n        {\n            \"image\": image_path\n        }\n    ],\n}\n\nresponse = requests.post(url + \"/v1/embeddings\", json=payload).json()\n\nprint(\"Embeddings:\", [x.get(\"embedding\") for x in response.get(\"data\", [])])\n```\n\n## Matryoshka Embedding Example\n\n[Matryoshka Embeddings](https://sbert.net/examples/sentence_transformer/training/matryoshka/README.html#matryoshka-embeddings) or [Matryoshka Representation Learning (MRL)](https://arxiv.org/abs/2205.13147) is a technique used in training embedding models. It allows user to trade off between performance and cost.\n\n### 1. Launch a Matryoshka‑capable model\n\nIf the model config already includes `matryoshka_dimensions` or `is_matryoshka` then no override is needed. Otherwise, you can use `--json-model-override-args` as below:\n\n```shell\npython3 -m sglang.launch_server \\\n    --model-path Qwen/Qwen3-Embedding-0.6B \\\n    --is-embedding \\\n    --host 0.0.0.0 \\\n    --port 30000 \\\n    --json-model-override-args '{\"matryoshka_dimensions\": [128, 256, 512, 1024, 1536]}'\n```\n\n1. Setting `\"is_matryoshka\": true` allows truncating to any dimension. Otherwise, the server will validate that the specified dimension in the request is one of `matryoshka_dimensions`.\n2. Omitting `dimensions` in a request returns the full vector.\n\n### 2. Make requests with different output dimensions\n\n```python\nimport requests\n\nurl = \"http://127.0.0.1:30000\"\n\n# Request a truncated (Matryoshka) embedding by specifying a supported dimension.\npayload = {\n    \"model\": \"Qwen/Qwen3-Embedding-0.6B\",\n    \"input\": \"Explain diffusion models simply.\",\n    \"dimensions\": 512  # change to 128 / 1024 / omit for full size\n}\n\nresponse = requests.post(url + \"/v1/embeddings\", json=payload).json()\nprint(\"Embedding:\", response[\"data\"][0][\"embedding\"])\n```\n\n\n## Supported Models\n\n| Model Family                               | Example Model                          | Chat Template | Description                                                                 |\n| ------------------------------------------ | -------------------------------------- | ------------- | --------------------------------------------------------------------------- |\n| **E5 (Llama/Mistral based)**              | `intfloat/e5-mistral-7b-instruct`     | N/A           | High-quality text embeddings based on Mistral/Llama architectures          |\n| **GTE-Qwen2**                             | `Alibaba-NLP/gte-Qwen2-7B-instruct`   | N/A           | Alibaba's text embedding model with multilingual support                   |\n| **Qwen3-Embedding**                       | `Qwen/Qwen3-Embedding-4B`             | N/A           | Latest Qwen3-based text embedding model for semantic representation        |\n| **BGE**                                    | `BAAI/bge-large-en-v1.5`              | N/A           | BAAI's text embeddings (requires `attention-backend` triton/torch_native)  |\n| **GME (Multimodal)**                      | `Alibaba-NLP/gme-Qwen2-VL-2B-Instruct`| `gme-qwen2-vl`| Multimodal embedding for text and image cross-modal tasks                  |\n| **CLIP**                                   | `openai/clip-vit-large-patch14-336`   | N/A           | OpenAI's CLIP for image and text embeddings                                |\n"
  },
  {
    "path": "docs/supported_models/retrieval_ranking/index.rst",
    "content": "Retrieval & Ranking\n===================\n\nModels for embeddings, reranking, and classification.\n\n.. toctree::\n   :maxdepth: 1\n\n   embedding_models.md\n   rerank_models.md\n   classify_models.md\n"
  },
  {
    "path": "docs/supported_models/retrieval_ranking/rerank_models.md",
    "content": "# Rerank Models\n\nSGLang offers comprehensive support for rerank models by incorporating optimized serving frameworks with a flexible programming interface. This setup enables efficient processing of cross-encoder reranking tasks, improving the accuracy and relevance of search result ordering. SGLang’s design ensures high throughput and low latency during reranker model deployment, making it ideal for semantic-based result refinement in large-scale retrieval systems.\n\n```{important}\nRerank models in SGLang fall into two categories:\n\n- **Cross-encoder rerank models**: run with `--is-embedding` (embedding runner).\n- **Decoder-only rerank models**: run **without** `--is-embedding` and use next-token logprob scoring (yes/no).\n  - Text-only (e.g. Qwen3-Reranker)\n  - Multimodal (e.g. Qwen3-VL-Reranker): also supports image/video content\n\nSome models may require `--trust-remote-code`.\n```\n\n## Supported rerank models\n\n| Model Family (Rerank)                          | Example HuggingFace Identifier       | Chat Template | Description                                                                                                                      |\n|------------------------------------------------|--------------------------------------|---------------|----------------------------------------------------------------------------------------------------------------------------------|\n| **BGE-Reranker (BgeRerankModel)**              | `BAAI/bge-reranker-v2-m3`            | N/A           | Currently only support `attention-backend` `triton` and `torch_native`. High-performance cross-encoder reranker model from BAAI. Suitable for reranking search results based on semantic relevance.   |\n| **Qwen3-Reranker (decoder-only yes/no)**       | `Qwen/Qwen3-Reranker-8B`             | `examples/chat_template/qwen3_reranker.jinja` | Decoder-only reranker using next-token logprob scoring for labels (yes/no). Launch **without** `--is-embedding`. |\n| **Qwen3-VL-Reranker (multimodal yes/no)**      | `Qwen/Qwen3-VL-Reranker-2B`          | `examples/chat_template/qwen3_vl_reranker.jinja` | Multimodal decoder-only reranker supporting text, images, and videos. Uses yes/no logprob scoring. Launch **without** `--is-embedding`. |\n\n\n## Cross-Encoder Rerank (embedding runner)\n\n### Launch Command\n\n```shell\npython3 -m sglang.launch_server \\\n  --model-path BAAI/bge-reranker-v2-m3 \\\n  --host 0.0.0.0 \\\n  --disable-radix-cache \\\n  --chunked-prefill-size -1 \\\n  --attention-backend triton \\\n  --is-embedding \\\n  --port 30000\n```\n\n### Example Client Request\n\n```python\nimport requests\n\nurl = \"http://127.0.0.1:30000/v1/rerank\"\n\npayload = {\n    \"model\": \"BAAI/bge-reranker-v2-m3\",\n    \"query\": \"what is panda?\",\n    \"documents\": [\n        \"hi\",\n        \"The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.\"\n    ],\n    \"top_n\": 1,\n    \"return_documents\": True\n}\n\nresponse = requests.post(url, json=payload)\nresponse_json = response.json()\n\nfor item in response_json:\n    if item.get(\"document\"):\n        print(f\"Score: {item['score']:.2f} - Document: '{item['document']}'\")\n    else:\n        print(f\"Score: {item['score']:.2f} - Index: {item['index']}\")\n```\n\n**Request Parameters:**\n\n- `query` (required): The query text to rank documents against\n- `documents` (required): List of documents to be ranked\n- `model` (required): Model to use for reranking\n- `top_n` (optional): Maximum number of documents to return. Defaults to returning all documents. If specified value is greater than the total number of documents, all documents will be returned.\n- `return_documents` (optional): Whether to return documents in the response. Defaults to `True`.\n\n## Qwen3-Reranker (decoder-only yes/no rerank)\n\n### Launch Command\n\n```shell\npython3 -m sglang.launch_server \\\n  --model-path Qwen/Qwen3-Reranker-0.6B \\\n  --trust-remote-code \\\n  --disable-radix-cache \\\n  --host 0.0.0.0 \\\n  --port 8001 \\\n  --chat-template examples/chat_template/qwen3_reranker.jinja\n```\n\n```{note}\nQwen3-Reranker uses decoder-only logprob scoring (yes/no). Do NOT launch it with `--is-embedding`.\n```\n\n### Example Client Request (supports optional instruct, top_n, and return_documents)\n\n```shell\ncurl -X POST http://127.0.0.1:8001/v1/rerank \\\n  -H \"Content-Type: application/json\" \\\n  -d '{\n    \"model\": \"Qwen3-Reranker-0.6B\",\n    \"query\": \"法国首都是哪里？\",\n    \"documents\": [\n      \"法国的首都是巴黎。\",\n      \"德国的首都是柏林。\",\n      \"香蕉是黄色的水果。\"\n    ],\n    \"instruct\": \"Given a web search query, retrieve relevant passages that answer the query.\",\n    \"top_n\": 2,\n    \"return_documents\": true\n  }'\n```\n\n**Request Parameters:**\n\n- `query` (required): The query text to rank documents against\n- `documents` (required): List of documents to be ranked\n- `model` (required): Model to use for reranking\n- `instruct` (optional): Instruction text for the reranker\n- `top_n` (optional): Maximum number of documents to return. Defaults to returning all documents. If specified value is greater than the total number of documents, all documents will be returned.\n- `return_documents` (optional): Whether to return documents in the response. Defaults to `True`.\n\n### Response Format\n\n`/v1/rerank` returns a list of objects (sorted by descending score):\n\n- `score`: float, higher means more relevant\n- `document`: the original document string (only included when `return_documents` is `true`)\n- `index`: the original index in the input `documents`\n- `meta_info`: optional debug/usage info (may be present for some models)\n\nThe number of returned results is controlled by the `top_n` parameter. If `top_n` is not specified or is greater than the total number of documents, all documents are returned.\n\nExample (with `return_documents: true`):\n\n```json\n[\n  {\"score\": 0.99, \"document\": \"法国的首都是巴黎。\", \"index\": 0},\n  {\"score\": 0.01, \"document\": \"德国的首都是柏林。\", \"index\": 1},\n  {\"score\": 0.00, \"document\": \"香蕉是黄色的水果。\", \"index\": 2}\n]\n```\n\nExample (with `return_documents: false`):\n\n```json\n[\n  {\"score\": 0.99, \"index\": 0},\n  {\"score\": 0.01, \"index\": 1},\n  {\"score\": 0.00, \"index\": 2}\n]\n```\n\nExample (with `top_n: 2`):\n\n```json\n[\n  {\"score\": 0.99, \"document\": \"法国的首都是巴黎。\", \"index\": 0},\n  {\"score\": 0.01, \"document\": \"德国的首都是柏林。\", \"index\": 1}\n]\n```\n\n### Common Pitfalls\n\n- **`--chat-template` is required.** Without `--chat-template examples/chat_template/qwen3_reranker.jinja`, the server does not recognize the model as a decoder-only reranker and returns a 400 error: `\"This model does not appear to be an embedding model by default. Please add `--is-embedding`...\"`. The fix is to add the chat template flag, NOT `--is-embedding`.\n- If you launch Qwen3-Reranker with `--is-embedding`, `/v1/rerank` cannot compute yes/no logprob scores. Relaunch **without** `--is-embedding`.\n- If you see a validation error like \"score should be a valid number\" and the backend returned a list, upgrade to a version that coerces `embedding[0]` into `score` for rerank responses.\n\n## Qwen3-VL-Reranker (multimodal decoder-only rerank)\n\nQwen3-VL-Reranker extends the Qwen3-Reranker to support multimodal content, allowing reranking of documents containing text, images, and videos.\n\n### Launch Command\n\n```shell\npython3 -m sglang.launch_server \\\n  --model-path Qwen/Qwen3-VL-Reranker-2B \\\n  --trust-remote-code \\\n  --disable-radix-cache \\\n  --host 0.0.0.0 \\\n  --port 30000 \\\n  --chat-template examples/chat_template/qwen3_vl_reranker.jinja\n```\n\n```{note}\nQwen3-VL-Reranker uses decoder-only logprob scoring (yes/no) like Qwen3-Reranker. Do NOT launch it with `--is-embedding`.\n```\n\n### Text-Only Reranking (backward compatible)\n\n```python\nimport requests\n\nurl = \"http://127.0.0.1:30000/v1/rerank\"\n\npayload = {\n    \"model\": \"Qwen3-VL-Reranker-2B\",\n    \"query\": \"What is machine learning?\",\n    \"documents\": [\n        \"Machine learning is a branch of artificial intelligence that enables computers to learn from data.\",\n        \"The weather in Paris is usually mild with occasional rain.\",\n        \"Deep learning is a subset of machine learning using neural networks with many layers.\",\n    ],\n    \"instruct\": \"Retrieve passages that answer the question.\",\n    \"return_documents\": True\n}\n\nresponse = requests.post(url, json=payload)\nresults = response.json()\n\nfor item in results:\n    print(f\"Score: {item['score']:.4f} - {item['document'][:60]}...\")\n```\n\n### Image Reranking (text query, image/mixed documents)\n\n```python\nimport requests\n\nurl = \"http://127.0.0.1:30000/v1/rerank\"\n\npayload = {\n    \"query\": \"A woman playing with her dog on a beach at sunset.\",\n    \"documents\": [\n        # Document 1: Text description\n        \"A woman shares a joyful moment with her golden retriever on a sun-drenched beach at sunset.\",\n        # Document 2: Image URL\n        [\n            {\n                \"type\": \"image_url\",\n                \"image_url\": {\n                    \"url\": \"https://example.com/beach_dog.jpeg\"\n                }\n            }\n        ],\n        # Document 3: Text + Image (mixed)\n        [\n            {\"type\": \"text\", \"text\": \"A joyful scene at the beach:\"},\n            {\n                \"type\": \"image_url\",\n                \"image_url\": {\n                    \"url\": \"https://example.com/beach_dog.jpeg\"\n                }\n            }\n        ]\n    ],\n    \"instruct\": \"Retrieve images or text relevant to the user's query.\",\n    \"return_documents\": False\n}\n\nresponse = requests.post(url, json=payload)\nresults = response.json()\n\nfor item in results:\n    print(f\"Index: {item['index']}, Score: {item['score']:.4f}\")\n```\n\n### Multimodal Query Reranking (query with image)\n\n```python\nimport requests\n\nurl = \"http://127.0.0.1:30000/v1/rerank\"\n\npayload = {\n    # Query with text and image\n    \"query\": [\n        {\"type\": \"text\", \"text\": \"Find similar images to this:\"},\n        {\n            \"type\": \"image_url\",\n            \"image_url\": {\n                \"url\": \"https://example.com/reference_image.jpeg\"\n            }\n        }\n    ],\n    \"documents\": [\n        \"A cat sleeping on a couch.\",\n        \"A woman and her dog enjoying the sunset at the beach.\",\n        \"A busy city street with cars and pedestrians.\",\n        [\n            {\n                \"type\": \"image_url\",\n                \"image_url\": {\n                    \"url\": \"https://example.com/similar_image.jpeg\"\n                }\n            }\n        ]\n    ],\n    \"instruct\": \"Find images or descriptions similar to the query image.\"\n}\n\nresponse = requests.post(url, json=payload)\nresults = response.json()\n\nfor item in results:\n    print(f\"Index: {item['index']}, Score: {item['score']:.4f}\")\n```\n\n### Request Parameters (Multimodal)\n\n- `query` (required): Can be a string (text-only) or a list of content parts:\n  - `{\"type\": \"text\", \"text\": \"...\"}` for text\n  - `{\"type\": \"image_url\", \"image_url\": {\"url\": \"...\"}}` for images\n  - `{\"type\": \"video_url\", \"video_url\": {\"url\": \"...\"}}` for videos\n- `documents` (required): List where each document can be a string or list of content parts (same format as query)\n- `instruct` (optional): Instruction text for the reranker\n- `top_n` (optional): Maximum number of documents to return\n- `return_documents` (optional): Whether to return documents in the response (default: `false`)\n\n### Common Pitfalls\n\n- Always use `--chat-template examples/chat_template/qwen3_vl_reranker.jinja` for Qwen3-VL-Reranker.\n- Do NOT launch with `--is-embedding`.\n- For best results, use `--disable-radix-cache` to avoid caching issues with multimodal content.\n- **Note**: Currently only `Qwen3-VL-Reranker-2B` is tested and supported. The 8B model may have different behavior and is not guaranteed to work with this template.\n"
  },
  {
    "path": "docs/supported_models/specialized/index.rst",
    "content": "Specialized Models\n==================\n\nModels for specialized tasks like reward modeling.\n\n.. toctree::\n   :maxdepth: 1\n\n   reward_models.md\n"
  },
  {
    "path": "docs/supported_models/specialized/reward_models.md",
    "content": "# Reward Models\r\n\r\nThese models output a scalar reward score or classification result, often used in reinforcement learning or content moderation tasks.\r\n\r\n```{important}\r\nThey are executed with `--is-embedding` and some may require `--trust-remote-code`.\r\n```\r\n\r\n## Example launch Command\r\n\r\n```shell\r\npython3 -m sglang.launch_server \\\r\n  --model-path Qwen/Qwen2.5-Math-RM-72B \\  # example HF/local path\r\n  --is-embedding \\\r\n  --host 0.0.0.0 \\\r\n  --tp-size=4 \\                          # set for tensor parallelism\r\n  --port 30000 \\\r\n```\r\n\r\n## Supported models\r\n\r\n| Model Family (Reward)                                                     | Example HuggingFace Identifier                              | Description                                                                     |\r\n|---------------------------------------------------------------------------|-----------------------------------------------------|---------------------------------------------------------------------------------|\r\n| **Llama (3.1 Reward / `LlamaForSequenceClassification`)**                   | `Skywork/Skywork-Reward-Llama-3.1-8B-v0.2`            | Reward model (preference classifier) based on Llama 3.1 (8B) for scoring and ranking responses for RLHF.  |\r\n| **Gemma 2 (27B Reward / `Gemma2ForSequenceClassification`)**                | `Skywork/Skywork-Reward-Gemma-2-27B-v0.2`             | Derived from Gemma‑2 (27B), this model provides human preference scoring for RLHF and multilingual tasks.  |\r\n| **InternLM 2 (Reward / `InternLM2ForRewardMode`)**                         | `internlm/internlm2-7b-reward`                       | InternLM 2 (7B)–based reward model used in alignment pipelines to guide outputs toward preferred behavior.  |\r\n| **Qwen2.5 (Reward - Math / `Qwen2ForRewardModel`)**                         | `Qwen/Qwen2.5-Math-RM-72B`                           | A 72B math-specialized RLHF reward model from the Qwen2.5 series, tuned for evaluating and refining responses.  |\r\n| **Qwen2.5 (Reward - Sequence / `Qwen2ForSequenceClassification`)**          | `jason9693/Qwen2.5-1.5B-apeach`                      | A smaller Qwen2.5 variant used for sequence classification, offering an alternative RLHF scoring mechanism.  |\r\n"
  },
  {
    "path": "docs/supported_models/text_generation/diffusion_language_models.md",
    "content": "# Diffusion Language Models\n\nDiffusion language models have shown promise for non-autoregressive text generation with parallel decoding capabilities. Unlike auto-regressive language models, different diffusion language models require different decoding strategies.\n\n## Example Launch Command\n\nSGLang supports different DLLM algorithms such as `LowConfidence` and `JointThreshold`.\n\n```shell\npython3 -m sglang.launch_server \\\n  --model-path inclusionAI/LLaDA2.0-mini \\ # example HF/local path\n  --dllm-algorithm LowConfidence \\\n  --dllm-algorithm-config ./config.yaml \\ # Optional. Uses the algorithm's default if not set.\n  --host 0.0.0.0 \\\n  --port 30000\n```\n\n## Example Configuration File\n\nDepending on the algorithm selected, the configuration parameters vary.\n\nLowConfidence Config:\n\n```yaml\n# Confidence threshold for accepting predicted tokens\n# - Higher values: More conservative, better quality but slower\n# - Lower values: More aggressive, faster but potentially lower quality\n# Range: 0.0 - 1.0\nthreshold: 0.95\n\n# Default: 32, for LLaDA2MoeModelLM\nblock_size: 32\n```\n\nJointThreshold Config:\n\n```yaml\n# Decoding threshold for Mask-to-Token (M2T) phase\n# - Higher values: More conservative, better quality but slower\n# - Lower values: More aggressive, faster but potentially lower quality\n# Range: 0.0 - 1.0\nthreshold: 0.5\n# Decoding threshold for Token-to-Token (T2T) phase\n# Range: 0.0 - 1.0\n# Setting to 0.0 allows full editing (recommended for most cases).\nedit_threshold: 0.0\n# Max extra T2T steps after all masks are removed. Prevents infinite loops.\nmax_post_edit_steps: 16\n# 2-gram repetition penalty (default 0).\n# An empirical value of 3 is often sufficient to mitigate most repetitions.\npenalty_lambda: 0\n```\n\n## Example Client Code Snippet\n\nJust like other supported models, diffusion language models can be used via the REST API or Python client.\n\nPython client example for making a generation request to the launched server:\n\n```python\nimport sglang as sgl\n\ndef main():\n    llm = sgl.Engine(model_path=\"inclusionAI/LLaDA2.0-mini\",\n                     dllm_algorithm=\"LowConfidence\",\n                     max_running_requests=1,\n                     trust_remote_code=True)\n\n    prompts = [\n        \"<role>SYSTEM</role>detailed thinking off<|role_end|><role>HUMAN</role> Write a brief introduction of the great wall <|role_end|><role>ASSISTANT</role>\"\n    ]\n\n    sampling_params = {\n        \"temperature\": 0,\n        \"max_new_tokens\": 1024,\n    }\n\n    outputs = llm.generate(prompts, sampling_params)\n    print(outputs)\n\nif __name__ == '__main__':\n    main()\n```\n\nCurl example for making a generation request to the launched server:\n\n```bash\ncurl -X POST \"http://127.0.0.1:30000/generate\" \\\n     -H \"Content-Type: application/json\" \\\n     -d '{\n        \"text\": [\n            \"<role>SYSTEM</role>detailed thinking off<|role_end|><role>HUMAN</role> Write the number from 1 to 128 <|role_end|><role>ASSISTANT</role>\",\n            \"<role>SYSTEM</role>detailed thinking off<|role_end|><role>HUMAN</role> Write a brief introduction of the great wall <|role_end|><role>ASSISTANT</role>\"\n        ],\n        \"stream\": true,\n        \"sampling_params\": {\n            \"temperature\": 0,\n            \"max_new_tokens\": 1024\n        }\n    }'\n```\n\n## Supported Models\n\nBelow the supported models are summarized in a table.\n\n| Model Family               | Example Model                | Description                                                                                          |\n| -------------------------- | ---------------------------- | ---------------------------------------------------------------------------------------------------- |\n| **LLaDA2.0 (mini, flash)** | `inclusionAI/LLaDA2.0-flash` | LLaDA2.0-flash is a diffusion language model featuring a 100B Mixture-of-Experts (MoE) architecture. |\n| **SDAR (JetLM)**           | `JetLM/SDAR-8B-Chat`         | SDAR series diffusion language model (Chat), dense architecture.                                 |\n| **SDAR (JetLM)**           | `JetLM/SDAR-30B-A3B-Chat`    | SDAR series diffusion language model (Chat), MoE architecture.                                   |\n"
  },
  {
    "path": "docs/supported_models/text_generation/generative_models.md",
    "content": "# Large Language Models\n\nThese models accept text input and produce text output (e.g., chat completions). They are primarily large language models (LLMs), some with mixture-of-experts (MoE) architectures for scaling.\n\n## Example launch Command\n\n```shell\npython3 -m sglang.launch_server \\\n  --model-path meta-llama/Llama-3.2-1B-Instruct \\  # example HF/local path\n  --host 0.0.0.0 \\\n  --port 30000 \\\n```\n\n## Supported models\n\nBelow the supported models are summarized in a table.\n\nIf you are unsure if a specific architecture is implemented, you can search for it via GitHub. For example, to search for `Qwen3ForCausalLM`, use the expression:\n\n```\nrepo:sgl-project/sglang path:/^python\\/sglang\\/srt\\/models\\// Qwen3ForCausalLM\n```\n\nin the GitHub search bar.\n\n| Model Family (Variants)             | Example HuggingFace Identifier                     | Description                                                                            |\n|-------------------------------------|--------------------------------------------------|----------------------------------------------------------------------------------------|\n| **DeepSeek** (v1, v2, v3/R1)        | `deepseek-ai/DeepSeek-R1`                        | Series of advanced reasoning-optimized models (including a 671B MoE) trained with reinforcement learning; top performance on complex reasoning, math, and code tasks. [SGLang provides Deepseek v3/R1 model-specific optimizations](../basic_usage/deepseek.md) and [Reasoning Parser](../advanced_features/separate_reasoning.ipynb)|\n| **Kimi K2** (Thinking, Instruct)    | `moonshotai/Kimi-K2-Instruct`                    | Moonshot AI's 1 trillion parameter MoE model (32B active) with 128K–256K context; state-of-the-art agentic intelligence with stable long-horizon agency across 200–300 sequential tool calls. Features MLA attention and native INT4 quantization. [See Reasoning Parser docs](../advanced_features/separate_reasoning.ipynb)|\n| **Kimi Linear** (48B-A3B)           | `moonshotai/Kimi-Linear-48B-A3B-Instruct`        | Moonshot AI's hybrid linear attention model (48B total, 3B active) with 1M token context; features Kimi Delta Attention (KDA) for up to 6× faster decoding and 75% KV cache reduction vs full attention. |\n| **GPT-OSS**       | `openai/gpt-oss-20b`, `openai/gpt-oss-120b`       | OpenAI’s latest GPT-OSS series for complex reasoning, agentic tasks, and versatile developer use cases.|\n| **Qwen** (3.5, 3, 3MoE, 3Next, 2.5, 2 series)       | `Qwen/Qwen3.5-397B-A17B`, `Qwen/Qwen3-0.6B`, `Qwen/Qwen3-30B-A3B`      | Alibaba’s latest Qwen3 series for complex reasoning, language understanding, and generation tasks; Support for MoE variants along with previous generation 2.5, 2, etc. [SGLang provides Qwen3 specific reasoning parser](../advanced_features/separate_reasoning.ipynb)|\n| **Llama** (2, 3.x, 4 series)        | `meta-llama/Llama-4-Scout-17B-16E-Instruct`       | Meta's open LLM series, spanning 7B to 400B parameters (Llama 2, 3, and new Llama 4) with well-recognized performance. [SGLang provides Llama-4 model-specific optimizations](../basic_usage/llama4.md)  |\n| **Mistral** (Mixtral, NeMo, Small3) | `mistralai/Mistral-7B-Instruct-v0.2`             | Open 7B LLM by Mistral AI with strong performance; extended into MoE (“Mixtral”) and NeMo Megatron variants for larger scale. |\n| **Gemma** (v1, v2, v3)              | `google/gemma-3-1b-it`                            | Google’s family of efficient multilingual models (1B–27B); Gemma 3 offers a 128K context window, and its larger (4B+) variants support vision input. |\n| **Phi** (Phi-1.5, Phi-2, Phi-3, Phi-4, Phi-MoE series) | `microsoft/Phi-4-multimodal-instruct`, `microsoft/Phi-3.5-MoE-instruct` | Microsoft’s Phi family of small models (1.3B–5.6B); Phi-4-multimodal (5.6B) processes text, images, and speech, Phi-4-mini is a high-accuracy text model and Phi-3.5-MoE is a mixture-of-experts model. |\n| **MiniCPM** (v3, 4B)               | `openbmb/MiniCPM3-4B`                            | OpenBMB’s series of compact LLMs for edge devices; MiniCPM 3 (4B) achieves GPT-3.5-level results in text tasks. |\n| **OLMo** (2, 3) | `allenai/OLMo-3-1125-32B`, `allenai/OLMo-3-32B-Think`, `allenai/OLMo-2-1124-7B-Instruct` | Allen AI’s series of Open Language Models designed to enable the science of language models. |\n| **OLMoE** (Open MoE)               | `allenai/OLMoE-1B-7B-0924`                       | Allen AI’s open Mixture-of-Experts model (7B total, 1B active parameters) delivering state-of-the-art results with sparse expert activation. |\n| **MiniMax-M2** (M2, M2.1, M2.5)               | `MiniMaxAI/MiniMax-M2.5`, `MiniMaxAI/MiniMax-M2.1`, `MiniMaxAI/MiniMax-M2` | MiniMax's SOTA LLM for coding & agentic workflows. |\n| **StableLM** (3B, 7B)               | `stabilityai/stablelm-tuned-alpha-7b`            | StabilityAI’s early open-source LLM (3B & 7B) for general text generation; a demonstration model with basic instruction-following ability. |\n| **Command-(R,A)** (Cohere)              | `CohereLabs/c4ai-command-r-v01`, `CohereLabs/c4ai-command-r7b-12-2024`, `CohereLabs/c4ai-command-a-03-2025`                 | Cohere’s open conversational LLM (Command series) optimized for long context, retrieval-augmented generation, and tool use. |\n| **DBRX** (Databricks)              | `databricks/dbrx-instruct`                       | Databricks’ 132B-parameter MoE model (36B active) trained on 12T tokens; competes with GPT-3.5 quality as a fully open foundation model. |\n| **Grok** (xAI)                     | `xai-org/grok-1`                                | xAI’s grok-1 model known for vast size(314B parameters) and high quality; integrated in SGLang for high-performance inference. |\n| **ChatGLM** (GLM-130B family)       | `THUDM/chatglm2-6b`                              | Zhipu AI’s bilingual chat model (6B) excelling at Chinese-English dialogue; fine-tuned for conversational quality and alignment. |\n| **InternLM 2** (7B, 20B)           | `internlm/internlm2-7b`                          | Next-gen InternLM (7B and 20B) from SenseTime, offering strong reasoning and ultra-long context support (up to 200K tokens). |\n| **ExaONE 3** (Korean-English)      | `LGAI-EXAONE/EXAONE-3.5-7.8B-Instruct`           | LG AI Research’s Korean-English model (7.8B) trained on 8T tokens; provides high-quality bilingual understanding and generation. |\n| **Baichuan 2** (7B, 13B)           | `baichuan-inc/Baichuan2-13B-Chat`                | BaichuanAI’s second-generation Chinese-English LLM (7B/13B) with improved performance and an open commercial license. |\n| **XVERSE** (MoE)                   | `xverse/XVERSE-MoE-A36B`                         | Yuanxiang’s open MoE LLM (XVERSE-MoE-A36B: 255B total, 36B active) supporting ~40 languages; delivers 100B+ dense-level performance via expert routing. |\n| **SmolLM** (135M–1.7B)            | `HuggingFaceTB/SmolLM-1.7B`                      | Hugging Face’s ultra-small LLM series (135M–1.7B params) offering surprisingly strong results, enabling advanced AI on mobile/edge devices. |\n| **GLM-4** (Multilingual 9B)        | `ZhipuAI/glm-4-9b-chat`                          | Zhipu’s GLM-4 series (up to 9B parameters) – open multilingual models with support for 1M-token context and even a 5.6B multimodal variant (Phi-4V). |\n| **MiMo** (7B series)               | `XiaomiMiMo/MiMo-7B-RL`                         | Xiaomi's reasoning-optimized model series, leverages Multiple-Token Prediction for faster inference. |\n| **ERNIE-4.5** (4.5, 4.5MoE series) | `baidu/ERNIE-4.5-21B-A3B-PT`                    | Baidu's ERNIE-4.5 series which consists of MoE with 47B and 3B active parameters, with the largest model having 424B total parameters, as well as a 0.3B dense model. |\n| **Arcee AFM-4.5B**               | `arcee-ai/AFM-4.5B-Base`                         | Arcee's foundational model series for real world reliability and edge deployments. |\n| **Persimmon** (8B)               | `adept/persimmon-8b-chat`                         | Adept’s open 8B model with a 16K context window and fast inference; trained for broad usability and licensed under Apache 2.0. |\n| **Solar** (10.7B)               | `upstage/SOLAR-10.7B-Instruct-v1.0`                         | Upstage's 10.7B parameter model, optimized for instruction-following tasks. This architecture incorporates a depth-up scaling methodology, enhancing model performance. |\n| **Tele FLM** (52B-1T)               | `CofeAI/Tele-FLM`                         | BAAI & TeleAI's multilingual model, available in 52-billion and 1-trillion parameter variants. It is a decoder-only transformer trained on ~2T tokens |\n| **Ling** (16.8B–290B) | `inclusionAI/Ling-lite`, `inclusionAI/Ling-plus` | InclusionAI’s open MoE models. Ling-Lite has 16.8B total / 2.75B active parameters, and Ling-Plus has 290B total / 28.8B active parameters. They are designed for high performance on NLP and complex reasoning tasks. |\n| **Granite 3.0, 3.1** (IBM)               | `ibm-granite/granite-3.1-8b-instruct`                          | IBM's open dense foundation models optimized for reasoning, code, and business AI use cases. Integrated with Red Hat and watsonx systems. |\n| **Granite 3.0 MoE** (IBM)               | `ibm-granite/granite-3.0-3b-a800m-instruct`                          | IBM’s Mixture-of-Experts models offering strong performance with cost-efficiency. MoE expert routing designed for enterprise deployment at scale. |\n| **GPT-J** (6B)                    | `EleutherAI/gpt-j-6b`                             | EleutherAI's GPT-2-like causal language model (6B) trained on the [Pile](https://pile.eleuther.ai/) dataset. |\n| **Orion** (14B)               | `OrionStarAI/Orion-14B-Base`                         | A series of open-source multilingual large language models by OrionStarAI, pretrained on a 2.5T token multilingual corpus including Chinese, English, Japanese, Korean, etc, and it exhibits superior performance in these languages. |\n| **Llama Nemotron Super** (v1, v1.5, NVIDIA) | `nvidia/Llama-3_3-Nemotron-Super-49B-v1`, `nvidia/Llama-3_3-Nemotron-Super-49B-v1_5` | The [NVIDIA Nemotron](https://www.nvidia.com/en-us/ai-data-science/foundation-models/nemotron/) family of multimodal models provides state-of-the-art reasoning models specifically designed for enterprise-ready AI agents. |\n| **Llama Nemotron Ultra** (v1, NVIDIA) | `nvidia/Llama-3_1-Nemotron-Ultra-253B-v1` | The [NVIDIA Nemotron](https://www.nvidia.com/en-us/ai-data-science/foundation-models/nemotron/) family of multimodal models provides state-of-the-art reasoning models specifically designed for enterprise-ready AI agents. |\n| **NVIDIA Nemotron Nano 2.0** | `nvidia/NVIDIA-Nemotron-Nano-9B-v2` | The [NVIDIA Nemotron](https://www.nvidia.com/en-us/ai-data-science/foundation-models/nemotron/) family of multimodal models provides state-of-the-art reasoning models specifically designed for enterprise-ready AI agents. `Nemotron-Nano-9B-v2` is a hybrid Mamba-Transformer language model designed to increase throughput for reasoning workloads while achieving state-of-the-art accuracy compared to similarly-sized models. |\n| **StarCoder2** (3B-15B) | `bigcode/starcoder2-7b` | StarCoder2 is a family of open large language models (LLMs) specialized for code generation and understanding. It is the successor to StarCoder, jointly developed by the BigCode project (a collaboration between Hugging Face, ServiceNow Research, and other contributors). |\n| **Jet-Nemotron** | `jet-ai/Jet-Nemotron-2B` | Jet-Nemotron is a new family of hybrid-architecture language models that surpass state-of-the-art open-source full-attention language models, while achieving significant efficiency gains. |\n| **Trinity** (Nano, Mini) | `arcee-ai/Trinity-Mini` | Arcee's foundational MoE Trinity family of models, open weights under Apache 2.0. |\n| **Falcon-H1** (0.5B–34B) | `tiiuae/Falcon-H1-34B-Instruct` | TII's hybrid Mamba-Transformer architecture combining attention and state-space models for efficient long-context inference. |\n| **Hunyuan-Large** (389B, MoE) | `tencent/Tencent-Hunyuan-Large` | Tencent's open-source MoE model with 389B total / 52B active parameters, featuring Cross-Layer Attention (CLA) for improved efficiency. |\n| **IBM Granite 4.0 (Hybrid, Dense)** | `ibm-granite/granite-4.0-h-micro`, `ibm-granite/granite-4.0-micro` | IBM Granite 4.0 micro models: hybrid Mamba–MoE (`h-micro`) and dense (`micro`) variants. Enterprise-focused reasoning models |\n| **Sarvam 2** (30B-A2B, 105B-A10B) | `sarvamai/sarvam-2` | Sarvam's Mixture-of-Experts models. The 105B variant uses MLA (Multi-head Latent Attention) and the 30B variant uses GQA, both with 128 routed experts. |\n"
  },
  {
    "path": "docs/supported_models/text_generation/index.rst",
    "content": "Text Generation\n===============\n\nModels for generating text from text or multimodal inputs.\n\n.. toctree::\n   :maxdepth: 1\n\n   generative_models.md\n   multimodal_language_models.md\n   diffusion_language_models.md\n"
  },
  {
    "path": "docs/supported_models/text_generation/multimodal_language_models.md",
    "content": "# Multimodal Language Models\n\nThese models accept multi-modal inputs (e.g., images and text) and generate text output. They augment language models with multimodal encoders.\n\n## Example launch Command\n\n```shell\npython3 -m sglang.launch_server \\\n  --model-path meta-llama/Llama-3.2-11B-Vision-Instruct \\  # example HF/local path\n  --host 0.0.0.0 \\\n  --port 30000 \\\n```\n\n> See the [OpenAI APIs section](https://docs.sglang.io/basic_usage/openai_api_vision.html) for how to send multimodal requests.\n\n## Supported models\n\nBelow the supported models are summarized in a table.\n\nIf you are unsure if a specific architecture is implemented, you can search for it via GitHub. For example, to search for `Qwen2_5_VLForConditionalGeneration`, use the expression:\n\n```\nrepo:sgl-project/sglang path:/^python\\/sglang\\/srt\\/models\\// Qwen2_5_VLForConditionalGeneration\n```\n\nin the GitHub search bar.\n\n\n| Model Family (Variants)    | Example HuggingFace Identifier             | Description                                                                                                                                                                                                     | Notes |\n|----------------------------|--------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------|\n| **Qwen-VL** | `Qwen/Qwen3-VL-235B-A22B-Instruct`              | Alibaba's vision-language extension of Qwen; for example, Qwen2.5-VL (7B and larger variants) can analyze and converse about image content.                                                                     |  |\n| **DeepSeek-VL2**           | `deepseek-ai/deepseek-vl2`                 | Vision-language variant of DeepSeek (with a dedicated image processor), enabling advanced multimodal reasoning on image and text inputs.                                                                        |  |\n| **DeepSeek-OCR / OCR-2**   | `deepseek-ai/DeepSeek-OCR-2`               | OCR-focused DeepSeek models for document understanding and text extraction.                                                                                                                                    | Use `--trust-remote-code`. |\n| **Janus-Pro** (1B, 7B)     | `deepseek-ai/Janus-Pro-7B`                 | DeepSeek's open-source multimodal model capable of both image understanding and generation. Janus-Pro employs a decoupled architecture for separate visual encoding paths, enhancing performance in both tasks. |  |\n| **MiniCPM-V / MiniCPM-o**  | `openbmb/MiniCPM-V-2_6`                    | MiniCPM-V (2.6, ~8B) supports image inputs, and MiniCPM-o adds audio/video; these multimodal LLMs are optimized for end-side deployment on mobile/edge devices.                                                 |  |\n| **Llama 3.2 Vision** (11B) | `meta-llama/Llama-3.2-11B-Vision-Instruct` | Vision-enabled variant of Llama 3 (11B) that accepts image inputs for visual question answering and other multimodal tasks.                                                                                     |  |\n| **LLaVA** (v1.5 & v1.6)    | *e.g.* `liuhaotian/llava-v1.5-13b`         | Open vision-chat models that add an image encoder to LLaMA/Vicuna (e.g. LLaMA2 13B) for following multimodal instruction prompts.                                                                               |  |\n| **LLaVA-NeXT** (8B, 72B)   | `lmms-lab/llava-next-72b`                  | Improved LLaVA models (with an 8B Llama3 version and a 72B version) offering enhanced visual instruction-following and accuracy on multimodal benchmarks.                                                       |  |\n| **LLaVA-OneVision**        | `lmms-lab/llava-onevision-qwen2-7b-ov`     | Enhanced LLaVA variant integrating Qwen as the backbone; supports multiple images (and even video frames) as inputs via an OpenAI Vision API-compatible format.                                                 |  |\n| **Gemma 3 (Multimodal)**   | `google/gemma-3-4b-it`                     | Gemma 3's larger models (4B, 12B, 27B) accept images (each image encoded as 256 tokens) alongside text in a combined 128K-token context.                                                                        |  |\n| **Kimi-VL** (A3B)          | `moonshotai/Kimi-VL-A3B-Instruct`          | Kimi-VL is a multimodal model that can understand and generate text from images.                                                                                                                                |  |\n| **Mistral-Small-3.1-24B**  | `mistralai/Mistral-Small-3.1-24B-Instruct-2503` | Mistral 3.1 is a multimodal model that can generate text from text or images input. It also supports tool calling and structured output. |  |\n| **Phi-4-multimodal-instruct**  | `microsoft/Phi-4-multimodal-instruct` | Phi-4-multimodal-instruct is the multimodal variant of the Phi-4-mini model, enhanced with LoRA for improved multimodal capabilities. It supports text, vision and audio modalities in SGLang. |  |\n| **MiMo-VL** (7B)           | `XiaomiMiMo/MiMo-VL-7B-RL`                 | Xiaomi's compact yet powerful vision-language model featuring a native resolution ViT encoder for fine-grained visual details, an MLP projector for cross-modal alignment, and the MiMo-7B language model optimized for complex reasoning tasks. |  |\n| **GLM-4.5V** (106B) /  **GLM-4.1V**(9B)           | `zai-org/GLM-4.5V`                   | GLM-4.5V and GLM-4.1V-Thinking: Towards Versatile Multimodal Reasoning with Scalable Reinforcement Learning                                                                                                                                                                                                      | Use `--chat-template glm-4v` |\n| **GLM-OCR**          | `zai-org/GLM-OCR`                   | GLM-OCR: A fast and accurate general OCR model                                                                   |  |\n| **DotsVLM** (General/OCR)  | `rednote-hilab/dots.vlm1.inst`             | RedNote's vision-language model built on a 1.2B vision encoder and DeepSeek V3 LLM, featuring NaViT vision encoder trained from scratch with dynamic resolution support and enhanced OCR capabilities through structured image data training. |  |\n| **DotsVLM-OCR**            | `rednote-hilab/dots.ocr`                   | Specialized OCR variant of DotsVLM optimized for optical character recognition tasks with enhanced text extraction and document understanding capabilities. | Don't use `--trust-remote-code` |\n| **NVILA** (8B, 15B, Lite-2B, Lite-8B, Lite-15B) | `Efficient-Large-Model/NVILA-8B` | `chatml` | NVILA explores the full stack efficiency of multi-modal design, achieving cheaper training, faster deployment and better performance. |\n| **NVIDIA Nemotron Nano 2.0 VL** | `nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16` | NVIDIA Nemotron Nano v2 VL enables multi-image reasoning and video understanding, along with strong document intelligence, visual Q&A and summarization capabilities. It builds on Nemotron Nano V2, a hybrid Mamba-Transformer LLM, in order to achieve higher inference throughput in long document and video scenarios. | Use `--trust-remote-code`. You may need to adjust `--max-mamba-cache-size` [default is 512] to fit memory constraints. |\n| **Ernie4.5-VL** | `baidu/ERNIE-4.5-VL-28B-A3B-PT`              | Baidu's vision-language models(28B,424B). Support image and video comprehension, and also support thinking.                                                                     |  |\n| **JetVLM** |  | JetVLM is an vision-language model designed for high-performance multimodal understanding and generation tasks built upon Jet-Nemotron. | Coming soon |\n| **Step3-VL** (10B) | `stepfun-ai/Step3-VL-10B` | StepFun's lightweight open-source 10B parameter VLM for multimodal intelligence, excelling in visual perception, complex reasoning, and human alignment. |  |\n| **Qwen3-Omni** | `Qwen/Qwen3-Omni-30B-A3B-Instruct` |  Alibaba's omni-modal MoE model. Currently supports the **Thinker** component (multimodal understanding for text, images, audio, and video), while the **Talker** component (audio generation) is not yet supported. |  |\n\n## Video Input Support\n\nSGLang supports video input for Vision-Language Models (VLMs), enabling temporal reasoning tasks such as video question answering, captioning, and holistic scene understanding. Video clips are decoded, key frames are sampled, and the resulting tensors are batched together with the text prompt, allowing multimodal inference to integrate visual and linguistic context.\n\n| Model Family | Example Identifier | Video notes |\n|--------------|--------------------|-------------|\n| **Qwen-VL** (Qwen2-VL, Qwen2.5-VL, Qwen3-VL, Qwen3-Omni) | `Qwen/Qwen3-VL-235B-A22B-Instruct` | The processor gathers `video_data`, runs Qwen's frame sampler, and merges the resulting features with text tokens before inference. |\n| **GLM-4v** (4.5V, 4.1V, MOE) | `zai-org/GLM-4.5V` | Video clips are read with Decord, converted to tensors, and passed to the model alongside metadata for rotary-position handling. |\n| **NVILA** (Full & Lite) | `Efficient-Large-Model/NVILA-8B` | The runtime samples eight frames per clip and attaches them to the multimodal request when `video_data` is present. |\n| **LLaVA video variants** (LLaVA-NeXT-Video, LLaVA-OneVision) | `lmms-lab/LLaVA-NeXT-Video-7B` | The processor routes video prompts to the LlavaVid video-enabled architecture, and the provided example shows how to query it with `sgl.video(...)` clips. |\n| **NVIDIA Nemotron Nano 2.0 VL** | `nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16` | The processor samples at 2 FPS, at a max of 128 frames, as per model training. The model uses [EVS](../../python/sglang/srt/multimodal/evs/README.md), a pruning method that removes redundant tokens from video embeddings. By default `video_pruning_rate=0.7`. Change this by providing: `--json-model-override-args '{\"video_pruning_rate\": 0.0}'` to disable EVS, for example. |\n| **JetVLM** |  | The runtime samples eight frames per clip and attaches them to the multimodal request when `video_data` is present. |\n\nUse `sgl.video(path, num_frames)` when building prompts to attach clips from your SGLang programs.\n\nExample OpenAI-compatible request that sends a video clip:\n\n```python\nimport requests\n\nurl = \"http://localhost:30000/v1/chat/completions\"\n\ndata = {\n    \"model\": \"Qwen/Qwen3-VL-30B-A3B-Instruct\",\n    \"messages\": [\n        {\n            \"role\": \"user\",\n            \"content\": [\n                {\"type\": \"text\", \"text\": \"What’s happening in this video?\"},\n                {\n                    \"type\": \"video_url\",\n                    \"video_url\": {\n                        \"url\": \"https://github.com/sgl-project/sgl-test-files/raw/refs/heads/main/videos/jobs_presenting_ipod.mp4\"\n                    },\n                },\n            ],\n        }\n    ],\n    \"max_tokens\": 300,\n}\n\nresponse = requests.post(url, json=data)\nprint(response.text)\n```\n\n## Usage Notes\n\n### Performance Optimization\n\nFor multimodal models, you can use the `--keep-mm-feature-on-device` flag to optimize for latency at the cost of increased GPU memory usage:\n\n- **Default behavior**: Multimodal feature tensors are moved to CPU after processing to save GPU memory\n- **With `--keep-mm-feature-on-device`**: Feature tensors remain on GPU, reducing device-to-host copy overhead and improving latency, but consuming more GPU memory\n\nUse this flag when you have sufficient GPU memory and want to minimize latency for multimodal inference.\n\n### Multimodal Inputs Limitation\n\n- **Use `--mm-process-config '{\"image\":{\"max_pixels\":1048576},\"video\":{\"fps\":3,\"max_pixels\":602112,\"max_frames\":60}}'`**: To set `image`, `video`, and `audio` input limits.\n\nThis can reduce GPU memory usage, improve inference speed, and help to avoid OOM, but may impact model performance, thus set a proper value based on your specific use case. Currently, only `qwen_vl` supports this config. Please refer to [qwen_vl processor](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/multimodal/processors/qwen_vl.py) for understanding the meaning of each parameter.\n\n### Bidirectional Attention in Multimodal Model Serving\n**Note for serving the Gemma-3 multimodal model**:\n\nAs mentioned in [Welcome Gemma 3: Google's all new multimodal, multilingual, long context open LLM\n](https://huggingface.co/blog/gemma3#multimodality), Gemma-3 employs bidirectional attention between image tokens during the prefill phase. Currently, SGLang only supports bidirectional attention when using the Triton Attention Backend. Note, however, that SGLang's current bidirectional attention implementation is incompatible with both CUDA Graph and Chunked Prefill.\n\nTo enable bidirectional attention, you can use the `TritonAttnBackend` while disabling CUDA Graph and Chunked Prefill. Example launch command:\n```shell\npython -m sglang.launch_server \\\n  --model-path google/gemma-3-4b-it \\\n  --host 0.0.0.0 --port 30000 \\\n  --enable-multimodal \\\n  --dtype bfloat16 --triton-attention-reduce-in-fp32 \\\n  --attention-backend triton \\ # Use Triton attention backend\n  --disable-cuda-graph \\ # Disable Cuda Graph\n  --chunked-prefill-size -1 # Disable Chunked Prefill\n```\n\nIf higher serving performance is required and a certain degree of accuracy loss is acceptable, you may choose to use other attention backends, and you can also enable features like CUDA Graph and Chunked Prefill for better performance, but note that the model will fall back to using causal attention instead of bidirectional attention.\n"
  },
  {
    "path": "docs/wrap_run_llm.py",
    "content": "import os\nimport re\n\n\ndef insert_runllm_widget(html_content):\n    # RunLLM Widget script to be inserted\n    widget_script = \"\"\"\n    <!-- RunLLM Widget Script -->\n    <script type=\"module\" id=\"runllm-widget-script\" src=\"https://widget.runllm.com\" crossorigin=\"true\" version=\"stable\" runllm-keyboard-shortcut=\"Mod+j\" runllm-name=\"SGLang Chatbot\" runllm-position=\"BOTTOM_RIGHT\" runllm-assistant-id=\"629\" async></script>\n    \"\"\"\n\n    # Find the closing body tag and insert the widget script before it\n    return re.sub(r\"</body>\", f\"{widget_script}\\n</body>\", html_content)\n\n\ndef process_html_files(build_dir):\n    for root, dirs, files in os.walk(build_dir):\n        for file in files:\n            if file.endswith(\".html\"):\n                file_path = os.path.join(root, file)\n\n                # Read the HTML file\n                with open(file_path, \"r\", encoding=\"utf-8\") as f:\n                    content = f.read()\n\n                # Insert the RunLLM widget\n                modified_content = insert_runllm_widget(content)\n\n                # Write back the modified content\n                with open(file_path, \"w\", encoding=\"utf-8\") as f:\n                    f.write(modified_content)\n\n\ndef main():\n    # Get the build directory path\n    build_dir = os.path.join(\n        os.path.dirname(os.path.abspath(__file__)), \"_build\", \"html\"\n    )\n    # Process all HTML files\n    if os.path.exists(build_dir):\n        process_html_files(build_dir)\n    else:\n        print(f\"Build directory not found: {build_dir}\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/assets/.gitignore",
    "content": "!example_image.png\n"
  },
  {
    "path": "examples/chat_template/qwen3_reranker.jinja",
    "content": "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n<Instruct>: {{ instruct | default(\"Given a web search query, retrieve relevant passages that answer the query.\") }}\n<Query>: {{ messages[0][\"content\"] }}\n<Document>: {{ messages[1][\"content\"] }}<|im_end|>\n<|im_start|>assistant{{ '\\n' }}\n"
  },
  {
    "path": "examples/chat_template/qwen3_vl_reranker.jinja",
    "content": "{#- Qwen3-VL-Reranker chat template for multimodal reranking -#}\n{#- This template formats query-document pairs for yes/no relevance judgment -#}\n{#- Supports text, images, and videos in both query and documents -#}\n<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n<Instruct>: {{ instruct | default(\"Given a search query, retrieve relevant candidates that answer the query.\") }}\n{#- Process query content -#}\n<Query>: {%- for content in query -%}\n    {%- if content.type == 'image' or 'image' in content or 'image_url' in content -%}\n        <|vision_start|><|image_pad|><|vision_end|>\n    {%- elif content.type == 'video' or 'video' in content -%}\n        <|vision_start|><|video_pad|><|vision_end|>\n    {%- elif 'text' in content -%}\n        {{ content.text }}\n    {%- elif content.type == 'text' -%}\n        {{ content.text }}\n    {%- endif -%}\n{%- endfor %}\n{#- Process document content -#}\n{{ '\\n' }}<Document>: {%- for content in document -%}\n    {%- if content.type == 'image' or 'image' in content or 'image_url' in content -%}\n        <|vision_start|><|image_pad|><|vision_end|>\n    {%- elif content.type == 'video' or 'video' in content -%}\n        <|vision_start|><|video_pad|><|vision_end|>\n    {%- elif 'text' in content -%}\n        {{ content.text }}\n    {%- elif content.type == 'text' -%}\n        {{ content.text }}\n    {%- endif -%}\n{%- endfor %}<|im_end|>\n<|im_start|>assistant{{ '\\n' }}\n"
  },
  {
    "path": "examples/chat_template/tool_chat_template_deepseekr1.jinja",
    "content": "{% if not add_generation_prompt is defined %}\n    {% set add_generation_prompt = false %}\n{% endif %}\n{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='', is_first_sp=true, is_last_user=false) %}\n{%- for message in messages %}\n    {%- if message['role'] == 'system' %}\n        {%- if ns.is_first_sp %}\n            {% set ns.system_prompt = ns.system_prompt + message['content'] %}\n            {% set ns.is_first_sp = false %}\n        {%- else %}\n            {% set ns.system_prompt = ns.system_prompt + '\\n\\n' + message['content'] %}\n        {%- endif %}\n    {%- endif %}\n{%- endfor %}\n\n{# --- Append tool descriptions if tools are defined --- #}\n{% if tools is defined and tools is not none %}\n    {% set tool_ns = namespace(text='You are a helpful assistant with tool calling capabilities. '\n        'When a tool call is needed, you MUST use the following format to issue the call:\\n'\n        '<｜tool▁calls▁begin｜><｜tool▁call▁begin｜>function<｜tool▁sep｜>FUNCTION_NAME\\n'\n        '```json\\n{\"param1\": \"value1\", \"param2\": \"value2\"}\\n```<｜tool▁call▁end｜><｜tool▁calls▁end｜>\\n\\n'\n        'Make sure the JSON is valid.'\n        '## Tools\\n\\n### Function\\n\\nYou have the following functions available:\\n\\n') %}\n    {% for tool in tools %}\n        {% set tool_ns.text = tool_ns.text + '- `' + tool['name'] + '`:\\n```json\\n' + (tool | tojson) + '\\n```\\n' %}\n    {% endfor %}\n    {% set ns.system_prompt = ns.system_prompt + '\\n\\n' + tool_ns.text %}\n{% endif %}\n\n{{ bos_token }}\n{{ ns.system_prompt }}\n{%- for message in messages %}\n    {% set content = message['content'] %}\n    {%- if message['role'] == 'user' %}\n        {%- set ns.is_tool = false -%}\n        {%- set ns.is_first = false -%}\n        {%- set ns.is_last_user = true -%}\n        {{'<｜User｜>' + content + '<｜Assistant｜>'}}\n    {%- endif %}\n    {%- if message['role'] == 'assistant' %}\n        {% if '</think>' in content %}\n            {% set content = content.split('</think>')[-1] %}\n        {% endif %}\n    {% endif %}\n    {%- if message['role'] == 'assistant' and message['tool_calls'] is defined and message['tool_calls'] is not none %}\n        {%- set ns.is_last_user = false -%}\n        {%- if ns.is_tool %}\n            {{'<｜tool▁outputs▁end｜>'}}\n        {%- endif %}\n        {%- set ns.is_first = false %}\n        {%- set ns.is_tool = false -%}\n        {%- set ns.is_output_first = true %}\n        {%- for tool in message['tool_calls'] %}\n            {%- set tool_type = tool['type'] if tool['type'] is defined else 'function' -%}\n            {%- set tool_args = tool['function']['arguments'] if tool['function']['arguments'] is string else tool['function']['arguments'] | tojson -%}\n            {%- if not ns.is_first %}\n                {%- if content is none %}\n                    {{'<｜tool▁calls▁begin｜><｜tool▁call▁begin｜>' + tool_type + '<｜tool▁sep｜>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool_args + '\\n' + '```' + '<｜tool▁call▁end｜>'}}\n                {%- else %}\n                    {{content + '<｜tool▁calls▁begin｜><｜tool▁call▁begin｜>' + tool_type + '<｜tool▁sep｜>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool_args + '\\n' + '```' + '<｜tool▁call▁end｜>'}}\n                {%- endif %}\n                {%- set ns.is_first = true -%}\n            {%- else %}\n                {{'\\n' + '<｜tool▁call▁begin｜>' + tool_type + '<｜tool▁sep｜>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool_args + '\\n' + '```' + '<｜tool▁call▁end｜>'}}\n            {%- endif %}\n        {%- endfor %}\n        {{'<｜tool▁calls▁end｜><｜end▁of▁sentence｜>'}}\n    {%- endif %}\n    {%- if message['role'] == 'assistant' and (message['tool_calls'] is not defined or message['tool_calls'] is none)%}\n        {%- set ns.is_last_user = false -%}\n        {%- if ns.is_tool %}\n            {{'<｜tool▁outputs▁end｜>' + content + '<｜end▁of▁sentence｜>'}}\n            {%- set ns.is_tool = false -%}\n        {%- else %}\n            {{content + '<｜end▁of▁sentence｜>'}}\n        {%- endif %}\n    {%- endif %}\n    {%- if message['role'] == 'tool' %}\n        {%- set ns.is_last_user = false -%}\n        {%- set ns.is_tool = true -%}\n        {%- if ns.is_output_first %}\n            {{'<｜tool▁outputs▁begin｜><｜tool▁output▁begin｜>' + content + '<｜tool▁output▁end｜>'}}\n            {%- set ns.is_output_first = false %}\n        {%- else %}\n            {{'\\n<｜tool▁output▁begin｜>' + content + '<｜tool▁output▁end｜>'}}\n        {%- endif %}\n    {%- endif %}\n{%- endfor -%}\n{% if ns.is_tool %}\n    {{'<｜tool▁outputs▁end｜>'}}\n{% endif %}\n{% if add_generation_prompt and not ns.is_last_user and not ns.is_tool %}\n    {{'<｜Assistant｜>'}}\n{% endif %}\n"
  },
  {
    "path": "examples/chat_template/tool_chat_template_deepseekv3.jinja",
    "content": "{% if not add_generation_prompt is defined %}\n    {% set add_generation_prompt = false %}\n{% endif %}\n\n{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='', is_first_sp=true, is_last_user=false) %}\n{%- for message in messages %}\n    {%- if message['role'] == 'system' %}\n        {%- if ns.is_first_sp %}\n            {% set ns.system_prompt = ns.system_prompt + message['content'] %}\n            {% set ns.is_first_sp = false %}\n        {%- else %}\n            {% set ns.system_prompt = ns.system_prompt + '\\n\\n' + message['content'] %}\n        {%- endif %}\n    {%- endif %}\n{%- endfor -%}\n\n{# --- Append tool descriptions if tools are defined --- #}\n{% if tools is defined and tools is not none %}\n    {% set tool_ns = namespace(text='You are a helpful assistant with tool calling capabilities. '\n        'When a tool call is needed, you MUST use the following format to issue the call:\\n'\n        '<｜tool▁calls▁begin｜><｜tool▁call▁begin｜>function<｜tool▁sep｜>FUNCTION_NAME\\n'\n        '```json\\n{\"param1\": \"value1\", \"param2\": \"value2\"}\\n```<｜tool▁call▁end｜><｜tool▁calls▁end｜>\\n\\n'\n        'Make sure the JSON is valid.'\n        '## Tools\\n\\n### Function\\n\\nYou have the following functions available:\\n\\n') %}\n    {% for tool in tools %}\n        {% set tool_ns.text = tool_ns.text + '\\n```json\\n' + (tool | tojson) + '\\n```\\n' %}\n    {% endfor %}\n    {% set ns.system_prompt = ns.system_prompt + '\\n\\n' + tool_ns.text %}\n{% endif %}\n\n{{- bos_token }}\n{{- ns.system_prompt }}\n\n{%- for message in messages %}\n    {%- if message['role'] == 'user' %}\n        {%- set ns.is_tool = false -%}\n        {%- set ns.is_first = false -%}\n        {%- set ns.is_last_user = true -%}\n        {{'<｜User｜>' + message['content'] + '<｜Assistant｜>'}}\n    {%- endif %}\n    {%- if message['role'] == 'assistant' and message['tool_calls'] is defined and message['tool_calls'] is not none %}\n        {%- set ns.is_last_user = false -%}\n        {%- if ns.is_tool %}\n            {{- '<｜tool▁outputs▁end｜>'}}\n        {%- endif %}\n        {%- set ns.is_first = false %}\n        {%- set ns.is_tool = false -%}\n        {%- set ns.is_output_first = true %}\n        {%- for tool in message['tool_calls'] %}\n            {%- set formatted_args = tool['function']['arguments'] if tool['function']['arguments'] is string else tool['function']['arguments']|tojson %}\n            {%- if not ns.is_first %}\n                {%- if message['content'] is none %}\n                    {{- '<｜tool▁calls▁begin｜><｜tool▁call▁begin｜>' + tool['type'] + '<｜tool▁sep｜>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + formatted_args + '\\n' + '```' + '<｜tool▁call▁end｜>'}}\n                {%- else %}\n                    {{- message['content'] + '<｜tool▁calls▁begin｜><｜tool▁call▁begin｜>' + tool['type'] + '<｜tool▁sep｜>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + formatted_args + '\\n' + '```' + '<｜tool▁call▁end｜>'}}\n                {%- endif %}\n                {%- set ns.is_first = true -%}\n            {%- else %}\n                {{- '\\n' + '<｜tool▁call▁begin｜>' + tool['type'] + '<｜tool▁sep｜>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + formatted_args + '\\n' + '```' + '<｜tool▁call▁end｜>'}}\n            {%- endif %}\n        {%- endfor %}\n        {{- '<｜tool▁calls▁end｜><｜end▁of▁sentence｜>'}}\n    {%- endif %}\n    {%- if message['role'] == 'assistant' and (message['tool_calls'] is not defined or message['tool_calls'] is none)%}\n        {%- set ns.is_last_user = false -%}\n        {%- if ns.is_tool %}\n            {{- '<｜tool▁outputs▁end｜>' + message['content'] + '<｜end▁of▁sentence｜>'}}\n            {%- set ns.is_tool = false -%}\n        {%- else %}\n            {% set content = message['content'] %}\n            {{- content + '<｜end▁of▁sentence｜>'}}\n        {%- endif %}\n    {%- endif %}\n    {%- if message['role'] == 'tool' %}\n        {%- set ns.is_last_user = false -%}\n        {%- set ns.is_tool = true -%}\n        {%- if ns.is_output_first %}\n            {{- 'Use the results below to formulate an answer to the user question unless additional information is needed.' }}\n            {{- '<｜tool▁outputs▁begin｜><｜tool▁output▁begin｜>' + message['content'] + '<｜tool▁output▁end｜>'}}\n            {%- set ns.is_output_first = false %}\n        {%- else %}\n            {{- '\\n<｜tool▁output▁begin｜>' + message['content'] + '<｜tool▁output▁end｜>'}}\n        {%- endif %}\n    {%- endif %}\n{%- endfor -%}\n\n{% if ns.is_tool %}\n    {{- '<｜tool▁outputs▁end｜>'}}\n{% endif %}\n{% if add_generation_prompt and not ns.is_last_user and not ns.is_tool %}\n    {{- '<｜Assistant｜>'}}\n{% endif %}\n"
  },
  {
    "path": "examples/chat_template/tool_chat_template_deepseekv31.jinja",
    "content": "{% if not add_generation_prompt is defined %}\n  {% set add_generation_prompt = false %}\n{% endif %}\n{% if not thinking is defined %}\n  {% set thinking = false %}\n{% endif %}\n{% set ns = namespace(is_first=false, is_tool=false, system_prompt='', is_first_sp=true, is_last_user=false) %}\n{%- for message in messages %}\n  {%- if message['role'] == 'system' %}\n    {%- if ns.is_first_sp %}\n      {% set ns.system_prompt = ns.system_prompt + message['content'] %}\n      {% set ns.is_first_sp = false %}\n    {%- else %}\n      {% set ns.system_prompt = ns.system_prompt + '\\n\\n' + message['content'] %}\n    {%- endif %}\n  {%- endif %}\n{%- endfor %}\n\n{% if tools is defined and tools is not none %}\n  {% set tool_ns = namespace(text='## Tools\\nYou have access to the following tools:\\n') %}\n  {% for tool in tools %}\n    {% if tool.function.description is not none %}\n      {% set tool_ns.text = tool_ns.text + '\\n### ' + tool.function.name + '\\nDescription: ' + tool.function.description + '\\n\\nParameters: ' + (tool.function.parameters | tojson) + '\\n' %}\n    {% else %}\n      {% set tool_ns.text = tool_ns.text + '\\n### ' + tool.function.name + '\\n\\nParameters: ' + (tool.function.parameters | tojson) + '\\n' %}\n    {% endif %}\n  {% endfor %}\n  {% set tool_ns.text = tool_ns.text + \"\\nIMPORTANT: ALWAYS adhere to this exact format for tool use:\\n<｜tool▁calls▁begin｜><｜tool▁call▁begin｜>tool_call_name<｜tool▁sep｜>tool_call_arguments<｜tool▁call▁end｜>{{additional_tool_calls}}<｜tool▁calls▁end｜>\\n\\nWhere:\\n\\n- `tool_call_name` must be an exact match to one of the available tools\\n- `tool_call_arguments` must be valid JSON that strictly follows the tool's Parameters Schema\\n- For multiple tool calls, chain them directly without separators or spaces\\n\" %}\n  {% set ns.system_prompt = ns.system_prompt + '\\n\\n' + tool_ns.text %}\n{% endif %}\n\n{{ bos_token }}{{ ns.system_prompt }}\n{%- for message in messages %}\n  {%- if message['role'] == 'user' %}\n    {%- set ns.is_tool = false -%}\n    {%- set ns.is_first = false -%}\n    {%- set ns.is_last_user = true -%}\n    {{'<｜User｜>' + message['content']}}\n  {%- endif %}\n  {%- if message['role'] == 'assistant' and message['tool_calls'] is defined and message['tool_calls'] is not none %}\n    {%- if ns.is_last_user %}\n      {{'<｜Assistant｜></think>'}}\n    {%- endif %}\n    {%- set ns.is_last_user = false -%}\n    {%- set ns.is_first = false %}\n    {%- set ns.is_tool = false -%}\n    {%- for tool in message['tool_calls'] %}\n      {%- set formatted_args = tool['function']['arguments'] if tool['function']['arguments'] is string else tool['function']['arguments']|tojson %}\n      {%- if not ns.is_first %}\n        {%- if message['content'] is none %}\n          {{'<｜tool▁calls▁begin｜><｜tool▁call▁begin｜>'+ tool['function']['name'] + '<｜tool▁sep｜>' + formatted_args + '<｜tool▁call▁end｜>'}}\n        {%- else %}\n          {{message['content'] + '<｜tool▁calls▁begin｜><｜tool▁call▁begin｜>' + tool['function']['name'] + '<｜tool▁sep｜>' + formatted_args + '<｜tool▁call▁end｜>'}}\n        {%- endif %}\n        {%- set ns.is_first = true -%}\n      {%- else %}\n        {{'<｜tool▁call▁begin｜>'+ tool['function']['name'] + '<｜tool▁sep｜>' + formatted_args + '<｜tool▁call▁end｜>'}}\n      {%- endif %}\n    {%- endfor %}\n    {{'<｜tool▁calls▁end｜><｜end▁of▁sentence｜>'}}\n  {%- endif %}\n  {%- if message['role'] == 'assistant' and (message['tool_calls'] is not defined or message['tool_calls'] is none) %}\n    {%- if ns.is_last_user %}\n      {{'<｜Assistant｜>'}}\n      {%- if message['prefix'] is defined and message['prefix'] and thinking %}\n        {{'<think>'}}\n      {%- else %}\n        {{'</think>'}}\n      {%- endif %}\n    {%- endif %}\n    {%- set ns.is_last_user = false -%}\n    {%- if ns.is_tool %}\n      {{message['content'] + '<｜end▁of▁sentence｜>'}}\n      {%- set ns.is_tool = false -%}\n    {%- else %}\n      {%- set content = message['content'] -%}\n      {%- if '</think>' in content %}\n        {%- set content = content.split('</think>', 1)[1] -%}\n      {%- endif %}\n      {{content + '<｜end▁of▁sentence｜>'}}\n    {%- endif %}\n  {%- endif %}\n  {%- if message['role'] == 'tool' %}\n    {%- set ns.is_last_user = false -%}\n    {%- set ns.is_tool = true -%}\n    {{'<｜tool▁output▁begin｜>' + message['content'] + '<｜tool▁output▁end｜>'}}\n  {%- endif %}\n{%- endfor -%}\n{%- if add_generation_prompt and ns.is_last_user and not ns.is_tool %}\n  {{'<｜Assistant｜>'}}\n  {%- if not thinking %}\n    {{'</think>'}}\n  {%- else %}\n    {{'<think>'}}\n  {%- endif %}\n{% endif %}\n"
  },
  {
    "path": "examples/chat_template/tool_chat_template_deepseekv32.jinja",
    "content": "{% if not add_generation_prompt is defined %}\n  {% set add_generation_prompt = false %}\n{% endif %}\n{% if not thinking is defined %}\n  {% set thinking = false %}\n{% endif %}\n{% set ns = namespace(is_first=false, is_tool=false, system_prompt='', is_first_sp=true, is_last_user=false, is_only_sys=false, is_prefix=false) %}\n{%- for message in messages %}\n  {%- if message['role'] == 'system' %}\n    {%- if ns.is_first_sp %}\n      {% set ns.system_prompt = ns.system_prompt + message['content'] %}\n      {% set ns.is_first_sp = false %}\n    {%- else %}\n      {% set ns.system_prompt = ns.system_prompt + '\\n\\n' + message['content'] %}\n    {%- endif %}\n    {% set ns.is_only_sys = true %}\n  {%- endif %}\n{%- endfor %}\n\n{% if tools is defined and tools is not none %}\n  {% set tool_ns = namespace(text='## Tools\\nYou have access to the following tools:\\n') %}\n  {% for tool in tools %}\n    {% set tool_ns.text = tool_ns.text + '\\n### ' + tool.function.name + '\\nDescription: ' + tool.function.description + '\\n\\nParameters: ' + (tool.function.parameters | tojson) + '\\n' %}\n  {% endfor %}\n  {% set tool_ns.text = tool_ns.text + \"\\nIMPORTANT: ALWAYS adhere to this exact format for tool use:\\n<｜tool▁calls▁begin｜><｜tool▁call▁begin｜>tool_call_name<｜tool▁sep｜>tool_call_arguments<｜tool▁call▁end｜>{{additional_tool_calls}}<｜tool▁calls▁end｜>\\n\\nWhere:\\n\\n- `tool_call_name` must be an exact match to one of the available tools\\n- `tool_call_arguments` must be valid JSON that strictly follows the tool's Parameters Schema\\n- For multiple tool calls, chain them directly without separators or spaces\\n\" %}\n  {% set ns.system_prompt = ns.system_prompt + '\\n\\n' + tool_ns.text %}\n{% endif %}\n\n{{ bos_token }}{{ ns.system_prompt }}\n{%- for message in messages %}\n  {%- if message['role'] == 'user' %}\n    {%- set ns.is_tool = false -%}\n    {%- set ns.is_first = false -%}\n    {%- set ns.is_last_user = true -%}\n    {{'<｜User｜>' + message['content']}}\n  {%- endif %}\n  {%- if message['role'] == 'assistant' and message['tool_calls'] is defined and message['tool_calls'] is not none %}\n    {%- if ns.is_last_user or ns.is_only_sys %}\n      {{'<｜Assistant｜></think>'}}\n    {%- endif %}\n    {%- set ns.is_last_user = false -%}\n    {%- set ns.is_first = false %}\n    {%- set ns.is_tool = false -%}\n    {%- for tool in message['tool_calls'] %}\n      {%- set formatted_args = tool['function']['arguments'] if tool['function']['arguments'] is string else tool['function']['arguments']|tojson %}\n      {%- if not ns.is_first %}\n        {%- if message['content'] is none %}\n          {{'<｜tool▁calls▁begin｜><｜tool▁call▁begin｜>'+ tool['function']['name'] + '<｜tool▁sep｜>' + formatted_args + '<｜tool▁call▁end｜>'}}\n        {%- else %}\n          {{message['content'] + '<｜tool▁calls▁begin｜><｜tool▁call▁begin｜>' + tool['function']['name'] + '<｜tool▁sep｜>' + formatted_args + '<｜tool▁call▁end｜>'}}\n        {%- endif %}\n        {%- set ns.is_first = true -%}\n      {%- else %}\n        {{'<｜tool▁call▁begin｜>'+ tool['function']['name'] + '<｜tool▁sep｜>' + formatted_args + '<｜tool▁call▁end｜>'}}\n      {%- endif %}\n    {%- endfor %}\n    {{'<｜tool▁calls▁end｜><｜end▁of▁sentence｜>'}}\n  {%- endif %}\n  {%- if message['role'] == 'assistant' and (message['tool_calls'] is not defined or message['tool_calls'] is none) %}\n    {%- if ns.is_last_user %}\n      {{'<｜Assistant｜>'}}\n      {%- if message['prefix'] is defined and message['prefix'] and thinking %}\n        {{'<think>'}}\n      {%- else %}\n        {{'</think>'}}\n      {%- endif %}\n    {%- endif %}\n    {%- if message['prefix'] is defined and message['prefix'] %}\n      {%- set ns.is_prefix = true -%}\n    {%- endif %}\n    {%- set ns.is_last_user = false -%}\n    {%- if ns.is_tool %}\n      {{message['content'] + '<｜end▁of▁sentence｜>'}}\n      {%- set ns.is_tool = false -%}\n    {%- else %}\n      {%- set content = message['content'] -%}\n      {%- if '</think>' in content %}\n        {%- set content = content.split('</think>', 1)[1] -%}\n      {%- endif %}\n      {{content + '<｜end▁of▁sentence｜>'}}\n    {%- endif %}\n  {%- endif %}\n  {%- if message['role'] == 'tool' %}\n    {%- set ns.is_last_user = false -%}\n    {%- set ns.is_tool = true -%}\n    {{'<｜tool▁output▁begin｜>' + message['content'] + '<｜tool▁output▁end｜>'}}\n  {%- endif %}\n  {%- if message['role'] != 'system' %}\n    {% set ns.is_only_sys = false %}\n  {%- endif %}\n{%- endfor -%}\n{% if add_generation_prompt and not ns.is_tool%}\n  {% if ns.is_last_user or ns.is_only_sys or not ns.is_prefix %}\n    {{'<｜Assistant｜>'}}\n    {%- if not thinking %}\n      {{'</think>'}}\n    {%- else %}\n      {{'<think>'}}\n    {%- endif %}\n  {% endif %}\n{% endif %}\n"
  },
  {
    "path": "examples/chat_template/tool_chat_template_llama3.1_json.jinja",
    "content": "{# Copied from https://github.com/vllm-project/vllm/blob/main/examples/tool_chat_template_llama3.1_json.jinja to enable better model response. #}\n{{- bos_token }}\n{%- if custom_tools is defined %}\n    {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n    {#- Llama 3.1 doesn't pass all tests if the tools are in the system prompt #}\n    {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n    {%- if strftime_now is defined %}\n        {%- set date_string = strftime_now(\"%d %b %Y\") %}\n    {%- else %}\n        {%- set date_string = \"26 Jul 2024\" %}\n    {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n    {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n    {%- if messages[0]['content'] is string %}\n        {%- set system_message = messages[0]['content']|trim %}\n    {%- else %}\n        {%- set system_message = messages[0]['content'][0]['text']|trim %}\n    {%- endif %}\n    {%- set messages = messages[1:] %}\n{%- else %}\n    {%- if tools is not none %}\n        {%- set system_message = \"You are a helpful assistant with tool calling capabilities. Only reply with a tool call if the function exists in the library provided by the user. If it doesn't exist, just reply directly in natural language. When you receive a tool call response, use the output to format an answer to the original user question.\" %}\n    {%- else %}\n        {%- set system_message = \"\" %}\n    {%- endif %}\n{%- endif %}\n\n{#- System message #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if tools is not none %}\n    {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n    {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call. \" }}\n    {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}. ' }}\n    {{- \"Do not use variables.\\n\\n\" }}\n    {%- for t in tools %}\n        {{- t | tojson(indent=4) }}\n        {{- \"\\n\\n\" }}\n    {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n    {#- Extract the first user message so we can plug it in here #}\n    {%- if messages | length != 0 %}\n        {%- if messages[0]['content'] is string %}\n            {%- set first_user_message = messages[0]['content']|trim %}\n        {%- else %}\n            {%- set first_user_message = messages[0]['content'] | selectattr('type', 'equalto', 'text') | map(attribute='text') | map('trim') | join('\\n') %}\n        {%- endif %}\n        {%- set messages = messages[1:] %}\n    {%- else %}\n        {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n    {%- endif %}\n    {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n    {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n    {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n    {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}. ' }}\n    {{- \"Do not use variables.\\n\\n\" }}\n    {%- for t in tools %}\n        {{- t | tojson(indent=4) }}\n        {{- \"\\n\\n\" }}\n    {%- endfor %}\n    {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n    {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n        {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n' }}\n        {%- if message['content'] is string %}\n            {{- message['content'] | trim}}\n        {%- else %}\n            {%- for content in message['content'] %}\n                {%- if content['type'] == 'text' %}\n                    {{- content['text'] | trim }}\n                {%- endif %}\n            {%- endfor %}\n        {%- endif %}\n        {{- '<|eot_id|>' }}\n    {%- elif 'tool_calls' in message %}\n        {%- if not message.tool_calls|length == 1 %}\n            {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n        {%- endif %}\n        {%- set tool_call = message.tool_calls[0].function %}\n        {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n        {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n        {{- '\"parameters\": ' }}\n        {{- tool_call.arguments | tojson }}\n        {{- \"}\" }}\n        {{- \"<|eot_id|>\" }}\n    {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n        {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n        {%- if message.content is string %}\n            {{- { \"output\": message.content } | tojson }}\n        {%- else %}\n            {%- for content in message['content']  %}\n                {%- if content['type']  == 'text' %}\n                    {{- { \"output\": content['text']  } | tojson }}\n                {%- endif %}\n            {%- endfor %}\n        {%- endif %}\n        {{- \"<|eot_id|>\" }}\n    {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n    {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n"
  },
  {
    "path": "examples/chat_template/tool_chat_template_llama4_pythonic.jinja",
    "content": "{# Copied from https://github.com/wukaixingxp/vllm/blob/8a32e2a6e452a03c0e8222e3876ad6086cbf581f/examples/tool_chat_template_llama4_pythonic.jinja to enable better model response. #}\n{{- bos_token }}\n{%- if custom_tools is defined and custom_tools %}\n    {%- set tools = custom_tools %}\n{%- endif %}\n{%- if tools is defined and tools %}\n    {%- set tool_definition = tool_definition ~ (tools | tojson(indent=4)) %}\n{%- else %}\n    {%- set tools = none %}\n{%- endif %}\n\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n    {%- set user_provided_system_message = true %}\n    {%- if messages[0]['content'] is string %}\n        {%- set system_message = messages[0]['content']|trim %}\n    {%- else %}\n        {%- set system_message = messages[0]['content'][0]['text']|trim %}\n    {%- endif %}\n    {%- set messages = messages[1:] %}\n{%- else %}\n    {%- if tools is not none  %}\n        {#- Since not system_message was provided by user, if tool is provided, system_message is now default tool system message #}\n        {#- This system message is from llama website:https://www.llama.com/docs/model-cards-and-prompt-formats/llama4/  #}\n        {%- set system_message = \"You are a helpful assistant and an expert in function composition. You can answer general questions using your internal knowledge OR invoke functions when necessary. Follow these strict guidelines:\\n\\n1. FUNCTION CALLS:\\n- ONLY use functions that are EXPLICITLY listed in the function list below\\n- If NO functions are listed (empty function list []), respond ONLY with internal knowledge or \\\"I don't have access to [Unavailable service] information\\\"\\n- If a function is not in the list, respond ONLY with internal knowledge or \\\"I don't have access to [Unavailable service] information\\\"\\n- If ALL required parameters are present AND the query EXACTLY matches a listed function's purpose: output ONLY the function call(s)\\n- Use exact format: [func_name1(param1=value1, param2=value2), func_name2(...)]\\nExamples:\\nCORRECT: [get_weather(location=\\\"Vancouver\\\"), calculate_route(start=\\\"Boston\\\", end=\\\"New York\\\")] <- Only if get_weather and calculate_route are in function list\\nINCORRECT: get_weather(location=\\\"New York\\\")\\nINCORRECT: Let me check the weather: [get_weather(location=\\\"New York\\\")]\\nINCORRECT: [get_events(location=\\\"Singapore\\\")] <- If function not in list\\n\\n2. RESPONSE RULES:\\n- For pure function requests matching a listed function: ONLY output the function call(s)\\n- For knowledge questions: ONLY output text\\n- For missing parameters: ONLY request the specific missing parameters\\n- For unavailable services (not in function list): output ONLY with internal knowledge or \\\"I don't have access to [Unavailable service] information\\\". Do NOT execute a function call.\\n- If the query asks for information beyond what a listed function provides: output ONLY with internal knowledge about your limitations\\n- NEVER combine text and function calls in the same response\\n- NEVER suggest alternative functions when the requested service is unavailable\\n- NEVER create or invent new functions not listed below\\n\\n3. STRICT BOUNDARIES:\\n- ONLY use functions from the list below - no exceptions\\n- NEVER use a function as an alternative to unavailable information\\n- NEVER call functions not present in the function list\\n- NEVER add explanatory text to function calls\\n- NEVER respond with empty brackets\\n- Use proper Python/JSON syntax for function calls\\n- Check the function list carefully before responding\\n\\n4. TOOL RESPONSE HANDLING:\\n- When receiving tool responses: provide concise, natural language responses\\n- Don't repeat tool response verbatim\\n- Don't add supplementary information\\n\\nHere is a list of functions in JSON format that you can invoke:\\n\" %}\n    {%- else %}\n        {%- set system_message = \"\" %}\n    {%- endif %}\n{%- endif %}\n{#- Now writing the system message: use the user provided system message if user_provided_system_message, else default tool system message if tools presented #}\n{%- if system_message %}\n    {#- always use user provided system message to override default tool system message #}\n    {{- \"<|header_start|>system<|header_end|>\\n\\n\" }}\n    {{- system_message }}\n    {%- if user_provided_system_message and tools %}\n        {{- \"\\nHere is a list of functions in JSON format that you can invoke. Use exact format: [func_name1(param1=value1, param2=value2), func_name2(...)]\\n\" }}\n        {{- tool_definition -}}\n        {%- elif tool_definition %}\n        {{- tool_definition -}}\n    {%- endif %}\n    {{- \"<|eot|>\" }}\n{%- endif %}\n\n{#- Now deal with all other messages #}\n{%- for message in messages %}\n    {#- Base case: messages that are not from tool role and has empty tool_call list  #}\n    {%- if not (message.role == 'ipython' or message.role == 'tool' or ('tool_calls' in message and message.tool_calls|length != 0 )) %}\n        {{- '<|header_start|>' + message['role'] + '<|header_end|>\\n\\n' }}\n        {%- if message['content'] is string %}\n            {{- message['content'] }}\n        {%- else %}\n            {%- for content in message['content'] %}\n                {%- if content['type'] == 'image' %}\n                    {{- '<|image|>' }}\n                {%- elif content['type'] == 'text' %}\n                    {{- content['text'] | trim }}\n                {%- endif %}\n            {%- endfor %}\n        {%- endif %}\n    {{- \"<|eot|>\" }}\n    {#- Tool case: messages has non-empty tool_call list, must from assistant #}\n    {%- elif 'tool_calls' in message %}\n        {#- assume tool_calls are always coming from assistant #}\n        {%- if message.role == 'assistant' %}\n            {{- '<|header_start|>assistant<|header_end|>\\n\\n' -}}\n        {%- if message['content'] is string %}\n            {{- message['content'] }}\n        {%- else %}\n            {%- for content in message['content'] %}\n                {%- if content['type'] == 'image' %}\n                    {{- '<|image|>' }}\n                {%- elif content['type'] == 'text' %}\n                    {{- content['text'] }}\n                {%- endif %}\n            {%- endfor %}\n        {%- endif %}\n        {{- \"[\" }}\n        {%- for tool_call in message.tool_calls %}\n            {%- if tool_call.function is defined %}\n                {%- set tool_call = tool_call.function %}\n            {%- endif %}\n                {{-  tool_call.name + '(' -}}\n            {%- for param in tool_call.arguments %}\n                {{- param + '=\"' -}}\n                {{- \"%s\" | format(tool_call.arguments[param]) -}}\n                {{- '\"' -}}\n                {% if not loop.last %}, {% endif %}\n            {%- endfor %}\n            {{- ')' -}}\n            {% if not loop.last %}, {% endif %}\n        {%- endfor %}\n        {{- \"]<|eot|>\" }}\n{%- endif %}\n{#- Tool_response case: messages are from tool_response  #}\n    {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n        {{- \"<|header_start|>ipython<|header_end|>\\n\\n\" }}\n        {%- if message.content is string %}\n            {{- message.content | tojson }}\n        {%- else %}\n            {%- for content in message['content']  %}\n                {%- if content['type']  == 'text' %}\n                    {{- content['text'] | tojson }}\n                {%- endif %}\n            {%- endfor %}\n        {%- endif %}\n        {{- \"<|eot|>\" }}\n    {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n    {{- '<|header_start|>assistant<|header_end|>\\n\\n' }}\n{%- endif %}\n"
  },
  {
    "path": "examples/chat_template/vision_template_sarashina_vl.jinja",
    "content": "{#\n In sglang, the default chat templates often assume message['content'] is a plain string.\n That works fine for simple text conversations, but it ignores multimodal inputs (e.g. image_url, tool_call).\n To align with the original model behavior and support richer content,\n we iterate over message['content'] as a list of typed items and extract their values directly.\n This way, both text and non-text inputs are preserved in the prompt.\n Original template: https://huggingface.co/sbintuitions/sarashina2-vision-8b?chat_template=default\n#}\n{{ bos_token + '<|prefix|><|file|><|suffix|>A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human\\'s questions.\\n\\n' }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Human: ' }}{%- if message['content'] is string %}{{ message['content'] }}{%- else %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% endif %}{% endfor %}{% endif %}{{ '\\n' }}{% elif message['role'] == 'assistant' %}{{ '### Assistant: ' }}{%- if message['content'] is string %}{{ message['content'] }}{%- else %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% endif %}{% endfor %}{% endif %}{{ '\\n' }}{% endif %}{% endfor %}{% if messages[-1]['role'] == 'user' %}{{ '### Assistant:' }}{% endif %}\n"
  },
  {
    "path": "examples/checkpoint_engine/update.py",
    "content": "\"\"\"\nUsage:\n1) Launch the server with wait-for-initial-weights option in one terminal:\n   python -m sglang.launch_server --model-path /workspace/Qwen/Qwen3-4B/ --tensor-parallel-size 2 --port 19730 --load-format dummy --checkpoint-engine-wait-weights-before-ready --mem-fraction-static 0.7\n\n2) Torchrun this script in another terminal:\n    torchrun --nproc-per-node 2 update.py --update-method broadcast --checkpoint-path /workspace/Qwen/Qwen3-4B/  --inference-parallel-size 2\n\"\"\"\n\nimport argparse\nimport json\nimport os\nimport pickle\nimport time\nfrom collections import defaultdict\nfrom collections.abc import Callable\nfrom contextlib import contextmanager\nfrom typing import Literal\n\nimport httpx\nimport torch\nimport torch.distributed as dist\nfrom checkpoint_engine.ps import ParameterServer\nfrom loguru import logger\nfrom safetensors import safe_open\n\n\n@contextmanager\ndef timer(msg: str):\n    start = time.perf_counter()\n    yield\n    end = time.perf_counter()\n    logger.info(f\"{msg} duration: {end - start:.2f} seconds\")\n\n\ndef check_sglang_ready(\n    endpoint: str, inference_parallel_size: int, uds: str | None = None\n):\n    if rank != rank // inference_parallel_size * inference_parallel_size:\n        return\n    retry_num = 0\n    transport = None\n    if uds is not None:\n        transport = httpx.HTTPTransport(uds=uds)\n    with httpx.Client(transport=transport) as client:\n        while True:\n            try:\n                response = client.get(f\"{endpoint}/ping\", timeout=10)\n                response.raise_for_status()\n                break\n            except (httpx.ConnectError, httpx.HTTPStatusError) as e:\n                if retry_num % 10 == 0:\n                    logger.warning(\n                        f\"fail to check sglang ready, retry {retry_num} times, error: {e}\"\n                    )\n                retry_num += 1\n                time.sleep(0.1)\n\n\ndef split_checkpoint_files(\n    checkpoint_path: str, rank: int, world_size: int\n) -> list[str]:\n    checkpoint_files = [\n        os.path.join(checkpoint_path, f)\n        for f in filter(\n            lambda x: x.endswith(\".safetensors\"), os.listdir(checkpoint_path)\n        )\n    ]\n    files_per_rank = (len(checkpoint_files) + world_size - 1) // world_size\n    return checkpoint_files[rank * files_per_rank : (rank + 1) * files_per_rank]\n\n\ndef split_tensors(\n    checkpoint_path: str, rank: int, world_size: int\n) -> dict[str, torch.Tensor]:\n    index_fn = os.path.join(checkpoint_path, \"model.safetensors.index.json\")\n    with open(index_fn) as f:\n        weight_map: dict[str, str] = json.load(f)[\"weight_map\"]\n    weights_per_rank = (len(weight_map) + world_size - 1) // world_size\n    fn_tensors: dict[str, list[str]] = defaultdict(list)\n    weight_keys = list(weight_map.items())\n    for name, file in weight_keys[\n        rank * weights_per_rank : (rank + 1) * weights_per_rank\n    ]:\n        fn_tensors[file].append(name)\n    named_tensors = {}\n    for file, names in fn_tensors.items():\n        with safe_open(os.path.join(checkpoint_path, file), framework=\"pt\") as f:\n            for name in names:\n                named_tensors[name] = f.get_tensor(name)\n    return named_tensors\n\n\ndef req_inference(\n    endpoint: str,\n    inference_parallel_size: int,\n    timeout: float = 300.0,\n    uds: str | None = None,\n    weight_version: str | None = None,\n) -> Callable[[list[tuple[str, str]]], None]:\n    rank = int(os.getenv(\"RANK\", 0))\n    src = rank // inference_parallel_size * inference_parallel_size\n\n    def req_func(socket_paths: list[tuple[str, str]]):\n        if rank == src:\n            with httpx.Client(transport=httpx.HTTPTransport(uds=uds)) as client:\n                resp = client.post(\n                    f\"{endpoint}/update_weights_from_ipc\",\n                    json={\n                        \"zmq_handles\": dict(\n                            socket_paths[src : src + inference_parallel_size]\n                        ),\n                        \"flush_cache\": True,\n                        \"weight_version\": weight_version,\n                    },\n                    timeout=timeout,\n                )\n                resp.raise_for_status()\n\n    return req_func\n\n\ndef update_weights(\n    ps: ParameterServer,\n    checkpoint_name: str,\n    checkpoint_files: list[str],\n    named_tensors: dict[str, torch.Tensor],\n    req_func: Callable[[list[tuple[str, str]]], None],\n    inference_parallel_size: int,\n    endpoint: str,\n    save_metas_file: str | None = None,\n    update_method: Literal[\"broadcast\", \"p2p\", \"all\"] = \"broadcast\",\n    uds: str | None = None,\n):\n    ps.register_checkpoint(\n        checkpoint_name, files=checkpoint_files, named_tensors=named_tensors\n    )\n    ps.init_process_group()\n    check_sglang_ready(endpoint, inference_parallel_size, uds)\n    dist.barrier()\n    with timer(\"Gather metas\"):\n        ps.gather_metas(checkpoint_name)\n    if save_metas_file and int(os.getenv(\"RANK\")) == 0:\n        with open(save_metas_file, \"wb\") as f:\n            pickle.dump(ps.get_metas(), f)\n\n    if update_method == \"broadcast\" or update_method == \"all\":\n        with timer(\"Update weights without setting ranks\"):\n            ps.update(checkpoint_name, req_func)\n\n    if update_method == \"p2p\" or update_method == \"all\":\n        if update_method:\n            # sleep 2s to wait destroy process group\n            time.sleep(2)\n        with timer(\"Update weights with setting ranks\"):\n            ps.update(\n                checkpoint_name, req_func, ranks=list(range(inference_parallel_size))\n            )\n\n\ndef join(\n    ps: ParameterServer,\n    checkpoint_name: str,\n    load_metas_file: str,\n    req_func: Callable[[list[tuple[str, str]]], None],\n    inference_parallel_size: int,\n    endpoint: str,\n    uds: str | None = None,\n):\n    assert load_metas_file, \"load_metas_file is required\"\n    with open(load_metas_file, \"rb\") as f:\n        metas = pickle.load(f)\n    ps.init_process_group()\n    check_sglang_ready(endpoint, inference_parallel_size, uds)\n    dist.barrier()\n    with timer(\"Gather metas before join\"):\n        ps.gather_metas(checkpoint_name)\n    ps.load_metas(metas)\n    with timer(\n        f\"Update weights with setting ranks as range(0, {inference_parallel_size}) by using p2p\"\n    ):\n        ps.update(checkpoint_name, req_func, ranks=list(range(inference_parallel_size)))\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"Update weights example\")\n    parser.add_argument(\"--checkpoint-path\", type=str, default=None)\n    parser.add_argument(\"--save-metas-file\", type=str, default=None)\n    parser.add_argument(\"--load-metas-file\", type=str, default=None)\n    parser.add_argument(\"--sleep-time\", type=int, default=0)\n    parser.add_argument(\"--endpoint\", type=str, default=\"http://localhost:19730\")\n    parser.add_argument(\"--inference-parallel-size\", type=int, default=8)\n    parser.add_argument(\"--checkpoint-name\", type=str, default=\"my-checkpoint-iter-0\")\n    parser.add_argument(\"--update-method\", type=str, default=\"broadcast\")\n    parser.add_argument(\"--uds\", type=str, default=None)\n    parser.add_argument(\"--weight-version\", type=str, default=None)\n    args = parser.parse_args()\n    rank = int(os.getenv(\"RANK\"))\n    world_size = int(os.getenv(\"WORLD_SIZE\"))\n    req_func = req_inference(\n        args.endpoint,\n        args.inference_parallel_size,\n        uds=args.uds,\n        weight_version=args.weight_version,\n    )\n    ps = ParameterServer(auto_pg=True)\n    ps._p2p_store = None\n    if args.load_metas_file:\n        join(\n            ps,\n            args.checkpoint_name,\n            args.load_metas_file,\n            req_func,\n            args.inference_parallel_size,\n            args.endpoint,\n            args.uds,\n        )\n    else:\n        if os.path.exists(\n            os.path.join(args.checkpoint_path, \"model.safetensors.index.json\")\n        ):\n            named_tensors = split_tensors(args.checkpoint_path, rank, world_size)\n            checkpoint_files = []\n        else:\n            checkpoint_files = split_checkpoint_files(\n                args.checkpoint_path, rank, world_size\n            )\n            named_tensors = {}\n        update_weights(\n            ps,\n            args.checkpoint_name,\n            checkpoint_files,\n            named_tensors,\n            req_func,\n            args.inference_parallel_size,\n            args.endpoint,\n            args.save_metas_file,\n            args.update_method,\n            args.uds,\n        )\n    time.sleep(args.sleep_time)\n"
  },
  {
    "path": "examples/frontend_language/quick_start/anthropic_example_chat.py",
    "content": "\"\"\"\nUsage:\nexport ANTHROPIC_API_KEY=sk-******\npython3 anthropic_example_chat.py\n\"\"\"\n\nimport sglang as sgl\n\n\n@sgl.function\ndef multi_turn_question(s, question_1, question_2):\n    s += sgl.user(question_1)\n    s += sgl.assistant(sgl.gen(\"answer_1\", max_tokens=256))\n    s += sgl.user(question_2)\n    s += sgl.assistant(sgl.gen(\"answer_2\", max_tokens=256))\n\n\ndef single():\n    state = multi_turn_question.run(\n        question_1=\"What is the capital of the United States?\",\n        question_2=\"List two local attractions.\",\n    )\n\n    for m in state.messages():\n        print(m[\"role\"], \":\", m[\"content\"])\n\n    print(\"\\n-- answer_1 --\\n\", state[\"answer_1\"])\n\n\ndef stream():\n    state = multi_turn_question.run(\n        question_1=\"What is the capital of the United States?\",\n        question_2=\"List two local attractions.\",\n        stream=True,\n    )\n\n    for out in state.text_iter():\n        print(out, end=\"\", flush=True)\n    print()\n\n\ndef batch():\n    states = multi_turn_question.run_batch(\n        [\n            {\n                \"question_1\": \"What is the capital of the United States?\",\n                \"question_2\": \"List two local attractions.\",\n            },\n            {\n                \"question_1\": \"What is the capital of France?\",\n                \"question_2\": \"What is the population of this city?\",\n            },\n        ]\n    )\n\n    for s in states:\n        print(s.messages())\n\n\nif __name__ == \"__main__\":\n    sgl.set_default_backend(sgl.Anthropic(\"claude-3-haiku-20240307\"))\n\n    # Run a single request\n    print(\"\\n========== single ==========\\n\")\n    single()\n\n    # Stream output\n    print(\"\\n========== stream ==========\\n\")\n    stream()\n\n    # Run a batch of requests\n    print(\"\\n========== batch ==========\\n\")\n    batch()\n"
  },
  {
    "path": "examples/frontend_language/quick_start/anthropic_example_complete.py",
    "content": "\"\"\"\nUsage:\nexport ANTHROPIC_API_KEY=sk-******\npython3 anthropic_example_complete.py\n\"\"\"\n\nimport sglang as sgl\n\n\n@sgl.function\ndef few_shot_qa(s, question):\n    s += \"\"\"\n\\n\\nHuman: What is the capital of France?\n\\n\\nAssistant: Paris\n\\n\\nHuman: What is the capital of Germany?\n\\n\\nAssistant: Berlin\n\\n\\nHuman: What is the capital of Italy?\n\\n\\nAssistant: Rome\n\"\"\"\n    s += \"\\n\\nHuman: \" + question + \"\\n\"\n    s += \"\\n\\nAssistant:\" + sgl.gen(\"answer\", temperature=0)\n\n\ndef single():\n    state = few_shot_qa.run(question=\"What is the capital of the United States?\")\n    answer = state[\"answer\"].strip().lower()\n\n    assert \"washington\" in answer, f\"answer: {state['answer']}\"\n\n    print(state.text())\n\n\ndef stream():\n    state = few_shot_qa.run(\n        question=\"What is the capital of the United States?\", stream=True\n    )\n\n    for out in state.text_iter(\"answer\"):\n        print(out, end=\"\", flush=True)\n    print()\n\n\ndef batch():\n    states = few_shot_qa.run_batch(\n        [\n            {\"question\": \"What is the capital of the United States?\"},\n            {\"question\": \"What is the capital of China?\"},\n        ]\n    )\n\n    for s in states:\n        print(s[\"answer\"])\n\n\nif __name__ == \"__main__\":\n    sgl.set_default_backend(sgl.Anthropic(\"claude-3-haiku-20240307\"))\n\n    # Run a single request\n    print(\"\\n========== single ==========\\n\")\n    single()\n\n    # Stream output\n    print(\"\\n========== stream ==========\\n\")\n    stream()\n\n    # Run a batch of requests\n    print(\"\\n========== batch ==========\\n\")\n    batch()\n"
  },
  {
    "path": "examples/frontend_language/quick_start/azure_openai_example_chat.py",
    "content": "\"\"\"\nUsage:\nexport AZURE_OPENAI_API_KEY=sk-******\npython3 openai_example_chat.py\n\"\"\"\n\nimport os\n\nimport sglang as sgl\n\n\n@sgl.function\ndef multi_turn_question(s, question_1, question_2):\n    s += sgl.system(\"You are a helpful assistant.\")\n    s += sgl.user(question_1)\n    s += sgl.assistant(sgl.gen(\"answer_1\", max_tokens=256))\n    s += sgl.user(question_2)\n    s += sgl.assistant(sgl.gen(\"answer_2\", max_tokens=256))\n\n\ndef single():\n    state = multi_turn_question.run(\n        question_1=\"What is the capital of the United States?\",\n        question_2=\"List two local attractions.\",\n    )\n\n    for m in state.messages():\n        print(m[\"role\"], \":\", m[\"content\"])\n\n    print(\"\\n-- answer_1 --\\n\", state[\"answer_1\"])\n\n\ndef stream():\n    state = multi_turn_question.run(\n        question_1=\"What is the capital of the United States?\",\n        question_2=\"List two local attractions.\",\n        stream=True,\n    )\n\n    for out in state.text_iter():\n        print(out, end=\"\", flush=True)\n    print()\n\n\ndef batch():\n    states = multi_turn_question.run_batch(\n        [\n            {\n                \"question_1\": \"What is the capital of the United States?\",\n                \"question_2\": \"List two local attractions.\",\n            },\n            {\n                \"question_1\": \"What is the capital of France?\",\n                \"question_2\": \"What is the population of this city?\",\n            },\n        ]\n    )\n\n    for s in states:\n        print(s.messages())\n\n\nif __name__ == \"__main__\":\n    backend = sgl.OpenAI(\n        model_name=\"azure-gpt-4\",\n        api_version=\"2023-07-01-preview\",\n        azure_endpoint=\"https://oai-arena-sweden.openai.azure.com/\",\n        api_key=os.environ[\"AZURE_OPENAI_API_KEY\"],\n        is_azure=True,\n    )\n    sgl.set_default_backend(backend)\n\n    # Run a single request\n    print(\"\\n========== single ==========\\n\")\n    single()\n\n    # Stream output\n    print(\"\\n========== stream ==========\\n\")\n    stream()\n\n    # Run a batch of requests\n    print(\"\\n========== batch ==========\\n\")\n    batch()\n"
  },
  {
    "path": "examples/frontend_language/quick_start/gemini_example_chat.py",
    "content": "\"\"\"\nUsage:\nexport GCP_PROJECT_ID=******\npython3 gemini_example_chat.py\n\"\"\"\n\nimport sglang as sgl\n\n\n@sgl.function\ndef multi_turn_question(s, question_1, question_2):\n    s += sgl.user(question_1)\n    s += sgl.assistant(sgl.gen(\"answer_1\", max_tokens=256))\n    s += sgl.user(question_2)\n    s += sgl.assistant(sgl.gen(\"answer_2\", max_tokens=256))\n\n\ndef single():\n    state = multi_turn_question.run(\n        question_1=\"What is the capital of the United States?\",\n        question_2=\"List two local attractions.\",\n    )\n\n    for m in state.messages():\n        print(m[\"role\"], \":\", m[\"content\"])\n\n    print(\"\\n-- answer_1 --\\n\", state[\"answer_1\"])\n\n\ndef stream():\n    state = multi_turn_question.run(\n        question_1=\"What is the capital of the United States?\",\n        question_2=\"List two local attractions.\",\n        stream=True,\n    )\n\n    for out in state.text_iter():\n        print(out, end=\"\", flush=True)\n    print()\n\n\ndef batch():\n    states = multi_turn_question.run_batch(\n        [\n            {\n                \"question_1\": \"What is the capital of the United States?\",\n                \"question_2\": \"List two local attractions.\",\n            },\n            {\n                \"question_1\": \"What is the capital of France?\",\n                \"question_2\": \"What is the population of this city?\",\n            },\n        ]\n    )\n\n    for s in states:\n        print(s.messages())\n\n\nif __name__ == \"__main__\":\n    sgl.set_default_backend(sgl.VertexAI(\"gemini-pro\"))\n\n    # Run a single request\n    print(\"\\n========== single ==========\\n\")\n    single()\n\n    # Stream output\n    print(\"\\n========== stream ==========\\n\")\n    stream()\n\n    # Run a batch of requests\n    print(\"\\n========== batch ==========\\n\")\n    batch()\n"
  },
  {
    "path": "examples/frontend_language/quick_start/gemini_example_complete.py",
    "content": "\"\"\"\nUsage:\nexport GCP_PROJECT_ID=******\npython3 gemini_example_complete.py\n\"\"\"\n\nimport sglang as sgl\n\n\n@sgl.function\ndef few_shot_qa(s, question):\n    s += \"\"\"The following are questions with answers.\nQ: What is the capital of France?\nA: Paris\nQ: What is the capital of Germany?\nA: Berlin\nQ: What is the capital of Italy?\nA: Rome\n\"\"\"\n    s += \"Q: \" + question + \"\\n\"\n    s += \"A:\" + sgl.gen(\"answer\", stop=\"\\n\", temperature=0)\n\n\ndef single():\n    state = few_shot_qa.run(question=\"What is the capital of the United States?\")\n    answer = state[\"answer\"].strip().lower()\n\n    assert \"washington\" in answer, f\"answer: {state['answer']}\"\n\n    print(state.text())\n\n\ndef stream():\n    state = few_shot_qa.run(\n        question=\"What is the capital of the United States?\", stream=True\n    )\n\n    for out in state.text_iter(\"answer\"):\n        print(out, end=\"\", flush=True)\n    print()\n\n\ndef batch():\n    states = few_shot_qa.run_batch(\n        [\n            {\"question\": \"What is the capital of the United States?\"},\n            {\"question\": \"What is the capital of China?\"},\n        ]\n    )\n\n    for s in states:\n        print(s[\"answer\"])\n\n\nif __name__ == \"__main__\":\n    sgl.set_default_backend(sgl.VertexAI(\"gemini-pro\"))\n\n    # Run a single request\n    print(\"\\n========== single ==========\\n\")\n    single()\n\n    # Stream output\n    print(\"\\n========== stream ==========\\n\")\n    stream()\n\n    # Run a batch of requests\n    print(\"\\n========== batch ==========\\n\")\n    batch()\n"
  },
  {
    "path": "examples/frontend_language/quick_start/gemini_example_multimodal_chat.py",
    "content": "\"\"\"\nUsage:\nexport GCP_PROJECT_ID=******\npython3 gemini_example_multimodal_chat.py\n\"\"\"\n\nimport sglang as sgl\n\n\n@sgl.function\ndef image_qa(s, image_file1, image_file2, question):\n    s += sgl.user(sgl.image(image_file1) + sgl.image(image_file2) + question)\n    s += sgl.assistant(sgl.gen(\"answer\", max_tokens=256))\n\n\nif __name__ == \"__main__\":\n    sgl.set_default_backend(sgl.VertexAI(\"gemini-pro-vision\"))\n\n    state = image_qa.run(\n        image_file1=\"./images/cat.jpeg\",\n        image_file2=\"./images/dog.jpeg\",\n        question=\"Describe difference of the two images in one sentence.\",\n        stream=True,\n    )\n\n    for out in state.text_iter(\"answer\"):\n        print(out, end=\"\", flush=True)\n    print()\n\n    print(state[\"answer\"])\n"
  },
  {
    "path": "examples/frontend_language/quick_start/local_example_chat.py",
    "content": "\"\"\"\nUsage:\npython3 local_example_chat.py\n\"\"\"\n\nimport sglang as sgl\n\n\n@sgl.function\ndef multi_turn_question(s, question_1, question_2):\n    s += sgl.user(question_1)\n    s += sgl.assistant(sgl.gen(\"answer_1\", max_tokens=256))\n    s += sgl.user(question_2)\n    s += sgl.assistant(sgl.gen(\"answer_2\", max_tokens=256))\n\n\ndef single():\n    state = multi_turn_question.run(\n        question_1=\"What is the capital of the United States?\",\n        question_2=\"List two local attractions.\",\n    )\n\n    for m in state.messages():\n        print(m[\"role\"], \":\", m[\"content\"])\n\n    print(\"\\n-- answer_1 --\\n\", state[\"answer_1\"])\n\n\ndef stream():\n    state = multi_turn_question.run(\n        question_1=\"What is the capital of the United States?\",\n        question_2=\"List two local attractions.\",\n        stream=True,\n    )\n\n    for out in state.text_iter():\n        print(out, end=\"\", flush=True)\n    print()\n\n\ndef batch():\n    states = multi_turn_question.run_batch(\n        [\n            {\n                \"question_1\": \"What is the capital of the United States?\",\n                \"question_2\": \"List two local attractions.\",\n            },\n            {\n                \"question_1\": \"What is the capital of France?\",\n                \"question_2\": \"What is the population of this city?\",\n            },\n        ]\n    )\n\n    for s in states:\n        print(s.messages())\n\n\nif __name__ == \"__main__\":\n    runtime = sgl.Runtime(model_path=\"meta-llama/Llama-2-7b-chat-hf\")\n    sgl.set_default_backend(runtime)\n\n    # Run a single request\n    print(\"\\n========== single ==========\\n\")\n    single()\n\n    # Stream output\n    print(\"\\n========== stream ==========\\n\")\n    stream()\n\n    # Run a batch of requests\n    print(\"\\n========== batch ==========\\n\")\n    batch()\n\n    runtime.shutdown()\n"
  },
  {
    "path": "examples/frontend_language/quick_start/local_example_complete.py",
    "content": "\"\"\"\nUsage:\npython3 local_example_complete.py\n\"\"\"\n\nimport sglang as sgl\n\n\n@sgl.function\ndef few_shot_qa(s, question):\n    s += \"\"\"The following are questions with answers.\nQ: What is the capital of France?\nA: Paris\nQ: What is the capital of Germany?\nA: Berlin\nQ: What is the capital of Italy?\nA: Rome\n\"\"\"\n    s += \"Q: \" + question + \"\\n\"\n    s += \"A:\" + sgl.gen(\"answer\", stop=\"\\n\", temperature=0)\n\n\ndef single():\n    state = few_shot_qa.run(question=\"What is the capital of the United States?\")\n    answer = state[\"answer\"].strip().lower()\n\n    assert \"washington\" in answer, f\"answer: {state['answer']}\"\n\n    print(state.text())\n\n\ndef stream():\n    state = few_shot_qa.run(\n        question=\"What is the capital of the United States?\", stream=True\n    )\n\n    for out in state.text_iter(\"answer\"):\n        print(out, end=\"\", flush=True)\n    print()\n\n\ndef batch():\n    states = few_shot_qa.run_batch(\n        [\n            {\"question\": \"What is the capital of the United States?\"},\n            {\"question\": \"What is the capital of China?\"},\n        ]\n    )\n\n    for s in states:\n        print(s[\"answer\"])\n\n\nif __name__ == \"__main__\":\n    runtime = sgl.Runtime(model_path=\"meta-llama/Llama-2-7b-chat-hf\")\n    sgl.set_default_backend(runtime)\n\n    # Run a single request\n    print(\"\\n========== single ==========\\n\")\n    single()\n\n    # Stream output\n    print(\"\\n========== stream ==========\\n\")\n    stream()\n\n    # Run a batch of requests\n    print(\"\\n========== batch ==========\\n\")\n    batch()\n\n    runtime.shutdown()\n"
  },
  {
    "path": "examples/frontend_language/quick_start/local_example_llava_next.py",
    "content": "\"\"\"\nUsage: python3 local_example_llava_next.py\n\"\"\"\n\nimport sglang as sgl\nfrom sglang.lang.chat_template import get_chat_template\n\n\n@sgl.function\ndef image_qa(s, image_path, question):\n    s += sgl.user(sgl.image(image_path) + question)\n    s += sgl.assistant(sgl.gen(\"answer\"))\n\n\ndef single():\n    state = image_qa.run(\n        image_path=\"images/cat.jpeg\", question=\"What is this?\", max_new_tokens=128\n    )\n    print(state[\"answer\"], \"\\n\")\n\n\ndef stream():\n    state = image_qa.run(\n        image_path=\"images/cat.jpeg\",\n        question=\"What is this?\",\n        max_new_tokens=64,\n        stream=True,\n    )\n\n    for out in state.text_iter(\"answer\"):\n        print(out, end=\"\", flush=True)\n    print()\n\n\ndef batch():\n    states = image_qa.run_batch(\n        [\n            {\"image_path\": \"images/cat.jpeg\", \"question\": \"What is this?\"},\n            {\"image_path\": \"images/dog.jpeg\", \"question\": \"What is this?\"},\n        ],\n        max_new_tokens=128,\n    )\n    for s in states:\n        print(s[\"answer\"], \"\\n\")\n\n\nif __name__ == \"__main__\":\n    import multiprocessing as mp\n\n    mp.set_start_method(\"spawn\", force=True)\n\n    runtime = sgl.Runtime(model_path=\"lmms-lab/llama3-llava-next-8b\")\n    runtime.endpoint.chat_template = get_chat_template(\"llama-3-instruct-llava\")\n\n    # Or you can use the 72B model\n    # runtime = sgl.Runtime(model_path=\"lmms-lab/llava-next-72b\", tp_size=8)\n    # runtime.endpoint.chat_template = get_chat_template(\"chatml-llava\")\n\n    sgl.set_default_backend(runtime)\n    print(f\"chat template: {runtime.endpoint.chat_template.name}\")\n\n    # Or you can use API models\n    # sgl.set_default_backend(sgl.OpenAI(\"gpt-4-vision-preview\"))\n    # sgl.set_default_backend(sgl.VertexAI(\"gemini-pro-vision\"))\n\n    # Run a single request\n    print(\"\\n========== single ==========\\n\")\n    single()\n\n    # Stream output\n    print(\"\\n========== stream ==========\\n\")\n    stream()\n\n    # Run a batch of requests\n    print(\"\\n========== batch ==========\\n\")\n    batch()\n\n    runtime.shutdown()\n"
  },
  {
    "path": "examples/frontend_language/quick_start/openai_example_chat.py",
    "content": "\"\"\"\nUsage:\nexport OPENAI_API_KEY=sk-******\npython3 openai_example_chat.py\n\"\"\"\n\nimport sglang as sgl\n\n\n@sgl.function\ndef multi_turn_question(s, question_1, question_2):\n    s += sgl.system(\"You are a helpful assistant.\")\n    s += sgl.user(question_1)\n    s += sgl.assistant(sgl.gen(\"answer_1\", max_tokens=256))\n    s += sgl.user(question_2)\n    s += sgl.assistant(sgl.gen(\"answer_2\", max_tokens=256))\n\n\ndef single():\n    state = multi_turn_question.run(\n        question_1=\"What is the capital of the United States?\",\n        question_2=\"List two local attractions.\",\n    )\n\n    for m in state.messages():\n        print(m[\"role\"], \":\", m[\"content\"])\n\n    print(\"\\n-- answer_1 --\\n\", state[\"answer_1\"])\n\n\ndef stream():\n    state = multi_turn_question.run(\n        question_1=\"What is the capital of the United States?\",\n        question_2=\"List two local attractions.\",\n        stream=True,\n    )\n\n    for out in state.text_iter():\n        print(out, end=\"\", flush=True)\n    print()\n\n\ndef batch():\n    states = multi_turn_question.run_batch(\n        [\n            {\n                \"question_1\": \"What is the capital of the United States?\",\n                \"question_2\": \"List two local attractions.\",\n            },\n            {\n                \"question_1\": \"What is the capital of France?\",\n                \"question_2\": \"What is the population of this city?\",\n            },\n        ]\n    )\n\n    for s in states:\n        print(s.messages())\n\n\nif __name__ == \"__main__\":\n    sgl.set_default_backend(sgl.OpenAI(\"gpt-3.5-turbo\"))\n\n    # Run a single request\n    print(\"\\n========== single ==========\\n\")\n    single()\n\n    # Stream output\n    print(\"\\n========== stream ==========\\n\")\n    stream()\n\n    # Run a batch of requests\n    print(\"\\n========== batch ==========\\n\")\n    batch()\n"
  },
  {
    "path": "examples/frontend_language/quick_start/openai_example_complete.py",
    "content": "\"\"\"\nUsage:\nexport OPENAI_API_KEY=sk-******\npython3 openai_example_complete.py\n\"\"\"\n\nimport sglang as sgl\n\n\n@sgl.function\ndef few_shot_qa(s, question):\n    s += \"\"\"The following are questions with answers.\nQ: What is the capital of France?\nA: Paris\nQ: What is the capital of Germany?\nA: Berlin\nQ: What is the capital of Italy?\nA: Rome\n\"\"\"\n    s += \"Q: \" + question + \"\\n\"\n    s += \"A:\" + sgl.gen(\"answer\", stop=\"\\n\", temperature=0)\n\n\ndef single():\n    state = few_shot_qa.run(question=\"What is the capital of the United States?\")\n    answer = state[\"answer\"].strip().lower()\n\n    assert \"washington\" in answer, f\"answer: {state['answer']}\"\n\n    print(state.text())\n\n\ndef stream():\n    state = few_shot_qa.run(\n        question=\"What is the capital of the United States?\", stream=True\n    )\n\n    for out in state.text_iter(\"answer\"):\n        print(out, end=\"\", flush=True)\n    print()\n\n\ndef batch():\n    states = few_shot_qa.run_batch(\n        [\n            {\"question\": \"What is the capital of the United States?\"},\n            {\"question\": \"What is the capital of China?\"},\n        ]\n    )\n\n    for s in states:\n        print(s[\"answer\"])\n\n\nif __name__ == \"__main__\":\n    sgl.set_default_backend(sgl.OpenAI(\"gpt-3.5-turbo-instruct\"))\n\n    # Run a single request\n    print(\"\\n========== single ==========\\n\")\n    single()\n\n    # Stream output\n    print(\"\\n========== stream ==========\\n\")\n    stream()\n\n    # Run a batch of requests\n    print(\"\\n========== batch ==========\\n\")\n    batch()\n"
  },
  {
    "path": "examples/frontend_language/quick_start/openai_example_n.py",
    "content": "\"\"\"\nUsage:\nexport OPENAI_API_KEY=sk-******\npython3 openai_example_chat.py\n\"\"\"\n\nimport sglang as sgl\n\n\n@sgl.function\ndef multi_turn_question(s, question_1, question_2):\n    s += sgl.system(\"You are a helpful assistant.\")\n    s += sgl.user(question_1)\n    s += sgl.assistant(sgl.gen(\"answer_1\", max_tokens=1024, n=2))\n    s += sgl.user(question_2)\n    s += sgl.assistant(\n        sgl.gen(\n            \"answer_2\",\n            max_tokens=1024,\n        )\n    )\n\n\ndef single():\n    state = multi_turn_question.run(\n        question_1=\"What is the capital of the United States?\",\n        question_2=\"List two local attractions.\",\n    )\n\n    for m in state.messages():\n        print(m[\"role\"], \":\", m[\"content\"])\n\n    print(\"\\n-- answer_1 --\\n\", state[\"answer_1\"])\n    print(\"\\n-- answer_2 --\\n\", state[\"answer_2\"])\n    assert isinstance(state[\"answer_1\"], list)\n    assert len(state[\"answer_1\"]) == 2\n    assert isinstance(state[\"answer_2\"], str)\n\n\ndef batch():\n    states = multi_turn_question.run_batch(\n        [\n            {\n                \"question_1\": \"What is the capital of the United States?\",\n                \"question_2\": \"List two local attractions.\",\n            },\n            {\n                \"question_1\": \"What is the capital of France?\",\n                \"question_2\": \"What is the population of this city?\",\n            },\n        ]\n    )\n\n    for s in states:\n        print(s.messages())\n        print(\"\\n-- answer_1 --\\n\", s[\"answer_1\"])\n        print(\"\\n-- answer_2 --\\n\", s[\"answer_2\"])\n        assert isinstance(s[\"answer_1\"], list)\n        assert len(s[\"answer_1\"]) == 2\n        assert isinstance(s[\"answer_2\"], str)\n\n\nif __name__ == \"__main__\":\n    sgl.set_default_backend(sgl.OpenAI(\"o1\"))\n\n    # Run a single request\n    print(\"\\n========== single ==========\\n\")\n    single()\n    # Run a batch of requests\n    print(\"\\n========== batch ==========\\n\")\n    batch()\n"
  },
  {
    "path": "examples/frontend_language/quick_start/openai_example_o1.py",
    "content": "\"\"\"\nUsage:\nexport OPENAI_API_KEY=sk-******\npython3 openai_example_chat.py\n\"\"\"\n\nimport sglang as sgl\n\n\n@sgl.function\ndef multi_turn_question(s, question_1, question_2):\n    s += sgl.system(\"You are a helpful assistant.\")\n    s += sgl.user(question_1)\n    s += sgl.assistant(sgl.gen(\"answer_1\", max_tokens=100))\n    s += sgl.user(question_2)\n    s += sgl.assistant(sgl.gen(\"answer_2\"))\n\n\ndef single():\n    state = multi_turn_question.run(\n        question_1=\"What is the capital of the United States?\",\n        question_2=\"List two local attractions.\",\n    )\n\n    for m in state.messages():\n        print(m[\"role\"], \":\", m[\"content\"])\n\n    print(\"\\n-- answer_1 --\\n\", state[\"answer_1\"])\n\n\ndef batch():\n    states = multi_turn_question.run_batch(\n        [\n            {\n                \"question_1\": \"What is the capital of the United States?\",\n                \"question_2\": \"List two local attractions.\",\n            },\n            {\n                \"question_1\": \"What is the capital of France?\",\n                \"question_2\": \"What is the population of this city?\",\n            },\n        ]\n    )\n\n    for s in states:\n        print(s.messages())\n\n\nif __name__ == \"__main__\":\n    sgl.set_default_backend(sgl.OpenAI(\"o1\"))\n\n    # Run a single request\n    print(\"\\n========== single ==========\\n\")\n    single()\n    # Run a batch of requests\n    print(\"\\n========== batch ==========\\n\")\n    batch()\n"
  },
  {
    "path": "examples/frontend_language/quick_start/openrouter_example_chat.py",
    "content": "\"\"\"\nUsage:\nexport OPENROUTER_API_KEY=sk-******\npython3 together_example_chat.py\n\"\"\"\n\nimport os\n\nimport sglang as sgl\n\n\n@sgl.function\ndef multi_turn_question(s, question_1, question_2):\n    s += sgl.system(\"You are a helpful assistant.\")\n    s += sgl.user(question_1)\n    s += sgl.assistant(sgl.gen(\"answer_1\", max_tokens=256))\n    s += sgl.user(question_2)\n    s += sgl.assistant(sgl.gen(\"answer_2\", max_tokens=256))\n\n\ndef single():\n    state = multi_turn_question.run(\n        question_1=\"What is the capital of the United States?\",\n        question_2=\"List two local attractions.\",\n    )\n\n    for m in state.messages():\n        print(m[\"role\"], \":\", m[\"content\"])\n\n    print(\"\\n-- answer_1 --\\n\", state[\"answer_1\"])\n\n\ndef stream():\n    state = multi_turn_question.run(\n        question_1=\"What is the capital of the United States?\",\n        question_2=\"List two local attractions.\",\n        stream=True,\n    )\n\n    for out in state.text_iter():\n        print(out, end=\"\", flush=True)\n    print()\n\n\ndef batch():\n    states = multi_turn_question.run_batch(\n        [\n            {\n                \"question_1\": \"What is the capital of the United States?\",\n                \"question_2\": \"List two local attractions.\",\n            },\n            {\n                \"question_1\": \"What is the capital of France?\",\n                \"question_2\": \"What is the population of this city?\",\n            },\n        ]\n    )\n\n    for s in states:\n        print(s.messages())\n\n\nif __name__ == \"__main__\":\n    backend = sgl.OpenAI(\n        model_name=\"google/gemma-7b-it:free\",\n        base_url=\"https://openrouter.ai/api/v1\",\n        api_key=os.environ.get(\"OPENROUTER_API_KEY\"),\n    )\n    sgl.set_default_backend(backend)\n\n    # Run a single request\n    print(\"\\n========== single ==========\\n\")\n    single()\n\n    # Stream output\n    print(\"\\n========== stream ==========\\n\")\n    stream()\n\n    # Run a batch of requests\n    print(\"\\n========== batch ==========\\n\")\n    batch()\n"
  },
  {
    "path": "examples/frontend_language/quick_start/together_example_chat.py",
    "content": "\"\"\"\nUsage:\nexport TOGETHER_API_KEY=sk-******\npython3 together_example_chat.py\n\"\"\"\n\nimport os\n\nimport sglang as sgl\n\n\n@sgl.function\ndef multi_turn_question(s, question_1, question_2):\n    s += sgl.system(\"You are a helpful assistant.\")\n    s += sgl.user(question_1)\n    s += sgl.assistant(sgl.gen(\"answer_1\", max_tokens=256))\n    s += sgl.user(question_2)\n    s += sgl.assistant(sgl.gen(\"answer_2\", max_tokens=256))\n\n\ndef single():\n    state = multi_turn_question.run(\n        question_1=\"What is the capital of the United States?\",\n        question_2=\"List two local attractions.\",\n    )\n\n    for m in state.messages():\n        print(m[\"role\"], \":\", m[\"content\"])\n\n    print(\"\\n-- answer_1 --\\n\", state[\"answer_1\"])\n\n\ndef stream():\n    state = multi_turn_question.run(\n        question_1=\"What is the capital of the United States?\",\n        question_2=\"List two local attractions.\",\n        stream=True,\n    )\n\n    for out in state.text_iter():\n        print(out, end=\"\", flush=True)\n    print()\n\n\ndef batch():\n    states = multi_turn_question.run_batch(\n        [\n            {\n                \"question_1\": \"What is the capital of the United States?\",\n                \"question_2\": \"List two local attractions.\",\n            },\n            {\n                \"question_1\": \"What is the capital of France?\",\n                \"question_2\": \"What is the population of this city?\",\n            },\n        ]\n    )\n\n    for s in states:\n        print(s.messages())\n\n\nif __name__ == \"__main__\":\n    backend = sgl.OpenAI(\n        model_name=\"mistralai/Mixtral-8x7B-Instruct-v0.1\",\n        base_url=\"https://api.together.xyz/v1\",\n        api_key=os.environ.get(\"TOGETHER_API_KEY\"),\n    )\n    sgl.set_default_backend(backend)\n\n    # Run a single request\n    print(\"\\n========== single ==========\\n\")\n    single()\n\n    # Stream output\n    print(\"\\n========== stream ==========\\n\")\n    stream()\n\n    # Run a batch of requests\n    print(\"\\n========== batch ==========\\n\")\n    batch()\n"
  },
  {
    "path": "examples/frontend_language/quick_start/together_example_complete.py",
    "content": "\"\"\"\nUsage:\nexport TOGETHER_API_KEY=sk-******\npython3 together_example_complete.py\n\"\"\"\n\nimport os\n\nimport sglang as sgl\n\n\n@sgl.function\ndef few_shot_qa(s, question):\n    s += \"\"\"The following are questions with answers.\nQ: What is the capital of France?\nA: Paris\nQ: What is the capital of Germany?\nA: Berlin\nQ: What is the capital of Italy?\nA: Rome\n\"\"\"\n    s += \"Q: \" + question + \"\\n\"\n    s += \"A:\" + sgl.gen(\"answer\", stop=\"\\n\", temperature=0)\n\n\ndef single():\n    state = few_shot_qa.run(question=\"What is the capital of the United States?\")\n    answer = state[\"answer\"].strip().lower()\n\n    assert \"washington\" in answer, f\"answer: {state['answer']}\"\n\n    print(state.text())\n\n\ndef stream():\n    state = few_shot_qa.run(\n        question=\"What is the capital of the United States?\", stream=True\n    )\n\n    for out in state.text_iter(\"answer\"):\n        print(out, end=\"\", flush=True)\n    print()\n\n\ndef batch():\n    states = few_shot_qa.run_batch(\n        [\n            {\"question\": \"What is the capital of the United States?\"},\n            {\"question\": \"What is the capital of China?\"},\n        ]\n    )\n\n    for s in states:\n        print(s[\"answer\"])\n\n\nif __name__ == \"__main__\":\n    backend = sgl.OpenAI(\n        model_name=\"mistralai/Mixtral-8x7B-Instruct-v0.1\",\n        is_chat_model=False,\n        base_url=\"https://api.together.xyz/v1\",\n        api_key=os.environ.get(\"TOGETHER_API_KEY\"),\n    )\n    sgl.set_default_backend(backend)\n\n    # Run a single request\n    print(\"\\n========== single ==========\\n\")\n    single()\n\n    # Stream output\n    print(\"\\n========== stream ==========\\n\")\n    stream()\n\n    # Run a batch of requests\n    print(\"\\n========== batch ==========\\n\")\n    batch()\n"
  },
  {
    "path": "examples/frontend_language/usage/chinese_regex.py",
    "content": "import sglang as sgl\n\ncharacter_regex = (\n    r\"\"\"\\{\\n\"\"\"\n    + r\"\"\"    \"姓名\": \"[^\"]{1,32}\",\\n\"\"\"\n    + r\"\"\"    \"学院\": \"(格兰芬多|赫奇帕奇|拉文克劳|斯莱特林)\",\\n\"\"\"\n    + r\"\"\"    \"血型\": \"(纯血|混血|麻瓜)\",\\n\"\"\"\n    + r\"\"\"    \"职业\": \"(学生|教师|傲罗|魔法部|食死徒|凤凰社成员)\",\\n\"\"\"\n    + r\"\"\"    \"魔杖\": \\{\\n\"\"\"\n    + r\"\"\"        \"材质\": \"[^\"]{1,32}\",\\n\"\"\"\n    + r\"\"\"        \"杖芯\": \"[^\"]{1,32}\",\\n\"\"\"\n    + r\"\"\"        \"长度\": [0-9]{1,2}\\.[0-9]{0,2}\\n\"\"\"\n    + r\"\"\"    \\},\\n\"\"\"\n    + r\"\"\"    \"存活\": \"(存活|死亡)\",\\n\"\"\"\n    + r\"\"\"    \"守护神\": \"[^\"]{1,32}\",\\n\"\"\"\n    + r\"\"\"    \"博格特\": \"[^\"]{1,32}\"\\n\"\"\"\n    + r\"\"\"\\}\"\"\"\n)\n\n\n@sgl.function\ndef character_gen(s, name):\n    s += name + \" 是一名哈利波特系列小说中的角色。请填写以下关于这个角色的信息。\"\n    s += \"\"\"\\\n这是一个例子\n{\n    \"姓名\": \"哈利波特\",\n    \"学院\": \"格兰芬多\",\n    \"血型\": \"混血\",\n    \"职业\": \"学生\",\n    \"魔杖\": {\n        \"材质\": \"冬青木\",\n        \"杖芯\": \"凤凰尾羽\",\n        \"长度\": 11.0\n    },\n    \"存活\": \"存活\",\n    \"守护神\": \"麋鹿\",\n    \"博格特\": \"摄魂怪\"\n}\n\"\"\"\n    s += f\"现在请你填写{name}的信息：\\n\"\n    s += sgl.gen(\"json_output\", max_tokens=256, regex=character_regex)\n\n\ndef main():\n    backend = sgl.RuntimeEndpoint(\"http://localhost:30000\")\n    sgl.set_default_backend(backend)\n    ret = character_gen.run(name=\"赫敏格兰杰\", temperature=0)\n    print(ret.text())\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/frontend_language/usage/choices_logprob.py",
    "content": "\"\"\"\nUsage:\npython -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000\npython choices_logprob.py\n\"\"\"\n\nimport sglang as sgl\n\n\n@sgl.function\ndef tool_use(s, question):\n    s += \"To answer this question: \" + question + \", \"\n    s += \"I need to use a \" + sgl.gen(\"tool\", choices=[\"calculator\", \"search engine\"])\n\n\ndef main():\n    # Run one case\n    question = \"What is 5 + 5?\"\n    state = tool_use.run(question)\n    print(\"questions:\", question)\n    print(\"choice:\", state[\"tool\"])\n    meta_info = state.get_meta_info(\"tool\")\n    print(\"logprobs of choice 1\", meta_info[\"input_token_logprobs\"][0])\n    print(\"logprobs of choice 2\", meta_info[\"input_token_logprobs\"][1])\n    print(\"-\" * 50)\n\n    # Run a batch\n    questions = [\n        \"What is 5 + 6?\",\n        \"Who is Michael Jordan?\",\n    ]\n    states = tool_use.run_batch([{\"question\": q} for q in questions])\n    for question, state in zip(questions, states):\n        print(\"questions:\", question)\n        print(\"choice:\", state[\"tool\"])\n        meta_info = state.get_meta_info(\"tool\")\n        print(\"logprobs of choice 1\", meta_info[\"input_token_logprobs\"][0])\n        print(\"logprobs of choice 2\", meta_info[\"input_token_logprobs\"][1])\n        print(\"-\" * 50)\n\n\nif __name__ == \"__main__\":\n    sgl.set_default_backend(sgl.RuntimeEndpoint(\"http://localhost:30000\"))\n    main()\n"
  },
  {
    "path": "examples/frontend_language/usage/cot_decoding.py",
    "content": "from math import exp\nfrom pprint import pformat\n\nimport sglang as sgl\n\nYELLOW = \"\\033[1;33m\"\nGREEN = \"\\033[1;32m\"\nBLUE = \"\\033[1;34m\"\nCLEAR = \"\\033[1;0m\"\n\n\n@sgl.function\ndef cot_decoding(s, question, get_top_k, is_chat_model, verbose):\n    \"\"\"CoT Decoding: http://arxiv.org/abs/2402.10200\"\"\"\n\n    if is_chat_model:\n        s += sgl.user(\"Question: \" + question + \"\\nAnswer:\")\n        s += sgl.assistant_begin()\n    else:\n        s += \"Question: \" + question + \"\\nAnswer:\"\n\n    step_0 = s.fork(1)[0]\n    forks = s.fork(get_top_k)\n    answer_forks = s.fork(get_top_k)\n\n    # decoding step 0\n    step_0 += sgl.gen(\n        \"get_top_k\",\n        max_tokens=0,\n        return_logprob=True,\n        top_logprobs_num=get_top_k,\n        return_text_in_logprobs=True,\n    )\n    logprobs = step_0.get_meta_info(\"get_top_k\")[\"output_top_logprobs\"][0]\n\n    print(\"Decoding step 0:\", \", \".join(pformat(token[2]) for token in logprobs))\n    for idx, (f, token) in enumerate(zip(forks, logprobs)):\n        logprob, token_id, text = token\n        f += text\n\n        if text == \"<|end_of_text|>\":\n            print(\n                f\"{YELLOW}Path #{idx} {pformat(text)}[{exp(logprob):.3f}] (score=nan, answer=nan){CLEAR}\"\n            )\n            continue\n\n        # continue greedy decoding\n        f += sgl.gen(\n            \"answer\",\n            temperature=0,\n            max_tokens=1024,\n            return_logprob=True,\n            top_logprobs_num=2,\n            return_text_in_logprobs=True,\n        )\n\n        # calculate probability disparity between the top and secondary tokens\n        x1s = [exp(xt[0][0]) for xt in f.get_meta_info(\"answer\")[\"output_top_logprobs\"]]\n        x2s = [exp(xt[1][0]) for xt in f.get_meta_info(\"answer\")[\"output_top_logprobs\"]]\n        tokens = [xt[0][2] for xt in f.get_meta_info(\"answer\")[\"output_top_logprobs\"]]\n        delta = (sum(x1s) - sum(x2s)) / len(x1s)\n\n        # extract the answer span (without the '<|end_of_text|>' token)\n        answer_forks[idx] += text + f[\"answer\"] + \"\\nSo the answer is\"\n        answer_forks[idx] += sgl.gen(\n            \"answer_span\",\n            temperature=0,\n            max_tokens=64,\n            return_logprob=True,\n            top_logprobs_num=2,\n            return_text_in_logprobs=True,\n        )\n        answer = answer_forks[idx][\"answer_span\"].replace(\"\\n\", \" \").strip(\":\")\n        print(\n            f\"{YELLOW}Path #{idx} {pformat(text)}[{exp(logprob):.3f}] (score={delta}, answer={answer}){CLEAR}\"\n        )\n        generated_text = str(answer_forks[idx])[len(\"ProgramState(\") : -1]\n        print(f\"{BLUE}{pformat(generated_text)}{CLEAR}\")\n\n        if verbose:\n            answer_tokens = [\n                xt[0][2]\n                for xt in answer_forks[idx].get_meta_info(\"answer_span\")[\n                    \"output_top_logprobs\"\n                ]\n            ]\n            answer_x1s = [\n                exp(xt[0][0])\n                for xt in answer_forks[idx].get_meta_info(\"answer_span\")[\n                    \"output_top_logprobs\"\n                ]\n            ]\n            answer_x2s = [\n                exp(xt[1][0])\n                for xt in answer_forks[idx].get_meta_info(\"answer_span\")[\n                    \"output_top_logprobs\"\n                ]\n            ]\n\n            for token, x1, x2 in zip(tokens, x1s, x2s):\n                print(f\" {GREEN}{pformat(token)}{CLEAR}({x1:.3f}-{x2:.3f})\", end=\"\")\n            print(\"\\n===========\")\n            for token, x1, x2 in zip(answer_tokens, answer_x1s, answer_x2s):\n                print(f\" {GREEN}{pformat(token)}{CLEAR}({x1:.3f}-{x2:.3f})\", end=\"\")\n            print()\n\n\nsgl.set_default_backend(sgl.RuntimeEndpoint(\"http://localhost:30000\"))\n\nstate = cot_decoding.run(\n    question=r\"Claire makes a 3 egg omelet every morning for breakfast. How many dozens of eggs will she eat in 4 weeks?\",\n    get_top_k=10,\n    is_chat_model=True,\n    verbose=False,\n)\n"
  },
  {
    "path": "examples/frontend_language/usage/json_decode.py",
    "content": "\"\"\"\nUsage:\npython -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000\npython json_decode.py\n\"\"\"\n\nfrom enum import Enum\n\nfrom pydantic import BaseModel\n\nimport sglang as sgl\nfrom sglang.srt.constrained.outlines_backend import build_regex_from_object\n\ncharacter_regex = (\n    r\"\"\"\\{\\n\"\"\"\n    + r\"\"\"    \"name\": \"[\\w\\d\\s]{1,16}\",\\n\"\"\"\n    + r\"\"\"    \"house\": \"(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)\",\\n\"\"\"\n    + r\"\"\"    \"blood status\": \"(Pure-blood|Half-blood|Muggle-born)\",\\n\"\"\"\n    + r\"\"\"    \"occupation\": \"(student|teacher|auror|ministry of magic|death eater|order of the phoenix)\",\\n\"\"\"\n    + r\"\"\"    \"wand\": \\{\\n\"\"\"\n    + r\"\"\"        \"wood\": \"[\\w\\d\\s]{1,16}\",\\n\"\"\"\n    + r\"\"\"        \"core\": \"[\\w\\d\\s]{1,16}\",\\n\"\"\"\n    + r\"\"\"        \"length\": [0-9]{1,2}\\.[0-9]{0,2}\\n\"\"\"\n    + r\"\"\"    \\},\\n\"\"\"\n    + r\"\"\"    \"alive\": \"(Alive|Deceased)\",\\n\"\"\"\n    + r\"\"\"    \"patronus\": \"[\\w\\d\\s]{1,16}\",\\n\"\"\"\n    + r\"\"\"    \"bogart\": \"[\\w\\d\\s]{1,16}\"\\n\"\"\"\n    + r\"\"\"\\}\"\"\"\n)\n\n\n@sgl.function\ndef character_gen(s, name):\n    s += (\n        name\n        + \" is a character in Harry Potter. Please fill in the following information about this character.\\n\"\n    )\n    s += \"The constrained regex is:\\n\"\n    s += character_regex + \"\\n\"\n    s += \"The JSON output is:\\n\"\n    s += sgl.gen(\"json_output\", max_tokens=256, regex=character_regex)\n\n\ndef driver_character_gen():\n    state = character_gen.run(name=\"Hermione Granger\")\n    print(state.text())\n\n\nclass Weapon(str, Enum):\n    sword = \"sword\"\n    axe = \"axe\"\n    mace = \"mace\"\n    spear = \"spear\"\n    bow = \"bow\"\n    crossbow = \"crossbow\"\n\n\nclass Wizard(BaseModel):\n    name: str\n    age: int\n    weapon: Weapon\n\n\n@sgl.function\ndef pydantic_wizard_gen(s):\n    s += \"Give me a description about a wizard in the JSON format.\\n\"\n    s += sgl.gen(\n        \"character\",\n        max_tokens=128,\n        temperature=0,\n        regex=build_regex_from_object(Wizard),  # Requires pydantic >= 2.0\n    )\n\n\ndef driver_pydantic_wizard_gen():\n    state = pydantic_wizard_gen.run()\n    print(state.text())\n\n\nif __name__ == \"__main__\":\n    sgl.set_default_backend(sgl.RuntimeEndpoint(\"http://localhost:30000\"))\n    driver_character_gen()\n    # driver_pydantic_wizard_gen()\n"
  },
  {
    "path": "examples/frontend_language/usage/json_logprobs.py",
    "content": "# NOTE: Currently this can only be run through HTTP requests.\nfrom concurrent.futures import ThreadPoolExecutor\n\nfrom json_decode import character_regex\n\nfrom sglang.utils import http_request\n\ncharacter_names = [\"Hermione Granger\", \"Ron Weasley\", \"Harry Potter\"]\n\nbase_url = \"http://localhost:30000\"\n\nprompt = \"is a character in Harry Potter. Please fill in the following information about this character.\\n\"\n\n\ndef openai_api_request(name):\n    data = {\n        \"model\": \"\",\n        \"prompt\": name + prompt,\n        \"temperature\": 0,\n        \"max_tokens\": 128,\n        \"regex\": character_regex,\n        \"logprobs\": 3,\n    }\n    res = http_request(base_url + \"/v1/completions\", json=data).json()\n\n    # with open(f\"json_logprobs_{name.replace(' ', '_')}_tmp.json\", \"w\") as fout:\n    #     fout.write(json.dumps(res, indent=4))\n\n    logprobs = res[\"choices\"][0][\"logprobs\"]\n    usage = res[\"usage\"]\n    assert len(logprobs[\"token_logprobs\"]) == len(logprobs[\"tokens\"])\n    assert len(logprobs[\"token_logprobs\"]) == len(logprobs[\"top_logprobs\"])\n    assert len(logprobs[\"token_logprobs\"]) == usage[\"completion_tokens\"] - 1\n\n    return res\n\n\ndef srt_api_request(name):\n    data = {\n        \"text\": name + prompt,\n        \"sampling_params\": {\n            \"temperature\": 0,\n            \"max_new_tokens\": 128,\n            \"regex\": character_regex,\n        },\n        \"return_logprob\": True,\n        \"logprob_start_len\": 0,\n        \"top_logprobs_num\": 3,\n        \"return_text_in_logprobs\": True,\n    }\n\n    res = http_request(base_url + \"/generate\", json=data).json()\n\n    # with open(f\"json_logprobs_{name.replace(' ', '_')}_tmp.json\", \"w\") as fout:\n    #     fout.write(json.dumps(res, indent=4))\n\n    meta_info = res[\"meta_info\"]\n    assert len(meta_info[\"input_token_logprobs\"]) == len(\n        meta_info[\"input_top_logprobs\"]\n    )\n    assert len(meta_info[\"output_token_logprobs\"]) == len(\n        meta_info[\"output_top_logprobs\"]\n    )\n    assert len(meta_info[\"input_token_logprobs\"]) == meta_info[\"prompt_tokens\"]\n    assert len(meta_info[\"output_token_logprobs\"]) == meta_info[\"completion_tokens\"] - 1\n\n    return res\n\n\ndef pretty_print(res):\n    meta_info = res[\"meta_info\"]\n\n    print(\"\\n\\n\", \"=\" * 30, \"Prefill\", \"=\" * 30)\n    for i in range(len(meta_info[\"input_token_logprobs\"])):\n        print(f\"{str(meta_info['input_token_logprobs'][i][2].encode()): <20}\", end=\"\")\n        top_ks = (\n            [str(t[2].encode()) for t in meta_info[\"input_top_logprobs\"][i]]\n            if meta_info[\"input_top_logprobs\"][i]\n            else []\n        )\n        for top_k in top_ks:\n            print(f\"{top_k: <15}\", end=\"\")\n        print()\n\n    print(\"\\n\\n\", \"=\" * 30, \"Decode\", \"=\" * 30)\n    for i in range(len(meta_info[\"output_token_logprobs\"])):\n        print(f\"{str(meta_info['output_token_logprobs'][i][2].encode()): <20}\", end=\"\")\n        top_ks = [str(t[2].encode()) for t in meta_info[\"output_top_logprobs\"][i]]\n        for top_k in top_ks:\n            print(f\"{top_k: <15}\", end=\"\")\n        print()\n\n    print(res[\"text\"])\n\n\nif __name__ == \"__main__\":\n    with ThreadPoolExecutor() as executor:\n        ress = executor.map(srt_api_request, character_names)\n\n    for res in ress:\n        pretty_print(res)\n\n    openai_api_request(\"Hermione Granger\")\n"
  },
  {
    "path": "examples/frontend_language/usage/llava_video/srt_example_llava_v.py",
    "content": "\"\"\"\nUsage:\npip install opencv-python-headless\n\npython3 srt_example_llava_v.py\n\"\"\"\n\nimport argparse\nimport csv\nimport json\nimport os\nimport time\n\nimport requests\n\nimport sglang as sgl\n\n\n@sgl.function\ndef video_qa(s, num_frames, video_path, question):\n    s += sgl.user(sgl.video(video_path, num_frames) + question)\n    s += sgl.assistant(sgl.gen(\"answer\"))\n\n\ndef single(path, num_frames=16):\n    state = video_qa.run(\n        num_frames=num_frames,\n        video_path=path,\n        question=\"Please provide a detailed description of the video, focusing on the main subjects, their actions, the background scenes\",\n        temperature=0.0,\n        max_new_tokens=1024,\n    )\n    print(state[\"answer\"], \"\\n\")\n\n\ndef split_into_chunks(lst, num_chunks):\n    \"\"\"Split a list into a specified number of chunks.\"\"\"\n    # Calculate the chunk size using integer division. Note that this may drop some items if not evenly divisible.\n    chunk_size = len(lst) // num_chunks\n\n    if chunk_size == 0:\n        chunk_size = len(lst)\n    # Use list comprehension to generate chunks. The last chunk will take any remainder if the list size isn't evenly divisible.\n    chunks = [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]\n    # Ensure we have exactly num_chunks chunks, even if some are empty\n    chunks.extend([[] for _ in range(num_chunks - len(chunks))])\n    return chunks\n\n\ndef save_batch_results(batch_video_files, states, cur_chunk, batch_idx, save_dir):\n    csv_filename = f\"{save_dir}/chunk_{cur_chunk}_batch_{batch_idx}.csv\"\n    with open(csv_filename, \"w\", newline=\"\") as csvfile:\n        writer = csv.writer(csvfile)\n        writer.writerow([\"video_name\", \"answer\"])\n        for video_path, state in zip(batch_video_files, states):\n            video_name = os.path.basename(video_path)\n            writer.writerow([video_name, state[\"answer\"]])\n\n\ndef compile_and_cleanup_final_results(cur_chunk, num_batches, save_dir):\n    final_csv_filename = f\"{save_dir}/final_results_chunk_{cur_chunk}.csv\"\n    with open(final_csv_filename, \"w\", newline=\"\") as final_csvfile:\n        writer = csv.writer(final_csvfile)\n        writer.writerow([\"video_name\", \"answer\"])\n        for batch_idx in range(num_batches):\n            batch_csv_filename = f\"{save_dir}/chunk_{cur_chunk}_batch_{batch_idx}.csv\"\n            with open(batch_csv_filename, \"r\") as batch_csvfile:\n                reader = csv.reader(batch_csvfile)\n                next(reader)  # Skip header row\n                for row in reader:\n                    writer.writerow(row)\n            os.remove(batch_csv_filename)\n\n\ndef find_video_files(video_dir):\n    # Check if the video_dir is actually a file\n    if os.path.isfile(video_dir):\n        # If it's a file, return it as a single-element list\n        return [video_dir]\n\n    # Original logic to find video files in a directory\n    video_files = []\n    for root, dirs, files in os.walk(video_dir):\n        for file in files:\n            if file.endswith((\".mp4\", \".avi\", \".mov\")):\n                video_files.append(os.path.join(root, file))\n    return video_files\n\n\ndef batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size=64):\n    video_files = find_video_files(video_dir)\n    chunked_video_files = split_into_chunks(video_files, num_chunks)[cur_chunk]\n    num_batches = 0\n\n    for i in range(0, len(chunked_video_files), batch_size):\n        batch_video_files = chunked_video_files[i : i + batch_size]\n        print(f\"Processing batch of {len(batch_video_files)} video(s)...\")\n\n        if not batch_video_files:\n            print(\"No video files found in the specified directory.\")\n            return\n\n        batch_input = [\n            {\n                \"num_frames\": num_frames,\n                \"video_path\": video_path,\n                \"question\": \"Please provide a detailed description of the video, focusing on the main subjects, their actions, the background scenes.\",\n            }\n            for video_path in batch_video_files\n        ]\n\n        start_time = time.perf_counter()\n        states = video_qa.run_batch(batch_input, max_new_tokens=512, temperature=0.2)\n        total_time = time.perf_counter() - start_time\n        average_time = total_time / len(batch_video_files)\n        print(\n            f\"Number of videos in batch: {len(batch_video_files)}. Average processing time per video: {average_time:.2f} seconds. Total time for this batch: {total_time:.2f} seconds\"\n        )\n\n        save_batch_results(batch_video_files, states, cur_chunk, num_batches, save_dir)\n        num_batches += 1\n\n    compile_and_cleanup_final_results(cur_chunk, num_batches, save_dir)\n\n\nif __name__ == \"__main__\":\n\n    url = \"https://raw.githubusercontent.com/EvolvingLMMs-Lab/sglang/dev/onevision_local/assets/jobs.mp4\"\n\n    cache_dir = os.path.expanduser(\"~/.cache\")\n    file_path = os.path.join(cache_dir, \"jobs.mp4\")\n\n    os.makedirs(cache_dir, exist_ok=True)\n\n    response = requests.get(url)\n    response.raise_for_status()  # Raise an exception for bad responses\n\n    with open(file_path, \"wb\") as f:\n        f.write(response.content)\n\n    print(f\"File downloaded and saved to: {file_path}\")\n    # Create the parser\n    parser = argparse.ArgumentParser(\n        description=\"Run video processing with specified port.\"\n    )\n\n    # Add an argument for the port\n    parser.add_argument(\n        \"--port\",\n        type=int,\n        default=30000,\n        help=\"The master port for distributed serving.\",\n    )\n    parser.add_argument(\n        \"--chunk-idx\", type=int, default=0, help=\"The index of the chunk to process.\"\n    )\n    parser.add_argument(\n        \"--num-chunks\", type=int, default=8, help=\"The number of chunks to process.\"\n    )\n    parser.add_argument(\n        \"--save-dir\",\n        type=str,\n        default=\"./work_dirs/llava_video\",\n        help=\"The directory to save the processed video files.\",\n    )\n    parser.add_argument(\n        \"--video-dir\",\n        type=str,\n        default=os.path.expanduser(\"~/.cache/jobs.mp4\"),\n        help=\"The directory or path for the processed video files.\",\n    )\n    parser.add_argument(\n        \"--model-path\",\n        type=str,\n        default=\"lmms-lab/LLaVA-NeXT-Video-7B\",\n        help=\"The model path for the video processing.\",\n    )\n    parser.add_argument(\n        \"--num-frames\",\n        type=int,\n        default=16,\n        help=\"The number of frames to process in each video.\",\n    )\n    parser.add_argument(\"--mm_spatial_pool_stride\", type=int, default=2)\n\n    # Parse the arguments\n    args = parser.parse_args()\n    cur_port = args.port\n    cur_chunk = args.chunk_idx\n    num_chunks = args.num_chunks\n    num_frames = args.num_frames\n\n    if \"34b\" in args.model_path.lower():\n        tokenizer_path = \"liuhaotian/llava-v1.6-34b-tokenizer\"\n    elif \"7b\" in args.model_path.lower():\n        tokenizer_path = \"llava-hf/llava-1.5-7b-hf\"\n    else:\n        print(\"Invalid model path. Please specify a valid model path.\")\n        exit()\n\n    model_override_args = {}\n    model_override_args[\"mm_spatial_pool_stride\"] = args.mm_spatial_pool_stride\n    model_override_args[\"architectures\"] = [\"LlavaVidForCausalLM\"]\n    model_override_args[\"num_frames\"] = args.num_frames\n    model_override_args[\"model_type\"] = \"llava\"\n\n    if \"34b\" in args.model_path.lower():\n        model_override_args[\"image_token_index\"] = 64002\n\n    if args.num_frames == 32:\n        model_override_args[\"rope_scaling\"] = {\"factor\": 2.0, \"rope_type\": \"linear\"}\n        model_override_args[\"max_sequence_length\"] = 4096 * 2\n        model_override_args[\"tokenizer_model_max_length\"] = 4096 * 2\n    elif args.num_frames < 32:\n        pass\n    else:\n        print(\n            \"The maximum number of frames to process is 32. Please specify a valid number of frames.\"\n        )\n        exit()\n\n    runtime = sgl.Runtime(\n        model_path=args.model_path,  # \"liuhaotian/llava-v1.6-vicuna-7b\",\n        tokenizer_path=tokenizer_path,\n        port=cur_port,\n        json_model_override_args=json.dumps(model_override_args),\n        tp_size=1,\n    )\n    sgl.set_default_backend(runtime)\n    print(f\"chat template: {runtime.endpoint.chat_template.name}\")\n\n    # Run a single request\n    print(\"\\n========== single ==========\\n\")\n    root = args.video_dir\n    if os.path.isfile(root):\n        video_files = [root]\n    else:\n        video_files = [\n            os.path.join(root, f)\n            for f in os.listdir(root)\n            if f.endswith((\".mp4\", \".avi\", \".mov\"))\n        ]  # Add more extensions if needed\n    start_time = time.perf_counter()  # Start time for processing a single video\n    for cur_video in video_files[:1]:\n        print(cur_video)\n        single(cur_video, num_frames)\n    end_time = time.perf_counter()  # End time for processing a single video\n    total_time = end_time - start_time\n    average_time = total_time / len(\n        video_files\n    )  # Calculate the average processing time\n    print(f\"Average processing time per video: {average_time:.2f} seconds\")\n    runtime.shutdown()\n\n    # # Run a batch of requests\n    # print(\"\\n========== batch ==========\\n\")\n    # if not os.path.exists(args.save_dir):\n    #     os.makedirs(args.save_dir)\n    # batch(args.video_dir, args.save_dir, cur_chunk, num_chunks, num_frames, num_chunks)\n    # runtime.shutdown()\n"
  },
  {
    "path": "examples/frontend_language/usage/llava_video/srt_example_llava_v.sh",
    "content": "#!/bin/bash\n\n##### USAGE #####\n#    - First node:\n#      ```sh\n#      bash examples/usage/llava_video/srt_example_llava_v.sh K 0 YOUR_VIDEO_PATH YOUR_MODEL_PATH FRAMES_PER_VIDEO\n#      ```\n#    - Second node:\n#      ```sh\n#      bash examples/usage/llava_video/srt_example_llava_v.sh K 1 YOUR_VIDEO_PATH YOUR_MODEL_PATH FRAMES_PER_VIDEO\n#      ```\n#    - The K node:\n#      ```sh\n#      bash examples/usage/llava_video/srt_example_llava_v.sh K K-1 YOUR_VIDEO_PATH YOUR_MODEL_PATH FRAMES_PER_VIDEO\n#      ```\n\n\n# Replace `K`, `YOUR_VIDEO_PATH`, `YOUR_MODEL_PATH`, and `FRAMES_PER_VIDEO` with your specific details.\n# CURRENT_ROOT=\"$( cd \"$( dirname \"${BASH_SOURCE[0]}\" )\" && pwd )\"\nCURRENT_ROOT=$(dirname \"$0\")\n\necho ${CURRENT_ROOT}\n\ncd ${CURRENT_ROOT}\n\nexport PYTHONWARNINGS=ignore\n\nSTART_TIME=$(date +%s)  # Capture start time\n\nNUM_NODES=$1\n\nCUR_NODES_IDX=$2\n\nVIDEO_DIR=$3\n\nMODEL_PATH=$4\n\nNUM_FRAMES=$5\n\n\n# FRAME_FORMAT=$6\n\n# FRAME_FORMAT=$(echo $FRAME_FORMAT | tr '[:lower:]' '[:upper:]')\n\n# # Check if FRAME_FORMAT is either JPEG or PNG\n# if [[ \"$FRAME_FORMAT\" != \"JPEG\" && \"$FRAME_FORMAT\" != \"PNG\" ]]; then\n#     echo \"Error: FRAME_FORMAT must be either JPEG or PNG.\"\n#     exit 1\n# fi\n\n# export TARGET_FRAMES=$TARGET_FRAMES\n\necho \"Each video you will sample $NUM_FRAMES frames\"\n\n# export FRAME_FORMAT=$FRAME_FORMAT\n\n# echo \"The frame format is $FRAME_FORMAT\"\n\n# Assuming GPULIST is a bash array containing your GPUs\nGPULIST=(0 1 2 3 4 5 6 7)\nLOCAL_CHUNKS=${#GPULIST[@]}\n\necho \"Number of GPUs in GPULIST: $LOCAL_CHUNKS\"\n\nALL_CHUNKS=$((NUM_NODES * LOCAL_CHUNKS))\n\n# Calculate GPUs per chunk\nGPUS_PER_CHUNK=1\n\necho $GPUS_PER_CHUNK\n\nfor IDX in $(seq 1 $LOCAL_CHUNKS); do\n    (\n        START=$(((IDX-1) * GPUS_PER_CHUNK))\n        LENGTH=$GPUS_PER_CHUNK # Length for slicing, not the end index\n\n        CHUNK_GPUS=(${GPULIST[@]:$START:$LENGTH})\n\n        # Convert the chunk GPUs array to a comma-separated string\n        CHUNK_GPUS_STR=$(IFS=,; echo \"${CHUNK_GPUS[*]}\")\n\n        LOCAL_IDX=$((CUR_NODES_IDX * LOCAL_CHUNKS + IDX))\n\n        echo \"Chunk $(($LOCAL_IDX - 1)) will run on GPUs $CHUNK_GPUS_STR\"\n\n        # Calculate the port for this chunk. Ensure it's incremented by 5 for each chunk.\n        PORT=$((10000 + RANDOM % 55536))\n\n        MAX_RETRIES=10\n        RETRY_COUNT=0\n        COMMAND_STATUS=1  # Initialize as failed\n\n        while [ $RETRY_COUNT -lt $MAX_RETRIES ] && [ $COMMAND_STATUS -ne 0 ]; do\n            echo \"Running chunk $(($LOCAL_IDX - 1)) on GPUs $CHUNK_GPUS_STR with port $PORT. Attempt $(($RETRY_COUNT + 1))\"\n\n#!/bin/bash\n            CUDA_VISIBLE_DEVICES=$CHUNK_GPUS_STR python3 srt_example_llava_v.py \\\n            --port $PORT \\\n            --num-chunks $ALL_CHUNKS \\\n            --chunk-idx $(($LOCAL_IDX - 1)) \\\n            --save-dir work_dirs/llava_next_video_inference_results \\\n            --video-dir $VIDEO_DIR \\\n            --model-path $MODEL_PATH \\\n            --num-frames $NUM_FRAMES #&\n\n            wait $!  # Wait for the process to finish and capture its exit status\n            COMMAND_STATUS=$?\n\n            if [ $COMMAND_STATUS -ne 0 ]; then\n                echo \"Execution failed for chunk $(($LOCAL_IDX - 1)), attempt $(($RETRY_COUNT + 1)). Retrying...\"\n                RETRY_COUNT=$(($RETRY_COUNT + 1))\n                sleep 180  # Wait a bit before retrying\n            else\n                echo \"Execution succeeded for chunk $(($LOCAL_IDX - 1)).\"\n            fi\n        done\n\n        if [ $COMMAND_STATUS -ne 0 ]; then\n            echo \"Execution failed for chunk $(($LOCAL_IDX - 1)) after $MAX_RETRIES attempts.\"\n        fi\n    ) #&\n    sleep 2  # Slight delay to stagger the start times\ndone\n\nwait\n\ncat work_dirs/llava_next_video_inference_results/final_results_chunk_*.csv > work_dirs/llava_next_video_inference_results/final_results_node_${CUR_NODES_IDX}.csv\n\nEND_TIME=$(date +%s)  # Capture end time\nELAPSED_TIME=$(($END_TIME - $START_TIME))\necho \"Total execution time: $ELAPSED_TIME seconds.\"\n"
  },
  {
    "path": "examples/frontend_language/usage/openai_chat_speculative.py",
    "content": "\"\"\"\nUsage:\n***Note: for speculative execution to work, user must put all \"gen\" in \"assistant\".\nShow in \"assistant\" the desired answer format. Each \"gen\" term should have a stop token.\nThe stream mode is not supported in speculative execution.\n\nE.g.\ncorrect:\n    sgl.assistant(\"\\nName:\" + sgl.gen(\"name\", stop=\"\\n\") + \"\\nBirthday:\" + sgl.gen(\"birthday\", stop=\"\\n\") + \"\\nJob:\" + sgl.gen(\"job\", stop=\"\\n\"))\nincorrect:\n    s += sgl.assistant(\"\\nName:\" + sgl.gen(\"name\", stop=\"\\n\"))\n    s += sgl.assistant(\"\\nBirthday:\" + sgl.gen(\"birthday\", stop=\"\\n\"))\n    s += sgl.assistant(\"\\nJob:\" + sgl.gen(\"job\", stop=\"\\n\"))\n\nexport OPENAI_API_KEY=sk-******\npython3 openai_chat_speculative.py\n\"\"\"\n\nimport sglang as sgl\nfrom sglang import OpenAI, function, set_default_backend\n\n\n@function(num_api_spec_tokens=256)\ndef gen_character_spec(s):\n    s += sgl.system(\"You are a helpful assistant.\")\n    s += sgl.user(\"Construct a character within the following format:\")\n    s += sgl.assistant(\n        \"Name: Steve Jobs.\\nBirthday: February 24, 1955.\\nJob: Apple CEO.\\n\"\n    )\n    s += sgl.user(\"Please generate new Name, Birthday and Job.\\n\")\n    s += sgl.assistant(\n        \"Name:\"\n        + sgl.gen(\"name\", stop=\"\\n\")\n        + \"\\nBirthday:\"\n        + sgl.gen(\"birthday\", stop=\"\\n\")\n        + \"\\nJob:\"\n        + sgl.gen(\"job\", stop=\"\\n\")\n    )\n\n\n@function(num_api_spec_tokens=256)\ndef gen_character_spec_no_few_shot(s):\n    s += sgl.user(\"Construct a character. For each field stop with a newline\\n\")\n    s += sgl.assistant(\n        \"Name:\"\n        + sgl.gen(\"name\", stop=\"\\n\")\n        + \"\\nAge:\"\n        + sgl.gen(\"age\", stop=\"\\n\")\n        + \"\\nJob:\"\n        + sgl.gen(\"job\", stop=\"\\n\")\n    )\n\n\n@function\ndef gen_character_normal(s):\n    s += sgl.system(\"You are a helpful assistant.\")\n    s += sgl.user(\"What's the answer of 23 + 8?\")\n    s += sgl.assistant(sgl.gen(\"answer\", max_tokens=64))\n\n\n@function(num_api_spec_tokens=1024)\ndef multi_turn_question(s, question_1, question_2):\n    s += sgl.system(\"You are a helpful assistant.\")\n    s += sgl.user(\"Answer questions in the following format:\")\n    s += sgl.user(\n        \"Question 1: What is the capital of France?\\nQuestion 2: What is the population of this city?\\n\"\n    )\n    s += sgl.assistant(\n        \"Answer 1: The capital of France is Paris.\\nAnswer 2: The population of Paris in 2024 is estimated to be around 2.1 million for the city proper.\\n\"\n    )\n    s += sgl.user(\"Question 1: \" + question_1 + \"\\nQuestion 2: \" + question_2)\n    s += sgl.assistant(\n        \"Answer 1: \"\n        + sgl.gen(\"answer_1\", stop=\"\\n\")\n        + \"\\nAnswer 2: \"\n        + sgl.gen(\"answer_2\", stop=\"\\n\")\n    )\n\n\ndef test_spec_single_turn():\n    backend.token_usage.reset()\n\n    state = gen_character_spec.run()\n    for m in state.messages():\n        print(m[\"role\"], \":\", m[\"content\"])\n\n    print(\"\\n-- name:\", state[\"name\"])\n    print(\"-- birthday:\", state[\"birthday\"])\n    print(\"-- job:\", state[\"job\"])\n    print(backend.token_usage)\n\n\ndef test_inaccurate_spec_single_turn():\n    state = gen_character_spec_no_few_shot.run()\n    for m in state.messages():\n        print(m[\"role\"], \":\", m[\"content\"])\n\n    print(\"\\n-- name:\", state[\"name\"])\n    print(\"\\n-- age:\", state[\"age\"])\n    print(\"\\n-- job:\", state[\"job\"])\n\n\ndef test_normal_single_turn():\n    state = gen_character_normal.run()\n    for m in state.messages():\n        print(m[\"role\"], \":\", m[\"content\"])\n\n\ndef test_spec_multi_turn():\n    state = multi_turn_question.run(\n        question_1=\"What is the capital of the United States?\",\n        question_2=\"List two local attractions in the capital of the United States.\",\n    )\n\n    for m in state.messages():\n        print(m[\"role\"], \":\", m[\"content\"])\n\n    print(\"\\n-- answer_1 --\\n\", state[\"answer_1\"])\n    print(\"\\n-- answer_2 --\\n\", state[\"answer_2\"])\n\n\ndef test_spec_multi_turn_stream():\n    state = multi_turn_question.run(\n        question_1=\"What is the capital of the United States?\",\n        question_2=\"List two local attractions.\",\n        stream=True,\n    )\n\n    for out in state.text_iter():\n        print(out, end=\"\", flush=True)\n\n\nif __name__ == \"__main__\":\n    backend = OpenAI(\"gpt-4-turbo\")\n    set_default_backend(backend)\n\n    print(\"\\n========== test spec single turn ==========\\n\")\n    # expect reasonable answer for each field\n    test_spec_single_turn()\n\n    print(\"\\n========== test inaccurate spec single turn ==========\\n\")\n    # expect incomplete or unreasonable answers\n    test_inaccurate_spec_single_turn()\n\n    print(\"\\n========== test normal single turn ==========\\n\")\n    # expect reasonable answer\n    test_normal_single_turn()\n\n    print(\"\\n========== test spec multi turn ==========\\n\")\n    # expect answer with same format as in the few shot\n    test_spec_multi_turn()\n\n    print(\"\\n========== test spec multi turn stream ==========\\n\")\n    # expect error in stream_executor: stream is not supported...\n    test_spec_multi_turn_stream()\n"
  },
  {
    "path": "examples/frontend_language/usage/openai_speculative.py",
    "content": "\"\"\"\nUsage:\npython3 openai_speculative.py\n\"\"\"\n\nfrom sglang import OpenAI, function, gen, set_default_backend\n\n\n@function(num_api_spec_tokens=64)\ndef gen_character_spec(s):\n    s += \"Construct a character within the following format:\\n\"\n    s += \"Name: Steve Jobs.\\nBirthday: February 24, 1955.\\nJob: Apple CEO.\\n\"\n    s += \"\\nPlease generate new Name, Birthday and Job.\\n\"\n    s += \"Name:\" + gen(\"name\", stop=\"\\n\") + \"\\nBirthday:\" + gen(\"birthday\", stop=\"\\n\")\n    s += \"\\nJob:\" + gen(\"job\", stop=\"\\n\") + \"\\n\"\n\n\n@function\ndef gen_character_no_spec(s):\n    s += \"Construct a character within the following format:\\n\"\n    s += \"Name: Steve Jobs.\\nBirthday: February 24, 1955.\\nJob: Apple CEO.\\n\"\n    s += \"\\nPlease generate new Name, Birthday and Job.\\n\"\n    s += \"Name:\" + gen(\"name\", stop=\"\\n\") + \"\\nBirthday:\" + gen(\"birthday\", stop=\"\\n\")\n    s += \"\\nJob:\" + gen(\"job\", stop=\"\\n\") + \"\\n\"\n\n\n@function(num_api_spec_tokens=64)\ndef gen_character_spec_no_few_shot(s):\n    # s += \"Construct a character with name, birthday, and job:\\n\"\n    s += \"Construct a character:\\n\"\n    s += \"Name:\" + gen(\"name\", stop=\"\\n\") + \"\\nBirthday:\" + gen(\"birthday\", stop=\"\\n\")\n    s += \"\\nJob:\" + gen(\"job\", stop=\"\\n\") + \"\\n\"\n\n\nif __name__ == \"__main__\":\n    backend = OpenAI(\"gpt-3.5-turbo-instruct\")\n    set_default_backend(backend)\n\n    for function in [\n        gen_character_spec,\n        gen_character_no_spec,\n        gen_character_spec_no_few_shot,\n    ]:\n        backend.token_usage.reset()\n\n        print(f\"function: {function.func.__name__}\")\n\n        state = function.run()\n\n        print(\"...name:\", state[\"name\"])\n        print(\"...birthday:\", state[\"birthday\"])\n        print(\"...job:\", state[\"job\"])\n        print(backend.token_usage)\n        print()\n"
  },
  {
    "path": "examples/frontend_language/usage/parallel_sample.py",
    "content": "\"\"\"\nUsage:\npython3 parallel_sample.py\n\"\"\"\n\nimport sglang as sgl\n\n\n@sgl.function\ndef parallel_sample(s, question, n):\n    s += (\n        \"Question: Compute 1 + 2 + 3\\n\"\n        \"Reasoning: I need to use a calculator.\\n\"\n        \"Tool: calculator\\n\"\n        \"Answer: 6\\n\"\n        \"Question: Compute 3 + 2 + 2\\n\"\n        \"Reasoning: I will try a calculator.\\n\"\n        \"Tool: calculator\\n\"\n        \"Answer: 7\\n\"\n    )\n    s += \"Question: \" + question + \"\\n\"\n    forks = s.fork(n)\n    forks += \"Reasoning:\" + sgl.gen(\"reasoning\", stop=\"\\n\") + \"\\n\"\n    forks += \"Tool:\" + sgl.gen(\"tool\", choices=[\"calculator\", \"browser\"]) + \"\\n\"\n    forks += \"Answer:\" + sgl.gen(\"answer\", stop=\"\\n\") + \"\\n\"\n    forks.join()\n\n\nsgl.set_default_backend(sgl.OpenAI(\"gpt-3.5-turbo-instruct\"))\n# sgl.set_default_backend(sgl.RuntimeEndpoint(\"http://localhost:30000\"))\n\nstate = parallel_sample.run(question=\"Compute 5 + 2 + 4.\", n=5, temperature=1.0)\n\nfor i in range(5):\n    obj = {\n        \"reasoning\": state[\"reasoning\"][i],\n        \"tool\": state[\"tool\"][i],\n        \"answer\": state[\"answer\"][i],\n    }\n    print(f\"[{i}], {obj}\")\n"
  },
  {
    "path": "examples/frontend_language/usage/rag_using_parea/trace_and_evaluate_rag_using_parea.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# RAG Powered by SGLang & Chroma Evaluated using Parea\\n\",\n    \"\\n\",\n    \"In this notebook, we will build a simple RAG pipeline using SGLang to execute our LLM calls, Chroma as vector database for retrieval and [Parea](https://www.parea.ai) for tracing and evaluation. We will then evaluate the performance of our RAG pipeline. The dataset we will use was created by [Virat](https://twitter.com/virattt) and contains 100 questions, contexts and answers from the Airbnb 2023 10k filing.\\n\",\n    \"\\n\",\n    \"The RAG pipeline consists of two steps:\\n\",\n    \"1. Retrieval: Given a question, we retrieve the relevant context from all provided contexts.\\n\",\n    \"2. Generation: Given the question and the retrieved context, we generate an answer.\\n\",\n    \"\\n\",\n    \"ℹ️ This notebook requires an OpenAI API key.\\n\",\n    \"\\n\",\n    \"ℹ️ This notebook requires a Parea API key, which can be created [here](https://docs.parea.ai/api-reference/authentication#parea-api-key).\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Setting up the environment\\n\",\n    \"\\n\",\n    \"We will first install the necessary packages: `sglang`, `parea-ai` and `chromadb`.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# note, if you use a Mac M1 chip, you might need to install grpcio 1.59.0 first such that installing chromadb works\\n\",\n    \"# !pip install grpcio==1.59.0\\n\",\n    \"\\n\",\n    \"!pip install sglang[openai] parea-ai chromadb\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Create a Parea API key as outlined [here](https://docs.parea.ai/api-reference/authentication#parea-api-key) and save it in a `.env` file as `PAREA_API_KEY=your-api-key`.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Indexing the data\\n\",\n    \"\\n\",\n    \"Now it's time to download the data & index it! For that, we create a collection called `contexts` in Chroma and add the contexts as documents.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import json\\n\",\n    \"import os\\n\",\n    \"from typing import List\\n\",\n    \"\\n\",\n    \"import chromadb\\n\",\n    \"\\n\",\n    \"path_qca = \\\"airbnb-2023-10k-qca.json\\\"\\n\",\n    \"\\n\",\n    \"if not os.path.exists(path_qca):\\n\",\n    \"    !wget https://virattt.github.io/datasets/abnb-2023-10k.json -O airbnb-2023-10k-qca.json\\n\",\n    \"\\n\",\n    \"with open(path_qca, \\\"r\\\") as f:\\n\",\n    \"    question_context_answers = json.load(f)\\n\",\n    \"\\n\",\n    \"chroma_client = chromadb.PersistentClient()\\n\",\n    \"collection = chroma_client.get_or_create_collection(name=\\\"contexts\\\")\\n\",\n    \"if collection.count() == 0:\\n\",\n    \"    collection.add(\\n\",\n    \"        documents=[qca[\\\"context\\\"] for qca in question_context_answers],\\n\",\n    \"        ids=[str(i) for i in range(len(question_context_answers))],\\n\",\n    \"    )\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Defining the RAG pipeline\\n\",\n    \"\\n\",\n    \"We will start with importing the necessary packages, setting up tracing of OpenAI calls via Parea and setting OpenAI as the default backend for SGLang.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import os\\n\",\n    \"import time\\n\",\n    \"\\n\",\n    \"from dotenv import load_dotenv\\n\",\n    \"\\n\",\n    \"from sglang import function, user, assistant, gen, set_default_backend, OpenAI\\n\",\n    \"from sglang.lang.interpreter import ProgramState\\n\",\n    \"from parea import Parea, trace\\n\",\n    \"\\n\",\n    \"load_dotenv()\\n\",\n    \"\\n\",\n    \"os.environ[\\\"TOKENIZERS_PARALLELISM\\\"] = \\\"false\\\"\\n\",\n    \"\\n\",\n    \"p = Parea(api_key=os.getenv(\\\"PAREA_API_KEY\\\"), project_name=\\\"rag_sglang\\\")\\n\",\n    \"p.integrate_with_sglang()\\n\",\n    \"\\n\",\n    \"set_default_backend(OpenAI(\\\"gpt-3.5-turbo\\\"))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Now we can define our retrieval step shown below. Notice, the `trace` decorator which will automatically trace inputs, output, latency, etc. of that call.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"@trace\\n\",\n    \"def retrieval(question: str) -> List[str]:\\n\",\n    \"    return collection.query(query_texts=[question], n_results=1)[\\\"documents\\\"][0]\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Next we will define the generation step which uses SGLang to execute the LLM call.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"@function\\n\",\n    \"def generation_sglang(s, question: str, *context: str):\\n\",\n    \"    context = \\\"\\\\n\\\".join(context)\\n\",\n    \"    s += user(\\n\",\n    \"        f\\\"Given this question:\\\\n{question}\\\\n\\\\nAnd this context:\\\\n{context}\\\\n\\\\nAnswer the question.\\\"\\n\",\n    \"    )\\n\",\n    \"    s += assistant(gen(\\\"answer\\\"))\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"@trace\\n\",\n    \"def generation(question: str, *context):\\n\",\n    \"    state: ProgramState = generation_sglang.run(question, *context)\\n\",\n    \"    while not state.stream_executor.is_finished:\\n\",\n    \"        time.sleep(1)\\n\",\n    \"    return state.stream_executor.variables[\\\"answer\\\"]\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Finally, we can tie it together and execute a sample query.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"@trace\\n\",\n    \"def rag_pipeline(question: str) -> str:\\n\",\n    \"    contexts = retrieval(question)\\n\",\n    \"    return generation(question, *contexts)\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"rag_pipeline(\\n\",\n    \"    \\\"When did the World Health Organization formally declare an end to the COVID-19 global health emergency?\\\"\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Debug Trace\\n\",\n    \"\\n\",\n    \"The output is unfortunately wrong! Using the traced pipeline, we can see that\\n\",\n    \"\\n\",\n    \"- the context is relevant to the question and contains the correct information\\n\",\n    \"- but, the generation step is cut off as max tokens is set to 16\\n\",\n    \"\\n\",\n    \"When opening the generation step in the playground and rerunning the prompt with max. tokens set to 1000, the correct answer is produced.\\n\",\n    \"\\n\",\n    \"![RAG Trace](https://drive.google.com/uc?id=1QI243ogGjzbO01tUrR72g9rFoGzUJqVH)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Evaluating RAG Pipelines\\n\",\n    \"\\n\",\n    \"Before we apply above's fix, let's dive into evaluating RAG pipelines.\\n\",\n    \"\\n\",\n    \"RAG pipelines consist of a retrieval step to fetch relevant information and a generation step to generate a response to a users question. A RAG pipeline can fail at either step. E.g. the retrieval step can fail to find relevant information which makes generating the correct impossible. Another failure mode is that the generation step doesn't leverage the retrieved information correctly. We will apply the following evaluation metrics to understand different failure modes:\\n\",\n    \"\\n\",\n    \"- `context_relevancy`: measures how relevant the context is given the question\\n\",\n    \"- `percent_target_supported_by_context`: measures how much of the target answer is supported by the context; this will give an upper ceiling of how well the generation step can perform\\n\",\n    \"- `answer_context_faithfulness`: measures how much the generated answer utilizes the context\\n\",\n    \"- `answer_matches_target`: measures how well the generated answer matches the target answer judged by a LLM and gives a sense of accuracy of our entire pipeline\\n\",\n    \"\\n\",\n    \"To use these evaluation metrics, we can import them from  `parea.evals.rag` and `parea.evals.general` and apply them to a function by specifying in the `trace` decorator which evaluation metrics to use. The `@trace` decorator will automatically log the results of the evaluation metrics to the Parea dashboard.\\n\",\n    \"\\n\",\n    \"Applying them to the retrieval step:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from parea.evals.rag import (\\n\",\n    \"    context_query_relevancy_factory,\\n\",\n    \"    percent_target_supported_by_context_factory,\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"context_relevancy_eval = context_query_relevancy_factory()\\n\",\n    \"percent_target_supported_by_context = percent_target_supported_by_context_factory()\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"@trace(eval_funcs=[context_relevancy_eval, percent_target_supported_by_context])\\n\",\n    \"def retrieval(question: str) -> List[str]:\\n\",\n    \"    return collection.query(query_texts=[question], n_results=1)[\\\"documents\\\"][0]\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Now we can apply `answer_context_faithfulness` and `answer_matches_target` to the generation step.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from parea.evals.general import answer_matches_target_llm_grader_factory\\n\",\n    \"from parea.evals.rag import answer_context_faithfulness_statement_level_factory\\n\",\n    \"\\n\",\n    \"answer_context_faithfulness = answer_context_faithfulness_statement_level_factory()\\n\",\n    \"answer_matches_target_llm_grader = answer_matches_target_llm_grader_factory()\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"@function\\n\",\n    \"def generation_sglang(s, question: str, *context: str):\\n\",\n    \"    context = \\\"\\\\n\\\".join(context)\\n\",\n    \"    s += user(\\n\",\n    \"        f\\\"Given this question:\\\\n{question}\\\\n\\\\nAnd this context:\\\\n{context}\\\\n\\\\nAnswer the question.\\\"\\n\",\n    \"    )\\n\",\n    \"    s += assistant(gen(\\\"answer\\\", max_tokens=1_000))\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"@trace(eval_funcs=[answer_context_faithfulness, answer_matches_target_llm_grader])\\n\",\n    \"def generation(question: str, *context):\\n\",\n    \"    state: ProgramState = generation_sglang.run(question, *context)\\n\",\n    \"    while not state.stream_executor.is_finished:\\n\",\n    \"        time.sleep(1)\\n\",\n    \"    return state.stream_executor.variables[\\\"answer\\\"]\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Finally, we tie them together & execute the original sample query.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"@trace\\n\",\n    \"def rag_pipeline(question: str) -> str:\\n\",\n    \"    contexts = retrieval(question)\\n\",\n    \"    return generation(question, *contexts)\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"rag_pipeline(\\n\",\n    \"    \\\"When did the World Health Organization formally declare an end to the COVID-19 global health emergency?\\\"\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Great, the answer is correct! Can you spot the line where we fixed the output truncation issue?\\n\",\n    \"\\n\",\n    \"The evaluation scores appear in the bottom right of the logs (screenshot below). Note, that there is no score for `answer_matches_target_llm_grader` and `percent_target_supported_by_context` as these evals are automatically skipped if the target answer is not provided.\\n\",\n    \"\\n\",\n    \"![Fixed Max. Tokens](max-tokens-fixed-rag-trace.png)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Running an experiment\\n\",\n    \"\\n\",\n    \"Now we are (almost) ready to evaluate the performance of our RAG pipeline on the entire dataset. First, we will need to apply the `nest_asyncio` package to avoid issues with the Jupyter notebook event loop.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"!pip install nest-asyncio\\n\",\n    \"import nest_asyncio\\n\",\n    \"\\n\",\n    \"nest_asyncio.apply()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Running the actual experiment is straight-forward. For that we use `p.experiment` to initialize the experiment with a name, the data (list of key-value pairs fed into our entry function) and the entry function. We then call `run` on the experiment to execute it. Note, that `target` is a reserved key in the data dictionary and will be used as the target answer for evaluation.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"e = p.experiment(\\n\",\n    \"    \\\"RAG\\\",\\n\",\n    \"    data=[\\n\",\n    \"        {\\n\",\n    \"            \\\"question\\\": qca[\\\"question\\\"],\\n\",\n    \"            \\\"target\\\": qca[\\\"answer\\\"],\\n\",\n    \"        }\\n\",\n    \"        for qca in question_context_answers\\n\",\n    \"    ],\\n\",\n    \"    func=rag_pipeline,\\n\",\n    \").run()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Analyzing the results\\n\",\n    \"\\n\",\n    \"When opening above experiment, we will see an overview of the experiment as shown below. The upper half shows a summary of the statistics on the left and charts to investigate the distribution and relationships of scores on the right. The lower half is a table with the individual traces which we can use to debug individual samples.\\n\",\n    \"\\n\",\n    \"When looking at the statistics, we can see that the accuracy of our RAG pipeline is 22% as measured by `answer_matches_target_llm_grader`. Though when checking the quality of our retrieval step (`context_query_relevancy`), we can see that our retrieval step is fetching relevant information in only 27% of all samples. As shown in the GIF, we investigate the relationship between the two and see the two scores have 95% agreement. This confirms that the retrieval step is a major bottleneck for our RAG pipeline. So, now it's your turn to improve the retrieval step!\\n\",\n    \"\\n\",\n    \"Note, above link isn't publicly accessible but the experiment can be accessed through [here](https://app.parea.ai/public-experiments/parea/rag_sglang/30f0244a-d56c-44ff-bdfb-8f47626304b6).\\n\",\n    \"\\n\",\n    \"![Experiment Results](https://drive.google.com/uc?id=1KMtJBU47nPB02Pvv3SPPTK7RnHRh5YdA)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 2\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython2\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 0\n}\n"
  },
  {
    "path": "examples/frontend_language/usage/readme_examples.py",
    "content": "\"\"\"\nUsage:\npython -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000\npython readme_examples.py\n\"\"\"\n\nimport sglang as sgl\n\n\n@sgl.function\ndef tool_use(s, question):\n    s += \"To answer this question: \" + question + \". \"\n    s += (\n        \"I need to use a \"\n        + sgl.gen(\"tool\", choices=[\"calculator\", \"search engine\"])\n        + \". \"\n    )\n\n    if s[\"tool\"] == \"calculator\":\n        s += \"The math expression is\" + sgl.gen(\"expression\")\n    elif s[\"tool\"] == \"search engine\":\n        s += \"The key word to search is\" + sgl.gen(\"word\")\n\n\n@sgl.function\ndef tip_suggestion(s):\n    s += (\n        \"Here are two tips for staying healthy: \"\n        \"1. Balanced Diet. 2. Regular Exercise.\\n\\n\"\n    )\n\n    forks = s.fork(2)\n    for i, f in enumerate(forks):\n        f += f\"Now, expand tip {i+1} into a paragraph:\\n\"\n        f += sgl.gen(f\"detailed_tip\", max_tokens=256, stop=\"\\n\\n\")\n\n    s += \"Tip 1:\" + forks[0][\"detailed_tip\"] + \"\\n\"\n    s += \"Tip 2:\" + forks[1][\"detailed_tip\"] + \"\\n\"\n    s += \"In summary\" + sgl.gen(\"summary\")\n\n\n@sgl.function\ndef regular_expression_gen(s):\n    s += \"Q: What is the IP address of the Google DNS servers?\\n\"\n    s += \"A: \" + sgl.gen(\n        \"answer\",\n        temperature=0,\n        regex=r\"((25[0-5]|2[0-4]\\d|[01]?\\d\\d?).){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\",\n    )\n\n\n@sgl.function\ndef text_qa(s, question):\n    s += \"Q: \" + question + \"\\n\"\n    s += \"A:\" + sgl.gen(\"answer\", stop=\"\\n\")\n\n\ndef driver_tool_use():\n    state = tool_use.run(question=\"What is the capital of the United States?\")\n    print(state.text())\n    print(\"\\n\")\n\n\ndef driver_tip_suggestion():\n    state = tip_suggestion.run()\n    print(state.text())\n    print(\"\\n\")\n\n\ndef driver_regex():\n    state = regular_expression_gen.run()\n    print(state.text())\n    print(\"\\n\")\n\n\ndef driver_batching():\n    states = text_qa.run_batch(\n        [\n            {\"question\": \"What is the capital of the United Kingdom?\"},\n            {\"question\": \"What is the capital of France?\"},\n            {\"question\": \"What is the capital of Japan?\"},\n        ],\n        progress_bar=True,\n    )\n\n    for s in states:\n        print(s.text())\n    print(\"\\n\")\n\n\ndef driver_stream():\n    state = text_qa.run(\n        question=\"What is the capital of France?\", temperature=0.1, stream=True\n    )\n\n    for out in state.text_iter():\n        print(out, end=\"\", flush=True)\n    print(\"\\n\")\n\n\nif __name__ == \"__main__\":\n    # sgl.set_default_backend(sgl.OpenAI(\"gpt-3.5-turbo-instruct\"))\n    sgl.set_default_backend(sgl.RuntimeEndpoint(\"http://localhost:30000\"))\n\n    driver_tool_use()\n    driver_tip_suggestion()\n    driver_regex()\n    driver_batching()\n    driver_stream()\n"
  },
  {
    "path": "examples/frontend_language/usage/sgl_gen_min_tokens.py",
    "content": "\"\"\"\nThis example demonstrates how to use `min_tokens` to enforce sgl.gen to generate a longer sequence\n\nUsage:\npython3 sgl_gen_min_tokens.py\n\"\"\"\n\nimport sglang as sgl\n\n\n@sgl.function\ndef long_answer(s):\n    s += sgl.user(\"What is the capital of the United States?\")\n    s += sgl.assistant(sgl.gen(\"answer\", min_tokens=64, max_tokens=128))\n\n\n@sgl.function\ndef short_answer(s):\n    s += sgl.user(\"What is the capital of the United States?\")\n    s += sgl.assistant(sgl.gen(\"answer\"))\n\n\nif __name__ == \"__main__\":\n    runtime = sgl.Runtime(model_path=\"meta-llama/Meta-Llama-3.1-8B-Instruct\")\n    sgl.set_default_backend(runtime)\n\n    state = long_answer.run()\n    print(\"=\" * 20)\n    print(\"Longer Answer\", state[\"answer\"])\n\n    state = short_answer.run()\n    print(\"=\" * 20)\n    print(\"Short Answer\", state[\"answer\"])\n\n    runtime.shutdown()\n"
  },
  {
    "path": "examples/frontend_language/usage/streaming.py",
    "content": "\"\"\"\nUsage:\npython3 streaming.py\n\"\"\"\n\nimport asyncio\n\nimport sglang as sgl\n\n\n@sgl.function\ndef multi_turn_question(s, question_1, question_2):\n    s += sgl.system(\"You are a helpful assistant.\")\n    s += sgl.user(question_1)\n    s += sgl.assistant(sgl.gen(\"answer_1\", max_tokens=256))\n    s += sgl.user(question_2)\n    s += sgl.assistant(sgl.gen(\"answer_2\", max_tokens=256))\n\n\nsgl.set_default_backend(sgl.OpenAI(\"gpt-3.5-turbo\"))\n\n\ndef stream_a_variable():\n    state = multi_turn_question.run(\n        question_1=\"What is the capital of the United States?\",\n        question_2=\"List two local attractions.\",\n        stream=True,\n    )\n\n    for out in state.text_iter(var_name=\"answer_2\"):\n        print(out, end=\"\", flush=True)\n    print(\"\\n\")\n\n\nasync def async_stream():\n    state = multi_turn_question.run(\n        question_1=\"What is the capital of the United States?\",\n        question_2=\"List two local attractions.\",\n        stream=True,\n    )\n\n    async for out in state.text_async_iter(var_name=\"answer_2\"):\n        print(out, end=\"\", flush=True)\n    print(\"\\n\")\n\n\nif __name__ == \"__main__\":\n    stream_a_variable()\n    asyncio.run(async_stream())\n"
  },
  {
    "path": "examples/frontend_language/usage/triton/Dockerfile",
    "content": "FROM nvcr.io/nvidia/tritonserver:24.01-py3\n\nWORKDIR /opt\n\nRUN git clone https://github.com/sgl-project/sglang.git\n\nWORKDIR /opt/sglang\nRUN pip install --upgrade pip && \\\n    pip install -e \"python[all]\" && \\\n    pip install datasets\n"
  },
  {
    "path": "examples/frontend_language/usage/triton/README.md",
    "content": "# sglang_triton\n\nBuild the docker image:\n```\ndocker build -t sglang-triton .\n```\n\nThen do:\n```\ndocker run -ti --gpus=all --network=host --name sglang-triton -v ./models:/mnt/models sglang-triton\n```\n\ninside the docker container:\n```\ncd sglang\npython3 -m sglang.launch_server --model-path mistralai/Mistral-7B-Instruct-v0.2 --port 30000 --mem-fraction-static 0.9\n```\n\nwith another shell, inside the docker container:\n```\ndocker exec -ti sglang-triton /bin/bash\ncd /mnt\ntritonserver --model-repository=/mnt/models\n```\n\n\nSend request to the server:\n```\ncurl -X POST http://localhost:8000/v2/models/character_generation/generate \\\n-H \"Content-Type: application/json\" \\\n-d '{\n  \"INPUT_TEXT\": [\"harry\"]\n}'\n\n```\n"
  },
  {
    "path": "examples/frontend_language/usage/triton/models/character_generation/1/model.py",
    "content": "import numpy\nimport triton_python_backend_utils as pb_utils\nfrom pydantic import BaseModel\n\nimport sglang as sgl\nfrom sglang import function\nfrom sglang.srt.constrained.outlines_backend import build_regex_from_object\n\nsgl.set_default_backend(sgl.RuntimeEndpoint(\"http://localhost:30000\"))\n\n\nclass Character(BaseModel):\n    name: str\n    eye_color: str\n    house: str\n\n\n@function\ndef character_gen(s, name):\n    s += (\n        name\n        + \" is a character in Harry Potter. Please fill in the following information about this character.\\n\"\n    )\n    s += sgl.gen(\n        \"json_output\", max_tokens=256, regex=build_regex_from_object(Character)\n    )\n\n\nclass TritonPythonModel:\n    def initialize(self, args):\n        print(\"Initialized.\")\n\n    def execute(self, requests):\n        responses = []\n        for request in requests:\n            tensor_in = pb_utils.get_input_tensor_by_name(request, \"INPUT_TEXT\")\n            if tensor_in is None:\n                return pb_utils.InferenceResponse(output_tensors=[])\n\n            input_list_names = [\n                i.decode(\"utf-8\") if isinstance(i, bytes) else i\n                for i in tensor_in.as_numpy().tolist()\n            ]\n\n            input_list_dicts = [{\"name\": i} for i in input_list_names]\n\n            states = character_gen.run_batch(input_list_dicts)\n            character_strs = [state.text() for state in states]\n\n            tensor_out = pb_utils.Tensor(\n                \"OUTPUT_TEXT\", numpy.array(character_strs, dtype=object)\n            )\n\n            responses.append(pb_utils.InferenceResponse(output_tensors=[tensor_out]))\n        return responses\n"
  },
  {
    "path": "examples/frontend_language/usage/triton/models/character_generation/config.pbtxt",
    "content": "name: \"character_generation\"\nbackend: \"python\"\ninput [\n    {\n        name: \"INPUT_TEXT\"\n        data_type: TYPE_STRING\n        dims: [ -1 ]\n    }\n]\noutput [\n    {\n        name: \"OUTPUT_TEXT\"\n        data_type: TYPE_STRING\n        dims: [ -1 ]\n    }\n]\ninstance_group [\n    {\n        count: 1\n        kind: KIND_GPU\n        gpus: [ 0 ]\n    }\n]\n"
  },
  {
    "path": "examples/monitoring/README.md",
    "content": "# SGLang Monitoring Setup\n\nThis directory contains a ready-to-use monitoring setup for SGLang using Prometheus and Grafana.\n\n## Prerequisites\n\n- Docker and Docker Compose installed\n- SGLang server running with metrics enabled\n\n## Usage\n\n1. Start your SGLang server with metrics enabled:\n\n```bash\npython -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --port 30000 --enable-metrics\n```\n\nBy default, the metrics server will run on `127.0.0.1:30000`.\n\n2. Start the monitoring stack:\n\n```bash\ncd examples/monitoring\ndocker compose up\n```\n\n3. Access the monitoring interfaces:\n   - Grafana: [http://localhost:3000](http://localhost:3000)\n   - Prometheus: [http://localhost:9090](http://localhost:9090)\n\nDefault Grafana login credentials:\n- Username: `admin`\n- Password: `admin`\n\nYou'll be prompted to change the password on first login.\n\n4. The SGLang dashboard will be automatically available in the \"SGLang Monitoring\" folder.\n\n## Troubleshooting\n\n### Port Conflicts\nIf you see errors like \"port is already allocated\":\n\n1. Check if you already have Prometheus or Grafana running:\n   ```bash\n   docker ps | grep -E 'prometheus|grafana'\n   ```\n\n2. Stop any conflicting containers:\n   ```bash\n   docker stop <container_id>\n   ```\n\n3. Ensure no other services are using ports 9090 and 3000:\n   ```bash\n   lsof -i :9090\n   lsof -i :3000\n   ```\n\n### Connection Issues\nIf Grafana cannot connect to Prometheus:\n1. Check that both services are running\n2. Verify the datasource configuration in Grafana\n3. Check that your SGLang server is properly exposing metrics\n\n## Configuration\n\n- Prometheus configuration: `prometheus.yaml`\n- Docker Compose configuration: `docker-compose.yaml`\n- Grafana datasource: `grafana/datasources/datasource.yaml`\n- Grafana dashboard configuration: `grafana/dashboards/config/dashboard.yaml`\n- SGLang dashboard JSON: `grafana/dashboards/json/sglang-dashboard.json`\n\n## Customization\n\nYou can customize the monitoring setup by modifying the configuration files as needed.\n"
  },
  {
    "path": "examples/monitoring/docker-compose.yaml",
    "content": "version: '3'\nservices:\n  prometheus:\n    image: prom/prometheus:latest\n    container_name: prometheus\n    network_mode: host\n    volumes:\n      - ./prometheus.yaml:/etc/prometheus/prometheus.yml\n    command:\n      - '--config.file=/etc/prometheus/prometheus.yml'\n      - '--storage.tsdb.path=/prometheus'\n\n  grafana:\n    image: grafana/grafana:latest\n    container_name: grafana\n    network_mode: host\n    volumes:\n      - ./grafana/datasources:/etc/grafana/provisioning/datasources\n      - ./grafana/dashboards/config:/etc/grafana/provisioning/dashboards\n      - ./grafana/dashboards/json:/var/lib/grafana/dashboards\n    environment:\n      - GF_AUTH_ANONYMOUS_ENABLED=true\n      - GF_AUTH_ANONYMOUS_ORG_ROLE=Viewer\n      - GF_AUTH_BASIC_ENABLED=false\n      - GF_USERS_ALLOW_SIGN_UP=false\n      - GF_DASHBOARDS_DEFAULT_HOME_DASHBOARD_PATH=/var/lib/grafana/dashboards/sglang-dashboard.json\n    depends_on:\n      - prometheus\n"
  },
  {
    "path": "examples/monitoring/grafana/dashboards/config/dashboard.yaml",
    "content": "apiVersion: 1\nproviders:\n  - name: 'SGLang'\n    orgId: 1\n    folder: 'SGLang Monitoring'\n    type: file\n    disableDeletion: false\n    updateIntervalSeconds: 10\n    allowUiUpdates: false\n    options:\n      path: /var/lib/grafana/dashboards\n"
  },
  {
    "path": "examples/monitoring/grafana/dashboards/json/sglang-dashboard.json",
    "content": "{\n  \"annotations\": {\n    \"list\": [\n      {\n        \"builtIn\": 1,\n        \"datasource\": {\n          \"type\": \"grafana\",\n          \"uid\": \"-- Grafana --\"\n        },\n        \"enable\": true,\n        \"hide\": true,\n        \"iconColor\": \"rgba(0, 211, 255, 1)\",\n        \"name\": \"Annotations & Alerts\",\n        \"type\": \"dashboard\"\n      }\n    ]\n  },\n  \"editable\": true,\n  \"fiscalYearStartMonth\": 0,\n  \"graphTooltip\": 0,\n  \"id\": 8,\n  \"links\": [],\n  \"panels\": [\n    {\n      \"datasource\": {\n        \"default\": true,\n        \"type\": \"prometheus\"\n      },\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"color\": {\n            \"mode\": \"palette-classic\"\n          },\n          \"custom\": {\n            \"axisBorderShow\": false,\n            \"axisCenteredZero\": false,\n            \"axisColorMode\": \"text\",\n            \"axisLabel\": \"\",\n            \"axisPlacement\": \"auto\",\n            \"barAlignment\": 0,\n            \"barWidthFactor\": 0.6,\n            \"drawStyle\": \"line\",\n            \"fillOpacity\": 0,\n            \"gradientMode\": \"none\",\n            \"hideFrom\": {\n              \"legend\": false,\n              \"tooltip\": false,\n              \"viz\": false\n            },\n            \"insertNulls\": false,\n            \"lineInterpolation\": \"linear\",\n            \"lineWidth\": 1,\n            \"pointSize\": 5,\n            \"scaleDistribution\": {\n              \"type\": \"linear\"\n            },\n            \"showPoints\": \"auto\",\n            \"spanNulls\": false,\n            \"stacking\": {\n              \"group\": \"A\",\n              \"mode\": \"none\"\n            },\n            \"thresholdsStyle\": {\n              \"mode\": \"off\"\n            }\n          },\n          \"mappings\": [],\n          \"thresholds\": {\n            \"mode\": \"absolute\",\n            \"steps\": [\n              {\n                \"color\": \"green\"\n              },\n              {\n                \"color\": \"red\",\n                \"value\": 80\n              }\n            ]\n          }\n        },\n        \"overrides\": []\n      },\n      \"gridPos\": {\n        \"h\": 8,\n        \"w\": 12,\n        \"x\": 0,\n        \"y\": 0\n      },\n      \"id\": 14,\n      \"options\": {\n        \"legend\": {\n          \"calcs\": [],\n          \"displayMode\": \"list\",\n          \"placement\": \"bottom\",\n          \"showLegend\": true\n        },\n        \"tooltip\": {\n          \"hideZeros\": false,\n          \"mode\": \"single\",\n          \"sort\": \"none\"\n        }\n      },\n      \"pluginVersion\": \"11.6.0\",\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"ddyfngn31dg5cf\"\n          },\n          \"disableTextWrap\": false,\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.99, sum by (le) (rate(sglang_e2e_request_latency_seconds_bucket[$__rate_interval])))\\r\\n\",\n          \"fullMetaSearch\": false,\n          \"includeNullMetadata\": true,\n          \"instant\": false,\n          \"legendFormat\": \"P99\",\n          \"range\": true,\n          \"refId\": \"A\",\n          \"useBackend\": false\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"ddyfngn31dg5cf\"\n          },\n          \"disableTextWrap\": false,\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.9, sum by (le) (rate(sglang_e2e_request_latency_seconds_bucket[$__rate_interval])))\\r\\n\",\n          \"fullMetaSearch\": false,\n          \"hide\": false,\n          \"includeNullMetadata\": true,\n          \"instant\": false,\n          \"legendFormat\": \"P90\",\n          \"range\": true,\n          \"refId\": \"B\",\n          \"useBackend\": false\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"ddyfngn31dg5cf\"\n          },\n          \"disableTextWrap\": false,\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.5, sum by (le) (rate(sglang_e2e_request_latency_seconds_bucket[$__rate_interval])))\\r\\n\",\n          \"fullMetaSearch\": false,\n          \"hide\": false,\n          \"includeNullMetadata\": true,\n          \"instant\": false,\n          \"legendFormat\": \"P50\",\n          \"range\": true,\n          \"refId\": \"C\",\n          \"useBackend\": false\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"ddyfngn31dg5cf\"\n          },\n          \"disableTextWrap\": false,\n          \"editorMode\": \"code\",\n          \"expr\": \"avg(rate(sglang_e2e_request_latency_seconds_sum[$__rate_interval]) /  rate(sglang_e2e_request_latency_seconds_count[$__rate_interval]))\\r\\n\",\n          \"fullMetaSearch\": false,\n          \"hide\": false,\n          \"includeNullMetadata\": true,\n          \"instant\": false,\n          \"legendFormat\": \"Avg\",\n          \"range\": true,\n          \"refId\": \"D\",\n          \"useBackend\": false\n        }\n      ],\n      \"title\": \"End-to-End Request Latency\",\n      \"type\": \"timeseries\"\n    },\n    {\n      \"datasource\": {\n        \"default\": true,\n        \"type\": \"prometheus\"\n      },\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"custom\": {\n            \"hideFrom\": {\n              \"legend\": false,\n              \"tooltip\": false,\n              \"viz\": false\n            },\n            \"scaleDistribution\": {\n              \"type\": \"linear\"\n            }\n          }\n        },\n        \"overrides\": []\n      },\n      \"gridPos\": {\n        \"h\": 8,\n        \"w\": 12,\n        \"x\": 12,\n        \"y\": 0\n      },\n      \"id\": 17,\n      \"maxDataPoints\": 30,\n      \"options\": {\n        \"calculate\": false,\n        \"calculation\": {\n          \"yBuckets\": {\n            \"scale\": {\n              \"type\": \"linear\"\n            }\n          }\n        },\n        \"cellGap\": 1,\n        \"cellValues\": {},\n        \"color\": {\n          \"exponent\": 0.5,\n          \"fill\": \"dark-orange\",\n          \"mode\": \"scheme\",\n          \"reverse\": false,\n          \"scale\": \"exponential\",\n          \"scheme\": \"Spectral\",\n          \"steps\": 64\n        },\n        \"exemplars\": {\n          \"color\": \"rgba(255,0,255,0.7)\"\n        },\n        \"filterValues\": {\n          \"le\": 1e-9\n        },\n        \"legend\": {\n          \"show\": true\n        },\n        \"rowsFrame\": {\n          \"layout\": \"auto\"\n        },\n        \"tooltip\": {\n          \"mode\": \"single\",\n          \"showColorScale\": true,\n          \"yHistogram\": false\n        },\n        \"yAxis\": {\n          \"axisPlacement\": \"left\",\n          \"reverse\": false,\n          \"unit\": \"secs\"\n        }\n      },\n      \"pluginVersion\": \"11.6.0\",\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"ddyfngn31dg5cf\"\n          },\n          \"disableTextWrap\": false,\n          \"editorMode\": \"builder\",\n          \"expr\": \"sum(increase(sglang_e2e_request_latency_seconds_bucket{model_name=~\\\"$model_name\\\"}[$__rate_interval])) by (le)\\r\\n\",\n          \"format\": \"heatmap\",\n          \"fullMetaSearch\": false,\n          \"includeNullMetadata\": true,\n          \"instant\": false,\n          \"legendFormat\": \"{{le}}\",\n          \"range\": true,\n          \"refId\": \"A\",\n          \"useBackend\": false\n        }\n      ],\n      \"title\": \"End-to-End Request Latency(s) Heatmap\",\n      \"type\": \"heatmap\"\n    },\n    {\n      \"datasource\": {\n        \"default\": true,\n        \"type\": \"prometheus\"\n      },\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"color\": {\n            \"mode\": \"palette-classic\"\n          },\n          \"custom\": {\n            \"axisBorderShow\": false,\n            \"axisCenteredZero\": false,\n            \"axisColorMode\": \"text\",\n            \"axisLabel\": \"\",\n            \"axisPlacement\": \"auto\",\n            \"barAlignment\": 0,\n            \"barWidthFactor\": 0.6,\n            \"drawStyle\": \"line\",\n            \"fillOpacity\": 0,\n            \"gradientMode\": \"none\",\n            \"hideFrom\": {\n              \"legend\": false,\n              \"tooltip\": false,\n              \"viz\": false\n            },\n            \"insertNulls\": false,\n            \"lineInterpolation\": \"linear\",\n            \"lineWidth\": 1,\n            \"pointSize\": 5,\n            \"scaleDistribution\": {\n              \"type\": \"linear\"\n            },\n            \"showPoints\": \"auto\",\n            \"spanNulls\": false,\n            \"stacking\": {\n              \"group\": \"A\",\n              \"mode\": \"none\"\n            },\n            \"thresholdsStyle\": {\n              \"mode\": \"off\"\n            }\n          },\n          \"mappings\": [],\n          \"thresholds\": {\n            \"mode\": \"absolute\",\n            \"steps\": [\n              {\n                \"color\": \"green\"\n              },\n              {\n                \"color\": \"red\",\n                \"value\": 80\n              }\n            ]\n          }\n        },\n        \"overrides\": []\n      },\n      \"gridPos\": {\n        \"h\": 8,\n        \"w\": 12,\n        \"x\": 0,\n        \"y\": 8\n      },\n      \"id\": 20,\n      \"options\": {\n        \"legend\": {\n          \"calcs\": [],\n          \"displayMode\": \"list\",\n          \"placement\": \"bottom\",\n          \"showLegend\": true\n        },\n        \"tooltip\": {\n          \"hideZeros\": false,\n          \"mode\": \"single\",\n          \"sort\": \"none\"\n        }\n      },\n      \"pluginVersion\": \"11.6.0\",\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"ddyfngn31dg5cf\"\n          },\n          \"disableTextWrap\": false,\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.99, sum by (le) (rate(sglang_time_to_first_token_seconds_bucket[$__rate_interval])))\\r\\n\",\n          \"fullMetaSearch\": false,\n          \"includeNullMetadata\": true,\n          \"instant\": false,\n          \"legendFormat\": \"P99\",\n          \"range\": true,\n          \"refId\": \"A\",\n          \"useBackend\": false\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"ddyfngn31dg5cf\"\n          },\n          \"disableTextWrap\": false,\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.9, sum by (le) (rate(sglang_time_to_first_token_seconds_bucket[$__rate_interval])))\\r\\n\",\n          \"fullMetaSearch\": false,\n          \"hide\": false,\n          \"includeNullMetadata\": true,\n          \"instant\": false,\n          \"legendFormat\": \"P90\",\n          \"range\": true,\n          \"refId\": \"B\",\n          \"useBackend\": false\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"ddyfngn31dg5cf\"\n          },\n          \"disableTextWrap\": false,\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.5, sum by (le) (rate(sglang_time_to_first_token_seconds_bucket[$__rate_interval])))\\r\\n\",\n          \"fullMetaSearch\": false,\n          \"hide\": false,\n          \"includeNullMetadata\": true,\n          \"instant\": false,\n          \"legendFormat\": \"P50\",\n          \"range\": true,\n          \"refId\": \"C\",\n          \"useBackend\": false\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"ddyfngn31dg5cf\"\n          },\n          \"disableTextWrap\": false,\n          \"editorMode\": \"code\",\n          \"expr\": \"avg(rate(sglang_time_to_first_token_seconds_sum[$__rate_interval]) /  rate(sglang_time_to_first_token_seconds_count[$__rate_interval]))\\r\\n\",\n          \"fullMetaSearch\": false,\n          \"hide\": false,\n          \"includeNullMetadata\": true,\n          \"instant\": false,\n          \"legendFormat\": \"Avg\",\n          \"range\": true,\n          \"refId\": \"D\",\n          \"useBackend\": false\n        }\n      ],\n      \"title\": \"Time-To-First-Token Latency\",\n      \"type\": \"timeseries\"\n    },\n    {\n      \"datasource\": {\n        \"default\": true,\n        \"type\": \"prometheus\"\n      },\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"custom\": {\n            \"hideFrom\": {\n              \"legend\": false,\n              \"tooltip\": false,\n              \"viz\": false\n            },\n            \"scaleDistribution\": {\n              \"type\": \"linear\"\n            }\n          }\n        },\n        \"overrides\": []\n      },\n      \"gridPos\": {\n        \"h\": 8,\n        \"w\": 12,\n        \"x\": 12,\n        \"y\": 8\n      },\n      \"id\": 19,\n      \"maxDataPoints\": 30,\n      \"options\": {\n        \"calculate\": false,\n        \"calculation\": {\n          \"xBuckets\": {\n            \"value\": \"\"\n          },\n          \"yBuckets\": {\n            \"mode\": \"size\",\n            \"scale\": {\n              \"type\": \"linear\"\n            },\n            \"value\": \"\"\n          }\n        },\n        \"cellGap\": 1,\n        \"color\": {\n          \"exponent\": 0.5,\n          \"fill\": \"dark-orange\",\n          \"mode\": \"scheme\",\n          \"reverse\": false,\n          \"scale\": \"exponential\",\n          \"scheme\": \"Spectral\",\n          \"steps\": 64\n        },\n        \"exemplars\": {\n          \"color\": \"rgba(255,0,255,0.7)\"\n        },\n        \"filterValues\": {\n          \"le\": 1e-9\n        },\n        \"legend\": {\n          \"show\": true\n        },\n        \"rowsFrame\": {\n          \"layout\": \"auto\"\n        },\n        \"tooltip\": {\n          \"mode\": \"single\",\n          \"showColorScale\": true,\n          \"yHistogram\": false\n        },\n        \"yAxis\": {\n          \"axisPlacement\": \"left\",\n          \"reverse\": false\n        }\n      },\n      \"pluginVersion\": \"11.6.0\",\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"ddyfngn31dg5cf\"\n          },\n          \"disableTextWrap\": false,\n          \"editorMode\": \"builder\",\n          \"exemplar\": false,\n          \"expr\": \"sum by(le) (increase(sglang_time_to_first_token_seconds_bucket{model_name=~\\\"$model_name\\\"}[$__rate_interval]))\",\n          \"format\": \"heatmap\",\n          \"fullMetaSearch\": false,\n          \"includeNullMetadata\": true,\n          \"instant\": false,\n          \"interval\": \"\",\n          \"legendFormat\": \"{{le}}\",\n          \"range\": true,\n          \"refId\": \"A\",\n          \"useBackend\": false\n        }\n      ],\n      \"title\": \"Time-To-First-Token Seconds Heatmap\",\n      \"type\": \"heatmap\"\n    },\n    {\n      \"datasource\": {\n        \"default\": true,\n        \"type\": \"prometheus\"\n      },\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"color\": {\n            \"mode\": \"palette-classic\"\n          },\n          \"custom\": {\n            \"axisBorderShow\": false,\n            \"axisCenteredZero\": false,\n            \"axisColorMode\": \"text\",\n            \"axisLabel\": \"\",\n            \"axisPlacement\": \"auto\",\n            \"barAlignment\": 0,\n            \"barWidthFactor\": 0.6,\n            \"drawStyle\": \"line\",\n            \"fillOpacity\": 0,\n            \"gradientMode\": \"none\",\n            \"hideFrom\": {\n              \"legend\": false,\n              \"tooltip\": false,\n              \"viz\": false\n            },\n            \"insertNulls\": false,\n            \"lineInterpolation\": \"linear\",\n            \"lineWidth\": 1,\n            \"pointSize\": 5,\n            \"scaleDistribution\": {\n              \"type\": \"linear\"\n            },\n            \"showPoints\": \"auto\",\n            \"spanNulls\": false,\n            \"stacking\": {\n              \"group\": \"A\",\n              \"mode\": \"none\"\n            },\n            \"thresholdsStyle\": {\n              \"mode\": \"off\"\n            }\n          },\n          \"mappings\": [],\n          \"thresholds\": {\n            \"mode\": \"absolute\",\n            \"steps\": [\n              {\n                \"color\": \"green\"\n              },\n              {\n                \"color\": \"red\",\n                \"value\": 80\n              }\n            ]\n          }\n        },\n        \"overrides\": []\n      },\n      \"gridPos\": {\n        \"h\": 8,\n        \"w\": 12,\n        \"x\": 0,\n        \"y\": 16\n      },\n      \"id\": 7,\n      \"options\": {\n        \"legend\": {\n          \"calcs\": [],\n          \"displayMode\": \"list\",\n          \"placement\": \"bottom\",\n          \"showLegend\": true\n        },\n        \"tooltip\": {\n          \"hideZeros\": false,\n          \"mode\": \"single\",\n          \"sort\": \"none\"\n        }\n      },\n      \"pluginVersion\": \"11.6.0\",\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"ddyfngn31dg5cf\"\n          },\n          \"disableTextWrap\": false,\n          \"editorMode\": \"code\",\n          \"expr\": \"sglang_num_running_reqs\",\n          \"fullMetaSearch\": false,\n          \"includeNullMetadata\": true,\n          \"instant\": false,\n          \"interval\": \"\",\n          \"legendFormat\": \"{{instance}}\",\n          \"range\": true,\n          \"refId\": \"A\",\n          \"useBackend\": false\n        }\n      ],\n      \"title\": \"Num Running Requests\",\n      \"type\": \"timeseries\"\n    },\n    {\n      \"datasource\": {\n        \"default\": true,\n        \"type\": \"prometheus\"\n      },\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"color\": {\n            \"mode\": \"palette-classic\"\n          },\n          \"custom\": {\n            \"axisBorderShow\": false,\n            \"axisCenteredZero\": false,\n            \"axisColorMode\": \"text\",\n            \"axisLabel\": \"\",\n            \"axisPlacement\": \"auto\",\n            \"barAlignment\": 0,\n            \"barWidthFactor\": 0.6,\n            \"drawStyle\": \"line\",\n            \"fillOpacity\": 0,\n            \"gradientMode\": \"none\",\n            \"hideFrom\": {\n              \"legend\": false,\n              \"tooltip\": false,\n              \"viz\": false\n            },\n            \"insertNulls\": false,\n            \"lineInterpolation\": \"linear\",\n            \"lineWidth\": 1,\n            \"pointSize\": 5,\n            \"scaleDistribution\": {\n              \"type\": \"linear\"\n            },\n            \"showPoints\": \"auto\",\n            \"spanNulls\": false,\n            \"stacking\": {\n              \"group\": \"A\",\n              \"mode\": \"none\"\n            },\n            \"thresholdsStyle\": {\n              \"mode\": \"off\"\n            }\n          },\n          \"mappings\": [],\n          \"thresholds\": {\n            \"mode\": \"absolute\",\n            \"steps\": [\n              {\n                \"color\": \"green\"\n              },\n              {\n                \"color\": \"red\",\n                \"value\": 80\n              }\n            ]\n          }\n        },\n        \"overrides\": []\n      },\n      \"gridPos\": {\n        \"h\": 8,\n        \"w\": 12,\n        \"x\": 12,\n        \"y\": 16\n      },\n      \"id\": 18,\n      \"options\": {\n        \"legend\": {\n          \"calcs\": [],\n          \"displayMode\": \"list\",\n          \"placement\": \"bottom\",\n          \"showLegend\": true\n        },\n        \"tooltip\": {\n          \"hideZeros\": false,\n          \"mode\": \"single\",\n          \"sort\": \"none\"\n        }\n      },\n      \"pluginVersion\": \"11.6.0\",\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"ddyfngn31dg5cf\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"sglang_gen_throughput\",\n          \"instant\": false,\n          \"legendFormat\": \"{{instance}}\",\n          \"range\": true,\n          \"refId\": \"A\"\n        }\n      ],\n      \"title\": \"Token Generation Throughput (Tokens / S)\",\n      \"type\": \"timeseries\"\n    },\n    {\n      \"datasource\": {\n        \"default\": true,\n        \"type\": \"prometheus\"\n      },\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"color\": {\n            \"mode\": \"palette-classic\"\n          },\n          \"custom\": {\n            \"axisBorderShow\": false,\n            \"axisCenteredZero\": false,\n            \"axisColorMode\": \"text\",\n            \"axisLabel\": \"\",\n            \"axisPlacement\": \"auto\",\n            \"barAlignment\": 0,\n            \"barWidthFactor\": 0.6,\n            \"drawStyle\": \"line\",\n            \"fillOpacity\": 0,\n            \"gradientMode\": \"none\",\n            \"hideFrom\": {\n              \"legend\": false,\n              \"tooltip\": false,\n              \"viz\": false\n            },\n            \"insertNulls\": false,\n            \"lineInterpolation\": \"linear\",\n            \"lineWidth\": 1,\n            \"pointSize\": 5,\n            \"scaleDistribution\": {\n              \"type\": \"linear\"\n            },\n            \"showPoints\": \"auto\",\n            \"spanNulls\": false,\n            \"stacking\": {\n              \"group\": \"A\",\n              \"mode\": \"none\"\n            },\n            \"thresholdsStyle\": {\n              \"mode\": \"off\"\n            }\n          },\n          \"mappings\": [],\n          \"thresholds\": {\n            \"mode\": \"absolute\",\n            \"steps\": [\n              {\n                \"color\": \"green\"\n              },\n              {\n                \"color\": \"red\",\n                \"value\": 80\n              }\n            ]\n          }\n        },\n        \"overrides\": []\n      },\n      \"gridPos\": {\n        \"h\": 8,\n        \"w\": 12,\n        \"x\": 0,\n        \"y\": 24\n      },\n      \"id\": 11,\n      \"options\": {\n        \"legend\": {\n          \"calcs\": [],\n          \"displayMode\": \"list\",\n          \"placement\": \"bottom\",\n          \"showLegend\": true\n        },\n        \"tooltip\": {\n          \"hideZeros\": false,\n          \"mode\": \"single\",\n          \"sort\": \"none\"\n        }\n      },\n      \"pluginVersion\": \"11.6.0\",\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"ddyfngn31dg5cf\"\n          },\n          \"disableTextWrap\": false,\n          \"editorMode\": \"code\",\n          \"expr\": \"sglang_cache_hit_rate\",\n          \"fullMetaSearch\": false,\n          \"includeNullMetadata\": true,\n          \"instant\": false,\n          \"legendFormat\": \"{{instance}}\",\n          \"range\": true,\n          \"refId\": \"A\",\n          \"useBackend\": false\n        }\n      ],\n      \"title\": \"Cache Hit Rate\",\n      \"type\": \"timeseries\"\n    },\n    {\n      \"datasource\": {\n        \"default\": true,\n        \"type\": \"prometheus\"\n      },\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"color\": {\n            \"mode\": \"palette-classic\"\n          },\n          \"custom\": {\n            \"axisBorderShow\": false,\n            \"axisCenteredZero\": false,\n            \"axisColorMode\": \"text\",\n            \"axisLabel\": \"\",\n            \"axisPlacement\": \"auto\",\n            \"barAlignment\": 0,\n            \"barWidthFactor\": 0.6,\n            \"drawStyle\": \"line\",\n            \"fillOpacity\": 0,\n            \"gradientMode\": \"none\",\n            \"hideFrom\": {\n              \"legend\": false,\n              \"tooltip\": false,\n              \"viz\": false\n            },\n            \"insertNulls\": false,\n            \"lineInterpolation\": \"linear\",\n            \"lineWidth\": 1,\n            \"pointSize\": 5,\n            \"scaleDistribution\": {\n              \"type\": \"linear\"\n            },\n            \"showPoints\": \"auto\",\n            \"spanNulls\": false,\n            \"stacking\": {\n              \"group\": \"A\",\n              \"mode\": \"none\"\n            },\n            \"thresholdsStyle\": {\n              \"mode\": \"off\"\n            }\n          },\n          \"mappings\": [],\n          \"thresholds\": {\n            \"mode\": \"absolute\",\n            \"steps\": [\n              {\n                \"color\": \"green\"\n              },\n              {\n                \"color\": \"red\",\n                \"value\": 80\n              }\n            ]\n          }\n        },\n        \"overrides\": []\n      },\n      \"gridPos\": {\n        \"h\": 8,\n        \"w\": 12,\n        \"x\": 12,\n        \"y\": 24\n      },\n      \"id\": 8,\n      \"options\": {\n        \"legend\": {\n          \"calcs\": [],\n          \"displayMode\": \"list\",\n          \"placement\": \"bottom\",\n          \"showLegend\": true\n        },\n        \"tooltip\": {\n          \"hideZeros\": false,\n          \"mode\": \"single\",\n          \"sort\": \"none\"\n        }\n      },\n      \"pluginVersion\": \"11.6.0\",\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"ddyfngn31dg5cf\"\n          },\n          \"disableTextWrap\": false,\n          \"editorMode\": \"code\",\n          \"expr\": \"sglang_num_queue_reqs\",\n          \"fullMetaSearch\": false,\n          \"includeNullMetadata\": true,\n          \"instant\": false,\n          \"legendFormat\": \"{{instance}}\",\n          \"range\": true,\n          \"refId\": \"A\",\n          \"useBackend\": false\n        }\n      ],\n      \"title\": \"Number Queued Requests\",\n      \"type\": \"timeseries\"\n    }\n  ],\n  \"preload\": false,\n  \"refresh\": \"5s\",\n  \"schemaVersion\": 41,\n  \"tags\": [],\n  \"templating\": {\n    \"list\": [\n      {\n        \"current\": {\n          \"text\": \"127.0.0.1:30000\",\n          \"value\": \"127.0.0.1:30000\"\n        },\n        \"datasource\": {\n          \"type\": \"prometheus\"\n        },\n        \"definition\": \"label_values(instance)\",\n        \"includeAll\": false,\n        \"label\": \"instance\",\n        \"name\": \"instance\",\n        \"options\": [],\n        \"query\": {\n          \"qryType\": 1,\n          \"query\": \"label_values(instance)\",\n          \"refId\": \"PrometheusVariableQueryEditor-VariableQuery\"\n        },\n        \"refresh\": 1,\n        \"regex\": \"\",\n        \"type\": \"query\"\n      },\n      {\n        \"current\": {\n          \"text\": \"meta-llama/Llama-3.1-8B-Instruct\",\n          \"value\": \"meta-llama/Llama-3.1-8B-Instruct\"\n        },\n        \"datasource\": {\n          \"type\": \"prometheus\"\n        },\n        \"definition\": \"label_values(model_name)\",\n        \"includeAll\": false,\n        \"label\": \"model name\",\n        \"name\": \"model_name\",\n        \"options\": [],\n        \"query\": {\n          \"qryType\": 1,\n          \"query\": \"label_values(model_name)\",\n          \"refId\": \"PrometheusVariableQueryEditor-VariableQuery\"\n        },\n        \"refresh\": 1,\n        \"regex\": \"\",\n        \"type\": \"query\"\n      }\n    ]\n  },\n  \"time\": {\n    \"from\": \"now-30m\",\n    \"to\": \"now\"\n  },\n  \"timepicker\": {},\n  \"timezone\": \"browser\",\n  \"title\": \"SGLang Dashboard\",\n  \"uid\": \"sglang-dashboard\",\n  \"version\": 11\n}\n"
  },
  {
    "path": "examples/monitoring/grafana/datasources/datasource.yaml",
    "content": "apiVersion: 1\ndatasources:\n  - name: Prometheus\n    type: prometheus\n    access: proxy\n    url: http://localhost:9090\n    isDefault: true\n    editable: false\n"
  },
  {
    "path": "examples/monitoring/opentelemetry.yaml",
    "content": "receivers:\n  otlp:\n    protocols:\n      grpc:\n        endpoint: 0.0.0.0:4317\n      http:\n        endpoint: 0.0.0.0:4318\nprocessors:\n  batch:\n\nexporters:\n  otlp:\n    endpoint: jaeger:4317\n    tls:\n      insecure: true\n  file:\n    path: /tmp/otel_trace.json\n\nextensions:\n  health_check:\n  pprof:\n  zpages:\n\nservice:\n  extensions: [health_check, pprof, zpages]\n  pipelines:\n    traces:\n      receivers: [otlp]\n      processors: [batch]\n      exporters: [otlp, file]\n    metrics:\n      receivers: [otlp]\n      processors: [batch]\n      exporters: [otlp]\n    logs:\n      receivers: [otlp]\n      processors: [batch]\n      exporters: [otlp]\n"
  },
  {
    "path": "examples/monitoring/prometheus.yaml",
    "content": "# prometheus.yaml\nglobal:\n  scrape_interval: 5s\n  evaluation_interval: 30s\n\nscrape_configs:\n  - job_name: sglang\n    static_configs:\n      - targets:\n          - '127.0.0.1:30000'\n"
  },
  {
    "path": "examples/monitoring/tracing_compose.yaml",
    "content": "services:\n  otel-collector:\n    image: docker.io/otel/opentelemetry-collector\n    volumes:\n      - ./opentelemetry.yaml:/etc/otelcol/config.yaml\n      - /tmp:/tmp\n    ports:\n      - \"4317:4317\"   # OTLP gRPC\n      - \"4318:4318\"   # OTLP HTTP\n    depends_on:\n      - jaeger\n    restart: unless-stopped\n\n  jaeger:\n    image: jaegertracing/all-in-one\n    container_name: jaeger\n    ports:\n      - \"16686:16686\"\n    environment:\n      - COLLECTOR_OTLP_ENABLED=true\n    restart: unless-stopped\n"
  },
  {
    "path": "examples/profiler/nsys_profile_tools/README.md",
    "content": "# gputrc2graph.py\n\nThis script processes NVIDIA Nsight Systems (`nsys`) GPU trace files\n(`.nsys-rep`) with -t cuda tracing enabled, and generates kernel-level\nsummaries and visualizations of GPU and non-GPU time. It is useful for\nprofiling and analyzing nsys profile output.\n\n## Usage\n\n### Command-line Arguments\n\n- `--in_file`\n  **(required)**\n  List of input files and their metadata. Each entry should be in the format:\n  `<nsys-rep>,<engine>,<model>,<elapsed_nonprofiled_sec>`\n  - `nsys-rep`: Path to the `.nsys-rep` file.\n  - `engine`: Engine name (e.g., `sglang`).\n  - `model`: Model name (e.g., `llama`, `gpt-oss`, `ds`).\n  - `elapsed_nonprofiled_sec`: Wall-clock runtime (in seconds) without\n    profiling. Specify `0` to use the elapsed time from the nsys-rep file\n    (this may inflate non-GPU time if actual runtime without profiling is\n    less). Multiple entries can be provided, separated by spaces.\n\n- `--out_dir`\n  Output directory for the generated CSV and HTML files.\n  If not specified, results are saved in the current directory.\n\n- `--title`\n  Title for the HTML chart/visualization.\n\n- `--nsys_cmd`\n  Path to the `nsys` command.\n  Default: `nsys` (assumes it is in your PATH).\n  Use this if `nsys` is not in your system PATH.\n\n## Notes\n\n- Make sure you have pandas installed. Any version is fine.\n- Make sure [nsys](https://developer.nvidia.com/nsight-systems/get-started) is\ninstalled, and specify the path to the `nsys` command with `--nsys_cmd` if it\n is not in your PATH. The nsys version must be >= the nsys profile version that\n was used to collect the traces when profiling the server, so that nsys can\n process the nsys-rep that was generated.\n\n- For more details on available engines and models, see the help string in\n  the script or run:\n\n```bash\npython3 gputrc2graph.py --help\n```\n\n## Example 1: analyze a single profile\n\nTo analyze the GPU cycles of for example, a llama-3.1-8B model with sglang:\n\n1. Run the following command to collect nsys profile, for sglang server config.\n\n   ```bash\n   nsys profile -t cuda -o nsys_res -f true --trace-fork-before-exec=true \\\n   --cuda-graph-trace=node --delay <DELAY> --duration <DURATION> \\\n   python3 -m sglang.launch_server --model meta-llama/Llama-3.1-8B ...\n   ```\n\n   where:\n\n   - DELAY: how many seconds to delay nsys from collecting profiles, needed so\n     that profiles aren't captured till sglang server has come up and load\n     generation starts.\n   - DURATION: how many seconds for nsys profile to run before generating the\n     profile. This should be > the duration of the run.\n2. After the server starts, run the client load generation command. Once the\ntest completes, after DURATION amount of time, nsys profile will generate an\nnsys_res.nsys-rep file and shut down the server.\n\n3. Run step #1 again, this time starting up the server without collecting the\nprofile.\n\n4. Run step #2 again, and record the total time to complete the test in\nseconds. This value will be used by the script to calculate the\n   CPU(non-GPU) seconds for the analysis.\n\n5. Say the run elapsed time from step #4 is 132 seconds. Run script to\n   analyze:\n\n   ```bash\n   python3 gputrc2graph.py \\\n   --in_file run1.nsys-rep,sglang,llama,132\n   ```\n\nThe command will produce 2 files for analysis:\n\n- result.html: this categorizes kernel names into different categories in a\n  stacked bar chart.\n- result.csv: shows how the kernel names are mapped to the different\n  categories.\n\n### HTML visualization with result.html\n\nThe html file shows the number of elapsed seconds due to different GPU\nSubstages or categories, which consist of attention kernels as the biggest\ncategory, at 63 seconds, followed by \"gemm\" kernels. This lets the user\nprioritize the kernels to focus on for performance optimizations.\n\nThere's also an appended data table underneath the bar chart for copying out to\n other post-processing tools.\n\n### Kernel to category mapping with result.csv\n\nSuppose the user would like to focus on improving triton kernels. It's not the\nbiggest consumer of cycles at .01 sec but perhaps it hasn't been optimized.\nThe next step is to use the result.csv to dive into what the kernels are which\ncompose the triton kernel GPU cycles.\n\n## Example 2: analyze multiple profiles\n\nSuppose the user has multiple nsys trace files, captured for different models,\nsay llama and gpt-oss in this case, and wish to compare their GPU/non-GPU\ntime, something like the following command can be used.\n\n```bash\npython3 gputrc2graph.py \\\n--in_file run1.nsys-rep,sglang,llama,100 run2.nsys-rep,sglang,gpt-oss,102 \\\n--out_dir results\n```\n\nThe analysis process is similar to example 1 but now there will be multiple\nstack bar charts that can be compared.  The categories for the different\nkernels will remain the same, so that it's easy to compare the GPU cycles for\nthe same categories.\n\nOnce a category is shown to have more cycles for one configuration than\nanother, the next step would be to use the csv file to see what kernels are\nmapped into that category, and which kernels are taking the largest amount of\ntime which would cause a difference for the overall category.\n\n## Example 3: add new classification for a new model\n\nTo create a new engine DEF with model ABC, just add another json file in the same directory as\ngputrc2graph.py with the same format as the other json files. The script will automatically pick up all the json files in the same directory as engine/model specifications.\n\nThen, for this new model, suppose there are 4 kernels to be classified into\n\"gemm\" and \"attn\", where the gemm kernels have names with \"*H*\" or \"*I*\" in\nthem, and attn kernels have names with \"*J*\" or \"*K*\" in them, just add another\n .json file in the same directory as gputrc2graph.py with the same format as\n the other json files, like the following:\n\n```json\n{\n  \"DEF\": {\n      \"ABC\": {\n          \"H|I\": \"gemm\",\n          \"J|K\": \"attn\",\n          \"CUDA mem\": \"non-gpu-H_D_memops\",\n          \".*\": \"misc\"\n      }\n  }\n}\n```\n\nEach entry in the dictionary consists of:\n\n- key: a regex used to classify the kernels\n- value: the category to classify the kernels into.\n\nThe last 2 entries are common for all engine/models, consisting of CUDA memory\noperations and a 'misc' for anything that's leftover and can't be classified.\n\nWhen invoking gputrc2graph.py, specify a trace file with this new model/engine\nlike the following:\n\n```bash\n--in_file new.nsys-rep,DEF,ABC,<runtime>\n```\n\nIf the engine_DEF.json file already exists, just add the model as a new node in\n the existing engine file, after the other models.\n"
  },
  {
    "path": "examples/profiler/nsys_profile_tools/gputrc2graph.py",
    "content": "\"\"\"\nThis generates gpu kernel analysis output from nsys rep. Will call nsys\nstats  -r cuda_gpu_kern_trace, get non-overlapped gpu cycles, then generate\ncsv and html output for analysis\n\"\"\"\n\nimport argparse\nimport logging\nimport os\nimport shlex\n\nimport regex as re\n\nlogger = logging.getLogger(__name__)\n\n\n# helper data class for annotating kernels\ndef load_engine_model():\n    \"\"\"returns engine_model built from all json files in the current dir\"\"\"\n    import glob\n    import json\n\n    engine_model = {}\n\n    json_files = glob.glob(os.path.join(os.path.dirname(__file__) or \".\", \"*.json\"))\n    for fname in json_files:\n        with open(fname, encoding=\"utf-8\") as f:\n            file_engine_model = json.load(f)\n        for engine, models in file_engine_model.items():\n            engine_model.setdefault(engine, {}).update(models)\n    return engine_model\n\n\nclass GPUTrace2Graph:\n    \"\"\"\n    Parses output of nsys report, generates csv and bar chart output\n    \"\"\"\n\n    def __init__(self):\n        import pandas as pd  # avoid importing till needed\n\n        self.pd = pd\n\n    # helper functions for generating trace->summary csvs\n    def gen_nonoverlapped_sum_from_gputrace(self, in_file, out_file):\n        logger.info(\"loading %s\", in_file)\n        df = self.pd.read_csv(\n            in_file, usecols=[\"Start (ns)\", \"Duration (ns)\", \"Device\", \"Strm\", \"Name\"]\n        )\n        df[\"End (ns)\"] = df[\"Start (ns)\"] + df[\"Duration (ns)\"]\n        df = self.sum_non_overlapping_intervals(df)\n        # get ready to print table with elapsed times per kernel\n        df[\"Instances\"] = 1\n        df_sum = df.groupby(\"Name\", as_index=False).agg(\n            {\"Elapsed Time (ns)\": \"sum\", \"Duration (ns)\": \"sum\", \"Instances\": \"size\"}\n        )\n\n        # generate csv\n        df_sum[\"Total Time (sec)\"] = df_sum[\"Duration (ns)\"] / 1e9\n        df_sum[\"Elapsed Time (sec)\"] = df_sum[\"Elapsed Time (ns)\"] / 1e9\n        df_sum = df_sum.sort_values(by=\"Elapsed Time (sec)\", ascending=False)\n        df_sum[[\"Elapsed Time (sec)\", \"Total Time (sec)\", \"Instances\", \"Name\"]].to_csv(\n            out_file, index=False\n        )\n\n    def sum_non_overlapping_intervals(self, df):\n        \"\"\"\n        returns new sorted df with Elapsed Time (ns) column using\n        vectorized operations\n        \"\"\"\n        logger.info(\"sorting %s trace records by start time\", str(df.shape))\n\n        # Sort by start time and reset index\n        df = df.sort_values(by=\"Start (ns)\").reset_index(drop=True)\n\n        # Initialize elapsed time as duration\n        df[\"Elapsed Time (ns)\"] = df[\"Duration (ns)\"]\n\n        # Get numpy arrays for faster operations\n        starts = df[\"Start (ns)\"].values\n        ends = df[\"End (ns)\"].values\n\n        # Keep track of current interval end\n        current_end = ends[0]\n        display_units = max(1, int(len(df) / 100))\n        # Update current_end for overlapping intervals\n        for i in range(1, len(df)):\n            if i % display_units == 0:\n                print(f\"processing trace: {int(i/len(df) * 100)} %\", end=\"\\r\")\n            if starts[i] <= current_end:\n                if ends[i] > current_end:\n                    # Partial overlap\n                    df.iloc[i, df.columns.get_loc(\"Elapsed Time (ns)\")] = (\n                        ends[i] - current_end\n                    )\n                    current_end = ends[i]\n                else:\n                    # Complete overlap\n                    df.iloc[i, df.columns.get_loc(\"Elapsed Time (ns)\")] = 0\n            else:\n                # No overlap\n                current_end = ends[i]\n\n        return df\n\n    # functions for generating html files\n    def make_html(self, df, output_dir, title):\n        \"\"\"make html graph from df\"\"\"\n        import plotly.express as px\n\n        if df.empty:\n            return\n        output_name = os.path.join(output_dir, \"result\")\n        if not title:\n            title = \"Model_Engine\"\n        x = \"Model_Engine\"\n        y = \"Elapsed Time (sec)\"\n        color = \"Category\"\n        \"\"\" generate kernel mapping table  \"\"\"\n        # Sort Model_Engine categories by last field after underscore\n        df[\"Model_Engine\"] = self.pd.Categorical(\n            df[\"Model_Engine\"],\n            sorted(df[\"Model_Engine\"].unique(), key=lambda x: x.split(\"_\")[-1]),\n        )\n        df[[\"Model_Engine\", color, \"Instances\", \"Name\", y]].sort_values(\n            by=color\n        ).to_csv(f\"{output_name}.csv\", index=False)\n        graph = px.histogram(\n            df.round(2),\n            x=x,\n            y=y,\n            title=(f\"{y} for {title}\"),\n            color=color,\n            text_auto=True,\n        )\n        # wrap x axis labels\n        graph.update_xaxes(automargin=True)\n        graph.write_html(f\"{output_name}.html\")\n        \"\"\"\n            Generate data table with columns per Model_Engine into result.html\n        \"\"\"\n        pivot_df = df.pivot_table(\n            values=\"Elapsed Time (sec)\",\n            index=\"Category\",\n            columns=\"Model_Engine\",\n            aggfunc=\"sum\",\n            observed=False,\n        ).round(2)\n        # Add sum row at bottom\n        pivot_df.loc[\"total_elapsed_sec\"] = pivot_df.sum()\n        pivot_df.fillna(\"\").to_html(\"temp.html\")\n        with (\n            open(f\"{output_name}.html\", \"a\", encoding=\"utf-8\") as outfile,\n            open(\"temp.html\", encoding=\"utf-8\") as infile,\n        ):\n            outfile.write(infile.read())\n        os.remove(\"temp.html\")\n\n        print(\n            f\"Finished generating: \\n\"\n            f\" {output_name}.html for stack bar chart \\n\"\n            f\" {output_name}.csv for Kernel-Category mapping\"\n        )\n\n    def anno_gpu_kernname(self, df, mapping):\n        \"\"\"add \"Category\" column\"\"\"\n\n        def anno_gpu_kernname_helper(name):\n            for kern_name, val in mapping.items():\n                if re.search(kern_name, name):\n                    return val\n\n        df[\"Category\"] = df[\"Name\"].apply(anno_gpu_kernname_helper)\n\n    def make_nongpu_row(self, df, nongpu_sec):\n        \"\"\"this will append non-gpu time entry at end of df\"\"\"\n        nongpu_row = self.pd.DataFrame([df.iloc[-1]])\n        nongpu_row[\"Category\"] = nongpu_row[\"Name\"] = \"CPU(non-GPU)\"\n        nongpu_row[\"Instances\"] = 1\n        nongpu_row[\"Elapsed Time (sec)\"] = nongpu_sec\n        return nongpu_row\n\n    def is_valid_file(self, base_file):\n        \"\"\"asserts if base_file is non-existent or is empty\"\"\"\n        assert (\n            os.path.isfile(base_file) and os.path.getsize(base_file) > 0\n        ), f\"{base_file} doesn't exist or is empty\"\n\n    def should_gen_file(self, new_file, base_file):\n        \"\"\"figure out if new file should be generated from base_file\"\"\"\n        self.is_valid_file(base_file)\n        if (\n            os.path.exists(new_file)\n            and (os.path.getmtime(new_file) > os.path.getmtime(base_file))\n            and (os.path.getsize(base_file) > 0)\n        ):\n            logger.info(\"reusing %s\", new_file)\n            return False\n        else:\n            logger.info(\"generating %s\", new_file)\n            return True\n\n    def gen_sum_file(self, file, nsys_cmd):\n        \"\"\"\n        generates sum file from nsys trace with times per kernel and\n        returns the name of the sum file\n        \"\"\"\n        import subprocess\n\n        file_dir = os.path.dirname(file)\n        file_name = os.path.basename(file)\n\n        if not file_dir:\n            file_dir = \".\"\n        # Walk through trace and get the total non-overlapped time\n        nsys_stats_file = os.path.join(file_dir, f\"{file_name}_cuda_gpu_trace.csv\")\n        sum_file = os.path.join(file_dir, f\"{file_name}_cuda_gpu_kernel_tracesum.csv\")\n        if self.should_gen_file(nsys_stats_file, file):\n            cmd = [\n                nsys_cmd,\n                \"stats\",\n                \"-r\",\n                \"cuda_gpu_trace\",\n                file,\n                \"-o\",\n                f\"{file_dir}/{file_name}\",\n            ]\n            cmd_str = shlex.join(cmd)\n            logger.info(\"+ %s\", cmd_str)\n            # estimate time based on calibrated 240M/min\n            file_size_mb = os.path.getsize(file) / 1e6\n            logger.info(\n                \"nsys stats for %.2f MB file expected to take %.2f min\",\n                file_size_mb,\n                file_size_mb / 240,\n            )\n            try:\n                subprocess.run(cmd, check=True)\n            except (FileNotFoundError, subprocess.CalledProcessError) as e:\n                logger.error(\n                    \"'%s' failed: %s. Use --nsys_cmd to specify nsys path\", cmd_str, e\n                )\n                exit(1)\n            logger.info(\"generating non-overalapped sum %s\", sum_file)\n            self.gen_nonoverlapped_sum_from_gputrace(nsys_stats_file, sum_file)\n        self.is_valid_file(sum_file)\n        logger.info(\"Finished generating %s\", sum_file)\n        return sum_file\n\n    def gen_graph(self, in_file, out_dir, title, nsys_cmd, engine_model):\n        \"\"\"generates graph and csv file from in_file into out_dir\"\"\"\n        # Initialize an empty DataFrame to store combined data\n        combined_df = self.pd.DataFrame()\n        for idx, (file, engine, model, total_sec) in enumerate(in_file):\n            file_dir = os.path.dirname(file)\n            file_name = os.path.basename(file)\n            if not file_dir:\n                file_dir = \".\"\n            sum_file = self.gen_sum_file(file, nsys_cmd)\n            # read kernel summary file\n            df = self.pd.read_csv(sum_file)\n            # annotate kernel to their categories\n            assert engine_model.get(engine), f\"engine {engine} unknown\"\n            assert engine_model[engine].get(model), f\"model {model} unknown\"\n            # remove nsys-rep from file_name for shorter x-label\n            file_name = file_name.replace(\".nsys-rep\", \"\")\n            df[\"Model_Engine\"] = f\"{model}_{engine}_{file_name}_{idx}\"\n            self.anno_gpu_kernname(df, engine_model[engine][model])\n            # patch in non-gpu time\n            gpu_sec = round(df[\"Elapsed Time (sec)\"].sum(), 1)\n            total_sec = round(float(total_sec), 1)\n            if total_sec < gpu_sec:\n                logger.warning(\n                    \"Elapsed sec %.2f < GPU sec %.2f resetting Elapsed sec \",\n                    total_sec,\n                    gpu_sec,\n                )\n                total_sec = gpu_sec\n            nongpu_row = self.make_nongpu_row(df, total_sec - gpu_sec)\n            df = self.pd.concat([df, nongpu_row], ignore_index=True)\n            combined_df = self.pd.concat([combined_df, df], ignore_index=True)\n        if out_dir is None:\n            out_dir = \".\"\n        else:\n            os.makedirs(out_dir, exist_ok=True)\n        # generate html file\n        self.make_html(combined_df, out_dir, title)\n\n\ndef parse_tuple(s):\n    return tuple(s.split(\",\"))\n\n\ndef main():\n    logging.basicConfig(\n        format=(\"%(asctime)s - %(levelname)s - %(message)s\"), level=logging.INFO\n    )\n    parser = argparse.ArgumentParser(\n        description=(\n            \"Process nsys rep and generate kernel non-overlapped cycles. \\n\"\n            \"Example:\\n\"\n            \"gputrc2graph.py --in_file d1.nsys-rep,sglang,llama,100 \\n\"\n            \"d2.nsys-rep,sglang,gpt-oss,102 \"\n            '--out_dir results/ --title \"Model=gpt-oss SGLANG chart\"'\n        ),\n        formatter_class=argparse.RawDescriptionHelpFormatter,\n    )\n\n    # load supported engine_model\n    engine_model_supported = load_engine_model()\n    # Get a string representation of supported engine/model combinations\n    engine_model_supported_str = \", \".join(\n        f\"{engine}:[{', '.join(models.keys())}]\"\n        for engine, models in engine_model_supported.items()\n    )\n    parser.add_argument(\n        \"--in_file\",\n        type=parse_tuple,\n        nargs=\"+\",\n        help=(\n            \"list of (nsys-rep, engine, model, elapsed_nonprofiled_sec) \"\n            \"separated by space. Elapsed_nonprofiled_sec is runtime without \"\n            \"profiling used to calculate non-gpu time. Specify 0 to use \"\n            \"elapsed time from nsys-rep but that might inflate non-gpu time. \"\n            f\"Available engine:[model] are: {engine_model_supported_str} \"\n            f\"Example: --infile d1.nsys-rep,sglan,llama,100 \"\n            \"d2.nsys-rep,sglang,gpt-oss,102\"\n        ),\n        required=True,\n    )\n    parser.add_argument(\"--out_dir\", help=(\"output dir for result.csv/html\"))\n    parser.add_argument(\"--title\", help=(\"title for html chart\"))\n    parser.add_argument(\n        \"--nsys_cmd\",\n        help=(\"nsys cmd, e.g. /usr/bin/nsys, Default: nsys\"),\n        default=\"nsys\",\n    )\n    args = parser.parse_args()\n    gputrace = GPUTrace2Graph()\n    gputrace.gen_graph(\n        args.in_file, args.out_dir, args.title, args.nsys_cmd, engine_model_supported\n    )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/profiler/nsys_profile_tools/sglang_engine_model.json",
    "content": "{\n  \"sglang\": {\n    \"llama\": {\n      \"gemm|nvjet\": \"gemm\",\n      \"fused_moe_kernel|GroupProblemShape|group_gemm_starts|bmm_|GemmUniversal\": \"moe_gemm\",\n      \"moe|sigmoid\": \"moe\",\n      \"CatArrayBatched|prepare_inputs\": \"prepare_next\",\n      \"ncclDevKernel|cross_device_reduce\": \"nccl_and_custom_ar\",\n      \"_norm_|Norm\": \"norm\",\n      \"topk\": \"topk\",\n      \"act_and_mul_\": \"activation\",\n      \"Rotary\": \"rope\",\n      \"SoftMax\": \"softmax\",\n      \"flash|fmha\": \"attn\",\n      \"elementwise\": \"elementwise\",\n      \"fp8_quant|cvt_|quantize\": \"quantize\",\n      \"reduce_kernel\": \"reduce\",\n      \"triton\": \"triton_kernel\",\n      \"CUDA mem\": \"non-gpu-H_D_memops\",\n      \".*\": \"misc\"\n    },\n    \"ds\": {\n      \"block_fp8_matmul\": \"block_fp8_gemm\",\n      \"gemm|matmul|nvjet\": \"gemm\",\n      \"fused_moe_kernel\": \"moe_gemm\",\n      \"moe|expert|sigmoid\": \"moe\",\n      \"CatArrayBatched|write_req_to\": \"prepare_next\",\n      \"ncclDevKernel|cross_device_reduce|all_gather\": \"nccl_and_custom_ar\",\n      \"Norm\": \"norm\",\n      \"topk\": \"topk\",\n      \"activation|act_and_mul\": \"activation\",\n      \"compute_position_kernel\": \"rope\",\n      \"elementwise\": \"elementwise\",\n      \"fp8_quant|quant_fp8|quantize\": \"quantize\",\n      \"SoftMax\": \"softmax\",\n      \"reduce\": \"reduce\",\n      \"_fwd_|create_flash|::mla::|KVCache\": \"attn\",\n      \"CUDA mem\": \"non-gpu-H_D_memops\",\n      \".*\": \"misc\"\n    },\n    \"gpt-oss\": {\n      \"gemm|nvjet\": \"gemm\",\n      \"fused_moe_kernel|_group_gemm|GroupProblemShape|GemmUniversal|bmm_|matmul_ogs_|_topk_forward|_combined_routing|_sum_bitmatrix_rows|_compute_writeback_idx\": \"moe_gemm\",\n      \"moe|sigmoid\": \"moe\",\n      \"CatArrayBatched|prepare_inputs\": \"prepare_next\",\n      \"_norm_|Norm\": \"norm\",\n      \"ncclDevKernel|cross_device_reduce|allreduce\": \"nccl_and_custom_ar\",\n      \"topk|TopK\": \"topk\",\n      \"act_and_mul_\": \"activation\",\n      \"Rotary\": \"rope\",\n      \"SoftMax\": \"softmax\",\n      \"flash|fmha\": \"attn\",\n      \"elementwise\": \"elementwise\",\n      \"fp8_quant|cvt_|quantize\": \"quantize\",\n      \"reduce_kernel\": \"reduce\",\n      \"triton\": \"triton_kernel\",\n      \"CUDA mem\": \"non-gpu-H_D_memops\",\n      \".*\": \"misc\"\n    }\n  }\n}\n"
  },
  {
    "path": "examples/runtime/README.md",
    "content": "# Runtime examples\n\nThe below examples will mostly need you to start a server in a separate terminal before you can execute them. Please see in the code for detailed instruction.\n\n## Native API\n\n* `lora.py`: An example how to use LoRA adapters.\n* `multimodal_embedding.py`: An example how perform [multi modal embedding](Alibaba-NLP/gme-Qwen2-VL-2B-Instruct).\n* `openai_batch_chat.py`: An example how to process batch requests for chat completions.\n* `openai_batch_complete.py`: An example how to process batch requests for text completions.\n* **`openai_chat_with_response_prefill.py`**:\n  An example that demonstrates how to [prefill a response](https://eugeneyan.com/writing/prompting/#prefill-claudes-responses) using the OpenAI API by enabling the `continue_final_message` parameter.\n  When enabled, the final (partial) assistant message is removed and its content is used as a prefill so that the model continues that message rather  than starting a new turn. See [Anthropic's prefill example](https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/prefill-claudes-response#example-structured-data-extraction-with-prefilling) for more context.\n* `reward_model.py`: An example how to extract scores from a reward model.\n* `vertex_predict.py`: An example how to deploy a model to [Vertex AI](https://cloud.google.com/vertex-ai?hl=en).\n\n## Engine\n\nThe `engine` folder contains that examples that show how to use [Offline Engine API](https://docs.sglang.io/basic_usage/offline_engine_api.html#Offline-Engine-API) for common workflows.\n\n* `custom_server.py`: An example how to deploy a custom server.\n* `embedding.py`: An example how to extract embeddings.\n* `launch_engine.py`: An example how to launch the Engine.\n* `offline_batch_inference_eagle.py`: An example how to perform speculative decoding using [EAGLE](https://docs.sglang.io/advanced_features/speculative_decoding.html).\n* `offline_batch_inference_torchrun.py`: An example how to perform inference using [torchrun](https://pytorch.org/docs/stable/elastic/run.html).\n* `offline_batch_inference_vlm.py`: An example how to use VLMs with the engine.\n* `offline_batch_inference.py`: An example how to use the engine to perform inference on a batch of examples.\n\n## Hidden States\n\nThe `hidden_states` folder contains examples on how to extract hidden states using SGLang. Please note that this might degrade throughput due to cuda graph rebuilding.\n\n* `hidden_states_engine.py`: An example how to extract hidden states using the Engine API.\n* `hidden_states_server.py`: An example how to extract hidden states using the Server API.\n\n## Multimodal\n\nSGLang supports multimodal inputs for various model architectures. The `multimodal` folder contains examples showing how to use urls, files or encoded data to make requests to multimodal models. Examples include querying the [Llava-OneVision](multimodal/llava_onevision_server.py) model (image, multi-image, video), Llava-backed [Qwen-Llava](multimodal/qwen_llava_server.py) and [Llama3-Llava](multimodal/llama3_llava_server.py) models (image, multi-image), and Mistral AI's [Pixtral](multimodal/pixtral_server.py) (image, multi-image).\n\n\n## Token In, Token Out\n\nThe folder `token_in_token_out` shows how to perform inference, where we provide tokens and get tokens as response.\n\n* `token_in_token_out_{llm|vlm}_{engine|server}.py`: Shows how to perform token in, token out workflow for llm/vlm using either the engine or native API.\n"
  },
  {
    "path": "examples/runtime/engine/custom_server.py",
    "content": "from sanic import Sanic, text\nfrom sanic.response import json\n\nimport sglang as sgl\n\nengine = None\n\n# Create an instance of the Sanic app\napp = Sanic(\"sanic-server\")\n\n\n# Define an asynchronous route handler\n@app.route(\"/generate\", methods=[\"POST\"])\nasync def generate(request):\n    prompt = request.json.get(\"prompt\")\n    if not prompt:\n        return json({\"error\": \"Prompt is required\"}, status=400)\n\n    # async_generate returns a dict\n    result = await engine.async_generate(prompt)\n\n    return text(result[\"text\"])\n\n\n@app.route(\"/generate_stream\", methods=[\"POST\"])\nasync def generate_stream(request):\n    prompt = request.json.get(\"prompt\")\n\n    if not prompt:\n        return json({\"error\": \"Prompt is required\"}, status=400)\n\n    # async_generate returns a dict\n    result = await engine.async_generate(prompt, stream=True)\n\n    # https://sanic.dev/en/guide/advanced/streaming.md#streaming\n    # init the response\n    response = await request.respond()\n\n    # result is an async generator\n    async for chunk in result:\n        await response.send(chunk[\"text\"])\n\n    await response.eof()\n\n\ndef run_server():\n    global engine\n    engine = sgl.Engine(model_path=\"meta-llama/Meta-Llama-3.1-8B-Instruct\")\n    app.run(host=\"0.0.0.0\", port=8000, single_process=True)\n\n\nif __name__ == \"__main__\":\n    run_server()\n"
  },
  {
    "path": "examples/runtime/engine/embedding.py",
    "content": "import sglang as sgl\n\n\ndef main():\n    # Sample prompts.\n    prompts = [\n        \"Hello, my name is\",\n        \"The president of the United States is\",\n        \"The capital of France is\",\n        \"The future of AI is\",\n    ]\n    # Create an LLM.\n    llm = sgl.Engine(\n        model_path=\"Alibaba-NLP/gte-Qwen2-1.5B-instruct\", is_embedding=True\n    )\n\n    outputs = llm.encode(prompts)\n    # Print the outputs (embedding vectors)\n    for prompt, output in zip(prompts, outputs):\n        print(\"===============================\")\n        print(f\"Prompt: {prompt}\\nEmbedding vector: {output['embedding']}\")\n\n\n# The __main__ condition is necessary here because we use \"spawn\" to create subprocesses\n# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/runtime/engine/fastapi_engine_inference.py",
    "content": "\"\"\"\nFastAPI server example for text generation using SGLang Engine and demonstrating client usage.\n\nStarts the server, sends requests to it, and prints responses.\n\nUsage:\npython fastapi_engine_inference.py --model-path Qwen/Qwen2.5-0.5B-Instruct --tp_size 1 --host 127.0.0.1 --port 8000 [--startup-timeout 60]\n\"\"\"\n\nimport os\nimport subprocess\nimport time\nfrom contextlib import asynccontextmanager\n\nimport requests\nfrom fastapi import FastAPI, Request\n\nimport sglang as sgl\nfrom sglang.utils import terminate_process\n\nengine = None\n\n\n# Use FastAPI's lifespan manager to initialize/shutdown the engine\n@asynccontextmanager\nasync def lifespan(app: FastAPI):\n    \"\"\"Manages SGLang engine initialization during server startup.\"\"\"\n    global engine\n    # Initialize the SGLang engine when the server starts\n    # Adjust model_path and other engine arguments as needed\n    print(\"Loading SGLang engine...\")\n    engine = sgl.Engine(\n        model_path=os.getenv(\"MODEL_PATH\"), tp_size=int(os.getenv(\"TP_SIZE\"))\n    )\n    print(\"SGLang engine loaded.\")\n    yield\n    # Clean up engine resources when the server stops (optional, depends on engine needs)\n    print(\"Shutting down SGLang engine...\")\n    # engine.shutdown() # Or other cleanup if available/necessary\n    print(\"SGLang engine shutdown.\")\n\n\napp = FastAPI(lifespan=lifespan)\n\n\n@app.post(\"/generate\")\nasync def generate_text(request: Request):\n    \"\"\"FastAPI endpoint to handle text generation requests.\"\"\"\n    global engine\n    if not engine:\n        return {\"error\": \"Engine not initialized\"}, 503\n\n    try:\n        data = await request.json()\n        prompt = data.get(\"prompt\")\n        max_new_tokens = data.get(\"max_new_tokens\", 128)\n        temperature = data.get(\"temperature\", 0.7)\n\n        if not prompt:\n            return {\"error\": \"Prompt is required\"}, 400\n\n        # Use async_generate for non-blocking generation\n        state = await engine.async_generate(\n            prompt,\n            sampling_params={\n                \"max_new_tokens\": max_new_tokens,\n                \"temperature\": temperature,\n            },\n            # Add other parameters like stop, top_p etc. as needed\n        )\n\n        return {\"generated_text\": state[\"text\"]}\n    except Exception as e:\n        return {\"error\": str(e)}, 500\n\n\n# Helper function to start the server\ndef start_server(args, timeout=60):\n    \"\"\"Starts the Uvicorn server as a subprocess and waits for it to be ready.\"\"\"\n    base_url = f\"http://{args.host}:{args.port}\"\n    command = [\n        \"python\",\n        \"-m\",\n        \"uvicorn\",\n        \"fastapi_engine_inference:app\",\n        f\"--host={args.host}\",\n        f\"--port={args.port}\",\n    ]\n\n    process = subprocess.Popen(command, stdout=None, stderr=None)\n\n    start_time = time.perf_counter()\n    with requests.Session() as session:\n        while time.perf_counter() - start_time < timeout:\n            try:\n                # Check the /docs endpoint which FastAPI provides by default\n                response = session.get(\n                    f\"{base_url}/docs\", timeout=5\n                )  # Add a request timeout\n                if response.status_code == 200:\n                    print(f\"Server {base_url} is ready (responded on /docs)\")\n                    return process\n            except requests.ConnectionError:\n                # Specific exception for connection refused/DNS error etc.\n                pass\n            except requests.Timeout:\n                # Specific exception for request timeout\n                print(f\"Health check to {base_url}/docs timed out, retrying...\")\n                pass\n            except requests.RequestException as e:\n                # Catch other request exceptions\n                print(f\"Health check request error: {e}, retrying...\")\n                pass\n            # Use a shorter sleep interval for faster startup detection\n            time.sleep(1)\n\n    # If loop finishes, raise the timeout error\n    # Attempt to terminate the failed process before raising\n    if process:\n        print(\n            \"Server failed to start within timeout, attempting to terminate process...\"\n        )\n        terminate_process(process)  # Use the imported terminate_process\n    raise TimeoutError(\n        f\"Server failed to start at {base_url} within the timeout period.\"\n    )\n\n\ndef send_requests(server_url, prompts, max_new_tokens, temperature):\n    \"\"\"Sends generation requests to the running server for a list of prompts.\"\"\"\n    # Iterate through prompts and send requests\n    for i, prompt in enumerate(prompts):\n        print(f\"\\n[{i+1}/{len(prompts)}] Sending prompt: '{prompt}'\")\n        payload = {\n            \"prompt\": prompt,\n            \"max_new_tokens\": max_new_tokens,\n            \"temperature\": temperature,\n        }\n\n        try:\n            response = requests.post(f\"{server_url}/generate\", json=payload, timeout=60)\n\n            result = response.json()\n\n            print(f\"Prompt: {prompt}\\nResponse: {result['generated_text']}\")\n\n        except requests.exceptions.Timeout:\n            print(f\"  Error: Request timed out for prompt '{prompt}'\")\n        except requests.exceptions.RequestException as e:\n            print(f\"  Error sending request for prompt '{prompt}': {e}\")\n\n\nif __name__ == \"__main__\":\n    \"\"\"Main entry point for the script.\"\"\"\n\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--host\", type=str, default=\"127.0.0.1\")\n    parser.add_argument(\"--port\", type=int, default=8000)\n    parser.add_argument(\"--model-path\", type=str, default=\"Qwen/Qwen2.5-0.5B-Instruct\")\n    parser.add_argument(\"--tp_size\", type=int, default=1)\n    parser.add_argument(\n        \"--startup-timeout\",\n        type=int,\n        default=60,\n        help=\"Time in seconds to wait for the server to be ready (default: %(default)s)\",\n    )\n    args = parser.parse_args()\n\n    # Pass the model to the child uvicorn process via an env var\n    os.environ[\"MODEL_PATH\"] = args.model_path\n    os.environ[\"TP_SIZE\"] = str(args.tp_size)\n\n    # Start the server\n    process = start_server(args, timeout=args.startup_timeout)\n\n    # Define the prompts and sampling parameters\n    prompts = [\n        \"Hello, my name is\",\n        \"The president of the United States is\",\n        \"The capital of France is\",\n        \"The future of AI is\",\n    ]\n    max_new_tokens = 64\n    temperature = 0.1\n\n    # Define server url\n    server_url = f\"http://{args.host}:{args.port}\"\n\n    # Send requests to the server\n    send_requests(server_url, prompts, max_new_tokens, temperature)\n\n    # Terminate the server process\n    terminate_process(process)\n"
  },
  {
    "path": "examples/runtime/engine/launch_engine.py",
    "content": "\"\"\"\nThis example demonstrates how to launch the offline engine.\n\"\"\"\n\nimport sglang as sgl\n\n\ndef main():\n    llm = sgl.Engine(model_path=\"meta-llama/Meta-Llama-3.1-8B-Instruct\")\n    llm.generate(\"What is the capital of France?\")\n    llm.shutdown()\n\n\n# The __main__ condition is necessary here because we use \"spawn\" to create subprocesses\n# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/runtime/engine/offline_batch_inference.py",
    "content": "\"\"\"\nUsage:\npython3 offline_batch_inference.py  --model meta-llama/Llama-3.1-8B-Instruct\n\"\"\"\n\nimport argparse\nimport dataclasses\n\nimport sglang as sgl\nfrom sglang.srt.server_args import ServerArgs\n\n\ndef main(\n    server_args: ServerArgs,\n):\n    # Sample prompts.\n    prompts = [\n        \"Hello, my name is\",\n        \"The president of the United States is\",\n        \"The capital of France is\",\n        \"The future of AI is\",\n    ]\n    # Create a sampling params object.\n    sampling_params = {\"temperature\": 0.8, \"top_p\": 0.95}\n\n    # Create an LLM.\n    llm = sgl.Engine(**dataclasses.asdict(server_args))\n\n    outputs = llm.generate(prompts, sampling_params)\n    # Print the outputs.\n    for prompt, output in zip(prompts, outputs):\n        print(\"===============================\")\n        print(f\"Prompt: {prompt}\\nGenerated text: {output['text']}\")\n\n\n# The __main__ condition is necessary here because we use \"spawn\" to create subprocesses\n# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    ServerArgs.add_cli_args(parser)\n    args = parser.parse_args()\n    server_args = ServerArgs.from_cli_args(args)\n    main(server_args)\n"
  },
  {
    "path": "examples/runtime/engine/offline_batch_inference_async.py",
    "content": "\"\"\"\nUsage:\npython offline_batch_inference_async.py --model-path Qwen/Qwen2-VL-7B-Instruct\n\nNote:\nThis demo shows the usage of async generation,\nwhich is useful to implement an online-like generation with batched inference.\n\"\"\"\n\nimport argparse\nimport asyncio\nimport dataclasses\nimport time\n\nimport sglang as sgl\nfrom sglang.srt.server_args import ServerArgs\n\n\nclass InferenceEngine:\n    def __init__(self, **kwargs):\n        self.engine = sgl.Engine(**kwargs)\n\n    async def generate(self, prompt, sampling_params):\n        result = await self.engine.async_generate(prompt, sampling_params)\n        return result\n\n\nasync def run_server(server_args):\n    inference = InferenceEngine(**dataclasses.asdict(server_args))\n\n    # Sample prompts.\n    prompts = [\n        \"Hello, my name is\",\n        \"The president of the United States is\",\n        \"The capital of France is\",\n        \"The future of AI is\",\n    ] * 100\n\n    # Create a sampling params object.\n    sampling_params = {\"temperature\": 0.8, \"top_p\": 0.95}\n\n    # Run the generation tasks concurrently in async mode.\n    tasks = []\n    for prompt in prompts:\n        task = asyncio.create_task(inference.generate(prompt, sampling_params))\n        tasks.append(task)\n\n    # Get and print the result\n    for task in tasks:\n        await task\n        while True:\n            if not task.done():\n                time.sleep(1)\n            else:\n                result = task.result()\n                print(f\"Generated text: {result['text']}\")\n                break\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    ServerArgs.add_cli_args(parser)\n    args = parser.parse_args()\n    server_args = ServerArgs.from_cli_args(args)\n    asyncio.run(run_server(server_args))\n"
  },
  {
    "path": "examples/runtime/engine/offline_batch_inference_eagle.py",
    "content": "import sglang as sgl\n\n\ndef main():\n    # Sample prompts.\n    prompts = [\n        \"Hello, my name is\",\n        \"The president of the United States is\",\n        \"The capital of France is\",\n        \"The future of AI is\",\n    ]\n\n    # Create a sampling params object.\n    sampling_params = {\"temperature\": 0, \"max_new_tokens\": 30}\n\n    # Create an LLM.\n    llm = sgl.Engine(\n        model_path=\"meta-llama/Llama-2-7b-chat-hf\",\n        speculative_algorithm=\"EAGLE\",\n        speculative_draft_model_path=\"lmsys/sglang-EAGLE-llama2-chat-7B\",\n        speculative_num_steps=3,\n        speculative_eagle_topk=4,\n        speculative_num_draft_tokens=16,\n        cuda_graph_max_bs=8,\n    )\n\n    outputs = llm.generate(prompts, sampling_params)\n\n    # Print the outputs.\n    for prompt, output in zip(prompts, outputs):\n        print(\"===============================\")\n        print(f\"Prompt: {prompt}\\nGenerated text: {output['text']}\")\n\n\n# The __main__ condition is necessary here because we use \"spawn\" to create subprocesses\n# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/runtime/engine/offline_batch_inference_qwen_1m.py",
    "content": "\"\"\"\nUsage:\npython3 offline_batch_inference.py\n\"\"\"\n\nfrom urllib.request import urlopen\n\nimport sglang as sgl\n\n\ndef load_prompt() -> str:\n    # Test cases with various lengths can be found at:\n    #\n    # https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/64k.txt\n    # https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/200k.txt\n    # https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/600k.txt\n    # https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/1m.txt\n\n    with urlopen(\n        \"https://qianwen-res.oss-cn-beijing.aliyuncs.com\"\n        \"/Qwen2.5-1M/test-data/64k.txt\",\n        timeout=5,\n    ) as response:\n        prompt = response.read().decode(\"utf-8\")\n    return prompt\n\n\n# Processing the prompt.\ndef process_requests(llm: sgl.Engine, prompts: list[str]) -> None:\n    # Create a sampling params object.\n    sampling_params = {\n        \"temperature\": 0.7,\n        \"top_p\": 0.8,\n        \"top_k\": 20,\n        \"repetition_penalty\": 1.05,\n        \"max_new_tokens\": 256,\n    }\n    # Generate texts from the prompts.\n    outputs = llm.generate(prompts, sampling_params)\n    # Print the outputs.\n    for output in outputs:\n        prompt_token_ids = output[\"meta_info\"][\"prompt_tokens\"]\n        generated_text = output[\"text\"]\n        print(\n            f\"Prompt length: {prompt_token_ids}, \" f\"Generated text: {generated_text!r}\"\n        )\n\n\n# Create an LLM.\ndef initialize_engine() -> sgl.Engine:\n    llm = sgl.Engine(\n        model_path=\"Qwen/Qwen2.5-7B-Instruct-1M\",\n        context_length=1048576,\n        page_size=256,\n        attention_backend=\"dual_chunk_flash_attn\",\n        tp_size=4,\n        disable_radix_cache=True,\n        enable_mixed_chunk=False,\n        enable_torch_compile=False,\n        chunked_prefill_size=131072,\n        mem_fraction_static=0.6,\n        log_level=\"DEBUG\",\n    )\n    return llm\n\n\ndef main():\n    llm = initialize_engine()\n    prompt = load_prompt()\n    process_requests(llm, [prompt])\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/runtime/engine/offline_batch_inference_vlm.py",
    "content": "\"\"\"\nUsage:\npython offline_batch_inference_vlm.py --model-path Qwen/Qwen2-VL-7B-Instruct\n\"\"\"\n\nimport argparse\nimport dataclasses\n\nimport sglang as sgl\nfrom sglang.srt.parser.conversation import chat_templates\nfrom sglang.srt.server_args import ServerArgs\n\n\ndef main(\n    server_args: ServerArgs,\n):\n    vlm = sgl.Engine(**dataclasses.asdict(server_args))\n\n    conv = chat_templates[server_args.chat_template].copy()\n    image_token = conv.image_token\n\n    image_url = \"https://github.com/sgl-project/sglang/blob/main/examples/assets/example_image.png?raw=true\"\n\n    prompt = f\"What's in this image?\\n{image_token}\"\n\n    sampling_params = {\n        \"temperature\": 0.001,\n        \"max_new_tokens\": 30,\n    }\n\n    output = vlm.generate(\n        prompt=prompt,\n        image_data=image_url,\n        sampling_params=sampling_params,\n    )\n\n    print(\"===============================\")\n    print(f\"Prompt: {prompt}\")\n    print(f\"Generated text: {output['text']}\")\n\n    vlm.shutdown()\n\n\n# The __main__ condition is necessary here because we use \"spawn\" to create subprocesses\n# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    ServerArgs.add_cli_args(parser)\n    args = parser.parse_args()\n\n    server_args = ServerArgs.from_cli_args(args)\n    main(server_args)\n"
  },
  {
    "path": "examples/runtime/engine/readme.md",
    "content": "# SGLang Engine\n\nSGLang provides a direct inference engine without the need for an HTTP server. There are generally these use cases:\n\n- [Offline Batch Inference](#offline-batch-inference)\n- [Embedding Generation](#embedding-generation)\n- [Custom Server](#custom-server)\n- [Token-In-Token-Out for RLHF](#token-in-token-out-for-rlhf)\n- [Inference Using FastAPI](#inference-using-fastapi)\n\n## Examples\n\n### [Offline Batch Inference](./offline_batch_inference.py)\n\nIn this example, we launch an SGLang engine and feed a batch of inputs for inference. If you provide a very large batch, the engine will intelligently schedule the requests to process efficiently and prevent OOM (Out of Memory) errors.\n\n### [Embedding Generation](./embedding.py)\n\nIn this example, we launch an SGLang engine and feed a batch of inputs for embedding generation.\n\n### [Custom Server](./custom_server.py)\n\nThis example demonstrates how to create a custom server on top of the SGLang Engine. We use [Sanic](https://sanic.dev/en/) as an example. The server supports both non-streaming and streaming endpoints.\n\n#### Steps\n\n1. Install Sanic:\n\n   ```bash\n   pip install sanic\n   ```\n\n2. Run the server:\n\n   ```bash\n   python custom_server\n   ```\n\n3. Send requests:\n\n   ```bash\n   curl -X POST http://localhost:8000/generate  -H \"Content-Type: application/json\"  -d '{\"prompt\": \"The Transformer architecture is...\"}'\n   curl -X POST http://localhost:8000/generate_stream  -H \"Content-Type: application/json\"  -d '{\"prompt\": \"The Transformer architecture is...\"}' --no-buffer\n   ```\n\n   This will send both non-streaming and streaming requests to the server.\n\n### [Token-In-Token-Out for RLHF](../token_in_token_out)\n\nIn this example, we launch an SGLang engine, feed tokens as input and generate tokens as output.\n\n### [Inference Using FastAPI](fastapi_engine_inference.py)\n\nThis example demonstrates how to create a FastAPI server that uses the SGLang engine for text generation.\n"
  },
  {
    "path": "examples/runtime/engine/save_remote_state.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\"\"\"\nSaves each worker's model state dict directly to a checkpoint, which enables a\nfast load path for large tensor-parallel models where each worker only needs to\nread its own shard rather than the entire checkpoint.\n\nExample usage:\n\npython save_remote_state.py \\\n    --model-path /path/to/load \\\n    --tensor-parallel-size 8 \\\n    --remote-model-save-url [protocol]://[host]:[port]/[model_name] \\\n\nThen, the model can be loaded with\n\nllm = Engine(\n    model_path=\"[protocol]://[host]:[port]/[model_name]\",\n    tensor_parallel_size=8,\n)\n\"\"\"\n\nimport dataclasses\nfrom argparse import ArgumentParser\nfrom pathlib import Path\n\nfrom sglang import Engine, ServerArgs\n\nparser = ArgumentParser()\nServerArgs.add_cli_args(parser)\n\nparser.add_argument(\n    \"--remote-model-save-url\",\n    required=True,\n    type=str,\n    help=\"remote address to store model weights\",\n)\nparser.add_argument(\n    \"--remote-draft-model-save-url\",\n    default=None,\n    type=str,\n    help=\"remote address to store draft model weights\",\n)\n\n\ndef main(args):\n    engine_args = ServerArgs.from_cli_args(args)\n    model_path = engine_args.model_path\n    if not Path(model_path).is_dir():\n        raise ValueError(\"model path must be a local directory\")\n    # Create LLM instance from arguments\n    llm = Engine(**dataclasses.asdict(engine_args))\n    llm.save_remote_model(\n        url=args.remote_model_save_url, draft_url=args.remote_draft_model_save_url\n    )\n    print(\"save remote (draft) model successfully\")\n\n\nif __name__ == \"__main__\":\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/runtime/engine/save_sharded_state.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\"\"\"\nSaves each worker's model state dict directly to a checkpoint, which enables a\nfast load path for large tensor-parallel models where each worker only needs to\nread its own shard rather than the entire checkpoint.\n\nExample usage:\n\npython save_sharded_state.py \\\n    --model-path /path/to/load \\\n    --quantization deepspeedfp \\\n    --tensor-parallel-size 8 \\\n    --output /path/to/save\n\nThen, the model can be loaded with\n\nllm = Engine(\n    model_path=\"/path/to/save\",\n    load_format=\"sharded_state\",\n    quantization=\"deepspeedfp\",\n    tensor_parallel_size=8,\n)\n\"\"\"\n\nimport dataclasses\nimport os\nimport shutil\nfrom argparse import ArgumentParser\nfrom pathlib import Path\n\nfrom sglang import Engine, ServerArgs\n\nparser = ArgumentParser()\nServerArgs.add_cli_args(parser)\n\nparser.add_argument(\n    \"--output\", \"-o\", required=True, type=str, help=\"path to output checkpoint\"\n)\nparser.add_argument(\n    \"--file-pattern\", type=str, help=\"string pattern of saved filenames\"\n)\nparser.add_argument(\n    \"--max-file-size\",\n    type=str,\n    default=5 * 1024**3,\n    help=\"max size (in bytes) of each safetensors file\",\n)\n\n\ndef main(args):\n    engine_args = ServerArgs.from_cli_args(args)\n    model_path = engine_args.model_path\n    if not Path(model_path).is_dir():\n        raise ValueError(\"model path must be a local directory\")\n    # Create LLM instance from arguments\n    llm = Engine(**dataclasses.asdict(engine_args))\n    Path(args.output).mkdir(exist_ok=True)\n    llm.save_sharded_model(\n        path=args.output, pattern=args.file_pattern, max_size=args.max_file_size\n    )\n\n    # Copy metadata files to output directory\n    for file in os.listdir(model_path):\n        if os.path.splitext(file)[1] not in (\".bin\", \".pt\", \".safetensors\"):\n            if os.path.isdir(os.path.join(model_path, file)):\n                shutil.copytree(\n                    os.path.join(model_path, file), os.path.join(args.output, file)\n                )\n            else:\n                shutil.copy(os.path.join(model_path, file), args.output)\n\n\nif __name__ == \"__main__\":\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/runtime/hidden_states/hidden_states_engine.py",
    "content": "\"\"\"\nUsage:\npython hidden_states.py\n\nNote that each time you change the `return_hidden_states` parameter,\nthe cuda graph will be recaptured, which might lead to a performance hit.\nSo avoid getting hidden states and completions alternately.\n\"\"\"\n\nimport torch\n\nimport sglang as sgl\n\n\ndef main():\n    prompts = [\n        \"Hello, my name is\",\n        \"The president of the United States is\",\n        \"The capital of France is\",\n        \"The future of AI is\",\n    ]\n    # Create an LLM.\n    llm = sgl.Engine(\n        model_path=\"Alibaba-NLP/gte-Qwen2-1.5B-instruct\",\n        enable_return_hidden_states=True,\n    )\n\n    sampling_params = {\n        \"temperature\": 0.8,\n        \"top_p\": 0.95,\n        \"max_new_tokens\": 10,\n    }\n\n    outputs = llm.generate(\n        prompts, sampling_params=sampling_params, return_hidden_states=True\n    )\n\n    llm.shutdown()\n\n    for prompt, output in zip(prompts, outputs):\n        for i in range(len(output[\"meta_info\"][\"hidden_states\"])):\n            output[\"meta_info\"][\"hidden_states\"][i] = torch.tensor(\n                output[\"meta_info\"][\"hidden_states\"][i], dtype=torch.bfloat16\n            )\n        print(\"===============================\")\n        print(\n            f\"Prompt: {prompt}\\n\"\n            f\"Generated text: {output['text']}\\n\"\n            f\"Prompt_Tokens: {output['meta_info']['prompt_tokens']}\\t\"\n            f\"Completion_tokens: {output['meta_info']['completion_tokens']}\"\n        )\n        print(\"Hidden states: \")\n        hidden_states = torch.cat(\n            [\n                i.unsqueeze(0) if len(i.shape) == 1 else i\n                for i in output[\"meta_info\"][\"hidden_states\"]\n            ]\n        )\n        print(hidden_states)\n        print()\n\n\n# The __main__ condition is necessary here because we use \"spawn\" to create subprocesses\n# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/runtime/hidden_states/hidden_states_server.py",
    "content": "\"\"\"\nUsage:\n\npython hidden_states_server.py\n\nNote that each time you change the `return_hidden_states` parameter,\nthe cuda graph will be recaptured, which might lead to a performance hit.\nSo avoid getting hidden states and completions alternately.\n\"\"\"\n\nimport requests\nimport torch\n\nfrom sglang.test.test_utils import is_in_ci\nfrom sglang.utils import terminate_process, wait_for_server\n\nif is_in_ci():\n    from docs.backend.patch import launch_server_cmd\nelse:\n    from sglang.utils import launch_server_cmd\n\n\ndef main():\n    # Launch the server\n    server_process, port = launch_server_cmd(\n        \"python -m sglang.launch_server --model-path Alibaba-NLP/gte-Qwen2-1.5B-instruct --enable-return-hidden-states --host 0.0.0.0\"\n    )\n    wait_for_server(f\"http://localhost:{port}\", process=server_process)\n\n    prompts = [\n        \"Hello, my name is\",\n        \"The president of the United States is\",\n        \"The capital of France is\",\n        \"The future of AI is\",\n    ]\n\n    sampling_params = {\n        \"temperature\": 0.8,\n        \"top_p\": 0.95,\n        \"max_new_tokens\": 10,\n    }\n\n    json_data = {\n        \"text\": prompts,\n        \"sampling_params\": sampling_params,\n        \"return_hidden_states\": True,\n    }\n\n    response = requests.post(\n        f\"http://localhost:{port}/generate\",\n        json=json_data,\n    )\n\n    terminate_process(server_process)\n\n    outputs = response.json()\n    for prompt, output in zip(prompts, outputs):\n        for i in range(len(output[\"meta_info\"][\"hidden_states\"])):\n            output[\"meta_info\"][\"hidden_states\"][i] = torch.tensor(\n                output[\"meta_info\"][\"hidden_states\"][i], dtype=torch.bfloat16\n            )\n        print(\"===============================\")\n        print(\n            f\"Prompt: {prompt}\\n\"\n            f\"Generated text: {output['text']}\\n\"\n            f\"Prompt_Tokens: {output['meta_info']['prompt_tokens']}\\t\"\n            f\"Completion_tokens: {output['meta_info']['completion_tokens']}\"\n        )\n        print(\"Hidden states: \")\n        hidden_states = torch.cat(\n            [\n                i.unsqueeze(0) if len(i.shape) == 1 else i\n                for i in output[\"meta_info\"][\"hidden_states\"]\n            ]\n        )\n        print(hidden_states)\n        print()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/runtime/lora.py",
    "content": "\"\"\"\nOpenAI-compatible LoRA adapter usage with SGLang.\n\nServer Setup:\n    python -m sglang.launch_server \\\\\n        --model meta-llama/Llama-3.1-8B-Instruct \\\\\n        --enable-lora \\\\\n        --lora-paths sql=/path/to/sql python=/path/to/python\n\"\"\"\n\nimport openai\n\nclient = openai.Client(base_url=\"http://127.0.0.1:30000/v1\", api_key=\"EMPTY\")\n\n\ndef main():\n    print(\"SGLang OpenAI-Compatible LoRA Examples\\n\")\n\n    # Example 1: NEW - Adapter in model parameter (OpenAI-compatible)\n    print(\"1. Chat with LoRA adapter in model parameter:\")\n    response = client.chat.completions.create(\n        model=\"meta-llama/Llama-3.1-8B-Instruct:sql\",  # ← adapter:name syntax\n        messages=[{\"role\": \"user\", \"content\": \"Convert to SQL: show all users\"}],\n        max_tokens=50,\n    )\n    print(f\"   Response: {response.choices[0].message.content}\\n\")\n\n    # Example 2: Completions API with adapter\n    print(\"2. Completion with LoRA adapter:\")\n    response = client.completions.create(\n        model=\"meta-llama/Llama-3.1-8B-Instruct:python\",\n        prompt=\"def fibonacci(n):\",\n        max_tokens=50,\n    )\n    print(f\"   Response: {response.choices[0].text}\\n\")\n\n    # Example 3: OLD - Backward compatible with explicit lora_path\n    print(\"3. Backward compatible (explicit lora_path):\")\n    response = client.chat.completions.create(\n        model=\"meta-llama/Llama-3.1-8B-Instruct\",\n        messages=[{\"role\": \"user\", \"content\": \"Convert to SQL: show all users\"}],\n        extra_body={\"lora_path\": \"sql\"},\n        max_tokens=50,\n    )\n    print(f\"   Response: {response.choices[0].message.content}\\n\")\n\n    # Example 4: Base model (no adapter)\n    print(\"4. Base model without adapter:\")\n    response = client.chat.completions.create(\n        model=\"meta-llama/Llama-3.1-8B-Instruct\",\n        messages=[{\"role\": \"user\", \"content\": \"Hello!\"}],\n        max_tokens=30,\n    )\n    print(f\"   Response: {response.choices[0].message.content}\\n\")\n\n    print(\"All examples completed!\")\n\n\nif __name__ == \"__main__\":\n    try:\n        main()\n    except Exception as e:\n        print(f\"Error: {e}\")\n        print(\n            \"\\nEnsure server is running:\\n\"\n            \"  python -m sglang.launch_server --model ... --enable-lora --lora-paths ...\"\n        )\n"
  },
  {
    "path": "examples/runtime/multimodal/llama3_llava_server.py",
    "content": "\"\"\"\nUsage:\n# Installing latest llava-next: pip install git+https://github.com/LLaVA-VL/LLaVA-NeXT.git\n# Installing latest sglang.\n\n# Endpoint Service CLI:\npython -m sglang.launch_server --model-path lmms-lab/llama3-llava-next-8b --port=30000\n\npython3 llama3_llava_server.py\n\nOutput:\n\"Friends posing for a fun photo with a life-sized teddy bear, creating a playful and memorable moment.\"\n\"\"\"\n\nimport argparse\nimport asyncio\nimport copy\nimport json\n\nimport aiohttp\nimport requests\nfrom llava.conversation import conv_llava_llama_3\n\nfrom sglang.utils import normalize_base_url\n\n\nasync def send_request(url, data, delay=0):\n    await asyncio.sleep(delay)\n    async with aiohttp.ClientSession() as session:\n        async with session.post(url, json=data) as resp:\n            output = await resp.json()\n    return output\n\n\nasync def test_concurrent(args):\n    url = normalize_base_url(args.host, args.port)\n\n    prompt = \"<image>\\nPlease generate caption towards this image.\"\n    conv_template = copy.deepcopy(conv_llava_llama_3)\n    conv_template.append_message(role=conv_template.roles[0], message=prompt)\n    conv_template.append_message(role=conv_template.roles[1], message=None)\n    prompt_with_template = conv_template.get_prompt()\n    response = []\n    for i in range(1):\n        response.append(\n            send_request(\n                url + \"/generate\",\n                {\n                    \"text\": prompt_with_template,\n                    \"image_data\": \"https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg\",\n                    \"sampling_params\": {\n                        \"max_new_tokens\": 1024,\n                        \"temperature\": 0,\n                        \"top_p\": 1.0,\n                        \"presence_penalty\": 2,\n                        \"frequency_penalty\": 2,\n                        \"stop\": \"<|eot_id|>\",\n                    },\n                },\n            )\n        )\n\n    rets = await asyncio.gather(*response)\n    for ret in rets:\n        print(ret[\"text\"])\n\n\ndef test_streaming(args):\n    url = normalize_base_url(args.host, args.port)\n    prompt = \"<image>\\nPlease generate caption towards this image.\"\n    conv_template = copy.deepcopy(conv_llava_llama_3)\n    conv_template.append_message(role=conv_template.roles[0], message=prompt)\n    conv_template.append_message(role=conv_template.roles[1], message=None)\n    prompt_with_template = conv_template.get_prompt()\n    pload = {\n        \"text\": prompt_with_template,\n        \"sampling_params\": {\n            \"max_new_tokens\": 1024,\n            \"temperature\": 0,\n            \"top_p\": 1.0,\n            \"presence_penalty\": 2,\n            \"frequency_penalty\": 2,\n            \"stop\": \"<|eot_id|>\",\n        },\n        \"image_data\": \"https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg\",\n        \"stream\": True,\n    }\n    response = requests.post(\n        url + \"/generate\",\n        json=pload,\n        stream=True,\n    )\n\n    prev = 0\n    for chunk in response.iter_lines(decode_unicode=False):\n        chunk = chunk.decode(\"utf-8\")\n        if chunk and chunk.startswith(\"data:\"):\n            if chunk == \"data: [DONE]\":\n                break\n            data = json.loads(chunk[5:].strip(\"\\n\"))\n            output = data[\"text\"].strip()\n            print(output[prev:], end=\"\", flush=True)\n            prev = len(output)\n    print(\"\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--host\", type=str, default=\"127.0.0.1\")\n    parser.add_argument(\"--port\", type=int, default=30000)\n    args = parser.parse_args()\n    asyncio.run(test_concurrent(args))\n    test_streaming(args)\n"
  },
  {
    "path": "examples/runtime/multimodal/llava_onevision_server.py",
    "content": "\"\"\"\nUsage:\n\npython3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --port=30000 --tp-size=8\n\npython3 llava_onevision_server.py\n\"\"\"\n\nimport io\nimport os\nimport sys\nimport time\n\nimport numpy as np\nimport openai\nimport pybase64\nimport requests\nfrom PIL import Image\n\nfrom sglang.srt.utils.video_decoder import VideoDecoderWrapper\n\n# pip install httpx==0.23.3\n# pip install torchcodec\n# pip install protobuf==3.20.0\n\n\ndef download_video(url, cache_dir):\n    file_path = os.path.join(cache_dir, \"jobs.mp4\")\n    os.makedirs(cache_dir, exist_ok=True)\n\n    response = requests.get(url)\n    response.raise_for_status()\n\n    with open(file_path, \"wb\") as f:\n        f.write(response.content)\n\n    print(f\"File downloaded and saved to: {file_path}\")\n    return file_path\n\n\ndef create_openai_client(base_url):\n    return openai.Client(api_key=\"EMPTY\", base_url=base_url)\n\n\ndef image_stream_request_test(client):\n    print(\"----------------------Image Stream Request Test----------------------\")\n    stream_request = client.chat.completions.create(\n        model=\"default\",\n        messages=[\n            {\n                \"role\": \"user\",\n                \"content\": [\n                    {\n                        \"type\": \"image_url\",\n                        \"image_url\": {\n                            \"url\": \"https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png\"\n                        },\n                    },\n                    {\n                        \"type\": \"text\",\n                        \"text\": \"Please describe this image. Please list the benchmarks and the models.\",\n                    },\n                ],\n            },\n        ],\n        temperature=0.7,\n        max_tokens=1024,\n        stream=True,\n    )\n    stream_response = \"\"\n\n    for chunk in stream_request:\n        if chunk.choices[0].delta.content is not None:\n            content = chunk.choices[0].delta.content\n            stream_response += content\n            sys.stdout.write(content)\n            sys.stdout.flush()\n\n    print(\"-\" * 30)\n\n\ndef multi_image_stream_request_test(client):\n    print(\n        \"----------------------Multi-Images Stream Request Test----------------------\"\n    )\n    stream_request = client.chat.completions.create(\n        model=\"default\",\n        messages=[\n            {\n                \"role\": \"user\",\n                \"content\": [\n                    {\n                        \"type\": \"image_url\",\n                        \"image_url\": {\n                            \"url\": \"https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png\"\n                        },\n                        \"modalities\": \"multi-images\",\n                    },\n                    {\n                        \"type\": \"image_url\",\n                        \"image_url\": {\n                            \"url\": \"https://raw.githubusercontent.com/sgl-project/sglang/main/examples/assets/example_image.png\"\n                        },\n                        \"modalities\": \"multi-images\",\n                    },\n                    {\n                        \"type\": \"text\",\n                        \"text\": \"I have shown you two images. Please describe the two images to me.\",\n                    },\n                ],\n            },\n        ],\n        temperature=0.7,\n        max_tokens=1024,\n        stream=True,\n    )\n    stream_response = \"\"\n\n    for chunk in stream_request:\n        if chunk.choices[0].delta.content is not None:\n            content = chunk.choices[0].delta.content\n            stream_response += content\n            sys.stdout.write(content)\n            sys.stdout.flush()\n\n    print(\"-\" * 30)\n\n\ndef video_stream_request_test(client, video_path):\n    print(\"------------------------Video Stream Request Test----------------------\")\n    messages = prepare_video_messages(video_path)\n\n    video_request = client.chat.completions.create(\n        model=\"default\",\n        messages=messages,\n        temperature=0,\n        max_tokens=1024,\n        stream=True,\n    )\n    print(\"-\" * 30)\n    video_response = \"\"\n\n    for chunk in video_request:\n        if chunk.choices[0].delta.content is not None:\n            content = chunk.choices[0].delta.content\n            video_response += content\n            sys.stdout.write(content)\n            sys.stdout.flush()\n    print(\"-\" * 30)\n\n\ndef image_speed_test(client):\n    print(\"----------------------Image Speed Test----------------------\")\n    start_time = time.perf_counter()\n    request = client.chat.completions.create(\n        model=\"default\",\n        messages=[\n            {\n                \"role\": \"user\",\n                \"content\": [\n                    {\n                        \"type\": \"image_url\",\n                        \"image_url\": {\n                            \"url\": \"https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png\"\n                        },\n                    },\n                    {\n                        \"type\": \"text\",\n                        \"text\": \"Please describe this image. Please list the benchmarks and the models.\",\n                    },\n                ],\n            },\n        ],\n        temperature=0,\n        max_tokens=1024,\n    )\n    end_time = time.perf_counter()\n    response = request.choices[0].message.content\n    print(response)\n    print(\"-\" * 30)\n    print_speed_test_results(request, start_time, end_time)\n\n\ndef video_speed_test(client, video_path):\n    print(\"------------------------Video Speed Test------------------------\")\n    messages = prepare_video_messages(video_path)\n\n    start_time = time.perf_counter()\n    video_request = client.chat.completions.create(\n        model=\"default\",\n        messages=messages,\n        temperature=0,\n        max_tokens=1024,\n    )\n    end_time = time.perf_counter()\n    video_response = video_request.choices[0].message.content\n    print(video_response)\n    print(\"-\" * 30)\n    print_speed_test_results(video_request, start_time, end_time)\n\n\ndef prepare_video_messages(video_path):\n    max_frames_num = 32\n    decoder = VideoDecoderWrapper(video_path)\n    total_frame_num = len(decoder)\n    uniform_sampled_frames = np.linspace(\n        0, total_frame_num - 1, max_frames_num, dtype=int\n    )\n    frame_idx = uniform_sampled_frames.tolist()\n    frames = decoder.get_frames_at(frame_idx)\n\n    base64_frames = []\n    for frame in frames:\n        pil_img = Image.fromarray(frame)\n        buff = io.BytesIO()\n        pil_img.save(buff, format=\"JPEG\")\n        base64_str = pybase64.b64encode(buff.getvalue()).decode(\"utf-8\")\n        base64_frames.append(base64_str)\n\n    messages = [{\"role\": \"user\", \"content\": []}]\n\n    for base64_frame in base64_frames:\n        frame_format = {\n            \"type\": \"image_url\",\n            \"image_url\": {\"url\": f\"data:image/jpeg;base64,{base64_frame}\"},\n            \"modalities\": \"video\",\n        }\n        messages[0][\"content\"].append(frame_format)\n\n    prompt = {\"type\": \"text\", \"text\": \"Please describe the video in detail.\"}\n    messages[0][\"content\"].append(prompt)\n\n    return messages\n\n\ndef print_speed_test_results(request, start_time, end_time):\n    total_tokens = request.usage.total_tokens\n    completion_tokens = request.usage.completion_tokens\n    prompt_tokens = request.usage.prompt_tokens\n\n    print(f\"Total tokens: {total_tokens}\")\n    print(f\"Completion tokens: {completion_tokens}\")\n    print(f\"Prompt tokens: {prompt_tokens}\")\n    print(f\"Time taken: {end_time - start_time} seconds\")\n    print(f\"Token per second: {total_tokens / (end_time - start_time)}\")\n    print(f\"Completion token per second: {completion_tokens / (end_time - start_time)}\")\n    print(f\"Prompt token per second: {prompt_tokens / (end_time - start_time)}\")\n\n\ndef main():\n    url = \"https://raw.githubusercontent.com/EvolvingLMMs-Lab/sglang/dev/onevision_local/assets/jobs.mp4\"\n    cache_dir = os.path.expanduser(\"~/.cache\")\n    video_path = download_video(url, cache_dir)\n\n    client = create_openai_client(\"http://127.0.0.1:30000/v1\")\n\n    image_stream_request_test(client)\n    multi_image_stream_request_test(client)\n    video_stream_request_test(client, video_path)\n    image_speed_test(client)\n    video_speed_test(client, video_path)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/runtime/multimodal/pixtral_server.py",
    "content": "\"\"\"\nUsage:\n# Run a Pixtral model with SGLang:\n# HuggingFace:\npython -m sglang.launch_server --model-path mistral-community/pixtral-12b --port=30000\n# ModelScope:\npython -m sglang.launch_server --model-path AI-ModelScope/pixtral-12b --port=30000\n\n# Then test it with:\npython pixtral_server.py\n\nThis script tests Pixtral model with both single and multiple images.\n\"\"\"\n\nimport argparse\nimport asyncio\nimport json\n\nimport aiohttp\nimport requests\n\nfrom sglang.utils import normalize_base_url\n\nIMAGE_TOKEN_SEP = \"\\n[IMG]\"\nROUTE = \"/generate\"\n\n\nasync def send_request(url, data, delay=0):\n    await asyncio.sleep(delay)\n    async with aiohttp.ClientSession() as session:\n        async with session.post(url, json=data) as resp:\n            output = await resp.json()\n    return output\n\n\nasync def test_concurrent(args):\n    url = f\"{normalize_base_url(args.host, args.port)}{ROUTE}\"\n\n    # Single image test\n    if args.single_image:\n        prompt = f\"<s>[INST]Describe this image in detail.{IMAGE_TOKEN_SEP}[/INST]\"\n        image_url = \"https://picsum.photos/id/237/400/300\"\n        modality = [\"image\"]\n    # Multiple images test\n    else:\n        image_urls = [\n            \"https://picsum.photos/id/237/400/300\",\n            \"https://picsum.photos/id/27/500/500\",\n        ]\n        prompt = f\"<s>[INST]How many photos are there? Describe each in a very short sentence.{IMAGE_TOKEN_SEP * len(image_urls)}[/INST]\"\n        image_url = image_urls\n        modality = [\"multi-images\"]\n\n    response = await send_request(\n        url,\n        {\n            \"text\": prompt,\n            \"image_data\": image_url,\n            \"sampling_params\": {\n                \"max_new_tokens\": 100,\n                \"temperature\": 0.7,\n                \"top_p\": 0.9,\n            },\n            \"modalities\": modality,\n        },\n    )\n\n    print(f\"Response: {response}\")\n    if \"text\" in response:\n        print(\"\\nOutput text:\", response[\"text\"])\n\n\ndef test_streaming(args):\n    url = f\"{normalize_base_url(args.host, args.port)}/generate\"\n\n    # Single image test\n    if args.single_image:\n        prompt = f\"<s>[INST]Describe this image in detail.{IMAGE_TOKEN_SEP}[/INST]\"\n        image_data = \"https://picsum.photos/id/237/400/300\"\n        modality = [\"image\"]\n    # Multiple images test\n    else:\n        image_urls = [\n            \"https://picsum.photos/id/237/400/300\",\n            \"https://picsum.photos/id/27/500/500\",\n        ]\n        prompt = f\"<s>[INST]How many photos are there? Describe each in a very short sentence.{IMAGE_TOKEN_SEP * len(image_urls)}[/INST]\"\n        image_data = image_urls\n        modality = [\"multi-images\"]\n\n    pload = {\n        \"text\": prompt,\n        \"image_data\": image_data,\n        \"sampling_params\": {\"max_new_tokens\": 100, \"temperature\": 0.7, \"top_p\": 0.9},\n        \"modalities\": modality,\n        \"stream\": True,\n    }\n\n    response = requests.post(url, json=pload, stream=True)\n\n    print(\"Streaming response:\")\n    prev = 0\n    for chunk in response.iter_lines(decode_unicode=False):\n        chunk = chunk.decode(\"utf-8\")\n        if chunk and chunk.startswith(\"data:\"):\n            if chunk == \"data: [DONE]\":\n                break\n            data = json.loads(chunk[5:].strip(\"\\n\"))\n            output = data[\"text\"].strip()\n            print(output[prev:], end=\"\", flush=True)\n            prev = len(output)\n    print(\"\\n\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--host\", type=str, default=\"127.0.0.1\")\n    parser.add_argument(\"--port\", type=int, default=30000)\n    parser.add_argument(\n        \"--single-image\",\n        action=\"store_true\",\n        help=\"Test with single image instead of multiple images\",\n    )\n    parser.add_argument(\"--no-stream\", action=\"store_true\", help=\"Don't test streaming\")\n    args = parser.parse_args()\n\n    asyncio.run(test_concurrent(args))\n    if not args.no_stream:\n        test_streaming(args)\n"
  },
  {
    "path": "examples/runtime/multimodal/qwen_llava_server.py",
    "content": "\"\"\"\nUsage:\n# Installing latest llava-next: pip install git+https://github.com/LLaVA-VL/LLaVA-NeXT.git\n# Installing latest sglang.\n\n# Endpoint Service CLI:\npython -m sglang.launch_server --model-path lmms-lab/llava-next-72b --port=30000 --tp-size=8\n\npython3 qwen_llava_server.py\n\nOutput:\n\"Two children pose with a large teddy bear, one holding a smaller stuffed bear, in a room with an American flag and potted plants.\"\n\"\"\"\n\nimport argparse\nimport asyncio\nimport copy\nimport json\n\nimport aiohttp\nimport requests\nfrom llava.conversation import conv_qwen\n\nfrom sglang.utils import normalize_base_url\n\n\nasync def send_request(url, data, delay=0):\n    await asyncio.sleep(delay)\n    async with aiohttp.ClientSession() as session:\n        async with session.post(url, json=data) as resp:\n            output = await resp.json()\n    return output\n\n\nasync def test_concurrent(args):\n    url = normalize_base_url(args.host, args.port)\n\n    prompt = \"<image>\\nPlease generate caption towards this image.\"\n    conv_template = copy.deepcopy(conv_qwen)\n    conv_template.append_message(role=conv_template.roles[0], message=prompt)\n    conv_template.append_message(role=conv_template.roles[1], message=None)\n    prompt_with_template = conv_template.get_prompt()\n    response = []\n    for i in range(1):\n        response.append(\n            send_request(\n                url + \"/generate\",\n                {\n                    \"text\": prompt_with_template,\n                    \"image_data\": \"https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg\",\n                    \"sampling_params\": {\n                        \"max_new_tokens\": 1024,\n                        \"temperature\": 0,\n                        \"top_p\": 1.0,\n                        \"presence_penalty\": 2,\n                        \"frequency_penalty\": 2,\n                        \"stop\": \"<|im_end|>\",\n                    },\n                },\n            )\n        )\n\n    rets = await asyncio.gather(*response)\n    for ret in rets:\n        print(ret[\"text\"])\n\n\ndef test_streaming(args):\n    url = normalize_base_url(args.host, args.port)\n    prompt = \"<image>\\nPlease generate caption towards this image.\"\n    conv_template = copy.deepcopy(conv_qwen)\n    conv_template.append_message(role=conv_template.roles[0], message=prompt)\n    conv_template.append_message(role=conv_template.roles[1], message=None)\n    prompt_with_template = conv_template.get_prompt()\n    pload = {\n        \"text\": prompt_with_template,\n        \"sampling_params\": {\n            \"max_new_tokens\": 1024,\n            \"temperature\": 0,\n            \"top_p\": 1.0,\n            \"presence_penalty\": 2,\n            \"frequency_penalty\": 2,\n            \"stop\": \"<|im_end|>\",\n        },\n        \"image_data\": \"https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg\",\n        \"stream\": True,\n    }\n    response = requests.post(\n        url + \"/generate\",\n        json=pload,\n        stream=True,\n    )\n\n    prev = 0\n    for chunk in response.iter_lines(decode_unicode=False):\n        chunk = chunk.decode(\"utf-8\")\n        if chunk and chunk.startswith(\"data:\"):\n            if chunk == \"data: [DONE]\":\n                break\n            data = json.loads(chunk[5:].strip(\"\\n\"))\n            output = data[\"text\"].strip()\n            print(output[prev:], end=\"\", flush=True)\n            prev = len(output)\n    print(\"\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--host\", type=str, default=\"127.0.0.1\")\n    parser.add_argument(\"--port\", type=int, default=30000)\n    args = parser.parse_args()\n    asyncio.run(test_concurrent(args))\n    test_streaming(args)\n"
  },
  {
    "path": "examples/runtime/multimodal_embedding.py",
    "content": "# launch server\n# python -m sglang.launch_server --model-path Alibaba-NLP/gme-Qwen2-VL-2B-Instruct --is-embedding\n\nimport requests\n\nurl = \"http://127.0.0.1:30000\"\n\ntext_input = \"Represent this image in embedding space.\"\nimage_path = \"https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild/resolve/main/images/023.jpg\"\n\npayload = {\n    \"model\": \"gme-qwen2-vl\",\n    \"input\": [{\"text\": text_input}, {\"image\": image_path}],\n}\n\nresponse = requests.post(url + \"/v1/embeddings\", json=payload).json()\n\nprint(\"Embeddings:\", [x.get(\"embedding\") for x in response.get(\"data\", [])])\n"
  },
  {
    "path": "examples/runtime/openai_chat_with_response_prefill.py",
    "content": "\"\"\"\nUsage:\n1) Launch the server in one terminal:\n   python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --port 30000\n\n2) Run this script in another terminal:\n   python openai_chat_with_response_prefill.py\n\nThis example demonstrates two chat completion calls:\n- One with continue_final_message enabled (the final assistant message is used as a prefill).\n- One without continue_final_message (the final assistant message remains, starting a new turn).\n\"\"\"\n\nimport openai\n\nclient = openai.Client(base_url=\"http://127.0.0.1:30000/v1\", api_key=\"EMPTY\")\n\nmessages = [\n    {\"role\": \"system\", \"content\": \"You are a helpful AI assistant.\"},\n    {\n        \"role\": \"user\",\n        \"content\": \"\"\"\nExtract the name, size, price, and color from this product description as a JSON object:\n\n<description>\nThe SmartHome Mini is a compact smart home assistant available in black or white for only $49.99.\nAt just 5 inches wide, it lets you control lights, thermostats, and other connected devices via voice or app—\nno matter where you place it in your home.\nThis affordable little hub brings convenient hands-free control to your smart devices.\n</description>\n\"\"\",\n    },\n    {\"role\": \"assistant\", \"content\": \"{\\n\"},\n]\n\n# Calling the API with continue_final_message enabled.\nprint(\"=== Prefill with continue_final_messagem ===\")\nresponse_with = client.chat.completions.create(\n    model=\"meta-llama/Llama-3.1-8B-Instruct\",\n    messages=messages,\n    temperature=0,\n    extra_body={\"continue_final_message\": True},\n)\nprint(response_with.choices[0].message.content)\n\n# Calling the API without continue_final_message (using default behavior).\nprint(\"\\n=== Prefill without continue_final_message ===\")\nresponse_without = client.chat.completions.create(\n    model=\"meta-llama/Llama-3.1-8B-Instruct\",\n    messages=messages,\n    temperature=0,\n)\nprint(response_without.choices[0].message.content)\n"
  },
  {
    "path": "examples/runtime/qwen3_vl_reranker.py",
    "content": "\"\"\"\nExample usage of Qwen3-VL-Reranker with SGLang.\n\nThis example demonstrates how to use the Qwen3-VL-Reranker model for multimodal\nreranking tasks, supporting text, images, and videos.\n\nServer Launch:\n    python -m sglang.launch_server \\\n        --model-path Qwen/Qwen3-VL-Reranker-2B \\\n        --served-model-name Qwen3-VL-Reranker-2B \\\n        --trust-remote-code \\\n        --disable-radix-cache \\\n        --chat-template examples/chat_template/qwen3_vl_reranker.jinja\n\nClient Usage:\n    python examples/runtime/qwen3_vl_reranker.py\n\"\"\"\n\nimport requests\n\n# Server URL\nBASE_URL = \"http://localhost:30000\"\n\n\ndef rerank_text_only():\n    \"\"\"Example: Text-only reranking (backward compatible).\"\"\"\n    print(\"=\" * 60)\n    print(\"Text-only reranking example\")\n    print(\"=\" * 60)\n\n    request_data = {\n        \"query\": \"What is machine learning?\",\n        \"documents\": [\n            \"Machine learning is a branch of artificial intelligence that enables computers to learn from data.\",\n            \"The weather in Paris is usually mild with occasional rain.\",\n            \"Deep learning is a subset of machine learning using neural networks with many layers.\",\n        ],\n        \"instruct\": \"Retrieve passages that answer the question.\",\n        \"return_documents\": True,\n    }\n\n    response = requests.post(f\"{BASE_URL}/v1/rerank\", json=request_data)\n    results = response.json()\n\n    print(\"Results (sorted by relevance):\")\n    for i, result in enumerate(results):\n        print(f\"  {i+1}. Score: {result['score']:.4f} - {result['document'][:60]}...\")\n    print()\n\n\ndef rerank_with_images():\n    \"\"\"Example: Query is text, documents contain images.\"\"\"\n    print(\"=\" * 60)\n    print(\"Image reranking example\")\n    print(\"=\" * 60)\n\n    request_data = {\n        \"query\": \"A woman playing with her dog on a beach at sunset.\",\n        \"documents\": [\n            # Document 1: Text description\n            \"A woman shares a joyful moment with her golden retriever on a sun-drenched beach at sunset.\",\n            # Document 2: Image URL\n            [\n                {\n                    \"type\": \"image_url\",\n                    \"image_url\": {\n                        \"url\": \"https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg\"\n                    },\n                }\n            ],\n            # Document 3: Text + Image (mixed)\n            [\n                {\n                    \"type\": \"text\",\n                    \"text\": \"A joyful scene at the beach:\",\n                },\n                {\n                    \"type\": \"image_url\",\n                    \"image_url\": {\n                        \"url\": \"https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg\"\n                    },\n                },\n            ],\n        ],\n        \"instruct\": \"Retrieve images or text relevant to the user's query.\",\n        \"return_documents\": False,\n    }\n\n    response = requests.post(f\"{BASE_URL}/v1/rerank\", json=request_data)\n    results = response.json()\n\n    # Debug: print raw response if it's an error\n    if isinstance(results, dict) and \"message\" in results:\n        print(f\"Error: {results['message']}\")\n        return\n    if isinstance(results, str):\n        print(f\"Error: {results}\")\n        return\n\n    print(\"Results (sorted by relevance):\")\n    for i, result in enumerate(results):\n        print(f\"  {i+1}. Index: {result['index']}, Score: {result['score']:.4f}\")\n    print()\n\n\ndef rerank_multimodal_query():\n    \"\"\"Example: Query contains both text and image.\"\"\"\n    print(\"=\" * 60)\n    print(\"Multimodal query reranking example\")\n    print(\"=\" * 60)\n\n    request_data = {\n        # Query with text and image\n        \"query\": [\n            {\"type\": \"text\", \"text\": \"Find similar images to this:\"},\n            {\n                \"type\": \"image_url\",\n                \"image_url\": {\n                    \"url\": \"https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg\"\n                },\n            },\n        ],\n        \"documents\": [\n            \"A cat sleeping on a couch.\",\n            \"A woman and her dog enjoying the sunset at the beach.\",\n            \"A busy city street with cars and pedestrians.\",\n            [\n                {\n                    \"type\": \"image_url\",\n                    \"image_url\": {\n                        \"url\": \"https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg\"\n                    },\n                }\n            ],\n        ],\n        \"instruct\": \"Find images or descriptions similar to the query image.\",\n    }\n\n    response = requests.post(f\"{BASE_URL}/v1/rerank\", json=request_data)\n    results = response.json()\n\n    # Debug: print raw response if it's an error\n    if isinstance(results, dict) and \"message\" in results:\n        print(f\"Error: {results['message']}\")\n        return\n    if isinstance(results, str):\n        print(f\"Error: {results}\")\n        return\n\n    print(\"Results (sorted by relevance):\")\n    for i, result in enumerate(results):\n        print(f\"  {i+1}. Index: {result['index']}, Score: {result['score']:.4f}\")\n    print()\n\n\ndef main():\n    \"\"\"Run all examples.\"\"\"\n    print(\"\\nQwen3-VL-Reranker Examples\")\n    print(\"Make sure the server is running with the correct model and template.\\n\")\n\n    # Check if server is available\n    try:\n        response = requests.get(f\"{BASE_URL}/health\")\n        if response.status_code != 200:\n            print(f\"Server health check failed: {response.status_code}\")\n            return\n    except requests.exceptions.ConnectionError:\n        print(f\"Cannot connect to server at {BASE_URL}\")\n        print(\"Please start the server first with:\")\n        print(\"  python -m sglang.launch_server \\\\\")\n        print(\"      --model-path Qwen/Qwen3-VL-Reranker-2B \\\\\")\n        print(\"      --served-model-name Qwen3-VL-Reranker-2B \\\\\")\n        print(\"      --trust-remote-code \\\\\")\n        print(\"      --disable-radix-cache \\\\\")\n        print(\"      --chat-template examples/chat_template/qwen3_vl_reranker.jinja\")\n        return\n\n    # Run examples\n    rerank_text_only()\n    rerank_with_images()\n    rerank_multimodal_query()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/runtime/reward_model.py",
    "content": "# launch server\n# python -m sglang.launch_server --model LxzGordon/URM-LLaMa-3.1-8B --is-embedding\n\nimport requests\n\nurl = \"http://127.0.0.1:30000\"\n\nPROMPT = (\n    \"What is the range of the numeric output of a sigmoid node in a neural network?\"\n)\nRESPONSE1 = \"The output of a sigmoid node is bounded between -1 and 1.\"\nRESPONSE2 = \"The output of a sigmoid node is bounded between 0 and 1.\"\n\njson_data = {\n    \"conv\": [\n        [\n            {\"role\": \"user\", \"content\": PROMPT},\n            {\"role\": \"assistant\", \"content\": RESPONSE1},\n        ],\n        [\n            {\"role\": \"user\", \"content\": PROMPT},\n            {\"role\": \"assistant\", \"content\": RESPONSE2},\n        ],\n    ],\n}\nresponse = requests.post(\n    url + \"/classify\",\n    json=json_data,\n).json()\n\nprint(response)\nprint(\"scores:\", [x[\"embedding\"] for x in response])\n"
  },
  {
    "path": "examples/runtime/token_in_token_out/token_in_token_out_llm_engine.py",
    "content": "\"\"\"\nThis example demonstrates how to provide tokenized ids to LLM as input instead of text prompt, i.e. a token-in-token-out workflow.\n\"\"\"\n\nimport sglang as sgl\nfrom sglang.srt.utils.hf_transformers_utils import get_tokenizer\n\nMODEL_PATH = \"meta-llama/Llama-3.1-8B-Instruct\"\n\n\ndef main():\n    # Sample prompts.\n    prompts = [\n        \"Hello, my name is\",\n        \"The president of the United States is\",\n        \"The capital of France is\",\n        \"The future of AI is\",\n    ]\n    # Create a sampling params object.\n    sampling_params = {\"temperature\": 0.8, \"top_p\": 0.95}\n\n    # Tokenize inputs\n    tokenizer = get_tokenizer(MODEL_PATH)\n    token_ids_list = [tokenizer.encode(prompt) for prompt in prompts]\n\n    # Create an LLM.\n    llm = sgl.Engine(model_path=MODEL_PATH, skip_tokenizer_init=True)\n\n    outputs = llm.generate(input_ids=token_ids_list, sampling_params=sampling_params)\n    # Print the outputs.\n    for prompt, output in zip(prompts, outputs):\n        decode_output = tokenizer.decode(output[\"output_ids\"])\n        print(\"===============================\")\n        print(\n            f\"Prompt: {prompt}\\nGenerated token ids: {output['output_ids']}\\nGenerated text: {decode_output}\"\n        )\n        print()\n\n\n# The __main__ condition is necessary here because we use \"spawn\" to create subprocesses\n# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/runtime/token_in_token_out/token_in_token_out_llm_server.py",
    "content": "\"\"\"\nUsage:\n\npython token_in_token_out_llm_server.py\n\n\"\"\"\n\nimport requests\n\nfrom sglang.srt.utils.hf_transformers_utils import get_tokenizer\nfrom sglang.test.test_utils import is_in_ci\nfrom sglang.utils import terminate_process, wait_for_server\n\nif is_in_ci():\n    from docs.backend.patch import launch_server_cmd\nelse:\n    from sglang.utils import launch_server_cmd\n\n\nMODEL_PATH = \"meta-llama/Llama-3.1-8B-Instruct\"\n\n\ndef main():\n    # Launch the server\n    server_process, port = launch_server_cmd(\n        f\"python -m sglang.launch_server --model-path {MODEL_PATH} --skip-tokenizer-init --host 0.0.0.0\"\n    )\n    wait_for_server(f\"http://localhost:{port}\", process=server_process)\n\n    # Sample prompts.\n    prompts = [\n        \"Hello, my name is\",\n        \"The president of the United States is\",\n        \"The capital of France is\",\n        \"The future of AI is\",\n    ]\n\n    # Create a sampling params object.\n    sampling_params = {\"temperature\": 0.8, \"top_p\": 0.95}\n\n    # Tokenize inputs\n    tokenizer = get_tokenizer(MODEL_PATH)\n    token_ids_list = [tokenizer.encode(prompt) for prompt in prompts]\n\n    json_data = {\n        \"input_ids\": token_ids_list,\n        \"sampling_params\": sampling_params,\n    }\n\n    response = requests.post(\n        f\"http://localhost:{port}/generate\",\n        json=json_data,\n    )\n\n    outputs = response.json()\n    for prompt, output in zip(prompts, outputs):\n        print(\"===============================\")\n        decode_output = tokenizer.decode(output[\"output_ids\"])\n        print(\n            f\"Prompt: {prompt}\\nGenerated token ids: {output['output_ids']}\\nGenerated text: {decode_output}\"\n        )\n        print()\n\n    terminate_process(server_process)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/runtime/token_in_token_out/token_in_token_out_vlm_engine.py",
    "content": "import argparse\nimport dataclasses\nfrom typing import Tuple\n\nfrom transformers import AutoProcessor\n\nfrom sglang import Engine\nfrom sglang.lang.chat_template import get_chat_template_by_model_path\nfrom sglang.srt.configs.model_config import ModelConfig\nfrom sglang.srt.server_args import ServerArgs\nfrom sglang.test.test_utils import DEFAULT_IMAGE_URL\n\n\ndef get_input_ids(\n    server_args: ServerArgs, model_config: ModelConfig\n) -> Tuple[list[int], list]:\n    chat_template = get_chat_template_by_model_path(model_config.model_path)\n    text = f\"{chat_template.image_token}What is in this picture?\"\n    image_data = [DEFAULT_IMAGE_URL]\n\n    processor = AutoProcessor.from_pretrained(\n        model_config.model_path, trust_remote_code=server_args.trust_remote_code\n    )\n\n    input_ids = (\n        processor.tokenizer(\n            text=[text],\n            return_tensors=\"pt\",\n        )\n        .input_ids[0]\n        .tolist()\n    )\n\n    return input_ids, image_data\n\n\ndef token_in_out_example(\n    server_args: ServerArgs,\n):\n    input_ids, image_data = get_input_ids(\n        server_args,\n        ModelConfig(\n            server_args.model_path,\n            trust_remote_code=server_args.trust_remote_code,\n            model_override_args=server_args.json_model_override_args,\n        ),\n    )\n    backend = Engine(**dataclasses.asdict(server_args))\n\n    output = backend.generate(\n        input_ids=input_ids,\n        image_data=image_data,\n        sampling_params={\n            \"temperature\": 0.8,\n            \"max_new_tokens\": 32,\n        },\n    )\n\n    print(\"===============================\")\n    print(f\"Output token ids: \", output[\"output_ids\"])\n\n    backend.shutdown()\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    ServerArgs.add_cli_args(parser)\n    args = [\n        \"--model-path=Qwen/Qwen2-VL-2B\",\n    ]\n    args = parser.parse_args(args=args)\n    server_args = ServerArgs.from_cli_args(args)\n    server_args.skip_tokenizer_init = True\n    token_in_out_example(server_args)\n"
  },
  {
    "path": "examples/runtime/token_in_token_out/token_in_token_out_vlm_server.py",
    "content": "\"\"\"\nUsage:\n\npython token_in_token_out_vlm_server.py\n\n\"\"\"\n\nfrom typing import Tuple\n\nimport requests\nfrom transformers import AutoProcessor\n\nfrom sglang.lang.chat_template import get_chat_template_by_model_path\nfrom sglang.test.test_utils import DEFAULT_IMAGE_URL, is_in_ci\nfrom sglang.utils import terminate_process, wait_for_server\n\nif is_in_ci():\n    from docs.backend.patch import launch_server_cmd\nelse:\n    from sglang.utils import launch_server_cmd\n\n\nMODEL_PATH = \"Qwen/Qwen2-VL-2B\"\n\n\ndef get_input_ids() -> Tuple[list[int], list]:\n    chat_template = get_chat_template_by_model_path(MODEL_PATH)\n    text = f\"{chat_template.image_token}What is in this picture?\"\n    image_data = [DEFAULT_IMAGE_URL]\n\n    processor = AutoProcessor.from_pretrained(MODEL_PATH)\n\n    input_ids = (\n        processor.tokenizer(\n            text=[text],\n            return_tensors=\"pt\",\n        )\n        .input_ids[0]\n        .tolist()\n    )\n\n    return input_ids, image_data\n\n\ndef main():\n    # Launch the server\n    server_process, port = launch_server_cmd(\n        f\"python -m sglang.launch_server --model-path {MODEL_PATH} --skip-tokenizer-init --host 0.0.0.0\"\n    )\n    wait_for_server(f\"http://localhost:{port}\", process=server_process)\n\n    input_ids, image_data = get_input_ids()\n\n    sampling_params = {\n        \"temperature\": 0.8,\n        \"max_new_tokens\": 32,\n    }\n\n    json_data = {\n        \"input_ids\": input_ids,\n        \"image_data\": image_data,\n        \"sampling_params\": sampling_params,\n    }\n\n    response = requests.post(\n        f\"http://localhost:{port}/generate\",\n        json=json_data,\n    )\n\n    output = response.json()\n    print(\"===============================\")\n    print(f\"Output token ids: \", output[\"output_ids\"])\n\n    terminate_process(server_process)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/runtime/vertex_predict.py",
    "content": "\"\"\"\nUsage:\npython -m sglang.launch_server --model meta-llama/Llama-2-7b-hf --port 30000\npython vertex_predict.py\n\nThis example shows the request and response formats of the prediction route for\nGoogle Cloud Vertex AI Online Predictions.\n\nVertex AI SDK for Python is recommended for deploying models to Vertex AI\ninstead of a local server. After deploying the model to a Vertex AI Online\nPrediction Endpoint, send requests via the Python SDK:\n\nresponse = endpoint.predict(\n    instances=[\n        {\"text\": \"The capital of France is\"},\n        {\"text\": \"What is a car?\"},\n    ],\n    parameters={\"sampling_params\": {\"max_new_tokens\": 16}},\n)\nprint(response.predictions)\n\nMore details about get online predictions from Vertex AI can be found at\nhttps://cloud.google.com/vertex-ai/docs/predictions/get-online-predictions.\n\"\"\"\n\nfrom dataclasses import dataclass\nfrom typing import List, Optional\n\nimport requests\n\n\n@dataclass\nclass VertexPrediction:\n    predictions: List\n\n\nclass LocalVertexEndpoint:\n    def __init__(self) -> None:\n        self.base_url = \"http://127.0.0.1:30000\"\n\n    def predict(self, instances: List[dict], parameters: Optional[dict] = None):\n        response = requests.post(\n            self.base_url + \"/vertex_generate\",\n            json={\n                \"instances\": instances,\n                \"parameters\": parameters,\n            },\n        )\n        return VertexPrediction(predictions=response.json()[\"predictions\"])\n\n\nendpoint = LocalVertexEndpoint()\n\n# Predict with a single prompt.\nresponse = endpoint.predict(instances=[{\"text\": \"The capital of France is\"}])\nprint(response.predictions)\n\n# Predict with multiple prompts and parameters.\nresponse = endpoint.predict(\n    instances=[\n        {\"text\": \"The capital of France is\"},\n        {\"text\": \"What is a car?\"},\n    ],\n    parameters={\"sampling_params\": {\"max_new_tokens\": 16}},\n)\nprint(response.predictions)\n"
  },
  {
    "path": "examples/sagemaker/deploy_and_serve_endpoint.py",
    "content": "import json\n\nimport boto3\nfrom sagemaker import serializers\nfrom sagemaker.model import Model\nfrom sagemaker.predictor import Predictor\n\nboto_session = boto3.session.Session()\nsm_client = boto_session.client(\"sagemaker\")\nsm_role = boto_session.resource(\"iam\").Role(\"SageMakerRole\").arn\n\nendpoint_name = \"<YOUR_ENDPOINT_NAME>\"\nimage_uri = \"<YOUR_DOCKER_IMAGE_URI>\"\nmodel_id = (\n    \"<YOUR_MODEL_ID>\"  # eg: Qwen/Qwen3-0.6B from https://huggingface.co/Qwen/Qwen3-0.6B\n)\nhf_token = \"<YOUR_HUGGINGFACE_TOKEN>\"\nprompt = \"<YOUR_ENDPOINT_PROMPT>\"\n\nmodel = Model(\n    name=endpoint_name,\n    image_uri=image_uri,\n    role=sm_role,\n    env={\n        \"SM_SGLANG_MODEL_PATH\": model_id,\n        \"HF_TOKEN\": hf_token,\n    },\n)\nprint(\"Model created successfully\")\nprint(\"Starting endpoint deployment (this may take 10-15 minutes)...\")\n\nendpoint_config = model.deploy(\n    instance_type=\"ml.g5.12xlarge\",\n    initial_instance_count=1,\n    endpoint_name=endpoint_name,\n    inference_ami_version=\"al2-ami-sagemaker-inference-gpu-3-1\",\n    wait=True,\n)\nprint(\"Endpoint deployment completed successfully\")\n\n\nprint(f\"Creating predictor for endpoint: {endpoint_name}\")\npredictor = Predictor(\n    endpoint_name=endpoint_name,\n    serializer=serializers.JSONSerializer(),\n)\n\npayload = {\n    \"model\": model_id,\n    \"messages\": [{\"role\": \"user\", \"content\": prompt}],\n    \"max_tokens\": 2400,\n    \"temperature\": 0.01,\n    \"top_p\": 0.9,\n    \"top_k\": 50,\n}\nprint(f\"Sending inference request with prompt: '{prompt[:50]}...'\")\nresponse = predictor.predict(payload)\nprint(\"Inference request completed successfully\")\n\nif isinstance(response, bytes):\n    response = response.decode(\"utf-8\")\n\nif isinstance(response, str):\n    try:\n        response = json.loads(response)\n    except json.JSONDecodeError:\n        print(\"Warning: Response is not valid JSON. Returning as string.\")\n\nprint(f\"Received model response: '{response}'\")\n"
  },
  {
    "path": "examples/usage/modelopt_quantize_and_export.py",
    "content": "#!/usr/bin/env python3\n\"\"\"\nExample: ModelOpt Quantization and Export with SGLang\n\nThis example demonstrates the streamlined workflow for quantizing a model with\nModelOpt and automatically exporting it for deployment with SGLang.\n\"\"\"\n\nimport argparse\nimport os\nfrom typing import Optional\n\nimport torch\n\nimport sglang as sgl\nfrom sglang.srt.configs.device_config import DeviceConfig\nfrom sglang.srt.configs.load_config import LoadConfig\nfrom sglang.srt.configs.model_config import ModelConfig\nfrom sglang.srt.distributed.parallel_state import (\n    init_distributed_environment,\n    initialize_model_parallel,\n)\nfrom sglang.srt.model_loader.loader import get_model_loader\n\n\ndef _validate_export(export_dir: str) -> bool:\n    \"\"\"Validate that an exported model directory contains the expected files.\"\"\"\n    import glob\n\n    required_files = [\"config.json\", \"tokenizer_config.json\"]\n\n    if not os.path.exists(export_dir):\n        return False\n\n    # Check required files\n    for file in required_files:\n        if not os.path.exists(os.path.join(export_dir, file)):\n            return False\n\n    # Check for model files using pattern matching to handle sharded models\n    model_patterns = [\n        \"model*.safetensors\",\n        \"pytorch_model*.bin\",\n    ]\n\n    has_model_file = False\n    for pattern in model_patterns:\n        matching_files = glob.glob(os.path.join(export_dir, pattern))\n        if matching_files:\n            has_model_file = True\n            break\n\n    return has_model_file\n\n\ndef _get_export_info(export_dir: str) -> Optional[dict]:\n    \"\"\"Get information about an exported model.\"\"\"\n    import json\n\n    if not _validate_export(export_dir):\n        return None\n\n    try:\n        config_path = os.path.join(export_dir, \"config.json\")\n        with open(config_path, \"r\") as f:\n            config = json.load(f)\n\n        return {\n            \"model_type\": config.get(\"model_type\", \"unknown\"),\n            \"architectures\": config.get(\"architectures\", []),\n            \"quantization_config\": config.get(\"quantization_config\", {}),\n            \"export_dir\": export_dir,\n        }\n    except Exception:\n        return None\n\n\ndef quantize_and_export_model(\n    model_path: str,\n    export_dir: str,\n    quantization_method: str = \"modelopt_fp8\",\n    checkpoint_save_path: Optional[str] = None,\n    device: str = \"cuda\",\n) -> None:\n    \"\"\"\n    Quantize a model with ModelOpt and export it for SGLang deployment.\n\n    Args:\n        model_path: Path to the original model\n        export_dir: Directory to export the quantized model\n        quantization_method: Quantization method (\"modelopt_fp8\" or \"modelopt_fp4\")\n        checkpoint_save_path: Optional path to save ModelOpt checkpoint\n        device: Device to use for quantization\n    \"\"\"\n    print(\"🚀 Starting ModelOpt quantization and export workflow\")\n    print(f\"📥 Input model: {model_path}\")\n    print(f\"📤 Export directory: {export_dir}\")\n    print(f\"⚙️  Quantization method: {quantization_method}\")\n\n    # Initialize minimal distributed environment for single GPU quantization\n    if not torch.distributed.is_initialized():\n        print(\"🔧 Initializing distributed environment...\")\n        # Set up environment variables for single-process distributed\n        os.environ[\"RANK\"] = \"0\"\n        os.environ[\"WORLD_SIZE\"] = \"1\"\n        os.environ[\"MASTER_ADDR\"] = \"localhost\"\n        os.environ[\"MASTER_PORT\"] = \"12355\"  # Use a different port than tests\n        os.environ[\"LOCAL_RANK\"] = \"0\"\n\n        init_distributed_environment(\n            world_size=1,\n            rank=0,\n            local_rank=0,\n            backend=\"nccl\" if device == \"cuda\" else \"gloo\",\n        )\n        initialize_model_parallel(\n            tensor_model_parallel_size=1,\n            pipeline_model_parallel_size=1,\n        )\n\n    # Configure model loading with ModelOpt quantization and export\n    model_config = ModelConfig(\n        model_path=model_path,\n        quantization=quantization_method,  # Use unified quantization flag\n        trust_remote_code=True,\n    )\n\n    load_config = LoadConfig(\n        modelopt_checkpoint_save_path=checkpoint_save_path,\n        modelopt_export_path=export_dir,\n    )\n    device_config = DeviceConfig(device=device)\n\n    # Load and quantize the model (export happens automatically)\n    print(\"🔄 Loading and quantizing model...\")\n    model_loader = get_model_loader(load_config, model_config)\n\n    try:\n        model_loader.load_model(\n            model_config=model_config,\n            device_config=device_config,\n        )\n        print(\"✅ Model quantized successfully!\")\n\n        # Validate the export\n        if _validate_export(export_dir):\n            print(\"✅ Export validation passed!\")\n\n            info = _get_export_info(export_dir)\n            if info:\n                print(\"📋 Model info:\")\n                print(f\"   - Type: {info['model_type']}\")\n                print(f\"   - Architecture: {info['architectures']}\")\n                print(f\"   - Quantization: {info['quantization_config']}\")\n        else:\n            print(\"❌ Export validation failed!\")\n            return\n\n    except Exception as e:\n        print(f\"❌ Quantization failed: {e}\")\n        return\n\n    print(\"\\n🎉 Workflow completed successfully!\")\n    print(f\"📁 Quantized model exported to: {export_dir}\")\n    print(\"\\n🚀 To use the exported model:\")\n    print(\n        f\"   python -m sglang.launch_server --model-path {export_dir} --quantization modelopt\"\n    )\n    print(\"\\n   # Or in Python:\")\n    print(\"   import sglang as sgl\")\n    print(f\"   llm = sgl.Engine(model_path='{export_dir}', quantization='modelopt')\")\n    print(\"   # Note: 'modelopt' auto-detects FP4/FP8 from model config\")\n\n\ndef deploy_exported_model(\n    export_dir: str,\n    host: str = \"127.0.0.1\",\n    port: int = 30000,\n) -> None:\n    \"\"\"\n    Deploy an exported ModelOpt quantized model with SGLang.\n\n    Args:\n        export_dir: Directory containing the exported model\n        host: Host to bind the server to\n        port: Port to bind the server to\n    \"\"\"\n    print(f\"🚀 Deploying exported model from: {export_dir}\")\n\n    # Validate export first\n    if not _validate_export(export_dir):\n        print(\"❌ Invalid export directory!\")\n        return\n\n    try:\n        # Launch SGLang engine with the exported model\n        # Using generic \"modelopt\" for auto-detection of FP4/FP8\n        llm = sgl.Engine(\n            model_path=export_dir,\n            quantization=\"modelopt\",\n            host=host,\n            port=port,\n        )\n\n        print(\"✅ Model deployed successfully!\")\n        print(f\"🌐 Server running at http://{host}:{port}\")\n\n        # Example inference\n        prompts = [\"Hello, how are you?\", \"What is the capital of France?\"]\n        sampling_params = {\"temperature\": 0.8, \"top_p\": 0.95, \"max_new_tokens\": 100}\n\n        print(\"\\n🧪 Running example inference...\")\n        outputs = llm.generate(prompts, sampling_params)\n\n        for i, output in enumerate(outputs):\n            print(f\"Prompt {i+1}: {prompts[i]}\")\n            print(f\"Output: {output['text']}\")\n            print()\n\n    except Exception as e:\n        print(f\"❌ Deployment failed: {e}\")\n\n\ndef main():\n    parser = argparse.ArgumentParser(\n        description=\"ModelOpt Quantization and Export with SGLang\",\n        formatter_class=argparse.RawDescriptionHelpFormatter,\n        epilog=\"\"\"\nExamples:\n  # Quantize and export a model (recommended workflow)\n  python modelopt_quantize_and_export.py quantize \\\\\n    --model-path TinyLlama/TinyLlama-1.1B-Chat-v1.0 \\\\\n    --export-dir ./quantized_model \\\\\n    --quantization-method modelopt_fp8\n\n  # Deploy a pre-exported model\n  python modelopt_quantize_and_export.py deploy \\\\\n    --export-dir ./quantized_model\n        \"\"\",\n    )\n\n    subparsers = parser.add_subparsers(dest=\"command\", help=\"Available commands\")\n\n    # Quantize command\n    quantize_parser = subparsers.add_parser(\n        \"quantize\", help=\"Quantize and export a model\"\n    )\n    quantize_parser.add_argument(\n        \"--model-path\", required=True, help=\"Path to the model to quantize\"\n    )\n    quantize_parser.add_argument(\n        \"--export-dir\", required=True, help=\"Directory to export the quantized model\"\n    )\n    quantize_parser.add_argument(\n        \"--quantization-method\",\n        choices=[\"modelopt_fp8\", \"modelopt_fp4\"],\n        default=\"modelopt_fp8\",\n        help=\"Quantization method to use\",\n    )\n    quantize_parser.add_argument(\n        \"--checkpoint-save-path\", help=\"Optional path to save ModelOpt checkpoint\"\n    )\n    quantize_parser.add_argument(\n        \"--device\", default=\"cuda\", help=\"Device to use for quantization\"\n    )\n\n    # TODO: Quantize-and-serve command removed due to compatibility issues\n    # Use the separate quantize-then-deploy workflow instead\n\n    # Deploy command\n    deploy_parser = subparsers.add_parser(\"deploy\", help=\"Deploy an exported model\")\n    deploy_parser.add_argument(\n        \"--export-dir\", required=True, help=\"Directory containing the exported model\"\n    )\n    deploy_parser.add_argument(\n        \"--host\", default=\"127.0.0.1\", help=\"Host to bind the server to\"\n    )\n    deploy_parser.add_argument(\n        \"--port\", type=int, default=30000, help=\"Port to bind the server to\"\n    )\n\n    args = parser.parse_args()\n\n    if args.command == \"quantize\":\n        quantize_and_export_model(\n            model_path=args.model_path,\n            export_dir=args.export_dir,\n            quantization_method=args.quantization_method,\n            checkpoint_save_path=args.checkpoint_save_path,\n            device=args.device,\n        )\n    elif args.command == \"deploy\":\n        deploy_exported_model(\n            export_dir=args.export_dir,\n            host=args.host,\n            port=args.port,\n        )\n    else:\n        parser.print_help()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "python/pyproject.toml",
    "content": "[build-system]\nrequires = [\"setuptools>=61.0\", \"setuptools-scm>=8.0\", \"wheel\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"sglang\"\ndynamic = [\"version\"]\ndescription = \"SGLang is a fast serving framework for large language models and vision language models.\"\nreadme = \"README.md\"\nrequires-python = \">=3.10\"\nlicense = { file = \"LICENSE\" }\nclassifiers = [\n  \"Programming Language :: Python :: 3\",\n  \"License :: OSI Approved :: Apache Software License\",\n]\n\ndependencies = [\n  \"IPython\",\n  \"aiohttp\",\n  \"apache-tvm-ffi>=0.1.5,<0.2\",\n  \"anthropic>=0.20.0\",\n  \"blobfile==3.0.0\",\n  \"build\",\n  \"compressed-tensors\",\n  \"cuda-python==12.9\",\n  \"decord2 ; sys_platform == 'linux' and (platform_machine == 'aarch64' or platform_machine == 'arm64' or platform_machine == 'armv7l')\",\n  \"datasets\",\n  \"einops\",\n  \"fastapi\",\n  \"flashinfer_python==0.6.6\", # keep it aligned with jit-cache version in Dockerfile\n  \"flashinfer_cubin==0.6.6\",\n  \"gguf\",\n  \"interegular\",\n  \"llguidance>=0.7.11,<0.8.0\",\n  \"modelscope\",\n  \"msgspec\",\n  \"ninja\",\n  \"numpy\",\n  \"nvidia-cutlass-dsl>=4.4.1\",\n  \"nvidia-ml-py\",\n  \"openai-harmony==0.0.4\",\n  \"openai==2.6.1\",\n  \"orjson\",\n  \"outlines==0.1.11\",\n  \"packaging\",\n  \"partial_json_parser\",\n  \"pillow\",\n  \"prometheus-client>=0.20.0\",\n  \"psutil\",\n  \"py-spy\",\n  \"pybase64\",\n  \"pydantic\",\n  \"python-multipart\",\n  \"pyzmq>=25.1.2\",\n  \"quack-kernels>=0.3.0\",\n  \"requests\",\n  \"scipy\",\n  \"sentencepiece\",\n  \"setproctitle\",\n  \"flash-attn-4>=4.0.0b4\",\n  \"sglang-kernel==0.4.0\",\n  \"soundfile==0.13.1\",\n  \"tiktoken\",\n  \"timm==1.0.16\",\n  \"torch_memory_saver==0.0.9\",\n  \"torch==2.9.1\",\n  \"torchao==0.9.0\",\n  \"torchaudio==2.9.1\",\n  \"torchcodec==0.9.1 ; sys_platform != 'linux' or (sys_platform == 'linux' and platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')\", # torchcodec 0.9.1 for torch 2.9.x. Not available on Linux ARM.\n  \"av ; sys_platform == 'linux' and (platform_machine == 'aarch64' or platform_machine == 'arm64' or platform_machine == 'armv7l')\",\n  \"torchvision\",\n  \"tqdm\",\n  \"mistral_common>=1.9.0\",\n  \"transformers==5.3.0\",\n  \"uvicorn\",\n  \"uvloop\",\n  \"watchfiles\",\n  \"xgrammar==0.1.27\",\n  \"smg-grpc-servicer>=0.5.0\",\n]\n\n[[tool.uv.index]]\nname = \"pypi\"\nurl = \"https://pypi.org/simple\"\ndefault = true\n\n[[tool.uv.index]]\nname = \"torch-cu129\"\nurl = \"https://download.pytorch.org/whl/cu129\"\nexplicit = true\n\n[tool.uv.sources]\ntorch = [\n  { index = \"pypi\", marker = \"platform_machine == 'x86_64'\"},\n  { index = \"torch-cu129\", marker = \"platform_machine == 'aarch64'\"},\n]\n\n[project.optional-dependencies]\ncheckpoint-engine = [\"checkpoint-engine==0.1.2\"]\ndiffusion = [\n  \"PyYAML==6.0.1\",\n  \"cloudpickle==3.1.2\",\n  \"diffusers==0.37.0\",\n  \"imageio==2.36.0\",\n  \"imageio-ffmpeg==0.5.1\",\n  \"moviepy>=2.0.0\",\n  \"opencv-python-headless==4.10.0.84\",\n  \"remote-pdb==2.1.0\",\n  \"st_attn==0.0.7 ; platform_machine != 'aarch64' and platform_machine != 'arm64'\",\n  \"vsa==0.0.4 ; platform_machine != 'aarch64' and platform_machine != 'arm64'\",\n  \"runai_model_streamer>=0.15.5\",\n  \"cache-dit==1.3.0\",\n  \"addict==2.4.0\",\n  \"av==16.1.0\",\n  \"scikit-image==0.25.2\",\n  \"trimesh>=4.0.0\",\n  \"xatlas\",\n]\n\nray = [\n  \"ray[default]>=2.54.0\",\n]\n\ntracing = [\n  \"opentelemetry-api\",\n  \"opentelemetry-exporter-otlp\",\n  \"opentelemetry-exporter-otlp-proto-grpc\",\n  \"opentelemetry-sdk\",\n]\n\ntest = [\n  \"accelerate\",\n  \"addict\",\n  \"bitsandbytes\",\n  \"expecttest\",\n  \"jsonlines\",\n  \"lm-eval[api]>=0.4.9.2\",\n  \"matplotlib\",\n  \"pandas\",\n  \"parameterized\",\n  \"peft>=0.18.0\",\n  \"pytest\",\n  \"pytest-cov\",\n  \"diff-cover\",\n  \"sentence_transformers\",\n  \"tabulate\",\n]\n\ndev = [\"sglang[test]\"]\n\nall = [\n  \"sglang[diffusion]\",\n  \"sglang[tracing]\",\n]\n\n[tool.uv.extra-build-dependencies]\nst-attn = [\"torch\", \"setuptools\"]\nvsa = [\"torch\", \"setuptools\"]\n\n[project.urls]\n\"Homepage\" = \"https://github.com/sgl-project/sglang\"\n\"Bug Tracker\" = \"https://github.com/sgl-project/sglang/issues\"\n\n[project.scripts]\nsglang = \"sglang.cli.main:main\"\n\n[tool.setuptools.package-data]\n\"sglang\" = [\n  \"srt/**/*\",\n  \"jit_kernel/**/*\"\n]\n\n[tool.setuptools.packages.find]\nexclude = [\n  \"assets*\",\n  \"benchmark*\",\n  \"docs*\",\n  \"dist*\",\n  \"playground*\",\n  \"scripts*\",\n  \"tests*\",\n]\n\n[tool.wheel]\nexclude = [\n  \"assets*\",\n  \"benchmark*\",\n  \"docs*\",\n  \"dist*\",\n  \"playground*\",\n  \"scripts*\",\n  \"tests*\",\n]\n\n[tool.setuptools_scm]\nroot = \"..\"\nversion_file = \"sglang/_version.py\"\ngit_describe_command = [\"bash\", \"-c\", \"git tag --list --sort=-version:refname 'v*.*.*' | head -1 | xargs git describe --tags --long\"]\n# Allow editable installs even when .git metadata is not available.\nfallback_version = \"0.0.0.dev0\"\n"
  },
  {
    "path": "python/pyproject_cpu.toml",
    "content": "# https://docs.sglang.io/platforms/cpu_server.html\n[build-system]\nrequires = [\"setuptools>=61.0\", \"setuptools-scm>=8.0\", \"wheel\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"sglang-cpu\"\ndynamic = [\"version\"]\ndescription = \"SGLang is a fast serving framework for large language models and vision language models.\"\nreadme = \"README.md\"\nrequires-python = \">=3.10\"\nlicense = { file = \"LICENSE\" }\nclassifiers = [\n  \"Programming Language :: Python :: 3\",\n  \"License :: OSI Approved :: Apache Software License\",\n]\n\ndependencies = [\n  \"IPython\",\n  \"aiohttp\",\n  \"anthropic>=0.20.0\",\n  \"blobfile==3.0.0\",\n  \"build\",\n  \"compressed-tensors\",\n  \"datasets\",\n  \"einops\",\n  \"fastapi\",\n  \"gguf\",\n  \"intel-openmp; platform_machine == 'x86_64'\",\n  \"interegular\",\n  \"llguidance>=0.7.11,<0.8.0\",\n  \"modelscope\",\n  \"msgspec\",\n  \"ninja\",\n  \"numpy\",\n  \"openai-harmony==0.0.4\",\n  \"openai==2.6.1\",\n  \"orjson\",\n  \"outlines\",\n  \"packaging\",\n  \"partial_json_parser\",\n  \"pillow\",\n  \"prometheus-client>=0.20.0\",\n  \"psutil\",\n  \"py-spy\",\n  \"pybase64\",\n  \"pydantic\",\n  \"python-multipart\",\n  \"pyzmq>=25.1.2\",\n  \"requests\",\n  \"scipy\",\n  \"sentencepiece\",\n  \"setproctitle\",\n  \"soundfile==0.13.1\",\n  \"tabulate\",\n  \"tiktoken\",\n  \"timm==1.0.16\",\n  \"torch==2.9.0\",\n  \"torchao==0.14.1\",\n  \"torchaudio==2.9.0\",\n  \"torchvision==0.24.0\",\n  \"tqdm\",\n  \"mistral_common>=1.9.0\",\n  \"transformers==5.3.0\",\n  \"triton==3.5.0\",\n  \"uvicorn\",\n  \"uvloop\",\n  \"xgrammar==0.1.27\",\n  \"smg-grpc-servicer>=0.5.0\",\n]\n\n[project.optional-dependencies]\ntracing = [\n  \"opentelemetry-sdk\",\n  \"opentelemetry-api\",\n  \"opentelemetry-exporter-otlp\",\n  \"opentelemetry-exporter-otlp-proto-grpc\",\n]\ntest = [\n  \"accelerate\",\n  \"expecttest\",\n  \"jsonlines\",\n  \"matplotlib\",\n  \"pandas\",\n  \"peft>=0.18.0\",\n  \"pytest\",\n  \"sentence_transformers\",\n]\nall = []\ndev = [\"sglang[test]\"]\n\n[project.urls]\n\"Homepage\" = \"https://github.com/sgl-project/sglang\"\n\"Bug Tracker\" = \"https://github.com/sgl-project/sglang/issues\"\n\n[project.scripts]\nsglang = \"sglang.cli.main:main\"\n\n[tool.setuptools.package-data]\n\"sglang\" = [\n  \"srt/**/*\",\n  \"jit_kernel/**/*\"\n]\n\n[tool.setuptools.packages.find]\nexclude = [\n  \"assets*\",\n  \"benchmark*\",\n  \"docs*\",\n  \"dist*\",\n  \"playground*\",\n  \"scripts*\",\n  \"tests*\",\n]\n\n[tool.wheel]\nexclude = [\n  \"assets*\",\n  \"benchmark*\",\n  \"docs*\",\n  \"dist*\",\n  \"playground*\",\n  \"scripts*\",\n  \"tests*\",\n]\n\n[tool.setuptools_scm]\nroot = \"..\"\nversion_file = \"sglang/_version.py\"\ngit_describe_command = [\"git\", \"describe\", \"--tags\", \"--long\", \"--match\", \"v*\"]\n"
  },
  {
    "path": "python/pyproject_npu.toml",
    "content": "[build-system]\nrequires = [\"setuptools>=61.0\", \"setuptools-scm>=8.0\", \"wheel\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"sglang\"\ndynamic = [\"version\"]\ndescription = \"SGLang is a fast serving framework for large language models and vision language models.\"\nreadme = \"README.md\"\nrequires-python = \">=3.10\"\nlicense = { file = \"LICENSE\" }\nclassifiers = [\n  \"Programming Language :: Python :: 3\",\n  \"License :: OSI Approved :: Apache Software License\",\n]\n\ndependencies = [\n  \"IPython\",\n  \"aiohttp\",\n  \"anthropic>=0.20.0\",\n  \"blobfile==3.0.0\",\n  \"av\",\n  \"build\",\n  \"compressed-tensors\",\n  \"datasets\",\n  \"einops\",\n  \"fastapi\",\n  \"gguf\",\n  \"interegular\",\n  \"llguidance>=0.7.11,<0.8.0\",\n  \"modelscope\",\n  \"msgspec\",\n  \"ninja\",\n  \"numpy\",\n  \"openai-harmony==0.0.4\",\n  \"openai==2.6.1\",\n  \"orjson\",\n  \"outlines==0.1.11\",\n  \"packaging\",\n  \"partial_json_parser\",\n  \"pillow\",\n  \"prometheus-client>=0.20.0\",\n  \"psutil\",\n  \"py-spy\",\n  \"pybase64\",\n  \"pydantic\",\n  \"python-multipart\",\n  \"pyzmq>=25.1.2\",\n  \"requests\",\n  \"scipy\",\n  \"sentencepiece\",\n  \"setproctitle\",\n  \"soundfile==0.13.1\",\n  \"tiktoken\",\n  \"timm==1.0.16\",\n  \"torchao==0.9.0\",\n  \"tqdm\",\n  \"mistral_common>=1.9.0\",\n  \"transformers==5.3.0\",\n  \"uvicorn\",\n  \"uvloop\",\n  \"xgrammar==0.1.27\",\n  \"smg-grpc-servicer>=0.5.0\",\n]\n\n[project.optional-dependencies]\ncheckpoint-engine = [\"checkpoint-engine==0.1.2\"]\ndiffusion = [\n    \"PyYAML==6.0.1\",\n    \"cloudpickle\",\n    \"diffusers==0.37.0\",\n    \"imageio==2.36.0\",\n    \"imageio-ffmpeg==0.5.1\",\n    \"moviepy>=2.0.0\",\n    \"opencv-python==4.10.0.84\",\n    \"remote-pdb\",\n    \"cache-dit==1.2.1\",\n    \"addict\",\n    \"scikit-image==0.25.2\",\n    \"trimesh>=4.0.0\",\n    \"xatlas\",\n]\n\ntracing = [\n  \"opentelemetry-api\",\n  \"opentelemetry-exporter-otlp\",\n  \"opentelemetry-exporter-otlp-proto-grpc\",\n  \"opentelemetry-sdk\",\n]\n\ntest = [\n  \"accelerate\",\n  \"expecttest\",\n  \"gguf\",\n  \"jsonlines\",\n  \"matplotlib\",\n  \"pandas\",\n  \"peft>=0.18.0\",\n  \"pytest\",\n  \"sentence_transformers\",\n  \"tabulate\",\n]\n\n# https://docs.sglang.io/platforms/ascend_npu.html\nsrt_npu = []\nall_npu = [\"sglang[diffusion]\"]\ndev_npu = [\"sglang[all_npu]\", \"sglang[test]\"]\n\n[project.urls]\n\"Homepage\" = \"https://github.com/sgl-project/sglang\"\n\"Bug Tracker\" = \"https://github.com/sgl-project/sglang/issues\"\n\n[project.scripts]\nsglang = \"sglang.cli.main:main\"\n\n[tool.setuptools.package-data]\n\"sglang\" = [\n  \"srt/**/*\",\n  \"jit_kernel/**/*\"\n]\n\n[tool.setuptools.packages.find]\nexclude = [\n  \"assets*\",\n  \"benchmark*\",\n  \"docs*\",\n  \"dist*\",\n  \"playground*\",\n  \"scripts*\",\n  \"tests*\",\n]\n\n[tool.wheel]\nexclude = [\n  \"assets*\",\n  \"benchmark*\",\n  \"docs*\",\n  \"dist*\",\n  \"playground*\",\n  \"scripts*\",\n  \"tests*\",\n]\n\n[tool.setuptools_scm]\nroot = \"..\"\nversion_file = \"sglang/_version.py\"\ngit_describe_command = [\"git\", \"describe\", \"--tags\", \"--long\", \"--match\", \"v*\"]\n"
  },
  {
    "path": "python/pyproject_other.toml",
    "content": "[build-system]\nrequires = [\"setuptools>=61.0\", \"setuptools-scm>=8.0\", \"wheel\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"sglang\"\ndynamic = [\"version\"]\ndescription = \"SGLang is a fast serving framework for large language models and vision language models.\"\nreadme = \"README.md\"\nrequires-python = \">=3.10\"\nlicense = { file = \"LICENSE\" }\nclassifiers = [\n  \"Programming Language :: Python :: 3\",\n  \"License :: OSI Approved :: Apache Software License\",\n]\ndependencies = [\"aiohttp\", \"requests\", \"tqdm\", \"numpy\", \"IPython\", \"setproctitle\"]\n\n[project.optional-dependencies]\nruntime_common = [\n  \"IPython\",\n  \"aiohttp\",\n  \"anthropic>=0.20.0\",\n  \"blobfile==3.0.0\",\n  \"av\",\n  \"build\",\n  \"compressed-tensors\",\n  \"datasets\",\n  \"einops\",\n  \"fastapi\",\n  \"gguf\",\n  \"interegular\",\n  \"llguidance>=0.7.11,<0.8.0\",\n  \"modelscope\",\n  \"msgspec\",\n  \"ninja\",\n  \"numpy\",\n  \"openai-harmony==0.0.4\",\n  \"openai==2.6.1\",\n  \"orjson\",\n  \"outlines==0.1.11\",\n  \"packaging\",\n  \"partial_json_parser\",\n  \"pillow\",\n  \"prometheus-client>=0.20.0\",\n  \"psutil\",\n  \"py-spy\",\n  \"pybase64\",\n  \"pydantic\",\n  \"python-multipart\",\n  \"pyzmq>=25.1.2\",\n  \"requests\",\n  \"scipy\",\n  \"sentencepiece\",\n  \"setproctitle\",\n  \"soundfile==0.13.1\",\n  \"tiktoken\",\n  \"timm==1.0.16\",\n  \"torchao==0.9.0\",\n  \"tqdm\",\n  \"mistral_common>=1.9.0\",\n  \"transformers==5.3.0\",\n  \"uvicorn\",\n  \"uvloop\",\n  \"xgrammar==0.1.27\",\n  \"smg-grpc-servicer>=0.5.0\",\n]\n\ndiffusion_common = [\n  \"PyYAML==6.0.1\",\n  \"cloudpickle\",\n  \"diffusers==0.37.0\",\n  \"imageio==2.36.0\",\n  \"imageio-ffmpeg==0.5.1\",\n  \"moviepy>=2.0.0\",\n  \"opencv-python-headless==4.10.0.84\",\n  \"remote-pdb\",\n  \"addict\",\n  \"scikit-image==0.25.2\",\n  \"trimesh>=4.0.0\",\n  \"xatlas\",\n]\n\ntracing = [\n  \"opentelemetry-sdk\",\n  \"opentelemetry-api\",\n  \"opentelemetry-exporter-otlp\",\n  \"opentelemetry-exporter-otlp-proto-grpc\",\n]\n\n# HIP (Heterogeneous-computing Interface for Portability) for AMD\n# => base docker rocm/vllm-dev:20250114, not from public vllm whl\nsrt_hip = [\n  \"sglang[runtime_common]\",\n  \"torch\",\n  \"petit_kernel==0.0.2\",\n  \"wave-lang==3.8.2\",\n]\n\ndiffusion_hip = [\n  \"sglang[diffusion_common]\",\n  \"st_attn==0.0.7\",\n  \"vsa==0.0.4\",\n  \"runai_model_streamer>=0.15.5\",\n  \"cache-dit==1.1.8\",\n]\n\n# For Intel Gaudi(device : hpu) follow the installation guide\n# https://docs.vllm.ai/en/latest/getting_started/gaudi-installation.html\nsrt_hpu = [\"sglang[runtime_common]\"]\n\n# https://docs.sglang.io/platforms/mthreads_gpu.md\nsrt_musa = [\n    \"sglang[runtime_common]\",\n    \"torch\",\n    \"torch_musa\",\n    \"torchada>=0.1.25\",\n    \"mthreads-ml-py\",\n    \"numpy<2.0\",\n]\n\ndiffusion_musa = [\n  \"sglang[diffusion_common]\",\n  \"st_attn==0.0.7\",\n  \"vsa==0.0.4\",\n  \"runai_model_streamer>=0.15.5\",\n  \"cache-dit==1.1.8\",\n]\n\n# https://docs.sglang.io/platforms/mps.md\nsrt_mps = [\n    \"sglang[runtime_common]\",\n    \"torch==2.9.1\",\n    \"torchao==0.9.0\",\n    \"torchaudio==2.9.1\",\n    \"torchvision\",\n]\n\ndiffusion_mps = [\n  \"sglang[diffusion_common]\",\n  \"cloudpickle==3.1.2\",\n  \"remote-pdb==2.1.0\",\n  \"cache-dit==1.2.3\",\n  \"addict==2.4.0\",\n  \"av==16.1.0\",\n]\n\ntest = [\n  \"accelerate\",\n  \"expecttest\",\n  \"gguf\",\n  \"jsonlines\",\n  \"matplotlib\",\n  \"pandas\",\n  \"peft>=0.18.0\",\n  \"pytest\",\n  \"sentence_transformers\",\n  \"tabulate\",\n]\n\nall_hip = [\"sglang[srt_hip]\", \"sglang[diffusion_hip]\"]\nall_hpu = [\"sglang[srt_hpu]\"]\nall_musa = [\"sglang[srt_musa]\", \"sglang[diffusion_musa]\"]\nall_mps = [\"sglang[srt_mps]\", \"sglang[diffusion_mps]\"]\n\ndev_hip = [\"sglang[all_hip]\", \"sglang[test]\"]\ndev_hpu = [\"sglang[all_hpu]\", \"sglang[test]\"]\ndev_musa = [\"sglang[all_musa]\", \"sglang[test]\"]\ndev_mps = [\"sglang[all_mps]\", \"sglang[test]\"]\n\n[project.urls]\n\"Homepage\" = \"https://github.com/sgl-project/sglang\"\n\"Bug Tracker\" = \"https://github.com/sgl-project/sglang/issues\"\n\n[project.scripts]\nsglang = \"sglang.cli.main:main\"\n\n[tool.setuptools.package-data]\n\"sglang\" = [\n  \"srt/**/*\",\n  \"jit_kernel/**/*\"\n]\n\n[tool.setuptools.packages.find]\nexclude = [\n  \"assets*\",\n  \"benchmark*\",\n  \"docs*\",\n  \"dist*\",\n  \"playground*\",\n  \"scripts*\",\n  \"tests*\",\n]\n\n[tool.wheel]\nexclude = [\n  \"assets*\",\n  \"benchmark*\",\n  \"docs*\",\n  \"dist*\",\n  \"playground*\",\n  \"scripts*\",\n  \"tests*\",\n]\n\n[tool.setuptools_scm]\nroot = \"..\"\nversion_file = \"sglang/_version.py\"\ngit_describe_command = [\"git\", \"describe\", \"--tags\", \"--long\", \"--match\", \"v*\"]\n"
  },
  {
    "path": "python/pyproject_xpu.toml",
    "content": "[build-system]\nrequires = [\"setuptools>=61.0\", \"setuptools-scm>=8.0\", \"wheel\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"sglang\"\ndynamic = [\"version\"]\ndescription = \"SGLang is a fast serving framework for large language models and vision language models.\"\nreadme = \"README.md\"\nrequires-python = \">=3.10\"\nlicense = { file = \"LICENSE\" }\nclassifiers = [\n  \"Programming Language :: Python :: 3\",\n  \"License :: OSI Approved :: Apache Software License\",\n]\n\ndependencies = [\n  \"torch==2.10.0+xpu\",\n  \"torchcodec==0.10.0 ; sys_platform != 'linux' or (sys_platform == 'linux' and platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')\", # torchcodec does not exist in those systems. torch==2.10.0 on XPU uses 0.10.0\n  \"av ; sys_platform == 'linux' and (platform_machine == 'aarch64' or platform_machine == 'arm64' or platform_machine == 'armv7l')\",\n  \"torchaudio==2.10.0+xpu\",\n  \"torchvision\",\n  \"sgl-kernel @ git+https://github.com/sgl-project/sgl-kernel-xpu.git\",\n  \"IPython\",\n  \"aiohttp\",\n  \"anthropic>=0.20.0\",\n  \"blobfile==3.0.0\",\n  \"build\",\n  \"compressed-tensors\",\n  \"datasets\",\n  \"einops\",\n  \"fastapi\",\n  \"gguf\",\n  \"interegular\",\n  \"llguidance>=0.7.11,<0.8.0\",\n  \"modelscope\",\n  \"msgspec\",\n  \"ninja\",\n  \"numpy\",\n  \"openai-harmony==0.0.4\",\n  \"openai==2.6.1\",\n  \"orjson\",\n  \"outlines==0.1.11\",\n  \"packaging\",\n  \"partial_json_parser\",\n  \"pillow\",\n  \"prometheus-client>=0.20.0\",\n  \"psutil\",\n  \"py-spy\",\n  \"pybase64\",\n  \"pydantic\",\n  \"python-multipart\",\n  \"pyzmq>=25.1.2\",\n  \"requests\",\n  \"scipy\",\n  \"sentencepiece\",\n  \"setproctitle\",\n  \"soundfile==0.13.1\",\n  \"tiktoken\",\n  \"timm==1.0.16\",\n  \"torchao==0.9.0\",\n  \"tqdm\",\n  \"mistral_common>=1.9.0\",\n  \"transformers==5.3.0\",\n  \"uvicorn\",\n  \"uvloop\",\n  # \"xgrammar==0.1.24\", , xgrammar depends on CUDA PyTorch and Triton only\n  \"smg-grpc-servicer>=0.5.0\",\n]\n\n[project.optional-dependencies]\ntracing = [\n  \"opentelemetry-api\",\n  \"opentelemetry-exporter-otlp\",\n  \"opentelemetry-exporter-otlp-proto-grpc\",\n  \"opentelemetry-sdk\",\n]\ntest = [\n  \"accelerate\",\n  \"bitsandbytes\",\n  \"expecttest\",\n  \"jsonlines\",\n  \"lm-eval[api]>=0.4.9.2\",\n  \"matplotlib\",\n  \"pandas\",\n  \"parameterized\",\n  \"peft>=0.18.0\",\n  \"pytest\",\n  \"sentence_transformers\",\n  \"tabulate\",\n]\n\ndev = [\"sglang[test]\"]\n\nall = [\n  \"sglang[tracing]\",\n]\n\n[project.urls]\n\"Homepage\" = \"https://github.com/sgl-project/sglang\"\n\"Bug Tracker\" = \"https://github.com/sgl-project/sglang/issues\"\n\n[project.scripts]\nsglang = \"sglang.cli.main:main\"\n\n[tool.setuptools.package-data]\n\"sglang\" = [\n  \"srt/**/*\",\n  \"jit_kernel/**/*\"\n]\n\n[tool.setuptools.packages.find]\nexclude = [\n  \"assets*\",\n  \"benchmark*\",\n  \"docs*\",\n  \"dist*\",\n  \"playground*\",\n  \"scripts*\",\n  \"tests*\",\n]\n\n[tool.wheel]\nexclude = [\n  \"assets*\",\n  \"benchmark*\",\n  \"docs*\",\n  \"dist*\",\n  \"playground*\",\n  \"scripts*\",\n  \"tests*\",\n]\n\n[tool.setuptools_scm]\nroot = \"..\"\nversion_file = \"sglang/_version.py\"\ngit_describe_command = [\"bash\", \"-c\", \"git tag --list --sort=-version:refname 'v*.*.*' | head -1 | xargs git describe --tags --long\"]\n# Allow editable installs even when .git metadata is not available.\nfallback_version = \"0.0.0.dev0\"\n"
  },
  {
    "path": "python/sglang/README.md",
    "content": "# Code Structure\n\n- `eval`: The evaluation utilities.\n- `lang`: The frontend language.\n- `multimodal_gen`: Inference framework for accelerated image/video generation.\n- `srt`: The backend engine for running local models. (SRT = SGLang Runtime).\n- `test`: The test utilities.\n- `api.py`: The public APIs.\n- `bench_offline_throughput.py`: Benchmark the performance in the offline mode.\n- `bench_one_batch.py`: Benchmark the latency of running a single static batch without a server.\n- `bench_one_batch_server.py`: Benchmark the latency of running a single batch with a server.\n- `bench_serving.py`: Benchmark online serving with dynamic requests.\n- `check_env.py`: Check the environment variables and dependencies.\n- `global_config.py`: The global configs and constants.\n- `launch_server.py`: The entry point for launching a local server.\n- `profiler.py`: The profiling entry point to send profile requests.\n- `utils.py`: Common utilities.\n- `version.py`: Version info.\n"
  },
  {
    "path": "python/sglang/__init__.py",
    "content": "# SGLang public APIs\n\n# Install stubs early for platforms where certain dependencies are unavailable\n# (e.g. macOS/MPS has no triton, and torch.mps lacks Stream / set_device /\n# get_device_properties).  This must run before any downstream imports.\nimport sys as _sys\n\nif _sys.platform == \"darwin\":\n    try:\n        import torch as _torch\n\n        if _torch.backends.mps.is_available():\n            from sglang._triton_stub import install as _install_triton_stub\n\n            _install_triton_stub()\n            del _install_triton_stub\n\n            from sglang._mps_stub import install as _install_mps_stub\n\n            _install_mps_stub()\n            del _install_mps_stub\n        del _torch\n    except ImportError:\n        pass\ndel _sys\n\n# Frontend Language APIs\nfrom sglang.global_config import global_config\nfrom sglang.lang.api import (\n    Engine,\n    Runtime,\n    assistant,\n    assistant_begin,\n    assistant_end,\n    flush_cache,\n    function,\n    gen,\n    gen_int,\n    gen_string,\n    get_server_info,\n    image,\n    select,\n    separate_reasoning,\n    set_default_backend,\n    system,\n    system_begin,\n    system_end,\n    user,\n    user_begin,\n    user_end,\n    video,\n)\nfrom sglang.lang.backend.runtime_endpoint import RuntimeEndpoint\nfrom sglang.lang.choices import (\n    greedy_token_selection,\n    token_length_normalized,\n    unconditional_likelihood_normalized,\n)\n\n# Lazy import some libraries\nfrom sglang.utils import LazyImport\nfrom sglang.version import __version__\n\nAnthropic = LazyImport(\"sglang.lang.backend.anthropic\", \"Anthropic\")\nLiteLLM = LazyImport(\"sglang.lang.backend.litellm\", \"LiteLLM\")\nOpenAI = LazyImport(\"sglang.lang.backend.openai\", \"OpenAI\")\nVertexAI = LazyImport(\"sglang.lang.backend.vertexai\", \"VertexAI\")\n\n# Runtime Engine APIs\nServerArgs = LazyImport(\"sglang.srt.server_args\", \"ServerArgs\")\nEngine = LazyImport(\"sglang.srt.entrypoints.engine\", \"Engine\")\n\n__all__ = [\n    \"Engine\",\n    \"Runtime\",\n    \"assistant\",\n    \"assistant_begin\",\n    \"assistant_end\",\n    \"flush_cache\",\n    \"function\",\n    \"gen\",\n    \"gen_int\",\n    \"gen_string\",\n    \"get_server_info\",\n    \"image\",\n    \"select\",\n    \"separate_reasoning\",\n    \"set_default_backend\",\n    \"system\",\n    \"system_begin\",\n    \"system_end\",\n    \"user\",\n    \"user_begin\",\n    \"user_end\",\n    \"video\",\n    \"RuntimeEndpoint\",\n    \"greedy_token_selection\",\n    \"token_length_normalized\",\n    \"unconditional_likelihood_normalized\",\n    \"ServerArgs\",\n    \"Anthropic\",\n    \"LiteLLM\",\n    \"OpenAI\",\n    \"VertexAI\",\n    \"global_config\",\n    \"__version__\",\n]\n"
  },
  {
    "path": "python/sglang/_mps_stub.py",
    "content": "\"\"\"Stub implementations for APIs missing from ``torch.mps``.\n\n``torch.mps`` lacks several APIs that ``torch.cuda`` provides (``Stream``,\n``set_device``, ``get_device_properties``, …).  Rather than scattering\n``hasattr`` / ``getattr`` guards throughout the codebase, we monkey-patch\n``torch.mps`` once at startup so that generic device-agnostic code paths\njust work.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport functools\nfrom dataclasses import dataclass, field\nfrom typing import Any\n\n\nclass Stream:\n    \"\"\"Minimal stand-in for ``torch.cuda.Stream``.\n\n    MPS does not expose user-visible streams.  Every method is a no-op so\n    that code written for CUDA's multi-stream model still runs.\n    \"\"\"\n\n    def __init__(self, device: Any = None, priority: int = 0) -> None:\n        pass\n\n    def synchronize(self) -> None:\n        pass\n\n    def wait_stream(self, stream: Any) -> None:\n        pass\n\n    def wait_event(self, event: Any) -> None:\n        pass\n\n    def record_event(self, event: Any = None) -> Any:\n        return None\n\n    def query(self) -> bool:\n        return True\n\n    # context-manager protocol (``with stream:``)\n    def __enter__(self) -> \"Stream\":\n        return self\n\n    def __exit__(self, *args: Any) -> None:\n        pass\n\n\nclass Event:\n    \"\"\"Minimal stand-in for ``torch.cuda.Event``.\"\"\"\n\n    def __init__(self, enable_timing: bool = False) -> None:\n        pass\n\n    def record(self, stream: Any = None) -> None:\n        pass\n\n    def wait(self, stream: Any = None) -> None:\n        pass\n\n    def query(self) -> bool:\n        return True\n\n    def synchronize(self) -> None:\n        pass\n\n    def elapsed_time(self, end_event: Any) -> float:\n        return 0.0\n\n\n_default_stream = Stream()\n\n\ndef current_stream(device: Any = None) -> Stream:\n    \"\"\"Return the default (and only) MPS stream.\"\"\"\n    return _default_stream\n\n\ndef stream(s: Any) -> Stream:\n    \"\"\"Return a context manager that is a no-op on MPS.\"\"\"\n    return s if s is not None else _default_stream\n\n\ndef set_device(device: Any) -> None:  # noqa: ARG001\n    \"\"\"Set the current device. This is a no-op for MPS as it has exactly one device.\"\"\"\n    pass\n\n\ndef current_device() -> int:\n    \"\"\"Return the index of the current MPS device (always 0).\"\"\"\n    return 0\n\n\ndef device_count() -> int:\n    \"\"\"Return the number of available MPS devices (always 1).\"\"\"\n    return 1\n\n\n@dataclass\nclass _MPSDeviceProperties:\n    \"\"\"Mimics the object returned by ``torch.cuda.get_device_properties``.\"\"\"\n\n    name: str = \"Apple MPS\"\n    total_memory: int = 0  # populated at install time\n    multi_processor_count: int = 0\n    warp_size: int = 32\n    is_integrated: bool = True\n    major: int = 0\n    minor: int = 0\n    # Extra attrs some callers inspect\n    _extra: dict = field(default_factory=dict)\n\n    def __getattr__(self, name: str) -> Any:\n        # Return a safe default for any attribute we didn't anticipate\n        try:\n            return self._extra[name]\n        except KeyError:\n            return None\n\n\n_cached_props: _MPSDeviceProperties | None = None\n\n\ndef get_device_properties(device: Any = 0) -> _MPSDeviceProperties:  # noqa: ARG001\n    \"\"\"Return the properties of the MPS device. Results are cached after first call.\"\"\"\n    global _cached_props\n    if _cached_props is None:\n        import psutil\n\n        _cached_props = _MPSDeviceProperties(\n            total_memory=psutil.virtual_memory().total,\n        )\n    return _cached_props\n\n\nclass _MPSMemoryTracker:\n    \"\"\"Tracks peak memory values on top of ``torch.mps`` current-value APIs.\n\n    * ``memory_allocated`` → ``torch.mps.current_allocated_memory()``\n    * ``memory_reserved``  → ``torch.mps.driver_allocated_memory()``\n    * ``max_memory_*``     → high-water marks of the above\n    \"\"\"\n\n    def __init__(self) -> None:\n        self._peak_allocated: int = 0\n        self._peak_reserved: int = 0\n\n    def memory_allocated(self, device: Any = None) -> int:  # noqa: ARG002\n        import torch\n\n        val = torch.mps.current_allocated_memory()\n        if val > self._peak_allocated:\n            self._peak_allocated = val\n        return val\n\n    def memory_reserved(self, device: Any = None) -> int:  # noqa: ARG002\n        import torch\n\n        val = torch.mps.driver_allocated_memory()\n        if val > self._peak_reserved:\n            self._peak_reserved = val\n        return val\n\n    def max_memory_allocated(self, device: Any = None) -> int:  # noqa: ARG002\n        self.memory_allocated()\n        return self._peak_allocated\n\n    def max_memory_reserved(self, device: Any = None) -> int:  # noqa: ARG002\n        self.memory_reserved()\n        return self._peak_reserved\n\n    def reset_peak_memory_stats(self, device: Any = None) -> None:  # noqa: ARG002\n        import torch\n\n        self._peak_allocated = torch.mps.current_allocated_memory()\n        self._peak_reserved = torch.mps.driver_allocated_memory()\n\n\n_memory_tracker = _MPSMemoryTracker()\n\n\ndef _patch_non_blocking() -> None:\n    \"\"\"Force ``non_blocking=False`` for copies targeting the MPS device.\n\n    Unlike CUDA, MPS does not guarantee that a subsequent kernel on the same\n    \"stream\" will wait for an async host-to-device transfer to finish.  Reading\n    the tensor before the transfer completes yields uninitialised (garbage)\n    data.  Patching ``Tensor.to`` and ``Tensor.copy_`` centrally avoids having\n    to sprinkle ``non_blocking=not is_mps()`` at every call-site.\n    \"\"\"\n    import torch\n\n    _original_to = torch.Tensor.to\n\n    @functools.wraps(_original_to)\n    def _patched_to(self, *args, **kwargs):\n        if kwargs.get(\"non_blocking\"):\n            # Detect target device from positional or keyword args\n            device = None\n            if args and isinstance(args[0], (str, torch.device)):\n                device = torch.device(args[0]) if isinstance(args[0], str) else args[0]\n            elif \"device\" in kwargs:\n                d = kwargs[\"device\"]\n                device = torch.device(d) if isinstance(d, str) else d\n            if device is not None and device.type == \"mps\":\n                kwargs = {**kwargs, \"non_blocking\": False}\n        return _original_to(self, *args, **kwargs)\n\n    torch.Tensor.to = _patched_to\n\n    _original_copy_ = torch.Tensor.copy_\n\n    @functools.wraps(_original_copy_)\n    def _patched_copy_(self, src, non_blocking=False):\n        if non_blocking and self.device.type == \"mps\":\n            non_blocking = False\n        return _original_copy_(self, src, non_blocking=non_blocking)\n\n    torch.Tensor.copy_ = _patched_copy_\n\n\n_installed = False\n\n\ndef install() -> None:\n    \"\"\"Patch ``torch.mps`` with the stubs above.  Safe to call multiple times.\"\"\"\n    global _installed\n    if _installed:\n        return\n\n    import torch\n\n    mps = torch.mps\n    # Only patch attributes that are actually missing\n    for name, obj in [\n        (\"Stream\", Stream),\n        (\"Event\", Event),\n        (\"current_stream\", current_stream),\n        (\"stream\", stream),\n        (\"set_device\", set_device),\n        (\"current_device\", current_device),\n        (\"device_count\", device_count),\n        (\"get_device_properties\", get_device_properties),\n        (\"reset_peak_memory_stats\", _memory_tracker.reset_peak_memory_stats),\n        (\"memory_allocated\", _memory_tracker.memory_allocated),\n        (\"memory_reserved\", _memory_tracker.memory_reserved),\n        (\"max_memory_allocated\", _memory_tracker.max_memory_allocated),\n        (\"max_memory_reserved\", _memory_tracker.max_memory_reserved),\n    ]:\n        if not hasattr(mps, name):\n            setattr(mps, name, obj)\n\n    _patch_non_blocking()\n\n    _installed = True\n"
  },
  {
    "path": "python/sglang/_triton_stub.py",
    "content": "\"\"\"\nMock triton module for platforms where triton is not available (e.g., macOS/MPS).\n\nThis module provides stub implementations of triton APIs so that modules which\nimport triton at the top level can be loaded without error.  The actual triton\nkernels are never executed on these platforms – alternative backends (e.g. SDPA\nfor MPS) are used instead.\n\nUsage – call ``install()`` **before** any ``import triton`` in the process:\n\n    from sglang._triton_stub import install\n    install()\n\"\"\"\n\nimport sys\nimport types\n\n\nclass _StubBase:\n    \"\"\"A base class that any mock attribute can safely be subclassed from.\n\n    Used when external code does ``class Foo(triton.runtime.KernelInterface):``.\n    \"\"\"\n\n    def __init_subclass__(cls, **kwargs):\n        super().__init_subclass__(**kwargs)\n\n\nclass _MockModule(types.ModuleType):\n    \"\"\"A module whose every attribute is itself a ``_MockModule``.\n\n    When called (e.g. ``@triton.jit``), it acts as a pass-through decorator so\n    that kernel *definitions* are syntactically valid even though they will never\n    be compiled.\n    \"\"\"\n\n    def __init__(self, name: str):\n        super().__init__(name)\n        self.__path__: list[str] = []  # make it look like a package\n        self.__package__ = name\n        self.__file__ = __file__\n        self._children: dict[str, object] = {}\n        # Set __spec__ so that importlib.util.find_spec() works on cached modules\n        import importlib\n\n        self.__spec__ = importlib.machinery.ModuleSpec(name, None, is_package=True)\n\n    def __getattr__(self, name: str):\n        \"\"\"Handle attribute access by creating and returning a child _MockModule.\"\"\"\n        if name.startswith(\"__\") and name.endswith(\"__\"):\n            raise AttributeError(name)\n        full = f\"{self.__name__}.{name}\"\n        if full in sys.modules:\n            return sys.modules[full]\n        # If the name looks like a class (CamelCase / uppercase), return a\n        # stub class that can be used as a base class for inheritance.\n        if name[0:1].isupper():\n            stub_cls = type(name, (_StubBase,), {\"__module__\": self.__name__})\n            self._children[name] = stub_cls\n            return stub_cls\n        child = _MockModule(full)\n        sys.modules[full] = child\n        self._children[name] = child\n        return child\n\n    def __call__(self, *args, **kwargs):\n        # Direct decorator usage:  @triton.jit  (receives the function)\n        if len(args) == 1 and callable(args[0]) and not kwargs:\n            return args[0]\n\n        # Parameterised decorator: @triton.jit(...)  → returns a decorator\n        def _decorator(fn):\n            return fn\n\n        return _decorator\n\n    def __instancecheck__(self, instance):\n        \"\"\"Return False for all instance checks against the mock.\"\"\"\n        return False\n\n    def __contains__(self, item):\n        \"\"\"Return False for all membership checks.\"\"\"\n        return False\n\n    def __iter__(self):\n        return iter([])\n\n    def __len__(self):\n        return 0\n\n    def __bool__(self):\n        return False\n\n    def __repr__(self):\n        return f\"<triton-stub {self.__name__!r}>\"\n\n\ndef _cdiv(a: int, b: int) -> int:\n    \"\"\"Ceiling division – mirrors ``triton.cdiv``.\"\"\"\n    return -(a // -b)\n\n\ndef _next_power_of_2(n: int) -> int:\n    \"\"\"Mirrors ``triton.next_power_of_2``.\"\"\"\n    return 1 << (n - 1).bit_length() if n > 0 else 1\n\n\nclass _Config:\n    \"\"\"Minimal stand-in for ``triton.Config`` used in ``@triton.autotune``.\"\"\"\n\n    def __init__(self, kwargs=None, num_warps=4, num_stages=2, **extra):\n        self.kwargs = kwargs or {}\n        self.num_warps = num_warps\n        self.num_stages = num_stages\n\n\nclass _TritonFinder:\n    \"\"\"A meta-path finder that intercepts all ``import triton.*`` statements.\n\n    When Python encounters ``import triton.backends.compiler``, it walks the\n    dotted path and tries to import each component.  Our mock module's\n    ``__getattr__`` handles *attribute* access, but the import machinery uses\n    ``importlib`` finders, not attribute access, for sub-module resolution.\n    This finder bridges that gap by creating ``_MockModule`` instances for any\n    ``triton.*`` sub-module that isn't already in ``sys.modules``.\n    \"\"\"\n\n    def find_module(self, fullname, path=None):\n        if fullname == \"triton\" or fullname.startswith(\"triton.\"):\n            return self\n        return None\n\n    def load_module(self, fullname):\n        if fullname in sys.modules:\n            return sys.modules[fullname]\n        mod = _MockModule(fullname)\n        sys.modules[fullname] = mod\n        # Wire up the parent relationship\n        parts = fullname.rsplit(\".\", 1)\n        if len(parts) == 2:\n            parent_name, child_name = parts\n            parent = sys.modules.get(parent_name)\n            if parent is not None:\n                setattr(parent, child_name, mod)\n        return mod\n\n\ndef _make_mock(name: str) -> _MockModule:\n    \"\"\"Create a ``_MockModule`` and register it in ``sys.modules``.\"\"\"\n    mod = _MockModule(name)\n    sys.modules[name] = mod\n    return mod\n\n\ndef install() -> None:\n    \"\"\"Register a mock ``triton`` package in *sys.modules*.\n\n    This is a no-op if a real ``triton`` is already importable.\n    \"\"\"\n    if \"triton\" in sys.modules:\n        return\n    # Check whether a real triton exists before installing the stub.\n    import importlib.util\n\n    if importlib.util.find_spec(\"triton\") is not None:\n        return\n\n    # Register the meta-path finder FIRST so that any ``import triton.X``\n    # during the rest of install() (or later) is handled.\n    sys.meta_path.insert(0, _TritonFinder())\n\n    triton = _make_mock(\"triton\")\n    triton.__version__ = \"3.0.0\"\n    triton.cdiv = _cdiv\n    triton.next_power_of_2 = _next_power_of_2\n    triton.Config = _Config\n\n    # triton.language  (commonly imported as ``tl``)\n    tl = _make_mock(\"triton.language\")\n\n    class _constexpr:\n        \"\"\"Stand-in for ``tl.constexpr`` – works as both annotation and value wrapper.\"\"\"\n\n        def __init__(self, value=None):\n            self.value = value\n\n        def __repr__(self):\n            return f\"constexpr({self.value!r})\"\n\n    tl.constexpr = _constexpr\n    triton.language = tl\n\n    # triton.language.extra.libdevice\n    extra = _make_mock(\"triton.language.extra\")\n    tl.extra = extra\n    libdevice = _make_mock(\"triton.language.extra.libdevice\")\n    extra.libdevice = libdevice\n\n    # triton.runtime.jit  (JITFunction used in isinstance checks)\n    runtime = _make_mock(\"triton.runtime\")\n    triton.runtime = runtime\n    jit_mod = _make_mock(\"triton.runtime.jit\")\n\n    class _JITFunction:\n        \"\"\"Dummy so ``isinstance(fn, triton.runtime.jit.JITFunction)`` works.\"\"\"\n\n        pass\n\n    jit_mod.JITFunction = _JITFunction\n    runtime.jit = jit_mod\n\n    # triton.runtime.driver  (used by fla/utils.py)\n    driver = _make_mock(\"triton.runtime.driver\")\n    runtime.driver = driver\n\n    # triton.testing\n    testing = _make_mock(\"triton.testing\")\n    triton.testing = testing\n\n    # triton.tools / triton.tools.tensor_descriptor\n    tools = _make_mock(\"triton.tools\")\n    triton.tools = tools\n    td = _make_mock(\"triton.tools.tensor_descriptor\")\n    tools.tensor_descriptor = td\n\n    # triton.backends / triton.backends.compiler  (used by torch._inductor)\n    backends = _make_mock(\"triton.backends\")\n    triton.backends = backends\n    compiler = _make_mock(\"triton.backends.compiler\")\n    backends.compiler = compiler\n"
  },
  {
    "path": "python/sglang/bench_offline_throughput.py",
    "content": "\"\"\"\nBenchmark the throughput in the offline mode.\nIt accepts server arguments (the same as launch_server.py) and benchmark arguments (the same as bench_serving.py).\n\n# Usage\n## Sharegpt dataset with default args\npython -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --num-prompts 10\n\n## Random dataset with default args\npython -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024\n\"\"\"\n\nimport argparse\nimport asyncio\nimport dataclasses\nimport inspect\nimport json\nimport logging\nimport os\nimport random\nimport time\nfrom typing import Dict, List, Optional\n\nimport numpy as np\n\nfrom sglang.benchmark.datasets import DatasetRow, get_dataset\nfrom sglang.benchmark.datasets.random import sample_random_requests\nfrom sglang.benchmark.utils import get_tokenizer, set_ulimit\nfrom sglang.lang.backend.runtime_endpoint import Runtime\nfrom sglang.srt.entrypoints.engine import Engine\nfrom sglang.srt.server_args import ServerArgs\n\n\n@dataclasses.dataclass\nclass BenchArgs:\n    backend: str = \"engine\"\n    result_filename: str = \"\"\n    dataset_name: str = \"sharegpt\"\n    dataset_path: str = \"\"\n    num_prompts: int = 1000\n    sharegpt_output_len: Optional[int] = None\n    sharegpt_context_len: Optional[int] = None\n    random_input_len: int = 1024\n    random_output_len: int = 1024\n    random_range_ratio: float = 0.0\n    gsp_num_groups: int = 64\n    gsp_prompts_per_group: int = 16\n    gsp_system_prompt_len: int = 2048\n    gsp_question_len: int = 128\n    gsp_output_len: int = 256\n    seed: int = 1\n    disable_ignore_eos: bool = False\n    extra_request_body: Optional[str] = None\n    apply_chat_template: bool = False\n    profile: bool = False\n    skip_warmup: bool = False\n    do_not_exit: bool = False\n    prompt_suffix: str = \"\"\n    return_logprob: bool = False\n    logprob_start_len: int = -1\n\n    @staticmethod\n    def add_cli_args(parser: argparse.ArgumentParser):\n        parser.add_argument(\"--backend\", type=str, default=BenchArgs.backend)\n        parser.add_argument(\n            \"--result-filename\", type=str, default=BenchArgs.result_filename\n        )\n        parser.add_argument(\n            \"--dataset-name\",\n            type=str,\n            default=\"sharegpt\",\n            choices=[\"sharegpt\", \"random\", \"generated-shared-prefix\"],\n            help=\"Name of the dataset to benchmark on.\",\n        )\n        parser.add_argument(\n            \"--dataset-path\", type=str, default=\"\", help=\"Path to the dataset.\"\n        )\n        parser.add_argument(\n            \"--num-prompts\",\n            type=int,\n            default=BenchArgs.num_prompts,\n            help=\"Number of prompts to process. Default is 1000.\",\n        )\n        parser.add_argument(\n            \"--sharegpt-output-len\",\n            type=int,\n            default=BenchArgs.sharegpt_output_len,\n            help=\"Output length for each request. Overrides the output length from the ShareGPT dataset.\",\n        )\n        parser.add_argument(\n            \"--sharegpt-context-len\",\n            type=int,\n            default=BenchArgs.sharegpt_context_len,\n            help=\"The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped.\",\n        )\n        parser.add_argument(\n            \"--random-input-len\",\n            type=int,\n            default=BenchArgs.random_input_len,\n            help=\"Number of input tokens per request, used only for random dataset.\",\n        )\n        parser.add_argument(\n            \"--random-output-len\",\n            type=int,\n            default=BenchArgs.random_output_len,\n            help=\"Number of output tokens per request, used only for random dataset.\",\n        )\n        parser.add_argument(\n            \"--random-range-ratio\",\n            type=float,\n            default=BenchArgs.random_range_ratio,\n            help=\"Range of sampled ratio of input/output length, \"\n            \"used only for random dataset.\",\n        )\n        parser.add_argument(\n            \"--gsp-num-groups\",\n            type=int,\n            default=BenchArgs.gsp_num_groups,\n            help=\"Number of groups with shared prefix, used\"\n            \"only for generate-shared-prefix\",\n        )\n        parser.add_argument(\n            \"--gsp-prompts-per-group\",\n            type=int,\n            default=BenchArgs.gsp_prompts_per_group,\n            help=\"Number of prompts per group of shared prefix, used\"\n            \"only for generate-shared-prefix\",\n        )\n        parser.add_argument(\n            \"--gsp-system-prompt-len\",\n            type=int,\n            default=BenchArgs.gsp_system_prompt_len,\n            help=\"System prompt length, used\" \"only for generate-shared-prefix\",\n        )\n        parser.add_argument(\n            \"--gsp-question-len\",\n            type=int,\n            default=BenchArgs.gsp_question_len,\n            help=\"Question length, used\" \"only for generate-shared-prefix\",\n        )\n        parser.add_argument(\n            \"--gsp-output-len\",\n            type=int,\n            default=BenchArgs.gsp_output_len,\n            help=\"Target length in tokens for outputs in generated-shared-prefix dataset\",\n        )\n        parser.add_argument(\"--seed\", type=int, default=1, help=\"The random seed.\")\n        parser.add_argument(\n            \"--disable-ignore-eos\",\n            action=\"store_true\",\n            help=\"Disable ignore EOS token\",\n        )\n        parser.add_argument(\n            \"--extra-request-body\",\n            metavar='{\"key1\": \"value1\", \"key2\": \"value2\"}',\n            type=str,\n            default=BenchArgs.extra_request_body,\n            help=\"Append given JSON object to the request payload. You can use this to specify\"\n            \"additional generate params like sampling params.\",\n        )\n        parser.add_argument(\n            \"--apply-chat-template\",\n            action=\"store_true\",\n            help=\"Apply chat template\",\n        )\n        parser.add_argument(\n            \"--profile\",\n            action=\"store_true\",\n            help=\"Use Torch Profiler. The endpoint must be launched with \"\n            \"SGLANG_TORCH_PROFILER_DIR to enable profiler.\",\n        )\n        parser.add_argument(\n            \"--skip-warmup\",\n            action=\"store_true\",\n            help=\"Skip the warmup batches.\",\n        )\n        parser.add_argument(\n            \"--do-not-exit\",\n            action=\"store_true\",\n            help=\"Do not exit the program. This is useful for nsys profile with --duration and --delay.\",\n        )\n        parser.add_argument(\n            \"--prompt-suffix\",\n            type=str,\n            default=\"\",\n            help=\"Suffix applied to the end of all user prompts, followed by assistant prompt suffix.\",\n        )\n        parser.add_argument(\n            \"--return-logprob\",\n            action=\"store_true\",\n            help=\"Enable returning log probabilities.\",\n        )\n        parser.add_argument(\n            \"--logprob-start-len\",\n            type=int,\n            default=-1,\n            help=\"Start length for logprob. -1 means only return logprobs for output tokens (default). 0 means return logprobs for all tokens including input.\",\n        )\n\n    @classmethod\n    def from_cli_args(cls, args: argparse.Namespace):\n        attrs = [attr.name for attr in dataclasses.fields(cls)]\n        return cls(**{attr: getattr(args, attr) for attr in attrs})\n\n\ndef throughput_test_once(\n    backend_name: str,\n    backend,\n    reqs: List[DatasetRow],\n    ignore_eos: bool,\n    extra_request_body: Dict,\n    profile: bool,\n    return_logprob: bool = False,\n    logprob_start_len: int = -1,\n):\n    measurement_results = {\n        \"backend\": backend_name,\n        \"successful_requests\": len(reqs),\n        \"total_latency\": -1,\n        \"total_input_tokens\": sum(r.prompt_len for r in reqs),\n        \"total_output_tokens\": -1,\n        \"request_throughput\": -1,\n        \"input_throughput\": -1,\n        \"output_throughput\": -1,\n        \"total_throughput\": -1,\n    }\n\n    prompt = [r.prompt for r in reqs]\n    sampling_params = [\n        {\n            \"temperature\": 0,\n            \"max_new_tokens\": r.output_len,\n            \"ignore_eos\": ignore_eos,\n            **extra_request_body,\n        }\n        for r in reqs\n    ]\n\n    if profile:\n        assert (\n            \"SGLANG_TORCH_PROFILER_DIR\" in os.environ\n        ), \"Please set SGLANG_TORCH_PROFILER_DIR.\"\n        os.makedirs(os.environ[\"SGLANG_TORCH_PROFILER_DIR\"], exist_ok=True)\n        backend.start_profile()\n\n    st = time.perf_counter()\n    gen_out = backend.generate(\n        prompt=prompt,\n        sampling_params=sampling_params,\n        return_logprob=return_logprob,\n        logprob_start_len=logprob_start_len,\n    )\n    latency = time.perf_counter() - st\n\n    if profile:\n        dir = os.getenv(\"SGLANG_TORCH_PROFILER_DIR\")\n        known_files = set(os.listdir(dir))\n        backend.stop_profile()\n        monitor_trace_file(known_files, dir)\n\n    if backend_name == \"runtime\":\n        gen_out = json.loads(gen_out)\n\n    server_info = backend.get_server_info()\n\n    measurement_results[\"total_latency\"] = latency\n    measurement_results[\"total_output_tokens\"] = sum(\n        o[\"meta_info\"][\"completion_tokens\"] for o in gen_out\n    )\n    measurement_results[\"request_throughput\"] = (\n        measurement_results[\"successful_requests\"] / latency\n    )\n    measurement_results[\"input_throughput\"] = (\n        measurement_results[\"total_input_tokens\"] / latency\n    )\n    measurement_results[\"output_throughput\"] = (\n        measurement_results[\"total_output_tokens\"] / latency\n    )\n    measurement_results[\"total_throughput\"] = (\n        measurement_results[\"total_input_tokens\"]\n        + measurement_results[\"total_output_tokens\"]\n    ) / latency\n\n    if inspect.isawaitable(server_info):\n        server_info = asyncio.run(server_info)\n\n    measurement_results[\"last_gen_throughput\"] = server_info[\"internal_states\"][0][\n        \"last_gen_throughput\"\n    ]\n\n    return measurement_results\n\n\ndef monitor_trace_file(known_files, directory, interval=1):\n    print(f\"Monitoring {directory} for new trace files...\")\n\n    while True:\n        flag = False\n        time.sleep(interval)\n        current_files = set(os.listdir(directory))\n\n        new_files = current_files - known_files\n        for new_file in new_files:\n            new_file_path = os.path.join(directory, new_file)\n            print(f\"New file detected: {new_file}\")\n\n            previous_size = 0\n            while True:\n                try:\n                    current_size = os.path.getsize(new_file_path)\n                except FileNotFoundError:\n                    print(f\"File {new_file} is no longer accessible.\")\n                    break\n\n                if current_size > previous_size:\n                    previous_size = current_size\n                else:\n                    flag = True\n                    break\n\n                time.sleep(interval)\n        if flag:\n            break\n\n\ndef _create_ray_engine_backend(server_args: ServerArgs):\n    \"\"\"Create a RayEngine inside a Ray actor on a placement group.\n\n    RayEngine requires a placement group, so we launch it inside a Ray actor\n    and return a lightweight proxy that forwards calls via ray.get().\n    \"\"\"\n    import ray\n    from ray.runtime_env import RuntimeEnv\n    from ray.util.placement_group import placement_group\n    from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy\n\n    env_vars = {\"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES\": \"1\"}\n    if os.environ.get(\"HF_TOKEN\"):\n        env_vars[\"HF_TOKEN\"] = os.environ[\"HF_TOKEN\"]\n    if not ray.is_initialized():\n        ray.init(runtime_env=RuntimeEnv(env_vars=env_vars))\n\n    total_gpus = server_args.tp_size * server_args.pp_size\n    pg = placement_group([{\"CPU\": 1, \"GPU\": total_gpus}], strategy=\"STRICT_PACK\")\n    ray.get(pg.ready())\n\n    @ray.remote\n    class _EngineActor:\n        def __init__(self, **kwargs):\n            from sglang.srt.ray.engine import RayEngine\n\n            self.engine = RayEngine(**kwargs)\n\n        def call(self, method, **kwargs):\n            return getattr(self.engine, method)(**kwargs)\n\n    actor = _EngineActor.options(\n        num_cpus=1,\n        num_gpus=0,\n        scheduling_strategy=PlacementGroupSchedulingStrategy(\n            placement_group=pg,\n            placement_group_bundle_index=0,\n        ),\n    ).remote(**dataclasses.asdict(server_args))\n\n    class _Proxy:\n        \"\"\"Forwards method calls to the remote RayEngine actor.\"\"\"\n\n        def generate(self, **kwargs):\n            return ray.get(actor.call.remote(\"generate\", **kwargs))\n\n        def get_server_info(self, **kwargs):\n            return ray.get(actor.call.remote(\"get_server_info\", **kwargs))\n\n        def start_profile(self, **kwargs):\n            return ray.get(actor.call.remote(\"start_profile\", **kwargs))\n\n        def stop_profile(self, **kwargs):\n            return ray.get(actor.call.remote(\"stop_profile\", **kwargs))\n\n        def shutdown(self):\n            try:\n                ray.get(actor.call.remote(\"shutdown\"), timeout=60)\n            except Exception:\n                pass\n            try:\n                ray.util.remove_placement_group(pg)\n            except Exception:\n                pass\n\n    return _Proxy()\n\n\ndef throughput_test(\n    server_args: ServerArgs,\n    bench_args: BenchArgs,\n):\n    if bench_args.backend == \"engine\":\n        if server_args.use_ray:\n            backend = _create_ray_engine_backend(server_args)\n        else:\n            backend = Engine(**dataclasses.asdict(server_args))\n        if not backend:\n            raise ValueError(\"Please provide valid engine arguments\")\n    elif bench_args.backend == \"runtime\":\n        backend = Runtime(**dataclasses.asdict(server_args))\n    else:\n        raise ValueError('Please set backend to either \"engine\" or \"runtime\"')\n\n    tokenizer_id = server_args.tokenizer_path or server_args.model_path\n    tokenizer = get_tokenizer(tokenizer_id)\n\n    # Set global environments\n    set_ulimit()\n    random.seed(bench_args.seed)\n    np.random.seed(bench_args.seed)\n\n    # Parse args\n    extra_request_body = {}\n    if bench_args.extra_request_body:\n        extra_request_body = json.loads(args.extra_request_body)\n\n    # Read dataset\n    input_requests = get_dataset(bench_args, tokenizer)\n\n    warmup_requests = sample_random_requests(\n        input_len=256,\n        output_len=16,\n        num_prompts=min(bench_args.num_prompts, 16),\n        range_ratio=1.0,\n        tokenizer=tokenizer,\n        dataset_path=bench_args.dataset_path,\n    )\n\n    # Warm up\n    if not bench_args.skip_warmup:\n        logging.info(\"\\nWarmup...\")\n        throughput_test_once(\n            backend_name=bench_args.backend,\n            backend=backend,\n            reqs=warmup_requests,\n            ignore_eos=not bench_args.disable_ignore_eos,\n            extra_request_body=extra_request_body,\n            profile=False,\n            return_logprob=bench_args.return_logprob,\n            logprob_start_len=bench_args.logprob_start_len,\n        )\n        time.sleep(0.5)\n\n    logging.info(\"\\nBenchmark...\")\n    result = throughput_test_once(\n        backend_name=bench_args.backend,\n        backend=backend,\n        reqs=input_requests,\n        ignore_eos=not bench_args.disable_ignore_eos,\n        extra_request_body=extra_request_body,\n        profile=bench_args.profile,\n        return_logprob=bench_args.return_logprob,\n        logprob_start_len=bench_args.logprob_start_len,\n    )\n    backend.shutdown()\n\n    if bench_args.result_filename:\n        with open(bench_args.result_filename, \"a\") as fout:\n            fout.write(json.dumps(result) + \"\\n\")\n\n    print(\n        \"\\n{s:{c}^{n}}\".format(s=\" Offline Throughput Benchmark Result \", n=50, c=\"=\")\n    )\n    print(\"{:<40} {:<10}\".format(\"Backend:\", result[\"backend\"]))\n    print(\"{:<40} {:<10}\".format(\"Successful requests:\", result[\"successful_requests\"]))\n    print(\"{:<40} {:<10.2f}\".format(\"Benchmark duration (s):\", result[\"total_latency\"]))\n    print(\"{:<40} {:<10}\".format(\"Total input tokens:\", result[\"total_input_tokens\"]))\n    print(\n        \"{:<40} {:<10}\".format(\"Total generated tokens:\", result[\"total_output_tokens\"])\n    )\n    print(\n        \"{:<40} {:<10.2f}\".format(\n            \"Last generation throughput (tok/s):\", result[\"last_gen_throughput\"]\n        )\n    )\n    print(\n        \"{:<40} {:<10.2f}\".format(\n            \"Request throughput (req/s):\", result[\"request_throughput\"]\n        )\n    )\n    print(\n        \"{:<40} {:<10.2f}\".format(\n            \"Input token throughput (tok/s):\", result[\"input_throughput\"]\n        )\n    )\n    print(\n        \"{:<40} {:<10.2f}\".format(\n            \"Output token throughput (tok/s):\", result[\"output_throughput\"]\n        )\n    )\n    print(\n        \"{:<40} {:<10.2f}\".format(\n            \"Total token throughput (tok/s):\", result[\"total_throughput\"]\n        )\n    )\n    print(\"=\" * 50)\n\n    return result\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    ServerArgs.add_cli_args(parser)\n    BenchArgs.add_cli_args(parser)\n    args = parser.parse_args()\n\n    # handling ModelScope model downloads\n    if os.getenv(\"SGLANG_USE_MODELSCOPE\", \"false\").lower() in (\"true\", \"1\"):\n        if os.path.exists(args.model_path):\n            print(f\"Using local model path: {args.model_path}\")\n        else:\n            try:\n                from modelscope import snapshot_download\n\n                print(f\"Using ModelScope to download model: {args.model_path}\")\n\n                # download the model and replace args.model_path\n                args.model_path = snapshot_download(\n                    args.model_path,\n                )\n                print(f\"Model downloaded to: {args.model_path}\")\n            except Exception as e:\n                print(f\"ModelScope download failed: {str(e)}\")\n                raise e\n\n    server_args = ServerArgs.from_cli_args(args)\n    bench_args = BenchArgs.from_cli_args(args)\n\n    logging.basicConfig(\n        level=getattr(logging, server_args.log_level.upper()),\n        format=\"%(message)s\",\n    )\n\n    throughput_test(server_args, bench_args)\n\n    while bench_args.do_not_exit:\n        pass\n"
  },
  {
    "path": "python/sglang/bench_one_batch.py",
    "content": "\"\"\"\nBenchmark the latency of running a single static batch without a server.\n\nThis script does not launch a server and uses the low-level APIs.\nIt accepts server arguments (the same as launch_server.py) and benchmark arguments (e.g., batch size, input lengths).\n\n# Usage (latency test)\n## with dummy weights:\npython -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy\n## sweep through multiple data points and store (append) the results in a jsonl file:\npython -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --run-name test_run\n## run with profiling:\npython -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --profile\n## run with profiling to custom directory:\nexport SGLANG_TORCH_PROFILER_DIR=/root/sglang/profile_log\npython -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 --input-len 256 --profile\n## run with CUDA profiler (nsys):\nnsys profile --force-overwrite=true -o bench_one_batch python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 --input-len 256 --profile --profile-activities CUDA_PROFILER\n# Usage (correctness test):\npython -m sglang.bench_one_batch --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct\n\n## Reference output (of the correctness test above, can be gpu dependent):\ninput_ids=[[1, 450, 7483, 310, 3444, 338], [1, 450, 7483, 310, 278, 3303, 13187, 290, 338], [1, 20628, 338, 263, 6575, 1460, 2462, 322, 306, 763]]\n\nprefill logits (first half): tensor([[-10.0312,  -9.5000,   0.8931,  ...,  -4.9414,  -3.2422,  -3.3633],\n        [-10.0312,  -9.5000,   0.8931,  ...,  -4.9414,  -3.2422,  -3.3633],\n        [ -9.1875, -10.2500,   2.7129,  ...,  -4.3359,  -4.0664,  -4.1328]],\n       device='cuda:0')\n\nprefill logits (final): tensor([[-8.3125, -7.1172,  3.3457,  ..., -4.9570, -4.1328, -3.4141],\n        [-8.9141, -9.0156,  4.1445,  ..., -4.9922, -4.4961, -4.0781],\n        [-9.6328, -9.0547,  4.0195,  ..., -5.3047, -4.7148, -4.4570]],\n       device='cuda:0')\n\n========== Prompt 0 ==========\n<s> The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\n\n\n========== Prompt 1 ==========\n<s> The capital of the United Kindom is London.\nThe capital of the United Kingdom is London.\nThe capital of the\n\n========== Prompt 2 ==========\n<s> Today is a sunny day and I like to go for a walk in the park.\nI'm going to the park\n\"\"\"\n\nimport argparse\nimport copy\nimport dataclasses\nimport itertools\nimport json\nimport logging\nimport multiprocessing\nimport os\nimport time\nfrom types import SimpleNamespace\nfrom typing import Optional, Tuple\n\nimport numpy as np\nimport torch\nimport torch.distributed as dist\n\nfrom sglang.srt.configs.model_config import ModelConfig\nfrom sglang.srt.distributed.parallel_state import destroy_distributed_environment\nfrom sglang.srt.entrypoints.engine import _set_envs_and_config\nfrom sglang.srt.layers.moe import initialize_moe_config\nfrom sglang.srt.layers.quantization.fp4_utils import initialize_fp4_gemm_config\nfrom sglang.srt.layers.quantization.fp8_utils import initialize_fp8_gemm_config\nfrom sglang.srt.managers.schedule_batch import Req, ScheduleBatch\nfrom sglang.srt.managers.scheduler_dp_attn_mixin import prepare_mlp_sync_batch_raw\nfrom sglang.srt.mem_cache.base_prefix_cache import EvictParams\nfrom sglang.srt.model_executor.forward_batch_info import ForwardBatch\nfrom sglang.srt.model_executor.model_runner import ModelRunner\nfrom sglang.srt.sampling.sampling_params import SamplingParams\nfrom sglang.srt.server_args import PortArgs, ServerArgs\nfrom sglang.srt.speculative.spec_info import SpeculativeAlgorithm\nfrom sglang.srt.utils import (\n    configure_logger,\n    get_bool_env_var,\n    kill_process_tree,\n    maybe_reindex_device_id,\n    require_mlp_sync,\n    require_mlp_tp_gather,\n    set_gpu_proc_affinity,\n    suppress_other_loggers,\n)\nfrom sglang.srt.utils.hf_transformers_utils import get_tokenizer\n\n\ndef start_profile(profile_activities, profile_record_shapes=False, rank_print=print):\n    \"\"\"\n    Abstracted function to start profiling based on profile_activities.\n    Returns profiler object (or None).\n    \"\"\"\n    if \"CUDA_PROFILER\" in profile_activities:\n        try:\n            torch.cuda.cudart().cudaProfilerStart()\n            rank_print(\"CUDA Profiler started (nsys will begin capturing)\")\n        except Exception as e:\n            rank_print(f\"Failed to start CUDA profiler: {e}\")\n        return None\n    else:\n        activities = []\n        if \"CPU\" in profile_activities:\n            activities.append(torch.profiler.ProfilerActivity.CPU)\n        if \"GPU\" in profile_activities:\n            activities.append(torch.profiler.ProfilerActivity.CUDA)\n        if \"XPU\" in profile_activities:\n            activities.append(torch.profiler.ProfilerActivity.XPU)\n        if activities:\n            profiler = torch.profiler.profile(\n                activities=activities,\n                with_stack=True,\n                record_shapes=profile_record_shapes,\n            )\n            profiler.start()\n            return profiler\n        return None\n\n\ndef stop_profile(\n    profiler,\n    profile_activities,\n    rank_print=print,\n    save_trace=False,\n    trace_filename=None,\n    stage=None,\n):\n    \"\"\"\n    Abstracted function to stop profiling based on profile_activities.\n    Optionally saves trace results and prints completion messages.\n    \"\"\"\n    if \"CUDA_PROFILER\" in profile_activities:\n        try:\n            torch.cuda.cudart().cudaProfilerStop()\n            rank_print(\"CUDA Profiler stopped (nsys should dump traces)\")\n        except Exception as e:\n            rank_print(f\"Failed to stop CUDA profiler: {e}\")\n    elif profiler is not None:\n        profiler.stop()\n\n    if save_trace:\n        if profiler is not None:\n            if trace_filename:\n                _save_profile_trace_results(profiler, trace_filename)\n                stage_desc = f\"for {stage}\" if stage else \"\"\n                rank_print(\n                    f\"torch profiler chrome trace {stage_desc} saved to {trace_filename}\"\n                )\n        if \"CUDA_PROFILER\" in profile_activities:\n            rank_print(f\"CUDA profiler trace for {stage} completed\")\n\n\n@dataclasses.dataclass\nclass BenchArgs:\n    run_name: str = \"default\"\n    batch_size: Tuple[int] = (1,)\n    input_len: Tuple[int] = (1024,)\n    output_len: Tuple[int] = (16,)\n    prompt_filename: str = \"\"\n    result_filename: str = \"result.jsonl\"\n    correctness_test: bool = False\n    # This is only used for correctness test\n    cut_len: int = 4\n    log_decode_step: int = 0\n    profile: bool = False\n    profile_record_shapes: bool = False\n    profile_activities: Tuple[str] = (\"CPU\", \"GPU\")\n    profile_stage: str = \"all\"\n    profile_filename_prefix: str = \"profile\"\n    profile_start_step: Optional[int] = None\n    profile_steps: Optional[int] = None\n\n    @staticmethod\n    def add_cli_args(parser: argparse.ArgumentParser):\n        parser.add_argument(\"--run-name\", type=str, default=BenchArgs.run_name)\n        parser.add_argument(\n            \"--batch-size\", type=int, nargs=\"+\", default=BenchArgs.batch_size\n        )\n        parser.add_argument(\n            \"--input-len\", type=int, nargs=\"+\", default=BenchArgs.input_len\n        )\n        parser.add_argument(\n            \"--output-len\", type=int, nargs=\"+\", default=BenchArgs.output_len\n        )\n        parser.add_argument(\n            \"--prompt-filename\", type=str, default=BenchArgs.prompt_filename\n        )\n        parser.add_argument(\n            \"--result-filename\", type=str, default=BenchArgs.result_filename\n        )\n        parser.add_argument(\"--correctness-test\", action=\"store_true\")\n        parser.add_argument(\"--cut-len\", type=int, default=BenchArgs.cut_len)\n        parser.add_argument(\n            \"--log-decode-step\",\n            type=int,\n            default=BenchArgs.log_decode_step,\n            help=\"Log decode latency by step, default is set to zero to disable.\",\n        )\n        parser.add_argument(\"--profile\", action=\"store_true\", help=\"Enable profiling.\")\n        parser.add_argument(\n            \"--profile-record-shapes\",\n            action=\"store_true\",\n            help=\"Record tensor shapes in profiling results.\",\n        )\n        parser.add_argument(\n            \"--profile-activities\",\n            type=str,\n            nargs=\"+\",\n            default=[\"CPU\", \"GPU\"],\n            choices=[\"CPU\", \"GPU\", \"CUDA_PROFILER\", \"XPU\"],\n            help=\"Profiler activities: CPU, GPU, XPU, CUDA_PROFILER. If CPU/GPU/XPU, use torch profiler. If CUDA_PROFILER, use CUDA profiler.\",\n        )\n        parser.add_argument(\n            \"--profile-stage\",\n            type=str,\n            default=BenchArgs.profile_stage,\n            choices=[\"all\", \"prefill\", \"decode\"],\n            help=\"Which stage to profile: all, prefill, or decode only.\",\n        )\n        parser.add_argument(\n            \"--profile-filename-prefix\",\n            type=str,\n            default=BenchArgs.profile_filename_prefix,\n            help=\"Prefix of the profiling file names. The full profiling result file(s) be \"\n            '\"[profile_filename_prefix]_batch[batch_size]_input[input_len]_output[output_len].trace.json.gz\"',\n        )\n        parser.add_argument(\n            \"--profile-start-step\",\n            type=int,\n            default=None,\n            help=\"Decode step at which to start profiling (0-indexed). If not specified, defaults to output_len // 2.\",\n        )\n        parser.add_argument(\n            \"--profile-steps\",\n            type=int,\n            default=None,\n            help=\"Number of decode steps to profile starting from profile-start-step. If not specified, profiles only one step.\",\n        )\n\n    @classmethod\n    def from_cli_args(cls, args: argparse.Namespace):\n        # use the default value's type to cast the args into correct types.\n        attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]\n        result = {}\n        for attr, attr_type in attrs:\n            value = getattr(args, attr)\n            # Handle None values - don't try to cast them\n            if value is None or attr_type == type(None):\n                result[attr] = value\n            else:\n                result[attr] = attr_type(value)\n        return cls(**result)\n\n\ndef load_model(server_args, port_args, gpu_id, tp_rank):\n    suppress_other_loggers()\n    rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None\n    moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)\n\n    model_config = ModelConfig.from_server_args(server_args)\n    model_runner = ModelRunner(\n        model_config=model_config,\n        mem_fraction_static=server_args.mem_fraction_static,\n        gpu_id=gpu_id,\n        tp_rank=tp_rank,\n        tp_size=server_args.tp_size,\n        moe_ep_rank=moe_ep_rank,\n        moe_ep_size=server_args.ep_size,\n        pp_rank=0,\n        pp_size=1,\n        nccl_port=port_args.nccl_port,\n        server_args=server_args,\n    )\n    rank_print(f\"max_total_num_tokens={model_runner.max_total_num_tokens}\")\n    tokenizer = get_tokenizer(\n        server_args.tokenizer_path,\n        tokenizer_mode=server_args.tokenizer_mode,\n        trust_remote_code=server_args.trust_remote_code,\n    )\n    if server_args.tp_size > 1:\n        dist.barrier()\n    return model_runner, tokenizer\n\n\ndef prepare_inputs_for_correctness_test(bench_args, tokenizer, custom_prompts):\n    if custom_prompts:\n        custom_input_len = len(custom_prompts)\n        bs = bench_args.batch_size[0]\n        if custom_input_len > bs:\n            logging.warning(\n                f\"Custom input size ({custom_input_len}) is larger than batch_size ({bs}). \"\n                f\"Using the first {bs} prompts.\"\n            )\n            custom_prompts = custom_prompts[:bs]\n\n    prompts = (\n        custom_prompts\n        if custom_prompts\n        else [\n            \"The capital of France is\",\n            \"The capital of the United Kindom is\",\n            \"Today is a sunny day and I like\",\n        ]\n    )\n    input_ids = [tokenizer.encode(p) for p in prompts]\n    sampling_params = SamplingParams(\n        temperature=0,\n        max_new_tokens=BenchArgs.output_len,\n    )\n\n    reqs = []\n    for i in range(len(prompts)):\n        assert len(input_ids[i]) > bench_args.cut_len\n\n        tmp_input_ids = input_ids[i][: bench_args.cut_len]\n        req = Req(\n            rid=i,\n            origin_input_text=prompts[i],\n            origin_input_ids=tmp_input_ids,\n            sampling_params=sampling_params,\n        )\n        req.fill_ids = req.origin_input_ids\n        req.logprob_start_len = -1\n        req.set_extend_input_len(len(req.fill_ids) - len(req.prefix_indices))\n        reqs.append(req)\n\n    return input_ids, reqs\n\n\ndef prepare_extend_inputs_for_correctness_test(\n    bench_args, input_ids, reqs, model_runner\n):\n    for i in range(len(reqs)):\n        req: Req = reqs[i]\n        req.fill_ids += input_ids[i][bench_args.cut_len :]\n        req.prefix_indices = model_runner.req_to_token_pool.req_to_token[\n            i, : bench_args.cut_len\n        ].to(req.prefix_indices.dtype)\n        req.logprob_start_len = -1\n        req.set_extend_input_len(len(req.fill_ids) - len(req.prefix_indices))\n    return reqs\n\n\ndef prepare_synthetic_inputs_for_latency_test(\n    batch_size, input_len, custom_inputs=None\n):\n    input_ids = (\n        custom_inputs\n        if custom_inputs\n        else np.random.randint(0, 10000, (batch_size, input_len), dtype=np.int32)\n    )\n    sampling_params = SamplingParams(\n        temperature=0,\n        max_new_tokens=BenchArgs.output_len,\n    )\n\n    reqs = []\n    for i in range(len(input_ids)):\n        req = Req(\n            rid=i,\n            origin_input_text=\"\",\n            origin_input_ids=list(input_ids[i]),\n            sampling_params=sampling_params,\n        )\n        req.fill_ids = req.origin_input_ids\n        req.logprob_start_len = -1\n        req.set_extend_input_len(len(req.fill_ids) - len(req.prefix_indices))\n        reqs.append(req)\n\n    return reqs\n\n\nclass TreeCacheNamespace(SimpleNamespace):\n    def supports_swa(self) -> bool:\n        return False\n\n    def supports_mamba(self) -> bool:\n        return False\n\n    def is_chunk_cache(self) -> bool:\n        return False\n\n    def is_tree_cache(self) -> bool:\n        return not self.is_chunk_cache()\n\n    def evict(self, params: EvictParams):\n        pass\n\n\n@torch.no_grad\ndef extend(reqs, model_runner):\n    # Create dummy tree_cache for benchmarks (no prefix caching, just allocation)\n    dummy_tree_cache = TreeCacheNamespace(\n        page_size=model_runner.server_args.page_size,\n        device=model_runner.device,\n        token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator,\n    )\n\n    batch = ScheduleBatch.init_new(\n        reqs=reqs,\n        req_to_token_pool=model_runner.req_to_token_pool,\n        token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator,\n        tree_cache=dummy_tree_cache,\n        model_config=model_runner.model_config,\n        enable_overlap=False,\n        spec_algorithm=SpeculativeAlgorithm.NONE,\n    )\n    batch.prepare_for_extend()\n    _maybe_prepare_mlp_sync_batch(batch, model_runner)\n    model_worker_batch = batch.get_model_worker_batch()\n    forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)\n    logits_output = model_runner.forward(forward_batch).logits_output\n    next_token_ids = model_runner.sample(logits_output, forward_batch)\n    return next_token_ids, logits_output.next_token_logits, batch\n\n\n@torch.no_grad\ndef decode(input_token_ids, batch, model_runner):\n    batch.output_ids = input_token_ids\n    batch.prepare_for_decode()\n    _maybe_prepare_mlp_sync_batch(batch, model_runner)\n    model_worker_batch = batch.get_model_worker_batch()\n    forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)\n    logits_output = model_runner.forward(forward_batch).logits_output\n    next_token_ids = model_runner.sample(logits_output, forward_batch)\n    return next_token_ids, logits_output.next_token_logits\n\n\ndef _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):\n    if require_mlp_sync(model_runner.server_args):\n        prepare_mlp_sync_batch_raw(\n            batch,\n            dp_size=model_runner.server_args.dp_size,\n            attn_tp_size=1,\n            tp_group=model_runner.tp_group,\n            get_idle_batch=None,\n            disable_cuda_graph=model_runner.server_args.disable_cuda_graph,\n            require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args),\n            disable_overlap_schedule=model_runner.server_args.disable_overlap_schedule,\n            offload_tags=set(),\n        )\n\n\ndef _read_prompts_from_file(prompt_file, rank_print):\n    \"\"\"Read custom prompts from the file specified by `--prompt-filename`.\"\"\"\n    if not prompt_file:\n        return []\n    if not os.path.exists(prompt_file):\n        rank_print(\n            f\"Custom prompt file {prompt_file} not found. Using default inputs...\"\n        )\n        return []\n    with open(prompt_file, \"r\") as pf:\n        return pf.readlines()\n\n\ndef _get_torch_profiler_output_dir():\n    return os.environ.get(\"SGLANG_TORCH_PROFILER_DIR\", \"/tmp\")\n\n\ndef _create_torch_profiler_filename(\n    profile_filename_prefix, batch_size, input_len, output_len, stage\n):\n    output_dir = _get_torch_profiler_output_dir()\n    filename = f\"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_{stage}.trace.json.gz\"\n    return os.path.join(output_dir, filename)\n\n\ndef _save_profile_trace_results(profiler, filename):\n    parent_dir = os.path.dirname(os.path.abspath(filename))\n    os.makedirs(parent_dir, exist_ok=True)\n    profiler.export_chrome_trace(filename)\n    print(\n        profiler.key_averages(group_by_input_shape=True).table(\n            sort_by=\"self_cpu_time_total\"\n        )\n    )\n\n\ndef correctness_test(\n    server_args,\n    port_args,\n    bench_args,\n    gpu_id,\n    tp_rank,\n):\n    # Configure the logger\n    configure_logger(server_args, prefix=f\" TP{tp_rank}\")\n    rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None\n\n    # Load the model\n    model_runner, tokenizer = load_model(server_args, port_args, gpu_id, tp_rank)\n\n    # Prepare inputs\n    custom_prompts = _read_prompts_from_file(bench_args.prompt_filename, rank_print)\n    input_ids, reqs = prepare_inputs_for_correctness_test(\n        bench_args, tokenizer, custom_prompts\n    )\n    rank_print(f\"\\n{input_ids=}\\n\")\n\n    if bench_args.cut_len > 0:\n        # Prefill\n        next_token_ids, next_token_logits, batch = extend(reqs, model_runner)\n        rank_print(f\"prefill logits (first half): {next_token_logits} \\n\")\n\n    # Prepare extend inputs\n    reqs = prepare_extend_inputs_for_correctness_test(\n        bench_args, input_ids, reqs, model_runner\n    )\n\n    # Extend (prefill w/ KV cache)\n    next_token_ids, next_token_logits, batch = extend(reqs, model_runner)\n    rank_print(f\"prefill logits (final): {next_token_logits} \\n\")\n\n    # Decode\n    output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]\n    for _ in range(bench_args.output_len[0] - 1):\n        next_token_ids, _ = decode(next_token_ids, batch, model_runner)\n        next_token_ids_list = next_token_ids.tolist()\n        for i in range(len(reqs)):\n            output_ids[i].append(next_token_ids_list[i])\n\n    # Print output texts\n    for i in range(len(reqs)):\n        rank_print(f\"========== Prompt {i} ==========\")\n        rank_print(tokenizer.decode(output_ids[i]), \"\\n\")\n\n\ndef synchronize(device):\n    torch.get_device_module(device).synchronize()\n\n\ndef latency_test_run_once(\n    run_name,\n    model_runner,\n    rank_print,\n    reqs,\n    batch_size,\n    input_len,\n    output_len,\n    device,\n    log_decode_step,\n    profile,\n    profile_record_shapes,\n    profile_activities,\n    profile_filename_prefix,\n    profile_stage,\n    tp_rank,\n    profile_start_step=None,\n    profile_steps=None,\n):\n    max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)\n    if batch_size > max_batch_size:\n        rank_print(\n            f\"skipping ({batch_size}, {input_len}, {output_len}) due to max batch size limit\"\n        )\n        return\n\n    model_runner.req_to_token_pool.clear()\n    model_runner.token_to_kv_pool_allocator.clear()\n\n    measurement_results = {\n        \"run_name\": run_name,\n        \"batch_size\": batch_size,\n        \"input_len\": input_len,\n        \"output_len\": output_len,\n    }\n\n    tot_latency = 0\n\n    profiler = None\n    enable_profile_prefill = profile and profile_stage in [\"all\", \"prefill\"]\n    if enable_profile_prefill:\n        profiler = start_profile(\n            profile_activities,\n            profile_record_shapes=profile_record_shapes,\n            rank_print=rank_print,\n        )\n\n    synchronize(device)\n    tic = time.perf_counter()\n    next_token_ids, _, batch = extend(reqs, model_runner)\n    synchronize(device)\n    prefill_latency = time.perf_counter() - tic\n\n    if enable_profile_prefill:\n        trace_filename = _create_torch_profiler_filename(\n            profile_filename_prefix, batch_size, input_len, output_len, \"prefill\"\n        )\n        stop_profile(\n            profiler,\n            profile_activities,\n            rank_print=rank_print,\n            save_trace=True,\n            trace_filename=trace_filename,\n            stage=\"prefill\",\n        )\n\n    tot_latency += prefill_latency\n    throughput = input_len * batch_size / prefill_latency\n    rank_print(\n        f\"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s\"\n    )\n    measurement_results[\"prefill_latency\"] = prefill_latency\n    measurement_results[\"prefill_throughput\"] = throughput\n\n    decode_latencies = []\n    # Determine profiling start step and end step\n    profile_start = (\n        profile_start_step if profile_start_step is not None else (output_len // 2)\n    )\n    profile_end = profile_start + (profile_steps if profile_steps is not None else 1)\n    enable_profile_decode = profile and profile_stage in [\"all\", \"decode\"]\n    profiler = None\n    for i in range(output_len - 1):\n        synchronize(device)\n        # Start profiler at the specified step\n        if enable_profile_decode and i == profile_start:\n            profiler = start_profile(\n                profile_activities,\n                profile_record_shapes=profile_record_shapes,\n                rank_print=rank_print,\n            )\n\n        tic = time.perf_counter()\n        next_token_ids, _ = decode(next_token_ids, batch, model_runner)\n        synchronize(device)\n        latency = time.perf_counter() - tic\n\n        # Stop profiler after the specified number of steps\n        if enable_profile_decode and profiler is not None and i >= profile_end - 1:\n            trace_filename = _create_torch_profiler_filename(\n                profile_filename_prefix, batch_size, input_len, output_len, \"decode\"\n            )\n            stop_profile(\n                profiler,\n                profile_activities,\n                rank_print=rank_print,\n                save_trace=True,\n                trace_filename=trace_filename,\n                stage=\"decode\",\n            )\n            profiler = None\n\n        tot_latency += latency\n        throughput = batch_size / latency\n        decode_latencies.append(latency)\n        if i < 5 or (log_decode_step > 0 and i % log_decode_step == 0):\n            rank_print(\n                f\"Decode {i}. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s\"\n            )\n\n    # Record decode timing from 2nd output\n    if output_len > 1:\n        med_decode_latency = np.median(decode_latencies)\n        med_decode_throughput = batch_size / med_decode_latency\n        rank_print(\n            f\"Decode.  median latency: {med_decode_latency:6.5f} s, median throughput: {med_decode_throughput:9.2f} token/s\"\n        )\n        measurement_results[\"median_decode_latency\"] = med_decode_latency\n        measurement_results[\"median_decode_throughput\"] = med_decode_throughput\n\n    throughput = (input_len + output_len) * batch_size / tot_latency\n    rank_print(\n        f\"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s\"\n    )\n    measurement_results[\"total_latency\"] = tot_latency\n    measurement_results[\"overall_throughput\"] = throughput\n    return measurement_results\n\n\ndef latency_test(\n    server_args,\n    port_args,\n    bench_args,\n    gpu_id,\n    tp_rank,\n):\n    initialize_moe_config(server_args)\n    initialize_fp8_gemm_config(server_args)\n    initialize_fp4_gemm_config(server_args)\n\n    # Set CPU affinity\n    if get_bool_env_var(\"SGLANG_SET_CPU_AFFINITY\"):\n        set_gpu_proc_affinity(\n            server_args.pp_size, server_args.tp_size, server_args.nnodes, tp_rank\n        )\n\n    # Configure the logger\n    configure_logger(server_args, prefix=f\" TP{tp_rank}\")\n    rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None\n\n    # Load the model\n    model_runner, tokenizer = load_model(server_args, port_args, gpu_id, tp_rank)\n\n    # Prepare inputs for warm up\n    reqs = prepare_synthetic_inputs_for_latency_test(\n        bench_args.batch_size[0], bench_args.input_len[0]\n    )\n\n    # Warm up\n    rank_print(\"Warmup ...\")\n    latency_test_run_once(\n        bench_args.run_name,\n        model_runner,\n        rank_print,\n        reqs,\n        bench_args.batch_size[0],\n        bench_args.input_len[0],\n        min(32, bench_args.output_len[0]),  # shorter decoding to speed up the warmup\n        server_args.device,\n        log_decode_step=0,\n        profile=False,\n        profile_record_shapes=False,\n        profile_activities=(\"CPU\", \"GPU\"),\n        profile_filename_prefix=\"\",\n        profile_stage=\"all\",\n        tp_rank=tp_rank,\n        profile_start_step=None,\n        profile_steps=None,\n    )\n\n    rank_print(\"Benchmark ...\")\n\n    custom_inputs = _read_prompts_from_file(bench_args.prompt_filename, rank_print)\n    custom_inputs = [tokenizer.encode(p.strip()) for p in custom_inputs]\n    custom_input_len = len(custom_inputs)\n\n    # Run the sweep\n    result_list = []\n    for bs, il, ol in itertools.product(\n        bench_args.batch_size, bench_args.input_len, bench_args.output_len\n    ):\n        bs_aligned_inputs = []\n        if custom_inputs:\n            if custom_input_len == bs:\n                bs_aligned_inputs = custom_inputs\n            elif custom_input_len > bs:\n                rank_print(\n                    f\"Custom input size ({custom_input_len}) is larger than batch_size ({bs}). \"\n                    f\"Using the first {bs} prompts.\"\n                )\n                bs_aligned_inputs = copy.deepcopy(custom_inputs[:bs])\n            else:\n                rank_print(\n                    f\"Custom input size ({custom_input_len}) is smaller than batch_size ({bs}). \"\n                    f\"Pad to the desired batch_size with the last prompt.\"\n                )\n                bs_aligned_inputs = copy.deepcopy(custom_inputs)\n                bs_aligned_inputs.extend(\n                    [bs_aligned_inputs[-1]] * (bs - custom_input_len)\n                )\n\n        reqs = prepare_synthetic_inputs_for_latency_test(bs, il, bs_aligned_inputs)\n        ret = latency_test_run_once(\n            bench_args.run_name,\n            model_runner,\n            rank_print,\n            reqs,\n            bs,\n            il,\n            ol,\n            server_args.device,\n            bench_args.log_decode_step,\n            bench_args.profile if tp_rank == 0 else None,\n            bench_args.profile_record_shapes if tp_rank == 0 else None,\n            bench_args.profile_activities,\n            bench_args.profile_filename_prefix,\n            bench_args.profile_stage,\n            tp_rank,\n            bench_args.profile_start_step,\n            bench_args.profile_steps,\n        )\n        if ret is not None:\n            result_list.append(ret)\n\n    # Write results in jsonlines format on rank 0.\n    if tp_rank == 0 and bench_args.result_filename:\n        with open(bench_args.result_filename, \"a\") as fout:\n            for result in result_list:\n                fout.write(json.dumps(result) + \"\\n\")\n\n    if server_args.tp_size > 1:\n        destroy_distributed_environment()\n\n\ndef main(server_args, bench_args):\n    server_args.cuda_graph_max_bs = max(bench_args.batch_size)\n\n    _set_envs_and_config(server_args)\n\n    if server_args.model_path:\n        if bench_args.correctness_test:\n            work_func = correctness_test\n        else:\n            work_func = latency_test\n    else:\n        raise ValueError(\n            \"Provide --model-path for running the tests or \"\n            \"provide --result-filename for plotting the results\"\n        )\n\n    port_args = PortArgs.init_new(server_args)\n\n    if server_args.tp_size == 1:\n        work_func(server_args, port_args, bench_args, 0, 0)\n    else:\n        workers = []\n        for tp_rank in range(server_args.tp_size):\n            with maybe_reindex_device_id(tp_rank) as gpu_id:\n                proc = multiprocessing.Process(\n                    target=work_func,\n                    args=(\n                        server_args,\n                        port_args,\n                        bench_args,\n                        gpu_id,\n                        tp_rank,\n                    ),\n                )\n                proc.start()\n                workers.append(proc)\n\n        for proc in workers:\n            proc.join()\n\n        proc.terminate()\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    ServerArgs.add_cli_args(parser)\n    BenchArgs.add_cli_args(parser)\n    args = parser.parse_args()\n    server_args = ServerArgs.from_cli_args(args)\n    bench_args = BenchArgs.from_cli_args(args)\n\n    logging.basicConfig(\n        level=getattr(logging, server_args.log_level.upper()),\n        format=\"%(message)s\",\n    )\n\n    try:\n        main(server_args, bench_args)\n    finally:\n        if server_args.tp_size != 1:\n            kill_process_tree(os.getpid(), include_parent=False)\n"
  },
  {
    "path": "python/sglang/bench_one_batch_server.py",
    "content": "\"\"\"\nBenchmark the latency of running a single batch with a server.\n\nThis script launches a server and uses the HTTP interface.\nIt accepts server arguments (the same as launch_server.py) and benchmark arguments (e.g., batch size, input lengths).\n\nUsage:\npython3 -m sglang.bench_one_batch_server --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8\n\npython3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8\npython3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --show-report --profile --profile-by-stage\npython3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --output-path results.json --profile\n\"\"\"\n\nimport argparse\n\nfrom sglang.srt.server_args import ServerArgs\nfrom sglang.test.bench_one_batch_server_internal import (\n    BenchArgs,\n    run_benchmark_internal,\n)\nfrom sglang.test.nightly_bench_utils import save_results_as_pydantic_models\n\n\ndef run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):\n    results, server_info = run_benchmark_internal(server_args, bench_args)\n\n    # Save results as pydantic models in the JSON format\n    if bench_args.pydantic_result_filename:\n        save_results_as_pydantic_models(\n            results,\n            pydantic_result_filename=bench_args.pydantic_result_filename,\n            model_path=server_args.model_path,\n            server_args=bench_args.server_args_for_metrics,\n        )\n\n    return results, server_info\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    ServerArgs.add_cli_args(parser)\n    BenchArgs.add_cli_args(parser)\n    args = parser.parse_args()\n\n    server_args = ServerArgs.from_cli_args(args)\n    bench_args = BenchArgs.from_cli_args(args)\n\n    run_benchmark(server_args, bench_args)\n"
  },
  {
    "path": "python/sglang/bench_serving.py",
    "content": "# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/backend_request_func.py\n# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/benchmark_serving.py\n\n\"\"\"\nBenchmark online serving with dynamic requests.\n\nUsage:\npython3 -m sglang.bench_serving --backend sglang --num-prompt 10\n\npython3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 3000 --random-input 1024 --random-output 1024 --random-range-ratio 0.5\n\"\"\"\n\nimport argparse\nimport asyncio\nimport copy\nimport importlib.util\nimport json\nimport os\nimport random\nimport shutil\nimport sys\nimport time\nimport traceback\nimport uuid\nimport warnings\nfrom argparse import ArgumentParser\nfrom copy import deepcopy\nfrom dataclasses import dataclass, field, replace\nfrom datetime import datetime\nfrom pathlib import Path\nfrom typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Tuple, Union\n\nimport aiohttp\nimport numpy as np\nimport requests\nfrom tqdm.asyncio import tqdm\nfrom transformers import AutoTokenizer, PreTrainedTokenizerBase\n\nfrom sglang.benchmark.datasets import DatasetRow, get_dataset\nfrom sglang.benchmark.datasets.mooncake import get_mooncake_request_over_time\nfrom sglang.benchmark.utils import (\n    get_tokenizer,\n    parse_custom_headers,\n    remove_prefix,\n    set_ulimit,\n)\n\n_ROUTING_KEY_HEADER = \"X-SMG-Routing-Key\"\n\n_EMBEDDING_UNSUPPORTED_DATASETS = {\"image\", \"mmmu\", \"mooncake\"}\n\nTERM_PLOTLIB_AVAILABLE = (importlib.util.find_spec(\"termplotlib\") is not None) and (\n    shutil.which(\"gnuplot\") is not None\n)\n\nglobal args\n\n\n# don't want to import sglang package here\ndef _get_bool_env_var(name: str, default: str = \"false\") -> bool:\n    value = os.getenv(name, default)\n    return value.lower() in (\"true\", \"1\")\n\n\ndef _create_bench_client_session():\n    # When the pressure is big, the read buffer could be full before aio thread read\n    # the content. We increase the read_bufsize from 64K to 10M.\n    # Define constants for timeout and buffer size for clarity and maintainability\n    BENCH_AIOHTTP_TIMEOUT_SECONDS = 6 * 60 * 60  # 6 hours\n    BENCH_AIOHTTP_READ_BUFSIZE_BYTES = 10 * 1024**2  # 10 MB\n\n    aiohttp_timeout = aiohttp.ClientTimeout(total=BENCH_AIOHTTP_TIMEOUT_SECONDS)\n    return aiohttp.ClientSession(\n        timeout=aiohttp_timeout, read_bufsize=BENCH_AIOHTTP_READ_BUFSIZE_BYTES\n    )\n\n\n@dataclass\nclass RequestFuncInput:\n    prompt: Union[str, List[str], List[Dict[str, str]]]\n    api_url: str\n    prompt_len: int\n    output_len: int\n    model: str\n    lora_name: str\n    image_data: Optional[List[str]]\n    extra_request_body: Dict[str, Any]\n    timestamp: Optional[float] = None\n    routing_key: Optional[str] = None\n\n\n@dataclass\nclass RequestFuncOutput:\n    generated_text: str = \"\"\n    success: bool = False\n    latency: float = 0.0\n    ttft: float = 0.0  # Time to first token\n    itl: List[float] = field(default_factory=list)  # List of inter-token latencies\n    text_chunks: List[str] = field(default_factory=list)\n    prompt_len: int = 0\n    error: str = \"\"\n    output_len: int = 0\n    start_time: float = 0.0\n\n    @staticmethod\n    def init_new(request_func_input: RequestFuncInput):\n        output = RequestFuncOutput()\n        output.prompt_len = request_func_input.prompt_len\n        return output\n\n\ndef get_auth_headers() -> Dict[str, str]:\n    openai_api_key = os.environ.get(\"OPENAI_API_KEY\")\n    if openai_api_key:\n        return {\"Authorization\": f\"Bearer {openai_api_key}\"}\n    else:\n        api_key = os.environ.get(\"API_KEY\")\n        if api_key:\n            return {\"Authorization\": f\"{api_key}\"}\n        return {}\n\n\ndef get_request_headers() -> Dict[str, str]:\n    headers = get_auth_headers()\n    if h := getattr(args, \"header\", None):\n        headers.update(parse_custom_headers(h))\n    return headers\n\n\ndef wait_for_endpoint(url: str, timeout_sec: int = 60) -> bool:\n    \"\"\"Wait for the server to become ready by polling the given URL.\"\"\"\n    print(f\"Waiting up to {timeout_sec}s for {url} to become ready...\")\n    start_time = time.perf_counter()\n    headers = get_auth_headers()\n    while True:\n        try:\n            response = requests.get(url, headers=headers, timeout=5)\n            if response.status_code == 200:\n                elapsed = time.perf_counter() - start_time\n                print(f\"Server ready in {elapsed:.1f}s.\")\n                return True\n        except requests.exceptions.RequestException:\n            pass\n        elapsed = time.perf_counter() - start_time\n        if elapsed >= timeout_sec:\n            print(f\"Server did not become ready within {timeout_sec}s timeout.\")\n            return False\n        time.sleep(1)\n\n\n# trt llm does not support ignore_eos\n# https://github.com/triton-inference-server/tensorrtllm_backend/issues/505\nasync def async_request_trt_llm(\n    request_func_input: RequestFuncInput,\n    pbar: Optional[tqdm] = None,\n) -> RequestFuncOutput:\n    api_url = request_func_input.api_url\n    assert api_url.endswith(\"generate_stream\")\n\n    async with _create_bench_client_session() as session:\n        payload = {\n            \"accumulate_tokens\": True,\n            \"text_input\": request_func_input.prompt,\n            \"temperature\": 0.000001,\n            \"top_p\": 1.0,\n            \"max_tokens\": request_func_input.output_len,\n            \"stream\": True,\n            \"min_length\": request_func_input.output_len,\n            \"end_id\": 1048576,\n            **request_func_input.extra_request_body,\n        }\n        if args.disable_ignore_eos:\n            del payload[\"min_length\"]\n            del payload[\"end_id\"]\n        output = RequestFuncOutput.init_new(request_func_input)\n\n        ttft = 0.0\n        st = time.perf_counter()\n        most_recent_timestamp = st\n        try:\n            async with session.post(url=api_url, json=payload) as response:\n                if response.status == 200:\n                    async for chunk_bytes in response.content:\n                        chunk_bytes = chunk_bytes.strip()\n                        if not chunk_bytes:\n                            continue\n\n                        chunk = remove_prefix(chunk_bytes.decode(\"utf-8\"), \"data:\")\n\n                        data = json.loads(chunk)\n                        output.generated_text += data[\"text_output\"]\n                        timestamp = time.perf_counter()\n                        # First token\n                        if ttft == 0.0:\n                            ttft = timestamp - st\n                            output.ttft = ttft\n\n                        # Decoding phase\n                        else:\n                            output.itl.append(timestamp - most_recent_timestamp)\n\n                        most_recent_timestamp = timestamp\n\n                    output.latency = most_recent_timestamp - st\n                    output.success = True\n                    output.output_len = request_func_input.output_len\n\n                else:\n                    output.error = (\n                        (response.reason or \"\") + \": \" + (await response.text())\n                    )\n                    output.success = False\n        except Exception:\n            output.success = False\n            exc_info = sys.exc_info()\n            output.error = \"\".join(traceback.format_exception(*exc_info))\n\n        if pbar:\n            pbar.update(1)\n        return output\n\n\n# set ignore_eos True by default\nasync def async_request_openai_completions(\n    request_func_input: RequestFuncInput,\n    pbar: Optional[tqdm] = None,\n) -> RequestFuncOutput:\n    api_url = request_func_input.api_url\n    assert api_url.endswith(\n        \"completions\"\n    ), \"OpenAI Completions API URL must end with 'completions'.\"\n\n    prompt = request_func_input.prompt\n\n    async with _create_bench_client_session() as session:\n        # Build payload with defaults that can be overridden by extra_request_body\n        payload = {\n            \"model\": request_func_input.model,\n            \"prompt\": prompt,\n            \"best_of\": 1,\n            \"max_tokens\": request_func_input.output_len,\n            \"stream\": not args.disable_stream,\n        }\n\n        # Add temperature default only if not specified in extra_request_body\n        if \"temperature\" not in request_func_input.extra_request_body:\n            payload[\"temperature\"] = 0.0\n\n        # Add ignore_eos default only if not specified in extra_request_body\n        if \"ignore_eos\" not in request_func_input.extra_request_body:\n            payload[\"ignore_eos\"] = not args.disable_ignore_eos\n\n        if args.return_logprob and args.top_logprobs_num > 0:\n            payload[\"logprobs\"] = args.top_logprobs_num\n\n        # Merge in extra parameters - these will override defaults if present\n        payload.update(request_func_input.extra_request_body)\n\n        # hack to accommodate different LoRA conventions between SGLang and vLLM.\n        if request_func_input.lora_name:\n            payload[\"model\"] = request_func_input.lora_name\n            payload[\"lora_path\"] = request_func_input.lora_name\n\n        if request_func_input.image_data:\n            payload.update({\"image_data\": request_func_input.image_data})\n\n        headers = get_request_headers()\n        if request_func_input.routing_key:\n            headers[_ROUTING_KEY_HEADER] = request_func_input.routing_key\n\n        output = RequestFuncOutput.init_new(request_func_input)\n\n        generated_text = \"\"\n        output_len = request_func_input.output_len\n        ttft = 0.0\n        st = time.perf_counter()\n        output.start_time = st\n        most_recent_timestamp = st\n        try:\n            async with session.post(\n                url=api_url, json=payload, headers=headers\n            ) as response:\n                if response.status == 200:\n                    async for chunk_bytes in response.content:\n                        chunk_bytes = chunk_bytes.strip()\n                        if not chunk_bytes:\n                            continue\n\n                        chunk = remove_prefix(chunk_bytes.decode(\"utf-8\"), \"data: \")\n                        latency = time.perf_counter() - st\n                        if chunk == \"[DONE]\":\n                            pass\n                        else:\n                            data = json.loads(chunk)\n\n                            # NOTE: Some completion API might have a last\n                            # usage summary response without a token so we\n                            # want to check a token was generated\n                            if data[\"choices\"][0][\"text\"]:\n                                timestamp = time.perf_counter()\n                                # First token\n                                if ttft == 0.0:\n                                    ttft = time.perf_counter() - st\n                                    output.ttft = ttft\n\n                                # Decoding phase\n                                else:\n                                    output.text_chunks.append(\n                                        data[\"choices\"][0][\"text\"]\n                                    )\n                                    output.itl.append(timestamp - most_recent_timestamp)\n\n                                most_recent_timestamp = timestamp\n                                generated_text += data[\"choices\"][0][\"text\"]\n                                output_len = (data.get(\"usage\") or {}).get(\n                                    \"completion_tokens\", output_len\n                                )\n\n                    output.generated_text = generated_text\n                    output.success = True\n                    output.latency = latency\n                    output.output_len = output_len\n                else:\n                    output.error = (\n                        (response.reason or \"\") + \": \" + (await response.text())\n                    )\n                    output.success = False\n        except Exception:\n            output.success = False\n            exc_info = sys.exc_info()\n            output.error = \"\".join(traceback.format_exception(*exc_info))\n\n    if pbar:\n        pbar.update(1)\n    return output\n\n\nasync def async_request_openai_chat_completions(\n    request_func_input: RequestFuncInput,\n    pbar: Optional[tqdm] = None,\n) -> RequestFuncOutput:\n    \"\"\"Makes a request to the OpenAI Chat Completions API.\n\n    Handles both streaming and non-streaming responses, including support\n    for image data in messages. Calculates and returns various performance\n    metrics.\n\n    Args:\n        request_func_input: Input parameters for the request.\n        pbar: Optional tqdm progress bar to update.\n\n    Returns:\n        RequestFuncOutput: Output of the request, including generated text,\n                           latency, TTFT, ITL, and success status.\n    \"\"\"\n    api_url = request_func_input.api_url\n    assert api_url.endswith(\n        \"chat/completions\"\n    ), \"OpenAI Chat Completions API URL must end with 'chat/completions'.\"\n\n    # TODO put it to other functions when `pbar` logic is refactored\n    if getattr(args, \"print_requests\", False):\n        rid = str(uuid.uuid4())\n        input_partial = deepcopy(request_func_input)\n        input_partial.prompt = \"...\"\n        request_start_time = time.time()\n        print(\n            f'rid={rid} time={request_start_time} message=\"request start\" request_func_input=\"{str(input_partial)}\"'\n        )\n\n    if isinstance(request_func_input.prompt, list):\n        messages = request_func_input.prompt\n    elif request_func_input.image_data:\n        # Build multi-image content: a list of image_url entries followed by the text\n        content_items = [\n            {\n                \"type\": \"image_url\",\n                \"image_url\": {\"url\": img_url},\n            }\n            for img_url in request_func_input.image_data\n        ]\n        content_items.append({\"type\": \"text\", \"text\": request_func_input.prompt})\n        messages = [\n            {\n                \"role\": \"user\",\n                \"content\": content_items,\n            },\n        ]\n    else:\n        messages = [{\"role\": \"user\", \"content\": request_func_input.prompt}]\n\n    async with _create_bench_client_session() as session:\n        # Build payload with defaults that can be overridden by extra_request_body\n        payload = {\n            \"model\": request_func_input.model,\n            \"messages\": messages,\n            \"max_completion_tokens\": request_func_input.output_len,\n            \"stream\": not args.disable_stream,\n        }\n\n        # Add temperature default only if not specified in extra_request_body\n        if \"temperature\" not in request_func_input.extra_request_body:\n            payload[\"temperature\"] = 0.0\n\n        # Add ignore_eos default only if not specified in extra_request_body\n        # Default to False for more realistic behavior (respect EOS tokens)\n        if \"ignore_eos\" not in request_func_input.extra_request_body:\n            payload[\"ignore_eos\"] = not args.disable_ignore_eos\n\n        # Merge in extra parameters (tools, temperature, top_p, etc.)\n        # These will override defaults if present\n        payload.update(request_func_input.extra_request_body)\n\n        # hack to accommodate different LoRA conventions between SGLang and vLLM.\n        if request_func_input.lora_name:\n            payload[\"model\"] = request_func_input.lora_name\n            payload[\"lora_path\"] = request_func_input.lora_name\n\n        headers = get_request_headers()\n        if request_func_input.routing_key:\n            headers[_ROUTING_KEY_HEADER] = request_func_input.routing_key\n\n        output = RequestFuncOutput.init_new(request_func_input)\n\n        generated_text = \"\"\n        output_len = request_func_input.output_len\n        ttft = 0.0\n        st = time.perf_counter()\n        output.start_time = st\n        most_recent_timestamp = st\n        try:\n            async with session.post(\n                url=api_url, json=payload, headers=headers\n            ) as response:\n                if response.status == 200:\n                    if args.disable_stream:\n                        # Non-streaming response\n                        response_json = await response.json()\n                        output.generated_text = response_json[\"choices\"][0][\"message\"][\n                            \"content\"\n                        ]\n                        output.success = True\n                        output.latency = time.perf_counter() - st\n                        output.ttft = (\n                            output.latency\n                        )  # For non-streaming, TTFT = total latency\n                        output.output_len = response_json.get(\"usage\", {}).get(\n                            \"completion_tokens\", output_len\n                        )\n                    else:\n                        # Streaming response\n                        async for chunk_bytes in response.content:\n                            chunk_bytes = chunk_bytes.strip()\n                            if not chunk_bytes:\n                                continue\n\n                            chunk = remove_prefix(chunk_bytes.decode(\"utf-8\"), \"data: \")\n                            latency = time.perf_counter() - st\n                            if chunk == \"[DONE]\":\n                                pass\n                            else:\n                                data = json.loads(chunk)\n\n                                # Check if this chunk contains content\n                                delta = data.get(\"choices\", [{}])[0].get(\"delta\", {})\n                                content = delta.get(\"content\", \"\")\n\n                                if content:\n                                    timestamp = time.perf_counter()\n                                    # First token\n                                    if ttft == 0.0:\n                                        ttft = timestamp - st\n                                        output.ttft = ttft\n\n                                    # Decoding phase\n                                    else:\n                                        output.text_chunks.append(content)\n                                        output.itl.append(\n                                            timestamp - most_recent_timestamp\n                                        )\n\n                                    most_recent_timestamp = timestamp\n                                    generated_text += content\n\n                                # Check for usage info in final chunk\n                                output_len = (data.get(\"usage\") or {}).get(\n                                    \"completion_tokens\", output_len\n                                )\n\n                        output.generated_text = generated_text\n                        output.success = True\n                        output.latency = latency\n                        output.output_len = output_len\n                else:\n                    output.error = (\n                        (response.reason or \"\") + \": \" + (await response.text())\n                    )\n                    output.success = False\n        except Exception:\n            output.success = False\n            exc_info = sys.exc_info()\n            output.error = \"\".join(traceback.format_exception(*exc_info))\n\n    # TODO put it to other functions when `pbar` logic is refactored\n    if getattr(args, \"print_requests\", False):\n        curr_t = time.time()\n        output_partial = deepcopy(output)\n        output_partial.generated_text = \"...\"\n        print(\n            f'rid={rid} time={curr_t} time_delta={curr_t - request_start_time} message=\"request end\" output=\"{str(output_partial)}\"'\n        )\n\n    if pbar:\n        pbar.update(1)\n    return output\n\n\nasync def async_request_truss(\n    request_func_input: RequestFuncInput,\n    pbar: Optional[tqdm] = None,\n) -> RequestFuncOutput:\n    api_url = request_func_input.api_url\n\n    prompt = request_func_input.prompt\n\n    async with _create_bench_client_session() as session:\n        payload = {\n            \"model\": request_func_input.model,\n            \"prompt\": prompt,\n            \"temperature\": 0.0,\n            \"best_of\": 1,\n            \"max_tokens\": request_func_input.output_len,\n            \"stream\": not args.disable_stream,\n            \"ignore_eos\": not args.disable_ignore_eos,\n            **request_func_input.extra_request_body,\n        }\n        headers = get_request_headers()\n\n        output = RequestFuncOutput.init_new(request_func_input)\n\n        generated_text = \"\"\n        ttft = 0.0\n        st = time.perf_counter()\n        most_recent_timestamp = st\n        try:\n            async with session.post(\n                url=api_url, json=payload, headers=headers\n            ) as response:\n                if response.status == 200:\n                    async for chunk_bytes in response.content:\n                        chunk_bytes = chunk_bytes.strip()\n                        if not chunk_bytes:\n                            continue\n\n                        chunk = remove_prefix(chunk_bytes.decode(\"utf-8\"), \"data: \")\n                        latency = time.perf_counter() - st\n                        if chunk == \"[DONE]\":\n                            pass\n                        else:\n                            data = json.loads(chunk)\n\n                            # NOTE: Some completion API might have a last\n                            # usage summary response without a token so we\n                            # want to check a token was generated\n                            if data[\"choices\"][0][\"text\"]:\n                                timestamp = time.perf_counter()\n                                # First token\n                                if ttft == 0.0:\n                                    ttft = time.perf_counter() - st\n                                    output.ttft = ttft\n\n                                # Decoding phase\n                                else:\n                                    output.itl.append(timestamp - most_recent_timestamp)\n\n                                most_recent_timestamp = timestamp\n                                generated_text += data[\"choices\"][0][\"text\"]\n\n                    output.generated_text = generated_text\n                    output.success = True\n                    output.latency = latency\n                    output.output_len = request_func_input.output_len\n                else:\n                    output.error = (\n                        (response.reason or \"\") + \": \" + (await response.text())\n                    )\n                    output.success = False\n        except Exception:\n            output.success = False\n            exc_info = sys.exc_info()\n            output.error = \"\".join(traceback.format_exception(*exc_info))\n\n    if pbar:\n        pbar.update(1)\n    return output\n\n\nasync def async_request_sglang_generate(\n    request_func_input: RequestFuncInput,\n    pbar: Optional[tqdm] = None,\n) -> RequestFuncOutput:\n    api_url = request_func_input.api_url\n    prompt = request_func_input.prompt\n\n    async with _create_bench_client_session() as session:\n        payload = {\n            (\"text\" if isinstance(prompt, str) else \"input_ids\"): prompt,\n            \"sampling_params\": {\n                \"temperature\": 0.0,\n                \"max_new_tokens\": request_func_input.output_len,\n                \"ignore_eos\": not args.disable_ignore_eos,\n            },\n            \"stream\": not args.disable_stream,\n            \"lora_path\": request_func_input.lora_name,\n            \"return_logprob\": args.return_logprob,\n            \"return_routed_experts\": args.return_routed_experts,\n            \"logprob_start_len\": args.logprob_start_len,\n            **request_func_input.extra_request_body,\n        }\n        if args.top_logprobs_num > 0:\n            payload[\"top_logprobs_num\"] = args.top_logprobs_num\n        if args.token_ids_logprob is not None:\n            payload[\"token_ids_logprob\"] = args.token_ids_logprob\n\n        # Add image data if available (list of image urls/base64)\n        if request_func_input.image_data:\n            payload[\"image_data\"] = request_func_input.image_data\n\n        headers = get_request_headers()\n        if request_func_input.routing_key:\n            headers[_ROUTING_KEY_HEADER] = request_func_input.routing_key\n\n        output = RequestFuncOutput.init_new(request_func_input)\n\n        generated_text = \"\"\n        output_len = request_func_input.output_len\n        ttft = 0.0\n        st = time.perf_counter()\n        output.start_time = st\n        most_recent_timestamp = st\n        last_output_len = 0\n        try:\n            async with session.post(\n                url=api_url, json=payload, headers=headers\n            ) as response:\n                if response.status == 200:\n                    async for chunk_bytes in response.content:\n                        chunk_bytes = chunk_bytes.strip()\n                        if not chunk_bytes:\n                            continue\n\n                        chunk = remove_prefix(chunk_bytes.decode(\"utf-8\"), \"data: \")\n                        latency = time.perf_counter() - st\n                        if chunk == \"[DONE]\":\n                            pass\n                        else:\n                            data = json.loads(chunk)\n\n                            # NOTE: Some completion API might have a last\n                            # usage summary response without a token so we\n                            # want to check a token was generated\n                            if \"text\" in data and data[\"text\"]:\n                                timestamp = time.perf_counter()\n                                generated_text = data[\"text\"]\n                                output_len = data[\"meta_info\"][\"completion_tokens\"]\n\n                                # First token\n                                if ttft == 0.0:\n                                    ttft = time.perf_counter() - st\n                                    output.ttft = ttft\n\n                                # Decoding phase\n                                else:\n                                    num_new_tokens = output_len - last_output_len\n                                    if num_new_tokens == 0:\n                                        continue\n                                    chunk_gap = timestamp - most_recent_timestamp\n                                    adjust_itl = chunk_gap / num_new_tokens\n                                    output.itl.extend([adjust_itl] * num_new_tokens)\n\n                                most_recent_timestamp = timestamp\n                                last_output_len = output_len\n\n                    output.generated_text = generated_text\n                    output.success = True\n                    output.latency = latency\n                    output.output_len = output_len\n                else:\n                    output.error = (\n                        (response.reason or \"\") + \": \" + (await response.text())\n                    )\n                    output.success = False\n        except Exception:\n            output.success = False\n            exc_info = sys.exc_info()\n            output.error = \"\".join(traceback.format_exception(*exc_info))\n            print(f\"{output.error=}\")\n\n    if pbar:\n        pbar.update(1)\n    return output\n\n\nasync def async_request_openai_embeddings(\n    request_func_input: RequestFuncInput,\n    pbar: Optional[tqdm] = None,\n) -> RequestFuncOutput:\n    api_url = request_func_input.api_url\n\n    async with _create_bench_client_session() as session:\n        payload = {\n            \"input\": request_func_input.prompt,\n            \"model\": request_func_input.model,\n        }\n\n        if request_func_input.lora_name:\n            payload[\"model\"] = request_func_input.lora_name\n            payload[\"lora_path\"] = request_func_input.lora_name\n\n        payload.update(request_func_input.extra_request_body)\n\n        headers = get_request_headers()\n        if request_func_input.routing_key:\n            headers[_ROUTING_KEY_HEADER] = request_func_input.routing_key\n\n        output = RequestFuncOutput.init_new(request_func_input)\n\n        st = time.perf_counter()\n        output.start_time = st\n        try:\n            async with session.post(\n                url=api_url, json=payload, headers=headers\n            ) as response:\n                if response.status == 200:\n                    await response.json()\n                    output.latency = time.perf_counter() - st\n                    output.success = True\n                    output.output_len = 0\n                else:\n                    output.error = (\n                        (response.reason or \"\") + \": \" + (await response.text())\n                    )\n                    output.success = False\n        except Exception:\n            output.success = False\n            exc_info = sys.exc_info()\n            output.error = \"\".join(traceback.format_exception(*exc_info))\n\n    if pbar:\n        pbar.update(1)\n    return output\n\n\nasync def async_request_gserver(\n    request_func_input: RequestFuncInput,\n    pbar: Optional[tqdm] = None,\n) -> RequestFuncOutput:\n    raise NotImplementedError()\n\n\nasync def async_request_profile(api_url: str) -> RequestFuncOutput:\n    async with _create_bench_client_session() as session:\n        output = RequestFuncOutput()\n        try:\n            if api_url.endswith(\"/start_profile\"):\n                num_steps = getattr(args, \"profile_num_steps\", None)\n                profile_by_stage = getattr(args, \"profile_by_stage\", None)\n                if profile_by_stage and num_steps is None:\n                    num_steps = 5\n\n                output_dir = getattr(args, \"profile_output_dir\", None)\n                if output_dir is None:\n                    output_dir = os.getenv(\"SGLANG_TORCH_PROFILER_DIR\", \"/tmp\")\n                output_dir = Path(os.path.abspath(os.path.normpath(output_dir))) / str(\n                    time.time()\n                )\n                output_dir.mkdir(exist_ok=True, parents=True)\n                output_dir = str(output_dir)\n\n                body = {\n                    \"activities\": getattr(args, \"profile_activities\", []),\n                    \"num_steps\": num_steps,\n                    \"profile_by_stage\": profile_by_stage,\n                    \"profile_stages\": getattr(args, \"profile_stages\", None),\n                    \"output_dir\": output_dir,\n                    \"profile_prefix\": getattr(args, \"profile_prefix\", None),\n                }\n            else:\n                # stop_profile doesn't need any parameters\n                body = {}\n            print(f\"async_request_profile {api_url=} {body=}\")\n            # Add optional profiling parameters if provided\n            if (\n                hasattr(args, \"profile_start_step\")\n                and args.profile_start_step is not None\n            ):\n                body[\"start_step\"] = str(args.profile_start_step)\n            if hasattr(args, \"profile_steps\") and args.profile_steps is not None:\n                body[\"num_steps\"] = str(args.profile_steps)\n            async with session.post(url=api_url, json=body) as response:\n                if response.status == 200:\n                    output.success = True\n                else:\n                    output.error = (\n                        (response.reason or \"\") + \": \" + (await response.text())\n                    )\n                    output.success = False\n        except Exception:\n            output.success = False\n            exc_info = sys.exc_info()\n            output.error = \"\".join(traceback.format_exception(*exc_info))\n\n    return output\n\n\ndef _build_profile_urls(\n    profile_prefill_url: Optional[List[str]],\n    profile_decode_url: Optional[List[str]],\n) -> List[Tuple[str, str]]:\n    \"\"\"Build profile URLs list from prefill/decode URL arguments.\n\n    Returns:\n        List of (worker_type, url) tuples. e.g., [(\"Prefill-0\", \"http://...\"), (\"Decode-0\", \"http://...\")]\n    \"\"\"\n    profile_urls = []\n    if profile_prefill_url:\n        for idx, url in enumerate(profile_prefill_url):\n            profile_urls.append((f\"Prefill-{idx}\", url))\n    if profile_decode_url:\n        for idx, url in enumerate(profile_decode_url):\n            profile_urls.append((f\"Decode-{idx}\", url))\n    return profile_urls\n\n\nasync def _call_profile_pd(profile_urls: List[Tuple[str, str]], mode: str) -> None:\n    \"\"\"Call profile endpoint (start/stop) on PD separated workers.\n\n    Args:\n        profile_urls: List of (worker_type, url) tuples\n        mode: \"start\" or \"stop\"\n    \"\"\"\n    endpoint = \"/start_profile\" if mode == \"start\" else \"/stop_profile\"\n    action = \"Starting\" if mode == \"start\" else \"Stopping\"\n    action_past = \"started\" if mode == \"start\" else \"stopped\"\n\n    print(f\"{action} profiler...\")\n\n    for worker_type, url in profile_urls:\n        profile_output = await async_request_profile(api_url=url + endpoint)\n        if profile_output.success:\n            print(f\"Profiler {action_past} for {worker_type} worker at {url}\")\n        else:\n            print(\n                f\"Failed to {mode} profiler for {worker_type} worker at {url}: {profile_output.error}\"\n            )\n\n\nASYNC_REQUEST_FUNCS = {\n    \"sglang\": async_request_sglang_generate,\n    \"sglang-native\": async_request_sglang_generate,\n    \"sglang-oai\": async_request_openai_completions,\n    \"sglang-oai-chat\": async_request_openai_chat_completions,\n    \"sglang-embedding\": async_request_openai_embeddings,\n    \"vllm\": async_request_openai_completions,\n    \"vllm-chat\": async_request_openai_chat_completions,\n    \"lmdeploy\": async_request_openai_completions,\n    \"lmdeploy-chat\": async_request_openai_chat_completions,\n    \"trt\": async_request_trt_llm,\n    \"gserver\": async_request_gserver,\n    \"truss\": async_request_truss,\n}\n\n\n@dataclass\nclass BenchmarkMetrics:\n    completed: int\n    total_input: int\n    total_input_text: int\n    total_input_vision: int\n    total_output: int\n    total_output_retokenized: int\n    request_throughput: float\n    input_throughput: float\n    output_throughput: float\n    output_throughput_retokenized: float\n    total_throughput: float\n    total_throughput_retokenized: float\n    mean_ttft_ms: float\n    median_ttft_ms: float\n    std_ttft_ms: float\n    p99_ttft_ms: float\n    mean_tpot_ms: float\n    median_tpot_ms: float\n    std_tpot_ms: float\n    p99_tpot_ms: float\n    mean_itl_ms: float\n    median_itl_ms: float\n    std_itl_ms: float\n    p95_itl_ms: float\n    p99_itl_ms: float\n    max_itl_ms: float\n    mean_e2e_latency_ms: float\n    median_e2e_latency_ms: float\n    std_e2e_latency_ms: float\n    p90_e2e_latency_ms: float\n    p99_e2e_latency_ms: float\n    concurrency: float\n    max_output_tokens_per_s: float = 0.0\n    max_concurrent_requests: int = 0\n\n\nasync def get_request(\n    input_requests: List[DatasetRow],\n    request_rate: float,\n    use_trace_timestamps: bool = False,\n    slowdown_factor: float = 1.0,\n) -> AsyncGenerator[DatasetRow, None]:\n    if use_trace_timestamps:\n        print(\n            f\"Using trace timestamps for request generation with slowdown factor {slowdown_factor}.\"\n        )\n        # Sort requests by timestamp for correct replay\n        input_requests.sort(key=lambda r: r.timestamp)\n\n        start_time = time.perf_counter()\n        trace_start_time_ms = input_requests[0].timestamp if input_requests else 0\n\n        for request in input_requests:\n            trace_time_s = (request.timestamp - trace_start_time_ms) / 1000.0\n            target_arrival_time = start_time + (trace_time_s * slowdown_factor)\n\n            sleep_duration = target_arrival_time - time.perf_counter()\n            if sleep_duration > 0:\n                await asyncio.sleep(sleep_duration)\n\n            yield request\n    else:\n        input_requests_iter = iter(input_requests)\n        for request in input_requests_iter:\n            yield request\n\n            if request_rate == float(\"inf\"):\n                # If the request rate is infinity, then we don't need to wait.\n                continue\n\n            # Sample the request interval from the exponential distribution.\n            interval = np.random.exponential(1.0 / request_rate)\n            # The next request will be sent after the interval.\n            await asyncio.sleep(interval)\n\n\ndef calculate_metrics(\n    input_requests: Optional[List[DatasetRow]],\n    outputs: List[RequestFuncOutput],\n    dur_s: float,\n    tokenizer: PreTrainedTokenizerBase,\n    backend: str,\n    accept_length: Optional[float] = None,\n    plot_throughput: bool = False,\n) -> Tuple[BenchmarkMetrics, List[int]]:\n    output_lens: List[int] = []\n    retokenized_output_lens: List[int] = []\n    total_input = 0\n    total_input_text = 0\n    total_input_vision = 0\n    completed = 0\n    itls: List[float] = []\n    tpots: List[float] = []\n    ttfts: List[float] = []\n    e2e_latencies: List[float] = []\n    retokenized_itls: List[float] = []\n\n    use_retokenized_itl = (\n        accept_length is not None\n        and accept_length > 0\n        and backend in (\"sglang-oai\", \"sglang-oai-chat\")\n    )\n\n    for i in range(len(outputs)):\n        if outputs[i].success:\n            output_len = outputs[i].output_len\n            output_lens.append(output_len)\n            retokenized_output_len = len(\n                tokenizer.encode(outputs[i].generated_text, add_special_tokens=False)\n            )\n            retokenized_output_lens.append(retokenized_output_len)\n            if input_requests is not None:\n                total_input += input_requests[i].prompt_len\n                total_input_text += input_requests[i].text_prompt_len\n                total_input_vision += input_requests[i].vision_prompt_len\n            if output_len > 1:\n                tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1))\n            if use_retokenized_itl:\n                for k, itl in enumerate(outputs[i].itl):\n                    num_tokens = len(\n                        tokenizer.encode(\n                            outputs[i].text_chunks[k], add_special_tokens=False\n                        )\n                    )\n                    adjusted_itl = itl / num_tokens\n                    retokenized_itls.extend([adjusted_itl] * num_tokens)\n            else:\n                itls += outputs[i].itl\n            ttfts.append(outputs[i].ttft)\n\n            e2e_latencies.append(outputs[i].latency)\n\n            completed += 1\n        else:\n            output_lens.append(0)\n            retokenized_output_lens.append(0)\n\n    if completed == 0:\n        warnings.warn(\n            \"All requests failed. This is likely due to a misconfiguration \"\n            \"on the benchmark arguments.\",\n            stacklevel=2,\n        )\n\n    max_output_tokens_per_s = 0.0\n    max_concurrent_requests = 0\n\n    successful_outputs = [output for output in outputs if output.success]\n    if successful_outputs:\n        min_start_time = min(output.start_time for output in successful_outputs)\n        max_end_time = max(\n            output.start_time + output.latency for output in successful_outputs\n        )\n\n        duration_seconds = int(np.ceil(max_end_time - min_start_time)) + 1\n        tokens_per_second = np.zeros(duration_seconds)\n        concurrent_requests_per_second = np.zeros(duration_seconds)\n\n        for output in outputs:\n            if not output.success:\n                continue\n\n            token_times = [output.start_time + output.ttft]\n            current_time = token_times[0]\n            for itl_value in output.itl:\n                current_time += itl_value\n                token_times.append(current_time)\n\n            for token_time in token_times:\n                second_bucket = int(token_time - min_start_time)\n                if 0 <= second_bucket < duration_seconds:\n                    tokens_per_second[second_bucket] += 1\n\n            request_start_second = int(output.start_time - min_start_time)\n            request_end_second = int(\n                (output.start_time + output.latency) - min_start_time\n            )\n\n            for second in range(\n                request_start_second, min(request_end_second + 1, duration_seconds)\n            ):\n                concurrent_requests_per_second[second] += 1\n\n        if len(tokens_per_second) > 0:\n            max_output_tokens_per_s = float(np.max(tokens_per_second))\n            max_concurrent_requests = int(np.max(concurrent_requests_per_second))\n\n        if plot_throughput:\n            if TERM_PLOTLIB_AVAILABLE:\n                import termplotlib as tpl\n\n                fig = tpl.figure()\n                fig.plot(\n                    np.arange(len(tokens_per_second)),\n                    tokens_per_second,\n                    title=\"Output tokens per second\",\n                    xlabel=\"Time (s)\",\n                )\n                fig.plot(\n                    np.arange(len(concurrent_requests_per_second)),\n                    concurrent_requests_per_second,\n                    title=\"Concurrent requests per second\",\n                    xlabel=\"Time (s)\",\n                )\n                fig.show()\n            else:\n                print(\"tip: install termplotlib and gnuplot to plot the metrics\")\n\n    itls = retokenized_itls if use_retokenized_itl else itls\n    metrics = BenchmarkMetrics(\n        completed=completed,\n        total_input=total_input,\n        total_input_text=total_input_text,\n        total_input_vision=total_input_vision,\n        total_output=sum(output_lens),\n        total_output_retokenized=sum(retokenized_output_lens),\n        request_throughput=completed / dur_s,\n        input_throughput=total_input / dur_s,\n        output_throughput=sum(output_lens) / dur_s,\n        output_throughput_retokenized=sum(retokenized_output_lens) / dur_s,\n        total_throughput=(total_input + sum(output_lens)) / dur_s,\n        total_throughput_retokenized=(total_input + sum(retokenized_output_lens))\n        / dur_s,\n        mean_ttft_ms=np.mean(ttfts or 0)\n        * 1000,  # ttfts is empty if streaming is not supported by backend\n        median_ttft_ms=np.median(ttfts or 0) * 1000,\n        std_ttft_ms=np.std(ttfts or 0) * 1000,\n        p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,\n        mean_tpot_ms=np.mean(tpots or 0) * 1000,\n        median_tpot_ms=np.median(tpots or 0) * 1000,\n        std_tpot_ms=np.std(tpots or 0) * 1000,\n        p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,\n        mean_itl_ms=np.mean(itls or 0) * 1000,\n        median_itl_ms=np.median(itls or 0) * 1000,\n        std_itl_ms=np.std(itls or 0) * 1000,\n        p95_itl_ms=np.percentile(itls or 0, 95) * 1000,\n        p99_itl_ms=np.percentile(itls or 0, 99) * 1000,\n        max_itl_ms=np.max(itls or 0) * 1000,\n        mean_e2e_latency_ms=np.mean(e2e_latencies) * 1000,\n        median_e2e_latency_ms=np.median(e2e_latencies) * 1000,\n        std_e2e_latency_ms=np.std(e2e_latencies) * 1000,\n        p90_e2e_latency_ms=np.percentile(e2e_latencies, 90) * 1000,\n        p99_e2e_latency_ms=np.percentile(e2e_latencies, 99) * 1000,\n        concurrency=np.sum(e2e_latencies) / dur_s,\n        max_output_tokens_per_s=max_output_tokens_per_s,\n        max_concurrent_requests=max_concurrent_requests,\n    )\n\n    return metrics, output_lens\n\n\nMULTI_TURN_BACKENDS = {\"sglang-oai-chat\", \"vllm-chat\", \"lmdeploy-chat\"}\n\n\ndef wrap_multi_turn_request_func(request_func: Callable, backend: str) -> Callable:\n    assert (\n        backend in MULTI_TURN_BACKENDS\n    ), f\"Multi-turn only supports chat backends: {MULTI_TURN_BACKENDS}, got {backend}\"\n\n    async def f(\n        request_func_input: RequestFuncInput,\n        pbar: Optional[tqdm] = None,\n    ) -> List[RequestFuncOutput]:\n        prompts: List[str] = request_func_input.prompt\n        prev_messages: List[Dict[str, str]] = []\n        outputs = []\n\n        for round_index in range(len(prompts)):\n            prev_messages.append({\"role\": \"user\", \"content\": prompts[round_index]})\n\n            inner_input = replace(\n                copy.deepcopy(request_func_input), prompt=copy.deepcopy(prev_messages)\n            )\n            output = await request_func(\n                inner_input, pbar=pbar if round_index == len(prompts) - 1 else None\n            )\n            outputs.append(output)\n\n            prev_messages.append(\n                {\"role\": \"assistant\", \"content\": output.generated_text}\n            )\n\n        return outputs\n\n    return f\n\n\nasync def benchmark(\n    backend: str,\n    api_url: str,\n    base_url: str,\n    model_id: str,\n    tokenizer: PreTrainedTokenizerBase,\n    input_requests: List[DatasetRow],\n    request_rate: float,\n    max_concurrency: Optional[int],\n    disable_tqdm: bool,\n    lora_names: List[str],\n    lora_request_distribution: Optional[str],\n    lora_zipf_alpha: Optional[float],\n    extra_request_body: Dict[str, Any],\n    profile: bool,\n    pd_separated: bool = False,\n    flush_cache: bool = False,\n    warmup_requests: int = 1,\n    use_trace_timestamps: bool = False,\n    mooncake_slowdown_factor=1.0,\n    mooncake_num_rounds=1,\n    profile_prefill_url: Optional[List[str]] = None,\n    profile_decode_url: Optional[List[str]] = None,\n):\n    if backend in ASYNC_REQUEST_FUNCS:\n        request_func = ASYNC_REQUEST_FUNCS[backend]\n    else:\n        raise ValueError(f\"Unknown backend: {backend}\")\n\n    # Check for multi-turn: prompt is a list of strings (not OpenAI messages dicts)\n    # Multi-turn format: [\"turn1\", \"turn2\", ...] - list of strings\n    # OpenAI format: [{\"role\": \"user\", \"content\": \"...\"}, ...] - list of dicts\n    first_prompt = input_requests[0].prompt\n    is_multi_turn = (\n        isinstance(first_prompt, list)\n        and len(first_prompt) > 0\n        and isinstance(first_prompt[0], str)\n    )\n    if is_multi_turn:\n        request_func = wrap_multi_turn_request_func(request_func, backend=backend)\n\n    # Limit concurrency\n    # From https://github.com/vllm-project/vllm/pull/9390\n    semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None\n\n    async def limited_request_func(request_func_input, pbar):\n        if semaphore is None:\n            return await request_func(request_func_input=request_func_input, pbar=pbar)\n        async with semaphore:\n            return await request_func(request_func_input=request_func_input, pbar=pbar)\n\n    # Warmup\n    print(f\"Starting warmup with {warmup_requests} sequences...\")\n\n    # Handle the data structure difference for the warmup request\n    if args.dataset_name == \"mooncake\":\n        # For mooncake, input_requests is a list of dicts.\n        # We need to build a temporary DatasetRow for the warmup phase.\n        warmup_record = input_requests[0]\n\n        # Build prompt from hash_ids, just like in the async generator\n        hash_ids = warmup_record.get(\"hash_ids\", [])\n        prompt_text = \"\"\n        for hash_id in hash_ids:\n            prompt_text += f\"{hash_id}\" + \" \".join([\"hi\"] * 512)\n        prompt_text += \"Can you tell me a detailed story in 1000 words?\"\n\n        output_len = warmup_record.get(\"output_length\", 32)\n        prompt_len = len(tokenizer.encode(prompt_text))\n\n        # Create a temporary DatasetRow object for warmup\n        test_request = DatasetRow(\n            prompt=prompt_text,\n            prompt_len=prompt_len,\n            output_len=output_len,\n            image_data=None,  # Mooncake doesn't have image data\n        )\n    else:\n        # For all other datasets, input_requests is a list of DatasetRow objects\n        test_request = input_requests[0]\n\n    if lora_names is not None and len(lora_names) != 0:\n        lora_name = lora_names[0]\n    else:\n        lora_name = None\n\n    # Create the test input once\n    test_input = RequestFuncInput(\n        model=model_id,\n        prompt=test_request.prompt,\n        api_url=api_url,\n        prompt_len=test_request.prompt_len,\n        output_len=min(test_request.output_len, 32),\n        lora_name=lora_name,\n        image_data=test_request.image_data,\n        extra_request_body=extra_request_body,\n    )\n\n    # Run warmup requests\n    warmup_tasks = []\n    for _ in range(warmup_requests):\n        warmup_tasks.append(\n            asyncio.create_task(request_func(request_func_input=test_input))\n        )\n\n    warmup_outputs = await asyncio.gather(*warmup_tasks)\n    if is_multi_turn:\n        warmup_outputs = [x for output in warmup_outputs for x in output]\n\n    # Check if at least one warmup request succeeded\n    if warmup_requests > 0 and not any(output.success for output in warmup_outputs):\n        raise ValueError(\n            \"Warmup failed - Please make sure benchmark arguments \"\n            f\"are correctly specified. Error: {warmup_outputs[0].error}\"\n        )\n    else:\n        print(\n            f\"Warmup completed with {args.warmup_requests} sequences. Starting main benchmark run...\"\n        )\n\n    # Flush cache\n    if (\"sglang\" in backend and _get_bool_env_var(\"SGLANG_IS_IN_CI\")) or flush_cache:\n        requests.post(base_url + \"/flush_cache\", headers=get_auth_headers())\n\n    time.sleep(1.0)\n\n    # Build profile URLs for PD separated mode (do this once at the beginning)\n    pd_profile_urls = []\n    if profile and pd_separated:\n        pd_profile_urls = _build_profile_urls(profile_prefill_url, profile_decode_url)\n        if not pd_profile_urls:\n            print(\n                \"Warning: PD separated mode requires --profile-prefill-url or --profile-decode-url\"\n            )\n            print(\"Skipping profiler start. Please specify worker URLs for profiling.\")\n\n    # Start profiler\n    if profile:\n        if pd_separated:\n            if pd_profile_urls:\n                await _call_profile_pd(pd_profile_urls, \"start\")\n        else:\n            print(\"Starting profiler...\")\n            profile_output = await async_request_profile(\n                api_url=base_url + \"/start_profile\"\n            )\n            if profile_output.success:\n                print(\"Profiler started\")\n\n    # Run all requests\n    benchmark_start_time = time.perf_counter()\n    tasks: List[asyncio.Task] = []\n    pbar_total = len(input_requests)\n    if (\n        backend == \"sglang\" and args.dataset_name == \"mooncake\"\n    ):  # Assuming mooncake is mainly for sglang or similar backends\n        print(\"Using time-based Mooncake request scheduler, ignoring --request-rate.\")\n        request_generator = get_mooncake_request_over_time(\n            input_requests, tokenizer, mooncake_slowdown_factor, mooncake_num_rounds\n        )\n        print(\n            f\"Starting Mooncake trace replay. Sessions: {len(input_requests)}, Rounds per session: {mooncake_num_rounds}. Slowdown factor: {mooncake_slowdown_factor}\"\n        )\n        pbar_total *= args.mooncake_num_rounds\n    else:\n        request_generator = get_request(input_requests, request_rate)\n\n    # Prepare LoRA request distribution parameters\n    if lora_request_distribution == \"distinct\":\n        lora_idx = 0\n    elif lora_request_distribution == \"skewed\":\n        weights = np.array([lora_zipf_alpha**-i for i in range(len(lora_names))])\n        lora_probs = weights / np.sum(weights)\n    else:\n        lora_idx = None\n        lora_probs = None\n\n    pbar = None if disable_tqdm else tqdm(total=pbar_total)\n    async for request in request_generator:\n        if lora_names is not None and len(lora_names) != 0:\n            if lora_request_distribution == \"uniform\":\n                lora_name = random.choice(lora_names)\n            elif lora_request_distribution == \"distinct\":\n                lora_name = lora_names[lora_idx]\n                lora_idx = (lora_idx + 1) % len(lora_names)\n            else:\n                assert (\n                    lora_request_distribution == \"skewed\"\n                ), f\"Unexpected lora_request_distribution: {lora_request_distribution}. Expected 'skewed'.\"\n\n                lora_name = np.random.choice(lora_names, p=lora_probs)\n        else:\n            lora_name = None\n\n        # Merge global extra_request_body with per-request extras\n        # Per-request parameters take precedence over global ones\n        merged_extra_body = {**extra_request_body, **request.extra_request_body}\n\n        request_func_input = RequestFuncInput(\n            model=model_id,\n            prompt=request.prompt,\n            api_url=api_url,\n            prompt_len=request.prompt_len,\n            output_len=request.output_len,\n            lora_name=lora_name,\n            image_data=request.image_data,\n            extra_request_body=merged_extra_body,\n            timestamp=request.timestamp,\n            routing_key=request.routing_key,\n        )\n\n        tasks.append(\n            asyncio.create_task(\n                limited_request_func(request_func_input=request_func_input, pbar=pbar)\n            )\n        )\n    outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)\n    if is_multi_turn:\n        outputs = [x for output in outputs for x in output]\n\n    # Stop profiler (only if profile_steps was not provided, as it auto-stops)\n    if profile and not (\n        hasattr(args, \"profile_steps\") and args.profile_steps is not None\n    ):\n        if pd_separated:\n            if pd_profile_urls:\n                await _call_profile_pd(pd_profile_urls, \"stop\")\n        else:\n            if getattr(args, \"profile_num_steps\", None) is None:\n                print(\"Stopping profiler...\")\n                profile_output = await async_request_profile(\n                    api_url=base_url + \"/stop_profile\"\n                )\n                if profile_output.success:\n                    print(\"Profiler stopped\")\n\n    if pbar is not None:\n        pbar.close()\n\n    if \"sglang\" in backend:\n        server_info = requests.get(\n            base_url + \"/get_server_info\", headers=get_auth_headers()\n        )\n        if server_info.status_code == 200:\n            server_info_json = server_info.json()\n            if \"decode\" in server_info_json:\n                server_info_json = server_info_json[\"decode\"][0]\n            if (\n                \"internal_states\" in server_info_json\n                and server_info_json[\"internal_states\"]\n            ):\n                accept_length = server_info_json[\"internal_states\"][0].get(\n                    \"avg_spec_accept_length\", None\n                )\n            else:\n                accept_length = None\n        else:\n            accept_length = None\n    else:\n        accept_length = None\n\n    # Compute metrics and print results\n    benchmark_duration = time.perf_counter() - benchmark_start_time\n    metrics, output_lens = calculate_metrics(\n        input_requests=None if is_multi_turn else input_requests,\n        outputs=outputs,\n        dur_s=benchmark_duration,\n        tokenizer=tokenizer,\n        backend=backend,\n        accept_length=accept_length,\n        plot_throughput=args.plot_throughput,\n    )\n\n    print(\"\\n{s:{c}^{n}}\".format(s=\" Serving Benchmark Result \", n=50, c=\"=\"))\n    print(\"{:<40} {:<10}\".format(\"Backend:\", backend))\n    print(\n        \"{:<40} {:<10}\".format(\n            \"Traffic request rate:\", \"trace\" if use_trace_timestamps else request_rate\n        )\n    )\n    print(\n        \"{:<40} {:<10}\".format(\n            \"Max request concurrency:\",\n            max_concurrency if max_concurrency else \"not set\",\n        )\n    )\n    print(\"{:<40} {:<10}\".format(\"Successful requests:\", metrics.completed))\n    print(\"{:<40} {:<10.2f}\".format(\"Benchmark duration (s):\", benchmark_duration))\n    print(\"{:<40} {:<10}\".format(\"Total input tokens:\", metrics.total_input))\n    print(\"{:<40} {:<10}\".format(\"Total input text tokens:\", metrics.total_input_text))\n    if args.dataset_name in [\"image\", \"mmmu\"]:\n        print(\n            \"{:<40} {:<10}\".format(\n                \"Total input vision tokens:\", metrics.total_input_vision\n            )\n        )\n    is_embedding = backend == \"sglang-embedding\"\n    if not is_embedding:\n        print(\"{:<40} {:<10}\".format(\"Total generated tokens:\", metrics.total_output))\n        print(\n            \"{:<40} {:<10}\".format(\n                \"Total generated tokens (retokenized):\",\n                metrics.total_output_retokenized,\n            )\n        )\n    print(\n        \"{:<40} {:<10.2f}\".format(\n            \"Request throughput (req/s):\", metrics.request_throughput\n        )\n    )\n    print(\n        \"{:<40} {:<10.2f}\".format(\n            \"Input token throughput (tok/s):\", metrics.input_throughput\n        )\n    )\n    if not is_embedding:\n        print(\n            \"{:<40} {:<10.2f}\".format(\n                \"Output token throughput (tok/s):\", metrics.output_throughput\n            )\n        )\n        print(\n            \"{:<40} {:<10.2f}\".format(\n                \"Peak output token throughput (tok/s):\",\n                metrics.max_output_tokens_per_s,\n            )\n        )\n    print(\n        \"{:<40} {:<10}\".format(\n            \"Peak concurrent requests:\", metrics.max_concurrent_requests\n        )\n    )\n    if not is_embedding:\n        print(\n            \"{:<40} {:<10.2f}\".format(\n                \"Total token throughput (tok/s):\", metrics.total_throughput\n            )\n        )\n    print(\"{:<40} {:<10.2f}\".format(\"Concurrency:\", metrics.concurrency))\n    if accept_length:\n        print(\"{:<40} {:<10.2f}\".format(\"Accept length:\", accept_length))\n    print(\"{s:{c}^{n}}\".format(s=\"End-to-End Latency\", n=50, c=\"-\"))\n    print(\n        \"{:<40} {:<10.2f}\".format(\"Mean E2E Latency (ms):\", metrics.mean_e2e_latency_ms)\n    )\n    print(\n        \"{:<40} {:<10.2f}\".format(\n            \"Median E2E Latency (ms):\", metrics.median_e2e_latency_ms\n        )\n    )\n    print(\n        \"{:<40} {:<10.2f}\".format(\"P90 E2E Latency (ms):\", metrics.p90_e2e_latency_ms)\n    )\n    print(\n        \"{:<40} {:<10.2f}\".format(\"P99 E2E Latency (ms):\", metrics.p99_e2e_latency_ms)\n    )\n    if not is_embedding:\n        print(\"{s:{c}^{n}}\".format(s=\"Time to First Token\", n=50, c=\"-\"))\n        print(\"{:<40} {:<10.2f}\".format(\"Mean TTFT (ms):\", metrics.mean_ttft_ms))\n        print(\"{:<40} {:<10.2f}\".format(\"Median TTFT (ms):\", metrics.median_ttft_ms))\n        print(\"{:<40} {:<10.2f}\".format(\"P99 TTFT (ms):\", metrics.p99_ttft_ms))\n        print(\n            \"{s:{c}^{n}}\".format(\n                s=\"Time per Output Token (excl. 1st token)\", n=50, c=\"-\"\n            )\n        )\n        print(\"{:<40} {:<10.2f}\".format(\"Mean TPOT (ms):\", metrics.mean_tpot_ms))\n        print(\"{:<40} {:<10.2f}\".format(\"Median TPOT (ms):\", metrics.median_tpot_ms))\n        print(\"{:<40} {:<10.2f}\".format(\"P99 TPOT (ms):\", metrics.p99_tpot_ms))\n        print(\"{s:{c}^{n}}\".format(s=\"Inter-Token Latency\", n=50, c=\"-\"))\n        print(\"{:<40} {:<10.2f}\".format(\"Mean ITL (ms):\", metrics.mean_itl_ms))\n        print(\"{:<40} {:<10.2f}\".format(\"Median ITL (ms):\", metrics.median_itl_ms))\n        print(\"{:<40} {:<10.2f}\".format(\"P95 ITL (ms):\", metrics.p95_itl_ms))\n        print(\"{:<40} {:<10.2f}\".format(\"P99 ITL (ms):\", metrics.p99_itl_ms))\n        print(\"{:<40} {:<10.2f}\".format(\"Max ITL (ms):\", metrics.max_itl_ms))\n    print(\"=\" * 50)\n\n    resp = requests.get(base_url + \"/get_server_info\", headers=get_auth_headers())\n    server_info = resp.json() if resp.status_code == 200 else None\n\n    if (\n        metrics.median_ttft_ms is not None\n        and metrics.mean_itl_ms is not None\n        and metrics.output_throughput is not None\n    ):\n        result = {\n            # Arguments\n            \"tag\": getattr(args, \"tag\", None),\n            \"backend\": args.backend,\n            \"dataset_name\": args.dataset_name,\n            \"request_rate\": \"trace\" if use_trace_timestamps else request_rate,\n            \"max_concurrency\": max_concurrency,\n            \"sharegpt_output_len\": args.sharegpt_output_len,\n            \"random_input_len\": args.random_input_len,\n            \"random_output_len\": args.random_output_len,\n            \"random_range_ratio\": args.random_range_ratio,\n            # Information\n            \"server_info\": server_info,\n            # Results\n            \"duration\": benchmark_duration,\n            \"completed\": metrics.completed,\n            \"total_input_tokens\": metrics.total_input,\n            \"total_input_text_tokens\": metrics.total_input_text,\n            \"total_input_vision_tokens\": metrics.total_input_vision,\n            \"total_output_tokens\": metrics.total_output,\n            \"total_output_tokens_retokenized\": metrics.total_output_retokenized,\n            \"request_throughput\": metrics.request_throughput,\n            \"input_throughput\": metrics.input_throughput,\n            \"output_throughput\": metrics.output_throughput,\n            \"total_throughput\": metrics.total_throughput,\n            \"mean_e2e_latency_ms\": metrics.mean_e2e_latency_ms,\n            \"median_e2e_latency_ms\": metrics.median_e2e_latency_ms,\n            \"std_e2e_latency_ms\": metrics.std_e2e_latency_ms,\n            \"p90_e2e_latency_ms\": metrics.p90_e2e_latency_ms,\n            \"p99_e2e_latency_ms\": metrics.p99_e2e_latency_ms,\n            \"mean_ttft_ms\": metrics.mean_ttft_ms,\n            \"median_ttft_ms\": metrics.median_ttft_ms,\n            \"std_ttft_ms\": metrics.std_ttft_ms,\n            \"p99_ttft_ms\": metrics.p99_ttft_ms,\n            \"mean_tpot_ms\": metrics.mean_tpot_ms,\n            \"median_tpot_ms\": metrics.median_tpot_ms,\n            \"std_tpot_ms\": metrics.std_tpot_ms,\n            \"p99_tpot_ms\": metrics.p99_tpot_ms,\n            \"mean_itl_ms\": metrics.mean_itl_ms,\n            \"median_itl_ms\": metrics.median_itl_ms,\n            \"std_itl_ms\": metrics.std_itl_ms,\n            \"p95_itl_ms\": metrics.p95_itl_ms,\n            \"p99_itl_ms\": metrics.p99_itl_ms,\n            \"concurrency\": metrics.concurrency,\n            \"accept_length\": accept_length,\n            \"max_output_tokens_per_s\": metrics.max_output_tokens_per_s,\n            \"max_concurrent_requests\": metrics.max_concurrent_requests,\n        }\n    else:\n        print(f\"Error running benchmark for request rate: {request_rate}\")\n        print(\"-\" * 30)\n\n    # Determine output file name\n    if args.output_file:\n        output_file_name = args.output_file\n    else:\n        now = datetime.now().strftime(\"%m%d\")\n        if args.dataset_name == \"image\":\n            output_file_name = (\n                f\"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_\"\n                f\"{args.random_output_len}_{args.image_count}imgs_\"\n                f\"{args.image_resolution}.jsonl\"\n            )\n        elif args.dataset_name.startswith(\"random\"):\n            output_file_name = f\"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl\"\n        else:\n            output_file_name = (\n                f\"{args.backend}_{now}_{args.num_prompts}_{args.dataset_name}.jsonl\"\n            )\n\n    result_details = {\n        \"input_lens\": [output.prompt_len for output in outputs],\n        \"output_lens\": output_lens,\n        \"ttfts\": [output.ttft for output in outputs],\n        \"itls\": [output.itl for output in outputs],\n        \"generated_texts\": [output.generated_text for output in outputs],\n        \"errors\": [output.error for output in outputs],\n    }\n\n    # Append results to a JSONL file\n    with open(output_file_name, \"a\") as file:\n        if args.output_details:\n            result_for_dump = result | result_details\n        else:\n            result_for_dump = result\n        file.write(json.dumps(result_for_dump) + \"\\n\")\n\n    return result | result_details\n\n\ndef check_chat_template(model_path):\n    try:\n        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)\n        return \"chat_template\" in tokenizer.init_kwargs\n    except Exception as e:\n        print(f\"Fail to load tokenizer config with error={e}\")\n        return False\n\n\ndef set_global_args(args_: argparse.Namespace):\n    \"\"\"Set the global args.\"\"\"\n    global args\n    args = args_\n\n\ndef run_benchmark(args_: argparse.Namespace):\n    global args\n    args = args_\n\n    # Set default value for max_concurrency if not present\n    if not hasattr(args, \"max_concurrency\"):\n        args.max_concurrency = None\n\n    # Set default value for warmup_requests if not present\n    if not hasattr(args, \"warmup_requests\"):\n        args.warmup_requests = 1\n\n    if not hasattr(args, \"output_details\"):\n        args.output_details = False\n\n    if not hasattr(args, \"tokenize_prompt\"):\n        args.tokenize_prompt = False\n\n    if not hasattr(args, \"plot_throughput\"):\n        args.plot_throughput = False\n\n    if not hasattr(args, \"top_logprobs_num\"):\n        args.top_logprobs_num = 0\n    if not hasattr(args, \"token_ids_logprob\"):\n        args.token_ids_logprob = None\n    if not hasattr(args, \"logprob_start_len\"):\n        args.logprob_start_len = -1\n    if not hasattr(args, \"return_logprob\"):\n        args.return_logprob = False\n\n    if not hasattr(args, \"use_trace_timestamps\"):\n        args.use_trace_timestamps = False\n    if not hasattr(args, \"mooncake_slowdown_factor\"):\n        args.mooncake_slowdown_factor = 1.0\n\n    if not hasattr(args, \"mooncake_slowdown_factor\"):\n        args.mooncake_slowdown_factor = 1.0\n\n    if not hasattr(args, \"mooncake_num_rounds\"):\n        args.mooncake_num_rounds = 1\n\n    if not hasattr(args, \"served_model_name\"):\n        args.served_model_name = None\n\n    if getattr(args, \"print_requests\", False):\n        assert args.backend == \"sglang-oai-chat\"  # only support this now\n\n    print(f\"benchmark_args={args}\")\n\n    # Set global environments\n    set_ulimit()\n    random.seed(args.seed)\n    np.random.seed(args.seed)\n\n    extra_request_body = {}\n    if args.extra_request_body:\n        extra_request_body = json.loads(args.extra_request_body)\n\n    if args.tokenize_prompt:\n        assert (\n            args.backend == \"sglang\"\n        ), \"`--tokenize-prompt` only compatible with `--backend sglang` currently\"\n\n    # Set url\n    if args.port is None:\n        args.port = {\n            \"sglang\": 30000,\n            \"sglang-native\": 30000,\n            \"sglang-oai\": 30000,\n            \"lmdeploy\": 23333,\n            \"vllm\": 8000,\n            \"trt\": 8000,\n            \"gserver\": 9988,\n            \"truss\": 8080,\n        }.get(args.backend, 30000)\n\n    model_url = (\n        f\"{args.base_url}/v1/models\"\n        if args.base_url\n        else f\"http://{args.host}:{args.port}/v1/models\"\n    )\n\n    if args.backend == \"sglang-embedding\":\n        api_url = (\n            f\"{args.base_url}/v1/embeddings\"\n            if args.base_url\n            else f\"http://{args.host}:{args.port}/v1/embeddings\"\n        )\n    elif args.backend in [\"sglang\", \"sglang-native\"]:\n        api_url = (\n            f\"{args.base_url}/generate\"\n            if args.base_url\n            else f\"http://{args.host}:{args.port}/generate\"\n        )\n    elif args.backend in [\"sglang-oai\", \"vllm\", \"lmdeploy\"]:\n        api_url = (\n            f\"{args.base_url}/v1/completions\"\n            if args.base_url\n            else f\"http://{args.host}:{args.port}/v1/completions\"\n        )\n    elif args.backend in [\"sglang-oai-chat\", \"vllm-chat\", \"lmdeploy-chat\"]:\n        api_url = (\n            f\"{args.base_url}/v1/chat/completions\"\n            if args.base_url\n            else f\"http://{args.host}:{args.port}/v1/chat/completions\"\n        )\n    elif args.backend == \"trt\":\n        api_url = (\n            f\"{args.base_url}/v2/models/ensemble/generate_stream\"\n            if args.base_url\n            else f\"http://{args.host}:{args.port}/v2/models/ensemble/generate_stream\"\n        )\n        if args.model is None:\n            print(\"Please provide a model using `--model` when using `trt` backend.\")\n            sys.exit(1)\n    elif args.backend == \"gserver\":\n        api_url = args.base_url if args.base_url else f\"{args.host}:{args.port}\"\n        args.model = args.model or \"default\"\n    elif args.backend == \"truss\":\n        api_url = (\n            f\"{args.base_url}/v1/models/model:predict\"\n            if args.base_url\n            else f\"http://{args.host}:{args.port}/v1/models/model:predict\"\n        )\n    base_url = (\n        f\"http://{args.host}:{args.port}\" if args.base_url is None else args.base_url\n    )\n\n    # Wait for server to be ready\n    if args.ready_check_timeout_sec > 0:\n        health_url = model_url if args.backend not in (\"trt\", \"gserver\") else base_url\n        if not wait_for_endpoint(health_url, args.ready_check_timeout_sec):\n            print(f\"Server at {health_url} is not ready. Exiting.\")\n            sys.exit(1)\n\n    # Get model name\n    if args.model is None:\n        if args.backend == \"truss\":\n            print(\n                \"Please provide a model with `--model` when using truss backend. e.g. --model meta-llama/Llama-3.1-8B-Instruct\"\n            )\n            sys.exit(1)\n        try:\n            response = requests.get(model_url, headers=get_auth_headers())\n            model_list = response.json().get(\"data\", [])\n            args.model = model_list[0][\"id\"] if model_list else None\n        except Exception as e:\n            print(f\"Failed to fetch model from {model_url}. Error: {e}\")\n            print(\n                \"Please specify the correct host and port using `--host` and `--port`.\"\n            )\n            sys.exit(1)\n\n    if args.model is None:\n        print(\"No model specified or found. Please provide a model using `--model`.\")\n        sys.exit(1)\n\n    if args.backend != \"sglang-embedding\" and not check_chat_template(args.model):\n        print(\n            \"\\nWARNING It is recommended to use the `Chat` or `Instruct` model for benchmarking.\\n\"\n            \"Because when the tokenizer counts the output tokens, if there is gibberish, it might count incorrectly.\\n\"\n        )\n\n    if (\n        args.backend == \"sglang-embedding\"\n        and args.dataset_name in _EMBEDDING_UNSUPPORTED_DATASETS\n    ):\n        print(f\"{args.dataset_name} dataset is unsupported for embeddings benchmark\")\n        sys.exit(1)\n\n    if args.dataset_name in [\"image\", \"mmmu\"]:\n        args.apply_chat_template = True\n        assert (\n            not args.tokenize_prompt\n        ), \"`--tokenize-prompt` not compatible with image dataset\"\n\n    if args.lora_request_distribution in [\"distinct\", \"skewed\"]:\n        assert (\n            args.lora_name is not None and len(args.lora_name) > 1\n        ), \"More than 1 LoRA adapter must be specified via --lora-name to use 'distinct' or 'skewed' request distribution.\"\n\n    assert (\n        args.lora_zipf_alpha > 1\n    ), f\"Got invalid value for --lora-zipf-alpha of {args.lora_zipf_alpha}. It must be greater than 1.\"\n\n    print(f\"{args}\\n\")\n\n    # Read dataset\n    backend = args.backend\n    model_id = args.served_model_name or args.model\n    tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model\n    tokenizer = get_tokenizer(tokenizer_id)\n    input_requests = get_dataset(args, tokenizer, model_id)\n\n    # compatible with SimpleNamespace\n    if not hasattr(args, \"flush_cache\"):\n        args.flush_cache = False\n\n    # Prepare LoRA arguments\n    lora_request_distribution = (\n        args.lora_request_distribution if args.lora_name is not None else None\n    )\n\n    lora_zipf_alpha = (\n        args.lora_zipf_alpha\n        if args.lora_name is not None and args.lora_request_distribution == \"skewed\"\n        else None\n    )\n\n    return asyncio.run(\n        benchmark(\n            backend=backend,\n            api_url=api_url,\n            base_url=base_url,\n            model_id=model_id,\n            tokenizer=tokenizer,\n            input_requests=input_requests,\n            request_rate=args.request_rate,\n            max_concurrency=args.max_concurrency,\n            disable_tqdm=args.disable_tqdm,\n            lora_names=args.lora_name,\n            lora_request_distribution=lora_request_distribution,\n            lora_zipf_alpha=lora_zipf_alpha,\n            extra_request_body=extra_request_body,\n            profile=args.profile,\n            pd_separated=args.pd_separated,\n            flush_cache=args.flush_cache,\n            warmup_requests=args.warmup_requests,\n            use_trace_timestamps=args.use_trace_timestamps,\n            mooncake_slowdown_factor=args.mooncake_slowdown_factor,\n            mooncake_num_rounds=args.mooncake_num_rounds,\n            profile_prefill_url=getattr(args, \"profile_prefill_url\", None),\n            profile_decode_url=getattr(args, \"profile_decode_url\", None),\n        )\n    )\n\n\nclass LoRAPathAction(argparse.Action):\n    def __call__(self, parser, namespace, values, option_string=None):\n        setattr(namespace, self.dest, [])\n        for lora_name in values:\n            getattr(namespace, self.dest).append(lora_name)\n\n\nif __name__ == \"__main__\":\n    parser = ArgumentParser(description=\"Benchmark the online serving throughput.\")\n    parser.add_argument(\n        \"--backend\",\n        type=str,\n        choices=list(ASYNC_REQUEST_FUNCS.keys()),\n        default=\"sglang\",\n        help=\"Must specify a backend, depending on the LLM Inference Engine.\",\n    )\n    parser.add_argument(\n        \"--base-url\",\n        type=str,\n        default=None,\n        help=\"Server or API base url if not using http host and port.\",\n    )\n    parser.add_argument(\n        \"--host\", type=str, default=\"0.0.0.0\", help=\"Default host is 0.0.0.0.\"\n    )\n    parser.add_argument(\n        \"--port\",\n        type=int,\n        help=\"If not set, the default port is configured according to its default value for different LLM Inference Engines.\",\n    )\n    parser.add_argument(\n        \"--ready-check-timeout-sec\",\n        type=int,\n        default=60,\n        help=\"Maximum time in seconds to wait for the server to be ready before benchmarking. Set to 0 to skip. Default: 60.\",\n    )\n    parser.add_argument(\n        \"--dataset-name\",\n        type=str,\n        default=\"sharegpt\",\n        choices=[\n            \"sharegpt\",\n            \"custom\",\n            \"openai\",\n            \"random\",\n            \"random-ids\",\n            \"generated-shared-prefix\",\n            \"mmmu\",\n            \"image\",\n            \"mooncake\",\n        ],\n        help=\"Name of the dataset to benchmark on.\",\n    )\n    parser.add_argument(\n        \"--dataset-path\", type=str, default=\"\", help=\"Path to the dataset.\"\n    )\n    parser.add_argument(\n        \"--model\",\n        type=str,\n        help=\"Name or path of the model. If not set, the default model will request /v1/models for conf.\",\n    )\n    parser.add_argument(\n        \"--served-model-name\",\n        type=str,\n        help=\"The name of the model as served by the serving service. If not set, this defaults to the value of --model.\",\n    )\n    parser.add_argument(\n        \"--tokenizer\",\n        type=str,\n        help=\"Name or path of the tokenizer. If not set, using the model conf.\",\n    )\n    parser.add_argument(\n        \"--num-prompts\",\n        type=int,\n        default=1000,\n        help=\"Number of prompts to process. Default is 1000.\",\n    )\n    parser.add_argument(\n        \"--sharegpt-output-len\",\n        type=int,\n        default=None,\n        help=\"Output length for each request. Overrides the output length from the ShareGPT dataset.\",\n    )\n    parser.add_argument(\n        \"--sharegpt-context-len\",\n        type=int,\n        default=None,\n        help=\"The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped.\",\n    )\n    parser.add_argument(\n        \"--random-input-len\",\n        type=int,\n        default=1024,\n        help=\"Number of input tokens per request, used only for random and image dataset.\",\n    )\n    parser.add_argument(\n        \"--random-output-len\",\n        default=1024,\n        type=int,\n        help=\"Number of output tokens per request, used only for random and image dataset.\",\n    )\n    parser.add_argument(\n        \"--random-range-ratio\",\n        type=float,\n        default=0.0,\n        help=\"Range of sampled ratio of input/output length, \"\n        \"used only for random and image dataset.\",\n    )\n    # image dataset args\n    parser.add_argument(\n        \"--image-count\",\n        type=int,\n        default=1,\n        help=\"Number of images per request (only available with the image dataset)\",\n    )\n    parser.add_argument(\n        \"--image-resolution\",\n        type=str,\n        default=\"1080p\",\n        help=(\n            \"Resolution of images for image dataset. \"\n            \"Supports presets 4k/1080p/720p/360p or custom 'heightxwidth' (e.g., 1080x1920).\"\n        ),\n    )\n    parser.add_argument(\n        \"--random-image-count\",\n        action=\"store_true\",\n        help=\"Enable Random Image Count\",\n    )\n    parser.add_argument(\n        \"--image-format\",\n        type=str,\n        default=\"jpeg\",\n        help=(\"Format of images for image dataset. \" \"Supports jpeg and png.\"),\n    )\n    parser.add_argument(\n        \"--image-content\",\n        type=str,\n        default=\"random\",\n        help=(\"Content for images for image dataset. \" \"Supports random and blank.\"),\n    )\n    parser.add_argument(\n        \"--request-rate\",\n        type=float,\n        default=float(\"inf\"),\n        help=\"Number of requests per second. If this is inf, then all the requests are sent at time 0. \"\n        \"Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.\",\n    )\n    parser.add_argument(\n        \"--use-trace-timestamps\",\n        action=\"store_true\",\n        help=\"Use timestamps from the trace file for request scheduling. Only valid for 'mooncake' dataset.\",\n    )\n    parser.add_argument(\n        \"--max-concurrency\",\n        type=int,\n        default=None,\n        help=\"Maximum number of concurrent requests. This can be used \"\n        \"to help simulate an environment where a higher level component \"\n        \"is enforcing a maximum number of concurrent requests. While the \"\n        \"--request-rate argument controls the rate at which requests are \"\n        \"initiated, this argument will control how many are actually allowed \"\n        \"to execute at a time. This means that when used in combination, the \"\n        \"actual request rate may be lower than specified with --request-rate, \"\n        \"if the server is not processing requests fast enough to keep up.\",\n    )\n    parser.add_argument(\"--output-file\", type=str, help=\"Output JSONL file name.\")\n    parser.add_argument(\n        \"--output-details\", action=\"store_true\", help=\"Output details of benchmarking.\"\n    )\n    parser.add_argument(\n        \"--print-requests\",\n        action=\"store_true\",\n        help=\"Print requests immediately during benchmarking. Useful to quickly realize issues.\",\n    )\n    parser.add_argument(\n        \"--disable-tqdm\",\n        action=\"store_true\",\n        help=\"Specify to disable tqdm progress bar.\",\n    )\n    parser.add_argument(\n        \"--disable-stream\",\n        action=\"store_true\",\n        help=\"Disable streaming mode.\",\n    )\n    parser.add_argument(\n        \"--return-logprob\",\n        action=\"store_true\",\n        help=\"Return logprob.\",\n    )\n    parser.add_argument(\n        \"--top-logprobs-num\",\n        type=int,\n        default=0,\n        help=\"Number of top logprobs to return per token. Only used with --return-logprob.\",\n    )\n    parser.add_argument(\n        \"--token-ids-logprob\",\n        type=int,\n        nargs=\"+\",\n        default=None,\n        help=\"Token IDs to probe logprobs for. E.g. --token-ids-logprob 1 2 10 100 1000. Only used with --return-logprob.\",\n    )\n    parser.add_argument(\n        \"--logprob-start-len\",\n        type=int,\n        default=-1,\n        help=\"Start position for returning input logprobs. -1 means no input logprobs, 0 means all. Only used with --return-logprob.\",\n    )\n    parser.add_argument(\n        \"--return-routed-experts\",\n        action=\"store_true\",\n        help=\"Return routed experts.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=1, help=\"The random seed.\")\n    parser.add_argument(\n        \"--disable-ignore-eos\",\n        action=\"store_true\",\n        help=\"Disable ignoring EOS.\",\n    )\n    parser.add_argument(\n        \"--extra-request-body\",\n        metavar='{\"key1\": \"value1\", \"key2\": \"value2\"}',\n        type=str,\n        help=\"Append given JSON object to the request payload. You can use this to specify\"\n        \"additional generate params like sampling params.\",\n    )\n    parser.add_argument(\n        \"--apply-chat-template\",\n        action=\"store_true\",\n        help=\"Apply chat template\",\n    )\n    parser.add_argument(\n        \"--profile\",\n        action=\"store_true\",\n        help=\"Use Torch Profiler. The endpoint must be launched with \"\n        \"SGLANG_TORCH_PROFILER_DIR to enable profiler.\",\n    )\n    parser.add_argument(\n        \"--plot-throughput\",\n        action=\"store_true\",\n        help=\"Plot throughput and concurrent requests over time. Requires termplotlib and gnuplot.\",\n    )\n    # TODO unify all these\n    parser.add_argument(\n        \"--profile-activities\",\n        type=str,\n        nargs=\"+\",\n        default=[\"CPU\", \"GPU\"],\n        choices=[\"CPU\", \"GPU\", \"CUDA_PROFILER\", \"XPU\"],\n        help=\"Profiler activities to capture: CPU, GPU, XPU, CUDA_PROFILER.\",\n    )\n    parser.add_argument(\n        \"--profile-start-step\",\n        type=int,\n        default=None,\n        help=\"Start profiling after this many forward steps. Useful for warmup.\",\n    )\n    parser.add_argument(\n        \"--profile-steps\",\n        type=int,\n        default=None,\n        help=\"Number of steps to profile. If specified, profiling stops automatically after this many steps.\",\n    )\n    parser.add_argument(\"--profile-num-steps\", type=int, default=None)\n    parser.add_argument(\"--profile-by-stage\", action=\"store_true\", default=False)\n    parser.add_argument(\"--profile-stages\", nargs=\"+\", default=None)\n    parser.add_argument(\n        \"--profile-output-dir\",\n        type=str,\n        default=None,\n        help=\"Output directory for profile traces.\",\n    )\n    parser.add_argument(\n        \"--profile-prefix\",\n        type=str,\n        default=None,\n        help=\"Prefix for profile trace filenames.\",\n    )\n    parser.add_argument(\n        \"--lora-name\",\n        type=str,\n        nargs=\"*\",\n        default=None,\n        action=LoRAPathAction,\n        help=\"The names of LoRA adapters. You can provide a list of names in the format {name} {name} {name}...\",\n    )\n    parser.add_argument(\n        \"--lora-request-distribution\",\n        type=str,\n        default=\"uniform\",\n        choices=[\n            \"uniform\",\n            \"distinct\",\n            \"skewed\",\n        ],\n        help=\"What distribution to sample the LoRA adapters specified in --lora-name. Borrowed from the Punica paper. \"\n        \"'distinct' distribution means selecting a new LoRA adapter for every request. \"\n        \"'skewed' distribution follows the Zipf distribution, where the number of requests \"\n        \"to model i specified in --lora-name is α times the number of requests for model i+1, \"\n        \"where α > 1.\",\n    )\n    parser.add_argument(\n        \"--lora-zipf-alpha\",\n        type=float,\n        default=1.5,\n        help=\"The parameter to use for the Zipf distribution when --lora-request-distribution='skewed'.\",\n    )\n    parser.add_argument(\n        \"--prompt-suffix\",\n        type=str,\n        default=\"\",\n        help=\"Suffix applied to the end of all user prompts, followed by assistant prompt suffix.\",\n    )\n    parser.add_argument(\n        \"--pd-separated\",\n        action=\"store_true\",\n        help=\"Benchmark PD disaggregation server\",\n    )\n\n    # Create a mutually exclusive group for profiling URLs\n    # In PD separated mode, prefill and decode workers must be profiled separately\n    profile_url_group = parser.add_mutually_exclusive_group()\n    profile_url_group.add_argument(\n        \"--profile-prefill-url\",\n        type=str,\n        nargs=\"*\",\n        default=None,\n        help=\"URL(s) of the prefill worker(s) for profiling in PD separated mode. \"\n        \"Can specify multiple URLs: --profile-prefill-url http://localhost:30000 http://localhost:30001. \"\n        \"NOTE: Cannot be used together with --profile-decode-url. \"\n        \"In PD separated mode, prefill and decode workers must be profiled separately.\",\n    )\n    profile_url_group.add_argument(\n        \"--profile-decode-url\",\n        type=str,\n        nargs=\"*\",\n        default=None,\n        help=\"URL(s) of the decode worker(s) for profiling in PD separated mode. \"\n        \"Can specify multiple URLs: --profile-decode-url http://localhost:30010 http://localhost:30011. \"\n        \"NOTE: Cannot be used together with --profile-prefill-url. \"\n        \"In PD separated mode, prefill and decode workers must be profiled separately.\",\n    )\n    parser.add_argument(\n        \"--flush-cache\",\n        action=\"store_true\",\n        help=\"Flush the cache before running the benchmark\",\n    )\n    parser.add_argument(\n        \"--warmup-requests\",\n        type=int,\n        default=1,\n        help=\"Number of warmup requests to run before the benchmark\",\n    )\n    parser.add_argument(\n        \"--tokenize-prompt\",\n        action=\"store_true\",\n        help=\"Use integer ids instead of string for inputs. Useful to control prompt lengths accurately\",\n    )\n\n    group = parser.add_argument_group(\"generated-shared-prefix dataset arguments\")\n    group.add_argument(\n        \"--gsp-num-groups\",\n        type=int,\n        default=64,\n        help=\"Number of system prompt groups for generated-shared-prefix dataset\",\n    )\n    group.add_argument(\n        \"--gsp-prompts-per-group\",\n        type=int,\n        default=16,\n        help=\"Number of prompts per system prompt group for generated-shared-prefix dataset\",\n    )\n    group.add_argument(\n        \"--gsp-system-prompt-len\",\n        type=int,\n        default=2048,\n        help=\"Target length in tokens for system prompts in generated-shared-prefix dataset\",\n    )\n    group.add_argument(\n        \"--gsp-question-len\",\n        type=int,\n        default=128,\n        help=\"Target length in tokens for questions in generated-shared-prefix dataset\",\n    )\n    group.add_argument(\n        \"--gsp-output-len\",\n        type=int,\n        default=256,\n        help=\"Target length in tokens for outputs in generated-shared-prefix dataset\",\n    )\n    parser.add_argument(\n        \"--gsp-range-ratio\",\n        type=float,\n        # WARN: The default 1.0 is for backward compatibility, and is different from the default 0.0 for random dataset\n        default=1.0,\n        help=\"Range of sampled ratio of input/output length, used only for gsp dataset.\",\n    )\n    group.add_argument(\n        \"--gsp-fast-prepare\",\n        action=\"store_true\",\n        help=\"Speedup preparing by removing statistics computation, which will make some output statistics inaccurate but suitable for pressure tests.\",\n    )\n    group.add_argument(\n        \"--gsp-send-routing-key\",\n        action=\"store_true\",\n        help=\"Send routing key in requests via X-SMG-Routing-Key header. Requests with the same prefix share the same routing key.\",\n    )\n    group.add_argument(\n        \"--gsp-num-turns\",\n        type=int,\n        default=1,\n        help=\"Number of turns for multi-turn conversations. If > 1, each prompt becomes a list of questions sharing the same system prefix.\",\n    )\n    group.add_argument(\n        \"--gsp-ordered\",\n        action=\"store_true\",\n        help=\"Keep requests in order without shuffling. By default, requests are shuffled randomly.\",\n    )\n    mooncake_group = parser.add_argument_group(\"mooncake dataset arguments\")\n    mooncake_group.add_argument(\n        \"--mooncake-slowdown-factor\",\n        type=float,\n        default=1.0,\n        help=\"Slowdown factor for replaying the mooncake trace. \"\n        \"A value of 2.0 means the replay is twice as slow. \"\n        \"NOTE: --request-rate is IGNORED in mooncake mode.\",\n    )\n    mooncake_group.add_argument(\n        \"--mooncake-num-rounds\",\n        type=int,\n        default=1,\n        help=\"Number of conversation rounds for each session in the mooncake dataset. \"\n        \"A value > 1 will enable true multi-turn session benchmarking.\",\n    )\n    mooncake_group.add_argument(\n        \"--mooncake-workload\",\n        type=str,\n        default=\"conversation\",\n        choices=[\n            \"mooncake\",\n            \"conversation\",\n            \"synthetic\",\n            \"toolagent\",\n        ],\n        help=\"Underlying workload for the mooncake dataset.\",\n    )\n    parser.add_argument(\n        \"--tag\", type=str, default=None, help=\"The tag to be dumped to output.\"\n    )\n    parser.add_argument(\n        \"--header\",\n        type=str,\n        nargs=\"+\",\n        default=None,\n        help=\"Custom HTTP headers in Key=Value format. Example: --header MyHeader=MY_VALUE MyAnotherHeader=myanothervalue\",\n    )\n    args = parser.parse_args()\n    run_benchmark(args)\n"
  },
  {
    "path": "python/sglang/benchmark/__init__.py",
    "content": ""
  },
  {
    "path": "python/sglang/benchmark/bench_utils.py",
    "content": "\"\"\"Triton do_bench/do_bench_cudagraph compatible wrapper using flashinfer.testing.bench_gpu_time.\"\"\"\n\nimport numpy as np\nfrom flashinfer.testing import bench_gpu_time\n\n\ndef run_bench(\n    fn,\n    use_cuda_graph: bool = True,\n    quantiles=(0.5, 0.2, 0.8),\n    warmup_ms: int = 25,\n    rep_ms: int = 100,\n):\n    \"\"\"Returns (ms, min_ms, max_ms) or (median,) when quantiles=None.\"\"\"\n    times = bench_gpu_time(\n        fn=fn,\n        use_cuda_graph=use_cuda_graph,\n        dry_run_time_ms=warmup_ms,\n        repeat_time_ms=rep_ms,\n    )\n    if quantiles is None:\n        return (float(np.median(times)),)\n    return tuple(float(np.percentile(times, q * 100)) for q in quantiles)\n"
  },
  {
    "path": "python/sglang/benchmark/datasets/__init__.py",
    "content": "from typing import Dict, Type\n\nfrom sglang.benchmark.datasets.common import BaseDataset, DatasetRow\nfrom sglang.benchmark.datasets.custom import CustomDataset\nfrom sglang.benchmark.datasets.generated_shared_prefix import (\n    GeneratedSharedPrefixDataset,\n)\nfrom sglang.benchmark.datasets.image import ImageDataset\nfrom sglang.benchmark.datasets.mmmu import MMMUDataset\nfrom sglang.benchmark.datasets.mooncake import MooncakeDataset\nfrom sglang.benchmark.datasets.openai_dataset import OpenAIDataset\nfrom sglang.benchmark.datasets.random import RandomDataset\nfrom sglang.benchmark.datasets.sharegpt import ShareGPTDataset\n\nDATASET_MAPPING: Dict[str, Type[BaseDataset]] = {\n    \"sharegpt\": ShareGPTDataset,\n    \"custom\": CustomDataset,\n    \"openai\": OpenAIDataset,\n    # TODO: \"random\" vs \"random-ids\" should be a flag (e.g. --random-source=sharegpt|integers),\n    # not two separate dataset names sharing the same class.\n    \"random\": RandomDataset,\n    \"random-ids\": RandomDataset,\n    \"generated-shared-prefix\": GeneratedSharedPrefixDataset,\n    \"mmmu\": MMMUDataset,\n    \"image\": ImageDataset,\n    \"mooncake\": MooncakeDataset,\n}\n\n\ndef get_dataset(args, tokenizer, model_id=None):\n    dataset_name = args.dataset_name\n    if dataset_name.startswith(\"random\") and dataset_name not in DATASET_MAPPING:\n        dataset_name = \"random-ids\"\n\n    if dataset_name not in DATASET_MAPPING:\n        raise ValueError(f\"Unknown dataset: {args.dataset_name}\")\n\n    dataset_cls = DATASET_MAPPING[dataset_name]\n    dataset = dataset_cls.from_args(args)\n    return dataset.load(tokenizer=tokenizer, model_id=model_id)\n\n\n__all__ = [\n    \"DATASET_MAPPING\",\n    \"DatasetRow\",\n    \"get_dataset\",\n]\n"
  },
  {
    "path": "python/sglang/benchmark/datasets/common.py",
    "content": "import random\nfrom abc import ABC, abstractmethod\nfrom argparse import Namespace\nfrom dataclasses import dataclass\nfrom functools import lru_cache\nfrom typing import Any, Dict, List, Optional\n\nimport numpy as np\n\nASSISTANT_SUFFIX = \"Assistant:\"\nSHAREGPT_REPO_ID = \"anon8231489123/ShareGPT_Vicuna_unfiltered\"\nSHAREGPT_FILENAME = \"ShareGPT_V3_unfiltered_cleaned_split.json\"\nMOONCAKE_DATASET_URL = {\n    \"mooncake\": \"https://raw.githubusercontent.com/kvcache-ai/Mooncake/main/FAST25-release/arxiv-trace/mooncake_trace.jsonl\",\n    \"conversation\": \"https://raw.githubusercontent.com/kvcache-ai/Mooncake/main/FAST25-release/traces/conversation_trace.jsonl\",\n    \"synthetic\": \"https://raw.githubusercontent.com/kvcache-ai/Mooncake/main/FAST25-release/traces/synthetic_trace.jsonl\",\n    \"toolagent\": \"https://raw.githubusercontent.com/kvcache-ai/Mooncake/main/FAST25-release/traces/toolagent_trace.jsonl\",\n}\n\n\n@dataclass\nclass DatasetRow:\n    prompt: Any\n    prompt_len: int\n    output_len: int\n    text_prompt_len: Optional[int] = None\n    vision_prompt_len: Optional[int] = None\n    image_data: Optional[List[str]] = None\n    timestamp: Optional[float] = None\n    routing_key: Optional[str] = None\n    extra_request_body: Optional[Dict[str, Any]] = None  # Per-request API parameters\n\n    def __post_init__(self):\n        if self.text_prompt_len is None:\n            self.text_prompt_len = self.prompt_len\n        if self.vision_prompt_len is None:\n            self.vision_prompt_len = 0\n        if self.extra_request_body is None:\n            self.extra_request_body = {}\n\n\n@dataclass\nclass BaseDataset(ABC):\n    @classmethod\n    @abstractmethod\n    def from_args(cls, args: Namespace) -> \"BaseDataset\": ...\n\n    @abstractmethod\n    def load(\n        self,\n        tokenizer: Any,\n        model_id: Optional[str] = None,\n    ) -> List[DatasetRow]: ...\n\n\ndef compute_random_lens(full_len: int, range_ratio: float, num: int) -> List[int]:\n    # full_len=0 is valid for embedding benchmarks where no output tokens are generated\n    if full_len <= 0:\n        return [0] * num\n    return np.random.randint(\n        max(int(full_len * range_ratio), 1),\n        full_len + 1,\n        size=num,\n    ).tolist()\n\n\n@lru_cache(maxsize=1)\ndef get_available_tokens(tokenizer):\n    \"\"\"Get all available token ids from the tokenizer vocabulary.\"\"\"\n    return list(tokenizer.get_vocab().values())\n\n\ndef gen_prompt(tokenizer, token_num):\n    \"\"\"Generate a random prompt of specified token length using tokenizer vocabulary.\"\"\"\n    all_available_tokens = get_available_tokens(tokenizer)\n    selected_tokens = random.choices(all_available_tokens, k=token_num)\n    return tokenizer.decode(selected_tokens)\n\n\ndef gen_mm_prompt(tokenizer, image_pad_id, token_num):\n    \"\"\"Generate a random prompt of specified token length using tokenizer vocabulary.\"\"\"\n    all_available_tokens = list(tokenizer.get_vocab().values())\n    if image_pad_id:\n        all_available_tokens.remove(image_pad_id)\n    selected_tokens = random.choices(all_available_tokens, k=token_num)\n    return tokenizer.decode(selected_tokens)\n"
  },
  {
    "path": "python/sglang/benchmark/datasets/custom.py",
    "content": "import json\nimport os\nimport random\nfrom argparse import Namespace\nfrom dataclasses import dataclass\nfrom typing import List, Optional\n\nimport numpy as np\nfrom transformers import PreTrainedTokenizerBase\n\nfrom sglang.benchmark.datasets.common import (\n    ASSISTANT_SUFFIX,\n    BaseDataset,\n    DatasetRow,\n)\nfrom sglang.benchmark.utils import remove_suffix\n\n\n@dataclass\nclass CustomDataset(BaseDataset):\n    dataset_path: str\n    num_requests: int\n    fixed_output_len: Optional[int]\n    context_len: Optional[int]\n    prompt_suffix: str\n    apply_chat_template: bool\n\n    @classmethod\n    def from_args(cls, args: Namespace) -> \"CustomDataset\":\n        assert not getattr(args, \"tokenize_prompt\", False)\n        return cls(\n            dataset_path=args.dataset_path,\n            num_requests=args.num_prompts,\n            fixed_output_len=args.sharegpt_output_len,\n            context_len=args.sharegpt_context_len,\n            prompt_suffix=args.prompt_suffix,\n            apply_chat_template=args.apply_chat_template,\n        )\n\n    def load(\n        self, tokenizer: PreTrainedTokenizerBase, model_id=None\n    ) -> List[DatasetRow]:\n        return sample_custom_requests(\n            dataset_path=self.dataset_path,\n            num_requests=self.num_requests,\n            tokenizer=tokenizer,\n            fixed_output_len=self.fixed_output_len,\n            context_len=self.context_len,\n            prompt_suffix=self.prompt_suffix,\n            apply_chat_template=self.apply_chat_template,\n        )\n\n\ndef sample_custom_requests(\n    dataset_path: str,\n    num_requests: int,\n    tokenizer: PreTrainedTokenizerBase,\n    fixed_output_len: Optional[int] = None,\n    context_len: Optional[int] = None,\n    prompt_suffix: Optional[str] = \"\",\n    apply_chat_template=False,\n) -> List[DatasetRow]:\n    \"\"\"\n    Sample requests from a custom JSONL dataset: supports 'content'/'value' as conversation keys.\n    \"\"\"\n    if fixed_output_len is not None and fixed_output_len < 4:\n        raise ValueError(\"output_len too small\")\n\n    # Load the dataset\n    dataset = []\n    if not os.path.isfile(dataset_path):\n        raise FileNotFoundError(f\"Dataset not found at {dataset_path}\")\n\n    with open(dataset_path, \"r\", encoding=\"utf-8\") as f:\n        for line in f:\n            line = line.strip()\n            if line:  # skip empty lines\n                try:\n                    dataset.append(json.loads(line))\n                except json.JSONDecodeError:\n                    continue  # skip lines with JSON errors\n\n    # Filter out the conversations with less than 2 turns.\n    processed_dataset = []\n    for data in dataset:\n        convs = data.get(\"conversations\", data.get(\"conversation\", []))\n        if len(convs) >= 2:\n            user_turn = convs[0].get(\"content\", convs[0].get(\"value\", \"\"))\n            assist_turn = convs[1].get(\"content\", convs[1].get(\"value\", \"\"))\n            processed_dataset.append((user_turn, assist_turn))\n    dataset = processed_dataset\n    random.shuffle(dataset)\n\n    # Filter out sequences that are too long or too short\n    filtered_dataset: List[DatasetRow] = []\n\n    for i in range(len(dataset)):\n        if len(filtered_dataset) == num_requests:\n            break\n\n        # Tokenize the prompts and completions.\n        prompt = dataset[i][0]\n\n        if prompt_suffix:\n            prompt = (\n                remove_suffix(prompt, ASSISTANT_SUFFIX)\n                + prompt_suffix\n                + ASSISTANT_SUFFIX\n            )\n\n        if apply_chat_template:\n            prompt = tokenizer.apply_chat_template(\n                [{\"role\": \"user\", \"content\": prompt}],\n                add_generation_prompt=True,\n                tokenize=False,\n                return_dict=False,\n            )\n            if tokenizer.bos_token:\n                prompt = prompt.replace(tokenizer.bos_token, \"\")\n\n        prompt_token_ids = tokenizer.encode(prompt)\n        completion = dataset[i][1]\n        completion_token_ids = tokenizer.encode(completion)\n        prompt_len = len(prompt_token_ids)\n        output_len = (\n            len(completion_token_ids) if fixed_output_len is None else fixed_output_len\n        )\n\n        if prompt_len < 2 or output_len < 2:\n            # Prune too short sequences.\n            continue\n\n        if context_len and prompt_len + output_len > context_len:\n            # Prune too long sequences.\n            continue\n\n        filtered_dataset.append(\n            DatasetRow(\n                prompt=prompt,\n                prompt_len=prompt_len,\n                output_len=output_len,\n            )\n        )\n\n    print(f\"#Input tokens: {np.sum([x.prompt_len for x in filtered_dataset])}\")\n    print(f\"#Output tokens: {np.sum([x.output_len for x in filtered_dataset])}\")\n    return filtered_dataset\n"
  },
  {
    "path": "python/sglang/benchmark/datasets/generated_shared_prefix.py",
    "content": "import pickle\nimport random\nimport uuid\nfrom argparse import Namespace\nfrom dataclasses import dataclass\nfrom datetime import datetime\nfrom pathlib import Path\nfrom typing import List\n\nimport numpy as np\nfrom tqdm.asyncio import tqdm\nfrom transformers import PreTrainedTokenizerBase\n\nfrom sglang.benchmark.datasets.common import (\n    BaseDataset,\n    DatasetRow,\n    compute_random_lens,\n    gen_prompt,\n)\n\n\n@dataclass\nclass GeneratedSharedPrefixDataset(BaseDataset):\n    num_groups: int\n    prompts_per_group: int\n    system_prompt_len: int\n    question_len: int\n    output_len: int\n    range_ratio: float\n    seed: int\n    fast_prepare: bool\n    send_routing_key: bool\n    num_turns: int\n    ordered: bool\n\n    @classmethod\n    def from_args(cls, args: Namespace) -> \"GeneratedSharedPrefixDataset\":\n        assert not getattr(args, \"tokenize_prompt\", False)\n        return cls(\n            num_groups=args.gsp_num_groups,\n            prompts_per_group=args.gsp_prompts_per_group,\n            system_prompt_len=args.gsp_system_prompt_len,\n            question_len=args.gsp_question_len,\n            output_len=args.gsp_output_len,\n            range_ratio=getattr(args, \"gsp_range_ratio\", 1.0),\n            seed=args.seed,\n            fast_prepare=getattr(args, \"gsp_fast_prepare\", False),\n            send_routing_key=getattr(args, \"gsp_send_routing_key\", False),\n            num_turns=getattr(args, \"gsp_num_turns\", 1),\n            ordered=getattr(args, \"gsp_ordered\", False),\n        )\n\n    def load(\n        self, tokenizer: PreTrainedTokenizerBase, model_id=None\n    ) -> List[DatasetRow]:\n        return sample_generated_shared_prefix_requests(\n            num_groups=self.num_groups,\n            prompts_per_group=self.prompts_per_group,\n            system_prompt_len=self.system_prompt_len,\n            question_len=self.question_len,\n            output_len=self.output_len,\n            range_ratio=self.range_ratio,\n            tokenizer=tokenizer,\n            seed=self.seed,\n            send_routing_key=self.send_routing_key,\n            num_turns=self.num_turns,\n            fast_prepare=self.fast_prepare,\n            ordered=self.ordered,\n        )\n\n\ndef get_gen_prefix_cache_path(\n    seed: int,\n    num_groups: int,\n    prompts_per_group: int,\n    system_prompt_len: int,\n    question_len: int,\n    output_len: int,\n    tokenizer,\n):\n    \"\"\"Create cache directory under ~/.cache/sglang/benchmark\"\"\"\n    cache_dir = Path.home() / \".cache\" / \"sglang\" / \"benchmark\"\n\n    cache_key = (\n        f\"gen_shared_prefix_{seed}_{num_groups}_{prompts_per_group}_\"\n        f\"{system_prompt_len}_{question_len}_{output_len}_\"\n        f\"{tokenizer.__class__.__name__}.pkl\"\n    )\n    return cache_dir / cache_key\n\n\ndef sample_generated_shared_prefix_requests(\n    num_groups: int,\n    prompts_per_group: int,\n    system_prompt_len: int,\n    question_len: int,\n    output_len: int,\n    range_ratio: float,\n    tokenizer: PreTrainedTokenizerBase,\n    seed: int,\n    send_routing_key: bool = False,\n    num_turns: int = 1,\n    fast_prepare: bool = False,\n    ordered: bool = False,\n) -> List[DatasetRow]:\n    \"\"\"Generate benchmark requests with shared system prompts using random tokens and caching.\"\"\"\n    cache_path = get_gen_prefix_cache_path(\n        seed,\n        num_groups,\n        prompts_per_group,\n        system_prompt_len,\n        question_len,\n        output_len,\n        tokenizer,\n    )\n    should_cache = (range_ratio == 1) and not send_routing_key and num_turns == 1\n\n    # Try to load from cache first\n    if cache_path.exists() and should_cache:\n        print(f\"\\nLoading cached generated input data from {cache_path}\")\n        with open(cache_path, \"rb\") as f:\n            return pickle.load(f)\n\n    print(\n        f\"\\nGenerating new input data... \"\n        f\"({num_groups=}, {prompts_per_group}, {system_prompt_len=}, {question_len=}, {output_len=}, {range_ratio=}, {num_turns=})\"\n    )\n\n    run_random_str = uuid.uuid4().hex[:8]\n    run_start_timestamp = datetime.now().strftime(\"%Y%m%d%H%M%S\")\n\n    system_prompt_lens = compute_random_lens(\n        full_len=system_prompt_len,\n        range_ratio=range_ratio,\n        num=num_groups,\n    )\n    question_lens = np.array(\n        compute_random_lens(\n            full_len=question_len,\n            range_ratio=range_ratio,\n            num=num_groups * prompts_per_group * num_turns,\n        )\n    ).reshape(num_groups, prompts_per_group, num_turns)\n    output_lens = np.array(\n        compute_random_lens(\n            full_len=output_len,\n            range_ratio=range_ratio,\n            num=num_groups * prompts_per_group,\n        )\n    ).reshape(num_groups, prompts_per_group)\n    del system_prompt_len, question_len, output_len\n\n    # Generate system prompts for each group\n    system_prompts = [\n        gen_prompt(tokenizer, system_prompt_lens[i]) for i in range(num_groups)\n    ]\n\n    # Generate questions: shape (num_groups, prompts_per_group, num_turns)\n    questions = [\n        [\n            [\n                gen_prompt(tokenizer, int(question_lens[g, p, t]))\n                for t in range(num_turns)\n            ]\n            for p in range(prompts_per_group)\n        ]\n        for g in range(num_groups)\n    ]\n\n    # Combine system prompts with questions\n    input_requests = []\n    total_input_tokens = 0\n    total_output_tokens = 0\n\n    for group_idx in tqdm(range(num_groups), desc=\"Generating system prompt\"):\n        system_prompt = system_prompts[group_idx]\n        routing_key = (\n            f\"{run_random_str}_{run_start_timestamp}_{group_idx}\"\n            if send_routing_key\n            else None\n        )\n        for prompt_idx in tqdm(\n            range(prompts_per_group), desc=\"Generating questions\", leave=False\n        ):\n            turn_questions = questions[group_idx][prompt_idx]\n            turn_prompts = [f\"{system_prompt}\\n\\n{turn_questions[0]}\"] + turn_questions[\n                1:\n            ]\n            full_prompt = turn_prompts[0] if num_turns == 1 else turn_prompts\n            prompt_len = 1 if fast_prepare else len(tokenizer.encode(turn_prompts[0]))\n            output_len_val = int(output_lens[group_idx, prompt_idx])\n\n            input_requests.append(\n                DatasetRow(\n                    prompt=full_prompt,\n                    prompt_len=prompt_len,\n                    output_len=output_len_val,\n                    routing_key=routing_key,\n                )\n            )\n            total_input_tokens += prompt_len\n            total_output_tokens += output_len_val\n\n    if not ordered:\n        random.shuffle(input_requests)\n\n    # Print statistics\n    print(f\"\\nGenerated shared prefix dataset statistics:\")\n    print(f\"Number of groups: {num_groups}\")\n    print(f\"Prompts per group: {prompts_per_group}\")\n    print(f\"Number of turns: {num_turns}\")\n    print(f\"Total prompts: {len(input_requests)}\")\n    if not fast_prepare:\n        print(f\"Total input tokens: {total_input_tokens}\")\n        print(f\"Total output tokens: {total_output_tokens}\")\n        print(\n            f\"Average system prompt length: {sum(len(tokenizer.encode(sp)) for sp in system_prompts) / len(system_prompts):.1f} tokens\"\n        )\n        all_questions = [q for group in questions for conv in group for q in conv]\n        print(\n            f\"Average question length: {sum(len(tokenizer.encode(q)) for q in all_questions) / len(all_questions):.1f} tokens\\n\"\n        )\n\n    # Save to cache\n    if should_cache:\n        cache_path.parent.mkdir(parents=True, exist_ok=True)\n        print(f\"Caching generated input data to {cache_path}\")\n        with open(cache_path, \"wb\") as f:\n            pickle.dump(input_requests, f)\n\n    return input_requests\n"
  },
  {
    "path": "python/sglang/benchmark/datasets/image.py",
    "content": "import io\nimport warnings\nfrom argparse import Namespace\nfrom dataclasses import dataclass\nfrom typing import List, Tuple\n\nimport numpy as np\nimport pybase64\nfrom PIL import Image\nfrom transformers import AutoProcessor\n\nfrom sglang.benchmark.datasets.common import (\n    BaseDataset,\n    DatasetRow,\n    compute_random_lens,\n    gen_mm_prompt,\n)\nfrom sglang.benchmark.utils import get_processor\n\n\n@dataclass\nclass ImageDataset(BaseDataset):\n    num_requests: int\n    image_count: int\n    input_len: int\n    output_len: int\n    range_ratio: float\n    image_content: str\n    image_format: str\n    image_resolution: str\n    backend: str\n    random_image_count: bool\n\n    @classmethod\n    def from_args(cls, args: Namespace) -> \"ImageDataset\":\n        return cls(\n            num_requests=args.num_prompts,\n            image_count=args.image_count,\n            input_len=args.random_input_len,\n            output_len=args.random_output_len,\n            range_ratio=args.random_range_ratio,\n            image_content=args.image_content,\n            image_format=args.image_format,\n            image_resolution=args.image_resolution,\n            backend=args.backend,\n            random_image_count=args.random_image_count,\n        )\n\n    def load(self, tokenizer=None, model_id=None) -> List[DatasetRow]:\n        processor = get_processor(model_id)\n        return sample_image_requests(\n            num_requests=self.num_requests,\n            image_count=self.image_count,\n            input_len=self.input_len,\n            output_len=self.output_len,\n            range_ratio=self.range_ratio,\n            processor=processor,\n            image_content=self.image_content,\n            image_format=self.image_format,\n            image_resolution=self.image_resolution,\n            backend=self.backend,\n            random_image_count=self.random_image_count,\n        )\n\n\ndef parse_image_resolution(image_resolution: str) -> Tuple[int, int]:\n    \"\"\"Parse image resolution into (width, height).\n\n    Supports presets '1080p', '720p', '360p' and custom 'heightxwidth' format\n    (e.g., '1080x1920' means height=1080, width=1920).\n    \"\"\"\n    resolution_to_size = {\n        \"4k\": (3840, 2160),\n        \"1080p\": (1920, 1080),\n        \"720p\": (1280, 720),\n        \"360p\": (640, 360),\n    }\n    if image_resolution in resolution_to_size:\n        return resolution_to_size[image_resolution]\n\n    res = image_resolution.strip().lower()\n    if \"x\" in res:\n        parts = res.split(\"x\")\n        if len(parts) == 2 and parts[0].isdigit() and parts[1].isdigit():\n            height = int(parts[0])\n            width = int(parts[1])\n            if height > 0 and width > 0:\n                return (width, height)\n\n    raise ValueError(\n        f\"Unsupported image resolution: {image_resolution}. \"\n        \"Choose from 4k, 1080p, 720p, 360p, or provide custom 'heightxwidth' (e.g., 1080x1920).\"\n    )\n\n\ndef create_mm_data_row(\n    text_prompt, images: list, images_base64, output_len, processor, backend\n):\n    try:\n        if type(processor).__name__ == \"Phi4MMProcessor\":\n            # <|endoftext10|> is the image token used in the phi-4-multimodal model.\n            content_items = text_prompt.replace(\"image 1\", \"|endoftext10|\")\n        else:\n            content_items = [\n                {\"type\": \"image\", \"image\": {\"url\": image_base64}}\n                for image_base64 in images_base64\n            ]\n            content_items.append({\"type\": \"text\", \"text\": text_prompt})\n        prompt_str = processor.apply_chat_template(\n            [{\"role\": \"user\", \"content\": content_items}],\n            add_generation_prompt=True,\n            tokenize=False,\n        )\n    except Exception as e:\n        # Note (Xinyuan): This is a workaround for an issue where some tokenizers do not support content as a list. (e.g. InternVL)\n        print(f\"Error applying chat template: {e}, fallback to <image> tag\")\n        # Some tokenizers do not support list content; fall back to a placeholder in the text\n        prompt_str = f\"<image>{text_prompt}\"\n\n    # Calculate total tokens (text + vision)\n    prompt_len = processor(\n        text=[prompt_str],\n        images=images,\n        padding=False,\n        return_tensors=\"pt\",\n    )[\"input_ids\"].numel()\n\n    # Calculate text-only tokens\n    try:\n        # Create text-only version of the prompt\n        text_only_prompt = processor.apply_chat_template(\n            [{\"role\": \"user\", \"content\": text_prompt}],\n            add_generation_prompt=True,\n            tokenize=False,\n        )\n        text_prompt_len = processor(\n            text=[text_only_prompt],\n            padding=False,\n            return_tensors=\"pt\",\n        )[\"input_ids\"].numel()\n    except Exception:\n        # Fallback: just tokenize the text prompt directly\n        tokenizer_to_use = (\n            processor.tokenizer if hasattr(processor, \"tokenizer\") else processor\n        )\n        text_prompt_len = len(tokenizer_to_use.encode(text_prompt))\n\n    # Vision tokens = total tokens - text tokens\n    vision_prompt_len = prompt_len - text_prompt_len\n\n    use_raw_prompt = backend in [\n        \"sglang\",\n        \"sglang-oai\",\n        \"sglang-oai-chat\",\n        \"vllm\",\n        \"vllm-chat\",\n        \"lmdeploy\",\n        \"lmdeploy-chat\",\n    ]\n    return DatasetRow(\n        prompt=text_prompt if use_raw_prompt else prompt_str,\n        prompt_len=prompt_len,\n        output_len=output_len,\n        text_prompt_len=text_prompt_len,\n        vision_prompt_len=vision_prompt_len,\n        image_data=images_base64,\n    )\n\n\ndef sample_image_requests(\n    num_requests: int,\n    image_count: int,\n    input_len: int,\n    output_len: int,\n    range_ratio: float,\n    processor: AutoProcessor,\n    image_content: str,\n    image_format: str,\n    image_resolution: str,\n    backend: str,\n    random_image_count: bool = False,\n) -> List[DatasetRow]:\n    \"\"\"Generate requests with images.\n\n    - If ``random_image_count`` is True, each request includes a random number of images between 1 and ``image_count``.\n    - If ``random_image_count`` is False, each request includes exactly ``image_count`` images.\n    - Supported resolutions: 4k (3840x2160), 1080p (1920x1080), 720p (1280x720), 360p (640x360),\n      or custom 'heightxwidth' (e.g., 1080x1920).\n    - Text lengths follow the 'random' dataset sampling rule. ``prompt_len``\n      only counts text tokens and excludes image data.\n    \"\"\"\n\n    # Parse resolution (supports presets and 'heightxwidth')\n    width, height = parse_image_resolution(image_resolution)\n\n    # Determine image counts for each request\n    if random_image_count:\n        # Random number of images per request\n        image_counts = np.random.randint(1, image_count + 1, size=num_requests)\n        total_images = np.sum(image_counts)\n    else:\n        # Fixed number of images per request\n        image_counts = np.full(num_requests, image_count)\n        total_images = image_count * num_requests\n\n    # Check for potentially problematic combinations and warn user\n    if width * height >= 1920 * 1080 and total_images >= 100:\n        warnings.warn(\n            f\"High resolution ({width}x{height}) with {total_images} total images \"\n            f\"may take a long time. Consider reducing resolution or image count.\",\n            UserWarning,\n            stacklevel=2,\n        )\n\n    # Sample text lengths\n    input_lens = compute_random_lens(\n        full_len=input_len,\n        range_ratio=range_ratio,\n        num=num_requests,\n    )\n    output_lens = compute_random_lens(\n        full_len=output_len,\n        range_ratio=range_ratio,\n        num=num_requests,\n    )\n\n    def _gen_random_image_data_uri(\n        width: int = width, height: int = height\n    ) -> Tuple[Image.Image, str, int]:\n        if image_content == \"blank\":\n            # Generate blank white image\n            arr = np.full((height, width, 3), 255, dtype=np.uint8)\n        else:\n            # Generate random colored image\n            arr = (np.random.rand(height, width, 3) * 255).astype(np.uint8)\n        img = Image.fromarray(arr)\n        buf = io.BytesIO()\n        img.save(buf, format=image_format, quality=85)\n        encoded = pybase64.b64encode(buf.getvalue()).decode(\"utf-8\")\n        image_data = f\"data:image/{image_format};base64,{encoded}\"\n        image_bytes = len(image_data.encode(\"utf-8\"))\n        return img, image_data, image_bytes\n\n    dataset: List[DatasetRow] = []\n    total_image_bytes = 0\n    for i in range(num_requests):\n        # Get the number of images for this request\n        request_image_count = int(image_counts[i])\n\n        # Generate text prompt\n        text_prompt = gen_mm_prompt(\n            processor.tokenizer,\n            processor.image_token_id if hasattr(processor, \"image_token_id\") else None,\n            int(input_lens[i]),\n        )\n\n        # Generate image list\n        images, images_base64, images_bytes = zip(\n            *[_gen_random_image_data_uri() for _ in range(request_image_count)]\n        )\n        total_image_bytes += sum(images_bytes)\n\n        data_row = create_mm_data_row(\n            text_prompt,\n            list(images),\n            list(images_base64),\n            int(output_lens[i]),\n            processor,\n            backend,\n        )\n        dataset.append(data_row)\n\n    # Print statistics\n    print(f\"#Input tokens: {np.sum([x.prompt_len for x in dataset])}\")\n    print(f\"#Output tokens: {np.sum([x.output_len for x in dataset])}\")\n    print(f\"#Total images: {total_images}\")\n\n    if random_image_count:\n        print(\n            f\"#Images per request: min={np.min(image_counts)}, max={np.max(image_counts)}, mean={np.mean(image_counts):.2f}\"\n        )\n    else:\n        print(f\"#Images per request: {image_count} (fixed)\")\n\n    print(\n        f\"\\nCreated {len(dataset)} {image_content} {image_format} images with average {total_image_bytes // num_requests} bytes per request\"\n    )\n    return dataset\n"
  },
  {
    "path": "python/sglang/benchmark/datasets/mmmu.py",
    "content": "import io\nimport random\nfrom argparse import Namespace\nfrom dataclasses import dataclass\nfrom typing import List, Optional\n\nimport pybase64\nfrom datasets import load_dataset\nfrom transformers import AutoProcessor, AutoTokenizer\n\nfrom sglang.benchmark.datasets.common import BaseDataset, DatasetRow\nfrom sglang.benchmark.datasets.image import create_mm_data_row\nfrom sglang.benchmark.utils import get_processor\n\n\n@dataclass\nclass MMMUDataset(BaseDataset):\n    num_requests: int\n    backend: str\n    fixed_output_len: Optional[int]\n\n    @classmethod\n    def from_args(cls, args: Namespace) -> \"MMMUDataset\":\n        return cls(\n            num_requests=args.num_prompts,\n            backend=args.backend,\n            fixed_output_len=args.random_output_len,\n        )\n\n    def load(self, tokenizer=None, model_id=None) -> List[DatasetRow]:\n        processor = get_processor(model_id)\n        return sample_mmmu_requests(\n            num_requests=self.num_requests,\n            processor=processor,\n            backend=self.backend,\n            fixed_output_len=self.fixed_output_len,\n        )\n\n\ndef sample_mmmu_requests(\n    num_requests: int,\n    processor: AutoProcessor | AutoTokenizer,\n    backend: str = \"sglang\",\n    fixed_output_len: Optional[int] = None,\n    random_sample: bool = True,\n) -> List[DatasetRow]:\n    \"\"\"\n    Sample requests from the MMMU dataset using HuggingFace datasets.\n\n    Args:\n        num_requests: Number of requests to sample.\n        fixed_output_len: If provided, use this fixed output length for all requests.\n        random_sample: Whether to randomly sample or take the first N.\n\n    Returns:\n        List of tuples (prompt, prompt_token_len, output_token_len).\n    \"\"\"\n    print(\"Loading MMMU dataset from HuggingFace...\")\n\n    try:\n        print(\"Attempting to load MMMU Math dataset...\")\n        mmmu_dataset = load_dataset(\"MMMU/MMMU\", \"Math\", split=\"test\")\n        print(\n            f\"Successfully loaded MMMU Math dataset from HuggingFace with {len(mmmu_dataset)} examples\"\n        )\n    except Exception as e:\n        print(f\"Failed to load MMMU Math dataset: {e}\")\n        raise ValueError(f\"Failed to load MMMU dataset: {e}\")\n\n    # Sample from the dataset\n    if len(mmmu_dataset) > num_requests:\n        if random_sample:\n            # Random sample\n            indices = random.sample(range(len(mmmu_dataset)), num_requests)\n            sample_dataset = mmmu_dataset.select(indices)\n        else:\n            # Take first N\n            sample_dataset = mmmu_dataset.select(\n                range(min(num_requests, len(mmmu_dataset)))\n            )\n    else:\n        print(f\"Dataset has less than {num_requests} examples, using all examples\")\n        sample_dataset = mmmu_dataset\n\n    print(f\"Selected {len(sample_dataset)} examples for benchmarking\")\n\n    # Create prompts\n    filtered_dataset = []\n\n    for i, example in enumerate(sample_dataset):\n        try:\n            # Extract image_1\n            image = example.get(\"image_1\")\n\n            if image is not None:\n                if hasattr(image, \"save\"):\n                    # Convert RGBA images to RGB before encoding\n                    if image.mode == \"RGBA\":\n                        image = image.convert(\"RGB\")\n\n                    # Encode image to base64 (save as PNG to support palette/alpha modes)\n                    buffered = io.BytesIO()\n                    image.save(buffered, format=\"PNG\")\n                    img_str = pybase64.b64encode(buffered.getvalue()).decode(\"utf-8\")\n                    image_data = f\"data:image/png;base64,{img_str}\"\n                else:\n                    continue\n\n                # Extract the question\n                question = example.get(\"question\")\n\n                # Construct the prompt\n                text_prompt = f\"Question: {question}\\n\\nAnswer: \"\n                output_len = fixed_output_len if fixed_output_len is not None else 256\n                data_row = create_mm_data_row(\n                    text_prompt, [image], [image_data], output_len, processor, backend\n                )\n                filtered_dataset.append(data_row)\n\n        except Exception as e:\n            print(f\"Error processing example {i}: {e}\")\n\n    print(f\"\\nCreated {len(filtered_dataset)} MMMU prompts\")\n    return filtered_dataset\n"
  },
  {
    "path": "python/sglang/benchmark/datasets/mooncake.py",
    "content": "import asyncio\nimport json\nimport os\nimport time\nfrom argparse import Namespace\nfrom dataclasses import dataclass\nfrom typing import AsyncGenerator, Dict, List\n\nfrom transformers import PreTrainedTokenizerBase\n\nfrom sglang.benchmark.datasets.common import (\n    MOONCAKE_DATASET_URL,\n    BaseDataset,\n    DatasetRow,\n)\nfrom sglang.benchmark.utils import download_and_cache_file\n\n\n@dataclass\nclass MooncakeDataset(BaseDataset):\n    dataset_path: str\n    mooncake_workload: str\n    num_requests: int\n\n    @classmethod\n    def from_args(cls, args: Namespace) -> \"MooncakeDataset\":\n        return cls(\n            dataset_path=args.dataset_path,\n            mooncake_workload=args.mooncake_workload,\n            num_requests=args.num_prompts,\n        )\n\n    def load(self, tokenizer=None, model_id=None) -> List[Dict]:\n        if not self.dataset_path:\n            local_path = os.path.join(\"/tmp\", self.mooncake_workload + \"_trace.jsonl\")\n        else:\n            local_path = self.dataset_path\n\n        if not os.path.exists(local_path):\n            download_and_cache_file(\n                MOONCAKE_DATASET_URL[self.mooncake_workload], local_path\n            )\n\n        with open(local_path, \"r\") as f:\n            all_requests_data = [json.loads(line) for line in f if line.strip()]\n\n        return all_requests_data[: self.num_requests]\n\n\nasync def get_mooncake_request_over_time(\n    input_requests: List[Dict],\n    tokenizer: PreTrainedTokenizerBase,\n    slowdown_factor: float,\n    num_rounds: int,\n) -> AsyncGenerator[DatasetRow, None]:\n    \"\"\"\n    An async generator that yields requests based on the timestamps in the Mooncake trace file,\n    with support for multi-round sessions.\n    \"\"\"\n    if not input_requests:\n        return\n\n    input_requests.sort(key=lambda r: r[\"timestamp\"])\n\n    start_time = time.perf_counter()\n    trace_start_time_ms = input_requests[0][\"timestamp\"]\n\n    for record in input_requests:\n        # Calculate when this entire session should start\n        relative_arrival_time_s = (record[\"timestamp\"] - trace_start_time_ms) / 1000.0\n        target_arrival_time_s = relative_arrival_time_s * slowdown_factor\n\n        current_elapsed_time_s = time.perf_counter() - start_time\n        sleep_duration_s = target_arrival_time_s - current_elapsed_time_s\n        if sleep_duration_s > 0:\n            await asyncio.sleep(sleep_duration_s)\n\n        # Once the session starts, generate all rounds for it as a burst\n        # This simulates a user engaging in a multi-turn conversation\n\n        # Base user query constructed from hash_ids\n        user_query_base = \"\"\n        hash_ids = record.get(\"hash_ids\", [])\n        for hash_id in hash_ids:\n            user_query_base += f\"{hash_id}\" + \" \".join(\n                [\"hi\"] * 128\n            )  # Shorter for multi-round\n        user_query_base += \"Tell me a story based on this context.\"\n\n        output_len_per_round = record.get(\"output_length\", 256)\n        chat_history = []\n\n        for i in range(num_rounds):\n            # Add user query for the current round\n            chat_history.append(\n                {\"role\": \"user\", \"content\": f\"Round {i + 1}: {user_query_base}\"}\n            )\n\n            # Form the full prompt from history\n            try:\n                full_prompt_text = tokenizer.apply_chat_template(\n                    chat_history,\n                    tokenize=False,\n                    add_generation_prompt=True,\n                    return_dict=False,\n                )\n            except Exception:\n                full_prompt_text = \"\\n\".join(\n                    [f\"{msg['role']}: {msg['content']}\" for msg in chat_history]\n                )\n\n            prompt_len = len(tokenizer.encode(full_prompt_text))\n\n            yield DatasetRow(\n                prompt=full_prompt_text,\n                prompt_len=prompt_len,\n                output_len=output_len_per_round,\n            )\n\n            # Add a placeholder assistant response for the next round's context\n            # We use a placeholder because we don't know the real response\n            placeholder_response = \" \".join([\"story\"] * output_len_per_round)\n            chat_history.append({\"role\": \"assistant\", \"content\": placeholder_response})\n"
  },
  {
    "path": "python/sglang/benchmark/datasets/openai_dataset.py",
    "content": "import json\nfrom argparse import Namespace\nfrom dataclasses import dataclass\nfrom typing import List, Optional\n\nimport numpy as np\nfrom transformers import PreTrainedTokenizerBase\n\nfrom sglang.benchmark.datasets.common import BaseDataset, DatasetRow\n\n\n@dataclass\nclass OpenAIDataset(BaseDataset):\n    dataset_path: str\n    num_requests: int\n    fixed_output_len: Optional[int]\n\n    @classmethod\n    def from_args(cls, args: Namespace) -> \"OpenAIDataset\":\n        return cls(\n            dataset_path=args.dataset_path,\n            num_requests=args.num_prompts,\n            fixed_output_len=args.sharegpt_output_len,\n        )\n\n    def load(\n        self, tokenizer: PreTrainedTokenizerBase, model_id=None\n    ) -> List[DatasetRow]:\n        return sample_openai_requests(\n            dataset_path=self.dataset_path,\n            num_requests=self.num_requests,\n            tokenizer=tokenizer,\n            fixed_output_len=self.fixed_output_len,\n        )\n\n\ndef sample_openai_requests(\n    dataset_path: str,\n    num_requests: int,\n    tokenizer: PreTrainedTokenizerBase,\n    fixed_output_len: Optional[int] = None,\n) -> List[DatasetRow]:\n    \"\"\"\n    Load OpenAI-compatible chat completion requests from a JSONL file.\n\n    Each line should be a JSON object with:\n    - \"messages\": list of {\"role\": str, \"content\": str}\n    - \"max_tokens\": int (used as output_len if fixed_output_len not set)\n    - \"tools\": optional list of tool definitions\n    - \"temperature\": optional temperature value\n    - \"top_p\": optional top_p value\n    - Other OpenAI API parameters are also extracted and passed through\n    \"\"\"\n    dataset = []\n    with open(dataset_path, \"r\") as f:\n        for line in f:\n            if num_requests > 0 and len(dataset) >= num_requests:\n                break\n            if line.strip():\n                try:\n                    dataset.append(json.loads(line))\n                except json.JSONDecodeError:\n                    # Skip invalid JSON lines\n                    continue\n\n    # Fields that should NOT be passed through extra_request_body\n    # These are either handled separately or are metadata\n    # max_tokens is excluded because it's handled via output_len -> max_completion_tokens\n    # max_completion_tokens is also excluded to avoid conflicts\n    EXCLUDED_FIELDS = {\"messages\", \"max_tokens\", \"max_completion_tokens\", \"model\"}\n\n    filtered_dataset: List[DatasetRow] = []\n    for data in dataset:\n        messages = data.get(\"messages\", [])\n        if not messages:\n            continue\n\n        # Use max_tokens from the request, or fall back to fixed_output_len\n        output_len = fixed_output_len or data.get(\"max_tokens\", 256)\n\n        # Extract extra request body parameters (tools, temperature, top_p, etc.)\n        extra_body = {k: v for k, v in data.items() if k not in EXCLUDED_FIELDS}\n\n        # Calculate prompt length by applying chat template\n        # This includes the messages but not the tools\n        prompt_len = len(\n            tokenizer.apply_chat_template(\n                messages, tokenize=True, add_generation_prompt=True\n            )\n        )\n\n        # If tools are present, we need to add their token count\n        # Tools are sent as part of the request and count toward input tokens\n        if \"tools\" in extra_body:\n            # Encode tools as JSON string to estimate token count\n            tools_str = json.dumps(extra_body[\"tools\"])\n            tools_tokens = len(tokenizer.encode(tools_str))\n            prompt_len += tools_tokens\n\n        # Pass messages list directly - bench_serving handles List[Dict] prompts\n        filtered_dataset.append(\n            DatasetRow(\n                prompt=messages,\n                prompt_len=prompt_len,\n                output_len=output_len,\n                extra_request_body=extra_body,  # Store per-request parameters\n            )\n        )\n\n    print(f\"Loaded {len(filtered_dataset)} OpenAI-format requests\")\n    print(f\"#Input tokens: {np.sum([x.prompt_len for x in filtered_dataset])}\")\n    print(f\"#Output tokens: {np.sum([x.output_len for x in filtered_dataset])}\")\n    return filtered_dataset\n"
  },
  {
    "path": "python/sglang/benchmark/datasets/random.py",
    "content": "import json\nimport random\nfrom argparse import Namespace\nfrom dataclasses import dataclass\nfrom typing import List\n\nimport numpy as np\nfrom transformers import PreTrainedTokenizerBase\n\nfrom sglang.benchmark.datasets.common import (\n    SHAREGPT_FILENAME,\n    SHAREGPT_REPO_ID,\n    BaseDataset,\n    DatasetRow,\n    compute_random_lens,\n)\nfrom sglang.benchmark.utils import download_and_cache_hf_file, is_file_valid_json\n\n\n@dataclass\nclass RandomDataset(BaseDataset):\n    input_len: int\n    output_len: int\n    num_requests: int\n    range_ratio: float\n    dataset_path: str\n    return_text: bool\n    random_sample: bool\n\n    @classmethod\n    def from_args(cls, args: Namespace) -> \"RandomDataset\":\n        return cls(\n            input_len=args.random_input_len,\n            output_len=args.random_output_len,\n            num_requests=args.num_prompts,\n            range_ratio=args.random_range_ratio,\n            dataset_path=args.dataset_path,\n            return_text=not getattr(args, \"tokenize_prompt\", False),\n            random_sample=(args.dataset_name == \"random\"),\n        )\n\n    def load(\n        self, tokenizer: PreTrainedTokenizerBase, model_id=None\n    ) -> List[DatasetRow]:\n        return sample_random_requests(\n            input_len=self.input_len,\n            output_len=self.output_len,\n            num_prompts=self.num_requests,\n            range_ratio=self.range_ratio,\n            tokenizer=tokenizer,\n            dataset_path=self.dataset_path,\n            random_sample=self.random_sample,\n            return_text=self.return_text,\n        )\n\n\ndef sample_random_requests(\n    input_len: int,\n    output_len: int,\n    num_prompts: int,\n    range_ratio: float,\n    tokenizer: PreTrainedTokenizerBase,\n    dataset_path: str,\n    random_sample: bool = True,\n    return_text: bool = True,\n) -> List[DatasetRow]:\n    input_lens = compute_random_lens(\n        full_len=input_len,\n        range_ratio=range_ratio,\n        num=num_prompts,\n    )\n    output_lens = compute_random_lens(\n        full_len=output_len,\n        range_ratio=range_ratio,\n        num=num_prompts,\n    )\n\n    if return_text:\n        # Need to truncate input_len as server encode will add special token.\n        num_special_tokens = int(tokenizer.num_special_tokens_to_add())\n        for i in range(num_prompts):\n            input_lens[i] = max(0, input_lens[i] - num_special_tokens)\n\n    if random_sample:\n        # Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens\n\n        # Download sharegpt if necessary\n        if not is_file_valid_json(dataset_path):\n            dataset_path = download_and_cache_hf_file(\n                repo_id=SHAREGPT_REPO_ID,\n                filename=SHAREGPT_FILENAME,\n            )\n\n        # Load the dataset.\n        with open(dataset_path) as f:\n            dataset = json.load(f)\n        # Filter out the conversations with less than 2 turns.\n        dataset = [\n            data\n            for data in dataset\n            if len(data.get(\"conversations\", data.get(\"conversation\", []))) >= 2\n        ]\n        # Only keep the first two turns of each conversation.\n        dataset = [\n            (\n                data.get(\"conversations\", data.get(\"conversation\", []))[0][\"value\"],\n                data.get(\"conversations\", data.get(\"conversation\", []))[1][\"value\"],\n            )\n            for data in dataset\n        ]\n        # Shuffle the dataset.\n        random.shuffle(dataset)\n\n        # Filter out sequences that are too long or too short\n        input_requests: List[DatasetRow] = []\n        for data in dataset:\n            i = len(input_requests)\n            if i == num_prompts:\n                break\n\n            # Tokenize the prompts and completions.\n            prompt = data[0]\n            prompt_token_ids = tokenizer.encode(prompt)\n            prompt_len = len(prompt_token_ids)\n\n            # Skip empty prompt\n            if prompt_len == 0:\n                continue\n\n            if prompt_len > input_lens[i]:\n                input_ids = prompt_token_ids[: input_lens[i]]\n            else:\n                ratio = (input_lens[i] + prompt_len - 1) // prompt_len\n                input_ids = (prompt_token_ids * ratio)[: input_lens[i]]\n            input_content = input_ids\n            if return_text:\n                input_content = tokenizer.decode(input_content)\n            input_requests.append(\n                DatasetRow(\n                    prompt=input_content,\n                    prompt_len=input_lens[i],\n                    output_len=output_lens[i],\n                )\n            )\n    else:\n        # Sample token ids from random integers. This can cause some NaN issues.\n        offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)\n        input_requests = []\n        for i in range(num_prompts):\n            # Use int() to convert numpy.int64 to native Python int for JSON serialization\n            input_content = [\n                int((offsets[i] + i + j) % tokenizer.vocab_size)\n                for j in range(input_lens[i])\n            ]\n            if return_text:\n                input_content = tokenizer.decode(input_content)\n            input_requests.append(\n                DatasetRow(\n                    prompt=input_content,\n                    prompt_len=input_lens[i],\n                    output_len=output_lens[i],\n                )\n            )\n\n    print(f\"#Input tokens: {np.sum(input_lens)}\")\n    print(f\"#Output tokens: {np.sum(output_lens)}\")\n    return input_requests\n"
  },
  {
    "path": "python/sglang/benchmark/datasets/sharegpt.py",
    "content": "import json\nimport random\nfrom argparse import Namespace\nfrom dataclasses import dataclass\nfrom typing import List, Optional\n\nimport numpy as np\nfrom transformers import PreTrainedTokenizerBase\n\nfrom sglang.benchmark.datasets.common import (\n    ASSISTANT_SUFFIX,\n    SHAREGPT_FILENAME,\n    SHAREGPT_REPO_ID,\n    BaseDataset,\n    DatasetRow,\n)\nfrom sglang.benchmark.utils import (\n    download_and_cache_hf_file,\n    is_file_valid_json,\n    remove_suffix,\n)\n\n\n@dataclass\nclass ShareGPTDataset(BaseDataset):\n    dataset_path: str\n    num_requests: int\n    fixed_output_len: Optional[int]\n    context_len: Optional[int]\n    prompt_suffix: str\n    apply_chat_template: bool\n\n    @classmethod\n    def from_args(cls, args: Namespace) -> \"ShareGPTDataset\":\n        assert not getattr(args, \"tokenize_prompt\", False)\n        return cls(\n            dataset_path=args.dataset_path,\n            num_requests=args.num_prompts,\n            fixed_output_len=args.sharegpt_output_len,\n            context_len=args.sharegpt_context_len,\n            prompt_suffix=args.prompt_suffix,\n            apply_chat_template=args.apply_chat_template,\n        )\n\n    def load(\n        self, tokenizer: PreTrainedTokenizerBase, model_id=None\n    ) -> List[DatasetRow]:\n        return sample_sharegpt_requests(\n            dataset_path=self.dataset_path,\n            num_requests=self.num_requests,\n            tokenizer=tokenizer,\n            fixed_output_len=self.fixed_output_len,\n            context_len=self.context_len,\n            prompt_suffix=self.prompt_suffix,\n            apply_chat_template=self.apply_chat_template,\n        )\n\n\ndef sample_sharegpt_requests(\n    dataset_path: str,\n    num_requests: int,\n    tokenizer: PreTrainedTokenizerBase,\n    fixed_output_len: Optional[int] = None,\n    context_len: Optional[int] = None,\n    prompt_suffix: Optional[str] = \"\",\n    apply_chat_template=False,\n) -> List[DatasetRow]:\n    if fixed_output_len is not None and fixed_output_len < 4:\n        raise ValueError(\"output_len too small\")\n\n    # Download sharegpt if necessary\n    if not is_file_valid_json(dataset_path) and dataset_path == \"\":\n        dataset_path = download_and_cache_hf_file(\n            repo_id=SHAREGPT_REPO_ID,\n            filename=SHAREGPT_FILENAME,\n        )\n\n    # Load the dataset.\n    with open(dataset_path) as f:\n        dataset = json.load(f)\n\n    # Filter out the conversations with less than 2 turns.\n    dataset = [\n        data\n        for data in dataset\n        if len(data.get(\"conversations\", data.get(\"conversation\", []))) >= 2\n    ]\n    # Only keep the first two turns of each conversation.\n    dataset = [\n        (\n            data.get(\"conversations\", data.get(\"conversation\", []))[0][\"value\"],\n            data.get(\"conversations\", data.get(\"conversation\", []))[1][\"value\"],\n        )\n        for data in dataset\n    ]\n\n    # Shuffle the dataset.\n    random.shuffle(dataset)\n\n    # Filter out sequences that are too long or too short\n    filtered_dataset: List[DatasetRow] = []\n    for i in range(len(dataset)):\n        if len(filtered_dataset) == num_requests:\n            break\n\n        # Tokenize the prompts and completions.\n        prompt = dataset[i][0]\n        if prompt_suffix:\n            prompt = (\n                remove_suffix(prompt, ASSISTANT_SUFFIX)\n                + prompt_suffix\n                + ASSISTANT_SUFFIX\n            )\n\n        if apply_chat_template:\n            prompt = tokenizer.apply_chat_template(\n                [{\"role\": \"user\", \"content\": prompt}],\n                add_generation_prompt=True,\n                tokenize=False,\n                return_dict=False,\n            )\n            if tokenizer.bos_token:\n                prompt = prompt.replace(tokenizer.bos_token, \"\")\n\n        prompt_token_ids = tokenizer.encode(prompt)\n        completion = dataset[i][1]\n        completion_token_ids = tokenizer.encode(completion)\n        prompt_len = len(prompt_token_ids)\n        output_len = (\n            len(completion_token_ids) if fixed_output_len is None else fixed_output_len\n        )\n\n        if prompt_len < 2 or output_len < 2:\n            # Prune too short sequences.\n            continue\n\n        if context_len and prompt_len + output_len > context_len:\n            # Prune too long sequences.\n            continue\n\n        filtered_dataset.append(\n            DatasetRow(\n                prompt=prompt,\n                prompt_len=prompt_len,\n                output_len=output_len,\n            )\n        )\n\n    print(f\"#Input tokens: {np.sum([x.prompt_len for x in filtered_dataset])}\")\n    print(f\"#Output tokens: {np.sum([x.output_len for x in filtered_dataset])}\")\n    return filtered_dataset\n"
  },
  {
    "path": "python/sglang/benchmark/utils.py",
    "content": "import json\nimport os\nimport resource\nfrom json import JSONDecodeError\nfrom typing import Dict, List, Optional, Union\n\nimport requests\nfrom tqdm.asyncio import tqdm\nfrom transformers import (\n    AutoProcessor,\n    AutoTokenizer,\n    PreTrainedTokenizer,\n    PreTrainedTokenizerFast,\n)\n\n\ndef remove_prefix(text: str, prefix: str) -> str:\n    return text[len(prefix) :] if text.startswith(prefix) else text\n\n\ndef remove_suffix(text: str, suffix: str) -> str:\n    return text[: -len(suffix)] if text.endswith(suffix) else text\n\n\ndef parse_custom_headers(header_list: List[str]) -> Dict[str, str]:\n    return {k: v for h in header_list for k, _, v in [h.partition(\"=\")] if k and v}\n\n\ndef get_model(pretrained_model_name_or_path: str) -> str:\n    if os.getenv(\"SGLANG_USE_MODELSCOPE\", \"false\").lower() == \"true\":\n        import huggingface_hub.constants\n        from modelscope import snapshot_download\n\n        model_path = snapshot_download(\n            model_id=pretrained_model_name_or_path,\n            local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,\n            ignore_file_pattern=[\".*.pt\", \".*.safetensors\", \".*.bin\"],\n        )\n\n        return model_path\n    return pretrained_model_name_or_path\n\n\ndef get_tokenizer(\n    pretrained_model_name_or_path: str,\n) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:\n    assert (\n        pretrained_model_name_or_path is not None\n        and pretrained_model_name_or_path != \"\"\n    )\n    if pretrained_model_name_or_path.endswith(\n        \".json\"\n    ) or pretrained_model_name_or_path.endswith(\".model\"):\n        from sglang.srt.utils.hf_transformers_utils import get_tokenizer\n\n        return get_tokenizer(pretrained_model_name_or_path)\n\n    if pretrained_model_name_or_path is not None and not os.path.exists(\n        pretrained_model_name_or_path\n    ):\n        pretrained_model_name_or_path = get_model(pretrained_model_name_or_path)\n    return AutoTokenizer.from_pretrained(\n        pretrained_model_name_or_path, trust_remote_code=True\n    )\n\n\ndef get_processor(\n    pretrained_model_name_or_path: str,\n) -> AutoProcessor:\n    assert (\n        pretrained_model_name_or_path is not None\n        and pretrained_model_name_or_path != \"\"\n    )\n    if pretrained_model_name_or_path.endswith(\n        \".json\"\n    ) or pretrained_model_name_or_path.endswith(\".model\"):\n        from sglang.srt.utils.hf_transformers_utils import get_processor\n\n        return get_processor(pretrained_model_name_or_path)\n\n    if pretrained_model_name_or_path is not None and not os.path.exists(\n        pretrained_model_name_or_path\n    ):\n        pretrained_model_name_or_path = get_model(pretrained_model_name_or_path)\n    return AutoProcessor.from_pretrained(\n        pretrained_model_name_or_path, trust_remote_code=True\n    )\n\n\ndef download_and_cache_hf_file(\n    repo_id: str,\n    filename: str,\n    repo_type: str = \"dataset\",\n):\n    \"\"\"Download a file from Hugging Face and cache it locally.\"\"\"\n    from huggingface_hub import hf_hub_download\n\n    return hf_hub_download(repo_id=repo_id, filename=filename, repo_type=repo_type)\n\n\ndef download_and_cache_file(url: str, filename: Optional[str] = None):\n    \"\"\"Read and cache a file from a url.\"\"\"\n    if filename is None:\n        filename = os.path.join(\"/tmp\", url.split(\"/\")[-1])\n\n    # Check if the cache file already exists\n    if is_file_valid_json(filename):\n        return filename\n\n    print(f\"Downloading from {url} to {filename}\")\n\n    # Stream the response to show the progress bar\n    response = requests.get(url, stream=True)\n    response.raise_for_status()  # Check for request errors\n\n    # Total size of the file in bytes\n    total_size = int(response.headers.get(\"content-length\", 0))\n    chunk_size = 1024  # Download in chunks of 1KB\n\n    # Use tqdm to display the progress bar\n    with open(filename, \"wb\") as f, tqdm(\n        desc=filename,\n        total=total_size,\n        unit=\"B\",\n        unit_scale=True,\n        unit_divisor=1024,\n    ) as bar:\n        for chunk in response.iter_content(chunk_size=chunk_size):\n            f.write(chunk)\n            bar.update(len(chunk))\n\n    return filename\n\n\ndef is_file_valid_json(path):\n    if not os.path.isfile(path):\n        return False\n\n    # TODO can fuse into the real file open later\n    try:\n        with open(path) as f:\n            json.load(f)\n        return True\n    except JSONDecodeError as e:\n        print(\n            f\"{path} exists but json loading fails ({e=}), thus treat as invalid file\"\n        )\n        return False\n\n\ndef set_ulimit(target_soft_limit=65535):\n    resource_type = resource.RLIMIT_NOFILE\n    current_soft, current_hard = resource.getrlimit(resource_type)\n\n    if current_soft < target_soft_limit:\n        try:\n            resource.setrlimit(resource_type, (target_soft_limit, current_hard))\n        except ValueError as e:\n            print(f\"Fail to set RLIMIT_NOFILE: {e}\")\n"
  },
  {
    "path": "python/sglang/check_env.py",
    "content": "\"\"\"Check environment configurations and dependency versions.\"\"\"\n\nimport importlib.metadata\nimport os\nimport resource\nimport subprocess\nimport sys\nfrom abc import abstractmethod\nfrom collections import OrderedDict, defaultdict\n\nimport torch\n\nfrom sglang.srt.utils import is_hip, is_musa, is_npu\n\n\ndef is_cuda_v2():\n    return torch.version.cuda is not None\n\n\n# List of packages to check versions\nPACKAGE_LIST = [\n    \"sglang\",\n    \"sglang-kernel\",\n    \"flashinfer_python\",\n    \"flashinfer_cubin\",\n    \"flashinfer_jit_cache\",\n    \"triton\",\n    \"transformers\",\n    \"torchao\",\n    \"numpy\",\n    \"aiohttp\",\n    \"fastapi\",\n    \"huggingface_hub\",\n    \"interegular\",\n    \"modelscope\",\n    \"orjson\",\n    \"outlines\",\n    \"packaging\",\n    \"psutil\",\n    \"pydantic\",\n    \"python-multipart\",\n    \"pyzmq\",\n    \"torchao\",\n    \"uvicorn\",\n    \"uvloop\",\n    \"vllm\",\n    \"xgrammar\",\n    \"openai\",\n    \"tiktoken\",\n    \"anthropic\",\n    \"litellm\",\n    \"torchcodec\",\n]\n\n\nclass BaseEnv:\n    \"\"\"Base class for environment check\"\"\"\n\n    def __init__(self):\n        self.package_list = PACKAGE_LIST\n\n    @abstractmethod\n    def get_info(self) -> dict:\n        \"\"\"\n        Get CUDA-related information if available.\n        \"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def get_topology(self) -> dict:\n        raise NotImplementedError\n\n    def get_package_versions(self) -> dict:\n        \"\"\"\n        Get versions of specified packages.\n        \"\"\"\n        versions = {}\n        for package in self.package_list:\n            package_name = package.split(\"==\")[0].split(\">=\")[0].split(\"<=\")[0]\n            try:\n                version = importlib.metadata.version(package_name)\n                versions[package_name] = version\n            except ModuleNotFoundError:\n                versions[package_name] = \"Module Not Found\"\n        return versions\n\n    def get_device_info(self):\n        \"\"\"\n        Get information about available GPU devices.\n        \"\"\"\n        devices = defaultdict(list)\n        capabilities = defaultdict(list)\n        for k in range(torch.cuda.device_count()):\n            devices[torch.cuda.get_device_name(k)].append(str(k))\n            capability = torch.cuda.get_device_capability(k)\n            capabilities[f\"{capability[0]}.{capability[1]}\"].append(str(k))\n\n        gpu_info = {}\n        for name, device_ids in devices.items():\n            gpu_info[f\"GPU {','.join(device_ids)}\"] = name\n\n        if len(capabilities) == 1:\n            # All GPUs have the same compute capability\n            cap, gpu_ids = list(capabilities.items())[0]\n            gpu_info[f\"GPU {','.join(gpu_ids)} Compute Capability\"] = cap\n        else:\n            # GPUs have different compute capabilities\n            for cap, gpu_ids in capabilities.items():\n                gpu_info[f\"GPU {','.join(gpu_ids)} Compute Capability\"] = cap\n\n        return gpu_info\n\n    def get_hypervisor_vendor(self) -> dict:\n        try:\n            output = subprocess.check_output([\"lscpu\"], text=True)\n            for line in output.split(\"\\n\"):\n                if \"Hypervisor vendor:\" in line:\n                    return {\"Hypervisor vendor:\": line.split(\":\")[1].strip()}\n            return {}\n        except:\n            return {}\n\n    def get_ulimit_soft(self) -> dict:\n        ulimit_soft, _ = resource.getrlimit(resource.RLIMIT_NOFILE)\n        return {\"ulimit soft\": ulimit_soft}\n\n    def check_env(self):\n        \"\"\"\n        Check and print environment information.\n        \"\"\"\n        env_info = OrderedDict()\n        env_info[\"Python\"] = sys.version.replace(\"\\n\", \"\")\n        env_info.update(self.get_info())\n        env_info[\"PyTorch\"] = torch.__version__\n        env_info.update(self.get_package_versions())\n        env_info.update(self.get_topology())\n        env_info.update(self.get_hypervisor_vendor())\n        env_info.update(self.get_ulimit_soft())\n\n        for k, v in env_info.items():\n            print(f\"{k}: {v}\")\n\n\nclass GPUEnv(BaseEnv):\n    \"\"\"Environment checker for Nvidia GPU\"\"\"\n\n    def get_info(self):\n        cuda_info = {\"CUDA available\": torch.cuda.is_available()}\n\n        if cuda_info[\"CUDA available\"]:\n            cuda_info.update(self.get_device_info())\n            cuda_info.update(self._get_cuda_version_info())\n\n        return cuda_info\n\n    def _get_cuda_version_info(self):\n        \"\"\"\n        Get CUDA version information.\n        \"\"\"\n        from torch.utils.cpp_extension import CUDA_HOME\n\n        cuda_info = {\"CUDA_HOME\": CUDA_HOME}\n\n        if CUDA_HOME and os.path.isdir(CUDA_HOME):\n            cuda_info.update(self._get_nvcc_info())\n            cuda_info.update(self._get_cuda_driver_version())\n\n        return cuda_info\n\n    def _get_nvcc_info(self):\n        \"\"\"\n        Get NVCC version information.\n        \"\"\"\n        from torch.utils.cpp_extension import CUDA_HOME\n\n        try:\n            nvcc = os.path.join(CUDA_HOME, \"bin/nvcc\")\n            nvcc_output = (\n                subprocess.check_output(f'\"{nvcc}\" -V', shell=True)\n                .decode(\"utf-8\")\n                .strip()\n            )\n            return {\n                \"NVCC\": nvcc_output[\n                    nvcc_output.rfind(\"Cuda compilation tools\") : nvcc_output.rfind(\n                        \"Build\"\n                    )\n                ].strip()\n            }\n        except subprocess.SubprocessError:\n            return {\"NVCC\": \"Not Available\"}\n\n    def _get_cuda_driver_version(self):\n        \"\"\"\n        Get CUDA driver version.\n        \"\"\"\n        versions = set()\n        try:\n            output = subprocess.check_output(\n                [\n                    \"nvidia-smi\",\n                    \"--query-gpu=driver_version\",\n                    \"--format=csv,noheader,nounits\",\n                ]\n            )\n            versions = set(output.decode().strip().split(\"\\n\"))\n            if len(versions) == 1:\n                return {\"CUDA Driver Version\": versions.pop()}\n            else:\n                return {\"CUDA Driver Versions\": \", \".join(sorted(versions))}\n        except subprocess.SubprocessError:\n            return {\"CUDA Driver Version\": \"Not Available\"}\n\n    def get_topology(self):\n        \"\"\"\n        Get GPU topology information.\n        \"\"\"\n        try:\n            result = subprocess.run(\n                [\"nvidia-smi\", \"topo\", \"-m\"],\n                stdout=subprocess.PIPE,\n                stderr=subprocess.PIPE,\n                text=True,\n                check=True,\n            )\n            return {\n                \"NVIDIA Topology\": (\n                    \"\\n\" + result.stdout if result.returncode == 0 else None\n                )\n            }\n        except subprocess.SubprocessError:\n            return {}\n\n\nclass HIPEnv(BaseEnv):\n    \"\"\"Environment checker for ROCm/HIP\"\"\"\n\n    def get_info(self):\n        cuda_info = {\"ROCM available\": torch.cuda.is_available()}\n\n        if cuda_info[\"ROCM available\"]:\n            cuda_info.update(self.get_device_info())\n            cuda_info.update(self._get_cuda_version_info())\n\n        return cuda_info\n\n    def _get_cuda_version_info(self):\n        from torch.utils.cpp_extension import ROCM_HOME as ROCM_HOME\n\n        cuda_info = {\"ROCM_HOME\": ROCM_HOME}\n\n        if ROCM_HOME and os.path.isdir(ROCM_HOME):\n            cuda_info.update(self._get_hipcc_info())\n            cuda_info.update(self._get_rocm_driver_version())\n\n        return cuda_info\n\n    def _get_hipcc_info(self):\n        from torch.utils.cpp_extension import ROCM_HOME\n\n        try:\n            hipcc = os.path.join(ROCM_HOME, \"bin/hipcc\")\n            hipcc_output = (\n                subprocess.check_output(f'\"{hipcc}\" --version', shell=True)\n                .decode(\"utf-8\")\n                .strip()\n            )\n            return {\n                \"HIPCC\": hipcc_output[\n                    hipcc_output.rfind(\"HIP version\") : hipcc_output.rfind(\"AMD clang\")\n                ].strip()\n            }\n        except subprocess.SubprocessError:\n            return {\"HIPCC\": \"Not Available\"}\n\n    def _get_rocm_driver_version(self):\n        try:\n            output = subprocess.check_output(\n                [\n                    \"rocm-smi\",\n                    \"--showdriverversion\",\n                    \"--csv\",\n                ]\n            )\n            versions = set(output.decode().strip().split(\"\\n\"))\n            versions.discard(\"name, value\")\n            ver = versions.pop()\n            ver = ver.replace('\"Driver version\", ', \"\").replace('\"', \"\")\n\n            return {\"ROCM Driver Version\": ver}\n        except subprocess.SubprocessError:\n            return {\"ROCM Driver Version\": \"Not Available\"}\n\n    def get_topology(self):\n        try:\n            result = subprocess.run(\n                [\"rocm-smi\", \"--showtopotype\"],\n                stdout=subprocess.PIPE,\n                stderr=subprocess.PIPE,\n                text=True,\n                check=True,\n            )\n            return {\n                \"AMD Topology\": \"\\n\" + result.stdout if result.returncode == 0 else None\n            }\n        except subprocess.SubprocessError:\n            return {}\n\n\nclass NPUEnv(BaseEnv):\n    \"\"\"Environment checker for Ascend NPU\"\"\"\n\n    EXTRA_PACKAGE_LIST = [\n        \"torch_npu\",\n        \"sgl-kernel-npu\",\n        \"deep_ep\",\n    ]\n\n    def __init__(self):\n        super().__init__()\n        self.package_list.extend(NPUEnv.EXTRA_PACKAGE_LIST)\n\n    def get_info(self):\n        cuda_info = {\"NPU available\": torch.npu.is_available()}\n        if cuda_info[\"NPU available\"]:\n            cuda_info.update(self.get_device_info())\n            cuda_info.update(self._get_cann_version_info())\n\n        return cuda_info\n\n    def get_device_info(self):\n        \"\"\"\n        Get information about available NPUs.\n        Need to override due to torch_npu interface differences.\n        \"\"\"\n        devices = defaultdict(list)\n        for k in range(torch.npu.device_count()):\n            devices[torch.npu.get_device_name(k)].append(str(k))\n\n        npu_info = {}\n        for name, device_ids in devices.items():\n            npu_info[f\"NPU {','.join(device_ids)}\"] = name\n\n        return npu_info\n\n    def _get_cann_version_info(self):\n        cann_envs = [\"ASCEND_TOOLKIT_HOME\", \"ASCEND_INSTALL_PATH\"]\n        for var in cann_envs:\n            path = os.environ.get(var)\n            if path and os.path.exists(path):\n                CANN_HOME = path\n                break\n        else:\n            default_path = \"/usr/local/Ascend/ascend-toolkit/latest\"\n            CANN_HOME = default_path if os.path.exists(default_path) else None\n\n        if CANN_HOME:\n            npu_info = {\"CANN_HOME\": CANN_HOME}\n            npu_info.update(self._get_cann_info(CANN_HOME))\n            npu_info.update(self._get_ascend_driver_version())\n            return npu_info\n        else:\n            return {\"CANN_HOME\": \"Not found\"}\n\n    def _get_cann_info(self, CANN_HOME: str):\n        cann_info = {}\n        cann_version_file = os.path.join(CANN_HOME, \"version.cfg\")\n        if os.path.exists(cann_version_file):\n            with open(cann_version_file, \"r\", encoding=\"utf-8\") as f:\n                f.readline()  # discard first line comment in version.cfg\n                cann_info[\"CANN\"] = f.readline().split(\"[\")[1].split(\"]\")[0]\n        else:\n            cann_info[\"CANN\"] = \"Not Available\"\n        try:\n            bisheng = os.path.join(CANN_HOME, \"compiler/ccec_compiler/bin/bisheng\")\n            bisheng_output = (\n                subprocess.check_output([bisheng, \"--version\"]).decode(\"utf-8\").strip()\n            )\n            cann_info[\"BiSheng\"] = bisheng_output.split(\"\\n\")[0].strip()\n        except subprocess.SubprocessError:\n            cann_info[\"BiSheng\"] = \"Not Available\"\n        return cann_info\n\n    def _get_ascend_driver_version(self):\n        try:\n            output = subprocess.check_output(\n                [\n                    \"npu-smi\",\n                    \"info\",\n                    \"-t\",\n                    \"board\",\n                    \"-i\",\n                    \"0\",\n                ]\n            )\n            for line in output.decode().strip().split(\"\\n\"):\n                if \"Software Version\" in line:\n                    version = line.split(\":\")[-1].strip()\n                    break\n            else:\n                version = \"Not Available\"\n\n            return {\"Ascend Driver Version\": version}\n        except subprocess.SubprocessError:\n            return {\"Ascend Driver Version\": \"Not Available\"}\n\n    def get_topology(self):\n        try:\n            result = subprocess.run(\n                [\"npu-smi\", \"info\", \"-t\", \"topo\"],\n                stdout=subprocess.PIPE,\n                stderr=subprocess.PIPE,\n                text=True,\n                check=True,\n            )\n            return {\n                \"Ascend Topology\": (\n                    \"\\n\" + result.stdout if result.returncode == 0 else None\n                )\n            }\n        except subprocess.SubprocessError:\n            return {}\n\n\nclass MUSAEnv(BaseEnv):\n    \"\"\"Environment checker for MThreads GPU\"\"\"\n\n    def get_info(self):\n        musa_info = {\"MUSA available\": torch.musa.is_available()}\n\n        if musa_info[\"MUSA available\"]:\n            musa_info.update(self.get_device_info())\n            musa_info.update(self._get_musa_version_info())\n\n        return musa_info\n\n    def _get_musa_version_info(self):\n        \"\"\"\n        Get MUSA version information.\n        \"\"\"\n        from torch_musa.utils.musa_extension import MUSA_HOME\n\n        musa_info = {\"MUSA_HOME\": MUSA_HOME}\n\n        if MUSA_HOME and os.path.isdir(MUSA_HOME):\n            musa_info.update(self._get_mcc_info())\n            musa_info.update(self._get_musa_driver_version())\n\n        return musa_info\n\n    def _get_mcc_info(self):\n        \"\"\"\n        Get MCC version information.\n        \"\"\"\n        from torch_musa.utils.musa_extension import MUSA_HOME\n\n        try:\n            mcc = os.path.join(MUSA_HOME, \"bin/mcc\")\n            mcc_output = (\n                subprocess.check_output(f'\"{mcc}\" --version', shell=True)\n                .decode(\"utf-8\")\n                .strip()\n            )\n            return {\n                \"MCC\": mcc_output[\n                    mcc_output.rfind(\"mcc version\") : mcc_output.rfind(\"Target\")\n                ].strip()\n            }\n        except subprocess.SubprocessError:\n            return {\"MCC\": \"Not Available\"}\n\n    def _get_musa_driver_version(self):\n        \"\"\"\n        Get MUSA driver version.\n        \"\"\"\n        try:\n            output = subprocess.check_output(\n                [\n                    \"mthreads-gmi\",\n                    \"-q\",\n                ],\n                text=True,\n            )\n            driver_version = None\n            for line in output.splitlines():\n                if \"Driver Version\" in line:\n                    driver_version = line.split(\":\", 1)[1].strip()\n                    break\n\n            return {\"MUSA Driver Version\": driver_version}\n        except subprocess.SubprocessError:\n            return {\"MUSA Driver Version\": \"Not Available\"}\n\n    def get_topology(self):\n        \"\"\"\n        Get GPU topology information.\n        \"\"\"\n        try:\n            result = subprocess.run(\n                [\"mthreads-gmi\", \"topo\", \"-m\"],\n                stdout=subprocess.PIPE,\n                stderr=subprocess.PIPE,\n                text=True,\n                check=True,\n            )\n            return {\n                \"MTHREADS Topology\": (\n                    \"\\n\" + result.stdout if result.returncode == 0 else None\n                )\n            }\n        except subprocess.SubprocessError:\n            return {}\n\n\nif __name__ == \"__main__\":\n    if is_cuda_v2():\n        env = GPUEnv()\n    elif is_hip():\n        env = HIPEnv()\n    elif is_npu():\n        env = NPUEnv()\n    elif is_musa():\n        env = MUSAEnv()\n    env.check_env()\n"
  },
  {
    "path": "python/sglang/cli/__init__.py",
    "content": ""
  },
  {
    "path": "python/sglang/cli/generate.py",
    "content": "import argparse\n\nfrom sglang.cli.utils import get_is_diffusion_model, get_model_path\n\n\ndef generate(args, extra_argv):\n    # If help is requested, show generate subcommand help without requiring --model-path\n    if any(h in extra_argv for h in (\"-h\", \"--help\")):\n        from sglang.multimodal_gen.runtime.entrypoints.cli.generate import (\n            add_multimodal_gen_generate_args,\n        )\n\n        parser = argparse.ArgumentParser(description=\"SGLang Multimodal Generation\")\n        add_multimodal_gen_generate_args(parser)\n        parser.parse_args(extra_argv)\n        return\n\n    model_path = get_model_path(extra_argv)\n    is_diffusion_model = get_is_diffusion_model(model_path)\n    if is_diffusion_model:\n        from sglang.multimodal_gen.runtime.entrypoints.cli.generate import (\n            add_multimodal_gen_generate_args,\n            generate_cmd,\n        )\n\n        parser = argparse.ArgumentParser(description=\"SGLang Multimodal Generation\")\n        add_multimodal_gen_generate_args(parser)\n        parsed_args, unknown_args = parser.parse_known_args(extra_argv)\n        generate_cmd(parsed_args, unknown_args)\n    else:\n        raise Exception(\n            f\"Generate subcommand is not yet supported for model: {model_path}\"\n        )\n"
  },
  {
    "path": "python/sglang/cli/main.py",
    "content": "import argparse\n\nfrom sglang.cli.utils import get_git_commit_hash\nfrom sglang.version import __version__\n\n\ndef version(args, extra_argv):\n    print(f\"sglang version: {__version__}\")\n    print(f\"git revision: {get_git_commit_hash()[:7]}\")\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n\n    # complex sub commands\n    subparsers = parser.add_subparsers(dest=\"subcommand\", required=True)\n    subparsers.add_parser(\n        \"serve\",\n        help=\"Launch an SGLang server.\",\n        add_help=False,\n    )\n    subparsers.add_parser(\n        \"generate\",\n        help=\"Run inference on a multimodal model.\",\n        add_help=False,\n    )\n\n    # simple commands\n    version_parser = subparsers.add_parser(\n        \"version\",\n        help=\"Show the version information.\",\n    )\n    version_parser.set_defaults(func=version)\n\n    args, extra_argv = parser.parse_known_args()\n\n    if args.subcommand == \"serve\":\n        from sglang.cli.serve import serve\n\n        serve(args, extra_argv)\n    elif args.subcommand == \"generate\":\n        from sglang.cli.generate import generate\n\n        generate(args, extra_argv)\n    elif args.subcommand == \"version\":\n        version(args, extra_argv)\n"
  },
  {
    "path": "python/sglang/cli/serve.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\nimport argparse\nimport logging\nimport os\n\nfrom sglang.cli.utils import get_is_diffusion_model, get_model_path\nfrom sglang.srt.utils import kill_process_tree\nfrom sglang.srt.utils.common import suppress_noisy_warnings\n\nsuppress_noisy_warnings()\n\nlogger = logging.getLogger(__name__)\n\n\ndef _extract_model_type_override(extra_argv):\n    \"\"\"Extract and remove --model-type override from argv.\"\"\"\n    model_type = \"auto\"\n    filtered_argv = []\n    i = 0\n    while i < len(extra_argv):\n        arg = extra_argv[i]\n        if arg == \"--model-type\":\n            if i + 1 >= len(extra_argv):\n                raise Exception(\n                    \"Error: --model-type requires a value. \"\n                    \"Valid values are: auto, llm, diffusion.\"\n                )\n            model_type = extra_argv[i + 1]\n            i += 2\n            continue\n\n        if arg.startswith(\"--model-type=\"):\n            model_type = arg.split(\"=\", 1)[1]\n            i += 1\n            continue\n\n        filtered_argv.append(arg)\n        i += 1\n\n    if model_type not in (\"auto\", \"llm\", \"diffusion\"):\n        raise Exception(\n            f\"Error: invalid --model-type '{model_type}'. \"\n            \"Valid values are: auto, llm, diffusion.\"\n        )\n    return model_type, filtered_argv\n\n\ndef serve(args, extra_argv):\n    if any(h in extra_argv for h in (\"-h\", \"--help\")):\n        # Since the server type is determined by the model, and we don't have a model path,\n        # we can't show the exact help. Instead, we show a general help message and then\n        # the help for both possible server types.\n        print(\n            \"Usage: sglang serve --model-path <model-name-or-path> [additional-arguments]\\n\\n\"\n            \"This command can launch either a standard language model server or a diffusion model server.\\n\"\n            \"The server type is determined by the --model-path.\\n\"\n            \"Optional override: --model-type {auto,llm,diffusion} \"\n            \"(default: auto, fallback to LLM on detection failure).\"\n        )\n\n        print(\"\\n--- Help for Standard Language Model Server ---\")\n        from sglang.srt.server_args import prepare_server_args\n\n        try:\n            prepare_server_args([\"--help\"])\n        except SystemExit:\n            pass  # argparse --help calls sys.exit\n\n        print(\"\\n--- Help for Diffusion Model Server ---\")\n        try:\n            from sglang.multimodal_gen.runtime.entrypoints.cli.serve import (\n                add_multimodal_gen_serve_args,\n            )\n\n            parser = argparse.ArgumentParser(\n                prog=\"sglang serve\",\n                description=\"SGLang Diffusion Model Serving\",\n            )\n            add_multimodal_gen_serve_args(parser)\n            parser.print_help()\n        except ImportError:\n            print(\n                \"Diffusion model support is not available. \"\n                'Install with: pip install \"sglang[diffusion]\"'\n            )\n        return\n\n    model_type, dispatch_argv = _extract_model_type_override(extra_argv)\n    model_path = get_model_path(dispatch_argv)\n    try:\n        if model_type == \"auto\":\n            is_diffusion_model = get_is_diffusion_model(model_path)\n            if is_diffusion_model:\n                logger.info(\"Diffusion model detected\")\n        else:\n            is_diffusion_model = model_type == \"diffusion\"\n            logger.info(\n                \"Dispatch override enabled: --model-type=%s \" \"(skip auto detection)\",\n                model_type,\n            )\n\n        if is_diffusion_model:\n            # Logic for Diffusion Models\n            from sglang.multimodal_gen.runtime.entrypoints.cli.serve import (\n                add_multimodal_gen_serve_args,\n                execute_serve_cmd,\n            )\n\n            parser = argparse.ArgumentParser(\n                description=\"SGLang Diffusion Model Serving\"\n            )\n            add_multimodal_gen_serve_args(parser)\n            parsed_args, remaining_argv = parser.parse_known_args(dispatch_argv)\n\n            execute_serve_cmd(parsed_args, remaining_argv)\n        else:\n            # Logic for Standard Language Models\n            from sglang.launch_server import run_server\n            from sglang.srt.server_args import prepare_server_args\n\n            server_args = prepare_server_args(dispatch_argv)\n\n            run_server(server_args)\n    finally:\n        kill_process_tree(os.getpid(), include_parent=False)\n"
  },
  {
    "path": "python/sglang/cli/utils.py",
    "content": "import json\nimport logging\nimport os\nimport subprocess\nfrom functools import lru_cache\n\nfrom sglang.srt.environ import envs\n\nlogger = logging.getLogger(__name__)\n\n\ndef _is_diffusers_model_dir(model_dir: str) -> bool:\n    \"\"\"Check if a local directory contains a valid diffusers model_index.json.\"\"\"\n    config_path = os.path.join(model_dir, \"model_index.json\")\n    if not os.path.exists(config_path):\n        return False\n\n    with open(config_path) as f:\n        config = json.load(f)\n\n    return \"_diffusers_version\" in config\n\n\ndef get_is_diffusion_model(model_path: str) -> bool:\n    \"\"\"Detect whether model_path points to a diffusion model.\n\n    For local directories, checks the filesystem directly.\n    For HF/ModelScope model IDs, attempts to fetch only model_index.json.\n    Returns False on any failure (network error, 404, offline mode, etc.)\n    so that the caller falls through to the standard LLM server path.\n    \"\"\"\n    try:\n        from sglang.multimodal_gen.registry import (\n            is_known_non_diffusers_multimodal_model,\n        )\n    except ImportError:\n        is_known_non_diffusers_multimodal_model = lambda _: False\n\n    if os.path.isdir(model_path):\n        if _is_diffusers_model_dir(model_path):\n            return True\n        return is_known_non_diffusers_multimodal_model(model_path)\n\n    if is_known_non_diffusers_multimodal_model(model_path):\n        return True\n\n    try:\n        if envs.SGLANG_USE_MODELSCOPE.get():\n            from modelscope import model_file_download\n\n            file_path = model_file_download(\n                model_id=model_path, file_path=\"model_index.json\"\n            )\n        else:\n            from huggingface_hub import hf_hub_download\n\n            file_path = hf_hub_download(repo_id=model_path, filename=\"model_index.json\")\n\n        return _is_diffusers_model_dir(os.path.dirname(file_path))\n    except Exception as e:\n        logger.debug(\"Failed to auto-detect diffusion model for %s: %s\", model_path, e)\n        return False\n\n\ndef get_model_path(extra_argv):\n    # Find the model_path argument\n    model_path = None\n    for i, arg in enumerate(extra_argv):\n        if arg == \"--model-path\":\n            if i + 1 < len(extra_argv):\n                model_path = extra_argv[i + 1]\n                break\n        elif arg.startswith(\"--model-path=\"):\n            model_path = arg.split(\"=\", 1)[1]\n            break\n\n    if model_path is None:\n        # Fallback for --help or other cases where model-path is not provided\n        if any(h in extra_argv for h in [\"-h\", \"--help\"]):\n            raise Exception(\n                \"Usage: sglang serve --model-path <model-name-or-path> [additional-arguments]\\n\\n\"\n                \"This command can launch either a standard language model server or a diffusion model server.\\n\"\n                \"The server type is determined by the --model-path.\\n\"\n            )\n        else:\n            raise Exception(\n                \"Error: --model-path is required. \"\n                \"Please provide the path to the model.\"\n            )\n    return model_path\n\n\n@lru_cache(maxsize=1)\ndef get_git_commit_hash() -> str:\n    try:\n        commit_hash = os.environ.get(\"SGLANG_GIT_COMMIT\")\n        if not commit_hash:\n            commit_hash = (\n                subprocess.check_output(\n                    [\"git\", \"rev-parse\", \"HEAD\"], stderr=subprocess.DEVNULL\n                )\n                .strip()\n                .decode(\"utf-8\")\n            )\n        _CACHED_COMMIT_HASH = commit_hash\n        return commit_hash\n    except (subprocess.CalledProcessError, FileNotFoundError):\n        _CACHED_COMMIT_HASH = \"N/A\"\n        return \"N/A\"\n"
  },
  {
    "path": "python/sglang/compile_deep_gemm.py",
    "content": "\"\"\"\nCompile DeepGEMM Kernels for a model with specify server arguments\n\nThis script launches a server for capturing DeepGEMM calls and then compiles the kernels.\nIt accepts server arguments (the same as launch_server.py).\n\nUsage:\npython3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code\n\n\"\"\"\n\nimport argparse\nimport dataclasses\nimport multiprocessing\nimport os\nimport time\n\nimport requests\n\nfrom sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST\nfrom sglang.srt.entrypoints.http_server import launch_server\nfrom sglang.srt.entrypoints.warmup import warmup\nfrom sglang.srt.environ import envs\nfrom sglang.srt.managers.io_struct import GenerateReqInput\nfrom sglang.srt.managers.tokenizer_manager import TokenizerManager\nfrom sglang.srt.server_args import ServerArgs\nfrom sglang.srt.utils import kill_process_tree\n\nmultiprocessing.set_start_method(\"spawn\", force=True)\n\n# Reduce warning\nenvs.SGLANG_IN_DEEPGEMM_PRECOMPILE_STAGE.set(True)\n# Force enable deep gemm\nenvs.SGLANG_ENABLE_JIT_DEEPGEMM.set(True)\n# Force enable mha chunked kv for DeepSeek V3 to avoid missing kv_b_proj DeepGEMM case\nenvs.SGLANG_CHUNKED_PREFIX_CACHE_THRESHOLD.set(0)\n\n\n@dataclasses.dataclass\nclass CompileArgs:\n    timeout: int = 3600\n\n    @staticmethod\n    def add_cli_args(parser: argparse.ArgumentParser):\n        parser.add_argument(\"--timeout\", type=int, default=CompileArgs.timeout)\n\n    @classmethod\n    def from_cli_args(cls, args: argparse.Namespace):\n        # use the default value's type to cast the args into correct types.\n        attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]\n        return cls(\n            **{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}\n        )\n\n\n@warmup(\"compile-deep-gemm\")\nasync def warm_up_compile(\n    disaggregation_mode: str, tokenizer_manager: TokenizerManager\n):\n    print(\"\\nGenerate warm up request for compiling DeepGEMM...\\n\")\n    generate_req_input = GenerateReqInput(\n        input_ids=[0, 1, 2, 3],\n        sampling_params={\n            \"temperature\": 0.0,\n            \"max_new_tokens\": 8,\n            \"ignore_eos\": True,\n        },\n    )\n    if disaggregation_mode != \"null\":\n        generate_req_input.bootstrap_room = 0\n        generate_req_input.bootstrap_host = FAKE_BOOTSTRAP_HOST\n\n    await tokenizer_manager.generate_request(generate_req_input, None).__anext__()\n\n\ndef launch_server_internal(server_args):\n    try:\n        launch_server(server_args)\n    except Exception as e:\n        raise e\n    finally:\n        kill_process_tree(os.getpid(), include_parent=False)\n\n\ndef launch_server_process_and_send_one_request(\n    server_args: ServerArgs, compile_args: CompileArgs\n):\n    proc = multiprocessing.Process(target=launch_server_internal, args=(server_args,))\n    proc.start()\n    base_url = f\"http://{server_args.host}:{server_args.port}\"\n    timeout = compile_args.timeout\n\n    start_time = time.perf_counter()\n    while time.perf_counter() - start_time < timeout:\n        try:\n            headers = {\n                \"Content-Type\": \"application/json; charset=utf-8\",\n            }\n            if server_args.node_rank == 0:\n                response = requests.get(f\"{base_url}/v1/models\", headers=headers)\n            else:\n                # This http api is created by launch_dummy_health_check_server for none-rank0 node.\n                response = requests.get(f\"{base_url}/health\", headers=headers)\n            if response.status_code == 200:\n                # Rank-0 node send a request to sync with other node and then return.\n                if server_args.node_rank == 0:\n                    payload = {\n                        \"input_ids\": [0, 1, 2, 3],\n                        \"sampling_params\": {\n                            \"max_new_tokens\": 8,\n                            \"temperature\": 0,\n                        },\n                    }\n                    # In PD mode, include fake bootstrap fields so workers don't assert\n                    if server_args.disaggregation_mode != \"null\":\n                        payload[\"bootstrap_host\"] = FAKE_BOOTSTRAP_HOST\n                        payload[\"bootstrap_room\"] = 0\n\n                    response = requests.post(\n                        f\"{base_url}/generate\",\n                        json=payload,\n                        timeout=600,\n                    )\n                    if response.status_code != 200:\n                        error = response.json()\n                        raise RuntimeError(f\"Sync request failed: {error}\")\n                # Other nodes should wait for the exit signal from Rank-0 node.\n                else:\n                    start_time_waiting = time.perf_counter()\n                    while proc.is_alive():\n                        if time.perf_counter() - start_time_waiting < timeout:\n                            time.sleep(10)\n                        else:\n                            raise TimeoutError(\"Waiting for main node timeout!\")\n                return proc\n        except requests.RequestException:\n            pass\n        time.sleep(10)\n    raise TimeoutError(\n        \"DeepGEMM Kernels compilation timeout.\"\n        \"\\n\\nFeel free and please restart the command.\"\n    )\n\n\ndef refine_server_args(server_args: ServerArgs, compile_args: CompileArgs):\n    # Disable cuda graph and torch compile to save time\n    server_args.disable_cuda_graph = True\n    server_args.enable_torch_compile = False\n    print(f\"Disable CUDA Graph and Torch Compile to save time...\")\n\n    # Set watchdog timeout to compile_args.timeout because compilation will take a long time\n    server_args.watchdog_timeout = compile_args.timeout\n    server_args.warmups = \"compile-deep-gemm\"\n\n\ndef run_compile(server_args: ServerArgs, compile_args: CompileArgs):\n    print(\n        \"Begin DeepGEMM Kernels compilation...\\n\"\n        \"It may take a long time and timeout maybe raised \"\n        \"while the compilation is still in progress.\\n\"\n        \"Just feel free to restart the command \"\n        \"until the compilation is fully finished.\\n\"\n    )\n\n    proc = launch_server_process_and_send_one_request(server_args, compile_args)\n\n    print(\"\\nDeepGEMM Kernels compilation finished successfully.\")\n\n    # Sleep for safety\n    time.sleep(10)\n    if proc.is_alive():\n        # This is the rank0 node.\n        kill_process_tree(proc.pid)\n    else:\n        try:\n            kill_process_tree(proc.pid)\n        except Exception:\n            pass\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    ServerArgs.add_cli_args(parser)\n    CompileArgs.add_cli_args(parser)\n    args = parser.parse_args()\n    server_args = ServerArgs.from_cli_args(args)\n    compile_args = CompileArgs.from_cli_args(args)\n\n    refine_server_args(server_args, compile_args)\n\n    run_compile(server_args, compile_args)\n"
  },
  {
    "path": "python/sglang/eval/llama3_eval.py",
    "content": "# Adapt from https://github.com/fw-ai/llm_eval_meta\n\nimport argparse\nimport asyncio\nimport os\nimport pickle\nimport re\nimport shutil\nfrom collections import defaultdict\nfrom dataclasses import dataclass\n\nimport httpx\nimport numpy as np\nimport openai\nfrom datasets import load_dataset\nfrom openai import AsyncOpenAI\nfrom tqdm import tqdm\n\n# Mapping providers to their clients and models\nprovider_to_models = {\n    \"b10\": {\n        \"8b\": \"meta-llama/Llama-3.1-8B-Instruct\",\n        \"70b\": \"meta-llama/Llama-3.1-70B-Instruct\",\n        \"405b\": \"meta-llama/Llama-3.1-405B-Instruct\",\n    },\n    \"oai\": {\n        \"8b\": \"meta-llama/Llama-3.1-8B-Instruct\",\n        \"70b\": \"meta-llama/Llama-3.1-70B-Instruct\",\n        \"405b\": \"meta-llama/Llama-3.1-405B-Instruct\",\n    },\n    \"sgl\": {\n        \"8b\": \"meta-llama/Llama-3.1-8B-Instruct\",\n        \"70b\": \"meta-llama/Llama-3.1-70B-Instruct\",\n        \"405b\": \"meta-llama/Llama-3.1-405B-Instruct\",\n    },\n}\n\n\nasync def fetch_responses(\n    client, prompt, semaphore, index, provider, model_size, output_dir, max_tokens\n):\n    output_file = os.path.join(output_dir, f\"response_{index}.pkl\")\n    if os.path.exists(output_file):\n        print(f\"File {output_file} already exists, skipping.\")\n        return\n\n    async with semaphore:\n        response = await client.completions.create(\n            model=provider_to_models[provider][model_size],\n            prompt=prompt,\n            temperature=0.0,\n            max_tokens=max_tokens,\n        )\n        if isinstance(response, openai.BadRequestError):\n            with open(output_file, \"wb\") as f:\n                pickle.dump(\"bad_response\", f)\n        assert isinstance(response, openai.types.completion.Completion)\n        # Save response to a file\n        with open(output_file, \"wb\") as f:\n            pickle.dump(response, f)\n\n\nTASK_TO_MAX_TOKENS = {\n    \"evals__mmlu__details\": 1,\n    \"evals__mmlu__0_shot__cot__details\": 1024,\n    # Official meta uses 1024, but a small % (.05) of questions are answered correctly after relaxing\n    \"evals__mmlu_pro__details\": 2048,\n    \"evals__gsm8k__details\": 1024,\n}\n\nTASK_TO_EVAL_SET = {\n    \"mmlu\": \"evals__mmlu__details\",\n    \"mmlu_cot\": \"evals__mmlu__0_shot__cot__details\",\n    \"mmlu_pro\": \"evals__mmlu_pro__details\",\n    \"gsm8k\": \"evals__gsm8k__details\",\n}\n\n\nclass CustomAsyncHTTPXClient(httpx.AsyncClient):\n    async def send(self, request: httpx.Request, *args, **kwargs) -> httpx.Response:\n        request.url = httpx.URL(\n            f\"https://model-{os.getenv('MODEL_ID')}.api.baseten.co/development/predict\"\n        )\n        return await super().send(request, *args, **kwargs)\n\n\ndef get_client(provider):\n    if provider not in \"b10\":\n        if os.getenv(\"OPENAI_API_KEY\") is None:\n            os.environ[\"OPENAI_API_KEY\"] = \"EMPTY\"\n    return {\n        \"oai\": AsyncOpenAI(base_url=\"http://127.0.0.1:8000/v1/\"),\n        \"b10\": AsyncOpenAI(\n            api_key=f\"Api-Key {os.getenv('OPENAI_API_KEY')}\",\n            base_url=f\"https://model-{os.getenv('MODEL_ID')}.api.baseten.co/development/predict\",\n            http_client=CustomAsyncHTTPXClient(),\n        ),\n        \"sgl\": AsyncOpenAI(base_url=\"http://127.0.0.1:30000/v1/\"),\n    }[provider]\n\n\n# Define the benchmark function\nasync def benchmark(args):\n    ds = load_dataset(\n        \"meta-llama/Llama-3.1-405B-Instruct-evals\",\n        f\"Llama-3.1-405B-Instruct-{TASK_TO_EVAL_SET[args.task]}\",\n    )\n    semaphore = asyncio.Semaphore(args.concurrency)  # Limit to 16 concurrent tasks\n\n    if args.num_examples is None:\n        args.num_examples = len(ds[\"latest\"][\"input_final_prompts\"])\n    prompts = ds[\"latest\"][\"input_final_prompts\"][: args.num_examples]\n\n    # Create the output directory if it does not exist\n    os.makedirs(args.output_dir, exist_ok=True)\n\n    tasks = []\n    # Create the tasks with tqdm progress bar\n    max_tokens = TASK_TO_MAX_TOKENS[TASK_TO_EVAL_SET[args.task]]\n    client = get_client(args.provider)\n    for idx, prompt in enumerate(tqdm(prompts, desc=\"Creating tasks\")):\n        tasks.append(\n            asyncio.create_task(\n                fetch_responses(\n                    client,\n                    f\"<|begin_of_text|>{prompt[0]}\",\n                    semaphore,\n                    idx,\n                    args.provider,\n                    args.model_size,\n                    args.output_dir,\n                    max_tokens=max_tokens,\n                )\n            )\n        )\n\n    # Run the tasks with tqdm progress bar\n    for future in tqdm(\n        asyncio.as_completed(tasks), total=len(tasks), desc=\"Processing tasks\"\n    ):\n        await future\n\n\ndef get_mmlu_answer(response):\n    if response is not None:\n        return response.choices[0].text.lstrip().rstrip().upper().replace(\".\", \"\")\n    return None\n\n\ndef get_mmlu_cot_answer(response):\n    pattern = r\"The best answer is (.+)\\.?\"\n    match = re.search(pattern, response.choices[0].text)\n    if match:\n        return match.group(1).replace(\".\", \"\").replace(\"*\", \"\")\n\n    pattern = r\"the best answer is (.+)\\.?\"\n    match = re.search(pattern, response.choices[0].text)\n    if match:\n        return match.group(1).replace(\".\", \"\")\n\n    pattern = r\"The correct answer is (.+)\\.?\"\n    match = re.search(pattern, response.choices[0].text)\n    if match:\n        return match.group(1).replace(\".\", \"\")\n\n    pattern = r\"the correct answer is (.+)\\.?\"\n    match = re.search(pattern, response.choices[0].text)\n    if match:\n        return match.group(1).replace(\".\", \"\")\n\n\ndef get_answer_gsm8k(response):\n    pattern = r\"The final answer is (.+)\\.?\"\n    match = re.search(pattern, response.choices[0].text)\n    if match:\n        s = match.group(1)\n        for ok_symbol in [\"%\", \"$\"]:\n            s = s.replace(ok_symbol, \"\")\n        return s\n\n\nTASK_TO_ANSWER_EXTRACTOR = {\n    \"evals__mmlu__details\": get_mmlu_answer,\n    \"evals__mmlu__0_shot__cot__details\": get_mmlu_cot_answer,\n    \"evals__gsm8k__details\": get_answer_gsm8k,\n    \"evals__mmlu_pro__details\": get_mmlu_cot_answer,\n}\n\n\ndef get_dataset_from_task(task, response_path, model_size):\n    ds_405b = load_dataset(\n        f\"meta-llama/Llama-3.1-405B-Instruct-evals\",\n        f\"Llama-3.1-405B-Instruct-{task}\",\n    )\n    ds_405b_hash_order = [x[0] for x in ds_405b[\"latest\"][\"input_final_prompts_hash\"]]\n\n    if \"70b\" in model_size or \"8b\" in model_size:\n        if \"70\" in model_size:\n            ref_model_ds = load_dataset(\n                f\"meta-llama/Llama-3.1-70B-Instruct-evals\",\n                f\"Llama-3.1-70B-Instruct-{task}\",\n            )\n        else:\n            ref_model_ds = load_dataset(\n                f\"meta-llama/Llama-3.1-8B-Instruct-evals\",\n                f\"Llama-3.1-8B-Instruct-{task}\",\n            )\n\n        hash_to_row = {}\n        for row in ref_model_ds[\"latest\"]:\n            hash_to_row[row[\"input_final_prompts_hash\"][0]] = row\n        reordered_rows = []\n        for prompt_hash in ds_405b_hash_order:\n            reordered_rows.append(hash_to_row[prompt_hash])\n        ref_model_ds[\"latest\"] = reordered_rows\n        return ref_model_ds\n\n    return ds_405b\n\n\ndef analyze(task, response_path, model_size):\n    ds = get_dataset_from_task(task, response_path, model_size)\n\n    responses = []\n    total = len(ds[\"latest\"])\n\n    for i in range(0, total):\n        response = pickle.load(\n            open(os.path.join(response_path, f\"response_{i}.pkl\"), \"rb\")\n        )\n        responses.append(response)\n\n    @dataclass\n    class Stats:\n        correct: int = 0\n        total: int = 0\n        meta_correct: int = 0\n\n        average: float = None\n\n    subtask_name_to_stats = defaultdict(lambda: Stats())\n\n    for response, ds_row in zip(responses, ds[\"latest\"]):\n        model_answer = TASK_TO_ANSWER_EXTRACTOR[task](response)\n\n        subtask = ds_row[\"subtask_name\"]\n\n        is_eval_correct = model_answer in ds_row[\"input_correct_responses\"]\n        if is_eval_correct:\n            subtask_name_to_stats[subtask].correct += 1\n\n        if ds_row[\"is_correct\"]:\n            subtask_name_to_stats[subtask].meta_correct += 1\n\n        subtask_name_to_stats[subtask].total += 1\n\n    micro_stats = Stats()\n    for subtask, stats in subtask_name_to_stats.items():\n        stats.average = stats.correct / stats.total\n        stats.meta_average = stats.meta_correct / stats.total\n\n        micro_stats.correct += stats.correct\n        micro_stats.total += stats.total\n        micro_stats.meta_correct += stats.meta_correct\n\n    micro_stats.average = micro_stats.correct / micro_stats.total\n    micro_stats.meta_average = micro_stats.meta_correct / micro_stats.total\n\n    print(\"Macro average\", np.mean([x.average for x in subtask_name_to_stats.values()]))\n    print(\n        \"Meta Macro average\",\n        np.mean([x.meta_average for x in subtask_name_to_stats.values()]),\n    )\n    print(\"Micro average\", micro_stats.average)\n    print(\"Meta Micro average\", micro_stats.meta_average)\n\n\n# Entry point for the script\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\n        description=\"Script to run model with specified parameters.\"\n    )\n    parser.add_argument(\n        \"--model-size\",\n        type=str,\n        default=\"8b\",\n        help=\"Size of the model (e.g., 8b or 70b)\",\n    )\n    parser.add_argument(\n        \"--provider\",\n        type=str,\n        default=\"sgl\",\n        help=\"Provider name (e.g., sgl, oai, b10)\",\n    )\n    parser.add_argument(\n        \"--task\",\n        type=str,\n        required=True,\n        help=\"Task (e.g., mmlu, mmlu_cot, mmlu_pro, gsm8k)\",\n    )\n    parser.add_argument(\n        \"--num-examples\", type=int, default=None, help=\"Number of examples to process\"\n    )\n    parser.add_argument(\"--concurrency\", type=int, default=16)\n    parser.add_argument(\n        \"--output-dir\",\n        type=str,\n        default=\"tmp-output-dir\",\n        help=\"Directory to save responses\",\n    )\n\n    args = parser.parse_args()\n    asyncio.run(benchmark(args))\n    analyze(TASK_TO_EVAL_SET[args.task], args.output_dir, args.model_size)\n    shutil.rmtree(\"tmp-output-dir\", ignore_errors=True)\n"
  },
  {
    "path": "python/sglang/eval/loogle_eval.py",
    "content": "import argparse\nimport asyncio\nimport os\nimport pickle\nfrom pathlib import Path\nfrom typing import List\n\nimport openai\nimport torch\nfrom bert_score import BERTScorer\nfrom datasets import load_dataset\nfrom tqdm import tqdm\n\n\ndef get_client(api_url: str) -> openai.AsyncOpenAI:\n    if os.getenv(\"OPENAI_API_KEY\") is None:\n        os.environ[\"OPENAI_API_KEY\"] = \"EMPTY\"\n    return openai.AsyncOpenAI(base_url=api_url)\n\n\ndef get_dataset():\n    return load_dataset(\"bigai-nlco/LooGLE\", \"longdep_qa\", split=\"test\")\n\n\nasync def fetch_response(\n    client: openai.AsyncOpenAI,\n    context: str,\n    question: str,\n    semaphore: asyncio.Semaphore,\n    index: int,\n    model: str,\n    output_dir: Path,\n):\n    output_file = output_dir / f\"response_{index}.pkl\"\n    if output_file.exists():\n        return\n\n    prompt = (\n        \"Please answer the question based on the long texts below.\\n\"\n        f\"{context}\\n\"\n        f\"Question: {question}\\n\"\n        \"Answer:\"\n    )\n    messages = [\n        {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n        {\"role\": \"user\", \"content\": prompt},\n    ]\n\n    async with semaphore:\n        try:\n            response = await client.chat.completions.create(\n                model=model,\n                messages=messages,\n                temperature=0.0,\n                max_tokens=512,\n            )\n        except openai.BadRequestError as e:\n            with open(output_file, \"wb\") as f:\n                pickle.dump({\"error\": str(e)}, f)\n            return\n\n    with open(output_file, \"wb\") as f:\n        pickle.dump(response, f)\n\n\nasync def benchmark(args):\n    dataset = get_dataset()\n    output_dir = Path(args.output_dir)\n    output_dir.mkdir(parents=True, exist_ok=True)\n\n    client = get_client(args.api_url)\n    semaphore = asyncio.Semaphore(args.max_concurrency)\n\n    tasks: List[asyncio.Task] = []\n    for idx, ex in enumerate(dataset):\n        if idx >= args.num_prompts:\n            break\n        tasks.append(\n            asyncio.create_task(\n                fetch_response(\n                    client,\n                    ex[\"context\"],\n                    ex[\"question\"],\n                    semaphore,\n                    idx,\n                    args.model,\n                    output_dir,\n                )\n            )\n        )\n\n    for _ in tqdm(\n        asyncio.as_completed(tasks), total=len(tasks), desc=\"Running benchmark\"\n    ):\n        await _\n\n\ndef analyse(args):\n    dataset = get_dataset()\n    output_dir = Path(args.output_dir)\n\n    device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n    scorer = BERTScorer(lang=\"en\", device=device)\n\n    hyps: List[str] = []\n    refs: List[str] = []\n    for idx, ex in enumerate(tqdm(dataset, desc=\"Loading responses\")):\n        if idx >= args.num_prompts:\n            break\n        pkl_file = output_dir / f\"response_{idx}.pkl\"\n        if not pkl_file.exists():\n            raise FileNotFoundError(pkl_file)\n\n        response = pickle.load(open(pkl_file, \"rb\"))\n        if isinstance(response, dict) and \"error\" in response:\n            continue\n\n        hyps.append(response.choices[0].message.content.strip())\n        refs.append(ex[\"answer\"])\n\n    if not hyps:\n        print(\"No valid responses to score!\")\n        return\n\n    batch_size = 64\n    all_f1: List[float] = []\n    for i in tqdm(range(0, len(hyps), batch_size), desc=\"Scoring batches\"):\n        h_batch = hyps[i : i + batch_size]\n        r_batch = refs[i : i + batch_size]\n        _, _, f1_scores = scorer.score(h_batch, r_batch, verbose=False)\n        all_f1.extend([float(x) for x in f1_scores])\n\n    avg = sum(all_f1) / len(all_f1)\n    print(f\"Average BERTScore (F1): {avg:.2%}\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\n        description=\"Run benchmark and evaluation in one go.\"\n    )\n    parser.add_argument(\n        \"--api-url\",\n        default=\"http://127.0.0.1:30000/v1\",\n        help=\"OpenAI‑compatible API base URL\",\n    )\n    parser.add_argument(\n        \"--model\",\n        default=\"meta-llama/Llama-4-Maverick-17B-128E-Instruct\",\n        help=\"Model name or ID, only used for model name\",\n    )\n    parser.add_argument(\n        \"--max-concurrency\", type=int, default=144, help=\"Maximum concurrent requests\"\n    )\n    parser.add_argument(\n        \"--output-dir\", default=\"tmp-output-dir\", help=\"Directory for cached responses\"\n    )\n    parser.add_argument(\n        \"--num-prompts\", type=int, default=10000, help=\"Number of prompts to run\"\n    )\n    args = parser.parse_args()\n\n    asyncio.run(benchmark(args))\n\n    analyse(args)\n"
  },
  {
    "path": "python/sglang/global_config.py",
    "content": "\"\"\"Global configurations\"\"\"\n\n# FIXME: deprecate this file and move all usage to sglang.srt.environ or sglang.__init__.py\n\n\nclass GlobalConfig:\n    \"\"\"\n    Store some global constants.\n    \"\"\"\n\n    def __init__(self):\n        # Verbosity level\n        # 0: do not output anything\n        # 2: output final text after every run\n        self.verbosity = 0\n\n        # Default backend of the language\n        self.default_backend = None\n\n        # Output tokenization configs\n        self.skip_special_tokens_in_output = True\n        self.spaces_between_special_tokens_in_out = True\n\n        # Language frontend interpreter optimization configs\n        self.enable_precache_with_tracing = True\n        self.enable_parallel_encoding = True\n\n\nglobal_config = GlobalConfig()\n"
  },
  {
    "path": "python/sglang/jit_kernel/.clang-format",
    "content": "BasedOnStyle: Google\nIndentWidth: 2\nColumnLimit: 120\nAllowShortFunctionsOnASingleLine: Empty\nDerivePointerAlignment: false\nPointerAlignment: Left\nNamespaceIndentation: None\nSortIncludes: true\nAllowShortLoopsOnASingleLine: false\nBinPackParameters: false              # Prevents packing parameters in declarations\nBinPackArguments: false               # Prevents packing arguments in function calls\nAlignAfterOpenBracket: AlwaysBreak    # Forces a break after the opening parenthesis\nAlignOperands: Align                  # Aligns arguments vertically\nPenaltyBreakBeforeFirstCallParameter: 1  # Encourages breaking before the first argument\nPenaltyReturnTypeOnItsOwnLine: 100    # Keeps return type with function name\n\nIncludeCategories:\n  - Regex: '^<sgl_kernel/.*\\.h>$'\n    Priority: 0\n  - Regex: '^<sgl_kernel/impl/.*>$'\n    Priority: 2\n  - Regex: '^<sgl_kernel/.*\\.cuh>$'\n    Priority: 1\n  - Regex: '^<.*/.*>$'\n    Priority: 3\n"
  },
  {
    "path": "python/sglang/jit_kernel/__init__.py",
    "content": ""
  },
  {
    "path": "python/sglang/jit_kernel/__main__.py",
    "content": "assert __name__ == \"__main__\"\n\n\ndef generate_clangd():\n    import logging\n    import os\n    import subprocess\n\n    from tvm_ffi.libinfo import find_dlpack_include_path, find_include_path\n\n    from sglang.jit_kernel.utils import DEFAULT_INCLUDE\n\n    logger = logging.getLogger()\n    logger.info(\"Generating .clangd file...\")\n    include_paths = [find_include_path(), find_dlpack_include_path()] + DEFAULT_INCLUDE\n    status = subprocess.run(\n        args=[\"nvidia-smi\", \"--query-gpu=compute_cap\", \"--format=csv,noheader\"],\n        capture_output=True,\n        check=True,\n    )\n    compute_cap = status.stdout.decode(\"utf-8\").strip().split(\"\\n\")[0]\n    major, minor = compute_cap.split(\".\")\n    compile_flags = \",\\n    \".join(\n        [\n            \"-xcuda\",\n            f\"--cuda-gpu-arch=sm_{major}{minor}\",\n            \"-std=c++20\",\n            \"-Wall\",\n            \"-Wextra\",\n        ]\n        + [f\"-isystem{path}\" for path in include_paths]\n    )\n    clangd_content = f\"\"\"\nCompileFlags:\n  Add: [\n    {compile_flags}\n  ]\n\"\"\"\n    if os.path.exists(\".clangd\"):\n        logger.warning(\".clangd file already exists, nothing done.\")\n        logger.warning(f\"suggested content: {clangd_content}\")\n    else:\n        with open(\".clangd\", \"w\") as f:\n            f.write(clangd_content)\n        logger.info(\".clangd file generated.\")\n\n\ngenerate_clangd()\n"
  },
  {
    "path": "python/sglang/jit_kernel/add_constant.py",
    "content": "from __future__ import annotations\n\nfrom typing import TYPE_CHECKING\n\nimport torch\n\nfrom sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args\n\nif TYPE_CHECKING:\n    from tvm_ffi.module import Module\n\n\n@cache_once\ndef _jit_add_constant_module(constant: int) -> Module:\n    args = make_cpp_args(constant)\n    return load_jit(\n        \"add_constant\",\n        *args,\n        cuda_files=[\"add_constant.cuh\"],\n        cuda_wrappers=[(\"add_constant\", f\"add_constant<{args}>\")],\n    )\n\n\ndef add_constant(src: torch.Tensor, constant: int) -> torch.Tensor:\n    dst = torch.empty_like(src)\n    module = _jit_add_constant_module(constant)\n    module.add_constant(dst, src)\n    return dst\n"
  },
  {
    "path": "python/sglang/jit_kernel/awq_dequantize.py",
    "content": "from __future__ import annotations\n\nfrom typing import TYPE_CHECKING\n\nimport torch\n\nfrom sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args\n\nif TYPE_CHECKING:\n    from tvm_ffi.module import Module\n\n\n@cache_once\ndef _jit_awq_dequantize_module(dtype: torch.dtype) -> Module:\n    args = make_cpp_args(dtype)\n    return load_jit(\n        \"awq_dequantize\",\n        *args,\n        cuda_files=[\"gemm/awq_dequantize.cuh\"],\n        cuda_wrappers=[(\"awq_dequantize\", f\"awq_dequantize<{args}>\")],\n    )\n\n\ndef awq_dequantize(\n    qweight: torch.Tensor,\n    scales: torch.Tensor,\n    qzeros: torch.Tensor,\n) -> torch.Tensor:\n    qweight_rows = qweight.shape[0]\n    qweight_cols = qweight.shape[1]\n    output = torch.empty(\n        (qweight_rows, qweight_cols * 8),\n        dtype=scales.dtype,\n        device=scales.device,\n    )\n    module = _jit_awq_dequantize_module(scales.dtype)\n    module.awq_dequantize(output, qweight, scales, qzeros)\n    return output\n"
  },
  {
    "path": "python/sglang/jit_kernel/awq_marlin_repack.py",
    "content": "from __future__ import annotations\n\nfrom typing import TYPE_CHECKING\n\nimport torch\n\nfrom sglang.jit_kernel.utils import cache_once, load_jit\n\nif TYPE_CHECKING:\n    from tvm_ffi.module import Module\n\n\n@cache_once\ndef _jit_awq_marlin_repack_module() -> Module:\n    return load_jit(\n        \"awq_marlin_repack\",\n        cuda_files=[\"gemm/marlin/awq_marlin_repack.cuh\"],\n        cuda_wrappers=[(\"awq_marlin_repack\", \"awq_marlin_repack\")],\n    )\n\n\ndef awq_marlin_repack(\n    b_q_weight: torch.Tensor,\n    size_k: int,\n    size_n: int,\n    num_bits: int,\n) -> torch.Tensor:\n    tile_size = 16\n    pack_factor = 32 // num_bits\n    out = torch.empty(\n        (size_k // tile_size, size_n * tile_size // pack_factor),\n        dtype=b_q_weight.dtype,\n        device=b_q_weight.device,\n    )\n    module = _jit_awq_marlin_repack_module()\n    module.awq_marlin_repack(out, b_q_weight, size_k, size_n, num_bits)\n    return out\n\n\ndef awq_marlin_moe_repack(\n    b_q_weight: torch.Tensor,\n    perm: torch.Tensor,\n    size_k: int,\n    size_n: int,\n    num_bits: int,\n) -> torch.Tensor:\n    num_experts = b_q_weight.shape[0]\n    assert size_k % 16 == 0\n    output = torch.empty(\n        (num_experts, size_k // 16, size_n * (num_bits // 2)),\n        device=b_q_weight.device,\n        dtype=b_q_weight.dtype,\n    )\n    for e in range(num_experts):\n        output[e] = awq_marlin_repack(b_q_weight[e], size_k, size_n, num_bits)\n    return output\n"
  },
  {
    "path": "python/sglang/jit_kernel/benchmark/bench_awq_dequantize.py",
    "content": "import itertools\n\nimport torch\nimport triton\nimport triton.testing\n\nfrom sglang.jit_kernel.awq_dequantize import awq_dequantize as jit_awq_dequantize\nfrom sglang.jit_kernel.benchmark.utils import is_in_ci, run_benchmark\n\ntry:\n    from sgl_kernel import awq_dequantize as aot_awq_dequantize\n\n    AOT_AVAILABLE = True\nexcept ImportError:\n    AOT_AVAILABLE = False\n\nIS_CI = is_in_ci()\n\nif IS_CI:\n    qweight_row_range = [128]\n    qweight_cols_range = [16]\nelse:\n    qweight_row_range = [128, 256, 512, 1024, 3584]\n    qweight_cols_range = [16, 32, 64, 128, 448]\n\nconfigs = list(itertools.product(qweight_row_range, qweight_cols_range))\n\n\ndef check_correctness():\n    if not AOT_AVAILABLE:\n        print(\"sgl_kernel AOT not available, skipping correctness check\")\n        return\n\n    qweight_row, qweight_col = 128, 16\n    device = torch.device(\"cuda\")\n    qweight = torch.randint(\n        0,\n        torch.iinfo(torch.int32).max,\n        (qweight_row, qweight_col),\n        dtype=torch.int32,\n        device=device,\n    )\n    group_size = qweight_row\n    scales_row = qweight_row // group_size\n    scales_col = qweight_col * 8\n    scales = torch.rand(scales_row, scales_col, dtype=torch.float16, device=device)\n    qzeros = torch.randint(\n        0,\n        torch.iinfo(torch.int32).max,\n        (scales_row, qweight_col),\n        dtype=torch.int32,\n        device=device,\n    )\n\n    jit_out = jit_awq_dequantize(qweight, scales, qzeros)\n    aot_out = aot_awq_dequantize(qweight, scales, qzeros)\n    torch.cuda.synchronize()\n    torch.testing.assert_close(jit_out, aot_out, rtol=0, atol=0)\n    print(\"Correctness check passed (JIT vs AOT)\")\n\n\nif AOT_AVAILABLE:\n    line_vals = [\"jit\", \"aot\"]\n    line_names = [\"JIT Kernel\", \"AOT Kernel\"]\n    styles = [(\"blue\", \"-\"), (\"green\", \"-\")]\nelse:\n    line_vals = [\"jit\"]\n    line_names = [\"JIT Kernel\"]\n    styles = [(\"blue\", \"-\")]\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=[\"qweight_row\", \"qweight_col\"],\n        x_vals=configs,\n        line_arg=\"provider\",\n        line_vals=line_vals,\n        line_names=line_names,\n        styles=styles,\n        ylabel=\"us\",\n        plot_name=\"awq-dequantize-jit-vs-aot\",\n        args={},\n    )\n)\ndef benchmark(qweight_row, qweight_col, provider):\n    device = torch.device(\"cuda\")\n    qweight = torch.randint(\n        0,\n        torch.iinfo(torch.int32).max,\n        (qweight_row, qweight_col),\n        dtype=torch.int32,\n        device=device,\n    )\n    group_size = qweight_row\n    scales_row = qweight_row // group_size\n    scales_col = qweight_col * 8\n    scales = torch.rand(scales_row, scales_col, dtype=torch.float16, device=device)\n    qzeros = torch.randint(\n        0,\n        torch.iinfo(torch.int32).max,\n        (scales_row, qweight_col),\n        dtype=torch.int32,\n        device=device,\n    )\n\n    if provider == \"jit\":\n        fn = lambda: jit_awq_dequantize(qweight, scales, qzeros)\n    elif provider == \"aot\":\n        fn = lambda: aot_awq_dequantize(qweight, scales, qzeros)\n    else:\n        raise ValueError(f\"Unknown provider: {provider}\")\n\n    return run_benchmark(fn)\n\n\nif __name__ == \"__main__\":\n    check_correctness()\n    benchmark.run(print_data=True)\n"
  },
  {
    "path": "python/sglang/jit_kernel/benchmark/bench_awq_marlin_moe_repack.py",
    "content": "import numpy as np\nimport torch\nimport triton\nimport triton.testing\nfrom sgl_kernel.scalar_type import scalar_types\n\nfrom sglang.jit_kernel.awq_marlin_repack import (\n    awq_marlin_moe_repack as jit_awq_marlin_moe_repack,\n)\nfrom sglang.jit_kernel.benchmark.utils import is_in_ci, run_benchmark\nfrom sglang.srt.layers.quantization.utils import pack_cols, quantize_weights\n\nAOT_AVAILABLE = hasattr(torch.ops.sgl_kernel, \"awq_marlin_moe_repack\") and hasattr(\n    torch.ops.sgl_kernel.awq_marlin_moe_repack, \"default\"\n)\n\nIS_CI = is_in_ci()\n\nNUM_BITS = 4\nGROUP_SIZE = 128\nSIZE_N = 4096\n\n\ndef awq_pack(q_w, num_bits, size_k, size_n):\n    if num_bits == 4:\n        interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])\n    elif num_bits == 8:\n        interleave = np.array([0, 2, 1, 3])\n    else:\n        raise Exception(\"num_bits must be 4 or 8, got {}\".format(num_bits))\n\n    q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel()\n    q_w = q_w.reshape((-1, size_n)).contiguous()\n    return pack_cols(q_w, num_bits, size_k, size_n)\n\n\ndef make_moe_weights(num_experts, size_k, size_n, num_bits, group_size):\n    pack_factor = 32 // num_bits\n    b_q_weight = torch.empty(\n        (num_experts, size_k, size_n // pack_factor),\n        dtype=torch.int32,\n        device=\"cuda\",\n    )\n    for e in range(num_experts):\n        b_weight = torch.randn((size_k, size_n), dtype=torch.float16, device=\"cuda\")\n        w_ref, q_w, s, zp = quantize_weights(\n            b_weight, scalar_types.uint4, min(group_size, size_k), zero_points=True\n        )\n        b_q_weight[e] = awq_pack(q_w, num_bits, size_k, size_n)\n    perm = torch.empty((num_experts, 0), dtype=torch.int32, device=\"cuda\")\n    return b_q_weight, perm\n\n\ndef check_correctness():\n    if not AOT_AVAILABLE:\n        print(\"sgl_kernel AOT not available, skipping correctness check\")\n        return\n\n    num_experts = 4\n    size_k = 1024\n    b_q_weight, perm = make_moe_weights(\n        num_experts, size_k, SIZE_N, NUM_BITS, GROUP_SIZE\n    )\n\n    out_jit = jit_awq_marlin_moe_repack(b_q_weight, perm, size_k, SIZE_N, NUM_BITS)\n    out_aot = torch.ops.sgl_kernel.awq_marlin_moe_repack.default(\n        b_q_weight, perm, size_k, SIZE_N, NUM_BITS\n    )\n    torch.cuda.synchronize()\n    torch.testing.assert_close(out_jit, out_aot, rtol=0, atol=0)\n    print(\"Correctness check passed (JIT vs AOT)\")\n\n\nif IS_CI:\n    expert_range = [2, 4]\nelse:\n    expert_range = [2, 4, 8, 16]\n\nif AOT_AVAILABLE:\n    line_vals = [\"jit\", \"aot\"]\n    line_names = [\"JIT Kernel\", \"AOT Kernel\"]\n    styles = [(\"blue\", \"-\"), (\"green\", \"-\")]\nelse:\n    line_vals = [\"jit\"]\n    line_names = [\"JIT Kernel\"]\n    styles = [(\"blue\", \"-\")]\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=[\"num_experts\"],\n        x_vals=expert_range,\n        line_arg=\"provider\",\n        line_vals=line_vals,\n        line_names=line_names,\n        styles=styles,\n        ylabel=\"us\",\n        plot_name=\"awq-marlin-moe-repack-performance\",\n        args={\"size_k\": 4096, \"size_n\": SIZE_N, \"num_bits\": NUM_BITS},\n    )\n)\ndef benchmark(num_experts, size_k, size_n, num_bits, provider):\n    group_size = min(GROUP_SIZE, size_k)\n    b_q_weight, perm = make_moe_weights(\n        num_experts, size_k, size_n, num_bits, group_size\n    )\n\n    if provider == \"jit\":\n        fn = lambda: jit_awq_marlin_moe_repack(\n            b_q_weight, perm, size_k, size_n, num_bits\n        )\n    elif provider == \"aot\":\n        fn = lambda: torch.ops.sgl_kernel.awq_marlin_moe_repack.default(\n            b_q_weight, perm, size_k, size_n, num_bits\n        )\n    else:\n        raise ValueError(f\"Unknown provider: {provider}\")\n\n    return run_benchmark(fn)\n\n\nif __name__ == \"__main__\":\n    check_correctness()\n    benchmark.run(print_data=True)\n"
  },
  {
    "path": "python/sglang/jit_kernel/benchmark/bench_awq_marlin_repack.py",
    "content": "import numpy as np\nimport torch\nimport triton\nimport triton.testing\nfrom sgl_kernel.scalar_type import scalar_types\n\nfrom sglang.jit_kernel.awq_marlin_repack import (\n    awq_marlin_repack as jit_awq_marlin_repack,\n)\nfrom sglang.jit_kernel.benchmark.utils import is_in_ci, run_benchmark\nfrom sglang.srt.layers.quantization.utils import pack_cols, quantize_weights\n\nAOT_AVAILABLE = hasattr(torch.ops.sgl_kernel, \"awq_marlin_repack\") and hasattr(\n    torch.ops.sgl_kernel.awq_marlin_repack, \"default\"\n)\n\nIS_CI = is_in_ci()\n\nSIZE_K = 4096\nSIZE_N = 4096\nNUM_BITS = 4\nGROUP_SIZE = 128\n\n\ndef awq_pack(q_w, num_bits, size_k, size_n):\n    if num_bits == 4:\n        interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])\n    elif num_bits == 8:\n        interleave = np.array([0, 2, 1, 3])\n    else:\n        raise Exception(\"num_bits must be 4 or 8, got {}\".format(num_bits))\n\n    q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel()\n    q_w = q_w.reshape((-1, size_n)).contiguous()\n    return pack_cols(q_w, num_bits, size_k, size_n)\n\n\n_b_weight = torch.randn((SIZE_K, SIZE_N), dtype=torch.float16, device=\"cuda\")\n_w_ref, _q_w, _s, _zp = quantize_weights(\n    _b_weight, scalar_types.uint4, GROUP_SIZE, zero_points=True\n)\n_q_w_awq = awq_pack(_q_w, NUM_BITS, SIZE_K, SIZE_N)\n\n\ndef check_correctness():\n    if not AOT_AVAILABLE:\n        print(\"sgl_kernel AOT not available, skipping correctness check\")\n        return\n    out_jit = jit_awq_marlin_repack(_q_w_awq, SIZE_K, SIZE_N, NUM_BITS)\n    out_aot = torch.ops.sgl_kernel.awq_marlin_repack.default(\n        _q_w_awq, SIZE_K, SIZE_N, NUM_BITS\n    )\n    torch.cuda.synchronize()\n    torch.testing.assert_close(out_jit, out_aot, rtol=0, atol=0)\n    print(\"Correctness check passed (JIT vs AOT)\")\n\n\nif IS_CI:\n    k_range = [1024, 4096]\nelse:\n    k_range = [512, 1024, 2048, 4096, 8192]\n\nif AOT_AVAILABLE:\n    line_vals = [\"jit\", \"aot\"]\n    line_names = [\"JIT Kernel\", \"AOT Kernel\"]\n    styles = [(\"blue\", \"-\"), (\"green\", \"-\")]\nelse:\n    line_vals = [\"jit\"]\n    line_names = [\"JIT Kernel\"]\n    styles = [(\"blue\", \"-\")]\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=[\"size_k\"],\n        x_vals=k_range,\n        line_arg=\"provider\",\n        line_vals=line_vals,\n        line_names=line_names,\n        styles=styles,\n        ylabel=\"us\",\n        plot_name=\"awq-marlin-repack-performance\",\n        args={\"size_n\": SIZE_N, \"num_bits\": NUM_BITS},\n    )\n)\ndef benchmark(size_k, size_n, num_bits, provider):\n    group_size = min(GROUP_SIZE, size_k)\n\n    b_weight = torch.randn((size_k, size_n), dtype=torch.float16, device=\"cuda\")\n    w_ref, q_w, s, zp = quantize_weights(\n        b_weight, scalar_types.uint4, group_size, zero_points=True\n    )\n    q_w_awq = awq_pack(q_w, num_bits, size_k, size_n)\n\n    if provider == \"jit\":\n        fn = lambda: jit_awq_marlin_repack(q_w_awq, size_k, size_n, num_bits)\n    elif provider == \"aot\":\n        fn = lambda: torch.ops.sgl_kernel.awq_marlin_repack.default(\n            q_w_awq, size_k, size_n, num_bits\n        )\n    else:\n        raise ValueError(f\"Unknown provider: {provider}\")\n\n    return run_benchmark(fn)\n\n\nif __name__ == \"__main__\":\n    check_correctness()\n    benchmark.run(print_data=True)\n"
  },
  {
    "path": "python/sglang/jit_kernel/benchmark/bench_concat_mla.py",
    "content": "import itertools\n\nimport torch\nimport triton\nimport triton.testing\nfrom sgl_kernel import concat_mla_absorb_q as aot_absorb_q\nfrom sgl_kernel import concat_mla_k as aot_k\n\nfrom sglang.jit_kernel.benchmark.utils import is_in_ci, run_benchmark\nfrom sglang.jit_kernel.concat_mla import concat_mla_absorb_q as jit_absorb_q\nfrom sglang.jit_kernel.concat_mla import concat_mla_k as jit_k\n\nIS_CI = is_in_ci()\n\nNUM_LOCAL_HEADS = 128\nQK_NOPE_HEAD_DIM = 128\nQK_ROPE_HEAD_DIM = 64\nK_HEAD_DIM = QK_NOPE_HEAD_DIM + QK_ROPE_HEAD_DIM\n\nA_LAST_DIM = 512\nB_LAST_DIM = 64\n\nDTYPE = torch.bfloat16\nDEVICE = \"cuda\"\n\n\ndef aot_concat_mla_k(k, k_nope, k_rope):\n    aot_k(k, k_nope, k_rope)\n\n\ndef jit_concat_mla_k(k, k_nope, k_rope):\n    jit_k(k, k_nope, k_rope)\n\n\ndef torch_concat_mla_k(k, k_nope, k_rope):\n    nope_head_dim = k_nope.shape[-1]\n    k[:, :, :nope_head_dim] = k_nope\n    k[:, :, nope_head_dim:] = k_rope.expand(-1, k.shape[1], -1)\n\n\ndef aot_concat_mla_absorb_q(a, b):\n    return aot_absorb_q(a, b)\n\n\ndef jit_concat_mla_absorb_q(a, b):\n    return jit_absorb_q(a, b)\n\n\ndef torch_concat_mla_absorb_q(a, b, out):\n    a_last_dim = a.shape[-1]\n    out[:, :, :a_last_dim] = a\n    out[:, :, a_last_dim:] = b\n\n\nif IS_CI:\n    NUM_TOKENS_VALS = [256, 1024]\nelse:\n    NUM_TOKENS_VALS = [256, 512, 1024, 2048, 4096, 8192, 16384, 32768]\n\nK_LINE_VALS = [\"aot\", \"jit\", \"torch\"]\nK_LINE_NAMES = [\"SGL AOT Kernel\", \"SGL JIT Kernel\", \"PyTorch\"]\nK_STYLES = [(\"orange\", \"-\"), (\"blue\", \"--\"), (\"green\", \"-.\")]\n\n\ndef _create_concat_mla_k_data(num_tokens):\n    \"\"\"Allocate oversized containers and slice to produce non-contiguous tensors.\"\"\"\n    k_nope_container = torch.randn(\n        (num_tokens, NUM_LOCAL_HEADS, QK_NOPE_HEAD_DIM + 128),\n        dtype=DTYPE,\n        device=DEVICE,\n    )\n    k_nope = k_nope_container[:, :, :QK_NOPE_HEAD_DIM]\n\n    k_rope_container = torch.randn(\n        (num_tokens, 1, 128 + QK_ROPE_HEAD_DIM),\n        dtype=DTYPE,\n        device=DEVICE,\n    )\n    k_rope = k_rope_container[:, :, -QK_ROPE_HEAD_DIM:]\n\n    k = torch.empty(\n        (num_tokens, NUM_LOCAL_HEADS, K_HEAD_DIM),\n        dtype=DTYPE,\n        device=DEVICE,\n    )\n    return k, k_nope, k_rope\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=[\"num_tokens\"],\n        x_vals=NUM_TOKENS_VALS,\n        line_arg=\"provider\",\n        line_vals=K_LINE_VALS,\n        line_names=K_LINE_NAMES,\n        styles=K_STYLES,\n        ylabel=\"us\",\n        plot_name=\"concat-mla-k-performance\",\n        args={},\n    )\n)\ndef bench_concat_mla_k(num_tokens: int, provider: str):\n    k, k_nope, k_rope = _create_concat_mla_k_data(num_tokens)\n\n    FN_MAP = {\n        \"aot\": aot_concat_mla_k,\n        \"jit\": jit_concat_mla_k,\n        \"torch\": torch_concat_mla_k,\n    }\n    fn = lambda: FN_MAP[provider](k, k_nope, k_rope)\n    return run_benchmark(fn)\n\n\nif IS_CI:\n    ABSORB_Q_VALS = list(itertools.product([4, 16], [16]))\nelse:\n    ABSORB_Q_VALS = list(itertools.product([1, 4, 8, 16, 32], [1, 8, 32, 128]))\n\nQ_LINE_VALS = [\"aot\", \"jit\", \"torch\"]\nQ_LINE_NAMES = [\"SGL AOT Kernel\", \"SGL JIT Kernel\", \"PyTorch\"]\nQ_STYLES = [(\"orange\", \"-\"), (\"blue\", \"--\"), (\"green\", \"-.\")]\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=[\"dim_0\", \"dim_1\"],\n        x_vals=ABSORB_Q_VALS,\n        line_arg=\"provider\",\n        line_vals=Q_LINE_VALS,\n        line_names=Q_LINE_NAMES,\n        styles=Q_STYLES,\n        ylabel=\"us\",\n        plot_name=\"concat-mla-absorb-q-performance\",\n        args={},\n    )\n)\ndef bench_concat_mla_absorb_q(dim_0: int, dim_1: int, provider: str):\n    a = torch.randn(dim_0, dim_1, A_LAST_DIM, dtype=DTYPE, device=DEVICE)\n    b = torch.randn(dim_0, dim_1, B_LAST_DIM, dtype=DTYPE, device=DEVICE)\n\n    if provider == \"torch\":\n        out = torch.empty(\n            dim_0, dim_1, A_LAST_DIM + B_LAST_DIM, dtype=DTYPE, device=DEVICE\n        )\n        fn = lambda: torch_concat_mla_absorb_q(a, b, out)\n    else:\n        FN_MAP = {\n            \"aot\": aot_concat_mla_absorb_q,\n            \"jit\": jit_concat_mla_absorb_q,\n        }\n        fn = lambda: FN_MAP[provider](a, b)\n\n    return run_benchmark(fn)\n\n\nif __name__ == \"__main__\":\n    bench_concat_mla_k.run(print_data=True)\n    bench_concat_mla_absorb_q.run(print_data=True)\n"
  },
  {
    "path": "python/sglang/jit_kernel/benchmark/bench_fused_add_rmsnorm.py",
    "content": "import itertools\n\nimport torch\nimport triton\nimport triton.testing\nfrom flashinfer import fused_add_rmsnorm as fi_fused_add_rmsnorm\n\nfrom sglang.jit_kernel.benchmark.utils import is_in_ci, run_benchmark\nfrom sglang.jit_kernel.norm import fused_add_rmsnorm as jit_fused_add_rmsnorm\n\nIS_CI = is_in_ci()\n\n\ndef sglang_jit_fused_add_rmsnorm(\n    input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float\n) -> None:\n    jit_fused_add_rmsnorm(input, residual, weight, eps)\n\n\ndef flashinfer_fused_add_rmsnorm(\n    input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float\n) -> None:\n    fi_fused_add_rmsnorm(input, residual, weight, eps=eps)\n\n\nDTYPE = torch.bfloat16\nDEVICE = \"cuda\"\n\nif IS_CI:\n    BS_LIST = [16]\n    HIDDEN_SIZE_LIST = [512, 2048]\nelse:\n    BS_LIST = [2**n for n in range(0, 14)]\n    HIDDEN_SIZE_LIST = [1536, 3072, 4096, 5120, 8192]\n\nLINE_VALS = [\"jit\", \"flashinfer\"]\nLINE_NAMES = [\"SGL JIT Kernel\", \"FlashInfer\"]\nSTYLES = [(\"orange\", \"-\"), (\"blue\", \"--\"), (\"green\", \"-.\"), (\"red\", \":\")]\n\nconfigs = list(itertools.product(HIDDEN_SIZE_LIST, BS_LIST))\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=[\"hidden_size\", \"batch_size\"],\n        x_vals=configs,\n        line_arg=\"provider\",\n        line_vals=LINE_VALS,\n        line_names=LINE_NAMES,\n        styles=STYLES,\n        ylabel=\"us\",\n        plot_name=\"fused-add-rmsnorm-performance\",\n        args={},\n    )\n)\ndef benchmark(hidden_size: int, batch_size: int, provider: str):\n    input = torch.randn((batch_size, hidden_size), dtype=DTYPE, device=DEVICE)\n    residual = torch.randn((batch_size, hidden_size), dtype=DTYPE, device=DEVICE)\n    weight = torch.randn(hidden_size, dtype=DTYPE, device=DEVICE)\n    FN_MAP = {\n        \"jit\": sglang_jit_fused_add_rmsnorm,\n        \"flashinfer\": flashinfer_fused_add_rmsnorm,\n    }\n    fn = lambda: FN_MAP[provider](\n        input.clone(), residual.clone(), weight, torch.finfo(torch.bfloat16).eps\n    )\n    return run_benchmark(fn)\n\n\nif __name__ == \"__main__\":\n    benchmark.run(print_data=True)\n"
  },
  {
    "path": "python/sglang/jit_kernel/benchmark/bench_fused_norm_scale_shift.py",
    "content": "# Benchmarks SGLang fused layernorm/rmsnorm scale shift kernels\n# 1. fused_norm_scale_shift\n# 2. fused_scale_residual_norm_scale_shift\nimport itertools\nfrom typing import Tuple\n\nimport torch\nimport triton\nimport triton.testing\n\nfrom sglang.jit_kernel.benchmark.utils import is_in_ci, run_benchmark_no_cudagraph\nfrom sglang.multimodal_gen.runtime.layers.layernorm import (\n    LayerNormScaleShift,\n    RMSNormScaleShift,\n    ScaleResidualLayerNormScaleShift,\n    ScaleResidualRMSNormScaleShift,\n)\n\nif is_in_ci():\n    B_RANGE, S_RANGE, D_RANGE = [1], [128], [1024]\nelse:\n    B_RANGE, S_RANGE, D_RANGE = [1], [128, 1024, 4096], [1024, 3072, 4096]\n\nNORM_TYPE_RANGE = [\"layer\", \"rms\"]\nAFFINE_RANGE = [True, False]\nDTYPE = torch.bfloat16\nDEVICE = \"cuda\"\nEPS = 1e-5\nLINE_VALS = [\"native\", \"cuda\"]\nLINE_NAMES = [\"SGLang Native\", \"SGLang Fused\"]\nSTYLES = [(\"red\", \"-\"), (\"blue\", \"--\")]\nconfig = list(\n    itertools.product(B_RANGE, S_RANGE, D_RANGE, NORM_TYPE_RANGE, AFFINE_RANGE)\n)\n\n\ndef preprocess_layer(layer, affine: bool, D: int, DTYPE: torch.dtype):\n    if affine:\n        weight = torch.randn(D, dtype=DTYPE, device=DEVICE)\n        bias = torch.randn(D, dtype=DTYPE, device=DEVICE)\n        with torch.no_grad():\n            layer.norm.weight.copy_(weight)\n            if hasattr(layer.norm, \"bias\"):\n                layer.norm.bias.copy_(bias)\n    layer.requires_grad_(False)\n    return layer.to(DEVICE)\n\n\n# ============================================================================\n# Benchmark 1: fused_norm_scale_shift\n# ============================================================================\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=[\"B\", \"S\", \"D\", \"norm_type\", \"affine\"],\n        x_vals=config,\n        line_arg=\"provider\",\n        line_vals=LINE_VALS,\n        line_names=LINE_NAMES,\n        styles=STYLES,\n        ylabel=\"us\",\n        plot_name=\"fused_norm_scale_shift\",\n        args={},\n    )\n)\ndef bench_fused_norm_scale_shift(\n    B: int, S: int, D: int, norm_type, affine: bool, provider: str\n) -> Tuple[float, float, float]:\n    x = torch.randn(B, S, D, dtype=DTYPE, device=DEVICE)\n    scale = torch.randn(B, S, D, dtype=DTYPE, device=DEVICE)\n    shift = torch.randn(B, S, D, dtype=DTYPE, device=DEVICE)\n    if norm_type == \"layer\":\n        layer = LayerNormScaleShift(D, EPS, affine, dtype=DTYPE)\n    else:\n        layer = RMSNormScaleShift(D, EPS, affine, dtype=DTYPE)\n    layer = preprocess_layer(layer, affine, D, DTYPE)\n    if provider == \"native\":\n        fn = lambda: layer.forward_native(x, shift, scale)\n    else:\n        fn = lambda: layer.forward_cuda(x, shift, scale)\n\n    return run_benchmark_no_cudagraph(fn)\n\n\n# ============================================================================\n# Benchmark 2: fused_scale_residual_norm_scale_shift\n# ============================================================================\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=[\"B\", \"S\", \"D\", \"norm_type\", \"affine\"],\n        x_vals=config,\n        line_arg=\"provider\",\n        line_vals=LINE_VALS,\n        line_names=LINE_NAMES,\n        styles=STYLES,\n        ylabel=\"us\",\n        plot_name=\"fused_scale_residual_norm_scale_shift\",\n        args={},\n    )\n)\ndef bench_fused_scale_residual_norm_scale_shift(\n    B: int, S: int, D: int, norm_type, affine: bool, provider: str\n) -> Tuple[float, float, float]:\n    residual = torch.randn(B, S, D, dtype=DTYPE, device=DEVICE)\n    x = torch.randn(B, S, D, dtype=DTYPE, device=DEVICE)\n    scale = torch.randn(B, S, D, dtype=DTYPE, device=DEVICE)\n    shift = torch.randn(B, S, D, dtype=DTYPE, device=DEVICE)\n    gate = torch.randn(B, 1, D, dtype=DTYPE, device=DEVICE)\n    if norm_type == \"layer\":\n        layer = ScaleResidualLayerNormScaleShift(D, EPS, affine, dtype=DTYPE).to(DEVICE)\n    else:\n        layer = ScaleResidualRMSNormScaleShift(D, EPS, affine, dtype=DTYPE).to(DEVICE)\n    layer = preprocess_layer(layer, affine, D, DTYPE)\n    if provider == \"native\":\n        fn = lambda: layer.forward_native(residual, x, gate, shift, scale)\n    else:\n        fn = lambda: layer.forward_cuda(residual, x, gate, shift, scale)\n\n    return run_benchmark_no_cudagraph(fn)\n\n\nif __name__ == \"__main__\":\n    print(f\"\\n{'='*80}\")\n    print(\"Benchmark: fused_norm_scale_shift\")\n    print(f\"{'='*80}\\n\")\n    bench_fused_norm_scale_shift.run(print_data=True)\n\n    print(f\"\\n{'='*80}\")\n    print(\"Benchmark: fused_scale_residual_norm_scale_shift\")\n    print(f\"{'='*80}\\n\")\n    bench_fused_scale_residual_norm_scale_shift.run(print_data=True)\n"
  },
  {
    "path": "python/sglang/jit_kernel/benchmark/bench_gptq_marlin.py",
    "content": "import torch\nimport triton\nimport triton.testing\nfrom sgl_kernel.scalar_type import scalar_types\n\nfrom sglang.jit_kernel.benchmark.utils import is_in_ci, run_benchmark\nfrom sglang.jit_kernel.gptq_marlin import gptq_marlin_gemm as jit_gptq_marlin_gemm\nfrom sglang.srt.layers.quantization.marlin_utils import marlin_make_workspace\nfrom sglang.test.test_marlin_utils import marlin_quantize\n\nAOT_AVAILABLE = hasattr(torch.ops.sgl_kernel, \"gptq_marlin_gemm\") and hasattr(\n    torch.ops.sgl_kernel.gptq_marlin_gemm, \"default\"\n)\n\nIS_CI = is_in_ci()\n\nSIZE_K = 4096\nSIZE_N = 4096\nGROUP_SIZE = 128\nQUANT_TYPE = scalar_types.uint4b8\n\n_b_weight = torch.randn((SIZE_K, SIZE_N), dtype=torch.float16, device=\"cuda\")\n_w_ref, _marlin_q_w, _marlin_s, _g_idx, _sort_indices, _ = marlin_quantize(\n    _b_weight, QUANT_TYPE, GROUP_SIZE, act_order=False\n)\n_workspace = marlin_make_workspace(_w_ref.device)\n\n\ndef _run_gemm(fn, a):\n    return fn(\n        a,\n        None,\n        _marlin_q_w,\n        _marlin_s,\n        None,\n        None,\n        _g_idx,\n        _sort_indices,\n        _workspace,\n        QUANT_TYPE,\n        a.shape[0],\n        SIZE_N,\n        SIZE_K,\n        is_k_full=True,\n        use_atomic_add=False,\n        use_fp32_reduce=False,\n        is_zp_float=False,\n    )\n\n\ndef _run_gemm_aot(a):\n    return torch.ops.sgl_kernel.gptq_marlin_gemm.default(\n        a,\n        None,\n        _marlin_q_w,\n        _marlin_s,\n        None,\n        None,\n        _g_idx,\n        _sort_indices,\n        _workspace,\n        QUANT_TYPE.id,\n        a.shape[0],\n        SIZE_N,\n        SIZE_K,\n        True,\n        False,\n        False,\n        False,\n    )\n\n\ndef check_correctness():\n    if not AOT_AVAILABLE:\n        print(\"sgl_kernel AOT not available, skipping correctness check\")\n        return\n    a = torch.randn((16, SIZE_K), dtype=torch.float16, device=\"cuda\")\n    out_jit = _run_gemm(jit_gptq_marlin_gemm, a)\n    out_aot = _run_gemm_aot(a)\n    torch.testing.assert_close(out_jit, out_aot, rtol=1e-3, atol=1e-3)\n    print(\"Correctness check passed (JIT vs AOT)\")\n\n\nif IS_CI:\n    m_range = [1, 16, 128]\nelse:\n    m_range = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]\n\nif AOT_AVAILABLE:\n    line_vals = [\"jit\", \"aot\"]\n    line_names = [\"JIT Kernel\", \"AOT Kernel\"]\n    styles = [(\"blue\", \"-\"), (\"green\", \"-\")]\nelse:\n    line_vals = [\"jit\"]\n    line_names = [\"JIT Kernel\"]\n    styles = [(\"blue\", \"-\")]\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=[\"size_m\"],\n        x_vals=m_range,\n        line_arg=\"provider\",\n        line_vals=line_vals,\n        line_names=line_names,\n        styles=styles,\n        ylabel=\"us\",\n        plot_name=\"gptq-marlin-gemm-performance\",\n        args={},\n    )\n)\ndef benchmark(size_m, provider):\n    device = torch.device(\"cuda\")\n    a = torch.randn((size_m, SIZE_K), dtype=torch.float16, device=device)\n\n    if provider == \"jit\":\n        fn = lambda: _run_gemm(jit_gptq_marlin_gemm, a)\n    elif provider == \"aot\":\n        fn = lambda: _run_gemm_aot(a)\n    else:\n        raise ValueError(f\"Unknown provider: {provider}\")\n\n    return run_benchmark(fn)\n\n\nif __name__ == \"__main__\":\n    check_correctness()\n    benchmark.run(print_data=True)\n"
  },
  {
    "path": "python/sglang/jit_kernel/benchmark/bench_gptq_marlin_repack.py",
    "content": "import torch\nimport triton\nimport triton.testing\nfrom sgl_kernel.scalar_type import scalar_types\n\nfrom sglang.jit_kernel.benchmark.utils import is_in_ci, run_benchmark\nfrom sglang.jit_kernel.gptq_marlin_repack import gptq_marlin_repack as jit_fn\nfrom sglang.srt.layers.quantization.utils import gptq_quantize_weights, pack_rows\n\nAOT_AVAILABLE = hasattr(torch.ops.sgl_kernel, \"gptq_marlin_repack\") and hasattr(\n    torch.ops.sgl_kernel.gptq_marlin_repack, \"default\"\n)\n\nIS_CI = is_in_ci()\n\nSIZE_N = 4096\nNUM_BITS = 4\nQUANT_TYPE = scalar_types.uint4b8\nGROUP_SIZE = 128\n\n_cache = {}\n\n\ndef _get_inputs(size_k):\n    if size_k not in _cache:\n        size_n = SIZE_N\n        b_weight = torch.randn((size_k, size_n), dtype=torch.float16, device=\"cuda\")\n        _, q_w, _, _, _ = gptq_quantize_weights(\n            b_weight, QUANT_TYPE, GROUP_SIZE, act_order=False\n        )\n        q_w_gptq = pack_rows(q_w, NUM_BITS, size_k, size_n)\n        sort_indices = torch.empty(0, dtype=torch.int, device=\"cuda\")\n        _cache[size_k] = (q_w_gptq, sort_indices)\n    return _cache[size_k]\n\n\ndef check_correctness():\n    if not AOT_AVAILABLE:\n        print(\"sgl_kernel AOT not available, skipping correctness check\")\n        return\n    size_k = 4096\n    q_w_gptq, sort_indices = _get_inputs(size_k)\n    out_jit = jit_fn(q_w_gptq, sort_indices, size_k, SIZE_N, NUM_BITS)\n    out_aot = torch.ops.sgl_kernel.gptq_marlin_repack.default(\n        q_w_gptq, sort_indices, size_k, SIZE_N, NUM_BITS\n    )\n    torch.testing.assert_close(out_jit, out_aot, rtol=0, atol=0)\n    print(\"Correctness check passed (JIT vs AOT)\")\n\n\nif IS_CI:\n    k_range = [128, 1024, 4096]\nelse:\n    k_range = [128, 256, 512, 1024, 2048, 4096, 8192]\n\nif AOT_AVAILABLE:\n    line_vals = [\"jit\", \"aot\"]\n    line_names = [\"JIT Kernel\", \"AOT Kernel\"]\n    styles = [(\"blue\", \"-\"), (\"green\", \"-\")]\nelse:\n    line_vals = [\"jit\"]\n    line_names = [\"JIT Kernel\"]\n    styles = [(\"blue\", \"-\")]\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=[\"size_k\"],\n        x_vals=k_range,\n        line_arg=\"provider\",\n        line_vals=line_vals,\n        line_names=line_names,\n        styles=styles,\n        ylabel=\"us\",\n        plot_name=\"gptq-marlin-repack-performance\",\n        args={},\n    )\n)\ndef benchmark(size_k, provider):\n    q_w_gptq, sort_indices = _get_inputs(size_k)\n\n    if provider == \"jit\":\n        fn = lambda: jit_fn(q_w_gptq, sort_indices, size_k, SIZE_N, NUM_BITS)\n    elif provider == \"aot\":\n        fn = lambda: torch.ops.sgl_kernel.gptq_marlin_repack.default(\n            q_w_gptq, sort_indices, size_k, SIZE_N, NUM_BITS\n        )\n    else:\n        raise ValueError(f\"Unknown provider: {provider}\")\n\n    return run_benchmark(fn)\n\n\nif __name__ == \"__main__\":\n    check_correctness()\n    benchmark.run(print_data=True)\n"
  },
  {
    "path": "python/sglang/jit_kernel/benchmark/bench_hadamard.py",
    "content": "import itertools\nimport math\nfrom typing import Tuple\n\nimport torch\nimport torch.nn.functional as F\nimport triton\nimport triton.testing\n\nfrom sglang.jit_kernel.benchmark.utils import (\n    DEFAULT_DEVICE,\n    DEFAULT_DTYPE,\n    get_benchmark_range,\n    run_benchmark,\n)\nfrom sglang.jit_kernel.hadamard import hadamard_transform\n\n# AOT kernel: might not be available in all environments.\n# This is used for performance baseline comparison.\ntry:\n    from sgl_kernel import hadamard_transform as hadamard_transform_aot\n\n    AOT_AVAILABLE = True\nexcept Exception:\n    AOT_AVAILABLE = False\n\n# Naive reference implementation using scipy hadamard matrix.\ntry:\n    from scipy.linalg import hadamard\n\n    SCIPY_AVAILABLE = True\nexcept ImportError:\n    SCIPY_AVAILABLE = False\n\n# CI environment uses simplified parameters\nbatch_sizes = get_benchmark_range(\n    full_range=[1, 16, 64, 256],\n    ci_range=[16],\n)\ndim_range = get_benchmark_range(\n    full_range=[64, 256, 1024, 4096, 8192, 16384, 32768],\n    ci_range=[1024],\n)\n\n\n# Naive reference implementation using precomputed scipy hadamard matrix.\ndef torch_hadamard_transform(x, scale, H, dim, dim_padded):\n    flat = x.reshape(-1, dim)\n    if dim != dim_padded:\n        flat = F.pad(flat, (0, dim_padded - dim))\n    out = F.linear(flat, H) * scale\n    return out[..., :dim].reshape(x.shape)\n\n\navailable_providers = [\"jit_kernel\"]\navailable_names = [\"JIT Kernel\"]\navailable_styles = [(\"red\", \"-\")]\n\nif AOT_AVAILABLE:\n    available_providers.insert(0, \"aot_kernel\")\n    available_names.insert(0, \"AOT Kernel\")\n    available_styles.insert(0, (\"green\", \"-\"))\n\nif SCIPY_AVAILABLE:\n    available_providers.append(\"naive\")\n    available_names.append(\"Naive (scipy)\")\n    available_styles.append((\"blue\", \"-\"))\n\nconfigs = list(itertools.product(batch_sizes, dim_range))\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=[\"batch_size\", \"dim\"],\n        x_vals=[list(c) for c in configs],\n        line_arg=\"provider\",\n        line_vals=available_providers,\n        line_names=available_names,\n        styles=available_styles,\n        ylabel=\"us\",\n        plot_name=\"hadamard-transform-performance\",\n        args={},\n    )\n)\ndef benchmark(batch_size: int, dim: int, provider: str) -> Tuple[float, float, float]:\n    scale = 1.0 / math.sqrt(dim)\n    x = torch.randn(batch_size, dim, device=DEFAULT_DEVICE, dtype=DEFAULT_DTYPE)\n\n    FN_MAP = {\n        \"jit_kernel\": lambda: hadamard_transform(x.clone(), scale=scale),\n    }\n    if AOT_AVAILABLE:\n        FN_MAP[\"aot_kernel\"] = lambda: hadamard_transform_aot(x.clone(), scale=scale)\n    if SCIPY_AVAILABLE:\n        # Precompute Hadamard matrix on GPU to avoid CPU-GPU transfer\n        # during CUDA graph capture.\n        log_dim = math.ceil(math.log2(dim)) if dim > 0 else 0\n        dim_padded = 2**log_dim if dim > 0 else 1\n        H = torch.tensor(\n            hadamard(dim_padded, dtype=float),\n            dtype=DEFAULT_DTYPE,\n            device=DEFAULT_DEVICE,\n        )\n        FN_MAP[\"naive\"] = lambda: torch_hadamard_transform(\n            x.clone(), scale, H, dim, dim_padded\n        )\n\n    fn = FN_MAP[provider]\n    return run_benchmark(fn)\n\n\nif __name__ == \"__main__\":\n    print(\"=\" * 80)\n    print(\"Benchmarking Fast Hadamard Transform\")\n    print(\"=\" * 80)\n    benchmark.run(print_data=True)\n"
  },
  {
    "path": "python/sglang/jit_kernel/benchmark/bench_hicache.py",
    "content": "\"\"\"Benchmark for HiCache JIT kernel performance.\n\nThis benchmark tests the performance of KV cache transfer operations\nbetween GPU and CPU (host pinned memory), comparing:\n- SGL AOT Kernel: Pre-compiled transfer_kv kernels from sgl_kernel\n- SGL JIT Kernel: JIT-compiled hicache kernels\n- PyTorch Indexing: Plain PyTorch index copy\n- PyTorch 2 Stream: PyTorch implementation using 2 CUDA streams\n\nTests cover:\n- One Layer: CPU->GPU\n- All Layer: GPU->CPU\n\nNote: Uses do_bench instead of do_bench_cudagraph since CUDA graph\ncapture doesn't support CPU-GPU memory transfers.\n\"\"\"\n\nimport itertools\nimport os\nfrom dataclasses import dataclass\nfrom typing import Tuple\n\nimport torch\nimport triton\nimport triton.testing\nfrom sgl_kernel import transfer_kv_all_layer, transfer_kv_per_layer\n\nfrom sglang.jit_kernel.benchmark.utils import DEFAULT_QUANTILES, get_benchmark_range\nfrom sglang.jit_kernel.hicache import (\n    can_use_hicache_jit_kernel,\n    transfer_hicache_all_layer,\n    transfer_hicache_one_layer,\n)\n\nDISABLE_TORCH = os.environ.get(\"DISABLE_TORCH\", \"0\") == \"1\"\nPAGE_SIZE = 1\nENABLE_SORT = True\nGPU_CACHE_SIZE = 256 * 1024  # 256K tokens on GPU\nHOST_CACHE_SIZE = 512 * 1024  # 512K tokens on CPU\nNUM_LAYERS = 8\n\n\n@dataclass(frozen=True)\nclass HiCacheCache:\n    k_cache_cuda: torch.Tensor\n    v_cache_cuda: torch.Tensor\n    k_cache_host: torch.Tensor\n    v_cache_host: torch.Tensor\n\n    def get_slice(self, num_layers: int, element_size: int) -> \"HiCacheCache\":\n        def slice_cuda(t: torch.Tensor) -> torch.Tensor:\n            needed_cuda = num_layers * GPU_CACHE_SIZE\n            return t.view(-1, element_size)[:needed_cuda].unflatten(0, (num_layers, -1))\n\n        def slice_host(t: torch.Tensor) -> torch.Tensor:\n            needed_host = num_layers * HOST_CACHE_SIZE\n            return t.view(-1, element_size)[:needed_host].unflatten(0, (num_layers, -1))\n\n        return HiCacheCache(\n            k_cache_cuda=slice_cuda(self.k_cache_cuda),\n            v_cache_cuda=slice_cuda(self.v_cache_cuda),\n            k_cache_host=slice_host(self.k_cache_host),\n            v_cache_host=slice_host(self.v_cache_host),\n        )\n\n\ndef gen_indices(\n    size: int, max_size: int, *, page_size: int = PAGE_SIZE\n) -> torch.Tensor:\n    def align(x: int) -> int:\n        return (x + page_size - 1) // page_size\n\n    assert size <= max_size and max_size % page_size == 0\n    indices = torch.randperm(align(max_size))[: align(size)]\n    offsets = torch.arange(page_size)\n    return (indices[:, None] * page_size + offsets).flatten().cuda()[:size]\n\n\ndef sglang_aot_transfer_one(\n    k_cache_dst: torch.Tensor,\n    v_cache_dst: torch.Tensor,\n    indices_dst: torch.Tensor,\n    k_cache_src: torch.Tensor,\n    v_cache_src: torch.Tensor,\n    indices_src: torch.Tensor,\n    item_size: int,\n) -> None:\n    \"\"\"SGL AOT Kernel for single layer transfer.\"\"\"\n    transfer_kv_per_layer(\n        k_cache_src,\n        k_cache_dst,\n        v_cache_src,\n        v_cache_dst,\n        indices_src,\n        indices_dst,\n        item_size,\n    )\n\n\ndef sglang_jit_transfer_one(\n    k_cache_dst: torch.Tensor,\n    v_cache_dst: torch.Tensor,\n    indices_dst: torch.Tensor,\n    k_cache_src: torch.Tensor,\n    v_cache_src: torch.Tensor,\n    indices_src: torch.Tensor,\n    element_dim: int,\n) -> None:\n    \"\"\"SGL JIT Kernel for single layer transfer.\"\"\"\n    transfer_hicache_one_layer(\n        k_cache_dst,\n        v_cache_dst,\n        indices_dst,\n        k_cache_src,\n        v_cache_src,\n        indices_src,\n        element_dim=element_dim,\n    )\n\n\ndef sglang_aot_transfer_all(\n    k_ptrs_dst: torch.Tensor,\n    v_ptrs_dst: torch.Tensor,\n    indices_dst: torch.Tensor,\n    k_ptrs_src: torch.Tensor,\n    v_ptrs_src: torch.Tensor,\n    indices_src: torch.Tensor,\n    item_size: int,\n    num_layers: int,\n) -> None:\n    \"\"\"SGL AOT Kernel for all layer transfer.\"\"\"\n    transfer_kv_all_layer(\n        k_ptrs_src,\n        k_ptrs_dst,\n        v_ptrs_src,\n        v_ptrs_dst,\n        indices_src,\n        indices_dst,\n        item_size,\n        num_layers,\n    )\n\n\ndef sglang_jit_transfer_all(\n    k_ptrs_dst: torch.Tensor,\n    v_ptrs_dst: torch.Tensor,\n    indices_dst: torch.Tensor,\n    k_ptrs_src: torch.Tensor,\n    v_ptrs_src: torch.Tensor,\n    indices_src: torch.Tensor,\n    stride_bytes: int,\n    element_size: int,\n) -> None:\n    \"\"\"SGL JIT Kernel for all layer transfer.\"\"\"\n    transfer_hicache_all_layer(\n        k_ptrs_dst,\n        v_ptrs_dst,\n        indices_dst,\n        k_ptrs_src,\n        v_ptrs_src,\n        indices_src,\n        kv_cache_src_stride_bytes=stride_bytes,\n        kv_cache_dst_stride_bytes=stride_bytes,\n        element_size=element_size,\n    )\n\n\ndef pytorch_transfer(\n    k_cache_dst: torch.Tensor,\n    v_cache_dst: torch.Tensor,\n    indices_dst_on_dst: torch.Tensor,\n    k_cache_src: torch.Tensor,\n    v_cache_src: torch.Tensor,\n    indices_src_on_src: torch.Tensor,\n) -> None:\n    \"\"\"PyTorch indexing baseline.\"\"\"\n    dst_device = k_cache_dst.device\n    k_cache_dst[indices_dst_on_dst] = k_cache_src[indices_src_on_src].to(dst_device)\n    v_cache_dst[indices_dst_on_dst] = v_cache_src[indices_src_on_src].to(dst_device)\n\n\n# Benchmark configuration\n\nBS_RANGE = get_benchmark_range(\n    full_range=[2**n for n in range(0, 16)],\n    ci_range=[16],\n)\nELEMENT_SIZE_RANGE = get_benchmark_range(\n    full_range=[64, 128, 256, 512, 1024],\n    ci_range=[1024],\n)\n\nLINE_VALS = [\"aot\", \"jit\", \"torch\"]\nLINE_NAMES = [\"SGL AOT Kernel\", \"SGL JIT Kernel\", \"PyTorch\"]\nSTYLES = [(\"orange\", \"-\"), (\"blue\", \"--\"), (\"red\", \":\")]\n\nCONFIGS = list(itertools.product(ELEMENT_SIZE_RANGE, BS_RANGE))\n\n\n# =============================================================================\n# One Layer Benchmarks\n# =============================================================================\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=[\"element_size\", \"batch_size\"],\n        x_vals=CONFIGS,\n        line_arg=\"provider\",\n        line_vals=LINE_VALS,\n        line_names=LINE_NAMES,\n        styles=STYLES,\n        ylabel=\"us\",\n        plot_name=\"hicache-one-layer-h2d\",\n        args={},\n    )\n)\ndef benchmark_one_layer_h2d(\n    element_size: int, batch_size: int, provider: str\n) -> Tuple[float, float, float]:\n    \"\"\"One Layer: Host (CPU) -> Device (GPU).\"\"\"\n    global cache\n    cache_local = cache.get_slice(num_layers=NUM_LAYERS, element_size=element_size)\n    k_cache_src = cache_local.k_cache_host\n    v_cache_src = cache_local.v_cache_host\n    k_cache_dst = cache_local.k_cache_cuda\n    v_cache_dst = cache_local.v_cache_cuda\n    torch.manual_seed(batch_size * 65536 + element_size)\n    indices_src_gpu = gen_indices(batch_size, HOST_CACHE_SIZE)\n    indices_dst_gpu = gen_indices(batch_size, GPU_CACHE_SIZE)\n\n    if ENABLE_SORT:\n        indices_src_gpu, mapping = indices_src_gpu.sort()\n        indices_dst_gpu = indices_dst_gpu[mapping]\n    indices_src_cpu = indices_src_gpu.cpu()\n    torch.cuda.synchronize()\n\n    element_bytes = element_size * k_cache_src.element_size()\n\n    FN_MAP = {\n        \"aot\": lambda: [\n            sglang_aot_transfer_one(\n                k_cache_dst[i],\n                v_cache_dst[i],\n                indices_dst_gpu,\n                k_cache_src[i],\n                v_cache_src[i],\n                indices_src_gpu,\n                element_bytes,\n            )\n            for i in range(NUM_LAYERS)\n        ],\n        \"jit\": lambda: [\n            sglang_jit_transfer_one(\n                k_cache_dst[i],\n                v_cache_dst[i],\n                indices_dst_gpu,\n                k_cache_src[i],\n                v_cache_src[i],\n                indices_src_gpu,\n                element_size,\n            )\n            for i in range(NUM_LAYERS)\n        ],\n        \"torch\": lambda: [\n            pytorch_transfer(\n                k_cache_dst[i],\n                v_cache_dst[i],\n                indices_dst_gpu,\n                k_cache_src[i],\n                v_cache_src[i],\n                indices_src_cpu,\n            )\n            for i in range(NUM_LAYERS)\n        ],\n    }\n\n    if provider == \"jit\" and not can_use_hicache_jit_kernel(element_size=element_bytes):\n        return (float(\"nan\"), float(\"nan\"), float(\"nan\"))\n\n    if DISABLE_TORCH and provider in [\"torch\"]:\n        return (float(\"nan\"), float(\"nan\"), float(\"nan\"))\n\n    ms, min_ms, max_ms = triton.testing.do_bench(  # type: ignore\n        FN_MAP[provider], quantiles=DEFAULT_QUANTILES, warmup=5, rep=25\n    )\n    return (\n        1000 * ms / NUM_LAYERS,\n        1000 * max_ms / NUM_LAYERS,\n        1000 * min_ms / NUM_LAYERS,\n    )\n\n\n# =============================================================================\n# All Layer Benchmarks\n# =============================================================================\n\n\ndef _create_ptr_tensor(tensors, device=\"cuda\"):\n    \"\"\"Create a tensor of data pointers.\"\"\"\n    return torch.tensor(\n        [t.data_ptr() for t in tensors],\n        dtype=torch.uint64,\n        device=device,\n    )\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=[\"element_size\", \"batch_size\"],\n        x_vals=CONFIGS,\n        line_arg=\"provider\",\n        line_vals=LINE_VALS,\n        line_names=LINE_NAMES,\n        styles=STYLES,\n        ylabel=\"us\",\n        plot_name=\"hicache-all-layer-d2h\",\n        args={},\n    )\n)\ndef benchmark_all_layer_d2h(\n    element_size: int, batch_size: int, provider: str\n) -> Tuple[float, float, float]:\n    \"\"\"All Layer: Device (GPU) -> Host (CPU).\"\"\"\n    global cache\n    cache_local = cache.get_slice(num_layers=NUM_LAYERS, element_size=element_size)\n    k_caches_src = cache_local.k_cache_cuda\n    v_caches_src = cache_local.v_cache_cuda\n    k_caches_dst = cache_local.k_cache_host\n    v_caches_dst = cache_local.v_cache_host\n    torch.manual_seed(batch_size * 65536 + element_size)\n\n    indices_src_gpu = gen_indices(batch_size, GPU_CACHE_SIZE)\n    indices_dst_gpu = gen_indices(batch_size, HOST_CACHE_SIZE)\n    if ENABLE_SORT:\n        indices_dst_gpu, mapping = indices_dst_gpu.sort()\n        indices_src_gpu = indices_src_gpu[mapping]\n    indices_dst_cpu = indices_dst_gpu.cpu()\n    torch.cuda.synchronize()\n\n    element_bytes = element_size * k_caches_src.element_size()\n\n    k_ptrs_src = _create_ptr_tensor([k_caches_src[i] for i in range(NUM_LAYERS)])\n    v_ptrs_src = _create_ptr_tensor([v_caches_src[i] for i in range(NUM_LAYERS)])\n    k_ptrs_dst = _create_ptr_tensor([k_caches_dst[i] for i in range(NUM_LAYERS)])\n    v_ptrs_dst = _create_ptr_tensor([v_caches_dst[i] for i in range(NUM_LAYERS)])\n\n    FN_MAP = {\n        \"aot\": lambda: sglang_aot_transfer_all(\n            k_ptrs_dst,\n            v_ptrs_dst,\n            indices_dst_gpu,\n            k_ptrs_src,\n            v_ptrs_src,\n            indices_src_gpu,\n            element_bytes,\n            NUM_LAYERS,\n        ),\n        \"jit\": lambda: sglang_jit_transfer_all(\n            k_ptrs_dst,\n            v_ptrs_dst,\n            indices_dst_gpu,\n            k_ptrs_src,\n            v_ptrs_src,\n            indices_src_gpu,\n            element_bytes,\n            element_bytes,\n        ),\n        \"torch\": lambda: [\n            pytorch_transfer(\n                k_caches_dst[i],\n                v_caches_dst[i],\n                indices_dst_cpu,\n                k_caches_src[i],\n                v_caches_src[i],\n                indices_src_gpu,\n            )\n            for i in range(NUM_LAYERS)\n        ],\n    }\n\n    if provider == \"jit\" and not can_use_hicache_jit_kernel(element_size=element_bytes):\n        return (float(\"nan\"), float(\"nan\"), float(\"nan\"))\n\n    if DISABLE_TORCH and provider in [\"torch\"]:\n        return (float(\"nan\"), float(\"nan\"), float(\"nan\"))\n\n    ms, min_ms, max_ms = triton.testing.do_bench(  # type: ignore\n        FN_MAP[provider], quantiles=DEFAULT_QUANTILES, warmup=5, rep=25\n    )\n    return (\n        1000 * ms / NUM_LAYERS,\n        1000 * max_ms / NUM_LAYERS,\n        1000 * min_ms / NUM_LAYERS,\n    )\n\n\nif __name__ == \"__main__\":\n    MAX_SIZE = max(ELEMENT_SIZE_RANGE)\n    DEVICE_SHAPE = (NUM_LAYERS * GPU_CACHE_SIZE, MAX_SIZE)\n    HOST_SHAPE = (NUM_LAYERS * HOST_CACHE_SIZE, MAX_SIZE)\n\n    cache = HiCacheCache(\n        k_cache_cuda=torch.empty(DEVICE_SHAPE, dtype=torch.bfloat16, device=\"cuda\"),\n        v_cache_cuda=torch.empty(DEVICE_SHAPE, dtype=torch.bfloat16, device=\"cuda\"),\n        k_cache_host=torch.empty(HOST_SHAPE, dtype=torch.bfloat16, pin_memory=True),\n        v_cache_host=torch.empty(HOST_SHAPE, dtype=torch.bfloat16, pin_memory=True),\n    )\n\n    print(\"=\" * 60)\n    print(\"One Layer: Host -> Device (CPU -> GPU)\")\n    print(\"=\" * 60)\n    benchmark_one_layer_h2d.run(print_data=True)\n\n    print(\"\\n\" + \"=\" * 60)\n    print(\"All Layer: Device -> Host (GPU -> CPU) [per-layer avg]\")\n    print(\"=\" * 60)\n    benchmark_all_layer_d2h.run(print_data=True)\n"
  },
  {
    "path": "python/sglang/jit_kernel/benchmark/bench_moe_wna16_marlin.py",
    "content": "import torch\nimport triton\nimport triton.testing\nfrom sgl_kernel.scalar_type import scalar_types\n\nfrom sglang.jit_kernel.benchmark.utils import is_in_ci, run_benchmark\nfrom sglang.jit_kernel.moe_wna16_marlin import moe_wna16_marlin_gemm as jit_fn\nfrom sglang.srt.layers.moe.fused_moe_triton import moe_align_block_size\nfrom sglang.test.test_marlin_utils import marlin_quantize\n\nAOT_AVAILABLE = hasattr(torch.ops.sgl_kernel, \"moe_wna16_marlin_gemm\") and hasattr(\n    torch.ops.sgl_kernel.moe_wna16_marlin_gemm, \"default\"\n)\n\nIS_CI = is_in_ci()\n\n\ndef stack_and_dev(tensors):\n    dev = tensors[0].device\n    return torch.stack(tensors, dim=0).to(dev)\n\n\nE = 8\nSIZE_K = 4096\nSIZE_N = 4096\nGROUP_SIZE = 128\nTOPK = 2\nQUANT_TYPE = scalar_types.uint4b8\nDTYPE = torch.float16\nBLOCK_SIZE_M = 64\n\ntorch.manual_seed(0)\n_qweight_l, _scales_l, _w_ref_l = [], [], []\nfor i in range(E):\n    _w = torch.randn((SIZE_N, SIZE_K), dtype=DTYPE, device=\"cuda\") / 20\n    _perm = torch.randperm(SIZE_K)\n    _w_ref, _qw, _s, _, _, _ = marlin_quantize(_w, QUANT_TYPE, GROUP_SIZE, False, _perm)\n    _w_ref_l.append(_w_ref.T)\n    _qweight_l.append(_qw)\n    _scales_l.append(_s)\n\n_qweight = stack_and_dev(_qweight_l).contiguous()\n_scales = stack_and_dev(_scales_l)\n\n_sms = torch.cuda.get_device_properties(\"cuda\").multi_processor_count\n\n\ndef _make_inputs(size_m):\n    a = torch.randn((size_m, SIZE_K), dtype=DTYPE, device=\"cuda\") / 10\n    score = torch.randn((size_m, E), dtype=DTYPE, device=\"cuda\")\n    score_softmax = torch.softmax(score, dim=-1, dtype=torch.float32)\n    topk_weights, topk_ids = torch.topk(score_softmax, TOPK)\n\n    sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(\n        topk_ids, BLOCK_SIZE_M, E\n    )\n\n    max_workspace_size = (SIZE_N // 64) * (sorted_token_ids.size(0) // BLOCK_SIZE_M)\n    max_workspace_size = min(max_workspace_size, _sms * 4)\n    workspace = torch.zeros(max_workspace_size, dtype=torch.int, device=\"cuda\")\n\n    c = torch.empty((size_m * TOPK, SIZE_N), dtype=DTYPE, device=\"cuda\")\n\n    return (\n        a,\n        c,\n        topk_weights,\n        topk_ids,\n        sorted_token_ids,\n        expert_ids,\n        num_tokens_post_padded,\n        workspace,\n    )\n\n\ndef _run_jit(\n    a,\n    c,\n    topk_weights,\n    sorted_token_ids,\n    expert_ids,\n    num_tokens_post_padded,\n    workspace,\n    size_m,\n):\n    return jit_fn(\n        a,\n        c,\n        _qweight,\n        None,\n        _scales,\n        None,\n        None,\n        None,\n        None,\n        workspace,\n        sorted_token_ids,\n        expert_ids,\n        num_tokens_post_padded,\n        topk_weights,\n        moe_block_size=BLOCK_SIZE_M,\n        top_k=TOPK,\n        mul_topk_weights=False,\n        is_ep=False,\n        b_q_type=QUANT_TYPE,\n        size_m=size_m,\n        size_n=SIZE_N,\n        size_k=SIZE_K,\n        is_k_full=True,\n        use_atomic_add=True,\n        use_fp32_reduce=True,\n        is_zp_float=False,\n    )\n\n\ndef _run_aot(\n    a,\n    c,\n    topk_weights,\n    sorted_token_ids,\n    expert_ids,\n    num_tokens_post_padded,\n    workspace,\n    size_m,\n):\n    return torch.ops.sgl_kernel.moe_wna16_marlin_gemm.default(\n        a,\n        c,\n        _qweight,\n        None,\n        _scales,\n        None,\n        None,\n        None,\n        None,\n        workspace,\n        sorted_token_ids,\n        expert_ids,\n        num_tokens_post_padded,\n        topk_weights,\n        moe_block_size=BLOCK_SIZE_M,\n        top_k=TOPK,\n        mul_topk_weights=False,\n        is_ep=False,\n        b_q_type_id=QUANT_TYPE.id,\n        size_m=size_m,\n        size_n=SIZE_N,\n        size_k=SIZE_K,\n        is_k_full=True,\n        use_atomic_add=True,\n        use_fp32_reduce=True,\n        is_zp_float=False,\n    )\n\n\ndef check_correctness():\n    if not AOT_AVAILABLE:\n        print(\"sgl_kernel AOT not available, skipping correctness check\")\n        return\n    size_m = 16\n    a, c, topk_weights, topk_ids, sorted_token_ids, expert_ids, ntp, workspace = (\n        _make_inputs(size_m)\n    )\n    c_jit = c.clone()\n    c_aot = c.clone()\n    _run_jit(\n        a, c_jit, topk_weights, sorted_token_ids, expert_ids, ntp, workspace, size_m\n    )\n    _run_aot(\n        a, c_aot, topk_weights, sorted_token_ids, expert_ids, ntp, workspace, size_m\n    )\n    torch.testing.assert_close(c_jit, c_aot, rtol=1e-3, atol=1e-3)\n    print(\"Correctness check passed (JIT vs AOT)\")\n\n\nif IS_CI:\n    m_range = [1, 16, 128]\nelse:\n    m_range = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]\n\nif AOT_AVAILABLE:\n    line_vals = [\"jit\", \"aot\"]\n    line_names = [\"JIT Kernel\", \"AOT Kernel\"]\n    styles = [(\"blue\", \"-\"), (\"green\", \"-\")]\nelse:\n    line_vals = [\"jit\"]\n    line_names = [\"JIT Kernel\"]\n    styles = [(\"blue\", \"-\")]\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=[\"size_m\"],\n        x_vals=m_range,\n        line_arg=\"provider\",\n        line_vals=line_vals,\n        line_names=line_names,\n        styles=styles,\n        ylabel=\"us\",\n        plot_name=\"moe-wna16-marlin-gemm-performance\",\n        args={},\n    )\n)\ndef benchmark(size_m, provider):\n    a, c, topk_weights, topk_ids, sorted_token_ids, expert_ids, ntp, workspace = (\n        _make_inputs(size_m)\n    )\n\n    if provider == \"jit\":\n        fn = lambda: _run_jit(\n            a,\n            c.clone(),\n            topk_weights,\n            sorted_token_ids,\n            expert_ids,\n            ntp,\n            workspace,\n            size_m,\n        )\n    elif provider == \"aot\":\n        fn = lambda: _run_aot(\n            a,\n            c.clone(),\n            topk_weights,\n            sorted_token_ids,\n            expert_ids,\n            ntp,\n            workspace,\n            size_m,\n        )\n    else:\n        raise ValueError(f\"Unknown provider: {provider}\")\n\n    return run_benchmark(fn)\n\n\nif __name__ == \"__main__\":\n    check_correctness()\n    benchmark.run(print_data=True)\n"
  },
  {
    "path": "python/sglang/jit_kernel/benchmark/bench_norm.py",
    "content": "import itertools\n\nimport torch\nimport triton\nimport triton.testing\nfrom flashinfer.norm import fused_add_rmsnorm as fi_fused_add_rmsnorm\nfrom flashinfer.norm import rmsnorm as fi_rmsnorm\n\nfrom sglang.jit_kernel.benchmark.utils import is_in_ci, run_benchmark\nfrom sglang.jit_kernel.norm import fused_add_rmsnorm as jit_fused_add_rmsnorm\nfrom sglang.jit_kernel.norm import rmsnorm as jit_rmsnorm\n\nIS_CI = is_in_ci()\n\nDTYPE = torch.bfloat16\nDEVICE = \"cuda\"\n\n# JIT rmsnorm: hidden_size in {64,128,256} or (multiple of 256, <=8192)\n# JIT fused_add_rmsnorm: hidden_size % 8 == 0, <=8192\n# Use multiples of 256 <=8192 to satisfy both kernels\nif IS_CI:\n    BS_LIST = [16]\n    HIDDEN_SIZE_LIST = [512, 2048]\nelse:\n    BS_LIST = [2**n for n in range(0, 14)]\n    HIDDEN_SIZE_LIST = [1536, 3072, 4096, 5120, 8192]\n\nLINE_VALS = [\"jit\", \"flashinfer\"]\nLINE_NAMES = [\"SGL JIT Kernel\", \"FlashInfer\"]\nSTYLES = [(\"blue\", \"--\"), (\"green\", \"-.\")]\n\nconfigs = list(itertools.product(HIDDEN_SIZE_LIST, BS_LIST))\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=[\"hidden_size\", \"batch_size\"],\n        x_vals=configs,\n        line_arg=\"provider\",\n        line_vals=LINE_VALS,\n        line_names=LINE_NAMES,\n        styles=STYLES,\n        ylabel=\"us\",\n        plot_name=\"rmsnorm-performance\",\n        args={},\n    )\n)\ndef benchmark_rmsnorm(hidden_size: int, batch_size: int, provider: str):\n    input = torch.randn((batch_size, hidden_size), dtype=DTYPE, device=DEVICE)\n    weight = torch.randn(hidden_size, dtype=DTYPE, device=DEVICE)\n    FN_MAP = {\n        \"jit\": lambda: jit_rmsnorm(input.clone(), weight),\n        \"flashinfer\": lambda: fi_rmsnorm(input.clone(), weight, out=input.clone()),\n    }\n    fn = FN_MAP[provider]\n    return run_benchmark(fn)\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=[\"hidden_size\", \"batch_size\"],\n        x_vals=configs,\n        line_arg=\"provider\",\n        line_vals=LINE_VALS,\n        line_names=LINE_NAMES,\n        styles=STYLES,\n        ylabel=\"us\",\n        plot_name=\"fused-add-rmsnorm-performance\",\n        args={},\n    )\n)\ndef benchmark_fused_add_rmsnorm(hidden_size: int, batch_size: int, provider: str):\n    input = torch.randn((batch_size, hidden_size), dtype=DTYPE, device=DEVICE)\n    residual = torch.randn((batch_size, hidden_size), dtype=DTYPE, device=DEVICE)\n    weight = torch.randn(hidden_size, dtype=DTYPE, device=DEVICE)\n    FN_MAP = {\n        \"jit\": lambda: jit_fused_add_rmsnorm(\n            input.clone(), residual.clone(), weight, torch.finfo(DTYPE).eps\n        ),\n        \"flashinfer\": lambda: fi_fused_add_rmsnorm(\n            input.clone(), residual.clone(), weight, eps=torch.finfo(DTYPE).eps\n        ),\n    }\n    fn = FN_MAP[provider]\n    return run_benchmark(fn)\n\n\nif __name__ == \"__main__\":\n    print(\"Benchmarking rmsnorm...\")\n    benchmark_rmsnorm.run(print_data=True)\n\n    print(\"Benchmarking fused_add_rmsnorm...\")\n    benchmark_fused_add_rmsnorm.run(print_data=True)\n"
  },
  {
    "path": "python/sglang/jit_kernel/benchmark/bench_norm_impls.py",
    "content": "from __future__ import annotations\n\nimport argparse\nimport csv\nimport functools\nimport importlib\nimport math\nimport os\nimport statistics\nimport subprocess\nimport sys\nfrom pathlib import Path\nfrom typing import Callable\n\nimport torch\nimport torch.nn.functional as F\n\nfrom sglang.jit_kernel.benchmark.utils import DEFAULT_DEVICE, is_in_ci\nfrom sglang.jit_kernel.diffusion.triton.norm import norm_infer, rms_norm_fn\nfrom sglang.jit_kernel.diffusion.triton.rmsnorm_onepass import triton_one_pass_rms_norm\nfrom sglang.jit_kernel.norm import fused_add_rmsnorm as jit_fused_add_rmsnorm\nfrom sglang.jit_kernel.norm import rmsnorm as jit_rmsnorm\nfrom sglang.jit_kernel.utils import KERNEL_PATH\n\nos.environ.setdefault(\"FLASHINFER_DISABLE_VERSION_CHECK\", \"1\")\n\nREPO_ROOT = KERNEL_PATH.parents[2]\nTHIRD_PARTY_ROOT = REPO_ROOT / \"third_party\"\n\nFLAGGEMS_REPO = \"https://github.com/flagos-ai/FlagGems.git\"\nQUACK_REPO = \"https://github.com/Dao-AILab/quack.git\"\n\nTORCH_LN = \"torch.nn.LayerNorm\"\nSGL_RMS = \"sglang.RMSNorm.forward_cuda\"\nSGL_FUSED = \"sgl_kernel.fused_add_rmsnorm\"\nSGL_LN = \"sglang.LayerNormScaleShift\"\nSGL_RES_LN = \"sglang.ScaleResidualLayerNormScaleShift\"\nSGL_LN_PAIR = f\"{SGL_LN} / {SGL_RES_LN}\"\nMOVA_LN_MIX = f\"{TORCH_LN} / {SGL_LN_PAIR}\"\n\nACTUAL_DIFFUSION_GROUPS: list[\n    tuple[str, str, list[tuple[str, str, tuple[int, ...], str]]]\n] = [\n    (\n        \"qwen\",\n        \"1 GPU\",\n        [\n            (\"qwen_ln_4096x3072\", \"layernorm\", (1, 4096, 3072), SGL_LN_PAIR),\n            (\"qwen_ln_26x3072\", \"layernorm\", (1, 26, 3072), SGL_LN_PAIR),\n            (\"qwen_ln_6x3072\", \"layernorm\", (1, 6, 3072), SGL_LN_PAIR),\n            (\"qwen_rms_26x3584\", \"rmsnorm\", (1, 26, 3584), SGL_RMS),\n            (\"qwen_rms_6x3584\", \"rmsnorm\", (1, 6, 3584), SGL_RMS),\n        ],\n    ),\n    (\n        \"qwen-edit\",\n        \"1 GPU\",\n        [\n            (\"qwen_edit_ln_189x3072\", \"layernorm\", (1, 189, 3072), SGL_LN_PAIR),\n            (\"qwen_edit_ln_192x3072\", \"layernorm\", (1, 192, 3072), SGL_LN_PAIR),\n            (\"qwen_edit_ln_8308x3072\", \"layernorm\", (1, 8308, 3072), TORCH_LN),\n            (\"qwen_edit_rms_189x3584\", \"rmsnorm\", (1, 189, 3584), SGL_RMS),\n            (\"qwen_edit_rms_192x3584\", \"rmsnorm\", (1, 192, 3584), SGL_RMS),\n        ],\n    ),\n    (\n        \"flux\",\n        \"1 GPU\",\n        [\n            (\"flux_ln_77x768\", \"layernorm\", (1, 77, 768), TORCH_LN),\n            (\"flux_ln_512x3072\", \"layernorm\", (1, 512, 3072), TORCH_LN),\n            (\"flux_ln_4096x3072\", \"layernorm\", (1, 4096, 3072), TORCH_LN),\n            (\"flux_ln_4608x3072\", \"layernorm\", (1, 4608, 3072), TORCH_LN),\n            (\"flux_rms_512x4096\", \"rmsnorm\", (1, 512, 4096), SGL_RMS),\n        ],\n    ),\n    (\n        \"flux2\",\n        \"1 GPU\",\n        [\n            (\"flux2_ln_512x6144\", \"layernorm\", (1, 512, 6144), TORCH_LN),\n            (\"flux2_ln_4096x6144\", \"layernorm\", (1, 4096, 6144), TORCH_LN),\n            (\"flux2_ln_4608x6144\", \"layernorm\", (1, 4608, 6144), TORCH_LN),\n            (\"flux2_rms_4608x48x128\", \"rmsnorm\", (1, 4608, 48, 128), SGL_RMS),\n        ],\n    ),\n    (\n        \"zimage\",\n        \"1 GPU\",\n        [\n            (\"zimage_ln_4128x3840\", \"layernorm\", (1, 4128, 3840), TORCH_LN),\n            (\"zimage_rms_32x3840\", \"rmsnorm\", (1, 32, 3840), SGL_RMS),\n            (\"zimage_rms_4096x3840\", \"rmsnorm\", (1, 4096, 3840), SGL_RMS),\n            (\"zimage_rms_4128x3840\", \"rmsnorm\", (1, 4128, 3840), SGL_RMS),\n            (\"zimage_rms_512x2560\", \"rmsnorm\", (1, 512, 2560), SGL_RMS),\n            (\"zimage_rms_512x32x128\", \"rmsnorm\", (1, 512, 32, 128), SGL_RMS),\n            (\"zimage_rms_512x8x128\", \"rmsnorm\", (1, 512, 8, 128), SGL_RMS),\n        ],\n    ),\n    (\n        \"wan-ti2v\",\n        \"1 GPU\",\n        [\n            (\"wan_ti2v_ln_17850x3072\", \"layernorm\", (1, 17850, 3072), SGL_LN_PAIR),\n            (\"wan_ti2v_rms_17850x3072\", \"rmsnorm\", (1, 17850, 3072), SGL_RMS),\n            (\"wan_ti2v_rms_512x3072\", \"rmsnorm\", (1, 512, 3072), SGL_RMS),\n            (\"wan_ti2v_rms_512x4096\", \"rmsnorm\", (1, 512, 4096), SGL_RMS),\n        ],\n    ),\n    (\n        \"hunyuanvideo\",\n        \"1 GPU\",\n        [\n            (\"hunyuan_ln_46x768\", \"layernorm\", (1, 46, 768), TORCH_LN),\n            (\"hunyuan_ln_45x3072\", \"layernorm\", (1, 45, 3072), SGL_LN_PAIR),\n            (\"hunyuan_ln_27030x3072\", \"layernorm\", (1, 27030, 3072), SGL_LN_PAIR),\n            (\"hunyuan_ln_27075x3072\", \"layernorm\", (1, 27075, 3072), SGL_LN),\n            (\"hunyuan_rms_140x4096\", \"rmsnorm\", (1, 140, 4096), SGL_RMS),\n            (\"hunyuan_rms_45x24x128\", \"rmsnorm\", (1, 45, 24, 128), SGL_RMS),\n            (\"hunyuan_rms_27030x24x128\", \"rmsnorm\", (1, 27030, 24, 128), SGL_RMS),\n            (\"hunyuan_rms_27075x24x128\", \"rmsnorm\", (1, 27075, 24, 128), SGL_RMS),\n            (\"hunyuan_fused_add_140x4096\", \"fused_add_rmsnorm\", (140, 4096), SGL_FUSED),\n        ],\n    ),\n    (\n        \"mova-720p\",\n        \"4 GPU, ulysses=4, ring=1\",\n        [\n            (\"mova_ln_101x1536\", \"layernorm\", (1, 101, 1536), MOVA_LN_MIX),\n            (\"mova_ln_403x1536\", \"layernorm\", (1, 403, 1536), TORCH_LN),\n            (\"mova_ln_44100x5120\", \"layernorm\", (1, 44100, 5120), MOVA_LN_MIX),\n            (\"mova_ln_176400x5120\", \"layernorm\", (1, 176400, 5120), SGL_LN),\n            (\"mova_rms_101x1536\", \"rmsnorm\", (1, 101, 1536), SGL_RMS),\n            (\"mova_rms_101x5120\", \"rmsnorm\", (1, 101, 5120), SGL_RMS),\n            (\"mova_rms_44100x1536\", \"rmsnorm\", (1, 44100, 1536), SGL_RMS),\n            (\"mova_rms_44100x5120\", \"rmsnorm\", (1, 44100, 5120), SGL_RMS),\n            (\"mova_rms_512x1536\", \"rmsnorm\", (1, 512, 1536), SGL_RMS),\n            (\"mova_rms_512x4096\", \"rmsnorm\", (1, 512, 4096), SGL_RMS),\n            (\"mova_rms_512x5120\", \"rmsnorm\", (1, 512, 5120), SGL_RMS),\n        ],\n    ),\n]\n\nACTUAL_DIFFUSION_SHAPES: list[dict[str, object]] = [\n    {\n        \"shape_id\": shape_id,\n        \"model\": model,\n        \"gpu_config\": gpu_config,\n        \"op\": op,\n        \"input_shape\": list(input_shape),\n        \"source_impl\": source_impl,\n    }\n    for model, gpu_config, cases in ACTUAL_DIFFUSION_GROUPS\n    for shape_id, op, input_shape, source_impl in cases\n]\n\n\ndef effective_rows_from_shape(input_shape: list[int]) -> int:\n    rows = 1\n    for dim in input_shape[:-1]:\n        rows *= dim\n    return rows\n\n\ndef ensure_repo(repo_name: str, repo_url: str) -> Path:\n    repo_path = THIRD_PARTY_ROOT / repo_name\n    if repo_path.exists():\n        return repo_path\n    repo_path.parent.mkdir(parents=True, exist_ok=True)\n    subprocess.run(\n        [\"git\", \"clone\", \"--depth\", \"1\", repo_url, str(repo_path)],\n        check=True,\n        cwd=REPO_ROOT,\n    )\n    return repo_path\n\n\ndef ensure_python_dep(module_name: str, package_name: str | None = None) -> None:\n    package_name = package_name or module_name\n    try:\n        importlib.import_module(module_name)\n    except ModuleNotFoundError:\n        subprocess.run(\n            [sys.executable, \"-m\", \"pip\", \"install\", package_name],\n            check=True,\n        )\n\n\ndef dtype_from_name(name: str) -> torch.dtype:\n    mapping = {\n        \"bf16\": torch.bfloat16,\n        \"bfloat16\": torch.bfloat16,\n        \"fp16\": torch.float16,\n        \"float16\": torch.float16,\n        \"fp32\": torch.float32,\n        \"float32\": torch.float32,\n    }\n    return mapping[name]\n\n\ndef dtype_name(dtype: torch.dtype) -> str:\n    mapping = {\n        torch.bfloat16: \"bf16\",\n        torch.float16: \"fp16\",\n        torch.float32: \"fp32\",\n    }\n    return mapping[dtype]\n\n\ndef normalize_hidden_sizes(text: str) -> list[int]:\n    return [int(x) for x in text.split(\",\") if x]\n\n\ndef normalize_dtypes(text: str) -> list[torch.dtype]:\n    return [dtype_from_name(x.strip()) for x in text.split(\",\") if x.strip()]\n\n\ndef prewarm(fn: Callable[[], object], iters: int = 3) -> None:\n    for _ in range(iters):\n        fn()\n    torch.cuda.synchronize()\n\n\ndef benchmark_provider(\n    fn: Callable[[], object],\n    setup_fn: Callable[[], None] | None = None,\n    warmup: int = 10,\n    rep: int = 30,\n) -> tuple[float, float, float]:\n    for _ in range(warmup):\n        if setup_fn is not None:\n            setup_fn()\n        fn()\n    torch.cuda.synchronize()\n\n    start_event = torch.cuda.Event(enable_timing=True)\n    end_event = torch.cuda.Event(enable_timing=True)\n    times_us: list[float] = []\n    for _ in range(rep):\n        if setup_fn is not None:\n            setup_fn()\n        start_event.record()\n        fn()\n        end_event.record()\n        end_event.synchronize()\n        times_us.append(start_event.elapsed_time(end_event) * 1000.0)\n\n    return statistics.median(times_us), max(times_us), min(times_us)\n\n\ndef geometric_mean(values: list[float]) -> float:\n    if not values:\n        return float(\"nan\")\n    return math.exp(sum(math.log(v) for v in values) / len(values))\n\n\n@functools.cache\ndef load_flaggems():\n    ensure_python_dep(\"sqlalchemy\")\n    ensure_repo(\"FlagGems\", FLAGGEMS_REPO)\n    src_root = THIRD_PARTY_ROOT / \"FlagGems\" / \"src\"\n    if str(src_root) not in sys.path:\n        sys.path.insert(0, str(src_root))\n    from flag_gems.fused.fused_add_rms_norm import fused_add_rms_norm\n    from flag_gems.ops.layernorm import layer_norm\n    from flag_gems.ops.rms_norm import rms_norm\n\n    return rms_norm, layer_norm, fused_add_rms_norm\n\n\n@functools.cache\ndef load_quack():\n    repo_path = ensure_repo(\"quack\", QUACK_REPO)\n    try:\n        quack_rmsnorm = importlib.import_module(\"quack.rmsnorm\")\n    except ModuleNotFoundError:\n        subprocess.run(\n            [sys.executable, \"-m\", \"pip\", \"install\", \"-e\", str(repo_path)],\n            check=True,\n        )\n        quack_rmsnorm = importlib.import_module(\"quack.rmsnorm\")\n\n    return quack_rmsnorm.rmsnorm_fwd, quack_rmsnorm.layernorm_fwd\n\n\ndef build_rmsnorm_providers(dtype: torch.dtype, batch_size: int, hidden_size: int):\n    import flashinfer.norm as flashinfer_norm\n    import sgl_kernel\n\n    x = torch.randn((batch_size, hidden_size), device=DEFAULT_DEVICE, dtype=dtype)\n    weight = torch.randn(hidden_size, device=DEFAULT_DEVICE, dtype=dtype)\n\n    jit_out = torch.empty_like(x)\n    sgl_out = torch.empty_like(x)\n    flashinfer_out = torch.empty_like(x)\n\n    flaggems_rms_norm, _, _ = load_flaggems()\n    quack_rmsnorm_fwd, _ = load_quack()\n\n    providers = {\n        \"pytorch\": lambda: F.rms_norm(x, (hidden_size,), weight, 1e-6),\n        \"sgl_kernel\": lambda: sgl_kernel.rmsnorm(x, weight, eps=1e-6, out=sgl_out),\n        \"flashinfer\": lambda: flashinfer_norm.rmsnorm(\n            x, weight, eps=1e-6, out=flashinfer_out\n        ),\n        \"jit_rmsnorm\": lambda: jit_rmsnorm(x, weight, jit_out, 1e-6),\n        \"quack\": lambda: quack_rmsnorm_fwd(x, weight, eps=1e-6),\n        \"triton_rms_norm_fn\": lambda: rms_norm_fn(\n            x, weight, bias=None, residual=None, eps=1e-6\n        ),\n        \"flaggems\": lambda: flaggems_rms_norm(x, (hidden_size,), weight, 1e-6),\n    }\n    if hidden_size <= 128:\n        providers[\"triton_one_pass\"] = lambda: triton_one_pass_rms_norm(x, weight, 1e-6)\n    return providers\n\n\ndef build_fused_add_rmsnorm_providers(\n    dtype: torch.dtype, batch_size: int, hidden_size: int\n):\n    import flashinfer.norm as flashinfer_norm\n    import sgl_kernel\n\n    base_x = torch.randn((batch_size, hidden_size), device=DEFAULT_DEVICE, dtype=dtype)\n    base_residual = torch.randn_like(base_x)\n    weight = torch.randn(hidden_size, device=DEFAULT_DEVICE, dtype=dtype)\n\n    x = base_x.clone()\n    residual = base_residual.clone()\n\n    def reset():\n        x.copy_(base_x)\n        residual.copy_(base_residual)\n\n    _, _, flaggems_fused_add_rms_norm = load_flaggems()\n    quack_rmsnorm_fwd, _ = load_quack()\n\n    def pytorch_impl():\n        out = x + residual\n        return F.rms_norm(out, (hidden_size,), weight, 1e-6)\n\n    providers = {\n        \"pytorch\": (pytorch_impl, reset),\n        \"sgl_kernel\": (\n            lambda: sgl_kernel.fused_add_rmsnorm(x, residual, weight, eps=1e-6),\n            reset,\n        ),\n        \"flashinfer\": (\n            lambda: flashinfer_norm.fused_add_rmsnorm(x, residual, weight, eps=1e-6),\n            reset,\n        ),\n        \"jit_fused_add_rmsnorm\": (\n            lambda: jit_fused_add_rmsnorm(x, residual, weight, 1e-6),\n            reset,\n        ),\n        \"quack\": (\n            lambda: quack_rmsnorm_fwd(x, weight, residual=residual, eps=1e-6),\n            reset,\n        ),\n        \"flaggems\": (\n            lambda: flaggems_fused_add_rms_norm(\n                x, residual, (hidden_size,), weight, 1e-6\n            ),\n            reset,\n        ),\n    }\n    return providers\n\n\ndef build_layernorm_providers(dtype: torch.dtype, batch_size: int, hidden_size: int):\n    import flashinfer.norm as flashinfer_norm\n\n    x = torch.randn((batch_size, hidden_size), device=DEFAULT_DEVICE, dtype=dtype)\n    weight = torch.randn(hidden_size, device=DEFAULT_DEVICE, dtype=dtype)\n    bias = torch.randn(hidden_size, device=DEFAULT_DEVICE, dtype=dtype)\n    flashinfer_weight = torch.randn(\n        hidden_size, device=DEFAULT_DEVICE, dtype=torch.float32\n    )\n    flashinfer_bias = torch.randn(\n        hidden_size, device=DEFAULT_DEVICE, dtype=torch.float32\n    )\n\n    triton_out = torch.empty_like(x)\n\n    _, flaggems_layer_norm, _ = load_flaggems()\n    _, quack_layernorm_fwd = load_quack()\n\n    providers = {\n        \"pytorch\": lambda: F.layer_norm(x, (hidden_size,), weight, bias, 1e-6),\n        \"triton_norm_infer\": lambda: norm_infer(\n            x, weight, bias, eps=1e-6, is_rms_norm=False, out=triton_out\n        ),\n        \"flashinfer\": lambda: flashinfer_norm.layernorm(\n            x, flashinfer_weight, flashinfer_bias, 1e-6\n        ),\n        \"quack\": lambda: quack_layernorm_fwd(\n            x, flashinfer_weight, flashinfer_bias, 1e-6\n        ),\n        \"flaggems\": lambda: flaggems_layer_norm(x, (hidden_size,), weight, bias)[0],\n    }\n    return providers\n\n\ndef maybe_benchmark(\n    op_name: str,\n    provider_name: str,\n    fn: Callable[[], object],\n    rows: list[dict[str, object]],\n    dtype: torch.dtype,\n    batch_size: int,\n    hidden_size: int,\n    reset: Callable[[], None] | None = None,\n    metadata: dict[str, object] | None = None,\n) -> None:\n    metadata = metadata or {}\n    try:\n        median_us, max_us, min_us = benchmark_provider(fn, reset)\n        rows.append(\n            {\n                \"op\": op_name,\n                \"provider\": provider_name,\n                \"dtype\": dtype_name(dtype),\n                \"batch_size\": batch_size,\n                \"hidden_size\": hidden_size,\n                \"median_us\": median_us,\n                \"min_us\": min_us,\n                \"max_us\": max_us,\n                \"status\": \"ok\",\n                \"error\": \"\",\n                **metadata,\n            }\n        )\n    except Exception as exc:  # pragma: no cover - benchmark failures are data\n        rows.append(\n            {\n                \"op\": op_name,\n                \"provider\": provider_name,\n                \"dtype\": dtype_name(dtype),\n                \"batch_size\": batch_size,\n                \"hidden_size\": hidden_size,\n                \"median_us\": \"\",\n                \"min_us\": \"\",\n                \"max_us\": \"\",\n                \"status\": \"unsupported\",\n                \"error\": str(exc),\n                **metadata,\n            }\n        )\n\n\ndef write_csv(rows: list[dict[str, object]], output_path: Path) -> None:\n    output_path.parent.mkdir(parents=True, exist_ok=True)\n    with output_path.open(\"w\", newline=\"\", encoding=\"utf-8\") as f:\n        writer = csv.DictWriter(\n            f,\n            fieldnames=[\n                \"op\",\n                \"provider\",\n                \"dtype\",\n                \"batch_size\",\n                \"hidden_size\",\n                \"median_us\",\n                \"min_us\",\n                \"max_us\",\n                \"shape_id\",\n                \"source_model\",\n                \"source_gpu_config\",\n                \"source_input_shape\",\n                \"source_impl\",\n                \"status\",\n                \"error\",\n            ],\n        )\n        writer.writeheader()\n        writer.writerows(rows)\n\n\ndef write_markdown(rows: list[dict[str, object]], output_path: Path) -> None:\n    lines: list[str] = []\n    lines.append(\"# Norm Benchmark Summary\")\n    lines.append(\"\")\n    actual_shape_rows = [row for row in rows if row.get(\"shape_id\")]\n    if actual_shape_rows:\n        seen: set[tuple[str, str, str, str, str, str]] = set()\n        lines.append(\"## Diffusion Shape Cases\")\n        lines.append(\"\")\n        lines.append(\n            \"| Shape ID | Op | Model | GPU Config | Input Shape | Source Impl |\"\n        )\n        lines.append(\"|---|---|---|---|---|---|\")\n        for row in actual_shape_rows:\n            key = (\n                str(row.get(\"shape_id\", \"\")),\n                str(row.get(\"op\", \"\")),\n                str(row.get(\"source_model\", \"\")),\n                str(row.get(\"source_gpu_config\", \"\")),\n                str(row.get(\"source_input_shape\", \"\")),\n                str(row.get(\"source_impl\", \"\")),\n            )\n            if key in seen:\n                continue\n            seen.add(key)\n            lines.append(\n                f\"| {key[0]} | {key[1]} | {key[2]} | {key[3]} | `{key[4]}` | {key[5]} |\"\n            )\n        lines.append(\"\")\n    for op_name in (\"rmsnorm\", \"fused_add_rmsnorm\", \"layernorm\"):\n        for dtype in sorted({row[\"dtype\"] for row in rows}):\n            scoped = [\n                row\n                for row in rows\n                if row[\"op\"] == op_name\n                and row[\"dtype\"] == dtype\n                and row[\"status\"] == \"ok\"\n            ]\n            if not scoped:\n                continue\n            provider_to_values: dict[str, list[float]] = {}\n            provider_to_speedups: dict[str, list[float]] = {}\n            by_shape: dict[tuple[str, int, int], dict[str, float]] = {}\n            for row in scoped:\n                provider = str(row[\"provider\"])\n                value = float(row[\"median_us\"])\n                provider_to_values.setdefault(provider, []).append(value)\n                shape = (\n                    str(row.get(\"shape_id\", \"\")),\n                    int(row[\"batch_size\"]),\n                    int(row[\"hidden_size\"]),\n                )\n                by_shape.setdefault(shape, {})[provider] = value\n            for shape, perf in by_shape.items():\n                if \"pytorch\" not in perf:\n                    continue\n                baseline = perf[\"pytorch\"]\n                for provider, value in perf.items():\n                    provider_to_speedups.setdefault(provider, []).append(\n                        baseline / value\n                    )\n\n            lines.append(f\"## {op_name} ({dtype})\")\n            lines.append(\"\")\n            lines.append(\n                \"| Provider | Geomean Speedup vs PyTorch | Median Latency (us) | Win Count |\"\n            )\n            lines.append(\"|---|---:|---:|---:|\")\n            wins: dict[str, int] = {}\n            for perf in by_shape.values():\n                best_provider = min(perf, key=perf.get)\n                wins[best_provider] = wins.get(best_provider, 0) + 1\n            for provider in sorted(provider_to_values):\n                geomean_speedup = geometric_mean(provider_to_speedups.get(provider, []))\n                median_latency = statistics.median(provider_to_values[provider])\n                win_count = wins.get(provider, 0)\n                lines.append(\n                    f\"| {provider} | {geomean_speedup:.3f}x | {median_latency:.2f} | {win_count} |\"\n                )\n            lines.append(\"\")\n    output_path.write_text(\"\\n\".join(lines) + \"\\n\", encoding=\"utf-8\")\n\n\ndef run_suite(\n    hidden_sizes: list[int],\n    batch_sizes: list[int],\n    dtypes: list[torch.dtype],\n    ops: list[str],\n) -> list[dict[str, object]]:\n    rows: list[dict[str, object]] = []\n    for dtype in dtypes:\n        for batch_size in batch_sizes:\n            for hidden_size in hidden_sizes:\n                if \"rmsnorm\" in ops:\n                    rms_providers = build_rmsnorm_providers(\n                        dtype, batch_size, hidden_size\n                    )\n                    for provider_name, fn in rms_providers.items():\n                        maybe_benchmark(\n                            \"rmsnorm\",\n                            provider_name,\n                            fn,\n                            rows,\n                            dtype,\n                            batch_size,\n                            hidden_size,\n                        )\n\n                if \"fused_add_rmsnorm\" in ops:\n                    fused_providers = build_fused_add_rmsnorm_providers(\n                        dtype, batch_size, hidden_size\n                    )\n                    for provider_name, provider in fused_providers.items():\n                        fn, reset = provider\n                        maybe_benchmark(\n                            \"fused_add_rmsnorm\",\n                            provider_name,\n                            fn,\n                            rows,\n                            dtype,\n                            batch_size,\n                            hidden_size,\n                            reset,\n                        )\n\n                if \"layernorm\" in ops:\n                    layernorm_providers = build_layernorm_providers(\n                        dtype, batch_size, hidden_size\n                    )\n                    for provider_name, fn in layernorm_providers.items():\n                        maybe_benchmark(\n                            \"layernorm\",\n                            provider_name,\n                            fn,\n                            rows,\n                            dtype,\n                            batch_size,\n                            hidden_size,\n                        )\n    return rows\n\n\ndef run_shape_suite(\n    shape_cases: list[dict[str, object]],\n    dtypes: list[torch.dtype],\n) -> list[dict[str, object]]:\n    rows: list[dict[str, object]] = []\n    for case in shape_cases:\n        op_name = str(case[\"op\"])\n        input_shape = [int(x) for x in case[\"input_shape\"]]\n        batch_size = effective_rows_from_shape(input_shape)\n        hidden_size = input_shape[-1]\n        metadata = {\n            \"shape_id\": str(case[\"shape_id\"]),\n            \"source_model\": str(case[\"model\"]),\n            \"source_gpu_config\": str(case[\"gpu_config\"]),\n            \"source_input_shape\": str(input_shape),\n            \"source_impl\": str(case[\"source_impl\"]),\n        }\n        for dtype in dtypes:\n            if op_name == \"rmsnorm\":\n                providers = build_rmsnorm_providers(dtype, batch_size, hidden_size)\n                for provider_name, fn in providers.items():\n                    maybe_benchmark(\n                        op_name,\n                        provider_name,\n                        fn,\n                        rows,\n                        dtype,\n                        batch_size,\n                        hidden_size,\n                        metadata=metadata,\n                    )\n            elif op_name == \"fused_add_rmsnorm\":\n                providers = build_fused_add_rmsnorm_providers(\n                    dtype, batch_size, hidden_size\n                )\n                for provider_name, provider in providers.items():\n                    fn, reset = provider\n                    maybe_benchmark(\n                        op_name,\n                        provider_name,\n                        fn,\n                        rows,\n                        dtype,\n                        batch_size,\n                        hidden_size,\n                        reset,\n                        metadata=metadata,\n                    )\n            elif op_name == \"layernorm\":\n                providers = build_layernorm_providers(dtype, batch_size, hidden_size)\n                for provider_name, fn in providers.items():\n                    maybe_benchmark(\n                        op_name,\n                        provider_name,\n                        fn,\n                        rows,\n                        dtype,\n                        batch_size,\n                        hidden_size,\n                        metadata=metadata,\n                    )\n            else:\n                raise ValueError(f\"Unsupported op in shape preset: {op_name}\")\n    return rows\n\n\ndef main() -> None:\n    parser = argparse.ArgumentParser(\n        description=\"Benchmark RMSNorm/LayerNorm implementations across providers.\"\n    )\n    parser.add_argument(\n        \"--hidden-sizes\",\n        default=\"64,128,256,512,1024,2048,4096,8192,16384\",\n        help=\"Comma-separated hidden sizes.\",\n    )\n    parser.add_argument(\n        \"--batch-sizes\",\n        default=\"1,16,128,1024\",\n        help=\"Comma-separated batch sizes.\",\n    )\n    parser.add_argument(\n        \"--dtypes\",\n        default=\"bf16,fp16\",\n        help=\"Comma-separated dtypes: bf16, fp16, fp32.\",\n    )\n    parser.add_argument(\n        \"--output-dir\",\n        default=str(REPO_ROOT / \"outputs\" / \"norm_benchmarks\"),\n        help=\"Directory for CSV/Markdown outputs.\",\n    )\n    parser.add_argument(\n        \"--ops\",\n        default=\"rmsnorm,fused_add_rmsnorm,layernorm\",\n        help=\"Comma-separated ops to benchmark.\",\n    )\n    parser.add_argument(\n        \"--shape-preset\",\n        choices=[\"grid\", \"diffusion-actual\"],\n        default=\"grid\",\n        help=\"Use the default grid sweep or the captured diffusion workload shapes.\",\n    )\n    args = parser.parse_args()\n\n    if not torch.cuda.is_available():\n        raise RuntimeError(\"CUDA is required for norm benchmarks.\")\n\n    hidden_sizes = normalize_hidden_sizes(args.hidden_sizes)\n    batch_sizes = normalize_hidden_sizes(args.batch_sizes)\n    dtypes = normalize_dtypes(args.dtypes)\n    ops = [op.strip() for op in args.ops.split(\",\") if op.strip()]\n\n    if args.shape_preset == \"diffusion-actual\":\n        shape_cases = [case for case in ACTUAL_DIFFUSION_SHAPES if case[\"op\"] in ops]\n        rows = run_shape_suite(shape_cases, dtypes)\n    else:\n        rows = run_suite(hidden_sizes, batch_sizes, dtypes, ops)\n    output_dir = Path(args.output_dir)\n    csv_path = output_dir / \"norm_impls.csv\"\n    md_path = output_dir / \"norm_impls_summary.md\"\n    write_csv(rows, csv_path)\n    write_markdown(rows, md_path)\n    print(f\"Wrote {csv_path}\")\n    print(f\"Wrote {md_path}\")\n\n\nif __name__ == \"__main__\":\n    if is_in_ci():\n        print(\"Skipping bench_norm_impls.py in CI\")\n        sys.exit(0)\n    main()\n"
  },
  {
    "path": "python/sglang/jit_kernel/benchmark/bench_nvfp4_blockwise_moe.py",
    "content": "from __future__ import annotations\n\nimport sys\nfrom typing import Any\n\nimport torch\nimport triton\n\nfrom sglang.jit_kernel.benchmark.utils import get_benchmark_range, run_benchmark\nfrom sglang.jit_kernel.nvfp4 import (\n    cutlass_fp4_group_mm,\n    scaled_fp4_experts_quant,\n    scaled_fp4_quant,\n)\nfrom sglang.srt.utils import is_sm100_supported\n\nFLOAT4_E2M1_MAX = 6.0\nFLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max\n_NVFP4_SUPPORTED = is_sm100_supported()\n\n\ndef _round_up(x: int, y: int) -> int:\n    return ((x + y - 1) // y) * y\n\n\ndef _expert_offsets(m_per_expert: list[int], device: torch.device) -> torch.Tensor:\n    offsets = [0]\n    for m in m_per_expert:\n        offsets.append(offsets[-1] + m)\n    return torch.tensor(offsets, dtype=torch.int32, device=device)\n\n\ndef _blockscale_offsets(m_per_expert: list[int], device: torch.device) -> torch.Tensor:\n    offsets = [0]\n    for m in m_per_expert:\n        offsets.append(offsets[-1] + _round_up(m, 128))\n    return torch.tensor(offsets, dtype=torch.int32, device=device)\n\n\ndef _prepare_case(\n    total_tokens: int, n: int, k: int, num_experts: int, dtype: torch.dtype\n) -> dict[str, Any]:\n    device = torch.device(\"cuda\")\n    base = total_tokens // num_experts\n    rem = total_tokens % num_experts\n    m_per_expert = [base + (1 if i < rem else 0) for i in range(num_experts)]\n\n    expert_offsets_full = _expert_offsets(m_per_expert, device)\n    blockscale_offsets_full = _blockscale_offsets(m_per_expert, device)\n\n    a = torch.randn((total_tokens, k), device=device, dtype=dtype) * 0.1\n    b = torch.randn((num_experts, n, k), device=device, dtype=dtype) * 0.1\n\n    a_global_scale = torch.empty((num_experts,), device=device, dtype=torch.float32)\n    for i in range(num_experts):\n        start = int(expert_offsets_full[i].item())\n        end = int(expert_offsets_full[i + 1].item())\n        a_global_scale[i] = (\n            FLOAT8_E4M3_MAX\n            * FLOAT4_E2M1_MAX\n            / a[start:end].abs().max().to(torch.float32)\n        )\n\n    b_global_scale = torch.empty((num_experts,), device=device, dtype=torch.float32)\n    for i in range(num_experts):\n        b_global_scale[i] = (\n            FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b[i].abs().max().to(torch.float32)\n        )\n\n    a_fp4, a_blockscale = scaled_fp4_experts_quant(\n        a,\n        a_global_scale,\n        expert_offsets_full,\n        blockscale_offsets_full,\n        topk=1,\n    )\n\n    b_fp4 = torch.empty((num_experts, n, k // 2), device=device, dtype=torch.uint8)\n    b_blockscale = torch.empty(\n        (num_experts, _round_up(n, 128), _round_up(k // 16, 4)),\n        device=device,\n        dtype=torch.float8_e4m3fn,\n    )\n    for i in range(num_experts):\n        b_fp4_i, b_scale_i = scaled_fp4_quant(b[i], b_global_scale[i])\n        b_fp4[i].copy_(b_fp4_i)\n        b_blockscale[i].copy_(b_scale_i)\n\n    alphas = (1.0 / (a_global_scale * b_global_scale)).to(torch.float32)\n    params = {\n        \"ab_strides\": torch.full((num_experts,), k, dtype=torch.int64, device=device),\n        \"c_strides\": torch.full((num_experts,), n, dtype=torch.int64, device=device),\n        \"problem_sizes\": torch.tensor(\n            [[m, n, k] for m in m_per_expert], dtype=torch.int32, device=device\n        ),\n        \"expert_offsets\": expert_offsets_full[:-1].contiguous(),\n        \"blockscale_offsets\": blockscale_offsets_full[:-1].contiguous(),\n        \"a_ptrs\": torch.empty((num_experts,), dtype=torch.int64, device=device),\n        \"b_ptrs\": torch.empty((num_experts,), dtype=torch.int64, device=device),\n        \"out_ptrs\": torch.empty((num_experts,), dtype=torch.int64, device=device),\n        \"a_scales_ptrs\": torch.empty((num_experts,), dtype=torch.int64, device=device),\n        \"b_scales_ptrs\": torch.empty((num_experts,), dtype=torch.int64, device=device),\n        \"alpha_ptrs\": torch.empty((num_experts,), dtype=torch.int64, device=device),\n        \"layout_sfa\": torch.empty((num_experts, 5), dtype=torch.int64, device=device),\n        \"layout_sfb\": torch.empty((num_experts, 5), dtype=torch.int64, device=device),\n    }\n\n    expert_ranges: list[tuple[int, int]] = []\n    start = 0\n    for m in m_per_expert:\n        end = start + m\n        expert_ranges.append((start, end))\n        start = end\n\n    return {\n        \"a\": a,\n        \"b\": b,\n        \"a_fp4\": a_fp4,\n        \"b_fp4\": b_fp4,\n        \"a_blockscale\": a_blockscale,\n        \"b_blockscale\": b_blockscale,\n        \"alphas\": alphas,\n        \"params\": params,\n        \"expert_offsets_full\": expert_offsets_full,\n        \"expert_ranges\": expert_ranges,\n        \"dtype\": dtype,\n    }\n\n\ndef _torch_ref_group_mm(case: dict[str, Any]) -> torch.Tensor:\n    a = case[\"a\"]\n    b = case[\"b\"]\n    dtype = case[\"dtype\"]\n    expert_ranges = case[\"expert_ranges\"]\n    total_tokens = a.shape[0]\n    n = b.shape[1]\n    out = torch.empty((total_tokens, n), device=a.device, dtype=dtype)\n    for i, (start, end) in enumerate(expert_ranges):\n        out[start:end] = torch.matmul(a[start:end], b[i].t())\n    return out\n\n\ndef _aot_cutlass_fp4_group_mm(case: dict[str, Any]) -> torch.Tensor:\n    a_fp4 = case[\"a_fp4\"]\n    b_fp4 = case[\"b_fp4\"]\n    a_blockscale = case[\"a_blockscale\"]\n    b_blockscale = case[\"b_blockscale\"]\n    alphas = case[\"alphas\"]\n    params = case[\"params\"]\n    out_dtype = case[\"dtype\"]\n\n    out = torch.empty(\n        (a_fp4.shape[0], b_fp4.shape[1]), device=a_fp4.device, dtype=out_dtype\n    )\n    torch.ops.sgl_kernel.cutlass_fp4_group_mm.default(\n        out,\n        a_fp4,\n        b_fp4,\n        a_blockscale,\n        b_blockscale,\n        alphas,\n        params[\"ab_strides\"],\n        params[\"c_strides\"],\n        params[\"problem_sizes\"],\n        params[\"expert_offsets\"],\n        params[\"blockscale_offsets\"],\n    )\n    return out\n\n\ndef _probe_legacy_aot_group_mm() -> tuple[bool, str]:\n    if not torch.cuda.is_available():\n        return False, \"CUDA is not available.\"\n    if not _NVFP4_SUPPORTED:\n        return False, \"NVFP4 benchmarks require sm100+ with CUDA 12.8+.\"\n    try:\n        import sgl_kernel  # noqa: F401\n    except Exception as e:\n        return False, f\"import sgl_kernel failed: {e}\"\n    if not hasattr(torch.ops, \"sgl_kernel\"):\n        return False, \"torch.ops.sgl_kernel is not registered.\"\n    op = getattr(torch.ops.sgl_kernel, \"cutlass_fp4_group_mm\", None)\n    if op is None or not hasattr(op, \"default\"):\n        return False, \"torch.ops.sgl_kernel.cutlass_fp4_group_mm.default is missing.\"\n    try:\n        case = _prepare_case(64, 256, 128, 4, torch.bfloat16)\n        _aot_cutlass_fp4_group_mm(case)\n        torch.cuda.synchronize()\n    except Exception as e:\n        return False, f\"calling AOT grouped_mm op failed: {e}\"\n    return True, \"\"\n\n\n_AOT_GROUP_MM_AVAILABLE, _AOT_GROUP_MM_REASON = _probe_legacy_aot_group_mm()\n\nshape_range = get_benchmark_range(\n    full_range=[(128, 256, 128, 4), (256, 512, 128, 8), (512, 512, 256, 8)],\n    ci_range=[(128, 256, 128, 4)],\n)\n\nline_vals = [\"jit\"]\nline_names = [\"JIT NVFP4 MoE GroupMM\"]\nstyles = [(\"green\", \"-\")]\nif _AOT_GROUP_MM_AVAILABLE:\n    line_vals.append(\"aot_sgl_kernel\")\n    line_names.append(\"AOT NVFP4 MoE GroupMM\")\n    styles.append((\"orange\", \"-\"))\nline_vals.append(\"torch_ref\")\nline_names.append(\"Torch Ref\")\nstyles.append((\"blue\", \"-\"))\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=[\"total_tokens\", \"n\", \"k\", \"num_experts\"],\n        x_vals=shape_range,\n        x_log=False,\n        line_arg=\"provider\",\n        line_vals=line_vals,\n        line_names=line_names,\n        styles=styles,\n        ylabel=\"us\",\n        plot_name=\"nvfp4-blockwise-moe-groupmm-performance\",\n        args={},\n    )\n)\ndef benchmark(total_tokens, n, k, num_experts, provider):\n    case = _prepare_case(total_tokens, n, k, num_experts, torch.bfloat16)\n\n    if provider == \"jit\":\n        fn = lambda: cutlass_fp4_group_mm(\n            case[\"a_fp4\"],\n            case[\"b_fp4\"],\n            case[\"a_blockscale\"],\n            case[\"b_blockscale\"],\n            case[\"alphas\"],\n            case[\"dtype\"],\n            case[\"params\"],\n        )\n    elif provider == \"aot_sgl_kernel\":\n        fn = lambda: _aot_cutlass_fp4_group_mm(case)\n    elif provider == \"torch_ref\":\n        fn = lambda: _torch_ref_group_mm(case)\n    else:\n        raise ValueError(f\"Unknown provider: {provider}\")\n\n    return run_benchmark(fn)\n\n\nif __name__ == \"__main__\":\n    if not _NVFP4_SUPPORTED:\n        print(\"[skip] NVFP4 blockwise MoE benchmark requires sm100+ with CUDA 12.8+.\")\n        sys.exit(0)\n    if not _AOT_GROUP_MM_AVAILABLE:\n        print(\n            f\"[info] legacy AOT grouped_mm baseline unavailable: {_AOT_GROUP_MM_REASON}\"\n        )\n    benchmark.run(print_data=True)\n"
  },
  {
    "path": "python/sglang/jit_kernel/benchmark/bench_nvfp4_quant.py",
    "content": "from __future__ import annotations\n\nimport sys\n\nimport torch\nimport triton\n\nfrom sglang.jit_kernel.benchmark.utils import get_benchmark_range, run_benchmark\nfrom sglang.jit_kernel.nvfp4 import scaled_fp4_quant\nfrom sglang.srt.utils import is_sm100_supported\n\nFLOAT4_E2M1_MAX = 6.0\nFLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max\nBLOCK_SIZE = 16\n_NVFP4_SUPPORTED = is_sm100_supported()\n\ntry:\n    from flashinfer import fp4_quantize as flashinfer_fp4_quantize\nexcept Exception:\n    flashinfer_fp4_quantize = None\n\n\ndef _torch_ref_quant(input: torch.Tensor, input_global_scale: torch.Tensor):\n    m, n = input.shape\n    x = input.view(m, n // BLOCK_SIZE, BLOCK_SIZE)\n    vec_max = torch.max(torch.abs(x), dim=-1, keepdim=True)[0].to(torch.float32)\n    scale = input_global_scale * (vec_max / FLOAT4_E2M1_MAX)\n    scale = scale.to(torch.float8_e4m3fn).to(torch.float32)\n    output_scale = torch.where(scale == 0, torch.zeros_like(scale), 1.0 / scale)\n\n    scaled_x = x.to(torch.float32) * output_scale\n    clipped = torch.clamp(scaled_x, -6.0, 6.0).reshape(m, n)\n\n    rounded = clipped.clone()\n    rounded[(rounded >= 0.0) & (rounded <= 0.25)] = 0.0\n    rounded[(rounded > 0.25) & (rounded < 0.75)] = 0.5\n    rounded[(rounded >= 0.75) & (rounded <= 1.25)] = 1.0\n    rounded[(rounded > 1.25) & (rounded < 1.75)] = 1.5\n    rounded[(rounded >= 1.75) & (rounded <= 2.5)] = 2.0\n    rounded[(rounded > 2.5) & (rounded < 3.5)] = 3.0\n    rounded[(rounded >= 3.5) & (rounded <= 5.0)] = 4.0\n    rounded[rounded > 5.0] = 6.0\n\n    # This baseline intentionally keeps work on GPU but does not pack to uint8.\n    return rounded, scale\n\n\ndef _aot_scaled_fp4_quant(input: torch.Tensor, input_global_scale: torch.Tensor):\n    m, n = input.shape\n    output = torch.empty((m, n // 2), device=input.device, dtype=torch.uint8)\n    rounded_m = ((m + 128 - 1) // 128) * 128\n    scale_n = n // BLOCK_SIZE\n    rounded_n = ((scale_n + 4 - 1) // 4) * 4\n    output_scale = torch.empty(\n        (rounded_m, rounded_n // 4), device=input.device, dtype=torch.int32\n    )\n    torch.ops.sgl_kernel.scaled_fp4_quant.default(\n        output, input, output_scale, input_global_scale\n    )\n    return output, output_scale.view(torch.float8_e4m3fn)\n\n\ndef _probe_legacy_aot_quant() -> tuple[bool, str]:\n    if not torch.cuda.is_available():\n        return False, \"CUDA is not available.\"\n    if not _NVFP4_SUPPORTED:\n        return False, \"NVFP4 benchmarks require sm100+ with CUDA 12.8+.\"\n    try:\n        import sgl_kernel  # noqa: F401\n    except Exception as e:\n        return False, f\"import sgl_kernel failed: {e}\"\n    if not hasattr(torch.ops, \"sgl_kernel\"):\n        return False, \"torch.ops.sgl_kernel is not registered.\"\n    op = getattr(torch.ops.sgl_kernel, \"scaled_fp4_quant\", None)\n    if op is None or not hasattr(op, \"default\"):\n        return False, \"torch.ops.sgl_kernel.scaled_fp4_quant.default is missing.\"\n    try:\n        x = torch.randn((16, 64), dtype=torch.bfloat16, device=\"cuda\")\n        global_scale = (\n            FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / torch.abs(x).max().to(torch.float32)\n        )\n        _aot_scaled_fp4_quant(x, global_scale)\n        torch.cuda.synchronize()\n    except Exception as e:\n        return False, f\"calling AOT quant op failed: {e}\"\n    return True, \"\"\n\n\n_AOT_QUANT_AVAILABLE, _AOT_QUANT_REASON = _probe_legacy_aot_quant()\n\n\ndef _probe_flashinfer_quant() -> tuple[bool, str]:\n    if flashinfer_fp4_quantize is None:\n        return False, \"import flashinfer.fp4_quantize failed.\"\n    if not torch.cuda.is_available():\n        return False, \"CUDA is not available.\"\n    if not _NVFP4_SUPPORTED:\n        return False, \"NVFP4 benchmarks require sm100+ with CUDA 12.8+.\"\n    try:\n        x = torch.randn((16, 64), dtype=torch.bfloat16, device=\"cuda\")\n        global_scale = (\n            FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / torch.abs(x).max().to(torch.float32)\n        )\n        flashinfer_fp4_quantize(\n            x,\n            global_scale,\n            BLOCK_SIZE,  # sf_vec_size\n            False,  # use_ue8m0\n            True,  # is_sf_swizzled_layout\n        )\n        torch.cuda.synchronize()\n    except Exception as e:\n        return False, f\"calling flashinfer.fp4_quantize failed: {e}\"\n    return True, \"\"\n\n\n_FLASHINFER_QUANT_AVAILABLE, _FLASHINFER_QUANT_REASON = _probe_flashinfer_quant()\n\nshape_range = get_benchmark_range(\n    full_range=[(128, 2048), (512, 4096), (1024, 4096), (2048, 8192)],\n    ci_range=[(128, 2048)],\n)\n\nline_vals = []\nline_names = []\nstyles = []\nif _FLASHINFER_QUANT_AVAILABLE:\n    line_vals.append(\"flashinfer\")\n    line_names.append(\"FlashInfer FP4 Quant\")\n    styles.append((\"purple\", \"-\"))\nline_vals.append(\"jit\")\nline_names.append(\"JIT NVFP4 Quant\")\nstyles.append((\"green\", \"-\"))\nif _AOT_QUANT_AVAILABLE:\n    line_vals.append(\"aot_sgl_kernel\")\n    line_names.append(\"AOT NVFP4 Quant\")\n    styles.append((\"orange\", \"-\"))\nline_vals.append(\"torch_ref\")\nline_names.append(\"Torch Ref\")\nstyles.append((\"blue\", \"-\"))\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=[\"m\", \"n\"],\n        x_vals=shape_range,\n        x_log=False,\n        line_arg=\"provider\",\n        line_vals=line_vals,\n        line_names=line_names,\n        styles=styles,\n        ylabel=\"us\",\n        plot_name=\"nvfp4-quant-performance\",\n        args={},\n    )\n)\ndef benchmark(m, n, provider):\n    x = torch.randn((m, n), dtype=torch.bfloat16, device=\"cuda\")\n    tensor_amax = torch.abs(x).max().to(torch.float32)\n    global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax\n\n    if provider == \"jit\":\n        fn = lambda: scaled_fp4_quant(x, global_scale)\n    elif provider == \"flashinfer\":\n        fn = lambda: flashinfer_fp4_quantize(\n            x,\n            global_scale,\n            BLOCK_SIZE,  # sf_vec_size\n            False,  # use_ue8m0\n            True,  # is_sf_swizzled_layout\n        )\n    elif provider == \"aot_sgl_kernel\":\n        fn = lambda: _aot_scaled_fp4_quant(x, global_scale)\n    elif provider == \"torch_ref\":\n        fn = lambda: _torch_ref_quant(x, global_scale)\n    else:\n        raise ValueError(f\"Unknown provider: {provider}\")\n\n    return run_benchmark(fn)\n\n\nif __name__ == \"__main__\":\n    if not _NVFP4_SUPPORTED:\n        print(\"[skip] NVFP4 quant benchmark requires sm100+ with CUDA 12.8+.\")\n        sys.exit(0)\n    if not _FLASHINFER_QUANT_AVAILABLE:\n        print(\n            f\"[info] flashinfer quant baseline unavailable: {_FLASHINFER_QUANT_REASON}\"\n        )\n    if not _AOT_QUANT_AVAILABLE:\n        print(f\"[info] legacy AOT quant baseline unavailable: {_AOT_QUANT_REASON}\")\n    benchmark.run(print_data=True)\n"
  },
  {
    "path": "python/sglang/jit_kernel/benchmark/bench_nvfp4_scaled_mm.py",
    "content": "from __future__ import annotations\n\nimport sys\n\nimport torch\nimport triton\n\nfrom sglang.jit_kernel.benchmark.utils import get_benchmark_range, run_benchmark\nfrom sglang.jit_kernel.nvfp4 import cutlass_scaled_fp4_mm, scaled_fp4_quant\nfrom sglang.srt.utils import is_sm100_supported\n\nFLOAT4_E2M1_MAX = 6.0\nFLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max\nBLOCK_SIZE = 16\n_NVFP4_SUPPORTED = is_sm100_supported()\n\nK_E2M1_TO_FLOAT = [\n    0.0,\n    0.5,\n    1.0,\n    1.5,\n    2.0,\n    3.0,\n    4.0,\n    6.0,\n    0.0,\n    -0.5,\n    -1.0,\n    -1.5,\n    -2.0,\n    -3.0,\n    -4.0,\n    -6.0,\n]\n\n\ndef _dequantize_to_fp16(\n    tensor_fp4: torch.Tensor, tensor_sf: torch.Tensor, global_scale: torch.Tensor\n):\n    m, packed_k = tensor_fp4.shape\n    k = packed_k * 2\n    flat = tensor_fp4.flatten()\n    high = (flat & 0xF0) >> 4\n    low = flat & 0x0F\n    f_h = torch.tensor([K_E2M1_TO_FLOAT[x] for x in high], device=tensor_fp4.device)\n    f_l = torch.tensor([K_E2M1_TO_FLOAT[x] for x in low], device=tensor_fp4.device)\n    val = torch.stack((f_l, f_h), dim=-1).reshape(m, k)\n\n    rounded_m = ((m + 128 - 1) // 128) * 128\n    scale_n = k // BLOCK_SIZE\n    rounded_n = ((scale_n + 4 - 1) // 4) * 4\n    sf = tensor_sf.view(torch.float8_e4m3fn)\n    tmp = torch.reshape(sf, (1, rounded_m // 128, rounded_n // 4, 32, 4, 4))\n    tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5))\n    scale = torch.reshape(tmp, (rounded_m, rounded_n))[:m, :scale_n].to(torch.float32)\n    scale = scale / global_scale\n\n    return (val.view(m, scale_n, BLOCK_SIZE) * scale.unsqueeze(-1)).reshape(m, k)\n\n\ndef _aot_cutlass_scaled_fp4_mm(\n    a: torch.Tensor,\n    b: torch.Tensor,\n    block_scale_a: torch.Tensor,\n    block_scale_b: torch.Tensor,\n    alpha: torch.Tensor,\n    out_dtype: torch.dtype,\n) -> torch.Tensor:\n    out = torch.empty((a.shape[0], b.shape[0]), dtype=out_dtype, device=a.device)\n    torch.ops.sgl_kernel.cutlass_scaled_fp4_mm.default(\n        out, a, b, block_scale_a, block_scale_b, alpha\n    )\n    return out\n\n\ndef _probe_legacy_aot_scaled_mm() -> tuple[bool, str]:\n    if not torch.cuda.is_available():\n        return False, \"CUDA is not available.\"\n    if not _NVFP4_SUPPORTED:\n        return False, \"NVFP4 benchmarks require sm100+ with CUDA 12.8+.\"\n    try:\n        import sgl_kernel  # noqa: F401\n    except Exception as e:\n        return False, f\"import sgl_kernel failed: {e}\"\n    if not hasattr(torch.ops, \"sgl_kernel\"):\n        return False, \"torch.ops.sgl_kernel is not registered.\"\n    op = getattr(torch.ops.sgl_kernel, \"cutlass_scaled_fp4_mm\", None)\n    if op is None or not hasattr(op, \"default\"):\n        return False, \"torch.ops.sgl_kernel.cutlass_scaled_fp4_mm.default is missing.\"\n    try:\n        m, n, k = 16, 32, 64\n        a = torch.randn((m, k), dtype=torch.bfloat16, device=\"cuda\")\n        b = torch.randn((n, k), dtype=torch.bfloat16, device=\"cuda\")\n        a_global_scale = (\n            FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / torch.amax(a.flatten(), dim=-1)\n        ).to(torch.float32)\n        b_global_scale = (\n            FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / torch.amax(b.flatten(), dim=-1)\n        ).to(torch.float32)\n        alpha = 1.0 / (a_global_scale * b_global_scale)\n        a_fp4, a_sf = scaled_fp4_quant(a, a_global_scale)\n        b_fp4, b_sf = scaled_fp4_quant(b, b_global_scale)\n        _aot_cutlass_scaled_fp4_mm(a_fp4, b_fp4, a_sf, b_sf, alpha, torch.bfloat16)\n        torch.cuda.synchronize()\n    except Exception as e:\n        return False, f\"calling AOT scaled_mm op failed: {e}\"\n    return True, \"\"\n\n\n_AOT_SCALED_MM_AVAILABLE, _AOT_SCALED_MM_REASON = _probe_legacy_aot_scaled_mm()\n\nshape_range = get_benchmark_range(\n    full_range=[(128, 4096, 4096), (512, 4096, 4096), (1024, 8192, 4096)],\n    ci_range=[(128, 4096, 4096)],\n)\n\nline_vals = [\"jit\"]\nline_names = [\"JIT NVFP4 GEMM\"]\nstyles = [(\"green\", \"-\")]\nif _AOT_SCALED_MM_AVAILABLE:\n    line_vals.append(\"aot_sgl_kernel\")\n    line_names.append(\"AOT NVFP4 GEMM\")\n    styles.append((\"orange\", \"-\"))\nline_vals.append(\"torch_ref\")\nline_names.append(\"Torch Ref\")\nstyles.append((\"blue\", \"-\"))\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=[\"m\", \"n\", \"k\"],\n        x_vals=shape_range,\n        x_log=False,\n        line_arg=\"provider\",\n        line_vals=line_vals,\n        line_names=line_names,\n        styles=styles,\n        ylabel=\"us\",\n        plot_name=\"nvfp4-scaled-mm-performance\",\n        args={},\n    )\n)\ndef benchmark(m, n, k, provider):\n    a = torch.randn((m, k), dtype=torch.bfloat16, device=\"cuda\")\n    b = torch.randn((n, k), dtype=torch.bfloat16, device=\"cuda\")\n\n    a_global_scale = (\n        FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / torch.amax(a.flatten(), dim=-1)\n    ).to(torch.float32)\n    b_global_scale = (\n        FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / torch.amax(b.flatten(), dim=-1)\n    ).to(torch.float32)\n    alpha = 1.0 / (a_global_scale * b_global_scale)\n\n    a_fp4, a_sf = scaled_fp4_quant(a, a_global_scale)\n    b_fp4, b_sf = scaled_fp4_quant(b, b_global_scale)\n\n    if provider == \"jit\":\n        fn = lambda: cutlass_scaled_fp4_mm(\n            a_fp4, b_fp4, a_sf, b_sf, alpha, torch.bfloat16\n        )\n    elif provider == \"aot_sgl_kernel\":\n        fn = lambda: _aot_cutlass_scaled_fp4_mm(\n            a_fp4, b_fp4, a_sf, b_sf, alpha, torch.bfloat16\n        )\n    elif provider == \"torch_ref\":\n        a_ref = _dequantize_to_fp16(a_fp4, a_sf, a_global_scale)\n        b_ref = _dequantize_to_fp16(b_fp4, b_sf, b_global_scale)\n        fn = lambda: torch.matmul(a_ref, b_ref.t())\n    else:\n        raise ValueError(f\"Unknown provider: {provider}\")\n\n    return run_benchmark(fn)\n\n\nif __name__ == \"__main__\":\n    if not _NVFP4_SUPPORTED:\n        print(\"[skip] NVFP4 scaled_mm benchmark requires sm100+ with CUDA 12.8+.\")\n        sys.exit(0)\n    if not _AOT_SCALED_MM_AVAILABLE:\n        print(\n            f\"[info] legacy AOT scaled_mm baseline unavailable: {_AOT_SCALED_MM_REASON}\"\n        )\n    benchmark.run(print_data=True)\n"
  },
  {
    "path": "python/sglang/jit_kernel/benchmark/bench_per_tensor_quant_fp8.py",
    "content": "from typing import Optional, Tuple\n\nimport torch\nimport triton\nimport triton.testing\n\nfrom sglang.jit_kernel.benchmark.utils import get_benchmark_range, run_benchmark\nfrom sglang.jit_kernel.per_tensor_quant_fp8 import per_tensor_quant_fp8\n\ntry:\n    from vllm import _custom_ops as ops\n\n    VLLM_AVAILABLE = True\nexcept ImportError:\n    ops = None\n    VLLM_AVAILABLE = False\n\ntry:\n    from sglang.srt.utils import is_hip\n\n    _is_hip = is_hip()\nexcept ImportError:\n    _is_hip = False\n\nfp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn\n\n\ndef vllm_scaled_fp8_quant(\n    input: torch.Tensor,\n    scale: Optional[torch.Tensor] = None,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    if not VLLM_AVAILABLE:\n        return sglang_scaled_fp8_quant(input, scale)\n    return ops.scaled_fp8_quant(input, scale)\n\n\ndef sglang_scaled_fp8_quant(\n    input: torch.Tensor,\n    scale: Optional[torch.Tensor] = None,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    fp8_type_: torch.dtype = torch.float8_e4m3fn\n    output = torch.empty_like(input, device=input.device, dtype=fp8_type_)\n    is_static = True\n    if scale is None:\n        scale = torch.zeros(1, device=input.device, dtype=torch.float32)\n        is_static = False\n    per_tensor_quant_fp8(input, output, scale, is_static)\n\n    return output, scale\n\n\ndef calculate_diff(batch_size: int, seq_len: int):\n    device = torch.device(\"cuda\")\n    x = torch.rand((batch_size, seq_len), dtype=torch.bfloat16, device=device)\n\n    if not VLLM_AVAILABLE:\n        print(\"vLLM not available, skipping comparison\")\n        return\n\n    vllm_out, vllm_scale = vllm_scaled_fp8_quant(x)\n    sglang_out, sglang_scale = sglang_scaled_fp8_quant(x)\n\n    vllm_out = vllm_out.to(torch.float32)\n    sglang_out = sglang_out.to(torch.float32)\n\n    triton.testing.assert_close(vllm_out, sglang_out, rtol=1e-3, atol=1e-3)\n    triton.testing.assert_close(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-3)\n\n\n# Benchmark configuration\nelement_range = get_benchmark_range(\n    full_range=[2**n for n in range(10, 20)],\n    ci_range=[16384],\n)\n\nif VLLM_AVAILABLE:\n    line_vals = [\"vllm\", \"sglang\"]\n    line_names = [\"VLLM\", \"SGL Kernel\"]\n    styles = [(\"blue\", \"-\"), (\"green\", \"-\")]\nelse:\n    line_vals = [\"sglang\"]\n    line_names = [\"SGL Kernel\"]\n    styles = [(\"green\", \"-\")]\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=[\"element_count\"],\n        x_vals=element_range,\n        line_arg=\"provider\",\n        line_vals=line_vals,\n        line_names=line_names,\n        styles=styles,\n        ylabel=\"us\",\n        plot_name=\"per-tensor-quant-fp8-performance\",\n        args={},\n    )\n)\ndef benchmark(element_count, provider):\n    dtype = torch.float16\n    device = torch.device(\"cuda\")\n\n    x = torch.randn(element_count, 4096, device=device, dtype=dtype)\n\n    if provider == \"vllm\":\n        fn = lambda: vllm_scaled_fp8_quant(x.clone())\n    elif provider == \"sglang\":\n        fn = lambda: sglang_scaled_fp8_quant(x.clone())\n    else:\n        raise ValueError(f\"Unknown provider: {provider}\")\n\n    return run_benchmark(fn)\n\n\nif __name__ == \"__main__\":\n    calculate_diff(batch_size=4, seq_len=4096)\n    benchmark.run(print_data=True)\n"
  },
  {
    "path": "python/sglang/jit_kernel/benchmark/bench_per_token_group_quant_8bit.py",
    "content": "import itertools\nfrom typing import Any, Dict, List\n\nimport torch\nimport triton\nfrom sgl_kernel.test_utils import create_per_token_group_quant_test_data\n\nfrom sglang.jit_kernel.benchmark.utils import (\n    get_benchmark_range,\n    is_in_ci,\n)\nfrom sglang.jit_kernel.per_token_group_quant_8bit import (\n    per_token_group_quant_8bit as sglang_per_token_group_quant_8bit,\n)\nfrom sglang.srt.layers.quantization.fp8_kernel import (\n    create_per_token_group_quant_fp8_output_scale,\n)\nfrom sglang.srt.layers.quantization.fp8_kernel import (\n    per_token_group_quant_8bit as triton_per_token_group_quant_8bit,\n)\nfrom sglang.srt.utils import is_hip\nfrom sglang.srt.utils.bench_utils import bench_kineto\n\nIS_CI = is_in_ci()\n\n_is_hip = is_hip()\nfp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn\n\nNUM_TESTS = 30 if IS_CI else 300\n\nGROUP_SIZE_RANGE = [128]\nDST_DTYPE_RANGE = [fp8_type_]\n\n# ---- GEMM-like branch (num_ranks=None) ----\nNUM_TOKENS_RANGE_GEMM = get_benchmark_range(\n    full_range=[1, 4, 16, 64, 256, 768, 2048, 8192, 16384],\n    ci_range=[768],\n)\nHIDDEN_DIM_RANGE_GEMM = [1536, 7168, 16384]\nNUM_RANKS_RANGE_GEMM = [None]\n\n\nFLAGS_GEMM_FULL: List[Dict[str, Any]] = [\n    dict(\n        column_major_scales=False,\n        scale_tma_aligned=False,\n        scale_ue8m0=False,\n        fuse_silu_and_mul=False,\n        masked_layout_mode=None,\n    ),\n    dict(\n        column_major_scales=True,\n        scale_tma_aligned=False,\n        scale_ue8m0=False,\n        fuse_silu_and_mul=False,\n        masked_layout_mode=None,\n    ),\n    dict(\n        column_major_scales=True,\n        scale_tma_aligned=True,\n        scale_ue8m0=False,\n        fuse_silu_and_mul=False,\n        masked_layout_mode=None,\n    ),\n    dict(\n        column_major_scales=True,\n        scale_tma_aligned=True,\n        scale_ue8m0=True,\n        fuse_silu_and_mul=False,\n        masked_layout_mode=None,\n    ),\n]\nFLAGS_GEMM_CI: List[Dict[str, Any]] = [\n    dict(\n        column_major_scales=True,\n        scale_tma_aligned=True,\n        scale_ue8m0=True,\n        fuse_silu_and_mul=False,\n        masked_layout_mode=None,\n    ),\n]\nFLAGS_RANGE_GEMM = get_benchmark_range(\n    full_range=FLAGS_GEMM_FULL, ci_range=FLAGS_GEMM_CI\n)\n\nCONFIGS_GEMM = list(\n    itertools.product(\n        NUM_TOKENS_RANGE_GEMM,\n        HIDDEN_DIM_RANGE_GEMM,\n        GROUP_SIZE_RANGE,\n        NUM_RANKS_RANGE_GEMM,\n        DST_DTYPE_RANGE,\n        FLAGS_RANGE_GEMM,\n    )\n)\n\n# ---- MoE-like / multi-rank branch (hidden_dim=2048, num_ranks in {8,16,32,48}) ----\nNUM_TOKENS_RANGE_MOE = get_benchmark_range(\n    full_range=[1 * 8, 4 * 8, 64 * 8, 256 * 8, 768 * 8],\n    ci_range=[768 * 8],\n)\nHIDDEN_DIM_RANGE_MOE = [2048]\nNUM_RANKS_RANGE_MOE = get_benchmark_range(\n    full_range=[8, 16, 32, 48],\n    ci_range=[48],\n)\n\nFLAGS_MOE: List[Dict[str, Any]] = [\n    dict(\n        column_major_scales=True,\n        scale_tma_aligned=True,\n        scale_ue8m0=True,\n        fuse_silu_and_mul=True,\n        masked_layout_mode=None,\n    ),\n    dict(\n        column_major_scales=True,\n        scale_tma_aligned=True,\n        scale_ue8m0=True,\n        fuse_silu_and_mul=True,\n        masked_layout_mode=\"balanced\",\n    ),\n    dict(\n        column_major_scales=True,\n        scale_tma_aligned=True,\n        scale_ue8m0=True,\n        fuse_silu_and_mul=True,\n        masked_layout_mode=\"imbalanced\",\n    ),\n    dict(\n        column_major_scales=True,\n        scale_tma_aligned=True,\n        scale_ue8m0=True,\n        fuse_silu_and_mul=True,\n        masked_layout_mode=\"extreme\",\n    ),\n]\nFLAGS_RANGE_MOE = get_benchmark_range(full_range=FLAGS_MOE, ci_range=FLAGS_MOE)\n\nCONFIGS_MOE = list(\n    itertools.product(\n        NUM_TOKENS_RANGE_MOE,\n        HIDDEN_DIM_RANGE_MOE,\n        GROUP_SIZE_RANGE,\n        NUM_RANKS_RANGE_MOE,\n        DST_DTYPE_RANGE,\n        FLAGS_RANGE_MOE,\n    )\n)\n\n# ---- Final configs ----\nCONFIGS = CONFIGS_GEMM + CONFIGS_MOE\n\nLINE_VALS = [\"triton\", \"sglang\"]\nLINE_NAMES = [\"Triton (Inaccurate)\", \"SGL Kernel\"]\nSTYLES = [(\"blue\", \"-\"), (\"green\", \"-\")]\n\n\ndef _flatten_to_2d(t: torch.Tensor) -> torch.Tensor:\n    \"\"\"Reshape a tensor with 3+ dims to 2D by merging all leading dims.\"\"\"\n    if t.ndim <= 2:\n        return t\n    return t.reshape(-1, t.shape[-1])\n\n\ndef _make_sglang_bench_fn(\n    x: torch.Tensor,\n    group_size: int,\n    dst_dtype: torch.dtype,\n    flags: dict,\n):\n    \"\"\"\n    Adapter that pre-allocates output tensors and returns a zero-arg callable\n    matching the JIT kernel's signature.\n\n    The JIT kernel does not support fuse_silu_and_mul, so when enabled we\n    pre-compute silu+mul on the input. bench_kineto only times the kernel\n    matching the given name, so the pre-processing is not included.\n\n    The JIT kernel expects 2D tensors, so any higher-dimensional inputs\n    (e.g. from masked_layout_mode) are flattened to 2D.\n    \"\"\"\n    fuse_silu_and_mul = flags.get(\"fuse_silu_and_mul\", False)\n    column_major_scales = flags.get(\"column_major_scales\", False)\n    scale_tma_aligned = flags.get(\"scale_tma_aligned\", False)\n    scale_ue8m0 = flags.get(\"scale_ue8m0\", False)\n\n    # JIT kernel does not support fuse_silu_and_mul; pre-compute it\n    if fuse_silu_and_mul:\n        half = x.shape[-1] // 2\n        x_input = torch.nn.functional.silu(x[..., :half]) * x[..., half:]\n    else:\n        x_input = x\n\n    # JIT kernel expects 2D (num_tokens, hidden_dim); flatten if needed\n    x_input = _flatten_to_2d(x_input.contiguous())\n\n    out_shape = x_input.shape\n    output_q = torch.empty(out_shape, device=x.device, dtype=dst_dtype)\n\n    fp8_max = torch.finfo(dst_dtype).max\n    fp8_min = -fp8_max\n\n    output_s = create_per_token_group_quant_fp8_output_scale(\n        x_shape=out_shape,\n        device=x.device,\n        group_size=group_size,\n        column_major_scales=column_major_scales,\n        scale_tma_aligned=scale_tma_aligned,\n        scale_ue8m0=scale_ue8m0,\n    )\n\n    def _run():\n        sglang_per_token_group_quant_8bit(\n            input=x_input,\n            output_q=output_q,\n            output_s=output_s,\n            group_size=group_size,\n            eps=1e-10,\n            fp8_min=fp8_min,\n            fp8_max=fp8_max,\n            scale_ue8m0=scale_ue8m0,\n        )\n\n    return _run\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=[\n            \"num_tokens\",\n            \"hidden_dim\",\n            \"group_size\",\n            \"num_ranks\",\n            \"dst_dtype\",\n            \"flags\",\n        ],\n        x_vals=CONFIGS,\n        line_arg=\"provider\",\n        line_vals=LINE_VALS,\n        # Triton has multi kernels and we only report the time for the core one\n        line_names=LINE_NAMES,\n        styles=STYLES,\n        ylabel=\"us\",\n        plot_name=\"per-token-group-quant-8bit-performance\",\n        args={},\n    )\n)\ndef benchmark(\n    num_tokens, hidden_dim, group_size, num_ranks, dst_dtype, flags, provider\n):\n    print(\n        f\"Testing: {num_tokens=} {hidden_dim=} {group_size=} {num_ranks=} {dst_dtype=} {flags=} {provider=}\"\n    )\n\n    x, masked_m = create_per_token_group_quant_test_data(\n        num_tokens=num_tokens, hidden_dim=hidden_dim, num_ranks=num_ranks, flags=flags\n    )\n\n    if provider == \"triton\":\n        fn = triton_per_token_group_quant_8bit\n        kernel_names = \"_per_token_group_quant_8bit|_silu_and_mul_post_quant_kernel\"\n        bench_fn = lambda: fn(\n            x=x,\n            masked_m=masked_m,\n            group_size=group_size,\n            dst_dtype=dst_dtype,\n            **{k: v for k, v in flags.items() if k not in [\"masked_layout_mode\"]},\n        )\n    elif provider == \"sglang\":\n        kernel_names = \"per_token_group_quant_8bit_kernel\"\n        bench_fn = _make_sglang_bench_fn(\n            x=x,\n            group_size=group_size,\n            dst_dtype=dst_dtype,\n            flags=flags,\n        )\n    else:\n        raise ValueError(f\"Unknown provider: {provider}\")\n\n    time_s = bench_kineto(bench_fn, kernel_names=kernel_names, num_tests=NUM_TESTS)\n    return time_s * 1e6\n\n\nif __name__ == \"__main__\":\n    benchmark.run(print_data=True)\n"
  },
  {
    "path": "python/sglang/jit_kernel/benchmark/bench_qknorm.py",
    "content": "import itertools\n\nimport torch\nimport triton\nimport triton.testing\nfrom sgl_kernel import rmsnorm\n\nfrom sglang.jit_kernel.benchmark.utils import (\n    DEFAULT_DEVICE,\n    DEFAULT_DTYPE,\n    get_benchmark_range,\n    run_benchmark,\n)\nfrom sglang.jit_kernel.norm import fused_inplace_qknorm\nfrom sglang.srt.utils import get_current_device_stream_fast\n\nalt_stream = torch.cuda.Stream()\n\n\ndef sglang_aot_qknorm(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    q_weight: torch.Tensor,\n    k_weight: torch.Tensor,\n) -> None:\n\n    head_dim = q.shape[-1]\n    q = q.view(-1, head_dim)\n    k = k.view(-1, head_dim)\n\n    current_stream = get_current_device_stream_fast()\n    alt_stream.wait_stream(current_stream)\n    rmsnorm(q, q_weight, out=q)\n    with torch.cuda.stream(alt_stream):\n        rmsnorm(k, k_weight, out=k)\n    current_stream.wait_stream(alt_stream)\n\n\ndef sglang_jit_qknorm(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    q_weight: torch.Tensor,\n    k_weight: torch.Tensor,\n) -> None:\n\n    fused_inplace_qknorm(q, k, q_weight, k_weight)\n\n\ndef flashinfer_qknorm(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    q_weight: torch.Tensor,\n    k_weight: torch.Tensor,\n) -> None:\n    from flashinfer import rmsnorm\n\n    rmsnorm(q, q_weight, out=q)\n    rmsnorm(k, k_weight, out=k)\n\n\n@torch.compile()\ndef torch_impl_qknorm(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    q_weight: torch.Tensor,\n    k_weight: torch.Tensor,\n    eps: float = 1e-6,\n) -> None:\n    q_mean = q.float().pow(2).mean(dim=-1, keepdim=True)\n    k_mean = k.float().pow(2).mean(dim=-1, keepdim=True)\n    q_norm = (q_mean + eps).rsqrt()\n    k_norm = (k_mean + eps).rsqrt()\n    q.copy_(q.float() * q_norm * q_weight.float())\n    k.copy_(k.float() * k_norm * k_weight.float())\n\n\nBS_RANGE = get_benchmark_range(\n    full_range=[2**n for n in range(0, 14)],\n    ci_range=[16],\n)\nGQA_RANGE = get_benchmark_range(\n    full_range=[4, 8],\n    ci_range=[4],\n)\nKV_HEAD_RANGE = get_benchmark_range(\n    full_range=[1, 2, 4, 8],\n    ci_range=[1],\n)\nHEAD_DIM_RANGE = get_benchmark_range(\n    full_range=[128, 256, 512, 1024],\n    ci_range=[128],\n)\n\nLINE_VALS = [\"aot\", \"jit\", \"flashinfer\", \"torch\"]\nLINE_NAMES = [\"SGL AOT Kernel\", \"SGL JIT Kernel\", \"FlashInfer\", \"PyTorch\"]\nSTYLES = [(\"orange\", \"-\"), (\"blue\", \"--\"), (\"green\", \"-.\"), (\"red\", \":\")]\n\nconfigs = list(itertools.product(HEAD_DIM_RANGE, GQA_RANGE, KV_HEAD_RANGE, BS_RANGE))\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=[\"head_dim\", \"GQA\", \"num_kv_heads\", \"batch_size\"],\n        x_vals=configs,\n        line_arg=\"provider\",\n        line_vals=LINE_VALS,\n        line_names=LINE_NAMES,\n        styles=STYLES,\n        ylabel=\"us\",\n        plot_name=\"qknorm-performance\",\n        args={},\n    )\n)\ndef benchmark(\n    head_dim: int, GQA: int, num_kv_heads: int, batch_size: int, provider: str\n):\n    num_qo_heads = GQA * num_kv_heads\n    q = torch.randn(\n        (batch_size, num_qo_heads, head_dim), dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE\n    )\n    k = torch.randn(\n        (batch_size, num_kv_heads, head_dim), dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE\n    )\n    q_weight = torch.randn(head_dim, dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE)\n    k_weight = torch.randn(head_dim, dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE)\n    FN_MAP = {\n        \"aot\": sglang_aot_qknorm,\n        \"jit\": sglang_jit_qknorm,\n        \"flashinfer\": flashinfer_qknorm,\n        \"torch\": torch_impl_qknorm,\n    }\n    fn = lambda: FN_MAP[provider](q, k, q_weight, k_weight)\n    return run_benchmark(fn)\n\n\nif __name__ == \"__main__\":\n    benchmark.run(print_data=True)\n"
  },
  {
    "path": "python/sglang/jit_kernel/benchmark/bench_qknorm_across_heads.py",
    "content": "import itertools\nfrom typing import Tuple\n\nimport torch\nimport triton\nimport triton.testing\nfrom sgl_kernel import rmsnorm\n\nfrom sglang.jit_kernel.benchmark.utils import is_in_ci, run_benchmark\nfrom sglang.jit_kernel.norm import fused_inplace_qknorm_across_heads\nfrom sglang.srt.utils import get_current_device_stream_fast\n\nIS_CI = is_in_ci()\n\nalt_stream = torch.cuda.Stream()\n\n\ndef sglang_jit_qknorm_across_heads(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    q_weight: torch.Tensor,\n    k_weight: torch.Tensor,\n) -> None:\n\n    fused_inplace_qknorm_across_heads(q, k, q_weight, k_weight)\n\n\ndef sglang_aot_qknorm_across_heads(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    q_weight: torch.Tensor,\n    k_weight: torch.Tensor,\n) -> None:\n\n    current_stream = get_current_device_stream_fast()\n    alt_stream.wait_stream(current_stream)\n    rmsnorm(q, q_weight, out=q)\n    with torch.cuda.stream(alt_stream):\n        rmsnorm(k, k_weight, out=k)\n    current_stream.wait_stream(alt_stream)\n\n\ndef flashinfer_qknorm_across_heads(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    q_weight: torch.Tensor,\n    k_weight: torch.Tensor,\n) -> None:\n    from flashinfer import rmsnorm\n\n    rmsnorm(q, q_weight, out=q)\n    rmsnorm(k, k_weight, out=k)\n\n\n@torch.compile()\ndef torch_impl_qknorm_across_heads(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    q_weight: torch.Tensor,\n    k_weight: torch.Tensor,\n    eps: float = 1e-6,\n) -> None:\n    q_mean = q.float().pow(2).mean(dim=-1, keepdim=True)\n    k_mean = k.float().pow(2).mean(dim=-1, keepdim=True)\n    q_norm = (q_mean + eps).rsqrt()\n    k_norm = (k_mean + eps).rsqrt()\n    q.copy_(q.float() * q_norm * q_weight.float())\n    k.copy_(k.float() * k_norm * k_weight.float())\n\n\nDTYPE = torch.bfloat16\nDEVICE = \"cuda\"\n\nif IS_CI:\n    BS_RANGE = [16]\n    HIDDEN_DIM_RANGE = [1024]\nelse:\n    BS_RANGE = [2**n for n in range(0, 14)]\n    HIDDEN_DIM_RANGE = [512, 1024, 2048, 4096, 8192]\n\nLINE_VALS = [\"jit\", \"aot\", \"flashinfer\", \"torch\"]\nLINE_NAMES = [\"SGL JIT Kernel\", \"SGL AOT Kernel\", \"FlashInfer\", \"PyTorch\"]\nSTYLES = [(\"blue\", \"-\"), (\"orange\", \"--\"), (\"green\", \"-.\"), (\"red\", \":\")]\n\nconfigs = list(itertools.product(BS_RANGE, HIDDEN_DIM_RANGE))\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=[\"batch_size\", \"hidden_dim\"],\n        x_vals=configs,\n        line_arg=\"provider\",\n        line_vals=LINE_VALS,\n        line_names=LINE_NAMES,\n        styles=STYLES,\n        ylabel=\"us\",\n        plot_name=\"qknorm-across-heads-performance\",\n        args={},\n    )\n)\ndef benchmark(\n    batch_size: int, hidden_dim: int, provider: str\n) -> Tuple[float, float, float]:\n    q = torch.randn((batch_size, hidden_dim), dtype=DTYPE, device=DEVICE)\n    k = torch.randn((batch_size, hidden_dim), dtype=DTYPE, device=DEVICE)\n    q_weight = torch.randn(hidden_dim, dtype=DTYPE, device=DEVICE)\n    k_weight = torch.randn(hidden_dim, dtype=DTYPE, device=DEVICE)\n    FN_MAP = {\n        \"jit\": sglang_jit_qknorm_across_heads,\n        \"aot\": sglang_aot_qknorm_across_heads,\n        \"flashinfer\": flashinfer_qknorm_across_heads,\n        \"torch\": torch_impl_qknorm_across_heads,\n    }\n    fn = lambda: FN_MAP[provider](q, k, q_weight, k_weight)\n    return run_benchmark(fn)\n\n\nif __name__ == \"__main__\":\n    benchmark.run(print_data=True)\n"
  },
  {
    "path": "python/sglang/jit_kernel/benchmark/bench_qwen_image_modulation.py",
    "content": "from typing import Tuple\n\nimport torch\nimport triton.testing\n\nfrom sglang.jit_kernel.benchmark.utils import is_in_ci, run_benchmark_no_cudagraph\nfrom sglang.jit_kernel.diffusion.triton.norm import norm_infer\nfrom sglang.jit_kernel.diffusion.triton.scale_shift import (\n    fuse_layernorm_scale_shift_gate_select01_kernel,\n    fuse_residual_layernorm_scale_shift_gate_select01_kernel,\n    fuse_scale_shift_gate_select01_kernel,\n)\n\nif is_in_ci():\n    B_RANGE, S_RANGE, D_RANGE = [1], [128], [3072]\nelse:\n    B_RANGE, S_RANGE, D_RANGE = [1, 2], [128, 512, 2048], [1024, 1536, 3072]\n\nDTYPE = torch.bfloat16\nDEVICE = \"cuda\"\nEPS = 1e-6\nLINE_VALS = [\"split\", \"fused\"]\nLINE_NAMES = [\"Split Kernels\", \"Fused Triton\"]\nSTYLES = [(\"red\", \"-\"), (\"blue\", \"--\")]\nCONFIG = [(b, s, d) for b in B_RANGE for s in S_RANGE for d in D_RANGE]\n\n\ndef _make_common_inputs(batch_size: int, seq_len: int, hidden_size: int):\n    x = torch.randn(batch_size, seq_len, hidden_size, dtype=DTYPE, device=DEVICE)\n    weight = torch.randn(hidden_size, dtype=DTYPE, device=DEVICE)\n    bias = torch.randn(hidden_size, dtype=DTYPE, device=DEVICE)\n    index = torch.randint(0, 2, (batch_size, seq_len), dtype=torch.int32, device=DEVICE)\n    scale0 = torch.randn(batch_size, hidden_size, dtype=DTYPE, device=DEVICE)\n    shift0 = torch.randn(batch_size, hidden_size, dtype=DTYPE, device=DEVICE)\n    gate0 = torch.randn(batch_size, hidden_size, dtype=DTYPE, device=DEVICE)\n    scale1 = torch.randn(batch_size, hidden_size, dtype=DTYPE, device=DEVICE)\n    shift1 = torch.randn(batch_size, hidden_size, dtype=DTYPE, device=DEVICE)\n    gate1 = torch.randn(batch_size, hidden_size, dtype=DTYPE, device=DEVICE)\n    return x, weight, bias, index, scale0, shift0, gate0, scale1, shift1, gate1\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=[\"B\", \"S\", \"D\"],\n        x_vals=CONFIG,\n        line_arg=\"provider\",\n        line_vals=LINE_VALS,\n        line_names=LINE_NAMES,\n        styles=STYLES,\n        ylabel=\"us\",\n        plot_name=\"qwen_image_layernorm_scale_shift_gate_select01\",\n        args={},\n    )\n)\ndef bench_layernorm_scale_shift_gate_select01(\n    B: int, S: int, D: int, provider: str\n) -> Tuple[float, float, float]:\n    x, weight, bias, index, scale0, shift0, gate0, scale1, shift1, gate1 = (\n        _make_common_inputs(B, S, D)\n    )\n\n    if provider == \"split\":\n\n        def fn():\n            normalized = norm_infer(\n                x.view(-1, x.shape[-1]),\n                weight,\n                bias,\n                eps=EPS,\n                is_rms_norm=False,\n            ).view_as(x)\n            return fuse_scale_shift_gate_select01_kernel(\n                normalized,\n                scale0=scale0,\n                shift0=shift0,\n                gate0=gate0,\n                scale1=scale1,\n                shift1=shift1,\n                gate1=gate1,\n                index=index,\n            )\n\n    else:\n\n        def fn():\n            return fuse_layernorm_scale_shift_gate_select01_kernel(\n                x,\n                weight=weight,\n                bias=bias,\n                scale0=scale0,\n                shift0=shift0,\n                gate0=gate0,\n                scale1=scale1,\n                shift1=shift1,\n                gate1=gate1,\n                index=index,\n                eps=EPS,\n            )\n\n    return run_benchmark_no_cudagraph(fn)\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=[\"B\", \"S\", \"D\"],\n        x_vals=CONFIG,\n        line_arg=\"provider\",\n        line_vals=LINE_VALS,\n        line_names=LINE_NAMES,\n        styles=STYLES,\n        ylabel=\"us\",\n        plot_name=\"qwen_image_residual_layernorm_scale_shift_gate_select01\",\n        args={},\n    )\n)\ndef bench_residual_layernorm_scale_shift_gate_select01(\n    B: int, S: int, D: int, provider: str\n) -> Tuple[float, float, float]:\n    x, weight, bias, index, scale0, shift0, gate0, scale1, shift1, gate1 = (\n        _make_common_inputs(B, S, D)\n    )\n    residual = torch.randn_like(x)\n    residual_gate = torch.randn_like(x)\n\n    if provider == \"split\":\n\n        def fn():\n            residual_out = residual + residual_gate * x\n            normalized = norm_infer(\n                residual_out.view(-1, residual_out.shape[-1]),\n                weight,\n                bias,\n                eps=EPS,\n                is_rms_norm=False,\n            ).view_as(residual_out)\n            return fuse_scale_shift_gate_select01_kernel(\n                normalized,\n                scale0=scale0,\n                shift0=shift0,\n                gate0=gate0,\n                scale1=scale1,\n                shift1=shift1,\n                gate1=gate1,\n                index=index,\n            )\n\n    else:\n\n        def fn():\n            return fuse_residual_layernorm_scale_shift_gate_select01_kernel(\n                x,\n                residual=residual,\n                residual_gate=residual_gate,\n                weight=weight,\n                bias=bias,\n                scale0=scale0,\n                shift0=shift0,\n                gate0=gate0,\n                scale1=scale1,\n                shift1=shift1,\n                gate1=gate1,\n                index=index,\n                eps=EPS,\n            )\n\n    return run_benchmark_no_cudagraph(fn)\n\n\nif __name__ == \"__main__\":\n    print(f\"\\n{'=' * 80}\")\n    print(\"Benchmark: qwen_image layernorm + scale_shift_gate_select01\")\n    print(f\"{'=' * 80}\\n\")\n    bench_layernorm_scale_shift_gate_select01.run(print_data=True)\n\n    print(f\"\\n{'=' * 80}\")\n    print(\"Benchmark: qwen_image residual + layernorm + scale_shift_gate_select01\")\n    print(f\"{'=' * 80}\\n\")\n    bench_residual_layernorm_scale_shift_gate_select01.run(print_data=True)\n"
  },
  {
    "path": "python/sglang/jit_kernel/benchmark/bench_renorm.py",
    "content": "import itertools\n\nimport sgl_kernel\nimport torch\nimport triton\nimport triton.testing\n\nfrom sglang.jit_kernel.benchmark.utils import is_in_ci, run_benchmark_no_cudagraph\n\n\ndef torch_top_k_renorm_probs(probs, top_k):\n    \"\"\"Vectorized PyTorch implementation of top-k renormalization.\"\"\"\n    batch_size, vocab_size = probs.shape\n\n    # Handle scalar or tensor k\n    if isinstance(top_k, int):\n        k_val = min(max(top_k, 1), vocab_size)\n        # Get top-k indices for all batches at once\n        _, topk_indices = torch.topk(probs, k_val, dim=1, largest=True)\n\n        # Create mask: batch_size x vocab_size\n        mask = torch.zeros_like(probs)\n        mask.scatter_(1, topk_indices, 1.0)\n\n        # Vectorized renormalization\n        masked_probs = probs * mask\n        renorm_probs = masked_probs / (masked_probs.sum(dim=1, keepdim=True) + 1e-10)\n        return renorm_probs\n    else:\n        # Variable k per batch - need to handle separately\n        renorm_probs = torch.zeros_like(probs)\n        for i in range(batch_size):\n            k_val = min(max(top_k[i].item(), 1), vocab_size)\n            _, topk_indices = torch.topk(probs[i], k_val, largest=True)\n            mask = torch.zeros_like(probs[i])\n            mask[topk_indices] = 1.0\n            masked_probs = probs[i] * mask\n            renorm_probs[i] = masked_probs / (masked_probs.sum() + 1e-10)\n        return renorm_probs\n\n\ndef torch_top_p_renorm_probs(probs, top_p, eps=1e-5):\n    \"\"\"Vectorized PyTorch implementation of top-p renormalization.\"\"\"\n    batch_size, vocab_size = probs.shape\n\n    # Handle scalar or tensor p\n    if isinstance(top_p, float):\n        p_val = top_p\n        # Vectorized implementation for uniform top_p\n        # Sort probs in descending order\n        sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=1)\n        cumsum_probs = torch.cumsum(sorted_probs, dim=1)\n\n        # Find cutoff: where cumsum exceeds top_p\n        cutoff_mask = cumsum_probs <= p_val\n        # Keep at least one token (the highest prob)\n        cutoff_mask[:, 0] = True\n\n        # Create mask in original order\n        mask = torch.zeros_like(probs)\n        mask.scatter_(1, sorted_indices, cutoff_mask.float())\n\n        # Vectorized renormalization\n        masked_probs = probs * mask\n        renorm_probs = masked_probs / (masked_probs.sum(dim=1, keepdim=True) + eps)\n        return renorm_probs\n    else:\n        # Variable p per batch - need to handle separately\n        renorm_probs = torch.zeros_like(probs)\n        for i in range(batch_size):\n            p_val = top_p[i].item()\n            sorted_prob, indices = torch.sort(probs[i], descending=False)\n            cdf = torch.cumsum(sorted_prob, dim=-1)\n            mask = torch.zeros(vocab_size, dtype=torch.float32, device=probs.device)\n            mask.scatter_(0, indices, (cdf >= (1 - p_val) - eps).float())\n            masked_probs = probs[i] * mask\n            renorm_probs[i] = masked_probs / (masked_probs.sum() + eps)\n        return renorm_probs\n\n\ndef torch_top_k_mask_logits(logits, top_k):\n    \"\"\"Vectorized PyTorch implementation of top-k logits masking.\"\"\"\n    batch_size, vocab_size = logits.shape\n\n    # Handle scalar or tensor k\n    if isinstance(top_k, int):\n        k_val = min(max(top_k, 1), vocab_size)\n        # Get top-k indices for all batches at once\n        _, topk_indices = torch.topk(logits, k_val, dim=1, largest=True)\n\n        # Create masked logits: start with -inf everywhere\n        masked_logits = torch.full_like(logits, float(\"-inf\"))\n        # Scatter the top-k values back\n        masked_logits.scatter_(1, topk_indices, logits.gather(1, topk_indices))\n    else:\n        # Variable k per batch - need to handle separately\n        masked_logits = torch.full_like(logits, float(\"-inf\"))\n        for i in range(batch_size):\n            k_val = min(max(top_k[i].item(), 1), vocab_size)\n            _, topk_indices = torch.topk(logits[i], k_val, largest=True)\n            masked_logits[i, topk_indices] = logits[i, topk_indices]\n\n    return masked_logits\n\n\ndef calculate_diff_top_k_renorm(batch_size, vocab_size, k):\n    \"\"\"Compare Torch reference and SGLang kernel for top-k renorm correctness.\"\"\"\n    torch.manual_seed(42)\n    device = torch.device(\"cuda\")\n\n    pre_norm_prob = torch.rand(batch_size, vocab_size, device=device)\n    probs = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)\n\n    top_k_tensor = torch.full((batch_size,), k, device=device, dtype=torch.int32)\n\n    torch_output = torch_top_k_renorm_probs(probs, top_k_tensor)\n    sglang_output = sgl_kernel.top_k_renorm_prob(probs, top_k_tensor)\n\n    torch.testing.assert_close(torch_output, sglang_output, rtol=1e-3, atol=1e-3)\n\n\ndef calculate_diff_top_p_renorm(batch_size, vocab_size, p):\n    \"\"\"Compare Torch reference and SGLang kernel for top-p renorm correctness.\"\"\"\n    torch.manual_seed(42)\n    device = torch.device(\"cuda\")\n\n    pre_norm_prob = torch.rand(batch_size, vocab_size, device=device)\n    probs = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)\n\n    top_p_tensor = torch.full((batch_size,), p, device=device, dtype=torch.float32)\n\n    torch_output = torch_top_p_renorm_probs(probs, top_p_tensor)\n    sglang_output = sgl_kernel.top_p_renorm_prob(probs, top_p_tensor)\n\n    torch.testing.assert_close(torch_output, sglang_output, rtol=1e-3, atol=1e-3)\n\n\ndef calculate_diff_top_k_mask(batch_size, vocab_size, k):\n    \"\"\"Compare Torch reference and SGLang kernel for top-k mask correctness.\"\"\"\n    torch.manual_seed(42)\n    device = torch.device(\"cuda\")\n\n    logits = torch.randn(batch_size, vocab_size, device=device) * 5\n    top_k_tensor = torch.full((batch_size,), k, device=device, dtype=torch.int32)\n\n    torch_output = torch_top_k_mask_logits(logits, top_k_tensor)\n    sglang_output = sgl_kernel.top_k_mask_logits(logits, top_k_tensor)\n\n    torch.testing.assert_close(torch_output, sglang_output, rtol=1e-3, atol=1e-3)\n\n\n# Parameter space - simplified for CI\nif is_in_ci():\n    batch_size_range = [16]\n    vocab_size_range = [111]\n    k_range = [10]\n    p_range = [0.5]\nelse:\n    batch_size_range = [16, 64, 128]\n    vocab_size_range = [111, 32000, 128256]\n    k_range = [10, 100, 500]\n    p_range = [0.1, 0.5, 0.9]\n\nconfigs_k = list(itertools.product(batch_size_range, vocab_size_range, k_range))\nconfigs_p = list(itertools.product(batch_size_range, vocab_size_range, p_range))\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=[\"batch_size\", \"vocab_size\", \"k\"],\n        x_vals=configs_k,\n        line_arg=\"provider\",\n        line_vals=[\"torch\", \"sglang\"],\n        line_names=[\"Torch Reference\", \"SGL Kernel\"],\n        styles=[(\"red\", \"-\"), (\"green\", \"-\")],\n        ylabel=\"us\",\n        plot_name=\"top-k-renorm-probs-performance\",\n        args={},\n    )\n)\ndef benchmark_top_k_renorm(batch_size, vocab_size, k, provider):\n    # Skip invalid configurations\n    if k >= vocab_size:\n        return float(\"nan\"), float(\"nan\"), float(\"nan\")\n\n    torch.manual_seed(42)\n    device = torch.device(\"cuda\")\n\n    pre_norm_prob = torch.rand(batch_size, vocab_size, device=device)\n    probs = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)\n    top_k_tensor = torch.full((batch_size,), k, device=device, dtype=torch.int32)\n\n    if provider == \"torch\":\n        fn = lambda: torch_top_k_renorm_probs(probs.clone(), top_k_tensor)\n    elif provider == \"sglang\":\n        fn = lambda: sgl_kernel.top_k_renorm_prob(probs.clone(), top_k_tensor)\n\n    return run_benchmark_no_cudagraph(fn)\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=[\"batch_size\", \"vocab_size\", \"p\"],\n        x_vals=configs_p,\n        line_arg=\"provider\",\n        line_vals=[\"torch\", \"sglang\"],\n        line_names=[\"Torch Reference\", \"SGL Kernel\"],\n        styles=[(\"red\", \"-\"), (\"blue\", \"-\")],\n        ylabel=\"us\",\n        plot_name=\"top-p-renorm-probs-performance\",\n        args={},\n    )\n)\ndef benchmark_top_p_renorm(batch_size, vocab_size, p, provider):\n    torch.manual_seed(42)\n    device = torch.device(\"cuda\")\n\n    pre_norm_prob = torch.rand(batch_size, vocab_size, device=device)\n    probs = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)\n    top_p_tensor = torch.full((batch_size,), p, device=device, dtype=torch.float32)\n\n    if provider == \"torch\":\n        fn = lambda: torch_top_p_renorm_probs(probs.clone(), top_p_tensor)\n    elif provider == \"sglang\":\n        fn = lambda: sgl_kernel.top_p_renorm_prob(probs.clone(), top_p_tensor)\n\n    return run_benchmark_no_cudagraph(fn)\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=[\"batch_size\", \"vocab_size\", \"k\"],\n        x_vals=configs_k,\n        line_arg=\"provider\",\n        line_vals=[\"torch\", \"sglang\"],\n        line_names=[\"Torch Reference\", \"SGL Kernel\"],\n        styles=[(\"red\", \"-\"), (\"orange\", \"-\")],\n        ylabel=\"us\",\n        plot_name=\"top-k-mask-logits-performance\",\n        args={},\n    )\n)\ndef benchmark_top_k_mask(batch_size, vocab_size, k, provider):\n    # Skip invalid configurations\n    if k >= vocab_size:\n        return float(\"nan\"), float(\"nan\"), float(\"nan\")\n\n    torch.manual_seed(42)\n    device = torch.device(\"cuda\")\n\n    logits = torch.randn(batch_size, vocab_size, device=device) * 5\n    top_k_tensor = torch.full((batch_size,), k, device=device, dtype=torch.int32)\n\n    if provider == \"torch\":\n        fn = lambda: torch_top_k_mask_logits(logits.clone(), top_k_tensor)\n    elif provider == \"sglang\":\n        fn = lambda: sgl_kernel.top_k_mask_logits(logits.clone(), top_k_tensor)\n\n    return run_benchmark_no_cudagraph(fn)\n\n\nif __name__ == \"__main__\":\n    print(\"=\" * 60)\n    print(\"Running correctness checks...\")\n    print(\"=\" * 60)\n\n    # Correctness checks - simplified for CI\n    if is_in_ci():\n        test_configs_k = [configs_k[0]] if configs_k else [(16, 111, 10)]\n        test_configs_p = [configs_p[0]] if configs_p else [(16, 111, 0.5)]\n    else:\n        test_configs_k = configs_k[:3]  # Test first 3 configs\n        test_configs_p = configs_p[:3]\n\n    print(\"\\n1. Testing top_k_renorm_probs...\")\n    for cfg in test_configs_k:\n        batch_size, vocab_size, k = cfg\n        if k < vocab_size:  # Skip invalid configs\n            calculate_diff_top_k_renorm(batch_size, vocab_size, k)\n            print(\n                f\"  ✓ Passed: batch_size={batch_size}, vocab_size={vocab_size}, k={k}\"\n            )\n\n    print(\"\\n2. Testing top_p_renorm_probs...\")\n    for cfg in test_configs_p:\n        calculate_diff_top_p_renorm(*cfg)\n        batch_size, vocab_size, p = cfg\n        print(f\"  ✓ Passed: batch_size={batch_size}, vocab_size={vocab_size}, p={p}\")\n\n    print(\"\\n3. Testing top_k_mask_logits...\")\n    for cfg in test_configs_k:\n        batch_size, vocab_size, k = cfg\n        if k < vocab_size:  # Skip invalid configs\n            calculate_diff_top_k_mask(batch_size, vocab_size, k)\n            print(\n                f\"  ✓ Passed: batch_size={batch_size}, vocab_size={vocab_size}, k={k}\"\n            )\n\n    print(\"\\n\" + \"=\" * 60)\n    print(\"All correctness checks passed!\")\n    print(\"=\" * 60)\n\n    print(\"\\n\" + \"=\" * 60)\n    print(\"Starting performance benchmarks...\")\n    print(\"=\" * 60)\n\n    print(\"\\n1. Benchmarking top_k_renorm_probs...\")\n    benchmark_top_k_renorm.run(print_data=True)\n\n    print(\"\\n2. Benchmarking top_p_renorm_probs...\")\n    benchmark_top_p_renorm.run(print_data=True)\n\n    print(\"\\n3. Benchmarking top_k_mask_logits...\")\n    benchmark_top_k_mask.run(print_data=True)\n\n    print(\"\\n\" + \"=\" * 60)\n    print(\"Benchmarking complete!\")\n    print(\"=\" * 60)\n"
  },
  {
    "path": "python/sglang/jit_kernel/benchmark/bench_rmsnorm.py",
    "content": "import itertools\n\nimport torch\nimport triton\nimport triton.testing\nfrom flashinfer import rmsnorm as fi_rmsnorm\nfrom sgl_kernel import rmsnorm\n\nfrom sglang.jit_kernel.benchmark.utils import (\n    DEFAULT_DEVICE,\n    DEFAULT_DTYPE,\n    get_benchmark_range,\n    run_benchmark,\n)\nfrom sglang.jit_kernel.norm import rmsnorm as jit_rmsnorm\n\n\ndef sglang_aot_rmsnorm(\n    input: torch.Tensor,\n    weight: torch.Tensor,\n) -> None:\n    rmsnorm(input, weight, out=input)\n\n\ndef sglang_jit_rmsnorm(\n    input: torch.Tensor,\n    weight: torch.Tensor,\n) -> None:\n    jit_rmsnorm(input, weight, output=input)\n\n\ndef flashinfer_rmsnorm(\n    input: torch.Tensor,\n    weight: torch.Tensor,\n) -> None:\n    fi_rmsnorm(input, weight, out=input)\n\n\n@torch.compile()\ndef torch_impl_rmsnorm(\n    input: torch.Tensor,\n    weight: torch.Tensor,\n    eps: float = 1e-6,\n) -> None:\n    mean = input.float().pow(2).mean(dim=-1, keepdim=True)\n    norm = (mean + eps).rsqrt()\n    input.copy_(input.float() * norm * weight.float())\n\n\nBS_LIST = get_benchmark_range(\n    full_range=[2**n for n in range(0, 14)],\n    ci_range=[16],\n)\nHIDDEN_SIZE_LIST = get_benchmark_range(\n    full_range=[1536, 3072, 4096, 5120, 8192],\n    ci_range=[512, 2048],\n)\n\nLINE_VALS = [\"aot\", \"jit\", \"flashinfer\", \"torch\"]\nLINE_NAMES = [\"SGL AOT Kernel\", \"SGL JIT Kernel\", \"FlashInfer\", \"PyTorch\"]\nSTYLES = [(\"orange\", \"-\"), (\"blue\", \"--\"), (\"green\", \"-.\"), (\"red\", \":\")]\n\nconfigs = list(itertools.product(HIDDEN_SIZE_LIST, BS_LIST))\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=[\"hidden_size\", \"batch_size\"],\n        x_vals=configs,\n        line_arg=\"provider\",\n        line_vals=LINE_VALS,\n        line_names=LINE_NAMES,\n        styles=STYLES,\n        ylabel=\"us\",\n        plot_name=\"rmsnorm-performance\",\n        args={},\n    )\n)\ndef benchmark(hidden_size: int, batch_size: int, provider: str):\n    input = torch.randn(\n        (batch_size, hidden_size), dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE\n    )\n    weight = torch.randn(hidden_size, dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE)\n    FN_MAP = {\n        \"aot\": sglang_aot_rmsnorm,\n        \"jit\": sglang_jit_rmsnorm,\n        \"flashinfer\": flashinfer_rmsnorm,\n        \"torch\": torch_impl_rmsnorm,\n    }\n    fn = lambda: FN_MAP[provider](input.clone(), weight)\n    return run_benchmark(fn)\n\n\nif __name__ == \"__main__\":\n    benchmark.run(print_data=True)\n"
  },
  {
    "path": "python/sglang/jit_kernel/benchmark/bench_rope.py",
    "content": "import itertools\n\nimport torch\nimport triton\nimport triton.testing\n\nfrom sglang.jit_kernel.benchmark.utils import (\n    DEFAULT_DEVICE,\n    DEFAULT_DTYPE,\n    get_benchmark_range,\n    run_benchmark,\n)\n\nMAX_SEQ_LEN = 131072\nROPE_BASE = 10000.0\nROPE_DIM = 128\nCACHE_SIZE = 1024 * 1024\n\n\ndef create_cos_sin_cache(\n    rotary_dim: int = ROPE_DIM,\n    max_position: int = MAX_SEQ_LEN,\n    base: float = ROPE_BASE,\n) -> torch.Tensor:\n    inv_freq = 1.0 / (\n        base\n        ** (\n            torch.arange(0, rotary_dim, 2, dtype=torch.float32, device=DEFAULT_DEVICE)\n            / rotary_dim\n        )\n    )\n    t = torch.arange(max_position, dtype=torch.float32, device=DEFAULT_DEVICE)\n    freqs = torch.einsum(\"i,j->ij\", t, inv_freq)\n    cos = freqs.cos()\n    sin = freqs.sin()\n    return torch.cat((cos, sin), dim=-1)\n\n\n# Pre-build the cache once\nCOS_SIN_CACHE = create_cos_sin_cache()\n\n\n# ---------------------------------------------------------------------------\n# RoPE-only provider implementations\n# ---------------------------------------------------------------------------\n\n\ndef flashinfer_rope(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    positions: torch.Tensor,\n    is_neox: bool,\n) -> None:\n    from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace\n\n    head_size = q.shape[-1]\n    apply_rope_with_cos_sin_cache_inplace(\n        positions=positions,\n        query=q.view(q.shape[0], -1),\n        key=k.view(k.shape[0], -1),\n        head_size=head_size,\n        cos_sin_cache=COS_SIN_CACHE,\n        is_neox=is_neox,\n    )\n\n\ndef sglang_pos_enc_rope(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    positions: torch.Tensor,\n    is_neox: bool,\n) -> None:\n    from sglang.jit_kernel.rope import rotary_embedding_with_key\n\n    head_size = q.shape[-1]\n    rotary_embedding_with_key(\n        positions=positions,\n        query=q.view(q.shape[0], -1),\n        key=k.view(k.shape[0], -1),\n        head_size=head_size,\n        cos_sin_cache=COS_SIN_CACHE,\n        is_neox=is_neox,\n    )\n\n\ndef sglang_fused_rope(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    positions: torch.Tensor,\n    is_neox: bool,\n) -> None:\n    from sglang.jit_kernel.rope import apply_rope_inplace\n\n    apply_rope_inplace(q, k, COS_SIN_CACHE, positions, is_neox=is_neox)\n\n\n# ---------------------------------------------------------------------------\n# RoPE + KV cache store provider implementations\n# ---------------------------------------------------------------------------\n\n\ndef jit_rope_then_store(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    k_cache: torch.Tensor,\n    v_cache: torch.Tensor,\n    positions: torch.Tensor,\n    out_loc: torch.Tensor,\n    is_neox: bool,\n) -> None:\n    from sglang.jit_kernel.kvcache import store_cache\n    from sglang.jit_kernel.rope import apply_rope_inplace\n\n    head_size = q.shape[-1]\n    row_dim = k.shape[-2] * head_size\n    apply_rope_inplace(\n        positions=positions,\n        q=q,\n        k=k,\n        rope_dim=head_size,\n        cos_sin_cache=COS_SIN_CACHE,\n        is_neox=is_neox,\n    )\n    store_cache(\n        k.view(-1, row_dim),\n        v.view(-1, row_dim),\n        k_cache,\n        v_cache,\n        out_loc,\n    )\n\n\ndef jit_fused_rope_store(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    k_cache: torch.Tensor,\n    v_cache: torch.Tensor,\n    positions: torch.Tensor,\n    out_loc: torch.Tensor,\n    is_neox: bool,\n) -> None:\n    from sglang.jit_kernel.rope import apply_rope_inplace_with_kvcache\n\n    apply_rope_inplace_with_kvcache(\n        q, k, v, k_cache, v_cache, COS_SIN_CACHE, positions, out_loc, is_neox=is_neox\n    )\n\n\n# ---------------------------------------------------------------------------\n# Benchmark configuration (shared)\n# ---------------------------------------------------------------------------\n\nBS_RANGE = get_benchmark_range(\n    full_range=[2**n for n in range(0, 16)],\n    ci_range=[16],\n)\nQK_HEAD_RANGE = get_benchmark_range(\n    full_range=[(8, 1), (16, 2), (32, 8)],\n    ci_range=[(16, 2)],\n)\nQK_HEAD_RANGE = [f\"{q},{k}\" for q, k in QK_HEAD_RANGE]\nIS_NEOX_RANGE = get_benchmark_range(\n    full_range=[True, False],\n    ci_range=[True],\n)\n\n\n# ---------------------------------------------------------------------------\n# Benchmark 1: RoPE only\n# ---------------------------------------------------------------------------\n\nROPE_LINE_VALS = [\"flashinfer\", \"jit_pos_enc\", \"jit_fused_rope\"]\nROPE_LINE_NAMES = [\n    \"FlashInfer\",\n    \"SGL JIT PosEnc\",\n    \"SGL JIT Fused RoPE\",\n]\nROPE_STYLES = [(\"green\", \"-.\"), (\"red\", \"-\"), (\"blue\", \"--\")]\n\nrope_configs = list(itertools.product(QK_HEAD_RANGE, IS_NEOX_RANGE, BS_RANGE))\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=[\"num_q_k_heads\", \"is_neox\", \"batch_size\"],\n        x_vals=rope_configs,\n        line_arg=\"provider\",\n        line_vals=ROPE_LINE_VALS,\n        line_names=ROPE_LINE_NAMES,\n        styles=ROPE_STYLES,\n        ylabel=\"us\",\n        plot_name=\"rope-performance\",\n        args={},\n    )\n)\ndef benchmark(batch_size: int, num_q_k_heads: str, is_neox: bool, provider: str):\n    qo, kv = num_q_k_heads.split(\",\")\n    num_qo_heads = int(qo)\n    num_kv_heads = int(kv)\n    q = torch.randn(\n        (batch_size, num_qo_heads, ROPE_DIM),\n        dtype=DEFAULT_DTYPE,\n        device=DEFAULT_DEVICE,\n    )\n    k = torch.randn(\n        (batch_size, num_kv_heads, ROPE_DIM),\n        dtype=DEFAULT_DTYPE,\n        device=DEFAULT_DEVICE,\n    )\n    seed = batch_size << 16 | num_qo_heads << 8 | num_kv_heads << 4 | is_neox\n    torch.random.manual_seed(seed)\n    positions = torch.randint(\n        MAX_SEQ_LEN, (batch_size,), device=DEFAULT_DEVICE, dtype=torch.int64\n    )\n    torch.cuda.synchronize()\n\n    FN_MAP = {\n        \"flashinfer\": flashinfer_rope,\n        \"jit_pos_enc\": sglang_pos_enc_rope,\n        \"jit_fused_rope\": sglang_fused_rope,\n    }\n    fn = lambda: FN_MAP[provider](q, k, positions, is_neox)\n    return run_benchmark(fn)\n\n\n# ---------------------------------------------------------------------------\n# Benchmark 2: RoPE + KV cache store\n# ---------------------------------------------------------------------------\n\nSTORE_LINE_VALS = [\"jit_rope_then_store\", \"jit_fused_store\"]\nSTORE_LINE_NAMES = [\n    \"SGL JIT RoPE + Store\",\n    \"SGL JIT Fused RoPE + Store\",\n]\nSTORE_STYLES = [(\"red\", \"-\"), (\"blue\", \"--\")]\n\nstore_configs = list(itertools.product(QK_HEAD_RANGE, IS_NEOX_RANGE, BS_RANGE))\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=[\"num_q_k_heads\", \"is_neox\", \"batch_size\"],\n        x_vals=store_configs,\n        line_arg=\"provider\",\n        line_vals=STORE_LINE_VALS,\n        line_names=STORE_LINE_NAMES,\n        styles=STORE_STYLES,\n        ylabel=\"us\",\n        plot_name=\"rope-store-performance\",\n        args={},\n    )\n)\ndef benchmark_store(batch_size: int, num_q_k_heads: str, is_neox: bool, provider: str):\n    qo, kv = num_q_k_heads.split(\",\")\n    num_qo_heads = int(qo)\n    num_kv_heads = int(kv)\n    q = torch.randn(\n        (batch_size, num_qo_heads, ROPE_DIM),\n        dtype=DEFAULT_DTYPE,\n        device=DEFAULT_DEVICE,\n    )\n    k = torch.randn(\n        (batch_size, num_kv_heads, ROPE_DIM),\n        dtype=DEFAULT_DTYPE,\n        device=DEFAULT_DEVICE,\n    )\n    v = torch.randn(\n        (batch_size, num_kv_heads, ROPE_DIM),\n        dtype=DEFAULT_DTYPE,\n        device=DEFAULT_DEVICE,\n    )\n    row_size = num_kv_heads * ROPE_DIM\n    k_cache = torch.zeros(\n        CACHE_SIZE, row_size, dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE\n    )\n    v_cache = torch.zeros(\n        CACHE_SIZE, row_size, dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE\n    )\n    out_loc = torch.randperm(CACHE_SIZE, device=DEFAULT_DEVICE, dtype=torch.int64)[\n        :batch_size\n    ]\n    seed = batch_size << 16 | num_qo_heads << 8 | num_kv_heads << 4 | is_neox\n    torch.random.manual_seed(seed)\n    positions = torch.randint(\n        MAX_SEQ_LEN, (batch_size,), device=DEFAULT_DEVICE, dtype=torch.int64\n    )\n    torch.cuda.synchronize()\n\n    FN_MAP = {\n        \"jit_rope_then_store\": jit_rope_then_store,\n        \"jit_fused_store\": jit_fused_rope_store,\n    }\n    fn = lambda: FN_MAP[provider](\n        q, k, v, k_cache, v_cache, positions, out_loc, is_neox\n    )\n    return run_benchmark(fn)\n\n\nif __name__ == \"__main__\":\n    print(\"Running RoPE performance benchmark...\")\n    benchmark.run(print_data=True)\n    print(\"\\nRunning RoPE + KV cache store performance benchmark...\")\n    benchmark_store.run(print_data=True)\n"
  },
  {
    "path": "python/sglang/jit_kernel/benchmark/bench_store_cache.py",
    "content": "import itertools\nfrom typing import Tuple\n\nimport torch\nimport triton\nimport triton.testing\n\nfrom sglang.jit_kernel.benchmark.utils import (\n    DEFAULT_DEVICE,\n    DEFAULT_DTYPE,\n    DEFAULT_QUANTILES,\n    get_benchmark_range,\n)\nfrom sglang.jit_kernel.kvcache import store_cache\n\n\ndef sglang_jit_store_cache(\n    k: torch.Tensor,\n    v: torch.Tensor,\n    k_cache: torch.Tensor,\n    v_cache: torch.Tensor,\n    indices: torch.Tensor,\n) -> None:\n    store_cache(k, v, k_cache, v_cache, indices)\n\n\n@torch.compile()\ndef torch_compile_store_cache(\n    k: torch.Tensor,\n    v: torch.Tensor,\n    k_cache: torch.Tensor,\n    v_cache: torch.Tensor,\n    indices: torch.Tensor,\n) -> None:\n    k_cache[indices] = k\n    v_cache[indices] = v\n\n\nalt_stream = torch.cuda.Stream()\n\n\ndef torch_streams_store_cache(\n    k: torch.Tensor,\n    v: torch.Tensor,\n    k_cache: torch.Tensor,\n    v_cache: torch.Tensor,\n    indices: torch.Tensor,\n) -> None:\n    current_stream = torch.cuda.current_stream()\n    alt_stream.wait_stream(current_stream)\n    k_cache[indices] = k\n    with torch.cuda.stream(alt_stream):\n        v_cache[indices] = v\n    current_stream.wait_stream(alt_stream)\n\n\nNUM_LAYERS = 8\nCACHE_SIZE = 2 * 1024 * 1024 // NUM_LAYERS\n\nBS_RANGE = get_benchmark_range(\n    full_range=[2**n for n in range(0, 15)],\n    ci_range=[16],\n)\nITEM_SIZE = get_benchmark_range(\n    full_range=[64, 128, 256, 512, 1024],\n    ci_range=[1024],\n)\n\nLINE_VALS = [\"jit\", \"torch_compile\", \"torch_streams\"]\nLINE_NAMES = [\"SGL JIT Kernel\", \"PyTorch Compile\", \"PyTorch 2 Stream\"]\nSTYLES = [(\"blue\", \"--\"), (\"red\", \":\"), (\"green\", \"-.\")]\nX_NAMES = [\"item_size\", \"batch_size\"]\nCONFIGS = list(itertools.product(ITEM_SIZE, BS_RANGE))\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=X_NAMES,\n        x_vals=CONFIGS,\n        line_arg=\"provider\",\n        line_vals=LINE_VALS,\n        line_names=LINE_NAMES,\n        styles=STYLES,\n        ylabel=\"us\",\n        plot_name=\"store-kvcache-performance\",\n        args={},\n    )\n)\ndef benchmark(\n    batch_size: int, item_size: int, provider: str\n) -> Tuple[float, float, float]:\n    k = torch.randn(\n        (NUM_LAYERS, batch_size, item_size), dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE\n    )\n    v = torch.randn(\n        (NUM_LAYERS, batch_size, item_size), dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE\n    )\n    k_cache = torch.randn(\n        (NUM_LAYERS, CACHE_SIZE, item_size), dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE\n    )\n    v_cache = torch.randn(\n        (NUM_LAYERS, CACHE_SIZE, item_size), dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE\n    )\n    indices = torch.randperm(CACHE_SIZE, device=DEFAULT_DEVICE)[:batch_size]\n    torch.cuda.synchronize()\n\n    FN_MAP = {\n        \"jit\": sglang_jit_store_cache,\n        \"torch_compile\": torch_compile_store_cache,\n        \"torch_streams\": torch_streams_store_cache,\n    }\n\n    def fn():\n        impl = FN_MAP[provider]\n        for i in range(NUM_LAYERS):\n            impl(k[i], v[i], k_cache[i], v_cache[i], indices)\n\n    # Custom time calculation: divide by NUM_LAYERS\n    ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(\n        fn, quantiles=DEFAULT_QUANTILES\n    )\n    return (\n        1000 * ms / NUM_LAYERS,\n        1000 * max_ms / NUM_LAYERS,\n        1000 * min_ms / NUM_LAYERS,\n    )\n\n\nif __name__ == \"__main__\":\n    benchmark.run(print_data=True)\n"
  },
  {
    "path": "python/sglang/jit_kernel/benchmark/utils.py",
    "content": "\"\"\"Common utilities for jit_kernel benchmark files.\"\"\"\n\nfrom typing import Callable, List, Tuple\n\nimport torch\nimport triton.testing\n\nfrom sglang.jit_kernel.utils import is_in_ci as jit_kernel_is_in_ci\n\n# Common constants\nDEFAULT_DTYPE = torch.bfloat16\nDEFAULT_DEVICE = \"cuda\"\nDEFAULT_QUANTILES = [0.5, 0.2, 0.8]\n\n\ndef is_in_ci() -> bool:\n    \"\"\"Check if running in CI environment.\"\"\"\n    return jit_kernel_is_in_ci()\n\n\ndef get_benchmark_range(full_range: List, ci_range: List) -> List:\n    \"\"\"Return appropriate benchmark range based on CI environment.\"\"\"\n    return ci_range if is_in_ci() else full_range\n\n\ndef run_benchmark(\n    fn: Callable, quantiles: List[float] = None\n) -> Tuple[float, float, float]:\n    \"\"\"Execute benchmark using CUDA graph and return times in microseconds.\n\n    Args:\n        fn: Function to benchmark\n        quantiles: Quantiles for timing measurements [median, min, max]\n\n    Returns:\n        Tuple of (median_us, max_us, min_us)\n    \"\"\"\n    quantiles = quantiles or DEFAULT_QUANTILES\n    ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)\n    return 1000 * ms, 1000 * max_ms, 1000 * min_ms\n\n\ndef run_benchmark_no_cudagraph(\n    fn: Callable, quantiles: List[float] = None\n) -> Tuple[float, float, float]:\n    quantiles = quantiles or DEFAULT_QUANTILES\n    ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)\n    return 1000 * ms, 1000 * max_ms, 1000 * min_ms\n"
  },
  {
    "path": "python/sglang/jit_kernel/concat_mla.py",
    "content": "from __future__ import annotations\n\nfrom typing import TYPE_CHECKING\n\nimport torch\n\nfrom sglang.jit_kernel.utils import cache_once, load_jit\n\nif TYPE_CHECKING:\n    from tvm_ffi.module import Module\n\n\n@cache_once\ndef _jit_concat_mla_k_module() -> Module:\n    return load_jit(\n        \"concat_mla_k\",\n        cuda_files=[\"elementwise/concat_mla.cuh\"],\n        cuda_wrappers=[(\"concat_mla_k\", \"ConcatMlaKKernel::run\")],\n    )\n\n\n@cache_once\ndef _jit_concat_mla_absorb_q_module() -> Module:\n    return load_jit(\n        \"concat_mla_absorb_q\",\n        cuda_files=[\"elementwise/concat_mla.cuh\"],\n        cuda_wrappers=[(\"concat_mla_absorb_q\", \"ConcatMlaAbsorbQKernel::run\")],\n    )\n\n\ndef concat_mla_k(k: torch.Tensor, k_nope: torch.Tensor, k_rope: torch.Tensor) -> None:\n    \"\"\"\n    Concatenate k_nope and k_rope into k for MLA (Multi-head Latent Attention).\n\n    This kernel efficiently broadcasts k_rope across all heads while copying\n    k_nope values directly.\n\n    Args:\n        k: Output tensor of shape [num_tokens, num_heads=128, k_head_dim=192], dtype=bfloat16\n        k_nope: Input tensor of shape [num_tokens, num_heads=128, nope_head_dim=128], dtype=bfloat16\n        k_rope: Input tensor of shape [num_tokens, 1, rope_head_dim=64], dtype=bfloat16\n    \"\"\"\n    module = _jit_concat_mla_k_module()\n    module.concat_mla_k(k, k_nope, k_rope)\n\n\ndef concat_mla_absorb_q(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Concatenate tensors a and b for MLA absorbed Q computation.\n\n    Args:\n        a: Input tensor of shape [dim_0, dim_1, a_last_dim], dtype=bfloat16\n        b: Input tensor of shape [dim_0, dim_1, b_last_dim], dtype=bfloat16\n\n    Returns:\n        Output tensor of shape [dim_0, dim_1, a_last_dim + b_last_dim], dtype=bfloat16\n    \"\"\"\n    out = torch.empty(\n        (*a.shape[:-1], a.shape[-1] + b.shape[-1]),\n        dtype=a.dtype,\n        device=a.device,\n    )\n    module = _jit_concat_mla_absorb_q_module()\n    module.concat_mla_absorb_q(a, b, out)\n    return out\n"
  },
  {
    "path": "python/sglang/jit_kernel/csrc/add_constant.cuh",
    "content": "#include <sgl_kernel/tensor.h>  // For TensorMatcher, SymbolicSize, SymbolicDevice\n#include <sgl_kernel/utils.h>   // For div_ceil, RuntimeCheck\n\n#include <sgl_kernel/utils.cuh>  // For LaunchKernel\n\n#include <dlpack/dlpack.h>\n#include <tvm/ffi/container/tensor.h>\n\n#include <cstddef>\n#include <cstdint>\n\nnamespace {\n\ntemplate <int32_t kConstant>\n__global__ void add_constant_kernel(int32_t* dst, const int32_t* src, size_t length) {\n  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;\n  if (idx < length) {\n    dst[idx] = src[idx] + kConstant;\n  }\n}\n\nconstexpr size_t kBlockSize = 256;\n\n// You can also use struct with static method as an alternative\ntemplate <int32_t kConstant>\nvoid add_constant(tvm::ffi::TensorView dst, tvm::ffi::TensorView src) {\n  using namespace host;\n\n  // 1. Validate input tensors\n  SymbolicSize N = {\"num_elements\"};\n  SymbolicDevice device_;\n  TensorMatcher({N})                  // 1D tensor, must be contiguous\n      .with_dtype<int32_t>()          // must be int32\n      .with_device<kDLCUDA>(device_)  // must be on CUDA device\n      .verify(dst)                    // check tensor dst\n      .verify(src);                   // check tensor src\n\n  // 2. Extract required parameters, prepare for kernel launch\n  const size_t num_elements = N.unwrap();\n  const size_t grid_size = div_ceil(num_elements, kBlockSize);\n  const DLDevice device = device_.unwrap();\n  [[maybe_unused]]  // optional, can be omitted\n  const size_t dynamic_smem = 0;\n  [[maybe_unused]]  // optional, LaunchKernel can auto determine stream from device\n  const cudaStream_t stream = LaunchKernel::resolve_device(device);\n  // some extra runtime checks using host::RuntimeCheck\n  RuntimeCheck(num_elements > 0, \"We only support non-empty tensors, got num_elements = \", num_elements);\n\n  // 3. Launch the kernel. Error code will be automatically checked.\n  LaunchKernel(grid_size, kBlockSize, device /*, dynamic_smem*/)(\n      // kernel function\n      add_constant_kernel<kConstant>,\n      // kernel arguments\n      static_cast<int32_t*>(dst.data_ptr()),\n      static_cast<int32_t*>(src.data_ptr()),\n      num_elements);\n}\n\n}  // namespace\n"
  },
  {
    "path": "python/sglang/jit_kernel/csrc/diffusion/timestep_embedding.cuh",
    "content": "#include <sgl_kernel/tensor.h>\n#include <sgl_kernel/utils.h>\n\n#include <sgl_kernel/math.cuh>\n#include <sgl_kernel/type.cuh>\n#include <sgl_kernel/utils.cuh>\n\n#include <dlpack/dlpack.h>\n#include <tvm/ffi/container/tensor.h>\n\n#include <algorithm>\n#include <cmath>\n#include <cstdint>\n#include <cuda_runtime.h>\n#include <type_traits>\n\nnamespace {\n\ntemplate <bool kFlipSinToCos, typename TIn>\n__global__ void timestep_embedding_kernel(\n    const TIn* __restrict__ t_ptr,\n    float* __restrict__ output_ptr,\n    int dim,\n    float neg_log_max_period,\n    float scale,\n    int batch_size) {\n  int row_idx = static_cast<int>(blockIdx.x * blockDim.y + threadIdx.y);\n  if (row_idx >= batch_size) {\n    return;\n  }\n\n  float t_val = device::cast<float>(t_ptr[row_idx]);\n  float* output_batch_base_ptr = output_ptr + row_idx * dim;\n\n  int half_dim = dim / 2;\n  int thread_offset = static_cast<int>(threadIdx.x);\n  while (thread_offset * 4 < half_dim) {\n    float4* top_half;\n    float4* bottom_half;\n    if constexpr (!kFlipSinToCos) {\n      bottom_half = reinterpret_cast<float4*>(output_batch_base_ptr + thread_offset * 4);\n      top_half = reinterpret_cast<float4*>(output_batch_base_ptr + half_dim + thread_offset * 4);\n    } else {\n      top_half = reinterpret_cast<float4*>(output_batch_base_ptr + thread_offset * 4);\n      bottom_half = reinterpret_cast<float4*>(output_batch_base_ptr + half_dim + thread_offset * 4);\n    }\n\n    float4 vals;\n    vals.x = scale * t_val * device::math::exp(neg_log_max_period * __int2float_rn(thread_offset * 4 + 0));\n    vals.y = scale * t_val * device::math::exp(neg_log_max_period * __int2float_rn(thread_offset * 4 + 1));\n    vals.z = scale * t_val * device::math::exp(neg_log_max_period * __int2float_rn(thread_offset * 4 + 2));\n    vals.w = scale * t_val * device::math::exp(neg_log_max_period * __int2float_rn(thread_offset * 4 + 3));\n\n    float4 cos_vals;\n    cos_vals.x = device::math::cos(vals.x);\n    cos_vals.y = device::math::cos(vals.y);\n    cos_vals.z = device::math::cos(vals.z);\n    cos_vals.w = device::math::cos(vals.w);\n    *top_half = cos_vals;\n\n    float4 sin_vals;\n    sin_vals.x = device::math::sin(vals.x);\n    sin_vals.y = device::math::sin(vals.y);\n    sin_vals.z = device::math::sin(vals.z);\n    sin_vals.w = device::math::sin(vals.w);\n    *bottom_half = sin_vals;\n\n    thread_offset += static_cast<int>(blockDim.x);\n  }\n}\n\ntemplate <typename TIn>\ninline void launch_timestep_embedding(\n    const tvm::ffi::TensorView t,\n    const tvm::ffi::TensorView output,\n    int dim,\n    bool flip_sin_to_cos,\n    float downscale_freq_shift,\n    float scale,\n    int max_period) {\n  using namespace host;\n\n  const int batch_size = static_cast<int>(t.shape()[0]);\n  const int half_dim = dim / 2;\n\n  constexpr int kMaxThreadsPerBlock = 1024;\n  constexpr int kMinThreadsPerBlock = 128;\n\n  const int num_threads_per_row = std::min(kMaxThreadsPerBlock, half_dim / 4);\n  const int num_rows = (kMinThreadsPerBlock + num_threads_per_row - 1) / num_threads_per_row;\n\n  dim3 grid((batch_size + num_rows - 1) / num_rows);\n  dim3 block(num_threads_per_row, num_rows);\n\n  const float neg_log_max_period =\n      std::log(static_cast<float>(max_period)) * (-1.0f) / (static_cast<float>(half_dim) - downscale_freq_shift);\n\n  const DLDevice device = output.device();\n\n  if (flip_sin_to_cos) {\n    LaunchKernel(grid, block, device)(\n        timestep_embedding_kernel<true, TIn>,\n        static_cast<const TIn*>(t.data_ptr()),\n        static_cast<float*>(output.data_ptr()),\n        dim,\n        neg_log_max_period,\n        scale,\n        batch_size);\n  } else {\n    LaunchKernel(grid, block, device)(\n        timestep_embedding_kernel<false, TIn>,\n        static_cast<const TIn*>(t.data_ptr()),\n        static_cast<float*>(output.data_ptr()),\n        dim,\n        neg_log_max_period,\n        scale,\n        batch_size);\n  }\n}\n\ntemplate <typename TIn>\nvoid timestep_embedding(\n    tvm::ffi::TensorView input,\n    tvm::ffi::TensorView output,\n    int dim,\n    bool flip_sin_to_cos,\n    float downscale_freq_shift,\n    float scale,\n    int max_period) {\n  using namespace host;\n\n  auto B = SymbolicSize{\"batch_size\"};\n  auto D = SymbolicSize{\"dim\"};\n  auto device = SymbolicDevice{};\n\n  TensorMatcher({B})  // input\n      .with_strides({1})\n      .with_dtype<TIn>()\n      .template with_device<kDLCUDA>(device)\n      .verify(input);\n\n  TensorMatcher({B, D}).with_strides({D, 1}).with_dtype<float>().template with_device<kDLCUDA>(device).verify(output);\n\n  RuntimeCheck(D.unwrap() == dim, \"Output dim mismatch: \", D.unwrap(), \" vs \", dim);\n  RuntimeCheck(dim % 8 == 0, \"dim must align to 8, got \", dim);\n\n  launch_timestep_embedding<TIn>(input, output, dim, flip_sin_to_cos, downscale_freq_shift, scale, max_period);\n}\n\n}  // namespace\n"
  },
  {
    "path": "python/sglang/jit_kernel/csrc/elementwise/concat_mla.cuh",
    "content": "#include <sgl_kernel/tensor.h>\n#include <sgl_kernel/utils.h>\n\n#include <sgl_kernel/utils.cuh>\n\n#include <tvm/ffi/container/tensor.h>\n\n#include <cuda_bf16.h>\n#include <cuda_runtime.h>\n\nnamespace {\n\n// ======================= Memory Utilities =======================\n// Adapted from DeepEP: https://github.com/deepseek-ai/DeepEP/blob/main/csrc/kernels/utils.cuh\n\nSGL_DEVICE int get_lane_id() {\n  int lane_id;\n  asm(\"mov.s32 %0, %laneid;\" : \"=r\"(lane_id));\n  return lane_id;\n}\n\nSGL_DEVICE void st_na_global_v1(const int* ptr, int v) {\n  asm volatile(\"st.global.L1::no_allocate.s32 [%0], %1;\" ::\"l\"(ptr), \"r\"(v) : \"memory\");\n}\n\nSGL_DEVICE void st_na_global_v2(const int2* ptr, const int2& v) {\n  asm volatile(\"st.global.L1::no_allocate.v2.s32 [%0], {%1, %2};\" ::\"l\"(ptr), \"r\"(v.x), \"r\"(v.y) : \"memory\");\n}\n\nSGL_DEVICE int ld_na_global_v1(const int* ptr) {\n  int r;\n  asm volatile(\"ld.global.nc.L1::no_allocate.s32 %0, [%1];\" : \"=r\"(r) : \"l\"(ptr));\n  return r;\n}\n\nSGL_DEVICE int2 ld_na_global_v2(const int2* ptr) {\n  int2 r;\n  asm volatile(\"ld.global.nc.L1::no_allocate.v2.s32 {%0, %1}, [%2];\" : \"=r\"(r.x), \"=r\"(r.y) : \"l\"(ptr));\n  return r;\n}\n\nSGL_DEVICE void prefetch_L2(const void* p) {\n#if defined(ENABLE_L2_PREFETCH)\n  asm volatile(\"prefetch.global.L2 [%0];\" ::\"l\"(p));\n#endif\n}\n\n// ======================= concat_mla_k Kernel =======================\n\nconstexpr int NUM_LOCAL_HEADS = 128;\nconstexpr int QK_NOPE_HEAD_DIM = 128;\nconstexpr int QK_ROPE_HEAD_DIM = 64;\nconstexpr int K_HEAD_DIM = QK_NOPE_HEAD_DIM + QK_ROPE_HEAD_DIM;\n\nconstexpr int HEAD_CHUNK_SIZE = 16;\nconstexpr int NUM_HEAD_CHUNKS = NUM_LOCAL_HEADS / HEAD_CHUNK_SIZE;\n\n__global__ void concat_mla_k_kernel(\n    bf16_t* __restrict__ k,\n    const bf16_t* __restrict__ k_nope,\n    const bf16_t* __restrict__ k_rope,\n    const int num_tokens,\n    const int64_t k_stride_0,\n    const int k_stride_1,\n    const int64_t k_nope_stride_0,\n    const int k_nope_stride_1,\n    const int64_t k_rope_stride_0) {\n  const int flat_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / 32;\n  const int token_id = flat_warp_id / NUM_HEAD_CHUNKS;\n  const int head_chunk_id = flat_warp_id % NUM_HEAD_CHUNKS;\n  const int lane_id = get_lane_id();\n  if (token_id >= num_tokens) return;\n\n  using NopeVec = int2;  // 8B/thread, 32 threads = 256B/row\n  using RopeVec = int;   // 4B/thread, 32 threads = 128B/row\n  static_assert(sizeof(NopeVec) * 32 == QK_NOPE_HEAD_DIM * sizeof(bf16_t), \"nope vec mismatch\");\n  static_assert(sizeof(RopeVec) * 32 == QK_ROPE_HEAD_DIM * sizeof(bf16_t), \"rope vec mismatch\");\n\n  const int head_row0 = head_chunk_id * HEAD_CHUNK_SIZE;\n\n  const int2* __restrict__ nope_src =\n      reinterpret_cast<const int2*>(k_nope + token_id * k_nope_stride_0 + head_row0 * k_nope_stride_1) + lane_id;\n\n  int2* __restrict__ nope_dst = reinterpret_cast<int2*>(k + token_id * k_stride_0 + head_row0 * k_stride_1) + lane_id;\n\n  int* __restrict__ rope_dst =\n      reinterpret_cast<int*>(k + token_id * k_stride_0 + head_row0 * k_stride_1 + QK_NOPE_HEAD_DIM) + lane_id;\n\n  const int nope_src_stride_v = (k_nope_stride_1 >> 2);  // int2 covers 4 bf16\n  const int nope_dst_stride_v = (k_stride_1 >> 2);\n  const int rope_dst_stride_v = (k_stride_1 >> 1);  // int covers 2 bf16\n\n  const int* rope_base = reinterpret_cast<const int*>(k_rope + token_id * k_rope_stride_0);\n  const RopeVec rope_val = ld_na_global_v1(rope_base + lane_id);\n\n  prefetch_L2(nope_src);\n  NopeVec cur = ld_na_global_v2(nope_src);\n\n#pragma unroll\n  for (int i = 0; i < HEAD_CHUNK_SIZE; ++i) {\n    NopeVec next;\n    if (i + 1 < HEAD_CHUNK_SIZE) {\n      const int2* next_src = nope_src + nope_src_stride_v;\n      prefetch_L2(next_src);\n      next = ld_na_global_v2(next_src);\n    }\n\n    st_na_global_v2(nope_dst, cur);\n    st_na_global_v1(rope_dst, rope_val);\n\n    nope_src += nope_src_stride_v;\n    nope_dst += nope_dst_stride_v;\n    rope_dst += rope_dst_stride_v;\n\n    cur = next;\n  }\n}\n\nstruct ConcatMlaKKernel {\n  static void run(tvm::ffi::TensorView k, tvm::ffi::TensorView k_nope, tvm::ffi::TensorView k_rope) {\n    using namespace host;\n\n    auto N = SymbolicSize{\"num_tokens\"};\n    auto H = SymbolicSize{\"num_heads\"};\n    auto D = SymbolicSize{\"k_head_dim\"};\n    auto D_nope = SymbolicSize{\"nope_head_dim\"};\n    auto D_rope = SymbolicSize{\"rope_head_dim\"};\n    auto S0_k = SymbolicSize{\"k_stride_0\"};\n    auto S1_k = SymbolicSize{\"k_stride_1\"};\n    auto S0_k_nope = SymbolicSize{\"k_nope_stride_0\"};\n    auto S1_k_nope = SymbolicSize{\"k_nope_stride_1\"};\n    auto S0_k_rope = SymbolicSize{\"k_rope_stride_0\"};\n    auto device = SymbolicDevice{};\n\n    // Set known fixed values\n    H.set_value(NUM_LOCAL_HEADS);\n    D.set_value(K_HEAD_DIM);\n    D_nope.set_value(QK_NOPE_HEAD_DIM);\n    D_rope.set_value(QK_ROPE_HEAD_DIM);\n\n    // Verify k: [num_tokens, num_heads, k_head_dim]\n    TensorMatcher({N, H, D}).with_strides({S0_k, S1_k, 1}).with_dtype<bf16_t>().with_device<kDLCUDA>(device).verify(k);\n\n    // Verify k_nope: [num_tokens, num_heads, nope_head_dim]\n    TensorMatcher({N, H, D_nope})\n        .with_strides({S0_k_nope, S1_k_nope, 1})\n        .with_dtype<bf16_t>()\n        .with_device<kDLCUDA>(device)\n        .verify(k_nope);\n\n    // Verify k_rope: [num_tokens, 1, rope_head_dim]\n    TensorMatcher({N, 1, D_rope})\n        .with_strides({S0_k_rope, -1, 1})\n        .with_dtype<bf16_t>()\n        .with_device<kDLCUDA>(device)\n        .verify(k_rope);\n\n    // Check alignment\n    RuntimeCheck(reinterpret_cast<uintptr_t>(k.data_ptr()) % 16 == 0, \"Tensor k must be 16-byte aligned\");\n    RuntimeCheck(reinterpret_cast<uintptr_t>(k_nope.data_ptr()) % 16 == 0, \"Tensor k_nope must be 16-byte aligned\");\n    RuntimeCheck(reinterpret_cast<uintptr_t>(k_rope.data_ptr()) % 16 == 0, \"Tensor k_rope must be 16-byte aligned\");\n\n    const int num_tokens = static_cast<int>(N.unwrap());\n\n    constexpr int num_warps_per_block = 32;\n    const int grid_size = div_ceil(num_tokens * NUM_HEAD_CHUNKS, num_warps_per_block);\n    const int block_size = num_warps_per_block * 32;\n\n    LaunchKernel(grid_size, block_size, device.unwrap())(\n        concat_mla_k_kernel,\n        static_cast<bf16_t*>(k.data_ptr()),\n        static_cast<const bf16_t*>(k_nope.data_ptr()),\n        static_cast<const bf16_t*>(k_rope.data_ptr()),\n        num_tokens,\n        S0_k.unwrap(),\n        static_cast<int>(S1_k.unwrap()),\n        S0_k_nope.unwrap(),\n        static_cast<int>(S1_k_nope.unwrap()),\n        S0_k_rope.unwrap());\n  }\n};\n\n// ======================= concat_mla_absorb_q Kernel =======================\n\nconstexpr int A_LAST_DIM = 512;\nconstexpr int B_LAST_DIM = 64;\nconstexpr int OUT_LAST_DIM = A_LAST_DIM + B_LAST_DIM;\n\n__global__ void concat_mla_absorb_q_kernel(\n    bf16_t* a,\n    bf16_t* b,\n    bf16_t* out,\n    const int num_items,\n    const int dim_1,\n    const int64_t a_stride_0,\n    const int a_stride_1,\n    const int64_t b_stride_0,\n    const int b_stride_1,\n    const int64_t out_stride_0,\n    const int out_stride_1) {\n  const int flat_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / 32;\n  const int lane_id = get_lane_id();\n\n  const int idx_0 = flat_warp_id / dim_1;\n  const int idx_1 = flat_warp_id % dim_1;\n\n  if (flat_warp_id >= num_items) {\n    return;\n  }\n\n  using ABufType = int4;\n  constexpr int A_NUM_UNROLL = 2;\n  static_assert(sizeof(ABufType) * A_NUM_UNROLL == A_LAST_DIM * sizeof(a[0]) / 32);\n  ABufType a_buf[A_NUM_UNROLL];\n\n  using BBufType = int;\n  constexpr int B_NUM_UNROLL = 1;\n  static_assert(sizeof(BBufType) * B_NUM_UNROLL == B_LAST_DIM * sizeof(b[0]) / 32);\n  BBufType b_buf;\n\n  {\n    const BBufType* base_addr = reinterpret_cast<BBufType*>(b + idx_0 * b_stride_0 + idx_1 * b_stride_1);\n    b_buf = *(base_addr + lane_id);\n  }\n\n#pragma unroll\n  for (int i = 0; i < A_NUM_UNROLL; ++i) {\n    const ABufType* base_addr = reinterpret_cast<ABufType*>(a + idx_0 * a_stride_0 + idx_1 * a_stride_1);\n    a_buf[i] = *(base_addr + i * 32 + lane_id);\n  }\n\n  {\n    BBufType* base_addr = reinterpret_cast<BBufType*>(out + idx_0 * out_stride_0 + idx_1 * out_stride_1 + A_LAST_DIM);\n    *(base_addr + lane_id) = b_buf;\n  }\n\n#pragma unroll\n  for (int i = 0; i < A_NUM_UNROLL; ++i) {\n    ABufType* base_addr = reinterpret_cast<ABufType*>(out + idx_0 * out_stride_0 + idx_1 * out_stride_1);\n    *(base_addr + i * 32 + lane_id) = a_buf[i];\n  }\n}\n\nstruct ConcatMlaAbsorbQKernel {\n  static void run(tvm::ffi::TensorView a, tvm::ffi::TensorView b, tvm::ffi::TensorView out) {\n    using namespace host;\n\n    auto N0_a = SymbolicSize{\"a_dim_0\"};\n    auto N1_a = SymbolicSize{\"a_dim_1\"};\n    auto D_a = SymbolicSize{\"a_last_dim\"};\n    auto N0_b = SymbolicSize{\"b_dim_0\"};\n    auto N1_b = SymbolicSize{\"b_dim_1\"};\n    auto D_b = SymbolicSize{\"b_last_dim\"};\n    auto N0_out = SymbolicSize{\"out_dim_0\"};\n    auto N1_out = SymbolicSize{\"out_dim_1\"};\n    auto D_out = SymbolicSize{\"out_last_dim\"};\n    auto S0_a = SymbolicSize{\"a_stride_0\"};\n    auto S1_a = SymbolicSize{\"a_stride_1\"};\n    auto S0_b = SymbolicSize{\"b_stride_0\"};\n    auto S1_b = SymbolicSize{\"b_stride_1\"};\n    auto S0_out = SymbolicSize{\"out_stride_0\"};\n    auto S1_out = SymbolicSize{\"out_stride_1\"};\n    auto device = SymbolicDevice{};\n\n    // Set known fixed values\n    D_a.set_value(A_LAST_DIM);\n    D_b.set_value(B_LAST_DIM);\n    D_out.set_value(OUT_LAST_DIM);\n\n    // Verify a: [dim_0, dim_1, A_LAST_DIM]\n    TensorMatcher({N0_a, N1_a, D_a})\n        .with_strides({S0_a, S1_a, 1})\n        .with_dtype<bf16_t>()\n        .with_device<kDLCUDA>(device)\n        .verify(a);\n\n    // Verify b: [dim_0, dim_1, B_LAST_DIM]\n    TensorMatcher({N0_b, N1_b, D_b})\n        .with_strides({S0_b, S1_b, 1})\n        .with_dtype<bf16_t>()\n        .with_device<kDLCUDA>(device)\n        .verify(b);\n\n    // Verify out: [dim_0, dim_1, OUT_LAST_DIM]\n    TensorMatcher({N0_out, N1_out, D_out})\n        .with_strides({S0_out, S1_out, 1})\n        .with_dtype<bf16_t>()\n        .with_device<kDLCUDA>(device)\n        .verify(out);\n\n    // Check alignment\n    RuntimeCheck(reinterpret_cast<uintptr_t>(a.data_ptr()) % 16 == 0, \"Tensor a must be 16-byte aligned\");\n    RuntimeCheck(reinterpret_cast<uintptr_t>(b.data_ptr()) % 16 == 0, \"Tensor b must be 16-byte aligned\");\n    RuntimeCheck(reinterpret_cast<uintptr_t>(out.data_ptr()) % 16 == 0, \"Tensor out must be 16-byte aligned\");\n\n    // Verify dimensions match: a.size(0) * a.size(1) == b.size(0) * b.size(1)\n    RuntimeCheck(\n        N0_a.unwrap() * N1_a.unwrap() == N0_b.unwrap() * N1_b.unwrap(),\n        \"Dimension mismatch: a.size(0) * a.size(1) must equal b.size(0) * b.size(1)\");\n    RuntimeCheck(N1_a.unwrap() == N1_b.unwrap(), \"Dimension mismatch: a.size(1) must equal b.size(1)\");\n\n    const int num_items = static_cast<int>(N0_a.unwrap() * N1_a.unwrap());\n    const int dim_1 = static_cast<int>(N1_a.unwrap());\n\n    constexpr int num_warps_per_block = 32;\n    const int grid_size = div_ceil(num_items, num_warps_per_block);\n    const int block_size = num_warps_per_block * 32;\n\n    LaunchKernel(grid_size, block_size, device.unwrap())(\n        concat_mla_absorb_q_kernel,\n        static_cast<bf16_t*>(a.data_ptr()),\n        static_cast<bf16_t*>(b.data_ptr()),\n        static_cast<bf16_t*>(out.data_ptr()),\n        num_items,\n        dim_1,\n        S0_a.unwrap(),\n        static_cast<int>(S1_a.unwrap()),\n        S0_b.unwrap(),\n        static_cast<int>(S1_b.unwrap()),\n        S0_out.unwrap(),\n        static_cast<int>(S1_out.unwrap()));\n  }\n};\n\n}  // namespace\n"
  },
  {
    "path": "python/sglang/jit_kernel/csrc/elementwise/fused_add_rmsnorm.cuh",
    "content": "#include <sgl_kernel/tensor.h>\n#include <sgl_kernel/utils.h>\n\n#include <sgl_kernel/runtime.cuh>\n#include <sgl_kernel/tile.cuh>\n#include <sgl_kernel/type.cuh>\n#include <sgl_kernel/utils.cuh>\n#include <sgl_kernel/vec.cuh>\n\n#include <cooperative_groups/reduce.h>\n#include <tvm/ffi/container/tensor.h>\n\n#include <cooperative_groups.h>\n#include <type_traits>\n\nnamespace {\n\ntemplate <typename T, int VEC_SIZE_IN_BYTE>\nstruct VecTypeTrait;\n\ntemplate <>\nstruct VecTypeTrait<bf16_t, 16> {\n  using packed_t = packed_t<bf16_t>;\n  using vec_t = device::AlignedVector<packed_t, 4>;\n};\n\ntemplate <>\nstruct VecTypeTrait<fp16_t, 16> {\n  using packed_t = packed_t<fp16_t>;\n  using vec_t = device::AlignedVector<packed_t, 4>;\n};\n\ntemplate <>\nstruct VecTypeTrait<bf16_t, 32> {\n  using packed_t = packed_t<bf16_t>;\n  using vec_t = device::AlignedVector<packed_t, 8>;\n};\n\ntemplate <>\nstruct VecTypeTrait<fp16_t, 32> {\n  using packed_t = packed_t<fp16_t>;\n  using vec_t = device::AlignedVector<packed_t, 8>;\n};\n\ntemplate <typename packed_t>\nSGL_DEVICE packed_t rms(packed_t& val, packed_t& weight, float rsqrt_square_sum) {\n  float2 valf = device::cast<fp32x2_t, packed_t>(val);\n  float2 weightf = device::cast<fp32x2_t, packed_t>(weight);\n  return device::cast<packed_t, fp32x2_t>(\n      make_float2(valf.x * weightf.x * rsqrt_square_sum, valf.y * weightf.y * rsqrt_square_sum));\n}\n\ntemplate <typename T, int VEC_SIZE_IN_BYTE>\n__global__ void fused_add_rmsnorm_reg_kernel(\n    T* __restrict__ input, T* __restrict__ residual, const T* __restrict__ weight, int vec_hidden_size, float eps) {\n  constexpr int inner_loop = VEC_SIZE_IN_BYTE == 16 ? 4 : 8;\n\n  __shared__ float shared_memory[32];  // Used for CTA reduce\n\n  using vec_t = typename VecTypeTrait<T, VEC_SIZE_IN_BYTE>::vec_t;\n  using packed_t = typename VecTypeTrait<T, VEC_SIZE_IN_BYTE>::packed_t;\n  vec_t v;         // Save input\n  vec_t v_res;     // Save residual\n  vec_t v_weight;  // Save weight\n  vec_t v_out;     // Save output\n\n  auto token_id = blockIdx.x;\n  float2 acc_square = make_float2(0.0f, 0.0f);  // Sum of squares for each thread\n\n  if (threadIdx.x < vec_hidden_size) {\n    // Compute address\n    vec_t* p = reinterpret_cast<vec_t*>(input) + token_id * vec_hidden_size;\n    vec_t* p_res = reinterpret_cast<vec_t*>(residual) + token_id * vec_hidden_size;\n    const vec_t* p_weight = reinterpret_cast<const vec_t*>(weight);\n\n    // Load data\n    v = p[threadIdx.x];\n    v_res = p_res[threadIdx.x];\n    v_weight = p_weight[threadIdx.x];\n\n    for (int i = 0; i < inner_loop; i++) {\n      float2 val = device::cast<fp32x2_t, packed_t>(v[i]);\n      float2 res = device::cast<fp32x2_t, packed_t>(v_res[i]);\n      float2 inp_res = make_float2(val.x + res.x, val.y + res.y);\n      acc_square.x += inp_res.x * inp_res.x;\n      acc_square.y += inp_res.y * inp_res.y;\n      v[i] = device::cast<packed_t, fp32x2_t>(inp_res);\n    }\n\n    // Store inp+res to residual\n    p_res[threadIdx.x] = v;\n  }\n\n  // CTA Reduce\n  // Step 0: Warp Reduce\n  auto cg_warp = cooperative_groups::tiled_partition<32>(cooperative_groups::this_thread_block());\n  float warp_sum = cooperative_groups::reduce(cg_warp, acc_square.x + acc_square.y, cooperative_groups::plus<float>());\n\n  float* buffer = shared_memory;\n  if (threadIdx.x % 32 == 0) {\n    buffer[threadIdx.x / 32] = warp_sum;  // Write warp_sum to buffer\n  }\n\n  // Step 1: CTA Reduce\n  __syncthreads();\n  if (threadIdx.x < 32) {\n    float cta_sum = cooperative_groups::reduce(\n        cg_warp, (threadIdx.x < blockDim.x / 32) ? buffer[threadIdx.x] : 0.0f, cooperative_groups::plus<float>());\n    buffer[threadIdx.x] =\n        rsqrtf(eps + cta_sum * (1.0f / static_cast<float>(vec_hidden_size * (VEC_SIZE_IN_BYTE / sizeof(T)))));\n  }\n  __syncthreads();\n\n  // Compute RMSNorm\n  if (threadIdx.x < vec_hidden_size) {\n    float rsqrt_square_sum = buffer[threadIdx.x / 32];  // Read rsqrt from Shared Memory(Broadcast)\n    for (int i = 0; i < inner_loop; i++) {\n      v_out[i] = rms(v[i], v_weight[i], rsqrt_square_sum);\n    }\n    vec_t* p_out = reinterpret_cast<vec_t*>(input) + token_id * vec_hidden_size;\n    p_out[threadIdx.x] = v_out;\n  }\n}\n\ntemplate <typename DType>\nstruct FusedAddRMSNormKernel {\n  static void\n  run(const tvm::ffi::TensorView input,\n      const tvm::ffi::TensorView residual,\n      const tvm::ffi::TensorView weight,\n      float eps) {\n    using namespace host;\n    auto N = SymbolicSize{\"num_tokens\"};\n    auto D = SymbolicSize{\"hidden_size\"};\n    auto device = SymbolicDevice{};\n    device.set_options<kDLCUDA>();\n\n    TensorMatcher({N, D})  // input\n        .with_strides({D, 1})\n        .with_dtype<DType>()\n        .with_device(device)\n        .verify(input);\n    TensorMatcher({D})  // weight\n        .with_dtype<DType>()\n        .with_device(device)\n        .verify(weight);\n    TensorMatcher({N, D})  // residual\n        .with_strides({D, 1})\n        .with_dtype<DType>()\n        .with_device(device)\n        .verify(residual);\n\n    auto cc_major = host::runtime::get_cc_major(device.unwrap().device_id);\n    int hidden_size = static_cast<int>(D.unwrap());\n    if ((cc_major <= 9 && hidden_size <= 8192) || (cc_major >= 10 && hidden_size <= 12288)) {\n      int max_vec_size_byte = cc_major >= 10 ? 32 : 16;\n      int elements_in_vec = max_vec_size_byte / sizeof(DType);\n      int vec_hidden_size = hidden_size / elements_in_vec;\n      uint threads = (vec_hidden_size + 31) / 32 * 32;\n\n      // Runtime check\n      host::RuntimeCheck(\n          hidden_size % elements_in_vec == 0,\n          \"hidden_size\",\n          hidden_size,\n          \" can not align to elements_in_vec \",\n          elements_in_vec);\n\n      // Launch kernel\n      auto kernel =\n          max_vec_size_byte == 32 ? fused_add_rmsnorm_reg_kernel<DType, 32> : fused_add_rmsnorm_reg_kernel<DType, 16>;\n      LaunchKernel(static_cast<uint>(N.unwrap()), threads, device.unwrap())\n          .enable_pdl(false)(\n              kernel,\n              reinterpret_cast<DType*>(input.data_ptr()),\n              reinterpret_cast<DType*>(residual.data_ptr()),\n              reinterpret_cast<DType*>(weight.data_ptr()),\n              vec_hidden_size,\n              eps);\n    } else {\n      host::RuntimeCheck(false, \"Large hidden_sizes are not supported for now.\");\n    }\n  }\n};\n\n}  // namespace\n"
  },
  {
    "path": "python/sglang/jit_kernel/csrc/elementwise/fused_metadata_copy.cuh",
    "content": "/*\n * Fused metadata copy kernel for NSA backend CUDA graph replay.\n * JIT-compiled version for python/sglang/jit_kernel.\n *\n * OVERVIEW:\n * This kernel fuses multiple tensor copy operations (cache_seqlens, cu_seqlens_k,\n * page_table, nsa metadata, and optional FlashMLA metadata) into single kernel\n * launches, significantly reducing kernel launch overhead and improving CUDA\n * graph replay performance during inference.\n *\n * PERFORMANCE BENEFITS:\n * - Single kernel launch vs. multiple separate copies (3-10x faster)\n * - Optimized memory coalescing and SM utilization\n * - __grid_constant__ parameter passing via constant memory\n * - Especially beneficial in CUDA graph replay scenarios\n *\n * DESIGN:\n * - Unified kernel supporting all forward modes (DECODE, TARGET_VERIFY, DRAFT_EXTEND)\n * - Structured parameter passing (SourcePointers/DestinationPointers) for clarity\n * - Template parameters (HAS_REAL_PAGE_TABLE, HAS_FLASHMLA) for compile-time optimization\n * - Multi-backend variant copies to 3 destinations in one kernel (for speculative decoding)\n *\n * USAGE:\n * This header is included by JIT compilation system. The FusedMetadataCopyKernel\n * and FusedMetadataCopyMultiKernel wrapper structs provide the Python-accessible interface.\n */\n\n#pragma once\n\n#include <sgl_kernel/tensor.h>\n#include <sgl_kernel/utils.h>\n\n#include <sgl_kernel/utils.cuh>\n\n#include <tvm/ffi/container/tensor.h>\n\n#include <algorithm>  // for std::min\n#include <cuda_runtime.h>\n\n// Forward mode enum (must match Python ForwardMode in sglang/srt/layers/attention/nsa_backend.py)\nenum ForwardModeEnum { DECODE = 0, TARGET_VERIFY = 1, DRAFT_EXTEND = 2 };\n\n/**\n * Source pointers for metadata copy operations.\n * Groups all source tensor pointers for cleaner parameter passing.\n * Some pointers may be nullptr depending on forward mode and feature flags.\n */\nstruct SourcePointers {\n  const int32_t* __restrict__ cache_seqlens;        // [bs] sequence lengths in cache\n  const int32_t* __restrict__ cu_seqlens_k;         // [bs+1] cumulative sequence lengths\n  const int32_t* __restrict__ page_indices;         // page table indices\n  const int32_t* __restrict__ nsa_cache_seqlens;    // NSA-specific cache lengths\n  const int32_t* __restrict__ seqlens_expanded;     // expanded sequence lengths (TARGET_VERIFY/DRAFT_EXTEND only)\n  const int32_t* __restrict__ nsa_cu_seqlens_k;     // NSA cumulative sequence lengths\n  const int32_t* __restrict__ real_page_table;      // optional real page table\n  const int32_t* __restrict__ flashmla_num_splits;  // optional FlashMLA split counts\n  const int32_t* __restrict__ flashmla_metadata;    // optional FlashMLA metadata\n};\n\n/**\n * Destination pointers for metadata copy operations.\n * Groups all destination tensor pointers for cleaner parameter passing.\n * Layout matches SourcePointers for consistency.\n */\nstruct DestinationPointers {\n  int32_t* __restrict__ cache_seqlens;        // [bs] sequence lengths in cache\n  int32_t* __restrict__ cu_seqlens_k;         // [bs+1] cumulative sequence lengths\n  int32_t* __restrict__ page_table_1;         // page table (note: different name from source)\n  int32_t* __restrict__ nsa_cache_seqlens;    // NSA-specific cache lengths\n  int32_t* __restrict__ seqlens_expanded;     // expanded sequence lengths (TARGET_VERIFY/DRAFT_EXTEND only)\n  int32_t* __restrict__ nsa_cu_seqlens_k;     // NSA cumulative sequence lengths\n  int32_t* __restrict__ real_page_table;      // optional real page table\n  int32_t* __restrict__ flashmla_num_splits;  // optional FlashMLA split counts\n  int32_t* __restrict__ flashmla_metadata;    // optional FlashMLA metadata\n};\n\n/**\n * Parameter structure for single-backend fused metadata copy kernel.\n * Passed via __grid_constant__ for efficient constant memory access.\n */\nstruct FusedMetadataCopyParams {\n  SourcePointers src;       // Source tensor pointers\n  DestinationPointers dst;  // Destination tensor pointers\n\n  // Kernel parameters\n  int forward_mode;                // 0=DECODE, 1=TARGET_VERIFY, 2=DRAFT_EXTEND\n  int bs;                          // Batch size\n  int max_len;                     // Max length for DECODE mode\n  int max_seqlen_k;                // Max sequence length for TARGET_VERIFY/DRAFT_EXTEND\n  int seqlens_expanded_size;       // Size of expanded sequence lengths\n  int page_indices_rows;           // Number of rows in page_indices\n  int page_table_1_stride;         // Stride for page_table_1\n  int real_page_table_cols;        // Columns in real_page_table\n  int real_page_table_dst_stride;  // Stride for destination real_page_table\n  int flashmla_metadata_size;      // Size of FlashMLA metadata\n};\n\n/**\n * Parameter structure for multi-backend fused metadata copy kernel.\n * Enables copying from one source to three destinations in a single kernel launch.\n * Used for speculative decoding with multiple draft backends.\n */\nstruct FusedMetadataCopyMultiParams {\n  SourcePointers src;        // Source pointers (shared across all backends)\n  DestinationPointers dst0;  // Backend 0 destination pointers\n  DestinationPointers dst1;  // Backend 1 destination pointers\n  DestinationPointers dst2;  // Backend 2 destination pointers\n\n  // Kernel parameters\n  int bs;                          // Batch size\n  int max_len;                     // Max length (DECODE mode only)\n  int seqlens_expanded_size;       // Size of expanded sequence lengths\n  int page_table_1_stride;         // Stride for page_table_1\n  int real_page_table_cols;        // Columns in real_page_table\n  int real_page_table_dst_stride;  // Stride for destination real_page_table\n  int flashmla_metadata_size;      // Size of FlashMLA metadata\n};\n\n/**\n * Unified kernel for all forward modes (DECODE, TARGET_VERIFY, DRAFT_EXTEND).\n * Uses runtime branches for mode selection, with template parameters for\n * compile-time optimization of optional features.\n *\n * DESIGN:\n * - Runtime branches (forward_mode) handle mode-specific logic\n * - Template parameters (HAS_*) eliminate unused feature code at compile time\n * - Structured parameters (SourcePointers/DestinationPointers) passed via constant memory\n *\n * Used by FusedMetadataCopyKernel for single-backend metadata copy.\n *\n * @tparam HAS_REAL_PAGE_TABLE Compile-time flag for real_page_table support\n * @tparam HAS_FLASHMLA Compile-time flag for FlashMLA metadata support\n */\ntemplate <bool HAS_REAL_PAGE_TABLE, bool HAS_FLASHMLA>\n__global__ void fused_metadata_copy_kernel(const FusedMetadataCopyParams __grid_constant__ params) {\n  int tid = blockIdx.x * blockDim.x + threadIdx.x;\n  int total_threads = gridDim.x * blockDim.x;\n\n  // Unpack parameters for readability\n  const auto& src = params.src;\n  const auto& dst = params.dst;\n  const int forward_mode = params.forward_mode;\n  const int bs = params.bs;\n  const int max_len = params.max_len;\n  const int max_seqlen_k = params.max_seqlen_k;\n  const int seqlens_expanded_size = params.seqlens_expanded_size;\n  const int page_indices_rows = params.page_indices_rows;\n  const int page_table_1_stride = params.page_table_1_stride;\n  const int real_page_table_cols = params.real_page_table_cols;\n  const int real_page_table_dst_stride = params.real_page_table_dst_stride;\n  const int flashmla_metadata_size = params.flashmla_metadata_size;\n\n  // Copy cache_seqlens (bs elements) - common to all modes\n#pragma unroll 8\n  for (int i = tid; i < bs; i += total_threads) {\n    dst.cache_seqlens[i] = src.cache_seqlens[i];\n  }\n\n  // Copy cu_seqlens_k (skip first element) - common to all modes\n#pragma unroll 8\n  for (int i = tid; i < bs; i += total_threads) {\n    dst.cu_seqlens_k[i + 1] = src.cu_seqlens_k[i + 1];\n  }\n\n  // Branch 1: page_table copy (different dimensions per mode)\n  if (forward_mode == 0) {  // DECODE\n    int page_table_elements = bs * max_len;\n#pragma unroll 4\n    for (int i = tid; i < page_table_elements; i += total_threads) {\n      int row = i / max_len;\n      int col = i % max_len;\n      dst.page_table_1[row * page_table_1_stride + col] = src.page_indices[i];\n    }\n  } else {  // TARGET_VERIFY or DRAFT_EXTEND\n    int page_table_elements = page_indices_rows * max_seqlen_k;\n#pragma unroll 4\n    for (int i = tid; i < page_table_elements; i += total_threads) {\n      int row = i / max_seqlen_k;\n      int col = i % max_seqlen_k;\n      dst.page_table_1[row * page_table_1_stride + col] = src.page_indices[i];\n    }\n  }\n\n  // Branch 2: seqlens_expanded copy (only for TARGET_VERIFY/DRAFT_EXTEND)\n  if (forward_mode != 0) {  // TARGET_VERIFY or DRAFT_EXTEND\n#pragma unroll 4\n    for (int i = tid; i < seqlens_expanded_size; i += total_threads) {\n      dst.seqlens_expanded[i] = src.seqlens_expanded[i];\n    }\n  }\n\n  // Branch 3: NSA metadata copy (different loop sizes per mode)\n  if (forward_mode == 0) {  // DECODE\n#pragma unroll 8\n    for (int i = tid; i < bs; i += total_threads) {\n      dst.nsa_cache_seqlens[i] = src.nsa_cache_seqlens[i];\n    }\n\n#pragma unroll 8\n    for (int i = tid; i < bs; i += total_threads) {\n      dst.nsa_cu_seqlens_k[i + 1] = src.nsa_cu_seqlens_k[i + 1];\n    }\n  } else {  // TARGET_VERIFY or DRAFT_EXTEND\n#pragma unroll 4\n    for (int i = tid; i < seqlens_expanded_size; i += total_threads) {\n      dst.nsa_cache_seqlens[i] = src.nsa_cache_seqlens[i];\n    }\n\n#pragma unroll 4\n    for (int i = tid; i < seqlens_expanded_size; i += total_threads) {\n      dst.nsa_cu_seqlens_k[i + 1] = src.nsa_cu_seqlens_k[i + 1];\n    }\n  }\n\n  // Copy real page table - compile-time branch\n  if constexpr (HAS_REAL_PAGE_TABLE) {\n    int real_table_elements = (forward_mode == 0 ? bs : page_indices_rows) * real_page_table_cols;\n#pragma unroll 2\n    for (int i = tid; i < real_table_elements; i += total_threads) {\n      int row = i / real_page_table_cols;\n      int col = i % real_page_table_cols;\n      dst.real_page_table[row * real_page_table_dst_stride + col] =\n          src.real_page_table[row * real_page_table_cols + col];\n    }\n  }\n\n  // Branch 4: FlashMLA metadata copy (different sizes per mode)\n  if constexpr (HAS_FLASHMLA) {\n    int flashmla_size = (forward_mode == 0) ? (bs + 1) : (seqlens_expanded_size + 1);\n\n    if (forward_mode == 0) {\n#pragma unroll 8\n      for (int i = tid; i < flashmla_size; i += total_threads) {\n        dst.flashmla_num_splits[i] = src.flashmla_num_splits[i];\n      }\n    } else {\n#pragma unroll 4\n      for (int i = tid; i < flashmla_size; i += total_threads) {\n        dst.flashmla_num_splits[i] = src.flashmla_num_splits[i];\n      }\n    }\n\n#pragma unroll 2\n    for (int i = tid; i < flashmla_metadata_size; i += total_threads) {\n      dst.flashmla_metadata[i] = src.flashmla_metadata[i];\n    }\n  }\n}\n\n/**\n * Multi-backend kernel for DECODE mode.\n * Copies from one source to THREE destinations in a single kernel launch.\n *\n * PERFORMANCE: 3x faster than three separate kernel launches due to:\n * - Reduced kernel launch overhead (1 launch instead of 3)\n * - Improved memory coalescing (source read once, written to 3 destinations)\n * - Better instruction-level parallelism\n *\n * Used by FusedMetadataCopyMultiKernel for speculative decoding scenarios.\n *\n * @tparam HAS_REAL_PAGE_TABLE Compile-time flag for real_page_table support\n * @tparam HAS_FLASHMLA Compile-time flag for FlashMLA metadata support\n */\ntemplate <bool HAS_REAL_PAGE_TABLE, bool HAS_FLASHMLA>\n__global__ void fused_metadata_copy_multi_kernel(const FusedMetadataCopyMultiParams __grid_constant__ params) {\n  int tid = blockIdx.x * blockDim.x + threadIdx.x;\n  int total_threads = gridDim.x * blockDim.x;\n\n  // Unpack parameters for readability\n  const auto& src = params.src;\n  const auto& dst0 = params.dst0;\n  const auto& dst1 = params.dst1;\n  const auto& dst2 = params.dst2;\n  const int bs = params.bs;\n  const int max_len = params.max_len;\n  const int seqlens_expanded_size = params.seqlens_expanded_size;\n  const int page_table_1_stride = params.page_table_1_stride;\n  const int real_page_table_cols = params.real_page_table_cols;\n  const int real_page_table_dst_stride = params.real_page_table_dst_stride;\n  const int flashmla_metadata_size = params.flashmla_metadata_size;\n\n  // Copy cache_seqlens to all 3 backends\n#pragma unroll 8\n  for (int i = tid; i < bs; i += total_threads) {\n    int32_t val = src.cache_seqlens[i];\n    dst0.cache_seqlens[i] = val;\n    dst1.cache_seqlens[i] = val;\n    dst2.cache_seqlens[i] = val;\n  }\n\n  // Copy cu_seqlens_k to all 3 backends (skip first element)\n#pragma unroll 8\n  for (int i = tid; i < bs; i += total_threads) {\n    int32_t val = src.cu_seqlens_k[i + 1];\n    dst0.cu_seqlens_k[i + 1] = val;\n    dst1.cu_seqlens_k[i + 1] = val;\n    dst2.cu_seqlens_k[i + 1] = val;\n  }\n\n  // DECODE mode: copy page_table_1 to all 3 backends\n  int page_table_elements = bs * max_len;\n#pragma unroll 4\n  for (int i = tid; i < page_table_elements; i += total_threads) {\n    int row = i / max_len;\n    int col = i % max_len;\n    int32_t val = src.page_indices[i];\n    dst0.page_table_1[row * page_table_1_stride + col] = val;\n    dst1.page_table_1[row * page_table_1_stride + col] = val;\n    dst2.page_table_1[row * page_table_1_stride + col] = val;\n  }\n\n  // Copy nsa_cache_seqlens to all 3 backends\n#pragma unroll 8\n  for (int i = tid; i < bs; i += total_threads) {\n    int32_t val = src.nsa_cache_seqlens[i];\n    dst0.nsa_cache_seqlens[i] = val;\n    dst1.nsa_cache_seqlens[i] = val;\n    dst2.nsa_cache_seqlens[i] = val;\n  }\n\n  // Copy NSA cu_seqlens to all 3 backends\n#pragma unroll 8\n  for (int i = tid; i < bs; i += total_threads) {\n    int32_t val = src.nsa_cu_seqlens_k[i + 1];\n    dst0.nsa_cu_seqlens_k[i + 1] = val;\n    dst1.nsa_cu_seqlens_k[i + 1] = val;\n    dst2.nsa_cu_seqlens_k[i + 1] = val;\n  }\n\n  // Copy real page table to all 3 backends\n  if (src.real_page_table != nullptr && dst0.real_page_table != nullptr) {\n    int real_table_elements = bs * real_page_table_cols;\n#pragma unroll 2\n    for (int i = tid; i < real_table_elements; i += total_threads) {\n      int row = i / real_page_table_cols;\n      int col = i % real_page_table_cols;\n      int src_idx = row * real_page_table_cols + col;\n      int dst_idx = row * real_page_table_dst_stride + col;\n      int32_t val = src.real_page_table[src_idx];\n      dst0.real_page_table[dst_idx] = val;\n      dst1.real_page_table[dst_idx] = val;\n      dst2.real_page_table[dst_idx] = val;\n    }\n  }\n\n  // Copy FlashMLA metadata to all 3 backends\n  if constexpr (HAS_FLASHMLA) {\n    int flashmla_size = bs + 1;\n#pragma unroll 8\n    for (int i = tid; i < flashmla_size; i += total_threads) {\n      int32_t val = src.flashmla_num_splits[i];\n      dst0.flashmla_num_splits[i] = val;\n      dst1.flashmla_num_splits[i] = val;\n      dst2.flashmla_num_splits[i] = val;\n    }\n\n#pragma unroll 2\n    for (int i = tid; i < flashmla_metadata_size; i += total_threads) {\n      int32_t val = src.flashmla_metadata[i];\n      dst0.flashmla_metadata[i] = val;\n      dst1.flashmla_metadata[i] = val;\n      dst2.flashmla_metadata[i] = val;\n    }\n  }\n}\n\n// ============================================================================\n// Host-side launcher wrappers for JIT compilation\n// ============================================================================\n\nnamespace {\n\n// Launch configuration constants\nconstexpr int THREADS_PER_BLOCK = 256;\nconstexpr int MAX_GRID_SIZE = 1024;  // Limit to prevent excessive resource usage\n\n/**\n * Helper function to extract a typed data pointer from a TensorView.\n * Performs runtime type checking and returns the properly cast pointer.\n *\n * @tparam T The expected element type (e.g., int32_t)\n * @param tensor The TensorView to extract the pointer from\n * @param name The name of the tensor (for error reporting)\n * @return Typed pointer to the tensor data\n */\ntemplate <typename T>\ninline const T* unwrap_data_ptr(const tvm::ffi::TensorView& tensor, const char* name) {\n  using namespace host;\n  if (tensor.data_ptr()) {\n    RuntimeCheck(is_type<T>(tensor.dtype()), \"Tensor \", name, \" must have dtype int32\");\n  }\n  return static_cast<const T*>(tensor.data_ptr());\n}\n\n/**\n * Helper function to extract a typed mutable data pointer from a TensorView.\n * Performs runtime type checking and returns the properly cast pointer.\n *\n * @tparam T The expected element type (e.g., int32_t)\n * @param tensor The TensorView to extract the pointer from\n * @param name The name of the tensor (for error reporting)\n * @return Typed mutable pointer to the tensor data\n */\ntemplate <typename T>\ninline T* unwrap_data_ptr_mut(const tvm::ffi::TensorView& tensor, const char* name) {\n  using namespace host;\n  if (tensor.data_ptr()) {\n    RuntimeCheck(is_type<T>(tensor.dtype()), \"Tensor \", name, \" must have dtype int32\");\n  }\n  return static_cast<T*>(tensor.data_ptr());\n}\n\n/**\n * Helper function to extract a typed data pointer from an Optional TensorView.\n * Returns nullptr if the optional has no value, otherwise performs type checking.\n *\n * @tparam T The expected element type (e.g., int32_t)\n * @param optional_tensor The Optional TensorView to extract the pointer from\n * @param name The name of the tensor (for error reporting)\n * @return Typed pointer to the tensor data, or nullptr if optional has no value\n */\ntemplate <typename T>\ninline const T*\nunwrap_optional_data_ptr(const tvm::ffi::Optional<tvm::ffi::TensorView>& optional_tensor, const char* name) {\n  using namespace host;\n  if (!optional_tensor.has_value()) {\n    return nullptr;\n  }\n  const auto& tensor = optional_tensor.value();\n  RuntimeCheck(is_type<T>(tensor.dtype()), \"Tensor \", name, \" must have dtype int32\");\n  return static_cast<const T*>(tensor.data_ptr());\n}\n\n/**\n * Helper function to extract a typed mutable data pointer from an Optional TensorView.\n * Returns nullptr if the optional has no value, otherwise performs type checking.\n *\n * @tparam T The expected element type (e.g., int32_t)\n * @param optional_tensor The Optional TensorView to extract the pointer from\n * @param name The name of the tensor (for error reporting)\n * @return Typed mutable pointer to the tensor data, or nullptr if optional has no value\n */\ntemplate <typename T>\ninline T*\nunwrap_optional_data_ptr_mut(const tvm::ffi::Optional<tvm::ffi::TensorView>& optional_tensor, const char* name) {\n  using namespace host;\n  if (!optional_tensor.has_value()) {\n    return nullptr;\n  }\n  const auto& tensor = optional_tensor.value();\n  RuntimeCheck(is_type<T>(tensor.dtype()), \"Tensor \", name, \" must have dtype int32\");\n  return static_cast<T*>(tensor.data_ptr());\n}\n\n/**\n * Calculate kernel launch configuration.\n *\n * @param total_work Total number of work items\n * @param threads_per_block Threads per block (default: THREADS_PER_BLOCK)\n * @return Grid dimension for kernel launch\n */\ninline dim3 get_launch_config(int total_work, int threads_per_block = THREADS_PER_BLOCK) {\n  int num_blocks = (total_work + threads_per_block - 1) / threads_per_block;\n  // Limit grid size to prevent excessive resource usage while ensuring coverage\n  num_blocks = std::min(num_blocks, MAX_GRID_SIZE);\n  return dim3(num_blocks);\n}\n\n/**\n * JIT wrapper for single-backend fused metadata copy kernel.\n *\n * This struct provides a unified interface for launching the fused metadata copy\n * kernel with different forward modes. It constructs the parameter struct and\n * launches the unified kernel.\n *\n * IMPLEMENTATION:\n * - Extracts raw pointers from TensorView objects\n * - Constructs FusedMetadataCopyParams with nested SourcePointers/DestinationPointers\n * - Calculates grid configuration based on maximum work size\n * - Launches fused_metadata_copy_kernel with __grid_constant__ parameters\n *\n * @tparam FORWARD_MODE Forward mode: 0=DECODE, 1=TARGET_VERIFY, 2=DRAFT_EXTEND\n * @tparam HAS_REAL_PAGE_TABLE Whether real_page_table tensors are present\n * @tparam HAS_FLASHMLA Whether FlashMLA metadata tensors are present\n */\ntemplate <int FORWARD_MODE, bool HAS_REAL_PAGE_TABLE, bool HAS_FLASHMLA>\nstruct FusedMetadataCopyKernel {\n  static_assert(\n      FORWARD_MODE >= 0 && FORWARD_MODE <= 2,\n      \"FORWARD_MODE must be 0 (DECODE), 1 (TARGET_VERIFY), or 2 (DRAFT_EXTEND)\");\n\n  static void\n  run(const tvm::ffi::TensorView cache_seqlens_src,\n      const tvm::ffi::TensorView cu_seqlens_k_src,\n      const tvm::ffi::TensorView page_indices_src,\n      const tvm::ffi::TensorView nsa_cache_seqlens_src,\n      const tvm::ffi::Optional<tvm::ffi::TensorView> seqlens_expanded_src,\n      const tvm::ffi::TensorView nsa_cu_seqlens_k_src,\n      const tvm::ffi::Optional<tvm::ffi::TensorView> real_page_table_src,\n      const tvm::ffi::Optional<tvm::ffi::TensorView> flashmla_num_splits_src,\n      const tvm::ffi::Optional<tvm::ffi::TensorView> flashmla_metadata_src,\n      const tvm::ffi::TensorView cache_seqlens_dst,\n      const tvm::ffi::TensorView cu_seqlens_k_dst,\n      const tvm::ffi::TensorView page_table_1_dst,\n      const tvm::ffi::TensorView nsa_cache_seqlens_dst,\n      const tvm::ffi::Optional<tvm::ffi::TensorView> seqlens_expanded_dst,\n      const tvm::ffi::TensorView nsa_cu_seqlens_k_dst,\n      const tvm::ffi::Optional<tvm::ffi::TensorView> real_page_table_dst,\n      const tvm::ffi::Optional<tvm::ffi::TensorView> flashmla_num_splits_dst,\n      const tvm::ffi::Optional<tvm::ffi::TensorView> flashmla_metadata_dst,\n      int bs,\n      int max_len,\n      int max_seqlen_k,\n      int seqlens_expanded_size) {\n    using namespace host;\n\n    // Build parameter struct with nested source/destination pointers\n    // unwrap_data_ptr and unwrap_optional_data_ptr perform dtype validation\n    const auto params = FusedMetadataCopyParams{\n        .src =\n            {\n                .cache_seqlens = unwrap_data_ptr<int32_t>(cache_seqlens_src, \"cache_seqlens_src\"),\n                .cu_seqlens_k = unwrap_data_ptr<int32_t>(cu_seqlens_k_src, \"cu_seqlens_k_src\"),\n                .page_indices = unwrap_data_ptr<int32_t>(page_indices_src, \"page_indices_src\"),\n                .nsa_cache_seqlens = unwrap_data_ptr<int32_t>(nsa_cache_seqlens_src, \"nsa_cache_seqlens_src\"),\n                .seqlens_expanded = unwrap_optional_data_ptr<int32_t>(seqlens_expanded_src, \"seqlens_expanded_src\"),\n                .nsa_cu_seqlens_k = unwrap_data_ptr<int32_t>(nsa_cu_seqlens_k_src, \"nsa_cu_seqlens_k_src\"),\n                .real_page_table = unwrap_optional_data_ptr<int32_t>(real_page_table_src, \"real_page_table_src\"),\n                .flashmla_num_splits =\n                    unwrap_optional_data_ptr<int32_t>(flashmla_num_splits_src, \"flashmla_num_splits_src\"),\n                .flashmla_metadata = unwrap_optional_data_ptr<int32_t>(flashmla_metadata_src, \"flashmla_metadata_src\"),\n            },\n        .dst =\n            {\n                .cache_seqlens = unwrap_data_ptr_mut<int32_t>(cache_seqlens_dst, \"cache_seqlens_dst\"),\n                .cu_seqlens_k = unwrap_data_ptr_mut<int32_t>(cu_seqlens_k_dst, \"cu_seqlens_k_dst\"),\n                .page_table_1 = unwrap_data_ptr_mut<int32_t>(page_table_1_dst, \"page_table_1_dst\"),\n                .nsa_cache_seqlens = unwrap_data_ptr_mut<int32_t>(nsa_cache_seqlens_dst, \"nsa_cache_seqlens_dst\"),\n                .seqlens_expanded = unwrap_optional_data_ptr_mut<int32_t>(seqlens_expanded_dst, \"seqlens_expanded_dst\"),\n                .nsa_cu_seqlens_k = unwrap_data_ptr_mut<int32_t>(nsa_cu_seqlens_k_dst, \"nsa_cu_seqlens_k_dst\"),\n                .real_page_table = unwrap_optional_data_ptr_mut<int32_t>(real_page_table_dst, \"real_page_table_dst\"),\n                .flashmla_num_splits =\n                    unwrap_optional_data_ptr_mut<int32_t>(flashmla_num_splits_dst, \"flashmla_num_splits_dst\"),\n                .flashmla_metadata =\n                    unwrap_optional_data_ptr_mut<int32_t>(flashmla_metadata_dst, \"flashmla_metadata_dst\"),\n            },\n        .forward_mode = FORWARD_MODE,\n        .bs = bs,\n        .max_len = max_len,\n        .max_seqlen_k = max_seqlen_k,\n        .seqlens_expanded_size = seqlens_expanded_size,\n        .page_indices_rows = static_cast<int>(page_indices_src.shape()[0]),\n        .page_table_1_stride = static_cast<int>(page_table_1_dst.shape()[1]),\n        .real_page_table_cols =\n            real_page_table_src.has_value() ? static_cast<int>(real_page_table_src.value().shape()[1]) : 0,\n        .real_page_table_dst_stride =\n            real_page_table_dst.has_value() ? static_cast<int>(real_page_table_dst.value().stride(0)) : 0,\n        .flashmla_metadata_size =\n            flashmla_metadata_src.has_value() ? static_cast<int>(flashmla_metadata_src.value().numel()) : 0,\n    };\n\n    // Calculate grid configuration\n    int max_elements = std::max(\n        {bs,\n         params.page_indices_rows * max_seqlen_k,\n         seqlens_expanded_size,\n         HAS_FLASHMLA ? (seqlens_expanded_size + 1) : 0,\n         HAS_FLASHMLA ? params.flashmla_metadata_size : 0});\n\n    dim3 grid = get_launch_config(max_elements);\n    dim3 block(THREADS_PER_BLOCK);\n    DLDevice device = cache_seqlens_src.device();\n\n    // Launch unified kernel with params struct\n    host::LaunchKernel(grid, block, device)(fused_metadata_copy_kernel<HAS_REAL_PAGE_TABLE, HAS_FLASHMLA>, params);\n  }\n};\n\n/**\n * JIT wrapper for multi-backend fused metadata copy kernel.\n *\n * This kernel optimizes the common case where metadata needs to be copied from\n * one source to THREE destination backends in a single kernel launch. This is\n * 3x faster than launching three separate kernels due to:\n * - Reduced kernel launch overhead (1 launch instead of 3)\n * - Improved memory coalescing (source read once, written to 3 destinations)\n * - Better GPU occupancy and instruction-level parallelism\n *\n * USAGE: Primarily for speculative decoding with multiple draft models, where\n * the same source metadata needs to be replicated to multiple backend contexts.\n *\n * LIMITATION: Currently only supports DECODE mode, which is the most frequently\n * used mode in speculative decoding scenarios.\n *\n * IMPLEMENTATION:\n * - Constructs FusedMetadataCopyMultiParams with 1 SourcePointers + 3 DestinationPointers\n * - Launches fused_metadata_copy_multi_kernel with __grid_constant__ parameters\n *\n * @tparam HAS_REAL_PAGE_TABLE Whether real_page_table tensors are present\n * @tparam HAS_FLASHMLA Whether FlashMLA metadata tensors are present\n */\ntemplate <bool HAS_REAL_PAGE_TABLE, bool HAS_FLASHMLA>\nstruct FusedMetadataCopyMultiKernel {\n  static void\n  run(const tvm::ffi::TensorView cache_seqlens_src,\n      const tvm::ffi::TensorView cu_seqlens_k_src,\n      const tvm::ffi::TensorView page_indices_src,\n      const tvm::ffi::TensorView nsa_cache_seqlens_src,\n      const tvm::ffi::TensorView nsa_cu_seqlens_k_src,\n      const tvm::ffi::Optional<tvm::ffi::TensorView> real_page_table_src,\n      const tvm::ffi::Optional<tvm::ffi::TensorView> flashmla_num_splits_src,\n      const tvm::ffi::Optional<tvm::ffi::TensorView> flashmla_metadata_src,\n      const tvm::ffi::TensorView cache_seqlens_dst0,\n      const tvm::ffi::TensorView cu_seqlens_k_dst0,\n      const tvm::ffi::TensorView page_table_1_dst0,\n      const tvm::ffi::TensorView nsa_cache_seqlens_dst0,\n      const tvm::ffi::TensorView nsa_cu_seqlens_k_dst0,\n      const tvm::ffi::Optional<tvm::ffi::TensorView> real_page_table_dst0,\n      const tvm::ffi::Optional<tvm::ffi::TensorView> flashmla_num_splits_dst0,\n      const tvm::ffi::Optional<tvm::ffi::TensorView> flashmla_metadata_dst0,\n      const tvm::ffi::TensorView cache_seqlens_dst1,\n      const tvm::ffi::TensorView cu_seqlens_k_dst1,\n      const tvm::ffi::TensorView page_table_1_dst1,\n      const tvm::ffi::TensorView nsa_cache_seqlens_dst1,\n      const tvm::ffi::TensorView nsa_cu_seqlens_k_dst1,\n      const tvm::ffi::Optional<tvm::ffi::TensorView> real_page_table_dst1,\n      const tvm::ffi::Optional<tvm::ffi::TensorView> flashmla_num_splits_dst1,\n      const tvm::ffi::Optional<tvm::ffi::TensorView> flashmla_metadata_dst1,\n      const tvm::ffi::TensorView cache_seqlens_dst2,\n      const tvm::ffi::TensorView cu_seqlens_k_dst2,\n      const tvm::ffi::TensorView page_table_1_dst2,\n      const tvm::ffi::TensorView nsa_cache_seqlens_dst2,\n      const tvm::ffi::TensorView nsa_cu_seqlens_k_dst2,\n      const tvm::ffi::Optional<tvm::ffi::TensorView> real_page_table_dst2,\n      const tvm::ffi::Optional<tvm::ffi::TensorView> flashmla_num_splits_dst2,\n      const tvm::ffi::Optional<tvm::ffi::TensorView> flashmla_metadata_dst2,\n      int bs,\n      int max_len,\n      int seqlens_expanded_size) {\n    using namespace host;\n\n    // Build parameter struct with nested source/destination pointers\n    // unwrap_data_ptr and unwrap_optional_data_ptr perform dtype validation\n    const auto params = FusedMetadataCopyMultiParams{\n        .src =\n            {\n                .cache_seqlens = unwrap_data_ptr<int32_t>(cache_seqlens_src, \"cache_seqlens_src\"),\n                .cu_seqlens_k = unwrap_data_ptr<int32_t>(cu_seqlens_k_src, \"cu_seqlens_k_src\"),\n                .page_indices = unwrap_data_ptr<int32_t>(page_indices_src, \"page_indices_src\"),\n                .nsa_cache_seqlens = unwrap_data_ptr<int32_t>(nsa_cache_seqlens_src, \"nsa_cache_seqlens_src\"),\n                .seqlens_expanded = nullptr,  // Not used in multi-backend DECODE mode\n                .nsa_cu_seqlens_k = unwrap_data_ptr<int32_t>(nsa_cu_seqlens_k_src, \"nsa_cu_seqlens_k_src\"),\n                .real_page_table = unwrap_optional_data_ptr<int32_t>(real_page_table_src, \"real_page_table_src\"),\n                .flashmla_num_splits =\n                    unwrap_optional_data_ptr<int32_t>(flashmla_num_splits_src, \"flashmla_num_splits_src\"),\n                .flashmla_metadata = unwrap_optional_data_ptr<int32_t>(flashmla_metadata_src, \"flashmla_metadata_src\"),\n            },\n        .dst0 =\n            {\n                .cache_seqlens = unwrap_data_ptr_mut<int32_t>(cache_seqlens_dst0, \"cache_seqlens_dst0\"),\n                .cu_seqlens_k = unwrap_data_ptr_mut<int32_t>(cu_seqlens_k_dst0, \"cu_seqlens_k_dst0\"),\n                .page_table_1 = unwrap_data_ptr_mut<int32_t>(page_table_1_dst0, \"page_table_1_dst0\"),\n                .nsa_cache_seqlens = unwrap_data_ptr_mut<int32_t>(nsa_cache_seqlens_dst0, \"nsa_cache_seqlens_dst0\"),\n                .seqlens_expanded = nullptr,\n                .nsa_cu_seqlens_k = unwrap_data_ptr_mut<int32_t>(nsa_cu_seqlens_k_dst0, \"nsa_cu_seqlens_k_dst0\"),\n                .real_page_table = unwrap_optional_data_ptr_mut<int32_t>(real_page_table_dst0, \"real_page_table_dst0\"),\n                .flashmla_num_splits =\n                    unwrap_optional_data_ptr_mut<int32_t>(flashmla_num_splits_dst0, \"flashmla_num_splits_dst0\"),\n                .flashmla_metadata =\n                    unwrap_optional_data_ptr_mut<int32_t>(flashmla_metadata_dst0, \"flashmla_metadata_dst0\"),\n            },\n        .dst1 =\n            {\n                .cache_seqlens = unwrap_data_ptr_mut<int32_t>(cache_seqlens_dst1, \"cache_seqlens_dst1\"),\n                .cu_seqlens_k = unwrap_data_ptr_mut<int32_t>(cu_seqlens_k_dst1, \"cu_seqlens_k_dst1\"),\n                .page_table_1 = unwrap_data_ptr_mut<int32_t>(page_table_1_dst1, \"page_table_1_dst1\"),\n                .nsa_cache_seqlens = unwrap_data_ptr_mut<int32_t>(nsa_cache_seqlens_dst1, \"nsa_cache_seqlens_dst1\"),\n                .seqlens_expanded = nullptr,\n                .nsa_cu_seqlens_k = unwrap_data_ptr_mut<int32_t>(nsa_cu_seqlens_k_dst1, \"nsa_cu_seqlens_k_dst1\"),\n                .real_page_table = unwrap_optional_data_ptr_mut<int32_t>(real_page_table_dst1, \"real_page_table_dst1\"),\n                .flashmla_num_splits =\n                    unwrap_optional_data_ptr_mut<int32_t>(flashmla_num_splits_dst1, \"flashmla_num_splits_dst1\"),\n                .flashmla_metadata =\n                    unwrap_optional_data_ptr_mut<int32_t>(flashmla_metadata_dst1, \"flashmla_metadata_dst1\"),\n            },\n        .dst2 =\n            {\n                .cache_seqlens = unwrap_data_ptr_mut<int32_t>(cache_seqlens_dst2, \"cache_seqlens_dst2\"),\n                .cu_seqlens_k = unwrap_data_ptr_mut<int32_t>(cu_seqlens_k_dst2, \"cu_seqlens_k_dst2\"),\n                .page_table_1 = unwrap_data_ptr_mut<int32_t>(page_table_1_dst2, \"page_table_1_dst2\"),\n                .nsa_cache_seqlens = unwrap_data_ptr_mut<int32_t>(nsa_cache_seqlens_dst2, \"nsa_cache_seqlens_dst2\"),\n                .seqlens_expanded = nullptr,\n                .nsa_cu_seqlens_k = unwrap_data_ptr_mut<int32_t>(nsa_cu_seqlens_k_dst2, \"nsa_cu_seqlens_k_dst2\"),\n                .real_page_table = unwrap_optional_data_ptr_mut<int32_t>(real_page_table_dst2, \"real_page_table_dst2\"),\n                .flashmla_num_splits =\n                    unwrap_optional_data_ptr_mut<int32_t>(flashmla_num_splits_dst2, \"flashmla_num_splits_dst2\"),\n                .flashmla_metadata =\n                    unwrap_optional_data_ptr_mut<int32_t>(flashmla_metadata_dst2, \"flashmla_metadata_dst2\"),\n            },\n        .bs = bs,\n        .max_len = max_len,\n        .seqlens_expanded_size = seqlens_expanded_size,\n        .page_table_1_stride = static_cast<int>(page_table_1_dst0.shape()[1]),\n        .real_page_table_cols =\n            real_page_table_src.has_value() ? static_cast<int>(real_page_table_src.value().shape()[1]) : 0,\n        .real_page_table_dst_stride =\n            real_page_table_dst0.has_value() ? static_cast<int>(real_page_table_dst0.value().stride(0)) : 0,\n        .flashmla_metadata_size =\n            flashmla_metadata_src.has_value() ? static_cast<int>(flashmla_metadata_src.value().numel()) : 0,\n    };\n\n    dim3 grid = get_launch_config(bs * max_len);\n    dim3 block(THREADS_PER_BLOCK);\n    DLDevice device = cache_seqlens_src.device();\n\n    // Launch multi-backend kernel with params struct\n    host::LaunchKernel(grid, block, device)(\n        fused_metadata_copy_multi_kernel<HAS_REAL_PAGE_TABLE, HAS_FLASHMLA>, params);\n  }\n};\n\n}  // namespace\n"
  },
  {
    "path": "python/sglang/jit_kernel/csrc/elementwise/kvcache.cuh",
    "content": "#include <sgl_kernel/tensor.h>\n#include <sgl_kernel/utils.h>\n\n#include <sgl_kernel/tile.cuh>\n#include <sgl_kernel/utils.cuh>\n#include <sgl_kernel/vec.cuh>\n\n#include <dlpack/dlpack.h>\n#include <tvm/ffi/container/tensor.h>\n\n#include <cstdint>\n\nnamespace {\n\nstruct StoreKVCacheParams {\n  const void* __restrict__ k;\n  const void* __restrict__ v;\n  void* __restrict__ k_cache;\n  void* __restrict__ v_cache;\n  const void* __restrict__ indices;\n  int64_t stride_k_bytes;\n  int64_t stride_v_bytes;\n  int64_t stride_cache_bytes;\n  int64_t stride_indices;\n  uint32_t batch_size;\n};\n\nconstexpr uint32_t kNumWarps = 4;\nconstexpr uint32_t kThreadsPerBlock = kNumWarps * device::kWarpThreads;\n\n/**\n * \\brief Use a single warp to copy key and value data from source to destination.\n * Each thread in the warp copies a portion of the data in a coalesced manner.\n * \\tparam kElementBytes The size of each key/value element in bytes.\n * \\param k_src Pointer to the source key data.\n * \\param v_src Pointer to the source value data.\n * \\param k_dst Pointer to the destination key data.\n * \\param v_dst Pointer to the destination value data.\n */\ntemplate <int64_t kElementBytes>\nSGL_DEVICE void copy_kv_warp(\n    const void* __restrict__ k_src,\n    const void* __restrict__ v_src,\n    void* __restrict__ k_dst,\n    void* __restrict__ v_dst) {\n  using namespace device;\n  constexpr int64_t kAlignment = (kElementBytes % (16 * kWarpThreads) == 0) ? 16\n                                 : kElementBytes % (8 * kWarpThreads) == 0  ? 8\n                                 : kElementBytes % (4 * kWarpThreads) == 0  ? 4\n                                 : kElementBytes % 4 == 0                   ? 4\n                                                                            : 0;\n\n  static_assert(kAlignment > 0, \"Element size must be multiple of 4 bytes\");\n\n  using vec_t = AlignedStorage<uint32_t, kAlignment / 4>;\n  constexpr auto kLoopBytes = sizeof(vec_t) * kWarpThreads;\n  constexpr auto kLoopCount = kElementBytes / kLoopBytes;\n\n  const auto gmem = tile::Memory<vec_t>::warp();\n\n#pragma unroll kLoopCount\n  for (int64_t i = 0; i < kLoopCount; ++i) {\n    const auto k = gmem.load(k_src, i);\n    const auto v = gmem.load(v_src, i);\n    gmem.store(k_dst, k, i);\n    gmem.store(v_dst, v, i);\n  }\n\n  // handle the epilogue if any\n  if constexpr (kLoopCount * kLoopBytes < kElementBytes) {\n    if (gmem.in_bound(kElementBytes / sizeof(vec_t), kLoopCount)) {\n      const auto k = gmem.load(k_src, kLoopCount);\n      const auto v = gmem.load(v_src, kLoopCount);\n      gmem.store(k_dst, k, kLoopCount);\n      gmem.store(v_dst, v, kLoopCount);\n    }\n  }\n}\n\n/**\n * \\brief Kernel to store key-value pairs into the KV cache.\n * Each element is split into multiple parts to allow parallel memory copy.\n * \\tparam kElementBytes The size of each key/value element in bytes.\n * \\tparam kSplit The number of warps that handle each element.\n * \\tparam kUsePDL Whether to use PDL feature.\n * \\tparam T The data type of the indices (`int32_t` or `int64_t`).\n */\ntemplate <int64_t kElementBytes, int kSplit, bool kUsePDL, typename T>\n__global__ void store_kvcache(const __grid_constant__ StoreKVCacheParams params) {\n  using namespace device;\n  constexpr auto kSplitSize = kElementBytes / kSplit;\n  const uint32_t warp_id = blockIdx.x * kNumWarps + threadIdx.x / kWarpThreads;\n  const uint32_t item_id = warp_id / kSplit;\n  const uint32_t split_id = warp_id % kSplit;\n  const auto& [\n    k_input, v_input, k_cache, v_cache, indices, // ptr\n    stride_k, stride_v, stride_cache, stride_indices, batch_size // size\n  ] = params;\n  if (item_id >= batch_size) return;\n\n  const auto index_ptr = static_cast<const T*>(indices) + item_id * stride_indices;\n  PDLWaitPrimary<kUsePDL>();\n\n  const auto index = *index_ptr;\n  const auto k_src = pointer::offset(k_input, item_id * stride_k, split_id * kSplitSize);\n  const auto v_src = pointer::offset(v_input, item_id * stride_v, split_id * kSplitSize);\n  const auto k_dst = pointer::offset(k_cache, index * stride_cache, split_id * kSplitSize);\n  const auto v_dst = pointer::offset(v_cache, index * stride_cache, split_id * kSplitSize);\n\n  copy_kv_warp<kSplitSize>(k_src, v_src, k_dst, v_dst);\n  PDLTriggerSecondary<kUsePDL>();\n}\n\ntemplate <int64_t kElementBytes, bool kUsePDL>\nstruct StoreKVCacheKernel {\n  static_assert(kElementBytes > 0 && kElementBytes % 4 == 0);\n\n  template <int kSplit, typename T>\n  static constexpr auto store_kernel = store_kvcache<kElementBytes, kSplit, kUsePDL, T>;\n\n  template <typename T>\n  static auto get_kernel(const int num_split) {\n    using namespace host;\n    // only apply split optimization when element size is aligned\n    if constexpr (kElementBytes % (4 * 128) == 0) {\n      if (num_split == 4) return store_kernel<4, T>;\n    }\n    if constexpr (kElementBytes % (2 * 128) == 0) {\n      if (num_split == 2) return store_kernel<2, T>;\n    }\n    if (num_split == 1) return store_kernel<1, T>;\n    Panic(\"Unsupported num_split {} for element size {}\", num_split, kElementBytes);\n  }\n\n  static void\n  run(const tvm::ffi::TensorView k,\n      const tvm::ffi::TensorView v,\n      const tvm::ffi::TensorView k_cache,\n      const tvm::ffi::TensorView v_cache,\n      const tvm::ffi::TensorView indices,\n      const int num_split) {\n    using namespace host;\n    auto B = SymbolicSize{\"batch_size\"};\n    auto D = SymbolicSize{\"element_size\"};\n    auto KS = SymbolicSize{\"k_stride\"};\n    auto VS = SymbolicSize{\"v_stride\"};\n    auto S = SymbolicSize{\"cache_stride\"};\n    auto I = SymbolicSize{\"indices_stride\"};\n    auto dtype = SymbolicDType{};\n    auto device = SymbolicDevice{};\n    auto indice_dtype = SymbolicDType{};\n    device.set_options<kDLCUDA, kDLROCM>();\n\n    TensorMatcher({B, D})  //\n        .with_strides({KS, 1})\n        .with_dtype(dtype)\n        .with_device(device)\n        .verify(k);\n    TensorMatcher({B, D})  //\n        .with_strides({VS, 1})\n        .with_dtype(dtype)\n        .with_device(device)\n        .verify(v);\n    TensorMatcher({-1, D})  //\n        .with_strides({S, 1})\n        .with_dtype(dtype)\n        .with_device(device)\n        .verify(k_cache)\n        .verify(v_cache);\n    TensorMatcher({B})  //\n        .with_strides({I})\n        .with_dtype<int32_t, int64_t>(indice_dtype)\n        .with_device(device)\n        .verify(indices);\n\n    const int64_t dtype_size = dtype_bytes(dtype.unwrap());\n    const uint32_t num_elements = static_cast<uint32_t>(B.unwrap());\n    RuntimeCheck(kElementBytes == dtype_size * D.unwrap());\n\n    const auto params = StoreKVCacheParams{\n        .k = k.data_ptr(),\n        .v = v.data_ptr(),\n        .k_cache = k_cache.data_ptr(),\n        .v_cache = v_cache.data_ptr(),\n        .indices = indices.data_ptr(),\n        .stride_k_bytes = KS.unwrap() * dtype_size,\n        .stride_v_bytes = VS.unwrap() * dtype_size,\n        .stride_cache_bytes = S.unwrap() * dtype_size,\n        .stride_indices = I.unwrap(),\n        .batch_size = static_cast<uint32_t>(B.unwrap()),\n    };\n    // select kernel and update num_split if needed\n    const auto use_int32 = indice_dtype.is_type<int32_t>();\n    const auto kernel = use_int32 ? get_kernel<int32_t>(num_split) : get_kernel<int64_t>(num_split);\n    const auto num_blocks = div_ceil(num_elements * num_split, kNumWarps);\n    LaunchKernel(num_blocks, kThreadsPerBlock, device.unwrap())  //\n        .enable_pdl(kUsePDL)(kernel, params);\n  }\n};\n\n}  // namespace\n"
  },
  {
    "path": "python/sglang/jit_kernel/csrc/elementwise/pos_enc.cuh",
    "content": "// Adapted from\n// https://github.com/vllm-project/vllm/blob/014ece97c7aa49084a1119dca792af081a18dbc1/csrc/pos_encoding_kernels.cu\n\n#include <sgl_kernel/tensor.h>\n#include <sgl_kernel/utils.h>\n\n#include <sgl_kernel/utils.cuh>\n\n#include <tvm/ffi/container/tensor.h>\n\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n\nnamespace {\n\ntemplate <typename scalar_t, bool IS_NEOX>\ninline __device__ void apply_token_rotary_embedding(\n    scalar_t* __restrict__ arr,\n    const scalar_t* __restrict__ cos_ptr,\n    const scalar_t* __restrict__ sin_ptr,\n    int rot_offset,\n    int embed_dim) {\n  int x_index, y_index;\n  scalar_t cos, sin;\n  if (IS_NEOX) {\n    // GPT-NeoX style rotary embedding.\n    x_index = rot_offset;\n    y_index = embed_dim + rot_offset;\n    cos = SGLANG_LDG(cos_ptr + x_index);\n    sin = SGLANG_LDG(sin_ptr + x_index);\n  } else {\n    // GPT-J style rotary embedding.\n    x_index = 2 * rot_offset;\n    y_index = 2 * rot_offset + 1;\n    cos = SGLANG_LDG(cos_ptr + x_index / 2);\n    sin = SGLANG_LDG(sin_ptr + x_index / 2);\n  }\n\n  const scalar_t x = arr[x_index];\n  const scalar_t y = arr[y_index];\n  arr[x_index] = x * cos - y * sin;\n  arr[y_index] = y * cos + x * sin;\n}\n\ntemplate <typename scalar_t, bool IS_NEOX>\ninline __device__ void apply_rotary_embedding(\n    scalar_t* __restrict__ query,  // [batch_size, seq_len, num_heads,\n                                   // head_size] or [num_tokens, num_heads,\n                                   // head_size]\n    scalar_t* __restrict__ key,    // nullptr or\n                                   // [batch_size, seq_len, num_kv_heads,\n                                   // head_size] or [num_tokens, num_kv_heads,\n                                   // head_size]\n    const scalar_t* cache_ptr,\n    const int head_size,\n    const int num_heads,\n    const int num_kv_heads,\n    const int rot_dim,\n    const int token_idx,\n    const int64_t query_stride,\n    const int64_t key_stride,\n    const int64_t head_stride) {\n  const int embed_dim = rot_dim / 2;\n  const scalar_t* cos_ptr = cache_ptr;\n  const scalar_t* sin_ptr = cache_ptr + embed_dim;\n\n  const int nq = num_heads * embed_dim;\n  for (int i = threadIdx.x; i < nq; i += blockDim.x) {\n    const int head_idx = i / embed_dim;\n    const int64_t token_head = token_idx * query_stride + head_idx * head_stride;\n    const int rot_offset = i % embed_dim;\n    apply_token_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);\n  }\n\n  if (key != nullptr) {\n    const int nk = num_kv_heads * embed_dim;\n    for (int i = threadIdx.x; i < nk; i += blockDim.x) {\n      const int head_idx = i / embed_dim;\n      const int64_t token_head = token_idx * key_stride + head_idx * head_stride;\n      const int rot_offset = i % embed_dim;\n      apply_token_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);\n    }\n  }\n}\n\ntemplate <typename scalar_t, bool IS_NEOX>\n__global__ void rotary_embedding_kernel(\n    const int64_t* __restrict__ positions,       // [batch_size, seq_len] or\n                                                 // [num_tokens]\n    scalar_t* __restrict__ query,                // [batch_size, seq_len, num_heads,\n                                                 // head_size] or [num_tokens, num_heads,\n                                                 // head_size]\n    scalar_t* __restrict__ key,                  // nullptr or\n                                                 // [batch_size, seq_len, num_kv_heads,\n                                                 // head_size] or [num_tokens, num_kv_heads,\n                                                 // head_size]\n    const scalar_t* __restrict__ cos_sin_cache,  // [max_position, 2, rot_dim //\n                                                 // 2]\n    const int rot_dim,\n    const int64_t query_stride,\n    const int64_t key_stride,\n    const int64_t head_stride,\n    const int num_heads,\n    const int num_kv_heads,\n    const int head_size) {\n  // Each thread block is responsible for one token.\n  const int token_idx = blockIdx.x;\n  int64_t pos = positions[token_idx];\n  const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;\n\n  apply_rotary_embedding<scalar_t, IS_NEOX>(\n      query,\n      key,\n      cache_ptr,\n      head_size,\n      num_heads,\n      num_kv_heads,\n      rot_dim,\n      token_idx,\n      query_stride,\n      key_stride,\n      head_stride);\n}\n\n// Helper struct to launch kernel\ntemplate <typename scalar_t, bool IS_NEOX>\nvoid launch_kernel(\n    const int64_t* positions_data_ptr,\n    void* query_ptr,\n    void* key_ptr,\n    const void* cos_sin_cache_ptr,\n    int rot_dim,\n    int64_t query_stride,\n    int64_t key_stride,\n    int64_t head_stride,\n    int num_heads,\n    int num_kv_heads,\n    int head_size,\n    dim3 grid,\n    dim3 block,\n    const cudaStream_t stream) {\n  rotary_embedding_kernel<scalar_t, IS_NEOX><<<grid, block, 0, stream>>>(\n      positions_data_ptr,\n      static_cast<scalar_t*>(query_ptr),\n      static_cast<scalar_t*>(key_ptr),\n      static_cast<const scalar_t*>(cos_sin_cache_ptr),\n      rot_dim,\n      query_stride,\n      key_stride,\n      head_stride,\n      num_heads,\n      num_kv_heads,\n      head_size);\n};\n\n// Helper macro to reduce repetition\n#define DISPATCH_DTYPE(DTYPE_CODE, DTYPE_BITS, IS_NEOX, ...)                                                      \\\n  if (DTYPE_CODE == kDLFloat && DTYPE_BITS == 32) {                                                               \\\n    launch_kernel<float, IS_NEOX>(__VA_ARGS__);                                                                   \\\n  } else if (DTYPE_CODE == kDLFloat && DTYPE_BITS == 16) {                                                        \\\n    launch_kernel<half, IS_NEOX>(__VA_ARGS__);                                                                    \\\n  } else if (DTYPE_CODE == kDLBfloat && DTYPE_BITS == 16) {                                                       \\\n    launch_kernel<nv_bfloat16, IS_NEOX>(__VA_ARGS__);                                                             \\\n  } else {                                                                                                        \\\n    RuntimeCheck(                                                                                                 \\\n        false, \"Unsupported data type for rotary embedding. Only float32, float16, and bfloat16 are supported.\"); \\\n  }\n\n// Helper function to dispatch based on data type\ntemplate <bool IS_NEOX>\nvoid dispatch_by_dtype(\n    const int64_t* positions_data_ptr,\n    DLDataType query_dtype,\n    void* query_ptr,\n    void* key_ptr,\n    void* cos_sin_cache_ptr,\n    int rot_dim,\n    int64_t query_stride,\n    int64_t key_stride,\n    int64_t head_stride,\n    int num_heads,\n    int num_kv_heads,\n    int head_size,\n    dim3 grid,\n    dim3 block,\n    const cudaStream_t stream) {\n  using namespace host;\n  DISPATCH_DTYPE(\n      query_dtype.code,\n      query_dtype.bits,\n      IS_NEOX,\n      positions_data_ptr,\n      query_ptr,\n      key_ptr,\n      cos_sin_cache_ptr,\n      rot_dim,\n      query_stride,\n      key_stride,\n      head_stride,\n      num_heads,\n      num_kv_heads,\n      head_size,\n      grid,\n      block,\n      stream);\n}\n\nstruct RotaryEmbeddingKernel {\n  static void\n  run(tvm::ffi::TensorView positions,  // [batch_size, seq_len] or [num_tokens]\n      tvm::ffi::TensorView query,      // [batch_size, seq_len, num_heads * head_size] or\n                                       // [num_tokens, num_heads * head_size] or\n                                       // [batch_size, seq_len, num_heads, head_size] or\n                                       // [num_tokens, num_heads, head_size]\n      tvm::ffi::Optional<tvm::ffi::TensorView> key,\n      // null or\n      // [batch_size, seq_len, num_kv_heads * head_size] or\n      // [num_tokens, num_kv_heads * head_size] or\n      // [batch_size, seq_len, num_heads, head_size] or\n      // [num_tokens, num_heads, head_size]\n      int64_t head_size,\n      tvm::ffi::TensorView cos_sin_cache,  // [max_position, rot_dim]\n      bool is_neox) {\n    using namespace host;\n\n    // num_tokens = batch_size * seq_len\n    int64_t num_tokens = positions.numel();\n    int32_t positions_ndim = positions.ndim();\n\n    // Make sure num_tokens dim is consistent across positions, query, and key\n    RuntimeCheck(\n        positions_ndim == 1 || positions_ndim == 2, \"positions must have shape [num_tokens] or [batch_size, seq_len]\");\n    if (positions_ndim == 1) {\n      RuntimeCheck(\n          query.size(0) == positions.size(0) && (!key.has_value() || key.value().size(0) == positions.size(0)),\n          \"query, key and positions must have the same number of tokens\");\n    }\n    if (positions_ndim == 2) {\n      RuntimeCheck(\n          query.size(0) == positions.size(0) && (!key.has_value() || key.value().size(0) == positions.size(0)) &&\n              query.size(1) == positions.size(1) && (!key.has_value() || key.value().size(1) == positions.size(1)),\n          \"query, key and positions must have the same batch_size and seq_len\");\n    }\n\n    // Make sure head_size is valid for query and key\n    // hidden_size = num_heads * head_size\n    int query_hidden_size = query.numel() / num_tokens;\n    int key_hidden_size = key.has_value() ? key.value().numel() / num_tokens : 0;\n    RuntimeCheck(query_hidden_size % head_size == 0);\n    RuntimeCheck(key_hidden_size % head_size == 0);\n\n    // Make sure query and key have consistent number of heads\n    int num_heads = query_hidden_size / head_size;\n    int num_kv_heads = key.has_value() ? key_hidden_size / head_size : num_heads;\n    RuntimeCheck(num_heads % num_kv_heads == 0);\n\n    int rot_dim = cos_sin_cache.size(1);\n    int seq_dim_idx = positions_ndim - 1;\n    int64_t query_stride = query.stride(seq_dim_idx);\n    int64_t key_stride = key.has_value() ? key.value().stride(seq_dim_idx) : 0;\n    // Determine head stride: for [*, heads, head_size] use stride of last dim;\n    // for flat [*, heads*head_size], heads blocks are contiguous of size\n    // head_size\n    int query_ndim = query.dim();\n    int64_t head_stride = (query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size;\n\n    dim3 grid(num_tokens);\n    dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));\n\n    auto device = query.device();\n    const cudaStream_t stream = LaunchKernel::resolve_device(device);\n\n    auto positions_data_ptr = static_cast<const int64_t*>(positions.data_ptr());\n\n    if (is_neox) {\n      dispatch_by_dtype<true>(\n          positions_data_ptr,\n          query.dtype(),\n          query.data_ptr(),\n          key.has_value() ? key.value().data_ptr() : nullptr,\n          cos_sin_cache.data_ptr(),\n          rot_dim,\n          query_stride,\n          key_stride,\n          head_stride,\n          num_heads,\n          num_kv_heads,\n          head_size,\n          grid,\n          block,\n          stream);\n    } else {\n      dispatch_by_dtype<false>(\n          positions_data_ptr,\n          query.dtype(),\n          query.data_ptr(),\n          key.has_value() ? key.value().data_ptr() : nullptr,\n          cos_sin_cache.data_ptr(),\n          rot_dim,\n          query_stride,\n          key_stride,\n          head_stride,\n          num_heads,\n          num_kv_heads,\n          head_size,\n          grid,\n          block,\n          stream);\n    }\n  }\n};\n\n}  // namespace\n"
  },
  {
    "path": "python/sglang/jit_kernel/csrc/elementwise/qknorm.cuh",
    "content": "#include <sgl_kernel/tensor.h>\n#include <sgl_kernel/utils.h>\n\n#include <sgl_kernel/runtime.cuh>\n#include <sgl_kernel/tile.cuh>\n#include <sgl_kernel/utils.cuh>\n\n#include <sgl_kernel/impl/norm.cuh>\n\n#include <dlpack/dlpack.h>\n#include <tvm/ffi/container/tensor.h>\n\n#include <cstdint>\n#include <cuda_bf16.h>\n#include <cuda_fp16.h>\n#include <type_traits>\n\nnamespace {\n\nstruct QKNormParams {\n  void* __restrict__ q;\n  void* __restrict__ k;  // k is offset by (-num_qo_heads * head_dim) elements\n  int64_t q_stride;\n  int64_t k_stride;\n  uint32_t num_qo_heads;\n  uint32_t num_kv_heads;\n  float eps;\n  const void* __restrict__ q_weight;\n  const void* __restrict__ k_weight;\n  uint32_t num_tokens;\n};\n\nconstexpr uint32_t kWarpsPerBlock = 4;\nconstexpr uint32_t kThreadsPerBlock = kWarpsPerBlock * device::kWarpThreads;\n\n// Warp-level kernel for head_dim <= 256\ntemplate <int64_t kHeadDim, bool kUsePDL, typename Float>\n__global__ void fused_qknorm_warp(const QKNormParams __grid_constant__ params) {\n  using namespace device;\n  using Storage = norm::StorageType<Float, kHeadDim>;\n\n  static_assert(sizeof(Float) == 2, \"Only support FP16/BF16\");\n  const auto& [q, k, q_stride, k_stride, num_qo_heads, num_kv_heads, eps, q_weight, k_weight, num_tokens] = params;\n\n  const auto num_blks = gridDim.x;\n  const auto num_workers = num_blks * kWarpsPerBlock;\n  const auto num_q_and_k_heads = num_qo_heads + num_kv_heads;\n  const auto num_works = num_q_and_k_heads * num_tokens;\n  const auto start_worker_id = blockIdx.x * kWarpsPerBlock + threadIdx.x / kWarpThreads;\n  const auto gmem = tile::Memory<Storage>::warp();\n\n  PDLWaitPrimary<kUsePDL>();  // wait for primary kernel\n\n  for (auto idx = start_worker_id; idx < num_works; idx += num_workers) {\n    const int64_t token_id = idx / num_q_and_k_heads;\n    const int64_t head_id = idx % num_q_and_k_heads;\n    const auto load_q = head_id < num_qo_heads;\n    const auto input = load_q ? pointer::offset(q, 2 * (token_id * q_stride + head_id * kHeadDim))\n                              : pointer::offset(k, 2 * (token_id * k_stride + head_id * kHeadDim));\n    const auto weight = load_q ? q_weight : k_weight;\n    const auto input_vec = gmem.load(input);\n    const auto weight_vec = gmem.load(weight);\n    const auto output_vec = norm::apply_norm_warp<kHeadDim>(input_vec, weight_vec, eps);\n    gmem.store(input, output_vec);\n  }\n\n  PDLTriggerSecondary<kUsePDL>();  // launch secondary kernel\n}\n\n// For CTA level, used for head_dim > 256 (512,1024)\ntemplate <int64_t kHeadDim, bool kUsePDL, typename Float>\n__global__ void fused_qknorm_cta(const QKNormParams __grid_constant__ params) {\n  using namespace device;\n  using Storage = norm::StorageType<Float, kHeadDim>;\n\n  constexpr auto kNumThreads = host::norm::get_cta_threads<Float, kHeadDim>();\n  constexpr auto kNumWarps = kNumThreads / kWarpThreads;\n\n  static_assert(sizeof(Float) == 2, \"Only support FP16/BF16\");\n  const auto& [q, k, q_stride, k_stride, num_qo_heads, num_kv_heads, eps, q_weight, k_weight, num_tokens] = params;\n\n  const auto num_q_and_k_heads = num_qo_heads + num_kv_heads;\n  const auto num_works = num_q_and_k_heads * num_tokens;\n  const auto gmem = tile::Memory<Storage>::cta(kNumThreads);\n  __shared__ float smem[norm::kSmemBufferSize];\n\n  PDLWaitPrimary<kUsePDL>();  // wait for primary kernel\n\n  for (auto idx = blockIdx.x; idx < num_works; idx += gridDim.x) {\n    const int64_t token_id = idx / num_q_and_k_heads;\n    const int64_t head_id = idx % num_q_and_k_heads;\n    const auto load_q = head_id < num_qo_heads;\n    const auto input = load_q ? pointer::offset(q, 2 * (token_id * q_stride + head_id * kHeadDim))\n                              : pointer::offset(k, 2 * (token_id * k_stride + head_id * kHeadDim));\n    const auto weight = load_q ? q_weight : k_weight;\n    const auto input_vec = gmem.load(input);\n    const auto weight_vec = gmem.load(weight);\n    const auto output_vec = norm::apply_norm_cta<kHeadDim>(input_vec, weight_vec, eps, smem, kNumWarps);\n    gmem.store(input, output_vec);\n  }\n\n  PDLTriggerSecondary<kUsePDL>();  // launch secondary kernel\n}\n\n// Warp-level kernel struct for head_dim <= 256\ntemplate <int64_t kHeadDim, bool kUsePDL, typename DType>\nstruct QKNormKernelWarp {\n  static_assert(std::is_same_v<DType, fp16_t> || std::is_same_v<DType, bf16_t>);\n  static_assert(!host::norm::should_use_cta<DType, kHeadDim>(), \"Use QKNormKernelCTA for head_dim > 256\");\n  static constexpr auto kernel = fused_qknorm_warp<kHeadDim, kUsePDL, DType>;\n\n  static void\n  run(const tvm::ffi::TensorView q,\n      const tvm::ffi::TensorView k,\n      const tvm::ffi::TensorView q_weight,\n      const tvm::ffi::TensorView k_weight,\n      float eps) {\n    using namespace host;\n\n    auto N = SymbolicSize{\"num_tokens\"};\n    auto Q = SymbolicSize{\"num_qo_heads\"};\n    auto K = SymbolicSize{\"num_kv_heads\"};\n    auto D = SymbolicSize{\"head_dim\"};\n    auto Sq = SymbolicSize{\"q_stride\"};\n    auto Sk = SymbolicSize{\"k_stride\"};\n    auto device = SymbolicDevice{};\n    D.set_value(kHeadDim);\n    device.set_options<kDLCUDA>();\n\n    TensorMatcher({N, Q, D})  // q input\n        .with_strides({Sq, D, 1})\n        .with_dtype<DType>()\n        .with_device(device)\n        .verify(q);\n    TensorMatcher({N, K, D})  // k input\n        .with_strides({Sk, D, 1})\n        .with_dtype<DType>()\n        .with_device(device)\n        .verify(k);\n    TensorMatcher({D})  // weight\n        .with_dtype<DType>()\n        .with_device(device)\n        .verify(q_weight)\n        .verify(k_weight);\n\n    const auto num_tokens = static_cast<uint32_t>(N.unwrap());\n    const auto num_qo_heads = static_cast<uint32_t>(Q.unwrap());\n    const auto num_kv_heads = static_cast<uint32_t>(K.unwrap());\n\n    // NOTE: we offset the k here to reduce computation cost in the kernel\n    const auto params = QKNormParams{\n        .q = q.data_ptr(),\n        .k = pointer::offset(k.data_ptr(), -2 * static_cast<int64_t>(num_qo_heads) * kHeadDim),\n        .q_stride = static_cast<int64_t>(Sq.unwrap()),\n        .k_stride = static_cast<int64_t>(Sk.unwrap()),\n        .num_qo_heads = num_qo_heads,\n        .num_kv_heads = num_kv_heads,\n        .eps = eps,\n        .q_weight = q_weight.data_ptr(),\n        .k_weight = k_weight.data_ptr(),\n        .num_tokens = num_tokens,\n    };\n\n    static const uint32_t max_occupancy = runtime::get_blocks_per_sm(kernel, kThreadsPerBlock);\n    static const uint32_t kNumSM = runtime::get_sm_count(device.unwrap().device_id);\n\n    // choose kernel based on dtype\n    const auto num_works = (num_qo_heads + num_kv_heads) * num_tokens;\n    const auto needed_blocks = div_ceil(num_works, kWarpsPerBlock);\n\n    // we use persistent kernel, which limit the number of blocks to reduce overhead\n    const auto num_blocks = std::min(kNumSM * max_occupancy, needed_blocks);\n    LaunchKernel(num_blocks, kThreadsPerBlock, device.unwrap())  //\n        .enable_pdl(kUsePDL)(kernel, params);\n  }\n};\n\n// This goes with fused_qknorm_cta\ntemplate <int64_t kHeadDim, bool kUsePDL, typename DType>\nstruct QKNormKernelCTA {\n  static_assert(std::is_same_v<DType, fp16_t> || std::is_same_v<DType, bf16_t>);\n  static_assert(host::norm::should_use_cta<DType, kHeadDim>(), \"Use QKNormKernelWarp for head_dim <= 256\");\n  static constexpr auto kernel = fused_qknorm_cta<kHeadDim, kUsePDL, DType>;\n  static constexpr auto kNumThreads = host::norm::get_cta_threads<DType, kHeadDim>();\n\n  static void\n  run(const tvm::ffi::TensorView q,\n      const tvm::ffi::TensorView k,\n      const tvm::ffi::TensorView q_weight,\n      const tvm::ffi::TensorView k_weight,\n      float eps) {\n    using namespace host;\n\n    auto N = SymbolicSize{\"num_tokens\"};\n    auto Q = SymbolicSize{\"num_qo_heads\"};\n    auto K = SymbolicSize{\"num_kv_heads\"};\n    auto D = SymbolicSize{\"head_dim\"};\n    auto Sq = SymbolicSize{\"q_stride\"};\n    auto Sk = SymbolicSize{\"k_stride\"};\n    auto device = SymbolicDevice{};\n    D.set_value(kHeadDim);\n    device.set_options<kDLCUDA>();\n\n    TensorMatcher({N, Q, D})  // q input\n        .with_strides({Sq, D, 1})\n        .with_dtype<DType>()\n        .with_device(device)\n        .verify(q);\n    TensorMatcher({N, K, D})  // k input\n        .with_strides({Sk, D, 1})\n        .with_dtype<DType>()\n        .with_device(device)\n        .verify(k);\n    TensorMatcher({D})  // weight\n        .with_dtype<DType>()\n        .with_device(device)\n        .verify(q_weight)\n        .verify(k_weight);\n\n    const auto num_tokens = static_cast<uint32_t>(N.unwrap());\n    const auto num_qo_heads = static_cast<uint32_t>(Q.unwrap());\n    const auto num_kv_heads = static_cast<uint32_t>(K.unwrap());\n\n    // NOTE: we offset the k here to reduce computation cost in the kernel\n    const auto params = QKNormParams{\n        .q = q.data_ptr(),\n        .k = pointer::offset(k.data_ptr(), -2 * static_cast<int64_t>(num_qo_heads) * kHeadDim),\n        .q_stride = static_cast<int64_t>(Sq.unwrap()),\n        .k_stride = static_cast<int64_t>(Sk.unwrap()),\n        .num_qo_heads = num_qo_heads,\n        .num_kv_heads = num_kv_heads,\n        .eps = eps,\n        .q_weight = q_weight.data_ptr(),\n        .k_weight = k_weight.data_ptr(),\n        .num_tokens = num_tokens,\n    };\n\n    static const uint32_t max_occupancy = runtime::get_blocks_per_sm(kernel, kNumThreads);\n    static const uint32_t kNumSM = runtime::get_sm_count(device.unwrap().device_id);\n\n    const auto num_works = (num_qo_heads + num_kv_heads) * num_tokens;\n\n    // we use persistent kernel, which limit the number of blocks to reduce overhead\n    const auto num_blocks = std::min<uint32_t>(num_works, max_occupancy * kNumSM);\n    LaunchKernel(num_blocks, kNumThreads, device.unwrap())  //\n        .enable_pdl(kUsePDL)(kernel, params);\n  }\n};\n\n// Unified dispatch: select warp or CTA kernel based on head_dim\ntemplate <int64_t kHeadDim, bool kUsePDL, typename DType>\nusing QKNormKernel = std::conditional_t<\n    host::norm::should_use_cta<DType, kHeadDim>(),\n    QKNormKernelCTA<kHeadDim, kUsePDL, DType>,\n    QKNormKernelWarp<kHeadDim, kUsePDL, DType>>;\n\n}  // namespace\n"
  },
  {
    "path": "python/sglang/jit_kernel/csrc/elementwise/qknorm_across_heads.cuh",
    "content": "#include <sgl_kernel/tensor.h>\n#include <sgl_kernel/utils.h>\n\n#include <sgl_kernel/runtime.cuh>\n#include <sgl_kernel/tile.cuh>\n#include <sgl_kernel/type.cuh>\n#include <sgl_kernel/utils.cuh>\n#include <sgl_kernel/vec.cuh>\n\n#include <cooperative_groups/reduce.h>\n#include <tvm/ffi/container/tensor.h>\n\n#include <cooperative_groups.h>\n#include <type_traits>\n\nnamespace {\n\ntemplate <typename T, int VEC_SIZE_IN_BYTE>\nstruct VecTypeTrait;\n\ntemplate <>\nstruct VecTypeTrait<bf16_t, 16> {\n  using packed_t = packed_t<bf16_t>;\n  using vec_t = device::AlignedVector<packed_t, 4>;\n};\n\ntemplate <>\nstruct VecTypeTrait<fp16_t, 16> {\n  using packed_t = packed_t<fp16_t>;\n  using vec_t = device::AlignedVector<packed_t, 4>;\n};\n\ntemplate <>\nstruct VecTypeTrait<bf16_t, 32> {\n  using packed_t = packed_t<bf16_t>;\n  using vec_t = device::AlignedVector<packed_t, 8>;\n};\n\ntemplate <>\nstruct VecTypeTrait<fp16_t, 32> {\n  using packed_t = packed_t<fp16_t>;\n  using vec_t = device::AlignedVector<packed_t, 8>;\n};\n\ntemplate <typename packed_t>\nSGL_DEVICE packed_t rms(packed_t& val, packed_t& weight, float rsqrt_square_sum) {\n  float2 valf = device::cast<fp32x2_t, packed_t>(val);\n  float2 weightf = device::cast<fp32x2_t, packed_t>(weight);\n  return device::cast<packed_t, fp32x2_t>(\n      make_float2(valf.x * weightf.x * rsqrt_square_sum, valf.y * weightf.y * rsqrt_square_sum));\n}\n\ntemplate <typename T, int VEC_SIZE_IN_BYTE>\n__global__ void qknorm_across_heads_reg_kernel(\n    T* __restrict__ q,\n    T* __restrict__ k,\n    const T* __restrict__ q_weight,\n    const T* __restrict__ k_weight,\n    int vec_hidden_size,\n    float eps) {\n  constexpr int inner_loop = VEC_SIZE_IN_BYTE == 16 ? 4 : 8;\n\n  __shared__ float shared_memory[64];  // Used for CTA reduce, store both Q and K rsqrt\n\n  using vec_t = typename VecTypeTrait<T, VEC_SIZE_IN_BYTE>::vec_t;\n  using packed_t = typename VecTypeTrait<T, VEC_SIZE_IN_BYTE>::packed_t;\n  vec_t v_q;         // Save q\n  vec_t v_k;         // Save k\n  vec_t v_q_weight;  // Save q_weight\n  vec_t v_k_weight;  // Save k_weight\n  vec_t v_q_out;     // Save q output\n  vec_t v_k_out;     // Save k output\n\n  auto token_id = blockIdx.x;\n  float2 acc_square_q = make_float2(0.0f, 0.0f);  // Sum of squares for q\n  float2 acc_square_k = make_float2(0.0f, 0.0f);  // Sum of squares for k\n\n  if (threadIdx.x < vec_hidden_size) {\n    // Compute address for q and k\n    vec_t* p_q = reinterpret_cast<vec_t*>(q) + token_id * vec_hidden_size;\n    vec_t* p_k = reinterpret_cast<vec_t*>(k) + token_id * vec_hidden_size;\n    const vec_t* p_q_weight = reinterpret_cast<const vec_t*>(q_weight);\n    const vec_t* p_k_weight = reinterpret_cast<const vec_t*>(k_weight);\n\n    // Load data\n    v_q = p_q[threadIdx.x];\n    v_k = p_k[threadIdx.x];\n    v_q_weight = p_q_weight[threadIdx.x];\n    v_k_weight = p_k_weight[threadIdx.x];\n\n    // Compute sum of squares for q\n    for (int i = 0; i < inner_loop; i++) {\n      float2 val = device::cast<fp32x2_t, packed_t>(v_q[i]);\n      acc_square_q.x += val.x * val.x;\n      acc_square_q.y += val.y * val.y;\n    }\n\n    // Compute sum of squares for k\n    for (int i = 0; i < inner_loop; i++) {\n      float2 val = device::cast<fp32x2_t, packed_t>(v_k[i]);\n      acc_square_k.x += val.x * val.x;\n      acc_square_k.y += val.y * val.y;\n    }\n  }\n\n  auto cg_warp = cooperative_groups::tiled_partition<32>(cooperative_groups::this_thread_block());\n  float* buffer_q = shared_memory;       // [0, 31] for Q\n  float* buffer_k = shared_memory + 32;  // [32, 63] for K\n\n  // ========== Reduction phase: Compute rsqrt for both Q and K ==========\n\n  // Step 0: Warp Reduce for Q\n  float warp_sum_q =\n      cooperative_groups::reduce(cg_warp, acc_square_q.x + acc_square_q.y, cooperative_groups::plus<float>());\n  if (threadIdx.x % 32 == 0) {\n    buffer_q[threadIdx.x / 32] = warp_sum_q;\n  }\n\n  // Step 0: Warp Reduce for K\n  float warp_sum_k =\n      cooperative_groups::reduce(cg_warp, acc_square_k.x + acc_square_k.y, cooperative_groups::plus<float>());\n  if (threadIdx.x % 32 == 0) {\n    buffer_k[threadIdx.x / 32] = warp_sum_k;\n  }\n\n  // Step 1: CTA Reduce for both Q and K\n  __syncthreads();\n  if (threadIdx.x < 32) {\n    // CTA Reduce for Q\n    float cta_sum_q = cooperative_groups::reduce(\n        cg_warp, (threadIdx.x < blockDim.x / 32) ? buffer_q[threadIdx.x] : 0.0f, cooperative_groups::plus<float>());\n    buffer_q[threadIdx.x] =\n        rsqrtf(eps + cta_sum_q * (1.0f / static_cast<float>(vec_hidden_size * (VEC_SIZE_IN_BYTE / sizeof(T)))));\n\n    // CTA Reduce for K\n    float cta_sum_k = cooperative_groups::reduce(\n        cg_warp, (threadIdx.x < blockDim.x / 32) ? buffer_k[threadIdx.x] : 0.0f, cooperative_groups::plus<float>());\n    buffer_k[threadIdx.x] =\n        rsqrtf(eps + cta_sum_k * (1.0f / static_cast<float>(vec_hidden_size * (VEC_SIZE_IN_BYTE / sizeof(T)))));\n  }\n  __syncthreads();\n\n  // ========== Apply normalization phase: Compute and write back Q and K ==========\n\n  if (threadIdx.x < vec_hidden_size) {\n    // Apply RMSNorm for Q\n    float rsqrt_q = buffer_q[threadIdx.x / 32];\n    for (int i = 0; i < inner_loop; i++) {\n      v_q_out[i] = rms(v_q[i], v_q_weight[i], rsqrt_q);\n    }\n    vec_t* p_q_out = reinterpret_cast<vec_t*>(q) + token_id * vec_hidden_size;\n    p_q_out[threadIdx.x] = v_q_out;\n\n    // Apply RMSNorm for K\n    float rsqrt_k = buffer_k[threadIdx.x / 32];\n    for (int i = 0; i < inner_loop; i++) {\n      v_k_out[i] = rms(v_k[i], v_k_weight[i], rsqrt_k);\n    }\n    vec_t* p_k_out = reinterpret_cast<vec_t*>(k) + token_id * vec_hidden_size;\n    p_k_out[threadIdx.x] = v_k_out;\n  }\n}\n\ntemplate <typename DType>\nstruct QKNormAcrossHeadsKernel {\n  static void\n  run(const tvm::ffi::TensorView q,\n      const tvm::ffi::TensorView k,\n      const tvm::ffi::TensorView q_weight,\n      const tvm::ffi::TensorView k_weight,\n      float eps) {\n    using namespace host;\n    auto N = SymbolicSize{\"num_tokens\"};\n    auto D = SymbolicSize{\"hidden_size\"};\n    auto device = SymbolicDevice{};\n    device.set_options<kDLCUDA>();\n\n    TensorMatcher({N, D})  // q\n        .with_strides({D, 1})\n        .with_dtype<DType>()\n        .with_device(device)\n        .verify(q);\n    TensorMatcher({N, D})  // k\n        .with_strides({D, 1})\n        .with_dtype<DType>()\n        .with_device(device)\n        .verify(k);\n    TensorMatcher({D})  // q_weight\n        .with_dtype<DType>()\n        .with_device(device)\n        .verify(q_weight);\n    TensorMatcher({D})  // k_weight\n        .with_dtype<DType>()\n        .with_device(device)\n        .verify(k_weight);\n\n    auto cc_major = host::runtime::get_cc_major(device.unwrap().device_id);\n    int hidden_size = static_cast<int>(D.unwrap());\n    if ((cc_major <= 9 && hidden_size <= 8192) || (cc_major >= 10 && hidden_size <= 12288)) {\n      int max_vec_size_byte = cc_major >= 10 ? 32 : 16;\n      int elements_in_vec = max_vec_size_byte / sizeof(DType);\n      int vec_hidden_size = hidden_size / elements_in_vec;\n      uint threads = (vec_hidden_size + 31) / 32 * 32;\n\n      // Runtime check\n      host::RuntimeCheck(\n          hidden_size % elements_in_vec == 0,\n          \"hidden_size\",\n          hidden_size,\n          \" can not align to elements_in_vec \",\n          elements_in_vec);\n\n      // Launch single kernel for both q and k\n      auto kernel = max_vec_size_byte == 32 ? qknorm_across_heads_reg_kernel<DType, 32>\n                                            : qknorm_across_heads_reg_kernel<DType, 16>;\n\n      LaunchKernel(static_cast<uint>(N.unwrap()), threads, device.unwrap())\n          .enable_pdl(false)(\n              kernel,\n              reinterpret_cast<DType*>(q.data_ptr()),\n              reinterpret_cast<DType*>(k.data_ptr()),\n              reinterpret_cast<DType*>(q_weight.data_ptr()),\n              reinterpret_cast<DType*>(k_weight.data_ptr()),\n              vec_hidden_size,\n              eps);\n    } else {\n      host::RuntimeCheck(false, \"Large hidden_sizes are not supported for now.\");\n    }\n  }\n};\n\n}  // namespace\n"
  },
  {
    "path": "python/sglang/jit_kernel/csrc/elementwise/rmsnorm.cuh",
    "content": "#include <sgl_kernel/tensor.h>\n#include <sgl_kernel/utils.h>\n\n#include <sgl_kernel/runtime.cuh>\n#include <sgl_kernel/tile.cuh>\n#include <sgl_kernel/utils.cuh>\n\n#include <sgl_kernel/impl/norm.cuh>\n\n#include <tvm/ffi/container/tensor.h>\n\nnamespace {\n\nstruct RMSNormParams {\n  const void* input;\n  const void* __restrict__ weight;\n  void* output;\n  int64_t input_stride;\n  int64_t output_stride;\n  uint32_t num_tokens;\n  float eps;\n};\n\ntemplate <int64_t kDim, bool kUsePDL, typename Float>\n__global__ void rmsnorm_cta(const RMSNormParams __grid_constant__ params) {\n  using namespace device;\n  using Storage = norm::StorageType<Float, kDim>;\n\n  constexpr auto kNumThreads = host::norm::get_cta_threads<Float, kDim>();\n  constexpr auto kNumWarps = kNumThreads / kWarpThreads;\n\n  const auto& [input, weight_ptr, output, input_stride, output_stride, num_tokens, eps] = params;\n  const auto gmem = tile::Memory<Storage>::cta(kNumThreads);\n  __shared__ float smem[norm::kSmemBufferSize];\n\n  PDLWaitPrimary<kUsePDL>();  // wait for primary kernel\n\n  void* output_ptr = nullptr;\n  Storage output_vec;\n  for (uint32_t i = blockIdx.x; i < num_tokens; i += gridDim.x) {\n    const auto input_ptr = pointer::offset<Float>(input, i * input_stride);\n    const auto input_vec = gmem.load(input_ptr);\n    const auto weight_vec = gmem.load(weight_ptr);\n    if (output_ptr != nullptr) {\n      gmem.store(output_ptr, output_vec);\n    }\n    output_ptr = pointer::offset<Float>(output, i * output_stride);\n    output_vec = norm::apply_norm_cta<kDim>(input_vec, weight_vec, eps, smem, kNumWarps);\n  }\n  gmem.store(output_ptr, output_vec);\n\n  PDLTriggerSecondary<kUsePDL>();  // launch secondary kernel\n}\n\ntemplate <int64_t kDim, bool kUsePDL, typename DType>\nstruct RMSNormKernel {\n  static_assert(host::norm::should_use_cta<DType, kDim>(), \"Hidden size invalid for RMSNorm\");\n  static constexpr auto kernel = rmsnorm_cta<kDim, kUsePDL, DType>;\n\n  static void\n  run(const tvm::ffi::TensorView input,\n      const tvm::ffi::TensorView weight,\n      const tvm::ffi::TensorView output,\n      float eps) {\n    using namespace host;\n    auto N = SymbolicSize{\"num_tokens\"};\n    auto D = SymbolicSize{\"hidden_size\"};\n    auto SI = SymbolicSize{\"input_stride\"};\n    auto SO = SymbolicSize{\"output_stride\"};\n    auto device = SymbolicDevice{};\n    D.set_value(kDim);\n    device.set_options<kDLCUDA>();\n\n    TensorMatcher({N, D})  // input\n        .with_strides({SI, 1})\n        .with_dtype<DType>()\n        .with_device(device)\n        .verify(input);\n    TensorMatcher({D})  // weight\n        .with_dtype<DType>()\n        .with_device(device)\n        .verify(weight);\n    TensorMatcher({N, D})  // output\n        .with_strides({SO, 1})\n        .with_dtype<DType>()\n        .with_device(device)\n        .verify(output);\n\n    const auto num_tokens = static_cast<uint32_t>(N.unwrap());\n    const auto params = RMSNormParams{\n        .input = input.data_ptr(),\n        .weight = weight.data_ptr(),\n        .output = output.data_ptr(),\n        .input_stride = SI.unwrap(),\n        .output_stride = SO.unwrap(),\n        .num_tokens = num_tokens,\n        .eps = eps,\n    };\n\n    static constexpr auto kNumThreads = norm::get_cta_threads<DType, kDim>();\n    static const uint32_t max_occupancy = runtime::get_blocks_per_sm(kernel, kNumThreads);\n    static const uint32_t kNumSM = runtime::get_sm_count(device.unwrap().device_id);\n    const auto num_blocks = std::min<uint32_t>(num_tokens, max_occupancy * kNumSM);\n    LaunchKernel(num_blocks, kNumThreads, device.unwrap())  //\n        .enable_pdl(kUsePDL)(kernel, params);\n  }\n};\n\n}  // namespace\n"
  },
  {
    "path": "python/sglang/jit_kernel/csrc/elementwise/rope.cuh",
    "content": "#include <sgl_kernel/tensor.h>\n#include <sgl_kernel/utils.h>\n\n#include <sgl_kernel/runtime.cuh>\n#include <sgl_kernel/tile.cuh>\n#include <sgl_kernel/type.cuh>\n#include <sgl_kernel/utils.cuh>\n#include <sgl_kernel/vec.cuh>\n\n#include <dlpack/dlpack.h>\n\n#include <numeric>\n\nnamespace {\n\nstruct FusedRopeParams {\n  void* __restrict__ q_ptr;\n  void* __restrict__ k_ptr;  // NOTE: this k is pre-offset in host code to reduce computation in kernel\n  const void* __restrict__ cos_sin_cache_ptr;\n  const void* __restrict__ positions;\n  int64_t q_stride_bytes;\n  int64_t k_stride_bytes;\n  int64_t head_stride_bytes;\n  uint32_t num_qo_heads;\n  uint32_t num_kv_heads;\n  uint32_t num_tokens;\n};\n\nstruct FusedRopeStoreParams {\n  FusedRopeParams base_params;\n  void* v_ptr;\n  void* __restrict__ k_cache;\n  void* __restrict__ v_cache;\n  const void* __restrict__ out_loc;\n  int64_t v_stride_bytes;\n  int64_t cache_stride_bytes;\n};\n\nconstexpr uint32_t kBlockSize = 128;\n\n[[maybe_unused]]\nconstexpr auto next_pow2(uint32_t target, uint32_t factor = 1) {\n  uint32_t power = 1;\n  while (power * factor < target)\n    power *= 2;\n  return power;\n}\n\ntemplate <bool kIsNeox, int64_t kRopeDim, bool kUsePDL, typename DType, typename IdType, uint32_t kWorkThreads>\n__global__ void fused_rope_kernel(const __grid_constant__ FusedRopeParams params) {\n  using namespace device;\n\n  constexpr int64_t kCosSinStrideBytes = kRopeDim * sizeof(float);\n  constexpr int64_t kVecSize = next_pow2(kRopeDim, (2 * kWorkThreads * (1 + kIsNeox)));\n  using DType2 = packed_t<DType>;\n  using InputStorage = AlignedVector<DType2, kVecSize>;\n  constexpr int64_t kDimPerThread = kVecSize * 2 * (1 + kIsNeox);\n  constexpr uint32_t kLaneCount = kRopeDim / kDimPerThread;\n  static_assert(kRopeDim % kDimPerThread == 0 && kLaneCount <= kWorkThreads);\n\n  const auto &[\n    q, k, cos_sin_cache_ptr, positions, // pointers\n    q_stride_bytes, k_stride_bytes, head_stride_bytes,  // strides\n    num_qo_heads, num_kv_heads, num_tokens // dimensions\n  ] = params;\n\n  const auto num_blks = gridDim.x;\n  constexpr auto kWorkersPerBlock = kBlockSize / kWorkThreads;\n  const auto num_workers = num_blks * kWorkersPerBlock;\n  const auto num_q_and_k_heads = num_qo_heads + num_kv_heads;\n  const auto num_works = num_q_and_k_heads * num_tokens;\n  const auto start_worker_id = (blockIdx.x * kBlockSize + threadIdx.x) / kWorkThreads;\n  const auto cos_cache_ptr = cos_sin_cache_ptr;\n  const auto sin_cache_ptr = pointer::offset(cos_sin_cache_ptr, kCosSinStrideBytes / 2);\n\n  uint32_t lane_id = threadIdx.x % kWorkThreads;\n  if constexpr (kLaneCount < kWorkThreads) {\n    if (lane_id >= kLaneCount) return;\n  }\n\n  PDLWaitPrimary<kUsePDL>();\n\n  for (auto idx = start_worker_id; idx < num_works; idx += num_workers) {\n    const int64_t token_id = idx / num_q_and_k_heads;\n    const int64_t head_id = idx % num_q_and_k_heads;\n    const auto pos = static_cast<const IdType*>(positions)[token_id];\n    const auto load_q = head_id < num_qo_heads;\n    const auto input_ = load_q ? pointer::offset(q, token_id * q_stride_bytes)  //\n                               : pointer::offset(k, token_id * k_stride_bytes);\n    const auto input = pointer::offset(input_, head_id * head_stride_bytes);\n    const auto cos_ptr = pointer::offset(cos_cache_ptr, pos * kCosSinStrideBytes);\n    const auto sin_ptr = pointer::offset(sin_cache_ptr, pos * kCosSinStrideBytes);\n    if constexpr (kIsNeox) {\n      using CacheStorage = AlignedVector<fp32x2_t, kVecSize>;\n      const auto input_x = input;\n      const auto input_y = pointer::offset(input, (kRopeDim / 2) * sizeof(DType));\n      auto input_vec_x = load_as<InputStorage>(input_x, lane_id);\n      auto input_vec_y = load_as<InputStorage>(input_y, lane_id);\n      const auto cos_pair = load_as<CacheStorage>(cos_ptr, lane_id);\n      const auto sin_pair = load_as<CacheStorage>(sin_ptr, lane_id);\n#pragma unroll\n      for (int64_t j = 0; j < kVecSize; ++j) {\n        const auto [x0, x1] = cast<fp32x2_t>(input_vec_x[j]);\n        const auto [y0, y1] = cast<fp32x2_t>(input_vec_y[j]);\n        const auto [cos_0, cos_1] = cos_pair[j];\n        const auto [sin_0, sin_1] = sin_pair[j];\n        const auto out_x0 = x0 * cos_0 - y0 * sin_0;\n        const auto out_y0 = x0 * sin_0 + y0 * cos_0;\n        const auto out_x1 = x1 * cos_1 - y1 * sin_1;\n        const auto out_y1 = x1 * sin_1 + y1 * cos_1;\n        input_vec_x[j] = cast<DType2, fp32x2_t>({out_x0, out_x1});\n        input_vec_y[j] = cast<DType2, fp32x2_t>({out_y0, out_y1});\n      }\n      store_as<InputStorage>(input_x, input_vec_x, lane_id);\n      store_as<InputStorage>(input_y, input_vec_y, lane_id);\n    } else {\n      using CacheStorage = AlignedVector<float, kVecSize>;\n      auto input_vec = load_as<InputStorage>(input, lane_id);\n      const auto cos_vec = load_as<CacheStorage>(cos_ptr, lane_id);\n      const auto sin_vec = load_as<CacheStorage>(sin_ptr, lane_id);\n#pragma unroll\n      for (int64_t j = 0; j < kVecSize; ++j) {\n        const auto [x, y] = cast<fp32x2_t>(input_vec[j]);\n        const auto cos = cos_vec[j];\n        const auto sin = sin_vec[j];\n        const auto out_x = x * cos - y * sin;\n        const auto out_y = x * sin + y * cos;\n        input_vec[j] = cast<DType2, fp32x2_t>({out_x, out_y});\n      }\n      store_as<InputStorage>(input, input_vec, lane_id);\n    }\n  }\n\n  PDLTriggerSecondary<kUsePDL>();\n}\n\ntemplate <bool kIsNeox, int64_t kRopeDim, bool kUsePDL, typename DType, typename IdType, uint32_t kWorkThreads>\n__global__ void fused_rope_store_kernel(const __grid_constant__ FusedRopeStoreParams params) {\n  using namespace device;\n\n  constexpr int64_t kCosSinStrideBytes = kRopeDim * sizeof(float);\n  constexpr int64_t kVecSize = kRopeDim / (2 * kWorkThreads * (1 + kIsNeox));\n  using DType2 = packed_t<DType>;\n  using InputStorage = AlignedVector<DType2, kVecSize>;\n  constexpr int64_t kDimPerThread = kVecSize * 2 * (1 + kIsNeox);\n  static_assert(kRopeDim == kDimPerThread * kWorkThreads);\n\n  const auto& [base_params, v_ptr, k_cache, v_cache, out_loc, v_stride_bytes, cache_stride_bytes] = params;\n  const auto &[\n    q, k, cos_sin_cache_ptr, positions, // pointers\n    q_stride_bytes, k_stride_bytes, head_stride_bytes,  // strides\n    num_qo_heads, num_kv_heads, num_tokens // dimensions\n  ] = base_params;\n\n  const auto num_blks = gridDim.x;\n  constexpr auto kWorkersPerBlock = kBlockSize / kWorkThreads;\n  const auto num_workers = num_blks * kWorkersPerBlock;\n  const auto num_q_and_k_heads = num_qo_heads + num_kv_heads;\n  const auto num_works = num_q_and_k_heads * num_tokens;\n  const auto num_extra_works = num_kv_heads * num_tokens;  // rope works + v store works\n  const auto start_worker_id = (blockIdx.x * kBlockSize + threadIdx.x) / kWorkThreads;\n  const auto lane_id = threadIdx.x % kWorkThreads;\n  const auto cos_cache_ptr = cos_sin_cache_ptr;\n  const auto sin_cache_ptr = pointer::offset(cos_sin_cache_ptr, kCosSinStrideBytes / 2);\n\n  auto idx = start_worker_id;\n\n  PDLWaitPrimary<kUsePDL>();\n  // in this case, head_dim = rope_dim must be true\n  __builtin_assume(head_stride_bytes == kRopeDim * sizeof(DType));\n\n  for (; idx < num_works; idx += num_workers) {\n    const int64_t token_id = idx / num_q_and_k_heads;\n    const int64_t head_id = idx % num_q_and_k_heads;\n    const auto pos = static_cast<const IdType*>(positions)[token_id];\n    const auto loc = static_cast<const IdType*>(out_loc)[token_id];\n    const auto load_q = head_id < num_qo_heads;\n    const auto input_ = load_q ? pointer::offset(q, token_id * q_stride_bytes)  //\n                               : pointer::offset(k, token_id * k_stride_bytes);\n    const auto input = pointer::offset(input_, head_id * head_stride_bytes);\n    const auto cos_ptr = pointer::offset(cos_cache_ptr, pos * kCosSinStrideBytes);\n    const auto sin_ptr = pointer::offset(sin_cache_ptr, pos * kCosSinStrideBytes);\n    if constexpr (kIsNeox) {\n      using CacheStorage = AlignedVector<fp32x2_t, kVecSize>;\n      const auto input_x = input;\n      const auto input_y = pointer::offset(input, (kRopeDim / 2) * sizeof(DType));\n      auto input_vec_x = load_as<InputStorage>(input_x, lane_id);\n      auto input_vec_y = load_as<InputStorage>(input_y, lane_id);\n      const auto cos_pair = load_as<CacheStorage>(cos_ptr, lane_id);\n      const auto sin_pair = load_as<CacheStorage>(sin_ptr, lane_id);\n#pragma unroll\n      for (int64_t j = 0; j < kVecSize; ++j) {\n        const auto [x0, x1] = cast<fp32x2_t>(input_vec_x[j]);\n        const auto [y0, y1] = cast<fp32x2_t>(input_vec_y[j]);\n        const auto [cos_0, cos_1] = cos_pair[j];\n        const auto [sin_0, sin_1] = sin_pair[j];\n        const auto out_x0 = x0 * cos_0 - y0 * sin_0;\n        const auto out_y0 = x0 * sin_0 + y0 * cos_0;\n        const auto out_x1 = x1 * cos_1 - y1 * sin_1;\n        const auto out_y1 = x1 * sin_1 + y1 * cos_1;\n        input_vec_x[j] = cast<DType2, fp32x2_t>({out_x0, out_x1});\n        input_vec_y[j] = cast<DType2, fp32x2_t>({out_y0, out_y1});\n      }\n      store_as<InputStorage>(input, input_vec_x, lane_id);\n      const auto input_y_out = pointer::offset(input, (kRopeDim / 2) * sizeof(DType));\n      store_as<InputStorage>(input_y_out, input_vec_y, lane_id);\n      if (!load_q) {\n        const auto k_out = pointer::offset(k_cache, loc * cache_stride_bytes, head_id * head_stride_bytes);\n        store_as<InputStorage>(k_out, input_vec_x, lane_id);\n        const auto k_out_y = pointer::offset(k_out, (kRopeDim / 2) * sizeof(DType));\n        store_as<InputStorage>(k_out_y, input_vec_y, lane_id);\n      }\n    } else {\n      using CacheStorage = AlignedVector<float, kVecSize>;\n      auto input_vec = load_as<InputStorage>(input, lane_id);\n      const auto cos_vec = load_as<CacheStorage>(cos_ptr, lane_id);\n      const auto sin_vec = load_as<CacheStorage>(sin_ptr, lane_id);\n#pragma unroll\n      for (int64_t j = 0; j < kVecSize; ++j) {\n        const auto [x, y] = cast<fp32x2_t>(input_vec[j]);\n        const auto cos = cos_vec[j];\n        const auto sin = sin_vec[j];\n        const auto out_x = x * cos - y * sin;\n        const auto out_y = x * sin + y * cos;\n        input_vec[j] = cast<DType2, fp32x2_t>({out_x, out_y});\n      }\n      store_as<InputStorage>(input, input_vec, lane_id);\n      if (!load_q) {\n        const auto k_out = pointer::offset(k_cache, loc * cache_stride_bytes, head_id * head_stride_bytes);\n        store_as<InputStorage>(k_out, input_vec, lane_id);\n      }\n    }\n  }\n\n  __syncwarp();  // to avoid warp divergence\n  idx -= num_works;\n  for (; idx < num_extra_works; idx += num_workers) {\n    using VStorage = AlignedVector<DType, kRopeDim / kWorkThreads>;\n    const int64_t token_id = idx / num_kv_heads;\n    const int64_t head_id = idx % num_kv_heads;\n    const auto loc = static_cast<const IdType*>(out_loc)[token_id];\n    const auto input = pointer::offset(v_ptr, token_id * v_stride_bytes, head_id * head_stride_bytes);\n    const auto input_vec = load_as<VStorage>(input, lane_id);\n    const auto output = pointer::offset(v_cache, loc * cache_stride_bytes, head_id * head_stride_bytes);\n    store_as<VStorage>(output, input_vec, lane_id);\n  }\n  PDLTriggerSecondary<kUsePDL>();\n}\n\ntemplate <bool kIsNeox, int64_t kRopeDim, bool kUsePDL, typename DType>\nstruct FusedRopeKernel {\n  static constexpr uint32_t kDimPerThread = std::gcd(16 / sizeof(DType), kRopeDim);\n  static constexpr uint32_t kWorkThreads = next_pow2(kRopeDim, kDimPerThread);\n  static constexpr bool kSupportFused = kWorkThreads * kDimPerThread == kRopeDim;\n  static_assert(kRopeDim % kDimPerThread == 0);\n  static_assert(kBlockSize % kWorkThreads == 0);\n\n  template <typename IdType>\n  static constexpr auto _kernel_0 = fused_rope_kernel<kIsNeox, kRopeDim, kUsePDL, DType, IdType, kWorkThreads>;\n  template <typename IdType>\n  static constexpr auto _kernel_1 = fused_rope_store_kernel<kIsNeox, kRopeDim, kUsePDL, DType, IdType, kWorkThreads>;\n\n  static auto get_num_sm(DLDevice device) {\n    static const auto kNumSM = host::runtime::get_sm_count(device.device_id);\n    return kNumSM;\n  }\n\n  static void\n  run(const tvm::ffi::TensorView q,\n      const tvm::ffi::TensorView k,\n      const tvm::ffi::TensorView cos_sin_cache,\n      const tvm::ffi::TensorView positions) {\n    using namespace host;\n    auto N = SymbolicSize{\"num_tokens\"};\n    auto Q = SymbolicSize{\"num_qo_heads\"};\n    auto K = SymbolicSize{\"num_kv_heads\"};\n    auto D = SymbolicSize{\"rope_dim\"};\n    auto Dq = SymbolicSize{\"q_stride\"};\n    auto Dk = SymbolicSize{\"k_stride\"};\n    auto Dd = SymbolicSize{\"head_stride\"};\n    auto device = SymbolicDevice{};\n    auto id_type = SymbolicDType{};\n    D.set_value(kRopeDim);\n    device.set_options<kDLCUDA>();\n    TensorMatcher({N, Q, D})  // q input\n        .with_strides({Dq, Dd, 1})\n        .with_dtype<DType>()\n        .with_device(device)\n        .verify(q);\n    TensorMatcher({N, K, D})  // k input\n        .with_strides({Dk, Dd, 1})\n        .with_dtype<DType>()\n        .with_device(device)\n        .verify(k);\n    TensorMatcher({-1, D})  // cos_sin_cache\n        .with_dtype<float>()\n        .with_device(device)\n        .verify(cos_sin_cache);\n    TensorMatcher({N})  // positions\n        .with_dtype<int32_t, int64_t>(id_type)\n        .with_device(device)\n        .verify(positions);\n\n    const auto num_tokens = static_cast<uint32_t>(N.unwrap());\n    const auto num_qo_heads = static_cast<uint32_t>(Q.unwrap());\n    const auto num_kv_heads = static_cast<uint32_t>(K.unwrap());\n    const auto q_stride_bytes = static_cast<int64_t>(Dq.unwrap() * sizeof(DType));\n    const auto k_stride_bytes = static_cast<int64_t>(Dk.unwrap() * sizeof(DType));\n    const auto head_stride_bytes = static_cast<int64_t>(Dd.unwrap() * sizeof(DType));\n\n    // NOTE: we offset the k here to reduce computation cost in the kernel\n    const int64_t k_offset = static_cast<int64_t>(num_qo_heads) * head_stride_bytes;\n    const auto params = FusedRopeParams{\n        .q_ptr = q.data_ptr(),\n        .k_ptr = pointer::offset(k.data_ptr(), -k_offset),\n        .cos_sin_cache_ptr = cos_sin_cache.data_ptr(),\n        .positions = positions.data_ptr(),\n        .q_stride_bytes = q_stride_bytes,\n        .k_stride_bytes = k_stride_bytes,\n        .head_stride_bytes = head_stride_bytes,\n        .num_qo_heads = num_qo_heads,\n        .num_kv_heads = num_kv_heads,\n        .num_tokens = num_tokens,\n    };\n\n    const auto is_int32 = id_type.is_type<int32_t>();\n    const auto kernel = is_int32 ? _kernel_0<int32_t> : _kernel_0<int64_t>;\n    const uint32_t kNumSM = get_num_sm(device.unwrap());\n    static const uint32_t kOccupancyTable[2] = {\n        runtime::get_blocks_per_sm(_kernel_0<int32_t>, kBlockSize),\n        runtime::get_blocks_per_sm(_kernel_0<int64_t>, kBlockSize),\n    };\n    const auto max_blocks = kOccupancyTable[is_int32 ? 0 : 1] * kNumSM;\n    const auto num_works = (num_qo_heads + num_kv_heads) * num_tokens;\n    const auto needed_blocks = div_ceil(num_works, (kBlockSize / kWorkThreads));\n    const auto num_blocks = std::min(max_blocks, needed_blocks);\n    LaunchKernel(num_blocks, kBlockSize, device.unwrap())  //\n        .enable_pdl(kUsePDL)(kernel, params);\n  }\n\n  static void run_fused(\n      const tvm::ffi::TensorView q,\n      const tvm::ffi::TensorView k,\n      const tvm::ffi::TensorView v,\n      const tvm::ffi::TensorView k_cache,\n      const tvm::ffi::TensorView v_cache,\n      const tvm::ffi::TensorView cos_sin_cache,\n      const tvm::ffi::TensorView positions,\n      const tvm::ffi::TensorView out_loc) {\n    if constexpr (kSupportFused) {\n      return _run_fused_impl(q, k, v, k_cache, v_cache, cos_sin_cache, positions, out_loc);\n    } else {\n      host::Panic(\"Fused rope + store is not supported for rope_dim \", kRopeDim);\n    }\n  }\n\n  static void _run_fused_impl(\n      const tvm::ffi::TensorView q,\n      const tvm::ffi::TensorView k,\n      const tvm::ffi::TensorView v,\n      const tvm::ffi::TensorView k_cache,\n      const tvm::ffi::TensorView v_cache,\n      const tvm::ffi::TensorView cos_sin_cache,\n      const tvm::ffi::TensorView positions,\n      const tvm::ffi::TensorView out_loc) {\n    using namespace host;\n\n    auto N = SymbolicSize{\"num_tokens\"};\n    auto Q = SymbolicSize{\"num_qo_heads\"};\n    auto K = SymbolicSize{\"num_kv_heads\"};\n    auto D = SymbolicSize{\"rope_dim\"};\n    auto R = SymbolicSize{\"row_size\"};\n    auto Dq = SymbolicSize{\"q_stride\"};\n    auto Dk = SymbolicSize{\"k_stride\"};\n    auto Dv = SymbolicSize{\"v_stride\"};\n    auto Dd = SymbolicSize{\"head_stride\"};\n    auto Dc = SymbolicSize{\"cache_stride\"};\n    auto device = SymbolicDevice{};\n    auto id_type = SymbolicDType{};\n    D.set_value(kRopeDim);\n    device.set_options<kDLCUDA>();\n\n    TensorMatcher({N, Q, D})  // q input\n        .with_strides({Dq, Dd, 1})\n        .with_dtype<DType>()\n        .with_device(device)\n        .verify(q);\n    TensorMatcher({N, K, D})  // k input\n        .with_strides({Dk, Dd, 1})\n        .with_dtype<DType>()\n        .with_device(device)\n        .verify(k);\n    TensorMatcher({N, K, D})  // v input\n        .with_strides({Dv, Dd, 1})\n        .with_dtype<DType>()\n        .with_device(device)\n        .verify(v);\n    TensorMatcher({-1, D})  // cos_sin_cache\n        .with_dtype<float>()\n        .with_device(device)\n        .verify(cos_sin_cache);\n    TensorMatcher({N})  // positions, out_loc\n        .with_dtype<int32_t, int64_t>(id_type)\n        .with_device(device)\n        .verify(positions)\n        .verify(out_loc);\n    TensorMatcher({-1, R})  // k_cache\n        .with_strides({Dc, 1})\n        .with_dtype<DType>()\n        .with_device(device)\n        .verify(k_cache)\n        .verify(v_cache);\n\n    const auto num_tokens = static_cast<uint32_t>(N.unwrap());\n    const auto num_qo_heads = static_cast<uint32_t>(Q.unwrap());\n    const auto num_kv_heads = static_cast<uint32_t>(K.unwrap());\n    const auto q_stride_bytes = static_cast<int64_t>(Dq.unwrap() * sizeof(DType));\n    const auto k_stride_bytes = static_cast<int64_t>(Dk.unwrap() * sizeof(DType));\n    const auto head_stride = Dd.unwrap();\n    const auto row_dim = R.unwrap();\n    const auto head_stride_bytes = static_cast<int64_t>(Dd.unwrap() * sizeof(DType));\n\n    RuntimeCheck(kRopeDim == head_stride, \"rope_dim \", kRopeDim, \" should = head_stride \", head_stride);\n    RuntimeCheck(num_kv_heads * kRopeDim == row_dim, \"invalid kvcache\");\n\n    // NOTE: we offset the k here to reduce computation cost in the kernel\n    const int64_t k_offset = static_cast<int64_t>(num_qo_heads) * head_stride_bytes;\n    const auto params = FusedRopeParams{\n        .q_ptr = q.data_ptr(),\n        .k_ptr = pointer::offset(k.data_ptr(), -k_offset),\n        .cos_sin_cache_ptr = cos_sin_cache.data_ptr(),\n        .positions = positions.data_ptr(),\n        .q_stride_bytes = q_stride_bytes,\n        .k_stride_bytes = k_stride_bytes,\n        .head_stride_bytes = head_stride_bytes,\n        .num_qo_heads = num_qo_heads,\n        .num_kv_heads = num_kv_heads,\n        .num_tokens = num_tokens,\n    };\n\n    const auto v_stride_bytes = static_cast<int64_t>(Dv.unwrap() * sizeof(DType));\n    const auto cache_stride_bytes = static_cast<int64_t>(Dc.unwrap() * sizeof(DType));\n    const auto store_params = FusedRopeStoreParams{\n        .base_params = params,\n        .v_ptr = v.data_ptr(),\n        .k_cache = pointer::offset(k_cache.data_ptr(), -k_offset),\n        .v_cache = v_cache.data_ptr(),\n        .out_loc = out_loc.data_ptr(),\n        .v_stride_bytes = v_stride_bytes,\n        .cache_stride_bytes = cache_stride_bytes,\n    };\n\n    const auto is_int32 = id_type.is_type<int32_t>();\n    const auto kernel = is_int32 ? _kernel_1<int32_t> : _kernel_1<int64_t>;\n    const uint32_t kNumSM = get_num_sm(device.unwrap());\n    static const uint32_t kOccupancyTable[2] = {\n        runtime::get_blocks_per_sm(_kernel_1<int32_t>, kBlockSize),\n        runtime::get_blocks_per_sm(_kernel_1<int64_t>, kBlockSize),\n    };\n    const auto max_blocks = kOccupancyTable[is_int32 ? 0 : 1] * kNumSM;\n    // rope works for q+k heads, plus v store works for kv heads\n    const auto num_total_works = (num_qo_heads + 2 * num_kv_heads) * num_tokens;\n    const auto needed_blocks = div_ceil(num_total_works, (kBlockSize / kWorkThreads));\n    const auto num_blocks = std::min(max_blocks, needed_blocks);\n    LaunchKernel(num_blocks, kBlockSize, device.unwrap())  //\n        .enable_pdl(kUsePDL)(kernel, store_params);\n  }\n};\n\n}  // namespace\n"
  },
  {
    "path": "python/sglang/jit_kernel/csrc/fast-hadamard-transform/code_gen.py",
    "content": "from pathlib import Path\n\nimport numpy as np\n\n# From https://en.wikipedia.org/wiki/Paley_construction (construction II for q = 5)\n\nhad_12_paley = \"\"\"\n+-++++++++++\n--+-+-+-+-+-\n+++-++----++\n+---+--+-++-\n+++++-++----\n+-+---+--+-+\n++--+++-++--\n+--++---+--+\n++----+++-++\n+--+-++---+-\n++++----+++-\n+-+--+-++---\n\"\"\"\n\n# From http://neilsloane.com/hadamard/\n\nhad_12 = \"\"\"\n+-----------\n++-+---+++-+\n+++-+---+++-\n+-++-+---+++\n++-++-+---++\n+++-++-+---+\n++++-++-+---\n+-+++-++-+--\n+--+++-++-+-\n+---+++-++-+\n++---+++-++-\n+-+---+++-++\n\"\"\"\n\nhad_20_will = \"\"\"\n+----+----++--++-++-\n-+----+---+++---+-++\n--+----+---+++-+-+-+\n---+----+---+++++-+-\n----+----++--++-++-+\n-+++++-----+--+++--+\n+-+++-+---+-+--+++--\n++-++--+---+-+--+++-\n+++-+---+---+-+--+++\n++++-----++--+-+--++\n--++-+-++-+-----++++\n---++-+-++-+---+-+++\n+---++-+-+--+--++-++\n++---++-+----+-+++-+\n-++---++-+----+++++-\n-+--+--++-+----+----\n+-+-----++-+----+---\n-+-+-+---+--+----+--\n--+-+++------+----+-\n+--+--++------+----+\n\"\"\"\n\n\nhad_28_will = \"\"\"\n+------++----++-+--+-+--++--\n-+-----+++-----+-+--+-+--++-\n--+-----+++---+-+-+----+--++\n---+-----+++---+-+-+-+--+--+\n----+-----+++---+-+-+++--+--\n-----+-----++++--+-+--++--+-\n------++----++-+--+-+--++--+\n--++++-+-------++--+++-+--+-\n---++++-+-----+-++--+-+-+--+\n+---+++--+----++-++--+-+-+--\n++---++---+----++-++--+-+-+-\n+++---+----+----++-++--+-+-+\n++++--------+-+--++-++--+-+-\n-++++--------+++--++--+--+-+\n-+-++-++--++--+--------++++-\n+-+-++--+--++--+--------++++\n-+-+-++--+--++--+----+---+++\n+-+-+-++--+--+---+---++---++\n++-+-+-++--+------+--+++---+\n-++-+-+-++--+------+-++++---\n+-++-+---++--+------+-++++--\n-++--++-+-++-+++----++------\n+-++--++-+-++-+++-----+-----\n++-++---+-+-++-+++-----+----\n-++-++-+-+-+-+--+++-----+---\n--++-++++-+-+----+++-----+--\n+--++-+-++-+-+----+++-----+-\n++--++-+-++-+-+----++------+\n\"\"\"\n\n\nhad_40_tpal = \"\"\"\n+-------------------+-------------------\n++-++----+-+-++++--+++-++----+-+-++++--+\n+++-++----+-+-++++--+++-++----+-+-++++--\n+-++-++----+-+-++++-+-++-++----+-+-++++-\n+--++-++----+-+-+++++--++-++----+-+-++++\n++--++-++----+-+-+++++--++-++----+-+-+++\n+++--++-++----+-+-+++++--++-++----+-+-++\n++++--++-++----+-+-+++++--++-++----+-+-+\n+++++--++-++----+-+-+++++--++-++----+-+-\n+-++++--++-++----+-++-++++--++-++----+-+\n++-++++--++-++----+-++-++++--++-++----+-\n+-+-++++--++-++----++-+-++++--++-++----+\n++-+-++++--++-++----++-+-++++--++-++----\n+-+-+-++++--++-++---+-+-+-++++--++-++---\n+--+-+-++++--++-++--+--+-+-++++--++-++--\n+---+-+-++++--++-++-+---+-+-++++--++-++-\n+----+-+-++++--++-+++----+-+-++++--++-++\n++----+-+-++++--++-+++----+-+-++++--++-+\n+++----+-+-++++--++-+++----+-+-++++--++-\n+-++----+-+-++++--+++-++----+-+-++++--++\n+--------------------+++++++++++++++++++\n++-++----+-+-++++--+--+--++++-+-+----++-\n+++-++----+-+-++++-----+--++++-+-+----++\n+-++-++----+-+-++++--+--+--++++-+-+----+\n+--++-++----+-+-++++-++--+--++++-+-+----\n++--++-++----+-+-+++--++--+--++++-+-+---\n+++--++-++----+-+-++---++--+--++++-+-+--\n++++--++-++----+-+-+----++--+--++++-+-+-\n+++++--++-++----+-+------++--+--++++-+-+\n+-++++--++-++----+-+-+----++--+--++++-+-\n++-++++--++-++----+---+----++--+--++++-+\n+-+-++++--++-++----+-+-+----++--+--++++-\n++-+-++++--++-++------+-+----++--+--++++\n+-+-+-++++--++-++----+-+-+----++--+--+++\n+--+-+-++++--++-++---++-+-+----++--+--++\n+---+-+-++++--++-++--+++-+-+----++--+--+\n+----+-+-++++--++-++-++++-+-+----++--+--\n++----+-+-++++--++-+--++++-+-+----++--+-\n+++----+-+-++++--++----++++-+-+----++--+\n+-++----+-+-++++--++-+--++++-+-+----++--\n\"\"\"\n\n\nheader = \"\"\"\n/******************************************************************************\n * Copyright (c) 2023, Tri Dao.\n ******************************************************************************/\n\n// This file is auto-generated. See \"code_gen.py\"\\n\n\n#pragma once\n\n\"\"\"\n\ntemplate = \"\"\"\n__device__ __forceinline__ void hadamard_mult_thread_{N}(float x[{N}]) {\n    float out[{N}];\n    {code}\n    #pragma unroll\n    for (int i = 0; i < {N}; i++) { x[i] = out[i]; }\n}\n\n\"\"\"\n\n\ndef string_to_array(string):\n    # Convert strings of + and - to bool arrays\n    string = string.strip().replace(\"+\", \"1\").replace(\"-\", \"-1\").split()\n    return np.stack(\n        [\n            np.fromstring(\" \".join(string[i]), dtype=np.int32, sep=\" \")\n            for i in range(len(string))\n        ]\n    )\n\n\ndef array_code_gen(arr):\n    N = arr.shape[0]\n    assert arr.shape[0] == arr.shape[1]\n    out = []\n    for i in range(N):\n        out.append(\n            f\"out[{i}] = \"\n            + \" \".join([f\"{'+' if arr[i, j] == 1 else '-'} x[{j}]\" for j in range(N)])\n            + \";\"\n        )\n    return template.replace(\"{N}\", str(N)).replace(\"{code}\", \"\\n    \".join(out))\n\n\ndef main():\n    output_dir = Path(__file__).parent / \"fast_hadamard_transform_special.h\"\n    output_dir.write_text(\n        header\n        + array_code_gen(string_to_array(had_12_paley))\n        + array_code_gen(string_to_array(had_20_will))\n        + array_code_gen(string_to_array(had_28_will))\n        + array_code_gen(string_to_array(had_40_tpal))\n    )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "python/sglang/jit_kernel/csrc/fast-hadamard-transform/fast_hadamard_transform.h",
    "content": "/******************************************************************************\n * Copyright (c) 2023, Tri Dao.\n ******************************************************************************/\n\n// Copied from https://github.com/sgl-project/fast-hadamard-transform\n\n#pragma once\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct HadamardParamsBase {\n  using index_t = int64_t;\n\n  int batch, dim, log_N;\n\n  index_t x_batch_stride;\n  index_t out_batch_stride;\n\n  float scale;\n\n  // Common data pointers.\n  void* __restrict__ x_ptr;\n  void* __restrict__ out_ptr;\n};\n"
  },
  {
    "path": "python/sglang/jit_kernel/csrc/fast-hadamard-transform/fast_hadamard_transform_common.h",
    "content": "/******************************************************************************\n * Copyright (c) 2023, Tri Dao.\n ******************************************************************************/\n\n// Copied from https://github.com/sgl-project/fast-hadamard-transform\n\n#pragma once\n\n#include <cuda_bf16.h>\n#include <cuda_fp16.h>\n\n#define FULL_MASK 0xffffffff\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct uint8 {\n  uint4 u;\n  uint4 v;\n};\n\ntemplate <int BYTES>\nstruct BytesToType {};\n\ntemplate <>\nstruct BytesToType<32> {\n  using Type = uint8;\n  static_assert(sizeof(Type) == 32);\n};\n\ntemplate <>\nstruct BytesToType<16> {\n  using Type = uint4;\n  static_assert(sizeof(Type) == 16);\n};\n\ntemplate <>\nstruct BytesToType<8> {\n  using Type = uint64_t;\n  static_assert(sizeof(Type) == 8);\n};\n\ntemplate <>\nstruct BytesToType<4> {\n  using Type = uint32_t;\n  static_assert(sizeof(Type) == 4);\n};\n\ntemplate <>\nstruct BytesToType<2> {\n  using Type = uint16_t;\n  static_assert(sizeof(Type) == 2);\n};\n\ntemplate <>\nstruct BytesToType<1> {\n  using Type = uint8_t;\n  static_assert(sizeof(Type) == 1);\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename T>\nstruct SumOp {\n  __device__ inline T operator()(T const& x, T const& y) {\n    return x + y;\n  }\n};\n\ntemplate <int THREADS>\nstruct Allreduce {\n  static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);\n  template <typename T, typename Operator>\n  static __device__ inline T run(T x, Operator& op) {\n    constexpr int OFFSET = THREADS / 2;\n    x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));\n    return Allreduce<OFFSET>::run(x, op);\n  }\n};\n\ntemplate <>\nstruct Allreduce<2> {\n  template <typename T, typename Operator>\n  static __device__ inline T run(T x, Operator& op) {\n    x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));\n    return x;\n  }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n// https://stackoverflow.com/questions/35311711/whats-the-right-way-to-compute-integral-base-2-logarithms-at-compile-time\nconstexpr int cilog2(int val) {\n  return val > 0 ? 1 + cilog2(val >> 1) : -1;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int kLogN, int kNChunks>\n__device__ __forceinline__ void hadamard_mult_thread(float x[kNChunks][1 << kLogN]) {\n  constexpr int N = 1 << kLogN;\n#pragma unroll\n  for (int i = 0; i < kLogN; ++i) {\n    const int stride = 1 << i;\n#pragma unroll\n    for (int j = 0; j < N / 2; ++j) {\n      const int lo = j & (stride - 1);\n      const int idx = (j - lo) * 2 + lo;\n#pragma unroll\n      for (int c = 0; c < kNChunks; ++c) {\n        const float a = x[c][idx];\n        const float b = x[c][idx + stride];\n        x[c][idx] = a + b;\n        x[c][idx + stride] = a - b;\n      }\n    }\n  }\n}\n\ntemplate <int kLogWarpSize, int kStepStart, int kNChunks, int kNItems>\n__device__ __forceinline__ void hadamard_mult_warp(float x[kNChunks][kNItems]) {\n  constexpr int N = 1 << kLogWarpSize;\n  int lane_id = threadIdx.x % N;\n#pragma unroll\n  for (int step = kStepStart; step < kLogWarpSize; ++step) {\n    const int lane_mask = 1 << step;\n    const float sign = (lane_id & lane_mask) ? -1.f : 1.f;\n#pragma unroll\n    for (int c = 0; c < kNChunks; ++c) {\n#pragma unroll\n      for (int i = 0; i < kNItems; ++i) {\n        float x_val_other = __shfl_xor_sync(FULL_MASK, x[c][i], lane_mask);\n        x[c][i] = sign * x[c][i] + x_val_other;\n      }\n    }\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int kNChunks, int kNElts, typename input_t>\ninline __device__ void load_input(input_t* x, float x_vals[kNChunks][kNElts], int dim) {\n  using vec_t = typename BytesToType<sizeof(input_t) * kNElts>::Type;\n  input_t x_vals_load[kNChunks][kNElts] = {0};\n#pragma unroll\n  for (int c = 0; c < kNChunks; ++c) {\n    if ((c * blockDim.x + threadIdx.x) * kNElts < dim) {\n      reinterpret_cast<vec_t*>(x_vals_load)[c] = reinterpret_cast<const vec_t*>(x)[c * blockDim.x + threadIdx.x];\n    }\n  }\n#pragma unroll\n  for (int c = 0; c < kNChunks; ++c) {\n#pragma unroll\n    for (int i = 0; i < kNElts; ++i) {\n      x_vals[c][i] = float(x_vals_load[c][i]);\n    }\n  }\n}\n\ntemplate <int kNChunks, int kNElts, typename output_t>\ninline __device__ void store_output(output_t* out, float out_vals[kNChunks][kNElts], int dim, float scale = 1.f) {\n  using vec_t = typename BytesToType<sizeof(output_t) * kNElts>::Type;\n  output_t out_vals_store[kNChunks][kNElts];\n#pragma unroll\n  for (int c = 0; c < kNChunks; ++c) {\n#pragma unroll\n    for (int i = 0; i < kNElts; ++i) {\n      out_vals_store[c][i] = out_vals[c][i] * scale;\n    }\n  }\n#pragma unroll\n  for (int c = 0; c < kNChunks; ++c) {\n    if ((c * blockDim.x + threadIdx.x) * kNElts < dim) {\n      reinterpret_cast<vec_t*>(out)[c * blockDim.x + threadIdx.x] = reinterpret_cast<const vec_t*>(out_vals_store)[c];\n    }\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n// Pre=true means the exchange before the hadamard_mult_warp, Pre=false means after.\ntemplate <int kNChunks, int kChunksPerExchange, int kNElts, int kWarpSize, int kNWarps, bool Pre, typename vec_t>\ninline __device__ void exchange_smem_pre(float x_vals[kNChunks][kNElts], vec_t* smem) {\n  constexpr int kNThreads = kWarpSize * kNWarps;\n  constexpr int kNExchangePerVec = kNElts / (sizeof(vec_t) / sizeof(float));\n  const int warp_id = threadIdx.x / kWarpSize;\n  const int lane_id = threadIdx.x % kWarpSize;\n  const int row_t = threadIdx.x % kNWarps;\n  const int col_t = threadIdx.x / kNWarps;\n// We use the XOR swizzle trick (new_col = col ^ row) to avoid / reduce smem bank conflicts.\n#pragma unroll\n  for (int c0 = 0; c0 < kNChunks / kChunksPerExchange; ++c0) {\n    __syncthreads();\n#pragma unroll\n    for (int c1 = 0; c1 < kChunksPerExchange; ++c1) {\n#pragma unroll\n      for (int r = 0; r < kNExchangePerVec; ++r) {\n        smem\n            [(c1 * kNExchangePerVec + r) * kNThreads +\n             (Pre ? warp_id * kWarpSize + lane_id ^ warp_id : row_t * kWarpSize + col_t ^ row_t)] =\n                reinterpret_cast<vec_t*>(x_vals[c0 * kChunksPerExchange + c1])[r];\n      }\n    }\n    __syncthreads();\n#pragma unroll\n    for (int c1 = 0; c1 < kChunksPerExchange; ++c1) {\n#pragma unroll\n      for (int r = 0; r < kNExchangePerVec; ++r) {\n        reinterpret_cast<vec_t*>(x_vals[c0 * kChunksPerExchange + c1])[r] = smem\n            [(c1 * kNExchangePerVec + r) * kNThreads +\n             (Pre ? row_t * kWarpSize + col_t ^ row_t : warp_id * kWarpSize + lane_id ^ warp_id)];\n      }\n    }\n  }\n}\n"
  },
  {
    "path": "python/sglang/jit_kernel/csrc/fast-hadamard-transform/fast_hadamard_transform_special.h",
    "content": "\n/******************************************************************************\n * Copyright (c) 2023, Tri Dao.\n ******************************************************************************/\n\n// Copied from https://github.com/sgl-project/fast-hadamard-transform\n\n// This file is auto-generated. See \"code_gen.py\"\n\n#pragma once\n\n__device__ __forceinline__ void hadamard_mult_thread_12(float x[12]) {\n  float out[12];\n  out[0] = +x[0] - x[1] + x[2] + x[3] + x[4] + x[5] + x[6] + x[7] + x[8] + x[9] + x[10] + x[11];\n  out[1] = -x[0] - x[1] + x[2] - x[3] + x[4] - x[5] + x[6] - x[7] + x[8] - x[9] + x[10] - x[11];\n  out[2] = +x[0] + x[1] + x[2] - x[3] + x[4] + x[5] - x[6] - x[7] - x[8] - x[9] + x[10] + x[11];\n  out[3] = +x[0] - x[1] - x[2] - x[3] + x[4] - x[5] - x[6] + x[7] - x[8] + x[9] + x[10] - x[11];\n  out[4] = +x[0] + x[1] + x[2] + x[3] + x[4] - x[5] + x[6] + x[7] - x[8] - x[9] - x[10] - x[11];\n  out[5] = +x[0] - x[1] + x[2] - x[3] - x[4] - x[5] + x[6] - x[7] - x[8] + x[9] - x[10] + x[11];\n  out[6] = +x[0] + x[1] - x[2] - x[3] + x[4] + x[5] + x[6] - x[7] + x[8] + x[9] - x[10] - x[11];\n  out[7] = +x[0] - x[1] - x[2] + x[3] + x[4] - x[5] - x[6] - x[7] + x[8] - x[9] - x[10] + x[11];\n  out[8] = +x[0] + x[1] - x[2] - x[3] - x[4] - x[5] + x[6] + x[7] + x[8] - x[9] + x[10] + x[11];\n  out[9] = +x[0] - x[1] - x[2] + x[3] - x[4] + x[5] + x[6] - x[7] - x[8] - x[9] + x[10] - x[11];\n  out[10] = +x[0] + x[1] + x[2] + x[3] - x[4] - x[5] - x[6] - x[7] + x[8] + x[9] + x[10] - x[11];\n  out[11] = +x[0] - x[1] + x[2] - x[3] - x[4] + x[5] - x[6] + x[7] + x[8] - x[9] - x[10] - x[11];\n#pragma unroll\n  for (int i = 0; i < 12; i++) {\n    x[i] = out[i];\n  }\n}\n\n__device__ __forceinline__ void hadamard_mult_thread_20(float x[20]) {\n  float out[20];\n  out[0] = +x[0] - x[1] - x[2] - x[3] - x[4] + x[5] - x[6] - x[7] - x[8] - x[9] + x[10] + x[11] - x[12] - x[13] +\n           x[14] + x[15] - x[16] + x[17] + x[18] - x[19];\n  out[1] = -x[0] + x[1] - x[2] - x[3] - x[4] - x[5] + x[6] - x[7] - x[8] - x[9] + x[10] + x[11] + x[12] - x[13] -\n           x[14] - x[15] + x[16] - x[17] + x[18] + x[19];\n  out[2] = -x[0] - x[1] + x[2] - x[3] - x[4] - x[5] - x[6] + x[7] - x[8] - x[9] - x[10] + x[11] + x[12] + x[13] -\n           x[14] + x[15] - x[16] + x[17] - x[18] + x[19];\n  out[3] = -x[0] - x[1] - x[2] + x[3] - x[4] - x[5] - x[6] - x[7] + x[8] - x[9] - x[10] - x[11] + x[12] + x[13] +\n           x[14] + x[15] + x[16] - x[17] + x[18] - x[19];\n  out[4] = -x[0] - x[1] - x[2] - x[3] + x[4] - x[5] - x[6] - x[7] - x[8] + x[9] + x[10] - x[11] - x[12] + x[13] +\n           x[14] - x[15] + x[16] + x[17] - x[18] + x[19];\n  out[5] = -x[0] + x[1] + x[2] + x[3] + x[4] + x[5] - x[6] - x[7] - x[8] - x[9] - x[10] + x[11] - x[12] - x[13] +\n           x[14] + x[15] + x[16] - x[17] - x[18] + x[19];\n  out[6] = +x[0] - x[1] + x[2] + x[3] + x[4] - x[5] + x[6] - x[7] - x[8] - x[9] + x[10] - x[11] + x[12] - x[13] -\n           x[14] + x[15] + x[16] + x[17] - x[18] - x[19];\n  out[7] = +x[0] + x[1] - x[2] + x[3] + x[4] - x[5] - x[6] + x[7] - x[8] - x[9] - x[10] + x[11] - x[12] + x[13] -\n           x[14] - x[15] + x[16] + x[17] + x[18] - x[19];\n  out[8] = +x[0] + x[1] + x[2] - x[3] + x[4] - x[5] - x[6] - x[7] + x[8] - x[9] - x[10] - x[11] + x[12] - x[13] +\n           x[14] - x[15] - x[16] + x[17] + x[18] + x[19];\n  out[9] = +x[0] + x[1] + x[2] + x[3] - x[4] - x[5] - x[6] - x[7] - x[8] + x[9] + x[10] - x[11] - x[12] + x[13] -\n           x[14] + x[15] - x[16] - x[17] + x[18] + x[19];\n  out[10] = -x[0] - x[1] + x[2] + x[3] - x[4] + x[5] - x[6] + x[7] + x[8] - x[9] + x[10] - x[11] - x[12] - x[13] -\n            x[14] - x[15] + x[16] + x[17] + x[18] + x[19];\n  out[11] = -x[0] - x[1] - x[2] + x[3] + x[4] - x[5] + x[6] - x[7] + x[8] + x[9] - x[10] + x[11] - x[12] - x[13] -\n            x[14] + x[15] - x[16] + x[17] + x[18] + x[19];\n  out[12] = +x[0] - x[1] - x[2] - x[3] + x[4] + x[5] - x[6] + x[7] - x[8] + x[9] - x[10] - x[11] + x[12] - x[13] -\n            x[14] + x[15] + x[16] - x[17] + x[18] + x[19];\n  out[13] = +x[0] + x[1] - x[2] - x[3] - x[4] + x[5] + x[6] - x[7] + x[8] - x[9] - x[10] - x[11] - x[12] + x[13] -\n            x[14] + x[15] + x[16] + x[17] - x[18] + x[19];\n  out[14] = -x[0] + x[1] + x[2] - x[3] - x[4] - x[5] + x[6] + x[7] - x[8] + x[9] - x[10] - x[11] - x[12] - x[13] +\n            x[14] + x[15] + x[16] + x[17] + x[18] - x[19];\n  out[15] = -x[0] + x[1] - x[2] - x[3] + x[4] - x[5] - x[6] + x[7] + x[8] - x[9] + x[10] - x[11] - x[12] - x[13] -\n            x[14] + x[15] - x[16] - x[17] - x[18] - x[19];\n  out[16] = +x[0] - x[1] + x[2] - x[3] - x[4] - x[5] - x[6] - x[7] + x[8] + x[9] - x[10] + x[11] - x[12] - x[13] -\n            x[14] - x[15] + x[16] - x[17] - x[18] - x[19];\n  out[17] = -x[0] + x[1] - x[2] + x[3] - x[4] + x[5] - x[6] - x[7] - x[8] + x[9] - x[10] - x[11] + x[12] - x[13] -\n            x[14] - x[15] - x[16] + x[17] - x[18] - x[19];\n  out[18] = -x[0] - x[1] + x[2] - x[3] + x[4] + x[5] + x[6] - x[7] - x[8] - x[9] - x[10] - x[11] - x[12] + x[13] -\n            x[14] - x[15] - x[16] - x[17] + x[18] - x[19];\n  out[19] = +x[0] - x[1] - x[2] + x[3] - x[4] - x[5] + x[6] + x[7] - x[8] - x[9] - x[10] - x[11] - x[12] - x[13] +\n            x[14] - x[15] - x[16] - x[17] - x[18] + x[19];\n#pragma unroll\n  for (int i = 0; i < 20; i++) {\n    x[i] = out[i];\n  }\n}\n\n__device__ __forceinline__ void hadamard_mult_thread_28(float x[28]) {\n  float out[28];\n  out[0] = +x[0] - x[1] - x[2] - x[3] - x[4] - x[5] - x[6] + x[7] + x[8] - x[9] - x[10] - x[11] - x[12] + x[13] +\n           x[14] - x[15] + x[16] - x[17] - x[18] + x[19] - x[20] + x[21] - x[22] - x[23] + x[24] + x[25] - x[26] -\n           x[27];\n  out[1] = -x[0] + x[1] - x[2] - x[3] - x[4] - x[5] - x[6] + x[7] + x[8] + x[9] - x[10] - x[11] - x[12] - x[13] -\n           x[14] + x[15] - x[16] + x[17] - x[18] - x[19] + x[20] - x[21] + x[22] - x[23] - x[24] + x[25] + x[26] -\n           x[27];\n  out[2] = -x[0] - x[1] + x[2] - x[3] - x[4] - x[5] - x[6] - x[7] + x[8] + x[9] + x[10] - x[11] - x[12] - x[13] +\n           x[14] - x[15] + x[16] - x[17] + x[18] - x[19] - x[20] - x[21] - x[22] + x[23] - x[24] - x[25] + x[26] +\n           x[27];\n  out[3] = -x[0] - x[1] - x[2] + x[3] - x[4] - x[5] - x[6] - x[7] - x[8] + x[9] + x[10] + x[11] - x[12] - x[13] -\n           x[14] + x[15] - x[16] + x[17] - x[18] + x[19] - x[20] + x[21] - x[22] - x[23] + x[24] - x[25] - x[26] +\n           x[27];\n  out[4] = -x[0] - x[1] - x[2] - x[3] + x[4] - x[5] - x[6] - x[7] - x[8] - x[9] + x[10] + x[11] + x[12] - x[13] -\n           x[14] - x[15] + x[16] - x[17] + x[18] - x[19] + x[20] + x[21] + x[22] - x[23] - x[24] + x[25] - x[26] -\n           x[27];\n  out[5] = -x[0] - x[1] - x[2] - x[3] - x[4] + x[5] - x[6] - x[7] - x[8] - x[9] - x[10] + x[11] + x[12] + x[13] +\n           x[14] - x[15] - x[16] + x[17] - x[18] + x[19] - x[20] - x[21] + x[22] + x[23] - x[24] - x[25] + x[26] -\n           x[27];\n  out[6] = -x[0] - x[1] - x[2] - x[3] - x[4] - x[5] + x[6] + x[7] - x[8] - x[9] - x[10] - x[11] + x[12] + x[13] -\n           x[14] + x[15] - x[16] - x[17] + x[18] - x[19] + x[20] - x[21] - x[22] + x[23] + x[24] - x[25] - x[26] +\n           x[27];\n  out[7] = -x[0] - x[1] + x[2] + x[3] + x[4] + x[5] - x[6] + x[7] - x[8] - x[9] - x[10] - x[11] - x[12] - x[13] -\n           x[14] + x[15] + x[16] - x[17] - x[18] + x[19] + x[20] + x[21] - x[22] + x[23] - x[24] - x[25] + x[26] -\n           x[27];\n  out[8] = -x[0] - x[1] - x[2] + x[3] + x[4] + x[5] + x[6] - x[7] + x[8] - x[9] - x[10] - x[11] - x[12] - x[13] +\n           x[14] - x[15] + x[16] + x[17] - x[18] - x[19] + x[20] - x[21] + x[22] - x[23] + x[24] - x[25] - x[26] +\n           x[27];\n  out[9] = +x[0] - x[1] - x[2] - x[3] + x[4] + x[5] + x[6] - x[7] - x[8] + x[9] - x[10] - x[11] - x[12] - x[13] +\n           x[14] + x[15] - x[16] + x[17] + x[18] - x[19] - x[20] + x[21] - x[22] + x[23] - x[24] + x[25] - x[26] -\n           x[27];\n  out[10] = +x[0] + x[1] - x[2] - x[3] - x[4] + x[5] + x[6] - x[7] - x[8] - x[9] + x[10] - x[11] - x[12] - x[13] -\n            x[14] + x[15] + x[16] - x[17] + x[18] + x[19] - x[20] - x[21] + x[22] - x[23] + x[24] - x[25] + x[26] -\n            x[27];\n  out[11] = +x[0] + x[1] + x[2] - x[3] - x[4] - x[5] + x[6] - x[7] - x[8] - x[9] - x[10] + x[11] - x[12] - x[13] -\n            x[14] - x[15] + x[16] + x[17] - x[18] + x[19] + x[20] - x[21] - x[22] + x[23] - x[24] + x[25] - x[26] +\n            x[27];\n  out[12] = +x[0] + x[1] + x[2] + x[3] - x[4] - x[5] - x[6] - x[7] - x[8] - x[9] - x[10] - x[11] + x[12] - x[13] +\n            x[14] - x[15] - x[16] + x[17] + x[18] - x[19] + x[20] + x[21] - x[22] - x[23] + x[24] - x[25] + x[26] -\n            x[27];\n  out[13] = -x[0] + x[1] + x[2] + x[3] + x[4] - x[5] - x[6] - x[7] - x[8] - x[9] - x[10] - x[11] - x[12] + x[13] +\n            x[14] + x[15] - x[16] - x[17] + x[18] + x[19] - x[20] - x[21] + x[22] - x[23] - x[24] + x[25] - x[26] +\n            x[27];\n  out[14] = -x[0] + x[1] - x[2] + x[3] + x[4] - x[5] + x[6] + x[7] - x[8] - x[9] + x[10] + x[11] - x[12] - x[13] +\n            x[14] - x[15] - x[16] - x[17] - x[18] - x[19] - x[20] - x[21] - x[22] + x[23] + x[24] + x[25] + x[26] -\n            x[27];\n  out[15] = +x[0] - x[1] + x[2] - x[3] + x[4] + x[5] - x[6] - x[7] + x[8] - x[9] - x[10] + x[11] + x[12] - x[13] -\n            x[14] + x[15] - x[16] - x[17] - x[18] - x[19] - x[20] - x[21] - x[22] - x[23] + x[24] + x[25] + x[26] +\n            x[27];\n  out[16] = -x[0] + x[1] - x[2] + x[3] - x[4] + x[5] + x[6] - x[7] - x[8] + x[9] - x[10] - x[11] + x[12] + x[13] -\n            x[14] - x[15] + x[16] - x[17] - x[18] - x[19] - x[20] + x[21] - x[22] - x[23] - x[24] + x[25] + x[26] +\n            x[27];\n  out[17] = +x[0] - x[1] + x[2] - x[3] + x[4] - x[5] + x[6] + x[7] - x[8] - x[9] + x[10] - x[11] - x[12] + x[13] -\n            x[14] - x[15] - x[16] + x[17] - x[18] - x[19] - x[20] + x[21] + x[22] - x[23] - x[24] - x[25] + x[26] +\n            x[27];\n  out[18] = +x[0] + x[1] - x[2] + x[3] - x[4] + x[5] - x[6] + x[7] + x[8] - x[9] - x[10] + x[11] - x[12] - x[13] -\n            x[14] - x[15] - x[16] - x[17] + x[18] - x[19] - x[20] + x[21] + x[22] + x[23] - x[24] - x[25] - x[26] +\n            x[27];\n  out[19] = -x[0] + x[1] + x[2] - x[3] + x[4] - x[5] + x[6] - x[7] + x[8] + x[9] - x[10] - x[11] + x[12] - x[13] -\n            x[14] - x[15] - x[16] - x[17] - x[18] + x[19] - x[20] + x[21] + x[22] + x[23] + x[24] - x[25] - x[26] -\n            x[27];\n  out[20] = +x[0] - x[1] + x[2] + x[3] - x[4] + x[5] - x[6] - x[7] - x[8] + x[9] + x[10] - x[11] - x[12] + x[13] -\n            x[14] - x[15] - x[16] - x[17] - x[18] - x[19] + x[20] - x[21] + x[22] + x[23] + x[24] + x[25] - x[26] -\n            x[27];\n  out[21] = -x[0] + x[1] + x[2] - x[3] - x[4] + x[5] + x[6] - x[7] + x[8] - x[9] + x[10] + x[11] - x[12] + x[13] +\n            x[14] + x[15] - x[16] - x[17] - x[18] - x[19] + x[20] + x[21] - x[22] - x[23] - x[24] - x[25] - x[26] -\n            x[27];\n  out[22] = +x[0] - x[1] + x[2] + x[3] - x[4] - x[5] + x[6] + x[7] - x[8] + x[9] - x[10] + x[11] + x[12] - x[13] +\n            x[14] + x[15] + x[16] - x[17] - x[18] - x[19] - x[20] - x[21] + x[22] - x[23] - x[24] - x[25] - x[26] -\n            x[27];\n  out[23] = +x[0] + x[1] - x[2] + x[3] + x[4] - x[5] - x[6] - x[7] + x[8] - x[9] + x[10] - x[11] + x[12] + x[13] -\n            x[14] + x[15] + x[16] + x[17] - x[18] - x[19] - x[20] - x[21] - x[22] + x[23] - x[24] - x[25] - x[26] -\n            x[27];\n  out[24] = -x[0] + x[1] + x[2] - x[3] + x[4] + x[5] - x[6] + x[7] - x[8] + x[9] - x[10] + x[11] - x[12] + x[13] -\n            x[14] - x[15] + x[16] + x[17] + x[18] - x[19] - x[20] - x[21] - x[22] - x[23] + x[24] - x[25] - x[26] -\n            x[27];\n  out[25] = -x[0] - x[1] + x[2] + x[3] - x[4] + x[5] + x[6] + x[7] + x[8] - x[9] + x[10] - x[11] + x[12] - x[13] -\n            x[14] - x[15] - x[16] + x[17] + x[18] + x[19] - x[20] - x[21] - x[22] - x[23] - x[24] + x[25] - x[26] -\n            x[27];\n  out[26] = +x[0] - x[1] - x[2] + x[3] + x[4] - x[5] + x[6] - x[7] + x[8] + x[9] - x[10] + x[11] - x[12] + x[13] -\n            x[14] - x[15] - x[16] - x[17] + x[18] + x[19] + x[20] - x[21] - x[22] - x[23] - x[24] - x[25] + x[26] -\n            x[27];\n  out[27] = +x[0] + x[1] - x[2] - x[3] + x[4] + x[5] - x[6] + x[7] - x[8] + x[9] + x[10] - x[11] + x[12] - x[13] +\n            x[14] - x[15] - x[16] - x[17] - x[18] + x[19] + x[20] - x[21] - x[22] - x[23] - x[24] - x[25] - x[26] +\n            x[27];\n#pragma unroll\n  for (int i = 0; i < 28; i++) {\n    x[i] = out[i];\n  }\n}\n\n__device__ __forceinline__ void hadamard_mult_thread_40(float x[40]) {\n  float out[40];\n  out[0] = +x[0] - x[1] - x[2] - x[3] - x[4] - x[5] - x[6] - x[7] - x[8] - x[9] - x[10] - x[11] - x[12] - x[13] -\n           x[14] - x[15] - x[16] - x[17] - x[18] - x[19] + x[20] - x[21] - x[22] - x[23] - x[24] - x[25] - x[26] -\n           x[27] - x[28] - x[29] - x[30] - x[31] - x[32] - x[33] - x[34] - x[35] - x[36] - x[37] - x[38] - x[39];\n  out[1] = +x[0] + x[1] - x[2] + x[3] + x[4] - x[5] - x[6] - x[7] - x[8] + x[9] - x[10] + x[11] - x[12] + x[13] +\n           x[14] + x[15] + x[16] - x[17] - x[18] + x[19] + x[20] + x[21] - x[22] + x[23] + x[24] - x[25] - x[26] -\n           x[27] - x[28] + x[29] - x[30] + x[31] - x[32] + x[33] + x[34] + x[35] + x[36] - x[37] - x[38] + x[39];\n  out[2] = +x[0] + x[1] + x[2] - x[3] + x[4] + x[5] - x[6] - x[7] - x[8] - x[9] + x[10] - x[11] + x[12] - x[13] +\n           x[14] + x[15] + x[16] + x[17] - x[18] - x[19] + x[20] + x[21] + x[22] - x[23] + x[24] + x[25] - x[26] -\n           x[27] - x[28] - x[29] + x[30] - x[31] + x[32] - x[33] + x[34] + x[35] + x[36] + x[37] - x[38] - x[39];\n  out[3] = +x[0] - x[1] + x[2] + x[3] - x[4] + x[5] + x[6] - x[7] - x[8] - x[9] - x[10] + x[11] - x[12] + x[13] -\n           x[14] + x[15] + x[16] + x[17] + x[18] - x[19] + x[20] - x[21] + x[22] + x[23] - x[24] + x[25] + x[26] -\n           x[27] - x[28] - x[29] - x[30] + x[31] - x[32] + x[33] - x[34] + x[35] + x[36] + x[37] + x[38] - x[39];\n  out[4] = +x[0] - x[1] - x[2] + x[3] + x[4] - x[5] + x[6] + x[7] - x[8] - x[9] - x[10] - x[11] + x[12] - x[13] +\n           x[14] - x[15] + x[16] + x[17] + x[18] + x[19] + x[20] - x[21] - x[22] + x[23] + x[24] - x[25] + x[26] +\n           x[27] - x[28] - x[29] - x[30] - x[31] + x[32] - x[33] + x[34] - x[35] + x[36] + x[37] + x[38] + x[39];\n  out[5] = +x[0] + x[1] - x[2] - x[3] + x[4] + x[5] - x[6] + x[7] + x[8] - x[9] - x[10] - x[11] - x[12] + x[13] -\n           x[14] + x[15] - x[16] + x[17] + x[18] + x[19] + x[20] + x[21] - x[22] - x[23] + x[24] + x[25] - x[26] +\n           x[27] + x[28] - x[29] - x[30] - x[31] - x[32] + x[33] - x[34] + x[35] - x[36] + x[37] + x[38] + x[39];\n  out[6] = +x[0] + x[1] + x[2] - x[3] - x[4] + x[5] + x[6] - x[7] + x[8] + x[9] - x[10] - x[11] - x[12] - x[13] +\n           x[14] - x[15] + x[16] - x[17] + x[18] + x[19] + x[20] + x[21] + x[22] - x[23] - x[24] + x[25] + x[26] -\n           x[27] + x[28] + x[29] - x[30] - x[31] - x[32] - x[33] + x[34] - x[35] + x[36] - x[37] + x[38] + x[39];\n  out[7] = +x[0] + x[1] + x[2] + x[3] - x[4] - x[5] + x[6] + x[7] - x[8] + x[9] + x[10] - x[11] - x[12] - x[13] -\n           x[14] + x[15] - x[16] + x[17] - x[18] + x[19] + x[20] + x[21] + x[22] + x[23] - x[24] - x[25] + x[26] +\n           x[27] - x[28] + x[29] + x[30] - x[31] - x[32] - x[33] - x[34] + x[35] - x[36] + x[37] - x[38] + x[39];\n  out[8] = +x[0] + x[1] + x[2] + x[3] + x[4] - x[5] - x[6] + x[7] + x[8] - x[9] + x[10] + x[11] - x[12] - x[13] -\n           x[14] - x[15] + x[16] - x[17] + x[18] - x[19] + x[20] + x[21] + x[22] + x[23] + x[24] - x[25] - x[26] +\n           x[27] + x[28] - x[29] + x[30] + x[31] - x[32] - x[33] - x[34] - x[35] + x[36] - x[37] + x[38] - x[39];\n  out[9] = +x[0] - x[1] + x[2] + x[3] + x[4] + x[5] - x[6] - x[7] + x[8] + x[9] - x[10] + x[11] + x[12] - x[13] -\n           x[14] - x[15] - x[16] + x[17] - x[18] + x[19] + x[20] - x[21] + x[22] + x[23] + x[24] + x[25] - x[26] -\n           x[27] + x[28] + x[29] - x[30] + x[31] + x[32] - x[33] - x[34] - x[35] - x[36] + x[37] - x[38] + x[39];\n  out[10] = +x[0] + x[1] - x[2] + x[3] + x[4] + x[5] + x[6] - x[7] - x[8] + x[9] + x[10] - x[11] + x[12] + x[13] -\n            x[14] - x[15] - x[16] - x[17] + x[18] - x[19] + x[20] + x[21] - x[22] + x[23] + x[24] + x[25] + x[26] -\n            x[27] - x[28] + x[29] + x[30] - x[31] + x[32] + x[33] - x[34] - x[35] - x[36] - x[37] + x[38] - x[39];\n  out[11] = +x[0] - x[1] + x[2] - x[3] + x[4] + x[5] + x[6] + x[7] - x[8] - x[9] + x[10] + x[11] - x[12] + x[13] +\n            x[14] - x[15] - x[16] - x[17] - x[18] + x[19] + x[20] - x[21] + x[22] - x[23] + x[24] + x[25] + x[26] +\n            x[27] - x[28] - x[29] + x[30] + x[31] - x[32] + x[33] + x[34] - x[35] - x[36] - x[37] - x[38] + x[39];\n  out[12] = +x[0] + x[1] - x[2] + x[3] - x[4] + x[5] + x[6] + x[7] + x[8] - x[9] - x[10] + x[11] + x[12] - x[13] +\n            x[14] + x[15] - x[16] - x[17] - x[18] - x[19] + x[20] + x[21] - x[22] + x[23] - x[24] + x[25] + x[26] +\n            x[27] + x[28] - x[29] - x[30] + x[31] + x[32] - x[33] + x[34] + x[35] - x[36] - x[37] - x[38] - x[39];\n  out[13] = +x[0] - x[1] + x[2] - x[3] + x[4] - x[5] + x[6] + x[7] + x[8] + x[9] - x[10] - x[11] + x[12] + x[13] -\n            x[14] + x[15] + x[16] - x[17] - x[18] - x[19] + x[20] - x[21] + x[22] - x[23] + x[24] - x[25] + x[26] +\n            x[27] + x[28] + x[29] - x[30] - x[31] + x[32] + x[33] - x[34] + x[35] + x[36] - x[37] - x[38] - x[39];\n  out[14] = +x[0] - x[1] - x[2] + x[3] - x[4] + x[5] - x[6] + x[7] + x[8] + x[9] + x[10] - x[11] - x[12] + x[13] +\n            x[14] - x[15] + x[16] + x[17] - x[18] - x[19] + x[20] - x[21] - x[22] + x[23] - x[24] + x[25] - x[26] +\n            x[27] + x[28] + x[29] + x[30] - x[31] - x[32] + x[33] + x[34] - x[35] + x[36] + x[37] - x[38] - x[39];\n  out[15] = +x[0] - x[1] - x[2] - x[3] + x[4] - x[5] + x[6] - x[7] + x[8] + x[9] + x[10] + x[11] - x[12] - x[13] +\n            x[14] + x[15] - x[16] + x[17] + x[18] - x[19] + x[20] - x[21] - x[22] - x[23] + x[24] - x[25] + x[26] -\n            x[27] + x[28] + x[29] + x[30] + x[31] - x[32] - x[33] + x[34] + x[35] - x[36] + x[37] + x[38] - x[39];\n  out[16] = +x[0] - x[1] - x[2] - x[3] - x[4] + x[5] - x[6] + x[7] - x[8] + x[9] + x[10] + x[11] + x[12] - x[13] -\n            x[14] + x[15] + x[16] - x[17] + x[18] + x[19] + x[20] - x[21] - x[22] - x[23] - x[24] + x[25] - x[26] +\n            x[27] - x[28] + x[29] + x[30] + x[31] + x[32] - x[33] - x[34] + x[35] + x[36] - x[37] + x[38] + x[39];\n  out[17] = +x[0] + x[1] - x[2] - x[3] - x[4] - x[5] + x[6] - x[7] + x[8] - x[9] + x[10] + x[11] + x[12] + x[13] -\n            x[14] - x[15] + x[16] + x[17] - x[18] + x[19] + x[20] + x[21] - x[22] - x[23] - x[24] - x[25] + x[26] -\n            x[27] + x[28] - x[29] + x[30] + x[31] + x[32] + x[33] - x[34] - x[35] + x[36] + x[37] - x[38] + x[39];\n  out[18] = +x[0] + x[1] + x[2] - x[3] - x[4] - x[5] - x[6] + x[7] - x[8] + x[9] - x[10] + x[11] + x[12] + x[13] +\n            x[14] - x[15] - x[16] + x[17] + x[18] - x[19] + x[20] + x[21] + x[22] - x[23] - x[24] - x[25] - x[26] +\n            x[27] - x[28] + x[29] - x[30] + x[31] + x[32] + x[33] + x[34] - x[35] - x[36] + x[37] + x[38] - x[39];\n  out[19] = +x[0] - x[1] + x[2] + x[3] - x[4] - x[5] - x[6] - x[7] + x[8] - x[9] + x[10] - x[11] + x[12] + x[13] +\n            x[14] + x[15] - x[16] - x[17] + x[18] + x[19] + x[20] - x[21] + x[22] + x[23] - x[24] - x[25] - x[26] -\n            x[27] + x[28] - x[29] + x[30] - x[31] + x[32] + x[33] + x[34] + x[35] - x[36] - x[37] + x[38] + x[39];\n  out[20] = +x[0] - x[1] - x[2] - x[3] - x[4] - x[5] - x[6] - x[7] - x[8] - x[9] - x[10] - x[11] - x[12] - x[13] -\n            x[14] - x[15] - x[16] - x[17] - x[18] - x[19] - x[20] + x[21] + x[22] + x[23] + x[24] + x[25] + x[26] +\n            x[27] + x[28] + x[29] + x[30] + x[31] + x[32] + x[33] + x[34] + x[35] + x[36] + x[37] + x[38] + x[39];\n  out[21] = +x[0] + x[1] - x[2] + x[3] + x[4] - x[5] - x[6] - x[7] - x[8] + x[9] - x[10] + x[11] - x[12] + x[13] +\n            x[14] + x[15] + x[16] - x[17] - x[18] + x[19] - x[20] - x[21] + x[22] - x[23] - x[24] + x[25] + x[26] +\n            x[27] + x[28] - x[29] + x[30] - x[31] + x[32] - x[33] - x[34] - x[35] - x[36] + x[37] + x[38] - x[39];\n  out[22] = +x[0] + x[1] + x[2] - x[3] + x[4] + x[5] - x[6] - x[7] - x[8] - x[9] + x[10] - x[11] + x[12] - x[13] +\n            x[14] + x[15] + x[16] + x[17] - x[18] - x[19] - x[20] - x[21] - x[22] + x[23] - x[24] - x[25] + x[26] +\n            x[27] + x[28] + x[29] - x[30] + x[31] - x[32] + x[33] - x[34] - x[35] - x[36] - x[37] + x[38] + x[39];\n  out[23] = +x[0] - x[1] + x[2] + x[3] - x[4] + x[5] + x[6] - x[7] - x[8] - x[9] - x[10] + x[11] - x[12] + x[13] -\n            x[14] + x[15] + x[16] + x[17] + x[18] - x[19] - x[20] + x[21] - x[22] - x[23] + x[24] - x[25] - x[26] +\n            x[27] + x[28] + x[29] + x[30] - x[31] + x[32] - x[33] + x[34] - x[35] - x[36] - x[37] - x[38] + x[39];\n  out[24] = +x[0] - x[1] - x[2] + x[3] + x[4] - x[5] + x[6] + x[7] - x[8] - x[9] - x[10] - x[11] + x[12] - x[13] +\n            x[14] - x[15] + x[16] + x[17] + x[18] + x[19] - x[20] + x[21] + x[22] - x[23] - x[24] + x[25] - x[26] -\n            x[27] + x[28] + x[29] + x[30] + x[31] - x[32] + x[33] - x[34] + x[35] - x[36] - x[37] - x[38] - x[39];\n  out[25] = +x[0] + x[1] - x[2] - x[3] + x[4] + x[5] - x[6] + x[7] + x[8] - x[9] - x[10] - x[11] - x[12] + x[13] -\n            x[14] + x[15] - x[16] + x[17] + x[18] + x[19] - x[20] - x[21] + x[22] + x[23] - x[24] - x[25] + x[26] -\n            x[27] - x[28] + x[29] + x[30] + x[31] + x[32] - x[33] + x[34] - x[35] + x[36] - x[37] - x[38] - x[39];\n  out[26] = +x[0] + x[1] + x[2] - x[3] - x[4] + x[5] + x[6] - x[7] + x[8] + x[9] - x[10] - x[11] - x[12] - x[13] +\n            x[14] - x[15] + x[16] - x[17] + x[18] + x[19] - x[20] - x[21] - x[22] + x[23] + x[24] - x[25] - x[26] +\n            x[27] - x[28] - x[29] + x[30] + x[31] + x[32] + x[33] - x[34] + x[35] - x[36] + x[37] - x[38] - x[39];\n  out[27] = +x[0] + x[1] + x[2] + x[3] - x[4] - x[5] + x[6] + x[7] - x[8] + x[9] + x[10] - x[11] - x[12] - x[13] -\n            x[14] + x[15] - x[16] + x[17] - x[18] + x[19] - x[20] - x[21] - x[22] - x[23] + x[24] + x[25] - x[26] -\n            x[27] + x[28] - x[29] - x[30] + x[31] + x[32] + x[33] + x[34] - x[35] + x[36] - x[37] + x[38] - x[39];\n  out[28] = +x[0] + x[1] + x[2] + x[3] + x[4] - x[5] - x[6] + x[7] + x[8] - x[9] + x[10] + x[11] - x[12] - x[13] -\n            x[14] - x[15] + x[16] - x[17] + x[18] - x[19] - x[20] - x[21] - x[22] - x[23] - x[24] + x[25] + x[26] -\n            x[27] - x[28] + x[29] - x[30] - x[31] + x[32] + x[33] + x[34] + x[35] - x[36] + x[37] - x[38] + x[39];\n  out[29] = +x[0] - x[1] + x[2] + x[3] + x[4] + x[5] - x[6] - x[7] + x[8] + x[9] - x[10] + x[11] + x[12] - x[13] -\n            x[14] - x[15] - x[16] + x[17] - x[18] + x[19] - x[20] + x[21] - x[22] - x[23] - x[24] - x[25] + x[26] +\n            x[27] - x[28] - x[29] + x[30] - x[31] - x[32] + x[33] + x[34] + x[35] + x[36] - x[37] + x[38] - x[39];\n  out[30] = +x[0] + x[1] - x[2] + x[3] + x[4] + x[5] + x[6] - x[7] - x[8] + x[9] + x[10] - x[11] + x[12] + x[13] -\n            x[14] - x[15] - x[16] - x[17] + x[18] - x[19] - x[20] - x[21] + x[22] - x[23] - x[24] - x[25] - x[26] +\n            x[27] + x[28] - x[29] - x[30] + x[31] - x[32] - x[33] + x[34] + x[35] + x[36] + x[37] - x[38] + x[39];\n  out[31] = +x[0] - x[1] + x[2] - x[3] + x[4] + x[5] + x[6] + x[7] - x[8] - x[9] + x[10] + x[11] - x[12] + x[13] +\n            x[14] - x[15] - x[16] - x[17] - x[18] + x[19] - x[20] + x[21] - x[22] + x[23] - x[24] - x[25] - x[26] -\n            x[27] + x[28] + x[29] - x[30] - x[31] + x[32] - x[33] - x[34] + x[35] + x[36] + x[37] + x[38] - x[39];\n  out[32] = +x[0] + x[1] - x[2] + x[3] - x[4] + x[5] + x[6] + x[7] + x[8] - x[9] - x[10] + x[11] + x[12] - x[13] +\n            x[14] + x[15] - x[16] - x[17] - x[18] - x[19] - x[20] - x[21] + x[22] - x[23] + x[24] - x[25] - x[26] -\n            x[27] - x[28] + x[29] + x[30] - x[31] - x[32] + x[33] - x[34] - x[35] + x[36] + x[37] + x[38] + x[39];\n  out[33] = +x[0] - x[1] + x[2] - x[3] + x[4] - x[5] + x[6] + x[7] + x[8] + x[9] - x[10] - x[11] + x[12] + x[13] -\n            x[14] + x[15] + x[16] - x[17] - x[18] - x[19] - x[20] + x[21] - x[22] + x[23] - x[24] + x[25] - x[26] -\n            x[27] - x[28] - x[29] + x[30] + x[31] - x[32] - x[33] + x[34] - x[35] - x[36] + x[37] + x[38] + x[39];\n  out[34] = +x[0] - x[1] - x[2] + x[3] - x[4] + x[5] - x[6] + x[7] + x[8] + x[9] + x[10] - x[11] - x[12] + x[13] +\n            x[14] - x[15] + x[16] + x[17] - x[18] - x[19] - x[20] + x[21] + x[22] - x[23] + x[24] - x[25] + x[26] -\n            x[27] - x[28] - x[29] - x[30] + x[31] + x[32] - x[33] - x[34] + x[35] - x[36] - x[37] + x[38] + x[39];\n  out[35] = +x[0] - x[1] - x[2] - x[3] + x[4] - x[5] + x[6] - x[7] + x[8] + x[9] + x[10] + x[11] - x[12] - x[13] +\n            x[14] + x[15] - x[16] + x[17] + x[18] - x[19] - x[20] + x[21] + x[22] + x[23] - x[24] + x[25] - x[26] +\n            x[27] - x[28] - x[29] - x[30] - x[31] + x[32] + x[33] - x[34] - x[35] + x[36] - x[37] - x[38] + x[39];\n  out[36] = +x[0] - x[1] - x[2] - x[3] - x[4] + x[5] - x[6] + x[7] - x[8] + x[9] + x[10] + x[11] + x[12] - x[13] -\n            x[14] + x[15] + x[16] - x[17] + x[18] + x[19] - x[20] + x[21] + x[22] + x[23] + x[24] - x[25] + x[26] -\n            x[27] + x[28] - x[29] - x[30] - x[31] - x[32] + x[33] + x[34] - x[35] - x[36] + x[37] - x[38] - x[39];\n  out[37] = +x[0] + x[1] - x[2] - x[3] - x[4] - x[5] + x[6] - x[7] + x[8] - x[9] + x[10] + x[11] + x[12] + x[13] -\n            x[14] - x[15] + x[16] + x[17] - x[18] + x[19] - x[20] - x[21] + x[22] + x[23] + x[24] + x[25] - x[26] +\n            x[27] - x[28] + x[29] - x[30] - x[31] - x[32] - x[33] + x[34] + x[35] - x[36] - x[37] + x[38] - x[39];\n  out[38] = +x[0] + x[1] + x[2] - x[3] - x[4] - x[5] - x[6] + x[7] - x[8] + x[9] - x[10] + x[11] + x[12] + x[13] +\n            x[14] - x[15] - x[16] + x[17] + x[18] - x[19] - x[20] - x[21] - x[22] + x[23] + x[24] + x[25] + x[26] -\n            x[27] + x[28] - x[29] + x[30] - x[31] - x[32] - x[33] - x[34] + x[35] + x[36] - x[37] - x[38] + x[39];\n  out[39] = +x[0] - x[1] + x[2] + x[3] - x[4] - x[5] - x[6] - x[7] + x[8] - x[9] + x[10] - x[11] + x[12] + x[13] +\n            x[14] + x[15] - x[16] - x[17] + x[18] + x[19] - x[20] + x[21] - x[22] - x[23] + x[24] + x[25] + x[26] +\n            x[27] - x[28] + x[29] - x[30] + x[31] - x[32] - x[33] - x[34] - x[35] + x[36] + x[37] - x[38] - x[39];\n#pragma unroll\n  for (int i = 0; i < 40; i++) {\n    x[i] = out[i];\n  }\n}\n"
  },
  {
    "path": "python/sglang/jit_kernel/csrc/fast-hadamard-transform/hadamard_jit.cuh",
    "content": "/******************************************************************************\n * Copyright (c) 2023, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include <sgl_kernel/tensor.h>\n#include <sgl_kernel/utils.h>\n\n#include <sgl_kernel/utils.cuh>\n\n#include <tvm/ffi/container/tensor.h>\n\n#include \"fast_hadamard_transform.h\"\n#include \"fast_hadamard_transform_common.h\"\n#include \"fast_hadamard_transform_special.h\"\n#include \"static_switch.h\"\n#include <algorithm>\n#include <cstdint>\n#include <cstring>\n\nnamespace {\n\nusing ::bf16_t;\nusing ::fp16_t;\nusing ::HadamardParamsBase;\n\nconstexpr inline int ceil_log2(int val) {\n  int log = 0;\n  int p = 1;\n  while (p < val) {\n    p <<= 1;\n    ++log;\n  }\n  return log;\n}\n\ntemplate <int kNThreads_, int kLogN_, typename input_t_>\nstruct FastHadamardKernelTraits {\n  using input_t = input_t_;\n  static constexpr int kNThreads = kNThreads_;\n  static constexpr int kLogN = kLogN_;\n  static constexpr int N = 1 << kLogN;\n  static constexpr int kNBytes = sizeof(input_t);\n  static_assert(kNBytes == 2 || kNBytes == 4);\n  static constexpr int kNElts = kNBytes == 4 ? 4 : 8;\n  static constexpr int kNExchangePerVec = sizeof(float) / sizeof(input_t);\n  using vec_t = typename BytesToType<kNBytes * kNElts>::Type;\n  static constexpr int kNChunks = N / (kNElts * kNThreads);\n  static constexpr int kSmemExchangeSize = (N * 4) < (32 * 1024) ? (N * 4) : (32 * 1024);\n  static constexpr int kNExchangeRounds = N * 4 / kSmemExchangeSize;\n  static_assert(kNExchangeRounds * kSmemExchangeSize == N * 4);\n  static constexpr int kSmemSize = kSmemExchangeSize;\n};\n\ntemplate <int kNThreads_, int kLogN_, int kMultiple, int kMaxDim, int kMaxSmem, typename input_t_>\nstruct FastHadamardMNKernelTraits {\n  using input_t = input_t_;\n  static constexpr int kNThreads = kNThreads_;\n  static constexpr int kLogN = kLogN_;\n  static constexpr int N = (1 << kLogN) * kMultiple;\n  static_assert(N <= kMaxDim);\n  static constexpr int kNBytes = sizeof(input_t);\n  static_assert(kNBytes == 2 || kNBytes == 4);\n  static constexpr int kNElts = 4;\n  static constexpr int kNExchangePerVec = sizeof(float) / sizeof(input_t);\n  using vec_t = typename BytesToType<kNBytes * kNElts>::Type;\n  static constexpr int kNChunks = N / (kNElts * kNThreads);\n  static_assert(kNChunks == kMultiple);\n  static constexpr int kSmemExchangeSize = (N * 4) < kMaxSmem ? (N * 4) : kMaxSmem;\n  static constexpr int kNExchangeRounds = N * 4 / kSmemExchangeSize;\n  static_assert(kNExchangeRounds * kSmemExchangeSize == N * 4);\n  static constexpr int kSmemSize = kSmemExchangeSize;\n};\n\ntemplate <int kNThreads_, int kLogN_, typename input_t_>\nusing FastHadamard12NTraits = FastHadamardMNKernelTraits<kNThreads_, kLogN_, 12, 12 * 1024, 24 * 1024, input_t_>;\n\ntemplate <int kNThreads_, int kLogN_, typename input_t_>\nusing FastHadamard20NTraits = FastHadamardMNKernelTraits<kNThreads_, kLogN_, 20, 20 * 1024, 40 * 1024, input_t_>;\n\ntemplate <int kNThreads_, int kLogN_, typename input_t_>\nusing FastHadamard28NTraits = FastHadamardMNKernelTraits<kNThreads_, kLogN_, 28, 28 * 1024, 28 * 1024, input_t_>;\n\ntemplate <int kNThreads_, int kLogN_, typename input_t_>\nusing FastHadamard40NTraits = FastHadamardMNKernelTraits<kNThreads_, kLogN_, 40, 40 * 1024, 40 * 1024, input_t_>;\n\ntemplate <int kNChunks>\nSGL_DEVICE void hadamard_mult_thread_chunk_12(float x[kNChunks][12]) {\n#pragma unroll\n  for (int c = 0; c < kNChunks; ++c) {\n    hadamard_mult_thread_12(x[c]);\n  }\n}\n\ntemplate <int kNChunks>\nSGL_DEVICE void hadamard_mult_thread_chunk_20(float x[kNChunks][20]) {\n#pragma unroll\n  for (int c = 0; c < kNChunks; ++c) {\n    hadamard_mult_thread_20(x[c]);\n  }\n}\n\ntemplate <int kNChunks>\nSGL_DEVICE void hadamard_mult_thread_chunk_28(float x[kNChunks][28]) {\n#pragma unroll\n  for (int c = 0; c < kNChunks; ++c) {\n    hadamard_mult_thread_28(x[c]);\n  }\n}\n\ntemplate <int kNChunks>\nSGL_DEVICE void hadamard_mult_thread_chunk_40(float x[kNChunks][40]) {\n#pragma unroll\n  for (int c = 0; c < kNChunks; ++c) {\n    hadamard_mult_thread_40(x[c]);\n  }\n}\n\ntemplate <typename Ktraits>\n__global__ __launch_bounds__(Ktraits::kNThreads) void fast_hadamard_transform_kernel(HadamardParamsBase params) {\n  constexpr int kNThreads = Ktraits::kNThreads;\n  constexpr int kNElts = Ktraits::kNElts;\n  constexpr int kNExchangePerVec = Ktraits::kNExchangePerVec;\n  constexpr int kNChunks = Ktraits::kNChunks;\n  using input_t = typename Ktraits::input_t;\n  using vec_t = typename Ktraits::vec_t;\n\n  constexpr int kLogNElts = cilog2(Ktraits::kNElts);\n  static_assert(1 << kLogNElts == kNElts, \"kNElts must be a power of 2\");\n\n  constexpr int kWarpSize = kNThreads < 32 ? kNThreads : 32;\n  constexpr int kLogWarpSize = cilog2(kWarpSize);\n  static_assert(1 << kLogWarpSize == kWarpSize, \"Warp size must be a power of 2\");\n\n  constexpr int kNWarps = kNThreads / kWarpSize;\n  constexpr int kLogNWarps = cilog2(kNWarps);\n  static_assert(1 << kLogNWarps == kNWarps, \"kNWarps must be a power of 2\");\n\n  constexpr int kChunksPerExchange = Ktraits::kSmemExchangeSize / (sizeof(vec_t) * kNExchangePerVec * kNThreads);\n  static_assert(kChunksPerExchange * sizeof(vec_t) * kNExchangePerVec * kNThreads == Ktraits::kSmemExchangeSize);\n  constexpr int kNExchanges = kNChunks / kChunksPerExchange;\n  static_assert(kNExchanges * kChunksPerExchange == kNChunks);\n\n  extern __shared__ char smem_[];\n  vec_t* smem_exchange = reinterpret_cast<vec_t*>(smem_);\n\n  const int batch_id = static_cast<int>(blockIdx.x);\n  input_t* x = reinterpret_cast<input_t*>(params.x_ptr) + batch_id * params.x_batch_stride;\n  input_t* out = reinterpret_cast<input_t*>(params.out_ptr) + batch_id * params.out_batch_stride;\n\n  float x_vals[kNChunks][kNElts];\n  load_input<kNChunks, kNElts, input_t>(x, x_vals, params.dim);\n\n  hadamard_mult_thread<kLogNElts, kNChunks>(x_vals);\n  hadamard_mult_warp<kLogWarpSize, 0, kNChunks, kNElts>(x_vals);\n\n  if constexpr (kNWarps > 1) {\n    exchange_smem_pre<kNChunks, kChunksPerExchange, kNElts, kWarpSize, kNWarps, true, vec_t>(x_vals, smem_exchange);\n    hadamard_mult_warp<kLogNWarps, 0, kNChunks, kNElts>(x_vals);\n    exchange_smem_pre<kNChunks, kChunksPerExchange, kNElts, kWarpSize, kNWarps, false, vec_t>(x_vals, smem_exchange);\n  }\n\n  if constexpr (kNChunks > 1) {\n    float x_vals_transposed[kNElts][kNChunks];\n#pragma unroll\n    for (int c = 0; c < kNChunks; ++c) {\n#pragma unroll\n      for (int i = 0; i < kNElts; ++i) {\n        x_vals_transposed[i][c] = x_vals[c][i];\n      }\n    }\n\n    if constexpr (kNChunks == 12) {\n      hadamard_mult_thread_chunk_12<kNElts>(x_vals_transposed);\n    } else if constexpr (kNChunks == 20) {\n      hadamard_mult_thread_chunk_20<kNElts>(x_vals_transposed);\n    } else if constexpr (kNChunks == 28) {\n      hadamard_mult_thread_chunk_28<kNElts>(x_vals_transposed);\n    } else if constexpr (kNChunks == 40) {\n      hadamard_mult_thread_chunk_40<kNElts>(x_vals_transposed);\n    } else {\n      constexpr int kLogNChunks = cilog2(kNChunks);\n      static_assert(1 << kLogNChunks == kNChunks, \"kNChunks must be a power of 2\");\n      hadamard_mult_thread<kLogNChunks, kNElts>(x_vals_transposed);\n    }\n\n#pragma unroll\n    for (int c = 0; c < kNChunks; ++c) {\n#pragma unroll\n      for (int i = 0; i < kNElts; ++i) {\n        x_vals[c][i] = x_vals_transposed[i][c];\n      }\n    }\n  }\n\n  store_output<kNChunks, kNElts, input_t>(out, x_vals, params.dim, params.scale);\n}\n\ntemplate <typename Ktraits>\ninline void set_max_dynamic_smem() {\n  constexpr int kSmemSize = Ktraits::kSmemSize;\n  if constexpr (kSmemSize >= 48 * 1024) {\n    auto kernel = &fast_hadamard_transform_kernel<Ktraits>;\n    host::RuntimeDeviceCheck(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));\n  }\n}\n\ntemplate <typename Ktraits>\ninline void launch_kernel(HadamardParamsBase& params, DLDevice device) {\n  constexpr int kSmemSize = Ktraits::kSmemSize;\n  set_max_dynamic_smem<Ktraits>();\n  auto kernel = &fast_hadamard_transform_kernel<Ktraits>;\n  host::LaunchKernel(dim3(params.batch), dim3(Ktraits::kNThreads), device, kSmemSize)(kernel, params);\n  host::RuntimeDeviceCheck();\n}\n\ntemplate <int kNThreads, int kLogN, typename input_t>\ninline void fast_hadamard_transform_launch(HadamardParamsBase& params, DLDevice device) {\n  using Ktraits = FastHadamardKernelTraits<kNThreads, kLogN, input_t>;\n  launch_kernel<Ktraits>(params, device);\n}\n\ntemplate <typename input_t>\ninline void fast_hadamard_transform_cuda(HadamardParamsBase& params, DLDevice device) {\n  if (params.log_N == 3) {\n    fast_hadamard_transform_launch<1, 3, input_t>(params, device);\n  } else if (params.log_N == 4) {\n    fast_hadamard_transform_launch<2, 4, input_t>(params, device);\n  } else if (params.log_N == 5) {\n    fast_hadamard_transform_launch<4, 5, input_t>(params, device);\n  } else if (params.log_N == 6) {\n    fast_hadamard_transform_launch<8, 6, input_t>(params, device);\n  } else if (params.log_N == 7) {\n    fast_hadamard_transform_launch<16, 7, input_t>(params, device);\n  } else if (params.log_N == 8) {\n    fast_hadamard_transform_launch<32, 8, input_t>(params, device);\n  } else if (params.log_N == 9) {\n    fast_hadamard_transform_launch<32, 9, input_t>(params, device);\n  } else if (params.log_N == 10) {\n    fast_hadamard_transform_launch<128, 10, input_t>(params, device);\n  } else if (params.log_N == 11) {\n    fast_hadamard_transform_launch<256, 11, input_t>(params, device);\n  } else if (params.log_N == 12) {\n    fast_hadamard_transform_launch<256, 12, input_t>(params, device);\n  } else if (params.log_N == 13) {\n    fast_hadamard_transform_launch<256, 13, input_t>(params, device);\n  } else if (params.log_N == 14) {\n    fast_hadamard_transform_launch<256, 14, input_t>(params, device);\n  } else if (params.log_N == 15) {\n    fast_hadamard_transform_launch<256, 15, input_t>(params, device);\n  } else {\n    host::Panic(\"fast_hadamard_transform: unsupported log_N=\", params.log_N);\n  }\n}\n\ntemplate <int kNThreads, int kLogN, typename input_t>\ninline void fast_hadamard_transform_12N_launch(HadamardParamsBase& params, DLDevice device) {\n  using Ktraits = FastHadamard12NTraits<kNThreads, kLogN, input_t>;\n  launch_kernel<Ktraits>(params, device);\n}\n\ntemplate <typename input_t>\ninline void fast_hadamard_transform_12N_cuda(HadamardParamsBase& params, DLDevice device) {\n  if (params.log_N == 2) {\n    fast_hadamard_transform_12N_launch<1, 2, input_t>(params, device);\n  } else if (params.log_N == 3) {\n    fast_hadamard_transform_12N_launch<2, 3, input_t>(params, device);\n  } else if (params.log_N == 4) {\n    fast_hadamard_transform_12N_launch<4, 4, input_t>(params, device);\n  } else if (params.log_N == 5) {\n    fast_hadamard_transform_12N_launch<8, 5, input_t>(params, device);\n  } else if (params.log_N == 6) {\n    fast_hadamard_transform_12N_launch<16, 6, input_t>(params, device);\n  } else if (params.log_N == 7) {\n    fast_hadamard_transform_12N_launch<32, 7, input_t>(params, device);\n  } else if (params.log_N == 8) {\n    fast_hadamard_transform_12N_launch<64, 8, input_t>(params, device);\n  } else if (params.log_N == 9) {\n    fast_hadamard_transform_12N_launch<128, 9, input_t>(params, device);\n  } else if (params.log_N == 10) {\n    fast_hadamard_transform_12N_launch<256, 10, input_t>(params, device);\n  } else {\n    host::Panic(\"fast_hadamard_transform_12N: unsupported log_N=\", params.log_N);\n  }\n}\n\ntemplate <int kNThreads, int kLogN, typename input_t>\ninline void fast_hadamard_transform_20N_launch(HadamardParamsBase& params, DLDevice device) {\n  using Ktraits = FastHadamard20NTraits<kNThreads, kLogN, input_t>;\n  launch_kernel<Ktraits>(params, device);\n}\n\ntemplate <typename input_t>\ninline void fast_hadamard_transform_20N_cuda(HadamardParamsBase& params, DLDevice device) {\n  if (params.log_N == 2) {\n    fast_hadamard_transform_20N_launch<1, 2, input_t>(params, device);\n  } else if (params.log_N == 3) {\n    fast_hadamard_transform_20N_launch<2, 3, input_t>(params, device);\n  } else if (params.log_N == 4) {\n    fast_hadamard_transform_20N_launch<4, 4, input_t>(params, device);\n  } else if (params.log_N == 5) {\n    fast_hadamard_transform_20N_launch<8, 5, input_t>(params, device);\n  } else if (params.log_N == 6) {\n    fast_hadamard_transform_20N_launch<16, 6, input_t>(params, device);\n  } else if (params.log_N == 7) {\n    fast_hadamard_transform_20N_launch<32, 7, input_t>(params, device);\n  } else if (params.log_N == 8) {\n    fast_hadamard_transform_20N_launch<64, 8, input_t>(params, device);\n  } else if (params.log_N == 9) {\n    fast_hadamard_transform_20N_launch<128, 9, input_t>(params, device);\n  } else if (params.log_N == 10) {\n    fast_hadamard_transform_20N_launch<256, 10, input_t>(params, device);\n  } else {\n    host::Panic(\"fast_hadamard_transform_20N: unsupported log_N=\", params.log_N);\n  }\n}\n\ntemplate <int kNThreads, int kLogN, typename input_t>\ninline void fast_hadamard_transform_28N_launch(HadamardParamsBase& params, DLDevice device) {\n  using Ktraits = FastHadamard28NTraits<kNThreads, kLogN, input_t>;\n  launch_kernel<Ktraits>(params, device);\n}\n\ntemplate <typename input_t>\ninline void fast_hadamard_transform_28N_cuda(HadamardParamsBase& params, DLDevice device) {\n  if (params.log_N == 2) {\n    fast_hadamard_transform_28N_launch<1, 2, input_t>(params, device);\n  } else if (params.log_N == 3) {\n    fast_hadamard_transform_28N_launch<2, 3, input_t>(params, device);\n  } else if (params.log_N == 4) {\n    fast_hadamard_transform_28N_launch<4, 4, input_t>(params, device);\n  } else if (params.log_N == 5) {\n    fast_hadamard_transform_28N_launch<8, 5, input_t>(params, device);\n  } else if (params.log_N == 6) {\n    fast_hadamard_transform_28N_launch<16, 6, input_t>(params, device);\n  } else if (params.log_N == 7) {\n    fast_hadamard_transform_28N_launch<32, 7, input_t>(params, device);\n  } else if (params.log_N == 8) {\n    fast_hadamard_transform_28N_launch<64, 8, input_t>(params, device);\n  } else if (params.log_N == 9) {\n    fast_hadamard_transform_28N_launch<128, 9, input_t>(params, device);\n  } else if (params.log_N == 10) {\n    fast_hadamard_transform_28N_launch<256, 10, input_t>(params, device);\n  } else {\n    host::Panic(\"fast_hadamard_transform_28N: unsupported log_N=\", params.log_N);\n  }\n}\n\ntemplate <int kNThreads, int kLogN, typename input_t>\ninline void fast_hadamard_transform_40N_launch(HadamardParamsBase& params, DLDevice device) {\n  using Ktraits = FastHadamard40NTraits<kNThreads, kLogN, input_t>;\n  launch_kernel<Ktraits>(params, device);\n}\n\ntemplate <typename input_t>\ninline void fast_hadamard_transform_40N_cuda(HadamardParamsBase& params, DLDevice device) {\n  if (params.log_N == 2) {\n    fast_hadamard_transform_40N_launch<1, 2, input_t>(params, device);\n  } else if (params.log_N == 3) {\n    fast_hadamard_transform_40N_launch<2, 3, input_t>(params, device);\n  } else if (params.log_N == 4) {\n    fast_hadamard_transform_40N_launch<4, 4, input_t>(params, device);\n  } else if (params.log_N == 5) {\n    fast_hadamard_transform_40N_launch<8, 5, input_t>(params, device);\n  } else if (params.log_N == 6) {\n    fast_hadamard_transform_40N_launch<16, 6, input_t>(params, device);\n  } else if (params.log_N == 7) {\n    fast_hadamard_transform_40N_launch<32, 7, input_t>(params, device);\n  } else if (params.log_N == 8) {\n    fast_hadamard_transform_40N_launch<64, 8, input_t>(params, device);\n  } else if (params.log_N == 9) {\n    fast_hadamard_transform_40N_launch<128, 9, input_t>(params, device);\n  } else if (params.log_N == 10) {\n    fast_hadamard_transform_40N_launch<256, 10, input_t>(params, device);\n  } else {\n    host::Panic(\"fast_hadamard_transform_40N: unsupported log_N=\", params.log_N);\n  }\n}\n\ninline void set_hadamard_params(\n    HadamardParamsBase& params,\n    int64_t batch,\n    int64_t dim,\n    int64_t multiple,\n    const tvm::ffi::TensorView x,\n    const tvm::ffi::TensorView out,\n    float scale) {\n  std::memset(&params, 0, sizeof(params));\n  params.batch = static_cast<int>(batch);\n  params.dim = static_cast<int>(dim);\n  params.log_N = ceil_log2(static_cast<int>(dim / multiple));\n  params.x_ptr = const_cast<void*>(x.data_ptr());\n  params.out_ptr = const_cast<void*>(out.data_ptr());\n  params.x_batch_stride = x.stride(0);\n  params.out_batch_stride = out.stride(0);\n  params.scale = scale;\n}\n\ntemplate <int kMultiple, typename DType>\ninline void run_hadamard(const tvm::ffi::TensorView x, const tvm::ffi::TensorView out, float scale) {\n  using namespace host;\n\n  auto N = SymbolicSize{\"batch\"};\n  auto D = SymbolicSize{\"dim\"};\n  auto SX = SymbolicSize{\"x_batch_stride\"};\n  auto SO = SymbolicSize{\"out_batch_stride\"};\n  auto device = SymbolicDevice{};\n  device.set_options<kDLCUDA>();\n\n  TensorMatcher({N, D}).with_strides({SX, 1}).with_dtype<DType>().with_device(device).verify(x);\n  TensorMatcher({N, D}).with_strides({SO, 1}).with_dtype<DType>().with_device(device).verify(out);\n\n  const int64_t batch = N.unwrap();\n  const int64_t dim = D.unwrap();\n\n  RuntimeCheck(dim % kMultiple == 0, \"hadamard: dim must be divisible by \", kMultiple);\n\n  HadamardParamsBase params;\n  set_hadamard_params(params, batch, dim, kMultiple, x, out, scale);\n\n  if constexpr (kMultiple == 1) {\n    RuntimeCheck(dim % 8 == 0, \"fast_hadamard_transform only supports hidden dim divisible by 8\");\n    RuntimeCheck(dim <= 32768, \"fast_hadamard_transform only supports hidden dim <= 32768\");\n    fast_hadamard_transform_cuda<DType>(params, device.unwrap());\n  } else if constexpr (kMultiple == 12) {\n    RuntimeCheck(dim % (4 * 12) == 0, \"fast_hadamard_transform_12N only supports hidden dim divisible by 48\");\n    RuntimeCheck(dim <= 12 * 1024, \"fast_hadamard_transform_12N only supports hidden dim <= 12288\");\n    fast_hadamard_transform_12N_cuda<DType>(params, device.unwrap());\n  } else if constexpr (kMultiple == 20) {\n    RuntimeCheck(dim % (4 * 20) == 0, \"fast_hadamard_transform_20N only supports hidden dim divisible by 80\");\n    RuntimeCheck(dim <= 20 * 1024, \"fast_hadamard_transform_20N only supports hidden dim <= 20480\");\n    fast_hadamard_transform_20N_cuda<DType>(params, device.unwrap());\n  } else if constexpr (kMultiple == 28) {\n    RuntimeCheck(dim % (4 * 28) == 0, \"fast_hadamard_transform_28N only supports hidden dim divisible by 112\");\n    RuntimeCheck(dim <= 28 * 1024, \"fast_hadamard_transform_28N only supports hidden dim <= 28672\");\n    fast_hadamard_transform_28N_cuda<DType>(params, device.unwrap());\n  } else if constexpr (kMultiple == 40) {\n    RuntimeCheck(dim % (4 * 40) == 0, \"fast_hadamard_transform_40N only supports hidden dim divisible by 160\");\n    RuntimeCheck(dim <= 40 * 1024, \"fast_hadamard_transform_40N only supports hidden dim <= 40960\");\n    fast_hadamard_transform_40N_cuda<DType>(params, device.unwrap());\n  } else {\n    Panic(\"Unsupported multiple\");\n  }\n}\n\ntemplate <typename DType>\nstruct HadamardKernel {\n  static void run(const tvm::ffi::TensorView x, const tvm::ffi::TensorView out, float scale) {\n    run_hadamard<1, DType>(x, out, scale);\n  }\n};\n\ntemplate <typename DType>\nstruct Hadamard12NKernel {\n  static void run(const tvm::ffi::TensorView x, const tvm::ffi::TensorView out, float scale) {\n    run_hadamard<12, DType>(x, out, scale);\n  }\n};\n\ntemplate <typename DType>\nstruct Hadamard20NKernel {\n  static void run(const tvm::ffi::TensorView x, const tvm::ffi::TensorView out, float scale) {\n    run_hadamard<20, DType>(x, out, scale);\n  }\n};\n\ntemplate <typename DType>\nstruct Hadamard28NKernel {\n  static void run(const tvm::ffi::TensorView x, const tvm::ffi::TensorView out, float scale) {\n    run_hadamard<28, DType>(x, out, scale);\n  }\n};\n\ntemplate <typename DType>\nstruct Hadamard40NKernel {\n  static void run(const tvm::ffi::TensorView x, const tvm::ffi::TensorView out, float scale) {\n    run_hadamard<40, DType>(x, out, scale);\n  }\n};\n\n}  // namespace\n"
  },
  {
    "path": "python/sglang/jit_kernel/csrc/fast-hadamard-transform/static_switch.h",
    "content": "// Copied from https://github.com/sgl-project/fast-hadamard-transform\n\n// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h\n// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h\n\n#pragma once\n\n/// @param COND       - a boolean expression to switch by\n/// @param CONST_NAME - a name given for the constexpr bool variable.\n/// @param ...       - code to execute for true and false\n///\n/// Usage:\n/// ```\n/// BOOL_SWITCH(flag, BoolConst, [&] {\n///     some_function<BoolConst>(...);\n/// });\n/// ```\n#define BOOL_SWITCH(COND, CONST_NAME, ...)      \\\n  [&] {                                         \\\n    if (COND) {                                 \\\n      static constexpr bool CONST_NAME = true;  \\\n      return __VA_ARGS__();                     \\\n    } else {                                    \\\n      static constexpr bool CONST_NAME = false; \\\n      return __VA_ARGS__();                     \\\n    }                                           \\\n  }()\n"
  },
  {
    "path": "python/sglang/jit_kernel/csrc/gemm/awq_dequantize.cuh",
    "content": "// Adapted from\n// https://github.com/vllm-project/vllm/blob/eb59b5a6cba6727d3727c0372258db9002f687c1/csrc/quantization/awq/gemm_kernels.cu#L350\n#pragma once\n\n#include <sgl_kernel/tensor.h>\n\n#include <sgl_kernel/utils.cuh>\n\nnamespace device::awq {\n\ntemplate <int lut>\n__device__ inline int lop3(int a, int b, int c) {\n  int res;\n  asm volatile(\"lop3.b32 %0, %1, %2, %3, %4;\\n\" : \"=r\"(res) : \"r\"(a), \"r\"(b), \"r\"(c), \"n\"(lut));\n  return res;\n}\n\n__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) {\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750\n  uint4 result;\n\n  uint32_t* h = reinterpret_cast<uint32_t*>(&result);\n  uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);\n\n  // First, we extract the i4s and construct an intermediate fp16 number.\n  static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;\n  static constexpr uint32_t BOTTOM_MASK = 0x000f000f;\n  static constexpr uint32_t TOP_MASK = 0x00f000f0;\n  static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;\n\n  // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW\n  // dependency if we issue immediately before required.\n  const uint32_t top_i4s = i4s >> 8;\n  // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400\n  asm volatile(\"lop3.b32 %0, %1, %2, %3, %4;\\n\"\n               : \"=r\"(h[0])\n               : \"r\"(i4s), \"n\"(BOTTOM_MASK), \"n\"(I4s_TO_F16s_MAGIC_NUM), \"n\"(immLut));\n  // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400\n  asm volatile(\"lop3.b32 %0, %1, %2, %3, %4;\\n\"\n               : \"=r\"(h[1])\n               : \"r\"(i4s), \"n\"(TOP_MASK), \"n\"(I4s_TO_F16s_MAGIC_NUM), \"n\"(immLut));\n  // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400\n  asm volatile(\"lop3.b32 %0, %1, %2, %3, %4;\\n\"\n               : \"=r\"(h[2])\n               : \"r\"(top_i4s), \"n\"(BOTTOM_MASK), \"n\"(I4s_TO_F16s_MAGIC_NUM), \"n\"(immLut));\n  // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400\n  asm volatile(\"lop3.b32 %0, %1, %2, %3, %4;\\n\"\n               : \"=r\"(h[3])\n               : \"r\"(top_i4s), \"n\"(TOP_MASK), \"n\"(I4s_TO_F16s_MAGIC_NUM), \"n\"(immLut));\n\n  // This is the half2 {1024, 1024} represented as an integer.\n  static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;\n  // This is the half2 {1 / 16, 1 / 16} represented as an integer.\n  static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;\n  // This is the half2 {-64, -64} represented as an integer.\n  static constexpr uint32_t NEG_64 = 0xd400d400;\n\n  // Finally, we construct the output numbers.\n  // Convert elt_01\n  asm volatile(\"sub.f16x2 %0, %1, %2;\\n\" : \"=r\"(h[0]) : \"r\"(h[0]), \"r\"(FP16_TOP_MAGIC_NUM));\n  // Convert elt_23\n  asm volatile(\"fma.rn.f16x2 %0, %1, %2, %3;\\n\" : \"=r\"(h[1]) : \"r\"(h[1]), \"r\"(ONE_SIXTEENTH), \"r\"(NEG_64));\n  // Convert elt_45\n  asm volatile(\"sub.f16x2 %0, %1, %2;\\n\" : \"=r\"(h[2]) : \"r\"(h[2]), \"r\"(FP16_TOP_MAGIC_NUM));\n  // Convert elt_67\n  asm volatile(\"fma.rn.f16x2 %0, %1, %2, %3;\\n\" : \"=r\"(h[3]) : \"r\"(h[3]), \"r\"(ONE_SIXTEENTH), \"r\"(NEG_64));\n\n  return result;\n#else\n  assert(false);\n  return {};\n#endif\n}\n\n__device__ uint4 dequantize_s4_to_bf16x2(uint32_t const& source) {\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800\n  uint4 result;\n  uint32_t* h = reinterpret_cast<uint32_t*>(&result);\n  uint32_t const i4s = source;\n\n  // Define masks and constants\n  static constexpr uint32_t MASK = 0x000f000f;\n  static constexpr uint32_t EX = 0x43004300;\n  static constexpr uint32_t MUL = 0x3F803F80;\n  static constexpr uint32_t ADD = 0xC300C300;\n\n  int lo0 = lop3<(0xf0 & 0xcc) | 0xaa>(i4s, MASK, EX);\n  int hi0 = lop3<(0xf0 & 0xcc) | 0xaa>(i4s >> 4, MASK, EX);\n  int lo1 = lop3<(0xf0 & 0xcc) | 0xaa>(i4s >> 8, MASK, EX);\n  int hi1 = lop3<(0xf0 & 0xcc) | 0xaa>(i4s >> 12, MASK, EX);\n\n  nv_bfloat162* res = reinterpret_cast<nv_bfloat162*>(h);\n  res[0] = __hfma2(\n      *reinterpret_cast<nv_bfloat162*>(&lo0),\n      *reinterpret_cast<const nv_bfloat162*>(&MUL),\n      *reinterpret_cast<const nv_bfloat162*>(&ADD));\n  res[1] = __hfma2(\n      *reinterpret_cast<nv_bfloat162*>(&hi0),\n      *reinterpret_cast<const nv_bfloat162*>(&MUL),\n      *reinterpret_cast<const nv_bfloat162*>(&ADD));\n  res[2] = __hfma2(\n      *reinterpret_cast<nv_bfloat162*>(&lo1),\n      *reinterpret_cast<const nv_bfloat162*>(&MUL),\n      *reinterpret_cast<const nv_bfloat162*>(&ADD));\n  res[3] = __hfma2(\n      *reinterpret_cast<nv_bfloat162*>(&hi1),\n      *reinterpret_cast<const nv_bfloat162*>(&MUL),\n      *reinterpret_cast<const nv_bfloat162*>(&ADD));\n\n  return result;\n#else\n  assert(false);\n  return {};\n#endif\n}\n\ntemplate <typename OutputT>\n__global__ void __launch_bounds__(256) dequantize_weights(\n    int* __restrict__ qweight,\n    OutputT* __restrict__ scales,\n    int* __restrict__ qzeros,\n    OutputT* __restrict__ output,\n    int group_size,\n    int qweight_cols,\n    int qweight_rows) {\n  int col = blockIdx.x * blockDim.x + threadIdx.x;\n  int row = blockIdx.y * blockDim.y + threadIdx.y;\n  if (col >= qweight_cols || row >= qweight_rows) return;\n\n  int group_idx = row / group_size;\n  int scale_offset = 8 * col + group_idx * qweight_cols * 8;\n  uint4 loaded_scale = *(uint4*)(scales + scale_offset);\n\n  // Handle different data types\n  if constexpr (std::is_same<OutputT, half>::value) {\n    // FP16 path\n    uint4 zeros = dequantize_s4_to_fp16x2(qzeros[col + group_idx * qweight_cols]);\n    uint4 weight_fp16 = dequantize_s4_to_fp16x2(qweight[col + row * qweight_cols]);\n\n    // Use PTX assembly for FP16 operations\n    asm volatile(\"sub.f16x2 %0, %1, %2;\\n\" : \"=r\"(weight_fp16.x) : \"r\"(weight_fp16.x), \"r\"(zeros.x));\n    asm volatile(\"mul.rn.f16x2 %0, %1, %2;\\n\" : \"=r\"(weight_fp16.x) : \"r\"(weight_fp16.x), \"r\"(loaded_scale.x));\n    asm volatile(\"sub.f16x2 %0, %1, %2;\\n\" : \"=r\"(weight_fp16.y) : \"r\"(weight_fp16.y), \"r\"(zeros.y));\n    asm volatile(\"mul.rn.f16x2 %0, %1, %2;\\n\" : \"=r\"(weight_fp16.y) : \"r\"(weight_fp16.y), \"r\"(loaded_scale.y));\n    asm volatile(\"sub.f16x2 %0, %1, %2;\\n\" : \"=r\"(weight_fp16.z) : \"r\"(weight_fp16.z), \"r\"(zeros.z));\n    asm volatile(\"mul.rn.f16x2 %0, %1, %2;\\n\" : \"=r\"(weight_fp16.z) : \"r\"(weight_fp16.z), \"r\"(loaded_scale.z));\n    asm volatile(\"sub.f16x2 %0, %1, %2;\\n\" : \"=r\"(weight_fp16.w) : \"r\"(weight_fp16.w), \"r\"(zeros.w));\n    asm volatile(\"mul.rn.f16x2 %0, %1, %2;\\n\" : \"=r\"(weight_fp16.w) : \"r\"(weight_fp16.w), \"r\"(loaded_scale.w));\n\n    OutputT* output_ptr = output + 8 * col + 8 * row * qweight_cols;\n    *(uint4*)output_ptr = weight_fp16;\n  } else if constexpr (std::is_same<OutputT, __nv_bfloat16>::value) {\n    uint4 weight_raw = dequantize_s4_to_bf16x2(qweight[col + row * qweight_cols]);\n    uint4 zero_raw = dequantize_s4_to_bf16x2(qzeros[col + group_idx * qweight_cols]);\n    uint4 scale_raw = *reinterpret_cast<uint4*>(scales + scale_offset);\n\n    // Vectorized processing (each uint4 contains 4 nv_bfloat162)\n    nv_bfloat162* weight_vec = reinterpret_cast<nv_bfloat162*>(&weight_raw);\n    nv_bfloat162* zero_vec = reinterpret_cast<nv_bfloat162*>(&zero_raw);\n    nv_bfloat162* scale_vec = reinterpret_cast<nv_bfloat162*>(&scale_raw);\n\n// Single instruction dual-channel operation\n#pragma unroll\n    for (int i = 0; i < 4; ++i) {  // uint4 = 4 * nv_bfloat162\n      weight_vec[i] = __hmul2(__hsub2(weight_vec[i], zero_vec[i]), scale_vec[i]);\n    }\n\n    // Directly store to OutputT array (guaranteed contiguous memory)\n    OutputT* output_ptr = output + 8 * col + row * qweight_cols * 8;\n    static_assert(sizeof(uint4) == 8 * sizeof(OutputT), \"Memory layout mismatch\");\n    *reinterpret_cast<uint4*>(output_ptr) = weight_raw;\n  }\n}\n\n}  // namespace device::awq\n\n// Host wrapper\ntemplate <typename OutputT>\nvoid awq_dequantize(\n    tvm::ffi::TensorView output,\n    tvm::ffi::TensorView qweight,\n    tvm::ffi::TensorView scales,\n    tvm::ffi::TensorView qzeros) {\n  using namespace host;\n\n  int64_t qweight_rows = qweight.size(0);\n  int64_t qweight_cols = qweight.size(1);\n  int64_t scales_rows = scales.size(0);\n\n  // Validate tensors\n  SymbolicDevice cuda_device;\n  cuda_device.set_options<kDLCUDA>();\n\n  TensorMatcher({qweight_rows, qweight_cols}).with_dtype<int32_t>().with_device(cuda_device).verify(qweight);\n  TensorMatcher({scales_rows, qweight_cols * 8}).with_dtype<OutputT>().with_device(cuda_device).verify(scales);\n  TensorMatcher({scales_rows, qweight_cols}).with_dtype<int32_t>().with_device(cuda_device).verify(qzeros);\n  TensorMatcher({qweight_rows, qweight_cols * 8}).with_dtype<OutputT>().with_device(cuda_device).verify(output);\n\n  // Get device and stream\n  auto device = cuda_device.unwrap();\n  auto stream = LaunchKernel::resolve_device(device);\n\n  int group_size = static_cast<int>(qweight_rows / scales_rows);\n  int x_num_threads = 16;\n  int y_num_threads = 16;\n  int x_blocks = (static_cast<int>(qweight_cols) + x_num_threads - 1) / x_num_threads;\n  int y_blocks = (static_cast<int>(qweight_rows) + y_num_threads - 1) / y_num_threads;\n\n  dim3 num_blocks(x_blocks, y_blocks);\n  dim3 threads_per_block(x_num_threads, y_num_threads);\n\n  // Get pointers\n  auto* qweight_ptr = reinterpret_cast<int*>(qweight.data_ptr());\n  auto* scales_ptr = reinterpret_cast<OutputT*>(scales.data_ptr());\n  auto* qzeros_ptr = reinterpret_cast<int*>(qzeros.data_ptr());\n  auto* output_ptr = reinterpret_cast<OutputT*>(output.data_ptr());\n\n  LaunchKernel(num_blocks, threads_per_block, stream)(\n      device::awq::dequantize_weights<OutputT>,\n      qweight_ptr,\n      scales_ptr,\n      qzeros_ptr,\n      output_ptr,\n      group_size,\n      static_cast<int>(qweight_cols),\n      static_cast<int>(qweight_rows));\n}\n"
  },
  {
    "path": "python/sglang/jit_kernel/csrc/gemm/marlin/awq_marlin_repack.cuh",
    "content": "#pragma once\n\n#include <sgl_kernel/tensor.h>\n\n#include <sgl_kernel/utils.cuh>\n\n#include \"marlin.cuh\"\n\nnamespace device::marlin {\n\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800\ntemplate <int const num_threads, int const num_bits>\n__global__ void awq_marlin_repack_kernel(\n    uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr, int size_k, int size_n) {\n  return;\n}\n#else\n\ntemplate <int const num_threads, int const num_bits>\n__global__ void awq_marlin_repack_kernel(\n    uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr, int size_k, int size_n) {\n  constexpr int pack_factor = 32 / num_bits;\n\n  int k_tiles = size_k / tile_k_size;\n  int n_tiles = size_n / tile_n_size;\n  int block_k_tiles = div_ceil(k_tiles, (int)gridDim.x);\n\n  auto start_k_tile = blockIdx.x * block_k_tiles;\n  if (start_k_tile >= k_tiles) {\n    return;\n  }\n\n  int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);\n\n  // Wait until the next thread tile has been loaded to shared memory.\n  auto wait_for_stage = [&]() {\n    // We only have `stages - 2` active fetches since we are double buffering\n    // and can only issue the next fetch when it is guaranteed that the previous\n    // shared memory load is fully complete (as it may otherwise be\n    // overwritten).\n    cp_async_wait<repack_stages - 2>();\n    __syncthreads();\n  };\n\n  extern __shared__ int4 sh[];\n\n  constexpr int tile_n_ints = tile_n_size / pack_factor;\n\n  constexpr int stage_n_threads = tile_n_ints / 4;\n  constexpr int stage_k_threads = tile_k_size;\n  constexpr int stage_size = stage_k_threads * stage_n_threads;\n\n  auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {\n    if (n_tile_id >= n_tiles) {\n      cp_async_fence();\n      return;\n    }\n\n    int first_n = n_tile_id * tile_n_size;\n    int first_n_packed = first_n / pack_factor;\n\n    int4* sh_ptr = sh + stage_size * pipe;\n\n    if (threadIdx.x < stage_size) {\n      auto k_id = threadIdx.x / stage_n_threads;\n      auto n_id = threadIdx.x % stage_n_threads;\n\n      int first_k = k_tile_id * tile_k_size;\n\n      cp_async4(\n          &sh_ptr[k_id * stage_n_threads + n_id],\n          reinterpret_cast<int4 const*>(\n              &(b_q_weight_ptr[(first_k + k_id) * (size_n / pack_factor) + first_n_packed + (n_id * 4)])));\n    }\n\n    cp_async_fence();\n  };\n\n  auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {\n    if (n_tile_id >= n_tiles) {\n      return;\n    }\n\n    auto warp_id = threadIdx.x / 32;\n    auto th_id = threadIdx.x % 32;\n\n    if (warp_id >= 4) {\n      return;\n    }\n\n    int tc_col = th_id / 4;\n    int tc_row = (th_id % 4) * 2;\n\n    constexpr int tc_offsets[4] = {0, 1, 8, 9};\n\n    int cur_n = warp_id * 16 + tc_col;\n    int cur_n_packed = cur_n / pack_factor;\n    int cur_n_pos = cur_n % pack_factor;\n\n    constexpr int sh_stride = tile_n_ints;\n    constexpr uint32_t mask = (1 << num_bits) - 1;\n\n    int4* sh_stage_ptr = sh + stage_size * pipe;\n    uint32_t* sh_stage_int_ptr = reinterpret_cast<uint32_t*>(sh_stage_ptr);\n\n    // Undo interleaving\n    int cur_n_pos_unpacked;\n    if constexpr (num_bits == 4) {\n      constexpr int undo_pack[8] = {0, 4, 1, 5, 2, 6, 3, 7};\n      cur_n_pos_unpacked = undo_pack[cur_n_pos];\n    } else {\n      constexpr int undo_pack[4] = {0, 2, 1, 3};\n      cur_n_pos_unpacked = undo_pack[cur_n_pos];\n    }\n\n    uint32_t vals[8];\n#pragma unroll\n    for (int i = 0; i < 4; i++) {\n      int cur_elem = tc_row + tc_offsets[i];\n\n      int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem];\n      int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) + sh_stride * cur_elem];\n\n      vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask;\n      vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask;\n    }\n\n    constexpr int tile_size_val = tile_k_size * tile_n_size / pack_factor;\n    int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size_val;\n\n    // Result of:\n    // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h\n    if constexpr (num_bits == 4) {\n      constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};\n\n      uint32_t res = 0;\n#pragma unroll\n      for (int i = 0; i < 8; i++) {\n        res |= vals[pack_idx[i]] << (i * 4);\n      }\n\n      out_ptr[out_offset + th_id * 4 + warp_id] = res;\n\n    } else {\n      constexpr int pack_idx[4] = {0, 2, 1, 3};\n\n      uint32_t res1 = 0;\n      uint32_t res2 = 0;\n#pragma unroll\n      for (int i = 0; i < 4; i++) {\n        res1 |= vals[pack_idx[i]] << (i * 8);\n        res2 |= vals[4 + pack_idx[i]] << (i * 8);\n      }\n\n      out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;\n      out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;\n    }\n  };\n\n  auto start_pipes = [&](int k_tile_id, int n_tile_id) {\n#pragma unroll\n    for (int pipe = 0; pipe < repack_stages - 1; pipe++) {\n      fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);\n    }\n\n    wait_for_stage();\n  };\n#pragma unroll\n  for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {\n    int n_tile_id = 0;\n\n    start_pipes(k_tile_id, n_tile_id);\n\n    while (n_tile_id < n_tiles) {\n#pragma unroll\n      for (int pipe = 0; pipe < repack_stages; pipe++) {\n        fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, n_tile_id + pipe + repack_stages - 1);\n        repack_tile(pipe, k_tile_id, n_tile_id + pipe);\n        wait_for_stage();\n      }\n      n_tile_id += repack_stages;\n    }\n  }\n}\n#endif\n\n}  // namespace device::marlin\n\n// Host wrapper\nvoid awq_marlin_repack(\n    tvm::ffi::TensorView out, tvm::ffi::TensorView b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits) {\n  using namespace host;\n  using namespace device::marlin;\n\n  // Validate alignment\n  RuntimeCheck(size_k % tile_k_size == 0, \"size_k = \", size_k, \" is not divisible by tile_k_size = \", tile_k_size);\n  RuntimeCheck(size_n % tile_n_size == 0, \"size_n = \", size_n, \" is not divisible by tile_n_size = \", tile_n_size);\n  RuntimeCheck(num_bits == 4 || num_bits == 8, \"num_bits must be 4 or 8. Got = \", num_bits);\n\n  int const pack_factor = 32 / num_bits;\n\n  // Validate tensors\n  SymbolicDevice cuda_device;\n  cuda_device.set_options<kDLCUDA>();\n\n  TensorMatcher({size_k, size_n / pack_factor}).with_dtype<int32_t>().with_device(cuda_device).verify(b_q_weight);\n\n  TensorMatcher({size_k / tile_size, size_n * tile_size / pack_factor})\n      .with_dtype<int32_t>()\n      .with_device(cuda_device)\n      .verify(out);\n\n  // Get device and stream\n  auto device = cuda_device.unwrap();\n  auto stream = LaunchKernel::resolve_device(device);\n\n  // Get pointers\n  auto* b_q_weight_ptr = reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());\n  auto* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr());\n\n  // Get device attributes\n  int blocks = 0;\n  cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, device.device_id);\n\n  int max_shared_mem = 0;\n  cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, device.device_id);\n  RuntimeCheck(max_shared_mem > 0, \"max_shared_mem must be > 0\");\n\n  // Dispatch based on num_bits\n  if (num_bits == 4) {\n    cudaFuncSetAttribute(\n        awq_marlin_repack_kernel<repack_threads, 4>, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem);\n    LaunchKernel(blocks, repack_threads, stream, max_shared_mem)(\n        awq_marlin_repack_kernel<repack_threads, 4>,\n        b_q_weight_ptr,\n        out_ptr,\n        static_cast<int>(size_k),\n        static_cast<int>(size_n));\n  } else if (num_bits == 8) {\n    cudaFuncSetAttribute(\n        awq_marlin_repack_kernel<repack_threads, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem);\n    LaunchKernel(blocks, repack_threads, stream, max_shared_mem)(\n        awq_marlin_repack_kernel<repack_threads, 8>,\n        b_q_weight_ptr,\n        out_ptr,\n        static_cast<int>(size_k),\n        static_cast<int>(size_n));\n  } else {\n    RuntimeCheck(false, \"Unsupported repack config: num_bits = \", num_bits);\n  }\n}\n"
  },
  {
    "path": "python/sglang/jit_kernel/csrc/gemm/marlin/dequant.h",
    "content": "/*\nFast Dequantization (Converting INT4/INT8/FP4/FP8 to FP16/BF16)\n\nThe process of fast dequantization can be summarized as a combination\nof bitwise operations and floating-point computations:\n\nweight =>(bit_op / bitwise operations)=>\nf16_value =>(flop / floating-point computation)=>\ndequantized_weight\n\nSince the dequantized weights typically require subtracting the zero point and\napplying a scale factor, the floating-point computation step can be fused with\nthe zero-point subtraction and scaling operations.\n\nThe following are the parts that need to be modified for the fused operation\nof zero-point subtraction and scaling.\n\n## INT4 => FP16/BF16 or INT8 => FP16\n\nThe floating-point computation is `__hsub2`\n\nIf has zero points:\n\n    flop(bit_op(weight)) - flop(bit_op(zp))\n  = sub(bit_op(weight), bias) - sub(bit_op(zp), bias)\n  = bit_op(weight) - bit_op(zp)\n\nso we don't need additional modification.\n\nIf has float zero points:\n\n    flop(bit_op(weight)) - fzp\n  = sub(bit_op(weight), bias) - fzp\n  = bit_op(weight) - (fzp + bias)\n\nwhere the `fzp + bias` can be computed at weight loading. But this\nmay have accuracy issue, so we should not use this in most cases.\n\nIf has not zero points:\n\n    scale(flop(bit_op(weight)))\n  = scale(sub(bit_op(weight), bias))\n  = scale(bit_op(weight)) - scale(bias)\n  = fma(bit_op(weight), scale_factor, scale(bias))\n\nwhere the `scale(bias)` can be cached. But this may have accuracy issue,\nso we should not use this in most cases.\n\n\n## INT8 => BF16\n\nINT8 => BF16 is a special case, it use byte_perm instead of flop.\nWe cannot fused byte_perm with scaling.\n\n\n## FP4/FP8 => FP16/BF16\n\n    scale(flop(bit_op(weight)))\n  = scale(mul(bit_op(weight), multiplier))\n  = mul(bit_op(weight), scale_factor * multiplier)\n\nwhere `scale_factor * multiplier` can be computed at weight loading.\n\n*/\n\n#include \"marlin_dtypes.cuh\"\n\nnamespace device::marlin {\n\n#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800\n// Lookup-table based 3-input logical operation; explicitly used for\n// dequantization as the compiler does not seem to automatically recognize it in\n// all cases.\ntemplate <int lut>\n__device__ inline int lop3(int a, int b, int c) {\n  int res;\n  asm volatile(\"lop3.b32 %0, %1, %2, %3, %4;\\n\" : \"=r\"(res) : \"r\"(a), \"r\"(b), \"r\"(c), \"n\"(lut));\n  return res;\n}\n\n// Constructs destination register by taking bytes from 2 sources (based on\n// mask)\ntemplate <int start_byte, int mask>\n__device__ inline uint32_t prmt(uint32_t a) {\n  uint32_t res;\n  asm volatile(\"prmt.b32 %0, %1, %2, %3;\\n\" : \"=r\"(res) : \"r\"(a), \"n\"(start_byte), \"n\"(mask));\n  return res;\n}\n\ntemplate <typename scalar_t2, host::ScalarTypeId w_type_id, bool skip_flop = false>\n__device__ inline void dequant(int q, scalar_t2* frag_b);\n\n//\n// Efficiently dequantize 4bit values packed in an int32 value into a full\n// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,\n// with some small changes:\n// - FP16:\n// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287\n// - BF16:\n// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385\n//\ntemplate <>\n__device__ inline void dequant<half2, host::kU4B8.id(), true>(int q, half2* frag_b) {\n  const int MASK = 0x000f000f;\n  const int EX = 0x64006400;\n  // Guarantee that the `(a & b) | c` operations are LOP3s.\n  int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);\n  q >>= 4;\n  int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);\n\n  frag_b[0] = *reinterpret_cast<half2*>(&lo);\n  frag_b[1] = *reinterpret_cast<half2*>(&hi);\n}\n\ntemplate <>\n__device__ inline void dequant<half2, host::kU4B8.id(), false>(int q, half2* frag_b) {\n  const int LO = 0x000f000f;\n  const int HI = 0x00f000f0;\n  const int EX = 0x64006400;\n  // Guarantee that the `(a & b) | c` operations are LOP3s.\n  // clang-format off\n  int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);\n  int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);\n  // clang-format on\n  // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point\n  // directly into `SUB` and `ADD`.\n  const int SUB = 0x64086408;\n  const int MUL = 0x2c002c00;\n  const int ADD = 0xd480d480;\n  frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo), *reinterpret_cast<const half2*>(&SUB));\n  frag_b[1] = __hfma2(\n      *reinterpret_cast<half2*>(&hi), *reinterpret_cast<const half2*>(&MUL), *reinterpret_cast<const half2*>(&ADD));\n}\n\ntemplate <>\n__device__ inline void dequant<half2, host::kU4.id(), true>(int q, half2* frag_b) {\n  dequant<half2, host::kU4B8.id(), true>(q, frag_b);\n}\n\ntemplate <>\n__device__ inline void dequant<half2, host::kU4.id(), false>(int q, half2* frag_b) {\n  const int LO = 0x000f000f;\n  const int HI = 0x00f000f0;\n  const int EX = 0x64006400;\n  // Guarantee that the `(a & b) | c` operations are LOP3s.\n  // clang-format off\n  int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);\n  int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);\n  // clang-format on\n  // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point\n  // directly into `SUB` and `ADD`.\n  const int SUB = 0x64006400;\n  const int MUL = 0x2c002c00;\n  const int ADD = 0xd400d400;\n  frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo), *reinterpret_cast<const half2*>(&SUB));\n  frag_b[1] = __hfma2(\n      *reinterpret_cast<half2*>(&hi), *reinterpret_cast<const half2*>(&MUL), *reinterpret_cast<const half2*>(&ADD));\n}\n\ntemplate <>\n__device__ inline void dequant<nv_bfloat162, host::kU4B8.id(), true>(int q, nv_bfloat162* frag_b) {\n  static constexpr uint32_t MASK = 0x000f000f;\n  static constexpr uint32_t EX = 0x43004300;\n\n  // Guarantee that the `(a & b) | c` operations are LOP3s.\n  // clang-format off\n  int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);\n  q >>= 4;\n  int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);\n  // clang-format on\n\n  frag_b[0] = *reinterpret_cast<nv_bfloat162*>(&lo);\n  frag_b[1] = *reinterpret_cast<nv_bfloat162*>(&hi);\n}\n\ntemplate <>\n__device__ inline void dequant<nv_bfloat162, host::kU4B8.id(), false>(int q, nv_bfloat162* frag_b) {\n  dequant<nv_bfloat162, host::kU4B8.id(), true>(q, frag_b);\n\n  static constexpr uint32_t SUB = 0x43084308;\n\n  frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast<const nv_bfloat162*>(&SUB));\n  frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast<const nv_bfloat162*>(&SUB));\n}\n\ntemplate <>\n__device__ inline void dequant<nv_bfloat162, host::kU4.id(), true>(int q, nv_bfloat162* frag_b) {\n  dequant<nv_bfloat162, host::kU4B8.id(), true>(q, frag_b);\n}\n\ntemplate <>\n__device__ inline void dequant<nv_bfloat162, host::kU4.id(), false>(int q, nv_bfloat162* frag_b) {\n  dequant<nv_bfloat162, host::kU4.id(), true>(q, frag_b);\n\n  static constexpr uint32_t SUB = 0x43004300;\n\n  frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast<const nv_bfloat162*>(&SUB));\n  frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast<const nv_bfloat162*>(&SUB));\n}\n\n//\n// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or\n// bf16 Reference:\n// - FP16:\n// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85\n// - BF16:\n// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175\n//\ntemplate <>\n__device__ inline void dequant<half2, host::kU8B128.id(), true>(int q, half2* frag_b) {\n  static constexpr uint32_t mask_for_elt_01 = 0x5250;\n  static constexpr uint32_t mask_for_elt_23 = 0x5351;\n  static constexpr uint32_t start_byte_for_fp16 = 0x64646464;\n\n  uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);\n  uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);\n\n  frag_b[0] = *reinterpret_cast<half2*>(&lo);\n  frag_b[1] = *reinterpret_cast<half2*>(&hi);\n}\n\ntemplate <>\n__device__ inline void dequant<half2, host::kU8B128.id(), false>(int q, half2* frag_b) {\n  dequant<half2, host::kU8B128.id(), true>(q, frag_b);\n\n  static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;\n  frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));\n  frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));\n}\n\ntemplate <>\n__device__ inline void dequant<half2, host::kU8.id(), true>(int q, half2* frag_b) {\n  dequant<half2, host::kU8B128.id(), true>(q, frag_b);\n}\n\ntemplate <>\n__device__ inline void dequant<half2, host::kU8.id(), false>(int q, half2* frag_b) {\n  dequant<half2, host::kU8.id(), true>(q, frag_b);\n\n  static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;\n  frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));\n  frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));\n}\n\ntemplate <>\n__device__ inline void dequant<nv_bfloat162, host::kU8B128.id(), false>(int q, nv_bfloat162* frag_b) {\n  float fp32_intermediates[4];\n  uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);\n\n  static constexpr uint32_t fp32_base = 0x4B000000;\n  fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);\n  fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);\n  fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);\n  fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);\n\n  fp32_intermediates[0] -= 8388736.f;\n  fp32_intermediates[1] -= 8388736.f;\n  fp32_intermediates[2] -= 8388736.f;\n  fp32_intermediates[3] -= 8388736.f;\n\n  uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(frag_b);\n  bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632);\n  bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632);\n}\n\ntemplate <>\n__device__ inline void dequant<nv_bfloat162, host::kU8.id(), false>(int q, nv_bfloat162* frag_b) {\n  float fp32_intermediates[4];\n  uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);\n\n  static constexpr uint32_t fp32_base = 0x4B000000;\n  fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);\n  fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);\n  fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);\n  fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);\n\n  fp32_intermediates[0] -= 8388608.f;\n  fp32_intermediates[1] -= 8388608.f;\n  fp32_intermediates[2] -= 8388608.f;\n  fp32_intermediates[3] -= 8388608.f;\n\n  uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(frag_b);\n  bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632);\n  bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632);\n}\n\ntemplate <>\n__device__ inline void dequant<half2, host::kFE4M3fn.id(), true>(int q, half2* frag_b) {\n  // Constants for FP8 (E4M3) and FP16 formats\n  constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5;\n  constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT;\n  constexpr int MASK = 0x7F007F00;\n\n  // Extract and shift FP8 values to FP16 format\n  int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);\n  q <<= 8;\n  int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);\n\n  // Note: reverse indexing is intentional because weights are permuted\n  frag_b[1] = *reinterpret_cast<const half2*>(&Out1);\n  frag_b[0] = *reinterpret_cast<const half2*>(&Out2);\n}\n\ntemplate <>\n__device__ inline void dequant<half2, host::kFE4M3fn.id(), false>(int q, half2* frag_b) {\n  dequant<half2, host::kFE4M3fn.id(), true>(q, frag_b);\n\n  // Constants for FP8 (E4M3) and FP16 formats\n  constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5;\n\n  // Construct and apply exponent bias\n  constexpr int BIAS_OFFSET = (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1));\n  const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET));\n\n  // Convert to half2 and apply bias\n  frag_b[1] = __hmul2(frag_b[1], bias_reg);\n  frag_b[0] = __hmul2(frag_b[0], bias_reg);\n}\n\ntemplate <>\n__device__ inline void dequant<nv_bfloat162, host::kFE4M3fn.id(), true>(int q, nv_bfloat162* frag_b) {\n  // Constants for FP8 (E4M3) and BF16 formats\n  constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;\n  constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;\n\n  constexpr int MASK = 0x7F007F00;\n\n  // Extract and shift FP8 values to BF16 format\n  int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);\n  q <<= 8;\n  int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);\n\n  // Note: reverse indexing is intentional because weights are permuted\n  frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);\n  frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);\n}\n\ntemplate <>\n__device__ inline void dequant<nv_bfloat162, host::kFE4M3fn.id(), false>(int q, nv_bfloat162* frag_b) {\n  dequant<nv_bfloat162, host::kFE4M3fn.id(), true>(q, frag_b);\n\n  // Constants for FP8 (E4M3) and BF16 formats\n  constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;\n\n  // Construct and apply exponent bias\n  constexpr int BIAS_OFFSET = (1 << (BF16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1));\n  // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent\n  // position\n  constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23;\n  const nv_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast<const float*>(&BIAS));\n\n  // Convert to bfloat162 and apply bias\n  frag_b[1] = __hmul2(frag_b[1], bias_reg);\n  frag_b[0] = __hmul2(frag_b[0], bias_reg);\n}\n\ntemplate <>\n__device__ inline void dequant<half2, host::kFE2M1f.id(), true>(int q, half2* frag_b) {\n  // Constants for FP4 (E2M1) and FP16 formats\n  constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5;\n  constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP4_EXPONENT;\n  constexpr int MASK = 0x70007000;\n\n  // Extract and shift FP4 values to FP16 format\n  int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);\n  q <<= 4;\n  int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);\n\n  // Note: reverse indexing is intentional because weights are permuted\n  frag_b[1] = *reinterpret_cast<const half2*>(&Out1);\n  frag_b[0] = *reinterpret_cast<const half2*>(&Out2);\n}\n\ntemplate <>\n__device__ inline void dequant<half2, host::kFE2M1f.id(), false>(int q, half2* frag_b) {\n  dequant<half2, host::kFE2M1f.id(), true>(q, frag_b);\n\n  // Constants for FP4 (E2M1) and FP16 formats\n  constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5;\n\n  // Construct and apply exponent bias\n  constexpr int BIAS_OFFSET = (1 << (FP16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1));\n  const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET));\n\n  // Convert to half2 and apply bias\n  frag_b[1] = __hmul2(frag_b[1], bias_reg);\n  frag_b[0] = __hmul2(frag_b[0], bias_reg);\n}\n\ntemplate <>\n__device__ inline void dequant<nv_bfloat162, host::kFE2M1f.id(), true>(int q, nv_bfloat162* frag_b) {\n  // Constants for FP4 (E2M1) and FP16 formats\n  constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8;\n  constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP4_EXPONENT;\n  constexpr int MASK = 0x70007000;\n\n  // Extract and shift FP4 values to FP16 format\n  int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);\n  q <<= 4;\n  int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);\n\n  // Note: reverse indexing is intentional because weights are permuted\n  frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);\n  frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);\n}\n\ntemplate <>\n__device__ inline void dequant<nv_bfloat162, host::kFE2M1f.id(), false>(int q, nv_bfloat162* frag_b) {\n  dequant<nv_bfloat162, host::kFE2M1f.id(), true>(q, frag_b);\n\n  // Constants for FP4 (E2M1) and BF16 formats\n  constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8;\n\n  // Construct and apply exponent bias\n  constexpr int BIAS_OFFSET = (1 << (BF16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1));\n  // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent\n  // position\n  constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23;\n  const nv_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast<const float*>(&BIAS));\n\n  // Convert to half2 and apply bias\n  frag_b[1] = __hmul2(frag_b[1], bias_reg);\n  frag_b[0] = __hmul2(frag_b[0], bias_reg);\n}\n\ntemplate <typename scalar_t2>\n__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b);\n\ntemplate <>\n__device__ inline void dequant_fp8_scales<half2>(int q, half2* frag_b) {\n  int Out1 = (q & 0xFF00FF00) >> 1;\n  ;\n  q <<= 8;\n  int Out2 = (q & 0xFF00FF00) >> 1;\n\n  // Note: reverse indexing is intentional because weights are permuted\n  frag_b[1] = *reinterpret_cast<const half2*>(&Out1);\n  frag_b[0] = *reinterpret_cast<const half2*>(&Out2);\n};\n\ntemplate <>\n__device__ inline void dequant_fp8_scales<nv_bfloat162>(int q, nv_bfloat162* frag_b) {\n  constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;\n  constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;\n  constexpr int MASK = 0x7F007F00;\n\n  // Extract and shift FP8 values to BF16 format\n  int Out1 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT);\n  q <<= 8;\n  int Out2 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT);\n\n  // Note: reverse indexing is intentional because weights are permuted\n  frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);\n  frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);\n};\n\n// New version with s_type_id parameter for marlin_moe_wna16_v2\ntemplate <typename scalar_t2, host::ScalarTypeId s_type_id>\n__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b);\n\ntemplate <>\n__device__ inline void dequant_fp8_scales<half2, host::kFE4M3fn.id()>(int q, half2* frag_b) {\n  int Out1 = (q & 0xFF00FF00) >> 1;\n  ;\n  q <<= 8;\n  int Out2 = (q & 0xFF00FF00) >> 1;\n\n  // Note: reverse indexing is intentional because weights are permuted\n  frag_b[1] = *reinterpret_cast<const half2*>(&Out1);\n  frag_b[0] = *reinterpret_cast<const half2*>(&Out2);\n};\n\ntemplate <>\n__device__ inline void dequant_fp8_scales<nv_bfloat162, host::kFE4M3fn.id()>(int q, nv_bfloat162* frag_b) {\n  constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;\n  constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;\n  constexpr int MASK = 0x7F007F00;\n\n  // Extract and shift FP8 values to BF16 format\n  int Out1 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT);\n  q <<= 8;\n  int Out2 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT);\n\n  // Note: reverse indexing is intentional because weights are permuted\n  frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);\n  frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);\n}\n\ntemplate <>\n__device__ inline void dequant_fp8_scales<nv_bfloat162, host::kFE8M0fnu.id()>(int q, nv_bfloat162* frag_b) {\n  // In this conversion, 2 ** -127 in FP8E8M0 would become 0 in BF16,\n  // but we assume that such a extreme value would not occur in real models.\n  int Out1 = (q & 0xFF00FF00) >> 1;\n  q <<= 7;\n  int Out2 = q & 0x7F807F80;\n\n  // Note: reverse indexing is intentional because weights are permuted\n  frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);\n  frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);\n}\n\n#endif\n\n}  // namespace device::marlin\n"
  },
  {
    "path": "python/sglang/jit_kernel/csrc/gemm/marlin/gptq_marlin.cuh",
    "content": "/*\n * Modified by Neural Magic\n * Copyright (C) Marlin.2024 Elias Frantar\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *         http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n/*\n * Adapted from https://github.com/IST-DASLab/marlin\n */\n\n#pragma once\n\n#include <sgl_kernel/tensor.h>\n\n#include <sgl_kernel/scalar_type.hpp>\n\n#include \"kernel.h\"\n#include \"marlin_template.h\"\n\nnamespace device::marlin {\n\n__global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){};\n\nusing MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS);\n\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800\n\n__global__ void permute_cols_kernel(\n    int4 const* __restrict__ a_int4_ptr,\n    int const* __restrict__ perm_int_ptr,\n    int4* __restrict__ out_int4_ptr,\n    int size_m,\n    int size_k,\n    int lda,\n    int block_rows) {}\n\n#else\n\n// For a given \"a\" of size [M,K] performs a permutation of the K columns based\n// on the given \"perm\" indices.\n__global__ void permute_cols_kernel(\n    int4 const* __restrict__ a_int4_ptr,\n    int const* __restrict__ perm_int_ptr,\n    int4* __restrict__ out_int4_ptr,\n    int size_m,\n    int size_k,\n    int lda,\n    int block_rows) {\n  auto start_row = block_rows * blockIdx.x;\n  int finish_row = start_row + block_rows;\n  if (finish_row > size_m) {\n    finish_row = size_m;\n  }\n  int cur_block_rows = finish_row - start_row;\n\n  int input_row_stride = lda * sizeof(half) / 16;\n  int output_row_stride = size_k * sizeof(half) / 16;\n\n  auto permute_row = [&](int row) {\n    int iters = size_k / default_threads;\n    int rest = size_k % default_threads;\n\n    int input_offset = row * input_row_stride;\n    int output_offset = row * output_row_stride;\n\n    half const* a_row_half = reinterpret_cast<half const*>(a_int4_ptr + input_offset);\n    half* out_half = reinterpret_cast<half*>(out_int4_ptr + output_offset);\n\n    int base_k = 0;\n\n    for (int i = 0; i < iters; i++) {\n      auto cur_k = base_k + threadIdx.x;\n      int src_pos = perm_int_ptr[cur_k];\n\n      out_half[cur_k] = a_row_half[src_pos];\n\n      base_k += default_threads;\n    }\n\n    if (rest) {\n      if (threadIdx.x < rest) {\n        auto cur_k = base_k + threadIdx.x;\n        int src_pos = perm_int_ptr[cur_k];\n\n        out_half[cur_k] = a_row_half[src_pos];\n      }\n    }\n  };\n\n  for (int i = 0; i < cur_block_rows; i++) {\n    int cur_row = start_row + i;\n    if (cur_row < size_m) {\n      permute_row(cur_row);\n    }\n  }\n}\n\ntypedef struct {\n  int thread_k;\n  int thread_n;\n  int num_threads;\n} thread_config_t;\n\nthread_config_t small_batch_thread_configs[] = {\n    // Ordered by priority\n\n    // thread_k, thread_n, num_threads\n    {128, 128, 256},\n    {64, 128, 128},\n    {128, 64, 128}};\n\nthread_config_t large_batch_thread_configs[] = {\n    // Ordered by priority\n\n    // thread_k, thread_n, num_threads\n    {64, 256, 256},\n    {64, 128, 128},\n    {128, 64, 128}};\n\ntypedef struct {\n  int blocks_per_sm;\n  thread_config_t tb_cfg;\n} exec_config_t;\n\nint get_scales_cache_size(\n    thread_config_t const& th_config,\n    int prob_m,\n    int prob_n,\n    int prob_k,\n    int num_bits,\n    int group_size,\n    bool has_act_order,\n    bool is_k_full) {\n  bool cache_scales_chunk = has_act_order && !is_k_full;\n\n  int tb_n = th_config.thread_n;\n  int tb_k = th_config.thread_k;\n\n  // Get max scale groups per thread-block\n  int tb_groups;\n  if (group_size == -1) {\n    tb_groups = 1;\n  } else if (group_size == 0) {\n    tb_groups = div_ceil(tb_k, 32);  // Worst case is 32 group size\n  } else {\n    tb_groups = div_ceil(tb_k, group_size);\n  }\n\n  if (cache_scales_chunk) {\n    int load_groups = tb_groups * pipe_stages * 2;  // Chunk size is 2x pipeline over dim K\n    load_groups = max(load_groups, 32);             // We load at least 32 scale groups\n    return load_groups * tb_n * 2;\n  } else {\n    int tb_scales = tb_groups * tb_n * 2;\n\n    return tb_scales * pipe_stages;\n  }\n}\n\nint get_kernel_cache_size(\n    thread_config_t const& th_config,\n    int thread_m_blocks,\n    int prob_m,\n    int prob_n,\n    int prob_k,\n    int num_bits,\n    int group_size,\n    bool has_act_order,\n    bool is_k_full,\n    int has_zp,\n    int is_zp_float) {\n  int pack_factor = 32 / num_bits;\n\n  // Get B size\n  int tb_k = th_config.thread_k;\n  int tb_n = th_config.thread_n;\n  int tb_m = thread_m_blocks * 16;\n  int sh_a_size = pipe_stages * (tb_m * tb_k) * 2;\n  int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;\n  int sh_red_size = tb_m * (tb_n + 8);\n  int sh_s_size =\n      get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full);\n  int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0;\n  int sh_zp_size = 0;\n  if (has_zp) {\n    if (is_zp_float)\n      sh_zp_size = sh_s_size;\n    else if (num_bits == 4)\n      sh_zp_size = sh_s_size / 4;\n    else if (num_bits == 8)\n      sh_zp_size = sh_s_size / 2;\n  }\n\n  int total_size = max(sh_b_size, sh_red_size) + sh_a_size + sh_s_size + sh_zp_size + sh_g_idx_size;\n\n  return total_size;\n}\n\nbool is_valid_config(\n    thread_config_t const& th_config,\n    int thread_m_blocks,\n    int prob_m,\n    int prob_n,\n    int prob_k,\n    int num_bits,\n    int group_size,\n    bool has_act_order,\n    bool is_k_full,\n    int has_zp,\n    int is_zp_float,\n    int max_shared_mem) {\n  // Sanity\n  if (th_config.thread_k == -1 || th_config.thread_n == -1 || th_config.num_threads == -1) {\n    return false;\n  }\n\n  // Verify K/N are divisible by thread K/N\n  if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {\n    return false;\n  }\n\n  // Verify min for thread K/N\n  if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {\n    return false;\n  }\n\n  // num_threads must be at least 128 (= 4 warps)\n  if (th_config.num_threads < 128) {\n    return false;\n  }\n\n  // Check that pipeline fits into cache\n  int cache_size = get_kernel_cache_size(\n      th_config,\n      thread_m_blocks,\n      prob_m,\n      prob_n,\n      prob_k,\n      num_bits,\n      group_size,\n      has_act_order,\n      is_k_full,\n      has_zp,\n      is_zp_float);\n  return cache_size <= max_shared_mem;\n}\n\n#define _GET_IF(                                                                                                       \\\n    W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \\\n  else if (                                                                                                            \\\n      q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && thread_n_blocks == THREAD_N_BLOCKS &&                  \\\n      thread_k_blocks == THREAD_K_BLOCKS && m_block_size_8 == M_BLOCK_SIZE_8 && group_blocks == GROUP_BLOCKS &&        \\\n      num_threads == NUM_THREADS && is_zp_float == IS_ZP_FLOAT) {                                                      \\\n    kernel = Marlin<                                                                                                   \\\n        scalar_t,                                                                                                      \\\n        W_TYPE.id(),                                                                                                   \\\n        NUM_THREADS,                                                                                                   \\\n        THREAD_M_BLOCKS,                                                                                               \\\n        THREAD_N_BLOCKS,                                                                                               \\\n        THREAD_K_BLOCKS,                                                                                               \\\n        M_BLOCK_SIZE_8,                                                                                                \\\n        pipe_stages,                                                                                                   \\\n        GROUP_BLOCKS,                                                                                                  \\\n        IS_ZP_FLOAT>;                                                                                                  \\\n  }\n\n// COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false)\n//         this is the most common cases\n// BIGGROUP: cases for big group size (group_blocks in [-1, 8])\n// FZP: cases for float-zero-point (is_zp_float = true)\n// ACT: cases for act order case (group_blocks == 0)\n// FP4: cases for nvfp4(e2m1) (group_blocks == 1)\n#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS)       \\\n  _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false)  \\\n  _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false)   \\\n  _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false)   \\\n  _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false)   \\\n  _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \\\n  _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)  \\\n  _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false)  \\\n  _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)\n\n#define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS)     \\\n  _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \\\n  _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)  \\\n  _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false)  \\\n  _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)  \\\n                                                                        \\\n  _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \\\n  _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)  \\\n  _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false)  \\\n  _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)  \\\n                                                                        \\\n  _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \\\n  _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)  \\\n  _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false)  \\\n  _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)\n\n#define COMMON_GET_IF(W_TYPE)            \\\n  COMMON_GET_IF_M1(W_TYPE, 8, 8, 256)    \\\n  COMMON_GET_IF_M1(W_TYPE, 8, 4, 128)    \\\n  COMMON_GET_IF_M1(W_TYPE, 4, 8, 128)    \\\n  COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \\\n  COMMON_GET_IF_M234(W_TYPE, 8, 4, 128)  \\\n  COMMON_GET_IF_M234(W_TYPE, 4, 8, 128)\n\n#define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS)     \\\n  _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false)  \\\n  _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false)   \\\n  _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \\\n  _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)\n\n#define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS)   \\\n  _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \\\n  _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)  \\\n  _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \\\n  _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)  \\\n  _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \\\n  _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)\n\n#define BIGGROUP_GET_IF(W_TYPE)            \\\n  BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256)    \\\n  BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128)    \\\n  BIGGROUP_GET_IF_M1(W_TYPE, 4, 8, 128)    \\\n  BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \\\n  BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128)  \\\n  BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128)\n\n#define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS)        \\\n  _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \\\n  _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)\n\n#define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS)       \\\n  _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \\\n  _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \\\n  _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)\n\n#define FP4_GET_IF(W_TYPE)            \\\n  FP4_GET_IF_M1(W_TYPE, 8, 8, 256)    \\\n  FP4_GET_IF_M1(W_TYPE, 8, 4, 128)    \\\n  FP4_GET_IF_M1(W_TYPE, 4, 8, 128)    \\\n  FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \\\n  FP4_GET_IF_M234(W_TYPE, 8, 4, 128)  \\\n  FP4_GET_IF_M234(W_TYPE, 4, 8, 128)\n\n// We currently have 4-bit models only with group_blocks == 4\n#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS)       \\\n  _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \\\n  _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)\n\n#define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS)      \\\n  _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \\\n  _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \\\n  _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)\n\n#define FZP_GET_IF(W_TYPE)            \\\n  FZP_GET_IF_M1(W_TYPE, 8, 8, 256)    \\\n  FZP_GET_IF_M1(W_TYPE, 8, 4, 128)    \\\n  FZP_GET_IF_M1(W_TYPE, 4, 8, 128)    \\\n  FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \\\n  FZP_GET_IF_M234(W_TYPE, 8, 4, 128)  \\\n  FZP_GET_IF_M234(W_TYPE, 4, 8, 128)\n\n// We currently have 4-bit models only with group_blocks == 4\n#define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS)        \\\n  _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \\\n  _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)\n\n#define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS)       \\\n  _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \\\n  _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \\\n  _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)\n\n#define ACT_GET_IF(W_TYPE)            \\\n  ACT_GET_IF_M1(W_TYPE, 8, 8, 256)    \\\n  ACT_GET_IF_M1(W_TYPE, 8, 4, 128)    \\\n  ACT_GET_IF_M1(W_TYPE, 4, 8, 128)    \\\n  ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \\\n  ACT_GET_IF_M234(W_TYPE, 8, 4, 128)  \\\n  ACT_GET_IF_M234(W_TYPE, 4, 8, 128)\n\ntemplate <typename scalar_t>\nMarlinFuncPtr get_marlin_kernel(\n    const host::ScalarType q_type,\n    int thread_m_blocks,\n    int thread_n_blocks,\n    int thread_k_blocks,\n    bool m_block_size_8,\n    bool has_act_order,\n    bool has_zp,\n    int group_blocks,\n    int num_threads,\n    bool is_zp_float) {\n  int num_bits = q_type.size_bits();\n  auto kernel = MarlinDefault;\n  if (false) {\n  }\n\n  COMMON_GET_IF(host::kU4)\n  COMMON_GET_IF(host::kU4B8)\n  COMMON_GET_IF(host::kU8B128)\n\n  FP4_GET_IF(host::kFE2M1f)\n\n  BIGGROUP_GET_IF(host::kFE4M3fn)\n\n  ACT_GET_IF(host::kU4B8)\n  ACT_GET_IF(host::kU8B128)\n\n  if (std::is_same<scalar_t, half>::value) {\n    if (false) {\n    }\n    FZP_GET_IF(host::kU4)\n  }\n\n  return kernel;\n}\n\ntemplate <typename scalar_t>\nexec_config_t determine_exec_config(\n    const host::ScalarType& q_type,\n    int prob_m,\n    int prob_n,\n    int prob_k,\n    int thread_m_blocks,\n    bool m_block_size_8,\n    int num_bits,\n    int group_size,\n    bool has_act_order,\n    bool is_k_full,\n    bool has_zp,\n    bool is_zp_float,\n    int max_shared_mem,\n    int sms) {\n  exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}};\n  thread_config_t* thread_configs = thread_m_blocks > 1 ? large_batch_thread_configs : small_batch_thread_configs;\n  int thread_configs_size = thread_m_blocks > 1 ? sizeof(large_batch_thread_configs) / sizeof(thread_config_t)\n                                                : sizeof(small_batch_thread_configs) / sizeof(thread_config_t);\n\n  for (int i = 0; i < thread_configs_size; i++) {\n    thread_config_t th_config = thread_configs[i];\n\n    if (!is_valid_config(\n            th_config,\n            thread_m_blocks,\n            prob_m,\n            prob_n,\n            prob_k,\n            num_bits,\n            group_size,\n            has_act_order,\n            is_k_full,\n            has_zp,\n            is_zp_float,\n            max_shared_mem)) {\n      continue;\n    }\n\n    int cache_size = get_kernel_cache_size(\n        th_config,\n        thread_m_blocks,\n        prob_m,\n        prob_n,\n        prob_k,\n        num_bits,\n        group_size,\n        has_act_order,\n        is_k_full,\n        has_zp,\n        is_zp_float);\n\n    int group_blocks = 0;\n    if (!has_act_order) {\n      group_blocks = group_size == -1 ? -1 : group_size / 16;\n    }\n\n    auto kernel = get_marlin_kernel<scalar_t>(\n        q_type,\n        thread_m_blocks,\n        th_config.thread_n / 16,\n        th_config.thread_k / 16,\n        m_block_size_8,\n        has_act_order,\n        has_zp,\n        group_blocks,\n        th_config.num_threads,\n        is_zp_float);\n\n    if (kernel == MarlinDefault) continue;\n\n    // int m_tiles = div_ceil(prob_m, thread_m_blocks * 16);\n    // int n_tiles = prob_n / th_config.thread_n;\n    // int k_tiles = prob_k / th_config.thread_k;\n\n    return {1, th_config};\n  }\n\n  return exec_cfg;\n}\n\ntemplate <typename scalar_t>\nvoid marlin_mm(\n    const void* A,\n    const void* B,\n    void* C,\n    void* C_tmp,\n    void* s,\n    void* s2,\n    void* zp,\n    void* g_idx,\n    void* perm,\n    void* a_tmp,\n    int prob_m,\n    int prob_n,\n    int prob_k,\n    int lda,\n    void* workspace,\n    host::ScalarType const& q_type,\n    bool has_act_order,\n    bool is_k_full,\n    bool has_zp,\n    int num_groups,\n    int group_size,\n    int dev,\n    cudaStream_t stream,\n    int thread_k_init,\n    int thread_n_init,\n    int sms,\n    bool use_atomic_add,\n    bool use_fp32_reduce,\n    bool is_zp_float) {\n  if (has_zp) {\n    host::RuntimeCheck(\n        q_type == host::kU4 || q_type == host::kU8, \"q_type must be u4 or u8 when has_zp = True. Got = \", q_type.str());\n  } else {\n    host::RuntimeCheck(\n        q_type == host::kU4B8 || q_type == host::kU8B128 || q_type == host::kFE4M3fn || q_type == host::kFE2M1f,\n        \"q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when \"\n        \"has_zp = False. Got = \",\n        q_type.str());\n  }\n\n  host::RuntimeCheck(\n      prob_m > 0 && prob_n > 0 && prob_k > 0, \"Invalid MNK = [\", prob_m, \", \", prob_n, \", \", prob_k, \"]\");\n\n  int group_blocks = 0;\n  if (has_act_order) {\n    if (is_k_full) {\n      host::RuntimeCheck(group_size != -1);\n      group_blocks = group_size / 16;\n      host::RuntimeCheck(\n          prob_k % group_blocks == 0, \"prob_k = \", prob_k, \" is not divisible by group_blocks = \", group_blocks);\n    } else {\n      host::RuntimeCheck(group_size == 0);\n      group_blocks = 0;\n    }\n  } else {\n    if (group_size == -1) {\n      group_blocks = -1;\n    } else {\n      group_blocks = group_size / 16;\n      host::RuntimeCheck(\n          prob_k % group_blocks == 0, \"prob_k = \", prob_k, \" is not divisible by group_blocks = \", group_blocks);\n    }\n  }\n\n  int num_bits = q_type.size_bits();\n  const int4* A_ptr = (const int4*)A;\n  const int4* B_ptr = (const int4*)B;\n  int4* C_ptr = (int4*)C;\n  int4* C_tmp_ptr = (int4*)C_tmp;\n  const int4* s_ptr = (const int4*)s;\n  const uint16_t* s2_ptr = (const uint16_t*)s2;\n  const int4* zp_ptr = (const int4*)zp;\n  const int* g_idx_ptr = (const int*)g_idx;\n  const int* perm_ptr = (const int*)perm;\n  int4* a_tmp_ptr = (int4*)a_tmp;\n\n  int* locks = (int*)workspace;\n\n  if (has_act_order) {\n    // Permute A columns\n    int block_rows = div_ceil(prob_m, sms);\n    host::LaunchKernel(sms, default_threads, stream)(\n        permute_cols_kernel, A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, lda, block_rows);\n    A_ptr = a_tmp_ptr;\n    lda = prob_k;\n\n    // If we have a full K, then we can run the non-act-order version of Marlin\n    // (since the weight rows are reordered by increasing group ids, and by\n    // having a full K, we have full original groups)\n    if (is_k_full) has_act_order = false;\n  }\n\n  int max_shared_mem = 0;\n  host::RuntimeDeviceCheck(cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev));\n  host::RuntimeCheck(max_shared_mem > 0);\n\n  int max_par = 16;\n  if (prob_n <= 4096) max_par = 16 * 8;\n  int max_shared_mem_new = max_shared_mem;\n  int rest_m = prob_m;\n  int max_thread_m_blocks = 4;\n  while (rest_m) {\n    int par_count = rest_m / (max_thread_m_blocks * 16);\n    if (par_count > max_par) par_count = max_par;\n    int prob_m_split = par_count > 0 ? (par_count * (max_thread_m_blocks * 16)) : rest_m;\n\n    int thread_k = thread_k_init;\n    int thread_n = thread_n_init;\n\n    int thread_m_blocks = min(div_ceil(prob_m_split, 16), max_thread_m_blocks);\n    int m_block_size_8 = prob_m_split <= 8;\n\n    // Set thread config\n    exec_config_t exec_cfg;\n    thread_config_t thread_tfg;\n    if (thread_k != -1 && thread_n != -1) {\n      thread_tfg = thread_config_t{thread_k, thread_n, default_threads};\n      exec_cfg = exec_config_t{1, thread_tfg};\n      host::RuntimeCheck(prob_n % thread_n == 0, \"prob_n = \", prob_n, \" is not divisible by thread_n = \", thread_n);\n      host::RuntimeCheck(prob_k % thread_k == 0, \"prob_k = \", prob_k, \" is not divisible by thread_k = \", thread_k);\n    } else {\n      // Auto config\n      exec_cfg = determine_exec_config<scalar_t>(\n          q_type,\n          prob_m_split,\n          prob_n,\n          prob_k,\n          thread_m_blocks,\n          m_block_size_8,\n          num_bits,\n          group_size,\n          has_act_order,\n          is_k_full,\n          has_zp,\n          is_zp_float,\n          max_shared_mem,\n          sms);\n      thread_tfg = exec_cfg.tb_cfg;\n      if (thread_tfg.thread_k == -1 && max_thread_m_blocks > 1) {\n        max_thread_m_blocks--;\n        continue;\n      }\n    }\n\n    int num_threads = thread_tfg.num_threads;\n    thread_k = thread_tfg.thread_k;\n    thread_n = thread_tfg.thread_n;\n    int blocks = sms * exec_cfg.blocks_per_sm;\n    if (exec_cfg.blocks_per_sm > 1) max_shared_mem_new = max_shared_mem / exec_cfg.blocks_per_sm - 1024;\n\n    int thread_k_blocks = thread_k / 16;\n    int thread_n_blocks = thread_n / 16;\n\n    host::RuntimeCheck(\n        is_valid_config(\n            thread_tfg,\n            thread_m_blocks,\n            prob_m_split,\n            prob_n,\n            prob_k,\n            num_bits,\n            group_size,\n            has_act_order,\n            is_k_full,\n            has_zp,\n            is_zp_float,\n            max_shared_mem_new),\n        \"Invalid thread config: thread_m_blocks = \",\n        thread_m_blocks,\n        \", thread_k = \",\n        thread_tfg.thread_k,\n        \", thread_n = \",\n        thread_tfg.thread_n,\n        \", num_threads = \",\n        thread_tfg.num_threads,\n        \" for MKN = [\",\n        prob_m,\n        \", \",\n        prob_k,\n        \", \",\n        prob_n,\n        \"] and num_bits = \",\n        num_bits,\n        \", prob_m_split = \",\n        prob_m_split,\n        \", group_size = \",\n        group_size,\n        \", has_act_order = \",\n        has_act_order,\n        \", is_k_full = \",\n        is_k_full,\n        \", has_zp = \",\n        has_zp,\n        \", is_zp_float = \",\n        is_zp_float,\n        \", max_shared_mem_new = \",\n        max_shared_mem_new);\n\n    auto kernel = get_marlin_kernel<scalar_t>(\n        q_type,\n        thread_m_blocks,\n        thread_n_blocks,\n        thread_k_blocks,\n        m_block_size_8,\n        has_act_order,\n        has_zp,\n        group_blocks,\n        num_threads,\n        is_zp_float);\n\n    if (kernel == MarlinDefault) {\n      host::Panic(\n          \"Unsupported shapes: MNK = [\",\n          prob_m,\n          \", \",\n          prob_n,\n          \", \",\n          prob_k,\n          \"]\",\n          \", has_act_order = \",\n          has_act_order,\n          \", num_groups = \",\n          num_groups,\n          \", group_size = \",\n          group_size,\n          \", prob_m_split = \",\n          prob_m_split,\n          \", thread_m_blocks = \",\n          thread_m_blocks,\n          \", thread_n_blocks = \",\n          thread_n_blocks,\n          \", thread_k_blocks = \",\n          thread_k_blocks,\n          \", num_threads = \",\n          num_threads,\n          \", num_bits = \",\n          num_bits);\n    }\n\n    host::RuntimeDeviceCheck(\n        cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem_new));\n\n    bool part_use_atomic_add = use_atomic_add && div_ceil(prob_m_split, 64) * prob_n <= 2048;\n\n    host::LaunchKernel(blocks, num_threads, stream, max_shared_mem_new)(\n        kernel,\n        A_ptr,\n        B_ptr,\n        C_ptr,\n        C_tmp_ptr,\n        s_ptr,\n        s2_ptr,\n        zp_ptr,\n        g_idx_ptr,\n        num_groups,\n        prob_m_split,\n        prob_n,\n        prob_k,\n        lda,\n        locks,\n        part_use_atomic_add,\n        use_fp32_reduce,\n        max_shared_mem_new);\n\n    A_ptr += prob_m_split * (lda / 8);\n    C_ptr += prob_m_split * (prob_n / 8);\n    rest_m -= prob_m_split;\n  }\n}\n\n#endif\n\n}  // namespace device::marlin\n\ntemplate <typename scalar_t>\nvoid gptq_marlin_gemm(\n    tvm::ffi::TensorView a,\n    tvm::ffi::TensorView b_q_weight,\n    tvm::ffi::TensorView b_scales,\n    tvm::ffi::TensorView global_scale,\n    tvm::ffi::TensorView b_zeros,\n    tvm::ffi::TensorView g_idx,\n    tvm::ffi::TensorView perm,\n    tvm::ffi::TensorView c,\n    tvm::ffi::TensorView c_tmp,\n    tvm::ffi::TensorView a_tmp,\n    tvm::ffi::TensorView workspace,\n    int64_t b_q_type_id,\n    bool is_k_full,\n    bool use_atomic_add,\n    bool use_fp32_reduce,\n    bool is_zp_float) {\n  using namespace host;\n\n  ScalarType const b_q_type = ScalarType::from_id(b_q_type_id);\n  int pack_factor = 32 / b_q_type.size_bits();\n\n  // Bind symbolic sizes\n  auto M = SymbolicSize{\"M\"};\n  auto K = SymbolicSize{\"K\"};\n  auto N = SymbolicSize{\"N\"};\n  auto device = SymbolicDevice{};\n  device.set_options<kDLCUDA>();\n\n  // Verify a: [M, K]\n  auto lda = SymbolicSize{\"lda\"};\n  TensorMatcher({M, K}).with_strides({lda, 1}).with_dtype<scalar_t>().with_device(device).verify(a);\n\n  int64_t size_m = M.unwrap();\n  int64_t size_k = K.unwrap();\n\n  // Verify b_q_weight: [K/tile_size, packed_N]\n  RuntimeCheck(\n      size_k % device::marlin::tile_size == 0,\n      \"size_k = \",\n      size_k,\n      \" is not divisible by tile_size = \",\n      device::marlin::tile_size);\n  int64_t expected_bqw_dim0 = size_k / device::marlin::tile_size;\n  auto bqw_dim0 = SymbolicSize{\"bqw_dim0\"};\n  auto bqw_dim1 = SymbolicSize{\"bqw_dim1\"};\n  bqw_dim0.set_value(expected_bqw_dim0);\n  TensorMatcher({bqw_dim0, bqw_dim1}).with_dtype<int32_t>().with_device(device).verify(b_q_weight);\n\n  RuntimeCheck(\n      b_q_weight.size(1) % device::marlin::tile_size == 0,\n      \"b_q_weight.size(1) = \",\n      b_q_weight.size(1),\n      \" is not divisible by tile_size = \",\n      device::marlin::tile_size);\n  int64_t actual_size_n = (b_q_weight.size(1) / device::marlin::tile_size) * pack_factor;\n  N.set_value(actual_size_n);\n  int64_t size_n = N.unwrap();\n\n  // Verify stride alignment\n  int64_t a_stride0 = a.stride(0);\n  RuntimeCheck(a_stride0 % 8 == 0, \"a.stride(0) must be divisible by 8\");\n\n  // Verify b_scales: [num_groups, N]\n  auto num_groups_sym = SymbolicSize{\"num_groups\"};\n  TensorMatcher({num_groups_sym, N}).with_device(device).verify(b_scales);\n  int num_groups = static_cast<int>(num_groups_sym.unwrap());\n\n  // Verify c: [M, N]\n  TensorMatcher({M, N}).with_dtype<scalar_t>().with_device(device).verify(c);\n\n  // Early return for zero-size M\n  if (size_m == 0) return;\n\n  // Determine has_act_order from g_idx/perm sizes\n  int64_t g_idx_size = g_idx.size(0);\n  int64_t perm_size = perm.size(0);\n  bool has_act_order = g_idx_size > 0 && perm_size > 0;\n\n  if (has_act_order) {\n    RuntimeCheck(\n        (g_idx_size == size_k && perm_size == size_k),\n        \"Unexpected g_idx.size(0) = \",\n        g_idx_size,\n        \" and perm.size(0) = \",\n        perm_size,\n        \", where size_k = \",\n        size_k);\n  }\n\n  // Determine has_zp from b_zeros size\n  int64_t b_zeros_size = b_zeros.size(0);\n  bool has_zp = b_zeros_size > 0;\n\n  if (has_zp) {\n    RuntimeCheck(\n        b_q_type == kU4 || b_q_type == kU8, \"b_q_type must be u4 or u8 when has_zp = True. Got = \", b_q_type.str());\n  } else {\n    RuntimeCheck(\n        b_q_type == kU4B8 || b_q_type == kU8B128 || b_q_type == kFE4M3fn || b_q_type == kFE2M1f,\n        \"b_q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when \"\n        \"has_zp = False. Got = \",\n        b_q_type.str());\n  }\n\n  if (has_zp && is_zp_float) {\n    RuntimeCheck(\n        std::is_same<scalar_t, fp16_t>::value, \"Computation type must be float16 (half) when using float zero points.\");\n  }\n\n  // Verify b_zeros shape\n  if (has_zp) {\n    RuntimeCheck(b_zeros.dim() == 2, \"b_zeros rank = \", b_zeros.dim(), \" is not 2\");\n    if (is_zp_float) {\n      RuntimeCheck(b_zeros.size(1) == size_n, \"b_zeros dim 1 = \", b_zeros.size(1), \" is not size_n = \", size_n);\n      RuntimeCheck(\n          num_groups == b_zeros.size(0), \"b_zeros dim 0 = \", b_zeros.size(0), \" is not num_groups = \", num_groups);\n      RuntimeCheck(num_groups != -1, \"num_groups must be != -1\");\n    } else {\n      RuntimeCheck(\n          b_zeros.size(0) == num_groups, \"b_zeros dim 0 = \", b_zeros.size(0), \" is not num_groups = \", num_groups);\n      RuntimeCheck(\n          b_zeros.size(1) == size_n / pack_factor,\n          \"b_zeros dim 1 = \",\n          b_zeros.size(1),\n          \" is not size_n / pack_factor = \",\n          size_n / pack_factor);\n    }\n  }\n\n  // Verify global_scale\n  int64_t global_scale_size = global_scale.size(0);\n  if (global_scale_size > 0) {\n    RuntimeCheck(b_q_type == kFE2M1f, \"global_scale can only be used for float4_e2m1f.\");\n  } else {\n    RuntimeCheck(!(b_q_type == kFE2M1f), \"the global_scale parameter must be passed for float4_e2m1f.\");\n  }\n\n  // Derive group_size\n  int group_size = -1;\n  if (has_act_order) {\n    if (is_k_full) {\n      RuntimeCheck(num_groups > 1, \"For act_order, num_groups must be > 1\");\n      RuntimeCheck(size_k % num_groups == 0, \"size_k = \", size_k, \", is not divisible by num_groups = \", num_groups);\n      group_size = static_cast<int>(size_k / num_groups);\n    } else {\n      group_size = 0;\n    }\n  } else {\n    if (num_groups > 1) {\n      RuntimeCheck(size_k % num_groups == 0, \"size_k = \", size_k, \", is not divisible by num_groups = \", num_groups);\n      group_size = static_cast<int>(size_k / num_groups);\n    } else {\n      group_size = -1;\n    }\n  }\n\n  // Verify workspace and get device info\n  RuntimeCheck(\n      size_n % device::marlin::min_thread_n == 0,\n      \"size_n = \",\n      size_n,\n      \", is not divisible by min_thread_n = \",\n      device::marlin::min_thread_n);\n\n  DLDevice dl_device = device.unwrap();\n  int dev = dl_device.device_id;\n  cudaStream_t stream = LaunchKernel::resolve_device(dl_device);\n\n  int sms = -1;\n  RuntimeDeviceCheck(cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev));\n\n  RuntimeCheck(\n      workspace.size(0) >= sms, \"workspace.size(0) = \", workspace.size(0), \" is below min_workspace_size = \", sms);\n\n  // Hardcoded defaults (auto config)\n  int thread_k_init = -1;\n  int thread_n_init = -1;\n\n  // Compute c_tmp and a_tmp pointers\n  // c_tmp and a_tmp are pre-allocated by caller\n\n  device::marlin::marlin_mm<scalar_t>(\n      a.data_ptr(),\n      b_q_weight.data_ptr(),\n      c.data_ptr(),\n      c_tmp.data_ptr(),\n      b_scales.data_ptr(),\n      global_scale.data_ptr(),\n      b_zeros.data_ptr(),\n      g_idx.data_ptr(),\n      perm.data_ptr(),\n      a_tmp.data_ptr(),\n      static_cast<int>(size_m),\n      static_cast<int>(size_n),\n      static_cast<int>(size_k),\n      static_cast<int>(a_stride0),\n      workspace.data_ptr(),\n      b_q_type,\n      has_act_order,\n      is_k_full,\n      has_zp,\n      num_groups,\n      group_size,\n      dev,\n      stream,\n      thread_k_init,\n      thread_n_init,\n      sms,\n      use_atomic_add,\n      use_fp32_reduce,\n      is_zp_float);\n}\n"
  },
  {
    "path": "python/sglang/jit_kernel/csrc/gemm/marlin/gptq_marlin_repack.cuh",
    "content": "/*\n * Modified by Neural Magic\n * Copyright (C) Marlin.2024 Elias Frantar\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *         http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n/*\n * Adapted from https://github.com/IST-DASLab/marlin\n */\n\n#pragma once\n\n#include <sgl_kernel/tensor.h>\n\n#include <sgl_kernel/utils.cuh>\n\n#include \"marlin.cuh\"\n\nnamespace device::marlin {\n\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800\ntemplate <int const num_threads, int const num_bits, bool const has_perm>\n__global__ void gptq_marlin_repack_kernel(\n    uint32_t const* __restrict__ b_q_weight_ptr,\n    uint32_t const* __restrict__ perm_ptr,\n    uint32_t* __restrict__ out_ptr,\n    int size_k,\n    int size_n) {\n  return;\n}\n#else\ntemplate <int const num_threads, int const num_bits, bool const has_perm>\n__global__ void gptq_marlin_repack_kernel(\n    uint32_t const* __restrict__ b_q_weight_ptr,\n    uint32_t const* __restrict__ perm_ptr,\n    uint32_t* __restrict__ out_ptr,\n    int size_k,\n    int size_n) {\n  constexpr int pack_factor = 32 / num_bits;\n\n  int k_tiles = size_k / tile_k_size;\n  int n_tiles = size_n / tile_n_size;\n  int block_k_tiles = div_ceil(k_tiles, gridDim.x);\n\n  auto start_k_tile = blockIdx.x * block_k_tiles;\n  if (start_k_tile >= k_tiles) {\n    return;\n  }\n\n  int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);\n\n  // Wait until the next thread tile has been loaded to shared memory.\n  auto wait_for_stage = [&]() {\n    // We only have `stages - 2` active fetches since we are double buffering\n    // and can only issue the next fetch when it is guaranteed that the previous\n    // shared memory load is fully complete (as it may otherwise be\n    // overwritten).\n    cp_async_wait<repack_stages - 2>();\n    __syncthreads();\n  };\n\n  extern __shared__ int4 sh[];\n\n  constexpr int perm_size = tile_k_size / 4;\n\n  int4* sh_perm_ptr = sh;\n  int4* sh_pipe_ptr = sh_perm_ptr;\n  if constexpr (has_perm) {\n    sh_pipe_ptr += perm_size;\n  }\n\n  constexpr int tile_ints = tile_k_size / pack_factor;\n\n  constexpr int stage_n_threads = tile_n_size / 4;\n  constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints;\n  constexpr int stage_size = stage_k_threads * stage_n_threads;\n\n  auto load_perm_to_shared = [&](int k_tile_id) {\n    int first_k_int4 = (k_tile_id * tile_k_size) / 4;\n\n    int4 const* perm_int4_ptr = reinterpret_cast<int4 const*>(perm_ptr);\n\n    if (threadIdx.x < perm_size) {\n      sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x];\n    }\n    __syncthreads();\n  };\n\n  auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {\n    if (n_tile_id >= n_tiles) {\n      cp_async_fence();\n      return;\n    }\n\n    int first_n = n_tile_id * tile_n_size;\n\n    int4* sh_ptr = sh_pipe_ptr + stage_size * pipe;\n\n    if constexpr (has_perm) {\n      if (threadIdx.x < stage_size) {\n        auto k_id = threadIdx.x / stage_n_threads;\n        auto n_id = threadIdx.x % stage_n_threads;\n\n        uint32_t const* sh_perm_int_ptr = reinterpret_cast<uint32_t const*>(sh_perm_ptr);\n\n        int src_k = sh_perm_int_ptr[k_id];\n        int src_k_packed = src_k / pack_factor;\n\n        cp_async4(\n            &sh_ptr[k_id * stage_n_threads + n_id],\n            reinterpret_cast<int4 const*>(&(b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)])));\n      }\n\n    } else {\n      if (threadIdx.x < stage_size) {\n        auto k_id = threadIdx.x / stage_n_threads;\n        auto n_id = threadIdx.x % stage_n_threads;\n\n        int first_k = k_tile_id * tile_k_size;\n        int first_k_packed = first_k / pack_factor;\n\n        cp_async4(\n            &sh_ptr[k_id * stage_n_threads + n_id],\n            reinterpret_cast<int4 const*>(&(b_q_weight_ptr[(first_k_packed + k_id) * size_n + first_n + (n_id * 4)])));\n      }\n    }\n\n    cp_async_fence();\n  };\n\n  auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {\n    if (n_tile_id >= n_tiles) {\n      return;\n    }\n\n    auto warp_id = threadIdx.x / 32;\n    auto th_id = threadIdx.x % 32;\n\n    if (warp_id >= 4) {\n      return;\n    }\n\n    int tc_col = th_id / 4;\n    int tc_row = (th_id % 4) * 2;\n\n    constexpr int tc_offsets[4] = {0, 1, 8, 9};\n\n    int cur_n = warp_id * 16 + tc_col;\n\n    constexpr int sh_stride = 64;\n    constexpr uint32_t mask = (1 << num_bits) - 1;\n\n    int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe;\n    uint32_t* sh_stage_int_ptr = reinterpret_cast<uint32_t*>(sh_stage_ptr);\n\n    uint32_t* sh_perm_int_ptr = reinterpret_cast<uint32_t*>(sh_perm_ptr);\n\n    uint32_t vals[8];\n\n    if constexpr (has_perm) {\n      for (int i = 0; i < 4; i++) {\n        int k_idx = tc_row + tc_offsets[i];\n\n        uint32_t src_k = sh_perm_int_ptr[k_idx];\n        uint32_t src_k_pos = src_k % pack_factor;\n\n        uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n];\n        uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask;\n\n        uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8];\n        uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask;\n\n        vals[i] = b1_cur_val;\n        vals[4 + i] = b2_cur_val;\n      }\n\n    } else {\n      uint32_t b1_vals[tile_ints];\n      uint32_t b2_vals[tile_ints];\n\n#pragma unroll\n      for (int i = 0; i < tile_ints; i++) {\n        b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i];\n        b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i];\n      }\n\n#pragma unroll\n      for (int i = 0; i < 4; i++) {\n        int cur_elem = tc_row + tc_offsets[i];\n        int cur_int = cur_elem / pack_factor;\n        int cur_pos = cur_elem % pack_factor;\n\n        vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask;\n        vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask;\n      }\n    }\n\n    constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;\n    int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;\n\n    // Result of:\n    // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h\n    if constexpr (num_bits == 4) {\n      constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};\n\n      uint32_t res = 0;\n#pragma unroll\n      for (int i = 0; i < 8; i++) {\n        res |= vals[pack_idx[i]] << (i * 4);\n      }\n\n      out_ptr[out_offset + th_id * 4 + warp_id] = res;\n\n    } else {\n      constexpr int pack_idx[4] = {0, 2, 1, 3};\n\n      uint32_t res1 = 0;\n      uint32_t res2 = 0;\n#pragma unroll\n      for (int i = 0; i < 4; i++) {\n        res1 |= vals[pack_idx[i]] << (i * 8);\n        res2 |= vals[4 + pack_idx[i]] << (i * 8);\n      }\n\n      out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;\n      out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;\n    }\n  };\n\n  auto start_pipes = [&](int k_tile_id, int n_tile_id) {\n#pragma unroll\n    for (int pipe = 0; pipe < repack_stages - 1; pipe++) {\n      fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);\n    }\n\n    wait_for_stage();\n  };\n#pragma unroll\n  for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {\n    int n_tile_id = 0;\n\n    if constexpr (has_perm) {\n      load_perm_to_shared(k_tile_id);\n    }\n\n    start_pipes(k_tile_id, n_tile_id);\n\n    while (n_tile_id < n_tiles) {\n#pragma unroll\n      for (int pipe = 0; pipe < repack_stages; pipe++) {\n        fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, n_tile_id + pipe + repack_stages - 1);\n        repack_tile(pipe, k_tile_id, n_tile_id + pipe);\n        wait_for_stage();\n      }\n      n_tile_id += repack_stages;\n    }\n  }\n}\n#endif\n\n}  // namespace device::marlin\n\n#define CALL_IF_REPACK(NUM_BITS, HAS_PERM)                                                                        \\\n  else if (num_bits == NUM_BITS && has_perm == HAS_PERM) {                                                        \\\n    host::RuntimeDeviceCheck(cudaFuncSetAttribute(                                                                \\\n        device::marlin::gptq_marlin_repack_kernel<device::marlin::repack_threads, NUM_BITS, HAS_PERM>,            \\\n        cudaFuncAttributeMaxDynamicSharedMemorySize,                                                              \\\n        max_shared_mem));                                                                                         \\\n    host::LaunchKernel(blocks, device::marlin::repack_threads, stream, static_cast<std::size_t>(max_shared_mem))( \\\n        device::marlin::gptq_marlin_repack_kernel<device::marlin::repack_threads, NUM_BITS, HAS_PERM>,            \\\n        b_q_weight_ptr,                                                                                           \\\n        perm_ptr,                                                                                                 \\\n        out_ptr,                                                                                                  \\\n        size_k,                                                                                                   \\\n        size_n);                                                                                                  \\\n  }\n\nvoid gptq_marlin_repack(\n    tvm::ffi::TensorView b_q_weight,\n    tvm::ffi::TensorView perm,\n    tvm::ffi::TensorView out,\n    int64_t size_k,\n    int64_t size_n,\n    int64_t num_bits) {\n  using namespace host;\n\n  // Validate num_bits\n  RuntimeCheck(num_bits == 4 || num_bits == 8, \"num_bits must be 4 or 8. Got = \", num_bits);\n  int const pack_factor = 32 / static_cast<int>(num_bits);\n\n  // Validate size alignment\n  RuntimeCheck(\n      size_k % device::marlin::tile_k_size == 0,\n      \"size_k = \",\n      size_k,\n      \" is not divisible by tile_k_size = \",\n      device::marlin::tile_k_size);\n  RuntimeCheck(\n      size_n % device::marlin::tile_n_size == 0,\n      \"size_n = \",\n      size_n,\n      \" is not divisible by tile_n_size = \",\n      device::marlin::tile_n_size);\n\n  // Validate b_q_weight\n  auto bqw_dim0 = SymbolicSize{\"bqw_dim0\"};\n  auto bqw_dim1 = SymbolicSize{\"bqw_dim1\"};\n  bqw_dim0.set_value(size_k / pack_factor);\n  bqw_dim1.set_value(size_n);\n  auto device_ = SymbolicDevice{};\n  device_.set_options<kDLCUDA>();\n  TensorMatcher({bqw_dim0, bqw_dim1}).with_dtype<int32_t>().with_device(device_).verify(b_q_weight);\n\n  // Validate out\n  auto out_dim0 = SymbolicSize{\"out_dim0\"};\n  auto out_dim1 = SymbolicSize{\"out_dim1\"};\n  out_dim0.set_value(size_k / device::marlin::tile_size);\n  out_dim1.set_value(size_n * device::marlin::tile_size / pack_factor);\n  TensorMatcher({out_dim0, out_dim1}).with_dtype<int32_t>().with_device(device_).verify(out);\n\n  // Detect if there is act_order\n  bool has_perm = perm.size(0) != 0;\n\n  // Get ptrs\n  uint32_t const* b_q_weight_ptr = reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());\n  uint32_t const* perm_ptr = reinterpret_cast<uint32_t const*>(perm.data_ptr());\n  uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr());\n\n  // Get dev info\n  DLDevice dl_device = device_.unwrap();\n  int dev = dl_device.device_id;\n  cudaStream_t stream = LaunchKernel::resolve_device(dl_device);\n  int blocks;\n  RuntimeDeviceCheck(cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev));\n\n  int max_shared_mem = 0;\n  RuntimeDeviceCheck(cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev));\n  RuntimeCheck(max_shared_mem > 0, \"max_shared_mem must be > 0\");\n\n  if (false) {\n  }\n  CALL_IF_REPACK(4, false)\n  CALL_IF_REPACK(4, true)\n  CALL_IF_REPACK(8, false)\n  CALL_IF_REPACK(8, true)\n  else {\n    Panic(\"Unsupported repack config: num_bits = \", num_bits, \", has_perm = \", has_perm);\n  }\n}\n\n#undef CALL_IF_REPACK\n"
  },
  {
    "path": "python/sglang/jit_kernel/csrc/gemm/marlin/kernel.h",
    "content": "\n#include <sgl_kernel/scalar_type.hpp>\n\n#include \"marlin.cuh\"\n#include \"marlin_dtypes.cuh\"\n\n#define MARLIN_KERNEL_PARAMS                                                                                         \\\n  const int4 *__restrict__ A, const int4 *__restrict__ B, int4 *__restrict__ C, int4 *__restrict__ C_tmp,            \\\n      const int4 *__restrict__ scales_ptr, const uint16_t *__restrict__ scale2_ptr, const int4 *__restrict__ zp_ptr, \\\n      const int *__restrict__ g_idx, int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks,        \\\n      bool use_atomic_add, bool use_fp32_reduce, int max_shared_mem\n\nnamespace device::marlin {\ntemplate <\n    typename scalar_t,                   // compute dtype, half or nv_float16\n    const host::ScalarTypeId w_type_id,  // weight ScalarType id\n    const int threads,                   // number of threads in a threadblock\n    const int thread_m_blocks,           // number of 16x16 blocks in the m\n                                         // dimension (batchsize) of the\n                                         // threadblock\n    const int thread_n_blocks,           // same for n dimension (output)\n    const int thread_k_blocks,           // same for k dimension (reduction)\n    const bool m_block_size_8,           // whether m_block_size == 8\n                                         // only works when thread_m_blocks == 1\n    const int stages,                    // number of stages for the async global->shared\n                                         // fetch pipeline\n    const int group_blocks,              // number of consecutive 16x16 blocks\n                                         // with a separate quantization scale\n    const bool is_zp_float               // is zero point of float16 type?\n    >\n__global__ void Marlin(MARLIN_KERNEL_PARAMS);\n\n}  // namespace device::marlin\n"
  },
  {
    "path": "python/sglang/jit_kernel/csrc/gemm/marlin/marlin.cuh",
    "content": "#pragma once\n\n#include <sgl_kernel/utils.cuh>\n\n#include <iostream>\n\nnamespace device::marlin {\n// Marlin params\n\n// 8 warps are a good choice since every SM has 4 schedulers and having more\n// than 1 warp per schedule allows some more latency hiding. At the same time,\n// we want relatively few warps to have many registers per warp and small tiles.\nstatic constexpr int default_threads = 256;\n\nstatic constexpr int pipe_stages = 4;  // 4 pipeline stages fit into shared memory\n\nstatic constexpr int min_thread_n = 64;\nstatic constexpr int min_thread_k = 64;\nstatic constexpr int max_thread_n = 256;\n\nstatic constexpr int tile_size = 16;\nstatic constexpr int max_par = 16;\n\n// Repack params\nstatic constexpr int repack_stages = 8;\n\nstatic constexpr int repack_threads = 256;\n\nstatic constexpr int tile_k_size = tile_size;\nstatic constexpr int tile_n_size = tile_k_size * 4;\n\n// Helpers\ntemplate <typename T, int n>\nstruct Vec {\n  T elems[n];\n  __device__ T& operator[](int i) {\n    return elems[i];\n  }\n};\n\nusing I4 = Vec<int, 4>;\n\nusing host::div_ceil;\n\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800\n// No support for async\n#else\n\n__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) {\n  const int BYTES = 16;\n  uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));\n  asm volatile(\n      \"{\\n\"\n      \"   .reg .pred p;\\n\"\n      \"   setp.ne.b32 p, %0, 0;\\n\"\n      \"   @p cp.async.cg.shared.global [%1], [%2], %3;\\n\"\n      \"}\\n\" ::\"r\"((int)pred),\n      \"r\"(smem),\n      \"l\"(glob_ptr),\n      \"n\"(BYTES));\n}\n\n__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {\n  const int BYTES = 16;\n  uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));\n  asm volatile(\n      \"{\\n\"\n      \"   cp.async.cg.shared.global [%0], [%1], %2;\\n\"\n      \"}\\n\" ::\"r\"(smem),\n      \"l\"(glob_ptr),\n      \"n\"(BYTES));\n}\n\n__device__ inline void cp_async_fence() {\n  asm volatile(\"cp.async.commit_group;\\n\" ::);\n}\n\ntemplate <int n>\n__device__ inline void cp_async_wait() {\n  asm volatile(\"cp.async.wait_group %0;\\n\" ::\"n\"(n));\n}\n\n#endif\n\n}  // namespace device::marlin\n"
  },
  {
    "path": "python/sglang/jit_kernel/csrc/gemm/marlin/marlin_dtypes.cuh",
    "content": "#ifndef _data_types_cuh\n#define _data_types_cuh\n#include <sgl_kernel/utils.cuh>\n\n#include \"marlin.cuh\"\n\nnamespace device::marlin {\n\ntemplate <typename scalar_t>\nclass ScalarType {};\n\ntemplate <>\nclass ScalarType<fp16_t> {\n public:\n  using scalar_t = fp16_t;\n  using scalar_t2 = fp16x2_t;\n\n  // Matrix fragments for tensor core instructions; their precise layout is\n  // documented here:\n  // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type\n  using FragA = Vec<fp16x2_t, 4>;\n  using FragB = Vec<fp16x2_t, 2>;\n  using FragC = Vec<float, 4>;\n  using FragS = Vec<fp16x2_t, 1>;\n  using FragZP = Vec<fp16x2_t, 4>;\n\n  static __device__ float inline num2float(const fp16_t x) {\n    return __half2float(x);\n  }\n\n  static __device__ fp16x2_t inline num2num2(const fp16_t x) {\n    return __half2half2(x);\n  }\n\n  static __device__ fp16x2_t inline nums2num2(const fp16_t x1, const fp16_t x2) {\n    return __halves2half2(x1, x2);\n  }\n\n  static __host__ __device__ fp16_t inline float2num(const float x) {\n    return __float2half(x);\n  }\n};\n\ntemplate <>\nclass ScalarType<bf16_t> {\n public:\n  using scalar_t = bf16_t;\n  using scalar_t2 = bf16x2_t;\n\n  using FragA = Vec<bf16x2_t, 4>;\n  using FragB = Vec<bf16x2_t, 2>;\n  using FragC = Vec<float, 4>;\n  using FragS = Vec<bf16x2_t, 1>;\n  using FragZP = Vec<bf16x2_t, 4>;\n\n#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800\n  static __device__ float inline num2float(const bf16_t x) {\n    return __bfloat162float(x);\n  }\n\n  static __device__ bf16x2_t inline num2num2(const bf16_t x) {\n    return __bfloat162bfloat162(x);\n  }\n\n  static __device__ bf16x2_t inline nums2num2(const bf16_t x1, const bf16_t x2) {\n    return __halves2bfloat162(x1, x2);\n  }\n\n  static __host__ __device__ bf16_t inline float2num(const float x) {\n    return __float2bfloat16(x);\n  }\n#endif\n};\n\n}  // namespace device::marlin\n\n#endif\n"
  },
  {
    "path": "python/sglang/jit_kernel/csrc/gemm/marlin/marlin_template.h",
    "content": "/*\n * Modified by Neural Magic\n * Copyright (C) Marlin.2024 Elias Frantar\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *         http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n/*\n * Adapted from https://github.com/IST-DASLab/marlin\n */\n#include <sgl_kernel/scalar_type.hpp>\n\n#include \"dequant.h\"\n#include \"marlin.cuh\"\n#include \"marlin_dtypes.cuh\"\n\n#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t)                                        \\\n  static_assert(                                                                         \\\n      std::is_same<scalar_t, half>::value || std::is_same<scalar_t, nv_bfloat16>::value, \\\n      \"only float16 and bfloat16 is supported\");\n\nnamespace device::marlin {\n\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800\n\ntemplate <\n    typename scalar_t,                   // compute dtype, half or nv_float16\n    const host::ScalarTypeId w_type_id,  // weight ScalarType id\n    const int threads,                   // number of threads in a threadblock\n    const int thread_m_blocks,           // number of 16x16 blocks in the m\n                                         // dimension (batchsize) of the\n                                         // threadblock\n    const int thread_n_blocks,           // same for n dimension (output)\n    const int thread_k_blocks,           // same for k dimension (reduction)\n    const bool m_block_size_8,           // whether m_block_size == 8\n                                         // only works when thread_m_blocks == 1\n    const int stages,                    // number of stages for the async global->shared\n                                         // fetch pipeline\n    const bool has_act_order,            // whether act_order is enabled\n    const int group_blocks,              // number of consecutive 16x16 blocks\n                                         // with a separate quantization scale\n    const bool is_zp_float               // is zero point of float16 type?\n    >\n__global__ void Marlin(\n    const int4* __restrict__ A,           // fp16 input matrix of shape mxk\n    const int4* __restrict__ B,           // 4bit quantized weight matrix of shape kxn\n    int4* __restrict__ C,                 // fp16 output buffer of shape mxn\n    int4* __restrict__ C_tmp,             // fp32 tmp output buffer (for reduce)\n    const int4* __restrict__ scales_ptr,  // fp16 quantization scales of shape\n                                          // (k/groupsize)xn\n    const int* __restrict__ g_idx,        // int32 group indices of shape k\n    int num_groups,                       // number of scale groups per output channel\n    int prob_m,                           // batch dimension m\n    int prob_n,                           // output dimension n\n    int prob_k,                           // reduction dimension k\n    int* locks,                           // extra global storage for barrier synchronization\n    bool use_fp32_reduce                  // whether to use fp32 global reduce\n) {}\n\n}  // namespace device::marlin\n\n#else\n\n// m16n8k16 tensor core mma instruction with fp16 inputs and fp32\n// output/accumulation.\ntemplate <typename scalar_t>\n__device__ inline void\nmma(const typename ScalarType<scalar_t>::FragA& a_frag,\n    const typename ScalarType<scalar_t>::FragB& frag_b,\n    typename ScalarType<scalar_t>::FragC& frag_c) {\n  const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);\n  const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);\n  float* c = reinterpret_cast<float*>(&frag_c);\n  if constexpr (std::is_same<scalar_t, half>::value) {\n    asm volatile(\n        \"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \"\n        \"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\\n\"\n        : \"=f\"(c[0]), \"=f\"(c[1]), \"=f\"(c[2]), \"=f\"(c[3])\n        : \"r\"(a[0]), \"r\"(a[1]), \"r\"(a[2]), \"r\"(a[3]), \"r\"(b[0]), \"r\"(b[1]), \"f\"(c[0]), \"f\"(c[1]), \"f\"(c[2]), \"f\"(c[3]));\n  } else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {\n    asm volatile(\n        \"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 \"\n        \"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\\n\"\n        : \"=f\"(c[0]), \"=f\"(c[1]), \"=f\"(c[2]), \"=f\"(c[3])\n        : \"r\"(a[0]), \"r\"(a[1]), \"r\"(a[2]), \"r\"(a[3]), \"r\"(b[0]), \"r\"(b[1]), \"f\"(c[0]), \"f\"(c[1]), \"f\"(c[2]), \"f\"(c[3]));\n  } else {\n    STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);\n  }\n}\n\ntemplate <typename scalar_t>\n__device__ inline void mma_trans(\n    const typename ScalarType<scalar_t>::FragA& a_frag,\n    const typename ScalarType<scalar_t>::FragB& frag_b,\n    const typename ScalarType<scalar_t>::FragB& frag_b2,\n    typename ScalarType<scalar_t>::FragC& frag_c) {\n  const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);\n  const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);\n  const uint32_t* b2 = reinterpret_cast<const uint32_t*>(&frag_b2);\n  float* c = reinterpret_cast<float*>(&frag_c);\n  if constexpr (std::is_same<scalar_t, half>::value) {\n    asm volatile(\n        \"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \"\n        \"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\\n\"\n        : \"=f\"(c[0]), \"=f\"(c[1]), \"=f\"(c[2]), \"=f\"(c[3])\n        : \"r\"(b[0]),\n          \"r\"(b2[0]),\n          \"r\"(b[1]),\n          \"r\"(b2[1]),\n          \"r\"(a[0]),\n          \"r\"(a[1]),\n          \"f\"(c[0]),\n          \"f\"(c[1]),\n          \"f\"(c[2]),\n          \"f\"(c[3]));\n  } else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {\n    asm volatile(\n        \"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 \"\n        \"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\\n\"\n        : \"=f\"(c[0]), \"=f\"(c[1]), \"=f\"(c[2]), \"=f\"(c[3])\n        : \"r\"(b[0]),\n          \"r\"(b2[0]),\n          \"r\"(b[1]),\n          \"r\"(b2[1]),\n          \"r\"(a[0]),\n          \"r\"(a[1]),\n          \"f\"(c[0]),\n          \"f\"(c[1]),\n          \"f\"(c[2]),\n          \"f\"(c[3]));\n  } else {\n    STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);\n  }\n}\n\n// Instruction for loading a full 16x16 matrix fragment of operand A from shared\n// memory, directly in tensor core layout.\ntemplate <int count, typename scalar_t>\n__device__ inline void ldsm(typename ScalarType<scalar_t>::FragA& frag_a, const void* smem_ptr) {\n  uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);\n  uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));\n  if constexpr (count == 4) {\n    asm volatile(\"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\\n\"\n                 : \"=r\"(a[0]), \"=r\"(a[1]), \"=r\"(a[2]), \"=r\"(a[3])\n                 : \"r\"(smem));\n  } else if constexpr (count == 2) {\n    asm volatile(\"ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\\n\" : \"=r\"(a[0]), \"=r\"(a[1]) : \"r\"(smem));\n  } else if constexpr (count == 1) {\n    asm volatile(\"ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\\n\" : \"=r\"(a[0]) : \"r\"(smem));\n  } else {\n    static_assert(count == 1 || count == 2 || count == 4, \"invalid count\");\n  }\n}\n\n// Multiply dequantized values by the corresponding quantization scale; used\n// only for grouped quantization.\ntemplate <typename scalar_t>\n__device__ inline void\nscale(typename ScalarType<scalar_t>::FragB& frag_b, typename ScalarType<scalar_t>::FragS& frag_s, int i) {\n  using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;\n  scalar_t2 s = ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_s)[i]);\n  frag_b[0] = __hmul2(frag_b[0], s);\n  frag_b[1] = __hmul2(frag_b[1], s);\n}\n\ntemplate <typename scalar_t>\n__device__ inline void scale_and_sub(typename ScalarType<scalar_t>::FragB& frag_b, scalar_t s, scalar_t zp) {\n  using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;\n  scalar_t2 s2 = ScalarType<scalar_t>::num2num2(s);\n  scalar_t2 zp2 = ScalarType<scalar_t>::num2num2(zp);\n  frag_b[0] = __hfma2(frag_b[0], s2, __hneg2(zp2));\n  frag_b[1] = __hfma2(frag_b[1], s2, __hneg2(zp2));\n}\n\ntemplate <typename scalar_t>\n__device__ inline void\nsub_zp(typename ScalarType<scalar_t>::FragB& frag_b, typename ScalarType<scalar_t>::scalar_t2& frag_zp, int i) {\n  using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;\n  scalar_t2 zp = ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_zp)[i]);\n  frag_b[0] = __hsub2(frag_b[0], zp);\n  frag_b[1] = __hsub2(frag_b[1], zp);\n}\n\n// Same as above, but for act_order (each K is multiplied individually)\ntemplate <typename scalar_t>\n__device__ inline void scale4(\n    typename ScalarType<scalar_t>::FragB& frag_b,\n    typename ScalarType<scalar_t>::FragS& frag_s_1,\n    typename ScalarType<scalar_t>::FragS& frag_s_2,\n    typename ScalarType<scalar_t>::FragS& frag_s_3,\n    typename ScalarType<scalar_t>::FragS& frag_s_4,\n    int i) {\n  using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;\n  scalar_t2 s_val_1_2;\n  s_val_1_2.x = reinterpret_cast<scalar_t*>(&frag_s_1)[i];\n  s_val_1_2.y = reinterpret_cast<scalar_t*>(&frag_s_2)[i];\n\n  scalar_t2 s_val_3_4;\n  s_val_3_4.x = reinterpret_cast<scalar_t*>(&frag_s_3)[i];\n  s_val_3_4.y = reinterpret_cast<scalar_t*>(&frag_s_4)[i];\n\n  frag_b[0] = __hmul2(frag_b[0], s_val_1_2);\n  frag_b[1] = __hmul2(frag_b[1], s_val_3_4);\n}\n\n// Given 2 floats multiply by 2 scales (halves)\ntemplate <typename scalar_t>\n__device__ inline void scale_float(float* c, typename ScalarType<scalar_t>::FragS& s) {\n  scalar_t* s_ptr = reinterpret_cast<scalar_t*>(&s);\n  c[0] = __fmul_rn(c[0], ScalarType<scalar_t>::num2float(s_ptr[0]));\n  c[1] = __fmul_rn(c[1], ScalarType<scalar_t>::num2float(s_ptr[1]));\n}\n\n// Wait until barrier reaches `count`, then lock for current threadblock.\n__device__ inline void barrier_acquire(int* lock, int count) {\n  if (threadIdx.x == 0) {\n    int state = -1;\n    do\n      // Guarantee that subsequent writes by this threadblock will be visible\n      // globally.\n      asm volatile(\"ld.global.acquire.gpu.b32 %0, [%1];\\n\" : \"=r\"(state) : \"l\"(lock));\n    while (state != count);\n  }\n  __syncthreads();\n}\n\n// Release barrier and increment visitation count.\n__device__ inline void barrier_release(int* lock, bool reset = false) {\n  __syncthreads();\n  if (threadIdx.x == 0) {\n    if (reset) {\n      lock[0] = 0;\n      return;\n    }\n    int val = 1;\n    // Make sure that all writes since acquiring this barrier are visible\n    // globally, while releasing the barrier.\n    asm volatile(\"fence.acq_rel.gpu;\\n\");\n    asm volatile(\"red.relaxed.gpu.global.add.s32 [%0], %1;\\n\" : : \"l\"(lock), \"r\"(val));\n  }\n}\n\n// Wait until value of lock to be negative, and then add 1\n__device__ inline void wait_negative_and_add(int* lock) {\n  if (threadIdx.x == 0) {\n    int state = 0;\n    do\n      // Guarantee that subsequent writes by this threadblock will be visible\n      // globally.\n      asm volatile(\"ld.global.acquire.gpu.b32 %0, [%1];\\n\" : \"=r\"(state) : \"l\"(lock));\n    while (state >= 0);\n    atomicAdd(lock, 1);\n  }\n  __syncthreads();\n}\n\ntemplate <\n    typename scalar_t,                   // compute dtype, half or nv_float16\n    const host::ScalarTypeId w_type_id,  // weight ScalarType id\n    const int threads,                   // number of threads in a threadblock\n    const int thread_m_blocks,           // number of 16x16 blocks in the m\n                                         // dimension (batchsize) of the\n                                         // threadblock\n    const int thread_n_blocks,           // same for n dimension (output)\n    const int thread_k_blocks,           // same for k dimension (reduction)\n    const bool m_block_size_8,           // whether m_block_size == 8\n                                         // only works when thread_m_blocks == 1\n    const int stages,                    // number of stages for the async global->shared\n                                         // fetch pipeline\n    const int group_blocks,              // number of consecutive 16x16 blocks\n                                         // with a separate quantization scale\n    const bool is_zp_float               // is zero point of float16 type?\n    >\n__global__ void Marlin(\n    const int4* __restrict__ A,               // fp16 input matrix of shape mxk\n    const int4* __restrict__ B,               // 4bit quantized weight matrix of shape kxn\n    int4* __restrict__ C,                     // fp16 output buffer of shape mxn\n    int4* __restrict__ C_tmp,                 // fp32 tmp output buffer (for reduce)\n    const int4* __restrict__ scales_ptr,      // fp16 quantization scales of shape\n                                              // (k/groupsize)xn\n    const uint16_t* __restrict__ scale2_ptr,  // fp16 global scale (for nvfp4\n                                              // only)\n    const int4* __restrict__ zp_ptr,          // 4bit packed zero-points of shape\n                                              // (k/groupsize)x(n/pack_factor)\n    const int* __restrict__ g_idx,            // int32 group indices of shape k\n    int num_groups,                           // number of scale groups per output channel\n    int prob_m,                               // batch dimension m\n    int prob_n,                               // output dimension n\n    int prob_k,                               // reduction dimension k\n    int lda,                                  // A.stride(0), equal to prob_k is A is contiguous\n    int* locks,                               // extra global storage for barrier synchronization\n    bool use_atomic_add,                      // whether to use atomic add to reduce\n    bool use_fp32_reduce,                     // whether to use fp32 global reduce\n    int max_shared_mem) {\n  // Each threadblock processes one \"stripe\" of the B matrix with (roughly) the\n  // same size, which might involve multiple column \"slices\" (of width 16 *\n  // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM\n  // example:\n  //   0 1 3\n  //   0 2 3\n  //   1 2 4\n  // While this kind of partitioning makes things somewhat more complicated, it\n  // ensures good utilization of all SMs for many kinds of shape and GPU\n  // configurations, while requiring as few slow global cross-threadblock\n  // reductions as possible.\n  using Dtype = ScalarType<scalar_t>;\n  using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;\n  using FragA = typename ScalarType<scalar_t>::FragA;\n  using FragB = typename ScalarType<scalar_t>::FragB;\n  using FragC = typename ScalarType<scalar_t>::FragC;\n  using FragS = typename ScalarType<scalar_t>::FragS;\n  using FragZP = typename ScalarType<scalar_t>::FragZP;\n\n  static constexpr auto w_type = host::ScalarType::from_id(w_type_id);\n  constexpr bool has_zp = w_type == host::kU4 || w_type == host::kU8;\n  constexpr bool is_int_type =\n      w_type == host::kU4 || w_type == host::kU8 || w_type == host::kU4B8 || w_type == host::kU8B128;\n  // see comments of dequant.h for more details\n  constexpr bool dequant_skip_flop = !is_int_type ||\n                                     has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||\n                                     has_zp && !is_zp_float && !(w_type == host::kU8);\n\n  scalar_t2 global_scale;\n\n  if constexpr (w_type == host::kFE2M1f) {\n    uint16_t val = scale2_ptr[0];\n    global_scale = Dtype::num2num2(*reinterpret_cast<scalar_t*>(&val));\n  }\n\n  constexpr bool has_act_order = group_blocks == 0;\n  constexpr int m_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks);\n\n  constexpr int pack_factor = 32 / w_type.size_bits();\n  static_assert(thread_m_blocks == 1 || !m_block_size_8);\n\n  // For larger GEMMs we run multiple batchsize 64 versions in parallel for a\n  // better partitioning with less reductions\n  int parallel = 1;\n  if (prob_m > m_block_size) {\n    parallel = prob_m / m_block_size;\n    prob_m = m_block_size;\n  }\n\n  int k_tiles = prob_k / 16 / thread_k_blocks;\n  int n_tiles = prob_n / 16 / thread_n_blocks;\n  int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x);\n\n  if constexpr (!has_act_order && group_blocks != -1) {\n    if (group_blocks >= thread_k_blocks) {\n      // Ensure that the number of tiles in each stripe is a multiple of the\n      // groupsize; this avoids an annoying special case where a stripe starts\n      // in the middle of group.\n      iters = (group_blocks / thread_k_blocks) * div_ceil(iters, (group_blocks / thread_k_blocks));\n    }\n  }\n\n  int slice_row = (iters * blockIdx.x) % k_tiles;\n  int slice_col_par = (iters * blockIdx.x) / k_tiles;\n  int slice_col = slice_col_par;\n  int slice_iters;      // number of threadblock tiles in the current slice\n  int slice_count = 0;  // total number of active threadblocks in the current slice\n  int slice_idx;        // index of threadblock in current slice; numbered bottom to\n                        // top\n\n  int par_id = 0;\n  int locks_off = 0;\n\n  // We can easily implement parallel problem execution by just remapping\n  // indices and advancing global pointers\n  if (slice_col_par >= n_tiles) {\n    A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * lda / 8;\n    C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;\n    slice_col = slice_col_par % n_tiles;\n    par_id = slice_col_par / n_tiles;\n  }\n  if (parallel * n_tiles >= gridDim.x) {\n    // when parallel * n_tiles >= sms\n    // then there are at most $sms$ conflict tile blocks\n    locks_off = blockIdx.x;\n  } else {\n    locks_off = (iters * blockIdx.x) / k_tiles - 1;\n  }\n\n  // Compute all information about the current slice which is required for\n  // synchronization.\n  auto init_slice = [&](bool first_init = false) {\n    slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);\n    if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;\n    if (slice_iters == 0) return;\n    if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;\n    slice_count = 1;\n    slice_idx = 0;\n    int col_first = iters * div_ceil(k_tiles * slice_col_par, iters);\n    if (col_first <= k_tiles * (slice_col_par + 1)) {\n      int col_off = col_first - k_tiles * slice_col_par;\n      slice_count = div_ceil(k_tiles - col_off, iters);\n      if (col_off > 0) slice_count++;\n      int delta_first = iters * blockIdx.x - col_first;\n      if (delta_first < 0 || (col_off == 0 && delta_first == 0))\n        slice_idx = slice_count - 1;\n      else {\n        slice_idx = slice_count - 1 - delta_first / iters;\n        if (col_off > 0) slice_idx--;\n      }\n    }\n    if (parallel * n_tiles >= gridDim.x) {\n      if (slice_count > 1 && slice_idx == slice_count - 1) {\n        locks_off++;\n      }\n    } else {\n      locks_off++;\n    }\n\n    if (first_init && use_atomic_add && slice_count > 1 && slice_idx == 0) {\n      constexpr int threads_per_m = 16 * thread_n_blocks / 8;\n      int m_per_thread = div_ceil(thread_m_blocks * 16, threads / threads_per_m);\n      if (m_block_size_8) m_per_thread = div_ceil(8, threads / threads_per_m);\n      for (int i = 0; i < m_per_thread; i++) {\n        int row = threads / threads_per_m * i + threadIdx.x / threads_per_m;\n        if (row < prob_m) {\n          int col = slice_col * 16 * thread_n_blocks / 8 + threadIdx.x % threads_per_m;\n          C[row * prob_n / 8 + col] = {0, 0, 0, 0};\n        }\n      }\n      // After write zero to output, write a negative value to lock.\n      // Every SM that processes the same slice would wait for\n      // the negative value, and then atomicAdd 1 to it.\n      // After all SMs are processed, the lock value would back to 0 again.\n      __syncthreads();\n      if (threadIdx.x == 0) locks[locks_off] = 1 - slice_count;\n    }\n\n    if (slice_col == n_tiles) {\n      A += 16 * thread_m_blocks * lda / 8;\n      C += 16 * thread_m_blocks * prob_n / 8;\n      slice_col = 0;\n      par_id++;\n    }\n  };\n  init_slice(true);\n\n  // A sizes/strides\n\n  // stride of the A matrix in global memory\n  int a_gl_stride = lda / 8;\n  // stride of an A matrix tile in shared memory\n  constexpr int a_sh_stride = 16 * thread_k_blocks / 8;\n  // delta between subsequent A tiles in global memory\n  constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8;\n  // between subsequent accesses within a tile\n  int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o);\n  // between shared memory writes\n  constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o);\n  // between shared memory tile reads\n  constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4));\n  // within a shared memory tile\n  constexpr int a_sh_rd_delta_i = a_sh_stride * 16;\n  // overall size of a tile\n  constexpr int a_sh_stage = a_sh_stride * m_block_size;\n  // number of shared write iterations for a tile\n  constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta);\n\n  // B sizes/strides\n  int b_gl_stride = 16 * prob_n / (pack_factor * 4);\n  constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4;\n  constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2;\n  constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs;\n\n  int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;\n  int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads);\n  constexpr int b_sh_wr_delta = threads * b_thread_vecs;\n  constexpr int b_sh_rd_delta = threads * b_thread_vecs;\n  constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;\n  constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;\n\n  // Scale sizes/strides without act_order\n  int s_gl_stride = prob_n / 8;\n  constexpr int s_sh_stride = 16 * thread_n_blocks / 8;\n  constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks\n                                  ? thread_k_blocks / group_blocks / (w_type == host::kFE2M1f ? 2 : 1)\n                                  : 1;\n  constexpr int s_sh_stage = s_tb_groups * s_sh_stride;\n  int s_gl_rd_delta = s_gl_stride;\n\n  // Scale size/strides with act_order\n  constexpr int tb_k = 16 * thread_k_blocks;\n  constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0;\n  // constexpr int act_s_row_stride      = 1;\n  // int           act_s_col_stride      = act_s_row_stride * num_groups;\n  constexpr int act_s_max_num_groups = 32;\n  int act_s_col_stride = 1;\n  int act_s_col_warp_stride = act_s_col_stride * 8;\n\n  int tb_n_warps = thread_n_blocks / 4;\n  int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;\n\n  // Zero-points sizes/strides\n  int zp_gl_stride = is_zp_float ? prob_n / 8 : (prob_n / pack_factor) / 4;\n  constexpr int zp_sh_stride = is_zp_float ? 16 * thread_n_blocks / 8 : ((16 * thread_n_blocks) / pack_factor) / 4;\n  constexpr int zp_tb_groups = s_tb_groups;\n  constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0;\n  int zp_gl_rd_delta = zp_gl_stride;\n\n  // Global A read index of current thread.\n  int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o);\n  a_gl_rd += a_gl_rd_delta_o * slice_row;\n  // Shared write index of current thread.\n  int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o);\n  // Shared read index.\n  int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % (16 / (m_block_size_8 ? 2 : 1))) +\n                (threadIdx.x % 32) / (16 / (m_block_size_8 ? 2 : 1));\n  a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));\n\n  int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads) * b_thread_vecs;\n  b_gl_rd += b_sh_stride * slice_col;\n  b_gl_rd += b_gl_rd_delta_o * slice_row;\n  auto b_sh_wr = threadIdx.x * b_thread_vecs;\n  auto b_sh_rd = threadIdx.x * b_thread_vecs;\n\n  // For act_order\n  constexpr int k_iter_size = tb_k / b_sh_wr_iters;\n  int slice_k_start = tb_k * slice_row;\n  int slice_k_finish = slice_k_start + tb_k * slice_iters;\n  int slice_k_start_shared_fetch = slice_k_start;\n  int slice_n_offset = act_s_col_tb_stride * slice_col;\n\n  // No act_order\n  int s_gl_rd;\n  if constexpr (!has_act_order) {\n    if constexpr (group_blocks == -1) {\n      s_gl_rd = s_sh_stride * slice_col + threadIdx.x;\n    } else {\n      s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / (w_type == host::kFE2M1f ? 2 : 1) +\n                s_sh_stride * slice_col + threadIdx.x;\n    }\n  }\n  auto s_sh_wr = threadIdx.x;\n  bool s_sh_wr_pred = threadIdx.x < s_sh_stride;\n\n  // Zero-points\n  int zp_gl_rd;\n  if constexpr (has_zp) {\n    if constexpr (group_blocks == -1) {\n      zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;\n    } else {\n      zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + zp_sh_stride * slice_col + threadIdx.x;\n    }\n  }\n  auto zp_sh_wr = threadIdx.x;\n  bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride;\n\n  // We use a different scale layout for grouped and column-wise quantization as\n  // we scale a `half2` tile in column-major layout in the former and in\n  // row-major in the latter case.\n  int s_sh_rd;\n  if constexpr (group_blocks != -1 && w_type == host::kFE2M1f) {\n    auto warp_id = threadIdx.x / 32;\n    int n_warps = thread_n_blocks / 4;\n    int warp_row = warp_id / n_warps;\n\n    s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4;\n    s_sh_rd = s_sh_rd * 2 + warp_row % 2;\n\n  } else if constexpr (group_blocks != -1)\n    s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4;\n  else if constexpr (group_blocks == -1 && (m_block_size_8 || (has_zp && !dequant_skip_flop)))\n    s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 8;\n  else\n    s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4;\n\n  // Zero-points have the same read layout as the scales\n  // (without column-wise case)\n  constexpr int num_col_threads = 8;\n  constexpr int num_row_threads = 4;\n  constexpr int num_ints_per_thread = 8 / pack_factor;\n  int zp_sh_rd;\n  if constexpr (has_zp) {\n    if constexpr (is_zp_float) {\n      if constexpr (group_blocks != -1) {\n        zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4;\n      }\n    } else {\n      zp_sh_rd = num_ints_per_thread * num_col_threads * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +\n                 num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads);\n    }\n  }\n\n  // Precompute which thread should not read memory in which iterations; this is\n  // needed if there are more threads than required for a certain tilesize or\n  // when the batchsize is not a multiple of 16.\n  bool a_sh_wr_pred[a_sh_wr_iters];\n#pragma unroll\n  for (int i = 0; i < a_sh_wr_iters; i++)\n    a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;\n\n  // To ensure that writing and reading A tiles to/from shared memory, the\n  // latter in fragment format, is fully bank conflict free, we need to use a\n  // rather fancy XOR-based layout. The key here is that neither reads nor\n  // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the\n  // same shared memory banks. Further, it seems (based on NSight-Compute) that\n  // each warp must also write a consecutive memory segment?\n  auto transform_a = [&](int i) {\n    int row = i / a_gl_rd_delta_o;\n    return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ (row % 8);\n  };\n  // Since the computation of this remapping is non-trivial and, due to our main\n  // loop unrolls, all shared memory accesses are static, we simply precompute\n  // both transformed reads and writes.\n  int a_sh_wr_trans[a_sh_wr_iters];\n#pragma unroll\n  for (int i = 0; i < a_sh_wr_iters; i++)\n    a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);\n  int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];\n#pragma unroll\n  for (int i = 0; i < b_sh_wr_iters; i++) {\n#pragma unroll\n    for (int j = 0; j < thread_m_blocks; j++)\n      a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);\n  }\n\n  // Since B-accesses have non-constant stride they have to be computed at\n  // runtime; we break dependencies between subsequent accesses with a tile by\n  // maintining multiple pointers (we have enough registers), a tiny\n  // optimization.\n  const int4* B_ptr[b_sh_wr_iters];\n#pragma unroll\n  for (int i = 0; i < b_sh_wr_iters; i++)\n    B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;\n\n  extern __shared__ int4 sh[];\n  // Shared memory storage for global fetch pipelines.\n  constexpr int sh_red_size = (2 * thread_n_blocks + 1) * 16 * thread_m_blocks;\n  constexpr int sh_b_size = stages * b_sh_stage;\n  int4* sh_b = sh;\n  int4* sh_red = sh;\n  int4* sh_g_idx = sh_b + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size);\n  int4* sh_zp = sh_g_idx + (stages * g_idx_stage);\n  constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride) : (stages * s_sh_stage);\n  int4* sh_s = sh_zp + (stages * zp_sh_stage);\n  // shared memory reused by reduction should be smaller than\n  // shared memory used by weight.\n  static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <= stages * b_sh_stage);\n  int4* sh_a = sh_s + sh_s_size;\n  // constexpr int shm_size_used =\n  //     stages * (g_idx_stage + zp_sh_stage) + sh_s_size +\n  //     (sh_red_size > sh_b_size ? sh_red_size : sh_b_size);\n\n  // Register storage for double buffer of shared memory reads.\n  FragA frag_a[2][thread_m_blocks];\n  I4 frag_b_quant[2][b_thread_vecs];\n  FragC frag_c[thread_m_blocks][4][2];\n  FragS frag_s[2][4];                    // No act-order\n  FragS act_frag_s[2][4][4];             // For act-order\n  int frag_qzp[2][num_ints_per_thread];  // Zero-points\n  FragZP frag_zp;                        // Zero-points in fp16\n  FragZP frag_zpf[2];                    // Zero-points in fp16 in HQQ\n\n  // Zero accumulators.\n  auto zero_accums = [&]() {\n#pragma unroll\n    for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)\n      reinterpret_cast<float*>(frag_c)[i] = 0;\n  };\n\n  int sh_first_group_id = -1;\n  int sh_num_groups = -1;\n\n  auto fetch_act_order_scales_to_shared = [&](bool is_async, int first_group_id, int last_group_id) {\n    sh_first_group_id = first_group_id;\n    sh_num_groups = last_group_id - first_group_id + 1;\n\n    if (sh_num_groups > act_s_max_num_groups) {\n      sh_num_groups = act_s_max_num_groups;\n    }\n\n    if (sh_first_group_id + sh_num_groups > num_groups) {\n      sh_num_groups = num_groups - sh_first_group_id;\n    }\n\n    int row_offset = first_group_id * s_gl_stride;\n\n    if (is_async) {\n      for (int i = 0; i < sh_num_groups; i++) {\n        if (threadIdx.x < s_sh_stride) {\n          cp_async4_pred(\n              &sh_s[(i * s_sh_stride) + threadIdx.x],\n              &scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + threadIdx.x]);\n        }\n      }\n    } else {\n      for (int i = 0; i < sh_num_groups; i++) {\n        if (threadIdx.x < s_sh_stride) {\n          sh_s[(i * s_sh_stride) + threadIdx.x] =\n              scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + threadIdx.x];\n        }\n      }\n    }\n  };\n  // Asynchronously fetch the next A, B and s tile from global to the next\n  // shared memory pipeline location.\n  auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {\n    if (pred) {\n      int4* sh_a_stage = sh_a + a_sh_stage * pipe;\n#pragma unroll\n      for (int i = 0; i < a_sh_wr_iters; i++) {\n        cp_async4_pred(\n            &sh_a_stage[a_sh_wr_trans[i]],\n            &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],\n            a_sh_wr_pred[i]);\n      }\n      int4* sh_b_stage = sh_b + b_sh_stage * pipe;\n#pragma unroll\n      for (int i = 0; i < b_sh_wr_iters; i++) {\n#pragma unroll\n        for (int j = 0; j < b_thread_vecs; j++) {\n          cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j);\n        }\n\n        B_ptr[i] += b_gl_rd_delta_o;\n      }\n\n      if constexpr (has_act_order) {\n        // Fetch g_idx thread-block portion\n        int full_pipe = a_off;\n        int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe;\n        if (cur_k < prob_k && cur_k < slice_k_finish) {\n          int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;\n\n          int4 const* cur_g_idx_stage_ptr = reinterpret_cast<int4 const*>(&g_idx[cur_k]);\n\n          if (threadIdx.x < g_idx_stage) {\n            cp_async4_pred(&sh_g_idx_stage[threadIdx.x], &cur_g_idx_stage_ptr[threadIdx.x]);\n          }\n        }\n      } else {\n        if constexpr (group_blocks != -1) {\n          int4* sh_s_stage = sh_s + s_sh_stage * pipe;\n\n          if constexpr (group_blocks >= thread_k_blocks) {\n            // Only fetch scales if this tile starts a new group\n            if (pipe % (group_blocks / thread_k_blocks) == 0) {\n              if (s_sh_wr_pred) {\n                cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);\n              }\n              s_gl_rd += s_gl_rd_delta;\n            }\n          } else {\n            for (int i = 0; i < s_tb_groups; i++) {\n              if (s_sh_wr_pred) {\n                cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], &scales_ptr[s_gl_rd]);\n              }\n              s_gl_rd += s_gl_rd_delta;\n            }\n          }\n        }\n\n        if constexpr (has_zp && group_blocks != -1) {\n          int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;\n\n          if constexpr (group_blocks >= thread_k_blocks) {\n            // Only fetch zero-points if this tile starts a new group\n            if (pipe % (group_blocks / thread_k_blocks) == 0) {\n              if (zp_sh_wr_pred) {\n                cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]);\n              }\n              zp_gl_rd += zp_gl_rd_delta;\n            }\n          } else {\n            for (int i = 0; i < zp_tb_groups; i++) {\n              if (zp_sh_wr_pred) {\n                cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], &zp_ptr[zp_gl_rd]);\n              }\n              zp_gl_rd += zp_gl_rd_delta;\n            }\n          }\n        }\n      }\n    }\n    // Insert a fence even when we are winding down the pipeline to ensure that\n    // waiting is also correct at this point.\n    cp_async_fence();\n  };\n\n  auto fetch_col_zp_to_shared = [&]() {\n    if (zp_sh_wr_pred) {\n      cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]);\n    }\n  };\n\n  auto fetch_col_scale_to_shared = [&]() {\n    if (s_sh_wr_pred) {\n      cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);\n    }\n  };\n\n  // Wait until the next thread tile has been loaded to shared memory.\n  auto wait_for_stage = [&]() {\n    // We only have `stages - 2` active fetches since we are double buffering\n    // and can only issue the next fetch when it is guaranteed that the previous\n    // shared memory load is fully complete (as it may otherwise be\n    // overwritten).\n    cp_async_wait<stages - 2>();\n    __syncthreads();\n  };\n\n  // Load the next sub-tile from the current location in the shared memory pipe\n  // into the current register buffer.\n  auto fetch_to_registers = [&](int k, int pipe) {\n    int4* sh_a_stage = sh_a + a_sh_stage * pipe;\n#pragma unroll\n    for (int i = 0; i < thread_m_blocks; i++)\n      ldsm<m_block_size_8 ? 2 : 4, scalar_t>(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);\n    int4* sh_b_stage = sh_b + b_sh_stage * pipe;\n\n#pragma unroll\n    for (int i = 0; i < b_thread_vecs; i++) {\n      frag_b_quant[k % 2][i] = *reinterpret_cast<I4*>(&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]);\n    }\n  };\n\n  bool is_same_group[stages];\n  int same_group_id[stages];\n\n  auto init_same_group = [&](int pipe) {\n    if constexpr (!has_act_order) {\n      return;\n    }\n\n    int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;\n    int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage);\n\n    int group_id_1 = sh_g_idx_int_ptr[0];\n    int group_id_2 = sh_g_idx_int_ptr[tb_k - 1];\n\n    is_same_group[pipe] = group_id_1 == group_id_2;\n    same_group_id[pipe] = group_id_1;\n  };\n\n  auto fetch_scales_to_registers = [&](int k, int full_pipe) {\n    int pipe = full_pipe % stages;\n\n    if constexpr (!has_act_order) {\n      // No act-order case\n      if constexpr (group_blocks == -1) {\n        // load only when starting a new slice\n        if (k == 0 && full_pipe == 0) {\n          reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd];\n          reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];\n        }\n      } else if constexpr (group_blocks != -1) {\n        if constexpr (group_blocks >= thread_k_blocks) {\n          if (k % b_sh_wr_iters == 0) {\n            int4* sh_s_stage =\n                sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks)));\n            reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];\n          } else {\n            reinterpret_cast<int4*>(&frag_s[1])[0] = reinterpret_cast<int4*>(&frag_s[0])[0];\n          }\n        } else {\n          auto warp_id = threadIdx.x / 32;\n          int n_warps = thread_n_blocks / 4;\n\n          int warp_row = warp_id / n_warps;\n\n          int cur_k = warp_row * 16;\n          cur_k += k_iter_size * (k % b_sh_wr_iters);\n\n          int k_blocks = cur_k / 16;\n          int cur_group_id = k_blocks / (group_blocks * (w_type == host::kFE2M1f ? 2 : 1));\n\n          int4* sh_s_stage = sh_s + s_sh_stage * pipe;\n\n          if constexpr (w_type_id != host::kFE2M1f.id()) {\n            reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];\n          } else {\n            reinterpret_cast<int2*>(&frag_s[k % 2])[0] =\n                reinterpret_cast<int2*>(sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];\n          }\n        }\n      }\n\n      return;\n    }\n\n    // Act-order case\n\n    // Determine K of the \"current\" thread-block\n    int cur_k = slice_k_start + tb_k * full_pipe;\n    if (cur_k >= prob_k || cur_k >= slice_k_finish) {\n      return;\n    }\n\n    // Reset (to current thread-block) since we read g_idx portion from the\n    // shared memory\n    cur_k = 0;\n\n    // Progress to current iteration\n    cur_k += k_iter_size * (k % b_sh_wr_iters);\n\n    // Determine \"position\" inside the thread-block (based on warp and\n    // thread-id)\n    auto warp_id = threadIdx.x / 32;\n    int n_warps = thread_n_blocks / 4;  // Each warp processes 4 16-size tiles over N\n\n    int warp_row = warp_id / n_warps;\n    int warp_col = warp_id % n_warps;\n\n    cur_k += warp_row * 16;\n\n    auto th_id = threadIdx.x % 32;\n    cur_k += (th_id % 4) * 2;  // Due to tensor-core layout for fp16 B matrix\n\n    int s_col_shift =\n        /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + (th_id / 4) * act_s_col_stride;\n\n    if (is_same_group[pipe]) {\n      if (k % 2 == 0) {\n        *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0]))) =\n            sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + s_col_shift];\n      } else {\n        *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0]))) =\n            *(reinterpret_cast<int4*>(&(act_frag_s[(k - 1) % 2][0][0])));\n      }\n\n      for (int i = 1; i < 4; i++) {\n        *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][i][0]))) = *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0])));\n      }\n      return;\n    }\n\n    int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;\n    int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage);\n\n    constexpr int k_frag_offsets[4] = {0, 1, 8, 9};  // Tensor core offsets per thread\n\n#pragma unroll\n    for (int i = 0; i < 4; i++) {\n      int actual_k = cur_k + k_frag_offsets[i];\n\n      int group_id = sh_g_idx_int_ptr[actual_k];\n      int rel_group_id = group_id - sh_first_group_id;\n\n      *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][i][0]))) = sh_s[rel_group_id * s_sh_stride + s_col_shift];\n    }\n  };\n\n  auto fetch_zp_to_registers = [&](int k, int full_pipe) {\n    // This code does not handle group_blocks == 0,\n    // which signifies act_order.\n    // has_zp implies AWQ, which doesn't have act_order,\n    static_assert(!has_zp || group_blocks != 0);\n\n    if constexpr (has_zp && !is_zp_float) {\n      int pipe = full_pipe % stages;\n\n      if constexpr (group_blocks == -1) {\n        // load only when starting a new slice\n        if (k == 0 && full_pipe == 0) {\n#pragma unroll\n          for (int i = 0; i < num_ints_per_thread; i++) {\n            frag_qzp[k % 2][i] = (reinterpret_cast<int*>(sh_zp))[zp_sh_rd + i];\n          }\n        }\n\n      } else if constexpr (group_blocks >= thread_k_blocks) {\n        if (k % b_sh_wr_iters == 0) {\n          int4* sh_zp_stage =\n              sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks)));\n#pragma unroll\n          for (int i = 0; i < num_ints_per_thread; i++) {\n            frag_qzp[k % 2][i] = (reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];\n          }\n        }\n      } else {\n        auto warp_id = threadIdx.x / 32;\n        int n_warps = thread_n_blocks / 4;\n\n        int warp_row = warp_id / n_warps;\n\n        int cur_k = warp_row * 16;\n        cur_k += k_iter_size * (k % b_sh_wr_iters);\n\n        int k_blocks = cur_k / 16;\n        int cur_group_id = 0;\n\n        // Suppress bogus and persistent divide-by-zero warning\n#pragma nv_diagnostic push\n#pragma nv_diag_suppress divide_by_zero\n        cur_group_id = k_blocks / group_blocks;\n#pragma nv_diagnostic pop\n\n        int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;\n\n        sh_zp_stage += cur_group_id * zp_sh_stride;\n\n#pragma unroll\n        for (int i = 0; i < num_ints_per_thread; i++) {\n          frag_qzp[k % 2][i] = (reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];\n        }\n      }\n    }\n\n    else if constexpr (has_zp && is_zp_float) {\n      int pipe = full_pipe % stages;\n\n      if constexpr (group_blocks != -1) {\n        if constexpr (group_blocks >= thread_k_blocks) {\n          if (k % b_sh_wr_iters == 0) {\n            int4* sh_zp_stage =\n                sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks)));\n            reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd];\n          }\n        } else {\n          auto warp_id = threadIdx.x / 32;\n          int n_warps = thread_n_blocks / 4;\n\n          int warp_row = warp_id / n_warps;\n\n          int cur_k = warp_row * 16;\n          cur_k += k_iter_size * (k % b_sh_wr_iters);\n\n          int k_blocks = cur_k / 16;\n          // Suppress bogus and persistent divide-by-zero warning\n#pragma nv_diagnostic push\n#pragma nv_diag_suppress divide_by_zero\n          int cur_group_id = k_blocks / group_blocks;\n#pragma nv_diagnostic pop\n\n          int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;\n\n          reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd + cur_group_id * zp_sh_stride];\n        }\n      }\n    }\n  };\n\n  auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) {\n    dequant<scalar_t2, w_type_id, dequant_skip_flop>(q, frag_b_ptr);\n  };\n\n  // Execute the actual tensor core matmul of a sub-tile.\n  bool is_first_matmul_in_slice = true;\n  auto matmul = [&](int k) {\n    int k2 = k % 2;\n    const bool is_new_zp = ((group_blocks != -1) && (group_blocks < thread_k_blocks || k == 0)) ||\n                           (group_blocks == -1 && is_first_matmul_in_slice);\n    if constexpr (has_zp && !is_zp_float) {\n      if (is_new_zp) {\n        if constexpr (group_blocks == -1) is_first_matmul_in_slice = false;\n        FragB frag_zp_0;\n        FragB frag_zp_1;\n        int zp_quant_0, zp_quant_1;\n\n        if constexpr (w_type.size_bits() == 4) {\n          zp_quant_0 = frag_qzp[k2][0];\n          zp_quant_1 = zp_quant_0 >> 8;\n        } else {\n          static_assert(w_type.size_bits() == 8);\n          zp_quant_0 = frag_qzp[k2][0];\n          zp_quant_1 = frag_qzp[k2][1];\n        }\n\n        dequant_data(zp_quant_0, reinterpret_cast<scalar_t2*>(&frag_zp));\n        dequant_data(zp_quant_1, reinterpret_cast<scalar_t2*>(&frag_zp) + 2);\n      }\n    }\n    if constexpr (!dequant_skip_flop && has_zp && is_zp_float) {\n      if (is_new_zp) {\n        reinterpret_cast<int4*>(&frag_zp)[0] = reinterpret_cast<int4*>(&frag_zpf[k2])[0];\n      }\n    }\n\n    if constexpr (w_type == host::kFE2M1f) {\n      int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];\n      int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];\n\n      dequant_fp8_scales<scalar_t2>(s_quant_0, reinterpret_cast<scalar_t2*>(&frag_s[k2]));\n      dequant_fp8_scales<scalar_t2>(s_quant_1, reinterpret_cast<scalar_t2*>(&frag_s[k2]) + 2);\n    }\n\n// We have the m dimension as the inner loop in order to encourage overlapping\n// dequantization and matmul operations.\n#pragma unroll\n    for (int j = 0; j < 4; j++) {\n      FragB frag_b0;\n      FragB frag_b1;\n      int b_quant_0, b_quant_1;\n\n      if constexpr (w_type_id == host::kFE2M1f.id()) {\n        b_quant_1 = frag_b_quant[k2][0][j];\n        b_quant_0 = b_quant_1 << 8;\n      } else if constexpr (w_type.size_bits() == 4) {\n        b_quant_0 = frag_b_quant[k2][0][j];\n        b_quant_1 = b_quant_0 >> 8;\n      } else {\n        static_assert(w_type.size_bits() == 8);\n        int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k2]);\n        b_quant_0 = frag_b_quant_ptr[j * 2 + 0];\n        b_quant_1 = frag_b_quant_ptr[j * 2 + 1];\n      }\n\n      dequant_data(b_quant_0, reinterpret_cast<scalar_t2*>(&frag_b0));\n      dequant_data(b_quant_1, reinterpret_cast<scalar_t2*>(&frag_b1));\n\n      if constexpr (dequant_skip_flop && has_zp && !is_zp_float) {\n        sub_zp<scalar_t>(frag_b0, frag_zp[j], 0);\n        sub_zp<scalar_t>(frag_b1, frag_zp[j], 1);\n      }\n\n      // Apply scale to frag_b0\n      if constexpr (has_act_order) {\n        static_assert(group_blocks != -1);\n        scale4<scalar_t>(\n            frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j], act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0);\n        scale4<scalar_t>(\n            frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1);\n      } else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float && group_blocks == -1) {\n        int idx = (threadIdx.x / 4) % 2;\n        scalar_t2 s2 = Dtype::nums2num2(\n            reinterpret_cast<scalar_t*>(&frag_s[j / 2][j % 2 * 2 + 0])[idx],\n            reinterpret_cast<scalar_t*>(&frag_s[j / 2][j % 2 * 2 + 1])[idx]);\n        if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2);\n        scale_and_sub<scalar_t>(frag_b0, s2.x, frag_zp[j].x);\n        scale_and_sub<scalar_t>(frag_b1, s2.y, frag_zp[j].y);\n      } else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) {\n        if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], *reinterpret_cast<scalar_t2*>(&frag_s[k2][j]));\n        scale_and_sub<scalar_t>(frag_b0, frag_s[k2][j][0].x, frag_zp[j].x);\n        scale_and_sub<scalar_t>(frag_b1, frag_s[k2][j][0].y, frag_zp[j].y);\n      } else if constexpr (group_blocks != -1) {\n        scale<scalar_t>(frag_b0, frag_s[k2][j], 0);\n        scale<scalar_t>(frag_b1, frag_s[k2][j], 1);\n      }\n\n#pragma unroll\n      for (int i = 0; i < thread_m_blocks; i++) {\n        if constexpr (m_block_size_8) {\n          mma_trans<scalar_t>(frag_a[k2][i], frag_b0, frag_b1, frag_c[i][j][0]);\n        } else {\n          mma<scalar_t>(frag_a[k2][i], frag_b0, frag_c[i][j][0]);\n          mma<scalar_t>(frag_a[k2][i], frag_b1, frag_c[i][j][1]);\n        }\n      }\n    }\n  };\n\n  // Since we slice across the k dimension of a tile in order to increase the\n  // number of warps while keeping the n dimension of a tile reasonable, we have\n  // multiple warps that accumulate their partial sums of the same output\n  // location; which we have to reduce over in the end. We do in shared memory.\n  auto thread_block_reduce = [&]() {\n    constexpr int red_off = threads / b_sh_stride_threads / 2;\n    if (red_off >= 1) {\n      auto red_idx = threadIdx.x / b_sh_stride_threads;\n      constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;\n      constexpr int red_sh_delta = b_sh_stride_threads;\n      int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads);\n\n      // Parallel logarithmic shared memory reduction. We make sure to avoid any\n      // unnecessary read or write iterations, e.g., for two warps we write only\n      // once by warp 1 and read only once by warp 0.\n\n#pragma unroll\n      for (int m_block = 0; m_block < thread_m_blocks; m_block++) {\n#pragma unroll\n        for (int i = red_off; i > 0; i /= 2) {\n          if (i <= red_idx && red_idx < 2 * i) {\n#pragma unroll\n            for (int j = 0; j < 4 * 2; j += (m_block_size_8 ? 2 : 1)) {\n              int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i);\n              if (i < red_off) {\n                float* c_rd = reinterpret_cast<float*>(&sh_red[red_sh_delta * j + red_sh_rd]);\n                float* c_wr = reinterpret_cast<float*>(&sh_red[red_sh_wr]);\n#pragma unroll\n                for (int k = 0; k < 4; k++)\n                  reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k];\n              }\n              sh_red[red_sh_wr] = reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];\n            }\n          }\n          __syncthreads();\n        }\n        if (red_idx == 0) {\n#pragma unroll\n          for (int i = 0; i < 4 * 2; i += (m_block_size_8 ? 2 : 1)) {\n            float* c_rd = reinterpret_cast<float*>(&sh_red[red_sh_delta * i + red_sh_rd]);\n#pragma unroll\n            for (int j = 0; j < 4; j++)\n              reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] += c_rd[j];\n          }\n        }\n        __syncthreads();\n      }\n    }\n  };\n\n  // Since multiple threadblocks may process parts of the same column slice, we\n  // finally have to globally reduce over the results. As the striped\n  // partitioning minimizes the number of such reductions and our outputs are\n  // usually rather small, we perform this reduction serially in L2 cache.\n  auto global_reduce_fp16 = [&](bool first = false, bool last = false) {\n    // We are very careful here to reduce directly in the output buffer to\n    // maximize L2 cache utilization in this step. To do this, we write out\n    // results in FP16 (but still reduce with FP32 compute).\n    constexpr int active_threads = 32 * thread_n_blocks / 4;\n    if (threadIdx.x < active_threads) {\n      int c_gl_stride = prob_n / 8;\n      int c_gl_wr_delta_o = 8 * c_gl_stride;\n      int c_gl_wr_delta_i = 4 * (active_threads / 32);\n      int c_gl_wr;\n      if constexpr (m_block_size_8) {\n        c_gl_wr = c_gl_stride * ((threadIdx.x % 4) * 2) + 4 * (threadIdx.x / 32) + (threadIdx.x % 32) / 8;\n        c_gl_wr += (2 * thread_n_blocks) * slice_col;\n      } else {\n        c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + 4 * (threadIdx.x / 32) + threadIdx.x % 4;\n        c_gl_wr += (2 * thread_n_blocks) * slice_col;\n      }\n      constexpr int c_sh_wr_delta = active_threads;\n      auto c_sh_wr = threadIdx.x;\n\n      int row = (threadIdx.x % 32) / 4;\n\n      if (!first) {\n// Interestingly, doing direct global accesses here really seems to mess up\n// the compiler and lead to slowdowns, hence we also use async-copies even\n// though these fetches are not actually asynchronous.\n#pragma unroll\n        for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) {\n          if constexpr (m_block_size_8) {\n            cp_async4_pred(\n                &sh_red[c_sh_wr + c_sh_wr_delta * i],\n                &C[c_gl_wr + i * c_gl_stride + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i],\n                (threadIdx.x % 4) * 2 + i < prob_m);\n          } else {\n            cp_async4_pred(\n                &sh_red[c_sh_wr + c_sh_wr_delta * i],\n                &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)],\n                i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);\n          }\n        }\n        cp_async_fence();\n        cp_async_wait<0>();\n      }\n\n#pragma unroll\n      for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) {\n        bool mask = (!m_block_size_8) && (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) ||\n                    (m_block_size_8) && ((threadIdx.x % 4) * 2 + i < prob_m);\n        if (mask) {\n          if (!first) {\n            int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta];\n#pragma unroll\n            for (int j = 0; j < 2 * 4; j++) {\n              int delta = 0;\n              if constexpr (m_block_size_8) {\n                delta = j % 2 == 1 ? -2 : 0;\n              }\n              reinterpret_cast<float*>(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta] +=\n                  Dtype::num2float(reinterpret_cast<scalar_t*>(&c_red)[j]);\n            }\n          }\n          if (!last) {\n            int4 c;\n#pragma unroll\n            for (int j = 0; j < 2 * 4; j++) {\n              int delta = 0;\n              if constexpr (m_block_size_8) {\n                delta = j % 2 == 1 ? -2 : 0;\n              }\n              reinterpret_cast<scalar_t*>(&c)[j] =\n                  Dtype::float2num(reinterpret_cast<float*>(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta]);\n            }\n            if constexpr (m_block_size_8)\n              C[c_gl_wr + i * c_gl_stride + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i] = c;\n            else\n              C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = c;\n          }\n        }\n      }\n    }\n  };\n\n  // Globally reduce over threadblocks that compute the same column block.\n  // We use a tmp C buffer to reduce in full fp32 precision.\n  auto global_reduce_fp32 = [&](bool first = false, bool last = false) {\n    constexpr int tb_m = thread_m_blocks * 16;\n    constexpr int tb_n = thread_n_blocks * 16;\n\n    constexpr int c_size = tb_m * tb_n * sizeof(float) / 16;\n\n    constexpr int active_threads = 32 * thread_n_blocks / 4;\n    bool is_th_active = threadIdx.x < active_threads;\n\n    constexpr int num_floats = thread_m_blocks * 4 * 2 * 4;\n    constexpr int th_size = num_floats * sizeof(float) / 16;\n\n    int c_cur_offset = locks_off * c_size;\n\n    if (!is_th_active) {\n      return;\n    }\n\n    if (!first) {\n      float* frag_c_ptr = reinterpret_cast<float*>(&frag_c);\n#pragma unroll\n      for (int k = 0; k < th_size; k += (m_block_size_8 ? 2 : 1)) {\n        sh_red[threadIdx.x] = C_tmp[c_cur_offset + active_threads * k + threadIdx.x];\n\n        float* sh_c_ptr = reinterpret_cast<float*>(&sh_red[threadIdx.x]);\n#pragma unroll\n        for (int f = 0; f < 4; f++) {\n          frag_c_ptr[k * 4 + f] += sh_c_ptr[f];\n        }\n      }\n    }\n\n    if (!last) {\n      int4* frag_c_ptr = reinterpret_cast<int4*>(&frag_c);\n#pragma unroll\n      for (int k = 0; k < th_size; k += (m_block_size_8 ? 2 : 1)) {\n        C_tmp[c_cur_offset + active_threads * k + threadIdx.x] = frag_c_ptr[k];\n      }\n    }\n  };\n\n  // Write out the reduce final result in the correct layout. We only actually\n  // reshuffle matrix fragments in this step, the reduction above is performed\n  // in fragment layout.\n  auto write_result = [&]() {\n    int c_gl_stride = prob_n / 8;\n    constexpr int c_sh_stride = 2 * thread_n_blocks + 1;\n    int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));\n    constexpr int c_sh_rd_delta = c_sh_stride * (threads / (2 * thread_n_blocks));\n\n    int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks));\n    c_gl_wr += (2 * thread_n_blocks) * slice_col;\n    int c_sh_wr;\n    if constexpr (m_block_size_8) {\n      c_sh_wr = (8 * c_sh_stride) * ((threadIdx.x % 32) % 4 * 2) + (threadIdx.x % 32) / 4;\n      c_sh_wr += 64 * (threadIdx.x / 32);\n    } else {\n      c_sh_wr = (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;\n      c_sh_wr += 32 * (threadIdx.x / 32);\n    }\n\n    int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks));\n\n    int c_gl_wr_end = c_gl_stride * prob_m;\n    // We first reorder in shared memory to guarantee the most efficient final\n    // global write patterns\n    auto write = [&](int idx, float c0, float c1, FragS& s) {\n      scalar_t2 res = Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1));\n\n      // For per-column quantization we finally apply the scale here (only for\n      // 4-bit)\n      if constexpr (\n          !has_act_order && group_blocks == -1 && w_type.size_bits() == 4 && (has_zp && dequant_skip_flop || !has_zp)) {\n        res = __hmul2(res, s[0]);\n      }\n\n      if constexpr (w_type == host::kFE2M1f) {\n        res = __hmul2(res, global_scale);\n      }\n\n      if constexpr (m_block_size_8) {\n        ((scalar_t*)sh_red)[idx] = res.x;\n        ((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y;\n      } else {\n        ((scalar_t2*)sh_red)[idx] = res;\n      }\n    };\n\n    if (threadIdx.x / 32 < thread_n_blocks / 4) {\n#pragma unroll\n      for (int i = 0; i < thread_m_blocks; i++) {\n#pragma unroll\n        for (int j = 0; j < 4; j++) {\n          if constexpr (m_block_size_8) {\n            int wr = c_sh_wr + 16 * j;\n            write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);\n            write(wr + 8, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 1]);\n          } else {\n            int wr = c_sh_wr + 8 * j;\n            write(\n                wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);\n            write(\n                wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);\n            write(\n                wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);\n            write(\n                wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);\n          }\n        }\n        c_sh_wr += 16 * (4 * c_sh_stride);\n      }\n    }\n    __syncthreads();\n\n#pragma unroll\n    for (int i = 0; i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) {\n      if (c_gl_wr < c_gl_wr_end) {\n        if (use_atomic_add && slice_count > 1) {\n          scalar_t2* C_half2 = reinterpret_cast<scalar_t2*>(&C[c_gl_wr]);\n          scalar_t2* sh_red_half2 = reinterpret_cast<scalar_t2*>(&sh_red[c_sh_rd]);\n#pragma unroll\n          for (int a = 0; a < 4; a++) {\n            atomicAdd(&C_half2[a], sh_red_half2[a]);\n          }\n        } else {\n          C[c_gl_wr] = sh_red[c_sh_rd];\n        }\n        c_gl_wr += c_gl_wr_delta;\n        c_sh_rd += c_sh_rd_delta;\n      }\n    }\n    __syncthreads();\n  };\n\n  // Start global fetch and register load pipelines.\n  auto start_pipes = [&]() {\n\n#pragma unroll\n    for (int i = 0; i < stages - 1; i++) {\n      if (has_act_order && i == 0) {\n        int last_g_idx = slice_k_start + stages * tb_k * 2;\n        if (last_g_idx >= prob_k) {\n          last_g_idx = prob_k - 1;\n        }\n        fetch_act_order_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]);\n      }\n\n      if constexpr (has_zp && !is_zp_float && group_blocks == -1) {\n        if (i == 0) {\n          fetch_col_zp_to_shared();\n          if constexpr (!dequant_skip_flop) {\n            fetch_col_scale_to_shared();\n          }\n        }\n      }\n      fetch_to_shared(i, i, i < slice_iters);\n    }\n\n    zero_accums();\n    wait_for_stage();\n    init_same_group(0);\n    fetch_to_registers(0, 0);\n    fetch_scales_to_registers(0, 0);\n    fetch_zp_to_registers(0, 0);\n    a_gl_rd += a_gl_rd_delta_o * (stages - 1);\n    if constexpr (has_act_order) {\n      slice_k_start_shared_fetch += tb_k * (stages - 1);\n    }\n  };\n  if (slice_iters) {\n    start_pipes();\n  }\n\n  // Main loop.\n  while (slice_iters) {\n    // We unroll over both the global fetch and the register load pipeline to\n    // ensure all shared memory accesses are static. Note that both pipelines\n    // have even length meaning that the next iteration will always start at\n    // index 0.\n\n#pragma unroll\n    for (int pipe = 0; pipe < stages;) {\n#pragma unroll\n      for (int k = 0; k < b_sh_wr_iters; k++) {\n        fetch_to_registers(k + 1, pipe % stages);\n        fetch_scales_to_registers(k + 1, pipe);\n        fetch_zp_to_registers(k + 1, pipe);\n        if (k == b_sh_wr_iters - 2) {\n          fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages);\n          pipe++;\n          wait_for_stage();\n          init_same_group(pipe % stages);\n        }\n        matmul(k);\n      }\n      slice_iters--;\n      if (slice_iters == 0) {\n        break;\n      }\n    }\n\n    a_gl_rd += a_gl_rd_delta_o * stages;\n\n    if constexpr (has_act_order) {\n      slice_k_start += tb_k * stages;\n\n      if (slice_k_start < prob_k) {\n        slice_k_start_shared_fetch += tb_k * stages;\n        int first_group_id = g_idx[slice_k_start];\n        int last_g_idx = slice_k_start + stages * tb_k * 2;\n        if (last_g_idx >= prob_k) {\n          last_g_idx = prob_k - 1;\n        }\n        int last_group_id = g_idx[last_g_idx];\n        if (last_group_id >= sh_first_group_id + sh_num_groups) {\n          fetch_act_order_scales_to_shared(false, first_group_id, last_group_id);\n          __syncthreads();\n        }\n      }\n    }\n\n    // Process results and, if necessary, proceed to the next column slice.\n    // While this pattern may not be the most readable, other ways of writing\n    // the loop seemed to noticeably worse performance after compilation.\n    if (slice_iters == 0) {\n      cp_async_wait<0>();\n      bool last = slice_idx == slice_count - 1;\n      // For per-column scales, we only fetch them here in the final step before\n      // write-out\n      if constexpr (!has_act_order && group_blocks == -1 && (has_zp && dequant_skip_flop || !has_zp)) {\n        if (w_type.size_bits() == 8 || (last || use_atomic_add)) {\n          if (s_sh_wr_pred) {\n            cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);\n          }\n          cp_async_fence();\n        }\n      }\n\n      thread_block_reduce();\n      if constexpr (!has_act_order && group_blocks == -1 && (has_zp && dequant_skip_flop || !has_zp)) {\n        if (w_type.size_bits() == 8 || (last || use_atomic_add)) {\n          cp_async_wait<0>();\n          __syncthreads();\n          if (threadIdx.x / 32 < thread_n_blocks / 4) {\n            reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];\n            reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];\n            if constexpr (m_block_size_8) {\n              int idx = (threadIdx.x / 4) % 2;\n              scalar_t2* frag_s_half2 = reinterpret_cast<scalar_t2*>(frag_s);\n#pragma unroll\n              for (int i = 0; i < 8; i++) {\n                frag_s_half2[i] = Dtype::num2num2(reinterpret_cast<scalar_t*>(&frag_s_half2[i])[idx]);\n              }\n            }\n          }\n        }\n      }\n\n      // For 8-bit channelwise, we apply the scale before the global reduction\n      // that converts the fp32 results to fp16 (so that we avoid possible\n      // overflow in fp16)\n      if constexpr (\n          !has_act_order && group_blocks == -1 && w_type.size_bits() == 8 && (has_zp && dequant_skip_flop || !has_zp)) {\n        if (threadIdx.x / 32 < thread_n_blocks / 4) {\n#pragma unroll\n          for (int i = 0; i < thread_m_blocks; i++) {\n#pragma unroll\n            for (int j = 0; j < 4; j++) {\n              scale_float<scalar_t>(reinterpret_cast<float*>(&frag_c[i][j][0][0]), frag_s[j / 2][2 * (j % 2) + 0]);\n              scale_float<scalar_t>(\n                  reinterpret_cast<float*>(&frag_c[i][j][0][2]), frag_s[j / 2][2 * (j % 2) + (m_block_size_8 ? 1 : 0)]);\n\n              if constexpr (!m_block_size_8) {\n                scale_float<scalar_t>(reinterpret_cast<float*>(&frag_c[i][j][1][0]), frag_s[j / 2][2 * (j % 2) + 1]);\n                scale_float<scalar_t>(reinterpret_cast<float*>(&frag_c[i][j][1][2]), frag_s[j / 2][2 * (j % 2) + 1]);\n              }\n            }\n          }\n        }\n      }\n\n      if (slice_count > 1 && !use_atomic_add) {\n        // only globally reduce if there is more than one block in a slice\n        barrier_acquire(&locks[locks_off], slice_idx);\n        if (use_fp32_reduce) {\n          global_reduce_fp32(slice_idx == 0, last);\n        } else {\n          global_reduce_fp16(slice_idx == 0, last);\n        }\n        barrier_release(&locks[locks_off], last);\n      }\n      if (use_atomic_add && slice_count > 1 && slice_idx != 0) wait_negative_and_add(&locks[locks_off]);\n      if (last || use_atomic_add)\n        // only the last block in a slice actually writes the result\n        write_result();\n      slice_row = 0;\n      slice_col_par++;\n      slice_col++;\n      is_first_matmul_in_slice = true;\n      init_slice();\n\n      if (slice_iters) {\n        a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o);\n#pragma unroll\n        for (int i = 0; i < b_sh_wr_iters; i++)\n          B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;\n        if (slice_col == 0) {\n#pragma unroll\n          for (int i = 0; i < b_sh_wr_iters; i++)\n            B_ptr[i] -= b_gl_stride;\n        }\n\n        // Update slice k/n for scales loading\n        if constexpr (has_act_order) {\n          slice_k_start = tb_k * slice_row;\n          slice_k_finish = slice_k_start + tb_k * slice_iters;\n          slice_k_start_shared_fetch = slice_k_start;\n          slice_n_offset = act_s_col_tb_stride * slice_col;\n\n        } else {\n          s_gl_rd = s_sh_stride * slice_col + threadIdx.x;\n          zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;\n        }\n\n        start_pipes();\n      }\n    }\n  }\n}\n\n}  // namespace device::marlin\n\n#endif\n"
  },
  {
    "path": "python/sglang/jit_kernel/csrc/gemm/marlin_moe/kernel.h",
    "content": "\n#include <sgl_kernel/scalar_type.hpp>\n\n#include \"../marlin/marlin.cuh\"\n#include \"../marlin/marlin_dtypes.cuh\"\n\n#define MARLIN_KERNEL_PARAMS                                                                                         \\\n  const int4 *__restrict__ A, const int4 *__restrict__ B, int4 *__restrict__ C, int4 *__restrict__ C_tmp,            \\\n      const int4 *__restrict__ b_bias_ptr, const int4 *__restrict__ scales_ptr,                                      \\\n      const uint16_t *__restrict__ scale2_ptr, const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx,       \\\n      const int32_t *__restrict__ sorted_token_ids_ptr, const int32_t *__restrict__ expert_ids_ptr,                  \\\n      const int32_t *__restrict__ num_tokens_past_padded_ptr, const float *__restrict__ topk_weights_ptr, int top_k, \\\n      bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, int prob_n, int prob_k, int *locks,             \\\n      bool has_bias, bool use_atomic_add, bool use_fp32_reduce, int max_shared_mem\n\nnamespace device::marlin_moe {\ntemplate <\n    typename scalar_t,                   // compute dtype, half or nv_float16\n    const host::ScalarTypeId w_type_id,  // weight ScalarType id\n    const host::ScalarTypeId s_type_id,  // weight scale ScalarType id\n    const int threads,                   // number of threads in a threadblock\n    const int thread_m_blocks,           // number of 16x16 blocks in the m\n                                         // dimension (batchsize) of the\n                                         // threadblock\n    const int thread_n_blocks,           // same for n dimension (output)\n    const int thread_k_blocks,           // same for k dimension (reduction)\n    const bool m_block_size_8,           // whether m_block_size == 8\n                                         // only works when thread_m_blocks == 1\n    const int stages,                    // number of stages for the async global->shared\n                                         // fetch pipeline\n    const int group_blocks,              // number of consecutive 16x16 blocks\n                                         // with a separate quantization scale\n    const bool is_zp_float               // is zero point of float16 type?\n    >\n__global__ void Marlin(MARLIN_KERNEL_PARAMS);\n\n}  // namespace device::marlin_moe\n"
  },
  {
    "path": "python/sglang/jit_kernel/csrc/gemm/marlin_moe/marlin_template.h",
    "content": "/*\n * Modified by Neural Magic\n * Copyright (C) Marlin.2024 Elias Frantar\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *         http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n/*\n * Adapted from https://github.com/IST-DASLab/marlin\n */\n\n#include <sgl_kernel/scalar_type.hpp>\n\n#include \"../marlin/dequant.h\"\n#include \"../marlin/marlin.cuh\"\n#include \"../marlin/marlin_dtypes.cuh\"\n\n#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t)                                        \\\n  static_assert(                                                                         \\\n      std::is_same<scalar_t, half>::value || std::is_same<scalar_t, nv_bfloat16>::value, \\\n      \"only float16 and bfloat16 is supported\");\n\nnamespace device::marlin_moe {\nusing namespace device::marlin;\n\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800\n\ntemplate <\n    typename scalar_t,                   // compute dtype, half or nv_float16\n    const host::ScalarTypeId w_type_id,  // weight ScalarType id\n    const int threads,                   // number of threads in a threadblock\n    const int thread_m_blocks,           // number of 16x16 blocks in the m\n                                         // dimension (batchsize) of the\n                                         // threadblock\n    const int thread_n_blocks,           // same for n dimension (output)\n    const int thread_k_blocks,           // same for k dimension (reduction)\n    const bool m_block_size_8,           // whether m_block_size == 8\n                                         // only works when thread_m_blocks == 1\n    const int stages,                    // number of stages for the async global->shared\n                                         // fetch pipeline\n    const int group_blocks,              // number of consecutive 16x16 blocks\n                                         // with a separate quantization scale\n    const bool is_zp_float               // is zero point of float16 type?\n    >\n__global__ void Marlin(\n    const int4* __restrict__ A,                              // fp16 input matrix of shape mxk\n    const int4* __restrict__ B,                              // 4bit quantized weight matrix of shape kxn\n    int4* __restrict__ C,                                    // fp16 output buffer of shape mxn\n    int4* __restrict__ C_tmp,                                // fp32 tmp output buffer (for reduce)\n    const int4* __restrict__ scales_ptr,                     // fp16 quantization scales of shape\n                                                             // (k/groupsize)xn\n    const int4* __restrict__ zp_ptr,                         // 4bit packed zero-points of shape\n                                                             // (k/groupsize)x(n/pack_factor)\n    const int* __restrict__ g_idx,                           // int32 group indices of shape k\n    const int32_t* __restrict__ sorted_token_ids_ptr,        // moe sorted_ids\n    const int32_t* __restrict__ expert_ids_ptr,              // moe expert ids\n    const int32_t* __restrict__ num_tokens_past_padded_ptr,  // moe num tokens\n    const float* __restrict__ topk_weights_ptr,              // moe top weights\n    int top_k,                                               // num of experts per token\n    bool mul_topk_weights,                                   // mul topk weights or not\n    bool is_ep,                                              // expert parallelism\n    int num_groups,                                          // number of scale groups per output channel\n    int prob_m,                                              // batch dimension m\n    int prob_n,                                              // output dimension n\n    int prob_k,                                              // reduction dimension k\n    int* locks,                                              // extra global storage for barrier synchronization\n    bool use_atomic_add,                                     // whether to use atomic add to reduce\n    bool use_fp32_reduce,                                    // whether to use fp32 global reduce\n    int max_shared_mem) {}\n\n}  // namespace device::marlin_moe\n\n#else\n\n// m16n8k16 tensor core mma instruction with fp16 inputs and fp32\n// output/accumulation.\ntemplate <typename scalar_t>\n__device__ inline void\nmma(const typename ScalarType<scalar_t>::FragA& a_frag,\n    const typename ScalarType<scalar_t>::FragB& frag_b,\n    typename ScalarType<scalar_t>::FragC& frag_c) {\n  const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);\n  const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);\n  float* c = reinterpret_cast<float*>(&frag_c);\n  if constexpr (std::is_same<scalar_t, half>::value) {\n    asm volatile(\n        \"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \"\n        \"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\\n\"\n        : \"=f\"(c[0]), \"=f\"(c[1]), \"=f\"(c[2]), \"=f\"(c[3])\n        : \"r\"(a[0]), \"r\"(a[1]), \"r\"(a[2]), \"r\"(a[3]), \"r\"(b[0]), \"r\"(b[1]), \"f\"(c[0]), \"f\"(c[1]), \"f\"(c[2]), \"f\"(c[3]));\n  } else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {\n    asm volatile(\n        \"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 \"\n        \"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\\n\"\n        : \"=f\"(c[0]), \"=f\"(c[1]), \"=f\"(c[2]), \"=f\"(c[3])\n        : \"r\"(a[0]), \"r\"(a[1]), \"r\"(a[2]), \"r\"(a[3]), \"r\"(b[0]), \"r\"(b[1]), \"f\"(c[0]), \"f\"(c[1]), \"f\"(c[2]), \"f\"(c[3]));\n  } else {\n    STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);\n  }\n}\n\ntemplate <typename scalar_t>\n__device__ inline void mma_trans(\n    const typename ScalarType<scalar_t>::FragA& a_frag,\n    const typename ScalarType<scalar_t>::FragB& frag_b,\n    const typename ScalarType<scalar_t>::FragB& frag_b2,\n    typename ScalarType<scalar_t>::FragC& frag_c) {\n  const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);\n  const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);\n  const uint32_t* b2 = reinterpret_cast<const uint32_t*>(&frag_b2);\n  float* c = reinterpret_cast<float*>(&frag_c);\n  if constexpr (std::is_same<scalar_t, half>::value) {\n    asm volatile(\n        \"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \"\n        \"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\\n\"\n        : \"=f\"(c[0]), \"=f\"(c[1]), \"=f\"(c[2]), \"=f\"(c[3])\n        : \"r\"(b[0]),\n          \"r\"(b2[0]),\n          \"r\"(b[1]),\n          \"r\"(b2[1]),\n          \"r\"(a[0]),\n          \"r\"(a[1]),\n          \"f\"(c[0]),\n          \"f\"(c[1]),\n          \"f\"(c[2]),\n          \"f\"(c[3]));\n  } else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {\n    asm volatile(\n        \"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 \"\n        \"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\\n\"\n        : \"=f\"(c[0]), \"=f\"(c[1]), \"=f\"(c[2]), \"=f\"(c[3])\n        : \"r\"(b[0]),\n          \"r\"(b2[0]),\n          \"r\"(b[1]),\n          \"r\"(b2[1]),\n          \"r\"(a[0]),\n          \"r\"(a[1]),\n          \"f\"(c[0]),\n          \"f\"(c[1]),\n          \"f\"(c[2]),\n          \"f\"(c[3]));\n  } else {\n    STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);\n  }\n}\n\n// Instruction for loading a full 16x16 matrix fragment of operand A from shared\n// memory, directly in tensor core layout.\ntemplate <int count, typename scalar_t>\n__device__ inline void ldsm(typename ScalarType<scalar_t>::FragA& frag_a, const void* smem_ptr) {\n  uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);\n  uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));\n  if constexpr (count == 4) {\n    asm volatile(\"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\\n\"\n                 : \"=r\"(a[0]), \"=r\"(a[1]), \"=r\"(a[2]), \"=r\"(a[3])\n                 : \"r\"(smem));\n  } else if constexpr (count == 2) {\n    asm volatile(\"ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\\n\" : \"=r\"(a[0]), \"=r\"(a[1]) : \"r\"(smem));\n  } else if constexpr (count == 1) {\n    asm volatile(\"ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\\n\" : \"=r\"(a[0]) : \"r\"(smem));\n  } else {\n    static_assert(count == 1 || count == 2 || count == 4, \"invalid count\");\n  }\n}\n\n// Multiply dequantized values by the corresponding quantization scale; used\n// only for grouped quantization.\ntemplate <typename scalar_t>\n__device__ inline void\nscale(typename ScalarType<scalar_t>::FragB& frag_b, typename ScalarType<scalar_t>::FragS& frag_s, int i) {\n  using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;\n  scalar_t2 s = ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_s)[i]);\n  frag_b[0] = __hmul2(frag_b[0], s);\n  frag_b[1] = __hmul2(frag_b[1], s);\n}\n\ntemplate <typename scalar_t>\n__device__ inline void scale_and_sub(typename ScalarType<scalar_t>::FragB& frag_b, scalar_t s, scalar_t zp) {\n  using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;\n  scalar_t2 s2 = ScalarType<scalar_t>::num2num2(s);\n  scalar_t2 zp2 = ScalarType<scalar_t>::num2num2(zp);\n  frag_b[0] = __hfma2(frag_b[0], s2, __hneg2(zp2));\n  frag_b[1] = __hfma2(frag_b[1], s2, __hneg2(zp2));\n}\n\ntemplate <typename scalar_t>\n__device__ inline void\nsub_zp(typename ScalarType<scalar_t>::FragB& frag_b, typename ScalarType<scalar_t>::scalar_t2& frag_zp, int i) {\n  using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;\n  scalar_t2 zp = ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_zp)[i]);\n  frag_b[0] = __hsub2(frag_b[0], zp);\n  frag_b[1] = __hsub2(frag_b[1], zp);\n}\n\n// Same as above, but for act_order (each K is multiplied individually)\ntemplate <typename scalar_t>\n__device__ inline void scale4(\n    typename ScalarType<scalar_t>::FragB& frag_b,\n    typename ScalarType<scalar_t>::FragS& frag_s_1,\n    typename ScalarType<scalar_t>::FragS& frag_s_2,\n    typename ScalarType<scalar_t>::FragS& frag_s_3,\n    typename ScalarType<scalar_t>::FragS& frag_s_4,\n    int i) {\n  using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;\n  scalar_t2 s_val_1_2;\n  s_val_1_2.x = reinterpret_cast<scalar_t*>(&frag_s_1)[i];\n  s_val_1_2.y = reinterpret_cast<scalar_t*>(&frag_s_2)[i];\n\n  scalar_t2 s_val_3_4;\n  s_val_3_4.x = reinterpret_cast<scalar_t*>(&frag_s_3)[i];\n  s_val_3_4.y = reinterpret_cast<scalar_t*>(&frag_s_4)[i];\n\n  frag_b[0] = __hmul2(frag_b[0], s_val_1_2);\n  frag_b[1] = __hmul2(frag_b[1], s_val_3_4);\n}\n\n// Given 2 floats multiply by 2 scales (halves)\ntemplate <typename scalar_t>\n__device__ inline void scale_float(float* c, typename ScalarType<scalar_t>::FragS& s) {\n  scalar_t* s_ptr = reinterpret_cast<scalar_t*>(&s);\n  c[0] = __fmul_rn(c[0], ScalarType<scalar_t>::num2float(s_ptr[0]));\n  c[1] = __fmul_rn(c[1], ScalarType<scalar_t>::num2float(s_ptr[1]));\n}\n\n// Wait until barrier reaches `count`, then lock for current threadblock.\n__device__ inline void barrier_acquire(int* lock, int count) {\n  if (threadIdx.x == 0) {\n    int state = -1;\n    do\n      // Guarantee that subsequent writes by this threadblock will be visible\n      // globally.\n      asm volatile(\"ld.global.acquire.gpu.b32 %0, [%1];\\n\" : \"=r\"(state) : \"l\"(lock));\n    while (state != count);\n  }\n  __syncthreads();\n}\n\n// Release barrier and increment visitation count.\n__device__ inline void barrier_release(int* lock, bool reset = false) {\n  __syncthreads();\n  if (threadIdx.x == 0) {\n    if (reset) {\n      lock[0] = 0;\n      return;\n    }\n    int val = 1;\n    // Make sure that all writes since acquiring this barrier are visible\n    // globally, while releasing the barrier.\n    asm volatile(\"fence.acq_rel.gpu;\\n\");\n    asm volatile(\"red.relaxed.gpu.global.add.s32 [%0], %1;\\n\" : : \"l\"(lock), \"r\"(val));\n  }\n}\n\n// Wait until value of lock to be negative, and then add 1\n__device__ inline void wait_negative_and_add(int* lock) {\n  if (threadIdx.x == 0) {\n    int state = 0;\n    do\n      // Guarantee that subsequent writes by this threadblock will be visible\n      // globally.\n      asm volatile(\"ld.global.acquire.gpu.b32 %0, [%1];\\n\" : \"=r\"(state) : \"l\"(lock));\n    while (state >= 0);\n    atomicAdd(lock, 1);\n  }\n  __syncthreads();\n}\n\ntemplate <\n    typename scalar_t,                   // compute dtype, half or nv_float16\n    const host::ScalarTypeId w_type_id,  // weight ScalarType id\n    const host::ScalarTypeId s_type_id,  // weight scale ScalarType id\n    const int threads,                   // number of threads in a threadblock\n    const int thread_m_blocks,           // number of 16x16 blocks in the m\n                                         // dimension (batchsize) of the\n                                         // threadblock\n    const int thread_n_blocks,           // same for n dimension (output)\n    const int thread_k_blocks,           // same for k dimension (reduction)\n    const bool m_block_size_8,           // whether m_block_size == 8\n                                         // only works when thread_m_blocks == 1\n    const int stages,                    // number of stages for the async global->shared\n                                         // fetch pipeline\n    const int group_blocks,              // number of consecutive 16x16 blocks\n                                         // with a separate quantization scale\n    const bool is_zp_float               // is zero point of float16 type?\n    >\n__global__ void Marlin(\n    const int4* __restrict__ A,  // fp16 input matrix of shape mxk\n    const int4* __restrict__ B,  // 4bit quantized weight matrix of shape kxn\n    int4* __restrict__ C,        // fp16 output buffer of shape mxn\n    int4* __restrict__ C_tmp,    // fp32 tmp output buffer (for reduce)\n    const int4* __restrict__ b_bias_ptr,\n    const int4* __restrict__ scales_ptr,                     // fp16 quantization scales of shape\n                                                             // (k/groupsize)xn\n    const uint16_t* __restrict__ scale2_ptr,                 // fp16 global scale (for nvfp4\n                                                             // only)\n    const int4* __restrict__ zp_ptr,                         // 4bit packed zero-points of shape\n                                                             // (k/groupsize)x(n/pack_factor)\n    const int* __restrict__ g_idx,                           // int32 group indices of shape k\n    const int32_t* __restrict__ sorted_token_ids_ptr,        // moe sorted_ids\n    const int32_t* __restrict__ expert_ids_ptr,              // moe expert ids\n    const int32_t* __restrict__ num_tokens_past_padded_ptr,  // moe num tokens\n    const float* __restrict__ topk_weights_ptr,              // moe top weights\n    int top_k,                                               // num of experts per token\n    bool mul_topk_weights,                                   // mul topk weights or not\n    bool is_ep,                                              // expert parallelism\n    int num_groups,                                          // number of scale groups per output channel\n    int prob_m,                                              // batch dimension m\n    int prob_n,                                              // output dimension n\n    int prob_k,                                              // reduction dimension k\n    int* locks,                                              // extra global storage for barrier synchronization\n    bool has_bias,\n    bool use_atomic_add,   // whether to use atomic add to reduce\n    bool use_fp32_reduce,  // whether to use fp32 global reduce\n    int max_shared_mem) {\n  // Each threadblock processes one \"stripe\" of the B matrix with (roughly) the\n  // same size, which might involve multiple column \"slices\" (of width 16 *\n  // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM\n  // example:\n  //   0 1 3\n  //   0 2 3\n  //   1 2 4\n  // While this kind of partitioning makes things somewhat more complicated, it\n  // ensures good utilization of all SMs for many kinds of shape and GPU\n  // configurations, while requiring as few slow global cross-threadblock\n  // reductions as possible.\n  using Dtype = ScalarType<scalar_t>;\n  using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;\n  using FragA = typename ScalarType<scalar_t>::FragA;\n  using FragB = typename ScalarType<scalar_t>::FragB;\n  using FragC = typename ScalarType<scalar_t>::FragC;\n  using FragS = typename ScalarType<scalar_t>::FragS;\n  using FragZP = typename ScalarType<scalar_t>::FragZP;\n\n  extern __shared__ int4 sh[];\n  static constexpr auto w_type = host::ScalarType::from_id(w_type_id);\n  static constexpr auto s_type = host::ScalarType::from_id(s_type_id);\n  if constexpr (w_type == host::kFE2M1f) {\n    static_assert(s_type == host::kFE4M3fn && group_blocks == 1 || s_type == host::kFE8M0fnu && group_blocks == 2);\n  } else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {\n    static_assert(s_type == host::kBFloat16);\n  } else if constexpr (std::is_same<scalar_t, half>::value) {\n    static_assert(s_type == host::kFloat16);\n  }\n\n  constexpr bool has_zp = w_type == host::kU4 || w_type == host::kU8;\n  constexpr bool is_int_type =\n      w_type == host::kU4 || w_type == host::kU8 || w_type == host::kU4B8 || w_type == host::kU8B128;\n  // see comments of dequant.h for more details\n  constexpr bool dequant_skip_flop = w_type == host::kFE4M3fn || w_type == host::kFE2M1f && s_type == host::kFE4M3fn ||\n                                     has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||\n                                     has_zp && !is_zp_float && !(w_type == host::kU8);\n\n  scalar_t2 global_scale;\n\n  constexpr bool has_act_order = group_blocks == 0;\n\n  constexpr int pack_factor = 32 / w_type.size_bits();\n  static_assert(thread_m_blocks == 1 || !m_block_size_8);\n  constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks);\n  const int group_size = (!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups;\n  const int scales_expert_stride = prob_n * prob_k / group_size / (w_type == host::kFE2M1f ? 16 : 8);\n  const int zp_expert_stride =\n      is_zp_float ? prob_n * prob_k / group_size / 8 : prob_n * prob_k / group_size / (pack_factor * 4);\n  const int b_bias_expert_stride = prob_n / 8;\n\n  // parallel: num valid moe blocks\n  int num_tokens_past_padded = num_tokens_past_padded_ptr[0];\n  int parallel = num_tokens_past_padded / moe_block_size;\n  int num_valid_blocks = parallel;\n  if (is_ep) {\n    for (int i = 0; i < parallel; i++) {\n      if (expert_ids_ptr[i] == -1) num_valid_blocks--;\n    }\n  }\n  int num_invalid_blocks = parallel - num_valid_blocks;\n  parallel = num_valid_blocks;\n\n  int k_tiles = prob_k / 16 / thread_k_blocks;\n  int n_tiles = prob_n / 16 / thread_n_blocks;\n  int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x);\n\n  if constexpr (!has_act_order && group_blocks != -1) {\n    if (group_blocks >= thread_k_blocks) {\n      // Ensure that the number of tiles in each stripe is a multiple of the\n      // groupsize; this avoids an annoying special case where a stripe starts\n      // in the middle of group.\n      iters = (group_blocks / thread_k_blocks) * div_ceil(iters, (group_blocks / thread_k_blocks));\n    }\n  }\n\n  int slice_row = (iters * blockIdx.x) % k_tiles;\n  int slice_col_par = (iters * blockIdx.x) / k_tiles;\n  int slice_col = slice_col_par;\n  int slice_iters;      // number of threadblock tiles in the current slice\n  int slice_count = 0;  // total number of active threadblocks in the current slice\n  int slice_idx;        // index of threadblock in current slice; numbered bottom to\n                        // top\n\n  int par_id = 0;\n  int block_id = -1;\n  int64_t expert_id = 0;  // use int64 to avoid computation result overflow\n  int old_expert_id = 0;\n  int64_t B_expert_off = 0;\n\n  int4* sh_block_sorted_ids_int4 = sh;\n  int4* sh_rd_block_sorted_ids_int4 = sh_block_sorted_ids_int4 + moe_block_size / 4;\n  int4* sh_block_topk_weights_int4 = sh_rd_block_sorted_ids_int4 + moe_block_size / 4;\n  // sh_block_topk_weights_int4 only need (moe_block_size / 4);\n  // but we pad to align to 256 bytes\n  int4* sh_new = sh_block_topk_weights_int4 + moe_block_size / 2 + moe_block_size;\n  int32_t* sh_block_sorted_ids = reinterpret_cast<int*>(sh_block_sorted_ids_int4);\n  int32_t* sh_rd_block_sorted_ids = reinterpret_cast<int*>(sh_rd_block_sorted_ids_int4);\n  scalar_t2* sh_block_topk_weights = reinterpret_cast<scalar_t2*>(sh_block_topk_weights_int4);\n\n  int32_t block_num_valid_tokens = 0;\n  int32_t locks_off = 0;\n\n  // We can easily implement parallel problem execution by just remapping\n  // indices and advancing global pointers\n  if (slice_col_par >= n_tiles) {\n    slice_col = slice_col_par % n_tiles;\n    par_id = slice_col_par / n_tiles;\n  }\n  if (parallel * n_tiles >= gridDim.x) {\n    // when parallel * n_tiles >= sms\n    // then there are at most $sms$ conflict tile blocks\n    locks_off = blockIdx.x;\n  } else {\n    locks_off = (iters * blockIdx.x) / k_tiles - 1;\n  }\n\n  // read moe block data given block_id\n  // block_sorted_ids / block_num_valid_tokens / block_topk_weights\n  auto read_moe_block_data = [&](int block_id) {\n    block_num_valid_tokens = moe_block_size;\n#pragma unroll\n    for (int i = 0; i < moe_block_size / 4; i++) {\n      int4 sorted_token_ids_int4 =\n          reinterpret_cast<const int4*>(sorted_token_ids_ptr)[block_id * moe_block_size / 4 + i];\n      int* sorted_token_ids = reinterpret_cast<int*>(&sorted_token_ids_int4);\n#pragma unroll\n      for (int j = 0; j < 4; j++) {\n        if (sorted_token_ids[j] >= prob_m * top_k) {\n          block_num_valid_tokens = i * 4 + j;\n          break;\n        }\n      }\n      if (block_num_valid_tokens != moe_block_size) break;\n    }\n\n    __syncthreads();\n    int tid4 = threadIdx.x / 4;\n    if (threadIdx.x % 4 == 0 && threadIdx.x < block_num_valid_tokens) {\n      sh_block_sorted_ids_int4[tid4] =\n          reinterpret_cast<const int4*>(sorted_token_ids_ptr)[block_id * moe_block_size / 4 + tid4];\n\n#pragma unroll\n      for (int i = 0; i < 4; i++)\n        sh_rd_block_sorted_ids[tid4 * 4 + i] = sh_block_sorted_ids[tid4 * 4 + i] / top_k;\n\n      if (mul_topk_weights) {\n#pragma unroll\n        for (int i = 0; i < 4; i++) {\n          int idx = tid4 * 4 + i;\n          // idx = idx < block_num_valid_tokens ? idx : 0;\n          if (idx < block_num_valid_tokens) {\n            if constexpr (w_type == host::kFE2M1f && s_type == host::kFE4M3fn) {\n              sh_block_topk_weights[idx] =\n                  __hmul2(global_scale, Dtype::num2num2(Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[idx]])));\n            } else {\n              sh_block_topk_weights[idx] =\n                  Dtype::num2num2(Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[idx]]));\n            }\n          }\n        }\n      }\n    }\n    __syncthreads();\n  };\n\n  // when move to next moe block, find the next block_id and expert_id\n  // and then read moe block data\n  auto update_next_moe_block_data = [&]() {\n    if (par_id >= parallel) return;\n\n    old_expert_id = expert_id;\n    if (num_invalid_blocks > 0) {\n      int skip_count = block_id == -1 ? par_id : 0;\n      block_id++;\n      for (int i = block_id; i < num_tokens_past_padded / moe_block_size; i++) {\n        expert_id = expert_ids_ptr[i];\n        if (expert_id != -1) {\n          if (skip_count == 0) {\n            block_id = i;\n            break;\n          };\n          skip_count--;\n        };\n      }\n    } else {\n      block_id = par_id;\n      expert_id = expert_ids_ptr[block_id];\n    }\n\n    if constexpr (w_type == host::kFE2M1f && s_type == host::kFE4M3fn) {\n      uint16_t val = scale2_ptr[expert_id];\n      global_scale = Dtype::num2num2(*reinterpret_cast<scalar_t*>(&val));\n    }\n\n    B_expert_off = expert_id * prob_n * prob_k / (pack_factor * 4);\n    scales_ptr += (expert_id - old_expert_id) * scales_expert_stride;\n    if constexpr (has_zp) {\n      zp_ptr += (expert_id - old_expert_id) * zp_expert_stride;\n    }\n    if constexpr (has_act_order) {\n      g_idx += (expert_id - old_expert_id) * prob_k;\n    }\n    if (has_bias) {\n      b_bias_ptr += (expert_id - old_expert_id) * b_bias_expert_stride;\n    }\n\n    read_moe_block_data(block_id);\n  };\n\n  // Compute all information about the current slice which is required for\n  // synchronization.\n  auto init_slice = [&](bool first_init = false) {\n    slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);\n    if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;\n    if (slice_iters == 0) return;\n    if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;\n    slice_count = 1;\n    slice_idx = 0;\n    int col_first = iters * div_ceil(k_tiles * slice_col_par, iters);\n    if (col_first <= k_tiles * (slice_col_par + 1)) {\n      int col_off = col_first - k_tiles * slice_col_par;\n      slice_count = div_ceil(k_tiles - col_off, iters);\n      if (col_off > 0) slice_count++;\n      int delta_first = iters * blockIdx.x - col_first;\n      if (delta_first < 0 || (col_off == 0 && delta_first == 0))\n        slice_idx = slice_count - 1;\n      else {\n        slice_idx = slice_count - 1 - delta_first / iters;\n        if (col_off > 0) slice_idx--;\n      }\n    }\n    if (parallel * n_tiles >= gridDim.x) {\n      if (slice_count > 1 && slice_idx == slice_count - 1) {\n        locks_off++;\n      }\n    } else {\n      locks_off++;\n    }\n\n    if (first_init && use_atomic_add && slice_count > 1 && slice_idx == 0) {\n      constexpr int threads_per_m = 16 * thread_n_blocks / 8;\n      int m_per_thread = div_ceil(block_num_valid_tokens, threads / threads_per_m);\n      for (int i = 0; i < m_per_thread; i++) {\n        int row = threads / threads_per_m * i + threadIdx.x / threads_per_m;\n        if (row < block_num_valid_tokens) {\n          int64_t sorted_row = sh_block_sorted_ids[row];\n          int col = slice_col * 16 * thread_n_blocks / 8 + threadIdx.x % threads_per_m;\n          C[sorted_row * prob_n / 8 + col] = {0, 0, 0, 0};\n        }\n      }\n      // After write zero to output, write a negative value to lock.\n      // Every SM that processes the same slice would wait for\n      // the negative value, and then atomicAdd 1 to it.\n      // After all SMs are processed, the lock value would back to 0 again.\n      __syncthreads();\n      if (threadIdx.x == 0) locks[locks_off] = 1 - slice_count;\n    }\n\n    if (slice_col == n_tiles) {\n      slice_col = 0;\n      par_id++;\n      update_next_moe_block_data();\n    }\n  };\n\n  update_next_moe_block_data();\n  init_slice(true);\n\n  // A sizes/strides\n\n  // stride of the A matrix in global memory\n  int a_gl_stride = prob_k / 8;\n  // stride of an A matrix tile in shared memory\n  constexpr int a_sh_stride = 16 * thread_k_blocks / 8;\n  // delta between subsequent A tiles in global memory\n  constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8;\n  // between subsequent accesses within a tile\n  int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o);\n  // between shared memory writes\n  constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o);\n  // between shared memory tile reads\n  constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4));\n  // within a shared memory tile\n  constexpr int a_sh_rd_delta_i = a_sh_stride * 16;\n  // overall size of a tile\n  constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks);\n  // number of shared write iterations for a tile\n  constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta);\n\n  // B sizes/strides\n  int b_gl_stride = 16 * prob_n / (pack_factor * 4);\n  constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4;\n  constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2;\n  constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs;\n\n  int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;\n  int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads);\n  constexpr int b_sh_wr_delta = threads * b_thread_vecs;\n  constexpr int b_sh_rd_delta = threads * b_thread_vecs;\n  constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;\n  constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;\n\n  // Scale sizes/strides without act_order\n  int s_gl_stride = prob_n / 8;\n  constexpr int s_sh_stride = 16 * thread_n_blocks / 8;\n  constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks\n                                  ? thread_k_blocks / group_blocks / (w_type == host::kFE2M1f ? 2 : 1)\n                                  : 1;\n  constexpr int s_sh_stage = s_tb_groups * s_sh_stride;\n  int s_gl_rd_delta = s_gl_stride;\n\n  // Scale size/strides with act_order\n  constexpr int tb_k = 16 * thread_k_blocks;\n  constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0;\n  // constexpr int act_s_row_stride      = 1;\n  // int           act_s_col_stride      = act_s_row_stride * num_groups;\n  constexpr int act_s_max_num_groups = 32;\n  int act_s_col_stride = 1;\n  int act_s_col_warp_stride = act_s_col_stride * 8;\n  int tb_n_warps = thread_n_blocks / 4;\n  int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;\n\n  // Zero-points sizes/strides\n  int zp_gl_stride = is_zp_float ? prob_n / 8 : (prob_n / pack_factor) / 4;\n  constexpr int zp_sh_stride = is_zp_float ? 16 * thread_n_blocks / 8 : ((16 * thread_n_blocks) / pack_factor) / 4;\n  constexpr int zp_tb_groups = s_tb_groups;\n  constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0;\n  int zp_gl_rd_delta = zp_gl_stride;\n\n  // Global A read index of current thread.\n  int a_gl_rd_row = threadIdx.x / a_gl_rd_delta_o;\n  int a_gl_rd_col = a_gl_rd_delta_o * slice_row + threadIdx.x % a_gl_rd_delta_o;\n\n  // Shared write index of current thread.\n  int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o);\n  // Shared read index.\n  int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % (16 / (m_block_size_8 ? 2 : 1))) +\n                (threadIdx.x % 32) / (16 / (m_block_size_8 ? 2 : 1));\n  a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));\n\n  int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads) * b_thread_vecs;\n  b_gl_rd += b_sh_stride * slice_col;\n  b_gl_rd += b_gl_rd_delta_o * slice_row;\n  auto b_sh_wr = threadIdx.x * b_thread_vecs;\n  auto b_sh_rd = threadIdx.x * b_thread_vecs;\n\n  // For act_order\n  constexpr int k_iter_size = tb_k / b_sh_wr_iters;\n  int slice_k_start = tb_k * slice_row;\n  int slice_k_finish = slice_k_start + tb_k * slice_iters;\n  int slice_k_start_shared_fetch = slice_k_start;\n  int slice_n_offset = act_s_col_tb_stride * slice_col;\n\n  // No act_order\n  int s_gl_rd;\n  if constexpr (!has_act_order) {\n    if constexpr (group_blocks == -1) {\n      s_gl_rd = s_sh_stride * slice_col + threadIdx.x;\n    } else {\n      s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / (w_type == host::kFE2M1f ? 2 : 1) +\n                s_sh_stride * slice_col + threadIdx.x;\n    }\n  }\n  auto s_sh_wr = threadIdx.x;\n  bool s_sh_wr_pred = threadIdx.x < s_sh_stride;\n\n  // Zero-points\n  int zp_gl_rd;\n  if constexpr (has_zp) {\n    if constexpr (group_blocks == -1) {\n      zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;\n    } else {\n      zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + zp_sh_stride * slice_col + threadIdx.x;\n    }\n  }\n  auto zp_sh_wr = threadIdx.x;\n  bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride;\n\n  // We use a different scale layout for grouped and column-wise quantization as\n  // we scale a `half2` tile in column-major layout in the former and in\n  // row-major in the latter case.\n  int s_sh_rd;\n  if constexpr (group_blocks != -1 && w_type == host::kFE2M1f) {\n    auto warp_id = threadIdx.x / 32;\n    int n_warps = thread_n_blocks / 4;\n    int warp_row = warp_id / n_warps;\n\n    s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4;\n    s_sh_rd = s_sh_rd * 2 + (warp_row / group_blocks) % 2;\n\n  } else if constexpr (group_blocks != -1)\n    s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4;\n  else if constexpr (group_blocks == -1 && (m_block_size_8 || (has_zp && !dequant_skip_flop)))\n    s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 8;\n  else\n    s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4;\n\n  int bias_sh_rd;\n  if constexpr (m_block_size_8) {\n    bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 8;\n  } else {\n    bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4;\n  }\n\n  int bias_sh_wr = threadIdx.x;\n  int bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x;\n\n  // Zero-points have the same read layout as the scales\n  // (without column-wise case)\n  constexpr int num_col_threads = 8;\n  constexpr int num_row_threads = 4;\n  constexpr int num_ints_per_thread = 8 / pack_factor;\n  int zp_sh_rd;\n  if constexpr (has_zp) {\n    if constexpr (is_zp_float) {\n      if constexpr (group_blocks != -1) {\n        zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4;\n      }\n    } else {\n      zp_sh_rd = num_ints_per_thread * num_col_threads * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +\n                 num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads);\n    }\n  }\n\n  // To ensure that writing and reading A tiles to/from shared memory, the\n  // latter in fragment format, is fully bank conflict free, we need to use a\n  // rather fancy XOR-based layout. The key here is that neither reads nor\n  // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the\n  // same shared memory banks. Further, it seems (based on NSight-Compute) that\n  // each warp must also write a consecutive memory segment?\n  auto transform_a = [&](int i) {\n    int row = i / a_gl_rd_delta_o;\n    return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ (row % 8);\n  };\n  // Since the computation of this remapping is non-trivial and, due to our main\n  // loop unrolls, all shared memory accesses are static, we simply precompute\n  // both transformed reads and writes.\n  int a_sh_wr_trans[a_sh_wr_iters];\n#pragma unroll\n  for (int i = 0; i < a_sh_wr_iters; i++)\n    a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);\n  int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];\n#pragma unroll\n  for (int i = 0; i < b_sh_wr_iters; i++) {\n#pragma unroll\n    for (int j = 0; j < thread_m_blocks; j++)\n      a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);\n  }\n\n  // Since B-accesses have non-constant stride they have to be computed at\n  // runtime; we break dependencies between subsequent accesses with a tile by\n  // maintining multiple pointers (we have enough registers), a tiny\n  // optimization.\n  const int4* B_ptr[b_sh_wr_iters];\n#pragma unroll\n  for (int i = 0; i < b_sh_wr_iters; i++)\n    B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;\n\n  // Shared memory storage for global fetch pipelines.\n  constexpr int sh_red_size = (2 * thread_n_blocks + 1) * 16 * thread_m_blocks;\n  constexpr int sh_b_size = stages * b_sh_stage;\n  int4* sh_b = sh_new;\n  int4* sh_red = sh_new;\n\n  constexpr int sh_size_b_red_min = (sh_red_size < sh_b_size ? sh_red_size : sh_b_size);\n  constexpr int sh_size_b_red_max = (sh_red_size > sh_b_size ? sh_red_size : sh_b_size);\n  constexpr int sh_bias_size = (thread_n_blocks * 16 / 8);\n  constexpr int sh_b_red_bias_size =\n      sh_size_b_red_max > (sh_size_b_red_min + sh_bias_size) ? sh_size_b_red_max : (sh_size_b_red_min + sh_bias_size);\n\n  int4* sh_bias = sh_new + sh_size_b_red_min;\n  int4* sh_g_idx = sh_new + sh_b_red_bias_size;\n  int4* sh_zp = sh_g_idx + (stages * g_idx_stage);\n  constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride) : (stages * s_sh_stage);\n  int4* sh_s = sh_zp + (stages * zp_sh_stage);\n  // shared memory reused by reduction should be smaller than\n  // shared memory used by weight.\n  static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <= stages * b_sh_stage);\n  int4* sh_a = sh_s + sh_s_size;\n  constexpr int shm_size_used = moe_block_size + stages * (g_idx_stage + zp_sh_stage) + sh_s_size + sh_b_red_bias_size;\n\n  // all remaining shared memory is used to cache A (input)\n  // sh_a_max_row is at least ` stages * 16 * thread_m_blocks `\n  int sh_a_max_row = ((max_shared_mem - 1024) / 16 - shm_size_used) / (thread_k_blocks * 2);\n\n  // Register storage for double buffer of shared memory reads.\n  FragA frag_a[2][thread_m_blocks];\n  I4 frag_b_quant[2][b_thread_vecs];\n  FragC frag_c[thread_m_blocks][4][2];\n  FragS frag_s[2][4];  // No act-order\n  FragS frag_bias[2][4];\n  FragS act_frag_s[2][4][4];             // For act-order\n  int frag_qzp[2][num_ints_per_thread];  // Zero-points\n  FragZP frag_zp;                        // Zero-points in fp16\n  FragZP frag_zpf[2];                    // Zero-points in fp16 in HQQ\n\n  // Zero accumulators.\n  auto zero_accums = [&]() {\n#pragma unroll\n    for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)\n      reinterpret_cast<float*>(frag_c)[i] = 0;\n  };\n\n  int sh_first_group_id = -1;\n  int sh_num_groups = -1;\n\n  auto fetch_act_order_scales_to_shared = [&](bool is_async, int first_group_id, int last_group_id) {\n    sh_first_group_id = first_group_id;\n    sh_num_groups = last_group_id - first_group_id + 1;\n\n    if (sh_num_groups > act_s_max_num_groups) {\n      sh_num_groups = act_s_max_num_groups;\n    }\n\n    if (sh_first_group_id + sh_num_groups > num_groups) {\n      sh_num_groups = num_groups - sh_first_group_id;\n    }\n\n    int row_offset = first_group_id * s_gl_stride;\n\n    if (is_async) {\n      for (int i = 0; i < sh_num_groups; i++) {\n        if (threadIdx.x < s_sh_stride) {\n          cp_async4_pred(\n              &sh_s[(i * s_sh_stride) + threadIdx.x],\n              &scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + threadIdx.x]);\n        }\n      }\n    } else {\n      for (int i = 0; i < sh_num_groups; i++) {\n        if (threadIdx.x < s_sh_stride) {\n          sh_s[(i * s_sh_stride) + threadIdx.x] =\n              scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + threadIdx.x];\n        }\n      }\n    }\n  };\n\n  // Asynchronously fetch the next A, B and s tile from global to the next\n  // shared memory pipeline location.\n  bool should_load_a = true;\n  int max_num_stage_groups = ((sh_a_max_row - moe_block_size) / moe_block_size + 1) / stages;\n  max_num_stage_groups = max(max_num_stage_groups, 1);\n  auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true, int pipe_a = 0) {\n    if (pred) {\n      if (should_load_a) {\n        int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe_a;\n#pragma unroll\n        for (int i = 0; i < a_sh_wr_iters; i++) {\n          int row = a_gl_rd_delta_i / a_gl_stride * i + a_gl_rd_row;\n          int64_t sorted_row = 0;\n          if (!m_block_size_8 || row < 8) sorted_row = sh_rd_block_sorted_ids[row];\n          int64_t true_idx = sorted_row * a_gl_stride + a_gl_rd_col + a_gl_rd_delta_o * a_off;\n          cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[true_idx], row < block_num_valid_tokens);\n        }\n      }\n\n      int4* sh_b_stage = sh_b + b_sh_stage * pipe;\n#pragma unroll\n      for (int i = 0; i < b_sh_wr_iters; i++) {\n#pragma unroll\n        for (int j = 0; j < b_thread_vecs; j++) {\n          cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j + B_expert_off);\n        }\n\n        B_ptr[i] += b_gl_rd_delta_o;\n      }\n\n      if constexpr (has_act_order) {\n        // Fetch g_idx thread-block portion\n        int full_pipe = a_off;\n        int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe;\n        if (cur_k < prob_k && cur_k < slice_k_finish) {\n          int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;\n\n          int4 const* cur_g_idx_stage_ptr = reinterpret_cast<int4 const*>(&g_idx[cur_k]);\n\n          if (threadIdx.x < g_idx_stage) {\n            cp_async4_pred(&sh_g_idx_stage[threadIdx.x], &cur_g_idx_stage_ptr[threadIdx.x]);\n          }\n        }\n      } else {\n        if constexpr (group_blocks != -1) {\n          int4* sh_s_stage = sh_s + s_sh_stage * pipe;\n\n          if constexpr (group_blocks >= thread_k_blocks) {\n            // Only fetch scales if this tile starts a new group\n            if (pipe % (group_blocks / thread_k_blocks) == 0) {\n              if (s_sh_wr_pred) {\n                cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);\n              }\n              s_gl_rd += s_gl_rd_delta;\n            }\n          } else {\n            for (int i = 0; i < s_tb_groups; i++) {\n              if (s_sh_wr_pred) {\n                cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], &scales_ptr[s_gl_rd]);\n              }\n              s_gl_rd += s_gl_rd_delta;\n            }\n          }\n        }\n\n        if constexpr (has_zp && group_blocks != -1) {\n          int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;\n\n          if constexpr (group_blocks >= thread_k_blocks) {\n            // Only fetch zero-points if this tile starts a new group\n            if (pipe % (group_blocks / thread_k_blocks) == 0) {\n              if (zp_sh_wr_pred) {\n                cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]);\n              }\n              zp_gl_rd += zp_gl_rd_delta;\n            }\n          } else {\n            for (int i = 0; i < zp_tb_groups; i++) {\n              if (zp_sh_wr_pred) {\n                cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], &zp_ptr[zp_gl_rd]);\n              }\n              zp_gl_rd += zp_gl_rd_delta;\n            }\n          }\n        }\n      }\n    }\n    // Insert a fence even when we are winding down the pipeline to ensure that\n    // waiting is also correct at this point.\n    cp_async_fence();\n  };\n\n  auto fetch_col_zp_to_shared = [&]() {\n    if (zp_sh_wr_pred) {\n      cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]);\n    }\n  };\n\n  auto fetch_col_scale_to_shared = [&]() {\n    if (s_sh_wr_pred) {\n      cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);\n    }\n  };\n\n  // Wait until the next thread tile has been loaded to shared memory.\n  auto wait_for_stage = [&]() {\n    // We only have `stages - 2` active fetches since we are double buffering\n    // and can only issue the next fetch when it is guaranteed that the previous\n    // shared memory load is fully complete (as it may otherwise be\n    // overwritten).\n    cp_async_wait<stages - 2>();\n    __syncthreads();\n  };\n\n  // Load the next sub-tile from the current location in the shared memory pipe\n  // into the current register buffer.\n  auto fetch_to_registers = [&](int k, int pipe, int pipe_a = 0) {\n    int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe_a;\n#pragma unroll\n    for (int i = 0; i < thread_m_blocks; i++)\n      ldsm<m_block_size_8 ? 2 : 4, scalar_t>(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);\n    int4* sh_b_stage = sh_b + b_sh_stage * pipe;\n\n#pragma unroll\n    for (int i = 0; i < b_thread_vecs; i++) {\n      frag_b_quant[k % 2][i] = *reinterpret_cast<I4*>(&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]);\n    }\n  };\n\n  bool is_same_group[stages];\n  int same_group_id[stages];\n\n  auto init_same_group = [&](int pipe) {\n    if constexpr (!has_act_order) {\n      return;\n    }\n\n    int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;\n    int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage);\n\n    int group_id_1 = sh_g_idx_int_ptr[0];\n    int group_id_2 = sh_g_idx_int_ptr[tb_k - 1];\n\n    is_same_group[pipe] = group_id_1 == group_id_2;\n    same_group_id[pipe] = group_id_1;\n  };\n\n  auto fetch_scales_to_registers = [&](int k, int full_pipe) {\n    int pipe = full_pipe % stages;\n\n    if constexpr (!has_act_order) {\n      // No act-order case\n      if constexpr (group_blocks == -1) {\n        // load only when starting a new slice\n        if (k == 0 && full_pipe == 0) {\n          reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd];\n          reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];\n        }\n      } else if constexpr (group_blocks != -1) {\n        if constexpr (group_blocks >= thread_k_blocks) {\n          if (k % b_sh_wr_iters == 0) {\n            int4* sh_s_stage =\n                sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks)));\n            reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];\n          } else {\n            reinterpret_cast<int4*>(&frag_s[1])[0] = reinterpret_cast<int4*>(&frag_s[0])[0];\n          }\n        } else {\n          auto warp_id = threadIdx.x / 32;\n          int n_warps = thread_n_blocks / 4;\n\n          int warp_row = warp_id / n_warps;\n\n          int cur_k = warp_row * 16;\n          cur_k += k_iter_size * (k % b_sh_wr_iters);\n\n          int k_blocks = cur_k / 16;\n          int cur_group_id = k_blocks / (group_blocks * (w_type == host::kFE2M1f ? 2 : 1));\n\n          int4* sh_s_stage = sh_s + s_sh_stage * pipe;\n\n          if constexpr (w_type_id != host::kFE2M1f.id()) {\n            reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];\n          } else if constexpr (group_blocks == 1 || thread_k_blocks > 4) {\n            reinterpret_cast<int2*>(&frag_s[k % 2])[0] =\n                reinterpret_cast<int2*>(sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];\n          } else {\n            reinterpret_cast<int2*>(&frag_s[k % 2])[0] =\n                reinterpret_cast<int2*>(sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride) + k % 2];\n          }\n        }\n      }\n\n      return;\n    }\n\n    // Act-order case\n\n    // Determine K of the \"current\" thread-block\n    int cur_k = slice_k_start + tb_k * full_pipe;\n    if (cur_k >= prob_k || cur_k >= slice_k_finish) {\n      return;\n    }\n\n    // Reset (to current thread-block) since we read g_idx portion from the\n    // shared memory\n    cur_k = 0;\n\n    // Progress to current iteration\n    cur_k += k_iter_size * (k % b_sh_wr_iters);\n\n    // Determine \"position\" inside the thread-block (based on warp and\n    // thread-id)\n    auto warp_id = threadIdx.x / 32;\n    int n_warps = thread_n_blocks / 4;  // Each warp processes 4 16-size tiles over N\n\n    int warp_row = warp_id / n_warps;\n    int warp_col = warp_id % n_warps;\n\n    cur_k += warp_row * 16;\n\n    auto th_id = threadIdx.x % 32;\n    cur_k += (th_id % 4) * 2;  // Due to tensor-core layout for fp16 B matrix\n\n    int s_col_shift =\n        /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + (th_id / 4) * act_s_col_stride;\n\n    if (is_same_group[pipe]) {\n      if (k % 2 == 0) {\n        *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0]))) =\n            sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + s_col_shift];\n      } else {\n        *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0]))) =\n            *(reinterpret_cast<int4*>(&(act_frag_s[(k - 1) % 2][0][0])));\n      }\n\n      for (int i = 1; i < 4; i++) {\n        *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][i][0]))) = *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0])));\n      }\n      return;\n    }\n\n    int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;\n    int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage);\n\n    constexpr int k_frag_offsets[4] = {0, 1, 8, 9};  // Tensor core offsets per thread\n\n#pragma unroll\n    for (int i = 0; i < 4; i++) {\n      int actual_k = cur_k + k_frag_offsets[i];\n\n      int group_id = sh_g_idx_int_ptr[actual_k];\n      int rel_group_id = group_id - sh_first_group_id;\n\n      *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][i][0]))) = sh_s[rel_group_id * s_sh_stride + s_col_shift];\n    }\n  };\n\n  auto fetch_zp_to_registers = [&](int k, int full_pipe) {\n    // This code does not handle group_blocks == 0,\n    // which signifies act_order.\n    // has_zp implies AWQ, which doesn't have act_order,\n    static_assert(!has_zp || group_blocks != 0);\n\n    if constexpr (has_zp && !is_zp_float) {\n      int pipe = full_pipe % stages;\n\n      if constexpr (group_blocks == -1) {\n        // load only when starting a new slice\n        if (k == 0 && full_pipe == 0) {\n#pragma unroll\n          for (int i = 0; i < num_ints_per_thread; i++) {\n            frag_qzp[k % 2][i] = (reinterpret_cast<int*>(sh_zp))[zp_sh_rd + i];\n          }\n        }\n\n      } else if constexpr (group_blocks >= thread_k_blocks) {\n        if (k % b_sh_wr_iters == 0) {\n          int4* sh_zp_stage =\n              sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks)));\n#pragma unroll\n          for (int i = 0; i < num_ints_per_thread; i++) {\n            frag_qzp[k % 2][i] = (reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];\n          }\n        }\n      } else {\n        auto warp_id = threadIdx.x / 32;\n        int n_warps = thread_n_blocks / 4;\n\n        int warp_row = warp_id / n_warps;\n\n        int cur_k = warp_row * 16;\n        cur_k += k_iter_size * (k % b_sh_wr_iters);\n\n        int k_blocks = cur_k / 16;\n        int cur_group_id = 0;\n\n        // Suppress bogus and persistent divide-by-zero warning\n#pragma nv_diagnostic push\n#pragma nv_diag_suppress divide_by_zero\n        cur_group_id = k_blocks / group_blocks;\n#pragma nv_diagnostic pop\n\n        int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;\n\n        sh_zp_stage += cur_group_id * zp_sh_stride;\n\n#pragma unroll\n        for (int i = 0; i < num_ints_per_thread; i++) {\n          frag_qzp[k % 2][i] = (reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];\n        }\n      }\n    }\n\n    else if constexpr (has_zp && is_zp_float) {\n      int pipe = full_pipe % stages;\n\n      if constexpr (group_blocks != -1) {\n        if constexpr (group_blocks >= thread_k_blocks) {\n          if (k % b_sh_wr_iters == 0) {\n            int4* sh_zp_stage =\n                sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks)));\n            reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd];\n          }\n        } else {\n          auto warp_id = threadIdx.x / 32;\n          int n_warps = thread_n_blocks / 4;\n\n          int warp_row = warp_id / n_warps;\n\n          int cur_k = warp_row * 16;\n          cur_k += k_iter_size * (k % b_sh_wr_iters);\n\n          int k_blocks = cur_k / 16;\n          // Suppress bogus and persistent divide-by-zero warning\n#pragma nv_diagnostic push\n#pragma nv_diag_suppress divide_by_zero\n          int cur_group_id = k_blocks / group_blocks;\n#pragma nv_diagnostic pop\n\n          int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;\n\n          reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd + cur_group_id * zp_sh_stride];\n        }\n      }\n    }\n  };\n\n  auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) {\n    dequant<scalar_t2, w_type_id, dequant_skip_flop>(q, frag_b_ptr);\n  };\n\n  // Execute the actual tensor core matmul of a sub-tile.\n  bool is_first_matmul_in_slice = true;\n  auto matmul = [&](int k) {\n    int k2 = k % 2;\n    const bool is_new_zp = ((group_blocks != -1) && (group_blocks < thread_k_blocks || k == 0)) ||\n                           (group_blocks == -1 && is_first_matmul_in_slice);\n    if constexpr (has_zp && !is_zp_float) {\n      if (is_new_zp) {\n        if constexpr (group_blocks == -1) is_first_matmul_in_slice = false;\n        int zp_quant_0, zp_quant_1;\n\n        if constexpr (w_type.size_bits() == 4) {\n          zp_quant_0 = frag_qzp[k2][0];\n          zp_quant_1 = zp_quant_0 >> 8;\n        } else {\n          static_assert(w_type.size_bits() == 8);\n          zp_quant_0 = frag_qzp[k2][0];\n          zp_quant_1 = frag_qzp[k2][1];\n        }\n\n        dequant_data(zp_quant_0, reinterpret_cast<scalar_t2*>(&frag_zp));\n        dequant_data(zp_quant_1, reinterpret_cast<scalar_t2*>(&frag_zp) + 2);\n      }\n    }\n    if constexpr (!dequant_skip_flop && has_zp && is_zp_float) {\n      if (is_new_zp) {\n        reinterpret_cast<int4*>(&frag_zp)[0] = reinterpret_cast<int4*>(&frag_zpf[k2])[0];\n      }\n    }\n\n    // Commented out FP4/FP8 scale dequantization since we don't generate\n    // kFE2M1f kernels to reduce compilation time\n    // if constexpr (w_type == host::kFE2M1f) {\n    //   int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];\n    //   int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];\n    //\n    //   dequant_fp8_scales<scalar_t2, s_type_id>(\n    //       s_quant_0, reinterpret_cast<scalar_t2*>(&frag_s[k2]));\n    //   dequant_fp8_scales<scalar_t2, s_type_id>(\n    //       s_quant_1, reinterpret_cast<scalar_t2*>(&frag_s[k2]) + 2);\n    // }\n\n// We have the m dimension as the inner loop in order to encourage overlapping\n// dequantization and matmul operations.\n#pragma unroll\n    for (int j = 0; j < 4; j++) {\n      FragB frag_b0;\n      FragB frag_b1;\n      int b_quant_0, b_quant_1;\n\n      if constexpr (w_type_id == host::kFE2M1f.id()) {\n        b_quant_1 = frag_b_quant[k2][0][j];\n        b_quant_0 = b_quant_1 << 8;\n      } else if constexpr (w_type.size_bits() == 4) {\n        b_quant_0 = frag_b_quant[k2][0][j];\n        b_quant_1 = b_quant_0 >> 8;\n      } else {\n        static_assert(w_type.size_bits() == 8);\n        int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k2]);\n        b_quant_0 = frag_b_quant_ptr[j * 2 + 0];\n        b_quant_1 = frag_b_quant_ptr[j * 2 + 1];\n      }\n\n      dequant_data(b_quant_0, reinterpret_cast<scalar_t2*>(&frag_b0));\n      dequant_data(b_quant_1, reinterpret_cast<scalar_t2*>(&frag_b1));\n\n      if constexpr (dequant_skip_flop && has_zp && !is_zp_float) {\n        sub_zp<scalar_t>(frag_b0, frag_zp[j], 0);\n        sub_zp<scalar_t>(frag_b1, frag_zp[j], 1);\n      }\n\n      // Apply scale to frag_b0\n      if constexpr (has_act_order) {\n        static_assert(group_blocks != -1);\n        scale4<scalar_t>(\n            frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j], act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0);\n        scale4<scalar_t>(\n            frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1);\n      } else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float && group_blocks == -1) {\n        int idx = (threadIdx.x / 4) % 2;\n        scalar_t2 s2 = Dtype::nums2num2(\n            reinterpret_cast<scalar_t*>(&frag_s[j / 2][j % 2 * 2 + 0])[idx],\n            reinterpret_cast<scalar_t*>(&frag_s[j / 2][j % 2 * 2 + 1])[idx]);\n        if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2);\n        scale_and_sub<scalar_t>(frag_b0, s2.x, frag_zp[j].x);\n        scale_and_sub<scalar_t>(frag_b1, s2.y, frag_zp[j].y);\n      } else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) {\n        if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], *reinterpret_cast<scalar_t2*>(&frag_s[k2][j]));\n        scale_and_sub<scalar_t>(frag_b0, frag_s[k2][j][0].x, frag_zp[j].x);\n        scale_and_sub<scalar_t>(frag_b1, frag_s[k2][j][0].y, frag_zp[j].y);\n      } else if constexpr (group_blocks != -1) {\n        scale<scalar_t>(frag_b0, frag_s[k2][j], 0);\n        scale<scalar_t>(frag_b1, frag_s[k2][j], 1);\n      }\n\n#pragma unroll\n      for (int i = 0; i < thread_m_blocks; i++) {\n        if constexpr (m_block_size_8) {\n          mma_trans<scalar_t>(frag_a[k2][i], frag_b0, frag_b1, frag_c[i][j][0]);\n        } else {\n          mma<scalar_t>(frag_a[k2][i], frag_b0, frag_c[i][j][0]);\n          mma<scalar_t>(frag_a[k2][i], frag_b1, frag_c[i][j][1]);\n        }\n      }\n    }\n  };\n\n  // Since we slice across the k dimension of a tile in order to increase the\n  // number of warps while keeping the n dimension of a tile reasonable, we have\n  // multiple warps that accumulate their partial sums of the same output\n  // location; which we have to reduce over in the end. We do in shared memory.\n  auto thread_block_reduce = [&]() {\n    constexpr int red_off = threads / b_sh_stride_threads / 2;\n    if (red_off >= 1) {\n      auto red_idx = threadIdx.x / b_sh_stride_threads;\n      constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;\n      constexpr int red_sh_delta = b_sh_stride_threads;\n      int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads);\n\n      // Parallel logarithmic shared memory reduction. We make sure to avoid any\n      // unnecessary read or write iterations, e.g., for two warps we write only\n      // once by warp 1 and read only once by warp 0.\n\n#pragma unroll\n      for (int m_block = 0; m_block < thread_m_blocks; m_block++) {\n#pragma unroll\n        for (int i = red_off; i > 0; i /= 2) {\n          if (i <= red_idx && red_idx < 2 * i) {\n#pragma unroll\n            for (int j = 0; j < 4 * 2; j += (m_block_size_8 ? 2 : 1)) {\n              int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i);\n              if (i < red_off) {\n                float* c_rd = reinterpret_cast<float*>(&sh_red[red_sh_delta * j + red_sh_rd]);\n                float* c_wr = reinterpret_cast<float*>(&sh_red[red_sh_wr]);\n#pragma unroll\n                for (int k = 0; k < 4; k++)\n                  reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k];\n              }\n              sh_red[red_sh_wr] = reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];\n            }\n          }\n          __syncthreads();\n        }\n        if (red_idx == 0) {\n#pragma unroll\n          for (int i = 0; i < 4 * 2; i += (m_block_size_8 ? 2 : 1)) {\n            float* c_rd = reinterpret_cast<float*>(&sh_red[red_sh_delta * i + red_sh_rd]);\n#pragma unroll\n            for (int j = 0; j < 4; j++)\n              reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] += c_rd[j];\n          }\n        }\n        __syncthreads();\n      }\n    }\n  };\n\n  // Since multiple threadblocks may process parts of the same column slice, we\n  // finally have to globally reduce over the results. As the striped\n  // partitioning minimizes the number of such reductions and our outputs are\n  // usually rather small, we perform this reduction serially in L2 cache.\n  auto global_reduce_fp16 = [&](bool first = false, bool last = false) {\n    // We are very careful here to reduce directly in the output buffer to\n    // maximize L2 cache utilization in this step. To do this, we write out\n    // results in FP16 (but still reduce with FP32 compute).\n    constexpr int active_threads = 32 * thread_n_blocks / 4;\n    bool is_th_active = threadIdx.x < active_threads;\n    if (!is_th_active) {\n      return;\n    }\n\n    int c_gl_stride = prob_n / 8;\n    int c_gl_wr_delta_o = 8 * c_gl_stride;\n    int c_gl_wr_delta_i = 4 * (active_threads / 32);\n    int c_gl_wr;\n    if constexpr (m_block_size_8) {\n      c_gl_wr = c_gl_stride * ((threadIdx.x % 4) * 2) + 4 * (threadIdx.x / 32) + (threadIdx.x % 32) / 8;\n      c_gl_wr += (2 * thread_n_blocks) * slice_col;\n    } else {\n      c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + 4 * (threadIdx.x / 32) + threadIdx.x % 4;\n      c_gl_wr += (2 * thread_n_blocks) * slice_col;\n    }\n    constexpr int c_sh_wr_delta = active_threads;\n    int c_sh_wr = threadIdx.x;\n\n    if (!first) {\n\n#pragma unroll\n      for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) {\n        int c_idx;\n        if constexpr (m_block_size_8)\n          c_idx = c_gl_wr + i * c_gl_stride + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i;\n        else\n          c_idx = c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2);\n        if (c_idx / c_gl_stride < block_num_valid_tokens) {\n          int64_t sorted_row = sh_block_sorted_ids[c_idx / c_gl_stride];\n          int64_t true_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride;\n          sh_red[c_sh_wr + c_sh_wr_delta * i] = C[true_idx];\n        }\n      }\n    }\n\n#pragma unroll\n    for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) {\n      if (!first) {\n        int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta];\n#pragma unroll\n        for (int j = 0; j < 2 * 4; j++) {\n          int delta = 0;\n          if constexpr (m_block_size_8) {\n            delta = j % 2 == 1 ? -2 : 0;\n          }\n          reinterpret_cast<float*>(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta] +=\n              Dtype::num2float(reinterpret_cast<scalar_t*>(&c_red)[j]);\n        }\n      }\n      if (!last) {\n        int4 c;\n#pragma unroll\n        for (int j = 0; j < 2 * 4; j++) {\n          int delta = 0;\n          if constexpr (m_block_size_8) {\n            delta = j % 2 == 1 ? -2 : 0;\n          }\n          reinterpret_cast<scalar_t*>(&c)[j] =\n              Dtype::float2num(reinterpret_cast<float*>(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta]);\n        }\n\n        int c_idx;\n        if constexpr (m_block_size_8)\n          c_idx = c_gl_wr + i * c_gl_stride + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i;\n        else\n          c_idx = c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2);\n        if (c_idx / c_gl_stride < block_num_valid_tokens) {\n          int64_t sorted_row = sh_block_sorted_ids[c_idx / c_gl_stride];\n          int64_t true_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride;\n          C[true_idx] = c;\n        }\n      }\n    }\n  };\n\n  // Globally reduce over threadblocks that compute the same column block.\n  // We use a tmp C buffer to reduce in full fp32 precision.\n  auto global_reduce_fp32 = [&](bool first = false, bool last = false) {\n    constexpr int tb_m = thread_m_blocks * 16;\n    constexpr int tb_n = thread_n_blocks * 16;\n\n    constexpr int c_size = tb_m * tb_n * sizeof(float) / 16;\n\n    constexpr int active_threads = 32 * thread_n_blocks / 4;\n    bool is_th_active = threadIdx.x < active_threads;\n\n    constexpr int num_floats = thread_m_blocks * 4 * 2 * 4;\n    constexpr int th_size = num_floats * sizeof(float) / 16;\n\n    int c_cur_offset = locks_off * c_size;\n\n    if (!is_th_active) {\n      return;\n    }\n\n    if (!first) {\n      float* frag_c_ptr = reinterpret_cast<float*>(&frag_c);\n#pragma unroll\n      for (int k = 0; k < th_size; k++) {\n        if constexpr (m_block_size_8) {\n          if (k % 2) continue;\n        } else {\n          if (k / 8 * 16 + (threadIdx.x % 32) / 4 >= block_num_valid_tokens) continue;\n        }\n\n        sh_red[threadIdx.x] = C_tmp[c_cur_offset + active_threads * k + threadIdx.x];\n\n        float* sh_c_ptr = reinterpret_cast<float*>(&sh_red[threadIdx.x]);\n#pragma unroll\n        for (int f = 0; f < 4; f++) {\n          frag_c_ptr[k * 4 + f] += sh_c_ptr[f];\n        }\n      }\n    }\n\n    if (!last) {\n      int4* frag_c_ptr = reinterpret_cast<int4*>(&frag_c);\n#pragma unroll\n      for (int k = 0; k < th_size; k++) {\n        if constexpr (m_block_size_8) {\n          if (k % 2) continue;\n        } else {\n          if (k / 8 * 16 + (threadIdx.x % 32) / 4 >= block_num_valid_tokens) continue;\n        }\n\n        C_tmp[c_cur_offset + active_threads * k + threadIdx.x] = frag_c_ptr[k];\n      }\n    }\n  };\n\n  // Write out the reduce final result in the correct layout. We only actually\n  // reshuffle matrix fragments in this step, the reduction above is performed\n  // in fragment layout.\n  auto write_result = [&](bool last) {\n    int c_gl_stride = prob_n / 8;\n    constexpr int c_sh_stride = 2 * thread_n_blocks + 1;\n    int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));\n    constexpr int c_sh_rd_delta = c_sh_stride * (threads / (2 * thread_n_blocks));\n\n    int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks));\n    c_gl_wr += (2 * thread_n_blocks) * slice_col;\n    int c_sh_wr;\n    if constexpr (m_block_size_8) {\n      c_sh_wr = (8 * c_sh_stride) * ((threadIdx.x % 32) % 4 * 2) + (threadIdx.x % 32) / 4;\n      c_sh_wr += 64 * (threadIdx.x / 32);\n    } else {\n      c_sh_wr = (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;\n      c_sh_wr += 32 * (threadIdx.x / 32);\n    }\n\n    int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks));\n\n    // We first reorder in shared memory to guarantee the most efficient final\n    // global write patterns\n    auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) {\n      scalar_t2 res = Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1));\n\n      // For per-column quantization we finally apply the scale here (only for\n      // 4-bit)\n      if constexpr (\n          !has_act_order && group_blocks == -1 && w_type.size_bits() == 4 && (has_zp && dequant_skip_flop || !has_zp)) {\n        scalar_t2 tmp_scale = s[0];\n        if constexpr (m_block_size_8) {\n          tmp_scale = Dtype::num2num2(reinterpret_cast<scalar_t*>(&s[0])[(threadIdx.x % 8) / 4]);\n        }\n        res = __hmul2(res, tmp_scale);\n      }\n\n      if constexpr (w_type == host::kFE2M1f && s_type == host::kFE4M3fn) {\n        if (!mul_topk_weights) {\n          res = __hmul2(res, global_scale);\n        }\n      }\n      if (has_bias && last) {\n        scalar_t2 tmp_bias = b_bias[0];\n        if constexpr (m_block_size_8) {\n          tmp_bias = Dtype::num2num2(reinterpret_cast<scalar_t*>(&b_bias[0])[(threadIdx.x % 8) / 4]);\n        }\n        res = __hadd2(res, tmp_bias);\n      }\n\n      if constexpr (m_block_size_8) {\n        ((scalar_t*)sh_red)[idx] = res.x;\n        ((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y;\n      } else {\n        ((scalar_t2*)sh_red)[idx] = res;\n      }\n    };\n\n    if (threadIdx.x / 32 < thread_n_blocks / 4) {\n#pragma unroll\n      for (int i = 0; i < thread_m_blocks; i++) {\n#pragma unroll\n        for (int j = 0; j < 4; j++) {\n          if constexpr (m_block_size_8) {\n            int wr = c_sh_wr + 16 * j;\n            write(\n                wr,\n                frag_c[i][j][0][0],\n                frag_c[i][j][0][1],\n                frag_s[j / 2][2 * (j % 2) + 0],\n                frag_bias[j / 2][2 * (j % 2) + 0]);\n            write(\n                wr + 8,\n                frag_c[i][j][0][2],\n                frag_c[i][j][0][3],\n                frag_s[j / 2][2 * (j % 2) + 1],\n                frag_bias[j / 2][2 * (j % 2) + 1]);\n          } else {\n            int wr = c_sh_wr + 8 * j;\n            write(\n                wr + (4 * c_sh_stride) * 0 + 0,\n                frag_c[i][j][0][0],\n                frag_c[i][j][0][1],\n                frag_s[j / 2][2 * (j % 2) + 0],\n                frag_bias[j / 2][2 * (j % 2) + 0]);\n            write(\n                wr + (4 * c_sh_stride) * 8 + 0,\n                frag_c[i][j][0][2],\n                frag_c[i][j][0][3],\n                frag_s[j / 2][2 * (j % 2) + 0],\n                frag_bias[j / 2][2 * (j % 2) + 0]);\n            write(\n                wr + (4 * c_sh_stride) * 0 + 4,\n                frag_c[i][j][1][0],\n                frag_c[i][j][1][1],\n                frag_s[j / 2][2 * (j % 2) + 1],\n                frag_bias[j / 2][2 * (j % 2) + 1]);\n            write(\n                wr + (4 * c_sh_stride) * 8 + 4,\n                frag_c[i][j][1][2],\n                frag_c[i][j][1][3],\n                frag_s[j / 2][2 * (j % 2) + 1],\n                frag_bias[j / 2][2 * (j % 2) + 1]);\n          }\n        }\n        c_sh_wr += 16 * (4 * c_sh_stride);\n      }\n    }\n    __syncthreads();\n\n#pragma unroll\n    for (int i = 0; i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) {\n      int row = c_gl_wr / c_gl_stride;\n      if (row < block_num_valid_tokens) {\n        int64_t sorted_row = sh_block_sorted_ids[row];\n        int64_t true_idx = sorted_row * c_gl_stride + c_gl_wr % c_gl_stride;\n        scalar_t2 topk_weight_score;\n        if (mul_topk_weights) topk_weight_score = sh_block_topk_weights[row];\n        if (use_atomic_add && slice_count > 1 || mul_topk_weights) {\n          scalar_t2* C_half2 = reinterpret_cast<scalar_t2*>(&C[true_idx]);\n          scalar_t2* sh_red_half2 = reinterpret_cast<scalar_t2*>(&sh_red[c_sh_rd]);\n#pragma unroll\n          for (int a = 0; a < 4; a++) {\n            scalar_t2 res = sh_red_half2[a];\n            if (mul_topk_weights) {\n              res = __hmul2(res, topk_weight_score);\n            }\n\n            if (use_atomic_add && slice_count > 1) {\n              atomicAdd(&C_half2[a], res);\n            } else {\n              C_half2[a] = res;\n            };\n          }\n        } else {\n          C[true_idx] = sh_red[c_sh_rd];\n        }\n        c_gl_wr += c_gl_wr_delta;\n        c_sh_rd += c_sh_rd_delta;\n      }\n    }\n    __syncthreads();\n  };\n\n  // Start global fetch and register load pipelines.\n  auto start_pipes = [&]() {\n\n#pragma unroll\n    for (int i = 0; i < stages - 1; i++) {\n      if (has_act_order && i == 0) {\n        int last_g_idx = slice_k_start + stages * tb_k * 2;\n        if (last_g_idx >= prob_k) {\n          last_g_idx = prob_k - 1;\n        }\n        fetch_act_order_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]);\n      }\n\n      if constexpr (has_zp && !is_zp_float && group_blocks == -1) {\n        if (i == 0) {\n          fetch_col_zp_to_shared();\n          if constexpr (!dequant_skip_flop) {\n            fetch_col_scale_to_shared();\n          }\n        }\n      }\n      fetch_to_shared(i, i, i < slice_iters, i);\n    }\n\n    zero_accums();\n    wait_for_stage();\n    init_same_group(0);\n    fetch_to_registers(0, 0);\n    fetch_scales_to_registers(0, 0);\n    fetch_zp_to_registers(0, 0);\n    a_gl_rd_col += a_gl_rd_delta_o * (stages - 1);\n    if constexpr (has_act_order) {\n      slice_k_start_shared_fetch += tb_k * (stages - 1);\n    }\n  };\n  if (slice_iters) {\n    start_pipes();\n  }\n\n  // Main loop.\n  while (slice_iters) {\n    // We unroll over both the global fetch and the register load pipeline to\n    // ensure all shared memory accesses are static. Note that both pipelines\n    // have even length meaning that the next iteration will always start at\n    // index 0.\n\n    for (int stage_group_id = 0; stage_group_id < max_num_stage_groups; stage_group_id++) {\n#pragma unroll\n      for (int pipe = 0; pipe < stages;) {\n#pragma unroll\n        for (int k = 0; k < b_sh_wr_iters; k++) {\n          int idx = (pipe >= stages && stage_group_id == max_num_stage_groups - 1) ? (pipe - stages)\n                                                                                   : (pipe + stage_group_id * stages);\n          fetch_to_registers(k + 1, pipe % stages, idx);\n          fetch_scales_to_registers(k + 1, pipe);\n          fetch_zp_to_registers(k + 1, pipe);\n          if (k == b_sh_wr_iters - 2) {\n            int idx = (pipe >= 1 && stage_group_id == max_num_stage_groups - 1)\n                          ? (pipe - 1)\n                          : (pipe + (stage_group_id + 1) * stages - 1);\n            fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages, idx);\n            pipe++;\n            wait_for_stage();\n            init_same_group(pipe % stages);\n          }\n          matmul(k);\n        }\n        slice_iters--;\n        if (slice_iters == 0) {\n          break;\n        }\n      }\n\n      a_gl_rd_col += a_gl_rd_delta_o * stages;\n\n      if constexpr (has_act_order) {\n        slice_k_start += tb_k * stages;\n\n        if (slice_k_start < prob_k) {\n          slice_k_start_shared_fetch += tb_k * stages;\n          int first_group_id = g_idx[slice_k_start];\n          int last_g_idx = slice_k_start + stages * tb_k * 2;\n          if (last_g_idx >= prob_k) {\n            last_g_idx = prob_k - 1;\n          }\n          int last_group_id = g_idx[last_g_idx];\n          if (last_group_id >= sh_first_group_id + sh_num_groups) {\n            fetch_act_order_scales_to_shared(false, first_group_id, last_group_id);\n            __syncthreads();\n          }\n        }\n      }\n      if (slice_iters == 0) {\n        break;\n      }\n    }\n\n    // Process results and, if necessary, proceed to the next column slice.\n    // While this pattern may not be the most readable, other ways of writing\n    // the loop seemed to noticeably worse performance after compilation.\n    if (slice_iters == 0) {\n      cp_async_wait<0>();\n      bool last = slice_idx == slice_count - 1;\n      // For per-column scales, we only fetch them here in the final step before\n      // write-out\n      if constexpr (!has_act_order && group_blocks == -1 && (has_zp && dequant_skip_flop || !has_zp)) {\n        if (w_type.size_bits() == 8 || (last || use_atomic_add)) {\n          if (s_sh_wr_pred) {\n            cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);\n          }\n          cp_async_fence();\n        }\n      }\n\n      thread_block_reduce();\n\n      if (has_bias && last) {\n        __syncthreads();\n        cp_async4_pred(&sh_bias[bias_sh_wr], &b_bias_ptr[bias_gl_rd], threadIdx.x < 16 * thread_n_blocks / 8);\n        cp_async_fence();\n      }\n\n      if constexpr (!has_act_order && group_blocks == -1 && (has_zp && dequant_skip_flop || !has_zp)) {\n        if (w_type.size_bits() == 8 || (last || use_atomic_add)) {\n          cp_async_wait<0>();\n          __syncthreads();\n          if (threadIdx.x / 32 < thread_n_blocks / 4) {\n            reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];\n            reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];\n            if constexpr (m_block_size_8) {\n              int idx = (threadIdx.x / 4) % 2;\n              scalar_t2* frag_s_half2 = reinterpret_cast<scalar_t2*>(frag_s);\n#pragma unroll\n              for (int i = 0; i < 8; i++) {\n                frag_s_half2[i] = Dtype::num2num2(reinterpret_cast<scalar_t*>(&frag_s_half2[i])[idx]);\n              }\n            }\n          }\n        }\n      }\n\n      // For 8-bit channelwise, we apply the scale before the global reduction\n      // that converts the fp32 results to fp16 (so that we avoid possible\n      // overflow in fp16)\n      if constexpr (\n          !has_act_order && group_blocks == -1 && w_type.size_bits() == 8 && (has_zp && dequant_skip_flop || !has_zp)) {\n        if (threadIdx.x / 32 < thread_n_blocks / 4) {\n#pragma unroll\n          for (int i = 0; i < thread_m_blocks; i++) {\n#pragma unroll\n            for (int j = 0; j < 4; j++) {\n              scale_float<scalar_t>(reinterpret_cast<float*>(&frag_c[i][j][0][0]), frag_s[j / 2][2 * (j % 2) + 0]);\n              scale_float<scalar_t>(\n                  reinterpret_cast<float*>(&frag_c[i][j][0][2]), frag_s[j / 2][2 * (j % 2) + (m_block_size_8 ? 1 : 0)]);\n\n              if constexpr (!m_block_size_8) {\n                scale_float<scalar_t>(reinterpret_cast<float*>(&frag_c[i][j][1][0]), frag_s[j / 2][2 * (j % 2) + 1]);\n                scale_float<scalar_t>(reinterpret_cast<float*>(&frag_c[i][j][1][2]), frag_s[j / 2][2 * (j % 2) + 1]);\n              }\n            }\n          }\n        }\n      }\n\n      if (slice_count > 1 && !use_atomic_add) {\n        // only globally reduce if there is more than one block in a slice\n        barrier_acquire(&locks[locks_off], slice_idx);\n        if (use_fp32_reduce) {\n          global_reduce_fp32(slice_idx == 0, last);\n        } else {\n          global_reduce_fp16(slice_idx == 0, last);\n        }\n        barrier_release(&locks[locks_off], last);\n      }\n\n      if (has_bias && last) {\n        cp_async_wait<0>();\n        __syncthreads();\n        reinterpret_cast<int4*>(&frag_bias)[0] = sh_bias[bias_sh_rd];\n        reinterpret_cast<int4*>(&frag_bias)[1] = sh_bias[bias_sh_rd + 4];\n        __syncthreads();\n      }\n\n      if (use_atomic_add && slice_count > 1 && slice_idx != 0) wait_negative_and_add(&locks[locks_off]);\n      if (last || use_atomic_add)\n        // only the last block in a slice actually writes the result\n        write_result(last);\n      int old_slice_row = slice_row;\n      slice_row = 0;\n      slice_col_par++;\n      slice_col++;\n      is_first_matmul_in_slice = true;\n      init_slice();\n\n      // Should we load A matrix in next slice?\n      // `slice_col == 0`: when move to a new moe block\n      // `old_slice_row > 0`:\n      //    when the last slice is not starting from k_index == 0\n      //    (only happen when it is the first slice of a threadblock)\n      // `prob_k > thread_k_blocks * 16 * stages * max_num_stage_groups`:\n      //    when the required shared memory size is larger than\n      //    the remaining shared memory\n      if (slice_col == 0 || old_slice_row || prob_k > thread_k_blocks * 16 * stages * max_num_stage_groups) {\n        should_load_a = true;\n      } else {\n        should_load_a = false;\n      }\n\n      if (slice_iters) {\n        a_gl_rd_col = (threadIdx.x % a_gl_rd_delta_o);\n#pragma unroll\n        for (int i = 0; i < b_sh_wr_iters; i++)\n          B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;\n        if (slice_col == 0) {\n#pragma unroll\n          for (int i = 0; i < b_sh_wr_iters; i++)\n            B_ptr[i] -= b_gl_stride;\n        }\n\n        bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x;\n        // Update slice k/n for scales loading\n        if constexpr (has_act_order) {\n          slice_k_start = tb_k * slice_row;\n          slice_k_finish = slice_k_start + tb_k * slice_iters;\n          slice_k_start_shared_fetch = slice_k_start;\n          slice_n_offset = act_s_col_tb_stride * slice_col;\n        } else {\n          s_gl_rd = s_sh_stride * slice_col + threadIdx.x;\n          zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;\n        }\n        start_pipes();\n      }\n    }\n  }\n}\n\n}  // namespace device::marlin_moe\n\n#endif\n"
  },
  {
    "path": "python/sglang/jit_kernel/csrc/gemm/marlin_moe/moe_wna16_marlin.cuh",
    "content": "/*\n * Modified by Neural Magic\n * Copyright (C) Marlin.2024 Elias Frantar\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *         http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n/*\n * Adapted from https://github.com/IST-DASLab/marlin\n */\n\n#pragma once\n\n#include <sgl_kernel/tensor.h>\n\n#include <sgl_kernel/scalar_type.hpp>\n\n#include \"kernel.h\"\n#include \"marlin_template.h\"\n\nnamespace device::marlin_moe {\n\n__global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){};\n\nusing MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS);\n\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800\n\ntemplate <int moe_block_size>\n__global__ void permute_cols_kernel(\n    int4 const* __restrict__ a_int4_ptr,\n    int const* __restrict__ perm_int_ptr,\n    int4* __restrict__ out_int4_ptr,\n    const int32_t* __restrict__ sorted_token_ids_ptr,\n    const int32_t* __restrict__ expert_ids_ptr,\n    const int32_t* __restrict__ num_tokens_past_padded_ptr,\n    int size_m,\n    int size_k,\n    int top_k) {};\n\n#else\n\n// For a given \"a\" of size [M,K] performs a permutation of the K columns based\n// on the given \"perm\" indices.\ntemplate <int moe_block_size>\n__global__ void permute_cols_kernel(\n    int4 const* __restrict__ a_int4_ptr,\n    int const* __restrict__ perm_int_ptr,\n    int4* __restrict__ out_int4_ptr,\n    const int32_t* __restrict__ sorted_token_ids_ptr,\n    const int32_t* __restrict__ expert_ids_ptr,\n    const int32_t* __restrict__ num_tokens_past_padded_ptr,\n    int size_m,\n    int size_k,\n    int top_k) {\n  int num_tokens_past_padded = num_tokens_past_padded_ptr[0];\n  int num_moe_blocks = div_ceil(num_tokens_past_padded, moe_block_size);\n  int32_t block_sorted_ids[moe_block_size];\n  int block_num_valid_tokens = 0;\n  int64_t old_expert_id = 0;\n  int64_t expert_id = 0;\n  int row_stride = size_k * sizeof(half) / 16;\n\n  auto read_moe_block_data = [&](int block_id) {\n    block_num_valid_tokens = moe_block_size;\n    int4* tmp_block_sorted_ids = reinterpret_cast<int4*>(block_sorted_ids);\n    for (int i = 0; i < moe_block_size / 4; i++) {\n      tmp_block_sorted_ids[i] = ((int4*)sorted_token_ids_ptr)[block_id * moe_block_size / 4 + i];\n    }\n    for (int i = 0; i < moe_block_size; i++) {\n      if (block_sorted_ids[i] >= size_m * top_k) {\n        block_num_valid_tokens = i;\n        break;\n      };\n    }\n  };\n\n  auto permute_row = [&](int row) {\n    int iters = size_k / default_threads;\n    int rest = size_k % default_threads;\n\n    int in_offset = (row / top_k) * row_stride;\n    int out_offset = row * row_stride;\n\n    half const* a_row_half = reinterpret_cast<half const*>(a_int4_ptr + in_offset);\n    half* out_half = reinterpret_cast<half*>(out_int4_ptr + out_offset);\n\n    int base_k = 0;\n\n    for (int i = 0; i < iters; i++) {\n      auto cur_k = base_k + threadIdx.x;\n      int src_pos = perm_int_ptr[cur_k];\n\n      out_half[cur_k] = a_row_half[src_pos];\n\n      base_k += default_threads;\n    }\n\n    if (rest) {\n      if (threadIdx.x < rest) {\n        auto cur_k = base_k + threadIdx.x;\n        int src_pos = perm_int_ptr[cur_k];\n\n        out_half[cur_k] = a_row_half[src_pos];\n      }\n    }\n  };\n\n  for (int index = blockIdx.x; index < num_moe_blocks; index += gridDim.x) {\n    old_expert_id = expert_id;\n    int tmp_expert_id = expert_ids_ptr[index];\n    if (tmp_expert_id == -1) continue;\n    expert_id = tmp_expert_id;\n    perm_int_ptr += (expert_id - old_expert_id) * size_k;\n    read_moe_block_data(index);\n\n    for (int i = 0; i < block_num_valid_tokens; i++)\n      permute_row(block_sorted_ids[i]);\n  }\n}\n\ntypedef struct {\n  int thread_k;\n  int thread_n;\n  int num_threads;\n} thread_config_t;\n\nthread_config_t small_batch_thread_configs[] = {\n    // Ordered by priority\n\n    // thread_k, thread_n, num_threads\n    {128, 128, 256},\n    {64, 128, 128}};\n\nthread_config_t large_batch_thread_configs[] = {\n    // Ordered by priority\n\n    // thread_k, thread_n, num_threads\n    {64, 256, 256},\n    {64, 128, 128}};\n\ntypedef struct {\n  int blocks_per_sm;\n  thread_config_t tb_cfg;\n} exec_config_t;\n\nint get_scales_cache_size(\n    thread_config_t const& th_config,\n    int prob_m,\n    int prob_n,\n    int prob_k,\n    int num_bits,\n    int group_size,\n    bool has_act_order,\n    bool is_k_full) {\n  bool cache_scales_chunk = has_act_order && !is_k_full;\n\n  int tb_n = th_config.thread_n;\n  int tb_k = th_config.thread_k;\n\n  // Get max scale groups per thread-block\n  int tb_groups;\n  if (group_size == -1) {\n    tb_groups = 1;\n  } else if (group_size == 0) {\n    tb_groups = div_ceil(tb_k, 32);  // Worst case is 32 group size\n  } else {\n    tb_groups = div_ceil(tb_k, group_size);\n  }\n\n  if (cache_scales_chunk) {\n    int load_groups = tb_groups * pipe_stages * 2;  // Chunk size is 2x pipeline over dim K\n    load_groups = max(load_groups, 32);             // We load at least 32 scale groups\n    return load_groups * tb_n * 2;\n  } else {\n    int tb_scales = tb_groups * tb_n * 2;\n\n    return tb_scales * pipe_stages;\n  }\n}\n\nint get_kernel_cache_size(\n    thread_config_t const& th_config,\n    bool m_block_size_8,\n    int thread_m_blocks,\n    int prob_m,\n    int prob_n,\n    int prob_k,\n    int num_bits,\n    int group_size,\n    bool has_act_order,\n    bool is_k_full,\n    int has_zp,\n    int is_zp_float) {\n  int pack_factor = 32 / num_bits;\n\n  // Get B size\n  int tb_k = th_config.thread_k;\n  int tb_n = th_config.thread_n;\n  int tb_m = thread_m_blocks * 16;\n\n  // shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights\n  // both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32)\n  int sh_block_meta_size = tb_m * 4;\n  int sh_a_size = pipe_stages * (tb_m * tb_k) * 2;\n  int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;\n  int sh_red_size = tb_m * (tb_n + 8) * 2;\n  int sh_bias_size = tb_n * 2;\n  int tmp_size = (sh_b_size > sh_red_size ? sh_red_size : sh_b_size) + sh_bias_size;\n  tmp_size = max(max(sh_b_size, sh_red_size), tmp_size);\n\n  int sh_s_size =\n      get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full);\n  int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0;\n  int sh_zp_size = 0;\n  if (has_zp) {\n    if (is_zp_float)\n      sh_zp_size = sh_s_size;\n    else if (num_bits == 4)\n      sh_zp_size = sh_s_size / 4;\n    else if (num_bits == 8)\n      sh_zp_size = sh_s_size / 2;\n  }\n\n  int total_size = tmp_size + sh_a_size + sh_s_size + sh_zp_size + sh_g_idx_size + sh_block_meta_size;\n\n  return total_size;\n}\n\nbool is_valid_config(\n    thread_config_t const& th_config,\n    bool m_block_size_8,\n    int thread_m_blocks,\n    int prob_m,\n    int prob_n,\n    int prob_k,\n    int num_bits,\n    int group_size,\n    bool has_act_order,\n    bool is_k_full,\n    int has_zp,\n    int is_zp_float,\n    int max_shared_mem) {\n  // Sanity\n  if (th_config.thread_k == -1 || th_config.thread_n == -1 || th_config.num_threads == -1) {\n    return false;\n  }\n\n  // Verify K/N are divisible by thread K/N\n  if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {\n    return false;\n  }\n\n  // Verify min for thread K/N\n  if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {\n    return false;\n  }\n\n  // num_threads must be at least 128 (= 4 warps)\n  if (th_config.num_threads < 128) {\n    return false;\n  }\n\n  // Check that pipeline fits into cache\n  int cache_size = get_kernel_cache_size(\n      th_config,\n      m_block_size_8,\n      thread_m_blocks,\n      prob_m,\n      prob_n,\n      prob_k,\n      num_bits,\n      group_size,\n      has_act_order,\n      is_k_full,\n      has_zp,\n      is_zp_float);\n  return cache_size + 512 <= max_shared_mem;\n}\n\n#define _GET_IF(                                                                                                       \\\n    W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \\\n  else if (                                                                                                            \\\n      q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && thread_n_blocks == THREAD_N_BLOCKS &&                  \\\n      thread_k_blocks == THREAD_K_BLOCKS && m_block_size_8 == M_BLOCK_SIZE_8 && group_blocks == GROUP_BLOCKS &&        \\\n      num_threads == NUM_THREADS && is_zp_float == IS_ZP_FLOAT) {                                                      \\\n    constexpr auto S_TYPE = W_TYPE == host::kFE2M1f                                                                    \\\n                                ? (GROUP_BLOCKS == 1 ? host::kFE4M3fn : host::kFE8M0fnu)                               \\\n                                : (std::is_same<scalar_t, half>::value ? host::kFloat16 : host::kBFloat16);            \\\n    kernel = Marlin<                                                                                                   \\\n        scalar_t,                                                                                                      \\\n        W_TYPE.id(),                                                                                                   \\\n        S_TYPE.id(),                                                                                                   \\\n        NUM_THREADS,                                                                                                   \\\n        THREAD_M_BLOCKS,                                                                                               \\\n        THREAD_N_BLOCKS,                                                                                               \\\n        THREAD_K_BLOCKS,                                                                                               \\\n        M_BLOCK_SIZE_8,                                                                                                \\\n        pipe_stages,                                                                                                   \\\n        GROUP_BLOCKS,                                                                                                  \\\n        IS_ZP_FLOAT>;                                                                                                  \\\n  }\n\n// COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false)\n//         this is the most common cases\n// BIGGROUP: cases for big group size (group_blocks in [-1, 8])\n// FZP: cases for float-zero-point (is_zp_float = true)\n// ACT: cases for act order case (group_blocks == 0)\n// NVFP4: cases for nvfp4(e2m1) (group_blocks == 1)\n// MXFP4: cases for mxfp4(e2m1) (group_blocks == 2)\n#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS)       \\\n  _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false)  \\\n  _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false)   \\\n  _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false)   \\\n  _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false)   \\\n  _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \\\n  _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)  \\\n  _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false)  \\\n  _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)\n\n#define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS)     \\\n  _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \\\n  _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)  \\\n  _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false)  \\\n  _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)  \\\n                                                                        \\\n  _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \\\n  _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)  \\\n  _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false)  \\\n  _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)  \\\n                                                                        \\\n  _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \\\n  _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)  \\\n  _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false)  \\\n  _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)\n\n#define COMMON_GET_IF(W_TYPE)            \\\n  COMMON_GET_IF_M1(W_TYPE, 8, 8, 256)    \\\n  COMMON_GET_IF_M1(W_TYPE, 8, 4, 128)    \\\n  COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \\\n  COMMON_GET_IF_M234(W_TYPE, 8, 4, 128)\n\n#define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS)     \\\n  _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false)  \\\n  _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false)   \\\n  _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \\\n  _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)\n\n#define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS)   \\\n  _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \\\n  _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)  \\\n  _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \\\n  _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)  \\\n  _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \\\n  _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)\n\n#define BIGGROUP_GET_IF(W_TYPE)            \\\n  BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256)    \\\n  BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128)    \\\n  BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \\\n  BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128)\n\n#define NVFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS)      \\\n  _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \\\n  _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)\n\n#define NVFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS)     \\\n  _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \\\n  _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \\\n  _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)\n\n#define NVFP4_GET_IF(W_TYPE)            \\\n  NVFP4_GET_IF_M1(W_TYPE, 8, 8, 256)    \\\n  NVFP4_GET_IF_M1(W_TYPE, 8, 4, 128)    \\\n  NVFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \\\n  NVFP4_GET_IF_M234(W_TYPE, 8, 4, 128)\n\n#define MXFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS)      \\\n  _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \\\n  _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)\n\n#define MXFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS)     \\\n  _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \\\n  _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \\\n  _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)\n\n#define MXFP4_GET_IF(W_TYPE)            \\\n  MXFP4_GET_IF_M1(W_TYPE, 8, 8, 256)    \\\n  MXFP4_GET_IF_M1(W_TYPE, 8, 4, 128)    \\\n  MXFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \\\n  MXFP4_GET_IF_M234(W_TYPE, 8, 4, 128)\n\n// We currently have 4-bit models only with group_blocks == 4\n#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS)       \\\n  _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \\\n  _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)\n\n#define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS)      \\\n  _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \\\n  _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \\\n  _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)\n\n#define FZP_GET_IF(W_TYPE)            \\\n  FZP_GET_IF_M1(W_TYPE, 8, 8, 256)    \\\n  FZP_GET_IF_M1(W_TYPE, 8, 4, 128)    \\\n  FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \\\n  FZP_GET_IF_M234(W_TYPE, 8, 4, 128)\n\n// We currently have 4-bit models only with group_blocks == 4\n#define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS)        \\\n  _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \\\n  _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)\n\n#define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS)       \\\n  _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \\\n  _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \\\n  _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)\n\n#define ACT_GET_IF(W_TYPE)            \\\n  ACT_GET_IF_M1(W_TYPE, 8, 8, 256)    \\\n  ACT_GET_IF_M1(W_TYPE, 8, 4, 128)    \\\n  ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \\\n  ACT_GET_IF_M234(W_TYPE, 8, 4, 128)\n\ntemplate <typename scalar_t>\nMarlinFuncPtr get_marlin_kernel(\n    const host::ScalarType q_type,\n    int thread_m_blocks,\n    int thread_n_blocks,\n    int thread_k_blocks,\n    bool m_block_size_8,\n    bool has_act_order,\n    bool has_zp,\n    int group_blocks,\n    int num_threads,\n    bool is_zp_float) {\n  int num_bits = q_type.size_bits();\n  auto kernel = MarlinDefault;\n  if (false) {\n  }\n\n  COMMON_GET_IF(host::kU4)\n  COMMON_GET_IF(host::kU4B8)\n  COMMON_GET_IF(host::kU8B128)\n\n  NVFP4_GET_IF(host::kFE2M1f)\n\n  BIGGROUP_GET_IF(host::kFE4M3fn)\n\n  ACT_GET_IF(host::kU4B8)\n  ACT_GET_IF(host::kU8B128)\n  if (std::is_same<scalar_t, nv_bfloat16>::value) {\n    if (false) {\n    }\n    MXFP4_GET_IF(host::kFE2M1f)\n  }\n\n  return kernel;\n}\n\ntemplate <typename scalar_t>\nexec_config_t determine_exec_config(\n    const host::ScalarType& q_type,\n    int prob_m,\n    int prob_n,\n    int prob_k,\n    int thread_m_blocks,\n    bool m_block_size_8,\n    int num_bits,\n    int group_size,\n    bool has_act_order,\n    bool is_k_full,\n    bool has_zp,\n    bool is_zp_float,\n    int max_shared_mem) {\n  exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}};\n  thread_config_t* thread_configs = thread_m_blocks > 1 ? large_batch_thread_configs : small_batch_thread_configs;\n  int thread_configs_size = thread_m_blocks > 1 ? sizeof(large_batch_thread_configs) / sizeof(thread_config_t)\n                                                : sizeof(small_batch_thread_configs) / sizeof(thread_config_t);\n\n  int count = 0;\n  constexpr int device_max_reg_size = 255 * 1024;\n  for (int i = 0; i < thread_configs_size; i++) {\n    thread_config_t th_config = thread_configs[i];\n\n    if (!is_valid_config(\n            th_config,\n            m_block_size_8,\n            thread_m_blocks,\n            prob_m,\n            prob_n,\n            prob_k,\n            num_bits,\n            group_size,\n            has_act_order,\n            is_k_full,\n            has_zp,\n            is_zp_float,\n            max_shared_mem)) {\n      continue;\n    }\n\n    int cache_size = get_kernel_cache_size(\n        th_config,\n        m_block_size_8,\n        thread_m_blocks,\n        prob_m,\n        prob_n,\n        prob_k,\n        num_bits,\n        group_size,\n        has_act_order,\n        is_k_full,\n        has_zp,\n        is_zp_float);\n\n    int group_blocks = 0;\n    if (!has_act_order) {\n      group_blocks = group_size == -1 ? -1 : (group_size / 16);\n    }\n\n    auto kernel = get_marlin_kernel<scalar_t>(\n        q_type,\n        thread_m_blocks,\n        th_config.thread_n / 16,\n        th_config.thread_k / 16,\n        m_block_size_8,\n        has_act_order,\n        has_zp,\n        group_blocks,\n        th_config.num_threads,\n        is_zp_float);\n\n    if (kernel == MarlinDefault) continue;\n\n    if (thread_m_blocks > 1) {\n      exec_cfg = {1, th_config};\n      break;\n    } else {\n      cudaFuncAttributes attr;\n      cudaFuncGetAttributes(&attr, kernel);\n      int reg_size = max(attr.numRegs, 1) * th_config.num_threads * 4;\n      int allow_count = min(device_max_reg_size / reg_size, max_shared_mem / (cache_size + 1024));\n      allow_count = max(min(allow_count, 4), 1);\n      if (allow_count > count) {\n        count = allow_count;\n        exec_cfg = {count, th_config};\n      };\n    }\n  }\n\n  return exec_cfg;\n}\n\ntemplate <typename scalar_t>\nvoid marlin_mm(\n    const void* A,\n    const void* B,\n    void* C,\n    void* C_tmp,\n    void* b_bias,\n    void* s,\n    void* s2,\n    void* zp,\n    void* g_idx,\n    void* perm,\n    void* a_tmp,\n    void* sorted_token_ids,\n    void* expert_ids,\n    void* num_tokens_past_padded,\n    void* topk_weights,\n    int moe_block_size,\n    int top_k,\n    bool mul_topk_weights,\n    bool is_ep,\n    int prob_m,\n    int prob_n,\n    int prob_k,\n    void* workspace,\n    host::ScalarType const& q_type,\n    bool has_bias,\n    bool has_act_order,\n    bool is_k_full,\n    bool has_zp,\n    int num_groups,\n    int group_size,\n    int dev,\n    cudaStream_t stream,\n    int thread_k,\n    int thread_n,\n    int sms,\n    bool use_atomic_add,\n    bool use_fp32_reduce,\n    bool is_zp_float) {\n  int thread_m_blocks = div_ceil(moe_block_size, 16);\n  bool m_block_size_8 = moe_block_size == 8;\n\n  if (has_zp) {\n    host::RuntimeCheck(\n        q_type == host::kU4 || q_type == host::kU8, \"q_type must be u4 or u8 when has_zp = True. Got = \", q_type.str());\n  } else {\n    host::RuntimeCheck(\n        q_type == host::kU4B8 || q_type == host::kU8B128 || q_type == host::kFE4M3fn || q_type == host::kFE2M1f,\n        \"q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when \"\n        \"has_zp = False. Got = \",\n        q_type.str());\n  }\n\n  host::RuntimeCheck(\n      prob_m > 0 && prob_n > 0 && prob_k > 0, \"Invalid MNK = [\", prob_m, \", \", prob_n, \", \", prob_k, \"]\");\n\n  int group_blocks = 0;\n  if (has_act_order) {\n    if (is_k_full) {\n      host::RuntimeCheck(group_size != -1);\n      group_blocks = group_size / 16;\n      host::RuntimeCheck(\n          prob_k % group_blocks == 0, \"prob_k = \", prob_k, \" is not divisible by group_blocks = \", group_blocks);\n    } else {\n      host::RuntimeCheck(group_size == 0);\n      group_blocks = 0;\n    }\n  } else {\n    if (group_size == -1) {\n      group_blocks = -1;\n    } else {\n      group_blocks = group_size / 16;\n      host::RuntimeCheck(\n          prob_k % group_blocks == 0, \"prob_k = \", prob_k, \" is not divisible by group_blocks = \", group_blocks);\n    }\n  }\n\n  int num_bits = q_type.size_bits();\n  const int4* A_ptr = (const int4*)A;\n  const int4* B_ptr = (const int4*)B;\n  int4* C_ptr = (int4*)C;\n  int4* C_tmp_ptr = (int4*)C_tmp;\n  const int4* bias_ptr = (const int4*)b_bias;\n  const int4* s_ptr = (const int4*)s;\n  const uint16_t* s2_ptr = (const uint16_t*)s2;\n  const int4* zp_ptr = (const int4*)zp;\n  const int* g_idx_ptr = (const int*)g_idx;\n  const int* perm_ptr = (const int*)perm;\n  int4* a_tmp_ptr = (int4*)a_tmp;\n  const int32_t* sorted_token_ids_ptr = (const int32_t*)sorted_token_ids;\n  const int32_t* expert_ids_ptr = (const int32_t*)expert_ids;\n  const int32_t* num_tokens_past_padded_ptr = (const int32_t*)num_tokens_past_padded;\n  const float* topk_weights_ptr = (const float*)topk_weights;\n  int* locks = (int*)workspace;\n\n  if (has_act_order) {\n    // Permute A columns\n    auto perm_kernel = permute_cols_kernel<8>;\n    if (moe_block_size == 8) {\n    } else if (moe_block_size == 16)\n      perm_kernel = permute_cols_kernel<16>;\n    else if (moe_block_size == 32)\n      perm_kernel = permute_cols_kernel<32>;\n    else if (moe_block_size == 48)\n      perm_kernel = permute_cols_kernel<48>;\n    else if (moe_block_size == 64)\n      perm_kernel = permute_cols_kernel<64>;\n    else\n      host::Panic(\"unsupported moe_block_size \", moe_block_size);\n\n    // clang-format off\n    perm_kernel<<<sms, default_threads, 0, stream>>>(\n        A_ptr, perm_ptr, a_tmp_ptr, sorted_token_ids_ptr, expert_ids_ptr,\n        num_tokens_past_padded_ptr, prob_m, prob_k, top_k);\n    // clang-format on\n    A_ptr = a_tmp_ptr;\n    prob_m = prob_m * top_k;\n    top_k = 1;\n\n    // If we have a full K, then we can run the non-act-order version of Marlin\n    // (since the weight rows are reordered by increasing group ids, and by\n    // having a full K, we have full original groups)\n    if (is_k_full) has_act_order = false;\n  }\n\n  int max_shared_mem = 0;\n  host::RuntimeDeviceCheck(cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev));\n  host::RuntimeCheck(max_shared_mem > 0);\n\n  // Set thread config\n  exec_config_t exec_cfg;\n  thread_config_t thread_tfg;\n  if (thread_k != -1 && thread_n != -1) {\n    thread_tfg = thread_config_t{thread_k, thread_n, default_threads};\n    exec_cfg = exec_config_t{1, thread_tfg};\n    host::RuntimeCheck(prob_n % thread_n == 0, \"prob_n = \", prob_n, \" is not divisible by thread_n = \", thread_n);\n    host::RuntimeCheck(prob_k % thread_k == 0, \"prob_k = \", prob_k, \" is not divisible by thread_k = \", thread_k);\n  } else {\n    // Auto config\n    exec_cfg = determine_exec_config<scalar_t>(\n        q_type,\n        prob_m,\n        prob_n,\n        prob_k,\n        thread_m_blocks,\n        m_block_size_8,\n        num_bits,\n        group_size,\n        has_act_order,\n        is_k_full,\n        has_zp,\n        is_zp_float,\n        max_shared_mem);\n    thread_tfg = exec_cfg.tb_cfg;\n  }\n\n  int num_threads = thread_tfg.num_threads;\n  thread_k = thread_tfg.thread_k;\n  thread_n = thread_tfg.thread_n;\n  int blocks = sms * exec_cfg.blocks_per_sm;\n  if (exec_cfg.blocks_per_sm > 1) max_shared_mem = max_shared_mem / exec_cfg.blocks_per_sm - 1024;\n\n  int thread_k_blocks = thread_k / 16;\n  int thread_n_blocks = thread_n / 16;\n\n  host::RuntimeCheck(\n      is_valid_config(\n          thread_tfg,\n          m_block_size_8,\n          thread_m_blocks,\n          prob_m,\n          prob_n,\n          prob_k,\n          num_bits,\n          group_size,\n          has_act_order,\n          is_k_full,\n          has_zp,\n          is_zp_float,\n          max_shared_mem),\n      \"Invalid thread config: thread_m_blocks = \",\n      thread_m_blocks,\n      \", thread_k = \",\n      thread_tfg.thread_k,\n      \", thread_n = \",\n      thread_tfg.thread_n,\n      \", num_threads = \",\n      thread_tfg.num_threads,\n      \" for MKN = [\",\n      prob_m,\n      \", \",\n      prob_k,\n      \", \",\n      prob_n,\n      \"] and num_bits = \",\n      num_bits,\n      \", group_size = \",\n      group_size,\n      \", has_act_order = \",\n      has_act_order,\n      \", is_k_full = \",\n      is_k_full,\n      \", has_zp = \",\n      has_zp,\n      \", is_zp_float = \",\n      is_zp_float,\n      \", max_shared_mem = \",\n      max_shared_mem);\n\n  auto kernel = get_marlin_kernel<scalar_t>(\n      q_type,\n      thread_m_blocks,\n      thread_n_blocks,\n      thread_k_blocks,\n      m_block_size_8,\n      has_act_order,\n      has_zp,\n      group_blocks,\n      num_threads,\n      is_zp_float);\n\n  if (kernel == MarlinDefault) {\n    host::Panic(\n        \"Unsupported shapes: MNK = [\",\n        prob_m,\n        \", \",\n        prob_n,\n        \", \",\n        prob_k,\n        \"]\",\n        \", has_act_order = \",\n        has_act_order,\n        \", num_groups = \",\n        num_groups,\n        \", group_size = \",\n        group_size,\n        \", thread_m_blocks = \",\n        thread_m_blocks,\n        \", thread_n_blocks = \",\n        thread_n_blocks,\n        \", thread_k_blocks = \",\n        thread_k_blocks,\n        \", num_bits = \",\n        num_bits);\n  }\n\n  host::RuntimeDeviceCheck(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem));\n  // clang-format off\n  kernel<<<blocks, num_threads, max_shared_mem, stream>>>(\n      A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr,\n      sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr,\n      topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m,\n      prob_n, prob_k, locks, has_bias, use_atomic_add, use_fp32_reduce, max_shared_mem);\n  // clang-format on\n}\n\n#endif\n\n}  // namespace device::marlin_moe\n\ntemplate <typename scalar_t>\nvoid moe_wna16_marlin_gemm(\n    tvm::ffi::TensorView a,\n    tvm::ffi::TensorView c,\n    tvm::ffi::TensorView b_q_weight,\n    tvm::ffi::TensorView b_bias,\n    tvm::ffi::TensorView b_scales,\n    tvm::ffi::TensorView global_scale,\n    tvm::ffi::TensorView b_zeros,\n    tvm::ffi::TensorView g_idx,\n    tvm::ffi::TensorView perm,\n    tvm::ffi::TensorView workspace,\n    tvm::ffi::TensorView sorted_token_ids,\n    tvm::ffi::TensorView expert_ids,\n    tvm::ffi::TensorView num_tokens_post_padded,\n    tvm::ffi::TensorView topk_weights,\n    tvm::ffi::TensorView a_tmp,\n    tvm::ffi::TensorView c_tmp,\n    int64_t moe_block_size,\n    int64_t top_k,\n    bool mul_topk_weights,\n    bool is_ep,\n    int64_t b_q_type_id,\n    int64_t size_m,\n    int64_t size_n,\n    int64_t size_k,\n    bool has_act_order,\n    bool has_bias,\n    bool is_k_full,\n    bool has_zp,\n    int64_t num_groups,\n    int64_t group_size,\n    bool use_atomic_add,\n    bool use_fp32_reduce,\n    bool is_zp_float) {\n  using namespace host;\n\n  ScalarType const b_q_type = ScalarType::from_id(b_q_type_id);\n  int pack_factor = 32 / b_q_type.size_bits();\n\n  if (moe_block_size != 8) {\n    RuntimeCheck(moe_block_size % 16 == 0, \"unsupported moe_block_size=\", moe_block_size);\n    RuntimeCheck(moe_block_size >= 16 && moe_block_size <= 64, \"unsupported moe_block_size=\", moe_block_size);\n  }\n\n  // Verify A\n  RuntimeCheck(a.size(0) == size_m, \"Shape mismatch: a.size(0) = \", a.size(0), \", size_m = \", size_m);\n  RuntimeCheck(a.size(1) == size_k, \"Shape mismatch: a.size(1) = \", a.size(1), \", size_k = \", size_k);\n\n  // Verify B\n  RuntimeCheck(\n      size_k % device::marlin::tile_size == 0,\n      \"size_k = \",\n      size_k,\n      \" is not divisible by tile_size = \",\n      device::marlin::tile_size);\n  RuntimeCheck(\n      (size_k / device::marlin::tile_size) == b_q_weight.size(1),\n      \"Shape mismatch: b_q_weight.size(1) = \",\n      b_q_weight.size(1),\n      \", size_k = \",\n      size_k,\n      \", tile_size = \",\n      device::marlin::tile_size);\n  RuntimeCheck(\n      b_q_weight.size(2) % device::marlin::tile_size == 0,\n      \"b_q_weight.size(2) = \",\n      b_q_weight.size(2),\n      \" is not divisible by tile_size = \",\n      device::marlin::tile_size);\n  int64_t actual_size_n = (b_q_weight.size(2) / device::marlin::tile_size) * pack_factor;\n  RuntimeCheck(size_n == actual_size_n, \"size_n = \", size_n, \", actual_size_n = \", actual_size_n);\n\n  // Verify device and strides\n  auto device = SymbolicDevice{};\n  device.set_options<kDLCUDA>();\n  TensorMatcher({-1, -1}).with_dtype<scalar_t>().with_device(device).verify(a);\n\n  device.verify(b_q_weight.device());\n  RuntimeCheck(b_q_weight.is_contiguous(), \"b_q_weight is not contiguous\");\n\n  device.verify(b_scales.device());\n  RuntimeCheck(b_scales.is_contiguous(), \"b_scales is not contiguous\");\n\n  // thread_k, thread_n, sms\n  int thread_k = -1;\n  int thread_n = -1;\n  int sms = -1;\n  DLDevice dl_device = device.unwrap();\n  int dev = dl_device.device_id;\n  cudaStream_t stream = LaunchKernel::resolve_device(dl_device);\n  RuntimeDeviceCheck(cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev));\n\n  // Verify c (allocation done in Python)\n  device.verify(c.device());\n  RuntimeCheck(c.is_contiguous(), \"c is not contiguous\");\n  RuntimeCheck(\n      c.size(0) == size_m * top_k, \"Shape mismatch: c.size(0) = \", c.size(0), \", size_m * topk = \", size_m * top_k);\n  RuntimeCheck(c.size(1) == size_n, \"Shape mismatch: c.size(1) = \", c.size(1), \", size_n = \", size_n);\n\n  // Alloc c_tmp: SKIP, done in Python\n\n  // Detect groupsize: b_scales rank and dims\n  RuntimeCheck(b_scales.dim() == 3, \"b_scales rank = \", b_scales.dim(), \" is not 3\");\n  RuntimeCheck(b_scales.size(2) == size_n, \"b_scales dim 2 = \", b_scales.size(2), \" is not size_n = \", size_n);\n  RuntimeCheck(\n      b_scales.size(1) == num_groups, \"b_scales dim 1 = \", b_scales.size(1), \" is not num_groups = \", num_groups);\n\n  // Validate g_idx, perm (Optional unwrap done in Python; empty tensors when absent)\n  if (g_idx.size(g_idx.dim() - 1) > 0 && perm.size(perm.dim() - 1) > 0) {\n    device.verify(g_idx.device());\n    RuntimeCheck(g_idx.is_contiguous(), \"g_idx is not contiguous\");\n    device.verify(perm.device());\n    RuntimeCheck(perm.is_contiguous(), \"perm is not contiguous\");\n\n    int64_t g_idx_last = g_idx.size(g_idx.dim() - 1);\n    int64_t perm_last = perm.size(perm.dim() - 1);\n    RuntimeCheck(\n        (g_idx_last == 0 && perm_last == 0) || (g_idx_last == size_k && perm_last == size_k),\n        \"Unexpected g_idx.size(-1) = \",\n        g_idx_last,\n        \" and perm.size(-1) = \",\n        perm_last,\n        \", where size_k = \",\n        size_k);\n  }\n  // has_act_order derivation: SKIP (passed as param)\n\n  // Verify group_size consistency\n  if (has_act_order) {\n    // SKIP: a_tmp allocation done in Python\n    if (is_k_full) {\n      RuntimeCheck(num_groups > 1, \"For act_order, num_groups must be > 1\");\n      RuntimeCheck(size_k % num_groups == 0, \"size_k = \", size_k, \", is not divisible by num_groups = \", num_groups);\n    }\n  } else {\n    if (num_groups > 1) {\n      RuntimeCheck(\n          size_k % num_groups == 0, \"size_k = \", size_k, \", is not divisible by b_scales.size(1) = \", num_groups);\n    }\n  }\n\n  // Verify global_scale (Optional unwrap done in Python)\n  int64_t global_scale_size = global_scale.size(0);\n  if (global_scale_size > 0) {\n    RuntimeCheck(b_q_type == kFE2M1f && group_size == 16, \"global_scale can only be used for nvfp4 format.\");\n  } else {\n    RuntimeCheck(\n        !(b_q_type == kFE2M1f && group_size == 16), \"the global_scale parameter must be passed for nvfp4 format.\");\n  }\n\n  // Verify b_bias (Optional unwrap done in Python)\n  if (has_bias) {\n    device.verify(b_bias.device());\n    RuntimeCheck(b_bias.is_contiguous(), \"b_bias is not contiguous\");\n    RuntimeCheck(b_bias.size(1) == size_n, \"b_bias.size(0) != size_n\");\n    RuntimeCheck(b_bias.stride(1) == 1, \"b_bias.stride(1) != 1\");\n  }\n\n  // b_zeros Optional unwrap + has_zp derivation: SKIP (done in Python)\n\n  // Verify b_q_type vs has_zp\n  if (has_zp) {\n    device.verify(b_zeros.device());\n    RuntimeCheck(b_zeros.is_contiguous(), \"b_zeros is not contiguous\");\n    RuntimeCheck(\n        b_q_type == kU4 || b_q_type == kU8, \"b_q_type must be u4 or u8 when has_zp = True. Got = \", b_q_type.str());\n  } else {\n    RuntimeCheck(\n        b_q_type == kU4B8 || b_q_type == kU8B128 || b_q_type == kFE4M3fn || b_q_type == kFE2M1f,\n        \"b_q_type must be uint4b8, uint8b128, float8_e4m3fn or \"\n        \"float4_e2m1f when \"\n        \"has_zp = False. Got = \",\n        b_q_type.str());\n  }\n\n  if (has_zp && is_zp_float) {\n    RuntimeCheck(\n        std::is_same<scalar_t, fp16_t>::value,\n        \"Computation type must be float16 (half) when using float zero \"\n        \"points.\");\n  }\n\n  // Verify b_zeros\n  if (has_zp) {\n    RuntimeCheck(b_zeros.dim() == 3, \"b_zeros rank = \", b_zeros.dim(), \" is not 3\");\n    if (is_zp_float) {\n      RuntimeCheck(b_zeros.size(2) == size_n, \"b_zeros dim 2 = \", b_zeros.size(2), \" is not size_n = \", size_n);\n      RuntimeCheck(\n          num_groups == b_zeros.size(1), \"b_zeros dim 1 = \", b_zeros.size(1), \" is not num_groups = \", num_groups);\n      RuntimeCheck(num_groups != -1, \"num_groups must be != -1\");\n    } else {\n      RuntimeCheck(\n          b_zeros.size(1) == num_groups, \"b_zeros dim 1 = \", b_zeros.size(1), \" is not num_groups = \", num_groups);\n      RuntimeCheck(\n          b_zeros.size(2) == size_n / pack_factor,\n          \"b_zeros dim 2 = \",\n          b_zeros.size(2),\n          \" is not size_n / pack_factor = \",\n          size_n / pack_factor);\n    }\n  }\n\n  // Verify workspace size\n  RuntimeCheck(\n      size_n % device::marlin::min_thread_n == 0,\n      \"size_n = \",\n      size_n,\n      \", is not divisible by min_thread_n = \",\n      device::marlin::min_thread_n);\n\n  int64_t max_n_tiles = size_n / device::marlin::min_thread_n;\n  int64_t min_workspace_size =\n      std::min(max_n_tiles * (sorted_token_ids.size(0) / moe_block_size), static_cast<int64_t>(sms) * 4);\n  RuntimeCheck(\n      workspace.size(0) >= min_workspace_size,\n      \"workspace.numel = \",\n      workspace.size(0),\n      \" is below min_workspace_size = \",\n      min_workspace_size);\n\n  // Early return for zero-size M (moved after all validation)\n  if (size_m == 0) return;\n\n  device::marlin_moe::marlin_mm<scalar_t>(\n      a.data_ptr(),\n      b_q_weight.data_ptr(),\n      c.data_ptr(),\n      c_tmp.data_ptr(),\n      b_bias.data_ptr(),\n      b_scales.data_ptr(),\n      global_scale.data_ptr(),\n      b_zeros.data_ptr(),\n      g_idx.data_ptr(),\n      perm.data_ptr(),\n      a_tmp.data_ptr(),\n      sorted_token_ids.data_ptr(),\n      expert_ids.data_ptr(),\n      num_tokens_post_padded.data_ptr(),\n      topk_weights.data_ptr(),\n      static_cast<int>(moe_block_size),\n      static_cast<int>(top_k),\n      mul_topk_weights,\n      is_ep,\n      static_cast<int>(size_m),\n      static_cast<int>(size_n),\n      static_cast<int>(size_k),\n      workspace.data_ptr(),\n      b_q_type,\n      has_bias,\n      has_act_order,\n      is_k_full,\n      has_zp,\n      static_cast<int>(num_groups),\n      static_cast<int>(group_size),\n      dev,\n      stream,\n      thread_k,\n      thread_n,\n      sms,\n      use_atomic_add,\n      use_fp32_reduce,\n      is_zp_float);\n}\n"
  },
  {
    "path": "python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_expert_quant.cuh",
    "content": "#include <sgl_kernel/tensor.h>\n#include <sgl_kernel/utils.h>\n\n#include <sgl_kernel/runtime.cuh>\n#include <sgl_kernel/type.cuh>\n#include <sgl_kernel/utils.cuh>\n\n#include \"nvfp4_quant.cuh\"\n#include <cuda_runtime.h>\n#include <cuda_runtime_api.h>\n\nusing namespace host;\n\n// Quantizes the provided PackedVec into the uint32_t output\ntemplate <class Type, bool UE8M0_SF = false>\nSGL_DEVICE uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal, uint8_t* SFout) {\n#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)\n  // Get absolute maximum values among the local 8 values.\n  auto localMax = __habs2(vec.elts[0]);\n\n// Local maximum value.\n#pragma unroll\n  for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {\n    localMax = __hmax2(localMax, __habs2(vec.elts[i]));\n  }\n\n  // Get the absolute maximum among all 16 values (two threads).\n  localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax);\n  // Get the final absolute maximum values.\n  float vecMax = float(__hmax(localMax.x, localMax.y));\n\n  // Get the SF (max value of the vector / max value of e2m1).\n  // maximum value of e2m1 = 6.0.\n  // TODO: use half as compute data type.\n  float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f));\n  // 8 bits representation of the SF.\n  uint8_t fp8SFVal;\n  // Write the SF to global memory (STG.8).\n  if constexpr (UE8M0_SF) {\n    // Extract the 8 exponent bits from float32.\n    // float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits.\n    uint32_t tmp = reinterpret_cast<uint32_t&>(SFValue) >> 23;\n    fp8SFVal = tmp & 0xff;\n    // Convert back to fp32.\n    reinterpret_cast<uint32_t&>(SFValue) = tmp << 23;\n  } else {\n    // Here SFValue is always positive, so E4M3 is the same as UE4M3.\n    __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue);\n    reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp;\n    // Convert back to fp32.\n    SFValue = float(tmp);\n  }\n  // Get the output scale.\n  // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) *\n  //                       reciprocal(SFScaleVal))\n  float outputScale =\n      SFValue != 0 ? reciprocal_approximate_ftz(SFValue * reciprocal_approximate_ftz(SFScaleVal)) : 0.0f;\n\n  if (SFout) {\n    // Write the SF to global memory (STG.8).\n    *SFout = fp8SFVal;\n  }\n\n  // Convert the input to float.\n  float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2];\n\n#pragma unroll\n  for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {\n    fp2Vals[i] = device::cast<float2>(vec.elts[i]);\n    fp2Vals[i].x *= outputScale;\n    fp2Vals[i].y *= outputScale;\n  }\n\n  // Convert to e2m1 values.\n  uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals);\n\n  // Write the e2m1 values to global memory.\n  return e2m1Vec;\n#else\n  return 0;\n#endif\n}\n\nSGL_DEVICE float silu(const float& val) {\n  return val / (1.0f + __expf(-val));\n}\n\ntemplate <class Type>\nSGL_DEVICE void silu_and_mul(PackedVec<Type>& x_vec, const PackedVec<Type>& y_vec) {\n  float2 x[CVT_FP4_ELTS_PER_THREAD / 2];\n  float2 y[CVT_FP4_ELTS_PER_THREAD / 2];\n\n#pragma unroll\n  for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {\n    x[i] = device::cast<float2>(x_vec.elts[i]);\n    y[i] = device::cast<float2>(y_vec.elts[i]);\n    x[i].x = silu(x[i].x) * y[i].x;\n    x[i].y = silu(x[i].y) * y[i].y;\n    x_vec.elts[i] = device::cast<packed_t<Type>>(x[i]);\n  }\n}\n\n// Use UE4M3 by default.\ntemplate <class Type, bool UE8M0_SF = false, bool SMALL_NUM_EXPERTS = false>\n__global__ void\n#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)\n__launch_bounds__(512, 4) cvt_fp16_to_fp4(\n#else\ncvt_fp16_to_fp4(\n#endif\n    int32_t numRows,\n    int32_t numCols,\n    Type const* in,\n    float const* SFScale,\n    uint32_t* out,\n    uint32_t* SFout,\n    uint32_t* input_offset_by_experts,\n    uint32_t* output_scale_offset_by_experts,\n    int32_t* mask,\n    int n_experts,\n    bool low_latency) {\n#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)\n  using PackedVec = PackedVec<Type>;\n  static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);\n  static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, \"Vec size is not matched.\");\n\n  // Input tensor row/col loops.\n  int tid = blockIdx.x * blockDim.x + threadIdx.x;\n  int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;\n  // TODO(kaixih@nvidia): For now, we assume mask is used together with\n  // silu_and_mal. Maybe we want a more general behavior of mask later. In the\n  // silu case, the input last dim doubles.\n  bool use_mask = mask != nullptr;\n  int actualColsPerRow = use_mask ? colsPerRow * 2 : colsPerRow;\n\n  // Each global thread processes one element\n  for (int globalIdx = tid; globalIdx < numRows * colsPerRow; globalIdx += gridDim.x * blockDim.x) {\n    // Calculate which row and column this global thread should process\n    int rowIdx = globalIdx / colsPerRow;\n    int colIdx = globalIdx % colsPerRow;\n\n    // Find index within the experts using different strategies based on expert\n    // count\n    int rowIdx_in_expert = 0;\n    int expert_idx = 0;\n\n    if constexpr (SMALL_NUM_EXPERTS) {\n      for (int i = 0; i < n_experts; i++) {\n        uint32_t current_offset = __ldca(&input_offset_by_experts[i]);\n        uint32_t next_offset = __ldca(&input_offset_by_experts[i + 1]);\n        if (rowIdx >= current_offset && rowIdx < next_offset) {\n          rowIdx_in_expert = rowIdx - current_offset;\n          expert_idx = i;\n          break;\n        }\n      }\n    } else {\n      // Load input offsets into registers first, then do the computation.\n      // Local array size set to 17 because of register limit.\n      uint32_t local_offsets[17];\n      for (int chunk_start = 0; chunk_start < n_experts; chunk_start += 16) {\n        *reinterpret_cast<int4*>(local_offsets) =\n            __ldca(reinterpret_cast<const int4*>(&input_offset_by_experts[chunk_start]));\n        *reinterpret_cast<int4*>(local_offsets + 4) =\n            __ldca(reinterpret_cast<const int4*>(&input_offset_by_experts[chunk_start + 4]));\n        *reinterpret_cast<int4*>(local_offsets + 8) =\n            __ldca(reinterpret_cast<const int4*>(&input_offset_by_experts[chunk_start + 8]));\n        *reinterpret_cast<int4*>(local_offsets + 12) =\n            __ldca(reinterpret_cast<const int4*>(&input_offset_by_experts[chunk_start + 12]));\n        local_offsets[16] = __ldca(&input_offset_by_experts[chunk_start + 16]);\n\n// Check against the 16 loaded offsets\n#pragma unroll\n        for (int i = 0; i < 16; i++) {\n          if (rowIdx >= local_offsets[i] && rowIdx < local_offsets[i + 1]) {\n            rowIdx_in_expert = rowIdx - local_offsets[i];\n            expert_idx = chunk_start + i;\n            break;\n          }\n        }\n      }\n    }\n\n    // Early exit when using masks.\n    if (use_mask && rowIdx_in_expert >= mask[expert_idx]) {\n      continue;\n    }\n\n    int64_t inOffset = rowIdx * actualColsPerRow + colIdx;\n    PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];\n    if (use_mask) {\n      PackedVec in_vec_mul = reinterpret_cast<PackedVec const*>(in)[inOffset + colsPerRow];\n      silu_and_mul(in_vec, in_vec_mul);\n    }\n\n    // Get the output tensor offset.\n    // Same as inOffset because 8 elements are packed into one uint32_t.\n    int64_t outOffset = rowIdx * colsPerRow + colIdx;\n    auto& out_pos = out[outOffset];\n\n    // Get the global scaling factor, which will be applied to the SF.\n    // Note SFScale is the same as next GEMM's alpha, which is\n    // (448.f / (Alpha_A / 6.f)).\n    float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];\n\n    int factor = CVT_FP4_SF_VEC_SIZE * 4;\n    // The actual output_scales dim is computed from the padded numCols.\n    int32_t numCols_padded = (numCols + factor - 1) / factor * factor;\n    int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4;\n    uint32_t* SFout_in_expert = SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout;\n\n    auto sf_out = cvt_quant_to_fp4_get_sf_out_offset<uint32_t, CVT_FP4_NUM_THREADS_PER_SF>(\n        rowIdx_in_expert, colIdx, numCols, SFout_in_expert);\n\n    out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);\n  }\n#endif\n}\n\n// Use UE4M3 by default.\ntemplate <class Type, bool UE8M0_SF = false>\n__global__ void\n#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)\n__launch_bounds__(512, 4) cvt_fp16_to_fp4_expert(\n#else\ncvt_fp16_to_fp4_expert(\n#endif\n    int32_t numRows,\n    int32_t numCols,\n    Type const* in,\n    float const* SFScale,\n    uint32_t* out,\n    uint32_t* SFout,\n    int32_t* mask,\n    bool use_silu_and_mul,\n    int n_experts) {\n#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)\n  using PackedVec = PackedVec<Type>;\n  static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);\n  static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, \"Vec size is not matched.\");\n\n  // Input tensor row/col loops.\n  int tid = blockIdx.x * blockDim.x + threadIdx.x;\n  int stride = (gridDim.x * blockDim.x) / n_experts;\n  int remainder = (gridDim.x * blockDim.x) % n_experts;\n  int expert_idx;\n  int tid_in_expert;\n  int actual_stride;\n  if (remainder > 0) {\n    int bound = remainder * (stride + 1);\n    if (tid < bound) {\n      expert_idx = tid / (stride + 1);\n      tid_in_expert = tid % (stride + 1);\n      actual_stride = stride + 1;\n    } else {\n      expert_idx = remainder + (tid - bound) / stride;\n      tid_in_expert = (tid - bound) % stride;\n      actual_stride = stride;\n    }\n  } else {\n    expert_idx = tid / stride;\n    tid_in_expert = tid % stride;\n    actual_stride = stride;\n  }\n  int m = numRows / n_experts;\n  int padded_m = (m + (128 - 1)) / 128 * 128;\n\n  int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;\n  // TODO(kaixih@nvidia): For now, we assume mask is used together with\n  // silu_and_mal. Maybe we want a more general behavior of mask later. In the\n  // silu case, the input last dim doubles.\n  bool use_mask = mask != nullptr;\n  int actualColsPerRow = use_silu_and_mul ? colsPerRow * 2 : colsPerRow;\n\n  // Each global thread processes one element\n  for (int globalIdx = tid_in_expert + expert_idx * m * colsPerRow; globalIdx < (expert_idx + 1) * m * colsPerRow;\n       globalIdx += actual_stride) {\n    // Calculate which row and column this global thread should process\n    int rowIdx = globalIdx / colsPerRow;\n    int colIdx = globalIdx % colsPerRow;\n\n    // Find index within the experts\n    int rowIdx_in_expert = rowIdx - expert_idx * m;\n\n    // Early exit when using masks.\n    if (use_mask && rowIdx_in_expert >= mask[expert_idx]) {\n      break;\n    }\n\n    int64_t inOffset = rowIdx * actualColsPerRow + colIdx;\n    PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];\n    if (use_silu_and_mul) {\n      PackedVec in_vec_mul = reinterpret_cast<PackedVec const*>(in)[inOffset + colsPerRow];\n      silu_and_mul(in_vec, in_vec_mul);\n    }\n\n    // Get the output tensor offset.\n    // Same as inOffset because 8 elements are packed into one uint32_t.\n    int64_t outOffset = rowIdx * colsPerRow + colIdx;\n    auto& out_pos = out[outOffset];\n\n    // Get the global scaling factor, which will be applied to the SF.\n    // Note SFScale is the same as next GEMM's alpha, which is\n    // (448.f / (Alpha_A / 6.f)).\n    float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];\n\n    int factor = CVT_FP4_SF_VEC_SIZE * 4;\n    // The actual output_scales dim is computed from the padded numCols.\n    int32_t numCols_padded = (numCols + factor - 1) / factor * factor;\n    int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4;\n    uint32_t* SFout_in_expert = SFout + expert_idx * padded_m * numCols_SFout;\n\n    auto sf_out = cvt_quant_to_fp4_get_sf_out_offset<uint32_t, CVT_FP4_NUM_THREADS_PER_SF>(\n        rowIdx_in_expert, colIdx, numCols, SFout_in_expert);\n\n    out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);\n  }\n#endif\n}\n\n// Kernel for LARGE_M_TOPK = true (large m_topk optimized version)\ntemplate <class Type, bool UE8M0_SF = false, bool SMALL_NUM_EXPERTS = false>\n__global__ void\n#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)\n__launch_bounds__(1024, 4) cvt_fp16_to_fp4(\n#else\ncvt_fp16_to_fp4(\n#endif\n    int32_t numRows,\n    int32_t numCols,\n    Type const* in,\n    float const* SFScale,\n    uint32_t* out,\n    uint32_t* SFout,\n    uint32_t* input_offset_by_experts,\n    uint32_t* output_scale_offset_by_experts,\n    int32_t* mask,\n    int n_experts) {\n#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)\n  using PackedVec = PackedVec<Type>;\n  static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);\n  static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, \"Vec size is not matched.\");\n  extern __shared__ uint32_t shared_input_offsets[];\n\n  // Load input offsets into shared memory.\n  // If n_experts is larger than 4, use vectorized int4 to save instructions.\n  // If n_experts is smaller than 4, read directly.\n  if constexpr (SMALL_NUM_EXPERTS) {\n    for (int i = threadIdx.x; i < n_experts + 1; i += blockDim.x) {\n      shared_input_offsets[i] = input_offset_by_experts[i];\n    }\n  } else {\n    for (int i = threadIdx.x * 4; i < n_experts; i += blockDim.x * 4) {\n      *reinterpret_cast<int4*>(&shared_input_offsets[i]) = *reinterpret_cast<const int4*>(&input_offset_by_experts[i]);\n    }\n    if (threadIdx.x == 0) {\n      shared_input_offsets[n_experts] = input_offset_by_experts[n_experts];\n    }\n  }\n\n  __syncthreads();\n\n  int tid = blockIdx.x * blockDim.x + threadIdx.x;\n  int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;\n  bool use_mask = mask != nullptr;\n  int actualColsPerRow = use_mask ? colsPerRow * 2 : colsPerRow;\n\n  // Each global thread processes one element\n  for (int globalIdx = tid; globalIdx < numRows * colsPerRow; globalIdx += gridDim.x * blockDim.x) {\n    // Calculate which row and column this global thread should process\n    int rowIdx = globalIdx / colsPerRow;\n    int colIdx = globalIdx % colsPerRow;\n\n    // Find expert using binary search for better performance with large m_topk\n    int rowIdx_in_expert = 0;\n    int expert_idx = 0;\n\n    // Binary search through experts using shared memory\n    int left = 0, right = n_experts - 1;\n    while (left <= right) {\n      int mid = (left + right) / 2;\n      // Get offsets: shared_input_offsets[i] corresponds to\n      // input_offset_by_experts[i]\n      uint32_t mid_offset = shared_input_offsets[mid];\n      uint32_t next_offset = shared_input_offsets[mid + 1];\n\n      if (rowIdx >= mid_offset && rowIdx < next_offset) {\n        rowIdx_in_expert = rowIdx - mid_offset;\n        expert_idx = mid;\n        break;\n      } else if (rowIdx < mid_offset) {\n        right = mid - 1;\n      } else {\n        left = mid + 1;\n      }\n    }\n\n    if (use_mask && rowIdx_in_expert >= mask[expert_idx]) {\n      continue;\n    }\n\n    int64_t inOffset = rowIdx * actualColsPerRow + colIdx;\n\n    PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];\n    if (use_mask) {\n      PackedVec in_vec_mul = reinterpret_cast<PackedVec const*>(in)[inOffset + colsPerRow];\n      silu_and_mul(in_vec, in_vec_mul);\n    }\n\n    int64_t outOffset = rowIdx * colsPerRow + colIdx;\n    auto& out_pos = out[outOffset];\n\n    float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];\n\n    int factor = CVT_FP4_SF_VEC_SIZE * 4;\n    int32_t numCols_padded = (numCols + factor - 1) / factor * factor;\n    int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4;\n    uint32_t* SFout_in_expert = SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout;\n\n    auto sf_out = cvt_quant_to_fp4_get_sf_out_offset<uint32_t, CVT_FP4_NUM_THREADS_PER_SF>(\n        rowIdx_in_expert, colIdx, numCols, SFout_in_expert);\n\n    out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);\n  }\n#endif\n}\n\ntemplate <typename T>\nvoid quant_impl(\n    void* output,\n    void* output_scale,\n    void* input,\n    void* input_global_scale,\n    void* input_offset_by_experts,\n    void* output_scale_offset_by_experts,\n    void* mask,\n    bool use_silu_and_mul,\n    int m_topk,\n    int k,\n    int n_experts,\n    cudaStream_t stream) {\n  // TODO: this multiProcessorCount should be cached.\n  int device;\n  cudaGetDevice(&device);\n  int multiProcessorCount;\n  cudaDeviceGetAttribute(&multiProcessorCount, cudaDevAttrMultiProcessorCount, device);\n\n  // Grid, Block size.\n  // Each thread converts 8 values.\n  int const workSizePerRow = k / ELTS_PER_THREAD;\n  int const totalWorkSize = m_topk * workSizePerRow;\n  dim3 block(std::min(workSizePerRow, 512));\n  // Get number of blocks per SM (assume we can fully utilize the SM).\n  int const numBlocksPerSM = 2048 / block.x;\n  dim3 grid(std::min(static_cast<int>((totalWorkSize + block.x - 1) / block.x), multiProcessorCount * numBlocksPerSM));\n  while (grid.x <= multiProcessorCount && block.x > 64) {\n    grid.x *= 2;\n    block.x = (block.x + 1) / 2;\n  }\n\n  // TODO(kaixih@nvidia): Should relax this to allow any grid size.\n  if (mask != nullptr) {\n    grid.x = (grid.x + n_experts - 1) / n_experts * n_experts;\n    cvt_fp16_to_fp4_expert<T, false><<<grid, block, 0, stream>>>(\n        m_topk,\n        k,\n        reinterpret_cast<T*>(input),\n        reinterpret_cast<float*>(input_global_scale),\n        reinterpret_cast<uint32_t*>(output),\n        reinterpret_cast<uint32_t*>(output_scale),\n        reinterpret_cast<int32_t*>(mask),\n        use_silu_and_mul,\n        n_experts);\n    return;\n  }\n\n  int const blockRepeat = (totalWorkSize + block.x * grid.x - 1) / (block.x * grid.x);\n  if (blockRepeat > 1) {\n    size_t shared_mem_size = (n_experts + 1) * sizeof(uint32_t);\n    if (n_experts >= 4) {\n      cvt_fp16_to_fp4<T, false, false><<<grid, block, shared_mem_size, stream>>>(\n          m_topk,\n          k,\n          reinterpret_cast<T*>(input),\n          reinterpret_cast<float*>(input_global_scale),\n          reinterpret_cast<uint32_t*>(output),\n          reinterpret_cast<uint32_t*>(output_scale),\n          reinterpret_cast<uint32_t*>(input_offset_by_experts),\n          reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),\n          reinterpret_cast<int32_t*>(mask),\n          n_experts);\n    } else {\n      cvt_fp16_to_fp4<T, false, true><<<grid, block, shared_mem_size, stream>>>(\n          m_topk,\n          k,\n          reinterpret_cast<T*>(input),\n          reinterpret_cast<float*>(input_global_scale),\n          reinterpret_cast<uint32_t*>(output),\n          reinterpret_cast<uint32_t*>(output_scale),\n          reinterpret_cast<uint32_t*>(input_offset_by_experts),\n          reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),\n          reinterpret_cast<int32_t*>(mask),\n          n_experts);\n    }\n  } else {\n    if (n_experts >= 16) {\n      cvt_fp16_to_fp4<T, false, false><<<grid, block, 0, stream>>>(\n          m_topk,\n          k,\n          reinterpret_cast<T*>(input),\n          reinterpret_cast<float*>(input_global_scale),\n          reinterpret_cast<uint32_t*>(output),\n          reinterpret_cast<uint32_t*>(output_scale),\n          reinterpret_cast<uint32_t*>(input_offset_by_experts),\n          reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),\n          reinterpret_cast<int32_t*>(mask),\n          n_experts,\n          /* bool low_latency */ true);\n    } else {\n      cvt_fp16_to_fp4<T, false, true><<<grid, block, 0, stream>>>(\n          m_topk,\n          k,\n          reinterpret_cast<T*>(input),\n          reinterpret_cast<float*>(input_global_scale),\n          reinterpret_cast<uint32_t*>(output),\n          reinterpret_cast<uint32_t*>(output_scale),\n          reinterpret_cast<uint32_t*>(input_offset_by_experts),\n          reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),\n          reinterpret_cast<int32_t*>(mask),\n          n_experts,\n          /* bool low_latency */ true);\n    }\n  }\n}\n\ninline int getSMVersion(int device_id) {\n  int sm_major = 0;\n  int sm_minor = 0;\n  RuntimeDeviceCheck(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device_id));\n  RuntimeDeviceCheck(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device_id));\n  return sm_major * 10 + sm_minor;\n}\n\nvoid scaled_fp4_experts_quant_sm100a(\n    tvm::ffi::TensorView output,\n    tvm::ffi::TensorView output_scale,\n    tvm::ffi::TensorView input,\n    tvm::ffi::TensorView input_global_scale,\n    tvm::ffi::TensorView input_offset_by_experts,\n    tvm::ffi::TensorView output_scale_offset_by_experts) {\n  auto MTopK = SymbolicSize{\"m_topk\"};\n  auto K = SymbolicSize{\"k\"};\n  auto OutputCols = SymbolicSize{\"output_cols\"};\n  auto OutputScaleRows = SymbolicSize{\"output_scale_rows\"};\n  auto OutputScaleCols = SymbolicSize{\"output_scale_cols\"};\n  auto NExperts = SymbolicSize{\"n_experts\"};\n  auto OffsetSize = SymbolicSize{\"offset_size\"};\n  auto device = SymbolicDevice{};\n\n  TensorMatcher({MTopK, K})  //\n      .with_dtype<fp16_t, bf16_t>()\n      .template with_device<kDLCUDA>(device)\n      .verify(input);\n  TensorMatcher({MTopK, OutputCols})  //\n      .with_dtype<uint8_t>()\n      .with_device(device)\n      .verify(output);\n  TensorMatcher({OutputScaleRows, OutputScaleCols})  //\n      .with_dtype<int32_t>()\n      .with_device(device)\n      .verify(output_scale);\n  TensorMatcher({NExperts})  //\n      .with_dtype<float>()\n      .with_device(device)\n      .verify(input_global_scale);\n  TensorMatcher({OffsetSize})  //\n      .with_dtype<int32_t>()\n      .with_device(device)\n      .verify(input_offset_by_experts)\n      .verify(output_scale_offset_by_experts);\n\n  const int device_id = input.device().device_id;\n  RuntimeCheck(getSMVersion(device_id) >= 100, \"fp4_quant is only supported on sm100+\");\n\n  const int BLOCK_SIZE = 16;\n  const auto m_topk = static_cast<int>(MTopK.unwrap());\n  const auto k = static_cast<int>(K.unwrap());\n  RuntimeCheck(k % BLOCK_SIZE == 0, \"k must be a multiple of 16\");\n  const auto n_experts = static_cast<int>(NExperts.unwrap());\n  const auto offset_size = static_cast<int>(OffsetSize.unwrap());\n  RuntimeCheck(offset_size == n_experts + 1, \"input/output offset size mismatch\");\n  RuntimeCheck(static_cast<int>(OutputCols.unwrap()) == k / 2, \"output second dim mismatch\");\n  const int scales_k = k / BLOCK_SIZE;\n  const int padded_k = (scales_k + 3) / 4 * 4;\n  RuntimeCheck(static_cast<int>(OutputScaleCols.unwrap()) * 4 == padded_k, \"output_scale second dim mismatch\");\n\n  const cudaStream_t stream = LaunchKernel::resolve_device(input.device());\n  if (host::is_type<fp16_t>(input.dtype())) {\n    quant_impl<half>(\n        output.data_ptr(),\n        output_scale.data_ptr(),\n        input.data_ptr(),\n        input_global_scale.data_ptr(),\n        input_offset_by_experts.data_ptr(),\n        output_scale_offset_by_experts.data_ptr(),\n        nullptr,  // mask\n        false,    // use_silu_and_mul\n        m_topk,\n        k,\n        n_experts,\n        stream);\n  } else {\n    quant_impl<__nv_bfloat16>(\n        output.data_ptr(),\n        output_scale.data_ptr(),\n        input.data_ptr(),\n        input_global_scale.data_ptr(),\n        input_offset_by_experts.data_ptr(),\n        output_scale_offset_by_experts.data_ptr(),\n        nullptr,  // mask\n        false,    // use_silu_and_mul\n        m_topk,\n        k,\n        n_experts,\n        stream);\n  }\n}\n\nvoid silu_and_mul_scaled_fp4_experts_quant_sm100a(\n    tvm::ffi::TensorView output,\n    tvm::ffi::TensorView output_scale,\n    tvm::ffi::TensorView input,\n    tvm::ffi::TensorView input_global_scale,\n    tvm::ffi::TensorView mask,\n    bool use_silu_and_mul) {\n  auto MTopK = SymbolicSize{\"m_topk\"};\n  auto KBy2 = SymbolicSize{\"k_by_2\"};\n  auto OutputCols = SymbolicSize{\"output_cols\"};\n  auto OutputScaleRows = SymbolicSize{\"output_scale_rows\"};\n  auto OutputScaleCols = SymbolicSize{\"output_scale_cols\"};\n  auto NExperts = SymbolicSize{\"n_experts\"};\n  auto device = SymbolicDevice{};\n\n  TensorMatcher({MTopK, KBy2})  //\n      .with_dtype<fp16_t, bf16_t>()\n      .template with_device<kDLCUDA>(device)\n      .verify(input);\n  TensorMatcher({MTopK, OutputCols})  //\n      .with_dtype<uint8_t>()\n      .with_device(device)\n      .verify(output);\n  TensorMatcher({OutputScaleRows, OutputScaleCols})  //\n      .with_dtype<int32_t>()\n      .with_device(device)\n      .verify(output_scale);\n  TensorMatcher({NExperts})  //\n      .with_dtype<float>()\n      .with_device(device)\n      .verify(input_global_scale);\n  TensorMatcher({NExperts})  //\n      .with_dtype<int32_t>()\n      .with_device(device)\n      .verify(mask);\n\n  const int device_id = input.device().device_id;\n  RuntimeCheck(getSMVersion(device_id) >= 100, \"fp4_quant is only supported on sm100+\");\n\n  const int BLOCK_SIZE = 16;\n  const auto m_topk = static_cast<int>(MTopK.unwrap());\n  const auto k_by_2 = static_cast<int>(KBy2.unwrap());\n  int k = k_by_2;\n  if (use_silu_and_mul) {\n    RuntimeCheck(k_by_2 % 2 == 0, \"k must be a multiple of 2\");\n    k = k_by_2 / 2;\n  }\n  const auto n_experts = static_cast<int>(NExperts.unwrap());\n  RuntimeCheck(static_cast<int>(OutputCols.unwrap()) == k / 2, \"output second dim mismatch\");\n  const int scales_k = k / BLOCK_SIZE;\n  const int padded_k = (scales_k + 3) / 4 * 4;\n  RuntimeCheck(static_cast<int>(OutputScaleCols.unwrap()) * 4 == padded_k, \"output_scale second dim mismatch\");\n\n  const cudaStream_t stream = LaunchKernel::resolve_device(input.device());\n  if (host::is_type<fp16_t>(input.dtype())) {\n    quant_impl<half>(\n        output.data_ptr(),\n        output_scale.data_ptr(),\n        input.data_ptr(),\n        input_global_scale.data_ptr(),\n        nullptr,  // input_offset_by_experts\n        nullptr,  // output_scale_offset_by_experts\n        mask.data_ptr(),\n        use_silu_and_mul,\n        m_topk,\n        k,\n        n_experts,\n        stream);\n  } else {\n    quant_impl<__nv_bfloat16>(\n        output.data_ptr(),\n        output_scale.data_ptr(),\n        input.data_ptr(),\n        input_global_scale.data_ptr(),\n        nullptr,  // input_offset_by_experts\n        nullptr,  // output_scale_offset_by_experts\n        mask.data_ptr(),\n        use_silu_and_mul,\n        m_topk,\n        k,\n        n_experts,\n        stream);\n  }\n}\n"
  },
  {
    "path": "python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_quant.cuh",
    "content": "/* Copyright 2025 SGLang Team. All Rights Reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n==============================================================================*/\n\n#include <sgl_kernel/type.cuh>\n#include <sgl_kernel/utils.cuh>\n\n#include <cutlass/arch/config.h>\n\n#include <cuda.h>\n#include <cuda_fp8.h>\n\n#define ELTS_PER_THREAD 8\n\nconstexpr int CVT_FP4_ELTS_PER_THREAD = 8;\nconstexpr int CVT_FP4_SF_VEC_SIZE = 16;\n\n// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).\nSGL_DEVICE uint32_t fp32_vec_to_e2m1(float (&array)[8]) {\n  // PTX instructions used here requires >= sm100f.\n#if CUTLASS_ARCH_MMA_SM100A_ENABLED || CUTLASS_ARCH_MMA_SM103A_ENABLED || CUTLASS_ARCH_MMA_SM120A_ENABLED || \\\n    (defined(__CUDA_ARCH_FAMILY_SPECIFIC__) && (__CUDA_ARCH_FAMILY_SPECIFIC__ >= 1000))\n  uint32_t val;\n  asm volatile(\n      \"{\\n\"\n      \".reg .b8 byte0;\\n\"\n      \".reg .b8 byte1;\\n\"\n      \".reg .b8 byte2;\\n\"\n      \".reg .b8 byte3;\\n\"\n      \"cvt.rn.satfinite.e2m1x2.f32   byte0, %2, %1;\\n\"\n      \"cvt.rn.satfinite.e2m1x2.f32   byte1, %4, %3;\\n\"\n      \"cvt.rn.satfinite.e2m1x2.f32   byte2, %6, %5;\\n\"\n      \"cvt.rn.satfinite.e2m1x2.f32   byte3, %8, %7;\\n\"\n      \"mov.b32 %0, {byte0, byte1, byte2, byte3};\\n\"\n      \"}\"\n      : \"=r\"(val)\n      : \"f\"(array[0]),\n        \"f\"(array[1]),\n        \"f\"(array[2]),\n        \"f\"(array[3]),\n        \"f\"(array[4]),\n        \"f\"(array[5]),\n        \"f\"(array[6]),\n        \"f\"(array[7]));\n  return val;\n#else\n  printf(\"fp32_vec_to_e2m1 is not supported on this architecture\\n\");\n  __trap();\n  return 0;\n#endif\n}\n\n// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).\nSGL_DEVICE uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) {\n  // PTX instructions used here requires >= sm100f.\n#if CUTLASS_ARCH_MMA_SM100A_ENABLED || CUTLASS_ARCH_MMA_SM103A_ENABLED || CUTLASS_ARCH_MMA_SM120A_ENABLED || \\\n    (defined(__CUDA_ARCH_FAMILY_SPECIFIC__) && (__CUDA_ARCH_FAMILY_SPECIFIC__ >= 1000))\n  uint32_t val;\n  asm volatile(\n      \"{\\n\"\n      \".reg .b8 byte0;\\n\"\n      \".reg .b8 byte1;\\n\"\n      \".reg .b8 byte2;\\n\"\n      \".reg .b8 byte3;\\n\"\n      \"cvt.rn.satfinite.e2m1x2.f32   byte0, %2, %1;\\n\"\n      \"cvt.rn.satfinite.e2m1x2.f32   byte1, %4, %3;\\n\"\n      \"cvt.rn.satfinite.e2m1x2.f32   byte2, %6, %5;\\n\"\n      \"cvt.rn.satfinite.e2m1x2.f32   byte3, %8, %7;\\n\"\n      \"mov.b32 %0, {byte0, byte1, byte2, byte3};\\n\"\n      \"}\"\n      : \"=r\"(val)\n      : \"f\"(array[0].x),\n        \"f\"(array[0].y),\n        \"f\"(array[1].x),\n        \"f\"(array[1].y),\n        \"f\"(array[2].x),\n        \"f\"(array[2].y),\n        \"f\"(array[3].x),\n        \"f\"(array[3].y));\n  return val;\n#else\n  printf(\"fp32_vec_to_e2m1 is not supported on this architecture\\n\");\n  __trap();\n  return 0;\n#endif\n}\n\n// Fast reciprocal.\nSGL_DEVICE float reciprocal_approximate_ftz(float a) {\n  float b;\n  asm volatile(\"rcp.approx.ftz.f32 %0, %1;\\n\" : \"=f\"(b) : \"f\"(a));\n  return b;\n}\n\ntemplate <class SFType, int CVT_FP4_NUM_THREADS_PER_SF>\nSGL_DEVICE uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx, int numCols, SFType* SFout) {\n#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)\n  static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || CVT_FP4_NUM_THREADS_PER_SF == 2);\n\n  // One pair of threads write one SF to global memory.\n  // TODO: stage through smem for packed STG.32\n  // is it better than STG.8 from 4 threads ?\n  if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) {\n    // SF vector index (16 elements share one SF in the K dimension).\n    int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF;\n    int32_t mIdx = rowIdx;\n\n    // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]\n    // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]\n\n    int32_t mTileIdx = mIdx / (32 * 4);\n    // SF vector size 16.\n    int factor = CVT_FP4_SF_VEC_SIZE * 4;\n    int32_t numKTiles = (numCols + factor - 1) / factor;\n    int64_t mTileStride = numKTiles * 32 * 4 * 4;\n\n    int32_t kTileIdx = (kIdx / 4);\n    int64_t kTileStride = 32 * 4 * 4;\n\n    // M tile layout [32, 4] is column-major.\n    int32_t outerMIdx = (mIdx % 32);\n    int64_t outerMStride = 4 * 4;\n\n    int32_t innerMIdx = (mIdx % (32 * 4)) / 32;\n    int64_t innerMStride = 4;\n\n    int32_t innerKIdx = (kIdx % 4);\n    int64_t innerKStride = 1;\n\n    // Compute the global offset.\n    int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + outerMIdx * outerMStride +\n                       innerMIdx * innerMStride + innerKIdx * innerKStride;\n\n    return reinterpret_cast<uint8_t*>(SFout) + SFOffset;\n  }\n#endif\n  return nullptr;\n}\n\n// Define a 16 bytes packed data type.\ntemplate <class Type>\nstruct PackedVec {\n  packed_t<Type> elts[4];\n};\n\ntemplate <>\nstruct PackedVec<__nv_fp8_e4m3> {\n  __nv_fp8x2_e4m3 elts[8];\n};\n"
  },
  {
    "path": "python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_quant_entry.cuh",
    "content": "/* Copyright 2025 SGLang Team. All Rights Reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n==============================================================================*/\n\n#include <sgl_kernel/tensor.h>\n#include <sgl_kernel/utils.h>\n\nvoid scaled_fp4_quant_sm100a_sm120a(\n    tvm::ffi::TensorView output,\n    tvm::ffi::TensorView input,\n    tvm::ffi::TensorView output_sf,\n    tvm::ffi::TensorView input_sf);\n\nvoid scaled_fp4_experts_quant_sm100a(\n    tvm::ffi::TensorView output,\n    tvm::ffi::TensorView output_scale,\n    tvm::ffi::TensorView input,\n    tvm::ffi::TensorView input_global_scale,\n    tvm::ffi::TensorView input_offset_by_experts,\n    tvm::ffi::TensorView output_scale_offset_by_experts);\n\nvoid silu_and_mul_scaled_fp4_experts_quant_sm100a(\n    tvm::ffi::TensorView output,\n    tvm::ffi::TensorView output_scale,\n    tvm::ffi::TensorView input,\n    tvm::ffi::TensorView input_global_scale,\n    tvm::ffi::TensorView mask,\n    bool use_silu_and_mul);\n\nvoid scaled_fp4_quant(\n    tvm::ffi::TensorView output,\n    tvm::ffi::TensorView input,\n    tvm::ffi::TensorView output_sf,\n    tvm::ffi::TensorView input_sf) {\n  scaled_fp4_quant_sm100a_sm120a(output, input, output_sf, input_sf);\n}\n\nvoid scaled_fp4_experts_quant(\n    tvm::ffi::TensorView output,\n    tvm::ffi::TensorView output_scale,\n    tvm::ffi::TensorView input,\n    tvm::ffi::TensorView input_global_scale,\n    tvm::ffi::TensorView input_offset_by_experts,\n    tvm::ffi::TensorView output_scale_offset_by_experts) {\n  scaled_fp4_experts_quant_sm100a(\n      output, output_scale, input, input_global_scale, input_offset_by_experts, output_scale_offset_by_experts);\n}\n\nvoid silu_and_mul_scaled_fp4_experts_quant(\n    tvm::ffi::TensorView output,\n    tvm::ffi::TensorView output_scale,\n    tvm::ffi::TensorView input,\n    tvm::ffi::TensorView input_global_scale,\n    tvm::ffi::TensorView mask,\n    bool use_silu_and_mul) {\n  silu_and_mul_scaled_fp4_experts_quant_sm100a(output, output_scale, input, input_global_scale, mask, use_silu_and_mul);\n}\n"
  },
  {
    "path": "python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_quant_kernels.cuh",
    "content": "/* Copyright 2025 SGLang Team. All Rights Reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n==============================================================================*/\n\n#include <sgl_kernel/tensor.h>\n#include <sgl_kernel/utils.h>\n\n#include <sgl_kernel/runtime.cuh>\n#include <sgl_kernel/utils.cuh>\n\n#include \"nvfp4_quant.cuh\"\n#include <cuda_runtime.h>\n#include <cuda_runtime_api.h>\n\nusing namespace host;\n\n// Quantizes the provided PackedVec into the uint32_t output\ntemplate <class Type, bool UE8M0_SF = false>\nSGL_DEVICE uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal, uint8_t* SFout) {\n#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)\n  // Get absolute maximum values among the local 8 values.\n  auto localMax = __habs2(vec.elts[0]);\n\n// Local maximum value.\n#pragma unroll\n  for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {\n    localMax = __hmax2(localMax, __habs2(vec.elts[i]));\n  }\n\n  // Get the absolute maximum among all 16 values (two threads).\n  localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax);\n  // Get the final absolute maximum values.\n  float vecMax = float(__hmax(localMax.x, localMax.y));\n\n  // Get the SF (max value of the vector / max value of e2m1).\n  // maximum value of e2m1 = 6.0.\n  // TODO: use half as compute data type.\n  float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f));\n  // 8 bits representation of the SF.\n  uint8_t fp8SFVal;\n  // Write the SF to global memory (STG.8).\n  if constexpr (UE8M0_SF) {\n    __nv_fp8_e8m0 tmp;\n    tmp.__x = __nv_cvt_float_to_e8m0(SFValue, __NV_SATFINITE, cudaRoundPosInf);\n    SFValue = static_cast<float>(tmp);\n    fp8SFVal = tmp.__x;\n  } else {\n    // Here SFValue is always positive, so E4M3 is the same as UE4M3.\n    __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue);\n    fp8SFVal = tmp.__x;\n    SFValue = static_cast<float>(tmp);\n  }\n  // Get the output scale.\n  // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) *\n  //                       reciprocal(SFScaleVal))\n  float outputScale =\n      SFValue != 0 ? reciprocal_approximate_ftz(SFValue * reciprocal_approximate_ftz(SFScaleVal)) : 0.0f;\n\n  if (SFout) {\n    // Write the SF to global memory (STG.8).\n    *SFout = fp8SFVal;\n  }\n\n  // Convert the input to float.\n  float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2];\n\n#pragma unroll\n  for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {\n    if constexpr (std::is_same_v<Type, half>) {\n      fp2Vals[i] = __half22float2(vec.elts[i]);\n    } else {\n      fp2Vals[i] = __bfloat1622float2(vec.elts[i]);\n    }\n    fp2Vals[i].x *= outputScale;\n    fp2Vals[i].y *= outputScale;\n  }\n\n  // Convert to e2m1 values.\n  uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals);\n\n  // Write the e2m1 values to global memory.\n  return e2m1Vec;\n#else\n  return 0;\n#endif\n}\n\n// Use UE4M3 by default.\ntemplate <class Type, bool UE8M0_SF = false>\n__global__ void\n#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)\n__launch_bounds__(512, 4) cvt_fp16_to_fp4(\n#else\ncvt_fp16_to_fp4(\n#endif\n    int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, uint32_t* out, uint32_t* SFout) {\n#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)\n  using PackedVec = PackedVec<Type>;\n  static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);\n  static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, \"Vec size is not matched.\");\n\n  // Get the global scaling factor, which will be applied to the SF.\n  // Note SFScale is the same as next GEMM's alpha, which is\n  // (448.f / (Alpha_A / 6.f)).\n  float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0];\n\n  // Input tensor row/col loops.\n  for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) {\n    for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD; colIdx += blockDim.x) {\n      int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx;\n      PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];\n      // Get the output tensor offset.\n      // Same as inOffset because 8 elements are packed into one uint32_t.\n      int64_t outOffset = inOffset;\n      auto& out_pos = out[outOffset];\n\n      auto sf_out =\n          cvt_quant_to_fp4_get_sf_out_offset<uint32_t, CVT_FP4_NUM_THREADS_PER_SF>(rowIdx, colIdx, numCols, SFout);\n\n      out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);\n    }\n  }\n#endif\n}\n\ntemplate <typename T>\nvoid invokeFP4Quantization(\n    int m,\n    int n,\n    T const* input,\n    float const* SFScale,\n    int64_t* output,\n    int32_t* SFOuput,\n    bool useUE8M0,\n    int multiProcessorCount,\n    cudaStream_t stream) {\n  // Grid, Block size.\n  // Each thread converts 8 values.\n  dim3 block(std::min(int(n / ELTS_PER_THREAD), 512));\n  // Get number of blocks per SM (assume we can fully utilize the SM).\n  int const numBlocksPerSM = 2048 / block.x;\n  dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM));\n\n  // Launch the cvt kernel.\n  if (useUE8M0) {\n    cvt_fp16_to_fp4<T, true><<<grid, block, 0, stream>>>(\n        m, n, input, SFScale, reinterpret_cast<uint32_t*>(output), reinterpret_cast<uint32_t*>(SFOuput));\n  } else {\n    cvt_fp16_to_fp4<T, false><<<grid, block, 0, stream>>>(\n        m, n, input, SFScale, reinterpret_cast<uint32_t*>(output), reinterpret_cast<uint32_t*>(SFOuput));\n  }\n}\n\n// Instantiate the function.\ntemplate void invokeFP4Quantization(\n    int m,\n    int n,\n    half const* input,\n    float const* SFScale,\n    int64_t* output,\n    int32_t* SFOuput,\n    bool useUE8M0,\n    int multiProcessorCount,\n    cudaStream_t stream);\n\ntemplate void invokeFP4Quantization(\n    int m,\n    int n,\n    __nv_bfloat16 const* input,\n    float const* SFScale,\n    int64_t* output,\n    int32_t* SFOuput,\n    bool useUE8M0,\n    int multiProcessorCount,\n    cudaStream_t stream);\n\ninline int getSMVersion(int device_id) {\n  int sm_major = 0;\n  int sm_minor = 0;\n  RuntimeDeviceCheck(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device_id));\n  RuntimeDeviceCheck(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device_id));\n  return sm_major * 10 + sm_minor;\n}\n\nvoid scaled_fp4_quant_sm100a_sm120a(\n    tvm::ffi::TensorView output,\n    tvm::ffi::TensorView input,\n    tvm::ffi::TensorView output_sf,\n    tvm::ffi::TensorView input_sf) {\n  RuntimeCheck(input.device().device_type == kDLCUDA, \"input must be a CUDA tensor\");\n  RuntimeCheck(output.device() == input.device(), \"output and input must be on same device\");\n  RuntimeCheck(output_sf.device() == input.device(), \"output_sf and input must be on same device\");\n  RuntimeCheck(input_sf.device() == input.device(), \"input_sf and input must be on same device\");\n  RuntimeCheck(input.dim() == 2, \"input must be a 2D tensor\");\n  RuntimeCheck(output.dim() == 2, \"output must be a 2D tensor\");\n  RuntimeCheck(output_sf.dim() == 2, \"output_sf must be a 2D tensor\");\n  RuntimeCheck(input_sf.numel() == 1, \"input_sf must have exactly one element\");\n  RuntimeCheck(host::is_type<uint8_t>(output.dtype()), \"output must be uint8\");\n  RuntimeCheck(host::is_type<int32_t>(output_sf.dtype()), \"output_sf must be int32\");\n  RuntimeCheck(host::is_type<float>(input_sf.dtype()), \"input_sf must be float32\");\n  RuntimeCheck(\n      host::is_type<fp16_t>(input.dtype()) || host::is_type<bf16_t>(input.dtype()), \"input dtype must be fp16 or bf16\");\n\n  const int device_id = input.device().device_id;\n  const auto sm_version = getSMVersion(device_id);\n  RuntimeCheck(sm_version >= 100, \"fp4_quant is only supported on sm100+\");\n\n  const int32_t m = static_cast<int32_t>(input.size(0));\n  const int32_t n = static_cast<int32_t>(input.size(1));\n\n  RuntimeCheck(output.size(0) == m, \"output row size mismatch\");\n  RuntimeCheck(output.size(1) == n / 2, \"output column size mismatch\");\n  RuntimeCheck(n % 16 == 0, \"The N dimension must be multiple of 16.\");\n\n  const int multiProcessorCount = static_cast<int>(runtime::get_sm_count(device_id));\n\n  auto input_sf_ptr = static_cast<float const*>(input_sf.data_ptr());\n  auto sf_out = static_cast<int32_t*>(output_sf.data_ptr());\n  auto output_ptr = static_cast<int64_t*>(output.data_ptr());\n  const cudaStream_t stream = LaunchKernel::resolve_device(input.device());\n\n  constexpr bool useUE8M0 = false;\n  if (host::is_type<fp16_t>(input.dtype())) {\n    auto input_ptr = reinterpret_cast<half const*>(input.data_ptr());\n    invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr, sf_out, useUE8M0, multiProcessorCount, stream);\n  } else {\n    auto input_ptr = reinterpret_cast<__nv_bfloat16 const*>(input.data_ptr());\n    invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr, sf_out, useUE8M0, multiProcessorCount, stream);\n  }\n}\n"
  },
  {
    "path": "python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_scaled_mm_entry.cuh",
    "content": "/* Copyright 2025 SGLang Team. All Rights Reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n==============================================================================*/\n\n#include <sgl_kernel/tensor.h>\n\nvoid cutlass_scaled_fp4_mm_sm100a_sm120a(\n    tvm::ffi::TensorView D,\n    tvm::ffi::TensorView A,\n    tvm::ffi::TensorView B,\n    tvm::ffi::TensorView A_sf,\n    tvm::ffi::TensorView B_sf,\n    tvm::ffi::TensorView alpha);\n\nvoid cutlass_scaled_fp4_mm(\n    tvm::ffi::TensorView D,\n    tvm::ffi::TensorView A,\n    tvm::ffi::TensorView B,\n    tvm::ffi::TensorView A_sf,\n    tvm::ffi::TensorView B_sf,\n    tvm::ffi::TensorView alpha) {\n  cutlass_scaled_fp4_mm_sm100a_sm120a(D, A, B, A_sf, B_sf, alpha);\n}\n"
  },
  {
    "path": "python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_scaled_mm_kernels.cuh",
    "content": "/* Copyright 2025 SGLang Team. All Rights Reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n==============================================================================*/\n\n#include <sgl_kernel/tensor.h>\n#include <sgl_kernel/utils.h>\n\n#include <sgl_kernel/runtime.cuh>\n#include <sgl_kernel/utils.cuh>\n\n#include <cstddef>\n#include <cstdint>\n#include <cuda_runtime.h>\n#include <unordered_map>\n\nusing namespace host;\n\n// clang-format off\n#include \"cutlass/cutlass.h\"\n#include \"cutlass/gemm/collective/collective_builder.hpp\"\n#include \"cutlass/epilogue/collective/collective_builder.hpp\"\n#include \"cutlass/gemm/device/gemm_universal_adapter.h\"\n#include \"cutlass/gemm/kernel/gemm_universal.hpp\"\n#include \"cutlass/util/packed_stride.hpp\"\n// clang-format on\n\n/**\n * Helper function for checking CUTLASS errors\n */\n#define CUTLASS_CHECK(status)                                                        \\\n  {                                                                                  \\\n    cutlass::Status error = status;                                                  \\\n    RuntimeCheck(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \\\n  }\n\nusing namespace cute;\n\n// Helper function for next power of 2\ninline uint32_t next_pow_2(uint32_t x) {\n  if (x == 0) return 1;\n  x--;\n  x |= x >> 1;\n  x |= x >> 2;\n  x |= x >> 4;\n  x |= x >> 8;\n  x |= x >> 16;\n  return x + 1;\n}\n\nstruct WorkspaceKey {\n  int device_id;\n  uintptr_t stream;\n  auto operator==(const WorkspaceKey&) const -> bool = default;\n};\n\nstruct WorkspaceKeyHash {\n  auto operator()(const WorkspaceKey& key) const -> size_t {\n    size_t h1 = std::hash<int>{}(key.device_id);\n    size_t h2 = std::hash<uintptr_t>{}(key.stream);\n    return h1 ^ (h2 + 0x9e3779b97f4a7c15ULL + (h1 << 6) + (h1 >> 2));\n  }\n};\n\nstruct WorkspaceState {\n  void* ptr = nullptr;\n  size_t bytes = 0;\n};\n\ninline auto get_cached_workspace(size_t required_bytes, int device_id, cudaStream_t stream) -> void* {\n  if (required_bytes == 0) {\n    return nullptr;\n  }\n\n  thread_local std::unordered_map<WorkspaceKey, WorkspaceState, WorkspaceKeyHash> cache;\n  WorkspaceKey key{device_id, reinterpret_cast<uintptr_t>(stream)};\n  auto& ws = cache[key];\n\n  if (ws.ptr != nullptr && ws.bytes >= required_bytes) {\n    return ws.ptr;\n  }\n\n  RuntimeDeviceCheck(cudaSetDevice(device_id));\n  if (ws.ptr != nullptr) {\n    RuntimeDeviceCheck(cudaFreeAsync(ws.ptr, stream));\n    ws.ptr = nullptr;\n    ws.bytes = 0;\n  }\n  RuntimeDeviceCheck(cudaMallocAsync(&ws.ptr, required_bytes, stream));\n  ws.bytes = required_bytes;\n  return ws.ptr;\n}\n\n#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || \\\n    defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)\n// Config(half_t/bfloat16_t) for M <= 128\ntemplate <typename T>\nstruct KernelConfigM128 {\n  using OutputType = T;\n  using MmaTileShape = Shape<_128, _256, _256>;\n  using ClusterShape = Shape<int, int, _1>;\n  using EpilogueTile = Shape<_128, _64>;  // Avoid register spilling\n  using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized1Sm;\n  using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100;\n  const static dim3 preferred_cluster;\n  const static dim3 fallback_cluster;\n};\ntemplate <typename T>\nconst dim3 KernelConfigM128<T>::preferred_cluster(1, 4, 1);\ntemplate <typename T>\nconst dim3 KernelConfigM128<T>::fallback_cluster(1, 2, 1);\n\n// Config(half_t/bfloat16_t) for M <= 256\ntemplate <typename T>\nstruct KernelConfigM256 {\n  using OutputType = T;\n  using MmaTileShape = Shape<_256, _256, _256>;\n  using ClusterShape = Shape<int, int, _1>;\n  using EpilogueTile = Shape<_128, _64>;  // Avoid register spilling\n  using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized2Sm;\n  using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100;\n  const static dim3 preferred_cluster;\n  const static dim3 fallback_cluster;\n};\ntemplate <typename T>\nconst dim3 KernelConfigM256<T>::preferred_cluster(2, 4, 1);\ntemplate <typename T>\nconst dim3 KernelConfigM256<T>::fallback_cluster(2, 1, 1);\n\n// Default config(half_t/bfloat16_t) for M > 256\ntemplate <typename T>\nstruct KernelConfigDefault {\n  using OutputType = T;\n  using MmaTileShape = Shape<_256, _256, _256>;\n  using ClusterShape = Shape<int, int, _1>;\n  using EpilogueTile = Shape<_128, _64>;  // Avoid register spilling\n  using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized2Sm;\n  using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100;\n  const static dim3 preferred_cluster;\n  const static dim3 fallback_cluster;\n};\ntemplate <typename T>\nconst dim3 KernelConfigDefault<T>::preferred_cluster(4, 4, 1);\ntemplate <typename T>\nconst dim3 KernelConfigDefault<T>::fallback_cluster(2, 1, 1);\n\nstruct KernelConfigFp32 {\n  using OutputType = float;\n  using MmaTileShape = Shape<_128, _128, _256>;\n  using ClusterShape = Shape<int, int, _1>;\n  using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;\n  using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized1Sm;\n  using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100;\n  const static dim3 preferred_cluster;\n  const static dim3 fallback_cluster;\n};\nconst dim3 KernelConfigFp32::preferred_cluster = dim3(1, 4, 1);\nconst dim3 KernelConfigFp32::fallback_cluster = dim3(1, 2, 1);\n\n// SM120 specific configurations\nstruct sm120_fp4_config_M256 {\n  using ClusterShape = Shape<_1, _1, _1>;\n  using MmaTileShape = Shape<_128, _128, _128>;\n  using PerSmTileShape_MNK = Shape<_128, _128, _128>;\n};\n\nstruct sm120_fp4_config_default {\n  using ClusterShape = Shape<_1, _1, _1>;\n  using MmaTileShape = Shape<_256, _128, _128>;\n  using PerSmTileShape_MNK = Shape<_256, _128, _128>;\n};\n\ntemplate <typename KernelConfig>\nstruct Fp4GemmSm100 {\n  using Config = KernelConfig;  // For generating args\n  using OutputType = typename KernelConfig::OutputType;\n  // A matrix configuration\n  using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>;\n  using LayoutATag = cutlass::layout::RowMajor;\n  static constexpr int AlignmentA = 32;\n\n  // B matrix configuration\n  using ElementB = cutlass::nv_float4_t<cutlass::float_e2m1_t>;\n  using LayoutBTag = cutlass::layout::ColumnMajor;\n  static constexpr int AlignmentB = 32;\n\n  // C/D matrix configuration\n  using ElementD = OutputType;\n  using ElementC = OutputType;\n  using LayoutCTag = cutlass::layout::RowMajor;\n  using LayoutDTag = cutlass::layout::RowMajor;\n  static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;\n  static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;\n  // Kernel functional config\n  using ElementAccumulator = float;\n  using ArchTag = cutlass::arch::Sm100;\n  using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp;\n\n  // Kernel Perf config\n  using MmaTileShape = typename KernelConfig::MmaTileShape;\n  using ClusterShape = typename KernelConfig::ClusterShape;\n  using EpilogueTile = typename KernelConfig::EpilogueTile;\n  using EpilogueSchedule = typename KernelConfig::EpilogueSchedule;\n  using MainloopSchedule = typename KernelConfig::MainloopSchedule;\n\n  using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<\n      ArchTag,\n      OperatorClass,\n      MmaTileShape,\n      ClusterShape,\n      EpilogueTile,\n      ElementAccumulator,\n      ElementAccumulator,\n      void,\n      LayoutCTag,\n      AlignmentC,\n      ElementD,\n      LayoutDTag,\n      AlignmentD,\n      EpilogueSchedule,\n      cutlass::epilogue::fusion::LinearCombination<ElementD, float, void, float>>::CollectiveOp;\n\n  using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<\n      ArchTag,\n      OperatorClass,\n      ElementA,\n      LayoutATag,\n      AlignmentA,\n      ElementB,\n      LayoutBTag,\n      AlignmentB,\n      ElementAccumulator,\n      MmaTileShape,\n      ClusterShape,\n      cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(\n          sizeof(typename CollectiveEpilogue::SharedStorage))>,\n      MainloopSchedule>::CollectiveOp;\n\n  using GemmKernel =\n      cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;\n  using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;\n  using StrideA = typename Gemm::GemmKernel::StrideA;\n  using LayoutA = decltype(cute::make_layout(make_shape(0, 0, 0), StrideA{}));\n  using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA;\n  using StrideB = typename Gemm::GemmKernel::StrideB;\n  using LayoutB = decltype(cute::make_layout(make_shape(0, 0, 0), StrideB{}));\n  using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB;\n  using StrideC = typename Gemm::GemmKernel::StrideC;\n  using LayoutC = decltype(cute::make_layout(make_shape(0, 0, 0), StrideC{}));\n  using StrideD = typename Gemm::GemmKernel::StrideD;\n  using LayoutD = decltype(cute::make_layout(make_shape(0, 0, 0), StrideD{}));\n};\n\n// SM120 specific GEMM template\ntemplate <typename Config, typename OutType>\nstruct Fp4GemmSm120 {\n  using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>;\n  using LayoutATag = cutlass::layout::RowMajor;\n  static constexpr int AlignmentA = 32;\n\n  using ElementB = cutlass::nv_float4_t<cutlass::float_e2m1_t>;\n  using LayoutBTag = cutlass::layout::ColumnMajor;\n  static constexpr int AlignmentB = 32;\n\n  using ElementD = OutType;\n  using ElementC = OutType;\n  using LayoutCTag = cutlass::layout::RowMajor;\n  using LayoutDTag = cutlass::layout::RowMajor;\n  static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;\n  static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;\n\n  using ElementAccumulator = float;\n  using ArchTag = cutlass::arch::Sm120;\n  using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp;\n\n  using MmaTileShape = typename Config::MmaTileShape;\n  using ClusterShape = typename Config::ClusterShape;\n  using PerSmTileShape_MNK = typename Config::PerSmTileShape_MNK;\n\n  using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<\n      ArchTag,\n      OperatorClass,\n      PerSmTileShape_MNK,\n      ClusterShape,\n      cutlass::epilogue::collective::EpilogueTileAuto,\n      ElementAccumulator,\n      ElementAccumulator,\n      ElementC,\n      LayoutCTag,\n      AlignmentC,\n      ElementD,\n      LayoutDTag,\n      AlignmentD,\n      cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp;\n\n  using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<\n      ArchTag,\n      OperatorClass,\n      ElementA,\n      LayoutATag,\n      AlignmentA,\n      ElementB,\n      LayoutBTag,\n      AlignmentB,\n      ElementAccumulator,\n      MmaTileShape,\n      ClusterShape,\n      cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(\n          sizeof(typename CollectiveEpilogue::SharedStorage))>,\n      cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp;\n\n  using GemmKernel =\n      cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;\n\n  using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;\n};\n\ntemplate <typename T>\ntypename T::Gemm::Arguments args_from_options(\n    tvm::ffi::TensorView D,\n    tvm::ffi::TensorView A,\n    tvm::ffi::TensorView B,\n    tvm::ffi::TensorView A_sf,\n    tvm::ffi::TensorView B_sf,\n    tvm::ffi::TensorView alpha,\n    int64_t M,\n    int64_t N,\n    int64_t K) {\n  using ElementA = typename T::Gemm::ElementA;\n  using ElementB = typename T::Gemm::ElementB;\n  using ElementSFA = cutlass::float_ue4m3_t;\n  using ElementSFB = cutlass::float_ue4m3_t;\n  using ElementD = typename T::Gemm::ElementD;\n  using ElementCompute = float;\n  using StrideA = typename T::StrideA;\n  using StrideB = typename T::StrideB;\n  using StrideD = typename T::StrideD;\n  using Sm1xxBlkScaledConfig = typename T::Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;\n\n  int m = static_cast<int>(M);\n  int n = static_cast<int>(N);\n  int k = static_cast<int>(K);\n  auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {m, k, 1});\n  auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {n, k, 1});\n  auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {m, n, 1});\n\n  auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1));\n  auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1));\n\n  typename T::Gemm::Arguments arguments{\n      cutlass::gemm::GemmUniversalMode::kGemm,\n      {m, n, k, 1},\n      {// Mainloop arguments\n       static_cast<ElementA const*>(A.data_ptr()),\n       stride_A,\n       static_cast<ElementB const*>(B.data_ptr()),\n       stride_B,\n       static_cast<ElementSFA const*>(A_sf.data_ptr()),\n       layout_SFA,\n       static_cast<ElementSFB const*>(B_sf.data_ptr()),\n       layout_SFB},\n      {     // Epilogue arguments\n       {},  // epilogue.thread\n       nullptr,\n       stride_D,\n       static_cast<ElementD*>(D.data_ptr()),\n       stride_D}};\n  auto& fusion_args = arguments.epilogue.thread;\n  fusion_args.alpha_ptr = static_cast<ElementCompute const*>(alpha.data_ptr());\n  using KernelConfig = typename T::Config;\n  arguments.hw_info.cluster_shape = KernelConfig::preferred_cluster;\n  arguments.hw_info.cluster_shape_fallback = KernelConfig::fallback_cluster;\n  return arguments;\n}\n\ntemplate <typename T>\nvoid runGemm(\n    tvm::ffi::TensorView D,\n    tvm::ffi::TensorView A,\n    tvm::ffi::TensorView B,\n    tvm::ffi::TensorView A_sf,\n    tvm::ffi::TensorView B_sf,\n    tvm::ffi::TensorView alpha,\n    int64_t m,\n    int64_t n,\n    int64_t k,\n    cudaStream_t stream) {\n  typename T::Gemm gemm;\n  auto arguments = args_from_options<T>(D, A, B, A_sf, B_sf, alpha, m, n, k);\n\n  size_t workspace_size = T::Gemm::get_workspace_size(arguments);\n  int device_id = A.device().device_id;\n  void* workspace = get_cached_workspace(workspace_size, device_id, stream);\n\n  CUTLASS_CHECK(gemm.can_implement(arguments));\n\n  CUTLASS_CHECK(gemm.initialize(arguments, workspace, stream));\n\n  CUTLASS_CHECK(gemm.run(arguments, workspace, stream));\n}\n\n// SM120 specific args_from_options function\ntemplate <typename Gemm>\ntypename Gemm::Arguments args_from_options_sm120(\n    tvm::ffi::TensorView D,\n    tvm::ffi::TensorView A,\n    tvm::ffi::TensorView B,\n    tvm::ffi::TensorView A_sf,\n    tvm::ffi::TensorView B_sf,\n    tvm::ffi::TensorView alpha,\n    int M,\n    int N,\n    int K) {\n  using ElementA = typename Gemm::ElementA;\n  using ElementB = typename Gemm::ElementB;\n  using ElementD = typename Gemm::ElementD;\n  using ElementSFA = cutlass::float_ue4m3_t;\n  using ElementSFB = cutlass::float_ue4m3_t;\n  using ElementCompute = float;\n\n  using StrideA = typename Gemm::GemmKernel::StrideA;\n  using StrideB = typename Gemm::GemmKernel::StrideB;\n  using StrideC = typename Gemm::GemmKernel::StrideC;\n  using StrideD = typename Gemm::GemmKernel::StrideD;\n\n  using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;\n\n  auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1});\n  auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1});\n  auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1});\n\n  auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1));\n  auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1));\n\n  typename Gemm::Arguments arguments{\n      cutlass::gemm::GemmUniversalMode::kGemm,\n      {M, N, K, 1},\n      {static_cast<ElementA const*>(A.data_ptr()),\n       stride_A,\n       static_cast<ElementB const*>(B.data_ptr()),\n       stride_B,\n       static_cast<ElementSFA const*>(A_sf.data_ptr()),\n       layout_SFA,\n       static_cast<ElementSFB const*>(B_sf.data_ptr()),\n       layout_SFB},\n      {{}, static_cast<ElementD const*>(D.data_ptr()), stride_D, static_cast<ElementD*>(D.data_ptr()), stride_D}};\n  auto& fusion_args = arguments.epilogue.thread;\n  fusion_args.alpha_ptr = static_cast<ElementCompute const*>(alpha.data_ptr());\n\n  return arguments;\n}\n\n// SM120 specific runGemm function\ntemplate <typename Gemm>\nvoid runGemmSm120(\n    tvm::ffi::TensorView D,\n    tvm::ffi::TensorView A,\n    tvm::ffi::TensorView B,\n    tvm::ffi::TensorView A_sf,\n    tvm::ffi::TensorView B_sf,\n    tvm::ffi::TensorView alpha,\n    int M,\n    int N,\n    int K,\n    cudaStream_t stream) {\n  Gemm gemm;\n\n  auto arguments = args_from_options_sm120<Gemm>(D, A, B, A_sf, B_sf, alpha, M, N, K);\n\n  size_t workspace_size = Gemm::get_workspace_size(arguments);\n  int device_id = A.device().device_id;\n  void* workspace = get_cached_workspace(workspace_size, device_id, stream);\n\n  CUTLASS_CHECK(gemm.can_implement(arguments));\n\n  CUTLASS_CHECK(gemm.initialize(arguments, workspace, stream));\n\n  CUTLASS_CHECK(gemm.run(arguments, workspace, stream));\n}\n\n// Dispatch function to select appropriate config based on M\ntemplate <typename OutType>\nvoid cutlassFp4GemmDispatch(\n    tvm::ffi::TensorView D,\n    tvm::ffi::TensorView A,\n    tvm::ffi::TensorView B,\n    tvm::ffi::TensorView A_sf,\n    tvm::ffi::TensorView B_sf,\n    tvm::ffi::TensorView alpha,\n    int64_t m,\n    int64_t n,\n    int64_t k,\n    cudaStream_t stream) {\n  if (m <= 128) {\n    // m in [1, 128]\n    runGemm<Fp4GemmSm100<KernelConfigM128<OutType>>>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);\n  } else if (m <= 256) {\n    // m in (128, 256]\n    runGemm<Fp4GemmSm100<KernelConfigM256<OutType>>>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);\n  } else {\n    // m in (256, inf)\n    runGemm<Fp4GemmSm100<KernelConfigDefault<OutType>>>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);\n  }\n}\n\n// Dispatch function to select appropriate config based on M\ntemplate <>\nvoid cutlassFp4GemmDispatch<float>(\n    tvm::ffi::TensorView D,\n    tvm::ffi::TensorView A,\n    tvm::ffi::TensorView B,\n    tvm::ffi::TensorView A_sf,\n    tvm::ffi::TensorView B_sf,\n    tvm::ffi::TensorView alpha,\n    int64_t m,\n    int64_t n,\n    int64_t k,\n    cudaStream_t stream) {\n  runGemm<Fp4GemmSm100<KernelConfigFp32>>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);\n}\n\n// SM120 specific dispatch functions\nvoid cutlass_fp4_bf16_gemm_dispatch_sm120(\n    tvm::ffi::TensorView D,\n    tvm::ffi::TensorView A,\n    tvm::ffi::TensorView B,\n    tvm::ffi::TensorView A_sf,\n    tvm::ffi::TensorView B_sf,\n    tvm::ffi::TensorView alpha,\n    int m,\n    int n,\n    int k,\n    cudaStream_t stream) {\n  uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));\n  if (mp2 <= 256) {\n    runGemmSm120<Fp4GemmSm120<sm120_fp4_config_M256, cutlass::bfloat16_t>::Gemm>(\n        D, A, B, A_sf, B_sf, alpha, m, n, k, stream);\n  } else {\n    runGemmSm120<Fp4GemmSm120<sm120_fp4_config_default, cutlass::bfloat16_t>::Gemm>(\n        D, A, B, A_sf, B_sf, alpha, m, n, k, stream);\n  }\n}\n\nvoid cutlass_fp4_f16_gemm_dispatch_sm120(\n    tvm::ffi::TensorView D,\n    tvm::ffi::TensorView A,\n    tvm::ffi::TensorView B,\n    tvm::ffi::TensorView A_sf,\n    tvm::ffi::TensorView B_sf,\n    tvm::ffi::TensorView alpha,\n    int m,\n    int n,\n    int k,\n    cudaStream_t stream) {\n  uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));\n  if (mp2 <= 256) {\n    runGemmSm120<Fp4GemmSm120<sm120_fp4_config_M256, cutlass::half_t>::Gemm>(\n        D, A, B, A_sf, B_sf, alpha, m, n, k, stream);\n  } else {\n    runGemmSm120<Fp4GemmSm120<sm120_fp4_config_default, cutlass::half_t>::Gemm>(\n        D, A, B, A_sf, B_sf, alpha, m, n, k, stream);\n  }\n}\n\n#else\ntemplate <typename T>\nvoid cutlassFp4GemmDispatch(\n    tvm::ffi::TensorView D,\n    tvm::ffi::TensorView A,\n    tvm::ffi::TensorView B,\n    tvm::ffi::TensorView A_sf,\n    tvm::ffi::TensorView B_sf,\n    tvm::ffi::TensorView alpha,\n    int64_t m,\n    int64_t n,\n    int64_t k,\n    cudaStream_t stream) {\n  RuntimeCheck(\n      false,\n      \"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to \"\n      \"a CUTLASS 3.8 source directory to enable support.\");\n}\n#endif  // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) ||\n        // defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)\n\ninline int getSMVersion(int device_id) {\n  int sm_major = 0;\n  int sm_minor = 0;\n  RuntimeDeviceCheck(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device_id));\n  RuntimeDeviceCheck(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device_id));\n  return sm_major * 10 + sm_minor;\n}\n\nvoid cutlass_scaled_fp4_mm_sm100a_sm120a(\n    tvm::ffi::TensorView D,\n    tvm::ffi::TensorView A,\n    tvm::ffi::TensorView B,\n    tvm::ffi::TensorView A_sf,\n    tvm::ffi::TensorView B_sf,\n    tvm::ffi::TensorView alpha) {\n  RuntimeCheck(A.device().device_type == kDLCUDA, \"a must be a CUDA tensor\");\n  RuntimeCheck(B.device().device_type == kDLCUDA, \"b must be a CUDA tensor\");\n  RuntimeCheck(A_sf.device().device_type == kDLCUDA, \"scale_a must be a CUDA tensor\");\n  RuntimeCheck(B_sf.device().device_type == kDLCUDA, \"scale_b must be a CUDA tensor\");\n  RuntimeCheck(alpha.device().device_type == kDLCUDA, \"alpha must be a CUDA tensor\");\n  RuntimeCheck(D.device().device_type == kDLCUDA, \"out must be a CUDA tensor\");\n\n  RuntimeCheck(A.device() == B.device(), \"a and b must be on same device\");\n  RuntimeCheck(A.device() == A_sf.device(), \"a and scale_a must be on same device\");\n  RuntimeCheck(A.device() == B_sf.device(), \"a and scale_b must be on same device\");\n  RuntimeCheck(A.device() == alpha.device(), \"a and alpha must be on same device\");\n  RuntimeCheck(A.device() == D.device(), \"a and out must be on same device\");\n\n  RuntimeCheck(A.is_contiguous(), \"a must be contiguous\");\n  RuntimeCheck(B.is_contiguous(), \"b must be contiguous\");\n  RuntimeCheck(A_sf.is_contiguous(), \"scale_a must be contiguous\");\n  RuntimeCheck(B_sf.is_contiguous(), \"scale_b must be contiguous\");\n  RuntimeCheck(alpha.is_contiguous(), \"alpha must be contiguous\");\n  RuntimeCheck(D.is_contiguous(), \"out must be contiguous\");\n\n  RuntimeCheck(host::is_type<uint8_t>(A.dtype()), \"a must be uint8\");\n  RuntimeCheck(host::is_type<uint8_t>(B.dtype()), \"b must be uint8\");\n  RuntimeCheck(host::is_type<fp8_e4m3_t>(A_sf.dtype()), \"scale_a must be float8_e4m3fn\");\n  RuntimeCheck(host::is_type<fp8_e4m3_t>(B_sf.dtype()), \"scale_b must be float8_e4m3fn\");\n  RuntimeCheck(host::is_type<float>(alpha.dtype()), \"alpha must be float32\");\n\n  RuntimeCheck(A.dim() == 2, \"a must be a matrix\");\n  RuntimeCheck(B.dim() == 2, \"b must be a matrix\");\n  RuntimeCheck(A_sf.dim() == 2, \"scale_a must be a matrix\");\n  RuntimeCheck(B_sf.dim() == 2, \"scale_b must be a matrix\");\n  RuntimeCheck(alpha.numel() == 1, \"alpha must have exactly one element\");\n\n  RuntimeCheck(\n      A.size(1) == B.size(1),\n      \"a and b shapes cannot be multiplied (\",\n      A.size(0),\n      \"x\",\n      A.size(1),\n      \" and \",\n      B.size(0),\n      \"x\",\n      B.size(1),\n      \")\");\n\n  const auto m = static_cast<int64_t>(A.size(0));\n  const auto n = static_cast<int64_t>(B.size(0));\n  const auto k = static_cast<int64_t>(A.size(1) * 2);\n\n  RuntimeCheck(D.dim() == 2, \"out must be 2D\");\n  RuntimeCheck(D.size(0) == m, \"out first dim must equal m\");\n  RuntimeCheck(D.size(1) == n, \"out second dim must equal n\");\n\n  constexpr int alignment = 32;\n  RuntimeCheck(k % alignment == 0, \"Expected k to be divisible by \", alignment, \", but got k: \", k);\n  RuntimeCheck(n % alignment == 0, \"Expected n to be divisible by \", alignment, \", but got n: \", n);\n\n  auto round_up = [](int64_t x, int64_t y) { return (x + y - 1) / y * y; };\n  const int64_t rounded_m = round_up(m, 128);\n  const int64_t rounded_n = round_up(n, 128);\n  const int64_t rounded_k = round_up(k / 16, 4);\n\n  RuntimeCheck(\n      A_sf.size(1) == B_sf.size(1),\n      \"scale_a and scale_b shapes cannot be multiplied (\",\n      A_sf.size(0),\n      \"x\",\n      A_sf.size(1),\n      \" and \",\n      B_sf.size(0),\n      \"x\",\n      B_sf.size(1),\n      \")\");\n  RuntimeCheck(\n      A_sf.size(0) == rounded_m && A_sf.size(1) == rounded_k,\n      \"scale_a must be padded/swizzled to shape (\",\n      rounded_m,\n      \"x\",\n      rounded_k,\n      \"), got (\",\n      A_sf.size(0),\n      \"x\",\n      A_sf.size(1),\n      \")\");\n  RuntimeCheck(\n      B_sf.size(0) == rounded_n && B_sf.size(1) == rounded_k,\n      \"scale_b must be padded/swizzled to shape (\",\n      rounded_n,\n      \"x\",\n      rounded_k,\n      \"), got (\",\n      B_sf.size(0),\n      \"x\",\n      B_sf.size(1),\n      \")\");\n\n  const cudaStream_t stream = LaunchKernel::resolve_device(A.device());\n  const int sm_version = getSMVersion(A.device().device_id);\n\n  if (sm_version >= 120) {\n    if (host::is_type<fp16_t>(D.dtype())) {\n      cutlass_fp4_f16_gemm_dispatch_sm120(\n          D, A, B, A_sf, B_sf, alpha, static_cast<int>(m), static_cast<int>(n), static_cast<int>(k), stream);\n    } else if (host::is_type<bf16_t>(D.dtype())) {\n      cutlass_fp4_bf16_gemm_dispatch_sm120(\n          D, A, B, A_sf, B_sf, alpha, static_cast<int>(m), static_cast<int>(n), static_cast<int>(k), stream);\n    } else {\n      Panic(\"Unsupported output data type of nvfp4 mm sm120\");\n    }\n  } else {\n    if (host::is_type<fp16_t>(D.dtype())) {\n      cutlassFp4GemmDispatch<cutlass::half_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);\n    } else if (host::is_type<bf16_t>(D.dtype())) {\n      cutlassFp4GemmDispatch<cutlass::bfloat16_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);\n    } else if (host::is_type<float>(D.dtype())) {\n      cutlassFp4GemmDispatch<float>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);\n    } else {\n      Panic(\"Unsupported output data type of nvfp4 mm\");\n    }\n  }\n}\n"
  },
  {
    "path": "python/sglang/jit_kernel/csrc/gemm/per_tensor_quant_fp8.cuh",
    "content": "#include <sgl_kernel/tensor.h>\n#include <sgl_kernel/utils.h>\n\n#include <sgl_kernel/atomic.cuh>\n#include <sgl_kernel/cta.cuh>\n#include <sgl_kernel/math.cuh>\n#include <sgl_kernel/tile.cuh>\n#include <sgl_kernel/utils.cuh>\n#include <sgl_kernel/vec.cuh>\n#include <sgl_kernel/warp.cuh>\n\n#include <cstddef>\n#include <cstdint>\n\nnamespace {\n\nconstexpr size_t kBlockSize = 256;\n\n// each warp will handle 512B data\ntemplate <typename T>\n__global__ void\nper_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output_s, const int64_t num_elements) {\n  using namespace device;\n  constexpr uint32_t VEC_SIZE = 16 / sizeof(T);\n\n  const int64_t gid = blockIdx.x * blockDim.x + threadIdx.x;\n\n  float max_value = 0.0f;\n  if (gid * VEC_SIZE + VEC_SIZE <= num_elements) {\n    using vec_t = AlignedVector<T, VEC_SIZE>;\n    const auto gmem_in = tile::Memory<vec_t>::thread();\n    const auto input_vec = gmem_in.load(input, gid);\n#pragma unroll\n    for (uint32_t i = 0; i < VEC_SIZE; ++i) {\n      const float value = static_cast<float>(input_vec[i]);\n      max_value = math::max(max_value, math::abs(value));\n    }\n  } else if (gid * VEC_SIZE < num_elements) {\n    [[unlikely]];  // poorly aligned case, do not optimize\n    const auto remainder = num_elements - gid * VEC_SIZE;\n    for (uint32_t i = 0; i < remainder; ++i) {\n      const float value = static_cast<float>(input[gid * VEC_SIZE + i]);\n      max_value = math::max(max_value, math::abs(value));\n    }\n  }\n\n  // reduce within block and then atomic reduce between blocks\n  __shared__ float smem[kWarpThreads];\n  cta::reduce_max(max_value, smem);\n  if (threadIdx.x == 0) {\n    const auto max_value = smem[0];\n    atomic::max(output_s, max_value / math::FP8_E4M3_MAX);\n  }\n}\n\n[[maybe_unused]]\nSGL_DEVICE float fp8_e4m3_clip(float val) {\n  namespace math = device::math;\n  return math::max(math::min(val, math::FP8_E4M3_MAX), -math::FP8_E4M3_MAX);\n}\n\ntemplate <typename T, typename DST_DTYPE>\n__global__ void per_tensor_quant_fp8_kernel(\n    const T* __restrict__ input,\n    DST_DTYPE* __restrict__ output,\n    const float* __restrict__ scale,\n    const int64_t num_elements) {\n  using namespace device;\n  constexpr uint32_t VEC_SIZE = 16 / sizeof(T);\n\n  const int64_t gid = blockIdx.x * blockDim.x + threadIdx.x;\n  const float scale_val = 1.0f / (*scale);\n\n  if (gid * VEC_SIZE + VEC_SIZE <= num_elements) {\n    using input_vec_t = AlignedVector<T, VEC_SIZE>;\n    using output_vec_t = AlignedVector<DST_DTYPE, VEC_SIZE>;\n    const auto gmem_in = tile::Memory<input_vec_t>::thread();\n    const auto gmem_out = tile::Memory<output_vec_t>::thread();\n    const auto input_vec = gmem_in.load(input, gid);\n    output_vec_t output_vec;\n#pragma unroll\n    for (uint32_t i = 0; i < VEC_SIZE; ++i) {\n      const float value = fp8_e4m3_clip(static_cast<float>(input_vec[i]) * scale_val);\n      output_vec[i] = static_cast<DST_DTYPE>(value);\n    }\n    gmem_out.store(output, output_vec, gid);\n  } else if (gid * VEC_SIZE < num_elements) {\n    [[unlikely]];  // poorly aligned case, do not optimize\n    const auto remainder = num_elements - gid * VEC_SIZE;\n    for (uint32_t i = 0; i < remainder; ++i) {\n      const float value = fp8_e4m3_clip(static_cast<float>(input[gid * VEC_SIZE + i]) * scale_val);\n      output[gid * VEC_SIZE + i] = static_cast<DST_DTYPE>(value);\n    }\n  }\n}\n\ntemplate <bool kIsStatic, typename DType>\nvoid per_tensor_quant_fp8(tvm::ffi::TensorView input, tvm::ffi::TensorView output_q, tvm::ffi::TensorView output_s) {\n  using namespace host;\n\n  auto device = SymbolicDevice{};\n  auto N = SymbolicSize{\"num_elements\"};\n  device.set_options<kDLCUDA>();\n\n  TensorMatcher({N})  //\n      .with_dtype<DType>()\n      .with_device(device)\n      .verify(input);\n  TensorMatcher({N})  //\n      .with_dtype<fp8_e4m3_t>()\n      .with_device(device)\n      .verify(output_q);\n  TensorMatcher({1})  //\n      .with_dtype<float>()\n      .with_device(device)\n      .verify(output_s);\n\n  const auto num_elements = N.unwrap();\n\n  constexpr size_t kElementsPerBlock = kBlockSize * (16 / sizeof(DType));\n  const uint32_t num_blocks = div_ceil(num_elements, kElementsPerBlock);\n\n  if constexpr (!kIsStatic) {\n    LaunchKernel(num_blocks, kBlockSize, device.unwrap())(\n        per_tensor_absmax_kernel<DType>,\n        static_cast<const DType*>(input.data_ptr()),\n        static_cast<float*>(output_s.data_ptr()),\n        static_cast<int64_t>(num_elements));\n  }\n\n  LaunchKernel(num_blocks, kBlockSize, device.unwrap())(\n      per_tensor_quant_fp8_kernel<DType, fp8_e4m3_t>,\n      static_cast<const DType*>(input.data_ptr()),\n      static_cast<fp8_e4m3_t*>(output_q.data_ptr()),\n      static_cast<const float*>(output_s.data_ptr()),\n      static_cast<int64_t>(num_elements));\n}\n\n}  // namespace\n"
  },
  {
    "path": "python/sglang/jit_kernel/csrc/gemm/per_token_group_quant_8bit.cuh",
    "content": "#include <sgl_kernel/tensor.h>\n#include <sgl_kernel/utils.h>\n\n#include <sgl_kernel/atomic.cuh>\n#include <sgl_kernel/cta.cuh>\n#include <sgl_kernel/math.cuh>\n#include <sgl_kernel/tile.cuh>\n#include <sgl_kernel/utils.cuh>\n#include <sgl_kernel/vec.cuh>\n#include <sgl_kernel/warp.cuh>\n\n#include <cstddef>\n#include <cstdint>\n\nnamespace {\n\nconstexpr int kThreadsPerGroup = 16;\n\n__device__ __forceinline__ float GroupReduceMax(float val, const int tid) {\n  unsigned mask = threadIdx.x % 32 >= 16 ? 0xffff0000 : 0x0000ffff;\n  val = fmaxf(val, __shfl_xor_sync(mask, val, 8));\n  val = fmaxf(val, __shfl_xor_sync(mask, val, 4));\n  val = fmaxf(val, __shfl_xor_sync(mask, val, 2));\n  val = fmaxf(val, __shfl_xor_sync(mask, val, 1));\n  return val;\n}\n\ntemplate <bool kScaleUE8M0>\nusing scale_packed_t_t = std::conditional_t<kScaleUE8M0, uint32_t, float>;\n\ntemplate <bool kScaleUE8M0>\nusing scale_element_t_t = std::conditional_t<kScaleUE8M0, uint8_t, float>;\n\ntemplate <typename T, typename DST_DTYPE, bool kIsColumnMajor, bool kScaleUE8M0>\n__global__ void per_token_group_quant_8bit_kernel(\n    const T* __restrict__ input,\n    DST_DTYPE* __restrict__ output_q,\n    scale_packed_t_t<kScaleUE8M0>* __restrict__ output_s,\n    const int group_size,\n    const int num_groups,\n    const int groups_per_block,\n    const float eps,\n    const float min_8bit,\n    const float max_8bit,\n    const int num_groups_per_row = 0,\n    const int scale_stride = 0) {\n  using namespace device;\n  namespace math = device::math;\n\n  (void)num_groups;\n\n  const int local_group_id = static_cast<int>(threadIdx.x / kThreadsPerGroup);\n  const int lane_id = threadIdx.x % kThreadsPerGroup;\n\n  const int64_t block_group_id = blockIdx.x * groups_per_block;\n  const int64_t global_group_id = block_group_id + local_group_id;\n  const int64_t block_group_offset = global_group_id * group_size;\n\n  float local_absmax = eps;\n\n  using scale_packed_t = scale_packed_t_t<kScaleUE8M0>;\n  using scale_element_t = scale_element_t_t<kScaleUE8M0>;\n  static_assert(sizeof(scale_packed_t) % sizeof(scale_element_t) == 0);\n\n  const T* group_input = input + block_group_offset;\n  DST_DTYPE* group_output = static_cast<DST_DTYPE*>(output_q) + block_group_offset;\n  scale_element_t* scale_output = nullptr;\n\n  if constexpr (kIsColumnMajor) {\n    constexpr int kElemsPerPack = static_cast<int>(sizeof(scale_packed_t) / sizeof(scale_element_t));\n    const int row_idx = global_group_id / num_groups_per_row;\n    const int col_idx_unpacked = global_group_id % num_groups_per_row;\n    const int col_idx = col_idx_unpacked / kElemsPerPack;\n    const int pack_idx = col_idx_unpacked % kElemsPerPack;\n    scale_output = reinterpret_cast<scale_element_t*>(output_s) +\n                   (col_idx * scale_stride * kElemsPerPack + row_idx * kElemsPerPack + pack_idx);\n  } else {\n    static_assert(!kScaleUE8M0);\n    scale_output = output_s + global_group_id;\n  }\n\n  constexpr uint32_t kVecSize = 16 / sizeof(T);\n  using vec_t = AlignedVector<T, kVecSize>;\n  const auto gmem_in = tile::Memory<vec_t>::thread();\n\n  const int32_t num_vec_elems = group_size / kVecSize;\n\n  for (int32_t i = lane_id; i < num_vec_elems; i += kThreadsPerGroup) {\n    const vec_t input_vec = gmem_in.load(group_input, i);\n\n#pragma unroll\n    for (uint32_t j = 0; j < kVecSize; ++j) {\n      const float val = static_cast<float>(input_vec[j]);\n      local_absmax = math::max(local_absmax, math::abs(val));\n    }\n  }\n\n  local_absmax = GroupReduceMax(local_absmax, lane_id);\n\n  float y_s = local_absmax / max_8bit;\n  if constexpr (kScaleUE8M0) {\n    y_s = exp2f(ceilf(log2f(math::max(y_s, 1e-10f))));\n  }\n\n  scale_element_t y_s_quant;\n  if constexpr (kScaleUE8M0) {\n    y_s_quant = static_cast<uint8_t>(((int)log2f(y_s)) + 127);\n  } else {\n    y_s_quant = y_s;\n  }\n\n  if (lane_id == 0) {\n    *scale_output = y_s_quant;\n  }\n\n  for (int32_t i = lane_id; i < num_vec_elems; i += kThreadsPerGroup) {\n    const vec_t input_vec = gmem_in.load(group_input, i);\n\n#pragma unroll\n    for (uint32_t j = 0; j < kVecSize; ++j) {\n      const float val = static_cast<float>(input_vec[j]);\n      const float q_val = math::min(math::max(val / y_s, min_8bit), max_8bit);\n      group_output[i * kVecSize + j] = DST_DTYPE(q_val);\n    }\n  }\n}\n\ninline int compute_groups_per_block(int64_t num_groups) {\n  if (num_groups % 16 == 0) return 16;\n  if (num_groups % 8 == 0) return 8;\n  if (num_groups % 4 == 0) return 4;\n  if (num_groups % 2 == 0) return 2;\n  return 1;\n}\n\ntemplate <typename DType, typename OutType>\nvoid per_token_group_quant_8bit(\n    tvm::ffi::TensorView input,\n    tvm::ffi::TensorView output_q,\n    tvm::ffi::TensorView output_s,\n    int64_t group_size,\n    double eps,\n    double min_8bit,\n    double max_8bit,\n    bool scale_ue8m0) {\n  using namespace host;\n\n  auto device = SymbolicDevice{};\n  auto M = SymbolicSize{\"num_tokens\"};\n  auto K = SymbolicSize{\"hidden_dim\"};\n  device.set_options<kDLCUDA>();\n\n  TensorMatcher({M, K}).with_dtype<DType>().with_device(device).verify(input);\n  TensorMatcher({M, K}).with_dtype<OutType>().with_device(device).verify(output_q);\n\n  const auto num_tokens = M.unwrap();\n  const auto hidden_dim = K.unwrap();\n\n  const int64_t num_groups_per_row = hidden_dim / group_size;\n  const int64_t num_groups = num_tokens * num_groups_per_row;\n\n  const int groups_per_block = compute_groups_per_block(num_groups);\n  const int num_blocks = num_groups / groups_per_block;\n  const int num_threads = groups_per_block * kThreadsPerGroup;\n  const bool is_column_major = output_s.stride(0) < output_s.stride(1);\n  const int scale_stride = output_s.stride(1);\n\n  const float feps = static_cast<float>(eps);\n  const float fmin8 = static_cast<float>(min_8bit);\n  const float fmax8 = static_cast<float>(max_8bit);\n\n  if (is_column_major) {\n    if (scale_ue8m0) {\n      LaunchKernel(num_blocks, num_threads, input.device())(\n          per_token_group_quant_8bit_kernel<DType, OutType, true, true>,\n          static_cast<const DType*>(input.data_ptr()),\n          static_cast<OutType*>(output_q.data_ptr()),\n          static_cast<uint32_t*>(output_s.data_ptr()),\n          static_cast<int>(group_size),\n          static_cast<int>(num_groups),\n          static_cast<int>(groups_per_block),\n          feps,\n          fmin8,\n          fmax8,\n          static_cast<int>(num_groups_per_row),\n          scale_stride);\n    } else {\n      LaunchKernel(num_blocks, num_threads, input.device())(\n          per_token_group_quant_8bit_kernel<DType, OutType, true, false>,\n          static_cast<const DType*>(input.data_ptr()),\n          static_cast<OutType*>(output_q.data_ptr()),\n          static_cast<float*>(output_s.data_ptr()),\n          static_cast<int>(group_size),\n          static_cast<int>(num_groups),\n          static_cast<int>(groups_per_block),\n          feps,\n          fmin8,\n          fmax8,\n          static_cast<int>(num_groups_per_row),\n          scale_stride);\n    }\n  } else {\n    LaunchKernel(num_blocks, num_threads, input.device())(\n        per_token_group_quant_8bit_kernel<DType, OutType, false, false>,\n        static_cast<const DType*>(input.data_ptr()),\n        static_cast<OutType*>(output_q.data_ptr()),\n        static_cast<float*>(output_s.data_ptr()),\n        static_cast<int>(group_size),\n        static_cast<int>(num_groups),\n        static_cast<int>(groups_per_block),\n        feps,\n        fmin8,\n        fmax8,\n        0,\n        0);\n  }\n}\n}  // namespace\n"
  },
  {
    "path": "python/sglang/jit_kernel/csrc/hicache.cuh",
    "content": "#include <sgl_kernel/tensor.h>\n#include <sgl_kernel/utils.h>\n\n#include <sgl_kernel/utils.cuh>\n#include <sgl_kernel/vec.cuh>\n\n#include <dlpack/dlpack.h>\n\n#include <algorithm>\n#include <cstdint>\n#include <type_traits>\n\nnamespace device {\n\nnamespace details {\n\ntemplate <int kUnit>\ninline constexpr auto get_mem_package() {\n  if constexpr (kUnit == 16) {\n    return uint4{};\n  } else if constexpr (kUnit == 8) {\n    return uint2{};\n  } else if constexpr (kUnit == 4) {\n    return uint1{};\n  } else {\n    static_assert(kUnit == 16 || kUnit == 8 || kUnit == 4, \"Unsupported memory package size\");\n  }\n}\n\ntemplate <int kUnit>\nusing PackageType = decltype(get_mem_package<kUnit>());\n\nSGL_DEVICE uint1 load_nc(const uint1* __restrict__ src) {\n  uint32_t tmp;\n  asm volatile(\"ld.global.L1::no_allocate.b32 %0,[%1];\" : \"=r\"(tmp) : \"l\"(src));\n  return uint1{tmp};\n}\n\nSGL_DEVICE uint2 load_nc(const uint2* __restrict__ src) {\n  uint32_t tmp0, tmp1;\n  asm volatile(\"ld.global.L1::no_allocate.v2.b32 {%0,%1},[%2];\" : \"=r\"(tmp0), \"=r\"(tmp1) : \"l\"(src));\n  return uint2{tmp0, tmp1};\n}\n\nSGL_DEVICE uint4 load_nc(const uint4* __restrict__ src) {\n  uint32_t tmp0, tmp1, tmp2, tmp3;\n  asm volatile(\"ld.global.L1::no_allocate.v4.b32 {%0,%1,%2,%3},[%4];\"\n               : \"=r\"(tmp0), \"=r\"(tmp1), \"=r\"(tmp2), \"=r\"(tmp3)\n               : \"l\"(src));\n  return uint4{tmp0, tmp1, tmp2, tmp3};\n}\n\nSGL_DEVICE void store_nc(uint1* __restrict__ dst, const uint1& value) {\n  uint32_t tmp = value.x;\n  asm volatile(\"st.global.L1::no_allocate.b32 [%0],%1;\" ::\"l\"(dst), \"r\"(tmp));\n}\n\nSGL_DEVICE void store_nc(uint2* __restrict__ dst, const uint2& value) {\n  uint32_t tmp0 = value.x;\n  uint32_t tmp1 = value.y;\n  asm volatile(\"st.global.L1::no_allocate.v2.b32 [%0],{%1,%2};\" ::\"l\"(dst), \"r\"(tmp0), \"r\"(tmp1));\n}\n\nSGL_DEVICE void store_nc(uint4* __restrict__ dst, const uint4& value) {\n  uint32_t tmp0 = value.x;\n  uint32_t tmp1 = value.y;\n  uint32_t tmp2 = value.z;\n  uint32_t tmp3 = value.w;\n  asm volatile(\n      \"st.global.L1::no_allocate.v4.b32 [%0],{%1,%2,%3,%4};\" ::\"l\"(dst), \"r\"(tmp0), \"r\"(tmp1), \"r\"(tmp2), \"r\"(tmp3));\n}\n\n}  // namespace details\n\ntemplate <int64_t kBytes, uint32_t kNumThreads>\nSGL_DEVICE auto load_vec(const void* __restrict__ src) {\n  static_assert(kBytes % 128 == 0, \"kBytes must be multiple of 128 bytes\");\n  static_assert(128 % kNumThreads == 0, \"kNumThreads must divide 128 bytes\");\n  constexpr uint32_t kLoopCount = kBytes / 128;\n  using Package = details::PackageType<128 / kNumThreads>;\n  using Storage = AlignedStorage<Package, kLoopCount>;\n\n  const auto src_packed = static_cast<const Package*>(src);\n  const auto lane_id = threadIdx.x % kNumThreads;\n  Storage vec;\n\n#pragma unroll kLoopCount\n  for (uint32_t i = 0; i < kLoopCount; ++i) {\n    const auto j = i * kNumThreads + lane_id;\n    vec.data[i] = details::load_nc(&src_packed[j]);\n  }\n\n  return vec;\n}\n\ntemplate <int64_t kBytes, uint32_t kNumThreads, typename Storage>\nSGL_DEVICE void store_vec(void* __restrict__ dst, const Storage& vec) {\n  using Package = std::decay_t<decltype(vec.data[0])>;\n  constexpr uint32_t kBytesPerLoop = sizeof(Package) * kNumThreads;\n  constexpr uint32_t kLoopCount = kBytes / kBytesPerLoop;\n  static_assert(kBytes % kBytesPerLoop == 0, \"Invalid Storage configuration\");\n\n  const auto dst_packed = static_cast<Package*>(dst);\n  const auto lane_id = threadIdx.x % kNumThreads;\n\n#pragma unroll kLoopCount\n  for (uint32_t i = 0; i < kLoopCount; ++i) {\n    const auto j = i * kNumThreads + lane_id;\n    details::store_nc(&dst_packed[j], vec.data[i]);\n  }\n}\n\n}  // namespace device\n\nnamespace {\n\n#define SGL_HICACHE_KERNEL __global__ __launch_bounds__(kBlockSize, 1)\n\nstruct HicacheKernelParams {\n  void* __restrict__ k_cache_dst;\n  void* __restrict__ v_cache_dst;\n  const void* __restrict__ indices_dst;\n  void* __restrict__ k_cache_src;\n  void* __restrict__ v_cache_src;\n  const void* __restrict__ indices_src;\n  int64_t kv_cache_src_stride;\n  int64_t kv_cache_dst_stride;\n  uint32_t length;\n  uint32_t num_layers = 0;  // only used in all_layer transfer\n};\n\ntemplate <typename T, int64_t kElementSize, uint32_t kUnroll, uint32_t kBlockQuota, uint32_t kBlockSize>\nSGL_HICACHE_KERNEL void hicache_transfer_per_layer(const __grid_constant__ HicacheKernelParams params) {\n  using namespace device;\n  static_assert(kBlockSize % kWarpThreads == 0);\n  static_assert(kWarpThreads % kUnroll == 0);\n\n  constexpr uint32_t kNumThreads = kWarpThreads / kUnroll;\n  constexpr uint32_t kWorkersPerBlock = kBlockSize / kNumThreads;\n  constexpr uint32_t kNumWorkers = kWorkersPerBlock * kBlockQuota;\n\n  const auto& [\n    k_cache_dst, v_cache_dst, indices_dst, // dst\n    k_cache_src, v_cache_src, indices_src, // src\n    kv_cache_src_stride, kv_cache_dst_stride, length, _ // metadata\n  ] = params;\n\n  const uint32_t work_id = blockIdx.x * kWorkersPerBlock + threadIdx.x / kNumThreads;\n  for (uint32_t i = work_id; i < length; i += kNumWorkers) {\n    const auto pos_src = static_cast<const T*>(indices_src)[i];\n    const auto pos_dst = static_cast<const T*>(indices_dst)[i];\n    const auto src_k = pointer::offset(k_cache_src, pos_src * kv_cache_src_stride);\n    const auto dst_k = pointer::offset(k_cache_dst, pos_dst * kv_cache_dst_stride);\n    const auto src_v = pointer::offset(v_cache_src, pos_src * kv_cache_src_stride);\n    const auto dst_v = pointer::offset(v_cache_dst, pos_dst * kv_cache_dst_stride);\n    const auto vec_k = load_vec<kElementSize, kNumThreads>(src_k);\n    const auto vec_v = load_vec<kElementSize, kNumThreads>(src_v);\n    store_vec<kElementSize, kNumThreads>(dst_k, vec_k);\n    store_vec<kElementSize, kNumThreads>(dst_v, vec_v);\n  }\n}\n\ntemplate <typename T, int64_t kElementSize, uint32_t kUnroll, uint32_t kBlockQuota, uint32_t kBlockSize>\nSGL_HICACHE_KERNEL void hicache_transfer_all_layer(const __grid_constant__ HicacheKernelParams params) {\n  using namespace device;\n  using src_ptr_t = const void*;\n  using dst_ptr_t = void*;\n\n  static_assert(kBlockSize % kWarpThreads == 0);\n  static_assert(kWarpThreads % kUnroll == 0);\n\n  constexpr uint32_t kNumThreads = kWarpThreads / kUnroll;\n  constexpr uint32_t kWorkersPerBlock = kBlockSize / kNumThreads;\n  constexpr uint32_t kNumWorkers = kWorkersPerBlock * kBlockQuota;\n\n  const auto& [\n    k_ptr_dst, v_ptr_dst, indices_dst, // dst\n    k_ptr_src, v_ptr_src, indices_src, // src\n    kv_cache_src_stride, kv_cache_dst_stride, length, num_layers // metadata\n  ] = params;\n\n  const uint32_t work_id = blockIdx.x * kWorkersPerBlock + threadIdx.x / kNumThreads;\n  for (uint32_t i = work_id; i < length; i += kNumWorkers) {\n    const auto pos_src = static_cast<const T*>(indices_src)[i];\n    const auto pos_dst = static_cast<const T*>(indices_dst)[i];\n    for (uint32_t layer = 0; layer < num_layers; ++layer) {\n      const auto k_cache_src = static_cast<const src_ptr_t*>(k_ptr_src)[layer];\n      const auto v_cache_src = static_cast<const src_ptr_t*>(v_ptr_src)[layer];\n      const auto k_cache_dst = static_cast<const dst_ptr_t*>(k_ptr_dst)[layer];\n      const auto v_cache_dst = static_cast<const dst_ptr_t*>(v_ptr_dst)[layer];\n      const auto src_k = pointer::offset(k_cache_src, pos_src * kv_cache_src_stride);\n      const auto dst_k = pointer::offset(k_cache_dst, pos_dst * kv_cache_dst_stride);\n      const auto src_v = pointer::offset(v_cache_src, pos_src * kv_cache_src_stride);\n      const auto dst_v = pointer::offset(v_cache_dst, pos_dst * kv_cache_dst_stride);\n      const auto vec_k = load_vec<kElementSize, kNumThreads>(src_k);\n      const auto vec_v = load_vec<kElementSize, kNumThreads>(src_v);\n      store_vec<kElementSize, kNumThreads>(dst_k, vec_k);\n      store_vec<kElementSize, kNumThreads>(dst_v, vec_v);\n    }\n  }\n}\n\ntemplate <int64_t kElementSize, uint32_t kUnroll, uint32_t kBlockQuota, uint32_t kBlockSize>\nstruct HiCacheKernel {\n  template <typename T>\n  static constexpr auto kernel_one = hicache_transfer_per_layer<T, kElementSize, kUnroll, kBlockQuota, kBlockSize>;\n  template <typename T>\n  static constexpr auto kernel_all = hicache_transfer_all_layer<T, kElementSize, kUnroll, kBlockQuota, kBlockSize>;\n\n  static void run_one(\n      const tvm::ffi::TensorView k_cache_dst,\n      const tvm::ffi::TensorView v_cache_dst,\n      const tvm::ffi::TensorView indices_dst,\n      const tvm::ffi::TensorView k_cache_src,\n      const tvm::ffi::TensorView v_cache_src,\n      const tvm::ffi::TensorView indices_src) {\n    using namespace host;\n\n    auto D = SymbolicSize{\"head dimension\"};\n    auto N = SymbolicSize{\"src kv stride\"};\n    auto M = SymbolicSize{\"dst kv stride\"};\n    auto L = SymbolicSize{\"indices length\"};\n    auto cache_dtype = SymbolicDType{};\n    auto indices_dtype = SymbolicDType{};\n    auto indices_device = SymbolicDevice{};\n\n    TensorMatcher({-1, D})  //\n        .with_strides({N, 1})\n        .with_dtype(cache_dtype)\n        .with_device<kDLCUDA, kDLCUDAHost, kDLCPU>()\n        .verify(k_cache_src)\n        .verify(v_cache_src);\n    TensorMatcher({-1, D})  //\n        .with_strides({M, 1})\n        .with_dtype(cache_dtype)\n        .with_device<kDLCUDA, kDLCUDAHost, kDLCPU>()\n        .verify(k_cache_dst)\n        .verify(v_cache_dst);\n    TensorMatcher({L})  //\n        .with_dtype<int32_t, int64_t>(indices_dtype)\n        .with_device<kDLCUDA>(indices_device)\n        .verify(indices_src)\n        .verify(indices_dst);\n\n    // verify dimension match\n    const auto dtype_size = dtype_bytes(cache_dtype.unwrap());\n    const auto element_bytes = D.unwrap() * dtype_size;\n    RuntimeCheck(kElementSize == element_bytes, \"HicacheKernel: cache dimension mismatch.\");\n\n    const auto k_cache_dst_ptr = k_cache_dst.data_ptr();\n    const auto v_cache_dst_ptr = v_cache_dst.data_ptr();\n    const auto k_cache_src_ptr = k_cache_src.data_ptr();\n    const auto v_cache_src_ptr = v_cache_src.data_ptr();\n    const auto indices_dst_ptr = indices_dst.data_ptr();\n    const auto indices_src_ptr = indices_src.data_ptr();\n    const auto length = static_cast<uint32_t>(L.unwrap());\n    const auto kv_cache_src_stride = static_cast<int64_t>(N.unwrap() * dtype_size);\n    const auto kv_cache_dst_stride = static_cast<int64_t>(M.unwrap() * dtype_size);\n    const auto use_int32 = indices_dtype.unwrap().bits == 32;\n    const auto device = indices_device.unwrap();\n\n    constexpr auto kWorkersPerBlock = kBlockSize / (device::kWarpThreads / kUnroll);\n    const auto num_blocks = std::min(div_ceil(length, kWorkersPerBlock), kBlockQuota);\n    const auto params = HicacheKernelParams{\n        .k_cache_dst = k_cache_dst_ptr,\n        .v_cache_dst = v_cache_dst_ptr,\n        .indices_dst = indices_dst_ptr,\n        .k_cache_src = k_cache_src_ptr,\n        .v_cache_src = v_cache_src_ptr,\n        .indices_src = indices_src_ptr,\n        .kv_cache_src_stride = kv_cache_src_stride,\n        .kv_cache_dst_stride = kv_cache_dst_stride,\n        .length = length,\n    };\n    const auto kernel = use_int32 ? kernel_one<int32_t> : kernel_one<int64_t>;\n    LaunchKernel(num_blocks, kBlockSize, device)(kernel, params);\n  }\n\n  static void run_all(\n      const tvm::ffi::TensorView k_ptr_dst,\n      const tvm::ffi::TensorView v_ptr_dst,\n      const tvm::ffi::TensorView indices_dst,\n      const tvm::ffi::TensorView k_ptr_src,\n      const tvm::ffi::TensorView v_ptr_src,\n      const tvm::ffi::TensorView indices_src,\n      const int64_t kv_src_stride_bytes,\n      const int64_t kv_dst_stride_bytes) {\n    using namespace host;\n\n    auto N = SymbolicSize{\"num_layers\"};\n    auto L = SymbolicSize{\"indices length\"};\n    auto dtype_ = SymbolicDType{};\n    auto device_ = SymbolicDevice{};\n\n    TensorMatcher({N})  //\n        .with_dtype<uint64_t>()\n        .with_device<kDLCUDA>(device_)\n        .verify(k_ptr_src)\n        .verify(v_ptr_src)\n        .verify(k_ptr_dst)\n        .verify(v_ptr_dst);\n    TensorMatcher({L})  //\n        .with_dtype<int32_t, int64_t>(dtype_)\n        .with_device<kDLCUDA>(device_)\n        .verify(indices_src)\n        .verify(indices_dst);\n\n    // verify dimension match\n    const auto k_cache_dst_ptr = k_ptr_dst.data_ptr();\n    const auto v_cache_dst_ptr = v_ptr_dst.data_ptr();\n    const auto k_cache_src_ptr = k_ptr_src.data_ptr();\n    const auto v_cache_src_ptr = v_ptr_src.data_ptr();\n    const auto indices_dst_ptr = indices_dst.data_ptr();\n    const auto indices_src_ptr = indices_src.data_ptr();\n    const auto length = static_cast<uint32_t>(L.unwrap());\n    const auto use_int32 = dtype_.unwrap().bits == 32;\n    const auto device = device_.unwrap();\n\n    constexpr auto kWorkersPerBlock = kBlockSize / (device::kWarpThreads / kUnroll);\n    const auto num_blocks = std::min(div_ceil(length, kWorkersPerBlock), kBlockQuota);\n    const auto params = HicacheKernelParams{\n        .k_cache_dst = k_cache_dst_ptr,\n        .v_cache_dst = v_cache_dst_ptr,\n        .indices_dst = indices_dst_ptr,\n        .k_cache_src = k_cache_src_ptr,\n        .v_cache_src = v_cache_src_ptr,\n        .indices_src = indices_src_ptr,\n        .kv_cache_src_stride = kv_src_stride_bytes,\n        .kv_cache_dst_stride = kv_dst_stride_bytes,\n        .length = length,\n        .num_layers = static_cast<uint32_t>(N.unwrap()),\n    };\n    const auto kernel = use_int32 ? kernel_all<int32_t> : kernel_all<int64_t>;\n    LaunchKernel(num_blocks, kBlockSize, device)(kernel, params);\n  }\n};\n\n#undef SGL_HICACHE_KERNEL\n\n}  // namespace\n"
  },
  {
    "path": "python/sglang/jit_kernel/csrc/lora/moe_lora_align_kernel.cu",
    "content": "// Temporarily adapted from https://github.com/vllm-project/vllm/blob/main/csrc/moe/moe_align_sum_kernels.cu, will\n// optimize in future refactor\n\n#include <sgl_kernel/tensor.h>\n#include <sgl_kernel/utils.h>\n\n#include <sgl_kernel/utils.cuh>\n\n#include <cub/cub.cuh>\n#include <tvm/ffi/container/tensor.h>\n\n#include <algorithm>\n\n#ifndef WARP_SIZE\n#define WARP_SIZE 32\n#endif\n\n#define CEILDIV(x, y) (((x) + (y) - 1) / (y))\n\nnamespace moe {\n\ntemplate <typename scalar_t>\nSGL_DEVICE void _moe_align_block_size(\n    const scalar_t* __restrict__ topk_ids,\n    int32_t* __restrict__ sorted_token_ids,\n    int32_t* __restrict__ expert_ids,\n    int32_t* __restrict__ total_tokens_post_pad,\n    int32_t* __restrict__ expert_map,\n    int32_t num_experts,\n    int32_t padded_num_experts,\n    int32_t experts_per_warp,\n    int32_t block_size,\n    size_t numel,\n    int32_t* __restrict__ cumsum,\n    int32_t max_num_tokens_padded,\n    int32_t max_num_m_blocks,\n    int32_t model_offset,\n    int32_t inactive_expert_id,\n    int32_t topk_num,\n    int32_t* token_mask,\n    bool has_expert_map) {\n  extern __shared__ int32_t shared_counts[];\n\n  // Compute input buffer offsets. Typically these will all be 0, except when\n  // using Multi LoRA.\n  int sorted_token_ids_offset = max_num_tokens_padded * model_offset;\n  int expert_ids_offset = max_num_m_blocks * model_offset;\n  int cumsum_offset = (num_experts + 1) * model_offset;\n\n  // Use separate threadblocks to fill sorted_token_ids.\n  // This is safe since the current kernel does not use sorted_token_ids.\n  if (blockIdx.x % 2) {\n    // Initialize sorted_token_ids with numel\n    for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += blockDim.x) {\n      sorted_token_ids[sorted_token_ids_offset + it] = static_cast<int32_t>(numel);\n    }\n    return;\n  }\n\n  const int warp_id = threadIdx.x / WARP_SIZE;\n  const int my_expert_start = warp_id * experts_per_warp;\n\n  for (int i = 0; i < experts_per_warp; ++i) {\n    if (my_expert_start + i < padded_num_experts) {\n      shared_counts[warp_id * experts_per_warp + i] = 0;\n    }\n  }\n\n  __syncthreads();\n\n  const size_t tid = threadIdx.x;\n  const size_t stride = blockDim.x;\n\n  for (size_t i = tid; i < numel; i += stride) {\n    int expert_id = topk_ids[i];\n    if (expert_id < 0 || expert_id >= num_experts) {\n      continue;\n    }\n    if (has_expert_map) {\n      expert_id = expert_map[expert_id];\n      if (expert_id < 0 || expert_id >= num_experts) continue;\n    }\n    int warp_idx = expert_id / experts_per_warp;\n    int expert_offset = expert_id % experts_per_warp;\n    int mask = token_mask == nullptr ? 1 : token_mask[i / topk_num];\n    atomicAdd(&shared_counts[warp_idx * experts_per_warp + expert_offset], mask);\n  }\n\n  __syncthreads();\n\n  // Compute prefix sum over token counts per expert\n  using BlockScan = cub::BlockScan<int32_t, 1024>;\n  __shared__ typename BlockScan::TempStorage temp_storage;\n\n  int expert_count = 0;\n  int expert_id = threadIdx.x;\n  if (expert_id < num_experts) {\n    int warp_idx = expert_id / experts_per_warp;\n    int expert_offset = expert_id % experts_per_warp;\n    expert_count = shared_counts[warp_idx * experts_per_warp + expert_offset];\n    expert_count = CEILDIV(expert_count, block_size) * block_size;\n  }\n\n  int cumsum_val;\n  BlockScan(temp_storage).ExclusiveSum(expert_count, cumsum_val);\n  if (expert_id <= num_experts) {\n    cumsum[cumsum_offset + expert_id] = cumsum_val;\n  }\n\n  if (expert_id == num_experts) {\n    total_tokens_post_pad[model_offset] = cumsum_val;\n  }\n\n  __syncthreads();\n\n  if (threadIdx.x < num_experts) {\n    for (int i = cumsum[cumsum_offset + threadIdx.x]; i < cumsum[cumsum_offset + threadIdx.x + 1]; i += block_size) {\n      expert_ids[expert_ids_offset + i / block_size] = threadIdx.x;\n    }\n  }\n\n  // Fill remaining expert_ids with 0\n  const size_t fill_start_idx = cumsum[cumsum_offset + num_experts] / block_size + threadIdx.x;\n  for (size_t i = fill_start_idx; i < max_num_m_blocks; i += blockDim.x) {\n    expert_ids[expert_ids_offset + i] = inactive_expert_id;\n  }\n}\n\ntemplate <typename scalar_t, int32_t fill_threads>\nSGL_DEVICE void _moe_align_block_size_small_batch_expert(\n    const scalar_t* __restrict__ topk_ids,\n    int32_t* __restrict__ sorted_token_ids,\n    int32_t* __restrict__ expert_ids,\n    int32_t* __restrict__ total_tokens_post_pad,\n    int32_t* __restrict__ expert_map,\n    int32_t num_experts,\n    int32_t block_size,\n    size_t numel,\n    int32_t max_num_tokens_padded,\n    int32_t max_num_m_blocks,\n    int32_t inactive_expert_id,\n    int32_t model_offset,\n    int32_t topk_num,\n    int32_t* token_mask,\n    bool has_expert_map) {\n  // Compute input buffer offsets. Typically these will all be 0, except when\n  // using Multi LoRA.\n  int sorted_token_ids_offset = max_num_tokens_padded * model_offset;\n  int expert_ids_offset = max_num_m_blocks * model_offset;\n\n  // Use an additional group of threads to fill sorted_token_ids.\n  // Since the current kernel will use sorted_token_ids afterward,\n  // we fill sorted_token_ids within the same threadblock to make\n  // synchronization easier.\n  if (threadIdx.x < fill_threads) {\n    // Initialize sorted_token_ids with numel\n    for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += fill_threads) {\n      sorted_token_ids[sorted_token_ids_offset + it] = static_cast<int32_t>(numel);\n    }\n    // Three __syncthreads() corresponding to the other threads\n    __syncthreads();\n    __syncthreads();\n    __syncthreads();\n    return;\n  }\n\n  const size_t tid = threadIdx.x - fill_threads;\n  const size_t stride = blockDim.x - fill_threads;\n\n  extern __shared__ int32_t shared_mem[];\n  int32_t* cumsum = shared_mem;\n  int32_t* tokens_cnts = (int32_t*)(shared_mem + num_experts + 1);\n\n  for (int i = 0; i < num_experts; ++i) {\n    tokens_cnts[(tid + 1) * num_experts + i] = 0;\n  }\n\n  for (size_t i = tid; i < numel; i += stride) {\n    int32_t expert_id = topk_ids[i];\n    if (expert_id < 0 || expert_id >= num_experts) continue;\n    if (has_expert_map) {\n      expert_id = expert_map[expert_id];\n      if (expert_id < 0 || expert_id >= num_experts) continue;\n    }\n    int mask = token_mask == nullptr ? 1 : token_mask[i / topk_num];\n    tokens_cnts[(tid + 1) * num_experts + expert_id] += mask;\n  }\n\n  __syncthreads();\n\n  if (tid < num_experts) {\n    tokens_cnts[tid] = 0;\n    for (int i = 1; i <= stride; ++i) {\n      tokens_cnts[i * num_experts + tid] += tokens_cnts[(i - 1) * num_experts + tid];\n    }\n  }\n\n  __syncthreads();\n\n  if (tid == 0) {\n    cumsum[0] = 0;\n    for (int i = 1; i <= num_experts; ++i) {\n      cumsum[i] = cumsum[i - 1] + CEILDIV(tokens_cnts[stride * num_experts + i - 1], block_size) * block_size;\n    }\n    total_tokens_post_pad[model_offset] = static_cast<int32_t>(cumsum[num_experts]);\n  }\n\n  __syncthreads();\n\n  if (tid < num_experts) {\n    for (int i = cumsum[tid]; i < cumsum[tid + 1]; i += block_size) {\n      expert_ids[expert_ids_offset + i / block_size] = tid;\n    }\n  }\n\n  // Fill remaining expert_ids with 0\n  const size_t fill_start_idx = cumsum[num_experts] / block_size + tid;\n  for (size_t i = fill_start_idx; i < max_num_m_blocks; i += stride) {\n    expert_ids[expert_ids_offset + i] = inactive_expert_id;\n  }\n\n  for (size_t i = tid; i < numel; i += stride) {\n    int32_t expert_id = topk_ids[i];\n    if (expert_id < 0 || expert_id >= num_experts) continue;\n    if (has_expert_map) {\n      expert_id = expert_map[expert_id];\n      if (expert_id < 0 || expert_id >= num_experts) continue;\n    }\n    int32_t rank_post_pad = tokens_cnts[tid * num_experts + expert_id] + cumsum[expert_id];\n\n    if (token_mask == nullptr || token_mask[i / topk_num]) {\n      sorted_token_ids[sorted_token_ids_offset + rank_post_pad] = i;\n      ++tokens_cnts[tid * num_experts + expert_id];\n    }\n  }\n}\n\ntemplate <typename scalar_t>\nSGL_DEVICE void _count_and_sort_expert_tokens(\n    const scalar_t* __restrict__ topk_ids,\n    int32_t* __restrict__ sorted_token_ids,\n    int32_t* __restrict__ cumsum_buffer,\n    int32_t* __restrict__ expert_map,\n    size_t numel,\n    int32_t num_experts,\n    int32_t max_num_tokens_padded,\n    int32_t* __restrict__ token_mask,\n    int32_t model_offset,\n    int32_t topk_num,\n    bool has_expert_map) {\n  const size_t tid = blockIdx.y * blockDim.x + threadIdx.x;\n  const size_t stride = blockDim.x * gridDim.y;\n\n  for (size_t i = tid; i < numel; i += stride) {\n    int32_t expert_id = topk_ids[i];\n    if (expert_id >= num_experts) {\n      continue;\n    }\n\n    if (has_expert_map) {\n      expert_id = expert_map[expert_id];\n      // filter invalid experts\n      if (expert_id == -1) continue;\n    }\n\n    if (token_mask == nullptr || token_mask[i / topk_num]) {\n      int32_t rank_post_pad = atomicAdd(&cumsum_buffer[(model_offset * (num_experts + 1)) + expert_id], 1);\n      sorted_token_ids[max_num_tokens_padded * model_offset + rank_post_pad] = i;\n    }\n  }\n}\n\ntemplate <typename scalar_t>\n__global__ void moe_lora_align_block_size_kernel(\n    scalar_t* __restrict__ topk_ids,\n    int32_t* __restrict__ seg_indptr,\n    int32_t* __restrict__ req_to_lora,\n    int num_reqs,\n    int64_t block_size,\n    int32_t* __restrict__ expert_map,\n    int num_experts,\n    int max_loras,\n    size_t numel,\n    int max_num_tokens_padded,\n    int max_num_m_blocks,\n    int32_t* __restrict__ sorted_token_ids,\n    int32_t* __restrict__ expert_ids,\n    int32_t topk_num,\n    int32_t* total_tokens_post_pad,\n    int32_t* adapter_enabled,\n    int32_t* __restrict__ cumsum,\n    int32_t experts_per_warp,\n    int32_t padded_num_experts,\n    int32_t* lora_ids,\n    int32_t* __restrict__ token_mask,\n    bool has_expert_map) {\n  int lora_idx = blockIdx.x / 2;\n  int lora_id = lora_ids[lora_idx];\n  if (lora_id == -1 || adapter_enabled[lora_id] == 0) {\n    return;\n  }\n\n  int num_tokens = numel / topk_num;\n  int lora_offset = lora_id * num_tokens;\n\n  if (blockIdx.x % 2 == 0) {\n    // 1. Parallel Clear (Reset mask to 0)\n    for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) {\n      token_mask[lora_offset + i] = 0;\n    }\n\n    if (threadIdx.x == 0) {\n      total_tokens_post_pad[lora_id] = 0;\n    }\n\n    __syncthreads();\n\n    // 2. Segment-based Fill\n    for (int r = 0; r < num_reqs; ++r) {\n      if (req_to_lora[r] == lora_id) {\n        int start = seg_indptr[r];\n        int end = seg_indptr[r + 1];\n        for (int i = start + threadIdx.x; i < end; i += blockDim.x) {\n          token_mask[lora_offset + i] = 1;\n        }\n      }\n    }\n\n    __syncthreads();\n  }\n\n  _moe_align_block_size(\n      topk_ids,\n      sorted_token_ids,\n      expert_ids,\n      total_tokens_post_pad,\n      expert_map,\n      num_experts,\n      padded_num_experts,\n      experts_per_warp,\n      block_size,\n      numel,\n      cumsum,\n      max_num_tokens_padded,\n      max_num_m_blocks,\n      lora_id,\n      -1,  // inactive_expert_id padding\n      topk_num,\n      &token_mask[(lora_id * num_tokens)],\n      has_expert_map);\n}\n\ntemplate <typename scalar_t>\n__global__ void lora_count_and_sort_expert_tokens_kernel(\n    const scalar_t* __restrict__ topk_ids,\n    int32_t* __restrict__ sorted_token_ids,\n    int32_t* __restrict__ cumsum_buffer,\n    int32_t* __restrict__ expert_map,\n    size_t numel,\n    int32_t num_experts,\n    int32_t max_num_tokens_padded,\n    int32_t topk_num,\n    int32_t* token_mask,\n    int32_t* lora_ids,\n    int32_t* adapter_enabled,\n    bool has_expert_map) {\n  int lora_idx = blockIdx.x;\n  int lora_id = lora_ids[lora_idx];\n  if (lora_id == -1 || adapter_enabled[lora_id] == 0) {\n    return;\n  }\n\n  int num_tokens = numel / topk_num;\n\n  _count_and_sort_expert_tokens(\n      topk_ids,\n      sorted_token_ids,\n      cumsum_buffer,\n      expert_map,\n      numel,\n      num_experts,\n      max_num_tokens_padded,\n      &token_mask[(lora_id * num_tokens)],\n      lora_id,\n      topk_num,\n      has_expert_map);\n}\n\ntemplate <typename scalar_t, int32_t fill_threads>\n__global__ void moe_lora_align_block_size_small_batch_expert_kernel(\n    scalar_t* __restrict__ topk_ids,\n    int32_t* __restrict__ seg_indptr,\n    int32_t* __restrict__ req_to_lora,\n    int num_reqs,\n    int64_t block_size,\n    int32_t* __restrict__ expert_map,\n    int num_experts,\n    int max_loras,\n    size_t numel,\n    int max_num_tokens_padded,\n    int max_num_m_blocks,\n    int32_t* __restrict__ sorted_token_ids,\n    int32_t* __restrict__ expert_ids,\n    int topk_num,\n    int32_t* total_tokens_post_pad,\n    int32_t* adapter_enabled,\n    int32_t* lora_ids,\n    int32_t* token_mask,\n    bool has_expert_map) {\n  int lora_idx = blockIdx.x;\n  int lora_id = lora_ids[lora_idx];\n  if (lora_id == -1 || adapter_enabled[lora_id] == 0) {\n    return;\n  }\n\n  int num_tokens = numel / topk_num;\n  int lora_offset = lora_id * num_tokens;\n\n  // 1. Parallel Clear (Reset mask to 0)\n  // All threads help clear the mask for this adapter\n  for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) {\n    token_mask[lora_offset + i] = 0;\n  }\n\n  // Initialize output counter\n  if (threadIdx.x == 0) {\n    total_tokens_post_pad[lora_id] = 0;\n  }\n\n  __syncthreads();\n\n  // 2. Segment-based Fill\n  // Iterate over requests. If a request matches this LoRA, fill its range.\n  for (int r = 0; r < num_reqs; ++r) {\n    if (req_to_lora[r] == lora_id) {\n      int start = seg_indptr[r];\n      int end = seg_indptr[r + 1];\n\n      // Parallel Fill: All threads help mark this segment as \"1\"\n      for (int i = start + threadIdx.x; i < end; i += blockDim.x) {\n        token_mask[lora_offset + i] = 1;\n      }\n    }\n  }\n\n  __syncthreads();\n\n  _moe_align_block_size_small_batch_expert<scalar_t, fill_threads>(\n      topk_ids,\n      sorted_token_ids,\n      expert_ids,\n      total_tokens_post_pad,\n      expert_map,\n      num_experts,\n      block_size,\n      numel,\n      max_num_tokens_padded,\n      max_num_m_blocks,\n      -1,  // inactive_expert_id padding\n      lora_id,\n      topk_num,\n      &token_mask[(lora_id * num_tokens)],\n      has_expert_map);\n}\n\n}  // namespace moe\n\nnamespace {\n\ntemplate <typename scalar_t>\nstruct MoeLoraAlignBlockSizeKernel {\n  static void\n  run(tvm::ffi::TensorView topk_ids,\n      tvm::ffi::TensorView seg_indptr,\n      tvm::ffi::TensorView req_to_lora,\n      int64_t num_experts,\n      int64_t block_size,\n      int64_t max_loras,\n      int64_t max_num_tokens_padded,\n      int64_t max_num_m_blocks,\n      tvm::ffi::TensorView sorted_token_ids,\n      tvm::ffi::TensorView expert_ids,\n      tvm::ffi::TensorView num_tokens_post_pad,\n      tvm::ffi::TensorView adapter_enabled,\n      tvm::ffi::TensorView lora_ids,\n      tvm::ffi::Optional<tvm::ffi::TensorView> maybe_expert_map,\n      tvm::ffi::TensorView cumsum_buffer,\n      tvm::ffi::TensorView token_mask) {\n    using namespace host;\n\n    const int topk_num = topk_ids.size(1);\n\n    RuntimeCheck(block_size > 0, \"block_size should be greater than 0. \");\n\n    int device_max_shared_mem;\n    auto device = topk_ids.device();\n    int dev_id = device.device_id;\n    RuntimeDeviceCheck(cudaDeviceGetAttribute(&device_max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev_id));\n    const cudaStream_t stream = LaunchKernel::resolve_device(device);\n\n    int64_t padded_num_experts = ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;\n\n    // BlockScan uses 1024 threads and assigns one thread per expert.\n    RuntimeCheck(padded_num_experts < 1024, \"padded_num_experts must be less than 1024\");\n\n    int32_t* token_mask_ptr = static_cast<int32_t*>(token_mask.data_ptr());\n\n    bool has_expert_map = maybe_expert_map.has_value();\n    int32_t* expert_map_ptr = nullptr;\n    if (has_expert_map) {\n      expert_map_ptr = static_cast<int32_t*>(maybe_expert_map.value().data_ptr());\n    }\n    int num_reqs = seg_indptr.size(0) - 1;\n\n    bool small_batch_expert_mode = (topk_ids.numel() < 1024) && (num_experts <= 64);\n\n    if (small_batch_expert_mode) {\n      const int32_t num_thread = std::max((int32_t)num_experts, 128);\n      const int32_t shared_mem = (num_thread + 1) * num_experts * sizeof(int32_t) + (num_experts + 1) * sizeof(int32_t);\n      if (shared_mem > device_max_shared_mem) {\n        RuntimeCheck(false, \"Shared memory usage exceeds device limit.\");\n      }\n\n      // threadIdx.x >= fill_threads: counting experts and aligning\n      // threadIdx.x < fill_threads: filling sorted_token_ids\n      constexpr int32_t fill_threads = 256;\n\n      dim3 blockDim(num_thread + fill_threads);\n      auto kernel = moe::moe_lora_align_block_size_small_batch_expert_kernel<scalar_t, fill_threads>;\n      RuntimeDeviceCheck(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem));\n\n      LaunchKernel(dim3(max_loras), blockDim, stream, shared_mem)(\n          kernel,\n          static_cast<scalar_t*>(topk_ids.data_ptr()),\n          static_cast<int32_t*>(seg_indptr.data_ptr()),\n          static_cast<int32_t*>(req_to_lora.data_ptr()),\n          num_reqs,\n          block_size,\n          expert_map_ptr,\n          num_experts,\n          max_loras,\n          topk_ids.numel(),\n          max_num_tokens_padded,\n          max_num_m_blocks,\n          static_cast<int32_t*>(sorted_token_ids.data_ptr()),\n          static_cast<int32_t*>(expert_ids.data_ptr()),\n          topk_num,\n          static_cast<int32_t*>(num_tokens_post_pad.data_ptr()),\n          static_cast<int32_t*>(adapter_enabled.data_ptr()),\n          static_cast<int32_t*>(lora_ids.data_ptr()),\n          token_mask_ptr,\n          has_expert_map);\n\n    } else {\n      int num_thread = 1024;\n      dim3 blockDim(num_thread);\n      size_t num_warps = CEILDIV(padded_num_experts, WARP_SIZE);\n\n      size_t shared_mem_size = num_warps * WARP_SIZE * sizeof(int32_t);\n\n      auto align_kernel = moe::moe_lora_align_block_size_kernel<scalar_t>;\n\n      // launch two threadblocks for each lora\n      // blockIdx.x % 2 == 0: counting experts and aligning\n      // blockIdx.x % 2 == 1: filling sorted_token_ids\n      LaunchKernel(dim3(max_loras * 2), blockDim, stream, shared_mem_size)(\n          align_kernel,\n          static_cast<scalar_t*>(topk_ids.data_ptr()),\n          static_cast<int32_t*>(seg_indptr.data_ptr()),\n          static_cast<int32_t*>(req_to_lora.data_ptr()),\n          num_reqs,\n          block_size,\n          expert_map_ptr,\n          num_experts,\n          max_loras,\n          topk_ids.numel(),\n          max_num_tokens_padded,\n          max_num_m_blocks,\n          static_cast<int32_t*>(sorted_token_ids.data_ptr()),\n          static_cast<int32_t*>(expert_ids.data_ptr()),\n          topk_num,\n          static_cast<int32_t*>(num_tokens_post_pad.data_ptr()),\n          static_cast<int32_t*>(adapter_enabled.data_ptr()),\n          static_cast<int32_t*>(cumsum_buffer.data_ptr()),\n          WARP_SIZE,\n          padded_num_experts,\n          static_cast<int32_t*>(lora_ids.data_ptr()),\n          token_mask_ptr,\n          has_expert_map);\n\n      const int block_threads = std::min(256, (int)num_thread);\n      const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads;\n\n      const int max_blocks = 65535;\n      const int actual_blocks = std::min(num_blocks, max_blocks);\n\n      dim3 gridDims(max_loras, actual_blocks);\n      auto sort_kernel = moe::lora_count_and_sort_expert_tokens_kernel<scalar_t>;\n\n      LaunchKernel(gridDims, dim3(block_threads), stream)(\n          sort_kernel,\n          static_cast<scalar_t*>(topk_ids.data_ptr()),\n          static_cast<int32_t*>(sorted_token_ids.data_ptr()),\n          static_cast<int32_t*>(cumsum_buffer.data_ptr()),\n          expert_map_ptr,\n          topk_ids.numel(),\n          num_experts,\n          max_num_tokens_padded,\n          topk_num,\n          token_mask_ptr,\n          static_cast<int32_t*>(lora_ids.data_ptr()),\n          static_cast<int32_t*>(adapter_enabled.data_ptr()),\n          has_expert_map);\n    }\n  }\n};\n\n}  // namespace\n"
  },
  {
    "path": "python/sglang/jit_kernel/csrc/moe/nvfp4_blockwise_moe.cuh",
    "content": "#include <sgl_kernel/tensor.h>\n#include <sgl_kernel/utils.h>\n\n#include <sgl_kernel/runtime.cuh>\n#include <sgl_kernel/utils.cuh>\n\n#include <cutlass/arch/arch.h>\n#include <cutlass/cutlass.h>\n\n#include \"cute/tensor.hpp\"\n#include \"cutlass/epilogue/collective/collective_builder.hpp\"\n#include \"cutlass/epilogue/collective/default_epilogue.hpp\"\n#include \"cutlass/epilogue/thread/linear_combination.h\"\n#include \"cutlass/gemm/collective/collective_builder.hpp\"\n#include \"cutlass/gemm/device/gemm_universal_adapter.h\"\n#include \"cutlass/gemm/dispatch_policy.hpp\"\n#include \"cutlass/gemm/group_array_problem_shape.hpp\"\n#include \"cutlass/gemm/kernel/gemm_universal.hpp\"\n#include \"cutlass/tensor_ref.h\"\n#include \"cutlass/util/command_line.h\"\n#include \"cutlass/util/distribution.h\"\n#include \"cutlass/util/host_tensor.h\"\n#include \"cutlass/util/packed_stride.hpp\"\n#include \"cutlass/util/reference/device/gemm.h\"\n#include \"cutlass/util/reference/device/tensor_compare.h\"\n#include \"cutlass/util/reference/host/gett.hpp\"\n#include \"cutlass/util/reference/host/tensor_compare.h\"\n#include \"cutlass/util/reference/host/tensor_fill.h\"\n#include \"cutlass/util/reference/host/tensor_norm.h\"\n#include \"cutlass/util/tensor_view_io.h\"\n#include <algorithm>\n#include <cassert>\n#include <cstdint>\n#include <limits>\n#include <unordered_map>\n\nusing namespace host;\nusing namespace cute;\n\nstruct WorkspaceKey {\n  int device_id;\n  uintptr_t stream;\n  auto operator==(const WorkspaceKey&) const -> bool = default;\n};\n\nstruct WorkspaceKeyHash {\n  auto operator()(const WorkspaceKey& key) const -> size_t {\n    size_t h1 = std::hash<int>{}(key.device_id);\n    size_t h2 = std::hash<uintptr_t>{}(key.stream);\n    return h1 ^ (h2 + 0x9e3779b97f4a7c15ULL + (h1 << 6) + (h1 >> 2));\n  }\n};\n\nstruct WorkspaceState {\n  void* ptr = nullptr;\n  size_t bytes = 0;\n};\n\ninline auto get_cached_workspace(size_t required_bytes, int device_id, cudaStream_t stream) -> void* {\n  if (required_bytes == 0) {\n    return nullptr;\n  }\n\n  thread_local std::unordered_map<WorkspaceKey, WorkspaceState, WorkspaceKeyHash> cache;\n  WorkspaceKey key{device_id, reinterpret_cast<uintptr_t>(stream)};\n  auto& ws = cache[key];\n\n  if (ws.ptr != nullptr && ws.bytes >= required_bytes) {\n    return ws.ptr;\n  }\n\n  RuntimeDeviceCheck(cudaSetDevice(device_id));\n  if (ws.ptr != nullptr) {\n    RuntimeDeviceCheck(cudaFreeAsync(ws.ptr, stream));\n    ws.ptr = nullptr;\n    ws.bytes = 0;\n  }\n  RuntimeDeviceCheck(cudaMallocAsync(&ws.ptr, required_bytes, stream));\n  ws.bytes = required_bytes;\n  return ws.ptr;\n}\n\ninline int getSMVersion(int device_id) {\n  int sm_major = 0;\n  int sm_minor = 0;\n  RuntimeDeviceCheck(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device_id));\n  RuntimeDeviceCheck(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device_id));\n  return sm_major * 10 + sm_minor;\n}\n\ntemplate <\n    typename ElementAB,\n    typename ElementC,\n    typename ElementSF,\n    typename ElementAccumulator,\n    typename LayoutSFA,\n    typename LayoutSFB,\n    typename ScaleConfig>\n__global__ void __get_group_gemm_starts(\n    ElementAB** a_offsets,\n    ElementAB** b_offsets,\n    ElementC** out_offsets,\n    ElementSF** a_scales_offsets,\n    ElementSF** b_scales_offsets,\n    ElementAccumulator** alpha_offsets,\n    LayoutSFA* layout_sfa_base_as_int,\n    LayoutSFB* layout_sfb_base_as_int,\n    ElementAB* a_base_as_int,\n    ElementAB* b_base_as_int,\n    ElementC* out_base_as_int,\n    ElementSF* a_scales_base_as_int,\n    ElementSF* b_scales_base_as_int,\n    ElementAccumulator* alphas_base_as_int,\n    const int32_t* expert_offsets,\n    const int32_t* sf_offsets,\n    const int32_t* problem_sizes_as_shapes,\n    const int K,\n    const int N) {\n  int64_t expert_id = threadIdx.x;\n  if (expert_id >= gridDim.x * blockDim.x) {\n    return;\n  }\n  // Originally int32_t but upcasting to int64_t to avoid overflow\n  // during offset calculations\n  int64_t expert_offset = static_cast<int64_t>(expert_offsets[expert_id]);\n  int64_t sf_offset = static_cast<int64_t>(sf_offsets[expert_id]);\n  // size for block in block scale.\n  int64_t group_size = 16;\n  int64_t m = static_cast<int64_t>(problem_sizes_as_shapes[expert_id * 3]);\n  int64_t n = static_cast<int64_t>(problem_sizes_as_shapes[expert_id * 3 + 1]);\n  int64_t k = static_cast<int64_t>(problem_sizes_as_shapes[expert_id * 3 + 2]);\n  assert((m >= 0 && n == N && k == K && k % 2 == 0) && \"unexpected problem sizes\");\n\n  int64_t half_k = static_cast<int64_t>(k / 2);\n  int64_t group_k = static_cast<int64_t>(k / group_size);\n  // Shape of A as uint8/byte = [M, K // 2]\n  // Shape of B as uint8/byte = [E, N, K // 2]\n  a_offsets[expert_id] = a_base_as_int + expert_offset * half_k;\n\n  b_offsets[expert_id] = b_base_as_int + expert_id * n * half_k;\n  // Shape of C = [M, N]\n  out_offsets[expert_id] = out_base_as_int + expert_offset * n;\n  // Shape of a_scale = [sum(sf_sizes), K // group_size]\n  a_scales_offsets[expert_id] = a_scales_base_as_int + sf_offset * group_k;\n\n  assert((reinterpret_cast<uintptr_t>(a_scales_offsets[expert_id]) % 128) == 0 && \"TMA requires 128-byte alignment\");\n\n  // Shape of B scale = [E, N, K // group_size]\n  b_scales_offsets[expert_id] = b_scales_base_as_int + expert_id * n * group_k;\n  assert((reinterpret_cast<uintptr_t>(b_scales_offsets[expert_id]) % 128) == 0 && \"TMA requires 128-byte alignment\");\n  // Shape of alpha = [E]\n  alpha_offsets[expert_id] = alphas_base_as_int + expert_id;\n\n  LayoutSFA* layout_sfa_ptr = layout_sfa_base_as_int + expert_id;\n  LayoutSFB* layout_sfb_ptr = layout_sfb_base_as_int + expert_id;\n\n  *layout_sfa_ptr = ScaleConfig::tile_atom_to_shape_SFA(\n      cute::make_shape(static_cast<int>(m), static_cast<int>(n), static_cast<int>(k), 1));\n  *layout_sfb_ptr = ScaleConfig::tile_atom_to_shape_SFB(\n      cute::make_shape(static_cast<int>(m), static_cast<int>(n), static_cast<int>(k), 1));\n}\n\n#define __CALL_GET_STARTS_KERNEL_BLOCKSCALE(                                                            \\\n    ELEMENT_AB_TYPE, SF_TYPE, TYPE_CHECK, C_TYPE, LayoutSFA, LayoutSFB, ScaleConfig)                    \\\n  else if (TYPE_CHECK) {                                                                                \\\n    __get_group_gemm_starts<ELEMENT_AB_TYPE, C_TYPE, SF_TYPE, float, LayoutSFA, LayoutSFB, ScaleConfig> \\\n        <<<1, num_experts, 0, stream>>>(                                                                \\\n            static_cast<ELEMENT_AB_TYPE**>(a_starts.data_ptr()),                                        \\\n            static_cast<ELEMENT_AB_TYPE**>(b_starts.data_ptr()),                                        \\\n            static_cast<C_TYPE**>(out_starts.data_ptr()),                                               \\\n            static_cast<SF_TYPE**>(a_scales_starts.data_ptr()),                                         \\\n            static_cast<SF_TYPE**>(b_scales_starts.data_ptr()),                                         \\\n            static_cast<float**>(alpha_starts.data_ptr()),                                              \\\n            reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()),                                        \\\n            reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr()),                                        \\\n            static_cast<ELEMENT_AB_TYPE*>(a_tensors.data_ptr()),                                        \\\n            static_cast<ELEMENT_AB_TYPE*>(b_tensors.data_ptr()),                                        \\\n            static_cast<C_TYPE*>(out_tensors.data_ptr()),                                               \\\n            static_cast<SF_TYPE*>(a_scales.data_ptr()),                                                 \\\n            static_cast<SF_TYPE*>(b_scales.data_ptr()),                                                 \\\n            static_cast<float*>(alphas.data_ptr()),                                                     \\\n            static_cast<int32_t*>(expert_offsets.data_ptr()),                                           \\\n            static_cast<int32_t*>(sf_offsets.data_ptr()),                                               \\\n            static_cast<int32_t*>(problem_sizes.data_ptr()),                                            \\\n            K,                                                                                          \\\n            N);                                                                                         \\\n  }\n\ntemplate <typename LayoutSFA, typename LayoutSFB, typename ScaleConfig>\nvoid run_get_group_gemm_starts(\n    const tvm::ffi::TensorView a_starts,\n    const tvm::ffi::TensorView b_starts,\n    const tvm::ffi::TensorView out_starts,\n    const tvm::ffi::TensorView a_scales_starts,\n    const tvm::ffi::TensorView b_scales_starts,\n    const tvm::ffi::TensorView alpha_starts,\n    const tvm::ffi::TensorView layout_sfa,\n    const tvm::ffi::TensorView layout_sfb,\n    /*these are used for their base addresses*/\n    tvm::ffi::TensorView const& a_tensors,\n    tvm::ffi::TensorView const& b_tensors,\n    tvm::ffi::TensorView const& out_tensors,\n    tvm::ffi::TensorView const& a_scales,\n    tvm::ffi::TensorView const& b_scales,\n    tvm::ffi::TensorView const& alphas,\n    tvm::ffi::TensorView const& expert_offsets,\n    tvm::ffi::TensorView const& sf_offsets,\n    tvm::ffi::TensorView const& problem_sizes,\n    int M,\n    int N,\n    int K) {\n  int num_experts = static_cast<int>(expert_offsets.size(0));\n  auto stream = LaunchKernel::resolve_device(a_tensors.device());\n\n  RuntimeCheck(out_tensors.size(1) == N, \"Output tensor shape doesn't match expected shape\");\n  RuntimeCheck(\n      K / 2 == b_tensors.size(2),\n      \"b_tensors(dim = 2) and a_tensors(dim = 1) trailing\"\n      \" dimension must match\");\n  if (false) {\n  }\n  //(ELEMENT_AB_TYPE, BS_TYPE, TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB,\n  // ScaleConfig)\n  __CALL_GET_STARTS_KERNEL_BLOCKSCALE(\n      cutlass::float_e2m1_t,\n      cutlass::float_ue4m3_t,\n      host::is_type<bf16_t>(out_tensors.dtype()),\n      cutlass::bfloat16_t,\n      LayoutSFA,\n      LayoutSFB,\n      ScaleConfig)\n  __CALL_GET_STARTS_KERNEL_BLOCKSCALE(\n      cutlass::float_e2m1_t,\n      cutlass::float_ue4m3_t,\n      host::is_type<fp16_t>(out_tensors.dtype()),\n      cutlass::half_t,\n      LayoutSFA,\n      LayoutSFB,\n      ScaleConfig)\n  else {\n    Panic(\"Invalid output type (must be float16 or bfloat16)\");\n  }\n}\n\nvoid run_fp4_blockwise_scaled_group_mm_sm120(\n    tvm::ffi::TensorView output,\n    const tvm::ffi::TensorView a,\n    const tvm::ffi::TensorView b,\n    const tvm::ffi::TensorView a_blockscale,\n    const tvm::ffi::TensorView b_blockscales,\n    const tvm::ffi::TensorView alphas,\n    const tvm::ffi::TensorView ab_strides,\n    const tvm::ffi::TensorView c_strides,\n    const tvm::ffi::TensorView problem_sizes,\n    const tvm::ffi::TensorView expert_offsets,\n    const tvm::ffi::TensorView sf_offsets,\n    const tvm::ffi::TensorView a_ptrs,\n    const tvm::ffi::TensorView b_ptrs,\n    const tvm::ffi::TensorView out_ptrs,\n    const tvm::ffi::TensorView a_scales_ptrs,\n    const tvm::ffi::TensorView b_scales_ptrs,\n    const tvm::ffi::TensorView alpha_ptrs,\n    const tvm::ffi::TensorView layout_sfa,\n    const tvm::ffi::TensorView layout_sfb,\n    int M,\n    int N,\n    int K) {\n  using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int32_t, int32_t, int32_t>>;\n  using ElementType = cutlass::float_e2m1_t;\n  using ElementSFType = cutlass::float_ue4m3_t;\n  using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>;\n  using ElementB = cutlass::nv_float4_t<cutlass::float_e2m1_t>;\n\n  using ElementC = cutlass::bfloat16_t;\n  using ElementD = cutlass::bfloat16_t;\n  using ElementAccumulator = float;\n  // Layout definitions\n  using LayoutA = cutlass::layout::RowMajor;\n  using LayoutB = cutlass::layout::ColumnMajor;\n  using LayoutC = cutlass::layout::RowMajor;\n  using LayoutD = cutlass::layout::RowMajor;\n\n  // Alignment constraints\n  static constexpr int AlignmentA = 32;\n  static constexpr int AlignmentB = 32;\n  static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;\n  static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;\n\n  // Architecture definitions\n  using ArchTag = cutlass::arch::Sm120;\n  using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp;\n  using StageCountType = cutlass::gemm::collective::StageCountAuto;\n  using ThreadBlockShape = Shape<_128, _128, _128>;\n  // on the tile size\n\n  using ClusterShape = Shape<_1, _1, _1>;\n\n  using FusionOperation =\n      cutlass::epilogue::fusion::LinearCombination<ElementD, ElementAccumulator, ElementC, ElementAccumulator>;\n\n  using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<\n      ArchTag,\n      OperatorClass,\n      ThreadBlockShape,\n      ClusterShape,\n      cutlass::epilogue::collective::EpilogueTileAuto,\n      ElementAccumulator,\n      ElementAccumulator,\n      ElementC,\n      LayoutC*,\n      AlignmentC,\n      ElementD,\n      LayoutC*,\n      AlignmentD,\n      cutlass::epilogue::collective::EpilogueScheduleAuto,\n      FusionOperation>::CollectiveOp;\n\n  using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<\n      ArchTag,\n      OperatorClass,\n      ElementA,\n      LayoutA*,\n      AlignmentA,\n      ElementB,\n      LayoutB*,\n      AlignmentB,\n      ElementAccumulator,\n      ThreadBlockShape,\n      ClusterShape,\n      cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(\n          sizeof(typename CollectiveEpilogue::SharedStorage))>,\n      cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong>::CollectiveOp;\n\n  using GemmKernel = cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop, CollectiveEpilogue>;\n\n  using Gemm1SM = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;\n  using Gemm = Gemm1SM;\n  using StrideA = typename Gemm::GemmKernel::InternalStrideA;\n  using StrideB = typename Gemm::GemmKernel::InternalStrideB;\n  using StrideC = typename Gemm::GemmKernel::InternalStrideC;\n  using StrideD = typename Gemm::GemmKernel::InternalStrideD;\n\n  using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA;\n  using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB;\n  using ScaleConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;\n\n  using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;\n  int num_experts = static_cast<int>(expert_offsets.size(0));\n\n  run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>(\n      a_ptrs,\n      b_ptrs,\n      out_ptrs,\n      a_scales_ptrs,\n      b_scales_ptrs,\n      alpha_ptrs,\n      layout_sfa,\n      layout_sfb,\n      a,\n      b,\n      output,\n      a_blockscale,\n      b_blockscales,\n      alphas,\n      expert_offsets,\n      sf_offsets,\n      problem_sizes,\n      M,\n      N,\n      K);\n\n  // Create an instance of the GEMM\n  Gemm gemm_op;\n\n  // Initialize problem_sizes_as_shapes correctly\n  UnderlyingProblemShape* problem_sizes_as_shapes = static_cast<UnderlyingProblemShape*>(problem_sizes.data_ptr());\n\n  // Set the Scheduler info\n  cutlass::KernelHardwareInfo hw_info;\n\n  using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions;\n  typename Gemm::GemmKernel::TileSchedulerArguments scheduler;\n  scheduler.raster_order = RasterOrderOptions::AlongM;\n  hw_info.device_id = a.device().device_id;\n  static std::unordered_map<int, int> cached_sm_counts;\n  if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) {\n    cached_sm_counts[hw_info.device_id] =\n        cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);\n  }\n  hw_info.sm_count = std::min(cached_sm_counts[hw_info.device_id], std::numeric_limits<int>::max());\n\n  // Mainloop Arguments\n  typename GemmKernel::MainloopArguments mainloop_args{\n      static_cast<const ElementType**>(a_ptrs.data_ptr()),\n      static_cast<StrideA*>(ab_strides.data_ptr()),\n      static_cast<const ElementType**>(b_ptrs.data_ptr()),\n      static_cast<StrideB*>(ab_strides.data_ptr()),\n      static_cast<const ElementSFType**>(a_scales_ptrs.data_ptr()),\n      reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()),\n      static_cast<const ElementSFType**>(b_scales_ptrs.data_ptr()),\n      reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr())};\n\n  // Epilogue Arguments\n  typename GemmKernel::EpilogueArguments epilogue_args{\n      {},  // epilogue.thread\n      nullptr,\n      static_cast<StrideC*>(c_strides.data_ptr()),\n      static_cast<ElementD**>(out_ptrs.data_ptr()),\n      static_cast<StrideC*>(c_strides.data_ptr())};\n  auto& fusion_args = epilogue_args.thread;\n  fusion_args.alpha_ptr_array = reinterpret_cast<float**>(alpha_ptrs.data_ptr());\n  fusion_args.dAlpha = {_0{}, _0{}, 1};\n  fusion_args.beta = 0.0f;\n\n  // Gemm Arguments\n  typename GemmKernel::Arguments args{\n      cutlass::gemm::GemmUniversalMode::kGrouped,\n      {num_experts, problem_sizes_as_shapes, nullptr},\n      mainloop_args,\n      epilogue_args,\n      hw_info,\n      scheduler};\n\n  size_t workspace_size = Gemm::get_workspace_size(args);\n  const cudaStream_t stream = LaunchKernel::resolve_device(a.device());\n  void* workspace = get_cached_workspace(workspace_size, hw_info.device_id, stream);\n\n  auto can_implement_status = gemm_op.can_implement(args);\n  RuntimeCheck(\n      can_implement_status == cutlass::Status::kSuccess,\n      \"Failed to implement GEMM: \",\n      cutlassGetStatusString(can_implement_status));\n\n  // Run the GEMM\n  auto status = gemm_op.initialize(args, workspace);\n  RuntimeCheck(status == cutlass::Status::kSuccess, \"Failed to initialize GEMM: \", cutlassGetStatusString(status));\n\n  status = gemm_op.run(args, workspace, stream);\n  RuntimeCheck(status == cutlass::Status::kSuccess, \"Failed to run GEMM: \", cutlassGetStatusString(status));\n}\n\ntemplate <typename OutType>\nvoid run_fp4_blockwise_scaled_group_mm_sm100(\n    tvm::ffi::TensorView output,\n    const tvm::ffi::TensorView a,\n    const tvm::ffi::TensorView b,\n    const tvm::ffi::TensorView a_blockscale,\n    const tvm::ffi::TensorView b_blockscales,\n    const tvm::ffi::TensorView alphas,\n    const tvm::ffi::TensorView ab_strides,\n    const tvm::ffi::TensorView c_strides,\n    const tvm::ffi::TensorView problem_sizes,\n    const tvm::ffi::TensorView expert_offsets,\n    const tvm::ffi::TensorView sf_offsets,\n    const tvm::ffi::TensorView a_ptrs,\n    const tvm::ffi::TensorView b_ptrs,\n    const tvm::ffi::TensorView out_ptrs,\n    const tvm::ffi::TensorView a_scales_ptrs,\n    const tvm::ffi::TensorView b_scales_ptrs,\n    const tvm::ffi::TensorView alpha_ptrs,\n    const tvm::ffi::TensorView layout_sfa,\n    const tvm::ffi::TensorView layout_sfb,\n    int M,\n    int N,\n    int K) {\n  using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int32_t, int32_t, int32_t>>;\n  using ElementType = cutlass::float_e2m1_t;\n  using ElementSFType = cutlass::float_ue4m3_t;\n  using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>;\n  using ElementB = cutlass::nv_float4_t<cutlass::float_e2m1_t>;\n\n  using ElementC = OutType;\n  using ElementD = ElementC;\n  using ElementAccumulator = float;\n  // Layout definitions\n  using LayoutA = cutlass::layout::RowMajor;\n  using LayoutB = cutlass::layout::ColumnMajor;\n  using LayoutC = cutlass::layout::RowMajor;\n  using LayoutD = LayoutC;\n\n  // Alignment constraints\n  static constexpr int AlignmentA = 32;\n  static constexpr int AlignmentB = 32;\n  static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;\n  static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;\n\n  // Architecture definitions\n  using ArchTag = cutlass::arch::Sm100;\n  using EpilogueOperatorClass = cutlass::arch::OpClassTensorOp;             // Epilogue Operator class tag\n  using MainloopOperatorClass = cutlass::arch::OpClassBlockScaledTensorOp;  // Mainloop Operator class tag\n  using StageCountType = cutlass::gemm::collective::StageCountAuto;         // Stage count maximized based\n                                                                            // on the tile size\n\n  using ClusterShape = Shape<_1, _1, _1>;\n  struct MMA1SMConfig {\n    using MmaTileShape = Shape<_128, _128, _128>;\n    using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100;  // Kernel to launch\n    using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;           // Epilogue to launch\n  };\n\n  using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<\n      ArchTag,\n      EpilogueOperatorClass,\n      typename MMA1SMConfig::MmaTileShape,\n      ClusterShape,\n      Shape<_128, _64>,\n      ElementAccumulator,\n      ElementAccumulator,\n      ElementC,\n      LayoutC*,\n      AlignmentC,\n      ElementD,\n      LayoutC*,\n      AlignmentD,\n      typename MMA1SMConfig::EpilogueSchedule>::CollectiveOp;\n\n  using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<\n      ArchTag,\n      MainloopOperatorClass,\n      ElementA,\n      LayoutA*,\n      AlignmentA,\n      ElementB,\n      LayoutB*,\n      AlignmentB,\n      ElementAccumulator,\n      typename MMA1SMConfig::MmaTileShape,\n      ClusterShape,\n      cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(\n          sizeof(typename CollectiveEpilogue::SharedStorage))>,\n      typename MMA1SMConfig::KernelSchedule>::CollectiveOp;\n\n  using GemmKernel = cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop, CollectiveEpilogue>;\n\n  using Gemm1SM = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;\n  using Gemm = Gemm1SM;\n  using StrideA = typename Gemm::GemmKernel::InternalStrideA;\n  using StrideB = typename Gemm::GemmKernel::InternalStrideB;\n  using StrideC = typename Gemm::GemmKernel::InternalStrideC;\n  using StrideD = typename Gemm::GemmKernel::InternalStrideD;\n\n  using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA;\n  using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB;\n  using ScaleConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;\n\n  using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;\n  int num_experts = static_cast<int>(expert_offsets.size(0));\n\n  run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>(\n      a_ptrs,\n      b_ptrs,\n      out_ptrs,\n      a_scales_ptrs,\n      b_scales_ptrs,\n      alpha_ptrs,\n      layout_sfa,\n      layout_sfb,\n      a,\n      b,\n      output,\n      a_blockscale,\n      b_blockscales,\n      alphas,\n      expert_offsets,\n      sf_offsets,\n      problem_sizes,\n      M,\n      N,\n      K);\n\n  // Create an instance of the GEMM\n  Gemm gemm_op;\n\n  // Initialize problem_sizes_as_shapes correctly\n  UnderlyingProblemShape* problem_sizes_as_shapes = static_cast<UnderlyingProblemShape*>(problem_sizes.data_ptr());\n\n  // Set the Scheduler info\n  cutlass::KernelHardwareInfo hw_info;\n  using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm100GroupParams<\n      typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions;\n  typename Gemm::GemmKernel::TileSchedulerArguments scheduler;\n  scheduler.raster_order = RasterOrderOptions::AlongM;\n  hw_info.device_id = a.device().device_id;\n  static std::unordered_map<int, int> cached_sm_counts;\n  if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) {\n    cached_sm_counts[hw_info.device_id] =\n        cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);\n  }\n  hw_info.sm_count = std::min(cached_sm_counts[hw_info.device_id], std::numeric_limits<int>::max());\n\n  // Mainloop Arguments\n  typename GemmKernel::MainloopArguments mainloop_args{\n      static_cast<const ElementType**>(a_ptrs.data_ptr()),\n      static_cast<StrideA*>(ab_strides.data_ptr()),\n      static_cast<const ElementType**>(b_ptrs.data_ptr()),\n      static_cast<StrideB*>(ab_strides.data_ptr()),\n      static_cast<const ElementSFType**>(a_scales_ptrs.data_ptr()),\n      reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()),\n      static_cast<const ElementSFType**>(b_scales_ptrs.data_ptr()),\n      reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr())};\n\n  // Epilogue Arguments\n  typename GemmKernel::EpilogueArguments epilogue_args{\n      {},  // epilogue.thread\n      nullptr,\n      static_cast<StrideC*>(c_strides.data_ptr()),\n      static_cast<ElementD**>(out_ptrs.data_ptr()),\n      static_cast<StrideC*>(c_strides.data_ptr())};\n  auto& fusion_args = epilogue_args.thread;\n  fusion_args.alpha_ptr_array = reinterpret_cast<float**>(alpha_ptrs.data_ptr());\n  fusion_args.dAlpha = {_0{}, _0{}, 1};\n\n  // Gemm Arguments\n  typename GemmKernel::Arguments args{\n      cutlass::gemm::GemmUniversalMode::kGrouped,\n      {num_experts, problem_sizes_as_shapes, nullptr},\n      mainloop_args,\n      epilogue_args,\n      hw_info,\n      scheduler};\n\n  size_t workspace_size = Gemm::get_workspace_size(args);\n  const cudaStream_t stream = LaunchKernel::resolve_device(a.device());\n  void* workspace = get_cached_workspace(workspace_size, hw_info.device_id, stream);\n\n  auto can_implement_status = gemm_op.can_implement(args);\n  RuntimeCheck(\n      can_implement_status == cutlass::Status::kSuccess,\n      \"Failed to implement GEMM: \",\n      cutlassGetStatusString(can_implement_status));\n\n  // Run the GEMM\n  auto status = gemm_op.initialize(args, workspace);\n  RuntimeCheck(status == cutlass::Status::kSuccess, \"Failed to initialize GEMM: \", cutlassGetStatusString(status));\n\n  status = gemm_op.run(args, workspace, stream);\n  RuntimeCheck(status == cutlass::Status::kSuccess, \"Failed to run GEMM: \", cutlassGetStatusString(status));\n}\n\nvoid cutlass_fp4_group_mm_sm100a_sm120a(\n    tvm::ffi::TensorView output,\n    const tvm::ffi::TensorView a,\n    const tvm::ffi::TensorView b,\n    const tvm::ffi::TensorView a_blockscale,\n    const tvm::ffi::TensorView b_blockscales,\n    const tvm::ffi::TensorView alphas,\n    const tvm::ffi::TensorView ab_strides,\n    const tvm::ffi::TensorView c_strides,\n    const tvm::ffi::TensorView problem_sizes,\n    const tvm::ffi::TensorView expert_offsets,\n    const tvm::ffi::TensorView sf_offsets,\n    const tvm::ffi::TensorView a_ptrs,\n    const tvm::ffi::TensorView b_ptrs,\n    const tvm::ffi::TensorView out_ptrs,\n    const tvm::ffi::TensorView a_scales_ptrs,\n    const tvm::ffi::TensorView b_scales_ptrs,\n    const tvm::ffi::TensorView alpha_ptrs,\n    const tvm::ffi::TensorView layout_sfa,\n    const tvm::ffi::TensorView layout_sfb) {\n  auto check_cuda_contig = [](const tvm::ffi::TensorView t, const char* name) {\n    RuntimeCheck(t.device().device_type == kDLCUDA, name, \" must be a CUDA tensor\");\n    RuntimeCheck(t.is_contiguous(), name, \" must be contiguous\");\n  };\n\n  check_cuda_contig(output, \"output\");\n  check_cuda_contig(a, \"a\");\n  check_cuda_contig(b, \"b\");\n  check_cuda_contig(a_blockscale, \"a_blockscale\");\n  check_cuda_contig(b_blockscales, \"b_blockscales\");\n  check_cuda_contig(alphas, \"alphas\");\n  check_cuda_contig(ab_strides, \"ab_strides\");\n  check_cuda_contig(c_strides, \"c_strides\");\n  check_cuda_contig(problem_sizes, \"problem_sizes\");\n  check_cuda_contig(expert_offsets, \"expert_offsets\");\n  check_cuda_contig(sf_offsets, \"sf_offsets\");\n  check_cuda_contig(a_ptrs, \"a_ptrs\");\n  check_cuda_contig(b_ptrs, \"b_ptrs\");\n  check_cuda_contig(out_ptrs, \"out_ptrs\");\n  check_cuda_contig(a_scales_ptrs, \"a_scales_ptrs\");\n  check_cuda_contig(b_scales_ptrs, \"b_scales_ptrs\");\n  check_cuda_contig(alpha_ptrs, \"alpha_ptrs\");\n  check_cuda_contig(layout_sfa, \"layout_sfa\");\n  check_cuda_contig(layout_sfb, \"layout_sfb\");\n\n  RuntimeCheck(\n      output.device() == a.device() && a.device() == b.device() && a.device() == a_blockscale.device() &&\n          a.device() == b_blockscales.device() && a.device() == alphas.device() && a.device() == ab_strides.device() &&\n          a.device() == c_strides.device() && a.device() == problem_sizes.device() &&\n          a.device() == expert_offsets.device() && a.device() == sf_offsets.device() && a.device() == a_ptrs.device() &&\n          a.device() == b_ptrs.device() && a.device() == out_ptrs.device() && a.device() == a_scales_ptrs.device() &&\n          a.device() == b_scales_ptrs.device() && a.device() == alpha_ptrs.device() &&\n          a.device() == layout_sfa.device() && a.device() == layout_sfb.device(),\n      \"all tensors must be on the same device\");\n\n  RuntimeCheck(host::is_type<uint8_t>(a.dtype()), \"a must be uint8\");\n  RuntimeCheck(host::is_type<uint8_t>(b.dtype()), \"b must be uint8\");\n  RuntimeCheck(host::is_type<fp8_e4m3_t>(a_blockscale.dtype()), \"a_blockscale must be float8_e4m3fn\");\n  RuntimeCheck(host::is_type<fp8_e4m3_t>(b_blockscales.dtype()), \"b_blockscales must be float8_e4m3fn\");\n  RuntimeCheck(host::is_type<float>(alphas.dtype()), \"alphas must be float32\");\n  RuntimeCheck(host::is_type<int64_t>(ab_strides.dtype()), \"ab_strides must be int64\");\n  RuntimeCheck(host::is_type<int64_t>(c_strides.dtype()), \"c_strides must be int64\");\n  RuntimeCheck(host::is_type<int32_t>(problem_sizes.dtype()), \"problem_sizes must be int32\");\n  RuntimeCheck(host::is_type<int32_t>(expert_offsets.dtype()), \"expert_offsets must be int32\");\n  RuntimeCheck(host::is_type<int32_t>(sf_offsets.dtype()), \"sf_offsets must be int32\");\n  RuntimeCheck(host::is_type<int64_t>(a_ptrs.dtype()), \"a_ptrs must be int64\");\n  RuntimeCheck(host::is_type<int64_t>(b_ptrs.dtype()), \"b_ptrs must be int64\");\n  RuntimeCheck(host::is_type<int64_t>(out_ptrs.dtype()), \"out_ptrs must be int64\");\n  RuntimeCheck(host::is_type<int64_t>(a_scales_ptrs.dtype()), \"a_scales_ptrs must be int64\");\n  RuntimeCheck(host::is_type<int64_t>(b_scales_ptrs.dtype()), \"b_scales_ptrs must be int64\");\n  RuntimeCheck(host::is_type<int64_t>(alpha_ptrs.dtype()), \"alpha_ptrs must be int64\");\n  RuntimeCheck(host::is_type<int64_t>(layout_sfa.dtype()), \"layout_sfa must be int64\");\n  RuntimeCheck(host::is_type<int64_t>(layout_sfb.dtype()), \"layout_sfb must be int64\");\n  RuntimeCheck(\n      host::is_type<bf16_t>(output.dtype()) || host::is_type<fp16_t>(output.dtype()),\n      \"output must be bfloat16 or float16\");\n\n  RuntimeCheck(a.dim() == 2, \"a must be 2D\");\n  RuntimeCheck(b.dim() == 3, \"b must be 3D\");\n  RuntimeCheck(a_blockscale.dim() == 2, \"a_blockscale must be 2D\");\n  RuntimeCheck(b_blockscales.dim() == 3, \"b_blockscales must be 3D\");\n  RuntimeCheck(alphas.dim() == 1, \"alphas must be 1D\");\n  RuntimeCheck(ab_strides.dim() == 1, \"ab_strides must be 1D\");\n  RuntimeCheck(c_strides.dim() == 1, \"c_strides must be 1D\");\n  RuntimeCheck(problem_sizes.dim() == 2, \"problem_sizes must be 2D\");\n  RuntimeCheck(expert_offsets.dim() == 1, \"expert_offsets must be 1D\");\n  RuntimeCheck(sf_offsets.dim() == 1, \"sf_offsets must be 1D\");\n  RuntimeCheck(a_ptrs.dim() == 1, \"a_ptrs must be 1D\");\n  RuntimeCheck(b_ptrs.dim() == 1, \"b_ptrs must be 1D\");\n  RuntimeCheck(out_ptrs.dim() == 1, \"out_ptrs must be 1D\");\n  RuntimeCheck(a_scales_ptrs.dim() == 1, \"a_scales_ptrs must be 1D\");\n  RuntimeCheck(b_scales_ptrs.dim() == 1, \"b_scales_ptrs must be 1D\");\n  RuntimeCheck(alpha_ptrs.dim() == 1, \"alpha_ptrs must be 1D\");\n  RuntimeCheck(layout_sfa.dim() == 2, \"layout_sfa must be 2D\");\n  RuntimeCheck(layout_sfb.dim() == 2, \"layout_sfb must be 2D\");\n  RuntimeCheck(problem_sizes.size(1) == 3, \"problem_sizes must have shape (num_experts, 3)\");\n\n  const int num_experts = static_cast<int>(expert_offsets.size(0));\n  RuntimeCheck(problem_sizes.size(0) == num_experts, \"problem_sizes size mismatch with expert_offsets\");\n  RuntimeCheck(sf_offsets.size(0) == num_experts, \"sf_offsets size mismatch with expert_offsets\");\n  RuntimeCheck(alphas.size(0) == num_experts, \"alphas size mismatch with expert_offsets\");\n  RuntimeCheck(ab_strides.size(0) == num_experts, \"ab_strides size mismatch with expert_offsets\");\n  RuntimeCheck(c_strides.size(0) == num_experts, \"c_strides size mismatch with expert_offsets\");\n  RuntimeCheck(a_ptrs.size(0) == num_experts, \"a_ptrs size mismatch with expert_offsets\");\n  RuntimeCheck(b_ptrs.size(0) == num_experts, \"b_ptrs size mismatch with expert_offsets\");\n  RuntimeCheck(out_ptrs.size(0) == num_experts, \"out_ptrs size mismatch with expert_offsets\");\n  RuntimeCheck(a_scales_ptrs.size(0) == num_experts, \"a_scales_ptrs size mismatch with expert_offsets\");\n  RuntimeCheck(b_scales_ptrs.size(0) == num_experts, \"b_scales_ptrs size mismatch with expert_offsets\");\n  RuntimeCheck(alpha_ptrs.size(0) == num_experts, \"alpha_ptrs size mismatch with expert_offsets\");\n  RuntimeCheck(layout_sfa.size(0) == num_experts && layout_sfa.size(1) == 5, \"layout_sfa must be [num_experts, 5]\");\n  RuntimeCheck(layout_sfb.size(0) == num_experts && layout_sfb.size(1) == 5, \"layout_sfb must be [num_experts, 5]\");\n\n  int M = static_cast<int>(a.size(0));\n  int N = static_cast<int>(b.size(1));\n  int K = static_cast<int>(2 * b.size(2));\n  RuntimeCheck(output.dim() == 2, \"output must be 2D\");\n  RuntimeCheck(output.size(0) == M && output.size(1) == N, \"output shape mismatch\");\n\n  auto sm_version = getSMVersion(a.device().device_id);\n  if (sm_version == 100 || sm_version == 103) {\n    if (host::is_type<bf16_t>(output.dtype())) {\n      run_fp4_blockwise_scaled_group_mm_sm100<cutlass::bfloat16_t>(\n          output,\n          a,\n          b,\n          a_blockscale,\n          b_blockscales,\n          alphas,\n          ab_strides,\n          c_strides,\n          problem_sizes,\n          expert_offsets,\n          sf_offsets,\n          a_ptrs,\n          b_ptrs,\n          out_ptrs,\n          a_scales_ptrs,\n          b_scales_ptrs,\n          alpha_ptrs,\n          layout_sfa,\n          layout_sfb,\n          M,\n          N,\n          K);\n    } else {\n      run_fp4_blockwise_scaled_group_mm_sm100<cutlass::half_t>(\n          output,\n          a,\n          b,\n          a_blockscale,\n          b_blockscales,\n          alphas,\n          ab_strides,\n          c_strides,\n          problem_sizes,\n          expert_offsets,\n          sf_offsets,\n          a_ptrs,\n          b_ptrs,\n          out_ptrs,\n          a_scales_ptrs,\n          b_scales_ptrs,\n          alpha_ptrs,\n          layout_sfa,\n          layout_sfb,\n          M,\n          N,\n          K);\n    }\n  } else if (sm_version >= 120) {\n    if (host::is_type<bf16_t>(output.dtype())) {\n      run_fp4_blockwise_scaled_group_mm_sm120(\n          output,\n          a,\n          b,\n          a_blockscale,\n          b_blockscales,\n          alphas,\n          ab_strides,\n          c_strides,\n          problem_sizes,\n          expert_offsets,\n          sf_offsets,\n          a_ptrs,\n          b_ptrs,\n          out_ptrs,\n          a_scales_ptrs,\n          b_scales_ptrs,\n          alpha_ptrs,\n          layout_sfa,\n          layout_sfb,\n          M,\n          N,\n          K);\n    } else {\n      Panic(\"SM120 path currently supports only bfloat16 output\");\n    }\n  } else {\n    RuntimeCheck(false, \"Unsupported SM version: \", sm_version);\n  }\n}\n\nvoid cutlass_fp4_group_mm(\n    tvm::ffi::TensorView output,\n    const tvm::ffi::TensorView a,\n    const tvm::ffi::TensorView b,\n    const tvm::ffi::TensorView a_blockscale,\n    const tvm::ffi::TensorView b_blockscales,\n    const tvm::ffi::TensorView alphas,\n    const tvm::ffi::TensorView ab_strides,\n    const tvm::ffi::TensorView c_strides,\n    const tvm::ffi::TensorView problem_sizes,\n    const tvm::ffi::TensorView expert_offsets,\n    const tvm::ffi::TensorView sf_offsets,\n    const tvm::ffi::TensorView a_ptrs,\n    const tvm::ffi::TensorView b_ptrs,\n    const tvm::ffi::TensorView out_ptrs,\n    const tvm::ffi::TensorView a_scales_ptrs,\n    const tvm::ffi::TensorView b_scales_ptrs,\n    const tvm::ffi::TensorView alpha_ptrs,\n    const tvm::ffi::TensorView layout_sfa,\n    const tvm::ffi::TensorView layout_sfb) {\n  cutlass_fp4_group_mm_sm100a_sm120a(\n      output,\n      a,\n      b,\n      a_blockscale,\n      b_blockscales,\n      alphas,\n      ab_strides,\n      c_strides,\n      problem_sizes,\n      expert_offsets,\n      sf_offsets,\n      a_ptrs,\n      b_ptrs,\n      out_ptrs,\n      a_scales_ptrs,\n      b_scales_ptrs,\n      alpha_ptrs,\n      layout_sfa,\n      layout_sfb);\n}\n"
  },
  {
    "path": "python/sglang/jit_kernel/csrc/ngram_embedding.cuh",
    "content": "#include <sgl_kernel/tensor.h>\n#include <sgl_kernel/utils.h>\n\n#include <sgl_kernel/utils.cuh>\n\n#include <dlpack/dlpack.h>\n\n#include <algorithm>\n#include <concepts>\n#include <cstddef>\n#include <cstdint>\n#include <type_traits>\n\nnamespace device::ngram_embedding {\n\n__global__ void ComputeNGramIdsKernel(\n    int batch_size,\n    int ne_n,\n    int ne_k,\n    int* ne_weights,                      // [ne_n-1,ne_k,ne_n]\n    int* ne_mods,                         // [ne_n-1,ne_k]\n    int* exclusive_ne_embeder_size_sums,  // [(ne_n-1)*ne_k]\n    int* tokens,                          // [token_num]\n    int* exclusive_req_len_sums,          // [batch_size+1]\n    int* ne_token_table,                  // [max_running_reqs, max_context_len]\n    int max_context_len,                  // max_context_len\n    long* row_indices,                    // [batch_size]\n    int* column_starts,                   // [batch_size]\n    int* n_gram_ids                       // [ne_n-1,ne_k,token_num]\n) {\n  // Determine which n, k, and request this block handles.\n  /**\n  Example: [req0, req1, req2] with n=3, k=2\n  n       k       req_id      blockIdx.x  config_id (combination of n and k)\n  2       1       0           0           0\n  2       1       1           1           0\n  2       1       2           2           0\n  2       2       0           3           1\n  2       2       1           4           1\n  2       2       2           5           1\n  3       1       0           0           2\n  3       1       1           1           2\n  3       1       2           2           2\n  3       2       0           3           3\n  3       2       1           4           3\n  3       2       2           5           3\n  */\n  const int req_id = blockIdx.x % batch_size;\n  const int config_id = (blockIdx.x - req_id) / batch_size;\n  // n and k here are offset from their physical meanings: n = real_n - 2, k = real_k - 1.\n  // This offset exists because n and k are used as indices into ne_weights and ne_mods.\n  const int k = config_id % ne_k;\n  const int n = (config_id - config_id % ne_k) / ne_k;\n  // ne_weights has shape [ne_n-1, ne_k, ne_n]; last dim is token distance, so compute base index first\n  const int ne_weight_base_idx = n * ne_k * ne_n + k * ne_n;\n  // ne_mods has shape [ne_n-1, ne_k]\n  const int ne_mod = ne_mods[n * ne_k + k];\n  // stride loop\n  for (int i = exclusive_req_len_sums[req_id] + threadIdx.x; i < exclusive_req_len_sums[req_id + 1]; i += blockDim.x) {\n    uint64_t n_gram_id = 0;\n    // Token offset within the current request\n    int current_token_offset = i - exclusive_req_len_sums[req_id];\n    // Start index of this request in the token table; tokens before this belong to other requests\n    int req_token_table_index = row_indices[req_id] * max_context_len;\n    // Position of the current token in the token table\n    int current_token_table_index = req_token_table_index + column_starts[req_id] + current_token_offset;\n    for (int j = 0; j < n + 2; j++) {\n      if (current_token_table_index - j < req_token_table_index) {\n        // Out of this request's range, stop computing n_gram_id\n        break;\n      }\n      if (ne_token_table[current_token_table_index - j] < 0) {\n        // Token was marked as ignored during write\n        break;\n      }\n      const uint64_t term =\n          (uint64_t)ne_token_table[current_token_table_index - j] * (uint64_t)ne_weights[ne_weight_base_idx + j];\n      n_gram_id += term % ne_mod;\n    }\n    n_gram_id %= ne_mod;\n    n_gram_id += exclusive_ne_embeder_size_sums[n * ne_k + k];\n    // [token_num, ne_n-1, ne_k]\n    n_gram_ids[i * (ne_n - 1) * ne_k + n * ne_k + k] = (int)(n_gram_id);\n  }\n}\n\n__global__ void UpdateTokenTableKernel(\n    int batch_size,\n    int* tokens,           // [token_num]\n    int* ne_token_table,   // [max_running_reqs, max_context_len]\n    int max_context_len,   // max_context_len\n    long* row_indices,     // [batch_size]\n    int* column_starts,    // [batch_size]\n    int* req_lens,         // [batch_size]\n    int ignore_token_num,  // number of tokens to ignore\n    int* ignore_tokens     // [ignore_token_num]\n) {\n  // Each block processes one request.\n  const int req_id = blockIdx.x % batch_size;\n  int start = 0;\n  int end = 0;\n  for (int i = 0; i < req_id; i++) {\n    start += req_lens[i];\n  }\n  end = start + req_lens[req_id];\n  // stride loop\n  for (int i = start + threadIdx.x; i < end; i += blockDim.x) {\n    // Token offset within the current request\n    int current_token_offset = i - start;\n    // Start index of this request in the token table\n    int req_token_table_index = row_indices[req_id] * max_context_len;\n    // Position of the current token in the token table\n    int current_token_table_index = req_token_table_index + column_starts[req_id] + current_token_offset;\n    ne_token_table[current_token_table_index] = tokens[i];\n    for (int j = 0; j < ignore_token_num; j++) {\n      if (ignore_tokens[j] == tokens[i]) {\n        ne_token_table[current_token_table_index] = -tokens[i];\n        break;\n      }\n    }\n  }\n}\n\n}  // namespace device::ngram_embedding\n\nnamespace {\n\nstruct NgramEmbeddingKernel {\n  static void compute_n_gram_ids(\n      const int64_t ne_n,\n      const int64_t ne_k,\n      const tvm::ffi::TensorView ne_weights,\n      const tvm::ffi::TensorView ne_mods,\n      const tvm::ffi::TensorView exclusive_ne_embeder_size_sums,\n      const tvm::ffi::TensorView tokens,\n      const tvm::ffi::TensorView exclusive_req_len_sums,\n      const tvm::ffi::TensorView ne_token_table,\n      const tvm::ffi::TensorView row_indices,\n      const tvm::ffi::TensorView column_starts,\n      const tvm::ffi::TensorView n_gram_ids) {\n    using namespace host;\n\n    auto device_ = SymbolicDevice{};\n\n    // Verify tensor shapes and types using -1 (kAnySize) for dynamic dimensions\n    TensorMatcher({-1, -1, -1})  // [ne_n-1, ne_k, ne_n]\n        .with_dtype<int32_t>()\n        .with_device<kDLCUDA>(device_)\n        .verify(ne_weights);\n\n    TensorMatcher({-1, -1})  // [ne_n-1, ne_k]\n        .with_dtype<int32_t>()\n        .with_device<kDLCUDA>()\n        .verify(ne_mods);\n\n    TensorMatcher({-1})  // [(ne_n-1)*ne_k + 1]\n        .with_dtype<int32_t>()\n        .with_device<kDLCUDA>()\n        .verify(exclusive_ne_embeder_size_sums);\n\n    TensorMatcher({-1})  // [token_num]\n        .with_dtype<int32_t>()\n        .with_device<kDLCUDA>()\n        .verify(tokens);\n\n    TensorMatcher({-1})  // [batch_size+1]\n        .with_dtype<int32_t>()\n        .with_device<kDLCUDA>()\n        .verify(exclusive_req_len_sums);\n\n    TensorMatcher({-1, -1})  // [max_running_reqs, max_context_len]\n        .with_dtype<int32_t>()\n        .with_device<kDLCUDA>()\n        .verify(ne_token_table);\n\n    TensorMatcher({-1})  // [batch_size]\n        .with_dtype<int64_t>()\n        .with_device<kDLCUDA>()\n        .verify(row_indices);\n\n    TensorMatcher({-1})  // [batch_size]\n        .with_dtype<int32_t>()\n        .with_device<kDLCUDA>()\n        .verify(column_starts);\n\n    TensorMatcher({-1, -1})  // [token_num, (ne_n-1)*ne_k]\n        .with_dtype<int32_t>()\n        .with_device<kDLCUDA>()\n        .verify(n_gram_ids);\n\n    const int batch_size = static_cast<int>(exclusive_req_len_sums.size(0) - 1);\n    const int max_context_len = static_cast<int>(ne_token_table.size(1));\n    const auto stream = LaunchKernel::resolve_device(device_.unwrap());\n\n    constexpr int BLOCK_THREADS = 256;\n    const int num_configs = (static_cast<int>(ne_n) - 1) * static_cast<int>(ne_k);\n    const int grid_size = num_configs * batch_size;\n\n    LaunchKernel(grid_size, BLOCK_THREADS, stream)(\n        device::ngram_embedding::ComputeNGramIdsKernel,\n        batch_size,\n        static_cast<int>(ne_n),\n        static_cast<int>(ne_k),\n        static_cast<int*>(ne_weights.data_ptr()),\n        static_cast<int*>(ne_mods.data_ptr()),\n        static_cast<int*>(exclusive_ne_embeder_size_sums.data_ptr()),\n        static_cast<int*>(tokens.data_ptr()),\n        static_cast<int*>(exclusive_req_len_sums.data_ptr()),\n        static_cast<int*>(ne_token_table.data_ptr()),\n        max_context_len,\n        static_cast<long*>(row_indices.data_ptr()),\n        static_cast<int*>(column_starts.data_ptr()),\n        static_cast<int*>(n_gram_ids.data_ptr()));\n  }\n\n  static void update_token_table(\n      const tvm::ffi::TensorView tokens,\n      const tvm::ffi::TensorView ne_token_table,\n      const tvm::ffi::TensorView row_indices,\n      const tvm::ffi::TensorView column_starts,\n      const tvm::ffi::TensorView req_lens,\n      const tvm::ffi::TensorView ignore_tokens) {\n    using namespace host;\n\n    auto device_ = SymbolicDevice{};\n\n    // Verify tensor shapes and types using -1 (kAnySize) for dynamic dimensions\n    TensorMatcher({-1})  // [token_num]\n        .with_dtype<int32_t>()\n        .with_device<kDLCUDA>(device_)\n        .verify(tokens);\n\n    TensorMatcher({-1, -1})  // [max_running_reqs, max_context_len]\n        .with_dtype<int32_t>()\n        .with_device<kDLCUDA>()\n        .verify(ne_token_table);\n\n    TensorMatcher({-1})  // [batch_size]\n        .with_dtype<int64_t>()\n        .with_device<kDLCUDA>()\n        .verify(row_indices);\n\n    TensorMatcher({-1})  // [batch_size]\n        .with_dtype<int32_t>()\n        .with_device<kDLCUDA>()\n        .verify(column_starts);\n\n    TensorMatcher({-1})  // [batch_size]\n        .with_dtype<int32_t>()\n        .with_device<kDLCUDA>()\n        .verify(req_lens);\n\n    // ignore_tokens can be empty or have values\n    void* ignore_tokens_ptr = ignore_tokens.data_ptr();\n    const bool has_ignore_tokens = ignore_tokens_ptr != nullptr && ignore_tokens.numel() > 0;\n    if (has_ignore_tokens) {\n      TensorMatcher({-1})  // [ignore_token_num]\n          .with_dtype<int32_t>()\n          .with_device<kDLCUDA>()\n          .verify(ignore_tokens);\n    }\n\n    const int batch_size = static_cast<int>(req_lens.size(0));\n    if (batch_size <= 0) {\n      return;\n    }\n\n    const int max_context_len = static_cast<int>(ne_token_table.size(1));\n    const auto stream = LaunchKernel::resolve_device(device_.unwrap());\n\n    constexpr int BLOCK_THREADS = 256;\n    const int grid_size = batch_size;\n\n    int ignore_token_num = 0;\n    int* ignore_tokens_typed_ptr = nullptr;\n    if (has_ignore_tokens) {\n      ignore_token_num = static_cast<int>(ignore_tokens.numel());\n      ignore_tokens_typed_ptr = static_cast<int*>(ignore_tokens_ptr);\n    }\n\n    LaunchKernel(grid_size, BLOCK_THREADS, stream)(\n        device::ngram_embedding::UpdateTokenTableKernel,\n        batch_size,\n        static_cast<int*>(tokens.data_ptr()),\n        static_cast<int*>(ne_token_table.data_ptr()),\n        max_context_len,\n        static_cast<long*>(row_indices.data_ptr()),\n        static_cast<int*>(column_starts.data_ptr()),\n        static_cast<int*>(req_lens.data_ptr()),\n        ignore_token_num,\n        ignore_tokens_typed_ptr);\n  }\n};\n\n}  // namespace\n"
  },
  {
    "path": "python/sglang/jit_kernel/csrc/nsa/fused_store_index_cache.cuh",
    "content": "#include <sgl_kernel/tensor.h>\n#include <sgl_kernel/utils.h>\n\n#include <sgl_kernel/math.cuh>\n#include <sgl_kernel/type.cuh>\n#include <sgl_kernel/utils.cuh>\n#include <sgl_kernel/vec.cuh>\n#include <sgl_kernel/warp.cuh>\n\n#include <dlpack/dlpack.h>\n#include <tvm/ffi/container/tensor.h>\n\n#include <bit>\n#include <cstdint>\n#include <cuda_fp8.h>\n\nnamespace {\n\nstruct FusedStoreCacheParam {\n  const void* __restrict__ input;\n  void* __restrict__ cache;\n  const void* __restrict__ indices;\n  uint32_t num_tokens;\n};\n\n[[maybe_unused]]\nSGL_DEVICE float fp8_e4m3_clip(float val) {\n  namespace math = device::math;\n  return math::max(math::min(val, math::FP8_E4M3_MAX), -math::FP8_E4M3_MAX);\n}\n\n[[maybe_unused]]\nSGL_DEVICE fp8x2_e4m3_t pack_fp8(float x, float y) {\n  return fp8x2_e4m3_t{fp32x2_t{fp8_e4m3_clip(x), fp8_e4m3_clip(y)}};\n}\n\ntemplate <typename KeyT, typename IndicesT, uint32_t kPageBits, bool kUsePDL>\n__global__ void fused_store_indexer_cache(const __grid_constant__ FusedStoreCacheParam param) {\n  using namespace device;\n\n  /// NOTE: 132 = 128 + 4\n  constexpr int64_t kPageBytes = 132 << kPageBits;\n\n  // each warp handles 128 elements, each block handles multiple rows\n  const auto& [input, cache, indices, num_tokens] = param;\n  const auto global_tid = blockIdx.x * blockDim.x + threadIdx.x;\n  const auto global_wid = global_tid / 32;\n  const auto lane_id = threadIdx.x % 32;\n\n  if (global_wid >= num_tokens) return;\n\n  PDLWaitPrimary<kUsePDL>();  // wait for primary kernel\n\n  // prefetch the index\n  const auto index = static_cast<const IndicesT*>(indices)[global_wid];\n  // always load the value from input (don't store if invalid)\n  using KeyT2 = packed_t<KeyT>;\n  using InStorage = AlignedVector<KeyT2, 2>;\n  using OutStorage = AlignedVector<fp8x2_e4m3_t, 2>;\n  const auto elems = static_cast<const InStorage*>(input)[global_tid];\n  const auto [x0, x1] = cast<fp32x2_t>(elems[0]);\n  const auto [y0, y1] = cast<fp32x2_t>(elems[1]);\n  const auto local_max = fmaxf(fmaxf(fabs(x0), fabs(x1)), fmaxf(fabs(y0), fabs(y1)));\n  const auto abs_max = warp::reduce_max(local_max);\n  // use normal fp32 scale\n  const auto scale = fmaxf(1e-4f, abs_max) / math::FP8_E4M3_MAX;\n  const auto inv_scale = 1.0f / scale;\n  const int32_t page = index >> kPageBits;\n  const int32_t offset = index & ((1 << kPageBits) - 1);\n  const auto page_ptr = pointer::offset(cache, page * kPageBytes);\n  const auto value_ptr = pointer::offset(page_ptr, offset * 128);\n  const auto scale_ptr = pointer::offset(page_ptr, 128 << kPageBits, offset * 4);\n  OutStorage result;\n  result[0] = pack_fp8(x0 * inv_scale, x1 * inv_scale);\n  result[1] = pack_fp8(y0 * inv_scale, y1 * inv_scale);\n  static_cast<OutStorage*>(value_ptr)[lane_id] = result;\n  static_cast<float*>(scale_ptr)[0] = scale;\n\n  PDLTriggerSecondary<kUsePDL>();  // launch secondary kernel\n}\n\ntemplate <typename KeyT, typename IndicesT, uint32_t kPageSize, bool kUsePDL>\nstruct FusedStoreCacheIndexerKernel {\n  static constexpr int32_t kLogSize = std::countr_zero(kPageSize);\n  /// NOTE: 132 = 128 + 4 (128 represent K and 4 represent scale)\n  static constexpr int64_t kPageBytes = 132 * kPageSize;\n  static constexpr auto kernel = fused_store_indexer_cache<KeyT, IndicesT, kLogSize, kUsePDL>;\n\n  static_assert(std::has_single_bit(kPageSize), \"kPageSize must be a power of 2\");\n  static_assert(1 << kLogSize == kPageSize);\n\n  static void run(tvm::ffi::TensorView input, tvm::ffi::TensorView cache, tvm::ffi::TensorView indices) {\n    using namespace host;\n\n    auto N = SymbolicSize{\"num_tokens\"};\n    auto device_ = SymbolicDevice{};\n    device_.set_options<kDLCUDA>();\n    TensorMatcher({N, 128})  // input\n        .with_dtype<KeyT>()\n        .with_device(device_)\n        .verify(input);\n    TensorMatcher({-1, -1})  // cache\n        .with_strides({kPageBytes, 1})\n        .with_dtype<uint8_t>()\n        .with_device(device_)\n        .verify(cache);\n    TensorMatcher({N})  // indices\n        .with_dtype<IndicesT>()\n        .with_device(device_)\n        .verify(indices);\n    const auto num_tokens = static_cast<uint32_t>(N.unwrap());\n    const auto params = FusedStoreCacheParam{\n        .input = input.data_ptr(),\n        .cache = cache.data_ptr(),\n        .indices = indices.data_ptr(),\n        .num_tokens = num_tokens,\n    };\n    const auto kBlockSize = 128;\n    const auto num_blocks = div_ceil(num_tokens * 32, kBlockSize);\n    LaunchKernel(num_blocks, kBlockSize, device_.unwrap()).enable_pdl(kUsePDL)(kernel, params);\n  }\n};\n\n}  // namespace\n"
  },
  {
    "path": "python/sglang/jit_kernel/cutedsl_gdn.py",
    "content": "\"\"\"CuTe DSL Fused Sigmoid Gating Delta Rule Kernel for GDN Decode.\"\"\"\n\nimport logging\nfrom typing import Dict, Optional, Tuple\n\nimport cuda.bindings.driver as cuda\nimport cutlass\nimport cutlass.cute as cute\nimport torch\nfrom cutlass.cute.nvgpu import cpasync\nfrom cutlass.cute.runtime import from_dlpack\n\nlogger = logging.getLogger(__name__)\n\n_compiled_kernels: Dict[Tuple, object] = {}\n_cu_seqlens_cache: Dict[Tuple, torch.Tensor] = {}\nTILE_K = 128\nTILE_V = 32\nTILE_V_PADDED = 36\nTILE_V_SMALL = 16\nTILE_V_SMALL_PADDED = 20\nNUM_STAGES = 2\nNUM_THREADS = 128\nNUM_BLOCKS_PER_STATE_SMALL = 8\nNUM_THREADS_LARGE = 256\nNUM_WARPS_LARGE = 8\nV_PER_WARP = 4\nROWS_PER_ITER = 8\nNUM_K_ITERS = TILE_K // ROWS_PER_ITER\nSMALL_BATCH_THRESHOLD = 32\n\n\ndef _define_kernels():\n    \"\"\"Define CuTe DSL kernels for normal and varlen decode modes.\"\"\"\n\n    NUM_WARPS_SMALL = 4\n    V_PER_WARP_SMALL = TILE_V_SMALL // NUM_WARPS_SMALL\n    ROWS_PER_ITER_SMALL = 32 // V_PER_WARP_SMALL\n    NUM_K_ITERS_SMALL = TILE_K // ROWS_PER_ITER_SMALL\n\n    @cute.kernel\n    def gdn_kernel_small_batch(\n        tiled_copy_load: cute.TiledCopy,\n        h0_source: cute.Tensor,\n        smem_layout_staged: cute.Layout,\n        num_v_tiles: cutlass.Constexpr[int],\n        q: cute.Tensor,\n        k: cute.Tensor,\n        v: cute.Tensor,\n        a: cute.Tensor,\n        b: cute.Tensor,\n        A_log: cute.Tensor,\n        dt_bias: cute.Tensor,\n        o: cute.Tensor,\n        h0_indices: cute.Tensor,\n        softplus_beta: cutlass.Constexpr[float],\n        softplus_threshold: cutlass.Constexpr[float],\n        scale: cutlass.Constexpr[float],\n        H: cutlass.Constexpr[int],\n        HV: cutlass.Constexpr[int],\n        use_qk_l2norm: cutlass.Constexpr[bool],\n    ):\n        \"\"\"Small batch kernel for (N, 1, ...) format.\"\"\"\n        tidx, _, _ = cute.arch.thread_idx()\n        in_warp_tid = tidx % 32\n        warp_idx = cute.arch.warp_idx()\n        warp_idx = cute.arch.make_warp_uniform(warp_idx)\n        block_idx, _, _ = cute.arch.block_idx()\n\n        batch_idx = block_idx // NUM_BLOCKS_PER_STATE_SMALL\n        batch_inner = block_idx % NUM_BLOCKS_PER_STATE_SMALL\n        num_v_tiles_per_block = num_v_tiles // NUM_BLOCKS_PER_STATE_SMALL\n        start_v_tile = batch_inner * num_v_tiles_per_block\n\n        i_n = batch_idx // HV\n        i_hv = batch_idx % HV\n        i_h = i_hv // (HV // H)\n\n        pool_idx = h0_indices[i_n]\n\n        if pool_idx >= 0:\n            k_local = in_warp_tid // V_PER_WARP_SMALL\n            v_local = in_warp_tid % V_PER_WARP_SMALL\n            v_base = warp_idx * V_PER_WARP_SMALL\n            v_idx = v_base + v_local\n\n            smem = cutlass.utils.SmemAllocator()\n            sData = smem.allocate_tensor(cutlass.Float32, smem_layout_staged, 128)\n            smem_o_layout = cute.make_layout((TILE_V_SMALL,), stride=(1,))\n            smem_o = smem.allocate_tensor(cutlass.Float32, smem_o_layout, 128)\n            smem_k_layout = cute.make_layout((TILE_K,), stride=(1,))\n            smem_q_layout = cute.make_layout((TILE_K,), stride=(1,))\n            sK = smem.allocate_tensor(cutlass.Float32, smem_k_layout, 128)\n            sQ = smem.allocate_tensor(cutlass.Float32, smem_q_layout, 128)\n\n            if tidx < TILE_K:\n                sK[tidx] = cutlass.Float32(k[i_n, 0, i_h, tidx])\n                sQ[tidx] = cutlass.Float32(q[i_n, 0, i_h, tidx])\n\n            gSrc_batch = h0_source[(pool_idx, i_hv, None, None)]\n            gSrc = cute.local_tile(gSrc_batch, (TILE_K, TILE_V_SMALL), (0, None))\n            thr_copy_load = tiled_copy_load.get_slice(tidx)\n\n            prefetch_count = cutlass.min(NUM_STAGES - 1, num_v_tiles_per_block)\n            for v_tile_offset in range(prefetch_count):\n                v_tile = start_v_tile + v_tile_offset\n                stage = v_tile_offset % NUM_STAGES\n                gSrc_tile = gSrc[(None, None, v_tile)]\n                sData_stage = sData[(None, None, stage)]\n                thr_gSrc = thr_copy_load.partition_S(gSrc_tile)\n                thr_sData = thr_copy_load.partition_D(sData_stage)\n                cute.copy(tiled_copy_load, thr_gSrc, thr_sData)\n                cute.arch.cp_async_commit_group()\n\n            r_A_log = cutlass.Float32(A_log[i_hv])\n            r_dt_bias = cutlass.Float32(dt_bias[i_hv])\n            r_a = cutlass.Float32(a[i_n, 0, i_hv])\n            r_b = cutlass.Float32(b[i_n, 0, i_hv])\n\n            r_g = 0.0\n            r_beta = 0.0\n            if in_warp_tid == 0:\n                x = r_a + r_dt_bias\n                beta_x = softplus_beta * x\n                softplus_x = 0.0\n                if beta_x <= softplus_threshold:\n                    exp_beta_x = cute.exp(beta_x)\n                    log_input = cutlass.Float32(1.0 + exp_beta_x)\n                    log_result = cutlass.Float32(cute.log(log_input))\n                    softplus_x = cutlass.Float32(\n                        (cutlass.Float32(1.0) / softplus_beta) * log_result\n                    )\n                else:\n                    softplus_x = x\n                r_g_value = -cute.exp(r_A_log) * softplus_x\n                r_beta = 1.0 / (1.0 + cute.exp(-r_b))\n                r_g = cute.exp(r_g_value)\n\n            r_g = cute.arch.shuffle_sync(r_g, 0)\n            r_beta = cute.arch.shuffle_sync(r_beta, 0)\n\n            cute.arch.barrier()\n\n            if use_qk_l2norm:\n                sum_q_partial = 0.0\n                sum_k_partial = 0.0\n                if tidx < TILE_K:\n                    q_val = sQ[tidx]\n                    k_val = sK[tidx]\n                    sum_q_partial = q_val * q_val\n                    sum_k_partial = k_val * k_val\n\n                for offset in [16, 8, 4, 2, 1]:\n                    sum_q_partial += cute.arch.shuffle_sync_bfly(\n                        sum_q_partial, offset=offset, mask=-1, mask_and_clamp=31\n                    )\n                    sum_k_partial += cute.arch.shuffle_sync_bfly(\n                        sum_k_partial, offset=offset, mask=-1, mask_and_clamp=31\n                    )\n\n                if in_warp_tid == 0:\n                    smem_o[warp_idx] = sum_q_partial\n                    smem_o[warp_idx + 4] = sum_k_partial\n                cute.arch.barrier()\n\n                inv_norm_q = 0.0\n                inv_norm_k = 0.0\n                if warp_idx == 0:\n                    local_sum_q = 0.0\n                    local_sum_k = 0.0\n                    if in_warp_tid < NUM_WARPS_SMALL:\n                        local_sum_q = smem_o[in_warp_tid]\n                        local_sum_k = smem_o[in_warp_tid + 4]\n                    for offset in [2, 1]:\n                        local_sum_q += cute.arch.shuffle_sync_bfly(\n                            local_sum_q, offset=offset, mask=-1, mask_and_clamp=31\n                        )\n                        local_sum_k += cute.arch.shuffle_sync_bfly(\n                            local_sum_k, offset=offset, mask=-1, mask_and_clamp=31\n                        )\n                    if in_warp_tid == 0:\n                        smem_o[0] = cute.rsqrt(local_sum_q + 1e-6)\n                        smem_o[1] = cute.rsqrt(local_sum_k + 1e-6)\n                cute.arch.barrier()\n\n                inv_norm_q = smem_o[0]\n                inv_norm_k = smem_o[1]\n\n                if tidx < TILE_K:\n                    sK[tidx] = sK[tidx] * inv_norm_k\n                    sQ[tidx] = sQ[tidx] * scale * inv_norm_q\n                cute.arch.barrier()\n            else:\n                if tidx < TILE_K:\n                    sQ[tidx] = sQ[tidx] * scale\n                cute.arch.barrier()\n\n            for v_tile_offset in range(num_v_tiles_per_block):\n                v_tile = start_v_tile + v_tile_offset\n                stage = v_tile_offset % NUM_STAGES\n\n                cute.arch.cp_async_wait_group(0)\n                cute.arch.barrier()\n\n                next_v_tile_offset = v_tile_offset + prefetch_count\n                if next_v_tile_offset < num_v_tiles_per_block:\n                    next_v_tile = start_v_tile + next_v_tile_offset\n                    next_stage = next_v_tile_offset % NUM_STAGES\n                    gSrc_next = gSrc[(None, None, next_v_tile)]\n                    sData_next = sData[(None, None, next_stage)]\n                    thr_gSrc = thr_copy_load.partition_S(gSrc_next)\n                    thr_sData = thr_copy_load.partition_D(sData_next)\n                    cute.copy(tiled_copy_load, thr_gSrc, thr_sData)\n                    cute.arch.cp_async_commit_group()\n\n                v_global = v_tile * TILE_V_SMALL + v_idx\n                r_v = cutlass.Float32(v[i_n, 0, i_hv, v_global])\n\n                sum_hk = 0.0\n                for k_iter in cutlass.range_dynamic(NUM_K_ITERS_SMALL, unroll=16):\n                    k_base = k_iter * ROWS_PER_ITER_SMALL\n                    k_idx = k_base + k_local\n                    h_val = sData[(k_idx, v_idx, stage)] * r_g\n                    r_k_val = sK[k_idx]\n                    sum_hk += h_val * r_k_val\n\n                for offset in [4, 2, 1]:\n                    sum_hk += cute.arch.shuffle_sync_bfly(\n                        sum_hk,\n                        offset=offset * V_PER_WARP_SMALL,\n                        mask=-1,\n                        mask_and_clamp=31,\n                    )\n\n                v_new = (r_v - sum_hk) * r_beta\n                v_new = cute.arch.shuffle_sync(v_new, v_local)\n\n                sum_hq = 0.0\n                for k_iter in cutlass.range_dynamic(NUM_K_ITERS_SMALL, unroll=16):\n                    k_base = k_iter * ROWS_PER_ITER_SMALL\n                    k_idx = k_base + k_local\n                    h_old = sData[(k_idx, v_idx, stage)] * r_g\n                    r_k_val = sK[k_idx]\n                    r_q_val = sQ[k_idx]\n                    h_new = h_old + r_k_val * v_new\n                    sData[(k_idx, v_idx, stage)] = h_new\n                    sum_hq += h_new * r_q_val\n\n                for offset in [4, 2, 1]:\n                    sum_hq += cute.arch.shuffle_sync_bfly(\n                        sum_hq,\n                        offset=offset * V_PER_WARP_SMALL,\n                        mask=-1,\n                        mask_and_clamp=31,\n                    )\n\n                if k_local == 0:\n                    v_global_out = v_tile * TILE_V_SMALL + v_idx\n                    o[(i_n, 0, i_hv, v_global_out)] = cutlass.BFloat16(sum_hq)\n\n                cute.arch.barrier()\n\n                for k_iter in range(NUM_K_ITERS_SMALL):\n                    flat_idx = tidx + k_iter * 128\n                    k_write = flat_idx // TILE_V_SMALL\n                    v_write = flat_idx % TILE_V_SMALL\n                    if k_write < TILE_K:\n                        h_val = sData[(k_write, v_write, stage)]\n                        v_global_write = v_tile * TILE_V_SMALL + v_write\n                        h0_source[(pool_idx, i_hv, k_write, v_global_write)] = h_val\n\n                cute.arch.barrier()\n\n    @cute.kernel\n    def gdn_kernel_small_batch_varlen(\n        tiled_copy_load: cute.TiledCopy,\n        h0_source: cute.Tensor,\n        smem_layout_staged: cute.Layout,\n        num_v_tiles: cutlass.Constexpr[int],\n        q: cute.Tensor,\n        k: cute.Tensor,\n        v: cute.Tensor,\n        a: cute.Tensor,\n        b: cute.Tensor,\n        A_log: cute.Tensor,\n        dt_bias: cute.Tensor,\n        o: cute.Tensor,\n        h0_indices: cute.Tensor,\n        softplus_beta: cutlass.Constexpr[float],\n        softplus_threshold: cutlass.Constexpr[float],\n        scale: cutlass.Constexpr[float],\n        H: cutlass.Constexpr[int],\n        HV: cutlass.Constexpr[int],\n        use_qk_l2norm: cutlass.Constexpr[bool],\n    ):\n        \"\"\"Small batch kernel for varlen decode (1, N, ...) format.\"\"\"\n        tidx, _, _ = cute.arch.thread_idx()\n        in_warp_tid = tidx % 32\n        warp_idx = cute.arch.warp_idx()\n        warp_idx = cute.arch.make_warp_uniform(warp_idx)\n        block_idx, _, _ = cute.arch.block_idx()\n\n        batch_idx = block_idx // NUM_BLOCKS_PER_STATE_SMALL\n        batch_inner = block_idx % NUM_BLOCKS_PER_STATE_SMALL\n        num_v_tiles_per_block = num_v_tiles // NUM_BLOCKS_PER_STATE_SMALL\n        start_v_tile = batch_inner * num_v_tiles_per_block\n\n        i_n = batch_idx // HV\n        i_hv = batch_idx % HV\n        i_h = i_hv // (HV // H)\n\n        pool_idx = h0_indices[i_n]\n\n        if pool_idx >= 0:\n            k_local = in_warp_tid // V_PER_WARP_SMALL\n            v_local = in_warp_tid % V_PER_WARP_SMALL\n            v_base = warp_idx * V_PER_WARP_SMALL\n            v_idx = v_base + v_local\n\n            smem = cutlass.utils.SmemAllocator()\n            sData = smem.allocate_tensor(cutlass.Float32, smem_layout_staged, 128)\n            smem_o_layout = cute.make_layout((TILE_V_SMALL,), stride=(1,))\n            smem_o = smem.allocate_tensor(cutlass.Float32, smem_o_layout, 128)\n            smem_k_layout = cute.make_layout((TILE_K,), stride=(1,))\n            smem_q_layout = cute.make_layout((TILE_K,), stride=(1,))\n            sK = smem.allocate_tensor(cutlass.Float32, smem_k_layout, 128)\n            sQ = smem.allocate_tensor(cutlass.Float32, smem_q_layout, 128)\n\n            if tidx < TILE_K:\n                sK[tidx] = cutlass.Float32(k[0, i_n, i_h, tidx])\n                sQ[tidx] = cutlass.Float32(q[0, i_n, i_h, tidx])\n\n            gSrc_batch = h0_source[(pool_idx, i_hv, None, None)]\n            gSrc = cute.local_tile(gSrc_batch, (TILE_K, TILE_V_SMALL), (0, None))\n            thr_copy_load = tiled_copy_load.get_slice(tidx)\n\n            prefetch_count = cutlass.min(NUM_STAGES - 1, num_v_tiles_per_block)\n            for v_tile_offset in range(prefetch_count):\n                v_tile = start_v_tile + v_tile_offset\n                stage = v_tile_offset % NUM_STAGES\n                gSrc_tile = gSrc[(None, None, v_tile)]\n                sData_stage = sData[(None, None, stage)]\n                thr_gSrc = thr_copy_load.partition_S(gSrc_tile)\n                thr_sData = thr_copy_load.partition_D(sData_stage)\n                cute.copy(tiled_copy_load, thr_gSrc, thr_sData)\n                cute.arch.cp_async_commit_group()\n\n            r_A_log = cutlass.Float32(A_log[i_hv])\n            r_dt_bias = cutlass.Float32(dt_bias[i_hv])\n            r_a = cutlass.Float32(a[i_n, i_hv])\n            r_b = cutlass.Float32(b[i_n, i_hv])\n\n            r_g = 0.0\n            r_beta = 0.0\n            if in_warp_tid == 0:\n                x = r_a + r_dt_bias\n                beta_x = softplus_beta * x\n                softplus_x = 0.0\n                if beta_x <= softplus_threshold:\n                    exp_beta_x = cute.exp(beta_x)\n                    log_input = cutlass.Float32(1.0 + exp_beta_x)\n                    log_result = cutlass.Float32(cute.log(log_input))\n                    softplus_x = cutlass.Float32(\n                        (cutlass.Float32(1.0) / softplus_beta) * log_result\n                    )\n                else:\n                    softplus_x = x\n                r_g_value = -cute.exp(r_A_log) * softplus_x\n                r_beta = 1.0 / (1.0 + cute.exp(-r_b))\n                r_g = cute.exp(r_g_value)\n\n            r_g = cute.arch.shuffle_sync(r_g, 0)\n            r_beta = cute.arch.shuffle_sync(r_beta, 0)\n\n            cute.arch.barrier()\n\n            if use_qk_l2norm:\n                sum_q_partial = 0.0\n                sum_k_partial = 0.0\n                if tidx < TILE_K:\n                    q_val = sQ[tidx]\n                    k_val = sK[tidx]\n                    sum_q_partial = q_val * q_val\n                    sum_k_partial = k_val * k_val\n\n                for offset in [16, 8, 4, 2, 1]:\n                    sum_q_partial += cute.arch.shuffle_sync_bfly(\n                        sum_q_partial, offset=offset, mask=-1, mask_and_clamp=31\n                    )\n                    sum_k_partial += cute.arch.shuffle_sync_bfly(\n                        sum_k_partial, offset=offset, mask=-1, mask_and_clamp=31\n                    )\n\n                if in_warp_tid == 0:\n                    smem_o[warp_idx] = sum_q_partial\n                    smem_o[warp_idx + 4] = sum_k_partial\n                cute.arch.barrier()\n\n                inv_norm_q = 0.0\n                inv_norm_k = 0.0\n                if warp_idx == 0:\n                    local_sum_q = 0.0\n                    local_sum_k = 0.0\n                    if in_warp_tid < NUM_WARPS_SMALL:\n                        local_sum_q = smem_o[in_warp_tid]\n                        local_sum_k = smem_o[in_warp_tid + 4]\n                    for offset in [2, 1]:\n                        local_sum_q += cute.arch.shuffle_sync_bfly(\n                            local_sum_q, offset=offset, mask=-1, mask_and_clamp=31\n                        )\n                        local_sum_k += cute.arch.shuffle_sync_bfly(\n                            local_sum_k, offset=offset, mask=-1, mask_and_clamp=31\n                        )\n                    if in_warp_tid == 0:\n                        smem_o[0] = cute.rsqrt(local_sum_q + 1e-6)\n                        smem_o[1] = cute.rsqrt(local_sum_k + 1e-6)\n                cute.arch.barrier()\n\n                inv_norm_q = smem_o[0]\n                inv_norm_k = smem_o[1]\n\n                if tidx < TILE_K:\n                    sK[tidx] = sK[tidx] * inv_norm_k\n                    sQ[tidx] = sQ[tidx] * scale * inv_norm_q\n                cute.arch.barrier()\n            else:\n                if tidx < TILE_K:\n                    sQ[tidx] = sQ[tidx] * scale\n                cute.arch.barrier()\n\n            for v_tile_offset in range(num_v_tiles_per_block):\n                v_tile = start_v_tile + v_tile_offset\n                stage = v_tile_offset % NUM_STAGES\n\n                cute.arch.cp_async_wait_group(0)\n                cute.arch.barrier()\n\n                next_v_tile_offset = v_tile_offset + prefetch_count\n                if next_v_tile_offset < num_v_tiles_per_block:\n                    next_v_tile = start_v_tile + next_v_tile_offset\n                    next_stage = next_v_tile_offset % NUM_STAGES\n                    gSrc_next = gSrc[(None, None, next_v_tile)]\n                    sData_next = sData[(None, None, next_stage)]\n                    thr_gSrc = thr_copy_load.partition_S(gSrc_next)\n                    thr_sData = thr_copy_load.partition_D(sData_next)\n                    cute.copy(tiled_copy_load, thr_gSrc, thr_sData)\n                    cute.arch.cp_async_commit_group()\n\n                v_global = v_tile * TILE_V_SMALL + v_idx\n                r_v = cutlass.Float32(v[0, i_n, i_hv, v_global])\n\n                sum_hk = 0.0\n                for k_iter in cutlass.range_dynamic(NUM_K_ITERS_SMALL, unroll=16):\n                    k_base = k_iter * ROWS_PER_ITER_SMALL\n                    k_idx = k_base + k_local\n                    h_val = sData[(k_idx, v_idx, stage)] * r_g\n                    r_k_val = sK[k_idx]\n                    sum_hk += h_val * r_k_val\n\n                for offset in [4, 2, 1]:\n                    sum_hk += cute.arch.shuffle_sync_bfly(\n                        sum_hk,\n                        offset=offset * V_PER_WARP_SMALL,\n                        mask=-1,\n                        mask_and_clamp=31,\n                    )\n\n                v_new = (r_v - sum_hk) * r_beta\n                v_new = cute.arch.shuffle_sync(v_new, v_local)\n\n                sum_hq = 0.0\n                for k_iter in cutlass.range_dynamic(NUM_K_ITERS_SMALL, unroll=16):\n                    k_base = k_iter * ROWS_PER_ITER_SMALL\n                    k_idx = k_base + k_local\n                    h_old = sData[(k_idx, v_idx, stage)] * r_g\n                    r_k_val = sK[k_idx]\n                    r_q_val = sQ[k_idx]\n                    h_new = h_old + r_k_val * v_new\n                    sData[(k_idx, v_idx, stage)] = h_new\n                    sum_hq += h_new * r_q_val\n\n                for offset in [4, 2, 1]:\n                    sum_hq += cute.arch.shuffle_sync_bfly(\n                        sum_hq,\n                        offset=offset * V_PER_WARP_SMALL,\n                        mask=-1,\n                        mask_and_clamp=31,\n                    )\n\n                if k_local == 0:\n                    v_global_out = v_tile * TILE_V_SMALL + v_idx\n                    o[(0, i_n, i_hv, v_global_out)] = cutlass.BFloat16(sum_hq)\n\n                cute.arch.barrier()\n\n                for k_iter in range(NUM_K_ITERS_SMALL):\n                    flat_idx = tidx + k_iter * 128\n                    k_write = flat_idx // TILE_V_SMALL\n                    v_write = flat_idx % TILE_V_SMALL\n                    if k_write < TILE_K:\n                        h_val = sData[(k_write, v_write, stage)]\n                        v_global_write = v_tile * TILE_V_SMALL + v_write\n                        h0_source[(pool_idx, i_hv, k_write, v_global_write)] = h_val\n\n                cute.arch.barrier()\n\n    @cute.kernel\n    def gdn_kernel_large_batch(\n        tiled_copy_load: cute.TiledCopy,\n        h0_source: cute.Tensor,\n        smem_layout_staged: cute.Layout,\n        num_v_tiles: cutlass.Constexpr[int],\n        q: cute.Tensor,\n        k: cute.Tensor,\n        v: cute.Tensor,\n        a: cute.Tensor,\n        b: cute.Tensor,\n        A_log: cute.Tensor,\n        dt_bias: cute.Tensor,\n        o: cute.Tensor,\n        h0_indices: cute.Tensor,\n        softplus_beta: cutlass.Constexpr[float],\n        softplus_threshold: cutlass.Constexpr[float],\n        scale: cutlass.Constexpr[float],\n        H: cutlass.Constexpr[int],\n        HV: cutlass.Constexpr[int],\n        use_qk_l2norm: cutlass.Constexpr[bool],\n    ):\n        \"\"\"Large batch kernel for (N, 1, ...) format.\"\"\"\n        tidx, _, _ = cute.arch.thread_idx()\n        in_warp_tid = tidx % 32\n        warp_idx = cute.arch.warp_idx()\n        warp_idx = cute.arch.make_warp_uniform(warp_idx)\n        batch_idx, _, _ = cute.arch.block_idx()\n        i_n = batch_idx // HV\n        i_hv = batch_idx % HV\n        i_h = i_hv // (HV // H)\n\n        pool_idx = h0_indices[i_n]\n\n        if pool_idx >= 0:\n            k_local = in_warp_tid // V_PER_WARP\n            v_local = in_warp_tid % V_PER_WARP\n            v_base = warp_idx * V_PER_WARP\n            v_idx = v_base + v_local\n\n            smem = cutlass.utils.SmemAllocator()\n            sData = smem.allocate_tensor(cutlass.Float32, smem_layout_staged, 128)\n            smem_o_layout = cute.make_layout((TILE_V,), stride=(1,))\n            smem_o = smem.allocate_tensor(cutlass.Float32, smem_o_layout, 128)\n            smem_k_layout = cute.make_layout((TILE_K,), stride=(1,))\n            smem_q_layout = cute.make_layout((TILE_K,), stride=(1,))\n            sK = smem.allocate_tensor(cutlass.Float32, smem_k_layout, 128)\n            sQ = smem.allocate_tensor(cutlass.Float32, smem_q_layout, 128)\n\n            if tidx < TILE_K:\n                sK[tidx] = cutlass.Float32(k[i_n, 0, i_h, tidx])\n                sQ[tidx] = cutlass.Float32(q[i_n, 0, i_h, tidx])\n\n            gSrc_batch = h0_source[(pool_idx, i_hv, None, None)]\n            gSrc = cute.local_tile(gSrc_batch, (TILE_K, TILE_V), (0, None))\n            thr_copy_load = tiled_copy_load.get_slice(tidx)\n\n            prefetch_count = cutlass.min(NUM_STAGES - 1, num_v_tiles)\n            for v_tile in range(prefetch_count):\n                stage = v_tile % NUM_STAGES\n                gSrc_tile = gSrc[(None, None, v_tile)]\n                sData_stage = sData[(None, None, stage)]\n                thr_gSrc = thr_copy_load.partition_S(gSrc_tile)\n                thr_sData = thr_copy_load.partition_D(sData_stage)\n                cute.copy(tiled_copy_load, thr_gSrc, thr_sData)\n                cute.arch.cp_async_commit_group()\n\n            r_A_log = cutlass.Float32(A_log[i_hv])\n            r_dt_bias = cutlass.Float32(dt_bias[i_hv])\n            r_a = cutlass.Float32(a[i_n, 0, i_hv])\n            r_b = cutlass.Float32(b[i_n, 0, i_hv])\n\n            r_g = 0.0\n            r_beta = 0.0\n            if in_warp_tid == 0:\n                x = r_a + r_dt_bias\n                beta_x = softplus_beta * x\n                softplus_x = 0.0\n                if beta_x <= softplus_threshold:\n                    exp_beta_x = cute.exp(beta_x)\n                    log_input = cutlass.Float32(1.0 + exp_beta_x)\n                    log_result = cutlass.Float32(cute.log(log_input))\n                    softplus_x = cutlass.Float32(\n                        (cutlass.Float32(1.0) / softplus_beta) * log_result\n                    )\n                else:\n                    softplus_x = x\n                r_g_value = -cute.exp(r_A_log) * softplus_x\n                r_beta = 1.0 / (1.0 + cute.exp(-r_b))\n                r_g = cute.exp(r_g_value)\n\n            r_g = cute.arch.shuffle_sync(r_g, 0)\n            r_beta = cute.arch.shuffle_sync(r_beta, 0)\n\n            cute.arch.barrier()\n\n            if use_qk_l2norm:\n                sum_q_partial = 0.0\n                sum_k_partial = 0.0\n                if tidx < TILE_K:\n                    q_val = sQ[tidx]\n                    k_val = sK[tidx]\n                    sum_q_partial = q_val * q_val\n                    sum_k_partial = k_val * k_val\n\n                for offset in [16, 8, 4, 2, 1]:\n                    sum_q_partial += cute.arch.shuffle_sync_bfly(\n                        sum_q_partial, offset=offset, mask=-1, mask_and_clamp=31\n                    )\n                    sum_k_partial += cute.arch.shuffle_sync_bfly(\n                        sum_k_partial, offset=offset, mask=-1, mask_and_clamp=31\n                    )\n\n                if in_warp_tid == 0:\n                    smem_o[warp_idx] = sum_q_partial\n                    smem_o[warp_idx + 8] = sum_k_partial\n                cute.arch.barrier()\n\n                inv_norm_q = 0.0\n                inv_norm_k = 0.0\n                if warp_idx == 0:\n                    local_sum_q = 0.0\n                    local_sum_k = 0.0\n                    if in_warp_tid < NUM_WARPS_LARGE:\n                        local_sum_q = smem_o[in_warp_tid]\n                        local_sum_k = smem_o[in_warp_tid + 8]\n                    for offset in [4, 2, 1]:\n                        local_sum_q += cute.arch.shuffle_sync_bfly(\n                            local_sum_q, offset=offset, mask=-1, mask_and_clamp=31\n                        )\n                        local_sum_k += cute.arch.shuffle_sync_bfly(\n                            local_sum_k, offset=offset, mask=-1, mask_and_clamp=31\n                        )\n                    if in_warp_tid == 0:\n                        smem_o[0] = cute.rsqrt(local_sum_q + 1e-6)\n                        smem_o[1] = cute.rsqrt(local_sum_k + 1e-6)\n                cute.arch.barrier()\n\n                inv_norm_q = smem_o[0]\n                inv_norm_k = smem_o[1]\n\n                if tidx < TILE_K:\n                    sK[tidx] = sK[tidx] * inv_norm_k\n                    sQ[tidx] = sQ[tidx] * scale * inv_norm_q\n                cute.arch.barrier()\n            else:\n                if tidx < TILE_K:\n                    sQ[tidx] = sQ[tidx] * scale\n                cute.arch.barrier()\n\n            for v_tile in range(num_v_tiles):\n                stage = v_tile % NUM_STAGES\n\n                cute.arch.cp_async_wait_group(0)\n                cute.arch.barrier()\n\n                next_v_tile = v_tile + prefetch_count\n                if next_v_tile < num_v_tiles:\n                    next_stage = next_v_tile % NUM_STAGES\n                    gSrc_next = gSrc[(None, None, next_v_tile)]\n                    sData_next = sData[(None, None, next_stage)]\n                    thr_gSrc = thr_copy_load.partition_S(gSrc_next)\n                    thr_sData = thr_copy_load.partition_D(sData_next)\n                    cute.copy(tiled_copy_load, thr_gSrc, thr_sData)\n                    cute.arch.cp_async_commit_group()\n\n                v_global = v_tile * TILE_V + v_idx\n                r_v = cutlass.Float32(v[i_n, 0, i_hv, v_global])\n\n                sum_hk = 0.0\n                for k_iter in cutlass.range_dynamic(NUM_K_ITERS, unroll=8):\n                    k_base = k_iter * ROWS_PER_ITER\n                    k_idx = k_base + k_local\n                    h_val = sData[(k_idx, v_idx, stage)] * r_g\n                    r_k_val = sK[k_idx]\n                    sum_hk += h_val * r_k_val\n\n                for offset in [4, 2, 1]:\n                    sum_hk += cute.arch.shuffle_sync_bfly(\n                        sum_hk, offset=offset * V_PER_WARP, mask=-1, mask_and_clamp=31\n                    )\n\n                v_new = (r_v - sum_hk) * r_beta\n                v_new = cute.arch.shuffle_sync(v_new, v_local)\n\n                sum_hq = 0.0\n                for k_iter in cutlass.range_dynamic(NUM_K_ITERS, unroll=8):\n                    k_base = k_iter * ROWS_PER_ITER\n                    k_idx = k_base + k_local\n                    h_old = sData[(k_idx, v_idx, stage)] * r_g\n                    r_k_val = sK[k_idx]\n                    r_q_val = sQ[k_idx]\n                    h_new = h_old + r_k_val * v_new\n                    sData[(k_idx, v_idx, stage)] = h_new\n                    sum_hq += h_new * r_q_val\n\n                for offset in [4, 2, 1]:\n                    sum_hq += cute.arch.shuffle_sync_bfly(\n                        sum_hq, offset=offset * V_PER_WARP, mask=-1, mask_and_clamp=31\n                    )\n\n                if k_local == 0:\n                    v_global_out = v_tile * TILE_V + v_idx\n                    o[(i_n, 0, i_hv, v_global_out)] = cutlass.BFloat16(sum_hq)\n\n                cute.arch.barrier()\n\n                for k_iter in range(NUM_K_ITERS):\n                    flat_idx = tidx + k_iter * 256\n                    k_write = flat_idx // TILE_V\n                    v_write = flat_idx % TILE_V\n                    if k_write < TILE_K:\n                        h_val = sData[(k_write, v_write, stage)]\n                        v_global_write = v_tile * TILE_V + v_write\n                        h0_source[(pool_idx, i_hv, k_write, v_global_write)] = h_val\n\n                cute.arch.barrier()\n\n    @cute.kernel\n    def gdn_kernel_large_batch_varlen(\n        tiled_copy_load: cute.TiledCopy,\n        h0_source: cute.Tensor,\n        smem_layout_staged: cute.Layout,\n        num_v_tiles: cutlass.Constexpr[int],\n        q: cute.Tensor,\n        k: cute.Tensor,\n        v: cute.Tensor,\n        a: cute.Tensor,\n        b: cute.Tensor,\n        A_log: cute.Tensor,\n        dt_bias: cute.Tensor,\n        o: cute.Tensor,\n        h0_indices: cute.Tensor,\n        softplus_beta: cutlass.Constexpr[float],\n        softplus_threshold: cutlass.Constexpr[float],\n        scale: cutlass.Constexpr[float],\n        H: cutlass.Constexpr[int],\n        HV: cutlass.Constexpr[int],\n        use_qk_l2norm: cutlass.Constexpr[bool],\n    ):\n        \"\"\"Large batch kernel for varlen decode (1, N, ...) format.\"\"\"\n        tidx, _, _ = cute.arch.thread_idx()\n        in_warp_tid = tidx % 32\n        warp_idx = cute.arch.warp_idx()\n        warp_idx = cute.arch.make_warp_uniform(warp_idx)\n        batch_idx, _, _ = cute.arch.block_idx()\n        i_n = batch_idx // HV\n        i_hv = batch_idx % HV\n        i_h = i_hv // (HV // H)\n\n        pool_idx = h0_indices[i_n]\n\n        if pool_idx >= 0:\n            k_local = in_warp_tid // V_PER_WARP\n            v_local = in_warp_tid % V_PER_WARP\n            v_base = warp_idx * V_PER_WARP\n            v_idx = v_base + v_local\n\n            smem = cutlass.utils.SmemAllocator()\n            sData = smem.allocate_tensor(cutlass.Float32, smem_layout_staged, 128)\n            smem_o_layout = cute.make_layout((TILE_V,), stride=(1,))\n            smem_o = smem.allocate_tensor(cutlass.Float32, smem_o_layout, 128)\n            smem_k_layout = cute.make_layout((TILE_K,), stride=(1,))\n            smem_q_layout = cute.make_layout((TILE_K,), stride=(1,))\n            sK = smem.allocate_tensor(cutlass.Float32, smem_k_layout, 128)\n            sQ = smem.allocate_tensor(cutlass.Float32, smem_q_layout, 128)\n\n            if tidx < TILE_K:\n                sK[tidx] = cutlass.Float32(k[0, i_n, i_h, tidx])\n                sQ[tidx] = cutlass.Float32(q[0, i_n, i_h, tidx])\n\n            gSrc_batch = h0_source[(pool_idx, i_hv, None, None)]\n            gSrc = cute.local_tile(gSrc_batch, (TILE_K, TILE_V), (0, None))\n            thr_copy_load = tiled_copy_load.get_slice(tidx)\n\n            prefetch_count = cutlass.min(NUM_STAGES - 1, num_v_tiles)\n            for v_tile in range(prefetch_count):\n                stage = v_tile % NUM_STAGES\n                gSrc_tile = gSrc[(None, None, v_tile)]\n                sData_stage = sData[(None, None, stage)]\n                thr_gSrc = thr_copy_load.partition_S(gSrc_tile)\n                thr_sData = thr_copy_load.partition_D(sData_stage)\n                cute.copy(tiled_copy_load, thr_gSrc, thr_sData)\n                cute.arch.cp_async_commit_group()\n\n            r_A_log = cutlass.Float32(A_log[i_hv])\n            r_dt_bias = cutlass.Float32(dt_bias[i_hv])\n            r_a = cutlass.Float32(a[i_n, i_hv])\n            r_b = cutlass.Float32(b[i_n, i_hv])\n\n            r_g = 0.0\n            r_beta = 0.0\n            if in_warp_tid == 0:\n                x = r_a + r_dt_bias\n                beta_x = softplus_beta * x\n                softplus_x = 0.0\n                if beta_x <= softplus_threshold:\n                    exp_beta_x = cute.exp(beta_x)\n                    log_input = cutlass.Float32(1.0 + exp_beta_x)\n                    log_result = cutlass.Float32(cute.log(log_input))\n                    softplus_x = cutlass.Float32(\n                        (cutlass.Float32(1.0) / softplus_beta) * log_result\n                    )\n                else:\n                    softplus_x = x\n                r_g_value = -cute.exp(r_A_log) * softplus_x\n                r_beta = 1.0 / (1.0 + cute.exp(-r_b))\n                r_g = cute.exp(r_g_value)\n\n            r_g = cute.arch.shuffle_sync(r_g, 0)\n            r_beta = cute.arch.shuffle_sync(r_beta, 0)\n\n            cute.arch.barrier()\n\n            if use_qk_l2norm:\n                sum_q_partial = 0.0\n                sum_k_partial = 0.0\n                if tidx < TILE_K:\n                    q_val = sQ[tidx]\n                    k_val = sK[tidx]\n                    sum_q_partial = q_val * q_val\n                    sum_k_partial = k_val * k_val\n\n                for offset in [16, 8, 4, 2, 1]:\n                    sum_q_partial += cute.arch.shuffle_sync_bfly(\n                        sum_q_partial, offset=offset, mask=-1, mask_and_clamp=31\n                    )\n                    sum_k_partial += cute.arch.shuffle_sync_bfly(\n                        sum_k_partial, offset=offset, mask=-1, mask_and_clamp=31\n                    )\n\n                if in_warp_tid == 0:\n                    smem_o[warp_idx] = sum_q_partial\n                    smem_o[warp_idx + 8] = sum_k_partial\n                cute.arch.barrier()\n\n                inv_norm_q = 0.0\n                inv_norm_k = 0.0\n                if warp_idx == 0:\n                    local_sum_q = 0.0\n                    local_sum_k = 0.0\n                    if in_warp_tid < NUM_WARPS_LARGE:\n                        local_sum_q = smem_o[in_warp_tid]\n                        local_sum_k = smem_o[in_warp_tid + 8]\n                    for offset in [4, 2, 1]:\n                        local_sum_q += cute.arch.shuffle_sync_bfly(\n                            local_sum_q, offset=offset, mask=-1, mask_and_clamp=31\n                        )\n                        local_sum_k += cute.arch.shuffle_sync_bfly(\n                            local_sum_k, offset=offset, mask=-1, mask_and_clamp=31\n                        )\n                    if in_warp_tid == 0:\n                        smem_o[0] = cute.rsqrt(local_sum_q + 1e-6)\n                        smem_o[1] = cute.rsqrt(local_sum_k + 1e-6)\n                cute.arch.barrier()\n\n                inv_norm_q = smem_o[0]\n                inv_norm_k = smem_o[1]\n\n                if tidx < TILE_K:\n                    sK[tidx] = sK[tidx] * inv_norm_k\n                    sQ[tidx] = sQ[tidx] * scale * inv_norm_q\n                cute.arch.barrier()\n            else:\n                if tidx < TILE_K:\n                    sQ[tidx] = sQ[tidx] * scale\n                cute.arch.barrier()\n\n            for v_tile in range(num_v_tiles):\n                stage = v_tile % NUM_STAGES\n\n                cute.arch.cp_async_wait_group(0)\n                cute.arch.barrier()\n\n                next_v_tile = v_tile + prefetch_count\n                if next_v_tile < num_v_tiles:\n                    next_stage = next_v_tile % NUM_STAGES\n                    gSrc_next = gSrc[(None, None, next_v_tile)]\n                    sData_next = sData[(None, None, next_stage)]\n                    thr_gSrc = thr_copy_load.partition_S(gSrc_next)\n                    thr_sData = thr_copy_load.partition_D(sData_next)\n                    cute.copy(tiled_copy_load, thr_gSrc, thr_sData)\n                    cute.arch.cp_async_commit_group()\n\n                v_global = v_tile * TILE_V + v_idx\n                r_v = cutlass.Float32(v[0, i_n, i_hv, v_global])\n\n                sum_hk = 0.0\n                for k_iter in cutlass.range_dynamic(NUM_K_ITERS, unroll=8):\n                    k_base = k_iter * ROWS_PER_ITER\n                    k_idx = k_base + k_local\n                    h_val = sData[(k_idx, v_idx, stage)] * r_g\n                    r_k_val = sK[k_idx]\n                    sum_hk += h_val * r_k_val\n\n                for offset in [4, 2, 1]:\n                    sum_hk += cute.arch.shuffle_sync_bfly(\n                        sum_hk, offset=offset * V_PER_WARP, mask=-1, mask_and_clamp=31\n                    )\n\n                v_new = (r_v - sum_hk) * r_beta\n                v_new = cute.arch.shuffle_sync(v_new, v_local)\n\n                sum_hq = 0.0\n                for k_iter in cutlass.range_dynamic(NUM_K_ITERS, unroll=8):\n                    k_base = k_iter * ROWS_PER_ITER\n                    k_idx = k_base + k_local\n                    h_old = sData[(k_idx, v_idx, stage)] * r_g\n                    r_k_val = sK[k_idx]\n                    r_q_val = sQ[k_idx]\n                    h_new = h_old + r_k_val * v_new\n                    sData[(k_idx, v_idx, stage)] = h_new\n                    sum_hq += h_new * r_q_val\n\n                for offset in [4, 2, 1]:\n                    sum_hq += cute.arch.shuffle_sync_bfly(\n                        sum_hq, offset=offset * V_PER_WARP, mask=-1, mask_and_clamp=31\n                    )\n\n                if k_local == 0:\n                    v_global_out = v_tile * TILE_V + v_idx\n                    o[(0, i_n, i_hv, v_global_out)] = cutlass.BFloat16(sum_hq)\n\n                cute.arch.barrier()\n\n                for k_iter in range(NUM_K_ITERS):\n                    flat_idx = tidx + k_iter * 256\n                    k_write = flat_idx // TILE_V\n                    v_write = flat_idx % TILE_V\n                    if k_write < TILE_K:\n                        h_val = sData[(k_write, v_write, stage)]\n                        v_global_write = v_tile * TILE_V + v_write\n                        h0_source[(pool_idx, i_hv, k_write, v_global_write)] = h_val\n\n                cute.arch.barrier()\n\n    return (\n        gdn_kernel_small_batch,\n        gdn_kernel_small_batch_varlen,\n        gdn_kernel_large_batch,\n        gdn_kernel_large_batch_varlen,\n    )\n\n\ndef _create_jit_functions():\n    \"\"\"Create JIT-compiled launcher functions for all kernel variants.\"\"\"\n\n    gdn_small, gdn_small_varlen, gdn_large, gdn_large_varlen = _define_kernels()\n\n    @cute.jit\n    def run_small_batch(\n        cu_seqlens: cute.Tensor,\n        q: cute.Tensor,\n        k: cute.Tensor,\n        v: cute.Tensor,\n        a: cute.Tensor,\n        b: cute.Tensor,\n        A_log: cute.Tensor,\n        dt_bias: cute.Tensor,\n        h0_source: cute.Tensor,\n        h0_indices: cute.Tensor,\n        o: cute.Tensor,\n        softplus_beta: cutlass.Constexpr[float],\n        softplus_threshold: cutlass.Constexpr[float],\n        scale: cutlass.Constexpr[float],\n        B: cutlass.Constexpr[int],\n        T: cutlass.Constexpr[int],\n        H: cutlass.Constexpr[int],\n        HV: cutlass.Constexpr[int],\n        K: cutlass.Constexpr[int],\n        V: cutlass.Constexpr[int],\n        use_initial_state: cutlass.Constexpr[bool],\n        use_qk_l2norm: cutlass.Constexpr[bool],\n        stream: cuda.CUstream,\n    ):\n        pool_size, hv_dim, k_dim, v_dim = h0_source.layout.shape\n        n_indices = h0_indices.layout.shape[0]\n        batch_size = n_indices * hv_dim\n\n        copy_atom = cute.make_copy_atom(\n            cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),\n            cutlass.Float32,\n            num_bits_per_copy=128,\n        )\n        num_v_tiles_small = cute.ceil_div(v_dim, TILE_V_SMALL)\n        smem_layout_small = cute.make_layout(\n            (TILE_K, TILE_V_SMALL, NUM_STAGES),\n            stride=(TILE_V_SMALL_PADDED, 1, TILE_K * TILE_V_SMALL_PADDED),\n        )\n        thread_layout_small = cute.make_layout((32, 4), stride=(4, 1))\n        val_layout_small = cute.make_layout((1, 4))\n        tiled_copy_load_small = cute.make_tiled_copy_tv(\n            copy_atom, thread_layout_small, val_layout_small\n        )\n        smem_bytes_small = (\n            4 * TILE_K * TILE_V_SMALL_PADDED * NUM_STAGES\n            + 4 * TILE_V_SMALL\n            + 4 * TILE_K * 2\n            + 64\n        )\n\n        gdn_small(\n            tiled_copy_load_small,\n            h0_source,\n            smem_layout_small,\n            num_v_tiles_small,\n            q,\n            k,\n            v,\n            a,\n            b,\n            A_log,\n            dt_bias,\n            o,\n            h0_indices,\n            softplus_beta,\n            softplus_threshold,\n            scale,\n            H,\n            HV,\n            use_qk_l2norm,\n        ).launch(\n            grid=(batch_size * NUM_BLOCKS_PER_STATE_SMALL, 1, 1),\n            block=[NUM_THREADS, 1, 1],\n            smem=smem_bytes_small,\n            stream=stream,\n        )\n\n    @cute.jit\n    def run_small_batch_varlen(\n        cu_seqlens: cute.Tensor,\n        q: cute.Tensor,\n        k: cute.Tensor,\n        v: cute.Tensor,\n        a: cute.Tensor,\n        b: cute.Tensor,\n        A_log: cute.Tensor,\n        dt_bias: cute.Tensor,\n        h0_source: cute.Tensor,\n        h0_indices: cute.Tensor,\n        o: cute.Tensor,\n        softplus_beta: cutlass.Constexpr[float],\n        softplus_threshold: cutlass.Constexpr[float],\n        scale: cutlass.Constexpr[float],\n        B: cutlass.Constexpr[int],\n        T: cutlass.Constexpr[int],\n        H: cutlass.Constexpr[int],\n        HV: cutlass.Constexpr[int],\n        K: cutlass.Constexpr[int],\n        V: cutlass.Constexpr[int],\n        use_initial_state: cutlass.Constexpr[bool],\n        use_qk_l2norm: cutlass.Constexpr[bool],\n        stream: cuda.CUstream,\n    ):\n        pool_size, hv_dim, k_dim, v_dim = h0_source.layout.shape\n        n_indices = h0_indices.layout.shape[0]\n        batch_size = n_indices * hv_dim\n\n        copy_atom = cute.make_copy_atom(\n            cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),\n            cutlass.Float32,\n            num_bits_per_copy=128,\n        )\n        num_v_tiles_small = cute.ceil_div(v_dim, TILE_V_SMALL)\n        smem_layout_small = cute.make_layout(\n            (TILE_K, TILE_V_SMALL, NUM_STAGES),\n            stride=(TILE_V_SMALL_PADDED, 1, TILE_K * TILE_V_SMALL_PADDED),\n        )\n        thread_layout_small = cute.make_layout((32, 4), stride=(4, 1))\n        val_layout_small = cute.make_layout((1, 4))\n        tiled_copy_load_small = cute.make_tiled_copy_tv(\n            copy_atom, thread_layout_small, val_layout_small\n        )\n        smem_bytes_small = (\n            4 * TILE_K * TILE_V_SMALL_PADDED * NUM_STAGES\n            + 4 * TILE_V_SMALL\n            + 4 * TILE_K * 2\n            + 64\n        )\n\n        gdn_small_varlen(\n            tiled_copy_load_small,\n            h0_source,\n            smem_layout_small,\n            num_v_tiles_small,\n            q,\n            k,\n            v,\n            a,\n            b,\n            A_log,\n            dt_bias,\n            o,\n            h0_indices,\n            softplus_beta,\n            softplus_threshold,\n            scale,\n            H,\n            HV,\n            use_qk_l2norm,\n        ).launch(\n            grid=(batch_size * NUM_BLOCKS_PER_STATE_SMALL, 1, 1),\n            block=[NUM_THREADS, 1, 1],\n            smem=smem_bytes_small,\n            stream=stream,\n        )\n\n    @cute.jit\n    def run_large_batch(\n        cu_seqlens: cute.Tensor,\n        q: cute.Tensor,\n        k: cute.Tensor,\n        v: cute.Tensor,\n        a: cute.Tensor,\n        b: cute.Tensor,\n        A_log: cute.Tensor,\n        dt_bias: cute.Tensor,\n        h0_source: cute.Tensor,\n        h0_indices: cute.Tensor,\n        o: cute.Tensor,\n        softplus_beta: cutlass.Constexpr[float],\n        softplus_threshold: cutlass.Constexpr[float],\n        scale: cutlass.Constexpr[float],\n        B: cutlass.Constexpr[int],\n        T: cutlass.Constexpr[int],\n        H: cutlass.Constexpr[int],\n        HV: cutlass.Constexpr[int],\n        K: cutlass.Constexpr[int],\n        V: cutlass.Constexpr[int],\n        use_initial_state: cutlass.Constexpr[bool],\n        use_qk_l2norm: cutlass.Constexpr[bool],\n        stream: cuda.CUstream,\n    ):\n        pool_size, hv_dim, k_dim, v_dim = h0_source.layout.shape\n        n_indices = h0_indices.layout.shape[0]\n        batch_size = n_indices * hv_dim\n\n        copy_atom = cute.make_copy_atom(\n            cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),\n            cutlass.Float32,\n            num_bits_per_copy=128,\n        )\n        num_v_tiles = cute.ceil_div(v_dim, TILE_V)\n        base_smem_layout = cute.make_layout(\n            (TILE_K, TILE_V, NUM_STAGES),\n            stride=(TILE_V_PADDED, 1, TILE_K * TILE_V_PADDED),\n        )\n        thread_layout = cute.make_layout((32, 8), stride=(8, 1))\n        val_layout = cute.make_layout((1, 4))\n        tiled_copy_load = cute.make_tiled_copy_tv(copy_atom, thread_layout, val_layout)\n        smem_bytes = (\n            4 * TILE_K * TILE_V_PADDED * NUM_STAGES + 4 * TILE_V + 4 * TILE_K * 2 + 64\n        )\n\n        gdn_large(\n            tiled_copy_load,\n            h0_source,\n            base_smem_layout,\n            num_v_tiles,\n            q,\n            k,\n            v,\n            a,\n            b,\n            A_log,\n            dt_bias,\n            o,\n            h0_indices,\n            softplus_beta,\n            softplus_threshold,\n            scale,\n            H,\n            HV,\n            use_qk_l2norm,\n        ).launch(\n            grid=(batch_size, 1, 1),\n            block=[NUM_THREADS_LARGE, 1, 1],\n            smem=smem_bytes,\n            stream=stream,\n        )\n\n    @cute.jit\n    def run_large_batch_varlen(\n        cu_seqlens: cute.Tensor,\n        q: cute.Tensor,\n        k: cute.Tensor,\n        v: cute.Tensor,\n        a: cute.Tensor,\n        b: cute.Tensor,\n        A_log: cute.Tensor,\n        dt_bias: cute.Tensor,\n        h0_source: cute.Tensor,\n        h0_indices: cute.Tensor,\n        o: cute.Tensor,\n        softplus_beta: cutlass.Constexpr[float],\n        softplus_threshold: cutlass.Constexpr[float],\n        scale: cutlass.Constexpr[float],\n        B: cutlass.Constexpr[int],\n        T: cutlass.Constexpr[int],\n        H: cutlass.Constexpr[int],\n        HV: cutlass.Constexpr[int],\n        K: cutlass.Constexpr[int],\n        V: cutlass.Constexpr[int],\n        use_initial_state: cutlass.Constexpr[bool],\n        use_qk_l2norm: cutlass.Constexpr[bool],\n        stream: cuda.CUstream,\n    ):\n        pool_size, hv_dim, k_dim, v_dim = h0_source.layout.shape\n        n_indices = h0_indices.layout.shape[0]\n        batch_size = n_indices * hv_dim\n\n        copy_atom = cute.make_copy_atom(\n            cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),\n            cutlass.Float32,\n            num_bits_per_copy=128,\n        )\n        num_v_tiles = cute.ceil_div(v_dim, TILE_V)\n        base_smem_layout = cute.make_layout(\n            (TILE_K, TILE_V, NUM_STAGES),\n            stride=(TILE_V_PADDED, 1, TILE_K * TILE_V_PADDED),\n        )\n        thread_layout = cute.make_layout((32, 8), stride=(8, 1))\n        val_layout = cute.make_layout((1, 4))\n        tiled_copy_load = cute.make_tiled_copy_tv(copy_atom, thread_layout, val_layout)\n        smem_bytes = (\n            4 * TILE_K * TILE_V_PADDED * NUM_STAGES + 4 * TILE_V + 4 * TILE_K * 2 + 64\n        )\n\n        gdn_large_varlen(\n            tiled_copy_load,\n            h0_source,\n            base_smem_layout,\n            num_v_tiles,\n            q,\n            k,\n            v,\n            a,\n            b,\n            A_log,\n            dt_bias,\n            o,\n            h0_indices,\n            softplus_beta,\n            softplus_threshold,\n            scale,\n            H,\n            HV,\n            use_qk_l2norm,\n        ).launch(\n            grid=(batch_size, 1, 1),\n            block=[NUM_THREADS_LARGE, 1, 1],\n            smem=smem_bytes,\n            stream=stream,\n        )\n\n    return (\n        run_small_batch,\n        run_small_batch_varlen,\n        run_large_batch,\n        run_large_batch_varlen,\n    )\n\n\n_jit_functions = None\n\n\ndef _get_jit_functions():\n    global _jit_functions\n    if _jit_functions is None:\n        _jit_functions = _create_jit_functions()\n    return _jit_functions\n\n\ndef _get_compiled_kernel(N, H, HV, K, V, pool_size, use_small_batch, is_varlen_decode):\n    \"\"\"Get or compile the kernel for given dimensions.\"\"\"\n    global _compiled_kernels\n\n    key = (N, H, HV, K, V, pool_size, use_small_batch, is_varlen_decode)\n    if key in _compiled_kernels:\n        return _compiled_kernels[key]\n\n    cu_seqlens = torch.zeros(N + 1, dtype=torch.int32, device=\"cuda\")\n\n    if is_varlen_decode:\n        q = torch.zeros(1, N, H, K, dtype=torch.bfloat16, device=\"cuda\")\n        k = torch.zeros(1, N, H, K, dtype=torch.bfloat16, device=\"cuda\")\n        v = torch.zeros(1, N, HV, V, dtype=torch.bfloat16, device=\"cuda\")\n        a = torch.zeros(N, HV, dtype=torch.bfloat16, device=\"cuda\")\n        b = torch.zeros(N, HV, dtype=torch.bfloat16, device=\"cuda\")\n        o = torch.zeros(1, N, HV, V, dtype=torch.bfloat16, device=\"cuda\")\n    else:\n        q = torch.zeros(N, 1, H, K, dtype=torch.bfloat16, device=\"cuda\")\n        k = torch.zeros(N, 1, H, K, dtype=torch.bfloat16, device=\"cuda\")\n        v = torch.zeros(N, 1, HV, V, dtype=torch.bfloat16, device=\"cuda\")\n        a = torch.zeros(N, 1, HV, dtype=torch.bfloat16, device=\"cuda\")\n        b = torch.zeros(N, 1, HV, dtype=torch.bfloat16, device=\"cuda\")\n        o = torch.zeros(N, 1, HV, V, dtype=torch.bfloat16, device=\"cuda\")\n\n    A_log = torch.zeros(HV, dtype=torch.float32, device=\"cuda\")\n    dt_bias = torch.zeros(HV, dtype=torch.bfloat16, device=\"cuda\")\n    h0_source = torch.zeros(pool_size, HV, K, V, dtype=torch.float32, device=\"cuda\")\n    h0_indices = torch.zeros(N, dtype=torch.int32, device=\"cuda\")\n\n    cu_seqlens_tensor = from_dlpack(cu_seqlens, assumed_align=16)\n    q_tensor = from_dlpack(q, assumed_align=16)\n    k_tensor = from_dlpack(k, assumed_align=16)\n    v_tensor = from_dlpack(v, assumed_align=16)\n    a_tensor = from_dlpack(a, assumed_align=16)\n    b_tensor = from_dlpack(b, assumed_align=16)\n    A_log_tensor = from_dlpack(A_log, assumed_align=16)\n    dt_bias_tensor = from_dlpack(dt_bias, assumed_align=16)\n    h0_source_tensor = from_dlpack(h0_source, assumed_align=16)\n    h0_indices_tensor = from_dlpack(h0_indices, assumed_align=16)\n    o_tensor = from_dlpack(o, assumed_align=16)\n\n    stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)\n\n    run_small, run_small_varlen, run_large, run_large_varlen = _get_jit_functions()\n\n    if use_small_batch:\n        kernel_func = run_small_varlen if is_varlen_decode else run_small\n    else:\n        kernel_func = run_large_varlen if is_varlen_decode else run_large\n\n    scale = K**-0.5\n    softplus_beta = 1.0\n    softplus_threshold = 20.0\n\n    B_compile = 1 if is_varlen_decode else N\n    T_compile = N if is_varlen_decode else 1\n\n    compiled_kernel = cute.compile(\n        kernel_func,\n        cu_seqlens_tensor,\n        q_tensor,\n        k_tensor,\n        v_tensor,\n        a_tensor,\n        b_tensor,\n        A_log_tensor,\n        dt_bias_tensor,\n        h0_source_tensor,\n        h0_indices_tensor,\n        o_tensor,\n        softplus_beta=softplus_beta,\n        softplus_threshold=softplus_threshold,\n        scale=scale,\n        B=B_compile,\n        T=T_compile,\n        H=H,\n        K=K,\n        V=V,\n        HV=HV,\n        use_initial_state=True,\n        use_qk_l2norm=True,\n        stream=stream,\n    )\n\n    _compiled_kernels[key] = compiled_kernel\n    logger.info(\n        f\"CuTe DSL GDN kernel compiled: N={N}, H={H}, HV={HV}, K={K}, V={V}, pool_size={pool_size}, small_batch={use_small_batch}, varlen={is_varlen_decode}\"\n    )\n\n    return compiled_kernel\n\n\ndef cutedsl_fused_sigmoid_gating_delta_rule_update(\n    A_log: torch.Tensor,\n    dt_bias: torch.Tensor,\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    a: torch.Tensor,\n    b: torch.Tensor,\n    initial_state_source: torch.Tensor,\n    initial_state_indices: torch.Tensor,\n    cu_seqlens: Optional[torch.Tensor] = None,\n    scale: Optional[float] = None,\n    use_qk_l2norm_in_kernel: bool = True,\n    softplus_beta: float = 1.0,\n    softplus_threshold: float = 20.0,\n) -> torch.Tensor:\n    \"\"\"CuTe DSL implementation of fused sigmoid gating delta rule update.\"\"\"\n\n    B_q, T_q, H, K = q.shape\n    HV = v.shape[2]\n    V = v.shape[3]\n    N = initial_state_indices.shape[0]\n\n    is_varlen_decode = B_q == 1 and T_q == N and N > 1\n    if scale is None:\n        scale = K**-0.5\n\n    use_small_batch = N < SMALL_BATCH_THRESHOLD\n\n    if initial_state_source.dim() == 1:\n        pool_size = initial_state_source.numel() // (HV * K * V)\n        h0_source = initial_state_source.view(pool_size, HV, K, V)\n    elif initial_state_source.dim() == 4:\n        pool_size = initial_state_source.shape[0]\n        h0_source = initial_state_source\n    else:\n        raise ValueError(\n            f\"Unexpected initial_state_source shape: {initial_state_source.shape}\"\n        )\n\n    if is_varlen_decode:\n        if a.dim() == 3:\n            a = a.squeeze(0)\n        if b.dim() == 3:\n            b = b.squeeze(0)\n        o = q.new_empty(1, N, HV, V, dtype=torch.bfloat16)\n    else:\n        if a.dim() == 2:\n            a = a.unsqueeze(1)\n        if b.dim() == 2:\n            b = b.unsqueeze(1)\n        o = q.new_empty(N, 1, HV, V, dtype=torch.bfloat16)\n\n    q, k, v = [t.contiguous() for t in (q, k, v)]\n\n    global _cu_seqlens_cache\n    if cu_seqlens is not None:\n        cu_seqlens_to_use = cu_seqlens\n    else:\n        cache_key = (N, str(q.device))\n        if cache_key not in _cu_seqlens_cache:\n            _cu_seqlens_cache[cache_key] = torch.arange(\n                N + 1, dtype=torch.int32, device=q.device\n            )\n        cu_seqlens_to_use = _cu_seqlens_cache[cache_key]\n\n    cu_seqlens_tensor = from_dlpack(\n        cu_seqlens_to_use.detach(), assumed_align=16\n    ).mark_layout_dynamic(leading_dim=0)\n    q_tensor = from_dlpack(q.detach(), assumed_align=16).mark_layout_dynamic(\n        leading_dim=q.ndim - 1\n    )\n    k_tensor = from_dlpack(k.detach(), assumed_align=16).mark_layout_dynamic(\n        leading_dim=k.ndim - 1\n    )\n    v_tensor = from_dlpack(v.detach(), assumed_align=16).mark_layout_dynamic(\n        leading_dim=v.ndim - 1\n    )\n    a_tensor = from_dlpack(a.detach(), assumed_align=16).mark_layout_dynamic(\n        leading_dim=a.ndim - 1\n    )\n    b_tensor = from_dlpack(b.detach(), assumed_align=16).mark_layout_dynamic(\n        leading_dim=b.ndim - 1\n    )\n    A_log_tensor = from_dlpack(A_log.detach(), assumed_align=16).mark_layout_dynamic(\n        leading_dim=0\n    )\n    dt_bias_tensor = from_dlpack(\n        dt_bias.detach(), assumed_align=16\n    ).mark_layout_dynamic(leading_dim=0)\n    h0_source_tensor = from_dlpack(\n        h0_source.detach(), assumed_align=16\n    ).mark_layout_dynamic(leading_dim=h0_source.ndim - 1)\n    h0_indices_tensor = from_dlpack(\n        initial_state_indices.detach(), assumed_align=16\n    ).mark_layout_dynamic(leading_dim=0)\n    o_tensor = from_dlpack(o.detach(), assumed_align=16).mark_layout_dynamic(\n        leading_dim=o.ndim - 1\n    )\n\n    stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)\n\n    compiled_kernel = _get_compiled_kernel(\n        N, H, HV, K, V, pool_size, use_small_batch, is_varlen_decode\n    )\n\n    compiled_kernel(\n        cu_seqlens_tensor,\n        q_tensor,\n        k_tensor,\n        v_tensor,\n        a_tensor,\n        b_tensor,\n        A_log_tensor,\n        dt_bias_tensor,\n        h0_source_tensor,\n        h0_indices_tensor,\n        o_tensor,\n        stream,\n    )\n\n    return o\n"
  },
  {
    "path": "python/sglang/jit_kernel/diffusion/cutedsl/common/norm_fusion.py",
    "content": "from typing import Optional, Tuple, Union\n\nimport cutlass\nimport cutlass.cute as cute\nimport torch\nfrom einops import rearrange\n\nfrom sglang.jit_kernel.diffusion.cutedsl.common.reduce import (\n    cta_reduce_sum,\n    warp_reduce_sum,\n)\n\n\n@cute.jit\ndef apply_norm_cta(\n    norm_type: cutlass.Constexpr,\n    num_warps: cutlass.Constexpr,\n    tidx: cutlass.Int32,\n    tXrX: cute.Tensor,\n    tWrW: Optional[cute.Tensor],\n    tBrB: Optional[cute.Tensor],\n    D: Union[cutlass.Int32, cutlass.Constexpr],\n    eps: Union[cutlass.Float32, cutlass.Constexpr],\n) -> cute.Tensor:\n    if cutlass.const_expr(norm_type == \"rms\"):\n        return apply_rmsnorm_cta(num_warps, tidx, tXrX, tWrW, D, eps)\n    else:\n        return apply_layernorm_cta(num_warps, tidx, tXrX, tWrW, tBrB, D, eps)\n\n\n@cute.jit\ndef apply_rmsnorm_cta(\n    num_warps: Union[cutlass.Int32, cutlass.Constexpr],\n    tidx: cutlass.Int32,\n    tXrX: cute.Tensor,\n    tWrW: Optional[cute.Tensor],\n    D: Union[cutlass.Int32, cutlass.Constexpr],\n    eps: Union[cutlass.Float32, cutlass.Constexpr],\n) -> cute.Tensor:\n    \"\"\"\n    RMSNorm:\n      y[i] = x[i] / sqrt(sum(x ^ 2) / D + eps) * w[i]\n    \"\"\"\n    val = cute.Float32(0.0)\n    for idx in range(cute.size(tXrX)):\n        # Accumulate in FP32 to improve numerical precision.\n        x_fp32 = tXrX[idx].to(cutlass.Float32)\n        val += x_fp32 * x_fp32\n    val = warp_reduce_sum(val)\n    acc_sq = cta_reduce_sum(val, num_warps, tidx)\n    factor = cute.rsqrt(acc_sq / D + eps)\n    tNrN = cute.make_fragment_like(tXrX)\n    if cutlass.const_expr(isinstance(tWrW, cute.Tensor)):\n        tNrN.store((tXrX.load() * factor * tWrW.load()).to(tNrN.element_type))\n    else:\n        tNrN.store((tXrX.load() * factor).to(tNrN.element_type))\n    return tNrN\n\n\n@cute.jit\ndef apply_layernorm_cta(\n    num_warps: Union[cutlass.Int32, cutlass.Constexpr],\n    tidx: cutlass.Int32,\n    tXrX: cute.Tensor,\n    tWrW: Optional[cute.Tensor],\n    tBrB: Optional[cute.Tensor],\n    D: Union[cutlass.Int32, cutlass.Constexpr],\n    eps: Union[cutlass.Float32, cutlass.Constexpr],\n) -> cute.Tensor:\n    \"\"\"\n    LayerNorm:\n        mean = sum(x) / D\n        var  = sum((x - mean) ^ 2) / D\n        y[i] = (x[i] - mean) / sqrt(var + eps) * w[i] + b[i]\n    \"\"\"\n    # Reduce mean\n    val = cute.Float32(0.0)\n    for idx in range(cute.size(tXrX)):\n        # Accumulate in FP32 to improve numerical precision.\n        val += tXrX[idx].to(cutlass.Float32)\n    val = warp_reduce_sum(val)\n    val = cta_reduce_sum(val, num_warps, tidx)\n    mean = val / D\n    # Reduce variance\n    val = cute.Float32(0.0)\n    for idx in range(cute.size(tXrX)):\n        # Accumulate in FP32 to improve numerical precision.\n        x_fp32 = tXrX[idx].to(cutlass.Float32)\n        val += (x_fp32 - mean) * (x_fp32 - mean)\n    val = warp_reduce_sum(val)\n    val = cta_reduce_sum(val, num_warps, tidx)\n    factor = cute.rsqrt(val / D + eps)\n    # Normalize\n    tNrN = cute.make_fragment_like(tXrX)\n    if cutlass.const_expr(\n        isinstance(tWrW, cute.Tensor) and isinstance(tBrB, cute.Tensor)\n    ):\n        tNrN.store(\n            ((tXrX.load() - mean) * factor * tWrW.load() + tBrB.load()).to(\n                tNrN.element_type\n            )\n        )\n    else:\n        tNrN.store(((tXrX.load() - mean) * factor).to(tNrN.element_type))\n    return tNrN\n\n\n################################################################################\n# BSFD Indexing\n################################################################################\n# In diffusion norm-fusion kernels, we compute `norm(x) + y`, where\n# `x` has shape [B, S, D] and `y` may come in various broadcastable forms:\n#   [1], [D], [1, D], [1, 1, D], [B, D], [B, 1, D], [B, S, D], or [B, F, 1, D].\n#\n# For a given (batch_id, seq_id), the index mapping for `y` falls into 3 cases:\n#   1) Scalar broadcast [1]:\n#        (batch_id, seq_id, *) -> (0)\n#   2) Frame-based BSFD broadcast [B, F, 1, D]:\n#        frame_id = seq_id // len_frame\n#        (batch_id, seq_id, *) -> (batch_id, frame_id, *)\n#   3) All other cases:\n#        `y` is broadcast to [B, S, D] (via view/expand, no materialization),\n#        and indexed as (batch_id, seq_id, *).\n#\n# This helper normalizes `y` into a BSFD-compatible view so that kernel\n# indexing logic remains simple and uniform.\n################################################################################\n\n\ndef broadcast_tensor_for_bsfd(\n    tensor: Union[Optional[torch.Tensor], int],\n    B: int,\n    S: int,\n    D: int,\n) -> Union[Optional[torch.Tensor], int]:\n    \"\"\"\n    Broadcast to (B, S, D) without memory copy for following shapes:\n    - [D], [1, D], [1, 1, D], [B, D], [B, 1, D], [B, S, D].\n    \"\"\"\n\n    # Return directly for non-tensor value\n    if not isinstance(tensor, torch.Tensor):\n        return tensor\n\n    if tensor.ndim == 1:\n        # Scalar [1] is preserved as-is and handled specially in CuTe kernel.\n        if tensor.numel() == 1:\n            return tensor\n        return rearrange(tensor, \"d -> 1 1 d\").expand(B, S, D)\n    if tensor.ndim == 2:\n        return rearrange(tensor, \"b d -> b 1 d\").expand(B, S, D)\n    if tensor.ndim == 3:\n        return tensor.expand(B, S, D)\n    if tensor.ndim == 4:\n        return tensor\n    raise ValueError(f\"BSFD broadcast: unsupported tensor ndim: {tensor.ndim}.\")\n\n\n@cute.jit\ndef tensor_slice_for_bsfd(\n    mV: cute.Tensor,\n    thr_copy: cute.ThrCopy,\n    batch_id: cutlass.Int32,\n    seq_id: cutlass.Int32,\n    S: Union[cutlass.Int32, cutlass.Constexpr],\n    D: Union[cutlass.Int32, cutlass.Constexpr],\n) -> Tuple[cute.Tensor, cute.Tensor]:\n    \"\"\"\n    Slice a BSFD-compatible tensor into a per-thread gmem tile and rmem fragment.\n\n    Given a logical (batch_id, seq_id), this helper selects the corresponding\n    D-length slice from `mV` and prepares it for vectorized copy.\n    \"\"\"\n    gV: cute.Tensor\n    if cutlass.const_expr(cute.is_static(mV.layout) and cute.size(mV.layout) == 1):\n        # build a ((1,1),(1,)) layout so it could broadcast-align with the\n        # regular rmem fragment shape ((4,1),(k,)).\n        layout = cute.make_layout(shape=((1, 1), (1,)))\n        tVgV = cute.make_tensor(mV.iterator, layout)\n        tVrV = cute.make_rmem_tensor(layout, mV.element_type)\n        return tVgV, tVrV\n\n    # Use `local_tile` instead of direct indexing to preserve gmem base pointer\n    # alignment required for vectorized loads.\n    if cutlass.const_expr(len(mV.shape) == 1):\n        gV = mV\n    elif cutlass.const_expr(len(mV.shape) == 3):\n        gV = cute.local_tile(mV, tiler=(1, 1, D), coord=(batch_id, seq_id, 0))\n        gV = gV[0, 0, None]\n    elif cutlass.const_expr(len(mV.shape) == 4):\n        # Compute frame length at runtime (instead of compile time) to avoid\n        # specializing kernels on the frame dimension.\n        frame_len = S // mV.shape[1]\n        frame_id = seq_id // frame_len\n        gV = cute.local_tile(mV, tiler=(1, 1, 1, D), coord=(batch_id, frame_id, 0, 0))\n        gV = gV[0, 0, 0, None]\n    else:\n        raise NotImplementedError(f\"BSFD slice: unsupported shape {mV.shape}.\")\n    tVgV = thr_copy.partition_S(gV)\n    tVrV = cute.make_fragment_like(tVgV, tVgV.element_type)\n    return tVgV, tVrV\n"
  },
  {
    "path": "python/sglang/jit_kernel/diffusion/cutedsl/common/reduce.py",
    "content": "import math\n\nimport cutlass\nimport cutlass.cute as cute\n\n\n@cute.jit\ndef warp_reduce_sum(val: cute.Numeric, reduce_size: int = 32) -> cute.Numeric:\n    iters = int(math.log2(reduce_size))\n    for i in range(iters):\n        val = val + cute.arch.shuffle_sync_down(val, offset=1 << (iters - i - 1))\n    return val\n\n\n@cute.jit\ndef cta_reduce_sum(\n    val: cute.Numeric, num_warps: cutlass.Constexpr, tidx: cutlass.Int32\n) -> cute.Numeric:\n    smem = cutlass.utils.SmemAllocator()\n    acc = smem.allocate_tensor(cutlass.Float32, num_warps + 1)\n    warp_id = tidx >> 5\n    lane_id = tidx & 31\n    if lane_id == 0:\n        acc[warp_id] = val\n    cute.arch.sync_threads()\n    if warp_id == 0:\n        val = acc[lane_id] if lane_id < num_warps else cutlass.Float32(0)\n        val = warp_reduce_sum(val)\n        if lane_id == 0:\n            acc[num_warps] = val\n    cute.arch.sync_threads()\n    val = acc[num_warps]\n    return val\n"
  },
  {
    "path": "python/sglang/jit_kernel/diffusion/cutedsl/scale_residual_norm_scale_shift.py",
    "content": "from typing import Optional, Tuple, Union\n\nimport cuda.bindings.driver as cuda\nimport cutlass\nimport cutlass.cute as cute\nimport torch\n\nfrom sglang.jit_kernel.diffusion.cutedsl.common.norm_fusion import (\n    apply_norm_cta,\n    broadcast_tensor_for_bsfd,\n    tensor_slice_for_bsfd,\n)\nfrom sglang.jit_kernel.diffusion.cutedsl.utils import TORCH_TO_CUTE_DTYPE, WARP_SIZE\n\n_COMPILE_CACHE = {}\n\n\ndef to_cute_arg(\n    t,\n    *,\n    assume_aligned: Optional[int] = 32,\n    use_32bit_stride: bool = False,\n    enable_tvm_ffi: bool = True,\n):\n    \"\"\"\n    Convert a Python value into a CuTeDSL value.\n    \"\"\"\n    if isinstance(t, torch.Tensor):\n        return cute.runtime.from_dlpack(\n            t,\n            assumed_align=assume_aligned,\n            use_32bit_stride=use_32bit_stride,\n            enable_tvm_ffi=enable_tvm_ffi,\n        )\n    if isinstance(t, int):\n        return cutlass.Int32(t)\n    if isinstance(t, float):\n        return cutlass.Float32(t)\n    return t\n\n\ndef to_fake_cute_args(t: torch.Tensor):\n    if isinstance(t, torch.Tensor):\n        # Only keep the last dim as compile-time value to maximum compiled kernel reuse\n        # e.g. (1,2,1536):(3027,1536,1) -> (?,?,1536):(?,?,1)\n        D = t.shape[-1]\n        dtype = TORCH_TO_CUTE_DTYPE[t.dtype]\n        shape = (*(cute.sym_int() for _ in range(t.ndim - 1)), D)\n        stride = (*(cute.sym_int(divisibility=D) for _ in range(t.ndim - 1)), 1)\n        fake_t = cute.runtime.make_fake_tensor(\n            dtype, shape, stride, memspace=cute.AddressSpace.gmem, assumed_align=32\n        )\n        return fake_t\n    return to_cute_arg(t)\n\n\nclass ScaleResidualNormScaleShift:\n    @classmethod\n    def make_hash_key(cls, *inputs):\n        \"\"\"\n        Compile-time values:\n          - D: hidden dimension (size of the last dimension)\n          - norm_type: layer norm or RMS norm\n          - tensor dtype\n          - tensor rank (i.e., tensor.ndim)\n\n        Runtime values:\n          - all other inputs\n\n        This hash key defines the compile-time specialization boundary for\n        ScaleResidualNormScaleShift kernels.\n        \"\"\"\n\n        def _sig(val):\n            if isinstance(val, torch.Tensor):\n                return (val.dtype, val.ndim, val.shape[-1])\n            return val\n\n        return tuple(_sig(val) for val in inputs)\n\n    def __init__(self, D: int, norm_type: str):\n        self.D = D\n        self.norm_type = norm_type  # \"layer\" or \"rms\"\n        self.num_warps = self.D // 256  # num of warps per cta\n        self.num_threads = self.num_warps * WARP_SIZE  # num of threads per cta\n\n    @cute.jit\n    def __call__(\n        self,\n        mY,\n        mResOut,\n        mRes,\n        mX,\n        mGate,\n        mWeight,\n        mBias,\n        mScale,\n        mShift,\n        eps: cutlass.Float32 = cutlass.Float32(1e-5),\n        stream: cuda.CUstream = cuda.CUstream(cuda.CUstream_flags.CU_STREAM_DEFAULT),\n    ):\n        # Tensor shapes\n        B, S, _ = mX.shape  # (batch, seq_len, hidden_dim)\n        # Vectorized copy configuration\n        num_vectorized = 8  # maximum num of elem per copy\n        atom_copy = cute.make_copy_atom(\n            cute.nvgpu.CopyUniversalOp(),\n            mX.element_type,\n            num_bits_per_copy=128,\n        )\n        # Thread/value layouts for tiled copy\n        t_layout = cute.make_layout(self.num_threads)  # thread layout within a CTA\n        v_layout = cute.make_layout(num_vectorized)  # per-thread vector layout\n        tiled_copy = cute.make_tiled_copy_tv(atom_copy, t_layout, v_layout)\n\n        self.kernel(\n            mY,\n            mResOut,\n            mRes,\n            mX,\n            mGate,\n            mWeight,\n            mBias,\n            mScale,\n            mShift,\n            tiled_copy,\n            eps,\n        ).launch(\n            grid=[B * S, 1, 1],\n            block=[self.num_threads, 1, 1],\n            stream=stream,\n        )\n\n    @cute.kernel\n    def kernel(\n        self,\n        mY,\n        mResOut,\n        mRes,\n        mX,\n        mGate,\n        mWeight,\n        mBias,\n        mScale,\n        mShift,\n        tiled_copy: cute.TiledCopy,\n        eps: cutlass.Float32,\n    ):\n        _, S, _ = mX.shape\n        tidx, _, _ = cute.arch.thread_idx()  # thread index\n        bid, _, _ = cute.arch.block_idx()  # cta index\n        bidx = cutlass.Int32(bid // S)  # batch index\n        bidy = cutlass.Int32(bid % S)  # seq_len index\n        thr_copy = tiled_copy.get_slice(tidx)\n\n        @cute.jit\n        def slice_if(mV):\n            if cutlass.const_expr(isinstance(mV, cute.Tensor)):\n                return tensor_slice_for_bsfd(mV, thr_copy, bidx, bidy, S, self.D)\n            return mV, mV\n\n        @cute.jit\n        def copy_if(src, dst):\n            if cutlass.const_expr(\n                isinstance(src, cute.Tensor) and isinstance(dst, cute.Tensor)\n            ):\n                cute.autovec_copy(src, dst)  # LDG.128\n\n        @cute.jit\n        def norm(x, weight, bias):\n            return apply_norm_cta(\n                self.norm_type, self.num_warps, tidx, x, weight, bias, self.D, eps\n            )\n\n        # Slice: retrieve the per-thread data slices for both global memory (gmem)\n        # and register memory (rmem). The layouts are:\n        # - ((4,2),(1)):((1,4),(0)) for fp32\n        # - ((8,1),(1)):((1,0),(0)) for fp16/bf16\n        tRgR, tRrR = slice_if(mRes)  # residual\n        tXgX, tXrX = slice_if(mX)  # x\n        tGgG, tGrG = slice_if(mGate)  # gate\n        tROgRO, tROrRO = slice_if(mResOut)  # residual_out\n        tWgW, tWrW = slice_if(mWeight)  # weight\n        tBgB, tBrB = slice_if(mBias)  # bias\n        tSCgSC, tSCrSC = slice_if(mScale)  # scale\n        tSHgSH, tSHrSH = slice_if(mShift)  # shift\n        tYgY, tYrY = slice_if(mY)  # y\n        # Load: load tensor from global memory to registers\n        copy_if(tRgR, tRrR)  # gmem -> rmem\n        copy_if(tXgX, tXrX)  # gmem -> rmem\n        copy_if(tGgG, tGrG)  # gmem -> rmem\n        copy_if(tWgW, tWrW)  # gmem -> rmem\n        copy_if(tBgB, tBrB)  # gmem -> rmem\n\n        # For norm_scale_shift, output:\n        # - y = norm(x, weight, bias) * (1 + scale) + shift\n        # For scale_residual_norm_scale_shift, output:\n        # - residual_out = residual + gate * x\n        # - y = norm(residual_out, weight, bias) * (1 + scale) + shift\n        # Compute: value = <gate> * x\n        value = tXrX.load()\n        if cutlass.const_expr(isinstance(tGrG, cute.Tensor)):\n            value = tGrG.load() * value\n        # Compute: value = value + <residual>\n        if cutlass.const_expr(isinstance(tRrR, cute.Tensor)):\n            value = value + tRrR.load()\n        # Store: residual_out\n        if cutlass.const_expr(isinstance(tROrRO, cute.Tensor)):\n            tROrRO.store(value.to(tROrRO.element_type))\n            copy_if(tROrRO, tROgRO)  # rmem -> gmem\n        # Compute: value = norm(value) * <weight> + <bias>\n        tNrN = cute.make_rmem_tensor_like(tXrX, tXrX.element_type)\n        tNrN.store(value.to(tNrN.element_type))\n        tNrN = norm(tNrN, tWrW, tBrB)\n        # Compute: value = value * (1 + <scale>) + <shift>\n        value = tNrN.load()\n        copy_if(tSCgSC, tSCrSC)  # gmem -> rmem\n        copy_if(tSHgSH, tSHrSH)  # gmem -> rmem\n        if cutlass.const_expr(isinstance(tSCrSC, cute.Tensor)):\n            value = value * (1 + tSCrSC.load())\n        if cutlass.const_expr(isinstance(tSHrSH, cute.Tensor)):\n            value = value + tSHrSH.load()\n        # Store: y\n        tYrY.store(value.to(tYrY.element_type))\n        copy_if(tYrY, tYgY)  # rmem -> gmem\n\n\ndef validate_x(t: torch.Tensor, B: int, S: int, D: int):\n    if t.dtype not in (torch.float16, torch.bfloat16, torch.float32):\n        raise ValueError(f\"Validate failed: unsupported dtype: {t.dtype}\")\n    if t.shape != (B, S, D):\n        raise ValueError(f\"Validate failed: unsupported tensor shape: {t.shape}.\")\n    if t.stride()[-1] != 1:\n        raise ValueError(f\"Validate failed: not contiguous on dim D.\")\n\n\ndef validate_weight_bias(t: Optional[torch.Tensor], B: int, S: int, D: int):\n    if t is None:\n        return\n    if t.dtype not in (torch.float16, torch.bfloat16, torch.float32):\n        raise ValueError(f\"Validate failed: unsupported dtype: {t.dtype}\")\n    if t.shape != (D,):\n        raise ValueError(f\"Validate failed: unsupported tensor shape: {t.shape}.\")\n    if t.stride()[-1] != 1:\n        raise ValueError(f\"Validate failed: not contiguous on dim D.\")\n\n\ndef validate_scale_shift(t: torch.Tensor, B: int, S: int, D: int):\n    if t.dtype not in (torch.float16, torch.bfloat16, torch.float32):\n        raise ValueError(f\"Validate failed: unsupported dtype: {t.dtype}\")\n    failed = False\n    if t.ndim == 1 and (t.shape[0] not in (1, D)):\n        failed = True\n    elif t.ndim == 2 and ((t.shape[0] not in (1, B)) or t.shape[1] != D):\n        failed = True\n    elif t.ndim == 3 and (\n        (t.shape[0] not in (1, B)) or (t.shape[1] not in (1, S) or t.shape[2] != D)\n    ):\n        failed = True\n    elif t.ndim == 4 and (t.shape[0] != B or t.shape[2] != 1 or t.shape[3] != D):\n        F = t.shape[1]\n        if S % F != 0:\n            raise ValueError(f\"Validate failed: S({S}) must be divisible by F({F}).\")\n        failed = True\n    if failed:\n        raise ValueError(f\"Validate failed: unsupported tensor shape: {t.shape}.\")\n    if t.stride()[-1] != 1:\n        raise ValueError(f\"Validate failed: not contiguous on dim D.\")\n\n\ndef validate_gate(t: Union[torch.Tensor, int], B: int, S: int, D: int):\n    if not isinstance(t, torch.Tensor):\n        return\n    validate_scale_shift(t, B, S, D)\n\n\n@torch.library.custom_op(\"sglang::fused_norm_scale_shift\", mutates_args=())\ndef fused_norm_scale_shift(\n    x: torch.Tensor,\n    weight: Optional[torch.Tensor],\n    bias: Optional[torch.Tensor],\n    scale: torch.Tensor,\n    shift: torch.Tensor,\n    norm_type: str,\n    eps: float = 1e-5,\n) -> torch.Tensor:\n    \"\"\"\n    Fuse: norm(x) * (1 + scale) + shift\n      where norm is either layernorm or rmsnorm.\n\n    Expects:\n      - x: [B, S, D]\n      - weight/bias: None, [D]\n      - scale/shift: [1], [D], [1/B, D], [1/B, 1/S, D] or [B, F, 1, D]\n      - norm_type: str, \"layer\" or \"rms\"\n      - eps: Optional[float], default: 1e-5\n\n    D must be a multiple of 256 and <= 8192 to enable LDG.128 vectorized loads per\n    thread and avoid predicated loads (e.g., bounds checks such as `index < D`).\n    \"\"\"\n    stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)\n    # Tensor Validation\n    BSD = x.shape\n    validate_x(x, *BSD)\n    validate_weight_bias(weight, *BSD)\n    validate_weight_bias(bias, *BSD)\n    validate_scale_shift(scale, *BSD)\n    validate_scale_shift(shift, *BSD)\n\n    if norm_type == \"layer\" or norm_type == \"rms\":\n        D = x.shape[-1]\n        if D % 256 != 0 or D > 8192:\n            raise ValueError(\n                f\"D={D} not supported, must be multiple of 256 and <= 8192\"\n            )\n        y = torch.empty_like(x)  # create output tensor\n        scale = broadcast_tensor_for_bsfd(scale, *x.shape)  # handle various shapes\n        shift = broadcast_tensor_for_bsfd(shift, *x.shape)  # handle various shapes\n        # Use scalar placeholders for None tensors as a workaround, since the CuTe DSL\n        # TVM-FFI backend does not support None parameters. scalar values do not result\n        # in code generation and have no impact on runtime performance.\n        weight = 1 if weight is None else weight\n        bias = 0 if bias is None else bias\n        ResOut, Residual, Gate = 0, 0, 1\n        torch_tensors = [y, ResOut, Residual, x, Gate, weight, bias, scale, shift]\n        # Compile cache\n        hash_key = ScaleResidualNormScaleShift.make_hash_key(norm_type, *torch_tensors)\n        compiled_fn = _COMPILE_CACHE.get(hash_key)\n        if compiled_fn is None:\n            kernel = ScaleResidualNormScaleShift(D, norm_type)\n            fake_sig_args = [to_fake_cute_args(t) for t in torch_tensors]\n            compiled_fn = cute.compile(\n                kernel, *fake_sig_args, options=\"--enable-tvm-ffi\"\n            )\n            _COMPILE_CACHE[hash_key] = compiled_fn\n        # Execute\n        compiled_fn(*torch_tensors, eps, stream)\n        return y\n    else:\n        raise ValueError(f'norm_type must be one of \"layer\" and \"rms\"')\n\n\n@fused_norm_scale_shift.register_fake\ndef _fused_norm_scale_shift_fake(x, weight, bias, scale, shift, norm_type, eps=1e-5):\n    y = x.new_empty(x.shape)\n    return y\n\n\n@torch.library.custom_op(\n    \"sglang::fused_scale_residual_norm_scale_shift\", mutates_args=()\n)\ndef fused_scale_residual_norm_scale_shift(\n    residual: torch.Tensor,\n    x: torch.Tensor,\n    gate: Optional[torch.Tensor],  # Union[Optional[torch.Tensor], int] indeed\n    weight: Optional[torch.Tensor],\n    bias: Optional[torch.Tensor],\n    scale: torch.Tensor,\n    shift: torch.Tensor,\n    norm_type: str,\n    eps: float = 1e-5,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Fuse: norm(residual + gate * x) * (1 + scale) + shift\n      where norm is either layernorm or rmsnorm.\n\n    Expects:\n      - residual, x: [B, S, D]\n      - gate: None, [1], [D], [1/B, D], [1/B, 1/S, D] or [B, F, 1, D]\n      - weight/bias: None, [D]\n      - scale/shift: [1], [D], [1/B, D], [1/B, 1/S, D] or [B, F, 1, D]\n      - norm_type: str, \"layer\" or \"rms\"\n      - eps: Optional[float], default: 1e-5\n\n    D must be a multiple of 256 and <= 8192 to enable LDG.128 vectorized loads per\n    thread and avoid predicated loads (e.g., bounds checks such as `index < D`).\n    \"\"\"\n    # Tensor Validation\n    BSD = x.shape\n    validate_x(x, *BSD)\n    validate_x(residual, *BSD)\n    validate_gate(gate, *BSD)\n    validate_weight_bias(weight, *BSD)\n    validate_weight_bias(bias, *BSD)\n    validate_scale_shift(scale, *BSD)\n    validate_scale_shift(shift, *BSD)\n    if norm_type == \"layer\" or norm_type == \"rms\":\n        stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)\n\n        # if norm_type == \"layer\" or norm_type == \"rms\":\n        D = x.shape[-1]\n        if D % 256 != 0 or D > 8192:\n            raise ValueError(\n                f\"D={D} not supported, must be multiple of 256 and <= 8192\"\n            )\n        y = torch.empty_like(x)  # create output tensor\n        resi_out = torch.empty_like(x)  # create output tensor\n        gate = broadcast_tensor_for_bsfd(gate, *x.shape)  # handle various shapes\n        scale = broadcast_tensor_for_bsfd(scale, *x.shape)  # handle various shapes\n        shift = broadcast_tensor_for_bsfd(shift, *x.shape)  # handle various shapes\n        # Use scalar placeholders for None tensors as a workaround, since the CuTe DSL\n        # TVM-FFI backend does not support None parameters. scalar values do not result\n        # in code generation and have no impact on runtime performance.\n        gate = 1 if gate is None else gate\n        weight = 1 if weight is None else weight\n        bias = 0 if bias is None else bias\n        torch_tensors = [y, resi_out, residual, x, gate, weight, bias, scale, shift]\n        # Compile cache\n        hash_key = ScaleResidualNormScaleShift.make_hash_key(norm_type, *torch_tensors)\n        compiled_fn = _COMPILE_CACHE.get(hash_key)\n        if compiled_fn is None:\n            kernel = ScaleResidualNormScaleShift(D, norm_type)\n            fake_sig_args = [to_fake_cute_args(t) for t in torch_tensors]\n            compiled_fn = cute.compile(\n                kernel, *fake_sig_args, options=\"--enable-tvm-ffi\"\n            )\n            _COMPILE_CACHE[hash_key] = compiled_fn\n        # Execute\n        compiled_fn(*torch_tensors, eps, stream)\n        return y, resi_out\n    else:\n        raise ValueError(f'norm_type must be one of \"layer\" and \"rms\"')\n\n\n@fused_scale_residual_norm_scale_shift.register_fake\ndef _fused_scale_residual_norm_scale_shift_fake(\n    residual, x, gate, weight, bias, scale, shift, norm_type, eps=1e-5\n):\n    y = x.new_empty(x.shape)\n    residual_out = x.new_empty(x.shape)\n    return y, residual_out\n"
  },
  {
    "path": "python/sglang/jit_kernel/diffusion/cutedsl/utils.py",
    "content": "import cutlass\nimport torch\n\nWARP_SIZE = 32\n\nTORCH_TO_CUTE_DTYPE = {\n    torch.float16: cutlass.Float16,\n    torch.bfloat16: cutlass.BFloat16,\n    torch.float32: cutlass.Float32,\n}\n"
  },
  {
    "path": "python/sglang/jit_kernel/diffusion/triton/mps_fallback.py",
    "content": "\"\"\"MPS (Apple Silicon) fallbacks for Triton diffusion kernels.\n\nTriton is not available on macOS / Metal, so these pure-PyTorch (and\noptionally MLX-accelerated) implementations replace the Triton kernels\nat import time when ``current_platform.is_mps()`` is True.\n\nMLX acceleration (opt-in via ``SGLANG_USE_MLX=1``):\n    Norm ops use ``mx.fast.rms_norm`` / ``mx.fast.layer_norm`` — single fused\n    Metal kernels that are 1.4x–2.9x faster than the multi-step PyTorch MPS\n    decomposition for medium-to-large tensors.\n\"\"\"\n\nfrom typing import Optional\n\nimport torch\nfrom torch import Tensor\n\nfrom sglang.srt.environ import envs\n\n# MLX acceleration – opt-in via SGLANG_USE_MLX=1\n_MLX_AVAILABLE = False\ntry:\n    import mlx.core as mx\n\n    _MLX_AVAILABLE = True\nexcept ImportError:\n    pass\n\n_USE_MLX = envs.SGLANG_USE_MLX.get() and _MLX_AVAILABLE\n\n# Dtype mapping for torch <-> MLX tensor bridge\n_TORCH_TO_MLX_DTYPE = (\n    {\n        torch.float32: mx.float32,\n        torch.float16: mx.float16,\n        torch.bfloat16: mx.bfloat16,\n    }\n    if _MLX_AVAILABLE\n    else {}\n)\n\n_MLX_TO_TORCH_DTYPE = {v: k for k, v in _TORCH_TO_MLX_DTYPE.items()}\n\n\ndef _torch_to_mlx(tensor: torch.Tensor) -> \"mx.array\":\n    \"\"\"Convert a PyTorch tensor to an MLX array (via numpy on CPU).\"\"\"\n    t = tensor.cpu().detach()\n    if t.dtype == torch.bfloat16:\n        return mx.array(t.float().numpy(), dtype=mx.bfloat16)\n    return mx.array(t.numpy())\n\n\ndef _mlx_to_torch(array: \"mx.array\", device: torch.device) -> torch.Tensor:\n    \"\"\"Convert an MLX array to a PyTorch tensor (zero-copy via memoryview).\"\"\"\n    torch_dtype = _MLX_TO_TORCH_DTYPE.get(array.dtype, torch.float32)\n    array = mx.contiguous(array)\n    mx.eval(array)\n    tensor = torch.frombuffer(memoryview(array), dtype=torch_dtype).reshape(array.shape)\n    if device.type == \"mps\":\n        tensor = tensor.to(device)\n    return tensor\n\n\ndef fuse_scale_shift_kernel_native(\n    x: torch.Tensor,\n    scale: torch.Tensor,\n    shift: torch.Tensor,\n    scale_constant: float = 1.0,\n    block_l: int = 128,\n    block_c: int = 128,\n):\n    \"\"\"Native fallback for fuse_scale_shift_kernel with scale_constant support.\"\"\"\n    B, L, C = x.shape\n\n    def _expand(t: torch.Tensor) -> torch.Tensor:\n        if t.dim() == 4:\n            # [B, F, 1, C] -> [B, L, C]\n            num_frames = t.shape[1]\n            frame_seqlen = L // num_frames\n            return (\n                t.squeeze(2)\n                .unsqueeze(2)\n                .expand(-1, -1, frame_seqlen, -1)\n                .reshape(B, L, C)\n            )\n        elif t.dim() == 2:\n            # [B, C] -> [B, 1, C]\n            return t.unsqueeze(1)\n        return t\n\n    scale = _expand(scale)\n    shift = _expand(shift)\n\n    return x * (scale_constant + scale) + shift\n\n\ndef fuse_scale_shift_gate_select01_kernel_native(\n    x: torch.Tensor,\n    scale0: torch.Tensor,\n    shift0: torch.Tensor,\n    gate0: torch.Tensor,\n    scale1: torch.Tensor,\n    shift1: torch.Tensor,\n    gate1: torch.Tensor,\n    index: torch.Tensor,\n    block_l: int = 128,\n    block_c: int = 128,\n):\n    \"\"\"Native fallback for fuse_scale_shift_gate_select01_kernel.\"\"\"\n    idx = index.unsqueeze(-1).bool()\n    scale = torch.where(idx, scale1.unsqueeze(1), scale0.unsqueeze(1))\n    shift = torch.where(idx, shift1.unsqueeze(1), shift0.unsqueeze(1))\n    gate = torch.where(idx, gate1.unsqueeze(1), gate0.unsqueeze(1))\n    y = x * (1 + scale) + shift\n    return y, gate\n\n\ndef apply_rotary_embedding_native(\n    x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False\n) -> torch.Tensor:\n    \"\"\"Native fallback for rotary embedding (shared with NPU implementation).\"\"\"\n    cos = cos.unsqueeze(-2).to(x.dtype)\n    sin = sin.unsqueeze(-2).to(x.dtype)\n    x1 = x[..., ::2]\n    x2 = x[..., 1::2]\n    o1 = x1 * cos - x2 * sin\n    o2 = x2 * cos + x1 * sin\n    return torch.stack((o1, o2), dim=-1).flatten(-2)\n\n\ndef norm_infer_native(\n    x: Tensor,\n    weight: Optional[Tensor],\n    bias: Optional[Tensor],\n    eps: float,\n    is_rms_norm: bool = False,\n    out: Optional[Tensor] = None,\n) -> Tensor:\n    \"\"\"Native fallback for norm_infer (layer norm / rms norm inference).\"\"\"\n    orig_dtype = x.dtype\n    x = x.contiguous().float()\n    if is_rms_norm:\n        variance = x.pow(2).mean(dim=-1, keepdim=True)\n        x_hat = x * torch.rsqrt(variance + eps)\n    else:\n        mean = x.mean(dim=-1, keepdim=True)\n        variance = (x - mean).pow(2).mean(dim=-1, keepdim=True)\n        x_hat = (x - mean) * torch.rsqrt(variance + eps)\n    if weight is not None:\n        x_hat = x_hat * weight.float()\n    if bias is not None:\n        x_hat = x_hat + bias.float()\n    result = x_hat.to(orig_dtype)\n    if out is not None:\n        out.copy_(result)\n        return out\n    return result\n\n\ndef triton_one_pass_rms_norm_native(\n    x: torch.Tensor, w: torch.Tensor, eps: float = 1e-6\n) -> torch.Tensor:\n    \"\"\"Native fallback for triton_one_pass_rms_norm.\"\"\"\n    shape = x.shape\n    orig_dtype = x.dtype\n    x = x.contiguous().float()\n    variance = x.pow(2).mean(dim=-1, keepdim=True)\n    x_hat = x * torch.rsqrt(variance + eps)\n    return (x_hat * w.float()).to(orig_dtype).view(shape)\n\n\ndef rms_norm_fn_native(\n    x,\n    weight,\n    bias,\n    residual=None,\n    x1=None,\n    weight1=None,\n    bias1=None,\n    eps=1e-6,\n    dropout_p=0.0,\n    rowscale=None,\n    prenorm=False,\n    residual_in_fp32=False,\n    zero_centered_weight=False,\n    return_dropout_mask=False,\n    out_dtype=None,\n    out=None,\n    residual_out=None,\n):\n    \"\"\"Native fallback for rms_norm_fn (inference only, no dropout/x1 support).\"\"\"\n    x_shape_og = x.shape\n    orig_dtype = x.dtype\n    x = x.reshape(-1, x.shape[-1]).float()\n    if residual is not None:\n        residual = residual.reshape(-1, residual.shape[-1]).float()\n        x = x + residual\n        residual_out_val = x.to(torch.float32 if residual_in_fp32 else orig_dtype)\n    else:\n        residual_out_val = None\n    variance = x.pow(2).mean(dim=-1, keepdim=True)\n    x_hat = x * torch.rsqrt(variance + eps)\n    if weight is not None:\n        w = weight.float()\n        if zero_centered_weight:\n            w = w + 1.0\n        x_hat = x_hat * w\n    if bias is not None:\n        x_hat = x_hat + bias.float()\n    final_dtype = out_dtype if out_dtype is not None else orig_dtype\n    y = x_hat.to(final_dtype).reshape(x_shape_og)\n    if residual is not None and residual_out_val is not None:\n        return y, residual_out_val.reshape(x_shape_og)\n    return y\n\n\n# MLX-accelerated norm ops (1.4x–2.9x faster than torch native on MPS)\n# Uses mx.fast.rms_norm / mx.fast.layer_norm — single fused Metal kernels\n# instead of 7+ separate PyTorch MPS kernel launches.\n\nif _USE_MLX:\n\n    def norm_infer_native(  # noqa: F811\n        x: Tensor,\n        weight: Optional[Tensor],\n        bias: Optional[Tensor],\n        eps: float,\n        is_rms_norm: bool = False,\n        out: Optional[Tensor] = None,\n    ) -> Tensor:\n        \"\"\"MLX-accelerated norm_infer (layer norm / rms norm inference).\"\"\"\n        device = x.device\n        orig_dtype = x.dtype\n        x_mx = _torch_to_mlx(x)\n        if is_rms_norm:\n            w_mx = (\n                _torch_to_mlx(weight) if weight is not None else mx.ones(x_mx.shape[-1])\n            )\n            result_mx = mx.fast.rms_norm(x_mx, w_mx, eps)\n        else:\n            w_mx = _torch_to_mlx(weight) if weight is not None else None\n            b_mx = _torch_to_mlx(bias) if bias is not None else None\n            result_mx = mx.fast.layer_norm(x_mx, w_mx, b_mx, eps)\n        result = _mlx_to_torch(result_mx, device).to(orig_dtype)\n        if out is not None:\n            out.copy_(result)\n            return out\n        return result\n\n    def triton_one_pass_rms_norm_native(  # noqa: F811\n        x: torch.Tensor, w: torch.Tensor, eps: float = 1e-6\n    ) -> torch.Tensor:\n        \"\"\"MLX-accelerated triton_one_pass_rms_norm.\"\"\"\n        shape = x.shape\n        device = x.device\n        orig_dtype = x.dtype\n        x_mx = _torch_to_mlx(x.reshape(-1, x.shape[-1]))\n        w_mx = _torch_to_mlx(w)\n        result_mx = mx.fast.rms_norm(x_mx, w_mx, eps)\n        return _mlx_to_torch(result_mx, device).to(orig_dtype).view(shape)\n\n    def rms_norm_fn_native(  # noqa: F811\n        x,\n        weight,\n        bias,\n        residual=None,\n        x1=None,\n        weight1=None,\n        bias1=None,\n        eps=1e-6,\n        dropout_p=0.0,\n        rowscale=None,\n        prenorm=False,\n        residual_in_fp32=False,\n        zero_centered_weight=False,\n        return_dropout_mask=False,\n        out_dtype=None,\n        out=None,\n        residual_out=None,\n    ):\n        \"\"\"MLX-accelerated rms_norm_fn (inference only, no dropout/x1 support).\"\"\"\n        x_shape_og = x.shape\n        device = x.device\n        orig_dtype = x.dtype\n        x_flat = x.reshape(-1, x.shape[-1])\n        if residual is not None:\n            residual = residual.reshape(-1, residual.shape[-1]).float()\n            x_flat = x_flat.float() + residual\n            residual_out_val = x_flat.to(\n                torch.float32 if residual_in_fp32 else orig_dtype\n            )\n        else:\n            residual_out_val = None\n        if weight is not None and zero_centered_weight:\n            w = weight.float() + 1.0\n        else:\n            w = weight\n        x_mx = _torch_to_mlx(x_flat)\n        w_mx = _torch_to_mlx(w) if w is not None else mx.ones(x_mx.shape[-1])\n        result_mx = mx.fast.rms_norm(x_mx, w_mx, eps)\n        x_hat = _mlx_to_torch(result_mx, device)\n        if bias is not None:\n            x_hat = x_hat + bias.to(x_hat.device, x_hat.dtype)\n        final_dtype = out_dtype if out_dtype is not None else orig_dtype\n        y = x_hat.to(final_dtype).reshape(x_shape_og)\n        if residual is not None and residual_out_val is not None:\n            return y, residual_out_val.reshape(x_shape_og)\n        return y\n"
  },
  {
    "path": "python/sglang/jit_kernel/diffusion/triton/norm.py",
    "content": "from typing import Optional\n\nimport torch\nimport triton  # type: ignore\nimport triton.language as tl  # type: ignore\nfrom torch import Tensor\n\n\n# RMSNorm-fp32\ndef maybe_contiguous_lastdim(x):\n    return x.contiguous() if x is not None and x.stride(-1) != 1 else x\n\n\ndef maybe_contiguous(x):\n    return x.contiguous() if x is not None else None\n\n\ndef triton_autotune_configs():\n    # Return configs with a valid warp count for the current device\n    configs = []\n    # Maximum threads per block is architecture-dependent in theory, but in reality all are 1024\n    max_threads_per_block = 1024\n    # Default to warp size 32 if not defined by device\n    warp_size = getattr(\n        torch.get_device_module().get_device_properties(\n            torch.get_device_module().current_device()\n        ),\n        \"warp_size\",\n        32,\n    )\n    if warp_size is None:\n        warp_size = 32\n    # Autotune for warp counts which are powers of 2 and do not exceed thread per block limit\n    return [\n        triton.Config({}, num_warps=warp_count)\n        for warp_count in [1, 2, 4, 8, 16, 32]\n        if warp_count * warp_size <= max_threads_per_block\n    ]\n    # return [triton.Config({}, num_warps=8)]\n\n\n# Copied from flash-attn\n@triton.autotune(\n    configs=triton_autotune_configs(),\n    key=[\n        \"N\",\n        \"HAS_RESIDUAL\",\n        \"STORE_RESIDUAL_OUT\",\n        \"IS_RMS_NORM\",\n        \"HAS_BIAS\",\n        \"HAS_WEIGHT\",\n        \"HAS_X1\",\n        \"HAS_W1\",\n        \"HAS_B1\",\n    ],\n)\n# torch compile doesn't like triton.heuristics, so we set these manually when calling the kernel\n# @triton.heuristics({\"HAS_BIAS\": lambda args: args[\"B\"] is not None})\n# @triton.heuristics({\"HAS_RESIDUAL\": lambda args: args[\"RESIDUAL\"] is not None})\n# @triton.heuristics({\"HAS_X1\": lambda args: args[\"X1\"] is not None})\n# @triton.heuristics({\"HAS_W1\": lambda args: args[\"W1\"] is not None})\n# @triton.heuristics({\"HAS_B1\": lambda args: args[\"B1\"] is not None})\n@triton.jit\ndef _layer_norm_fwd_1pass_kernel(\n    X,  # pointer to the input\n    Y,  # pointer to the output\n    W,  # pointer to the weights\n    B,  # pointer to the biases\n    RESIDUAL,  # pointer to the residual\n    X1,\n    W1,\n    B1,\n    Y1,\n    RESIDUAL_OUT,  # pointer to the residual\n    ROWSCALE,\n    SEEDS,  # Dropout seeds for each row\n    DROPOUT_MASK,\n    DROPOUT_MASK1,\n    Mean,  # pointer to the mean\n    Rstd,  # pointer to the 1/std\n    stride_x_row,  # how much to increase the pointer when moving by 1 row\n    stride_y_row,\n    stride_res_row,\n    stride_res_out_row,\n    stride_x1_row,\n    stride_y1_row,\n    M,  # number of rows in X\n    N,  # number of columns in X\n    eps,  # epsilon to avoid division by zero\n    dropout_p,  # Dropout probability\n    zero_centered_weight,  # If true, add 1.0 to the weight\n    IS_RMS_NORM: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    HAS_RESIDUAL: tl.constexpr,\n    STORE_RESIDUAL_OUT: tl.constexpr,\n    HAS_WEIGHT: tl.constexpr,\n    HAS_BIAS: tl.constexpr,\n    HAS_DROPOUT: tl.constexpr,\n    STORE_DROPOUT_MASK: tl.constexpr,\n    HAS_ROWSCALE: tl.constexpr,\n    HAS_X1: tl.constexpr,\n    HAS_W1: tl.constexpr,\n    HAS_B1: tl.constexpr,\n):\n    # Map the program id to the row of X and Y it should compute.\n    row = tl.program_id(0)\n    X += row * stride_x_row\n    Y += row * stride_y_row\n    if HAS_RESIDUAL:\n        RESIDUAL += row * stride_res_row\n    if STORE_RESIDUAL_OUT:\n        RESIDUAL_OUT += row * stride_res_out_row\n    if HAS_X1:\n        X1 += row * stride_x1_row\n    if HAS_W1:\n        Y1 += row * stride_y1_row\n    # Compute mean and variance\n    cols = tl.arange(0, BLOCK_N)\n    x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n    if HAS_ROWSCALE:\n        rowscale = tl.load(ROWSCALE + row).to(tl.float32)\n        x *= rowscale\n    if HAS_DROPOUT:\n        # Compute dropout mask\n        # 7 rounds is good enough, and reduces register pressure\n        keep_mask = (\n            tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p\n        )\n        x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)\n        if STORE_DROPOUT_MASK:\n            tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)\n    if HAS_X1:\n        x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)\n        if HAS_ROWSCALE:\n            rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)\n            x1 *= rowscale\n        if HAS_DROPOUT:\n            # Compute dropout mask\n            # 7 rounds is good enough, and reduces register pressure\n            keep_mask = (\n                tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7)\n                > dropout_p\n            )\n            x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)\n            if STORE_DROPOUT_MASK:\n                tl.store(DROPOUT_MASK1 + row * N + cols, keep_mask, mask=cols < N)\n        x += x1\n    if HAS_RESIDUAL:\n        residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)\n        x += residual\n    if STORE_RESIDUAL_OUT:\n        tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)\n    if not IS_RMS_NORM:\n        mean = tl.sum(x, axis=0) / N\n        tl.store(Mean + row, mean)\n        xbar = tl.where(cols < N, x - mean, 0.0)\n        var = tl.sum(xbar * xbar, axis=0) / N\n    else:\n        xbar = tl.where(cols < N, x, 0.0)\n        var = tl.sum(xbar * xbar, axis=0) / N\n    rstd = 1 / tl.sqrt(var + eps)\n    tl.store(Rstd + row, rstd)\n    # Normalize and apply linear transformation\n    mask = cols < N\n    if HAS_WEIGHT:\n        w = tl.load(W + cols, mask=mask).to(tl.float32)\n        if zero_centered_weight:\n            w += 1.0\n    if HAS_BIAS:\n        b = tl.load(B + cols, mask=mask).to(tl.float32)\n    x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n    if HAS_WEIGHT:\n        y = x_hat * w + b if HAS_BIAS else x_hat * w\n    else:\n        y = x_hat + b if HAS_BIAS else x_hat\n    # Write output\n    tl.store(Y + cols, y, mask=mask)\n    if HAS_W1:\n        w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)\n        if zero_centered_weight:\n            w1 += 1.0\n        if HAS_B1:\n            b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)\n        y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1\n        tl.store(Y1 + cols, y1, mask=mask)\n\n\ndef _layer_norm_fwd(\n    x: Tensor,\n    weight: Tensor,\n    bias: Tensor,\n    eps: float,\n    residual: Optional[Tensor] = None,\n    x1: Optional[Tensor] = None,\n    weight1: Optional[Tensor] = None,\n    bias1: Optional[Tensor] = None,\n    dropout_p: float = 0.0,\n    rowscale: Optional[Tensor] = None,\n    out_dtype: Optional[torch.dtype] = None,\n    residual_dtype: Optional[torch.dtype] = None,\n    zero_centered_weight: bool = False,\n    is_rms_norm: bool = False,\n    return_dropout_mask: bool = False,\n    out: Optional[Tensor] = None,\n    residual_out: Optional[Tensor] = None,\n) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor):\n    # Need to wrap to handle the case where residual_out is a alias of x, which makes torch.library\n    # and torch.compile unhappy. Also allocate memory for out and residual_out if they are None\n    # so that _layer_norm_fwd_impl doesn't have to return them.\n    if out is None:\n        out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)\n    if residual is not None:\n        residual_dtype = residual.dtype\n    if residual_out is None and (\n        residual is not None\n        or (residual_dtype is not None and residual_dtype != x.dtype)\n        or dropout_p > 0.0\n        or rowscale is not None\n        or x1 is not None\n    ):\n        residual_out = torch.empty_like(\n            x, dtype=residual_dtype if residual_dtype is not None else x.dtype\n        )\n    else:\n        residual_out = None\n    y1, mean, rstd, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd_impl(\n        x,\n        weight,\n        bias,\n        eps,\n        out,\n        residual=residual,\n        x1=x1,\n        weight1=weight1,\n        bias1=bias1,\n        dropout_p=dropout_p,\n        rowscale=rowscale,\n        zero_centered_weight=zero_centered_weight,\n        is_rms_norm=is_rms_norm,\n        return_dropout_mask=return_dropout_mask,\n        residual_out=residual_out,\n    )\n    # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0\n    if residual_out is None:\n        residual_out = x\n    return out, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1\n\n\n# [2025-04-28] torch.library.triton_op ignores the schema argument, but here we need the schema\n# since we're returning a tuple of tensors\ndef _layer_norm_fwd_impl(\n    x: Tensor,\n    weight: Optional[Tensor],\n    bias: Tensor,\n    eps: float,\n    out: Tensor,\n    residual: Optional[Tensor] = None,\n    x1: Optional[Tensor] = None,\n    weight1: Optional[Tensor] = None,\n    bias1: Optional[Tensor] = None,\n    dropout_p: float = 0.0,\n    rowscale: Optional[Tensor] = None,\n    zero_centered_weight: bool = False,\n    is_rms_norm: bool = False,\n    return_dropout_mask: bool = False,\n    residual_out: Optional[Tensor] = None,\n) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor):\n    M, N = x.shape\n    assert x.stride(-1) == 1\n    if residual is not None:\n        assert residual.stride(-1) == 1\n        assert residual.shape == (M, N)\n    if weight is not None:\n        assert weight.shape == (N,)\n        assert weight.stride(-1) == 1\n    if bias is not None:\n        assert bias.stride(-1) == 1\n        assert bias.shape == (N,)\n    if x1 is not None:\n        assert x1.shape == x.shape\n        assert rowscale is None\n        assert x1.stride(-1) == 1\n    if weight1 is not None:\n        assert weight1.shape == (N,)\n        assert weight1.stride(-1) == 1\n    if bias1 is not None:\n        assert bias1.shape == (N,)\n        assert bias1.stride(-1) == 1\n    if rowscale is not None:\n        assert rowscale.is_contiguous()\n        assert rowscale.shape == (M,)\n    assert out.shape == x.shape\n    assert out.stride(-1) == 1\n    if residual_out is not None:\n        assert residual_out.shape == x.shape\n        assert residual_out.stride(-1) == 1\n    if weight1 is not None:\n        y1 = torch.empty_like(out)\n        assert y1.stride(-1) == 1\n    else:\n        y1 = None\n    mean = (\n        torch.empty((M,), dtype=torch.float32, device=x.device)\n        if not is_rms_norm\n        else None\n    )\n    rstd = torch.empty((M,), dtype=torch.float32, device=x.device)\n    if dropout_p > 0.0:\n        seeds = torch.randint(\n            2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64\n        )\n    else:\n        seeds = None\n    if return_dropout_mask and dropout_p > 0.0:\n        dropout_mask = torch.empty(M, N, device=x.device, dtype=torch.bool)\n        if x1 is not None:\n            dropout_mask1 = torch.empty(M, N, device=x.device, dtype=torch.bool)\n        else:\n            dropout_mask1 = None\n    else:\n        dropout_mask, dropout_mask1 = None, None\n    # Less than 64KB per feature: enqueue fused kernel\n    MAX_FUSED_SIZE = 65536 // x.element_size()\n    BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n    if N > BLOCK_N:\n        raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n    with torch.get_device_module().device(x.device.index):\n        torch.library.wrap_triton(_layer_norm_fwd_1pass_kernel)[(M,)](\n            x,\n            out,\n            weight if weight is not None else x,  # unused when HAS_WEIGHT == False\n            bias,\n            residual,\n            x1,\n            weight1,\n            bias1,\n            y1,\n            residual_out,\n            rowscale,\n            seeds,\n            dropout_mask,\n            dropout_mask1,\n            mean,\n            rstd,\n            x.stride(0),\n            out.stride(0),\n            residual.stride(0) if residual is not None else 0,\n            residual_out.stride(0) if residual_out is not None else 0,\n            x1.stride(0) if x1 is not None else 0,\n            y1.stride(0) if y1 is not None else 0,\n            M,\n            N,\n            eps,\n            dropout_p,\n            # Passing bool make torch inductor very unhappy since it then tries to compare to int_max\n            int(zero_centered_weight),\n            is_rms_norm,\n            BLOCK_N,\n            residual is not None,\n            residual_out is not None,\n            weight is not None,\n            bias is not None,\n            dropout_p > 0.0,\n            dropout_mask is not None,\n            rowscale is not None,\n            HAS_X1=x1 is not None,\n            HAS_W1=weight1 is not None,\n            HAS_B1=bias1 is not None,\n        )\n    return y1, mean, rstd, seeds, dropout_mask, dropout_mask1\n\n\nclass LayerNormFn:\n\n    @staticmethod\n    def forward(\n        x,\n        weight,\n        bias,\n        residual=None,\n        x1=None,\n        weight1=None,\n        bias1=None,\n        eps=1e-6,\n        dropout_p=0.0,\n        rowscale=None,\n        prenorm=False,\n        residual_in_fp32=False,\n        zero_centered_weight=False,\n        is_rms_norm=False,\n        return_dropout_mask=False,\n        out_dtype=None,\n        out=None,\n        residual_out=None,\n    ):\n        x_shape_og = x.shape\n        # reshape input data into 2D tensor\n        x = maybe_contiguous_lastdim(x.reshape(-1, x.shape[-1]))\n        if residual is not None:\n            assert residual.shape == x_shape_og\n            residual = maybe_contiguous_lastdim(\n                residual.reshape(-1, residual.shape[-1])\n            )\n        if x1 is not None:\n            assert x1.shape == x_shape_og\n            assert rowscale is None, \"rowscale is not supported with parallel LayerNorm\"\n            x1 = maybe_contiguous_lastdim(x1.reshape(-1, x1.shape[-1]))\n        # weight can be None when elementwise_affine=False for LayerNorm\n        if weight is not None:\n            weight = weight.contiguous()\n        bias = maybe_contiguous(bias)\n        weight1 = maybe_contiguous(weight1)\n        bias1 = maybe_contiguous(bias1)\n        if rowscale is not None:\n            rowscale = rowscale.reshape(-1).contiguous()\n        residual_dtype = (\n            residual.dtype\n            if residual is not None\n            else (torch.float32 if residual_in_fp32 else None)\n        )\n        if out is not None:\n            out = out.reshape(-1, out.shape[-1])\n        if residual_out is not None:\n            residual_out = residual_out.reshape(-1, residual_out.shape[-1])\n        y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = (\n            _layer_norm_fwd(\n                x,\n                weight,\n                bias,\n                eps,\n                residual,\n                x1,\n                weight1,\n                bias1,\n                dropout_p=dropout_p,\n                rowscale=rowscale,\n                out_dtype=out_dtype,\n                residual_dtype=residual_dtype,\n                zero_centered_weight=zero_centered_weight,\n                is_rms_norm=is_rms_norm,\n                return_dropout_mask=return_dropout_mask,\n                out=out,\n                residual_out=residual_out,\n            )\n        )\n        y = y.reshape(x_shape_og)\n        if residual is not None:\n            residual_out = residual_out.reshape(x_shape_og)\n            return y, residual_out\n        return y\n\n\ndef layer_norm_fn(\n    x,\n    weight,\n    bias,\n    residual=None,\n    x1=None,\n    weight1=None,\n    bias1=None,\n    eps=1e-6,\n    dropout_p=0.0,\n    rowscale=None,\n    prenorm=False,\n    residual_in_fp32=False,\n    zero_centered_weight=False,\n    is_rms_norm=False,\n    return_dropout_mask=False,\n    out_dtype=None,\n    out=None,\n    residual_out=None,\n):\n    return LayerNormFn.forward(\n        x,\n        weight,\n        bias,\n        residual,\n        x1,\n        weight1,\n        bias1,\n        eps,\n        dropout_p,\n        rowscale,\n        prenorm,\n        residual_in_fp32,\n        zero_centered_weight,\n        is_rms_norm,\n        return_dropout_mask,\n        out_dtype,\n        out,\n        residual_out,\n    )\n\n\n@triton.jit\ndef _norm_infer_kernel(\n    X,\n    Y,\n    W,\n    B,\n    stride_x_row,\n    stride_y_row,\n    M,\n    N,\n    eps,\n    IS_RMS_NORM: tl.constexpr,\n    HAS_WEIGHT: tl.constexpr,\n    HAS_BIAS: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n):\n    row = tl.program_id(0)\n    X += row * stride_x_row\n    Y += row * stride_y_row\n    if HAS_WEIGHT:\n        W += 0\n    if HAS_BIAS:\n        B += 0\n    cols = tl.arange(0, BLOCK_N)\n    x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n    if not IS_RMS_NORM:\n        mean = tl.sum(x, axis=0) / N\n        xbar = tl.where(cols < N, x - mean, 0.0)\n        var = tl.sum(xbar * xbar, axis=0) / N\n    else:\n        xbar = tl.where(cols < N, x, 0.0)\n        var = tl.sum(xbar * xbar, axis=0) / N\n    rstd = 1 / tl.sqrt(var + eps)\n    x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n    if HAS_WEIGHT:\n        w = tl.load(W + cols, mask=cols < N, other=1.0).to(tl.float32)\n        y = x_hat * w\n    else:\n        y = x_hat\n    if HAS_BIAS:\n        b = tl.load(B + cols, mask=cols < N, other=0.0).to(tl.float32)\n        y += b\n    tl.store(Y + cols, y, mask=cols < N)\n\n\ndef norm_infer(\n    x: Tensor,\n    weight: Optional[Tensor],\n    bias: Optional[Tensor],\n    eps: float,\n    is_rms_norm: bool = False,\n    out: Optional[Tensor] = None,\n):\n    M, N = x.shape\n    x = x.contiguous()\n    if weight is not None:\n        assert weight.shape == (N,)\n        assert weight.stride(-1) == 1\n    if bias is not None:\n        assert bias.shape == (N,)\n        assert bias.stride(-1) == 1\n    if out is None:\n        out = torch.empty_like(x)\n    MAX_FUSED_SIZE = 65536 // x.element_size()\n    BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n    if N > BLOCK_N:\n        raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n    num_warps = min(max(BLOCK_N // 256, 1), 8)\n    _norm_infer_kernel[(M,)](\n        x,\n        out,\n        weight if weight is not None else x,  # dummy when HAS_WEIGHT=False\n        bias if bias is not None else x,  # dummy when HAS_BIAS=False\n        x.stride(0),\n        out.stride(0),\n        M,\n        N,\n        eps,\n        IS_RMS_NORM=is_rms_norm,\n        HAS_WEIGHT=weight is not None,\n        HAS_BIAS=bias is not None,\n        BLOCK_N=BLOCK_N,\n        num_warps=num_warps,\n    )\n    return out\n\n\ndef rms_norm_fn(\n    x,\n    weight,\n    bias,\n    residual=None,\n    x1=None,\n    weight1=None,\n    bias1=None,\n    eps=1e-6,\n    dropout_p=0.0,\n    rowscale=None,\n    prenorm=False,\n    residual_in_fp32=False,\n    zero_centered_weight=False,\n    return_dropout_mask=False,\n    out_dtype=None,\n    out=None,\n    residual_out=None,\n):\n    return LayerNormFn.forward(\n        x,\n        weight,\n        bias,\n        residual,\n        x1,\n        weight1,\n        bias1,\n        eps,\n        dropout_p,\n        rowscale,\n        prenorm,\n        residual_in_fp32,\n        zero_centered_weight,\n        True,\n        return_dropout_mask,\n        out_dtype,\n        out,\n        residual_out,\n    )\n\n\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\n\nif current_platform.is_mps():\n    from .mps_fallback import norm_infer_native, rms_norm_fn_native\n\n    norm_infer = norm_infer_native\n    rms_norm_fn = rms_norm_fn_native\n"
  },
  {
    "path": "python/sglang/jit_kernel/diffusion/triton/npu_fallback.py",
    "content": "import torch\n\n\n# TODO: remove this when triton ascend bug is fixed\ndef fuse_scale_shift_native(\n    x: torch.Tensor,\n    scale: torch.Tensor,\n    shift: torch.Tensor,\n    block_l: int = 128,\n    block_c: int = 128,\n):\n    return x * (1 + scale) + shift\n\n\n# TODO: remove this when triton ascend bug is fixed\ndef apply_rotary_embedding_native(\n    x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False\n) -> torch.Tensor:\n    cos = cos.unsqueeze(-2).to(x.dtype)\n    sin = sin.unsqueeze(-2).to(x.dtype)\n    x1 = x[..., ::2]\n    x2 = x[..., 1::2]\n    o1 = x1 * cos - x2 * sin\n    o2 = x2 * cos + x1 * sin\n    return torch.stack((o1, o2), dim=-1).flatten(-2)\n"
  },
  {
    "path": "python/sglang/jit_kernel/diffusion/triton/rmsnorm_onepass.py",
    "content": "import torch\nimport triton  # type: ignore\nimport triton.language as tl  # type: ignore\n\n\n# Adapted from https://github.com/ModelTC/LightX2V/blob/main/lightx2v/common/ops/norm/triton_ops.py#L905-L956\n@triton.jit\ndef _rms_norm_tiled_onepass(\n    y_ptr,\n    x_ptr,\n    w_ptr,\n    SEQ: tl.constexpr,\n    DIM: tl.constexpr,\n    EPS: tl.constexpr,\n    BLOCK_SIZE_SEQ: tl.constexpr,\n    BLOCK_SIZE_DIM: tl.constexpr,\n):\n    seq_blk_id = tl.program_id(0)\n    seq_id = seq_blk_id * BLOCK_SIZE_SEQ\n\n    seq_offset = seq_id + tl.arange(0, BLOCK_SIZE_SEQ)[:, None]\n    s_mask = seq_offset < SEQ\n    d_offset = tl.arange(0, BLOCK_SIZE_DIM)[None, :]\n    d_mask = d_offset < DIM\n    y_blk = y_ptr + seq_offset * DIM + d_offset\n    x_blk = x_ptr + seq_offset * DIM + d_offset\n    mask = s_mask & d_mask\n\n    x = tl.load(x_blk, mask=mask, other=0.0).to(tl.float32)\n    mean_square = tl.sum(x * x, axis=1, keep_dims=True) / DIM\n    rstd = tl.math.rsqrt(mean_square + EPS)\n    w = tl.load(w_ptr + d_offset, mask=d_mask)\n    tl.store(y_blk, x * rstd * w, mask=mask)\n\n\ndef triton_one_pass_rms_norm(x: torch.Tensor, w: torch.Tensor, eps: float = 1e-6):\n    shape = x.shape\n    x = x.contiguous()\n    y = torch.empty_like(x)\n    x_view = x.reshape(-1, shape[-1])\n    y_view = y.reshape(-1, shape[-1])\n    S, D = x_view.shape\n\n    BLOCK_SIZE_SEQ = min(16, triton.next_power_of_2(max(1, S // 512)))\n    grid = (triton.cdiv(S, BLOCK_SIZE_SEQ),)\n\n    with torch.get_device_module().device(x.device):\n        torch.library.wrap_triton(_rms_norm_tiled_onepass)[grid](\n            y_view,\n            x_view,\n            w,\n            S,\n            D,\n            eps,\n            BLOCK_SIZE_DIM=triton.next_power_of_2(D),\n            BLOCK_SIZE_SEQ=BLOCK_SIZE_SEQ,\n        )\n    return y\n\n\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\n\nif current_platform.is_mps():\n    from .mps_fallback import triton_one_pass_rms_norm_native\n\n    triton_one_pass_rms_norm = triton_one_pass_rms_norm_native\n"
  },
  {
    "path": "python/sglang/jit_kernel/diffusion/triton/rotary.py",
    "content": "import torch\nimport triton  # type: ignore\nimport triton.language as tl  # type: ignore\n\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\n\n\n@triton.autotune(\n    configs=[\n        triton.Config({\"BLOCK_HS_HALF\": 32}, num_warps=2),\n        triton.Config({\"BLOCK_HS_HALF\": 64}, num_warps=4),\n        triton.Config({\"BLOCK_HS_HALF\": 128}, num_warps=4),\n        triton.Config({\"BLOCK_HS_HALF\": 256}, num_warps=8),\n    ],\n    key=[\"head_size\", \"interleaved\"],\n)\n@triton.jit\ndef _rotary_embedding_kernel(\n    output_ptr,\n    x_ptr,\n    cos_ptr,\n    sin_ptr,\n    num_heads,\n    head_size,\n    num_tokens,\n    stride_x_row,\n    stride_cos_row,\n    stride_sin_row,\n    interleaved: tl.constexpr,\n    BLOCK_HS_HALF: tl.constexpr,\n):\n    row_idx = tl.program_id(0)\n    token_idx = (row_idx // num_heads) % num_tokens\n\n    x_row_ptr = x_ptr + row_idx * stride_x_row\n    cos_row_ptr = cos_ptr + token_idx * stride_cos_row\n    sin_row_ptr = sin_ptr + token_idx * stride_sin_row\n    output_row_ptr = output_ptr + row_idx * stride_x_row\n\n    # half size for x1 and x2\n    head_size_half = head_size // 2\n\n    for block_start in range(0, head_size_half, BLOCK_HS_HALF):\n        offsets_half = block_start + tl.arange(0, BLOCK_HS_HALF)\n        mask = offsets_half < head_size_half\n\n        cos_vals = tl.load(cos_row_ptr + offsets_half, mask=mask, other=0.0)\n        sin_vals = tl.load(sin_row_ptr + offsets_half, mask=mask, other=0.0)\n\n        offsets_x1 = 2 * offsets_half\n        offsets_x2 = 2 * offsets_half + 1\n\n        x1_vals = tl.load(x_row_ptr + offsets_x1, mask=mask, other=0.0)\n        x2_vals = tl.load(x_row_ptr + offsets_x2, mask=mask, other=0.0)\n\n        x1_fp32 = x1_vals.to(tl.float32)\n        x2_fp32 = x2_vals.to(tl.float32)\n        cos_fp32 = cos_vals.to(tl.float32)\n        sin_fp32 = sin_vals.to(tl.float32)\n        o1_vals = tl.fma(-x2_fp32, sin_fp32, x1_fp32 * cos_fp32)\n        o2_vals = tl.fma(x1_fp32, sin_fp32, x2_fp32 * cos_fp32)\n\n        tl.store(output_row_ptr + offsets_x1, o1_vals.to(x1_vals.dtype), mask=mask)\n        tl.store(output_row_ptr + offsets_x2, o2_vals.to(x2_vals.dtype), mask=mask)\n\n\ndef apply_rotary_embedding(\n    x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False\n) -> torch.Tensor:\n    output = torch.empty_like(x)\n\n    if x.dim() > 3:\n        bsz, num_tokens, num_heads, head_size = x.shape\n    else:\n        num_tokens, num_heads, head_size = x.shape\n        bsz = 1\n\n    assert head_size % 2 == 0, \"head_size must be divisible by 2\"\n\n    x_reshaped = x.view(-1, head_size)\n    output_reshaped = output.view(-1, head_size)\n\n    # num_tokens per head, 1 token per block\n    grid = (bsz * num_tokens * num_heads,)\n\n    if interleaved and cos.shape[-1] == head_size:\n        cos = cos[..., ::2].contiguous()\n        sin = sin[..., ::2].contiguous()\n    else:\n        cos = cos.contiguous()\n        sin = sin.contiguous()\n\n    _rotary_embedding_kernel[grid](\n        output_reshaped,\n        x_reshaped,\n        cos,\n        sin,\n        num_heads,\n        head_size,\n        num_tokens,\n        x_reshaped.stride(0),\n        cos.stride(0),\n        sin.stride(0),\n        interleaved,\n    )\n\n    return output\n\n\nif current_platform.is_npu():\n    from .npu_fallback import apply_rotary_embedding_native\n\n    apply_rotary_embedding = apply_rotary_embedding_native\n\nif current_platform.is_mps():\n    from .mps_fallback import apply_rotary_embedding_native\n\n    apply_rotary_embedding = apply_rotary_embedding_native\n"
  },
  {
    "path": "python/sglang/jit_kernel/diffusion/triton/scale_shift.py",
    "content": "import torch\nimport triton  # type: ignore\nimport triton.language as tl  # type: ignore\n\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\n\n\n@triton.jit\ndef _fused_layernorm_scale_shift_gate_select01_kernel(\n    output_ptr,\n    gate_out_ptr,\n    x_ptr,\n    weight_ptr,\n    bias_ptr,\n    scale0_ptr,\n    shift0_ptr,\n    gate0_ptr,\n    scale1_ptr,\n    shift1_ptr,\n    gate1_ptr,\n    index_ptr,\n    inner_dim,\n    seq_len,\n    stride_x_row,\n    stride_out_row,\n    stride_go_row,\n    stride_w,\n    stride_b,\n    stride_s0_b,\n    stride_s0_c,\n    stride_sh0_b,\n    stride_sh0_c,\n    stride_g0_b,\n    stride_g0_c,\n    stride_s1_b,\n    stride_s1_c,\n    stride_sh1_b,\n    stride_sh1_c,\n    stride_g1_b,\n    stride_g1_c,\n    stride_i_b,\n    stride_i_l,\n    eps,\n    HAS_WEIGHT: tl.constexpr,\n    HAS_BIAS: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n):\n    row = tl.program_id(0)\n    cols = tl.arange(0, BLOCK_N)\n    mask = cols < inner_dim\n\n    x_row_ptr = x_ptr + row * stride_x_row\n    out_row_ptr = output_ptr + row * stride_out_row\n    gate_row_ptr = gate_out_ptr + row * stride_go_row\n\n    x = tl.load(x_row_ptr + cols, mask=mask, other=0.0).to(tl.float32)\n    mean = tl.sum(x, axis=0) / inner_dim\n    xbar = tl.where(mask, x - mean, 0.0)\n    var = tl.sum(xbar * xbar, axis=0) / inner_dim\n    rstd = tl.rsqrt(var + eps)\n    x_hat = (x - mean) * rstd\n\n    if HAS_WEIGHT:\n        w = tl.load(weight_ptr + cols * stride_w, mask=mask, other=1.0).to(tl.float32)\n        x_hat = x_hat * w\n    if HAS_BIAS:\n        b = tl.load(bias_ptr + cols * stride_b, mask=mask, other=0.0).to(tl.float32)\n        x_hat = x_hat + b\n\n    batch_idx = row // seq_len\n    seq_idx = row % seq_len\n    idx = tl.load(index_ptr + batch_idx * stride_i_b + seq_idx * stride_i_l).to(tl.int1)\n\n    scale0 = tl.load(\n        scale0_ptr + batch_idx * stride_s0_b + cols * stride_s0_c,\n        mask=mask,\n        other=0.0,\n    ).to(tl.float32)\n    shift0 = tl.load(\n        shift0_ptr + batch_idx * stride_sh0_b + cols * stride_sh0_c,\n        mask=mask,\n        other=0.0,\n    ).to(tl.float32)\n    gate0 = tl.load(\n        gate0_ptr + batch_idx * stride_g0_b + cols * stride_g0_c,\n        mask=mask,\n        other=0.0,\n    )\n\n    scale1 = tl.load(\n        scale1_ptr + batch_idx * stride_s1_b + cols * stride_s1_c,\n        mask=mask,\n        other=0.0,\n    ).to(tl.float32)\n    shift1 = tl.load(\n        shift1_ptr + batch_idx * stride_sh1_b + cols * stride_sh1_c,\n        mask=mask,\n        other=0.0,\n    ).to(tl.float32)\n    gate1 = tl.load(\n        gate1_ptr + batch_idx * stride_g1_b + cols * stride_g1_c,\n        mask=mask,\n        other=0.0,\n    )\n\n    scale = tl.where(idx, scale1, scale0)\n    shift = tl.where(idx, shift1, shift0)\n    gate = tl.where(idx, gate1, gate0)\n    y = x_hat * (1.0 + scale) + shift\n\n    tl.store(out_row_ptr + cols, y, mask=mask)\n    tl.store(gate_row_ptr + cols, gate, mask=mask)\n\n\n@triton.jit\ndef _fused_residual_layernorm_scale_shift_gate_select01_kernel(\n    output_ptr,\n    residual_out_ptr,\n    gate_out_ptr,\n    x_ptr,\n    residual_ptr,\n    residual_gate_ptr,\n    weight_ptr,\n    bias_ptr,\n    scale0_ptr,\n    shift0_ptr,\n    gate0_ptr,\n    scale1_ptr,\n    shift1_ptr,\n    gate1_ptr,\n    index_ptr,\n    inner_dim,\n    seq_len,\n    stride_x_row,\n    stride_res_row,\n    stride_rg_row,\n    stride_out_row,\n    stride_res_out_row,\n    stride_go_row,\n    stride_w,\n    stride_b,\n    stride_s0_b,\n    stride_s0_c,\n    stride_sh0_b,\n    stride_sh0_c,\n    stride_g0_b,\n    stride_g0_c,\n    stride_s1_b,\n    stride_s1_c,\n    stride_sh1_b,\n    stride_sh1_c,\n    stride_g1_b,\n    stride_g1_c,\n    stride_i_b,\n    stride_i_l,\n    eps,\n    HAS_WEIGHT: tl.constexpr,\n    HAS_BIAS: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n):\n    row = tl.program_id(0)\n    cols = tl.arange(0, BLOCK_N)\n    mask = cols < inner_dim\n\n    x_row_ptr = x_ptr + row * stride_x_row\n    res_row_ptr = residual_ptr + row * stride_res_row\n    rg_row_ptr = residual_gate_ptr + row * stride_rg_row\n    out_row_ptr = output_ptr + row * stride_out_row\n    res_out_row_ptr = residual_out_ptr + row * stride_res_out_row\n    gate_row_ptr = gate_out_ptr + row * stride_go_row\n\n    x = tl.load(x_row_ptr + cols, mask=mask, other=0.0).to(tl.float32)\n    residual = tl.load(res_row_ptr + cols, mask=mask, other=0.0).to(tl.float32)\n    residual_gate = tl.load(rg_row_ptr + cols, mask=mask, other=0.0).to(tl.float32)\n    residual_out = residual + residual_gate * x\n    tl.store(res_out_row_ptr + cols, residual_out, mask=mask)\n\n    mean = tl.sum(residual_out, axis=0) / inner_dim\n    xbar = tl.where(mask, residual_out - mean, 0.0)\n    var = tl.sum(xbar * xbar, axis=0) / inner_dim\n    rstd = tl.rsqrt(var + eps)\n    x_hat = (residual_out - mean) * rstd\n\n    if HAS_WEIGHT:\n        w = tl.load(weight_ptr + cols * stride_w, mask=mask, other=1.0).to(tl.float32)\n        x_hat = x_hat * w\n    if HAS_BIAS:\n        b = tl.load(bias_ptr + cols * stride_b, mask=mask, other=0.0).to(tl.float32)\n        x_hat = x_hat + b\n\n    batch_idx = row // seq_len\n    seq_idx = row % seq_len\n    idx = tl.load(index_ptr + batch_idx * stride_i_b + seq_idx * stride_i_l).to(tl.int1)\n\n    scale0 = tl.load(\n        scale0_ptr + batch_idx * stride_s0_b + cols * stride_s0_c,\n        mask=mask,\n        other=0.0,\n    ).to(tl.float32)\n    shift0 = tl.load(\n        shift0_ptr + batch_idx * stride_sh0_b + cols * stride_sh0_c,\n        mask=mask,\n        other=0.0,\n    ).to(tl.float32)\n    gate0 = tl.load(\n        gate0_ptr + batch_idx * stride_g0_b + cols * stride_g0_c,\n        mask=mask,\n        other=0.0,\n    )\n\n    scale1 = tl.load(\n        scale1_ptr + batch_idx * stride_s1_b + cols * stride_s1_c,\n        mask=mask,\n        other=0.0,\n    ).to(tl.float32)\n    shift1 = tl.load(\n        shift1_ptr + batch_idx * stride_sh1_b + cols * stride_sh1_c,\n        mask=mask,\n        other=0.0,\n    ).to(tl.float32)\n    gate1 = tl.load(\n        gate1_ptr + batch_idx * stride_g1_b + cols * stride_g1_c,\n        mask=mask,\n        other=0.0,\n    )\n\n    scale = tl.where(idx, scale1, scale0)\n    shift = tl.where(idx, shift1, shift0)\n    gate = tl.where(idx, gate1, gate0)\n    y = x_hat * (1.0 + scale) + shift\n\n    tl.store(out_row_ptr + cols, y, mask=mask)\n    tl.store(gate_row_ptr + cols, gate, mask=mask)\n\n\n@triton.autotune(\n    configs=[\n        triton.Config({\"BLOCK_N\": 64}, num_warps=2),\n        triton.Config({\"BLOCK_N\": 128}, num_warps=4),\n        triton.Config({\"BLOCK_N\": 256}, num_warps=4),\n        triton.Config({\"BLOCK_N\": 512}, num_warps=4),\n        triton.Config({\"BLOCK_N\": 1024}, num_warps=8),\n    ],\n    key=[\"inner_dim\"],\n)\n@triton.jit\ndef _fused_scale_shift_4d_kernel(\n    output_ptr,\n    normalized_ptr,\n    scale_ptr,\n    shift_ptr,\n    scale_constant: tl.constexpr,  # scale_constant is either 0 or 1.\n    rows,\n    inner_dim,\n    seq_len,\n    num_frames,\n    frame_seqlen,\n    BLOCK_N: tl.constexpr,\n):\n    pid_row = tl.program_id(0)\n    pid_col = tl.program_id(1)\n\n    col_offsets = pid_col * BLOCK_N + tl.arange(0, BLOCK_N)\n    mask = col_offsets < inner_dim\n\n    # Pointers for normalized and output\n    row_base = pid_row * inner_dim\n    norm_ptrs = normalized_ptr + row_base + col_offsets\n    out_ptrs = output_ptr + row_base + col_offsets\n\n    # Pointers for scale (per-frame) and shift (per-token)\n    b_idx = pid_row // seq_len\n    t_idx = pid_row % seq_len\n    frame_idx_in_batch = t_idx // frame_seqlen\n\n    scale_row_idx = b_idx * num_frames + frame_idx_in_batch\n    scale_ptrs = scale_ptr + scale_row_idx * inner_dim + col_offsets\n    # shift is per-token [B*L, C], indexed by pid_row directly\n    shift_ptrs = shift_ptr + pid_row * inner_dim + col_offsets\n\n    normalized = tl.load(norm_ptrs, mask=mask, other=0.0)\n    scale = tl.load(scale_ptrs, mask=mask, other=0.0)\n    shift = tl.load(shift_ptrs, mask=mask, other=0.0)\n\n    scale_const_tensor = tl.full([BLOCK_N], scale_constant, dtype=scale.dtype)\n    output = normalized * (scale_const_tensor + scale) + shift\n\n    tl.store(out_ptrs, output, mask=mask)\n\n\n@triton.jit\ndef fuse_scale_shift_kernel_blc_opt(\n    x_ptr,\n    shift_ptr,\n    scale_ptr,\n    scale_constant: tl.constexpr,  # scale_constant is either 0 or 1.,\n    y_ptr,\n    B,\n    L,\n    C,\n    stride_x_b,\n    stride_x_l,\n    stride_x_c,\n    stride_s_b,\n    stride_s_l,\n    stride_s_c,\n    stride_sc_b,\n    stride_sc_l,\n    stride_sc_c,\n    SCALE_IS_SCALAR: tl.constexpr,\n    SHIFT_IS_SCALAR: tl.constexpr,\n    BLOCK_L: tl.constexpr,\n    BLOCK_C: tl.constexpr,\n):\n    pid_l = tl.program_id(0)\n    pid_c = tl.program_id(1)\n    pid_b = tl.program_id(2)\n\n    l_offsets = pid_l * BLOCK_L + tl.arange(0, BLOCK_L)\n    c_offsets = pid_c * BLOCK_C + tl.arange(0, BLOCK_C)\n\n    mask_l = l_offsets < L\n    mask_c = c_offsets < C\n    mask = mask_l[:, None] & mask_c[None, :]\n\n    x_off = (\n        pid_b * stride_x_b\n        + l_offsets[:, None] * stride_x_l\n        + c_offsets[None, :] * stride_x_c\n    )\n    x = tl.load(x_ptr + x_off, mask=mask, other=0)\n\n    if SHIFT_IS_SCALAR:\n        shift_val = tl.load(shift_ptr)\n        shift = tl.full((BLOCK_L, BLOCK_C), shift_val, dtype=shift_val.dtype)\n    else:\n        s_off = (\n            pid_b * stride_s_b\n            + l_offsets[:, None] * stride_s_l\n            + c_offsets[None, :] * stride_s_c\n        )\n        shift = tl.load(shift_ptr + s_off, mask=mask, other=0)\n\n    if SCALE_IS_SCALAR:\n        scale_val = tl.load(scale_ptr)\n        scale = tl.full((BLOCK_L, BLOCK_C), scale_val, dtype=scale_val.dtype)\n    else:\n        sc_off = (\n            pid_b * stride_sc_b\n            + l_offsets[:, None] * stride_sc_l\n            + c_offsets[None, :] * stride_sc_c\n        )\n        scale = tl.load(scale_ptr + sc_off, mask=mask, other=0)\n\n    y = x * (scale_constant + scale) + shift\n    tl.store(y_ptr + x_off, y, mask=mask)\n\n\n@triton.jit\ndef fuse_scale_shift_gate_select01_kernel_blc_opt(\n    x_ptr,\n    shift0_ptr,\n    scale0_ptr,\n    gate0_ptr,\n    shift1_ptr,\n    scale1_ptr,\n    gate1_ptr,\n    index_ptr,\n    y_ptr,\n    gate_out_ptr,\n    B,\n    L,\n    C,\n    stride_x_b,\n    stride_x_l,\n    stride_x_c,\n    stride_s0_b,\n    stride_s0_c,\n    stride_sc0_b,\n    stride_sc0_c,\n    stride_g0_b,\n    stride_g0_c,\n    stride_s1_b,\n    stride_s1_c,\n    stride_sc1_b,\n    stride_sc1_c,\n    stride_g1_b,\n    stride_g1_c,\n    stride_i_b,\n    stride_i_l,\n    stride_go_b,\n    stride_go_l,\n    stride_go_c,\n    BLOCK_L: tl.constexpr,\n    BLOCK_C: tl.constexpr,\n):\n    pid_l = tl.program_id(0)\n    pid_c = tl.program_id(1)\n    pid_b = tl.program_id(2)\n\n    l_offsets = pid_l * BLOCK_L + tl.arange(0, BLOCK_L)\n    c_offsets = pid_c * BLOCK_C + tl.arange(0, BLOCK_C)\n\n    mask_l = l_offsets < L\n    mask_c = c_offsets < C\n    mask = mask_l[:, None] & mask_c[None, :]\n\n    x_off = (\n        pid_b * stride_x_b\n        + l_offsets[:, None] * stride_x_l\n        + c_offsets[None, :] * stride_x_c\n    )\n    x = tl.load(x_ptr + x_off, mask=mask, other=0)\n\n    idx_off = pid_b * stride_i_b + l_offsets * stride_i_l\n    idx = tl.load(index_ptr + idx_off, mask=mask_l, other=0).to(tl.int1)[:, None]\n\n    s0_off = pid_b * stride_s0_b + c_offsets[None, :] * stride_s0_c\n    sc0_off = pid_b * stride_sc0_b + c_offsets[None, :] * stride_sc0_c\n    g0_off = pid_b * stride_g0_b + c_offsets[None, :] * stride_g0_c\n    s1_off = pid_b * stride_s1_b + c_offsets[None, :] * stride_s1_c\n    sc1_off = pid_b * stride_sc1_b + c_offsets[None, :] * stride_sc1_c\n    g1_off = pid_b * stride_g1_b + c_offsets[None, :] * stride_g1_c\n\n    shift0 = tl.load(shift0_ptr + s0_off, mask=mask_c[None, :], other=0)\n    scale0 = tl.load(scale0_ptr + sc0_off, mask=mask_c[None, :], other=0)\n    gate0 = tl.load(gate0_ptr + g0_off, mask=mask_c[None, :], other=0)\n    shift1 = tl.load(shift1_ptr + s1_off, mask=mask_c[None, :], other=0)\n    scale1 = tl.load(scale1_ptr + sc1_off, mask=mask_c[None, :], other=0)\n    gate1 = tl.load(gate1_ptr + g1_off, mask=mask_c[None, :], other=0)\n\n    shift = tl.where(idx, shift1, shift0)\n    scale = tl.where(idx, scale1, scale0)\n    gate = tl.where(idx, gate1, gate0)\n\n    y = x * (1 + scale) + shift\n    tl.store(y_ptr + x_off, y, mask=mask)\n\n    go_off = (\n        pid_b * stride_go_b\n        + l_offsets[:, None] * stride_go_l\n        + c_offsets[None, :] * stride_go_c\n    )\n    tl.store(gate_out_ptr + go_off, gate, mask=mask)\n\n\ndef fuse_scale_shift_kernel(\n    x: torch.Tensor,\n    scale: torch.Tensor,\n    shift: torch.Tensor,\n    scale_constant: float = 1.0,\n    block_l: int = 128,\n    block_c: int = 128,\n):\n    assert x.is_cuda and scale.is_cuda\n    assert x.is_contiguous()\n\n    B, L, C = x.shape\n    output = torch.empty_like(x)\n\n    if scale.dim() == 4:\n        # scale/shift: [B, F, 1, C]\n        rows = B * L\n        x_2d = x.view(rows, C)\n        output_2d = output.view(rows, C)\n        grid = lambda META: (rows, triton.cdiv(C, META[\"BLOCK_N\"]))\n        num_frames = scale.shape[1]\n        assert (\n            L % num_frames == 0\n        ), \"seq_len must be divisible by num_frames for 4D scale/shift\"\n        frame_seqlen = L // num_frames\n\n        # Compact scale [B, F, 1, C] -> [B*F, C] (per-frame)\n        scale_reshaped = scale.squeeze(2).reshape(-1, C).contiguous()\n        # shift is per-token [B, L, C] -> [B*L, C]\n        shift_reshaped = shift.reshape(rows, C).contiguous()\n\n        _fused_scale_shift_4d_kernel[grid](\n            output_2d,\n            x_2d,\n            scale_reshaped,\n            shift_reshaped,\n            scale_constant,\n            rows,\n            C,\n            L,\n            num_frames,\n            frame_seqlen,\n        )\n    else:\n        # 2D: [B, C] or [1, C]  -> treat as [B, 1, C] and broadcast over L\n        # 3D: [B, L, C] (or broadcastable variants like [B, 1, C], [1, L, C], [1, 1, C])\n        # Also support scalar (0D or 1-element)\n        if scale.dim() == 0 or (scale.dim() == 1 and scale.numel() == 1):\n            scale_blc = scale.reshape(1)\n        elif scale.dim() == 2:\n            scale_blc = scale[:, None, :]\n        elif scale.dim() == 3:\n            scale_blc = scale\n        else:\n            raise ValueError(\"scale must be 0D/1D(1)/2D/3D or 4D\")\n\n        if shift.dim() == 0 or (shift.dim() == 1 and shift.numel() == 1):\n            shift_blc = shift.reshape(1)\n        elif shift.dim() == 2:\n            shift_blc = shift[:, None, :]\n        elif shift.dim() == 3:\n            shift_blc = shift\n        else:\n            # broadcast later via expand if possible\n            shift_blc = shift\n\n        need_scale_scalar = scale_blc.dim() == 1 and scale_blc.numel() == 1\n        need_shift_scalar = shift_blc.dim() == 1 and shift_blc.numel() == 1\n\n        if not need_scale_scalar:\n            scale_exp = scale_blc.expand(B, L, C)\n            s_sb, s_sl, s_sc = scale_exp.stride()\n        else:\n            s_sb = s_sl = s_sc = 0\n\n        if not need_shift_scalar:\n            shift_exp = shift_blc.expand(B, L, C)\n            sh_sb, sh_sl, sh_sc = shift_exp.stride()\n        else:\n            sh_sb = sh_sl = sh_sc = 0\n\n        # If both scalars and both zero, copy fast-path\n        if need_scale_scalar and need_shift_scalar:\n            if not (\n                scale_blc.any().to(\"cpu\", non_blocking=True)\n                or shift_blc.any().to(\"cpu\", non_blocking=True)\n            ):\n                output.copy_(x)\n                return output\n\n        grid = (triton.cdiv(L, block_l), triton.cdiv(C, block_c), B)\n        fuse_scale_shift_kernel_blc_opt[grid](\n            x,\n            shift_blc if need_shift_scalar else shift_exp,\n            scale_blc if need_scale_scalar else scale_exp,\n            scale_constant,\n            output,\n            B,\n            L,\n            C,\n            x.stride(0),\n            x.stride(1),\n            x.stride(2),\n            sh_sb,\n            sh_sl,\n            sh_sc,\n            s_sb,\n            s_sl,\n            s_sc,\n            SCALE_IS_SCALAR=need_scale_scalar,\n            SHIFT_IS_SCALAR=need_shift_scalar,\n            BLOCK_L=block_l,\n            BLOCK_C=block_c,\n            num_warps=4,\n            num_stages=2,\n        )\n    return output\n\n\ndef fuse_scale_shift_gate_select01_kernel(\n    x: torch.Tensor,\n    scale0: torch.Tensor,\n    shift0: torch.Tensor,\n    gate0: torch.Tensor,\n    scale1: torch.Tensor,\n    shift1: torch.Tensor,\n    gate1: torch.Tensor,\n    index: torch.Tensor,\n    block_l: int = 128,\n    block_c: int = 128,\n):\n    assert x.is_contiguous()\n    B, L, C = x.shape\n    output = torch.empty_like(x)\n    gate_out = torch.empty_like(x)\n\n    if (\n        scale0.dim() != 2\n        or shift0.dim() != 2\n        or gate0.dim() != 2\n        or scale1.dim() != 2\n        or shift1.dim() != 2\n        or gate1.dim() != 2\n    ):\n        raise ValueError(\"scale0/shift0/gate0/scale1/shift1/gate1 must be 2D [B, C]\")\n    if index.dim() != 2:\n        raise ValueError(\"index must be 2D [B, L]\")\n\n    grid = (triton.cdiv(L, block_l), triton.cdiv(C, block_c), B)\n    fuse_scale_shift_gate_select01_kernel_blc_opt[grid](\n        x,\n        shift0,\n        scale0,\n        gate0,\n        shift1,\n        scale1,\n        gate1,\n        index,\n        output,\n        gate_out,\n        B,\n        L,\n        C,\n        x.stride(0),\n        x.stride(1),\n        x.stride(2),\n        shift0.stride(0),\n        shift0.stride(1),\n        scale0.stride(0),\n        scale0.stride(1),\n        gate0.stride(0),\n        gate0.stride(1),\n        shift1.stride(0),\n        shift1.stride(1),\n        scale1.stride(0),\n        scale1.stride(1),\n        gate1.stride(0),\n        gate1.stride(1),\n        index.stride(0),\n        index.stride(1),\n        gate_out.stride(0),\n        gate_out.stride(1),\n        gate_out.stride(2),\n        BLOCK_L=block_l,\n        BLOCK_C=block_c,\n        num_warps=4,\n        num_stages=2,\n    )\n    return output, gate_out\n\n\ndef fuse_layernorm_scale_shift_gate_select01_kernel(\n    x: torch.Tensor,\n    weight: torch.Tensor | None,\n    bias: torch.Tensor | None,\n    scale0: torch.Tensor,\n    shift0: torch.Tensor,\n    gate0: torch.Tensor,\n    scale1: torch.Tensor,\n    shift1: torch.Tensor,\n    gate1: torch.Tensor,\n    index: torch.Tensor,\n    eps: float,\n):\n    assert x.is_cuda\n    assert x.is_contiguous()\n    B, L, C = x.shape\n    output = torch.empty_like(x)\n    gate_out = torch.empty_like(x)\n\n    if (\n        scale0.dim() != 2\n        or shift0.dim() != 2\n        or gate0.dim() != 2\n        or scale1.dim() != 2\n        or shift1.dim() != 2\n        or gate1.dim() != 2\n    ):\n        raise ValueError(\"scale0/shift0/gate0/scale1/shift1/gate1 must be 2D [B, C]\")\n    if index.dim() != 2:\n        raise ValueError(\"index must be 2D [B, L]\")\n    if weight is not None and (weight.dim() != 1 or weight.shape[0] != C):\n        raise ValueError(\"weight must be 1D [C]\")\n    if bias is not None and (bias.dim() != 1 or bias.shape[0] != C):\n        raise ValueError(\"bias must be 1D [C]\")\n\n    x_2d = x.view(B * L, C)\n    output_2d = output.view(B * L, C)\n    gate_out_2d = gate_out.view(B * L, C)\n    weight = weight.contiguous() if weight is not None else x_2d\n    bias = bias.contiguous() if bias is not None else x_2d\n\n    MAX_FUSED_SIZE = 65536 // x_2d.element_size()\n    BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(C))\n    if C > BLOCK_N:\n        raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n\n    grid = (B * L,)\n    _fused_layernorm_scale_shift_gate_select01_kernel[grid](\n        output_2d,\n        gate_out_2d,\n        x_2d,\n        weight,\n        bias,\n        scale0.contiguous(),\n        shift0.contiguous(),\n        gate0.contiguous(),\n        scale1.contiguous(),\n        shift1.contiguous(),\n        gate1.contiguous(),\n        index.contiguous(),\n        C,\n        L,\n        x_2d.stride(0),\n        output_2d.stride(0),\n        gate_out_2d.stride(0),\n        weight.stride(0) if weight.dim() == 1 else 0,\n        bias.stride(0) if bias.dim() == 1 else 0,\n        scale0.stride(0),\n        scale0.stride(1),\n        shift0.stride(0),\n        shift0.stride(1),\n        gate0.stride(0),\n        gate0.stride(1),\n        scale1.stride(0),\n        scale1.stride(1),\n        shift1.stride(0),\n        shift1.stride(1),\n        gate1.stride(0),\n        gate1.stride(1),\n        index.stride(0),\n        index.stride(1),\n        eps,\n        HAS_WEIGHT=weight is not x_2d,\n        HAS_BIAS=bias is not x_2d,\n        BLOCK_N=BLOCK_N,\n    )\n    return output, gate_out\n\n\ndef fuse_residual_layernorm_scale_shift_gate_select01_kernel(\n    x: torch.Tensor,\n    residual: torch.Tensor,\n    residual_gate: torch.Tensor,\n    weight: torch.Tensor | None,\n    bias: torch.Tensor | None,\n    scale0: torch.Tensor,\n    shift0: torch.Tensor,\n    gate0: torch.Tensor,\n    scale1: torch.Tensor,\n    shift1: torch.Tensor,\n    gate1: torch.Tensor,\n    index: torch.Tensor,\n    eps: float,\n):\n    assert x.is_cuda\n    assert x.is_contiguous()\n    assert residual.is_contiguous()\n    assert residual_gate.is_contiguous()\n    B, L, C = x.shape\n    output = torch.empty_like(x)\n    residual_out = torch.empty_like(x)\n    gate_out = torch.empty_like(x)\n\n    if residual.shape != x.shape:\n        raise ValueError(\"residual must have the same shape as x\")\n    if residual_gate.shape != x.shape:\n        raise ValueError(\"residual_gate must have the same shape as x\")\n    if (\n        scale0.dim() != 2\n        or shift0.dim() != 2\n        or gate0.dim() != 2\n        or scale1.dim() != 2\n        or shift1.dim() != 2\n        or gate1.dim() != 2\n    ):\n        raise ValueError(\"scale0/shift0/gate0/scale1/shift1/gate1 must be 2D [B, C]\")\n    if index.dim() != 2:\n        raise ValueError(\"index must be 2D [B, L]\")\n    if weight is not None and (weight.dim() != 1 or weight.shape[0] != C):\n        raise ValueError(\"weight must be 1D [C]\")\n    if bias is not None and (bias.dim() != 1 or bias.shape[0] != C):\n        raise ValueError(\"bias must be 1D [C]\")\n\n    x_2d = x.view(B * L, C)\n    residual_2d = residual.view(B * L, C)\n    residual_gate_2d = residual_gate.view(B * L, C)\n    output_2d = output.view(B * L, C)\n    residual_out_2d = residual_out.view(B * L, C)\n    gate_out_2d = gate_out.view(B * L, C)\n    weight = weight.contiguous() if weight is not None else x_2d\n    bias = bias.contiguous() if bias is not None else x_2d\n\n    MAX_FUSED_SIZE = 65536 // x_2d.element_size()\n    BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(C))\n    if C > BLOCK_N:\n        raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n\n    grid = (B * L,)\n    _fused_residual_layernorm_scale_shift_gate_select01_kernel[grid](\n        output_2d,\n        residual_out_2d,\n        gate_out_2d,\n        x_2d,\n        residual_2d,\n        residual_gate_2d,\n        weight,\n        bias,\n        scale0.contiguous(),\n        shift0.contiguous(),\n        gate0.contiguous(),\n        scale1.contiguous(),\n        shift1.contiguous(),\n        gate1.contiguous(),\n        index.contiguous(),\n        C,\n        L,\n        x_2d.stride(0),\n        residual_2d.stride(0),\n        residual_gate_2d.stride(0),\n        output_2d.stride(0),\n        residual_out_2d.stride(0),\n        gate_out_2d.stride(0),\n        weight.stride(0) if weight.dim() == 1 else 0,\n        bias.stride(0) if bias.dim() == 1 else 0,\n        scale0.stride(0),\n        scale0.stride(1),\n        shift0.stride(0),\n        shift0.stride(1),\n        gate0.stride(0),\n        gate0.stride(1),\n        scale1.stride(0),\n        scale1.stride(1),\n        shift1.stride(0),\n        shift1.stride(1),\n        gate1.stride(0),\n        gate1.stride(1),\n        index.stride(0),\n        index.stride(1),\n        eps,\n        HAS_WEIGHT=weight is not x_2d,\n        HAS_BIAS=bias is not x_2d,\n        BLOCK_N=BLOCK_N,\n    )\n    return output, residual_out, gate_out\n\n\nif current_platform.is_npu():\n    from .npu_fallback import fuse_scale_shift_native\n\n    fuse_scale_shift_kernel = fuse_scale_shift_native\n\nif current_platform.is_mps():\n    from .mps_fallback import (\n        fuse_scale_shift_gate_select01_kernel_native,\n        fuse_scale_shift_kernel_native,\n    )\n\n    fuse_scale_shift_kernel = fuse_scale_shift_kernel_native\n    fuse_scale_shift_gate_select01_kernel = fuse_scale_shift_gate_select01_kernel_native\n"
  },
  {
    "path": "python/sglang/jit_kernel/flash_attention_v4.py",
    "content": "from __future__ import annotations\n\nfrom typing import Callable, Optional, Tuple, Union\n\nimport torch\n\ntry:\n    from flash_attn.cute import flash_attn_varlen_func as _flash_attn_varlen_func\nexcept Exception as _e:  # pragma: no cover\n    _flash_attn_varlen_func = None\n    _flash_attn_import_error = _e\nelse:\n    _flash_attn_import_error = None\n\n\ndef _maybe_contiguous(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]:\n    return x.contiguous() if x is not None and x.stride(-1) != 1 else x\n\n\ndef flash_attn_varlen_func(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    cu_seqlens_q: Optional[torch.Tensor] = None,\n    cu_seqlens_k: Optional[torch.Tensor] = None,\n    seqused_q: Optional[torch.Tensor] = None,\n    seqused_k: Optional[torch.Tensor] = None,\n    max_seqlen_q: Optional[int] = None,\n    max_seqlen_k: Optional[int] = None,\n    page_table: Optional[torch.Tensor] = None,\n    softmax_scale: Optional[float] = None,\n    causal: bool = False,\n    softcap: Optional[float] = None,\n    window_size: Tuple[Optional[int], Optional[int]] = (-1, -1),\n    learnable_sink: Optional[torch.Tensor] = None,\n    sinks: Optional[torch.Tensor] = None,\n    num_splits: int = 1,\n    pack_gqa: Optional[bool] = None,\n    score_mod: Optional[Callable] = None,\n    aux_tensors: Optional[list] = None,\n    return_softmax_lse: bool = False,\n    **_: object,\n):\n    if _flash_attn_varlen_func is None:  # pragma: no cover\n        raise ImportError(\n            \"Vendored FlashAttention CUTE is not available (cannot import \"\n            \"flash_attn.cute). Please check your source tree.\"\n        ) from _flash_attn_import_error\n\n    q, k, v = [_maybe_contiguous(t) for t in (q, k, v)]\n    cu_seqlens_q, cu_seqlens_k = [\n        _maybe_contiguous(t) for t in (cu_seqlens_q, cu_seqlens_k)\n    ]\n    seqused_q, seqused_k = [_maybe_contiguous(t) for t in (seqused_q, seqused_k)]\n    page_table = _maybe_contiguous(page_table)\n\n    if learnable_sink is None and sinks is not None:\n        learnable_sink = sinks\n\n    if window_size == (-1, -1):\n        window_size = (None, None)\n\n    result = _flash_attn_varlen_func(\n        q=q,\n        k=k,\n        v=v,\n        cu_seqlens_q=cu_seqlens_q,\n        cu_seqlens_k=cu_seqlens_k,\n        seqused_q=seqused_q,\n        seqused_k=seqused_k,\n        max_seqlen_q=max_seqlen_q,\n        max_seqlen_k=max_seqlen_k,\n        page_table=page_table,\n        softmax_scale=softmax_scale,\n        causal=causal,\n        softcap=softcap,\n        window_size=window_size,\n        learnable_sink=learnable_sink,\n        num_splits=num_splits,\n        pack_gqa=pack_gqa,\n        score_mod=score_mod,\n        aux_tensors=aux_tensors,\n    )\n\n    if return_softmax_lse:\n        return result\n    if isinstance(result, tuple):\n        return result[0]\n    return result\n\n\ndef flash_attn_with_kvcache(\n    q: torch.Tensor,\n    k_cache: torch.Tensor,\n    v_cache: torch.Tensor,\n    k: Optional[torch.Tensor] = None,\n    v: Optional[torch.Tensor] = None,\n    qv: Optional[torch.Tensor] = None,\n    rotary_cos: Optional[torch.Tensor] = None,\n    rotary_sin: Optional[torch.Tensor] = None,\n    cache_seqlens: Optional[Union[int, torch.Tensor]] = None,\n    cache_batch_idx: Optional[torch.Tensor] = None,\n    cache_leftpad: Optional[torch.Tensor] = None,\n    page_table: Optional[torch.Tensor] = None,\n    cu_seqlens_q: Optional[torch.Tensor] = None,\n    cu_seqlens_k_new: Optional[torch.Tensor] = None,\n    max_seqlen_q: Optional[int] = None,\n    rotary_seqlens: Optional[torch.Tensor] = None,\n    q_descale: Optional[torch.Tensor] = None,\n    k_descale: Optional[torch.Tensor] = None,\n    v_descale: Optional[torch.Tensor] = None,\n    softmax_scale: Optional[float] = None,\n    causal: bool = False,\n    window_size: Tuple[int, int] = (-1, -1),\n    attention_chunk: Optional[int] = None,\n    softcap: float = 0.0,\n    rotary_interleaved: bool = True,\n    scheduler_metadata=None,\n    num_splits: int = 0,\n    pack_gqa: Optional[bool] = None,\n    sm_margin: int = 0,\n    sinks: Optional[torch.Tensor] = None,\n    score_mod: Optional[Callable] = None,\n    aux_tensors: Optional[list] = None,\n    return_softmax_lse: bool = False,\n    **_: object,\n):\n    if k is not None or v is not None or qv is not None:\n        raise NotImplementedError(\"FA4 does not support updating KV cache in-place.\")\n    if rotary_cos is not None or rotary_sin is not None or rotary_seqlens is not None:\n        raise NotImplementedError(\"FA4 path does not support rotary embedding.\")\n    if cache_batch_idx is not None or cache_leftpad is not None:\n        raise NotImplementedError(\n            \"FA4 path does not support non-consecutive batch indices or left padding.\"\n        )\n    if q_descale is not None or k_descale is not None or v_descale is not None:\n        raise NotImplementedError(\"FA4 path does not support descale.\")\n\n    if isinstance(cache_seqlens, int):\n        cache_seqlens = torch.full(\n            (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device\n        )\n\n    result = flash_attn_varlen_func(\n        q=q,\n        k=k_cache,\n        v=v_cache,\n        cu_seqlens_q=cu_seqlens_q,\n        seqused_k=cache_seqlens,\n        max_seqlen_q=max_seqlen_q,\n        page_table=page_table,\n        softmax_scale=softmax_scale,\n        causal=causal,\n        softcap=softcap if softcap != 0.0 else None,\n        window_size=window_size,\n        num_splits=num_splits if num_splits != 0 else 1,\n        pack_gqa=pack_gqa,\n        learnable_sink=sinks,\n        score_mod=score_mod,\n        aux_tensors=aux_tensors,\n        return_softmax_lse=True,\n    )\n\n    if return_softmax_lse:\n        return result\n    if isinstance(result, tuple):\n        return result[0]\n    return result\n"
  },
  {
    "path": "python/sglang/jit_kernel/fused_metadata_copy.py",
    "content": "\"\"\"\nFused metadata copy kernel for NSA backend CUDA graph replay.\n\nThis module provides JIT-compiled CUDA kernels for fusing multiple tensor\ncopy operations into single kernel launches, reducing kernel launch overhead\nand improving CUDA graph replay performance.\n\nThe kernels are compiled on-demand using TVM FFI and cached for subsequent use.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport logging\nfrom typing import Optional\n\nimport torch\n\nfrom sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args\n\nlogger = logging.getLogger(__name__)\n\n\n# ============================================================================\n# JIT Module Compilation\n# ============================================================================\n\n\n@cache_once\ndef _jit_fused_metadata_copy_module(\n    forward_mode: int, has_real_page_table: bool, has_flashmla: bool\n):\n    \"\"\"Compile JIT module for single-backend fused metadata copy.\n\n    Args:\n        forward_mode: 0=DECODE, 1=TARGET_VERIFY, 2=DRAFT_EXTEND\n        has_real_page_table: Whether real_page_table tensors are used\n        has_flashmla: Whether FlashMLA metadata tensors are used\n    \"\"\"\n    args = make_cpp_args(forward_mode, has_real_page_table, has_flashmla)\n    try:\n        return load_jit(\n            \"fused_metadata_copy\",\n            *args,\n            cuda_files=[\"elementwise/fused_metadata_copy.cuh\"],\n            cuda_wrappers=[\n                (\n                    \"fused_metadata_copy\",\n                    f\"FusedMetadataCopyKernel<{args}>::run\",\n                )\n            ],\n        )\n    except Exception as e:\n        logger.error(\n            f\"Failed to compile JIT fused metadata copy kernel \"\n            f\"(forward_mode={forward_mode}, has_real_page_table={has_real_page_table}, \"\n            f\"has_flashmla={has_flashmla}): {e}\"\n        )\n        raise\n\n\n@cache_once\ndef _jit_fused_metadata_copy_multi_module(\n    has_real_page_table: bool, has_flashmla: bool\n):\n    \"\"\"Compile JIT module for multi-backend fused metadata copy (DECODE mode only).\n\n    Args:\n        has_real_page_table: Whether real_page_table tensors are used\n        has_flashmla: Whether FlashMLA metadata tensors are used\n    \"\"\"\n    args = make_cpp_args(has_real_page_table, has_flashmla)\n    try:\n        return load_jit(\n            \"fused_metadata_copy_multi\",\n            *args,\n            cuda_files=[\"elementwise/fused_metadata_copy.cuh\"],\n            cuda_wrappers=[\n                (\n                    \"fused_metadata_copy_multi\",\n                    f\"FusedMetadataCopyMultiKernel<{args}>::run\",\n                )\n            ],\n        )\n    except Exception as e:\n        logger.error(\n            f\"Failed to compile JIT fused metadata copy multi kernel \"\n            f\"(has_real_page_table={has_real_page_table}, has_flashmla={has_flashmla}): {e}\"\n        )\n        raise\n\n\n# ============================================================================\n# Public API\n# ============================================================================\n\n\ndef fused_metadata_copy_cuda(\n    cache_seqlens_src: torch.Tensor,\n    cu_seqlens_k_src: torch.Tensor,\n    page_indices_src: torch.Tensor,\n    nsa_cache_seqlens_src: torch.Tensor,\n    seqlens_expanded_src: Optional[torch.Tensor],\n    nsa_cu_seqlens_k_src: torch.Tensor,\n    real_page_table_src: Optional[torch.Tensor],\n    flashmla_num_splits_src: Optional[torch.Tensor],\n    flashmla_metadata_src: Optional[torch.Tensor],\n    cache_seqlens_dst: torch.Tensor,\n    cu_seqlens_k_dst: torch.Tensor,\n    page_table_1_dst: torch.Tensor,\n    nsa_cache_seqlens_dst: torch.Tensor,\n    seqlens_expanded_dst: Optional[torch.Tensor],\n    nsa_cu_seqlens_k_dst: torch.Tensor,\n    real_page_table_dst: Optional[torch.Tensor],\n    flashmla_num_splits_dst: Optional[torch.Tensor],\n    flashmla_metadata_dst: Optional[torch.Tensor],\n    forward_mode: int,\n    bs: int,\n    max_len: int,\n    max_seqlen_k: int,\n    seqlens_expanded_size: int,\n) -> None:\n    \"\"\"\n    Fused metadata copy kernel for NSA backend CUDA graph replay.\n\n    This function fuses multiple tensor copy operations into a single kernel launch,\n    reducing kernel launch overhead and improving performance.\n\n    Args:\n        cache_seqlens_src: Source cache sequence lengths [bs]\n        cu_seqlens_k_src: Source cumulative sequence lengths [bs+1]\n        page_indices_src: Source page indices [rows, max_len]\n        nsa_cache_seqlens_src: Source NSA cache sequence lengths [size]\n        seqlens_expanded_src: Optional source expanded sequence lengths [size] (required for TARGET_VERIFY/DRAFT_EXTEND)\n        nsa_cu_seqlens_k_src: Source NSA cumulative sequence lengths [size+1]\n        real_page_table_src: Optional source real page table [rows, cols]\n        flashmla_num_splits_src: Optional source FlashMLA num_splits [size+1]\n        flashmla_metadata_src: Optional source FlashMLA metadata tensor\n        cache_seqlens_dst: Destination cache sequence lengths [bs]\n        cu_seqlens_k_dst: Destination cumulative sequence lengths [bs+1]\n        page_table_1_dst: Destination page table [rows, stride]\n        nsa_cache_seqlens_dst: Destination NSA cache sequence lengths [size]\n        seqlens_expanded_dst: Optional destination expanded sequence lengths [size] (required for TARGET_VERIFY/DRAFT_EXTEND)\n        nsa_cu_seqlens_k_dst: Destination NSA cumulative sequence lengths [size+1]\n        real_page_table_dst: Optional destination real page table [rows, cols]\n        flashmla_num_splits_dst: Optional destination FlashMLA num_splits [size+1]\n        flashmla_metadata_dst: Optional destination FlashMLA metadata tensor\n        forward_mode: Forward mode (0=DECODE, 1=TARGET_VERIFY, 2=DRAFT_EXTEND)\n        bs: Batch size\n        max_len: Maximum length for decode/draft_extend mode\n        max_seqlen_k: Maximum sequence length for target_verify mode\n        seqlens_expanded_size: Size of expanded sequence lengths\n    \"\"\"\n    # Determine template parameters for kernel specialization\n    has_real_page_table = real_page_table_src is not None\n    has_flashmla = flashmla_num_splits_src is not None\n\n    # Get JIT-compiled module for this configuration (cached after first use)\n    module = _jit_fused_metadata_copy_module(\n        forward_mode, has_real_page_table, has_flashmla\n    )\n\n    # Ensure all required source tensors are contiguous (required for kernel's linear indexing)\n    # This matches the CHECK_INPUT checks in the verified sgl-kernel implementation\n    cache_seqlens_src = cache_seqlens_src.contiguous()\n    cu_seqlens_k_src = cu_seqlens_k_src.contiguous()\n    page_indices_src = page_indices_src.contiguous()\n    nsa_cache_seqlens_src = nsa_cache_seqlens_src.contiguous()\n    if seqlens_expanded_src is not None:\n        seqlens_expanded_src = seqlens_expanded_src.contiguous()\n    nsa_cu_seqlens_k_src = nsa_cu_seqlens_k_src.contiguous()\n\n    # Call JIT-compiled kernel (None values are passed as Optional with no value)\n    module.fused_metadata_copy(\n        cache_seqlens_src,\n        cu_seqlens_k_src,\n        page_indices_src,\n        nsa_cache_seqlens_src,\n        seqlens_expanded_src,\n        nsa_cu_seqlens_k_src,\n        real_page_table_src,\n        flashmla_num_splits_src,\n        flashmla_metadata_src,\n        cache_seqlens_dst,\n        cu_seqlens_k_dst,\n        page_table_1_dst,\n        nsa_cache_seqlens_dst,\n        seqlens_expanded_dst,\n        nsa_cu_seqlens_k_dst,\n        real_page_table_dst,\n        flashmla_num_splits_dst,\n        flashmla_metadata_dst,\n        bs,\n        max_len,\n        max_seqlen_k,\n        seqlens_expanded_size,\n    )\n\n\ndef fused_metadata_copy_multi_cuda(\n    cache_seqlens_src: torch.Tensor,\n    cu_seqlens_k_src: torch.Tensor,\n    page_indices_src: torch.Tensor,\n    nsa_cache_seqlens_src: torch.Tensor,\n    nsa_cu_seqlens_k_src: torch.Tensor,\n    real_page_table_src: Optional[torch.Tensor],\n    flashmla_num_splits_src: Optional[torch.Tensor],\n    flashmla_metadata_src: Optional[torch.Tensor],\n    cache_seqlens_dst0: torch.Tensor,\n    cu_seqlens_k_dst0: torch.Tensor,\n    page_table_1_dst0: torch.Tensor,\n    nsa_cache_seqlens_dst0: torch.Tensor,\n    nsa_cu_seqlens_k_dst0: torch.Tensor,\n    real_page_table_dst0: Optional[torch.Tensor],\n    flashmla_num_splits_dst0: Optional[torch.Tensor],\n    flashmla_metadata_dst0: Optional[torch.Tensor],\n    cache_seqlens_dst1: torch.Tensor,\n    cu_seqlens_k_dst1: torch.Tensor,\n    page_table_1_dst1: torch.Tensor,\n    nsa_cache_seqlens_dst1: torch.Tensor,\n    nsa_cu_seqlens_k_dst1: torch.Tensor,\n    real_page_table_dst1: Optional[torch.Tensor],\n    flashmla_num_splits_dst1: Optional[torch.Tensor],\n    flashmla_metadata_dst1: Optional[torch.Tensor],\n    cache_seqlens_dst2: torch.Tensor,\n    cu_seqlens_k_dst2: torch.Tensor,\n    page_table_1_dst2: torch.Tensor,\n    nsa_cache_seqlens_dst2: torch.Tensor,\n    nsa_cu_seqlens_k_dst2: torch.Tensor,\n    real_page_table_dst2: Optional[torch.Tensor],\n    flashmla_num_splits_dst2: Optional[torch.Tensor],\n    flashmla_metadata_dst2: Optional[torch.Tensor],\n    bs: int,\n    max_len: int,\n    seqlens_expanded_size: int,\n) -> None:\n    \"\"\"\n    Multi-backend fused metadata copy kernel for NSA backend CUDA graph replay.\n\n    This function copies metadata from one source to THREE destinations in a single\n    kernel launch, eliminating the overhead of 3 separate kernel calls. Currently\n    only supports DECODE mode, which is the most common case.\n\n    Args:\n        cache_seqlens_src: Source cache sequence lengths [bs]\n        cu_seqlens_k_src: Source cumulative sequence lengths [bs+1]\n        page_indices_src: Source page indices [bs, max_len]\n        nsa_cache_seqlens_src: Source NSA cache sequence lengths [bs]\n        nsa_cu_seqlens_k_src: Source NSA cumulative sequence lengths [bs+1]\n        real_page_table_src: Optional source real page table [bs, cols]\n        flashmla_num_splits_src: Optional source FlashMLA num_splits [bs+1]\n        flashmla_metadata_src: Optional source FlashMLA metadata tensor\n        cache_seqlens_dst0-2: Destination cache sequence lengths for backends 0-2\n        cu_seqlens_k_dst0-2: Destination cumulative sequence lengths for backends 0-2\n        page_table_1_dst0-2: Destination page tables for backends 0-2\n        nsa_cache_seqlens_dst0-2: Destination NSA cache sequence lengths for backends 0-2\n        nsa_cu_seqlens_k_dst0-2: Destination NSA cumulative sequence lengths for backends 0-2\n        real_page_table_dst0-2: Optional destination real page tables for backends 0-2\n        flashmla_num_splits_dst0-2: Optional destination FlashMLA num_splits for backends 0-2\n        flashmla_metadata_dst0-2: Optional destination FlashMLA metadata tensors for backends 0-2\n        bs: Batch size\n        max_len: Maximum length for decode mode\n        seqlens_expanded_size: Size of expanded sequence lengths\n    \"\"\"\n    # Determine template parameters for kernel specialization\n    has_real_page_table = real_page_table_src is not None\n    has_flashmla = flashmla_num_splits_src is not None\n\n    # Get JIT-compiled module for this configuration (cached after first use)\n    module = _jit_fused_metadata_copy_multi_module(has_real_page_table, has_flashmla)\n\n    # Ensure all source tensors are contiguous (required for kernel's linear indexing)\n    # This matches the CHECK_INPUT checks in the verified sgl-kernel implementation\n    cache_seqlens_src = cache_seqlens_src.contiguous()\n    cu_seqlens_k_src = cu_seqlens_k_src.contiguous()\n    page_indices_src = page_indices_src.contiguous()\n    nsa_cache_seqlens_src = nsa_cache_seqlens_src.contiguous()\n    nsa_cu_seqlens_k_src = nsa_cu_seqlens_k_src.contiguous()\n\n    # Call JIT-compiled kernel (None values are passed as Optional with no value)\n    module.fused_metadata_copy_multi(\n        cache_seqlens_src,\n        cu_seqlens_k_src,\n        page_indices_src,\n        nsa_cache_seqlens_src,\n        nsa_cu_seqlens_k_src,\n        real_page_table_src,\n        flashmla_num_splits_src,\n        flashmla_metadata_src,\n        cache_seqlens_dst0,\n        cu_seqlens_k_dst0,\n        page_table_1_dst0,\n        nsa_cache_seqlens_dst0,\n        nsa_cu_seqlens_k_dst0,\n        real_page_table_dst0,\n        flashmla_num_splits_dst0,\n        flashmla_metadata_dst0,\n        cache_seqlens_dst1,\n        cu_seqlens_k_dst1,\n        page_table_1_dst1,\n        nsa_cache_seqlens_dst1,\n        nsa_cu_seqlens_k_dst1,\n        real_page_table_dst1,\n        flashmla_num_splits_dst1,\n        flashmla_metadata_dst1,\n        cache_seqlens_dst2,\n        cu_seqlens_k_dst2,\n        page_table_1_dst2,\n        nsa_cache_seqlens_dst2,\n        nsa_cu_seqlens_k_dst2,\n        real_page_table_dst2,\n        flashmla_num_splits_dst2,\n        flashmla_metadata_dst2,\n        bs,\n        max_len,\n        seqlens_expanded_size,\n    )\n"
  },
  {
    "path": "python/sglang/jit_kernel/fused_store_index_cache.py",
    "content": "\"\"\"\nThis module provides JIT-compiled CUDA kernels for fusing multiple tensor\ncopy operations into single kernel launches, reducing kernel launch overhead\nand improving CUDA graph replay performance.\n\nThe kernels are compiled on-demand using TVM FFI and cached for subsequent use.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport logging\nfrom typing import TYPE_CHECKING\n\nimport torch\n\nfrom sglang.jit_kernel.utils import (\n    cache_once,\n    is_arch_support_pdl,\n    load_jit,\n    make_cpp_args,\n)\n\nif TYPE_CHECKING:\n    from tvm_ffi.module import Module\n\nlogger = logging.getLogger(__name__)\n\n\n@cache_once\ndef _jit_nsa_fused_store_module(\n    key_dtype: torch.dtype, indices_dtype: torch.dtype, page_size: int\n) -> Module:\n    \"\"\"\n    Build a JIT module that exposes:\n      module.fused_store_index_k_cache(input_bf16, index_k_with_scale_u8, loc_i64)\n    \"\"\"\n    args = make_cpp_args(key_dtype, indices_dtype, page_size, is_arch_support_pdl())\n    return load_jit(\n        \"fused_store_index_k_cache\",\n        *args,\n        cuda_files=[\"nsa/fused_store_index_cache.cuh\"],\n        cuda_wrappers=[\n            (\n                \"fused_store_index_k_cache\",\n                # - Float  = bf16_t (sgl_kernel/type.cuh)\n                # - IndicesT = int64_t (out_cache_loc is int64 in SGLang SetKAndS)\n                # - kPageSize = 64 (CUDA NSA)\n                f\"FusedStoreCacheIndexerKernel<{args}>::run\",\n            )\n        ],\n    )\n\n\n@cache_once\ndef can_use_nsa_fused_store(\n    key_dtype: torch.dtype, indices_dtype: torch.dtype, page_size: int\n) -> bool:\n    logger = logging.getLogger(__name__)\n    try:\n        _jit_nsa_fused_store_module(key_dtype, indices_dtype, page_size)\n        return True\n    except Exception as e:\n        logger.warning(f\"Failed to load nsa fused store JIT kernel: {e}\")\n        return False\n\n\ndef fused_store_index_k_cache(\n    key: torch.Tensor,\n    index_k_with_scale: torch.Tensor,\n    out_cache_loc: torch.Tensor,\n    page_size: int = 64,\n) -> None:\n    \"\"\"\n    Fused: quantize bf16 key (N,128) -> fp8 + fp32 scale and write into NSATokenToKVPool.index_k_with_scale_buffer.\n\n    key:            (num_tokens, 128) bf16 (or reshapeable to it)\n    index_k_with_scale:  (num_pages, 64*(128+4)) uint8\n    out_cache_loc:       (num_tokens,) int64 token indices in TokenToKVPool\n    \"\"\"\n    assert key.is_cuda\n    assert index_k_with_scale.is_cuda\n    assert out_cache_loc.is_cuda\n\n    # 1) normalize shapes\n    if key.dim() != 2:\n        key = key.view(-1, key.shape[-1])\n    assert key.shape[1] == 128, f\"expected key last-dim=128, got {key.shape}\"\n\n    # 2) dtypes\n    assert key.dtype == torch.bfloat16, f\"{key.dtype=}\"\n    assert index_k_with_scale.dtype == torch.uint8, f\"{index_k_with_scale.dtype=}\"\n    assert out_cache_loc.dtype == torch.int64, f\"{out_cache_loc.dtype=}\"\n\n    # 3) contiguity\n    if not key.is_contiguous():\n        key = key.contiguous()\n    if not out_cache_loc.is_contiguous():\n        out_cache_loc = out_cache_loc.contiguous()\n    if not index_k_with_scale.is_contiguous():\n        index_k_with_scale = index_k_with_scale.contiguous()\n\n    module = _jit_nsa_fused_store_module(key.dtype, out_cache_loc.dtype, page_size)\n    module.fused_store_index_k_cache(key, index_k_with_scale, out_cache_loc)\n"
  },
  {
    "path": "python/sglang/jit_kernel/gptq_marlin.py",
    "content": "from __future__ import annotations\n\nfrom typing import TYPE_CHECKING, Optional\n\nimport torch\n\nfrom sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args\n\nif TYPE_CHECKING:\n    from sgl_kernel.scalar_type import ScalarType\n    from tvm_ffi.module import Module\n\n# Constants matching device::marlin:: in marlin.cuh\n_MAX_THREAD_N = 256\n\n\n@cache_once\ndef _jit_gptq_marlin_module(dtype: torch.dtype) -> Module:\n    args = make_cpp_args(dtype)\n    return load_jit(\n        \"gptq_marlin\",\n        *args,\n        cuda_files=[\"gemm/marlin/gptq_marlin.cuh\"],\n        cuda_wrappers=[(\"gptq_marlin_gemm\", f\"gptq_marlin_gemm<{args}>\")],\n    )\n\n\ndef _or_empty(\n    t: Optional[torch.Tensor], device: torch.device, dtype: torch.dtype\n) -> torch.Tensor:\n    return t if t is not None else torch.empty(0, device=device, dtype=dtype)\n\n\ndef gptq_marlin_gemm(\n    a: torch.Tensor,\n    c: Optional[torch.Tensor],\n    b_q_weight: torch.Tensor,\n    b_scales: torch.Tensor,\n    global_scale: Optional[torch.Tensor],\n    b_zeros: Optional[torch.Tensor],\n    g_idx: Optional[torch.Tensor],\n    perm: Optional[torch.Tensor],\n    workspace: torch.Tensor,\n    b_q_type: ScalarType,\n    size_m: int,\n    size_n: int,\n    size_k: int,\n    is_k_full: bool = True,\n    use_atomic_add: bool = False,\n    use_fp32_reduce: bool = False,\n    is_zp_float: bool = False,\n) -> torch.Tensor:\n    device = a.device\n\n    # Allocate output if not provided\n    if c is None:\n        c = torch.empty((size_m, size_n), dtype=a.dtype, device=device)\n\n    # Early return for zero-size M\n    if size_m == 0:\n        return c\n\n    # Determine activation ordering\n    has_act_order = (\n        g_idx is not None\n        and perm is not None\n        and g_idx.numel() > 0\n        and perm.numel() > 0\n    )\n\n    # Allocate c_tmp for fp32 reduce\n    if use_fp32_reduce:\n        sms = torch.cuda.get_device_properties(device).multi_processor_count\n        max_m_block = min(((size_m + 15) // 16) * 16, 64)\n        c_tmp = torch.empty(\n            sms * max_m_block * _MAX_THREAD_N,\n            dtype=torch.float32,\n            device=device,\n        )\n    else:\n        c_tmp = torch.empty(0, dtype=torch.float32, device=device)\n\n    # Allocate a_tmp for act_order column permutation\n    if has_act_order:\n        a_tmp = torch.empty((size_m, size_k), dtype=a.dtype, device=device)\n    else:\n        a_tmp = torch.empty(0, dtype=a.dtype, device=device)\n\n    # Convert Optional tensors to empty tensors\n    global_scale_t = _or_empty(global_scale, device, a.dtype)\n    b_zeros_t = _or_empty(b_zeros, device, torch.int32)\n    g_idx_t = _or_empty(g_idx, device, torch.int32)\n    perm_t = _or_empty(perm, device, torch.int32)\n\n    module = _jit_gptq_marlin_module(a.dtype)\n    module.gptq_marlin_gemm(\n        a,\n        b_q_weight,\n        b_scales,\n        global_scale_t,\n        b_zeros_t,\n        g_idx_t,\n        perm_t,\n        c,\n        c_tmp,\n        a_tmp,\n        workspace,\n        b_q_type.id,\n        is_k_full,\n        use_atomic_add,\n        use_fp32_reduce,\n        is_zp_float,\n    )\n\n    return c\n"
  },
  {
    "path": "python/sglang/jit_kernel/gptq_marlin_repack.py",
    "content": "from __future__ import annotations\n\nfrom typing import TYPE_CHECKING\n\nimport torch\n\nfrom sglang.jit_kernel.utils import cache_once, load_jit\n\nif TYPE_CHECKING:\n    from tvm_ffi.module import Module\n\n# Constants matching device::marlin:: in marlin.cuh\n_TILE_SIZE = 16\n\n\n@cache_once\ndef _jit_gptq_marlin_repack_module() -> Module:\n    return load_jit(\n        \"gptq_marlin_repack\",\n        cuda_files=[\"gemm/marlin/gptq_marlin_repack.cuh\"],\n        cuda_wrappers=[(\"gptq_marlin_repack\", \"gptq_marlin_repack\")],\n    )\n\n\ndef gptq_marlin_repack(\n    b_q_weight: torch.Tensor,\n    perm: torch.Tensor,\n    size_k: int,\n    size_n: int,\n    num_bits: int,\n) -> torch.Tensor:\n    pack_factor = 32 // num_bits\n\n    # Allocate output tensor\n    out = torch.empty(\n        (size_k // _TILE_SIZE, size_n * _TILE_SIZE // pack_factor),\n        dtype=b_q_weight.dtype,\n        device=b_q_weight.device,\n    )\n\n    module = _jit_gptq_marlin_repack_module()\n    module.gptq_marlin_repack(b_q_weight, perm, out, size_k, size_n, num_bits)\n    return out\n"
  },
  {
    "path": "python/sglang/jit_kernel/hadamard.py",
    "content": "from __future__ import annotations\n\nfrom typing import TYPE_CHECKING, Callable\n\nimport torch\n\nfrom sglang.jit_kernel.utils import KERNEL_PATH, cache_once, load_jit, make_cpp_args\n\nif TYPE_CHECKING:\n    from tvm_ffi.module import Module\n\n\n@cache_once\ndef _jit_hadamard_module(dtype: torch.dtype) -> Module:\n    args = make_cpp_args(dtype)\n    hadamard_include_dir = (KERNEL_PATH / \"csrc\" / \"fast-hadamard-transform\").resolve()\n    return load_jit(\n        \"hadamard\",\n        *args,\n        cuda_files=[\"fast-hadamard-transform/hadamard_jit.cuh\"],\n        cuda_wrappers=[\n            (\"hadamard_transform\", f\"HadamardKernel<{args}>::run\"),\n            (\"hadamard_transform_12n\", f\"Hadamard12NKernel<{args}>::run\"),\n            (\"hadamard_transform_20n\", f\"Hadamard20NKernel<{args}>::run\"),\n            (\"hadamard_transform_28n\", f\"Hadamard28NKernel<{args}>::run\"),\n            (\"hadamard_transform_40n\", f\"Hadamard40NKernel<{args}>::run\"),\n        ],\n        extra_include_paths=[str(hadamard_include_dir)],\n    )\n\n\ndef _hadamard_transform_impl(\n    x: torch.Tensor,\n    scale: float,\n    pad_multiple: int,\n    kernel_fn: Callable,\n) -> torch.Tensor:\n    if not x.is_cuda:\n        raise RuntimeError(f\"{kernel_fn.__name__} only supports CUDA tensors\")\n\n    shapes_og = x.size()\n    dim_og = x.size(-1)\n    x = x.reshape(-1, dim_og)\n    if x.stride(-1) != 1:\n        x = x.contiguous()\n\n    needs_pad = dim_og % pad_multiple != 0\n    if needs_pad:\n        x = torch.nn.functional.pad(x, (0, pad_multiple - dim_og % pad_multiple))\n\n    out = torch.empty_like(x)\n    kernel_fn(x, out, scale)\n\n    if needs_pad:\n        out = out[:, :dim_og]\n    return out.reshape(shapes_og)\n\n\ndef hadamard_transform(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:\n    module = _jit_hadamard_module(x.dtype)\n    return _hadamard_transform_impl(x, scale, 8, module.hadamard_transform)\n\n\ndef hadamard_transform_12n(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:\n    module = _jit_hadamard_module(x.dtype)\n    return _hadamard_transform_impl(x, scale, 4 * 12, module.hadamard_transform_12n)\n\n\ndef hadamard_transform_20n(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:\n    module = _jit_hadamard_module(x.dtype)\n    return _hadamard_transform_impl(x, scale, 4 * 20, module.hadamard_transform_20n)\n\n\ndef hadamard_transform_28n(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:\n    module = _jit_hadamard_module(x.dtype)\n    return _hadamard_transform_impl(x, scale, 4 * 28, module.hadamard_transform_28n)\n\n\ndef hadamard_transform_40n(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:\n    module = _jit_hadamard_module(x.dtype)\n    return _hadamard_transform_impl(x, scale, 4 * 40, module.hadamard_transform_40n)\n"
  },
  {
    "path": "python/sglang/jit_kernel/hicache.py",
    "content": "from __future__ import annotations\n\nimport logging\nfrom typing import TYPE_CHECKING\n\nfrom sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args\n\nif TYPE_CHECKING:\n    import torch\n    from tvm_ffi.module import Module\n\nDEFAULT_BLOCK_QUOTA = 2\n\n\n@cache_once\ndef _jit_hicache_module(*, element_size: int, unroll: int, block_quota: int) -> Module:\n    args = make_cpp_args(\n        element_size,\n        unroll,\n        block_quota,\n        1024,  # num_threads, can be tuned for performance\n    )\n    return load_jit(\n        \"hicache\",\n        *args,\n        cuda_files=[\"hicache.cuh\"],\n        cuda_wrappers=[\n            (\"launch_one\", f\"&HiCacheKernel<{args}>::run_one\"),\n            (\"launch_all\", f\"&HiCacheKernel<{args}>::run_all\"),\n        ],\n    )\n\n\ndef can_use_hicache_jit_kernel(\n    *,\n    element_size: int,\n    unroll: int | None = None,  # can be tuned for performance\n    block_quota: int | None = None,  # can be tuned for less interference\n) -> bool:\n    logger = logging.getLogger(__name__)\n    if element_size % 128 != 0:\n        logger.warning(f\"Unsupported {element_size = } for JIT HiCache kernel\")\n        return False\n    try:\n        unroll = unroll or _default_unroll(element_size)\n        block_quota = block_quota or DEFAULT_BLOCK_QUOTA\n        _jit_hicache_module(\n            element_size=element_size,\n            unroll=unroll,\n            block_quota=block_quota,\n        )\n        return True\n    except Exception as e:\n        logger.warning(f\"Failed to load JIT HiCache kernel: {e}\")\n        return False\n\n\ndef _default_unroll(element_size: int) -> int:\n    if element_size <= 512:\n        return 4\n\n    if element_size <= 1024:\n        return 2\n\n    # fallback: no unroll\n    return 1\n\n\ndef transfer_hicache_one_layer(\n    k_cache_dst: torch.Tensor,\n    v_cache_dst: torch.Tensor,\n    indices_dst: torch.Tensor,\n    k_cache_src: torch.Tensor,\n    v_cache_src: torch.Tensor,\n    indices_src: torch.Tensor,\n    *,\n    element_dim: int | None = None,\n    unroll: int | None = None,  # can be tuned for performance\n    block_quota: int | None = None,  # can be tuned for less interference\n) -> None:\n    element_dim = element_dim or k_cache_dst.size(-1)\n    k_cache_src = k_cache_src.view(-1, element_dim)\n    v_cache_src = v_cache_src.view(-1, element_dim)\n    k_cache_dst = k_cache_dst.view(-1, element_dim)\n    v_cache_dst = v_cache_dst.view(-1, element_dim)\n    element_size = element_dim * k_cache_dst.element_size()\n    block_quota = block_quota or DEFAULT_BLOCK_QUOTA\n    unroll = unroll or _default_unroll(element_size)\n    module = _jit_hicache_module(\n        element_size=element_size,\n        unroll=unroll,\n        block_quota=block_quota,\n    )\n    module.launch_one(\n        k_cache_dst,\n        v_cache_dst,\n        indices_dst,\n        k_cache_src,\n        v_cache_src,\n        indices_src,\n    )\n\n\ndef transfer_hicache_all_layer(\n    k_ptr_dst: torch.Tensor,\n    v_ptr_dst: torch.Tensor,\n    indices_dst: torch.Tensor,\n    k_ptr_src: torch.Tensor,\n    v_ptr_src: torch.Tensor,\n    indices_src: torch.Tensor,\n    *,\n    kv_cache_src_stride_bytes: int,\n    kv_cache_dst_stride_bytes: int,\n    element_size: int | None = None,\n    unroll: int | None = None,  # can be tuned for performance\n    block_quota: int | None = None,  # can be tuned for less interference\n) -> None:\n    if element_size is None:  # assume both contiguous\n        assert kv_cache_dst_stride_bytes == kv_cache_src_stride_bytes\n        element_size = kv_cache_dst_stride_bytes\n\n    block_quota = block_quota or DEFAULT_BLOCK_QUOTA\n    unroll = unroll or _default_unroll(element_size)\n    module = _jit_hicache_module(\n        element_size=element_size,\n        unroll=unroll,\n        block_quota=block_quota,\n    )\n    module.launch_all(\n        k_ptr_dst,\n        v_ptr_dst,\n        indices_dst,\n        k_ptr_src,\n        v_ptr_src,\n        indices_src,\n        kv_cache_src_stride_bytes,\n        kv_cache_dst_stride_bytes,\n    )\n"
  },
  {
    "path": "python/sglang/jit_kernel/include/sgl_kernel/atomic.cuh",
    "content": "/// \\file atomic.cuh\n/// \\brief Device-side atomic operations.\n\n#pragma once\n#include <sgl_kernel/utils.cuh>\n\nnamespace device::atomic {\n\n/**\n * \\brief Atomically computes the maximum of `*addr` and `value`, storing the\n *        result in `*addr`.\n * \\param addr Pointer to the value in global/shared memory to be updated.\n * \\param value The value to compare against.\n * \\return The old value at `*addr` before the update.\n * \\note On CUDA, this uses `atomicMax`/`atomicMin` on the reinterpreted\n *       integer representation. On ROCm, a CAS loop is used as a fallback.\n */\nSGL_DEVICE float max(float* addr, float value) {\n#ifndef USE_ROCM\n  float old;\n  old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))\n                     : __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));\n  return old;\n#else\n  int* addr_as_i = (int*)addr;\n  int old = *addr_as_i, assumed;\n  do {\n    assumed = old;\n    old = atomicCAS(addr_as_i, assumed, __float_as_int(fmaxf(value, __int_as_float(assumed))));\n  } while (assumed != old);\n  return __int_as_float(old);\n#endif\n}\n\n}  // namespace device::atomic\n"
  },
  {
    "path": "python/sglang/jit_kernel/include/sgl_kernel/cta.cuh",
    "content": "/// \\file cta.cuh\n/// \\brief CTA (Cooperative Thread Array / thread-block) level primitives.\n\n#pragma once\n#include <sgl_kernel/math.cuh>\n#include <sgl_kernel/utils.cuh>\n#include <sgl_kernel/warp.cuh>\n\nnamespace device::cta {\n\n/**\n * \\brief Compute the maximum of `value` across all threads in the CTA.\n *\n * Uses a two-level reduction: first within each warp via `warp::reduce_max`,\n * then across warps using shared memory. The final result is stored in\n * `smem[0]`.\n *\n * \\tparam T Numeric type (must be supported by `warp::reduce_max`).\n * \\param value Per-thread input value.\n * \\param smem Shared memory buffer (must have at least `blockDim.x / 32`\n *             elements).\n * \\param min_value Identity element for max (default 0.0f).\n * \\note This function does NOT issue a trailing `__syncthreads()`.\n *       Callers must synchronize before reading `smem[0]`.\n */\ntemplate <typename T>\nSGL_DEVICE void reduce_max(T value, float* smem, float min_value = 0.0f) {\n  const uint32_t warp_id = threadIdx.x / kWarpThreads;\n  smem[warp_id] = warp::reduce_max(value);\n  __syncthreads();\n  if (warp_id == 0) {\n    const auto tx = threadIdx.x;\n    const auto local_value = tx * kWarpThreads < blockDim.x ? smem[tx] : min_value;\n    const auto max_value = warp::reduce_max(local_value);\n    smem[0] = max_value;\n  }\n  // no extra sync; it is caller's responsibility to sync if needed\n}\n\n}  // namespace device::cta\n"
  },
  {
    "path": "python/sglang/jit_kernel/include/sgl_kernel/impl/norm.cuh",
    "content": "#pragma once\n#include <sgl_kernel/math.cuh>\n#include <sgl_kernel/type.cuh>\n#include <sgl_kernel/utils.cuh>\n#include <sgl_kernel/vec.cuh>\n#include <sgl_kernel/warp.cuh>\n\n#include <cstdint>\n#include <type_traits>\n\nnamespace host::norm {\n\n/**\n * \\brief Check if the given configuration is supported.\n * \\tparam T Element type (only fp16_t/bf16_t is supported)\n * \\tparam kDim Dimension size (usually hidden size)\n */\ntemplate <typename T, int64_t kDim>\ninline constexpr bool is_config_supported() {\n  if (!std::is_same_v<T, fp16_t> && !std::is_same_v<T, bf16_t>) return false;\n  if (kDim <= 256) {\n    return (kDim == 64 || kDim == 128 || kDim == 256);\n  } else {\n    return (kDim % 256 == 0 && kDim <= 8192);\n  }\n}\n\n/**\n * \\brief Determine whether to use cta norm based on dimension size.\n * TL;DR: use warp norm for dim <= 256, cta norm otherwise.\n * \\tparam T Element type (fp16_t or bf16_t)\n * \\tparam kDim Dimension size (usually hidden size)\n * \\note This function assumes that the configuration is supported.\n * \\see `is_config_supported`\n */\ntemplate <typename T, int64_t kDim>\ninline constexpr bool should_use_cta() {\n  static_assert(is_config_supported<T, kDim>(), \"Unsupported norm configuration\");\n  return kDim > 256;\n}\n\n/**\n * \\brief Get the number of threads per CTA for cta norm.\n * \\tparam T Element type (fp16_t or bf16_t)\n * \\tparam kDim Dimension size (usually hidden size)\n * \\return Number of threads per CTA\n */\ntemplate <typename T, int64_t kDim>\ninline constexpr uint32_t get_cta_threads() {\n  static_assert(should_use_cta<T, kDim>());\n  return (kDim / 256) * device::kWarpThreads;\n}\n\n}  // namespace host::norm\n\nnamespace device::norm {\n\nnamespace details {\n\ntemplate <int64_t kDim, bool kUseCTA, typename PackedFloat, std::size_t N>\nSGL_DEVICE AlignedVector<PackedFloat, N> apply_norm_impl(\n    const AlignedVector<PackedFloat, N> input,\n    const AlignedVector<PackedFloat, N> weight,\n    const float eps,\n    [[maybe_unused]] float* smem_buffer,\n    [[maybe_unused]] uint32_t num_warps) {\n  float sum_of_squares = 0.0f;\n\n#pragma unroll\n  for (auto i = 0u; i < N; ++i) {\n    const auto fp32_input = cast<fp32x2_t>(input[i]);\n    sum_of_squares += fp32_input.x * fp32_input.x;\n    sum_of_squares += fp32_input.y * fp32_input.y;\n  }\n\n  sum_of_squares = warp::reduce_sum(sum_of_squares);\n  float norm_factor;\n  if constexpr (kUseCTA) {\n    // need to synchronize across the cta\n    const auto warp_id = threadIdx.x / kWarpThreads;\n    smem_buffer[warp_id] = sum_of_squares;\n    __syncthreads();\n    // use the first warp to reduce\n    if (warp_id == 0) {\n      const auto tx = threadIdx.x;\n      const auto local_sum = tx < num_warps ? smem_buffer[tx] : 0.0f;\n      sum_of_squares = warp::reduce_sum(local_sum);\n      smem_buffer[32] = math::rsqrt(sum_of_squares / kDim + eps);\n    }\n    __syncthreads();\n    norm_factor = smem_buffer[32];\n  } else {\n    norm_factor = math::rsqrt(sum_of_squares / kDim + eps);\n  }\n\n  AlignedVector<PackedFloat, N> output;\n\n#pragma unroll\n  for (auto i = 0u; i < N; ++i) {\n    const auto fp32_input = cast<fp32x2_t>(input[i]);\n    const auto fp32_weight = cast<fp32x2_t>(weight[i]);\n    output[i] = cast<PackedFloat, fp32x2_t>({\n        fp32_input.x * norm_factor * fp32_weight.x,\n        fp32_input.y * norm_factor * fp32_weight.y,\n    });\n  }\n\n  return output;\n}\n\n}  // namespace details\n\n/**\n * \\brief Apply norm using warp-level implementation.\n * \\tparam kDim Dimension size\n * \\tparam T Element type (fp16_t or bf16_t)\n * \\param input Input vector\n * \\param weight Weight vector\n * \\param eps Epsilon value for numerical stability\n * \\return Normalized output vector\n */\ntemplate <int64_t kDim, typename T>\nSGL_DEVICE T apply_norm_warp(const T& input, const T& weight, float eps) {\n  static_assert(kDim <= 256, \"Warp norm only supports dim <= 256\");\n  return details::apply_norm_impl<kDim, false>(input, weight, eps, nullptr, 0);\n}\n\n/**\n * \\brief Apply norm using CTA-level implementation.\n * \\tparam kDim Dimension size\n * \\tparam T Element type (fp16_t or bf16_t)\n * \\param input Input vector\n * \\param weight Weight vector\n * \\param eps Epsilon value for numerical stability\n * \\param smem Shared memory buffer\n * \\param num_warps Number of warps in the CTA\n * \\return Normalized output vector\n */\ntemplate <int64_t kDim, typename T>\nSGL_DEVICE T apply_norm_cta(\n    const T& input, const T& weight, float eps, float* smem, uint32_t num_warps = blockDim.x / kWarpThreads) {\n  static_assert(kDim > 256, \"CTA norm only supports dim > 256\");\n  return details::apply_norm_impl<kDim, true>(input, weight, eps, smem, num_warps);\n}\n\n/**\n * \\brief Storage type for norm operation.\n * For warp norm, the storage size depends on kDim.\n * For cta norm, the storage size is fixed to 16B.\n * We will also pack the input 16-bit floats into 32-bit types\n * for faster CUDA core operations.\n *\n * \\tparam T Element type (fp16_t or bf16_t)\n * \\tparam kDim Dimension size\n */\ntemplate <typename T, int64_t kDim>\nusing StorageType = std::conditional_t<                    // storage type\n    (kDim > 256),                                          // whether to use cta norm\n    AlignedVector<packed_t<T>, 4>,                         // cta norm storage, fixed to 16B\n    AlignedVector<packed_t<T>, kDim / (2 * kWarpThreads)>  // warp norm storage\n    >;\n\n/**\n * \\brief Minimum shared memory size (in bytes) required for cta norm.\n */\ninline constexpr uint32_t kSmemBufferSize = 33;\n\n}  // namespace device::norm\n"
  },
  {
    "path": "python/sglang/jit_kernel/include/sgl_kernel/math.cuh",
    "content": "/// \\file math.cuh\n/// \\brief Device-side math helper functions and constants.\n///\n/// Provides type-generic wrappers around CUDA math intrinsics by\n/// dispatching through `dtype_trait<T>`. All functions are forced-inline\n/// device functions.\n\n#pragma once\n#include <sgl_kernel/type.cuh>\n\n#include <cmath>\n\nnamespace device::math {\n\n/// \\brief Constant: log2(e)\ninline constexpr float log2e = 1.44269504088896340736f;\n/// \\brief Constant: ln(2)\ninline constexpr float loge2 = 0.693147180559945309417f;\n/// \\brief Maximum representable value for FP8 E4M3 format.\ninline constexpr float FP8_E4M3_MAX = 448.0f;\nstatic_assert(log2e * loge2 == 1.0f, \"log2e * loge2 must be 1\");\n\n/// \\brief Returns the larger of `a` and `b`.\ntemplate <typename T>\nSGL_DEVICE T max(T a, T b) {\n  return dtype_trait<T>::max(a, b);\n}\n\n/// \\brief Returns the smaller of `a` and `b`.\ntemplate <typename T>\nSGL_DEVICE T min(T a, T b) {\n  return dtype_trait<T>::min(a, b);\n}\n\n/// \\brief Returns the absolute value of `a`.\ntemplate <typename T>\nSGL_DEVICE T abs(T a) {\n  return dtype_trait<T>::abs(a);\n}\n\n/// \\brief Returns the square root of `a`.\ntemplate <typename T>\nSGL_DEVICE T sqrt(T a) {\n  return dtype_trait<T>::sqrt(a);\n}\n\n/// \\brief Returns the reciprocal square root of `a` (i.e. 1 / sqrt(a)).\ntemplate <typename T>\nSGL_DEVICE T rsqrt(T a) {\n  return dtype_trait<T>::rsqrt(a);\n}\n\n/// \\brief Returns e^a.\ntemplate <typename T>\nSGL_DEVICE T exp(T a) {\n  return dtype_trait<T>::exp(a);\n}\n\n/// \\brief Returns sin(a).\ntemplate <typename T>\nSGL_DEVICE T sin(T a) {\n  return dtype_trait<T>::sin(a);\n}\n\n/// \\brief Returns cos(a).\ntemplate <typename T>\nSGL_DEVICE T cos(T a) {\n  return dtype_trait<T>::cos(a);\n}\n\n}  // namespace device::math\n"
  },
  {
    "path": "python/sglang/jit_kernel/include/sgl_kernel/runtime.cuh",
    "content": "/// \\file runtime.cuh\n/// \\brief Host-side CUDA runtime query helpers.\n///\n/// Thin wrappers around CUDA occupancy and device-property APIs with\n/// automatic error checking via `RuntimeDeviceCheck`.\n\n#pragma once\n\n#include <sgl_kernel/utils.cuh>\n\n#include <cstddef>\n#include <cstdint>\n#include <cuda_runtime.h>\n\nnamespace host::runtime {\n\n// Return the maximum number of active blocks per SM for the given kernel\ntemplate <typename T>\ninline auto get_blocks_per_sm(T&& kernel, int32_t block_dim, std::size_t dynamic_smem = 0) -> uint32_t {\n  int num_blocks_per_sm = 0;\n  RuntimeDeviceCheck(\n      cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel, block_dim, dynamic_smem));\n  return static_cast<uint32_t>(num_blocks_per_sm);\n}\n\n// Return the number of SMs for the given device\ninline auto get_sm_count(int device_id) -> uint32_t {\n  int sm_count;\n  RuntimeDeviceCheck(cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device_id));\n  return static_cast<uint32_t>(sm_count);\n}\n\n// Return the Major compute capability for the given device\ninline auto get_cc_major(int device_id) -> int {\n  int cc_major;\n  RuntimeDeviceCheck(cudaDeviceGetAttribute(&cc_major, cudaDevAttrComputeCapabilityMajor, device_id));\n  return cc_major;\n}\n\n// Return the runtime version\ninline auto get_runtime_version() -> int {\n  int runtime_version;\n  RuntimeDeviceCheck(cudaRuntimeGetVersion(&runtime_version));\n  return runtime_version;\n}\n\n// Return the maximum dynamic shared memory per block for the given kernel\ntemplate <typename T>\ninline auto get_available_dynamic_smem_per_block(T&& kernel, int num_blocks, int block_size) -> std::size_t {\n  std::size_t smem_size;\n  RuntimeDeviceCheck(cudaOccupancyAvailableDynamicSMemPerBlock(&smem_size, kernel, num_blocks, block_size));\n  return smem_size;\n}\n\n}  // namespace host::runtime\n"
  },
  {
    "path": "python/sglang/jit_kernel/include/sgl_kernel/scalar_type.hpp",
    "content": "#pragma once\n\n#include <cassert>\n#include <stdexcept>\n#ifndef __CUDACC__\n#include <variant>\n#endif\n\nnamespace host {\n\n//\n//  ScalarType can represent a wide range of floating point and integer types,\n//  in particular it can be used to represent sub-byte data types (something\n//  that torch.dtype currently does not support).\n//\n//  The type definitions on the Python side can be found in: vllm/scalar_type.py\n//  these type definitions should be kept up to date with any Python API changes\n//  here.\n//\nclass ScalarType {\n public:\n  enum NanRepr : uint8_t {\n    NAN_NONE = 0,                // nans are not supported\n    NAN_IEEE_754 = 1,            // nans are: exp all 1s, mantissa not all 0s\n    NAN_EXTD_RANGE_MAX_MIN = 2,  // nans are: exp all 1s, mantissa all 1s\n\n    NAN_REPR_ID_MAX\n  };\n\n  constexpr ScalarType(\n      uint8_t exponent,\n      uint8_t mantissa,\n      bool signed_,\n      int32_t bias,\n      bool finite_values_only = false,\n      NanRepr nan_repr = NAN_IEEE_754)\n      : exponent(exponent),\n        mantissa(mantissa),\n        signed_(signed_),\n        bias(bias),\n        finite_values_only(finite_values_only),\n        nan_repr(nan_repr) {};\n\n  static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) {\n    return ScalarType(0, size_bits - 1, true, bias);\n  }\n\n  static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0) {\n    return ScalarType(0, size_bits, false, bias);\n  }\n\n  // IEEE 754 compliant floating point type\n  static constexpr ScalarType float_IEEE754(uint8_t exponent, uint8_t mantissa) {\n    assert(mantissa > 0 && exponent > 0);\n    return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754);\n  }\n\n  // IEEE 754 non-compliant floating point type\n  static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa, bool finite_values_only, NanRepr nan_repr) {\n    assert(nan_repr < NAN_REPR_ID_MAX);\n    assert(mantissa > 0 && exponent > 0);\n    assert(nan_repr != NAN_IEEE_754);\n    return ScalarType(exponent, mantissa, true, 0, finite_values_only, nan_repr);\n  }\n\n  uint8_t const exponent;  // size of the exponent field (0 for integer types)\n  uint8_t const mantissa;  // size of the mantissa field (size of the integer\n                           // excluding the sign bit for integer types)\n  bool const signed_;      // flag if the type supports negative numbers (i.e. has a\n                           // sign bit)\n  int32_t const bias;      // stored values equal value + bias,\n                           // used for quantized type\n\n  // Extra Floating point info\n  bool const finite_values_only;  // i.e. no +/-inf if true\n  NanRepr const nan_repr;         // how NaNs are represented\n                                  // (not applicable for integer types)\n\n  using Id = int64_t;\n\n private:\n  // Field size in id\n  template <typename T_>\n  static constexpr size_t member_id_field_width() {\n    using T = std::decay_t<T_>;\n    return std::is_same_v<T, bool> ? 1 : sizeof(T) * 8;\n  }\n\n  template <typename Fn, typename Init, typename Member, typename... Rest>\n  static constexpr auto reduce_members_helper(Fn f, Init val, Member member, Rest... rest) {\n    auto new_val = f(val, member);\n    if constexpr (sizeof...(rest) > 0) {\n      return reduce_members_helper(f, new_val, rest...);\n    } else {\n      return new_val;\n    };\n  }\n\n  template <typename Fn, typename Init>\n  constexpr auto reduce_members(Fn f, Init init) const {\n    // Should be in constructor order for `from_id`\n    return reduce_members_helper(f, init, exponent, mantissa, signed_, bias, finite_values_only, nan_repr);\n  };\n\n  template <typename Fn, typename Init>\n  static constexpr auto reduce_member_types(Fn f, Init init) {\n    constexpr auto dummy_type = ScalarType(0, 0, false, 0, false, NAN_NONE);\n    return dummy_type.reduce_members(f, init);\n  };\n\n  static constexpr auto id_size_bits() {\n    return reduce_member_types(\n        [](int acc, auto member) -> int { return acc + member_id_field_width<decltype(member)>(); }, 0);\n  }\n\n public:\n  // unique id for this scalar type that can be computed at compile time for\n  //  c++17 template specialization this is not needed once we migrate to\n  //  c++20 and can pass literal classes as template parameters\n  constexpr Id id() const {\n    static_assert(id_size_bits() <= sizeof(Id) * 8, \"ScalarType id is too large to be stored\");\n\n    auto or_and_advance = [](std::pair<Id, uint32_t> result, auto member) -> std::pair<Id, uint32_t> {\n      auto [id, bit_offset] = result;\n      auto constexpr bits = member_id_field_width<decltype(member)>();\n      return {id | (int64_t(member) & ((uint64_t(1) << bits) - 1)) << bit_offset, bit_offset + bits};\n    };\n    return reduce_members(or_and_advance, std::pair<Id, uint32_t>{}).first;\n  }\n\n  // create a ScalarType from an id, for c++17 template specialization,\n  //  this is not needed once we migrate to c++20 and can pass literal\n  //  classes as template parameters\n  static constexpr ScalarType from_id(Id id) {\n    auto extract_and_advance = [id](auto result, auto member) {\n      using T = decltype(member);\n      auto [tuple, bit_offset] = result;\n      auto constexpr bits = member_id_field_width<T>();\n      auto extracted_val = static_cast<T>((int64_t(id) >> bit_offset) & ((uint64_t(1) << bits) - 1));\n      auto new_tuple = std::tuple_cat(tuple, std::make_tuple(extracted_val));\n      return std::pair<decltype(new_tuple), int>{new_tuple, bit_offset + bits};\n    };\n\n    auto [tuple_args, _] = reduce_member_types(extract_and_advance, std::pair<std::tuple<>, int>{});\n    return std::apply([](auto... args) { return ScalarType(args...); }, tuple_args);\n  }\n\n  constexpr int64_t size_bits() const {\n    return mantissa + exponent + is_signed();\n  }\n  constexpr bool is_signed() const {\n    return signed_;\n  }\n  constexpr bool is_integer() const {\n    return exponent == 0;\n  }\n  constexpr bool is_floating_point() const {\n    return exponent > 0;\n  }\n  constexpr bool is_ieee_754() const {\n    return is_floating_point() && finite_values_only == false && nan_repr == NAN_IEEE_754;\n  }\n  constexpr bool has_nans() const {\n    return is_floating_point() && nan_repr != NAN_NONE;\n  }\n  constexpr bool has_infs() const {\n    return is_floating_point() && finite_values_only == false;\n  }\n  constexpr bool has_bias() const {\n    return bias != 0;\n  }\n\n#ifndef __CUDACC__\n private:\n  double _floating_point_max() const {\n    assert(mantissa <= 52 && exponent <= 11);\n\n    uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1;\n    if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) {\n      max_mantissa -= 1;\n    }\n\n    uint64_t max_exponent = (uint64_t(1) << exponent) - 2;\n    if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) {\n      assert(exponent < 11);\n      max_exponent += 1;\n    }\n\n    // adjust the exponent to match that of a double\n    //  for now we assume the exponent bias is the standard 2^(e-1) -1, (where e\n    //  is the exponent bits), there is some precedent for non-standard biases,\n    //  example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes\n    //  but to avoid premature over complication we are just assuming the\n    //  standard exponent bias until there is a need to support non-standard\n    //  biases\n    uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1;\n    uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1;  // double e = 11\n\n    uint64_t max_exponent_double = max_exponent - exponent_bias + exponent_bias_double;\n\n    // shift the mantissa into the position for a double and\n    // the exponent\n    uint64_t double_raw = (max_mantissa << (52 - mantissa)) | (max_exponent_double << 52);\n\n    return *reinterpret_cast<double*>(&double_raw);\n  }\n\n  constexpr std::variant<int64_t, double> _raw_max() const {\n    if (is_floating_point()) {\n      return {_floating_point_max()};\n    } else {\n      assert(size_bits() < 64 || (size_bits() == 64 && is_signed()));\n      return {(int64_t(1) << mantissa) - 1};\n    }\n  }\n\n  constexpr std::variant<int64_t, double> _raw_min() const {\n    if (is_floating_point()) {\n      assert(is_signed());\n      constexpr uint64_t sign_bit_double = (uint64_t(1) << 63);\n\n      double max = _floating_point_max();\n      uint64_t max_raw = *reinterpret_cast<uint64_t*>(&max);\n      uint64_t min_raw = max_raw | sign_bit_double;\n      return {*reinterpret_cast<double*>(&min_raw)};\n    } else {\n      assert(!is_signed() || size_bits() <= 64);\n      if (is_signed()) {\n        // set the top bit to 1 (i.e. INT64_MIN) and the rest to 0\n        // then perform an arithmetic shift right to set all the bits above\n        // (size_bits() - 1) to 1\n        return {INT64_MIN >> (64 - size_bits())};\n      } else {\n        return {int64_t(0)};\n      }\n    }\n  }\n\n public:\n  // Max representable value for this scalar type.\n  // (accounting for bias if there is one)\n  constexpr std::variant<int64_t, double> max() const {\n    return std::visit([this](auto x) -> std::variant<int64_t, double> { return {x - bias}; }, _raw_max());\n  }\n\n  // Min representable value for this scalar type.\n  // (accounting for bias if there is one)\n  constexpr std::variant<int64_t, double> min() const {\n    return std::visit([this](auto x) -> std::variant<int64_t, double> { return {x - bias}; }, _raw_min());\n  }\n#endif  // __CUDACC__\n\n public:\n  std::string str() const {\n    /* naming generally follows: https://github.com/jax-ml/ml_dtypes\n     * for floating point types (leading f) the scheme is:\n     *  `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`\n     *  flags:\n     *  - no-flags: means it follows IEEE 754 conventions\n     *  - f: means finite values only (no infinities)\n     *  - n: means nans are supported (non-standard encoding)\n     * for integer types the scheme is:\n     *  `[u]int<size_bits>[b<bias>]`\n     *  - if bias is not present it means its zero\n     */\n    if (is_floating_point()) {\n      auto ret =\n          \"float\" + std::to_string(size_bits()) + \"_e\" + std::to_string(exponent) + \"m\" + std::to_string(mantissa);\n      if (!is_ieee_754()) {\n        if (finite_values_only) {\n          ret += \"f\";\n        }\n        if (nan_repr != NAN_NONE) {\n          ret += \"n\";\n        }\n      }\n      return ret;\n    } else {\n      auto ret = ((is_signed()) ? \"int\" : \"uint\") + std::to_string(size_bits());\n      if (has_bias()) {\n        ret += \"b\" + std::to_string(bias);\n      }\n      return ret;\n    }\n  }\n\n  constexpr bool operator==(ScalarType const& other) const {\n    return mantissa == other.mantissa && exponent == other.exponent && bias == other.bias && signed_ == other.signed_ &&\n           finite_values_only == other.finite_values_only && nan_repr == other.nan_repr;\n  }\n};\n\nusing ScalarTypeId = ScalarType::Id;\n\n// \"rust style\" names generally following:\n//   https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70\nstatic inline constexpr auto kS4 = ScalarType::int_(4);\nstatic inline constexpr auto kU4 = ScalarType::uint(4);\nstatic inline constexpr auto kU4B8 = ScalarType::uint(4, 8);\nstatic inline constexpr auto kS8 = ScalarType::int_(8);\nstatic inline constexpr auto kU8 = ScalarType::uint(8);\nstatic inline constexpr auto kU8B128 = ScalarType::uint(8, 128);\n\nstatic inline constexpr auto kFE2M1f = ScalarType::float_(2, 1, true, ScalarType::NAN_NONE);\nstatic inline constexpr auto kFE3M2f = ScalarType::float_(3, 2, true, ScalarType::NAN_NONE);\nstatic inline constexpr auto kFE4M3fn = ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN);\nstatic inline constexpr auto kFE8M0fnu = ScalarType(8, 0, false, 0, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN);\nstatic inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2);\nstatic inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7);\nstatic inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10);\n\n// Fixed width style names, generally following:\n//  https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L47-L57\nstatic inline constexpr auto kInt4 = kS4;\nstatic inline constexpr auto kUint4 = kU4;\nstatic inline constexpr auto kUint4b8 = kU4B8;\nstatic inline constexpr auto kInt8 = kS8;\nstatic inline constexpr auto kUint8 = kU8;\nstatic inline constexpr auto kUint8b128 = kU8B128;\n\nstatic inline constexpr auto kFloat4_e2m1f = kFE2M1f;\nstatic inline constexpr auto kFloat6_e3m2f = kFE3M2f;\nstatic inline constexpr auto kFloat8_e4m3fn = kFE4M3fn;\nstatic inline constexpr auto kFloat8_e5m2 = kFE5M2;\nstatic inline constexpr auto kFloat16_e8m7 = kFE8M7;\nstatic inline constexpr auto kFloat16_e5m10 = kFE5M10;\n\n// colloquial names\nstatic inline constexpr auto kHalf = kFE5M10;\nstatic inline constexpr auto kFloat16 = kHalf;\nstatic inline constexpr auto kBFloat16 = kFE8M7;\n\nstatic inline constexpr auto kFloat16Id = kFloat16.id();\n}  // namespace host\n"
  },
  {
    "path": "python/sglang/jit_kernel/include/sgl_kernel/source_location.h",
    "content": "/// \\file source_location.h\n/// \\brief Portable `source_location` wrapper.\n///\n/// Uses `std::source_location` when available (C++20), otherwise falls\n/// back to a minimal stub that returns empty/zero values.\n\n#pragma once\n#include <version>\n\n/// NOTE: fallback to a minimal source_location implementation\n#if defined(__cpp_lib_source_location)\n#include <source_location>\n\nusing source_location_t = std::source_location;\n\n#else\n\nstruct source_location_fallback {\n public:\n  static constexpr source_location_fallback current() noexcept {\n    return source_location_fallback{};\n  }\n  constexpr source_location_fallback() noexcept = default;\n  constexpr unsigned line() const noexcept {\n    return 0;\n  }\n  constexpr unsigned column() const noexcept {\n    return 0;\n  }\n  constexpr const char* file_name() const noexcept {\n    return \"\";\n  }\n  constexpr const char* function_name() const noexcept {\n    return \"\";\n  }\n};\n\nusing source_location_t = source_location_fallback;\n\n#endif\n"
  },
  {
    "path": "python/sglang/jit_kernel/include/sgl_kernel/tensor.h",
    "content": "/// \\file tensor.h\n/// \\brief Tensor validation and symbolic matching utilities.\n///\n/// Provides the `TensorMatcher` fluent API for validating tensor shapes,\n/// strides, dtypes, and devices at kernel entry points, along with\n/// `SymbolicSize`, `SymbolicDType`, and `SymbolicDevice` for capturing\n/// and cross-checking tensor metadata across multiple tensors.\n///\n/// See the \"Tensor Checking\" section in the JIT kernel dev guide for\n/// usage examples.\n\n#pragma once\n#include <sgl_kernel/utils.h>\n\n#include <dlpack/dlpack.h>\n#include <tvm/ffi/container/tensor.h>\n#include <tvm/ffi/dtype.h>\n\n#include <algorithm>\n#include <array>\n#include <concepts>\n#include <cstddef>\n#include <cstdint>\n#include <initializer_list>\n#include <optional>\n#include <ranges>\n#include <span>\n#include <sstream>\n#include <string>\n#include <string_view>\n#include <type_traits>\n#include <utility>\n\n#ifdef __CUDACC__\n#include <sgl_kernel/utils.cuh>\n#endif\n\nnamespace host {\n\nnamespace details {\n\ninline constexpr auto kAnyDeviceID = -1;\ninline constexpr auto kAnySize = static_cast<int64_t>(-1);\ninline constexpr auto kNullSize = static_cast<int64_t>(-1);\ninline constexpr auto kNullDType = static_cast<DLDataTypeCode>(18u);\ninline constexpr auto kNullDevice = static_cast<DLDeviceType>(-1);\n\nstruct SizeRef;\nstruct DTypeRef;\nstruct DeviceRef;\n\ntemplate <typename T>\nstruct _dtype_trait {};\n\ntemplate <std::integral T>\nstruct _dtype_trait<T> {\n  inline static constexpr DLDataType value = {\n      .code = std::is_signed_v<T> ? DLDataTypeCode::kDLInt : DLDataTypeCode::kDLUInt,\n      .bits = static_cast<std::uint8_t>(sizeof(T) * 8),\n      .lanes = 1};\n};\n\ntemplate <std::floating_point T>\nstruct _dtype_trait<T> {\n  inline static constexpr DLDataType value = {\n      .code = DLDataTypeCode::kDLFloat, .bits = static_cast<std::uint8_t>(sizeof(T) * 8), .lanes = 1};\n};\n\n#ifdef __CUDACC__\ntemplate <>\nstruct _dtype_trait<fp16_t> {\n  inline static constexpr DLDataType value = {.code = DLDataTypeCode::kDLFloat, .bits = 16, .lanes = 1};\n};\ntemplate <>\nstruct _dtype_trait<bf16_t> {\n  inline static constexpr DLDataType value = {.code = DLDataTypeCode::kDLBfloat, .bits = 16, .lanes = 1};\n};\ntemplate <>\nstruct _dtype_trait<fp8_e4m3_t> {\n  inline static constexpr DLDataType value = {.code = DLDataTypeCode::kDLFloat8_e4m3fn, .bits = 8, .lanes = 1};\n};\n#endif\n\ntemplate <DLDeviceType Code>\nstruct _device_trait {\n  inline static constexpr DLDevice value = {.device_type = Code, .device_id = kAnyDeviceID};\n};\n\ntemplate <typename... Ts>\ninline constexpr auto kDTypeList = std::array<DLDataType, sizeof...(Ts)>{_dtype_trait<Ts>::value...};\n\ntemplate <DLDeviceType... Codes>\ninline constexpr auto kDeviceList = std::array<DLDevice, sizeof...(Codes)>{_device_trait<Codes>::value...};\n\ntemplate <typename T>\nstruct PrintAbleSpan {\n  explicit PrintAbleSpan(std::span<const T> data) : data(data) {}\n  std::span<const T> data;\n};\n\n// define DLDataType comparison and printing in root namespace\ninline constexpr auto kDeviceStringMap = [] {\n  constexpr auto map = std::array<std::pair<DLDeviceType, const char*>, 16>{\n      std::pair{DLDeviceType::kDLCPU, \"cpu\"},\n      std::pair{DLDeviceType::kDLCUDA, \"cuda\"},\n      std::pair{DLDeviceType::kDLCUDAHost, \"cuda_host\"},\n      std::pair{DLDeviceType::kDLOpenCL, \"opencl\"},\n      std::pair{DLDeviceType::kDLVulkan, \"vulkan\"},\n      std::pair{DLDeviceType::kDLMetal, \"metal\"},\n      std::pair{DLDeviceType::kDLVPI, \"vpi\"},\n      std::pair{DLDeviceType::kDLROCM, \"rocm\"},\n      std::pair{DLDeviceType::kDLROCMHost, \"rocm_host\"},\n      std::pair{DLDeviceType::kDLExtDev, \"ext_dev\"},\n      std::pair{DLDeviceType::kDLCUDAManaged, \"cuda_managed\"},\n      std::pair{DLDeviceType::kDLOneAPI, \"oneapi\"},\n      std::pair{DLDeviceType::kDLWebGPU, \"webgpu\"},\n      std::pair{DLDeviceType::kDLHexagon, \"hexagon\"},\n      std::pair{DLDeviceType::kDLMAIA, \"maia\"},\n      std::pair{DLDeviceType::kDLTrn, \"trn\"},\n  };\n  constexpr auto max_type = stdr::max(map | stdv::keys);\n  auto result = std::array<std::string_view, max_type + 1>{};\n  for (const auto& [code, name] : map) {\n    result[static_cast<std::size_t>(code)] = name;\n  }\n  return result;\n}();\n\nstruct PrintableDevice {\n  DLDevice device;\n};\n\ninline auto& operator<<(std::ostream& os, DLDevice device) {\n  const auto& mapping = kDeviceStringMap;\n  const auto entry = static_cast<std::size_t>(device.device_type);\n  RuntimeCheck(entry < mapping.size());\n  const auto name = mapping[entry];\n  RuntimeCheck(!name.empty(), \"Unknown device: \", int(device.device_type));\n  os << name;\n  if (device.device_id != kAnyDeviceID && device.device_type != DLDeviceType::kDLCPU) {\n    os << \":\" << device.device_id;\n  }\n  return os;\n}\n\ninline auto& operator<<(std::ostream& os, PrintableDevice pd) {\n  return os << pd.device;\n}\n\ntemplate <typename T>\ninline auto& operator<<(std::ostream& os, PrintAbleSpan<T> span) {\n  os << \"[\";\n  for (const auto i : irange(span.data.size())) {\n    if (i > 0) {\n      os << \", \";\n    }\n    os << span.data[i];\n  }\n  os << \"]\";\n  return os;\n}\n\n}  // namespace details\n\n/// \\brief Check whether `dtype` matches the DLDataType for C++ type `T`.\ntemplate <typename T>\ninline bool is_type(DLDataType dtype) {\n  return dtype == details::_dtype_trait<T>::value;\n}\n\n/**\n * \\brief A symbolic dimension size that can be bound once and\n *        verified across multiple tensors.\n *\n * Create with an optional annotation string for error messages:\n * \\code\n *   auto N = SymbolicSize{\"num_tokens\"};\n * \\endcode\n *\n * Call `verify()` during tensor matching to either bind the first\n * observed value or check subsequent values match. Call `unwrap()`\n * to retrieve the bound value (panics if unset).\n */\nstruct SymbolicSize {\n public:\n  SymbolicSize(std::string_view annotation = {}) : m_value(details::kNullSize), m_annotation(annotation) {}\n  SymbolicSize(const SymbolicSize&) = delete;\n  SymbolicSize& operator=(const SymbolicSize&) = delete;\n\n  auto get_name() const -> std::string_view {\n    return m_annotation;\n  }\n\n  auto set_value(int64_t value) -> void {\n    RuntimeCheck(!this->has_value(), \"Size value already set\");\n    m_value = value;\n  }\n\n  auto has_value() const -> bool {\n    return m_value != details::kNullSize;\n  }\n\n  auto get_value() const -> std::optional<int64_t> {\n    return this->has_value() ? std::optional{m_value} : std::nullopt;\n  }\n\n  auto unwrap(DebugInfo info = {}) const -> int64_t {\n    RuntimeCheck(info, this->has_value(), \"Size value is not set\");\n    return m_value;\n  }\n\n  auto verify(int64_t value, const char* prefix, int64_t dim) -> void {\n    if (this->has_value()) {\n      if (m_value != value) {\n        [[unlikely]];\n        Panic(\"Size mismatch for \", m_name_str(prefix, dim), \": expected \", m_value, \" but got \", value);\n      }\n    } else {\n      this->set_value(value);\n    }\n  }\n\n  auto value_or_name(const char* prefix, int64_t dim) const -> std::string {\n    if (const auto value = this->get_value()) {\n      return std::to_string(*value);\n    } else {\n      return m_name_str(prefix, dim);\n    }\n  }\n\n private:\n  auto m_name_str(const char* prefix, int64_t dim) const -> std::string {\n    std::ostringstream os;\n    os << prefix << '#' << dim;\n    if (!m_annotation.empty()) os << \"('\" << m_annotation << \"')\";\n    return std::move(os).str();\n  }\n\n  std::int64_t m_value;\n  std::string_view m_annotation;\n};\n\ninline auto operator==(DLDevice lhs, DLDevice rhs) -> bool {\n  return lhs.device_type == rhs.device_type && lhs.device_id == rhs.device_id;\n}\n\n/**\n * \\brief A symbolic data type that can be constrained and verified.\n *\n * Optionally restrict allowed types via `set_options<fp16_t, bf16_t>()`.\n * Use `verify()` to bind/check the dtype, and `unwrap()` to retrieve it.\n */\nstruct SymbolicDType {\n public:\n  SymbolicDType() : m_value({details::kNullDType, 0, 0}) {}\n  SymbolicDType(const SymbolicDType&) = delete;\n  SymbolicDType& operator=(const SymbolicDType&) = delete;\n\n  auto set_value(DLDataType value) -> void {\n    RuntimeCheck(!this->has_value(), \"Dtype value already set\");\n    RuntimeCheck(\n        m_check(value), \"Dtype value [\", value, \"] not in the allowed options: \", details::PrintAbleSpan{m_options});\n    m_value = value;\n  }\n\n  auto has_value() const -> bool {\n    return m_value.code != details::kNullDType;\n  }\n\n  auto get_value() const -> std::optional<DLDataType> {\n    return this->has_value() ? std::optional{m_value} : std::nullopt;\n  }\n\n  auto unwrap(DebugInfo info = {}) const -> DLDataType {\n    RuntimeCheck(info, this->has_value(), \"Dtype value is not set\");\n    return m_value;\n  }\n\n  auto set_options(std::span<const DLDataType> options) -> void {\n    m_options = options;\n  }\n\n  template <typename... Ts>\n  auto set_options() -> void {\n    m_options = details::kDTypeList<Ts...>;\n  }\n\n  auto verify(DLDataType dtype) -> void {\n    if (this->has_value()) {\n      RuntimeCheck(m_value == dtype, \"DType mismatch: expected \", m_value, \" but got \", dtype);\n    } else {\n      this->set_value(dtype);\n    }\n  }\n\n  template <typename T>\n  auto is_type() const -> bool {\n    return ::host::is_type<T>(m_value);\n  }\n\n private:\n  auto m_check(DLDataType value) const -> bool {\n    return stdr::empty(m_options) || (stdr::find(m_options, value) != stdr::end(m_options));\n  }\n\n  std::span<const DLDataType> m_options;\n  DLDataType m_value;\n};\n\n/**\n * \\brief A symbolic device that can be constrained and verified.\n *\n * Optionally restrict allowed device types via\n * `set_options<kDLCUDA, kDLCPU>()`. The device id can be wildcarded.\n */\nstruct SymbolicDevice {\n public:\n  SymbolicDevice() : m_value({details::kNullDevice, details::kAnyDeviceID}) {}\n  SymbolicDevice(const SymbolicDevice&) = delete;\n  SymbolicDevice& operator=(const SymbolicDevice&) = delete;\n\n  auto set_value(DLDevice value) -> void {\n    RuntimeCheck(!this->has_value(), \"Device value already set\");\n    RuntimeCheck(\n        m_check(value),\n        \"Device value [\",\n        details::PrintableDevice{value},\n        \"] not in the allowed options: \",\n        details::PrintAbleSpan{m_options});\n    m_value = value;\n  }\n\n  auto has_value() const -> bool {\n    return m_value.device_type != details::kNullDevice;\n  }\n\n  auto get_value() const -> std::optional<DLDevice> {\n    return this->has_value() ? std::optional{m_value} : std::nullopt;\n  }\n\n  auto unwrap(DebugInfo info = {}) const -> DLDevice {\n    RuntimeCheck(info, this->has_value(), \"Device value is not set\");\n    return m_value;\n  }\n\n  auto set_options(std::span<const DLDevice> options) -> void {\n    m_options = options;\n  }\n\n  template <DLDeviceType... Codes>\n  auto set_options() -> void {\n    m_options = details::kDeviceList<Codes...>;\n  }\n\n  auto verify(DLDevice device) -> void {\n    if (this->has_value()) {\n      RuntimeCheck(\n          m_value == device,\n          \"Device mismatch: expected \",\n          details::PrintableDevice{m_value},\n          \" but got \",\n          details::PrintableDevice{device});\n    } else {\n      this->set_value(device);\n    }\n  }\n\n private:\n  auto m_check(DLDevice value) const -> bool {\n    return stdr::empty(m_options) || (stdr::any_of(m_options, [value](const DLDevice& opt) {\n             // device type must exactly match\n             if (opt.device_type != value.device_type) return false;\n             // device id can be wildcarded\n             return opt.device_id == details::kAnyDeviceID || opt.device_id == value.device_id;\n           }));\n  }\n\n  std::span<const DLDevice> m_options;\n  DLDevice m_value;\n};\n\nnamespace details {\n\ntemplate <typename T>\nstruct BaseRef {\n public:\n  BaseRef(const BaseRef&) = delete;\n  BaseRef& operator=(const BaseRef&) = delete;\n\n  auto operator->() const -> T* {\n    return m_ref;\n  }\n  auto operator*() const -> T& {\n    return *m_ref;\n  }\n  auto rebind(T& other) -> void {\n    m_ref = &other;\n  }\n\n  explicit BaseRef() : m_ref(&m_cache), m_cache() {}\n  BaseRef(T& size) : m_ref(&size), m_cache() {}\n\n private:\n  T* m_ref;\n  T m_cache;\n};\n\nstruct SizeRef : BaseRef<SymbolicSize> {\n  using BaseRef::BaseRef;\n  SizeRef(int64_t value) {\n    if (value != kAnySize) {\n      (**this).set_value(value);\n    } else {\n      // otherwise, we can match any size\n    }\n  }\n};\n\nstruct DTypeRef : BaseRef<SymbolicDType> {\n  using BaseRef::BaseRef;\n  DTypeRef(DLDataType options) {\n    (**this).set_value(options);\n  }\n  DTypeRef(std::initializer_list<DLDataType> options) {\n    (**this).set_options(options);\n  }\n  DTypeRef(std::span<const DLDataType> options) {\n    (**this).set_options(options);\n  }\n};\n\nstruct DeviceRef : BaseRef<SymbolicDevice> {\n  using BaseRef::BaseRef;\n  DeviceRef(DLDevice options) {\n    (**this).set_value(options);\n  }\n  DeviceRef(std::initializer_list<DLDevice> options) {\n    (**this).set_options(options);\n  }\n  DeviceRef(std::span<const DLDevice> options) {\n    (**this).set_options(options);\n  }\n};\n\n}  // namespace details\n\n/**\n * \\brief Fluent API for validating tensor shape, strides, dtype, and device.\n *\n * Construct with the expected shape (using `SymbolicSize` or literal\n * integers), chain `.with_strides()`, `.with_dtype<...>()`, and\n * `.with_device<...>()`, then call `.verify(tensor)`.\n *\n * Example:\n * \\code\n *   auto N = SymbolicSize{\"N\"};\n *   TensorMatcher({N, 128})\n *       .with_dtype<fp16_t, bf16_t>()\n *       .with_device<kDLCUDA>()\n *       .verify(input_tensor);\n * \\endcode\n *\n * \\note `TensorMatcher` is a move-only temporary. Do not store in a variable.\n */\nstruct TensorMatcher {\n private:\n  using SizeRef = details::SizeRef;\n  using DTypeRef = details::DTypeRef;\n  using DeviceRef = details::DeviceRef;\n\n public:\n  TensorMatcher(const TensorMatcher&) = delete;\n  TensorMatcher& operator=(const TensorMatcher&) = delete;\n\n  explicit TensorMatcher(std::initializer_list<SizeRef> shape) : m_shape(shape), m_strides(), m_dtype() {}\n\n  auto with_strides(std::initializer_list<SizeRef> strides) && -> TensorMatcher&& {\n    // no partial update allowed\n    RuntimeCheck(m_strides.size() == 0, \"Strides already specified\");\n    RuntimeCheck(m_shape.size() == strides.size(), \"Strides size must match shape size\");\n    m_strides = strides;\n    return std::move(*this);\n  }\n\n  template <typename... Ts>\n  auto with_dtype(DTypeRef&& dtype) && -> TensorMatcher&& {\n    m_init_dtype();\n    m_dtype.rebind(*dtype);\n    m_dtype->set_options<Ts...>();\n    return std::move(*this);\n  }\n\n  template <typename... Ts>\n  auto with_dtype() && -> TensorMatcher&& {\n    static_assert(sizeof...(Ts) > 0, \"At least one dtype option must be specified\");\n    m_init_dtype();\n    m_dtype->set_options<Ts...>();\n    return std::move(*this);\n  }\n\n  template <DLDeviceType... Codes>\n  auto with_device(DeviceRef&& device) && -> TensorMatcher&& {\n    m_init_device();\n    m_device.rebind(*device);\n    m_device->set_options<Codes...>();\n    return std::move(*this);\n  }\n\n  template <DLDeviceType... Codes>\n  auto with_device() && -> TensorMatcher&& {\n    static_assert(sizeof...(Codes) > 0, \"At least one device option must be specified\");\n    m_init_device();\n    m_device->set_options<Codes...>();\n    return std::move(*this);\n  }\n\n  // once we start verification, we cannot modify anymore\n  auto verify(tvm::ffi::TensorView view, DebugInfo info = {}) const&& -> const TensorMatcher&& {\n    try {\n      m_verify_impl(view);\n    } catch (PanicError& e) {\n      auto oss = std::ostringstream{};\n      oss << \"Tensor match failed for \";\n      s_print_tensor(oss, view);\n      oss << \" at \" << info.file_name() << \":\" << info.line() << \"\\n- Root cause: \" << e.root_cause();\n      throw PanicError(std::move(oss).str());\n    }\n    return std::move(*this);\n  }\n\n private:\n  static auto s_print_tensor(std::ostringstream& oss, tvm::ffi::TensorView view) -> void {\n    oss << \"Tensor<\";\n    int64_t dim = 0;\n    for (const auto& size : view.shape()) {\n      if (dim++ > 0) oss << \", \";\n      oss << size;\n    }\n    oss << \">[strides=<\";\n    dim = 0;\n    for (const auto& stride : view.strides()) {\n      if (dim++ > 0) {\n        oss << \", \";\n      }\n      oss << stride;\n    }\n    oss << \">, dtype=\" << view.dtype();\n    oss << \", device=\" << details::PrintableDevice{view.device()} << \"]\";\n  }\n\n  auto m_verify_impl(tvm::ffi::TensorView view) const -> void {\n    const auto dim = static_cast<std::size_t>(view.dim());\n    RuntimeCheck(dim == m_shape.size(), \"Tensor dimension mismatch: expected \", m_shape.size(), \" but got \", dim);\n    for (const auto i : irange(dim)) {\n      m_shape[i]->verify(view.size(i), \"shape\", i);\n    }\n    if (m_has_strides()) {\n      for (const auto i : irange(dim)) {\n        if (view.size(i) != 1 || !m_strides[i]->has_value()) {\n          // skip stride check for size 1 dimension\n          m_strides[i]->verify(view.stride(i), \"stride\", i);\n        }\n      }\n    } else {\n      RuntimeCheck(view.is_contiguous(), \"Tensor is not contiguous as expected\");\n    }\n    // since we may double verify, we will force to check\n    m_dtype->verify(view.dtype());\n    m_device->verify(view.device());\n  }\n\n  auto m_init_dtype() -> void {\n    RuntimeCheck(!m_has_dtype, \"DType already specified\");\n    m_has_dtype = true;\n  }\n\n  auto m_init_device() -> void {\n    RuntimeCheck(!m_has_device, \"Device already specified\");\n    m_has_device = true;\n  }\n\n  auto m_has_strides() const -> bool {\n    return !m_strides.empty();\n  }\n\n  std::span<const SizeRef> m_shape;\n  std::span<const SizeRef> m_strides;\n  DTypeRef m_dtype;\n  DeviceRef m_device;\n  bool m_has_dtype = false;\n  bool m_has_device = false;\n};\n\n}  // namespace host\n"
  },
  {
    "path": "python/sglang/jit_kernel/include/sgl_kernel/tile.cuh",
    "content": "/// \\file tile.cuh\n/// \\brief Tiled memory access helpers for coalesced global memory I/O.\n///\n/// `tile::Memory<T>` represents a contiguous memory region where multiple\n/// threads cooperatively load/store elements. The three factory methods\n/// determine the thread group:\n/// - `thread()` - single thread (no tiling).\n/// - `warp()`   - all threads in a warp cooperate.\n/// - `cta()`    - all threads in the CTA cooperate.\n\n#pragma once\n#include <sgl_kernel/utils.cuh>\n\n#include <cstdint>\n\nnamespace device::tile {\n\n/**\n * \\brief Represents a contiguous memory region for cooperative tiled access.\n *\n * Each instance is parameterized by an element type `T` and bound to a\n * specific thread id (`tid`) within a group of `tsize` threads.\n *\n * \\tparam T The storage element type (e.g. `AlignedVector<packed_t<float>, 4>`).\n */\ntemplate <typename T>\nstruct Memory {\n public:\n  SGL_DEVICE constexpr Memory(uint32_t tid, uint32_t tsize) : tid(tid), tsize(tsize) {}\n  /// \\brief Create a Memory accessor for a single thread (no cooperation).\n  SGL_DEVICE static constexpr Memory thread() {\n    return Memory{0, 1};\n  }\n  /// \\brief Create a Memory accessor distributed across warp threads.\n  SGL_DEVICE static Memory warp(int warp_threads = kWarpThreads) {\n    return Memory{static_cast<uint32_t>(threadIdx.x % warp_threads), static_cast<uint32_t>(warp_threads)};\n  }\n  /// \\brief Create a Memory accessor distributed across all CTA threads.\n  SGL_DEVICE static Memory cta(int cta_threads = blockDim.x) {\n    return Memory{static_cast<uint32_t>(threadIdx.x), static_cast<uint32_t>(cta_threads)};\n  }\n  /// \\brief Load one element from `ptr` at the position assigned to this thread.\n  /// \\param ptr  Base pointer (cast to `const T*`).\n  /// \\param offset  Optional tile offset (multiplied by `tsize`).\n  SGL_DEVICE T load(const void* ptr, int64_t offset = 0) const {\n    return static_cast<const T*>(ptr)[tid + offset * tsize];\n  }\n  /// \\brief Store one element to `ptr` at the position assigned to this thread.\n  SGL_DEVICE void store(void* ptr, T val, int64_t offset = 0) const {\n    static_cast<T*>(ptr)[tid + offset * tsize] = val;\n  }\n  /// \\brief Check whether this thread's element index is within bounds.\n  SGL_DEVICE bool in_bound(int64_t element_count, int64_t offset = 0) const {\n    return tid + offset * tsize < element_count;\n  }\n\n private:\n  uint32_t tid;\n  uint32_t tsize;\n};\n\n}  // namespace device::tile\n"
  },
  {
    "path": "python/sglang/jit_kernel/include/sgl_kernel/type.cuh",
    "content": "/// \\file type.cuh\n/// \\brief Dtype trait system for CUDA scalar/packed types.\n///\n/// `dtype_trait<T>` provides per-type metadata: packed type alias,\n/// conversion functions (`from`), and unary/binary math operations.\n/// Use `device::cast<To>(from_value)` for type conversion on device.\n///\n/// Registered types:\n/// | Scalar    | Packed (x2)  | Notes                         |\n/// |-----------|-------------|-------------------------------|\n/// | `fp32_t`  | `fp32x2_t`  | Full math ops (abs,sqrt,...) |\n/// | `fp16_t`  | `fp16x2_t`  | Conversion only             |\n/// | `bf16_t`  | `bf16x2_t`  | Conversion only             |\n/// | `fp32x2_t`| `fp32x4_t`  | Packed float2 <-> half2/bf162 |\n\n#pragma once\n#include <sgl_kernel/utils.cuh>\n\ntemplate <typename T>\nstruct dtype_trait {};\n\n#define SGL_REGISTER_DTYPE_TRAIT(TYPE, PACK2, ...)  \\\n  template <>                                       \\\n  struct dtype_trait<TYPE> {                        \\\n    using self_t = TYPE;                            \\\n    using packed_t = PACK2;                         \\\n    template <typename S>                           \\\n    SGL_DEVICE static self_t from(const S& value) { \\\n      return static_cast<TYPE>(value);              \\\n    }                                               \\\n    __VA_ARGS__                                     \\\n  }\n\n#define SGL_REGISTER_TYPE_END static_assert(true)\n\n#define SGL_REGISTER_FROM_FUNCTION(FROM, FN)     \\\n  SGL_DEVICE static self_t from(const FROM& x) { \\\n    return FN(x);                                \\\n  }                                              \\\n  static_assert(true)\n\n#define SGL_REGISTER_UNARY_FUNCTION(NAME, FN)      \\\n  SGL_DEVICE static self_t NAME(const self_t& x) { \\\n    return FN(x);                                  \\\n  }                                                \\\n  static_assert(true)\n\n#define SGL_REGISTER_BINARY_FUNCTION(NAME, FN)                      \\\n  SGL_DEVICE static self_t NAME(const self_t& x, const self_t& y) { \\\n    return FN(x, y);                                                \\\n  }                                                                 \\\n  static_assert(true)\n\nSGL_REGISTER_DTYPE_TRAIT(\n    fp32_t, fp32x2_t, SGL_REGISTER_TYPE_END;  //\n    SGL_REGISTER_FROM_FUNCTION(fp16_t, __half2float);\n    SGL_REGISTER_FROM_FUNCTION(bf16_t, __bfloat162float);\n    SGL_REGISTER_UNARY_FUNCTION(abs, fabsf);\n    SGL_REGISTER_UNARY_FUNCTION(sqrt, sqrtf);\n    SGL_REGISTER_UNARY_FUNCTION(rsqrt, rsqrtf);\n    SGL_REGISTER_UNARY_FUNCTION(exp, expf);\n    SGL_REGISTER_UNARY_FUNCTION(sin, sinf);\n    SGL_REGISTER_UNARY_FUNCTION(cos, cosf);\n    SGL_REGISTER_BINARY_FUNCTION(max, fmaxf);\n    SGL_REGISTER_BINARY_FUNCTION(min, fminf););\nSGL_REGISTER_DTYPE_TRAIT(fp16_t, fp16x2_t);\nSGL_REGISTER_DTYPE_TRAIT(bf16_t, bf16x2_t);\n\n/// TODO: Add ROCM implementation\nSGL_REGISTER_DTYPE_TRAIT(\n    fp32x2_t, fp32x4_t, SGL_REGISTER_TYPE_END; SGL_REGISTER_FROM_FUNCTION(fp16x2_t, __half22float2);\n    SGL_REGISTER_FROM_FUNCTION(bf16x2_t, __bfloat1622float2););\n\nSGL_REGISTER_DTYPE_TRAIT(\n    fp16x2_t, void, SGL_REGISTER_TYPE_END; SGL_REGISTER_FROM_FUNCTION(fp32x2_t, __float22half2_rn););\n\nSGL_REGISTER_DTYPE_TRAIT(\n    bf16x2_t, void, SGL_REGISTER_TYPE_END; SGL_REGISTER_FROM_FUNCTION(fp32x2_t, __float22bfloat162_rn););\n\n#undef SGL_REGISTER_DTYPE_TRAIT\n#undef SGL_REGISTER_FROM_FUNCTION\n\n/// \\brief Alias: the packed (x2) type for `T`.\ntemplate <typename T>\nusing packed_t = typename dtype_trait<T>::packed_t;\n\nnamespace device {\n\n/**\n * \\brief Cast a value from type `From` to type `To` on device.\n *\n * Dispatches through `dtype_trait<To>::from()`, which uses the appropriate\n * CUDA intrinsic (e.g. `__half2float`, `__float22half2_rn`).\n */\ntemplate <typename To, typename From>\nSGL_DEVICE To cast(const From& value) {\n  return dtype_trait<To>::from(value);\n}\n\n}  // namespace device\n"
  },
  {
    "path": "python/sglang/jit_kernel/include/sgl_kernel/utils.cuh",
    "content": "/// \\file utils.cuh\n/// \\brief Core CUDA/device utilities: type aliases, PDL helpers,\n///        typed pointer access, kernel launch wrapper, and error checking.\n///\n/// This header is included (directly or transitively) by nearly every\n/// JIT kernel. It provides:\n/// - Scalar/packed type aliases (`fp16_t`, `bf16_t`, `fp8_e4m3_t`, ...).\n/// - `SGL_DEVICE` macro (forced-inline device function qualifier).\n/// - `kWarpThreads` constant (32).\n/// - PDL (Programmatic Dependent Launch) helpers for Hopper (sm_90+).\n/// - Typed `load_as` / `store_as` for void-pointer access.\n/// - `pointer::offset` for safe void-pointer arithmetic.\n/// - `host::LaunchKernel` - kernel launcher with optional PDL.\n/// - `host::RuntimeDeviceCheck` - CUDA error checking.\n\n#pragma once\n\n#include <sgl_kernel/utils.h>\n\n#include <dlpack/dlpack.h>\n#include <tvm/ffi/extra/c_env_api.h>\n\n#include <concepts>\n#include <cstddef>\n#include <type_traits>\n#ifndef USE_ROCM\n#include <cuda_bf16.h>\n#include <cuda_fp16.h>\n#include <cuda_fp8.h>\n#include <cuda_runtime.h>\n#else\n#include <hip/hip_bf16.h>\n#include <hip/hip_fp16.h>\n#include <hip/hip_runtime.h>\n#ifndef __grid_constant__\n#define __grid_constant__\n#endif\nusing cudaError_t = hipError_t;\nusing cudaStream_t = hipStream_t;\nusing cudaLaunchConfig_t = hipLaunchConfig_t;\nusing cudaLaunchAttribute = hipLaunchAttribute;\ninline constexpr auto cudaSuccess = hipSuccess;\n#define cudaStreamPerThread hipStreamPerThread\n#define cudaGetErrorString hipGetErrorString\n#define cudaGetLastError hipGetLastError\n#define cudaLaunchKernel hipLaunchKernel\n#endif\n\n#ifndef USE_ROCM\nusing fp32_t = float;\nusing fp16_t = __half;\nusing bf16_t = __nv_bfloat16;\nusing fp8_e4m3_t = __nv_fp8_e4m3;\nusing fp8_e5m2_t = __nv_fp8_e5m2;\n\nusing fp32x2_t = float2;\nusing fp16x2_t = __half2;\nusing bf16x2_t = __nv_bfloat162;\nusing fp8x2_e4m3_t = __nv_fp8x2_e4m3;\nusing fp8x2_e5m2_t = __nv_fp8x2_e5m2;\n\nusing fp32x4_t = float4;\n#else\nusing fp32_t = float;\nusing fp16_t = __half;\nusing bf16_t = __hip_bfloat16;\nusing fp8_e4m3_t = uint8_t;\nusing fp8_e5m2_t = uint8_t;\nusing fp32x2_t = float2;\nusing fp16x2_t = half2;\nusing bf16x2_t = __hip_bfloat162;\nusing fp8x2_e4m3_t = uint16_t;\nusing fp8x2_e5m2_t = uint16_t;\nusing fp32x4_t = float4;\n#endif\n\n/*\n * LDG Support\n */\n#ifndef USE_ROCM\n#define SGLANG_LDG(arg) __ldg(arg)\n#else\n#define SGLANG_LDG(arg) *(arg)\n#endif\n\nnamespace device {\n\n/// \\brief Macro: forced-inline device function qualifier.\n#define SGL_DEVICE __forceinline__ __device__\n\n// Architecture detection: SGL_CUDA_ARCH is injected by load_jit() and is\n// available in both host and device compilation passes, whereas __CUDA_ARCH__\n// is only defined by nvcc during the device pass.\n#if !defined(USE_ROCM)\n#if !defined(SGL_CUDA_ARCH)\n#error \"SGL_CUDA_ARCH is not defined. JIT compilation must inject -DSGL_CUDA_ARCH via load_jit().\"\n#endif\n#if defined(__CUDA_ARCH__)\nstatic_assert(\n    __CUDA_ARCH__ == SGL_CUDA_ARCH, \"SGL_CUDA_ARCH mismatch: injected arch flag does not match device target\");\n#endif\n#define SGL_ARCH_HOPPER_OR_GREATER (SGL_CUDA_ARCH >= 900)\n#define SGL_ARCH_BLACKWELL_OR_GREATER ((SGL_CUDA_ARCH >= 1000) && (CUDA_VERSION >= 12090))\n#else  // USE_ROCM\n#define SGL_ARCH_HOPPER_OR_GREATER 0\n#define SGL_ARCH_BLACKWELL_OR_GREATER 0\n#endif\n\n/// \\brief Number of threads per warp (always 32 on NVIDIA/AMD GPUs).\ninline constexpr auto kWarpThreads = 32u;\n/// \\brief Full warp active mask (all 32 lanes).\ninline constexpr auto kFullMask = 0xffffffffu;\n\n/**\n * \\brief PDL (Programmatic Dependent Launch): wait for the primary kernel.\n *\n * On Hopper (sm_90+), inserts a `griddepcontrol.wait` instruction to\n * synchronize with a preceding kernel in the same stream. On older\n * architectures or ROCm this is a no-op.\n */\ntemplate <bool kUsePDL>\nSGL_DEVICE void PDLWaitPrimary() {\n#if SGL_ARCH_HOPPER_OR_GREATER\n  if constexpr (kUsePDL) {\n    asm volatile(\"griddepcontrol.wait;\" ::: \"memory\");\n  }\n#endif\n}\n\n/**\n * \\brief PDL: trigger dependent (secondary) kernel launch.\n *\n * On Hopper (sm_90+), inserts a `griddepcontrol.launch_dependents`\n * instruction. On older architectures or ROCm this is a no-op.\n */\ntemplate <bool kUsePDL>\nSGL_DEVICE void PDLTriggerSecondary() {\n#if SGL_ARCH_HOPPER_OR_GREATER\n  if constexpr (kUsePDL) {\n    asm volatile(\"griddepcontrol.launch_dependents;\" :::);\n  }\n#endif\n}\n\n/**\n * \\brief Load data with the specified type and offset from a void pointer.\n * \\tparam T The type to load.\n * \\param ptr The base pointer.\n * \\param offset The offset in number of elements of type T.\n */\ntemplate <typename T>\nSGL_DEVICE T load_as(const void* ptr, int64_t offset = 0) {\n  return static_cast<const T*>(ptr)[offset];\n}\n\n/**\n * \\brief Store data with the specified type and offset to a void pointer.\n * \\tparam T The type to store.\n * \\param ptr The base pointer.\n * \\param val The value to store.\n * \\param offset The offset in number of elements of type T.\n * \\note we use type_identity_t to force the caller to explicitly specify\n * the template parameter `T`, which can avoid accidentally using the wrong type.\n */\ntemplate <typename T>\nSGL_DEVICE void store_as(void* ptr, std::type_identity_t<T> val, int64_t offset = 0) {\n  static_cast<T*>(ptr)[offset] = val;\n}\n\n/// \\brief Safe void-pointer arithmetic (byte-level by default).\nnamespace pointer {\n\n// we only allow void * pointer arithmetic for safety\n\ntemplate <typename T = char, std::integral... U>\nSGL_DEVICE auto offset(void* ptr, U... offset) -> void* {\n  return static_cast<T*>(ptr) + (... + offset);\n}\n\ntemplate <typename T = char, std::integral... U>\nSGL_DEVICE auto offset(const void* ptr, U... offset) -> const void* {\n  return static_cast<const T*>(ptr) + (... + offset);\n}\n\n}  // namespace pointer\n\n}  // namespace device\n\nnamespace host {\n\n/**\n * \\brief Check the CUDA error code and panic with location info on failure.\n */\ninline void RuntimeDeviceCheck(::cudaError_t error, DebugInfo location = {}) {\n  if (error != ::cudaSuccess) {\n    [[unlikely]];\n    ::host::panic(location, \"CUDA error: \", ::cudaGetErrorString(error));\n  }\n}\n\n/// \\brief Check the last CUDA error (calls `cudaGetLastError`).\ninline void RuntimeDeviceCheck(DebugInfo location = {}) {\n  return RuntimeDeviceCheck(::cudaGetLastError(), location);\n}\n\n/**\n * \\brief Kernel launcher with automatic stream resolution and PDL support.\n *\n * Usage:\n * \\code\n *   host::LaunchKernel(grid, block, device)\n *       .enable_pdl(true)\n *       (my_kernel, arg1, arg2);\n * \\endcode\n *\n * The constructor resolves the CUDA stream from a `DLDevice` (via\n * `TVMFFIEnvGetStream`) or accepts a raw `cudaStream_t`. The call\n * operator launches the kernel and checks for errors.\n */\nstruct LaunchKernel {\n public:\n  explicit LaunchKernel(\n      dim3 grid_dim,\n      dim3 block_dim,\n      DLDevice device,\n      std::size_t dynamic_shared_mem_bytes = 0,\n      DebugInfo location = {}) noexcept\n      : m_config(s_make_config(grid_dim, block_dim, resolve_device(device), dynamic_shared_mem_bytes)),\n        m_location(location) {}\n\n  explicit LaunchKernel(\n      dim3 grid_dim,\n      dim3 block_dim,\n      cudaStream_t stream,\n      std::size_t dynamic_shared_mem_bytes = 0,\n      DebugInfo location = {}) noexcept\n      : m_config(s_make_config(grid_dim, block_dim, stream, dynamic_shared_mem_bytes)), m_location(location) {}\n\n  LaunchKernel(const LaunchKernel&) = delete;\n  LaunchKernel& operator=(const LaunchKernel&) = delete;\n\n  static auto resolve_device(DLDevice device) -> cudaStream_t {\n    return static_cast<cudaStream_t>(::TVMFFIEnvGetStream(device.device_type, device.device_id));\n  }\n\n  auto enable_pdl(bool enabled = true) -> LaunchKernel& {\n#ifdef USE_ROCM\n    (void)enabled;\n    m_config.numAttrs = 0;\n#else\n    if (enabled) {\n      m_attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;\n      m_attrs[0].val.programmaticStreamSerializationAllowed = true;\n      m_config.numAttrs = 1;\n      m_config.attrs = m_attrs;\n    } else {\n      m_config.numAttrs = 0;\n    }\n#endif\n    return *this;\n  }\n\n  template <typename T, typename... Args>\n  auto operator()(T&& kernel, Args&&... args) const -> void {\n#ifdef USE_ROCM\n    hipLaunchKernelGGL(\n        std::forward<T>(kernel),\n        m_config.gridDim,\n        m_config.blockDim,\n        m_config.dynamicSmemBytes,\n        m_config.stream,\n        std::forward<Args>(args)...);\n    RuntimeDeviceCheck(m_location);\n#else\n    RuntimeDeviceCheck(::cudaLaunchKernelEx(&m_config, kernel, std::forward<Args>(args)...), m_location);\n#endif\n  }\n\n private:\n  static auto s_make_config(  // Make a config for kernel launch\n      dim3 grid_dim,\n      dim3 block_dim,\n      cudaStream_t stream,\n      std::size_t smem) -> cudaLaunchConfig_t {\n    auto config = ::cudaLaunchConfig_t{};\n    config.gridDim = grid_dim;\n    config.blockDim = block_dim;\n    config.dynamicSmemBytes = smem;\n    config.stream = stream;\n    config.numAttrs = 0;\n    return config;\n  }\n\n  cudaLaunchConfig_t m_config;\n  const DebugInfo m_location;\n  cudaLaunchAttribute m_attrs[1];\n};\n\n}  // namespace host\n"
  },
  {
    "path": "python/sglang/jit_kernel/include/sgl_kernel/utils.h",
    "content": "/// \\file utils.h\n/// \\brief Host-side C++ utilities used by JIT kernel wrappers.\n///\n/// Provides:\n/// - `DebugInfo` - wraps `std::source_location` for error reporting.\n/// - `RuntimeCheck` - runtime assertion with formatted error messages.\n/// - `Panic` - unconditional abort with formatted error messages.\n/// - `pointer::offset` - safe void-pointer arithmetic (host side).\n/// - `div_ceil` - integer ceiling division.\n/// - `dtype_bytes` - byte width of a `DLDataType`.\n/// - `irange` - Python-style integer range for range-for loops.\n\n#pragma once\n\n// ref: https://forums.developer.nvidia.com/t/c-20s-source-location-compilation-error-when-using-nvcc-12-1/258026/3\n#ifdef __CUDACC__\n#include <cuda.h>\n#if CUDA_VERSION <= 12010\n\n#pragma push_macro(\"__cpp_consteval\")\n#pragma push_macro(\"_NODISCARD\")\n#pragma push_macro(\"__builtin_LINE\")\n\n#pragma clang diagnostic push\n#pragma clang diagnostic ignored \"-Wbuiltin-macro-redefined\"\n#define __cpp_consteval 201811L\n#pragma clang diagnostic pop\n\n#ifdef _NODISCARD\n#undef _NODISCARD\n#define _NODISCARD\n#endif\n\n#define consteval constexpr\n\n#include \"source_location.h\"\n\n#undef consteval\n#pragma pop_macro(\"__cpp_consteval\")\n#pragma pop_macro(\"_NODISCARD\")\n#else  // __CUDACC__ && CUDA_VERSION > 12010\n#include \"source_location.h\"\n#endif\n#else  // no __CUDACC__\n#include \"source_location.h\"\n#endif\n\n#include <dlpack/dlpack.h>\n\n#include <concepts>\n#include <cstddef>\n#include <ostream>\n#include <ranges>\n#include <sstream>\n#include <utility>\n\nnamespace host {\n\ntemplate <typename>\ninline constexpr bool dependent_false_v = false;\n\n/// \\brief Source-location wrapper for debug/error messages.\nstruct DebugInfo : public source_location_t {\n  DebugInfo(source_location_t loc = source_location_t::current()) : source_location_t(loc) {}\n};\n\n/// \\brief Exception type thrown by `RuntimeCheck` and `Panic`.\nstruct PanicError : public std::runtime_error {\n public:\n  explicit PanicError(std::string msg) : runtime_error(msg), m_message(std::move(msg)) {}\n  auto root_cause() const -> std::string_view {\n    const auto str = std::string_view{m_message};\n    const auto pos = str.find(\": \");\n    return pos == std::string_view::npos ? str : str.substr(pos + 2);\n  }\n\n private:\n  std::string m_message;\n};\n\n/// \\brief Unconditionally abort with a formatted error message.\ntemplate <typename... Args>\n[[noreturn]]\ninline auto panic(DebugInfo location, Args&&... args) -> void {\n  std::ostringstream os;\n  os << \"Runtime check failed at \" << location.file_name() << \":\" << location.line();\n  if constexpr (sizeof...(args) > 0) {\n    os << \": \";\n    (os << ... << std::forward<Args>(args));\n  } else {\n    os << \" in \" << location.function_name();\n  }\n  throw PanicError(std::move(os).str());\n}\n\n/**\n * \\brief Runtime assertion: panics with a formatted message when `condition`\n *        is false. Extra `args` are streamed to the error message.\n *\n * Example:\n * \\code\n *   RuntimeCheck(n > 0, \"n must be positive, got \", n);\n * \\endcode\n */\ntemplate <typename... Args>\nstruct RuntimeCheck {\n  template <typename Cond>\n  explicit RuntimeCheck(Cond&& condition, Args&&... args, DebugInfo location = {}) {\n    if (condition) return;\n    [[unlikely]] ::host::panic(location, std::forward<Args>(args)...);\n  }\n  template <typename Cond>\n  explicit RuntimeCheck(DebugInfo location, Cond&& condition, Args&&... args) {\n    if (condition) return;\n    [[unlikely]] ::host::panic(location, std::forward<Args>(args)...);\n  }\n};\n\ntemplate <typename... Args>\nstruct Panic {\n  explicit Panic(Args&&... args, DebugInfo location = {}) {\n    ::host::panic(location, std::forward<Args>(args)...);\n  }\n  explicit Panic(DebugInfo location, Args&&... args) {\n    ::host::panic(location, std::forward<Args>(args)...);\n  }\n  [[noreturn]] ~Panic() {\n    std::terminate();\n  }\n};\n\ntemplate <typename Cond, typename... Args>\nexplicit RuntimeCheck(Cond&&, Args&&...) -> RuntimeCheck<Args...>;\n\ntemplate <typename Cond, typename... Args>\nexplicit RuntimeCheck(DebugInfo, Cond&&, Args&&...) -> RuntimeCheck<Args...>;\n\ntemplate <typename... Args>\nexplicit Panic(Args&&...) -> Panic<Args...>;\n\ntemplate <typename... Args>\nexplicit Panic(DebugInfo, Args&&...) -> Panic<Args...>;\n\nnamespace pointer {\n\n// we only allow void * pointer arithmetic for safety\n\ntemplate <typename T = char, std::integral... U>\ninline auto offset(void* ptr, U... offset) -> void* {\n  return static_cast<T*>(ptr) + (... + offset);\n}\n\ntemplate <typename T = char, std::integral... U>\ninline auto offset(const void* ptr, U... offset) -> const void* {\n  return static_cast<const T*>(ptr) + (... + offset);\n}\n\n}  // namespace pointer\n\n/// \\brief Integer ceiling division: ceil(a / b).\ntemplate <std::integral T, std::integral U>\ninline constexpr auto div_ceil(T a, U b) {\n  return (a + b - 1) / b;\n}\n\n/// \\brief Returns the byte width of a DLPack data type.\ninline auto dtype_bytes(DLDataType dtype) -> std::size_t {\n  return static_cast<std::size_t>(dtype.bits / 8);\n}\n\nnamespace stdr = std::ranges;\nnamespace stdv = stdr::views;\n\n/// \\brief Python-style integer range: `irange(n)` -> `[0, n)`.\ntemplate <std::integral T>\ninline auto irange(T end) {\n  return stdv::iota(static_cast<T>(0), end);\n}\n\n/// \\brief Python-style integer range: `irange(start, end)` -> `[start, end)`.\ntemplate <std::integral T>\ninline auto irange(T start, T end) {\n  return stdv::iota(start, end);\n}\n\n}  // namespace host\n"
  },
  {
    "path": "python/sglang/jit_kernel/include/sgl_kernel/vec.cuh",
    "content": "/// \\file vec.cuh\n/// \\brief Aligned vector types for coalesced global memory access.\n///\n/// `AlignedVector<T, N>` wraps `N` elements of type `T` in a naturally\n/// aligned struct so that the compiler emits wide (vectorized) load/store\n/// instructions (e.g. `LDG.128`). The maximum supported vector width is\n/// 256 bits (32 bytes), matching CUDA's widest vector load.\n\n#pragma once\n#include <sgl_kernel/utils.cuh>\n\n#include <cstddef>\n#include <cstdint>\n\nnamespace device {\n\nnamespace details {\n\n/// \\brief Maps byte-width to the corresponding unsigned integer type.\ntemplate <std::size_t N>\nstruct uint_trait {};\n\ntemplate <>\nstruct uint_trait<1> {\n  using type = uint8_t;\n};\n\ntemplate <>\nstruct uint_trait<2> {\n  using type = uint16_t;\n};\n\ntemplate <>\nstruct uint_trait<4> {\n  using type = uint32_t;\n};\n\ntemplate <>\nstruct uint_trait<8> {\n  using type = uint64_t;\n};\n\n/// \\brief Alias: maps `sizeof(T)` to matching unsigned int type.\ntemplate <typename T>\nusing sized_int = typename uint_trait<sizeof(T)>::type;\n\n}  // namespace details\n\n/// \\brief Raw aligned storage for `N` elements of type `T`.\ntemplate <typename T, std::size_t N>\nstruct alignas(sizeof(T) * N) AlignedStorage {\n  T data[N];\n};\n\n/**\n * \\brief Aligned vector for vectorized memory access on GPU.\n *\n * Stores `N` elements of type `T` with natural alignment so that a single\n * `load`/`store` call compiles to a wide memory transaction.\n *\n * \\tparam T Element type (e.g. `fp16_t`, `bf16_t`, `float`).\n * \\tparam N Number of elements. Must be a power of two and\n *           `sizeof(T) * N <= 32` (256 bits).\n *\n * Example:\n * \\code\n *   AlignedVector<fp16_t, 8> vec;  // 16 bytes, 128-bit aligned\n *   vec.load(input_ptr, tid);      // vectorized load\n *   vec[0] = vec[0] + 1;\n *   vec.store(output_ptr, tid);    // vectorized store\n * \\endcode\n */\ntemplate <typename T, std::size_t N>\nstruct AlignedVector {\n private:\n  /// NOTE: N must be a power of two and sizeof(T) * N <= 32 bytes (256 bits)\n  static_assert((N > 0 && (N & (N - 1)) == 0) && sizeof(T) * N <= 32, \"CUDA only supports at most 256-bit vector op\");\n  using element_t = typename details::sized_int<T>;\n  using storage_t = AlignedStorage<element_t, N>;\n\n public:\n  /// \\brief Vectorized load from `ptr` at the given element `offset`.\n  template <typename U>\n  SGL_DEVICE void load(const U* ptr, std::size_t offset = 0) {\n    static_assert(std::is_same_v<U, T> || std::is_same_v<U, void>);\n    m_storage = reinterpret_cast<const storage_t*>(ptr)[offset];\n  }\n  /// \\brief Vectorized store to `ptr` at the given element `offset`.\n  template <typename U>\n  SGL_DEVICE void store(U* ptr, std::size_t offset = 0) const {\n    static_assert(std::is_same_v<U, T> || std::is_same_v<U, void>);\n    reinterpret_cast<storage_t*>(ptr)[offset] = m_storage;\n  }\n  /// \\brief Fill all N elements with the same `value`.\n  SGL_DEVICE void fill(T value) {\n    const auto store_value = *reinterpret_cast<element_t*>(&value);\n#pragma unroll\n    for (std::size_t i = 0; i < N; ++i) {\n      m_storage.data[i] = store_value;\n    }\n  }\n\n  SGL_DEVICE auto operator[](std::size_t idx) -> T& {\n    return reinterpret_cast<T*>(&m_storage)[idx];\n  }\n  SGL_DEVICE auto operator[](std::size_t idx) const -> T {\n    return reinterpret_cast<const T*>(&m_storage)[idx];\n  }\n  SGL_DEVICE auto data() -> T* {\n    return reinterpret_cast<T*>(&m_storage);\n  }\n  SGL_DEVICE auto data() const -> const T* {\n    return reinterpret_cast<const T*>(&m_storage);\n  }\n\n private:\n  storage_t m_storage;\n};\n\n}  // namespace device\n"
  },
  {
    "path": "python/sglang/jit_kernel/include/sgl_kernel/warp.cuh",
    "content": "/// \\file warp.cuh\n/// \\brief Warp-level reduction primitives using `__shfl_xor_sync`.\n\n#pragma once\n#include <sgl_kernel/math.cuh>\n\nnamespace device::warp {\n\n/// \\brief Full 32-thread active mask.\nstatic constexpr uint32_t kFullMask = 0xffffffffu;\n\n/**\n * \\brief Warp-level sum reduction.\n *\n * Computes the sum of `value` across all active lanes specified by\n * `active_mask` using butterfly (XOR) shuffles. The result is\n * broadcast to all participating lanes.\n *\n * \\tparam T Numeric type (e.g. float).\n * \\param value Per-lane input value.\n * \\param active_mask Bitmask of participating lanes (default: all 32).\n * \\return The sum across all active lanes.\n */\ntemplate <typename T>\nSGL_DEVICE T reduce_sum(T value, uint32_t active_mask = kFullMask) {\n#pragma unroll\n  for (int mask = 16; mask > 0; mask >>= 1)\n    value = value + __shfl_xor_sync(active_mask, value, mask, 32);\n  return value;\n}\n\n/**\n * \\brief Warp-level max reduction.\n *\n * Computes the maximum of `value` across all active lanes using\n * butterfly shuffles. The result is broadcast to all participating\n * lanes.\n *\n * \\tparam T Numeric type (must be supported by `math::max`).\n * \\param value Per-lane input value.\n * \\param active_mask Bitmask of participating lanes (default: all 32).\n * \\return The maximum across all active lanes.\n */\ntemplate <typename T>\nSGL_DEVICE T reduce_max(T value, uint32_t active_mask = kFullMask) {\n#pragma unroll\n  for (int mask = 16; mask > 0; mask >>= 1)\n    value = math::max(value, __shfl_xor_sync(active_mask, value, mask, 32));\n  return value;\n}\n\n}  // namespace device::warp\n"
  },
  {
    "path": "python/sglang/jit_kernel/kvcache.py",
    "content": "from __future__ import annotations\n\nimport logging\nfrom typing import TYPE_CHECKING\n\nimport torch\n\nfrom sglang.jit_kernel.utils import (\n    cache_once,\n    is_arch_support_pdl,\n    load_jit,\n    make_cpp_args,\n)\n\nif TYPE_CHECKING:\n    from tvm_ffi.module import Module\n\n\n@cache_once\ndef _jit_kvcache_module(row_bytes: int) -> Module:\n    args = make_cpp_args(row_bytes, is_arch_support_pdl())\n    return load_jit(\n        \"kvcache\",\n        *args,\n        cuda_files=[\"elementwise/kvcache.cuh\"],\n        cuda_wrappers=[(\"store_cache\", f\"StoreKVCacheKernel<{args}>::run\")],\n    )\n\n\n@cache_once\ndef can_use_store_cache(size: int) -> bool:\n    logger = logging.getLogger(__name__)\n    if size % 4 != 0:\n        logger.warning(\n            f\"Unsupported row_bytes={size} for JIT KV-Cache kernel:\"\n            \" must be multiple of 4\"\n        )\n        return False\n    try:\n        _jit_kvcache_module(size)\n        return True\n    except Exception as e:\n        logger.warning(\n            f\"Failed to load JIT KV-Cache kernel \" f\"with row_bytes={size}: {e}\"\n        )\n        return False\n\n\ndef store_cache(\n    k: torch.Tensor,\n    v: torch.Tensor,\n    k_cache: torch.Tensor,\n    v_cache: torch.Tensor,\n    indices: torch.Tensor,\n    *,\n    row_bytes: int = 0,\n    num_split: int = 0,  # can be tuned for performance\n) -> None:\n    \"\"\"Store key and value tensors into KV cache at specified indices.\n\n    Args:\n        k (torch.Tensor): Key tensor of shape (batch_size, H * D).\n        v (torch.Tensor): Value tensor of shape (batch_size, H * D).\n        k_cache (torch.Tensor): Key cache tensor of shape (num_pages, H * D).\n        v_cache (torch.Tensor): Value cache tensor of shape (num_pages, H * D).\n        indices (torch.Tensor): Indices tensor of shape (batch_size,).\n    \"\"\"\n    row_bytes = row_bytes or k.shape[-1] * k.element_size()\n    module = _jit_kvcache_module(row_bytes)\n    if num_split <= 0:\n        if row_bytes % 2048 == 0:\n            num_split = 4\n        elif row_bytes % 1024 == 0:\n            num_split = 2\n        else:\n            num_split = 1\n    module.store_cache(\n        k,\n        v,\n        k_cache,\n        v_cache,\n        indices,\n        num_split,\n    )\n"
  },
  {
    "path": "python/sglang/jit_kernel/moe_lora_align.py",
    "content": "from __future__ import annotations\n\nfrom typing import TYPE_CHECKING, Optional\n\nimport torch\n\nfrom sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args\n\nif TYPE_CHECKING:\n    from tvm_ffi.module import Module\n\n\n@cache_once\ndef _jit_moe_align_module(dtype: torch.dtype) -> Module:\n    args = make_cpp_args(dtype)\n    return load_jit(\n        \"moe_lora_align_block_size\",\n        *args,\n        cuda_files=[\"lora/moe_lora_align_kernel.cu\"],\n        cuda_wrappers=[\n            (\"moe_lora_align_block_size\", f\"MoeLoraAlignBlockSizeKernel<{args}>::run\"),\n        ],\n    )\n\n\ndef moe_lora_align_block_size(\n    topk_ids: torch.Tensor,\n    seg_indptr: torch.Tensor,\n    req_to_lora: torch.Tensor,\n    num_experts: int,\n    block_size: int,\n    max_loras: int,\n    max_num_tokens_padded: int,\n    max_num_m_blocks: int,\n    sorted_token_ids: torch.Tensor,\n    expert_ids: torch.Tensor,\n    num_tokens_post_pad: torch.Tensor,\n    adapter_enabled: torch.Tensor,\n    lora_ids: torch.Tensor,\n    maybe_expert_map: Optional[torch.Tensor] = None,\n) -> None:\n    module = _jit_moe_align_module(topk_ids.dtype)\n\n    cumsum_buffer = torch.zeros(\n        max_loras * (num_experts + 1), dtype=torch.int32, device=topk_ids.device\n    )\n    token_mask = torch.empty(\n        (max_loras * topk_ids.shape[0],), dtype=torch.int32, device=topk_ids.device\n    )\n\n    module.moe_lora_align_block_size(\n        topk_ids,\n        seg_indptr,\n        req_to_lora,\n        num_experts,\n        block_size,\n        max_loras,\n        max_num_tokens_padded,\n        max_num_m_blocks,\n        sorted_token_ids,\n        expert_ids,\n        num_tokens_post_pad,\n        adapter_enabled,\n        lora_ids,\n        maybe_expert_map,\n        cumsum_buffer,\n        token_mask,\n    )\n"
  },
  {
    "path": "python/sglang/jit_kernel/moe_wna16_marlin.py",
    "content": "from __future__ import annotations\n\nfrom typing import TYPE_CHECKING, Optional\n\nimport torch\n\nfrom sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args\n\nif TYPE_CHECKING:\n    from sgl_kernel.scalar_type import ScalarType\n    from tvm_ffi.module import Module\n\n# Constants matching device::marlin_moe:: in marlin.cuh\n_MAX_THREAD_N = 256\n\n\n@cache_once\ndef _jit_moe_wna16_marlin_module(dtype: torch.dtype) -> Module:\n    args = make_cpp_args(dtype)\n    return load_jit(\n        \"moe_wna16_marlin\",\n        *args,\n        cuda_files=[\"gemm/marlin_moe/moe_wna16_marlin.cuh\"],\n        cuda_wrappers=[\n            (\n                \"moe_wna16_marlin_gemm\",\n                f\"moe_wna16_marlin_gemm<{args}>\",\n            )\n        ],\n    )\n\n\ndef _or_empty(\n    t: Optional[torch.Tensor], device: torch.device, dtype: torch.dtype\n) -> torch.Tensor:\n    return t if t is not None else torch.empty(0, device=device, dtype=dtype)\n\n\ndef moe_wna16_marlin_gemm(\n    a: torch.Tensor,\n    c_or_none: Optional[torch.Tensor],\n    b_q_weight: torch.Tensor,\n    b_bias_or_none: Optional[torch.Tensor],\n    b_scales: torch.Tensor,\n    global_scale_or_none: Optional[torch.Tensor],\n    b_zeros_or_none: Optional[torch.Tensor],\n    g_idx_or_none: Optional[torch.Tensor],\n    perm_or_none: Optional[torch.Tensor],\n    workspace: torch.Tensor,\n    sorted_token_ids: torch.Tensor,\n    expert_ids: torch.Tensor,\n    num_tokens_post_padded: torch.Tensor,\n    topk_weights: torch.Tensor,\n    moe_block_size: int,\n    top_k: int,\n    mul_topk_weights: bool,\n    is_ep: bool,\n    b_q_type: ScalarType,\n    size_m: int,\n    size_n: int,\n    size_k: int,\n    is_k_full: bool = True,\n    use_atomic_add: bool = False,\n    use_fp32_reduce: bool = False,\n    is_zp_float: bool = False,\n) -> torch.Tensor:\n    device = a.device\n\n    # Allocate output if not provided\n    if c_or_none is not None:\n        c = c_or_none\n    else:\n        c = torch.empty((size_m * top_k, size_n), dtype=a.dtype, device=device)\n\n    # Early return for zero-size M\n    if size_m == 0:\n        return c\n\n    # Determine activation ordering\n    has_act_order = (\n        g_idx_or_none is not None\n        and perm_or_none is not None\n        and g_idx_or_none.numel() > 0\n        and perm_or_none.numel() > 0\n        and g_idx_or_none.size(-1) > 0\n        and perm_or_none.size(-1) > 0\n    )\n\n    # Determine has_zp\n    has_zp = b_zeros_or_none is not None and b_zeros_or_none.numel() > 0\n\n    # Determine has_bias\n    has_bias = b_bias_or_none is not None\n\n    # Derive num_groups and group_size from b_scales\n    num_groups = b_scales.size(1)\n    if has_act_order:\n        if is_k_full:\n            group_size = size_k // num_groups\n        else:\n            group_size = 0\n    else:\n        if num_groups > 1:\n            group_size = size_k // num_groups\n        else:\n            group_size = -1\n\n    # Allocate a_tmp for act_order column permutation\n    if has_act_order:\n        a_tmp = torch.empty((size_m * top_k, size_k), dtype=a.dtype, device=device)\n    else:\n        a_tmp = torch.empty(0, dtype=a.dtype, device=device)\n\n    # Allocate c_tmp for fp32 reduce\n    if use_fp32_reduce and not use_atomic_add:\n        sms = torch.cuda.get_device_properties(device).multi_processor_count\n        # max num of threadblocks is sms * 4\n        max_c_tmp_size = min(\n            size_n * sorted_token_ids.size(0),\n            sms * 4 * moe_block_size * _MAX_THREAD_N,\n        )\n        if moe_block_size == 8:\n            max_c_tmp_size *= 2\n        c_tmp = torch.empty(max_c_tmp_size, dtype=torch.float32, device=device)\n    else:\n        c_tmp = torch.empty(0, dtype=torch.float32, device=device)\n\n    # Convert Optional tensors to empty tensors\n    g_idx_t = _or_empty(g_idx_or_none, device, torch.int32)\n    perm_t = _or_empty(perm_or_none, device, torch.int32)\n    b_zeros_t = _or_empty(b_zeros_or_none, device, a.dtype)\n    b_bias_t = _or_empty(b_bias_or_none, device, a.dtype)\n    global_scale_t = _or_empty(global_scale_or_none, device, a.dtype)\n\n    module = _jit_moe_wna16_marlin_module(a.dtype)\n    module.moe_wna16_marlin_gemm(\n        a,\n        c,\n        b_q_weight,\n        b_bias_t,\n        b_scales,\n        global_scale_t,\n        b_zeros_t,\n        g_idx_t,\n        perm_t,\n        workspace,\n        sorted_token_ids,\n        expert_ids,\n        num_tokens_post_padded,\n        topk_weights,\n        a_tmp,\n        c_tmp,\n        moe_block_size,\n        top_k,\n        mul_topk_weights,\n        is_ep,\n        b_q_type.id,\n        size_m,\n        size_n,\n        size_k,\n        has_act_order,\n        has_bias,\n        is_k_full,\n        has_zp,\n        num_groups,\n        group_size,\n        use_atomic_add,\n        use_fp32_reduce,\n        is_zp_float,\n    )\n\n    return c\n"
  },
  {
    "path": "python/sglang/jit_kernel/ngram_embedding.py",
    "content": "from __future__ import annotations\n\nfrom typing import TYPE_CHECKING\n\nfrom sglang.jit_kernel.utils import cache_once, load_jit\n\nif TYPE_CHECKING:\n    import torch\n    from tvm_ffi.module import Module\n\n\n@cache_once\ndef _jit_ngram_embedding_module() -> Module:\n    return load_jit(\n        \"ngram_embedding\",\n        cuda_files=[\"ngram_embedding.cuh\"],\n        cuda_wrappers=[\n            (\"compute_n_gram_ids\", \"&NgramEmbeddingKernel::compute_n_gram_ids\"),\n            (\"update_token_table\", \"&NgramEmbeddingKernel::update_token_table\"),\n        ],\n    )\n\n\ndef compute_n_gram_ids(\n    ne_n: int,\n    ne_k: int,\n    ne_weights: torch.Tensor,\n    ne_mods: torch.Tensor,\n    exclusive_ne_embedder_size_sums: torch.Tensor,\n    tokens: torch.Tensor,\n    exclusive_req_len_sums: torch.Tensor,\n    ne_token_table: torch.Tensor,\n    row_indices: torch.Tensor,\n    column_starts: torch.Tensor,\n    n_gram_ids: torch.Tensor,\n) -> None:\n    \"\"\"\n    Compute n-gram IDs for embedding.\n\n    Args:\n        ne_n: n value for n-gram\n        ne_k: k value for n-gram configurations\n        ne_weights: weights tensor with shape [ne_n-1, ne_k, ne_n]\n        ne_mods: mods tensor with shape [ne_n-1, ne_k]\n        exclusive_ne_embedder_size_sums: exclusive sum of embedder sizes\n        tokens: input token ids\n        exclusive_req_len_sums: exclusive sum of request lengths\n        ne_token_table: token table for all requests\n        row_indices: row indices for each request\n        column_starts: column start positions for each request\n        n_gram_ids: output tensor for n-gram ids\n    \"\"\"\n    module = _jit_ngram_embedding_module()\n    module.compute_n_gram_ids(\n        ne_n,\n        ne_k,\n        ne_weights,\n        ne_mods,\n        exclusive_ne_embedder_size_sums,\n        tokens,\n        exclusive_req_len_sums,\n        ne_token_table,\n        row_indices,\n        column_starts,\n        n_gram_ids,\n    )\n\n\ndef update_token_table(\n    tokens: torch.Tensor,\n    ne_token_table: torch.Tensor,\n    row_indices: torch.Tensor,\n    column_starts: torch.Tensor,\n    req_lens: torch.Tensor,\n    ignore_tokens: torch.Tensor | None = None,\n) -> None:\n    \"\"\"\n    Update the token table with new tokens.\n\n    Args:\n        tokens: input token ids\n        ne_token_table: token table for all requests\n        row_indices: row indices for each request\n        column_starts: column start positions for each request\n        req_lens: request lengths\n        ignore_tokens: tokens to be ignored (marked as negative in table)\n    \"\"\"\n    module = _jit_ngram_embedding_module()\n    if ignore_tokens is None:\n        # Create an empty tensor for ignore_tokens\n        ignore_tokens = tokens.new_empty(0, dtype=tokens.dtype)\n    module.update_token_table(\n        tokens,\n        ne_token_table,\n        row_indices,\n        column_starts,\n        req_lens,\n        ignore_tokens,\n    )\n"
  },
  {
    "path": "python/sglang/jit_kernel/norm.py",
    "content": "from __future__ import annotations\n\nimport logging\nfrom typing import TYPE_CHECKING, Optional\n\nimport torch\n\nfrom sglang.jit_kernel.utils import (\n    cache_once,\n    is_arch_support_pdl,\n    load_jit,\n    make_cpp_args,\n)\n\nif TYPE_CHECKING:\n    from tvm_ffi.module import Module\n\n\n@cache_once\ndef _jit_qknorm_module(head_dim: int, dtype: torch.dtype) -> Module:\n    args = make_cpp_args(head_dim, is_arch_support_pdl(), dtype)\n    return load_jit(\n        \"qknorm\",\n        *args,\n        cuda_files=[\"elementwise/qknorm.cuh\"],\n        cuda_wrappers=[(\"qknorm\", f\"QKNormKernel<{args}>::run\")],\n    )\n\n\n@cache_once\ndef _jit_rmsnorm_module(hidden_size: int, dtype: torch.dtype) -> Module:\n    args = make_cpp_args(hidden_size, is_arch_support_pdl(), dtype)\n    return load_jit(\n        \"rmsnorm\",\n        *args,\n        cuda_files=[\"elementwise/rmsnorm.cuh\"],\n        cuda_wrappers=[(\"rmsnorm\", f\"RMSNormKernel<{args}>::run\")],\n    )\n\n\n@cache_once\ndef _jit_fused_add_rmsnorm_module(dtype: torch.dtype) -> Module:\n    args = make_cpp_args(dtype)\n    return load_jit(\n        \"fused_add_rmsnorm\",\n        *args,\n        cuda_files=[\"elementwise/fused_add_rmsnorm.cuh\"],\n        cuda_wrappers=[(\"fused_add_rmsnorm\", f\"FusedAddRMSNormKernel<{args}>::run\")],\n    )\n\n\n@cache_once\ndef _jit_qknorm_across_heads_module(dtype: torch.dtype) -> Module:\n    args = make_cpp_args(dtype)\n    return load_jit(\n        \"qknorm_across_heads\",\n        *args,\n        cuda_files=[\"elementwise/qknorm_across_heads.cuh\"],\n        cuda_wrappers=[\n            (\"qknorm_across_heads\", f\"QKNormAcrossHeadsKernel<{args}>::run\")\n        ],\n    )\n\n\n@cache_once\ndef can_use_fused_inplace_qknorm(head_dim: int, dtype: torch.dtype) -> bool:\n    logger = logging.getLogger(__name__)\n    if head_dim not in [64, 128, 256, 512, 1024]:\n        logger.warning(f\"Unsupported head_dim={head_dim} for JIT QK-Norm kernel\")\n        return False\n    try:\n        _jit_qknorm_module(head_dim, dtype)\n        return True\n    except Exception as e:\n        logger.warning(f\"Failed to load JIT QK-Norm kernel: {e}\")\n        return False\n\n\ndef fused_inplace_qknorm(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    q_weight: torch.Tensor,\n    k_weight: torch.Tensor,\n    eps: float = 1e-6,\n    *,\n    head_dim: int = 0,\n) -> None:\n    head_dim = head_dim or q.size(-1)\n    module = _jit_qknorm_module(head_dim, q.dtype)\n    module.qknorm(q, k, q_weight, k_weight, eps)\n\n\ndef rmsnorm(\n    input: torch.Tensor,\n    weight: torch.Tensor,\n    output: Optional[torch.Tensor] = None,\n    eps: float = 1e-6,\n) -> None:\n    output = output if output is not None else input\n    hidden_size = input.size(-1)\n    module = _jit_rmsnorm_module(hidden_size, input.dtype)\n    module.rmsnorm(input, weight, output, eps)\n\n\ndef fused_add_rmsnorm(\n    input: torch.Tensor,\n    residual: torch.Tensor,\n    weight: torch.Tensor,\n    eps: float = 1e-6,\n) -> None:\n    module = _jit_fused_add_rmsnorm_module(input.dtype)\n    module.fused_add_rmsnorm(input, residual, weight, eps)\n\n\ndef fused_inplace_qknorm_across_heads(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    q_weight: torch.Tensor,\n    k_weight: torch.Tensor,\n    eps: float = 1e-6,\n) -> None:\n    \"\"\"\n    Fused inplace QK normalization across all heads.\n\n    Args:\n        q: Query tensor of shape [batch_size, num_heads * head_dim]\n        k: Key tensor of shape [batch_size, num_heads * head_dim]\n        q_weight: Query weight tensor of shape [num_heads * head_dim]\n        k_weight: Key weight tensor of shape [num_heads * head_dim]\n        eps: Epsilon for numerical stability\n    \"\"\"\n    module = _jit_qknorm_across_heads_module(q.dtype)\n    module.qknorm_across_heads(q, k, q_weight, k_weight, eps)\n"
  },
  {
    "path": "python/sglang/jit_kernel/nvfp4.py",
    "content": "from __future__ import annotations\n\nimport importlib.util\nimport os\nimport pathlib\nfrom contextlib import contextmanager\nfrom typing import TYPE_CHECKING, Optional, Tuple\n\nimport torch\n\nfrom sglang.jit_kernel.utils import cache_once, load_jit\nfrom sglang.srt.utils.custom_op import register_custom_op\n\nif TYPE_CHECKING:\n    from tvm_ffi.module import Module\n\n\n_FLOAT4_E2M1_MAX = 6.0\n_FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max\n\n\ndef _find_package_root(package: str) -> Optional[pathlib.Path]:\n    spec = importlib.util.find_spec(package)\n    if spec is None or spec.origin is None:\n        return None\n    return pathlib.Path(spec.origin).resolve().parent\n\n\ndef _resolve_cutlass_include_paths() -> list[str]:\n    include_paths: list[str] = []\n\n    flashinfer_root = _find_package_root(\"flashinfer\")\n    if flashinfer_root is not None:\n        candidates = [\n            flashinfer_root / \"data\" / \"cutlass\" / \"include\",\n            flashinfer_root / \"data\" / \"cutlass\" / \"tools\" / \"util\" / \"include\",\n        ]\n        for path in candidates:\n            if path.exists():\n                include_paths.append(str(path))\n\n    deep_gemm_root = _find_package_root(\"deep_gemm\")\n    if deep_gemm_root is not None:\n        candidate = deep_gemm_root / \"include\"\n        if candidate.exists():\n            include_paths.append(str(candidate))\n\n    # De-duplicate while preserving order.\n    unique_paths = []\n    seen = set()\n    for path in include_paths:\n        if path in seen:\n            continue\n        seen.add(path)\n        unique_paths.append(path)\n    return unique_paths\n\n\ndef _nvfp4_cuda_flags() -> list[str]:\n    return [\n        \"-DNDEBUG\",\n        \"-DFLASHINFER_ENABLE_F16\",\n        \"-DCUTE_USE_PACKED_TUPLE=1\",\n        \"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1\",\n        \"-DCUTLASS_VERSIONS_GENERATED\",\n        \"-DCUTLASS_TEST_LEVEL=0\",\n        \"-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1\",\n        \"-DCUTLASS_DEBUG_TRACE_LEVEL=0\",\n        \"--expt-extended-lambda\",\n    ]\n\n\ndef _get_nvfp4_cuda_arch_list() -> str:\n    if not torch.cuda.is_available():\n        raise RuntimeError(\"NVFP4 JIT kernels require CUDA.\")\n    major, minor = torch.cuda.get_device_capability()\n    if major < 10:\n        raise RuntimeError(\n            f\"NVFP4 JIT kernels require compute capability >= 10.0, got {major}.{minor}.\"\n        )\n    # NVFP4 kernels use architecture-family-specific instructions and must be\n    # compiled for `sm_*a` targets (e.g. sm_100a), not plain sm_100.\n    # JIT compilation targets only the current device, unlike AOT fat-binaries;\n    # adding extra architectures here would clash with the single SGL_CUDA_ARCH\n    # value injected by load_jit().\n    return f\"{major}.{minor}a\"\n\n\n@contextmanager\ndef _nvfp4_arch_env():\n    key = \"TVM_FFI_CUDA_ARCH_LIST\"\n    old_val = os.environ.get(key)\n    os.environ[key] = _get_nvfp4_cuda_arch_list()\n    try:\n        yield\n    finally:\n        if old_val is None:\n            os.environ.pop(key, None)\n        else:\n            os.environ[key] = old_val\n\n\n@cache_once\ndef _jit_nvfp4_quant_module() -> Module:\n    extra_include_paths = _resolve_cutlass_include_paths()\n    if not extra_include_paths:\n        raise RuntimeError(\n            \"Cannot find CUTLASS headers required for NVFP4 JIT quantization. \"\n            \"Please install flashinfer or deep_gemm with CUTLASS headers.\"\n        )\n\n    with _nvfp4_arch_env():\n        return load_jit(\n            \"nvfp4_quant\",\n            cuda_files=[\n                \"gemm/nvfp4/nvfp4_quant_kernels.cuh\",\n            ],\n            cuda_wrappers=[\n                (\"scaled_fp4_quant\", \"scaled_fp4_quant_sm100a_sm120a\"),\n            ],\n            extra_include_paths=extra_include_paths,\n            extra_cuda_cflags=_nvfp4_cuda_flags(),\n        )\n\n\n@cache_once\ndef _jit_nvfp4_expert_quant_module() -> Module:\n    extra_include_paths = _resolve_cutlass_include_paths()\n    if not extra_include_paths:\n        raise RuntimeError(\n            \"Cannot find CUTLASS headers required for NVFP4 JIT expert quantization. \"\n            \"Please install flashinfer or deep_gemm with CUTLASS headers.\"\n        )\n\n    with _nvfp4_arch_env():\n        return load_jit(\n            \"nvfp4_expert_quant\",\n            cuda_files=[\n                \"gemm/nvfp4/nvfp4_expert_quant.cuh\",\n            ],\n            cuda_wrappers=[\n                (\"scaled_fp4_experts_quant\", \"scaled_fp4_experts_quant_sm100a\"),\n                (\n                    \"silu_and_mul_scaled_fp4_experts_quant\",\n                    \"silu_and_mul_scaled_fp4_experts_quant_sm100a\",\n                ),\n            ],\n            extra_include_paths=extra_include_paths,\n            extra_cuda_cflags=_nvfp4_cuda_flags(),\n        )\n\n\n@cache_once\ndef _jit_nvfp4_scaled_mm_module() -> Module:\n    extra_include_paths = _resolve_cutlass_include_paths()\n    if not extra_include_paths:\n        raise RuntimeError(\n            \"Cannot find CUTLASS headers required for NVFP4 JIT GEMM. \"\n            \"Please install flashinfer or deep_gemm with CUTLASS headers.\"\n        )\n\n    with _nvfp4_arch_env():\n        return load_jit(\n            \"nvfp4_scaled_mm\",\n            cuda_files=[\n                \"gemm/nvfp4/nvfp4_scaled_mm_kernels.cuh\",\n                \"gemm/nvfp4/nvfp4_scaled_mm_entry.cuh\",\n            ],\n            cuda_wrappers=[(\"cutlass_scaled_fp4_mm\", \"cutlass_scaled_fp4_mm\")],\n            extra_include_paths=extra_include_paths,\n            extra_cuda_cflags=_nvfp4_cuda_flags(),\n        )\n\n\n@cache_once\ndef _jit_nvfp4_blockwise_moe_module() -> Module:\n    extra_include_paths = _resolve_cutlass_include_paths()\n    if not extra_include_paths:\n        raise RuntimeError(\n            \"Cannot find CUTLASS headers required for NVFP4 JIT MoE grouped GEMM. \"\n            \"Please install flashinfer or deep_gemm with CUTLASS headers.\"\n        )\n\n    with _nvfp4_arch_env():\n        return load_jit(\n            \"nvfp4_blockwise_moe\",\n            cuda_files=[\n                \"moe/nvfp4_blockwise_moe.cuh\",\n            ],\n            cuda_wrappers=[\n                (\"cutlass_fp4_group_mm\", \"cutlass_fp4_group_mm_sm100a_sm120a\")\n            ],\n            extra_include_paths=extra_include_paths,\n            extra_cuda_cflags=_nvfp4_cuda_flags(),\n        )\n\n\ndef cutlass_scaled_fp4_mm(\n    a: torch.Tensor,\n    b: torch.Tensor,\n    block_scale_a: torch.Tensor,\n    block_scale_b: torch.Tensor,\n    alpha: torch.Tensor,\n    out_dtype: torch.dtype,\n) -> torch.Tensor:\n    assert a.ndim == 2 and b.ndim == 2\n    m, n = a.shape[0], b.shape[0]\n    out = torch.empty((m, n), dtype=out_dtype, device=a.device)\n    module = _jit_nvfp4_scaled_mm_module()\n    module.cutlass_scaled_fp4_mm(out, a, b, block_scale_a, block_scale_b, alpha)\n    return out\n\n\ndef cutlass_fp4_group_mm(\n    a_fp4: torch.Tensor,\n    b_fp4: torch.Tensor,\n    a_blockscale: torch.Tensor,\n    b_blockscale: torch.Tensor,\n    alphas: torch.Tensor,\n    out_dtype: torch.dtype,\n    params: dict[str, torch.Tensor],\n) -> torch.Tensor:\n    m_topk = a_fp4.shape[0]\n    n = b_fp4.shape[1]\n    output = torch.empty((m_topk, n), device=a_fp4.device, dtype=out_dtype)\n    num_experts = int(params[\"expert_offsets\"].numel())\n    device = a_fp4.device\n\n    # Backward compatibility: older callers may not pass scratch tensors.\n    a_ptrs = params.get(\n        \"a_ptrs\", torch.empty((num_experts,), dtype=torch.int64, device=device)\n    )\n    b_ptrs = params.get(\n        \"b_ptrs\", torch.empty((num_experts,), dtype=torch.int64, device=device)\n    )\n    out_ptrs = params.get(\n        \"out_ptrs\", torch.empty((num_experts,), dtype=torch.int64, device=device)\n    )\n    a_scales_ptrs = params.get(\n        \"a_scales_ptrs\", torch.empty((num_experts,), dtype=torch.int64, device=device)\n    )\n    b_scales_ptrs = params.get(\n        \"b_scales_ptrs\", torch.empty((num_experts,), dtype=torch.int64, device=device)\n    )\n    alpha_ptrs = params.get(\n        \"alpha_ptrs\", torch.empty((num_experts,), dtype=torch.int64, device=device)\n    )\n    layout_sfa = params.get(\n        \"layout_sfa\", torch.empty((num_experts, 5), dtype=torch.int64, device=device)\n    )\n    layout_sfb = params.get(\n        \"layout_sfb\", torch.empty((num_experts, 5), dtype=torch.int64, device=device)\n    )\n\n    _cutlass_fp4_group_mm_custom_op(\n        output,\n        a_fp4,\n        b_fp4,\n        a_blockscale,\n        b_blockscale,\n        alphas,\n        params[\"ab_strides\"],\n        params[\"c_strides\"],\n        params[\"problem_sizes\"],\n        params[\"expert_offsets\"],\n        params[\"blockscale_offsets\"],\n        a_ptrs,\n        b_ptrs,\n        out_ptrs,\n        a_scales_ptrs,\n        b_scales_ptrs,\n        alpha_ptrs,\n        layout_sfa,\n        layout_sfb,\n    )\n    return output\n\n\n@register_custom_op(\n    op_name=\"scaled_fp4_quant\",\n    mutates_args=[\"output\", \"output_scale\"],\n)\ndef _scaled_fp4_quant_custom_op(\n    input: torch.Tensor,\n    output: torch.Tensor,\n    output_scale: torch.Tensor,\n    input_global_scale: torch.Tensor,\n) -> None:\n    module = _jit_nvfp4_quant_module()\n    module.scaled_fp4_quant(output, input, output_scale, input_global_scale)\n\n\ndef scaled_fp4_quant(\n    input: torch.Tensor, input_global_scale: torch.Tensor\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"Quantize input tensor to FP4 and return packed FP4 tensor + swizzled scales.\"\"\"\n    assert input.ndim >= 1, f\"input.ndim needs to be >= 1, but got {input.ndim}.\"\n    other_dims = 1 if input.ndim == 1 else -1\n    input = input.reshape(other_dims, input.shape[-1])\n    m, n = input.shape\n    block_size = 16\n    device = input.device\n\n    assert n % block_size == 0, f\"last dim has to be multiple of 16, but got {n}.\"\n    assert input.dtype in (\n        torch.float16,\n        torch.bfloat16,\n    ), f\"input.dtype needs to be fp16 or bf16 but got {input.dtype}.\"\n\n    output = torch.empty((m, n // 2), device=device, dtype=torch.uint8)\n\n    rounded_m = ((m + 128 - 1) // 128) * 128\n    scale_n = n // block_size\n    rounded_n = ((scale_n + 4 - 1) // 4) * 4\n    if rounded_n > scale_n:\n        output_scale = torch.zeros(\n            (rounded_m, rounded_n // 4), device=device, dtype=torch.int32\n        )\n    else:\n        output_scale = torch.empty(\n            (rounded_m, rounded_n // 4), device=device, dtype=torch.int32\n        )\n\n    _scaled_fp4_quant_custom_op(input, output, output_scale, input_global_scale)\n    output_scale = output_scale.view(torch.float8_e4m3fn)\n    return output, output_scale\n\n\ndef _shuffle_rows_torch(\n    input_tensor: torch.Tensor,\n    dst2src_map: torch.Tensor,\n    output_tensor_shape: tuple[int, int],\n) -> torch.Tensor:\n    # Keep compatibility when sgl-kernel is slimmed and shuffle_rows may not be present.\n    output = input_tensor.index_select(0, dst2src_map.to(dtype=torch.int64))\n    return output.view(output_tensor_shape)\n\n\n@register_custom_op(\n    op_name=\"scaled_fp4_experts_quant\",\n    mutates_args=[\"output\", \"output_scales\"],\n)\ndef _scaled_fp4_experts_quant_custom_op(\n    output: torch.Tensor,\n    output_scales: torch.Tensor,\n    input_tensor: torch.Tensor,\n    input_global_scale: torch.Tensor,\n    expert_offsets: torch.Tensor,\n    blockscale_offsets: torch.Tensor,\n) -> None:\n    module = _jit_nvfp4_expert_quant_module()\n    module.scaled_fp4_experts_quant(\n        output,\n        output_scales,\n        input_tensor,\n        input_global_scale,\n        expert_offsets,\n        blockscale_offsets,\n    )\n\n\ndef scaled_fp4_experts_quant(\n    input_tensor: torch.Tensor,\n    input_global_scale: torch.Tensor,\n    expert_offsets: torch.Tensor,\n    blockscale_offsets: torch.Tensor,\n    topk: int,\n    expert_map: Optional[torch.Tensor] = None,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"Quantize packed MoE activations to NVFP4.\"\"\"\n    assert (\n        input_tensor.ndim == 2\n    ), f\"input.ndim needs to be == 2, but got {input_tensor.ndim}.\"\n    if expert_map is not None:\n        m, k = input_tensor.shape\n        output_tensor_shape = (m * topk, k)\n        input_tensor = _shuffle_rows_torch(\n            input_tensor, expert_map, output_tensor_shape\n        )\n\n    m_numtopk, k = input_tensor.shape\n    max_tokens_per_expert = int(os.environ.get(\"MODELOPT_MAX_TOKENS_PER_EXPERT\", 65536))\n    assert m_numtopk <= max_tokens_per_expert * topk, (\n        f\"m_numtopk must be less than MAX_TOKENS_PER_EXPERT({max_tokens_per_expert})\"\n        f\" for cutlass_moe_fp4, observed m_numtopk = {m_numtopk}. Use\"\n        \" MODELOPT_MAX_TOKENS_PER_EXPERT to set this value.\"\n    )\n    scales_k = k // 16\n    # output_scales is int32-packed FP8 scales, so second dim is in int32 units.\n    padded_k_in_int32 = (scales_k + 3) // 4\n\n    output = torch.empty(\n        m_numtopk, k // 2, device=input_tensor.device, dtype=torch.uint8\n    )\n    if padded_k_in_int32 * 4 > scales_k:\n        output_scales = torch.zeros(\n            max_tokens_per_expert * topk,\n            padded_k_in_int32,\n            dtype=torch.int32,\n            device=input_tensor.device,\n        )\n    else:\n        output_scales = torch.empty(\n            max_tokens_per_expert * topk,\n            padded_k_in_int32,\n            dtype=torch.int32,\n            device=input_tensor.device,\n        )\n\n    _scaled_fp4_experts_quant_custom_op(\n        output,\n        output_scales,\n        input_tensor,\n        input_global_scale,\n        expert_offsets,\n        blockscale_offsets,\n    )\n    output_scales = output_scales.view(torch.float8_e4m3fn)\n    return output, output_scales\n\n\n@register_custom_op(\n    op_name=\"scaled_fp4_grouped_quant\",\n    mutates_args=[\"output\", \"output_scales\"],\n)\ndef _scaled_fp4_grouped_quant_custom_op(\n    input_tensor: torch.Tensor,\n    output: torch.Tensor,\n    output_scales: torch.Tensor,\n    input_global_scale: torch.Tensor,\n    mask: torch.Tensor,\n) -> None:\n    l, m, k = input_tensor.shape\n    del l, m\n    module = _jit_nvfp4_expert_quant_module()\n    module.silu_and_mul_scaled_fp4_experts_quant(\n        output.view(-1, k // 2),\n        output_scales.view(-1, output_scales.shape[-1]),\n        input_tensor.view(-1, k),\n        input_global_scale,\n        mask,\n        False,\n    )\n\n\ndef scaled_fp4_grouped_quant(\n    input_tensor: torch.Tensor,\n    input_global_scale: torch.Tensor,\n    mask: torch.Tensor,\n):\n    \"\"\"Quantize grouped GEMM inputs to FP4 and return logical (m, k//2, l).\"\"\"\n    device = input_tensor.device\n    l, m, k = input_tensor.shape\n    sf_vec_size = 16\n    assert k % sf_vec_size == 0, f\"k must be multiple of 16, but got {k}.\"\n\n    scale_k = k // sf_vec_size\n    padded_k = (scale_k + (4 - 1)) // 4 * 4\n    padded_k_int32 = padded_k // 4\n    padded_m = (m + (128 - 1)) // 128 * 128\n    output = torch.empty(l, m, k // 2, device=device, dtype=torch.uint8)\n    output_scales = torch.empty(\n        l, padded_m, padded_k_int32, device=device, dtype=torch.int32\n    )\n\n    _scaled_fp4_grouped_quant_custom_op(\n        input_tensor,\n        output,\n        output_scales,\n        input_global_scale,\n        mask,\n    )\n\n    output = output.permute(1, 2, 0)\n    output_scales = output_scales.view(torch.float8_e4m3fn).view(\n        l, padded_m // 128, padded_k // 4, 32, 4, 4\n    )\n    output_scales = output_scales.permute(3, 4, 1, 5, 2, 0)\n    return output, output_scales\n\n\n@register_custom_op(\n    op_name=\"silu_and_mul_scaled_fp4_grouped_quant\",\n    mutates_args=[\"output\", \"output_scales\"],\n)\ndef _silu_and_mul_scaled_fp4_grouped_quant_custom_op(\n    input_tensor: torch.Tensor,\n    output: torch.Tensor,\n    output_scales: torch.Tensor,\n    input_global_scale: torch.Tensor,\n    mask: torch.Tensor,\n) -> None:\n    l, m, k_by_2 = input_tensor.shape\n    del l, m\n    module = _jit_nvfp4_expert_quant_module()\n    module.silu_and_mul_scaled_fp4_experts_quant(\n        output.view(-1, output.shape[-1]),\n        output_scales.view(-1, output_scales.shape[-1]),\n        input_tensor.view(-1, k_by_2),\n        input_global_scale,\n        mask,\n        True,\n    )\n\n\ndef silu_and_mul_scaled_fp4_grouped_quant(\n    input_tensor: torch.Tensor,\n    input_global_scale: torch.Tensor,\n    mask: torch.Tensor,\n):\n    \"\"\"Apply SiLU-and-mul then quantize grouped GEMM inputs to FP4.\"\"\"\n    device = input_tensor.device\n    l, m, k_by_2 = input_tensor.shape\n    k = k_by_2 // 2\n    sf_vec_size = 16\n    assert k % sf_vec_size == 0, f\"k must be multiple of 16, but got {k}.\"\n\n    scale_k = k // sf_vec_size\n    padded_k = (scale_k + (4 - 1)) // 4 * 4\n    padded_k_int32 = padded_k // 4\n    padded_m = (m + (128 - 1)) // 128 * 128\n    output = torch.empty(l, m, k // 2, device=device, dtype=torch.uint8)\n    output_scales = torch.empty(\n        l, padded_m, padded_k_int32, device=device, dtype=torch.int32\n    )\n\n    _silu_and_mul_scaled_fp4_grouped_quant_custom_op(\n        input_tensor,\n        output,\n        output_scales,\n        input_global_scale,\n        mask,\n    )\n\n    output = output.permute(1, 2, 0)\n    output_scales = output_scales.view(torch.float8_e4m3fn).view(\n        l, padded_m // 128, padded_k // 4, 32, 4, 4\n    )\n    output_scales = output_scales.permute(3, 4, 1, 5, 2, 0)\n    return output, output_scales\n\n\n@register_custom_op(\n    op_name=\"cutlass_fp4_group_mm\",\n    mutates_args=[\n        \"output\",\n        \"a_ptrs\",\n        \"b_ptrs\",\n        \"out_ptrs\",\n        \"a_scales_ptrs\",\n        \"b_scales_ptrs\",\n        \"alpha_ptrs\",\n        \"layout_sfa\",\n        \"layout_sfb\",\n    ],\n)\ndef _cutlass_fp4_group_mm_custom_op(\n    output: torch.Tensor,\n    a_fp4: torch.Tensor,\n    b_fp4: torch.Tensor,\n    a_blockscale: torch.Tensor,\n    b_blockscale: torch.Tensor,\n    alphas: torch.Tensor,\n    ab_strides: torch.Tensor,\n    c_strides: torch.Tensor,\n    problem_sizes: torch.Tensor,\n    expert_offsets: torch.Tensor,\n    blockscale_offsets: torch.Tensor,\n    a_ptrs: torch.Tensor,\n    b_ptrs: torch.Tensor,\n    out_ptrs: torch.Tensor,\n    a_scales_ptrs: torch.Tensor,\n    b_scales_ptrs: torch.Tensor,\n    alpha_ptrs: torch.Tensor,\n    layout_sfa: torch.Tensor,\n    layout_sfb: torch.Tensor,\n) -> None:\n    module = _jit_nvfp4_blockwise_moe_module()\n    module.cutlass_fp4_group_mm(\n        output,\n        a_fp4,\n        b_fp4,\n        a_blockscale,\n        b_blockscale,\n        alphas,\n        ab_strides,\n        c_strides,\n        problem_sizes,\n        expert_offsets,\n        blockscale_offsets,\n        a_ptrs,\n        b_ptrs,\n        out_ptrs,\n        a_scales_ptrs,\n        b_scales_ptrs,\n        alpha_ptrs,\n        layout_sfa,\n        layout_sfb,\n    )\n\n\ndef suggest_nvfp4_global_scale(x: torch.Tensor) -> torch.Tensor:\n    \"\"\"Utility for tests/benchmarks: return global scale used by NVFP4 quantization.\"\"\"\n    tensor_amax = torch.abs(x).max().to(torch.float32)\n    return _FLOAT8_E4M3_MAX * _FLOAT4_E2M1_MAX / tensor_amax\n"
  },
  {
    "path": "python/sglang/jit_kernel/per_tensor_quant_fp8.py",
    "content": "from __future__ import annotations\n\nfrom typing import TYPE_CHECKING\n\nimport torch\n\nfrom sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args\nfrom sglang.srt.utils.custom_op import register_custom_op\n\nif TYPE_CHECKING:\n    from tvm_ffi.module import Module\n\n\n@cache_once\ndef _jit_per_tensor_quant_fp8_module(is_static: bool, dtype: torch.dtype) -> Module:\n    args = make_cpp_args(is_static, dtype)\n    return load_jit(\n        \"per_tensor_quant_fp8\",\n        *args,\n        cuda_files=[\"gemm/per_tensor_quant_fp8.cuh\"],\n        cuda_wrappers=[(\"per_tensor_quant_fp8\", f\"per_tensor_quant_fp8<{args}>\")],\n    )\n\n\n@register_custom_op(\n    op_name=\"per_tensor_quant_fp8\",\n    mutates_args=[\"output_q\", \"output_s\"],\n)\ndef per_tensor_quant_fp8(\n    input: torch.Tensor,\n    output_q: torch.Tensor,\n    output_s: torch.Tensor,\n    is_static: bool = False,\n) -> None:\n    \"\"\"\n    Per-tensor quantization to FP8 format.\n\n    Args:\n        input: Input tensor to quantize (float, half, or bfloat16)\n        output_q: Output quantized tensor (fp8_e4m3)\n        output_s: Output scale tensor (float scalar or 1D tensor with 1 element)\n        is_static: If True, assumes scale is pre-computed and skips absmax computation\n    \"\"\"\n    module = _jit_per_tensor_quant_fp8_module(is_static, input.dtype)\n    module.per_tensor_quant_fp8(input.view(-1), output_q.view(-1), output_s.view(-1))\n"
  },
  {
    "path": "python/sglang/jit_kernel/per_token_group_quant_8bit.py",
    "content": "from __future__ import annotations\n\nfrom typing import TYPE_CHECKING\n\nimport torch\n\nfrom sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args\nfrom sglang.srt.utils.custom_op import register_custom_op\n\nif TYPE_CHECKING:\n    from tvm_ffi.module import Module\n\nfrom sglang.jit_kernel.utils import CPP_DTYPE_MAP as OUTPUT_DTYPE_MAP\n\n\n@cache_once\ndef _jit_per_token_group_quant_8bit_module(\n    dtype: torch.dtype, output_type: torch.dtype\n) -> Module:\n    input_args = make_cpp_args(dtype)\n    out_cpp = OUTPUT_DTYPE_MAP[output_type]\n    return load_jit(\n        \"per_token_group_quant_8bit\",\n        cuda_files=[\"gemm/per_token_group_quant_8bit.cuh\"],\n        cuda_wrappers=[\n            (\n                \"per_token_group_quant_8bit\",\n                f\"per_token_group_quant_8bit<{input_args}, {out_cpp}>\",\n            )\n        ],\n    )\n\n\n@register_custom_op(\n    op_name=\"per_token_group_quant_8bit\",\n    mutates_args=[\"output_q\", \"output_s\"],\n)\ndef _per_token_group_quant_8bit_custom_op(\n    input: torch.Tensor,\n    output_q: torch.Tensor,\n    output_s: torch.Tensor,\n    group_size: int,\n    eps: float,\n    fp8_min: float,\n    fp8_max: float,\n    scale_ue8m0: bool = False,\n) -> None:\n    \"\"\"\n    Per-token-group quantization to 8-bit format.\n\n    Args:\n        input: Input tensor to quantize (float, half, or bfloat16).\n        output_q: Output quantized tensor (e.g., fp8_e4m3 or int8).\n        output_s: Output scale tensor.\n        group_size: The size of the group for quantization.\n        eps: A small value to avoid division by zero.\n        fp8_min: The minimum value of the 8-bit data type.\n        fp8_max: The maximum value of the 8-bit data type.\n        scale_ue8m0: Whether to use UE8M0 format for scales.\n    \"\"\"\n    module = _jit_per_token_group_quant_8bit_module(input.dtype, output_q.dtype)\n    module.per_token_group_quant_8bit(\n        input,\n        output_q,\n        output_s,\n        group_size,\n        eps,\n        fp8_min,\n        fp8_max,\n        scale_ue8m0,\n    )\n    return None\n\n\ndef per_token_group_quant_8bit(\n    input: torch.Tensor,\n    output_q: torch.Tensor,\n    output_s: torch.Tensor,\n    group_size: int,\n    eps: float,\n    fp8_min: float,\n    fp8_max: float,\n    scale_ue8m0: bool = False,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    _per_token_group_quant_8bit_custom_op(\n        input=input,\n        output_q=output_q,\n        output_s=output_s,\n        group_size=group_size,\n        eps=eps,\n        fp8_min=fp8_min,\n        fp8_max=fp8_max,\n        scale_ue8m0=scale_ue8m0,\n    )\n    return output_q, output_s\n"
  },
  {
    "path": "python/sglang/jit_kernel/rope.py",
    "content": "from __future__ import annotations\n\nfrom dataclasses import dataclass\nfrom typing import TYPE_CHECKING, Optional\n\nimport torch\n\nfrom sglang.jit_kernel.utils import (\n    cache_once,\n    is_arch_support_pdl,\n    load_jit,\n    make_cpp_args,\n)\nfrom sglang.srt.utils.custom_op import register_custom_op\n\nif TYPE_CHECKING:\n    from tvm_ffi.module import Module\n\n\n@cache_once\ndef _jit_rotary_embedding_module() -> Module:\n    return load_jit(\n        \"rotary_embedding\",\n        cuda_files=[\"elementwise/pos_enc.cuh\"],\n        cuda_wrappers=[(\"rotary_embedding\", \"RotaryEmbeddingKernel::run\")],\n    )\n\n\n@cache_once\ndef _jit_fused_rope_module(is_neox: bool, rope_dim: int, dtype: torch.dtype) -> Module:\n    args = make_cpp_args(is_neox, rope_dim, is_arch_support_pdl(), dtype)\n    return load_jit(\n        \"fused_rope\",\n        *args,\n        cuda_files=[\"elementwise/rope.cuh\"],\n        cuda_wrappers=[\n            (\"run_rope\", f\"FusedRopeKernel<{args}>::run\"),\n            (\"run_rope_store\", f\"FusedRopeKernel<{args}>::run_fused\"),\n        ],\n    )\n\n\n@register_custom_op(\n    op_name=\"rotary_embedding_with_key\",\n    mutates_args=[\"query\", \"key\"],\n)\ndef rotary_embedding_with_key(\n    positions: torch.Tensor,\n    query: torch.Tensor,\n    key: torch.Tensor,\n    head_size: int,\n    cos_sin_cache: torch.Tensor,\n    is_neox: bool = True,\n) -> None:\n    module = _jit_rotary_embedding_module()\n    module.rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox)\n\n\n@register_custom_op(\n    op_name=\"rotary_embedding_without_key\",\n    mutates_args=[\"query\"],\n)\ndef rotary_embedding_without_key(\n    positions: torch.Tensor,\n    query: torch.Tensor,\n    head_size: int,\n    cos_sin_cache: torch.Tensor,\n    is_neox: bool = True,\n) -> None:\n    module = _jit_rotary_embedding_module()\n    module.rotary_embedding(positions, query, None, head_size, cos_sin_cache, is_neox)\n\n\ndef rotary_embedding(\n    positions: torch.Tensor,\n    query: torch.Tensor,\n    key: Optional[torch.Tensor],\n    head_size: int,\n    cos_sin_cache: torch.Tensor,\n    is_neox: bool = True,\n):\n    if key is None:\n        rotary_embedding_without_key(\n            positions, query, head_size, cos_sin_cache, is_neox\n        )\n    else:\n        rotary_embedding_with_key(\n            positions, query, key, head_size, cos_sin_cache, is_neox\n        )\n    return query, key\n\n\n@dataclass\nclass FusedSetKVBufferArg:\n    \"\"\"\n    value : Optional[torch.Tensor]\n        Value tensor, shape: ``(nnz, num_v_heads * head_size)``.\n    k_buffer : Optional[torch.Tensor]\n        Buffer for keys, shape: ``(nnz, num_k_heads * head_size)``.\n    v_buffer : Optional[torch.Tensor]\n        Buffer for values, shape: ``(nnz, num_v_heads * head_size)``.\n    cache_loc : Optional[torch.Tensor]\n        Cache location tensor, used for indexing kv cache.\n    \"\"\"\n\n    value: torch.Tensor\n    k_buffer: torch.Tensor\n    v_buffer: torch.Tensor\n    cache_loc: torch.Tensor\n\n\n@register_custom_op(mutates_args=[\"q\", \"k\"])\ndef apply_rope_inplace(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    cos_sin_cache: torch.Tensor,\n    positions: torch.Tensor,\n    *,\n    is_neox: bool,\n    rope_dim: int = 0,\n) -> None:\n    \"\"\"\n    Fused inplace rotary position embedding for query and key tensors.\n\n    Args:\n        q: Query tensor of shape [num_tokens, num_qo_heads, rope_dim].\n        k: Key tensor of shape [num_tokens, num_kv_heads, rope_dim].\n        cos_sin_cache: Cosine/sine cache of shape [max_position, rope_dim],\n            where the first half along dim=-1 is cos and the second half is sin.\n            Must be float32.\n        positions: Position indices of shape [num_tokens], int32 or int64.\n        is_neox: Whether to use GPT-NeoX style (True) or GPT-J interleaved style (False).\n        rope_dim: Rotary embedding dimension. Defaults to cos_sin_cache.size(-1).\n    \"\"\"\n    rope_dim = rope_dim or cos_sin_cache.size(-1)\n    module = _jit_fused_rope_module(is_neox, rope_dim, q.dtype)\n    module.run_rope(q, k, cos_sin_cache, positions)\n\n\n@register_custom_op(mutates_args=[\"q\", \"k\", \"k_cache\", \"v_cache\"])\ndef apply_rope_inplace_with_kvcache(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    k_cache: torch.Tensor,\n    v_cache: torch.Tensor,\n    cos_sin_cache: torch.Tensor,\n    positions: torch.Tensor,\n    out_loc: torch.Tensor,\n    *,\n    is_neox: bool,\n    rope_dim: int = 0,\n) -> None:\n    \"\"\"\n    Fused inplace RoPE + KV cache store.\n\n    Applies rotary position embedding to q and k inplace. The rotated k is also\n    stored in k_cache. The original v is also stored in v_cache.\n\n    Args:\n        q: Query tensor of shape [num_tokens, num_qo_heads, head_dim].\n        k: Key tensor of shape [num_tokens, num_kv_heads, head_dim].\n        v: Value tensor of shape [num_tokens, num_kv_heads, head_dim].\n        k_cache: Key cache of shape [cache_size, num_kv_heads * head_dim].\n        v_cache: Value cache of shape [cache_size, num_kv_heads * head_dim].\n        cos_sin_cache: Cosine/sine cache of shape [max_position, rope_dim], float32.\n        positions: Position indices of shape [num_tokens], int32 or int64.\n        out_loc: Cache write locations of shape [num_tokens], same dtype as positions.\n        is_neox: Whether to use GPT-NeoX style (True) or GPT-J interleaved (False).\n        rope_dim: Rotary embedding dimension. Defaults to cos_sin_cache.size(-1).\n    \"\"\"\n    rope_dim = rope_dim or cos_sin_cache.size(-1)\n    v = v.view_as(k)\n    module = _jit_fused_rope_module(is_neox, rope_dim, q.dtype)\n    module.run_rope_store(q, k, v, k_cache, v_cache, cos_sin_cache, positions, out_loc)\n\n\n# NOTE: this name is intentionally set as the old kernel in `sgl_kernel`\ndef apply_rope_with_cos_sin_cache_inplace(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    cos_sin_cache: torch.Tensor,\n    positions: torch.Tensor,\n    *,\n    is_neox: bool,\n    rope_dim: int = 0,\n    fused_args: Optional[FusedSetKVBufferArg] = None,\n) -> None:\n    \"\"\"\n    Apply RoPE to q and k inplace, with optional fused kv cache store.\n\n    If `fused_args` is provided, it will perform fused RoPE and KV cache store.\n    Otherwise, it will only apply RoPE inplace.\n\n    Args:\n        q: Query tensor of shape [num_tokens, num_qo_heads, head_dim].\n        k: Key tensor of shape [num_tokens, num_kv_heads, head_dim].\n        cos_sin_cache: Cosine/sine cache of shape [max_position, rope_dim], float32.\n        positions: Position indices of shape [num_tokens], int32 or int64.\n        is_neox: Whether to use GPT-NeoX style (True) or GPT-J interleaved (False).\n        rope_dim: Rotary embedding dimension. Defaults to cos_sin_cache.size(-1).\n        fused_args: Optional arguments for fused RoPE + KV cache store. If None,\n            only RoPE will be applied inplace without touching kv cache.\n    \"\"\"\n    if fused_args is not None:\n        apply_rope_inplace_with_kvcache(\n            q,\n            k,\n            fused_args.value,\n            fused_args.k_buffer,\n            fused_args.v_buffer,\n            cos_sin_cache,\n            positions,\n            fused_args.cache_loc,\n            is_neox=is_neox,\n            rope_dim=rope_dim,\n        )\n    else:\n        apply_rope_inplace(\n            q, k, cos_sin_cache, positions, is_neox=is_neox, rope_dim=rope_dim\n        )\n"
  },
  {
    "path": "python/sglang/jit_kernel/tests/test_add_constant.py",
    "content": "import pytest\nimport torch\n\nfrom sglang.jit_kernel.add_constant import add_constant\n\n\n@pytest.mark.parametrize(\"size\", [1, 2, 127, 128, 1024, 1025])\n@pytest.mark.parametrize(\"constant\", [0, 1, 7, 1024, -3])\ndef test_add_constant(size: int, constant: int) -> None:\n    src = torch.arange(0, size, dtype=torch.int32, device=\"cuda\")\n    dst = add_constant(src, constant)\n    assert torch.all(dst == src + constant)\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__, \"-v\", \"-s\"])\n"
  },
  {
    "path": "python/sglang/jit_kernel/tests/test_awq_dequantize.py",
    "content": "import itertools\n\nimport pytest\nimport torch\n\nfrom sglang.jit_kernel.awq_dequantize import awq_dequantize as jit_awq_dequantize\n\ntry:\n    from sgl_kernel import awq_dequantize as aot_awq_dequantize\n\n    AOT_AVAILABLE = True\nexcept ImportError:\n    AOT_AVAILABLE = False\n\n\ndef reverse_awq_order(t: torch.Tensor):\n    bits = 4\n    AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]\n    reverse_order_tensor = torch.arange(\n        t.shape[-1],\n        dtype=torch.int32,\n        device=t.device,\n    )\n    reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)\n    reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER]\n    reverse_order_tensor = reverse_order_tensor.view(-1)\n\n    t = t[:, reverse_order_tensor] & 0xF\n    return t\n\n\n# qweights - [R     , C // 8], int32\n# scales   - [R // G, C     ], float16\n# zeros    - [R // G, C // 8], int32\ndef awq_dequantize_torch(\n    qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor, group_size: int\n) -> torch.Tensor:\n    if group_size == -1:\n        group_size = qweight.shape[0]\n\n    bits = 4\n    shifts = torch.arange(0, 32, bits, device=qzeros.device)\n\n    iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(\n        torch.int8\n    )\n\n    iweights = iweights.view(iweights.shape[0], -1)\n\n    zeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to(\n        torch.int8\n    )\n    zeros = zeros.view(qzeros.shape[0], -1)\n    zeros = reverse_awq_order(zeros)\n\n    iweights = reverse_awq_order(iweights)\n\n    iweights = torch.bitwise_and(iweights, (2**bits) - 1)\n    zeros = torch.bitwise_and(zeros, (2**bits) - 1)\n\n    scales = scales.repeat_interleave(group_size, dim=0)\n    zeros = zeros.repeat_interleave(group_size, dim=0)\n    return (iweights - zeros) * scales\n\n\n@pytest.mark.parametrize(\n    \"qweight_row,qweight_col,is_bf16_act\",\n    list(\n        itertools.product(\n            [128, 256, 512, 1024, 3584],\n            [16, 32, 64, 128, 448],\n            [True, False],\n        )\n    ),\n)\ndef test_awq_dequantize_jit_vs_torch(\n    qweight_row: int, qweight_col: int, is_bf16_act: bool\n):\n    device = torch.device(\"cuda\")\n    qweight = torch.randint(\n        0,\n        torch.iinfo(torch.int32).max,\n        (qweight_row, qweight_col),\n        dtype=torch.int32,\n        device=device,\n    )\n    group_size = qweight_row\n    scales_row = qweight_row // group_size\n    scales_col = qweight_col * 8\n\n    if is_bf16_act:\n        scales = torch.rand(scales_row, scales_col, dtype=torch.bfloat16, device=device)\n    else:\n        scales = torch.rand(scales_row, scales_col, dtype=torch.float16, device=device)\n\n    qzeros = torch.randint(\n        0,\n        torch.iinfo(torch.int32).max,\n        (scales_row, qweight_col),\n        dtype=torch.int32,\n        device=device,\n    )\n\n    # Run both implementations\n    torch_out = awq_dequantize_torch(qweight, scales, qzeros, group_size)\n    jit_out = jit_awq_dequantize(qweight, scales, qzeros)\n\n    # Compare results (approximate due to different computation paths)\n    torch.testing.assert_close(\n        torch_out.to(torch.float32), jit_out.to(torch.float32), rtol=1e-3, atol=1e-5\n    )\n\n\n@pytest.mark.parametrize(\n    \"qweight_row,qweight_col,is_bf16_act\",\n    list(\n        itertools.product(\n            [128, 256, 512, 1024, 3584],\n            [16, 32, 64, 128, 448],\n            [True, False],\n        )\n    ),\n)\ndef test_awq_dequantize_jit_vs_aot(\n    qweight_row: int, qweight_col: int, is_bf16_act: bool\n):\n    if not AOT_AVAILABLE:\n        pytest.skip(\"sgl_kernel AOT not available\")\n\n    device = torch.device(\"cuda\")\n    qweight = torch.randint(\n        0,\n        torch.iinfo(torch.int32).max,\n        (qweight_row, qweight_col),\n        dtype=torch.int32,\n        device=device,\n    )\n    group_size = qweight_row\n    scales_row = qweight_row // group_size\n    scales_col = qweight_col * 8\n\n    if is_bf16_act:\n        scales = torch.rand(scales_row, scales_col, dtype=torch.bfloat16, device=device)\n    else:\n        scales = torch.rand(scales_row, scales_col, dtype=torch.float16, device=device)\n\n    qzeros = torch.randint(\n        0,\n        torch.iinfo(torch.int32).max,\n        (scales_row, qweight_col),\n        dtype=torch.int32,\n        device=device,\n    )\n\n    # Run both implementations\n    aot_out = aot_awq_dequantize(qweight, scales, qzeros)\n    jit_out = jit_awq_dequantize(qweight, scales, qzeros)\n\n    # Bitwise equality\n    torch.testing.assert_close(jit_out, aot_out, rtol=0, atol=0)\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__, \"-v\", \"-s\"])\n"
  },
  {
    "path": "python/sglang/jit_kernel/tests/test_awq_marlin_moe_repack.py",
    "content": "import numpy as np\nimport pytest\nimport torch\nfrom sgl_kernel.scalar_type import scalar_types\n\nfrom sglang.jit_kernel.awq_marlin_repack import (\n    awq_marlin_moe_repack as jit_awq_marlin_moe_repack,\n)\nfrom sglang.srt.layers.quantization.utils import pack_cols, quantize_weights\n\n\ndef _has_aot_awq_marlin_moe_repack() -> bool:\n    return hasattr(torch.ops.sgl_kernel, \"awq_marlin_moe_repack\") and hasattr(\n        torch.ops.sgl_kernel.awq_marlin_moe_repack, \"default\"\n    )\n\n\nAOT_AVAILABLE = _has_aot_awq_marlin_moe_repack()\n\n\ndef awq_pack(\n    q_w: torch.Tensor,\n    num_bits: int,\n    size_k: int,\n    size_n: int,\n):\n    assert q_w.shape == (size_k, size_n)\n\n    if num_bits == 4:\n        interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])\n    elif num_bits == 8:\n        interleave = np.array([0, 2, 1, 3])\n    else:\n        raise Exception(\"num_bits must be 4 or 8, got {}\".format(num_bits))\n\n    q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel()\n    q_w = q_w.reshape((-1, size_n)).contiguous()\n\n    return pack_cols(q_w, num_bits, size_k, size_n)\n\n\n@pytest.mark.parametrize(\"num_bits\", [4])\n@pytest.mark.parametrize(\"num_experts\", [2, 4, 8])\n@pytest.mark.parametrize(\"k_tiles,n_tiles\", [(1, 1), (2, 2), (4, 4)])\n@pytest.mark.parametrize(\"group_size\", [16, 32])\ndef test_awq_marlin_moe_repack_jit_vs_aot(\n    num_bits, num_experts, k_tiles, n_tiles, group_size\n):\n    if not AOT_AVAILABLE:\n        pytest.skip(\"sgl_kernel AOT not available\")\n\n    tile_k, tile_n = 16, 64\n    size_k = k_tiles * tile_k\n    size_n = n_tiles * tile_n\n    pack_factor = 32 // num_bits\n\n    # Create per-expert AWQ-packed weights\n    b_q_weight = torch.empty(\n        (num_experts, size_k, size_n // pack_factor),\n        dtype=torch.int32,\n        device=\"cuda\",\n    )\n    for e in range(num_experts):\n        b_weight = torch.randn((size_k, size_n), dtype=torch.float16, device=\"cuda\")\n        w_ref, q_w, s, zp = quantize_weights(\n            b_weight, scalar_types.uint4, group_size, zero_points=True\n        )\n        b_q_weight[e] = awq_pack(q_w, num_bits, size_k, size_n)\n\n    perm = torch.empty((num_experts, 0), dtype=torch.int32, device=\"cuda\")\n\n    out_jit = jit_awq_marlin_moe_repack(b_q_weight, perm, size_k, size_n, num_bits)\n    out_aot = torch.ops.sgl_kernel.awq_marlin_moe_repack.default(\n        b_q_weight, perm, size_k, size_n, num_bits\n    )\n\n    torch.cuda.synchronize()\n\n    # Bitwise equality\n    torch.testing.assert_close(out_jit, out_aot, rtol=0, atol=0)\n\n\n@pytest.mark.parametrize(\"num_bits\", [4])\n@pytest.mark.parametrize(\"num_experts\", [2, 4])\n@pytest.mark.parametrize(\"k_tiles,n_tiles\", [(1, 1), (2, 2)])\n@pytest.mark.parametrize(\"group_size\", [16, 32])\ndef test_awq_marlin_moe_repack_shape(\n    num_bits, num_experts, k_tiles, n_tiles, group_size\n):\n    tile_k, tile_n = 16, 64\n    size_k = k_tiles * tile_k\n    size_n = n_tiles * tile_n\n    pack_factor = 32 // num_bits\n\n    # Create per-expert AWQ-packed weights\n    b_q_weight = torch.empty(\n        (num_experts, size_k, size_n // pack_factor),\n        dtype=torch.int32,\n        device=\"cuda\",\n    )\n    for e in range(num_experts):\n        b_weight = torch.randn((size_k, size_n), dtype=torch.float16, device=\"cuda\")\n        w_ref, q_w, s, zp = quantize_weights(\n            b_weight, scalar_types.uint4, group_size, zero_points=True\n        )\n        b_q_weight[e] = awq_pack(q_w, num_bits, size_k, size_n)\n\n    perm = torch.empty((num_experts, 0), dtype=torch.int32, device=\"cuda\")\n\n    out = jit_awq_marlin_moe_repack(b_q_weight, perm, size_k, size_n, num_bits)\n    torch.cuda.synchronize()\n\n    assert out.is_cuda and out.dtype == torch.int32\n    expected_shape = (num_experts, size_k // 16, size_n * (num_bits // 2))\n    assert list(out.shape) == list(expected_shape)\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__, \"-v\", \"-s\"])\n"
  },
  {
    "path": "python/sglang/jit_kernel/tests/test_awq_marlin_repack.py",
    "content": "import numpy as np\nimport pytest\nimport torch\nfrom sgl_kernel.scalar_type import scalar_types\n\nfrom sglang.jit_kernel.awq_marlin_repack import (\n    awq_marlin_repack as jit_awq_marlin_repack,\n)\nfrom sglang.srt.layers.quantization.utils import pack_cols, quantize_weights\nfrom sglang.test.test_marlin_utils import get_weight_perm, marlin_weights\n\n\ndef _has_aot_awq_marlin_repack() -> bool:\n    return hasattr(torch.ops.sgl_kernel, \"awq_marlin_repack\") and hasattr(\n        torch.ops.sgl_kernel.awq_marlin_repack, \"default\"\n    )\n\n\nAOT_AVAILABLE = _has_aot_awq_marlin_repack()\n\n\ndef awq_pack(\n    q_w: torch.Tensor,\n    num_bits: int,\n    size_k: int,\n    size_n: int,\n):\n    assert q_w.shape == (size_k, size_n)\n\n    if num_bits == 4:\n        interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])\n    elif num_bits == 8:\n        interleave = np.array([0, 2, 1, 3])\n    else:\n        raise Exception(\"num_bits must be 4 or 8, got {}\".format(num_bits))\n\n    q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel()\n    q_w = q_w.reshape((-1, size_n)).contiguous()\n\n    return pack_cols(q_w, num_bits, size_k, size_n)\n\n\n@pytest.mark.parametrize(\"num_bits\", [4, 8])\n@pytest.mark.parametrize(\"k_tiles,n_tiles\", [(1, 1), (2, 2), (4, 4)])\n@pytest.mark.parametrize(\"group_size\", [16, 32])\ndef test_awq_marlin_repack_jit_vs_aot(num_bits, k_tiles, n_tiles, group_size):\n    if not AOT_AVAILABLE:\n        pytest.skip(\"sgl_kernel AOT not available\")\n\n    tile_k, tile_n = 16, 64\n    size_k = k_tiles * tile_k\n    size_n = n_tiles * tile_n\n\n    b_weight = torch.randn((size_k, size_n), dtype=torch.float16, device=\"cuda\")\n\n    w_ref, q_w, s, zp = quantize_weights(\n        b_weight, scalar_types.uint4, group_size, zero_points=True\n    )\n\n    q_w_awq = awq_pack(q_w, num_bits, size_k, size_n)\n\n    out_jit = jit_awq_marlin_repack(q_w_awq, size_k, size_n, num_bits)\n    out_aot = torch.ops.sgl_kernel.awq_marlin_repack.default(\n        q_w_awq, size_k, size_n, num_bits\n    )\n\n    torch.cuda.synchronize()\n\n    # Bitwise equality\n    torch.testing.assert_close(out_jit, out_aot, rtol=0, atol=0)\n\n\n@pytest.mark.parametrize(\"num_bits\", [4, 8])\n@pytest.mark.parametrize(\"k_tiles,n_tiles\", [(1, 1), (2, 2)])\n@pytest.mark.parametrize(\"group_size\", [16, 32])\ndef test_awq_marlin_repack_correct(num_bits, k_tiles, n_tiles, group_size):\n    tile_k, tile_n = 16, 64\n    size_k = k_tiles * tile_k\n    size_n = n_tiles * tile_n\n    pack_factor = 32 // num_bits\n\n    b_weight = torch.randn((size_k, size_n), dtype=torch.float16, device=\"cuda\")\n\n    w_ref, q_w, s, zp = quantize_weights(\n        b_weight, scalar_types.uint4, group_size, zero_points=True\n    )\n\n    q_w_awq = awq_pack(q_w, num_bits, size_k, size_n)\n\n    weight_perm = get_weight_perm(num_bits)\n    q_w_marlin = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)\n\n    out_gpu = jit_awq_marlin_repack(q_w_awq, size_k, size_n, num_bits)\n    assert out_gpu.is_cuda and out_gpu.dtype == torch.int32\n\n    expected_cols = size_n * tile_k // pack_factor\n    assert list(out_gpu.shape) == [size_k // tile_k, expected_cols]\n\n    torch.cuda.synchronize()\n\n    torch.testing.assert_close(out_gpu, q_w_marlin)\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__, \"-v\", \"-s\"])\n"
  },
  {
    "path": "python/sglang/jit_kernel/tests/test_concat_mla.py",
    "content": "import itertools\n\nimport pytest\nimport torch\nimport triton\n\n\ndef torch_concat_mla_k(\n    k: torch.Tensor, k_nope: torch.Tensor, k_rope: torch.Tensor\n) -> None:\n    \"\"\"Reference PyTorch implementation for concat_mla_k.\"\"\"\n    # k_nope: [num_tokens, num_heads, nope_head_dim]\n    # k_rope: [num_tokens, 1, rope_head_dim]\n    # k: [num_tokens, num_heads, nope_head_dim + rope_head_dim]\n    nope_head_dim = k_nope.shape[-1]\n    k[:, :, :nope_head_dim] = k_nope\n    # Broadcast k_rope across all heads\n    k[:, :, nope_head_dim:] = k_rope.expand(-1, k.shape[1], -1)\n\n\ndef torch_concat_mla_absorb_q(\n    a: torch.Tensor, b: torch.Tensor, out: torch.Tensor\n) -> None:\n    \"\"\"Reference PyTorch implementation for concat_mla_absorb_q.\"\"\"\n    # a: [dim_0, dim_1, a_last_dim]\n    # b: [dim_0, dim_1, b_last_dim]\n    # out: [dim_0, dim_1, a_last_dim + b_last_dim]\n    a_last_dim = a.shape[-1]\n    out[:, :, :a_last_dim] = a\n    out[:, :, a_last_dim:] = b\n\n\ndef sgl_kernel_concat_mla_k(\n    k: torch.Tensor, k_nope: torch.Tensor, k_rope: torch.Tensor\n) -> None:\n    \"\"\"AOT compiled sgl_kernel implementation.\"\"\"\n    from sgl_kernel import concat_mla_k\n\n    concat_mla_k(k, k_nope, k_rope)\n\n\ndef sgl_kernel_concat_mla_absorb_q(\n    a: torch.Tensor, b: torch.Tensor, out: torch.Tensor\n) -> None:\n    \"\"\"AOT compiled sgl_kernel implementation.\"\"\"\n    from sgl_kernel import concat_mla_absorb_q\n\n    result = concat_mla_absorb_q(a, b)  # AOT returns output\n    out.copy_(result)  # Copy to provided tensor for comparison\n\n\ndef jit_concat_mla_k(\n    k: torch.Tensor, k_nope: torch.Tensor, k_rope: torch.Tensor\n) -> None:\n    \"\"\"JIT compiled implementation.\"\"\"\n    from sglang.jit_kernel.concat_mla import concat_mla_k\n\n    concat_mla_k(k, k_nope, k_rope)\n\n\ndef jit_concat_mla_absorb_q(\n    a: torch.Tensor, b: torch.Tensor, out: torch.Tensor\n) -> None:\n    \"\"\"JIT compiled implementation - wrapper for test compatibility.\"\"\"\n    from sglang.jit_kernel.concat_mla import concat_mla_absorb_q\n\n    result = concat_mla_absorb_q(a, b)\n    out.copy_(result)\n\n\n# Constants matching the kernel\nNUM_LOCAL_HEADS = 128\nQK_NOPE_HEAD_DIM = 128\nQK_ROPE_HEAD_DIM = 64\nK_HEAD_DIM = QK_NOPE_HEAD_DIM + QK_ROPE_HEAD_DIM\n\nA_LAST_DIM = 512\nB_LAST_DIM = 64\nOUT_LAST_DIM = A_LAST_DIM + B_LAST_DIM\n\nDEVICE = \"cuda\"\nDTYPE = torch.bfloat16\n\n# Test configurations\nNUM_TOKENS_LIST = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]\n\n\n@pytest.mark.parametrize(\"num_tokens\", NUM_TOKENS_LIST)\ndef test_concat_mla_k_jit_vs_torch(num_tokens: int) -> None:\n    \"\"\"Test JIT kernel against PyTorch reference.\"\"\"\n    k_jit = torch.empty(\n        num_tokens, NUM_LOCAL_HEADS, K_HEAD_DIM, device=DEVICE, dtype=DTYPE\n    )\n    k_torch = torch.empty(\n        num_tokens, NUM_LOCAL_HEADS, K_HEAD_DIM, device=DEVICE, dtype=DTYPE\n    )\n\n    k_nope = torch.randn(\n        num_tokens, NUM_LOCAL_HEADS, QK_NOPE_HEAD_DIM, device=DEVICE, dtype=DTYPE\n    )\n    k_rope = torch.randn(num_tokens, 1, QK_ROPE_HEAD_DIM, device=DEVICE, dtype=DTYPE)\n\n    torch_concat_mla_k(k_torch, k_nope, k_rope)\n    jit_concat_mla_k(k_jit, k_nope, k_rope)\n\n    triton.testing.assert_close(k_jit, k_torch, atol=0, rtol=0)\n\n\n@pytest.mark.parametrize(\"num_tokens\", NUM_TOKENS_LIST)\ndef test_concat_mla_k_jit_vs_aot(num_tokens: int) -> None:\n    \"\"\"Test JIT kernel against AOT kernel for bitwise equivalence.\"\"\"\n    k_jit = torch.empty(\n        num_tokens, NUM_LOCAL_HEADS, K_HEAD_DIM, device=DEVICE, dtype=DTYPE\n    )\n    k_aot = torch.empty(\n        num_tokens, NUM_LOCAL_HEADS, K_HEAD_DIM, device=DEVICE, dtype=DTYPE\n    )\n\n    k_nope = torch.randn(\n        num_tokens, NUM_LOCAL_HEADS, QK_NOPE_HEAD_DIM, device=DEVICE, dtype=DTYPE\n    )\n    k_rope = torch.randn(num_tokens, 1, QK_ROPE_HEAD_DIM, device=DEVICE, dtype=DTYPE)\n\n    sgl_kernel_concat_mla_k(k_aot, k_nope, k_rope)\n    jit_concat_mla_k(k_jit, k_nope, k_rope)\n\n    triton.testing.assert_close(k_jit, k_aot, atol=0, rtol=0)\n\n\nDIM_0_LIST = [1, 2, 4, 8, 16, 32]\nDIM_1_LIST = [1, 2, 4, 8, 16, 128]\n\n\n@pytest.mark.parametrize(\n    \"dim_0,dim_1\",\n    list(itertools.product(DIM_0_LIST, DIM_1_LIST)),\n)\ndef test_concat_mla_absorb_q_jit_vs_torch(dim_0: int, dim_1: int) -> None:\n    \"\"\"Test JIT kernel against PyTorch reference.\"\"\"\n    a = torch.randn(dim_0, dim_1, A_LAST_DIM, device=DEVICE, dtype=DTYPE)\n    b = torch.randn(dim_0, dim_1, B_LAST_DIM, device=DEVICE, dtype=DTYPE)\n    out_jit = torch.empty(dim_0, dim_1, OUT_LAST_DIM, device=DEVICE, dtype=DTYPE)\n    out_torch = torch.empty(dim_0, dim_1, OUT_LAST_DIM, device=DEVICE, dtype=DTYPE)\n\n    torch_concat_mla_absorb_q(a, b, out_torch)\n    jit_concat_mla_absorb_q(a, b, out_jit)\n\n    triton.testing.assert_close(out_jit, out_torch, atol=0, rtol=0)\n\n\n@pytest.mark.parametrize(\n    \"dim_0,dim_1\",\n    list(itertools.product(DIM_0_LIST, DIM_1_LIST)),\n)\ndef test_concat_mla_absorb_q_jit_vs_aot(dim_0: int, dim_1: int) -> None:\n    \"\"\"Test JIT kernel against AOT kernel for bitwise equivalence.\"\"\"\n    a = torch.randn(dim_0, dim_1, A_LAST_DIM, device=DEVICE, dtype=DTYPE)\n    b = torch.randn(dim_0, dim_1, B_LAST_DIM, device=DEVICE, dtype=DTYPE)\n    out_jit = torch.empty(dim_0, dim_1, OUT_LAST_DIM, device=DEVICE, dtype=DTYPE)\n    out_aot = torch.empty(dim_0, dim_1, OUT_LAST_DIM, device=DEVICE, dtype=DTYPE)\n\n    sgl_kernel_concat_mla_absorb_q(a, b, out_aot)\n    jit_concat_mla_absorb_q(a, b, out_jit)\n\n    triton.testing.assert_close(out_jit, out_aot, atol=0, rtol=0)\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__, \"-v\", \"-s\"])\n"
  },
  {
    "path": "python/sglang/jit_kernel/tests/test_cutedsl_gdn.py",
    "content": "\"\"\"Tests for CuTe DSL fused sigmoid gating delta rule kernel (GDN).\"\"\"\n\nimport numpy as np\nimport pytest\nimport torch\n\ntry:\n    import cuda.bindings.driver as cuda_driver\n    import cutlass  # noqa: F401\n    from cutlass.cute.runtime import from_dlpack\n\n    from sglang.jit_kernel import cutedsl_gdn\n\n    CUTEDSL_AVAILABLE = True\nexcept ImportError:\n    CUTEDSL_AVAILABLE = False\n    cutedsl_gdn = None\n\ntry:\n    from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import (\n        fused_sigmoid_gating_delta_rule_update,\n    )\n\n    TRITON_AVAILABLE = True\nexcept ImportError:\n    TRITON_AVAILABLE = False\n\n\ndef run_triton_kernel(A_log, dt_bias, q, k, v, a, b, initial_state, indices, scale):\n    return fused_sigmoid_gating_delta_rule_update(\n        A_log=A_log,\n        a=a,\n        dt_bias=dt_bias,\n        softplus_beta=1.0,\n        softplus_threshold=20.0,\n        q=q,\n        k=k,\n        v=v,\n        b=b,\n        initial_state_source=initial_state,\n        initial_state_indices=indices,\n        scale=scale,\n        use_qk_l2norm_in_kernel=True,\n        cu_seqlens=None,\n    )\n\n\n@pytest.mark.skipif(not CUTEDSL_AVAILABLE, reason=\"CuTe DSL not available\")\n@pytest.mark.skipif(not TRITON_AVAILABLE, reason=\"Triton kernel not available\")\n@pytest.mark.skip(\n    reason=(\n        \"Temporary CI workaround: CuTe DSL GDN precision is currently unstable \"\n        \"against the Triton reference and needs follow-up investigation.\"\n    )\n)\n@pytest.mark.parametrize(\"B\", [16, 128])\ndef test_cutedsl_gdn_precision(B: int):\n    \"\"\"Test precision of CuTe DSL GDN kernel against Triton reference.\"\"\"\n    torch.manual_seed(2025)\n    T, H, K, V, HV = 1, 16, 128, 128, 32\n    scale = K**-0.5\n\n    A_log = torch.randn(HV, dtype=torch.float32, device=\"cuda\")\n    dt_bias = torch.randn(HV, dtype=torch.bfloat16, device=\"cuda\")\n    a = torch.randn(B, T, HV, dtype=torch.bfloat16, device=\"cuda\")\n    b = torch.randn(B, T, HV, dtype=torch.bfloat16, device=\"cuda\")\n    q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device=\"cuda\")\n    k = torch.randn(B, T, H, K, dtype=torch.bfloat16, device=\"cuda\")\n    v = torch.randn(B, T, HV, V, dtype=torch.bfloat16, device=\"cuda\")\n    indices = torch.arange(B, dtype=torch.int32, device=\"cuda\")\n    state_cutedsl = torch.randn(B, HV, K, V, dtype=torch.float32, device=\"cuda\")\n    state_triton = state_cutedsl.clone().reshape(-1).contiguous()\n\n    # Warmup compilation\n    _ = cutedsl_gdn.cutedsl_fused_sigmoid_gating_delta_rule_update(\n        A_log, dt_bias, q, k, v, a, b, state_cutedsl.clone(), indices, scale=scale\n    )\n    torch.cuda.synchronize()\n\n    # Fresh state for actual test\n    state_cutedsl = torch.randn(B, HV, K, V, dtype=torch.float32, device=\"cuda\")\n    state_triton = state_cutedsl.clone().reshape(-1).contiguous()\n\n    out_cutedsl = cutedsl_gdn.cutedsl_fused_sigmoid_gating_delta_rule_update(\n        A_log, dt_bias, q, k, v, a, b, state_cutedsl, indices, scale=scale\n    )\n    out_triton = run_triton_kernel(\n        A_log, dt_bias, q, k, v, a, b, state_triton, indices, scale\n    )\n\n    # Check precision: diff > 0.1 must be < 1% of elements\n    abs_diff = (out_triton.float() - out_cutedsl.float()).abs()\n    max_diff = abs_diff.max().item()\n    mean_diff = abs_diff.mean().item()\n    fail_rate = (abs_diff > 0.1).float().mean().item() * 100\n    has_nan = torch.isnan(out_cutedsl).any() or torch.isinf(out_cutedsl).any()\n\n    kernel_type = \"SmallBatch\" if B < 32 else \"LargeBatch\"\n    print(\n        f\"\\n  B={B} ({kernel_type}): max_diff={max_diff:.2e}, mean_diff={mean_diff:.2e}, fail_rate={fail_rate:.2f}%\"\n    )\n\n    assert not has_nan, \"Output contains NaN/Inf\"\n    assert fail_rate < 1.0, f\"Fail rate {fail_rate:.2f}% >= 1%\"\n\n\n@pytest.mark.skipif(\n    True,\n    reason=\"Skip the performance test because the speedup ratio is highly unstable in the CI environment. \",\n)\n@pytest.mark.skipif(not CUTEDSL_AVAILABLE, reason=\"CuTe DSL not available\")\n@pytest.mark.skipif(not TRITON_AVAILABLE, reason=\"Triton kernel not available\")\n@pytest.mark.parametrize(\"B\", [1, 128])\ndef test_cutedsl_gdn_performance(B: int):\n    \"\"\"Benchmark CuTe DSL GDN kernel against Triton reference.\"\"\"\n    torch.manual_seed(2025)\n    T, H, K, V, HV = 1, 16, 128, 128, 32\n    N = B\n    scale = K**-0.5\n    is_varlen = True\n    warmup, bench_iters, run_iters = 10, 100, 10\n\n    A_log = torch.randn(HV, dtype=torch.float32, device=\"cuda\")\n    dt_bias = torch.randn(HV, dtype=torch.bfloat16, device=\"cuda\")\n    indices = torch.arange(N, dtype=torch.int32, device=\"cuda\")\n    state_cutedsl = torch.randn(N, HV, K, V, dtype=torch.float32, device=\"cuda\")\n    state_triton = state_cutedsl.reshape(-1).contiguous()\n    cu_seqlens = torch.zeros(N + 1, dtype=torch.int32, device=\"cuda\")\n    o_cutedsl = torch.zeros(1, N, HV, V, dtype=torch.bfloat16, device=\"cuda\")\n\n    # Prepare tensors for multiple runs\n    q_list, k_list, v_list, a_list, b_list = [], [], [], [], []\n    q_tensor_list, k_tensor_list, v_tensor_list, a_tensor_list, b_tensor_list = (\n        [],\n        [],\n        [],\n        [],\n        [],\n    )\n    q_triton, k_triton, v_triton, a_triton, b_triton = [], [], [], [], []\n\n    for ri in range(run_iters):\n        torch.manual_seed(2025 + ri)\n        q_i = torch.randn(1, N, H, K, dtype=torch.bfloat16, device=\"cuda\")\n        k_i = torch.randn(1, N, H, K, dtype=torch.bfloat16, device=\"cuda\")\n        v_i = torch.randn(1, N, HV, V, dtype=torch.bfloat16, device=\"cuda\")\n        a_i = torch.randn(N, HV, dtype=torch.bfloat16, device=\"cuda\")\n        b_i = torch.randn(N, HV, dtype=torch.bfloat16, device=\"cuda\")\n\n        q_list.append(q_i)\n        k_list.append(k_i)\n        v_list.append(v_i)\n        a_list.append(a_i)\n        b_list.append(b_i)\n        q_tensor_list.append(from_dlpack(q_i, assumed_align=16))\n        k_tensor_list.append(from_dlpack(k_i, assumed_align=16))\n        v_tensor_list.append(from_dlpack(v_i, assumed_align=16))\n        a_tensor_list.append(from_dlpack(a_i, assumed_align=16))\n        b_tensor_list.append(from_dlpack(b_i, assumed_align=16))\n        q_triton.append(q_i.transpose(0, 1).contiguous())\n        k_triton.append(k_i.transpose(0, 1).contiguous())\n        v_triton.append(v_i.transpose(0, 1).contiguous())\n        a_triton.append(a_i.unsqueeze(1).contiguous())\n        b_triton.append(b_i.unsqueeze(1).contiguous())\n\n    A_log_t = from_dlpack(A_log, assumed_align=16)\n    dt_bias_t = from_dlpack(dt_bias, assumed_align=16)\n    h0_t = from_dlpack(state_cutedsl, assumed_align=16)\n    idx_t = from_dlpack(indices, assumed_align=16)\n    o_t = from_dlpack(o_cutedsl, assumed_align=16)\n    cu_t = from_dlpack(cu_seqlens, assumed_align=16)\n\n    torch_stream = torch.cuda.Stream()\n    stream = cuda_driver.CUstream(torch_stream.cuda_stream)\n\n    # Compile kernels\n    compiled = cutedsl_gdn._get_compiled_kernel(N, H, HV, K, V, N, N < 32, is_varlen)\n    torch.cuda.synchronize()\n\n    for ri in range(run_iters):\n        _ = run_triton_kernel(\n            A_log,\n            dt_bias,\n            q_triton[ri],\n            k_triton[ri],\n            v_triton[ri],\n            a_triton[ri],\n            b_triton[ri],\n            state_triton,\n            indices,\n            scale,\n        )\n    torch.cuda.synchronize()\n\n    def run_cutedsl():\n        for ri in range(run_iters):\n            compiled(\n                cu_t,\n                q_tensor_list[ri],\n                k_tensor_list[ri],\n                v_tensor_list[ri],\n                a_tensor_list[ri],\n                b_tensor_list[ri],\n                A_log_t,\n                dt_bias_t,\n                h0_t,\n                idx_t,\n                o_t,\n                stream,\n            )\n\n    def run_triton():\n        for ri in range(run_iters):\n            _ = run_triton_kernel(\n                A_log,\n                dt_bias,\n                q_triton[ri],\n                k_triton[ri],\n                v_triton[ri],\n                a_triton[ri],\n                b_triton[ri],\n                state_triton,\n                indices,\n                scale,\n            )\n\n    # Warmup\n    with torch.cuda.stream(torch_stream):\n        run_cutedsl()\n    torch.cuda.synchronize()\n    run_triton()\n    torch.cuda.synchronize()\n\n    # Capture CUDA graphs\n    graph_triton = torch.cuda.CUDAGraph()\n    graph_cutedsl = torch.cuda.CUDAGraph()\n    try:\n        with torch.cuda.graph(graph_triton):\n            run_triton()\n        with torch.cuda.graph(graph_cutedsl, stream=torch_stream):\n            run_cutedsl()\n        torch.cuda.synchronize()\n    except Exception:\n        graph_triton = graph_cutedsl = None\n\n    # Warmup with graphs\n    for _ in range(warmup):\n        if graph_cutedsl:\n            graph_cutedsl.replay()\n        else:\n            with torch.cuda.stream(torch_stream):\n                run_cutedsl()\n        torch.cuda.synchronize()\n\n        if graph_triton:\n            graph_triton.replay()\n        else:\n            run_triton()\n        torch.cuda.synchronize()\n\n    # Benchmark\n    triton_times, cutedsl_times = [], []\n    for _ in range(bench_iters):\n        start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(\n            enable_timing=True\n        )\n        start.record()\n        if graph_triton:\n            graph_triton.replay()\n        else:\n            run_triton()\n        end.record()\n        torch.cuda.synchronize()\n        triton_times.append(start.elapsed_time(end))\n\n        start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(\n            enable_timing=True\n        )\n        with torch.cuda.stream(torch_stream):\n            start.record()\n            if graph_cutedsl:\n                graph_cutedsl.replay()\n            else:\n                run_cutedsl()\n            end.record()\n        torch.cuda.synchronize()\n        cutedsl_times.append(start.elapsed_time(end))\n\n    triton_mean = np.mean(triton_times) / run_iters * 1000\n    triton_std = np.std(triton_times) / run_iters * 1000\n    cutedsl_mean = np.mean(cutedsl_times) / run_iters * 1000\n    cutedsl_std = np.std(cutedsl_times) / run_iters * 1000\n    speedup = triton_mean / cutedsl_mean\n\n    kernel_type = \"SmallBatch\" if B < 32 else \"LargeBatch\"\n    print(\n        f\"\\n  B={B} ({kernel_type}): Triton={triton_mean:.2f}±{triton_std:.2f}μs, CuTeDSL={cutedsl_mean:.2f}±{cutedsl_std:.2f}μs, speedup={speedup:.2f}x\"\n    )\n\n    min_speedup = 1.0 if B < 32 else 1.15\n    assert speedup >= min_speedup, f\"Speedup {speedup:.2f}x < {min_speedup}x for B={B}\"\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__, \"-v\", \"-s\"])\n"
  },
  {
    "path": "python/sglang/jit_kernel/tests/test_flash_attention_4.py",
    "content": "# Adapted from https://github.com/Dao-AILab/flash-attention/blob/8ecf128f683266735ba68e3c106ff67a2611886e/tests/cute/test_flash_attn.py\n\n# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n\nimport itertools\nimport math\n\nimport pytest\nimport torch\nimport torch.nn.functional as F\nfrom einops import rearrange, repeat\n\nfrom sglang.jit_kernel.flash_attention_v4 import flash_attn_varlen_func\n\n# Skip this test on Hopper machine\nskip_condition = torch.cuda.get_device_capability() < (10, 0)\n\n\ndef apply_rotary_emb(\n    x: torch.Tensor,\n    cos: torch.Tensor,\n    sin: torch.Tensor,\n    seqlen_offsets: torch.Tensor | int | None = 0,\n    interleaved: bool = False,\n) -> torch.Tensor:\n    rotary_dim = cos.shape[-1] * 2\n    x_rot = x[..., :rotary_dim]\n    x_pass = x[..., rotary_dim:]\n\n    cos = cos.to(dtype=x.dtype)\n    sin = sin.to(dtype=x.dtype)\n\n    if x_rot.dim() < 2:\n        raise ValueError(f\"apply_rotary_emb expects x.dim() >= 2, got {x_rot.dim()}\")\n\n    b = x_rot.shape[0]\n    s = x_rot.shape[1]\n\n    if seqlen_offsets is None:\n        seqlen_offsets = 0\n\n    if isinstance(seqlen_offsets, int):\n        positions = (\n            torch.arange(s, device=x_rot.device, dtype=torch.long) + seqlen_offsets\n        )\n        cos_s = cos.index_select(0, positions)\n        sin_s = sin.index_select(0, positions)\n        cos_s = cos_s.unsqueeze(0).expand(b, -1, -1)\n        sin_s = sin_s.unsqueeze(0).expand(b, -1, -1)\n    else:\n        if seqlen_offsets.dim() != 1 or seqlen_offsets.shape[0] != b:\n            raise ValueError(\n                \"apply_rotary_emb expects seqlen_offsets to be int or shape [batch]\"\n            )\n        positions = torch.arange(s, device=x_rot.device, dtype=torch.long).unsqueeze(\n            0\n        ) + seqlen_offsets.to(dtype=torch.long).unsqueeze(1)\n        cos_s = cos.index_select(0, positions.reshape(-1)).view(b, s, -1)\n        sin_s = sin.index_select(0, positions.reshape(-1)).view(b, s, -1)\n\n    x_rot = x_rot.reshape(b, s, -1, rotary_dim)\n    cos_s = cos_s.unsqueeze(2)\n    sin_s = sin_s.unsqueeze(2)\n\n    if interleaved:\n        x1 = x_rot[..., ::2]\n        x2 = x_rot[..., 1::2]\n        o1 = x1 * cos_s - x2 * sin_s\n        o2 = x2 * cos_s + x1 * sin_s\n        x_rot = torch.stack((o1, o2), dim=-1).flatten(-2)\n    else:\n        x1, x2 = torch.chunk(x_rot, 2, dim=-1)\n        o1 = x1 * cos_s - x2 * sin_s\n        o2 = x2 * cos_s + x1 * sin_s\n        x_rot = torch.cat((o1, o2), dim=-1)\n\n    x_rot = x_rot.reshape_as(x[..., :rotary_dim])\n    return torch.cat((x_rot, x_pass), dim=-1)\n\n\ndef unpad_input(hidden_states, attention_mask, unused_mask=None):\n    \"\"\"\n    Arguments:\n        hidden_states: (batch, seqlen, ...)\n        attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.\n        unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.\n    Return:\n        hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.\n        indices: (total_nnz), the indices of masked tokens from the flattened input sequence.\n        cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.\n        max_seqlen_in_batch: int\n        seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.\n    \"\"\"\n    all_masks = (\n        (attention_mask + unused_mask) if unused_mask is not None else attention_mask\n    )\n    seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)\n    used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)\n    indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()\n    max_seqlen_in_batch = seqlens_in_batch.max().item()\n    cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))\n    # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the\n    # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim\n    # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to\n    # index with integer indices.\n    return (\n        rearrange(hidden_states, \"b s ... -> (b s) ...\")[indices],\n        indices,\n        cu_seqlens,\n        max_seqlen_in_batch,\n        used_seqlens_in_batch,\n    )\n\n\ndef pad_input(hidden_states, indices, batch, seqlen):\n    \"\"\"\n    Arguments:\n        hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.\n        indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.\n        batch: int, batch size for the padded sequence.\n        seqlen: int, maximum sequence length for the padded sequence.\n    Return:\n        hidden_states: (batch, seqlen, ...)\n    \"\"\"\n    dim = hidden_states.shape[1:]\n    output = torch.zeros(\n        (batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype\n    )\n    output[indices] = hidden_states\n    return rearrange(output, \"(b s) ... -> b s ...\", b=batch)\n\n\ndef generate_random_padding_mask(\n    max_seqlen, batch_size, device, mode=\"random\", zero_lengths=False\n):\n    assert mode in [\"full\", \"random\", \"third\"]\n    if mode == \"full\":\n        lengths = torch.full(\n            (batch_size, 1), max_seqlen, device=device, dtype=torch.int32\n        )\n    elif mode == \"random\":\n        lengths = torch.randint(\n            max(0 if zero_lengths else 1, max_seqlen - 20),\n            max_seqlen + 1,\n            (batch_size, 1),\n            device=device,\n        )\n    elif mode == \"third\":\n        lengths = torch.randint(\n            max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device\n        )\n    else:\n        # This should never happen due to the assertion above, but for linter\n        lengths = torch.full(\n            (batch_size, 1), max_seqlen, device=device, dtype=torch.int32\n        )\n\n    if zero_lengths:\n        # Generate zero-lengths every 5 batches and the last batch.\n        for i in range(batch_size):\n            if i % 5 == 0:\n                lengths[i] = 0\n        lengths[-1] = 0\n    padding_mask = (\n        repeat(torch.arange(max_seqlen, device=device), \"s -> b s\", b=batch_size)\n        < lengths\n    )\n    return padding_mask\n\n\ndef generate_qkv(\n    q,\n    k,\n    v,\n    query_padding_mask=None,\n    key_padding_mask=None,\n    qv=None,\n    kvpacked=False,\n    qkvpacked=False,\n    query_unused_mask=None,\n    key_unused_mask=None,\n):\n    \"\"\"\n    Arguments:\n        q: (batch_size, seqlen_q, nheads, d)\n        k: (batch_size, seqlen_k, nheads_k, d)\n        v: (batch_size, seqlen_k, nheads_k, d_v)\n        query_padding_mask: (batch_size, seqlen), bool\n        key_padding_mask: (batch_size, seqlen), bool\n    \"\"\"\n    assert not (kvpacked and qkvpacked)\n    batch_size, seqlen_q, nheads, d = q.shape\n    d_v = v.shape[-1]\n    _, seqlen_k, nheads_k, _ = k.shape\n    assert k.shape == (batch_size, seqlen_k, nheads_k, d)\n    assert v.shape == (batch_size, seqlen_k, nheads_k, d_v)\n    if query_unused_mask is not None or key_unused_mask is not None:\n        assert not kvpacked\n        assert not qkvpacked\n\n    if query_padding_mask is not None:\n        q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q = unpad_input(\n            q, query_padding_mask, query_unused_mask\n        )\n        output_pad_fn = lambda output_unpad: pad_input(\n            output_unpad, indices_q, batch_size, seqlen_q\n        )\n        qv_unpad = (\n            rearrange(qv, \"b s ... -> (b s) ...\")[indices_q] if qv is not None else None\n        )\n    else:\n        q_unpad = rearrange(q, \"b s h d -> (b s) h d\")\n        cu_seqlens_q = torch.arange(\n            0,\n            (batch_size + 1) * seqlen_q,\n            step=seqlen_q,\n            dtype=torch.int32,\n            device=q_unpad.device,\n        )\n        seqused_q = None\n        max_seqlen_q = seqlen_q\n        output_pad_fn = lambda output_unpad: rearrange(\n            output_unpad, \"(b s) h d -> b s h d\", b=batch_size\n        )\n        qv_unpad = rearrange(qv, \"b s ... -> (b s) ...\") if qv is not None else None\n\n    if key_padding_mask is not None:\n        k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input(\n            k, key_padding_mask, key_unused_mask\n        )\n        v_unpad, *rest = unpad_input(v, key_padding_mask, key_unused_mask)\n    else:\n        k_unpad = rearrange(k, \"b s h d -> (b s) h d\")\n        v_unpad = rearrange(v, \"b s h d -> (b s) h d\")\n        cu_seqlens_k = torch.arange(\n            0,\n            (batch_size + 1) * seqlen_k,\n            step=seqlen_k,\n            dtype=torch.int32,\n            device=k_unpad.device,\n        )\n        seqused_k = None\n        max_seqlen_k = seqlen_k\n\n    if qkvpacked:\n        assert (query_padding_mask == key_padding_mask).all()\n        assert nheads == nheads_k\n        qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)\n        qkv = torch.stack([q, k, v], dim=2)\n        if query_padding_mask is not None:\n            dqkv_pad_fn = lambda dqkv_unpad: pad_input(\n                dqkv_unpad, indices_q, batch_size, seqlen_q\n            )\n        else:\n            dqkv_pad_fn = lambda dqkv_unpad: rearrange(\n                dqkv_unpad, \"(b s) t h d -> b s t h d\", b=batch_size\n            )\n        return (\n            qkv_unpad.detach().requires_grad_(),\n            cu_seqlens_q,\n            max_seqlen_q,\n            qkv.detach().requires_grad_(),\n            output_pad_fn,\n            dqkv_pad_fn,\n        )\n    elif kvpacked:\n        kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)\n        kv = torch.stack([k, v], dim=2)\n        dq_pad_fn = output_pad_fn\n        if key_padding_mask is not None:\n            dkv_pad_fn = lambda dkv_unpad: pad_input(\n                dkv_unpad, indices_k, batch_size, seqlen_k\n            )\n        else:\n            dkv_pad_fn = lambda dkv_unpad: rearrange(\n                dkv_unpad, \"(b s) t h d -> b s t h d\", b=batch_size\n            )\n        return (\n            q_unpad.detach().requires_grad_(),\n            kv_unpad.detach().requires_grad_(),\n            cu_seqlens_q,\n            cu_seqlens_k,\n            max_seqlen_q,\n            max_seqlen_k,\n            q.detach().requires_grad_(),\n            kv.detach().requires_grad_(),\n            output_pad_fn,\n            dq_pad_fn,\n            dkv_pad_fn,\n        )\n    else:\n        dq_pad_fn = output_pad_fn\n        if key_padding_mask is not None:\n            dk_pad_fn = lambda dk_unpad: pad_input(\n                dk_unpad, indices_k, batch_size, seqlen_k\n            )\n        else:\n            dk_pad_fn = lambda dk_unpad: rearrange(\n                dk_unpad, \"(b s) h d -> b s h d\", b=batch_size\n            )\n        return (\n            q_unpad.detach().requires_grad_(),\n            k_unpad.detach().requires_grad_(),\n            v_unpad.detach().requires_grad_(),\n            qv_unpad.detach() if qv is not None else None,\n            cu_seqlens_q,\n            cu_seqlens_k,\n            seqused_q,\n            seqused_k,\n            max_seqlen_q,\n            max_seqlen_k,\n            q.detach().requires_grad_(),\n            k.detach().requires_grad_(),\n            v.detach().requires_grad_(),\n            qv.detach() if qv is not None else None,\n            output_pad_fn,\n            dq_pad_fn,\n            dk_pad_fn,\n        )\n\n\ndef construct_local_mask(\n    seqlen_q,\n    seqlen_k,\n    window_size=(None, None),\n    sink_token_length=0,\n    query_padding_mask=None,\n    key_padding_mask=None,\n    key_leftpad=None,\n    device=None,\n):\n    row_idx = rearrange(\n        torch.arange(seqlen_q, device=device, dtype=torch.long), \"s -> s 1\"\n    )\n    col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)\n    if key_leftpad is not None:\n        key_leftpad = rearrange(key_leftpad, \"b -> b 1 1 1\")\n        col_idx = repeat(col_idx, \"s -> b 1 1 s\", b=key_leftpad.shape[0])\n        col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)\n    sk = (\n        seqlen_k\n        if key_padding_mask is None\n        else rearrange(key_padding_mask.sum(-1), \"b -> b 1 1 1\")\n    )\n    sq = (\n        seqlen_q\n        if query_padding_mask is None\n        else rearrange(query_padding_mask.sum(-1), \"b -> b 1 1 1\")\n    )\n    if window_size[0] is None:\n        return col_idx > row_idx + sk - sq + window_size[1]\n    else:\n        sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk\n        return torch.logical_or(\n            col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),\n            torch.logical_and(\n                col_idx < row_idx + sk - sq - window_size[0],\n                col_idx >= sink_token_length,\n            ),\n        )\n\n\ndef construct_chunk_mask(\n    seqlen_q,\n    seqlen_k,\n    attention_chunk,\n    query_padding_mask=None,\n    key_padding_mask=None,\n    key_leftpad=None,\n    device=None,\n):\n    row_idx = rearrange(\n        torch.arange(seqlen_q, device=device, dtype=torch.long), \"s -> s 1\"\n    )\n    col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)\n    if key_leftpad is not None:\n        key_leftpad = rearrange(key_leftpad, \"b -> b 1 1 1\")\n        col_idx = repeat(col_idx, \"s -> b 1 1 s\", b=key_leftpad.shape[0])\n        col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)\n    sk = (\n        seqlen_k\n        if key_padding_mask is None\n        else rearrange(key_padding_mask.sum(-1), \"b -> b 1 1 1\")\n    )\n    sq = (\n        seqlen_q\n        if query_padding_mask is None\n        else rearrange(query_padding_mask.sum(-1), \"b -> b 1 1 1\")\n    )\n    sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk\n    # Subtract remainder instead of divide and then multiply to take care of negative values\n    col_limit_left_chunk = row_idx + sk - sq - (row_idx + sk - sq) % attention_chunk\n    return torch.logical_or(\n        col_idx < col_limit_left_chunk,\n        col_idx >= col_limit_left_chunk + attention_chunk,\n    )\n\n\ndef attention_ref(\n    q,\n    k,\n    v,\n    query_padding_mask=None,\n    key_padding_mask=None,\n    key_leftpad=None,\n    attn_bias=None,\n    dropout_p=0.0,\n    dropout_mask=None,\n    causal=False,\n    qv=None,\n    q_descale=None,\n    k_descale=None,\n    v_descale=None,\n    window_size=(None, None),\n    attention_chunk=0,\n    sink_token_length=0,\n    learnable_sink=None,\n    softcap=0.0,\n    upcast=True,\n    reorder_ops=False,\n    intermediate_dtype=None,\n):\n    \"\"\"\n    Arguments:\n        q: (batch_size, seqlen_q, nheads, head_dim)\n        k: (batch_size, seqlen_k, nheads, head_dim)\n        v: (batch_size, seqlen_k, nheads, head_dim_v)\n        qv: (batch_size, seqlen_q, nheads, head_dim_v)\n        query_padding_mask: (batch_size, seqlen_q)\n        key_padding_mask: (batch_size, seqlen_k)\n        attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)\n        dropout_p: float\n        dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)\n        causal: whether to apply causal masking\n        upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast\n            output back to fp16/bf16.\n        reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.)\n            without changing the math. This is to estimate the numerical error from operation\n            reordering.\n    Output:\n        output: (batch_size, seqlen_q, nheads, head_dim_v)\n        attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout\n    \"\"\"\n    if causal:\n        window_size = (window_size[0], 0)\n    dtype_og = q.dtype\n    if upcast:\n        q, k, v = q.float(), k.float(), v.float()\n        qv = qv.float() if qv is not None else None\n    if q_descale is not None:\n        q_descale = repeat(q_descale, \"b h -> b 1 (h g) 1\", g=q.shape[2] // k.shape[2])\n        q = (q.float() * q_descale).to(q.dtype)\n        qv = (qv.float() * q_descale).to(qv.dtype) if qv is not None else None\n    if k_descale is not None:\n        k = (k.float() * rearrange(k_descale, \"b h -> b 1 h 1\")).to(dtype=k.dtype)\n    if v_descale is not None:\n        v = (v.float() * rearrange(v_descale, \"b h -> b 1 h 1\")).to(dtype=v.dtype)\n    seqlen_q, seqlen_k = q.shape[1], k.shape[1]\n    k = repeat(k, \"b s h d -> b s (h g) d\", g=q.shape[2] // k.shape[2])\n    v = repeat(v, \"b s h d -> b s (h g) d\", g=q.shape[2] // v.shape[2])\n    d = q.shape[-1]\n    dv = v.shape[-1]\n    softmax_scale = 1.0 / math.sqrt(d if qv is None else d + dv)\n    if not reorder_ops:\n        scores = torch.einsum(\"bthd,bshd->bhts\", q * softmax_scale, k)\n    else:\n        scores = torch.einsum(\"bthd,bshd->bhts\", q, k * softmax_scale)\n    if qv is not None:\n        scores = scores + torch.einsum(\"bthd,bshd->bhts\", qv * softmax_scale, v)\n    if softcap > 0:\n        scores = torch.tanh(scores / softcap) * softcap\n    if key_padding_mask is not None:\n        scores.masked_fill_(\n            rearrange(~key_padding_mask, \"b s -> b 1 1 s\"), float(\"-inf\")\n        )\n    local_mask = None\n    if window_size[0] is not None or window_size[1] is not None:\n        local_mask = construct_local_mask(\n            seqlen_q,\n            seqlen_k,\n            window_size,\n            sink_token_length,\n            query_padding_mask,\n            key_padding_mask,\n            key_leftpad=key_leftpad,\n            device=q.device,\n        )\n    if attention_chunk > 0:\n        chunk_mask = construct_chunk_mask(\n            seqlen_q,\n            seqlen_k,\n            attention_chunk,\n            query_padding_mask,\n            key_padding_mask,\n            key_leftpad=key_leftpad,\n            device=q.device,\n        )\n        local_mask = (\n            torch.logical_or(local_mask, chunk_mask)\n            if local_mask is not None\n            else chunk_mask\n        )\n    if local_mask is not None:\n        scores.masked_fill_(local_mask, float(\"-inf\"))\n    if attn_bias is not None:\n        scores = scores + attn_bias\n    if learnable_sink is None:\n        attention = torch.softmax(scores, dim=-1).to(v.dtype)\n    else:\n        scores_fp32 = scores.to(torch.float32)\n        logits_max = torch.amax(scores_fp32, dim=-1, keepdim=True)\n        learnable_sink = rearrange(learnable_sink, \"h -> h 1 1\")\n        logits_or_sinks_max = torch.maximum(learnable_sink, logits_max)\n        unnormalized_scores = torch.exp(scores_fp32 - logits_or_sinks_max)\n        normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + torch.exp(\n            learnable_sink - logits_or_sinks_max\n        )\n        attention = (unnormalized_scores / normalizer).to(v.dtype)\n    # We want to mask here so that the attention matrix doesn't have any NaNs\n    # Otherwise we'll get NaN in dV\n    if query_padding_mask is not None:\n        attention = attention.masked_fill(\n            rearrange(~query_padding_mask, \"b s -> b 1 s 1\"), 0.0\n        )\n    # Without this we might get NaN in dv\n    if key_padding_mask is not None:\n        attention = attention.masked_fill(\n            rearrange(~key_padding_mask, \"b s -> b 1 1 s\"), 0.0\n        )\n    # Some rows might be completely masked out so we fill them with zero instead of NaN\n    if local_mask is not None:\n        attention = attention.masked_fill(\n            torch.all(local_mask, dim=-1, keepdim=True), 0.0\n        )\n    dropout_scaling = 1.0 / (1 - dropout_p)\n    # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling\n    # output = torch.einsum('bhts,bshd->bthd', attention_drop , v)\n    if dropout_mask is not None:\n        attention_drop = attention.masked_fill(~dropout_mask, 0.0)\n    else:\n        attention_drop = attention\n    if intermediate_dtype is not None:\n        attention_drop = attention_drop.to(intermediate_dtype).to(attention_drop.dtype)\n    output = torch.einsum(\"bhts,bshd->bthd\", attention_drop, v * dropout_scaling)\n    if query_padding_mask is not None:\n        output.masked_fill_(rearrange(~query_padding_mask, \"b s -> b s 1 1\"), 0.0)\n    return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)\n\n\n@pytest.mark.skipif(\n    skip_condition, reason=\"FA4 Requires compute capability of 10 or above.\"\n)\n# @pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])\n@pytest.mark.parametrize(\"dtype\", [torch.bfloat16])\n@pytest.mark.parametrize(\"mha_type\", [\"mha\", \"mqa\", \"gqa\"])\n# @pytest.mark.parametrize(\"mha_type\", [\"mqa\"])\n@pytest.mark.parametrize(\"has_learnable_sink\", [False, True])\n# @pytest.mark.parametrize(\"has_learnable_sink\", [False])\n# @pytest.mark.parametrize(\"has_qv\", [False, True])\n@pytest.mark.parametrize(\"has_qv\", [False])\n# @pytest.mark.parametrize(\"deterministic\", [False, True])\n@pytest.mark.parametrize(\"deterministic\", [False])\n# @pytest.mark.parametrize(\"softcap\", [0.0, 15.0])\n@pytest.mark.parametrize(\"softcap\", [0.0])\n# @pytest.mark.parametrize(\"local\", [False, True])\n@pytest.mark.parametrize(\"local\", [False])\n@pytest.mark.parametrize(\"causal\", [False, True])\n# @pytest.mark.parametrize(\"causal\", [False])\n# @pytest.mark.parametrize(\"add_unused_qkv\", [False, True])\n@pytest.mark.parametrize(\"add_unused_qkv\", [False])\n# @pytest.mark.parametrize(\"d\", [32, 64, 96, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256])\n# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])\n# @pytest.mark.parametrize('d', [56, 80])\n# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128])\n# @pytest.mark.parametrize(\"d\", [64, 96, 128])\n@pytest.mark.parametrize(\"d\", [64, 128])\n# @pytest.mark.parametrize(\"d\", [192])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        # (1, 1),\n        # (1, 3),\n        # (2, 1),\n        (511, 1),\n        (3, 513),\n        (64, 128),\n        (128, 128),\n        (256, 256),\n        # (113, 203),\n        # (128, 217),\n        # (113, 211),\n        # (108, 256),\n        # (256, 512),\n        (307, 256),\n        (640, 128),\n        (512, 256),\n        (1024, 1024),\n        (1023, 1024),\n        (1024, 1023),\n        (2048, 2048),\n    ],\n)\ndef test_flash_attn_varlen_output(\n    seqlen_q,\n    seqlen_k,\n    d,\n    add_unused_qkv,\n    causal,\n    local,\n    softcap,\n    deterministic,\n    has_qv,\n    has_learnable_sink,\n    mha_type,\n    dtype,\n):\n    if (\n        causal or local\n    ):  # Right now we only support causal attention with seqlen_k == seqlen_q\n        seqlen_k = seqlen_q\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local))\n    batch_size = 49 if seqlen_q <= 1024 else 7\n    nheads = 6\n    # batch_size = 1\n    # nheads = 1\n    nheads_kv = nheads if mha_type == \"mha\" else (3 if mha_type == \"gqa\" else 1)\n    dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype\n    # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d])\n    dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d])\n    if dtype == torch.float8_e4m3fn:\n        dv_vals = [d]\n    # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k else [0]\n    attention_chunk_vals = [0]\n    for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals):\n        q_ref = torch.randn(\n            batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref\n        )\n        if softcap > 0.0:\n            # Ensure the values of qk are at least within softcap range.\n            q_ref = (q_ref * softcap / 4).detach().requires_grad_()\n        q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_()\n        k_ref = (\n            torch.randn(\n                batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref\n            )\n            .to(dtype)\n            .to(dtype_ref)\n            .requires_grad_()\n        )\n        v_ref = (\n            torch.randn(\n                batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref\n            )\n            .to(dtype)\n            .to(dtype_ref)\n            .requires_grad_()\n        )\n        if has_qv:\n            qv_ref = (\n                torch.randn(\n                    batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref\n                )\n                .to(dtype)\n                .to(dtype_ref)\n            )\n        else:\n            qv_ref = None\n        # Put window_size after QKV randn so that window_size changes from test to test\n        window_size = (\n            (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist()\n        )\n        if has_learnable_sink:\n            learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device)\n        else:\n            learnable_sink = None\n        if dtype == torch.float8_e4m3fn:\n            q_descale, k_descale, v_descale = [\n                torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32)\n                * 2\n                for _ in range(3)\n            ]\n        else:\n            q_descale, k_descale, v_descale = None, None, None\n        q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)]\n        qv = qv_ref.detach() if has_qv else None\n        query_padding_mask = generate_random_padding_mask(\n            seqlen_q, batch_size, device, mode=\"random\", zero_lengths=False\n        )\n        # TODO: test zero_lengths\n        key_padding_mask = generate_random_padding_mask(\n            # seqlen_k, batch_size, device, mode=\"random\", zero_lengths=True\n            seqlen_k,\n            batch_size,\n            device,\n            mode=\"random\",\n            zero_lengths=False,\n        )\n\n        def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device):\n            if add_unused:\n                another_mask = generate_random_padding_mask(max_seq_len, bs, device)\n                attn_mask = torch.logical_and(padding_mask, another_mask)\n                unused_mask = torch.logical_xor(\n                    torch.logical_or(padding_mask, another_mask), attn_mask\n                )\n            else:\n                attn_mask = padding_mask\n                unused_mask = None\n            return attn_mask, unused_mask\n\n        query_padding_mask, query_unused_mask = _gen_unused_masks(\n            query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device\n        )\n        # query_padding_mask[:] = True\n        # query_unused_mask = None\n        key_padding_mask, key_unused_mask = _gen_unused_masks(\n            key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device\n        )\n\n        if causal or local:\n            key_padding_mask = query_padding_mask\n\n        result = generate_qkv(\n            q,\n            k,\n            v,\n            query_padding_mask,\n            key_padding_mask,\n            qv=qv,\n            kvpacked=False,\n            query_unused_mask=query_unused_mask,\n            key_unused_mask=key_unused_mask,\n        )\n        (\n            q_unpad,  # 0\n            k_unpad,  # 1\n            v_unpad,  # 2\n            qv_unpad,  # 3\n            cu_seqlens_q,  # 4\n            cu_seqlens_k,  # 5\n            seqused_q,  # 6\n            seqused_k,  # 7\n            max_seqlen_q,  # 8\n            max_seqlen_k,  # 9\n            q,  # 10\n            k,  # 11\n            v,  # 12\n            qv,  # 13\n            output_pad_fn,  # 14\n            dq_pad_fn,  # 15\n            dk_pad_fn,  # 16\n        ) = result\n        q_unpad, k_unpad, v_unpad = [\n            x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)\n        ]\n        out_ref, attn_ref = attention_ref(\n            q_ref,\n            k_ref,\n            v_ref,\n            query_padding_mask,\n            key_padding_mask,\n            causal=causal,\n            qv=qv_ref,\n            q_descale=q_descale,\n            k_descale=k_descale,\n            v_descale=v_descale,\n            window_size=window_size,\n            attention_chunk=attention_chunk,\n            learnable_sink=learnable_sink,\n            softcap=softcap,\n        )\n        out_pt, attn_pt = attention_ref(\n            q_ref,\n            k_ref,\n            v_ref,\n            query_padding_mask,\n            key_padding_mask,\n            causal=causal,\n            qv=qv_ref,\n            q_descale=q_descale,\n            k_descale=k_descale,\n            v_descale=v_descale,\n            window_size=window_size,\n            attention_chunk=attention_chunk,\n            learnable_sink=learnable_sink,\n            softcap=softcap,\n            upcast=False,\n            reorder_ops=True,\n            intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None,\n        )\n\n        print(f\"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}\")\n        print(f\"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}\")\n\n        if query_unused_mask is not None:\n            q_zero_masking = rearrange(query_unused_mask, \"b s -> b s 1 1\")\n\n        # Numerical error if we just do any arithmetic on out_ref\n        fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item()\n        rtol = 2 if softcap == 0.0 else 3\n\n        pack_gqa_vals = [False, True, None]\n        # num_splits_vals = [1, 3]\n        num_splits_vals = [1]\n        for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals):\n            out_unpad, lse = flash_attn_varlen_func(\n                q_unpad,\n                k_unpad,\n                v_unpad,\n                cu_seqlens_q=cu_seqlens_q,\n                cu_seqlens_k=cu_seqlens_k,\n                # max_seqlen_q and max_seqlen_k not needed for FA4\n                seqused_q=seqused_q,\n                seqused_k=seqused_k,\n                causal=causal,\n                window_size=window_size,\n                softcap=softcap,\n                sinks=learnable_sink,  # FA4 uses learnable_sink, not sinks\n                pack_gqa=pack_gqa,\n                return_softmax_lse=True,\n            )\n            out = output_pad_fn(out_unpad)\n            if query_unused_mask is not None:\n                out.masked_fill_(q_zero_masking, 0.0)\n            print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n            print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n            # if not causal:\n            #     print(f\"LSE max diff: {(lse - lse_ref).abs().max().item()}\")\n            # breakpoint()\n\n            # Check that FlashAttention's numerical error is at most 3x the numerical error\n            # of a Pytorch implementation.\n            assert (out - out_ref).abs().max().item() <= rtol * (\n                out_pt - out_ref\n            ).abs().max().item() + fwd_atol\n\n        if (\n            dtype != torch.float8_e4m3fn\n            and not has_qv\n            and not dv > 256\n            and not attention_chunk != 0\n            and dv == d\n            and not has_learnable_sink\n            and False\n        ):\n            g_unpad = torch.randn_like(out_unpad)\n            do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2)\n            # import flash_attn_3_cuda\n            # dq_unpad, dk_unpad, dv_unpad, softmax_d, dq_accum, lse_log2 = flash_attn_3_cuda.bwd_varlen(\n            #     g_unpad,\n            #     q_unpad,\n            #     k_unpad,\n            #     v_unpad,\n            #     out_unpad,\n            #     lse,\n            #     None,\n            #     None,\n            #     None,\n            #     cu_seqlens_q,\n            #     cu_seqlens_k,\n            #     None, None,\n            #     max_seqlen_q,\n            #     max_seqlen_k,\n            #     d ** (-0.5),\n            #     causal,\n            #     window_size[0], window_size[1],\n            #     softcap,\n            #     deterministic,\n            #     0,  # sm_margin\n            # )\n            dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(\n                out_unpad, (q_unpad, k_unpad, v_unpad), g_unpad\n            )\n            dq = dq_pad_fn(dq_unpad)\n            dk = dk_pad_fn(dk_unpad)\n            dv = dk_pad_fn(dv_unpad)\n            if key_unused_mask is not None:\n                k_zero_masking = rearrange(key_unused_mask, \"b s -> b s 1 1\")\n                dk.masked_fill_(k_zero_masking, 0.0)\n                dv.masked_fill_(k_zero_masking, 0.0)\n            if query_unused_mask is not None:\n                dq.masked_fill_(q_zero_masking, 0.0)\n            # print(f\"dO_O max diff: {(softmax_d - do_o).abs().max().item()}\")\n            # assert (softmax_d - do_o).abs().max().item() <= 1e-5\n            # assert dq_accum.abs().max().item() == 0.0\n            g = output_pad_fn(g_unpad)\n\n            # qk = torch.einsum('bthd,bshd->bhts', q / (d ** 0.5), k).float()\n            # qk = torch.masked_fill(qk, rearrange(~key_padding_mask, \"b s -> b 1 1 s\"), float(\"-inf\"))\n            # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float())\n            # P = torch.softmax(qk, -1)\n            # dP = P * (dS - (g.float() * out.float()).sum(-1).transpose(1, 2).unsqueeze(-1))\n            # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float())\n            # dV = torch.einsum('bhts,bthd->bshd', P, g.float())\n            # dK = torch.einsum('bhts,bthd->bshd', dP, q.float())\n\n            # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)\n            dq_ref, dk_ref, dv_ref = torch.autograd.grad(\n                out_ref, (q_ref, k_ref, v_ref), g\n            )\n            dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g)\n            print(f\"dQ max diff: {(dq - dq_ref).abs().max().item()}\")\n            print(f\"dK max diff: {(dk - dk_ref).abs().max().item()}\")\n            print(f\"dV max diff: {(dv - dv_ref).abs().max().item()}\")\n            print(f\"dQ mean diff: {(dq - dq_ref).abs().mean().item()}\")\n            print(f\"dK mean diff: {(dk - dk_ref).abs().mean().item()}\")\n            print(f\"dV mean diff: {(dv - dv_ref).abs().mean().item()}\")\n            print(f\"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}\")\n            print(f\"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}\")\n            print(f\"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}\")\n            print(f\"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}\")\n            print(f\"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}\")\n            print(f\"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}\")\n            # breakpoint()\n            dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (\n                0 if softcap == 0 else 3e-4\n            )\n            assert (dq - dq_ref).abs().max().item() <= rtol * (\n                dq_pt - dq_ref\n            ).abs().max().item() + dq_atol\n            dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (\n                0 if softcap == 0 else 3e-4\n            )\n            assert (dk - dk_ref).abs().max().item() <= rtol * (\n                dk_pt - dk_ref\n            ).abs().max().item() + dk_atol\n            dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (\n                0 if softcap == 0 else 3e-4\n            )\n            assert (dv - dv_ref).abs().max().item() <= rtol * (\n                dv_pt - dv_ref\n            ).abs().max().item() + dv_atol\n\n\n@pytest.mark.skipif(\n    skip_condition, reason=\"FA4 Requires compute capability of 10 or above.\"\n)\n# @pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])\n@pytest.mark.parametrize(\"dtype\", [torch.bfloat16])\n# @pytest.mark.parametrize(\"dtype\", [torch.float8_e4m3fn])\n@pytest.mark.parametrize(\"mha_type\", [\"mha\", \"mqa\", \"gqa\"])\n# @pytest.mark.parametrize(\"mha_type\", [\"mha\"])\n@pytest.mark.parametrize(\"has_learnable_sink\", [False, True])\n# @pytest.mark.parametrize(\"has_learnable_sink\", [False])\n# @pytest.mark.parametrize(\"new_kv\", [False, True])\n@pytest.mark.parametrize(\"new_kv\", [False])\n# @pytest.mark.parametrize(\"local\", [False, True])\n@pytest.mark.parametrize(\"local\", [False])\n# @pytest.mark.parametrize(\"causal\", [False, True])\n@pytest.mark.parametrize(\"causal\", [True])\n# @pytest.mark.parametrize(\"seqlen_new_eq_seqlen_q\", [True, False])\n@pytest.mark.parametrize(\"seqlen_new_eq_seqlen_q\", [False])\n# @pytest.mark.parametrize(\"has_rotary_seqlens\", [False, True])\n@pytest.mark.parametrize(\"has_rotary_seqlens\", [False])\n# @pytest.mark.parametrize(\"rotary_interleaved\", [False, True])\n@pytest.mark.parametrize(\"rotary_interleaved\", [True])\n# @pytest.mark.parametrize(\"rotary_fraction\", [0.0, 0.5, 1.0])\n@pytest.mark.parametrize(\"rotary_fraction\", [0.0])\n# @pytest.mark.parametrize(\"page_size\", [None] + ([1, 4, 128]))\n# @pytest.mark.parametrize(\"page_size\", [None, 128])\n@pytest.mark.parametrize(\"page_size\", [128])\n# @pytest.mark.parametrize(\"has_leftpad\", [False, True])\n@pytest.mark.parametrize(\"has_leftpad\", [False])\n# @pytest.mark.parametrize(\"has_batch_idx\", [False, True])\n@pytest.mark.parametrize(\"has_batch_idx\", [False])\n# @pytest.mark.parametrize(\"varlen_q\", [False, True])\n@pytest.mark.parametrize(\"varlen_q\", [False])\n# @pytest.mark.parametrize(\"d\", [32, 59, 64, 80, 128, 256])\n# @pytest.mark.parametrize(\"d\", [32, 64, 96, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])\n# @pytest.mark.parametrize('d', [56, 80])\n# @pytest.mark.parametrize(\"d\", [128])\n@pytest.mark.parametrize(\"d\", [64])\n# @pytest.mark.parametrize(\"d\", [192])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        (1, 128),\n        (1, 339),\n        (3, 1024),\n        (64, 800),\n        (64, 256),\n        (3, 799),\n        (64, 2048),\n        (16, 20000),\n        # # (1, 128 * 1024),\n        # # (16, 128 * 1024),\n        # (128, 128),\n        # (256, 512),  # To test appending KV with more than 1 block\n        # (2048, 3577),  # Enough tile to test persistent scheduler\n    ],\n)\n# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])\ndef test_flash_attn_kvcache(\n    seqlen_q,\n    seqlen_k,\n    d,\n    varlen_q,\n    has_batch_idx,\n    has_leftpad,\n    page_size,\n    rotary_fraction,\n    rotary_interleaved,\n    has_rotary_seqlens,\n    seqlen_new_eq_seqlen_q,\n    causal,\n    local,\n    new_kv,\n    has_learnable_sink,\n    mha_type,\n    dtype,\n):\n    if page_size is not None and seqlen_k % page_size != 0:\n        pytest.skip()\n    if seqlen_q > seqlen_k and new_kv:\n        pytest.skip()\n    if not new_kv and rotary_fraction > 0.0:\n        pytest.skip()\n    if rotary_fraction == 0.0 and has_rotary_seqlens:\n        pytest.skip()\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 5\n    # batch_size = 1\n    batch_size_cache = batch_size if not has_batch_idx else batch_size * 2\n    nheads = 6\n    # nheads = 1\n    # rotary_dim must be a multiple of 16, and must be <= d\n    rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16\n    nheads_k = nheads if mha_type == \"mha\" else (1 if mha_type == \"mqa\" else 3)\n    assert nheads % nheads_k == 0\n    dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype\n    # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d])\n    dv_vals = [d]\n    if dtype == torch.float8_e4m3fn:\n        dv_vals = [d]\n    # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if (causal or local) else [0]\n    attention_chunk_vals = [0]\n    for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals):\n        # has_qv = d == 64 and dv >= 256\n        has_qv = False\n        q = (\n            torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref)\n            .to(dtype)\n            .to(dtype_ref)\n        )\n        if has_qv:\n            qv = (\n                torch.randn(\n                    batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref\n                )\n                .to(dtype)\n                .to(dtype_ref)\n            )\n        else:\n            qv = None\n        if varlen_q:\n            query_padding_mask = generate_random_padding_mask(\n                seqlen_q, batch_size, device, mode=\"random\"\n            )\n            q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, *rest = unpad_input(\n                q, query_padding_mask\n            )\n            output_pad_fn = lambda output_unpad: pad_input(\n                output_unpad, indices_q, batch_size, seqlen_q\n            )\n            qv_unpad = (\n                rearrange(qv, \"b s ... -> (b s) ...\")[indices_q] if has_qv else None\n            )\n        else:\n            query_padding_mask = None\n            q_unpad = q\n            qv_unpad = qv\n            cu_seqlens_q, max_seqlen_q = None, None\n        # Put window_size after QKV randn so that window_size changes from test to test\n        window_size = (\n            (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist()\n        )\n        if has_learnable_sink:\n            learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device)\n        else:\n            learnable_sink = None\n\n        seqlen_new = (\n            seqlen_q\n            if seqlen_new_eq_seqlen_q\n            else torch.randint(1, seqlen_q + 1, (1,)).item()\n        )\n        cu_seqlens_k_new = None\n        key_new_padding_mask = None\n        if new_kv:\n            k = (\n                torch.randn(\n                    batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref\n                )\n                .to(dtype)\n                .to(dtype_ref)\n            )\n            v = (\n                torch.randn(\n                    batch_size, seqlen_new, nheads_k, dv, device=device, dtype=dtype_ref\n                )\n                .to(dtype)\n                .to(dtype_ref)\n            )\n            if varlen_q:  # k & v are also varlen\n                key_new_padding_mask = generate_random_padding_mask(\n                    seqlen_new, batch_size, device, mode=\"random\"\n                )\n                k_unpad, indices_k, cu_seqlens_k_new, *rest = unpad_input(\n                    k, key_new_padding_mask\n                )\n                v_unpad, *rest = unpad_input(v, key_new_padding_mask)\n            else:\n                k_unpad, v_unpad = k, v\n        else:\n            k, v, k_unpad, v_unpad = None, None, None, None\n        if page_size is None:\n            k_cache = (\n                torch.randn(\n                    batch_size_cache,\n                    seqlen_k,\n                    nheads_k,\n                    d,\n                    device=device,\n                    dtype=dtype_ref,\n                )\n                .to(dtype)\n                .to(dtype_ref)\n            )\n            v_cache = (\n                torch.randn(\n                    batch_size_cache,\n                    seqlen_k,\n                    nheads_k,\n                    dv,\n                    device=device,\n                    dtype=dtype_ref,\n                )\n                .to(dtype)\n                .to(dtype_ref)\n            )\n            page_table = None\n            num_blocks = None\n        else:\n            (\n                k_cache,\n                v_cache,\n                page_table,\n                k_cache_paged,\n                v_cache_paged,\n                num_blocks,\n            ) = _generate_block_kvcache(\n                seqlen_k,\n                page_size,\n                batch_size_cache,\n                nheads_k,\n                d,\n                dv,\n                device,\n                dtype,\n                dtype_ref,\n            )\n        cache_seqlens = torch.randint(\n            0 if new_kv else 1,\n            # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough\n            (\n                (\n                    seqlen_k\n                    - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new)\n                    + 1\n                )\n                if new_kv\n                else (seqlen_k + 1)\n            ),\n            (batch_size,),\n            dtype=torch.int32,\n            device=device,\n        )\n        if has_leftpad:\n            cache_leftpad = torch.cat(\n                [\n                    (\n                        torch.randint(\n                            0,\n                            cache_seqlens[i].item(),\n                            (1,),\n                            dtype=torch.int32,\n                            device=device,\n                        )\n                        if cache_seqlens[i].item() > 0\n                        else torch.zeros(1, dtype=torch.int32, device=device)\n                    )\n                    for i in range(batch_size)\n                ]\n            )\n        else:\n            cache_leftpad = None\n        if has_batch_idx:\n            cache_batch_idx = torch.randperm(\n                batch_size_cache, dtype=torch.int32, device=device\n            )[:batch_size]\n        else:\n            cache_batch_idx = None\n        arange = rearrange(torch.arange(seqlen_k, device=device), \"s -> 1 s\")\n        cache_seqlens_expanded = rearrange(cache_seqlens, \"b -> b 1\")\n        if not new_kv:\n            key_padding_mask = arange < cache_seqlens_expanded\n        else:\n            k_new_seqlens = (\n                key_new_padding_mask.sum(-1, keepdims=True) if varlen_q else seqlen_new\n            )\n            key_padding_mask = arange < cache_seqlens_expanded + k_new_seqlens\n        if has_leftpad:\n            key_padding_mask = torch.logical_and(\n                key_padding_mask,\n                arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k),\n            )\n        # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)\n        rotary_seqlens = cache_seqlens if not has_rotary_seqlens else cache_seqlens // 2\n        if rotary_dim > 0:\n            angle = (\n                torch.rand(\n                    seqlen_k if page_size is None else num_blocks * page_size,\n                    rotary_dim // 2,\n                    device=device,\n                )\n                * 2\n                * math.pi\n            )\n            cos = torch.cos(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref)\n            sin = torch.sin(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref)\n            if causal or local:\n                q_ro = apply_rotary_emb(\n                    q,\n                    cos,\n                    sin,\n                    seqlen_offsets=rotary_seqlens,\n                    interleaved=rotary_interleaved,\n                )\n            else:\n                q_ro = rearrange(\n                    apply_rotary_emb(\n                        rearrange(q, \"b s h d -> b 1 (s h) d\"),\n                        cos,\n                        sin,\n                        seqlen_offsets=rotary_seqlens,\n                        interleaved=rotary_interleaved,\n                    ),\n                    \"b 1 (s h) d -> b s h d\",\n                    s=seqlen_q,\n                )\n            # q_ro = q\n            k_ro = apply_rotary_emb(\n                k,\n                cos,\n                sin,\n                seqlen_offsets=rotary_seqlens,\n                interleaved=rotary_interleaved,\n            )\n        else:\n            cos, sin = None, None\n            q_ro, k_ro = q, k\n        # k_cache[:, 64:] = -1\n        k_cache_ref = (\n            k_cache if not has_batch_idx else k_cache[cache_batch_idx]\n        ).clone()\n        v_cache_ref = (\n            v_cache if not has_batch_idx else v_cache[cache_batch_idx]\n        ).clone()\n        if new_kv:\n            update_mask = torch.logical_and(\n                cache_seqlens_expanded <= arange,\n                arange < cache_seqlens_expanded + k_new_seqlens,\n            )\n            k_to_update = rearrange(k_ro, \"b s ... -> (b s) ...\")\n            v_to_update = rearrange(v, \"b s ... -> (b s) ...\")\n            if varlen_q:\n                k_to_update = k_to_update[indices_k]\n                v_to_update = v_to_update[indices_k]\n            k_cache_ref[update_mask] = k_to_update\n            v_cache_ref[update_mask] = v_to_update\n        k_cache_rep = repeat(\n            k_cache_ref, \"b s h d -> b s (h g) d\", g=nheads // nheads_k\n        )\n        v_cache_rep = repeat(\n            v_cache_ref, \"b s h d -> b s (h g) d\", g=nheads // nheads_k\n        )\n        out_ref, _ = attention_ref(\n            q_ro,\n            k_cache_rep,\n            v_cache_rep,\n            query_padding_mask,\n            key_padding_mask,\n            causal=causal,\n            qv=qv,\n            window_size=window_size,\n            learnable_sink=learnable_sink,\n            attention_chunk=attention_chunk,\n            key_leftpad=cache_leftpad,\n        )\n        out_pt, _ = attention_ref(\n            q_ro,\n            k_cache_rep,\n            v_cache_rep,\n            query_padding_mask,\n            key_padding_mask,\n            causal=causal,\n            qv=qv,\n            window_size=window_size,\n            learnable_sink=learnable_sink,\n            attention_chunk=attention_chunk,\n            upcast=False,\n            reorder_ops=True,\n            key_leftpad=cache_leftpad,\n            intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None,\n        )\n        q = q.to(dtype)\n        q_unpad = q_unpad.to(dtype) if varlen_q else None\n        k_cache = k_cache.to(dtype)\n        v_cache = v_cache.to(dtype)\n        k_cache_paged = k_cache_paged.to(dtype) if page_size is not None else None\n        v_cache_paged = v_cache_paged.to(dtype) if page_size is not None else None\n        k = k.to(dtype) if k is not None else None\n        v = v.to(dtype) if v is not None else None\n        k_unpad = k_unpad.to(dtype) if k_unpad is not None else None\n        v_unpad = v_unpad.to(dtype) if v_unpad is not None else None\n        qv = qv.to(dtype) if qv is not None else None\n        qv_unpad = qv_unpad.to(dtype) if (varlen_q and qv is not None) else None\n        cos = cos.to(dtype) if cos is not None else None\n        sin = sin.to(dtype) if sin is not None else None\n        k_cache_saved = k_cache.clone() if page_size is None else k_cache_paged.clone()\n        v_cache_saved = v_cache.clone() if page_size is None else v_cache_paged.clone()\n        # num_splits_vals = [1, 0]\n        num_splits_vals = [1]\n        # precompute_metadata_vals = [False, True]\n        precompute_metadata_vals = [False]\n        for num_splits, precompute_metadata in itertools.product(\n            num_splits_vals, precompute_metadata_vals\n        ):\n            # if precompute_metadata:\n            #     scheduler_metadata = get_scheduler_metadata(\n            #         batch_size, max_seqlen_q if varlen_q else seqlen_q, seqlen_k, nheads, nheads_k, d,\n            #         cache_seqlens, q.dtype, headdim_v=dv, cu_seqlens_q=cu_seqlens_q,\n            #         cu_seqlens_k_new=cu_seqlens_k_new, cache_leftpad=cache_leftpad,\n            #         max_seqlen_k_new=seqlen_new, page_size=page_size,\n            #         causal=causal, window_size=window_size, attention_chunk=attention_chunk,\n            #         num_splits=num_splits\n            #     )\n            # else:\n            #     scheduler_metadata = None\n            scheduler_metadata = None\n            # Repeat to test metadata reuse\n            for _ in range(1 if not precompute_metadata else 2):\n                if page_size is None:\n                    k_cache.copy_(k_cache_saved)\n                    v_cache.copy_(v_cache_saved)\n                else:\n                    k_cache_paged.copy_(k_cache_saved)\n                    v_cache_paged.copy_(v_cache_saved)\n                # For FA4, use flash_attn_varlen_func directly instead of flash_attn_with_kvcache\n                # This matches the pattern from the original FA4 test\n                out, lse = flash_attn_varlen_func(\n                    q if not varlen_q else q_unpad,\n                    k_cache if page_size is None else k_cache_paged,\n                    v_cache if page_size is None else v_cache_paged,\n                    cu_seqlens_q=cu_seqlens_q,\n                    cu_seqlens_k=None,  # FA4 doesn't use cu_seqlens_k for KV cache\n                    # max_seqlen_q and max_seqlen_k not needed for FA4\n                    seqused_k=cache_seqlens,  # Use cache_seqlens as seqused_k\n                    page_table=page_table,\n                    causal=causal,\n                    window_size=window_size,\n                    sinks=learnable_sink,  # FA4 uses learnable_sink, not sinks\n                    softcap=0.0,\n                    pack_gqa=None,\n                    return_softmax_lse=True,\n                )\n                if varlen_q:\n                    out = output_pad_fn(out)\n                # out = flash_attn_with_kvcache(\n                #     q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size\n                # )\n                # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size)\n                # qk = torch.einsum(\"bqhd,bkhd->bhqk\", q, k_cache_ref)\n                # m = qk.amax(-1, keepdim=True)\n                # s_tmp = torch.exp((qk - m) / math.sqrt(d))\n                # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref)\n                # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)\n                # probs = torch.softmax(qk, dim=-1)\n                print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n                print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n                print(f\"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}\")\n                print(f\"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}\")\n                # breakpoint()\n\n                # Check that FlashAttention's numerical error is at most twice the numerical error\n                # of a Pytorch implementation.\n                if new_kv:\n                    if page_size is None:\n                        k_cache_select = (\n                            k_cache.to(dtype_ref)\n                            if not has_batch_idx\n                            else k_cache.to(dtype_ref)[cache_batch_idx]\n                        )\n                        v_cache_select = (\n                            v_cache.to(dtype_ref)\n                            if not has_batch_idx\n                            else v_cache.to(dtype_ref)[cache_batch_idx]\n                        )\n                    else:\n                        k_cache_select = rearrange(\n                            k_cache_paged.to(dtype_ref)[\n                                (\n                                    page_table\n                                    if not has_batch_idx\n                                    else page_table[cache_batch_idx]\n                                ).flatten()\n                            ],\n                            \"(b nblocks) block_size ... -> b (nblocks block_size) ...\",\n                            b=batch_size,\n                        )[:, :seqlen_k].to(dtype_ref)\n                        v_cache_select = rearrange(\n                            v_cache_paged.to(dtype_ref)[\n                                (\n                                    page_table\n                                    if not has_batch_idx\n                                    else page_table[cache_batch_idx]\n                                ).flatten()\n                            ],\n                            \"(b nblocks) block_size ... -> b (nblocks block_size) ...\",\n                            b=batch_size,\n                        )[:, :seqlen_k].to(dtype_ref)\n                    k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref)\n                    v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref)\n                    if dtype is not torch.float8_e4m3fn:\n                        assert torch.equal(v_cache_select, v_cache_ref)\n                    else:\n                        assert torch.allclose(\n                            v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3\n                        )\n                    # breakpoint()\n                    # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn:\n                    if rotary_dim == 0:\n                        assert torch.equal(k_cache_select, k_cache_ref)\n                    else:\n                        # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3):\n                        #     breakpoint()\n                        if dtype is not torch.float8_e4m3fn:\n                            assert torch.allclose(\n                                k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3\n                            )\n                        else:\n                            assert torch.allclose(\n                                k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1\n                            )\n                mult = 4 if dtype == torch.float8_e4m3fn else 2\n                assert (out - out_ref).abs().max().item() <= mult * (\n                    out_pt - out_ref\n                ).abs().max().item() + 1e-5\n                mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5\n                assert (out - out_ref).abs().mean().item() <= mult_mean * (\n                    out_pt - out_ref\n                ).abs().mean().item()\n\n\ndef _generate_block_kvcache(\n    seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype, dtype_ref\n):\n    num_blocks = math.ceil(seqlen_k / page_size) * batch_size * 3\n    k_cache_paged = (\n        torch.randn(num_blocks, page_size, nheads_k, d, device=device, dtype=dtype_ref)\n        .to(dtype)\n        .to(dtype_ref)\n    )\n    v_cache_paged = (\n        torch.randn(num_blocks, page_size, nheads_k, dv, device=device, dtype=dtype_ref)\n        .to(dtype)\n        .to(dtype_ref)\n    )\n    page_table = rearrange(\n        torch.randperm(num_blocks, dtype=torch.int32, device=device),\n        \"(b nblocks) -> b nblocks\",\n        b=batch_size,\n    )\n    k_cache = rearrange(\n        k_cache_paged[page_table.flatten()],\n        \"(b nblocks) block_size ... -> b (nblocks block_size) ...\",\n        b=batch_size,\n    )[:, :seqlen_k]\n    v_cache = rearrange(\n        v_cache_paged[page_table.flatten()],\n        \"(b nblocks) block_size ... -> b (nblocks block_size) ...\",\n        b=batch_size,\n    )[:, :seqlen_k]\n    return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__, \"-v\", \"-s\"])\n"
  },
  {
    "path": "python/sglang/jit_kernel/tests/test_fused_add_rmsnorm.py",
    "content": "import itertools\n\nimport pytest\nimport torch\n\nfrom sglang.jit_kernel.utils import get_ci_test_range\n\n\ndef sglang_jit_fused_add_rmsnorm(\n    input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float\n) -> None:\n    from sglang.jit_kernel.norm import fused_add_rmsnorm\n\n    fused_add_rmsnorm(input, residual, weight, eps)\n\n\ndef flashinfer_fused_add_rmsnorm(\n    input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float\n) -> None:\n    from flashinfer.norm import fused_add_rmsnorm\n\n    fused_add_rmsnorm(input, residual, weight, eps=eps)\n\n\nBS_LIST = [2**n for n in range(0, 14)]\nBS_LIST += [x + 1 + i for i, x in enumerate(BS_LIST)]\nBS_LIST = get_ci_test_range(BS_LIST, [1, 9, 256, 4109])\nHIDDEN_SIZE_LIST = get_ci_test_range(\n    [512, 1024, 1536, 2048, 3072, 4096, 5120, 6144, 7168, 8192],\n    [512, 2048, 8192],\n)\nDEVICE = \"cuda\"\nDTYPE = torch.bfloat16\n\n\n@pytest.mark.parametrize(\n    \"batch_size,hidden_size\", list(itertools.product(BS_LIST, HIDDEN_SIZE_LIST))\n)\ndef test_fused_add_rmsnorm(batch_size: int, hidden_size: int) -> None:\n    input = torch.randn(batch_size, hidden_size, device=DEVICE, dtype=DTYPE)\n    residual = torch.randn(batch_size, hidden_size, device=DEVICE, dtype=DTYPE)\n    weight = torch.randn(hidden_size, device=DEVICE, dtype=DTYPE)\n\n    input_sglang = input.clone()\n    residual_sglang = residual.clone()\n    input_flashinfer = input.clone()\n    residual_flashinfer = residual.clone()\n    sglang_jit_fused_add_rmsnorm(\n        input_sglang, residual_sglang, weight, torch.finfo(torch.bfloat16).eps\n    )\n    flashinfer_fused_add_rmsnorm(\n        input_flashinfer, residual_flashinfer, weight, torch.finfo(torch.bfloat16).eps\n    )\n    torch.testing.assert_close(input_sglang, input_flashinfer, atol=1e-2, rtol=1e-2)\n    torch.testing.assert_close(\n        residual_sglang, residual_flashinfer, atol=1e-2, rtol=1e-2\n    )\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__, \"-v\", \"-s\"])\n"
  },
  {
    "path": "python/sglang/jit_kernel/tests/test_fused_metadata_copy.py",
    "content": "\"\"\"\nComprehensive tests for JIT-compiled fused metadata copy kernels.\n\nThis test suite verifies:\n1. Single-backend fused kernel (fused_metadata_copy_cuda) - all forward modes\n2. Multi-backend fused kernel (fused_metadata_copy_multi_cuda) - 3 backends at once\n3. Correctness against reference implementations\n4. Performance benchmarks and speedup measurements\n\"\"\"\n\nimport time\n\nimport pytest\nimport torch\n\n# =============================================================================\n# Helper Functions\n# =============================================================================\n\n\ndef create_test_metadata(\n    bs: int,\n    max_len: int,\n    max_seqlen_k: int,\n    seqlens_expanded_size: int,\n    has_real_page_table: bool = False,\n    has_flashmla: bool = False,\n    device: str = \"cuda\",\n):\n    \"\"\"Create test metadata tensors matching NSA backend structure.\"\"\"\n    # Basic tensors (always present)\n    cache_seqlens_src = torch.randint(\n        1, max_len, (bs,), dtype=torch.int32, device=device\n    )\n    cu_seqlens_k_src = torch.zeros(bs + 1, dtype=torch.int32, device=device)\n    cu_seqlens_k_src[1:] = torch.cumsum(cache_seqlens_src, dim=0)\n\n    page_indices_src = torch.randint(\n        0, 1000, (bs, max_len), dtype=torch.int32, device=device\n    )\n    nsa_cache_seqlens_src = torch.randint(\n        1, max_len, (seqlens_expanded_size,), dtype=torch.int32, device=device\n    )\n    seqlens_expanded_src = torch.randint(\n        1, max_seqlen_k, (seqlens_expanded_size,), dtype=torch.int32, device=device\n    )\n    nsa_cu_seqlens_k_src = torch.zeros(\n        seqlens_expanded_size + 1, dtype=torch.int32, device=device\n    )\n    nsa_cu_seqlens_k_src[1:] = torch.cumsum(nsa_cache_seqlens_src, dim=0)\n\n    # Destination tensors\n    cache_seqlens_dst = torch.zeros(bs, dtype=torch.int32, device=device)\n    cu_seqlens_k_dst = torch.zeros(bs + 1, dtype=torch.int32, device=device)\n    page_table_1_dst = torch.zeros((bs, max_len + 16), dtype=torch.int32, device=device)\n    nsa_cache_seqlens_dst = torch.zeros(\n        seqlens_expanded_size, dtype=torch.int32, device=device\n    )\n    nsa_seqlens_expanded_dst = torch.zeros(\n        seqlens_expanded_size, dtype=torch.int32, device=device\n    )\n    nsa_cu_seqlens_k_dst = torch.zeros(\n        seqlens_expanded_size + 1, dtype=torch.int32, device=device\n    )\n\n    # Optional tensors\n    real_page_table_src = None\n    real_page_table_dst = None\n    if has_real_page_table:\n        real_page_table_cols = max_len // 2\n        real_page_table_src = torch.randint(\n            0, 1000, (bs, real_page_table_cols), dtype=torch.int32, device=device\n        )\n        real_page_table_dst = torch.zeros(\n            (bs, real_page_table_cols + 8), dtype=torch.int32, device=device\n        )\n\n    flashmla_num_splits_src = None\n    flashmla_num_splits_dst = None\n    flashmla_metadata_src = None\n    flashmla_metadata_dst = None\n    if has_flashmla:\n        flashmla_num_splits_src = torch.randint(\n            1, 10, (seqlens_expanded_size + 1,), dtype=torch.int32, device=device\n        )\n        flashmla_num_splits_dst = torch.zeros(\n            seqlens_expanded_size + 1, dtype=torch.int32, device=device\n        )\n        # FlashMLA metadata is typically (num_sm_parts, TileSchedulerMetaDataSize)\n        # For testing, we use a simplified size\n        flashmla_metadata_size = 128\n        flashmla_metadata_src = torch.randint(\n            0, 100, (flashmla_metadata_size,), dtype=torch.int32, device=device\n        )\n        flashmla_metadata_dst = torch.zeros(\n            flashmla_metadata_size, dtype=torch.int32, device=device\n        )\n\n    return {\n        \"src\": {\n            \"cache_seqlens\": cache_seqlens_src,\n            \"cu_seqlens_k\": cu_seqlens_k_src,\n            \"page_indices\": page_indices_src,\n            \"nsa_cache_seqlens\": nsa_cache_seqlens_src,\n            \"seqlens_expanded\": seqlens_expanded_src,\n            \"nsa_cu_seqlens_k\": nsa_cu_seqlens_k_src,\n            \"real_page_table\": real_page_table_src,\n            \"flashmla_num_splits\": flashmla_num_splits_src,\n            \"flashmla_metadata\": flashmla_metadata_src,\n        },\n        \"dst\": {\n            \"cache_seqlens\": cache_seqlens_dst,\n            \"cu_seqlens_k\": cu_seqlens_k_dst,\n            \"page_table_1\": page_table_1_dst,\n            \"nsa_cache_seqlens\": nsa_cache_seqlens_dst,\n            \"nsa_seqlens_expanded\": nsa_seqlens_expanded_dst,\n            \"nsa_cu_seqlens_k\": nsa_cu_seqlens_k_dst,\n            \"real_page_table\": real_page_table_dst,\n            \"flashmla_num_splits\": flashmla_num_splits_dst,\n            \"flashmla_metadata\": flashmla_metadata_dst,\n        },\n    }\n\n\ndef reference_copy_decode(src, dst, max_len):\n    \"\"\"Reference implementation: individual .copy_() for DECODE mode.\"\"\"\n    bs = src[\"cache_seqlens\"].shape[0]\n    dst[\"cache_seqlens\"].copy_(src[\"cache_seqlens\"])\n    dst[\"cu_seqlens_k\"][1:].copy_(src[\"cu_seqlens_k\"][1:])\n    dst[\"page_table_1\"][:, :max_len].copy_(src[\"page_indices\"])\n    dst[\"nsa_cache_seqlens\"].copy_(src[\"nsa_cache_seqlens\"])\n    dst[\"nsa_cu_seqlens_k\"][1 : bs + 1].copy_(src[\"nsa_cu_seqlens_k\"][1 : bs + 1])\n\n    if src[\"real_page_table\"] is not None:\n        rows, cols = src[\"real_page_table\"].shape\n        dst[\"real_page_table\"][:rows, :cols].copy_(src[\"real_page_table\"])\n\n    if src[\"flashmla_num_splits\"] is not None:\n        flashmla_size = bs + 1\n        dst[\"flashmla_num_splits\"][:flashmla_size].copy_(\n            src[\"flashmla_num_splits\"][:flashmla_size]\n        )\n\n    if src[\"flashmla_metadata\"] is not None:\n        dst[\"flashmla_metadata\"].copy_(src[\"flashmla_metadata\"])\n\n\ndef reference_copy_target_verify(src, dst, max_seqlen_k, seqlens_expanded_size):\n    \"\"\"Reference implementation: individual .copy_() for TARGET_VERIFY mode.\"\"\"\n    bs = src[\"cache_seqlens\"].shape[0]\n    dst[\"cache_seqlens\"].copy_(src[\"cache_seqlens\"])\n    dst[\"cu_seqlens_k\"][1:].copy_(src[\"cu_seqlens_k\"][1:])\n\n    rows, cols = src[\"page_indices\"].shape\n    dst[\"page_table_1\"][:rows, :cols].copy_(src[\"page_indices\"])\n    dst[\"nsa_seqlens_expanded\"][:seqlens_expanded_size].copy_(src[\"seqlens_expanded\"])\n    dst[\"nsa_cache_seqlens\"][:seqlens_expanded_size].copy_(src[\"nsa_cache_seqlens\"])\n    dst[\"nsa_cu_seqlens_k\"][1 : seqlens_expanded_size + 1].copy_(\n        src[\"nsa_cu_seqlens_k\"][1 : seqlens_expanded_size + 1]\n    )\n\n    if src[\"real_page_table\"] is not None:\n        rows, cols = src[\"real_page_table\"].shape\n        dst[\"real_page_table\"][:rows, :cols].copy_(src[\"real_page_table\"])\n\n    if src[\"flashmla_num_splits\"] is not None:\n        flashmla_size = seqlens_expanded_size + 1\n        dst[\"flashmla_num_splits\"][:flashmla_size].copy_(\n            src[\"flashmla_num_splits\"][:flashmla_size]\n        )\n\n    if src[\"flashmla_metadata\"] is not None:\n        dst[\"flashmla_metadata\"].copy_(src[\"flashmla_metadata\"])\n\n\ndef reference_copy_draft_extend(src, dst, max_seqlen_k, seqlens_expanded_size):\n    \"\"\"Reference implementation: individual .copy_() for DRAFT_EXTEND mode.\"\"\"\n    bs = src[\"cache_seqlens\"].shape[0]\n    dst[\"cache_seqlens\"].copy_(src[\"cache_seqlens\"])\n    dst[\"cu_seqlens_k\"][1:].copy_(src[\"cu_seqlens_k\"][1:])\n\n    rows, cols = src[\"page_indices\"].shape\n    dst[\"page_table_1\"][:rows, :cols].copy_(src[\"page_indices\"])\n    dst[\"nsa_seqlens_expanded\"][:seqlens_expanded_size].copy_(src[\"seqlens_expanded\"])\n    dst[\"nsa_cache_seqlens\"][:seqlens_expanded_size].copy_(src[\"nsa_cache_seqlens\"])\n    dst[\"nsa_cu_seqlens_k\"][1 : seqlens_expanded_size + 1].copy_(\n        src[\"nsa_cu_seqlens_k\"][1 : seqlens_expanded_size + 1]\n    )\n\n    if src[\"real_page_table\"] is not None:\n        rows, cols = src[\"real_page_table\"].shape\n        dst[\"real_page_table\"][:rows, :cols].copy_(src[\"real_page_table\"])\n\n    if src[\"flashmla_num_splits\"] is not None:\n        flashmla_size = seqlens_expanded_size + 1\n        dst[\"flashmla_num_splits\"][:flashmla_size].copy_(\n            src[\"flashmla_num_splits\"][:flashmla_size]\n        )\n\n    if src[\"flashmla_metadata\"] is not None:\n        dst[\"flashmla_metadata\"].copy_(src[\"flashmla_metadata\"])\n\n\n# =============================================================================\n# Single-Backend Kernel Tests\n# =============================================================================\n\n\ndef test_fused_metadata_copy_dtype_validation():\n    \"\"\"Test that dtype validation rejects non-int32 tensors.\"\"\"\n    if not torch.cuda.is_available():\n        pytest.skip(\"CUDA not available\")\n\n    from sglang.jit_kernel.fused_metadata_copy import fused_metadata_copy_cuda\n\n    bs = 2\n    max_len = 128\n    max_seqlen_k = 256\n    seqlens_expanded_size = bs\n    device = \"cuda\"\n\n    # Create tensors with WRONG dtype (int64 instead of int32)\n    cache_seqlens_src_wrong = torch.randint(\n        1, max_len, (bs,), dtype=torch.int64, device=device\n    )\n    cu_seqlens_k_src = torch.zeros(bs + 1, dtype=torch.int32, device=device)\n    page_indices_src = torch.randint(\n        0, 1000, (bs, max_len), dtype=torch.int32, device=device\n    )\n    nsa_cache_seqlens_src = torch.randint(\n        1, max_len, (seqlens_expanded_size,), dtype=torch.int32, device=device\n    )\n    seqlens_expanded_src = torch.randint(\n        1, max_seqlen_k, (seqlens_expanded_size,), dtype=torch.int32, device=device\n    )\n    nsa_cu_seqlens_k_src = torch.zeros(\n        seqlens_expanded_size + 1, dtype=torch.int32, device=device\n    )\n\n    # Destination tensors (correct dtype)\n    cache_seqlens_dst = torch.zeros(bs, dtype=torch.int32, device=device)\n    cu_seqlens_k_dst = torch.zeros(bs + 1, dtype=torch.int32, device=device)\n    page_table_1_dst = torch.zeros((bs, max_len + 16), dtype=torch.int32, device=device)\n    nsa_cache_seqlens_dst = torch.zeros(\n        seqlens_expanded_size, dtype=torch.int32, device=device\n    )\n    nsa_seqlens_expanded_dst = torch.zeros(\n        seqlens_expanded_size, dtype=torch.int32, device=device\n    )\n    nsa_cu_seqlens_k_dst = torch.zeros(\n        seqlens_expanded_size + 1, dtype=torch.int32, device=device\n    )\n\n    # Test 1: Wrong dtype for source tensor should raise RuntimeError\n    with pytest.raises(RuntimeError, match=\"must have dtype int32\"):\n        fused_metadata_copy_cuda(\n            cache_seqlens_src_wrong,  # Wrong dtype: int64\n            cu_seqlens_k_src,\n            page_indices_src,\n            nsa_cache_seqlens_src,\n            seqlens_expanded_src,\n            nsa_cu_seqlens_k_src,\n            None,  # real_page_table_src\n            None,  # flashmla_num_splits_src\n            None,  # flashmla_metadata_src\n            cache_seqlens_dst,\n            cu_seqlens_k_dst,\n            page_table_1_dst,\n            nsa_cache_seqlens_dst,\n            nsa_seqlens_expanded_dst,\n            nsa_cu_seqlens_k_dst,\n            None,  # real_page_table_dst\n            None,  # flashmla_num_splits_dst\n            None,  # flashmla_metadata_dst\n            0,  # forward_mode\n            bs,\n            max_len,\n            max_seqlen_k,\n            seqlens_expanded_size,\n        )\n\n    # Test 2: Wrong dtype for destination tensor should also raise RuntimeError\n    cache_seqlens_src = torch.randint(\n        1, max_len, (bs,), dtype=torch.int32, device=device\n    )\n    cache_seqlens_dst_wrong = torch.zeros(bs, dtype=torch.int64, device=device)\n\n    with pytest.raises(RuntimeError, match=\"must have dtype int32\"):\n        fused_metadata_copy_cuda(\n            cache_seqlens_src,\n            cu_seqlens_k_src,\n            page_indices_src,\n            nsa_cache_seqlens_src,\n            seqlens_expanded_src,\n            nsa_cu_seqlens_k_src,\n            None,\n            None,\n            None,\n            cache_seqlens_dst_wrong,  # Wrong dtype: int64\n            cu_seqlens_k_dst,\n            page_table_1_dst,\n            nsa_cache_seqlens_dst,\n            nsa_seqlens_expanded_dst,\n            nsa_cu_seqlens_k_dst,\n            None,\n            None,\n            None,\n            0,\n            bs,\n            max_len,\n            max_seqlen_k,\n            seqlens_expanded_size,\n        )\n\n\n@pytest.mark.parametrize(\"bs\", [1, 2, 4, 8])\n@pytest.mark.parametrize(\n    \"forward_mode\", [0]\n)  # DECODE mode only (other modes not fully tested yet)\n@pytest.mark.parametrize(\"has_real_page_table\", [False, True])\n@pytest.mark.parametrize(\"has_flashmla\", [False, True])\ndef test_fused_metadata_copy(bs, forward_mode, has_real_page_table, has_flashmla):\n    \"\"\"Test fused metadata copy kernel against reference implementation.\"\"\"\n    if not torch.cuda.is_available():\n        pytest.skip(\"CUDA not available\")\n\n    from sglang.jit_kernel.fused_metadata_copy import fused_metadata_copy_cuda\n\n    max_len = 128\n    max_seqlen_k = 256\n    seqlens_expanded_size = bs if forward_mode == 0 else bs * 2\n\n    # Create test data\n    data = create_test_metadata(\n        bs=bs,\n        max_len=max_len,\n        max_seqlen_k=max_seqlen_k,\n        seqlens_expanded_size=seqlens_expanded_size,\n        has_real_page_table=has_real_page_table,\n        has_flashmla=has_flashmla,\n    )\n\n    # Create separate destination tensors for reference and fused kernel\n    dst_ref = {k: v.clone() if v is not None else None for k, v in data[\"dst\"].items()}\n    dst_fused = {\n        k: v.clone() if v is not None else None for k, v in data[\"dst\"].items()\n    }\n\n    # Run reference implementation\n    if forward_mode == 0:  # DECODE\n        reference_copy_decode(data[\"src\"], dst_ref, max_len)\n    elif forward_mode == 1:  # TARGET_VERIFY\n        reference_copy_target_verify(\n            data[\"src\"], dst_ref, max_seqlen_k, seqlens_expanded_size\n        )\n    else:  # DRAFT_EXTEND\n        reference_copy_draft_extend(\n            data[\"src\"], dst_ref, max_seqlen_k, seqlens_expanded_size\n        )\n\n    # Run fused kernel\n    fused_metadata_copy_cuda(\n        data[\"src\"][\"cache_seqlens\"],\n        data[\"src\"][\"cu_seqlens_k\"],\n        data[\"src\"][\"page_indices\"],\n        data[\"src\"][\"nsa_cache_seqlens\"],\n        data[\"src\"][\"seqlens_expanded\"],\n        data[\"src\"][\"nsa_cu_seqlens_k\"],\n        data[\"src\"][\"real_page_table\"],\n        data[\"src\"][\"flashmla_num_splits\"],\n        data[\"src\"][\"flashmla_metadata\"],\n        dst_fused[\"cache_seqlens\"],\n        dst_fused[\"cu_seqlens_k\"],\n        dst_fused[\"page_table_1\"],\n        dst_fused[\"nsa_cache_seqlens\"],\n        dst_fused[\"nsa_seqlens_expanded\"],\n        dst_fused[\"nsa_cu_seqlens_k\"],\n        dst_fused[\"real_page_table\"],\n        dst_fused[\"flashmla_num_splits\"],\n        dst_fused[\"flashmla_metadata\"],\n        forward_mode,\n        bs,\n        max_len,\n        max_seqlen_k,\n        seqlens_expanded_size,\n    )\n\n    # Compare results\n    assert torch.equal(\n        dst_ref[\"cache_seqlens\"], dst_fused[\"cache_seqlens\"]\n    ), \"cache_seqlens mismatch\"\n    assert torch.equal(\n        dst_ref[\"cu_seqlens_k\"], dst_fused[\"cu_seqlens_k\"]\n    ), \"cu_seqlens_k mismatch\"\n    assert torch.equal(\n        dst_ref[\"page_table_1\"], dst_fused[\"page_table_1\"]\n    ), \"page_table_1 mismatch\"\n    assert torch.equal(\n        dst_ref[\"nsa_cache_seqlens\"], dst_fused[\"nsa_cache_seqlens\"]\n    ), \"nsa_cache_seqlens mismatch\"\n    assert torch.equal(\n        dst_ref[\"nsa_seqlens_expanded\"], dst_fused[\"nsa_seqlens_expanded\"]\n    ), \"nsa_seqlens_expanded mismatch\"\n    assert torch.equal(\n        dst_ref[\"nsa_cu_seqlens_k\"], dst_fused[\"nsa_cu_seqlens_k\"]\n    ), \"nsa_cu_seqlens_k mismatch\"\n\n    if has_real_page_table:\n        assert torch.equal(\n            dst_ref[\"real_page_table\"], dst_fused[\"real_page_table\"]\n        ), \"real_page_table mismatch\"\n\n    if has_flashmla:\n        assert torch.equal(\n            dst_ref[\"flashmla_num_splits\"], dst_fused[\"flashmla_num_splits\"]\n        ), \"flashmla_num_splits mismatch\"\n        assert torch.equal(\n            dst_ref[\"flashmla_metadata\"], dst_fused[\"flashmla_metadata\"]\n        ), \"flashmla_metadata mismatch\"\n\n\n@pytest.mark.parametrize(\"bs\", [16, 32])\ndef test_fused_metadata_copy_large_batch(bs):\n    \"\"\"Test with larger batch sizes.\"\"\"\n    if not torch.cuda.is_available():\n        pytest.skip(\"CUDA not available\")\n\n    from sglang.jit_kernel.fused_metadata_copy import fused_metadata_copy_cuda\n\n    forward_mode = 0  # DECODE\n    max_len = 128\n    max_seqlen_k = 256\n    seqlens_expanded_size = bs\n\n    data = create_test_metadata(\n        bs=bs,\n        max_len=max_len,\n        max_seqlen_k=max_seqlen_k,\n        seqlens_expanded_size=seqlens_expanded_size,\n        has_real_page_table=True,\n        has_flashmla=True,\n    )\n\n    dst_ref = {k: v.clone() if v is not None else None for k, v in data[\"dst\"].items()}\n    dst_fused = {\n        k: v.clone() if v is not None else None for k, v in data[\"dst\"].items()\n    }\n\n    reference_copy_decode(data[\"src\"], dst_ref, max_len)\n\n    fused_metadata_copy_cuda(\n        data[\"src\"][\"cache_seqlens\"],\n        data[\"src\"][\"cu_seqlens_k\"],\n        data[\"src\"][\"page_indices\"],\n        data[\"src\"][\"nsa_cache_seqlens\"],\n        data[\"src\"][\"seqlens_expanded\"],\n        data[\"src\"][\"nsa_cu_seqlens_k\"],\n        data[\"src\"][\"real_page_table\"],\n        data[\"src\"][\"flashmla_num_splits\"],\n        data[\"src\"][\"flashmla_metadata\"],\n        dst_fused[\"cache_seqlens\"],\n        dst_fused[\"cu_seqlens_k\"],\n        dst_fused[\"page_table_1\"],\n        dst_fused[\"nsa_cache_seqlens\"],\n        dst_fused[\"nsa_seqlens_expanded\"],\n        dst_fused[\"nsa_cu_seqlens_k\"],\n        dst_fused[\"real_page_table\"],\n        dst_fused[\"flashmla_num_splits\"],\n        dst_fused[\"flashmla_metadata\"],\n        forward_mode,\n        bs,\n        max_len,\n        max_seqlen_k,\n        seqlens_expanded_size,\n    )\n\n    # Verify all tensors match\n    for key in dst_ref:\n        if dst_ref[key] is not None:\n            assert torch.equal(dst_ref[key], dst_fused[key]), f\"{key} mismatch\"\n\n\n# =============================================================================\n# Multi-Backend Kernel Tests\n# =============================================================================\n\n\ndef create_test_metadata_multi(\n    bs: int,\n    max_len: int,\n    seqlens_expanded_size: int,\n    has_real_page_table: bool = False,\n    has_flashmla: bool = False,\n    device: str = \"cuda\",\n):\n    \"\"\"Create test metadata tensors for multi-backend testing.\"\"\"\n    # Source tensors (precomputed metadata)\n    cache_seqlens_src = torch.randint(\n        1, max_len, (bs,), dtype=torch.int32, device=device\n    )\n    cu_seqlens_k_src = torch.zeros(bs + 1, dtype=torch.int32, device=device)\n    cu_seqlens_k_src[1:] = torch.cumsum(cache_seqlens_src, dim=0)\n\n    page_indices_src = torch.randint(\n        0, 1000, (bs, max_len), dtype=torch.int32, device=device\n    )\n    nsa_cache_seqlens_src = torch.randint(\n        1, max_len, (seqlens_expanded_size,), dtype=torch.int32, device=device\n    )\n    nsa_cu_seqlens_k_src = torch.zeros(\n        seqlens_expanded_size + 1, dtype=torch.int32, device=device\n    )\n    nsa_cu_seqlens_k_src[1:] = torch.cumsum(nsa_cache_seqlens_src, dim=0)\n\n    # Optional tensors\n    real_page_table_src = None\n    if has_real_page_table:\n        real_page_table_cols = max_len // 2\n        real_page_table_src = torch.randint(\n            0, 1000, (bs, real_page_table_cols), dtype=torch.int32, device=device\n        )\n\n    flashmla_num_splits_src = None\n    flashmla_metadata_src = None\n    if has_flashmla:\n        flashmla_num_splits_src = torch.randint(\n            1, 10, (seqlens_expanded_size + 1,), dtype=torch.int32, device=device\n        )\n        flashmla_metadata_size = 128\n        flashmla_metadata_src = torch.randint(\n            0, 100, (flashmla_metadata_size,), dtype=torch.int32, device=device\n        )\n\n    # Create destination tensors for 3 backends\n    def create_dst_tensors():\n        cache_seqlens_dst = torch.zeros(bs, dtype=torch.int32, device=device)\n        cu_seqlens_k_dst = torch.zeros(bs + 1, dtype=torch.int32, device=device)\n        page_table_1_dst = torch.zeros(\n            (bs, max_len + 16), dtype=torch.int32, device=device\n        )\n        nsa_cache_seqlens_dst = torch.zeros(\n            seqlens_expanded_size, dtype=torch.int32, device=device\n        )\n        nsa_cu_seqlens_k_dst = torch.zeros(\n            seqlens_expanded_size + 1, dtype=torch.int32, device=device\n        )\n\n        real_page_table_dst = None\n        if has_real_page_table:\n            real_page_table_cols = max_len // 2\n            real_page_table_dst = torch.zeros(\n                (bs, real_page_table_cols + 8), dtype=torch.int32, device=device\n            )\n\n        flashmla_num_splits_dst = None\n        flashmla_metadata_dst = None\n        if has_flashmla:\n            flashmla_num_splits_dst = torch.zeros(\n                seqlens_expanded_size + 1, dtype=torch.int32, device=device\n            )\n            flashmla_metadata_size = 128\n            flashmla_metadata_dst = torch.zeros(\n                flashmla_metadata_size, dtype=torch.int32, device=device\n            )\n\n        return {\n            \"cache_seqlens_int32\": cache_seqlens_dst,\n            \"cu_seqlens_k\": cu_seqlens_k_dst,\n            \"page_table_1\": page_table_1_dst,\n            \"nsa_cache_seqlens_int32\": nsa_cache_seqlens_dst,\n            \"nsa_cu_seqlens_k\": nsa_cu_seqlens_k_dst,\n            \"real_page_table\": real_page_table_dst,\n            \"flashmla_num_splits\": flashmla_num_splits_dst,\n            \"flashmla_metadata\": flashmla_metadata_dst,\n        }\n\n    return {\n        \"src\": {\n            \"cache_seqlens\": cache_seqlens_src,\n            \"cu_seqlens_k\": cu_seqlens_k_src,\n            \"page_indices\": page_indices_src,\n            \"nsa_cache_seqlens\": nsa_cache_seqlens_src,\n            \"nsa_cu_seqlens_k\": nsa_cu_seqlens_k_src,\n            \"real_page_table\": real_page_table_src,\n            \"flashmla_num_splits\": flashmla_num_splits_src,\n            \"flashmla_metadata\": flashmla_metadata_src,\n        },\n        \"dst0\": create_dst_tensors(),\n        \"dst1\": create_dst_tensors(),\n        \"dst2\": create_dst_tensors(),\n    }\n\n\ndef reference_copy_for_loop(src, dst_list, bs, max_len):\n    \"\"\"Reference implementation: for-loop calling copy for each backend.\"\"\"\n    for dst in dst_list:\n        # Simulate what init_forward_metadata_replay_cuda_graph_from_precomputed does\n        dst[\"cache_seqlens_int32\"].copy_(src[\"cache_seqlens\"])\n        dst[\"cu_seqlens_k\"][1:].copy_(src[\"cu_seqlens_k\"][1:])\n        dst[\"page_table_1\"][:, :max_len].copy_(src[\"page_indices\"])\n        dst[\"nsa_cache_seqlens_int32\"].copy_(src[\"nsa_cache_seqlens\"])\n        dst[\"nsa_cu_seqlens_k\"][1 : bs + 1].copy_(src[\"nsa_cu_seqlens_k\"][1 : bs + 1])\n\n        if src[\"real_page_table\"] is not None:\n            rows, cols = src[\"real_page_table\"].shape\n            dst[\"real_page_table\"][:rows, :cols].copy_(src[\"real_page_table\"])\n\n        if src[\"flashmla_num_splits\"] is not None:\n            flashmla_size = bs + 1\n            dst[\"flashmla_num_splits\"][:flashmla_size].copy_(\n                src[\"flashmla_num_splits\"][:flashmla_size]\n            )\n\n        if src[\"flashmla_metadata\"] is not None:\n            dst[\"flashmla_metadata\"].copy_(src[\"flashmla_metadata\"])\n\n\ndef test_fused_metadata_copy_multi_dtype_validation():\n    \"\"\"Test that dtype validation rejects non-int32 tensors for multi-backend kernel.\"\"\"\n    if not torch.cuda.is_available():\n        pytest.skip(\"CUDA not available\")\n\n    from sglang.jit_kernel.fused_metadata_copy import fused_metadata_copy_multi_cuda\n\n    bs = 2\n    max_len = 128\n    seqlens_expanded_size = bs\n    device = \"cuda\"\n\n    # Create source tensors - one with WRONG dtype\n    cache_seqlens_src_wrong = torch.randint(\n        1, max_len, (bs,), dtype=torch.int64, device=device  # Wrong dtype!\n    )\n    cu_seqlens_k_src = torch.zeros(bs + 1, dtype=torch.int32, device=device)\n    page_indices_src = torch.randint(\n        0, 1000, (bs, max_len), dtype=torch.int32, device=device\n    )\n    nsa_cache_seqlens_src = torch.randint(\n        1, max_len, (seqlens_expanded_size,), dtype=torch.int32, device=device\n    )\n    nsa_cu_seqlens_k_src = torch.zeros(\n        seqlens_expanded_size + 1, dtype=torch.int32, device=device\n    )\n\n    # Create destination tensors for 3 backends (all correct dtype)\n    def create_dst():\n        return {\n            \"cache_seqlens\": torch.zeros(bs, dtype=torch.int32, device=device),\n            \"cu_seqlens_k\": torch.zeros(bs + 1, dtype=torch.int32, device=device),\n            \"page_table_1\": torch.zeros(\n                (bs, max_len + 16), dtype=torch.int32, device=device\n            ),\n            \"nsa_cache_seqlens\": torch.zeros(\n                seqlens_expanded_size, dtype=torch.int32, device=device\n            ),\n            \"nsa_cu_seqlens_k\": torch.zeros(\n                seqlens_expanded_size + 1, dtype=torch.int32, device=device\n            ),\n        }\n\n    dst0 = create_dst()\n    dst1 = create_dst()\n    dst2 = create_dst()\n\n    # Test: Wrong dtype for source tensor should raise RuntimeError\n    with pytest.raises(RuntimeError, match=\"must have dtype int32\"):\n        fused_metadata_copy_multi_cuda(\n            cache_seqlens_src_wrong,  # Wrong dtype: int64\n            cu_seqlens_k_src,\n            page_indices_src,\n            nsa_cache_seqlens_src,\n            nsa_cu_seqlens_k_src,\n            None,  # real_page_table_src\n            None,  # flashmla_num_splits_src\n            None,  # flashmla_metadata_src\n            # Backend 0\n            dst0[\"cache_seqlens\"],\n            dst0[\"cu_seqlens_k\"],\n            dst0[\"page_table_1\"],\n            dst0[\"nsa_cache_seqlens\"],\n            dst0[\"nsa_cu_seqlens_k\"],\n            None,\n            None,\n            None,\n            # Backend 1\n            dst1[\"cache_seqlens\"],\n            dst1[\"cu_seqlens_k\"],\n            dst1[\"page_table_1\"],\n            dst1[\"nsa_cache_seqlens\"],\n            dst1[\"nsa_cu_seqlens_k\"],\n            None,\n            None,\n            None,\n            # Backend 2\n            dst2[\"cache_seqlens\"],\n            dst2[\"cu_seqlens_k\"],\n            dst2[\"page_table_1\"],\n            dst2[\"nsa_cache_seqlens\"],\n            dst2[\"nsa_cu_seqlens_k\"],\n            None,\n            None,\n            None,\n            # Parameters\n            bs,\n            max_len,\n            seqlens_expanded_size,\n        )\n\n\n@pytest.mark.parametrize(\"bs\", [1, 2, 4, 8, 16])\n@pytest.mark.parametrize(\"has_real_page_table\", [False, True])\n@pytest.mark.parametrize(\"has_flashmla\", [False, True])\ndef test_fused_metadata_copy_multi(bs, has_real_page_table, has_flashmla):\n    \"\"\"Test fused multi-backend metadata copy kernel against for-loop version.\"\"\"\n    if not torch.cuda.is_available():\n        pytest.skip(\"CUDA not available\")\n\n    from sglang.jit_kernel.fused_metadata_copy import fused_metadata_copy_multi_cuda\n\n    max_len = 128\n    seqlens_expanded_size = bs\n\n    # Create test data\n    data = create_test_metadata_multi(\n        bs=bs,\n        max_len=max_len,\n        seqlens_expanded_size=seqlens_expanded_size,\n        has_real_page_table=has_real_page_table,\n        has_flashmla=has_flashmla,\n    )\n\n    # Create separate destination tensors for reference (for-loop) and fused kernel\n    dst_ref_0 = {\n        k: v.clone() if v is not None else None for k, v in data[\"dst0\"].items()\n    }\n    dst_ref_1 = {\n        k: v.clone() if v is not None else None for k, v in data[\"dst1\"].items()\n    }\n    dst_ref_2 = {\n        k: v.clone() if v is not None else None for k, v in data[\"dst2\"].items()\n    }\n\n    dst_fused_0 = {\n        k: v.clone() if v is not None else None for k, v in data[\"dst0\"].items()\n    }\n    dst_fused_1 = {\n        k: v.clone() if v is not None else None for k, v in data[\"dst1\"].items()\n    }\n    dst_fused_2 = {\n        k: v.clone() if v is not None else None for k, v in data[\"dst2\"].items()\n    }\n\n    # Run reference implementation (for-loop)\n    torch.cuda.synchronize()\n    loop_start = time.perf_counter()\n    reference_copy_for_loop(data[\"src\"], [dst_ref_0, dst_ref_1, dst_ref_2], bs, max_len)\n    torch.cuda.synchronize()\n    loop_end = time.perf_counter()\n    loop_time = loop_end - loop_start\n\n    # Run fused kernel\n    torch.cuda.synchronize()\n    fused_start = time.perf_counter()\n    fused_metadata_copy_multi_cuda(\n        # Source tensors\n        data[\"src\"][\"cache_seqlens\"],\n        data[\"src\"][\"cu_seqlens_k\"],\n        data[\"src\"][\"page_indices\"],\n        data[\"src\"][\"nsa_cache_seqlens\"],\n        data[\"src\"][\"nsa_cu_seqlens_k\"],\n        data[\"src\"][\"real_page_table\"],\n        data[\"src\"][\"flashmla_num_splits\"],\n        data[\"src\"][\"flashmla_metadata\"],\n        # Destination tensors for backend 0\n        dst_fused_0[\"cache_seqlens_int32\"],\n        dst_fused_0[\"cu_seqlens_k\"],\n        dst_fused_0[\"page_table_1\"],\n        dst_fused_0[\"nsa_cache_seqlens_int32\"],\n        dst_fused_0[\"nsa_cu_seqlens_k\"],\n        dst_fused_0[\"real_page_table\"],\n        dst_fused_0[\"flashmla_num_splits\"],\n        dst_fused_0[\"flashmla_metadata\"],\n        # Destination tensors for backend 1\n        dst_fused_1[\"cache_seqlens_int32\"],\n        dst_fused_1[\"cu_seqlens_k\"],\n        dst_fused_1[\"page_table_1\"],\n        dst_fused_1[\"nsa_cache_seqlens_int32\"],\n        dst_fused_1[\"nsa_cu_seqlens_k\"],\n        dst_fused_1[\"real_page_table\"],\n        dst_fused_1[\"flashmla_num_splits\"],\n        dst_fused_1[\"flashmla_metadata\"],\n        # Destination tensors for backend 2\n        dst_fused_2[\"cache_seqlens_int32\"],\n        dst_fused_2[\"cu_seqlens_k\"],\n        dst_fused_2[\"page_table_1\"],\n        dst_fused_2[\"nsa_cache_seqlens_int32\"],\n        dst_fused_2[\"nsa_cu_seqlens_k\"],\n        dst_fused_2[\"real_page_table\"],\n        dst_fused_2[\"flashmla_num_splits\"],\n        dst_fused_2[\"flashmla_metadata\"],\n        # Parameters\n        bs,\n        max_len,\n        seqlens_expanded_size,\n    )\n    torch.cuda.synchronize()\n    fused_end = time.perf_counter()\n    fused_time = fused_end - fused_start\n\n    # Compare results for all 3 backends\n    speedup = loop_time / fused_time if fused_time > 0 else 0\n    print(\n        f\"\\n[VERIFY] bs={bs}, real_page_table={has_real_page_table}, flashmla={has_flashmla}\"\n    )\n    print(\n        f\"[VERIFY] Fused time: {fused_time*1000:.3f}ms, Loop time: {loop_time*1000:.3f}ms, Speedup: {speedup:.2f}x\"\n    )\n\n    max_diff = 0.0\n    all_match = True\n\n    for backend_idx, (dst_ref, dst_fused) in enumerate(\n        [\n            (dst_ref_0, dst_fused_0),\n            (dst_ref_1, dst_fused_1),\n            (dst_ref_2, dst_fused_2),\n        ]\n    ):\n        for key in [\n            \"cache_seqlens_int32\",\n            \"cu_seqlens_k\",\n            \"page_table_1\",\n            \"nsa_cache_seqlens_int32\",\n            \"nsa_cu_seqlens_k\",\n        ]:\n            if not torch.equal(dst_ref[key], dst_fused[key]):\n                diff = (\n                    (dst_ref[key].float() - dst_fused[key].float()).abs().max().item()\n                )\n                max_diff = max(max_diff, diff)\n                all_match = False\n                print(\n                    f\"[ERROR] Backend {backend_idx} {key}: MISMATCH! Max diff: {diff}\"\n                )\n\n        if has_real_page_table and dst_ref[\"real_page_table\"] is not None:\n            if not torch.equal(\n                dst_ref[\"real_page_table\"], dst_fused[\"real_page_table\"]\n            ):\n                diff = (\n                    (\n                        dst_ref[\"real_page_table\"].float()\n                        - dst_fused[\"real_page_table\"].float()\n                    )\n                    .abs()\n                    .max()\n                    .item()\n                )\n                max_diff = max(max_diff, diff)\n                all_match = False\n                print(\n                    f\"[ERROR] Backend {backend_idx} real_page_table: MISMATCH! Max diff: {diff}\"\n                )\n\n        if has_flashmla:\n            if dst_ref[\"flashmla_num_splits\"] is not None and not torch.equal(\n                dst_ref[\"flashmla_num_splits\"], dst_fused[\"flashmla_num_splits\"]\n            ):\n                diff = (\n                    (\n                        dst_ref[\"flashmla_num_splits\"].float()\n                        - dst_fused[\"flashmla_num_splits\"].float()\n                    )\n                    .abs()\n                    .max()\n                    .item()\n                )\n                max_diff = max(max_diff, diff)\n                all_match = False\n                print(\n                    f\"[ERROR] Backend {backend_idx} flashmla_num_splits: MISMATCH! Max diff: {diff}\"\n                )\n\n            if dst_ref[\"flashmla_metadata\"] is not None and not torch.equal(\n                dst_ref[\"flashmla_metadata\"], dst_fused[\"flashmla_metadata\"]\n            ):\n                diff = (\n                    (\n                        dst_ref[\"flashmla_metadata\"].float()\n                        - dst_fused[\"flashmla_metadata\"].float()\n                    )\n                    .abs()\n                    .max()\n                    .item()\n                )\n                max_diff = max(max_diff, diff)\n                all_match = False\n                print(\n                    f\"[ERROR] Backend {backend_idx} flashmla_metadata: MISMATCH! Max diff: {diff}\"\n                )\n\n    if not all_match:\n        error_msg = (\n            f\"Fused metadata copy verification FAILED! \"\n            f\"Maximum difference: {max_diff}. \"\n            f\"The fused kernel produces different results than the for-loop version.\"\n        )\n        print(f\"[ERROR] {error_msg}\")\n        raise AssertionError(error_msg)\n\n    print(f\"[VERIFY] Verification PASSED - all tensors match!\")\n\n\n@pytest.mark.parametrize(\"bs\", [32, 64])\ndef test_fused_metadata_copy_multi_large_batch(bs):\n    \"\"\"Test with larger batch sizes and timing comparison.\"\"\"\n    if not torch.cuda.is_available():\n        pytest.skip(\"CUDA not available\")\n\n    from sglang.jit_kernel.fused_metadata_copy import fused_metadata_copy_multi_cuda\n\n    max_len = 128\n    seqlens_expanded_size = bs\n\n    data = create_test_metadata_multi(\n        bs=bs,\n        max_len=max_len,\n        seqlens_expanded_size=seqlens_expanded_size,\n        has_real_page_table=True,\n        has_flashmla=True,\n    )\n\n    dst_ref_0 = {\n        k: v.clone() if v is not None else None for k, v in data[\"dst0\"].items()\n    }\n    dst_ref_1 = {\n        k: v.clone() if v is not None else None for k, v in data[\"dst1\"].items()\n    }\n    dst_ref_2 = {\n        k: v.clone() if v is not None else None for k, v in data[\"dst2\"].items()\n    }\n\n    dst_fused_0 = {\n        k: v.clone() if v is not None else None for k, v in data[\"dst0\"].items()\n    }\n    dst_fused_1 = {\n        k: v.clone() if v is not None else None for k, v in data[\"dst1\"].items()\n    }\n    dst_fused_2 = {\n        k: v.clone() if v is not None else None for k, v in data[\"dst2\"].items()\n    }\n\n    # Warmup\n    for _ in range(5):\n        reference_copy_for_loop(\n            data[\"src\"], [dst_ref_0, dst_ref_1, dst_ref_2], bs, max_len\n        )\n        fused_metadata_copy_multi_cuda(\n            data[\"src\"][\"cache_seqlens\"],\n            data[\"src\"][\"cu_seqlens_k\"],\n            data[\"src\"][\"page_indices\"],\n            data[\"src\"][\"nsa_cache_seqlens\"],\n            data[\"src\"][\"nsa_cu_seqlens_k\"],\n            data[\"src\"][\"real_page_table\"],\n            data[\"src\"][\"flashmla_num_splits\"],\n            data[\"src\"][\"flashmla_metadata\"],\n            dst_fused_0[\"cache_seqlens_int32\"],\n            dst_fused_0[\"cu_seqlens_k\"],\n            dst_fused_0[\"page_table_1\"],\n            dst_fused_0[\"nsa_cache_seqlens_int32\"],\n            dst_fused_0[\"nsa_cu_seqlens_k\"],\n            dst_fused_0[\"real_page_table\"],\n            dst_fused_0[\"flashmla_num_splits\"],\n            dst_fused_0[\"flashmla_metadata\"],\n            dst_fused_1[\"cache_seqlens_int32\"],\n            dst_fused_1[\"cu_seqlens_k\"],\n            dst_fused_1[\"page_table_1\"],\n            dst_fused_1[\"nsa_cache_seqlens_int32\"],\n            dst_fused_1[\"nsa_cu_seqlens_k\"],\n            dst_fused_1[\"real_page_table\"],\n            dst_fused_1[\"flashmla_num_splits\"],\n            dst_fused_1[\"flashmla_metadata\"],\n            dst_fused_2[\"cache_seqlens_int32\"],\n            dst_fused_2[\"cu_seqlens_k\"],\n            dst_fused_2[\"page_table_1\"],\n            dst_fused_2[\"nsa_cache_seqlens_int32\"],\n            dst_fused_2[\"nsa_cu_seqlens_k\"],\n            dst_fused_2[\"real_page_table\"],\n            dst_fused_2[\"flashmla_num_splits\"],\n            dst_fused_2[\"flashmla_metadata\"],\n            bs,\n            max_len,\n            seqlens_expanded_size,\n        )\n    torch.cuda.synchronize()\n\n    # Actual timing\n    torch.cuda.synchronize()\n    loop_start = time.perf_counter()\n    reference_copy_for_loop(data[\"src\"], [dst_ref_0, dst_ref_1, dst_ref_2], bs, max_len)\n    torch.cuda.synchronize()\n    loop_time = time.perf_counter() - loop_start\n\n    torch.cuda.synchronize()\n    fused_start = time.perf_counter()\n    fused_metadata_copy_multi_cuda(\n        data[\"src\"][\"cache_seqlens\"],\n        data[\"src\"][\"cu_seqlens_k\"],\n        data[\"src\"][\"page_indices\"],\n        data[\"src\"][\"nsa_cache_seqlens\"],\n        data[\"src\"][\"nsa_cu_seqlens_k\"],\n        data[\"src\"][\"real_page_table\"],\n        data[\"src\"][\"flashmla_num_splits\"],\n        data[\"src\"][\"flashmla_metadata\"],\n        dst_fused_0[\"cache_seqlens_int32\"],\n        dst_fused_0[\"cu_seqlens_k\"],\n        dst_fused_0[\"page_table_1\"],\n        dst_fused_0[\"nsa_cache_seqlens_int32\"],\n        dst_fused_0[\"nsa_cu_seqlens_k\"],\n        dst_fused_0[\"real_page_table\"],\n        dst_fused_0[\"flashmla_num_splits\"],\n        dst_fused_0[\"flashmla_metadata\"],\n        dst_fused_1[\"cache_seqlens_int32\"],\n        dst_fused_1[\"cu_seqlens_k\"],\n        dst_fused_1[\"page_table_1\"],\n        dst_fused_1[\"nsa_cache_seqlens_int32\"],\n        dst_fused_1[\"nsa_cu_seqlens_k\"],\n        dst_fused_1[\"real_page_table\"],\n        dst_fused_1[\"flashmla_num_splits\"],\n        dst_fused_1[\"flashmla_metadata\"],\n        dst_fused_2[\"cache_seqlens_int32\"],\n        dst_fused_2[\"cu_seqlens_k\"],\n        dst_fused_2[\"page_table_1\"],\n        dst_fused_2[\"nsa_cache_seqlens_int32\"],\n        dst_fused_2[\"nsa_cu_seqlens_k\"],\n        dst_fused_2[\"real_page_table\"],\n        dst_fused_2[\"flashmla_num_splits\"],\n        dst_fused_2[\"flashmla_metadata\"],\n        bs,\n        max_len,\n        seqlens_expanded_size,\n    )\n    torch.cuda.synchronize()\n    fused_time = time.perf_counter() - fused_start\n\n    speedup = loop_time / fused_time if fused_time > 0 else 0\n    print(\n        f\"\\n[PERF] Large batch (bs={bs}): Fused={fused_time*1000:.3f}ms, Loop={loop_time*1000:.3f}ms, Speedup={speedup:.2f}x\"\n    )\n\n    # Verify correctness\n    for backend_idx, (dst_ref, dst_fused) in enumerate(\n        [\n            (dst_ref_0, dst_fused_0),\n            (dst_ref_1, dst_fused_1),\n            (dst_ref_2, dst_fused_2),\n        ]\n    ):\n        for key in dst_ref:\n            if dst_ref[key] is not None and dst_fused[key] is not None:\n                assert torch.equal(\n                    dst_ref[key], dst_fused[key]\n                ), f\"Backend {backend_idx} {key} mismatch\"\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__, \"-v\", \"-s\"])\n"
  },
  {
    "path": "python/sglang/jit_kernel/tests/test_fused_norm_scale_shift.py",
    "content": "from typing import Optional, Tuple\n\nimport pytest\nimport torch\nfrom einops import rearrange\nfrom torch import Tensor\n\nfrom sglang.jit_kernel.diffusion.cutedsl.scale_residual_norm_scale_shift import (\n    fused_norm_scale_shift,\n    fused_scale_residual_norm_scale_shift,\n)\n\nDEVICE = \"cuda\"\nSHAPE_MAP = {\n    \"1\": lambda B, S, F, D: (1,),\n    \"D\": lambda B, S, F, D: (D,),\n    \"1D\": lambda B, S, F, D: (1, D),\n    \"BD\": lambda B, S, F, D: (B, D),\n    \"11D\": lambda B, S, F, D: (1, 1, D),\n    \"B1D\": lambda B, S, F, D: (B, 1, D),\n    \"1SD\": lambda B, S, F, D: (1, S, D),\n    \"BSD\": lambda B, S, F, D: (B, S, D),\n    \"BF1D\": lambda B, S, F, D: (B, F, 1, D),\n}\nSHAPES = [\n    # (B, S, F, D)\n    (1, 115200, 1, 3072),  # Hunyuan\n    (1, 32760, 1, 1536),  # Wan\n    (1, 6, 1, 3072),  # Qwen\n    (1, 1024, 8, 3072),\n    (4, 512, 16, 3072),\n]\nDTYPES = [torch.float16, torch.bfloat16, torch.float32]\nNORM_TYPES = [\"layer\", \"rms\"]\nAFFINE_MODES = [\"D\", \"NAT\"]\nINDEX_MODES = [\"BSD\", \"1\", \"1SD\", \"BD\", \"B1D\", \"D\", \"1D\", \"11D\", \"BF1D\"]\n\n\ndef _tol(dtype: torch.dtype):\n    return 1e-5 if dtype == torch.float32 else 5e-2\n\n\n@pytest.fixture(autouse=True)\ndef cuda_setup():\n    if not torch.cuda.is_available():\n        pytest.skip(\"CUDA required\")\n    torch.cuda.manual_seed(0)\n\n\ndef _apply_scale_shift(y: Tensor, scale: Tensor, shift: Tensor) -> Tensor:\n    if scale.ndim == 4:\n        num_frame = scale.shape[1]\n        return rearrange(\n            rearrange(y, \"b (f l) d -> b f l d\", f=num_frame) * (1 + scale) + shift,\n            \"b f l d -> b (f l) d\",\n        )\n    else:\n        scale = rearrange(scale, \"b d -> b 1 d\") if scale.ndim == 2 else scale\n        shift = rearrange(shift, \"b d -> b 1 d\") if shift.ndim == 2 else shift\n        return y * (1 + scale) + shift\n\n\ndef fused_norm_scale_shift_ref(\n    x: Tensor,\n    weight: Optional[Tensor],\n    bias: Optional[Tensor],\n    scale: Tensor,\n    shift: Tensor,\n    norm_type: str,\n    eps: float,\n) -> Tensor:\n    original_dtype = x.dtype\n    x, weight, bias, scale, shift = (\n        v.float() if v is not None else v for v in [x, weight, bias, scale, shift]\n    )\n    if norm_type == \"layer\":\n        norm = torch.layer_norm(x, x.shape[-1:], eps=eps, weight=weight, bias=bias)\n    else:\n        norm = torch.rms_norm(x, x.shape[-1:], eps=eps, weight=weight)\n    return _apply_scale_shift(norm, scale, shift).to(original_dtype)\n\n\ndef fused_scale_residual_norm_scale_shift_ref(\n    residual: Tensor,\n    x: Tensor,\n    gate: Optional[Tensor] | int,\n    weight: Optional[Tensor],\n    bias: Optional[Tensor],\n    scale: Tensor,\n    shift: Tensor,\n    norm_type: str,\n    eps: float,\n):\n    original_dtype = x.dtype\n    residual, x, gate, weight, bias, scale, shift = (\n        v.float() if isinstance(v, Tensor) else v\n        for v in [residual, x, gate, weight, bias, scale, shift]\n    )\n    if isinstance(gate, int):\n        x = residual + gate * x\n    else:\n        if gate.ndim == 4:\n            num_frame = gate.shape[1]\n            x_fld = rearrange(x, \"b (f l) d -> b f l d\", f=num_frame)\n            x = residual + rearrange(x_fld * gate, \"b f l d -> b (f l) d\")\n        else:\n            gate = rearrange(gate, \"b d -> b 1 d\") if gate.ndim == 2 else gate\n            x = residual + gate * x\n    if norm_type == \"layer\":\n        norm = torch.layer_norm(x, x.shape[-1:], eps=eps, weight=weight, bias=bias)\n    else:\n        norm = torch.rms_norm(x, x.shape[-1:], eps=eps, weight=weight)\n    y_ref = _apply_scale_shift(norm, scale, shift)\n    return y_ref.to(original_dtype), x.to(original_dtype)\n\n\ndef _make_tensor(index_mode: str, shape: Tuple, dtype: torch.dtype):\n    if index_mode == \"NAT\":\n        return None\n    return torch.randn(*SHAPE_MAP[index_mode](*shape), device=DEVICE, dtype=dtype)\n\n\n@torch.no_grad()\ndef run_norm_scale_shift(\n    shape=SHAPES[0],\n    dtype=DTYPES[0],\n    affine_dtype=DTYPES[0],\n    scale_dtype=DTYPES[0],\n    shift_dtype=DTYPES[0],\n    norm_type=NORM_TYPES[0],\n    affine_mode=AFFINE_MODES[0],\n    scale_mode=\"BSD\",\n    shift_mode=\"BSD\",\n    eps=1e-5,\n):\n    x = _make_tensor(\"BSD\", shape, dtype)\n    weight = _make_tensor(affine_mode, shape, affine_dtype)\n    bias = _make_tensor(affine_mode, shape, affine_dtype)\n    scale = _make_tensor(scale_mode, shape, scale_dtype)\n    shift = _make_tensor(shift_mode, shape, shift_dtype)\n    y_dev = fused_norm_scale_shift(x, weight, bias, scale, shift, norm_type, eps)\n    y_ref = fused_norm_scale_shift_ref(x, weight, bias, scale, shift, norm_type, eps)\n    torch.testing.assert_close(y_dev, y_ref, atol=_tol(dtype), rtol=_tol(dtype))\n\n\n@torch.no_grad()\ndef run_scale_resi_norm_scale_shift(\n    shape=SHAPES[0],\n    dtype=DTYPES[0],\n    affine_dtype=DTYPES[0],\n    scale_dtype=DTYPES[0],\n    shift_dtype=DTYPES[0],\n    norm_type=NORM_TYPES[0],\n    affine_mode=AFFINE_MODES[0],\n    gate_mode=\"B1D\",\n    scale_mode=\"BSD\",\n    shift_mode=\"BSD\",\n    eps=1e-5,\n):\n    residual = _make_tensor(\"BSD\", shape, dtype)\n    x = _make_tensor(\"BSD\", shape, dtype)\n    gate = _make_tensor(gate_mode, shape, dtype)\n    weight = _make_tensor(affine_mode, shape, affine_dtype)\n    bias = _make_tensor(affine_mode, shape, affine_dtype)\n    scale = _make_tensor(scale_mode, shape, scale_dtype)\n    shift = _make_tensor(shift_mode, shape, shift_dtype)\n    y_dev, res_dev = fused_scale_residual_norm_scale_shift(\n        residual, x, gate, weight, bias, scale, shift, norm_type, eps\n    )\n    y_ref, res_ref = fused_scale_residual_norm_scale_shift_ref(\n        residual, x, gate, weight, bias, scale, shift, norm_type, eps\n    )\n    torch.testing.assert_close(y_dev, y_ref, atol=_tol(dtype), rtol=_tol(dtype))\n    torch.testing.assert_close(res_dev, res_ref, atol=_tol(dtype), rtol=_tol(dtype))\n\n\n@pytest.mark.parametrize(\"norm_type\", NORM_TYPES)\nclass TestFusedNormScaleShift:\n    @pytest.mark.parametrize(\"shape\", SHAPES)\n    @pytest.mark.parametrize(\"dtype\", DTYPES)\n    def test_shape_dtype(self, shape, dtype, norm_type):\n        run_norm_scale_shift(shape=shape, dtype=dtype, norm_type=norm_type)\n\n    @pytest.mark.parametrize(\"dtype\", DTYPES)\n    def test_dtype_0(self, dtype, norm_type):\n        run_norm_scale_shift(affine_dtype=dtype, norm_type=norm_type)\n\n    @pytest.mark.parametrize(\"dtype\", DTYPES)\n    def test_dtype_1(self, dtype, norm_type):\n        run_norm_scale_shift(scale_dtype=dtype, shift_dtype=dtype, norm_type=norm_type)\n\n    @pytest.mark.parametrize(\"affine_mode\", AFFINE_MODES)\n    def test_normtype_affine(self, affine_mode, norm_type):\n        run_norm_scale_shift(affine_mode=affine_mode, norm_type=norm_type)\n\n    @pytest.mark.parametrize(\"index_mode\", INDEX_MODES)\n    def test_index_mode(self, index_mode, norm_type):\n        run_norm_scale_shift(\n            scale_mode=index_mode, shift_mode=index_mode, norm_type=norm_type\n        )\n\n\n@pytest.mark.parametrize(\"norm_type\", NORM_TYPES)\nclass TestFusedScaleResidualNormScaleShift:\n    @pytest.mark.parametrize(\"shape\", SHAPES)\n    @pytest.mark.parametrize(\"dtype\", DTYPES)\n    def test_shape_dtype(self, shape, dtype, norm_type):\n        run_scale_resi_norm_scale_shift(shape=shape, dtype=dtype, norm_type=norm_type)\n\n    @pytest.mark.parametrize(\"dtype\", DTYPES)\n    def test_dtype_0(self, dtype, norm_type):\n        run_scale_resi_norm_scale_shift(affine_dtype=dtype, norm_type=norm_type)\n\n    @pytest.mark.parametrize(\"dtype\", DTYPES)\n    def test_dtype_1(self, dtype, norm_type):\n        run_scale_resi_norm_scale_shift(\n            scale_dtype=dtype, shift_dtype=dtype, norm_type=norm_type\n        )\n\n    @pytest.mark.parametrize(\"affine_mode\", AFFINE_MODES)\n    def test_normtype_affine(self, affine_mode, norm_type):\n        run_scale_resi_norm_scale_shift(affine_mode=affine_mode, norm_type=norm_type)\n\n    @pytest.mark.parametrize(\"index_mode\", INDEX_MODES)\n    def test_scale_shift_index_mode(self, index_mode, norm_type):\n        run_scale_resi_norm_scale_shift(\n            scale_mode=index_mode, shift_mode=index_mode, norm_type=norm_type\n        )\n\n    @pytest.mark.parametrize(\"index_mode\", INDEX_MODES)\n    def test_gate_index_mode(self, index_mode, norm_type):\n        run_scale_resi_norm_scale_shift(gate_mode=index_mode, norm_type=norm_type)\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__, \"-v\", \"-s\"])\n"
  },
  {
    "path": "python/sglang/jit_kernel/tests/test_fused_store_index_cache.py",
    "content": "\"\"\"\nTest for fused_store_index_k_cache kernel.\n\nDesign Notes:\n  1. torch.cuda.synchronize() needed after TVM FFI kernel call.\n  2. _split_buffer used buf[:, :vb].reshape(-1) which COPIES data for\n     non-contiguous slices → reference buffer stayed all-zeros.\n     Fix: use flat byte-offset indexing.\n  3. act_quant may use a different quantization scheme → generous tolerance.\n  4. FP8 E4M3 1-ULP rounding differences between CUDA hardware cast\n     (__nv_fp8_e4m3) and PyTorch .to(float8_e4m3fn) at tie-break points.\n     Adjacent FP8 representable values at the high end differ by up to 32\n     in float space (e.g. 288, 320, 352, ..., 448).\n     Need to compare dequantized values with FP8-appropriate tolerance.\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom typing import Optional, Tuple\n\nimport pytest\nimport torch\n\ntry:\n    from sglang.jit_kernel.fused_store_index_cache import (\n        can_use_nsa_fused_store,\n        fused_store_index_k_cache,\n    )\n\n    HAS_FUSED = True\nexcept ImportError:\n    HAS_FUSED = False\n\ntry:\n    from sglang.srt.utils import is_hip\n\n    _is_hip = is_hip()\nexcept ImportError:\n    _is_hip = False\n\ntry:\n    from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz\n\n    _is_fp8_fnuz = is_fp8_fnuz()\nexcept ImportError:\n    _is_fp8_fnuz = False\n\nPAGE_SIZE = 64\nHEAD_DIM = 128\nFP8_E4M3_MAX = 448.0\nFP8_DTYPE = torch.float8_e4m3fn\nBYTES_PER_TOKEN = 128 + 4  # 128 fp8 bytes + 4 scale bytes\nBYTES_PER_PAGE = PAGE_SIZE * BYTES_PER_TOKEN\n\n\ndef _skip_if_unavailable(page_size: int = PAGE_SIZE):\n    if not torch.cuda.is_available():\n        pytest.skip(\"CUDA required\")\n    if _is_hip:\n        pytest.skip(\"Fused store kernel is CUDA-specific\")\n    if _is_fp8_fnuz:\n        pytest.skip(\"Fused store path disabled for FP8 FNUZ\")\n    if not hasattr(torch, \"float8_e4m3fn\"):\n        pytest.skip(\"torch.float8_e4m3fn not available\")\n    if not HAS_FUSED:\n        pytest.skip(\"fused_store_index_cache not importable\")\n    if not can_use_nsa_fused_store(torch.bfloat16, torch.int64, page_size):\n        pytest.skip(\"JIT kernel unavailable / failed to compile\")\n\n\ndef _num_pages(loc: torch.Tensor, page_size: int, extra: int = 1) -> int:\n    return int(loc.max().item()) // page_size + 1 + extra\n\n\ndef _make_buffer(num_pages: int, page_size: int = PAGE_SIZE) -> torch.Tensor:\n    return torch.zeros(\n        (num_pages, page_size * BYTES_PER_TOKEN),\n        dtype=torch.uint8,\n        device=\"cuda\",\n    )\n\n\ndef _read_token_from_buffer(\n    buf: torch.Tensor,\n    token_idx: int,\n    page_size: int = PAGE_SIZE,\n) -> Tuple[torch.Tensor, float]:\n    \"\"\"\n    Read a single token's fp8 values and scale from the paged buffer\n    using flat byte offsets.\n    \"\"\"\n    page = token_idx // page_size\n    offset = token_idx % page_size\n    page_bytes = page_size * BYTES_PER_TOKEN\n\n    buf_flat = buf.reshape(-1)\n\n    val_start = page * page_bytes + offset * 128\n    fp8_bytes = buf_flat[val_start : val_start + 128]\n    fp8_vals = fp8_bytes.view(FP8_DTYPE).float()\n\n    scale_start = page * page_bytes + 128 * page_size + offset * 4\n    scale_bytes = buf_flat[scale_start : scale_start + 4]\n    scale = scale_bytes.view(torch.float32).item()\n\n    return fp8_vals, scale\n\n\ndef _write_token_to_buffer(\n    buf: torch.Tensor,\n    token_idx: int,\n    fp8_data: torch.Tensor,\n    scale: float,\n    page_size: int = PAGE_SIZE,\n) -> None:\n    \"\"\"\n    Write a single token's fp8 values and scale into the paged buffer\n    using flat byte offsets on buf.reshape(-1) (which is a true view\n    since buf is contiguous).\n    \"\"\"\n    page = token_idx // page_size\n    offset = token_idx % page_size\n    page_bytes = page_size * BYTES_PER_TOKEN\n\n    buf_flat = buf.reshape(-1)\n\n    val_start = page * page_bytes + offset * 128\n    buf_flat[val_start : val_start + 128] = fp8_data.view(torch.uint8)\n\n    scale_start = page * page_bytes + 128 * page_size + offset * 4\n    scale_t = torch.tensor([scale], dtype=torch.float32, device=buf.device)\n    buf_flat[scale_start : scale_start + 4] = scale_t.view(torch.uint8)\n\n\ndef _gather_tokens(\n    buf: torch.Tensor,\n    loc: torch.Tensor,\n    page_size: int = PAGE_SIZE,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    N = loc.shape[0]\n    fp8_f32 = torch.empty((N, HEAD_DIM), dtype=torch.float32, device=buf.device)\n    scales = torch.empty((N,), dtype=torch.float32, device=buf.device)\n    for i in range(N):\n        idx = int(loc[i].item())\n        vals, s = _read_token_from_buffer(buf, idx, page_size)\n        fp8_f32[i] = vals\n        scales[i] = s\n    return fp8_f32, scales\n\n\n# Reference kernel\ndef _reference_quantize_and_store(\n    key_bf16: torch.Tensor,\n    loc: torch.Tensor,\n    num_pages: int,\n    page_size: int = PAGE_SIZE,\n) -> torch.Tensor:\n    \"\"\"\n    Reference kernel of the fused kernel's quantization:\n      abs_max = max(|row|)\n      scale   = max(1e-4, abs_max) / 448\n      fp8_val = clip(val / scale, -448, 448) -> cast to fp8\n    \"\"\"\n    N = key_bf16.shape[0]\n    key_f32 = key_bf16.float()\n    buf = _make_buffer(num_pages, page_size)\n\n    for i in range(N):\n        row = key_f32[i]\n        abs_max = row.abs().max().item()\n        scale = max(1e-4, abs_max) / FP8_E4M3_MAX\n        inv_scale = 1.0 / scale\n        quantized = (row * inv_scale).clamp(-FP8_E4M3_MAX, FP8_E4M3_MAX)\n        quantized_fp8 = quantized.to(FP8_DTYPE)\n\n        idx = int(loc[i].item())\n        _write_token_to_buffer(buf, idx, quantized_fp8, scale, page_size)\n\n    return buf\n\n\ndef _import_act_quant():\n    try:\n        from sglang.srt.layers.attention.nsa.triton_kernel import act_quant\n\n        return act_quant\n    except Exception:\n        return None\n\n\ndef _ref_store_via_act_quant(\n    key_bf16: torch.Tensor,\n    loc: torch.Tensor,\n    num_pages: int,\n    page_size: int = PAGE_SIZE,\n    block_size: int = 128,\n    scale_fmt: Optional[str] = None,\n) -> Optional[torch.Tensor]:\n    act_quant = _import_act_quant()\n    if act_quant is None:\n        return None\n\n    try:\n        k_fp8, k_scale = act_quant(key_bf16, block_size, scale_fmt)\n    except TypeError:\n        k_fp8, k_scale = act_quant(key_bf16, block_size)\n\n    if k_fp8.dim() == 3 and k_fp8.shape[1] == 1:\n        k_fp8 = k_fp8.squeeze(1)\n    if k_scale is not None and k_scale.dim() == 3 and k_scale.shape[1] == 1:\n        k_scale = k_scale.squeeze(1)\n    k_scale = k_scale.view(-1).float()\n\n    buf = _make_buffer(num_pages, page_size)\n    N = key_bf16.shape[0]\n    for i in range(N):\n        idx = int(loc[i].item())\n        _write_token_to_buffer(\n            buf, idx, k_fp8[i].to(FP8_DTYPE), k_scale[i].item(), page_size\n        )\n    return buf\n\n\n# TEST 1: Fused kernel vs. its own algorithm (pure-Python reference)\n#\n# NOTE on FP8 rounding:\n#   CUDA hardware fp8 cast (__nv_fp8_e4m3) and PyTorch .to(float8_e4m3fn)\n#   may round differently at tie-break points.  This causes up to 1-ULP\n#   differences in the FP8 codes.  In FP8 E4M3, adjacent representable\n#   values at the high end differ by up to 32 in float space (e.g.\n#   288 vs 320).  After dequantization (fp8_float * scale), the error\n#   from 1-ULP is: scale * ulp ≈ (abs_max/448) * 32 ≈ 0.07 * abs_max.\n#   For randn inputs (abs_max ≈ 3-4), this is about 0.2-0.3.\n#\n#   We therefore compare dequantized values with tolerances that\n#   accommodate 1-ULP FP8 rounding, NOT byte-exact fp8 codes.\n@pytest.mark.parametrize(\n    \"num_tokens,base_index\",\n    [(1, 0), (32, 0), (64, 0), (128, 64), (257, 65), (512, 0)],\n)\ndef test_fused_kernel_matches_own_algorithm(num_tokens: int, base_index: int):\n    \"\"\"Compare fused CUDA kernel against a pure-Python implementation\n    of the *same* quantization formula.\"\"\"\n    _skip_if_unavailable()\n    device = torch.device(\"cuda\")\n\n    key = torch.randn((num_tokens, HEAD_DIM), device=device, dtype=torch.bfloat16)\n    loc = (\n        base_index + torch.randperm(num_tokens, device=device, dtype=torch.int64)\n    ).contiguous()\n    num_pages = _num_pages(loc, PAGE_SIZE)\n\n    # Reference kernel\n    ref_buf = _reference_quantize_and_store(key, loc, num_pages)\n\n    # Fused kernel\n    out_buf = _make_buffer(num_pages)\n    fused_store_index_k_cache(key, out_buf, loc, page_size=PAGE_SIZE)\n    torch.cuda.synchronize()\n\n    out_f, out_s = _gather_tokens(out_buf, loc)\n    ref_f, ref_s = _gather_tokens(ref_buf, loc)\n\n    # 1) Scales must match tightly (same f32 formula, no rounding ambiguity)\n    torch.testing.assert_close(out_s, ref_s, rtol=1e-5, atol=1e-7)\n\n    # 2) Most FP8 codes should match; allow rare 1-ULP differences.\n    #    1-ULP at FP8 E4M3 high end = 32 in float space.\n    mismatch = out_f != ref_f\n    mismatch_frac = mismatch.float().mean().item()\n    assert mismatch_frac < 0.01, (\n        f\"Too many FP8 code mismatches: {mismatch_frac:.2%} \"\n        f\"(expected < 1% from rounding tie-breaks)\"\n    )\n\n    # 3) Where codes differ, the difference should be exactly 1 ULP.\n    #    In FP8 E4M3: if the float-cast value is V, the adjacent value\n    #    differs by ~V * 0.1 (relative) at most.\n    if mismatch.any():\n        diff = (out_f[mismatch] - ref_f[mismatch]).abs()\n        rel_diff = diff / ref_f[mismatch].abs().clamp(min=1e-6)\n        # 1-ULP relative difference for E4M3 is at most ~12.5% (2^-3)\n        assert rel_diff.max().item() <= 0.15, (\n            f\"FP8 code difference exceeds 1-ULP: max relative diff = \"\n            f\"{rel_diff.max().item():.4f}\"\n        )\n\n    # 4) Dequantized values should be close.\n    #    Max error from 1-ULP: scale * fp8_ulp ≈ (abs_max/448) * 32\n    #    For randn abs_max ≈ 3-4: max_err ≈ 0.21 - 0.29\n    out_deq = out_f * out_s.unsqueeze(-1)\n    ref_deq = ref_f * ref_s.unsqueeze(-1)\n    torch.testing.assert_close(out_deq, ref_deq, rtol=0.15, atol=0.5)\n\n\n# TEST 2: Cross-check against act_quant\n@pytest.mark.parametrize(\"scale_fmt\", [None, \"fp32\"])\ndef test_fused_kernel_vs_act_quant_semantic(scale_fmt: Optional[str]):\n    \"\"\"Both fused kernel and act_quant should approximately reconstruct\n    the original bf16 values.\"\"\"\n    _skip_if_unavailable()\n    device = torch.device(\"cuda\")\n\n    num_tokens = 257\n    base_index = 65\n    key = torch.randn((num_tokens, HEAD_DIM), device=device, dtype=torch.bfloat16)\n    loc = (\n        base_index + torch.randperm(num_tokens, device=device, dtype=torch.int64)\n    ).contiguous()\n    num_pages = _num_pages(loc, PAGE_SIZE)\n\n    ref_buf = _ref_store_via_act_quant(key, loc, num_pages, scale_fmt=scale_fmt)\n    if ref_buf is None:\n        pytest.skip(\"act_quant not available\")\n\n    out_buf = _make_buffer(num_pages)\n    fused_store_index_k_cache(key, out_buf, loc, page_size=PAGE_SIZE)\n    torch.cuda.synchronize()\n\n    out_f, out_s = _gather_tokens(out_buf, loc)\n    ref_f, ref_s = _gather_tokens(ref_buf, loc)\n\n    out_deq = out_f * out_s.unsqueeze(-1)\n    ref_deq = ref_f * ref_s.unsqueeze(-1)\n    orig_f32 = key.float()\n\n    # Fused kernel should reconstruct original within FP8 precision\n    torch.testing.assert_close(\n        out_deq,\n        orig_f32,\n        rtol=0.15,\n        atol=5e-2,\n        msg=\"Fused kernel dequantized values don't approximate original\",\n    )\n\n    # act_quant may use a very different scale policy.\n    try:\n        torch.testing.assert_close(\n            ref_deq,\n            orig_f32,\n            rtol=0.25,\n            atol=0.5,\n            msg=\"act_quant dequantized values don't approximate original\",\n        )\n    except AssertionError:\n        nonzero_frac = (ref_deq.abs() > 1e-6).float().mean().item()\n        if nonzero_frac < 0.5:\n            pytest.fail(\n                f\"act_quant output looks mostly zero ({nonzero_frac:.1%} nonzero).\"\n            )\n        else:\n            pytest.skip(\n                f\"act_quant uses a very different quantization scheme \"\n                f\"(scale_fmt={scale_fmt}). Fused kernel validated independently.\"\n            )\n\n    torch.testing.assert_close(\n        out_deq,\n        ref_deq,\n        rtol=0.3,\n        atol=0.5,\n        msg=\"Fused and act_quant dequantized values diverge too much\",\n    )\n\n\n# TEST 3: Roundtrip reconstruction\n@pytest.mark.parametrize(\"num_tokens\", [1, 64, 257])\ndef test_roundtrip_reconstruction(num_tokens: int):\n    _skip_if_unavailable()\n    device = torch.device(\"cuda\")\n\n    key = torch.randn((num_tokens, HEAD_DIM), device=device, dtype=torch.bfloat16)\n    loc = torch.arange(num_tokens, device=device, dtype=torch.int64)\n    num_pages = _num_pages(loc, PAGE_SIZE)\n\n    buf = _make_buffer(num_pages)\n    fused_store_index_k_cache(key, buf, loc, page_size=PAGE_SIZE)\n    torch.cuda.synchronize()\n\n    fp8_f32, scales = _gather_tokens(buf, loc)\n    reconstructed = fp8_f32 * scales.unsqueeze(-1)\n    original = key.float()\n\n    torch.testing.assert_close(reconstructed, original, rtol=0.15, atol=5e-2)\n\n    per_row_energy = reconstructed.abs().sum(dim=-1)\n    orig_energy = original.abs().sum(dim=-1)\n    mask = orig_energy > 0.1\n    assert (\n        per_row_energy[mask] > 0.01\n    ).all(), \"Some tokens have zero reconstruction — kernel may not be writing output\"\n\n\n# TEST 4: Boundary conditions\ndef test_single_token():\n    _skip_if_unavailable()\n    device = torch.device(\"cuda\")\n\n    key = torch.randn((1, HEAD_DIM), device=device, dtype=torch.bfloat16)\n    loc = torch.tensor([0], device=device, dtype=torch.int64)\n\n    buf = _make_buffer(1)\n    fused_store_index_k_cache(key, buf, loc, page_size=PAGE_SIZE)\n    torch.cuda.synchronize()\n\n    fp8_f32, scales = _gather_tokens(buf, loc)\n    reconstructed = fp8_f32 * scales.unsqueeze(-1)\n    torch.testing.assert_close(reconstructed, key.float(), rtol=0.15, atol=5e-2)\n\n\n# TEST 5: Zero input conditions\ndef test_zero_input():\n    _skip_if_unavailable()\n    device = torch.device(\"cuda\")\n\n    key = torch.zeros((4, HEAD_DIM), device=device, dtype=torch.bfloat16)\n    loc = torch.arange(4, device=device, dtype=torch.int64)\n\n    buf = _make_buffer(1)\n    fused_store_index_k_cache(key, buf, loc, page_size=PAGE_SIZE)\n    torch.cuda.synchronize()\n\n    fp8_f32, scales = _gather_tokens(buf, loc)\n\n    expected_scale = 1e-4 / FP8_E4M3_MAX\n    torch.testing.assert_close(\n        scales,\n        torch.full_like(scales, expected_scale),\n        rtol=1e-5,\n        atol=1e-10,\n    )\n    assert (fp8_f32 == 0).all()\n\n\n# TEST 6: Sanity check — verify reference itself writes non-zero data\ndef test_reference_writes_nonzero():\n    _skip_if_unavailable()\n    device = torch.device(\"cuda\")\n\n    key = torch.randn((8, HEAD_DIM), device=device, dtype=torch.bfloat16)\n    loc = torch.arange(8, device=device, dtype=torch.int64)\n\n    buf = _reference_quantize_and_store(key, loc, num_pages=1)\n\n    fp8_f32, scales = _gather_tokens(buf, loc)\n    deq = fp8_f32 * scales.unsqueeze(-1)\n\n    assert deq.abs().sum().item() > 0, \"Reference buffer is all zeros — error!\"\n    torch.testing.assert_close(deq, key.float(), rtol=0.15, atol=5e-2)\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__, \"-v\", \"-s\"])\n"
  },
  {
    "path": "python/sglang/jit_kernel/tests/test_fused_verify_triton_gdn.py",
    "content": "\"\"\"Tests for fused sigmoid gating delta rule MTP kernel (GDN target_verify).\n\nCompares the fused kernel `fused_sigmoid_gating_delta_rule_update` against\nthe reference two-step implementation:\n    1. g, beta = fused_gdn_gating(A_log, a, b, dt_bias)\n    2. o = fused_recurrent_gated_delta_rule_update(q, k, v, g, beta, ...)\n\"\"\"\n\nimport pytest\nimport torch\n\ntry:\n    from sglang.srt.layers.attention.fla.fused_gdn_gating import fused_gdn_gating\n    from sglang.srt.layers.attention.fla.fused_recurrent import (\n        fused_recurrent_gated_delta_rule_update,\n    )\n    from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import (\n        fused_sigmoid_gating_delta_rule_update,\n    )\n\n    KERNELS_AVAILABLE = True\nexcept ImportError:\n    KERNELS_AVAILABLE = False\n\n\ndef _make_tensors(N, T, H, HV, K, V, device=\"cuda\", seed=2025):\n    \"\"\"Create input tensors for GDN target_verify.\"\"\"\n    torch.manual_seed(seed)\n    A_log = torch.randn(HV, dtype=torch.float32, device=device)\n    dt_bias = torch.randn(HV, dtype=torch.bfloat16, device=device)\n    a = torch.randn(1, N * T, HV, dtype=torch.bfloat16, device=device)\n    b = torch.randn(1, N * T, HV, dtype=torch.bfloat16, device=device)\n    q = torch.randn(1, N * T, H, K, dtype=torch.bfloat16, device=device)\n    k = torch.randn(1, N * T, H, K, dtype=torch.bfloat16, device=device)\n    v = torch.randn(1, N * T, HV, V, dtype=torch.bfloat16, device=device)\n    indices = torch.arange(N, dtype=torch.int32, device=device)\n    initial_state = torch.randn(N, HV, K, V, dtype=torch.float, device=device)\n    cu_seqlens = torch.arange(0, N * T + 1, T, dtype=torch.int32, device=device)\n    return A_log, dt_bias, a, b, q, k, v, initial_state, indices, cu_seqlens\n\n\ndef run_reference(\n    A_log,\n    dt_bias,\n    q,\n    k,\n    v,\n    a,\n    b,\n    initial_state_source,\n    initial_state_indices,\n    cu_seqlens,\n    disable_state_update=True,\n    intermediate_states_buffer=None,\n    intermediate_state_indices=None,\n    cache_steps=None,\n    retrieve_parent_token=None,\n):\n    \"\"\"Reference: fused_gdn_gating + fused_recurrent_gated_delta_rule_update.\"\"\"\n    # fused_gdn_gating expects 2D [seq_len, HV]\n    a_2d = a.view(-1, a.shape[-1])\n    b_2d = b.view(-1, b.shape[-1])\n    g, beta = fused_gdn_gating(A_log, a_2d, b_2d, dt_bias)\n    # fused_recurrent expects 3D [B, T, HV]\n    g = g.view(a.shape)\n    beta = beta.view(b.shape)\n\n    # fused_recurrent requires intermediate_state_indices when cu_seqlens is used\n    if cu_seqlens is not None and intermediate_state_indices is None:\n        N = len(cu_seqlens) - 1\n        intermediate_state_indices = torch.arange(N, dtype=torch.int32, device=q.device)\n\n    return fused_recurrent_gated_delta_rule_update(\n        q=q,\n        k=k,\n        v=v,\n        g=g,\n        beta=beta,\n        initial_state_source=initial_state_source,\n        initial_state_indices=initial_state_indices,\n        cu_seqlens=cu_seqlens,\n        use_qk_l2norm_in_kernel=True,\n        disable_state_update=disable_state_update,\n        intermediate_states_buffer=intermediate_states_buffer,\n        intermediate_state_indices=intermediate_state_indices,\n        cache_steps=cache_steps,\n        retrieve_parent_token=retrieve_parent_token,\n    )\n\n\ndef run_fused_mtp(\n    A_log,\n    dt_bias,\n    q,\n    k,\n    v,\n    a,\n    b,\n    initial_state_source,\n    initial_state_indices,\n    cu_seqlens,\n    disable_state_update=True,\n    intermediate_states_buffer=None,\n    intermediate_state_indices=None,\n    cache_steps=None,\n    retrieve_parent_token=None,\n):\n    \"\"\"Fused: fused_sigmoid_gating_delta_rule_update.\"\"\"\n    return fused_sigmoid_gating_delta_rule_update(\n        A_log=A_log,\n        dt_bias=dt_bias,\n        q=q,\n        k=k,\n        v=v,\n        a=a,\n        b=b,\n        initial_state_source=initial_state_source,\n        initial_state_indices=initial_state_indices,\n        cu_seqlens=cu_seqlens,\n        use_qk_l2norm_in_kernel=True,\n        softplus_beta=1.0,\n        softplus_threshold=20.0,\n        is_kda=False,\n        disable_state_update=disable_state_update,\n        intermediate_states_buffer=intermediate_states_buffer,\n        intermediate_state_indices=intermediate_state_indices,\n        cache_steps=cache_steps,\n        retrieve_parent_token=retrieve_parent_token,\n    )\n\n\n@pytest.mark.skipif(not KERNELS_AVAILABLE, reason=\"Kernel not available\")\n@pytest.mark.parametrize(\"N\", [1, 8, 16])\n@pytest.mark.parametrize(\"T\", [1, 4, 8])\ndef test_fused_gdn_mtp_precision(N: int, T: int):\n    \"\"\"Compare fused MTP output against reference.\"\"\"\n    H, HV, K, V = 16, 32, 128, 128\n\n    A_log, dt_bias, a, b, q, k, v, state, indices, cu_seqlens = _make_tensors(\n        N, T, H, HV, K, V\n    )\n\n    state_ref = state.clone()\n    state_fused = state.clone()\n\n    out_ref = run_reference(\n        A_log,\n        dt_bias,\n        q,\n        k,\n        v,\n        a,\n        b,\n        state_ref,\n        indices,\n        cu_seqlens,\n        disable_state_update=True,\n    )\n    out_fused = run_fused_mtp(\n        A_log,\n        dt_bias,\n        q,\n        k,\n        v,\n        a,\n        b,\n        state_fused,\n        indices,\n        cu_seqlens,\n        disable_state_update=True,\n    )\n\n    torch.testing.assert_close(out_ref, out_fused, rtol=1e-2, atol=1e-2)\n\n\n@pytest.mark.skipif(not KERNELS_AVAILABLE, reason=\"Kernels not available\")\n@pytest.mark.parametrize(\"N\", [1, 16, 128])\ndef test_mtp_single_step_decode(N: int):\n    \"\"\"Verify MTP kernel matches reference for T=1 (decode scenario).\"\"\"\n    T = 1\n    H, HV, K, V = 16, 32, 128, 128\n\n    A_log, dt_bias, a, b, q, k, v, state, indices, cu_seqlens = _make_tensors(\n        N, T, H, HV, K, V\n    )\n\n    state_ref = state.clone()\n    state_fused = state.clone()\n\n    out_ref = run_reference(\n        A_log,\n        dt_bias,\n        q,\n        k,\n        v,\n        a,\n        b,\n        state_ref,\n        indices,\n        cu_seqlens,\n        disable_state_update=False,\n    )\n    out_fused = run_fused_mtp(\n        A_log,\n        dt_bias,\n        q,\n        k,\n        v,\n        a,\n        b,\n        state_fused,\n        indices,\n        cu_seqlens,\n        disable_state_update=False,\n    )\n\n    torch.testing.assert_close(out_ref, out_fused, rtol=1e-2, atol=1e-2)\n\n    # Also verify states match after update\n    state_diff = (state_ref.float() - state_fused.float()).abs()\n    state_max_diff = state_diff.max().item()\n    state_fail_rate = (state_diff > 0.1).float().mean().item() * 100\n    print(\n        f\"  single_step state N={N}: max_diff={state_max_diff:.2e}, \"\n        f\"fail_rate={state_fail_rate:.2f}%\"\n    )\n    assert state_fail_rate < 0.01, f\"State mismatch: fail_rate={state_fail_rate:.2f}%\"\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__, \"-v\", \"-s\"])\n"
  },
  {
    "path": "python/sglang/jit_kernel/tests/test_gptq_marlin.py",
    "content": "import pytest\nimport torch\nfrom sgl_kernel.scalar_type import scalar_types\n\nfrom sglang.jit_kernel.gptq_marlin import gptq_marlin_gemm\nfrom sglang.srt.layers.quantization.marlin_utils import marlin_make_workspace\nfrom sglang.test.test_marlin_utils import awq_marlin_quantize, marlin_quantize\n\nMNK_FACTORS = [\n    (1, 1, 1),\n    (1, 4, 8),\n    (13, 17, 67),\n    (257, 13, 11),\n]\n\n\n@pytest.mark.parametrize(\"k_chunk\", [128])\n@pytest.mark.parametrize(\"n_chunk\", [64, 256])\n@pytest.mark.parametrize(\"quant_type\", [scalar_types.uint4, scalar_types.uint4b8])\n@pytest.mark.parametrize(\"group_size\", [-1, 128])\n@pytest.mark.parametrize(\"mnk_factors\", MNK_FACTORS)\n@pytest.mark.parametrize(\"act_order\", [False, True])\ndef test_gptq_marlin_gemm(\n    k_chunk,\n    n_chunk,\n    quant_type,\n    group_size,\n    mnk_factors,\n    act_order,\n):\n    m_factor, n_factor, k_factor = mnk_factors\n    has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]\n\n    size_m = m_factor\n    size_k = k_chunk * k_factor\n    size_n = n_chunk * n_factor\n\n    if act_order:\n        if group_size == -1:\n            return\n        if group_size == size_k:\n            return\n        if has_zp:\n            return\n\n    if size_k % group_size != 0:\n        return\n\n    a_input = torch.randn((size_m, size_k), dtype=torch.float16, device=\"cuda\")\n    b_weight = torch.randn((size_k, size_n), dtype=torch.float16, device=\"cuda\")\n\n    if has_zp:\n        w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(\n            b_weight, quant_type, group_size\n        )\n        g_idx = None\n        sort_indices = None\n        marlin_s2 = None\n    else:\n        w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(\n            b_weight, quant_type, group_size, act_order\n        )\n        marlin_zp = None\n        marlin_s2 = None\n\n    workspace = marlin_make_workspace(w_ref.device)\n\n    output = gptq_marlin_gemm(\n        a_input,\n        None,\n        marlin_q_w,\n        marlin_s,\n        marlin_s2,\n        marlin_zp,\n        g_idx,\n        sort_indices,\n        workspace,\n        quant_type,\n        a_input.shape[0],\n        b_weight.shape[1],\n        a_input.shape[1],\n        is_k_full=True,\n        use_atomic_add=False,\n        use_fp32_reduce=False,\n        is_zp_float=False,\n    )\n\n    output_ref = torch.matmul(a_input, w_ref)\n    torch.cuda.synchronize()\n\n    # JIT kernel should produce approximately correct results vs torch.matmul\n    max_diff = torch.mean(torch.abs(output - output_ref)) / torch.mean(\n        torch.abs(output_ref)\n    )\n    assert max_diff < 0.04\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__, \"-v\", \"-s\"])\n"
  },
  {
    "path": "python/sglang/jit_kernel/tests/test_gptq_marlin_repack.py",
    "content": "import pytest\nimport torch\nfrom sgl_kernel.scalar_type import scalar_types\n\nfrom sglang.jit_kernel.gptq_marlin_repack import gptq_marlin_repack\nfrom sglang.srt.layers.quantization.utils import (\n    gptq_quantize_weights,\n    pack_rows,\n    sort_weights,\n)\nfrom sglang.test.test_marlin_utils import get_weight_perm, marlin_weights\n\nMARLIN_K_CHUNKS = [128]\nMARLIN_N_CHUNKS = [64, 256]\n\nMNK_FACTORS = [\n    (1, 1, 1),\n    (1, 4, 8),\n    (1, 7, 5),\n    (13, 17, 67),\n    (26, 37, 13),\n    (67, 13, 11),\n    (257, 13, 11),\n    (658, 13, 11),\n]\n\n\n@pytest.mark.parametrize(\"k_chunk\", MARLIN_K_CHUNKS)\n@pytest.mark.parametrize(\"n_chunk\", MARLIN_N_CHUNKS)\n@pytest.mark.parametrize(\"quant_type\", [scalar_types.uint4b8])\n@pytest.mark.parametrize(\"group_size\", [-1, 32, 64, 128])\n@pytest.mark.parametrize(\"act_order\", [False, True])\n@pytest.mark.parametrize(\"mnk_factors\", MNK_FACTORS)\ndef test_gptq_marlin_repack(\n    k_chunk, n_chunk, quant_type, group_size, act_order, mnk_factors\n):\n    m_factor, n_factor, k_factor = mnk_factors\n\n    size_k = k_chunk * k_factor\n    size_n = n_chunk * n_factor\n\n    # Filter act_order\n    if act_order:\n        if group_size == -1:\n            return\n        if group_size == size_k:\n            return\n\n    # Normalize group_size\n    if group_size == -1:\n        group_size = size_k\n    assert group_size <= size_k\n\n    if size_k % group_size != 0:\n        pytest.skip(\"size_k must be divisible by group_size\")\n\n    # Create input\n    b_weight = torch.randn((size_k, size_n), dtype=torch.float16, device=\"cuda\")\n\n    # Quantize (and apply act_order if provided)\n    w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(\n        b_weight, quant_type, group_size, act_order\n    )\n\n    q_w_gptq = pack_rows(q_w, quant_type.size_bits, size_k, size_n)\n\n    # For act_order, sort the \"weights\" and \"g_idx\" so that group ids are\n    # increasing\n    sort_indices = torch.empty(0, dtype=torch.int, device=b_weight.device)\n    if act_order:\n        q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)\n\n    marlin_layout_perm = get_weight_perm(quant_type.size_bits)\n    q_w_marlin_ref = marlin_weights(\n        q_w, size_k, size_n, quant_type.size_bits, marlin_layout_perm\n    )\n\n    # Run JIT repack kernel\n    jit_output = gptq_marlin_repack(\n        q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits\n    )\n\n    torch.cuda.synchronize()\n\n    # JIT should match the reference (computed from CPU marlin_weights)\n    torch.testing.assert_close(jit_output, q_w_marlin_ref)\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__, \"-v\", \"-s\"])\n"
  },
  {
    "path": "python/sglang/jit_kernel/tests/test_hadamard_jit.py",
    "content": "import math\n\nimport numpy as np\nimport pytest\nimport torch\nimport torch.nn.functional as F\nfrom scipy.linalg import hadamard\n\nfrom sglang.jit_kernel.hadamard import (\n    hadamard_transform,\n    hadamard_transform_12n,\n    hadamard_transform_20n,\n    hadamard_transform_28n,\n    hadamard_transform_40n,\n)\n\n# Exact M×N Hadamard matrices (±1 entries) copied from\n# python/sglang/jit_kernel/csrc/fast-hadamard-transform/code_gen.py.\n# These are non-power-of-2 Hadamard matrices constructed via Paley/Williamson methods.\n# \"+\" = +1, \"-\" = -1.  Used by the _12n/_20n/_28n/_40n kernel variants.\n\n_HAD_12_STR = \"\"\"\n+-++++++++++\n--+-+-+-+-+-\n+++-++----++\n+---+--+-++-\n+++++-++----\n+-+---+--+-+\n++--+++-++--\n+--++---+--+\n++----+++-++\n+--+-++---+-\n++++----+++-\n+-+--+-++---\n\"\"\"\n\n_HAD_20_STR = \"\"\"\n+----+----++--++-++-\n-+----+---+++---+-++\n--+----+---+++-+-+-+\n---+----+---+++++-+-\n----+----++--++-++-+\n-+++++-----+--+++--+\n+-+++-+---+-+--+++--\n++-++--+---+-+--+++-\n+++-+---+---+-+--+++\n++++-----++--+-+--++\n--++-+-++-+-----++++\n---++-+-++-+---+-+++\n+---++-+-+--+--++-++\n++---++-+----+-+++-+\n-++---++-+----+++++-\n-+--+--++-+----+----\n+-+-----++-+----+---\n-+-+-+---+--+----+--\n--+-+++------+----+-\n+--+--++------+----+\n\"\"\"\n\n_HAD_28_STR = \"\"\"\n+------++----++-+--+-+--++--\n-+-----+++-----+-+--+-+--++-\n--+-----+++---+-+-+----+--++\n---+-----+++---+-+-+-+--+--+\n----+-----+++---+-+-+++--+--\n-----+-----++++--+-+--++--+-\n------++----++-+--+-+--++--+\n--++++-+-------++--+++-+--+-\n---++++-+-----+-++--+-+-+--+\n+---+++--+----++-++--+-+-+--\n++---++---+----++-++--+-+-+-\n+++---+----+----++-++--+-+-+\n++++--------+-+--++-++--+-+-\n-++++--------+++--++--+--+-+\n-+-++-++--++--+--------++++-\n+-+-++--+--++--+--------++++\n-+-+-++--+--++--+----+---+++\n+-+-+-++--+--+---+---++---++\n++-+-+-++--+------+--+++---+\n-++-+-+-++--+------+-++++---\n+-++-+---++--+------+-++++--\n-++--++-+-++-+++----++------\n+-++--++-+-++-+++-----+-----\n++-++---+-+-++-+++-----+----\n-++-++-+-+-+-+--+++-----+---\n--++-++++-+-+----+++-----+--\n+--++-+-++-+-+----+++-----+-\n++--++-+-++-+-+----++------+\n\"\"\"\n\n_HAD_40_STR = \"\"\"\n+-------------------+-------------------\n++-++----+-+-++++--+++-++----+-+-++++--+\n+++-++----+-+-++++--+++-++----+-+-++++--\n+-++-++----+-+-++++-+-++-++----+-+-++++-\n+--++-++----+-+-+++++--++-++----+-+-++++\n++--++-++----+-+-+++++--++-++----+-+-+++\n+++--++-++----+-+-+++++--++-++----+-+-++\n++++--++-++----+-+-+++++--++-++----+-+-+\n+++++--++-++----+-+-+++++--++-++----+-+-\n+-++++--++-++----+-++-++++--++-++----+-+\n++-++++--++-++----+-++-++++--++-++----+-\n+-+-++++--++-++----++-+-++++--++-++----+\n++-+-++++--++-++----++-+-++++--++-++----\n+-+-+-++++--++-++---+-+-+-++++--++-++---\n+--+-+-++++--++-++--+--+-+-++++--++-++--\n+---+-+-++++--++-++-+---+-+-++++--++-++-\n+----+-+-++++--++-+++----+-+-++++--++-++\n++----+-+-++++--++-+++----+-+-++++--++-+\n+++----+-+-++++--++-+++----+-+-++++--++-\n+-++----+-+-++++--+++-++----+-+-++++--++\n+--------------------+++++++++++++++++++\n++-++----+-+-++++--+--+--++++-+-+----++-\n+++-++----+-+-++++-----+--++++-+-+----++\n+-++-++----+-+-++++--+--+--++++-+-+----+\n+--++-++----+-+-++++-++--+--++++-+-+----\n++--++-++----+-+-+++--++--+--++++-+-+---\n+++--++-++----+-+-++---++--+--++++-+-+--\n++++--++-++----+-+-+----++--+--++++-+-+-\n+++++--++-++----+-+------++--+--++++-+-+\n+-++++--++-++----+-+-+----++--+--++++-+-\n++-++++--++-++----+---+----++--+--++++-+\n+-+-++++--++-++----+-+-+----++--+--++++-\n++-+-++++--++-++------+-+----++--+--++++\n+-+-+-++++--++-++----+-+-+----++--+--+++\n+--+-+-++++--++-++---++-+-+----++--+--++\n+---+-+-++++--++-++--+++-+-+----++--+--+\n+----+-+-++++--++-++-++++-+-+----++--+--\n++----+-+-++++--++-+--++++-+-+----++--+-\n+++----+-+-++++--++----++++-+-+----++--+\n+-++----+-+-++++--++-+--++++-+-+----++--\n\"\"\"\n\n\ndef _parse_hadamard_str(s):\n    \"\"\"Parse a ±1 string matrix definition into a numpy array.\"\"\"\n    s = s.strip().replace(\"+\", \"1\").replace(\"-\", \"-1\").split()\n    return np.stack(\n        [np.fromstring(\" \".join(s[i]), dtype=np.int32, sep=\" \") for i in range(len(s))]\n    )\n\n\n# Parsed M×M special Hadamard matrices, keyed by M (the \"multiple\").\n# Copied from python/sglang/jit_kernel/csrc/fast-hadamard-transform/code_gen.py\n# (had_12_paley, had_20_will, had_28_will, had_40_tpal)\n_SPECIAL_MATRICES = {\n    12: _parse_hadamard_str(_HAD_12_STR),\n    20: _parse_hadamard_str(_HAD_20_STR),\n    28: _parse_hadamard_str(_HAD_28_STR),\n    40: _parse_hadamard_str(_HAD_40_STR),\n}\n\n\ndef hadamard_transform_ref(x, scale=1.0):\n    \"\"\"Reference impl for the general (power-of-2) hadamard_transform.\n\n    Pads dim to the next power of 2, multiplies by the full H matrix\n    via F.linear, then truncates back to the original dim.\n    \"\"\"\n    x_shape = x.shape\n    dim = x.shape[-1]\n    x = x.reshape(-1, dim)\n    log_dim = math.ceil(math.log2(dim)) if dim > 0 else 0\n    dim_padded = 2**log_dim if dim > 0 else 1\n    if dim != dim_padded:\n        x = F.pad(x, (0, dim_padded - dim))\n    H = torch.tensor(hadamard(dim_padded, dtype=float), dtype=x.dtype, device=x.device)\n    out = F.linear(x, H)\n    out = out * scale\n    return out[..., :dim].reshape(*x_shape)\n\n\ndef hadamard_transform_mn_ref(x, multiple, scale=1.0):\n    \"\"\"Reference impl for the M×N hadamard variants (_12n, _20n, _28n, _40n).\n\n    The kernel computes (H_M ⊗ H_N) · x via two steps:\n      1) H_N (power-of-2 Hadamard) along the N dimension\n      2) H_M (special ±1 matrix) along the M dimension\n    where dim = M * N, M = `multiple`, N = power of 2.\n    \"\"\"\n    x_shape = x.shape\n    dim = x.shape[-1]\n    x = x.reshape(-1, dim)\n\n    # The kernel requires dim % (4*M) == 0 (for vectorized memory access).\n    # See python/sglang/jit_kernel/hadamard.py: pad_multiple = 4 * 12 / 4 * 20 / etc.\n    pad_multiple = 4 * multiple\n    if dim % pad_multiple != 0:\n        pad_size = pad_multiple - dim % pad_multiple\n        x = F.pad(x, (0, pad_size))\n        dim_padded = dim + pad_size\n    else:\n        dim_padded = dim\n\n    # N = dim_padded / M, must be a power of 2\n    n = dim_padded // multiple\n    log_n = int(math.log2(n))\n    assert 2**log_n == n, f\"n={n} is not a power of 2\"\n\n    batch = x.shape[0]\n    x = x.reshape(batch, multiple, n)  # (batch, M, N)\n\n    # Step 1: apply H_N (standard power-of-2 Hadamard) along the N dimension\n    H_n = torch.tensor(hadamard(n, dtype=float), dtype=x.dtype, device=x.device)\n    x = torch.einsum(\"bmn,kn->bmk\", x, H_n)\n\n    # Step 2: apply H_M (special ±1 matrix) along the M dimension\n    H_m = torch.tensor(\n        _SPECIAL_MATRICES[multiple].astype(float), dtype=x.dtype, device=x.device\n    )\n    x = torch.einsum(\"bmn,km->bkn\", x, H_m)\n\n    x = x.reshape(batch, -1) * scale\n    return x[..., : x_shape[-1]].reshape(*x_shape)\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float32, torch.float16, torch.bfloat16])\n@pytest.mark.parametrize(\n    \"dim\",\n    # Power-of-2 dims from sgl-kernel/tests/test_hadamard.py (old AOT test)\n    [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768],\n)\ndef test_hadamard_transform(dim, dtype):\n    device = \"cuda\"\n\n    # Tolerances from sgl-kernel/tests/test_hadamard.py (old AOT test)\n    if dtype == torch.float32:\n        rtol, atol = 3e-4, 3e-3\n    elif dtype == torch.bfloat16:\n        rtol, atol = 1e-2, 5e-2\n    else:  # float16\n        rtol, atol = 3e-3, 5e-3\n\n    torch.random.manual_seed(0)\n    batch_size = 15\n\n    x = torch.randn(batch_size, dim, device=device, dtype=dtype)\n    scale = 1.0 / math.sqrt(dim)\n\n    out = hadamard_transform(x, scale=scale)\n    # Compute reference in float32 from a detached copy to avoid precision loss\n    out_ref = hadamard_transform_ref(x.detach().clone().float(), scale=scale)\n\n    torch.testing.assert_close(out.float(), out_ref, rtol=rtol, atol=atol)\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float32, torch.float16, torch.bfloat16])\n@pytest.mark.parametrize(\n    \"dim\",\n    # Non-power-of-2 dims to test the padding path\n    # (137 from sgl-kernel/tests/test_hadamard.py, 500/1000 added for coverage)\n    [137, 500, 1000],\n)\ndef test_hadamard_transform_non_power_of_two(dim, dtype):\n    device = \"cuda\"\n\n    if dtype == torch.float32:\n        rtol, atol = 3e-4, 3e-3\n    elif dtype == torch.bfloat16:\n        rtol, atol = 1e-2, 5e-2\n    else:\n        rtol, atol = 3e-3, 5e-3\n\n    torch.random.manual_seed(42)\n    batch_size = 15\n\n    x = torch.randn(batch_size, dim, device=device, dtype=dtype)\n    scale = 1.0 / math.sqrt(dim)\n\n    out = hadamard_transform(x, scale=scale)\n    out_ref = hadamard_transform_ref(x.detach().clone().float(), scale=scale)\n\n    torch.testing.assert_close(out.float(), out_ref, rtol=rtol, atol=atol)\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16])\ndef test_hadamard_transform_3d_input(dtype):\n    device = \"cuda\"\n\n    if dtype == torch.bfloat16:\n        rtol, atol = 1e-2, 5e-2\n    else:\n        rtol, atol = 3e-3, 5e-3\n\n    torch.random.manual_seed(0)\n\n    x = torch.randn(4, 8, 256, device=device, dtype=dtype)\n    scale = 1.0 / math.sqrt(256)\n\n    out = hadamard_transform(x, scale=scale)\n    assert out.shape == x.shape\n\n    out_ref = hadamard_transform_ref(x.detach().clone().float(), scale=scale)\n    torch.testing.assert_close(out.float(), out_ref, rtol=rtol, atol=atol)\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16])\ndef test_hadamard_transform_scale_one(dtype):\n    device = \"cuda\"\n\n    if dtype == torch.bfloat16:\n        rtol, atol = 1e-2, 5e-2\n    else:\n        rtol, atol = 3e-3, 5e-3\n\n    torch.random.manual_seed(0)\n\n    x = torch.randn(8, 64, device=device, dtype=dtype)\n\n    out = hadamard_transform(x, scale=1.0)\n    out_ref = hadamard_transform_ref(x.detach().clone().float(), scale=1.0)\n\n    torch.testing.assert_close(out.float(), out_ref, rtol=rtol, atol=atol)\n\n\n# Test dimensions for M×N variants: dim = M * N where N = 2^k.\n# M = 12/20/28/40 are the non-power-of-2 Hadamard sizes registered in\n# python/sglang/jit_kernel/hadamard.py (Hadamard12NKernel, ..., Hadamard40NKernel).\n# range(2,9) gives N = 4,8,...,256 so dims cover a practical range.\n_12N_DIMS = [12 * (2**k) for k in range(2, 9)]  # 48, 96, ... , 3072\n_20N_DIMS = [20 * (2**k) for k in range(2, 9)]  # 80, 160, ... , 5120\n_28N_DIMS = [28 * (2**k) for k in range(2, 9)]  # 112, 224, ... , 7168\n_40N_DIMS = [40 * (2**k) for k in range(2, 9)]  # 160, 320, ... , 10240\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float32, torch.float16, torch.bfloat16])\n@pytest.mark.parametrize(\"dim\", _12N_DIMS)\ndef test_hadamard_transform_12n(dim, dtype):\n    device = \"cuda\"\n\n    if dtype == torch.float32:\n        rtol, atol = 3e-4, 3e-3\n    elif dtype == torch.bfloat16:\n        rtol, atol = 1e-2, 5e-2\n    else:\n        rtol, atol = 3e-3, 5e-3\n\n    torch.random.manual_seed(0)\n    batch_size = 15\n\n    x = torch.randn(batch_size, dim, device=device, dtype=dtype)\n    scale = 1.0 / math.sqrt(dim)\n\n    out = hadamard_transform_12n(x, scale=scale)\n    out_ref = hadamard_transform_mn_ref(x.detach().clone().float(), 12, scale=scale)\n\n    torch.testing.assert_close(out.float(), out_ref, rtol=rtol, atol=atol)\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float32, torch.float16, torch.bfloat16])\n@pytest.mark.parametrize(\"dim\", _20N_DIMS)\ndef test_hadamard_transform_20n(dim, dtype):\n    device = \"cuda\"\n\n    if dtype == torch.float32:\n        rtol, atol = 3e-4, 3e-3\n    elif dtype == torch.bfloat16:\n        rtol, atol = 1e-2, 5e-2\n    else:\n        rtol, atol = 3e-3, 5e-3\n\n    torch.random.manual_seed(0)\n    batch_size = 15\n\n    x = torch.randn(batch_size, dim, device=device, dtype=dtype)\n    scale = 1.0 / math.sqrt(dim)\n\n    out = hadamard_transform_20n(x, scale=scale)\n    out_ref = hadamard_transform_mn_ref(x.detach().clone().float(), 20, scale=scale)\n\n    torch.testing.assert_close(out.float(), out_ref, rtol=rtol, atol=atol)\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float32, torch.float16, torch.bfloat16])\n@pytest.mark.parametrize(\"dim\", _28N_DIMS)\ndef test_hadamard_transform_28n(dim, dtype):\n    device = \"cuda\"\n\n    if dtype == torch.float32:\n        rtol, atol = 3e-4, 3e-3\n    elif dtype == torch.bfloat16:\n        rtol, atol = 1e-2, 5e-2\n    else:\n        rtol, atol = 3e-3, 5e-3\n\n    torch.random.manual_seed(0)\n    batch_size = 15\n\n    x = torch.randn(batch_size, dim, device=device, dtype=dtype)\n    scale = 1.0 / math.sqrt(dim)\n\n    out = hadamard_transform_28n(x, scale=scale)\n    out_ref = hadamard_transform_mn_ref(x.detach().clone().float(), 28, scale=scale)\n\n    torch.testing.assert_close(out.float(), out_ref, rtol=rtol, atol=atol)\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float32, torch.float16, torch.bfloat16])\n@pytest.mark.parametrize(\"dim\", _40N_DIMS)\ndef test_hadamard_transform_40n(dim, dtype):\n    device = \"cuda\"\n\n    if dtype == torch.float32:\n        rtol, atol = 3e-4, 3e-3\n    elif dtype == torch.bfloat16:\n        rtol, atol = 1e-2, 5e-2\n    else:\n        rtol, atol = 3e-3, 5e-3\n\n    torch.random.manual_seed(0)\n    batch_size = 15\n\n    x = torch.randn(batch_size, dim, device=device, dtype=dtype)\n    scale = 1.0 / math.sqrt(dim)\n\n    out = hadamard_transform_40n(x, scale=scale)\n    out_ref = hadamard_transform_mn_ref(x.detach().clone().float(), 40, scale=scale)\n\n    torch.testing.assert_close(out.float(), out_ref, rtol=rtol, atol=atol)\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__])\n"
  },
  {
    "path": "python/sglang/jit_kernel/tests/test_moe_lora_align_block_size.py",
    "content": "# Temporarily adapted from https://github.com/vllm-project/vllm/blob/main/tests/lora/test_moe_lora_align_sum.py, will optimize in future refactor\nimport random\n\nimport pytest\nimport torch\n\n# ---------------------------------------------------------\n# IMPORT PREBUILT KERNEL\n# ---------------------------------------------------------\nfrom sglang.jit_kernel.moe_lora_align import moe_lora_align_block_size\nfrom sglang.test.ci.ci_register import register_cuda_ci\n\nregister_cuda_ci(est_time=80, suite=\"stage-b-test-large-1-gpu\")\n\n\ndef round_up(x, base):\n    return ((x + base - 1) // base) * base\n\n\ndef CEILDIV(x, y):\n    return (x + y - 1) // y\n\n\ndef sample_data(num_experts, max_loras, num_tokens, topk_num):\n    # 1. Generate TopK IDs (Flattened tokens)\n    topk_ids = torch.zeros((num_tokens, topk_num), dtype=torch.int32)\n    for i in range(num_tokens):\n        pool = list(range(num_experts))\n        random.shuffle(pool)\n        for j in range(topk_num):\n            topk_ids[i, j] = pool[j]\n\n    # 2. Generate Random Requests (Segments)\n    # We split num_tokens into random chunks to simulate a batch of requests\n    remaining_tokens = num_tokens\n    seg_lens = []\n    while remaining_tokens > 0:\n        # Random length between 1 and remaining\n        length = random.randint(1, min(32, remaining_tokens))\n        if remaining_tokens - length < 0:\n            length = remaining_tokens\n        seg_lens.append(length)\n        remaining_tokens -= length\n\n    # Ensure we cover the full range exactly (cleanup last segment)\n    if sum(seg_lens) < num_tokens:\n        seg_lens.append(num_tokens - sum(seg_lens))\n\n    # 3. Build seg_indptr [0, len1, len1+len2, ...]\n    seg_indptr = torch.cumsum(\n        torch.tensor([0] + seg_lens, dtype=torch.int32), dim=0\n    ).to(dtype=torch.int32)\n\n    # 4. Assign a LoRA ID to each Request\n    num_reqs = len(seg_lens)\n    req_to_lora = torch.randint(0, max_loras, (num_reqs,), dtype=torch.int32)\n\n    return (topk_ids.to(\"cuda\"), seg_indptr.to(\"cuda\"), req_to_lora.to(\"cuda\"))\n\n\n@pytest.mark.parametrize(\"num_tokens\", [100, 200, 1024, 4096])\n@pytest.mark.parametrize(\"topk_num\", [6])\n@pytest.mark.parametrize(\"num_experts\", [64, 128, 256, 512])\n@pytest.mark.parametrize(\"max_loras\", [2, 32])\n@pytest.mark.parametrize(\"block_size\", [16])\ndef test_moe_lora_align_block_size(\n    num_tokens, topk_num, num_experts, max_loras, block_size\n):\n    # sample data\n    random.seed(1)\n    torch.manual_seed(1)\n\n    if not torch.cuda.is_available():\n        pytest.skip(\"CUDA is not available, skipping moe_lora_align_block_size test.\")\n    # UPDATED: Get the new 3-step mapping tensors\n    topk_ids, seg_indptr, req_to_lora = sample_data(\n        num_experts, max_loras, num_tokens, topk_num\n    )\n\n    # compute paddings\n    max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)\n    max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)\n    max_num_m_blocks = CEILDIV(max_num_tokens_padded, block_size)\n\n    # init output tensors\n    sorted_token_ids = torch.full(\n        (max_loras * max_num_tokens_padded,),\n        topk_ids.numel(),\n        dtype=torch.int32,\n        device=\"cuda\",\n    )\n    expert_ids = torch.full(\n        (max_loras * max_num_m_blocks,), num_experts, dtype=torch.int32, device=\"cuda\"\n    )\n    num_tokens_post_pad = torch.zeros((max_loras,), dtype=torch.int32, device=\"cuda\")\n    adapter_enabled = torch.ones((max_loras + 1,), dtype=torch.int32, device=\"cuda\")\n    lora_ids = torch.arange(max_loras, dtype=torch.int32, device=\"cuda\")\n\n    # UPDATED: Call kernel with new signature\n    moe_lora_align_block_size(\n        topk_ids,\n        seg_indptr,  # Arg 2: Pointers\n        req_to_lora,  # Arg 3: Request Map\n        num_experts,\n        block_size,\n        max_loras,\n        max_num_tokens_padded,\n        max_num_m_blocks,\n        sorted_token_ids,\n        expert_ids,\n        num_tokens_post_pad,\n        adapter_enabled,\n        lora_ids,\n        None,\n    )\n\n    # verify values\n    expert_ids = expert_ids.view(max_loras, -1)\n    sorted_token_ids = sorted_token_ids.view(max_loras, -1, block_size)\n\n    # Reconstruct token-level ownership for verification logic\n    # We expand req_to_lora back to [num_tokens] on CPU just to check correctness\n    # This proves the kernel (which used the compressed format) produced the right result\n    cpu_seg_indptr = seg_indptr.cpu()\n    cpu_req_to_lora = req_to_lora.cpu()\n    token_ownership = torch.zeros(num_tokens, dtype=torch.int32)\n\n    for r in range(len(cpu_req_to_lora)):\n        start = cpu_seg_indptr[r]\n        end = cpu_seg_indptr[r + 1]\n        token_ownership[start:end] = cpu_req_to_lora[r]\n\n    token_ownership = token_ownership.to(\"cuda\")\n\n    for lora_idx in range(max_loras):\n        # Count how many tokens actually belong to this LoRA\n        expected_count = (token_ownership == lora_idx).sum().item()\n\n        # Verify the kernel processed a reasonable number of tokens (sanity check)\n        # Note: num_tokens_post_pad includes padding, so it might be larger than expected_count\n        assert num_tokens_post_pad[lora_idx].item() >= expected_count * topk_num\n\n        for token_idx in range(sorted_token_ids.size(1)):\n            block = sorted_token_ids[lora_idx][token_idx]\n            # Valid indices are those less than total numel\n            indices = block[block != topk_ids.numel()]\n\n            if indices.numel() > 0:\n                # 1. Verify routing: Does the token actually route to this expert?\n                expert_id = expert_ids[lora_idx][token_idx]\n                assert torch.all(topk_ids.view(-1)[indices] == expert_id)\n\n                # 2. Verify ownership: Did the kernel grab the correct tokens for this LoRA?\n                # The indices in 'sorted_token_ids' point to the flattened [token, topk] array.\n                # We divide by topk_num to get the original token index.\n                original_token_indices = indices // topk_num\n\n                # Check that all tokens in this block truly belong to 'lora_idx'\n                actual_owners = token_ownership[original_token_indices]\n                assert torch.all(\n                    actual_owners == lora_idx\n                ), f\"Kernel put tokens from LoRA {actual_owners} into block for LoRA {lora_idx}\"\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__])\n"
  },
  {
    "path": "python/sglang/jit_kernel/tests/test_moe_wna16_marlin.py",
    "content": "import itertools\n\nimport pytest\nimport torch\nfrom sgl_kernel.scalar_type import scalar_types\n\nfrom sglang.jit_kernel.moe_wna16_marlin import moe_wna16_marlin_gemm\nfrom sglang.srt.layers.moe.fused_moe_triton import moe_align_block_size\nfrom sglang.test.test_marlin_utils import awq_marlin_quantize, marlin_quantize\n\n\ndef _has_aot_moe_wna16_marlin_gemm() -> bool:\n    return hasattr(torch.ops.sgl_kernel, \"moe_wna16_marlin_gemm\") and hasattr(\n        torch.ops.sgl_kernel.moe_wna16_marlin_gemm, \"default\"\n    )\n\n\nAOT_AVAILABLE = _has_aot_moe_wna16_marlin_gemm()\n\n\ndef stack_and_dev(tensors: list[torch.Tensor]):\n    dev = tensors[0].device\n    return torch.stack(tensors, dim=0).to(dev)\n\n\ndef _get_scalar_type(num_bits: int, has_zp: bool):\n    if has_zp:\n        assert num_bits == 4\n        return scalar_types.uint4\n    else:\n        return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128\n\n\ndef _setup_moe_weights(e, n, k, quant_type, group_size, act_order, dtype):\n    \"\"\"Set up quantized MoE weights for a single gate (e experts, output n, input k).\"\"\"\n    has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]\n\n    w = torch.randn((e, n, k), device=\"cuda\", dtype=dtype) / 20\n\n    w_ref_l = []\n    qweight_l = []\n    scales_l = []\n    zeros_l = []\n    g_idx_l = []\n    sort_indices_l = []\n\n    for i in range(e):\n        if has_zp:\n            w_ref, qweight, scales, zeros = awq_marlin_quantize(\n                w[i].transpose(1, 0), quant_type, group_size\n            )\n            w_ref_l.append(w_ref.T)\n            qweight_l.append(qweight)\n            scales_l.append(scales)\n            zeros_l.append(zeros)\n        else:\n            test_perm = torch.randperm(k)\n            w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize(\n                w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm\n            )\n            w_ref_l.append(w_ref.T)\n            qweight_l.append(qweight)\n            scales_l.append(scales)\n            g_idx_l.append(g_idx)\n            sort_indices_l.append(sort_indices)\n\n    w_ref = stack_and_dev(w_ref_l)\n    qweight = stack_and_dev(qweight_l).contiguous()\n    scales = stack_and_dev(scales_l)\n    g_idx = stack_and_dev(g_idx_l) if g_idx_l else None\n    sort_indices = stack_and_dev(sort_indices_l) if sort_indices_l else None\n    zeros = stack_and_dev(zeros_l) if zeros_l else None\n\n    return w_ref, qweight, scales, zeros, g_idx, sort_indices\n\n\ndef _run_single_gemm(\n    fn,\n    a,\n    c,\n    qweight,\n    scales,\n    zeros,\n    g_idx,\n    sort_indices,\n    workspace,\n    sorted_token_ids,\n    expert_ids,\n    num_tokens_post_padded,\n    topk_weights,\n    quant_type,\n    block_size_m,\n    topk,\n    size_m,\n    size_n,\n    size_k,\n    mul_topk_weights,\n    is_k_full,\n    use_atomic_add,\n):\n    return fn(\n        a,\n        c,\n        qweight,\n        None,  # b_bias\n        scales,\n        None,  # global_scale\n        zeros,\n        g_idx,\n        sort_indices,\n        workspace,\n        sorted_token_ids,\n        expert_ids,\n        num_tokens_post_padded,\n        topk_weights,\n        moe_block_size=block_size_m,\n        top_k=topk,\n        mul_topk_weights=mul_topk_weights,\n        is_ep=False,\n        b_q_type=quant_type,\n        size_m=size_m,\n        size_n=size_n,\n        size_k=size_k,\n        is_k_full=is_k_full,\n        use_atomic_add=use_atomic_add,\n        use_fp32_reduce=True,\n        is_zp_float=False,\n    )\n\n\ndef _run_single_gemm_aot(\n    a,\n    c,\n    qweight,\n    scales,\n    zeros,\n    g_idx,\n    sort_indices,\n    workspace,\n    sorted_token_ids,\n    expert_ids,\n    num_tokens_post_padded,\n    topk_weights,\n    quant_type,\n    block_size_m,\n    topk,\n    size_m,\n    size_n,\n    size_k,\n    mul_topk_weights,\n    is_k_full,\n    use_atomic_add,\n):\n    return torch.ops.sgl_kernel.moe_wna16_marlin_gemm.default(\n        a,\n        c,\n        qweight,\n        None,  # b_bias\n        scales,\n        None,  # global_scale\n        zeros,\n        g_idx,\n        sort_indices,\n        workspace,\n        sorted_token_ids,\n        expert_ids,\n        num_tokens_post_padded,\n        topk_weights,\n        moe_block_size=block_size_m,\n        top_k=topk,\n        mul_topk_weights=mul_topk_weights,\n        is_ep=False,\n        b_q_type_id=quant_type.id,\n        size_m=size_m,\n        size_n=size_n,\n        size_k=size_k,\n        is_k_full=is_k_full,\n        use_atomic_add=use_atomic_add,\n        use_fp32_reduce=True,\n        is_zp_float=False,\n    )\n\n\ndef generate_test_cases():\n    m_list = [1, 123]\n    n_list = [128, 1024]\n    k_list = [256]\n    e_list = [4]\n    topk_list = [2]\n    dtype_list = [torch.float16, torch.bfloat16]\n    group_size_list = [128]\n    act_order_list = [False, True]\n    quant_type_list = [scalar_types.uint4, scalar_types.uint4b8]\n\n    all_combinations = itertools.product(\n        m_list,\n        n_list,\n        k_list,\n        e_list,\n        topk_list,\n        dtype_list,\n        group_size_list,\n        act_order_list,\n        quant_type_list,\n    )\n\n    def is_valid(m, n, k, e, topk, dtype, group_size, act_order, quant_type):\n        has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]\n        if act_order:\n            if group_size == -1 or group_size == k:\n                return False\n            if has_zp:\n                return False\n        if group_size > 0 and k % group_size != 0:\n            return False\n        return True\n\n    return [case for case in all_combinations if is_valid(*case)]\n\n\nTEST_CASES = generate_test_cases()\n\n\n@pytest.mark.parametrize(\n    \"m,n,k,e,topk,dtype,group_size,act_order,quant_type\",\n    TEST_CASES,\n    ids=[\n        f\"m{c[0]}_n{c[1]}_k{c[2]}_e{c[3]}_t{c[4]}_{c[5].__name__ if hasattr(c[5], '__name__') else str(c[5]).split('.')[-1]}_g{c[6]}_act{c[7]}_{c[8]}\"\n        for c in TEST_CASES\n    ],\n)\ndef test_moe_wna16_marlin_gemm(\n    m, n, k, e, topk, dtype, group_size, act_order, quant_type\n):\n    if not AOT_AVAILABLE:\n        pytest.skip(\"sgl_kernel moe_wna16_marlin_gemm AOT op not available\")\n\n    torch.manual_seed(0)\n\n    has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]\n\n    a = torch.randn((m, k), device=\"cuda\", dtype=dtype) / 10\n\n    # Set up quantized weights for first gemm (gate_up: output 2*n, input k)\n    w_ref1, qweight1, scales1, zeros1, g_idx1, sort_indices1 = _setup_moe_weights(\n        e, 2 * n, k, quant_type, group_size, act_order, dtype\n    )\n\n    # Compute block_size_m\n    for block_size_m in [8, 16, 32, 48, 64]:\n        if m * topk / e / block_size_m < 0.9:\n            break\n\n    # Align tokens\n    score = torch.randn((m, e), device=\"cuda\", dtype=dtype)\n    score_softmax = torch.softmax(score, dim=-1, dtype=torch.float32)\n    topk_weights, topk_ids = torch.topk(score_softmax, topk)\n\n    sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(\n        topk_ids, block_size_m, e\n    )\n\n    # Workspace\n    sms = torch.cuda.get_device_properties(\"cuda\").multi_processor_count\n    max_workspace_size = (max(2 * n, k) // 64) * (\n        sorted_token_ids.size(0) // block_size_m\n    )\n    max_workspace_size = min(max_workspace_size, sms * 4)\n    workspace = torch.zeros(\n        max_workspace_size, dtype=torch.int, device=\"cuda\", requires_grad=False\n    )\n\n    use_atomic_add = (\n        dtype == torch.half or torch.cuda.get_device_capability(\"cuda\")[0] >= 9\n    )\n\n    scalar_type = _get_scalar_type(4, has_zp)\n\n    # --- Run JIT kernel ---\n    c_jit = torch.empty((m * topk, 2 * n), dtype=dtype, device=\"cuda\")\n    c_jit = _run_single_gemm(\n        moe_wna16_marlin_gemm,\n        a,\n        c_jit,\n        qweight1,\n        scales1,\n        zeros1,\n        g_idx1,\n        sort_indices1,\n        workspace,\n        sorted_token_ids,\n        expert_ids,\n        num_tokens_post_padded,\n        topk_weights,\n        scalar_type,\n        block_size_m,\n        topk,\n        m,\n        2 * n,\n        k,\n        False,\n        True,\n        use_atomic_add,\n    )\n\n    torch.cuda.synchronize()\n\n    # --- Check bitwise equality with AOT kernel ---\n    c_aot = torch.empty((m * topk, 2 * n), dtype=dtype, device=\"cuda\")\n    c_aot = _run_single_gemm_aot(\n        a,\n        c_aot,\n        qweight1,\n        scales1,\n        zeros1,\n        g_idx1,\n        sort_indices1,\n        workspace,\n        sorted_token_ids,\n        expert_ids,\n        num_tokens_post_padded,\n        topk_weights,\n        scalar_type,\n        block_size_m,\n        topk,\n        m,\n        2 * n,\n        k,\n        False,\n        True,\n        use_atomic_add,\n    )\n    torch.cuda.synchronize()\n    torch.testing.assert_close(c_jit, c_aot, rtol=0, atol=0)\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__, \"-v\", \"-s\"])\n"
  },
  {
    "path": "python/sglang/jit_kernel/tests/test_norm_jit.py",
    "content": "# Adapted from sgl-kernel/tests/test_norm.py\n\nimport pytest\nimport torch\n\n# JIT rmsnorm: fp16/bf16 only; hidden_size must be a multiple of 256, > 256, and <=8192\nRMSNORM_HIDDEN_SIZES = [512, 1024, 3072, 3584, 4096, 8192]\n\n# JIT fused_add_rmsnorm: fp16/bf16 only; hidden_size % 8 == 0, <=8192\nFUSED_ADD_RMSNORM_HIDDEN_SIZES = [1024, 3072, 3584, 4096, 8192]\n\nBS_LIST = [1, 19, 99, 989]\n\n\ndef _jit_rmsnorm(input, weight, output, eps):\n    from sglang.jit_kernel.norm import rmsnorm\n\n    rmsnorm(input, weight, output=output, eps=eps)\n\n\ndef _fi_rmsnorm(input, weight, out, eps):\n    from flashinfer.norm import rmsnorm\n\n    rmsnorm(input, weight, out=out, eps=eps)\n\n\ndef _jit_fused_add_rmsnorm(input, residual, weight, eps):\n    from sglang.jit_kernel.norm import fused_add_rmsnorm\n\n    fused_add_rmsnorm(input, residual, weight, eps)\n\n\ndef _fi_fused_add_rmsnorm(input, residual, weight, eps):\n    from flashinfer.norm import fused_add_rmsnorm\n\n    fused_add_rmsnorm(input, residual, weight, eps=eps)\n\n\n@pytest.mark.parametrize(\"batch_size\", BS_LIST)\n@pytest.mark.parametrize(\"hidden_size\", RMSNORM_HIDDEN_SIZES)\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16])\n@pytest.mark.parametrize(\"specify_out\", [True, False])\ndef test_rmsnorm_jit(batch_size, hidden_size, dtype, specify_out):\n    eps = 1e-6\n    x = torch.randn(batch_size, hidden_size, device=\"cuda\", dtype=dtype)\n    w = torch.randn(hidden_size, device=\"cuda\", dtype=dtype)\n\n    # flashinfer reference\n    x_ref = x.clone()\n    _fi_rmsnorm(x_ref, w, out=x_ref, eps=eps)\n\n    if specify_out:\n        y = torch.empty_like(x)\n        _jit_rmsnorm(x, w, output=y, eps=eps)\n    else:\n        y = x.clone()\n        _jit_rmsnorm(y, w, output=y, eps=eps)\n\n    torch.testing.assert_close(y, x_ref, rtol=1e-2, atol=1e-2)\n\n\n@pytest.mark.parametrize(\"batch_size\", BS_LIST)\n@pytest.mark.parametrize(\"hidden_size\", FUSED_ADD_RMSNORM_HIDDEN_SIZES)\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16])\ndef test_fused_add_rmsnorm_jit(batch_size, hidden_size, dtype):\n    eps = 1e-6\n    x = torch.randn(batch_size, hidden_size, dtype=dtype, device=\"cuda\")\n    residual = torch.randn_like(x)\n    weight = torch.randn(hidden_size, dtype=dtype, device=\"cuda\")\n\n    # flashinfer reference\n    x_ref = x.clone()\n    r_ref = residual.clone()\n    _fi_fused_add_rmsnorm(x_ref, r_ref, weight, eps=eps)\n\n    x_jit = x.clone()\n    r_jit = residual.clone()\n    _jit_fused_add_rmsnorm(x_jit, r_jit, weight, eps)\n\n    torch.testing.assert_close(x_jit, x_ref, rtol=1e-2, atol=1e-2)\n    torch.testing.assert_close(r_jit, r_ref, rtol=1e-2, atol=1e-2)\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__])\n"
  },
  {
    "path": "python/sglang/jit_kernel/tests/test_nvfp4_blockwise_moe.py",
    "content": "import pytest\nimport torch\n\nfrom sglang.jit_kernel.nvfp4 import (\n    cutlass_fp4_group_mm,\n    scaled_fp4_experts_quant,\n    scaled_fp4_quant,\n)\n\nFLOAT4_E2M1_MAX = 6.0\nFLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max\n\n\ndef _nvfp4_supported() -> bool:\n    return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (10, 0)\n\n\ndef _round_up(x: int, y: int) -> int:\n    return ((x + y - 1) // y) * y\n\n\ndef _build_expert_offsets(\n    m_per_expert: list[int], device: torch.device\n) -> torch.Tensor:\n    offsets = [0]\n    for m in m_per_expert:\n        offsets.append(offsets[-1] + m)\n    return torch.tensor(offsets, dtype=torch.int32, device=device)\n\n\ndef _build_blockscale_offsets(\n    m_per_expert: list[int], device: torch.device\n) -> torch.Tensor:\n    offsets = [0]\n    for m in m_per_expert:\n        offsets.append(offsets[-1] + _round_up(m, 128))\n    return torch.tensor(offsets, dtype=torch.int32, device=device)\n\n\n@pytest.mark.skipif(\n    not _nvfp4_supported(), reason=\"NVFP4 requires compute capability >= 10.0\"\n)\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16])\ndef test_nvfp4_blockwise_moe_grouped_mm(dtype: torch.dtype) -> None:\n    torch.manual_seed(0)\n    device = torch.device(\"cuda\")\n\n    num_experts = 4\n    m_per_expert = [33, 17, 48, 29]\n    n = 256\n    k = 128\n\n    expert_offsets_full = _build_expert_offsets(m_per_expert, device)\n    blockscale_offsets_full = _build_blockscale_offsets(m_per_expert, device)\n\n    total_m = int(expert_offsets_full[-1].item())\n    a = torch.randn((total_m, k), device=device, dtype=dtype) * 0.1\n    b = torch.randn((num_experts, n, k), device=device, dtype=dtype) * 0.1\n\n    a_global_scale = torch.empty((num_experts,), device=device, dtype=torch.float32)\n    for i in range(num_experts):\n        start = int(expert_offsets_full[i].item())\n        end = int(expert_offsets_full[i + 1].item())\n        amax = a[start:end].abs().max().to(torch.float32)\n        a_global_scale[i] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / amax\n\n    b_global_scale = torch.empty((num_experts,), device=device, dtype=torch.float32)\n    for i in range(num_experts):\n        bmax = b[i].abs().max().to(torch.float32)\n        b_global_scale[i] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / bmax\n\n    a_fp4, a_blockscale = scaled_fp4_experts_quant(\n        a,\n        a_global_scale,\n        expert_offsets_full,\n        blockscale_offsets_full,\n        topk=1,\n    )\n\n    b_fp4 = torch.empty((num_experts, n, k // 2), device=device, dtype=torch.uint8)\n    b_blockscale = torch.empty(\n        (num_experts, _round_up(n, 128), _round_up(k // 16, 4)),\n        device=device,\n        dtype=torch.float8_e4m3fn,\n    )\n    for i in range(num_experts):\n        b_fp4_i, b_scale_i = scaled_fp4_quant(b[i], b_global_scale[i])\n        b_fp4[i].copy_(b_fp4_i)\n        b_blockscale[i].copy_(b_scale_i)\n\n    alphas = (1.0 / (a_global_scale * b_global_scale)).to(torch.float32)\n\n    params = {\n        \"ab_strides\": torch.full((num_experts,), k, dtype=torch.int64, device=device),\n        \"c_strides\": torch.full((num_experts,), n, dtype=torch.int64, device=device),\n        \"problem_sizes\": torch.tensor(\n            [[m, n, k] for m in m_per_expert], dtype=torch.int32, device=device\n        ),\n        \"expert_offsets\": expert_offsets_full[:-1].contiguous(),\n        \"blockscale_offsets\": blockscale_offsets_full[:-1].contiguous(),\n        \"a_ptrs\": torch.empty((num_experts,), dtype=torch.int64, device=device),\n        \"b_ptrs\": torch.empty((num_experts,), dtype=torch.int64, device=device),\n        \"out_ptrs\": torch.empty((num_experts,), dtype=torch.int64, device=device),\n        \"a_scales_ptrs\": torch.empty((num_experts,), dtype=torch.int64, device=device),\n        \"b_scales_ptrs\": torch.empty((num_experts,), dtype=torch.int64, device=device),\n        \"alpha_ptrs\": torch.empty((num_experts,), dtype=torch.int64, device=device),\n        \"layout_sfa\": torch.empty((num_experts, 5), dtype=torch.int64, device=device),\n        \"layout_sfb\": torch.empty((num_experts, 5), dtype=torch.int64, device=device),\n    }\n\n    out = cutlass_fp4_group_mm(\n        a_fp4,\n        b_fp4,\n        a_blockscale,\n        b_blockscale,\n        alphas,\n        dtype,\n        params,\n    )\n\n    ref = torch.empty((total_m, n), device=device, dtype=dtype)\n    for i in range(num_experts):\n        start = int(expert_offsets_full[i].item())\n        end = int(expert_offsets_full[i + 1].item())\n        ref[start:end] = torch.matmul(a[start:end], b[i].t())\n\n    torch.testing.assert_close(out, ref, atol=1e-1, rtol=1e-1)\n"
  },
  {
    "path": "python/sglang/jit_kernel/tests/test_nvfp4_gemm.py",
    "content": "import pytest\nimport torch\n\nfrom sglang.jit_kernel.nvfp4 import cutlass_scaled_fp4_mm, scaled_fp4_quant\n\n\ndef _nvfp4_supported() -> bool:\n    return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (10, 0)\n\n\nDTYPES = [torch.float16, torch.bfloat16]\nSHAPES = [\n    (128, 128, 64),\n    (128, 128, 128),\n    (256, 128, 64),\n    (128, 256, 128),\n    (150, 128, 64),\n]\n\nFLOAT4_E2M1_MAX = 6.0\nFLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max\n\nK_E2M1_TO_FLOAT = [\n    0.0,\n    0.5,\n    1.0,\n    1.5,\n    2.0,\n    3.0,\n    4.0,\n    6.0,\n]\n\n\ndef e2m1_to_fp32(int4_value: int) -> float:\n    sign_bit = int4_value & 0x8\n    int4_abs_value = int4_value & 0x7\n    float_result = K_E2M1_TO_FLOAT[int4_abs_value]\n    return -float_result if sign_bit else float_result\n\n\ndef break_fp4_bytes(a: torch.Tensor) -> torch.Tensor:\n    assert a.dtype == torch.uint8\n    m, n = a.shape\n    a = a.flatten()\n    high_half_byte = (a & 0xF0) >> 4\n    low_half_byte = a & 0x0F\n    f_h = torch.tensor([e2m1_to_fp32(x) for x in high_half_byte], device=a.device)\n    f_l = torch.tensor([e2m1_to_fp32(x) for x in low_half_byte], device=a.device)\n    return torch.stack((f_l, f_h), dim=-1).reshape(m, n * 2)\n\n\ndef convert_swizzled_to_linear(\n    a_sf_swizzled: torch.Tensor, m: int, k: int, block_size: int\n) -> torch.Tensor:\n    sf_m, sf_k = a_sf_swizzled.shape\n    del sf_m, sf_k\n    m_tiles = (m + 128 - 1) // 128\n    f = block_size * 4\n    k_tiles = (k + f - 1) // f\n    tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4))\n    tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5))\n    out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size)\n    return out[0:m, 0 : k // block_size]\n\n\ndef dequantize_to_dtype(\n    tensor_fp4: torch.Tensor,\n    tensor_sf: torch.Tensor,\n    global_scale: torch.Tensor,\n    block_size: int = 16,\n) -> torch.Tensor:\n    assert tensor_fp4.dtype == torch.uint8\n    m, packed_k = tensor_fp4.shape\n    k = packed_k * 2\n    tensor_f32 = break_fp4_bytes(tensor_fp4)\n    tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size)\n    tensor_sf = tensor_sf.view(torch.float8_e4m3fn)\n    tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)\n    tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale\n    return (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k)\n\n\ndef get_ref_results(\n    a_fp4: torch.Tensor,\n    b_fp4: torch.Tensor,\n    a_sf: torch.Tensor,\n    b_sf: torch.Tensor,\n    a_global_scale: torch.Tensor,\n    b_global_scale: torch.Tensor,\n    block_size: int,\n) -> torch.Tensor:\n    a_in_dtype = dequantize_to_dtype(a_fp4, a_sf, a_global_scale, block_size=block_size)\n    b_in_dtype = dequantize_to_dtype(b_fp4, b_sf, b_global_scale, block_size=block_size)\n    return torch.matmul(a_in_dtype, b_in_dtype.t())\n\n\n@pytest.mark.skipif(\n    not _nvfp4_supported(), reason=\"NVFP4 requires compute capability >= 10.0\"\n)\n@pytest.mark.parametrize(\"dtype\", DTYPES)\n@pytest.mark.parametrize(\"shape\", SHAPES)\ndef test_nvfp4_gemm(dtype: torch.dtype, shape: tuple[int, int, int]) -> None:\n    m, n, packed_k = shape\n    k = packed_k * 2\n    block_size = 16\n\n    a_dtype = torch.randn((m, k), dtype=dtype, device=\"cuda\")\n    b_dtype = torch.randn((n, k), dtype=dtype, device=\"cuda\")\n\n    a_global_scale = (\n        (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a_dtype.flatten(), dim=-1)\n    ).to(torch.float32)\n    b_global_scale = (\n        (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1)\n    ).to(torch.float32)\n\n    alpha = 1.0 / (a_global_scale * b_global_scale)\n\n    a_fp4, a_scale_interleaved = scaled_fp4_quant(a_dtype, a_global_scale)\n    b_fp4, b_scale_interleaved = scaled_fp4_quant(b_dtype, b_global_scale)\n\n    expected_out = get_ref_results(\n        a_fp4,\n        b_fp4,\n        a_scale_interleaved,\n        b_scale_interleaved,\n        a_global_scale,\n        b_global_scale,\n        block_size,\n    )\n\n    out = cutlass_scaled_fp4_mm(\n        a_fp4,\n        b_fp4,\n        a_scale_interleaved,\n        b_scale_interleaved,\n        alpha,\n        dtype,\n    )\n\n    torch.testing.assert_close(out, expected_out.to(dtype=dtype), atol=1e-1, rtol=1e-1)\n"
  },
  {
    "path": "python/sglang/jit_kernel/tests/test_nvfp4_quant.py",
    "content": "import pytest\nimport torch\n\nfrom sglang.jit_kernel.nvfp4 import (\n    scaled_fp4_grouped_quant,\n    scaled_fp4_quant,\n    silu_and_mul_scaled_fp4_grouped_quant,\n)\n\ntry:\n    from sgl_kernel import silu_and_mul as _sgl_silu_and_mul\nexcept Exception:\n    _sgl_silu_and_mul = None\n\n\ndef _nvfp4_supported() -> bool:\n    return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (10, 0)\n\n\ndef _silu_and_mul_reference(x: torch.Tensor) -> torch.Tensor:\n    if _sgl_silu_and_mul is not None:\n        return _sgl_silu_and_mul(x)\n    k = x.shape[-1] // 2\n    return torch.nn.functional.silu(x[:, :, :k]) * x[:, :, k:]\n\n\nDTYPES = [torch.float16, torch.bfloat16]\nSHAPES = [(128, 64), (128, 128), (256, 64), (256, 128)]\nPAD_SHAPES = [\n    (90, 64),\n    (150, 64),\n    (128, 48),\n    (128, 80),\n]\n\nFLOAT4_E2M1_MAX = 6.0\nFLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max\nBLOCK_SIZE = 16\n\nE2M1_TO_FLOAT32 = [\n    0.0,\n    0.5,\n    1.0,\n    1.5,\n    2.0,\n    3.0,\n    4.0,\n    6.0,\n    0.0,\n    -0.5,\n    -1.0,\n    -1.5,\n    -2.0,\n    -3.0,\n    -4.0,\n    -6.0,\n]\n\n\ndef cast_from_fp4(x: torch.Tensor, m: int, n: int) -> torch.Tensor:\n    v_2nd = (x & 0xF).to(torch.long)\n    v_1st = ((x >> 4) & 0xF).to(torch.long)\n    c = torch.stack((v_2nd, v_1st), dim=-1).flatten()\n    lut = torch.tensor(E2M1_TO_FLOAT32, device=x.device, dtype=torch.float32)\n    return lut[c].reshape(m, n)\n\n\ndef cast_to_fp4(x: torch.Tensor) -> torch.Tensor:\n    sign = torch.sign(x)\n    x = torch.abs(x)\n    x[(x >= 0.0) & (x <= 0.25)] = 0.0\n    x[(x > 0.25) & (x < 0.75)] = 0.5\n    x[(x >= 0.75) & (x <= 1.25)] = 1.0\n    x[(x > 1.25) & (x < 1.75)] = 1.5\n    x[(x >= 1.75) & (x <= 2.5)] = 2.0\n    x[(x > 2.5) & (x < 3.5)] = 3.0\n    x[(x >= 3.5) & (x <= 5.0)] = 4.0\n    x[x > 5.0] = 6.0\n    return x * sign\n\n\ndef get_reciprocal(x):\n    if isinstance(x, torch.Tensor):\n        return torch.where(x == 0, torch.tensor(0.0, dtype=x.dtype), 1.0 / x)\n    return 0.0 if x == 0 else 1.0 / x\n\n\ndef ref_nvfp4_quant(x: torch.Tensor, global_scale: torch.Tensor):\n    assert global_scale.dtype == torch.float32\n    assert x.ndim == 2\n    m, n = x.shape\n    x = torch.reshape(x, (m, n // BLOCK_SIZE, BLOCK_SIZE))\n    vec_max = torch.max(torch.abs(x), dim=-1, keepdim=True)[0].to(torch.float32)\n    scale = global_scale * (vec_max * get_reciprocal(FLOAT4_E2M1_MAX))\n    scale = scale.to(torch.float8_e4m3fn).to(torch.float32)\n    output_scale = get_reciprocal(scale * get_reciprocal(global_scale))\n\n    scaled_x = x.to(torch.float32) * output_scale\n    clipped_x = torch.clamp(scaled_x, -6.0, 6.0).reshape(m, n)\n    return cast_to_fp4(clipped_x), scale.squeeze(-1)\n\n\ndef recover_swizzled_scales(scale: torch.Tensor, m: int, n: int) -> torch.Tensor:\n    rounded_m = ((m + 128 - 1) // 128) * 128\n    scale_n = n // BLOCK_SIZE\n    rounded_n = ((scale_n + 4 - 1) // 4) * 4\n    tmp = torch.reshape(scale, (1, rounded_m // 128, rounded_n // 4, 32, 4, 4))\n    tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5))\n    result = torch.reshape(tmp, (rounded_m, rounded_n)).to(torch.float32)\n    return result[:m, :scale_n]\n\n\n@pytest.mark.skipif(\n    not _nvfp4_supported(), reason=\"NVFP4 requires compute capability >= 10.0\"\n)\n@pytest.mark.parametrize(\"dtype\", DTYPES)\n@pytest.mark.parametrize(\"shape\", SHAPES)\ndef test_quantize_to_fp4(dtype: torch.dtype, shape: tuple[int, int]) -> None:\n    torch.manual_seed(42)\n    m, n = shape\n\n    x = torch.randn((m, n), dtype=dtype, device=\"cuda\")\n    tensor_amax = torch.abs(x).max().to(torch.float32)\n    global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax\n    out_ref, scale_ref = ref_nvfp4_quant(x, global_scale)\n\n    out, out_scale = scaled_fp4_quant(x, global_scale)\n    scale_ans = recover_swizzled_scales(out_scale, m, n)\n    out_ans = cast_from_fp4(out, m, n)\n\n    torch.testing.assert_close(out_ans, out_ref)\n    torch.testing.assert_close(scale_ans, scale_ref)\n\n\n@pytest.mark.skipif(\n    not _nvfp4_supported(), reason=\"NVFP4 requires compute capability >= 10.0\"\n)\n@pytest.mark.parametrize(\"shape\", PAD_SHAPES)\ndef test_quantize_to_fp4_padded(shape: tuple[int, int]) -> None:\n    torch.manual_seed(42)\n    m, n = shape\n    x = torch.randn((m, n), dtype=torch.float16, device=\"cuda\")\n\n    tensor_amax = torch.abs(x).max().to(torch.float32)\n    global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax\n    out_ref, scale_ref = ref_nvfp4_quant(x, global_scale)\n\n    out, out_scale = scaled_fp4_quant(x, global_scale)\n    scale_ans = recover_swizzled_scales(out_scale, m, n)\n    out_ans = cast_from_fp4(out, m, n)\n\n    torch.testing.assert_close(out_ans, out_ref)\n    torch.testing.assert_close(scale_ans, scale_ref)\n\n\n@pytest.mark.skipif(\n    not _nvfp4_supported(), reason=\"NVFP4 requires compute capability >= 10.0\"\n)\n@pytest.mark.parametrize(\"shape\", [(2, 128, 512), (2, 100, 128)])\ndef test_quantize_to_fp4_grouped(shape: tuple[int, int, int]) -> None:\n    torch.manual_seed(42)\n    l, m, k = shape\n\n    x = torch.randn((l, m, k), dtype=torch.bfloat16, device=\"cuda\")\n    mask = torch.randint(1, max(2, m // 2), (l,), dtype=torch.int32, device=\"cuda\")\n    tensor_amax = x.abs().amax(dim=(1, 2)).to(torch.float32)\n    x_sf_global = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax\n\n    output, output_scales = scaled_fp4_grouped_quant(x, x_sf_global, mask)\n    output = output.permute(2, 0, 1)\n    padded_m = ((m + 128 - 1) // 128) * 128\n    output_scales = output_scales.permute(5, 2, 4, 0, 1, 3).view(l, padded_m, -1)\n\n    for i in range(l):\n        a_fp4, a_scale_interleaved = scaled_fp4_quant(x[i], x_sf_global[i])\n        torch.testing.assert_close(a_fp4[: mask[i]], output[i][: mask[i]])\n        scale_ref = recover_swizzled_scales(a_scale_interleaved, m, k)\n        scale_ans = recover_swizzled_scales(output_scales[i], m, k)\n        torch.testing.assert_close(scale_ref[: mask[i]], scale_ans[: mask[i]])\n\n\n@pytest.mark.skipif(\n    not _nvfp4_supported(), reason=\"NVFP4 requires compute capability >= 10.0\"\n)\n@pytest.mark.parametrize(\"shape\", [(4, 96, 256), (8, 128, 512)])\ndef test_silu_and_mul_quantize_to_fp4_grouped(shape: tuple[int, int, int]) -> None:\n    torch.manual_seed(42)\n    l, m, k = shape\n\n    x = torch.randn((l, m, k * 2), dtype=torch.bfloat16, device=\"cuda\")\n    mask = torch.randint(1, max(2, m // 2), (l,), dtype=torch.int32, device=\"cuda\")\n\n    ref_y = _silu_and_mul_reference(x)\n\n    tensor_amax = ref_y.abs().amax(dim=(1, 2)).to(torch.float32)\n    y_sf_global = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax\n\n    ref_output, ref_output_scales = scaled_fp4_grouped_quant(ref_y, y_sf_global, mask)\n    output, output_scales = silu_and_mul_scaled_fp4_grouped_quant(x, y_sf_global, mask)\n\n    output = output.permute(2, 0, 1)\n    ref_output = ref_output.permute(2, 0, 1)\n\n    padded_m = ((m + 128 - 1) // 128) * 128\n    output_scales = output_scales.permute(5, 2, 4, 0, 1, 3).view(l, padded_m, -1)\n    ref_output_scales = ref_output_scales.permute(5, 2, 4, 0, 1, 3).view(\n        l, padded_m, -1\n    )\n\n    for i in range(l):\n        torch.testing.assert_close(ref_output[i, : mask[i]], output[i, : mask[i]])\n        scale_ref = recover_swizzled_scales(ref_output_scales[i], m, k)\n        scale_ans = recover_swizzled_scales(output_scales[i], m, k)\n        torch.testing.assert_close(scale_ref[: mask[i]], scale_ans[: mask[i]])\n"
  },
  {
    "path": "python/sglang/jit_kernel/tests/test_per_tensor_quant_fp8.py",
    "content": "import itertools\nfrom typing import Optional, Tuple\n\nimport pytest\nimport torch\n\nfrom sglang.jit_kernel.per_tensor_quant_fp8 import per_tensor_quant_fp8\n\ntry:\n    from sglang.srt.utils import is_hip\n\n    _is_hip = is_hip()\nexcept ImportError:\n    _is_hip = False\n\nfp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn\n\n\ndef sglang_scaled_fp8_quant(\n    input: torch.Tensor,\n    scale: Optional[torch.Tensor] = None,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    fp8_type_: torch.dtype = torch.float8_e4m3fn\n    output = torch.empty_like(input, device=input.device, dtype=fp8_type_)\n    is_static = True\n    if scale is None:\n        scale = torch.zeros(1, device=input.device, dtype=torch.float32)\n        is_static = False\n    per_tensor_quant_fp8(input, output, scale, is_static)\n\n    return output, scale\n\n\ndef torch_scaled_fp8_quant(tensor, inv_scale):\n    finfo = torch.finfo(torch.float8_e4m3fn)\n    scale = inv_scale.reciprocal()\n    qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max)\n    qweight = qweight.to(torch.float8_e4m3fn)\n    return qweight\n\n\n@pytest.mark.parametrize(\n    \"num_tokens,hidden_dim\",\n    list(itertools.product([128, 256, 512], [512, 2048, 4096])),\n)\ndef test_jit_per_tensor_quant_compare_implementations(\n    num_tokens: int,\n    hidden_dim: int,\n):\n    device = torch.device(\"cuda\")\n    x = torch.rand((num_tokens, hidden_dim), dtype=torch.float16, device=device)\n\n    sglang_out, sglang_scale = sglang_scaled_fp8_quant(x)\n    torch_out = torch_scaled_fp8_quant(x, sglang_scale)\n\n    torch.testing.assert_close(\n        sglang_out.float(), torch_out.float(), rtol=1e-3, atol=1e-3\n    )\n\n\n@pytest.mark.parametrize(\"shape\", [(4, 8, 64), (2, 16, 128), (19260817, 1, 1)])\ndef test_jit_per_tensor_quant_supports_3d(shape):\n    device = torch.device(\"cuda\")\n    x = torch.rand(shape, dtype=torch.bfloat16, device=device)\n    out = torch.empty_like(x, device=x.device, dtype=fp8_type_)\n    scale = torch.zeros(1, device=x.device, dtype=torch.float32)\n\n    per_tensor_quant_fp8(x, out, scale, is_static=False)\n\n    x_2d = x.flatten(0, -2)\n    out_ref_2d = torch_scaled_fp8_quant(x_2d, scale)\n    out_ref = out_ref_2d.reshape(shape)\n\n    torch.testing.assert_close(out.float(), out_ref.float(), rtol=1e-3, atol=1e-3)\n\n    scale = torch.rand(1, dtype=torch.float32, device=device)\n    sglang_out, _ = sglang_scaled_fp8_quant(x, scale)\n    torch_out = torch_scaled_fp8_quant(x, scale)\n\n    torch.testing.assert_close(\n        sglang_out.float(), torch_out.float(), rtol=1e-3, atol=1e-3\n    )\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__, \"-v\", \"-s\"])\n"
  },
  {
    "path": "python/sglang/jit_kernel/tests/test_per_token_group_quant_8bit.py",
    "content": "import itertools\n\nimport pytest\nimport torch\n\nfrom sglang.srt.utils import is_hip\n\n_is_hip = is_hip()\nfp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn\n\nfrom sgl_kernel.test_utils import (\n    assert_all_close_or_tiny_diff,\n    create_per_token_group_quant_test_data,\n)\n\nfrom sglang.jit_kernel.per_token_group_quant_8bit import (\n    per_token_group_quant_8bit as sglang_per_token_group_quant_8bit,\n)\nfrom sglang.srt.layers.quantization.fp8_kernel import (\n    create_per_token_group_quant_fp8_output_scale,\n)\nfrom sglang.srt.layers.quantization.fp8_kernel import (\n    per_token_group_quant_8bit as triton_per_token_group_quant_8bit,\n)\n\nconfigs = list(\n    itertools.product(\n        [1, 4, 16, 64, 127, 128, 512, 1024, 4096, 8192],  # num_tokens\n        [128, 256, 384, 512, 1024, 1536, 1664, 2048, 4096, 7168, 16384],  # hidden_dim\n        [16, 32, 64, 128],  # group_size\n        [None],  # num_ranks\n        [fp8_type_],  # dtype\n        [\n            dict(\n                column_major_scales=False,\n                scale_tma_aligned=False,\n                scale_ue8m0=False,\n                fuse_silu_and_mul=False,\n                masked_layout_mode=None,\n            ),\n            dict(\n                column_major_scales=True,\n                scale_tma_aligned=False,\n                scale_ue8m0=False,\n                fuse_silu_and_mul=False,\n                masked_layout_mode=None,\n            ),\n            dict(\n                column_major_scales=True,\n                scale_tma_aligned=True,\n                scale_ue8m0=False,\n                fuse_silu_and_mul=False,\n                masked_layout_mode=None,\n            ),\n            dict(\n                column_major_scales=True,\n                scale_tma_aligned=True,\n                scale_ue8m0=True,\n                fuse_silu_and_mul=False,\n                masked_layout_mode=None,\n            ),\n        ],\n    )\n) + list(\n    itertools.product(\n        [1, 4, 1 * 8, 4 * 8, 64 * 8, 256 * 8, 768 * 8],\n        [2048],\n        [128],\n        [8, 16, 32, 48],\n        [fp8_type_],\n        [\n            dict(\n                column_major_scales=True,\n                scale_tma_aligned=True,\n                scale_ue8m0=True,\n                fuse_silu_and_mul=True,\n                masked_layout_mode=None,\n            ),\n            dict(\n                column_major_scales=True,\n                scale_tma_aligned=True,\n                scale_ue8m0=True,\n                fuse_silu_and_mul=True,\n                masked_layout_mode=\"balanced\",\n            ),\n            dict(\n                column_major_scales=True,\n                scale_tma_aligned=True,\n                scale_ue8m0=True,\n                fuse_silu_and_mul=True,\n                masked_layout_mode=\"imbalanced\",\n            ),\n            dict(\n                column_major_scales=True,\n                scale_tma_aligned=True,\n                scale_ue8m0=True,\n                fuse_silu_and_mul=True,\n                masked_layout_mode=\"extreme\",\n            ),\n        ],\n    )\n)\n\n\n@pytest.mark.parametrize(\n    \"num_tokens, hidden_dim, group_size, num_ranks, dst_dtype, flags\", configs\n)\ndef test_per_token_group_quant_with_column_major(\n    num_tokens,\n    hidden_dim,\n    group_size,\n    num_ranks,\n    dst_dtype,\n    flags,\n):\n    arch_major, _ = torch.cuda.get_device_capability(torch.cuda.current_device())\n    if flags[\"scale_ue8m0\"] and (arch_major <= 9):\n        pytest.skip(\"Only Blackwell need ue8m0 fusion\")\n        return\n\n    if (flags[\"scale_ue8m0\"] and (group_size != 128)) or (\n        (dst_dtype == torch.int8) and flags[\"column_major_scales\"]\n    ):\n        pytest.skip()\n        return\n\n    x, masked_m = create_per_token_group_quant_test_data(\n        num_tokens=num_tokens, hidden_dim=hidden_dim, num_ranks=num_ranks, flags=flags\n    )\n\n    execute_kwargs = dict(\n        x=x,\n        masked_m=masked_m,\n        group_size=group_size,\n        eps=1e-10,\n        dst_dtype=dst_dtype,\n        **{k: v for k, v in flags.items() if k not in [\"masked_layout_mode\"]},\n    )\n\n    def _postprocess(x_q, x_s):\n        if masked_m is not None:\n            print(f\"Mask tokens after {masked_m} to be zero\")\n            for i in range(len(masked_m)):\n                x_q[i, masked_m[i] :, :] = 0\n                x_s[i, masked_m[i] :, :] = 0\n        return x_q, x_s\n\n    x_q_triton, x_s_triton = _postprocess(\n        *triton_per_token_group_quant_8bit(**execute_kwargs)\n    )\n\n    fuse_silu_and_mul = False\n    out_shape = (*x.shape[:-1], x.shape[-1] // (2 if fuse_silu_and_mul else 1))\n\n    fp8_dtype = torch.float8_e4m3fn\n    fp8_max = torch.finfo(fp8_dtype).max\n    fp8_min = -fp8_max\n    x_q = torch.empty(out_shape, device=x.device, dtype=fp8_dtype)\n    x_s = create_per_token_group_quant_fp8_output_scale(\n        x_shape=out_shape,\n        device=x.device,\n        group_size=group_size,\n        column_major_scales=False,\n        scale_tma_aligned=False,\n        scale_ue8m0=False,\n    )\n\n    execute_kwargs = dict(\n        input=x,\n        output_q=x_q,\n        output_s=x_s,\n        group_size=group_size,\n        eps=1e-10,\n        fp8_max=fp8_max,\n        fp8_min=fp8_min,\n    )\n    x_q_sglang, x_s_sglang = _postprocess(\n        *sglang_per_token_group_quant_8bit(**execute_kwargs)\n    )\n\n    try:\n        assert_all_close_or_tiny_diff(x_q_triton, x_q_sglang)\n        torch.testing.assert_close(\n            x_s_triton.contiguous(),\n            x_s_sglang.contiguous(),\n            rtol=1e-3,\n            atol=1e-5,\n            msg=lambda message: message + f\" {x_s_triton=} {x_s_sglang=}\",\n        )\n    except AssertionError:\n        print(\n            f\"{x.shape=} {x_q_triton.shape=} {x_s_triton.shape=} {x_q_sglang.shape=} {x_s_sglang.shape=}\"\n        )\n        print(f\"{x=}\")\n        print(f\"{masked_m=}\")\n        print(f\"{x_q_triton=}\")\n        print(f\"{x_s_triton=}\")\n        print(f\"{x_q_sglang=}\")\n        print(f\"{x_s_sglang=}\")\n\n        raise\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__, \"-v\", \"-s\"])\n"
  },
  {
    "path": "python/sglang/jit_kernel/tests/test_pos_enc.py",
    "content": "import time\nfrom typing import Optional, Tuple, Union\n\nimport pytest\nimport torch\nimport triton\nimport triton.language as tl\n\nfrom sglang.jit_kernel.rope import rotary_embedding\n\n\n@triton.jit\ndef burn_kernel(out_ptr, iters: tl.constexpr):\n    pid = tl.program_id(0)\n    x = tl.full((), pid + 1, dtype=tl.uint32)\n\n    a = tl.full((), 1664525, dtype=tl.uint32)\n    c = tl.full((), 1013904223, dtype=tl.uint32)\n    sh = tl.full((), 13, dtype=tl.uint32)\n\n    for _ in range(iters):\n        x = x * a + c\n        x = x ^ (x >> sh)\n\n    if pid == 0:\n        tl.store(out_ptr, x)\n\n\ndef triton_burn(ms: float, grid=(256,)):\n    iters = int(ms * 20000)\n    out = torch.empty((), device=\"cuda\", dtype=torch.uint32)\n    burn_kernel[grid](out, iters=iters)\n    return out\n\n\ndef create_test_inputs(\n    head_size, batch_size, seq_len, device, dtype, num_q_heads, num_kv_heads\n):\n    \"\"\"Create test inputs.\"\"\"\n    total_tokens = batch_size * seq_len\n\n    query = torch.randn(\n        batch_size, seq_len, num_q_heads, head_size, dtype=dtype, device=device\n    )\n    key = torch.randn(\n        batch_size, seq_len, num_kv_heads, head_size, dtype=dtype, device=device\n    )\n\n    pos_ids = torch.randint(\n        0, min(seq_len * 2, 100), (total_tokens,), dtype=torch.long, device=device\n    )\n\n    query = query.view(total_tokens, num_q_heads, head_size)\n    key = key.view(total_tokens, num_kv_heads, head_size)\n\n    return query, key, pos_ids\n\n\ndef create_cos_sin_cache(rotary_dim, max_position_embeddings, base, dtype, device):\n    \"\"\"Create cos/sin cache for rotary embedding.\"\"\"\n    max_pos = max_position_embeddings\n    extended_max_pos = max(max_pos, 100)\n    cos_sin_cache = torch.zeros(\n        extended_max_pos, rotary_dim, dtype=dtype, device=device\n    )\n\n    inv_freq = 1.0 / (\n        base\n        ** (\n            torch.arange(0, rotary_dim, 2, dtype=torch.float32, device=device)\n            / rotary_dim\n        )\n    )\n    t = torch.arange(extended_max_pos, dtype=torch.float32, device=device)\n    freqs = torch.outer(t, inv_freq)\n    cos_cache = torch.cos(freqs).to(dtype)\n    sin_cache = torch.sin(freqs).to(dtype)\n\n    cos_sin_cache[:, : rotary_dim // 2] = cos_cache\n    cos_sin_cache[:, rotary_dim // 2 :] = sin_cache\n\n    return cos_sin_cache\n\n\n# vLLM torch native\ndef _apply_rotary_emb(\n    x: torch.Tensor,\n    cos: torch.Tensor,\n    sin: torch.Tensor,\n    is_neox_style: bool,\n) -> torch.Tensor:\n    \"\"\"\n    Args:\n        x: [num_tokens, num_heads, head_size]\n        cos: [num_tokens, head_size // 2]\n        sin: [num_tokens, head_size // 2]\n        is_neox_style: Whether to use the Neox-style or GPT-J-style rotary\n            positional embeddings.\n    \"\"\"\n    cos = cos.unsqueeze(-2).to(x.dtype)\n    sin = sin.unsqueeze(-2).to(x.dtype)\n    if is_neox_style:\n        x1, x2 = torch.chunk(x, 2, dim=-1)\n    else:\n        x1 = x[..., ::2]\n        x2 = x[..., 1::2]\n    o1 = x1 * cos - x2 * sin\n    o2 = x2 * cos + x1 * sin\n    if is_neox_style:\n        return torch.cat((o1, o2), dim=-1)\n    else:\n        return torch.stack((o1, o2), dim=-1).flatten(-2)\n\n\nclass RotaryEmbedding(torch.nn.Module):\n    # Reference: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding.py\n    def __init__(\n        self,\n        head_size: int,\n        rotary_dim: int,\n        max_position_embeddings: int,\n        base: int,\n        is_neox_style: bool,\n        dtype: torch.dtype,\n    ) -> None:\n        super().__init__()\n        self.head_size = head_size\n        self.rotary_dim = rotary_dim\n        self.max_position_embeddings = max_position_embeddings\n        self.base = base\n        self.is_neox_style = is_neox_style\n        self.dtype = dtype\n\n        cache = self._compute_cos_sin_cache()\n        self.cos_sin_cache: torch.Tensor\n        self.register_buffer(\"cos_sin_cache\", cache, persistent=False)\n\n    def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:\n        inv_freq = 1.0 / (\n            base\n            ** (\n                torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim\n            )\n        )\n        return inv_freq\n\n    def _compute_cos_sin_cache(self) -> torch.Tensor:\n        \"\"\"Compute the cos and sin cache.\"\"\"\n        inv_freq = self._compute_inv_freq(self.base)\n        t = torch.arange(self.max_position_embeddings, dtype=torch.float)\n\n        freqs = torch.einsum(\"i,j -> ij\", t, inv_freq)\n        cos = freqs.cos()\n        sin = freqs.sin()\n        cache = torch.cat((cos, sin), dim=-1)\n        return cache\n\n    def forward_native(\n        self,\n        positions: torch.Tensor,\n        query: torch.Tensor,\n        key: Optional[torch.Tensor] = None,\n        offsets: Optional[torch.Tensor] = None,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"A PyTorch-native implementation of forward().\"\"\"\n\n        if offsets is not None:\n            positions = positions + offsets\n\n        positions = positions.flatten()\n        num_tokens = positions.shape[0]\n        cos_sin = self.cos_sin_cache.index_select(0, positions)\n\n        cos, sin = cos_sin.chunk(2, dim=-1)\n\n        query_shape = query.shape\n        query = query.view(num_tokens, -1, self.head_size)\n        query_rot = query[..., : self.rotary_dim]\n        query_pass = query[..., self.rotary_dim :]\n        query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)\n        query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)\n\n        # Modification: convert to the correct dtype\n        query = query.to(self.dtype)\n\n        if key is not None:\n            key_shape = key.shape\n            key = key.view(num_tokens, -1, self.head_size)\n            key_rot = key[..., : self.rotary_dim]\n            key_pass = key[..., self.rotary_dim :]\n            key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)\n            key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)\n\n            key = key.to(self.dtype)\n\n        return query, key\n\n\ndef get_torch_rotary_embedding(\n    head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, device\n):\n    \"\"\"Initialize Torch Native RotaryEmbedding based on vLLM implementation.\"\"\"\n    return RotaryEmbedding(\n        head_size=head_size,\n        rotary_dim=rotary_dim,\n        max_position_embeddings=max_position_embeddings,\n        base=base,\n        is_neox_style=is_neox_style,\n        dtype=dtype,\n    ).to(device)\n\n\ndef get_sgl_rotary_embedding(\n    head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, device\n):\n    \"\"\"Initialize SglKernelRotaryEmbedding.\"\"\"\n    try:\n        from sgl_kernel.testing.rotary_embedding import SglKernelRotaryEmbedding\n    except ImportError:\n        pytest.skip(\n            \"SglKernelRotaryEmbedding is not available. Test case can be removed.\"\n        )\n\n    return SglKernelRotaryEmbedding(\n        head_size=head_size,\n        rotary_dim=rotary_dim,\n        max_position_embeddings=max_position_embeddings,\n        base=base,\n        is_neox_style=is_neox_style,\n        dtype=dtype,\n    ).to(device)\n\n\ndef compare_results(jit_out, sgl_out, dtype):\n    \"\"\"Compare results between JIT and SGL implementations.\"\"\"\n    if jit_out is None:\n        assert sgl_out is None\n        return\n\n    assert sgl_out is not None\n\n    # Check for NaN values\n    assert not torch.isnan(jit_out).any(), \"NaN in JIT results\"\n    assert not torch.isnan(sgl_out).any(), \"NaN in SGL results\"\n\n    # Compare results\n    atol = 1e-2 if dtype != torch.float32 else 1e-5\n    rtol = 1e-2 if dtype != torch.float32 else 1e-5\n\n    torch.testing.assert_close(jit_out, sgl_out, atol=atol, rtol=rtol)\n\n\n@pytest.mark.parametrize(\n    \"head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, device, batch_size, seq_len, num_q_heads, num_kv_heads\",\n    [\n        # GPT-OSS cases\n        *[\n            (64, 64, 4096, 8000, True, torch.bfloat16, \"cuda\", bs, sl, 8, 8)\n            for bs, sl in [(1, 1), (32, 1), (128, 1), (512, 1), (2, 512), (4, 4096)]\n        ],\n        # Other cases\n        (64, 64, 32, 8000, True, torch.bfloat16, \"cuda\", 32, 32, 1, 1),\n        (256, 128, 4096, 10000, True, torch.bfloat16, \"cuda\", 2, 512, 4, 2),\n        (512, 128, 311, 10000, True, torch.bfloat16, \"cuda\", 3, 39, 4, 2),\n        (128, 128, 2048, 10000, False, torch.bfloat16, \"cuda\", 2, 512, 32, 8),\n        (128, 128, 2048, 10000, False, torch.bfloat16, \"cuda\", 2, 512, 16, 4),\n        (512, 128, 311, 10000, False, torch.bfloat16, \"cuda\", 3, 39, 4, 2),\n        (64, 64, 32, 8000, True, torch.float32, \"cuda\", 32, 32, 1, 1),\n        (256, 128, 4096, 10000, True, torch.float32, \"cuda\", 2, 512, 4, 2),\n        (512, 128, 311, 10000, True, torch.float32, \"cuda\", 3, 39, 4, 2),\n        (128, 128, 2048, 10000, False, torch.float32, \"cuda\", 2, 512, 32, 8),\n        (128, 128, 2048, 10000, False, torch.float32, \"cuda\", 2, 512, 16, 4),\n        (512, 128, 311, 10000, False, torch.float32, \"cuda\", 3, 39, 4, 2),\n        # Additional test cases for different head sizes and dtypes\n        (64, 32, 1024, 10000, True, torch.float16, \"cuda\", 16, 64, 8, 4),\n        (128, 64, 2048, 10000, True, torch.float16, \"cuda\", 8, 128, 16, 8),\n        (256, 128, 4096, 10000, True, torch.float16, \"cuda\", 4, 256, 8, 4),\n    ],\n)\n@pytest.mark.parametrize(\n    \"key_is_none\",\n    [True, False],\n)\ndef test_correctness(\n    head_size,\n    rotary_dim,\n    max_position_embeddings,\n    base,\n    is_neox_style,\n    dtype,\n    device,\n    batch_size,\n    seq_len,\n    num_q_heads,\n    num_kv_heads,\n    key_is_none,\n):\n    \"\"\"Test correctness of JIT rotary embedding implementation.\"\"\"\n    # Create inputs and caches\n    query, key, pos_ids = create_test_inputs(\n        head_size, batch_size, seq_len, device, dtype, num_q_heads, num_kv_heads\n    )\n    cos_sin_cache = create_cos_sin_cache(\n        rotary_dim, max_position_embeddings, base, dtype, device\n    )\n\n    # Initialize torch kernel\n    torch_rotary_emb = get_torch_rotary_embedding(\n        head_size,\n        rotary_dim,\n        max_position_embeddings,\n        base,\n        is_neox_style,\n        dtype,\n        device,\n    )\n    torch_rotary_emb.cos_sin_cache = cos_sin_cache\n    r = torch.randn_like(query)\n\n    # Apply rotary embeddings\n    query_jit, key_jit = query.clone(), key.clone()\n    query_torch, key_torch = query.clone(), key.clone()\n    stream_jit = torch.get_device_module(\"cuda\").Stream()\n    stream_kernel = torch.get_device_module(\"cuda\").Stream()\n\n    if key_is_none:\n        key_jit = None\n        key_torch = None\n    triton_burn(100.0, grid=(1024,))\n\n    r_jit, r_torch = r.clone(), r.clone()\n    torch.cuda.synchronize()\n\n    with torch.cuda.stream(stream_jit):\n        # Test if rotary_embedding runs on stream_jit\n        triton_burn(100.0, grid=(1024,))\n        query_jit = query_jit + r_jit\n        query_jit_out, key_jit_out = rotary_embedding(\n            positions=pos_ids,\n            query=query_jit,\n            key=key_jit,\n            head_size=head_size,\n            cos_sin_cache=cos_sin_cache,\n            is_neox=is_neox_style,\n        )\n\n    with torch.cuda.stream(stream_kernel):\n        triton_burn(100.0, grid=(1024,))\n        query_torch = query_torch + r_torch\n        query_torch_out, key_torch_out = torch_rotary_emb.forward_native(\n            positions=pos_ids, query=query_torch, key=key_torch\n        )\n\n    torch.cuda.synchronize()\n    compare_results(query_jit_out, query_torch_out, dtype)\n    compare_results(key_jit_out, key_torch_out, dtype)\n\n\n@pytest.mark.parametrize(\n    \"head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, device, batch_size, seq_len, num_q_heads, num_kv_heads\",\n    [\n        # Small scale\n        (64, 64, 4096, 8000, True, torch.bfloat16, \"cuda\", 1, 1, 8, 8),\n        (64, 64, 4096, 8000, True, torch.bfloat16, \"cuda\", 4, 16, 8, 8),\n        # Medium scale\n        (64, 64, 4096, 8000, True, torch.bfloat16, \"cuda\", 8, 64, 8, 8),\n        (64, 64, 4096, 8000, True, torch.bfloat16, \"cuda\", 16, 128, 8, 8),\n        # Large scale\n        (64, 64, 4096, 8000, True, torch.bfloat16, \"cuda\", 32, 512, 8, 8),\n        (64, 64, 4096, 8000, True, torch.bfloat16, \"cuda\", 64, 1024, 8, 8),\n    ],\n)\ndef test_performance(\n    head_size: int,\n    rotary_dim: int,\n    max_position_embeddings: int,\n    base: int,\n    is_neox_style,\n    dtype,\n    device,\n    batch_size,\n    seq_len,\n    num_q_heads,\n    num_kv_heads,\n):\n    \"\"\"Performance test comparing JIT and SGL implementations with accuracy validation.\"\"\"\n    # Create inputs and caches\n    query, key, pos_ids = create_test_inputs(\n        head_size, batch_size, seq_len, device, dtype, num_q_heads, num_kv_heads\n    )\n    cos_sin_cache = create_cos_sin_cache(\n        rotary_dim, max_position_embeddings, base, dtype, device\n    )\n\n    # Initialize SGL kernel\n    sgl_rotary_emb = get_sgl_rotary_embedding(\n        head_size,\n        rotary_dim,\n        max_position_embeddings,\n        base,\n        is_neox_style,\n        dtype,\n        device,\n    )\n    sgl_rotary_emb.cos_sin_cache = cos_sin_cache\n\n    warmup = 3\n\n    # Warmup runs\n    for _ in range(warmup):\n        query_warm, key_warm = query.clone(), key.clone()\n        rotary_embedding(\n            positions=pos_ids,\n            query=query_warm,\n            key=key_warm,\n            head_size=head_size,\n            cos_sin_cache=cos_sin_cache,\n            is_neox=is_neox_style,\n        )\n\n        query_sgl_warm, key_sgl_warm = query.clone(), key.clone()\n        sgl_rotary_emb.forward_cuda(\n            positions=pos_ids, query=query_sgl_warm, key=key_sgl_warm\n        )\n\n    iteration = 100\n\n    # Time JIT implementation\n    torch.cuda.synchronize()\n    start_time = time.time()\n    for _ in range(iteration):\n        query_jit, key_jit = query.clone(), key.clone()\n        rotary_embedding(\n            positions=pos_ids,\n            query=query_jit,\n            key=key_jit,\n            head_size=head_size,\n            cos_sin_cache=cos_sin_cache,\n            is_neox=is_neox_style,\n        )\n    torch.cuda.synchronize()\n    jit_time = (time.time() - start_time) / iteration\n\n    # Time SGL implementation\n    torch.cuda.synchronize()\n    start_time = time.time()\n    for _ in range(iteration):\n        query_sgl, key_sgl = query.clone(), key.clone()\n        sgl_rotary_emb.forward_cuda(positions=pos_ids, query=query_sgl, key=key_sgl)\n    torch.cuda.synchronize()\n    sgl_time = (time.time() - start_time) / iteration\n\n    # Accuracy validation during performance test\n    # Run one more time to get outputs for comparison\n    query_jit_final, key_jit_final = query.clone(), key.clone()\n    query_sgl_final, key_sgl_final = query.clone(), key.clone()\n\n    query_jit_out, key_jit_out = rotary_embedding(\n        positions=pos_ids,\n        query=query_jit_final,\n        key=key_jit_final,\n        head_size=head_size,\n        cos_sin_cache=cos_sin_cache,\n        is_neox=is_neox_style,\n    )\n\n    query_sgl_out, key_sgl_out = sgl_rotary_emb.forward_cuda(\n        positions=pos_ids, query=query_sgl_final, key=key_sgl_final\n    )\n\n    # Validate accuracy\n    compare_results(query_jit_out, query_sgl_out, dtype)\n    compare_results(key_jit_out, key_sgl_out, dtype)\n\n    # Print results\n    total_tokens = batch_size * seq_len\n    print(\n        f\"\\nPerformance Test - Batch={batch_size}, SeqLen={seq_len}, Tokens={total_tokens}\"\n    )\n    print(f\"JIT: {jit_time*1000:.9f}ms, SGL: {sgl_time*1000:.9f}ms\")\n    if sgl_time > 0:\n        speedup = sgl_time / jit_time if jit_time > 0 else float(\"inf\")\n        print(f\"Speedup (SGL/JIT): {speedup:.2f}x\")\n\n    assert jit_time >= 0 and sgl_time >= 0\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__, \"-v\", \"-s\"])\n"
  },
  {
    "path": "python/sglang/jit_kernel/tests/test_qknorm.py",
    "content": "import itertools\n\nimport pytest\nimport torch\nimport triton\n\nfrom sglang.jit_kernel.utils import get_ci_test_range\n\n\ndef sglang_aot_qknorm(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    q_weight: torch.Tensor,\n    k_weight: torch.Tensor,\n) -> None:\n    from sgl_kernel import rmsnorm\n\n    head_dim = q.shape[-1]\n    q = q.view(-1, head_dim)\n    k = k.view(-1, head_dim)\n    rmsnorm(q, q_weight, out=q)\n    rmsnorm(k, k_weight, out=k)\n\n\ndef sglang_jit_qknorm(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    q_weight: torch.Tensor,\n    k_weight: torch.Tensor,\n) -> None:\n    from sglang.jit_kernel.norm import fused_inplace_qknorm\n\n    fused_inplace_qknorm(q, k, q_weight, k_weight)\n\n\ndef flashinfer_qknorm(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    q_weight: torch.Tensor,\n    k_weight: torch.Tensor,\n) -> None:\n    from flashinfer.norm import rmsnorm\n\n    rmsnorm(q, q_weight, out=q)\n    rmsnorm(k, k_weight, out=k)\n\n\n@torch.compile()\ndef torch_impl_qknorm(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    q_weight: torch.Tensor,\n    k_weight: torch.Tensor,\n    eps: float = 1e-6,\n) -> None:\n    q_mean = q.float().pow(2).mean(dim=-1, keepdim=True)\n    k_mean = k.float().pow(2).mean(dim=-1, keepdim=True)\n    q_norm = (q_mean + eps).rsqrt()\n    k_norm = (k_mean + eps).rsqrt()\n    q.copy_(q.float() * q_norm * q_weight.float())\n    k.copy_(k.float() * k_norm * k_weight.float())\n\n\nBS_LIST = [2**n for n in range(0, 14)]\nBS_LIST += [x + 1 + i for i, x in enumerate(BS_LIST)]\nBS_LIST = get_ci_test_range(BS_LIST, [1, 9, 256, 4109])\nN_K_LIST = get_ci_test_range([2, 4], [2, 4])\nN_Q_LIST = get_ci_test_range([8, 16], [8, 16])\nHEAD_DIM_LIST = get_ci_test_range([64, 128, 256, 512, 1024], [64, 256, 1024])\nDEVICE = \"cuda\"\nDTYPE = torch.bfloat16\n\n# NOTE(dark): sgl_kernel use flashinfer template, which is bitwise identical to flashinfer impl.\n# However, sgl-jit-kernel, flashinfer, torch_impl, may have small numerical differences.\n# so we allow a small rel/abs tolerance in correctness test.\n\n\n@pytest.mark.parametrize(\n    \"batch_size,n_k,n_q,head_dim\",\n    list(itertools.product(BS_LIST, N_K_LIST, N_Q_LIST, HEAD_DIM_LIST)),\n)\ndef test_qknorm(batch_size: int, n_k: int, n_q: int, head_dim: int) -> None:\n    q = torch.randn(batch_size, n_q, head_dim, device=DEVICE, dtype=DTYPE)\n    k = torch.randn(batch_size, n_k, head_dim, device=DEVICE, dtype=DTYPE)\n    q_weight = torch.randn(head_dim, device=DEVICE, dtype=DTYPE)\n    k_weight = torch.randn(head_dim, device=DEVICE, dtype=DTYPE)\n    q_k_aot = (q.clone(), k.clone())\n    q_k_jit = (q.clone(), k.clone())\n    sglang_aot_qknorm(q_k_aot[0], q_k_aot[1], q_weight, k_weight)\n    sglang_jit_qknorm(q_k_jit[0], q_k_jit[1], q_weight, k_weight)\n    triton.testing.assert_close(q_k_aot[0], q_k_jit[0], atol=1e-2, rtol=1e-2)\n    triton.testing.assert_close(q_k_aot[1], q_k_jit[1], atol=1e-2, rtol=1e-2)\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__, \"-v\", \"-s\"])\n"
  },
  {
    "path": "python/sglang/jit_kernel/tests/test_qknorm_across_heads.py",
    "content": "import itertools\n\nimport pytest\nimport torch\nimport triton\n\nfrom sglang.jit_kernel.utils import get_ci_test_range\n\n\ndef sglang_jit_qknorm_across_heads(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    q_weight: torch.Tensor,\n    k_weight: torch.Tensor,\n) -> None:\n    from sglang.jit_kernel.norm import fused_inplace_qknorm_across_heads\n\n    fused_inplace_qknorm_across_heads(q, k, q_weight, k_weight)\n\n\ndef sglang_aot_qknorm_across_heads(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    q_weight: torch.Tensor,\n    k_weight: torch.Tensor,\n) -> None:\n    from sgl_kernel import rmsnorm\n\n    rmsnorm(q, q_weight, out=q)\n    rmsnorm(k, k_weight, out=k)\n\n\n@torch.compile()\ndef torch_impl_qknorm_across_heads(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    q_weight: torch.Tensor,\n    k_weight: torch.Tensor,\n    eps: float = 1e-6,\n) -> None:\n    q_mean = q.float().pow(2).mean(dim=-1, keepdim=True)\n    k_mean = k.float().pow(2).mean(dim=-1, keepdim=True)\n    q_norm = (q_mean + eps).rsqrt()\n    k_norm = (k_mean + eps).rsqrt()\n    q.copy_(q.float() * q_norm * q_weight.float())\n    k.copy_(k.float() * k_norm * k_weight.float())\n\n\nBS_LIST = [2**n for n in range(0, 14)]\nBS_LIST += [x + 1 + i for i, x in enumerate(BS_LIST)]\nBS_LIST = get_ci_test_range(BS_LIST, [1, 9, 256, 4109])\nHIDDEN_DIM_LIST = get_ci_test_range([512, 1024, 2048, 4096], [512, 2048, 4096])\nDEVICE = \"cuda\"\nDTYPE = torch.bfloat16\n\n\n@pytest.mark.parametrize(\n    \"batch_size,hidden_dim\",\n    list(itertools.product(BS_LIST, HIDDEN_DIM_LIST)),\n)\ndef test_qknorm_across_heads(batch_size: int, hidden_dim: int) -> None:\n    q = torch.randn(batch_size, hidden_dim, device=DEVICE, dtype=DTYPE)\n    k = torch.randn(batch_size, hidden_dim, device=DEVICE, dtype=DTYPE)\n    q_weight = torch.randn(hidden_dim, device=DEVICE, dtype=DTYPE)\n    k_weight = torch.randn(hidden_dim, device=DEVICE, dtype=DTYPE)\n\n    q_k_jit = (q.clone(), k.clone())\n    q_k_aot = (q.clone(), k.clone())\n\n    sglang_jit_qknorm_across_heads(q_k_jit[0], q_k_jit[1], q_weight, k_weight)\n    sglang_aot_qknorm_across_heads(q_k_aot[0], q_k_aot[1], q_weight, k_weight)\n\n    triton.testing.assert_close(q_k_jit[0], q_k_aot[0], atol=1e-2, rtol=1e-2)\n    triton.testing.assert_close(q_k_jit[1], q_k_aot[1], atol=1e-2, rtol=1e-2)\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__, \"-v\", \"-s\"])\n"
  },
  {
    "path": "python/sglang/jit_kernel/tests/test_qwen_image_modulation.py",
    "content": "import pytest\nimport torch\nimport triton\n\nfrom sglang.jit_kernel.diffusion.triton.norm import norm_infer\nfrom sglang.jit_kernel.diffusion.triton.scale_shift import (\n    fuse_layernorm_scale_shift_gate_select01_kernel,\n    fuse_residual_layernorm_scale_shift_gate_select01_kernel,\n    fuse_scale_shift_gate_select01_kernel,\n)\nfrom sglang.jit_kernel.utils import get_ci_test_range\n\nDEVICE = \"cuda\"\nDTYPES = get_ci_test_range(\n    [torch.float16, torch.bfloat16, torch.float32], [torch.float16, torch.bfloat16]\n)\nBATCH_SIZES = get_ci_test_range([1, 2, 4], [1, 2])\nSEQ_LENS = get_ci_test_range([6, 33, 128, 257], [6, 128])\nHIDDEN_SIZES = get_ci_test_range([512, 1024, 1536, 3072], [512, 3072])\nEPS = 1e-6\n\n\ndef _tol(dtype: torch.dtype) -> tuple[float, float]:\n    if dtype == torch.float32:\n        return 1e-5, 1e-5\n    return 5e-2, 5e-2\n\n\ndef _make_modulation_tensors(batch_size: int, hidden_size: int, dtype: torch.dtype):\n    scale0 = torch.randn(batch_size, hidden_size, device=DEVICE, dtype=dtype)\n    shift0 = torch.randn(batch_size, hidden_size, device=DEVICE, dtype=dtype)\n    gate0 = torch.randn(batch_size, hidden_size, device=DEVICE, dtype=dtype)\n    scale1 = torch.randn(batch_size, hidden_size, device=DEVICE, dtype=dtype)\n    shift1 = torch.randn(batch_size, hidden_size, device=DEVICE, dtype=dtype)\n    gate1 = torch.randn(batch_size, hidden_size, device=DEVICE, dtype=dtype)\n    return scale0, shift0, gate0, scale1, shift1, gate1\n\n\ndef _baseline_select01_modulation(\n    x: torch.Tensor,\n    weight: torch.Tensor | None,\n    bias: torch.Tensor | None,\n    scale0: torch.Tensor,\n    shift0: torch.Tensor,\n    gate0: torch.Tensor,\n    scale1: torch.Tensor,\n    shift1: torch.Tensor,\n    gate1: torch.Tensor,\n    index: torch.Tensor,\n    eps: float,\n):\n    normalized = norm_infer(\n        x.view(-1, x.shape[-1]),\n        weight,\n        bias,\n        eps=eps,\n        is_rms_norm=False,\n    ).view_as(x)\n    output, gate_out = fuse_scale_shift_gate_select01_kernel(\n        normalized,\n        scale0=scale0,\n        shift0=shift0,\n        gate0=gate0,\n        scale1=scale1,\n        shift1=shift1,\n        gate1=gate1,\n        index=index,\n    )\n    return output, gate_out\n\n\ndef _baseline_residual_select01_modulation(\n    x: torch.Tensor,\n    residual: torch.Tensor,\n    residual_gate: torch.Tensor,\n    weight: torch.Tensor | None,\n    bias: torch.Tensor | None,\n    scale0: torch.Tensor,\n    shift0: torch.Tensor,\n    gate0: torch.Tensor,\n    scale1: torch.Tensor,\n    shift1: torch.Tensor,\n    gate1: torch.Tensor,\n    index: torch.Tensor,\n    eps: float,\n):\n    residual_out = residual + residual_gate * x\n    normalized = norm_infer(\n        residual_out.view(-1, residual_out.shape[-1]),\n        weight,\n        bias,\n        eps=eps,\n        is_rms_norm=False,\n    ).view_as(residual_out)\n    output, gate_out = fuse_scale_shift_gate_select01_kernel(\n        normalized,\n        scale0=scale0,\n        shift0=shift0,\n        gate0=gate0,\n        scale1=scale1,\n        shift1=shift1,\n        gate1=gate1,\n        index=index,\n    )\n    return output, residual_out, gate_out\n\n\n@pytest.fixture(autouse=True)\ndef cuda_setup():\n    if not torch.cuda.is_available():\n        pytest.skip(\"CUDA required\")\n    torch.cuda.manual_seed(0)\n\n\n@pytest.mark.parametrize(\"dtype\", DTYPES)\n@pytest.mark.parametrize(\"batch_size\", BATCH_SIZES)\n@pytest.mark.parametrize(\"seq_len\", SEQ_LENS)\n@pytest.mark.parametrize(\"hidden_size\", HIDDEN_SIZES)\ndef test_fused_layernorm_scale_shift_gate_select01(\n    dtype, batch_size, seq_len, hidden_size\n):\n    x = torch.randn(batch_size, seq_len, hidden_size, device=DEVICE, dtype=dtype)\n    weight = torch.randn(hidden_size, device=DEVICE, dtype=dtype)\n    bias = torch.randn(hidden_size, device=DEVICE, dtype=dtype)\n    index = torch.randint(0, 2, (batch_size, seq_len), device=DEVICE, dtype=torch.int32)\n    scale0, shift0, gate0, scale1, shift1, gate1 = _make_modulation_tensors(\n        batch_size, hidden_size, dtype\n    )\n\n    out_ref, gate_ref = _baseline_select01_modulation(\n        x,\n        weight,\n        bias,\n        scale0,\n        shift0,\n        gate0,\n        scale1,\n        shift1,\n        gate1,\n        index,\n        EPS,\n    )\n    out_fused, gate_fused = fuse_layernorm_scale_shift_gate_select01_kernel(\n        x.contiguous(),\n        weight=weight,\n        bias=bias,\n        scale0=scale0,\n        shift0=shift0,\n        gate0=gate0,\n        scale1=scale1,\n        shift1=shift1,\n        gate1=gate1,\n        index=index,\n        eps=EPS,\n    )\n\n    atol, rtol = _tol(dtype)\n    triton.testing.assert_close(out_ref, out_fused, atol=atol, rtol=rtol)\n    triton.testing.assert_close(gate_ref, gate_fused, atol=atol, rtol=rtol)\n\n\n@pytest.mark.parametrize(\"dtype\", DTYPES)\n@pytest.mark.parametrize(\"batch_size\", BATCH_SIZES)\n@pytest.mark.parametrize(\"seq_len\", SEQ_LENS)\n@pytest.mark.parametrize(\"hidden_size\", HIDDEN_SIZES)\ndef test_fused_residual_layernorm_scale_shift_gate_select01(\n    dtype, batch_size, seq_len, hidden_size\n):\n    x = torch.randn(batch_size, seq_len, hidden_size, device=DEVICE, dtype=dtype)\n    residual = torch.randn_like(x)\n    residual_gate = torch.randn_like(x)\n    weight = torch.randn(hidden_size, device=DEVICE, dtype=dtype)\n    bias = torch.randn(hidden_size, device=DEVICE, dtype=dtype)\n    index = torch.randint(0, 2, (batch_size, seq_len), device=DEVICE, dtype=torch.int32)\n    scale0, shift0, gate0, scale1, shift1, gate1 = _make_modulation_tensors(\n        batch_size, hidden_size, dtype\n    )\n\n    out_ref, residual_ref, gate_ref = _baseline_residual_select01_modulation(\n        x,\n        residual,\n        residual_gate,\n        weight,\n        bias,\n        scale0,\n        shift0,\n        gate0,\n        scale1,\n        shift1,\n        gate1,\n        index,\n        EPS,\n    )\n    out_fused, residual_fused, gate_fused = (\n        fuse_residual_layernorm_scale_shift_gate_select01_kernel(\n            x.contiguous(),\n            residual=residual.contiguous(),\n            residual_gate=residual_gate.contiguous(),\n            weight=weight,\n            bias=bias,\n            scale0=scale0,\n            shift0=shift0,\n            gate0=gate0,\n            scale1=scale1,\n            shift1=shift1,\n            gate1=gate1,\n            index=index,\n            eps=EPS,\n        )\n    )\n\n    atol, rtol = _tol(dtype)\n    triton.testing.assert_close(out_ref, out_fused, atol=atol, rtol=rtol)\n    triton.testing.assert_close(residual_ref, residual_fused, atol=atol, rtol=rtol)\n    triton.testing.assert_close(gate_ref, gate_fused, atol=atol, rtol=rtol)\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__, \"-v\", \"-s\"])\n"
  },
  {
    "path": "python/sglang/jit_kernel/tests/test_renorm.py",
    "content": "# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/main/tests/test_sampling.py\n# and /sgl-workspace/sglang/sgl-kernel/tests/test_sampling.py\n\nimport pytest\nimport sgl_kernel\nimport torch\n\n\n@pytest.mark.parametrize(\"batch_size\", [1, 99, 989])\n@pytest.mark.parametrize(\"vocab_size\", [111, 32000, 128256])\n@pytest.mark.parametrize(\"k\", [10, 100, 500])\ndef test_top_k_renorm_probs(batch_size, vocab_size, k):\n    \"\"\"Test top_k_renorm_probs kernel for correctness.\n\n    This test validates that the kernel correctly:\n    1. Identifies the top-k probabilities\n    2. Masks out non-top-k values\n    3. Renormalizes the remaining probabilities to sum to 1\n    \"\"\"\n    if k > vocab_size:\n        pytest.skip(\"k should be less than vocab_size\")\n    torch.manual_seed(42)\n    pre_norm_prob = torch.rand(batch_size, vocab_size, device=\"cuda:0\")\n    normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)\n    sorted_prob, _ = torch.sort(normalized_prob, descending=True)\n    pivot = sorted_prob[:, k - 1]\n    mask = (normalized_prob >= pivot.unsqueeze(-1)).int()\n    renorm_prob_ground_truth = normalized_prob.clone()\n    renorm_prob_ground_truth[mask == 0] = 0\n    renorm_prob_ground_truth = renorm_prob_ground_truth / renorm_prob_ground_truth.sum(\n        dim=-1, keepdim=True\n    )\n\n    renorm_prob = sgl_kernel.top_k_renorm_prob(normalized_prob, k)\n    for i in range(batch_size):\n        torch.testing.assert_close(\n            renorm_prob_ground_truth[i],\n            renorm_prob[i],\n            rtol=1e-3,\n            atol=1e-3,\n        )\n\n\n@pytest.mark.parametrize(\"batch_size\", [1, 99, 989])\n@pytest.mark.parametrize(\"vocab_size\", [111, 32000, 128256])\n@pytest.mark.parametrize(\"p\", [0.1, 0.5, 0.9])\ndef test_top_p_renorm_probs(batch_size, vocab_size, p):\n    \"\"\"Test top_p_renorm_probs kernel for correctness.\n\n    This test validates that the kernel correctly:\n    1. Computes the cumulative probability distribution\n    2. Identifies tokens in the top-p threshold\n    3. Masks out tokens outside the threshold\n    4. Renormalizes the remaining probabilities to sum to 1\n    \"\"\"\n    torch.manual_seed(42)\n    pre_norm_prob = torch.rand(batch_size, vocab_size, device=\"cuda:0\")\n    normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)\n    sorted_prob, indices = torch.sort(normalized_prob, descending=False)\n    cdf = torch.cumsum(sorted_prob, dim=-1)\n    mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32, device=\"cuda:0\")\n    mask.scatter_add_(1, indices, (cdf >= (1 - p)).int())\n    renorm_prob_ground_truth = normalized_prob.clone()\n    renorm_prob_ground_truth[mask == 0] = 0\n    renorm_prob_ground_truth = renorm_prob_ground_truth / renorm_prob_ground_truth.sum(\n        dim=-1, keepdim=True\n    )\n\n    renorm_prob = sgl_kernel.top_p_renorm_prob(normalized_prob, p)\n    torch.testing.assert_close(\n        renorm_prob_ground_truth,\n        renorm_prob,\n        rtol=1e-3,\n        atol=1e-3,\n    )\n\n\n@pytest.mark.parametrize(\"batch_size\", [1, 99, 989])\n@pytest.mark.parametrize(\"vocab_size\", [111, 32000, 128256])\n@pytest.mark.parametrize(\"k\", [10, 100, 500])\n@pytest.mark.parametrize(\"neginf_input\", [False, True])\ndef test_top_k_mask_logits(batch_size, vocab_size, k, neginf_input):\n    \"\"\"Test top_k_mask_logits kernel for correctness.\n\n    This test validates that the kernel correctly:\n    1. Identifies the top-k logits\n    2. Masks non-top-k values to -inf\n    3. Preserves the top-k values\n    4. Handles negative infinity inputs gracefully\n\n    The test verifies correctness by comparing softmax(top_k_mask_logits(logits))\n    with top_k_renorm_prob(probs), which should be equivalent.\n    \"\"\"\n    if k > vocab_size:\n        pytest.skip(\"k should be less than vocab_size\")\n    torch.manual_seed(42)\n    logits = torch.randn(batch_size, vocab_size, device=\"cuda:0\") * 5\n    if neginf_input:\n        # Randomly assign some logits to -inf to test edge cases\n        num_neginf = torch.randint(1, vocab_size * batch_size, (1,)).item()\n        idxs = torch.randperm(batch_size * vocab_size, device=\"cuda:0\")[:num_neginf]\n        logits[idxs // vocab_size, idxs % vocab_size] = -float(\"inf\")\n\n    probs = torch.softmax(logits, dim=-1)\n    masked_logits = sgl_kernel.top_k_mask_logits(logits, k)\n    renormed_probs = torch.softmax(masked_logits, dim=-1)\n    renormed_probs_ref = sgl_kernel.top_k_renorm_prob(probs, k)\n\n    torch.testing.assert_close(\n        renormed_probs,\n        renormed_probs_ref,\n        rtol=1e-3,\n        atol=1e-3,\n    )\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__])\n"
  },
  {
    "path": "python/sglang/jit_kernel/tests/test_rmsnorm.py",
    "content": "import itertools\n\nimport pytest\nimport torch\nimport triton\n\nfrom sglang.jit_kernel.utils import get_ci_test_range\n\n\ndef sglang_jit_rmsnorm(input: torch.Tensor, weight: torch.Tensor) -> None:\n    from sglang.jit_kernel.norm import rmsnorm\n\n    rmsnorm(input, weight, output=input)\n\n\ndef flashinfer_rmsnorm(input: torch.Tensor, weight: torch.Tensor) -> None:\n    from flashinfer.norm import rmsnorm\n\n    rmsnorm(input, weight, out=input)\n\n\nBS_LIST = [2**n for n in range(0, 14)]\nBS_LIST += [x + 1 + i for i, x in enumerate(BS_LIST)]\nBS_LIST = get_ci_test_range(BS_LIST, [1, 9, 256, 4109])\nHIDDEN_SIZE_LIST = get_ci_test_range(\n    [512, 1024, 1536, 2048, 3072, 4096, 5120, 6144, 7168, 8192],\n    [512, 2048, 8192],\n)\nDEVICE = \"cuda\"\nDTYPE = torch.bfloat16\n\n\n@pytest.mark.parametrize(\n    \"batch_size,hidden_size\", list(itertools.product(BS_LIST, HIDDEN_SIZE_LIST))\n)\ndef test_rmsnorm(batch_size: int, hidden_size: int) -> None:\n    input = torch.randn(batch_size, hidden_size, device=DEVICE, dtype=DTYPE)\n    weight = torch.randn(hidden_size, device=DEVICE, dtype=DTYPE)\n    input_sglang = input.clone()\n    input_flashinfer = input.clone()\n    sglang_jit_rmsnorm(input_sglang, weight)\n    flashinfer_rmsnorm(input_flashinfer, weight)\n    triton.testing.assert_close(input_sglang, input_flashinfer, atol=1e-2, rtol=1e-2)\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__, \"-v\", \"-s\"])\n"
  },
  {
    "path": "python/sglang/jit_kernel/tests/test_rope.py",
    "content": "import pytest\nimport torch\nimport triton\n\nfrom sglang.jit_kernel.utils import get_ci_test_range\n\nDEVICE = \"cuda\"\nDTYPE = torch.bfloat16\nMAX_SEQ_LEN = 131072  # common seq length\nROPE_BASE = 10000.0\nCACHE_SIZE = 1024 * 128\n\n\ndef create_cos_sin_cache(\n    rotary_dim: int,\n    max_position: int = MAX_SEQ_LEN,\n    base: float = ROPE_BASE,\n) -> torch.Tensor:\n    \"\"\"Create cos/sin cache compatible with SGLang layout: [max_pos, rotary_dim].\"\"\"\n    inv_freq = 1.0 / (\n        base\n        ** (\n            torch.arange(0, rotary_dim, 2, dtype=torch.float32, device=DEVICE)\n            / rotary_dim\n        )\n    )\n    t = torch.arange(max_position, dtype=torch.float32, device=DEVICE)\n    freqs = torch.einsum(\"i,j->ij\", t, inv_freq)\n    cos = freqs.cos()\n    sin = freqs.sin()\n    cache = torch.cat((cos, sin), dim=-1)  # [max_pos, rotary_dim]\n    return cache\n\n\n# ---------------------------------------------------------------------------\n# Implementation wrappers\n# ---------------------------------------------------------------------------\n\n\ndef sglang_jit_rope(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    cos_sin_cache: torch.Tensor,\n    positions: torch.Tensor,\n    is_neox: bool,\n) -> None:\n    from sglang.jit_kernel.rope import apply_rope_inplace\n\n    apply_rope_inplace(q, k, cos_sin_cache, positions, is_neox=is_neox)\n\n\ndef flashinfer_rope(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    cos_sin_cache: torch.Tensor,\n    positions: torch.Tensor,\n    is_neox: bool,\n) -> None:\n    from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace\n\n    head_size = q.shape[-1]\n    # flashinfer expects [nnz, num_heads * head_size]\n    q_2d = q.view(q.shape[0], -1)\n    k_2d = k.view(k.shape[0], -1)\n    apply_rope_with_cos_sin_cache_inplace(\n        positions=positions,\n        query=q_2d,\n        key=k_2d,\n        head_size=head_size,\n        cos_sin_cache=cos_sin_cache,\n        is_neox=is_neox,\n    )\n\n\ndef torch_impl_rope(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    cos_sin_cache: torch.Tensor,\n    positions: torch.Tensor,\n    is_neox: bool,\n) -> None:\n    # TODO: implement a pure-PyTorch reference for extra coverage\n    pass\n\n\n# ---------------------------------------------------------------------------\n# Test parameters\n# ---------------------------------------------------------------------------\n\nBS_LIST = [2**x for x in range(12)]\nBS_LIST += [x + 1 for x in BS_LIST]  # odd sizes to stress non-aligned paths\nBS_LIST = get_ci_test_range(BS_LIST, [1, 129, 2048, 2049])\nNUM_KV_HEADS_LIST = get_ci_test_range([1, 2, 8], [1, 8])\nGQA_RATIO = get_ci_test_range([1, 4, 8], [1, 8])\nROPE_DIM_LIST = get_ci_test_range([64, 128, 256, 512], [64, 256])\nIS_NEOX_LIST = [False, True]\nDTYPE_LIST = get_ci_test_range(\n    [torch.bfloat16, torch.float16], [torch.bfloat16, torch.float16]\n)\nPARTIAL_ROPE_DIM_LIST = get_ci_test_range([64, 80, 96, 128], [64, 96])\nHEAD_DIM_LIST = get_ci_test_range([64, 128, 256], [64, 256])\n\n\n@pytest.mark.parametrize(\"batch_size\", BS_LIST)\n@pytest.mark.parametrize(\"gqa_ratio\", GQA_RATIO)\n@pytest.mark.parametrize(\"num_kv_heads\", NUM_KV_HEADS_LIST)\n@pytest.mark.parametrize(\"rope_dim\", ROPE_DIM_LIST)\n@pytest.mark.parametrize(\"is_neox\", IS_NEOX_LIST)\n@pytest.mark.parametrize(\"dtype\", DTYPE_LIST)\ndef test_rope(\n    batch_size: int,\n    gqa_ratio: int,\n    num_kv_heads: int,\n    rope_dim: int,\n    is_neox: bool,\n    dtype: torch.dtype,\n) -> None:\n    num_qo_heads = num_kv_heads * gqa_ratio\n    q = torch.randn(batch_size, num_qo_heads, rope_dim, device=DEVICE, dtype=dtype)\n    k = torch.randn(batch_size, num_kv_heads, rope_dim, device=DEVICE, dtype=dtype)\n    positions = torch.randint(\n        0, MAX_SEQ_LEN, (batch_size,), device=DEVICE, dtype=torch.int64\n    )\n    cos_sin_cache = create_cos_sin_cache(rope_dim)\n\n    q_fi, k_fi = q.clone(), k.clone()\n    q_jit, k_jit = q.clone(), k.clone()\n\n    flashinfer_rope(q_fi, k_fi, cos_sin_cache, positions, is_neox)\n    sglang_jit_rope(q_jit, k_jit, cos_sin_cache, positions, is_neox)\n\n    atol = rtol = 1e-2\n    triton.testing.assert_close(q_fi, q_jit, atol=atol, rtol=rtol)\n    triton.testing.assert_close(k_fi, k_jit, atol=atol, rtol=rtol)\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.int32, torch.int64])\ndef test_rope_position_dtypes(dtype: torch.dtype) -> None:\n    \"\"\"Ensure both int32 and int64 position tensors work correctly.\"\"\"\n    batch_size, num_qo_heads, num_kv_heads, rope_dim = 16384, 16, 2, 128\n    is_neox = True\n\n    q = torch.randn(batch_size, num_qo_heads, rope_dim, device=DEVICE, dtype=DTYPE)\n    k = torch.randn(batch_size, num_kv_heads, rope_dim, device=DEVICE, dtype=DTYPE)\n    positions = torch.randint(0, MAX_SEQ_LEN, (batch_size,), device=DEVICE, dtype=dtype)\n    cos_sin_cache = create_cos_sin_cache(rope_dim)\n\n    q_fi, k_fi = q.clone(), k.clone()\n    q_jit, k_jit = q.clone(), k.clone()\n\n    flashinfer_rope(q_fi, k_fi, cos_sin_cache, positions.long(), is_neox)\n    sglang_jit_rope(q_jit, k_jit, cos_sin_cache, positions, is_neox)\n\n    atol = rtol = 1e-2\n    triton.testing.assert_close(q_fi, q_jit, atol=atol, rtol=rtol)\n    triton.testing.assert_close(k_fi, k_jit, atol=atol, rtol=rtol)\n\n\n@pytest.mark.parametrize(\"batch_size\", BS_LIST)\n@pytest.mark.parametrize(\"is_neox\", IS_NEOX_LIST)\n@pytest.mark.parametrize(\"rope_dim\", PARTIAL_ROPE_DIM_LIST)\n@pytest.mark.parametrize(\"head_dim\", HEAD_DIM_LIST)\ndef test_partial_rope(batch_size: int, is_neox: bool, rope_dim: int, head_dim: int):\n    if head_dim < rope_dim:\n        pytest.skip(\"Invalid config: head_dim must be >= rope_dim.\")\n    num_qo_heads, num_kv_heads = 8, 2\n\n    q = torch.randn(batch_size, num_qo_heads, head_dim, device=DEVICE, dtype=DTYPE)\n    k = torch.randn(batch_size, num_kv_heads, head_dim, device=DEVICE, dtype=DTYPE)\n    positions = torch.randint(0, MAX_SEQ_LEN, (batch_size,), device=DEVICE)\n    cos_sin_cache = create_cos_sin_cache(rope_dim)\n\n    q_fi, k_fi = q.clone(), k.clone()\n    q_jit, k_jit = q.clone(), k.clone()\n    rope = ..., slice(rope_dim)  # NOTE: flashinfer by default apply to first rope_dim\n\n    flashinfer_rope(q_fi, k_fi, cos_sin_cache, positions.long(), is_neox)\n    sglang_jit_rope(q_jit[rope], k_jit[rope], cos_sin_cache, positions, is_neox)\n\n    atol = rtol = 1e-2\n    triton.testing.assert_close(q_fi, q_jit, atol=atol, rtol=rtol)\n    triton.testing.assert_close(k_fi, k_jit, atol=atol, rtol=rtol)\n\n\n@pytest.mark.parametrize(\"batch_size\", BS_LIST)\n@pytest.mark.parametrize(\"gqa_ratio\", GQA_RATIO)\n@pytest.mark.parametrize(\"num_kv_heads\", NUM_KV_HEADS_LIST)\n@pytest.mark.parametrize(\"rope_dim\", ROPE_DIM_LIST)\n@pytest.mark.parametrize(\"is_neox\", IS_NEOX_LIST)\ndef test_fused_rope_store(\n    batch_size: int,\n    gqa_ratio: int,\n    num_kv_heads: int,\n    rope_dim: int,\n    is_neox: bool,\n) -> None:\n    \"\"\"Test fused RoPE + KV cache store against separate RoPE + manual store.\"\"\"\n    from sglang.jit_kernel.rope import apply_rope_inplace_with_kvcache\n\n    num_qo_heads = num_kv_heads * gqa_ratio\n    dtype = DTYPE\n\n    q = torch.randn(batch_size, num_qo_heads, rope_dim, device=DEVICE, dtype=dtype)\n    k = torch.randn(batch_size, num_kv_heads, rope_dim, device=DEVICE, dtype=dtype)\n    v = torch.randn(batch_size, num_kv_heads, rope_dim, device=DEVICE, dtype=dtype)\n    positions = torch.randint(\n        0, MAX_SEQ_LEN, (batch_size,), device=DEVICE, dtype=torch.int64\n    )\n    out_loc = torch.randperm(CACHE_SIZE, device=DEVICE, dtype=torch.int64)[:batch_size]\n    cos_sin_cache = create_cos_sin_cache(rope_dim)\n\n    row_size = num_kv_heads * rope_dim\n    k_cache_ref = torch.zeros(CACHE_SIZE, row_size, device=DEVICE, dtype=dtype)\n    v_cache_ref = torch.zeros(CACHE_SIZE, row_size, device=DEVICE, dtype=dtype)\n    k_cache_fused = torch.zeros(CACHE_SIZE, row_size, device=DEVICE, dtype=dtype)\n    v_cache_fused = torch.zeros(CACHE_SIZE, row_size, device=DEVICE, dtype=dtype)\n\n    # --- reference: separate RoPE then manual scatter ---\n    q_ref, k_ref = q.clone(), k.clone()\n    flashinfer_rope(q_ref, k_ref, cos_sin_cache, positions, is_neox)\n    k_cache_ref[out_loc] = k_ref.view(batch_size, -1)\n    v_cache_ref[out_loc] = v.view(batch_size, -1)\n\n    # --- fused kernel ---\n    q_fused, k_fused = q.clone(), k.clone()\n    v_fused = v.clone()\n    apply_rope_inplace_with_kvcache(\n        q_fused,\n        k_fused,\n        v_fused,\n        k_cache_fused,\n        v_cache_fused,\n        cos_sin_cache,\n        positions,\n        out_loc,\n        is_neox=is_neox,\n    )\n\n    atol = rtol = 1e-2\n    # q should match RoPE-only result\n    triton.testing.assert_close(q_ref, q_fused, atol=atol, rtol=rtol)\n    # k_cache should contain the rotated k\n    triton.testing.assert_close(\n        k_cache_ref[out_loc], k_cache_fused[out_loc], atol=atol, rtol=rtol\n    )\n    # v_cache should be an exact copy\n    assert torch.all(v_cache_ref[out_loc] == v_cache_fused[out_loc]), \"v_cache mismatch\"\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__, \"-v\", \"-s\"])\n"
  },
  {
    "path": "python/sglang/jit_kernel/tests/test_store_cache.py",
    "content": "import itertools\n\nimport pytest\nimport torch\n\nfrom sglang.jit_kernel.kvcache import can_use_store_cache, store_cache\nfrom sglang.jit_kernel.utils import get_ci_test_range\n\nBS_LIST = [2**n for n in range(0, 15)]\nBS_LIST += [x + 1 + i for i, x in enumerate(BS_LIST)]\nBS_LIST = get_ci_test_range(BS_LIST, [1, 9, 256, 16399])\nHIDDEN_DIMS = get_ci_test_range(\n    [64, 128, 256, 512, 1024, 96, 98, 100], [64, 512, 1024, 98]\n)\nCACHE_SIZE = 1024 * 1024\nDTYPE = torch.bfloat16\nDEVICE = \"cuda\"\n\n\n@pytest.mark.parametrize(\n    \"batch_size,element_dim\",\n    list(itertools.product(BS_LIST, HIDDEN_DIMS)),\n)\ndef test_store_cache(batch_size: int, element_dim: int) -> None:\n    k = torch.randn((batch_size, element_dim), dtype=DTYPE, device=DEVICE)\n    v = torch.randn((batch_size, element_dim), dtype=DTYPE, device=DEVICE)\n    k_cache = torch.randn((CACHE_SIZE, element_dim), dtype=DTYPE, device=DEVICE)\n    v_cache = torch.randn((CACHE_SIZE, element_dim), dtype=DTYPE, device=DEVICE)\n    indices = torch.randperm(CACHE_SIZE, device=DEVICE)[:batch_size]\n\n    # AOT store cache\n    store_cache(k, v, k_cache, v_cache, indices)\n\n    assert torch.all(k_cache[indices] == k)\n    assert torch.all(v_cache[indices] == v)\n\n\n# Smaller subset for targeted tests below\nREPR_BS = get_ci_test_range([1, 7, 128], [1, 128])\nREPR_DIMS = get_ci_test_range([64, 128, 512, 1024, 96], [64, 1024, 96])\nSMALL_CACHE = 4096\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16, torch.float32])\n@pytest.mark.parametrize(\n    \"batch_size,element_dim\",\n    list(itertools.product(REPR_BS, REPR_DIMS)),\n)\ndef test_store_cache_dtypes(\n    batch_size: int, element_dim: int, dtype: torch.dtype\n) -> None:\n    k = torch.randn((batch_size, element_dim), dtype=dtype, device=DEVICE)\n    v = torch.randn((batch_size, element_dim), dtype=dtype, device=DEVICE)\n    k_cache = torch.randn((SMALL_CACHE, element_dim), dtype=dtype, device=DEVICE)\n    v_cache = torch.randn((SMALL_CACHE, element_dim), dtype=dtype, device=DEVICE)\n    indices = torch.randperm(SMALL_CACHE, device=DEVICE)[:batch_size]\n\n    store_cache(k, v, k_cache, v_cache, indices)\n\n    assert torch.all(k_cache[indices] == k)\n    assert torch.all(v_cache[indices] == v)\n\n\n@pytest.mark.parametrize(\n    \"batch_size,element_dim\",\n    list(itertools.product(REPR_BS, REPR_DIMS)),\n)\ndef test_store_cache_int32_indices(batch_size: int, element_dim: int) -> None:\n    k = torch.randn((batch_size, element_dim), dtype=DTYPE, device=DEVICE)\n    v = torch.randn((batch_size, element_dim), dtype=DTYPE, device=DEVICE)\n    k_cache = torch.randn((SMALL_CACHE, element_dim), dtype=DTYPE, device=DEVICE)\n    v_cache = torch.randn((SMALL_CACHE, element_dim), dtype=DTYPE, device=DEVICE)\n    # int32 indices exercise a different CUDA template instantiation than default int64\n    indices = torch.randperm(SMALL_CACHE, device=DEVICE)[:batch_size].to(torch.int32)\n\n    store_cache(k, v, k_cache, v_cache, indices)\n\n    assert torch.all(k_cache[indices.long()] == k)\n    assert torch.all(v_cache[indices.long()] == v)\n\n\ndef _valid_num_splits(element_dim: int, dtype: torch.dtype) -> list:\n    \"\"\"Return the list of valid num_split values for a given element_dim/dtype.\"\"\"\n    row_bytes = element_dim * dtype.itemsize\n    splits = [1]\n    if row_bytes % (2 * 128) == 0:\n        splits.append(2)\n    if row_bytes % (4 * 128) == 0:\n        splits.append(4)\n    return splits\n\n\n_NUM_SPLIT_CASES = [\n    (_dim, _ns, _dtype)\n    for _dtype in [torch.float16, torch.bfloat16, torch.float32]\n    for _dim in REPR_DIMS\n    for _ns in _valid_num_splits(_dim, _dtype)\n]\n\n\n@pytest.mark.parametrize(\"element_dim,num_split,dtype\", _NUM_SPLIT_CASES)\ndef test_store_cache_num_split(\n    element_dim: int, num_split: int, dtype: torch.dtype\n) -> None:\n    batch_size = 128\n    k = torch.randn((batch_size, element_dim), dtype=dtype, device=DEVICE)\n    v = torch.randn((batch_size, element_dim), dtype=dtype, device=DEVICE)\n    k_cache = torch.randn((SMALL_CACHE, element_dim), dtype=dtype, device=DEVICE)\n    v_cache = torch.randn((SMALL_CACHE, element_dim), dtype=dtype, device=DEVICE)\n    indices = torch.randperm(SMALL_CACHE, device=DEVICE)[:batch_size]\n\n    # Verify each num_split kernel path (1, 2, 4) produces correct results\n    store_cache(k, v, k_cache, v_cache, indices, num_split=num_split)\n\n    assert torch.all(k_cache[indices] == k)\n    assert torch.all(v_cache[indices] == v)\n\n\ndef test_can_use_store_cache() -> None:\n    assert can_use_store_cache(128)\n    assert can_use_store_cache(256)\n    assert can_use_store_cache(1024)\n    assert can_use_store_cache(2048)\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__, \"-v\", \"-s\"])\n"
  },
  {
    "path": "python/sglang/jit_kernel/tests/test_timestep_embedding.py",
    "content": "import os\n\nimport numpy as np\nimport pytest\nimport torch\n\ntry:\n    import tabulate\nexcept Exception:\n    tabulate = None\n\nfrom sglang.jit_kernel.timestep_embedding import (\n    timestep_embedding as timestep_embedding_cuda,\n)\nfrom sglang.jit_kernel.utils import get_ci_test_range\n\nCORRECTNESS_BATCH_SIZES = get_ci_test_range(\n    [1, 2, 8, 128, 256, 512, 1536, 2048, 4096, 11008, 16384],\n    [1, 128, 2048, 16384],\n)\nCORRECTNESS_DIMS = get_ci_test_range(\n    [32, 128, 256, 512, 1536, 2048, 4096, 8192],\n    [32, 512, 8192],\n)\nDIFFUSERS_BATCH_SIZES = get_ci_test_range(\n    [1, 2, 8, 128, 256, 512, 1536, 2048, 16384],\n    [1, 512, 16384],\n)\nDIFFUSERS_DIMS = get_ci_test_range([32, 256, 512, 1536, 8192], [32, 512, 8192])\nDTYPES = get_ci_test_range(\n    [torch.float16, torch.bfloat16, torch.float32],\n    [torch.float16, torch.bfloat16],\n)\nSCALES = get_ci_test_range([1, 0.01], [1, 0.01])\n\n\ndef get_timestep_embedding_reference(\n    timesteps: torch.Tensor,\n    dim: int,\n    *,\n    flip_sin_to_cos: bool = False,\n    downscale_freq_shift: float = 1,\n    scale: float = 1,\n    max_period: int = 10000,\n):\n    assert len(timesteps.shape) == 1, \"Timesteps should be a 1d-array\"\n\n    timesteps = timesteps.to(torch.float32)\n    half_dim = dim // 2\n    exponent = -torch.log(\n        torch.tensor(max_period, dtype=torch.float32, device=timesteps.device)\n    ) * torch.arange(\n        start=0, end=half_dim, dtype=torch.float32, device=timesteps.device\n    )\n    exponent = exponent / (half_dim - downscale_freq_shift)\n\n    emb = torch.exp(exponent)\n    emb = timesteps[:, None].float() * emb[None, :]\n\n    emb = scale * emb\n\n    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)\n    if flip_sin_to_cos:\n        emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)\n    if dim % 2 == 1:\n        emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))\n    return emb\n\n\n@pytest.mark.parametrize(\"batch_size\", CORRECTNESS_BATCH_SIZES)\n@pytest.mark.parametrize(\"dim\", CORRECTNESS_DIMS)\n@pytest.mark.parametrize(\"dtype\", DTYPES)\ndef test_timestep_embedding_correctness_with_sgld(batch_size, dim, dtype):\n    device = \"cuda\"\n    t = torch.randint(low=0, high=1000, size=(batch_size,), device=device).to(dtype)\n    torch_output = get_timestep_embedding_reference(\n        t, dim, flip_sin_to_cos=True, downscale_freq_shift=0\n    )\n    cuda_output = timestep_embedding_cuda(\n        t, dim, flip_sin_to_cos=True, downscale_freq_shift=0\n    )\n    torch.testing.assert_close(torch_output, cuda_output, atol=1e-3, rtol=1e-3)\n\n\n@pytest.mark.parametrize(\"batch_size\", DIFFUSERS_BATCH_SIZES)\n@pytest.mark.parametrize(\"dim\", DIFFUSERS_DIMS)\n@pytest.mark.parametrize(\"dtype\", DTYPES)\n@pytest.mark.parametrize(\"flip_sin_to_cos\", [False, True])\n@pytest.mark.parametrize(\"downscale_freq_shift\", [0, 1])\n@pytest.mark.parametrize(\"scale\", SCALES)\ndef test_timestep_embedding_correctness_with_diffusers(\n    batch_size, dim, flip_sin_to_cos, downscale_freq_shift, scale, dtype\n):\n    device = \"cuda\"\n    t = torch.randint(low=0, high=1000, size=(batch_size,), device=device).to(dtype)\n    torch_output = get_timestep_embedding_reference(\n        t,\n        dim,\n        flip_sin_to_cos=flip_sin_to_cos,\n        downscale_freq_shift=downscale_freq_shift,\n        scale=scale,\n        max_period=10000,\n    )\n    cuda_output = timestep_embedding_cuda(\n        t,\n        dim,\n        flip_sin_to_cos=flip_sin_to_cos,\n        downscale_freq_shift=downscale_freq_shift,\n        scale=scale,\n        max_period=10000,\n    )\n    torch.testing.assert_close(torch_output, cuda_output, atol=1e-3, rtol=1e-3)\n\n\ndef test_timestep_embedding_perf():\n    if os.environ.get(\"SGLANG_RUN_JIT_KERNEL_PERF_TESTS\") != \"1\":\n        pytest.skip(\"Perf test disabled by default\")\n    if tabulate is None:\n        pytest.skip(\"Optional dependency 'tabulate' is not installed\")\n\n    NUM_BATCH = [1, 2, 8, 63, 256, 512, 613, 1024, 1536]\n    NUM_DIM = [32, 64, 128, 256, 512, 1024, 2048, 4096]\n\n    def perf_kernel_fn(kernel_fn: callable, *args, **kwargs):\n        warmup_times = 4\n        repeat_times = 20\n        start = torch.cuda.Event(enable_timing=True)\n        end = torch.cuda.Event(enable_timing=True)\n\n        for _ in range(warmup_times):\n            output_fn = kernel_fn(*args, **kwargs)\n        torch.cuda.synchronize()\n\n        start.record()\n        for _ in range(repeat_times):\n            output_fn = kernel_fn(*args, **kwargs)\n        end.record()\n        end.synchronize()\n        return start.elapsed_time(end) / repeat_times\n\n    device = \"cuda\"\n    results = []\n\n    cuda_speedups = []\n    for B in NUM_BATCH:\n        for dim in NUM_DIM:\n            t = torch.linspace(0, max(100000, B), steps=B, device=device).to(\n                torch.float32\n            )\n            time_torch = perf_kernel_fn(get_timestep_embedding_reference, t, dim)\n            time_cuda = perf_kernel_fn(timestep_embedding_cuda, t, dim)\n            speedup_cuda = time_torch / time_cuda\n\n            results.append(\n                {\n                    \"Batch Size\": B,\n                    \"Dimension\": dim,\n                    \"Torch Time (ms)\": time_torch,\n                    \"CUDA Time (ms)\": time_cuda,\n                    \"Speedup (CUDA)\": speedup_cuda,\n                }\n            )\n            cuda_speedups.append(speedup_cuda)\n\n    print(\"=== Timestep Embedding Benchmark Results ===\")\n    print(\n        tabulate.tabulate(\n            results,\n            headers=\"keys\",\n            tablefmt=\"fancy_grid\",\n            floatfmt=(\".0f\", \".0f\", \".6f\", \".6f\", \".5f\"),\n        )\n    )\n    print(f\"Average Speedup(cuda): {np.mean(cuda_speedups):.4f}\")\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__, \"-v\", \"-s\"])\n"
  },
  {
    "path": "python/sglang/jit_kernel/timestep_embedding.py",
    "content": "from __future__ import annotations\n\nfrom typing import TYPE_CHECKING\n\nimport torch\n\nfrom sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args\n\nif TYPE_CHECKING:\n    from tvm_ffi.module import Module\n\n\n@cache_once\ndef _jit_timestep_embedding_module(dtype: torch.dtype) -> Module:\n    args = make_cpp_args(dtype)\n    return load_jit(\n        \"timestep_embedding\",\n        *args,\n        cuda_files=[\"diffusion/timestep_embedding.cuh\"],\n        cuda_wrappers=[(\"timestep_embedding\", f\"timestep_embedding<{args}>\")],\n    )\n\n\ndef timestep_embedding(\n    t: torch.Tensor,\n    dim: int,\n    flip_sin_to_cos: bool = False,\n    downscale_freq_shift: float = 0.0,\n    scale: float = 1,\n    max_period: int = 10000,\n    dtype: torch.dtype = torch.float32,\n) -> torch.Tensor:\n    if t.dtype not in (torch.float16, torch.bfloat16, torch.float32):\n        t = t.to(dtype)\n    output = torch.empty((t.shape[0], dim), dtype=torch.float32, device=t.device)\n    module = _jit_timestep_embedding_module(t.dtype)\n    module.timestep_embedding(\n        t,\n        output,\n        dim,\n        flip_sin_to_cos,\n        float(downscale_freq_shift),\n        float(scale),\n        int(max_period),\n    )\n    return output\n"
  },
  {
    "path": "python/sglang/jit_kernel/utils.py",
    "content": "from __future__ import annotations\n\nimport functools\nimport os\nimport pathlib\nfrom typing import TYPE_CHECKING, Any, Callable, List, Tuple, TypeAlias, TypeVar, Union\n\nimport torch\n\nif TYPE_CHECKING:\n    from tvm_ffi import Module\n\nF = TypeVar(\"F\", bound=Callable[..., Any])\n_FULL_TEST_ENV_VAR = \"SGLANG_JIT_KERNEL_RUN_FULL_TESTS\"\n\n\ndef is_in_ci() -> bool:\n    ci_env_vars = (\"SGLANG_IS_IN_CI\", \"CI\", \"GITHUB_ACTIONS\")\n    return any(os.getenv(env_var, \"false\").lower() == \"true\" for env_var in ci_env_vars)\n\n\ndef should_run_full_tests() -> bool:\n    return os.getenv(_FULL_TEST_ENV_VAR, \"false\").lower() == \"true\"\n\n\ndef get_ci_test_range(full_range: List[Any], ci_range: List[Any]) -> List[Any]:\n    if should_run_full_tests():\n        return full_range\n    return ci_range if is_in_ci() else full_range\n\n\ndef cache_once(fn: F) -> F:\n    \"\"\"\n    NOTE: `functools.lru_cache` is not compatible with `torch.compile`\n    So we manually implement a simple cache_once decorator to replace it.\n    \"\"\"\n    result_map = {}\n\n    @functools.wraps(fn)\n    def wrapper(*args, **kwargs):\n        key = (args, tuple(sorted(kwargs.items(), key=lambda x: x[0])))\n        if key not in result_map:\n            result_map[key] = fn(*args, **kwargs)\n        return result_map[key]\n\n    return wrapper  # type: ignore\n\n\ndef _make_wrapper(tup: Tuple[str, str]) -> str:\n    export_name, kernel_name = tup\n    return f\"TVM_FFI_DLL_EXPORT_TYPED_FUNC({export_name}, ({kernel_name}));\"\n\n\n@cache_once\ndef _resolve_kernel_path() -> pathlib.Path:\n    cur_dir = pathlib.Path(__file__).parent.resolve()\n\n    # first, try this directory structure\n    def _environment_install():\n        candidate = cur_dir.resolve()\n        if (candidate / \"include\").exists() and (candidate / \"csrc\").exists():\n            return candidate\n        return None\n\n    def _package_install():\n        # TODO: support find path by package\n        return None\n\n    path = _environment_install() or _package_install()\n    if path is None:\n        raise RuntimeError(\"Cannot find sglang.jit_kernel path\")\n    return path\n\n\nKERNEL_PATH = _resolve_kernel_path()\nDEFAULT_INCLUDE = [str(KERNEL_PATH / \"include\")]\nDEFAULT_CFLAGS = [\"-std=c++20\", \"-O3\"]\nDEFAULT_CUDA_CFLAGS = [\"-std=c++20\", \"-O3\", \"--expt-relaxed-constexpr\"]\nDEFAULT_HIP_CFLAGS = [\n    flag for flag in DEFAULT_CUDA_CFLAGS if flag != \"--expt-relaxed-constexpr\"\n]\nDEFAULT_LDFLAGS = []\nCPP_TEMPLATE_TYPE: TypeAlias = Union[int, float, bool, torch.dtype]\n\n\nclass CPPArgList(list[str]):\n    def __str__(self) -> str:\n        return \", \".join(self)\n\n\nCPP_DTYPE_MAP = {\n    torch.float: \"fp32_t\",\n    torch.float16: \"fp16_t\",\n    torch.float8_e4m3fn: \"fp8_e4m3_t\",\n    torch.bfloat16: \"bf16_t\",\n    torch.int8: \"int8_t\",\n    torch.int32: \"int32_t\",\n    torch.int64: \"int64_t\",\n}\n\n\n# AMD/ROCm note:\n@cache_once\ndef is_hip_runtime() -> bool:\n    return bool(torch.version.hip)\n\n\ndef make_cpp_args(*args: CPP_TEMPLATE_TYPE) -> CPPArgList:\n    def _convert(arg: CPP_TEMPLATE_TYPE) -> str:\n        if isinstance(arg, bool):\n            return \"true\" if arg else \"false\"\n        if isinstance(arg, (int, float)):\n            return str(arg)\n        if isinstance(arg, torch.dtype):\n            return CPP_DTYPE_MAP[arg]\n        raise TypeError(f\"Unsupported argument type for cpp template: {type(arg)}\")\n\n    return CPPArgList(_convert(arg) for arg in args)\n\n\ndef load_jit(\n    *args: str,\n    cpp_files: List[str] | None = None,\n    cuda_files: List[str] | None = None,\n    cpp_wrappers: List[Tuple[str, str]] | None = None,\n    cuda_wrappers: List[Tuple[str, str]] | None = None,\n    extra_cflags: List[str] | None = None,\n    extra_cuda_cflags: List[str] | None = None,\n    extra_ldflags: List[str] | None = None,\n    extra_include_paths: List[str] | None = None,\n    build_directory: str | None = None,\n) -> Module:\n    \"\"\"\n    Loading a JIT module from C++/CUDA source files.\n    We define a wrapper as a tuple of (export_name, kernel_name),\n    where `export_name` is the name used to called from Python,\n    and `kernel_name` is the name of the kernel class in C++/CUDA source.\n\n    :param args: Unique marker of the JIT module. Must be distinct for different kernels.\n    :type args: str\n    :param cpp_files: A list of C++ source files.\n    :type cpp_files: List[str] | None\n    :param cuda_files: A list of CUDA source files.\n    :type cuda_files: List[str] | None\n    :param cpp_wrappers: A list of C++ wrappers, defining the export name and kernel name.\n    :type cpp_wrappers: List[Tuple[str, str]] | None\n    :param cuda_wrappers: A list of CUDA wrappers, defining the export name and kernel name.\n    :type cuda_wrappers: List[Tuple[str, str]] | None\n    :param extra_cflags: Extra C++ compiler flags.\n    :type extra_cflags: List[str] | None\n    :param extra_cuda_cflags: Extra CUDA compiler flags.\n    :type extra_cuda_cflags: List[str] | None\n    :param extra_ldflags: Extra linker flags.\n    :type extra_ldflags: List[str] | None\n    :param extra_include_paths: Extra include paths.\n    :type extra_include_paths: List[str] | None\n    :param build_directory: The build directory for JIT compilation.\n    :type build_directory: str | None\n    :return: A just-in-time(JIT) compiled module.\n    :rtype: Module\n    \"\"\"\n\n    from tvm_ffi.cpp import load_inline\n\n    cpp_files = cpp_files or []\n    cuda_files = cuda_files or []\n    cpp_wrappers = cpp_wrappers or []\n    cuda_wrappers = cuda_wrappers or []\n    extra_cflags = extra_cflags or []\n    extra_cuda_cflags = extra_cuda_cflags or []\n    extra_ldflags = extra_ldflags or []\n    extra_include_paths = extra_include_paths or []\n\n    # include cpp files\n    cpp_paths = [(KERNEL_PATH / \"csrc\" / f).resolve() for f in cpp_files]\n    cpp_sources = [f'#include \"{path}\"' for path in cpp_paths]\n    cpp_sources += [_make_wrapper(tup) for tup in cpp_wrappers]\n\n    # include cuda files\n    cuda_paths = [(KERNEL_PATH / \"csrc\" / f).resolve() for f in cuda_files]\n    cuda_sources = [f'#include \"{path}\"' for path in cuda_paths]\n    cuda_sources += [_make_wrapper(tup) for tup in cuda_wrappers]\n\n    # Override TVM_FFI_CUDA_ARCH_LIST if it does not exist.\n    env_key = \"TVM_FFI_CUDA_ARCH_LIST\"\n    env_existed = env_key in os.environ\n    selected_cuda_cflags = DEFAULT_CUDA_CFLAGS\n    if is_hip_runtime():\n        selected_cuda_cflags = DEFAULT_HIP_CFLAGS\n        extra_cuda_cflags = [\"-DUSE_ROCM\"] + extra_cuda_cflags\n    else:\n        extra_cuda_cflags = [\n            f\"-DSGL_CUDA_ARCH={_get_cuda_arch_value()}\"\n        ] + extra_cuda_cflags\n    if not env_existed:\n        os.environ[env_key] = _get_cuda_arch_list()\n    try:\n        return load_inline(\n            \"sgl_kernel_jit_\" + \"_\".join(str(arg) for arg in args),\n            cpp_sources=cpp_sources,\n            cuda_sources=cuda_sources,\n            extra_cflags=DEFAULT_CFLAGS + extra_cflags,\n            extra_cuda_cflags=selected_cuda_cflags + extra_cuda_cflags,\n            extra_ldflags=DEFAULT_LDFLAGS + extra_ldflags,\n            extra_include_paths=DEFAULT_INCLUDE + extra_include_paths,\n            build_directory=build_directory,\n        )\n    finally:\n        # Reset TVM_FFI_CUDA_ARCH_LIST to original state (not exist)\n        if not env_existed:\n            del os.environ[env_key]\n\n\n@cache_once\ndef is_arch_support_pdl() -> bool:\n    import torch\n\n    device = torch.cuda.current_device()\n    return torch.cuda.get_device_capability(device)[0] >= 9\n\n\n@cache_once\ndef _get_cuda_arch_value() -> int:\n    \"\"\"Get CUDA arch value for -DSGL_CUDA_ARCH (e.g. 900 for SM 9.0).\"\"\"\n    device = torch.cuda.current_device()\n    major, minor = torch.cuda.get_device_capability(device)\n    return major * 100 + minor * 10\n\n\n@cache_once\ndef _get_cuda_arch_list() -> str:\n    \"\"\"Get the correct CUDA architecture string for TVM_FFI_CUDA_ARCH_LIST.\"\"\"\n    device = torch.cuda.current_device()\n    major, minor = torch.cuda.get_device_capability(device)\n    return f\"{major}.{minor}\"\n"
  },
  {
    "path": "python/sglang/lang/api.py",
    "content": "\"\"\"Public APIs of the language.\"\"\"\n\nimport re\nfrom typing import Callable, List, Optional, Union\n\nfrom sglang.global_config import global_config\nfrom sglang.lang.backend.base_backend import BaseBackend\nfrom sglang.lang.choices import ChoicesSamplingMethod, token_length_normalized\nfrom sglang.lang.ir import (\n    SglExpr,\n    SglExprList,\n    SglFunction,\n    SglGen,\n    SglImage,\n    SglRoleBegin,\n    SglRoleEnd,\n    SglSelect,\n    SglSeparateReasoning,\n    SglVideo,\n)\n\n\ndef function(\n    func: Optional[Callable] = None, num_api_spec_tokens: Optional[int] = None\n):\n    if func:\n        return SglFunction(func, num_api_spec_tokens=num_api_spec_tokens)\n\n    def decorator(func):\n        return SglFunction(func, num_api_spec_tokens=num_api_spec_tokens)\n\n    return decorator\n\n\ndef Runtime(*args, **kwargs):\n    # Avoid importing unnecessary dependency\n    from sglang.lang.backend.runtime_endpoint import Runtime\n\n    return Runtime(*args, **kwargs)\n\n\ndef Engine(*args, **kwargs):\n    # Avoid importing unnecessary dependency\n    from sglang.srt.entrypoints.engine import Engine\n\n    return Engine(*args, **kwargs)\n\n\ndef set_default_backend(backend: BaseBackend):\n    global_config.default_backend = backend\n\n\ndef flush_cache(backend: Optional[BaseBackend] = None):\n    backend = backend or global_config.default_backend\n    if backend is None:\n        return False\n\n    # If backend is Runtime\n    if hasattr(backend, \"endpoint\"):\n        backend = backend.endpoint\n    return backend.flush_cache()\n\n\ndef get_server_info(backend: Optional[BaseBackend] = None):\n    backend = backend or global_config.default_backend\n    if backend is None:\n        return None\n\n    # If backend is Runtime\n    if hasattr(backend, \"endpoint\"):\n        backend = backend.endpoint\n    return backend.get_server_info()\n\n\ndef gen(\n    name: Optional[str] = None,\n    max_tokens: Optional[int] = None,\n    min_tokens: Optional[int] = None,\n    n: Optional[int] = None,\n    stop: Optional[Union[str, List[str]]] = None,\n    stop_token_ids: Optional[List[int]] = None,\n    stop_regex: Optional[Union[str, List[str]]] = None,\n    temperature: Optional[float] = None,\n    top_p: Optional[float] = None,\n    top_k: Optional[int] = None,\n    min_p: Optional[float] = None,\n    frequency_penalty: Optional[float] = None,\n    presence_penalty: Optional[float] = None,\n    ignore_eos: Optional[bool] = None,\n    return_logprob: Optional[bool] = None,\n    logprob_start_len: Optional[int] = None,\n    top_logprobs_num: Optional[int] = None,\n    return_text_in_logprobs: Optional[bool] = None,\n    dtype: Optional[Union[type, str]] = None,\n    choices: Optional[List[str]] = None,\n    choices_method: Optional[ChoicesSamplingMethod] = None,\n    regex: Optional[str] = None,\n    json_schema: Optional[str] = None,\n):\n    \"\"\"Call the model to generate. See the meaning of the arguments in docs/backend/sampling_params.md\"\"\"\n\n    if choices:\n        return SglSelect(\n            name,\n            choices,\n            0.0 if temperature is None else temperature,\n            token_length_normalized if choices_method is None else choices_method,\n        )\n\n    # check regex is valid\n    if regex is not None:\n        try:\n            re.compile(regex)\n        except re.error as e:\n            raise e\n\n    return SglGen(\n        name,\n        max_tokens,\n        min_tokens,\n        n,\n        stop,\n        stop_token_ids,\n        stop_regex,\n        temperature,\n        top_p,\n        top_k,\n        min_p,\n        frequency_penalty,\n        presence_penalty,\n        ignore_eos,\n        return_logprob,\n        logprob_start_len,\n        top_logprobs_num,\n        return_text_in_logprobs,\n        dtype,\n        regex,\n        json_schema,\n    )\n\n\ndef gen_int(\n    name: Optional[str] = None,\n    max_tokens: Optional[int] = None,\n    n: Optional[int] = None,\n    stop: Optional[Union[str, List[str]]] = None,\n    stop_token_ids: Optional[List[int]] = None,\n    stop_regex: Optional[Union[str, List[str]]] = None,\n    temperature: Optional[float] = None,\n    top_p: Optional[float] = None,\n    top_k: Optional[int] = None,\n    min_p: Optional[float] = None,\n    frequency_penalty: Optional[float] = None,\n    presence_penalty: Optional[float] = None,\n    ignore_eos: Optional[bool] = None,\n    return_logprob: Optional[bool] = None,\n    logprob_start_len: Optional[int] = None,\n    top_logprobs_num: Optional[int] = None,\n    return_text_in_logprobs: Optional[bool] = None,\n):\n    return SglGen(\n        name,\n        max_tokens,\n        None,\n        n,\n        stop,\n        stop_token_ids,\n        stop_regex,\n        temperature,\n        top_p,\n        top_k,\n        min_p,\n        frequency_penalty,\n        presence_penalty,\n        ignore_eos,\n        return_logprob,\n        logprob_start_len,\n        top_logprobs_num,\n        return_text_in_logprobs,\n        int,\n        None,\n    )\n\n\ndef gen_string(\n    name: Optional[str] = None,\n    max_tokens: Optional[int] = None,\n    n: Optional[int] = None,\n    stop: Optional[Union[str, List[str]]] = None,\n    stop_token_ids: Optional[List[int]] = None,\n    stop_regex: Optional[Union[str, List[str]]] = None,\n    temperature: Optional[float] = None,\n    top_p: Optional[float] = None,\n    top_k: Optional[int] = None,\n    min_p: Optional[float] = None,\n    frequency_penalty: Optional[float] = None,\n    presence_penalty: Optional[float] = None,\n    ignore_eos: Optional[bool] = None,\n    return_logprob: Optional[bool] = None,\n    logprob_start_len: Optional[int] = None,\n    top_logprobs_num: Optional[int] = None,\n    return_text_in_logprobs: Optional[bool] = None,\n):\n    return SglGen(\n        name,\n        max_tokens,\n        None,\n        n,\n        stop,\n        stop_token_ids,\n        stop_regex,\n        temperature,\n        top_p,\n        top_k,\n        min_p,\n        frequency_penalty,\n        presence_penalty,\n        ignore_eos,\n        return_logprob,\n        logprob_start_len,\n        top_logprobs_num,\n        return_text_in_logprobs,\n        str,\n        None,\n    )\n\n\ndef image(expr: SglExpr):\n    return SglImage(expr)\n\n\ndef video(path: str, num_frames: int):\n    return SglVideo(path, num_frames)\n\n\ndef select(\n    name: Optional[str] = None,\n    choices: Optional[List[str]] = None,\n    temperature: float = 0.0,\n    choices_method: ChoicesSamplingMethod = token_length_normalized,\n):\n    assert choices is not None\n    return SglSelect(name, choices, temperature, choices_method)\n\n\ndef _role_common(name: str, expr: Optional[SglExpr] = None):\n    if expr is None:\n        return SglExprList([SglRoleBegin(name), SglRoleEnd(name)])\n    else:\n        return SglExprList([SglRoleBegin(name), expr, SglRoleEnd(name)])\n\n\ndef system(expr: Optional[SglExpr] = None):\n    return _role_common(\"system\", expr)\n\n\ndef user(expr: Optional[SglExpr] = None):\n    return _role_common(\"user\", expr)\n\n\ndef assistant(expr: Optional[SglExpr] = None):\n    return _role_common(\"assistant\", expr)\n\n\ndef system_begin():\n    return SglRoleBegin(\"system\")\n\n\ndef system_end():\n    return SglRoleEnd(\"system\")\n\n\ndef user_begin():\n    return SglRoleBegin(\"user\")\n\n\ndef user_end():\n    return SglRoleEnd(\"user\")\n\n\ndef assistant_begin():\n    return SglRoleBegin(\"assistant\")\n\n\ndef assistant_end():\n    return SglRoleEnd(\"assistant\")\n\n\ndef separate_reasoning(\n    expr: Optional[SglExpr] = None, model_type: Optional[str] = None\n):\n    return SglExprList([expr, SglSeparateReasoning(model_type, expr=expr)])\n"
  },
  {
    "path": "python/sglang/lang/backend/anthropic.py",
    "content": "from sglang.lang.backend.base_backend import BaseBackend\nfrom sglang.lang.chat_template import get_chat_template\nfrom sglang.lang.interpreter import StreamExecutor\nfrom sglang.lang.ir import SglSamplingParams\n\ntry:\n    import anthropic\nexcept ImportError as e:\n    anthropic = e\n\n\nclass Anthropic(BaseBackend):\n    def __init__(self, model_name, *args, **kwargs):\n        super().__init__()\n\n        if isinstance(anthropic, Exception):\n            raise anthropic\n\n        self.model_name = model_name\n        self.chat_template = get_chat_template(\"claude\")\n        self.client = anthropic.Anthropic(*args, **kwargs)\n\n    def get_chat_template(self):\n        return self.chat_template\n\n    def generate(\n        self,\n        s: StreamExecutor,\n        sampling_params: SglSamplingParams,\n    ):\n        if s.messages_:\n            messages = s.messages_\n        else:\n            messages = [{\"role\": \"user\", \"content\": s.text_}]\n\n        if messages and messages[0][\"role\"] == \"system\":\n            system = messages.pop(0)[\"content\"]\n        else:\n            system = \"\"\n\n        ret = self.client.messages.create(\n            model=self.model_name,\n            system=system,\n            messages=messages,\n            **sampling_params.to_anthropic_kwargs(),\n        )\n        comp = ret.content[0].text\n\n        return comp, {}\n\n    def generate_stream(\n        self,\n        s: StreamExecutor,\n        sampling_params: SglSamplingParams,\n    ):\n        if s.messages_:\n            messages = s.messages_\n        else:\n            messages = [{\"role\": \"user\", \"content\": s.text_}]\n\n        if messages and messages[0][\"role\"] == \"system\":\n            system = messages.pop(0)[\"content\"]\n        else:\n            system = \"\"\n\n        with self.client.messages.stream(\n            model=self.model_name,\n            system=system,\n            messages=messages,\n            **sampling_params.to_anthropic_kwargs(),\n        ) as stream:\n            for text in stream.text_stream:\n                yield text, {}\n"
  },
  {
    "path": "python/sglang/lang/backend/base_backend.py",
    "content": "from typing import List, Optional, Union\n\nfrom sglang.lang.chat_template import get_chat_template\nfrom sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod\nfrom sglang.lang.interpreter import StreamExecutor\nfrom sglang.lang.ir import SglSamplingParams\n\n\nclass BaseBackend:\n    def __init__(self) -> None:\n        self.support_concate_and_append = False\n        self.chat_template = get_chat_template(\"default\")\n\n    def get_model_name(self):\n        raise NotImplementedError()\n\n    def get_chat_template(self):\n        return self.chat_template\n\n    def cache_prefix(self, prefix_str: str):\n        pass\n\n    def uncache_prefix(self, rid: str):\n        pass\n\n    def end_request(self, rid: Union[str, List[str]]):\n        pass\n\n    def begin_program(self, s: StreamExecutor):\n        pass\n\n    def end_program(self, s: Union[StreamExecutor, List[StreamExecutor]]):\n        pass\n\n    def commit_lazy_operations(self, s: StreamExecutor):\n        pass\n\n    def fork_program(\n        self,\n        src: StreamExecutor,\n        dst: List[StreamExecutor],\n        position_ids_offset: Optional[List[int]] = None,\n    ):\n        pass\n\n    def fill_image(self, s: StreamExecutor):\n        pass\n\n    def generate(\n        self,\n        s: StreamExecutor,\n        sampling_params: SglSamplingParams,\n    ):\n        raise NotImplementedError()\n\n    def generate_stream(\n        self,\n        s: StreamExecutor,\n        sampling_params: SglSamplingParams,\n    ):\n        raise NotImplementedError()\n\n    def select(\n        self,\n        s: StreamExecutor,\n        choices: List[str],\n        temperature: float,\n        choices_method: Optional[ChoicesSamplingMethod] = None,\n    ) -> ChoicesDecision:\n        raise NotImplementedError()\n\n    def concatenate_and_append(self, src_rids: List[str], dst_rid: str):\n        raise NotImplementedError()\n\n    def shutdown(self):\n        pass\n\n    def flush_cache(self):\n        pass\n\n    def get_server_info(self):\n        pass\n"
  },
  {
    "path": "python/sglang/lang/backend/litellm.py",
    "content": "from typing import Mapping, Optional\n\nfrom sglang.lang.backend.base_backend import BaseBackend\nfrom sglang.lang.chat_template import get_chat_template_by_model_path\nfrom sglang.lang.interpreter import StreamExecutor\nfrom sglang.lang.ir import SglSamplingParams\n\ntry:\n    import litellm\nexcept ImportError as e:\n    litellm = e\n    litellm.num_retries = 1\n\n\nclass LiteLLM(BaseBackend):\n    def __init__(\n        self,\n        model_name,\n        chat_template=None,\n        api_key=None,\n        organization: Optional[str] = None,\n        base_url: Optional[str] = None,\n        timeout: Optional[float] = 600,\n        max_retries: Optional[int] = litellm.num_retries,\n        default_headers: Optional[Mapping[str, str]] = None,\n    ):\n        super().__init__()\n\n        if isinstance(litellm, Exception):\n            raise litellm\n\n        self.model_name = model_name\n\n        self.chat_template = chat_template or get_chat_template_by_model_path(\n            model_name\n        )\n\n        self.client_params = {\n            \"api_key\": api_key,\n            \"organization\": organization,\n            \"base_url\": base_url,\n            \"timeout\": timeout,\n            \"max_retries\": max_retries,\n            \"default_headers\": default_headers,\n        }\n\n    def get_chat_template(self):\n        return self.chat_template\n\n    def generate(\n        self,\n        s: StreamExecutor,\n        sampling_params: SglSamplingParams,\n    ):\n        if s.messages_:\n            messages = s.messages_\n        else:\n            messages = [{\"role\": \"user\", \"content\": s.text_}]\n\n        ret = litellm.completion(\n            model=self.model_name,\n            messages=messages,\n            **self.client_params,\n            **sampling_params.to_litellm_kwargs(),\n        )\n        comp = ret.choices[0].message.content\n\n        return comp, {}\n\n    def generate_stream(\n        self,\n        s: StreamExecutor,\n        sampling_params: SglSamplingParams,\n    ):\n        if s.messages_:\n            messages = s.messages_\n        else:\n            messages = [{\"role\": \"user\", \"content\": s.text_}]\n\n        ret = litellm.completion(\n            model=self.model_name,\n            messages=messages,\n            stream=True,\n            **self.client_params,\n            **sampling_params.to_litellm_kwargs(),\n        )\n        for chunk in ret:\n            text = chunk.choices[0].delta.content\n            if text is not None:\n                yield text, {}\n"
  },
  {
    "path": "python/sglang/lang/backend/openai.py",
    "content": "import dataclasses\nimport logging\nimport time\nimport warnings\nfrom typing import List, Optional, Union\n\nimport numpy as np\n\nfrom sglang.lang.backend.base_backend import BaseBackend\nfrom sglang.lang.chat_template import ChatTemplate, get_chat_template_by_model_path\nfrom sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod\nfrom sglang.lang.interpreter import StreamExecutor\nfrom sglang.lang.ir import SglSamplingParams\n\ntry:\n    import openai\n    import tiktoken\nexcept ImportError as e:\n    openai = tiktoken = e\n\n\nlogger = logging.getLogger(__name__)\n\n\ndef create_logit_bias_int(tokenizer):\n    \"\"\"Get logit bias for integer numbers.\"\"\"\n    int_token_ids = []\n\n    tokens = tokenizer._mergeable_ranks\n    for token, token_id in tokens.items():\n        s = tokenizer.decode([token_id])\n        if all([c.isdigit() for c in s]) or s in [\" \"]:\n            int_token_ids.append(token_id)\n            if len(int_token_ids) >= 300:  # OpenAI API limit\n                break\n    special_tokens = tokenizer._special_tokens\n    mask = {t: 100 for t in int_token_ids[:299]}\n    mask[special_tokens[\"<|endoftext|>\"]] = 100\n    return mask\n\n\nINSTRUCT_MODEL_NAMES = [\n    \"gpt-3.5-turbo-instruct\",\n]\n\n\n@dataclasses.dataclass\nclass TokenUsage:\n    prompt_tokens: int\n    completion_tokens: int\n\n    def reset(self):\n        self.prompt_tokens = self.completion_tokens = 0\n\n\nclass OpenAI(BaseBackend):\n    def __init__(\n        self,\n        model_name: str,\n        is_chat_model: Optional[bool] = None,\n        chat_template: Optional[ChatTemplate] = None,\n        is_azure: bool = False,\n        *args,\n        **kwargs,\n    ):\n        super().__init__()\n\n        if isinstance(openai, Exception):\n            raise openai\n\n        if is_azure:\n            self.client = openai.AzureOpenAI(*args, **kwargs)\n        else:\n            self.client = openai.OpenAI(*args, **kwargs)\n\n        self.model_name = model_name\n        try:\n            self.tokenizer = tiktoken.encoding_for_model(model_name)\n        except KeyError:\n            self.tokenizer = tiktoken.get_encoding(\"cl100k_base\")\n        self.logit_bias_int = create_logit_bias_int(self.tokenizer)\n\n        self.chat_template = chat_template or get_chat_template_by_model_path(\n            model_name\n        )\n\n        if is_chat_model is not None:\n            self.is_chat_model = is_chat_model\n        else:\n            if model_name in INSTRUCT_MODEL_NAMES:\n                self.is_chat_model = False\n            else:\n                self.is_chat_model = True\n\n        self.chat_prefix = self.chat_template.role_prefix_and_suffix[\"assistant\"][0]\n\n        # Usage\n        self.token_usage = TokenUsage(0, 0)\n\n        # API speculative execution\n        # TODO(ying): This does not support multi-threading (run_batch)\n        self.spec_kwargs = {}\n        self.spec_format = []\n        self.spec_max_num_tries = 3\n\n    def get_chat_template(self):\n        return self.chat_template\n\n    def _prepare_spec_execution(\n        self,\n        sampling_params: SglSamplingParams,\n        num_api_spec_tokens: int,\n        spec_var_name: str,\n    ):\n        if \"max_tokens\" not in self.spec_kwargs:\n            self.spec_kwargs[\"max_tokens\"] = num_api_spec_tokens\n        else:\n            assert self.spec_kwargs[\"max_tokens\"] == num_api_spec_tokens\n\n        params = sampling_params.to_openai_kwargs()\n        for key, value in params.items():\n            if key in [\"stop\"]:\n                continue\n            if key in [\"max_tokens\"]:\n                warnings.warn(\n                    \"The parameter max_tokens will be overwritten by speculated number of tokens.\"\n                )\n                continue\n            if key not in self.spec_kwargs:\n                self.spec_kwargs[key] = value\n            else:\n                assert (\n                    value == self.spec_kwargs[key]\n                ), \"sampling parameters should be consistent if turn on api speculative execution.\"\n        self.spec_format.append(\n            {\"text\": \"\", \"stop\": params[\"stop\"], \"name\": spec_var_name}\n        )\n        return \"\", {}\n\n    def generate(\n        self,\n        s: StreamExecutor,\n        sampling_params: SglSamplingParams,\n        spec_var_name: str = None,\n    ):\n        if sampling_params.dtype is None:\n            if self.is_chat_model:\n                if s.num_api_spec_tokens is None:\n                    if not s.text_.endswith(self.chat_prefix):\n                        raise RuntimeError(\n                            \"This use case is not supported if api speculative execution is off. \"\n                            \"For OpenAI chat models, sgl.gen must be right after sgl.assistant. \"\n                            \"Example of adding api speculative execution: @function(num_api_spec_tokens=128).\"\n                        )\n                    prompt = s.messages_\n                else:\n                    return self._prepare_spec_execution(\n                        sampling_params, s.num_api_spec_tokens, spec_var_name\n                    )\n            else:\n                prompt = s.text_\n\n            kwargs = sampling_params.to_openai_kwargs()\n            if (\n                self.model_name.startswith(\"o1\")\n                or self.model_name.startswith(\"o3\")\n                or \"o1\" in self.model_name\n            ):\n                kwargs.pop(\"max_tokens\", None)\n            else:\n                kwargs.pop(\"max_completion_tokens\", None)\n\n            comp = openai_completion(\n                client=self.client,\n                token_usage=self.token_usage,\n                is_chat=self.is_chat_model,\n                model=self.model_name,\n                prompt=prompt,\n                **kwargs,\n            )\n            # Keep the returned list (or string) as is.\n        elif sampling_params.dtype in [str, \"str\", \"string\"]:\n            assert (\n                not self.is_chat_model\n            ), \"constrained type not supported on chat model\"\n            kwargs = sampling_params.to_openai_kwargs()\n            kwargs.pop(\"stop\")\n            comp = openai_completion(\n                client=self.client,\n                token_usage=self.token_usage,\n                is_chat=self.is_chat_model,\n                model=self.model_name,\n                prompt=s.text_ + '\"',\n                stop='\"',\n                **kwargs,\n            )\n            # Wrap each element in quotes if we have a list.\n            if isinstance(comp, list):\n                comp = ['\"' + x + '\"' for x in comp]\n            else:\n                comp = '\"' + comp + '\"'\n        elif sampling_params.dtype in [int, \"int\"]:\n            assert (\n                not self.is_chat_model\n            ), \"constrained type not supported on chat model\"\n            kwargs = sampling_params.to_openai_kwargs()\n            kwargs.pop(\"stop\")\n            comp = openai_completion(\n                client=self.client,\n                token_usage=self.token_usage,\n                is_chat=self.is_chat_model,\n                model=self.model_name,\n                prompt=s.text_,\n                logit_bias=self.logit_bias_int,\n                stop=[\" \"],\n                **kwargs,\n            )\n            # Leave as a list if that's what is returned.\n        else:\n            raise ValueError(f\"Unknown dtype: {sampling_params.dtype}\")\n\n        return comp, {}\n\n    def spec_fill(self, value: str):\n        assert self.is_chat_model\n        self.spec_format.append({\"text\": value, \"stop\": None, \"name\": None})\n\n    def spec_pattern_match(self, comp):\n        for i, term in enumerate(self.spec_format):\n            text = term[\"text\"]\n            if text != \"\":\n                if comp.startswith(text):\n                    comp = comp[len(text) :]\n                else:\n                    return False\n            else:\n                pos = comp.find(term[\"stop\"])\n                if pos != -1:\n                    term[\"text\"] = comp[:pos]\n                    comp = comp[pos:]\n                else:\n                    if i == len(self.spec_format) - 1:\n                        term[\"text\"] = comp\n                    else:\n                        return False\n        return True\n\n    def role_end_generate(\n        self,\n        s: StreamExecutor,\n    ):\n        if s.num_api_spec_tokens is None or not s.text_.endswith(self.chat_prefix):\n            return\n\n        comp = \"\"\n        if not all(x[\"name\"] is None for x in self.spec_format):\n            # TODO(ying): throw errors or warnings\n            for i in range(self.spec_max_num_tries):\n                comp = openai_completion(\n                    client=self.client,\n                    token_usage=self.token_usage,\n                    is_chat=self.is_chat_model,\n                    model=self.model_name,\n                    prompt=s.messages_,\n                    **self.spec_kwargs,\n                )\n                # Use a string for pattern matching.\n                comp_for_match = comp[0] if isinstance(comp, list) else comp\n                if self.spec_pattern_match(comp_for_match):\n                    break\n\n        for term in self.spec_format:\n            s.text_ += term[\"text\"]\n            name = term[\"name\"]\n            if name is not None:\n                s.variables[name] = term[\"text\"]\n                s.meta_info[name] = {}\n                s.variable_event[name].set()\n\n        self.spec_kwargs = {}\n        self.spec_format = []\n\n    def generate_stream(\n        self,\n        s: StreamExecutor,\n        sampling_params: SglSamplingParams,\n    ):\n        if sampling_params.dtype is None:\n            if self.is_chat_model:\n                if not s.text_.endswith(self.chat_prefix):\n                    raise RuntimeError(\n                        \"This use case is not supported. \"\n                        \"For OpenAI chat models, sgl.gen must be right after sgl.assistant\"\n                    )\n                prompt = s.messages_\n            else:\n                prompt = s.text_\n\n            kwargs = sampling_params.to_openai_kwargs()\n            generator = openai_completion_stream(\n                client=self.client,\n                token_usage=self.token_usage,\n                is_chat=self.is_chat_model,\n                model=self.model_name,\n                prompt=prompt,\n                **kwargs,\n            )\n            return generator\n        else:\n            raise ValueError(f\"Unknown dtype: {sampling_params.dtype}\")\n\n    def select(\n        self,\n        s: StreamExecutor,\n        choices: List[str],\n        temperature: float,\n        choices_method: ChoicesSamplingMethod,\n    ) -> ChoicesDecision:\n        \"\"\"Note: `choices_method` is not used by the OpenAI backend.\"\"\"\n        if self.is_chat_model:\n            raise NotImplementedError(\n                \"select/choices is not supported for chat models. \"\n                \"Please try to use a non-chat model such as gpt-3.5-turbo-instruct\"\n            )\n\n        n_choices = len(choices)\n        token_ids = [self.tokenizer.encode(x) for x in choices]\n        scores = [0] * n_choices\n        valid = [len(x) > 0 for x in token_ids]\n        prompt_tokens = self.tokenizer.encode(s.text_)\n\n        max_len = max([len(x) for x in token_ids])\n        for step in range(max_len):\n            # Build logit bias\n            logit_bias = {}\n            for i in range(n_choices):\n                if valid[i]:\n                    logit_bias[token_ids[i][step]] = 100\n\n            # Call API\n            ret = self.client.completions.create(\n                model=self.model_name,\n                prompt=prompt_tokens,\n                logit_bias=logit_bias,\n                max_tokens=1,\n                temperature=temperature,\n            )\n            ret_str = ret.choices[0].text\n            ret_token = self.tokenizer.encode(ret_str)[0]\n            self.token_usage.prompt_tokens += ret.usage.prompt_tokens\n            self.token_usage.completion_tokens = ret.usage.completion_tokens\n\n            # TODO:\n            # 1. return logits as the scores\n            # 2. compute logits of the full choice\n            # 3. consider chunk-based decoding\n\n            # Update valid\n            hit = False\n            for i in range(n_choices):\n                if valid[i]:\n                    if step == len(token_ids[i]) - 1:\n                        valid[i] = False\n\n                    if ret_token == token_ids[i][step]:\n                        scores[i] += 1\n                        hit = True\n                    else:\n                        valid[i] = False\n            assert hit\n\n            if np.sum(valid) <= 1:\n                break\n\n            prompt_tokens.append(ret_token)\n\n        return ChoicesDecision(\n            decision=choices[np.argmax(scores)],\n            meta_info={\"scores\": scores},\n        )\n\n\ndef openai_completion(\n    client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs\n) -> Union[str, List[str]]:\n    # if \"ebnf\" is in kwargs, warn and remove\n    if \"ebnf\" in kwargs:\n        warnings.warn(\"EBNF is not officially supported by OpenAI endpoints. Ignoring.\")\n        del kwargs[\"ebnf\"]\n\n    for attempt in range(retries):\n        try:\n            if is_chat:\n                if \"stop\" in kwargs and kwargs[\"stop\"] is None:\n                    kwargs.pop(\"stop\")\n                ret = client.chat.completions.create(messages=prompt, **kwargs)\n                if len(ret.choices) == 1:\n                    comp = ret.choices[0].message.content\n                else:\n                    comp = [c.message.content for c in ret.choices]\n            else:\n                ret = client.completions.create(prompt=prompt, **kwargs)\n                if isinstance(prompt, (list, tuple)):\n                    comp = [c.text for c in ret.choices]\n                else:\n                    comp = ret.choices[0].text\n                    if len(ret.choices) > 1:\n                        comp = [c.text for c in ret.choices]\n\n            token_usage.prompt_tokens += ret.usage.prompt_tokens\n            token_usage.completion_tokens += ret.usage.completion_tokens\n            break\n        except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e:\n            logger.error(f\"OpenAI Error: {e}. Waiting 5 seconds...\")\n            time.sleep(5)\n            if attempt == retries - 1:\n                raise e\n        except Exception as e:\n            logger.error(f\"RuntimeError {e}.\")\n            raise e\n\n    return comp\n\n\ndef openai_completion_stream(\n    client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs\n):\n    # if \"ebnf\" is in kwargs, warn and remove\n    if \"ebnf\" in kwargs:\n        warnings.warn(\"EBNF is not officially supported by OpenAI endpoints. Ignoring.\")\n        del kwargs[\"ebnf\"]\n\n    for attempt in range(retries):\n        try:\n            if is_chat:\n                if \"stop\" in kwargs and kwargs[\"stop\"] is None:\n                    kwargs.pop(\"stop\")\n                generator = client.chat.completions.create(\n                    messages=prompt,\n                    stream=True,\n                    stream_options={\"include_usage\": True},\n                    **kwargs,\n                )\n                for ret in generator:\n                    if len(ret.choices) == 0:\n                        continue\n                    try:\n                        content = ret.choices[0].delta.content\n                    except IndexError:\n                        content = None\n                    yield content or \"\", {}\n            else:\n                generator = client.completions.create(\n                    prompt=prompt,\n                    stream=True,\n                    stream_options={\"include_usage\": True},\n                    **kwargs,\n                )\n                for ret in generator:\n                    if len(ret.choices) == 0:\n                        continue\n                    content = ret.choices[0].text\n                    yield content or \"\", {}\n\n            token_usage.prompt_tokens += ret.usage.prompt_tokens\n            token_usage.completion_tokens += ret.usage.completion_tokens\n            break\n        except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e:\n            logger.error(f\"OpenAI Error: {e}. Waiting 5 seconds...\")\n            time.sleep(5)\n            if attempt == retries - 1:\n                raise e\n        except Exception as e:\n            logger.error(f\"RuntimeError {e}.\")\n            raise e\n"
  },
  {
    "path": "python/sglang/lang/backend/runtime_endpoint.py",
    "content": "import atexit\nimport json\nimport multiprocessing\nimport time\nimport warnings\nfrom typing import Dict, List, Optional, Union\n\nimport aiohttp\nimport requests\n\nfrom sglang.global_config import global_config\nfrom sglang.lang.backend.base_backend import BaseBackend\nfrom sglang.lang.chat_template import get_chat_template, get_chat_template_by_model_path\nfrom sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod\nfrom sglang.lang.interpreter import StreamExecutor\nfrom sglang.lang.ir import (\n    REGEX_BOOL,\n    REGEX_FLOAT,\n    REGEX_INT,\n    REGEX_STR,\n    SglSamplingParams,\n)\nfrom sglang.utils import http_request\n\n\nclass RuntimeEndpoint(BaseBackend):\n    def __init__(\n        self,\n        base_url: str,\n        api_key: Optional[str] = None,\n        verify: Optional[str] = None,\n        chat_template_name: Optional[str] = None,\n    ):\n        super().__init__()\n        self.support_concate_and_append = True\n\n        self.base_url = base_url\n        self.api_key = api_key\n        self.verify = verify\n\n        res = http_request(\n            self.base_url + \"/get_model_info\",\n            api_key=self.api_key,\n            verify=self.verify,\n        )\n        self._assert_success(res)\n        self.model_info = res.json()\n\n        if chat_template_name:\n            self.chat_template = get_chat_template(chat_template_name)\n        else:\n            self.chat_template = get_chat_template_by_model_path(\n                self.model_info[\"model_path\"]\n            )\n\n    def get_model_name(self):\n        return self.model_info[\"model_path\"]\n\n    def flush_cache(self):\n        res = http_request(\n            self.base_url + \"/flush_cache\",\n            api_key=self.api_key,\n            verify=self.verify,\n            method=\"POST\",\n        )\n        self._assert_success(res)\n\n    def get_server_info(self):\n        res = http_request(\n            self.base_url + \"/get_server_info\",\n            api_key=self.api_key,\n            verify=self.verify,\n        )\n        self._assert_success(res)\n        return res.json()\n\n    def get_chat_template(self):\n        return self.chat_template\n\n    def cache_prefix(self, prefix_str: str):\n        res = http_request(\n            self.base_url + \"/generate\",\n            json={\"text\": prefix_str, \"sampling_params\": {\"max_new_tokens\": 0}},\n            api_key=self.api_key,\n            verify=self.verify,\n        )\n        self._assert_success(res)\n\n    def start_profile(self):\n        res = http_request(\n            self.base_url + \"/start_profile\",\n            api_key=self.api_key,\n            verify=self.verify,\n        )\n        self._assert_success(res)\n\n    def stop_profile(self):\n        res = http_request(\n            self.base_url + \"/stop_profile\",\n            api_key=self.api_key,\n            verify=self.verify,\n        )\n        self._assert_success(res)\n\n    def commit_lazy_operations(self, s: StreamExecutor):\n        data = {\"text\": s.text_, \"sampling_params\": {\"max_new_tokens\": 0}}\n        self._add_images(s, data)\n        res = http_request(\n            self.base_url + \"/generate\",\n            json=data,\n            api_key=self.api_key,\n            verify=self.verify,\n        )\n        self._assert_success(res)\n\n    def fill_image(self, s: StreamExecutor):\n        data = {\"text\": s.text_, \"sampling_params\": {\"max_new_tokens\": 0}}\n        self._add_images(s, data)\n        res = http_request(\n            self.base_url + \"/generate\",\n            json=data,\n            api_key=self.api_key,\n            verify=self.verify,\n        )\n        self._assert_success(res)\n\n    def _handle_dtype_to_regex(self, sampling_params: SglSamplingParams):\n        if sampling_params.dtype is None:\n            return\n\n        if sampling_params.stop == ():\n            sampling_params.stop = []\n\n        dtype_regex = None\n        if sampling_params.dtype in [\"int\", int]:\n\n            dtype_regex = REGEX_INT\n            sampling_params.stop.extend([\" \", \"\\n\"])\n        elif sampling_params.dtype in [\"float\", float]:\n\n            dtype_regex = REGEX_FLOAT\n            sampling_params.stop.extend([\" \", \"\\n\"])\n        elif sampling_params.dtype in [\"str\", str]:\n\n            dtype_regex = REGEX_STR\n        elif sampling_params.dtype in [\"bool\", bool]:\n\n            dtype_regex = REGEX_BOOL\n        else:\n            raise RuntimeError(f\"Invalid dtype: {sampling_params.dtype}\")\n\n        if dtype_regex is not None and sampling_params.regex is not None:\n            warnings.warn(\n                f\"Both dtype and regex are set. Only dtype will be used. dtype: {sampling_params.dtype}, regex: {sampling_params.regex}\"\n            )\n\n        sampling_params.regex = dtype_regex\n\n    def generate(\n        self,\n        s: StreamExecutor,\n        sampling_params: SglSamplingParams,\n    ):\n        self._handle_dtype_to_regex(sampling_params)\n        data = {\n            \"text\": s.text_,\n            \"sampling_params\": {\n                \"skip_special_tokens\": global_config.skip_special_tokens_in_output,\n                \"spaces_between_special_tokens\": global_config.spaces_between_special_tokens_in_out,\n                **sampling_params.to_srt_kwargs(),\n            },\n        }\n\n        for item in [\n            \"return_logprob\",\n            \"logprob_start_len\",\n            \"top_logprobs_num\",\n            \"return_text_in_logprobs\",\n        ]:\n            value = getattr(sampling_params, item, None)\n            if value is not None:\n                data[item] = value\n\n        self._add_images(s, data)\n\n        res = http_request(\n            self.base_url + \"/generate\",\n            json=data,\n            api_key=self.api_key,\n            verify=self.verify,\n        )\n        self._assert_success(res)\n\n        obj = res.json()\n        comp = obj[\"text\"]\n        return comp, obj[\"meta_info\"]\n\n    def generate_stream(\n        self,\n        s: StreamExecutor,\n        sampling_params: SglSamplingParams,\n    ):\n        self._handle_dtype_to_regex(sampling_params)\n\n        data = {\n            \"text\": s.text_,\n            \"sampling_params\": {\n                \"skip_special_tokens\": global_config.skip_special_tokens_in_output,\n                \"spaces_between_special_tokens\": global_config.spaces_between_special_tokens_in_out,\n                **sampling_params.to_srt_kwargs(),\n            },\n        }\n\n        for item in [\n            \"return_logprob\",\n            \"logprob_start_len\",\n            \"top_logprobs_num\",\n            \"return_text_in_logprobs\",\n        ]:\n            value = getattr(sampling_params, item, None)\n            if value is not None:\n                data[item] = value\n\n        data[\"stream\"] = True\n        self._add_images(s, data)\n\n        res = http_request(\n            self.base_url + \"/generate\",\n            json=data,\n            stream=True,\n            api_key=self.api_key,\n            verify=self.verify,\n        )\n        self._assert_success(res)\n        pos = 0\n\n        for chunk in res.iter_lines(decode_unicode=False):\n            chunk = chunk.decode(\"utf-8\")\n            if chunk and chunk.startswith(\"data:\"):\n                if chunk == \"data: [DONE]\":\n                    break\n                data = json.loads(chunk[5:].strip(\"\\n\"))\n                chunk_text = data[\"text\"][pos:]\n                meta_info = data[\"meta_info\"]\n                pos += len(chunk_text)\n                yield chunk_text, meta_info\n\n    def select(\n        self,\n        s: StreamExecutor,\n        choices: List[str],\n        temperature: float,\n        choices_method: ChoicesSamplingMethod,\n    ) -> ChoicesDecision:\n        assert temperature <= 1e-5\n\n        # Cache common prefix\n        data = {\"text\": s.text_, \"sampling_params\": {\"max_new_tokens\": 0}}\n        obj = self._generate_http_request(s, data)\n        prompt_len = obj[\"meta_info\"][\"prompt_tokens\"]\n        logprob_start_len = max(prompt_len - 2, 0)  # For token healing\n\n        # Compute logprob\n        data = {\n            \"text\": [s.text_ + c for c in choices],\n            \"sampling_params\": {\n                \"max_new_tokens\": 0,\n                \"temperature\": 0,\n            },\n            \"return_logprob\": True,\n            \"return_text_in_logprobs\": True,\n            \"logprob_start_len\": logprob_start_len,\n        }\n        obj = self._generate_http_request(s, data)\n\n        input_token_logprobs = [r[\"meta_info\"][\"input_token_logprobs\"] for r in obj]\n        output_token_logprobs = [r[\"meta_info\"][\"output_token_logprobs\"] for r in obj]\n        normalized_prompt_logprobs = [\n            compute_normalized_prompt_logprobs(r[\"meta_info\"][\"input_token_logprobs\"])\n            for r in obj\n        ]\n\n        # Remove extra token if no token healing occurred\n        for i in range(len(input_token_logprobs)):\n            healed_token_str = input_token_logprobs[i][0][-1]\n            if s.text_.endswith(healed_token_str):\n                healed_token_logprob = input_token_logprobs[i][0][0]\n                normalized_prompt_logprobs[i] = (\n                    normalized_prompt_logprobs[i] * len(input_token_logprobs[i])\n                    - healed_token_logprob\n                ) / (len(input_token_logprobs[i]) - 1)\n                input_token_logprobs[i] = input_token_logprobs[i][1:]\n\n        # Compute unconditional logprobs if required\n        if choices_method.requires_unconditional_logprobs:\n            input_ids = [[el[1] for el in subl] for subl in input_token_logprobs]\n            data = {\n                \"input_ids\": input_ids,\n                \"sampling_params\": {\"max_new_tokens\": 0},\n                \"return_logprob\": True,\n            }\n            obj = self._generate_http_request(s, data)\n            unconditional_token_logprobs = [\n                r[\"meta_info\"][\"input_token_logprobs\"] for r in obj\n            ]\n        else:\n            unconditional_token_logprobs = None\n\n        return choices_method(\n            choices=choices,\n            normalized_prompt_logprobs=normalized_prompt_logprobs,\n            input_token_logprobs=input_token_logprobs,\n            output_token_logprobs=output_token_logprobs,\n            unconditional_token_logprobs=unconditional_token_logprobs,\n        )\n\n    def concatenate_and_append(self, src_rids: List[str], dst_rid: str):\n        res = http_request(\n            self.base_url + \"/concate_and_append_request\",\n            json={\"src_rids\": src_rids, \"dst_rid\": dst_rid},\n            api_key=self.api_key,\n            verify=self.verify,\n        )\n        self._assert_success(res)\n\n    def _generate_http_request(self, s: StreamExecutor, data):\n        self._add_images(s, data)\n        res = http_request(\n            self.base_url + \"/generate\",\n            json=data,\n            api_key=self.api_key,\n            verify=self.verify,\n        )\n        self._assert_success(res)\n        return res.json()\n\n    def _add_images(self, s: StreamExecutor, data):\n        if s.images_:\n            assert len(s.images_) == 1, \"Only support one image.\"\n            data[\"image_data\"] = s.images_[0][1]\n\n    def _assert_success(self, res):\n        if res.status_code != 200:\n            try:\n                content = res.json()\n            except json.JSONDecodeError:\n                content = res.text\n            raise RuntimeError(content)\n\n\ndef compute_normalized_prompt_logprobs(input_logprobs):\n    values = [x[0] for x in input_logprobs if x[0]]\n    return sum(values) / len(values)\n\n\nclass Runtime:\n    \"\"\"\n    A wrapper for the HTTP server.\n    This is used for launching the server in a python program without\n    using the command line interface.\n\n    It is mainly used for the frontend language.\n    You should use the Engine class if you want to do normal offline processing without the frontend language.\n    \"\"\"\n\n    def __init__(\n        self,\n        log_level: str = \"error\",\n        launch_timeout: float = 300.0,\n        *args,\n        **kwargs,\n    ):\n        \"\"\"See the arguments in server_args.py::ServerArgs\n\n        Args:\n            log_level: Log level for the server.\n            timeout: Timeout in seconds for waiting for the server to start.\n            *args: Additional arguments passed to ServerArgs.\n            **kwargs: Additional keyword arguments passed to ServerArgs.\n        \"\"\"\n        # We delay the import of any `sglang.srt` components in `sglang.lang`, so users can run\n        # client code without installing SRT server and its dependency if they want.\n        from sglang.srt.entrypoints.http_server import launch_server\n        from sglang.srt.server_args import ServerArgs\n        from sglang.srt.utils.network import is_port_available\n\n        self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)\n\n        # Pre-allocate ports\n        for port in range(self.server_args.port, 40000):\n            if is_port_available(port):\n                break\n        self.server_args.port = port\n\n        self.url = self.server_args.url()\n        self.generate_url = self.url + \"/generate\"\n\n        # NOTE: We store pid instead of proc to fix some issues during __delete__\n        self.pid = None\n\n        ctx = multiprocessing.get_context(\"spawn\")\n        proc = ctx.Process(\n            target=launch_server,\n            args=(self.server_args,),\n        )\n        proc.start()\n        self.pid = proc.pid\n\n        # Before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()\n        atexit.register(self.shutdown)\n\n        # Wait for server to be ready by polling /health_generate\n        start_time = time.time()\n        with requests.Session() as session:\n            while time.time() - start_time < launch_timeout:\n                try:\n                    response = session.get(f\"{self.url}/health_generate\")\n                    if response.status_code == 200:\n                        break\n                except requests.RequestException:\n                    pass\n\n                if not proc.is_alive():\n                    self.shutdown()\n                    raise RuntimeError(\n                        \"Initialization failed. Please see the error messages above.\"\n                    )\n\n                time.sleep(2)\n            else:\n                self.shutdown()\n                raise TimeoutError(\"Server failed to start within the timeout period.\")\n\n        self.endpoint = RuntimeEndpoint(self.url)\n\n    def shutdown(self):\n        from sglang.srt.utils import kill_process_tree\n\n        if self.pid is not None:\n            kill_process_tree(self.pid)\n            self.pid = None\n\n    def start_profile(self):\n        self.endpoint.start_profile()\n\n    def stop_profile(self):\n        self.endpoint.stop_profile()\n\n    def cache_prefix(self, prefix: str):\n        self.endpoint.cache_prefix(prefix)\n\n    def get_tokenizer(self):\n        from sglang.srt.utils.hf_transformers_utils import get_tokenizer\n\n        return get_tokenizer(\n            self.server_args.tokenizer_path,\n            tokenizer_mode=self.server_args.tokenizer_mode,\n            trust_remote_code=self.server_args.trust_remote_code,\n            revision=self.server_args.revision,\n        )\n\n    async def async_generate(\n        self,\n        prompt: str,\n        sampling_params: Optional[Dict] = None,\n    ):\n        if self.server_args.skip_tokenizer_init:\n            json_data = {\n                \"input_ids\": prompt,\n                \"sampling_params\": sampling_params,\n                \"stream\": True,\n            }\n        else:\n            json_data = {\n                \"text\": prompt,\n                \"sampling_params\": sampling_params,\n                \"stream\": True,\n            }\n        pos = 0\n\n        timeout = aiohttp.ClientTimeout(total=3 * 3600)\n        async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:\n            async with session.post(self.generate_url, json=json_data) as response:\n                async for chunk, _ in response.content.iter_chunks():\n                    chunk = chunk.decode(\"utf-8\")\n                    if chunk and chunk.startswith(\"data:\"):\n                        if chunk == \"data: [DONE]\\n\\n\":\n                            break\n                        data = json.loads(chunk[5:].strip(\"\\n\"))\n                        if \"text\" in data:\n                            cur = data[\"text\"][pos:]\n                            if cur:\n                                yield cur\n                            pos += len(cur)\n                        else:\n                            yield data\n\n    add_request = async_generate\n\n    def generate(\n        self,\n        prompt: Union[str, List[str]],\n        sampling_params: Optional[Dict] = None,\n        return_logprob: Optional[Union[List[bool], bool]] = False,\n        logprob_start_len: Optional[Union[List[int], int]] = None,\n        top_logprobs_num: Optional[Union[List[int], int]] = None,\n        lora_path: Optional[List[Optional[str]]] = None,\n    ):\n        json_data = {\n            \"text\": prompt,\n            \"sampling_params\": sampling_params,\n            \"return_logprob\": return_logprob,\n            \"logprob_start_len\": logprob_start_len,\n            \"top_logprobs_num\": top_logprobs_num,\n            \"lora_path\": lora_path,\n        }\n        assert not isinstance(lora_path, list) or len(lora_path) == len(prompt)\n        response = requests.post(\n            self.url + \"/generate\",\n            json=json_data,\n        )\n        return json.dumps(response.json())\n\n    def encode(\n        self,\n        prompt: Union[str, List[str], List[Dict], List[List[Dict]]],\n    ):\n        json_data = {\"text\": prompt}\n        response = requests.post(self.url + \"/encode\", json=json_data)\n        return json.dumps(response.json())\n\n    async def get_server_info(self):\n        async with aiohttp.ClientSession() as session:\n            async with session.get(f\"{self.url}/get_server_info\") as response:\n                if response.status == 200:\n                    return await response.json()\n                else:\n                    error_data = await response.json()\n                    raise RuntimeError(\n                        f\"Failed to get server info. {error_data['error']['message']}\"\n                    )\n\n    def __del__(self):\n        self.shutdown()\n"
  },
  {
    "path": "python/sglang/lang/backend/vertexai.py",
    "content": "import os\nimport warnings\n\nfrom sglang.lang.backend.base_backend import BaseBackend\nfrom sglang.lang.chat_template import get_chat_template\nfrom sglang.lang.interpreter import StreamExecutor\nfrom sglang.lang.ir import SglSamplingParams\n\ntry:\n    import vertexai\n    from vertexai.preview.generative_models import (\n        GenerationConfig,\n        GenerativeModel,\n        Image,\n    )\nexcept ImportError as e:\n    GenerativeModel = e\n\n\nclass VertexAI(BaseBackend):\n    def __init__(self, model_name, safety_settings=None):\n        super().__init__()\n\n        if isinstance(GenerativeModel, Exception):\n            raise GenerativeModel\n\n        project_id = os.environ[\"GCP_PROJECT_ID\"]\n        location = os.environ.get(\"GCP_LOCATION\")\n        vertexai.init(project=project_id, location=location)\n\n        self.model_name = model_name\n        self.chat_template = get_chat_template(\"default\")\n        self.safety_settings = safety_settings\n\n    def get_chat_template(self):\n        return self.chat_template\n\n    def generate(\n        self,\n        s: StreamExecutor,\n        sampling_params: SglSamplingParams,\n    ):\n        if s.messages_:\n            prompt = self.messages_to_vertexai_input(s.messages_)\n        else:\n            # single-turn\n            prompt = (\n                self.text_to_vertexai_input(s.text_, s.cur_images)\n                if s.cur_images\n                else s.text_\n            )\n        ret = GenerativeModel(self.model_name).generate_content(\n            prompt,\n            generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),\n            safety_settings=self.safety_settings,\n        )\n\n        comp = ret.text\n\n        return comp, {}\n\n    def generate_stream(\n        self,\n        s: StreamExecutor,\n        sampling_params: SglSamplingParams,\n    ):\n        if s.messages_:\n            prompt = self.messages_to_vertexai_input(s.messages_)\n        else:\n            # single-turn\n            prompt = (\n                self.text_to_vertexai_input(s.text_, s.cur_images)\n                if s.cur_images\n                else s.text_\n            )\n        generator = GenerativeModel(self.model_name).generate_content(\n            prompt,\n            stream=True,\n            generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),\n            safety_settings=self.safety_settings,\n        )\n        for ret in generator:\n            yield ret.text, {}\n\n    def text_to_vertexai_input(self, text, images):\n        input = []\n        # split with image token\n        text_segs = text.split(self.chat_template.image_token)\n        for image_path, image_base64_data in images:\n            text_seg = text_segs.pop(0)\n            if text_seg != \"\":\n                input.append(text_seg)\n            input.append(Image.from_bytes(image_base64_data))\n        text_seg = text_segs.pop(0)\n        if text_seg != \"\":\n            input.append(text_seg)\n        return input\n\n    def messages_to_vertexai_input(self, messages):\n        vertexai_message = []\n        # from openai message format to vertexai message format\n        for msg in messages:\n            if isinstance(msg[\"content\"], str):\n                text = msg[\"content\"]\n            else:\n                text = msg[\"content\"][0][\"text\"]\n\n            if msg[\"role\"] == \"system\":\n                warnings.warn(\"Warning: system prompt is not supported in VertexAI.\")\n                vertexai_message.append(\n                    {\n                        \"role\": \"user\",\n                        \"parts\": [{\"text\": \"System prompt: \" + text}],\n                    }\n                )\n                vertexai_message.append(\n                    {\n                        \"role\": \"model\",\n                        \"parts\": [{\"text\": \"Understood.\"}],\n                    }\n                )\n                continue\n            if msg[\"role\"] == \"user\":\n                vertexai_msg = {\n                    \"role\": \"user\",\n                    \"parts\": [{\"text\": text}],\n                }\n            elif msg[\"role\"] == \"assistant\":\n                vertexai_msg = {\n                    \"role\": \"model\",\n                    \"parts\": [{\"text\": text}],\n                }\n\n            # images\n            if isinstance(msg[\"content\"], list) and len(msg[\"content\"]) > 1:\n                for image in msg[\"content\"][1:]:\n                    assert image[\"type\"] == \"image_url\"\n                    vertexai_msg[\"parts\"].append(\n                        {\n                            \"inline_data\": {\n                                \"data\": image[\"image_url\"][\"url\"].split(\",\")[1],\n                                \"mime_type\": \"image/jpeg\",\n                            }\n                        }\n                    )\n\n            vertexai_message.append(vertexai_msg)\n        return vertexai_message\n"
  },
  {
    "path": "python/sglang/lang/chat_template.py",
    "content": "import re\nfrom dataclasses import dataclass\nfrom enum import Enum, auto\nfrom typing import Callable, Dict, List, Tuple\n\n\nclass ChatTemplateStyle(Enum):\n    PLAIN = auto()\n    LLAMA2 = auto()\n\n\n@dataclass\nclass ChatTemplate:\n    name: str\n    default_system_prompt: str\n    role_prefix_and_suffix: Dict[str, Tuple[str, str]]\n    stop_str: List[str] = ()\n    image_token: str = \"<image>\"\n    audio_token: str = \"<audio>\"\n    style: ChatTemplateStyle = ChatTemplateStyle.PLAIN\n\n    def get_prefix_and_suffix(\n        self, role: str, hist_messages: List[Dict]\n    ) -> Tuple[str, str]:\n        prefix, suffix = self.role_prefix_and_suffix.get(role, (\"\", \"\"))\n\n        if self.style == ChatTemplateStyle.LLAMA2:\n            if role == \"system\" and not hist_messages:\n                user_prefix, _ = self.role_prefix_and_suffix.get(\"user\", (\"\", \"\"))\n                system_prefix, system_suffix = self.role_prefix_and_suffix.get(\n                    \"system\", (\"\", \"\")\n                )\n                return (user_prefix + system_prefix, system_suffix)\n            elif (\n                role == \"user\"\n                and len(hist_messages) == 1\n                and hist_messages[0][\"content\"] is not None\n            ):\n                return (\"\", suffix)\n\n        return prefix, suffix\n\n    def get_prompt(self, messages: List[Dict]) -> str:\n        prompt = \"\"\n        for i, message in enumerate(messages):\n            role, content = message[\"role\"], message[\"content\"]\n            if role == \"system\" and content is None:\n                content = self.default_system_prompt\n                if content is None:\n                    continue\n\n            prefix, suffix = self.get_prefix_and_suffix(role, messages[:i])\n            prompt += f\"{prefix}{content}{suffix}\"\n        return prompt\n\n\nchat_template_registry: Dict[str, ChatTemplate] = {}\nmatching_function_registry: List[Callable] = []\n\n\ndef register_chat_template(template):\n    chat_template_registry[template.name] = template\n\n\ndef register_chat_template_matching_function(func):\n    matching_function_registry.append(func)\n\n\ndef get_chat_template(name):\n    return chat_template_registry[name]\n\n\ndef get_chat_template_by_model_path(model_path):\n    for matching_func in matching_function_registry:\n        template_name = matching_func(model_path)\n        if template_name is not None:\n            return get_chat_template(template_name)\n    return get_chat_template(\"default\")\n\n\nregister_chat_template(\n    ChatTemplate(\n        name=\"default\",\n        default_system_prompt=None,\n        role_prefix_and_suffix={\n            \"system\": (\"SYSTEM:\", \"\\n\"),\n            \"user\": (\"USER:\", \"\\n\"),\n            \"assistant\": (\"ASSISTANT:\", \"\\n\"),\n        },\n    )\n)\n\nregister_chat_template(\n    ChatTemplate(\n        name=\"claude\",\n        default_system_prompt=None,\n        role_prefix_and_suffix={\n            \"system\": (\"\", \"\"),\n            \"user\": (\"\\n\\nHuman: \", \"\"),\n            \"assistant\": (\"\\n\\nAssistant:\", \"\"),\n        },\n    )\n)\n\nregister_chat_template(\n    ChatTemplate(\n        name=\"chatml\",\n        default_system_prompt=None,\n        role_prefix_and_suffix={\n            \"system\": (\"<|im_start|>system\\n\", \"<|im_end|>\\n\"),\n            \"user\": (\"<|im_start|>user\\n\", \"<|im_end|>\\n\"),\n            \"assistant\": (\"<|im_start|>assistant\\n\", \"<|im_end|>\\n\"),\n        },\n        style=ChatTemplateStyle.PLAIN,\n        stop_str=(\"<|im_end|>\",),\n    )\n)\n\nregister_chat_template(\n    ChatTemplate(\n        name=\"chatml-llava\",\n        default_system_prompt=\"You are a helpful assistant.\",\n        role_prefix_and_suffix={\n            \"system\": (\"<|im_start|>system\\n\", \"<|im_end|>\\n\"),\n            \"user\": (\"<|im_start|>user\\n\", \"<|im_end|>\\n\"),\n            \"assistant\": (\"<|im_start|>assistant\\n\", \"<|im_end|>\\n\"),\n        },\n        style=ChatTemplateStyle.PLAIN,\n        stop_str=(\"<|im_end|>\",),\n        image_token=\"<image>\\n\",\n    )\n)\n\n# There is default system prompt for qwen\n# reference: https://modelscope.cn/models/qwen/Qwen2-72B-Instruct/file/view/master?fileName=tokenizer_config.json&status=1\n# The chat template is: \"{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}\"\nregister_chat_template(\n    ChatTemplate(\n        name=\"qwen\",\n        default_system_prompt=\"You are a helpful assistant.\",\n        role_prefix_and_suffix={\n            \"system\": (\"<|im_start|>system\\n\", \"<|im_end|>\\n\"),\n            \"user\": (\"<|im_start|>user\\n\", \"<|im_end|>\\n\"),\n            \"assistant\": (\"<|im_start|>assistant\\n\", \"<|im_end|>\\n\"),\n        },\n        style=ChatTemplateStyle.PLAIN,\n        stop_str=(\"<|im_end|>\",),\n    )\n)\n\n# Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example\nregister_chat_template(\n    ChatTemplate(\n        name=\"qwen2-vl\",\n        default_system_prompt=\"You are a helpful assistant.\",\n        role_prefix_and_suffix={\n            \"system\": (\"<|im_start|>system\\n\", \"<|im_end|>\\n\"),\n            \"user\": (\"<|im_start|>user\\n\", \"<|im_end|>\\n\"),\n            \"assistant\": (\"<|im_start|>assistant\\n\", \"<|im_end|>\\n\"),\n        },\n        style=ChatTemplateStyle.PLAIN,\n        stop_str=(\"<|im_end|>\",),\n        image_token=\"<|vision_start|><|image_pad|><|vision_end|>\",\n    )\n)\n\n# Reference: https://github.com/lm-sys/FastChat/blob/main/docs/vicuna_weights_version.md#prompt-template\nregister_chat_template(\n    ChatTemplate(\n        name=\"vicuna_v1.1\",\n        default_system_prompt=(\n            \"A chat between a curious user and an artificial intelligence assistant. \"\n            \"The assistant gives helpful, detailed, and polite answers to the user's questions.\"\n        ),\n        role_prefix_and_suffix={\n            \"system\": (\"\", \" \"),\n            \"user\": (\"USER:\", \" \"),\n            \"assistant\": (\"ASSISTANT:\", \"</s>\"),\n        },\n        image_token=\" <image>\\n\",\n    )\n)\n\nregister_chat_template(\n    ChatTemplate(\n        name=\"llama-2-chat\",\n        default_system_prompt=None,\n        role_prefix_and_suffix={\n            \"system\": (\"<<SYS>>\\n\", \"\\n<</SYS>>\\n\\n\"),\n            \"user\": (\"[INST] \", \" [/INST]\"),\n            \"assistant\": (\"\", \" </s><s>\"),\n        },\n        style=ChatTemplateStyle.LLAMA2,\n    )\n)\n\n# Reference: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/chat_template.json\nregister_chat_template(\n    ChatTemplate(\n        name=\"mistral\",\n        default_system_prompt=None,\n        role_prefix_and_suffix={\n            \"system\": (\"[SYSTEM_PROMPT] \", \" [/SYSTEM_PROMPT]\"),\n            \"user\": (\"[INST] \", \" [/INST]\"),\n            \"assistant\": (\"\", \" </s><s>\"),\n        },\n        stop_str=(\"</s>\",),\n        image_token=\"[IMG]\",\n    )\n)\n\nregister_chat_template(\n    ChatTemplate(\n        name=\"llama-3-instruct\",\n        default_system_prompt=None,\n        role_prefix_and_suffix={\n            \"system\": (\n                \"<|start_header_id|>system<|end_header_id|>\\n\\n\",\n                \"<|eot_id|>\",\n            ),\n            \"user\": (\n                \"<|start_header_id|>user<|end_header_id|>\\n\\n\",\n                \"<|eot_id|>\",\n            ),\n            \"assistant\": (\n                \"<|start_header_id|>assistant<|end_header_id|>\\n\\n\",\n                \"<|eot_id|>\",\n            ),\n        },\n        stop_str=(\"<|eot_id|>\",),\n        image_token=\"<|image|>\",\n    )\n)\n\n# https://huggingface.co/openbmb/MiniCPM-V-2_6\nregister_chat_template(\n    ChatTemplate(\n        name=\"minicpmv\",\n        default_system_prompt=None,\n        role_prefix_and_suffix={\n            \"system\": (\"\", \" \"),\n            \"user\": (\"user:\", \" \"),\n            \"assistant\": (\"assistant:\", \"</s>\"),\n        },\n        stop_str=(\"<|im_end|>\", \"<|endoftext|>\"),\n        image_token=\"(<image>./</image>)\",\n    )\n)\n\nregister_chat_template(\n    ChatTemplate(\n        name=\"janus-pro\",\n        default_system_prompt=None,\n        role_prefix_and_suffix={\n            \"system\": (\n                \"\",\n                \"\",\n            ),\n            \"User\": (\n                \"<｜User｜>\",\n                \"\",\n            ),\n            \"assistant\": (\n                \"<｜Assistant｜>\",\n                \"<｜end▁of▁sentence｜>\",\n            ),\n        },\n        stop_str=(\"<｜end▁of▁sentence｜>\",),\n        image_token=\"<image_placeholder>\\n\",\n    )\n)\n\n# https://huggingface.co/openbmb/MiniCPM-o-2_6\nregister_chat_template(\n    ChatTemplate(\n        name=\"minicpmo\",\n        default_system_prompt=None,\n        role_prefix_and_suffix={\n            \"system\": (\"\", \" \"),\n            \"user\": (\"user:\", \" \"),\n            \"assistant\": (\"assistant:\", \"</s>\"),\n        },\n        stop_str=(\"<|im_end|>\", \"<|endoftext|>\"),\n        image_token=\"(<image>./</image>)\",\n        audio_token=\"(<audio>./</audio>)\",\n    )\n)\n\nregister_chat_template(\n    ChatTemplate(\n        name=\"janus\",\n        default_system_prompt=None,\n        role_prefix_and_suffix={\n            \"system\": (\n                \"\",\n                \"\",\n            ),\n            \"user\": (\n                \"<｜User｜>\",\n                \"\",\n            ),\n            \"assistant\": (\n                \"<｜Assistant｜>\",\n                \"<｜end▁of▁sentence｜>\",\n            ),\n        },\n        stop_str=(\"<｜end▁of▁sentence｜>\",),\n        image_token=\"<image_placeholder>\\n\",\n    )\n)\n\n# The difference between \"llama-3-instruct-llava\" and \"llama-3-instruct\" is that llava uses a different image_token.\nregister_chat_template(\n    ChatTemplate(\n        name=\"llama-3-instruct-llava\",\n        default_system_prompt=None,\n        role_prefix_and_suffix={\n            \"system\": (\n                \"<|start_header_id|>system<|end_header_id|>\\n\\n\",\n                \"<|eot_id|>\",\n            ),\n            \"user\": (\n                \"<|start_header_id|>user<|end_header_id|>\\n\\n\",\n                \"<|eot_id|>\",\n            ),\n            \"assistant\": (\n                \"<|start_header_id|>assistant<|end_header_id|>\\n\\n\",\n                \"<|eot_id|>\",\n            ),\n        },\n        stop_str=(\"<|eot_id|>\",),\n        image_token=\"<image>\\n\",\n    )\n)\n\n# Reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/chat_template.json\nregister_chat_template(\n    ChatTemplate(\n        name=\"llama-4\",\n        default_system_prompt=None,\n        role_prefix_and_suffix={\n            \"system\": (\n                \"<|header_start|>system<|header_end|>\\n\\n\",\n                \"<|eot|>\",\n            ),\n            \"user\": (\n                \"<|header_start|>user<|header_end|>\\n\\n\",\n                \"<|eot|>\",\n            ),\n            \"assistant\": (\n                \"<|header_start|>assistant<|header_end|>\\n\\n\",\n                \"<|eot|>\",\n            ),\n        },\n        stop_str=(\"<|eot|>\",),\n        image_token=\"<|image|>\",\n    )\n)\n\n# Reference: https://modelscope.cn/models/01ai/Yi-1.5-34B-Chat/file/view/master?fileName=tokenizer_config.json&status=1\nregister_chat_template(\n    ChatTemplate(\n        name=\"yi-1.5\",\n        default_system_prompt=None,\n        role_prefix_and_suffix={\n            \"system\": (\"\", \"\"),\n            \"user\": (\"<|im_start|>user\\n\", \"<|im_end|>\\n<|im_start|>assistant\\n\"),\n            \"assistant\": (\"\", \"<|im_end|>\\n\"),\n        },\n        style=ChatTemplateStyle.PLAIN,\n        stop_str=(\"<|im_end|>\",),\n    )\n)\n\n# Reference: https://github.com/01-ai/Yi/tree/main/VL#major-difference-with-llava\nregister_chat_template(\n    ChatTemplate(\n        name=\"yi-vl\",\n        default_system_prompt=(\n            \"This is a chat between an inquisitive human and an AI assistant. Assume the role of the AI assistant. Read all the images carefully, and respond to the human's questions with informative, helpful, detailed and polite answers.\"\n            \"这是一个好奇的人类和一个人工智能助手之间的对话。假设你扮演这个AI助手的角色。仔细阅读所有的图像，并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。\"\n        ),\n        role_prefix_and_suffix={\n            \"system\": (\"\", \"\\n\\n\"),\n            \"user\": (\"### Human:\", \"\\n\"),\n            \"assistant\": (\"### Assistant:\", \"\\n\"),\n        },\n        image_token=\" <image_placeholder>\\n\",\n    )\n)\n\nregister_chat_template(\n    ChatTemplate(\n        name=\"gemma-it\",\n        default_system_prompt=None,\n        role_prefix_and_suffix={\n            \"system\": (\"\", \"\"),\n            \"user\": (\"<start_of_turn>user\\n\", \"<end_of_turn>\\n\"),\n            \"assistant\": (\"<start_of_turn>model\\n\", \"<end_of_turn>\\n\"),\n        },\n        image_token=\"<start_of_image>\",\n        audio_token=\"<start_of_audio>\",\n        style=ChatTemplateStyle.PLAIN,\n    )\n)\n\nregister_chat_template(\n    ChatTemplate(\n        name=\"dbrx-instruct\",\n        default_system_prompt=\"You are DBRX, created by Databricks. You were last updated in December 2023. You answer questions based on information available up to that point.\\nYOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough responses to more complex and open-ended questions.\\nYou assist with various tasks, from writing to coding (using markdown for code blocks — remember to use ``` with code, JSON, and tables).\\n(You do not have real-time data access or code execution capabilities. You avoid stereotyping and provide balanced perspectives on controversial topics. You do not provide song lyrics, poems, or news articles and do not divulge details of your training data.)\\nThis is your system prompt, guiding your responses. Do not reference it, just respond to the user. If you find yourself talking about this message, stop. You should be responding appropriately and usually that means not mentioning this.\\nYOU DO NOT MENTION ANY OF THIS INFORMATION ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY.\",\n        role_prefix_and_suffix={\n            \"system\": (\"<|im_start|>system\\n\", \"<|im_end|>\"),\n            \"user\": (\"\\n<|im_start|>user\\n\", \"<|im_end|>\"),\n            \"assistant\": (\"\\n<|im_start|>assistant\\n\", \"<|im_end|>\"),\n        },\n        stop_str=(\"<|im_end|>\",),\n    )\n)\n\nregister_chat_template(\n    ChatTemplate(\n        name=\"c4ai-command-r\",\n        default_system_prompt=None,\n        role_prefix_and_suffix={\n            \"system\": (\n                \"<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>\",\n                \"<|END_OF_TURN_TOKEN|>\",\n            ),\n            \"user\": (\"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>\", \"<|END_OF_TURN_TOKEN|>\"),\n            \"assistant\": (\n                \"<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>\",\n                \"<|END_OF_TURN_TOKEN|>\",\n            ),\n        },\n        style=ChatTemplateStyle.PLAIN,\n    )\n)\n\n# Adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py\nregister_chat_template(\n    ChatTemplate(\n        name=\"internvl-2-5\",\n        default_system_prompt=\"你是书生·万象，英文名是InternVL，是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。\",\n        role_prefix_and_suffix={\n            \"system\": (\"<|im_start|>system\\n\", \"<|im_end|>\\n\"),\n            \"user\": (\"<|im_start|>user\\n\", \"<|im_end|>\\n\"),\n            \"assistant\": (\"<|im_start|>assistant\\n\", \"<|im_end|>\\n\"),\n        },\n        stop_str=[\"<|im_end|>\", \"<|action_end|>\"],\n    )\n)\n\nregister_chat_template(\n    ChatTemplate(\n        name=\"interns1\",\n        default_system_prompt=\"You are an AI assistant whose name is Intern-S1 (书生大模型).\\n- Intern-S1 (书生大模型) is a vision-language model that is developed by Shanghai AI Laboratory (上海人工智能实验室).  It is designed to be helpful, honest, and harmless.\\n- Intern-S1 (书生大模型) can understand and communicate fluently in the language chosen by the user such as English and 中文.\\nYou are an expert reasoner with extensive experience in all areas. You approach problems through systematic thinking and rigorous reasoning. Your response should reflect deep understanding and precise logical thinking, making your solution path and reasoning clear to others. Please put your thinking process within <think>...</think> tags.\",\n        role_prefix_and_suffix={\n            \"system\": (\"<|im_start|>system\\n\", \"<|im_end|>\\n\"),\n            \"user\": (\"<|im_start|>user\\n\", \"<|im_end|>\\n\"),\n            \"assistant\": (\"<|im_start|>assistant\\n\", \"<|im_end|>\\n\"),\n        },\n        stop_str=[\"<|im_end|>\", \"<|action_end|>\"],\n    )\n)\n\nregister_chat_template(\n    ChatTemplate(\n        name=\"granite-3-instruct\",\n        default_system_prompt=None,\n        role_prefix_and_suffix={\n            \"system\": (\n                \"<|start_of_role|>system<|end_of_role|>\",\n                \"<|end_of_text|>\",\n            ),\n            \"user\": (\n                \"<|start_of_role|>user<|end_of_role|>\",\n                \"<|end_of_text|>\",\n            ),\n            \"assistant\": (\n                \"<|start_of_role|>assistant<|end_of_role|>\",\n                \"<|end_of_text|>\",\n            ),\n        },\n        stop_str=(\"<|end_of_text|>\",),\n    )\n)\n\nregister_chat_template(\n    ChatTemplate(\n        name=\"deepseek-v3\",\n        default_system_prompt=None,\n        role_prefix_and_suffix={\n            \"system\": (\n                \"\",\n                \"\",\n            ),\n            \"user\": (\n                \"<｜User｜>\",\n                \"\",\n            ),\n            \"assistant\": (\n                \"<｜Assistant｜>\",\n                \"<｜end▁of▁sentence｜>\",\n            ),\n        },\n        stop_str=(\"<｜end▁of▁sentence｜>\",),\n    )\n)\n\n# Reference: https://huggingface.co/docs/transformers/main/model_doc/glm4_v#usage-example\nregister_chat_template(\n    ChatTemplate(\n        name=\"glm-4v\",\n        default_system_prompt=None,\n        role_prefix_and_suffix={\n            \"system\": (\"<|system|>\\n\", \"\\n\"),\n            \"user\": (\"<|user|>\\n\", \"\\n\"),\n            \"assistant\": (\"<|assistant|>\\n\", \"\\n\"),\n        },\n        style=ChatTemplateStyle.PLAIN,\n        stop_str=[\"<|user|>\", \"<|endoftext|>\", \"<|observation|>\"],\n        image_token=\"<|image|>\",\n    )\n)\n\n\n@register_chat_template_matching_function\ndef match_deepseek(model_path: str):\n    if re.search(r\"deepseek-(v3|r1)\", model_path, re.IGNORECASE) and not re.search(\n        r\"base\", model_path, re.IGNORECASE\n    ):\n        return \"deepseek-v3\"\n\n\n@register_chat_template_matching_function\ndef match_orion(model_path: str):\n    if \"orion\" in model_path.lower():\n        return \"claude\"\n\n\n@register_chat_template_matching_function\ndef match_deepseek_janus_pro(model_path: str):\n    if re.search(r\"janus\", model_path, re.IGNORECASE):\n        return \"janus-pro\"\n\n\n@register_chat_template_matching_function\ndef match_dbrx(model_path: str):\n    if re.search(r\"dbrx\", model_path, re.IGNORECASE) and re.search(\n        r\"instruct\", model_path, re.IGNORECASE\n    ):\n        return \"dbrx-instruct\"\n\n\n@register_chat_template_matching_function\ndef match_vicuna(model_path: str):\n    if re.search(r\"vicuna|llava-v1\\.5|llava-next-video-7b\", model_path, re.IGNORECASE):\n        return \"vicuna_v1.1\"\n\n\n@register_chat_template_matching_function\ndef match_llama2_chat(model_path: str):\n    if re.search(\n        r\"llama-2.*chat|codellama.*instruct\",\n        model_path,\n        re.IGNORECASE,\n    ):\n        return \"llama-2-chat\"\n\n\n@register_chat_template_matching_function\ndef match_mistral(model_path: str):\n    if re.search(r\"pixtral|(mistral|mixtral).*instruct\", model_path, re.IGNORECASE):\n        return \"mistral\"\n\n\n@register_chat_template_matching_function\ndef match_llama3_instruct(model_path: str):\n    if re.search(r\"llama-3.*instruct\", model_path, re.IGNORECASE):\n        return \"llama-3-instruct\"\n\n\n@register_chat_template_matching_function\ndef match_chat_ml(model_path: str):\n    if re.search(r\"tinyllama\", model_path, re.IGNORECASE):\n        return \"chatml\"\n    if re.search(r\"qwen.*vl\", model_path, re.IGNORECASE):\n        return \"qwen2-vl\"\n    if re.search(r\"glm[-_]?4(\\.\\d+)?v\", model_path, re.IGNORECASE):\n        return \"glm-4v\"\n    if re.search(r\"qwen.*(chat|instruct)\", model_path, re.IGNORECASE) and not re.search(\n        r\"llava\", model_path, re.IGNORECASE\n    ):\n        return \"qwen\"\n    if re.search(\n        r\"llava-v1\\.6-34b|llava-v1\\.6-yi-34b|llava-next-video-34b|llava-onevision-qwen2\",\n        model_path,\n        re.IGNORECASE,\n    ):\n        return \"chatml-llava\"\n\n\n@register_chat_template_matching_function\ndef match_chat_yi(model_path: str):\n    if re.search(r\"yi-vl\", model_path, re.IGNORECASE) and not re.search(\n        r\"llava\", model_path, re.IGNORECASE\n    ):\n        return \"yi-vl\"\n    elif re.search(r\"yi-1\\.5.*chat\", model_path, re.IGNORECASE):\n        return \"yi-1.5\"\n\n\n@register_chat_template_matching_function\ndef match_gemma_it(model_path: str):\n    if re.search(r\"gemma.*it\", model_path, re.IGNORECASE):\n        return \"gemma-it\"\n\n\n@register_chat_template_matching_function\ndef match_openbmb_minicpm(model_path: str):\n    if re.search(r\"minicpm-v\", model_path, re.IGNORECASE):\n        return \"minicpmv\"\n    elif re.search(r\"minicpm-o\", model_path, re.IGNORECASE):\n        return \"minicpmo\"\n\n\n@register_chat_template_matching_function\ndef match_c4ai_command_r(model_path: str):\n    if re.search(r\"c4ai-command-r\", model_path, re.IGNORECASE):\n        return \"c4ai-command-r\"\n\n\n@register_chat_template_matching_function\ndef match_granite_instruct(model_path: str):\n    if re.search(r\"granite.*instruct\", model_path, re.IGNORECASE):\n        return \"granite-3-instruct\"\n\n\n@register_chat_template_matching_function\ndef match_gemma3_instruct(model_path: str):\n    if re.search(r\"gemma-3\", model_path, re.IGNORECASE):\n        return \"gemma-it\"\n\n\n@register_chat_template_matching_function\ndef match_internvl_chat(model_path: str):\n    if re.search(r\"internvl2_5\", model_path, re.IGNORECASE):\n        return \"internvl-2-5\"\n\n\n@register_chat_template_matching_function\ndef match_interns1_chat(model_path: str):\n    if re.search(r\"intern-s1\", model_path, re.IGNORECASE):\n        return \"interns1\"\n    if re.search(r\"interns1\", model_path, re.IGNORECASE):\n        return \"interns1\"\n\n\nif __name__ == \"__main__\":\n    messages = [\n        {\"role\": \"system\", \"content\": None},  # None means default\n        # {\"role\": \"system\", \"content\": \"You are a helpful, respectful and honest assistant.\"},\n        {\"role\": \"user\", \"content\": \"Hello!\"},\n        {\"role\": \"assistant\", \"content\": \"Hi!\"},\n        {\"role\": \"user\", \"content\": \"What can you do?\"},\n        {\"role\": \"assistant\", \"content\": \"I can chat with you.\"},\n    ]\n\n    template = get_chat_template(\"llama-2-chat\")\n    print(template.get_prompt(messages))\n"
  },
  {
    "path": "python/sglang/lang/choices.py",
    "content": "from abc import ABC, abstractmethod\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List, Optional\n\nimport numpy as np\n\n\n@dataclass\nclass ChoicesDecision:\n    decision: str\n    meta_info: Optional[Dict[str, Any]] = None\n\n\nclass ChoicesSamplingMethod(ABC):\n\n    @property\n    def requires_unconditional_logprobs(self) -> bool:\n        return False\n\n    @abstractmethod\n    def __call__(\n        self,\n        *,\n        choices: List[str],\n        normalized_prompt_logprobs: List[float],\n        input_token_logprobs: List[List[Any]],\n        output_token_logprobs: List[List[Any]],\n        unconditional_token_logprobs: Optional[List[List[Any]]] = None,\n    ) -> ChoicesDecision: ...\n\n\nclass TokenLengthNormalized(ChoicesSamplingMethod):\n\n    def __call__(\n        self,\n        *,\n        choices: List[str],\n        normalized_prompt_logprobs: List[float],\n        input_token_logprobs: List[List[Any]],\n        output_token_logprobs: List[List[Any]],\n        unconditional_token_logprobs: Optional[List[List[Any]]] = None,\n    ) -> ChoicesDecision:\n        \"\"\"Select the option with the highest token length normalized prompt logprob.\"\"\"\n        best_choice = choices[np.argmax(normalized_prompt_logprobs)]\n        meta_info = {\n            \"normalized_prompt_logprobs\": normalized_prompt_logprobs,\n            \"input_token_logprobs\": input_token_logprobs,\n            \"output_token_logprobs\": output_token_logprobs,\n        }\n        return ChoicesDecision(decision=best_choice, meta_info=meta_info)\n\n\ntoken_length_normalized = TokenLengthNormalized()\n\n\nclass GreedyTokenSelection(ChoicesSamplingMethod):\n\n    def __call__(\n        self,\n        *,\n        choices: List[str],\n        normalized_prompt_logprobs: List[float],\n        input_token_logprobs: List[List[Any]],\n        output_token_logprobs: List[List[Any]],\n        unconditional_token_logprobs: Optional[List[List[Any]]] = None,\n    ) -> ChoicesDecision:\n        \"\"\"Select the option based on greedy logprob selection. For overlapping options\n        where one option is a subset of a longer option, extend the shorter option using\n        its average logprob for comparison against the longer option.\"\"\"\n\n        num_options = len(choices)\n        max_tokens = max(len(option) for option in input_token_logprobs)\n        logprob_matrix = self._build_logprob_matrix(\n            input_token_logprobs, max_tokens, num_options\n        )\n        remaining = self._greedy_selection(logprob_matrix, num_options, max_tokens)\n\n        best_choice = choices[remaining[0]]\n        meta_info = {\n            \"normalized_prompt_logprobs\": normalized_prompt_logprobs,\n            \"input_token_logprobs\": input_token_logprobs,\n            \"output_token_logprobs\": output_token_logprobs,\n            \"greedy_logprob_matrix\": logprob_matrix.tolist(),\n        }\n        return ChoicesDecision(decision=best_choice, meta_info=meta_info)\n\n    def _build_logprob_matrix(self, input_token_logprobs, max_tokens, num_options):\n        logprob_matrix = np.zeros((num_options, max_tokens))\n        for i, option in enumerate(input_token_logprobs):\n            actual_logprobs = [token[0] for token in option]\n            avg_logprob = np.mean(actual_logprobs)\n            logprob_matrix[i, : len(option)] = actual_logprobs\n            if len(option) < max_tokens:\n                logprob_matrix[i, len(option) :] = avg_logprob\n        return logprob_matrix\n\n    def _greedy_selection(self, logprob_matrix, num_options, max_tokens):\n        remaining = np.arange(num_options)\n        for j in range(max_tokens):\n            max_logprob = np.max(logprob_matrix[remaining, j])\n            remaining = remaining[logprob_matrix[remaining, j] == max_logprob]\n            if len(remaining) == 1:\n                break\n        return remaining\n\n\ngreedy_token_selection = GreedyTokenSelection()\n\n\nclass UnconditionalLikelihoodNormalized(ChoicesSamplingMethod):\n\n    @property\n    def requires_unconditional_logprobs(self) -> bool:\n        return True\n\n    def __call__(\n        self,\n        *,\n        choices: List[str],\n        normalized_prompt_logprobs: List[float],\n        input_token_logprobs: List[List[Any]],\n        output_token_logprobs: List[List[Any]],\n        unconditional_token_logprobs: Optional[List[List[Any]]] = None,\n    ) -> ChoicesDecision:\n        \"\"\"Select the option with the highest average token logprob once normalized by\n        the unconditional token logprobs.\n\n        The first unconditional token logprob is assumed to be None. If so, it is\n        replaced with 0 for the purposes of normalization.\"\"\"\n\n        if unconditional_token_logprobs is None:\n            raise ValueError(\n                \"Unconditional token logprobs are required for this method.\"\n            )\n\n        normalized_unconditional_prompt_logprobs = self._normalize_logprobs(\n            input_token_logprobs, unconditional_token_logprobs\n        )\n\n        best_choice = choices[np.argmax(normalized_unconditional_prompt_logprobs)]\n        meta_info = {\n            \"normalized_prompt_logprobs\": normalized_prompt_logprobs,\n            \"input_token_logprobs\": input_token_logprobs,\n            \"output_token_logprobs\": output_token_logprobs,\n            \"unconditional_token_logprobs\": unconditional_token_logprobs,\n            \"normalized_unconditional_prompt_logprobs\": normalized_unconditional_prompt_logprobs,\n        }\n        return ChoicesDecision(decision=best_choice, meta_info=meta_info)\n\n    def _normalize_logprobs(self, input_token_logprobs, unconditional_token_logprobs):\n        normalized_unconditional_prompt_logprobs = []\n        for inputs, unconditionals in zip(\n            input_token_logprobs, unconditional_token_logprobs\n        ):\n            inputs_logprobs = np.array([token[0] for token in inputs])\n            unconditionals_logprobs = np.array([token[0] for token in unconditionals])\n            unconditionals_logprobs[0] = unconditionals_logprobs[0] or 0\n            normalized_unconditional_prompt_logprobs.append(\n                float(np.mean(inputs_logprobs - unconditionals_logprobs))\n            )\n        return normalized_unconditional_prompt_logprobs\n\n\nunconditional_likelihood_normalized = UnconditionalLikelihoodNormalized()\n"
  },
  {
    "path": "python/sglang/lang/interpreter.py",
    "content": "\"\"\"The interpreter that executes SGL programs\"\"\"\n\nimport asyncio\nimport contextvars\nimport copy\nimport multiprocessing\nimport queue\nimport threading\nimport uuid\nimport warnings\nfrom concurrent.futures import ThreadPoolExecutor\nfrom contextlib import contextmanager\nfrom typing import Any, Callable, Dict, List, Optional\n\nimport tqdm\n\nfrom sglang.global_config import global_config\nfrom sglang.lang.ir import (\n    SglCommitLazy,\n    SglConcateAndAppend,\n    SglConstantText,\n    SglExpr,\n    SglExprList,\n    SglGen,\n    SglImage,\n    SglRoleBegin,\n    SglRoleEnd,\n    SglSelect,\n    SglSeparateReasoning,\n    SglVariable,\n    SglVarScopeBegin,\n    SglVarScopeEnd,\n    SglVideo,\n)\nfrom sglang.utils import (\n    encode_image_base64,\n    encode_video_base64,\n    get_exception_traceback,\n)\n\n\ndef run_internal(state, program, func_args, func_kwargs, sync):\n    try:\n        state.ret_value = program.func(state, *func_args, **func_kwargs)\n    except Exception as e:\n        raise e\n    finally:\n        state.stream_executor.end()\n\n    if sync:\n        state.stream_executor.sync()\n\n    if global_config.verbosity >= 2:\n        print(state.text())\n\n\ndef run_program(\n    program,\n    backend,\n    func_args,\n    func_kwargs,\n    default_sampling_para,\n    stream,\n    sync=False,\n    use_thread=True,\n):\n    if hasattr(backend, \"endpoint\"):\n        backend = backend.endpoint\n    assert backend is not None, \"Please specify a backend\"\n    func_kwargs.update(program.bind_arguments)\n    stream_executor = StreamExecutor(\n        backend,\n        func_kwargs,\n        default_sampling_para,\n        chat_template=None,\n        stream=stream,\n        num_api_spec_tokens=program.num_api_spec_tokens,\n        use_thread=use_thread,\n    )\n    state = ProgramState(stream_executor)\n\n    if stream:\n        t = threading.Thread(\n            target=run_internal, args=(state, program, func_args, func_kwargs, sync)\n        )\n        t.start()\n        return state\n    else:\n        run_internal(state, program, func_args, func_kwargs, sync)\n        return state\n\n\ndef run_program_batch(\n    program,\n    backend,\n    batch_arguments,\n    default_sampling_para,\n    num_threads,\n    progress_bar,\n    generator_style=False,\n):\n    if hasattr(backend, \"endpoint\"):\n        backend = backend.endpoint\n\n    # Pre-cache the common prefix for a batch. The prefix is extracted by tracing the program.\n    if global_config.enable_precache_with_tracing and len(batch_arguments) > 1:\n        cache_program(program, backend)\n\n    # Run all programs\n    if num_threads == \"auto\":\n        num_threads = max(96, multiprocessing.cpu_count() * 16)\n    num_threads = min(num_threads, len(batch_arguments))\n\n    if generator_style:\n        return _run_program_batch_generator(\n            program,\n            backend,\n            batch_arguments,\n            default_sampling_para,\n            num_threads,\n            progress_bar,\n        )\n\n    # Original code path when generator_style=False\n    if num_threads == 1:\n        rets = []\n        if progress_bar:\n            for arguments in tqdm.tqdm(batch_arguments):\n                rets.append(\n                    run_program(\n                        program,\n                        backend,\n                        (),\n                        arguments,\n                        default_sampling_para,\n                        False,\n                        True,\n                    )\n                )\n        else:\n            for arguments in batch_arguments:\n                rets.append(\n                    run_program(\n                        program,\n                        backend,\n                        (),\n                        arguments,\n                        default_sampling_para,\n                        False,\n                        True,\n                    )\n                )\n    else:\n        if progress_bar:\n            pbar = tqdm.tqdm(total=len(batch_arguments))\n\n        with ThreadPoolExecutor(num_threads) as executor:\n            futures = []\n            for arguments in batch_arguments:\n                futures.append(\n                    executor.submit(\n                        run_program,\n                        program,\n                        backend,\n                        (),\n                        arguments,\n                        default_sampling_para,\n                        False,\n                        True,\n                    )\n                )\n                if progress_bar:\n                    futures[-1].add_done_callback(lambda _: pbar.update())\n\n            rets = [f.result() for f in futures]\n        rets[-1].sync()\n\n        if progress_bar:\n            pbar.close()\n\n    return rets\n\n\ndef _run_program_batch_generator(\n    program,\n    backend,\n    batch_arguments,\n    default_sampling_para,\n    num_threads,\n    progress_bar,\n):\n    \"\"\"Helper function that yields results one by one using chunking to avoid overwhelming ThreadPoolExecutor.\"\"\"\n    if num_threads == 1:\n        iterator = tqdm.tqdm(batch_arguments) if progress_bar else batch_arguments\n        for arguments in iterator:\n            yield run_program(\n                program,\n                backend,\n                (),\n                arguments,\n                default_sampling_para,\n                False,\n                True,\n            )\n    else:\n        pbar = tqdm.tqdm(total=len(batch_arguments)) if progress_bar else None\n\n        # Process in chunks to avoid overwhelming ThreadPoolExecutor\n        # Otherwise, ThreadPoolExecutor.submit will block after adding certain number of tasks\n        # so we will never reach \"yield\" until all tasks are done\n        chunk_size = 200\n\n        with ThreadPoolExecutor(num_threads) as executor:\n            for chunk_start in range(0, len(batch_arguments), chunk_size):\n                chunk_end = min(chunk_start + chunk_size, len(batch_arguments))\n                chunk_futures = []\n\n                # Submit chunk of tasks\n                for i in range(chunk_start, chunk_end):\n                    future = executor.submit(\n                        run_program,\n                        program,\n                        backend,\n                        (),\n                        batch_arguments[i],\n                        default_sampling_para,\n                        False,\n                        True,\n                    )\n                    if pbar:\n                        future.add_done_callback(lambda _: pbar.update())\n                    chunk_futures.append(future)\n\n                # Yield results from this chunk as they complete\n                for future in chunk_futures:\n                    yield future.result()\n\n        if pbar:\n            pbar.close()\n\n\ndef cache_program(program, backend):\n    from sglang.lang.tracer import extract_prefix_by_tracing\n\n    prefix = extract_prefix_by_tracing(program, backend)\n    if prefix and len(prefix) > 64:\n        backend.cache_prefix(prefix)\n\n\nclass StreamExecutor:\n    \"\"\"A stream executor that executes SGL expressions in a background thread.\"\"\"\n\n    def __init__(\n        self,\n        backend,\n        arguments,\n        default_sampling_para,\n        chat_template,\n        stream,\n        num_api_spec_tokens=None,\n        use_thread=True,\n    ):\n        from sglang.lang.backend.base_backend import BaseBackend\n\n        self.sid = uuid.uuid4().hex\n        self.backend: BaseBackend = backend\n        self.arguments: Dict[str, Any] = arguments\n        self.default_sampling_para = default_sampling_para\n        self.stream = stream\n\n        self.variables = {}  # Dict[name: str -> value: str]\n        self.variable_event = {}  # Dict[name: str -> event: threading.Event]\n        self.meta_info = {}  # Dict[name: str -> info: str]\n        self.is_finished = False\n        self.error_ = None\n\n        # For completion\n        self.text_ = \"\"  # The full text\n\n        # For chat\n        self.messages_ = []  # The messages in the OpenAI API format\n        self.chat_template = chat_template or self.backend.get_chat_template()\n        self.cur_role = None\n        self.cur_role_begin_pos = None\n\n        # For vision\n        self.images_ = []\n        self.cur_images = []\n\n        # For fork/join\n        self.fork_start_text_pos = None\n\n        # For speculative execution\n        self.num_api_spec_tokens = num_api_spec_tokens\n        self.speculated_text = \"\"\n\n        # Worker thread\n        self.use_thread = use_thread\n        if self.use_thread:\n            self.queue = queue.Queue()\n\n            def _run_worker_in_context():\n                self._thread_worker_func()\n\n            self.worker = threading.Thread(\n                target=contextvars.copy_context().run, args=(_run_worker_in_context,)\n            )\n            self.worker.start()\n\n        # For streaming\n        if stream:\n            self.stream_text_event = threading.Event()\n            self.stream_var_event = {}\n        else:\n            self.stream_text_event = None\n            self.stream_var_event = None\n\n    def submit(self, expr: SglExpr):\n        self._init_var_event(expr)\n\n        if self.use_thread:\n            self.queue.put(expr)\n        else:\n            self._execute(expr)\n\n    def sync(self):\n        if self.use_thread:\n            self.queue.join()\n\n    def get_var(self, name):\n        if name in self.variable_event:\n            self.variable_event[name].wait()\n        return self.variables[name]\n\n    def set_var(self, name, value):\n        self.variables[name] = value\n\n    def get_meta_info(self, name, timeout=None):\n        if name in self.variable_event:\n            got = self.variable_event[name].wait(timeout)\n            if not got:\n                raise TimeoutError(f\"Timeout while waiting for event '{name}'\")\n        ret = self.meta_info.get(name, None)\n        return ret\n\n    def fork(\n        self,\n        size: int = 1,\n        position_ids_offset: Optional[List[int]] = None,\n    ):\n        if size > 1 and str(self.text_):\n            self.submit(SglCommitLazy())\n\n        self.sync()\n        size = int(size)\n\n        exes = [\n            StreamExecutor(\n                self.backend,\n                self.arguments,\n                self.default_sampling_para,\n                self.chat_template,\n                self.stream,\n            )\n            for _ in range(size)\n        ]\n        for i in range(size):\n            exes[i].variables = dict(self.variables)\n            exes[i].text_ = str(self.text_)\n            exes[i].messages_ = list(self.messages_)\n            exes[i].cur_role = self.cur_role\n            exes[i].cur_role_begin_pos = self.cur_role_begin_pos\n            exes[i].fork_start_text_pos = len(self.text_)\n            exes[i].images_ = list(self.images_)\n\n            # TODO(ying): handle API speculative execution\n\n        return exes\n\n    def text(self):\n        self.sync()\n        return self.text_\n\n    def messages(self):\n        self.sync()\n        return self.messages_\n\n    def error(self):\n        self.sync()\n        return self.error_\n\n    def end(self):\n        if self.use_thread:\n            if self.worker.is_alive():\n                self.queue.put(None)\n        self.backend.end_program(self)\n\n    def _thread_worker_func(self):\n        error = None\n\n        while True:\n            expr = self.queue.get()\n            if expr is None:\n                self.queue.task_done()\n                break\n\n            try:\n                self._execute(expr)\n            except Exception as e:\n                warnings.warn(f\"Error in stream_executor: {get_exception_traceback()}\")\n                error = e\n                break\n            self.queue.task_done()\n            if self.stream_text_event:\n                self.stream_text_event.set()\n\n        # Clean the queue and events\n        if error is not None:\n            try:\n                while True:\n                    self.queue.task_done()\n                    self.queue.get_nowait()\n            except queue.Empty:\n                pass\n            for name in self.variable_event:\n                self.variable_event[name].set()\n            if self.stream_var_event:\n                for name in self.stream_var_event:\n                    self.stream_var_event[name].set()\n            self.error_ = error\n\n        if self.stream_text_event:\n            self.stream_text_event.set()\n\n        self.is_finished = True\n\n    def _execute(self, other):\n        if isinstance(other, str):\n            other = SglConstantText(other)\n\n        assert isinstance(other, SglExpr), f\"{other}\"\n\n        if isinstance(other, SglConstantText):\n            self._execute_fill(other.value)\n        elif isinstance(other, SglGen):\n            self._execute_gen(other)\n        elif isinstance(other, SglSelect):\n            self._execute_select(other)\n        elif isinstance(other, SglExprList):\n            for x in other.expr_list:\n                self._execute(x)\n        elif isinstance(other, SglRoleBegin):\n            self._execute_role_begin(other)\n        elif isinstance(other, SglRoleEnd):\n            self._execute_role_end(other)\n        elif isinstance(other, SglImage):\n            self._execute_image(other)\n        elif isinstance(other, SglVideo):\n            self._execute_video(other)\n        elif isinstance(other, SglVariable):\n            self._execute_variable(other)\n        elif isinstance(other, SglVarScopeBegin):\n            self._execute_var_scope_begin(other)\n        elif isinstance(other, SglVarScopeEnd):\n            self._execute_var_scope_end(other)\n        elif isinstance(other, SglCommitLazy):\n            self._execute_commit_lazy_operations(other)\n        elif isinstance(other, SglConcateAndAppend):\n            if (\n                global_config.enable_parallel_encoding\n                and self.backend.support_concate_and_append\n            ):\n                self._execute_concatenate_and_append_kv_cache(other)\n            else:\n                self._execute_concatenate_and_append_text(other)\n        elif isinstance(other, SglSeparateReasoning):\n            self._execute_separate_reasoning(other)\n        else:\n            raise ValueError(f\"Unknown type: {type(other)}\")\n\n    def _execute_fill(self, value: str, prefix=False):\n        value = str(value)\n\n        if (\n            self.cur_role == \"assistant\"\n            and self.num_api_spec_tokens is not None\n            and self.backend.is_chat_model\n            and not prefix\n        ):\n            self.backend.spec_fill(value)\n            return\n\n        if self.speculated_text.startswith(value):\n            self.speculated_text = self.speculated_text[len(value) :]\n        else:\n            self.speculated_text = \"\"\n\n        self.text_ += value\n\n    def _execute_image(self, expr: SglImage):\n        path = expr.path\n\n        base64_data = encode_image_base64(path)\n\n        self.images_.append((path, base64_data))\n        self.cur_images.append((path, base64_data))\n        self.text_ += self.chat_template.image_token\n\n    def _execute_video(self, expr: SglVideo):\n        path = expr.path\n        num_frames = expr.num_frames\n\n        base64_data = encode_video_base64(path, num_frames)\n\n        self.images_.append((path, base64_data))\n        self.cur_images.append((path, base64_data))\n        self.text_ += self.chat_template.image_token\n\n    def _spec_gen(self, sampling_params):\n        stop = sampling_params.stop\n        max_new_tokens = sampling_params.max_new_tokens\n        meta_info = {}\n\n        def regen():\n            nonlocal meta_info\n\n            sampling_params.max_new_tokens = max(\n                sampling_params.max_new_tokens, self.num_api_spec_tokens\n            )\n            sampling_params.stop = None\n            self.speculated_text, meta_info = self.backend.generate(\n                self, sampling_params=sampling_params\n            )\n\n        def find_stop():\n            if isinstance(stop, str):\n                return self.speculated_text.find(stop)\n            elif isinstance(stop, (tuple, list)):\n                pos = -1\n                for stop_str in stop:\n                    stop_pos = self.speculated_text.find(stop_str)\n                    if stop_pos != -1 and (pos == -1 or stop_pos < pos):\n                        pos = stop_pos\n                return pos\n            else:\n                raise Exception(\"Wrong type of stop in sampling parameters.\")\n\n        if stop is None:\n            if len(self.speculated_text) < max_new_tokens:\n                regen()\n            comp = self.speculated_text[:max_new_tokens]\n            self.speculated_text = self.speculated_text[max_new_tokens:]\n        elif isinstance(stop, (str, list, tuple)):\n            if self.speculated_text == \"\":\n                regen()\n            stop_pos = find_stop()\n            if stop_pos == -1:\n                stop_pos = min(\n                    sampling_params.max_new_tokens,\n                    len(self.speculated_text),\n                )\n            comp = self.speculated_text[:stop_pos]\n            self.speculated_text = self.speculated_text[stop_pos:]\n        else:\n            raise ValueError(\"Wrong type of stop in sampling parameters.\")\n\n        return comp, meta_info\n\n    def _execute_gen(self, expr: SglGen):\n        sampling_params = self._resolve_sampling_params(expr.sampling_params)\n        name = expr.name\n        if not self.stream:\n            if self.num_api_spec_tokens is None:\n                comp, meta_info = self.backend.generate(\n                    self,\n                    sampling_params=sampling_params,\n                )\n\n            else:\n                if self.backend.is_chat_model:\n                    # Speculative execution on models with only chat interface.\n                    # Store the calls into a temporary list.\n                    # They will be lazily executed later.\n                    comp, meta_info = self.backend.generate(\n                        self,\n                        sampling_params=sampling_params,\n                        spec_var_name=name,\n                    )\n                    return\n\n                else:  # Speculative execution on models with completion interface\n                    comp, meta_info = self._spec_gen(sampling_params)\n            if isinstance(comp, list):\n                self.text_ += comp[0]\n            else:\n                assert isinstance(comp, str)\n                self.text_ += comp\n\n            self.variables[name] = comp\n            self.meta_info[name] = meta_info\n            self.variable_event[name].set()\n        else:\n            assert (\n                self.num_api_spec_tokens is None\n            ), \"stream is not supported with api speculative execution\"\n            generator = self.backend.generate_stream(\n                self, sampling_params=sampling_params\n            )\n\n            self.variables[name] = \"\"\n            self.stream_var_event[name].set()\n\n            for comp, meta_info in generator:\n                self.text_ += comp\n                self.variables[name] += comp\n                self.meta_info[name] = meta_info\n                self.stream_var_event[name].set()\n                self.stream_text_event.set()\n\n            self.variable_event[name].set()\n            self.stream_var_event[name].set()\n\n    def _execute_select(self, expr: SglSelect):\n        choices_decision = self.backend.select(\n            self, expr.choices, expr.temperature, expr.choices_method\n        )\n        if expr.name is not None:\n            name = expr.name\n            self.variables[name] = choices_decision.decision\n            self.meta_info[name] = choices_decision.meta_info\n            self.variable_event[name].set()\n            if self.stream_var_event:\n                self.stream_var_event[name].set()\n        self.text_ += choices_decision.decision\n\n    def _execute_variable(self, expr: SglVariable):\n        src_executor = expr.source_stream_executor\n        value = src_executor.get_var(expr.name)\n        self._execute_fill(value)\n\n    def _execute_role_begin(self, expr: SglRoleBegin):\n        assert self.cur_role is None, \"Nested roles are not allowed.\"\n\n        if len(self.messages_) == 0 and expr.role != \"system\":\n            # Insert the default system message\n            default_system = self.chat_template.default_system_prompt\n            if default_system:\n                self._execute_role_begin(SglRoleBegin(\"system\"))\n                self._execute_fill(default_system)\n                self._execute_role_end(SglRoleEnd(\"system\"))\n\n        self.cur_role = expr.role\n\n        prefix, _ = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)\n\n        self._execute_fill(prefix, prefix=True)\n        self.cur_role_begin_pos = len(self.text_)\n\n    def _execute_role_end(self, expr: SglRoleEnd):\n        if (\n            self.cur_role == \"assistant\"\n            and self.num_api_spec_tokens is not None\n            and self.backend.is_chat_model\n        ):\n            # Execute the stored lazy generation calls\n            self.backend.role_end_generate(self)\n        self.cur_role = None\n\n        new_text = self.text_[self.cur_role_begin_pos :].lstrip()\n\n        _, suffix = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)\n        self._execute_fill(suffix)\n\n        if self.cur_images:\n            # OpenAI vision API format\n            last_msg = {\n                \"role\": expr.role,\n                \"content\": [{\"type\": \"text\", \"text\": new_text}],\n            }\n            for image_path, image_base64_data in self.cur_images:\n                last_msg[\"content\"].append(\n                    {\n                        \"type\": \"image_url\",\n                        \"image_url\": {\n                            \"url\": f\"data:image/jpeg;base64,{image_base64_data}\"\n                        },\n                    }\n                )\n            self.messages_.append(last_msg)\n            self.cur_images = []\n        else:\n            # OpenAI chat API format\n            self.messages_.append({\"role\": expr.role, \"content\": new_text})\n\n    def _execute_var_scope_begin(self, expr: SglVarScopeBegin):\n        self.variables[expr.name] = int(len(self.text_))\n\n    def _execute_var_scope_end(self, expr: SglVarScopeEnd):\n        self.variables[expr.name] = self.text_[self.variables[expr.name] :]\n        self.variable_event[expr.name].set()\n\n    def _execute_commit_lazy_operations(self, expr: SglCommitLazy):\n        self.backend.commit_lazy_operations(self)\n\n    def _execute_concatenate_and_append_text(self, expr: SglConcateAndAppend):\n        new_text = \"\"\n        for s in expr.states:\n            exe = s.stream_executor\n            exe.sync()\n            new_text += exe.text_[exe.fork_start_text_pos :]\n\n        self._execute_fill(new_text)\n\n    def _execute_concatenate_and_append_kv_cache(self, expr: SglConcateAndAppend):\n        self_len = len(self.text_)\n\n        for i, s in enumerate(expr.states):\n            exe = s.stream_executor\n            exe.submit(SglCommitLazy())\n\n        for i, s in enumerate(expr.states):\n            exe = s.stream_executor\n            exe.sync()\n            assert exe.fork_start_text_pos == self_len\n            self.text_ += exe.text_[exe.fork_start_text_pos :]\n\n        src_rids = [state.stream_executor.sid for state in expr.states]\n        self.backend.concatenate_and_append(src_rids, self.sid)\n\n    def _execute_separate_reasoning(self, expr: SglSeparateReasoning):\n        if self.stream:\n            # separate reasoning for stream is not supported\n            return\n\n        if (\n            self.cur_role == \"assistant\"\n            and self.num_api_spec_tokens is not None\n            and self.backend.is_chat_model\n        ):\n            # Execute the stored lazy generation calls\n            self.backend.role_end_generate(self)\n\n        from sglang.srt.parser.reasoning_parser import ReasoningParser\n\n        reasoning_parser = ReasoningParser(expr.model_type)\n        other = expr.expr\n        if not other:\n            return\n        elif isinstance(other, SglGen) or isinstance(other, SglSelect):\n            cur_text = self.get_var(other.name)\n            reasoning, normal_text = reasoning_parser.parse_non_stream(cur_text)\n            reasoning_name = expr.process_name_for_reasoning(other.name)\n            self.set_var(other.name, normal_text)\n            self.set_var(reasoning_name, reasoning)\n            # the variable is ready to be used\n            self.variable_event[reasoning_name].set()\n            self.text_ = self.text_[: self.cur_role_begin_pos] + normal_text\n        elif isinstance(other, SglExprList):\n            for x in other.expr_list:\n                self._execute_separate_reasoning(\n                    SglSeparateReasoning(expr.model_type, x)\n                )\n\n    def _init_var_event(self, expr):\n        if isinstance(\n            expr, (SglGen, SglSelect, SglVarScopeBegin, SglSeparateReasoning)\n        ):\n            self.variable_event[expr.name] = threading.Event()\n            if self.stream:\n                self.stream_var_event[expr.name] = threading.Event()\n        elif isinstance(expr, SglExprList):\n            for e in expr.expr_list:\n                self._init_var_event(e)\n\n    def _resolve_sampling_params(self, sampling_params):\n        \"\"\"\n        Construct sampling param based on default + override values\n\n        The default values of sampling are populated in `default_sampling_para` via sgl.function.run(...sampling_args)\n        , and `sampling_params` contains the override values from sgl.gen().\n\n        Here we use default_sampling_para as the base and override the values if they exist in `sampling_params`.\n        It also extends the stop tokens based on the chat template.\n        \"\"\"\n\n        # deepcopy is required because the dict has lists inside\n        clone = copy.deepcopy(self.default_sampling_para)\n\n        for item in [\n            \"max_new_tokens\",\n            \"min_new_tokens\",\n            \"n\",\n            \"stop\",\n            \"stop_token_ids\",\n            \"stop_regex\",\n            \"temperature\",\n            \"top_p\",\n            \"top_k\",\n            \"min_p\",\n            \"frequency_penalty\",\n            \"presence_penalty\",\n            \"ignore_eos\",\n            \"return_logprob\",\n            \"logprob_start_len\",\n            \"top_logprobs_num\",\n            \"return_text_in_logprobs\",\n            \"dtype\",\n            \"regex\",\n            \"json_schema\",\n        ]:\n            value = getattr(sampling_params, item, None)\n            if value is not None:\n                setattr(clone, item, value)\n\n        if self.chat_template.stop_str:\n            if clone.stop == ():\n                clone.stop = []\n            elif isinstance(clone.stop, str):\n                clone.stop = [clone.stop]\n            clone.stop += self.chat_template.stop_str\n\n        return clone\n\n    def __del__(self):\n        self.end()\n\n\nclass ProgramState:\n    \"\"\"The state of an SGL program.\"\"\"\n\n    def __init__(self, stream_executor: StreamExecutor):\n        self.stream_executor = stream_executor\n\n    def _role_common(self, name: str, expr: Optional[SglExpr] = None):\n        if expr is not None:\n            role_expr = SglExprList([SglRoleBegin(name), expr, SglRoleEnd(name)])\n            self.stream_executor.submit(role_expr)\n            return role_expr\n        else:\n\n            @contextmanager\n            def role_scope():\n                self.stream_executor.submit(SglRoleBegin(name))\n                yield\n                self.stream_executor.submit(SglRoleEnd(name))\n\n            return role_scope()\n\n    def system(self, expr: Optional[SglExpr] = None):\n        return self._role_common(\"system\", expr)\n\n    def user(self, expr: Optional[SglExpr] = None):\n        return self._role_common(\"user\", expr)\n\n    def assistant(self, expr: Optional[SglExpr] = None):\n        return self._role_common(\"assistant\", expr)\n\n    @contextmanager\n    def var_scope(self, name: str):\n        self.stream_executor.submit(SglVarScopeBegin(name))\n        yield\n        self.stream_executor.submit(SglVarScopeEnd(name))\n\n    def fork(\n        self,\n        size: int = 1,\n        position_ids_offset: Optional[List[int]] = None,\n    ):\n        stream_executors = self.stream_executor.fork(size, position_ids_offset)\n        states = [ProgramState(x) for x in stream_executors]\n        state_group = ProgramStateGroup(states, self)\n        return state_group\n\n    @contextmanager\n    def copy(self, position_ids_offset: Optional[List[int]] = None):\n        state_group = self.fork(1, position_ids_offset)\n        try:\n            yield state_group[0]\n        finally:\n            state_group.join()\n\n    def text(self):\n        return self.stream_executor.text()\n\n    def messages(self):\n        return self.stream_executor.messages()\n\n    def sync(self):\n        return self.stream_executor.sync()\n\n    def error(self):\n        return self.stream_executor.error()\n\n    def text_iter(self, var_name: Optional[str] = None):\n        if self.stream_executor.stream:\n            prev = 0\n            if var_name is None:\n                event = self.stream_executor.stream_text_event\n                while True:\n                    event.wait()\n                    event.clear()\n                    out = str(self.stream_executor.text_[prev:])\n                    prev += len(out)\n                    if out:\n                        yield out\n                    if self.stream_executor.is_finished:\n                        break\n            else:\n                event = None\n                while not event:\n                    if var_name in self.stream_executor.stream_var_event:\n                        event = self.stream_executor.stream_var_event[var_name]\n                    if self.stream_executor.is_finished:\n                        yield \"\"\n                        return\n\n                while True:\n                    event.wait()\n                    event.clear()\n                    out = str(self.stream_executor.variables[var_name][prev:])\n                    prev += len(out)\n                    if out:\n                        yield out\n                    if self.stream_executor.variable_event[var_name].is_set():\n                        break\n        else:\n            if var_name is None:\n                yield self.text()\n            else:\n                yield self.get_var(var_name)\n\n    async def text_async_iter(\n        self, var_name: Optional[str] = None, return_meta_data: bool = False\n    ):\n        loop = asyncio.get_running_loop()\n\n        if self.stream_executor.stream:\n            prev = 0\n            if var_name is None:\n                event = self.stream_executor.stream_text_event\n                while True:\n                    await loop.run_in_executor(None, event.wait)\n                    event.clear()\n                    out = str(self.stream_executor.text_[prev:])\n                    prev += len(out)\n                    if out:\n                        yield out\n                    if self.stream_executor.is_finished:\n                        break\n            else:\n                event = None\n                while not event:\n                    if var_name in self.stream_executor.stream_var_event:\n                        event = self.stream_executor.stream_var_event[var_name]\n                    if self.stream_executor.is_finished:\n                        yield \"\"\n                        return\n\n                while True:\n                    await loop.run_in_executor(None, event.wait)\n                    event.clear()\n                    out = str(self.stream_executor.variables[var_name][prev:])\n                    prev += len(out)\n                    if out:\n                        if return_meta_data:\n                            yield out, self.stream_executor.meta_info[var_name]\n                        else:\n                            yield out\n                    if self.stream_executor.variable_event[var_name].is_set():\n                        break\n        else:\n            if var_name is None:\n                yield self.text()\n            else:\n                yield self.get_var(var_name)\n\n    def get_var(self, name):\n        return self.stream_executor.get_var(name)\n\n    def set_var(self, name, value):\n        return self.stream_executor.set_var(name, value)\n\n    def get_meta_info(self, name):\n        return self.stream_executor.get_meta_info(name)\n\n    def __iadd__(self, other):\n        if other is None:\n            raise ValueError(\"Tried to append None to state.\")\n        self.stream_executor.submit(other)\n        return self\n\n    def __getitem__(self, name):\n        return self.get_var(name)\n\n    def __setitem__(self, name, value):\n        self.set_var(name, value)\n\n    def __contains__(self, name):\n        return name in self.stream_executor.variables\n\n    def __del__(self):\n        self.stream_executor.end()\n\n    def __repr__(self) -> str:\n        return f\"ProgramState({self.text()})\"\n\n\nclass ProgramStateGroup:\n    def __init__(\n        self, states: List[ProgramState], src_state: Optional[ProgramState] = None\n    ):\n        self.states = states\n        self.src_state = src_state\n\n    def join(self, mode: str = \"gather_variable\"):\n        if mode == \"gather_variable\":\n            # Copy variables back\n            src_vars = self.src_state.stream_executor.variables\n            src_var_set = set(src_vars.keys())\n            for child_state in self.states:\n                child_state.stream_executor.sync()\n                child_vars = child_state.stream_executor.variables\n                new_vars = set(child_vars.keys()) - src_var_set\n\n                for k in new_vars:\n                    if k in src_vars:\n                        src_vars[k].append(child_vars[k])\n                    else:\n                        src_vars[k] = [child_vars[k]]\n        elif mode == \"concate_and_append\":\n            # Concatenate and append KV cache\n            self.src_state += SglConcateAndAppend(self.states)\n            # Need a sync here. Otherwise, `states` can be deleted.\n            self.src_state.stream_executor.sync()\n        else:\n            raise ValueError(f\"Invalid join mode: {mode}\")\n\n        for s in self.states:\n            s.stream_executor.end()\n\n    def __getitem__(self, i: int):\n        return self.states[i]\n\n    def __setitem__(self, i: int, value):\n        assert self.states[i] == value\n\n    def __iadd__(self, other):\n        if isinstance(other, Callable):\n            # lambda function\n            for i in range(len(self.states)):\n                self.states[i] += other(i)\n        elif isinstance(other, SglExpr):\n            for i in range(len(self.states)):\n                self.states[i] += other\n        elif isinstance(other, (list, tuple)):\n            for i in range(len(self.states)):\n                self.states[i] += other[i]\n        else:\n            raise ValueError(f\"Invalid value: {other}\")\n\n        return self\n"
  },
  {
    "path": "python/sglang/lang/ir.py",
    "content": "\"\"\"The intermediate representation.\"\"\"\n\nimport dataclasses\nimport inspect\nimport warnings\nfrom typing import List, Optional, Union\n\nfrom sglang.global_config import global_config\nfrom sglang.lang.choices import ChoicesSamplingMethod\n\nREGEX_INT = r\"[-+]?[0-9]+[ \\n]*\"\nREGEX_FLOAT = r\"[-+]?[0-9]*\\.?[0-9]+[ \\n]*\"\nREGEX_BOOL = r\"(True|False)\"\nREGEX_STR = r\"\\\"[\\w\\d\\s]*\\\"\"  # bugs with regex r\"\\\".*\\\"\" in interegular pkg\n\n\n@dataclasses.dataclass\nclass SglSamplingParams:\n    max_new_tokens: int = 128\n    min_new_tokens: int = 0\n    n: int = 1\n    stop: Union[str, List[str]] = ()\n    stop_token_ids: Optional[List[int]] = ()\n    stop_regex: Optional[Union[str, List[str]]] = ()\n    temperature: float = 1.0\n    top_p: float = 1.0\n    top_k: int = -1  # -1 means disable\n    min_p: float = 0.0\n    frequency_penalty: float = 0.0\n    presence_penalty: float = 0.0\n    ignore_eos: bool = False\n    return_logprob: Optional[bool] = None\n    logprob_start_len: Optional[int] = (None,)\n    top_logprobs_num: Optional[int] = (None,)\n    return_text_in_logprobs: Optional[bool] = (None,)\n    json_schema: Optional[str] = None\n\n    # for constrained generation, not included in to_xxx_kwargs\n    dtype: Optional[str] = None\n    regex: Optional[str] = None\n\n    def clone(self):\n        return SglSamplingParams(\n            self.max_new_tokens,\n            self.min_new_tokens,\n            self.n,\n            self.stop,\n            self.stop_token_ids,\n            self.stop_regex,\n            self.temperature,\n            self.top_p,\n            self.top_k,\n            self.min_p,\n            self.frequency_penalty,\n            self.presence_penalty,\n            self.ignore_eos,\n            self.return_logprob,\n            self.logprob_start_len,\n            self.top_logprobs_num,\n            self.return_text_in_logprobs,\n            self.json_schema,\n        )\n\n    def to_openai_kwargs(self):\n        # OpenAI does not support top_k, so we drop it here\n        if self.regex is not None:\n            warnings.warn(\"Regular expression is not supported in the OpenAI backend.\")\n        return {\n            \"max_tokens\": self.max_new_tokens,\n            \"max_completion_tokens\": self.max_new_tokens,\n            \"n\": self.n,\n            \"stop\": self.stop or None,\n            \"temperature\": self.temperature,\n            \"top_p\": self.top_p,\n            \"frequency_penalty\": self.frequency_penalty,\n            \"presence_penalty\": self.presence_penalty,\n        }\n\n    def to_vertexai_kwargs(self):\n        if self.regex is not None:\n            warnings.warn(\n                \"Regular expression is not supported in the VertexAI backend.\"\n            )\n        return {\n            \"candidate_count\": 1,\n            \"max_output_tokens\": self.max_new_tokens,\n            \"stop_sequences\": self.stop,\n            \"temperature\": self.temperature,\n            \"top_p\": self.top_p,\n            \"top_k\": self.top_k if self.top_k > 0 else None,\n        }\n\n    def to_anthropic_kwargs(self):\n        # Anthropic does not support frequency_penalty or presence_penalty, so we drop it here\n        if self.regex is not None:\n            warnings.warn(\n                \"Regular expression is not supported in the Anthropic backend.\"\n            )\n        return {\n            \"max_tokens\": self.max_new_tokens,\n            \"stop_sequences\": (\n                self.stop if isinstance(self.stop, (list, tuple)) else [self.stop]\n            ),\n            \"temperature\": self.temperature,\n            \"top_p\": self.top_p,\n            \"top_k\": self.top_k,\n        }\n\n    def to_litellm_kwargs(self):\n        if self.regex is not None:\n            warnings.warn(\"Regular expression is not supported in the LiteLLM backend.\")\n        return {\n            \"max_tokens\": self.max_new_tokens,\n            \"stop\": self.stop or None,\n            \"temperature\": self.temperature,\n            \"top_p\": self.top_p,\n            \"frequency_penalty\": self.frequency_penalty,\n            \"presence_penalty\": self.presence_penalty,\n        }\n\n    def to_srt_kwargs(self):\n        return {\n            \"max_new_tokens\": self.max_new_tokens,\n            \"min_new_tokens\": self.min_new_tokens,\n            \"n\": self.n,\n            \"stop\": self.stop,\n            \"stop_token_ids\": self.stop_token_ids,\n            \"stop_regex\": self.stop_regex,\n            \"temperature\": self.temperature,\n            \"top_p\": self.top_p,\n            \"top_k\": self.top_k,\n            \"min_p\": self.min_p,\n            \"frequency_penalty\": self.frequency_penalty,\n            \"presence_penalty\": self.presence_penalty,\n            \"ignore_eos\": self.ignore_eos,\n            \"regex\": self.regex,\n            \"json_schema\": self.json_schema,\n        }\n\n\nclass SglFunction:\n    def __init__(self, func, num_api_spec_tokens=None, bind_arguments=None):\n        self.func = func\n        self.num_api_spec_tokens = num_api_spec_tokens\n        self.bind_arguments = bind_arguments or {}\n        self.pin_prefix_rid = None\n\n        # Parse arguments\n        argspec = inspect.getfullargspec(func)\n        assert argspec.args[0] == \"s\", 'The first argument must be \"s\"'\n        self.arg_names = argspec.args[1:]\n        self.arg_defaults = argspec.defaults if argspec.defaults is not None else []\n\n    def bind(self, **kwargs):\n        assert all(key in self.arg_names for key in kwargs)\n\n        new_bind_dict = {**self.bind_arguments, **kwargs}\n        return SglFunction(self.func, bind_arguments=new_bind_dict)\n\n    def run(\n        self,\n        *args,\n        max_new_tokens: int = 128,\n        n: int = 1,\n        stop: Optional[Union[str, List[str]]] = None,\n        stop_token_ids: Optional[List[int]] = None,\n        stop_regex: Optional[Union[str, List[str]]] = None,\n        temperature: float = 1.0,\n        top_p: float = 1.0,\n        top_k: int = -1,\n        min_p: float = 0.0,\n        frequency_penalty: float = 0.0,\n        presence_penalty: float = 0.0,\n        ignore_eos: bool = False,\n        return_logprob: Optional[bool] = None,\n        logprob_start_len: Optional[int] = None,\n        top_logprobs_num: Optional[int] = None,\n        return_text_in_logprobs: Optional[bool] = None,\n        stream: bool = False,\n        backend=None,\n        use_thread: bool = True,\n        **kwargs,\n    ):\n        from sglang.lang.interpreter import run_program\n\n        # avoid using [] as the default arg: https://nikos7am.com/posts/mutable-default-arguments/\n        if stop is None:\n            stop = []\n        if stop_token_ids is None:\n            stop_token_ids = []\n        if stop_regex is None:\n            stop_regex = []\n\n        default_sampling_para = SglSamplingParams(\n            max_new_tokens=max_new_tokens,\n            n=n,\n            stop=stop,\n            stop_token_ids=stop_token_ids,\n            stop_regex=stop_regex,\n            temperature=temperature,\n            top_p=top_p,\n            top_k=top_k,\n            min_p=min_p,\n            frequency_penalty=frequency_penalty,\n            presence_penalty=presence_penalty,\n            ignore_eos=ignore_eos,\n            return_logprob=return_logprob,\n            logprob_start_len=logprob_start_len,\n            top_logprobs_num=top_logprobs_num,\n            return_text_in_logprobs=return_text_in_logprobs,\n        )\n        backend = backend or global_config.default_backend\n        return run_program(\n            self,\n            backend,\n            args,\n            kwargs,\n            default_sampling_para,\n            stream,\n            use_thread=use_thread,\n        )\n\n    def run_batch(\n        self,\n        batch_kwargs,\n        *,\n        max_new_tokens: int = 128,\n        n: int = 1,\n        stop: Optional[Union[str, List[str]]] = None,\n        stop_token_ids: Optional[List[int]] = None,\n        stop_regex: Optional[Union[str, List[str]]] = None,\n        temperature: float = 1.0,\n        top_p: float = 1.0,\n        top_k: int = -1,\n        min_p: float = 0.0,\n        frequency_penalty: float = 0.0,\n        presence_penalty: float = 0.0,\n        ignore_eos: bool = False,\n        return_logprob: Optional[bool] = None,\n        logprob_start_len: Optional[int] = None,\n        top_logprobs_num: Optional[int] = None,\n        return_text_in_logprobs: Optional[bool] = None,\n        backend=None,\n        num_threads: Union[str, int] = \"auto\",\n        progress_bar: bool = False,\n        generator_style: bool = False,\n    ):\n        from sglang.lang.interpreter import run_program_batch\n\n        if stop is None:\n            stop = []\n        if stop_token_ids is None:\n            stop_token_ids = []\n        if stop_regex is None:\n            stop_regex = []\n\n        assert isinstance(batch_kwargs, (list, tuple))\n        if len(batch_kwargs) == 0:\n            return []\n        if not isinstance(batch_kwargs[0], dict):\n            num_programs = len(batch_kwargs)\n            # change the list of argument values to dict of arg_name -> arg_value\n            batch_kwargs = [\n                {self.arg_names[i]: v for i, v in enumerate(arg_values)}\n                for arg_values in batch_kwargs\n                if isinstance(arg_values, (list, tuple))\n                and len(self.arg_names) - len(self.arg_defaults)\n                <= len(arg_values)\n                <= len(self.arg_names)\n            ]\n            # Ensure to raise an exception if the number of arguments mismatch\n            if len(batch_kwargs) != num_programs:\n                raise Exception(\"Given arguments mismatch the SGL function signature\")\n\n        default_sampling_para = SglSamplingParams(\n            max_new_tokens=max_new_tokens,\n            n=n,\n            stop=stop,\n            stop_token_ids=stop_token_ids,\n            stop_regex=stop_regex,\n            temperature=temperature,\n            top_p=top_p,\n            top_k=top_k,\n            min_p=min_p,\n            frequency_penalty=frequency_penalty,\n            presence_penalty=presence_penalty,\n            ignore_eos=ignore_eos,\n            return_logprob=return_logprob,\n            logprob_start_len=logprob_start_len,\n            top_logprobs_num=top_logprobs_num,\n            return_text_in_logprobs=return_text_in_logprobs,\n        )\n        backend = backend or global_config.default_backend\n        return run_program_batch(\n            self,\n            backend,\n            batch_kwargs,\n            default_sampling_para,\n            num_threads,\n            progress_bar,\n            generator_style=generator_style,\n        )\n\n    def trace(self, *, backend=None, **kwargs):\n        from sglang.lang.tracer import trace_program\n\n        backend = backend or global_config.default_backend\n        return trace_program(self, kwargs, backend)\n\n    def cache(self, backend=None):\n        from sglang.lang.interpreter import cache_program\n\n        backend = backend or global_config.default_backend\n        return cache_program(self, backend)\n\n    def __call__(self, *args, **kwargs):\n        from sglang.lang.tracer import TracingScope\n\n        tracing_scope = TracingScope.get_current_scope()\n        if tracing_scope is None:\n            return self.run(*args, **kwargs)\n        else:\n            kwargs[\"backend\"] = tracing_scope.tracer_state.backend\n            return self.trace(*args, **kwargs)\n\n\nclass SglExpr:\n    node_ct = 0\n\n    def __init__(self):\n        self.node_id = SglExpr.node_ct\n        self.prev_node = None\n        self.pid = None\n        SglExpr.node_ct += 1\n\n    def __add__(self, other):\n        if isinstance(other, str):\n            other = SglConstantText(other)\n        assert isinstance(other, SglExpr)\n\n        return self.concatenate_ir(self, other)\n\n    def __radd__(self, other):\n        if isinstance(other, str):\n            other = SglConstantText(other)\n        assert isinstance(other, SglExpr), f\"{other}\"\n\n        return self.concatenate_ir(other, self)\n\n    def concatenate_ir(self, a, b):\n        if isinstance(a, SglExprList):\n            if isinstance(b, SglExprList):\n                return SglExprList(a.expr_list + b.expr_list)\n            else:\n                return SglExprList(a.expr_list + [b])\n        elif isinstance(b, SglExprList):\n            return SglExprList([a] + b.expr_list)\n\n        return SglExprList([a, b])\n\n    def print_graph_dfs(self):\n        ret = [\"\"]\n        visited = set()\n\n        def dfs_print(x):\n            if x is None or x in visited:\n                return\n            visited.add(x)\n\n            # Print dependency\n            if x.prev_node is not None:\n                dfs_print(x.prev_node)\n\n            if isinstance(x, SglExprList):\n                for y in x.expr_list:\n                    dfs_print(y)\n            # elif isinstance(x, SglRole):\n            #    dfs_print(x.expr)\n            elif isinstance(x, SglVariable):\n                dfs_print(x.source)\n\n            # Print the node itself\n            if isinstance(x, (SglFork, SglGetForkItem)):\n                ret[0] += f\"%{x.node_id} = {x}\\n\"\n            else:\n                if x.prev_node is not None:\n                    ret[0] += (\n                        f\"%{x.node_id} = %{x.prev_node.node_id} + \" + str(x) + \"\\n\"\n                    )\n                else:\n                    ret[0] += f\"%{x.node_id} = \" + str(x) + \"\\n\"\n\n        dfs_print(self)\n        return ret[0]\n\n\nclass SglExprList(SglExpr):\n    def __init__(self, expr_list: List[SglExpr]):\n        super().__init__()\n        self.expr_list = expr_list\n\n    def __repr__(self):\n        return f\"ExprList({self.expr_list})\"\n\n\nclass SglArgument(SglExpr):\n    def __init__(self, name: str, value: str):\n        super().__init__()\n        self.name = name\n        self.value = value\n\n    def __repr__(self):\n        return f\"Argument(name={self.name}, value={repr(self.value)})\"\n\n    def __len__(self):\n        return len(self.value)\n\n    def __getitem__(self, i):\n        return self.value[i]\n\n    def __int__(self):\n        return self.value\n\n    def __bool__(self):\n        return self.value\n\n    def __format__(self, *args):\n        raise TypeError(\n            \"Cannot put argument inside a f-string. \"\n            \"This is not compatible with the tracer. \"\n        )\n\n\nclass SglImage(SglExpr):\n    def __init__(self, path: str):\n        self.path = path\n\n    def __repr__(self) -> str:\n        return f\"SglImage({self.path})\"\n\n\nclass SglVideo(SglExpr):\n    def __init__(self, path: str, num_frames: int):\n        self.path = path\n        self.num_frames = num_frames\n\n    def __repr__(self) -> str:\n        return f\"SglVideo({self.path}, {self.num_frames})\"\n\n\nclass SglGen(SglExpr):\n    def __init__(\n        self,\n        name: Optional[str] = None,\n        max_new_tokens: Optional[int] = None,\n        min_new_tokens: Optional[int] = None,\n        n: Optional[int] = None,\n        stop: Optional[Union[str, List[str]]] = None,\n        stop_token_ids: Optional[List[int]] = None,\n        stop_regex: Optional[Union[str, List[str]]] = None,\n        temperature: Optional[float] = None,\n        top_p: Optional[float] = None,\n        top_k: Optional[int] = None,\n        min_p: Optional[float] = None,\n        frequency_penalty: Optional[float] = None,\n        presence_penalty: Optional[float] = None,\n        ignore_eos: Optional[bool] = None,\n        return_logprob: Optional[bool] = None,\n        logprob_start_len: Optional[int] = None,\n        top_logprobs_num: Optional[int] = None,\n        return_text_in_logprobs: Optional[bool] = None,\n        dtype: Optional[type] = None,\n        regex: Optional[str] = None,\n        json_schema: Optional[str] = None,\n    ):\n        \"\"\"Call the model to generate. See the meaning of the arguments in docs/backend/sampling_params.md\"\"\"\n        super().__init__()\n        self.name = name\n        self.sampling_params = SglSamplingParams(\n            max_new_tokens=max_new_tokens,\n            min_new_tokens=min_new_tokens,\n            n=n,\n            stop=stop,\n            stop_regex=stop_regex,\n            stop_token_ids=stop_token_ids,\n            temperature=temperature,\n            top_p=top_p,\n            top_k=top_k,\n            min_p=min_p,\n            frequency_penalty=frequency_penalty,\n            presence_penalty=presence_penalty,\n            ignore_eos=ignore_eos,\n            return_logprob=return_logprob,\n            logprob_start_len=logprob_start_len,\n            top_logprobs_num=top_logprobs_num,\n            return_text_in_logprobs=return_text_in_logprobs,\n            dtype=dtype,\n            regex=regex,\n            json_schema=json_schema,\n        )\n\n    def __repr__(self):\n        return f\"Gen('{self.name}')\"\n\n\nclass SglConstantText(SglExpr):\n    def __init__(self, value: str):\n        super().__init__()\n        self.value = value\n\n    def __repr__(self):\n        return f\"Constant({repr(self.value)})\"\n\n\nclass SglRoleBegin(SglExpr):\n    def __init__(self, role: str):\n        super().__init__()\n        self.role = role\n\n    def __repr__(self):\n        return f\"RoleBegin({self.role})\"\n\n\nclass SglRoleEnd(SglExpr):\n    def __init__(self, role: str):\n        super().__init__()\n        self.role = role\n\n    def __repr__(self):\n        return f\"RoleEnd({self.role})\"\n\n\nclass SglSelect(SglExpr):\n\n    def __init__(\n        self,\n        name: str,\n        choices: List[str],\n        temperature: float,\n        choices_method: ChoicesSamplingMethod,\n    ):\n        super().__init__()\n        self.name = name\n        self.choices = choices\n        self.temperature = temperature\n        self.choices_method = choices_method\n\n    def __repr__(self):\n        return f\"Select({self.name}, choices={self.choices}, choices_method={self.choices_method})\"\n\n\nclass SglFork(SglExpr):\n    def __init__(self, number: int, position_ids_offset=None):\n        super().__init__()\n        self.number = number\n        self.position_ids_offset = position_ids_offset\n\n    def __repr__(self):\n        return (\n            f\"Fork(%{self.prev_node.node_id}, number={self.number}, \"\n            f\"position_ids_offset={self.position_ids_offset})\"\n        )\n\n\nclass SglGetForkItem(SglExpr):\n    def __init__(self, index: int):\n        super().__init__()\n        self.index = index\n\n    def __repr__(self):\n        return f\"GetForkItem(%{self.prev_node.node_id}, index={self.index})\"\n\n\nclass SglVariable(SglExpr):\n    def __init__(self, name: str, source):\n        super().__init__()\n        self.name = name\n        self.source = source\n\n    def __repr__(self):\n        return f\"Variable('{self.name}', source=%{self.source.node_id})\"\n\n\nclass SglVarScopeBegin(SglExpr):\n    def __init__(self, name: str):\n        super().__init__()\n        self.name = name\n\n    def __repr__(self):\n        return f\"VarScopeBegin('{self.name}')\"\n\n\nclass SglVarScopeEnd(SglExpr):\n    def __init__(self, name: str):\n        super().__init__()\n        self.name = name\n\n    def __repr__(self):\n        return f\"VarScopeEnd('{self.name}')\"\n\n\nclass SglConcateAndAppend(SglExpr):\n    def __init__(self, states):\n        super().__init__()\n        self.states = states\n\n    def __repr__(self):\n        return f\"ConcatenateAndAppend('{self.states}')\"\n\n\nclass SglCommitLazy(SglExpr):\n    def __init__(self):\n        super().__init__()\n\n    def __repr__(self):\n        return \"CommitLazy()\"\n\n\nclass SglSeparateReasoning(SglExpr):\n    def __init__(self, model_type: str, expr: SglExpr):\n        super().__init__()\n        self.model_type = model_type\n\n        self.expr = expr\n        self.name = None\n        self._process_expr(expr)\n\n    def process_name_for_reasoning(self, name):\n        if not name:\n            raise ValueError(\"name must be provided\")\n        return f\"{name}_reasoning_content\"\n\n    def _process_expr(self, expr):\n        if isinstance(expr, SglGen):\n            self.name = self.process_name_for_reasoning(expr.name)\n        elif isinstance(expr, SglSelect):\n            self.name = self.process_name_for_reasoning(expr.name)\n        elif isinstance(expr, SglExprList):\n            for x in expr.expr_list:\n                self._process_expr(x)\n\n    def __repr__(self):\n        return f\"SeparateReasoning(model_type={self.model_type}, name={self.name})\"\n"
  },
  {
    "path": "python/sglang/lang/tracer.py",
    "content": "\"\"\"Tracing a program.\"\"\"\n\nimport uuid\nfrom typing import Any, Dict, List, Optional\n\nfrom sglang.lang.backend.base_backend import BaseBackend\nfrom sglang.lang.interpreter import ProgramState, ProgramStateGroup\nfrom sglang.lang.ir import (\n    SglArgument,\n    SglConstantText,\n    SglExpr,\n    SglExprList,\n    SglFork,\n    SglGen,\n    SglGetForkItem,\n    SglRoleBegin,\n    SglRoleEnd,\n    SglSelect,\n    SglVariable,\n    SglVarScopeBegin,\n    SglVarScopeEnd,\n)\n\n\nclass StopTracing(Exception):\n    pass\n\n\ndef extract_prefix_by_tracing(program, backend):\n    # Create dummy arguments\n    dummy_arguments = {name: SglArgument(name, None) for name in program.arg_names}\n    arguments = dummy_arguments\n    arguments.update(program.bind_arguments)\n\n    # Trace\n    tracer = TracerProgramState(backend, arguments, only_trace_prefix=True)\n    try:\n        with TracingScope(tracer):\n            tracer.ret_value = program.func(tracer, **arguments)\n    except (StopTracing, TypeError, AttributeError):\n        # Some exceptions may not be caught\n        pass\n\n    # Run and cache prefix\n    prefix = \"\"\n    for expr in tracer.flatten_nodes():\n        if isinstance(expr, SglConstantText):\n            prefix += expr.value\n        else:\n            break\n    return prefix\n\n\ndef trace_program(program, arguments, backend):\n    # Create dummy backend\n    if backend is None:\n        backend = BaseBackend()\n\n    # Create dummy arguments\n    dummy_arguments = {\n        name: SglArgument(name, None)\n        for name in program.arg_names\n        if name not in arguments\n    }\n    arguments.update(dummy_arguments)\n    arguments.update(program.bind_arguments)\n\n    # Trace\n    tracer = TracerProgramState(backend, arguments, only_trace_prefix=False)\n    with TracingScope(tracer):\n        tracer.ret_value = program.func(tracer, **arguments)\n    return tracer\n\n\nclass TracerProgramState(ProgramState):\n    def __init__(self, backend, arguments, only_trace_prefix):\n        self.pid = uuid.uuid4().hex\n        self.backend = backend\n        self.arguments: Dict[str, Any] = arguments\n        self.only_trace_prefix = only_trace_prefix\n\n        if hasattr(backend, \"endpoint\"):\n            self.backend = backend.endpoint\n\n        self.nodes = []\n        self.last_node = None\n        self.variables = {}\n        self.ret_value = None\n\n        # For completion\n\n        # For chat\n        self.messages_ = []\n        self.cur_role = None\n        self.chat_template = self.backend.get_chat_template()\n\n        # For multi states\n        self.child_states = []\n\n        cur_scope = TracingScope.get_current_scope()\n        if cur_scope is not None:\n            cur_scope.add_child_state(self)\n\n    ##################################\n    ########### Public API ###########\n    ##################################\n\n    def fork(self, size: int = 1, position_ids_offset: Optional[List[int]] = None):\n        assert size >= 1\n\n        if self.only_trace_prefix:\n            raise StopTracing()\n\n        fork_node = SglFork(size)\n        fork_node.prev_node = self.last_node\n\n        states = [\n            TracerProgramState(self.backend, self.arguments, self.only_trace_prefix)\n            for _ in range(size)\n        ]\n\n        for i in range(size):\n            node = SglGetForkItem(i)\n            node.prev_node = fork_node\n            states[i].last_node = node\n            states[i].variables = dict(self.variables)\n            states[i].messages_ = list(self.messages_)\n            states[i].cur_role = self.cur_role\n            states[i].chat_template = self.chat_template\n\n        state_group = ProgramStateGroup(states, self)\n\n        return state_group\n\n    ##################################\n    ########## Internal API ##########\n    ##################################\n\n    def _append_node(self, other: SglExpr):\n        self.nodes.append(other)\n        other.prev_node = self.last_node\n        self.last_node = other\n\n    def _execute(self, other: SglExpr):\n        if isinstance(other, str):\n            other = SglConstantText(other)\n\n        other.pid = self.pid\n\n        if isinstance(other, SglConstantText):\n            self._execute_fill(other)\n        elif isinstance(other, SglGen):\n            self._execute_gen(other)\n        elif isinstance(other, SglSelect):\n            self._execute_select(other)\n        elif isinstance(other, SglExprList):\n            for x in other.expr_list:\n                self._execute(x)\n        elif isinstance(other, SglRoleBegin):\n            self._execute_role_begin(other)\n        elif isinstance(other, SglRoleEnd):\n            self._execute_role_end(other)\n        elif isinstance(other, SglVarScopeBegin):\n            self._execute_var_scope_begin(other)\n        elif isinstance(other, SglVarScopeEnd):\n            self._execute_var_scope_end(other)\n        else:\n            if self.only_trace_prefix:\n                raise StopTracing()\n            else:\n                self._append_node(other)\n\n        return self\n\n    def __iadd__(self, other):\n        self._execute(other)\n        return self\n\n    def _execute_fill(self, expr: SglConstantText):\n        if isinstance(expr, str):\n            expr = SglConstantText(expr)\n        self._append_node(expr)\n\n    def _execute_gen(self, expr: SglGen):\n        name = expr.name if expr.name is not None else \"gen_\" + str(len(self.variables))\n        new_node = SglVariable(name, source=expr)\n        self.variables[name] = new_node\n        self._append_node(expr)\n\n    def _execute_select(self, expr: SglSelect):\n        name = (\n            expr.name if expr.name is not None else \"select_\" + str(len(self.variables))\n        )\n        new_node = SglVariable(name, source=expr)\n        self.variables[name] = new_node\n        self._append_node(expr)\n\n    def _execute_role_begin(self, expr: SglRoleBegin):\n        assert self.cur_role is None, \"Nested roles are not allowed.\"\n\n        if len(self.messages_) == 0 and expr.role != \"system\":\n            # Insert default system message\n            default_system = self.chat_template.default_system_prompt\n            if default_system:\n                self._execute_role_begin(SglRoleBegin(\"system\"))\n                self._execute_fill(default_system)\n                self._execute_role_end(SglRoleEnd(\"system\"))\n\n        self.cur_role = expr.role\n\n        prefix, suffix = self.chat_template.get_prefix_and_suffix(\n            expr.role, self.messages_\n        )\n\n        self._execute_fill(prefix)\n\n    def _execute_role_end(self, expr: SglRoleEnd):\n        prefix, suffix = self.chat_template.get_prefix_and_suffix(\n            expr.role, self.messages_\n        )\n\n        self._execute_fill(suffix)\n\n        self.messages_.append({\"role\": expr.role, \"content\": \"\"})\n\n        self.cur_role = None\n\n    def _execute_var_scope_end(self, expr: SglVarScopeEnd):\n        new_node = SglVariable(expr.name, source=self.last_node)\n        self.variables[expr.name] = new_node\n\n    def get_var(self, name):\n        ret = self.arguments.get(name, None)\n        if ret is not None:\n            return ret\n\n        v = self.variables[name]\n        return SglVariable(v.name, v.source)\n\n    def flatten_nodes(self):\n        def traverse(cur):\n            if isinstance(cur, SglExprList):\n                for child in cur.expr_list:\n                    traverse(child)\n            else:\n                ret.append(cur)\n\n        ret = []\n        for x in self.nodes:\n            traverse(x)\n        return ret\n\n    def __del__(self):\n        pass\n\n\nclass TracingScope:\n    cur_scope = None\n\n    def __init__(self, tracer_state: TracerProgramState):\n        self.tracer_state = tracer_state\n        self.last_scope = TracingScope.cur_scope\n\n    def __enter__(self):\n        TracingScope.cur_scope = self\n        return self\n\n    def __exit__(self, exc_type, exc_value, traceback):\n        TracingScope.cur_scope = self.last_scope\n\n    @staticmethod\n    def get_current_scope():\n        return TracingScope.cur_scope\n\n    def add_child_state(self, state: TracerProgramState):\n        cur_scope = self\n        while cur_scope is not None:\n            cur_scope.tracer_state.child_states.append(state)\n            cur_scope = cur_scope.last_scope\n"
  },
  {
    "path": "python/sglang/launch_server.py",
    "content": "\"\"\"Launch the inference server.\"\"\"\n\nimport asyncio\nimport os\nimport sys\nimport warnings\n\nfrom sglang.srt.server_args import prepare_server_args\nfrom sglang.srt.utils import kill_process_tree\nfrom sglang.srt.utils.common import suppress_noisy_warnings\n\nsuppress_noisy_warnings()\n\n\ndef run_server(server_args):\n    \"\"\"Run the server based on server_args.grpc_mode and server_args.encoder_only.\"\"\"\n    if server_args.encoder_only:\n        # For encoder disaggregation\n        if server_args.grpc_mode:\n            from sglang.srt.disaggregation.encode_grpc_server import (\n                serve_grpc_encoder,\n            )\n\n            asyncio.run(serve_grpc_encoder(server_args))\n        else:\n            from sglang.srt.disaggregation.encode_server import launch_server\n\n            launch_server(server_args)\n    elif server_args.grpc_mode:\n        from sglang.srt.entrypoints.grpc_server import serve_grpc\n\n        asyncio.run(serve_grpc(server_args))\n    elif server_args.use_ray:\n        try:\n            from sglang.srt.ray.http_server import launch_server\n        except ImportError:\n            raise ImportError(\n                \"Ray is required for --use-ray mode. \"\n                \"Install it with: pip install 'sglang[ray]'\"\n            )\n\n        launch_server(server_args)\n    else:\n        # Default mode: HTTP mode.\n        from sglang.srt.entrypoints.http_server import launch_server\n\n        launch_server(server_args)\n\n\nif __name__ == \"__main__\":\n    warnings.warn(\n        \"'python -m sglang.launch_server' is still supported, but \"\n        \"'sglang serve' is the recommended entrypoint.\\n\"\n        \"  Example: sglang serve --model-path <model> [options]\",\n        UserWarning,\n        stacklevel=1,\n    )\n\n    server_args = prepare_server_args(sys.argv[1:])\n\n    try:\n        run_server(server_args)\n    finally:\n        kill_process_tree(os.getpid(), include_parent=False)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/.claude/CLAUDE.md",
    "content": "# CLAUDE.md — sglang-diffusion (multimodal_gen)\n\n## What is this?\n\nSGLang's diffusion/multimodal generation subsystem. Separate from the LLM runtime (`srt`). Supports 20+ image/video diffusion models (Wan, FLUX, HunyuanVideo, LTX, Qwen-Image, etc.) with distributed inference, LoRA, and multiple attention backends.\n\n## Quick Start\n\n```bash\n# One-shot generation\nsglang generate --model-path Wan-AI/Wan2.1-T2V-1.3B-Diffusers --prompt \"A curious raccoon\" --save-output\n\n# Start server\nsglang serve --model-path Wan-AI/Wan2.1-T2V-1.3B-Diffusers --num-gpus 4\n\n# Python API\nfrom sglang import DiffGenerator\ngen = DiffGenerator.from_pretrained(\"Wan-AI/Wan2.1-T2V-1.3B-Diffusers\")\nresult = gen.generate(sampling_params_kwargs={\"prompt\": \"A curious raccoon\"})\n```\n\n## Architecture\n\n```\nCLI / Python API / HTTP Server (FastAPI + OpenAI-compatible)\n    ↓ ZMQ\nScheduler (request queue, batching, dispatch)\n    ↓ multiprocessing pipes\nGPU Worker(s) → ComposedPipeline (stages: TextEncode → Denoise → Decode)\n```\n\n### Key Directories\n\n```\nruntime/\n├── entrypoints/        # CLI (generate/serve), HTTP server, Python API (DiffGenerator)\n├── managers/           # scheduler.py, gpu_worker.py\n├── pipelines_core/     # ComposedPipelineBase, stages/, schedule_batch.py (Req/OutputBatch)\n├── pipelines/          # Model-specific pipelines (wan, flux, hunyuan, ltx, qwen_image, ...)\n├── models/             # encoders/, dits/, vaes/, schedulers/\n├── layers/             # attention/, lora/, quantization/\n├── loader/             # Model loading, weight utils\n├── server_args.py      # ServerArgs (all CLI/config params)\n└── distributed/        # Multi-GPU (TP, SP: ulysses/ring)\nconfigs/\n├── pipeline_configs/    # Per-model pipeline configs\n├── sample/             # SamplingParams\n└── models/             # DiT, VAE, Encoder configs\n```\n\n### Key Classes\n\n| Class | Location | Purpose |\n|-------|----------|---------|\n| `DiffGenerator` | `runtime/entrypoints/diffusion_generator.py` | Python API entry point |\n| `ComposedPipelineBase` | `runtime/pipelines_core/composed_pipeline_base.py` | Pipeline orchestrator (stages) |\n| `Scheduler` | `runtime/managers/scheduler.py` | ZMQ event loop, request dispatch |\n| `GPUWorker` | `runtime/managers/gpu_worker.py` | GPU inference worker |\n| `Req` / `OutputBatch` | `runtime/pipelines_core/schedule_batch.py` | Request/output containers |\n| `ServerArgs` | `runtime/server_args.py` | All config params |\n| `SamplingParams` | `configs/sample/sampling_params.py` | Generation params |\n| `PipelineConfig` | `configs/pipeline_configs/base.py` | Model structure config |\n\n### Key Functions\n\n| Function | Module | Purpose |\n|----------|--------|---------|\n| `build_pipeline()` | `runtime/pipelines_core/__init__.py` | Instantiate pipeline from model_path |\n| `get_model_info()` | `registry.py` | Resolve pipeline + config classes |\n| `launch_server()` | `runtime/launch_server.py` | Start multi-process server |\n\n## Adding a New Model\n\n1. Create pipeline in `runtime/pipelines/` extending `ComposedPipelineBase`\n2. Define stages via `create_pipeline_stages()` (TextEncoding → Denoising → Decoding)\n3. Add config in `configs/pipeline_configs/`\n4. Register in `registry.py` via `register_configs()`\n\n## Multi-GPU\n\n```bash\n# Sequence parallelism (video frames across GPUs)\nsglang serve --model-path ... --num-gpus 4 --ulysses-degree 2 --ring-degree 2\n\n# Tensor parallelism (model layers across GPUs)\nsglang serve --model-path ... --num-gpus 2 --tp-size 2\n```\n\n## Testing\n\n```bash\n# Tests live in test/ subdirectory\npython -m pytest python/sglang/multimodal_gen/test/\n\n# No need to pre-download models — auto-downloaded at runtime\n# Dependencies assumed already installed via `pip install -e \"python[diffusion]\"`\n```\n\n## Performance Tuning\n\nFor questions about optimal performance, fastest commands, VRAM reduction, or best flag combinations for a given model/GPU setup, **read the [diffusion-optimal-perf skill](skills/diffusion-optimal-perf/SKILL.md)**. It contains a complete table of all lossless and lossy optimization flags with trade-offs, quick recipes, and tips.\n\n### Perf Measurement\n\nLook for `Pixel data generated successfully in xxxx seconds` in console output. With warmup enabled, use the line containing `warmup excluded` for accurate timing.\n\n## Env Vars\n\nDefined in `envs.py` (300+ vars). Key ones:\n- `SGLANG_DIFFUSION_ATTENTION_BACKEND` — attention backend override\n- `SGLANG_CACHE_DIT_ENABLED` — enable Cache-DiT acceleration\n- `SGLANG_CLOUD_STORAGE_TYPE` — cloud output storage (s3, etc.)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/SKILL.md",
    "content": "---\nname: diffusion-kernel\ndescription: Index for SGLang Diffusion kernel development skills.\n---\n\n# Diffusion Kernel Skills\n\n## Rule: Follow User Kernel Language Preference\n\nIf the user explicitly states a preference for **Triton** or **CUDA**, follow that preference when implementing and optimizing kernels (even if the other option could work). Do not “pick for convenience”.\n\n## Directory Layout\n\n```\npython/sglang/multimodal_gen/.claude/skills/diffusion-kernel/\n├── SKILL.md\n├── add-triton-kernel.md\n├── add-cuda-kernel.md\n├── diffusion-benchmark-and-profile.md\n├── nsight-profiler.md\n├── use-efficient-diffusion-kernels.md\n├── references/\n│   ├── kernel-templates.md          # Copy-paste CUDA kernel templates (sglang JIT style)\n│   ├── troubleshooting.md           # Build/perf/integration issues & fixes\n│   ├── h100-optimization-guide.md   # H100 (sm_90) deep dive\n│   ├── a100-optimization-guide.md   # A100 (sm_80) deep dive\n│   └── t4-optimization-guide.md     # T4 (sm_75, FP16 only) deep dive\n└── scripts/\n    ├── bench_diffusion_rmsnorm.py   # RMSNorm micro-benchmark vs PyTorch\n    └── bench_diffusion_denoise.py   # End-to-end denoise benchmark (sglang generate)\n```\n\n## Index\n\nBefore running any benchmark, profiler, or kernel-validation command, use\n`scripts/diffusion_skill_env.py` to derive the repo root from `sglang.__file__`,\nverify the repo is writable, export `FLASHINFER_DISABLE_VERSION_CHECK=1`, and\nchoose idle GPU(s) before starting perf work.\n\n- [scripts/diffusion_skill_env.py](scripts/diffusion_skill_env.py)\n\n  Shared preflight helper for all diffusion skill commands. Use it to print the repo root, create benchmark/profile output directories, and choose idle GPUs before running `sglang generate`, torch profiler, nsys, or ncu.\n\n- [add-triton-kernel.md](./add-triton-kernel.md)\n\n  Step-by-step guide for adding a new Triton kernel to SGLang Diffusion's `jit_kernel/diffusion/triton/` module, including authoring, autotune, `torch.compile` compatibility, integration, and tests. Use for fused elementwise ops, norm variants, RoPE variants, or when NPU/CPU fallback is needed.\n\n- [add-cuda-kernel.md](./add-cuda-kernel.md)\n\n  Step-by-step guide for adding a JIT CUDA kernel. CUDA source goes in `jit_kernel/csrc/diffusion/<op>.cuh`; Python wrapper at `jit_kernel/diffusion/<op>.py`. Uses SGLang's JIT compilation system (`load_jit`, `cache_once`) and internal abstractions (`TensorMatcher`, `device::AlignedVector`, `host::LaunchKernel`, `device::warp::reduce_sum`). Use for bandwidth-bound reductions (RMSNorm, LayerNorm) or ops needing fine-grained vectorization and shared memory control. Adapted from [HuggingFace kernels cuda-kernels skill](https://github.com/huggingface/kernels/tree/main/skills/cuda-kernels).\n\n- [use-efficient-diffusion-kernels.md](./use-efficient-diffusion-kernels.md)\n\n  Practical guidance for using SGLang Diffusion fused kernels and fast CUDA paths, including constraints, fallbacks, and where the fused ops are wired into the runtime.\n\n- [diffusion-benchmark-and-profile.md](./diffusion-benchmark-and-profile.md)\n\n  Denoise-stage benchmark and profiling guide for SGLang Diffusion models. Three profiling levels: Level 1 (torch.profiler — kernel time ranking), Level 2 (nsys — category breakdown), Level 3 (ncu — per-kernel bandwidth/occupancy/roofline analysis). **ncu is critical for kernel optimization** — always use it when writing or tuning custom kernels to verify hardware saturation.\n\n- [nsight-profiler.md](./nsight-profiler.md)\n\n  Advanced profiling skill for NVIDIA Nsight Systems / Nsight Compute: collecting traces, reading reports, and interpreting kernel-level performance metrics.\n\n## References (GPU optimization guides, templates, troubleshooting)\n\nLoaded by `add-cuda-kernel.md`. Adapted from [HuggingFace kernels cuda-kernels skill](https://github.com/huggingface/kernels/tree/main/skills/cuda-kernels).\n\n- [references/kernel-templates.md](references/kernel-templates.md) — copy-paste ready sglang JIT CUDA templates: element-wise (SiLU), row-reduction (RMSNorm), fused AdaLN, Python wrapper, test, benchmark\n- [references/troubleshooting.md](references/troubleshooting.md) — build errors, performance issues, torch.compile compatibility, kernel injection pitfalls\n- [references/h100-optimization-guide.md](references/h100-optimization-guide.md) — H100 (sm_90): AlignedVector benchmarks, warp reductions, occupancy, TMA, PDL\n- [references/a100-optimization-guide.md](references/a100-optimization-guide.md) — A100 (sm_80): cp.async, TF32, 2:4 sparsity, H100→A100 migration checklist\n- [references/t4-optimization-guide.md](references/t4-optimization-guide.md) — T4 (sm_75): FP16 only, 320 GB/s bandwidth, 64 KB shared mem, 16 GB memory management\n\n## Scripts (runnable benchmarks)\n\n- [scripts/diffusion_skill_env.py](scripts/diffusion_skill_env.py) — preflight helper: repo root discovery via `sglang.__file__`, write-access probe, benchmark/profile output directories, idle GPU selection\n- [scripts/bench_diffusion_rmsnorm.py](scripts/bench_diffusion_rmsnorm.py) — RMSNorm micro-benchmark: JIT CUDA vs PyTorch, correctness check, bandwidth efficiency analysis\n- [scripts/bench_diffusion_denoise.py](scripts/bench_diffusion_denoise.py) — end-to-end denoise benchmark via `sglang generate`, baseline vs custom kernels comparison table\n"
  },
  {
    "path": "python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/add-cuda-kernel.md",
    "content": "---\nname: add-cuda-kernel\ndescription: Step-by-step guide for adding a new JIT CUDA kernel to SGLang Diffusion. CUDA source files go in jit_kernel/csrc/diffusion/<op>.cuh; Python wrapper at jit_kernel/diffusion/<op>.py. Use when implementing optimized CUDA kernels for diffusion model operators (RMSNorm, RoPE, AdaLN, GEGLU, etc.) on NVIDIA GPUs (H100, A100). Covers kernel authoring with sglang abstractions, JIT compilation, Python wrapper, integration into the denoise stage, and benchmarking. Adapted from https://github.com/huggingface/kernels/tree/main/skills/cuda-kernels.\n---\n\n# Adding a CUDA Kernel to SGLang Diffusion (JIT Style)\n\n> **Origin**: This skill is adapted from the [HuggingFace kernels cuda-kernels skill](https://github.com/huggingface/kernels/tree/main/skills/cuda-kernels), rewritten to follow SGLang's JIT compilation system and internal abstractions.\n>\n> **Run environment first**: before compiling, benchmarking, or profiling any kernel from this guide, use `scripts/diffusion_skill_env.py` (or the setup block in `diffusion-benchmark-and-profile.md`) to `cd` to the repo root resolved from `sglang.__file__`, verify write access, export `FLASHINFER_DISABLE_VERSION_CHECK=1`, and pick an idle GPU.\n>\n> **Extended references** (in this directory's `references/` and `scripts/`):\n> - [references/kernel-templates.md](references/kernel-templates.md) — copy-paste ready templates for element-wise, row-reduction (RMSNorm), fused AdaLN\n> - [references/troubleshooting.md](references/troubleshooting.md) — build errors, perf issues, integration pitfalls\n> - [references/h100-optimization-guide.md](references/h100-optimization-guide.md) — H100 (sm_90) deep dive\n> - [references/a100-optimization-guide.md](references/a100-optimization-guide.md) — A100 (sm_80) deep dive\n> - [references/t4-optimization-guide.md](references/t4-optimization-guide.md) — T4 (sm_75, FP16 only) deep dive\n> - [scripts/bench_diffusion_rmsnorm.py](scripts/bench_diffusion_rmsnorm.py) — RMSNorm micro-benchmark vs PyTorch\n> - [scripts/bench_diffusion_denoise.py](scripts/bench_diffusion_denoise.py) — end-to-end denoise benchmark with/without kernels\n\n## When to Use CUDA vs Triton\n\n| Scenario | Use |\n|----------|-----|\n| Fused elementwise / norm variants / RoPE | **Triton** (`add-triton-kernel.md`) — faster iteration |\n| Bandwidth-bound reduction (RMSNorm, LayerNorm) requiring max vectorization | **CUDA** — full control over `__nv_bfloat162` / `float4` vectorization |\n| Attention pattern or tile-based ops needing shared memory tuning | **CUDA** — warp-level primitives, shared memory layout |\n| Prototype or NPU/CPU fallback needed | **Triton** — portable across backends |\n\nFor most diffusion-model elementwise ops, **start with Triton**. Switch to CUDA when profiling shows Triton can't reach hardware bandwidth limits.\n\n## Directory Layout\n\n```\npython/sglang/jit_kernel/\n├── csrc/\n│   ├── diffusion/               # JIT CUDA source files for diffusion kernels (this skill)\n│   │   ├── timestep_embedding.cuh   # existing example\n│   │   ├── rmsnorm.cuh              # NEW: add here\n│   │   └── adaln.cuh                # NEW: add here\n│   └── elementwise/             # shared JIT CUDA csrc (non-diffusion)\n├── diffusion/\n│   ├── triton/                  # Triton kernels (scale_shift, norm, rope, ...)\n│   ├── cutedsl/                 # CuTe DSL kernels\n│   └── rmsnorm.py               # NEW: CUDA JIT Python wrapper (add here)\n├── timestep_embedding.py        # existing CUDA diffusion kernel Python wrapper (legacy)\n```\n\nNew diffusion CUDA kernel source files go into `python/sglang/jit_kernel/csrc/diffusion/<op_name>.cuh`.\nThe Python wrapper goes at `python/sglang/jit_kernel/diffusion/<op_name>.py`\n(inside `diffusion/`, alongside the `triton/` and `cutedsl/` subdirectories).\n\n---\n\n## SGLang Kernel Abstractions (Required)\n\nAlways use these — do **not** use raw CUDA primitives directly.\n\n```cpp\n#include <sgl_kernel/tensor.h>    // TensorMatcher, SymbolicSize, SymbolicDevice\n#include <sgl_kernel/type.cuh>    // fp16_t, bf16_t, fp32_t, dtype_trait, packed_t\n#include <sgl_kernel/utils.h>     // RuntimeCheck, div_ceil\n#include <sgl_kernel/utils.cuh>   // LaunchKernel, SGL_DEVICE, type aliases\n#include <sgl_kernel/vec.cuh>     // AlignedVector<T, N> — 128-bit vector loads\n#include <sgl_kernel/warp.cuh>    // warp::reduce_sum, warp::reduce_max\n#include <sgl_kernel/math.cuh>    // device::math::rsqrt, sqrt, ...\n#include <sgl_kernel/tile.cuh>    // tile::Memory (strided access pattern)\n```\n\nKey types: `fp16_t` = `__half`, `bf16_t` = `__nv_bfloat16`, `fp32_t` = `float`.\nPacked variants: `fp16x2_t`, `bf16x2_t`. Use `packed_t<T>` for the 2-element alias.\n\n---\n\n## Step 1: Write the CUDA Kernel\n\nCreate `python/sglang/jit_kernel/csrc/diffusion/rmsnorm.cuh` (RMSNorm as example).\n\n### 1a. Vectorized RMSNorm Kernel\n\n```cpp\n#include <sgl_kernel/tensor.h>\n#include <sgl_kernel/type.cuh>\n#include <sgl_kernel/utils.h>\n#include <sgl_kernel/utils.cuh>\n#include <sgl_kernel/vec.cuh>\n#include <sgl_kernel/warp.cuh>\n\n#include <dlpack/dlpack.h>\n#include <tvm/ffi/container/tensor.h>\n\nnamespace {\n\n// ---------------------------------------------------------------\n// RMSNorm kernel: y = x / rms(x) * weight\n// T      = fp16_t | bf16_t | fp32_t\n// kVecN  = vectorized elements per load (8 for fp16/bf16, 4 for fp32)\n// ---------------------------------------------------------------\ntemplate <typename T, int kVecN>\n__global__ void rmsnorm_kernel(\n    T* __restrict__ dst,\n    const T* __restrict__ src,\n    const T* __restrict__ weight,        // may be nullptr if no affine weight\n    uint32_t hidden_size,\n    uint32_t n_vecs,                     // hidden_size / kVecN\n    float eps)\n{\n    using vec_t = device::AlignedVector<T, kVecN>;\n\n    const uint32_t row = blockIdx.x;\n    const T* row_src = src + row * hidden_size;\n    T*       row_dst = dst + row * hidden_size;\n\n    // --- Pass 1: accumulate sum of squares (vectorized) ---\n    float sum_sq = 0.f;\n    for (uint32_t vi = threadIdx.x; vi < n_vecs; vi += blockDim.x) {\n        vec_t v;\n        v.load(row_src, vi);\n        #pragma unroll\n        for (int i = 0; i < kVecN; ++i) {\n            float val = static_cast<float>(v[i]);\n            sum_sq += val * val;\n        }\n    }\n\n    // --- Warp reduction ---\n    sum_sq = device::warp::reduce_sum<float>(sum_sq);\n\n    // --- Block reduction via shared memory ---\n    __shared__ float smem[32];\n    if (threadIdx.x % 32 == 0) {\n        smem[threadIdx.x / 32] = sum_sq;\n    }\n    __syncthreads();\n    if (threadIdx.x < 32) {\n        sum_sq = (threadIdx.x < blockDim.x / 32) ? smem[threadIdx.x] : 0.f;\n        sum_sq = device::warp::reduce_sum<float>(sum_sq);\n    }\n    __syncthreads();\n\n    const float rms_inv = device::math::rsqrt<float>(sum_sq / static_cast<float>(hidden_size) + eps);\n\n    // --- Pass 2: normalize + apply weight (vectorized) ---\n    for (uint32_t vi = threadIdx.x; vi < n_vecs; vi += blockDim.x) {\n        vec_t v_in, v_w, v_out;\n        v_in.load(row_src, vi);\n        if (weight != nullptr) {\n            v_w.load(weight, vi);\n        }\n        #pragma unroll\n        for (int i = 0; i < kVecN; ++i) {\n            float val = static_cast<float>(v_in[i]) * rms_inv;\n            if (weight != nullptr) {\n                val *= static_cast<float>(v_w[i]);\n            }\n            v_out[i] = static_cast<T>(val);\n        }\n        v_out.store(row_dst, vi);\n    }\n}\n\n// ---------------------------------------------------------------\n// Launcher\n// ---------------------------------------------------------------\ntemplate <typename T>\nvoid rmsnorm(\n    tvm::ffi::TensorView dst,\n    tvm::ffi::TensorView src,\n    tvm::ffi::TensorView weight,          // pass empty / nullptr for no-weight case\n    float eps)\n{\n    using namespace host;\n\n    // Validate\n    SymbolicSize B{\"batch_tokens\"}, H{\"hidden_size\"};\n    SymbolicDevice device;\n    device.set_options<kDLCUDA>();\n\n    TensorMatcher({B, H})\n        .with_dtype<T>()\n        .with_device(device)\n        .verify(dst)\n        .verify(src);\n\n    const uint32_t num_rows   = static_cast<uint32_t>(B.unwrap());\n    const uint32_t hidden     = static_cast<uint32_t>(H.unwrap());\n    const DLDevice dev        = device.unwrap();\n\n    RuntimeCheck(hidden % (16 / sizeof(T)) == 0,\n        \"rmsnorm: hidden_size must be divisible by vector width, got \", hidden);\n\n    constexpr int kVecN    = 16 / sizeof(T);   // 128-bit vector: 8×fp16/bf16, 4×fp32\n    const uint32_t n_vecs  = hidden / kVecN;\n\n    // Thread count: enough warps to cover n_vecs, max 512 threads\n    uint32_t threads = std::min(n_vecs, 512u);\n    threads = (threads + 31) / 32 * 32;   // round up to warp boundary\n\n    const T* w_ptr = (weight.data_ptr() != nullptr)\n        ? static_cast<const T*>(weight.data_ptr()) : nullptr;\n\n    LaunchKernel(num_rows, threads, dev)(\n        rmsnorm_kernel<T, kVecN>,\n        static_cast<T*>(dst.data_ptr()),\n        static_cast<const T*>(src.data_ptr()),\n        w_ptr,\n        hidden,\n        n_vecs,\n        eps);\n}\n\n}  // namespace\n```\n\n---\n\n## Step 2: Python Wrapper\n\nCreate `python/sglang/jit_kernel/diffusion/rmsnorm.py`:\n\n```python\nfrom __future__ import annotations\nfrom typing import TYPE_CHECKING\n\nimport torch\n\nfrom sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args\n\nif TYPE_CHECKING:\n    from tvm_ffi.module import Module\n\n\n@cache_once\ndef _jit_rmsnorm_module(dtype: torch.dtype) -> Module:\n    args = make_cpp_args(dtype)\n    return load_jit(\n        \"diffusion_rmsnorm\",\n        *args,\n        cuda_files=[\"diffusion/rmsnorm.cuh\"],    # relative to csrc/\n        cuda_wrappers=[(\"rmsnorm\", f\"rmsnorm<{args}>\")],\n        extra_cuda_cflags=[\"-O3\", \"--use_fast_math\"],\n    )\n\n\ndef diffusion_rmsnorm(\n    src: torch.Tensor,\n    weight: torch.Tensor | None = None,\n    eps: float = 1e-6,\n    out: torch.Tensor | None = None,\n) -> torch.Tensor:\n    \"\"\"\n    RMSNorm for diffusion DiT layers.\n\n    y = x / rms(x) * weight   (weight=None → no affine scaling)\n\n    Supported dtypes: float16, bfloat16, float32.\n    hidden_size must be divisible by 8 (fp16/bf16) or 4 (fp32).\n    \"\"\"\n    assert src.is_cuda, \"src must be a CUDA tensor\"\n    assert src.dtype in (torch.float16, torch.bfloat16, torch.float32)\n\n    if out is None:\n        out = torch.empty_like(src)\n\n    # Pass a zero-sized tensor when weight is absent (launcher checks data_ptr == nullptr)\n    w = weight if weight is not None else torch.empty(0, dtype=src.dtype, device=src.device)\n\n    module = _jit_rmsnorm_module(src.dtype)\n    module.rmsnorm(out, src, w, eps)\n    return out\n```\n\n**Key rules for the wrapper:**\n- Use `cache_once` — never `functools.lru_cache` (breaks `torch.compile`)\n- First arg(s) to `load_jit` form the unique build cache key\n- `cuda_files` are relative to `python/sglang/jit_kernel/csrc/`\n- `cuda_wrappers`: `(python_name, cpp_template_instantiation)`\n\n---\n\n## Step 3: Integrate into Denoising Stage\n\nThe kernel replaces a slow operator inside the DiT forward pass. Find the correct module in:\n\n```\npython/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py\npython/sglang/multimodal_gen/runtime/models/dits/<model>.py\n```\n\n**Pattern — monkey-patch the DiT block's RMSNorm:**\n\n```python\nfrom sglang.jit_kernel.diffusion.rmsnorm import diffusion_rmsnorm\n\ndef _patch_rmsnorm(model: torch.nn.Module) -> None:\n    for name, module in model.named_modules():\n        cls_name = type(module).__name__\n        if cls_name in (\"RMSNorm\", \"LlamaRMSNorm\") or \"RMSNorm\" in cls_name:\n            eps = getattr(module, \"eps\", getattr(module, \"variance_epsilon\", 1e-6))\n            has_weight = hasattr(module, \"weight\") and module.weight is not None\n\n            if has_weight:\n                def _make_fwd(mod, epsilon):\n                    def forward(x):\n                        return diffusion_rmsnorm(x, weight=mod.weight, eps=epsilon)\n                    return forward\n                module.forward = _make_fwd(module, eps)\n            else:\n                def _make_fwd_noweight(epsilon):\n                    def forward(x):\n                        return diffusion_rmsnorm(x, weight=None, eps=epsilon)\n                    return forward\n                module.forward = _make_fwd_noweight(eps)\n```\n\n**Critical:** inject kernels **before** `torch.compile` and before any CPU offload is enabled.\n\n---\n\n## Step 4: Key Kernel Patterns Reference\n\n### Diffusion-Specific Operators\n\n| Operator | Kernel Pattern | Notes |\n|----------|---------------|-------|\n| **RMSNorm** | 2-pass row reduction + vectorized normalize | Weight may be `None` (`elementwise_affine=False`) |\n| **AdaLN modulation** | `y = norm(x) * (1 + scale) + shift` | Fuse norm + scale + shift in one pass |\n| **RoPE 3D** | Read `(t, h, w)` cos/sin tables, apply to `(q, k)` | Layout: `[batch, t*h*w, heads, head_dim]` |\n| **GEGLU** | Split last dim → `gate * silu(linear)` | Input `[B, L, 2*H]` → output `[B, L, H]` |\n| **SiLU gate** | `out = a * sigmoid(a)` fused | Avoid separate elementwise ops |\n\n### Vectorized Memory Access\n\n```cpp\n// BF16: 8 elements × 2 bytes = 16 bytes per vector load (AlignedVector<bf16_t, 8>)\n// FP16: 8 elements × 2 bytes = 16 bytes (AlignedVector<fp16_t, 8>)\n// FP32: 4 elements × 4 bytes = 16 bytes (AlignedVector<fp32_t, 4>)\nconstexpr int kVecN = 16 / sizeof(T);\nusing vec_t = device::AlignedVector<T, kVecN>;\n```\n\n### Warp / Block Reductions\n\n```cpp\n// Warp reduction (within 32 threads)\nfloat result = device::warp::reduce_sum<float>(partial);\n\n// Block reduction via shared memory (see rmsnorm example above)\n__shared__ float smem[32];\n// ... write warp-leaders into smem, sync, reduce again\n```\n\n### Thread Configuration\n\n```cpp\n// Element-wise (RoPE, GEGLU, SiLU): simple 1D grid\nconstexpr uint32_t kBlock = 256;\nuint32_t grid = host::div_ceil(total_elements, kBlock);\nLaunchKernel(grid, kBlock, dev)(kernel, ...);\n\n// Row reduction (RMSNorm, LayerNorm): one block per row\nuint32_t threads = std::min(hidden_size / kVecN, 512u);\nthreads = (threads + 31) / 32 * 32;\nLaunchKernel(num_rows, threads, dev)(kernel, ...);\n```\n\n---\n\n## Step 5: GPU Architecture Targets\n\n| GPU | Compute Cap | Memory BW | BF16 | Key Note |\n|-----|------------|-----------|------|----------|\n| H100 | sm_90 | 3.35 TB/s | Yes | Primary target; 132 SMs, 192 KB shared mem/SM |\n| A100 | sm_80 | 2.0 TB/s  | Yes | 108 SMs, 164 KB shared mem/SM |\n| T4   | sm_75 | 320 GB/s  | **No** | FP16 only; no `__nv_bfloat16` |\n\nIf kernel requires SM90+ features (e.g., TMA, wgmma), raise a clear error:\n\n```python\nif torch.cuda.get_device_capability()[0] < 9:\n    raise RuntimeError(\"This kernel requires SM90 (H100/Hopper) or later\")\n```\n\n**Grid sizing for H100** (132 SMs): aim for grid multiples of 132 for good occupancy.\n\n---\n\n## Step 6: Tests\n\nCreate `python/sglang/jit_kernel/tests/test_diffusion_rmsnorm.py`:\n\n```python\nimport pytest\nimport torch\nfrom sglang.jit_kernel.diffusion.rmsnorm import diffusion_rmsnorm\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16, torch.float32])\n@pytest.mark.parametrize(\"shape\", [(1, 2048), (4, 3072), (16, 4096)])\n@pytest.mark.parametrize(\"has_weight\", [True, False])\ndef test_rmsnorm_correctness(dtype, shape, has_weight):\n    batch, hidden = shape\n    src = torch.randn(batch, hidden, dtype=dtype, device=\"cuda\")\n    weight = torch.randn(hidden, dtype=dtype, device=\"cuda\") if has_weight else None\n\n    out_jit = diffusion_rmsnorm(src, weight=weight, eps=1e-6)\n\n    # Reference: torch.nn.functional\n    ref = torch.nn.functional.rms_norm(\n        src.float(), (hidden,), weight.float() if weight is not None else None, eps=1e-6\n    ).to(dtype)\n\n    tol = {\"rtol\": 1e-2, \"atol\": 1e-2} if dtype != torch.float32 else {\"rtol\": 1e-5, \"atol\": 1e-6}\n    torch.testing.assert_close(out_jit, ref, **tol)\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__, \"-v\", \"-s\"])\n```\n\n---\n\n## Step 7: Benchmark\n\nCreate `python/sglang/jit_kernel/benchmark/bench_diffusion_rmsnorm.py`:\n\n```python\nimport torch\nimport triton.testing\n\nfrom sglang.jit_kernel.benchmark.utils import DEFAULT_DEVICE, DEFAULT_DTYPE, run_benchmark\nfrom sglang.jit_kernel.diffusion.rmsnorm import diffusion_rmsnorm\n\nSHAPES = [(4096, 2048), (4096, 3072), (4096, 4096)]\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=[\"hidden\"],\n        x_vals=[s[1] for s in SHAPES],\n        line_arg=\"provider\",\n        line_vals=[\"jit_cuda\", \"torch\"],\n        line_names=[\"SGLang JIT CUDA\", \"PyTorch rms_norm\"],\n        styles=[(\"blue\", \"-\"), (\"red\", \"--\")],\n        ylabel=\"us\",\n        plot_name=\"diffusion-rmsnorm\",\n        args={},\n    )\n)\ndef benchmark(hidden: int, provider: str):\n    src = torch.randn(4096, hidden, dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE)\n    w   = torch.ones(hidden, dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE)\n\n    if provider == \"jit_cuda\":\n        fn = lambda: diffusion_rmsnorm(src, weight=w, eps=1e-6)\n    else:\n        fn = lambda: torch.nn.functional.rms_norm(src, (hidden,), w, eps=1e-6)\n\n    return run_benchmark(fn)\n\n\nif __name__ == \"__main__\":\n    benchmark.run(print_data=True)\n```\n\n---\n\n## Step 8: Profile with Nsight Compute (required)\n\nAfter correctness + benchmarking, you must collect **Nsight Compute (ncu)** data to validate:\n\n- Whether the kernel reaches reasonable bandwidth/throughput (avoid false positives where it is “faster” but under-utilizes hardware)\n- Whether there are clear occupancy / register / shared memory limiters\n\nUse the canonical docs in this directory (do not duplicate CLI details across multiple skills):\n\n- `diffusion-benchmark-and-profile.md` → Step 3.5 (ncu workflow, including CUDA graph profiling)\n- `nsight-profiler.md` (metrics interpretation: bandwidth / occupancy / roofline / stall reasons)\n\n---\n\n## Common Pitfalls\n\n| Issue | Fix |\n|-------|-----|\n| `RMSNorm weight is None` | Use `type(module).__name__` check; pass `None` weight explicitly |\n| `isinstance(m, torch.nn.RMSNorm)` misses diffusers variants | Use `\"RMSNorm\" in type(m).__name__` |\n| Kernel patched after `torch.compile` | Inject **before** any compile call |\n| Kernel patched after `enable_model_cpu_offload()` | Inject **before** CPU offload |\n| `hidden_size` not divisible by `kVecN` | Add `RuntimeCheck(hidden % kVecN == 0, ...)` in launcher |\n| `torch.compile` fails with custom CUDA kernel | Register as `@torch.library.custom_op` or use Triton instead |\n| T4 GPU with BF16 kernel | Gate on compute capability; T4 is `sm_75`, no native BF16 |\n\n---\n\n## Summary of Files\n\n```\npython/sglang/jit_kernel/csrc/diffusion/\n└── rmsnorm.cuh                                  # NEW: JIT CUDA kernel source\n\npython/sglang/jit_kernel/diffusion/\n└── rmsnorm.py                                   # NEW: Python wrapper + load_jit\n\npython/sglang/jit_kernel/tests/\n└── test_diffusion_rmsnorm.py                    # NEW: correctness tests\n\npython/sglang/jit_kernel/benchmark/\n└── bench_diffusion_rmsnorm.py                   # NEW: benchmark\n```\n\n---\n\n## References\n\n### This Skill's Extended Docs (references/ and scripts/)\n\n| File | Contents |\n|------|----------|\n| [references/kernel-templates.md](references/kernel-templates.md) | Copy-paste templates: element-wise, RMSNorm, AdaLN, Python wrapper, test, benchmark |\n| [references/troubleshooting.md](references/troubleshooting.md) | Build errors, perf issues, torch.compile compatibility, debugging checklist |\n| [references/h100-optimization-guide.md](references/h100-optimization-guide.md) | H100 (sm_90): memory hierarchy, warp reductions, occupancy, vectorization benchmarks |\n| [references/a100-optimization-guide.md](references/a100-optimization-guide.md) | A100 (sm_80): cp.async, TF32, 2:4 sparsity, H100→A100 migration checklist |\n| [references/t4-optimization-guide.md](references/t4-optimization-guide.md) | T4 (sm_75): FP16 only, low bandwidth, tile size limits, memory constraints |\n| [scripts/bench_diffusion_rmsnorm.py](scripts/bench_diffusion_rmsnorm.py) | Micro-benchmark: JIT CUDA RMSNorm vs PyTorch, correctness check, bandwidth analysis |\n| [scripts/bench_diffusion_denoise.py](scripts/bench_diffusion_denoise.py) | End-to-end: `sglang generate` baseline vs custom kernels, comparison table |\n\n### SGLang Internals\n\n- **JIT system**: `add-jit-kernel` skill (`sglang/.claude/skills/add-jit-kernel/SKILL.md`)\n- **JIT utils**: `python/sglang/jit_kernel/utils.py` — `cache_once`, `load_jit`, `make_cpp_args`\n- **Abstractions**: `python/sglang/jit_kernel/include/sgl_kernel/` — `tensor.h`, `utils.cuh`, `vec.cuh`, `warp.cuh`, `math.cuh`, `tile.cuh`\n- **Real csrc examples**: `python/sglang/jit_kernel/csrc/elementwise/rmsnorm.cuh`, `python/sglang/jit_kernel/csrc/elementwise/qknorm.cuh`\n\n### Other Diffusion Kernel Skills (this directory)\n\n- **Triton alternative**: `add-triton-kernel.md` — prefer Triton unless bandwidth analysis shows CUDA needed\n- **Existing fused kernels**: `use-efficient-diffusion-kernels.md` — check here first before writing new kernels\n- **Profiling**: `diffusion-benchmark-and-profile.md` — workflow to identify bottleneck before implementing\n- **Nsight Compute deep dive**: `nsight-profiler.md` — full guide: occupancy analysis, roofline model, warp efficiency, kernel comparison\n\n### External\n\n- [HuggingFace kernels cuda-kernels skill](https://github.com/huggingface/kernels/tree/main/skills/cuda-kernels) — original source adapted for this skill\n"
  },
  {
    "path": "python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/add-triton-kernel.md",
    "content": "---\nname: add-triton-kernel\ndescription: Step-by-step guide for adding a new Triton kernel to SGLang Diffusion's jit_kernel module. Use when implementing fused elementwise ops, norm variants, RoPE variants, or any other lightweight GPU kernel for diffusion models using Triton JIT. Covers kernel authoring, autotune, torch.compile compatibility, layer integration, and tests.\n---\n\n# Adding a Triton Kernel to SGLang Diffusion\n\nThis guide walks through adding a Triton kernel to `python/sglang/jit_kernel/diffusion/triton/`.\nWe use a fused elementwise operation as the running example: `y = x * (1 + scale) + shift` (AdaLN modulation).\n\nBefore compiling, benchmarking, or profiling any Triton kernel from this guide, use `scripts/diffusion_skill_env.py` (or the setup block in `diffusion-benchmark-and-profile.md`) to `cd` to the repo root resolved from `sglang.__file__`, verify write access, export `FLASHINFER_DISABLE_VERSION_CHECK=1`, and choose an idle GPU.\n\n---\n\n## Directory Layout\n\n```\npython/sglang/jit_kernel/diffusion/\n├── triton/\n│   ├── scale_shift.py          # AdaLN scale/shift fused kernels\n│   ├── norm.py                 # LayerNorm / RMSNorm fused kernels\n│   ├── rmsnorm_onepass.py      # One-pass RMSNorm for small hidden size\n│   └── rotary.py               # RoPE kernel\n└── cutedsl/\n    └── ...                     # CuTe DSL kernels (see use-efficient-diffusion-kernels.md)\n```\n\nNew Triton kernels go into `triton/<op_name>.py`.\n\n---\n\n## Step 1: Write the Triton Kernel\n\nCreate `python/sglang/jit_kernel/diffusion/triton/<op_name>.py`.\n\n### 1a. Imports\n\n```python\nimport torch\nimport triton          # type: ignore\nimport triton.language as tl  # type: ignore\n```\n\nAlways use `# type: ignore` on triton imports — the stubs are incomplete.\n\n### 1b. The `@triton.jit` Kernel Function\n\nFollow the naming convention `_<op_name>_kernel` (private, underscore prefix).\n\n```python\n@triton.autotune(\n    configs=[\n        triton.Config({\"BLOCK_C\": 64},  num_warps=2),\n        triton.Config({\"BLOCK_C\": 128}, num_warps=4),\n        triton.Config({\"BLOCK_C\": 256}, num_warps=4),\n        triton.Config({\"BLOCK_C\": 512}, num_warps=8),\n    ],\n    key=[\"C\"],   # re-tune when hidden dim changes\n)\n@triton.jit\ndef _fused_scale_shift_kernel(\n    # Pointers — always pass raw tensors; Triton takes .data_ptr() internally\n    x_ptr,\n    scale_ptr,\n    shift_ptr,\n    y_ptr,\n    # Dimensions\n    B,        # batch size\n    L,        # sequence length\n    C,        # hidden / channel dim\n    # Strides — pass every stride separately; do NOT assume contiguous\n    stride_xb, stride_xl, stride_xc,\n    stride_sb, stride_sc,\n    stride_yb, stride_yl, stride_yc,\n    # Compile-time constants (tl.constexpr)\n    BLOCK_C: tl.constexpr,\n):\n    # Grid: (cdiv(L, 1), B) — one program per (batch, token)\n    pid_l = tl.program_id(0)\n    pid_b = tl.program_id(1)\n\n    c_offs = tl.arange(0, BLOCK_C)\n    mask   = c_offs < C\n\n    x_row = pid_b * stride_xb + pid_l * stride_xl\n    y_row = pid_b * stride_yb + pid_l * stride_yl\n    s_row = pid_b * stride_sb\n\n    x     = tl.load(x_ptr     + x_row + c_offs * stride_xc, mask=mask, other=0.0)\n    scale = tl.load(scale_ptr + s_row + c_offs * stride_sc,  mask=mask, other=0.0)\n    shift = tl.load(shift_ptr + s_row + c_offs * stride_sc,  mask=mask, other=0.0)\n\n    y = x * (1.0 + scale) + shift\n    tl.store(y_ptr + y_row + c_offs * stride_yc, y, mask=mask)\n```\n\n**Rules:**\n- All pointer arguments are raw (Triton extracts `.data_ptr()` internally when called via `kernel[grid](...)`).\n- Pass every stride as a separate scalar — never assume a tensor is contiguous inside the kernel.\n- Use `tl.constexpr` for block sizes and boolean flags (`HAS_RESIDUAL`, `IS_RMS_NORM`, etc.).\n- Use `mask=mask, other=0.0` on every `tl.load` to avoid out-of-bounds reads.\n- Compute in `tl.float32` when precision matters (`x.to(tl.float32)`), then cast back to output dtype before `tl.store`.\n- Use `tl.fma(a, b, c)` (`a*b + c`) for fused multiply-add — avoids rounding errors and maps to a single instruction.\n\n### 1c. `@triton.autotune` Guidelines\n\n| `key` entry | When to include |\n|-------------|-----------------|\n| `\"C\"` / `\"hidden_dim\"` | Always — block tile size depends on C |\n| `\"IS_RMS_NORM\"` | When the kernel has a `constexpr` boolean flag that changes code paths |\n| `\"HAS_RESIDUAL\"` | Same — constexpr path branching |\n| Shape / batch / seq | Usually NOT — autotune cost outweighs benefit |\n\nKeep configs in ascending `BLOCK_C` order with matching `num_warps` (warp × 32 threads ≤ 1024).\n\n### 1d. `torch.compile` Compatibility\n\nWhen the kernel is called inside a `torch.compile`-d region, wrap the launch with `torch.library.wrap_triton`:\n\n```python\nwith torch.get_device_module().device(x.device):\n    torch.library.wrap_triton(_fused_scale_shift_kernel)[grid](\n        x, scale, shift, y,\n        B, L, C,\n        x.stride(0), x.stride(1), x.stride(2),\n        scale.stride(0), scale.stride(1),\n        y.stride(0), y.stride(1), y.stride(2),\n    )\n```\n\nUse `wrap_triton` when the kernel is called from a layer that runs under `torch.compile`.\nSkip it for utility kernels called only at Python graph boundaries.\n\n---\n\n## Step 2: Write the Python Launcher\n\nThe launcher is a regular Python function (public, no underscore) in the same file.\n\n```python\ndef fused_scale_shift(\n    x: torch.Tensor,\n    scale: torch.Tensor,\n    shift: torch.Tensor,\n) -> torch.Tensor:\n    \"\"\"\n    Fused AdaLN modulation: y = x * (1 + scale) + shift.\n\n    Args:\n        x:     [B, L, C], CUDA, contiguous\n        scale: [B, C],    CUDA\n        shift: [B, C],    CUDA (same shape as scale)\n\n    Returns:\n        y: same shape and dtype as x\n    \"\"\"\n    # --- Precondition checks ---\n    assert x.is_cuda,           \"x must be on CUDA\"\n    assert x.is_contiguous(),   \"x must be contiguous\"\n    assert scale.is_cuda and shift.is_cuda\n    assert x.ndim == 3,         f\"x must be 3D [B, L, C], got {x.shape}\"\n    assert scale.shape == shift.shape\n    B, L, C = x.shape\n\n    # Allocate output\n    y = torch.empty_like(x)\n\n    # Grid: one program per token\n    grid = (L, B)\n\n    _fused_scale_shift_kernel[grid](\n        x, scale, shift, y,\n        B, L, C,\n        x.stride(0),     x.stride(1),     x.stride(2),\n        scale.stride(0), scale.stride(1),\n        y.stride(0),     y.stride(1),     y.stride(2),\n    )\n    return y\n```\n\n**Rules:**\n- Validate CUDA placement and shape/dtype **before** launching — use `assert` with a helpful message.\n- Call `.contiguous()` on inputs that the kernel requires contiguous **before** the launch, not inside it.\n- Allocate the output with `torch.empty_like(x)` — never reuse input buffers unless the op is explicitly in-place.\n- The `grid` is a tuple or a lambda `(META)` when block sizes are auto-tuned:\n\n```python\n# Static grid (block size fixed)\ngrid = (triton.cdiv(L, BLOCK_L), triton.cdiv(C, BLOCK_C), B)\n\n# Dynamic grid (block size comes from autotune)\ngrid = lambda META: (triton.cdiv(L, META[\"BLOCK_C\"]), B)\n```\n\n### Handling Non-Contiguous Inputs\n\nNever call `.contiguous()` silently — it copies data. Instead, pass strides to the kernel and let it handle arbitrary layouts. Only call `.contiguous()` when the kernel genuinely requires it (e.g., after a reshape):\n\n```python\n# OK: reshape + contiguous needed for 2D view trick\nx_2d = x.view(B * L, C)             # view only works on contiguous\nif not x.is_contiguous():\n    x = x.contiguous()\n    x_2d = x.view(B * L, C)\n```\n\n---\n\n## Step 3: Integrate into the Layer\n\nCall the new kernel from the appropriate layer file in\n`python/sglang/multimodal_gen/runtime/layers/` (typically `layernorm.py` or `elementwise.py`).\n\n```python\n# In layernorm.py or elementwise.py\nimport torch\n\ndef apply_scale_shift(x, scale, shift):\n    if x.is_cuda:\n        from sglang.jit_kernel.diffusion.triton.my_op import fused_scale_shift\n        return fused_scale_shift(x, scale, shift)\n    # Pure-PyTorch fallback for non-CUDA execution\n    return x * (1.0 + scale) + shift\n```\n\n**Rules:**\n- Gate on `x.is_cuda` — the Triton kernel only runs on CUDA; the fallback handles everything else.\n- The launcher raises `AssertionError` on invalid inputs (wrong shape, CPU tensor, etc.) — do **not** silently catch these. Let them propagate so bugs are visible during development.\n- Add `logger.warning_once(...)` only when falling back due to a **known hardware limitation** (e.g., unsupported SM compute capability), not for wrong-input errors.\n\n---\n\n## Step 4: Write Tests\n\nCreate `python/sglang/jit_kernel/tests/test_<op_name>.py`.\n\n```python\nimport pytest\nimport torch\n\nfrom sglang.jit_kernel.diffusion.triton.my_op import fused_scale_shift\n\n\ndef _ref_fused_scale_shift(x, scale, shift):\n    \"\"\"PyTorch reference implementation.\"\"\"\n    # Broadcast scale/shift from [B, C] to [B, L, C]\n    return x * (1.0 + scale.unsqueeze(1)) + shift.unsqueeze(1)\n\n\n@pytest.fixture(autouse=True)\ndef require_cuda():\n    if not torch.cuda.is_available():\n        pytest.skip(\"CUDA required\")\n\n\n@pytest.mark.parametrize(\"B,L,C\", [\n    (1, 6,    3072),   # Qwen (small batch)\n    (1, 1024, 1536),   # Wan\n    (2, 512,  3072),   # typical training shape\n    (1, 1,    256),    # edge: L=1\n])\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16, torch.float32])\ndef test_fused_scale_shift_correctness(B, L, C, dtype):\n    torch.manual_seed(0)\n    x     = torch.randn(B, L, C, dtype=dtype, device=\"cuda\")\n    scale = torch.randn(B, C,    dtype=dtype, device=\"cuda\") * 0.1\n    shift = torch.randn(B, C,    dtype=dtype, device=\"cuda\") * 0.1\n\n    out = fused_scale_shift(x, scale, shift)\n    ref = _ref_fused_scale_shift(x.float(), scale.float(), shift.float()).to(dtype)\n\n    atol = 1e-5 if dtype == torch.float32 else 1e-2\n    torch.testing.assert_close(out, ref, atol=atol, rtol=atol,\n                                msg=f\"Mismatch at B={B} L={L} C={C} dtype={dtype}\")\n\n\ndef test_fused_scale_shift_non_cuda_raises():\n    x     = torch.randn(1, 4, 64)\n    scale = torch.randn(1, 64)\n    shift = torch.randn(1, 64)\n    with pytest.raises(AssertionError, match=\"CUDA\"):\n        fused_scale_shift(x, scale, shift)\n\n\ndef test_fused_scale_shift_output_dtype_preserved():\n    x     = torch.randn(1, 8, 128, dtype=torch.bfloat16, device=\"cuda\")\n    scale = torch.randn(1, 128, dtype=torch.bfloat16, device=\"cuda\")\n    shift = torch.zeros(1, 128, dtype=torch.bfloat16, device=\"cuda\")\n    out   = fused_scale_shift(x, scale, shift)\n    assert out.dtype == torch.bfloat16\n    assert out.shape == x.shape\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__, \"-v\"])\n```\n\nRun:\n\n```bash\npytest python/sglang/jit_kernel/tests/test_<op_name>.py -v\n```\n\n**Test coverage requirements:**\n1. Reference comparison against pure-PyTorch for all supported dtypes (fp16, bf16, fp32).\n2. Edge shapes: `L=1`, `C` not a multiple of the largest BLOCK_C, large `B`.\n3. Error cases: CPU tensor, wrong shape.\n4. Output dtype and shape preservation.\n\n---\n\n## Step 5: Add a Benchmark (required)\n\nCreate `python/sglang/jit_kernel/benchmark/bench_<op_name>.py`.\n\n```python\nimport torch\nimport triton.testing\n\nfrom sglang.jit_kernel.diffusion.triton.my_op import fused_scale_shift\n\n\nSHAPES = [\n    # (B, L, C)  — representative diffusion shapes\n    (1, 6,    3072),   # Qwen image\n    (1, 1024, 1536),   # Wan video\n    (1, 4096, 3072),   # FLUX double-stream\n]\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=[\"B\", \"L\", \"C\"],\n        x_vals=SHAPES,\n        line_arg=\"provider\",\n        line_vals=[\"triton\", \"torch\"],\n        line_names=[\"Triton Fused\", \"PyTorch\"],\n        styles=[(\"blue\", \"-\"), (\"red\", \"--\")],\n        ylabel=\"µs (median)\",\n        plot_name=\"fused-scale-shift\",\n        args={},\n    )\n)\ndef benchmark(B, L, C, provider):\n    dtype = torch.bfloat16\n    x     = torch.randn(B, L, C, dtype=dtype, device=\"cuda\")\n    scale = torch.randn(B, C,    dtype=dtype, device=\"cuda\")\n    shift = torch.randn(B, C,    dtype=dtype, device=\"cuda\")\n\n    if provider == \"triton\":\n        fn = lambda: fused_scale_shift(x, scale, shift)\n    else:\n        fn = lambda: x * (1.0 + scale.unsqueeze(1)) + shift.unsqueeze(1)\n\n    ms, *_ = triton.testing.do_bench_cudagraph(fn, quantiles=[0.5, 0.2, 0.8])\n    return ms * 1000  # µs\n\n\nif __name__ == \"__main__\":\n    benchmark.run(print_data=True)\n```\n\nRun:\n\n```bash\npython python/sglang/jit_kernel/benchmark/bench_<op_name>.py\n```\n\n---\n\n## Step 6: Profile with Nsight Compute (required for optimization work)\n\nAfter correctness tests, you must use **ncu (Nsight Compute)** to validate hardware efficiency (bandwidth/throughput/occupancy/bottleneck type).\n\nTo avoid duplicating ncu CLI details across multiple skills, this skill does not repeat command flags. Follow the canonical docs:\n\n- `diffusion-benchmark-and-profile.md` → Step 3.5 (ncu workflow, including CUDA graph profiling)\n- `nsight-profiler.md` (metrics interpretation: bandwidth / occupancy / roofline / warp stalls)\n\n---\n\n## Common Patterns Reference\n\n### Pattern 1: Autotune over a 2D tile (L × C)\n\nUsed in `scale_shift.py` (`fuse_scale_shift_kernel_blc_opt`):\n\n```python\n@triton.jit\ndef _kernel(..., BLOCK_L: tl.constexpr, BLOCK_C: tl.constexpr):\n    pid_l = tl.program_id(0)\n    pid_c = tl.program_id(1)\n    pid_b = tl.program_id(2)\n    l_offs = pid_l * BLOCK_L + tl.arange(0, BLOCK_L)\n    c_offs = pid_c * BLOCK_C + tl.arange(0, BLOCK_C)\n    mask = (l_offs[:, None] < L) & (c_offs[None, :] < C)\n    ...\n\n# Launch:\ngrid = (triton.cdiv(L, BLOCK_L), triton.cdiv(C, BLOCK_C), B)\n_kernel[grid](..., BLOCK_L=block_l, BLOCK_C=block_c, num_warps=4, num_stages=2)\n```\n\n### Pattern 2: One-pass RMSNorm for small hidden size\n\nUsed in `rmsnorm_onepass.py`:\n\n```python\n@triton.jit\ndef _rms_norm_tiled_onepass(y_ptr, x_ptr, w_ptr,\n                              SEQ: tl.constexpr, DIM: tl.constexpr, EPS: tl.constexpr,\n                              BLOCK_SIZE_SEQ: tl.constexpr, BLOCK_SIZE_DIM: tl.constexpr):\n    seq_blk_id = tl.program_id(0)\n    seq_id     = seq_blk_id * BLOCK_SIZE_SEQ\n    seq_offset = seq_id + tl.arange(0, BLOCK_SIZE_SEQ)[:, None]\n    d_offset   = tl.arange(0, BLOCK_SIZE_DIM)[None, :]\n    ...\n    x = tl.load(x_ptr + seq_offset * DIM + d_offset, mask=..., other=0.0).to(tl.float32)\n    mean_sq = tl.sum(x * x, axis=1, keep_dims=True) / DIM\n    rstd    = tl.math.rsqrt(mean_sq + EPS)\n    tl.store(y_ptr + ..., x * rstd * w, mask=...)\n\n# Launch with wrap_triton for torch.compile compat:\nwith torch.get_device_module().device(x.device):\n    torch.library.wrap_triton(_rms_norm_tiled_onepass)[grid](\n        y_view, x_view, w,\n        S, D, eps,\n        BLOCK_SIZE_DIM=triton.next_power_of_2(D),\n        BLOCK_SIZE_SEQ=BLOCK_SIZE_SEQ,\n    )\n```\n\n### Pattern 3: `tl.constexpr` boolean flags for conditional paths\n\nUsed in `norm.py` and `scale_shift.py`:\n\n```python\n@triton.jit\ndef _kernel(...,\n            IS_RMS_NORM:   tl.constexpr,\n            HAS_RESIDUAL:  tl.constexpr,\n            SCALE_IS_SCALAR: tl.constexpr):\n    ...\n    if IS_RMS_NORM:\n        var = tl.sum(x * x, axis=0) / N\n    else:\n        mean = tl.sum(x, axis=0) / N\n        var  = tl.sum((x - mean) ** 2, axis=0) / N\n\n    if HAS_RESIDUAL:\n        x = x + tl.load(residual_ptr + ...)\n\n    if SCALE_IS_SCALAR:\n        scale_val = tl.load(scale_ptr)\n        scale = tl.full([BLOCK_N], scale_val, dtype=scale_val.dtype)\n    else:\n        scale = tl.load(scale_ptr + col_offsets, mask=mask, other=0.0)\n```\n\nAutotune key must include these booleans so the compiler generates separate specializations.\n\n### Pattern 4: Computing in fp32, storing in original dtype\n\nAlways up-cast to `tl.float32` for reductions and math, then down-cast before storing:\n\n```python\nx_f32    = x.to(tl.float32)\nscale_f32 = scale.to(tl.float32)\ny_f32    = x_f32 * (1.0 + scale_f32) + shift_f32\ntl.store(y_ptr + offsets, y_f32.to(x.dtype), mask=mask)\n```\n\n---\n\n## Checklist Before Submitting\n\n### Prerequisites\n- [ ] `ncu --version` prints a valid Nsight Compute version (required for Step 7 profiling)\n\n### Implementation\n- [ ] Kernel file at `python/sglang/jit_kernel/diffusion/triton/<op_name>.py`\n- [ ] All pointer arguments passed with separate stride scalars\n- [ ] Every `tl.load` uses `mask=` and `other=`\n- [ ] Autotune `key` includes all `constexpr` flags that change code paths\n- [ ] `torch.library.wrap_triton` used if kernel runs inside `torch.compile` region\n- [ ] PyTorch fallback path in the layer integration (see Step 4)\n\n### Validation\n- [ ] Tests pass: `pytest python/sglang/jit_kernel/tests/test_<op_name>.py -v`\n- [ ] Benchmark runs: `python python/sglang/jit_kernel/benchmark/bench_<op_name>.py`\n- [ ] **Correctness verified**: Triton output matches PyTorch reference within tolerance\n- [ ] Nsight Compute profile collected (`ncu --set full`); achieved occupancy ≥ 50% and memory throughput ≥ 70% of peak (or bottleneck documented)\n\n---\n\n## Summary of Files Created/Modified\n\n```\npython/sglang/jit_kernel/diffusion/triton/<op_name>.py      # NEW: Triton kernel + launcher\npython/sglang/jit_kernel/tests/test_<op_name>.py            # NEW: correctness tests\npython/sglang/jit_kernel/benchmark/bench_<op_name>.py       # NEW: performance benchmark\npython/sglang/multimodal_gen/runtime/layers/layernorm.py    # MODIFIED: integrate into layer\n  (or elementwise.py, depending on op type)\n```\n\n## References\n\n- `python/sglang/jit_kernel/diffusion/triton/scale_shift.py` — 2D tile pattern, scalar broadcast, 4D shape handling\n- `python/sglang/jit_kernel/diffusion/triton/rmsnorm_onepass.py` — `wrap_triton`, tiled one-pass reduction\n- `python/sglang/jit_kernel/diffusion/triton/norm.py` — complex autotune with many `constexpr` flags\n- `python/sglang/jit_kernel/diffusion/triton/rotary.py` — per-head grid, interleaved RoPE\n- `nsight-profiler.md` — full Nsight Compute guide: occupancy analysis, roofline model, warp efficiency, kernel comparison\n- `diffusion-benchmark-and-profile.md` — how to verify the kernel's impact on denoise latency\n- `use-efficient-diffusion-kernels.md` — overview of existing fused kernel entry points\n"
  },
  {
    "path": "python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/diffusion-benchmark-and-profile.md",
    "content": "---\nname: diffusion-benchmark-and-profile\ndescription: Denoise-stage benchmark and per-layer kernel profiling guide for SGLang Diffusion models. Use when measuring denoising latency, profiling DiT kernel breakdown with torch.profiler or nsys+gputrc2graph.py, investigating performance bottlenecks, or optimizing with custom Triton/CUDA kernels. Always verify output correctness before and after any optimization.\n---\n\n# SGLang Diffusion Benchmark and Profile Guide\n\n**Primary Metric: Denoise Latency**\nThe denoising loop latency — total DiT forward pass time across all inference steps — is the dominant cost (>80% of end-to-end) and the **sole optimization target** for kernel work. End-to-end latency is recorded as a secondary check only.\n\n> **Correctness First**: Faster but incorrect output is not an improvement. Always compare generated images/videos against a reference baseline before and after any change.\n\n---\n\n## Prerequisites\n\n```bash\nENV_PY=python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/scripts/diffusion_skill_env.py\nROOT=$(python3 \"$ENV_PY\" print-root)\ncd \"$ROOT\"\npython3 \"$ENV_PY\" check-write-access >/dev/null\n\nexport FLASHINFER_DISABLE_VERSION_CHECK=1\nexport CUDA_VISIBLE_DEVICES=$(python3 \"$ENV_PY\" print-idle-gpus --count 1)\n\nASSET_DIR=$(python3 \"$ENV_PY\" print-assets-dir --mkdir)\nBENCH_DIR=$(python3 \"$ENV_PY\" print-output-dir --kind benchmarks --mkdir)\nPROFILE_DIR=$(python3 \"$ENV_PY\" print-output-dir --kind profiles --mkdir)\nNCU_DIR=$(python3 \"$ENV_PY\" print-output-dir --kind ncu --mkdir)\nexport PROFILE_DIR\n\ncheck() {\n  local label=\"$1\"\n  shift\n  \"$@\" &>/dev/null && echo \"[OK]  $label\" || echo \"[MISS] $label\"\n}\n\ncheck \"sglang\" python3 -c \"import sglang\"\ncheck \"torch+CUDA\" python3 -c \"import torch; assert torch.cuda.is_available()\"\ncheck \"torch.profiler\" python3 -c \"import torch.profiler\"\ncheck \"nsys (Level 2)\" which nsys\ncheck \"ncu (Level 3)\" which ncu\ncheck \"pandas\" python3 -c \"import pandas\"\ncheck \"plotly\" python3 -c \"import plotly\"\n```\n\n**Minimum for benchmarking**: `sglang`, `torch` with CUDA.\n**Level 1 profiling**: `torch.profiler` (bundled with torch).\n**Level 2 profiling**: `nsys`, `pandas`, `plotly` + `gputrc2graph.py` from the sglang repo.\nAll commands below assume you are inside the configured diffusion container shell, already `cd`'d to the repo root derived from `sglang.__file__`, with `FLASHINFER_DISABLE_VERSION_CHECK=1` exported. Re-run `print-idle-gpus` before each perf command if GPU availability may have changed. Keep benchmark commands within 4 GPUs or fewer.\n\nDownload input images required by some models:\n```bash\nwget -O \"${ASSET_DIR}/cat.png\" \\\n  https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png\nwget -O \"${ASSET_DIR}/astronaut.jpg\" \\\n  https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg\nwget -O \"${ASSET_DIR}/mova_single_person.jpg\" \\\n  https://github.com/OpenMOSS/MOVA/raw/main/assets/single_person.jpg\n```\n\n---\n\n## Benchmark Commands\n\nAll commands include `--warmup` and `--enable-torch-compile` for real production performance. Add `--perf-dump-path <file>.json` for machine-readable output.\n\n### Perf dump & before/after compare\n\nFor every benchmark run, always write a perf dump JSON:\n\n```bash\nsglang generate ... --warmup --perf-dump-path \"${BENCH_DIR}/<result>.json\"\n```\n\nBefore/after comparison (outputs a Markdown table suitable for PR descriptions):\n\n```bash\n# Baseline (on main branch or before changes)\nsglang generate ... --warmup --perf-dump-path \"${BENCH_DIR}/baseline.json\"\n\n# New (after changes)\nsglang generate ... --warmup --perf-dump-path \"${BENCH_DIR}/new.json\"\n\npython3 python/sglang/multimodal_gen/benchmarks/compare_perf.py \\\n  \"${BENCH_DIR}/baseline.json\" \"${BENCH_DIR}/new.json\"\n```\n\n### Qwen-Image-2512 (1024×1024, 50 steps)\n```bash\nsglang generate \\\n  --model-path=Qwen/Qwen-Image-2512 \\\n  --prompt=\"A futuristic cyberpunk city at night, neon lights reflecting on wet streets, highly detailed, 8k\" \\\n  '--negative-prompt= ' \\\n  --width=1024 --height=1024 --num-inference-steps=50 --guidance-scale=4.0 \\\n  --seed=42 --save-output --enable-torch-compile --warmup \\\n  --dit-cpu-offload false --text-encoder-cpu-offload false\n```\n\n### Qwen-Image-Edit-2511 (image editing, 1024×1024, 50 steps)\n```bash\nsglang generate \\\n  --model-path=Qwen/Qwen-Image-Edit-2511 \\\n  '--prompt=Transform into anime style' '--negative-prompt= ' \\\n  --image-path=\"${ASSET_DIR}/cat.png\" \\\n  --width=1024 --height=1024 --num-inference-steps=50 --guidance-scale=4.0 \\\n  --seed=42 --save-output --enable-torch-compile --warmup \\\n  --dit-cpu-offload false --text-encoder-cpu-offload false\n```\n\n### FLUX.1-dev (1024×1024, 50 steps)\n```bash\nsglang generate \\\n  --model-path=black-forest-labs/FLUX.1-dev \\\n  --prompt=\"A futuristic cyberpunk city at night, neon lights reflecting on wet streets, highly detailed, 8k\" \\\n  --width=1024 --height=1024 --num-inference-steps=50 --guidance-scale=4.0 \\\n  --seed=42 --save-output --enable-torch-compile --warmup\n```\n\n### FLUX.2-dev (1024×1024)\n```bash\nsglang generate \\\n  --model-path black-forest-labs/FLUX.2-dev \\\n  --prompt \"A Logo With Bold Large Text: SGL Diffusion\" \\\n  --width=1024 --height=1024 \\\n  --dit-layerwise-offload false --enable-torch-compile --warmup \\\n  --dit-cpu-offload false --text-encoder-cpu-offload true --vae-cpu-offload false\n```\n\n### Z-Image-Turbo (1024×1024, 9 steps)\n```bash\nsglang generate \\\n  --model-path=Tongyi-MAI/Z-Image-Turbo \\\n  --prompt='A fantasy landscape with mountains and a river, detailed, vibrant colors' \\\n  --width=1024 --height=1024 --num-inference-steps=9 --guidance-scale=0.0 \\\n  --seed=42 --save-output --enable-torch-compile --warmup \\\n  --dit-cpu-offload false --text-encoder-cpu-offload false\n```\n\n### Wan2.2-T2V-A14B 720P (4 GPUs, 81 frames, 2 steps)\n```bash\n# Select four idle GPUs first:\n# export CUDA_VISIBLE_DEVICES=$(python3 \"$ENV_PY\" print-idle-gpus --count 4)\nsglang generate \\\n  --model-path=Wan-AI/Wan2.2-T2V-A14B-Diffusers \\\n  --prompt=\"A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon.\" \\\n  --negative-prompt=\" \" --720p --num-inference-steps=2 --num-frames=81 \\\n  --guidance-scale=5.0 --seed=42 --save-output \\\n  --num-gpus=4 --ulysses-degree=4 \\\n  --text-encoder-cpu-offload --pin-cpu-memory \\\n  --warmup --enable-torch-compile\n```\n\n### Wan2.2-TI2V-5B 720P (single GPU, 81 frames, 50 steps)\n```bash\nsglang generate \\\n  --model-path Wan-AI/Wan2.2-TI2V-5B-Diffusers \\\n  --prompt \"An astronaut hatching from an egg, on the surface of the moon...\" \\\n  --negative-prompt \"Bright tones, overexposed, static, blurred details...\" \\\n  --image-path=\"${ASSET_DIR}/astronaut.jpg\" \\\n  --num-frames 81 --720p --num-inference-steps 50 --guidance-scale 5.0 \\\n  --seed 42 --save-output \\\n  --dit-layerwise-offload false --dit-cpu-offload false \\\n  --vae-cpu-offload false --text-encoder-cpu-offload false \\\n  --enable-torch-compile --warmup\n```\n\n### HunyuanVideo (848×480, 65 frames, 30 steps)\n```bash\nsglang generate \\\n  --model-path=hunyuanvideo-community/HunyuanVideo \\\n  --text-encoder-cpu-offload --pin-cpu-memory \\\n  --prompt=\"A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window.\" \\\n  --save-output --num-frames=65 --width=848 --height=480 \\\n  --num-inference-steps=30 \\\n  --warmup --enable-torch-compile\n```\n\n### MOVA-720p (4 GPUs, 193 frames, 2 steps)\n```bash\n# Select four idle GPUs first:\n# export CUDA_VISIBLE_DEVICES=$(python3 \"$ENV_PY\" print-idle-gpus --count 4)\nsglang generate \\\n  --model-path=OpenMOSS-Team/MOVA-720p \\\n  --prompt=\"A man in a blue blazer and glasses speaks in a formal indoor setting, framed by wooden furniture and a filled bookshelf. Quiet room acoustics underscore his measured tone as he delivers his remarks. At one point, he says, \\\"I would also believe that this advance in AI recently wasn’t unexpected.\\\"\" \\\n  --image-path=\"${ASSET_DIR}/mova_single_person.jpg\" \\\n  --adjust-frames=false \\\n  --num-gpus=4 --ring-degree=1 --ulysses-degree=4 \\\n  --num-frames=193 --fps=24 \\\n  --num-inference-steps=2 \\\n  --enable-torch-compile --save-output --warmup\n```\n\n**Key metrics** (all models): denoise latency ★, end-to-end latency, peak GPU memory.\n\n---\n\n## Performance Bottleneck Workflow\n\n### Step 1: Identify the Slow DiT Operation\n\nAdd `--log-level=info` and observe:\n- **Denoise loop latency** ★ — primary target\n- Per-step DiT latency — denoise ÷ steps\n\n### Step 2: Profile with torch.profiler (Level 1)\n\n**Compile-safety rule for fused or rewritten kernels**\n- Any new kernel must be checked for `torch.compile` graph breaks before trusting its benchmark result.\n- If a direct Python/library call triggers tracing issues, wrap it as a custom op first.\n- For external libraries, use `register_custom_op_from_extern(...)`.\n- For SGLang JIT kernels, use `@register_custom_op(...)` and keep the JIT/module loading inside the custom op body.\n- Re-run `torch._dynamo.explain` on representative shapes and verify the optimized path still gets `graph_count=1` and `graph_break_count=0`.\n\n```bash\nSGLANG_TORCH_PROFILER_DIR=\"${PROFILE_DIR}/torch\" \\\nsglang generate \\\n  --model-path=black-forest-labs/FLUX.1-dev \\\n  --prompt=\"A futuristic cyberpunk city at night\" \\\n  --width=1024 --height=1024 --num-inference-steps=50 \\\n  --seed=42 --enable-torch-compile --warmup \\\n  --profile --num-profiled-timesteps 3\n```\n\nParse the trace without a browser:\n```python\nimport gzip, json, collections, glob, os\n\nlog_dir = os.environ.get(\"SGLANG_TORCH_PROFILER_DIR\", \"./logs\")\ntrace_path = sorted(glob.glob(f\"{log_dir}/*.trace.json.gz\"), key=os.path.getmtime, reverse=True)[0]\nwith gzip.open(trace_path, \"rb\") as f:\n    data = json.loads(f.read())\n\ncuda_ops = collections.defaultdict(lambda: {\"total_us\": 0, \"count\": 0})\nfor e in data.get(\"traceEvents\", []):\n    if e.get(\"cat\") in (\"kernel\", \"gpu_memcpy\") and \"dur\" in e:\n        cuda_ops[e.get(\"name\",\"unknown\")][\"total_us\"] += e[\"dur\"]\n        cuda_ops[e.get(\"name\",\"unknown\")][\"count\"] += 1\n\nprint(f\"{'Kernel':<80} {'Total(ms)':>10} {'Count':>6}\")\nfor name, s in sorted(cuda_ops.items(), key=lambda x: -x[1][\"total_us\"])[:30]:\n    print(f\"{name:<80} {s['total_us']/1000:>10.3f} {s['count']:>6}\")\n```\n\nAdd `record_function` scopes in the DiT block for per-layer attribution:\n```python\nwith torch.profiler.record_function(f\"dit_block_{idx}.attn\"):\n    x = self.attn(x)\nwith torch.profiler.record_function(f\"dit_block_{idx}.norm\"):\n    x = self.norm(x)\n```\n\n**Expected dominant kernels per DiT sub-component:**\n\n| Sub-component | Expected kernel |\n|--------------|-----------------|\n| QKV / output / MLP projections | `cutlass_gemm` / `ampere_*_gemm` |\n| Attention | `flash_attn_fwd` / `fmha_*` (FA3/FA4) |\n| AdaLN modulation | `fuse_scale_shift_kernel` |\n| RMSNorm / LayerNorm | `sgl_kernel_rmsnorm` / Triton norm |\n| SiLU gate | `vectorized_elementwise_kernel` |\n| RoPE | `apply_rotary_embedding` (Triton) |\n| QK Norm | `fused_inplace_qknorm` (JIT) |\n\n### Step 3: Deep CUDA Kernel Breakdown (Level 2 — nsys)\n\n```bash\n# Pass A — collect nsys trace (skip warmup with --delay)\nnsys profile -t cuda -o \"${PROFILE_DIR}/flux_dev\" -f true \\\n  --trace-fork-before-exec=true --delay 120 --duration 60 \\\n  sglang generate \\\n    --model-path=black-forest-labs/FLUX.1-dev \\\n    --prompt=\"A futuristic cyberpunk city at night\" \\\n    --width=1024 --height=1024 --num-inference-steps=50 \\\n    --seed=42 --enable-torch-compile --warmup\n\n# Pass B — measure wall-clock time without profiling\ntime sglang generate --model-path=black-forest-labs/FLUX.1-dev \\\n  --width=1024 --height=1024 --num-inference-steps=50 --seed=42 \\\n  --enable-torch-compile --warmup\n# Record ELAPSED_SEC from Pass B\n```\n\nCreate classification JSON at `examples/profiler/nsys_profile_tools/sglang_diffusion_engine_model.json`:\n```json\n{\n  \"sglang\": {\n    \"diffusion\": {\n      \"gemm|nvjet|cutlass\": \"gemm\",\n      \"flash|fmha|fwd_flash\": \"attn\",\n      \"fuse_scale_shift|scale_shift_gate\": \"adaln_modulation\",\n      \"_norm_|Norm|rmsnorm|fused_add_rmsnorm\": \"norm\",\n      \"rotary|rope\": \"rope\",\n      \"act_and_mul|silu|gelu\": \"activation\",\n      \"ncclDevKernel|all_gather|all_reduce\": \"nccl_comm\",\n      \"triton\": \"triton_kernel\",\n      \"CUDA mem\": \"non-gpu-H_D_memops\",\n      \".*\": \"misc\"\n    }\n  }\n}\n```\n\nRun analysis:\n```bash\ncd \"$ROOT/examples/profiler/nsys_profile_tools\"\npython3 gputrc2graph.py \\\n  --in_file \"${PROFILE_DIR}/flux_dev.nsys-rep,sglang,diffusion,ELAPSED_SEC\" \\\n  --out_dir \"${PROFILE_DIR}/analysis\" \\\n  --title \"FLUX.1-dev denoise kernel breakdown\"\n\n# Read results\npython3 - << 'EOF'\nimport pandas as pd\ndf = pd.read_csv(f\"{os.environ['PROFILE_DIR']}/analysis/result.csv\")\nsummary = df.groupby(\"Category\")[\"Elapsed Time (sec)\"].sum().sort_values(ascending=False)\ntotal = summary.sum()\nfor cat, sec in summary.items():\n    print(f\"{cat:<30} {sec:>8.3f}s  ({sec/total*100:>5.1f}%)\")\nEOF\n```\n\n**What the category breakdown tells you:**\n\n| Category high | Investigation |\n|--------------|---------------|\n| `gemm` dominant | Check tensor parallelism; QKV/MLP bottleneck |\n| `attn` dominant | Verify FA3/FA4 is active |\n| `adaln_modulation` high | Verify fused `fuse_scale_shift_kernel` is used |\n| `norm` high | Verify `sgl_kernel_rmsnorm` / CuTe DSL path; check D alignment |\n| `nccl_comm` high | Multi-GPU: tune Ulysses degree |\n| `triton_kernel` high | Identify which Triton kernel; consider CUDA replacement |\n| `non-gpu-H_D_memops` high | Accidental CPU offload or `.cpu()` calls mid-denoising |\n| `CPU(non-GPU)` high | Python dispatch overhead / torch.compile graph breaks |\n\n### Step 3.5: Per-Kernel Deep Analysis (Level 3 — ncu)\n\n**CRITICAL**: `ncu` (Nsight Compute) is the essential tool for kernel-level optimization. While nsys and torch.profiler tell you **which** kernels are slow, only ncu tells you **why** — memory bandwidth utilization, compute throughput, occupancy limiters, warp stall reasons, and roofline position. **Always use ncu when optimizing or writing custom kernels.**\n\n#### When to use ncu\n\n- After writing a new Triton or CUDA kernel — verify it saturates hardware bandwidth\n- When a kernel shows up as a top bottleneck in Level 1/2 profiling\n- When comparing your fused kernel vs PyTorch baseline or torch.compile output\n- When tuning Triton autotune configs (block sizes, num_warps)\n\n#### Basic ncu workflow\n\n```bash\n# 1. Profile a specific kernel by name (skip warmup launches, collect 3 invocations)\nncu --kernel-name \"_fused_gated_residual_add_kernel\" \\\n    --launch-skip 10 --launch-count 3 \\\n    --set full \\\n    -o \"${NCU_DIR}/gated_residual\" \\\n    sglang generate \\\n      --model-path=black-forest-labs/FLUX.1-dev \\\n      --prompt=\"test\" --width=1024 --height=1024 \\\n      --num-inference-steps=5 --seed=42\n\n# 2. Profile all kernels in a short run (use few steps to limit time)\nncu --launch-skip 50 --launch-count 200 \\\n    --set full \\\n    -o \"${NCU_DIR}/all_kernels\" \\\n    sglang generate \\\n      --model-path=black-forest-labs/FLUX.1-dev \\\n      --prompt=\"test\" --width=1024 --height=1024 \\\n      --num-inference-steps=3 --seed=42\n\n# 3. For CUDA graph mode, keep --graph-profiling=node on the ncu side.\n# Note: `--enable-piecewise-cuda-graph` is a server flag, not a valid\n# `sglang generate` flag, so do not append it here.\nncu --graph-profiling node \\\n    --kernel-name \"_fused_gated_residual_add_kernel\" \\\n    --launch-skip 5 --launch-count 3 \\\n    --set full \\\n    -o \"${NCU_DIR}/gated_residual_cudagraph\" \\\n    sglang generate \\\n      --model-path=black-forest-labs/FLUX.1-dev \\\n      --prompt=\"test\" --width=1024 --height=1024 \\\n      --num-inference-steps=5 --seed=42\n```\n\n#### Reading ncu results (CLI, no GUI needed)\n\n```bash\n# Summary of all profiled kernels\nncu --import \"${NCU_DIR}/gated_residual.ncu-rep\" --page raw --csv 2>/dev/null | head -50\n\n# Key metrics to extract:\nncu --import \"${NCU_DIR}/gated_residual.ncu-rep\" \\\n    --page details --csv 2>/dev/null | python3 -c \"\nimport csv, sys\nreader = csv.DictReader(sys.stdin)\nkey_metrics = {\n    'gpu__time_duration.avg': 'Duration',\n    'sm__throughput.avg.pct_of_peak_sustained_elapsed': 'Compute (SM) Throughput',\n    'dram__throughput.avg.pct_of_peak_sustained_elapsed': 'DRAM Throughput',\n    'l1tex__throughput.avg.pct_of_peak_sustained_elapsed': 'L1/TEX Cache Throughput',\n    'sm__warps_active.avg.pct_of_peak_sustained_active': 'Achieved Occupancy',\n    'launch__occupancy_limit_registers': 'Block Limit Registers',\n    'launch__occupancy_limit_shared_mem': 'Block Limit Shared Mem',\n}\nfor row in reader:\n    name = row.get('Metric Name', '')\n    if any(alias in name or metric in name for metric, alias in key_metrics.items()):\n        print(f'{name:<60} {row.get(\\\"Metric Value\\\",\\\"\\\")}')\n\"\n```\n\n#### Interpreting ncu results for kernel optimization\n\n| Metric | Good | Action if bad |\n|--------|------|--------------|\n| DRAM throughput > 80% peak | Memory-bound, near optimal | Already saturating HBM — fuse with adjacent ops to reduce total memory traffic |\n| DRAM throughput < 50% peak | Not saturating memory bandwidth | Check coalescing, increase vector width, tune BLOCK sizes |\n| SM throughput > 60% peak | Compute-bound, near optimal | Reduce arithmetic, use faster instructions (e.g., FMA) |\n| SM throughput < 30% peak | Underutilized compute | Increase occupancy, reduce warp stalls, check instruction mix |\n| Achieved occupancy > 50% | Acceptable for most kernels | — |\n| Achieved occupancy < 25% | Too few active warps | Reduce register pressure or shared memory; increase block size |\n\n#### Comparing before/after with ncu\n\n```bash\n# Profile baseline kernel\nncu --kernel-name \"vectorized_elementwise_kernel\" \\\n    --launch-skip 10 --launch-count 3 --set full \\\n    -o \"${NCU_DIR}/baseline\" ./program\n\n# Profile optimized kernel\nncu --kernel-name \"_fused_gated_residual_add_kernel\" \\\n    --launch-skip 10 --launch-count 3 --set full \\\n    -o \"${NCU_DIR}/optimized\" ./program\n\n# Compare key metrics\nfor report in baseline optimized; do\n  echo \"=== $report ===\"\n  ncu --import \"${NCU_DIR}/${report}.ncu-rep\" \\\n      --page details --csv 2>/dev/null | grep -E \"time_duration|throughput.*pct|occupancy\"\ndone\n```\n\n**Decision rule after ncu analysis:**\n- Kernel already at >80% DRAM bandwidth → fuse with neighbors to reduce total traffic\n- Kernel at <50% DRAM bandwidth → tune block sizes, fix coalescing, increase vectorization\n- Kernel compute-bound (SM util high, DRAM low) → reduce FLOPs or switch to a faster algorithm\n- Low occupancy → reduce registers (simplify kernel) or increase block size in autotune configs\n\n### Step 4: Apply Kernel Optimization\n\nAfter pinpointing the slow op, choose the right tool:\n\n| Scenario | Skill to use |\n|----------|-------------|\n| New fused elementwise, norm variant, RoPE variant | **`add-triton-kernel.md`** — Triton JIT, faster iteration, NPU fallback |\n| Bandwidth-bound reduction (RMSNorm) needing max vectorization | **`add-cuda-kernel.md`** — CUDA JIT with `AlignedVector`, warp reductions |\n| Attention or tile-based op needing shared memory tuning | **`add-cuda-kernel.md`** — full control over CUDA primitives |\n| Slow op already covered by existing fused kernel | **`use-efficient-diffusion-kernels.md`** — check constraints & enable |\n\n**Quick decision rule**: start with Triton. Switch to CUDA JIT only when profiling shows Triton can't saturate hardware bandwidth.\n\nBoth kernel types use SGLang's JIT compilation:\n- **Triton**: `python/sglang/jit_kernel/diffusion/triton/<op>.py`\n- **CUDA JIT**: `python/sglang/jit_kernel/csrc/diffusion/<op>.cuh` + wrapper `python/sglang/jit_kernel/diffusion/<op>.py`\n\n### Step 5: torch.compile Coverage\n\n```bash\nTORCH_COMPILE_DEBUG=1 sglang generate ...\n```\n- Dynamic shape changes trigger recompilation → fix resolution and frame count when benchmarking\n- `tensor.item()` in conditional branches causes graph breaks → rewrite as tensor ops\n\n### Step 6: Multi-GPU Efficiency (Wan2.2-T2V-A14B / MOVA)\n\n- Verify `--ulysses-degree` evenly divides `--num-gpus`\n- Keep the command shape fixed when comparing kernels; for quick checks, reduce only `--num-inference-steps`\n- If a run OOMs or jitters because of host contention, first confirm there are no leaked scheduler processes on the chosen GPU set\n\n---\n\n## Optimization Workflow Summary\n\n```\n0. BASELINE\n   sglang generate --seed=42 --save-output → save reference images/videos\n       ↓\n1. BENCHMARK\n   Run benchmark commands above → record denoise latency baseline\n       ↓\n2. LEVEL 1 PROFILE (torch.profiler)\n   --profile --num-profiled-timesteps 3\n   → parse .trace.json.gz → rank ops by CUDA time\n   → identify slow DiT layer (norm / attn / mlp / rope / adaln)\n       ↓\n3. LEVEL 2 PROFILE (nsys + gputrc2graph.py)\n   → result.csv category breakdown (gemm / attn / adaln / norm / triton / cpu)\n   → confirm where GPU time is concentrated\n       ↓\n4. LEVEL 3 PROFILE (ncu — per-kernel deep analysis) ★ CRITICAL\n   → ncu --set full on target kernel(s)\n   → extract DRAM bandwidth util, SM throughput, achieved occupancy\n   → determine if kernel is memory-bound, compute-bound, or latency-bound\n   → for CUDA graph: use --graph-profiling node\n       ↓\n5. KERNEL OPTIMIZATION\n   Existing fused kernel?  → use-efficient-diffusion-kernels.md\n   New Triton kernel?      → add-triton-kernel.md\n   New CUDA JIT kernel?    → add-cuda-kernel.md\n   After writing kernel    → ncu again to verify bandwidth/occupancy ★\n       ↓\n6. VERIFY CORRECTNESS\n   sglang generate --seed=42 --save-output → diff against reference\n   If output differs beyond tolerance → reject optimization\n       ↓\n7. RE-BENCHMARK\n   Verify denoise latency improvement; no regression on other models\n```\n\n---\n\n## Checklist Before Merging\n\n### Correctness (must pass first)\n- [ ] Reference outputs collected with `--seed=42 --save-output` **before** any change\n- [ ] After change: regenerate with identical args and compare\n- [ ] No visible quality degradation in generated images / videos\n- [ ] Correctness verified on all benchmark models\n\n### Performance (only after correctness passes)\n- [ ] All benchmark models executed; denoise latency ★, end-to-end, peak memory recorded\n- [ ] No regression in denoise latency vs. previous baseline (±2% tolerance)\n- [ ] New kernel shows measurable improvement on at least 2 models\n- [ ] No new torch.compile graph breaks introduced\n- [ ] Results reproducible with all offloads disabled and fixed `--seed=42`\n"
  },
  {
    "path": "python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/nsight-profiler.md",
    "content": "---\nname: nsight-profiler\ndescription: Expert skill for NVIDIA Nsight Systems and Nsight Compute profiling tools. Configure profiling sessions, analyze kernel reports, interpret occupancy metrics, roofline model data, memory bandwidth bottlenecks, and warp execution efficiency.\nallowed-tools: Bash(*) Read Write Edit Glob Grep WebFetch\nmetadata:\n  author: babysitter-sdk\n  version: \"1.0.0\"\n  category: performance-profiling\n  backlog-id: SK-002\n  source: \"Adapted from https://github.com/lobehub/lobehub (.agents/skills/nsight-profiler)\"\n---\n\n> **Source**: This skill is adapted from the [lobehub/lobehub](https://github.com/lobehub/lobehub) open-source repository (`.agents/skills/nsight-profiler`). Original author: `babysitter-sdk`.\n\n# nsight-profiler\n\nYou are **nsight-profiler** - a specialized skill for NVIDIA Nsight Systems and Nsight Compute profiling tools. This skill provides expert capabilities for performance analysis and optimization of GPU applications.\n\n## Overview\n\nThis skill enables AI-powered GPU profiling operations including:\n- Configure and execute Nsight Systems profiling sessions\n- Analyze Nsight Compute kernel reports\n- Interpret occupancy metrics and SM utilization\n- Parse and visualize roofline model data\n- Identify memory bandwidth bottlenecks\n- Analyze warp execution efficiency\n- Generate optimization recommendations from profiler data\n- Compare kernel performance across different configurations\n\n## Prerequisites\n\n- NVIDIA Nsight Systems 2023.1+\n- NVIDIA Nsight Compute 2023.1+\n- CUDA Toolkit 11.0+\n- GPU with compute capability 7.0+ (for full profiling features)\n\n## Capabilities\n\n### 1. Nsight Systems Profiling\n\nSystem-wide performance analysis:\n\n```bash\n# Basic system profile\nnsys profile -o report ./cuda_program\n\n# Profile with CUDA API tracing\nnsys profile -t cuda,nvtx,osrt -o report ./cuda_program\n\n# Capture GPU metrics\nnsys profile --gpu-metrics-device=all -o report ./cuda_program\n\n# Profile specific duration\nnsys profile -d 10 -o report ./cuda_program\n\n# Export to multiple formats (one type per command)\nnsys export -t sqlite report.nsys-rep\nnsys export -t json report.nsys-rep\n\n# Generate summary statistics\nnsys stats report.nsys-rep\n```\n\n### 2. Nsight Compute Profiling\n\nDetailed kernel analysis:\n\n```bash\n# Profile all kernels\nncu -o profile ./cuda_program\n\n# Profile specific kernel\nncu --kernel-name myKernel -o profile ./cuda_program\n\n# Full metric collection\nncu --set full -o profile ./cuda_program\n\n# Roofline analysis\nncu --set roofline -o profile ./cuda_program\n\n# Memory analysis\nncu --section MemoryWorkloadAnalysis -o profile ./cuda_program\n\n# Compare two runs\nncu --import baseline.ncu-rep --diff ./cuda_program\n```\n\n### 3. Occupancy Analysis\n\nAnalyze and optimize occupancy:\n\n```bash\n# Collect occupancy metrics\nncu --section Occupancy -o occupancy ./cuda_program\n\n# Key metrics to analyze:\n# - Achieved Occupancy\n# - Theoretical Occupancy\n# - Block Limit (registers, shared memory, warps)\n# - Occupancy Limiter\n```\n\n```cuda\n// Query occupancy in code\nint numBlocks;\nint blockSize = 256;\ncudaOccupancyMaxActiveBlocksPerMultiprocessor(\n    &numBlocks, myKernel, blockSize, sharedMemSize);\n\nfloat occupancy = (numBlocks * blockSize) /\n    (float)deviceProp.maxThreadsPerMultiProcessor;\nprintf(\"Theoretical Occupancy: %.2f%%\\n\", occupancy * 100);\n```\n\n### 4. Roofline Model Analysis\n\nPerformance bound analysis:\n\n```bash\n# Generate roofline data\nncu --set roofline -o roofline ./cuda_program\n\n# Key metrics:\n# - Achieved FLOP/s\n# - Achieved Memory Bandwidth\n# - Arithmetic Intensity (FLOP/byte)\n# - Ridge Point\n```\n\nInterpretation guide:\n- Below memory roofline: Memory bound\n- Below compute roofline: Compute bound\n- At peak: Optimal utilization\n\n### 5. Memory Bandwidth Analysis\n\nIdentify memory bottlenecks:\n\n```bash\n# Memory analysis sections\nncu --section MemoryWorkloadAnalysis \\\n    --section MemoryWorkloadAnalysis_Chart \\\n    --section MemoryWorkloadAnalysis_Tables \\\n    -o memory ./cuda_program\n```\n\nKey metrics:\n- Global Load/Store Throughput\n- L1/L2 Cache Hit Rate\n- Shared Memory Bandwidth\n- Memory Transactions per Request\n\n### 6. Warp Execution Analysis\n\nAnalyze warp efficiency:\n\n```bash\n# Warp state analysis\nncu --section WarpStateStatistics -o warp ./cuda_program\n\n# Scheduler statistics\nncu --section SchedulerStatistics -o scheduler ./cuda_program\n```\n\nKey metrics:\n- Warp Cycles Per Issued Instruction\n- Eligible Warps Per Active Cycle\n- Active Warps Per Scheduler\n- Stall Reasons (memory, sync, execution)\n\n### 7. Kernel Comparison\n\nCompare kernel variants:\n\n```bash\n# Step 1: Profile baseline\nncu --set full -o baseline ./program_v1\n\n# Step 2: Profile optimized version\nncu --set full -o optimized ./program_v2\n\n# Step 3: Export both profiles to CSV, then compare with Python (no GUI needed)\n# Note: --import can only be specified once; --page diff is not a valid page value.\nncu --import baseline.ncu-rep --page details --csv > baseline_details.csv\nncu --import optimized.ncu-rep --page details --csv > optimized_details.csv\n\npython3 -c \"\nimport csv\ndef load(p):\n    return {r.get('Metric Name',''): r.get('Metric Value','')\n            for r in csv.DictReader(open(p))}\nb = load('baseline_details.csv')\no = load('optimized_details.csv')\nfor k in sorted(set(b) | set(o)):\n    bv, ov = b.get(k,''), o.get(k,'')\n    if bv != ov:\n        print(f'{k[:55]:<55} {bv} -> {ov}')\n\"\n\n### 8. Performance Recommendations\n\nAutomated analysis:\n\n```bash\n# Get optimization recommendations\nncu --section SpeedOfLight \\\n    --section SpeedOfLight_RooflineChart \\\n    -o speedoflight ./cuda_program\n\n# Export with recommendations\nncu --import profile.ncu-rep --page details --csv > details.csv\n```\n\n## Common Profiling Workflows\n\n### Workflow 1: Initial Performance Assessment\n\n```bash\n# Step 1: System overview\nnsys profile -t cuda -o system_overview ./program\nnsys stats system_overview.nsys-rep\n\n# Step 2: Identify hot kernels\nncu --launch-skip 10 --launch-count 5 -o hot_kernels ./program\n\n# Step 3: Deep dive on bottleneck kernel\nncu --kernel-name hotKernel --set full -o detailed ./program\n```\n\n### Workflow 2: Memory Optimization\n\n```bash\n# Analyze memory access patterns\nncu --section SourceCounters \\\n    --section MemoryWorkloadAnalysis \\\n    --kernel-name targetKernel \\\n    -o memory_analysis ./program\n\n# Check for coalescing issues\nncu --metrics l1tex__t_sectors_pipe_lsu_mem_global_op_ld.sum,\\\nl1tex__t_requests_pipe_lsu_mem_global_op_ld.sum \\\n    -o coalescing ./program\n```\n\n### Workflow 3: Occupancy Optimization\n\n```bash\n# Profile with occupancy focus\nncu --section Occupancy \\\n    --section LaunchStatistics \\\n    -o occupancy ./program\n```\n\n**Interpreting occupancy limiters** (from the `Occupancy` section report):\n\n| Limiter shown | Fix |\n|---------------|-----|\n| `Registers` | Reduce register pressure: use fewer local variables, add `maxnreg` hint |\n| `Shared Memory` | Decrease shared memory allocation or use 32-bit instead of 64-bit |\n| `Block Size` | Increase threads per block; ensure block size is a multiple of warp size (32) |\n| `Warp Limit` | Already at theoretical max for this SM; no action needed |\n\n> **For Triton kernels**: block sizes are controlled via `@triton.autotune` configs, not CLI flags. To test occupancy at different block sizes, add or modify the `triton.Config({\"BLOCK_C\": N}, num_warps=W)` entries in the autotune list and re-run. Do **not** pass `--block-size` as a CLI argument — the Triton benchmark script does not accept it.\n\n## Dependencies\n\n- Nsight Systems 2023.1+\n- Nsight Compute 2023.1+\n- CUDA Toolkit 11.0+\n\n## Constraints\n\n- Full profiling requires root/admin privileges\n- Some metrics only available on specific GPU architectures\n- Profiling adds overhead; results may differ from production\n- Nsight Compute profiles one kernel invocation at a time by default\n"
  },
  {
    "path": "python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/references/a100-optimization-guide.md",
    "content": "# A100 GPU Optimization Guide — SGLang Diffusion JIT Kernels\n\nDeep dive into A100-specific optimizations for diffusion model CUDA kernels in SGLang's JIT system.\n\n> **Adapted from**: [HuggingFace kernels cuda-kernels skill](https://github.com/huggingface/kernels/tree/main/skills/cuda-kernels)\n\n---\n\n## A100 Ampere Architecture Overview\n\n| Component | A100 40GB | A100 80GB | Notes |\n|-----------|-----------|-----------|-------|\n| Compute Capability | sm_80 | sm_80 | Use `\"-arch=sm_80\"` in `extra_cuda_cflags` |\n| SMs | 108 | 108 | Grid: aim for multiples of 108 |\n| Shared Memory | 164 KB/SM | 164 KB/SM | Configurable: 48/96/164 KB |\n| L2 Cache | 40 MB | 40 MB | Less than H100 (50 MB) |\n| Memory Bandwidth | 1.55 TB/s | 2.0 TB/s | HBM2e |\n| Max Threads/SM | 2048 | 2048 | Same as H100 |\n| Tensor Cores | 3rd gen | 3rd gen | FP16, BF16, TF32, INT8, INT4 |\n\n### A100 vs H100 Comparison\n\n| Feature | A100 | H100 | Impact on JIT Kernels |\n|---------|------|------|-----------------------|\n| Memory BW | 2.0 TB/s | 3.35 TB/s | H100 ~67% faster for memory-bound ops |\n| SMs | 108 | 132 | Adjust persistent kernel grid sizing |\n| Shared Mem/SM | 164 KB | 192 KB | Reduce max tile sizes on A100 |\n| L2 Cache | 40 MB | 50 MB | Attention tile reuse still works well |\n| TMA | No | Yes | Can't use `cp.async.bulk` on A100 |\n| FP8 | No | Yes | Use FP16/BF16 only on A100 |\n\n---\n\n## Memory Access Optimization\n\nSame coalescing and vectorization rules as H100; lower bandwidth makes them even more critical.\n\n### `AlignedVector` Vectorization (same pattern as H100)\n\n```cpp\n#include <sgl_kernel/vec.cuh>\n\nconstexpr int kVecN = 16 / sizeof(T);   // 8 for bf16/fp16, 4 for fp32\nusing vec_t = device::AlignedVector<T, kVecN>;\n\nvec_t v;\nv.load(src, vi);\n// ... process elements ...\nv.store(dst, vi);\n```\n\n**Expected A100 performance (BF16 RMSNorm):**\n\n| Implementation | A100 (ms) | H100 (ms) | A100 Speedup |\n|:---|:---:|:---:|:---:|\n| Scalar loads | ~0.10 | 0.065 | 1.00x |\n| `AlignedVector<bf16_t, 8>` | ~0.03 | 0.019 | ~3x |\n\n**Target bandwidth**: 30–40% of A100's 2.0 TB/s = 600–800 GB/s.\n\n### Shared Memory Configuration\n\n```cpp\n// A100 max: 164 KB/SM\ncudaFuncSetAttribute(\n    your_kernel,\n    cudaFuncAttributeMaxDynamicSharedMemorySize,\n    164 * 1024  // 164 KB max on A100\n);\n```\n\nAttention tile sizes for A100:\n\n```\nBLOCK_SIZE_M = 128  (Q block)\nBLOCK_SIZE_N = 64   (K,V block)\nTile = 128×64×2 = 16 KB (FP16) — fits in 164 KB shared mem\n```\n\n---\n\n## Occupancy Tuning\n\n**Grid sizing for A100 (108 SMs):**\n\n```cpp\n#include <sgl_kernel/runtime.cuh>\n\n// Cap blocks to SM × occupancy (same pattern as H100)\nstatic const uint32_t max_occ = host::runtime::get_blocks_per_sm(kernel, kBlockSize);\nstatic const uint32_t num_sm  = host::runtime::get_sm_count(device.unwrap().device_id);\nconst uint32_t num_blocks = std::min(num_sm * max_occ, host::div_ceil(n, kBlockSize));\n```\n\n**Recommended block sizes (same as H100):**\n\n| Kernel Type | Threads/Block | Notes |\n|-------------|---------------|-------|\n| Element-wise | 256 | High occupancy |\n| Row reduction | 512 | Full reduction per row |\n| Tiled/attention | 256 | Balance shared mem |\n\n---\n\n## A100-Specific Features\n\n### Async Memory Copy (sm_80)\n\nA100 introduced `cp.async` for overlapping compute and memory. Use this in custom kernels for prefetching:\n\n```cuda\n#if __CUDA_ARCH__ >= 800\n// Async copy from global to shared (A100+)\n__pipeline_memcpy_async(smem_ptr, global_ptr, bytes);\n__pipeline_commit();\n__pipeline_wait_prior(0);\n#endif\n```\n\n### TF32 Mode (A100 specific)\n\nEnables FP32-range with FP16-like throughput for GEMM. Enable in Python:\n\n```python\n# Enable TF32 for matmuls (A100+)\ntorch.backends.cuda.matmul.allow_tf32 = True\ntorch.backends.cudnn.allow_tf32 = True\n```\n\nTF32 is automatic for FP32 GEMMs via cuBLAS — no kernel changes needed.\n\n### Structural Sparsity (2:4)\n\nA100 tensor cores support 50% structured sparsity:\n\n```python\nfrom torch.sparse import to_sparse_semi_structured\nsparse_weight = to_sparse_semi_structured(dense_weight)\n# ~2x GEMM speedup for matmul with sparse weight\n```\n\n---\n\n## JIT Compilation for A100\n\n```python\nreturn load_jit(\n    \"my_kernel\",\n    *args,\n    cuda_files=[\"diffusion/my_kernel.cuh\"],\n    cuda_wrappers=[(\"my_kernel\", f\"my_kernel<{args}>\")],\n    extra_cuda_cflags=[\n        \"-O3\",\n        \"--use_fast_math\",\n        \"-arch=sm_80\",   # A100 only; omit for multi-arch\n    ],\n)\n```\n\n**Multi-arch (A100 + H100):**\n\n```python\nextra_cuda_cflags=[\n    \"-O3\", \"--use_fast_math\",\n    \"-gencode=arch=compute_80,code=sm_80\",   # A100\n    \"-gencode=arch=compute_90,code=sm_90\",   # H100\n]\n```\n\nRuntime arch guard (in Python wrapper):\n\n```python\ncap = torch.cuda.get_device_capability()\nif cap < (8, 0):\n    raise RuntimeError(f\"This kernel requires sm_80 (A100) or later, got sm_{cap[0]}{cap[1]}\")\n```\n\n---\n\n## H100 → A100 Migration Checklist\n\nWhen porting an H100-optimized kernel to A100:\n\n| Item | H100 | A100 | Change Required |\n|------|------|------|-----------------|\n| Shared memory | 192 KB | 164 KB | Reduce `cudaFuncSetAttribute` size |\n| Grid sizing | ×132 SMs | ×108 SMs | `get_sm_count()` handles automatically |\n| TMA bulk copy | Available | **Not available** | Remove `cp.async.bulk`; use standard `__pipeline_memcpy_async` |\n| FP8 | Available | **Not available** | Fall back to FP16/BF16 |\n| PDL | Supported | Supported | `.enable_pdl(true)` works on sm_80 |\n| Warp shuffles | Same | Same | No changes |\n| `AlignedVector` | Same | Same | No changes |\n\n**Conditional compilation:**\n\n```cuda\n#if __CUDA_ARCH__ >= 900\n    // H100-only: TMA, FP8, thread block clusters\n    #define USE_TMA 1\n#elif __CUDA_ARCH__ >= 800\n    // A100: cp.async, TF32, 2:4 sparsity\n    #define USE_ASYNC_COPY 1\n#endif\n```\n\n---\n\n## Precision Notes\n\n| Type | Available on A100 | Notes |\n|------|-------------------|-------|\n| FP16 | Yes | Good, watch overflow in attention |\n| BF16 | Yes | Preferred for training and inference |\n| TF32 | Yes (A100 specific) | Auto for FP32 GEMMs |\n| FP8 | **No** | H100 only |\n\n---\n\n## Performance Profiling\n\n### NVIDIA Nsight Systems (nsys)\n\n```bash\nnsys profile -o a100_profile python scripts/bench_diffusion_rmsnorm.py\n\n# Key metrics to watch:\n# - Kernel duration\n# - Memory transfer time\n# - GPU idle time\n# - Stream utilization\n```\n\n### NVIDIA Nsight Compute (ncu)\n\n```bash\n# Full metrics\nncu --set full -o a100_metrics.ncu-rep \\\n  python scripts/bench_diffusion_rmsnorm.py\n\n# Specific metrics for bandwidth / occupancy checks\nncu --metrics sm__throughput.avg.pct_of_peak_sustained_elapsed,\\\ndram__throughput.avg.pct_of_peak_sustained_elapsed \\\n  python scripts/bench_diffusion_rmsnorm.py\n\n# Key metrics for A100 diffusion kernels:\n# - Achieved occupancy        (sm__warps_active.avg.pct_of_peak_sustained_active)\n# - Memory throughput         (dram__throughput.avg.pct_of_peak_sustained_elapsed)\n#   → Target: 30–40% of 2.0 TB/s (600–800 GB/s) for vectorized kernels\n# - Compute throughput        (sm__throughput.avg.pct_of_peak_sustained_elapsed)\n# - Warp stall reasons        (smsp__warp_issue_stalled_*.avg.pct_of_peak_sustained_active)\n# - Kernel time               (gpu__time_duration.avg)\n```\n\n### Common A100 Performance Issues\n\n1. **Memory bound below target**: `dram__throughput` < 30%\n   - Fix: Use `AlignedVector<bf16_t, 8>` (128-bit vector loads)\n\n2. **Low occupancy**: Grid too small for 108 SMs\n   - Fix: Use `runtime::get_sm_count()` persistent kernel pattern\n\n3. **No TF32 for FP32 GEMMs**: torch.backends.cuda.matmul.allow_tf32 not set\n   - Fix: `torch.backends.cuda.matmul.allow_tf32 = True`\n\n---\n\n## Best Practices Summary (A100)\n\n1. **Bandwidth**: Even more critical than H100 — profile with `ncu` first\n2. **Vectorization**: `AlignedVector<bf16_t, 8>` gives ~3x over scalar\n3. **TF32**: Enable for any FP32 matmul workload\n4. **Shared memory**: Cap at 164 KB; use `cudaFuncSetAttribute`\n5. **Grid sizing**: Multiples of 108 SMs via `runtime::get_sm_count`\n6. **cp.async**: Use for prefetching in tiled kernels\n7. **Multi-arch**: Build for both `sm_80` and `sm_90` to support both GPUs\n8. **Same abstractions**: `AlignedVector`, `TensorMatcher`, `LaunchKernel` work identically\n\n## Reference Benchmark Results (A100 80GB, BF16)\n\n| Kernel | Shape | A100 (ms) | H100 (ms) | H100 Speedup |\n|--------|-------|-----------|-----------|--------------|\n| RMSNorm | [2, 1024, 2048] | ~0.08 | 0.054 | 1.5x |\n| GEGLU | [2, 1024, 4096] | ~0.05 | 0.030 | 1.7x |\n"
  },
  {
    "path": "python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/references/h100-optimization-guide.md",
    "content": "# H100 GPU Optimization Guide — SGLang Diffusion JIT Kernels\n\nDeep dive into H100-specific optimizations for diffusion model CUDA kernels, written for SGLang's JIT kernel system.\n\n> **Adapted from**: [HuggingFace kernels cuda-kernels skill](https://github.com/huggingface/kernels/tree/main/skills/cuda-kernels)\n\n---\n\n## H100 Hopper Architecture Overview\n\n| Component | Specification | Optimization Implication |\n|-----------|---------------|--------------------------|\n| Compute Capability | sm_90 | Use `extra_cuda_cflags=[\"-arch=sm_90\"]` in `load_jit` |\n| SMs | 132 | Grid: aim for multiples of 132 |\n| Shared Memory | 192 KB/SM | Configurable: 96/144/192 KB |\n| L2 Cache | 50 MB | Tile K,V of attention to fit in L2 |\n| Memory Bandwidth | 3.35 TB/s | BF16 vectorized: achieves ~38% (~1.27 TB/s) |\n| Max Threads/SM | 2048 | Max 16 blocks of 128 threads per SM |\n| Warp Size | 32 | All reductions use `warp::reduce_sum` |\n| Registers | 64K 32-bit/SM | 255 per thread max |\n\n### New Hopper Features (sm_90+)\n\n1. **Thread Block Clusters** — groups cooperating via Distributed Shared Memory\n2. **TMA (Tensor Memory Accelerator)** — hardware-accelerated bulk copies\n3. **FP8 support** — native 8-bit floating point in tensor cores\n4. **PDL (Programmatic Dependent Launch)** — enable with `.enable_pdl(true)` in `LaunchKernel`\n\nGate sm_90+ features with a runtime check before calling `load_jit`:\n\n```python\nif torch.cuda.get_device_capability()[0] < 9:\n    raise RuntimeError(\"This kernel requires H100 (sm_90+)\")\n```\n\n---\n\n## Memory Hierarchy Optimization\n\n### Coalesced Global Memory Access\n\n```cpp\n// GOOD: threads read consecutive addresses → 128-byte transaction per warp\nuint32_t idx = blockIdx.x * blockDim.x + threadIdx.x;\nfp16_t val = src[idx];\n\n// BAD: strided access → multiple transactions, lower effective bandwidth\nuint32_t idx = threadIdx.x * stride;  // avoid stride > 1\n```\n\n**Transaction sizes**: 32 bytes minimum, 128 bytes optimal (full warp, FP32).\n\n### Vectorized Memory Access with `AlignedVector`\n\nSGLang's `AlignedVector<T, N>` provides 128-bit (16-byte) vector loads. Always use this instead of raw pointer reinterprets.\n\n```cpp\n#include <sgl_kernel/vec.cuh>\n\n// 16 bytes per load: 8×bf16_t, 8×fp16_t, or 4×fp32_t\nconstexpr int kVecN = 16 / sizeof(T);\nusing vec_t = device::AlignedVector<T, kVecN>;\n\n// Load\nvec_t v;\nv.load(src, vi);           // loads src[vi * kVecN .. vi * kVecN + kVecN - 1]\n\n// Process\n#pragma unroll\nfor (int i = 0; i < kVecN; ++i) {\n    float val = static_cast<float>(v[i]);\n    // ... compute ...\n    v[i] = static_cast<T>(result);\n}\n\n// Store\nv.store(dst, vi);\n```\n\n**RMSNorm benchmark (H100 80GB, BF16):**\n\n| Implementation | Time (ms) | Speedup |\n|:---|:---:|:---:|\n| Scalar loads | 0.065 | 1.00x |\n| `AlignedVector<bf16_t, 8>` | 0.019 | **3.37x** |\n\nBandwidth achieved: **~38% of 3.35 TB/s** = 1.27 TB/s.\n\n### L2 Cache Utilization (50 MB)\n\nFor attention, tile K and V so they stay in L2 while Q iterates:\n\n```\nBLOCK_SIZE_M = 128  (Q block)\nBLOCK_SIZE_N = 64   (K,V block)\nWith head_dim=64: tile = 128×64×2 = 16 KB (FP16), multiple tiles fit in L2\n```\n\n### Shared Memory Configuration\n\nRequest max shared memory for attention kernels:\n\n```cpp\n// In launcher (after selecting kernel function pointer):\ncudaFuncSetAttribute(\n    your_kernel,\n    cudaFuncAttributeMaxDynamicSharedMemorySize,\n    192 * 1024  // 192 KB max on H100\n);\n```\n\nShared memory has 32 banks (4 bytes/bank). Avoid conflicts with padding:\n\n```cpp\n__shared__ float data[32][33];  // 33 instead of 32 → no bank conflict\n```\n\n---\n\n## Warp & CTA Reductions (SGLang Abstractions)\n\nUse `sgl_kernel/warp.cuh` and `sgl_kernel/cta.cuh` — never raw `__shfl_xor_sync`.\n\n```cpp\n#include <sgl_kernel/warp.cuh>\n#include <sgl_kernel/cta.cuh>\n\n// Warp-level sum (uses __shfl_xor_sync internally)\nfloat result = device::warp::reduce_sum<float>(partial);\n\n// Warp-level max\nfloat mx = device::warp::reduce_max<float>(val);\n\n// CTA-wide max via shared memory\n__shared__ float smem[32];\ndevice::cta::reduce_max<float>(val, smem, -1e38f);\n// smem[0] holds the result after __syncthreads()\n```\n\n**Block reduction pattern for RMSNorm:**\n\n```cpp\n// 1. Warp reduction\nsum_sq = device::warp::reduce_sum<float>(sum_sq);\n\n// 2. Write warp leaders to smem\n__shared__ float smem_r[32];\nif (threadIdx.x % 32 == 0) smem_r[threadIdx.x / 32] = sum_sq;\n__syncthreads();\n\n// 3. Final warp reduction over warp leaders\nif (threadIdx.x < 32) {\n    sum_sq = (threadIdx.x < blockDim.x / 32) ? smem_r[threadIdx.x] : 0.f;\n    sum_sq = device::warp::reduce_sum<float>(sum_sq);\n}\n__syncthreads();\n```\n\n---\n\n## Occupancy Tuning\n\n```\nOccupancy = Active Warps per SM / Max Warps per SM (64)\n\nLimiting factors on H100:\n  1. Registers: 65536 / (threads_per_block × regs_per_thread)\n  2. Shared Memory: 192 KB / smem_per_block\n  3. Threads: 2048 / threads_per_block\n```\n\n**Recommended block sizes:**\n\n| Kernel Type | Threads/Block | Warps | Reasoning |\n|-------------|---------------|-------|-----------|\n| Element-wise (RoPE, GEGLU) | 256 | 8 | High occupancy, simple |\n| Row reduction (RMSNorm, LayerNorm) | 256–512 | 8–16 | Enough threads for full reduction |\n| Tiled (attention) | 256 | 8 | Balance shared mem and registers |\n\n**Persistent kernel pattern** (cap grid to SM × occupancy):\n\n```cpp\n#include <sgl_kernel/runtime.cuh>\n\nstatic const uint32_t max_occ = host::runtime::get_blocks_per_sm(kernel, kBlockSize);\nstatic const uint32_t num_sm  = host::runtime::get_sm_count(device.unwrap().device_id);\nconst uint32_t num_blocks = std::min(num_sm * max_occ, host::div_ceil(n, kBlockSize));\nhost::LaunchKernel(num_blocks, kBlockSize, device.unwrap())(kernel, params);\n```\n\n---\n\n## Precision and Numerical Stability\n\n| Type | Exponent Bits | Mantissa Bits | Range | Use Case |\n|------|--------------|---------------|-------|----------|\n| FP16 | 5 | 10 | ±65504 | Inference; attention score overflow risk |\n| BF16 | 8 | 7 | ±3.39×10³⁸ | Training/inference preferred; safer for attn |\n| FP32 | 8 | 23 | ±3.39×10³⁸ | Accumulation only |\n\n**Mixed precision pattern** (always accumulate in FP32):\n\n```cpp\n// Input via AlignedVector\nvec_t v;\nv.load(src, vi);\nfloat acc = 0.f;\n#pragma unroll\nfor (int i = 0; i < kVecN; ++i) {\n    float val = static_cast<float>(v[i]);  // promote to FP32\n    acc += val * val;\n}\n// Output\nv[i] = static_cast<T>(fp32_result);       // demote back\n```\n\n---\n\n## Diffusion-Specific Patterns\n\n### DiT Block Operators\n\n| Operator | Pattern | Key Constraint |\n|----------|---------|----------------|\n| **RMSNorm** | 2-pass row reduction | weight may be `None` |\n| **AdaLN** | `norm(x) * (1 + scale) + shift` | fuse norm+scale+shift |\n| **RoPE 3D** | `[B, t*h*w, heads, head_dim]` | layout: `seq = t*h*w` |\n| **GEGLU** | `gelu(gate) * value`, input `[B,L,2H]` | don't use for LTX-Video (uses GELU) |\n| **SiLU gate** | `x * sigmoid(x)` | fuse with MLP linear |\n\n### Online Softmax (for custom attention)\n\n```cuda\n// Numerically stable without materializing full [seq×seq] score matrix\nfloat row_max = -INFINITY, row_sum = 0.f;\nfor each K block:\n    compute local_scores\n    new_max = max(row_max, max(local_scores))\n    rescale = exp(row_max - new_max)\n    row_sum = row_sum * rescale + sum(exp(local_scores - new_max))\n    out_acc = out_acc * rescale + softmax(local_scores) @ V_block\n    row_max = new_max\n```\n\n---\n\n## Profiling and Debugging\n\n### NVIDIA Nsight Systems (nsys)\n\nSystem-wide profiling to see kernel durations, memory transfers, and GPU idle time:\n\n```bash\nnsys profile -o profile_report python scripts/bench_diffusion_rmsnorm.py\n\n# Key metrics to watch:\n# - Kernel duration\n# - Memory transfer time\n# - GPU idle time\n# - Stream utilization\n```\n\nFor end-to-end denoise profiling via `sglang generate`, see `diffusion-benchmark-and-profile.md` (Level 2: nsys + gputrc2graph.py).\n\n### NVIDIA Nsight Compute (ncu)\n\nDetailed per-kernel analysis for tuning individual JIT CUDA kernels:\n\n```bash\n# Full metrics — use when you need everything (slow)\nncu --set full -o metrics.ncu-rep \\\n  python scripts/bench_diffusion_rmsnorm.py\n\n# Specific metrics — use for targeted bandwidth / occupancy checks\nncu --metrics sm__throughput.avg.pct_of_peak_sustained_elapsed,\\\ndram__throughput.avg.pct_of_peak_sustained_elapsed \\\n  python scripts/bench_diffusion_rmsnorm.py\n\n# Key metrics for diffusion JIT kernels:\n# - Achieved occupancy        (sm__warps_active.avg.pct_of_peak_sustained_active)\n# - Memory throughput         (dram__throughput.avg.pct_of_peak_sustained_elapsed)\n# - Compute throughput        (sm__throughput.avg.pct_of_peak_sustained_elapsed)\n# - Warp stall reasons        (smsp__warp_issue_stalled_*.avg.pct_of_peak_sustained_active)\n# - L1 cache hit rate         (l1tex__t_requests_pipe_lsu_mem_global_op_ld.sum)\n```\n\n### Common Performance Issues\n\n1. **Low occupancy**: Too many registers or shared memory per block\n   - Check: `--ptxas-options=-v` in `extra_cuda_cflags` to see register count\n   - Fix: Reduce `--maxrregcount=N`; use smaller block size\n\n2. **Memory bound, low bandwidth**: Achieved < 30% of 3.35 TB/s\n   - Check: `dram__throughput.avg.pct_of_peak_sustained_elapsed`\n   - Fix: Switch to `AlignedVector<T, 16/sizeof(T)>` for 128-bit vector loads\n\n3. **Shared memory bank conflicts**: `l1tex__data_bank_conflicts_pipe_lmem_op_st.sum` is high\n   - Fix: Add padding — `__shared__ float data[32][33]`\n\n4. **Warp divergence**: Conditional branches splitting warps\n   - Check: `smsp__warp_issue_stalled_branch.avg.pct_of_peak_sustained_active`\n   - Fix: Restructure so elements with identical branches are in the same warp\n\n5. **Too many small kernels**: High kernel launch overhead\n   - Fix: Fuse operations (e.g., norm + scale + shift → AdaLN in one kernel)\n\n---\n\n## JIT Compilation Notes\n\nSGLang's JIT compiles kernels on first use via `load_jit`. For H100-specific flags:\n\n```python\nreturn load_jit(\n    \"my_kernel\",\n    *args,\n    cuda_files=[\"diffusion/my_kernel.cuh\"],\n    cuda_wrappers=[(\"my_kernel\", f\"my_kernel<{args}>\")],\n    extra_cuda_cflags=[\n        \"-O3\",\n        \"--use_fast_math\",\n        \"-arch=sm_90\",          # H100 only; omit for multi-arch\n        \"--ptxas-options=-v\",   # Remove after tuning\n    ],\n)\n```\n\nFor multi-arch (H100 + A100):\n\n```python\nextra_cuda_cflags=[\n    \"-O3\",\n    \"--use_fast_math\",\n    \"-gencode=arch=compute_80,code=sm_80\",   # A100\n    \"-gencode=arch=compute_90,code=sm_90\",   # H100\n]\n```\n\n---\n\n## Best Practices Summary\n\n1. **Memory access**: Coalesce writes, align to 128-byte boundaries\n2. **Vectorization**: Use `AlignedVector<T, 16/sizeof(T)>` for all element-wise loads/stores\n3. **Reductions**: Use `warp::reduce_sum/max`, then shared memory pattern above\n4. **Precision**: BF16 for I/O, FP32 for accumulation; use `static_cast<float>`\n5. **Block size**: 256 threads default; 512 for reductions; tune with `runtime::get_blocks_per_sm`\n6. **Grid sizing**: Multiples of 132 SMs; use persistent kernel pattern for small N\n7. **Shared memory**: Add padding (`[32][33]`) to avoid bank conflicts\n8. **Profile**: Run `ncu` before claiming a speedup; check dram throughput %\n9. **Fuse**: Combine norm + scale + shift into a single pass to reduce memory traffic\n10. **Abstractions**: Always use `TensorMatcher`, `AlignedVector`, `LaunchKernel` — never raw CUDA\n\n## Reference Benchmark Results (H100 80GB, BF16)\n\n| Kernel | Shape | Time (ms) |\n|--------|-------|-----------|\n| RMSNorm | [2, 1024, 2048] | 0.054 |\n| GEGLU | [2, 1024, 4096] → [2, 1024, 2048] | 0.030 |\n| RoPE 3D | [2, 480, 8, 64] | 1.670 |\n| RMSNorm vectorized | [1, 1024, 2048] | 0.019 |\n| RMSNorm vectorized | [4, 4096, 3072] | 0.157 |\n\n> See `kernel-templates.md` for copy-paste ready sglang JIT kernel implementations.\n"
  },
  {
    "path": "python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/references/kernel-templates.md",
    "content": "# CUDA Kernel Templates — SGLang Diffusion JIT Style\n\nCopy-paste ready templates for JIT CUDA kernels in `python/sglang/jit_kernel/csrc/diffusion/`.\nAll templates use SGLang's internal abstractions; no raw CUDA headers needed.\n\n> **Adapted from**: [HuggingFace kernels cuda-kernels skill](https://github.com/huggingface/kernels/tree/main/skills/cuda-kernels)\n\n---\n\n## Prerequisite: Standard Includes\n\nEvery kernel file in `csrc/diffusion/` starts with:\n\n```cpp\n#include <sgl_kernel/tensor.h>    // TensorMatcher, SymbolicSize, SymbolicDevice\n#include <sgl_kernel/type.cuh>    // fp16_t, bf16_t, fp32_t, dtype_trait, packed_t\n#include <sgl_kernel/utils.h>     // RuntimeCheck, Panic, div_ceil\n#include <sgl_kernel/utils.cuh>   // LaunchKernel, SGL_DEVICE, type aliases\n#include <sgl_kernel/vec.cuh>     // AlignedVector<T, N>\n#include <sgl_kernel/warp.cuh>    // warp::reduce_sum, warp::reduce_max\n#include <sgl_kernel/math.cuh>    // device::math::rsqrt, sqrt, ...\n#include <sgl_kernel/tile.cuh>    // tile::Memory (strided access pattern)\n\n#include <dlpack/dlpack.h>\n#include <tvm/ffi/container/tensor.h>\n```\n\n**Key type aliases** (from `utils.cuh`):\n- `fp16_t` = `__half`, `fp16x2_t` = `__half2`\n- `bf16_t` = `__nv_bfloat16`, `bf16x2_t` = `__nv_bfloat162`\n- `fp32_t` = `float`, `fp32x2_t` = `float2`\n- `SGL_DEVICE` = `__forceinline__ __device__`\n\n---\n\n## Template 1: Element-wise Operation\n\nUse for ops that process elements independently: RoPE, SiLU, GEGLU, scale+bias.\n\n### `.cuh` file: `csrc/diffusion/silu_gate.cuh`\n\n```cpp\n#include <sgl_kernel/tensor.h>\n#include <sgl_kernel/type.cuh>\n#include <sgl_kernel/utils.h>\n#include <sgl_kernel/utils.cuh>\n#include <sgl_kernel/vec.cuh>\n#include <sgl_kernel/math.cuh>\n\n#include <dlpack/dlpack.h>\n#include <tvm/ffi/container/tensor.h>\n\nnamespace {\n\n// SiLU gate: out[i] = x[i] * sigmoid(x[i])\n// Input layout: [B, L, hidden]\ntemplate <typename T, int kVecN>\n__global__ void silu_gate_kernel(\n    T* __restrict__ dst,\n    const T* __restrict__ src,\n    uint32_t n_vecs,\n    uint32_t n_remainder,\n    uint32_t n_total)\n{\n    using vec_t = device::AlignedVector<T, kVecN>;\n\n    const uint32_t stride = blockDim.x * gridDim.x;\n\n    // --- vectorized body ---\n    for (uint32_t vi = blockIdx.x * blockDim.x + threadIdx.x; vi < n_vecs; vi += stride) {\n        vec_t v;\n        v.load(src, vi);\n        #pragma unroll\n        for (int i = 0; i < kVecN; ++i) {\n            float val = static_cast<float>(v[i]);\n            float sig = 1.f / (1.f + device::math::exp<float>(-val));\n            v[i] = static_cast<T>(val * sig);\n        }\n        v.store(dst, vi);\n    }\n\n    // --- scalar tail (for sizes not divisible by kVecN) ---\n    const uint32_t base = n_vecs * kVecN;\n    for (uint32_t i = blockIdx.x * blockDim.x + threadIdx.x; i < n_remainder; i += stride) {\n        float val = static_cast<float>(src[base + i]);\n        float sig = 1.f / (1.f + device::math::exp<float>(-val));\n        dst[base + i] = static_cast<T>(val * sig);\n    }\n}\n\ntemplate <typename T>\nvoid silu_gate(tvm::ffi::TensorView dst, tvm::ffi::TensorView src) {\n    using namespace host;\n\n    SymbolicSize N{\"num_elements\"};\n    SymbolicDevice device;\n    device.set_options<kDLCUDA>();\n\n    TensorMatcher({N})\n        .with_dtype<T>()\n        .with_device(device)\n        .verify(dst)\n        .verify(src);\n\n    const uint32_t n      = static_cast<uint32_t>(N.unwrap());\n    const DLDevice dev    = device.unwrap();\n    RuntimeCheck(n > 0, \"silu_gate: num_elements must be > 0\");\n\n    constexpr int kVecN      = 16 / sizeof(T);   // 128-bit vector load\n    const uint32_t n_vecs    = n / kVecN;\n    const uint32_t n_rem     = n % kVecN;\n\n    constexpr uint32_t kBlock = 256;\n    const uint32_t grid       = div_ceil(std::max(n_vecs, n_rem), kBlock);\n\n    LaunchKernel(grid, kBlock, dev)(\n        silu_gate_kernel<T, kVecN>,\n        static_cast<T*>(dst.data_ptr()),\n        static_cast<const T*>(src.data_ptr()),\n        n_vecs, n_rem, n);\n}\n\n}  // namespace\n```\n\n### Python wrapper: `diffusion/silu_gate.py`\n\n```python\nfrom __future__ import annotations\nimport torch\nfrom sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args\n\n@cache_once\ndef _jit_silu_gate_module(dtype: torch.dtype):\n    args = make_cpp_args(dtype)\n    return load_jit(\n        \"diffusion_silu_gate\",\n        *args,\n        cuda_files=[\"diffusion/silu_gate.cuh\"],\n        cuda_wrappers=[(\"silu_gate\", f\"silu_gate<{args}>\")],\n        extra_cuda_cflags=[\"-O3\", \"--use_fast_math\"],\n    )\n\ndef diffusion_silu_gate(src: torch.Tensor, out: torch.Tensor | None = None) -> torch.Tensor:\n    assert src.is_cuda and src.dtype in (torch.float16, torch.bfloat16, torch.float32)\n    if out is None:\n        out = torch.empty_like(src)\n    module = _jit_silu_gate_module(src.dtype)\n    module.silu_gate(out, src)\n    return out\n```\n\n---\n\n## Template 2: Row-wise Reduction (RMSNorm / LayerNorm)\n\nUse for ops that reduce across the last dimension of each row.\n\n### `.cuh` file: `csrc/diffusion/rmsnorm.cuh`\n\n```cpp\n#include <sgl_kernel/tensor.h>\n#include <sgl_kernel/type.cuh>\n#include <sgl_kernel/utils.h>\n#include <sgl_kernel/utils.cuh>\n#include <sgl_kernel/vec.cuh>\n#include <sgl_kernel/warp.cuh>\n#include <sgl_kernel/math.cuh>\n\n#include <dlpack/dlpack.h>\n#include <tvm/ffi/container/tensor.h>\n\nnamespace {\n\n// RMSNorm: y = x / rms(x) * weight\n// One block per row; vectorized loads/stores; warp + shared-mem reduction\ntemplate <typename T, int kVecN>\n__global__ void rmsnorm_kernel(\n    T* __restrict__ dst,\n    const T* __restrict__ src,\n    const T* __restrict__ weight,   // nullptr if no affine weight\n    uint32_t hidden,\n    uint32_t n_vecs,\n    float eps)\n{\n    using vec_t = device::AlignedVector<T, kVecN>;\n\n    const uint32_t row     = blockIdx.x;\n    const T* row_src       = src + row * hidden;\n    T*       row_dst       = dst + row * hidden;\n\n    // Pass 1: sum of squares\n    float sum_sq = 0.f;\n    for (uint32_t vi = threadIdx.x; vi < n_vecs; vi += blockDim.x) {\n        vec_t v;\n        v.load(row_src, vi);\n        #pragma unroll\n        for (int i = 0; i < kVecN; ++i) {\n            float val = static_cast<float>(v[i]);\n            sum_sq += val * val;\n        }\n    }\n\n    // Warp + block reduction\n    sum_sq = device::warp::reduce_sum<float>(sum_sq);\n    __shared__ float smem[32];\n    if (threadIdx.x % 32 == 0) smem[threadIdx.x / 32] = sum_sq;\n    __syncthreads();\n    if (threadIdx.x < 32) {\n        sum_sq = (threadIdx.x < blockDim.x / 32) ? smem[threadIdx.x] : 0.f;\n        sum_sq = device::warp::reduce_sum<float>(sum_sq);\n    }\n    __syncthreads();\n\n    const float rms_inv = device::math::rsqrt<float>(sum_sq / static_cast<float>(hidden) + eps);\n\n    // Pass 2: normalize + optional weight\n    for (uint32_t vi = threadIdx.x; vi < n_vecs; vi += blockDim.x) {\n        vec_t v_in, v_out;\n        v_in.load(row_src, vi);\n        if (weight != nullptr) {\n            vec_t v_w;\n            v_w.load(weight, vi);\n            #pragma unroll\n            for (int i = 0; i < kVecN; ++i)\n                v_out[i] = static_cast<T>(static_cast<float>(v_in[i]) * rms_inv\n                                         * static_cast<float>(v_w[i]));\n        } else {\n            #pragma unroll\n            for (int i = 0; i < kVecN; ++i)\n                v_out[i] = static_cast<T>(static_cast<float>(v_in[i]) * rms_inv);\n        }\n        v_out.store(row_dst, vi);\n    }\n}\n\ntemplate <typename T>\nvoid rmsnorm(\n    tvm::ffi::TensorView dst,\n    tvm::ffi::TensorView src,\n    tvm::ffi::TensorView weight,   // data_ptr == nullptr → no weight\n    float eps)\n{\n    using namespace host;\n\n    SymbolicSize B{\"batch_tokens\"}, H{\"hidden_size\"};\n    SymbolicDevice device;\n    device.set_options<kDLCUDA>();\n\n    TensorMatcher({B, H})\n        .with_dtype<T>()\n        .with_device(device)\n        .verify(dst)\n        .verify(src);\n\n    const uint32_t num_rows = static_cast<uint32_t>(B.unwrap());\n    const uint32_t hidden   = static_cast<uint32_t>(H.unwrap());\n    const DLDevice dev      = device.unwrap();\n\n    constexpr int kVecN = 16 / sizeof(T);\n    RuntimeCheck(hidden % kVecN == 0,\n        \"rmsnorm: hidden_size (\", hidden, \") must be divisible by \", kVecN);\n    const uint32_t n_vecs = hidden / kVecN;\n\n    uint32_t threads = std::min(n_vecs, 512u);\n    threads = (threads + 31) / 32 * 32;\n\n    const T* w_ptr = (weight.data_ptr() != nullptr)\n        ? static_cast<const T*>(weight.data_ptr()) : nullptr;\n\n    LaunchKernel(num_rows, threads, dev)(\n        rmsnorm_kernel<T, kVecN>,\n        static_cast<T*>(dst.data_ptr()),\n        static_cast<const T*>(src.data_ptr()),\n        w_ptr, hidden, n_vecs, eps);\n}\n\n}  // namespace\n```\n\n---\n\n## Template 3: Fused Row-Reduction + Element-wise (AdaLN)\n\nCombines RMSNorm + AdaLN modulation into one pass: `y = norm(x) * (1 + scale) + shift`.\n\n### `.cuh` file: `csrc/diffusion/adaln.cuh`\n\n```cpp\n#include <sgl_kernel/tensor.h>\n#include <sgl_kernel/type.cuh>\n#include <sgl_kernel/utils.h>\n#include <sgl_kernel/utils.cuh>\n#include <sgl_kernel/vec.cuh>\n#include <sgl_kernel/warp.cuh>\n#include <sgl_kernel/math.cuh>\n\n#include <dlpack/dlpack.h>\n#include <tvm/ffi/container/tensor.h>\n\nnamespace {\n\n// AdaLN: y = norm(x) * (1 + scale) + shift\n// scale, shift: [batch, hidden] (one per row)\ntemplate <typename T, int kVecN>\n__global__ void adaln_kernel(\n    T* __restrict__ dst,\n    const T* __restrict__ src,\n    const T* __restrict__ weight,\n    const T* __restrict__ scale,\n    const T* __restrict__ shift,\n    uint32_t hidden,\n    uint32_t n_vecs,\n    float eps)\n{\n    using vec_t = device::AlignedVector<T, kVecN>;\n\n    const uint32_t row     = blockIdx.x;\n    const T* row_src       = src   + row * hidden;\n    const T* row_scale     = scale + row * hidden;\n    const T* row_shift     = shift + row * hidden;\n    T*       row_dst       = dst   + row * hidden;\n\n    // Pass 1: compute RMS\n    float sum_sq = 0.f;\n    for (uint32_t vi = threadIdx.x; vi < n_vecs; vi += blockDim.x) {\n        vec_t v;\n        v.load(row_src, vi);\n        #pragma unroll\n        for (int i = 0; i < kVecN; ++i) {\n            float val = static_cast<float>(v[i]);\n            sum_sq += val * val;\n        }\n    }\n    sum_sq = device::warp::reduce_sum<float>(sum_sq);\n    __shared__ float smem[32];\n    if (threadIdx.x % 32 == 0) smem[threadIdx.x / 32] = sum_sq;\n    __syncthreads();\n    if (threadIdx.x < 32) {\n        sum_sq = (threadIdx.x < blockDim.x / 32) ? smem[threadIdx.x] : 0.f;\n        sum_sq = device::warp::reduce_sum<float>(sum_sq);\n    }\n    __syncthreads();\n    const float rms_inv = device::math::rsqrt<float>(sum_sq / static_cast<float>(hidden) + eps);\n\n    // Pass 2: normalize + modulate\n    for (uint32_t vi = threadIdx.x; vi < n_vecs; vi += blockDim.x) {\n        vec_t v_in, v_w, v_sc, v_sh, v_out;\n        v_in.load(row_src, vi);\n        v_w.load(weight,   vi);\n        v_sc.load(row_scale, vi);\n        v_sh.load(row_shift, vi);\n        #pragma unroll\n        for (int i = 0; i < kVecN; ++i) {\n            float x  = static_cast<float>(v_in[i]) * rms_inv * static_cast<float>(v_w[i]);\n            float sc = static_cast<float>(v_sc[i]);\n            float sh = static_cast<float>(v_sh[i]);\n            v_out[i] = static_cast<T>(x * (1.f + sc) + sh);\n        }\n        v_out.store(row_dst, vi);\n    }\n}\n\ntemplate <typename T>\nvoid adaln(\n    tvm::ffi::TensorView dst,\n    tvm::ffi::TensorView src,\n    tvm::ffi::TensorView weight,\n    tvm::ffi::TensorView scale,\n    tvm::ffi::TensorView shift,\n    float eps)\n{\n    using namespace host;\n\n    SymbolicSize B{\"batch_tokens\"}, H{\"hidden_size\"};\n    SymbolicDevice device;\n    device.set_options<kDLCUDA>();\n\n    TensorMatcher({B, H})\n        .with_dtype<T>()\n        .with_device(device)\n        .verify(dst).verify(src).verify(weight).verify(scale).verify(shift);\n\n    const uint32_t num_rows = static_cast<uint32_t>(B.unwrap());\n    const uint32_t hidden   = static_cast<uint32_t>(H.unwrap());\n    const DLDevice dev      = device.unwrap();\n\n    constexpr int kVecN = 16 / sizeof(T);\n    RuntimeCheck(hidden % kVecN == 0, \"adaln: hidden_size must be divisible by \", kVecN);\n    const uint32_t n_vecs = hidden / kVecN;\n\n    uint32_t threads = std::min(n_vecs, 512u);\n    threads = (threads + 31) / 32 * 32;\n\n    LaunchKernel(num_rows, threads, dev)(\n        adaln_kernel<T, kVecN>,\n        static_cast<T*>(dst.data_ptr()),\n        static_cast<const T*>(src.data_ptr()),\n        static_cast<const T*>(weight.data_ptr()),\n        static_cast<const T*>(scale.data_ptr()),\n        static_cast<const T*>(shift.data_ptr()),\n        hidden, n_vecs, eps);\n}\n\n}  // namespace\n```\n\n---\n\n## Template 4: Python Wrapper (generic pattern)\n\nFile location: `python/sglang/jit_kernel/diffusion/<op>.py`\n\n```python\nfrom __future__ import annotations\nfrom typing import TYPE_CHECKING\n\nimport torch\n\nfrom sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args\n\nif TYPE_CHECKING:\n    from tvm_ffi.module import Module\n\n\n@cache_once\ndef _jit_module(dtype: torch.dtype) -> Module:\n    \"\"\"Cache key: dtype (and any other template params you need).\"\"\"\n    args = make_cpp_args(dtype)\n    return load_jit(\n        \"diffusion_your_op\",           # unique build cache key\n        *args,\n        cuda_files=[\"diffusion/your_op.cuh\"],  # relative to csrc/\n        cuda_wrappers=[(\"your_op\", f\"your_op<{args}>\")],\n        extra_cuda_cflags=[\"-O3\", \"--use_fast_math\"],\n    )\n\n\ndef diffusion_your_op(\n    src: torch.Tensor,\n    out: torch.Tensor | None = None,\n) -> torch.Tensor:\n    \"\"\"\n    Your op description.\n\n    Supported dtypes: float16, bfloat16, float32.\n    \"\"\"\n    assert src.is_cuda, \"src must be a CUDA tensor\"\n    assert src.dtype in (torch.float16, torch.bfloat16, torch.float32), (\n        f\"Unsupported dtype {src.dtype}\"\n    )\n    if out is None:\n        out = torch.empty_like(src)\n\n    module = _jit_module(src.dtype)\n    module.your_op(out, src)\n    return out\n```\n\n**`make_cpp_args` conversion table:**\n\n| `torch.dtype` | C++ type |\n|---------------|----------|\n| `torch.float16` | `fp16_t` |\n| `torch.bfloat16` | `bf16_t` |\n| `torch.float32` | `fp32_t` |\n\n---\n\n## Template 5: Correctness Test\n\n```python\n# python/sglang/jit_kernel/tests/test_diffusion_<op>.py\nimport pytest\nimport torch\nfrom sglang.jit_kernel.diffusion.<op> import diffusion_<op>\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16, torch.float32])\n@pytest.mark.parametrize(\"shape\", [(1, 2048), (4, 3072), (16, 4096)])\ndef test_<op>_correctness(dtype, shape):\n    src = torch.randn(*shape, dtype=dtype, device=\"cuda\")\n\n    out_jit = diffusion_<op>(src)\n    ref     = reference_<op>(src.float()).to(dtype)  # reference in fp32\n\n    tol = {\"rtol\": 1e-2, \"atol\": 1e-2} if dtype != torch.float32 else {\"rtol\": 1e-5, \"atol\": 1e-6}\n    torch.testing.assert_close(out_jit, ref, **tol)\n\n\ndef test_<op>_out_param():\n    src = torch.randn(1024, 2048, dtype=torch.bfloat16, device=\"cuda\")\n    out = torch.empty_like(src)\n    result = diffusion_<op>(src, out=out)\n    assert result is out\n\n\ndef test_<op>_cpu_error():\n    src = torch.randn(128, dtype=torch.float16)  # CPU tensor\n    with pytest.raises(AssertionError):\n        diffusion_<op>(src)\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__, \"-v\", \"-s\"])\n```\n\n---\n\n## Template 6: Benchmark\n\n```python\n# python/sglang/jit_kernel/benchmark/bench_diffusion_<op>.py\nimport torch\nimport triton.testing\n\nfrom sglang.jit_kernel.benchmark.utils import DEFAULT_DEVICE, DEFAULT_DTYPE, run_benchmark\nfrom sglang.jit_kernel.diffusion.<op> import diffusion_<op>\n\nSHAPES = [(4096, 2048), (4096, 3072), (4096, 4096)]\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=[\"hidden\"],\n        x_vals=[s[1] for s in SHAPES],\n        line_arg=\"provider\",\n        line_vals=[\"jit_cuda\", \"torch\"],\n        line_names=[\"SGLang JIT CUDA\", \"PyTorch\"],\n        styles=[(\"blue\", \"-\"), (\"red\", \"--\")],\n        ylabel=\"us\",\n        plot_name=\"diffusion-<op>\",\n        args={},\n    )\n)\ndef benchmark(hidden: int, provider: str):\n    src = torch.randn(4096, hidden, dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE)\n\n    if provider == \"jit_cuda\":\n        fn = lambda: diffusion_<op>(src)\n    else:\n        fn = lambda: reference_<op>(src)  # torch baseline\n\n    return run_benchmark(fn)\n\n\nif __name__ == \"__main__\":\n    benchmark.run(print_data=True)\n```\n\n---\n\n## Summary of New Files per Kernel\n\n```\npython/sglang/jit_kernel/csrc/diffusion/\n└── <op>.cuh                               # CUDA kernel + launcher\n\npython/sglang/jit_kernel/diffusion/\n└── <op>.py                                # Python wrapper (load_jit + cache_once)\n\npython/sglang/jit_kernel/tests/\n└── test_diffusion_<op>.py                 # correctness tests\n\npython/sglang/jit_kernel/benchmark/\n└── bench_diffusion_<op>.py                # triton.testing benchmark\n```\n\n> See `scripts/bench_diffusion_rmsnorm.py` and `scripts/bench_diffusion_denoise.py` for full runnable examples.\n"
  },
  {
    "path": "python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/references/t4-optimization-guide.md",
    "content": "# T4 GPU Optimization Guide — SGLang Diffusion JIT Kernels\n\nT4 is a Turing architecture GPU (GCP n1+T4, AWS g4dn) commonly used for cloud inference.\nIts key constraint for diffusion kernels: **no BF16 support** — FP16 only.\n\n> **Adapted from**: [HuggingFace kernels cuda-kernels skill](https://github.com/huggingface/kernels/tree/main/skills/cuda-kernels)\n\n---\n\n## T4 Turing Architecture Overview\n\n| Component | T4 | A100 | H100 |\n|-----------|-----|------|------|\n| Compute Capability | sm_75 | sm_80 | sm_90 |\n| SMs | 40 | 108 | 132 |\n| Shared Memory/SM | **64 KB** | 164 KB | 192 KB |\n| L2 Cache | 4 MB | 40 MB | 50 MB |\n| Memory Bandwidth | **320 GB/s** | 2.0 TB/s | 3.35 TB/s |\n| Memory | 16 GB GDDR6 | 40–80 GB HBM2e | 80 GB HBM3 |\n| Max Threads/SM | **1024** | 2048 | 2048 |\n| BF16 Support | **No** | Yes | Yes |\n\n### Critical T4 Constraints\n\n1. **No BFloat16** — must use FP16 everywhere\n2. **320 GB/s bandwidth** — ~10x lower than H100; vectorization is critical\n3. **16 GB memory** — limits model size; use offloading\n4. **64 KB shared memory/SM** — smaller attention tiles\n5. **Max 1024 threads/SM** — half of A100/H100; affects occupancy calculations\n\n---\n\n## No BF16: Always Use FP16\n\nThis is the most impactful constraint. **Never use `bf16_t` or `__nv_bfloat16` on T4.**\n\n**Python wrapper guard:**\n\n```python\nimport torch\nfrom sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args\n\n@cache_once\ndef _jit_rmsnorm_module(dtype: torch.dtype):\n    # T4 (sm_75) does not support BF16\n    cap = torch.cuda.get_device_capability()\n    if cap < (8, 0) and dtype == torch.bfloat16:\n        raise RuntimeError(\n            f\"T4 (sm_75) does not support BF16. Use torch.float16 instead. \"\n            f\"Got dtype={dtype}\"\n        )\n    args = make_cpp_args(dtype)\n    return load_jit(\n        \"diffusion_rmsnorm\",\n        *args,\n        cuda_files=[\"diffusion/rmsnorm.cuh\"],\n        cuda_wrappers=[(\"rmsnorm\", f\"rmsnorm<{args}>\")],\n    )\n```\n\n**Conditional type in kernel:**\n\n```cuda\n#if __CUDA_ARCH__ >= 800\n    // A100/H100: BF16 available\n    using DefaultHalf = bf16_t;\n#else\n    // T4/Turing: FP16 only\n    using DefaultHalf = fp16_t;\n#endif\n```\n\n**Runtime detection helper:**\n\n```python\ndef get_diffusion_dtype() -> torch.dtype:\n    \"\"\"Return the appropriate half-precision dtype for the current GPU.\"\"\"\n    cap = torch.cuda.get_device_capability()\n    if cap >= (8, 0):\n        return torch.bfloat16   # A100/H100: prefer BF16\n    else:\n        return torch.float16    # T4/older: FP16 only\n```\n\n---\n\n## Memory Access Optimization\n\nWith only 320 GB/s, **vectorization is more critical on T4 than on A100/H100**.\n\n### `AlignedVector` (same abstraction, FP16 only)\n\n```cpp\n#include <sgl_kernel/vec.cuh>\n\n// On T4, T must be fp16_t or fp32_t (NOT bf16_t)\nconstexpr int kVecN = 16 / sizeof(T);   // 8 for fp16, 4 for fp32\nusing vec_t = device::AlignedVector<T, kVecN>;\n```\n\n**Target bandwidth**: 40–50% of T4's 320 GB/s = 128–160 GB/s.\n\n### Increase Arithmetic Intensity\n\nWith low bandwidth, fusing ops saves more on T4 than on H100:\n\n```cpp\n// BAD on T4: separate passes → 2× memory traffic\noutput1[i] = input[i] * scale;       // pass 1\noutput2[i] = output1[i] + bias;      // pass 2\n\n// GOOD: fuse → single memory read, single write\nfloat val = static_cast<float>(v[i]);\nval = val * scale + bias;\nval = device::math::max<float>(val, 0.f);  // ReLU\nv[i] = static_cast<T>(val);\n```\n\n### Expected T4 Performance\n\n| Kernel | T4 (ms) | A100 (ms) | H100 (ms) | T4 vs H100 |\n|--------|---------|-----------|-----------|------------|\n| RMSNorm [2, 1024, 2048] | ~0.5 | ~0.08 | 0.054 | ~9x slower |\n| GEGLU [2, 1024, 4096] | ~0.3 | ~0.05 | 0.030 | ~10x slower |\n\n---\n\n## Shared Memory Configuration\n\nT4 max: **64 KB/SM**. Use smaller tiles vs A100/H100.\n\n```cpp\n// T4: request max shared memory (64 KB)\ncudaFuncSetAttribute(\n    your_kernel,\n    cudaFuncAttributeMaxDynamicSharedMemorySize,\n    64 * 1024\n);\n```\n\n**Attention tile sizes for T4** (halved vs H100):\n\n```\nH100/A100: BLOCK_SIZE_M = 128, BLOCK_SIZE_N = 64\nT4:        BLOCK_SIZE_M =  64, BLOCK_SIZE_N = 32   ← reduced for 64 KB limit\n```\n\n---\n\n## Occupancy Tuning\n\nT4 max: **1024 threads/SM** (vs 2048 on A100/H100). This halves max occupancy for a given block size.\n\n**Block sizes for T4:**\n\n| Kernel Type | Threads/Block | Notes |\n|-------------|---------------|-------|\n| Element-wise | 256 | Same as H100 |\n| Row reduction | 256–512 | Avoid > 512 to fit multiple blocks/SM |\n| Tiled/attention | 128–256 | Small tiles due to 64 KB shared mem |\n\n**Grid sizing for T4 (40 SMs)** — `runtime::get_sm_count` handles this automatically:\n\n```cpp\n// get_sm_count() returns 40 on T4, 108 on A100, 132 on H100\nconst uint32_t num_sm = host::runtime::get_sm_count(device.unwrap().device_id);\n```\n\n---\n\n## Numerical Stability with FP16\n\nFP16 has a smaller dynamic range (±65504) vs BF16 (±3.39×10³⁸). Watch for overflow in attention:\n\n```cuda\n// Scale attention scores to prevent FP16 overflow\nfloat scale_factor = 1.0f / sqrtf(static_cast<float>(head_dim));\n// For very long sequences on T4, may need additional scaling:\n// if (score * scale_factor > 65000.f) { /* clamp */ }\n```\n\nAlways accumulate in FP32:\n\n```cpp\nfloat acc = 0.f;   // FP32 accumulation\nfor (uint32_t vi = threadIdx.x; vi < n_vecs; vi += blockDim.x) {\n    vec_t v;\n    v.load(src, vi);\n    #pragma unroll\n    for (int i = 0; i < kVecN; ++i) {\n        float val = static_cast<float>(v[i]);   // fp16 → fp32\n        acc += val * val;\n    }\n}\n```\n\n---\n\n## Memory Management for 16 GB\n\nT4's 16 GB requires careful planning for large diffusion models.\n\n**sglang generate flags for T4:**\n\n```bash\n# Enable CPU offloading to fit within 16 GB\nsglang generate \\\n  --model-path=black-forest-labs/FLUX.1-dev \\\n  --dit-cpu-offload true \\        # DiT weights to CPU\n  --text-encoder-cpu-offload true \\\n  --vae-cpu-offload true \\\n  --width=512 --height=512 \\      # Reduce resolution\n  --num-inference-steps=20 \\      # Fewer steps\n  --seed=42\n```\n\n**Resolution recommendations for T4:**\n\n| Model | H100/A100 | T4 |\n|-------|-----------|-----|\n| FLUX.1-dev | 1024×1024 | 512×512 |\n| Wan2.2-TI2V-5B | 720P | 480P |\n| FLUX.2-dev | 1024×1024 | 512×512 |\n\n---\n\n## JIT Compilation for T4\n\n```python\nreturn load_jit(\n    \"my_kernel\",\n    *args,\n    cuda_files=[\"diffusion/my_kernel.cuh\"],\n    cuda_wrappers=[(\"my_kernel\", f\"my_kernel<{args}>\")],\n    extra_cuda_cflags=[\n        \"-O3\",\n        \"--use_fast_math\",\n        \"-arch=sm_75\",   # T4 only; omit for multi-arch\n    ],\n)\n```\n\n**Multi-arch (T4 + A100 + H100):**\n\n```python\nextra_cuda_cflags=[\n    \"-O3\", \"--use_fast_math\",\n    \"-gencode=arch=compute_75,code=sm_75\",   # T4\n    \"-gencode=arch=compute_80,code=sm_80\",   # A100\n    \"-gencode=arch=compute_90,code=sm_90\",   # H100\n]\n```\n\n---\n\n## H100/A100 → T4 Migration Checklist\n\n| Item | H100/A100 | T4 | Action |\n|------|-----------|-----|--------|\n| BF16 | Available | **Not available** | Replace `bf16_t` with `fp16_t`; guard in Python wrapper |\n| Shared memory | 164–192 KB | **64 KB** | Halve tile sizes |\n| Grid sizing | ×108/132 SMs | ×40 SMs | `get_sm_count()` auto-handles |\n| Max threads/SM | 2048 | **1024** | Don't exceed 512 threads/block |\n| Memory | 40–80 GB | **16 GB** | Enable CPU offloading |\n| cp.async | Available | No (Turing has limited async) | Remove async copy patterns |\n| `AlignedVector` | Same | Same | No changes |\n| `warp::reduce_sum` | Same | Same | No changes |\n\n---\n\n## Performance Profiling\n\n### NVIDIA Nsight Systems (nsys)\n\n```bash\nnsys profile -o t4_profile python scripts/bench_diffusion_rmsnorm.py\n\n# Key metrics to watch:\n# - Kernel duration\n# - Memory transfer time\n# - GPU idle time\n# - Stream utilization\n```\n\n### NVIDIA Nsight Compute (ncu)\n\n```bash\n# Full metrics\nncu --set full -o t4_metrics.ncu-rep \\\n  python scripts/bench_diffusion_rmsnorm.py\n\n# Specific metrics — T4 is memory-bound; focus on dram throughput\nncu --metrics sm__throughput.avg.pct_of_peak_sustained_elapsed,\\\ndram__throughput.avg.pct_of_peak_sustained_elapsed \\\n  python scripts/bench_diffusion_rmsnorm.py\n\n# Key metrics for T4 diffusion kernels:\n# - Memory throughput     (dram__throughput.avg.pct_of_peak_sustained_elapsed)\n#   → Target: 40–50% of 320 GB/s (128–160 GB/s) for vectorized kernels\n# - SM utilization        (sm__throughput.avg.pct_of_peak_sustained_elapsed)\n#   → Target high with only 40 SMs\n# - Achieved occupancy    (sm__warps_active.avg.pct_of_peak_sustained_active)\n#   → Max 1024 threads/SM on T4 — block size ≤ 512 for decent occupancy\n# - Warp stall reasons    (smsp__warp_issue_stalled_*.avg.pct_of_peak_sustained_active)\n```\n\n### Common T4 Bottlenecks\n\n1. **Memory Bandwidth** — 320 GB/s is the primary limit; if `dram__throughput` < 40% → use `AlignedVector`\n2. **Limited Memory** — 16 GB; enable `--dit-cpu-offload`/`--vae-cpu-offload` as needed\n3. **No BF16** — guard in Python wrapper; FP16 overflow risk in long-sequence attention\n4. **Smaller tiles** — 64 KB shared memory; reduce `BLOCK_SIZE_M/N` vs H100\n\n---\n\n## Best Practices Summary (T4)\n\n1. **No BF16**: Guard in Python wrapper, raise clear error\n2. **Vectorization**: Even more critical at 320 GB/s — always use `AlignedVector`\n3. **Tile sizes**: 64 KB shared memory limit → halve BLOCK_SIZE vs H100\n4. **Block size**: Max 512 threads/block for decent occupancy (max 1024 threads/SM)\n5. **Grid sizing**: 40 SMs — `runtime::get_sm_count()` auto-handles\n6. **FP32 accumulation**: Always accumulate in FP32 to avoid FP16 overflow\n7. **Memory**: Plan for 16 GB; use `--dit-cpu-offload`/`--vae-cpu-offload` as needed\n8. **Fuse more**: Low bandwidth makes kernel fusion more impactful than on H100\n9. **Multi-arch build**: Always build for `sm_75,sm_80,sm_90` together\n\n## T4 Cloud Instance Quick Reference\n\n| Provider | Instance | Notes |\n|----------|----------|-------|\n| GCP | n1-standard-4 + T4 | Most common inference setup |\n| AWS | g4dn.xlarge | 1× T4, 16 GB |\n| AWS | g4dn.12xlarge | 4× T4, 64 GB total |\n| Azure | NC4as T4 v3 | 1× T4 |\n"
  },
  {
    "path": "python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/references/troubleshooting.md",
    "content": "# Troubleshooting Guide — SGLang Diffusion JIT CUDA Kernels\n\nCommon issues and solutions when writing and integrating JIT CUDA kernels for SGLang Diffusion.\n\n> **Adapted from**: [HuggingFace kernels cuda-kernels skill](https://github.com/huggingface/kernels/tree/main/skills/cuda-kernels)\n\n---\n\n## Build / Compile Issues\n\n### 1. JIT compilation fails: \"No such file or directory\"\n\n**Problem:** `load_jit` cannot find your `.cuh` file.\n\n```\nFileNotFoundError: .../jit_kernel/csrc/diffusion/your_op.cuh not found\n```\n\n**Fix:** Ensure the file is under `python/sglang/jit_kernel/csrc/diffusion/`. The path passed to `cuda_files` is relative to `csrc/`:\n\n```python\n# CORRECT — file lives at csrc/diffusion/your_op.cuh\nload_jit(..., cuda_files=[\"diffusion/your_op.cuh\"])\n# resolves to: python/sglang/jit_kernel/csrc/diffusion/your_op.cuh\n\n# ALSO CORRECT — absolute path (pathlib replaces the csrc/ prefix)\nload_jit(..., cuda_files=[\"/full/absolute/path/to/your_op.cuh\"])\n```\n\n### 2. Type conversion errors (FP16/BF16)\n\n**Problem:** Implicit FP16/BF16 conversion fails because PyTorch compiles with `-D__CUDA_NO_HALF_OPERATORS__`:\n\n```\nerror: no suitable conversion function from \"__half\" to \"float\" exists\n```\n\n**Fix:** SGLang's `static_cast<float>` works because `fp16_t` and `bf16_t` are typedef'd with proper conversion operators. Always use explicit casts:\n\n```cpp\n// CORRECT — explicit cast\nfloat val = static_cast<float>(v[i]);   // fp16_t / bf16_t → float\nv[i] = static_cast<T>(fp32_result);    // float → T\n\n// WRONG — implicit conversion (disabled by PyTorch build flags)\nfloat val = v[i];           // compile error\nv[i] = fp32_result;         // compile error\n```\n\nIf you need the raw intrinsics for packed types:\n```cpp\n// bf16x2_t → two floats\nbf16x2_t packed = ...;\nfloat v0 = __bfloat162float(packed.x);\nfloat v1 = __bfloat162float(packed.y);\n```\n\n### 3. Template instantiation explodes / slow first compile\n\n**Problem:** Many template combinations makes the first JIT compile very slow.\n\n**Fix:** Reduce template argument combinations. Move compile-time constants to runtime if they don't affect performance critically:\n\n```cpp\n// Fewer template args = fewer instantiations\ntemplate <typename T>  // only dtype varies\nvoid my_op(tvm::ffi::TensorView dst, tvm::ffi::TensorView src, int block_size);\n```\n\n### 4. SM check: kernel requires sm_90 but device is sm_80\n\n**Problem:** Kernel uses H100-only features on A100.\n\n**Fix:** Add a Python guard before calling `load_jit`:\n\n```python\ncap = torch.cuda.get_device_capability()\nif cap[0] < 9:\n    raise RuntimeError(\n        f\"This kernel requires H100 (sm_90+). \"\n        f\"Got compute capability {cap[0]}.{cap[1]}. \"\n        f\"Use the Triton fallback instead: diffusion_triton_<op>()\"\n    )\n```\n\n---\n\n## Performance Issues\n\n### 5. Kernel is slower than Triton / PyTorch baseline\n\n**Steps to diagnose:**\n\n1. Check dtype: are you using `bf16_t` on T4? (T4 has no BF16 — silently falls back to slow emulation)\n2. Check vectorization: is `hidden_size` divisible by `kVecN = 16/sizeof(T)` (8 for bf16, 4 for fp32)?\n3. Profile with `ncu`:\n   ```bash\n   ncu --set full --csv -o metrics.csv \\\n     python -c \"from sglang.jit_kernel.diffusion.rmsnorm import diffusion_rmsnorm; ...\"\n   ```\n   Look at `dram__throughput.avg.pct_of_peak_sustained_elapsed` — if < 30%, check coalescing.\n\n4. Check occupancy: run with `--ptxas-options=-v` in `extra_cuda_cflags` to see register usage.\n\n### 6. Shared memory bank conflicts\n\n**Problem:** `ncu` reports high `l1tex__data_bank_conflicts_pipe_lmem_op_st.sum`.\n\n**Fix:** Add padding to shared memory arrays:\n\n```cpp\n// Conflict (all threads hit same bank when stride=32)\n__shared__ float data[32][32];\n\n// Fixed with padding\n__shared__ float data[32][33];  // 33 instead of 32\n```\n\n### 7. Low occupancy from too many registers\n\n**Problem:** `nvcc --ptxas-options=-v` shows high register count; occupancy < 25%.\n\n**Fix:** Add `--maxrregcount=N` to limit registers:\n\n```python\nextra_cuda_cflags=[\"-O3\", \"--use_fast_math\", \"--maxrregcount=64\"]\n```\n\nReduces registers per thread at the cost of possible register spilling to local memory.\n\n---\n\n## Integration Issues\n\n### 8. RMSNorm weight is None (`elementwise_affine=False`)\n\n**Problem:**\n```\nAttributeError: 'NoneType' object has no attribute 'data_ptr'\n```\n\n**Root Cause:** DiT transformer blocks often use `RMSNorm(dim, elementwise_affine=False)` — no learnable weight.\n\n**Fix in Python wrapper:** pass an empty tensor when weight is absent; the kernel launcher checks `data_ptr == nullptr`:\n\n```python\nw = weight if weight is not None else torch.empty(0, dtype=src.dtype, device=src.device)\nmodule.rmsnorm(out, src, w, eps)\n```\n\n**Fix in `.cuh` launcher:**\n\n```cpp\nconst T* w_ptr = (weight.data_ptr() != nullptr)\n    ? static_cast<const T*>(weight.data_ptr()) : nullptr;\n// ... pass w_ptr to kernel ...\n```\n\n**Fix in module patching:**\n\n```python\nhas_weight = hasattr(module, \"weight\") and module.weight is not None\nif has_weight:\n    def _fwd(mod, eps):\n        def forward(x): return diffusion_rmsnorm(x, weight=mod.weight, eps=eps)\n        return forward\n    module.forward = _fwd(module, module.eps)\nelse:\n    def _fwd_noweight(eps):\n        def forward(x): return diffusion_rmsnorm(x, weight=None, eps=eps)\n        return forward\n    module.forward = _fwd_noweight(module.eps)\n```\n\n### 9. `isinstance(module, torch.nn.RMSNorm)` misses diffusion variants\n\n**Problem:** Patching doesn't apply because diffusers / sglang diffusion models define their own `RMSNorm` class that is **not** a subclass of `torch.nn.RMSNorm`.\n\n**Fix:** Match by class name string:\n\n```python\n# WRONG — misses diffusers/sglang RMSNorm\nif isinstance(module, torch.nn.RMSNorm):\n\n# CORRECT — catches all variants\nif type(module).__name__ == \"RMSNorm\":\n# or for broader matching:\nif \"RMSNorm\" in type(module).__name__:\n```\n\n### 10. Kernel patching doesn't persist after CPU offloading\n\n**Problem:** After calling `pipe.enable_model_cpu_offload()`, patched modules revert.\n\n**Fix:** Always inject **after** moving to CUDA, **before** enabling any offloading:\n\n```python\npipe = load_pipeline(...)\npipe.to(\"cuda\")                  # 1. Move to CUDA\ninject_optimized_kernels(pipe)   # 2. Patch modules\npipe.enable_model_cpu_offload()  # 3. Now safe to enable offloading\n```\n\n### 11. Kernel patched after `torch.compile`\n\n**Problem:** Module is already compiled; patching its `forward` after compilation has no effect.\n\n**Fix:** Apply patches **before** any `torch.compile` call:\n\n```python\ninject_optimized_kernels(pipe)          # FIRST: patch\npipe.transformer = torch.compile(...)   # SECOND: compile\n```\n\n---\n\n## `torch.compile` Compatibility\n\n### 12. Custom CUDA kernel causes graph break\n\n**Problem:**\n```\ntorch._dynamo.exc.Unsupported: Attempted to call function marked as skipped\n```\nor:\n```\ntorch._dynamo.exc.TorchRuntimeError: Cannot access data pointer of Tensor (FakeTensor)\n```\n\n**Root Cause:** `torch.compile` traces with \"fake tensors\" that have no real data. Any kernel that calls `.data_ptr()` during tracing fails.\n\n**Options:**\n\n**Option A (simplest):** Don't use `torch.compile` with CUDA JIT kernels — use Triton instead:\n```python\n# Triton kernels are torch.compile compatible\nfrom sglang.jit_kernel.diffusion.triton.norm import fused_rmsnorm\n```\n\n**Option B:** Register as a `@torch.library.custom_op` (advanced):\n```python\nimport torch\n\n@torch.library.custom_op(\"diffusion_jit::rmsnorm\", mutates_args={\"out\"})\ndef _rmsnorm_op(out: torch.Tensor, src: torch.Tensor,\n                weight: torch.Tensor, eps: float) -> None:\n    module = _jit_rmsnorm_module(src.dtype)\n    module.rmsnorm(out, src, weight, eps)\n\n@_rmsnorm_op.register_fake\ndef _(out, src, weight, eps):\n    pass  # no shape changes; output already allocated in 'out'\n```\n\n**Performance trade-off:**\n\n| Approach | Speedup (denoise) | torch.compile | Notes |\n|----------|-------------------|---------------|-------|\n| CUDA JIT kernel | best | Yes (via `torch.library.custom_op`) | Performance-optimal regardless of whether `torch.compile` is enabled; use `custom_op` + `register_fake` for compile compatibility |\n| Triton kernel | good | Yes | Use when you need faster iteration/portability, or when you do not have a well-tuned CUDA kernel yet |\n| Triton + compile | good | Yes | Use for end-to-end `torch.compile` integration convenience; typically slower than a well-tuned CUDA kernel |\n\n### 13. Unstable benchmark results from JIT timing\n\n**Problem:** First few runs are slow due to JIT compilation; timing is noisy.\n\n**Fix:** Use `triton.testing.do_bench` / `run_benchmark` which use CUDA-graph-based timing automatically. Always do a warmup run first:\n\n```python\n# Pre-compile by running once before timing\ndiffusion_rmsnorm(dummy_src, weight=dummy_w, eps=1e-6)\ntorch.cuda.synchronize()\n# Now time\nresult = run_benchmark(lambda: diffusion_rmsnorm(src, weight=w, eps=1e-6))\n```\n\n---\n\n## Debugging Checklist\n\n```bash\n# 1. Verify CUDA device and compute capability\npython -c \"import torch; print(torch.cuda.get_device_name(), torch.cuda.get_device_capability())\"\n\n# 2. Force synchronous CUDA execution to get real error location\nCUDA_LAUNCH_BLOCKING=1 python scripts/bench_diffusion_rmsnorm.py\n\n# 3. Run memory sanitizer to catch illegal accesses\ncompute-sanitizer --tool memcheck python scripts/bench_diffusion_rmsnorm.py\n\n# 4. Check register and shared memory usage\n# Add to extra_cuda_cflags: \"--ptxas-options=-v\"\n\n# 5a. Kernel-level profiling — full metrics\nncu --set full -o metrics.ncu-rep \\\n  python scripts/bench_diffusion_rmsnorm.py\n\n# 5b. Kernel-level profiling — targeted bandwidth + occupancy check\nncu --metrics sm__throughput.avg.pct_of_peak_sustained_elapsed,\\\ndram__throughput.avg.pct_of_peak_sustained_elapsed \\\n  python scripts/bench_diffusion_rmsnorm.py\n\n# Key metrics to interpret:\n# - sm__throughput  : compute utilization % of peak\n# - dram__throughput: memory bandwidth % of peak (target ≥ 30% on H100/A100)\n# - smsp__warp_issue_stalled_*: warp stall breakdown (memory_dependency / math_pipe)\n\n# 6. System-level profiling (per-op breakdown inside sglang generate)\nnsys profile -o denoise_profile \\\n  sglang generate --model-path=black-forest-labs/FLUX.1-dev \\\n    --width=1024 --height=1024 --num-inference-steps=50 \\\n    --seed=42 --enable-torch-compile --warmup\n\n# 7. Verify a patched module produces correct output\npython - << 'EOF'\nimport torch\nfrom sglang.jit_kernel.diffusion.rmsnorm import diffusion_rmsnorm\n\nx = torch.randn(4, 2048, dtype=torch.bfloat16, device=\"cuda\")\nw = torch.ones(2048, dtype=torch.bfloat16, device=\"cuda\")\n\nout_jit = diffusion_rmsnorm(x, weight=w, eps=1e-6)\nout_ref = torch.nn.functional.rms_norm(x.float(), (2048,), w.float(), eps=1e-6).to(torch.bfloat16)\n\nmax_diff = (out_jit - out_ref).abs().max().item()\nprint(f\"Max diff: {max_diff:.2e} ({'PASS' if max_diff < 0.02 else 'FAIL'})\")\nEOF\n```\n"
  },
  {
    "path": "python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/scripts/bench_diffusion_denoise.py",
    "content": "\"\"\"\nEnd-to-end denoise-stage benchmark for SGLang Diffusion with/without custom JIT CUDA kernels.\n\nMeasures denoise latency (primary metric ★) and peak GPU memory.\nAll model configs are kept in exact sync with diffusion-benchmark-and-profile.md.\n\nAdapted from: https://github.com/huggingface/kernels/tree/main/skills/cuda-kernels\n\nUsage:\n    # Baseline — single model\n    cd /path/to/sglang\n    python3 python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/scripts/bench_diffusion_denoise.py --model flux\n\n    # With custom JIT CUDA kernels\n    python3 python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/scripts/bench_diffusion_denoise.py --model flux --custom-kernels\n\n    # Side-by-side comparison\n    python3 python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/scripts/bench_diffusion_denoise.py --model flux --compare\n\n    # All 10 models, comparison\n    python3 python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/scripts/bench_diffusion_denoise.py --all --compare\n\nInput images required for image-guided models:\n    ASSET_DIR=$(python3 python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/scripts/diffusion_skill_env.py print-assets-dir --mkdir)\n    wget -O \"${ASSET_DIR}/cat.png\" \\\n      https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png\n    wget -O \"${ASSET_DIR}/astronaut.jpg\" \\\n      https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg\n    wget -O \"${ASSET_DIR}/mova_single_person.jpg\" \\\n      https://github.com/OpenMOSS/MOVA/raw/main/assets/single_person.jpg\n\"\"\"\n\nimport argparse\nimport json\nimport os\nimport subprocess\nimport sys\nimport time\nfrom pathlib import Path\nfrom typing import Optional\n\nSCRIPT_DIR = Path(__file__).resolve().parent\nif str(SCRIPT_DIR) not in sys.path:\n    sys.path.insert(0, str(SCRIPT_DIR))\n\nfrom diffusion_skill_env import (\n    ensure_dir,\n    get_assets_dir,\n    get_output_dir,\n    get_repo_root,\n    pick_idle_gpus,\n)\n\nREPO_ROOT = get_repo_root()\nASSET_DIR = ensure_dir(get_assets_dir(REPO_ROOT))\n\n# ---------------------------------------------------------------------------\n# Model configs — kept in exact sync with diffusion-benchmark-and-profile.md\n# Each entry produces the same `sglang generate` command as shown in that doc.\n# ---------------------------------------------------------------------------\nMODELS = {\n    # 1. Qwen/Qwen-Image-2512 — Text-to-Image, 1024×1024, 50 steps\n    \"qwen\": {\n        \"path\": \"Qwen/Qwen-Image-2512\",\n        \"prompt\": \"A futuristic cyberpunk city at night, neon lights reflecting on wet streets, highly detailed, 8k\",\n        \"negative_prompt\": \" \",\n        \"extra_args\": [\n            \"--width=1024\",\n            \"--height=1024\",\n            \"--num-inference-steps=50\",\n            \"--guidance-scale=4.0\",\n            \"--dit-cpu-offload\",\n            \"false\",\n            \"--text-encoder-cpu-offload\",\n            \"false\",\n        ],\n    },\n    # 2. Qwen/Qwen-Image-Edit-2511 — Image Editing, 1024×1024, 50 steps\n    # Requires: <repo>/inputs/diffusion_benchmark/figs/cat.png\n    \"qwen-edit\": {\n        \"path\": \"Qwen/Qwen-Image-Edit-2511\",\n        \"prompt\": \"Transform into anime style\",\n        \"negative_prompt\": \" \",\n        \"image_path\": str(ASSET_DIR / \"cat.png\"),\n        \"extra_args\": [\n            \"--width=1024\",\n            \"--height=1024\",\n            \"--num-inference-steps=50\",\n            \"--guidance-scale=4.0\",\n            \"--dit-cpu-offload\",\n            \"false\",\n            \"--text-encoder-cpu-offload\",\n            \"false\",\n        ],\n    },\n    # 3. black-forest-labs/FLUX.1-dev — Text-to-Image, 1024×1024, 50 steps\n    \"flux\": {\n        \"path\": \"black-forest-labs/FLUX.1-dev\",\n        \"prompt\": \"A futuristic cyberpunk city at night, neon lights reflecting on wet streets, highly detailed, 8k\",\n        \"extra_args\": [\n            \"--width=1024\",\n            \"--height=1024\",\n            \"--num-inference-steps=50\",\n            \"--guidance-scale=4.0\",\n        ],\n    },\n    # 4. black-forest-labs/FLUX.2-dev — Text-to-Image, 1024×1024\n    \"flux2\": {\n        \"path\": \"black-forest-labs/FLUX.2-dev\",\n        \"prompt\": \"A Logo With Bold Large Text: SGL Diffusion\",\n        \"extra_args\": [\n            \"--width=1024\",\n            \"--height=1024\",\n            \"--dit-layerwise-offload\",\n            \"false\",\n            \"--dit-cpu-offload\",\n            \"false\",\n            \"--text-encoder-cpu-offload\",\n            \"true\",\n            \"--vae-cpu-offload\",\n            \"false\",\n        ],\n    },\n    # 5. Tongyi-MAI/Z-Image-Turbo — Turbo Text-to-Image, 1024×1024, 9 steps\n    \"zimage\": {\n        \"path\": \"Tongyi-MAI/Z-Image-Turbo\",\n        \"prompt\": \"A fantasy landscape with mountains and a river, detailed, vibrant colors\",\n        \"extra_args\": [\n            \"--width=1024\",\n            \"--height=1024\",\n            \"--num-inference-steps=9\",\n            \"--guidance-scale=0.0\",\n            \"--dit-cpu-offload\",\n            \"false\",\n            \"--text-encoder-cpu-offload\",\n            \"false\",\n        ],\n    },\n    # 6. Wan-AI/Wan2.2-T2V-A14B-Diffusers — Text-to-Video, 720P, 4 GPUs, 81 frames, 2 steps\n    \"wan-t2v\": {\n        \"path\": \"Wan-AI/Wan2.2-T2V-A14B-Diffusers\",\n        \"prompt\": \"A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon.\",\n        \"negative_prompt\": \" \",\n        \"extra_args\": [\n            \"--720p\",\n            \"--num-inference-steps=2\",\n            \"--num-frames=81\",\n            \"--guidance-scale=5.0\",\n            \"--num-gpus=4\",\n            \"--ulysses-degree=4\",\n            \"--text-encoder-cpu-offload\",\n            \"--pin-cpu-memory\",\n        ],\n    },\n    # 7. Wan-AI/Wan2.2-TI2V-5B-Diffusers — Text-Image-to-Video, 720P, 1 GPU, 81 frames, 50 steps\n    # Requires: <repo>/inputs/diffusion_benchmark/figs/astronaut.jpg\n    \"wan-ti2v\": {\n        \"path\": \"Wan-AI/Wan2.2-TI2V-5B-Diffusers\",\n        \"prompt\": \"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot.\",\n        \"negative_prompt\": \"Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards\",\n        \"image_path\": str(ASSET_DIR / \"astronaut.jpg\"),\n        \"extra_args\": [\n            \"--num-frames\",\n            \"81\",\n            \"--720p\",\n            \"--num-inference-steps\",\n            \"50\",\n            \"--guidance-scale\",\n            \"5.0\",\n            \"--dit-layerwise-offload\",\n            \"false\",\n            \"--dit-cpu-offload\",\n            \"false\",\n            \"--vae-cpu-offload\",\n            \"false\",\n            \"--text-encoder-cpu-offload\",\n            \"false\",\n        ],\n    },\n    # 8. hunyuanvideo-community/HunyuanVideo — Text-to-Video, 848×480, 65 frames, 30 steps\n    \"hunyuanvideo\": {\n        \"path\": \"hunyuanvideo-community/HunyuanVideo\",\n        \"prompt\": \"A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window.\",\n        \"extra_args\": [\n            \"--text-encoder-cpu-offload\",\n            \"--pin-cpu-memory\",\n            \"--num-frames=65\",\n            \"--width=848\",\n            \"--height=480\",\n            \"--num-inference-steps=30\",\n        ],\n    },\n    # 9. OpenMOSS-Team/MOVA-720p — Image-to-Video, 4 GPUs, 193 frames, 2 steps\n    # Requires: <repo>/inputs/diffusion_benchmark/figs/mova_single_person.jpg\n    \"mova-720p\": {\n        \"path\": \"OpenMOSS-Team/MOVA-720p\",\n        \"prompt\": 'A man in a blue blazer and glasses speaks in a formal indoor setting, framed by wooden furniture and a filled bookshelf. Quiet room acoustics underscore his measured tone as he delivers his remarks. At one point, he says, \"I would also believe that this advance in AI recently was not unexpected.\"',\n        \"image_path\": str(ASSET_DIR / \"mova_single_person.jpg\"),\n        \"extra_args\": [\n            \"--adjust-frames=false\",\n            \"--num-gpus=4\",\n            \"--ring-degree=1\",\n            \"--ulysses-degree=4\",\n            \"--num-frames=193\",\n            \"--fps=24\",\n            \"--num-inference-steps=2\",\n        ],\n    },\n    # 10. BestWishYsh/Helios-Base — Text-to-Video, 640×384, 33 frames\n    \"helios\": {\n        \"path\": \"BestWishYsh/Helios-Base\",\n        \"prompt\": \"A curious raccoon\",\n        \"extra_args\": [\n            \"--width=640\",\n            \"--height=384\",\n            \"--num-frames=33\",\n            \"--dit-layerwise-offload\",\n            \"false\",\n            \"--dit-cpu-offload\",\n            \"false\",\n            \"--text-encoder-cpu-offload\",\n            \"false\",\n            \"--vae-cpu-offload\",\n            \"false\",\n        ],\n    },\n}\n\n\ndef required_gpus_for_model(model_key: str) -> int:\n    if model_key == \"wan-t2v\":\n        return 4\n    if model_key == \"mova-720p\":\n        return 4\n    return 1\n\n\ndef build_sglang_cmd(\n    model_key: str,\n    use_custom_kernels: bool,\n    perf_dump_path: Optional[str] = None,\n    warmup: bool = True,\n    torch_compile: bool = True,\n    seed: int = 42,\n    save_output: bool = True,\n) -> list[str]:\n    \"\"\"\n    Build the `sglang generate` command for the given model.\n    Matches the commands in diffusion-benchmark-and-profile.md exactly.\n    \"\"\"\n    cfg = MODELS[model_key]\n\n    cmd = [\n        \"sglang\",\n        \"generate\",\n        f\"--model-path={cfg['path']}\",\n        f\"--prompt={cfg['prompt']}\",\n        \"--log-level=info\",\n    ]\n\n    if seed is not None:\n        cmd.append(f\"--seed={seed}\")\n\n    if \"negative_prompt\" in cfg:\n        cmd.append(f\"--negative-prompt={cfg['negative_prompt']}\")\n\n    if \"image_path\" in cfg:\n        cmd.append(f\"--image-path={cfg['image_path']}\")\n\n    cmd.extend(cfg[\"extra_args\"])\n\n    if save_output:\n        cmd.append(\"--save-output\")\n    if warmup:\n        cmd.append(\"--warmup\")\n    if torch_compile:\n        cmd.append(\"--enable-torch-compile\")\n    if perf_dump_path:\n        cmd.extend([\"--perf-dump-path\", perf_dump_path])\n\n    return cmd\n\n\ndef run_benchmark_once(\n    model_key: str,\n    use_custom_kernels: bool,\n    output_dir: Path,\n    warmup: bool = True,\n) -> dict:\n    \"\"\"Run a single benchmark pass and return results dict.\"\"\"\n    label = \"custom\" if use_custom_kernels else \"baseline\"\n    perf_path = output_dir / f\"{model_key}_{label}.json\"\n\n    cmd = build_sglang_cmd(\n        model_key,\n        use_custom_kernels=use_custom_kernels,\n        perf_dump_path=str(perf_path),\n        warmup=warmup,\n    )\n\n    env = os.environ.copy()\n    env.setdefault(\"FLASHINFER_DISABLE_VERSION_CHECK\", \"1\")\n    if not env.get(\"CUDA_VISIBLE_DEVICES\"):\n        env[\"CUDA_VISIBLE_DEVICES\"] = \",\".join(\n            str(index) for index in pick_idle_gpus(required_gpus_for_model(model_key))\n        )\n    if use_custom_kernels:\n        # NOTE: This env var is a convention for user-implemented kernel injection\n        # logic. SGLang runtime does not read it by default — you must add handling\n        # in your denoising stage or model code to check this var and apply patches.\n        env[\"SGLANG_DIFFUSION_CUSTOM_CUDA_KERNELS\"] = \"1\"\n\n    print(f\"\\n{'=' * 64}\")\n    print(f\"[{label.upper()}] {model_key}\")\n    print(f\"  CUDA_VISIBLE_DEVICES={env.get('CUDA_VISIBLE_DEVICES', '<unset>')}\")\n    print(\"  \" + \" \\\\\\n  \".join(cmd))\n    print()\n\n    t0 = time.time()\n    result = subprocess.run(cmd, env=env, text=True)\n    elapsed = time.time() - t0\n\n    if result.returncode != 0:\n        print(f\"  ERROR: exit code {result.returncode}\")\n        return {\"model\": model_key, \"label\": label, \"error\": True, \"elapsed_s\": elapsed}\n\n    metrics = {\"model\": model_key, \"label\": label, \"elapsed_s\": elapsed, \"error\": False}\n    if perf_path.exists():\n        try:\n            with open(perf_path) as f:\n                perf = json.load(f)\n\n            # e2e latency: total_duration_ms (set by PerformanceLogger.dump_benchmark_report)\n            total_ms = perf.get(\"total_duration_ms\")\n            metrics[\"e2e_latency_s\"] = (\n                float(total_ms) / 1000.0 if total_ms is not None else None\n            )\n\n            # denoise latency: look in \"steps\" list for the \"DenoisingStage\" entry\n            # steps = [{\"name\": \"DenoisingStage\", \"duration_ms\": 1234.5}, ...]\n            denoise_latency_s = None\n            for step in perf.get(\"steps\", []):\n                if (\n                    step.get(\"name\") == \"DenoisingStage\"\n                    and step.get(\"duration_ms\") is not None\n                ):\n                    denoise_latency_s = float(step[\"duration_ms\"]) / 1000.0\n                    break\n\n            # fallback: sum all per-step durations from denoise_steps_ms\n            # denoise_steps_ms = [{\"step\": 0, \"duration_ms\": 100.5}, ...]\n            if denoise_latency_s is None:\n                denoise_steps = perf.get(\"denoise_steps_ms\", [])\n                if denoise_steps:\n                    denoise_latency_s = (\n                        sum(s.get(\"duration_ms\", 0.0) for s in denoise_steps) / 1000.0\n                    )\n            metrics[\"denoise_latency_s\"] = denoise_latency_s\n\n            # peak memory: max peak_reserved_mb across all memory checkpoints (→ GB)\n            # memory_checkpoints = {\"after_DenoisingStage\": {\"peak_reserved_mb\": 12288.0, ...}}\n            peak_memory_gb = None\n            for snapshot in perf.get(\"memory_checkpoints\", {}).values():\n                peak_mb = snapshot.get(\"peak_reserved_mb\")\n                if peak_mb is not None:\n                    candidate = float(peak_mb) / 1024.0\n                    if peak_memory_gb is None or candidate > peak_memory_gb:\n                        peak_memory_gb = candidate\n            metrics[\"peak_memory_gb\"] = peak_memory_gb\n\n        except Exception as e:\n            print(f\"  Warning: could not parse perf dump: {e}\")\n\n    return metrics\n\n\ndef print_results_table(results: list[dict]):\n    \"\"\"Print baseline vs custom kernel comparison table.\"\"\"\n    print()\n    print(\"=\" * 80)\n    print(\"BENCHMARK RESULTS — Denoise Latency (primary metric ★)\")\n    print(\"(Models and params match diffusion-benchmark-and-profile.md)\")\n    print(\"=\" * 80)\n\n    by_model: dict[str, dict] = {}\n    for r in results:\n        by_model.setdefault(r[\"model\"], {})[r[\"label\"]] = r\n\n    print(\n        f\"{'Model':<16} {'Baseline(s)':>12} {'Custom(s)':>10} {'Speedup':>9} {'Peak Mem(GB)':>14}\"\n    )\n    print(\"-\" * 64)\n\n    for model_key in MODELS:  # preserve order\n        if model_key not in by_model:\n            continue\n        runs = by_model[model_key]\n        base = runs.get(\"baseline\", {})\n        custom = runs.get(\"custom\", {})\n\n        base_lat = base.get(\"denoise_latency_s\")\n        custom_lat = custom.get(\"denoise_latency_s\")\n        peak_mem = base.get(\"peak_memory_gb\") or custom.get(\"peak_memory_gb\")\n\n        speedup = f\"{base_lat / custom_lat:.2f}x\" if base_lat and custom_lat else \"n/a\"\n        base_s = f\"{base_lat:.2f}\" if base_lat else \"n/a\"\n        custom_s = f\"{custom_lat:.2f}\" if custom_lat else \"n/a\"\n        mem_s = f\"{peak_mem:.1f}\" if isinstance(peak_mem, float) else \"n/a\"\n\n        print(f\"{model_key:<16} {base_s:>12} {custom_s:>10} {speedup:>9} {mem_s:>14}\")\n\n    print(\"-\" * 64)\n    print()\n    print(\"★ Denoise latency = total DiT forward pass time across all inference steps.\")\n    print(\n        \"  See diffusion-benchmark-and-profile.md for full Level 1/2 profiling workflow.\"\n    )\n\n\ndef inject_kernels_example():\n    \"\"\"\n    Show the kernel injection pattern used when SGLANG_DIFFUSION_CUSTOM_CUDA_KERNELS=1.\n    After implementing add-cuda-kernel.md, this logic lives in denoising.py or\n    the model's transformer.py — NOT in this script.\n\n    Call patch_rmsnorm(dit_model) BEFORE torch.compile and BEFORE any CPU offloading.\n    \"\"\"\n    import torch.nn as nn\n\n    try:\n        from sglang.jit_kernel.diffusion.rmsnorm import diffusion_rmsnorm\n    except ImportError:\n        print(\n            \"diffusion.rmsnorm JIT kernel not available. \"\n            \"Implement add-cuda-kernel.md first.\"\n        )\n        return\n\n    def patch_rmsnorm(model: nn.Module, verbose: bool = False) -> int:\n        \"\"\"Monkey-patch all RMSNorm variants to use the JIT CUDA kernel.\"\"\"\n        patched = 0\n        for name, module in model.named_modules():\n            if \"RMSNorm\" not in type(module).__name__:\n                continue\n            eps = getattr(module, \"eps\", getattr(module, \"variance_epsilon\", 1e-6))\n            has_weight = hasattr(module, \"weight\") and module.weight is not None\n\n            if has_weight:\n\n                def _make(mod, ep):\n                    def fwd(x):\n                        return diffusion_rmsnorm(x, weight=mod.weight, eps=ep)\n\n                    return fwd\n\n                module.forward = _make(module, eps)\n            else:\n\n                def _make_no_w(ep):\n                    def fwd(x):\n                        return diffusion_rmsnorm(x, weight=None, eps=ep)\n\n                    return fwd\n\n                module.forward = _make_no_w(eps)\n\n            patched += 1\n            if verbose:\n                print(f\"  Patched: {name} (weight={has_weight})\")\n        return patched\n\n    return patch_rmsnorm\n\n\ndef main():\n    parser = argparse.ArgumentParser(\n        description=\"SGLang Diffusion denoise benchmark — baseline vs JIT CUDA kernels\"\n    )\n    parser.add_argument(\n        \"--model\",\n        choices=list(MODELS.keys()),\n        help=\"Model to benchmark (default: flux)\",\n    )\n    parser.add_argument(\"--all\", action=\"store_true\", help=\"Benchmark all 7 models\")\n    parser.add_argument(\n        \"--custom-kernels\",\n        action=\"store_true\",\n        help=\"Run with custom JIT CUDA kernels (SGLANG_DIFFUSION_CUSTOM_CUDA_KERNELS=1)\",\n    )\n    parser.add_argument(\n        \"--no-custom-kernels\",\n        action=\"store_true\",\n        help=\"Run baseline (no custom kernels)\",\n    )\n    parser.add_argument(\n        \"--compare\",\n        action=\"store_true\",\n        help=\"Run both baseline and custom, print comparison table\",\n    )\n    parser.add_argument(\n        \"--output-dir\",\n        type=str,\n        default=str(get_output_dir(\"benchmarks\", REPO_ROOT)),\n        help=\"Directory for perf dump JSON files\",\n    )\n    parser.add_argument(\"--no-warmup\", action=\"store_true\", help=\"Skip warmup\")\n    parser.add_argument(\n        \"--show-injection-example\",\n        action=\"store_true\",\n        help=\"Print kernel injection pattern and exit\",\n    )\n\n    args = parser.parse_args()\n\n    if args.show_injection_example:\n        patch_fn = inject_kernels_example()\n        if patch_fn:\n            print(\n                \"patch_rmsnorm function defined. \"\n                \"Call it on the DiT model before torch.compile and CPU offloading.\"\n            )\n        return\n\n    output_dir = Path(args.output_dir)\n    output_dir.mkdir(parents=True, exist_ok=True)\n    warmup = not args.no_warmup\n\n    models_to_run = list(MODELS.keys()) if args.all else [args.model or \"flux\"]\n    results = []\n\n    for model_key in models_to_run:\n        if args.compare:\n            results.append(run_benchmark_once(model_key, False, output_dir, warmup))\n            results.append(run_benchmark_once(model_key, True, output_dir, warmup))\n        elif args.custom_kernels:\n            results.append(run_benchmark_once(model_key, True, output_dir, warmup))\n        else:\n            results.append(run_benchmark_once(model_key, False, output_dir, warmup))\n\n    if results:\n        print_results_table(results)\n\n    print(f\"Perf dump JSONs → {output_dir}\")\n    print(\n        \"Compare across runs: follow diffusion-benchmark-and-profile.md → Perf dump & before/after compare.\"\n    )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/scripts/bench_diffusion_rmsnorm.py",
    "content": "\"\"\"\nMicro-benchmark for the SGLang Diffusion JIT CUDA RMSNorm kernel.\n\nCompares:\n  1. SGLang JIT CUDA kernel (diffusion_rmsnorm)\n  2. PyTorch baseline (torch.nn.functional.rms_norm)\n\nAdapted from: https://github.com/huggingface/kernels/tree/main/skills/cuda-kernels\n\nUsage:\n    cd /path/to/sglang\n    python3 python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/scripts/bench_diffusion_rmsnorm.py\n\nRequirements:\n    # Run inside the configured SGLang diffusion container shell.\n    # This script auto-selects an idle GPU when CUDA_VISIBLE_DEVICES is unset.\n\"\"\"\n\nimport os\nimport sys\nimport time\nfrom pathlib import Path\nfrom typing import Tuple\n\nSCRIPT_DIR = Path(__file__).resolve().parent\nif str(SCRIPT_DIR) not in sys.path:\n    sys.path.insert(0, str(SCRIPT_DIR))\n\nfrom diffusion_skill_env import configure_runtime_env\n\nconfigure_runtime_env(required_gpus=1)\n\nimport torch\n\n# ---------------------------------------------------------------------------\n# Import the JIT CUDA kernel.\n# When you implement add-cuda-kernel.md, the file will be at:\n#   python/sglang/jit_kernel/diffusion/rmsnorm.py\n# ---------------------------------------------------------------------------\ntry:\n    from sglang.jit_kernel.diffusion.rmsnorm import diffusion_rmsnorm\n\n    JIT_AVAILABLE = True\nexcept ImportError:\n    JIT_AVAILABLE = False\n    print(\n        \"WARNING: diffusion.rmsnorm JIT kernel not available. \"\n        \"Run after implementing add-cuda-kernel.md.\"\n    )\n\n\ndef pytorch_rmsnorm(\n    x: torch.Tensor,\n    weight: torch.Tensor | None = None,\n    eps: float = 1e-6,\n) -> torch.Tensor:\n    \"\"\"Reference PyTorch implementation of RMSNorm.\"\"\"\n    hidden = x.shape[-1]\n    return torch.nn.functional.rms_norm(\n        x.float(), (hidden,), weight.float() if weight is not None else None, eps=eps\n    ).to(x.dtype)\n\n\ndef benchmark_kernel(\n    func,\n    args,\n    warmup: int = 20,\n    iterations: int = 100,\n) -> Tuple[float, float]:\n    \"\"\"Benchmark a kernel function. Returns (avg_ms, min_ms).\"\"\"\n    for _ in range(warmup):\n        func(*args)\n    torch.cuda.synchronize()\n\n    times = []\n    for _ in range(iterations):\n        torch.cuda.synchronize()\n        t0 = time.perf_counter()\n        func(*args)\n        torch.cuda.synchronize()\n        times.append((time.perf_counter() - t0) * 1000)\n\n    return sum(times) / len(times), min(times)\n\n\ndef run_benchmark():\n    print(\"=\" * 72)\n    print(\"SGLang Diffusion RMSNorm Micro-Benchmark: JIT CUDA vs PyTorch\")\n    print(\"=\" * 72)\n    print(f\"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', '<unset>')}\")\n    print(f\"Device: {torch.cuda.get_device_name(0)}\")\n    cap = torch.cuda.get_device_capability()\n    print(f\"Compute Capability: sm_{cap[0]}{cap[1]}\")\n    print()\n\n    if not JIT_AVAILABLE:\n        print(\"Skipping JIT kernel benchmark (kernel not available).\")\n        return\n\n    # Determine dtype: T4 (sm_75) has no BF16\n    dtype = torch.bfloat16 if cap >= (8, 0) else torch.float16\n    print(f\"Dtype: {dtype}\")\n    print()\n\n    # Typical DiT hidden sizes for sglang diffusion models:\n    #   FLUX.1-dev: hidden=3072\n    #   Qwen-Image: hidden=2048\n    #   Wan2.2:     hidden=4096\n    configs = [\n        # (batch_tokens, hidden_size, has_weight)\n        (1024, 2048, True),  # Qwen-Image: 1 sample × 1024 tokens\n        (4096, 2048, True),  # Qwen-Image: larger batch\n        (1024, 3072, True),  # FLUX: 1 sample × 1024 tokens\n        (4096, 3072, True),  # FLUX: larger\n        (4096, 4096, True),  # Wan2.2\n        (4096, 2048, False),  # no-weight (elementwise_affine=False)\n        (16384, 3072, True),  # long sequence\n    ]\n\n    print(\n        f\"{'Config':<32} {'JIT(ms)':>10} {'PyTorch(ms)':>12} {'Speedup':>9} {'Weight'}\"\n    )\n    print(\"-\" * 72)\n\n    total_speedup = 0\n    n = 0\n\n    for batch_tokens, hidden, has_weight in configs:\n        x = torch.randn(batch_tokens, hidden, dtype=dtype, device=\"cuda\")\n        weight = torch.ones(hidden, dtype=dtype, device=\"cuda\") if has_weight else None\n\n        jit_avg, _ = benchmark_kernel(\n            diffusion_rmsnorm, (x, weight, 1e-6), warmup=20, iterations=100\n        )\n        pt_avg, _ = benchmark_kernel(\n            pytorch_rmsnorm, (x, weight, 1e-6), warmup=20, iterations=100\n        )\n\n        speedup = pt_avg / jit_avg\n        total_speedup += speedup\n        n += 1\n\n        w_str = \"yes\" if has_weight else \"no \"\n        cfg = f\"[{batch_tokens}×{hidden}]\"\n        print(f\"{cfg:<32} {jit_avg:>10.3f} {pt_avg:>12.3f} {speedup:>8.2f}x  {w_str}\")\n\n    print(\"-\" * 72)\n    print(f\"{'Average Speedup':>56} {total_speedup / n:.2f}x\")\n    print()\n\n    # -----------------------------------------------------------------------\n    # Correctness check\n    # -----------------------------------------------------------------------\n    print(\"Correctness Check (BF16 tolerance 0.02):\")\n    x = torch.randn(4096, 3072, dtype=dtype, device=\"cuda\")\n    weight = torch.ones(3072, dtype=dtype, device=\"cuda\")\n\n    out_jit = diffusion_rmsnorm(x, weight=weight, eps=1e-6)\n    out_ref = pytorch_rmsnorm(x, weight=weight, eps=1e-6)\n\n    max_diff = (out_jit - out_ref).abs().max().item()\n    rel_diff = ((out_jit - out_ref).abs() / (out_ref.abs() + 1e-8)).max().item()\n    passed = max_diff < 0.02\n\n    print(f\"  Max absolute diff: {max_diff:.2e}\")\n    print(f\"  Max relative diff: {rel_diff:.2e}\")\n    print(f\"  Correctness: {'PASS ✓' if passed else 'FAIL ✗'}\")\n    print()\n\n    # -----------------------------------------------------------------------\n    # Memory bandwidth analysis\n    # -----------------------------------------------------------------------\n    print(\"Memory Bandwidth Analysis:\")\n    bt, hid = 4096, 3072\n    x = torch.randn(bt, hid, dtype=dtype, device=\"cuda\")\n    weight = torch.ones(hid, dtype=dtype, device=\"cuda\")\n\n    bytes_per_elem = dtype.itemsize\n    total_bytes = (\n        bt * hid + hid + bt * hid\n    ) * bytes_per_elem  # read x + read w + write out\n    jit_avg, _ = benchmark_kernel(diffusion_rmsnorm, (x, weight, 1e-6))\n\n    bandwidth_gbps = (total_bytes / 1e9) / (jit_avg / 1000)\n    theoretical_bw = {\n        (9, 0): 3350,  # H100: 3.35 TB/s\n        (8, 0): 2000,  # A100 80GB\n    }.get(\n        cap, 320\n    )  # T4: 320 GB/s\n    efficiency = bandwidth_gbps / theoretical_bw * 100\n\n    print(f\"  Shape: [{bt} × {hid}]  dtype: {dtype}\")\n    print(f\"  Total data: {total_bytes / 1e6:.1f} MB\")\n    print(f\"  Achieved: {bandwidth_gbps:.1f} GB/s\")\n    print(f\"  Theoretical ({torch.cuda.get_device_name(0)}): {theoretical_bw} GB/s\")\n    print(f\"  Bandwidth efficiency: {efficiency:.1f}%\")\n    print()\n    print(\"Target: ≥ 30% efficiency (H100/A100), ≥ 40% (T4)\")\n\n\nif __name__ == \"__main__\":\n    if not torch.cuda.is_available():\n        print(\"CUDA not available.\")\n    else:\n        run_benchmark()\n"
  },
  {
    "path": "python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/scripts/diffusion_skill_env.py",
    "content": "from __future__ import annotations\n\nimport argparse\nimport csv\nimport os\nimport subprocess\nfrom pathlib import Path\n\nOUTPUT_DIR_NAMES = {\n    \"benchmarks\": Path(\"outputs/diffusion_benchmarks\"),\n    \"profiles\": Path(\"outputs/diffusion_profiles\"),\n    \"ncu\": Path(\"outputs/ncu_reports\"),\n}\n\n\ndef get_repo_root() -> Path:\n    import sglang\n\n    return Path(sglang.__file__).resolve().parents[2]\n\n\ndef get_assets_dir(repo_root: Path | None = None) -> Path:\n    root = repo_root or get_repo_root()\n    return root / \"inputs\" / \"diffusion_benchmark\" / \"figs\"\n\n\ndef get_output_dir(name: str, repo_root: Path | None = None) -> Path:\n    if name not in OUTPUT_DIR_NAMES:\n        raise KeyError(f\"Unknown output dir name: {name}\")\n    root = repo_root or get_repo_root()\n    return root / OUTPUT_DIR_NAMES[name]\n\n\ndef ensure_dir(path: Path) -> Path:\n    path.mkdir(parents=True, exist_ok=True)\n    return path\n\n\ndef check_write_access(repo_root: Path | None = None) -> Path:\n    root = repo_root or get_repo_root()\n    probe_dir = ensure_dir(root / \".cache\" / \"diffusion_skill_write_test\")\n    probe_file = probe_dir / \"probe.txt\"\n    probe_file.write_text(\"ok\", encoding=\"utf-8\")\n    return probe_file\n\n\ndef _run_nvidia_smi(query: str) -> list[list[str]]:\n    command = [\n        \"nvidia-smi\",\n        f\"--query-{query}\",\n        \"--format=csv,noheader,nounits\",\n    ]\n    result = subprocess.run(command, check=True, capture_output=True, text=True)\n    rows: list[list[str]] = []\n    for raw_line in result.stdout.splitlines():\n        line = raw_line.strip()\n        if not line:\n            continue\n        rows.append([field.strip() for field in csv.reader([line]).__next__()])\n    return rows\n\n\ndef get_gpu_inventory() -> list[dict[str, int | str]]:\n    rows = _run_nvidia_smi(\"gpu=index,uuid,memory.used,memory.total,utilization.gpu\")\n    inventory = []\n    for index, uuid, memory_used, memory_total, utilization_gpu in rows:\n        inventory.append(\n            {\n                \"index\": int(index),\n                \"uuid\": uuid,\n                \"memory_used_mib\": int(memory_used),\n                \"memory_total_mib\": int(memory_total),\n                \"utilization_gpu_pct\": int(utilization_gpu),\n            }\n        )\n    return inventory\n\n\ndef get_busy_gpu_uuids() -> set[str]:\n    rows = _run_nvidia_smi(\"compute-apps=gpu_uuid,pid,process_name,used_gpu_memory\")\n    return {gpu_uuid for gpu_uuid, *_ in rows}\n\n\ndef pick_idle_gpus(\n    required_gpus: int,\n    max_memory_used_mib: int = 32,\n    max_utilization_gpu_pct: int = 5,\n) -> list[int]:\n    inventory = get_gpu_inventory()\n    busy_uuids = get_busy_gpu_uuids()\n\n    idle = [\n        int(gpu[\"index\"])\n        for gpu in inventory\n        if gpu[\"uuid\"] not in busy_uuids\n        and int(gpu[\"memory_used_mib\"]) <= max_memory_used_mib\n        and int(gpu[\"utilization_gpu_pct\"]) <= max_utilization_gpu_pct\n    ]\n    if len(idle) < required_gpus:\n        raise RuntimeError(\n            \"Not enough idle GPUs. \"\n            f\"required={required_gpus}, idle={idle}, inventory={inventory}, busy={sorted(busy_uuids)}\"\n        )\n    return idle[:required_gpus]\n\n\ndef configure_runtime_env(required_gpus: int = 1) -> str | None:\n    os.environ.setdefault(\"FLASHINFER_DISABLE_VERSION_CHECK\", \"1\")\n    if os.environ.get(\"CUDA_VISIBLE_DEVICES\"):\n        return None\n    selected = \",\".join(str(index) for index in pick_idle_gpus(required_gpus))\n    os.environ[\"CUDA_VISIBLE_DEVICES\"] = selected\n    return selected\n\n\ndef main() -> None:\n    parser = argparse.ArgumentParser(\n        description=\"Resolve SGLang diffusion skill paths and idle GPUs.\"\n    )\n    parser.add_argument(\n        \"command\",\n        choices=[\n            \"print-root\",\n            \"print-assets-dir\",\n            \"print-output-dir\",\n            \"print-idle-gpus\",\n            \"check-write-access\",\n        ],\n    )\n    parser.add_argument(\n        \"--kind\",\n        choices=sorted(OUTPUT_DIR_NAMES),\n        help=\"Output directory kind for print-output-dir.\",\n    )\n    parser.add_argument(\n        \"--count\",\n        type=int,\n        default=1,\n        help=\"Number of idle GPUs to print.\",\n    )\n    parser.add_argument(\n        \"--mkdir\",\n        action=\"store_true\",\n        help=\"Create the requested directory before printing it.\",\n    )\n    args = parser.parse_args()\n\n    if args.command == \"print-root\":\n        print(get_repo_root())\n        return\n    if args.command == \"print-assets-dir\":\n        path = get_assets_dir()\n        if args.mkdir:\n            ensure_dir(path)\n        print(path)\n        return\n    if args.command == \"print-output-dir\":\n        if not args.kind:\n            raise SystemExit(\"--kind is required for print-output-dir\")\n        path = get_output_dir(args.kind)\n        if args.mkdir:\n            ensure_dir(path)\n        print(path)\n        return\n    if args.command == \"print-idle-gpus\":\n        print(\",\".join(str(index) for index in pick_idle_gpus(args.count)))\n        return\n    if args.command == \"check-write-access\":\n        print(check_write_access())\n        return\n    raise SystemExit(f\"Unhandled command: {args.command}\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/use-efficient-diffusion-kernels.md",
    "content": "---\nname: use-efficient-diffusion-kernels\ndescription: Guidance for using SGLang Diffusion fused kernels and fast CUDA paths. Use when mapping fusion patterns in diffusion inference, choosing fused ops or attention backends, handling RoPE/QK norm performance pitfalls, or integrating new diffusion models with kernel-aware constraints.\n---\n\n# Use Efficient Diffusion Kernels\n\n**Overview**\nThis skill focuses on SGLang Diffusion (`sglang.multimodal_gen`) kernel fusion patterns and fast CUDA paths. Prefer existing fused ops (Triton, CuTe DSL, sgl-kernel). Make constraints and fallbacks explicit.\n\n**Key Files**\n- `python/sglang/multimodal_gen/runtime/layers/layernorm.py`\n- `python/sglang/multimodal_gen/runtime/layers/elementwise.py`\n- `python/sglang/multimodal_gen/runtime/layers/rotary_embedding/utils.py`\n- `python/sglang/jit_kernel/diffusion/triton/scale_shift.py`\n- `python/sglang/jit_kernel/diffusion/triton/norm.py`\n- `python/sglang/jit_kernel/diffusion/triton/rmsnorm_onepass.py`\n- `python/sglang/jit_kernel/diffusion/triton/rotary.py`\n- `python/sglang/jit_kernel/diffusion/cutedsl/scale_residual_norm_scale_shift.py`\n- `python/sglang/jit_kernel/norm.py`\n- `python/sglang/multimodal_gen/runtime/platforms/cuda.py`\n- `python/sglang/multimodal_gen/runtime/layers/attention/selector.py`\n- `docs/diffusion/performance/attention_backends.md` (repo root)\n\n**Core Fusion Patterns**\n\n1. Scale/Shift elementwise fusion (AdaLN modulation)\n- Kernels: `fuse_scale_shift_kernel`, `fuse_scale_shift_gate_select01_kernel`\n- Locations: `elementwise.py`, `layernorm.py`, `qwen_image.py`, `triton/scale_shift.py`\n- Use cases: `x * (1 + scale) + shift` and `a * (k + b) + c`\n- Constraints: `x` must be CUDA and contiguous. `scale/shift` support 0D/1D/2D/3D/4D broadcast. 4D `[B, F, 1, C]` requires `L % F == 0`.\n- NPU fallback: `scale_shift.py` swaps to `npu_fallback` native path.\n\n2. Norm + Scale/Shift fusion (CuTe DSL)\n- Kernels: `fused_norm_scale_shift`, `fused_scale_residual_norm_scale_shift`\n- Locations: `layernorm.py`, `cutedsl/scale_residual_norm_scale_shift.py`\n- Use cases:\n  - `y = norm(x) * (1 + scale) + shift`\n  - `y = norm(residual + gate * x) * (1 + scale) + shift`\n- Constraints: `D % 256 == 0` and `D <= 8192`. `x/residual/gate/scale/shift` must pass shape and stride validation. Dtypes limited to fp16/bf16/fp32.\n- Behavior: CuTe DSL compilation cached by `(dtype, ndim, D, norm_type)`. `None` tensors replaced by scalar placeholders. If constraints fail, `layernorm.py` warns and falls back to native PyTorch.\n\n3. Triton LayerNorm/RMSNorm fusion\n- Kernels: `rms_norm_fn`, `layer_norm_fn`, `norm_infer`\n- Locations: `triton/norm.py`, `layernorm.py`\n- Use cases: fp32 RMSNorm with residual/dropout/rowscale/x1 branches, and inference-friendly `norm_infer`.\n- Constraints: last dim must be contiguous, and `N * element_size < 64KB`.\n\n4. Triton one-pass RMSNorm (small hidden size fast path)\n- Kernel: `triton_one_pass_rms_norm`\n- Locations: `triton/rmsnorm_onepass.py`, `layernorm.py`\n- Use case: `hidden_size <= 128` in `RMSNorm.forward_cuda`.\n\n5. Triton RoPE fusion\n- Kernel: `apply_rotary_embedding`\n- Locations: `triton/rotary.py`, `rotary_embedding/utils.py`\n- Use case: GPT-J style RoPE when not Neox.\n- Constraints: `head_size` must be even.\n- NPU fallback: `npu_fallback.apply_rotary_embedding_native`.\n\n**Faster CUDA Kernel Usage Points**\n\n1. sgl-kernel RMSNorm and fused add RMSNorm\n- Location: `layernorm.py`\n- Behavior: CUDA uses `sgl_kernel.fused_add_rmsnorm` and `sgl_kernel.rmsnorm`. `hidden_size <= 128` uses Triton one-pass. ROCm falls back to native.\n\n2. Attention backend selection (FlashAttention, Sage, SDPA)\n- Locations: `platforms/cuda.py`, `attention/selector.py`, `docs/diffusion/performance/attention_backends.md`\n- Behavior: CUDA prefers FlashAttention (FA3/FA4) when supported, otherwise Torch SDPA. Force via `--attention-backend` or `global_force_attn_backend`.\n\n3. FlashInfer RoPE (Q/K inplace)\n- Location: `rotary_embedding/utils.py`\n- Behavior: `flashinfer.rope.apply_rope_with_cos_sin_cache_inplace` when available, otherwise Triton RoPE fallback.\n\n**QK Norm Optimization**\n\n- Entry point: `apply_qk_norm` in `layernorm.py`.\n- Fast path: JIT fused inplace QK norm from `python/sglang/jit_kernel/norm.py` via `fused_inplace_qknorm`.\n- Preconditions for fused path:\n  - CUDA only.\n  - `allow_inplace=True` and `q_eps == k_eps`.\n  - `can_use_fused_inplace_qknorm(head_dim, dtype)` returns true.\n  - Supported head dims: `64, 128, 256, 512, 1024`.\n- Behavior: Fused path operates on `q` and `k` in place after reshaping to `[B, -1, head_dim]`. If preconditions fail, fall back to per-tensor RMSNorm.\n\n**Common Entry Points in Diffusion Models**\n- AdaLN modulation: `LayerNormScaleShift`, `RMSNormScaleShift`, `ScaleResidual*` in `layernorm.py`.\n- Qwen-Image gating: `fuse_scale_shift_gate_select01_kernel` in `qwen_image.py`.\n- QK norm: `apply_qk_norm` used in `flux.py`, `flux_2.py`, `qwen_image.py`, `zimage.py`, `wanvideo.py`, `ltx_2.py`, `hunyuanvideo.py`.\n- RoPE: `_apply_rotary_emb` prefers Triton; Q/K RoPE prefers FlashInfer when present.\n\n**Constraints and Fallbacks**\n- `scale_shift` Triton requires CUDA + contiguous `x`. NPU swaps to native.\n- CuTe DSL fused norms require `D % 256 == 0` and `D <= 8192`.\n- Triton norm kernels error on feature size >= 64KB.\n- FlashAttention requires fp16/bf16 and SM80+; otherwise SDPA.\n\n**Integration Checklist for New Models**\n\n1. Reuse `LayerNormScaleShift` or `ScaleResidual*` modules instead of re-implementing fusion logic.\n2. Keep tensors contiguous and satisfy D alignment (`% 256`) and size (`<= 8192`) for CuTe fused paths.\n3. Use `fuse_scale_shift_kernel` for AdaLN modulation and keep a PyTorch fallback.\n4. Use `apply_qk_norm` and ensure head_dim is in the supported list for fused QK norm.\n5. If using FlashInfer RoPE, avoid `pack qkv` and ensure Q/K are contiguous.\n6. For attention, follow `selector.py` priority; override with CLI only if needed.\n\n**When Extending or Modifying Kernels**\n- Add `torch.library.custom_op` and `register_fake` for compile and meta support.\n- Keep CuTe compile cache keys aligned to `(dtype, ndim, D)`.\n- Avoid implicit broadcasts that force hidden `contiguous()` copies.\n- Preserve NPU and ROCm fallback paths.\n- **Always verify with ncu** (`ncu --set full`) that the kernel achieves adequate memory bandwidth utilization (>70% of peak for bandwidth-bound ops) and occupancy (>50%). See `diffusion-benchmark-and-profile.md` Step 3.5 for the ncu workflow.\n"
  },
  {
    "path": "python/sglang/multimodal_gen/.claude/skills/diffusion-optimal-perf/SKILL.md",
    "content": "---\nname: diffusion-optimal-perf\ndescription: Guide for achieving optimal performance with SGLang-Diffusion. Covers all perf-related CLI flags, env vars, and best practices for lossless and lossy speedup.\n---\n\n# SGLang-Diffusion: Optimal Performance Guide\n\nUse this guide when a user asks how to speed up diffusion inference, reduce latency, lower VRAM usage, or tune SGLang-Diffusion for production.\n\nBefore running any `sglang generate` command below inside the diffusion container:\n- derive the repo root from `python3 -c \"import os, sglang; print(os.path.abspath(os.path.join(os.path.dirname(sglang.__file__), '..', '..')))\"` and `cd` there\n- export `FLASHINFER_DISABLE_VERSION_CHECK=1`\n- verify the repo is writable if you expect perf dumps or outputs\n- choose idle GPU(s) first; reuse `diffusion-kernel/scripts/diffusion_skill_env.py` when doing perf work\n\nReference: [SGLang-Diffusion Advanced Optimizations Blog](https://lmsys.org/blog/2026-02-16-sglang-diffusion-advanced-optimizations/)\n\n---\n\n## Section 1: Lossless Optimizations\n\nThese options **do not** affect output quality. The generated images/videos are numerically identical (or within floating-point rounding) to the baseline.\n\n| Option | CLI Flag / Env Var | What It Does | Speedup | Limitations / Notes |\n|---|---|---|---|---|\n| **torch.compile** | `--enable-torch-compile` | Applies `torch.compile` to the DiT forward pass, fusing ops and reducing kernel launch overhead. | ~1.2–1.5x on denoising | First request is slow (compilation). May cause minor precision drifts due to [PyTorch issue #145213](https://github.com/pytorch/pytorch/issues/145213). Pair with `--warmup` for best results. |\n| **Warmup** | `--warmup` | Runs dummy forward passes to warm up CUDA caches, JIT, and `torch.compile`. Eliminates cold-start penalty. | Removes first-request latency spike | Adds startup time. Without `--warmup-resolutions`, warmup happens on first request. |\n| **Warmup Resolutions** | `--warmup-resolutions 256x256 720x720` | Pre-compiles and warms up specific resolutions at server startup (instead of lazily on first request). | Faster first request per resolution | Each resolution adds to startup time. Serving mode only; useful when you know your target resolutions in advance. |\n| **Multi-GPU (SP)** | `--num-gpus N --ulysses-degree N` | Sequence parallelism across GPUs. Shards sequence tokens (not frames) to minimize padding. | Near-linear scaling with N GPUs | Requires NCCL; inter-GPU bandwidth matters. `ulysses_degree * ring_degree = sp_degree`. |\n| **CFG Parallel** | `--enable-cfg-parallel` | Runs conditional and unconditional CFG branches in parallel across GPUs. **For CFG models with multi-GPU, always prefer `--enable-cfg-parallel` + Ulysses over pure Ulysses** — it is generally faster at the same GPU count due to better compute-to-communication ratio and elimination of sequential branch execution. | Typically faster than pure SP for CFG models | Requires `num_gpus >= 2`. Halves the Ulysses group size (e.g. 8 GPU → two 4-GPU groups). Only for models that use CFG. |\n| **Layerwise Offload** | `--dit-layerwise-offload` | Async layer-by-layer H2D prefetch with compute overlap. Only ~2 DiT layers reside on GPU at a time, dramatically reducing VRAM. For **video models** (where per-layer compute >> H2D transfer), the memcpy is completely hidden behind computation — **zero-cost offload** that saves VRAM without speed penalty ([PR #15511](https://github.com/sgl-project/sglang/pull/15511)). | Saves VRAM (40 GB → ~11 GB for Wan A14B); zero or near-zero speed cost for video models | Enabled by default for Wan/MOVA video models. Incompatible with Cache-DiT. For **image models** or highly parallelized setups (many GPUs, small per-GPU compute), the copy stream may not be fully hidden and can cause slowdown. |\n| **Offload Prefetch Size** | `--dit-offload-prefetch-size F` | Fine-grained control over layerwise offload: how many layers to prefetch ahead. `0.0` = 1 layer (min VRAM), `0.1` = 10% of layers, `≥1` = absolute layer count. | Tune for cases where default offload has copy stream interference (e.g. image models). 0.05–0.1 is a good starting point. | Values ≥ 0.5 approach no-offload VRAM with worse performance. See [PR #17693](https://github.com/sgl-project/sglang/pull/17693) for benchmarks on image models. |\n| **FSDP Inference** | `--use-fsdp-inference` | Uses PyTorch FSDP to shard model weights across GPUs with prefetch. Low latency, low VRAM. | Reduces per-GPU VRAM | Mutually exclusive with `--dit-layerwise-offload`. More overhead than SP on high-bandwidth interconnects. |\n| **CPU Offload (components)** | `--text-encoder-cpu-offload`, `--image-encoder-cpu-offload`, `--vae-cpu-offload`, `--dit-cpu-offload` | Offloads specific pipeline components to CPU when not in use. | Reduces peak VRAM | Adds H2D transfer latency when the component is needed. Auto-enabled for low-VRAM GPUs (<30 GB). **Tip:** after the first request completes, the console prints a peak VRAM analysis with suggestions on which offload flags can be safely disabled — look for the `\"Components that could stay resident\"` log line. |\n| **Pin CPU Memory** | `--pin-cpu-memory` | Uses pinned (page-locked) memory for CPU offload transfers. | Faster H2D transfers | Slightly higher host memory usage. Enabled by default; disable only as workaround for CUDA errors. |\n| **Attention Backend (lossless)** | `--attention-backend fa` | Selects lossless attention kernel: `fa` (FlashAttention 2/3/4), `torch_sdpa`. FA is the fastest lossless option. | FA >> SDPA for long sequences | FA requires compatible GPU (Ampere+). `fa3`/`fa4` are aliased to `fa`. Ring attention only works with `fa` or `sage_attn`. |\n| **Parallel Folding** | *(automatic when SP > 1)* | Reuses the SP process group as TP for the T5 text encoder, so text encoding is parallelized \"for free\". | Faster text encoding on multi-GPU | Automatic; no user action needed. Only applies to T5-based pipelines. |\n\n---\n\n## Section 2: Lossy Optimizations\n\nThese options **trade output quality** for speed or VRAM savings. Results will differ from the baseline.\n\n| Option | CLI Flag / Env Var | What It Does | Speedup | Quality Impact / Limitations |\n|---|---|---|---|---|\n| **Approximate Attention** | `--attention-backend sage_attn` / `sage_attn_3` / `sliding_tile_attn` / `video_sparse_attn` / `sparse_video_gen_2_attn` / `vmoba_attn` / `sla_attn` / `sage_sla_attn` | Replaces exact attention with approximate or sparse variants. `sage_attn`: INT8/FP8 quantized Q·K; `sliding_tile_attn`: spatial-temporal tile skipping; others: model-specific sparse patterns. | ~1.5–2x on attention (varies by backend) | Quality degradation varies by backend and model. `sage_attn` is the most general; sparse backends (`sliding_tile_attn`, `video_sparse_attn`, etc.) are video-model-specific and may require config files (e.g. `--mask-strategy-file-path` for STA). Requires corresponding packages installed. |\n| **Cache-DiT** | `SGLANG_CACHE_DIT_ENABLED=true` + `--cache-dit-config <path>` | Caches intermediate residuals across denoising steps and skips redundant computations via a Selective Computation Mask (SCM). | ~1.5–2x on supported models | Quality depends on SCM config. Incompatible with `--dit-layerwise-offload`. Requires correct per-model config YAML. |\n| **Quantized Models (Nunchaku / SVDQuant)** | `--enable-svdquant --transformer-weights-path <path>` + optional `--quantization-precision int4\\|nvfp4`, `--quantization-rank 32` | W4A4-style quantization via [Nunchaku](https://nunchaku.tech). Reduces DiT weight memory by ~4x. Precision/rank can be auto-inferred from weight filename or set explicitly. | ~1.5–2x compute speedup | Lossy quantization; quality depends on rank and precision. Requires pre-quantized weights. Ampere (SM8x) or SM12x only (no Hopper SM90). Higher rank = better quality but more memory. |\n| **Pre-quantized Weights** | `--transformer-weights-path <path>` | Load any pre-quantized transformer weights (FP8, INT8, etc.) from a single `.safetensors` file, a directory, or a HuggingFace repo ID. | ~1.3–1.5x compute (dtype dependent) | Requires pre-converted weights (e.g. via `tools/convert_hf_to_fp8.py` for FP8). Quality slightly worse than BF16; varies by quantization format. |\n| **Component Precision Override** | `--dit-precision fp16`, `--vae-precision fp16\\|bf16` | On-the-fly dtype conversion for individual components. E.g. convert a BF16 model to FP16 at load time, or run VAE in BF16 instead of FP32. | Reduces memory; FP16 can be faster on some GPUs | May affect numerical stability. VAE is FP32 by default for accuracy; lowering it is lossy. DiT defaults to BF16. |\n| **Fewer Inference Steps** | `--num-inference-steps N` (sampling param) | Reduces the number of denoising steps. Fewer steps = faster. | Linear speedup | Quality degrades with too few steps. Model-dependent optimal range. |\n\n---\n\n## Quick Recipes\n\n### Maximum speed, video model, multi-GPU, lossless (Wan A14B, 8 GPUs)\n\n```bash\nsglang generate --model-path Wan-AI/Wan2.2-T2V-A14B-Diffusers \\\n  --num-gpus 8 --enable-cfg-parallel --ulysses-degree 4 \\\n  --enable-torch-compile --warmup \\\n  --text-encoder-cpu-offload true \\\n  --prompt \"...\" --save-output\n```\n\nNote: `--dit-layerwise-offload` is enabled by default for Wan/MOVA video models and is zero-cost (H2D fully overlapped with compute). No need to disable it.\n\n### Maximum speed, image model, single GPU, lossless\n\n```bash\nsglang generate --model-path <IMAGE_MODEL> \\\n  --enable-torch-compile --warmup \\\n  --dit-layerwise-offload false \\\n  --prompt \"...\" --save-output\n```\n\nNote: for image models, per-layer compute is smaller, so layerwise offload may not fully hide H2D transfer. Disable it if VRAM allows.\n\n### Low VRAM, decent speed (single GPU)\n\n```bash\nsglang generate --model-path <MODEL> \\\n  --enable-torch-compile --warmup \\\n  --dit-layerwise-offload --dit-offload-prefetch-size 0.1 \\\n  --text-encoder-cpu-offload true --vae-cpu-offload true \\\n  --prompt \"...\" --save-output\n```\n\n### Maximum speed, lossy (SageAttention + Cache-DiT)\n\n```bash\nSGLANG_CACHE_DIT_ENABLED=true sglang generate --model-path <MODEL> \\\n  --attention-backend sage_attn \\\n  --cache-dit-config <config.yaml> \\\n  --enable-torch-compile --warmup \\\n  --dit-layerwise-offload false \\\n  --prompt \"...\" --save-output\n```\n\n---\n\n## Tips\n\n- **Benchmarking**: always use `--warmup` and look for the line ending with `(with warmup excluded)` for accurate timing.\n- **Perf dump**: use `--perf-dump-path result.json` to save structured metrics, then compare with `python python/sglang/multimodal_gen/benchmarks/compare_perf.py baseline.json result.json`.\n- **Offload tuning**: after the first request, the runtime logs peak GPU memory and which components could stay resident. Use this to decide which `--*-cpu-offload` flags to disable.\n- **Backend selection**: `--backend sglang` (default, auto-detected) enables all native optimizations (fused kernels, SP, etc.). `--backend diffusers` falls back to vanilla Diffusers pipelines but supports `--cache-dit-config` and diffusers attention backends.\n"
  },
  {
    "path": "python/sglang/multimodal_gen/.claude/skills/support-new-model/SKILL.md",
    "content": "---\nname: add-new-diffusion-model\ndescription: Step-by-step guide for adding a new diffusion model to SGLang. Covers the recommended Hybrid Monolithic pipeline pattern (BeforeDenoisingStage), as well as when to use the Modular Composition Style. Includes pipeline config, model components, registration, and testing.\n---\n\n# Tutorial: Adding a New Diffusion Model to SGLang\n\nThis tutorial walks through adding support for a new diffusion model. SGLang Diffusion supports two pipeline styles; choose the one that best fits your model.\n\n## Two Pipeline Styles\n\n### Style A: Hybrid Monolithic Pipeline (Recommended)\n\nThe recommended default for most new models. Uses a three-stage structure:\n\n```\nBeforeDenoisingStage (model-specific)  -->  DenoisingStage (standard)  -->  DecodingStage (standard)\n```\n\n- **BeforeDenoisingStage**: A single, model-specific stage that consolidates all pre-processing logic: input validation, text encoding, image encoding, latent preparation, timestep setup. This stage is unique per model.\n- **DenoisingStage**: Framework-standard stage for the denoising loop (DiT/UNet forward passes). Shared across models.\n- **DecodingStage**: Framework-standard stage for VAE decoding. Shared across models.\n\n**Why recommended?** Modern diffusion models have highly heterogeneous pre-processing requirements (different text encoders, different latent formats, different conditioning mechanisms). The Hybrid approach keeps pre-processing isolated per model, avoids fragile shared stages with excessive conditional logic, and lets developers port Diffusers reference code quickly.\n\n### Style B: Modular Composition Style\n\nUses the framework's fine-grained standard stages (`TextEncodingStage`, `LatentPreparationStage`, `TimestepPreparationStage`, etc.) to build the pipeline by composition.\n\nThis style is appropriate when:\n- **The new model's pre-processing can largely reuse existing stages** — e.g., a model that uses standard CLIP/T5 text encoding + standard latent preparation with minimal customization. In this case, `add_standard_t2i_stages()` or `add_standard_ti2i_stages()` may be all you need.\n- **A model-specific optimization needs to be extracted as a standalone stage** — e.g., a specialized encoding or conditioning step that benefits from being a separate stage for profiling, parallelism control, or reuse across multiple pipeline variants.\n\nSee existing Modular examples: `QwenImagePipeline` (uses `add_standard_t2i_stages`), `FluxPipeline`, `WanPipeline`.\n\n### How to Choose\n\n| Situation | Recommended Style |\n|-----------|-------------------|\n| Model has unique/complex pre-processing (VLM captioning, AR token generation, custom latent packing, etc.) | **Hybrid** — consolidate into a BeforeDenoisingStage |\n| Model fits neatly into standard text-to-image or text+image-to-image pattern | **Modular** — use `add_standard_t2i_stages()` / `add_standard_ti2i_stages()` |\n| Porting a Diffusers pipeline with many custom steps | **Hybrid** — copy the `__call__` logic into a single stage |\n| Adding a variant of an existing model that shares most logic | **Modular** — reuse existing stages, customize via PipelineConfig callbacks |\n| A specific pre-processing step needs special parallelism or profiling isolation | **Modular** — extract that step as a dedicated stage |\n\n**Key principle (both styles)**: The stage(s) before `DenoisingStage` must produce a `Req` batch object with all the standard tensor fields that `DenoisingStage` expects (latents, timesteps, prompt_embeds, etc.). As long as this contract is met, the pipeline remains composable regardless of which style you use.\n\n---\n\n## Key Files and Directories\n\n| Purpose | Path |\n|---------|------|\n| Pipeline classes | `python/sglang/multimodal_gen/runtime/pipelines/` |\n| Model-specific stages | `python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/` |\n| PipelineStage base class | `python/sglang/multimodal_gen/runtime/pipelines_core/stages/base.py` |\n| Pipeline base class | `python/sglang/multimodal_gen/runtime/pipelines_core/composed_pipeline_base.py` |\n| Standard stages (Denoising, Decoding) | `python/sglang/multimodal_gen/runtime/pipelines_core/stages/` |\n| Pipeline configs | `python/sglang/multimodal_gen/configs/pipeline_configs/` |\n| Sampling params | `python/sglang/multimodal_gen/configs/sample/` |\n| DiT model implementations | `python/sglang/multimodal_gen/runtime/models/dits/` |\n| VAE implementations | `python/sglang/multimodal_gen/runtime/models/vaes/` |\n| Encoder implementations | `python/sglang/multimodal_gen/runtime/models/encoders/` |\n| Scheduler implementations | `python/sglang/multimodal_gen/runtime/models/schedulers/` |\n| Model/VAE/DiT configs | `python/sglang/multimodal_gen/configs/models/dits/`, `vaes/`, `encoders/` |\n| Central registry | `python/sglang/multimodal_gen/registry.py` |\n\n---\n\n## Step-by-Step Implementation\n\n### Step 1: Obtain and Study the Reference Implementation\n\n**Before writing any code, ask the user to provide the model's original implementation or Diffusers pipeline code.** You need the actual source code to work from — do not guess or assume the model's architecture. If the user has not provided it, request:\n- The model's Diffusers pipeline source (e.g., the `pipeline_*.py` file from the `diffusers` library or HuggingFace repo)\n- Or the model's official reference implementation (e.g., from the model author's GitHub repo)\n- Or the HuggingFace model ID so you can look up `model_index.json` and the associated pipeline class\n\nOnce you have the reference code, study it thoroughly:\n\n1. Find the model's `model_index.json` to identify required modules (text_encoder, vae, transformer, scheduler, etc.)\n2. Read the Diffusers pipeline's `__call__` method end-to-end. Identify:\n   - How text prompts are encoded\n   - How latents are prepared (shape, dtype, scaling)\n   - How timesteps/sigmas are computed\n   - What conditioning kwargs the DiT/UNet expects\n   - How the denoising loop works (classifier-free guidance, etc.)\n   - How VAE decoding is done (scaling factors, tiling, etc.)\n\n### Step 2: Evaluate Reuse of Existing Pipelines and Stages\n\n**Before creating any new files, check whether an existing pipeline or stage can be reused or extended.** Only create new pipelines/stages when the existing ones would require extensive modifications or when no similar implementation exists.\n\nSpecifically:\n1. **Compare the new model's architecture against existing pipelines** (Flux, Wan, Qwen-Image, GLM-Image, HunyuanVideo, LTX, etc.). If the new model shares most of its structure with an existing one (e.g., same text encoders, similar latent format, compatible denoising loop), prefer:\n   - Adding a new config variant to the existing pipeline rather than creating a new pipeline class\n   - Reusing the existing `BeforeDenoisingStage` with minor parameter differences\n   - Using `add_standard_t2i_stages()` / `add_standard_ti2i_stages()` / `add_standard_ti2v_stages()` if the model fits standard patterns\n2. **Check existing stages** in `runtime/pipelines_core/stages/` and `stages/model_specific_stages/`. If an existing stage handles 80%+ of what the new model needs, extend it rather than duplicating it.\n3. **Check existing model components** — many models share VAEs (e.g., `AutoencoderKL`), text encoders (CLIP, T5), and schedulers. Reuse these directly instead of re-implementing.\n\n**Rule of thumb**: Only create a new file when the existing implementation would need substantial structural changes to accommodate the new model, or when no architecturally similar implementation exists.\n\n### Step 3: Implement Model Components\n\nAdapt or implement the model's core components in the appropriate directories.\n\n**DiT/Transformer** (`runtime/models/dits/{model_name}.py`):\n\n```python\n# python/sglang/multimodal_gen/runtime/models/dits/my_model.py\n\nimport torch\nimport torch.nn as nn\n\nfrom sglang.multimodal_gen.runtime.layers.layernorm import (\n    LayerNormScaleShift,\n    RMSNormScaleShift,\n)\nfrom sglang.multimodal_gen.runtime.layers.attention.selector import (\n    get_attn_backend,\n)\n\n\nclass MyModelTransformer2DModel(nn.Module):\n    \"\"\"DiT model for MyModel.\n\n    Adapt from the Diffusers/reference implementation. Key points:\n    - Use SGLang's fused LayerNorm/RMSNorm ops (see use-efficient-diffusion-kernels skill)\n    - Use SGLang's attention backend selector\n    - Keep the same parameter naming as Diffusers for weight loading compatibility\n    \"\"\"\n\n    def __init__(self, config):\n        super().__init__()\n        # ... model layers ...\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: torch.Tensor,\n        timestep: torch.Tensor,\n        # ... model-specific kwargs ...\n    ) -> torch.Tensor:\n        # ... forward pass ...\n        return output\n```\n\n**Tensor Parallel (TP) and Sequence Parallel (SP)**: For multi-GPU deployment, it is recommended to add TP/SP support to the DiT model. This can be done incrementally after the single-GPU implementation is verified. Reference existing implementations and adapt to your model's architecture:\n\n- **Wan model** (`runtime/models/dits/wanvideo.py`) — Full TP + SP reference:\n  - TP: Uses `ColumnParallelLinear` for Q/K/V projections, `RowParallelLinear` for output projections, attention heads divided by `tp_size`\n  - SP: Sequence dimension sharding via `get_sp_world_size()`, padding for alignment, `sequence_model_parallel_all_gather` for aggregation\n  - Cross-attention skips SP (`skip_sequence_parallel=is_cross_attention`)\n- **Qwen-Image model** (`runtime/models/dits/qwen_image.py`) — SP + USPAttention reference:\n  - SP: Uses `USPAttention` (Ulysses + Ring Attention), configured via `--ulysses-degree` / `--ring-degree`\n  - TP: Uses `MergedColumnParallelLinear` for QKV (with Nunchaku quantization), `ReplicatedLinear` otherwise\n\n**Important**: These are references only — each model has its own architecture and parallelism requirements. Consider:\n- How attention heads can be divided across TP ranks\n- Whether the model's sequence dimension is naturally shardable for SP\n- Which linear layers benefit from column/row parallel sharding vs. replication\n- Whether cross-attention or other special modules need SP exclusion\n\nKey imports for distributed support:\n```python\nfrom sglang.multimodal_gen.runtime.distributed import (\n    divide,\n    get_sp_group,\n    get_sp_world_size,\n    get_tp_world_size,\n    sequence_model_parallel_all_gather,\n)\nfrom sglang.multimodal_gen.runtime.layers.linear import (\n    ColumnParallelLinear,\n    RowParallelLinear,\n    ReplicatedLinear,\n)\n```\n\n**VAE** (`runtime/models/vaes/{model_name}.py`): Implement if the model uses a non-standard VAE. Many models reuse existing VAEs.\n\n**Encoders** (`runtime/models/encoders/{model_name}.py`): Implement if the model uses custom text/image encoders.\n\n**Schedulers** (`runtime/models/schedulers/{scheduler_name}.py`): Implement if the model requires a custom scheduler not available in Diffusers.\n\n### Step 4: Create Model Configs\n\n**DiT Config** (`configs/models/dits/{model_name}.py`):\n\n```python\n# python/sglang/multimodal_gen/configs/models/dits/mymodel.py\n\nfrom dataclasses import dataclass, field\n\nfrom sglang.multimodal_gen.configs.models.dits.base import DiTConfig\n\n\n@dataclass\nclass MyModelDitConfig(DiTConfig):\n    arch_config: dict = field(default_factory=lambda: {\n        \"in_channels\": 16,\n        \"num_layers\": 24,\n        \"patch_size\": 2,\n        # ... model-specific architecture params ...\n    })\n```\n\n**VAE Config** (`configs/models/vaes/{model_name}.py`):\n\n```python\nfrom dataclasses import dataclass, field\n\nfrom sglang.multimodal_gen.configs.models.vaes.base import VAEConfig\n\n\n@dataclass\nclass MyModelVAEConfig(VAEConfig):\n    vae_scale_factor: int = 8\n    # ... VAE-specific params ...\n```\n\n**Sampling Params** (`configs/sample/{model_name}.py`):\n\n```python\nfrom dataclasses import dataclass\n\nfrom sglang.multimodal_gen.configs.sample.base import SamplingParams\n\n\n@dataclass\nclass MyModelSamplingParams(SamplingParams):\n    num_inference_steps: int = 50\n    guidance_scale: float = 7.5\n    height: int = 1024\n    width: int = 1024\n    # ... model-specific defaults ...\n```\n\n### Step 5: Create PipelineConfig\n\nThe `PipelineConfig` holds static model configuration and defines callback methods used by the standard `DenoisingStage` and `DecodingStage`.\n\n```python\n# python/sglang/multimodal_gen/configs/pipeline_configs/my_model.py\n\nfrom dataclasses import dataclass, field\n\nfrom sglang.multimodal_gen.configs.pipeline_configs.base import (\n    ImagePipelineConfig,      # for image generation\n    # SpatialImagePipelineConfig,  # alternative base\n    # VideoPipelineConfig,         # for video generation\n)\nfrom sglang.multimodal_gen.configs.models.dits.mymodel import MyModelDitConfig\nfrom sglang.multimodal_gen.configs.models.vaes.mymodel import MyModelVAEConfig\n\n\n@dataclass\nclass MyModelPipelineConfig(ImagePipelineConfig):\n    \"\"\"Pipeline config for MyModel.\n\n    This config provides callbacks that the standard DenoisingStage and\n    DecodingStage use during execution. The BeforeDenoisingStage handles\n    all model-specific pre-processing independently.\n    \"\"\"\n\n    task_type: ModelTaskType = ModelTaskType.T2I\n    vae_precision: str = \"bf16\"\n    should_use_guidance: bool = True\n    vae_tiling: bool = False\n    enable_autocast: bool = False\n\n    dit_config: DiTConfig = field(default_factory=MyModelDitConfig)\n    vae_config: VAEConfig = field(default_factory=MyModelVAEConfig)\n\n    # --- Callbacks used by DenoisingStage ---\n\n    def get_freqs_cis(self, batch, device, rotary_emb, dtype):\n        \"\"\"Prepare rotary position embeddings for the DiT.\"\"\"\n        # Model-specific RoPE computation\n        ...\n        return freqs_cis\n\n    def prepare_pos_cond_kwargs(self, batch, latent_model_input, t, **kwargs):\n        \"\"\"Build positive conditioning kwargs for each denoising step.\"\"\"\n        return {\n            \"hidden_states\": latent_model_input,\n            \"encoder_hidden_states\": batch.prompt_embeds[0],\n            \"timestep\": t,\n            # ... model-specific kwargs ...\n        }\n\n    def prepare_neg_cond_kwargs(self, batch, latent_model_input, t, **kwargs):\n        \"\"\"Build negative conditioning kwargs for CFG.\"\"\"\n        return {\n            \"hidden_states\": latent_model_input,\n            \"encoder_hidden_states\": batch.negative_prompt_embeds[0],\n            \"timestep\": t,\n            # ... model-specific kwargs ...\n        }\n\n    # --- Callbacks used by DecodingStage ---\n\n    def get_decode_scale_and_shift(self):\n        \"\"\"Return (scale, shift) for latent denormalization before VAE decode.\"\"\"\n        return self.vae_config.latents_std, self.vae_config.latents_mean\n\n    def post_denoising_loop(self, latents, batch):\n        \"\"\"Optional post-processing after the denoising loop finishes.\"\"\"\n        return latents.to(torch.bfloat16)\n\n    def post_decoding(self, frames, server_args):\n        \"\"\"Optional post-processing after VAE decoding.\"\"\"\n        return frames\n```\n\n**Important**: The `prepare_pos_cond_kwargs` / `prepare_neg_cond_kwargs` methods define what the DiT receives at each denoising step. These must match the DiT's `forward()` signature.\n\n### Step 6: Implement the BeforeDenoisingStage (Core Step)\n\nThis is the heart of the Hybrid pattern. Create a single stage that handles ALL pre-processing.\n\n```python\n# python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/my_model.py\n\nimport torch\nfrom typing import List, Optional, Union\n\nfrom sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.base import PipelineStage\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.distributed import get_local_torch_device\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\nclass MyModelBeforeDenoisingStage(PipelineStage):\n    \"\"\"Monolithic pre-processing stage for MyModel.\n\n    Consolidates all logic before the denoising loop:\n    - Input validation\n    - Text/image encoding\n    - Latent preparation\n    - Timestep/sigma computation\n\n    This stage produces a Req batch with all fields required by\n    the standard DenoisingStage.\n    \"\"\"\n\n    def __init__(self, vae, text_encoder, tokenizer, transformer, scheduler):\n        super().__init__()\n        self.vae = vae\n        self.text_encoder = text_encoder\n        self.tokenizer = tokenizer\n        self.transformer = transformer\n        self.scheduler = scheduler\n        # ... other initialization (image processors, scale factors, etc.) ...\n\n    # --- Internal helper methods ---\n    # Copy/adapt directly from the Diffusers reference pipeline.\n    # These are private to this stage; no need to make them reusable.\n\n    def _encode_prompt(self, prompt, device, dtype):\n        \"\"\"Encode text prompt into embeddings.\"\"\"\n        # ... model-specific text encoding logic ...\n        return prompt_embeds, negative_prompt_embeds\n\n    def _prepare_latents(self, batch_size, height, width, dtype, device, generator):\n        \"\"\"Create initial noisy latents.\"\"\"\n        # ... model-specific latent preparation ...\n        return latents\n\n    def _prepare_timesteps(self, num_inference_steps, device):\n        \"\"\"Compute the timestep/sigma schedule.\"\"\"\n        # ... model-specific timestep computation ...\n        return timesteps, sigmas\n\n    # --- Main forward method ---\n\n    @torch.no_grad()\n    def forward(self, batch: Req, server_args: ServerArgs) -> Req:\n        \"\"\"Execute all pre-processing and populate batch for DenoisingStage.\n\n        This method mirrors the first half of a Diffusers pipeline __call__,\n        up to (but not including) the denoising loop.\n        \"\"\"\n        device = get_local_torch_device()\n        dtype = torch.bfloat16\n        generator = torch.Generator(device=device).manual_seed(batch.seed)\n\n        # 1. Encode prompt\n        prompt_embeds, negative_prompt_embeds = self._encode_prompt(\n            batch.prompt, device, dtype\n        )\n\n        # 2. Prepare latents\n        latents = self._prepare_latents(\n            batch_size=1,\n            height=batch.height,\n            width=batch.width,\n            dtype=dtype,\n            device=device,\n            generator=generator,\n        )\n\n        # 3. Prepare timesteps\n        timesteps, sigmas = self._prepare_timesteps(\n            batch.num_inference_steps, device\n        )\n\n        # 4. Populate batch with everything DenoisingStage needs\n        batch.prompt_embeds = [prompt_embeds]\n        batch.negative_prompt_embeds = [negative_prompt_embeds]\n        batch.latents = latents\n        batch.timesteps = timesteps\n        batch.num_inference_steps = len(timesteps)\n        batch.sigmas = sigmas\n        batch.generator = generator\n        batch.raw_latent_shape = latents.shape\n        batch.height = batch.height\n        batch.width = batch.width\n\n        return batch\n```\n\n**Key fields that `DenoisingStage` expects on the batch** (set these in your `forward`):\n\n| Field | Type | Description |\n|-------|------|-------------|\n| `batch.latents` | `torch.Tensor` | Initial noisy latent tensor |\n| `batch.timesteps` | `torch.Tensor` | Timestep schedule |\n| `batch.num_inference_steps` | `int` | Number of denoising steps |\n| `batch.sigmas` | `list[float]` | Sigma schedule (as a list, not numpy) |\n| `batch.prompt_embeds` | `list[torch.Tensor]` | Positive prompt embeddings (wrapped in list) |\n| `batch.negative_prompt_embeds` | `list[torch.Tensor]` | Negative prompt embeddings (wrapped in list) |\n| `batch.generator` | `torch.Generator` | RNG generator for reproducibility |\n| `batch.raw_latent_shape` | `tuple` | Original latent shape before any packing |\n| `batch.height` / `batch.width` | `int` | Output dimensions |\n\n### Step 7: Define the Pipeline Class\n\nThe pipeline class is minimal -- it just wires the stages together.\n\n```python\n# python/sglang/multimodal_gen/runtime/pipelines/my_model.py\n\nfrom sglang.multimodal_gen.runtime.pipelines_core import LoRAPipeline\nfrom sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import (\n    ComposedPipelineBase,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages import DenoisingStage\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.model_specific_stages.my_model import (\n    MyModelBeforeDenoisingStage,\n)\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\n\n\nclass MyModelPipeline(LoRAPipeline, ComposedPipelineBase):\n    pipeline_name = \"MyModelPipeline\"  # Must match model_index.json _class_name\n\n    _required_config_modules = [\n        \"text_encoder\",\n        \"tokenizer\",\n        \"vae\",\n        \"transformer\",\n        \"scheduler\",\n        # ... list all modules from model_index.json ...\n    ]\n\n    def create_pipeline_stages(self, server_args: ServerArgs):\n        # 1. Monolithic pre-processing (model-specific)\n        self.add_stage(\n            MyModelBeforeDenoisingStage(\n                vae=self.get_module(\"vae\"),\n                text_encoder=self.get_module(\"text_encoder\"),\n                tokenizer=self.get_module(\"tokenizer\"),\n                transformer=self.get_module(\"transformer\"),\n                scheduler=self.get_module(\"scheduler\"),\n            ),\n        )\n\n        # 2. Standard denoising loop (framework-provided)\n        self.add_stage(\n            DenoisingStage(\n                transformer=self.get_module(\"transformer\"),\n                scheduler=self.get_module(\"scheduler\"),\n            ),\n        )\n\n        # 3. Standard VAE decoding (framework-provided)\n        self.add_standard_decoding_stage()\n\n\n# REQUIRED: This is how the registry discovers the pipeline\nEntryClass = [MyModelPipeline]\n```\n\n### Step 8: Register the Model\n\nIn `python/sglang/multimodal_gen/registry.py`, register your configs:\n\n```python\nregister_configs(\n    model_family=\"my_model\",\n    sampling_param_cls=MyModelSamplingParams,\n    pipeline_config_cls=MyModelPipelineConfig,\n    hf_model_paths=[\n        \"org/my-model-name\",  # HuggingFace model ID(s)\n    ],\n)\n```\n\nThe `EntryClass` in your pipeline file is automatically discovered by the registry's `_discover_and_register_pipelines()` function -- no additional registration needed for the pipeline class itself.\n\n### Step 9: Verify Output Quality\n\nAfter implementation, **you must verify that the generated output is not noise**. A noisy or garbled output image/video is the most common sign of an incorrect implementation. Common causes include:\n\n- Incorrect latent scale/shift factors (`get_decode_scale_and_shift` returning wrong values)\n- Wrong timestep/sigma schedule (order, dtype, or value range)\n- Mismatched conditioning kwargs (fields not matching the DiT's `forward()` signature)\n- Incorrect VAE decoder configuration (wrong `vae_scale_factor`, missing denormalization)\n- Rotary embedding style mismatch (`is_neox_style` set incorrectly)\n- Wrong prompt embedding format (missing list wrapping, wrong encoder output selection)\n\n**If the output is noise, the implementation is incorrect — do not ship it.** Debug by:\n1. Comparing intermediate tensor values (latents, prompt_embeds, timesteps) against the Diffusers reference pipeline\n2. Running the Diffusers pipeline and SGLang pipeline side-by-side with the same seed\n3. Checking each stage's output shape and value range independently\n\n## Reference Implementations\n\n### Hybrid Style (recommended for most new models)\n\n| Model | Pipeline | BeforeDenoisingStage | PipelineConfig |\n|-------|----------|---------------------|----------------|\n| GLM-Image | `runtime/pipelines/glm_image.py` | `stages/model_specific_stages/glm_image.py` | `configs/pipeline_configs/glm_image.py` |\n| Qwen-Image-Layered | `runtime/pipelines/qwen_image.py` (`QwenImageLayeredPipeline`) | `stages/model_specific_stages/qwen_image_layered.py` | `configs/pipeline_configs/qwen_image.py` (`QwenImageLayeredPipelineConfig`) |\n\n### Modular Style (when standard stages fit well)\n\n| Model | Pipeline | Notes |\n|-------|----------|-------|\n| Qwen-Image (T2I) | `runtime/pipelines/qwen_image.py` | Uses `add_standard_t2i_stages()` — standard text encoding + latent prep fits this model |\n| Qwen-Image-Edit | `runtime/pipelines/qwen_image.py` | Uses `add_standard_ti2i_stages()` — standard image-to-image flow |\n| Flux | `runtime/pipelines/flux.py` | Uses `add_standard_t2i_stages()` with custom `prepare_mu` |\n| Wan | `runtime/pipelines/wan_pipeline.py` | Uses `add_standard_ti2v_stages()` |\n\n---\n\n## Checklist\n\nBefore submitting, verify:\n\n**Common (both styles):**\n- [ ] **Pipeline file** exists at `runtime/pipelines/{model_name}.py` with `EntryClass`\n- [ ] **PipelineConfig** at `configs/pipeline_configs/{model_name}.py`\n- [ ] **SamplingParams** at `configs/sample/{model_name}.py`\n- [ ] **DiT model** at `runtime/models/dits/{model_name}.py`\n- [ ] **DiT config** at `configs/models/dits/{model_name}.py`\n- [ ] **VAE** — reuse existing (e.g., `AutoencoderKL`) or create new at `runtime/models/vaes/`\n- [ ] **VAE config** — reuse existing or create new at `configs/models/vaes/{model_name}.py`\n- [ ] **Registry entry** in `registry.py` via `register_configs()`\n- [ ] `pipeline_name` matches Diffusers `model_index.json` `_class_name`\n- [ ] `_required_config_modules` lists all modules from `model_index.json`\n- [ ] `PipelineConfig` callbacks (`prepare_pos_cond_kwargs`, `get_freqs_cis`, etc.) match DiT's `forward()` signature\n- [ ] Latent scale/shift factors are correctly configured\n- [ ] Use fused kernels where possible (see `use-efficient-diffusion-kernels` skill)\n- [ ] Weight names match Diffusers for automatic loading\n- [ ] **TP/SP support** considered for DiT model (recommended; reference `wanvideo.py` for TP+SP, `qwen_image.py` for USPAttention)\n- [ ] **Output quality verified** — generated images/videos are not noise; compared against Diffusers reference output\n\n**Hybrid style only:**\n- [ ] **BeforeDenoisingStage** at `stages/model_specific_stages/{model_name}.py`\n- [ ] `BeforeDenoisingStage.forward()` populates all fields needed by `DenoisingStage`\n\n## Common Pitfalls\n\n1. **`batch.sigmas` must be a Python list**, not a numpy array. Use `.tolist()` to convert.\n2. **`batch.prompt_embeds` is a list of tensors** (one per encoder), not a single tensor. Wrap with `[tensor]`.\n3. **Don't forget `batch.raw_latent_shape`** -- `DecodingStage` uses it to unpack latents.\n4. **Rotary embedding style matters**: `is_neox_style=True` = split-half rotation, `is_neox_style=False` = interleaved. Check the reference model carefully.\n5. **VAE precision**: Many VAEs need fp32 or bf16 for numerical stability. Set `vae_precision` in the PipelineConfig accordingly.\n6. **Avoid forcing model-specific logic into shared stages**: If your model's pre-processing doesn't naturally fit the existing standard stages, prefer the Hybrid pattern with a dedicated BeforeDenoisingStage rather than adding conditional branches to shared stages.\n\n## After Implementation: Tests and Performance Data\n\nOnce the model is working and output quality is verified, **ask the user** whether they would like to:\n\n1. **Add tests** — Create unit tests and/or integration tests for the new model. Tests should cover:\n   - Pipeline construction and stage wiring\n   - Single-GPU inference producing non-noise output\n   - Multi-GPU inference (TP/SP) if supported\n   - See the `write-sglang-test` skill for test conventions and placement guidelines\n\n2. **Generate performance data** — Run benchmarks and collect perf metrics:\n   - Single-GPU latency and throughput (look for `Pixel data generated successfully in xxxx seconds` in console output; use the `warmup excluded` line for accurate timing)\n   - Multi-GPU scaling (TP/SP) throughput comparison\n   - Use `python/sglang/multimodal_gen/benchmarks/bench_serving.py` for serving benchmarks\n\nDo not skip this step — always ask the user before proceeding, as test and benchmark requirements vary per model.\n"
  },
  {
    "path": "python/sglang/multimodal_gen/README.md",
    "content": "<div align=\"center\"  style=\"display:block; margin:auto;\">\n<img src=https://github.com/lm-sys/lm-sys.github.io/releases/download/test/sgl-diffusion-logo.png width=\"80%\"/>\n</div>\n\n**SGLang diffusion is an inference framework for accelerated image/video generation.**\n\nSGLang diffusion features an end-to-end unified pipeline for accelerating diffusion models. It is designed to be modular and extensible, allowing users to easily add new models and optimizations.\n\n## Key Features\n\nSGLang Diffusion has the following features:\n  - Broad model support: Wan series, FastWan series, Hunyuan, Qwen-Image, Qwen-Image-Edit, Flux, Z-Image, GLM-Image\n  - Fast inference speed: enpowered by highly optimized kernel from sgl-kernel and efficient scheduler loop\n  - Ease of use: OpenAI-compatible api, CLI, and python sdk support\n  - Multi-platform support:\n    - NVIDIA GPUs (H100, H200, A100, B200, 4090)\n    - AMD GPUs (MI300X, MI325X)\n    - Ascend NPU (A2, A3)\n    - Apple Silicon (M-series via MPS)\n    - Moore Threads GPUs (MTT S5000)\n\n### AMD/ROCm Support\n\nSGLang Diffusion supports AMD Instinct GPUs through ROCm. On AMD platforms, we use the Triton attention backend and leverage AITER kernels for optimized layernorm and other operations. See the [installation guide](https://github.com/sgl-project/sglang/tree/main/docs/diffusion/installation.md) for setup instructions.\n\n### Moore Threads/MUSA Support\n\nSGLang Diffusion supports Moore Threads GPUs (MTGPU) through the MUSA software stack. On MUSA platforms, we use the Torch SDPA backend for attention. See the [installation guide](https://github.com/sgl-project/sglang/tree/main/docs/diffusion/installation.md) for setup instructions.\n\n### Apple MPS Support\n\nSGLang Diffusion supports Apple Silicon (M-series) via the MPS backend. Since Triton is Linux-only, all Triton kernels are replaced with PyTorch-native fallbacks on MPS. Norm operations can be optionally accelerated with MLX fused Metal kernels (`SGLANG_USE_MLX=1`). See the [installation guide](https://github.com/sgl-project/sglang/tree/main/docs/diffusion/installation.md) for setup instructions.\n\n## Getting Started\n\n```bash\nuv pip install 'sglang[diffusion]' --prerelease=allow\n```\n\nFor more installation methods (e.g. pypi, uv, docker, ROCm/AMD, MUSA/Moore Threads), check [install.md](https://github.com/sgl-project/sglang/tree/main/docs/diffusion/installation.md).\n\n## Inference\n\nHere's a minimal example to generate a video using the default settings:\n\n```python\nfrom sglang.multimodal_gen import DiffGenerator\n\ndef main():\n    # Create a diff generator from a pre-trained model\n    generator = DiffGenerator.from_pretrained(\n        model_path=\"Wan-AI/Wan2.1-T2V-1.3B-Diffusers\",\n        num_gpus=1,  # Adjust based on your hardware\n    )\n\n    # Generate the video\n    video = generator.generate(\n        sampling_params_kwargs=dict(\n            prompt=\"A curious raccoon peers through a vibrant field of yellow sunflowers, its eyes wide with interest.\",\n            return_frames=True,  # Also return frames from this call (defaults to False)\n            output_path=\"my_videos/\",  # Controls where videos are saved\n            save_output=True\n        )\n    )\n\nif __name__ == '__main__':\n    main()\n```\n\nOr, more simply, with the CLI:\n\n```bash\nsglang generate --model-path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \\\n    --text-encoder-cpu-offload --pin-cpu-memory \\\n    --prompt \"A curious raccoon\" \\\n    --save-output\n```\n\n### LoRA support\n\nApply LoRA adapters via `--lora-path`:\n\n```bash\nsglang generate \\\n  --model-path Qwen/Qwen-Image-Edit-2511 \\\n  --lora-path prithivMLmods/Qwen-Image-Edit-2511-Anime \\\n  --prompt \"Transform into anime.\" \\\n  --image-path \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png\" \\\n  --save-output\n```\n\nFor more usage examples (e.g. OpenAI compatible API, server mode), check [cli.md](https://github.com/sgl-project/sglang/tree/main/docs/diffusion/cli.md).\n\n## Contributing\n\nAll contributions are welcome. The contribution guide is available [here](https://github.com/sgl-project/sglang/tree/main/docs/diffusion/contributing.md).\n\n## Acknowledgement\n\nWe learnt and reused code from the following projects:\n\n- [FastVideo](https://github.com/hao-ai-lab/FastVideo.git). The major components of this repo are based on a fork of FastVideo on Sept. 24, 2025.\n- [xDiT](https://github.com/xdit-project/xDiT). We used the parallelism library from it.\n- [diffusers](https://github.com/huggingface/diffusers) We used the pipeline design from it.\n"
  },
  {
    "path": "python/sglang/multimodal_gen/__init__.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\nfrom sglang.multimodal_gen.configs.pipeline_configs import PipelineConfig\nfrom sglang.multimodal_gen.configs.sample import SamplingParams\nfrom sglang.multimodal_gen.runtime.entrypoints.diffusion_generator import DiffGenerator\n\n__all__ = [\"DiffGenerator\", \"PipelineConfig\", \"SamplingParams\"]\n\n# Trigger multimodal CI tests\n"
  },
  {
    "path": "python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/README.md",
    "content": "# ComfyUI SGLDiffusion Plugin\n\nA ComfyUI plugin for integrating with SGLang Diffusion server, supporting image and video generation capabilities.\n\n## Installation\n\n1. **Install SGLang**: Follow the [Installation Guide](../../docs/install.md) to install `sglang[diffusion]`.\n2. **Install Plugin**: Copy this entire directory (`ComfyUI_SGLDiffusion`) to your ComfyUI `custom_nodes/` folder.\n3. **Restart ComfyUI**: Restart ComfyUI to load the plugin.\n\n## Usage\n\nThe plugin supports two modes of operation: **Server Mode** (via HTTP API) and **Integrated Mode** (tight integration with ComfyUI).\n\n### Supported Models\n- **Z-Image**: High-speed image generation models (e.g., `Z-Image-Turbo`)\n- **FLUX**: State-of-the-art text-to-image models (e.g., `FLUX.1-dev`)\n- **Qwen-Image**: Multi-modal image generation models (e.g., `Qwen-Image`,`Qwen-Image-2512`). *Note: Image editing support is currently experimental and may have some issues.*\n\n### Mode 1: Server Mode (HTTP API)\nConnect to a standalone SGLang Diffusion server.\n\n1. **Start SGLang Diffusion Server**: Ensure the server is running and accessible.\n2. **Connect to Server**: Use the `SGLDiffusion Server Model` node to connect (default: `http://localhost:3000/v1`).\n3. **Generate Content**:\n   - `SGLDiffusion Generate Image`: For text-to-image and image editing.\n   - `SGLDiffusion Generate Video`: For text-to-video and image-to-video.\n4. **LoRA Support**: Use `SGLDiffusion Server Set LoRA` and `SGLDiffusion Server Unset LoRA`.\n\n### Mode 2: Integrated Mode (Tight Integration)\nLeverage SGLang's high-performance sampling directly within ComfyUI while using ComfyUI's front-end nodes (CLIP, VAE, etc.).\n\n1. **Load Model**: Use the `SGLDiffusion UNET Loader` node to load your diffusion model.\n2. **Configure Options**: Use the `SGLDiffusion Options` node to set runtime parameters like `num_gpus`, `tp_size`, `model_type`, or `enable_torch_compile`.\n3. **Sample**: Connect the loaded model to standard ComfyUI samplers. SGLang will handle the sampling process efficiently.\n4. **LoRA Support**: Use the `SGLDiffusion LoRA Loader` for native LoRA integration.\n\n## Example Workflows\n\nReference workflow files are provided in the `workflows/` directory:\n\n- **`flux_sgld_sp.json`**: Multi-GPU (Sequence Parallelism) workflow for FLUX models. High-performance inference across multiple cards.\n- **`qwen_image_sgld.json`**: Qwen-Image generation with LoRA support. Optimized for multi-modal image tasks.\n- **`z-image_sgld.json`**: High-speed image generation using Z-Image.\n- **`sgld_text2img.json`**: Server-mode text-to-image generation with LoRA support.\n- **`sgld_image2video.json`**: Server-mode image-to-video generation.\n\nFor other workflows supporting the models, you can easily use SGLang by replacing the official `UNET Loader` node with the `SGLDUNETLoader` node. Similarly, for LoRA support, replace the official LoRA loader with the `SGLDiffusion LoRA Loader`.\n\nTo use these workflows:\n1. Open ComfyUI.\n2. Load the workflow JSON file from the `workflows/` directory.\n3. Adjust the parameters and model paths as needed.\n4. Run the workflow.\n\n\n## Current Implementation\n\nThis plugin provides a high-performance backend for diffusion models in ComfyUI. By leveraging SGLang's optimized kernels and parallelization techniques (Tensor Parallelism, TeaCache, etc.), it significantly accelerates the sampling process, especially for large models like FLUX.\n"
  },
  {
    "path": "python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/__init__.py",
    "content": "\"\"\"\nComfyUI SGLang Diffusion nodes package.\n\"\"\"\n\ntry:\n    from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS\n\n    __all__ = [\"NODE_CLASS_MAPPINGS\", \"NODE_DISPLAY_NAME_MAPPINGS\"]\nexcept ImportError:\n    # ComfyUI dependencies not available (e.g., in test environment)\n    NODE_CLASS_MAPPINGS = {}\n    NODE_DISPLAY_NAME_MAPPINGS = {}\n    __all__ = [\"NODE_CLASS_MAPPINGS\", \"NODE_DISPLAY_NAME_MAPPINGS\"]\n"
  },
  {
    "path": "python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/core/__init__.py",
    "content": "\"\"\"\nCore components for SGLang Diffusion ComfyUI integration.\nProvides generator, model patcher, and server API client.\n\"\"\"\n\nfrom .generator import SGLDiffusionGenerator\nfrom .model_patcher import SGLDModelPatcher\nfrom .server_api import SGLDiffusionServerAPI\n\n__all__ = [\n    \"SGLDiffusionGenerator\",\n    \"SGLDModelPatcher\",\n    \"SGLDiffusionServerAPI\",\n]\n"
  },
  {
    "path": "python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/core/generator.py",
    "content": "\"\"\"\nGenerator for SGLang Diffusion ComfyUI integration.\n\"\"\"\n\nimport logging\nimport os\n\nimport psutil\nfrom comfy import model_detection, model_management\nfrom comfy.utils import (\n    calculate_parameters,\n    load_torch_file,\n    state_dict_prefix_replace,\n    unet_to_diffusers,\n)\n\nlogger = logging.getLogger(__name__)\n\ntry:\n    from sglang.multimodal_gen import DiffGenerator\nexcept ImportError:\n    logger.error(\n        \"Error: sglang.multimodal_gen is not installed. Please install it using 'pip install sglang[diffusion]'\"\n    )\n\nfrom ..executors import (\n    FluxExecutor,\n    QwenImageEditExecutor,\n    QwenImageExecutor,\n    ZImageExecutor,\n)\nfrom .model_patcher import SGLDModelPatcher\n\n\nclass SGLDiffusionGenerator:\n    \"\"\"Generator for SGLang Diffusion models in ComfyUI.\"\"\"\n\n    def __init__(self):\n        self.model_path = None\n        self.generator = None\n        self.executor = None\n        self.last_options = None\n\n        self.pipeline_class_dict = {\n            \"flux\": \"ComfyUIFluxPipeline\",\n            \"lumina2\": \"ComfyUIZImagePipeline\",  # zimage\n            \"qwen_image\": \"ComfyUIQwenImagePipeline\",\n            \"qwen_image_edit\": \"ComfyUIQwenImageEditPipeline\",\n        }\n        self.executor_class_dict = {\n            \"flux\": FluxExecutor,\n            \"lumina2\": ZImageExecutor,\n            \"qwen_image\": QwenImageExecutor,\n            \"qwen_image_edit\": QwenImageEditExecutor,\n        }\n\n    def __del__(self):\n        self.close_generator()\n\n    def init_generator(\n        self, model_path: str, pipeline_class_name: str, kwargs: dict = None\n    ):\n        \"\"\"Initialize the diffusion generator.\"\"\"\n        if self.generator is not None:\n            return self.generator\n        if kwargs is None:\n            kwargs = {}\n        # Set comfyui_mode for ComfyUI integration\n        kwargs[\"comfyui_mode\"] = True\n        self.generator = DiffGenerator.from_pretrained(\n            model_path=model_path,\n            pipeline_class_name=pipeline_class_name,\n            **kwargs,\n        )\n        return self.generator\n\n    def kill_generator(self):\n        \"\"\"Kill worker processes manually because generator shutdown cannot terminate them.\"\"\"\n        current_pid = os.getpid()\n        worker_processes = []\n        for proc in psutil.process_iter([\"pid\", \"name\", \"cmdline\"]):\n            try:\n                # Look for sglang-diffusionWorker processes\n                if proc.info[\"cmdline\"]:\n                    cmdline = \" \".join(proc.info[\"cmdline\"])\n                    if \"sgl_diffusion::\" in cmdline:\n                        if proc.info[\"pid\"] != current_pid:\n                            worker_processes.append(proc)\n            except (psutil.NoSuchProcess, psutil.AccessDenied):\n                continue\n\n        if worker_processes:\n            logger.info(\n                f\"Found {len(worker_processes)} worker processes to terminate...\"\n            )\n            for proc in worker_processes:\n                try:\n                    logger.info(\n                        f\"Terminating worker process {proc.info['pid']}: {proc.info['name']}\"\n                    )\n                    proc.terminate()\n                    proc.wait(timeout=5)\n                except psutil.TimeoutExpired:\n                    logger.warning(\n                        f\"Process {proc.info['pid']} did not terminate, forcing kill...\"\n                    )\n                    try:\n                        proc.kill()\n                        proc.wait(timeout=2)\n                    except (psutil.NoSuchProcess, psutil.TimeoutExpired):\n                        pass\n                except (psutil.NoSuchProcess, psutil.AccessDenied):\n                    pass\n\n    def close_generator(self):\n        \"\"\"Close and cleanup the generator and all associated resources.\"\"\"\n        if self.generator is not None:\n            self.generator.shutdown()\n            self.kill_generator()\n            # Clear other references\n            self.last_options = None\n            self.model_path = None\n            self.generator = None\n            self.executor = None\n\n    def get_comfyui_model(self, model_path: str, model_options: dict = None):\n        \"\"\"Get ComfyUI model from model path.\"\"\"\n        if model_options is None:\n            model_options = {}\n        dtype = model_options.get(\"dtype\", None)\n        # Allow loading unets from checkpoint files\n        sd = load_torch_file(model_path)\n        diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)\n        temp_sd = state_dict_prefix_replace(\n            sd, {diffusion_model_prefix: \"\"}, filter_keys=True\n        )\n        if len(temp_sd) > 0:\n            sd = temp_sd\n\n        parameters = calculate_parameters(sd)\n        load_device = model_management.get_torch_device()\n\n        model_detect_config = model_detection.detect_unet_config(sd, \"\")\n        model_type = model_detect_config.get(\"image_model\", None)\n        if model_type is None or model_type not in self.pipeline_class_dict:\n            raise ValueError(f\"Unsupported model type: {model_type}\")\n        model_config = model_detection.model_config_from_unet(sd, \"\")\n\n        if model_config is not None:\n            new_sd = sd\n        else:\n            new_sd = model_detection.convert_diffusers_mmdit(sd, \"\")\n            if new_sd is not None:  # diffusers mmdit\n                model_config = model_detection.model_config_from_unet(new_sd, \"\")\n                if model_config is None:\n                    return None\n            else:  # diffusers unet\n                model_config = model_detection.model_config_from_diffusers_unet(sd)\n                if model_config is None:\n                    return None\n\n                diffusers_keys = unet_to_diffusers(model_config.unet_config)\n                new_sd = {}\n                for k in diffusers_keys:\n                    if k in sd:\n                        new_sd[diffusers_keys[k]] = sd.pop(k)\n        offload_device = model_management.unet_offload_device()\n        if dtype is None:\n            unet_dtype = model_management.unet_dtype(\n                model_params=parameters,\n                supported_dtypes=model_config.supported_inference_dtypes,\n            )\n        else:\n            unet_dtype = dtype\n\n        manual_cast_dtype = model_management.unet_manual_cast(\n            unet_dtype, load_device, model_config.supported_inference_dtypes\n        )\n        model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)\n        model_config.custom_operations = model_options.get(\"custom_operations\", None)\n        model_config.unet_config[\"disable_unet_model_creation\"] = True\n        comfyui_model = model_config.get_model({})\n        return comfyui_model, model_config, model_type\n\n    def load_model(\n        self, model_path: str, model_options: dict = None, sgld_options: dict = None\n    ):\n        \"\"\"Load model and return model patcher.\"\"\"\n        gather_options = {\n            \"model_path\": model_path,\n            \"model_options\": model_options,\n            \"sgld_options\": sgld_options,\n        }\n        if (\n            self.last_options is not None\n            and self.last_options == gather_options\n            and self.generator is not None\n        ):\n            return self.generator\n        else:\n            self.close_generator()\n\n        self.last_options = gather_options\n        self.model_path = model_path\n\n        comfyui_model, model_config, model_type = self.get_comfyui_model(\n            model_path, model_options\n        )\n        if model_type is None or model_type not in self.pipeline_class_dict:\n            raise ValueError(f\"Unsupported model type: {model_type}\")\n\n        set_model_type = sgld_options.pop(\"model_type\", None) if sgld_options else None\n        if set_model_type is not None and set_model_type in self.pipeline_class_dict:\n            model_type = set_model_type\n\n        pipeline_class_name = self.pipeline_class_dict[model_type]\n        self.generator = self.init_generator(\n            model_path, pipeline_class_name, sgld_options\n        )\n\n        executor_class = self.executor_class_dict[model_type]\n        self.executor = executor_class(\n            self.generator, model_path, comfyui_model, model_config\n        )\n        comfyui_model.diffusion_model = self.executor\n\n        load_device = model_management.get_torch_device()\n        offload_device = model_management.unet_offload_device()\n\n        return SGLDModelPatcher(\n            comfyui_model, load_device, offload_device, model_type=model_type\n        )\n"
  },
  {
    "path": "python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/core/model_patcher.py",
    "content": "\"\"\"\nModel patcher for SGLang Diffusion ComfyUI integration.\n\"\"\"\n\nimport copy\n\nfrom comfy.model_patcher import ModelPatcher\n\n\nclass SGLDModelPatcher(ModelPatcher):\n    \"\"\"Model patcher for SGLang Diffusion models in ComfyUI.\"\"\"\n\n    def __init__(\n        self,\n        model,\n        load_device,\n        offload_device,\n        size=0,\n        weight_inplace_update=False,\n        model_type=None,\n    ):\n        super().__init__(\n            model, load_device, offload_device, size, weight_inplace_update\n        )\n        self.lora_cache = {}\n        self.model_type = model_type\n        self.model_size_dict = {\n            \"flux\": 27 * 1024 * 1024 * 1024,\n            \"lumina2\": 8 * 1024 * 1024 * 1024,\n        }\n\n    def clone(self):\n        \"\"\"Clone the model patcher.\"\"\"\n        n = SGLDModelPatcher(\n            self.model,\n            self.load_device,\n            self.offload_device,\n            self.size,\n            weight_inplace_update=self.weight_inplace_update,\n        )\n        n.patches = {}\n        for k in self.patches:\n            n.patches[k] = self.patches[k][:]\n        n.patches_uuid = self.patches_uuid\n\n        n.object_patches = self.object_patches.copy()\n        n.model_options = copy.deepcopy(self.model_options)\n        n.backup = self.backup\n        n.object_patches_backup = self.object_patches_backup\n        n.lora_cache = copy.copy(self.lora_cache)\n        return n\n\n    def model_size(self):\n        \"\"\"Get the model size in bytes.\"\"\"\n        if self.model_type in self.model_size_dict:\n            return self.model_size_dict[self.model_type]\n        else:\n            return 0\n\n    def load(\n        self,\n        device_to=None,\n        lowvram_model_memory=0,\n        force_patch_weights=False,\n        full_load=False,\n    ):\n        \"\"\"Load model (no-op for SGLang Diffusion).\"\"\"\n        pass\n\n    def patch_model(\n        self,\n        device_to=None,\n        lowvram_model_memory=0,\n        load_weights=True,\n        force_patch_weights=False,\n    ):\n        \"\"\"Patch model (no-op for SGLang Diffusion).\"\"\"\n        pass\n\n    def unpatch_model(self, device_to=None, unpatch_weights=True):\n        \"\"\"Unpatch model (no-op for SGLang Diffusion).\"\"\"\n        pass\n"
  },
  {
    "path": "python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/core/server_api.py",
    "content": "\"\"\"\nSGLang Diffusion Server API client.\nProvides a low-level interface for interacting with SGLang Diffusion HTTP server.\n\"\"\"\n\nimport base64\nimport io\nimport os\nimport time\nfrom typing import Any, Dict, Optional\n\nimport requests\nfrom PIL import Image\n\n\nclass SGLDiffusionServerAPI:\n    \"\"\"Client for SGLang Diffusion HTTP server API.\"\"\"\n\n    def __init__(self, base_url: str, api_key: str = \"sk-proj-1234567890\"):\n        \"\"\"\n        Initialize the API client.\n\n        Args:\n            base_url: Base URL of the SGLang Diffusion server (e.g., \"http://localhost:30010/v1\")\n            api_key: API key for authentication (default: \"sk-proj-1234567890\")\n        \"\"\"\n        # Ensure base_url doesn't end with /v1 if it's already there\n        if base_url.endswith(\"/v1\"):\n            self.base_url = base_url\n        elif base_url.endswith(\"/v1/\"):\n            self.base_url = base_url.rstrip(\"/\")\n        else:\n            self.base_url = f\"{base_url.rstrip('/')}/v1\"\n\n        self.api_key = api_key\n        self.headers = {\n            \"Content-Type\": \"application/json\",\n            \"Authorization\": f\"Bearer {api_key}\",\n        }\n\n    def get_model_info(self) -> Dict[str, Any]:\n        \"\"\"\n        Get information about the model served by this server.\n\n        Returns:\n            Dictionary containing model information including:\n            - model_path: Path to the model\n            - task_type: Type of task (e.g., \"T2V\", \"I2I\")\n            - pipeline_name: Name of the pipeline\n            - num_gpus: Number of GPUs\n            - dit_precision: DiT model precision\n            - vae_precision: VAE model precision\n        \"\"\"\n        try:\n            # Remove /v1 from base_url for /models endpoint\n            models_url = self.base_url.removesuffix(\"/v1\") + \"/models\"\n            response = requests.get(models_url, headers=self.headers, timeout=30)\n            response.raise_for_status()\n            return response.json()\n        except requests.exceptions.RequestException as e:\n            raise RuntimeError(f\"Failed to get model info: {str(e)}\")\n\n    def generate_image(\n        self,\n        prompt: str,\n        image_path: Optional[str] = None,\n        mask_path: Optional[str] = None,\n        size: Optional[str] = None,\n        width: Optional[int] = None,\n        height: Optional[int] = None,\n        n: int = 1,\n        negative_prompt: Optional[str] = None,\n        guidance_scale: Optional[float] = None,\n        num_inference_steps: Optional[int] = None,\n        seed: Optional[int] = None,\n        enable_teacache: bool = False,\n        response_format: str = \"b64_json\",\n        quality: Optional[str] = \"auto\",\n        style: Optional[str] = \"vivid\",\n        background: Optional[str] = \"auto\",\n        output_format: Optional[str] = None,\n        generator_device: Optional[str] = \"cuda\",\n    ) -> Dict[str, Any]:\n        \"\"\"\n        Generate or edit an image using SGLang Diffusion API.\n        If image_path is provided, calls the edit endpoint; otherwise calls the generation endpoint.\n\n        Args:\n            prompt: Text prompt for image generation/editing\n            image_path: Optional path to input image file for editing. If provided, uses edit API.\n            mask_path: Optional path to mask image file (only used when image_path is provided)\n            size: Image size in format \"WIDTHxHEIGHT\" (e.g., \"1024x1024\")\n            width: Image width (used if size is not provided)\n            height: Image height (used if size is not provided)\n            n: Number of images to generate (1-10)\n            negative_prompt: Negative prompt to avoid certain elements\n            guidance_scale: Classifier-free guidance scale\n            num_inference_steps: Number of denoising steps\n            seed: Random seed for reproducible generation\n            enable_teacache: Enable TEA cache acceleration\n            response_format: Response format (\"b64_json\" or \"url\")\n            quality: Image quality (\"auto\", \"standard\", \"hd\") - only for generation\n            style: Image style (\"vivid\" or \"natural\") - only for generation\n            background: Background type (\"auto\", \"transparent\", \"opaque\")\n            output_format: Output format (\"png\", \"jpeg\", \"webp\")\n            generator_device: Device for random generator (\"cuda\" or \"cpu\")\n\n        Returns:\n            Dictionary containing the API response with generated/edited image data\n        \"\"\"\n        if not prompt:\n            raise ValueError(\"Prompt cannot be empty\")\n\n        # Determine size\n        if size is None:\n            if width is not None and height is not None:\n                size = f\"{width}x{height}\"\n            else:\n                size = \"1024x1024\"\n\n        # Build common parameters\n        common_params = self._build_image_common_params(\n            prompt=prompt,\n            size=size,\n            n=n,\n            response_format=response_format,\n            negative_prompt=negative_prompt,\n            guidance_scale=guidance_scale,\n            num_inference_steps=num_inference_steps,\n            seed=seed,\n            enable_teacache=enable_teacache,\n            background=background,\n            output_format=output_format,\n            generator_device=generator_device,\n        )\n\n        # If image_path is provided, use edit endpoint\n        if image_path:\n            if not os.path.exists(image_path):\n                raise FileNotFoundError(f\"Image file not found: {image_path}\")\n\n            # Prepare multipart form data for edit\n            files: Dict[str, Any] = {}\n            data = common_params.copy()\n\n            # Add image file\n            files[\"image\"] = (\n                os.path.basename(image_path),\n                open(image_path, \"rb\"),\n                self._get_content_type(image_path),\n            )\n\n            # Add mask file if provided\n            if mask_path:\n                if not os.path.exists(mask_path):\n                    raise FileNotFoundError(f\"Mask file not found: {mask_path}\")\n                files[\"mask\"] = (\n                    os.path.basename(mask_path),\n                    open(mask_path, \"rb\"),\n                    self._get_content_type(mask_path),\n                )\n\n            # Prepare headers for multipart form data\n            headers = {\n                \"Authorization\": f\"Bearer {self.api_key}\",\n            }\n\n            try:\n                response = requests.post(\n                    f\"{self.base_url}/images/edits\",\n                    files=files,\n                    data=data,\n                    headers=headers,\n                    timeout=300,  # 5 minutes timeout for generation\n                )\n                response.raise_for_status()\n                return response.json()\n            except requests.exceptions.RequestException as e:\n                raise RuntimeError(f\"Failed to edit image: {str(e)}\")\n            finally:\n                # Close file handles\n                for file_tuple in files.values():\n                    if isinstance(file_tuple, tuple) and len(file_tuple) > 1:\n                        file_tuple[1].close()\n        else:\n            # Use generation endpoint - add generation-specific parameters\n            payload = common_params.copy()\n            if quality:\n                payload[\"quality\"] = quality\n            if style:\n                payload[\"style\"] = style\n\n            try:\n                response = requests.post(\n                    f\"{self.base_url}/images/generations\",\n                    json=payload,\n                    headers=self.headers,\n                    timeout=300,  # 5 minutes timeout for generation\n                )\n                response.raise_for_status()\n                return response.json()\n            except requests.exceptions.RequestException as e:\n                raise RuntimeError(f\"Failed to generate image: {str(e)}\")\n\n    def generate_video(\n        self,\n        prompt: str,\n        size: Optional[str] = None,\n        width: Optional[int] = None,\n        height: Optional[int] = None,\n        seconds: Optional[int] = 4,\n        fps: Optional[int] = None,\n        num_frames: Optional[int] = None,\n        negative_prompt: Optional[str] = None,\n        guidance_scale: Optional[float] = None,\n        num_inference_steps: Optional[int] = None,\n        seed: Optional[int] = None,\n        enable_teacache: bool = False,\n        generator_device: Optional[str] = \"cuda\",\n        input_reference: Optional[str] = None,\n        output_path: Optional[str] = None,\n    ) -> Dict[str, Any]:\n        \"\"\"\n        Generate a video using SGLang Diffusion API and wait for completion.\n\n        Args:\n            prompt: Text prompt for video generation\n            size: Video size in format \"WIDTHxHEIGHT\" (e.g., \"1280x720\")\n            width: Video width (used if size is not provided)\n            height: Video height (used if size is not provided)\n            seconds: Duration of the video in seconds\n            fps: Frames per second\n            num_frames: Number of frames (overrides seconds * fps if provided)\n            negative_prompt: Negative prompt to avoid certain elements\n            guidance_scale: Classifier-free guidance scale\n            num_inference_steps: Number of denoising steps\n            seed: Random seed for reproducible generation\n            enable_teacache: Enable TEA cache acceleration\n            generator_device: Device for random generator (\"cuda\" or \"cpu\")\n            input_reference: Path to input reference image for image-to-video\n\n        Returns:\n            Dictionary containing completed video job information with file_path\n        \"\"\"\n        if not prompt:\n            raise ValueError(\"Prompt cannot be empty\")\n\n        # Determine size\n        if size is None:\n            if width is not None and height is not None:\n                size = f\"{width}x{height}\"\n            else:\n                size = \"720x1280\"\n\n        # Prepare request payload\n        payload: Dict[str, Any] = {\n            \"prompt\": prompt,\n            \"size\": size,\n        }\n\n        # Add optional parameters\n        if seconds is not None:\n            payload[\"seconds\"] = seconds\n        if fps is not None:\n            payload[\"fps\"] = fps\n        if num_frames is not None:\n            payload[\"num_frames\"] = num_frames\n        if negative_prompt:\n            payload[\"negative_prompt\"] = negative_prompt\n        if guidance_scale is not None:\n            payload[\"guidance_scale\"] = guidance_scale\n        if num_inference_steps is not None:\n            payload[\"num_inference_steps\"] = num_inference_steps\n        if seed is not None and seed >= 0:\n            payload[\"seed\"] = seed\n        if enable_teacache:\n            payload[\"enable_teacache\"] = True\n        if generator_device:\n            payload[\"generator_device\"] = generator_device\n        if input_reference:\n            payload[\"input_reference\"] = input_reference\n        if output_path:\n            payload[\"output_path\"] = output_path\n\n        try:\n            # Create video generation job\n            response = requests.post(\n                f\"{self.base_url}/videos\",\n                json=payload,\n                headers=self.headers,\n                timeout=30,\n            )\n            response.raise_for_status()\n            video_job = response.json()\n            video_id = video_job.get(\"id\")\n\n            # Wait for completion with fixed polling\n            poll_interval = 5  # 5 seconds\n            max_wait_time = 3600  # 1 hour\n            max_consecutive_errors = 5\n            consecutive_errors = 0\n            start_time = time.time()\n\n            while time.time() - start_time < max_wait_time:\n                try:\n                    status_response = requests.get(\n                        f\"{self.base_url}/videos/{video_id}\",\n                        headers=self.headers,\n                        timeout=30,\n                    )\n                    status_response.raise_for_status()\n                    status = status_response.json()\n\n                    # Reset error counter on successful request\n                    consecutive_errors = 0\n\n                    if status.get(\"status\") == \"completed\":\n                        return status\n                    elif status.get(\"status\") == \"failed\":\n                        error = status.get(\"error\", {})\n                        error_msg = (\n                            error.get(\"message\", \"Unknown error\")\n                            if error\n                            else \"Unknown error\"\n                        )\n                        raise RuntimeError(f\"Video generation failed: {error_msg}\")\n                except requests.exceptions.ConnectionError as e:\n                    # Connection errors - likely server is down\n                    consecutive_errors += 1\n                    if consecutive_errors >= max_consecutive_errors:\n                        raise RuntimeError(\n                            f\"Lost connection to server after {consecutive_errors} consecutive errors. \"\n                            f\"Server may be unavailable: {str(e)}\"\n                        )\n                except requests.exceptions.RequestException as e:\n                    # Other network errors - continue polling but track errors\n                    consecutive_errors += 1\n                    if consecutive_errors >= max_consecutive_errors:\n                        raise RuntimeError(\n                            f\"Network error after {consecutive_errors} consecutive failures: {str(e)}\"\n                        )\n\n                time.sleep(poll_interval)\n\n            raise TimeoutError(\n                f\"Video generation timed out after {max_wait_time} seconds\"\n            )\n        except requests.exceptions.RequestException as e:\n            raise RuntimeError(f\"Failed to generate video: {str(e)}\")\n\n    def _build_image_common_params(\n        self,\n        prompt: str,\n        size: str,\n        n: int,\n        response_format: str,\n        negative_prompt: Optional[str] = None,\n        guidance_scale: Optional[float] = None,\n        num_inference_steps: Optional[int] = None,\n        seed: Optional[int] = None,\n        enable_teacache: bool = False,\n        background: Optional[str] = None,\n        output_format: Optional[str] = None,\n        generator_device: Optional[str] = None,\n    ) -> Dict[str, Any]:\n        \"\"\"\n        Build common parameters for both image generation and editing.\n\n        Returns:\n            Dictionary containing common parameters\n        \"\"\"\n        params: Dict[str, Any] = {\n            \"prompt\": prompt,\n            \"size\": size,\n            \"n\": max(1, min(n, 10)),\n            \"response_format\": response_format,\n        }\n\n        # Add optional parameters\n        if negative_prompt:\n            params[\"negative_prompt\"] = negative_prompt\n        if guidance_scale is not None:\n            params[\"guidance_scale\"] = guidance_scale\n        if num_inference_steps is not None:\n            params[\"num_inference_steps\"] = num_inference_steps\n        if seed is not None and seed >= 0:\n            params[\"seed\"] = seed\n        if enable_teacache:\n            params[\"enable_teacache\"] = True\n        if background:\n            params[\"background\"] = background\n        if output_format:\n            params[\"output_format\"] = output_format\n        if generator_device:\n            params[\"generator_device\"] = generator_device\n\n        return params\n\n    def _get_content_type(self, file_path: str) -> str:\n        \"\"\"Get content type based on file extension.\"\"\"\n        ext = os.path.splitext(file_path)[1].lower()\n        content_types = {\n            \".png\": \"image/png\",\n            \".jpg\": \"image/jpeg\",\n            \".jpeg\": \"image/jpeg\",\n            \".webp\": \"image/webp\",\n        }\n        return content_types.get(ext, \"image/png\")\n\n    def decode_image_from_response(\n        self, response_data: Dict[str, Any], index: int = 0\n    ) -> Image.Image:\n        \"\"\"\n        Decode base64 image from API response.\n\n        Args:\n            response_data: API response dictionary\n            index: Index of the image in the response (default: 0)\n\n        Returns:\n            PIL Image object\n        \"\"\"\n        if \"data\" not in response_data or not response_data[\"data\"]:\n            raise ValueError(\"No image data in response\")\n\n        if index >= len(response_data[\"data\"]):\n            raise IndexError(f\"Image index {index} out of range\")\n\n        image_data = response_data[\"data\"][index]\n        if \"b64_json\" not in image_data or not image_data[\"b64_json\"]:\n            raise ValueError(\"No base64 image data found\")\n\n        image_bytes = base64.b64decode(image_data[\"b64_json\"])\n        image = Image.open(io.BytesIO(image_bytes))\n\n        # Convert to RGB if needed\n        if image.mode != \"RGB\":\n            image = image.convert(\"RGB\")\n\n        return image\n\n    def set_lora(\n        self,\n        lora_nickname: str,\n        lora_path: Optional[str] = None,\n        target: str = \"all\",\n    ) -> Dict[str, Any]:\n        \"\"\"\n        Set a LoRA adapter for the specified transformer(s).\n\n        Args:\n            lora_nickname: The nickname of the adapter (required).\n            lora_path: Path to the LoRA adapter (local path or HF repo id).\n                      Required for the first load; optional if re-activating a cached nickname.\n            target: Which transformer(s) to apply the LoRA to. One of:\n                - \"all\": Apply to all transformers (default)\n                - \"transformer\": Apply only to the primary transformer (high noise for Wan2.2)\n                - \"transformer_2\": Apply only to transformer_2 (low noise for Wan2.2)\n                - \"critic\": Apply only to the critic model\n\n        Returns:\n            Dictionary containing the API response with status and message\n        \"\"\"\n        if not lora_nickname:\n            raise ValueError(\"lora_nickname cannot be empty\")\n\n        # Prepare request payload\n        payload: Dict[str, Any] = {\n            \"lora_nickname\": lora_nickname,\n            \"target\": target,\n        }\n\n        # Add optional lora_path if provided\n        if lora_path:\n            payload[\"lora_path\"] = lora_path\n\n        try:\n            response = requests.post(\n                f\"{self.base_url}/set_lora\",\n                json=payload,\n                headers=self.headers,\n                timeout=30,\n            )\n            response.raise_for_status()\n            return response.json()\n        except requests.exceptions.RequestException as e:\n            raise RuntimeError(f\"Failed to set LoRA adapter: {str(e)}\")\n\n    def unset_lora(\n        self,\n        target: str = \"all\",\n    ) -> Dict[str, Any]:\n        \"\"\"\n        Unset (unmerge) LoRA weights from the base model.\n\n        Args:\n            target: same as set_lora\n\n        Returns:\n            Dictionary containing the API response with status and message\n        \"\"\"\n        # Prepare request payload\n        payload: Dict[str, Any] = {\n            \"target\": target,\n        }\n\n        try:\n            response = requests.post(\n                f\"{self.base_url}/unmerge_lora_weights\",\n                json=payload,\n                headers=self.headers,\n                timeout=30,\n            )\n            response.raise_for_status()\n            return response.json()\n        except requests.exceptions.RequestException as e:\n            raise RuntimeError(f\"Failed to unset LoRA adapter: {str(e)}\")\n\n\nif __name__ == \"__main__\":\n    api = SGLDiffusionServerAPI(\n        base_url=\"http://localhost:30010/v1\", api_key=\"sk-proj-1234567890\"\n    )\n    model_info = api.get_model_info()\n    print(api.get_model_info())\n    if model_info.get(\"task_type\") == \"T2V\" or model_info.get(\"task_type\") == \"I2V\":\n        print(\n            api.generate_video(\n                prompt=\"A calico cat playing a piano on stage\",\n                num_inference_steps=1,\n                size=\"480x480\",\n            )\n        )\n    else:\n        print(\n            api.generate_image(\n                prompt=\"A calico cat playing a piano on stage\", size=\"1024x1024\"\n            )\n        )\n"
  },
  {
    "path": "python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/executors/__init__.py",
    "content": "\"\"\"\nComfyUI SGLang Diffusion executors package.\nProvides executor classes for different model types.\n\"\"\"\n\nfrom .base import SGLDiffusionExecutor\nfrom .flux import FluxExecutor\nfrom .qwen_image import QwenImageEditExecutor, QwenImageExecutor\nfrom .zimage import ZImageExecutor\n\n__all__ = [\n    \"SGLDiffusionExecutor\",\n    \"FluxExecutor\",\n    \"ZImageExecutor\",\n    \"QwenImageExecutor\",\n    \"QwenImageEditExecutor\",\n]\n"
  },
  {
    "path": "python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/executors/base.py",
    "content": "\"\"\"\nBase executor class for SGLang Diffusion ComfyUI integration.\n\"\"\"\n\nimport torch\n\n\nclass SGLDiffusionExecutor(torch.nn.Module):\n    \"\"\"Base executor class for SGLang Diffusion models in ComfyUI.\"\"\"\n\n    def __init__(self, generator, model_path, model, config):\n        super(SGLDiffusionExecutor, self).__init__()\n        self.generator = generator\n        self.model_path = model_path\n        self.model = model\n        self.dtype = config.unet_config[\"dtype\"]\n        self.config = config\n        self.loras = []\n\n    @staticmethod\n    def should_suppress_logs(timestep):\n        \"\"\"Determine if logs should be suppressed based on timestep value.\"\"\"\n        if torch.is_tensor(timestep):\n            return bool((timestep < 1.0).item())\n        return bool(timestep < 1.0)\n\n    def set_lora(self, lora_nickname=None, lora_path=None, strength=None, target=None):\n        \"\"\"Set LoRA adapter using SGLang Diffusion API.\"\"\"\n        if len(lora_nickname) > 0:\n            self.generator.set_lora(\n                lora_nickname=lora_nickname,\n                lora_path=lora_path,\n                strength=strength,\n                target=target,\n            )\n\n    def _unpack_latents(self, latents, height, width, channels):\n        \"\"\"Unpack latents from packed format to standard format.\"\"\"\n        batch_size = latents.shape[0]\n        latents = latents.view(batch_size, height // 2, width // 2, channels, 2, 2)\n        latents = latents.permute(0, 3, 1, 4, 2, 5)\n        latents = latents.reshape(batch_size, channels, height, width)\n\n        return latents\n\n    def _pack_latents(self, latents):\n        \"\"\"Pack latents from standard format to packed format.\"\"\"\n        batch_size, num_channels_latents, height, width = latents.shape\n        latents = latents.view(\n            batch_size, num_channels_latents, height // 2, 2, width // 2, 2\n        )\n        latents = latents.permute(0, 2, 4, 1, 3, 5)\n        latents = latents.reshape(\n            batch_size, (height // 2) * (width // 2), num_channels_latents * 4\n        )\n        return latents\n"
  },
  {
    "path": "python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/executors/flux.py",
    "content": "\"\"\"\nFlux executor for SGLang Diffusion ComfyUI integration.\n\"\"\"\n\nimport torch\n\ntry:\n    from sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams\n    from sglang.multimodal_gen.runtime.entrypoints.utils import prepare_request\nexcept ImportError:\n    print(\n        \"Error: sglang.multimodal_gen is not installed. Please install it using 'pip install sglang[diffusion]'\"\n    )\n\nfrom .base import SGLDiffusionExecutor\n\n\nclass FluxExecutor(SGLDiffusionExecutor):\n    \"\"\"Executor for Flux models in ComfyUI.\"\"\"\n\n    def __init__(self, generator, model_path, model, config):\n        super().__init__(generator, model_path, model, config)\n\n    def forward(self, x, timestep, context, y=None, guidance=None, **kwargs):\n        \"\"\"Forward pass for Flux model.\"\"\"\n        hidden_states = self._pack_latents(x)\n        timesteps = timestep * 1000.0\n        encoder_hidden_states = context\n        pooled_projections = y\n        guidance = guidance * 1000.0\n\n        B, C, H, W = x.shape\n        height = H * 8\n        width = W * 8\n        # Create SamplingParams\n        sampling_params = SamplingParams.from_user_sampling_params_args(\n            self.model_path,\n            server_args=self.generator.server_args,\n            prompt=\" \",\n            guidance_scale=3.5,  # Flux typically uses embedded_cfg_scale=3.5\n            height=height,\n            width=width,\n            num_frames=1,\n            num_inference_steps=1,\n            save_output=False,\n            suppress_logs=self.should_suppress_logs(timestep),\n        )\n\n        # Prepare request (converts SamplingParams to Req)\n        req = prepare_request(\n            server_args=self.generator.server_args,\n            sampling_params=sampling_params,\n        )\n        req.latents = hidden_states  # Set as [B, S, D] format directly\n        req.timesteps = timesteps  # ComfyUI's timesteps parameter\n        req.prompt_embeds = [pooled_projections, encoder_hidden_states]  # [CLIP, T5]\n        req.raw_latent_shape = torch.tensor(hidden_states.shape, dtype=torch.long)\n\n        # Set pooled_projections (required by Flux)\n        req.pooled_embeds = [pooled_projections]  # List format as per Req definition\n        req.do_classifier_free_guidance = False\n        req.generator = [\n            torch.Generator(\"cuda\") for _ in range(req.num_outputs_per_prompt)\n        ]\n\n        # Send request to scheduler\n        output_batch = self.generator._send_to_scheduler_and_wait_for_response([req])\n        noise_pred = output_batch.noise_pred\n        return self._unpack_latents(noise_pred, H, W, C).to(x.device)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/executors/qwen_image.py",
    "content": "\"\"\"\nQwenImage executor for SGLang Diffusion ComfyUI integration.\n\"\"\"\n\nimport torch\n\ntry:\n    from sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams\n    from sglang.multimodal_gen.runtime.entrypoints.utils import prepare_request\nexcept ImportError:\n    print(\n        \"Error: sglang.multimodal_gen is not installed. Please install it using 'pip install sglang[diffusion]'\"\n    )\n\nimport comfy.ldm.common_dit\n\nfrom .base import SGLDiffusionExecutor\n\n\nclass QwenImageExecutor(SGLDiffusionExecutor):\n    \"\"\"Executor for QwenImage models in ComfyUI.\"\"\"\n\n    def __init__(self, generator, model_path, model, config):\n        super().__init__(generator, model_path, model, config)\n        self.patch_size = 2\n\n    def _pack_latents(self, x):\n        \"\"\"Process hidden states for QwenImage model.\"\"\"\n        bs, c, t, h, w = x.shape\n        patch_size = self.patch_size\n        latents = comfy.ldm.common_dit.pad_to_patch_size(\n            x, (1, self.patch_size, self.patch_size)\n        )\n        orig_shape = latents.shape\n        latents = latents.view(\n            orig_shape[0],\n            orig_shape[1],\n            orig_shape[-3],\n            orig_shape[-2] // 2,\n            2,\n            orig_shape[-1] // 2,\n            2,\n        )\n        latents = latents.permute(0, 2, 3, 5, 1, 4, 6)\n        latents = latents.reshape(\n            orig_shape[0],\n            orig_shape[-3] * (orig_shape[-2] // 2) * (orig_shape[-1] // 2),\n            orig_shape[1] * 4,\n        )\n        return latents, orig_shape\n\n    def _unpack_latents(self, latents, num_embeds, orig_shape, x):\n        \"\"\"Unpack hidden states from packed format to standard format.\"\"\"\n        latents = latents[:, :num_embeds].view(\n            orig_shape[0],\n            orig_shape[-3],\n            orig_shape[-2] // 2,\n            orig_shape[-1] // 2,\n            orig_shape[1],\n            2,\n            2,\n        )\n        latents = latents.permute(0, 4, 1, 2, 5, 3, 6)\n        latents = latents.reshape(orig_shape)[:, :, :, : x.shape[-2], : x.shape[-1]]\n        return latents\n\n    def forward(self, x, timestep, context, **kwargs):\n        \"\"\"Forward pass for QwenImage model.\"\"\"\n        latents, orig_shape = self._pack_latents(x)\n        num_embeds = latents.shape[1]\n        height = orig_shape[-2] * 8\n        width = orig_shape[-1] * 8\n\n        sampling_params = SamplingParams.from_user_sampling_params_args(\n            self.model_path,\n            server_args=self.generator.server_args,\n            prompt=\" \",\n            guidance_scale=1.0,\n            height=height,\n            width=width,\n            num_frames=1,\n            num_inference_steps=1,\n            save_output=False,\n            suppress_logs=self.should_suppress_logs(timestep),\n        )\n\n        # Prepare request (converts SamplingParams to Req)\n        req = prepare_request(\n            server_args=self.generator.server_args,\n            sampling_params=sampling_params,\n        )\n        # Set ComfyUI-specific inputs directly on the Req object\n        req.latents = latents\n        req.timesteps = timestep * 1000.0\n        req.prompt_embeds = [context]\n        req.raw_latent_shape = torch.tensor(latents.shape, dtype=torch.long)\n        req.do_classifier_free_guidance = False\n        req.generator = [\n            torch.Generator(\"cuda\") for _ in range(req.num_outputs_per_prompt)\n        ]\n\n        output_batch = self.generator._send_to_scheduler_and_wait_for_response([req])\n        noise_pred = output_batch.noise_pred\n\n        return self._unpack_latents(noise_pred, num_embeds, orig_shape, x)\n\n\nclass QwenImageEditExecutor(QwenImageExecutor):\n    \"\"\"Executor for QwenImageEdit models in ComfyUI.\"\"\"\n\n    def __init__(self, generator, model_path, model, config):\n        super().__init__(generator, model_path, model, config)\n\n    def forward(\n        self,\n        x,\n        timestep,\n        context,\n        attention_mask=None,\n        ref_latents=None,\n        additional_t_cond=None,\n        transformer_options={},\n        **kwargs\n    ):\n        \"\"\"Forward pass for QwenImageEdit model.\"\"\"\n        latents, orig_shape = self._pack_latents(x)\n        num_embeds = latents.shape[1]\n        height = orig_shape[-2] * 8\n        width = orig_shape[-1] * 8\n\n        # Prepare vae_image_sizes for the condition image (ref_latents)\n        vae_image_sizes = []\n        pack_ref_latents = None\n\n        # TODO: sgld now don't support multiple condition images, so we only support one condition image for now.\n        if ref_latents is not None and len(ref_latents) > 0:\n            pack_ref_latents, orig_ref_shape = self._pack_latents(ref_latents[0])\n            vae_image_sizes = [(orig_ref_shape[-1], orig_ref_shape[-2])]\n\n        sampling_params = SamplingParams.from_user_sampling_params_args(\n            self.model_path,\n            server_args=self.generator.server_args,\n            prompt=\" \",\n            guidance_scale=1.0,\n            image_path=\"\",\n            height=height,\n            width=width,\n            num_frames=1,\n            num_inference_steps=1,\n            save_output=False,\n            suppress_logs=self.should_suppress_logs(timestep),\n        )\n\n        # Prepare request (converts SamplingParams to Req)\n        req = prepare_request(\n            server_args=self.generator.server_args,\n            sampling_params=sampling_params,\n        )\n        # Set ComfyUI-specific inputs directly on the Req object\n        req.latents = latents\n        req.image_latent = pack_ref_latents\n        req.timesteps = timestep * 1000.0\n        req.vae_image_sizes = vae_image_sizes\n        req.prompt_embeds = [context]\n        req.raw_latent_shape = torch.tensor(latents.shape, dtype=torch.long)\n        req.do_classifier_free_guidance = False\n        req.generator = [\n            torch.Generator(\"cuda\") for _ in range(req.num_outputs_per_prompt)\n        ]\n\n        output_batch = self.generator._send_to_scheduler_and_wait_for_response([req])\n        noise_pred = output_batch.noise_pred\n\n        return self._unpack_latents(noise_pred, num_embeds, orig_shape, x)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/executors/zimage.py",
    "content": "\"\"\"\nZImage executor for SGLang Diffusion ComfyUI integration.\n\"\"\"\n\nimport torch\n\ntry:\n    from sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams\n    from sglang.multimodal_gen.runtime.entrypoints.utils import prepare_request\nexcept ImportError:\n    print(\n        \"Error: sglang.multimodal_gen is not installed. Please install it using 'pip install sglang[diffusion]'\"\n    )\n\nfrom .base import SGLDiffusionExecutor\n\n\nclass ZImageExecutor(SGLDiffusionExecutor):\n    \"\"\"Executor for ZImage models in ComfyUI.\"\"\"\n\n    def __init__(self, generator, model_path, model, config):\n        super().__init__(generator, model_path, model, config)\n\n    def forward(self, x, timesteps, context, **kwargs):\n        \"\"\"Forward pass for ZImage model.\"\"\"\n        B, C, H, W = x.shape\n        height = H * 8\n        width = W * 8\n        sampling_params = SamplingParams.from_user_sampling_params_args(\n            self.model_path,\n            server_args=self.generator.server_args,\n            prompt=\" \",\n            guidance_scale=1.0,\n            height=height,\n            width=width,\n            num_frames=1,  # For images\n            num_inference_steps=1,  # Single step for ComfyUI\n            save_output=False,\n            suppress_logs=self.should_suppress_logs(timesteps),\n        )\n\n        # Prepare request (converts SamplingParams to Req)\n        req = prepare_request(\n            server_args=self.generator.server_args,\n            sampling_params=sampling_params,\n        )\n        latents = x.unsqueeze(2)\n        context = context.squeeze(0)\n        # Set ComfyUI-specific inputs directly on the Req object\n        req.latents = latents  # ComfyUI's x parameter\n        req.timesteps = timesteps * 1000.0  # ComfyUI's timesteps parameter\n        req.prompt_embeds = [\n            context\n        ]  # ComfyUI's context parameter (must be List[Tensor])\n        req.raw_latent_shape = torch.tensor(latents.shape, dtype=torch.long)\n        req.do_classifier_free_guidance = False\n        req.generator = [\n            torch.Generator(\"cuda\") for _ in range(req.num_outputs_per_prompt)\n        ]\n\n        output_batch = self.generator._send_to_scheduler_and_wait_for_response([req])\n        noise_pred = output_batch.noise_pred\n\n        return noise_pred.permute(1, 0, 2, 3).to(x.device)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/nodes.py",
    "content": "\"\"\"\nComfyUI nodes for SGLang Diffusion integration.\nProvides nodes for connecting to SGLang Diffusion server and generating images/videos.\n\"\"\"\n\nimport os\nimport uuid\n\nimport folder_paths\nimport torch\n\nfrom .core import SGLDiffusionGenerator, SGLDiffusionServerAPI\nfrom .utils import (\n    convert_b64_to_tensor_image,\n    convert_video_to_comfy_video,\n    get_image_path,\n    is_empty_image,\n)\n\n\nclass SGLDOptions:\n    @classmethod\n    def INPUT_TYPES(cls):\n        return {\n            \"required\": {},\n            \"optional\": {\n                \"model_type\": (\n                    [\"auto-detect\", \"qwen_image\", \"qwen_image_edit\", \"flux\", \"lumina2\"],\n                    {\"default\": \"auto-detect\"},\n                ),\n                \"enable_torch_compile\": (\n                    \"BOOLEAN\",\n                    {\"default\": False},\n                ),\n                \"num_gpus\": (\"INT\", {\"default\": 1, \"min\": 1, \"step\": 1}),\n                \"tp_size\": (\"INT\", {\"default\": -1, \"min\": -1, \"step\": 1}),\n                \"sp_degree\": (\"INT\", {\"default\": -1, \"min\": -1, \"step\": 1}),\n                \"ulysses_degree\": (\n                    \"INT\",\n                    {\n                        \"default\": -1,\n                        \"min\": -1,\n                        \"step\": 1,\n                    },\n                ),\n                \"ring_degree\": (\n                    \"INT\",\n                    {\n                        \"default\": -1,\n                        \"min\": -1,\n                        \"step\": 1,\n                    },\n                ),\n                \"dp_size\": (\"INT\", {\"default\": 1, \"min\": 1, \"step\": 1}),\n                \"dp_degree\": (\"INT\", {\"default\": 1, \"min\": 1, \"step\": 1}),\n                \"enable_cfg_parallel\": (\n                    \"BOOLEAN\",\n                    {\"default\": False},\n                ),\n                \"attention_backend\": (\n                    \"STRING\",\n                    {\"default\": \"\"},\n                ),\n            },\n        }\n\n    RETURN_TYPES = (\"SGLD_OPTIONS\",)\n    RETURN_NAMES = (\"sgld_options\",)\n    FUNCTION = \"create_options\"\n    CATEGORY = \"SGLDiffusion\"\n\n    def create_options(\n        self,\n        model_type: str = \"auto-detect\",\n        enable_torch_compile: bool = False,\n        num_gpus: int = 1,\n        tp_size: int = -1,\n        sp_degree: int = -1,\n        ulysses_degree: int = -1,\n        ring_degree: int = -1,\n        dp_size: int = 1,\n        dp_degree: int = 1,\n        enable_cfg_parallel: bool = False,\n        attention_backend: str = \"\",\n    ):\n        \"\"\"\n        Build a dictionary of SGLang Diffusion runtime options.\n        \"\"\"\n        # Convert -1 to None for optional parameters (matching ServerArgs defaults)\n        ulysses_degree = None if ulysses_degree == -1 else ulysses_degree\n        ring_degree = None if ring_degree == -1 else ring_degree\n        attention_backend = None if attention_backend == \"\" else attention_backend\n\n        options = {\n            \"model_type\": model_type,\n            \"enable_torch_compile\": enable_torch_compile,\n            \"num_gpus\": num_gpus,\n            \"tp_size\": tp_size,\n            \"sp_degree\": sp_degree,\n            \"ulysses_degree\": ulysses_degree,\n            \"ring_degree\": ring_degree,\n            \"dp_size\": dp_size,\n            \"dp_degree\": dp_degree,\n            \"enable_cfg_parallel\": enable_cfg_parallel,\n            \"attention_backend\": attention_backend,\n        }\n\n        # Strip None to keep payload clean\n        options = {k: v for k, v in options.items() if v is not None}\n        return (options,)\n\n\nclass SGLDLoraLoader:\n    @classmethod\n    def INPUT_TYPES(cls):\n        return {\n            \"required\": {\n                \"model\": (\"MODEL\",),\n                \"lora_name\": (folder_paths.get_filename_list(\"loras\"),),\n                \"strength_model\": (\n                    \"FLOAT\",\n                    {\"default\": 1.0, \"min\": 0, \"max\": 10, \"step\": 0.01},\n                ),\n                \"nickname\": (\"STRING\", {\"default\": \"\"}),\n                \"target\": (\n                    [\"all\", \"transformer\", \"transformer_2\", \"critic\"],\n                    {\"default\": \"all\"},\n                ),\n            },\n        }\n\n    RETURN_TYPES = (\"MODEL\",)\n    FUNCTION = \"load_lora\"\n\n    CATEGORY = \"SGLDiffusion\"\n\n    def load_lora(\n        self, model, lora_name, strength_model=1.0, nickname=\"\", target=\"all\"\n    ):\n        \"\"\"Load LoRA adapter using SGLang Diffusion API.\"\"\"\n        lora_path = folder_paths.get_full_path(\"loras\", lora_name)\n        assert model is not None\n        bi = model.clone()\n        nickname = nickname if nickname != \"\" else str(\"lora\" + str(uuid.uuid4()))\n        # set lora in the model\n        bi.patches[nickname] = (lora_path, strength_model, target)\n\n        # prepare input for the SGLang Diffusion API\n        lora_input = {\n            \"lora_nickname\": [],\n            \"lora_path\": [],\n            \"strength\": [],\n            \"target\": [],\n        }\n        for nickname, lora_info in bi.patches.items():\n            lora_input[\"lora_nickname\"].append(nickname)\n            lora_input[\"lora_path\"].append(lora_info[0])\n            lora_input[\"strength\"].append(lora_info[1])\n            lora_input[\"target\"].append(lora_info[2])\n\n        # call the SGLang Diffusion API\n        model.model.diffusion_model.set_lora(**lora_input)\n        return (model,)\n\n\nclass SGLDUNETLoader:\n    def __init__(self):\n        self.generator = SGLDiffusionGenerator()\n\n    @classmethod\n    def INPUT_TYPES(s):\n        return {\n            \"required\": {\n                \"unet_name\": (folder_paths.get_filename_list(\"diffusion_models\"),),\n                \"weight_dtype\": ([\"default\", \"fp8_e4m3fn\", \"fp8_e5m2\"],),\n            },\n            \"optional\": {\n                \"sgld_options\": (\"SGLD_OPTIONS\",),\n            },\n        }\n\n    RETURN_TYPES = (\"MODEL\",)\n    FUNCTION = \"load_unet\"\n\n    CATEGORY = \"SGLDiffusion\"\n\n    def load_unet(self, unet_name, weight_dtype, sgld_options: dict = None):\n        model_options = {}\n        if weight_dtype == \"fp8_e4m3fn\":\n            model_options[\"dtype\"] = torch.float8_e4m3fn\n        elif weight_dtype == \"fp8_e5m2\":\n            model_options[\"dtype\"] = torch.float8_e5m2\n\n        unet_path = folder_paths.get_full_path(\"diffusion_models\", unet_name)\n\n        model = self.generator.load_model(\n            unet_path, model_options=model_options, sgld_options=sgld_options\n        )\n        return (model,)\n\n\nclass SGLDiffusionServerModel:\n    \"\"\"Node to load and manage SGLang Diffusion server connection.\"\"\"\n\n    @classmethod\n    def INPUT_TYPES(cls):\n        return {\n            \"required\": {\n                \"base_url\": (\n                    \"STRING\",\n                    {\n                        \"default\": \"http://localhost:3000/v1\",\n                        \"multiline\": False,\n                    },\n                ),\n                \"api_key\": (\n                    \"STRING\",\n                    {\n                        \"default\": \"sk-proj-1234567890\",\n                        \"multiline\": False,\n                    },\n                ),\n            }\n        }\n\n    RETURN_TYPES = (\"SGLD_CLIENT\", \"STRING\")\n    RETURN_NAMES = (\"sgld_client\", \"model_info\")\n    FUNCTION = \"load_server\"\n    CATEGORY = \"SGLDiffusion\"\n\n    def load_server(self, base_url: str, api_key: str):\n        \"\"\"Initialize OpenAI client for SGLang Diffusion server.\"\"\"\n        client = SGLDiffusionServerAPI(base_url=base_url, api_key=api_key)\n        try:\n            model_info = client.get_model_info()\n            # Format model_info as a readable string\n            info_lines = [\"=== SGLDiffusion Model Info ===\"]\n            for key, value in model_info.items():\n                info_lines.append(f\"{key}: {value}\")\n            model_info_str = \"\\n\".join(info_lines)\n        except Exception as e:\n            model_info_str = f\"Failed to get model info: {str(e)}\"\n        return (client, model_info_str)\n\n\nclass SGLDiffusionGenerateImage:\n    \"\"\"Node to generate images using SGLang Diffusion.\"\"\"\n\n    @classmethod\n    def INPUT_TYPES(cls):\n        return {\n            \"required\": {\n                \"sgld_client\": (\"SGLD_CLIENT\",),\n                \"positive_prompt\": (\n                    \"STRING\",\n                    {\n                        \"default\": \"\",\n                        \"tooltip\": \"Text prompt for image generation\",\n                    },\n                ),\n            },\n            \"optional\": {\n                \"negative_prompt\": (\n                    \"STRING\",\n                    {\n                        \"default\": \"\",\n                        \"tooltip\": \"Negative prompt to avoid certain elements\",\n                    },\n                ),\n                \"image\": (\n                    \"IMAGE\",\n                    {\n                        \"default\": None,\n                        \"tooltip\": \"input image to use for editing\",\n                    },\n                ),\n                \"seed\": (\n                    \"INT\",\n                    {\n                        \"default\": 1024,\n                        \"min\": -1,\n                        \"max\": 2**32 - 1,\n                    },\n                ),\n                \"steps\": (\n                    \"INT\",\n                    {\n                        \"default\": 6,\n                        \"min\": 1,\n                        \"max\": 100,\n                        \"step\": 1,\n                    },\n                ),\n                \"cfg\": (\n                    \"FLOAT\",\n                    {\n                        \"default\": 7.0,\n                        \"min\": 1.0,\n                        \"max\": 20.0,\n                        \"step\": 0.1,\n                    },\n                ),\n                \"width\": (\n                    \"INT\",\n                    {\n                        \"default\": 1024,\n                        \"min\": 256,\n                        \"max\": 4096,\n                        \"step\": 64,\n                    },\n                ),\n                \"height\": (\n                    \"INT\",\n                    {\n                        \"default\": 1024,\n                        \"min\": 256,\n                        \"max\": 4096,\n                        \"step\": 64,\n                    },\n                ),\n                \"enable_teacache\": (\n                    \"BOOLEAN\",\n                    {\n                        \"default\": False,\n                    },\n                ),\n            },\n        }\n\n    RETURN_TYPES = (\"IMAGE\",)\n    RETURN_NAMES = (\"image\",)\n    FUNCTION = \"generate_image\"\n    CATEGORY = \"SGLDiffusion\"\n    OUTPUT_NODE = False\n\n    def generate_image(\n        self,\n        sgld_client: SGLDiffusionServerAPI,\n        positive_prompt: str,\n        negative_prompt: str = \"\",\n        image: torch.Tensor = None,\n        seed: int = 1024,\n        steps: int = 6,\n        cfg: float = 7.0,\n        width: int = 1024,\n        height: int = 1024,\n        enable_teacache: bool = False,\n    ):\n        \"\"\"Generate image using SGLang Diffusion API.\"\"\"\n        if not positive_prompt:\n            raise ValueError(\"Prompt cannot be empty\")\n\n        size = f\"{width}x{height}\"\n\n        # Prepare request parameters\n        request_params = {\n            \"prompt\": positive_prompt,\n            \"size\": size,\n            \"response_format\": \"b64_json\",\n        }\n\n        # Add optional parameters if provided\n        if negative_prompt:\n            request_params[\"negative_prompt\"] = negative_prompt\n        if cfg is not None:\n            request_params[\"guidance_scale\"] = cfg\n        if steps is not None:\n            request_params[\"num_inference_steps\"] = steps\n        if seed is not None and seed >= 0:\n            request_params[\"seed\"] = seed\n        if enable_teacache:\n            request_params[\"enable_teacache\"] = True\n        if image is not None:\n            # If the image is empty, use the size of the image to generate the image\n            if is_empty_image(image):\n                width, height = image.shape[2], image.shape[1]\n                size = f\"{width}x{height}\"\n                request_params[\"size\"] = size\n            else:\n                request_params[\"image_path\"] = get_image_path(image)\n\n        # Call API\n        try:\n            response = sgld_client.generate_image(**request_params)\n        except Exception as e:\n            raise RuntimeError(f\"Failed to generate image: {str(e)}\")\n\n        # Decode base64 image\n        if not response[\"data\"] or not response[\"data\"][0][\"b64_json\"]:\n            raise RuntimeError(\"No image data in response\")\n        image_data = response[\"data\"][0][\"b64_json\"]\n        image = convert_b64_to_tensor_image(image_data)\n\n        return (image,)\n\n\nclass SGLDiffusionGenerateVideo:\n    \"\"\"Node to generate videos using SGLang Diffusion.\"\"\"\n\n    @classmethod\n    def INPUT_TYPES(cls):\n        return {\n            \"required\": {\n                \"sgld_client\": (\"SGLD_CLIENT\",),\n                \"positive_prompt\": (\n                    \"STRING\",\n                    {\n                        \"default\": \"\",\n                        \"tooltip\": \"Text prompt for video generation\",\n                    },\n                ),\n            },\n            \"optional\": {\n                \"negative_prompt\": (\n                    \"STRING\",\n                    {\n                        \"default\": \"\",\n                        \"tooltip\": \"Negative prompt to avoid certain elements\",\n                    },\n                ),\n                \"image\": (\n                    \"IMAGE\",\n                    {\n                        \"default\": None,\n                        \"tooltip\": \"input image to use for image-to-video\",\n                    },\n                ),\n                \"seed\": (\n                    \"INT\",\n                    {\n                        \"default\": 1024,\n                        \"min\": -1,\n                        \"max\": 2**32 - 1,\n                    },\n                ),\n                \"steps\": (\n                    \"INT\",\n                    {\n                        \"default\": 6,\n                        \"min\": 1,\n                        \"max\": 100,\n                        \"step\": 1,\n                    },\n                ),\n                \"cfg\": (\n                    \"FLOAT\",\n                    {\n                        \"default\": 7.0,\n                        \"min\": 1.0,\n                        \"max\": 20.0,\n                        \"step\": 0.1,\n                    },\n                ),\n                \"width\": (\n                    \"INT\",\n                    {\n                        \"default\": 1280,\n                        \"min\": 256,\n                        \"max\": 4096,\n                        \"step\": 1,\n                    },\n                ),\n                \"height\": (\n                    \"INT\",\n                    {\n                        \"default\": 720,\n                        \"min\": 256,\n                        \"max\": 4096,\n                        \"step\": 1,\n                    },\n                ),\n                \"num_frames\": (\n                    \"INT\",\n                    {\n                        \"default\": 120,\n                        \"min\": 1,\n                        \"max\": 1000,\n                        \"step\": 1,\n                    },\n                ),\n                \"fps\": (\n                    \"INT\",\n                    {\n                        \"default\": 24,\n                        \"min\": 1,\n                        \"max\": 60,\n                        \"step\": 1,\n                    },\n                ),\n                \"seconds\": (\n                    \"INT\",\n                    {\n                        \"default\": 5,\n                        \"min\": 1,\n                        \"max\": 60,\n                        \"step\": 1,\n                    },\n                ),\n                \"enable_teacache\": (\n                    \"BOOLEAN\",\n                    {\n                        \"default\": False,\n                    },\n                ),\n            },\n        }\n\n    RETURN_TYPES = (\"VIDEO\", \"STRING\")\n    RETURN_NAMES = (\"video\", \"video_path\")\n    FUNCTION = \"generate_video\"\n    CATEGORY = \"SGLDiffusion\"\n    OUTPUT_NODE = False\n\n    def generate_video(\n        self,\n        sgld_client: SGLDiffusionServerAPI,\n        positive_prompt: str,\n        negative_prompt: str = \"\",\n        image: torch.Tensor = None,\n        seed: int = 1024,\n        steps: int = 6,\n        cfg: float = 7.0,\n        width: int = 1280,\n        height: int = 720,\n        num_frames: int = 120,\n        fps: int = 24,\n        seconds: int = 5,\n        enable_teacache: bool = False,\n    ):\n        \"\"\"Generate video using SGLang Diffusion API.\"\"\"\n        if not positive_prompt:\n            raise ValueError(\"Prompt cannot be empty\")\n\n        size = f\"{width}x{height}\"\n        output_dir = folder_paths.get_temp_directory()\n\n        # Prepare request parameters\n        request_params = {\n            \"prompt\": positive_prompt,\n            \"size\": size,\n            \"seconds\": seconds,\n            \"fps\": fps,\n            \"output_path\": output_dir,\n        }\n\n        # Add optional parameters if provided\n        if negative_prompt:\n            request_params[\"negative_prompt\"] = negative_prompt\n        if cfg is not None:\n            request_params[\"guidance_scale\"] = cfg\n        if steps is not None:\n            request_params[\"num_inference_steps\"] = steps\n        if seed is not None and seed >= 0:\n            request_params[\"seed\"] = seed\n        if enable_teacache:\n            request_params[\"enable_teacache\"] = True\n        if num_frames is not None:\n            request_params[\"num_frames\"] = num_frames\n        if image is not None:\n            # If the image is empty, use the size of the image to generate the video\n            if is_empty_image(image):\n                width, height = image.shape[2], image.shape[1]\n                size = f\"{width}x{height}\"\n                request_params[\"size\"] = size\n            else:\n                request_params[\"input_reference\"] = get_image_path(image)\n\n        # Call API\n        try:\n            response = sgld_client.generate_video(**request_params)\n            video_path = response.get(\"file_path\", \"\")\n            video = convert_video_to_comfy_video(video_path, height, width)\n        except Exception as e:\n            raise RuntimeError(f\"Failed to generate video: {str(e)}\")\n\n        return (video, video_path)\n\n\nclass SGLDiffusionServerSetLora:\n    \"\"\"Node to set LoRA adapter for SGLang Diffusion server.\"\"\"\n\n    @classmethod\n    def INPUT_TYPES(cls):\n        return {\n            \"required\": {\n                \"sgld_client\": (\"SGLD_CLIENT\",),\n                \"lora_name\": (\n                    \"STRING\",\n                    {\n                        \"default\": \"\",\n                        \"tooltip\": \"The name of the LoRA adapter\",\n                    },\n                ),\n            },\n            \"optional\": {\n                \"lora_nickname\": (\n                    \"STRING\",\n                    {\n                        \"default\": \"\",\n                        \"tooltip\": \"The nickname of the LoRA adapter\",\n                    },\n                ),\n                \"target\": (\n                    [\n                        \"all\",\n                        \"transformer\",\n                        \"transformer_2\",\n                        \"critic\",\n                    ],\n                    {\n                        \"default\": \"all\",\n                        \"tooltip\": \"Which transformer(s) to apply the LoRA to\",\n                    },\n                ),\n            },\n        }\n\n    RETURN_TYPES = (\"SGLD_CLIENT\",)\n    RETURN_NAMES = (\"sgld_client\",)\n    FUNCTION = \"set_lora\"\n    CATEGORY = \"SGLDiffusion\"\n    OUTPUT_NODE = False\n\n    def set_lora(\n        self,\n        sgld_client: SGLDiffusionServerAPI,\n        lora_name: str = \"\",\n        lora_nickname: str = \"\",\n        target: str = \"all\",\n    ):\n        \"\"\"Set LoRA adapter using SGLang Diffusion API.\"\"\"\n        if lora_nickname == \"\":\n            lora_nickname = os.path.splitext(lora_name)[0]\n\n        # Prepare request parameters\n        request_params = {\n            \"lora_nickname\": lora_nickname,\n            \"lora_path\": lora_name,\n            \"target\": target,\n        }\n\n        # Call API\n        try:\n            response = sgld_client.set_lora(**request_params)\n            return (sgld_client,)\n        except Exception as e:\n            raise RuntimeError(f\"Failed to set LoRA adapter: {str(e)}\")\n\n\nclass SGLDiffusionServerUnsetLora:\n    \"\"\"Node to unset LoRA adapter for SGLang Diffusion server.\"\"\"\n\n    @classmethod\n    def INPUT_TYPES(cls):\n        return {\n            \"required\": {\n                \"sgld_client\": (\"SGLD_CLIENT\",),\n            },\n            \"optional\": {\n                \"target\": (\n                    [\n                        \"all\",\n                        \"transformer\",\n                        \"transformer_2\",\n                        \"critic\",\n                    ],\n                    {\n                        \"default\": \"all\",\n                        \"tooltip\": \"Which transformer(s) to unset the LoRA from\",\n                    },\n                ),\n            },\n        }\n\n    RETURN_TYPES = (\"SGLD_CLIENT\",)\n    RETURN_NAMES = (\"sgld_client\",)\n    FUNCTION = \"unset_lora\"\n    CATEGORY = \"SGLDiffusion\"\n    OUTPUT_NODE = False\n\n    def unset_lora(\n        self,\n        sgld_client: SGLDiffusionServerAPI,\n        target: str = \"all\",\n    ):\n        \"\"\"Unset LoRA adapter using SGLang Diffusion API.\"\"\"\n        try:\n            response = sgld_client.unset_lora(target=target)\n            return (sgld_client,)\n        except Exception as e:\n            raise RuntimeError(f\"Failed to unset LoRA adapter: {str(e)}\")\n\n\n# Register nodes\nNODE_CLASS_MAPPINGS = {\n    \"SGLDiffusionServerModel\": SGLDiffusionServerModel,\n    \"SGLDiffusionGenerateImage\": SGLDiffusionGenerateImage,\n    \"SGLDiffusionGenerateVideo\": SGLDiffusionGenerateVideo,\n    \"SGLDiffusionServerSetLora\": SGLDiffusionServerSetLora,\n    \"SGLDiffusionServerUnsetLora\": SGLDiffusionServerUnsetLora,\n    \"SGLDUNETLoader\": SGLDUNETLoader,\n    \"SGLDOptions\": SGLDOptions,\n    \"SGLDLoraLoader\": SGLDLoraLoader,\n}\n\nNODE_DISPLAY_NAME_MAPPINGS = {\n    \"SGLDiffusionServerModel\": \"SGLDiffusion Server Model\",\n    \"SGLDiffusionGenerateImage\": \"SGLDiffusion Generate Image\",\n    \"SGLDiffusionGenerateVideo\": \"SGLDiffusion Generate Video\",\n    \"SGLDiffusionServerSetLora\": \"SGLDiffusion Server Set LoRA\",\n    \"SGLDiffusionServerUnsetLora\": \"SGLDiffusion Server Unset LoRA\",\n    \"SGLDUNETLoader\": \"SGLDiffusion UNET Loader\",\n    \"SGLDOptions\": \"SGLDiffusion Options\",\n    \"SGLDLoraLoader\": \"SGLDiffusion LoRA Loader\",\n}\n"
  },
  {
    "path": "python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/test/README.md",
    "content": "# ComfyUI SGLDiffusion Pipeline Tests\n\nThis directory contains tests for each ComfyUI pipeline integration.\n\n## Test Files\n\n- `test_zimage_pipeline.py` - Tests for ComfyUIZImagePipeline\n- `test_flux_pipeline.py` - Tests for ComfyUIFluxPipeline\n- `test_qwen_image_pipeline.py` - Tests for ComfyUIQwenImagePipeline\n- `test_qwen_image_edit_pipeline.py` - Tests for ComfyUIQwenImageEditPipeline (I2I/edit mode)\n\n## Running Tests\n\n### Run all tests\n\n```bash\npytest python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/test/ -v -s\n```\n\n### Run a specific test file\n\n```bash\npytest python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/test/test_zimage_pipeline.py -v -s\n```\n\n## Environment Variables\n\nYou can configure model paths via environment variables. Model paths support two formats:\n- **Safetensors file**: Path to a single `.safetensors` file (e.g., `/path/to/model.safetensors`)\n- **Diffusers format**: HuggingFace model ID or local diffusers directory (e.g., `Tongyi-MAI/Z-Image-Turbo`)\n\nEnvironment variables:\n- `SGLANG_TEST_ZIMAGE_MODEL_PATH` - Path to ZImage model (default: `Tongyi-MAI/Z-Image-Turbo`)\n- `SGLANG_TEST_FLUX_MODEL_PATH` - Path to Flux model (default: `black-forest-labs/FLUX.1-dev`)\n- `SGLANG_TEST_QWEN_IMAGE_MODEL_PATH` - Path to QwenImage model (default: `Qwen/Qwen-Image`)\n- `SGLANG_TEST_QWEN_IMAGE_EDIT_MODEL_PATH` - Path to QwenImageEdit model (default: `Qwen/Qwen-Image-Edit-2511`)\n\nExamples:\n\n```bash\n# Using HuggingFace model ID (diffusers format)\nexport SGLANG_TEST_ZIMAGE_MODEL_PATH=\"Tongyi-MAI/Z-Image-Turbo\"\npytest python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/test/test_zimage_pipeline.py -v -s\n\n# Using safetensors file\nexport SGLANG_TEST_ZIMAGE_MODEL_PATH=\"/path/to/z_image_turbo_bf16.safetensors\"\npytest python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/test/test_zimage_pipeline.py -v -s\n```\n\n## Test Structure\n\nEach test file follows a similar structure:\n\n1. **Setup**: Creates a `DiffGenerator` with the appropriate pipeline class\n2. **Input Preparation**: Creates dummy tensors for latents, timesteps, and embeddings\n3. **Request Preparation**: Uses `prepare_request` to convert `SamplingParams` to `Req`\n4. **ComfyUI Inputs**: Sets ComfyUI-specific inputs directly on the `Req` object\n5. **Execution**: Sends request to scheduler and waits for response\n6. **Validation**: Checks that `noise_pred` is retrieved from `OutputBatch`\n\n## Notes\n\n- These tests use `comfyui_mode=True` to enable ComfyUI-specific behavior\n- Tests use pre-processed inputs (latents, timesteps, embeddings) as ComfyUI would provide\n- The tests verify that `noise_pred` can be retrieved from the `OutputBatch` after processing\n- All tests use dummy/ones tensors for simplicity - in production, these would be actual model outputs\n"
  },
  {
    "path": "python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/test/__init__.py",
    "content": "\"\"\"\nTest suite for ComfyUI SGLDiffusion pipelines.\n\nThis package contains tests for each ComfyUI pipeline integration:\n- ZImagePipeline\n- FluxPipeline\n- QwenImagePipeline\n- QwenImageEditPipeline\n\"\"\"\n"
  },
  {
    "path": "python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/test/test_flux_pipeline.py",
    "content": "\"\"\"Test for ComfyUIFluxPipeline with pass-through scheduler.\"\"\"\n\nimport os\n\nimport pytest\nimport torch\n\nfrom sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams\nfrom sglang.multimodal_gen.runtime.entrypoints.diffusion_generator import DiffGenerator\nfrom sglang.multimodal_gen.runtime.entrypoints.utils import prepare_request\n\n\ndef test_comfyui_flux_pipeline_direct() -> None:\n    \"\"\"Test ComfyUIFluxPipeline with custom inputs.\"\"\"\n    model_path = os.environ.get(\n        \"SGLANG_TEST_FLUX_MODEL_PATH\",\n        \"black-forest-labs/FLUX.1-dev\",  # Supports both safetensors file and diffusers format\n    )\n\n    generator = DiffGenerator.from_pretrained(\n        model_path=model_path,\n        pipeline_class_name=\"ComfyUIFluxPipeline\",\n        num_gpus=2,\n        comfyui_mode=True,\n    )\n\n    batch_size = 1\n    hidden_states_seq_len = 3600\n    hidden_states_dim = 64\n    height = 1280\n    width = 720\n\n    encoder_seq_len = 512\n    encoder_dim = 4096\n    pooled_dim = 768\n\n    hidden_states = torch.ones(\n        batch_size,\n        hidden_states_seq_len,\n        hidden_states_dim,\n        device=\"cuda\",\n        dtype=torch.bfloat16,\n    )\n\n    encoder_hidden_states = torch.ones(\n        batch_size,\n        encoder_seq_len,\n        encoder_dim,\n        device=\"cuda\",\n        dtype=torch.bfloat16,\n    )\n\n    pooled_projections = torch.ones(\n        batch_size,\n        pooled_dim,\n        device=\"cuda\",\n        dtype=torch.bfloat16,\n    )\n\n    timesteps = torch.tensor([1000], dtype=torch.long, device=\"cuda\")\n\n    sampling_params = SamplingParams.from_user_sampling_params_args(\n        generator.server_args.model_path,\n        server_args=generator.server_args,\n        prompt=\"a beautiful girl\",\n        height=height,\n        width=width,\n        num_frames=1,\n        num_inference_steps=1,\n        save_output=True,\n        return_trajectory_latents=True,\n    )\n\n    req = prepare_request(\n        server_args=generator.server_args,\n        sampling_params=sampling_params,\n    )\n\n    req.latents = hidden_states\n    req.timesteps = timesteps\n    req.raw_latent_shape = torch.tensor(hidden_states.shape, dtype=torch.long)\n\n    clip_dim = 768\n    dummy_clip_embedding = torch.zeros(\n        batch_size,\n        77,\n        clip_dim,\n        device=\"cuda\",\n        dtype=torch.bfloat16,\n    )\n    req.prompt_embeds = [pooled_projections, encoder_hidden_states]\n\n    if req.guidance_scale > 1.0:\n        dummy_neg_clip_embedding = torch.zeros(\n            batch_size,\n            77,\n            clip_dim,\n            device=\"cuda\",\n            dtype=torch.bfloat16,\n        )\n        negative_encoder_hidden_states = torch.ones(\n            batch_size,\n            encoder_seq_len,\n            encoder_dim,\n            device=\"cuda\",\n            dtype=torch.bfloat16,\n        )\n        req.negative_prompt_embeds = [\n            dummy_neg_clip_embedding,\n            negative_encoder_hidden_states,\n        ]\n    else:\n        req.negative_prompt_embeds = None\n\n    req.pooled_embeds = [pooled_projections]\n    req.neg_pooled_embeds = []\n\n    if (\n        req.guidance_scale > 1.0\n        and req.negative_prompt_embeds is not None\n        and len(req.negative_prompt_embeds) > 0\n    ):\n        req.do_classifier_free_guidance = True\n    else:\n        req.do_classifier_free_guidance = False\n\n    if req.seed is not None:\n        generator_device = req.generator_device\n        device_str = \"cuda\" if generator_device == \"cuda\" else \"cpu\"\n        req.generator = [\n            torch.Generator(device_str).manual_seed(req.seed + i)\n            for i in range(req.num_outputs_per_prompt)\n        ]\n    else:\n        req.generator = [\n            torch.Generator(\"cuda\") for _ in range(req.num_outputs_per_prompt)\n        ]\n\n    output_batch = generator._send_to_scheduler_and_wait_for_response([req])\n    noise_pred = output_batch.noise_pred\n\n    assert noise_pred is not None, \"noise_pred should not be None in OutputBatch\"\n    assert isinstance(noise_pred, torch.Tensor), \"noise_pred should be a torch.Tensor\"\n    assert (\n        noise_pred.device.type == \"cuda\"\n    ), f\"noise_pred should be on cuda, got {noise_pred.device}\"\n    assert (\n        noise_pred.dtype == torch.bfloat16\n    ), f\"noise_pred should be bfloat16, got {noise_pred.dtype}\"\n\n    print(f\"✓ Successfully retrieved noise_pred from OutputBatch!\")\n    print(f\"  noise_pred shape: {noise_pred.shape}\")\n    print(f\"  noise_pred dtype: {noise_pred.dtype}\")\n    print(f\"  noise_pred device: {noise_pred.device}\")\n\n    latents = output_batch.output if output_batch.output is not None else req.latents\n    assert latents is not None, \"latents should not be None\"\n    print(f\"latents.shape: {latents.shape}\")\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__, \"-v\"])\n"
  },
  {
    "path": "python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/test/test_qwen_image_edit_pipeline.py",
    "content": "\"\"\"Test for ComfyUIQwenImageEditPipeline with pass-through scheduler (I2I/edit mode).\"\"\"\n\nimport os\n\nimport pytest\nimport torch\n\nfrom sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams\nfrom sglang.multimodal_gen.runtime.entrypoints.diffusion_generator import DiffGenerator\nfrom sglang.multimodal_gen.runtime.entrypoints.utils import prepare_request\n\n\ndef test_comfyui_qwen_image_edit_pipeline_direct() -> None:\n    \"\"\"Test ComfyUIQwenImageEditPipeline with edit mode (I2I) and custom inputs.\"\"\"\n    model_path = os.environ.get(\n        \"SGLANG_TEST_QWEN_IMAGE_EDIT_MODEL_PATH\",\n        \"Qwen/Qwen-Image-Edit-2511\",  # Supports both safetensors file and diffusers format\n    )\n\n    generator = DiffGenerator.from_pretrained(\n        model_path=model_path,\n        pipeline_class_name=\"ComfyUIQwenImageEditPipeline\",\n        num_gpus=1,\n        comfyui_mode=True,\n        dit_layerwise_offload=False,\n    )\n\n    batch_size = 1\n    noisy_image_seq_len = 3600\n    hidden_states_dim = 64\n    condition_image_seq_len = 6889\n    condition_image_dim = 64\n    encoder_seq_len = 45\n    encoder_dim = 3584\n    height = 720\n    width = 1280\n\n    vae_scale_factor = 8\n    condition_height_latent = 1328 // vae_scale_factor\n    condition_width_latent = 1328 // vae_scale_factor\n\n    noisy_image_latents = torch.ones(\n        batch_size,\n        noisy_image_seq_len,\n        hidden_states_dim,\n        device=\"cuda\",\n        dtype=torch.bfloat16,\n    )\n\n    condition_image_latents = torch.ones(\n        batch_size,\n        condition_image_seq_len,\n        condition_image_dim,\n        device=\"cuda\",\n        dtype=torch.bfloat16,\n    )\n\n    encoder_hidden_states = torch.ones(\n        batch_size,\n        encoder_seq_len,\n        encoder_dim,\n        device=\"cuda\",\n        dtype=torch.bfloat16,\n    )\n\n    timesteps = torch.tensor([1000], dtype=torch.long, device=\"cuda\")\n\n    sampling_params = SamplingParams.from_user_sampling_params_args(\n        generator.server_args.model_path,\n        server_args=generator.server_args,\n        prompt=\" \",\n        guidance_scale=1.0,\n        height=height,\n        width=width,\n        image_path=\"\",\n        num_frames=1,\n        num_inference_steps=1,\n        seed=42,\n        save_output=False,\n        return_frames=False,\n    )\n\n    req = prepare_request(\n        server_args=generator.server_args,\n        sampling_params=sampling_params,\n    )\n\n    req.latents = noisy_image_latents\n    req.image_latent = condition_image_latents\n    req.timesteps = timesteps\n    req.prompt_embeds = [encoder_hidden_states]\n    req.negative_prompt_embeds = None\n    req.vae_image_sizes = [(condition_width_latent, condition_height_latent)]\n    req.raw_latent_shape = torch.tensor(noisy_image_latents.shape, dtype=torch.long)\n\n    if req.guidance_scale > 1.0 and req.negative_prompt_embeds is not None:\n        req.do_classifier_free_guidance = True\n    else:\n        req.do_classifier_free_guidance = False\n\n    if req.seed is not None:\n        generator_device = req.generator_device\n        device_str = \"cpu\" if generator_device == \"cpu\" else \"cuda\"\n        req.generator = [\n            torch.Generator(device_str).manual_seed(req.seed + i)\n            for i in range(req.num_outputs_per_prompt)\n        ]\n    else:\n        req.generator = [\n            torch.Generator(\"cuda\") for _ in range(req.num_outputs_per_prompt)\n        ]\n\n    output_batch = generator._send_to_scheduler_and_wait_for_response([req])\n    noise_pred = output_batch.noise_pred\n\n    assert noise_pred is not None, \"noise_pred should not be None in OutputBatch\"\n    assert isinstance(noise_pred, torch.Tensor), \"noise_pred should be a torch.Tensor\"\n    assert (\n        noise_pred.device.type == \"cuda\"\n    ), f\"noise_pred should be on cuda, got {noise_pred.device}\"\n    assert (\n        noise_pred.dtype == torch.bfloat16\n    ), f\"noise_pred should be bfloat16, got {noise_pred.dtype}\"\n\n    print(f\"✓ Successfully retrieved noise_pred from OutputBatch (Edit Mode)!\")\n    print(f\"  noise_pred shape: {noise_pred.shape}\")\n    print(f\"  noise_pred dtype: {noise_pred.dtype}\")\n    print(f\"  noise_pred device: {noise_pred.device}\")\n\n    latents = output_batch.output if output_batch.output is not None else req.latents\n    assert latents is not None, \"latents should not be None\"\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__, \"-v\"])\n"
  },
  {
    "path": "python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/test/test_qwen_image_pipeline.py",
    "content": "\"\"\"Test for ComfyUIQwenImagePipeline with pass-through scheduler.\"\"\"\n\nimport os\n\nimport pytest\nimport torch\n\nfrom sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams\nfrom sglang.multimodal_gen.runtime.entrypoints.diffusion_generator import DiffGenerator\nfrom sglang.multimodal_gen.runtime.entrypoints.utils import prepare_request\n\n\ndef test_comfyui_qwen_image_pipeline_direct() -> None:\n    \"\"\"Test ComfyUIQwenImagePipeline with custom inputs.\"\"\"\n    model_path = os.environ.get(\n        \"SGLANG_TEST_QWEN_IMAGE_MODEL_PATH\",\n        \"Qwen/Qwen-Image\",  # Supports both safetensors file and diffusers format\n    )\n\n    generator = DiffGenerator.from_pretrained(\n        model_path=model_path,\n        pipeline_class_name=\"ComfyUIQwenImagePipeline\",\n        num_gpus=2,\n        comfyui_mode=True,\n        dit_layerwise_offload=False,\n    )\n\n    batch_size = 1\n    hidden_states_seq_len = 6889\n    hidden_states_dim = 64\n    encoder_seq_len = 45\n    encoder_dim = 3584\n    height = 1328\n    width = 1328\n    dtype = torch.bfloat16\n\n    hidden_states = torch.ones(\n        batch_size,\n        hidden_states_seq_len,\n        hidden_states_dim,\n        device=\"cuda\",\n        dtype=dtype,\n    )\n\n    encoder_hidden_states = torch.ones(\n        batch_size,\n        encoder_seq_len,\n        encoder_dim,\n        device=\"cuda\",\n        dtype=torch.bfloat16,\n    )\n\n    timesteps = torch.tensor([1000], dtype=torch.long, device=\"cuda\")\n\n    sampling_params = SamplingParams.from_user_sampling_params_args(\n        generator.server_args.model_path,\n        server_args=generator.server_args,\n        prompt=\" \",\n        guidance_scale=3.0,\n        height=height,\n        width=width,\n        num_frames=1,\n        num_inference_steps=1,\n        seed=42,\n        save_output=False,\n        return_frames=False,\n    )\n\n    req = prepare_request(\n        server_args=generator.server_args,\n        sampling_params=sampling_params,\n    )\n\n    req.latents = hidden_states\n    req.timesteps = timesteps\n    req.prompt_embeds = [encoder_hidden_states]\n    req.negative_prompt_embeds = [encoder_hidden_states]\n    req.raw_latent_shape = torch.tensor(hidden_states.shape, dtype=torch.long)\n\n    if req.guidance_scale > 1.0 and req.negative_prompt_embeds is not None:\n        req.do_classifier_free_guidance = True\n    else:\n        req.do_classifier_free_guidance = False\n\n    if req.seed is not None:\n        generator_device = req.generator_device\n        device_str = \"cpu\" if generator_device == \"cpu\" else \"cuda\"\n        req.generator = [\n            torch.Generator(device_str).manual_seed(req.seed + i)\n            for i in range(req.num_outputs_per_prompt)\n        ]\n    else:\n        req.generator = [\n            torch.Generator(\"cuda\") for _ in range(req.num_outputs_per_prompt)\n        ]\n\n    output_batch = generator._send_to_scheduler_and_wait_for_response([req])\n    noise_pred = output_batch.noise_pred\n\n    assert noise_pred is not None, \"noise_pred should not be None in OutputBatch\"\n    assert isinstance(noise_pred, torch.Tensor), \"noise_pred should be a torch.Tensor\"\n    assert (\n        noise_pred.device.type == \"cuda\"\n    ), f\"noise_pred should be on cuda, got {noise_pred.device}\"\n    assert (\n        noise_pred.dtype == torch.bfloat16\n    ), f\"noise_pred should be bfloat16, got {noise_pred.dtype}\"\n\n    print(f\"✓ Successfully retrieved noise_pred from OutputBatch!\")\n    print(f\"  noise_pred shape: {noise_pred.shape}\")\n    print(f\"  noise_pred dtype: {noise_pred.dtype}\")\n    print(f\"  noise_pred device: {noise_pred.device}\")\n\n    latents = output_batch.output if output_batch.output is not None else req.latents\n    assert latents is not None, \"latents should not be None\"\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__, \"-v\"])\n"
  },
  {
    "path": "python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/test/test_zimage_pipeline.py",
    "content": "\"\"\"Test for ComfyUIZImagePipeline with pass-through scheduler.\"\"\"\n\nimport os\n\nimport pytest\nimport torch\n\nfrom sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams\nfrom sglang.multimodal_gen.runtime.entrypoints.diffusion_generator import DiffGenerator\nfrom sglang.multimodal_gen.runtime.entrypoints.utils import prepare_request\n\n\ndef test_comfyui_zimage_pipeline_direct() -> None:\n    \"\"\"Test ComfyUIZImagePipeline with custom inputs.\"\"\"\n    model_path = os.environ.get(\n        \"SGLANG_TEST_ZIMAGE_MODEL_PATH\",\n        \"Tongyi-MAI/Z-Image-Turbo\",  # Supports both safetensors file and diffusers format\n    )\n\n    generator = DiffGenerator.from_pretrained(\n        model_path=model_path,\n        pipeline_class_name=\"ComfyUIZImagePipeline\",\n        num_gpus=1,\n        sp_degree=1,\n        comfyui_mode=True,\n    )\n\n    batch_size = 1\n    num_channels = 16\n    num_frames = 1\n    height = 720\n    width = 1280\n    latent_height = height // 8\n    latent_width = width // 8\n\n    latents = torch.ones(\n        batch_size,\n        num_channels,\n        num_frames,\n        latent_height,\n        latent_width,\n        device=\"cuda\",\n        dtype=torch.bfloat16,\n    )\n\n    timesteps = torch.tensor([1000], dtype=torch.long, device=\"cuda\")\n\n    context_seq_len = 19\n    context_dim = 2560\n    context = torch.ones(\n        context_seq_len,\n        context_dim,\n        device=\"cuda\",\n        dtype=torch.bfloat16,\n    )\n\n    sampling_params = SamplingParams.from_user_sampling_params_args(\n        generator.server_args.model_path,\n        server_args=generator.server_args,\n        prompt=\"a beautiful girl\",\n        guidance_scale=1.0,\n        height=height,\n        width=width,\n        num_frames=1,\n        num_inference_steps=1,\n        seed=42,\n        save_output=False,\n        return_frames=False,\n    )\n\n    req = prepare_request(\n        server_args=generator.server_args,\n        sampling_params=sampling_params,\n    )\n\n    req.latents = latents\n    req.timesteps = timesteps\n    req.prompt_embeds = [context]\n    req.negative_prompt_embeds = None\n    req.raw_latent_shape = torch.tensor(latents.shape, dtype=torch.long)\n\n    if req.guidance_scale > 1.0 and req.negative_prompt_embeds is not None:\n        req.do_classifier_free_guidance = True\n    else:\n        req.do_classifier_free_guidance = False\n\n    if req.seed is not None:\n        generator_device = req.generator_device\n        device_str = \"cpu\" if generator_device == \"cpu\" else \"cuda\"\n        req.generator = [\n            torch.Generator(device_str).manual_seed(req.seed + i)\n            for i in range(req.num_outputs_per_prompt)\n        ]\n    else:\n        req.generator = [\n            torch.Generator(\"cuda\") for _ in range(req.num_outputs_per_prompt)\n        ]\n\n    output_batch = generator._send_to_scheduler_and_wait_for_response([req])\n    noise_pred = output_batch.noise_pred\n\n    assert noise_pred is not None, \"noise_pred should not be None in OutputBatch\"\n    assert isinstance(noise_pred, torch.Tensor), \"noise_pred should be a torch.Tensor\"\n    assert (\n        noise_pred.device.type == \"cuda\"\n    ), f\"noise_pred should be on cuda, got {noise_pred.device}\"\n    assert (\n        noise_pred.dtype == torch.bfloat16\n    ), f\"noise_pred should be bfloat16, got {noise_pred.dtype}\"\n\n    print(f\"✓ Successfully retrieved noise_pred from OutputBatch!\")\n    print(f\"  noise_pred shape: {noise_pred.shape}\")\n    print(f\"  noise_pred dtype: {noise_pred.dtype}\")\n    print(f\"  noise_pred device: {noise_pred.device}\")\n\n    latents = output_batch.output if output_batch.output is not None else req.latents\n    assert latents is not None, \"latents should not be None\"\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__, \"-v\"])\n"
  },
  {
    "path": "python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/utils.py",
    "content": "import base64\nimport io\nimport os\nimport shutil\nimport time\nimport uuid\n\nimport folder_paths\nimport numpy as np\nimport torch\nfrom comfy_api.input import VideoInput\nfrom PIL import Image\n\n\ndef _ensure_dir(path: str) -> None:\n    os.makedirs(path, exist_ok=True)\n\n\ndef _to_numpy_image(image: torch.Tensor) -> np.ndarray:\n    \"\"\"Convert ComfyUI image tensor to uint8 numpy array (H, W, C).\"\"\"\n    if image.dim() == 4:\n        image = image[0]\n    if image.dim() == 3 and image.shape[0] in (1, 3, 4):\n        image = image.permute(1, 2, 0)\n    elif image.dim() == 2:\n        image = image.unsqueeze(-1)\n    np_img = image.detach().cpu().numpy()\n    np_img = np.clip(np_img, 0.0, 1.0)\n    np_img = (np_img * 255).astype(np.uint8)\n    if np_img.shape[-1] == 1:\n        np_img = np.repeat(np_img, 3, axis=-1)\n    return np_img\n\n\ndef _to_hwc_tensor(image: torch.Tensor) -> torch.Tensor:\n    \"\"\"Convert ComfyUI image tensor to HWC format (normalized [0, 1]).\"\"\"\n    img = image.clone()\n    if img.dim() == 4:\n        img = img[0]\n    if img.dim() == 3 and img.shape[0] in (1, 3, 4):\n        img = img.permute(1, 2, 0)\n    elif img.dim() == 2:\n        img = img.unsqueeze(-1)\n\n    img = torch.clamp(img, 0.0, 1.0)\n    if img.shape[-1] == 1:\n        img = img.repeat(1, 1, 3)\n\n    return img\n\n\ndef is_empty_image(image: torch.Tensor, tolerance: float = 1e-6) -> bool:\n    \"\"\"\n    Check if the input image is an empty/solid color image (like ComfyUI's empty image).\n    Args:\n        image: Input tensor image in ComfyUI format (BCHW, CHW, HWC, etc.)\n        tolerance: Tolerance for floating point comparison (default: 1e-6)\n\n    Returns:\n        True if the image is empty (all pixels have same color), False otherwise\n    \"\"\"\n    if image is None:\n        return True\n\n    # Convert to HWC format\n    img_hwc = _to_hwc_tensor(image)\n\n    # Get the first pixel's RGB values\n    first_pixel = img_hwc[0, 0, :]\n\n    h, w, c = img_hwc.shape\n    pixels = img_hwc.reshape(-1, c)\n\n    diff = torch.abs(pixels - first_pixel)\n    max_diff = torch.max(diff)\n\n    return max_diff.item() <= tolerance\n\n\ndef get_image_path(image: torch.Tensor) -> str:\n    \"\"\"\n    Save tensor image to ComfyUI temp directory as PNG and return the path.\n    \"\"\"\n    temp_dir = folder_paths.get_temp_directory()\n\n    # Build file name\n    ts = time.strftime(\"%Y%m%d-%H%M%S\")\n    unique = uuid.uuid4().hex[:8]\n    file_name = f\"sgl_output_{ts}_{unique}.png\"\n    file_path = os.path.join(temp_dir, file_name)\n\n    # Save image\n    np_img = _to_numpy_image(image)\n    img = Image.fromarray(np_img)\n    img.save(file_path, format=\"PNG\")\n\n    return file_path\n\n\ndef convert_b64_to_tensor_image(b64_image: str) -> torch.Tensor:\n    \"\"\"\n    Convert base64 encoded image to ComfyUI IMAGE format (torch.Tensor).\n\n    Args:\n        b64_image: Base64 encoded image string\n\n    Returns:\n        torch.Tensor with shape [batch_size, height, width, channels] (BHWC format),\n        values normalized to [0, 1] range, RGB format (3 channels)\n    \"\"\"\n    # Decode base64\n    image_bytes = base64.b64decode(b64_image)\n\n    # Open image and convert to RGB\n    pil_image = Image.open(io.BytesIO(image_bytes))\n    if pil_image.mode != \"RGB\":\n        pil_image = pil_image.convert(\"RGB\")\n\n    # Convert to numpy array and normalize to [0, 1]\n    image_array = np.array(pil_image).astype(np.float32) / 255.0\n\n    # Add batch dimension: [height, width, channels] -> [1, height, width, channels]\n    image_array = image_array[np.newaxis, ...]\n\n    # Convert to torch.Tensor\n    tensor_image = torch.from_numpy(image_array)\n\n    return tensor_image\n\n\nclass SGLDVideoInput(VideoInput):\n    def __init__(self, video_path: str, height: int, width: int):\n        super().__init__()\n\n        self.video_path = video_path\n        self.height = height\n        self.width = width\n\n    def get_dimensions(self) -> tuple[int, int]:\n        \"\"\"\n        Returns the dimensions of the video input.\n\n        Returns:\n            Tuple of (width, height)\n        \"\"\"\n        return self.width, self.height\n\n    def get_components(self):\n        \"\"\"\n        Returns the components of the video input.\n        This is required by the VideoInput abstract base class.\n        \"\"\"\n        return [self.video_path]\n\n    def save_to(self, path: str, format=None, codec=None, metadata=None):\n        \"\"\"\n        Abstract method to save the video input to a file.\n        \"\"\"\n        save_path = path\n        # Copy video file from video_path to save_path\n        if os.path.exists(self.video_path):\n            # Ensure destination directory exists\n            save_dir = os.path.dirname(save_path)\n            if save_dir:\n                os.makedirs(save_dir, exist_ok=True)\n            shutil.copy2(self.video_path, save_path)\n\n\ndef convert_video_to_comfy_video(\n    video_path: str, height: int, width: int\n) -> VideoInput:\n    \"\"\"\n    Convert video to ComfyUI VIDEO format (VideoInput).\n    \"\"\"\n    video_input = SGLDVideoInput(video_path, height, width)\n    return video_input\n"
  },
  {
    "path": "python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/workflows/flux_sgld_sp.json",
    "content": "{\n    \"8\": {\n      \"inputs\": {\n        \"samples\": [\n          \"40\",\n          0\n        ],\n        \"vae\": [\n          \"10\",\n          0\n        ]\n      },\n      \"class_type\": \"VAEDecode\",\n      \"_meta\": {\n        \"title\": \"VAE Decode\"\n      }\n    },\n    \"10\": {\n      \"inputs\": {\n        \"vae_name\": \"ae.safetensors\"\n      },\n      \"class_type\": \"VAELoader\",\n      \"_meta\": {\n        \"title\": \"Load VAE\"\n      }\n    },\n    \"11\": {\n      \"inputs\": {\n        \"clip_name1\": \"t5xxl_fp16.safetensors\",\n        \"clip_name2\": \"clip_l.safetensors\",\n        \"type\": \"flux\",\n        \"device\": \"default\"\n      },\n      \"class_type\": \"DualCLIPLoader\",\n      \"_meta\": {\n        \"title\": \"DualCLIPLoader\"\n      }\n    },\n    \"17\": {\n      \"inputs\": {\n        \"scheduler\": \"normal\",\n        \"steps\": 25,\n        \"denoise\": 1,\n        \"model\": [\n          \"46\",\n          0\n        ]\n      },\n      \"class_type\": \"BasicScheduler\",\n      \"_meta\": {\n        \"title\": \"BasicScheduler\"\n      }\n    },\n    \"38\": {\n      \"inputs\": {\n        \"model\": [\n          \"46\",\n          0\n        ],\n        \"conditioning\": [\n          \"42\",\n          0\n        ]\n      },\n      \"class_type\": \"BasicGuider\",\n      \"_meta\": {\n        \"title\": \"BasicGuider\"\n      }\n    },\n    \"39\": {\n      \"inputs\": {\n        \"filename_prefix\": \"ComfyUI\",\n        \"images\": [\n          \"8\",\n          0\n        ]\n      },\n      \"class_type\": \"SaveImage\",\n      \"_meta\": {\n        \"title\": \"Save Image\"\n      }\n    },\n    \"40\": {\n      \"inputs\": {\n        \"noise\": [\n          \"45\",\n          0\n        ],\n        \"guider\": [\n          \"38\",\n          0\n        ],\n        \"sampler\": [\n          \"47\",\n          0\n        ],\n        \"sigmas\": [\n          \"17\",\n          0\n        ],\n        \"latent_image\": [\n          \"44\",\n          0\n        ]\n      },\n      \"class_type\": \"SamplerCustomAdvanced\",\n      \"_meta\": {\n        \"title\": \"SamplerCustomAdvanced\"\n      }\n    },\n    \"42\": {\n      \"inputs\": {\n        \"guidance\": 3.5,\n        \"conditioning\": [\n          \"43\",\n          0\n        ]\n      },\n      \"class_type\": \"FluxGuidance\",\n      \"_meta\": {\n        \"title\": \"FluxGuidance\"\n      }\n    },\n    \"43\": {\n      \"inputs\": {\n        \"text\": \"beautiful photography of a gonger haired artist with Lots of Colorful coloursplashes in face and pn her hands, she is natural, having her hair in a casual bun, looking happily into camera, cinematic,\",\n        \"speak_and_recognation\": {\n          \"__value__\": [\n            false,\n            true\n          ]\n        },\n        \"clip\": [\n          \"11\",\n          0\n        ]\n      },\n      \"class_type\": \"CLIPTextEncode\",\n      \"_meta\": {\n        \"title\": \"CLIP Text Encode (Prompt)\"\n      }\n    },\n    \"44\": {\n      \"inputs\": {\n        \"width\": 1024,\n        \"height\": 1024,\n        \"batch_size\": 1\n      },\n      \"class_type\": \"EmptySD3LatentImage\",\n      \"_meta\": {\n        \"title\": \"EmptySD3LatentImage\"\n      }\n    },\n    \"45\": {\n      \"inputs\": {\n        \"noise_seed\": 747172083610812\n      },\n      \"class_type\": \"RandomNoise\",\n      \"_meta\": {\n        \"title\": \"RandomNoise\"\n      }\n    },\n    \"46\": {\n      \"inputs\": {\n        \"max_shift\": 1.15,\n        \"base_shift\": 0.5,\n        \"width\": 1024,\n        \"height\": 1024,\n        \"model\": [\n          \"51\",\n          0\n        ]\n      },\n      \"class_type\": \"ModelSamplingFlux\",\n      \"_meta\": {\n        \"title\": \"ModelSamplingFlux\"\n      }\n    },\n    \"47\": {\n      \"inputs\": {\n        \"sampler_name\": \"euler\"\n      },\n      \"class_type\": \"KSamplerSelect\",\n      \"_meta\": {\n        \"title\": \"KSamplerSelect\"\n      }\n    },\n    \"51\": {\n      \"inputs\": {\n        \"unet_name\": \"flux1-dev.safetensors\",\n        \"weight_dtype\": \"default\",\n        \"sgld_options\": [\n          \"52\",\n          0\n        ]\n      },\n      \"class_type\": \"SGLDUNETLoader\",\n      \"_meta\": {\n        \"title\": \"SGLDiffusion UNET Loader\"\n      }\n    },\n    \"52\": {\n      \"inputs\": {\n        \"model_type\": \"auto-detect\",\n        \"enable_torch_compile\": false,\n        \"num_gpus\": 2,\n        \"tp_size\": -1,\n        \"sp_degree\": -1,\n        \"ulysses_degree\": -1,\n        \"ring_degree\": -1,\n        \"dp_size\": 1,\n        \"dp_degree\": 1,\n        \"enable_cfg_parallel\": false,\n        \"attention_backend\": \"\",\n        \"cache_strategy\": \"none\"\n      },\n      \"class_type\": \"SGLDOptions\",\n      \"_meta\": {\n        \"title\": \"SGLDiffusion Options\"\n      }\n    }\n  }\n"
  },
  {
    "path": "python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/workflows/qwen_image_sgld.json",
    "content": "{\n    \"3\": {\n      \"inputs\": {\n        \"seed\": 808633539418610,\n        \"steps\": 4,\n        \"cfg\": 1,\n        \"sampler_name\": \"euler\",\n        \"scheduler\": \"simple\",\n        \"denoise\": 1,\n        \"model\": [\n          \"66\",\n          0\n        ],\n        \"positive\": [\n          \"6\",\n          0\n        ],\n        \"negative\": [\n          \"7\",\n          0\n        ],\n        \"latent_image\": [\n          \"58\",\n          0\n        ]\n      },\n      \"class_type\": \"KSampler\",\n      \"_meta\": {\n        \"title\": \"KSampler\"\n      }\n    },\n    \"6\": {\n      \"inputs\": {\n        \"text\": \"\\\"A vibrant, warm neon-lit street scene in Hong Kong at the afternoon, with a mix of colorful Chinese and English signs glowing brightly. The atmosphere is lively, cinematic, and rain-washed with reflections on the pavement. The colors are vivid, full of pink, blue, red, and green hues. Crowded buildings with overlapping neon signs. 1980s Hong Kong style. Signs include:\\n\\\"龍鳳冰室\\\" \\\"金華燒臘\\\" \\\"HAPPY HAIR\\\" \\\"鴻運茶餐廳\\\" \\\"EASY BAR\\\" \\\"永發魚蛋粉\\\" \\\"添記粥麵\\\" \\\"SUNSHINE MOTEL\\\" \\\"美都餐室\\\" \\\"富記糖水\\\" \\\"太平館\\\" \\\"雅芳髮型屋\\\" \\\"STAR KTV\\\" \\\"銀河娛樂城\\\" \\\"百樂門舞廳\\\" \\\"BUBBLE CAFE\\\" \\\"萬豪麻雀館\\\" \\\"CITY LIGHTS BAR\\\" \\\"瑞祥香燭莊\\\" \\\"文記文具\\\" \\\"GOLDEN JADE HOTEL\\\" \\\"LOVELY BEAUTY\\\" \\\"合興百貨\\\" \\\"興旺電器\\\" And the background is warm yellow street and with all stores' lights on.\",\n        \"speak_and_recognation\": {\n          \"__value__\": [\n            false,\n            true\n          ]\n        },\n        \"clip\": [\n          \"38\",\n          0\n        ]\n      },\n      \"class_type\": \"CLIPTextEncode\",\n      \"_meta\": {\n        \"title\": \"CLIP Text Encode (Positive Prompt)\"\n      }\n    },\n    \"7\": {\n      \"inputs\": {\n        \"text\": \"\",\n        \"speak_and_recognation\": {\n          \"__value__\": [\n            false,\n            true\n          ]\n        },\n        \"clip\": [\n          \"38\",\n          0\n        ]\n      },\n      \"class_type\": \"CLIPTextEncode\",\n      \"_meta\": {\n        \"title\": \"CLIP Text Encode (Negative Prompt)\"\n      }\n    },\n    \"8\": {\n      \"inputs\": {\n        \"samples\": [\n          \"3\",\n          0\n        ],\n        \"vae\": [\n          \"39\",\n          0\n        ]\n      },\n      \"class_type\": \"VAEDecode\",\n      \"_meta\": {\n        \"title\": \"VAE Decode\"\n      }\n    },\n    \"38\": {\n      \"inputs\": {\n        \"clip_name\": \"qwen_2.5_vl_7b_fp8_scaled.safetensors\",\n        \"type\": \"qwen_image\",\n        \"device\": \"default\"\n      },\n      \"class_type\": \"CLIPLoader\",\n      \"_meta\": {\n        \"title\": \"Load CLIP\"\n      }\n    },\n    \"39\": {\n      \"inputs\": {\n        \"vae_name\": \"qwen_image_vae.safetensors\"\n      },\n      \"class_type\": \"VAELoader\",\n      \"_meta\": {\n        \"title\": \"Load VAE\"\n      }\n    },\n    \"58\": {\n      \"inputs\": {\n        \"width\": 1328,\n        \"height\": 1328,\n        \"batch_size\": 1\n      },\n      \"class_type\": \"EmptySD3LatentImage\",\n      \"_meta\": {\n        \"title\": \"EmptySD3LatentImage\"\n      }\n    },\n    \"60\": {\n      \"inputs\": {\n        \"filename_prefix\": \"ComfyUI\"\n      },\n      \"class_type\": \"SaveImage\",\n      \"_meta\": {\n        \"title\": \"Save Image\"\n      }\n    },\n    \"66\": {\n      \"inputs\": {\n        \"shift\": 3.1000000000000005,\n        \"model\": [\n          \"78\",\n          0\n        ]\n      },\n      \"class_type\": \"ModelSamplingAuraFlow\",\n      \"_meta\": {\n        \"title\": \"ModelSamplingAuraFlow\"\n      }\n    },\n    \"77\": {\n      \"inputs\": {\n        \"unet_name\": \"qwen_image_2512_bf16.safetensors\",\n        \"weight_dtype\": \"default\"\n      },\n      \"class_type\": \"SGLDUNETLoader\",\n      \"_meta\": {\n        \"title\": \"SGLDiffusion UNET Loader\"\n      }\n    },\n    \"78\": {\n      \"inputs\": {\n        \"lora_name\": \"Qwen-Image-2512-Lightning-4steps-V1.0-bf16.safetensors\",\n        \"strength_model\": 1,\n        \"nickname\": \"\",\n        \"target\": \"all\",\n        \"model\": [\n          \"77\",\n          0\n        ]\n      },\n      \"class_type\": \"SGLDLoraLoader\",\n      \"_meta\": {\n        \"title\": \"SGLDiffusion LoRA Loader\"\n      }\n    }\n  }\n"
  },
  {
    "path": "python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/workflows/sgld_image2video.json",
    "content": "{\n    \"1\": {\n      \"inputs\": {\n        \"base_url\": \"http://localhost:3000/v1\",\n        \"api_key\": \"sk-proj-1234567890\"\n      },\n      \"class_type\": \"SGLDiffusionServerModel\",\n      \"_meta\": {\n        \"title\": \"SGLDiffusion Server Model\"\n      }\n    },\n    \"3\": {\n      \"inputs\": {\n        \"prompt\": \"The girl turn the body and spin around in place.\",\n        \"main\": \"none\",\n        \"lighting\": \"none\",\n        \"speak_and_recognation\": {\n          \"__value__\": [\n            false,\n            true\n          ]\n        }\n      },\n      \"class_type\": \"easy prompt\",\n      \"_meta\": {\n        \"title\": \"Prompt\"\n      }\n    },\n    \"4\": {\n      \"inputs\": {\n        \"text\": \"\",\n        \"anything\": [\n          \"1\",\n          1\n        ]\n      },\n      \"class_type\": \"easy showAnything\",\n      \"_meta\": {\n        \"title\": \"Show Any\"\n      }\n    },\n    \"15\": {\n      \"inputs\": {\n        \"positive_prompt\": [\n          \"3\",\n          0\n        ],\n        \"negative_prompt\": \"\",\n        \"seed\": 2435791308,\n        \"steps\": 50,\n        \"cfg\": 4,\n        \"width\": 704,\n        \"height\": 1280,\n        \"num_frames\": 16,\n        \"fps\": 16,\n        \"seconds\": 1,\n        \"enable_teacache\": false,\n        \"sgld_client\": [\n          \"1\",\n          0\n        ],\n        \"image\": [\n          \"17\",\n          0\n        ]\n      },\n      \"class_type\": \"SGLDiffusionGenerateVideo\",\n      \"_meta\": {\n        \"title\": \"SGLDiffusion Generate Video\"\n      }\n    },\n    \"16\": {\n      \"inputs\": {\n        \"filename_prefix\": \"video/ComfyUI\",\n        \"format\": \"auto\",\n        \"codec\": \"auto\",\n        \"video-preview\": \"\",\n        \"video\": [\n          \"15\",\n          0\n        ]\n      },\n      \"class_type\": \"SaveVideo\",\n      \"_meta\": {\n        \"title\": \"save video\"\n      }\n    },\n    \"17\": {\n      \"inputs\": {\n        \"image\": \"tmpe_w0bd_0.jpg\"\n      },\n      \"class_type\": \"LoadImage\",\n      \"_meta\": {\n        \"title\": \"load image\"\n      }\n    }\n  }\n"
  },
  {
    "path": "python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/workflows/sgld_text2img.json",
    "content": "{\n    \"1\": {\n      \"inputs\": {\n        \"base_url\": \"http://localhost:3000/v1\",\n        \"api_key\": \"sk-proj-1234567890\"\n      },\n      \"class_type\": \"SGLDiffusionServerModel\",\n      \"_meta\": {\n        \"title\": \"SGLDiffusion Server Model\"\n      }\n    },\n    \"3\": {\n      \"inputs\": {\n        \"prompt\": \"a bicycle, illustration in the style of SMPL, thick black lines on a white background\",\n        \"main\": \"none\",\n        \"lighting\": \"none\",\n        \"speak_and_recognation\": {\n          \"__value__\": [\n            false,\n            true\n          ]\n        }\n      },\n      \"class_type\": \"easy prompt\",\n      \"_meta\": {\n        \"title\": \"Prompt\"\n      }\n    },\n    \"4\": {\n      \"inputs\": {\n        \"text\": \"\",\n        \"anything\": [\n          \"1\",\n          1\n        ]\n      },\n      \"class_type\": \"easy showAnything\",\n      \"_meta\": {\n        \"title\": \"Show Any\"\n      }\n    },\n    \"5\": {\n      \"inputs\": {\n        \"filename_prefix\": \"ComfyUI\",\n        \"images\": [\n          \"6\",\n          0\n        ]\n      },\n      \"class_type\": \"SaveImage\",\n      \"_meta\": {\n        \"title\": \"save image\"\n      }\n    },\n    \"6\": {\n      \"inputs\": {\n        \"positive_prompt\": [\n          \"3\",\n          0\n        ],\n        \"negative_prompt\": \"\",\n        \"seed\": 4215918563,\n        \"steps\": 50,\n        \"cfg\": 4,\n        \"width\": 512,\n        \"height\": 512,\n        \"enable_teacache\": false,\n        \"sgld_client\": [\n          \"11\",\n          0\n        ],\n        \"image\": [\n          \"14\",\n          0\n        ]\n      },\n      \"class_type\": \"SGLDiffusionGenerateImage\",\n      \"_meta\": {\n        \"title\": \"SGLDiffusion Generate Image\"\n      }\n    },\n    \"11\": {\n      \"inputs\": {\n        \"lora_name\": \"dvyio/flux-lora-simple-illustration\",\n        \"lora_nickname\": \"\",\n        \"target\": \"all\",\n        \"sgld_client\": [\n          \"1\",\n          0\n        ]\n      },\n      \"class_type\": \"SGLDiffusionSetLora\",\n      \"_meta\": {\n        \"title\": \"SGLDiffusion Set LoRA\"\n      }\n    },\n    \"14\": {\n      \"inputs\": {\n        \"width\": 512,\n        \"height\": 512,\n        \"batch_size\": 1,\n        \"color\": 0\n      },\n      \"class_type\": \"EmptyImage\",\n      \"_meta\": {\n        \"title\": \"empty image\"\n      }\n    }\n  }\n"
  },
  {
    "path": "python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/workflows/z-image_sgld.json",
    "content": "{\n    \"3\": {\n      \"inputs\": {\n        \"seed\": 3338398,\n        \"steps\": 9,\n        \"cfg\": 1,\n        \"sampler_name\": \"euler\",\n        \"scheduler\": \"simple\",\n        \"denoise\": 1,\n        \"model\": [\n          \"28\",\n          0\n        ],\n        \"positive\": [\n          \"6\",\n          0\n        ],\n        \"negative\": [\n          \"7\",\n          0\n        ],\n        \"latent_image\": [\n          \"13\",\n          0\n        ]\n      },\n      \"class_type\": \"KSampler\",\n      \"_meta\": {\n        \"title\": \"KSampler\"\n      }\n    },\n    \"6\": {\n      \"inputs\": {\n        \"text\": \"cute anime style girl with massive fluffy fennec ears and a big fluffy tail blonde messy long hair blue eyes wearing a maid outfit with a long black gold leaf pattern dress and a white apron, it is a postcard held by a hand in front of a beautiful realistic city at sunset and there is cursive writing that says \\\"ZImage, Now in ComfyUI\\\"\",\n        \"speak_and_recognation\": {\n          \"__value__\": [\n            false,\n            true\n          ]\n        },\n        \"clip\": [\n          \"18\",\n          0\n        ]\n      },\n      \"class_type\": \"CLIPTextEncode\",\n      \"_meta\": {\n        \"title\": \"CLIP Text Encode (Positive Prompt)\"\n      }\n    },\n    \"7\": {\n      \"inputs\": {\n        \"text\": \"blurry ugly bad\",\n        \"speak_and_recognation\": {\n          \"__value__\": [\n            false,\n            true\n          ]\n        },\n        \"clip\": [\n          \"18\",\n          0\n        ]\n      },\n      \"class_type\": \"CLIPTextEncode\",\n      \"_meta\": {\n        \"title\": \"CLIP Text Encode (Negative Prompt)\"\n      }\n    },\n    \"8\": {\n      \"inputs\": {\n        \"samples\": [\n          \"3\",\n          0\n        ],\n        \"vae\": [\n          \"17\",\n          0\n        ]\n      },\n      \"class_type\": \"VAEDecode\",\n      \"_meta\": {\n        \"title\": \"VAE Decode\"\n      }\n    },\n    \"9\": {\n      \"inputs\": {\n        \"filename_prefix\": \"ComfyUI\",\n        \"images\": [\n          \"8\",\n          0\n        ]\n      },\n      \"class_type\": \"SaveImage\",\n      \"_meta\": {\n        \"title\": \"Save Image\"\n      }\n    },\n    \"13\": {\n      \"inputs\": {\n        \"width\": 1024,\n        \"height\": 1024,\n        \"batch_size\": 1\n      },\n      \"class_type\": \"EmptySD3LatentImage\",\n      \"_meta\": {\n        \"title\": \"EmptySD3LatentImage\"\n      }\n    },\n    \"17\": {\n      \"inputs\": {\n        \"vae_name\": \"ae.safetensors\"\n      },\n      \"class_type\": \"VAELoader\",\n      \"_meta\": {\n        \"title\": \"VAE Loader\"\n      }\n    },\n    \"18\": {\n      \"inputs\": {\n        \"clip_name\": \"qwen_3_4b.safetensors\",\n        \"type\": \"lumina2\",\n        \"device\": \"default\"\n      },\n      \"class_type\": \"CLIPLoader\",\n      \"_meta\": {\n        \"title\": \"CLIP Loader\"\n      }\n    },\n    \"28\": {\n      \"inputs\": {\n        \"unet_name\": \"z_image_turbo_bf16.safetensors\",\n        \"weight_dtype\": \"default\"\n      },\n      \"class_type\": \"SGLDUNETLoader\",\n      \"_meta\": {\n        \"title\": \"SGLDiffusion UNET Loader\"\n      }\n    }\n}\n"
  },
  {
    "path": "python/sglang/multimodal_gen/apps/webui/README.md",
    "content": "# SGLang Diffusion WebUI User Guide\n\nSGLang Diffusion WebUI provides an intuitive Gradio-based interface for image and video generation, supporting parameter\ntuning and real-time previews.\n\n## Prerequisites\n\nThe WebUI runs on Gradio. To get started, install Gradio first:\n\n```bash\npip install gradio==6.1.0\n```\n\n## Launch WebUI Service\n\nSGLang Diffusion now includes an integrated WebUI. Simply add the `--webui` parameter when starting the service.\n\n### Launch Text-to-Image Service\n\n```bash\nsglang serve --model-path black-forest-labs/FLUX.1-dev --num-gpus 1 --webui --webui-port 2333\n```\n\n### Launch Text-to-Video Service\n\n```bash\nsglang serve --model-path Wan-AI/Wan2.2-T2V-A14B-Diffusers --num-gpus 1 --webui --webui-port 2333\n```\n\n### Launch Image-to-Image Service\n```bash\nsglang serve --model-path Qwen/Qwen-Image-Edit-2511 --num-gpus 1 --webui --webui-port 2333\n```\n\n### Launch Image-to-Video Service\n```bash\nsglang serve --model-path Wan-AI/Wan2.2-TI2V-5B-Diffusers --num-gpus 1 --webui --webui-port 2333\n```\n\n## Port Forwarding\n\nOnce the WebUI service is running, you need to use **SSH port forwarding** to securely access the remote service from\nyour local machine.\n\nIn most cases: Your IDE (like VS Code, Cursor, etc.) can handle this automatically. Check your IDE's remote development\nor port forwarding features. Otherwise, execute this command manually.\n\n```bash\nssh -L ${WEBUI_PORT}:localhost:${WEBUI_PORT} user_name@machine_name\n```\n\nLearn more about port forwarding: [Port Forwarding](https://en.wikipedia.org/wiki/Port_forwarding).\n\n## Interface Instructions\n\nYou can view your model path and task name directly in the UI. We'd appreciate any feedback you'd like to share.\n\nOnce launched, access the interface at `http://localhost:${WEBUI_PORT}` in your browser.\n"
  },
  {
    "path": "python/sglang/multimodal_gen/apps/webui/__init__.py",
    "content": "from .main import run_sgl_diffusion_webui\n\n__all__ = [\"run_sgl_diffusion_webui\"]\n"
  },
  {
    "path": "python/sglang/multimodal_gen/apps/webui/main.py",
    "content": "import argparse\nimport os\n\nfrom sglang.multimodal_gen.configs.sample.sampling_params import (\n    DataType,\n    SamplingParams,\n)\nfrom sglang.multimodal_gen.runtime.entrypoints.utils import (\n    post_process_sample,\n    prepare_request,\n)\nfrom sglang.multimodal_gen.runtime.scheduler_client import sync_scheduler_client\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.srt.environ import envs\n\nlogger = init_logger(__name__)\n\n\ndef add_webui_args(parser: argparse.ArgumentParser):\n    \"\"\"Add the arguments for the generate command.\"\"\"\n    parser = ServerArgs.add_cli_args(parser)\n    parser = SamplingParams.add_cli_args(parser)\n    return parser\n\n\ndef run_sgl_diffusion_webui(server_args: ServerArgs):\n    # import gradio in function to avoid CI crash\n\n    import gradio as gr\n\n    def resolve_model_repo_id(model_path: str) -> str:\n        from pathlib import Path\n\n        from huggingface_hub.utils import HFValidationError, validate_repo_id\n\n        try:\n            validate_repo_id(model_path)\n            return model_path\n        except HFValidationError:\n            pass\n\n        p = Path(model_path).expanduser()\n        parts = p.parts\n\n        if len(parts) < 2:\n            raise ValueError(f\"Invalid model_path: {model_path}\")\n\n        candidate = f\"{parts[-2]}/{parts[-1]}\"\n        validate_repo_id(candidate)  # let it raise if invalid\n        return candidate\n\n    repo_id = resolve_model_repo_id(server_args.model_path)\n    if envs.SGLANG_USE_MODELSCOPE.get():\n        from modelscope.hub.api import HubApi\n\n        api = HubApi()\n        model_info_obj = api.model_info(repo_id)\n        task_name = model_info_obj.tasks[0][\"Name\"].replace(\"-synthesis\", \"\")\n    else:\n        from huggingface_hub import model_info\n\n        task_name = model_info(repo_id).pipeline_tag\n\n    # init client\n    sync_scheduler_client.initialize(server_args)\n\n    if task_name in (\"text-to-video\", \"image-to-video\", \"video-to-video\"):\n        task_type = \"video\"\n    elif task_name in [\"text-to-image\", \"image-to-image\"]:\n        task_type = \"image\"\n    else:\n        raise ValueError(\n            f\"The task name {task_name} of model {server_args.model_path} is not a valid task name. Please check the model path.\"\n        )\n    video_visible_only = task_type == \"video\"\n    image_visible_only = task_type == \"image\"\n\n    # server_args will be reused in gradio_generate function\n    def gradio_generate(\n        prompt,\n        negative_prompt,\n        reference_image_paths_str,\n        seed,\n        num_frames,\n        frames_per_second,\n        width,\n        height,\n        num_inference_steps,\n        guidance_scale,\n        enable_teacache,\n    ):\n        \"\"\"\n        NOTE: The input and output of function which is called by gradio button must be gradio components\n        So we use global variable sampling_params_kwargs to avoid pass this param, because gradio does not support this.\n        return [ np.ndarray, None ] | [None, np.ndarray]\n        \"\"\"\n        if reference_image_paths_str:\n            if \"，\" in reference_image_paths_str:\n                logger.warning(\n                    f\"Warning: please use English comma to separate the reference image paths, and the reference image paths is: {reference_image_paths_str}\"\n                )\n                reference_image_paths_str = reference_image_paths_str.replace(\"，\", \",\")\n            image_path = [path.strip() for path in reference_image_paths_str.split(\",\")]\n        else:\n            image_path = None\n\n        sampling_params_kwargs = dict(\n            prompt=prompt,\n            negative_prompt=negative_prompt,\n            image_path=image_path,\n            seed=seed,\n            num_frames=num_frames,\n            fps=frames_per_second,\n            width=width,\n            height=height,\n            guidance_scale=guidance_scale,\n            num_inference_steps=num_inference_steps,\n            enable_teacache=enable_teacache,\n            return_file_paths_only=False,\n        )\n        sampling_params = SamplingParams.from_user_sampling_params_args(\n            server_args.model_path,\n            server_args=server_args,\n            **sampling_params_kwargs,\n        )\n        batch = prepare_request(\n            server_args=server_args,\n            sampling_params=sampling_params,\n        )\n        result = sync_scheduler_client.forward([batch])\n        save_file_path = str(os.path.join(batch.output_path, batch.output_file_name))\n        if result.output is None:\n            sampling_params_str = \"\\n\".join(\n                [f\"{key}: {value}\" for key, value in sampling_params_kwargs.items()]\n            )\n            no_output_msg = f\"No output is generated by client, and their sampling params is: {sampling_params_str}\"\n\n            if batch.data_type == DataType.VIDEO:\n                if os.path.exists(save_file_path):\n                    logger.warning(no_output_msg)\n                    return None, save_file_path\n                else:\n                    no_output_msg += f\"\\nAnd the expected output file was not found at: {save_file_path}\"\n                    raise ValueError(no_output_msg)\n            else:\n                raise ValueError(no_output_msg)\n\n        frames = post_process_sample(\n            result.output[0],\n            batch.data_type,\n            batch.fps,\n            batch.save_output,\n            save_file_path,\n        )\n        if batch.data_type == DataType.VIDEO:\n            # gradio video need video path to show video\n            return None, save_file_path\n        else:\n            return frames[0], None\n\n    with gr.Blocks() as demo:\n        gr.Markdown(\"# 🚀 SGLang Diffusion Application\")\n        with gr.Row():\n            launched_model_box = gr.Textbox(label=\"Model\", value=server_args.model_path)\n            task_name_box = gr.Textbox(label=\"Task name\", value=task_name)\n\n        with gr.Row():\n            with gr.Column(scale=4):\n                prompt = gr.Textbox(label=\"Prompt\", value=\"A curious raccoon\")\n                negative_prompt = gr.Textbox(\n                    label=\"Negative_prompt\",\n                    value=\"Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards\",\n                )\n            with gr.Column(scale=1):\n                seed = gr.Number(label=\"seed\", precision=0, value=1234)\n                run_btn = gr.Button(\"Generate\", variant=\"primary\", size=\"lg\")\n\n        with gr.Row():\n            with gr.Column():\n                width = gr.Number(label=\"width\", precision=0, value=720)\n                height = gr.Number(label=\"height\", precision=0, value=480)\n                num_inference_steps = gr.Slider(\n                    minimum=0, maximum=50, value=20, step=1, label=\"num_inference_steps\"\n                )\n                guidance_scale = gr.Slider(\n                    minimum=0.0, maximum=10, value=5, step=0.01, label=\"guidance_scale\"\n                )\n                num_frames = gr.Slider(\n                    minimum=1,\n                    maximum=181,\n                    value=81,\n                    step=1,\n                    label=\"num_frames\",\n                    visible=video_visible_only,\n                )\n                frames_per_second = gr.Slider(\n                    minimum=4,\n                    maximum=60,\n                    value=16,\n                    step=1,\n                    label=\"frames_per_second\",\n                    visible=video_visible_only,\n                )\n                reference_image_paths_str = gr.Textbox(\n                    label=\"reference images\",\n                    placeholder=\"Examples: 'image1.png, image2.png' or 'https://example.com/image1.png, https://example.com/image2.png'\",\n                )\n                enable_teacache = gr.Checkbox(label=\"enable_teacache\", value=False)\n\n            with gr.Column():\n                image_out = gr.Image(\n                    label=\"Generated Image\", visible=image_visible_only, format=\"png\"\n                )\n                video_out = gr.Video(\n                    label=\"Generated Video\", visible=video_visible_only\n                )\n\n        run_btn.click(\n            fn=gradio_generate,\n            inputs=[\n                prompt,\n                negative_prompt,\n                reference_image_paths_str,\n                seed,\n                num_frames,\n                frames_per_second,\n                width,\n                height,\n                num_inference_steps,\n                guidance_scale,\n                enable_teacache,\n            ],\n            outputs=[image_out, video_out],\n        )\n\n        _, local_url, _ = demo.launch(\n            server_port=server_args.webui_port,\n            quiet=True,\n            prevent_thread_lock=True,\n            show_error=True,\n        )\n\n        # print banner\n        delimiter = \"=\" * 80\n        url = local_url or f\"http://localhost:{server_args.webui_port}\"\n        print(f\"\"\"\n{delimiter}\n\\033[1mSGLang Diffusion WebUI available at:\\033[0m \\033[1;4;92m{url}\\033[0m\n{delimiter}\n\"\"\")\n\n        demo.block_thread()\n"
  },
  {
    "path": "python/sglang/multimodal_gen/benchmarks/bench_offline_throughput.py",
    "content": "\"\"\"\nBenchmark offline throughput for multimodal generation models (Image/Video Generation).\n\nThis script benchmarks generation throughput without running a server, using low-level APIs.\nIt provides detailed metrics on throughput, latency, and resource utilization.\n\n# Usage Examples\n\n## Text-to-Video with VBench dataset\npython -m sglang.multimodal_gen.benchmarks.bench_offline_throughput \\\\\n    --model-path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \\\\\n    --dataset vbench \\\\\n    --num-prompts 20 \\\\\n    --batch-size 1 \\\\\n    --width 512 --height 512 --num-frames 16\n\n## Random dataset for stress testing\npython -m sglang.multimodal_gen.benchmarks.bench_offline_throughput \\\\\n    --model-path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \\\\\n    --dataset random \\\\\n    --num-prompts 100 \\\\\n    --batch-size 1 \\\\\n    --num-inference-steps 20 \\\\\n    --output-file results.json\n\"\"\"\n\nimport argparse\nimport dataclasses\nimport json\nimport time\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List, Tuple\n\nimport torch\nfrom tqdm import tqdm\n\nfrom sglang.multimodal_gen.benchmarks.datasets import RandomDataset, VBenchDataset\nfrom sglang.multimodal_gen.runtime.entrypoints.diffusion_generator import DiffGenerator\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs, set_global_server_args\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import (\n    configure_logger,\n    init_logger,\n)\nfrom sglang.multimodal_gen.test.test_utils import print_divider, print_value_formatted\n\nlogger = init_logger(__name__)\n\n\n@dataclass\nclass BatchOutput:\n    \"\"\"Container for batch generation results.\"\"\"\n\n    latency: float = 0.0\n    latency_per_sample: float = 0.0\n    num_samples: int = 0\n    total_frames: int = 0\n    peak_memory_mb: float = 0.0\n    success: bool = False\n    error: str = \"\"\n\n\n@dataclass\nclass BenchArgs:\n    \"\"\"Benchmark configuration for multimodal generation.\"\"\"\n\n    # Diffusion Model Configuration\n    num_inference_steps: int = 20\n    guidance_scale: float = 7.5\n    seed: int = 42\n    disable_safety_checker: bool = False\n\n    # Output Configuration\n    width: int = 32\n    height: int = 32\n    num_frames: int = 1\n    fps: int = 24\n\n    # Dataset & Benchmark\n    dataset: str = \"random\"\n    dataset_path: str = \"\"\n    task_name: str = \"unknown\"\n    num_prompts: int = 10\n    batch_size: int = 1\n\n    # Benchmark Execution\n    skip_warmup: bool = False\n    output_file: str = \"\"\n    disable_tqdm: bool = False\n\n    @staticmethod\n    def add_cli_args(parser: argparse.ArgumentParser):\n        \"\"\"Add benchmark-specific CLI arguments.\"\"\"\n        # Diffusion Model Configuration\n        parser.add_argument(\n            \"--num-inference-steps\",\n            type=int,\n            default=20,\n            help=\"Number of denoising steps\",\n        )\n        parser.add_argument(\n            \"--guidance-scale\",\n            type=float,\n            default=7.5,\n            help=\"Classifier-free guidance scale\",\n        )\n        parser.add_argument(\"--seed\", type=int, default=42, help=\"Random seed\")\n        parser.add_argument(\n            \"--disable-safety-checker\",\n            action=\"store_true\",\n            help=\"Disable NSFW detection\",\n        )\n\n        # Output Configuration\n        parser.add_argument(\"--width\", type=int, default=32, help=\"Image/video width\")\n        parser.add_argument(\"--height\", type=int, default=32, help=\"Image/video height\")\n        parser.add_argument(\n            \"--num-frames\", type=int, default=1, help=\"Number of frames for video\"\n        )\n        parser.add_argument(\"--fps\", type=int, default=24, help=\"FPS for video\")\n\n        # Dataset & Benchmark\n        parser.add_argument(\n            \"--dataset\",\n            type=str,\n            default=\"random\",\n            choices=[\"vbench\", \"random\"],\n            help=\"Dataset to use\",\n        )\n        parser.add_argument(\n            \"--dataset-path\",\n            type=str,\n            default=\"\",\n            help=\"Path to dataset (prompts file or image directory)\",\n        )\n        parser.add_argument(\n            \"--task-name\",\n            type=str,\n            default=\"unknown\",\n            help=\"Task name for benchmark identification\",\n        )\n        parser.add_argument(\n            \"--num-prompts\",\n            type=int,\n            default=10,\n            help=\"Total number of prompts to benchmark\",\n        )\n        parser.add_argument(\n            \"--batch-size\",\n            type=int,\n            default=1,\n            help=\"Batch size per generation call (currently only bs=1 is supported)\",\n        )\n\n        # Benchmark Execution\n        parser.add_argument(\n            \"--skip-warmup\", action=\"store_true\", help=\"Skip warmup batch\"\n        )\n        parser.add_argument(\n            \"--output-file\",\n            type=str,\n            default=\"\",\n            help=\"Output JSON file for results (append mode)\",\n        )\n        parser.add_argument(\n            \"--disable-tqdm\",\n            action=\"store_true\",\n            help=\"Disable progress bar\",\n        )\n\n    @classmethod\n    def from_cli_args(cls, args: argparse.Namespace):\n        \"\"\"Create BenchArgs from parsed CLI arguments.\"\"\"\n        attrs = [attr.name for attr in dataclasses.fields(cls)]\n        return cls(**{attr: getattr(args, attr) for attr in attrs})\n\n\ndef initialize_engine(server_args: ServerArgs) -> DiffGenerator:\n    \"\"\"Initialize diffusion pipeline engine.\"\"\"\n    logger.info(\"Initializing engine...\")\n    engine = DiffGenerator.from_server_args(server_args, local_mode=True)\n    logger.info(\"Engine initialized successfully\")\n    return engine\n\n\ndef generate_batch(\n    engine: DiffGenerator,\n    bench_args: BenchArgs,\n    prompts: List[str],\n    user_sampling_params: Dict[str, Any],\n) -> BatchOutput:\n    \"\"\"Generate batch of images/videos synchronously.\"\"\"\n    output = BatchOutput()\n    start_time = time.perf_counter()\n\n    torch.cuda.reset_peak_memory_stats()\n\n    for prompt in prompts:\n        try:\n            sampling_params_kwargs = dict(user_sampling_params)\n            sampling_params_kwargs[\"prompt\"] = prompt\n            result = engine.generate(sampling_params_kwargs=sampling_params_kwargs)\n\n            if result is not None:\n                if isinstance(result, list):\n                    output.total_frames += len(result)\n                else:\n                    output.total_frames += 1\n            output.num_samples += 1\n        except Exception as e:\n            logger.error(f\"Generation failed for prompt '{prompt[:50]}...': {e}\")\n            output.error = str(e)\n\n    output.latency = time.perf_counter() - start_time\n    output.latency_per_sample = output.latency / len(prompts) if prompts else 0.0\n    output.success = output.num_samples > 0\n    output.peak_memory_mb = torch.cuda.max_memory_allocated() / (1024 * 1024)\n\n    logger.debug(\n        f\"Batch generated: {output.num_samples}/{len(prompts)} samples in {output.latency:.2f}s\"\n    )\n\n    return output\n\n\ndef calculate_metrics(\n    outputs: List[BatchOutput],\n    total_duration: float,\n    resolution: Tuple[int, int, int],\n    num_requests: int,\n) -> Dict[str, Any]:\n    \"\"\"Calculate generation-specific throughput metrics.\"\"\"\n    successful = [o for o in outputs if o.success]\n    num_success = sum(o.num_samples for o in successful)\n    total_frames = sum(o.total_frames for o in successful)\n    peak_memory = max((o.peak_memory_mb for o in outputs), default=0)\n\n    width, height, frames = resolution\n    pixels_per_sample = width * height * frames\n    total_pixels = num_success * pixels_per_sample\n\n    metrics = {\n        \"num_requests\": num_requests,\n        \"successful_requests\": num_success,\n        \"failed_requests\": num_requests - num_success,\n        \"total_duration_seconds\": total_duration,\n        \"total_frames_generated\": total_frames,\n        \"total_pixels_generated\": total_pixels,\n        \"images_per_second\": num_success / total_duration if total_duration > 0 else 0,\n        \"frames_per_second\": total_frames / total_duration if total_duration > 0 else 0,\n        \"megapixels_per_second\": (\n            total_pixels / (total_duration * 1e6) if total_duration > 0 else 0\n        ),\n        \"requests_per_second\": (\n            num_success / total_duration if total_duration > 0 else 0\n        ),\n        \"latency_per_request_seconds\": (\n            total_duration / num_success if num_success > 0 else 0\n        ),\n        \"peak_memory_mb\": peak_memory,\n    }\n\n    return metrics\n\n\ndef throughput_test(\n    server_args: ServerArgs,\n    bench_args: BenchArgs,\n) -> Dict[str, Any]:\n    \"\"\"Main throughput benchmark function.\"\"\"\n    configure_logger(server_args=server_args)\n    logger.info(\"Starting offline throughput benchmark...\")\n\n    engine = initialize_engine(server_args)\n\n    logger.info(f\"Loading {bench_args.dataset} dataset...\")\n    if bench_args.dataset == \"vbench\":\n        bench_args.task_name = engine.server_args.pipeline_config.task_type\n        dataset = VBenchDataset(bench_args)\n    elif bench_args.dataset == \"random\":\n        dataset = RandomDataset(bench_args)\n    else:\n        raise ValueError(f\"Unknown dataset: {bench_args.dataset}\")\n\n    sampling_params = {\n        \"guidance_scale\": bench_args.guidance_scale,\n        \"num_inference_steps\": bench_args.num_inference_steps,\n        \"height\": bench_args.height,\n        \"width\": bench_args.width,\n        \"num_frames\": bench_args.num_frames,\n        \"seed\": bench_args.seed,\n    }\n    if bench_args.disable_safety_checker:\n        sampling_params[\"safety_checker\"] = None\n\n    if not bench_args.skip_warmup:\n        logger.info(\"Running warmup batch...\")\n        warmup_count = min(bench_args.batch_size, len(dataset))\n        warmup_prompts = [dataset[i].prompt for i in range(warmup_count)]\n        generate_batch(engine, bench_args, warmup_prompts, sampling_params)\n\n    logger.info(f\"Running benchmark with {bench_args.num_prompts} prompts...\")\n    outputs: List[BatchOutput] = []\n    total_count = min(bench_args.num_prompts, len(dataset))\n    all_prompts = [dataset[i].prompt for i in range(total_count)]\n\n    start_time = time.perf_counter()\n\n    num_batches = (total_count + bench_args.batch_size - 1) // bench_args.batch_size\n    pbar = tqdm(\n        total=num_batches,\n        disable=bench_args.disable_tqdm,\n        desc=\"Benchmark\",\n    )\n\n    for batch_start in range(0, total_count, bench_args.batch_size):\n        batch_end = min(batch_start + bench_args.batch_size, total_count)\n        batch_prompts = all_prompts[batch_start:batch_end]\n\n        batch_output = generate_batch(\n            engine, bench_args, batch_prompts, sampling_params\n        )\n        outputs.append(batch_output)\n\n        pbar.update(1)\n\n    pbar.close()\n    total_duration = time.perf_counter() - start_time\n\n    resolution = (bench_args.width, bench_args.height, bench_args.num_frames)\n    metrics = calculate_metrics(\n        outputs,\n        total_duration,\n        resolution=resolution,\n        num_requests=total_count,\n    )\n\n    display_results(\n        metrics,\n        bench_args,\n        model_path=server_args.model_path,\n    )\n\n    if bench_args.output_file:\n        save_results(metrics, bench_args, server_args)\n\n    return metrics\n\n\ndef display_results(\n    metrics: Dict[str, Any],\n    bench_args: BenchArgs,\n    model_path: str,\n):\n    \"\"\"Display benchmark results in console.\"\"\"\n    print(\n        \"\\n{s:{c}^{n}}\".format(s=\" Offline Throughput Benchmark Result \", n=110, c=\"=\")\n    )\n    print_value_formatted(\"Model:\", model_path)\n    print_value_formatted(\"Dataset:\", bench_args.dataset)\n    print_value_formatted(\n        \"Resolution:\",\n        f\"{bench_args.width}x{bench_args.height}x{bench_args.num_frames}\",\n    )\n    print_value_formatted(\"Num Inference Steps:\", bench_args.num_inference_steps)\n    print_divider(75)\n    print_value_formatted(\"Total Requests:\", metrics[\"num_requests\"])\n    print_value_formatted(\"Successful Requests:\", metrics[\"successful_requests\"])\n    print_value_formatted(\"Failed Requests:\", metrics[\"failed_requests\"])\n    print_value_formatted(\n        \"Total Duration (seconds):\", metrics[\"total_duration_seconds\"]\n    )\n    print_divider(75)\n    print_value_formatted(\"Frames Generated:\", metrics[\"total_frames_generated\"])\n    print_value_formatted(\n        \"Megapixels Generated:\", metrics[\"total_pixels_generated\"] / 1e6\n    )\n    print_divider(75)\n    print_value_formatted(\n        \"Frame Throughput (frames/sec):\", metrics[\"frames_per_second\"]\n    )\n    print_value_formatted(\"MP Throughput (MP/sec):\", metrics[\"megapixels_per_second\"])\n    print_value_formatted(\"Requests Per Second:\", metrics[\"requests_per_second\"])\n    print_value_formatted(\n        \"Latency Per Request (sec):\", metrics[\"latency_per_request_seconds\"]\n    )\n    print_value_formatted(\"Peak Memory (MB):\", metrics[\"peak_memory_mb\"])\n    print_divider(110, \"=\")\n\n\ndef save_results(\n    metrics: Dict[str, Any],\n    bench_args: BenchArgs,\n    server_args: ServerArgs,\n):\n    \"\"\"Save benchmark results to JSON file.\"\"\"\n    result = {\n        \"metadata\": {\n            \"timestamp\": time.strftime(\"%Y-%m-%dT%H:%M:%S\"),\n            \"model_path\": server_args.model_path,\n            \"task_type\": bench_args.task_name,\n            \"backend\": \"engine\",\n        },\n        \"configuration\": {\n            \"num_inference_steps\": bench_args.num_inference_steps,\n            \"guidance_scale\": bench_args.guidance_scale,\n            \"seed\": bench_args.seed,\n            \"batch_size\": bench_args.batch_size,\n            \"num_prompts\": bench_args.num_prompts,\n            \"resolution\": f\"{bench_args.width}x{bench_args.height}x{bench_args.num_frames}\",\n            \"dataset\": bench_args.dataset,\n        },\n        \"results\": metrics,\n    }\n\n    with open(bench_args.output_file, \"a\") as f:\n        f.write(json.dumps(result) + \"\\n\")\n\n    logger.info(f\"Results saved to {bench_args.output_file}\")\n\n\ndef main():\n    \"\"\"Main entry point.\"\"\"\n    parser = argparse.ArgumentParser(\n        description=\"Offline throughput benchmark for multimodal generation models\"\n    )\n\n    ServerArgs.add_cli_args(parser)\n    BenchArgs.add_cli_args(parser)\n\n    args = parser.parse_args()\n\n    server_args = ServerArgs.from_cli_args(args)\n    bench_args = BenchArgs.from_cli_args(args)\n\n    set_global_server_args(server_args)\n\n    result = throughput_test(server_args, bench_args)\n\n    return result\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "python/sglang/multimodal_gen/benchmarks/bench_serving.py",
    "content": "\"\"\"\nBenchmark online serving for diffusion models (Image/Video Generation).\n\n\nUsage:\n    # launch a server and benchmark on it\n\n    # T2V or T2I or any other multimodal generation model\n    sglang serve --model-path Wan-AI/Wan2.2-T2V-A14B-Diffusers --num-gpus 1 --port 1231\n\n    # benchmark it and make sure the port is the same as the server's port\n    python3 -m sglang.multimodal_gen.benchmarks.bench_serving --dataset vbench --num-prompts 20 --port 1231\n\n    # benchmark with SLO metrics enabled\n    python3 -m sglang.multimodal_gen.benchmarks.bench_serving --dataset vbench --num-prompts 20 --port 1231 --slo --slo-scale 3.0 --warmup-requests 2\n\"\"\"\n\nimport argparse\nimport asyncio\nimport json\nimport os\nimport time\nfrom dataclasses import replace\nfrom typing import Any, Dict, List, Optional\n\nimport aiohttp\nimport numpy as np\nimport requests\nfrom tqdm.asyncio import tqdm\n\nfrom sglang.multimodal_gen.benchmarks.datasets import (\n    RandomDataset,\n    RequestFuncInput,\n    RequestFuncOutput,\n    VBenchDataset,\n)\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import (\n    configure_logger,\n    init_logger,\n)\nfrom sglang.multimodal_gen.test.test_utils import print_divider, print_value_formatted\n\nlogger = init_logger(__name__)\n\n# Patch size used for computing area units (e.g. in latent diffusion models).\nPATCH_SIZE = 16\nPATCH_AREA = PATCH_SIZE * PATCH_SIZE\n\n\ndef _compute_scale_factor(req: RequestFuncInput, args) -> Optional[float]:\n    \"\"\"Computes the composite scale factor (area × frames × steps) for a request.\"\"\"\n    width = req.width or args.width\n    height = req.height or args.height\n    if None in (width, height):\n        return None\n    frames = req.num_frames or args.num_frames\n    steps = req.num_inference_steps or args.num_inference_steps\n\n    frame_scale = frames if isinstance(frames, int) and frames > 0 else 1\n    step_scale = steps if isinstance(steps, int) and steps > 0 else 1\n\n    area_units = max((float(width) * float(height)) / float(PATCH_AREA), 1.0)\n    return area_units * float(frame_scale) * float(step_scale)\n\n\ndef _compute_expected_latency_ms_from_base(\n    req: RequestFuncInput, args, base_time_ms: Optional[float]\n) -> Optional[float]:\n    \"\"\"Scales latency linearly by pixel area, frame count, and inference steps.\"\"\"\n    if base_time_ms is None:\n        return None\n    scale = _compute_scale_factor(req, args)\n    if scale is None:\n        return None\n    return float(base_time_ms) * scale\n\n\ndef _infer_slo_base_time_ms_from_warmups(\n    warmup_pairs: List[tuple], args\n) -> Optional[float]:\n    \"\"\"Derives median base latency from successful warmup runs.\"\"\"\n    candidates_ms: List[float] = []\n    for req, out in warmup_pairs:\n        if not out.success or out.latency <= 0:\n            logger.warning(\n                f\"Skipping warmup result: success={out.success}, latency={out.latency:.3f}\"\n            )\n            continue\n\n        scale = _compute_scale_factor(req, args)\n        if scale is None or scale <= 0:\n            continue\n\n        candidates_ms.append((out.latency * 1000.0) / scale)\n\n    return float(np.median(candidates_ms)) if candidates_ms else None\n\n\ndef _populate_slo_ms_from_warmups(\n    requests_list: List[RequestFuncInput], warmup_pairs: List[tuple], args\n) -> List[RequestFuncInput]:\n    \"\"\"Assigns estimated SLO targets to requests lacking them.\"\"\"\n    if not any(req.slo_ms is None for req in requests_list):\n        return requests_list\n\n    base_time_ms = _infer_slo_base_time_ms_from_warmups(warmup_pairs, args)\n    if base_time_ms is None:\n        return requests_list\n\n    slo_scale = float(getattr(args, \"slo_scale\", 3.0))\n    if slo_scale <= 0:\n        raise ValueError(f\"slo_scale must be positive, got {slo_scale}.\")\n\n    updated: List[RequestFuncInput] = []\n    for req in requests_list:\n        if req.slo_ms is not None:\n            updated.append(req)\n            continue\n        expected_ms = _compute_expected_latency_ms_from_base(req, args, base_time_ms)\n        if expected_ms is not None:\n            # Create a new RequestFuncInput with updated slo_ms\n            updated.append(replace(req, slo_ms=expected_ms * slo_scale))\n        else:\n            updated.append(req)\n\n    return updated\n\n\nasync def async_request_image_sglang(\n    input: RequestFuncInput,\n    session: aiohttp.ClientSession,\n    pbar: Optional[tqdm] = None,\n) -> RequestFuncOutput:\n    output = RequestFuncOutput()\n    output.start_time = time.perf_counter()\n\n    # Check if we need to use multipart (for image edits with input images)\n    if input.image_paths and len(input.image_paths) > 0:\n        # Use multipart/form-data for image edits\n        data = aiohttp.FormData()\n        data.add_field(\"model\", input.model)\n        data.add_field(\"prompt\", input.prompt)\n        data.add_field(\"response_format\", \"b64_json\")\n\n        if input.width and input.height:\n            data.add_field(\"size\", f\"{input.width}x{input.height}\")\n\n        # Merge extra parameters\n        for key, value in input.extra_body.items():\n            data.add_field(key, str(value))\n\n        # Add image file(s)\n        for idx, img_path in enumerate(input.image_paths):\n            if os.path.exists(img_path):\n                data.add_field(\n                    \"image\",\n                    open(img_path, \"rb\"),\n                    filename=os.path.basename(img_path),\n                    content_type=\"application/octet-stream\",\n                )\n            else:\n                output.error = f\"Image file not found: {img_path}\"\n                output.success = False\n                if pbar:\n                    pbar.update(1)\n                return output\n\n        try:\n            async with session.post(input.api_url, data=data) as response:\n                if response.status == 200:\n                    resp_json = await response.json()\n                    output.response_body = resp_json\n                    output.success = True\n                    if \"peak_memory_mb\" in resp_json:\n                        output.peak_memory_mb = resp_json[\"peak_memory_mb\"]\n                else:\n                    output.error = f\"HTTP {response.status}: {await response.text()}\"\n                    output.success = False\n        except Exception as e:\n            output.error = str(e)\n            output.success = False\n    else:\n        # Use JSON for text-to-image generation\n        payload = {\n            \"model\": input.model,\n            \"prompt\": input.prompt,\n            \"n\": 1,\n            \"response_format\": \"b64_json\",\n        }\n\n        if input.width and input.height:\n            payload[\"size\"] = f\"{input.width}x{input.height}\"\n\n        # Merge extra parameters\n        payload.update(input.extra_body)\n\n        try:\n            async with session.post(input.api_url, json=payload) as response:\n                if response.status == 200:\n                    resp_json = await response.json()\n                    output.response_body = resp_json\n                    output.success = True\n                    if \"peak_memory_mb\" in resp_json:\n                        output.peak_memory_mb = resp_json[\"peak_memory_mb\"]\n                else:\n                    output.error = f\"HTTP {response.status}: {await response.text()}\"\n                    output.success = False\n        except Exception as e:\n            output.error = str(e)\n            output.success = False\n\n    output.latency = time.perf_counter() - output.start_time\n\n    # Check SLO if defined\n    if input.slo_ms is not None and output.success:\n        output.slo_achieved = (output.latency * 1000.0) <= input.slo_ms\n\n    if pbar:\n        pbar.update(1)\n    return output\n\n\nasync def async_request_video_sglang(\n    input: RequestFuncInput,\n    session: aiohttp.ClientSession,\n    pbar: Optional[tqdm] = None,\n) -> RequestFuncOutput:\n    output = RequestFuncOutput()\n    output.start_time = time.perf_counter()\n\n    # 1. Submit Job\n    job_id = None\n    # Check if we need to upload images (Multipart) or just send JSON\n    if input.image_paths and len(input.image_paths) > 0:\n        # Use multipart/form-data\n        data = aiohttp.FormData()\n        data.add_field(\"model\", input.model)\n        data.add_field(\"prompt\", input.prompt)\n\n        if input.width and input.height:\n            data.add_field(\"size\", f\"{input.width}x{input.height}\")\n\n        # Add extra body fields to form data if possible, or assume simple key-values\n        # Note: Nested dicts in extra_body might need JSON serialization if API expects it stringified\n        if input.extra_body:\n            data.add_field(\"extra_body\", json.dumps(input.extra_body))\n\n        # Explicitly add fps/num_frames if they are not in extra_body (bench_serving logic overrides)\n        if input.num_frames:\n            data.add_field(\"num_frames\", str(input.num_frames))\n        if input.fps:\n            data.add_field(\"fps\", str(input.fps))\n\n        # Add image file\n        # Currently only support single image upload as 'input_reference' per API spec\n        img_path = input.image_paths[0]\n        if os.path.exists(img_path):\n            data.add_field(\n                \"input_reference\",\n                open(img_path, \"rb\"),\n                filename=os.path.basename(img_path),\n                content_type=\"application/octet-stream\",\n            )\n        else:\n            output.error = f\"Image file not found: {img_path}\"\n            output.success = False\n            if pbar:\n                pbar.update(1)\n            return output\n\n        try:\n            async with session.post(input.api_url, data=data) as response:\n                if response.status == 200:\n                    resp_json = await response.json()\n                    job_id = resp_json.get(\"id\")\n                else:\n                    output.error = (\n                        f\"Submit failed HTTP {response.status}: {await response.text()}\"\n                    )\n                    output.success = False\n                    if pbar:\n                        pbar.update(1)\n                    return output\n        except Exception as e:\n            output.error = f\"Submit exception: {str(e)}\"\n            output.success = False\n            if pbar:\n                pbar.update(1)\n            return output\n\n    else:\n        # Use JSON\n        payload: Dict[str, Any] = {\n            \"model\": input.model,\n            \"prompt\": input.prompt,\n        }\n        if input.width and input.height:\n            payload[\"size\"] = f\"{input.width}x{input.height}\"\n        if input.num_frames:\n            payload[\"num_frames\"] = input.num_frames\n        if input.fps:\n            payload[\"fps\"] = input.fps\n\n        payload.update(input.extra_body)\n\n        try:\n            async with session.post(input.api_url, json=payload) as response:\n                if response.status == 200:\n                    resp_json = await response.json()\n                    job_id = resp_json.get(\"id\")\n                else:\n                    output.error = (\n                        f\"Submit failed HTTP {response.status}: {await response.text()}\"\n                    )\n                    output.success = False\n                    if pbar:\n                        pbar.update(1)\n                    return output\n        except Exception as e:\n            output.error = f\"Submit exception: {str(e)}\"\n            output.success = False\n            if pbar:\n                pbar.update(1)\n            return output\n\n    if not job_id:\n        output.error = \"No job_id returned\"\n        output.success = False\n        if pbar:\n            pbar.update(1)\n        return output\n\n    # 2. Poll for completion\n    # Assuming the API returns a 'status' field.\n    # We construct the check URL. Assuming api_url is like .../v1/videos\n    # The check url should be .../v1/videos/{id}\n    check_url = f\"{input.api_url}/{job_id}\"\n\n    while True:\n        try:\n            async with session.get(check_url) as response:\n                if response.status == 200:\n                    status_data = await response.json()\n                    status = status_data.get(\"status\")\n                    if status == \"completed\":\n                        output.success = True\n                        output.response_body = status_data\n                        if \"peak_memory_mb\" in status_data:\n                            output.peak_memory_mb = status_data[\"peak_memory_mb\"]\n                        break\n                    elif status == \"failed\":\n                        output.success = False\n                        output.error = f\"Job failed: {status_data.get('error')}\"\n                        break\n                    else:\n                        # queued or processing\n                        await asyncio.sleep(1.0)\n                else:\n                    output.success = False\n                    output.error = (\n                        f\"Poll failed HTTP {response.status}: {await response.text()}\"\n                    )\n                    break\n        except Exception as e:\n            output.success = False\n            output.error = f\"Poll exception: {str(e)}\"\n            break\n\n    output.latency = time.perf_counter() - output.start_time\n\n    # Check SLO if defined\n    if input.slo_ms is not None and output.success:\n        output.slo_achieved = (output.latency * 1000.0) <= input.slo_ms\n\n    if pbar:\n        pbar.update(1)\n    return output\n\n\ndef calculate_metrics(\n    outputs: List[RequestFuncOutput],\n    total_duration: float,\n    requests_list: List[RequestFuncInput],\n    args,\n    slo_enabled: bool,\n):\n    success_outputs = [o for o in outputs if o.success]\n    error_outputs = [o for o in outputs if not o.success]\n\n    num_success = len(success_outputs)\n    latencies = [o.latency for o in success_outputs]\n    peak_memories = [o.peak_memory_mb for o in success_outputs if o.peak_memory_mb > 0]\n\n    metrics = {\n        \"duration\": total_duration,\n        \"completed_requests\": num_success,\n        \"failed_requests\": len(error_outputs),\n        \"throughput_qps\": num_success / total_duration if total_duration > 0 else 0,\n        \"latency_mean\": np.mean(latencies) if latencies else 0,\n        \"latency_median\": np.median(latencies) if latencies else 0,\n        \"latency_p99\": np.percentile(latencies, 99) if latencies else 0,\n        \"latency_p50\": np.percentile(latencies, 50) if latencies else 0,\n        \"peak_memory_mb_max\": max(peak_memories) if peak_memories else 0,\n        \"peak_memory_mb_mean\": np.mean(peak_memories) if peak_memories else 0,\n        \"peak_memory_mb_median\": np.median(peak_memories) if peak_memories else 0,\n    }\n\n    if slo_enabled:\n        slo_defined_total = 0\n        slo_met_success = 0\n\n        for req, out in zip(requests_list, outputs):\n            if req.slo_ms is None:\n                continue\n            slo_defined_total += 1\n            if out.slo_achieved:\n                slo_met_success += 1\n\n        slo_attain_all = (\n            (slo_met_success / slo_defined_total) if slo_defined_total > 0 else 0.0\n        )\n\n        metrics.update(\n            {\n                \"slo_attainment_rate\": slo_attain_all,\n                \"slo_met_success\": slo_met_success,\n                \"slo_scale\": getattr(args, \"slo_scale\", 3.0),\n            }\n        )\n\n    return metrics\n\n\ndef wait_for_service(base_url: str, timeout: int = 1200) -> None:\n    logger.info(f\"Waiting for service at {base_url}...\")\n    start_time = time.time()\n    while True:\n        try:\n            # Try /health endpoint first\n            resp = requests.get(f\"{base_url}/health\", timeout=1)\n            if resp.status_code == 200:\n                logger.info(\"Service is ready.\")\n                break\n        except requests.exceptions.RequestException:\n            pass\n\n        if time.time() - start_time > timeout:\n            raise TimeoutError(\n                f\"Service at {base_url} did not start within {timeout} seconds.\"\n            )\n\n        time.sleep(1)\n\n\nasync def benchmark(args):\n    from huggingface_hub import model_info\n\n    # Construct base_url if not provided\n    if args.base_url is None:\n        args.base_url = f\"http://{args.host}:{args.port}\"\n\n    # Wait for service\n    wait_for_service(args.base_url)\n\n    # Fetch model info\n    try:\n        resp = requests.get(f\"{args.base_url}/v1/model_info\", timeout=5)\n        if resp.status_code == 200:\n            info = resp.json()\n            if \"model_path\" in info and info[\"model_path\"]:\n                args.model = info[\"model_path\"]\n                logger.info(f\"Updated model name from server: {args.model}\")\n    except Exception as e:\n        logger.info(f\"Failed to fetch model info: {e}. Using default: {args.model}\")\n\n    valid_tasks = (\n        \"text-to-video\",\n        \"image-to-video\",\n        \"video-to-video\",\n        \"text-to-image\",\n        \"image-to-image\",\n    )\n\n    # Resolve task_name with priority: args.task > local config > HF pipeline_tag\n    if args.task:\n        task_name = args.task\n        logger.info(f\"Using task from --task: {task_name}\")\n    elif os.path.exists(args.model):\n        config_path = os.path.join(args.model, \"config.json\")\n        if os.path.exists(config_path):\n            with open(config_path, \"r\") as f:\n                config = json.load(f)\n            task_name = config.get(\"pipeline_tag\", \"text-to-image\")\n            logger.info(f\"Inferred task from local config.json: {task_name}\")\n        else:\n            task_name = \"text-to-image\"\n            logger.info(f\"No config.json found, defaulting task to: {task_name}\")\n    else:\n        task_name = model_info(args.model).pipeline_tag\n        logger.info(f\"Inferred task from HuggingFace pipeline_tag: {task_name}\")\n\n    if task_name not in valid_tasks:\n        raise ValueError(\n            f\"Task '{task_name}' is not a valid multimodal generation task. \"\n            f\"Use --task to specify one of: {', '.join(valid_tasks)}\"\n        )\n\n    if task_name in (\"text-to-video\", \"image-to-video\", \"video-to-video\"):\n        api_url = f\"{args.base_url}/v1/videos\"\n        request_func = async_request_video_sglang\n    else:  # text-to-image or image-to-image\n        api_url = (\n            f\"{args.base_url}/v1/images/edits\"\n            if task_name == \"image-to-image\"\n            else f\"{args.base_url}/v1/images/generations\"\n        )\n        request_func = async_request_image_sglang\n\n    setattr(args, \"task_name\", task_name)\n\n    if args.dataset == \"vbench\":\n        dataset = VBenchDataset(args, api_url, args.model)\n    elif args.dataset == \"random\":\n        dataset = RandomDataset(args, api_url, args.model)\n    else:\n        raise ValueError(f\"Unknown dataset: {args.dataset}\")\n\n    logger.info(f\"Loading requests...\")\n    requests_list = dataset.get_requests()\n    logger.info(f\"Prepared {len(requests_list)} requests from {args.dataset} dataset.\")\n\n    # Limit concurrency\n    if args.max_concurrency is not None:\n        semaphore = asyncio.Semaphore(args.max_concurrency)\n    else:\n        semaphore = None\n\n    async def limited_request_func(req, session, pbar):\n        if semaphore:\n            async with semaphore:\n                return await request_func(req, session, pbar)\n        else:\n            return await request_func(req, session, pbar)\n\n    async with aiohttp.ClientSession() as session:\n        # Run warmup requests\n        warmup_pairs: List[tuple] = []\n        if args.warmup_requests and requests_list:\n            # The server always overrides warmup requests to use\n            # num_inference_steps=1 (see Req.set_as_warmup), so we match\n            # that here to keep the benchmark's SLO estimation consistent.\n            warmup_steps = 1\n            logger.info(\n                f\"Running {args.warmup_requests} warmup request(s) with \"\n                f\"num_inference_steps={warmup_steps}...\"\n            )\n            for i in range(args.warmup_requests):\n                warm_req = requests_list[i % len(requests_list)]\n                warm_req = replace(\n                    warm_req,\n                    num_inference_steps=warmup_steps,\n                )\n                warm_out = await limited_request_func(warm_req, session, None)\n                warmup_pairs.append((warm_req, warm_out))\n                logger.info(\n                    f\"Warmup {i+1}/{args.warmup_requests}: \"\n                    f\"latency={warm_out.latency:.2f}s, success={warm_out.success}\"\n                )\n\n        # Populate SLO values from warmups if enabled\n        if args.slo:\n            requests_list = _populate_slo_ms_from_warmups(\n                requests_list=requests_list, warmup_pairs=warmup_pairs, args=args\n            )\n\n        # Run benchmark\n        pbar = tqdm(total=len(requests_list), disable=args.disable_tqdm)\n        start_time = time.perf_counter()\n        tasks = []\n        for req in requests_list:\n            if args.request_rate != float(\"inf\"):\n                # Poisson process: inter-arrival times follow exponential distribution\n                interval = np.random.exponential(1.0 / args.request_rate)\n                await asyncio.sleep(interval)\n\n            task = asyncio.create_task(limited_request_func(req, session, pbar))\n            tasks.append(task)\n\n        outputs = await asyncio.gather(*tasks)\n        total_duration = time.perf_counter() - start_time\n\n        pbar.close()\n\n    # Calculate metrics\n    metrics = calculate_metrics(outputs, total_duration, requests_list, args, args.slo)\n\n    print(\"\\n{s:{c}^{n}}\".format(s=\" Serving Benchmark Result \", n=60, c=\"=\"))\n\n    # Section 1: Configuration\n    print_value_formatted(\"Task:\", task_name)\n    print_value_formatted(\"Model:\", args.model)\n    print_value_formatted(\"Dataset:\", args.dataset)\n\n    # Section 2: Execution & Traffic\n    print_divider(50)\n    print_value_formatted(\"Benchmark duration (s):\", metrics[\"duration\"])\n    print_value_formatted(\"Request rate:\", str(args.request_rate))\n    print_value_formatted(\n        \"Max request concurrency:\",\n        str(args.max_concurrency) if args.max_concurrency else \"not set\",\n    )\n    print_value_formatted(\n        \"Successful requests:\",\n        f\"{metrics['completed_requests']}/{len(requests_list)}\",\n    )\n\n    # Section 3: Performance Metrics\n    print_divider(50)\n\n    print_value_formatted(\"Request throughput (req/s):\", metrics[\"throughput_qps\"])\n\n    print_value_formatted(\"Latency Mean (s):\", metrics[\"latency_mean\"])\n    print_value_formatted(\"Latency Median (s):\", metrics[\"latency_median\"])\n    print_value_formatted(\"Latency P99 (s):\", metrics[\"latency_p99\"])\n\n    if metrics[\"peak_memory_mb_max\"] > 0:\n        print_divider(50)\n        print_value_formatted(\"Peak Memory Max (MB):\", metrics[\"peak_memory_mb_max\"])\n        print_value_formatted(\"Peak Memory Mean (MB):\", metrics[\"peak_memory_mb_mean\"])\n        print_value_formatted(\n            \"Peak Memory Median (MB):\", metrics[\"peak_memory_mb_median\"]\n        )\n\n    if args.slo and \"slo_attainment_rate\" in metrics:\n        print_divider(50)\n        print(\n            \"{:<40} {:<15.2%}\".format(\n                \"SLO Attainment Rate:\", metrics[\"slo_attainment_rate\"]\n            )\n        )\n        print(\"{:<40} {:<15}\".format(\"SLO Met (Success):\", metrics[\"slo_met_success\"]))\n        print(\"{:<40} {:<15.2f}\".format(\"SLO Scale:\", metrics[\"slo_scale\"]))\n\n    print_divider(60)\n\n    if args.output_file:\n        with open(args.output_file, \"w\") as f:\n            json.dump(metrics, f, indent=2)\n        print(f\"Metrics saved to {args.output_file}\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\n        description=\"Benchmark serving for diffusion models.\"\n    )\n    parser.add_argument(\n        \"--backend\",\n        type=str,\n        default=None,\n        help=\"DEPRECATED: --task is deprecated and will be ignored. The task will be inferred from --model.\",\n    )\n    parser.add_argument(\n        \"--base-url\",\n        type=str,\n        default=None,\n        help=\"Base URL of the server (e.g., http://localhost:30000). Overrides host/port.\",\n    )\n    parser.add_argument(\"--host\", type=str, default=\"localhost\", help=\"Server host.\")\n    parser.add_argument(\"--port\", type=int, default=30000, help=\"Server port.\")\n    parser.add_argument(\"--model\", type=str, default=\"default\", help=\"Model name.\")\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"vbench\",\n        choices=[\"vbench\", \"random\"],\n        help=\"Dataset to use.\",\n    )\n    parser.add_argument(\n        \"--task\",\n        type=str,\n        choices=[\n            \"text-to-video\",\n            \"image-to-video\",\n            \"text-to-image\",\n            \"image-to-image\",\n            \"video-to-video\",\n        ],\n        default=None,\n        help=\"The task will be inferred from huggingface pipeline_tag. When huggingface pipeline_tag is not provided, --task will be used.\",\n    )\n    parser.add_argument(\n        \"--dataset-path\",\n        type=str,\n        default=None,\n        help=\"Path to local dataset file (optional).\",\n    )\n    parser.add_argument(\n        \"--num-prompts\", type=int, default=10, help=\"Number of prompts to benchmark.\"\n    )\n    parser.add_argument(\n        \"--max-concurrency\",\n        type=int,\n        default=1,\n        help=\"Maximum number of concurrent requests, default to `1`. This can be used \"\n        \"to help simulate an environment where a higher level component \"\n        \"is enforcing a maximum number of concurrent requests. While the \"\n        \"--request-rate argument controls the rate at which requests are \"\n        \"initiated, this argument will control how many are actually allowed \"\n        \"to execute at a time. This means that when used in combination, the \"\n        \"actual request rate may be lower than specified with --request-rate, \"\n        \"if the server is not processing requests fast enough to keep up.\",\n    )\n    parser.add_argument(\n        \"--request-rate\",\n        type=float,\n        default=float(\"inf\"),\n        help=\"Number of requests per second. If this is inf, then all the requests are sent at time 0. \"\n        \"Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.\",\n    )\n    parser.add_argument(\"--width\", type=int, default=None, help=\"Image/Video width.\")\n    parser.add_argument(\"--height\", type=int, default=None, help=\"Image/Video height.\")\n    parser.add_argument(\n        \"--num-frames\", type=int, default=None, help=\"Number of frames (for video).\"\n    )\n    parser.add_argument(\"--fps\", type=int, default=None, help=\"FPS (for video).\")\n    parser.add_argument(\n        \"--output-file\", type=str, default=None, help=\"Output JSON file for metrics.\"\n    )\n    parser.add_argument(\n        \"--disable-tqdm\", action=\"store_true\", help=\"Disable progress bar.\"\n    )\n    parser.add_argument(\n        \"--log-level\",\n        type=str,\n        default=\"INFO\",\n        choices=[\"DEBUG\", \"INFO\", \"WARNING\", \"ERROR\"],\n        help=\"Log level.\",\n    )\n    parser.add_argument(\n        \"--slo\",\n        action=\"store_true\",\n        help=\"Enable SLO calculation. Uses trace-provided slo_ms or infers from warmups.\",\n    )\n    parser.add_argument(\n        \"--slo-scale\",\n        type=float,\n        default=3.0,\n        help=\"SLO target multiplier: slo_ms = estimated_exec_time_ms * slo_scale (default: 3).\",\n    )\n    parser.add_argument(\n        \"--warmup-requests\",\n        type=int,\n        default=1,\n        help=\"Number of warmup requests to run before measurement.\",\n    )\n    parser.add_argument(\n        \"--num-inference-steps\",\n        type=int,\n        default=None,\n        help=\"Number of inference steps for diffusion models.\",\n    )\n\n    args = parser.parse_args()\n\n    configure_logger(args)\n\n    asyncio.run(benchmark(args))\n"
  },
  {
    "path": "python/sglang/multimodal_gen/benchmarks/compare_perf.py",
    "content": "import argparse\nimport json\nimport os\nimport re\nfrom datetime import datetime\nfrom typing import Any, Dict, List, Tuple\n\n\ndef calculate_diff(base: float, new: float) -> Tuple[float, float]:\n    \"\"\"Returns (diff, diff_percent).\"\"\"\n    diff = new - base\n    if base == 0:\n        percent = 0.0\n    else:\n        percent = (diff / base) * 100\n    return diff, percent\n\n\ndef calculate_upper_bound(baseline: float, rel_tol: float, min_abs_tol: float) -> float:\n    \"\"\"Calculates the upper bound for performance regression check.\"\"\"\n    rel_limit = baseline * (1 + rel_tol)\n    abs_limit = baseline + min_abs_tol\n    return max(rel_limit, abs_limit)\n\n\ndef calculate_lower_bound(baseline: float, rel_tol: float, min_abs_tol: float) -> float:\n    \"\"\"Calculates the lower bound for performance improvement check.\"\"\"\n    rel_lower = baseline * (1 - rel_tol)\n    abs_lower = baseline - min_abs_tol\n    return min(rel_lower, abs_lower)\n\n\ndef get_perf_status_emoji(\n    baseline: float,\n    new: float,\n    rel_tol: float = 0.1,\n    min_abs_tol: float = 120.0,\n) -> str:\n    \"\"\"\n    Determines the status emoji based on performance difference.\n\n    Logic:\n      Upper bound (Slower): max(baseline * (1 + rel_tol), baseline + min_abs_tol)\n      Lower bound (Faster): min(baseline * (1 - rel_tol), baseline - min_abs_tol)\n    \"\"\"\n    upper_bound = calculate_upper_bound(baseline, rel_tol, min_abs_tol)\n    lower_bound = calculate_lower_bound(baseline, rel_tol, min_abs_tol)\n\n    if new > upper_bound:\n        return \"🔴\"\n    elif new < lower_bound:\n        return \"🟢\"\n    else:\n        return \"⚪️\"\n\n\ndef consolidate_steps(\n    steps_list: List[Dict[str, Any]],\n) -> Tuple[Dict[str, float], List[str], Dict[str, int]]:\n    \"\"\"\n    Aggregates specific repeating steps (like denoising_step_*) into groups.\n    Returns:\n        - aggregated_durations: {name: duration_ms}\n        - ordered_names: list of names in execution order\n        - counts: {name: count_of_steps_aggregated}\n    \"\"\"\n    durations = {}\n    counts = {}\n    ordered_names = []\n    seen_names = set()\n\n    # Regex for steps to group\n    # Group \"denoising_step_0\", \"denoising_step_1\" -> \"Denoising Loop\"\n    denoise_pattern = re.compile(r\"^denoising_step_(\\d+)$\")\n    denoising_group_name = \"Denoising Loop\"\n\n    for step in steps_list:\n        name = step.get(\"name\", \"unknown\")\n        dur = step.get(\"duration_ms\", 0.0)\n\n        match = denoise_pattern.match(name)\n        if match:\n            key = denoising_group_name\n            if key not in durations:\n                durations[key] = 0.0\n                counts[key] = 0\n                if key not in seen_names:\n                    ordered_names.append(key)\n                    seen_names.add(key)\n            durations[key] += dur\n            counts[key] += 1\n        else:\n            # Standard stage (preserve order)\n            if name not in durations:\n                durations[name] = 0.0\n                counts[name] = 0\n                if name not in seen_names:\n                    ordered_names.append(name)\n                    seen_names.add(name)\n            durations[name] += dur\n            counts[name] += 1\n\n    return durations, ordered_names, counts\n\n\ndef _load_benchmark_file(file_path: str) -> Dict[str, Any]:\n    \"\"\"Loads a benchmark JSON file.\"\"\"\n    with open(file_path, \"r\", encoding=\"utf-8\") as f:\n        return json.load(f)\n\n\ndef _get_status_emoji_from_diff_percent(diff_pct):\n    if diff_pct < -2.0:\n        return \"✅\"\n    elif diff_pct > 2.0:\n        return \"❌\"\n    else:\n        return \"⚪️\"\n\n\ndef _print_single_comparison_report(\n    others_data, base_e2e, combined_order, base_durations, others_processed, base_counts\n):\n    new_data = others_data[0]\n    new_e2e = new_data.get(\"total_duration_ms\", 0)\n    diff_ms, diff_pct = calculate_diff(base_e2e, new_e2e)\n    status = _get_status_emoji_from_diff_percent(diff_pct)\n\n    print(\"#### 1. High-level Summary\")\n    print(\"| Metric | Baseline | New | Diff | Status |\")\n    print(\"| :--- | :--- | :--- | :--- | :--- |\")\n    print(\n        f\"| **E2E Latency** | {base_e2e:.2f} ms | {new_e2e:.2f} ms | **{diff_ms:+.2f} ms ({diff_pct:+.1f}%)** | {status} |\"\n    )\n    print(\n        f\"| **Throughput** | {1000 / base_e2e if base_e2e else 0:.2f} req/s | {1000 / new_e2e if new_e2e else 0:.2f} req/s | - | - |\"\n    )\n    print(\"\\n\")\n\n    print(\"#### 2. Stage Breakdown\")\n    print(\"| Stage Name | Baseline (ms) | New (ms) | Diff (ms) | Diff (%) | Status |\")\n    print(\"| :--- | :--- | :--- | :--- | :--- | :--- |\")\n\n    new_durations, _, new_counts = others_processed[0]\n\n    for stage in combined_order:\n        b_val = base_durations.get(stage, 0.0)\n        n_val = new_durations.get(stage, 0.0)\n        b_count = base_counts.get(stage, 1)\n        n_count = new_counts.get(stage, 1)\n\n        s_diff, s_pct = calculate_diff(b_val, n_val)\n\n        count_str = \"\"\n        if stage == \"Denoising Loop\":\n            count_str = (\n                f\" ({n_count} steps)\"\n                if n_count == b_count\n                else f\" ({b_count}->{n_count} steps)\"\n            )\n\n        status_emoji = get_perf_status_emoji(b_val, n_val)\n        print(\n            f\"| {stage}{count_str} | {b_val:.2f} | {n_val:.2f} | {s_diff:+.2f} | {s_pct:+.1f}% | {status_emoji} |\"\n        )\n\n\ndef _print_multi_comparison_report(\n    base_e2e,\n    others_data,\n    other_labels,\n    combined_order,\n    base_durations,\n    others_processed,\n):\n    print(\"#### 1. High-level Summary\")\n    header = \"| Metric | Baseline | \" + \" | \".join(other_labels) + \" |\"\n    sep = \"| :--- | :--- | \" + \" | \".join([\":---\"] * len(other_labels)) + \" |\"\n    print(header)\n    print(sep)\n\n    # E2E Row\n    row_e2e = f\"| **E2E Latency** | {base_e2e:.2f} ms |\"\n    for i, d in enumerate(others_data):\n        val = d.get(\"total_duration_ms\", 0)\n        diff_ms, diff_pct = calculate_diff(base_e2e, val)\n\n        status = _get_status_emoji_from_diff_percent(diff_pct)\n\n        row_e2e += f\" {val:.2f} ms ({diff_pct:+.1f}%) {status} |\"\n    print(row_e2e)\n    print(\"\\n\")\n\n    print(\"#### 2. Stage Breakdown\")\n    # Header: Stage | Baseline | Label1 | Label2 ...\n    header = \"| Stage Name | Baseline | \" + \" | \".join(other_labels) + \" |\"\n    sep = \"| :--- | :--- | \" + \" | \".join([\":---\"] * len(other_labels)) + \" |\"\n    print(header)\n    print(sep)\n\n    for stage in combined_order:\n        b_val = base_durations.get(stage, 0.0)\n        row_str = f\"| {stage} | {b_val:.2f} |\"\n\n        for i, (n_durations, _, n_counts) in enumerate(others_processed):\n            n_val = n_durations.get(stage, 0.0)\n            _, s_pct = calculate_diff(b_val, n_val)\n            status_emoji = get_perf_status_emoji(b_val, n_val)\n\n            row_str += f\" {n_val:.2f} ({s_pct:+.1f}%) {status_emoji} |\"\n        print(row_str)\n\n\ndef compare_benchmarks(file_paths: List[str], output_format: str = \"markdown\"):\n    \"\"\"\n    Compares benchmark JSON files and prints a report.\n    First file is baseline, others will be compared against it.\n    \"\"\"\n    if len(file_paths) < 2:\n        print(\"Error: Need at least 2 files to compare.\")\n        return\n\n    try:\n        data_list = [_load_benchmark_file(f) for f in file_paths]\n    except Exception as e:\n        print(f\"Error loading benchmark files: {e}\")\n        return\n\n    base_data = data_list[0]\n    others_data = data_list[1:]\n\n    # Use filenames as labels if multiple comparisons, else just \"New\"\n    other_labels = [os.path.basename(p) for p in file_paths[1:]]\n\n    base_e2e = base_data.get(\"total_duration_ms\", 0)\n\n    base_durations, base_order, base_counts = consolidate_steps(\n        base_data.get(\"steps\", [])\n    )\n\n    others_processed = []\n    for d in others_data:\n        dur, order, counts = consolidate_steps(d.get(\"steps\", []))\n        others_processed.append((dur, order, counts))\n\n    combined_order = []\n    # Collect all unique stages maintaining order from newest to baseline\n    for _, order, _ in reversed(others_processed):\n        for name in order:\n            if name not in combined_order:\n                combined_order.append(name)\n    for name in base_order:\n        if name not in combined_order:\n            combined_order.append(name)\n\n    if output_format == \"markdown\":\n        print(\"### Performance Comparison Report\\n\")\n\n        if len(others_data) == 1:\n            _print_single_comparison_report(\n                others_data,\n                base_e2e,\n                combined_order,\n                base_durations,\n                others_processed,\n                base_counts,\n            )\n        else:\n            _print_multi_comparison_report(\n                base_e2e,\n                others_data,\n                other_labels,\n                combined_order,\n                base_durations,\n                others_processed,\n            )\n\n        print(\"\\n\")\n        # Metadata\n        print(\"<details>\")\n        print(\"<summary>Metadata</summary>\\n\")\n        print(f\"- Baseline Commit: `{base_data.get('commit_hash', 'N/A')}`\")\n        for i, d in enumerate(others_data):\n            label = \"New\" if len(others_data) == 1 else other_labels[i]\n            print(f\"- {label} Commit: `{d.get('commit_hash', 'N/A')}`\")\n        print(f\"- Timestamp: {datetime.now().isoformat()}\")\n        print(\"</details>\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\n        description=\"Compare sglang-diffusion performance JSON files.\"\n    )\n    parser.add_argument(\n        \"files\",\n        nargs=\"+\",\n        help=\"List of JSON files. First is baseline, others are compared against it.\",\n    )\n    args = parser.parse_args()\n\n    compare_benchmarks(args.files)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/benchmarks/datasets.py",
    "content": "import glob\nimport json\nimport os\nimport re\nimport subprocess\nimport uuid\nfrom abc import ABC, abstractmethod\nfrom dataclasses import dataclass, field\nfrom typing import Any, Dict, List, Optional\n\nimport requests\nfrom PIL import Image\n\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\n@dataclass\nclass RequestFuncInput:\n    prompt: str\n    api_url: str = \"\"\n    model: str = \"\"\n    width: Optional[int] = None\n    height: Optional[int] = None\n    num_frames: Optional[int] = None\n    fps: Optional[int] = None\n    extra_body: Dict[str, Any] = field(default_factory=dict)\n    image_paths: Optional[List[str]] = None\n    request_id: str = field(default_factory=lambda: str(uuid.uuid4()))\n    slo_ms: Optional[float] = None\n    num_inference_steps: Optional[int] = None\n\n\n@dataclass\nclass RequestFuncOutput:\n    success: bool = False\n    latency: float = 0.0\n    error: str = \"\"\n    start_time: float = 0.0\n    response_body: Dict[str, Any] = field(default_factory=dict)\n    peak_memory_mb: float = 0.0\n    slo_achieved: Optional[bool] = None\n\n\ndef is_dir_not_empty(path: str) -> bool:\n    return os.path.isdir(path) and bool(os.listdir(path))\n\n\nclass BaseDataset(ABC):\n    def __init__(self, args, api_url: str = \"\", model: str = \"\"):\n        self.args = args\n        self.api_url = api_url\n        self.model = model\n        self.items: List[Dict[str, Any]] = []\n\n    @abstractmethod\n    def __len__(self) -> int:\n        pass\n\n    @abstractmethod\n    def __getitem__(self, idx: int) -> RequestFuncInput:\n        pass\n\n    def get_requests(self) -> List[RequestFuncInput]:\n        return [self[i] for i in range(len(self))]\n\n\nclass VBenchDataset(BaseDataset):\n    \"\"\"\n    Dataset loader for VBench prompts.\n    Supports t2v, i2v.\n    \"\"\"\n\n    T2V_PROMPT_URL = \"https://raw.githubusercontent.com/Vchitect/VBench/master/prompts/prompts_per_dimension/subject_consistency.txt\"\n    I2V_DOWNLOAD_SCRIPT_URL = \"https://raw.githubusercontent.com/Vchitect/VBench/master/vbench2_beta_i2v/download_data.sh\"\n\n    def __init__(self, args, api_url: str = \"\", model: str = \"\"):\n        super().__init__(args, api_url, model)\n        self.cache_dir = os.path.join(os.path.expanduser(\"~\"), \".cache\", \"sglang\")\n        self.items = self._load_data()\n\n    def _load_data(self) -> List[Dict[str, Any]]:\n        if self.args.task_name in (\"text-to-video\", \"text-to-image\", \"video-to-video\"):\n            return self._load_t2v_prompts()\n        elif self.args.task_name in (\"image-to-video\", \"image-to-image\"):\n            return self._load_i2v_data()\n        else:\n            raise ValueError(\n                f\"Illegal task name is found in VBenchDataset {self.args.task_name}\"\n            )\n\n    def _download_file(self, url: str, dest_path: str) -> None:\n        \"\"\"Download a file from URL to destination path.\"\"\"\n        os.makedirs(os.path.dirname(dest_path), exist_ok=True)\n        resp = requests.get(url)\n        resp.raise_for_status()\n        with open(dest_path, \"w\") as f:\n            f.write(resp.text)\n\n    def _load_t2v_prompts(self) -> List[Dict[str, Any]]:\n        path = self.args.dataset_path\n\n        if not path:\n            path = os.path.join(self.cache_dir, \"vbench_subject_consistency.txt\")\n            if not os.path.exists(path):\n                logger.info(f\"Downloading VBench T2V prompts to {path}...\")\n                try:\n                    self._download_file(self.T2V_PROMPT_URL, path)\n                except Exception as e:\n                    logger.info(f\"Failed to download VBench prompts: {e}\")\n                    return [{\"prompt\": \"A cat sitting on a bench\"}] * 50\n\n        prompts = []\n        with open(path, \"r\") as f:\n            for line in f:\n                line = line.strip()\n                if line:\n                    prompts.append({\"prompt\": line})\n\n        return self._resize_data(prompts)\n\n    def _auto_download_i2v_dataset(self) -> Optional[str]:\n        \"\"\"Auto-download VBench I2V dataset and return the dataset directory.\"\"\"\n        vbench_i2v_dir = os.path.join(self.cache_dir, \"vbench_i2v\", \"vbench2_beta_i2v\")\n        info_json_path = os.path.join(vbench_i2v_dir, \"data\", \"i2v-bench-info.json\")\n        crop_dir = os.path.join(vbench_i2v_dir, \"data\", \"crop\")\n        origin_dir = os.path.join(vbench_i2v_dir, \"data\", \"origin\")\n\n        if (\n            os.path.exists(info_json_path)\n            and is_dir_not_empty(crop_dir)\n            and is_dir_not_empty(origin_dir)\n        ):\n            return vbench_i2v_dir\n\n        logger.info(f\"Downloading VBench I2V dataset to {vbench_i2v_dir}...\")\n        try:\n            cache_root = os.path.join(self.cache_dir, \"vbench_i2v\")\n            script_path = os.path.join(cache_root, \"download_data.sh\")\n\n            self._download_file(self.I2V_DOWNLOAD_SCRIPT_URL, script_path)\n            os.chmod(script_path, 0o755)\n\n            logger.info(\"Executing download_data.sh (this may take a while)...\")\n\n            result = subprocess.run(\n                [\"bash\", script_path],\n                cwd=cache_root,\n                capture_output=True,\n                text=True,\n            )\n            if result.returncode != 0:\n                raise RuntimeError(f\"Download script failed: {result.stderr}\")\n            missing_packages = re.findall(r\"(\\S+): command not found\", result.stderr)\n            if missing_packages:\n                missing_packages = list(set(missing_packages))\n                package_list = \", \".join(f\"'{cmd}'\" for cmd in missing_packages)\n                raise RuntimeError(\n                    f\"Download script failed because the following commands are not installed: {package_list}.\\n\"\n                    \"Please install them (e.g., on Ubuntu: `sudo apt install ...`) and try again.\"\n                )\n            logger.info(\n                f\"Successfully downloaded VBench I2V dataset to {vbench_i2v_dir}\"\n            )\n        except Exception as e:\n            logger.info(f\"Failed to download VBench I2V dataset: {e}\")\n            logger.info(\"Please manually download following instructions at:\")\n            logger.info(\n                \"https://github.com/Vchitect/VBench/tree/master/vbench2_beta_i2v#22-download\"\n            )\n            return None\n\n        return vbench_i2v_dir if os.path.exists(info_json_path) else None\n\n    def _load_from_i2v_json(self, json_path: str) -> List[Dict[str, Any]]:\n        \"\"\"Load I2V data from i2v-bench-info.json format.\"\"\"\n        with open(json_path, \"r\") as f:\n            items = json.load(f)\n\n        base_dir = os.path.dirname(\n            os.path.dirname(json_path)\n        )  # Go up to vbench2_beta_i2v\n        origin_dir = os.path.join(base_dir, \"data\", \"origin\")\n\n        data = []\n        for item in items:\n            img_path = os.path.join(origin_dir, item.get(\"file_name\", \"\"))\n            if os.path.exists(img_path):\n                data.append({\"prompt\": item.get(\"caption\", \"\"), \"image_path\": img_path})\n            else:\n                logger.warning(f\"Image not found: {img_path}\")\n\n        logger.info(f\"Loaded {len(data)} I2V samples from VBench I2V dataset\")\n        return data\n\n    def _scan_directory_for_images(self, path: str) -> List[Dict[str, Any]]:\n        \"\"\"Scan directory for image files.\"\"\"\n        exts = [\"*.jpg\", \"*.jpeg\", \"*.png\", \"*.webp\"]\n        files = []\n\n        for ext in exts:\n            files.extend(glob.glob(os.path.join(path, ext)))\n            files.extend(glob.glob(os.path.join(path, ext.upper())))\n\n            origin_dir = os.path.join(path, \"data\", \"origin\")\n            if os.path.exists(origin_dir):\n                files.extend(glob.glob(os.path.join(origin_dir, ext)))\n                files.extend(glob.glob(os.path.join(origin_dir, ext.upper())))\n\n        return [\n            {\"prompt\": os.path.splitext(os.path.basename(f))[0], \"image_path\": f}\n            for f in files\n        ]\n\n    def _create_dummy_data(self) -> List[Dict[str, Any]]:\n        \"\"\"Create dummy data with a placeholder image in cache directory.\"\"\"\n        logger.info(\"No I2V data found. Using dummy placeholders.\")\n\n        dummy_image = os.path.join(self.cache_dir, \"dummy_image.jpg\")\n        if not os.path.exists(dummy_image):\n            os.makedirs(self.cache_dir, exist_ok=True)\n            img = Image.new(\"RGB\", (100, 100), color=\"red\")\n            img.save(dummy_image)\n            logger.info(f\"Created dummy image at {dummy_image}\")\n\n        return [{\"prompt\": \"A moving cat\", \"image_path\": dummy_image}] * 10\n\n    def _load_i2v_data(self) -> List[Dict[str, Any]]:\n        \"\"\"Load I2V data from VBench I2V dataset or user-provided path.\"\"\"\n        path = self.args.dataset_path\n        if not path:\n            path = self._auto_download_i2v_dataset()\n            if not path:\n                return self._resize_data(self._create_dummy_data())\n\n        info_json_candidates = [\n            os.path.join(path, \"data\", \"i2v-bench-info.json\"),\n            path if path.endswith(\".json\") else None,\n        ]\n\n        for json_path in info_json_candidates:\n            if json_path and os.path.exists(json_path):\n                try:\n                    return self._resize_data(self._load_from_i2v_json(json_path))\n                except Exception as e:\n                    logger.info(f\"Failed to load {json_path}: {e}\")\n\n        if os.path.isdir(path):\n            data = self._scan_directory_for_images(path)\n            if data:\n                return self._resize_data(data)\n\n        return self._resize_data(self._create_dummy_data())\n\n    def _resize_data(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:\n        \"\"\"Resize data to match num_prompts.\"\"\"\n        if not self.args.num_prompts:\n            return data\n\n        if len(data) < self.args.num_prompts:\n            factor = (self.args.num_prompts // len(data)) + 1\n            data = data * factor\n\n        return data[: self.args.num_prompts]\n\n    def __len__(self) -> int:\n        return len(self.items)\n\n    def __getitem__(self, idx: int) -> RequestFuncInput:\n        item = self.items[idx]\n        return RequestFuncInput(\n            prompt=item.get(\"prompt\", \"\"),\n            api_url=self.api_url,\n            model=self.model,\n            width=self.args.width,\n            height=self.args.height,\n            num_frames=self.args.num_frames,\n            fps=self.args.fps,\n            image_paths=[item[\"image_path\"]] if \"image_path\" in item else None,\n        )\n\n\nclass RandomDataset(BaseDataset):\n    def __init__(self, args, api_url: str = \"\", model: str = \"\"):\n        super().__init__(args, api_url, model)\n        self.num_prompts = args.num_prompts or 100\n\n    def __len__(self) -> int:\n        return self.num_prompts\n\n    def __getitem__(self, idx: int) -> RequestFuncInput:\n        return RequestFuncInput(\n            prompt=f\"Random prompt {idx} for benchmarking diffusion models\",\n            api_url=self.api_url,\n            model=self.model,\n            width=self.args.width,\n            height=self.args.height,\n            num_frames=self.args.num_frames,\n            fps=self.args.fps,\n        )\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/__init__.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# Configs for pipelines, and pipeline modules (in models folder)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/backend/vmoba/wan_1.3B_77_448_832.json",
    "content": "{\n    \"temporal_chunk_size\": 2,\n    \"temporal_topk\": 2,\n    \"spatial_chunk_size\": [4, 13],\n    \"spatial_topk\": 6,\n    \"st_chunk_size\": [4, 4, 13],\n    \"st_topk\": 18,\n    \"moba_select_mode\": \"topk\",\n    \"moba_threshold\": 0.25,\n    \"moba_threshold_type\": \"query_head\",\n    \"first_full_layer\": 0,\n    \"first_full_step\": 12,\n    \"temporal_layer\": 1,\n    \"spatial_layer\": 1,\n    \"st_layer\": 1\n}\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/backend/vmoba/wan_1.3B_77_480_832.json",
    "content": "{\n    \"temporal_chunk_size\": 2,\n    \"temporal_topk\": 3,\n    \"spatial_chunk_size\": [3, 4],\n    \"spatial_topk\": 20,\n    \"st_chunk_size\": [4, 6, 4],\n    \"st_topk\": 15,\n    \"moba_select_mode\": \"threshold\",\n    \"moba_threshold\": 0.25,\n    \"moba_threshold_type\": \"query_head\",\n    \"first_full_layer\": 0,\n    \"first_full_step\": 12,\n    \"temporal_layer\": 1,\n    \"spatial_layer\": 1,\n    \"st_layer\": 1\n}\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/models/__init__.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\nfrom sglang.multimodal_gen.configs.models.base import ModelConfig\nfrom sglang.multimodal_gen.configs.models.dits.base import DiTConfig\nfrom sglang.multimodal_gen.configs.models.encoders.base import EncoderConfig\nfrom sglang.multimodal_gen.configs.models.vaes.base import VAEConfig\n\n__all__ = [\"ModelConfig\", \"VAEConfig\", \"DiTConfig\", \"EncoderConfig\"]\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/models/adapter/base.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\nfrom dataclasses import dataclass, field\nfrom typing import Any\n\nfrom sglang.multimodal_gen.configs.models.base import ArchConfig, ModelConfig\nfrom sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum\n\n\n@dataclass\nclass AdapterArchConfig(ArchConfig):\n    _fsdp_shard_conditions: list = field(default_factory=list)\n    _compile_conditions: list = field(default_factory=list)\n\n    # convert weights name from HF-format to SGLang-dit-format\n    param_names_mapping: dict = field(default_factory=dict)\n\n    # Reverse mapping for saving checkpoints: custom -> hf\n    reverse_param_names_mapping: dict = field(default_factory=dict)\n    _supported_attention_backends: set[AttentionBackendEnum] = field(\n        default_factory=lambda: {\n            AttentionBackendEnum.SLIDING_TILE_ATTN,\n            AttentionBackendEnum.SAGE_ATTN,\n            AttentionBackendEnum.FA,\n            AttentionBackendEnum.AITER,\n            AttentionBackendEnum.AITER_SAGE,\n            AttentionBackendEnum.TORCH_SDPA,\n            AttentionBackendEnum.VIDEO_SPARSE_ATTN,\n            AttentionBackendEnum.VMOBA_ATTN,\n            AttentionBackendEnum.SAGE_ATTN_3,\n        }\n    )\n\n    hidden_size: int = 0\n    num_attention_heads: int = 0\n    num_channels_latents: int = 0\n    exclude_lora_layers: list[str] = field(default_factory=list)\n    boundary_ratio: float | None = None\n\n    def __post_init__(self) -> None:\n        if not self._compile_conditions:\n            self._compile_conditions = self._fsdp_shard_conditions.copy()\n\n\n@dataclass\nclass AdapterConfig(ModelConfig):\n    arch_config: AdapterArchConfig = field(default_factory=AdapterArchConfig)\n\n    # sglang-diffusion Adapter-specific parameters\n    prefix: str = \"\"\n\n    @staticmethod\n    def add_cli_args(parser: Any, prefix: str = \"dit-config\") -> Any:\n        \"\"\"Add CLI arguments for AdapterConfig fields\"\"\"\n        parser.add_argument(\n            f\"--{prefix}.prefix\",\n            type=str,\n            dest=f\"{prefix.replace('-', '_')}.prefix\",\n            default=AdapterConfig.prefix,\n            help=\"Prefix for the Adapter\",\n        )\n\n        return parser\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/models/adapter/ltx_2_connector.py",
    "content": "from dataclasses import dataclass, field\n\nfrom sglang.multimodal_gen.configs.models.adapter.base import (\n    AdapterArchConfig,\n    AdapterConfig,\n)\n\n\n@dataclass\nclass LTX2ConnectorArchConfig(AdapterArchConfig):\n    audio_connector_attention_head_dim: int = 128\n    audio_connector_num_attention_heads: int = 30\n    audio_connector_num_layers: int = 2\n    audio_connector_num_learnable_registers: int = 128\n    caption_channels: int = 3840\n    causal_temporal_positioning: bool = False\n    connector_rope_base_seq_len: int = 4096\n    rope_double_precision: bool = True\n    rope_theta: float = 10000.0\n    rope_type: str = \"split\"\n    text_proj_in_factor: int = 49\n    video_connector_attention_head_dim: int = 128\n    video_connector_num_attention_heads: int = 30\n    video_connector_num_layers: int = 2\n    video_connector_num_learnable_registers: int = 128\n\n\n@dataclass\nclass LTX2ConnectorConfig(AdapterConfig):\n\n    arch_config: AdapterArchConfig = field(default_factory=LTX2ConnectorArchConfig)\n\n    prefix: str = \"LTX2\"\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/models/base.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\nfrom dataclasses import dataclass, field, fields\nfrom typing import Any, Dict\n\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\n# 1. ArchConfig contains all fields from diffuser's/transformer's config.json (i.e. all fields related to the architecture of the model)\n# 2. ArchConfig should be inherited & overridden by each model arch_config\n# 3. Any field in ArchConfig is fixed upon initialization, and should be hidden away from users\n@dataclass\nclass ArchConfig:\n    stacked_params_mapping: list[tuple[str, str, str]] = field(\n        default_factory=list\n    )  # mapping from huggingface weight names to custom names\n    extra_attrs: Dict[str, Any] = field(default_factory=dict)\n\n    def __getattr__(self, name: str):\n        d = object.__getattribute__(self, \"__dict__\")\n        extras = d.get(\"extra_attrs\")\n        if extras is not None and name in extras:\n            return extras[name]\n        raise AttributeError(\n            f\"'{self.__class__.__name__}' object has no attribute '{name}'\"\n        )\n\n    def __setattr__(self, key, value):\n        if key in type(self).__dataclass_fields__:\n            object.__setattr__(self, key, value)\n        else:\n            d = object.__getattribute__(self, \"__dict__\")\n            extras = d.get(\"extra_attrs\")\n            if extras is None:\n                extras = {}\n                d[\"extra_attrs\"] = extras\n            extras[key] = value\n\n\n@dataclass\nclass ModelConfig:\n    # Every model config parameter can be categorized into either ArchConfig or everything else\n    # Diffuser/Transformer parameters\n    arch_config: ArchConfig = field(default_factory=ArchConfig)\n\n    # sglang-diffusion-specific parameters here\n    # i.e. STA, quantization, teacache\n\n    def __getattr__(self, name):\n        # Only called if 'name' is not found in ModelConfig directly\n        if hasattr(self.arch_config, name):\n            return getattr(self.arch_config, name)\n        raise AttributeError(\n            f\"'{type(self).__name__}' object has no attribute '{name}'\"\n        )\n\n    def __getstate__(self):\n        # Return a dictionary of attributes to pickle\n        # Convert to dict and exclude any problematic attributes\n        state = self.__dict__.copy()\n        return state\n\n    def __setstate__(self, state):\n        # Restore instance attributes from the unpickled state\n        self.__dict__.update(state)\n\n    # This should be used only when loading from transformers/diffusers\n    def update_model_arch(self, source_model_dict: dict[str, Any]) -> None:\n        \"\"\"\n        Update arch_config with source_model_dict\n        \"\"\"\n        arch_config = self.arch_config\n\n        for key, value in source_model_dict.items():\n            setattr(arch_config, key, value)\n\n        if hasattr(arch_config, \"__post_init__\"):\n            arch_config.__post_init__()\n\n    def update_model_config(self, source_model_dict: dict[str, Any]) -> None:\n        assert (\n            \"arch_config\" not in source_model_dict\n        ), \"Source model config shouldn't contain arch_config.\"\n\n        valid_fields = {f.name for f in fields(self)}\n\n        for key, value in source_model_dict.items():\n            if key in valid_fields:\n                setattr(self, key, value)\n            else:\n                logger.warning(\n                    \"%s does not contain field '%s'!\", type(self).__name__, key\n                )\n                raise AttributeError(f\"Invalid field: {key}\")\n\n        if hasattr(self, \"__post_init__\"):\n            self.__post_init__()\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/models/bridges/__init__.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\nfrom sglang.multimodal_gen.configs.models.bridges.mova_dual_tower import (\n    MOVADualTowerConfig,\n)\n\n__all__ = [\"MOVADualTowerConfig\"]\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/models/bridges/mova_dual_tower.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\"\"\"Configuration for MOVA dual tower bridge model.\"\"\"\n\nfrom dataclasses import dataclass, field\n\nfrom sglang.multimodal_gen.configs.models.dits.base import DiTArchConfig, DiTConfig\n\n\ndef _is_conditioner_block(name: str, module) -> bool:\n    \"\"\"Check if module is a ConditionalCrossAttentionBlock.\"\"\"\n    return \"ConditionalCrossAttentionBlock\" in type(module).__name__\n\n\n@dataclass\nclass MOVADualTowerArchConfig(DiTArchConfig):\n    _fsdp_shard_conditions: list = field(\n        default_factory=lambda: [_is_conditioner_block]\n    )\n\n    # Model architecture parameters\n    visual_layers: int = 40\n    audio_layers: int = 30\n    visual_hidden_dim: int = 5120\n    audio_hidden_dim: int = 1536\n    audio_fps: float = 50.0\n    head_dim: int = 128\n    interaction_strategy: str = \"full\"\n    apply_cross_rope: bool = True\n    apply_first_frame_bias_in_rope: bool = False\n    trainable_condition_scale: bool = False\n    pooled_adaln: bool = False\n    eps: float = 1e-6\n\n    def __post_init__(self):\n        super().__post_init__()\n        self.hidden_size = self.visual_hidden_dim\n        self.num_attention_heads = self.visual_hidden_dim // self.head_dim\n\n\n@dataclass\nclass MOVADualTowerConfig(DiTConfig):\n    arch_config: DiTArchConfig = field(default_factory=MOVADualTowerArchConfig)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/models/dits/__init__.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\nfrom sglang.multimodal_gen.configs.models.dits.helios import HeliosConfig\nfrom sglang.multimodal_gen.configs.models.dits.hunyuan3d import Hunyuan3DDiTConfig\nfrom sglang.multimodal_gen.configs.models.dits.hunyuanvideo import HunyuanVideoConfig\nfrom sglang.multimodal_gen.configs.models.dits.mova_audio import MOVAAudioConfig\nfrom sglang.multimodal_gen.configs.models.dits.mova_video import MOVAVideoConfig\nfrom sglang.multimodal_gen.configs.models.dits.wanvideo import WanVideoConfig\n\n__all__ = [\n    \"HeliosConfig\",\n    \"HunyuanVideoConfig\",\n    \"WanVideoConfig\",\n    \"Hunyuan3DDiTConfig\",\n    \"MOVAAudioConfig\",\n    \"MOVAVideoConfig\",\n]\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/models/dits/base.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\nfrom dataclasses import dataclass, field\nfrom typing import Any\n\nfrom sglang.multimodal_gen.configs.models.base import ArchConfig, ModelConfig\nfrom sglang.multimodal_gen.runtime.layers.quantization import QuantizationConfig\nfrom sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum\n\n\n@dataclass\nclass DiTArchConfig(ArchConfig):\n    _fsdp_shard_conditions: list = field(default_factory=list)\n    _compile_conditions: list = field(default_factory=list)\n\n    # convert weights name from HF-format to SGLang-dit-format\n    param_names_mapping: dict = field(default_factory=dict)\n\n    # convert weights name from misc-format to HF-format\n    # usually applicable if the LoRA is trained with official repo implementation\n    lora_param_names_mapping: dict = field(default_factory=dict)\n\n    # Reverse mapping for saving checkpoints: custom -> hf\n    reverse_param_names_mapping: dict = field(default_factory=dict)\n    _supported_attention_backends: set[AttentionBackendEnum] = field(\n        default_factory=lambda: {\n            AttentionBackendEnum.SLIDING_TILE_ATTN,\n            AttentionBackendEnum.SAGE_ATTN,\n            AttentionBackendEnum.FA,\n            AttentionBackendEnum.AITER,\n            AttentionBackendEnum.AITER_SAGE,\n            AttentionBackendEnum.TORCH_SDPA,\n            AttentionBackendEnum.VIDEO_SPARSE_ATTN,\n            AttentionBackendEnum.SPARSE_VIDEO_GEN_2_ATTN,\n            AttentionBackendEnum.VMOBA_ATTN,\n            AttentionBackendEnum.SAGE_ATTN_3,\n        }\n    )\n\n    hidden_size: int = 0\n    num_attention_heads: int = 0\n    num_channels_latents: int = 0\n    exclude_lora_layers: list[str] = field(default_factory=list)\n    boundary_ratio: float | None = None\n\n    def __post_init__(self) -> None:\n        if not self._compile_conditions:\n            self._compile_conditions = self._fsdp_shard_conditions.copy()\n\n\n@dataclass\nclass DiTConfig(ModelConfig):\n    arch_config: DiTArchConfig = field(default_factory=DiTArchConfig)\n\n    # sglang-diffusion DiT-specific parameters\n    prefix: str = \"\"\n    quant_config: QuantizationConfig | None = None\n\n    @staticmethod\n    def add_cli_args(parser: Any, prefix: str = \"dit-config\") -> Any:\n        \"\"\"Add CLI arguments for DiTConfig fields\"\"\"\n        parser.add_argument(\n            f\"--{prefix}.prefix\",\n            type=str,\n            dest=f\"{prefix.replace('-', '_')}.prefix\",\n            default=DiTConfig.prefix,\n            help=\"Prefix for the DiT model\",\n        )\n\n        parser.add_argument(\n            f\"--{prefix}.quant-config\",\n            type=str,\n            dest=f\"{prefix.replace('-', '_')}.quant_config\",\n            default=None,\n            help=\"Quantization configuration for the DiT model\",\n        )\n\n        return parser\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/models/dits/flux.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\nfrom dataclasses import dataclass, field\nfrom typing import Tuple\n\nfrom sglang.multimodal_gen.configs.models.dits.base import DiTArchConfig, DiTConfig\n\n\n@dataclass\nclass FluxArchConfig(DiTArchConfig):\n    patch_size: int = 1\n    in_channels: int = 64\n    out_channels: int | None = None\n    num_layers: int = 19\n    num_single_layers: int = 38\n    attention_head_dim: int = 128\n    num_attention_heads: int = 24\n    joint_attention_dim: int = 4096\n    pooled_projection_dim: int = 768\n    guidance_embeds: bool = False\n    axes_dims_rope: Tuple[int, int, int] = (16, 56, 56)\n\n    stacked_params_mapping: list[tuple[str, str, str]] = field(default_factory=list)\n\n    exclude_lora_layers: list[str] = field(\n        default_factory=lambda: [\n            \"time_guidance_embed.timestep_embedder.linear_1\",\n            \"time_guidance_embed.timestep_embedder.linear_2\",\n            \"time_guidance_embed.guidance_embedder.linear_1\",\n            \"time_guidance_embed.guidance_embedder.linear_2\",\n        ]\n    )\n\n    # nunchaku checkpoint uses different weight names; map to sglang flux layout\n    param_names_mapping: dict = field(\n        default_factory=lambda: {\n            # HF diffusers format\n            r\"^transformer\\.(\\w*)\\.(.*)$\": r\"\\1.\\2\",\n            # transformer_blocks nunchaku format (raw export - before internal conversion)\n            r\"^transformer_blocks\\.(\\d+)\\.mlp_fc1\\.(.*)$\": r\"transformer_blocks.\\1.ff.net.0.proj.\\2\",\n            r\"^transformer_blocks\\.(\\d+)\\.mlp_fc2\\.(.*)$\": r\"transformer_blocks.\\1.ff.net.2.\\2\",\n            r\"^transformer_blocks\\.(\\d+)\\.mlp_context_fc1\\.(.*)$\": r\"transformer_blocks.\\1.ff_context.net.0.proj.\\2\",\n            r\"^transformer_blocks\\.(\\d+)\\.mlp_context_fc2\\.(.*)$\": r\"transformer_blocks.\\1.ff_context.net.2.\\2\",\n            r\"^transformer_blocks\\.(\\d+)\\.qkv_proj\\.(.*)$\": r\"transformer_blocks.\\1.attn.to_qkv.\\2\",\n            r\"^transformer_blocks\\.(\\d+)\\.qkv_proj_context\\.(.*)$\": r\"transformer_blocks.\\1.attn.to_added_qkv.\\2\",\n            r\"^transformer_blocks\\.(\\d+)\\.out_proj\\.(.*)$\": r\"transformer_blocks.\\1.attn.to_out.0.\\2\",\n            r\"^transformer_blocks\\.(\\d+)\\.out_proj_context\\.(.*)$\": r\"transformer_blocks.\\1.attn.to_add_out.\\2\",\n            r\"^transformer_blocks\\.(\\d+)\\.norm_q\\.(.*)$\": r\"transformer_blocks.\\1.attn.norm_q.\\2\",\n            r\"^transformer_blocks\\.(\\d+)\\.norm_k\\.(.*)$\": r\"transformer_blocks.\\1.attn.norm_k.\\2\",\n            r\"^transformer_blocks\\.(\\d+)\\.norm_added_q\\.(.*)$\": r\"transformer_blocks.\\1.attn.norm_added_q.\\2\",\n            r\"^transformer_blocks\\.(\\d+)\\.norm_added_k\\.(.*)$\": r\"transformer_blocks.\\1.attn.norm_added_k.\\2\",\n            # transformer_blocks nunchaku format (already converted with convert_flux_state_dict)\n            r\"^transformer_blocks\\.(\\d+)\\.attn\\.add_qkv_proj\\.(.*)$\": r\"transformer_blocks.\\1.attn.to_added_qkv.\\2\",\n            # single_transformer_blocks nunchaku format (raw export - before internal conversion)\n            r\"^single_transformer_blocks\\.(\\d+)\\.qkv_proj\\.(.*)$\": r\"single_transformer_blocks.\\1.attn.to_qkv.\\2\",\n            r\"^single_transformer_blocks\\.(\\d+)\\.out_proj\\.(.*)$\": r\"single_transformer_blocks.\\1.attn.to_out.0.\\2\",\n            r\"^single_transformer_blocks\\.(\\d+)\\.norm_q\\.(.*)$\": r\"single_transformer_blocks.\\1.attn.norm_q.\\2\",\n            r\"^single_transformer_blocks\\.(\\d+)\\.norm_k\\.(.*)$\": r\"single_transformer_blocks.\\1.attn.norm_k.\\2\",\n            # nunchaku quantization parameter name conversions (apply to all blocks)\n            r\"^(.*)\\.smooth_orig$\": r\"\\1.smooth_factor_orig\",\n            r\"^(.*)\\.smooth$\": r\"\\1.smooth_factor\",\n            r\"^(.*)\\.lora_down$\": r\"\\1.proj_down\",\n            r\"^(.*)\\.lora_up$\": r\"\\1.proj_up\",\n        }\n    )\n\n    def __post_init__(self):\n        super().__post_init__()\n        self.out_channels = self.out_channels or self.in_channels\n        self.hidden_size = self.num_attention_heads * self.attention_head_dim\n        self.num_channels_latents = self.out_channels\n\n\n@dataclass\nclass FluxConfig(DiTConfig):\n\n    arch_config: DiTArchConfig = field(default_factory=FluxArchConfig)\n\n    prefix: str = \"Flux\"\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/models/dits/glmimage.py",
    "content": "from dataclasses import dataclass, field\n\nfrom sglang.multimodal_gen.configs.models.dits.base import DiTArchConfig, DiTConfig\n\n\n@dataclass\nclass GlmImageArchConfig(DiTArchConfig):\n    patch_size: int = 2\n    in_channels: int = 16\n    out_channels: int | None = 16\n    num_layers: int = 30\n    attention_head_dim: int = 128\n    num_attention_heads: int = 32\n    condition_dim: int = 256\n    prior_vq_quantizer_codebook_size: int = 16384\n    text_embed_dim: int = 1472\n    time_embed_dim: int = 512\n\n    stacked_params_mapping: list[tuple[str, str, str]] = field(default_factory=list)\n\n    param_names_mapping: dict = field(\n        default_factory=lambda: {\n            # LoRA mappings\n            r\"^(transformer_blocks\\.\\d+\\.attn\\..*\\.lora_[AB])\\.default$\": r\"\\1\",\n        }\n    )\n\n    def __post_init__(self):\n        super().__post_init__()\n        self.out_channels = self.out_channels or self.in_channels\n        self.hidden_size = self.num_attention_heads * self.attention_head_dim\n        self.num_channels_latents = self.out_channels\n\n\n@dataclass\nclass GlmImageDitConfig(DiTConfig):\n    arch_config: DiTArchConfig = field(default_factory=GlmImageArchConfig)\n\n    prefix: str = \"glmimage\"\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/models/dits/helios.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\nfrom dataclasses import dataclass, field\n\nfrom sglang.multimodal_gen.configs.models.dits.base import DiTArchConfig, DiTConfig\n\n\ndef is_blocks(n: str, m) -> bool:\n    return \"blocks\" in n and str.isdigit(n.split(\".\")[-1])\n\n\n@dataclass\nclass HeliosArchConfig(DiTArchConfig):\n    _fsdp_shard_conditions: list = field(default_factory=lambda: [is_blocks])\n\n    param_names_mapping: dict = field(\n        default_factory=lambda: {\n            # Patch embeddings\n            r\"^patch_embedding\\.(.*)$\": r\"patch_embedding.proj.\\1\",\n            # Condition embedder: text\n            r\"^condition_embedder\\.text_embedder\\.linear_1\\.(.*)$\": r\"condition_embedder.text_embedder.fc_in.\\1\",\n            r\"^condition_embedder\\.text_embedder\\.linear_2\\.(.*)$\": r\"condition_embedder.text_embedder.fc_out.\\1\",\n            # Condition embedder: time\n            r\"^condition_embedder\\.time_embedder\\.linear_1\\.(.*)$\": r\"condition_embedder.time_embedder.mlp.fc_in.\\1\",\n            r\"^condition_embedder\\.time_embedder\\.linear_2\\.(.*)$\": r\"condition_embedder.time_embedder.mlp.fc_out.\\1\",\n            r\"^condition_embedder\\.time_proj\\.(.*)$\": r\"condition_embedder.time_modulation.linear.\\1\",\n            # Blocks: self-attention (keep attn1. prefix, drop .0. from to_out)\n            r\"^blocks\\.(\\d+)\\.attn1\\.to_out\\.0\\.(.*)$\": r\"blocks.\\1.attn1.to_out.\\2\",\n            # Blocks: cross-attention output (drop .0. from to_out)\n            r\"^blocks\\.(\\d+)\\.attn2\\.to_out\\.0\\.(.*)$\": r\"blocks.\\1.attn2.to_out.\\2\",\n            # Blocks: feed-forward\n            r\"^blocks\\.(\\d+)\\.ffn\\.net\\.0\\.proj\\.(.*)$\": r\"blocks.\\1.ffn.fc_in.\\2\",\n            r\"^blocks\\.(\\d+)\\.ffn\\.net\\.2\\.(.*)$\": r\"blocks.\\1.ffn.fc_out.\\2\",\n            # Blocks: cross-attn residual norm\n            r\"^blocks\\.(\\d+)\\.norm2\\.(.*)$\": r\"blocks.\\1.self_attn_residual_norm.\\2\",\n        }\n    )\n\n    reverse_param_names_mapping: dict = field(default_factory=lambda: {})\n\n    lora_param_names_mapping: dict = field(default_factory=lambda: {})\n\n    patch_size: tuple[int, int, int] = (1, 2, 2)\n    text_len: int = 226\n    num_attention_heads: int = 40\n    attention_head_dim: int = 128\n    in_channels: int = 16\n    out_channels: int = 16\n    text_dim: int = 4096\n    freq_dim: int = 256\n    ffn_dim: int = 13824\n    num_layers: int = 40\n    cross_attn_norm: bool = True\n    qk_norm: str = \"rms_norm_across_heads\"\n    eps: float = 1e-6\n    added_kv_proj_dim: int | None = None\n    rope_max_seq_len: int = 1024\n    pos_embed_seq_len: int | None = None\n    exclude_lora_layers: list[str] = field(default_factory=lambda: [\"embedder\"])\n\n    # Helios-specific\n    rope_dim: tuple[int, int, int] = (44, 42, 42)\n    rope_theta: float = 10000.0\n    guidance_cross_attn: bool = True\n    zero_history_timestep: bool = True\n    has_multi_term_memory_patch: bool = True\n    is_amplify_history: bool = False\n    history_scale_mode: str = \"per_head\"\n\n    def __post_init__(self):\n        super().__post_init__()\n        self.out_channels = self.out_channels or self.in_channels\n        self.hidden_size = self.num_attention_heads * self.attention_head_dim\n        self.num_channels_latents = self.out_channels\n\n\n@dataclass\nclass HeliosConfig(DiTConfig):\n    arch_config: DiTArchConfig = field(default_factory=HeliosArchConfig)\n\n    prefix: str = \"Helios\"\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/models/dits/hunyuan3d.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\nfrom dataclasses import dataclass, field\n\nfrom sglang.multimodal_gen.configs.models.dits.base import DiTArchConfig, DiTConfig\n\n\n@dataclass\nclass Hunyuan3DDiTArchConfig(DiTArchConfig):\n    \"\"\"Architecture config for Hunyuan3D DiT (Flux-style for Hunyuan3D-2.0).\"\"\"\n\n    param_names_mapping: dict = field(\n        default_factory=lambda: {\n            r\"(.*)\\.img_mlp\\.0\\.(.*)$\": r\"\\1.img_mlp.fc_in.\\2\",\n            r\"(.*)\\.img_mlp\\.2\\.(.*)$\": r\"\\1.img_mlp.fc_out.\\2\",\n            r\"(.*)\\.txt_mlp\\.0\\.(.*)$\": r\"\\1.txt_mlp.fc_in.\\2\",\n            r\"(.*)\\.txt_mlp\\.2\\.(.*)$\": r\"\\1.txt_mlp.fc_out.\\2\",\n        }\n    )\n\n    in_channels: int = 64\n    hidden_size: int = 1024\n    num_attention_heads: int = 16\n    num_layers: int = 16\n    num_single_layers: int = 32\n    mlp_ratio: float = 4.0\n    context_in_dim: int = 1536\n    axes_dim: tuple[int, ...] = (64,)\n    theta: int = 10000\n    qkv_bias: bool = True\n    guidance_embed: bool = False\n    time_factor: float = 1000.0\n\n    def __post_init__(self) -> None:\n        if self.num_channels_latents == 0:\n            self.num_channels_latents = self.in_channels\n        super().__post_init__()\n\n\n@dataclass\nclass Hunyuan3DDiTConfig(DiTConfig):\n    \"\"\"DiT configuration for Hunyuan3D shape generation (Flux-style).\"\"\"\n\n    arch_config: Hunyuan3DDiTArchConfig = field(default_factory=Hunyuan3DDiTArchConfig)\n    subfolder: str = \"hunyuan3d-dit-v2-0\"\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/models/dits/hunyuanvideo.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\nfrom dataclasses import dataclass, field\n\nimport torch\n\nfrom sglang.multimodal_gen.configs.models.dits.base import DiTArchConfig, DiTConfig\n\n\ndef is_double_block(n: str, m) -> bool:\n    return \"double\" in n and str.isdigit(n.split(\".\")[-1])\n\n\ndef is_single_block(n: str, m) -> bool:\n    return \"single\" in n and str.isdigit(n.split(\".\")[-1])\n\n\ndef is_refiner_block(n: str, m) -> bool:\n    return \"refiner\" in n and str.isdigit(n.split(\".\")[-1])\n\n\ndef is_txt_in(n: str, m) -> bool:\n    return n.split(\".\")[-1] == \"txt_in\"\n\n\n@dataclass\nclass HunyuanVideoArchConfig(DiTArchConfig):\n    _fsdp_shard_conditions: list = field(\n        default_factory=lambda: [is_double_block, is_single_block, is_refiner_block]\n    )\n\n    _compile_conditions: list = field(\n        default_factory=lambda: [is_double_block, is_single_block, is_txt_in]\n    )\n\n    param_names_mapping: dict = field(\n        default_factory=lambda: {\n            # 1. context_embedder.time_text_embed submodules (specific rules, applied first):\n            r\"^context_embedder\\.time_text_embed\\.timestep_embedder\\.linear_1\\.(.*)$\": r\"txt_in.t_embedder.mlp.fc_in.\\1\",\n            r\"^context_embedder\\.time_text_embed\\.timestep_embedder\\.linear_2\\.(.*)$\": r\"txt_in.t_embedder.mlp.fc_out.\\1\",\n            r\"^context_embedder\\.proj_in\\.(.*)$\": r\"txt_in.input_embedder.\\1\",\n            r\"^context_embedder\\.time_text_embed\\.text_embedder\\.linear_1\\.(.*)$\": r\"txt_in.c_embedder.fc_in.\\1\",\n            r\"^context_embedder\\.time_text_embed\\.text_embedder\\.linear_2\\.(.*)$\": r\"txt_in.c_embedder.fc_out.\\1\",\n            r\"^context_embedder\\.token_refiner\\.refiner_blocks\\.(\\d+)\\.norm1\\.(.*)$\": r\"txt_in.refiner_blocks.\\1.norm1.\\2\",\n            r\"^context_embedder\\.token_refiner\\.refiner_blocks\\.(\\d+)\\.norm2\\.(.*)$\": r\"txt_in.refiner_blocks.\\1.norm2.\\2\",\n            r\"^context_embedder\\.token_refiner\\.refiner_blocks\\.(\\d+)\\.attn\\.to_q\\.(.*)$\": (\n                r\"txt_in.refiner_blocks.\\1.self_attn_qkv.\\2\",\n                0,\n                3,\n            ),\n            r\"^context_embedder\\.token_refiner\\.refiner_blocks\\.(\\d+)\\.attn\\.to_k\\.(.*)$\": (\n                r\"txt_in.refiner_blocks.\\1.self_attn_qkv.\\2\",\n                1,\n                3,\n            ),\n            r\"^context_embedder\\.token_refiner\\.refiner_blocks\\.(\\d+)\\.attn\\.to_v\\.(.*)$\": (\n                r\"txt_in.refiner_blocks.\\1.self_attn_qkv.\\2\",\n                2,\n                3,\n            ),\n            r\"^context_embedder\\.token_refiner\\.refiner_blocks\\.(\\d+)\\.attn\\.to_out\\.0\\.(.*)$\": r\"txt_in.refiner_blocks.\\1.self_attn_proj.\\2\",\n            r\"^context_embedder\\.token_refiner\\.refiner_blocks\\.(\\d+)\\.ff\\.net\\.0(?:\\.proj)?\\.(.*)$\": r\"txt_in.refiner_blocks.\\1.mlp.fc_in.\\2\",\n            r\"^context_embedder\\.token_refiner\\.refiner_blocks\\.(\\d+)\\.ff\\.net\\.2(?:\\.proj)?\\.(.*)$\": r\"txt_in.refiner_blocks.\\1.mlp.fc_out.\\2\",\n            r\"^context_embedder\\.token_refiner\\.refiner_blocks\\.(\\d+)\\.norm_out\\.linear\\.(.*)$\": r\"txt_in.refiner_blocks.\\1.adaLN_modulation.linear.\\2\",\n            # 3. x_embedder mapping:\n            r\"^x_embedder\\.proj\\.(.*)$\": r\"img_in.proj.\\1\",\n            # 4. Top-level time_text_embed mappings:\n            r\"^time_text_embed\\.timestep_embedder\\.linear_1\\.(.*)$\": r\"time_in.mlp.fc_in.\\1\",\n            r\"^time_text_embed\\.timestep_embedder\\.linear_2\\.(.*)$\": r\"time_in.mlp.fc_out.\\1\",\n            r\"^time_text_embed\\.guidance_embedder\\.linear_1\\.(.*)$\": r\"guidance_in.mlp.fc_in.\\1\",\n            r\"^time_text_embed\\.guidance_embedder\\.linear_2\\.(.*)$\": r\"guidance_in.mlp.fc_out.\\1\",\n            r\"^time_text_embed\\.text_embedder\\.linear_1\\.(.*)$\": r\"vector_in.fc_in.\\1\",\n            r\"^time_text_embed\\.text_embedder\\.linear_2\\.(.*)$\": r\"vector_in.fc_out.\\1\",\n            # 5. transformer_blocks mapping:\n            r\"^transformer_blocks\\.(\\d+)\\.norm1\\.linear\\.(.*)$\": r\"double_blocks.\\1.img_mod.linear.\\2\",\n            r\"^transformer_blocks\\.(\\d+)\\.norm1_context\\.linear\\.(.*)$\": r\"double_blocks.\\1.txt_mod.linear.\\2\",\n            r\"^transformer_blocks\\.(\\d+)\\.attn\\.norm_q\\.(.*)$\": r\"double_blocks.\\1.img_attn_q_norm.\\2\",\n            r\"^transformer_blocks\\.(\\d+)\\.attn\\.norm_k\\.(.*)$\": r\"double_blocks.\\1.img_attn_k_norm.\\2\",\n            r\"^transformer_blocks\\.(\\d+)\\.attn\\.to_q\\.(.*)$\": (\n                r\"double_blocks.\\1.img_attn_qkv.\\2\",\n                0,\n                3,\n            ),\n            r\"^transformer_blocks\\.(\\d+)\\.attn\\.to_k\\.(.*)$\": (\n                r\"double_blocks.\\1.img_attn_qkv.\\2\",\n                1,\n                3,\n            ),\n            r\"^transformer_blocks\\.(\\d+)\\.attn\\.to_v\\.(.*)$\": (\n                r\"double_blocks.\\1.img_attn_qkv.\\2\",\n                2,\n                3,\n            ),\n            r\"^transformer_blocks\\.(\\d+)\\.attn\\.add_q_proj\\.(.*)$\": (\n                r\"double_blocks.\\1.txt_attn_qkv.\\2\",\n                0,\n                3,\n            ),\n            r\"^transformer_blocks\\.(\\d+)\\.attn\\.add_k_proj\\.(.*)$\": (\n                r\"double_blocks.\\1.txt_attn_qkv.\\2\",\n                1,\n                3,\n            ),\n            r\"^transformer_blocks\\.(\\d+)\\.attn\\.add_v_proj\\.(.*)$\": (\n                r\"double_blocks.\\1.txt_attn_qkv.\\2\",\n                2,\n                3,\n            ),\n            r\"^transformer_blocks\\.(\\d+)\\.attn\\.to_out\\.0\\.(.*)$\": r\"double_blocks.\\1.img_attn_proj.\\2\",\n            # Corrected: merge attn.to_add_out into the main projection.\n            r\"^transformer_blocks\\.(\\d+)\\.attn\\.to_add_out\\.(.*)$\": r\"double_blocks.\\1.txt_attn_proj.\\2\",\n            r\"^transformer_blocks\\.(\\d+)\\.attn\\.norm_added_q\\.(.*)$\": r\"double_blocks.\\1.txt_attn_q_norm.\\2\",\n            r\"^transformer_blocks\\.(\\d+)\\.attn\\.norm_added_k\\.(.*)$\": r\"double_blocks.\\1.txt_attn_k_norm.\\2\",\n            r\"^transformer_blocks\\.(\\d+)\\.ff\\.net\\.0(?:\\.proj)?\\.(.*)$\": r\"double_blocks.\\1.img_mlp.fc_in.\\2\",\n            r\"^transformer_blocks\\.(\\d+)\\.ff\\.net\\.2(?:\\.proj)?\\.(.*)$\": r\"double_blocks.\\1.img_mlp.fc_out.\\2\",\n            r\"^transformer_blocks\\.(\\d+)\\.ff_context\\.net\\.0(?:\\.proj)?\\.(.*)$\": r\"double_blocks.\\1.txt_mlp.fc_in.\\2\",\n            r\"^transformer_blocks\\.(\\d+)\\.ff_context\\.net\\.2(?:\\.proj)?\\.(.*)$\": r\"double_blocks.\\1.txt_mlp.fc_out.\\2\",\n            # 6. single_transformer_blocks mapping:\n            r\"^single_transformer_blocks\\.(\\d+)\\.attn\\.norm_q\\.(.*)$\": r\"single_blocks.\\1.q_norm.\\2\",\n            r\"^single_transformer_blocks\\.(\\d+)\\.attn\\.norm_k\\.(.*)$\": r\"single_blocks.\\1.k_norm.\\2\",\n            r\"^single_transformer_blocks\\.(\\d+)\\.attn\\.to_q\\.(.*)$\": (\n                r\"single_blocks.\\1.linear1.\\2\",\n                0,\n                4,\n            ),\n            r\"^single_transformer_blocks\\.(\\d+)\\.attn\\.to_k\\.(.*)$\": (\n                r\"single_blocks.\\1.linear1.\\2\",\n                1,\n                4,\n            ),\n            r\"^single_transformer_blocks\\.(\\d+)\\.attn\\.to_v\\.(.*)$\": (\n                r\"single_blocks.\\1.linear1.\\2\",\n                2,\n                4,\n            ),\n            r\"^single_transformer_blocks\\.(\\d+)\\.proj_mlp\\.(.*)$\": (\n                r\"single_blocks.\\1.linear1.\\2\",\n                3,\n                4,\n            ),\n            # Corrected: map proj_out to modulation.linear rather than a separate proj_out branch.\n            r\"^single_transformer_blocks\\.(\\d+)\\.proj_out\\.(.*)$\": r\"single_blocks.\\1.linear2.\\2\",\n            r\"^single_transformer_blocks\\.(\\d+)\\.norm\\.linear\\.(.*)$\": r\"single_blocks.\\1.modulation.linear.\\2\",\n            # 7. Final layers mapping:\n            r\"^norm_out\\.linear\\.(.*)$\": r\"final_layer.adaLN_modulation.linear.\\1\",\n            r\"^proj_out\\.(.*)$\": r\"final_layer.linear.\\1\",\n        }\n    )\n\n    reverse_param_names_mapping: dict = field(default_factory=lambda: {})\n\n    patch_size: int = 2\n    patch_size_t: int = 1\n    in_channels: int = 16\n    out_channels: int = 16\n    num_attention_heads: int = 24\n    attention_head_dim: int = 128\n    mlp_ratio: float = 4.0\n    num_layers: int = 20\n    num_single_layers: int = 40\n    num_refiner_layers: int = 2\n    rope_axes_dim: tuple[int, int, int] = (16, 56, 56)\n    guidance_embeds: bool = False\n    dtype: torch.dtype | None = None\n    text_embed_dim: int = 4096\n    pooled_projection_dim: int = 768\n    rope_theta: int = 256\n    qk_norm: str = \"rms_norm\"\n    exclude_lora_layers: list[str] = field(\n        default_factory=lambda: [\"img_in\", \"txt_in\", \"time_in\", \"vector_in\"]\n    )\n\n    def __post_init__(self):\n        super().__post_init__()\n        self.hidden_size: int = self.attention_head_dim * self.num_attention_heads\n        self.num_channels_latents: int = self.in_channels\n\n\n@dataclass\nclass HunyuanVideoConfig(DiTConfig):\n    arch_config: DiTArchConfig = field(default_factory=HunyuanVideoArchConfig)\n\n    prefix: str = \"Hunyuan\"\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/models/dits/ltx_2.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\nfrom dataclasses import dataclass, field\nfrom enum import Enum\n\nfrom sglang.multimodal_gen.configs.models.dits.base import DiTArchConfig, DiTConfig\n\n\nclass LTXModelType(Enum):\n    \"\"\"\n    Model type enum mirroring upstream `LTXModelType`.\n\n    Upstream reference:\n      - `LTX-2/packages/ltx-core/src/ltx_core/model/transformer/model.py::LTXModelType`\n    \"\"\"\n\n    AudioVideo = \"ltx av model\"\n    VideoOnly = \"ltx video only model\"\n    AudioOnly = \"ltx audio only model\"\n\n    def is_video_enabled(self) -> bool:\n        return self in (LTXModelType.AudioVideo, LTXModelType.VideoOnly)\n\n    def is_audio_enabled(self) -> bool:\n        return self in (LTXModelType.AudioVideo, LTXModelType.AudioOnly)\n\n\nclass LTX2RopeType(str, Enum):\n    \"\"\"\n    Minimal RoPE type enum mirroring LTX-2 upstream `LTXRopeType`.\n\n    Upstream reference:\n      - `LTX-2/packages/ltx-core/src/ltx_core/model/transformer/rope.py::LTXRopeType`\n    \"\"\"\n\n    INTERLEAVED = \"interleaved\"\n    SPLIT = \"split\"\n\n\nclass LTX2AttentionFunction(str, Enum):\n    \"\"\"\n    Placeholder enum for upstream `AttentionFunction.DEFAULT`.\n\n    Upstream reference:\n      - `LTX-2/packages/ltx-core/src/ltx_core/model/transformer/attention.py`\n    \"\"\"\n\n    DEFAULT = \"default\"\n\n\ndef is_blocks(n: str, m) -> bool:\n    return \"blocks\" in n and str.isdigit(n.split(\".\")[-1])\n\n\n@dataclass\nclass LTX2ArchConfig(DiTArchConfig):\n    \"\"\"Architecture configuration for LTX-2 Video Transformer.\"\"\"\n\n    _fsdp_shard_conditions: list = field(default_factory=lambda: [is_blocks])\n\n    param_names_mapping: dict = field(\n        default_factory=lambda: {\n            # Parameter name mappings from HuggingFace checkpoint keys to SGLang module names.\n            # We use upstream variable names (patchify_proj, adaln_single) but HF uses different keys.\n            #\n            # HF key -> SGLang key (upstream naming)\n            r\"^proj_in\\.(.*)$\": r\"patchify_proj.\\1\",\n            r\"^time_embed\\.(.*)$\": r\"adaln_single.\\1\",\n            r\"^audio_proj_in\\.(.*)$\": r\"audio_patchify_proj.\\1\",\n            r\"^audio_time_embed\\.(.*)$\": r\"audio_adaln_single.\\1\",\n            # FeedForward\n            r\"(.*)ff\\.net\\.0\\.proj\\.(.*)$\": r\"\\1ff.proj_in.\\2\",\n            r\"(.*)ff\\.net\\.2\\.(.*)$\": r\"\\1ff.proj_out.\\2\",\n            # Attention Norms\n            r\"(.*)\\.norm_q\\.(.*)$\": r\"\\1.q_norm.\\2\",\n            r\"(.*)\\.norm_k\\.(.*)$\": r\"\\1.k_norm.\\2\",\n            # Scale Shift Tables (Global)\n            r\"^av_cross_attn_video_scale_shift\\.(.*)$\": r\"av_ca_video_scale_shift_adaln_single.\\1\",\n            r\"^av_cross_attn_audio_scale_shift\\.(.*)$\": r\"av_ca_audio_scale_shift_adaln_single.\\1\",\n            r\"^av_cross_attn_video_a2v_gate\\.(.*)$\": r\"av_ca_a2v_gate_adaln_single.\\1\",\n            r\"^av_cross_attn_audio_v2a_gate\\.(.*)$\": r\"av_ca_v2a_gate_adaln_single.\\1\",\n            # Scale Shift Tables (Block Level)\n            # HF: scale_shift_table_a2v_ca_video -> SGLang: video_a2v_cross_attn_scale_shift_table\n            r\"(.*)scale_shift_table_a2v_ca_video\": r\"\\1video_a2v_cross_attn_scale_shift_table\",\n            r\"(.*)scale_shift_table_a2v_ca_audio\": r\"\\1audio_a2v_cross_attn_scale_shift_table\",\n        }\n    )\n\n    reverse_param_names_mapping: dict = field(\n        default_factory=lambda: {\n            # Reverse mapping: SGLang module names -> HF checkpoint keys (for saving).\n            r\"^patchify_proj\\.(.*)$\": r\"proj_in.\\1\",\n            r\"^adaln_single\\.(.*)$\": r\"time_embed.\\1\",\n            r\"^audio_patchify_proj\\.(.*)$\": r\"audio_proj_in.\\1\",\n            r\"^audio_adaln_single\\.(.*)$\": r\"audio_time_embed.\\1\",\n            # FeedForward\n            r\"(.*)ff\\.proj_in\\.(.*)$\": r\"\\1ff.net.0.proj.\\2\",\n            r\"(.*)ff\\.proj_out\\.(.*)$\": r\"\\1ff.net.2.\\2\",\n            # Attention Norms\n            r\"(.*)\\.q_norm\\.(.*)$\": r\"\\1.norm_q.\\2\",\n            r\"(.*)\\.k_norm\\.(.*)$\": r\"\\1.norm_k.\\2\",\n            # Scale Shift Tables (Global)\n            r\"^av_ca_video_scale_shift_adaln_single\\.(.*)$\": r\"av_cross_attn_video_scale_shift.\\1\",\n            r\"^av_ca_audio_scale_shift_adaln_single\\.(.*)$\": r\"av_cross_attn_audio_scale_shift.\\1\",\n            r\"^av_ca_a2v_gate_adaln_single\\.(.*)$\": r\"av_cross_attn_video_a2v_gate.\\1\",\n            r\"^av_ca_v2a_gate_adaln_single\\.(.*)$\": r\"av_cross_attn_audio_v2a_gate.\\1\",\n            # Scale Shift Tables (Block Level)\n            # SGLang: video_a2v_cross_attn_scale_shift_table -> HF: scale_shift_table_a2v_ca_video\n            r\"(.*)video_a2v_cross_attn_scale_shift_table\": r\"\\1scale_shift_table_a2v_ca_video\",\n            r\"(.*)audio_a2v_cross_attn_scale_shift_table\": r\"\\1scale_shift_table_a2v_ca_audio\",\n        }\n    )\n\n    lora_param_names_mapping: dict = field(\n        default_factory=lambda: {\n            # LoRA parameter name mappings from official repo format to HF format.\n            # This is applied before param_names_mapping when loading LoRA adapters.\n            # Will be populated if LoRA adapters use different naming conventions.\n        }\n    )\n\n    # Model type and attention configuration\n    model_type: LTXModelType = LTXModelType.AudioVideo\n    attention_type: LTX2AttentionFunction = LTX2AttentionFunction.DEFAULT\n    rope_type: LTX2RopeType = LTX2RopeType.INTERLEAVED\n    double_precision_rope: bool = False\n\n    # Video parameters\n    num_attention_heads: int = 32\n    attention_head_dim: int = 128\n    in_channels: int = 128\n    out_channels: int = 128\n    num_layers: int = 48\n    cross_attention_dim: int = 4096\n    norm_eps: float = 1e-6\n    caption_channels: int = 3840\n    positional_embedding_theta: float = 10000.0\n    positional_embedding_max_pos: list[int] | None = None\n    timestep_scale_multiplier: int = 1000\n    use_middle_indices_grid: bool = True\n\n    # Audio parameters\n    audio_num_attention_heads: int = 32\n    audio_attention_head_dim: int = 64\n    audio_in_channels: int = 128\n    audio_out_channels: int = 128\n    audio_cross_attention_dim: int = 2048\n    audio_positional_embedding_max_pos: list[int] | None = None\n    av_ca_timestep_scale_multiplier: int = 1\n\n    # SGLang-specific parameters\n    patch_size: tuple[int, int, int] = (1, 2, 2)\n    text_len: int = 512\n\n    def __post_init__(self):\n        super().__post_init__()\n        # Video derived values\n        self.hidden_size = self.num_attention_heads * self.attention_head_dim\n        self.num_channels_latents = self.out_channels\n        if self.positional_embedding_max_pos is None:\n            self.positional_embedding_max_pos = [20, 2048, 2048]\n\n        # Audio derived values\n        self.audio_hidden_size = (\n            self.audio_num_attention_heads * self.audio_attention_head_dim\n        )\n        if self.audio_positional_embedding_max_pos is None:\n            self.audio_positional_embedding_max_pos = [20]\n\n\n@dataclass\nclass LTX2Config(DiTConfig):\n    \"\"\"Configuration for LTX-2 Video Transformer.\"\"\"\n\n    arch_config: LTX2ArchConfig = field(default_factory=LTX2ArchConfig)\n\n    prefix: str = \"ltx2\"\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/models/dits/mova_audio.py",
    "content": "# Copied and adapted from: mossVG/mova/diffusion/models/wan_audio_dit.py\n# SPDX-License-Identifier: Apache-2.0\n\nfrom dataclasses import dataclass, field\n\nfrom sglang.multimodal_gen.configs.models.dits.base import DiTArchConfig, DiTConfig\n\n\ndef _is_blocks(n: str, m) -> bool:\n    return \"blocks\" in n and str.isdigit(n.split(\".\")[-1])\n\n\n@dataclass\nclass MOVAAudioArchConfig(DiTArchConfig):\n    _fsdp_shard_conditions: list = field(default_factory=lambda: [_is_blocks])\n\n    param_names_mapping: dict = field(\n        default_factory=lambda: {\n            r\"^blocks\\.(\\d+)\\.ffn\\.0\\.(.*)$\": r\"blocks.\\1.ffn.fc_in.\\2\",\n            r\"^blocks\\.(\\d+)\\.ffn\\.2\\.(.*)$\": r\"blocks.\\1.ffn.fc_out.\\2\",\n            r\"^blocks\\.(\\d+)\\.norm3\\.(.*)$\": r\"blocks.\\1.self_attn_norm.\\2\",\n            r\"^text_embedding\\.0\\.(.*)$\": r\"text_embedding.fc_in.\\1\",\n            r\"^text_embedding\\.2\\.(.*)$\": r\"text_embedding.fc_out.\\1\",\n            r\"^time_embedding\\.0\\.(.*)$\": r\"time_embedding.fc_in.\\1\",\n            r\"^time_embedding\\.2\\.(.*)$\": r\"time_embedding.fc_out.\\1\",\n            r\"^img_emb\\.proj\\.1\\.(.*)$\": r\"img_emb.fc_in.\\1\",\n            r\"^img_emb\\.proj\\.3\\.(.*)$\": r\"img_emb.fc_out.\\1\",\n        }\n    )\n    reverse_param_names_mapping: dict = field(default_factory=dict)\n    lora_param_names_mapping: dict = field(default_factory=dict)\n\n    dim: int = 1536\n    in_dim: int = 128\n    ffn_dim: int = 6144\n    out_dim: int = 128\n    text_dim: int = 4096\n    freq_dim: int = 256\n    eps: float = 1e-6\n    patch_size: tuple[int, int, int] = (1, 2, 2)\n    num_heads: int = 12\n    num_layers: int = 30\n    has_image_input: bool = False\n    has_image_pos_emb: bool = False\n    has_ref_conv: bool = False\n    add_control_adapter: bool = False\n    in_dim_control_adapter: int = 24\n    separated_timestep: bool = False\n    require_vae_embedding: bool = False\n    require_clip_embedding: bool = False\n    fuse_vae_embedding_in_latents: bool = False\n    vae_type: str = \"dac\"\n\n    def __post_init__(self):\n        super().__post_init__()\n        self.hidden_size = self.dim\n        self.num_attention_heads = self.num_heads\n        self.num_channels_latents = self.out_dim\n        assert (\n            not self.has_image_input\n        ), \"has_image_input must be False; it's a config from Diffsynth Studio, which means the model uses CLIP for image encoding (we don't).\"\n\n\n@dataclass\nclass MOVAAudioConfig(DiTConfig):\n    arch_config: DiTArchConfig = field(default_factory=MOVAAudioArchConfig)\n    prefix: str = \"mova_audio\"\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/models/dits/mova_video.py",
    "content": "# Copied and adapted from: mossVG/mova/diffusion/models/wan_video_dit.py\n# SPDX-License-Identifier: Apache-2.0\n\nfrom dataclasses import dataclass, field\n\nfrom sglang.multimodal_gen.configs.models.dits.base import DiTArchConfig, DiTConfig\n\n\ndef _is_blocks(n: str, m) -> bool:\n    return \"blocks\" in n and str.isdigit(n.split(\".\")[-1])\n\n\n@dataclass\nclass MOVAVideoArchConfig(DiTArchConfig):\n    _fsdp_shard_conditions: list = field(default_factory=lambda: [_is_blocks])\n\n    param_names_mapping: dict = field(\n        default_factory=lambda: {\n            r\"^blocks\\.(\\d+)\\.ffn\\.0\\.(.*)$\": r\"blocks.\\1.ffn.fc_in.\\2\",\n            r\"^blocks\\.(\\d+)\\.ffn\\.2\\.(.*)$\": r\"blocks.\\1.ffn.fc_out.\\2\",\n            r\"^blocks\\.(\\d+)\\.norm3\\.(.*)$\": r\"blocks.\\1.self_attn_norm.\\2\",\n            r\"^text_embedding\\.0\\.(.*)$\": r\"text_embedding.fc_in.\\1\",\n            r\"^text_embedding\\.2\\.(.*)$\": r\"text_embedding.fc_out.\\1\",\n            r\"^time_embedding\\.0\\.(.*)$\": r\"time_embedding.fc_in.\\1\",\n            r\"^time_embedding\\.2\\.(.*)$\": r\"time_embedding.fc_out.\\1\",\n            r\"^img_emb\\.proj\\.1\\.(.*)$\": r\"img_emb.fc_in.\\1\",\n            r\"^img_emb\\.proj\\.3\\.(.*)$\": r\"img_emb.fc_out.\\1\",\n        }\n    )\n    reverse_param_names_mapping: dict = field(default_factory=dict)\n    lora_param_names_mapping: dict = field(default_factory=dict)\n\n    dim: int = 5120\n    in_dim: int = 16\n    ffn_dim: int = 13824\n    out_dim: int = 16\n    text_dim: int = 4096\n    freq_dim: int = 256\n    eps: float = 1e-6\n    patch_size: tuple[int, int, int] = (1, 2, 2)\n    num_heads: int = 40\n    num_layers: int = 40\n    has_image_input: bool = False\n    has_image_pos_emb: bool = False\n    has_ref_conv: bool = False\n    add_control_adapter: bool = False\n    in_dim_control_adapter: int = 24\n    separated_timestep: bool = False\n    require_vae_embedding: bool = True\n    require_clip_embedding: bool = True\n    fuse_vae_embedding_in_latents: bool = False\n\n    def __post_init__(self):\n        super().__post_init__()\n        self.hidden_size = self.dim\n        self.num_attention_heads = self.num_heads\n        self.num_channels_latents = self.out_dim\n        assert (\n            not self.has_image_input\n        ), \"has_image_input must be False; it's a config from Diffsynth Studio, which means the model uses CLIP for image encoding (we don't).\"\n\n\n@dataclass\nclass MOVAVideoConfig(DiTConfig):\n    arch_config: DiTArchConfig = field(default_factory=MOVAVideoArchConfig)\n    prefix: str = \"mova_video\"\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/models/dits/qwenimage.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\nfrom dataclasses import dataclass, field\nfrom typing import Tuple\n\nfrom sglang.multimodal_gen.configs.models.dits.base import DiTArchConfig, DiTConfig\n\n\n@dataclass\nclass QwenImageArchConfig(DiTArchConfig):\n    patch_size: int = 1\n    in_channels: int = 64\n    out_channels: int | None = None\n    num_layers: int = 19\n    num_single_layers: int = 38\n    attention_head_dim: int = 128\n    num_attention_heads: int = 24\n    joint_attention_dim: int = 4096\n    pooled_projection_dim: int = 768\n    guidance_embeds: bool = False\n    axes_dims_rope: Tuple[int, int, int] = (16, 56, 56)\n    zero_cond_t: bool = False\n\n    stacked_params_mapping: list[tuple[str, str, str]] = field(default_factory=list)\n\n    param_names_mapping: dict = field(\n        default_factory=lambda: {\n            # LoRA mappings\n            r\"^(transformer_blocks\\.\\d+\\.attn\\..*\\.lora_[AB])\\.default$\": r\"\\1\",\n            # SVDquant mappings\n            r\"(.*)\\.add_qkv_proj\\.(.+)$\": r\"\\1.to_added_qkv.\\2\",\n            r\"(transformer_blocks\\.\\d+\\.(img_mlp|txt_mlp)\\..*\\.(smooth_factor_orig|wcscales))$\": r\"\\1\",\n            r\".*\\.wtscale$\": r\"\",\n        }\n    )\n\n    def __post_init__(self):\n        super().__post_init__()\n        self.out_channels = self.out_channels or self.in_channels\n        self.hidden_size = self.num_attention_heads * self.attention_head_dim\n        self.num_channels_latents = self.out_channels\n\n\n@dataclass\nclass QwenImageEditPlus_2511_ArchConfig(QwenImageArchConfig):\n    zero_cond_t: bool = True\n\n\n@dataclass\nclass QwenImageDitConfig(DiTConfig):\n    arch_config: DiTArchConfig = field(default_factory=QwenImageArchConfig)\n\n    prefix: str = \"qwenimage\"\n\n\n@dataclass\nclass QwenImageEditPlus_2511_DitConfig(DiTConfig):\n    arch_config: DiTArchConfig = field(\n        default_factory=QwenImageEditPlus_2511_ArchConfig\n    )\n\n    prefix: str = \"qwenimageedit\"\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/models/dits/sana.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n#\n# Architecture and model configuration for SANA DiT (Diffusion Transformer).\n#\n# SANA uses a linear-attention-based transformer that replaces standard\n# quadratic self-attention with ReLU-based linear attention, enabling\n# efficient high-resolution image synthesis. Cross-attention (standard SDPA)\n# is used for text conditioning via Gemma2 embeddings.\n#\n# Defaults below correspond to the SANA-1.6B / 1024px variant.\n# For 4.8B, override num_layers=36, num_attention_heads=64, etc.\n#\n# Reference: https://arxiv.org/abs/2410.10629\n\nfrom dataclasses import dataclass, field\n\nfrom sglang.multimodal_gen.configs.models.dits.base import DiTArchConfig, DiTConfig\n\n\n@dataclass\nclass SanaArchConfig(DiTArchConfig):\n    patch_size: int = 1\n    in_channels: int = 32\n    out_channels: int = 32\n    num_layers: int = 20\n    attention_head_dim: int = 32\n    num_attention_heads: int = 70\n    num_cross_attention_heads: int = 20\n    cross_attention_head_dim: int = 112\n    cross_attention_dim: int = 2240\n    caption_channels: int = 2304\n\n    mlp_ratio: float = 2.5\n    # \"rms_norm_across_heads\" applies RMSNorm over the full (num_heads * head_dim)\n\n    qk_norm: str = \"rms_norm_across_heads\"\n    norm_elementwise_affine: bool = False\n    norm_eps: float = 1e-6\n    sample_size: int = 32\n    guidance_embeds: bool = False\n\n    param_names_mapping: dict = field(\n        default_factory=lambda: {\n            r\"^transformer\\.(.*)$\": r\"\\1\",\n        }\n    )\n\n    def __post_init__(self):\n        super().__post_init__()\n        self.hidden_size = self.num_attention_heads * self.attention_head_dim\n        self.num_channels_latents = self.out_channels\n\n\n@dataclass\nclass SanaConfig(DiTConfig):\n    arch_config: DiTArchConfig = field(default_factory=SanaArchConfig)\n    prefix: str = \"Sana\"\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/models/dits/wanvideo.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\nfrom dataclasses import dataclass, field\n\nfrom sglang.multimodal_gen.configs.models.dits.base import DiTArchConfig, DiTConfig\n\n\ndef is_blocks(n: str, m) -> bool:\n    return \"blocks\" in n and str.isdigit(n.split(\".\")[-1])\n\n\n@dataclass\nclass WanVideoArchConfig(DiTArchConfig):\n    _fsdp_shard_conditions: list = field(default_factory=lambda: [is_blocks])\n\n    param_names_mapping: dict = field(\n        default_factory=lambda: {\n            r\"^patch_embedding\\.(.*)$\": r\"patch_embedding.proj.\\1\",\n            r\"^condition_embedder\\.text_embedder\\.linear_1\\.(.*)$\": r\"condition_embedder.text_embedder.fc_in.\\1\",\n            r\"^condition_embedder\\.text_embedder\\.linear_2\\.(.*)$\": r\"condition_embedder.text_embedder.fc_out.\\1\",\n            r\"^condition_embedder\\.time_embedder\\.linear_1\\.(.*)$\": r\"condition_embedder.time_embedder.mlp.fc_in.\\1\",\n            r\"^condition_embedder\\.time_embedder\\.linear_2\\.(.*)$\": r\"condition_embedder.time_embedder.mlp.fc_out.\\1\",\n            r\"^condition_embedder\\.time_proj\\.(.*)$\": r\"condition_embedder.time_modulation.linear.\\1\",\n            r\"^condition_embedder\\.image_embedder\\.ff\\.net\\.0\\.proj\\.(.*)$\": r\"condition_embedder.image_embedder.ff.fc_in.\\1\",\n            r\"^condition_embedder\\.image_embedder\\.ff\\.net\\.2\\.(.*)$\": r\"condition_embedder.image_embedder.ff.fc_out.\\1\",\n            r\"^blocks\\.(\\d+)\\.attn1\\.to_q\\.(.*)$\": r\"blocks.\\1.to_q.\\2\",\n            r\"^blocks\\.(\\d+)\\.attn1\\.to_k\\.(.*)$\": r\"blocks.\\1.to_k.\\2\",\n            r\"^blocks\\.(\\d+)\\.attn1\\.to_v\\.(.*)$\": r\"blocks.\\1.to_v.\\2\",\n            r\"^blocks\\.(\\d+)\\.attn1\\.to_out\\.0\\.(.*)$\": r\"blocks.\\1.to_out.\\2\",\n            r\"^blocks\\.(\\d+)\\.attn1\\.norm_q\\.(.*)$\": r\"blocks.\\1.norm_q.\\2\",\n            r\"^blocks\\.(\\d+)\\.attn1\\.norm_k\\.(.*)$\": r\"blocks.\\1.norm_k.\\2\",\n            r\"^blocks\\.(\\d+)\\.attn1\\.attn_op\\.local_attn\\.proj_l\\.(.*)$\": r\"blocks.\\1.attn1.local_attn.proj_l.\\2\",\n            r\"^blocks\\.(\\d+)\\.attn2\\.to_out\\.0\\.(.*)$\": r\"blocks.\\1.attn2.to_out.\\2\",\n            r\"^blocks\\.(\\d+)\\.ffn\\.net\\.0\\.proj\\.(.*)$\": r\"blocks.\\1.ffn.fc_in.\\2\",\n            r\"^blocks\\.(\\d+)\\.ffn\\.net\\.2\\.(.*)$\": r\"blocks.\\1.ffn.fc_out.\\2\",\n            r\"^blocks\\.(\\d+)\\.norm2\\.(.*)$\": r\"blocks.\\1.self_attn_residual_norm.norm.\\2\",\n        }\n    )\n\n    reverse_param_names_mapping: dict = field(default_factory=lambda: {})\n\n    # Some LoRA adapters use the original official layer names instead of hf layer names,\n    # so apply this before the param_names_mapping\n    lora_param_names_mapping: dict = field(\n        default_factory=lambda: {\n            r\"^blocks\\.(\\d+)\\.self_attn\\.q\\.(.*)$\": r\"blocks.\\1.attn1.to_q.\\2\",\n            r\"^blocks\\.(\\d+)\\.self_attn\\.k\\.(.*)$\": r\"blocks.\\1.attn1.to_k.\\2\",\n            r\"^blocks\\.(\\d+)\\.self_attn\\.v\\.(.*)$\": r\"blocks.\\1.attn1.to_v.\\2\",\n            r\"^blocks\\.(\\d+)\\.self_attn\\.o\\.(.*)$\": r\"blocks.\\1.attn1.to_out.0.\\2\",\n            r\"^blocks\\.(\\d+)\\.cross_attn\\.q\\.(.*)$\": r\"blocks.\\1.attn2.to_q.\\2\",\n            r\"^blocks\\.(\\d+)\\.cross_attn\\.k\\.(.*)$\": r\"blocks.\\1.attn2.to_k.\\2\",\n            r\"^blocks\\.(\\d+)\\.cross_attn\\.v\\.(.*)$\": r\"blocks.\\1.attn2.to_v.\\2\",\n            r\"^blocks\\.(\\d+)\\.cross_attn\\.o\\.(.*)$\": r\"blocks.\\1.attn2.to_out.0.\\2\",\n            r\"^blocks\\.(\\d+)\\.ffn\\.0\\.(.*)$\": r\"blocks.\\1.ffn.fc_in.\\2\",\n            r\"^blocks\\.(\\d+)\\.ffn\\.2\\.(.*)$\": r\"blocks.\\1.ffn.fc_out.\\2\",\n        }\n    )\n\n    patch_size: tuple[int, int, int] = (1, 2, 2)\n    text_len = 512\n    num_attention_heads: int = 40\n    attention_head_dim: int = 128\n    in_channels: int = 16\n    out_channels: int = 16\n    text_dim: int = 4096\n    freq_dim: int = 256\n    ffn_dim: int = 13824\n    num_layers: int = 40\n    cross_attn_norm: bool = True\n    qk_norm: str = \"rms_norm_across_heads\"\n    eps: float = 1e-6\n    image_dim: int | None = None\n    added_kv_proj_dim: int | None = None\n    rope_max_seq_len: int = 1024\n    pos_embed_seq_len: int | None = None\n    exclude_lora_layers: list[str] = field(default_factory=lambda: [\"embedder\"])\n\n    # Wan MoE\n    boundary_ratio: float | None = None\n\n    # Causal Wan\n    local_attn_size: int = (\n        -1\n    )  # Window size for temporal local attention (-1 indicates global attention)\n    sink_size: int = (\n        0  # Size of the attention sink, we keep the first `sink_size` frames unchanged when rolling the KV cache\n    )\n    num_frames_per_block: int = 3\n    sliding_window_num_frames: int = 21\n    attention_type: str = \"original\"\n    sla_topk: float = 0.1\n\n    def __post_init__(self):\n        super().__post_init__()\n        self.out_channels = self.out_channels or self.in_channels\n        self.hidden_size = self.num_attention_heads * self.attention_head_dim\n        self.num_channels_latents = self.out_channels\n\n\n@dataclass\nclass WanVideoConfig(DiTConfig):\n    arch_config: DiTArchConfig = field(default_factory=WanVideoArchConfig)\n\n    prefix: str = \"Wan\"\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/models/dits/zimage.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\nfrom dataclasses import dataclass, field\nfrom typing import Tuple\n\nfrom sglang.multimodal_gen.configs.models.dits.base import DiTArchConfig, DiTConfig\n\n\ndef is_zimage_layer(n: str, m) -> bool:\n    \"\"\"Returns if the module should be sharded for Z-Image model.\"\"\"\n    if \"layers\" in n and str.isdigit(n.split(\".\")[-1]):\n        return True\n    if (\"noise_refiner\" in n or \"context_refiner\" in n) and str.isdigit(\n        n.split(\".\")[-1]\n    ):\n        return True\n    return False\n\n\n@dataclass\nclass ZImageArchConfig(DiTArchConfig):\n    all_patch_size: Tuple[int, ...] = (2,)\n    all_f_patch_size: Tuple[int, ...] = (1,)\n    in_channels: int = 16\n    out_channels: int | None = None\n    dim: int = 3840\n    num_layers: int = 30\n    n_refiner_layers: int = 2\n    num_attention_heads: int = 30\n    n_kv_heads: int = 30\n    norm_eps: float = 1e-5\n    qk_norm: bool = True\n    cap_feat_dim: int = 2560\n    rope_theta: float = 256.0\n    t_scale: float = 1000.0\n    axes_dims: Tuple[int, int, int] = (32, 48, 48)\n    axes_lens: Tuple[int, int, int] = (1024, 512, 512)\n\n    _fsdp_shard_conditions: list = field(default_factory=lambda: [is_zimage_layer])\n\n    stacked_params_mapping: list[tuple[str, str, str]] = field(\n        default_factory=lambda: [\n            # (param_name, shard_name, shard_id)\n            (\".feed_forward.w13\", \".feed_forward.w1\", \"gate\"),\n            (\".feed_forward.w13\", \".feed_forward.w3\", \"up\"),\n        ]\n    )\n\n    param_names_mapping: dict = field(\n        default_factory=lambda: {\n            r\"(.*)\\.feed_forward\\.w1\\.weight$\": (r\"\\1.feed_forward.w13.weight\", 0, 2),\n            r\"(.*)\\.feed_forward\\.w3\\.weight$\": (r\"\\1.feed_forward.w13.weight\", 1, 2),\n            r\"(.*)\\.feed_forward\\.w1\\.(lora_A|lora_B)$\": (\n                r\"\\1.feed_forward.w13.\\2\",\n                0,\n                2,\n            ),\n            r\"(.*)\\.feed_forward\\.w3\\.(lora_A|lora_B)$\": (\n                r\"\\1.feed_forward.w13.\\2\",\n                1,\n                2,\n            ),\n        }\n    )\n\n    def __post_init__(self):\n        super().__post_init__()\n        self.out_channels = self.out_channels or self.in_channels\n        self.num_channels_latents = self.in_channels\n        self.hidden_size = self.dim\n\n\n@dataclass\nclass ZImageDitConfig(DiTConfig):\n    arch_config: ZImageArchConfig = field(default_factory=ZImageArchConfig)\n\n    prefix: str = \"zimage\"\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/models/encoders/__init__.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\nfrom sglang.multimodal_gen.configs.models.encoders.base import (\n    BaseEncoderOutput,\n    EncoderConfig,\n    ImageEncoderConfig,\n    TextEncoderConfig,\n)\nfrom sglang.multimodal_gen.configs.models.encoders.clip import (\n    CLIPTextConfig,\n    CLIPVisionConfig,\n)\nfrom sglang.multimodal_gen.configs.models.encoders.gemma2 import Gemma2Config\nfrom sglang.multimodal_gen.configs.models.encoders.gemma_3 import Gemma3Config\nfrom sglang.multimodal_gen.configs.models.encoders.llama import LlamaConfig\nfrom sglang.multimodal_gen.configs.models.encoders.qwen3 import Qwen3TextConfig\nfrom sglang.multimodal_gen.configs.models.encoders.t5 import T5Config\n\n__all__ = [\n    \"EncoderConfig\",\n    \"TextEncoderConfig\",\n    \"ImageEncoderConfig\",\n    \"BaseEncoderOutput\",\n    \"CLIPTextConfig\",\n    \"CLIPVisionConfig\",\n    \"LlamaConfig\",\n    \"Qwen3TextConfig\",\n    \"T5Config\",\n    \"Gemma2Config\",\n    \"Gemma3Config\",\n]\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/models/encoders/base.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\nfrom dataclasses import dataclass, field\nfrom typing import Any\n\nimport torch\n\nfrom sglang.multimodal_gen.configs.models.base import ArchConfig, ModelConfig\nfrom sglang.multimodal_gen.runtime.layers.quantization import QuantizationConfig\nfrom sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum\n\n\n@dataclass\nclass EncoderArchConfig(ArchConfig):\n    _fsdp_shard_conditions: list = field(default_factory=lambda: [])\n    architectures: list[str] = field(default_factory=lambda: [])\n    _supported_attention_backends: set[AttentionBackendEnum] = field(\n        default_factory=lambda: {\n            AttentionBackendEnum.FA,\n            AttentionBackendEnum.TORCH_SDPA,\n            AttentionBackendEnum.SAGE_ATTN_3,\n        }\n    )\n    output_hidden_states: bool = False\n    use_return_dict: bool = True\n\n\n@dataclass\nclass TextEncoderArchConfig(EncoderArchConfig):\n    vocab_size: int = 0\n    hidden_size: int = 0\n    num_hidden_layers: int = 0\n    num_attention_heads: int = 0\n    pad_token_id: int = 0\n    eos_token_id: int = 0\n    text_len: int = 0\n    hidden_state_skip_layer: int = 0\n    decoder_start_token_id: int = 0\n    output_past: bool = True\n    scalable_attention: bool = True\n    tie_word_embeddings: bool = False\n    stacked_params_mapping: list[tuple[str, str, str]] = field(\n        default_factory=list\n    )  # mapping from huggingface weight names to custom names\n    tokenizer_kwargs: dict[str, Any] = field(default_factory=dict)\n    _fsdp_shard_conditions: list = field(default_factory=lambda: [])\n\n    def __post_init__(self) -> None:\n        self.tokenizer_kwargs = {\n            \"truncation\": True,\n            \"max_length\": self.text_len,\n            \"return_tensors\": \"pt\",\n        }\n\n\n@dataclass\nclass ImageEncoderArchConfig(EncoderArchConfig):\n    pass\n\n\n@dataclass\nclass BaseEncoderOutput:\n    last_hidden_state: torch.FloatTensor | None = None\n    pooler_output: torch.FloatTensor | None = None\n    hidden_states: tuple[torch.FloatTensor, ...] | None = None\n    attentions: tuple[torch.FloatTensor, ...] | None = None\n    attention_mask: torch.Tensor | None = None\n\n\n@dataclass\nclass EncoderConfig(ModelConfig):\n    arch_config: ArchConfig = field(default_factory=EncoderArchConfig)\n\n    prefix: str = \"\"\n    quant_config: QuantizationConfig | None = None\n    lora_config: Any | None = None\n\n\n@dataclass\nclass TextEncoderConfig(EncoderConfig):\n    arch_config: ArchConfig = field(default_factory=TextEncoderArchConfig)\n\n    # Use the SP Group of the transformer as the TP Group of T5.\n    parallel_folding: bool = False\n    # \"sp\" or \"ulysses\" or \"ring\"\n    parallel_folding_mode: str = \"sp\"\n\n\n@dataclass\nclass ImageEncoderConfig(EncoderConfig):\n    arch_config: ArchConfig = field(default_factory=ImageEncoderArchConfig)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/models/encoders/clip.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\nfrom dataclasses import dataclass, field\n\nfrom sglang.multimodal_gen.configs.models.encoders.base import (\n    ImageEncoderArchConfig,\n    ImageEncoderConfig,\n    TextEncoderArchConfig,\n    TextEncoderConfig,\n)\nfrom sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum\n\n\ndef _is_transformer_layer(n: str, m) -> bool:\n    return \"layers\" in n and str.isdigit(n.split(\".\")[-1])\n\n\ndef _is_embeddings(n: str, m) -> bool:\n    return n.endswith(\"embeddings\")\n\n\n@dataclass\nclass CLIPTextArchConfig(TextEncoderArchConfig):\n    vocab_size: int = 49408\n    hidden_size: int = 512\n    intermediate_size: int = 2048\n    projection_dim: int = 512\n    num_hidden_layers: int = 12\n    num_attention_heads: int = 8\n    max_position_embeddings: int = 77\n    hidden_act: str = \"quick_gelu\"\n    layer_norm_eps: float = 1e-5\n    dropout: float = 0.0\n    attention_dropout: float = 0.0\n    initializer_range: float = 0.02\n    initializer_factor: float = 1.0\n    pad_token_id: int = 1\n    bos_token_id: int = 49406\n    eos_token_id: int = 49407\n    text_len: int = 77\n    _supported_attention_backends: set[AttentionBackendEnum] = field(\n        default_factory=lambda: {\n            AttentionBackendEnum.TORCH_SDPA,  # Force TORCH_SDPA to support attention_mask\n        }\n    )\n    stacked_params_mapping: list[tuple[str, str, str]] = field(\n        default_factory=lambda: [\n            # (param_name, shard_name, shard_id)\n            (\"qkv_proj\", \"q_proj\", \"q\"),\n            (\"qkv_proj\", \"k_proj\", \"k\"),\n            (\"qkv_proj\", \"v_proj\", \"v\"),\n        ]\n    )\n    _fsdp_shard_conditions: list = field(\n        default_factory=lambda: [_is_transformer_layer, _is_embeddings]\n    )\n\n\n@dataclass\nclass CLIPVisionArchConfig(ImageEncoderArchConfig):\n    hidden_size: int = 768\n    intermediate_size: int = 3072\n    projection_dim: int = 512\n    num_hidden_layers: int = 12\n    num_attention_heads: int = 12\n    num_channels: int = 3\n    image_size: int = 224\n    patch_size: int = 32\n    hidden_act: str = \"quick_gelu\"\n    layer_norm_eps: float = 1e-5\n    dropout: float = 0.0\n    attention_dropout: float = 0.0\n    initializer_range: float = 0.02\n    initializer_factor: float = 1.0\n    stacked_params_mapping: list[tuple[str, str, str]] = field(\n        default_factory=lambda: [\n            # (param_name, shard_name, shard_id)\n            (\"qkv_proj\", \"q_proj\", \"q\"),\n            (\"qkv_proj\", \"k_proj\", \"k\"),\n            (\"qkv_proj\", \"v_proj\", \"v\"),\n        ]\n    )\n\n\n@dataclass\nclass CLIPTextConfig(TextEncoderConfig):\n    arch_config: TextEncoderArchConfig = field(default_factory=CLIPTextArchConfig)\n\n    num_hidden_layers_override: int | None = None\n    require_post_norm: bool | None = None\n    prefix: str = \"clip\"\n\n\n@dataclass\nclass CLIPVisionConfig(ImageEncoderConfig):\n    arch_config: ImageEncoderArchConfig = field(default_factory=CLIPVisionArchConfig)\n\n    num_hidden_layers_override: int | None = None\n    require_post_norm: bool | None = None\n    prefix: str = \"clip\"\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/models/encoders/gemma2.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n#\n# Text encoder configuration for Gemma2 2B, used by SANA for text conditioning.\n#\n# SANA uses the hidden states from Gemma2 (not logits) as the conditioning\n# signal for cross-attention in the DiT. The encoder output dimension (2304)\n# is projected to the DiT's inner_dim via caption_projection.\n#\n# Defaults match google/gemma-2-2b-it (the model used in SANA HF checkpoints).\n\nfrom dataclasses import dataclass, field\n\nfrom sglang.multimodal_gen.configs.models.encoders.base import (\n    TextEncoderArchConfig,\n    TextEncoderConfig,\n)\n\n\ndef _is_transformer_layer(n: str, m) -> bool:\n    return \"layers\" in n and str.isdigit(n.split(\".\")[-1])\n\n\ndef _is_embeddings(n: str, m) -> bool:\n    return n.endswith(\"embed_tokens\")\n\n\ndef _is_final_norm(n: str, m) -> bool:\n    return n.endswith(\"norm\")\n\n\n@dataclass\nclass Gemma2ArchConfig(TextEncoderArchConfig):\n    vocab_size: int = 256000\n    hidden_size: int = 2304\n    intermediate_size: int = 9216\n    num_hidden_layers: int = 26\n    num_attention_heads: int = 8\n    num_key_value_heads: int = 4\n    head_dim: int = 256\n    hidden_act: str = \"gelu_pytorch_tanh\"\n    hidden_activation: str = \"gelu_pytorch_tanh\"\n    max_position_embeddings: int = 8192\n    rms_norm_eps: float = 1e-6\n    use_cache: bool = True\n    pad_token_id: int = 0\n    eos_token_id: int = 1\n    bos_token_id: int = 2\n    tie_word_embeddings: bool = True\n    rope_theta: float = 10000.0\n    attention_bias: bool = False\n    attention_dropout: float = 0.0\n\n    # Gemma2 alternates between global and sliding-window attention\n    # on odd/even layers, respectively.\n    sliding_window: int = 4096\n\n    # query_pre_attn_scalar replaces the standard 1/sqrt(head_dim) scaling.\n    query_pre_attn_scalar: int = 256\n\n    # Softcapping bounds raw attention logits via tanh(logits/cap)*cap.\n    # NOTE: SDPA does not natively support softcapping; the runtime model\n    # currently skips this (see Gemma2Attention.forward). Quality impact\n    # is minimal for short text-encoder sequences but should be revisited\n    # for longer context.\n    attn_logit_softcapping: float = 50.0\n    final_logit_softcapping: float = 30.0\n\n    text_len: int = 300\n\n    stacked_params_mapping: list[tuple[str, str, str]] = field(\n        default_factory=lambda: [\n            (\".qkv_proj\", \".q_proj\", \"q\"),\n            (\".qkv_proj\", \".k_proj\", \"k\"),\n            (\".qkv_proj\", \".v_proj\", \"v\"),\n            (\".gate_up_proj\", \".gate_proj\", \"0\"),\n            (\".gate_up_proj\", \".up_proj\", \"1\"),\n        ]\n    )\n    _fsdp_shard_conditions: list = field(\n        default_factory=lambda: [_is_transformer_layer, _is_embeddings, _is_final_norm]\n    )\n\n\n@dataclass\nclass Gemma2Config(TextEncoderConfig):\n    arch_config: TextEncoderArchConfig = field(default_factory=Gemma2ArchConfig)\n    prefix: str = \"gemma_2\"\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/models/encoders/gemma_3.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\nfrom dataclasses import dataclass, field\n\nfrom sglang.multimodal_gen.configs.models.encoders.base import (\n    TextEncoderArchConfig,\n    TextEncoderConfig,\n)\n\n\ndef _is_transformer_layer(n: str, m) -> bool:\n    return \"layers\" in n and str.isdigit(n.split(\".\")[-1])\n\n\ndef _is_embeddings(n: str, m) -> bool:\n    return n.endswith(\"embed_tokens\")\n\n\ndef _is_final_norm(n: str, m) -> bool:\n    return n.endswith(\"norm\")\n\n\n@dataclass\nclass Gemma3ArchConfig(TextEncoderArchConfig):\n    \"\"\"Minimal Gemma text-encoder config for tokenizer kwargs.\n\n    Note: runtime will load the actual `text_encoder/` module from the model repo\n    (e.g. Gemma3Model) via transformers; this config mainly controls tokenization.\n    \"\"\"\n\n    vocab_size: int = 32000\n    hidden_size: int = 4096\n    intermediate_size: int = 11008\n    num_hidden_layers: int = 32\n    num_attention_heads: int = 32\n    num_key_value_heads: int | None = None\n    hidden_act: str = \"gelu_pytorch_tanh\"\n    max_position_embeddings: int = 2048\n    initializer_range: float = 0.02\n    rms_norm_eps: float = 1e-6\n    use_cache: bool = True\n    pad_token_id: int = 0\n    bos_token_id: int = 1\n    eos_token_id: int = 2\n    pretraining_tp: int = 1\n    tie_word_embeddings: bool = True\n    rope_theta: float = 10000.0\n    rope_scaling: dict | None = None\n    rope_local_base_freq: float = 10000.0\n    sliding_window: int = 4096\n    layer_types: list[str] = field(default_factory=list)\n    query_pre_attn_scalar: int | None = None\n    attention_bias: bool = False\n    attention_dropout: float = 0.0\n    mlp_bias: bool = False\n    head_dim: int | None = None\n    hidden_state_skip_layer: int = 2\n    text_len: int = 1024\n\n    stacked_params_mapping: list[tuple[str, str, str]] = field(\n        default_factory=lambda: [\n            # (param_name, shard_name, shard_id)\n            (\".qkv_proj\", \".q_proj\", \"q\"),\n            (\".qkv_proj\", \".k_proj\", \"k\"),\n            (\".qkv_proj\", \".v_proj\", \"v\"),\n            (\".gate_up_proj\", \".gate_proj\", \"0\"),  # type: ignore\n            (\".gate_up_proj\", \".up_proj\", \"1\"),  # type: ignore\n        ]\n    )\n    _fsdp_shard_conditions: list = field(\n        default_factory=lambda: [_is_transformer_layer, _is_embeddings, _is_final_norm]\n    )\n\n\n@dataclass\nclass Gemma3Config(TextEncoderConfig):\n    arch_config: TextEncoderArchConfig = field(default_factory=Gemma3ArchConfig)\n\n    prefix: str = \"gemma_3\"\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/models/encoders/llama.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\nfrom dataclasses import dataclass, field\n\nfrom sglang.multimodal_gen.configs.models.encoders.base import (\n    TextEncoderArchConfig,\n    TextEncoderConfig,\n)\n\n\ndef _is_transformer_layer(n: str, m) -> bool:\n    return \"layers\" in n and str.isdigit(n.split(\".\")[-1])\n\n\ndef _is_embeddings(n: str, m) -> bool:\n    return n.endswith(\"embed_tokens\")\n\n\ndef _is_final_norm(n: str, m) -> bool:\n    return n.endswith(\"norm\")\n\n\n@dataclass\nclass LlamaArchConfig(TextEncoderArchConfig):\n    vocab_size: int = 32000\n    hidden_size: int = 4096\n    intermediate_size: int = 11008\n    num_hidden_layers: int = 32\n    num_attention_heads: int = 32\n    num_key_value_heads: int | None = None\n    hidden_act: str = \"silu\"\n    max_position_embeddings: int = 2048\n    initializer_range: float = 0.02\n    rms_norm_eps: float = 1e-6\n    use_cache: bool = True\n    pad_token_id: int = 0\n    bos_token_id: int = 1\n    eos_token_id: int = 2\n    pretraining_tp: int = 1\n    tie_word_embeddings: bool = False\n    rope_theta: float = 10000.0\n    rope_scaling: float | None = None\n    attention_bias: bool = False\n    attention_dropout: float = 0.0\n    mlp_bias: bool = False\n    head_dim: int | None = None\n    hidden_state_skip_layer: int = 2\n    text_len: int = 256\n    stacked_params_mapping: list[tuple[str, str, str]] = field(\n        default_factory=lambda: [\n            # (param_name, shard_name, shard_id)\n            (\".qkv_proj\", \".q_proj\", \"q\"),\n            (\".qkv_proj\", \".k_proj\", \"k\"),\n            (\".qkv_proj\", \".v_proj\", \"v\"),\n            (\".gate_up_proj\", \".gate_proj\", 0),  # type: ignore\n            (\".gate_up_proj\", \".up_proj\", 1),  # type: ignore\n        ]\n    )\n    _fsdp_shard_conditions: list = field(\n        default_factory=lambda: [_is_transformer_layer, _is_embeddings, _is_final_norm]\n    )\n\n\n@dataclass\nclass LlamaConfig(TextEncoderConfig):\n    arch_config: TextEncoderArchConfig = field(default_factory=LlamaArchConfig)\n\n    prefix: str = \"llama\"\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/models/encoders/qwen3.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\"\"\"Qwen3 text encoder configuration for SGLang diffusion models.\"\"\"\n\nfrom dataclasses import dataclass, field\n\nfrom sglang.multimodal_gen.configs.models.encoders.base import (\n    TextEncoderArchConfig,\n    TextEncoderConfig,\n)\n\n\ndef _is_transformer_layer(n: str, m) -> bool:\n    return \"layers\" in n and str.isdigit(n.split(\".\")[-1])\n\n\ndef _is_embeddings(n: str, m) -> bool:\n    return n.endswith(\"embed_tokens\")\n\n\ndef _is_final_norm(n: str, m) -> bool:\n    return n.endswith(\"norm\")\n\n\n@dataclass\nclass Qwen3TextArchConfig(TextEncoderArchConfig):\n    \"\"\"Architecture config for Qwen3 text encoder.\n\n    Qwen3 is similar to LLaMA but with QK-Norm (RMSNorm on Q and K before attention).\n    \"\"\"\n\n    vocab_size: int = 151936\n    hidden_size: int = 2560\n    intermediate_size: int = 9728\n    num_hidden_layers: int = 36\n    num_attention_heads: int = 32\n    num_key_value_heads: int = 8\n    hidden_act: str = \"silu\"\n    max_position_embeddings: int = 40960\n    initializer_range: float = 0.02\n    rms_norm_eps: float = 1e-6\n    use_cache: bool = True\n    pad_token_id: int = 151643\n    bos_token_id: int = 151643\n    eos_token_id: int = 151645\n    tie_word_embeddings: bool = True\n    rope_theta: float = 1000000.0\n    rope_scaling: dict | None = None\n    attention_bias: bool = False\n    attention_dropout: float = 0.0\n    mlp_bias: bool = False\n    head_dim: int = 128\n    text_len: int = 512\n    output_hidden_states: bool = True  # Klein needs hidden states from layers 9, 18, 27\n\n    # Stacked params for weight loading with tensor parallelism\n    stacked_params_mapping: list[tuple[str, str, str]] = field(\n        default_factory=lambda: [\n            # (param_name, shard_name, shard_id)\n            (\".qkv_proj\", \".q_proj\", \"q\"),\n            (\".qkv_proj\", \".k_proj\", \"k\"),\n            (\".qkv_proj\", \".v_proj\", \"v\"),\n            (\".gate_up_proj\", \".gate_proj\", 0),\n            (\".gate_up_proj\", \".up_proj\", 1),\n        ]\n    )\n\n    # FSDP sharding conditions for CPU offload\n    _fsdp_shard_conditions: list = field(\n        default_factory=lambda: [_is_transformer_layer, _is_embeddings, _is_final_norm]\n    )\n\n    def __post_init__(self) -> None:\n        self.tokenizer_kwargs = {\n            \"padding\": \"max_length\",\n            \"truncation\": True,\n            \"max_length\": self.text_len,\n            \"return_tensors\": \"pt\",\n        }\n\n\n@dataclass\nclass Qwen3TextConfig(TextEncoderConfig):\n    \"\"\"Top-level config for Qwen3 text encoder.\"\"\"\n\n    arch_config: TextEncoderArchConfig = field(default_factory=Qwen3TextArchConfig)\n    prefix: str = \"qwen3\"\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/models/encoders/qwen_image.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\nfrom dataclasses import dataclass, field\n\nfrom sglang.multimodal_gen.configs.models.encoders.base import (\n    TextEncoderArchConfig,\n    TextEncoderConfig,\n)\n\n\ndef _is_transformer_layer(n: str, m) -> bool:\n    return \"layers\" in n and str.isdigit(n.split(\".\")[-1])\n\n\ndef _is_embeddings(n: str, m) -> bool:\n    return n.endswith(\"embed_tokens\")\n\n\ndef _is_final_norm(n: str, m) -> bool:\n    return n.endswith(\"norm\")\n\n\n@dataclass\nclass QwenImageArchConfig(TextEncoderArchConfig):\n    vocab_size: int = 32000\n    hidden_size: int = 4096\n    intermediate_size: int = 11008\n    num_hidden_layers: int = 32\n    num_attention_heads: int = 32\n    num_key_value_heads: int | None = None\n    hidden_act: str = \"silu\"\n    max_position_embeddings: int = 2048\n    initializer_range: float = 0.02\n    rms_norm_eps: float = 1e-6\n    use_cache: bool = True\n    pad_token_id: int = -1\n    eos_token_id: int = 2\n    pretraining_tp: int = 1\n    tie_word_embeddings: bool = False\n    rope_theta: float = 10000.0\n    rope_scaling: float | None = None\n    attention_bias: bool = False\n    attention_dropout: float = 0.0\n    mlp_bias: bool = False\n    head_dim: int | None = None\n    hidden_state_skip_layer: int = 2\n    text_len: int = 512\n\n    stacked_params_mapping: list[tuple[str, str, str]] = field(\n        default_factory=lambda: [\n            # (param_name, shard_name, shard_id)\n            (\".qkv_proj\", \".q_proj\", \"q\"),\n            (\".qkv_proj\", \".k_proj\", \"k\"),\n            (\".qkv_proj\", \".v_proj\", \"v\"),\n            (\".gate_up_proj\", \".gate_proj\", 0),  # type: ignore\n            (\".gate_up_proj\", \".up_proj\", 1),  # type: ignore\n        ]\n    )\n    _fsdp_shard_conditions: list = field(\n        default_factory=lambda: [_is_transformer_layer, _is_embeddings, _is_final_norm]\n    )\n\n\n@dataclass\nclass Qwen2_5VLConfig(TextEncoderConfig):\n    arch_config: TextEncoderArchConfig = field(default_factory=QwenImageArchConfig)\n    # prefix: str = \"qwen_image\"\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/models/encoders/t5.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\nimport argparse\nfrom dataclasses import dataclass, field\n\nfrom sglang.multimodal_gen.configs.models.encoders.base import (\n    TextEncoderArchConfig,\n    TextEncoderConfig,\n)\n\n\ndef _is_transformer_layer(n: str, m) -> bool:\n    return \"block\" in n and str.isdigit(n.split(\".\")[-1])\n\n\ndef _is_embeddings(n: str, m) -> bool:\n    return n.endswith(\"shared\")\n\n\ndef _is_final_layernorm(n: str, m) -> bool:\n    return n.endswith(\"final_layer_norm\")\n\n\n@dataclass\nclass T5ArchConfig(TextEncoderArchConfig):\n    vocab_size: int = 32128\n    d_model: int = 512\n    d_kv: int = 64\n    d_ff: int = 2048\n    num_layers: int = 6\n    num_decoder_layers: int | None = None\n    num_heads: int = 8\n    relative_attention_num_buckets: int = 32\n    relative_attention_max_distance: int = 128\n    dropout_rate: float = 0.1\n    layer_norm_epsilon: float = 1e-6\n    initializer_factor: float = 1.0\n    feed_forward_proj: str = \"relu\"\n    dense_act_fn: str = \"\"\n    is_gated_act: bool = False\n    is_encoder_decoder: bool = True\n    use_cache: bool = True\n    pad_token_id: int = 0\n    eos_token_id: int = 1\n    classifier_dropout: float = 0.0\n    text_len: int = 512\n    stacked_params_mapping: list[tuple[str, str, str]] = field(\n        default_factory=lambda: [\n            # (param_name, shard_name, shard_id)\n            (\".qkv_proj\", \".q\", \"q\"),\n            (\".qkv_proj\", \".k\", \"k\"),\n            (\".qkv_proj\", \".v\", \"v\"),\n        ]\n    )\n    _fsdp_shard_conditions: list = field(\n        default_factory=lambda: [\n            _is_transformer_layer,\n            _is_embeddings,\n            _is_final_layernorm,\n        ]\n    )\n\n    # Referenced from https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/configuration_t5.py\n    def __post_init__(self):\n        super().__post_init__()\n        act_info = self.feed_forward_proj.split(\"-\")\n        self.dense_act_fn: str = act_info[-1]\n        self.is_gated_act: bool = act_info[0] == \"gated\"\n        if self.feed_forward_proj == \"gated-gelu\":\n            self.dense_act_fn = \"gelu_new\"\n\n        self.tokenizer_kwargs = {\n            \"padding\": \"max_length\",\n            \"truncation\": True,\n            \"max_length\": self.text_len,\n            \"add_special_tokens\": True,\n            \"return_attention_mask\": True,\n            \"return_tensors\": \"pt\",\n        }\n\n\n@dataclass\nclass T5Config(TextEncoderConfig):\n    arch_config: TextEncoderArchConfig = field(default_factory=T5ArchConfig)\n\n    prefix: str = \"t5\"\n    # Use the SP Group of the transformer as the TP Group of T5.\n    parallel_folding: bool = False\n    # \"sp\" or \"ulysses\" or \"ring\"\n    parallel_folding_mode: str = \"sp\"\n\n    @staticmethod\n    def add_cli_args(\n        parser: argparse.ArgumentParser, prefix: str = \"t5-config\"\n    ) -> argparse.ArgumentParser:\n        return parser\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/models/vaes/__init__.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\nfrom sglang.multimodal_gen.configs.models.vaes.dac import DacVAEConfig\nfrom sglang.multimodal_gen.configs.models.vaes.hunyuan3d import Hunyuan3DVAEConfig\nfrom sglang.multimodal_gen.configs.models.vaes.hunyuanvae import HunyuanVAEConfig\nfrom sglang.multimodal_gen.configs.models.vaes.wanvae import WanVAEConfig\n\n__all__ = [\n    \"DacVAEConfig\",\n    \"HunyuanVAEConfig\",\n    \"WanVAEConfig\",\n    \"Hunyuan3DVAEConfig\",\n]\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/models/vaes/base.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\nimport argparse\nimport dataclasses\nfrom dataclasses import dataclass, field\nfrom typing import Any\n\nimport torch\n\nfrom sglang.multimodal_gen.configs.models.base import ArchConfig, ModelConfig\nfrom sglang.multimodal_gen.utils import StoreBoolean\n\n\n@dataclass\nclass VAEArchConfig(ArchConfig):\n    scaling_factor: float | torch.Tensor = 0\n\n    temporal_compression_ratio: int = 4\n    # or vae_scale_factor?\n    spatial_compression_ratio: int = 8\n\n\n@dataclass\nclass VAEConfig(ModelConfig):\n    arch_config: VAEArchConfig = field(default_factory=VAEArchConfig)\n\n    # sglang-diffusion VAE-specific parameters\n    load_encoder: bool = True\n    load_decoder: bool = True\n\n    tile_sample_min_height: int = 256\n    tile_sample_min_width: int = 256\n    tile_sample_min_num_frames: int = 16\n    tile_sample_stride_height: int = 192\n    tile_sample_stride_width: int = 192\n    tile_sample_stride_num_frames: int = 12\n    blend_num_frames: int = 0\n\n    use_tiling: bool = True\n    use_temporal_tiling: bool = True\n    use_parallel_tiling: bool = True\n    use_temporal_scaling_frames: bool = True\n\n    def __post_init__(self):\n        self.blend_num_frames = (\n            self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames\n        )\n\n    def post_init(self):\n        pass\n\n    @staticmethod\n    def add_cli_args(parser: Any, prefix: str = \"vae-config\") -> Any:\n        \"\"\"Add CLI arguments for VAEConfig fields\"\"\"\n        parser.add_argument(\n            f\"--{prefix}.load-encoder\",\n            action=StoreBoolean,\n            dest=f\"{prefix.replace('-', '_')}.load_encoder\",\n            default=VAEConfig.load_encoder,\n            help=\"Whether to load the VAE encoder\",\n        )\n        parser.add_argument(\n            f\"--{prefix}.load-decoder\",\n            action=StoreBoolean,\n            dest=f\"{prefix.replace('-', '_')}.load_decoder\",\n            default=VAEConfig.load_decoder,\n            help=\"Whether to load the VAE decoder\",\n        )\n        parser.add_argument(\n            f\"--{prefix}.tile-sample-min-height\",\n            type=int,\n            dest=f\"{prefix.replace('-', '_')}.tile_sample_min_height\",\n            default=VAEConfig.tile_sample_min_height,\n            help=\"Minimum height for VAE tile sampling\",\n        )\n        parser.add_argument(\n            f\"--{prefix}.tile-sample-min-width\",\n            type=int,\n            dest=f\"{prefix.replace('-', '_')}.tile_sample_min_width\",\n            default=VAEConfig.tile_sample_min_width,\n            help=\"Minimum width for VAE tile sampling\",\n        )\n        parser.add_argument(\n            f\"--{prefix}.tile-sample-min-num-frames\",\n            type=int,\n            dest=f\"{prefix.replace('-', '_')}.tile_sample_min_num_frames\",\n            default=VAEConfig.tile_sample_min_num_frames,\n            help=\"Minimum number of frames for VAE tile sampling\",\n        )\n        parser.add_argument(\n            f\"--{prefix}.tile-sample-stride-height\",\n            type=int,\n            dest=f\"{prefix.replace('-', '_')}.tile_sample_stride_height\",\n            default=VAEConfig.tile_sample_stride_height,\n            help=\"Stride height for VAE tile sampling\",\n        )\n        parser.add_argument(\n            f\"--{prefix}.tile-sample-stride-width\",\n            type=int,\n            dest=f\"{prefix.replace('-', '_')}.tile_sample_stride_width\",\n            default=VAEConfig.tile_sample_stride_width,\n            help=\"Stride width for VAE tile sampling\",\n        )\n        parser.add_argument(\n            f\"--{prefix}.tile-sample-stride-num-frames\",\n            type=int,\n            dest=f\"{prefix.replace('-', '_')}.tile_sample_stride_num_frames\",\n            default=VAEConfig.tile_sample_stride_num_frames,\n            help=\"Stride number of frames for VAE tile sampling\",\n        )\n        parser.add_argument(\n            f\"--{prefix}.blend-num-frames\",\n            type=int,\n            dest=f\"{prefix.replace('-', '_')}.blend_num_frames\",\n            default=VAEConfig.blend_num_frames,\n            help=\"Number of frames to blend for VAE tile sampling\",\n        )\n        parser.add_argument(\n            f\"--{prefix}.use-tiling\",\n            action=StoreBoolean,\n            dest=f\"{prefix.replace('-', '_')}.use_tiling\",\n            default=VAEConfig.use_tiling,\n            help=\"Whether to use tiling for VAE\",\n        )\n        parser.add_argument(\n            f\"--{prefix}.use-temporal-tiling\",\n            action=StoreBoolean,\n            dest=f\"{prefix.replace('-', '_')}.use_temporal_tiling\",\n            default=VAEConfig.use_temporal_tiling,\n            help=\"Whether to use temporal tiling for VAE\",\n        )\n        parser.add_argument(\n            f\"--{prefix}.use-parallel-tiling\",\n            action=StoreBoolean,\n            dest=f\"{prefix.replace('-', '_')}.use_parallel_tiling\",\n            default=VAEConfig.use_parallel_tiling,\n            help=\"Whether to use parallel tiling for VAE\",\n        )\n\n        return parser\n\n    def get_vae_scale_factor(self):\n        return 2 ** (len(self.arch_config.block_out_channels) - 1)\n\n    def encode_sample_mode(self):\n        return \"argmax\"\n\n    @classmethod\n    def from_cli_args(cls, args: argparse.Namespace) -> \"VAEConfig\":\n        kwargs = {}\n        for attr in dataclasses.fields(cls):\n            value = getattr(args, attr.name, None)\n            if value is not None:\n                kwargs[attr.name] = value\n        return cls(**kwargs)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/models/vaes/dac.py",
    "content": "# Copied and adapted from: mossVG/mova/diffusion/models/dac_vae.py\n# SPDX-License-Identifier: Apache-2.0\n\nfrom dataclasses import dataclass, field\nfrom typing import List\n\nfrom sglang.multimodal_gen.configs.models.base import ArchConfig, ModelConfig\n\n\n@dataclass\nclass DacVAEArchConfig(ArchConfig):\n    codebook_dim: int = 8\n    codebook_size: int = 1024\n    continuous: bool = True\n    decoder_dim: int = 2048\n    decoder_rates: List[int] = field(default_factory=lambda: [8, 5, 4, 3, 2])\n    encoder_dim: int = 128\n    encoder_rates: List[int] = field(default_factory=lambda: [2, 3, 4, 5, 8])\n    hop_length: int = 3840\n    latent_dim: int = 128\n    n_codebooks: int = 9\n    quantizer_dropout: bool = False\n    sample_rate: int = 48000\n\n\n@dataclass\nclass DacVAEConfig(ModelConfig):\n    arch_config: DacVAEArchConfig = field(default_factory=DacVAEArchConfig)\n    load_encoder: bool = True\n    load_decoder: bool = True\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/models/vaes/flux.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\nfrom dataclasses import dataclass, field\n\nfrom sglang.multimodal_gen.configs.models.vaes.base import VAEArchConfig, VAEConfig\n\n\n@dataclass\nclass FluxVAEArchConfig(VAEArchConfig):\n    spatial_compression_ratio: int = 1\n\n    base_dim: int = 96\n    decoder_base_dim: int | None = None\n    z_dim: int = 16\n    dim_mult: tuple[int, ...] = (1, 2, 4, 4)\n    num_res_blocks: int = 2\n    attn_scales: tuple[float, ...] = ()\n    temperal_downsample: tuple[bool, ...] = (False, True, True)\n    dropout: float = 0.0\n\n    is_residual: bool = False\n    in_channels: int = 3\n    out_channels: int = 3\n    patch_size: int | None = None\n    scale_factor_temporal: int = 4\n    scale_factor_spatial: int = 8\n    clip_output: bool = True\n\n\n@dataclass\nclass Flux2VAEArchConfig(FluxVAEArchConfig):\n    pass\n\n\n@dataclass\nclass FluxVAEConfig(VAEConfig):\n    arch_config: FluxVAEArchConfig = field(default_factory=FluxVAEArchConfig)\n\n    use_feature_cache: bool = True\n\n    use_tiling: bool = False\n    use_temporal_tiling: bool = False\n    use_parallel_tiling: bool = False\n\n    def __post_init__(self):\n        self.blend_num_frames = (\n            self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames\n        ) * 2\n\n    def post_init(self):\n        # Calculate vae_scale_factor: prefer block_out_channels, fallback to dim_mult or scale_factor_spatial\n        if (\n            hasattr(self.arch_config, \"block_out_channels\")\n            and self.arch_config.block_out_channels\n        ):\n            self.arch_config.vae_scale_factor = 2 ** (\n                len(self.arch_config.block_out_channels) - 1\n            )\n        elif self.arch_config.dim_mult:\n            self.arch_config.vae_scale_factor = 2 ** (\n                len(self.arch_config.dim_mult) - 1\n            )\n        else:\n            self.arch_config.vae_scale_factor = self.arch_config.scale_factor_spatial\n\n        self.arch_config.spatial_compression_ratio = self.arch_config.vae_scale_factor\n\n\n@dataclass\nclass Flux2VAEConfig(FluxVAEConfig):\n    arch_config: Flux2VAEArchConfig = field(default_factory=Flux2VAEArchConfig)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/models/vaes/glmimage.py",
    "content": "from dataclasses import dataclass, field\n\nimport torch\n\nfrom sglang.multimodal_gen.configs.models.vaes.base import VAEArchConfig, VAEConfig\n\n\n@dataclass\nclass GlmImageVAEArchConfig(VAEArchConfig):\n    spatial_compression_ratio: int = 1\n\n    base_dim: int = 96\n    decoder_base_dim: int | None = None\n    z_dim: int = 16\n    dim_mult: tuple[int, ...] = (1, 2, 4, 4)\n    num_res_blocks: int = 2\n    attn_scales: tuple[float, ...] = ()\n    temperal_downsample: tuple[bool, ...] = (False, True, True)\n    dropout: float = 0.0\n\n    is_residual: bool = False\n    input_channels: int = 3\n    out_channels: int = 3\n    patch_size: int | None = None\n    scale_factor_temporal: int = 4\n    scale_factor_spatial: int = 8\n    clip_output: bool = True\n\n    scaling_factor: float | torch.Tensor = 0\n\n    latents_mean: tuple[float, ...] | None = None\n    latents_std: tuple[float, ...] | None = None\n    shift_factor: float | None = None\n    latent_channels: int = 16\n    in_channels: int = 16\n\n\n@dataclass\nclass GlmImageVAEConfig(VAEConfig):\n    arch_config: GlmImageVAEArchConfig = field(default_factory=GlmImageVAEArchConfig)\n\n    use_feature_cache: bool = True\n\n    use_tiling: bool = False\n    use_temporal_tiling: bool = False\n    use_parallel_tiling: bool = False\n\n    def get_vae_scale_factor(self):\n        return 2 ** len(self.arch_config.temperal_downsample)\n\n    def __post_init__(self):\n        self.blend_num_frames = (\n            self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames\n        ) * 2\n\n    def post_init(self):\n        self.arch_config.vae_scale_factor = 2 ** (\n            len(self.arch_config.temperal_downsample)\n        )\n        self.arch_config.spatial_compression_ratio = self.arch_config.vae_scale_factor\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/models/vaes/hunyuan3d.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\nfrom dataclasses import dataclass, field\n\nfrom sglang.multimodal_gen.configs.models.vaes.base import VAEArchConfig, VAEConfig\n\n\n@dataclass\nclass Hunyuan3DVAEArchConfig(VAEArchConfig):\n    \"\"\"Architecture config for Hunyuan3D VAE.\"\"\"\n\n    latent_shape: tuple[int, ...] = (1024, 64)\n    scale_factor: float = 1.0\n\n\n@dataclass\nclass Hunyuan3DVAEConfig(VAEConfig):\n    \"\"\"VAE configuration for Hunyuan3D.\"\"\"\n\n    arch_config: Hunyuan3DVAEArchConfig = field(default_factory=Hunyuan3DVAEArchConfig)\n    subfolder: str = \"hunyuan3d-dit-v2-0\"\n    load_encoder: bool = False\n    load_decoder: bool = True\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/models/vaes/hunyuanvae.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\nfrom dataclasses import dataclass, field\n\nfrom sglang.multimodal_gen.configs.models.vaes.base import VAEArchConfig, VAEConfig\n\n\n@dataclass\nclass HunyuanVAEArchConfig(VAEArchConfig):\n    in_channels: int = 3\n    out_channels: int = 3\n    latent_channels: int = 16\n    down_block_types: tuple[str, ...] = (\n        \"HunyuanVideoDownBlock3D\",\n        \"HunyuanVideoDownBlock3D\",\n        \"HunyuanVideoDownBlock3D\",\n        \"HunyuanVideoDownBlock3D\",\n    )\n    up_block_types: tuple[str, ...] = (\n        \"HunyuanVideoUpBlock3D\",\n        \"HunyuanVideoUpBlock3D\",\n        \"HunyuanVideoUpBlock3D\",\n        \"HunyuanVideoUpBlock3D\",\n    )\n    block_out_channels: tuple[int, ...] = (128, 256, 512, 512)\n    layers_per_block: int = 2\n    act_fn: str = \"silu\"\n    norm_num_groups: int = 32\n    scaling_factor: float = 0.476986\n    spatial_compression_ratio: int = 8\n    temporal_compression_ratio: int = 4\n    mid_block_add_attention: bool = True\n\n    def __post_init__(self):\n        self.spatial_compression_ratio: int = 2 ** (len(self.block_out_channels) - 1)\n\n\n@dataclass\nclass HunyuanVAEConfig(VAEConfig):\n    arch_config: VAEArchConfig = field(default_factory=HunyuanVAEArchConfig)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/models/vaes/ltx_audio.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\nfrom dataclasses import dataclass, field\nfrom typing import Optional, Tuple\n\nfrom sglang.multimodal_gen.configs.models.vaes.base import VAEArchConfig, VAEConfig\n\n\n@dataclass\nclass LTXAudioVAEArchConfig(VAEArchConfig):\n    # Architecture params\n    causality_axis: str = \"height\"\n    attn_resolutions: Optional[Tuple[int, ...]] = None\n    base_channels: int = 128\n    latent_channels: int = 8\n    output_channels: int = 2\n    ch_mult: Tuple[int, ...] = (1, 2, 4)\n    num_res_blocks: int = 2\n    norm_type: str = \"pixel\"\n    dropout: float = 0.0\n    mid_block_add_attention: bool = False\n    sample_rate: int = 16000\n    mel_hop_length: int = 160\n    is_causal: bool = True\n    mel_bins: Optional[int] = 64\n    double_z: bool = True\n\n\n@dataclass\nclass LTXAudioVAEConfig(VAEConfig):\n    arch_config: LTXAudioVAEArchConfig = field(default_factory=LTXAudioVAEArchConfig)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/models/vaes/ltx_video.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\nfrom dataclasses import dataclass, field\nfrom typing import List\n\nfrom sglang.multimodal_gen.configs.models.vaes.base import VAEArchConfig, VAEConfig\n\n\n@dataclass\nclass LTXVideoVAEArchConfig(VAEArchConfig):\n    # Architecture params\n    in_channels: int = 3\n    latent_channels: int = 128\n    out_channels: int = 3\n    block_out_channels: List[int] = field(\n        default_factory=lambda: [256, 512, 1024, 2048]\n    )\n    down_block_types: List[str] = field(\n        default_factory=lambda: [\n            \"LTX2VideoDownBlock3D\",\n            \"LTX2VideoDownBlock3D\",\n            \"LTX2VideoDownBlock3D\",\n            \"LTX2VideoDownBlock3D\",\n        ]\n    )\n    spatio_temporal_scaling: List[bool] = field(\n        default_factory=lambda: [True, True, True, True]\n    )\n    layers_per_block: List[int] = field(default_factory=lambda: [4, 6, 6, 2, 2])\n    downsample_type: List[str] = field(\n        default_factory=lambda: [\n            \"spatial\",\n            \"temporal\",\n            \"spatiotemporal\",\n            \"spatiotemporal\",\n        ]\n    )\n    patch_size: int = 4\n    patch_size_t: int = 1\n    resnet_norm_eps: float = 1e-6\n    encoder_causal: bool = True\n    encoder_spatial_padding_mode: str = \"zeros\"\n\n    decoder_block_out_channels: List[int] = field(\n        default_factory=lambda: [256, 512, 1024]\n    )\n    decoder_spatio_temporal_scaling: List[bool] = field(\n        default_factory=lambda: [True, True, True]\n    )\n    decoder_layers_per_block: List[int] = field(default_factory=lambda: [5, 5, 5, 5])\n    decoder_causal: bool = False\n    decoder_spatial_padding_mode: str = \"reflect\"\n\n\n@dataclass\nclass LTXVideoVAEConfig(VAEConfig):\n    arch_config: LTXVideoVAEArchConfig = field(default_factory=LTXVideoVAEArchConfig)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/models/vaes/qwenimage.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\nfrom dataclasses import dataclass, field\n\nfrom sglang.multimodal_gen.configs.models.vaes.base import VAEArchConfig, VAEConfig\n\n\n@dataclass\nclass QwenImageVAEArchConfig(VAEArchConfig):\n    spatial_compression_ratio: int = 1\n\n    base_dim: int = 96\n    decoder_base_dim: int | None = None\n    z_dim: int = 16\n    dim_mult: tuple[int, ...] = (1, 2, 4, 4)\n    num_res_blocks: int = 2\n    attn_scales: tuple[float, ...] = ()\n    temperal_downsample: tuple[bool, ...] = (False, True, True)\n    dropout: float = 0.0\n\n    is_residual: bool = False\n    input_channels: int = 3\n    out_channels: int = 3\n    patch_size: int | None = None\n    scale_factor_temporal: int = 4\n    scale_factor_spatial: int = 8\n    clip_output: bool = True\n\n\n@dataclass\nclass QwenImageVAEConfig(VAEConfig):\n    arch_config: QwenImageVAEArchConfig = field(default_factory=QwenImageVAEArchConfig)\n\n    use_feature_cache: bool = True\n\n    use_tiling: bool = False\n    use_temporal_tiling: bool = False\n    use_parallel_tiling: bool = False\n\n    def get_vae_scale_factor(self):\n        return 2 ** len(self.arch_config.temperal_downsample)\n\n    def __post_init__(self):\n        self.blend_num_frames = (\n            self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames\n        ) * 2\n\n    def post_init(self):\n        self.arch_config.vae_scale_factor = 2 ** (\n            len(self.arch_config.temperal_downsample)\n        )\n        self.arch_config.spatial_compression_ratio = self.arch_config.vae_scale_factor\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/models/vaes/sana.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n#\n# VAE configuration for SANA's DC-AE (Deep Compression AutoEncoder).\n#\n# DC-AE achieves a 32x spatial compression ratio (vs. 8x for standard SD VAEs),\n# which means a 1024x1024 image becomes 32x32 latents with 32 channels.\n# This aggressive compression is what allows SANA to run efficiently at\n# high resolutions despite having a relatively small DiT.\n#\n# Reference: https://arxiv.org/abs/2405.17811\n\nfrom dataclasses import dataclass, field\n\nfrom sglang.multimodal_gen.configs.models.vaes.base import VAEArchConfig, VAEConfig\n\n\n@dataclass\nclass SanaVAEArchConfig(VAEArchConfig):\n    spatial_compression_ratio: int = 32\n    # DC-AE uses a different scaling factor than standard VAEs;\n    # this value must match the pretrained checkpoint.\n    scaling_factor: float = 0.41407\n    latent_channels: int = 32\n    in_channels: int = 3\n\n\n@dataclass\nclass SanaVAEConfig(VAEConfig):\n    arch_config: SanaVAEArchConfig = field(default_factory=SanaVAEArchConfig)\n\n    # DC-AE does not currently support tiling in our wrapper.\n    # Enable these once the diffusers AutoencoderDC adds tiling support.\n    use_tiling: bool = False\n    use_temporal_tiling: bool = False\n    use_parallel_tiling: bool = False\n\n    def post_init(self):\n        # Called by VAELoader AFTER update_model_arch() merges the HF config.json\n        # values into arch_config. Must be post_init() (not __post_init__) because\n        # __post_init__ fires at dataclass creation time, before the HF config merge.\n        #\n        # The base VAEConfig.get_vae_scale_factor() derives from block_out_channels,\n        # which DC-AE doesn't have. Set vae_scale_factor directly from the\n        # spatial_compression_ratio (32x for DC-AE).\n        self.arch_config.vae_scale_factor = self.arch_config.spatial_compression_ratio\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/models/vaes/wanvae.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\nfrom dataclasses import dataclass, field\n\nimport torch\n\nfrom sglang.multimodal_gen.configs.models.vaes.base import VAEArchConfig, VAEConfig\n\n\n@dataclass\nclass WanVAEArchConfig(VAEArchConfig):\n    base_dim: int = 96\n    decoder_base_dim: int | None = None\n    z_dim: int = 16\n    dim_mult: tuple[int, ...] = (1, 2, 4, 4)\n    num_res_blocks: int = 2\n    attn_scales: tuple[float, ...] = ()\n    temperal_downsample: tuple[bool, ...] = (False, True, True)\n    dropout: float = 0.0\n    latents_mean: tuple[float, ...] = (\n        -0.7571,\n        -0.7089,\n        -0.9113,\n        0.1075,\n        -0.1745,\n        0.9653,\n        -0.1517,\n        1.5508,\n        0.4134,\n        -0.0715,\n        0.5517,\n        -0.3632,\n        -0.1922,\n        -0.9497,\n        0.2503,\n        -0.2921,\n    )\n    latents_std: tuple[float, ...] = (\n        2.8184,\n        1.4541,\n        2.3275,\n        2.6558,\n        1.2196,\n        1.7708,\n        2.6052,\n        2.0743,\n        3.2687,\n        2.1526,\n        2.8652,\n        1.5579,\n        1.6382,\n        1.1253,\n        2.8251,\n        1.9160,\n    )\n    is_residual: bool = False\n    in_channels: int = 3\n    out_channels: int = 3\n    patch_size: int | None = None\n    scale_factor_temporal: int = 4\n    scale_factor_spatial: int = 8\n    clip_output: bool = True\n\n    def __post_init__(self):\n        self.scaling_factor: torch.tensor = 1.0 / torch.tensor(self.latents_std).view(\n            1, self.z_dim, 1, 1, 1\n        )\n        self.shift_factor: torch.tensor = torch.tensor(self.latents_mean).view(\n            1, self.z_dim, 1, 1, 1\n        )\n        self.temporal_compression_ratio = self.scale_factor_temporal\n        self.spatial_compression_ratio = self.scale_factor_spatial\n\n\n@dataclass\nclass WanVAEConfig(VAEConfig):\n    arch_config: WanVAEArchConfig = field(default_factory=WanVAEArchConfig)\n    use_feature_cache: bool = True\n\n    use_tiling: bool = False\n    use_temporal_tiling: bool = False\n    use_parallel_tiling: bool = False\n\n    use_parallel_encode: bool = True\n    use_parallel_decode: bool = True\n\n    def __post_init__(self):\n        self.blend_num_frames = (\n            self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames\n        ) * 2\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/models/vocoder/__init__.py",
    "content": "from sglang.multimodal_gen.configs.models.vocoder.ltx_vocoder import LTXVocoderConfig\n\n__all__ = [\"LTXVocoderConfig\"]\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/models/vocoder/base.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\nimport argparse\nimport dataclasses\nfrom dataclasses import dataclass, field\n\nfrom sglang.multimodal_gen.configs.models.base import ArchConfig, ModelConfig\n\n\n@dataclass\nclass VocoderArchConfig(ArchConfig):\n    in_channels: int = 128\n    hidden_channels: int = 1024\n    out_channels: int = 2\n\n\n@dataclass\nclass VocoderConfig(ModelConfig):\n    arch_config: VocoderArchConfig = field(default_factory=VocoderArchConfig)\n\n    @classmethod\n    def from_cli_args(cls, args: argparse.Namespace) -> \"VocoderConfig\":\n        kwargs = {}\n        for attr in dataclasses.fields(cls):\n            value = getattr(args, attr.name, None)\n            if value is not None:\n                kwargs[attr.name] = value\n        return cls(**kwargs)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/models/vocoder/ltx_vocoder.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\nfrom dataclasses import dataclass, field\nfrom typing import List\n\nfrom sglang.multimodal_gen.configs.models.vocoder.base import (\n    VocoderArchConfig,\n    VocoderConfig,\n)\n\n\n@dataclass\nclass LTXVocoderArchConfig(VocoderArchConfig):\n    # Architecture params\n    in_channels: int = 128\n    hidden_channels: int = 1024\n    out_channels: int = 2\n    upsample_kernel_sizes: List[int] = field(default_factory=lambda: [3, 7, 11])\n    upsample_factors: List[int] = field(default_factory=lambda: [6, 5, 2, 2, 2])\n    resnet_kernel_sizes: List[int] = field(default_factory=lambda: [3, 7, 11])\n    resnet_dilations: List[List[int]] = field(\n        default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]\n    )\n    leaky_relu_negative_slope: float = 0.1\n    sample_rate: int = 24000\n\n\n@dataclass\nclass LTXVocoderConfig(VocoderConfig):\n    arch_config: LTXVocoderArchConfig = field(default_factory=LTXVocoderArchConfig)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/pipeline_configs/__init__.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\nfrom sglang.multimodal_gen.configs.pipeline_configs.base import (\n    PipelineConfig,\n    SlidingTileAttnConfig,\n)\nfrom sglang.multimodal_gen.configs.pipeline_configs.diffusers_generic import (\n    DiffusersGenericPipelineConfig,\n)\nfrom sglang.multimodal_gen.configs.pipeline_configs.flux import (\n    Flux2KleinPipelineConfig,\n    Flux2PipelineConfig,\n    FluxPipelineConfig,\n)\nfrom sglang.multimodal_gen.configs.pipeline_configs.flux_finetuned import (\n    Flux2FinetunedPipelineConfig,\n)\nfrom sglang.multimodal_gen.configs.pipeline_configs.helios import (\n    HeliosDistilledConfig,\n    HeliosMidConfig,\n    HeliosT2VConfig,\n)\nfrom sglang.multimodal_gen.configs.pipeline_configs.hunyuan import (\n    FastHunyuanConfig,\n    HunyuanConfig,\n)\nfrom sglang.multimodal_gen.configs.pipeline_configs.hunyuan3d import (\n    Hunyuan3D2PipelineConfig,\n)\nfrom sglang.multimodal_gen.configs.pipeline_configs.ltx_2 import LTX2PipelineConfig\nfrom sglang.multimodal_gen.configs.pipeline_configs.mova import MOVAPipelineConfig\nfrom sglang.multimodal_gen.configs.pipeline_configs.sana import SanaPipelineConfig\nfrom sglang.multimodal_gen.configs.pipeline_configs.wan import (\n    SelfForcingWanT2V480PConfig,\n    WanI2V480PConfig,\n    WanI2V720PConfig,\n    WanT2V480PConfig,\n    WanT2V720PConfig,\n)\nfrom sglang.multimodal_gen.configs.pipeline_configs.zimage import ZImagePipelineConfig\n\n__all__ = [\n    \"DiffusersGenericPipelineConfig\",\n    \"HeliosDistilledConfig\",\n    \"HeliosMidConfig\",\n    \"HeliosT2VConfig\",\n    \"HunyuanConfig\",\n    \"FastHunyuanConfig\",\n    \"Hunyuan3D2PipelineConfig\",\n    \"FluxPipelineConfig\",\n    \"Flux2PipelineConfig\",\n    \"Flux2KleinPipelineConfig\",\n    \"Flux2FinetunedPipelineConfig\",\n    \"PipelineConfig\",\n    \"SanaPipelineConfig\",\n    \"SlidingTileAttnConfig\",\n    \"MOVAPipelineConfig\",\n    \"WanT2V480PConfig\",\n    \"WanI2V480PConfig\",\n    \"WanT2V720PConfig\",\n    \"WanI2V720PConfig\",\n    \"SelfForcingWanT2V480PConfig\",\n    \"ZImagePipelineConfig\",\n    \"LTX2PipelineConfig\",\n]\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/pipeline_configs/base.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\nimport json\nimport os\nfrom collections.abc import Callable\nfrom dataclasses import asdict, dataclass, field, fields\nfrom enum import Enum, auto\nfrom typing import Any\n\nimport numpy as np\nimport PIL\nimport torch\nfrom einops import rearrange\n\nfrom sglang.multimodal_gen.configs.models import (\n    DiTConfig,\n    EncoderConfig,\n    ModelConfig,\n    VAEConfig,\n)\nfrom sglang.multimodal_gen.configs.models.encoders import BaseEncoderOutput\nfrom sglang.multimodal_gen.configs.models.encoders.t5 import T5Config\nfrom sglang.multimodal_gen.configs.sample.sampling_params import DataType\nfrom sglang.multimodal_gen.configs.utils import update_config_from_args\nfrom sglang.multimodal_gen.runtime.distributed.communication_op import (\n    sequence_model_parallel_all_gather,\n)\nfrom sglang.multimodal_gen.runtime.distributed.parallel_state import (\n    get_sp_parallel_rank,\n    get_sp_world_size,\n)\nfrom sglang.multimodal_gen.runtime.models.vision_utils import get_default_height_width\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.utils import (\n    FlexibleArgumentParser,\n    StoreBoolean,\n    shallow_asdict,\n)\n\nlogger = init_logger(__name__)\n\n\n# NOTE: possible duplication with DataType\n# this may focus on the model's original ability\nclass ModelTaskType(Enum):\n    # TODO: check if I2V/TI2V models can work w/wo text\n\n    I2V = auto()  # Image to Video\n    T2V = auto()  # Text to Video\n    TI2V = auto()  # Text and Image to Video\n\n    T2I = auto()  # Text to Image\n    I2I = auto()  # Image to Image\n    TI2I = auto()  # Image to Image or Text-Image to Image\n    I2M = auto()  # Image to Mesh\n\n    def is_image_gen(self) -> bool:\n        return (\n            self == ModelTaskType.T2I\n            or self == ModelTaskType.I2I\n            or self == ModelTaskType.TI2I\n        )\n\n    def requires_image_input(self) -> bool:\n        return (\n            self == ModelTaskType.I2V\n            or self == ModelTaskType.I2I\n            or self == ModelTaskType.I2M\n        )\n\n    def accepts_image_input(self) -> bool:\n        return (\n            self == ModelTaskType.I2V\n            or self == ModelTaskType.I2I\n            or self == ModelTaskType.TI2I\n            or self == ModelTaskType.TI2V\n            or self == ModelTaskType.I2M\n        )\n\n    def data_type(self) -> DataType:\n        if self == ModelTaskType.I2M:\n            return DataType.MESH\n        if self.is_image_gen():\n            return DataType.IMAGE\n        else:\n            return DataType.VIDEO\n\n\nclass STA_Mode(str, Enum):\n    \"\"\"STA (Sliding Tile Attention) modes.\"\"\"\n\n    STA_INFERENCE = \"STA_inference\"\n    STA_SEARCHING = \"STA_searching\"\n    STA_TUNING = \"STA_tuning\"\n    STA_TUNING_CFG = \"STA_tuning_cfg\"\n    NONE = None\n\n\ndef preprocess_text(prompt: str) -> str:\n    return prompt\n\n\ndef postprocess_text(output: BaseEncoderOutput, _text_inputs) -> torch.tensor:\n    raise NotImplementedError\n\n\ndef shard_rotary_emb_for_sp(emb):\n    \"\"\"\n    Shard rotary embeddings [S, D] along sequence for SP.\n    If S is not divisible by SP degree, pad by repeating the last row.\n    \"\"\"\n    # Sequence Parallelism: slice image RoPE to local shard if enabled\n    try:\n        from sglang.multimodal_gen.runtime.distributed.parallel_state import (\n            get_sp_parallel_rank,\n            get_sp_world_size,\n        )\n\n        sp_world_size = get_sp_world_size()\n    except Exception:\n        sp_world_size = 1\n    seq_len = emb.shape[0]\n    if seq_len % sp_world_size != 0:\n        pad_len = sp_world_size - (seq_len % sp_world_size)\n        pad = emb[-1:].repeat(pad_len, 1)\n        emb = torch.cat([emb, pad], dim=0)\n    if sp_world_size > 1:\n        try:\n            rank = get_sp_parallel_rank()\n        except Exception:\n            rank = 0\n        seq_len = emb.shape[0]\n        local_len = seq_len // sp_world_size\n        start = rank * local_len\n        end = start + local_len\n        emb = emb[start:end]\n        return emb\n    else:\n        return emb\n\n\ndef maybe_unpad_latents(latents, batch):\n    # If SP padding was applied, remove extra tokens before reshaping\n    raw_shape = batch.raw_latent_shape\n    if len(raw_shape) == 3:\n        # Sequence format [B, S, D]: use seq_len directly\n        target_tokens = raw_shape[1]\n    else:\n        # Spatial format [B, C, H, W] or [B, C, T, H, W]: use width * height\n        width, height = raw_shape[-1], raw_shape[-2]\n        target_tokens = width * height\n    if latents.shape[1] > target_tokens:\n        latents = latents[:, :target_tokens, :]\n    return latents\n\n\n# config for a single pipeline\n@dataclass\nclass PipelineConfig:\n    \"\"\"The base configuration class for a generation pipeline.\"\"\"\n\n    task_type: ModelTaskType = ModelTaskType.I2I\n    skip_input_image_preprocess: bool = False\n\n    model_path: str = \"\"\n    pipeline_config_path: str | None = None\n\n    # precision and autocast\n    enable_autocast: bool = True\n\n    # generation parameters\n    # controls the timestep embedding generation\n    should_use_guidance: bool = True\n    embedded_cfg_scale: float = 6.0\n    flow_shift: float | None = None\n    disable_autocast: bool = False\n\n    # Model configuration\n    dit_config: DiTConfig = field(default_factory=DiTConfig)\n    dit_precision: str = \"bf16\"\n\n    # VAE configuration\n    vae_config: VAEConfig = field(default_factory=VAEConfig)\n    vae_precision: str = \"fp32\"\n    vae_tiling: bool = True\n    vae_slicing: bool = False\n    vae_sp: bool = True\n\n    # Image encoder configuration\n    image_encoder_config: EncoderConfig = field(default_factory=EncoderConfig)\n    image_encoder_precision: str = \"fp32\"\n\n    # Text encoder configuration\n    DEFAULT_TEXT_ENCODER_PRECISIONS = (\"fp32\",)\n    text_encoder_configs: tuple[EncoderConfig, ...] = field(\n        default_factory=lambda: (EncoderConfig(),)\n    )\n    # See PRECISION_TO_TYPE for detailed mapping\n    text_encoder_precisions: tuple[str, ...] = field(default_factory=lambda: (\"fp32\",))\n    text_encoder_extra_args: list[dict] = field(default_factory=lambda: [{}])\n\n    # image encoding\n    image_encoder_extra_args: dict = field(default_factory=lambda: {})\n\n    def postprocess_image(self, image):\n        return image.last_hidden_state\n\n    preprocess_text_funcs: tuple[Callable[[str], str], ...] = field(\n        default_factory=lambda: (preprocess_text,)\n    )\n\n    # get prompt_embeds from encoder output\n    postprocess_text_funcs: tuple[Callable[[BaseEncoderOutput], torch.tensor], ...] = (\n        field(default_factory=lambda: (postprocess_text,))\n    )\n\n    # STA (Sliding Tile Attention) parameters\n    mask_strategy_file_path: str | None = None\n    STA_mode: STA_Mode = STA_Mode.STA_INFERENCE\n    skip_time_steps: int = 15\n\n    # DMD parameters\n    dmd_denoising_steps: list[int] | None = field(default=None)\n\n    # Wan2.2 TI2V parameters\n    boundary_ratio: float | None = None\n\n    # Compilation\n    # enable_torch_compile: bool = False\n\n    # calculate the adjust size for condition image\n    # width: original condition image width\n    # height: original condition image height\n    def calculate_condition_image_size(self, image, width, height) -> tuple[int, int]:\n        vae_scale_factor = self.vae_config.arch_config.spatial_compression_ratio\n        height, width = get_default_height_width(image, vae_scale_factor, height, width)\n        return width, height\n\n    ## For timestep preparation stage\n\n    def prepare_sigmas(self, sigmas, num_inference_steps):\n        return sigmas\n\n    ## For ImageVAEEncodingStage\n    def preprocess_condition_image(\n        self, image, target_width, target_height, _vae_image_processor\n    ):\n        \"\"\"\n        preprocess the condition image, returns (image, final_image_width, final_image_height)\n        \"\"\"\n        return image.resize(\n            (target_width, target_height), PIL.Image.Resampling.LANCZOS\n        ), (target_width, target_height)\n\n    def prepare_calculated_size(self, image):\n        return self.calculate_condition_image_size(image, image.width, image.height)\n\n    def prepare_image_processor_kwargs(self, batch, neg=False):\n        return {}\n\n    def postprocess_image_latent(self, latent_condition, batch):\n        vae_arch_config = self.vae_config.arch_config\n        spatial_compression_ratio = vae_arch_config.spatial_compression_ratio\n        temporal_compression_ratio = vae_arch_config.temporal_compression_ratio\n        num_frames = batch.num_frames\n        latent_height = batch.height // spatial_compression_ratio\n        latent_width = batch.width // spatial_compression_ratio\n        mask_lat_size = torch.ones(1, 1, num_frames, latent_height, latent_width)\n        mask_lat_size[:, :, 1:] = 0\n        first_frame_mask = mask_lat_size[:, :, 0:1]\n        first_frame_mask = torch.repeat_interleave(\n            first_frame_mask,\n            repeats=temporal_compression_ratio,\n            dim=2,\n        )\n        mask_lat_size = torch.concat(\n            [first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2\n        )\n        mask_lat_size = mask_lat_size.view(\n            1,\n            -1,\n            temporal_compression_ratio,\n            latent_height,\n            latent_width,\n        )\n        mask_lat_size = mask_lat_size.transpose(1, 2)\n        mask_lat_size = mask_lat_size.to(latent_condition.device)\n        image_latents = torch.concat([mask_lat_size, latent_condition], dim=1)\n        return image_latents\n\n    def slice_noise_pred(self, noise, latents):\n        return noise\n\n    def adjust_num_frames(self, num_frames):\n        return num_frames\n\n    # tokenize the prompt\n    def tokenize_prompt(self, prompt: list[str], tokenizer, tok_kwargs) -> dict:\n        return tokenizer(prompt, **tok_kwargs)\n\n    def prepare_latent_shape(self, batch, batch_size, num_frames):\n        height = batch.height // self.vae_config.arch_config.spatial_compression_ratio\n        width = batch.width // self.vae_config.arch_config.spatial_compression_ratio\n\n        # Calculate latent shape\n        shape = (\n            batch_size,\n            self.dit_config.num_channels_latents,\n            num_frames,\n            height,\n            width,\n        )\n\n        return shape\n\n    def allow_set_num_frames(self):\n        return False\n\n    def get_decode_scale_and_shift(self, device, dtype, vae):\n        vae_arch_config = self.vae_config.arch_config\n        scaling_factor = getattr(vae_arch_config, \"scaling_factor\", None)\n        if scaling_factor is None:\n            scaling_factor = getattr(vae, \"scaling_factor\", None)\n\n        shift_factor = getattr(vae_arch_config, \"shift_factor\", None)\n        if shift_factor is None:\n            shift_factor = getattr(vae, \"shift_factor\", None)\n        return scaling_factor, shift_factor\n\n    # called after latents are prepared\n    def maybe_pack_latents(self, latents, batch_size, batch):\n        return latents\n\n    def maybe_prepare_latent_ids(self, latents):\n        return None\n\n    # called after vae encode\n    def postprocess_vae_encode(self, image_latents, vae):\n        return image_latents\n\n    # called after scale_and_shift, before vae decoding\n    def preprocess_decoding(self, latents, server_args=None, vae=None):\n        return latents\n\n    def gather_latents_for_sp(self, latents):\n        # For video latents [B, C, T_local, H, W], gather along time dim=2\n        latents = sequence_model_parallel_all_gather(latents, dim=2)\n        return latents\n\n    def preprocess_vae_image(self, batch, vae_image_processor):\n        pass\n\n    def shard_latents_for_sp(self, batch, latents):\n        # general logic for video models\n        sp_world_size, rank_in_sp_group = get_sp_world_size(), get_sp_parallel_rank()\n        if batch.enable_sequence_shard and sp_world_size > 1:\n            return latents, False\n        if latents.dim() != 5:\n            return latents, False\n        time_dim = latents.shape[2]\n\n        # Pad to next multiple of SP degree if needed\n        if time_dim > 0 and time_dim % sp_world_size != 0:\n            logger.debug(\n                \"Padding latents to next multiple of SP degree, performance is sub-optimal\"\n            )\n            pad_len = sp_world_size - (time_dim % sp_world_size)\n            pad = torch.zeros(\n                (*latents.shape[:2], pad_len, *latents.shape[3:]),\n                dtype=latents.dtype,\n                device=latents.device,\n            )\n            latents = torch.cat([latents, pad], dim=2)\n\n        assert latents.shape[2] % sp_world_size == 0\n        sharded_tensor = rearrange(\n            latents, \"b c (n t) h w -> b c n t h w\", n=sp_world_size\n        ).contiguous()\n        sharded_tensor = sharded_tensor[:, :, rank_in_sp_group, :, :, :]\n        return sharded_tensor, True\n\n    def get_pos_prompt_embeds(self, batch):\n        return batch.prompt_embeds\n\n    def get_neg_prompt_embeds(self, batch):\n        return batch.negative_prompt_embeds\n\n    def post_denoising_loop(self, latents, batch):\n        latents = maybe_unpad_latents(latents, batch)\n        return latents\n\n    def post_decoding(self, frames, server_args):\n        return frames\n\n    def prepare_pos_cond_kwargs(self, batch, device, rotary_emb, dtype):\n        return {}\n\n    def prepare_neg_cond_kwargs(self, batch, device, rotary_emb, dtype):\n        return {}\n\n    @staticmethod\n    def add_cli_args(\n        parser: FlexibleArgumentParser, prefix: str = \"\"\n    ) -> FlexibleArgumentParser:\n        prefix_with_dot = f\"{prefix}.\" if (prefix.strip() != \"\") else \"\"\n\n        # model_path will be conflicting with the model_path in ServerArgs,\n        # so we add it separately if prefix is not empty\n        if prefix_with_dot != \"\":\n            parser.add_argument(\n                f\"--{prefix_with_dot}model-path\",\n                type=str,\n                dest=f\"{prefix_with_dot.replace('-', '_')}model_path\",\n                default=PipelineConfig.model_path,\n                help=\"Path to the pretrained model\",\n            )\n\n        parser.add_argument(\n            f\"--{prefix_with_dot}pipeline-config-path\",\n            type=str,\n            dest=f\"{prefix_with_dot.replace('-', '_')}pipeline_config_path\",\n            default=PipelineConfig.pipeline_config_path,\n            help=\"Path to the pipeline config\",\n        )\n        parser.add_argument(\n            f\"--{prefix_with_dot}embedded-cfg-scale\",\n            type=float,\n            dest=f\"{prefix_with_dot.replace('-', '_')}embedded_cfg_scale\",\n            default=PipelineConfig.embedded_cfg_scale,\n            help=\"Embedded CFG scale\",\n        )\n        parser.add_argument(\n            f\"--{prefix_with_dot}flow-shift\",\n            type=float,\n            dest=f\"{prefix_with_dot.replace('-', '_')}flow_shift\",\n            default=PipelineConfig.flow_shift,\n            help=\"Flow shift parameter\",\n        )\n        parser.add_argument(\n            f\"--{prefix_with_dot}resolution\",\n            type=int,\n            dest=f\"{prefix_with_dot.replace('-', '_')}resolution\",\n            default=None,\n            help=\"Override the selected pipeline config's resolution setting. Only applies to pipelines that define a resolution field.\",\n        )\n\n        # DiT configuration\n        parser.add_argument(\n            f\"--{prefix_with_dot}dit-precision\",\n            type=str,\n            dest=f\"{prefix_with_dot.replace('-', '_')}dit_precision\",\n            default=PipelineConfig.dit_precision,\n            choices=[\"fp32\", \"fp16\", \"bf16\"],\n            help=\"Precision for the DiT model\",\n        )\n\n        # VAE configuration\n        parser.add_argument(\n            f\"--{prefix_with_dot}vae-precision\",\n            type=str,\n            dest=f\"{prefix_with_dot.replace('-', '_')}vae_precision\",\n            default=PipelineConfig.vae_precision,\n            choices=[\"fp32\", \"fp16\", \"bf16\"],\n            help=\"Precision for VAE\",\n        )\n        parser.add_argument(\n            f\"--{prefix_with_dot}vae-tiling\",\n            action=StoreBoolean,\n            dest=f\"{prefix_with_dot.replace('-', '_')}vae_tiling\",\n            default=PipelineConfig.vae_tiling,\n            help=\"Enable VAE tiling\",\n        )\n        parser.add_argument(\n            f\"--{prefix_with_dot}vae-slicing\",\n            action=StoreBoolean,\n            dest=f\"{prefix_with_dot.replace('-', '_')}vae_slicing\",\n            default=PipelineConfig.vae_slicing,\n            help=\"Enable VAE slicing\",\n        )\n        parser.add_argument(\n            f\"--{prefix_with_dot}vae-sp\",\n            action=StoreBoolean,\n            dest=f\"{prefix_with_dot.replace('-', '_')}vae_sp\",\n            help=\"Enable VAE spatial parallelism\",\n        )\n\n        # Text encoder configuration\n        parser.add_argument(\n            f\"--{prefix_with_dot}text-encoder-precisions\",\n            nargs=\"+\",\n            type=str,\n            dest=f\"{prefix_with_dot.replace('-', '_')}text_encoder_precisions\",\n            default=PipelineConfig.DEFAULT_TEXT_ENCODER_PRECISIONS,\n            choices=[\"fp32\", \"fp16\", \"bf16\"],\n            help=\"Precision for each text encoder\",\n        )\n\n        # Image encoder configuration\n        parser.add_argument(\n            f\"--{prefix_with_dot}image-encoder-precision\",\n            type=str,\n            dest=f\"{prefix_with_dot.replace('-', '_')}image_encoder_precision\",\n            default=PipelineConfig.image_encoder_precision,\n            choices=[\"fp32\", \"fp16\", \"bf16\"],\n            help=\"Precision for image encoder\",\n        )\n\n        # DMD parameters\n        parser.add_argument(\n            f\"--{prefix_with_dot}dmd-denoising-steps\",\n            type=parse_int_list,\n            default=PipelineConfig.dmd_denoising_steps,\n            help=\"Comma-separated list of denoising steps (e.g., '1000,757,522')\",\n        )\n\n        # Add VAE configuration arguments\n        from sglang.multimodal_gen.configs.models.vaes.base import VAEConfig\n\n        VAEConfig.add_cli_args(parser, prefix=f\"{prefix_with_dot}vae-config\")\n\n        # Add DiT configuration arguments\n        from sglang.multimodal_gen.configs.models.dits.base import DiTConfig\n\n        DiTConfig.add_cli_args(parser, prefix=f\"{prefix_with_dot}dit-config\")\n\n        # Add T5 configuration arguments\n        from sglang.multimodal_gen.configs.models.encoders.t5 import T5Config\n\n        T5Config.add_cli_args(parser, prefix=f\"{prefix_with_dot}t5-config\")\n\n        return parser\n\n    def update_config_from_dict(self, args: dict[str, Any], prefix: str = \"\") -> None:\n        prefix_with_dot = f\"{prefix}.\" if (prefix.strip() != \"\") else \"\"\n        update_config_from_args(self, args, prefix, pop_args=True)\n        update_config_from_args(\n            self.vae_config, args, f\"{prefix_with_dot}vae_config\", pop_args=True\n        )\n        update_config_from_args(\n            self.dit_config, args, f\"{prefix_with_dot}dit_config\", pop_args=True\n        )\n        for text_encoder_config in self.text_encoder_configs:\n            if isinstance(text_encoder_config, T5Config):\n                update_config_from_args(\n                    text_encoder_config,\n                    args,\n                    f\"{prefix_with_dot}t5_config\",\n                    pop_args=True,\n                )\n\n    @classmethod\n    def from_kwargs(\n        cls, kwargs: dict[str, Any], config_cli_prefix: str = \"\"\n    ) -> \"PipelineConfig\":\n        \"\"\"\n        Load PipelineConfig from kwargs Dictionary, as part of the ServerArg initialization process\n        kwargs: dictionary of kwargs\n        config_cli_prefix: prefix of CLI arguments for this PipelineConfig instance\n        \"\"\"\n        from sglang.multimodal_gen.registry import get_model_info\n\n        prefix_with_dot = (\n            f\"{config_cli_prefix}.\" if (config_cli_prefix.strip() != \"\") else \"\"\n        )\n        model_path: str | None = kwargs.get(\n            prefix_with_dot + \"model_path\", None\n        ) or kwargs.get(\"model_path\")\n        pipeline_config_or_path: str | PipelineConfig | dict[str, Any] | None = (\n            kwargs.get(prefix_with_dot + \"pipeline_config\", None)\n            or kwargs.get(\"pipeline_config\")\n        )\n        if model_path is None:\n            raise ValueError(\"model_path is required in kwargs\")\n\n        # Check if model_path is a safetensors file and pipeline_class_name is specified\n        pipeline_class_name = kwargs.get(\n            prefix_with_dot + \"pipeline_class_name\"\n        ) or kwargs.get(\"pipeline_class_name\")\n        is_safetensors_file = os.path.isfile(model_path) and model_path.endswith(\n            \".safetensors\"\n        )\n\n        # 1. Get the pipeline config class from the registry\n        from sglang.multimodal_gen.configs.pipeline_configs.flux import (\n            Flux2PipelineConfig,\n        )\n        from sglang.multimodal_gen.registry import get_pipeline_config_classes\n\n        # If model_path is a safetensors file and pipeline_class_name is specified,\n        # try to get PipelineConfig from the registry first\n        if is_safetensors_file and pipeline_class_name:\n            config_classes = get_pipeline_config_classes(pipeline_class_name)\n            if config_classes is not None:\n                pipeline_config_cls, _ = config_classes\n                logger.info(\n                    f\"Detected safetensors file with {pipeline_class_name}, \"\n                    f\"using {pipeline_config_cls.__name__} directly without model_index.json\"\n                )\n            else:\n                model_info = get_model_info(\n                    model_path,\n                    backend=kwargs.get(\"backend\"),\n                    model_id=kwargs.get(\"model_id\"),\n                )\n                if model_info is None:\n                    from sglang.multimodal_gen.registry import (\n                        _PIPELINE_CONFIG_REGISTRY,\n                        _discover_and_register_pipelines,\n                    )\n\n                    _discover_and_register_pipelines()\n                    available_pipelines = list(_PIPELINE_CONFIG_REGISTRY.keys())\n                    raise ValueError(\n                        f\"Could not get model info for '{model_path}'. \"\n                        f\"If using a safetensors file, please specify a valid pipeline_class_name. \"\n                        f\"Available pipelines with config classes: {available_pipelines}\"\n                    )\n                pipeline_config_cls = model_info.pipeline_config_cls\n        else:\n            model_info = get_model_info(\n                model_path,\n                backend=kwargs.get(\"backend\"),\n                model_id=kwargs.get(\"model_id\"),\n            )\n            if model_info is None:\n                raise ValueError(\n                    f\"Could not get model info for '{model_path}'. \"\n                    f\"If using a safetensors file, please specify pipeline_class_name\"\n                )\n            # 1.5. Adjust pipeline config for fine-tuned VAE if needed\n            pipeline_config_cls = model_info.pipeline_config_cls\n        vae_path = kwargs.get(prefix_with_dot + \"vae_path\") or kwargs.get(\"vae_path\")\n        if vae_path is None:\n            component_paths = kwargs.get(\n                prefix_with_dot + \"component_paths\"\n            ) or kwargs.get(\"component_paths\")\n            if isinstance(component_paths, dict):\n                vae_path = component_paths.get(\"vae\")\n\n        # Check if this is a Flux2 model with fal/FLUX.2-Tiny-AutoEncoder\n        if (\n            isinstance(pipeline_config_cls, type)\n            and issubclass(pipeline_config_cls, Flux2PipelineConfig)\n            and vae_path is not None\n            and \"FLUX.2-Tiny-AutoEncoder\" in vae_path\n        ):\n            from sglang.multimodal_gen.configs.pipeline_configs.flux_finetuned import (\n                Flux2FinetunedPipelineConfig,\n            )\n\n            pipeline_config_cls = Flux2FinetunedPipelineConfig\n\n        pipeline_config = pipeline_config_cls()\n\n        # 2. Load PipelineConfig from a json file or a PipelineConfig object if provided\n        if isinstance(pipeline_config_or_path, str):\n            pipeline_config.load_from_json(pipeline_config_or_path)\n            kwargs[prefix_with_dot + \"pipeline_config_path\"] = pipeline_config_or_path\n        elif isinstance(pipeline_config_or_path, PipelineConfig):\n            pipeline_config = pipeline_config_or_path\n        elif isinstance(pipeline_config_or_path, dict):\n            pipeline_config.update_pipeline_config(pipeline_config_or_path)\n\n        # 3. Update PipelineConfig from CLI arguments if provided\n        kwargs[prefix_with_dot + \"model_path\"] = model_path\n        pipeline_config.update_config_from_dict(kwargs, config_cli_prefix)\n        return pipeline_config\n\n    def check_pipeline_config(self) -> None:\n        if self.vae_sp and not self.vae_tiling:\n            raise ValueError(\n                \"Currently enabling vae_sp requires enabling vae_tiling, please set --vae-tiling to True.\"\n            )\n\n        if len(self.text_encoder_configs) != len(self.text_encoder_precisions):\n            raise ValueError(\n                f\"Length of text encoder configs ({len(self.text_encoder_configs)}) must be equal to length of text encoder precisions ({len(self.text_encoder_precisions)})\"\n            )\n\n        if len(self.text_encoder_configs) != len(self.preprocess_text_funcs):\n            raise ValueError(\n                f\"Length of text encoder configs ({len(self.text_encoder_configs)}) must be equal to length of text preprocessing functions ({len(self.preprocess_text_funcs)})\"\n            )\n\n        if len(self.preprocess_text_funcs) != len(self.postprocess_text_funcs):\n            raise ValueError(\n                f\"Length of text postprocess functions ({len(self.postprocess_text_funcs)}) must be equal to length of text preprocessing functions ({len(self.preprocess_text_funcs)})\"\n            )\n\n    def dump_to_json(self, file_path: str):\n        output_dict = shallow_asdict(self)\n        del_keys = []\n        for key, value in output_dict.items():\n            if isinstance(value, ModelConfig):\n                model_dict = asdict(value)\n                # Model Arch Config should be hidden away from the users\n                model_dict.pop(\"arch_config\")\n                output_dict[key] = model_dict\n            elif isinstance(value, tuple) and all(\n                isinstance(v, ModelConfig) for v in value\n            ):\n                model_dicts = []\n                for v in value:\n                    model_dict = asdict(v)\n                    # Model Arch Config should be hidden away from the users\n                    model_dict.pop(\"arch_config\")\n                    model_dicts.append(model_dict)\n                output_dict[key] = model_dicts\n            elif isinstance(value, tuple) and all(callable(f) for f in value):\n                # Skip dumping functions\n                del_keys.append(key)\n\n        for key in del_keys:\n            output_dict.pop(key, None)\n\n        with open(file_path, \"w\") as f:\n            json.dump(output_dict, f, indent=2)\n\n    def load_from_json(self, file_path: str):\n        with open(file_path) as f:\n            input_pipeline_dict = json.load(f)\n        self.update_pipeline_config(input_pipeline_dict)\n\n    def update_pipeline_config(self, source_pipeline_dict: dict[str, Any]) -> None:\n        for f in fields(self):\n            key = f.name\n            if key in source_pipeline_dict:\n                current_value = getattr(self, key)\n                new_value = source_pipeline_dict[key]\n\n                # If it's a nested ModelConfig, update it recursively\n                if isinstance(current_value, ModelConfig):\n                    current_value.update_model_config(new_value)\n                elif isinstance(current_value, tuple) and all(\n                    isinstance(v, ModelConfig) for v in current_value\n                ):\n                    assert len(current_value) == len(\n                        new_value\n                    ), \"Users shouldn't delete or add text encoder config objects in your json\"\n                    for target_config, source_config in zip(\n                        current_value, new_value, strict=True\n                    ):\n                        target_config.update_model_config(source_config)\n                else:\n                    setattr(self, key, new_value)\n\n        if hasattr(self, \"__post_init__\"):\n            self.__post_init__()\n\n\n@dataclass\nclass ImagePipelineConfig(PipelineConfig):\n    \"\"\"Base config for image generation pipelines with token-like latents [B, S, D].\"\"\"\n\n    def _prepare_sigmas(self, sigmas, num_inference_steps):\n        sigmas = (\n            np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)\n            if sigmas is None\n            else sigmas\n        )\n        return sigmas\n\n    def shard_latents_for_sp(self, batch, latents):\n        # latents: [B, H * W, C]\n        sp_world_size, rank_in_sp_group = get_sp_world_size(), get_sp_parallel_rank()\n        seq_len = latents.shape[1]\n\n        # TODO: reuse code in PipelineConfig::shard_latents_for_sp\n        # Pad to next multiple of SP degree if needed\n        if seq_len % sp_world_size != 0:\n            pad_len = sp_world_size - (seq_len % sp_world_size)\n            pad = torch.zeros(\n                (*latents.shape[:1], pad_len, *latents.shape[2:]),\n                dtype=latents.dtype,\n                device=latents.device,\n            )\n            latents = torch.cat([latents, pad], dim=1)\n\n        sharded_tensor = rearrange(\n            latents, \"b (n s) d -> b n s d\", n=sp_world_size\n        ).contiguous()\n        sharded_tensor = sharded_tensor[:, rank_in_sp_group, :, :]\n        return sharded_tensor, True\n\n    def gather_latents_for_sp(self, latents):\n        # For image latents [B, S_local, D], gather along sequence dim=1\n        latents = sequence_model_parallel_all_gather(latents, dim=1)\n        return latents\n\n    def _unpad_and_unpack_latents(self, latents, batch):\n        vae_scale_factor = self.vae_config.arch_config.vae_scale_factor\n        channels = self.dit_config.arch_config.in_channels\n        batch_size = latents.shape[0]\n\n        height = 2 * (int(batch.height) // (vae_scale_factor * 2))\n        width = 2 * (int(batch.width) // (vae_scale_factor * 2))\n\n        latents = maybe_unpad_latents(latents, batch)\n\n        latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)\n        latents = latents.permute(0, 3, 1, 4, 2, 5)\n        return latents, batch_size, channels, height, width\n\n\n@dataclass\nclass SpatialImagePipelineConfig(ImagePipelineConfig):\n    \"\"\"Base config for spatial image pipelines (e.g. GLM-Image) with 4D latents (B, C, H', W').\n\n    Overrides shard_latents_for_sp / gather_latents_for_sp to shard along the height dimension\n    so that each SP rank gets (B, C, H'_local, W') instead of using the token-style (B, S, C) path.\n    \"\"\"\n\n    def shard_latents_for_sp(self, batch, latents):\n        # 4D latents (B, C, H', W') -> shard along H' (dim=2); otherwise fall back to base (B, S, C)\n        sp_world_size = get_sp_world_size()\n        if sp_world_size <= 1:\n            return latents, False\n        if latents.dim() != 4:\n            return super().shard_latents_for_sp(batch, latents)\n\n        # (B, C, H', W')\n        _, _, h_lat, w_lat = latents.shape\n        if h_lat % sp_world_size != 0:\n            pad_len = sp_world_size - (h_lat % sp_world_size)\n            pad = torch.zeros(\n                (latents.shape[0], latents.shape[1], pad_len, latents.shape[3]),\n                dtype=latents.dtype,\n                device=latents.device,\n            )\n            latents = torch.cat([latents, pad], dim=2)\n            h_lat = latents.shape[2]\n        rank_in_sp_group = get_sp_parallel_rank()\n        chunk_size = h_lat // sp_world_size\n        h0 = rank_in_sp_group * chunk_size\n        h1 = h0 + chunk_size\n        sharded = latents[:, :, h0:h1, :].contiguous()\n        return sharded, True\n\n    def gather_latents_for_sp(self, latents):\n        if get_sp_world_size() <= 1:\n            return latents\n        if latents.dim() != 4:\n            return super().gather_latents_for_sp(latents)\n        # Gather along dim=2 (H') to match shard_latents_for_sp\n        return sequence_model_parallel_all_gather(latents, dim=2)\n\n\n@dataclass\nclass SlidingTileAttnConfig(PipelineConfig):\n    \"\"\"Configuration for sliding tile attention.\"\"\"\n\n    # Override any BaseConfig defaults as needed\n    # Add sliding tile specific parameters\n    window_size: int = 16\n    stride: int = 8\n\n    # You can provide custom defaults for inherited fields\n    height: int = 576\n    width: int = 1024\n\n    # Additional configuration specific to sliding tile attention\n    pad_to_square: bool = False\n    use_overlap_optimization: bool = True\n\n\ndef parse_int_list(value: str) -> list[int]:\n    \"\"\"Parse a comma-separated string of integers into a list.\"\"\"\n    if not value:\n        return []\n    return [int(x.strip()) for x in value.split(\",\")]\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/pipeline_configs/diffusers_generic.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\"\"\"\nGeneric pipeline configuration for diffusers backend.\n\nThis module provides a minimal pipeline configuration that works with the diffusers backend.\nSince diffusers handles its own model loading and configuration, this config is intentionally minimal.\n\"\"\"\n\nfrom dataclasses import dataclass, field\nfrom typing import Any\n\nfrom sglang.multimodal_gen.configs.models import DiTConfig, EncoderConfig, VAEConfig\nfrom sglang.multimodal_gen.configs.pipeline_configs.base import (\n    ModelTaskType,\n    PipelineConfig,\n)\n\n\n@dataclass\nclass DiffusersGenericPipelineConfig(PipelineConfig):\n    \"\"\"\n    Generic pipeline configuration for diffusers backend.\n\n    This is a minimal configuration since the diffusers backend handles most\n    configuration internally. It provides sensible defaults for the required fields.\n    \"\"\"\n\n    # default to T2I since it's the most common\n    task_type: ModelTaskType = ModelTaskType.T2I\n\n    dit_precision: str = \"bf16\"\n    vae_precision: str = \"bf16\"\n\n    should_use_guidance: bool = True\n    embedded_cfg_scale: float = 1.0\n    flow_shift: float | None = None\n    disable_autocast: bool = True  # let diffusers handle dtype\n\n    # diffusers handles its own loading\n    dit_config: DiTConfig = field(default_factory=DiTConfig)\n    vae_config: VAEConfig = field(default_factory=VAEConfig)\n    image_encoder_config: EncoderConfig = field(default_factory=EncoderConfig)\n    text_encoder_configs: tuple[EncoderConfig, ...] = field(\n        default_factory=lambda: (EncoderConfig(),)\n    )\n    text_encoder_precisions: tuple[str, ...] = field(default_factory=lambda: (\"fp16\",))\n\n    # VAE settings\n    vae_tiling: bool = False  # diffusers handles this\n    vae_slicing: bool = False  # slice VAE decode for lower memory usage\n    vae_sp: bool = False\n\n    # Quantization config for pipeline-level quantization\n    # See: https://huggingface.co/docs/diffusers/main/en/quantization/overview\n    # Use PipelineQuantizationConfig for component-level control:\n    #   from diffusers.quantizers import PipelineQuantizationConfig\n    #   quantization_config = PipelineQuantizationConfig(\n    #       quant_backend=\"bitsandbytes_4bit\",\n    #       quant_kwargs={\"load_in_4bit\": True, \"bnb_4bit_compute_dtype\": torch.bfloat16},\n    #       components_to_quantize=[\"transformer\", \"text_encoder_2\"],\n    #   )\n    quantization_config: Any = None\n\n    def check_pipeline_config(self) -> None:\n        \"\"\"\n        Override to skip most validation since diffusers handles its own config.\n        \"\"\"\n        pass\n\n    def adjust_size(self, width, height, image):\n        \"\"\"\n        Pass through - diffusers handles size adjustments.\n        \"\"\"\n        return width, height\n\n    def adjust_num_frames(self, num_frames):\n        \"\"\"\n        Pass through - diffusers handles frame count.\n        \"\"\"\n        return num_frames\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/pipeline_configs/flux.py",
    "content": "import math\nfrom dataclasses import dataclass, field\nfrom typing import Callable, List, Optional\n\nimport PIL\nimport torch\nfrom diffusers.image_processor import VaeImageProcessor\n\nfrom sglang.multimodal_gen.configs.models import DiTConfig, EncoderConfig, VAEConfig\nfrom sglang.multimodal_gen.configs.models.dits.flux import FluxConfig\nfrom sglang.multimodal_gen.configs.models.encoders import (\n    BaseEncoderOutput,\n    CLIPTextConfig,\n    T5Config,\n    TextEncoderConfig,\n)\nfrom sglang.multimodal_gen.configs.models.encoders.base import TextEncoderArchConfig\nfrom sglang.multimodal_gen.configs.models.encoders.qwen3 import Qwen3TextConfig\nfrom sglang.multimodal_gen.configs.models.encoders.qwen_image import (\n    _is_transformer_layer,\n)\nfrom sglang.multimodal_gen.configs.models.vaes.flux import Flux2VAEConfig, FluxVAEConfig\nfrom sglang.multimodal_gen.configs.pipeline_configs.base import (\n    ImagePipelineConfig,\n    ModelTaskType,\n    preprocess_text,\n    shard_rotary_emb_for_sp,\n)\nfrom sglang.multimodal_gen.configs.pipeline_configs.hunyuan import (\n    clip_postprocess_text,\n    clip_preprocess_text,\n)\nfrom sglang.multimodal_gen.configs.pipeline_configs.qwen_image import _pack_latents\nfrom sglang.multimodal_gen.runtime.distributed import get_local_torch_device\n\n\ndef t5_postprocess_text(outputs: BaseEncoderOutput, _text_inputs) -> torch.Tensor:\n    return outputs.last_hidden_state\n\n\n@dataclass\nclass FluxPipelineConfig(ImagePipelineConfig):\n    \"\"\"Configuration for the FLUX pipeline.\"\"\"\n\n    embedded_cfg_scale: float = 3.5\n\n    task_type: ModelTaskType = ModelTaskType.T2I\n\n    vae_tiling: bool = False\n\n    vae_sp: bool = False\n\n    dit_config: DiTConfig = field(default_factory=FluxConfig)\n    # VAE\n    vae_config: VAEConfig = field(default_factory=FluxVAEConfig)\n\n    enable_autocast: bool = False\n\n    # Text encoding stage\n    text_encoder_configs: tuple[EncoderConfig, ...] = field(\n        default_factory=lambda: (CLIPTextConfig(), T5Config())\n    )\n\n    text_encoder_precisions: tuple[str, ...] = field(\n        default_factory=lambda: (\"bf16\", \"bf16\")\n    )\n\n    preprocess_text_funcs: tuple[Callable[[str], str], ...] = field(\n        default_factory=lambda: (clip_preprocess_text, preprocess_text),\n    )\n\n    postprocess_text_funcs: tuple[Callable[[str], str], ...] = field(\n        default_factory=lambda: (clip_postprocess_text, t5_postprocess_text)\n    )\n\n    text_encoder_extra_args: list[dict] = field(\n        default_factory=lambda: [\n            dict(\n                max_length=77,\n                padding=\"max_length\",\n                truncation=True,\n                return_overflowing_tokens=False,\n                return_length=False,\n            ),\n            None,\n        ]\n    )\n\n    def prepare_sigmas(self, sigmas, num_inference_steps):\n        return self._prepare_sigmas(sigmas, num_inference_steps)\n\n    def prepare_latent_shape(self, batch, batch_size, num_frames):\n        height = 2 * (\n            batch.height // (self.vae_config.arch_config.vae_scale_factor * 2)\n        )\n        width = 2 * (batch.width // (self.vae_config.arch_config.vae_scale_factor * 2))\n        num_channels_latents = self.dit_config.arch_config.in_channels // 4\n        shape = (batch_size, num_channels_latents, height, width)\n        return shape\n\n    def maybe_pack_latents(self, latents, batch_size, batch):\n        height = 2 * (\n            batch.height // (self.vae_config.arch_config.vae_scale_factor * 2)\n        )\n        width = 2 * (batch.width // (self.vae_config.arch_config.vae_scale_factor * 2))\n        num_channels_latents = self.dit_config.arch_config.in_channels // 4\n        # pack latents\n        return _pack_latents(latents, batch_size, num_channels_latents, height, width)\n\n    def get_pos_prompt_embeds(self, batch):\n        return batch.prompt_embeds[1]\n\n    def get_neg_prompt_embeds(self, batch):\n        return batch.negative_prompt_embeds[1]\n\n    def _prepare_latent_image_ids(self, original_height, original_width, device):\n        vae_scale_factor = self.vae_config.arch_config.vae_scale_factor\n        height = int(original_height) // (vae_scale_factor * 2)\n        width = int(original_width) // (vae_scale_factor * 2)\n        latent_image_ids = torch.zeros(height, width, 3, device=device)\n        latent_image_ids[..., 1] = (\n            latent_image_ids[..., 1] + torch.arange(height, device=device)[:, None]\n        )\n        latent_image_ids[..., 2] = (\n            latent_image_ids[..., 2] + torch.arange(width, device=device)[None, :]\n        )\n\n        latent_image_id_height, latent_image_id_width, latent_image_id_channels = (\n            latent_image_ids.shape\n        )\n\n        latent_image_ids = latent_image_ids.reshape(\n            latent_image_id_height * latent_image_id_width, latent_image_id_channels\n        )\n\n        return latent_image_ids\n\n    def get_freqs_cis(self, prompt_embeds, width, height, device, rotary_emb, batch):\n        txt_ids = torch.zeros(prompt_embeds.shape[1], 3, device=device)\n        img_ids = self._prepare_latent_image_ids(\n            original_height=height,\n            original_width=width,\n            device=device,\n        )\n\n        # NOTE(mick): prepare it here, to avoid unnecessary computations\n        img_cos, img_sin = rotary_emb.forward(img_ids)\n        img_cos = shard_rotary_emb_for_sp(img_cos)\n        img_sin = shard_rotary_emb_for_sp(img_sin)\n\n        txt_cos, txt_sin = rotary_emb.forward(txt_ids)\n\n        cos = torch.cat([txt_cos, img_cos], dim=0).to(device=device)\n        sin = torch.cat([txt_sin, img_sin], dim=0).to(device=device)\n        return cos, sin\n\n    def post_denoising_loop(self, latents, batch):\n        # unpack latents for flux\n        (\n            latents,\n            batch_size,\n            channels,\n            height,\n            width,\n        ) = self._unpad_and_unpack_latents(latents, batch)\n        latents = latents.reshape(batch_size, channels // (2 * 2), height, width)\n        return latents\n\n    def prepare_pos_cond_kwargs(self, batch, device, rotary_emb, dtype):\n        return {\n            \"freqs_cis\": self.get_freqs_cis(\n                batch.prompt_embeds[1],\n                batch.width,\n                batch.height,\n                device,\n                rotary_emb,\n                batch,\n            ),\n            \"pooled_projections\": (\n                batch.pooled_embeds[0] if batch.pooled_embeds else None\n            ),\n        }\n\n    def prepare_neg_cond_kwargs(self, batch, device, rotary_emb, dtype):\n        return {\n            \"freqs_cis\": self.get_freqs_cis(\n                batch.negative_prompt_embeds[1],\n                batch.width,\n                batch.height,\n                device,\n                rotary_emb,\n                batch,\n            ),\n            \"pooled_projections\": (\n                batch.neg_pooled_embeds[0] if batch.neg_pooled_embeds else None\n            ),\n        }\n\n\ndef _prepare_latent_ids(\n    latents: torch.Tensor,  # (B, C, H, W)\n):\n    r\"\"\"\n    Generates 4D position coordinates (T, H, W, L) for latent tensors.\n\n    Args:\n        latents (torch.Tensor):\n            Latent tensor of shape (B, C, H, W)\n\n    Returns:\n        torch.Tensor:\n            Position IDs tensor of shape (B, H*W, 4) All batches share the same coordinate structure: T=0,\n            H=[0..H-1], W=[0..W-1], L=0\n    \"\"\"\n\n    batch_size, _, height, width = latents.shape\n\n    t = torch.arange(1)  # [0] - time dimension\n    h = torch.arange(height)\n    w = torch.arange(width)\n    layer = torch.arange(1)  # [0] - layer dimension\n\n    # Create position IDs: (H*W, 4)\n    latent_ids = torch.cartesian_prod(t, h, w, layer)\n\n    # Expand to batch: (B, H*W, 4)\n    latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1)\n    return latent_ids\n\n\ndef _unpack_latents_with_ids(\n    x: torch.Tensor, x_ids: torch.Tensor\n) -> list[torch.Tensor]:\n    \"\"\"\n    using position ids to scatter tokens into place\n    \"\"\"\n    x_list = []\n    x_ids = x_ids.to(device=x.device)\n    for data, pos in zip(x, x_ids):\n        _, ch = data.shape  # noqa: F841\n        h_ids = pos[:, 1].to(torch.int64)\n        w_ids = pos[:, 2].to(torch.int64)\n\n        h = torch.max(h_ids) + 1\n        w = torch.max(w_ids) + 1\n\n        flat_ids = h_ids * w + w_ids\n\n        out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype)\n        out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data)\n\n        # reshape from (H * W, C) to (H, W, C) and permute to (C, H, W)\n\n        out = out.view(h, w, ch).permute(2, 0, 1)\n        x_list.append(out)\n\n    return torch.stack(x_list, dim=0)\n\n\ndef _patchify_latents(latents):\n    batch_size, num_channels_latents, height, width = latents.shape\n    latents = latents.view(\n        batch_size, num_channels_latents, height // 2, 2, width // 2, 2\n    )\n    latents = latents.permute(0, 1, 3, 5, 2, 4)\n    latents = latents.reshape(\n        batch_size, num_channels_latents * 4, height // 2, width // 2\n    )\n    return latents\n\n\ndef _unpatchify_latents(latents):\n    batch_size, num_channels_latents, height, width = latents.shape\n    latents = latents.reshape(\n        batch_size, num_channels_latents // (2 * 2), 2, 2, height, width\n    )\n    latents = latents.permute(0, 1, 4, 2, 5, 3)\n    latents = latents.reshape(\n        batch_size, num_channels_latents // (2 * 2), height * 2, width * 2\n    )\n    return latents\n\n\ndef _prepare_text_ids(\n    x: torch.Tensor,  # (B, L, D) or (L, D)\n    t_coord: Optional[torch.Tensor] = None,\n):\n    B, L, _ = x.shape\n    out_ids = []\n\n    for i in range(B):\n        t = torch.arange(1) if t_coord is None else t_coord[i]\n        h = torch.arange(1)\n        w = torch.arange(1)\n        layer = torch.arange(L)\n\n        coords = torch.cartesian_prod(t, h, w, layer)\n        out_ids.append(coords)\n\n    return torch.stack(out_ids)\n\n\ndef _prepare_image_ids(\n    image_latents: List[torch.Tensor],  # [(1, C, H, W), (1, C, H, W), ...]\n    scale: int = 10,\n):\n    if not isinstance(image_latents, list):\n        raise ValueError(\n            f\"Expected `image_latents` to be a list, got {type(image_latents)}.\"\n        )\n\n    # create time offset for each reference image\n    t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))]\n    t_coords = [t.view(-1) for t in t_coords]\n\n    image_latent_ids = []\n    for x, t in zip(image_latents, t_coords):\n        x = x.squeeze(0)\n        _, height, width = x.shape\n\n        x_ids = torch.cartesian_prod(\n            t, torch.arange(height), torch.arange(width), torch.arange(1)\n        )\n        image_latent_ids.append(x_ids)\n\n    image_latent_ids = torch.cat(image_latent_ids, dim=0)\n    image_latent_ids = image_latent_ids.unsqueeze(0)\n\n    return image_latent_ids\n\n\ndef flux2_postprocess_text(outputs: BaseEncoderOutput, _text_inputs) -> torch.Tensor:\n    hidden_states_layers: list[int] = [10, 20, 30]\n\n    out = torch.stack([outputs.hidden_states[k] for k in hidden_states_layers], dim=1)\n    batch_size, num_channels, seq_len, hidden_dim = out.shape\n    prompt_embeds = out.permute(0, 2, 1, 3).reshape(\n        batch_size, seq_len, num_channels * hidden_dim\n    )\n\n    return prompt_embeds\n\n\ndef flux2_klein_postprocess_text(\n    outputs: BaseEncoderOutput, _text_inputs\n) -> torch.Tensor:\n    hidden_states_layers: list[int] = [9, 18, 27]\n\n    out = torch.stack([outputs.hidden_states[k] for k in hidden_states_layers], dim=1)\n    batch_size, num_channels, seq_len, hidden_dim = out.shape\n    prompt_embeds = out.permute(0, 2, 1, 3).reshape(\n        batch_size, seq_len, num_channels * hidden_dim\n    )\n\n    return prompt_embeds\n\n\n@dataclass\nclass Flux2MistralTextArchConfig(TextEncoderArchConfig):\n    stacked_params_mapping: list[tuple[str, str, str]] = field(\n        default_factory=lambda: [\n            # (param_name, shard_name, shard_id)\n            (\"qkv_proj\", \"q_proj\", \"q\"),\n            (\"qkv_proj\", \"k_proj\", \"k\"),\n            (\"qkv_proj\", \"v_proj\", \"v\"),\n        ]\n    )\n    _fsdp_shard_conditions: list = field(\n        default_factory=lambda: [_is_transformer_layer]\n    )\n\n    def __post_init__(self):\n        self.tokenizer_kwargs = {\n            \"padding\": \"max_length\",\n            \"truncation\": True,\n            \"max_length\": 512,\n            \"add_special_tokens\": True,\n            \"return_attention_mask\": True,\n            \"return_tensors\": \"pt\",\n        }\n\n\n@dataclass\nclass Flux2MistralTextConfig(TextEncoderConfig):\n    arch_config: TextEncoderArchConfig = field(\n        default_factory=Flux2MistralTextArchConfig\n    )\n\n\ndef format_text_input(prompts: List[str], system_message: str = None):\n    # Remove [IMG] tokens from prompts to avoid Pixtral validation issues\n    # when truncation is enabled. The processor counts [IMG] tokens and fails\n    # if the count changes after truncation.\n    cleaned_txt = [prompt.replace(\"[IMG]\", \"\") for prompt in prompts]\n\n    return [\n        [\n            {\n                \"role\": \"system\",\n                \"content\": [{\"type\": \"text\", \"text\": system_message}],\n            },\n            {\"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": prompt}]},\n        ]\n        for prompt in cleaned_txt\n    ]\n\n\ndef flux_2_preprocess_text(prompt: str):\n    system_message = \"You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.\"\n    return format_text_input([prompt], system_message=system_message)\n\n\n# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents\ndef flux2_pack_latents(latents):\n    batch_size, num_channels, height, width = latents.shape\n    latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1)\n\n    return latents\n\n\n@dataclass\nclass Flux2PipelineConfig(FluxPipelineConfig):\n    embedded_cfg_scale: float = 4.0\n\n    task_type: ModelTaskType = ModelTaskType.TI2I\n\n    text_encoder_precisions: tuple[str, ...] = field(default_factory=lambda: (\"bf16\",))\n\n    text_encoder_configs: tuple[EncoderConfig, ...] = field(\n        default_factory=lambda: (Flux2MistralTextConfig(),)\n    )\n    preprocess_text_funcs: tuple[Callable[[str], str], ...] = field(\n        default_factory=lambda: (flux_2_preprocess_text,),\n    )\n\n    postprocess_text_funcs: tuple[Callable[[str], str], ...] = field(\n        default_factory=lambda: (flux2_postprocess_text,)\n    )\n    vae_config: VAEConfig = field(default_factory=Flux2VAEConfig)\n\n    def tokenize_prompt(self, prompts: list[str], tokenizer, tok_kwargs) -> dict:\n        # flatten to 1-d list\n        prompts = [p for prompt in prompts for p in prompt]\n        inputs = tokenizer.apply_chat_template(\n            prompts,\n            add_generation_prompt=False,\n            tokenize=True,\n            return_dict=True,\n            return_tensors=\"pt\",\n            padding=\"max_length\",\n            truncation=True,\n            # 2048 from official github repo, 512 from diffusers\n            max_length=512,\n        )\n\n        return inputs\n\n    def prepare_latent_shape(self, batch, batch_size, num_frames):\n        height = 2 * (\n            batch.height // (self.vae_config.arch_config.vae_scale_factor * 2)\n        )\n        width = 2 * (batch.width // (self.vae_config.arch_config.vae_scale_factor * 2))\n        num_channels_latents = self.dit_config.arch_config.in_channels\n        shape = (batch_size, num_channels_latents, height // 2, width // 2)\n        return shape\n\n    def get_pos_prompt_embeds(self, batch):\n        return batch.prompt_embeds[0]\n\n    def get_neg_prompt_embeds(self, batch):\n        return batch.negative_prompt_embeds[0]\n\n    def calculate_condition_image_size(\n        self, image, width, height\n    ) -> Optional[tuple[int, int]]:\n        vae_scale_factor = self.vae_config.arch_config.vae_scale_factor\n        multiple_of = vae_scale_factor * 2\n\n        target_area: int = 1024 * 1024\n        if width is not None and height is not None:\n            new_width, new_height = width, height\n            if width * height > target_area:\n                scale = math.sqrt(target_area / (width * height))\n                new_width = int(width * scale)\n                new_height = int(height * scale)\n\n            # Flux requires multiples of (VAE scale 8 * Patch size 2)\n            new_width = (new_width // multiple_of) * multiple_of\n            new_height = (new_height // multiple_of) * multiple_of\n\n            if new_width != width or new_height != height:\n                return new_width, new_height\n\n        return None\n\n    def preprocess_condition_image(\n        self, image, target_width, target_height, vae_image_processor: VaeImageProcessor\n    ):\n        img = image.resize((target_width, target_height), PIL.Image.Resampling.LANCZOS)\n        image_width, image_height = img.size\n        vae_scale_factor = self.vae_config.arch_config.vae_scale_factor\n        multiple_of = vae_scale_factor * 2\n        image_width = (image_width // multiple_of) * multiple_of\n        image_height = (image_height // multiple_of) * multiple_of\n        img = vae_image_processor.preprocess(\n            img, height=image_height, width=image_width, resize_mode=\"crop\"\n        )\n        return img, (image_width, image_height)\n\n    def postprocess_image_latent(self, latent_condition, batch):\n        batch_size = batch.batch_size\n        # latent: (1, 128, 32, 32)\n        packed = self.maybe_pack_latents(\n            latent_condition, None, batch\n        )  # (1, 1024, 128)\n        packed = packed.squeeze(0)  # (1024, 128) - remove batch dim\n\n        # Concatenate all reference tokens along sequence dimension\n        image_latents = packed.unsqueeze(0)  # (1, N*1024, 128)\n        image_latents = image_latents.repeat(batch_size, 1, 1)\n        return image_latents\n\n    def prepare_condition_image_latent_ids(self, image_latents, batch):\n        image_latent_ids = _prepare_image_ids(image_latents)\n        image_latent_ids = image_latent_ids.repeat(batch.batch_size, 1, 1)\n        batch.condition_image_latent_ids = image_latent_ids.to(get_local_torch_device())\n\n    def get_freqs_cis(self, prompt_embeds, width, height, device, rotary_emb, batch):\n        txt_ids = _prepare_text_ids(prompt_embeds).to(device=device)\n\n        img_ids = batch.latent_ids\n        if img_ids.ndim == 3:\n            img_ids = img_ids[0]\n        if txt_ids.ndim == 3:\n            txt_ids = txt_ids[0]\n\n        # NOTE(mick): prepare it here, to avoid unnecessary computations\n        img_cos, img_sin = rotary_emb.forward(img_ids)\n        img_cos = shard_rotary_emb_for_sp(img_cos)\n        img_sin = shard_rotary_emb_for_sp(img_sin)\n\n        if batch.image_latent is not None:\n            cond_ids = batch.condition_image_latent_ids\n            if cond_ids.ndim == 3:\n                cond_ids = cond_ids[0]\n            cond_cos, cond_sin = rotary_emb.forward(cond_ids)\n            cond_cos = shard_rotary_emb_for_sp(cond_cos)\n            cond_sin = shard_rotary_emb_for_sp(cond_sin)\n            img_cos = torch.cat([img_cos, cond_cos], dim=0)\n            img_sin = torch.cat([img_sin, cond_sin], dim=0)\n\n        txt_cos, txt_sin = rotary_emb.forward(txt_ids)\n\n        cos = torch.cat([txt_cos, img_cos], dim=0).to(device=device)\n        sin = torch.cat([txt_sin, img_sin], dim=0).to(device=device)\n        return cos, sin\n\n    def prepare_pos_cond_kwargs(self, batch, device, rotary_emb, dtype):\n        return {\n            \"freqs_cis\": self.get_freqs_cis(\n                batch.prompt_embeds[0],\n                batch.width,\n                batch.height,\n                device,\n                rotary_emb,\n                batch,\n            )\n        }\n\n    def prepare_neg_cond_kwargs(self, batch, device, rotary_emb, dtype):\n        return {}\n\n    def maybe_pack_latents(self, latents, batch_size, batch):\n        return flux2_pack_latents(latents)\n\n    def maybe_prepare_latent_ids(self, latents):\n        return _prepare_latent_ids(latents)\n\n    def postprocess_vae_encode(self, image_latents, vae):\n        # patchify\n        image_latents = _patchify_latents(image_latents)\n        return image_latents\n\n    def _check_vae_has_bn(self, vae):\n        \"\"\"Check if VAE has bn attribute (cached check to avoid repeated hasattr calls).\"\"\"\n        if not hasattr(self, \"_vae_has_bn_cache\"):\n            self._vae_has_bn_cache = hasattr(vae, \"bn\") and vae.bn is not None\n        return self._vae_has_bn_cache\n\n    def preprocess_decoding(self, latents, server_args=None, vae=None):\n        \"\"\"Preprocess latents before decoding.\n\n        Dynamically adapts based on VAE type:\n        - Standard Flux2 VAE (has bn): needs unpatchify (128 channels -> 32 channels)\n        - Distilled VAE (no bn): keeps patchified latents (128 channels)\n        \"\"\"\n        if vae is not None and self._check_vae_has_bn(vae):\n            return _unpatchify_latents(latents)\n        return latents\n\n    def get_decode_scale_and_shift(self, device, dtype, vae):\n        \"\"\"Get scale and shift for decoding.\n\n        Dynamically adapts based on VAE type:\n        - Standard Flux2 VAE (has bn): uses BatchNorm statistics\n        - Distilled VAE (no bn): uses scaling_factor from config\n        \"\"\"\n        vae_arch_config = self.vae_config.arch_config\n\n        if self._check_vae_has_bn(vae):\n            # Standard Flux2 VAE: use BatchNorm statistics\n            latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(device, dtype)\n            latents_bn_std = torch.sqrt(\n                vae.bn.running_var.view(1, -1, 1, 1) + vae_arch_config.batch_norm_eps\n            ).to(device, dtype)\n            return 1 / latents_bn_std, latents_bn_mean\n\n        # Distilled VAE or unknown: use scaling_factor\n        scaling_factor = (\n            getattr(vae.config, \"scaling_factor\", None)\n            if hasattr(vae, \"config\")\n            else getattr(vae, \"scaling_factor\", None)\n        ) or getattr(vae_arch_config, \"scaling_factor\", 0.13025)\n\n        scale = torch.tensor(scaling_factor, device=device, dtype=dtype).view(\n            1, 1, 1, 1\n        )\n        return 1 / scale, None\n\n    def post_denoising_loop(self, latents, batch):\n        latent_ids = batch.latent_ids\n        latents = _unpack_latents_with_ids(latents, latent_ids)\n\n        return latents\n\n    def slice_noise_pred(self, noise, latents):\n        # remove noise over input image\n        noise = noise[:, : latents.size(1) :]\n        return noise\n\n\n@dataclass\nclass Flux2KleinPipelineConfig(Flux2PipelineConfig):\n    # Klein is distilled, so no guidance embeddings\n    should_use_guidance: bool = False\n\n    text_encoder_precisions: tuple[str, ...] = field(default_factory=lambda: (\"bf16\",))\n\n    text_encoder_configs: tuple[EncoderConfig, ...] = field(\n        default_factory=lambda: (Qwen3TextConfig(),)\n    )\n\n    preprocess_text_funcs: tuple[Callable[[str], str], ...] = field(\n        default_factory=lambda: (preprocess_text,),\n    )\n\n    postprocess_text_funcs: tuple[Callable[[str], str], ...] = field(\n        default_factory=lambda: (flux2_klein_postprocess_text,)\n    )\n\n    def tokenize_prompt(self, prompts: list[str], tokenizer, tok_kwargs) -> dict:\n        if prompts and isinstance(prompts[0], list):\n            prompts = [p for prompt in prompts for p in prompt]\n\n        def _apply_chat_template(prompt: str) -> str:\n            messages = [{\"role\": \"user\", \"content\": prompt}]\n            try:\n                return tokenizer.apply_chat_template(\n                    messages,\n                    tokenize=False,\n                    add_generation_prompt=True,\n                    enable_thinking=False,\n                )\n            except TypeError:\n                return tokenizer.apply_chat_template(\n                    messages,\n                    tokenize=False,\n                    add_generation_prompt=True,\n                )\n\n        texts = [_apply_chat_template(prompt) for prompt in prompts]\n\n        tok_kwargs = dict(tok_kwargs or {})\n        max_length = tok_kwargs.pop(\"max_length\", 512)\n        padding = tok_kwargs.pop(\"padding\", \"max_length\")\n        truncation = tok_kwargs.pop(\"truncation\", True)\n        return_tensors = tok_kwargs.pop(\"return_tensors\", \"pt\")\n\n        return tokenizer(\n            texts,\n            padding=padding,\n            truncation=truncation,\n            max_length=max_length,\n            return_tensors=return_tensors,\n            **tok_kwargs,\n        )\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/pipeline_configs/flux_finetuned.py",
    "content": "\"\"\"\nPipeline configuration for Flux fine-tuned/distilled models.\n\nThis module provides specialized handling for Flux fine-tuned models from HuggingFace,\nsuch as fal/FLUX.2-Tiny-AutoEncoder and other community fine-tuned variants.\n\nKey differences from standard Flux2PipelineConfig:\n- Handles custom VAE architectures loaded via auto_map\n- Supports both patchified (128 channels) and unpatchified (32 channels) latents\n- Dynamically adapts scale/shift based on VAE type\n- Properly handles 5D latents (batch, channels, frames, height, width) for decoding\n\"\"\"\n\nfrom dataclasses import dataclass\n\nimport torch\n\nfrom sglang.multimodal_gen.configs.pipeline_configs.flux import (\n    Flux2PipelineConfig,\n    _unpatchify_latents,\n)\n\n\n@dataclass\nclass Flux2FinetunedPipelineConfig(Flux2PipelineConfig):\n    \"\"\"\n    Pipeline configuration for Flux fine-tuned/distilled models.\n\n    This configuration automatically detects and handles custom VAE architectures\n    (e.g., Flux2TinyAutoEncoder) loaded via HuggingFace's auto_map mechanism.\n\n    Features:\n    - Automatic VAE type detection (standard vs. distilled)\n    - Proper handling of patchified/unpatchified latents\n    - Support for custom scaling factors from fine-tuned models\n    - 5D latents support for both single-frame and multi-frame generation\n    \"\"\"\n\n    def preprocess_decoding(\n        self, latents: torch.Tensor, server_args=None, vae=None\n    ) -> torch.Tensor:\n        \"\"\"\n        Preprocess latents before decoding.\n\n        Handles both standard Flux2 VAE and fine-tuned/distilled VAEs:\n        - Standard Flux2 VAE (has bn): needs unpatchify (128 channels -> 32 channels)\n        - Distilled/Finetuned VAE (no bn): keeps patchified latents (128 channels)\n\n        Also handles 5D latents (batch, channels, frames, height, width) by converting\n        to 4D (batch, channels, height, width) for single-frame cases.\n\n        Args:\n            latents: Input latents tensor, can be 4D or 5D\n            server_args: Server arguments (optional, for compatibility)\n            vae: VAE model instance for dynamic type detection\n\n        Returns:\n            Preprocessed latents ready for VAE decoding\n        \"\"\"\n        # Handle 5D latents (batch, channels, frames, height, width)\n        if latents.ndim == 5:\n            batch_size, channels, frames, height, width = latents.shape\n            if frames == 1:\n                latents = latents.squeeze(2)\n            else:\n                latents = latents.permute(0, 2, 1, 3, 4).contiguous()\n                latents = latents.view(batch_size * frames, channels, height, width)\n\n        if vae is not None and self._check_vae_has_bn(vae):\n            latents = _unpatchify_latents(latents)\n        return latents\n\n    def get_decode_scale_and_shift(self, device, dtype, vae):\n        \"\"\"\n        Get scale and shift for decoding.\n\n        Dynamically adapts based on VAE type:\n        - Standard Flux2 VAE (has bn): uses BatchNorm statistics\n        - Distilled/Finetuned VAE (no bn): uses scaling_factor from config\n\n        Args:\n            device: Target device for tensors\n            dtype: Target dtype for tensors\n            vae: VAE model instance\n\n        Returns:\n            Tuple of (scaling_factor, shift_factor)\n            - scaling_factor: Tensor or scalar to divide latents by\n            - shift_factor: Tensor or scalar to add to latents (None for distilled VAEs)\n        \"\"\"\n        vae_arch_config = self.vae_config.arch_config\n\n        if self._check_vae_has_bn(vae):\n            # Standard Flux2 VAE: use BatchNorm statistics\n            latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(device, dtype)\n            latents_bn_std = torch.sqrt(\n                vae.bn.running_var.view(1, -1, 1, 1) + vae_arch_config.batch_norm_eps\n            ).to(device, dtype)\n            return 1 / latents_bn_std, latents_bn_mean\n\n        # Distilled/Finetuned VAE: Flux2TinyAutoEncoder doesn't need external scaling\n        scale = torch.tensor(1.0, device=device, dtype=dtype).view(1, 1, 1, 1)\n        return scale, None\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/pipeline_configs/glm_image.py",
    "content": "from dataclasses import dataclass, field\n\nimport torch\nfrom diffusers.image_processor import VaeImageProcessor\n\nfrom sglang.multimodal_gen.configs.models import DiTConfig, VAEConfig\nfrom sglang.multimodal_gen.configs.models.dits.glmimage import GlmImageDitConfig\nfrom sglang.multimodal_gen.configs.models.encoders.base import EncoderConfig\nfrom sglang.multimodal_gen.configs.models.encoders.t5 import T5Config\nfrom sglang.multimodal_gen.configs.models.vaes.glmimage import GlmImageVAEConfig\nfrom sglang.multimodal_gen.configs.pipeline_configs.base import (\n    ModelTaskType,\n    SpatialImagePipelineConfig,\n)\n\n\n@dataclass\nclass GlmImagePipelineConfig(SpatialImagePipelineConfig):\n    \"\"\"Configuration for the GlmImage pipeline.\"\"\"\n\n    vae_precision: str = \"bf16\"\n\n    should_use_guidance: bool = False\n    task_type: ModelTaskType = ModelTaskType.T2I\n\n    vae_tiling: bool = False\n\n    vae_sp: bool = False\n\n    text_encoder_configs: tuple[EncoderConfig, ...] = field(\n        default_factory=lambda: (T5Config(),)\n    )\n\n    dit_config: DiTConfig = field(default_factory=GlmImageDitConfig)\n    # VAE\n    vae_config: VAEConfig = field(default_factory=GlmImageVAEConfig)\n\n    # GLM-Image uses T5 text encoder; base default is EncoderConfig() which lacks\n    # parallel_folding and causes AttributeError + fallback to native T5 with missing weights.\n    text_encoder_configs: tuple[EncoderConfig, ...] = field(\n        default_factory=lambda: (T5Config(),)\n    )\n\n    enable_autocast: bool = False\n\n    def __post_init__(self):\n        self.vae_scale_factor = self.vae_config.get_vae_scale_factor()\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n\n    def get_freqs_cis(self, batch, device, rotary_emb, dtype):\n        height = batch.height // self.vae_scale_factor\n        width = batch.width // self.vae_scale_factor\n        hidden_states = torch.empty(1, 1, height, width, device=device, dtype=dtype)\n        freqs_cis = rotary_emb(hidden_states)\n        return freqs_cis\n\n    def prepare_pos_cond_kwargs(self, batch, device, rotary_emb, dtype):\n        return {\n            \"prior_token_id\": batch.prior_token_id,\n            \"prior_token_drop\": batch.prior_token_drop_cond,\n            \"crop_coords\": batch.crop_coords,\n            \"target_size\": batch.target_size,\n            \"kv_caches\": batch.kv_caches,\n            \"kv_caches_mode\": \"read\",\n            \"freqs_cis\": self.get_freqs_cis(batch, device, rotary_emb, dtype),\n        }\n\n    def prepare_neg_cond_kwargs(self, batch, device, rotary_emb, dtype):\n        return {\n            \"prior_token_id\": batch.prior_token_id,\n            \"prior_token_drop\": batch.prior_token_drop_uncond,\n            \"crop_coords\": batch.crop_coords,\n            \"target_size\": batch.target_size,\n            \"kv_caches\": batch.kv_caches,\n            \"kv_caches_mode\": \"skip\",\n            \"freqs_cis\": self.get_freqs_cis(batch, device, rotary_emb, dtype),\n        }\n\n    def get_decode_scale_and_shift(self, device, dtype, vae):\n        latents_mean = (\n            torch.tensor(self.vae_config.latents_mean)\n            .view(1, self.vae_config.latent_channels, 1, 1)\n            .to(device, dtype)\n        )\n        latents_std = (\n            torch.tensor(self.vae_config.latents_std)\n            .view(1, self.vae_config.latent_channels, 1, 1)\n            .to(device, dtype)\n        )\n        return 1.0 / latents_std, latents_mean\n\n    def post_denoising_loop(self, latents, batch):\n        if getattr(batch, \"kv_caches\", None) is not None:\n            batch.kv_caches.clear()\n        return latents.bfloat16()\n\n    def post_decoding(self, frames, server_args):\n        return self.image_processor.postprocess(frames, output_type=\"latent\")\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/pipeline_configs/helios.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\nfrom collections.abc import Callable\nfrom dataclasses import dataclass, field\n\nimport torch\n\nfrom sglang.multimodal_gen.configs.models import DiTConfig, EncoderConfig, VAEConfig\nfrom sglang.multimodal_gen.configs.models.dits.helios import HeliosConfig\nfrom sglang.multimodal_gen.configs.models.encoders import BaseEncoderOutput, T5Config\nfrom sglang.multimodal_gen.configs.models.encoders.t5 import T5ArchConfig\nfrom sglang.multimodal_gen.configs.models.vaes import WanVAEConfig\nfrom sglang.multimodal_gen.configs.pipeline_configs.base import (\n    ModelTaskType,\n    PipelineConfig,\n)\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\n# Helios UMT5 max sequence length (used for both tokenizer and post-processing padding)\n# Matches diffusers HeliosPipeline.__call__ default max_sequence_length=512\nHELIOS_MAX_SEQUENCE_LENGTH = 512\n\n\ndef umt5_postprocess_text(outputs: BaseEncoderOutput, _text_inputs) -> torch.Tensor:\n    \"\"\"Post-process UMT5 text encoder outputs, padding to HELIOS_MAX_SEQUENCE_LENGTH tokens.\"\"\"\n    max_seq_len = HELIOS_MAX_SEQUENCE_LENGTH\n    mask: torch.Tensor = outputs.attention_mask\n    hidden_state: torch.Tensor = outputs.last_hidden_state\n    seq_lens = mask.gt(0).sum(dim=1).long()\n    assert torch.isnan(hidden_state).sum() == 0\n    prompt_embeds = [u[:v] for u, v in zip(hidden_state, seq_lens, strict=True)]\n    prompt_embeds_tensor: torch.Tensor = torch.stack(\n        [\n            torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))])\n            for u in prompt_embeds\n        ],\n        dim=0,\n    )\n    return prompt_embeds_tensor\n\n\n@dataclass\nclass HeliosT2VConfig(PipelineConfig):\n    \"\"\"Configuration for the Helios T2V pipeline.\"\"\"\n\n    task_type: ModelTaskType = ModelTaskType.T2V\n\n    # DiT\n    dit_config: DiTConfig = field(default_factory=HeliosConfig)\n\n    # VAE (same as Wan)\n    vae_config: VAEConfig = field(default_factory=WanVAEConfig)\n    vae_tiling: bool = False\n    vae_sp: bool = False\n\n    # Denoising stage\n    flow_shift: float | None = 1.0\n\n    # Text encoding stage (UMT5 is T5-compatible)\n    text_encoder_configs: tuple[EncoderConfig, ...] = field(\n        default_factory=lambda: (\n            T5Config(arch_config=T5ArchConfig(text_len=HELIOS_MAX_SEQUENCE_LENGTH)),\n        )\n    )\n    postprocess_text_funcs: tuple[Callable[[BaseEncoderOutput], torch.Tensor], ...] = (\n        field(default_factory=lambda: (umt5_postprocess_text,))\n    )\n\n    # Precision for each component\n    precision: str = \"bf16\"\n    vae_precision: str = \"fp32\"\n    text_encoder_precisions: tuple[str, ...] = field(default_factory=lambda: (\"fp32\",))\n\n    # Helios-specific chunked denoising params\n    num_latent_frames_per_chunk: int = 9\n    history_sizes: list[int] = field(default_factory=lambda: [16, 2, 1])\n    is_cfg_zero_star: bool = False\n    zero_steps: int = 1\n    keep_first_frame: bool = True\n\n    # Stage 2 (Pyramid SR) & Stage 3 (DMD) params\n    is_enable_stage2: bool = False\n    pyramid_num_stages: int = 3\n    pyramid_num_inference_steps_list: list[int] = field(\n        default_factory=lambda: [10, 10, 10]\n    )\n    is_distilled: bool = False\n    is_amplify_first_chunk: bool = False\n    scheduler_type: str = \"unipc\"\n    gamma: float = 1 / 3\n\n    def __post_init__(self):\n        self.vae_config.load_encoder = False\n        self.vae_config.load_decoder = True\n\n\n@dataclass\nclass HeliosMidConfig(HeliosT2VConfig):\n    \"\"\"Configuration for Helios-Mid (Stage 1 + Stage 2 pyramid SR).\"\"\"\n\n    is_enable_stage2: bool = True\n    is_cfg_zero_star: bool = True\n    pyramid_num_inference_steps_list: list[int] = field(\n        default_factory=lambda: [20, 20, 20]\n    )\n\n\n@dataclass\nclass HeliosDistilledConfig(HeliosT2VConfig):\n    \"\"\"Configuration for Helios-Distilled (Stage 1 + Stage 2 + Stage 3 DMD).\"\"\"\n\n    is_enable_stage2: bool = True\n    is_distilled: bool = True\n    is_amplify_first_chunk: bool = True\n    scheduler_type: str = \"dmd\"\n    pyramid_num_inference_steps_list: list[int] = field(\n        default_factory=lambda: [10, 10, 10]\n    )\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/pipeline_configs/hunyuan.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\nfrom collections.abc import Callable\nfrom dataclasses import dataclass, field\nfrom typing import TypedDict\n\nimport torch\n\nfrom sglang.multimodal_gen.configs.models import DiTConfig, EncoderConfig, VAEConfig\nfrom sglang.multimodal_gen.configs.models.dits import HunyuanVideoConfig\nfrom sglang.multimodal_gen.configs.models.encoders import (\n    BaseEncoderOutput,\n    CLIPTextConfig,\n    LlamaConfig,\n)\nfrom sglang.multimodal_gen.configs.models.vaes import HunyuanVAEConfig\nfrom sglang.multimodal_gen.configs.pipeline_configs.base import (\n    ModelTaskType,\n    PipelineConfig,\n)\n\nPROMPT_TEMPLATE_ENCODE_VIDEO = (\n    \"<|start_header_id|>system<|end_header_id|>\\n\\nDescribe the video by detailing the following aspects: \"\n    \"1. The main content and theme of the video.\"\n    \"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.\"\n    \"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.\"\n    \"4. background environment, light, style and atmosphere.\"\n    \"5. camera angles, movements, and transitions used in the video:<|eot_id|>\"\n    \"<|start_header_id|>user<|end_header_id|>\\n\\n{}<|eot_id|>\"\n)\n\n\nclass PromptTemplate(TypedDict):\n    template: str\n    crop_start: int\n\n\nprompt_template_video: PromptTemplate = {\n    \"template\": PROMPT_TEMPLATE_ENCODE_VIDEO,\n    \"crop_start\": 95,\n}\n\n\ndef llama_preprocess_text(prompt: str) -> str:\n    return prompt_template_video[\"template\"].format(prompt)\n\n\ndef llama_postprocess_text(outputs: BaseEncoderOutput, _text_inputs) -> torch.tensor:\n    hidden_state_skip_layer = 2\n    assert outputs.hidden_states is not None\n    hidden_states: tuple[torch.Tensor, ...] = outputs.hidden_states\n    last_hidden_state: torch.tensor = hidden_states[-(hidden_state_skip_layer + 1)]\n    crop_start = prompt_template_video.get(\"crop_start\", -1)\n    last_hidden_state = last_hidden_state[:, crop_start:]\n    return last_hidden_state\n\n\ndef clip_preprocess_text(prompt: str) -> str:\n    return prompt\n\n\ndef clip_postprocess_text(outputs: BaseEncoderOutput, _text_inputs) -> torch.tensor:\n    pooler_output: torch.tensor = outputs.pooler_output\n    return pooler_output\n\n\n@dataclass\nclass HunyuanConfig(PipelineConfig):\n    \"\"\"Base configuration for HunYuan pipeline architecture.\"\"\"\n\n    task_type: ModelTaskType = ModelTaskType.T2V\n\n    # HunyuanConfig-specific parameters with defaults\n    # DiT\n    dit_config: DiTConfig = field(default_factory=HunyuanVideoConfig)\n    # VAE\n    vae_config: VAEConfig = field(default_factory=HunyuanVAEConfig)\n    # Denoising stage\n    embedded_cfg_scale: int = 6\n    flow_shift: int = 7\n\n    # Text encoding stage\n    text_encoder_configs: tuple[EncoderConfig, ...] = field(\n        default_factory=lambda: (LlamaConfig(), CLIPTextConfig())\n    )\n    preprocess_text_funcs: tuple[Callable[[str], str], ...] = field(\n        default_factory=lambda: (llama_preprocess_text, clip_preprocess_text)\n    )\n    postprocess_text_funcs: tuple[Callable[[BaseEncoderOutput], torch.tensor], ...] = (\n        field(default_factory=lambda: (llama_postprocess_text, clip_postprocess_text))\n    )\n\n    # Precision for each component\n    dit_precision: str = \"bf16\"\n    vae_precision: str = \"fp16\"\n    text_encoder_precisions: tuple[str, ...] = field(\n        default_factory=lambda: (\"fp16\", \"fp16\")\n    )\n\n    def __post_init__(self):\n        self.vae_config.load_encoder = False\n        self.vae_config.load_decoder = True\n\n\n@dataclass\nclass FastHunyuanConfig(HunyuanConfig):\n    \"\"\"Configuration specifically optimized for FastHunyuan weights.\"\"\"\n\n    # Override HunyuanConfig defaults\n    flow_shift: int = 17\n\n    # No need to re-specify guidance_scale or embedded_cfg_scale as they\n    # already have the desired values from HunyuanConfig\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/pipeline_configs/hunyuan3d.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\nfrom dataclasses import dataclass, field\nfrom typing import Optional\n\nfrom sglang.multimodal_gen.configs.models import DiTConfig, VAEConfig\nfrom sglang.multimodal_gen.configs.models.dits.hunyuan3d import Hunyuan3DDiTConfig\nfrom sglang.multimodal_gen.configs.models.vaes.hunyuan3d import Hunyuan3DVAEConfig\nfrom sglang.multimodal_gen.configs.pipeline_configs.base import (\n    ModelTaskType,\n    PipelineConfig,\n)\n\n\n@dataclass\nclass Hunyuan3D2PipelineConfig(PipelineConfig):\n    \"\"\"Pipeline configuration for Hunyuan3D image-to-mesh generation.\"\"\"\n\n    task_type: ModelTaskType = ModelTaskType.I2M\n\n    # Subfolder paths\n    shape_subfolder: str = \"hunyuan3d-dit-v2-0\"\n    paint_subfolder: str = \"hunyuan3d-paint-v2-0\"\n    delight_subfolder: str = \"hunyuan3d-delight-v2-0\"\n\n    # DiT configuration\n    dit_config: DiTConfig = field(default_factory=Hunyuan3DDiTConfig)\n    dit_precision: str = \"fp16\"\n\n    # VAE configuration\n    vae_config: VAEConfig = field(default_factory=Hunyuan3DVAEConfig)\n    vae_precision: str = \"fp32\"\n\n    # Shape model configuration\n    shape_model_path: Optional[str] = None\n    shape_use_safetensors: bool = True\n    shape_variant: Optional[str] = \"fp16\"\n    shape_num_inference_steps: int = 50\n    guidance_scale: float = 5.0\n    shape_box_v: float = 1.01\n    shape_octree_resolution: int = 384\n    shape_mc_level: float = 0.0\n    shape_mc_algo: Optional[str] = \"mc\"\n    shape_num_chunks: int = 8000\n    shape_output_type: str = \"trimesh\"\n\n    # Delight model configuration\n    delight_enable: bool = True\n    delight_prompt: str = \"\"\n    delight_negative_prompt: str = \"\"\n    delight_strength: float = 1.0\n    delight_num_inference_steps: int = 50\n    delight_guidance_scale: float = 1.0\n    delight_cfg_image: float = 1.5\n\n    # Paint model configuration\n    paint_enable: bool = True\n    paint_num_inference_steps: int = 30\n    paint_guidance_scale: float = 2.0\n    paint_resolution: int = 512\n    paint_render_size: int = 2048\n    paint_texture_size: int = 2048\n    paint_use_remesh: bool = True\n    paint_save_glb: bool = True\n    paint_turbo_mode: bool = False\n\n    def __post_init__(self):\n        self.vae_config.load_encoder = False\n        self.vae_config.load_decoder = True\n\n    def prepare_latent_shape(self, batch, batch_size, num_frames):\n        latent_shape = self.vae_config.arch_config.latent_shape\n        shape = (batch_size, *latent_shape)\n        return shape\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/pipeline_configs/ltx_2.py",
    "content": "import dataclasses\nfrom dataclasses import field\nfrom typing import Callable\n\nimport torch\n\nfrom sglang.multimodal_gen.configs.models.dits.ltx_2 import LTX2Config\nfrom sglang.multimodal_gen.configs.models.encoders import (\n    BaseEncoderOutput,\n    EncoderConfig,\n)\nfrom sglang.multimodal_gen.configs.models.encoders.gemma_3 import Gemma3Config\nfrom sglang.multimodal_gen.configs.models.vaes.ltx_audio import LTXAudioVAEConfig\nfrom sglang.multimodal_gen.configs.pipeline_configs.base import (\n    ModelTaskType,\n    PipelineConfig,\n    preprocess_text,\n)\nfrom sglang.multimodal_gen.runtime.distributed import (\n    get_sp_parallel_rank,\n    get_sp_world_size,\n    sequence_model_parallel_all_gather,\n)\n\n\ndef pack_text_embeds(\n    text_hidden_states: torch.Tensor,\n    sequence_lengths: torch.Tensor,\n    padding_side: str = \"left\",\n    scale_factor: int = 8,\n    eps: float = 1e-6,\n) -> torch.Tensor:\n    \"\"\"\n    Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and\n    per-layer in a masked fashion (only over non-padded positions).\n\n    Args:\n        text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`):\n            Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`).\n        sequence_lengths (`torch.Tensor of shape `(batch_size,)`):\n            The number of valid (non-padded) tokens for each batch instance.\n        device: (`str` or `torch.device`, *optional*):\n            torch device to place the resulting embeddings on\n        padding_side: (`str`, *optional*, defaults to `\"left\"`):\n            Whether the text tokenizer performs padding on the `\"left\"` or `\"right\"`.\n        scale_factor (`int`, *optional*, defaults to `8`):\n            Scaling factor to multiply the normalized hidden states by.\n        eps (`float`, *optional*, defaults to `1e-6`):\n            A small positive value for numerical stability when performing normalization.\n\n    Returns:\n        `torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`:\n            Normed and flattened text encoder hidden states.\n    \"\"\"\n    batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape\n    original_dtype = text_hidden_states.dtype\n    device = text_hidden_states.device\n\n    # Create padding mask\n    token_indices = torch.arange(seq_len, device=device).unsqueeze(0)\n    if padding_side == \"right\":\n        mask = token_indices < sequence_lengths[:, None]\n    elif padding_side == \"left\":\n        start_indices = seq_len - sequence_lengths[:, None]\n        mask = token_indices >= start_indices\n    else:\n        raise ValueError(f\"padding_side must be 'left' or 'right', got {padding_side}\")\n    mask = mask[:, :, None, None]  # [batch_size, seq_len, 1, 1]\n\n    masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0)\n    num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1)\n    masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (\n        num_valid_positions + eps\n    )\n\n    x_min = text_hidden_states.masked_fill(~mask, float(\"inf\")).amin(\n        dim=(1, 2), keepdim=True\n    )\n    x_max = text_hidden_states.masked_fill(~mask, float(\"-inf\")).amax(\n        dim=(1, 2), keepdim=True\n    )\n\n    normalized_hidden_states = (text_hidden_states - masked_mean) / (\n        x_max - x_min + eps\n    )\n    normalized_hidden_states = normalized_hidden_states * scale_factor\n\n    normalized_hidden_states = normalized_hidden_states.flatten(2)\n    mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers)\n    normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0)\n    normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype)\n\n    return normalized_hidden_states\n\n\ndef _gemma_postprocess_func(\n    outputs: BaseEncoderOutput, text_inputs: dict\n) -> torch.Tensor:\n    # LTX-2 requires all hidden states concatenated for the connector\n    if hasattr(outputs, \"hidden_states\") and outputs.hidden_states is not None:\n        # outputs.hidden_states is a tuple of tensors\n        # We need to stack them along the last dimension and pack them\n        hidden_states = torch.stack(outputs.hidden_states, dim=-1)\n        attention_mask = text_inputs[\"attention_mask\"]\n        sequence_lengths = attention_mask.sum(dim=-1)\n        # Assuming left padding for Gemma as per Diffusers\n        return pack_text_embeds(hidden_states, sequence_lengths, padding_side=\"left\")\n    else:\n        raise AttributeError(\n            \"Unsupported text encoder output: expected `hidden_states`.\"\n        )\n\n\n@dataclasses.dataclass\nclass LTX2PipelineConfig(PipelineConfig):\n    \"\"\"Configuration for LTX-Video pipeline.\"\"\"\n\n    task_type: ModelTaskType = ModelTaskType.TI2V\n    skip_input_image_preprocess: bool = True\n    dit_config: LTX2Config = field(default_factory=LTX2Config)\n\n    # Model architecture\n    in_channels: int = 128\n    out_channels: int = 128\n    patch_size: int = 1\n    patch_size_t: int = 1\n\n    # Audio VAE configuration\n    audio_vae_config: LTXAudioVAEConfig = field(default_factory=LTXAudioVAEConfig)\n    audio_vae_precision: str = \"fp32\"\n    audio_vae_temporal_compression_ratio: int = 4\n    audio_vae_mel_compression_ratio: int = 4\n\n    @property\n    def vae_scale_factor(self):\n        return getattr(self.vae_config.arch_config, \"spatial_compression_ratio\", 32)\n\n    @property\n    def vae_temporal_compression(self):\n        return getattr(self.vae_config.arch_config, \"temporal_compression_ratio\", 8)\n\n    def prepare_latent_shape(self, batch, batch_size, num_frames):\n        \"\"\"Return packed latent shape [B, seq, C] directly.\"\"\"\n        height = batch.height // self.vae_scale_factor\n        width = batch.width // self.vae_scale_factor\n\n        post_patch_num_frames = num_frames // self.patch_size_t\n        post_patch_height = height // self.patch_size\n        post_patch_width = width // self.patch_size\n        seq_len = post_patch_num_frames * post_patch_height * post_patch_width\n\n        num_channels = (\n            self.in_channels * self.patch_size_t * self.patch_size * self.patch_size\n        )\n\n        shape = (batch_size, seq_len, num_channels)\n        return shape\n\n    def prepare_audio_latent_shape(self, batch, batch_size, num_frames):\n        # Adapted from diffusers pipeline prepare_audio_latents\n        duration_s = num_frames / batch.fps\n\n        sample_rate = self.audio_vae_config.arch_config.sample_rate\n        hop_length = self.audio_vae_config.arch_config.mel_hop_length\n        temporal_compression = self.audio_vae_temporal_compression_ratio\n\n        latents_per_second = (\n            float(sample_rate) / float(hop_length) / float(temporal_compression)\n        )\n        latent_length = round(duration_s * latents_per_second)\n\n        num_mel_bins = self.audio_vae_config.arch_config.mel_bins\n        mel_compression_ratio = self.audio_vae_mel_compression_ratio\n        latent_mel_bins = num_mel_bins // mel_compression_ratio\n\n        # Default to 8\n        num_channels_latents = self.audio_vae_config.arch_config.latent_channels\n\n        shape = (batch_size, latent_length, num_channels_latents * latent_mel_bins)\n\n        return shape\n\n    # Text encoding stage (Gemma)\n    # LTX-2 needs separate contexts for video/audio streams. We model this as\n    # two logical encoders sharing the same underlying `text_encoder` module.\n    text_encoder_configs: tuple[EncoderConfig, ...] = field(\n        default_factory=lambda: (Gemma3Config(),)\n    )\n    text_encoder_precisions: tuple[str, ...] = field(default_factory=lambda: (\"bf16\",))\n    text_encoder_extra_args: list[dict] = field(default_factory=lambda: [{}])\n\n    preprocess_text_funcs: tuple[Callable[[str], str], ...] = field(\n        default_factory=lambda: (preprocess_text,)\n    )\n    postprocess_text_funcs: tuple[\n        Callable[[BaseEncoderOutput, dict], torch.Tensor], ...\n    ] = field(default_factory=lambda: (_gemma_postprocess_func,))\n\n    def prepare_sigmas(self, sigmas, num_inference_steps):\n        if sigmas is None:\n            steps = int(num_inference_steps)\n            if steps <= 0:\n                raise ValueError(f\"num_inference_steps must be positive, got {steps}\")\n            return [1.0 - i / steps for i in range(steps)]\n        return sigmas\n\n    def tokenize_prompt(self, prompt: list[str], tokenizer, tok_kwargs) -> dict:\n        # Adapted from diffusers_pipeline.py _get_gemma_prompt_embeds\n        # But we only need tokenization here, the embedding happens in TextEncodingStage\n\n        # Gemma expects left padding for chat-style prompts\n        tokenizer.padding_side = \"left\"\n        if tokenizer.pad_token is None:\n            tokenizer.pad_token = tokenizer.eos_token\n\n        max_sequence_length = tok_kwargs.get(\n            \"max_length\", 1024\n        )  # Default from diffusers pipeline\n\n        text_inputs = tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=max_sequence_length,\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n        return text_inputs\n\n    def maybe_pack_latents(self, latents, batch_size, batch):\n        # If already packed (3D shape [B, seq, C]), skip packing\n        if latents.dim() == 3:\n            return latents\n\n        # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].\n        # The patch dimensions are then permuted and collapsed into the channel dimension of shape:\n        # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).\n        # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features\n        batch_size, num_channels, num_frames, height, width = latents.shape\n        post_patch_num_frames = num_frames // self.patch_size_t\n        post_patch_height = height // self.patch_size\n        post_patch_width = width // self.patch_size\n        latents = latents.reshape(\n            batch_size,\n            -1,\n            post_patch_num_frames,\n            self.patch_size_t,\n            post_patch_height,\n            self.patch_size,\n            post_patch_width,\n            self.patch_size,\n        )\n        latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)\n        return latents\n\n    def _infer_video_latent_frames_and_tokens_per_frame(\n        self, batch, seq_len: int\n    ) -> tuple[int, int]:\n        \"\"\"Infer latent-frame count and tokens-per-frame for packed token latents [B, S, D].\n\n        Notes:\n        - This assumes `patch_size_t == 1` (no temporal patching).\n        - Tokens are ordered as (frame, height, width) after packing.\n        \"\"\"\n        if int(self.patch_size_t) != 1:\n            raise ValueError(\n                \"LTX-2 SP time-sharding for packed token latents currently requires \"\n                f\"{self.patch_size_t=}. (Expected 1)\"\n            )\n        if int(seq_len) <= 0:\n            raise ValueError(f\"Expected {seq_len=} > 0 for packed token latents.\")\n        if int(self.vae_scale_factor) <= 0:\n            raise ValueError(f\"Invalid {self.vae_scale_factor=}. Must be > 0.\")\n        if int(self.patch_size) <= 0:\n            raise ValueError(f\"Invalid {self.patch_size=}. Must be > 0.\")\n\n        latent_height = int(batch.height) // int(self.vae_scale_factor)\n        latent_width = int(batch.width) // int(self.vae_scale_factor)\n        if latent_height <= 0 or latent_width <= 0:\n            raise ValueError(\n                \"Invalid latent H/W computed from batch.height/width: \"\n                f\"{batch.height=} {batch.width=} {self.vae_scale_factor=}\"\n            )\n        if (latent_height % int(self.patch_size)) != 0 or (\n            latent_width % int(self.patch_size)\n        ) != 0:\n            raise ValueError(\n                \"Invalid spatial patching for packed token latents. Expected latent H/W \"\n                \"to be divisible by patch_size, got \"\n                f\"{latent_height=} {latent_width=} {self.patch_size=}.\"\n            )\n\n        post_patch_h = latent_height // int(self.patch_size)\n        post_patch_w = latent_width // int(self.patch_size)\n        tokens_per_frame = int(post_patch_h) * int(post_patch_w)\n        if tokens_per_frame <= 0:\n            raise ValueError(\n                f\"Invalid tokens_per_frame={tokens_per_frame} from \"\n                f\"{latent_height=} {latent_width=} {self.patch_size=}\"\n            )\n        if int(seq_len) % int(tokens_per_frame) != 0:\n            raise ValueError(\n                f\"LTX-2 token latents seq_len={seq_len} is not divisible by \"\n                f\"tokens_per_frame={tokens_per_frame}. Cannot time-shard for SP.\"\n            )\n        latent_num_frames = int(seq_len) // int(tokens_per_frame)\n        return int(latent_num_frames), int(tokens_per_frame)\n\n    def shard_latents_for_sp(self, batch, latents):\n        \"\"\"Shard LTX-2 packed token latents across SP ranks by latent time (frame) dimension.\"\"\"\n        sp_world_size = get_sp_world_size()\n        if sp_world_size <= 1:\n            return latents, False\n\n        # Default behavior for 5D latents.\n        if isinstance(latents, torch.Tensor) and latents.ndim == 5:\n            return super().shard_latents_for_sp(batch, latents)\n\n        # LTX-2 packed token latents [B, S, D]\n        if not (isinstance(latents, torch.Tensor) and latents.ndim == 3):\n            return latents, False\n\n        sp_rank = get_sp_parallel_rank()\n        seq_len = int(latents.shape[1])\n        latent_frames, tokens_per_frame = (\n            self._infer_video_latent_frames_and_tokens_per_frame(batch, seq_len)\n        )\n\n        # Pad whole frames so `latent_frames` is divisible by `sp_world_size`.\n        pad_frames = (sp_world_size - (latent_frames % sp_world_size)) % sp_world_size\n        if pad_frames:\n            pad_tokens = int(pad_frames) * int(tokens_per_frame)\n            pad = torch.zeros(\n                (latents.shape[0], pad_tokens, latents.shape[2]),\n                device=latents.device,\n                dtype=latents.dtype,\n            )\n            latents = torch.cat([latents, pad], dim=1)\n            latent_frames = int(latent_frames) + int(pad_frames)\n\n        local_frames = int(latent_frames) // int(sp_world_size)\n        start_frame = int(sp_rank) * int(local_frames)\n        start = int(start_frame) * int(tokens_per_frame)\n        end = int(start) + int(local_frames) * int(tokens_per_frame)\n        latents = latents[:, start:end, :]\n\n        # Store SP metadata for denoising (TI2V gating) and model-side RoPE shift.\n        batch.sp_video_latent_num_frames = int(local_frames)\n        batch.sp_video_start_frame = int(start_frame)\n        batch.sp_video_tokens_per_frame = int(tokens_per_frame)\n\n        return latents, True\n\n    def gather_latents_for_sp(self, latents):\n        \"\"\"Gather latents after SP. For packed token latents [B, S_local, D], gather on dim=1.\"\"\"\n        if get_sp_world_size() <= 1:\n            return latents\n        if isinstance(latents, torch.Tensor) and latents.ndim == 3:\n            return sequence_model_parallel_all_gather(latents.contiguous(), dim=1)\n        return super().gather_latents_for_sp(latents)\n\n    def maybe_pack_audio_latents(self, latents, batch_size, batch):\n        # If already packed (3D shape [B, T, C*F]), skip packing\n        if latents.dim() == 3:\n            return latents\n\n        # Audio latents shape: [B, C, L, M], where L is the latent audio length and M is the number of mel bins\n        # We need to pack them if patch_size/patch_size_t are defined for audio (not standard DiT patch size)\n\n        # So for LTX-2 (unless we change patch sizes), we just do:\n        latents = latents.transpose(1, 2).flatten(\n            2, 3\n        )  # [B, C, L, M] --> [B, L, C * M]\n        return latents\n\n    def get_pos_prompt_embeds(self, batch):\n        # LTX-2 returns multiple prompt embed tensors (video/audio contexts).\n        return (\n            batch.prompt_embeds[0]\n            if isinstance(batch.prompt_embeds, list)\n            else batch.prompt_embeds\n        )\n\n    def get_neg_prompt_embeds(self, batch):\n        return (\n            batch.negative_prompt_embeds[0]\n            if isinstance(batch.negative_prompt_embeds, list)\n            else batch.negative_prompt_embeds\n        )\n\n    def get_decode_scale_and_shift(self, device, dtype, vae):\n        latents_mean = getattr(vae, \"latents_mean\", None)\n        latents_std = getattr(vae, \"latents_std\", None)\n\n        scaling_factor = (\n            getattr(getattr(vae, \"config\", None), \"scaling_factor\", None)\n            or getattr(vae, \"scaling_factor\", None)\n            or getattr(self.vae_config.arch_config, \"scaling_factor\", None)\n            or 1.0\n        )\n        if isinstance(scaling_factor, (int, float)) and float(scaling_factor) == 0.0:\n            scaling_factor = 1.0\n\n        if isinstance(latents_mean, torch.Tensor) and isinstance(\n            latents_std, torch.Tensor\n        ):\n            latents_mean = latents_mean.to(device=device, dtype=dtype).view(\n                1, -1, 1, 1, 1\n            )\n            latents_std = latents_std.to(device=device, dtype=dtype).view(\n                1, -1, 1, 1, 1\n            )\n            sf = torch.tensor(float(scaling_factor), device=device, dtype=dtype).view(\n                1, 1, 1, 1, 1\n            )\n            return sf / latents_std, latents_mean\n\n        sf = torch.tensor(float(scaling_factor), device=device, dtype=dtype).view(\n            1, 1, 1, 1, 1\n        )\n        return sf, None\n\n    @staticmethod\n    def _unpack_latents(\n        latents: torch.Tensor,\n        num_frames: int,\n        height: int,\n        width: int,\n        patch_size: int = 1,\n        patch_size_t: int = 1,\n    ) -> torch.Tensor:\n        # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions)\n        # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of\n        # what happens in the `_pack_latents` method.\n        batch_size = latents.size(0)\n        latents = latents.reshape(\n            batch_size,\n            num_frames,\n            height,\n            width,\n            -1,\n            patch_size_t,\n            patch_size,\n            patch_size,\n        )\n        latents = (\n            latents.permute(0, 4, 1, 5, 2, 6, 3, 7)\n            .flatten(6, 7)\n            .flatten(4, 5)\n            .flatten(2, 3)\n        )\n        return latents\n\n    @staticmethod\n    def _denormalize_latents(\n        latents: torch.Tensor,\n        latents_mean: torch.Tensor,\n        latents_std: torch.Tensor,\n        scaling_factor: float = 1.0,\n    ) -> torch.Tensor:\n        # Denormalize latents across the channel dimension [B, C, F, H, W]\n        latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(\n            latents.device, latents.dtype\n        )\n        latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)\n        latents = latents * latents_std / scaling_factor + latents_mean\n        return latents\n\n    @staticmethod\n    def _denormalize_audio_latents(\n        latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor\n    ):\n        latents_mean = latents_mean.to(latents.device, latents.dtype)\n        latents_std = latents_std.to(latents.device, latents.dtype)\n        return (latents * latents_std) + latents_mean\n\n    @staticmethod\n    def _unpack_audio_latents(\n        latents: torch.Tensor,\n        latent_length: int,\n        num_mel_bins: int,\n        patch_size: int | None = None,\n        patch_size_t: int | None = None,\n    ) -> torch.Tensor:\n        # Unpacks an audio patch sequence of shape [B, S, D] into a latent spectrogram tensor of shape [B, C, L, M],\n        # where L is the latent audio length and M is the number of mel bins.\n        if patch_size is not None and patch_size_t is not None:\n            batch_size = latents.size(0)\n            latents = latents.reshape(\n                batch_size, latent_length, num_mel_bins, -1, patch_size_t, patch_size\n            )\n            latents = latents.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3)\n        else:\n            # Assume [B, S, D] = [B, L, C * M], which implies that patch_size = M and patch_size_t = 1.\n            latents = latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2)\n        return latents\n\n    def _unpad_and_unpack_latents(self, latents, audio_latents, batch, vae, audio_vae):\n        # Calculate latent dimensions\n        # Assuming batch has height, width, num_frames\n        height = batch.height\n        width = batch.width\n        num_frames = batch.num_frames\n\n        # Get compression ratios\n        # Default LTX-2 values if not present in config\n        vae_spatial_compression_ratio = getattr(\n            self.vae_config.arch_config, \"spatial_compression_ratio\", 32\n        )\n        vae_temporal_compression_ratio = getattr(\n            self.vae_config.arch_config, \"temporal_compression_ratio\", 8\n        )\n\n        latent_height = height // vae_spatial_compression_ratio\n        latent_width = width // vae_spatial_compression_ratio\n        latent_num_frames = (num_frames - 1) // vae_temporal_compression_ratio + 1\n\n        latents = self._unpack_latents(\n            latents,\n            latent_num_frames,\n            latent_height,\n            latent_width,\n            self.patch_size,\n            self.patch_size_t,\n        )\n\n        sample_rate = self.audio_vae_config.arch_config.sample_rate\n        hop_length = self.audio_vae_config.arch_config.mel_hop_length\n        temporal_compression = self.audio_vae_temporal_compression_ratio\n        duration_s = num_frames / batch.fps\n\n        latents_per_second = (\n            float(sample_rate) / float(hop_length) / float(temporal_compression)\n        )\n        audio_num_frames = round(duration_s * latents_per_second)\n\n        num_mel_bins = self.audio_vae_config.arch_config.mel_bins\n        mel_compression_ratio = self.audio_vae_mel_compression_ratio\n        latent_mel_bins = num_mel_bins // mel_compression_ratio\n\n        audio_latents_mean = getattr(audio_vae, \"latents_mean\", None)\n        audio_latents_std = getattr(audio_vae, \"latents_std\", None)\n        if (\n            isinstance(audio_latents_mean, torch.Tensor)\n            and isinstance(audio_latents_std, torch.Tensor)\n            and audio_latents_mean.numel() == audio_latents_std.numel()\n        ):\n            audio_latents_mean = audio_latents_mean.to(\n                device=audio_latents.device, dtype=audio_latents.dtype\n            )\n            audio_latents_std = audio_latents_std.to(\n                device=audio_latents.device, dtype=audio_latents.dtype\n            )\n            if audio_latents.ndim == 3:\n                if audio_latents.shape[-1] != audio_latents_mean.numel():\n                    raise ValueError(\n                        f\"audio_latents last dim {audio_latents.shape[-1]} \"\n                        f\"does not match audio_vae stats {audio_latents_mean.numel()}\"\n                    )\n                audio_latents = audio_latents * audio_latents_std.view(\n                    1, 1, -1\n                ) + audio_latents_mean.view(1, 1, -1)\n            elif audio_latents.ndim == 2:\n                if audio_latents.shape[-1] != audio_latents_mean.numel():\n                    raise ValueError(\n                        f\"audio_latents last dim {audio_latents.shape[-1]} \"\n                        f\"does not match audio_vae stats {audio_latents_mean.numel()}\"\n                    )\n                audio_latents = audio_latents * audio_latents_std.view(\n                    1, -1\n                ) + audio_latents_mean.view(1, -1)\n            else:\n                audio_latents = audio_latents * audio_latents_std + audio_latents_mean\n\n        audio_latents = self._unpack_audio_latents(\n            audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins\n        )\n\n        return latents, audio_latents\n\n\n@dataclasses.dataclass\nclass LTX2I2VPipelineConfig(LTX2PipelineConfig):\n    task_type: ModelTaskType = ModelTaskType.TI2V\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/pipeline_configs/mova.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\"\"\"\nMOVA pipeline configuration.\n\"\"\"\n\nfrom dataclasses import dataclass, field\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom PIL import Image\n\nfrom sglang.multimodal_gen.configs.models.dits import MOVAAudioConfig, MOVAVideoConfig\nfrom sglang.multimodal_gen.configs.models.encoders import T5Config\nfrom sglang.multimodal_gen.configs.models.vaes import DacVAEConfig, WanVAEConfig\nfrom sglang.multimodal_gen.configs.pipeline_configs.base import (\n    ModelTaskType,\n    PipelineConfig,\n)\nfrom sglang.multimodal_gen.configs.pipeline_configs.wan import t5_postprocess_text\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\n@dataclass\nclass MOVAPipelineConfig(PipelineConfig):\n    \"\"\"Configuration for MOVA (text+image -> video+audio) pipelines.\"\"\"\n\n    task_type: ModelTaskType = ModelTaskType.I2V\n\n    # Model configs\n    dit_config: MOVAVideoConfig = field(default_factory=MOVAVideoConfig)\n    audio_dit_config: MOVAAudioConfig = field(default_factory=MOVAAudioConfig)\n\n    # Video VAE (Wan) + Audio VAE (DAC)\n    vae_config: WanVAEConfig = field(default_factory=WanVAEConfig)\n    audio_vae_config: DacVAEConfig = field(default_factory=DacVAEConfig)\n    audio_vae_precision: str = \"fp32\"\n\n    # Text encoder (UMT5 compatible)\n    text_encoder_configs: tuple = field(default_factory=lambda: (T5Config(),))\n    postprocess_text_funcs: tuple = field(\n        default_factory=lambda: (t5_postprocess_text,)\n    )\n\n    # MOVA specific\n    audio_vae_type: str = \"dac\"\n    boundary_ratio: float | None = 0.9\n\n    # temporal alignment: MOVA expects (num_frames - 1) % 4 == 0\n    time_division_factor: int = 4\n    time_division_remainder: int = 1\n\n    def _center_crop_and_resize(\n        self, image: torch.Tensor | Image.Image, target_height: int, target_width: int\n    ) -> torch.Tensor | Image.Image:\n        if not isinstance(image, (Image.Image, torch.Tensor)):\n            raise TypeError(f\"Unsupported image type: {type(image)}\")\n        if isinstance(image, Image.Image):\n            image = torch.from_numpy(np.array(image))\n\n        if image.ndim == 2:\n            image = image[..., None]\n\n        if not image.dtype.is_floating_point:\n            image = image.to(torch.float32).div(255.0)\n\n        if image.ndim == 3:\n            if image.shape[0] in (1, 3, 4) and image.shape[-1] not in (1, 3, 4):\n                image = image.unsqueeze(0)\n            else:\n                image = image.permute(2, 0, 1).unsqueeze(0)\n        elif image.ndim == 4:\n            if image.shape[1] not in (1, 3, 4) and image.shape[-1] in (1, 3, 4):\n                image = image.permute(0, 3, 1, 2)\n\n        image_height, image_width = image.shape[-2], image.shape[-1]\n        if image_height == target_height and image_width == target_width:\n            return image\n\n        logger.info(\n            \"Center cropping and resizing image to %dx%d\", target_width, target_height\n        )\n\n        if image_height * target_width < image_width * target_height:\n            cropped_width = (image_height * target_width) // target_height\n            left = (image_width - cropped_width) // 2\n            image = image[..., :, left : left + cropped_width]\n        else:\n            cropped_height = (image_width * target_height) // target_width\n            top = (image_height - cropped_height) // 2\n            image = image[..., top : top + cropped_height, :]\n\n        image = F.interpolate(\n            image,\n            size=(target_height, target_width),\n            mode=\"bilinear\",\n            align_corners=False,\n            antialias=True,\n        )\n        return image\n\n    def adjust_num_frames(self, num_frames: int) -> int:\n        if num_frames is None:\n            return num_frames\n        if num_frames % self.time_division_factor != self.time_division_remainder:\n            adjusted = (\n                (num_frames + self.time_division_factor - 1)\n                // self.time_division_factor\n                * self.time_division_factor\n                + self.time_division_remainder\n            )\n            logger.warning(\n                \"`num_frames` (%s) is not compatible with MOVA temporal constraints. \"\n                \"Rounding to %s.\",\n                num_frames,\n                adjusted,\n            )\n            return adjusted\n        return num_frames\n\n    def preprocess_condition_image(\n        self, image, target_width, target_height, _vae_image_processor\n    ):\n        image = self._center_crop_and_resize(image, target_height, target_width)\n        return image, (target_width, target_height)\n\n    def prepare_latent_shape(self, batch, batch_size, num_frames):\n        spatial = self.vae_config.arch_config.spatial_compression_ratio\n        length = (num_frames - 1) // self.time_division_factor + 1\n        shape = (\n            batch_size,\n            self.dit_config.arch_config.out_dim,\n            length,\n            batch.height // spatial,\n            batch.width // spatial,\n        )\n        return shape\n\n    def prepare_audio_latent_shape(self, batch_size, num_samples, audio_vae):\n        latent_T = (num_samples + audio_vae.hop_length - 1) // audio_vae.hop_length\n        return (batch_size, audio_vae.latent_dim, latent_T)\n\n    def normalize_video_latents(self, latents: torch.Tensor, video_vae) -> torch.Tensor:\n        latents_mean = getattr(video_vae.config, \"latents_mean\", None)\n        latents_std = getattr(video_vae.config, \"latents_std\", None)\n        if latents_mean is None or latents_std is None:\n            return latents\n        mean = torch.tensor(\n            latents_mean, device=latents.device, dtype=latents.dtype\n        ).view(1, video_vae.config.z_dim, 1, 1, 1)\n        inv_std = (\n            1.0 / torch.tensor(latents_std, device=latents.device, dtype=latents.dtype)\n        ).view(1, video_vae.config.z_dim, 1, 1, 1)\n        return (latents - mean) * inv_std\n\n    def denormalize_video_latents(\n        self, latents: torch.Tensor, video_vae\n    ) -> torch.Tensor:\n        latents_mean = getattr(video_vae.config, \"latents_mean\", None)\n        latents_std = getattr(video_vae.config, \"latents_std\", None)\n        if latents_mean is None or latents_std is None:\n            return latents\n        mean = torch.tensor(\n            latents_mean, device=latents.device, dtype=latents.dtype\n        ).view(1, video_vae.config.z_dim, 1, 1, 1)\n        std = torch.tensor(\n            latents_std, device=latents.device, dtype=latents.dtype\n        ).view(1, video_vae.config.z_dim, 1, 1, 1)\n        return latents * std + mean\n\n\n@dataclass\nclass MOVA360PConfig(MOVAPipelineConfig):\n    \"\"\"Configuration for MOVA 360P (text+image -> video+audio) pipelines.\"\"\"\n\n    max_area: int = 352 * 640\n\n\n@dataclass\nclass MOVA720PConfig(MOVAPipelineConfig):\n    \"\"\"Configuration for MOVA 720P (text+image -> video+audio) pipelines.\"\"\"\n\n    max_area: int = 720 * 1280\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/pipeline_configs/qwen_image.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\nfrom dataclasses import dataclass, field\nfrom typing import Callable\n\nimport torch\n\nfrom sglang.multimodal_gen.configs.models import DiTConfig, EncoderConfig, VAEConfig\nfrom sglang.multimodal_gen.configs.models.dits.qwenimage import (\n    QwenImageDitConfig,\n    QwenImageEditPlus_2511_DitConfig,\n)\nfrom sglang.multimodal_gen.configs.models.encoders.qwen_image import Qwen2_5VLConfig\nfrom sglang.multimodal_gen.configs.models.vaes.qwenimage import QwenImageVAEConfig\nfrom sglang.multimodal_gen.configs.pipeline_configs.base import (\n    ImagePipelineConfig,\n    ModelTaskType,\n    maybe_unpad_latents,\n    shard_rotary_emb_for_sp,\n)\nfrom sglang.multimodal_gen.runtime.models.vision_utils import resize\nfrom sglang.multimodal_gen.utils import calculate_dimensions\n\n\ndef _extract_masked_hidden(hidden_states: torch.Tensor, mask: torch.Tensor):\n    bool_mask = mask.bool()\n    valid_lengths = bool_mask.sum(dim=1)\n    selected = hidden_states[bool_mask]\n    split_result = torch.split(selected, valid_lengths.tolist(), dim=0)\n\n    return split_result\n\n\ndef qwen_image_preprocess_text(prompt):\n    prompt_template_encode = \"<|im_start|>system\\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\\n<|im_start|>user\\n{}<|im_end|>\\n<|im_start|>assistant\\n\"\n\n    template = prompt_template_encode\n    txt = template.format(prompt)\n    return txt\n\n\ndef qwen_image_postprocess_text(outputs, _text_inputs, drop_idx=34):\n    # squeeze the batch dim\n    hidden_states = outputs.hidden_states[-1]\n    split_hidden_states = _extract_masked_hidden(\n        hidden_states, _text_inputs.attention_mask\n    )\n    split_hidden_states = [e[drop_idx:] for e in split_hidden_states]\n    max_seq_len = max([e.size(0) for e in split_hidden_states])\n    prompt_embeds = torch.stack(\n        [\n            torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))])\n            for u in split_hidden_states\n        ]\n    )\n    return prompt_embeds\n\n\ndef _normalize_prompt_list(prompt):\n    return [prompt] if isinstance(prompt, str) else prompt\n\n\ndef _normalize_image_list(images):\n    if images is None:\n        return []\n    return images if isinstance(images, list) else [images]\n\n\ndef _build_qwen_edit_image_prompt(num_images: int) -> str:\n    img_prompt_template = \"Picture {}: <|vision_start|><|image_pad|><|vision_end|>\"\n    return \"\".join(img_prompt_template.format(i + 1) for i in range(num_images))\n\n\ndef _resolve_qwen_edit_per_prompt_images(prompt_list, image_list):\n    if len(prompt_list) <= 1:\n        return [image_list]\n\n    if len(image_list) <= 1:\n        return [list(image_list) for _ in prompt_list]\n\n    if len(image_list) != len(prompt_list):\n        raise ValueError(\n            \"QwenImageEditPlus expects either one shared condition image or \"\n            \"the same number of condition images and prompts.\"\n        )\n\n    return [[image] for image in image_list]\n\n\n# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents\ndef _pack_latents(latents, batch_size, num_channels_latents, height, width):\n    latents = latents.view(\n        batch_size, num_channels_latents, height // 2, 2, width // 2, 2\n    )\n    latents = latents.permute(0, 2, 4, 1, 3, 5)\n    latents = latents.reshape(\n        batch_size, (height // 2) * (width // 2), num_channels_latents * 4\n    )\n\n    return latents\n\n\n@dataclass\nclass QwenImagePipelineConfig(ImagePipelineConfig):\n    \"\"\"Configuration for the QwenImage pipeline.\"\"\"\n\n    should_use_guidance: bool = False\n    task_type: ModelTaskType = ModelTaskType.T2I\n\n    vae_tiling: bool = False\n\n    vae_sp: bool = False\n\n    dit_config: DiTConfig = field(default_factory=QwenImageDitConfig)\n    # VAE\n    vae_config: VAEConfig = field(default_factory=QwenImageVAEConfig)\n\n    enable_autocast: bool = False\n\n    # Text encoding stage\n    text_encoder_configs: tuple[EncoderConfig, ...] = field(\n        default_factory=lambda: (Qwen2_5VLConfig(),)\n    )\n\n    text_encoder_precisions: tuple[str, ...] = field(default_factory=lambda: (\"bf16\",))\n\n    preprocess_text_funcs: tuple[Callable[[str], str], ...] = field(\n        default_factory=lambda: (qwen_image_preprocess_text,)\n    )\n\n    postprocess_text_funcs: tuple[Callable[[str], str], ...] = field(\n        default_factory=lambda: (qwen_image_postprocess_text,)\n    )\n    text_encoder_extra_args: list[dict] = field(\n        default_factory=lambda: [\n            dict(\n                padding=True,\n                truncation=True,\n            ),\n            None,\n        ]\n    )\n\n    def prepare_sigmas(self, sigmas, num_inference_steps):\n        return self._prepare_sigmas(sigmas, num_inference_steps)\n\n    def prepare_image_processor_kwargs(self, batch, neg=False):\n        prompt = batch.prompt if not neg else batch.negative_prompt\n        if prompt:\n            prompt_template_encode = \"<|im_start|>system\\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\\n<|im_start|>user\\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\\n<|im_start|>assistant\\n\"\n            txt = prompt_template_encode.format(batch.prompt)\n            return dict(text=[txt], padding=True)\n        else:\n            return {}\n\n    def get_vae_scale_factor(self):\n        return self.vae_config.arch_config.vae_scale_factor\n\n    def prepare_latent_shape(self, batch, batch_size, num_frames):\n        vae_scale_factor = self.vae_config.arch_config.vae_scale_factor\n        height = 2 * (batch.height // (vae_scale_factor * 2))\n        width = 2 * (batch.width // (vae_scale_factor * 2))\n        num_channels_latents = self.dit_config.arch_config.in_channels // 4\n        shape = (batch_size, 1, num_channels_latents, height, width)\n        return shape\n\n    def maybe_pack_latents(self, latents, batch_size, batch):\n        height = 2 * (\n            batch.height // (self.vae_config.arch_config.vae_scale_factor * 2)\n        )\n        width = 2 * (batch.width // (self.vae_config.arch_config.vae_scale_factor * 2))\n        num_channels_latents = self.dit_config.arch_config.in_channels // 4\n        # pack latents\n        return _pack_latents(latents, batch_size, num_channels_latents, height, width)\n\n    def get_decode_scale_and_shift(self, device, dtype, vae):\n        vae_arch_config = self.vae_config.arch_config\n        scaling_factor = 1.0 / torch.tensor(\n            vae_arch_config.latents_std, device=device\n        ).view(1, vae_arch_config.z_dim, 1, 1, 1).to(device, dtype)\n        shift_factor = (\n            torch.tensor(vae_arch_config.latents_mean)\n            .view(1, vae_arch_config.z_dim, 1, 1, 1)\n            .to(device, dtype)\n        )\n        return scaling_factor, shift_factor\n\n    @staticmethod\n    def get_freqs_cis(img_shapes, txt_seq_lens, rotary_emb, device, dtype):\n        # img_shapes: for global entire image\n        img_freqs, txt_freqs = rotary_emb(img_shapes, txt_seq_lens, device=device)\n\n        # flashinfer RoPE expects a float32 cos/sin cache concatenated on the last dim\n        img_cos_half = img_freqs.real.to(dtype=torch.float32).contiguous()\n        img_sin_half = img_freqs.imag.to(dtype=torch.float32).contiguous()\n        txt_cos_half = txt_freqs.real.to(dtype=torch.float32).contiguous()\n        txt_sin_half = txt_freqs.imag.to(dtype=torch.float32).contiguous()\n\n        img_cos_sin_cache = torch.cat([img_cos_half, img_sin_half], dim=-1)\n        txt_cos_sin_cache = torch.cat([txt_cos_half, txt_sin_half], dim=-1)\n        return img_cos_sin_cache, txt_cos_sin_cache\n\n    def _prepare_cond_kwargs(self, batch, prompt_embeds, rotary_emb, device, dtype):\n        batch_size = prompt_embeds[0].shape[0]\n        height = batch.height\n        width = batch.width\n        vae_scale_factor = self.vae_config.arch_config.vae_scale_factor\n\n        img_shapes = [\n            [\n                (\n                    1,\n                    height // vae_scale_factor // 2,\n                    width // vae_scale_factor // 2,\n                )\n            ]\n        ] * batch_size\n        txt_seq_lens = [prompt_embeds[0].shape[1]]\n\n        if rotary_emb is None:\n            return {\n                \"img_shapes\": img_shapes,\n                \"txt_seq_lens\": txt_seq_lens,\n                \"freqs_cis\": None,\n            }\n\n        freqs_cis = self.get_freqs_cis(\n            img_shapes, txt_seq_lens, rotary_emb, device, dtype\n        )\n\n        img_cache, txt_cache = freqs_cis\n        img_cache = shard_rotary_emb_for_sp(img_cache)\n        return {\n            \"txt_seq_lens\": txt_seq_lens,\n            \"freqs_cis\": (img_cache, txt_cache),\n            \"img_shapes\": img_shapes,\n        }\n\n    def prepare_pos_cond_kwargs(self, batch, device, rotary_emb, dtype):\n        return self._prepare_cond_kwargs(\n            batch, batch.prompt_embeds, rotary_emb, device, dtype\n        )\n\n    def prepare_neg_cond_kwargs(self, batch, device, rotary_emb, dtype):\n        return self._prepare_cond_kwargs(\n            batch, batch.negative_prompt_embeds, rotary_emb, device, dtype\n        )\n\n    def post_denoising_loop(self, latents, batch):\n        # unpack latents for qwen-image\n        (\n            latents,\n            batch_size,\n            channels,\n            height,\n            width,\n        ) = self._unpad_and_unpack_latents(latents, batch)\n        latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)\n        return latents\n\n\n@dataclass\nclass QwenImageEditPipelineConfig(QwenImagePipelineConfig):\n    \"\"\"Configuration for the QwenImageEdit pipeline.\"\"\"\n\n    task_type: ModelTaskType = ModelTaskType.I2I\n\n    def _prepare_edit_cond_kwargs(\n        self, batch, prompt_embeds, rotary_emb, device, dtype\n    ):\n        batch_size = batch.latents.shape[0]\n        assert batch_size == 1\n        height = batch.height\n        width = batch.width\n        image_size = batch.original_condition_image_size\n        edit_width, edit_height, _ = calculate_dimensions(\n            1024 * 1024, image_size[0] / image_size[1]\n        )\n        vae_scale_factor = self.get_vae_scale_factor()\n\n        img_shapes = [\n            [\n                (\n                    1,\n                    height // vae_scale_factor // 2,\n                    width // vae_scale_factor // 2,\n                ),\n                (\n                    1,\n                    edit_height // vae_scale_factor // 2,\n                    edit_width // vae_scale_factor // 2,\n                ),\n            ],\n        ] * batch_size\n        txt_seq_lens = [prompt_embeds[0].shape[1]]\n\n        if rotary_emb is None:\n            return {\n                \"img_shapes\": img_shapes,\n                \"txt_seq_lens\": txt_seq_lens,\n                \"freqs_cis\": None,\n            }\n\n        freqs_cis = QwenImagePipelineConfig.get_freqs_cis(\n            img_shapes, txt_seq_lens, rotary_emb, device, dtype\n        )\n\n        # perform sp shard on noisy image tokens\n        noisy_img_seq_len = (\n            1 * (height // vae_scale_factor // 2) * (width // vae_scale_factor // 2)\n        )\n\n        img_cache, txt_cache = freqs_cis\n        noisy_img_cache = shard_rotary_emb_for_sp(img_cache[:noisy_img_seq_len, :])\n        img_cache = torch.cat(\n            [noisy_img_cache, img_cache[noisy_img_seq_len:, :]], dim=0\n        ).to(device=device)\n        return {\n            \"txt_seq_lens\": txt_seq_lens,\n            \"freqs_cis\": (img_cache, txt_cache),\n            \"img_shapes\": img_shapes,\n        }\n\n    def preprocess_condition_image(\n        self, image, target_width, target_height, _vae_image_processor\n    ):\n        return resize(image, target_height, target_width, resize_mode=\"default\"), (\n            target_width,\n            target_height,\n        )\n\n    def postprocess_image_latent(self, latent_condition, batch):\n        batch_size = batch.batch_size\n        if batch_size > latent_condition.shape[0]:\n            if batch_size % latent_condition.shape[0] == 0:\n                # expand init_latents for batch_size\n                additional_image_per_prompt = batch_size // latent_condition.shape[0]\n                image_latents = latent_condition.repeat(\n                    additional_image_per_prompt, 1, 1, 1\n                )\n            else:\n                raise ValueError(\n                    f\"Cannot duplicate `image` of batch size {latent_condition.shape[0]} to {batch_size} text prompts.\"\n                )\n        else:\n            image_latents = latent_condition\n        image_latent_height, image_latent_width = image_latents.shape[3:]\n        num_channels_latents = self.dit_config.arch_config.in_channels // 4\n        image_latents = _pack_latents(\n            image_latents,\n            batch_size,\n            num_channels_latents,\n            image_latent_height,\n            image_latent_width,\n        )\n\n        return image_latents\n\n    def prepare_pos_cond_kwargs(self, batch, device, rotary_emb, dtype):\n        return self._prepare_edit_cond_kwargs(\n            batch, batch.prompt_embeds, rotary_emb, device, dtype\n        )\n\n    def prepare_neg_cond_kwargs(self, batch, device, rotary_emb, dtype):\n        return self._prepare_edit_cond_kwargs(\n            batch, batch.negative_prompt_embeds, rotary_emb, device, dtype\n        )\n\n    def calculate_condition_image_size(self, image, width, height) -> tuple[int, int]:\n        calculated_width, calculated_height, _ = calculate_dimensions(\n            1024 * 1024, width / height\n        )\n        return calculated_width, calculated_height\n\n    def slice_noise_pred(self, noise, latents):\n        # remove noise over input image\n        noise = noise[:, : latents.size(1)]\n        return noise\n\n\nCONDITION_IMAGE_SIZE = 384 * 384\nVAE_IMAGE_SIZE = 1024 * 1024\n\n\n@dataclass\nclass QwenImageEditPlusPipelineConfig(QwenImageEditPipelineConfig):\n    task_type: ModelTaskType = ModelTaskType.I2I\n\n    def _get_condition_image_sizes(self, batch) -> list[tuple[int, int]]:\n        image = batch.condition_image\n        if not isinstance(image, list):\n            image = [image]\n\n        condition_image_sizes = []\n        for img in image:\n            image_width, image_height = img.size\n            edit_width, edit_height, _ = calculate_dimensions(\n                VAE_IMAGE_SIZE, image_width / image_height\n            )\n            condition_image_sizes.append((edit_width, edit_height))\n\n        return condition_image_sizes\n\n    def prepare_image_processor_kwargs(self, batch, neg=False) -> dict:\n        prompt = batch.prompt if not neg else batch.negative_prompt\n        if not prompt:\n            return {}\n\n        prompt_list = _normalize_prompt_list(prompt)\n        image_list = _normalize_image_list(batch.condition_image)\n        per_prompt_images = _resolve_qwen_edit_per_prompt_images(\n            prompt_list, image_list\n        )\n\n        prompt_template_encode = (\n            \"<|im_start|>system\\nDescribe the key features of the input image \"\n            \"(color, shape, size, texture, objects, background), then explain how \"\n            \"the user's text instruction should alter or modify the image. Generate \"\n            \"a new image that meets the user's requirements while maintaining \"\n            \"consistency with the original input where appropriate.<|im_end|>\\n\"\n            \"<|im_start|>user\\n{}<|im_end|>\\n\"\n            \"<|im_start|>assistant\\n\"\n        )\n        txt = [\n            prompt_template_encode.format(\n                _build_qwen_edit_image_prompt(len(prompt_images)) + prompt_text\n            )\n            for prompt_text, prompt_images in zip(prompt_list, per_prompt_images)\n        ]\n\n        return dict(text=txt, padding=True, per_prompt_images=per_prompt_images)\n\n    def prepare_calculated_size(self, image):\n        return self.calculate_vae_image_size(image, image.width, image.height)\n\n    def resize_condition_image(self, images, target_width, target_height):\n        if not isinstance(images, list):\n            images = [images]\n        new_images = []\n        for img, width, height in zip(images, target_width, target_height):\n            new_images.append(resize(img, height, width, resize_mode=\"default\"))\n        return new_images\n\n    def calculate_condition_image_size(self, image, width, height) -> tuple[int, int]:\n        calculated_width, calculated_height, _ = calculate_dimensions(\n            CONDITION_IMAGE_SIZE, width / height\n        )\n        return calculated_width, calculated_height\n\n    def calculate_vae_image_size(self, image, width, height) -> tuple[int, int]:\n        calculated_width, calculated_height, _ = calculate_dimensions(\n            VAE_IMAGE_SIZE, width / height\n        )\n        return calculated_width, calculated_height\n\n    def preprocess_vae_image(self, batch, vae_image_processor):\n        if not isinstance(batch.condition_image, list):\n            batch.condition_image = [batch.condition_image]\n        new_images = []\n        vae_image_sizes = []\n        for img in batch.condition_image:\n            width, height = self.calculate_vae_image_size(img, img.width, img.height)\n            new_images.append(vae_image_processor.preprocess(img, height, width))\n            vae_image_sizes.append((width, height))\n        batch.vae_image = new_images\n        batch.vae_image_sizes = vae_image_sizes\n        return batch\n\n    def _prepare_edit_cond_kwargs(\n        self, batch, prompt_embeds, rotary_emb, device, dtype\n    ):\n        batch_size = batch.latents.shape[0]\n        assert batch_size == 1\n        height = batch.height\n        width = batch.width\n\n        vae_scale_factor = self.get_vae_scale_factor()\n\n        img_shapes = [\n            [\n                (1, height // vae_scale_factor // 2, width // vae_scale_factor // 2),\n                *[\n                    (\n                        1,\n                        vae_height // vae_scale_factor // 2,\n                        vae_width // vae_scale_factor // 2,\n                    )\n                    for vae_width, vae_height in batch.vae_image_sizes\n                ],\n            ],\n        ] * batch_size\n        txt_seq_lens = [prompt_embeds[0].shape[1]]\n\n        freqs_cis = QwenImageEditPlusPipelineConfig.get_freqs_cis(\n            img_shapes, txt_seq_lens, rotary_emb, device, dtype\n        )\n\n        # perform sp shard on noisy image tokens\n        noisy_img_seq_len = (\n            1 * (height // vae_scale_factor // 2) * (width // vae_scale_factor // 2)\n        )\n\n        if isinstance(freqs_cis[0], torch.Tensor) and freqs_cis[0].dim() == 2:\n            img_cache, txt_cache = freqs_cis\n            noisy_img_cache = shard_rotary_emb_for_sp(img_cache[:noisy_img_seq_len, :])\n            img_cache = torch.cat(\n                [noisy_img_cache, img_cache[noisy_img_seq_len:, :]], dim=0\n            ).to(device=device)\n            return {\n                \"txt_seq_lens\": txt_seq_lens,\n                \"freqs_cis\": (img_cache, txt_cache),\n                \"img_shapes\": img_shapes,\n            }\n\n        (img_cos, img_sin), (txt_cos, txt_sin) = freqs_cis\n        noisy_img_cos = shard_rotary_emb_for_sp(img_cos[:noisy_img_seq_len, :])\n        noisy_img_sin = shard_rotary_emb_for_sp(img_sin[:noisy_img_seq_len, :])\n\n        # concat back the img_cos for input image (since it is not sp-shared later)\n        img_cos = torch.cat([noisy_img_cos, img_cos[noisy_img_seq_len:, :]], dim=0).to(\n            device=device\n        )\n        img_sin = torch.cat([noisy_img_sin, img_sin[noisy_img_seq_len:, :]], dim=0).to(\n            device=device\n        )\n\n        return {\n            \"txt_seq_lens\": txt_seq_lens,\n            \"freqs_cis\": ((img_cos, img_sin), (txt_cos, txt_sin)),\n            \"img_shapes\": img_shapes,\n        }\n\n\n@dataclass\nclass QwenImageEditPlus_2511_PipelineConfig(QwenImageEditPlusPipelineConfig):\n    dit_config: DiTConfig = field(default_factory=QwenImageEditPlus_2511_DitConfig)\n\n\n@dataclass\nclass QwenImageLayeredPipelineConfig(QwenImageEditPipelineConfig):\n    resolution: int = 640\n    vae_precision: str = \"bf16\"\n\n    def _prepare_edit_cond_kwargs(\n        self, batch, prompt_embeds, rotary_emb, device, dtype\n    ):\n        batch_size = batch.latents.shape[0]\n        assert batch_size == 1\n        height = batch.height\n        width = batch.width\n\n        vae_scale_factor = self.get_vae_scale_factor()\n\n        img_shapes = batch.img_shapes\n        txt_seq_lens = batch.txt_seq_lens\n\n        freqs_cis = QwenImageEditPlusPipelineConfig.get_freqs_cis(\n            img_shapes, txt_seq_lens, rotary_emb, device, dtype\n        )\n\n        # perform sp shard on noisy image tokens\n        noisy_img_seq_len = (\n            1 * (height // vae_scale_factor // 2) * (width // vae_scale_factor // 2)\n        )\n\n        img_cache, txt_cache = freqs_cis\n        noisy_img_cache = shard_rotary_emb_for_sp(img_cache[:noisy_img_seq_len, :])\n        img_cache = torch.cat(\n            [noisy_img_cache, img_cache[noisy_img_seq_len:, :]], dim=0\n        ).to(device=device)\n\n        return {\n            \"txt_seq_lens\": txt_seq_lens,\n            \"img_shapes\": img_shapes,\n            \"freqs_cis\": (img_cache, txt_cache),\n            \"additional_t_cond\": torch.tensor([0], device=device, dtype=torch.long),\n        }\n\n    def _unpad_and_unpack_latents(self, latents, batch):\n        vae_scale_factor = self.vae_config.arch_config.vae_scale_factor\n        channels = self.dit_config.arch_config.in_channels\n        batch_size = latents.shape[0]\n        layers = batch.num_frames\n\n        height = 2 * (int(batch.height) // (vae_scale_factor * 2))\n        width = 2 * (int(batch.width) // (vae_scale_factor * 2))\n\n        latents = maybe_unpad_latents(latents, batch)\n        latents = latents.view(\n            batch_size, layers + 1, height // 2, width // 2, channels // 4, 2, 2\n        )\n        latents = latents.permute(0, 1, 4, 2, 5, 3, 6)\n\n        latents = latents.reshape(\n            batch_size, layers + 1, channels // (2 * 2), height, width\n        )\n        latents = latents.permute(0, 2, 1, 3, 4)  # (b, c, f, h, w)\n        return latents, batch_size, channels, height, width\n\n    def allow_set_num_frames(self):\n        return True\n\n    def post_denoising_loop(self, latents, batch):\n        # unpack latents for qwen-image\n        (\n            latents,\n            batch_size,\n            channels,\n            height,\n            width,\n        ) = self._unpad_and_unpack_latents(latents, batch)\n        b, c, f, h, w = latents.shape\n        latents = latents[:, :, 1:]  # remove the first frame as it is the origin input\n        latents = latents.permute(0, 2, 1, 3, 4).view(-1, c, 1, h, w)\n        # latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)\n        return latents\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/pipeline_configs/sana.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n#\n# Pipeline configuration for SANA text-to-image generation.\n#\n# SANA produces 4D spatial latents (B, C, H', W') directly — unlike Flux/QwenImage\n# which use packed token-style latents (B, S, D). This means:\n#   - We inherit SpatialImagePipelineConfig (not ImagePipelineConfig)\n#   - prepare_latent_shape returns 4D, not 5D\n#   - post_denoising_loop is a no-op (no un-packing needed)\n#   - shard_latents_for_sp shards along the H' dimension\n#\n# SANA does NOT use rotary position embeddings, so prepare_pos/neg_cond_kwargs\n# return empty dicts (the DiT only needs hidden_states + encoder_hidden_states + timestep).\n#\n# CFG is handled by the denoising stage via guidance_scale in sampling params.\n# should_use_guidance=False means no embedded guidance (no extra guidance token in forward),\n# but negative_prompt + guidance_scale > 1.0 still enables standard classifier-free guidance.\n\nfrom collections.abc import Callable\nfrom dataclasses import dataclass, field\n\nimport torch\n\nfrom sglang.multimodal_gen.configs.models import DiTConfig, VAEConfig\nfrom sglang.multimodal_gen.configs.models.dits.sana import SanaConfig\nfrom sglang.multimodal_gen.configs.models.encoders import BaseEncoderOutput\nfrom sglang.multimodal_gen.configs.models.encoders.base import EncoderConfig\nfrom sglang.multimodal_gen.configs.models.encoders.gemma2 import Gemma2Config\nfrom sglang.multimodal_gen.configs.models.vaes.sana import SanaVAEConfig\nfrom sglang.multimodal_gen.configs.pipeline_configs.base import (\n    ModelTaskType,\n    SpatialImagePipelineConfig,\n    preprocess_text,\n)\n\n\ndef sana_postprocess_text(outputs: BaseEncoderOutput, _text_inputs) -> torch.Tensor:\n    # SANA uses the final hidden state from Gemma2 directly as text conditioning.\n    # No intermediate-layer extraction or masking needed (unlike QwenImage/ZImage).\n    return outputs.last_hidden_state\n\n\n@dataclass\nclass SanaPipelineConfig(SpatialImagePipelineConfig):\n\n    task_type: ModelTaskType = ModelTaskType.T2I\n\n    # should_use_guidance=False disables *embedded* guidance (timestep-conditioned\n    # guidance token). Standard CFG via guidance_scale is still active.\n    should_use_guidance: bool = False\n    enable_autocast: bool = False\n\n    # DC-AE does not support tiling or SP VAE decode yet.\n    vae_tiling: bool = False\n    vae_sp: bool = False\n    vae_precision: str = \"bf16\"\n\n    dit_config: DiTConfig = field(default_factory=SanaConfig)\n    vae_config: VAEConfig = field(default_factory=SanaVAEConfig)\n\n    # Single text encoder: Gemma2 (unlike Flux which uses CLIP + T5)\n    text_encoder_configs: tuple[EncoderConfig, ...] = field(\n        default_factory=lambda: (Gemma2Config(),)\n    )\n\n    text_encoder_precisions: tuple[str, ...] = field(default_factory=lambda: (\"bf16\",))\n\n    preprocess_text_funcs: tuple[Callable[[str], str], ...] = field(\n        default_factory=lambda: (preprocess_text,),\n    )\n\n    postprocess_text_funcs: tuple[Callable[[str], str], ...] = field(\n        default_factory=lambda: (sana_postprocess_text,)\n    )\n\n    def prepare_latent_shape(self, batch, batch_size, num_frames):\n        # 4D latent shape: (B, C, H', W') — no temporal dim for T2I.\n        # DC-AE compresses 1024x1024 -> 32x32 with 32 channels.\n        compression = self.vae_config.arch_config.spatial_compression_ratio\n        height = batch.height // compression\n        width = batch.width // compression\n        num_channels = self.dit_config.arch_config.num_channels_latents\n        shape = (batch_size, num_channels, height, width)\n        return shape\n\n    def get_pos_prompt_embeds(self, batch):\n        # Single encoder -> index [0] (Flux uses [1] because T5 is encoder #2)\n        return batch.prompt_embeds[0]\n\n    def get_neg_prompt_embeds(self, batch):\n        return batch.negative_prompt_embeds[0]\n\n    def prepare_pos_cond_kwargs(self, batch, device, rotary_emb, dtype):\n        # encoder_attention_mask: batch stores list-of-tensors; diffusers' SanaTransformer\n        # expects a single tensor (sglang's has list handling). Override with [0].\n        out = {}\n        m = batch.prompt_attention_mask\n        if isinstance(m, (list, tuple)):\n            out[\"encoder_attention_mask\"] = m[0] if m else None\n        elif m is not None:\n            out[\"encoder_attention_mask\"] = m\n        return out\n\n    def prepare_neg_cond_kwargs(self, batch, device, rotary_emb, dtype):\n        out = {}\n        m = batch.negative_attention_mask\n        if isinstance(m, (list, tuple)):\n            out[\"encoder_attention_mask\"] = m[0] if m else None\n        elif m is not None:\n            out[\"encoder_attention_mask\"] = m\n        return out\n\n    def post_denoising_loop(self, latents, batch):\n        return latents\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/pipeline_configs/wan.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\nfrom collections.abc import Callable\nfrom dataclasses import dataclass, field\n\nimport torch\n\nfrom sglang.multimodal_gen.configs.models import DiTConfig, EncoderConfig, VAEConfig\nfrom sglang.multimodal_gen.configs.models.dits import WanVideoConfig\nfrom sglang.multimodal_gen.configs.models.encoders import (\n    BaseEncoderOutput,\n    CLIPVisionConfig,\n    T5Config,\n)\nfrom sglang.multimodal_gen.configs.models.vaes import WanVAEConfig\nfrom sglang.multimodal_gen.configs.pipeline_configs.base import (\n    ModelTaskType,\n    PipelineConfig,\n)\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\ndef t5_postprocess_text(outputs: BaseEncoderOutput, _text_inputs) -> torch.Tensor:\n    mask: torch.Tensor = outputs.attention_mask\n    hidden_state: torch.Tensor = outputs.last_hidden_state\n    seq_lens = mask.gt(0).sum(dim=1).long()\n    assert torch.isnan(hidden_state).sum() == 0\n    prompt_embeds = [u[:v] for u, v in zip(hidden_state, seq_lens, strict=True)]\n    prompt_embeds_tensor: torch.Tensor = torch.stack(\n        [\n            torch.cat([u, u.new_zeros(512 - u.size(0), u.size(1))])\n            for u in prompt_embeds\n        ],\n        dim=0,\n    )\n    return prompt_embeds_tensor\n\n\n@dataclass\nclass WanI2VCommonConfig(PipelineConfig):\n    # for all wan i2v pipelines\n    def adjust_num_frames(self, num_frames):\n        vae_scale_factor_temporal = self.vae_config.arch_config.scale_factor_temporal\n        if num_frames % vae_scale_factor_temporal != 1:\n            logger.warning(\n                f\"`num_frames - 1` has to be divisible by {vae_scale_factor_temporal}. Rounding to the nearest number.\"\n            )\n            num_frames = (\n                num_frames // vae_scale_factor_temporal * vae_scale_factor_temporal + 1\n            )\n            return num_frames\n        return num_frames\n\n\n@dataclass\nclass WanT2V480PConfig(PipelineConfig):\n    \"\"\"Base configuration for Wan T2V 1.3B pipeline architecture.\"\"\"\n\n    task_type: ModelTaskType = ModelTaskType.T2V\n    # WanConfig-specific parameters with defaults\n    # DiT\n    dit_config: DiTConfig = field(default_factory=WanVideoConfig)\n\n    # VAE\n    vae_config: VAEConfig = field(default_factory=WanVAEConfig)\n    vae_tiling: bool = False\n    vae_sp: bool = False\n\n    # Denoising stage\n    flow_shift: float | None = 3.0\n\n    # Text encoding stage\n    text_encoder_configs: tuple[EncoderConfig, ...] = field(\n        default_factory=lambda: (T5Config(),)\n    )\n    postprocess_text_funcs: tuple[Callable[[BaseEncoderOutput], torch.Tensor], ...] = (\n        field(default_factory=lambda: (t5_postprocess_text,))\n    )\n\n    # Precision for each component\n    precision: str = \"bf16\"\n    vae_precision: str = \"fp32\"\n    text_encoder_precisions: tuple[str, ...] = field(default_factory=lambda: (\"fp32\",))\n\n    # WanConfig-specific added parameters\n\n    def __post_init__(self):\n        self.vae_config.load_encoder = False\n        self.vae_config.load_decoder = True\n\n\n@dataclass\nclass TurboWanT2V480PConfig(WanT2V480PConfig):\n    \"\"\"Base configuration for Wan T2V 1.3B pipeline architecture.\"\"\"\n\n    flow_shift: float | None = 8.0\n    dmd_denoising_steps: list[int] | None = field(\n        default_factory=lambda: [988, 932, 852, 608]\n    )\n\n\n@dataclass\nclass WanT2V720PConfig(WanT2V480PConfig):\n    \"\"\"Base configuration for Wan T2V 14B 720P pipeline architecture.\"\"\"\n\n    # WanConfig-specific parameters with defaults\n\n    # Denoising stage\n    flow_shift: float | None = 5.0\n\n\n@dataclass\nclass WanI2V480PConfig(WanT2V480PConfig, WanI2VCommonConfig):\n    \"\"\"Base configuration for Wan I2V 14B 480P pipeline architecture.\"\"\"\n\n    max_area: int = 480 * 832\n    # WanConfig-specific parameters with defaults\n    task_type: ModelTaskType = ModelTaskType.I2V\n    # Precision for each component\n    image_encoder_config: EncoderConfig = field(default_factory=CLIPVisionConfig)\n    image_encoder_precision: str = \"fp32\"\n\n    image_encoder_extra_args: dict = field(\n        default_factory=lambda: dict(\n            output_hidden_states=True,\n        )\n    )\n\n    def postprocess_image(self, image):\n        return image.hidden_states[-2]\n\n    def __post_init__(self) -> None:\n        self.vae_config.load_encoder = True\n        self.vae_config.load_decoder = True\n\n\n@dataclass\nclass WanI2V720PConfig(WanI2V480PConfig):\n    \"\"\"Base configuration for Wan I2V 14B 720P pipeline architecture.\"\"\"\n\n    max_area: int = 720 * 1280\n    # WanConfig-specific parameters with defaults\n\n    # Denoising stage\n    flow_shift: float | None = 5.0\n\n\n@dataclass\nclass TurboWanI2V720Config(WanI2V720PConfig):\n    flow_shift: float | None = 8.0\n    dmd_denoising_steps: list[int] | None = field(\n        default_factory=lambda: [996, 932, 852, 608]\n    )\n    boundary_ratio: float | None = 0.9\n\n    def __post_init__(self) -> None:\n        self.dit_config.boundary_ratio = self.boundary_ratio\n\n\n@dataclass\nclass FastWan2_1_T2V_480P_Config(WanT2V480PConfig):\n    \"\"\"Base configuration for FastWan T2V 1.3B 480P pipeline architecture with DMD\"\"\"\n\n    # WanConfig-specific parameters with defaults\n\n    # Denoising stage\n    flow_shift: float | None = 8.0\n    dmd_denoising_steps: list[int] | None = field(\n        default_factory=lambda: [1000, 757, 522]\n    )\n\n\n@dataclass\nclass Wan2_2_TI2V_5B_Config(WanT2V480PConfig, WanI2VCommonConfig):\n    flow_shift: float | None = 5.0\n    task_type: ModelTaskType = ModelTaskType.TI2V\n    expand_timesteps: bool = True\n    # ti2v, 5B\n    vae_stride = (4, 16, 16)\n\n    def prepare_latent_shape(self, batch, batch_size, num_frames):\n        F = num_frames\n        z_dim = self.vae_config.arch_config.z_dim\n        vae_stride = self.vae_stride\n        oh = batch.height\n        ow = batch.width\n        shape = (batch_size, z_dim, F, oh // vae_stride[1], ow // vae_stride[2])\n        return shape\n\n    def __post_init__(self) -> None:\n        self.vae_config.load_encoder = True\n        self.vae_config.load_decoder = True\n        self.dit_config.expand_timesteps = self.expand_timesteps\n\n\n@dataclass\nclass FastWan2_2_TI2V_5B_Config(Wan2_2_TI2V_5B_Config):\n    flow_shift: float | None = 5.0\n    dmd_denoising_steps: list[int] | None = field(\n        default_factory=lambda: [1000, 757, 522]\n    )\n\n\n@dataclass\nclass Wan2_2_T2V_A14B_Config(WanT2V480PConfig):\n    flow_shift: float | None = 12.0\n    boundary_ratio: float | None = 0.875\n\n    def __post_init__(self) -> None:\n        self.dit_config.boundary_ratio = self.boundary_ratio\n\n\n@dataclass\nclass Wan2_2_I2V_A14B_Config(WanI2V480PConfig):\n    flow_shift: float | None = 5.0\n    boundary_ratio: float | None = 0.900\n\n    def __post_init__(self) -> None:\n        super().__post_init__()\n        self.dit_config.boundary_ratio = self.boundary_ratio\n\n\n# =============================================\n# ============= Causal Self-Forcing =============\n# =============================================\n@dataclass\nclass SelfForcingWanT2V480PConfig(WanT2V480PConfig):\n    is_causal: bool = True\n    flow_shift: float | None = 5.0\n    dmd_denoising_steps: list[int] | None = field(\n        default_factory=lambda: [1000, 750, 500, 250]\n    )\n    warp_denoising_step: bool = True\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/pipeline_configs/zimage.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\nimport math\nfrom dataclasses import dataclass, field\nfrom typing import Callable\n\nimport torch\n\nfrom sglang.multimodal_gen.configs.models import DiTConfig, EncoderConfig, VAEConfig\nfrom sglang.multimodal_gen.configs.models.dits.zimage import ZImageDitConfig\nfrom sglang.multimodal_gen.configs.models.encoders import BaseEncoderOutput\nfrom sglang.multimodal_gen.configs.models.encoders.qwen3 import Qwen3TextConfig\nfrom sglang.multimodal_gen.configs.models.vaes.flux import FluxVAEConfig\nfrom sglang.multimodal_gen.configs.pipeline_configs.base import (\n    ImagePipelineConfig,\n    ModelTaskType,\n)\nfrom sglang.multimodal_gen.runtime.distributed.communication_op import (\n    sequence_model_parallel_all_gather,\n)\nfrom sglang.multimodal_gen.runtime.distributed.parallel_state import (\n    get_sp_parallel_rank,\n    get_sp_world_size,\n)\n\n\ndef zimage_preprocess_text(prompt: str):\n    messages = [\n        {\"role\": \"user\", \"content\": prompt},\n    ]\n    return messages\n\n\ndef zimage_postprocess_text(outputs: BaseEncoderOutput, _text_inputs) -> torch.Tensor:\n    device = outputs.hidden_states[-2].device\n    prompt_mask = _text_inputs.attention_mask.to(device).bool()\n    return outputs.hidden_states[-2][0][prompt_mask[0]]\n\n\nclass TransformersModelConfig(EncoderConfig):\n    tokenizer_kwargs: dict = field(default_factory=lambda: {})\n\n\n@dataclass\nclass ZImagePipelineConfig(ImagePipelineConfig):\n    should_use_guidance: bool = False\n    task_type: ModelTaskType = ModelTaskType.T2I\n    dit_config: DiTConfig = field(default_factory=ZImageDitConfig)\n    vae_config: VAEConfig = field(default_factory=FluxVAEConfig)\n    text_encoder_configs: tuple[EncoderConfig, ...] = field(\n        default_factory=lambda: (Qwen3TextConfig(),)\n    )\n\n    preprocess_text_funcs: tuple[Callable, ...] = field(\n        default_factory=lambda: (zimage_preprocess_text,)\n    )\n    postprocess_text_funcs: tuple[Callable, ...] = field(\n        default_factory=lambda: (zimage_postprocess_text,)\n    )\n\n    SEQ_LEN_MULTIPLE: int = 32\n    PATCH_SIZE: int = 2\n    F_PATCH_SIZE: int = 1\n\n    def tokenize_prompt(self, prompts: list[str], tokenizer, tok_kwargs) -> dict:\n        # flatten to 1-d list\n        inputs = tokenizer.apply_chat_template(\n            prompts,\n            tokenize=True,\n            add_generation_prompt=True,\n            enable_thinking=True,\n            padding=\"max_length\",\n            max_length=512,  # TODO (yhyang201): set max length according to config\n            truncation=True,\n            return_tensors=\"pt\",\n            return_dict=True,\n        )\n        return inputs\n\n    @staticmethod\n    def _ceil_to_multiple(x: int, m: int) -> int:\n        if m <= 0:\n            return x\n        return int(math.ceil(x / m) * m)\n\n    def _build_zimage_sp_plan(self, batch) -> dict:\n        \"\"\"Build a minimal SP plan on batch for zimage (spatial sharding + cap sharding).\"\"\"\n        sp_size = get_sp_world_size()\n        rank = get_sp_parallel_rank()\n\n        raw_latent_shape = getattr(batch, \"raw_latent_shape\", None)\n        if raw_latent_shape is not None and len(raw_latent_shape) >= 5:\n            H = int(raw_latent_shape[3])\n            W = int(raw_latent_shape[4])\n        else:\n            H = int(\n                batch.height // self.vae_config.arch_config.spatial_compression_ratio\n            )\n            W = int(\n                batch.width // self.vae_config.arch_config.spatial_compression_ratio\n            )\n\n        # Rule: shard along the larger spatial dimension (W/H), implemented via optional H/W transpose.\n        # Choose the larger of H and W for sharding, so H_eff = max(H, W).\n        swap_hw = W > H\n        H_eff = W if swap_hw else H\n        W_eff = H if swap_hw else W\n\n        # ZImage uses PATCH_SIZE=2 for spatial patchify; shard in token space and convert back to latent rows.\n        H_tok = H_eff // self.PATCH_SIZE\n        W_tok = W_eff // self.PATCH_SIZE\n        H_tok_pad = self._ceil_to_multiple(H_tok, sp_size)\n        H_tok_local = H_tok_pad // sp_size\n        h0_tok = rank * H_tok_local\n\n        # Cap/text sharding: avoid duplicating cap tokens across ranks.\n        cap_len = (\n            int(batch.prompt_embeds[0].size(0))\n            if getattr(batch, \"prompt_embeds\", None)\n            else 0\n        )\n        cap_total = self._ceil_to_multiple(cap_len, self.SEQ_LEN_MULTIPLE * sp_size)\n        cap_local = cap_total // sp_size\n        cap_start = rank * cap_local\n\n        plan = {\n            \"sp_size\": sp_size,\n            \"rank\": rank,\n            \"swap_hw\": swap_hw,\n            \"H\": H,\n            \"W\": W,\n            \"H_eff\": H_eff,\n            \"W_eff\": W_eff,\n            \"H_tok\": H_tok,\n            \"W_tok\": W_tok,\n            \"H_tok_pad\": H_tok_pad,\n            \"H_tok_local\": H_tok_local,\n            \"h0_tok\": h0_tok,\n            \"cap_total\": cap_total,\n            \"cap_local\": cap_local,\n            \"cap_start\": cap_start,\n        }\n        batch._zimage_sp_plan = plan\n        return plan\n\n    def _get_zimage_sp_plan(self, batch) -> dict:\n        plan = getattr(batch, \"_zimage_sp_plan\", None)\n        sp_size = get_sp_world_size()\n        if plan is None or plan.get(\"sp_size\") != sp_size:\n            plan = self._build_zimage_sp_plan(batch)\n        return plan\n\n    def _shard_cap(self, cap: torch.Tensor, plan: dict) -> torch.Tensor:\n        \"\"\"cap: [L, D] -> [cap_local, D], padded by repeating last token.\"\"\"\n        if plan[\"sp_size\"] <= 1:\n            return cap\n        # print(f\"cap shape: {cap.shape}\")  # [L, 2560] for zimage-turbo\n        L = cap.size(0)\n        cap_total = plan[\"cap_total\"]\n        if cap_total > L:\n            cap = torch.cat([cap, cap[-1:].repeat(cap_total - L, 1)], dim=0)\n        start = plan[\"cap_start\"]\n        local = plan[\"cap_local\"]\n        return cap[start : start + local]\n\n    def get_pos_prompt_embeds(self, batch):\n        # Keep ZImage model signature: encoder_hidden_states is List[Tensor]\n        if get_sp_world_size() <= 1:\n            return batch.prompt_embeds\n        plan = self._get_zimage_sp_plan(batch)\n        return [self._shard_cap(batch.prompt_embeds[0], plan)]\n\n    def shard_latents_for_sp(self, batch, latents):\n        sp_size = get_sp_world_size()\n        if sp_size <= 1 or latents.dim() != 5:\n            return latents, False\n\n        plan = self._get_zimage_sp_plan(batch)\n\n        # Layout: [B, C, T, H, W]. Always shard on dim=3 by optionally swapping H/W.\n        if plan[\"swap_hw\"]:\n            latents = latents.transpose(3, 4).contiguous()\n\n        # Pad on effective-H so that H_tok is divisible by sp.\n        H_eff = latents.size(3)\n\n        H_tok = H_eff // self.PATCH_SIZE\n        pad_tok = plan[\"H_tok_pad\"] - H_tok\n        pad_lat = pad_tok * self.PATCH_SIZE\n        if pad_lat > 0:\n            pad = latents[:, :, :, -1:, :].repeat(1, 1, 1, pad_lat, 1)\n            latents = torch.cat([latents, pad], dim=3)\n        h0 = plan[\"h0_tok\"] * self.PATCH_SIZE\n        h1 = (plan[\"h0_tok\"] + plan[\"H_tok_local\"]) * self.PATCH_SIZE\n        latents = latents[:, :, :, h0:h1, :]\n\n        batch._zimage_sp_swap_hw = plan[\"swap_hw\"]\n        return latents, True\n\n    def gather_latents_for_sp(self, latents):\n        # Gather on effective-H dim=3 (matches shard_latents_for_sp); swap-back is handled in post_denoising_loop.\n        latents = latents.contiguous()\n        if get_sp_world_size() <= 1 or latents.dim() != 5:\n            return latents\n        return sequence_model_parallel_all_gather(latents, dim=3)\n\n    def post_denoising_loop(self, latents, batch):\n        # Restore swapped H/W and crop padded spatial dims before final reshape.\n        if latents.dim() == 5 and getattr(batch, \"_zimage_sp_swap_hw\", False):\n            latents = latents.transpose(3, 4).contiguous()\n        raw_latent_shape = getattr(batch, \"raw_latent_shape\", None)\n        if raw_latent_shape is not None and latents.dim() == 5:\n            latents = latents[:, :, :, : raw_latent_shape[3], : raw_latent_shape[4]]\n\n        bs, channels, num_frames, height, width = latents.shape\n        if raw_latent_shape is not None and num_frames > raw_latent_shape[2]:\n            latents = latents[:, :, : raw_latent_shape[2], :, :]\n            num_frames = raw_latent_shape[2]\n        if num_frames != 1:\n            return latents[:, :, 0, :, :]\n        return latents.view(bs, channels, height, width)\n\n    def get_freqs_cis(self, prompt_embeds, width, height, device, rotary_emb, batch):\n        def create_coordinate_grid(size, start=None, device=None):\n            if start is None:\n                start = (0 for _ in size)\n\n            axes = [\n                torch.arange(x0, x0 + span, dtype=torch.int32, device=device)\n                for x0, span in zip(start, size)\n            ]\n            grids = torch.meshgrid(axes, indexing=\"ij\")\n            return torch.stack(grids, dim=-1)\n\n        sp_size = get_sp_world_size()\n        if sp_size > 1:\n            # SP path: build local-only freqs_cis matching local cap/x.\n            plan = self._get_zimage_sp_plan(batch)\n\n            # cap (local)\n            cap_pos_ids = create_coordinate_grid(\n                size=(plan[\"cap_local\"], 1, 1),\n                start=(1 + plan[\"cap_start\"], 0, 0),\n                device=device,\n            ).flatten(0, 2)\n            cap_freqs_cis = rotary_emb(cap_pos_ids)\n\n            # image (local, effective H-shard). Use cap_total for a stable offset across ranks/passes.\n            F_tokens = 1\n            H_tokens_local = plan[\"H_tok_local\"]\n            W_tokens = plan[\"W_tok\"]\n            img_pos_ids = create_coordinate_grid(\n                size=(F_tokens, H_tokens_local, W_tokens),\n                start=(plan[\"cap_total\"] + 1, plan[\"h0_tok\"], 0),\n                device=device,\n            ).flatten(0, 2)\n            img_pad_len = (-img_pos_ids.shape[0]) % self.SEQ_LEN_MULTIPLE\n            if img_pad_len:\n                pad_ids = create_coordinate_grid(\n                    size=(1, 1, 1), start=(0, 0, 0), device=device\n                ).flatten(0, 2)\n                img_pos_ids = torch.cat(\n                    [img_pos_ids, pad_ids.repeat(img_pad_len, 1)], dim=0\n                )\n            x_freqs_cis = rotary_emb(img_pos_ids)\n            return (cap_freqs_cis, x_freqs_cis)\n\n        cap_ori_len = prompt_embeds.size(0)\n        cap_padding_len = (-cap_ori_len) % self.SEQ_LEN_MULTIPLE\n        cap_padded_pos_ids = create_coordinate_grid(\n            size=(cap_ori_len + cap_padding_len, 1, 1),\n            start=(1, 0, 0),\n            device=device,\n        ).flatten(0, 2)\n\n        F = 1\n        H = height // self.vae_config.arch_config.spatial_compression_ratio\n        W = width // self.vae_config.arch_config.spatial_compression_ratio\n\n        pH, pW = self.PATCH_SIZE, self.PATCH_SIZE\n        pF = self.F_PATCH_SIZE\n        F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW\n        image_ori_len = F_tokens * H_tokens * W_tokens\n        image_padding_len = (-image_ori_len) % self.SEQ_LEN_MULTIPLE\n\n        image_ori_pos_ids = create_coordinate_grid(\n            size=(F_tokens, H_tokens, W_tokens),\n            start=(cap_ori_len + cap_padding_len + 1, 0, 0),\n            device=device,\n        ).flatten(0, 2)\n        image_padding_pos_ids = (\n            create_coordinate_grid(\n                size=(1, 1, 1),\n                start=(0, 0, 0),\n                device=device,\n            )\n            .flatten(0, 2)\n            .repeat(image_padding_len, 1)\n        )\n        image_padded_pos_ids = torch.cat(\n            [image_ori_pos_ids, image_padding_pos_ids], dim=0\n        )\n        cap_freqs_cis = rotary_emb(cap_padded_pos_ids)\n        x_freqs_cis = rotary_emb(image_padded_pos_ids)\n        return (cap_freqs_cis, x_freqs_cis)\n\n    def prepare_pos_cond_kwargs(self, batch, device, rotary_emb, dtype):\n        return {\n            \"freqs_cis\": self.get_freqs_cis(\n                batch.prompt_embeds[0],\n                batch.width,\n                batch.height,\n                device,\n                rotary_emb,\n                batch,\n            ),\n        }\n\n    def prepare_neg_cond_kwargs(self, batch, device, rotary_emb, dtype):\n        return {\n            \"freqs_cis\": self.get_freqs_cis(\n                batch.prompt_embeds[0],\n                batch.width,\n                batch.height,\n                device,\n                rotary_emb,\n                batch,\n            ),\n        }\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/quantization.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\nfrom __future__ import annotations\n\nimport os\nimport re\nfrom dataclasses import dataclass\nfrom typing import Any\n\nimport torch\n\nfrom sglang.multimodal_gen.runtime.layers.quantization.configs.nunchaku_config import (\n    is_nunchaku_available,\n)\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.utils import StoreBoolean\n\nlogger = init_logger(__name__)\n\n\n@dataclass\nclass NunchakuSVDQuantArgs:\n    \"\"\"CLI-facing configuration for Nunchaku (SVDQuant) inference.\n\n    This is intentionally lightweight and only contains arguments needed to\n    construct `runtime.layers.quantization.nunchaku_config.NunchakuConfig`.\n    \"\"\"\n\n    enable_svdquant: bool = False\n    transformer_weights_path: str | None = None\n    quantization_precision: str | None = None  # \"int4\" or \"nvfp4\"\n    quantization_rank: int | None = None\n    quantization_act_unsigned: bool = False\n\n    def _adjust_config(self) -> None:\n        \"\"\"infer precision and rank from filename if not provided\"\"\"\n        if self.transformer_weights_path and not self.enable_svdquant:\n            filename = os.path.basename(self.transformer_weights_path)\n            if re.search(r\"svdq-(int4|fp4)_r(\\d+)\", filename):\n                self.enable_svdquant = True\n\n        if not self.enable_svdquant or not self.transformer_weights_path:\n            return\n\n        inferred_precision = None\n        inferred_rank = None\n\n        filename = os.path.basename(self.transformer_weights_path)\n        # Expected pattern: svdq-{precision}_r{rank}-...\n        # e.g., svdq-int4_r32-qwen-image.safetensors\n        match = re.search(r\"svdq-(int4|fp4)_r(\\d+)\", filename)\n\n        if match:\n            p_str, r_str = match.groups()\n            inferred_precision = \"nvfp4\" if p_str == \"fp4\" else \"int4\"\n            inferred_rank = int(r_str)\n\n        if self.quantization_precision is None:\n            self.quantization_precision = inferred_precision or \"int4\"\n            if inferred_precision:\n                logger.info(\n                    f\"inferred --quantization-precision: {self.quantization_precision} \"\n                    f\"from --transformer-weights-path: {self.transformer_weights_path}\"\n                )\n\n        if self.quantization_rank is None:\n            self.quantization_rank = inferred_rank or 32\n            if inferred_rank:\n                logger.info(\n                    f\"inferred --quantization-rank: {self.quantization_rank} \"\n                    f\"from --transformer-weights-path: {self.transformer_weights_path}\"\n                )\n\n    def validate(self) -> None:\n        # TODO: warn if the served model doesn't support nunchaku\n        self._adjust_config()\n\n        if not self.enable_svdquant:\n            return\n\n        if not current_platform.is_cuda():\n            raise ValueError(\n                \"Nunchaku SVDQuant is only supported on NVIDIA CUDA GPUs \"\n                \"(Ampere SM8x or SM12x).\"\n            )\n\n        device_count = torch.cuda.device_count()\n\n        unsupported: list[str] = []\n        for i in range(device_count):\n            major, minor = torch.cuda.get_device_capability(i)\n            if major == 9:\n                unsupported.append(f\"cuda:{i} (SM{major}{minor}, Hopper)\")\n            elif major not in (8, 12):\n                unsupported.append(f\"cuda:{i} (SM{major}{minor})\")\n\n        if unsupported:\n            raise ValueError(\n                \"Nunchaku SVDQuant is currently only supported on Ampere (SM8x) or SM12x GPUs; \"\n                \"Hopper (SM90) is not supported. \"\n                f\"Unsupported devices: {', '.join(unsupported)}. \"\n                \"Disable it with --enable-svdquant false.\"\n            )\n\n        if not self.transformer_weights_path:\n            raise ValueError(\n                \"--enable-svdquant requires --transformer-weights-path to be set\"\n            )\n\n        if not is_nunchaku_available():\n            raise ValueError(\n                \"Nunchaku is enabled, but not installed. Please refer to https://nunchaku.tech/docs/nunchaku/installation/installation.html for detailed installation methods.\"\n            )\n\n        if self.quantization_precision not in (\"int4\", \"nvfp4\"):\n            raise ValueError(\n                f\"Invalid --quantization-precision: {self.quantization_precision}. \"\n                \"Must be one of: int4, nvfp4\"\n            )\n\n        if self.quantization_rank <= 0:\n            raise ValueError(\n                f\"Invalid --quantization-rank: {self.quantization_rank}. Must be > 0\"\n            )\n\n    @staticmethod\n    def add_cli_args(parser) -> None:\n        parser.add_argument(\n            \"--enable-svdquant\",\n            action=StoreBoolean,\n            default=NunchakuSVDQuantArgs.enable_svdquant,\n            help=\"Enable Nunchaku SVDQuant (W4A4-style) inference.\",\n        )\n        parser.add_argument(\n            \"--transformer-weights-path\",\n            type=str,\n            default=NunchakuSVDQuantArgs.transformer_weights_path,\n            help=(\n                \"Path to pre-quantized transformer weights. Can be a single .safetensors \"\n                \"file, a directory, or a HuggingFace repo ID. Used by Nunchaku (SVDQuant) and quantized single-file checkpoints.\"\n            ),\n        )\n        parser.add_argument(\n            \"--quantization-precision\",\n            type=str,\n            default=None,\n            help=\"Quantization precision: int4 or nvfp4. If not specified, inferred from model path or defaults to int4.\",\n        )\n        parser.add_argument(\n            \"--quantization-rank\",\n            type=int,\n            default=None,\n            help=\"SVD low-rank dimension (e.g., 32). If not specified, inferred from model path or defaults to 32.\",\n        )\n        parser.add_argument(\n            \"--quantization-act-unsigned\",\n            action=StoreBoolean,\n            default=NunchakuSVDQuantArgs.quantization_act_unsigned,\n            help=\"Use unsigned activation quantization (if supported).\",\n        )\n\n    @classmethod\n    def from_dict(cls, kwargs: dict[str, Any]) -> \"NunchakuSVDQuantArgs\":\n        # Map CLI/config keys to dataclass fields (keep backwards compatibility).\n        path = (\n            kwargs.get(\"transformer_weights_path\")\n            or kwargs.get(\"transformer_quantized_path\")\n            or kwargs.get(\"quantized_model_path\")\n        )\n        return cls(\n            enable_svdquant=bool(kwargs.get(\"enable_svdquant\", cls.enable_svdquant)),\n            transformer_weights_path=path,\n            quantization_precision=kwargs.get(\"quantization_precision\"),\n            quantization_rank=kwargs.get(\"quantization_rank\"),\n            quantization_act_unsigned=bool(\n                kwargs.get(\"quantization_act_unsigned\", cls.quantization_act_unsigned)\n            ),\n        )\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/sample/__init__.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\nfrom sglang.multimodal_gen.configs.sample.diffusers_generic import (\n    DiffusersGenericSamplingParams,\n)\nfrom sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams\n\n__all__ = [\"SamplingParams\", \"DiffusersGenericSamplingParams\"]\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/sample/diffusers_generic.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\"\"\"\nGeneric sampling parameters for diffusers backend.\n\nThis module provides generic sampling parameters that work with any diffusers pipeline.\n\"\"\"\n\nfrom dataclasses import dataclass, field\nfrom typing import Any, ClassVar\n\nfrom sglang.multimodal_gen.configs.sample.sampling_params import (\n    DataType,\n    SamplingParams,\n)\n\n\n@dataclass\nclass DiffusersGenericSamplingParams(SamplingParams):\n    \"\"\"\n    Generic sampling parameters for diffusers backend.\n\n    These parameters cover the most common options across different diffusers pipelines.\n    The diffusers pipeline will use whichever parameters it supports.\n\n    For pipeline-specific parameters, use `diffusers_kwargs` dict which will be\n    passed directly to the diffusers pipeline call.\n    \"\"\"\n\n    _default_height: ClassVar[int] = 1024\n    _default_width: ClassVar[int] = 1024\n\n    # Override defaults with more conservative values that work across pipelines\n    num_frames: int = 1  # default to image generation\n    height: int = 1024\n    width: int = 1024\n    num_inference_steps: int = 30\n    guidance_scale: float = 7.5\n    negative_prompt: str = \"\"\n\n    # extra kwargs to pass directly to the diffusers pipeline\n    # example: {\"output_type\": \"latent\", \"return_dict\": False}\n    diffusers_kwargs: dict[str, Any] = field(default_factory=dict)\n\n    def __post_init__(self) -> None:\n        if self.num_frames > 1:\n            self.data_type = DataType.VIDEO\n        else:\n            self.data_type = DataType.IMAGE\n\n        super().__post_init__()\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/sample/flux.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\nfrom dataclasses import dataclass\nfrom typing import ClassVar\n\nfrom sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams\n\n\n@dataclass\nclass FluxSamplingParams(SamplingParams):\n    _default_height: ClassVar[int] = 128 * 8  # default_sample_size * vae_scale_factor\n    _default_width: ClassVar[int] = 128 * 8\n\n    num_frames: int = 1\n    # Denoising stage\n    guidance_scale: float = 1.0\n    negative_prompt: str = None\n    num_inference_steps: int = 50\n\n\n@dataclass\nclass Flux2KleinSamplingParams(FluxSamplingParams):\n    # Klein is step-distilled, so default to 4 steps\n    num_inference_steps: int = 4\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/sample/glmimage.py",
    "content": "from dataclasses import dataclass\n\nfrom sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams\n\n\n@dataclass\nclass GlmImageSamplingParams(SamplingParams):\n    negative_prompt = \"\"\n\n    num_frames: int = 1\n    guidance_scale: float = 1.5\n    num_inference_steps: int = 30\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/sample/helios.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\nfrom dataclasses import dataclass, field\n\nfrom sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams\n\n\n@dataclass\nclass HeliosT2VSamplingParams(SamplingParams):\n    # Video parameters\n    height: int = 384\n    width: int = 640\n    num_frames: int = 99\n    fps: int = 24\n\n    # Denoising stage\n    guidance_scale: float = 5.0\n    negative_prompt: str = (\n        \"Bright tones, overexposed, static, blurred details, subtitles, style, \"\n        \"works, paintings, images, static, overall gray, worst quality, low quality, \"\n        \"JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, \"\n        \"poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, \"\n        \"still picture, messy background, three legs, many people in the background, \"\n        \"walking backwards\"\n    )\n    num_inference_steps: int = 50\n\n    # Helios T2V supported resolutions\n    supported_resolutions: list[tuple[int, int]] | None = field(\n        default_factory=lambda: [\n            (640, 384),  # ~5:3\n            (384, 640),  # ~3:5\n            (832, 480),  # ~16:9-ish\n            (480, 832),  # ~9:16-ish\n        ]\n    )\n\n\n@dataclass\nclass HeliosMidSamplingParams(HeliosT2VSamplingParams):\n    \"\"\"Sampling params for Helios-Mid (Stage 2 pyramid SR).\"\"\"\n\n    num_inference_steps: int = 20\n\n\n@dataclass\nclass HeliosDistilledSamplingParams(HeliosT2VSamplingParams):\n    \"\"\"Sampling params for Helios-Distilled (DMD, no CFG needed).\"\"\"\n\n    guidance_scale: float = 1.0\n    num_inference_steps: int = 10\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/sample/hunyuan.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\nfrom dataclasses import dataclass, field\n\nfrom sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams\nfrom sglang.multimodal_gen.configs.sample.teacache import TeaCacheParams\n\n\n@dataclass\nclass HunyuanSamplingParams(SamplingParams):\n    num_inference_steps: int = 50\n\n    num_frames: int = 125\n    height: int = 720\n    width: int = 1280\n    fps: int = 24\n\n    guidance_scale: float = 1.0\n\n    # HunyuanVideo supported resolutions\n    supported_resolutions: list[tuple[int, int]] | None = field(\n        default_factory=lambda: [\n            # 540p resolutions\n            (960, 544),  # 9:16\n            (544, 960),  # 16:9\n            (832, 624),  # 4:3\n            (624, 832),  # 3:4\n            (720, 720),  # 1:1\n            # 720p resolutions (recommended)\n            (1280, 720),  # 9:16\n            (720, 1280),  # 16:9\n            (832, 1104),  # 4:3\n            (1104, 832),  # 3:4\n            (960, 960),  # 1:1\n        ]\n    )\n\n    teacache_params: TeaCacheParams = field(\n        default_factory=lambda: TeaCacheParams(\n            teacache_thresh=0.15,\n            coefficients=[\n                7.33226126e02,\n                -4.01131952e02,\n                6.75869174e01,\n                -3.14987800e00,\n                9.61237896e-02,\n            ],\n        )\n    )\n\n\n@dataclass\nclass FastHunyuanSamplingParam(HunyuanSamplingParams):\n    num_inference_steps: int = 6\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/sample/hunyuan3d.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\"\"\"Sampling parameters for Hunyuan3D generation.\"\"\"\n\nfrom dataclasses import dataclass\n\nfrom sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams\n\n\n@dataclass\nclass Hunyuan3DSamplingParams(SamplingParams):\n    \"\"\"Sampling parameters for Hunyuan3D image-to-mesh generation.\"\"\"\n\n    negative_prompt: str = \"\"\n\n    shape_num_inference_steps: int = 50\n    guidance_scale: float = 5.0\n\n    paint_num_inference_steps: int = 30\n    paint_guidance_scale: float = 2.0\n\n    def __post_init__(self):\n        if self.prompt is None:\n            self.prompt = \"\"\n\n        if self.num_inference_steps is None:\n            self.num_inference_steps = self.shape_num_inference_steps\n\n        self.guidance_scale = max(5.0, min(self.guidance_scale, 6.5))\n        super().__post_init__()\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/sample/ltx_2.py",
    "content": "import dataclasses\n\nfrom sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams\n\n\n@dataclasses.dataclass\nclass LTX2SamplingParams(SamplingParams):\n    \"\"\"Sampling parameters for LTX-2.\"\"\"\n\n    # Match the reference defaults used by ltx-pipelines (one-stage).\n    # See: LTX-2/packages/ltx-pipelines/src/ltx_pipelines/utils/constants.py\n    seed: int = 10\n\n    # Video parameters\n    height: int = 512\n    width: int = 768\n    num_frames: int = 121\n    fps: int = 24\n\n    # Audio specific\n    generate_audio: bool = True\n\n    # Denoising parameters\n    guidance_scale: float = 4.0\n    num_inference_steps: int = 40\n\n    # Match ltx-pipelines default negative prompt (covers video + audio artifacts).\n    negative_prompt: str = (\n        \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n        \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n        \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n        \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n        \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n        \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n        \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n        \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n        \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n        \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n        \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n    )\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/sample/mova.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\nfrom dataclasses import dataclass, field\n\nfrom sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams\n\n\n@dataclass\nclass MOVASamplingParams(SamplingParams):\n    # Video parameters (MOVA defaults)\n    height: int = 352\n    width: int = 640\n    num_frames: int = 193\n    fps: int = 24\n\n    # Denoising stage\n    guidance_scale: float = 5.0\n    num_inference_steps: int = 50\n    sigma_shift: float = 5.0\n    visual_shift: float = 5.0\n    audio_shift: float = 5.0\n\n    adjust_frames: bool = False\n\n    negative_prompt: str = (\n        \"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，\"\n        \"整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，\"\n        \"画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，\"\n        \"静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\"\n    )\n\n\n@dataclass\nclass MOVA_360P_SamplingParams(MOVASamplingParams):\n    # Video parameters (MOVA 360P)\n    height: int = 352\n    width: int = 640\n\n    # MOVA 360P supported resolutions\n    supported_resolutions: list[tuple[int, int]] = field(\n        default_factory=lambda: [\n            (352, 640),\n            (640, 352),\n        ]\n    )\n\n\n@dataclass\nclass MOVA_720P_SamplingParams(MOVASamplingParams):\n    # Video parameters (MOVA 720P)\n    height: int = 720\n    width: int = 1280\n\n    # MOVA 720P supported resolutions\n    supported_resolutions: list[tuple[int, int]] = field(\n        default_factory=lambda: [\n            (720, 1280),\n            (1280, 720),\n        ]\n    )\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/sample/qwenimage.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\nfrom dataclasses import dataclass\n\nfrom sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams\n\n\n@dataclass\nclass QwenImageSamplingParams(SamplingParams):\n    negative_prompt: str = \" \"\n    num_frames: int = 1\n    # Denoising stage\n    guidance_scale: float = 4.0\n    num_inference_steps: int = 50\n\n\n@dataclass\nclass QwenImage2512SamplingParams(QwenImageSamplingParams):\n    negative_prompt: str = (\n        \"低分辨率，低画质，肢体畸形，手指畸形，画面过饱和，蜡像感，人脸无细节，过度光滑，画面具有AI感。构图混乱。文字模糊，扭曲。\"\n    )\n\n\n@dataclass\nclass QwenImageEditPlusSamplingParams(QwenImageSamplingParams):\n    # Denoising stage\n    guidance_scale: float = 4.0\n    # true_cfg_scale: float = 4.0\n    num_inference_steps: int = 40\n\n\n@dataclass\nclass QwenImageLayeredSamplingParams(QwenImageSamplingParams):\n    # num_frames: int = 4\n    height: int = 640\n    width: int = 640\n    prompt: str = \" \"\n    negative_prompt: str = \" \"\n\n    guidance_scale: float = 4.0\n    num_inference_steps: int = 50\n    cfg_normalize: bool = True\n    use_en_prompt: bool = True\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/sample/sampling_params.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\nimport argparse\nimport dataclasses\nimport hashlib\nimport json\nimport math\nimport os\nimport os.path\nimport re\nimport time\nimport unicodedata\nimport uuid\nfrom dataclasses import dataclass\nfrom enum import Enum, auto\nfrom typing import TYPE_CHECKING, Any, ClassVar\n\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.utils import StoreBoolean, expand_path_fields\n\nlogger = init_logger(__name__)\n\nif TYPE_CHECKING:\n    from sglang.multimodal_gen.runtime.server_args import ServerArgs\n\n\ndef _json_safe(obj: Any):\n    \"\"\"\n    Recursively convert objects to JSON-serializable forms.\n    - Enums -> their name\n    - Sets/Tuples -> lists\n    - Dicts/Lists -> recursively processed\n    \"\"\"\n    if isinstance(obj, Enum):\n        return obj.name\n    if isinstance(obj, dict):\n        return {k: _json_safe(v) for k, v in obj.items()}\n    if isinstance(obj, (list, tuple, set)):\n        return [_json_safe(v) for v in obj]\n    return obj\n\n\ndef generate_request_id() -> str:\n    return str(uuid.uuid4())\n\n\ndef _sanitize_filename(name: str, replacement: str = \"_\", max_length: int = 150) -> str:\n    \"\"\"Create a filesystem- and ffmpeg-friendly filename.\n\n    - Normalize to ASCII (drop accents and unsupported chars)\n    - Replace spaces with underscores\n    - Replace any char not in [A-Za-z0-9_.-] with replacement\n    - Collapse multiple underscores\n    - Trim leading/trailing dots/underscores and limit length\n    \"\"\"\n    normalized = unicodedata.normalize(\"NFKD\", name)\n    ascii_name = normalized.encode(\"ascii\", \"ignore\").decode(\"ascii\")\n    ascii_name = ascii_name.replace(\" \", \"_\")\n    ascii_name = re.sub(r\"[^A-Za-z0-9._-]\", replacement, ascii_name)\n    ascii_name = re.sub(r\"_+\", \"_\", ascii_name).strip(\"._\")\n    if not ascii_name:\n        ascii_name = \"output\"\n    if max_length and len(ascii_name) > max_length:\n        ascii_name = ascii_name[:max_length]\n    return ascii_name\n\n\nclass DataType(Enum):\n    IMAGE = auto()\n    VIDEO = auto()\n    MESH = auto()\n\n    def get_default_extension(self) -> str:\n        if self == DataType.IMAGE:\n            return \"png\"\n        if self == DataType.VIDEO:\n            return \"mp4\"\n        return \"glb\"\n\n\n@dataclass\nclass SamplingParams:\n    \"\"\"\n    Sampling parameters for generation.\n    \"\"\"\n\n    data_type: DataType = DataType.VIDEO\n\n    request_id: str | None = None\n\n    # All fields below are copied from ForwardBatch\n\n    # Image inputs\n    image_path: str | list[str] | None = None\n\n    # Text inputs\n    prompt: str | list[str] | None = None\n    negative_prompt: str = (\n        \"Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards\"\n    )\n    prompt_path: str | None = None\n    output_path: str | None = None\n    output_file_name: str | None = None\n    output_quality: str | None = \"default\"\n    output_compression: int | None = None\n\n    # Frame interpolation\n    enable_frame_interpolation: bool = False\n    frame_interpolation_exp: int = 1  # 1=2x, 2=4x\n    frame_interpolation_scale: float = 1.0  # RIFE inference scale (0.5 for high-res)\n    frame_interpolation_model_path: str | None = (\n        None  # local dir or HF repo ID with flownet.pkl (default: elfgum/RIFE-4.22.lite)\n    )\n\n    # Upscaling\n    enable_upscaling: bool = False\n    upscaling_model_path: str | None = (\n        None  # local .pth, HF repo ID, or repo_id:filename (default: ai-forever/Real-ESRGAN)\n    )\n    upscaling_scale: int = 4\n\n    # Batch info\n    num_outputs_per_prompt: int = 1\n    seed: int = 42\n    generator_device: str = \"cuda\"  # Device for random generator: \"cuda\" or \"cpu\"\n\n    # Original dimensions (before VAE scaling)\n    num_frames: int = 1  # Default for image models\n    num_frames_round_down: bool = (\n        False  # Whether to round down num_frames if it's not divisible by num_gpus\n    )\n\n    # Subclasses can set these to provide model-specific default resolutions.\n    # The base __post_init__ will apply them when height/width are not provided.\n    _default_height: ClassVar[int | None] = None\n    _default_width: ClassVar[int | None] = None\n\n    height: int | None = None\n    width: int | None = None\n    fps: int = 24\n\n    # Resolution validation\n    supported_resolutions: list[tuple[int, int]] | None = (\n        None  # None means all resolutions allowed\n    )\n\n    # Denoising parameters\n    num_inference_steps: int = None\n    guidance_scale: float = 1.0\n    guidance_scale_2: float = None\n    true_cfg_scale: float = None  # for CFG vs guidance distillation (e.g., QwenImage)\n    guidance_rescale: float = 0.0\n    cfg_normalization: float | bool = 0.0\n    boundary_ratio: float | None = None\n\n    # TeaCache parameters\n    enable_teacache: bool = False\n    teacache_params: Any = (\n        None  # TeaCacheParams or WanTeaCacheParams, set by model-specific subclass\n    )\n\n    # Profiling\n    profile: bool = False\n    num_profiled_timesteps: int = 5\n    profile_all_stages: bool = False\n\n    # Debugging\n    debug: bool = False\n    perf_dump_path: str | None = None\n\n    # Misc\n    save_output: bool = True\n    return_frames: bool = False\n    return_trajectory_latents: bool = False  # returns all latents for each timestep\n    return_trajectory_decoded: bool = False  # returns decoded latents for each timestep\n    # if True, disallow user params to override subclass-defined protected fields\n    no_override_protected_fields: bool = False\n    # whether to adjust num_frames for multi-GPU friendly splitting (default: True)\n    adjust_frames: bool = True\n    # if True, suppress verbose logging for this request\n    suppress_logs: bool = False\n\n    return_file_paths_only: bool = True\n    enable_sequence_shard: bool | None = None\n\n    def _set_output_file_ext(self):\n        # add extension if needed\n        if not any(\n            self.output_file_name.endswith(ext)\n            for ext in [\".mp4\", \".jpg\", \".png\", \".webp\", \".obj\", \".glb\"]\n        ):\n            self.output_file_name = (\n                f\"{self.output_file_name}.{self.data_type.get_default_extension()}\"\n            )\n\n    def _set_output_file_name(self):\n        # settle output_file_name\n        if (\n            self.output_file_name is None\n            and self.prompt\n            and isinstance(self.prompt, str)\n        ):\n            # generate a random filename\n            # get a hash of current params\n            params_dict = dataclasses.asdict(self)\n            # Avoid recursion\n            params_dict[\"output_file_name\"] = \"\"\n\n            # Convert to a stable JSON string\n            params_str = json.dumps(_json_safe(params_dict), sort_keys=True)\n            # Create a hash\n            hasher = hashlib.sha256()\n            hasher.update(params_str.encode(\"utf-8\"))\n            param_hash = hasher.hexdigest()[:8]\n\n            timestamp = time.strftime(\"%Y%m%d-%H%M%S\")\n            base = f\"{self.prompt[:100]}_{timestamp}_{param_hash}\"\n            self.output_file_name = base\n\n        if self.output_file_name is None:\n            timestamp = time.strftime(\"%Y%m%d-%H%M%S\")\n            self.output_file_name = f\"output_{timestamp}\"\n\n        self.output_file_name = _sanitize_filename(self.output_file_name)\n\n        # Ensure a proper extension is present\n        self._set_output_file_ext()\n\n    def __post_init__(self) -> None:\n        assert self.num_frames >= 1\n\n        if self.width is None and self._default_width is not None:\n            self.width = self._default_width\n        if self.height is None and self._default_height is not None:\n            self.height = self._default_height\n\n        # Handle output_quality to output_compression conversion\n        if self.output_compression is None and self.output_quality is not None:\n            self.output_compression = self._adjust_output_quality(\n                self.output_quality, self.data_type\n            )\n\n        self._validate()\n\n        # Allow env var to override num_inference_steps (for faster CI testing on AMD)\n        env_steps = os.environ.get(\"SGLANG_TEST_NUM_INFERENCE_STEPS\")\n        if env_steps is not None and self.num_inference_steps is not None:\n            self.num_inference_steps = int(env_steps)\n\n    def _adjust_output_quality(self, output_quality: str, data_type: DataType) -> int:\n        \"\"\"Convert output_quality string to compression level.\"\"\"\n        output_quality_mapper = {\"maximum\": 100, \"high\": 90, \"medium\": 55, \"low\": 35}\n        if output_quality == \"default\":\n            return 50 if data_type == DataType.VIDEO else 75\n        return output_quality_mapper.get(output_quality)\n\n    def _validate(self):\n        \"\"\"\n        check if the sampling params is correct by itself\n        \"\"\"\n        if self.prompt_path and not self.prompt_path.endswith(\".txt\"):\n            raise ValueError(\n                f\"prompt_path must be a txt file, got {self.prompt_path!r}\"\n            )\n\n        # These are always required to be sane regardless of pipeline.\n        if (\n            not isinstance(self.num_outputs_per_prompt, int)\n            or self.num_outputs_per_prompt <= 0\n        ):\n            raise ValueError(\n                f\"num_outputs_per_prompt must be a positive int, got {self.num_outputs_per_prompt!r}\"\n            )\n\n        # Used by seconds() and video writer; fps <= 0 is always invalid.\n        if not isinstance(self.fps, int) or self.fps <= 0:\n            raise ValueError(f\"fps must be a positive int, got {self.fps!r}\")\n\n        # num_frames is already asserted in __post_init__, but keep a friendly error here too\n        # (e.g., when validation is triggered from other code paths).\n        if not isinstance(self.num_frames, int) or self.num_frames <= 0:\n            raise ValueError(\n                f\"num_frames must be a positive int, got {self.num_frames!r}\"\n            )\n\n        if self.num_inference_steps is not None:\n            if (\n                not isinstance(self.num_inference_steps, int)\n                or self.num_inference_steps <= 0\n            ):\n                raise ValueError(\n                    f\"num_inference_steps must be a positive int, got {self.num_inference_steps!r}\"\n                )\n\n        # Numeric hyperparams should not be NaN/Inf and should be within basic ranges.\n        # Note: bool is a subclass of int; reject it explicitly to avoid silent surprises.\n        def _finite_non_negative_float(\n            name: str, value: Any, allow_none: bool = True\n        ) -> None:\n            if value is None and allow_none:\n                return\n            if isinstance(value, bool) or not isinstance(value, (int, float)):\n                raise ValueError(f\"{name} must be a number, got {value!r}\")\n            if not math.isfinite(float(value)):\n                raise ValueError(f\"{name} must be finite, got {value!r}\")\n            if float(value) < 0.0:\n                raise ValueError(f\"{name} must be non-negative, got {value!r}\")\n\n        _finite_non_negative_float(\n            \"guidance_scale\", self.guidance_scale, allow_none=True\n        )\n        _finite_non_negative_float(\n            \"guidance_scale_2\", self.guidance_scale_2, allow_none=True\n        )\n        _finite_non_negative_float(\n            \"true_cfg_scale\", self.true_cfg_scale, allow_none=True\n        )\n        _finite_non_negative_float(\n            \"guidance_rescale\", self.guidance_rescale, allow_none=False\n        )\n\n        if self.cfg_normalization is None:\n            self.cfg_normalization = 0.0\n        elif isinstance(self.cfg_normalization, bool):\n            self.cfg_normalization = 1.0 if self.cfg_normalization else 0.0\n\n        if self.boundary_ratio is not None:\n            if isinstance(self.boundary_ratio, bool) or not isinstance(\n                self.boundary_ratio, (int, float)\n            ):\n                raise ValueError(\n                    f\"boundary_ratio must be a number, got {self.boundary_ratio!r}\"\n                )\n            if not math.isfinite(float(self.boundary_ratio)):\n                raise ValueError(\n                    f\"boundary_ratio must be finite, got {self.boundary_ratio!r}\"\n                )\n            if not (0.0 <= float(self.boundary_ratio) <= 1.0):\n                raise ValueError(\n                    f\"boundary_ratio must be within [0, 1], got {self.boundary_ratio!r}\"\n                )\n\n    def check_sampling_param(self):\n        # Keep backward-compatibility for old call sites.\n        self._validate()\n\n    def _validate_with_pipeline_config(self, pipeline_config):\n        \"\"\"\n        check if the sampling params is compatible and valid with server_args\n        \"\"\"\n        if pipeline_config.task_type.requires_image_input():\n            # requires image input\n            if self.image_path is None:\n                raise ValueError(\n                    f\"Served model with task type '{pipeline_config.task_type.name}' requires an 'image_path' input, but none was provided\"\n                )\n\n        if not pipeline_config.task_type.accepts_image_input():\n            # does not support image input\n            if self.image_path is not None:\n                raise ValueError(\n                    f\"input_reference is not supported for {pipeline_config.task_type.name} models.\"\n                )\n\n    def _adjust(\n        self,\n        server_args,\n    ):\n        \"\"\"\n        final adjustment, called after merged with user params\n        \"\"\"\n        expand_path_fields(self)\n\n        # TODO: SamplingParams should not rely on ServerArgs\n        pipeline_config = server_args.pipeline_config\n\n        if self.guidance_scale is None:\n            try:\n                from sglang.multimodal_gen.configs.pipeline_configs.hunyuan3d import (\n                    Hunyuan3D2PipelineConfig,\n                )\n\n                if isinstance(pipeline_config, Hunyuan3D2PipelineConfig):\n                    self.guidance_scale = pipeline_config.guidance_scale\n                else:\n                    self.guidance_scale = 1.0\n            except ImportError:\n                self.guidance_scale = 1.0\n\n        self.data_type = server_args.pipeline_config.task_type.data_type()\n\n        if self.output_path is None:\n            if server_args.output_path is not None:\n                self.output_path = server_args.output_path\n                logger.debug(\n                    f\"Overriding output_path with server configuration: {self.output_path}\"\n                )\n            else:\n                self.save_output = False\n\n        # Process negative prompt\n        if self.negative_prompt is not None and not self.negative_prompt.isspace():\n            # avoid stripping default negative prompt: ' ' for qwen-image\n            self.negative_prompt = self.negative_prompt.strip()\n\n        # Validate dimensions\n        if self.num_frames <= 0:\n            raise ValueError(\n                f\"height, width, and num_frames must be positive integers, got \"\n                f\"height={self.height}, width={self.width}, \"\n                f\"num_frames={self.num_frames}\"\n            )\n\n        # Validate resolution against pipeline-specific supported resolutions\n        if self.height is None and self.width is None:\n            if self.supported_resolutions is not None:\n                self.width, self.height = self.supported_resolutions[0]\n                logger.info(\n                    f\"Resolution unspecified, using default: {self.supported_resolutions[0]}\"\n                )\n\n        if self.height is not None and self.width is not None:\n            if self.supported_resolutions is not None:\n                if (self.width, self.height) not in self.supported_resolutions:\n                    supported_str = \", \".join(\n                        [f\"{w}x{h}\" for w, h in self.supported_resolutions]\n                    )\n                    error_msg = (\n                        f\"Unsupported resolution: {self.width}x{self.height}, output quality may suffer. \"\n                        f\"Supported resolutions: {supported_str}\"\n                    )\n                    logger.warning(error_msg)\n\n        pipeline_name_lower = server_args.pipeline_config.__class__.__name__.lower()\n\n        if (\"wan\" in pipeline_name_lower or \"helios\" in pipeline_name_lower) and (\n            self.enable_sequence_shard is None or self.enable_sequence_shard\n        ):\n            self.enable_sequence_shard = True\n            logger.debug(\"Automatically enabled enable_sequence_shard\")\n        else:\n            self.enable_sequence_shard = False\n\n        if self.enable_sequence_shard:\n            self.adjust_frames = False\n            logger.info(\n                f\"Sequence dimension shard is enabled, disabling frame adjustment for better performance\"\n            )\n\n        if pipeline_config.task_type.is_image_gen():\n            # settle num_frames\n            if not server_args.pipeline_config.allow_set_num_frames():\n                logger.debug(f\"Setting `num_frames` to 1 for image generation model\")\n                self.num_frames = 1\n\n        else:\n            # mandatory frame adjusting logic, mod\n            # NOTE: We must apply adjust_num_frames BEFORE the SP alignment logic below.\n            # If we apply it after, adjust_num_frames might modify the frame count\n            # and break the divisibility constraint (alignment) required by num_gpus.\n            original_num_frames = self.num_frames\n            self.num_frames = server_args.pipeline_config.adjust_num_frames(\n                original_num_frames\n            )\n            logger.info(\n                \"Adjusting number of frames from %s to %s based on model\",\n                original_num_frames,\n                self.num_frames,\n            )\n\n            if self.adjust_frames:\n                # Adjust number of frames based on number of GPUs for video task\n                use_temporal_scaling_frames = (\n                    pipeline_config.vae_config.use_temporal_scaling_frames\n                )\n                num_frames = self.num_frames\n                num_gpus = server_args.num_gpus\n                temporal_scale_factor = (\n                    pipeline_config.vae_config.arch_config.temporal_compression_ratio\n                )\n\n                if use_temporal_scaling_frames:\n                    orig_latent_num_frames = (\n                        num_frames - 1\n                    ) // temporal_scale_factor + 1\n                else:\n                    orig_latent_num_frames = num_frames\n\n                if orig_latent_num_frames % server_args.num_gpus != 0:\n                    # Adjust latent frames to be divisible by number of GPUs\n                    if self.num_frames_round_down:\n                        # Ensure we have at least 1 batch per GPU\n                        new_latent_num_frames = (\n                            max(1, (orig_latent_num_frames // num_gpus)) * num_gpus\n                        )\n                    else:\n                        new_latent_num_frames = (\n                            math.ceil(orig_latent_num_frames / num_gpus) * num_gpus\n                        )\n\n                    if use_temporal_scaling_frames:\n                        # Convert back to number of frames, ensuring num_frames-1 is a multiple of temporal_scale_factor\n                        new_num_frames = (\n                            new_latent_num_frames - 1\n                        ) * temporal_scale_factor + 1\n                    else:\n                        new_num_frames = new_latent_num_frames\n\n                    logger.info(\n                        \"Adjusting number of frames from %s to %s based on number of GPUs (%s)\",\n                        self.num_frames,\n                        new_num_frames,\n                        server_args.num_gpus,\n                    )\n                    self.num_frames = new_num_frames\n\n        if not server_args.comfyui_mode:\n            self._set_output_file_name()\n\n    @classmethod\n    def from_pretrained(cls, model_path: str, **kwargs) -> \"SamplingParams\":\n        from sglang.multimodal_gen.registry import get_model_info\n\n        backend = kwargs.pop(\"backend\", None)\n        model_id = kwargs.pop(\"model_id\", None)\n        model_info = get_model_info(model_path, backend=backend, model_id=model_id)\n        sampling_params: SamplingParams = model_info.sampling_param_cls(**kwargs)\n        return sampling_params\n\n    @staticmethod\n    def from_user_sampling_params_args(\n        model_path: str, server_args: \"ServerArgs\", *args, **kwargs\n    ):\n        try:\n            sampling_params = SamplingParams.from_pretrained(\n                model_path, backend=server_args.backend, model_id=server_args.model_id\n            )\n        except (AttributeError, ValueError) as e:\n            # Handle safetensors files or other cases where model_index.json is not available\n            # Use appropriate SamplingParams based on pipeline_class_name from registry\n            if os.path.isfile(model_path) and model_path.endswith(\".safetensors\"):\n                # Determine which sampling params to use based on pipeline_class_name\n                pipeline_class_name = getattr(server_args, \"pipeline_class_name\", None)\n\n                # Try to get SamplingParams from registry\n                from sglang.multimodal_gen.registry import get_pipeline_config_classes\n\n                config_classes = (\n                    get_pipeline_config_classes(pipeline_class_name)\n                    if pipeline_class_name\n                    else None\n                )\n\n                if config_classes is not None:\n                    _, sampling_params_cls = config_classes\n                    try:\n                        sampling_params = sampling_params_cls()\n                        logger.info(\n                            f\"Using {sampling_params_cls.__name__} for {pipeline_class_name} safetensors file (no model_index.json): %s\",\n                            model_path,\n                        )\n                    except Exception as import_error:\n                        logger.warning(\n                            f\"Failed to instantiate {sampling_params_cls.__name__}: {import_error}. \"\n                            \"Using default SamplingParams\"\n                        )\n                        sampling_params = SamplingParams()\n                else:\n                    raise ValueError(\n                        f\"Could not get pipeline config classes for {pipeline_class_name}\"\n                    )\n            else:\n                # Re-raise if it's not a safetensors file issue\n                raise\n\n        user_kwargs = dict(kwargs)\n        user_kwargs.pop(\"diffusers_kwargs\", None)\n        user_sampling_params = SamplingParams(*args, **user_kwargs)\n        # TODO: refactor\n        sampling_params._merge_with_user_params(\n            user_sampling_params, explicit_fields=set(user_kwargs.keys())\n        )\n        sampling_params._adjust(server_args)\n\n        sampling_params._validate_with_pipeline_config(server_args.pipeline_config)\n\n        return sampling_params\n\n    def output_size_str(self) -> str:\n        return f\"{self.width}x{self.height}\"\n\n    def seconds(self) -> float:\n        return self.num_frames / self.fps\n\n    @staticmethod\n    def add_cli_args(parser: Any) -> Any:\n        \"\"\"Add CLI arguments for SamplingParam fields\"\"\"\n\n        def add_argument(*name_or_flags, **kwargs):\n            kwargs.setdefault(\"default\", argparse.SUPPRESS)\n            return parser.add_argument(*name_or_flags, **kwargs)\n\n        add_argument(\"--data-type\", type=str, nargs=\"+\")\n        add_argument(\n            \"--num-frames-round-down\",\n            action=\"store_true\",\n        )\n        add_argument(\n            \"--enable-teacache\",\n            action=\"store_true\",\n        )\n\n        # profiling\n        add_argument(\n            \"--profile\",\n            action=\"store_true\",\n            help=\"Enable torch profiler for denoising stage\",\n        )\n        add_argument(\n            \"--num-profiled-timesteps\",\n            type=int,\n            help=\"Number of timesteps to profile after warmup\",\n        )\n        add_argument(\n            \"--profile-all-stages\",\n            action=\"store_true\",\n            dest=\"profile_all_stages\",\n            help=\"Used with --profile, profile all pipeline stages\",\n        )\n\n        add_argument(\n            \"--debug\",\n            action=\"store_true\",\n            help=\"\",\n        )\n\n        add_argument(\n            \"--prompt\",\n            type=str,\n            nargs=\"+\",\n            help=\"Text prompt(s) for generation. Use space-separated values for multiple prompts, e.g., --prompt 'prompt 1' 'prompt 2'\",\n        )\n        add_argument(\n            \"--negative-prompt\",\n            type=str,\n            help=\"Negative text prompt for generation\",\n        )\n        add_argument(\n            \"--prompt-path\",\n            type=str,\n            help=\"Path to a text file containing the prompt\",\n        )\n        add_argument(\n            \"--output-file-name\",\n            type=str,\n            help=\"Name of the output file\",\n        )\n        add_argument(\n            \"--output-quality\",\n            type=str,\n            help=\"Output quality setting (default, low, medium, high, maximum)\",\n        )\n        add_argument(\n            \"--output-compression\",\n            type=int,\n            help=\"Output compression level (0-100, higher means better quality but larger file size)\",\n        )\n        add_argument(\n            \"--num-outputs-per-prompt\",\n            type=int,\n            help=\"Number of outputs to generate per prompt\",\n        )\n        add_argument(\n            \"--seed\",\n            type=int,\n            help=\"Random seed for generation\",\n        )\n        add_argument(\n            \"--generator-device\",\n            type=str,\n            choices=[\"cuda\", \"musa\", \"cpu\"],\n            help=\"Device for random generator (cuda, musa or cpu). Default: cuda\",\n        )\n        add_argument(\n            \"--num-frames\",\n            type=int,\n            help=\"Number of frames to generate\",\n        )\n        add_argument(\n            \"--height\",\n            type=int,\n            help=\"Height of generated output\",\n        )\n        add_argument(\n            \"--width\",\n            type=int,\n            help=\"Width of generated output\",\n        )\n        # resolution shortcuts\n        add_argument(\n            \"--4k\",\n            action=\"store_true\",\n            dest=\"resolution_4k\",\n            help=\"Set resolution to 4K (3840x2160)\",\n        )\n        add_argument(\n            \"--2k\",\n            action=\"store_true\",\n            dest=\"resolution_2k\",\n            help=\"Set resolution to 2K (2560x1440)\",\n        )\n        add_argument(\n            \"--1080p\",\n            action=\"store_true\",\n            dest=\"resolution_1080p\",\n            help=\"Set resolution to 1080p (1920x1080)\",\n        )\n        add_argument(\n            \"--720p\",\n            action=\"store_true\",\n            dest=\"resolution_720p\",\n            help=\"Set resolution to 720p (1280x720)\",\n        )\n\n        add_argument(\n            \"--fps\",\n            type=int,\n            help=\"Frames per second for saved output\",\n        )\n        add_argument(\n            \"--num-inference-steps\",\n            type=int,\n            help=\"Number of denoising steps\",\n        )\n        add_argument(\n            \"--guidance-scale\",\n            type=float,\n            help=\"Classifier-free guidance scale\",\n        )\n        add_argument(\n            \"--guidance-scale-2\",\n            type=float,\n            dest=\"guidance_scale_2\",\n            help=\"Secondary guidance scale for dual-guidance models (e.g., Wan low-noise expert)\",\n        )\n        add_argument(\n            \"--guidance-rescale\",\n            type=float,\n            help=\"Guidance rescale factor\",\n        )\n        add_argument(\n            \"--cfg-normalization\",\n            type=float,\n            dest=\"cfg_normalization\",\n            help=(\"CFG renormalization factor (for Z-Image). \"),\n        )\n        add_argument(\n            \"--boundary-ratio\",\n            type=float,\n            help=\"Boundary timestep ratio\",\n        )\n        add_argument(\n            \"--save-output\",\n            action=\"store_true\",\n            help=\"Whether to save the output to disk\",\n        )\n        add_argument(\n            \"--no-save-output\",\n            action=\"store_false\",\n            dest=\"save_output\",\n            help=\"Don't save the output to disk\",\n        )\n        add_argument(\n            \"--return-frames\",\n            action=\"store_true\",\n            help=\"Whether to return the raw frames\",\n        )\n        add_argument(\n            \"--image-path\",\n            type=str,\n            nargs=\"+\",\n            help=(\n                \"Path(s) to input image(s) for image-to-image / image-to-video \"\n                \"generation. For multiple images, pass them as space-separated \"\n                \"values, e.g.: \"\n                '--image-path \"img1.png\" \"img2.png\"'\n            ),\n        )\n        add_argument(\n            \"--moba-config-path\",\n            type=str,\n            help=\"Path to a JSON file containing V-MoBA specific configurations.\",\n        )\n        add_argument(\n            \"--return-trajectory-latents\",\n            action=\"store_true\",\n            help=\"Whether to return the trajectory\",\n        )\n        add_argument(\n            \"--return-trajectory-decoded\",\n            action=\"store_true\",\n            help=\"Whether to return the decoded trajectory\",\n        )\n        add_argument(\n            \"--diffusers-kwargs\",\n            type=str,\n            help=\"JSON string of extra kwargs to pass to diffusers pipeline. \"\n            'Example: \\'{\"output_type\": \"latent\", \"clip_skip\": 2}\\'',\n        )\n        add_argument(\n            \"--no-override-protected-fields\",\n            action=\"store_true\",\n            help=(\n                \"If set, disallow user params to override fields defined in subclasses.\"\n            ),\n        )\n        add_argument(\n            \"--adjust-frames\",\n            action=StoreBoolean,\n            help=(\n                \"Enable/disable adjusting num_frames to evenly split latent frames across GPUs \"\n                \"and satisfy model temporal constraints. If disabled, tokens might be padded for SP.\"\n                \"Default: true. Examples: --adjust-frames, --adjust-frames true, --adjust-frames false.\"\n            ),\n        )\n        add_argument(\n            \"--return-file-paths-only\",\n            action=StoreBoolean,\n            help=\"If set, output file will be saved early to get a performance boost, while output tensors will not be returned.\",\n        )\n        add_argument(\n            \"--enable-sequence-shard\",\n            action=StoreBoolean,\n            help=\"Enable sequence dimension shard with sequence parallelism.\",\n        )\n        add_argument(\n            \"--enable-frame-interpolation\",\n            action=\"store_true\",\n            help=\"Enable post-generation frame interpolation using RIFE 4.22.lite.\",\n        )\n        add_argument(\n            \"--frame-interpolation-exp\",\n            type=int,\n            help=\"Frame interpolation exponent: 1=2x, 2=4x (default: 1).\",\n        )\n        add_argument(\n            \"--frame-interpolation-scale\",\n            type=float,\n            help=\"RIFE inference scale factor (default: 1.0; use 0.5 for high-res).\",\n        )\n        add_argument(\n            \"--frame-interpolation-model-path\",\n            type=str,\n            help=\"Local directory or HuggingFace repo ID containing RIFE flownet.pkl weights \"\n            \"(default: elfgum/RIFE-4.22.lite, downloaded automatically). \"\n            \"Only RIFE 4.22.lite architecture is supported; other RIFE versions or \"\n            \"frame interpolation models are not compatible.\",\n        )\n        add_argument(\n            \"--enable-upscaling\",\n            action=\"store_true\",\n            help=\"Enable post-generation upscaling using Real-ESRGAN.\",\n        )\n        add_argument(\n            \"--upscaling-model-path\",\n            type=str,\n            help=\"Local .pth file, HuggingFace repo ID, or repo_id:filename for Real-ESRGAN weights \"\n            \"(default: ai-forever/Real-ESRGAN with RealESRGAN_x4.pth). \"\n            \"Only RRDBNet (e.g. RealESRGAN_x4plus) and SRVGGNetCompact (e.g. realesr-animevideov3) \"\n            \"architectures are supported; other super-resolution models are not compatible. \"\n            \"Use 'repo_id:filename' to specify a custom weight file from a HF repo.\",\n        )\n        add_argument(\n            \"--upscaling-scale\",\n            type=int,\n            help=\"Upscaling factor (default: 4).\",\n        )\n        return parser\n\n    @classmethod\n    def get_cli_args(cls, args: argparse.Namespace):\n        # handle resolution shortcuts\n        if hasattr(args, \"resolution_4k\") and args.resolution_4k:\n            args.width = 3840\n            args.height = 2160\n        elif hasattr(args, \"resolution_2k\") and args.resolution_2k:\n            args.width = 2560\n            args.height = 1440\n        elif hasattr(args, \"resolution_1080p\") and args.resolution_1080p:\n            args.width = 1920\n            args.height = 1080\n        elif hasattr(args, \"resolution_720p\") and args.resolution_720p:\n            args.width = 1280\n            args.height = 720\n\n        sampling_params_fields = {attr.name for attr in dataclasses.fields(cls)}\n        args_attrs = set(vars(args).keys())\n        attrs = sampling_params_fields & args_attrs\n        return {\n            attr: getattr(args, attr)\n            for attr in attrs\n            if hasattr(args, attr) and getattr(args, attr) is not None\n        }\n\n    def output_file_path(self):\n        if self.output_path is None:\n            return None\n        return os.path.join(self.output_path, self.output_file_name)\n\n    def _merge_with_user_params(\n        self,\n        user_params: \"SamplingParams\",\n        explicit_fields: set[str] | None = None,\n    ):\n        \"\"\"\n        Merges parameters from a user-provided SamplingParams object.\n\n        Args:\n            explicit_fields: field names explicitly set by the user (e.g. from\n                CLI kwargs). These are always treated as user-modified even when\n                their value matches the base-class default.\n        \"\"\"\n        if user_params is None:\n            return\n\n        predefined_fields = set(type(self).__annotations__.keys())\n\n        # global switch: if True, allow overriding protected fields\n        allow_override_protected = not user_params.no_override_protected_fields\n        for field in dataclasses.fields(user_params):\n            field_name = field.name\n            user_value = getattr(user_params, field_name)\n            default_class_value = getattr(SamplingParams, field_name)\n\n            is_user_modified = user_value != default_class_value or (\n                explicit_fields is not None and field_name in explicit_fields\n            )\n            is_protected_field = field_name in predefined_fields\n            if is_user_modified and (\n                allow_override_protected or not is_protected_field\n            ):\n                setattr(self, field_name, user_value)\n        self.__post_init__()\n\n    @property\n    def n_tokens(self) -> int:\n        # Calculate latent sizes\n        if self.height and self.width:\n            latents_size = [\n                (self.num_frames - 1) // 4 + 1,\n                self.height // 8,\n                self.width // 8,\n            ]\n            n_tokens = latents_size[0] * latents_size[1] * latents_size[2]\n        else:\n            n_tokens = -1\n        return n_tokens\n\n\n@dataclass\nclass CacheParams:\n    cache_type: str = \"none\"\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/sample/sana.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\"\"\"Sampling parameters for SANA image generation (T2I).\"\"\"\n\nfrom dataclasses import dataclass\n\nfrom sglang.multimodal_gen.configs.sample.sampling_params import (\n    DataType,\n    SamplingParams,\n)\n\n\n@dataclass\nclass SanaSamplingParams(SamplingParams):\n    \"\"\"Defaults for SANA 1.5 1024px variant.\n\n    guidance_scale=4.5 enables standard classifier-free guidance.\n    \"\"\"\n\n    data_type: DataType = DataType.IMAGE\n    num_frames: int = 1\n    guidance_scale: float = 4.5\n    num_inference_steps: int = 20\n    height: int = 1024\n    width: int = 1024\n    negative_prompt: str = (\n        \"low quality, low resolution, blurry, overexposed, underexposed, \"\n        \"distorted, deformed, disfigured, bad anatomy, extra limbs, \"\n        \"watermark, text, signature, ugly, noisy, artifacts\"\n    )\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/sample/teacache.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\nfrom dataclasses import dataclass, field\n\nfrom sglang.multimodal_gen.configs.sample.sampling_params import CacheParams\n\n\n@dataclass\nclass TeaCacheParams(CacheParams):\n    cache_type: str = \"teacache\"\n    teacache_thresh: float = 0.0\n    coefficients: list[float] = field(default_factory=list)\n\n\n@dataclass\nclass WanTeaCacheParams(CacheParams):\n    # Unfortunately, TeaCache is very different for Wan than other models\n    cache_type: str = \"teacache\"\n    teacache_thresh: float = 0.0\n    use_ret_steps: bool = True\n    ret_steps_coeffs: list[float] = field(default_factory=list)\n    non_ret_steps_coeffs: list[float] = field(default_factory=list)\n\n    @property\n    def coefficients(self) -> list[float]:\n        if self.use_ret_steps:\n            return self.ret_steps_coeffs\n        else:\n            return self.non_ret_steps_coeffs\n\n    @property\n    def ret_steps(self) -> int:\n        if self.use_ret_steps:\n            return 5 * 2\n        else:\n            return 1 * 2\n\n    def get_cutoff_steps(self, num_inference_steps: int) -> int:\n        if self.use_ret_steps:\n            return num_inference_steps * 2\n        else:\n            return num_inference_steps * 2 - 2\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/sample/wan.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\nfrom dataclasses import dataclass, field\n\nfrom sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams\nfrom sglang.multimodal_gen.configs.sample.teacache import WanTeaCacheParams\n\n\n@dataclass\nclass WanT2V_1_3B_SamplingParams(SamplingParams):\n    # Video parameters\n    height: int = 480\n    width: int = 832\n    num_frames: int = 81\n    fps: int = 16\n\n    # Denoising stage\n    guidance_scale: float = 3.0\n    negative_prompt: str = (\n        \"Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards\"\n    )\n    num_inference_steps: int = 50\n\n    # Wan T2V 1.3B supported resolutions\n    supported_resolutions: list[tuple[int, int]] | None = field(\n        default_factory=lambda: [\n            (832, 480),  # 16:9\n            (480, 832),  # 9:16\n        ]\n    )\n\n    teacache_params: WanTeaCacheParams = field(\n        default_factory=lambda: WanTeaCacheParams(\n            teacache_thresh=0.08,\n            ret_steps_coeffs=[\n                -5.21862437e04,\n                9.23041404e03,\n                -5.28275948e02,\n                1.36987616e01,\n                -4.99875664e-02,\n            ],\n            non_ret_steps_coeffs=[\n                2.39676752e03,\n                -1.31110545e03,\n                2.01331979e02,\n                -8.29855975e00,\n                1.37887774e-01,\n            ],\n        )\n    )\n\n\n@dataclass\nclass WanT2V_14B_SamplingParams(SamplingParams):\n    # Video parameters\n    height: int = 720\n    width: int = 1280\n    num_frames: int = 81\n    fps: int = 16\n\n    # Denoising stage\n    guidance_scale: float = 5.0\n    negative_prompt: str = (\n        \"Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards\"\n    )\n    num_inference_steps: int = 50\n\n    # Wan T2V 14B supported resolutions\n    supported_resolutions: list[tuple[int, int]] | None = field(\n        default_factory=lambda: [\n            (1280, 720),  # 16:9\n            (720, 1280),  # 9:16\n            (832, 480),  # 16:9\n            (480, 832),  # 9:16\n        ]\n    )\n\n    teacache_params: WanTeaCacheParams = field(\n        default_factory=lambda: WanTeaCacheParams(\n            teacache_thresh=0.20,\n            use_ret_steps=False,\n            ret_steps_coeffs=[\n                -3.03318725e05,\n                4.90537029e04,\n                -2.65530556e03,\n                5.87365115e01,\n                -3.15583525e-01,\n            ],\n            non_ret_steps_coeffs=[\n                -5784.54975374,\n                5449.50911966,\n                -1811.16591783,\n                256.27178429,\n                -13.02252404,\n            ],\n        )\n    )\n\n\n@dataclass\nclass WanI2V_14B_480P_SamplingParam(WanT2V_1_3B_SamplingParams):\n    # Denoising stage\n    guidance_scale: float = 5.0\n    num_inference_steps: int = 50\n    # num_inference_steps: int = 40\n\n    # Wan I2V 480P supported resolutions (override parent)\n    supported_resolutions: list[tuple[int, int]] | None = field(\n        default_factory=lambda: [\n            (832, 480),  # 16:9\n            (480, 832),  # 9:16\n        ]\n    )\n\n    teacache_params: WanTeaCacheParams = field(\n        default_factory=lambda: WanTeaCacheParams(\n            teacache_thresh=0.26,\n            ret_steps_coeffs=[\n                -3.03318725e05,\n                4.90537029e04,\n                -2.65530556e03,\n                5.87365115e01,\n                -3.15583525e-01,\n            ],\n            non_ret_steps_coeffs=[\n                -5784.54975374,\n                5449.50911966,\n                -1811.16591783,\n                256.27178429,\n                -13.02252404,\n            ],\n        )\n    )\n\n\n@dataclass\nclass WanI2V_14B_720P_SamplingParam(WanT2V_14B_SamplingParams):\n    # Denoising stage\n    guidance_scale: float = 5.0\n    num_inference_steps: int = 50\n    # num_inference_steps: int = 40\n\n    # Wan I2V 720P supported resolutions (override parent)\n    supported_resolutions: list[tuple[int, int]] | None = field(\n        default_factory=lambda: [\n            (1280, 720),  # 16:9\n            (720, 1280),  # 9:16\n            (832, 480),  # 16:9\n            (480, 832),  # 9:16\n        ]\n    )\n\n    teacache_params: WanTeaCacheParams = field(\n        default_factory=lambda: WanTeaCacheParams(\n            teacache_thresh=0.3,\n            ret_steps_coeffs=[\n                -3.03318725e05,\n                4.90537029e04,\n                -2.65530556e03,\n                5.87365115e01,\n                -3.15583525e-01,\n            ],\n            non_ret_steps_coeffs=[\n                -5784.54975374,\n                5449.50911966,\n                -1811.16591783,\n                256.27178429,\n                -13.02252404,\n            ],\n        )\n    )\n\n\n@dataclass\nclass FastWanT2V480PConfig(WanT2V_1_3B_SamplingParams):\n    # DMD parameters\n    # dmd_denoising_steps: list[int] | None = field(default_factory=lambda: [1000, 757, 522])\n    num_inference_steps: int = 3\n    num_frames: int = 61\n    height: int = 448\n    width: int = 832\n    fps: int = 16\n\n\n# =============================================\n# ============= Wan2.1 Fun Models =============\n# =============================================\n@dataclass\nclass Wan2_1_Fun_1_3B_InP_SamplingParams(SamplingParams):\n    \"\"\"Sampling parameters for Wan2.1 Fun 1.3B InP model.\"\"\"\n\n    height: int = 480\n    width: int = 832\n    num_frames: int = 81\n    fps: int = 16\n    negative_prompt: str | None = (\n        \"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\"\n    )\n    guidance_scale: float = 6.0\n    num_inference_steps: int = 50\n\n\n# =============================================\n# ============= Wan2.2 TI2V Models =============\n# =============================================\n@dataclass\nclass Wan2_2_Base_SamplingParams(SamplingParams):\n    \"\"\"Sampling parameters for Wan2.2 TI2V 5B model.\"\"\"\n\n    negative_prompt: str | None = (\n        \"色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\"\n    )\n\n    # TODO(Wan2.2): TeaCache coefficients need to be calibrated for Wan2.2 by\n    # profiling L1 distances across timesteps. Until then, teacache_params is None\n    # and enable_teacache will be accepted but silently no-op.\n    # Consider using Cache-DiT (SGLANG_CACHE_DIT_ENABLED=1) as an alternative.\n\n\n@dataclass\nclass Wan2_2_TI2V_5B_SamplingParam(Wan2_2_Base_SamplingParams):\n    \"\"\"Sampling parameters for Wan2.2 TI2V 5B model.\"\"\"\n\n    height: int = 704\n    width: int = 1280\n    num_frames: int = 121\n    fps: int = 24\n    guidance_scale: float = 5.0\n    num_inference_steps: int = 50\n\n    # Wan2.2 TI2V 5B supported resolutions\n    supported_resolutions: list[tuple[int, int]] | None = field(\n        default_factory=lambda: [\n            (1280, 704),  # 16:9-ish\n            (704, 1280),  # 9:16-ish\n        ]\n    )\n\n\n@dataclass\nclass Wan2_2_T2V_A14B_SamplingParam(Wan2_2_Base_SamplingParams):\n    guidance_scale: float = 4.0  # high_noise\n    guidance_scale_2: float = 3.0  # low_noise\n    num_inference_steps: int = 40\n    fps: int = 16\n\n    num_frames: int = 81\n\n    # Wan2.2 T2V A14B supported resolutions\n    supported_resolutions: list[tuple[int, int]] | None = field(\n        default_factory=lambda: [\n            (1280, 720),  # 16:9\n            (720, 1280),  # 9:16\n            (832, 480),  # 16:9\n            (480, 832),  # 9:16\n        ]\n    )\n\n\n@dataclass\nclass Wan2_2_I2V_A14B_SamplingParam(Wan2_2_Base_SamplingParams):\n    guidance_scale: float = 3.5  # high_noise\n    guidance_scale_2: float = 3.5  # low_noise\n    num_inference_steps: int = 40\n    fps: int = 16\n\n    num_frames: int = 81\n\n    # Wan2.2 I2V A14B supported resolutions\n    supported_resolutions: list[tuple[int, int]] | None = field(\n        default_factory=lambda: [\n            (1280, 720),  # 16:9\n            (720, 1280),  # 9:16\n            (832, 480),  # 16:9\n            (480, 832),  # 9:16\n        ]\n    )\n\n\n@dataclass\nclass Turbo_Wan2_2_I2V_A14B_SamplingParam(Wan2_2_Base_SamplingParams):\n    guidance_scale: float = 3.5  # high_noise\n    guidance_scale_2: float = 3.5  # low_noise\n    num_inference_steps: int = 4\n    fps: int = 16\n\n\n# =============================================\n# ============= Causal Self-Forcing =============\n# =============================================\n@dataclass\nclass SelfForcingWanT2V480PConfig(WanT2V_1_3B_SamplingParams):\n    pass\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/sample/zimage.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\nfrom dataclasses import dataclass, field\n\nfrom sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams\nfrom sglang.multimodal_gen.configs.sample.teacache import TeaCacheParams\n\n\n@dataclass\nclass ZImageTurboSamplingParams(SamplingParams):\n    num_inference_steps: int = 9\n\n    num_frames: int = 1\n    negative_prompt: str = None\n    # height: int = 720\n    # width: int = 1280\n    # fps: int = 24\n\n    guidance_scale: float = 0.0\n    cfg_normalization: float | bool = False\n\n    teacache_params: TeaCacheParams = field(\n        default_factory=lambda: TeaCacheParams(\n            teacache_thresh=0.15,\n            coefficients=[\n                7.33226126e02,\n                -4.01131952e02,\n                6.75869174e01,\n                -3.14987800e00,\n                9.61237896e-02,\n            ],\n        )\n    )\n\n\n@dataclass\nclass ZImageSamplingParams(SamplingParams):\n    num_inference_steps: int = 50\n\n    num_frames: int = 1\n    negative_prompt: str = \" \"\n    guidance_scale: float = 5.0\n    cfg_normalization: float | bool = True\n"
  },
  {
    "path": "python/sglang/multimodal_gen/configs/utils.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\nimport argparse\nfrom typing import Any\n\n\ndef update_config_from_args(\n    config: Any, args_dict: dict[str, Any], prefix: str = \"\", pop_args: bool = False\n) -> bool:\n    \"\"\"\n    Update configuration object from arguments dictionary.\n\n    Args:\n        config: The configuration object to update\n        args_dict: Dictionary containing arguments\n        prefix: Prefix for the configuration parameters in the args_dict.\n               If None, assumes direct attribute mapping without prefix.\n    \"\"\"\n    # Handle top-level attributes (no prefix)\n    args_not_to_remove = [\n        \"model_path\",\n    ]\n    args_to_remove = []\n    if prefix.strip() == \"\":\n        for key, value in args_dict.items():\n            if hasattr(config, key) and value is not None:\n                if key == \"text_encoder_precisions\" and isinstance(value, list):\n                    setattr(config, key, tuple(value))\n                else:\n                    setattr(config, key, value)\n                if pop_args:\n                    args_to_remove.append(key)\n    else:\n        # Handle nested attributes with prefix\n        prefix_with_dot = f\"{prefix}.\"\n        for key, value in args_dict.items():\n            if key.startswith(prefix_with_dot) and value is not None:\n                attr_name = key[len(prefix_with_dot) :]\n                if hasattr(config, attr_name):\n                    setattr(config, attr_name, value)\n                if pop_args:\n                    args_to_remove.append(key)\n\n    if pop_args:\n        for key in args_to_remove:\n            if key not in args_not_to_remove:\n                args_dict.pop(key)\n\n    return len(args_to_remove) > 0\n\n\ndef clean_cli_args(args: argparse.Namespace) -> dict[str, Any]:\n    \"\"\"\n    Clean the arguments by removing the ones that not explicitly provided by the user.\n    \"\"\"\n    provided_args = {}\n    for k, v in vars(args).items():\n        if v is not None and hasattr(args, \"_provided\") and k in args._provided:\n            provided_args[k] = v\n\n    return provided_args\n"
  },
  {
    "path": "python/sglang/multimodal_gen/csrc/attn/vmoba_attn/README.md",
    "content": "# Attention Kernel Used in SGLang diffusion\n\n## VMoBA: Mixture-of-Block Attention for Video Diffusion Models (VMoBA)\n\n### Installation\nPlease ensure that you have installed FlashAttention version **2.7.1 or higher**, as some interfaces have changed in recent releases.\n\n### Usage\n\nYou can use `moba_attn_varlen` in the following ways:\n\n**Install from source:**\n```bash\npython setup.py install\n```\n\n**Import after installation:**\n```python\nfrom vmoba import moba_attn_varlen\n```\n\n**Or import directly from the project root:**\n```python\nfrom csrc.attn.vmoba_attn.vmoba import moba_attn_varlen\n```\n\n### Verify if you have successfully installed\n\n```bash\npython csrc/attn/vmoba_attn/vmoba/vmoba.py\n```\n"
  },
  {
    "path": "python/sglang/multimodal_gen/csrc/attn/vmoba_attn/setup.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\nfrom setuptools import find_packages, setup\n\nPACKAGE_NAME = \"vmoba\"\nVERSION = \"0.0.0\"\nAUTHOR = \"JianzongWu\"\nDESCRIPTION = \"VMoBA: Mixture-of-Block Attention for Video Diffusion Models\"\nURL = \"https://github.com/KwaiVGI/VMoBA\"\n\nsetup(\n    name=PACKAGE_NAME,\n    version=VERSION,\n    author=AUTHOR,\n    description=DESCRIPTION,\n    url=URL,\n    packages=find_packages(),\n    classifiers=[\n        \"Programming Language :: Python :: 3\",\n        \"License :: OSI Approved :: Apache Software License\",\n    ],\n    python_requires=\">=3.12\",\n    install_requires=[\n        \"flash-attn >= 2.7.1\",\n    ],\n)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/csrc/attn/vmoba_attn/tests/test_vmoba_attn.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\nimport random\n\nimport pytest\nimport torch\nfrom sglang.multimodal_gen.csrc.attn.vmoba_attn.vmoba import moba_attn_varlen\n\ndef generate_test_data(\n    batch_size, total_seqlen, num_heads, head_dim, dtype, device=\"cuda\"\n):\n    \"\"\"\n    Generates random data for testing the variable-length attention function.\n    \"\"\"\n    torch.manual_seed(42)\n    random.seed(42)\n    torch.cuda.manual_seed_all(42)\n\n    # Generate sequence lengths for each item in the batch\n    if batch_size > 1:\n        # Ensure sequence lengths are reasonably distributed\n        avg_seqlen = total_seqlen // batch_size\n        seqlens = [\n            random.randint(avg_seqlen // 2, avg_seqlen + avg_seqlen // 2)\n            for _ in range(batch_size - 1)\n        ]\n        remaining_len = total_seqlen - sum(seqlens)\n        if remaining_len > 0:\n            seqlens.append(remaining_len)\n        else:  # Adjust if sum exceeds total_seqlen\n            seqlens.append(avg_seqlen)\n            current_sum = sum(seqlens)\n            seqlens[-1] -= current_sum - total_seqlen\n        # Ensure all lengths are positive\n        seqlens = [max(1, s) for s in seqlens]\n        # Final adjustment to match total_seqlen\n        seqlens[-1] += total_seqlen - sum(seqlens)\n\n    else:\n        seqlens = [total_seqlen]\n\n    cu_seqlens = torch.tensor(\n        [0] + list(torch.cumsum(torch.tensor(seqlens), 0)),\n        device=device,\n        dtype=torch.int32,\n    )\n    max_seqlen = max(seqlens) if seqlens else 0\n\n    q = torch.randn(\n        (total_seqlen, num_heads, head_dim),\n        dtype=dtype,\n        device=device,\n        requires_grad=False,\n    )\n    k = torch.randn(\n        (total_seqlen, num_heads, head_dim),\n        dtype=dtype,\n        device=device,\n        requires_grad=False,\n    )\n    v = torch.randn(\n        (total_seqlen, num_heads, head_dim),\n        dtype=dtype,\n        device=device,\n        requires_grad=False,\n    )\n\n    return q, k, v, cu_seqlens, max_seqlen\n\n\n@pytest.mark.parametrize(\"batch_size\", [1, 2])\n@pytest.mark.parametrize(\"total_seqlen\", [512, 1024])\n@pytest.mark.parametrize(\"num_heads\", [8])\n@pytest.mark.parametrize(\"head_dim\", [64])\n@pytest.mark.parametrize(\"moba_chunk_size\", [64])\n@pytest.mark.parametrize(\"moba_topk\", [2, 4])\n@pytest.mark.parametrize(\"select_mode\", [\"topk\", \"threshold\"])\n@pytest.mark.parametrize(\"threshold_type\", [\"query_head\", \"head_global\", \"overall\"])\n@pytest.mark.parametrize(\"dtype\", [torch.float32, torch.float16, torch.bfloat16])\ndef test_moba_attn_varlen_forward(\n    batch_size,\n    total_seqlen,\n    num_heads,\n    head_dim,\n    moba_chunk_size,\n    moba_topk,\n    select_mode,\n    threshold_type,\n    dtype,\n):\n    \"\"\"\n    Tests the forward pass of moba_attn_varlen for basic correctness.\n    It checks output shape, dtype, and for the presence of NaNs/Infs.\n    \"\"\"\n    if dtype == torch.float32:\n        pytest.skip(\"float32 is not supported in flash attention\")\n\n    q, k, v, cu_seqlens, max_seqlen = generate_test_data(\n        batch_size, total_seqlen, num_heads, head_dim, dtype\n    )\n\n    # Ensure chunk size is not larger than the smallest sequence length\n    min_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).min().item()\n    if moba_chunk_size > min_seqlen:\n        pytest.skip(\n            \"moba_chunk_size is larger than the minimum sequence length in the batch\"\n        )\n\n    try:\n        output = moba_attn_varlen(\n            q=q,\n            k=k,\n            v=v,\n            cu_seqlens=cu_seqlens,\n            max_seqlen=max_seqlen,\n            moba_chunk_size=moba_chunk_size,\n            moba_topk=moba_topk,\n            select_mode=select_mode,\n            threshold_type=threshold_type,\n            simsum_threshold=0.5,  # A reasonable default for threshold mode\n        )\n    except Exception as e:\n        pytest.fail(f\"moba_attn_varlen forward pass failed with exception: {e}\")\n\n    # 1. Check output shape\n    assert (\n        output.shape == q.shape\n    ), f\"Expected output shape {q.shape}, but got {output.shape}\"\n\n    # 2. Check output dtype\n    assert (\n        output.dtype == q.dtype\n    ), f\"Expected output dtype {q.dtype}, but got {output.dtype}\"\n\n    # 3. Check for NaNs or Infs in the output\n    assert torch.all(torch.isfinite(output)), \"Output contains NaN or Inf values\"\n"
  },
  {
    "path": "python/sglang/multimodal_gen/csrc/attn/vmoba_attn/vmoba/__init__.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\nfrom .vmoba import moba_attn_varlen, process_moba_input, process_moba_output\n"
  },
  {
    "path": "python/sglang/multimodal_gen/csrc/attn/vmoba_attn/vmoba/vmoba.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n# Adapt from https://github.com/KwaiVGI/VMoBA/blob/main/src/vmoba.py\n\nimport random\nimport time\nfrom typing import Tuple\n\nimport torch\n\ntry:\n    from flash_attn import (  # Use the new flash attention function\n        flash_attn_varlen_func,\n    )\n    from flash_attn.flash_attn_interface import (\n        _flash_attn_varlen_backward,\n        _flash_attn_varlen_forward,\n    )\nexcept ImportError:\n\n    def _unsupported(*args, **kwargs):\n        raise ImportError(\n            \"flash-attn is not installed. Please install it, e.g., `pip install flash-attn`.\"\n        )\n\n    _flash_attn_varlen_forward = _unsupported\n    _flash_attn_varlen_backward = _unsupported\n    flash_attn_varlen_func = _unsupported\n\nfrom functools import lru_cache\n\nfrom einops import rearrange\n\n\n@lru_cache(maxsize=16)\ndef calc_chunks(cu_seqlen, moba_chunk_size):\n    \"\"\"\n    Calculate chunk boundaries.\n\n    For vision tasks we include all chunks (even the last one which might be shorter)\n    so that every chunk can be selected.\n    \"\"\"\n    batch_sizes = cu_seqlen[1:] - cu_seqlen[:-1]\n    batch_num_chunk = (batch_sizes + (moba_chunk_size - 1)) // moba_chunk_size\n    cu_num_chunk = torch.ones(\n        batch_num_chunk.numel() + 1,\n        device=cu_seqlen.device,\n        dtype=batch_num_chunk.dtype,\n    )\n    cu_num_chunk[1:] = batch_num_chunk.cumsum(dim=0)\n    num_chunk = cu_num_chunk[-1]\n    chunk_sizes = torch.full(\n        (num_chunk + 1,), moba_chunk_size, dtype=torch.int32, device=cu_seqlen.device\n    )\n    chunk_sizes[0] = 0\n    batch_last_chunk_size = batch_sizes - (batch_num_chunk - 1) * moba_chunk_size\n    chunk_sizes[cu_num_chunk[1:]] = batch_last_chunk_size\n    cu_chunk = chunk_sizes.cumsum(dim=-1, dtype=torch.int32)\n    chunk_to_batch = torch.zeros(\n        (num_chunk,), dtype=torch.int32, device=cu_seqlen.device\n    )\n    chunk_to_batch[cu_num_chunk[1:-1]] = 1\n    chunk_to_batch = chunk_to_batch.cumsum(dim=0, dtype=torch.int32)\n\n    # Do not filter out any chunk\n    filtered_chunk_indices = torch.arange(\n        num_chunk, device=cu_seqlen.device, dtype=torch.int32\n    )\n    num_filtered_chunk = num_chunk\n\n    return cu_chunk, filtered_chunk_indices, num_filtered_chunk, chunk_to_batch\n\n\n# --- Threshold Selection Helper Functions ---\n\n\ndef _select_threshold_query_head(\n    gate: torch.Tensor,\n    valid_gate_mask: torch.Tensor,\n    gate_self_chunk_mask: torch.Tensor,\n    simsum_threshold: float,\n) -> torch.Tensor:\n    \"\"\"\n    Selects chunks for each <query, head> pair based on threshold.\n    Normalization and sorting happen along the chunk dimension (dim=0).\n    \"\"\"\n    C, H, S = gate.shape\n    eps = 1e-6\n\n    # LSE‐style normalization per <head, query> (across chunks)\n    gate_masked = torch.where(valid_gate_mask, gate, -torch.inf)  # Use -inf for max\n    gate_min_val = torch.where(valid_gate_mask, gate, torch.inf)  # Use +inf for min\n\n    row_min = gate_min_val.amin(dim=0)  # (H, S)\n    row_max = gate_masked.amax(dim=0)  # (H, S)\n    denom = row_max - row_min\n    denom = torch.where(\n        denom <= eps, torch.ones_like(denom), denom\n    )  # avoid divide‑by‑zero\n\n    gate_norm = (gate - row_min.unsqueeze(0)) / denom.unsqueeze(0)\n    gate_norm = torch.where(valid_gate_mask, gate_norm, 0.0)  # (C, H, S)\n\n    # 1) pull out the self‐chunk’s normalized weight for each <head,seq>\n    self_norm = (gate_norm * gate_self_chunk_mask).sum(dim=0)  # (H, S)\n\n    # 2) compute how much more normalized weight we need beyond self\n    total_norm_sum = gate_norm.sum(dim=0)  # (H, S)\n    remain_ratio = simsum_threshold - self_norm / (total_norm_sum + eps)  # (H, S)\n    remain_ratio = torch.clamp(\n        remain_ratio, min=0.0\n    )  # if already ≥ thresh, no extra needed\n\n    # 3) zero out the self‐chunk in a copy, so we only sort “others”\n    others_norm = gate_norm.clone()\n    others_norm[gate_self_chunk_mask] = 0.0\n\n    # 4) sort the other chunks by descending norm, per <head,seq>\n    sorted_norm, sorted_idx = torch.sort(\n        others_norm, descending=True, dim=0\n    )  # (C, H, S)\n\n    # 5) cumulative‑sum the sorted norms per <head,seq>\n    cumsum_others = sorted_norm.cumsum(dim=0)  # (C, H, S)\n\n    # 6) for each <head,seq>, find the smallest k where cumsum_ratio ≥ remain_ratio\n    ratio = cumsum_others / (total_norm_sum.unsqueeze(0) + eps)  # (C, H, S)\n    cond = ratio >= remain_ratio.unsqueeze(0)  # (C, H, S) boolean mask\n    any_cond = cond.any(dim=0)  # (H, S)\n    # Find the index of the first True value along dim 0. If none, use C-1.\n    cutoff = torch.where(\n        any_cond,\n        cond.float().argmax(dim=0),\n        torch.full_like(any_cond, fill_value=C - 1),\n    )  # (H, S)\n\n    # 7) build a mask in sorted order up to that cutoff\n    idx_range = torch.arange(C, device=gate.device).view(-1, 1, 1)  # (C, 1, 1)\n    sorted_mask = idx_range <= cutoff.unsqueeze(0)  # (C, H, S)\n\n    # 8) scatter it back to original chunk order\n    others_mask = torch.zeros_like(gate, dtype=torch.bool)\n    others_mask.scatter_(0, sorted_idx, sorted_mask)\n\n    # 9) finally, include every self‐chunk plus all selected others\n    final_gate_mask = valid_gate_mask & (others_mask | gate_self_chunk_mask)\n\n    return final_gate_mask\n\n\ndef _select_threshold_block(\n    gate: torch.Tensor,\n    valid_gate_mask: torch.Tensor,\n    gate_self_chunk_mask: torch.Tensor,\n    simsum_threshold: float,\n) -> torch.Tensor:\n    \"\"\"\n    Selects <query, head> pairs for each block based on threshold.\n    Normalization and sorting happen across the head and sequence dimensions (dim=1, 2).\n    \"\"\"\n    C, H, S = gate.shape\n    HS = H * S\n    eps = 1e-6\n\n    # LSE‐style normalization per block (across heads and queries)\n    gate_masked = torch.where(valid_gate_mask, gate, -torch.inf)  # Use -inf for max\n    gate_min_val = torch.where(valid_gate_mask, gate, torch.inf)  # Use +inf for min\n\n    block_max = gate_masked.amax(dim=(1, 2), keepdim=True)  # (C, 1, 1)\n    block_min = gate_min_val.amin(dim=(1, 2), keepdim=True)  # (C, 1, 1)\n    block_denom = block_max - block_min\n    block_denom = torch.where(\n        block_denom <= eps, torch.ones_like(block_denom), block_denom\n    )  # (C, 1, 1)\n\n    gate_norm = (gate - block_min) / block_denom  # (C, H, S)\n    gate_norm = torch.where(valid_gate_mask, gate_norm, 0.0)  # (C, H, S)\n\n    # 1) identify normalized weights of entries that *are* self-chunks (from query perspective)\n    self_norm_entries = gate_norm * gate_self_chunk_mask  # (C, H, S)\n    # Sum these weights *per block*\n    self_norm_sum_per_block = self_norm_entries.sum(dim=(1, 2))  # (C,)\n\n    # 2) compute how much more normalized weight each block needs beyond its self-chunk contributions\n    total_norm_sum_per_block = gate_norm.sum(dim=(1, 2))  # (C,)\n    remain_ratio = simsum_threshold - self_norm_sum_per_block / (\n        total_norm_sum_per_block + eps\n    )  # (C,)\n    remain_ratio = torch.clamp(remain_ratio, min=0.0)  # (C,)\n\n    # 3) zero out the self‐chunk entries in a copy, so we only sort “others”\n    others_norm = gate_norm.clone()\n    others_norm[gate_self_chunk_mask] = 0.0  # Zero out self entries\n\n    # 4) sort the other <head, seq> pairs by descending norm, per block\n    others_flat = others_norm.contiguous().view(C, HS)  # (C, H*S)\n    sorted_others_flat, sorted_indices_flat = torch.sort(\n        others_flat, dim=1, descending=True\n    )  # (C, H*S)\n\n    # 5) cumulative‑sum the sorted norms per block\n    cumsum_others_flat = sorted_others_flat.cumsum(dim=1)  # (C, H*S)\n\n    # 6) for each block, find the smallest k where cumsum_ratio ≥ remain_ratio\n    ratio_flat = cumsum_others_flat / (\n        total_norm_sum_per_block.unsqueeze(1) + eps\n    )  # (C, H*S)\n    cond_flat = ratio_flat >= remain_ratio.unsqueeze(1)  # (C, H*S) boolean mask\n    any_cond = cond_flat.any(dim=1)  # (C,)\n    # Find the index of the first True value along dim 1. If none, use HS-1.\n    cutoff_flat = torch.where(\n        any_cond,\n        cond_flat.float().argmax(dim=1),\n        torch.full_like(any_cond, fill_value=HS - 1),\n    )  # (C,)\n\n    # 7) build a mask in sorted order up to that cutoff per block\n    idx_range_flat = torch.arange(HS, device=gate.device).unsqueeze(0)  # (1, H*S)\n    sorted_mask_flat = idx_range_flat <= cutoff_flat.unsqueeze(1)  # (C, H*S)\n\n    # 8) scatter it back to original <head, seq> order per block\n    others_mask_flat = torch.zeros_like(others_flat, dtype=torch.bool)  # (C, H*S)\n    others_mask_flat.scatter_(1, sorted_indices_flat, sorted_mask_flat)\n    others_mask = others_mask_flat.view(C, H, S)  # (C, H, S)\n\n    # 9) finally, include every self‐chunk entry plus all selected others\n    final_gate_mask = valid_gate_mask & (others_mask | gate_self_chunk_mask)\n\n    return final_gate_mask\n\n\ndef _select_threshold_overall(\n    gate: torch.Tensor,\n    valid_gate_mask: torch.Tensor,\n    gate_self_chunk_mask: torch.Tensor,\n    simsum_threshold: float,\n) -> torch.Tensor:\n    \"\"\"\n    Selects <chunk, query, head> triplets globally based on threshold.\n    Normalization and sorting happen across all valid entries.\n    \"\"\"\n    C, H, S = gate.shape\n    CHS = C * H * S\n    eps = 1e-6\n\n    # LSE‐style normalization globally across all valid entries\n    gate_masked = torch.where(valid_gate_mask, gate, -torch.inf)  # Use -inf for max\n    gate_min_val = torch.where(valid_gate_mask, gate, torch.inf)  # Use +inf for min\n\n    overall_max = gate_masked.max()  # scalar\n    overall_min = gate_min_val.min()  # scalar\n    overall_denom = overall_max - overall_min\n    overall_denom = torch.where(\n        overall_denom <= eps,\n        torch.tensor(1.0, device=gate.device, dtype=gate.dtype),\n        overall_denom,\n    )\n\n    gate_norm = (gate - overall_min) / overall_denom  # (C, H, S)\n    gate_norm = torch.where(valid_gate_mask, gate_norm, 0.0)  # (C, H, S)\n\n    # 1) identify normalized weights of entries that *are* self-chunks\n    self_norm_entries = gate_norm * gate_self_chunk_mask  # (C, H, S)\n    # Sum these weights globally\n    self_norm_sum_overall = self_norm_entries.sum()  # scalar\n\n    # 2) compute how much more normalized weight is needed globally beyond self-chunk contributions\n    total_norm_sum_overall = gate_norm.sum()  # scalar\n    remain_ratio = simsum_threshold - self_norm_sum_overall / (\n        total_norm_sum_overall + eps\n    )  # scalar\n    remain_ratio = torch.clamp(remain_ratio, min=0.0)  # scalar\n\n    # 3) zero out the self‐chunk entries in a copy, so we only sort “others”\n    others_norm = gate_norm.clone()\n    others_norm[gate_self_chunk_mask] = 0.0  # Zero out self entries\n\n    # 4) sort all other entries by descending norm, globally\n    others_flat = others_norm.flatten()  # (C*H*S,)\n    valid_others_mask_flat = (\n        valid_gate_mask.flatten() & ~gate_self_chunk_mask.flatten()\n    )  # Mask for valid, non-self entries\n\n    # Only sort the valid 'other' entries\n    valid_others_indices = torch.where(valid_others_mask_flat)[0]\n    valid_others_values = others_flat[valid_others_indices]\n\n    sorted_others_values, sort_perm = torch.sort(\n        valid_others_values, descending=True\n    )  # (N_valid_others,)\n    sorted_original_indices = valid_others_indices[\n        sort_perm\n    ]  # Original indices in C*H*S space, sorted by value\n\n    # 5) cumulative‑sum the sorted valid 'other' norms globally\n    cumsum_others_values = sorted_others_values.cumsum(dim=0)  # (N_valid_others,)\n\n    # 6) find the smallest k where cumsum_ratio ≥ remain_ratio globally\n    ratio_values = cumsum_others_values / (\n        total_norm_sum_overall + eps\n    )  # (N_valid_others,)\n    cond_values = ratio_values >= remain_ratio  # (N_valid_others,) boolean mask\n    any_cond = cond_values.any()  # scalar\n\n    # Find the index of the first True value in the *sorted* list. If none, use all valid others.\n    cutoff_idx_in_sorted = torch.where(\n        any_cond,\n        cond_values.float().argmax(dim=0),\n        torch.tensor(\n            len(sorted_others_values) - 1, device=gate.device, dtype=torch.long\n        ),\n    )\n\n    # 7) build a mask selecting the top-k others based on the cutoff\n    # Select the original indices corresponding to the top entries in the sorted list\n    selected_other_indices = sorted_original_indices[: cutoff_idx_in_sorted + 1]\n\n    # 8) create the mask in the original flat shape\n    others_mask_flat = torch.zeros_like(others_flat, dtype=torch.bool)  # (C*H*S,)\n    if selected_other_indices.numel() > 0:  # Check if any 'other' indices were selected\n        others_mask_flat[selected_other_indices] = True\n    others_mask = others_mask_flat.view(C, H, S)  # (C, H, S)\n\n    # 9) finally, include every self‐chunk entry plus all selected others\n    final_gate_mask = valid_gate_mask & (others_mask | gate_self_chunk_mask)\n\n    return final_gate_mask\n\n\ndef _select_threshold_head_global(\n    gate: torch.Tensor,\n    valid_gate_mask: torch.Tensor,\n    gate_self_chunk_mask: torch.Tensor,\n    simsum_threshold: float,\n) -> torch.Tensor:\n    \"\"\"\n    Selects <chunk, query> globally for each head based on threshold.\n    \"\"\"\n    C, H, S = gate.shape\n    eps = 1e-6\n\n    # 1) LSE‐style normalization per head (across chunks and sequence dims)\n    gate_masked = torch.where(valid_gate_mask, gate, -torch.inf)\n    gate_min_val = torch.where(valid_gate_mask, gate, torch.inf)\n\n    max_per_head = gate_masked.amax(dim=(0, 2), keepdim=True)  # (1, H, 1)\n    min_per_head = gate_min_val.amin(dim=(0, 2), keepdim=True)  # (1, H, 1)\n    denom = max_per_head - min_per_head\n    denom = torch.where(denom <= eps, torch.ones_like(denom), denom)\n\n    gate_norm = (gate - min_per_head) / denom\n    gate_norm = torch.where(valid_gate_mask, gate_norm, 0.0)  # (C, H, S)\n\n    # 2) sum normalized self‐chunk contributions per head\n    self_norm_sum = (gate_norm * gate_self_chunk_mask).sum(dim=(0, 2))  # (H,)\n\n    # 3) total normalized sum per head\n    total_norm_sum = gate_norm.sum(dim=(0, 2))  # (H,)\n\n    # 4) how much more normalized weight needed per head\n    remain_ratio = simsum_threshold - self_norm_sum / (total_norm_sum + eps)  # (H,)\n    remain_ratio = torch.clamp(remain_ratio, min=0.0)\n\n    # 5) zero out self‐chunk entries to focus on \"others\"\n    others_norm = gate_norm.clone()\n    others_norm[gate_self_chunk_mask] = 0.0  # (C, H, S)\n\n    # 6) flatten chunk and sequence dims, per head\n    CS = C * S\n    others_flat = others_norm.permute(1, 0, 2).reshape(H, CS)  # (H, C*S)\n    valid_flat = (\n        (valid_gate_mask & ~gate_self_chunk_mask).permute(1, 0, 2).reshape(H, CS)\n    )  # (H, C*S)\n\n    # 7) vectorized selection of “others” per head\n    masked_flat = torch.where(valid_flat, others_flat, torch.zeros_like(others_flat))\n    sorted_vals, sorted_idx = torch.sort(\n        masked_flat, dim=1, descending=True\n    )  # (H, C*S)\n\n    cumsum_vals = sorted_vals.cumsum(dim=1)  # (H, C*S)\n    ratio_vals = cumsum_vals / (total_norm_sum.unsqueeze(1) + eps)  # (H, C*S)\n    cond = ratio_vals >= remain_ratio.unsqueeze(1)  # (H, C*S)\n\n    has_cutoff = cond.any(dim=1)  # (H,)\n    default = torch.full((H,), CS - 1, device=gate.device, dtype=torch.long)\n    cutoff = torch.where(has_cutoff, cond.float().argmax(dim=1), default)  # (H,)\n\n    idx_range = torch.arange(CS, device=gate.device).unsqueeze(0)  # (1, C*S)\n    sorted_mask = idx_range <= cutoff.unsqueeze(1)  # (H, C*S)\n\n    selected_flat = torch.zeros_like(valid_flat)  # (H, C*S)\n    selected_flat.scatter_(1, sorted_idx, sorted_mask)  # (H, C*S)\n\n    # 8) reshape selection mask back to (C, H, S)\n    others_mask = selected_flat.reshape(H, C, S).permute(1, 0, 2)  # (C, H, S)\n\n    # 9) include self‐chunks plus selected others, and obey valid mask\n    final_gate_mask = valid_gate_mask & (gate_self_chunk_mask | others_mask)\n\n    return final_gate_mask\n\n\nclass MixedAttention(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        q,\n        k,\n        v,\n        self_attn_cu_seqlen,\n        moba_q,\n        moba_kv,\n        moba_cu_seqlen_q,\n        moba_cu_seqlen_kv,\n        max_seqlen,\n        moba_chunk_size,\n        moba_q_sh_indices,\n    ):\n        ctx.max_seqlen = max_seqlen\n        ctx.moba_chunk_size = moba_chunk_size\n        ctx.softmax_scale = softmax_scale = q.shape[-1] ** (-0.5)\n\n        # Non-causal self-attention branch\n        # return out, softmax_lse, S_dmask, rng_state\n        self_attn_out_sh, self_attn_lse_hs, _, _ = _flash_attn_varlen_forward(\n            q=q,\n            k=k,\n            v=v,\n            cu_seqlens_q=self_attn_cu_seqlen,\n            cu_seqlens_k=self_attn_cu_seqlen,\n            max_seqlen_q=max_seqlen,\n            max_seqlen_k=max_seqlen,\n            softmax_scale=softmax_scale,\n            causal=False,\n            dropout_p=0.0,\n        )\n        # MOBA attention branch (non-causal)\n        moba_attn_out, moba_attn_lse_hs, _, _ = _flash_attn_varlen_forward(\n            q=moba_q,\n            k=moba_kv[:, 0],\n            v=moba_kv[:, 1],\n            cu_seqlens_q=moba_cu_seqlen_q,\n            cu_seqlens_k=moba_cu_seqlen_kv,\n            max_seqlen_q=max_seqlen,\n            max_seqlen_k=moba_chunk_size,\n            softmax_scale=softmax_scale,\n            causal=False,\n            dropout_p=0.0,\n        )\n\n        self_attn_lse_sh = self_attn_lse_hs.t().contiguous()\n        moba_attn_lse = moba_attn_lse_hs.t().contiguous()\n\n        output = torch.zeros(\n            (q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32\n        )\n        output_2d = output.view(-1, q.shape[2])\n\n        max_lse_1d = self_attn_lse_sh.view(-1)\n        max_lse_1d = max_lse_1d.index_reduce(\n            0, moba_q_sh_indices, moba_attn_lse.view(-1), \"amax\"\n        )\n        self_attn_lse_sh = self_attn_lse_sh - max_lse_1d.view_as(self_attn_lse_sh)\n        moba_attn_lse = (\n            moba_attn_lse.view(-1)\n            .sub(max_lse_1d.index_select(0, moba_q_sh_indices))\n            .reshape_as(moba_attn_lse)\n        )\n\n        mixed_attn_se_sh = self_attn_lse_sh.exp()\n        moba_attn_se = moba_attn_lse.exp()\n\n        mixed_attn_se_sh.view(-1).index_add_(\n            0, moba_q_sh_indices, moba_attn_se.view(-1)\n        )\n        mixed_attn_lse_sh = mixed_attn_se_sh.log()\n\n        # Combine self-attention output\n        factor = (self_attn_lse_sh - mixed_attn_lse_sh).exp()  # [S, H]\n        self_attn_out_sh = self_attn_out_sh * factor.unsqueeze(-1)\n        output_2d += self_attn_out_sh.reshape_as(output_2d)\n\n        # Combine MOBA attention output\n        mixed_attn_lse = (\n            mixed_attn_lse_sh.view(-1)\n            .index_select(0, moba_q_sh_indices)\n            .view_as(moba_attn_lse)\n        )\n        factor = (moba_attn_lse - mixed_attn_lse).exp()  # [S, H]\n        moba_attn_out = moba_attn_out * factor.unsqueeze(-1)\n        raw_attn_out = moba_attn_out.view(-1, moba_attn_out.shape[-1])\n        output_2d.index_add_(0, moba_q_sh_indices, raw_attn_out)\n        output = output.to(q.dtype)\n        mixed_attn_lse_sh = mixed_attn_lse_sh + max_lse_1d.view_as(mixed_attn_se_sh)\n        ctx.save_for_backward(\n            output,\n            mixed_attn_lse_sh,\n            q,\n            k,\n            v,\n            self_attn_cu_seqlen,\n            moba_q,\n            moba_kv,\n            moba_cu_seqlen_q,\n            moba_cu_seqlen_kv,\n            moba_q_sh_indices,\n        )\n\n        return output\n\n    @staticmethod\n    def backward(ctx, d_output):\n\n        max_seqlen = ctx.max_seqlen\n        moba_chunk_size = ctx.moba_chunk_size\n        softmax_scale = ctx.softmax_scale\n\n        (\n            output,\n            mixed_attn_vlse_sh,\n            q,\n            k,\n            v,\n            self_attn_cu_seqlen,\n            moba_q,\n            moba_kv,\n            moba_cu_seqlen_q,\n            moba_cu_seqlen_kv,\n            moba_q_sh_indices,\n        ) = ctx.saved_tensors\n\n        d_output = d_output.contiguous()\n\n        dq = torch.empty_like(q)\n        dk = torch.empty_like(k)\n        dv = torch.empty_like(v)\n        _ = _flash_attn_varlen_backward(\n            dout=d_output,\n            q=q,\n            k=k,\n            v=v,\n            out=output,\n            softmax_lse=mixed_attn_vlse_sh.t().contiguous(),\n            dq=dq,\n            dk=dk,\n            dv=dv,\n            cu_seqlens_q=self_attn_cu_seqlen,\n            cu_seqlens_k=self_attn_cu_seqlen,\n            max_seqlen_q=max_seqlen,\n            max_seqlen_k=max_seqlen,\n            softmax_scale=softmax_scale,\n            causal=False,\n            dropout_p=0.0,\n            softcap=0.0,\n            alibi_slopes=None,\n            deterministic=True,\n            window_size_left=-1,\n            window_size_right=-1,\n        )\n\n        headdim = q.shape[-1]\n        d_moba_output = (\n            d_output.view(-1, headdim).index_select(0, moba_q_sh_indices).unsqueeze(1)\n        )\n        moba_output = (\n            output.view(-1, headdim).index_select(0, moba_q_sh_indices).unsqueeze(1)\n        )\n\n        mixed_attn_vlse = (\n            mixed_attn_vlse_sh.view(-1).index_select(0, moba_q_sh_indices).view(1, -1)\n        )\n\n        dmq = torch.empty_like(moba_q)\n        dmkv = torch.empty_like(moba_kv)\n        _ = _flash_attn_varlen_backward(\n            dout=d_moba_output,\n            q=moba_q,\n            k=moba_kv[:, 0],\n            v=moba_kv[:, 1],\n            out=moba_output,\n            softmax_lse=mixed_attn_vlse,\n            dq=dmq,\n            dk=dmkv[:, 0],\n            dv=dmkv[:, 1],\n            cu_seqlens_q=moba_cu_seqlen_q,\n            cu_seqlens_k=moba_cu_seqlen_kv,\n            max_seqlen_q=max_seqlen,\n            max_seqlen_k=moba_chunk_size,\n            softmax_scale=softmax_scale,\n            causal=False,\n            dropout_p=0.0,\n            softcap=0.0,\n            alibi_slopes=None,\n            deterministic=True,\n            window_size_left=-1,\n            window_size_right=-1,\n        )\n\n        return dq, dk, dv, None, dmq, dmkv, None, None, None, None, None\n\n\ndef moba_attn_varlen(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    cu_seqlens: torch.Tensor,\n    max_seqlen: int,\n    moba_chunk_size: int,\n    moba_topk: int,\n    select_mode: str = \"threshold\",  # \"topk\" or \"threshold\"\n    simsum_threshold: float = 0.25,\n    threshold_type: str = \"query_head\",\n) -> torch.Tensor:\n    \"\"\"\n    Accelerated MOBA attention for vision tasks with proper LSE normalization.\n\n    This version:\n      - Splits KV into chunks.\n      - For each query head, selects the top-k relevant KV chunks (including the self chunk)\n        by amplifying the diagonal (self-chunk) logits.\n      - Aggregates the attention outputs from the selected chunks using a log-sum-exp\n        reduction so that attending to each query over the selected chunks is equivalent\n        to the original algorithm.\n    \"\"\"\n    # Stack keys and values.\n    kv = torch.stack((k, v), dim=1)\n    seqlen, num_head, head_dim = q.shape\n\n    # Compute chunk boundaries.\n    cu_chunk, filtered_chunk_indices, num_filtered_chunk, chunk_to_batch = calc_chunks(\n        cu_seqlens, moba_chunk_size\n    )\n\n    self_attn_cu_seqlen = cu_chunk\n\n    # Update top-k selection to include the self chunk.\n    moba_topk = min(moba_topk, num_filtered_chunk)\n\n    # --- Build filtered KV from chunks ---\n    chunk_starts = cu_chunk[filtered_chunk_indices]  # [num_filtered_chunk]\n    chunk_ends = cu_chunk[filtered_chunk_indices + 1]  # [num_filtered_chunk]\n    chunk_lengths = chunk_ends - chunk_starts  # [num_filtered_chunk]\n    max_chunk_len = int(chunk_lengths.max().item())\n\n    range_tensor = torch.arange(\n        max_chunk_len, device=kv.device, dtype=chunk_starts.dtype\n    ).unsqueeze(0)\n    indices = chunk_starts.unsqueeze(1) + range_tensor\n    indices = torch.clamp(indices, max=kv.shape[0] - 1)\n    valid_mask = range_tensor < chunk_lengths.unsqueeze(1)\n    gathered = kv[indices.view(-1)].view(\n        num_filtered_chunk, max_chunk_len, *kv.shape[1:]\n    )\n    gathered = gathered * valid_mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).type_as(\n        gathered\n    )\n\n    # Compute key_gate_weight over valid tokens.\n    key_values = gathered[\n        :, :, 0\n    ].float()  # [num_filtered_chunk, max_chunk_len, num_head, head_dim]\n    valid_mask_exp = valid_mask.unsqueeze(-1).unsqueeze(-1)\n    key_sum = (key_values * valid_mask_exp).sum(dim=1)\n    divisor = valid_mask.sum(dim=1).unsqueeze(-1).unsqueeze(-1)\n    key_gate_weight = key_sum / divisor  # [num_filtered_chunk, num_head, head_dim]\n\n    # Compute gate logits between key_gate_weight and queries.\n    q_float = q.float()\n    # gate = torch.einsum(\"nhd,shd->nhs\", key_gate_weight, q_float)  # [num_filtered_chunk, num_head, seqlen]\n    gate = torch.bmm(\n        key_gate_weight.permute(1, 0, 2), q_float.permute(1, 0, 2).transpose(1, 2)\n    ).permute(1, 0, 2)\n\n    # Amplify the diagonal (self chunk) contributions.\n    gate_seq_idx = (\n        torch.arange(seqlen, device=q.device, dtype=torch.int32)\n        .unsqueeze(0)\n        .expand(num_filtered_chunk, seqlen)\n    )\n    chunk_start = cu_chunk[filtered_chunk_indices]  # [num_filtered_chunk]\n    chunk_end = cu_chunk[filtered_chunk_indices + 1]  # [num_filtered_chunk]\n    gate_self_chunk_mask = (\n        (\n            (gate_seq_idx >= chunk_start.unsqueeze(1))\n            & (gate_seq_idx < chunk_end.unsqueeze(1))\n        )\n        .unsqueeze(1)\n        .expand(-1, num_head, -1)\n    )\n    amplification_factor = 1e9  # Example factor; adjust as needed.\n    origin_gate = gate.clone()\n    gate = gate.clone()\n    if select_mode == \"topk\":\n        gate[gate_self_chunk_mask] += amplification_factor\n\n    # Exclude positions that are outside the valid batch boundaries.\n    batch_starts = cu_seqlens[chunk_to_batch[filtered_chunk_indices]]\n    batch_ends = cu_seqlens[chunk_to_batch[filtered_chunk_indices] + 1]\n    gate_batch_start_mask = gate_seq_idx < batch_starts.unsqueeze(1)\n    gate_batch_end_mask = gate_seq_idx >= batch_ends.unsqueeze(1)\n    gate_inf_mask = gate_batch_start_mask | gate_batch_end_mask\n    gate.masked_fill_(gate_inf_mask.unsqueeze(1), -float(\"inf\"))\n\n    if select_mode == \"topk\":\n        # We amplify self‐chunk in gate already, so self entries will rank highest.\n        valid_gate_mask = gate != -float(\"inf\")\n        if threshold_type == \"query_head\":\n            # === per‐<head,seq> top-k across chunks (original behavior) ===\n            # gate: (C, H, S)\n            _, gate_topk_idx = torch.topk(\n                gate, k=moba_topk, dim=0, largest=True, sorted=False\n            )\n            gate_idx_mask = torch.zeros_like(gate, dtype=torch.bool)\n            gate_idx_mask.scatter_(0, gate_topk_idx, True)\n            gate_mask = valid_gate_mask & gate_idx_mask\n        elif threshold_type == \"overall\":\n            # === global top-k across all (chunk, head, seq) entries ===\n            C, H, S = gate.shape\n            flat_gate = gate.flatten()\n            flat_mask = valid_gate_mask.flatten()\n            flat_gate_masked = torch.where(flat_mask, flat_gate, -float(\"inf\"))\n            # pick topk global entries\n            vals, idx = torch.topk(\n                flat_gate_masked, k=moba_topk * H * S, largest=True, sorted=False\n            )\n            others_mask_flat = torch.zeros_like(flat_mask, dtype=torch.bool)\n            others_mask_flat[idx] = True\n            gate_mask = (valid_gate_mask.flatten() & others_mask_flat).view(gate.shape)\n        elif threshold_type == \"head_global\":\n            # per-head top-k across all chunks and sequence positions\n            C, H, S = gate.shape\n            CS = C * S\n            flat_gate = gate.permute(1, 0, 2).reshape(H, CS)\n            flat_valid = valid_gate_mask.permute(1, 0, 2).reshape(H, CS)\n            flat_gate_masked = torch.where(\n                flat_valid, flat_gate, torch.full_like(flat_gate, -float(\"inf\"))\n            )\n            # pick top-k indices per head\n            _, topk_idx = torch.topk(\n                flat_gate_masked, k=moba_topk * S, dim=1, largest=True, sorted=False\n            )\n            gate_idx_flat = torch.zeros_like(flat_valid, dtype=torch.bool)\n            gate_idx_flat.scatter_(1, topk_idx, True)\n            gate_mask = gate_idx_flat.reshape(H, C, S).permute(1, 0, 2)\n        else:\n            raise ValueError(\n                f\"Invalid threshold_type for topk: {threshold_type}. \"\n                \"Choose 'query_head', 'block', or 'overall'.\"\n            )\n    elif select_mode == \"threshold\":\n        # Delegate to the specific thresholding function\n        valid_gate_mask = gate != -float(\"inf\")  # (num_chunk, num_head, seqlen)\n        if threshold_type == \"query_head\":\n            gate_mask = _select_threshold_query_head(\n                gate, valid_gate_mask, gate_self_chunk_mask, simsum_threshold\n            )\n        elif threshold_type == \"block\":\n            gate_mask = _select_threshold_block(\n                gate, valid_gate_mask, gate_self_chunk_mask, simsum_threshold\n            )\n        elif threshold_type == \"overall\":\n            gate_mask = _select_threshold_overall(\n                gate, valid_gate_mask, gate_self_chunk_mask, simsum_threshold\n            )\n        elif threshold_type == \"head_global\":\n            gate_mask = _select_threshold_head_global(\n                gate, valid_gate_mask, gate_self_chunk_mask, simsum_threshold\n            )\n        else:\n            raise ValueError(\n                f\"Invalid threshold_type: {threshold_type}. Choose 'query_head', 'block', or 'overall'.\"\n            )\n    else:\n        raise ValueError(\n            f\"Invalid select_mode: {select_mode}. Choose 'topk' or 'threshold'.\"\n        )\n\n    # eliminate self_chunk in MoBA branch\n    gate_mask = gate_mask & ~gate_self_chunk_mask\n    # if gate_mask is all false, perform flash_attn instead\n    if gate_mask.sum() == 0:\n        return flash_attn_varlen_func(\n            q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, causal=False\n        )\n\n    # Determine which query positions are selected.\n    # nonzero_indices has shape [N, 3] where each row is [chunk_index, head_index, seq_index].\n    moba_q_indices = gate_mask.reshape(gate_mask.shape[0], -1).nonzero(as_tuple=True)[\n        -1\n    ]  # [(h s k)]\n    moba_q_sh_indices = (moba_q_indices % seqlen) * num_head + (\n        moba_q_indices // seqlen\n    )\n    moba_q = (\n        rearrange(q, \"s h d -> (h s) d\").index_select(0, moba_q_indices).unsqueeze(1)\n    )\n\n    # Build cumulative sequence lengths for the selected queries.\n    moba_seqlen_q = gate_mask.sum(dim=-1).flatten()\n    q_zero_mask = moba_seqlen_q == 0\n    valid_expert_mask = ~q_zero_mask\n    if q_zero_mask.sum() > 0:\n        moba_seqlen_q = moba_seqlen_q[valid_expert_mask]\n    moba_cu_seqlen_q = torch.cat(\n        (\n            torch.tensor([0], device=q.device, dtype=moba_seqlen_q.dtype),\n            moba_seqlen_q.cumsum(dim=0),\n        ),\n        dim=0,\n    ).to(torch.int32)\n\n    # Rearrange gathered KV for the MOBA branch.\n    experts_tensor = rearrange(gathered, \"nc cl two h d -> (nc h) cl two d\")\n    valid_expert_lengths = (\n        chunk_lengths.unsqueeze(1)\n        .expand(num_filtered_chunk, num_head)\n        .reshape(-1)\n        .to(torch.int32)\n    )\n    if q_zero_mask.sum() > 0:\n        experts_tensor = experts_tensor[valid_expert_mask]\n        valid_expert_lengths = valid_expert_lengths[valid_expert_mask]\n\n    seq_range = torch.arange(\n        experts_tensor.shape[1], device=experts_tensor.device\n    ).unsqueeze(0)\n    mask = seq_range < valid_expert_lengths.unsqueeze(1)\n    moba_kv = experts_tensor[mask]  # Shape: ((nc h cl_valid) two d)\n    moba_kv = moba_kv.unsqueeze(2)  # Shape: ((nc h cl_valid) two 1 d)\n\n    moba_cu_seqlen_kv = torch.cat(\n        [\n            torch.zeros(1, device=experts_tensor.device, dtype=torch.int32),\n            valid_expert_lengths.cumsum(dim=0),\n        ],\n        dim=0,\n    ).to(torch.int32)\n\n    assert (\n        moba_cu_seqlen_kv.shape == moba_cu_seqlen_q.shape\n    ), f\"Mismatch between moba_cu_seqlen_kv.shape and moba_cu_seqlen_q.shape: {moba_cu_seqlen_kv.shape} vs {moba_cu_seqlen_q.shape}\"\n\n    return MixedAttention.apply(\n        q,\n        k,\n        v,\n        self_attn_cu_seqlen,\n        moba_q,\n        moba_kv,\n        moba_cu_seqlen_q,\n        moba_cu_seqlen_kv,\n        max_seqlen,\n        moba_chunk_size,\n        moba_q_sh_indices,\n    )\n\n\ndef process_moba_input(\n    x,\n    patch_resolution,\n    chunk_size,\n):\n    \"\"\"\n    Process inputs for the attention function.\n\n    Args:\n        x (torch.Tensor): Input tensor with shape [batch_size, num_patches, num_heads, head_dim].\n        patch_resolution (tuple): Tuple containing the patch resolution (t, h, w).\n        chunk_size (int): Size of the chunk. (maybe tuple or int, according to chunk type)\n\n    Returns:\n        torch.Tensor: Processed input tensor.\n    \"\"\"\n    if isinstance(chunk_size, float) or isinstance(chunk_size, int):\n        moba_chunk_size = int(chunk_size * patch_resolution[1] * patch_resolution[2])\n    else:\n        assert isinstance(\n            chunk_size, (Tuple, list)\n        ), f\"chunk_size should be a tuple, list, or int, now it is: {type(chunk_size)}\"\n        if len(chunk_size) == 2:\n            assert (\n                patch_resolution[1] % chunk_size[0] == 0\n                and patch_resolution[2] % chunk_size[1] == 0\n            ), f\"spatial patch_resolution {patch_resolution[1:]} should be divisible by 2d chunk_size {chunk_size}\"\n            nch, ncw = (\n                patch_resolution[1] // chunk_size[0],\n                patch_resolution[2] // chunk_size[1],\n            )\n            x = rearrange(\n                x,\n                \"b (t nch ch ncw cw) n d -> b (nch ncw t ch cw) n d\",\n                t=patch_resolution[0],\n                nch=nch,\n                ncw=ncw,\n                ch=chunk_size[0],\n                cw=chunk_size[1],\n            )\n            moba_chunk_size = patch_resolution[0] * chunk_size[0] * chunk_size[1]\n        elif len(chunk_size) == 3:\n            assert (\n                patch_resolution[0] % chunk_size[0] == 0\n                and patch_resolution[1] % chunk_size[1] == 0\n                and patch_resolution[2] % chunk_size[2] == 0\n            ), f\"patch_resolution {patch_resolution} should be divisible by 3d chunk_size {chunk_size}\"\n            nct, nch, ncw = (\n                patch_resolution[0] // chunk_size[0],\n                patch_resolution[1] // chunk_size[1],\n                patch_resolution[2] // chunk_size[2],\n            )\n            x = rearrange(\n                x,\n                \"b (nct ct nch ch ncw cw) n d -> b (nct nch ncw ct ch cw) n d\",\n                nct=nct,\n                nch=nch,\n                ncw=ncw,\n                ct=chunk_size[0],\n                ch=chunk_size[1],\n                cw=chunk_size[2],\n            )\n            moba_chunk_size = chunk_size[0] * chunk_size[1] * chunk_size[2]\n        else:\n            raise ValueError(\n                f\"chunk_size should be a int, or a tuple of length 2 or 3, now it is: {len(chunk_size)}\"\n            )\n\n    return x, moba_chunk_size\n\n\ndef process_moba_output(\n    x,\n    patch_resolution,\n    chunk_size,\n):\n    if isinstance(chunk_size, float) or isinstance(chunk_size, int):\n        pass\n    elif len(chunk_size) == 2:\n        x = rearrange(\n            x,\n            \"b (nch ncw t ch cw) n d -> b (t nch ch ncw cw) n d\",\n            nch=patch_resolution[1] // chunk_size[0],\n            ncw=patch_resolution[2] // chunk_size[1],\n            t=patch_resolution[0],\n            ch=chunk_size[0],\n            cw=chunk_size[1],\n        )\n    elif len(chunk_size) == 3:\n        x = rearrange(\n            x,\n            \"b (nct nch ncw ct ch cw) n d -> b (nct ct nch ch ncw cw) n d\",\n            nct=patch_resolution[0] // chunk_size[0],\n            nch=patch_resolution[1] // chunk_size[1],\n            ncw=patch_resolution[2] // chunk_size[2],\n            ct=chunk_size[0],\n            ch=chunk_size[1],\n            cw=chunk_size[2],\n        )\n\n    return x\n\n\n# TEST\ndef generate_data(batch_size, seqlen, num_head, head_dim, dtype):\n    random.seed(0)\n    torch.manual_seed(0)\n    torch.cuda.manual_seed(0)\n    device = torch.cuda.current_device()\n\n    q = torch.randn((batch_size, seqlen, num_head, head_dim), requires_grad=True).to(\n        dtype=dtype, device=\"cuda\"\n    )\n    k = torch.randn((batch_size, seqlen, num_head, head_dim), requires_grad=True).to(\n        dtype=dtype, device=\"cuda\"\n    )\n    v = torch.randn((batch_size, seqlen, num_head, head_dim), requires_grad=True).to(\n        dtype=dtype, device=\"cuda\"\n    )\n    print(f\"q.shape: {q.shape}, k.shape: {k.shape}, v.shape: {v.shape}\")\n    cu_seqlens = torch.arange(\n        0, q.shape[0] * q.shape[1] + 1, q.shape[1], dtype=torch.int32, device=\"cuda\"\n    )\n    max_seqlen = q.shape[1]\n    q = rearrange(q, \"b s ... -> (b s) ...\")\n    k = rearrange(k, \"b s ... -> (b s) ...\")\n    v = rearrange(v, \"b s ... -> (b s) ...\")\n\n    return q, k, v, cu_seqlens, max_seqlen\n\n\ndef test_attn_varlen_moba_speed(\n    batch,\n    head,\n    seqlen,\n    head_dim,\n    moba_chunk_size,\n    moba_topk,\n    dtype=torch.bfloat16,\n    select_mode=\"threshold\",\n    simsum_threshold=0.25,\n    threshold_type=\"query_head\",\n):\n    \"\"\"Speed test comparing flash_attn vs moba_attention\"\"\"\n    # Get data\n    q, k, v, cu_seqlen, max_seqlen = generate_data(batch, seqlen, head, head_dim, dtype)\n    print(\n        f\"batch:{batch} head:{head} seqlen:{seqlen} chunk:{moba_chunk_size} topk:{moba_topk} select_mode: {select_mode} simsum_threshold:{simsum_threshold}\"\n    )\n    vo_grad = torch.randn_like(q)\n\n    # Warmup\n    warmup_iters = 3\n    perf_test_iters = 10\n\n    # Warmup\n    for _ in range(warmup_iters):\n        o = flash_attn_varlen_func(\n            q, k, v, cu_seqlen, cu_seqlen, max_seqlen, max_seqlen, causal=False\n        )\n        torch.autograd.backward(o, vo_grad)\n\n    torch.cuda.synchronize()\n    start_flash = time.perf_counter()\n    for _ in range(perf_test_iters):\n        o = flash_attn_varlen_func(\n            q, k, v, cu_seqlen, cu_seqlen, max_seqlen, max_seqlen, causal=False\n        )\n        torch.autograd.backward(o, vo_grad)\n\n    torch.cuda.synchronize()\n    time_flash = (time.perf_counter() - start_flash) / perf_test_iters * 1000\n\n    # Warmup\n    for _ in range(warmup_iters):\n        om = moba_attn_varlen(\n            q,\n            k,\n            v,\n            cu_seqlen,\n            max_seqlen,\n            moba_chunk_size=moba_chunk_size,\n            moba_topk=moba_topk,\n            select_mode=select_mode,\n            simsum_threshold=simsum_threshold,\n            threshold_type=threshold_type,\n        )\n        torch.autograd.backward(om, vo_grad)\n\n    torch.cuda.synchronize()\n    start_moba = time.perf_counter()\n    for _ in range(perf_test_iters):\n        om = moba_attn_varlen(\n            q,\n            k,\n            v,\n            cu_seqlen,\n            max_seqlen,\n            moba_chunk_size=moba_chunk_size,\n            moba_topk=moba_topk,\n            select_mode=select_mode,\n            simsum_threshold=simsum_threshold,\n            threshold_type=threshold_type,\n        )\n        torch.autograd.backward(om, vo_grad)\n\n    torch.cuda.synchronize()\n    time_moba = (time.perf_counter() - start_moba) / perf_test_iters * 1000\n\n    print(f\"Flash: {time_flash:.2f}ms, MoBA: {time_moba:.2f}ms\")\n    print(f\"Speedup:  {time_flash / time_moba:.2f}x\")\n\n\nif __name__ == \"__main__\":\n    \"\"\"\n    CUDA_VISIBLE_DEVICES=1 \\\n    python -u csrc/attn/vmoba_attn/vmoba/vmoba.py\n    \"\"\"\n    test_attn_varlen_moba_speed(\n        batch=1,\n        head=12,\n        seqlen=32760,\n        head_dim=128,\n        moba_chunk_size=32760 // 3 // 6 // 4,\n        moba_topk=3,\n        select_mode=\"threshold\",\n        simsum_threshold=0.3,\n        threshold_type=\"query_head\",\n    )\n"
  },
  {
    "path": "python/sglang/multimodal_gen/csrc/render/hunyuan3d_rasterizer/__init__.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\"\"\"\nCustom CUDA rasterizer for Hunyuan3D texture generation.\n\nThis module provides JIT-compiled CUDA rasterization for fast mesh rendering.\nAdapted from Hunyuan3D-2: https://github.com/Tencent/Hunyuan3D-2\n\"\"\"\n\nfrom __future__ import annotations\n\nimport os\nfrom typing import List, Tuple\n\nimport torch\n\n_abs_path = os.path.dirname(os.path.abspath(__file__))\n_custom_rasterizer_kernel = None\n\n\ndef _load_custom_rasterizer():\n    \"\"\"JIT compile and load the custom rasterizer kernel.\"\"\"\n    global _custom_rasterizer_kernel\n\n    if _custom_rasterizer_kernel is not None:\n        return _custom_rasterizer_kernel\n\n    from torch.utils.cpp_extension import load\n\n    _custom_rasterizer_kernel = load(\n        name=\"custom_rasterizer_kernel\",\n        sources=[\n            f\"{_abs_path}/rasterizer.cpp\",\n            f\"{_abs_path}/rasterizer_gpu.cu\",\n        ],\n        extra_cflags=[\"-O3\"],\n        extra_cuda_cflags=[\"-O3\", \"--use_fast_math\"],\n        verbose=False,\n    )\n    return _custom_rasterizer_kernel\n\n\ndef rasterize(\n    pos: torch.Tensor,\n    tri: torch.Tensor,\n    resolution: Tuple[int, int],\n    clamp_depth: torch.Tensor = None,\n    use_depth_prior: int = 0,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"Rasterize mesh to get face indices and barycentric coordinates.\"\"\"\n    kernel = _load_custom_rasterizer()\n\n    if clamp_depth is None:\n        clamp_depth = torch.zeros(0, device=pos.device)\n\n    # pos should be [N, 4], remove batch dim if present\n    if pos.dim() == 3:\n        pos = pos[0]\n\n    findices, barycentric = kernel.rasterize_image(\n        pos, tri, clamp_depth, resolution[1], resolution[0], 1e-6, use_depth_prior\n    )\n    return findices, barycentric\n\n\ndef interpolate(\n    col: torch.Tensor,\n    findices: torch.Tensor,\n    barycentric: torch.Tensor,\n    tri: torch.Tensor,\n) -> torch.Tensor:\n    \"\"\"Interpolate vertex attributes using barycentric coordinates.\"\"\"\n    # Handle zero indices (background)\n    f = findices - 1 + (findices == 0)\n    vcol = col[0, tri.long()[f.long()]]\n    result = barycentric.view(*barycentric.shape, 1) * vcol\n    result = torch.sum(result, axis=-2)\n    return result.view(1, *result.shape)\n\n\n__all__ = [\"rasterize\", \"interpolate\"]\n"
  },
  {
    "path": "python/sglang/multimodal_gen/csrc/render/hunyuan3d_rasterizer/rasterizer.cpp",
    "content": "// SPDX-License-Identifier: Apache-2.0\n// Adapted from Hunyuan3D-2: https://github.com/Tencent/Hunyuan3D-2\n// Original license: TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT\n\n#include \"rasterizer.h\"\n\nvoid rasterizeTriangleCPU(int idx, float* vt0, float* vt1, float* vt2, int width, int height, INT64* zbuffer, float* d, float occlusion_truncation) {\n    float x_min = std::min(vt0[0], std::min(vt1[0],vt2[0]));\n    float x_max = std::max(vt0[0], std::max(vt1[0],vt2[0]));\n    float y_min = std::min(vt0[1], std::min(vt1[1],vt2[1]));\n    float y_max = std::max(vt0[1], std::max(vt1[1],vt2[1]));\n\n    for (int px = x_min; px < x_max + 1; ++px) {\n        if (px < 0 || px >= width)\n            continue;\n        for (int py = y_min; py < y_max + 1; ++py) {\n            if (py < 0 || py >= height)\n                continue;\n            float vt[2] = {px + 0.5f, py + 0.5f};\n            float baryCentricCoordinate[3];\n            calculateBarycentricCoordinate(vt0, vt1, vt2, vt, baryCentricCoordinate);\n            if (isBarycentricCoordInBounds(baryCentricCoordinate)) {\n                int pixel = py * width + px;\n                if (zbuffer == 0) {\n                    zbuffer[pixel] = (INT64)(idx + 1);\n                    continue;\n                }\n\n                float depth = baryCentricCoordinate[0] * vt0[2] + baryCentricCoordinate[1] * vt1[2] + baryCentricCoordinate[2] * vt2[2];\n                float depth_thres = 0;\n                if (d) {\n                    depth_thres = d[pixel] * 0.49999f + 0.5f + occlusion_truncation;\n                }\n                \n                int z_quantize = depth * (2<<17);\n                INT64 token = (INT64)z_quantize * MAXINT + (INT64)(idx + 1);\n                if (depth < depth_thres)\n                    continue;\n                zbuffer[pixel] = std::min(zbuffer[pixel], token);\n            }\n        }\n    }\n}\n\nvoid barycentricFromImgcoordCPU(float* V, int* F, int* findices, INT64* zbuffer, int width, int height, int num_vertices, int num_faces,\n    float* barycentric_map, int pix)\n{\n    INT64 f = zbuffer[pix] % MAXINT;\n    if (f == (MAXINT-1)) {\n        findices[pix] = 0;\n        barycentric_map[pix * 3] = 0;\n        barycentric_map[pix * 3 + 1] = 0;\n        barycentric_map[pix * 3 + 2] = 0;\n        return;\n    }\n    findices[pix] = f;\n    f -= 1;\n    float barycentric[3] = {0, 0, 0};\n    if (f >= 0) {\n        float vt[2] = {float(pix % width) + 0.5f, float(pix / width) + 0.5f};\n        float* vt0_ptr = V + (F[f * 3] * 4);\n        float* vt1_ptr = V + (F[f * 3 + 1] * 4);\n        float* vt2_ptr = V + (F[f * 3 + 2] * 4);\n\n        float vt0[2] = {(vt0_ptr[0] / vt0_ptr[3] * 0.5f + 0.5f) * (width - 1) + 0.5f, (0.5f + 0.5f * vt0_ptr[1] / vt0_ptr[3]) * (height - 1) + 0.5f};\n        float vt1[2] = {(vt1_ptr[0] / vt1_ptr[3] * 0.5f + 0.5f) * (width - 1) + 0.5f, (0.5f + 0.5f * vt1_ptr[1] / vt1_ptr[3]) * (height - 1) + 0.5f};\n        float vt2[2] = {(vt2_ptr[0] / vt2_ptr[3] * 0.5f + 0.5f) * (width - 1) + 0.5f, (0.5f + 0.5f * vt2_ptr[1] / vt2_ptr[3]) * (height - 1) + 0.5f};\n\n        calculateBarycentricCoordinate(vt0, vt1, vt2, vt, barycentric);\n\n        barycentric[0] = barycentric[0] / vt0_ptr[3];\n        barycentric[1] = barycentric[1] / vt1_ptr[3];\n        barycentric[2] = barycentric[2] / vt2_ptr[3];\n        float w = 1.0f / (barycentric[0] + barycentric[1] + barycentric[2]);\n        barycentric[0] *= w;\n        barycentric[1] *= w;\n        barycentric[2] *= w;\n    }\n    barycentric_map[pix * 3] = barycentric[0];\n    barycentric_map[pix * 3 + 1] = barycentric[1];\n    barycentric_map[pix * 3 + 2] = barycentric[2];\n}\n\nvoid rasterizeImagecoordsKernelCPU(float* V, int* F, float* d, INT64* zbuffer, float occlusion_trunc, int width, int height, int num_vertices, int num_faces, int f)\n{\n    float* vt0_ptr = V + (F[f * 3] * 4);\n    float* vt1_ptr = V + (F[f * 3 + 1] * 4);\n    float* vt2_ptr = V + (F[f * 3 + 2] * 4);\n\n    float vt0[3] = {(vt0_ptr[0] / vt0_ptr[3] * 0.5f + 0.5f) * (width - 1) + 0.5f, (0.5f + 0.5f * vt0_ptr[1] / vt0_ptr[3]) * (height - 1) + 0.5f, vt0_ptr[2] / vt0_ptr[3] * 0.49999f + 0.5f};\n    float vt1[3] = {(vt1_ptr[0] / vt1_ptr[3] * 0.5f + 0.5f) * (width - 1) + 0.5f, (0.5f + 0.5f * vt1_ptr[1] / vt1_ptr[3]) * (height - 1) + 0.5f, vt1_ptr[2] / vt1_ptr[3] * 0.49999f + 0.5f};\n    float vt2[3] = {(vt2_ptr[0] / vt2_ptr[3] * 0.5f + 0.5f) * (width - 1) + 0.5f, (0.5f + 0.5f * vt2_ptr[1] / vt2_ptr[3]) * (height - 1) + 0.5f, vt2_ptr[2] / vt2_ptr[3] * 0.49999f + 0.5f};\n\n    rasterizeTriangleCPU(f, vt0, vt1, vt2, width, height, zbuffer, d, occlusion_trunc);\n}\n\nstd::vector<torch::Tensor> rasterize_image_cpu(torch::Tensor V, torch::Tensor F, torch::Tensor D,\n    int width, int height, float occlusion_truncation, int use_depth_prior)\n{\n    int num_faces = F.size(0);\n    int num_vertices = V.size(0);\n    auto options = torch::TensorOptions().dtype(torch::kInt32).requires_grad(false);\n    auto INT64_options = torch::TensorOptions().dtype(torch::kInt64).requires_grad(false);\n    auto findices = torch::zeros({height, width}, options);\n    INT64 maxint = (INT64)MAXINT * (INT64)MAXINT + (MAXINT - 1);\n    auto z_min = torch::ones({height, width}, INT64_options) * (int64_t)maxint;\n\n    if (!use_depth_prior) {\n        for (int i = 0; i < num_faces; ++i) {\n            rasterizeImagecoordsKernelCPU(V.data_ptr<float>(), F.data_ptr<int>(), 0,\n                (INT64*)z_min.data_ptr<int64_t>(), occlusion_truncation, width, height, num_vertices, num_faces, i); \n        }\n    } else {\n        for (int i = 0; i < num_faces; ++i)\n            rasterizeImagecoordsKernelCPU(V.data_ptr<float>(), F.data_ptr<int>(), D.data_ptr<float>(),\n                (INT64*)z_min.data_ptr<int64_t>(), occlusion_truncation, width, height, num_vertices, num_faces, i);\n    }\n\n    auto float_options = torch::TensorOptions().dtype(torch::kFloat32).requires_grad(false);\n    auto barycentric = torch::zeros({height, width, 3}, float_options);\n    for (int i = 0; i < width * height; ++i)\n        barycentricFromImgcoordCPU(V.data_ptr<float>(), F.data_ptr<int>(),\n            findices.data_ptr<int>(), (INT64*)z_min.data_ptr<int64_t>(), width, height, num_vertices, num_faces, barycentric.data_ptr<float>(), i);\n\n    return {findices, barycentric};\n}\n\nstd::vector<torch::Tensor> rasterize_image(torch::Tensor V, torch::Tensor F, torch::Tensor D,\n    int width, int height, float occlusion_truncation, int use_depth_prior)\n{\n    int device_id = V.get_device();\n    if (device_id == -1)\n        return rasterize_image_cpu(V, F, D, width, height, occlusion_truncation, use_depth_prior);\n    else\n        return rasterize_image_gpu(V, F, D, width, height, occlusion_truncation, use_depth_prior);\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n    m.def(\"rasterize_image\", &rasterize_image, \"Custom image rasterization\");\n}\n"
  },
  {
    "path": "python/sglang/multimodal_gen/csrc/render/hunyuan3d_rasterizer/rasterizer.h",
    "content": "// SPDX-License-Identifier: Apache-2.0\n// Adapted from Hunyuan3D-2: https://github.com/Tencent/Hunyuan3D-2\n// Original license: TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT\n\n#ifndef RASTERIZER_H_\n#define RASTERIZER_H_\n\n#include <torch/extension.h>\n#include <vector>\n#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n\n#define INT64 unsigned long long\n#define MAXINT 2147483647\n\n__host__ __device__ inline float calculateSignedArea2(float* a, float* b, float* c) {\n    return ((c[0] - a[0]) * (b[1] - a[1]) - (b[0] - a[0]) * (c[1] - a[1]));\n}\n\n__host__ __device__ inline void calculateBarycentricCoordinate(float* a, float* b, float* c, float* p,\n    float* barycentric)\n{\n    float beta_tri = calculateSignedArea2(a, p, c);\n    float gamma_tri = calculateSignedArea2(a, b, p);\n    float area = calculateSignedArea2(a, b, c);\n    if (area == 0) {\n        barycentric[0] = -1.0;\n        barycentric[1] = -1.0;\n        barycentric[2] = -1.0;\n        return;\n    }\n    float tri_inv = 1.0 / area;\n    float beta = beta_tri * tri_inv;\n    float gamma = gamma_tri * tri_inv;\n    float alpha = 1.0 - beta - gamma;\n    barycentric[0] = alpha;\n    barycentric[1] = beta;\n    barycentric[2] = gamma;\n}\n\n__host__ __device__ inline bool isBarycentricCoordInBounds(float* barycentricCoord) {\n    return barycentricCoord[0] >= 0.0 && barycentricCoord[0] <= 1.0 &&\n           barycentricCoord[1] >= 0.0 && barycentricCoord[1] <= 1.0 &&\n           barycentricCoord[2] >= 0.0 && barycentricCoord[2] <= 1.0;\n}\n\nstd::vector<torch::Tensor> rasterize_image_gpu(torch::Tensor V, torch::Tensor F, torch::Tensor D,\n    int width, int height, float occlusion_truncation, int use_depth_prior);\n\n#endif\n"
  },
  {
    "path": "python/sglang/multimodal_gen/csrc/render/hunyuan3d_rasterizer/rasterizer_gpu.cu",
    "content": "// SPDX-License-Identifier: Apache-2.0\n// Adapted from Hunyuan3D-2: https://github.com/Tencent/Hunyuan3D-2\n// Original license: TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT\n\n#include \"rasterizer.h\"\n\n__device__ void rasterizeTriangleGPU(int idx, float* vt0, float* vt1, float* vt2, int width, int height, INT64* zbuffer, float* d, float occlusion_truncation) {\n    float x_min = std::min(vt0[0], std::min(vt1[0],vt2[0]));\n    float x_max = std::max(vt0[0], std::max(vt1[0],vt2[0]));\n    float y_min = std::min(vt0[1], std::min(vt1[1],vt2[1]));\n    float y_max = std::max(vt0[1], std::max(vt1[1],vt2[1]));\n\n    for (int px = x_min; px < x_max + 1; ++px) {\n        if (px < 0 || px >= width)\n            continue;\n        for (int py = y_min; py < y_max + 1; ++py) {\n            if (py < 0 || py >= height)\n                continue;\n            float vt[2] = {px + 0.5f, py + 0.5f};\n            float baryCentricCoordinate[3];\n            calculateBarycentricCoordinate(vt0, vt1, vt2, vt, baryCentricCoordinate);\n            if (isBarycentricCoordInBounds(baryCentricCoordinate)) {\n                int pixel = py * width + px;\n                if (zbuffer == 0) {\n                    atomicExch(&zbuffer[pixel], (INT64)(idx + 1));\n                    continue;\n                }\n                float depth = baryCentricCoordinate[0] * vt0[2] + baryCentricCoordinate[1] * vt1[2] + baryCentricCoordinate[2] * vt2[2];\n                float depth_thres = 0;\n                if (d) {\n                    depth_thres = d[pixel] * 0.49999f + 0.5f + occlusion_truncation;\n                }\n                \n                int z_quantize = depth * (2<<17);\n                INT64 token = (INT64)z_quantize * MAXINT + (INT64)(idx + 1);\n                if (depth < depth_thres)\n                    continue;\n                atomicMin(&zbuffer[pixel], token);\n            }\n        }\n    }\n}\n\n__global__ void barycentricFromImgcoordGPU(float* V, int* F, int* findices, INT64* zbuffer, int width, int height, int num_vertices, int num_faces,\n    float* barycentric_map)\n{\n    int pix = blockIdx.x * blockDim.x + threadIdx.x;\n    if (pix >= width * height)\n        return;\n    INT64 f = zbuffer[pix] % MAXINT;\n    if (f == (MAXINT-1)) {\n        findices[pix] = 0;\n        barycentric_map[pix * 3] = 0;\n        barycentric_map[pix * 3 + 1] = 0;\n        barycentric_map[pix * 3 + 2] = 0;\n        return;\n    }\n    findices[pix] = f;\n    f -= 1;\n    float barycentric[3] = {0, 0, 0};\n    if (f >= 0) {\n        float vt[2] = {float(pix % width) + 0.5f, float(pix / width) + 0.5f};\n        float* vt0_ptr = V + (F[f * 3] * 4);\n        float* vt1_ptr = V + (F[f * 3 + 1] * 4);\n        float* vt2_ptr = V + (F[f * 3 + 2] * 4);\n\n        float vt0[2] = {(vt0_ptr[0] / vt0_ptr[3] * 0.5f + 0.5f) * (width - 1) + 0.5f, (0.5f + 0.5f * vt0_ptr[1] / vt0_ptr[3]) * (height - 1) + 0.5f};\n        float vt1[2] = {(vt1_ptr[0] / vt1_ptr[3] * 0.5f + 0.5f) * (width - 1) + 0.5f, (0.5f + 0.5f * vt1_ptr[1] / vt1_ptr[3]) * (height - 1) + 0.5f};\n        float vt2[2] = {(vt2_ptr[0] / vt2_ptr[3] * 0.5f + 0.5f) * (width - 1) + 0.5f, (0.5f + 0.5f * vt2_ptr[1] / vt2_ptr[3]) * (height - 1) + 0.5f};\n\n        calculateBarycentricCoordinate(vt0, vt1, vt2, vt, barycentric);\n\n        barycentric[0] = barycentric[0] / vt0_ptr[3];\n        barycentric[1] = barycentric[1] / vt1_ptr[3];\n        barycentric[2] = barycentric[2] / vt2_ptr[3];\n        float w = 1.0f / (barycentric[0] + barycentric[1] + barycentric[2]);\n        barycentric[0] *= w;\n        barycentric[1] *= w;\n        barycentric[2] *= w;\n    }\n    barycentric_map[pix * 3] = barycentric[0];\n    barycentric_map[pix * 3 + 1] = barycentric[1];\n    barycentric_map[pix * 3 + 2] = barycentric[2];\n}\n\n__global__ void rasterizeImagecoordsKernelGPU(float* V, int* F, float* d, INT64* zbuffer, float occlusion_trunc, int width, int height, int num_vertices, int num_faces)\n{\n    int f = blockIdx.x * blockDim.x + threadIdx.x;\n    if (f >= num_faces)\n        return; \n\n    float* vt0_ptr = V + (F[f * 3] * 4);\n    float* vt1_ptr = V + (F[f * 3 + 1] * 4);\n    float* vt2_ptr = V + (F[f * 3 + 2] * 4);\n\n    float vt0[3] = {(vt0_ptr[0] / vt0_ptr[3] * 0.5f + 0.5f) * (width - 1) + 0.5f, (0.5f + 0.5f * vt0_ptr[1] / vt0_ptr[3]) * (height - 1) + 0.5f, vt0_ptr[2] / vt0_ptr[3] * 0.49999f + 0.5f};\n    float vt1[3] = {(vt1_ptr[0] / vt1_ptr[3] * 0.5f + 0.5f) * (width - 1) + 0.5f, (0.5f + 0.5f * vt1_ptr[1] / vt1_ptr[3]) * (height - 1) + 0.5f, vt1_ptr[2] / vt1_ptr[3] * 0.49999f + 0.5f};\n    float vt2[3] = {(vt2_ptr[0] / vt2_ptr[3] * 0.5f + 0.5f) * (width - 1) + 0.5f, (0.5f + 0.5f * vt2_ptr[1] / vt2_ptr[3]) * (height - 1) + 0.5f, vt2_ptr[2] / vt2_ptr[3] * 0.49999f + 0.5f};\n\n    rasterizeTriangleGPU(f, vt0, vt1, vt2, width, height, zbuffer, d, occlusion_trunc);\n}\n\nstd::vector<torch::Tensor> rasterize_image_gpu(torch::Tensor V, torch::Tensor F, torch::Tensor D,\n    int width, int height, float occlusion_truncation, int use_depth_prior)\n{\n    int device_id = V.get_device();\n    cudaSetDevice(device_id);\n    int num_faces = F.size(0);\n    int num_vertices = V.size(0);\n    auto options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA, device_id).requires_grad(false);\n    auto INT64_options = torch::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA, device_id).requires_grad(false);\n    auto findices = torch::zeros({height, width}, options);\n    INT64 maxint = (INT64)MAXINT * (INT64)MAXINT + (MAXINT - 1);\n    auto z_min = torch::ones({height, width}, INT64_options) * (int64_t)maxint;\n\n    if (!use_depth_prior) {\n        rasterizeImagecoordsKernelGPU<<<(num_faces+255)/256,256,0,at::cuda::getCurrentCUDAStream()>>>(V.data_ptr<float>(), F.data_ptr<int>(), 0,\n            (INT64*)z_min.data_ptr<int64_t>(), occlusion_truncation, width, height, num_vertices, num_faces); \n    } else {\n        rasterizeImagecoordsKernelGPU<<<(num_faces+255)/256,256,0,at::cuda::getCurrentCUDAStream()>>>(V.data_ptr<float>(), F.data_ptr<int>(), D.data_ptr<float>(),\n            (INT64*)z_min.data_ptr<int64_t>(), occlusion_truncation, width, height, num_vertices, num_faces); \n    }\n\n    auto float_options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA, device_id).requires_grad(false);\n    auto barycentric = torch::zeros({height, width, 3}, float_options);\n    barycentricFromImgcoordGPU<<<(width * height + 255)/256, 256, 0, at::cuda::getCurrentCUDAStream()>>>(V.data_ptr<float>(), F.data_ptr<int>(),\n        findices.data_ptr<int>(), (INT64*)z_min.data_ptr<int64_t>(), width, height, num_vertices, num_faces, barycentric.data_ptr<float>());\n\n    return {findices, barycentric};\n}\n"
  },
  {
    "path": "python/sglang/multimodal_gen/csrc/render/mesh_processor/__init__.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\"\"\"\nMesh processor C++ extension for texture inpainting.\n\nThis module provides JIT-compiled C++ mesh processing for fast texture inpainting.\nAdapted from Hunyuan3D-2: https://github.com/Tencent/Hunyuan3D-2\n\"\"\"\n\nfrom __future__ import annotations\n\nimport os\nfrom typing import Tuple\n\nimport numpy as np\n\n_abs_path = os.path.dirname(os.path.abspath(__file__))\n_mesh_processor_kernel = None\n\n\ndef _load_mesh_processor():\n    \"\"\"JIT compile and load the mesh processor kernel.\"\"\"\n    global _mesh_processor_kernel\n\n    if _mesh_processor_kernel is not None:\n        return _mesh_processor_kernel\n\n    from torch.utils.cpp_extension import load\n\n    _mesh_processor_kernel = load(\n        name=\"mesh_processor_kernel\",\n        sources=[\n            f\"{_abs_path}/mesh_processor.cpp\",\n        ],\n        extra_cflags=[\"-O3\"],\n        verbose=False,\n    )\n    return _mesh_processor_kernel\n\n\ndef meshVerticeInpaint(\n    texture: np.ndarray,\n    mask: np.ndarray,\n    vtx_pos: np.ndarray,\n    vtx_uv: np.ndarray,\n    pos_idx: np.ndarray,\n    uv_idx: np.ndarray,\n    method: str = \"smooth\",\n) -> Tuple[np.ndarray, np.ndarray]:\n    \"\"\"Inpaint texture using mesh vertex connectivity.\"\"\"\n    kernel = _load_mesh_processor()\n\n    texture = np.ascontiguousarray(texture, dtype=np.float32)\n    mask = np.ascontiguousarray(mask, dtype=np.uint8)\n    vtx_pos = np.ascontiguousarray(vtx_pos, dtype=np.float32)\n    vtx_uv = np.ascontiguousarray(vtx_uv, dtype=np.float32)\n    pos_idx = np.ascontiguousarray(pos_idx, dtype=np.int32)\n    uv_idx = np.ascontiguousarray(uv_idx, dtype=np.int32)\n\n    return kernel.meshVerticeInpaint(texture, mask, vtx_pos, vtx_uv, pos_idx, uv_idx, method)\n\n\n__all__ = [\"meshVerticeInpaint\"]\n"
  },
  {
    "path": "python/sglang/multimodal_gen/csrc/render/mesh_processor/mesh_processor.cpp",
    "content": "// SPDX-License-Identifier: Apache-2.0\n// Adapted from Hunyuan3D-2: https://github.com/Tencent/Hunyuan3D-2\n// Original license: TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT\n\n#include <vector>\n#include <queue>\n#include <cmath>\n#include <algorithm>\n#include <torch/extension.h>\n#include <pybind11/pybind11.h>\n#include <pybind11/numpy.h>\n#include <pybind11/stl.h>\n\nnamespace py = pybind11;\nusing namespace std;\n\nstd::pair<py::array_t<float>,\n  py::array_t<uint8_t>>  meshVerticeInpaint_smooth(py::array_t<float> texture,\npy::array_t<uint8_t> mask,\n                 py::array_t<float> vtx_pos, py::array_t<float> vtx_uv, \n                 py::array_t<int> pos_idx, py::array_t<int> uv_idx) {\n    auto texture_buf = texture.request();\n    auto mask_buf = mask.request();\n    auto vtx_pos_buf = vtx_pos.request();\n    auto vtx_uv_buf = vtx_uv.request();\n    auto pos_idx_buf = pos_idx.request();\n    auto uv_idx_buf = uv_idx.request();\n\n    int texture_height = texture_buf.shape[0];\n    int texture_width = texture_buf.shape[1];\n    int texture_channel = texture_buf.shape[2];\n    float* texture_ptr = static_cast<float*>(texture_buf.ptr);\n    uint8_t* mask_ptr = static_cast<uint8_t*>(mask_buf.ptr);\n\n    int vtx_num = vtx_pos_buf.shape[0];\n    float* vtx_pos_ptr = static_cast<float*>(vtx_pos_buf.ptr);\n    float* vtx_uv_ptr = static_cast<float*>(vtx_uv_buf.ptr);\n    int* pos_idx_ptr = static_cast<int*>(pos_idx_buf.ptr);\n    int* uv_idx_ptr = static_cast<int*>(uv_idx_buf.ptr);\n\n    vector<float> vtx_mask(vtx_num, 0.0f);\n    vector<vector<float>> vtx_color(vtx_num, vector<float>(texture_channel, 0.0f));\n    vector<int> uncolored_vtxs;\n\n    vector<vector<int>> G(vtx_num);\n\n    for (int i = 0; i < uv_idx_buf.shape[0]; ++i) {\n        for (int k = 0; k < 3; ++k) {\n            int vtx_uv_idx = uv_idx_ptr[i * 3 + k];\n            int vtx_idx = pos_idx_ptr[i * 3 + k];\n            int uv_v = round(vtx_uv_ptr[vtx_uv_idx * 2] * (texture_width - 1));\n            int uv_u = round((1.0 - vtx_uv_ptr[vtx_uv_idx * 2 + 1]) * (texture_height - 1));\n\n            if (mask_ptr[uv_u * texture_width + uv_v] > 0) {\n                vtx_mask[vtx_idx] = 1.0f;\n                for (int c = 0; c < texture_channel; ++c) {\n                    vtx_color[vtx_idx][c] = texture_ptr[(uv_u * texture_width + uv_v) * texture_channel + c];\n                }\n            }else{\n                uncolored_vtxs.push_back(vtx_idx);\n            }\n\n            G[pos_idx_ptr[i * 3 + k]].push_back(pos_idx_ptr[i * 3 + (k + 1) % 3]);\n        }\n    }\n\n    int smooth_count = 2;\n    int last_uncolored_vtx_count = 0;\n    while (smooth_count>0) {\n        int uncolored_vtx_count = 0;\n\n        for (int vtx_idx : uncolored_vtxs) {\n\n            vector<float> sum_color(texture_channel, 0.0f);\n            float total_weight = 0.0f;\n\n            array<float, 3> vtx_0 = {vtx_pos_ptr[vtx_idx * 3],\nvtx_pos_ptr[vtx_idx * 3 + 1], vtx_pos_ptr[vtx_idx * 3 + 2]};\n            for (int connected_idx : G[vtx_idx]) {\n                if (vtx_mask[connected_idx] > 0) {\n                    array<float, 3> vtx1 = {vtx_pos_ptr[connected_idx * 3],\n                    vtx_pos_ptr[connected_idx * 3 + 1], vtx_pos_ptr[connected_idx * 3 + 2]};\n                    float dist_weight = 1.0f / max(sqrt(pow(vtx_0[0] - vtx1[0], 2) + pow(vtx_0[1] - vtx1[1], 2) + \\\n                     pow(vtx_0[2] - vtx1[2], 2)), 1E-4);\n                    dist_weight = dist_weight * dist_weight;\n                    for (int c = 0; c < texture_channel; ++c) {\n                        sum_color[c] += vtx_color[connected_idx][c] * dist_weight;\n                    }\n                    total_weight += dist_weight;\n                }\n            }\n\n            if (total_weight > 0.0f) {\n                for (int c = 0; c < texture_channel; ++c) {\n                    vtx_color[vtx_idx][c] = sum_color[c] / total_weight;\n                }\n                vtx_mask[vtx_idx] = 1.0f;\n            } else {\n                uncolored_vtx_count++;\n            }\n            \n        }\n\n        if(last_uncolored_vtx_count==uncolored_vtx_count){\n            smooth_count--;\n        }else{\n            smooth_count++;\n        }\n        last_uncolored_vtx_count = uncolored_vtx_count;\n    }\n\n    py::array_t<float> new_texture(texture_buf.size);\n    py::array_t<uint8_t> new_mask(mask_buf.size);\n\n    auto new_texture_buf = new_texture.request();\n    auto new_mask_buf = new_mask.request();\n\n    float* new_texture_ptr = static_cast<float*>(new_texture_buf.ptr);\n    uint8_t* new_mask_ptr = static_cast<uint8_t*>(new_mask_buf.ptr);\n    std::copy(texture_ptr, texture_ptr + texture_buf.size, new_texture_ptr);\n    std::copy(mask_ptr, mask_ptr + mask_buf.size, new_mask_ptr);\n\n    for (int face_idx = 0; face_idx < uv_idx_buf.shape[0]; ++face_idx) {\n        for (int k = 0; k < 3; ++k) {\n            int vtx_uv_idx = uv_idx_ptr[face_idx * 3 + k];\n            int vtx_idx = pos_idx_ptr[face_idx * 3 + k];\n\n            if (vtx_mask[vtx_idx] == 1.0f) {\n                int uv_v = round(vtx_uv_ptr[vtx_uv_idx * 2] * (texture_width - 1));\n                int uv_u = round((1.0 - vtx_uv_ptr[vtx_uv_idx * 2 + 1]) * (texture_height - 1));\n\n                for (int c = 0; c < texture_channel; ++c) {\n                    new_texture_ptr[(uv_u * texture_width + uv_v) * texture_channel + c] = vtx_color[vtx_idx][c];\n                }\n                new_mask_ptr[uv_u * texture_width + uv_v] = 255;\n            }\n        }\n    }\n\n    new_texture.resize({texture_height, texture_width, 3});\n    new_mask.resize({texture_height, texture_width});\n  return std::make_pair(new_texture, new_mask);\n}\n\n\nstd::pair<py::array_t<float>, py::array_t<uint8_t>> meshVerticeInpaint(py::array_t<float> texture,\n          py::array_t<uint8_t> mask,\n          py::array_t<float> vtx_pos, py::array_t<float> vtx_uv,\n          py::array_t<int> pos_idx, py::array_t<int> uv_idx, const std::string& method = \"smooth\") {\n    if (method == \"smooth\") {\n        return meshVerticeInpaint_smooth(texture, mask, vtx_pos, vtx_uv, pos_idx, uv_idx);\n    } else {\n        throw std::invalid_argument(\"Invalid method. Use 'smooth'.\");\n    }\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n    m.def(\"meshVerticeInpaint\", &meshVerticeInpaint, \"Mesh-aware texture inpainting\",\n          py::arg(\"texture\"), py::arg(\"mask\"),\n          py::arg(\"vtx_pos\"), py::arg(\"vtx_uv\"),\n          py::arg(\"pos_idx\"), py::arg(\"uv_idx\"),\n          py::arg(\"method\") = \"smooth\");\n}\n"
  },
  {
    "path": "python/sglang/multimodal_gen/docs/quantization.md",
    "content": "# Quantization\n\nThis document introduces the model quantization schemes supported in SGLang and how to use them to reduce memory usage and accelerate inference.\n\n## Nunchaku (SVDQuant)\n\n### Introduction\n\n**SVDQuant** is a Post-Training Quantization (PTQ) technique for diffusion models that quantizes model weights and activations to 4-bit precision (W4A4) while maintaining high visual quality. This method uses Singular Value Decomposition (SVD) to decompose the weight matrix into low-rank components and residuals, effectively absorbing outliers in activations, making 4-bit quantization possible.\n\n**Nunchaku** is a high-performance inference engine that implements SVDQuant, optimized for low-bit neural networks. It is not Quantization-Aware Training (QAT), but directly quantizes pre-trained models.\n\nPaper: [SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models](https://arxiv.org/abs/2411.05007) (ICLR 2025 Spotlight)\n\n### Key Features\n\nSVDQuant significantly reduces memory usage and accelerates inference while maintaining visual quality:\n\n- **Memory Optimization**: Reduces memory usage by **3.6×** compared to BF16 models.\n- **Inference Acceleration**:\n    - **3.0×** faster than the NF4 (W4A16) baseline on desktop/laptop RTX 4090 GPUs.\n    - **8.7×** speedup on laptop RTX 4090 by eliminating CPU offloading compared to 16-bit models.\n    - **3.1×** faster than BF16 and NF4 models on RTX 5090 GPUs with NVFP4.\n\n### Supported Precisions\n\nNunchaku supports two quantization precisions:\n\n- **INT4**: Standard INT4 quantization, supported on NVIDIA GPUs with Compute Capability 7.0+ (RTX 20 series and above).\n- **NVFP4**: FP4 quantization, providing better image quality on newer cards like the RTX 5090.\n\n### Usage\n\n#### 1. Install Nunchaku\n\n```bash\npip install nunchaku\n```\n\nFor more installation information, please refer to the [Nunchaku Official Documentation](https://nunchaku.tech/docs/nunchaku/installation/installation.html).\n\n#### 2. Download Quantized Models\n\nNunchaku provides pre-quantized model weights available on Hugging Face:\n\n- [nunchaku-ai/nunchaku-qwen-image](https://huggingface.co/nunchaku-ai/nunchaku-qwen-image)\n- [nunchaku-ai/nunchaku-flux](https://huggingface.co/nunchaku-ai/nunchaku-flux)\n\nTaking Qwen-Image as an example, several quantized models with different configurations are provided:\n\n| Filename | Precision | Rank | Usage |\n|----------|-----------|------|-------|\n| `svdq-int4_r32-qwen-image.safetensors` | INT4 | 32 | Standard Version |\n| `svdq-int4_r128-qwen-image.safetensors` | INT4 | 128 | High-Quality Version |\n| `svdq-fp4_r32-qwen-image.safetensors` | NVFP4 | 32 | RTX 5090 Standard Version |\n| `svdq-fp4_r128-qwen-image.safetensors` | NVFP4 | 128 | RTX 5090 High-Quality Version |\n| `svdq-int4_r32-qwen-image-lightningv1.0-4steps.safetensors` | INT4 | 32 | Lightning 4-Step Version |\n| `svdq-int4_r128-qwen-image-lightningv1.1-8steps.safetensors` | INT4 | 128 | Lightning 8-Step Version |\n\n> **Note**: Higher Rank usually means better image quality, but with slightly increased memory usage and computation.\n\n#### 3. Run Quantized Models\n\nSGLang features **smart auto-detection** for Nunchaku models. In most cases, you only need to provide the path to the quantized weights, and the precision and rank will be automatically inferred from the filename.\n\n**Simplified Command (Recommended):**\n\n```bash\nsglang generate \\\n  --model-path Qwen/Qwen-Image \\\n  --prompt \"change the raccoon to a cute cat\" \\\n  --save-output \\\n  --transformer-weights-path /path/to/svdq-int4_r32-qwen-image.safetensors\n```\n\n**Manual Override (If needed):**\n\nIf your filename doesn't follow the standard naming convention, or you want to force specific settings:\n\n- `--enable-svdquant`: Manually enable SVDQuant.\n- `--quantization-precision`: Set to `int4` or `nvfp4`.\n- `--quantization-rank`: Set the SVD rank (e.g., 32, 128).\n- `--quantization-act-unsigned` (Optional): Use unsigned activation quantization.\n\nExample with manual overrides:\n\n```bash\nsglang generate \\\n  --model-path Qwen/Qwen-Image \\\n  --prompt \"a beautiful sunset\" \\\n  --enable-svdquant \\\n  --transformer-weights-path /path/to/custom_model.safetensors \\\n  --quantization-precision int4 \\\n  --quantization-rank 128\n```\n\n#### 4. Configuration Recommendations\n\nChoose the appropriate configuration based on your hardware and requirements:\n\n| Scenario | Recommended Config | Description |\n|----------|-------------------|-------------|\n| Standard Use (20/30/40 Series GPU) | INT4 + Rank 32 | Balanced performance and quality |\n| Quality Focus (Sufficient VRAM) | INT4 + Rank 128 | Better image quality |\n| RTX 5090 Standard Use | NVFP4 + Rank 32 | Utilizes FP4 hardware acceleration |\n| RTX 5090 Quality Focus | NVFP4 + Rank 128 | Best image quality |\n| Fast Prototyping/Preview | Lightning 4-Step Version | Extremely fast generation, slightly reduced quality |\n\n### Notes\n\n1.  Model Path Correspondence: `--model-path` should point to the original non-quantized model (for loading config and tokenizer, etc.), while `--transformer-weights-path` points to the quantized weight file / folder / Huggingface Repo ID.\n\n2.  Auto-Detection Requirements: For auto-detection to work, the filename must contain the pattern `svdq-{precision}_r{rank}` (e.g., `svdq-int4_r32`).\n\n3.  GPU Compatibility:\n    -   INT4: Supports NVIDIA GPUs with Compute Capability 7.0+ (RTX 20 series and above).\n    -   NVFP4: Optimized mainly for newer cards like the RTX 50 series that support FP4.\n\n4.  Lightning Models: When using Lightning versions, adjust `--num-inference-steps` accordingly (usually 4 or 8 steps).\n\n### Custom Model Quantization\n\nIf you want to quantize your own models, you can use the [DeepCompressor](https://github.com/mit-han-lab/deepcompressor) tool. For detailed instructions, please refer to the Nunchaku official documentation.\n\n## Quantization\n\n### Usage\n\n#### Option 1: Pre-quantized folder (has `config.json`)\n\nFor quantized checkpoints that include a `config.json` with a `quantization_config` field (e.g., models converted via `convert_hf_to_fp8.py`), where the transformer's `config.json` already encodes the `quantization_config`, use the component override:\n\n```bash\nsglang generate \\\n  --model-path /path/to/FLUX.1-dev \\\n  --transformer-path /path/to/FLUX.1-dev/transformer-FP8 \\\n  --prompt \"A Logo With Bold Large Text: SGL Diffusion\" \\\n  --save-output\n```\n\n\nIf you need to convert a model to FP8 format yourself, use the provided conversion script:\n\n```bash\n# convert transformer to FP8 with block quantization\npython -m sglang.multimodal_gen.tools.convert_hf_to_fp8 \\\n  --model-dir /path/to/FLUX.1-dev/transformer \\\n  --save-dir /path/to/FLUX.1-dev/transformer-FP8 \\\n  --strategy block \\\n  --block-size 128 128\n```\n\n#### Option 2: Pre-quantized single-file checkpoint (no `config.json`)\n\n\n\nSome providers (e.g., [black-forest-labs/FLUX.2-klein-9b-fp8](https://huggingface.co/black-forest-labs/FLUX.2-klein-9b-fp8)) distribute a single `.safetensors` file without a companion `config.json`. Use `--transformer-weights-path` to point to this file (or HuggingFace repo ID) while keeping `--model-path` for the base model:\n\n```bash\nsglang generate \\\n  --model-path black-forest-labs/FLUX.2-klein-9B \\\n  --transformer-weights-path black-forest-labs/FLUX.2-klein-9b-fp8 \\\n  --prompt \"A Logo With Bold Large Text: SGL Diffusion\" \\\n  --save-output\n```\n\nSGLang-Diffusion will automatically read the `quantization_config` metadata embedded in the safetensors file header (if present). For the quant config to be auto-detected, the file's metadata must contain a JSON-encoded `quantization_config` key with at least a `quant_method` field (e.g. `\"fp8\"`).\n\nNote: this feature is a WIP\n"
  },
  {
    "path": "python/sglang/multimodal_gen/envs.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/envs.py\n\nimport logging\nimport os\nfrom typing import TYPE_CHECKING, Any, Callable\n\nfrom sglang.multimodal_gen.runtime.utils.common import get_bool_env_var\n\nlogger = logging.getLogger(__name__)\n\nif TYPE_CHECKING:\n    SGLANG_DIFFUSION_RINGBUFFER_WARNING_INTERVAL: int = 60\n    SGLANG_DIFFUSION_NCCL_SO_PATH: str | None = None\n    LD_LIBRARY_PATH: str | None = None\n    LOCAL_RANK: int = 0\n    CUDA_VISIBLE_DEVICES: str | None = None\n    SGLANG_DIFFUSION_CACHE_ROOT: str = os.path.expanduser(\"~/.cache/sgl_diffusion\")\n    SGLANG_DIFFUSION_CONFIG_ROOT: str = os.path.expanduser(\"~/.config/sgl_diffusion\")\n    SGLANG_DIFFUSION_CONFIGURE_LOGGING: int = 1\n    SGLANG_DIFFUSION_LOGGING_LEVEL: str = \"INFO\"\n    SGLANG_DIFFUSION_LOGGING_PREFIX: str = \"\"\n    SGLANG_DIFFUSION_LOGGING_CONFIG_PATH: str | None = None\n    SGLANG_DIFFUSION_TRACE_FUNCTION: int = 0\n    SGLANG_DIFFUSION_WORKER_MULTIPROC_METHOD: str = \"fork\"\n    SGLANG_DIFFUSION_TARGET_DEVICE: str = \"cuda\"\n    MAX_JOBS: str | None = None\n    NVCC_THREADS: str | None = None\n    CMAKE_BUILD_TYPE: str | None = None\n    VERBOSE: bool = False\n    SGLANG_DIFFUSION_SERVER_DEV_MODE: bool = False\n    SGLANG_DIFFUSION_STAGE_LOGGING: bool = False\n    # cache-dit env vars (primary transformer)\n    SGLANG_CACHE_DIT_ENABLED: bool = False\n    SGLANG_CACHE_DIT_FN: int = 1\n    SGLANG_CACHE_DIT_BN: int = 0\n    SGLANG_CACHE_DIT_WARMUP: int = 4\n    SGLANG_CACHE_DIT_RDT: float = 0.24\n    SGLANG_CACHE_DIT_MC: int = 3\n    SGLANG_CACHE_DIT_TAYLORSEER: bool = False\n    SGLANG_CACHE_DIT_TS_ORDER: int = 1\n    SGLANG_CACHE_DIT_SCM_PRESET: str = \"none\"\n    SGLANG_CACHE_DIT_SCM_COMPUTE_BINS: str | None = None\n    SGLANG_CACHE_DIT_SCM_CACHE_BINS: str | None = None\n    SGLANG_CACHE_DIT_SCM_POLICY: str = \"dynamic\"\n    # cache-dit env vars (secondary transformer, e.g., Wan2.2 low-noise expert)\n    SGLANG_CACHE_DIT_SECONDARY_FN: int = 1\n    SGLANG_CACHE_DIT_SECONDARY_BN: int = 0\n    SGLANG_CACHE_DIT_SECONDARY_WARMUP: int = 4\n    SGLANG_CACHE_DIT_SECONDARY_RDT: float = 0.24\n    SGLANG_CACHE_DIT_SECONDARY_MC: int = 3\n    SGLANG_CACHE_DIT_SECONDARY_TAYLORSEER: bool = False\n    SGLANG_CACHE_DIT_SECONDARY_TS_ORDER: int = 1\n    # model loading\n    SGLANG_USE_RUNAI_MODEL_STREAMER: bool = True\n    SGLANG_DIFFUSION_VAE_CHANNELS_LAST_3D: bool = False\n    SGLANG_USE_ROCM_VAE: bool = False\n\n\ndef get_default_cache_root() -> str:\n    return os.getenv(\n        \"XDG_CACHE_HOME\",\n        os.path.join(os.path.expanduser(\"~\"), \".cache\"),\n    )\n\n\ndef get_default_config_root() -> str:\n    return os.getenv(\n        \"XDG_CONFIG_HOME\",\n        os.path.join(os.path.expanduser(\"~\"), \".config\"),\n    )\n\n\ndef maybe_convert_int(value: str | None) -> int | None:\n    return int(value) if value is not None else None\n\n\n# helpers for environment variable definitions\ndef _lazy_str(key: str, default: str | None = None) -> Callable[[], str | None]:\n    return lambda: os.getenv(key, default)\n\n\ndef _lazy_int(key: str, default: str | int | None = None) -> Callable[[], int | None]:\n    def _getter():\n        val = os.getenv(key)\n        if val is None:\n            return int(default) if default is not None else None\n        return int(val)\n\n    return _getter\n\n\ndef _lazy_float(key: str, default: str | float) -> Callable[[], float]:\n    return lambda: float(os.getenv(key, str(default)))\n\n\ndef _lazy_bool(key: str, default: str = \"false\") -> Callable[[], bool]:\n    return lambda: get_bool_env_var(key, default)\n\n\ndef _lazy_bool_any(keys: list[str], default: str = \"false\") -> Callable[[], bool]:\n    def _getter():\n        for key in keys:\n            if get_bool_env_var(key, \"false\"):\n                return True\n        return (\n            get_bool_env_var(\"\", default)\n            if not keys\n            else get_bool_env_var(keys[0], default)\n        )\n\n    return _getter\n\n\ndef _lazy_path(\n    key: str, default_func: Callable[[], str] | None = None\n) -> Callable[[], str | None]:\n    def _getter():\n        val = os.getenv(key)\n        if val is None:\n            if default_func is None:\n                return None\n            val = default_func()\n        return os.path.expanduser(val)\n\n    return _getter\n\n\n# The begin-* and end* here are used by the documentation generator\n# to extract the used env vars.\n\n# begin-env-vars-definition\n\nenvironment_variables: dict[str, Callable[[], Any]] = {\n    # ================== Installation Time Env Vars ==================\n    # Target device of sglang-diffusion, supporting [cuda (by default),\n    # rocm, neuron, cpu, openvino]\n    \"SGLANG_DIFFUSION_TARGET_DEVICE\": _lazy_str(\n        \"SGLANG_DIFFUSION_TARGET_DEVICE\", \"cuda\"\n    ),\n    # Maximum number of compilation jobs to run in parallel.\n    # By default this is the number of CPUs\n    \"MAX_JOBS\": _lazy_str(\"MAX_JOBS\"),\n    # Number of threads to use for nvcc\n    # By default this is 1.\n    # If set, `MAX_JOBS` will be reduced to avoid oversubscribing the CPU.\n    \"NVCC_THREADS\": _lazy_str(\"NVCC_THREADS\"),\n    # If set, sgl_diffusion will use precompiled binaries (*.so)\n    \"SGLANG_DIFFUSION_USE_PRECOMPILED\": _lazy_bool_any(\n        [\n            \"SGLANG_DIFFUSION_USE_PRECOMPILED\",\n            \"SGLANG_DIFFUSION_PRECOMPILED_WHEEL_LOCATION\",\n        ]\n    ),\n    # CMake build type\n    # If not set, defaults to \"Debug\" or \"RelWithDebInfo\"\n    # Available options: \"Debug\", \"Release\", \"RelWithDebInfo\"\n    \"CMAKE_BUILD_TYPE\": _lazy_str(\"CMAKE_BUILD_TYPE\"),\n    # If set, sgl_diffusion will print verbose logs during installation\n    \"VERBOSE\": _lazy_bool(\"VERBOSE\"),\n    # Root directory for SGL-diffusion configuration files\n    # Defaults to `~/.config/sgl_diffusion` unless `XDG_CONFIG_HOME` is set\n    # Note that this not only affects how sgl_diffusion finds its configuration files\n    # during runtime, but also affects how sgl_diffusion installs its configuration\n    # files during **installation**.\n    \"SGLANG_DIFFUSION_CONFIG_ROOT\": _lazy_path(\n        \"SGLANG_DIFFUSION_CONFIG_ROOT\",\n        lambda: os.path.join(get_default_config_root(), \"sgl_diffusion\"),\n    ),\n    # ================== Runtime Env Vars ==================\n    # Root directory for SGL-diffusion cache files\n    # Defaults to `~/.cache/sgl_diffusion` unless `XDG_CACHE_HOME` is set\n    \"SGLANG_DIFFUSION_CACHE_ROOT\": _lazy_path(\n        \"SGLANG_DIFFUSION_CACHE_ROOT\",\n        lambda: os.path.join(get_default_cache_root(), \"sgl_diffusion\"),\n    ),\n    # Interval in seconds to log a warning message when the ring buffer is full\n    \"SGLANG_DIFFUSION_RINGBUFFER_WARNING_INTERVAL\": _lazy_int(\n        \"SGLANG_DIFFUSION_RINGBUFFER_WARNING_INTERVAL\", 60\n    ),\n    # Path to the NCCL library file. It is needed because nccl>=2.19 brought\n    # by PyTorch contains a bug: https://github.com/NVIDIA/nccl/issues/1234\n    \"SGLANG_DIFFUSION_NCCL_SO_PATH\": _lazy_str(\"SGLANG_DIFFUSION_NCCL_SO_PATH\"),\n    # when `SGLANG_DIFFUSION_NCCL_SO_PATH` is not set, sgl_diffusion will try to find the nccl\n    # library file in the locations specified by `LD_LIBRARY_PATH`\n    \"LD_LIBRARY_PATH\": _lazy_str(\"LD_LIBRARY_PATH\"),\n    # Internal flag to enable Dynamo fullgraph capture\n    \"SGLANG_DIFFUSION_TEST_DYNAMO_FULLGRAPH_CAPTURE\": _lazy_bool(\n        \"SGLANG_DIFFUSION_TEST_DYNAMO_FULLGRAPH_CAPTURE\", \"1\"\n    ),\n    # local rank of the process in the distributed setting, used to determine\n    # the GPU device id\n    \"LOCAL_RANK\": _lazy_int(\"LOCAL_RANK\", 0),\n    # used to control the visible devices in the distributed setting\n    \"CUDA_VISIBLE_DEVICES\": _lazy_str(\"CUDA_VISIBLE_DEVICES\"),\n    # timeout for each iteration in the engine\n    \"SGLANG_DIFFUSION_ENGINE_ITERATION_TIMEOUT_S\": _lazy_int(\n        \"SGLANG_DIFFUSION_ENGINE_ITERATION_TIMEOUT_S\", 60\n    ),\n    # Logging configuration\n    # If set to 0, sgl_diffusion will not configure logging\n    # If set to 1, sgl_diffusion will configure logging using the default configuration\n    #    or the configuration file specified by SGLANG_DIFFUSION_LOGGING_CONFIG_PATH\n    \"SGLANG_DIFFUSION_CONFIGURE_LOGGING\": _lazy_int(\n        \"SGLANG_DIFFUSION_CONFIGURE_LOGGING\", 1\n    ),\n    \"SGLANG_DIFFUSION_LOGGING_CONFIG_PATH\": _lazy_str(\n        \"SGLANG_DIFFUSION_LOGGING_CONFIG_PATH\"\n    ),\n    # this is used for configuring the default logging level\n    \"SGLANG_DIFFUSION_LOGGING_LEVEL\": _lazy_str(\n        \"SGLANG_DIFFUSION_LOGGING_LEVEL\", \"INFO\"\n    ),\n    # if set, SGLANG_DIFFUSION_LOGGING_PREFIX will be prepended to all log messages\n    \"SGLANG_DIFFUSION_LOGGING_PREFIX\": _lazy_str(\"SGLANG_DIFFUSION_LOGGING_PREFIX\", \"\"),\n    # Trace function calls\n    # If set to 1, sgl_diffusion will trace function calls\n    # Useful for debugging\n    \"SGLANG_DIFFUSION_TRACE_FUNCTION\": _lazy_int(\"SGLANG_DIFFUSION_TRACE_FUNCTION\", 0),\n    # Path to the attention configuration file. Only used for sliding tile\n    # attention for now.\n    \"SGLANG_DIFFUSION_ATTENTION_CONFIG\": _lazy_path(\n        \"SGLANG_DIFFUSION_ATTENTION_CONFIG\"\n    ),\n    # Optional override to force a specific attention backend (e.g. \"aiter\")\n    \"SGLANG_DIFFUSION_ATTENTION_BACKEND\": _lazy_str(\n        \"SGLANG_DIFFUSION_ATTENTION_BACKEND\"\n    ),\n    # Use dedicated multiprocess context for workers.\n    # Both spawn and fork work\n    \"SGLANG_DIFFUSION_WORKER_MULTIPROC_METHOD\": _lazy_str(\n        \"SGLANG_DIFFUSION_WORKER_MULTIPROC_METHOD\", \"fork\"\n    ),\n    # Enables torch profiler if set. Path to the directory where torch profiler\n    # traces are saved. Note that it must be an absolute path.\n    \"SGLANG_DIFFUSION_TORCH_PROFILER_DIR\": _lazy_path(\n        \"SGLANG_DIFFUSION_TORCH_PROFILER_DIR\"\n    ),\n    # If set, sgl_diffusion will run in development mode, which will enable\n    # some additional endpoints for developing and debugging,\n    # e.g. `/reset_prefix_cache`\n    \"SGLANG_DIFFUSION_SERVER_DEV_MODE\": _lazy_bool(\"SGLANG_DIFFUSION_SERVER_DEV_MODE\"),\n    # If set, sgl_diffusion will enable stage logging, which will print the time\n    # taken for each stage\n    \"SGLANG_DIFFUSION_STAGE_LOGGING\": _lazy_bool(\"SGLANG_DIFFUSION_STAGE_LOGGING\"),\n    \"SGLANG_DIFFUSION_VAE_CHANNELS_LAST_3D\": _lazy_bool(\n        \"SGLANG_DIFFUSION_VAE_CHANNELS_LAST_3D\", \"false\"\n    ),\n    # ================== cache-dit Env Vars ==================\n    # Enable cache-dit acceleration for DiT inference\n    \"SGLANG_CACHE_DIT_ENABLED\": _lazy_bool(\"SGLANG_CACHE_DIT_ENABLED\"),\n    # Number of first blocks to always compute (DBCache F parameter)\n    \"SGLANG_CACHE_DIT_FN\": _lazy_int(\"SGLANG_CACHE_DIT_FN\", 1),\n    # Number of last blocks to always compute (DBCache B parameter)\n    \"SGLANG_CACHE_DIT_BN\": _lazy_int(\"SGLANG_CACHE_DIT_BN\", 0),\n    # Warmup steps before caching (DBCache W parameter)\n    \"SGLANG_CACHE_DIT_WARMUP\": _lazy_int(\"SGLANG_CACHE_DIT_WARMUP\", 4),\n    # Residual difference threshold (DBCache R parameter)\n    \"SGLANG_CACHE_DIT_RDT\": _lazy_float(\"SGLANG_CACHE_DIT_RDT\", 0.24),\n    # Maximum continuous cached steps (DBCache MC parameter)\n    \"SGLANG_CACHE_DIT_MC\": _lazy_int(\"SGLANG_CACHE_DIT_MC\", 3),\n    # Enable TaylorSeer calibrator\n    \"SGLANG_CACHE_DIT_TAYLORSEER\": _lazy_bool(\"SGLANG_CACHE_DIT_TAYLORSEER\", \"false\"),\n    # TaylorSeer order (1 or 2)\n    \"SGLANG_CACHE_DIT_TS_ORDER\": _lazy_int(\"SGLANG_CACHE_DIT_TS_ORDER\", 1),\n    # SCM preset: none, slow, medium, fast, ultra\n    \"SGLANG_CACHE_DIT_SCM_PRESET\": _lazy_str(\"SGLANG_CACHE_DIT_SCM_PRESET\", \"none\"),\n    # SCM custom compute bins (e.g., \"8,3,3,2,2\")\n    \"SGLANG_CACHE_DIT_SCM_COMPUTE_BINS\": _lazy_str(\"SGLANG_CACHE_DIT_SCM_COMPUTE_BINS\"),\n    # SCM custom cache bins (e.g., \"1,2,2,2,3\")\n    \"SGLANG_CACHE_DIT_SCM_CACHE_BINS\": _lazy_str(\"SGLANG_CACHE_DIT_SCM_CACHE_BINS\"),\n    # SCM policy: dynamic or static\n    \"SGLANG_CACHE_DIT_SCM_POLICY\": _lazy_str(\"SGLANG_CACHE_DIT_SCM_POLICY\", \"dynamic\"),\n    # model loading\n    \"SGLANG_USE_RUNAI_MODEL_STREAMER\": _lazy_bool(\n        \"SGLANG_USE_RUNAI_MODEL_STREAMER\", \"true\"\n    ),\n    # ROCm: use AITer GroupNorm in VAE for improved performance\n    \"SGLANG_USE_ROCM_VAE\": _lazy_bool(\"SGLANG_USE_ROCM_VAE\"),\n}\n\n# Add cache-dit Secondary Transformer Env Vars via programmatic generation to reduce duplication\n_CACHE_DIT_SECONDARY_CONFIGS = [\n    (\"FN\", int, \"1\"),\n    (\"BN\", int, \"0\"),\n    (\"WARMUP\", int, \"4\"),\n    (\"RDT\", float, \"0.24\"),\n    (\"MC\", int, \"3\"),\n    (\"TS_ORDER\", int, \"1\"),\n]\n\n\ndef _create_secondary_getter(suffix, type_func, default_val):\n    primary_key = f\"SGLANG_CACHE_DIT_{suffix}\"\n    secondary_key = f\"SGLANG_CACHE_DIT_SECONDARY_{suffix}\"\n\n    def _getter():\n        val = os.getenv(secondary_key)\n        if val is not None:\n            return type_func(val)\n        return type_func(os.getenv(primary_key, str(default_val)))\n\n    return secondary_key, _getter\n\n\nfor suffix, type_func, default_val in _CACHE_DIT_SECONDARY_CONFIGS:\n    key, getter = _create_secondary_getter(suffix, type_func, default_val)\n    environment_variables[key] = getter\n\n\n# Special handling for boolean secondary var (TaylorSeer)\ndef _secondary_taylorseer_getter():\n    return get_bool_env_var(\n        \"SGLANG_CACHE_DIT_SECONDARY_TAYLORSEER\",\n        default=os.getenv(\"SGLANG_CACHE_DIT_TAYLORSEER\", \"false\"),\n    )\n\n\nenvironment_variables[\"SGLANG_CACHE_DIT_SECONDARY_TAYLORSEER\"] = (\n    _secondary_taylorseer_getter\n)\n\n\n# end-env-vars-definition\ndef __getattr__(name: str):\n    # lazy evaluation of environment variables\n    if name in environment_variables:\n        return environment_variables[name]()\n    raise AttributeError(f\"module {__name__!r} has no attribute {name!r}\")\n\n\ndef __dir__():\n    return list(environment_variables.keys())\n"
  },
  {
    "path": "python/sglang/multimodal_gen/registry.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\"\"\"\nCentral registry for multimodal models.\n\nThis module provides a centralized registry for multimodal models, including pipelines\nand sampling parameters. It allows for easy registration and retrieval of model\ninformation based on model paths or other identifiers.\n\"\"\"\n\nimport dataclasses\nimport importlib\nimport os\nimport pkgutil\nfrom functools import lru_cache\nfrom typing import (\n    TYPE_CHECKING,\n    Any,\n    Callable,\n    Dict,\n    List,\n    Optional,\n    Tuple,\n    Type,\n    Union,\n)\n\nif TYPE_CHECKING:\n    from sglang.multimodal_gen.runtime.server_args import Backend\n\nfrom sglang.multimodal_gen.configs.pipeline_configs import (\n    FastHunyuanConfig,\n    FluxPipelineConfig,\n    HeliosDistilledConfig,\n    HeliosMidConfig,\n    HeliosT2VConfig,\n    HunyuanConfig,\n    WanI2V480PConfig,\n    WanI2V720PConfig,\n    WanT2V480PConfig,\n    WanT2V720PConfig,\n    ZImagePipelineConfig,\n)\nfrom sglang.multimodal_gen.configs.pipeline_configs.base import PipelineConfig\nfrom sglang.multimodal_gen.configs.pipeline_configs.flux import (\n    Flux2KleinPipelineConfig,\n    Flux2PipelineConfig,\n)\nfrom sglang.multimodal_gen.configs.pipeline_configs.glm_image import (\n    GlmImagePipelineConfig,\n)\nfrom sglang.multimodal_gen.configs.pipeline_configs.hunyuan3d import (\n    Hunyuan3D2PipelineConfig,\n)\nfrom sglang.multimodal_gen.configs.pipeline_configs.ltx_2 import LTX2PipelineConfig\nfrom sglang.multimodal_gen.configs.pipeline_configs.mova import (\n    MOVA360PConfig,\n    MOVA720PConfig,\n)\nfrom sglang.multimodal_gen.configs.pipeline_configs.qwen_image import (\n    QwenImageEditPipelineConfig,\n    QwenImageEditPlus_2511_PipelineConfig,\n    QwenImageEditPlusPipelineConfig,\n    QwenImageLayeredPipelineConfig,\n    QwenImagePipelineConfig,\n)\nfrom sglang.multimodal_gen.configs.pipeline_configs.sana import SanaPipelineConfig\nfrom sglang.multimodal_gen.configs.pipeline_configs.wan import (\n    FastWan2_1_T2V_480P_Config,\n    FastWan2_2_TI2V_5B_Config,\n    TurboWanI2V720Config,\n    TurboWanT2V480PConfig,\n    Wan2_2_I2V_A14B_Config,\n    Wan2_2_T2V_A14B_Config,\n    Wan2_2_TI2V_5B_Config,\n)\nfrom sglang.multimodal_gen.configs.sample.flux import (\n    Flux2KleinSamplingParams,\n    FluxSamplingParams,\n)\nfrom sglang.multimodal_gen.configs.sample.glmimage import GlmImageSamplingParams\nfrom sglang.multimodal_gen.configs.sample.helios import (\n    HeliosDistilledSamplingParams,\n    HeliosMidSamplingParams,\n    HeliosT2VSamplingParams,\n)\nfrom sglang.multimodal_gen.configs.sample.hunyuan import (\n    FastHunyuanSamplingParam,\n    HunyuanSamplingParams,\n)\nfrom sglang.multimodal_gen.configs.sample.hunyuan3d import Hunyuan3DSamplingParams\nfrom sglang.multimodal_gen.configs.sample.ltx_2 import LTX2SamplingParams\nfrom sglang.multimodal_gen.configs.sample.mova import (\n    MOVA_360P_SamplingParams,\n    MOVA_720P_SamplingParams,\n)\nfrom sglang.multimodal_gen.configs.sample.qwenimage import (\n    QwenImage2512SamplingParams,\n    QwenImageEditPlusSamplingParams,\n    QwenImageLayeredSamplingParams,\n    QwenImageSamplingParams,\n)\nfrom sglang.multimodal_gen.configs.sample.sana import SanaSamplingParams\nfrom sglang.multimodal_gen.configs.sample.wan import (\n    FastWanT2V480PConfig,\n    Turbo_Wan2_2_I2V_A14B_SamplingParam,\n    Wan2_1_Fun_1_3B_InP_SamplingParams,\n    Wan2_2_I2V_A14B_SamplingParam,\n    Wan2_2_T2V_A14B_SamplingParam,\n    Wan2_2_TI2V_5B_SamplingParam,\n    WanI2V_14B_480P_SamplingParam,\n    WanI2V_14B_720P_SamplingParam,\n    WanT2V_1_3B_SamplingParams,\n    WanT2V_14B_SamplingParams,\n)\nfrom sglang.multimodal_gen.configs.sample.zimage import (\n    ZImageSamplingParams,\n    ZImageTurboSamplingParams,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import (\n    ComposedPipelineBase,\n)\nfrom sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import (\n    maybe_download_model_index,\n    verify_model_config_and_directory,\n)\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n# --- Part 1: Pipeline Discovery ---\n\n_PIPELINE_REGISTRY: Dict[str, Type[ComposedPipelineBase]] = {}\n\n# Registry for pipeline configuration classes (for safetensors files without model_index.json)\n# Maps pipeline_class_name -> (PipelineConfig class, SamplingParams class)\n_PIPELINE_CONFIG_REGISTRY: Dict[str, Tuple[Type[PipelineConfig], Type[Any]]] = {}\n\n\ndef _discover_and_register_pipelines():\n    \"\"\"\n    Automatically discover and register all ComposedPipelineBase subclasses.\n    This function scans the 'sglang.multimodal_gen.runtime.pipelines' package,\n    finds modules with an 'EntryClass' attribute, and maps the class's 'pipeline_name'\n    to the class itself in a global registry.\n    \"\"\"\n    if _PIPELINE_REGISTRY:  # run only once\n        return\n\n    package_name = \"sglang.multimodal_gen.runtime.pipelines\"\n    package = importlib.import_module(package_name)\n\n    for _, module_name, ispkg in pkgutil.walk_packages(\n        package.__path__, package.__name__ + \".\"\n    ):\n        if not ispkg:\n            pipeline_module = importlib.import_module(module_name)\n            if hasattr(pipeline_module, \"EntryClass\"):\n                entry_cls = pipeline_module.EntryClass\n                entry_cls_list = (\n                    [entry_cls] if not isinstance(entry_cls, list) else entry_cls\n                )\n\n                for cls in entry_cls_list:\n                    if not issubclass(cls, ComposedPipelineBase):\n                        continue\n                    if cls.pipeline_name in _PIPELINE_REGISTRY:\n                        logger.warning(\n                            f\"Duplicate pipeline name '{cls.pipeline_name}' found. Overwriting.\"\n                        )\n                    _PIPELINE_REGISTRY[cls.pipeline_name] = cls\n\n                    # Special handling for ComfyUI Pipelines:\n                    # Auto-register config classes if Pipeline class has them defined\n                    # since comfyui get model from a single weight file, so we need to register the config classes here\n                    if hasattr(cls, \"pipeline_config_cls\") and hasattr(\n                        cls, \"sampling_params_cls\"\n                    ):\n                        _PIPELINE_CONFIG_REGISTRY[cls.pipeline_name] = (\n                            cls.pipeline_config_cls,\n                            cls.sampling_params_cls,\n                        )\n                        logger.debug(\n                            f\"Auto-registered config classes for pipeline '{cls.pipeline_name}': \"\n                            f\"PipelineConfig={cls.pipeline_config_cls.__name__}, \"\n                            f\"SamplingParams={cls.sampling_params_cls.__name__}\"\n                        )\n    logger.debug(\n        f\"Registering pipelines complete, {len(_PIPELINE_REGISTRY)} pipelines registered\"\n    )\n\n\ndef get_pipeline_config_classes(\n    pipeline_class_name: str,\n) -> Tuple[Type[PipelineConfig], Type[Any]] | None:\n    \"\"\"\n    Get the configuration classes for a pipeline.\n    \"\"\"\n    # Ensure pipelines are discovered first\n    _discover_and_register_pipelines()\n    return _PIPELINE_CONFIG_REGISTRY.get(pipeline_class_name)\n\n\n# --- Part 2: Config Registration ---\n@dataclasses.dataclass\nclass ConfigInfo:\n    \"\"\"Encapsulates all configuration information required to register a\n    diffusers model within this framework.\"\"\"\n\n    sampling_param_cls: Any\n    pipeline_config_cls: Type[PipelineConfig]\n\n\n# The central registry mapping a model name to its configuration information\n_CONFIG_REGISTRY: Dict[str, ConfigInfo] = {}\n\n# Mappings from Hugging Face model paths to our internal model names\n_MODEL_HF_PATH_TO_NAME: Dict[str, str] = {}\n\n# Detectors to identify model families from paths or class names\n_MODEL_NAME_DETECTORS: List[Tuple[str, Callable[[str], bool]]] = []\n\n\ndef register_configs(\n    sampling_param_cls: Any,\n    pipeline_config_cls: Type[PipelineConfig],\n    hf_model_paths: Optional[List[str]] = None,\n    model_detectors: Optional[List[Callable[[str], bool]]] = None,\n):\n    \"\"\"\n    Registers configuration classes for a new model family.\n    \"\"\"\n    model_id = str(len(_CONFIG_REGISTRY))\n\n    _CONFIG_REGISTRY[model_id] = ConfigInfo(\n        sampling_param_cls=sampling_param_cls,\n        pipeline_config_cls=pipeline_config_cls,\n    )\n    if hf_model_paths:\n        for path in hf_model_paths:\n            if path in _MODEL_HF_PATH_TO_NAME:\n                logger.warning(\n                    f\"Model path '{path}' is already mapped to '{_MODEL_HF_PATH_TO_NAME[path]}' and will be overwritten by '{model_id}'.\"\n                )\n            _MODEL_HF_PATH_TO_NAME[path] = model_id\n\n    if model_detectors:\n        for detector in model_detectors:\n            _MODEL_NAME_DETECTORS.append((model_id, detector))\n\n\ndef get_model_short_name(model_id: str) -> str:\n    if \"/\" in model_id:\n        return model_id.rstrip(\"/\").split(\"/\")[-1]\n    else:\n        return model_id\n\n\n@lru_cache(maxsize=1)\ndef _get_config_info(\n    model_path: str, model_id: Optional[str] = None\n) -> Optional[ConfigInfo]:\n    \"\"\"\n    Gets the ConfigInfo for a given model path using mappings and detectors.\n    \"\"\"\n    all_model_hf_paths = sorted(_MODEL_HF_PATH_TO_NAME.keys(), key=len, reverse=True)\n\n    # 0. Explicit model_id override: match by short name\n    if model_id is not None:\n        model_id_lower = model_id.lower()\n        for registered_hf_id in all_model_hf_paths:\n            if get_model_short_name(registered_hf_id).lower() == model_id_lower:\n                logger.debug(\n                    f\"Resolved model via explicit --model-id '{model_id}' → '{registered_hf_id}'.\"\n                )\n                return _CONFIG_REGISTRY.get(_MODEL_HF_PATH_TO_NAME[registered_hf_id])\n        logger.warning(\n            f\"--model-id '{model_id}' did not match any registered model; \"\n            \"falling back to automatic detection.\"\n        )\n\n    # 1. Exact match\n    if model_path in _MODEL_HF_PATH_TO_NAME:\n        model_id = _MODEL_HF_PATH_TO_NAME[model_path]\n        logger.debug(f\"Resolved model path '{model_path}' from exact path match.\")\n        return _CONFIG_REGISTRY.get(model_id)\n\n    # 2. Partial match: find the best (longest) match against all registered model hf paths.\n    model_short_name = get_model_short_name(model_path.lower())\n    for registered_model_hf_id in all_model_hf_paths:\n        registered_model_name = get_model_short_name(registered_model_hf_id.lower())\n\n        if registered_model_name in model_short_name:\n            logger.debug(\n                f\"Resolved model name '{registered_model_hf_id}' from partial path match.\"\n            )\n            model_id = _MODEL_HF_PATH_TO_NAME[registered_model_hf_id]\n            return _CONFIG_REGISTRY.get(model_id)\n\n    # 3. Use detectors\n    if os.path.exists(model_path):\n        config = verify_model_config_and_directory(model_path)\n    else:\n        config = maybe_download_model_index(model_path)\n\n    pipeline_name = config.get(\"_class_name\", \"\").lower()\n\n    matched_model_names = []\n    for model_id, detector in _MODEL_NAME_DETECTORS:\n        if detector(model_path.lower()) or detector(pipeline_name):\n            logger.debug(\n                f\"Matched model name '{model_id}' using a registered detector.\"\n            )\n            matched_model_names += [model_id]\n\n    if len(matched_model_names) >= 1:\n        if len(matched_model_names) > 1:\n            logger.warning(\n                f\"More than one model name is matched, using the first matched\"\n            )\n        model_id = matched_model_names[0]\n        return _CONFIG_REGISTRY.get(model_id)\n    else:\n        raise RuntimeError(f\"No model info found for model path: {model_path}\")\n\n\n# --- Part 3: Main Resolver ---\n\n\n@dataclasses.dataclass\nclass ModelInfo:\n    \"\"\"\n    Encapsulates all configuration information required to register a\n    diffusers model within this framework.\n    \"\"\"\n\n    pipeline_cls: Type[ComposedPipelineBase]\n    sampling_param_cls: Any\n    pipeline_config_cls: Type[PipelineConfig]\n\n\ndef _get_diffusers_model_info(\n    model_path: Optional[str] = None,\n    model_id: Optional[str] = None,\n) -> ModelInfo:\n    \"\"\"\n    Get model info for diffusers backend.\n\n    Returns a ModelInfo with DiffusersPipeline and generic configs.\n    When model_path is provided and has a registered native config,\n    inherits task_type from it so that validation (e.g. accepts_image_input)\n    works correctly even under the diffusers backend.\n    \"\"\"\n    from sglang.multimodal_gen.configs.pipeline_configs.diffusers_generic import (\n        DiffusersGenericPipelineConfig,\n    )\n    from sglang.multimodal_gen.configs.sample.diffusers_generic import (\n        DiffusersGenericSamplingParams,\n    )\n    from sglang.multimodal_gen.runtime.pipelines.diffusers_pipeline import (\n        DiffusersPipeline,\n    )\n\n    sampling_param_cls = DiffusersGenericSamplingParams\n    pipeline_config_cls = DiffusersGenericPipelineConfig\n\n    # If there is a registered native config for this model, inherit its task_type\n    if model_path is not None:\n        config_info = _get_config_info(model_path, model_id=model_id)\n        if config_info is not None:\n            sampling_param_cls = config_info.sampling_param_cls\n            native_task_type = config_info.pipeline_config_cls.task_type\n            if native_task_type != DiffusersGenericPipelineConfig.task_type:\n                pipeline_config_cls = dataclasses.make_dataclass(\n                    \"DiffusersGenericPipelineConfig\",\n                    [\n                        (\n                            \"task_type\",\n                            type(native_task_type),\n                            dataclasses.field(default=native_task_type),\n                        )\n                    ],\n                    bases=(DiffusersGenericPipelineConfig,),\n                )\n                logger.debug(\n                    \"Inherited task_type=%s from native config for diffusers backend\",\n                    native_task_type.name,\n                )\n\n    return ModelInfo(\n        pipeline_cls=DiffusersPipeline,\n        sampling_param_cls=sampling_param_cls,\n        pipeline_config_cls=pipeline_config_cls,\n    )\n\n\n@lru_cache(maxsize=1)\ndef get_model_info(\n    model_path: str,\n    backend: Optional[Union[str, \"Backend\"]] = None,\n    model_id: Optional[str] = None,\n) -> Optional[ModelInfo]:\n    \"\"\"\n    Resolves all necessary classes (pipeline, sampling, config) for a given model path.\n\n    This function serves as the main entry point for model resolution. It performs two main tasks:\n    1. Dynamically resolves the pipeline class by reading 'model_index.json' and matching\n       '_class_name' against an auto-discovered registry of pipeline implementations.\n    2. Resolves the associated configuration classes (for sampling and pipeline) using a\n       manually registered mapping based on the model path.\n\n    Args:\n        backend: Backend to use ('auto', 'sglang', 'diffusers'). If None, uses 'auto'.\n\n    \"\"\"\n    # import Backend enum here to avoid circular imports\n    from sglang.multimodal_gen.runtime.server_args import Backend\n\n    # Normalize backend\n    if backend is None:\n        backend = Backend.AUTO\n    elif isinstance(backend, str):\n        backend = Backend.from_string(backend)\n\n    # Handle explicit diffusers backend\n    if backend == Backend.DIFFUSERS:\n        logger.info(\n            \"Using diffusers backend for model '%s' (explicitly requested)\", model_path\n        )\n        return _get_diffusers_model_info(model_path=model_path, model_id=model_id)\n\n    # For AUTO or SGLANG backend, try native implementation first\n    # 1. Discover all available pipeline classes and cache them\n    _discover_and_register_pipelines()\n\n    # Detect quantized models and fallback to diffusers\n    is_quantized = any(q in model_path.lower() for q in [\"-4bit\", \"-awq\", \"-gptq\"])\n    if is_quantized and backend != Backend.DIFFUSERS:\n        logger.info(\n            \"Detected a quantized model format ('%s'). \"\n            \"The native sglang-diffusion engine currently only supports BF16/FP16. \"\n            \"Falling back to diffusers backend.\",\n            model_path,\n        )\n        return _get_diffusers_model_info(model_path=model_path, model_id=model_id)\n\n    # 2. Get pipeline class - check non-diffusers models first\n    pipeline_class_name = get_non_diffusers_pipeline_name(model_path)\n    if pipeline_class_name:\n        # Known non-diffusers model, skip model_index.json download\n        logger.debug(\n            f\"Using registered pipeline '{pipeline_class_name}' for non-diffusers model '{model_path}'\"\n        )\n    else:\n        # Try to get from model_index.json\n        try:\n            if os.path.exists(model_path):\n                config = verify_model_config_and_directory(model_path)\n            else:\n                config = maybe_download_model_index(model_path)\n        except Exception as e:\n            logger.error(f\"Could not read model config for '{model_path}': {e}\")\n            if backend == Backend.AUTO:\n                logger.info(\"Falling back to diffusers backend\")\n                return _get_diffusers_model_info(\n                    model_path=model_path, model_id=model_id\n                )\n            return None\n\n        pipeline_class_name = config.get(\"_class_name\")\n        if not pipeline_class_name:\n            logger.error(\n                f\"'_class_name' not found in model_index.json for '{model_path}'\"\n            )\n            if backend == Backend.AUTO:\n                logger.info(\"Falling back to diffusers backend\")\n                return _get_diffusers_model_info(\n                    model_path=model_path, model_id=model_id\n                )\n            return None\n\n    pipeline_cls = _PIPELINE_REGISTRY.get(pipeline_class_name)\n    if not pipeline_cls:\n        if backend == Backend.AUTO:\n            logger.warning(\n                f\"Pipeline class '{pipeline_class_name}' specified in '{model_path}' has no native sglang support. \"\n                f\"Falling back to diffusers backend.\"\n            )\n            return _get_diffusers_model_info(model_path=model_path, model_id=model_id)\n        else:\n            logger.error(\n                f\"Pipeline class '{pipeline_class_name}' specified in '{model_path}' is not a registered EntryClass in the framework. \"\n                f\"Available pipelines: {list(_PIPELINE_REGISTRY.keys())}. \"\n                f\"Consider using --backend diffusers to use vanilla diffusers pipeline.\"\n            )\n            return None\n\n    # 3. Get configuration classes (sampling, pipeline config)\n    config_info = _get_config_info(model_path, model_id=model_id)\n    if not config_info:\n        if backend == Backend.AUTO:\n            logger.warning(\n                f\"Could not resolve native configuration for model '{model_path}'. \"\n                f\"Falling back to diffusers backend.\"\n            )\n            return _get_diffusers_model_info(model_path=model_path, model_id=model_id)\n        else:\n            logger.error(\n                f\"Could not resolve configuration for model '{model_path}'. \"\n                \"It is not a registered model path or detected by any registered model family detectors. \"\n                f\"Known model paths: {list(_MODEL_HF_PATH_TO_NAME.keys())}. \"\n                f\"Consider using --backend diffusers to use vanilla diffusers pipeline.\"\n            )\n            return None\n\n    # 4. Combine and return the complete model info\n    logger.debug(\"Using native sglang backend for model '%s'\", model_path)\n    model_info = ModelInfo(\n        pipeline_cls=pipeline_cls,\n        sampling_param_cls=config_info.sampling_param_cls,\n        pipeline_config_cls=config_info.pipeline_config_cls,\n    )\n    logger.debug(f\"Found model info: {model_info}\")\n\n    return model_info\n\n\n# Registration of model configs\ndef _register_configs():\n    # LTX-2\n    register_configs(\n        sampling_param_cls=LTX2SamplingParams,\n        pipeline_config_cls=LTX2PipelineConfig,\n        model_detectors=[\n            lambda path: \"ltx\" in path.lower() and \"video\" in path.lower(),\n            lambda path: \"ltx-2\" in path.lower(),\n        ],\n    )\n\n    # Hunyuan\n    register_configs(\n        sampling_param_cls=HunyuanSamplingParams,\n        pipeline_config_cls=HunyuanConfig,\n        hf_model_paths=[\n            \"hunyuanvideo-community/HunyuanVideo\",\n        ],\n        model_detectors=[lambda hf_id: \"hunyuanvideo\" in hf_id.lower()],\n    )\n    register_configs(\n        sampling_param_cls=FastHunyuanSamplingParam,\n        pipeline_config_cls=FastHunyuanConfig,\n        hf_model_paths=[\n            \"FastVideo/FastHunyuan-diffusers\",\n        ],\n    )\n    # Wan\n    register_configs(\n        sampling_param_cls=WanT2V_1_3B_SamplingParams,\n        pipeline_config_cls=WanT2V480PConfig,\n        hf_model_paths=[\n            \"Wan-AI/Wan2.1-T2V-1.3B-Diffusers\",\n        ],\n        model_detectors=[lambda hf_id: \"wanpipeline\" in hf_id.lower()],\n    )\n    register_configs(\n        sampling_param_cls=WanT2V_1_3B_SamplingParams,\n        pipeline_config_cls=TurboWanT2V480PConfig,\n        hf_model_paths=[\n            \"IPostYellow/TurboWan2.1-T2V-1.3B-Diffusers\",\n        ],\n    )\n    register_configs(\n        sampling_param_cls=WanT2V_14B_SamplingParams,\n        pipeline_config_cls=WanT2V720PConfig,\n        hf_model_paths=[\n            \"Wan-AI/Wan2.1-T2V-14B-Diffusers\",\n        ],\n    )\n    register_configs(\n        sampling_param_cls=WanT2V_14B_SamplingParams,\n        pipeline_config_cls=TurboWanT2V480PConfig,\n        hf_model_paths=[\n            \"IPostYellow/TurboWan2.1-T2V-14B-Diffusers\",\n            \"IPostYellow/TurboWan2.1-T2V-14B-720P-Diffusers\",\n        ],\n    )\n    register_configs(\n        sampling_param_cls=WanI2V_14B_480P_SamplingParam,\n        pipeline_config_cls=WanI2V480PConfig,\n        hf_model_paths=[\n            \"Wan-AI/Wan2.1-I2V-14B-480P-Diffusers\",\n        ],\n        model_detectors=[lambda hf_id: \"wanimagetovideo\" in hf_id.lower()],\n    )\n    register_configs(\n        sampling_param_cls=WanI2V_14B_720P_SamplingParam,\n        pipeline_config_cls=WanI2V720PConfig,\n        hf_model_paths=[\n            \"Wan-AI/Wan2.1-I2V-14B-720P-Diffusers\",\n        ],\n    )\n    register_configs(\n        sampling_param_cls=Turbo_Wan2_2_I2V_A14B_SamplingParam,\n        pipeline_config_cls=TurboWanI2V720Config,\n        hf_model_paths=[\n            \"IPostYellow/TurboWan2.2-I2V-A14B-Diffusers\",\n        ],\n    )\n    register_configs(\n        sampling_param_cls=Wan2_1_Fun_1_3B_InP_SamplingParams,\n        pipeline_config_cls=WanI2V480PConfig,\n        hf_model_paths=[\n            \"weizhou03/Wan2.1-Fun-1.3B-InP-Diffusers\",\n        ],\n    )\n    register_configs(\n        sampling_param_cls=Wan2_2_TI2V_5B_SamplingParam,\n        pipeline_config_cls=Wan2_2_TI2V_5B_Config,\n        hf_model_paths=[\n            \"Wan-AI/Wan2.2-TI2V-5B-Diffusers\",\n        ],\n    )\n    register_configs(\n        sampling_param_cls=Wan2_2_TI2V_5B_SamplingParam,\n        pipeline_config_cls=FastWan2_2_TI2V_5B_Config,\n        hf_model_paths=[\n            \"FastVideo/FastWan2.2-TI2V-5B-FullAttn-Diffusers\",\n            \"FastVideo/FastWan2.2-TI2V-5B-Diffusers\",\n        ],\n    )\n    register_configs(\n        sampling_param_cls=Wan2_2_T2V_A14B_SamplingParam,\n        pipeline_config_cls=Wan2_2_T2V_A14B_Config,\n        hf_model_paths=[\"Wan-AI/Wan2.2-T2V-A14B-Diffusers\"],\n    )\n    register_configs(\n        sampling_param_cls=Wan2_2_I2V_A14B_SamplingParam,\n        pipeline_config_cls=Wan2_2_I2V_A14B_Config,\n        hf_model_paths=[\"Wan-AI/Wan2.2-I2V-A14B-Diffusers\"],\n    )\n    register_configs(\n        sampling_param_cls=FastWanT2V480PConfig,\n        pipeline_config_cls=FastWan2_1_T2V_480P_Config,\n        hf_model_paths=[\n            \"FastVideo/FastWan2.1-T2V-1.3B-Diffusers\",\n        ],\n    )\n    # MOVA\n    register_configs(\n        sampling_param_cls=MOVA_360P_SamplingParams,\n        pipeline_config_cls=MOVA360PConfig,\n        model_detectors=[\n            lambda hf_id: \"mova\" in hf_id.lower() and \"360p\" in hf_id.lower()\n        ],\n    )\n    register_configs(\n        sampling_param_cls=MOVA_720P_SamplingParams,\n        pipeline_config_cls=MOVA720PConfig,\n        model_detectors=[\n            lambda hf_id: \"mova\" in hf_id.lower() and \"720p\" in hf_id.lower()\n        ],\n    )\n    # FLUX\n    register_configs(\n        sampling_param_cls=FluxSamplingParams,\n        pipeline_config_cls=FluxPipelineConfig,\n        hf_model_paths=[\n            \"black-forest-labs/FLUX.1-dev\",\n        ],\n        model_detectors=[lambda hf_id: \"flux.1\" in hf_id.lower()],\n    )\n    register_configs(\n        sampling_param_cls=Flux2KleinSamplingParams,\n        pipeline_config_cls=Flux2KleinPipelineConfig,\n        hf_model_paths=[\n            \"black-forest-labs/FLUX.2-klein-4B\",\n            \"black-forest-labs/FLUX.2-klein-9B\",\n        ],\n        model_detectors=[\n            lambda hf_id: \"flux.2-klein\" in hf_id.lower()\n            or \"flux2-klein\" in hf_id.lower()\n        ],\n    )\n    register_configs(\n        sampling_param_cls=FluxSamplingParams,\n        pipeline_config_cls=Flux2PipelineConfig,\n        hf_model_paths=[\n            \"black-forest-labs/FLUX.2-dev\",\n        ],\n        model_detectors=[\n            lambda hf_id: \"flux.2\" in hf_id.lower() and \"klein\" not in hf_id.lower()\n        ],\n    )\n    register_configs(\n        sampling_param_cls=ZImageTurboSamplingParams,\n        pipeline_config_cls=ZImagePipelineConfig,\n        hf_model_paths=[\n            \"Tongyi-MAI/Z-Image-Turbo\",\n        ],\n        model_detectors=[lambda hf_id: \"z-image-turbo\" in hf_id.lower()],\n    )\n    register_configs(\n        sampling_param_cls=ZImageSamplingParams,\n        pipeline_config_cls=ZImagePipelineConfig,\n        hf_model_paths=[\n            \"Tongyi-MAI/Z-Image\",\n        ],\n        model_detectors=[\n            lambda hf_id: \"z-image\" in hf_id.lower() and \"turbo\" not in hf_id.lower()\n        ],\n    )\n    # Qwen-Image\n    register_configs(\n        sampling_param_cls=QwenImageSamplingParams,\n        pipeline_config_cls=QwenImagePipelineConfig,\n        hf_model_paths=[\"Qwen/Qwen-Image\"],\n        model_detectors=[\n            lambda hf_id: \"qwen-image\" in hf_id.lower()\n            and \"edit\" not in hf_id.lower()\n            and \"layered\" not in hf_id.lower()\n            and \"2512\" not in hf_id.lower()\n        ],\n    )\n    register_configs(\n        sampling_param_cls=QwenImage2512SamplingParams,\n        pipeline_config_cls=QwenImagePipelineConfig,\n        hf_model_paths=[\"Qwen/Qwen-Image-2512\"],\n        model_detectors=[lambda hf_id: \"qwen-image-2512\" in hf_id.lower()],\n    )\n    register_configs(\n        sampling_param_cls=QwenImageSamplingParams,\n        pipeline_config_cls=QwenImageEditPipelineConfig,\n        hf_model_paths=[\"Qwen/Qwen-Image-Edit\"],\n        model_detectors=[\n            lambda hf_id: \"qwen-image-edit\" in hf_id.lower()\n            and \"2509\" not in hf_id.lower()\n            and \"2511\" not in hf_id.lower()\n        ],\n    )\n\n    register_configs(\n        sampling_param_cls=QwenImageEditPlusSamplingParams,\n        pipeline_config_cls=QwenImageEditPlusPipelineConfig,\n        hf_model_paths=[\"Qwen/Qwen-Image-Edit-2509\"],\n        model_detectors=[lambda hf_id: \"qwen-image-edit-2509\" in hf_id.lower()],\n    )\n\n    register_configs(\n        sampling_param_cls=QwenImageEditPlusSamplingParams,\n        pipeline_config_cls=QwenImageEditPlus_2511_PipelineConfig,\n        hf_model_paths=[\"Qwen/Qwen-Image-Edit-2511\"],\n        model_detectors=[lambda hf_id: \"qwen-image-edit-2511\" in hf_id.lower()],\n    )\n\n    register_configs(\n        sampling_param_cls=QwenImageLayeredSamplingParams,\n        pipeline_config_cls=QwenImageLayeredPipelineConfig,\n        hf_model_paths=[\"Qwen/Qwen-Image-Layered\"],\n        model_detectors=[lambda hf_id: \"qwen-image-layered\" in hf_id.lower()],\n    )\n\n    register_configs(\n        sampling_param_cls=GlmImageSamplingParams,\n        pipeline_config_cls=GlmImagePipelineConfig,\n        model_detectors=[lambda hf_id: \"glm-image\" in hf_id.lower()],\n    )\n    register_configs(\n        sampling_param_cls=Hunyuan3DSamplingParams,\n        pipeline_config_cls=Hunyuan3D2PipelineConfig,\n        hf_model_paths=[\n            \"tencent/Hunyuan3D-2\",\n        ],\n        model_detectors=[lambda hf_id: \"hunyuan3d\" in hf_id.lower()],\n    )\n\n    # Helios\n    register_configs(\n        sampling_param_cls=HeliosT2VSamplingParams,\n        pipeline_config_cls=HeliosT2VConfig,\n        hf_model_paths=[\n            \"BestWishYsh/Helios-Base\",\n        ],\n        model_detectors=[\n            lambda hf_id: \"helios\" in hf_id.lower()\n            and \"mid\" not in hf_id.lower()\n            and \"distill\" not in hf_id.lower()\n        ],\n    )\n    register_configs(\n        sampling_param_cls=HeliosMidSamplingParams,\n        pipeline_config_cls=HeliosMidConfig,\n        hf_model_paths=[\n            \"BestWishYsh/Helios-Mid\",\n        ],\n    )\n    register_configs(\n        sampling_param_cls=HeliosDistilledSamplingParams,\n        pipeline_config_cls=HeliosDistilledConfig,\n        hf_model_paths=[\n            \"BestWishYsh/Helios-Distilled\",\n        ],\n    )\n\n    # SANA\n    register_configs(\n        sampling_param_cls=SanaSamplingParams,\n        pipeline_config_cls=SanaPipelineConfig,\n        hf_model_paths=[\n            \"Efficient-Large-Model/SANA1.5_1.6B_1024px_diffusers\",\n            \"Efficient-Large-Model/SANA1.5_4.8B_1024px_diffusers\",\n            \"Efficient-Large-Model/Sana_1600M_1024px_diffusers\",\n            \"Efficient-Large-Model/Sana_600M_1024px_diffusers\",\n            \"Efficient-Large-Model/Sana_1600M_512px_diffusers\",\n            \"Efficient-Large-Model/Sana_600M_512px_diffusers\",\n        ],\n        model_detectors=[lambda hf_id: \"sana\" in hf_id.lower()],\n    )\n\n\n_register_configs()\n\n\n# Known non-diffusers multimodal model patterns\n# Maps pattern -> pipeline_name for models that don't have model_index.json\n_NON_DIFFUSERS_MULTIMODAL_PATTERNS: Dict[str, str] = {\n    \"hunyuan3d\": \"Hunyuan3D2Pipeline\",\n}\n\n\ndef is_known_non_diffusers_multimodal_model(model_path: str) -> bool:\n    model_path_lower = model_path.lower()\n    return any(\n        pattern in model_path_lower for pattern in _NON_DIFFUSERS_MULTIMODAL_PATTERNS\n    )\n\n\ndef get_non_diffusers_pipeline_name(model_path: str) -> Optional[str]:\n    \"\"\"Get the pipeline name for a known non-diffusers model.\"\"\"\n    model_path_lower = model_path.lower()\n    for pattern, pipeline_name in _NON_DIFFUSERS_MULTIMODAL_PATTERNS.items():\n        if pattern in model_path_lower:\n            return pipeline_name\n    return None\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/cache/__init__.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\"\"\"\nCache acceleration module for SGLang-diffusion\n\nThis module provides various caching strategies to accelerate\ndiffusion transformer (DiT) inference:\n\n- TeaCache: Temporal similarity-based caching for diffusion models\n- cache-dit integration: Block-level caching with DBCache and TaylorSeer\n\n\"\"\"\n\nfrom sglang.multimodal_gen.runtime.cache.cache_dit_integration import (\n    CacheDitConfig,\n    enable_cache_on_dual_transformer,\n    enable_cache_on_transformer,\n    get_scm_mask,\n)\nfrom sglang.multimodal_gen.runtime.cache.teacache import TeaCacheContext, TeaCacheMixin\n\n__all__ = [\n    # TeaCache (always available)\n    \"TeaCacheContext\",\n    \"TeaCacheMixin\",\n    # cache-dit integration (lazy-loaded, requires cache-dit package)\n    \"CacheDitConfig\",\n    \"enable_cache_on_transformer\",\n    \"enable_cache_on_dual_transformer\",\n    \"get_scm_mask\",\n]\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/cache/cache_dit_integration.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\"\"\"\ncache-dit integration module for SGLang DiT pipelines.\n\nThis module provides helper functions to enable cache-dit acceleration\non transformer modules in SGLang's modular pipeline architecture.\n\"\"\"\n\nfrom dataclasses import dataclass\nfrom typing import List, Optional\n\nimport torch\nimport torch.distributed as dist\n\nfrom sglang.multimodal_gen.runtime.distributed.parallel_state import (\n    get_ring_parallel_world_size,\n    get_tp_world_size,\n    get_ulysses_parallel_world_size,\n)\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\nimport cache_dit\nfrom cache_dit import (\n    BlockAdapter,\n    DBCacheConfig,\n    ForwardPattern,\n    ParamsModifier,\n    TaylorSeerCalibratorConfig,\n    steps_mask,\n)\nfrom cache_dit.caching.block_adapters import BlockAdapterRegister\nfrom cache_dit.parallelism import ParallelismBackend, ParallelismConfig\n\nfrom sglang.multimodal_gen.runtime.distributed.parallel_state import get_dit_group\n\n_original_similarity = None\n\n\ndef _patch_cache_dit_similarity():\n    from cache_dit.caching.cache_contexts import cache_manager\n\n    global _original_similarity\n    if _original_similarity is not None:\n        return\n\n    _original_similarity = cache_manager.CachedContextManager.similarity\n\n    def patched_similarity(self, t1, t2, *, threshold, parallelized=False, prefix=\"Fn\"):\n        if not parallelized:\n            return _original_similarity(\n                self,\n                t1,\n                t2,\n                threshold=threshold,\n                parallelized=parallelized,\n                prefix=prefix,\n            )\n\n        sp_group = getattr(self, \"_sglang_sp_group\", None)\n        tp_group = getattr(self, \"_sglang_tp_group\", None)\n        tp_sp_group = getattr(self, \"_sglang_tp_sp_group\", None)\n        target_group = tp_sp_group or sp_group or tp_group\n\n        if target_group is None:\n            return _original_similarity(\n                self,\n                t1,\n                t2,\n                threshold=threshold,\n                parallelized=parallelized,\n                prefix=prefix,\n            )\n\n        # Adapted from https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/caching/cache_contexts/cache_manager.py#L495-L523\n        condition_thresh = self.get_important_condition_threshold()\n        if condition_thresh > 0.0:\n            raw_diff = (t1 - t2).abs()\n            token_m_df = raw_diff.mean(dim=-1)\n            token_m_t1 = t1.abs().mean(dim=-1)\n            token_diff = token_m_df / token_m_t1\n            condition = token_diff > condition_thresh\n            if condition.sum() > 0:\n                condition = condition.unsqueeze(-1).expand_as(raw_diff)\n                mean_diff = raw_diff[condition].mean()\n                mean_t1 = t1[condition].abs().mean()\n            else:\n                mean_diff = (t1 - t2).abs().mean()\n                mean_t1 = t1.abs().mean()\n        else:\n            mean_diff = (t1 - t2).abs().mean()\n            mean_t1 = t1.abs().mean()\n\n        dist.all_reduce(mean_diff, op=dist.ReduceOp.AVG, group=target_group)\n        dist.all_reduce(mean_t1, op=dist.ReduceOp.AVG, group=target_group)\n\n        diff = (mean_diff / mean_t1).item()\n        self.add_residual_diff(diff)\n        return diff < threshold\n\n    cache_manager.CachedContextManager.similarity = patched_similarity\n\n\ndef _build_parallelism_config(\n    sp_group: Optional[torch.distributed.ProcessGroup],\n    tp_group: Optional[torch.distributed.ProcessGroup],\n):\n    if sp_group is None and tp_group is None:\n        return None\n\n    ulysses_size = None\n    ring_size = None\n    if sp_group is not None:\n        ulysses_size = get_ulysses_parallel_world_size()\n        ring_size = get_ring_parallel_world_size()\n\n    tp_size = None\n    if tp_group is not None:\n        tp_size = get_tp_world_size()\n\n    return ParallelismConfig(\n        backend=ParallelismBackend.AUTO,\n        ulysses_size=ulysses_size,\n        ring_size=ring_size,\n        tp_size=tp_size,\n    )\n\n\ndef _mark_transformer_parallelized(transformer, config, sp_group, tp_group):\n    if config is None:\n        return\n\n    transformer._is_parallelized = True\n    transformer._parallelism_config = config\n\n\ndef get_scm_mask(\n    preset: str,\n    num_inference_steps: int,\n    compute_bins: Optional[List[int]] = None,\n    cache_bins: Optional[List[int]] = None,\n) -> Optional[List[int]]:\n    \"\"\"\n    Get SCM mask using cache-dit's steps_mask().\n\n    This is a thin wrapper that delegates to cache-dit's built-in\n    steps_mask() function which handles all presets and scaling logic.\n\n    Args:\n        preset: Preset name (\"none\", \"slow\", \"medium\", \"fast\", \"ultra\").\n        compute_bins: Custom compute bins (overrides preset).\n        cache_bins: Custom cache bins (overrides preset).\n\n    Returns:\n        SCM mask list (1=compute, 0=cache), or None if disabled.\n    \"\"\"\n    if preset == \"none\" and not (compute_bins and cache_bins):\n        return None\n\n    # Use cache-dit's steps_mask() directly\n    mask = steps_mask(\n        compute_bins=compute_bins,\n        cache_bins=cache_bins,\n        total_steps=num_inference_steps,\n        mask_policy=preset if preset != \"none\" else \"medium\",\n    )\n\n    compute_count = sum(mask)\n    cache_count = len(mask) - compute_count\n    logger.info(\n        \"SCM: generated mask with %d compute steps, %d cache steps (preset=%s)\",\n        compute_count,\n        cache_count,\n        preset,\n    )\n\n    return mask\n\n\n@dataclass\nclass CacheDitConfig:\n    \"\"\"Configuration for cache-dit integration.\n\n    Attributes:\n        enabled: Whether to enable cache-dit acceleration.\n        Fn_compute_blocks: Number of first blocks to always compute (DBCache F).\n        Bn_compute_blocks: Number of last blocks to always compute (DBCache B).\n        max_warmup_steps: Number of warmup steps before caching starts (DBCache W).\n        residual_diff_threshold: Threshold for residual difference (DBCache R).\n        max_continuous_cached_steps: Maximum consecutive cached steps (DBCache MC).\n        enable_taylorseer: Whether to enable TaylorSeer calibrator.\n        taylorseer_order: Order of Taylor expansion (1 or 2).\n        num_inference_steps: Total number of inference steps (required for transformer-only mode).\n        steps_computation_mask: Binary mask for step-level caching (1=compute, 0=cache).\n            Generated by get_scm_mask() (wrapper around cache_dit.steps_mask()).\n        steps_computation_policy: Caching policy for SCM (\"dynamic\" or \"static\").\n    \"\"\"\n\n    enabled: bool = False\n    Fn_compute_blocks: int = 1\n    Bn_compute_blocks: int = 0\n    # Use 4 as default warmup steps instead of 8 in cache-dit, thus making\n    # DBCache work for few steps distilled models, e.g., Z-Image w/ 8-steps.\n    max_warmup_steps: int = 4\n    # Use a relatively higher residual diff threshold (namely, 0.24) as default\n    # to allow more aggressive caching due to we have already applied max continuous\n    # cached steps limit, otherwise, we should use a lower threshold here like 0.12.\n    residual_diff_threshold: float = 0.24\n    max_continuous_cached_steps: int = 3\n    # TaylorSeer is not suitable for few steps distilled models, so, we choose\n    # to disable it by default. Reference:\n    # - From Reusing to Forecasting: Accelerating Diffusion Models with TaylorSeers,\n    #   https://arxiv.org/pdf/2503.06923\n    # - FoCa: Forecast then Calibrate: Feature Caching as ODE for Efficient\n    #   Diffusion Transformers, https://arxiv.org/pdf/2508.16211\n    enable_taylorseer: bool = False\n    taylorseer_order: int = 1\n    num_inference_steps: Optional[int] = None\n    # SCM fields (generated by _maybe_enable_cache_dit from env configuration)\n    steps_computation_mask: Optional[List[int]] = None\n    steps_computation_policy: str = \"dynamic\"\n\n\ndef enable_cache_on_transformer(\n    transformer: torch.nn.Module,\n    config: CacheDitConfig,\n    model_name: str = \"transformer\",\n    sp_group: Optional[torch.distributed.ProcessGroup] = None,\n    tp_group: Optional[torch.distributed.ProcessGroup] = None,\n) -> torch.nn.Module:\n    \"\"\"Enable cache-dit on a transformer module, by wrapping the module with cache-dit\n\n    This function enables cache-dit acceleration using the BlockAdapterRegister\n    for pre-registered models\n\n    Args:\n        model_name: Name of the model for logging purposes.\n        sp_group: Sequence parallel process group (for Ulysses/Ring).\n        tp_group: Tensor parallel process group.\n\n    \"\"\"\n    if not config.enabled:\n        return transformer\n\n    if config.num_inference_steps is None:\n        raise ValueError(\n            \"num_inference_steps is required for transformer-only mode. \"\n            \"Please provide it in CacheDitConfig.\"\n        )\n\n    # Check if the transformer is pre-registered in cache-dit\n    if not BlockAdapterRegister.is_supported(transformer):\n        transformer_cls_name = transformer.__class__.__name__\n        raise ValueError(\n            f\"{transformer_cls_name} is not officially supported by cache-dit. \"\n            \"Supported cache-dit DiT families include Flux, QwenImage, HunyuanDiT, \"\n            \"HunyuanVideo, Wan, CogVideoX, Mochi, and others. \"\n            \"Please ensure your transformer belongs to one of these families or \"\n            \"define a custom BlockAdapter.\"\n        )\n\n    # Build cache config (including SCM fields if provided)\n    cache_config = DBCacheConfig(\n        num_inference_steps=config.num_inference_steps,\n        Fn_compute_blocks=config.Fn_compute_blocks,\n        Bn_compute_blocks=config.Bn_compute_blocks,\n        max_warmup_steps=config.max_warmup_steps,\n        residual_diff_threshold=config.residual_diff_threshold,\n        max_continuous_cached_steps=config.max_continuous_cached_steps,\n        # SCM fields\n        steps_computation_mask=config.steps_computation_mask,\n        steps_computation_policy=config.steps_computation_policy,\n    )\n\n    # Build calibrator config if TaylorSeer is enabled\n    calibrator_config = None\n    if config.enable_taylorseer:\n        calibrator_config = TaylorSeerCalibratorConfig(\n            taylorseer_order=config.taylorseer_order,\n        )\n\n    # Enable cache-dit on the transformer\n    logger.info(\n        \"Enabling cache-dit on %s with config: Fn=%d, Bn=%d, W=%d, R=%.2f, MC=%d, \"\n        \"TaylorSeer=%s (order=%d), steps=%d\",\n        model_name,\n        config.Fn_compute_blocks,\n        config.Bn_compute_blocks,\n        config.max_warmup_steps,\n        config.residual_diff_threshold,\n        config.max_continuous_cached_steps,\n        config.enable_taylorseer,\n        config.taylorseer_order,\n        config.num_inference_steps,\n    )\n\n    # Log SCM configuration if enabled\n    if config.steps_computation_mask:\n        compute_steps = sum(config.steps_computation_mask)\n        cache_steps = len(config.steps_computation_mask) - compute_steps\n        logger.info(\n            \"SCM enabled: %d compute steps, %d cache steps, policy=%s\",\n            compute_steps,\n            cache_steps,\n            config.steps_computation_policy,\n        )\n\n    parallelism_config = _build_parallelism_config(sp_group, tp_group)\n    if parallelism_config is not None:\n        _patch_cache_dit_similarity()\n\n    _mark_transformer_parallelized(transformer, parallelism_config, sp_group, tp_group)\n\n    cache_dit.enable_cache(\n        transformer,\n        cache_config=cache_config,\n        calibrator_config=calibrator_config,\n        parallelism_config=None,\n    )\n\n    if parallelism_config is not None:\n        context_manager = getattr(transformer, \"_context_manager\", None)\n        if context_manager is not None:\n            context_manager._sglang_sp_group = sp_group\n            context_manager._sglang_tp_group = tp_group\n            # In mixed TP + SP (Ulysses/Ring) mode, cache-dit decisions must be consistent\n            # across the full TP×SP model-parallel slice. Prefer using SGLang's DIT group\n            # as a conservative superset group; fallback to None.\n            tp_sp_group = None\n            if sp_group is not None and tp_group is not None:\n                tp_sp_group = get_dit_group()\n\n            context_manager._sglang_tp_sp_group = tp_sp_group\n\n    return transformer\n\n\ndef enable_cache_on_dual_transformer(\n    transformer: torch.nn.Module,\n    transformer_2: torch.nn.Module,\n    primary_config: CacheDitConfig,\n    secondary_config: CacheDitConfig,\n    model_name: str = \"wan2.2\",\n    sp_group: Optional[torch.distributed.ProcessGroup] = None,\n    tp_group: Optional[torch.distributed.ProcessGroup] = None,\n) -> tuple[torch.nn.Module, torch.nn.Module]:\n    \"\"\"Enable cache-dit on dual transformers using BlockAdapter.\n\n    For models with two transformers (high-noise expert and low-noise expert),\n    cache-dit requires enabling cache on both simultaneously via BlockAdapter.\n    This cannot be done by calling enable_cache separately on each transformer.\n\n    Args:\n        primary_config: CacheDitConfig for primary transformer.\n        secondary_config: CacheDitConfig for secondary transformer.\n        sp_group: Sequence parallel process group (for Ulysses/Ring).\n        tp_group: Tensor parallel process group.\n    \"\"\"\n    _supported_dual_transformer_models = [\n        \"wan2.2\",  # Currently, only Wan2.2 will run into dual-transformer case\n    ]\n    if model_name not in _supported_dual_transformer_models:\n        raise ValueError(\n            f\"Dual-transformer cache-dit is only supported for \"\n            f\"{_supported_dual_transformer_models}, got {model_name}.\"\n        )\n\n    if not primary_config.enabled:\n        return transformer, transformer_2\n\n    if primary_config.num_inference_steps is None:\n        raise ValueError(\n            \"num_inference_steps is required for dual-transformer mode. \"\n            \"Please provide it in CacheDitConfig.\"\n        )\n\n    # Build DBCacheConfig for primary transformer\n    primary_cache_config = DBCacheConfig(\n        num_inference_steps=primary_config.num_inference_steps,\n        Fn_compute_blocks=primary_config.Fn_compute_blocks,\n        Bn_compute_blocks=primary_config.Bn_compute_blocks,\n        max_warmup_steps=primary_config.max_warmup_steps,\n        residual_diff_threshold=primary_config.residual_diff_threshold,\n        max_continuous_cached_steps=primary_config.max_continuous_cached_steps,\n        steps_computation_mask=primary_config.steps_computation_mask,\n        steps_computation_policy=primary_config.steps_computation_policy,\n    )\n\n    # Build DBCacheConfig for secondary transformer\n    secondary_cache_config = DBCacheConfig(\n        num_inference_steps=secondary_config.num_inference_steps,\n        Fn_compute_blocks=secondary_config.Fn_compute_blocks,\n        Bn_compute_blocks=secondary_config.Bn_compute_blocks,\n        max_warmup_steps=secondary_config.max_warmup_steps,\n        residual_diff_threshold=secondary_config.residual_diff_threshold,\n        max_continuous_cached_steps=secondary_config.max_continuous_cached_steps,\n        steps_computation_mask=secondary_config.steps_computation_mask,\n        steps_computation_policy=secondary_config.steps_computation_policy,\n    )\n\n    # Build calibrator configs if TaylorSeer is enabled\n    primary_calibrator = None\n    if primary_config.enable_taylorseer:\n        primary_calibrator = TaylorSeerCalibratorConfig(\n            taylorseer_order=primary_config.taylorseer_order,\n        )\n\n    secondary_calibrator = None\n    if secondary_config.enable_taylorseer:\n        secondary_calibrator = TaylorSeerCalibratorConfig(\n            taylorseer_order=secondary_config.taylorseer_order,\n        )\n\n    # Build ParamsModifier for each transformer\n    primary_modifier = ParamsModifier(\n        cache_config=primary_cache_config,\n        calibrator_config=primary_calibrator,\n    )\n    secondary_modifier = ParamsModifier(\n        cache_config=secondary_cache_config,\n        calibrator_config=secondary_calibrator,\n    )\n\n    # Log configuration\n    logger.info(\n        \"Enabling cache-dit on %s dual transformers with BlockAdapter\",\n        model_name,\n    )\n    logger.info(\n        \"  Primary (transformer): Fn=%d, Bn=%d, W=%d, R=%.2f, MC=%d, TaylorSeer=%s\",\n        primary_config.Fn_compute_blocks,\n        primary_config.Bn_compute_blocks,\n        primary_config.max_warmup_steps,\n        primary_config.residual_diff_threshold,\n        primary_config.max_continuous_cached_steps,\n        primary_config.enable_taylorseer,\n    )\n    logger.info(\n        \"  Secondary (transformer_2): Fn=%d, Bn=%d, W=%d, R=%.2f, MC=%d, TaylorSeer=%s\",\n        secondary_config.Fn_compute_blocks,\n        secondary_config.Bn_compute_blocks,\n        secondary_config.max_warmup_steps,\n        secondary_config.residual_diff_threshold,\n        secondary_config.max_continuous_cached_steps,\n        secondary_config.enable_taylorseer,\n    )\n\n    # Log SCM configuration if enabled\n    if primary_config.steps_computation_mask:\n        compute_steps = sum(primary_config.steps_computation_mask)\n        cache_steps = len(primary_config.steps_computation_mask) - compute_steps\n        logger.info(\n            \"  SCM enabled for primary transformer: %d compute steps, %d cache steps, policy=%s\",\n            compute_steps,\n            cache_steps,\n            primary_config.steps_computation_policy,\n        )\n    if secondary_config.steps_computation_mask:\n        compute_steps = sum(secondary_config.steps_computation_mask)\n        cache_steps = len(secondary_config.steps_computation_mask) - compute_steps\n        logger.info(\n            \"  SCM enabled for secondary transformer: %d compute steps, %d cache steps, policy=%s\",\n            compute_steps,\n            cache_steps,\n            secondary_config.steps_computation_policy,\n        )\n\n    parallelism_config = _build_parallelism_config(sp_group, tp_group)\n    if parallelism_config is not None:\n        _patch_cache_dit_similarity()\n\n    _mark_transformer_parallelized(transformer, parallelism_config, sp_group, tp_group)\n    _mark_transformer_parallelized(\n        transformer_2, parallelism_config, sp_group, tp_group\n    )\n\n    # Get blocks attribute - Wan transformers use 'blocks' attribute\n    transformer_blocks = getattr(transformer, \"blocks\", None)\n    transformer_2_blocks = getattr(transformer_2, \"blocks\", None)\n\n    if transformer_blocks is None or transformer_2_blocks is None:\n        raise ValueError(\n            \"Dual transformers must have 'blocks' attribute for cache-dit. \"\n            f\"transformer has blocks: {transformer_blocks is not None}, \"\n            f\"transformer_2 has blocks: {transformer_2_blocks is not None}\"\n        )\n\n    # Enable cache-dit using BlockAdapter for both transformers simultaneously\n    # This is required for Wan2.2 and similar dual-transformer architectures\n    if model_name == \"wan2.2\":\n        # Use Pattern_2 for Wan2.2 dual-transformer. We should check `model_name`\n        # to ensure we only apply this for supported models. Different models\n        # may require different ForwardPattern.\n        cache_dit.enable_cache(\n            BlockAdapter(\n                transformer=[transformer, transformer_2],\n                blocks=[transformer_blocks, transformer_2_blocks],\n                forward_pattern=[ForwardPattern.Pattern_2, ForwardPattern.Pattern_2],\n                params_modifiers=[primary_modifier, secondary_modifier],\n                has_separate_cfg=True,\n            ),\n            parallelism_config=None,\n        )\n    else:\n        raise ValueError(\n            f\"Dual-transformer is not implemented for model {model_name} yet.\"\n        )\n\n    if parallelism_config is not None:\n        for t in [transformer, transformer_2]:\n            context_manager = getattr(t, \"_context_manager\", None)\n            if context_manager is not None:\n                context_manager._sglang_sp_group = sp_group\n                context_manager._sglang_tp_group = tp_group\n                tp_sp_group = None\n                if sp_group is not None and tp_group is not None:\n                    try:\n                        tp_sp_group = get_dit_group()\n                    except Exception:\n                        tp_sp_group = None\n                context_manager._sglang_tp_sp_group = tp_sp_group\n\n    return transformer, transformer_2\n\n\ndef refresh_context_on_transformer(\n    transformer: torch.nn.Module,\n    num_inference_steps: int,\n    scm_preset: str | None = None,\n    verbose: bool = False,\n) -> None:\n    \"\"\"Refresh cache-dit context for transformer.\"\"\"\n    cache_dit.refresh_context(\n        transformer,\n        cache_config=DBCacheConfig().reset(\n            num_inference_steps=num_inference_steps,\n            steps_computation_mask=cache_dit.steps_mask(\n                mask_policy=scm_preset, total_steps=num_inference_steps\n            ),\n            steps_computation_policy=scm_preset,\n        ),\n        verbose=verbose,\n    )\n    logger.debug(f\"cache-dit refreshed on transformer (steps={num_inference_steps})\")\n\n\ndef refresh_context_on_dual_transformer(\n    transformer: torch.nn.Module,\n    transformer_2: torch.nn.Module,\n    num_high_noise_steps: int,\n    num_low_noise_steps: int,\n    scm_preset: str | None = None,\n    verbose: bool = False,\n) -> None:\n    \"\"\"Refresh cache-dit context for dual transformers.\"\"\"\n    cache_dit.refresh_context(\n        transformer,\n        cache_config=DBCacheConfig().reset(\n            num_inference_steps=num_high_noise_steps,\n            steps_computation_mask=cache_dit.steps_mask(\n                mask_policy=scm_preset, total_steps=num_high_noise_steps\n            ),\n            steps_computation_policy=scm_preset,\n        ),\n        verbose=verbose,\n    )\n    cache_dit.refresh_context(\n        transformer_2,\n        cache_config=DBCacheConfig().reset(\n            num_inference_steps=num_low_noise_steps,\n            steps_computation_mask=cache_dit.steps_mask(\n                mask_policy=scm_preset, total_steps=num_low_noise_steps\n            ),\n            steps_computation_policy=scm_preset,\n        ),\n        verbose=verbose,\n    )\n    logger.debug(\n        f\"cache-dit refreshed on dual transformers (steps={num_high_noise_steps}, {num_low_noise_steps})\"\n    )\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/cache/teacache.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\"\"\"\nTeaCache: Temporal similarity-based caching for diffusion models.\n\nTeaCache accelerates diffusion inference by selectively skipping redundant\ncomputation when consecutive diffusion steps are similar enough. This is\nachieved by tracking the L1 distance between modulated inputs across timesteps.\n\nKey concepts:\n- Modulated input: The input to transformer blocks after timestep conditioning\n- L1 distance: Measures how different consecutive timesteps are\n- Threshold: When accumulated L1 distance exceeds threshold, force computation\n- CFG support: Separate caches for positive and negative branches\n\nReferences:\n- TeaCache: Accelerating Diffusion Models with Temporal Similarity\n  https://arxiv.org/abs/2411.14324\n\"\"\"\n\nfrom dataclasses import dataclass\nfrom typing import TYPE_CHECKING, Any\n\nimport numpy as np\nimport torch\n\nfrom sglang.multimodal_gen.configs.models import DiTConfig\n\nif TYPE_CHECKING:\n    from sglang.multimodal_gen.configs.sample.teacache import TeaCacheParams\n\n\n@dataclass\nclass TeaCacheContext:\n    \"\"\"Common context extracted for TeaCache skip decision.\n\n    This context is populated from the forward_batch and forward_context\n    during each denoising step, providing all information needed to make\n    cache decisions.\n\n    Attributes:\n        current_timestep: Current denoising timestep index (0-indexed).\n        num_inference_steps: Total number of inference steps.\n        do_cfg: Whether classifier-free guidance is enabled.\n        is_cfg_negative: True if currently processing negative CFG branch.\n        teacache_thresh: Threshold for accumulated L1 distance.\n        coefficients: Polynomial coefficients for L1 rescaling.\n        teacache_params: Full TeaCacheParams for model-specific access.\n    \"\"\"\n\n    current_timestep: int\n    num_inference_steps: int\n    do_cfg: bool\n    is_cfg_negative: bool  # For CFG branch selection\n    teacache_thresh: float\n    coefficients: list[float]\n    teacache_params: \"TeaCacheParams\"  # Full params for model-specific access\n\n\nclass TeaCacheMixin:\n    \"\"\"\n    Mixin class providing TeaCache optimization functionality.\n\n    TeaCache accelerates diffusion inference by selectively skipping redundant\n    computation when consecutive diffusion steps are similar enough.\n\n    This mixin should be inherited by DiT model classes that want to support\n    TeaCache optimization. It provides:\n    - State management for tracking L1 distances\n    - CFG-aware caching (separate caches for positive/negative branches)\n    - Decision logic for when to compute vs. use cache\n\n    Example usage in a DiT model:\n        class MyDiT(TeaCacheMixin, BaseDiT):\n            def __init__(self, config, **kwargs):\n                super().__init__(config, **kwargs)\n                self._init_teacache_state()\n\n            def forward(self, hidden_states, timestep, ...):\n                ctx = self._get_teacache_context()\n                if ctx is not None:\n                    # Compute modulated input (model-specific, e.g., after timestep embedding)\n                    modulated_input = self._compute_modulated_input(hidden_states, timestep)\n                    is_boundary = (ctx.current_timestep == 0 or\n                                   ctx.current_timestep >= ctx.num_inference_steps - 1)\n\n                    should_calc = self._compute_teacache_decision(\n                        modulated_inp=modulated_input,\n                        is_boundary_step=is_boundary,\n                        coefficients=ctx.coefficients,\n                        teacache_thresh=ctx.teacache_thresh,\n                    )\n\n                    if not should_calc:\n                        # Use cached residual (must implement retrieve_cached_states)\n                        return self.retrieve_cached_states(hidden_states)\n\n                # Normal forward pass...\n                output = self._transformer_forward(hidden_states, timestep, ...)\n\n                # Cache states for next step\n                if ctx is not None:\n                    self.maybe_cache_states(output, hidden_states)\n\n                return output\n\n    Subclass implementation notes:\n        - `_compute_modulated_input()`: Model-specific method to compute the input\n          after timestep conditioning (used for L1 distance calculation)\n        - `retrieve_cached_states()`: Must be overridden to return cached output\n        - `maybe_cache_states()`: Override to store states for cache retrieval\n\n    Attributes:\n        cnt: Counter for tracking steps.\n        enable_teacache: Whether TeaCache is enabled.\n        previous_modulated_input: Cached modulated input for positive branch.\n        previous_residual: Cached residual for positive branch.\n        accumulated_rel_l1_distance: Accumulated L1 distance for positive branch.\n        is_cfg_negative: Whether currently processing negative CFG branch.\n        _supports_cfg_cache: Whether this model supports CFG cache separation.\n\n    CFG-specific attributes (only when _supports_cfg_cache is True):\n        previous_modulated_input_negative: Cached input for negative branch.\n        previous_residual_negative: Cached residual for negative branch.\n        accumulated_rel_l1_distance_negative: L1 distance for negative branch.\n    \"\"\"\n\n    # Models that support CFG cache separation (wan/hunyuan/zimage)\n    # Models not in this set (flux/qwen) auto-disable TeaCache when CFG is enabled\n    _CFG_SUPPORTED_PREFIXES: set[str] = {\"wan\", \"hunyuan\", \"zimage\"}\n    config: DiTConfig\n\n    def _init_teacache_state(self) -> None:\n        \"\"\"Initialize TeaCache state. Call this in subclass __init__.\"\"\"\n        # Common TeaCache state\n        self.cnt = 0\n        self.enable_teacache = True\n        # Flag indicating if this model supports CFG cache separation\n        self._supports_cfg_cache = (\n            self.config.prefix.lower() in self._CFG_SUPPORTED_PREFIXES\n        )\n\n        # Always initialize positive cache fields (used in all modes)\n        self.previous_modulated_input: torch.Tensor | None = None\n        self.previous_residual: torch.Tensor | None = None\n        self.accumulated_rel_l1_distance: float = 0.0\n\n        self.is_cfg_negative = False\n        # CFG-specific fields initialized to None (created when CFG is used)\n        # These are only used when _supports_cfg_cache is True AND do_cfg is True\n        if self._supports_cfg_cache:\n            self.previous_modulated_input_negative: torch.Tensor | None = None\n            self.previous_residual_negative: torch.Tensor | None = None\n            self.accumulated_rel_l1_distance_negative: float = 0.0\n\n    def reset_teacache_state(self) -> None:\n        \"\"\"Reset all TeaCache state at the start of each generation task.\"\"\"\n        self.cnt = 0\n\n        # Primary cache fields (always present)\n        self.previous_modulated_input = None\n        self.previous_residual = None\n        self.accumulated_rel_l1_distance = 0.0\n        self.is_cfg_negative = False\n        self.enable_teacache = True\n        # CFG negative cache fields (always reset, may be unused)\n        if self._supports_cfg_cache:\n            self.previous_modulated_input_negative = None\n            self.previous_residual_negative = None\n            self.accumulated_rel_l1_distance_negative = 0.0\n\n    def _compute_l1_and_decide(\n        self,\n        modulated_inp: torch.Tensor,\n        coefficients: list[float],\n        teacache_thresh: float,\n    ) -> tuple[float, bool]:\n        \"\"\"\n        Compute L1 distance and decide whether to calculate or use cache.\n\n        Args:\n            modulated_inp: Current timestep's modulated input.\n            coefficients: Polynomial coefficients for L1 rescaling.\n            teacache_thresh: Threshold for cache decision.\n\n        Returns:\n            Tuple of (new_accumulated_distance, should_calc).\n        \"\"\"\n        prev_modulated_inp = (\n            self.previous_modulated_input_negative\n            if self.is_cfg_negative\n            else self.previous_modulated_input\n        )\n\n        # Defensive check: if previous input is not set, force calculation\n        if prev_modulated_inp is None:\n            return 0.0, True\n\n        # Compute relative L1 distance\n        diff = modulated_inp - prev_modulated_inp\n        rel_l1 = (diff.abs().mean() / prev_modulated_inp.abs().mean()).cpu().item()\n\n        # Apply polynomial rescaling\n        rescale_func = np.poly1d(coefficients)\n\n        accumulated_rel_l1_distance = (\n            self.accumulated_rel_l1_distance_negative\n            if self.is_cfg_negative\n            else self.accumulated_rel_l1_distance\n        )\n        accumulated_rel_l1_distance = accumulated_rel_l1_distance + rescale_func(rel_l1)\n\n        if accumulated_rel_l1_distance >= teacache_thresh:\n            # Threshold exceeded: force compute and reset accumulator\n            return 0.0, True\n        # Cache hit: keep accumulated distance\n        return accumulated_rel_l1_distance, False\n\n    def _compute_teacache_decision(\n        self,\n        modulated_inp: torch.Tensor,\n        is_boundary_step: bool,\n        coefficients: list[float],\n        teacache_thresh: float,\n    ) -> bool:\n        \"\"\"\n        Compute cache decision for TeaCache.\n\n        Args:\n            modulated_inp: Current timestep's modulated input.\n            is_boundary_step: True for boundary timesteps that always compute.\n            coefficients: Polynomial coefficients for L1 rescaling.\n            teacache_thresh: Threshold for cache decision.\n\n        Returns:\n            True if forward computation is needed, False to use cache.\n        \"\"\"\n        if not self.enable_teacache:\n            return True\n\n        if is_boundary_step:\n            new_accum, should_calc = 0.0, True\n        else:\n            new_accum, should_calc = self._compute_l1_and_decide(\n                modulated_inp=modulated_inp,\n                coefficients=coefficients,\n                teacache_thresh=teacache_thresh,\n            )\n\n        # Advance baseline and accumulator for the active branch\n        if not self.is_cfg_negative:\n            self.previous_modulated_input = modulated_inp.clone()\n            self.accumulated_rel_l1_distance = new_accum\n        elif self._supports_cfg_cache:\n            self.previous_modulated_input_negative = modulated_inp.clone()\n            self.accumulated_rel_l1_distance_negative = new_accum\n\n        return should_calc\n\n    def _get_teacache_context(self) -> TeaCacheContext | None:\n        \"\"\"\n        Check TeaCache preconditions and extract common context.\n\n        Returns:\n            TeaCacheContext if TeaCache is enabled and properly configured,\n            None if should skip TeaCache logic entirely.\n        \"\"\"\n        from sglang.multimodal_gen.runtime.managers.forward_context import (\n            get_forward_context,\n        )\n\n        forward_context = get_forward_context()\n        forward_batch = forward_context.forward_batch\n\n        # Early return checks\n        if (\n            forward_batch is None\n            or not forward_batch.enable_teacache\n            or forward_batch.teacache_params is None\n        ):\n            return None\n\n        teacache_params = forward_batch.teacache_params\n\n        # Extract common values\n        current_timestep = forward_context.current_timestep\n        num_inference_steps = forward_batch.num_inference_steps\n        do_cfg = forward_batch.do_classifier_free_guidance\n        is_cfg_negative = forward_batch.is_cfg_negative\n\n        # Reset at first timestep\n        if current_timestep == 0 and not self.is_cfg_negative:\n            self.reset_teacache_state()\n\n        return TeaCacheContext(\n            current_timestep=current_timestep,\n            num_inference_steps=num_inference_steps,\n            do_cfg=do_cfg,\n            is_cfg_negative=is_cfg_negative,\n            teacache_thresh=teacache_params.teacache_thresh,\n            coefficients=teacache_params.coefficients,\n            teacache_params=teacache_params,\n        )\n\n    def maybe_cache_states(\n        self, hidden_states: torch.Tensor, original_hidden_states: torch.Tensor\n    ) -> None:\n        \"\"\"Cache states for later retrieval. Override in subclass if needed.\"\"\"\n        pass\n\n    def should_skip_forward_for_cached_states(self, **kwargs: dict[str, Any]) -> bool:\n        \"\"\"Check if forward can be skipped using cached states.\"\"\"\n        return False\n\n    def retrieve_cached_states(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        \"\"\"Retrieve cached states. Must be implemented by subclass.\"\"\"\n        raise NotImplementedError(\"retrieve_cached_states is not implemented\")\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/distributed/__init__.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\nfrom functools import lru_cache\n\nfrom sglang.multimodal_gen.configs.models.encoders import TextEncoderConfig\nfrom sglang.multimodal_gen.runtime.distributed.communication_op import *\nfrom sglang.multimodal_gen.runtime.distributed.group_coordinator import (\n    get_local_torch_device,\n)\nfrom sglang.multimodal_gen.runtime.distributed.parallel_state import (\n    cleanup_dist_env_and_memory,\n    get_dp_group,\n    get_dp_rank,\n    get_dp_world_size,\n    get_sp_group,\n    get_sp_parallel_rank,\n    get_sp_world_size,\n    get_tp_group,\n    get_tp_rank,\n    get_tp_world_size,\n    get_world_group,\n    get_world_rank,\n    get_world_size,\n    init_distributed_environment,\n    initialize_model_parallel,\n    maybe_init_distributed_environment_and_model_parallel,\n    model_parallel_is_initialized,\n)\nfrom sglang.multimodal_gen.runtime.distributed.utils import *\n\n# SPDX-License-Identifier: Apache-2.0\n\n\n__all__ = [\n    # Initialization\n    \"init_distributed_environment\",\n    \"initialize_model_parallel\",\n    \"cleanup_dist_env_and_memory\",\n    \"model_parallel_is_initialized\",\n    \"maybe_init_distributed_environment_and_model_parallel\",\n    # World group\n    \"get_world_group\",\n    \"get_world_rank\",\n    \"get_world_size\",\n    # Data parallel group\n    \"get_dp_group\",\n    \"get_dp_rank\",\n    \"get_dp_world_size\",\n    # Sequence parallel group\n    \"get_sp_group\",\n    \"get_sp_parallel_rank\",\n    \"get_sp_world_size\",\n    # Tensor parallel group\n    \"get_tp_group\",\n    \"get_tp_rank\",\n    \"get_tp_world_size\",\n    # Get torch device\n    \"get_local_torch_device\",\n]\n\n\ndef _get_folding_tp_group(\n    config: TextEncoderConfig,\n) -> torch.distributed.ProcessGroup | None:\n    if config.parallel_folding:\n        if config.parallel_folding_mode == \"sp\":\n            return get_sp_group()\n        elif config.parallel_folding_mode == \"ulysses\":\n            return get_sp_group().ulysses_group\n        elif config.parallel_folding_mode == \"ring\":\n            return get_sp_group().ring_group\n    return get_tp_group()\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/distributed/communication_op.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# Adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/communication_op.py\n\nimport torch\nimport torch.distributed as dist\n\nfrom sglang.multimodal_gen.runtime.distributed.parallel_state import (\n    get_cfg_group,\n    get_sp_group,\n    get_tp_group,\n)\n\n\ndef tensor_model_parallel_all_reduce(\n    input_: torch.Tensor, tp_group: dist.ProcessGroup = None\n) -> torch.Tensor:\n    \"\"\"All-reduce the input tensor across model parallel group.\"\"\"\n    tp_group = tp_group or get_tp_group()\n    return tp_group.all_reduce(input_)\n\n\ndef tensor_model_parallel_all_gather(\n    input_: torch.Tensor, dim: int = -1, tp_group: dist.ProcessGroup = None\n) -> torch.Tensor:\n    \"\"\"All-gather the input tensor across model parallel group.\"\"\"\n    tp_group = tp_group or get_tp_group()\n    return tp_group.all_gather(input_, dim)\n\n\n# TODO: remove model, make it sequence_parallel\ndef sequence_model_parallel_all_to_all_4D(\n    input_: torch.Tensor, scatter_dim: int = 2, gather_dim: int = 1\n) -> torch.Tensor:\n    \"\"\"All-to-all communication of 4D tensors (e.g. QKV matrices) across sequence parallel group.\"\"\"\n    return get_sp_group().all_to_all_4D(input_, scatter_dim, gather_dim)\n\n\ndef sequence_model_parallel_all_gather(\n    input_: torch.Tensor, dim: int = -1\n) -> torch.Tensor:\n    \"\"\"All-gather the input tensor across model parallel group.\"\"\"\n    return get_sp_group().all_gather(input_, dim)\n\n\ndef cfg_model_parallel_all_gather(\n    input_: torch.Tensor, dim: int = -1, separate_tensors: bool = False\n) -> torch.Tensor:\n    \"\"\"All-gather the input tensor across model parallel group.\"\"\"\n    return get_cfg_group().all_gather(input_, dim, separate_tensors)\n\n\ndef cfg_model_parallel_all_reduce(\n    input_: torch.Tensor,\n    op: torch._C._distributed_c10d.ReduceOp = torch._C._distributed_c10d.ReduceOp.SUM,\n) -> torch.Tensor:\n    \"\"\"All-reduce the input tensor across CFG parallel group.\"\"\"\n    return get_cfg_group().all_reduce(input_, op=op)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/distributed/device_communicators/__init__.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/distributed/device_communicators/base_device_communicator.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# Adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/device_communicators/base_device_communicator.py\n\nfrom typing import Any\n\nimport torch\nimport torch.distributed as dist\nfrom torch import Tensor\nfrom torch.distributed import ProcessGroup, ReduceOp\n\n\nclass DistributedAutograd:\n    \"\"\"Collection of autograd functions for distributed operations.\n\n    This class provides custom autograd functions for distributed operations like all_reduce,\n    all_gather, and all_to_all. Each operation is implemented as a static inner class with\n    proper forward and backward implementations.\n    \"\"\"\n\n    class AllReduce(torch.autograd.Function):\n        \"\"\"Differentiable all_reduce operation.\n\n        The gradient of all_reduce is another all_reduce operation since the operation\n        combines values from all ranks equally.\n        \"\"\"\n\n        @staticmethod\n        def forward(\n            ctx: Any,\n            group: ProcessGroup,\n            input_: Tensor,\n            op: dist.ReduceOp | None = None,\n        ) -> Tensor:\n            ctx.group = group\n            ctx.op = op\n            output = input_.clone()\n            dist.all_reduce(output, group=group, op=op)\n            return output\n\n        @staticmethod\n        def backward(ctx: Any, grad_output: Tensor) -> tuple[None, Tensor, None]:\n            grad_output = grad_output.clone()\n            dist.all_reduce(grad_output, group=ctx.group, op=ctx.op)\n            return None, grad_output, None\n\n    class AllGather(torch.autograd.Function):\n        \"\"\"Differentiable all_gather operation.\n\n        The operation gathers tensors from all ranks and concatenates them along a specified dimension.\n        The backward pass uses reduce_scatter to efficiently distribute gradients back to source ranks.\n        \"\"\"\n\n        @staticmethod\n        def forward(\n            ctx: Any, group: ProcessGroup, input_: Tensor, world_size: int, dim: int\n        ) -> Tensor:\n            ctx.group = group\n            ctx.world_size = world_size\n            ctx.dim = dim\n            ctx.input_shape = input_.shape\n\n            input_size = input_.size()\n            output_size = (input_size[0] * world_size,) + input_size[1:]\n            output_tensor = torch.empty(\n                output_size, dtype=input_.dtype, device=input_.device\n            )\n\n            dist.all_gather_into_tensor(output_tensor, input_, group=group)\n\n            output_tensor = output_tensor.reshape((world_size,) + input_size)\n            output_tensor = output_tensor.movedim(0, dim)\n            output_tensor = output_tensor.reshape(\n                input_size[:dim]\n                + (world_size * input_size[dim],)\n                + input_size[dim + 1 :]\n            )\n            return output_tensor\n\n        @staticmethod\n        def backward(ctx: Any, grad_output: Tensor) -> tuple[None, Tensor, None, None]:\n            # Split the gradient tensor along the gathered dimension\n            dim_size = grad_output.size(ctx.dim) // ctx.world_size\n            grad_chunks = grad_output.reshape(\n                grad_output.shape[: ctx.dim]\n                + (ctx.world_size, dim_size)\n                + grad_output.shape[ctx.dim + 1 :]\n            )\n            grad_chunks = grad_chunks.movedim(ctx.dim, 0)\n\n            # Each rank only needs its corresponding gradient\n            grad_input = torch.empty(\n                ctx.input_shape, dtype=grad_output.dtype, device=grad_output.device\n            )\n            dist.reduce_scatter_tensor(\n                grad_input, grad_chunks.contiguous(), group=ctx.group\n            )\n\n            return None, grad_input, None, None\n\n    class AllToAll4D(torch.autograd.Function):\n        \"\"\"Differentiable all_to_all operation specialized for 4D tensors.\n\n        This operation is particularly useful for attention operations where we need to\n        redistribute data across ranks for efficient parallel processing.\n\n        The operation supports two modes:\n        1. scatter_dim=2, gather_dim=1: Used for redistributing attention heads\n        2. scatter_dim=1, gather_dim=2: Used for redistributing sequence dimensions\n        \"\"\"\n\n        @staticmethod\n        def forward(\n            ctx: Any,\n            group: ProcessGroup,\n            input_: Tensor,\n            world_size: int,\n            scatter_dim: int,\n            gather_dim: int,\n        ) -> Tensor:\n            ctx.group = group\n            ctx.world_size = world_size\n            ctx.scatter_dim = scatter_dim\n            ctx.gather_dim = gather_dim\n\n            if world_size == 1:\n                return input_\n\n            assert (\n                input_.dim() == 4\n            ), f\"input must be 4D tensor, got {input_.dim()} and shape {input_.shape}\"\n\n            if scatter_dim == 2 and gather_dim == 1:\n                bs, shard_seqlen, hn, hd = input_.shape\n                seqlen = shard_seqlen * world_size\n                shard_hn = hn // world_size\n\n                input_ = input_.transpose(0, 2).contiguous()  # hn, shard_seqlen, bs, hd\n                output = torch.empty_like(input_)\n\n                dist.all_to_all_single(\n                    output, input_, group=group\n                )  # hn, shard_seqlen, bs, hd\n\n                output = torch.cat(\n                    output.split(shard_hn), dim=1\n                )  # sharded hn, seqlen, bs, hd\n\n                output = output.transpose(\n                    0, 2\n                ).contiguous()  # bs, seqlen, sharded_hn, hd\n\n                return output\n            elif scatter_dim == 1 and gather_dim == 2:\n                bs, seqlen, shard_hn, hd = input_.shape\n                hn = shard_hn * world_size\n                shard_seqlen = seqlen // world_size\n\n                input_ = input_.transpose(0, 2).contiguous()  # shard_hn, seqlen, bs, hd\n\n                input_ = (\n                    input_.reshape(shard_hn, world_size, shard_seqlen, bs, hd)\n                    .transpose(0, 1)\n                    .reshape(shard_hn * world_size, shard_seqlen, bs, hd)\n                    .contiguous()\n                )\n\n                output = torch.empty_like(input_)\n\n                dist.all_to_all_single(output, input_, group=group)\n\n                output = output.transpose(\n                    0, 2\n                ).contiguous()  # bs, seqlen, sharded_hn, hd\n\n                return output\n            else:\n                raise RuntimeError(\n                    f\"Invalid scatter_dim={scatter_dim}, gather_dim={gather_dim}. \"\n                    f\"Only (scatter_dim=2, gather_dim=1) and (scatter_dim=1, gather_dim=2) are supported.\"\n                )\n\n        @staticmethod\n        def backward(\n            ctx: Any, grad_output: Tensor\n        ) -> tuple[None, Tensor, None, None, None]:\n            if ctx.world_size == 1:\n                return None, grad_output, None, None, None\n\n            # For backward pass, we swap scatter_dim and gather_dim\n            output = DistributedAutograd.AllToAll4D.apply(\n                ctx.group, grad_output, ctx.world_size, ctx.gather_dim, ctx.scatter_dim\n            )\n            return None, output, None, None, None\n\n\nclass DeviceCommunicatorBase:\n    \"\"\"\n    Base class for device-specific communicator with autograd support.\n    It can use the `cpu_group` to initialize the communicator.\n    If the device has PyTorch integration (PyTorch can recognize its\n    communication backend), the `device_group` will also be given.\n    \"\"\"\n\n    def __init__(\n        self,\n        cpu_group: ProcessGroup,\n        device: torch.device | None = None,\n        device_group: ProcessGroup | None = None,\n        unique_name: str = \"\",\n    ):\n        self.device = device or torch.device(\"cpu\")\n        self.cpu_group = cpu_group\n        self.device_group = device_group\n        self.unique_name = unique_name\n        self.rank = dist.get_rank(cpu_group)\n        self.world_size = dist.get_world_size(cpu_group)\n        self.ranks = dist.get_process_group_ranks(cpu_group)\n        self.global_rank = dist.get_rank()\n        self.global_world_size = dist.get_world_size()\n        self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank)\n\n    def all_reduce(\n        self, input_: torch.Tensor, op: dist.ReduceOp | None = ReduceOp.SUM\n    ) -> torch.Tensor:\n        \"\"\"Performs an all_reduce operation with gradient support.\"\"\"\n        return DistributedAutograd.AllReduce.apply(self.device_group, input_, op)\n\n    def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:\n        \"\"\"Performs an all_gather operation with gradient support.\"\"\"\n        if dim < 0:\n            dim += input_.dim()\n        return DistributedAutograd.AllGather.apply(\n            self.device_group, input_, self.world_size, dim\n        )\n\n    def all_to_all_4D(\n        self, input_: torch.Tensor, scatter_dim: int = 2, gather_dim: int = 1\n    ) -> torch.Tensor:\n        \"\"\"Performs a 4D all-to-all operation with gradient support.\"\"\"\n        return DistributedAutograd.AllToAll4D.apply(\n            self.device_group, input_, self.world_size, scatter_dim, gather_dim\n        )\n\n    def gather(\n        self, input_: torch.Tensor, dst: int = 0, dim: int = -1\n    ) -> torch.Tensor | None:\n        \"\"\"\n        NOTE: We assume that the input tensor is on the same device across\n        all the ranks.\n        NOTE: `dst` is the local rank of the destination rank.\n        \"\"\"\n        world_size = self.world_size\n        assert (\n            -input_.dim() <= dim < input_.dim()\n        ), f\"Invalid dim ({dim}) for input tensor with shape {input_.size()}\"\n        if dim < 0:\n            # Convert negative dim to positive.\n            dim += input_.dim()\n\n        # Allocate output tensor.\n        if self.rank_in_group == dst:\n            gather_list = [torch.empty_like(input_) for _ in range(world_size)]\n        else:\n            gather_list = None\n        # Gather.\n        torch.distributed.gather(\n            input_, gather_list, dst=self.ranks[dst], group=self.device_group\n        )\n        if self.rank_in_group == dst:\n            output_tensor = torch.cat(gather_list, dim=dim)\n        else:\n            output_tensor = None\n        return output_tensor\n\n    def send(self, tensor: torch.Tensor, dst: int | None = None) -> None:\n        \"\"\"Sends a tensor to the destination rank in a non-blocking way\"\"\"\n        \"\"\"NOTE: `dst` is the local rank of the destination rank.\"\"\"\n        if dst is None:\n            dst = (self.rank_in_group + 1) % self.world_size\n        torch.distributed.send(tensor, self.ranks[dst], self.device_group)\n\n    def recv(\n        self, size: torch.Size, dtype: torch.dtype, src: int | None = None\n    ) -> torch.Tensor:\n        \"\"\"Receives a tensor from the source rank.\"\"\"\n        \"\"\"NOTE: `src` is the local rank of the source rank.\"\"\"\n        if src is None:\n            src = (self.rank_in_group - 1) % self.world_size\n\n        tensor = torch.empty(size, dtype=dtype, device=self.device)\n        torch.distributed.recv(tensor, self.ranks[src], self.device_group)\n        return tensor\n\n    def destroy(self) -> None:\n        pass\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/distributed/device_communicators/cpu_communicator.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# Adapted from: https://github.com/vllm-project/vllm/blob/main/vllm/distributed/device_communicators/cpu_communicator.py\n\nimport os\n\nimport torch\nfrom torch.distributed import ProcessGroup\n\nfrom .base_device_communicator import DeviceCommunicatorBase\n\n\nclass CpuCommunicator(DeviceCommunicatorBase):\n\n    def __init__(\n        self,\n        cpu_group: ProcessGroup,\n        device: torch.device | None = None,\n        device_group: ProcessGroup | None = None,\n        unique_name: str = \"\",\n    ):\n        from sglang.multimodal_gen.runtime.platforms import current_platform\n        from sglang.multimodal_gen.runtime.platforms.interface import CpuArchEnum\n\n        super().__init__(cpu_group, device, device_group, unique_name)\n        self.dist_module = torch.distributed\n\n        if (\n            (current_platform.get_cpu_architecture() == CpuArchEnum.X86)\n            and hasattr(torch.ops._C, \"init_shm_manager\")\n            and unique_name.startswith(\"tp\")\n        ):\n            self.dist_module = _CPUSHMDistributed(self)\n\n    def all_reduce(\n        self,\n        input_: torch.Tensor,\n        op: torch.distributed.ReduceOp | None = torch.distributed.ReduceOp.SUM,\n    ) -> torch.Tensor:\n        self.dist_module.all_reduce(input_, group=self.device_group, op=op)\n        return input_\n\n    def gather(\n        self, input_: torch.Tensor, dst: int = 0, dim: int = -1\n    ) -> torch.Tensor | None:\n        \"\"\"\n        NOTE: We assume that the input tensor is on the same device across\n        all the ranks.\n        NOTE: `dst` is the local rank of the destination rank.\n        \"\"\"\n        world_size = self.world_size\n        assert (\n            -input_.dim() <= dim < input_.dim()\n        ), f\"Invalid dim ({dim}) for input tensor with shape {input_.size()}\"\n        if dim < 0:\n            # Convert negative dim to positive.\n            dim += input_.dim()\n\n        # Allocate output tensor.\n        if self.rank_in_group == dst:\n            gather_list = [torch.empty_like(input_) for _ in range(world_size)]\n        else:\n            gather_list = None\n\n        # Gather.\n        self.dist_module.gather(\n            input_, gather_list, dst=self.ranks[dst], group=self.device_group\n        )\n\n        if self.rank_in_group == dst:\n            output_tensor = torch.cat(gather_list, dim=dim)\n        else:\n            output_tensor = None\n        return output_tensor\n\n    def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:\n        if dim < 0:\n            # Convert negative dim to positive.\n            dim += input_.dim()\n        input_size = input_.size()\n        # NOTE: we have to use concat-style all-gather here,\n        # stack-style all-gather has compatibility issues with\n        # torch.compile . see https://github.com/pytorch/pytorch/issues/138795\n        output_size = (input_size[0] * self.world_size,) + input_size[1:]\n        # Allocate output tensor.\n        output_tensor = torch.empty(\n            output_size, dtype=input_.dtype, device=input_.device\n        )\n        # All-gather.\n        self.dist_module.all_gather_into_tensor(\n            output_tensor, input_, group=self.device_group\n        )\n\n        # Reshape\n        output_tensor = output_tensor.reshape((self.world_size,) + input_size)\n        output_tensor = output_tensor.movedim(0, dim)\n        output_tensor = output_tensor.reshape(\n            input_size[:dim]\n            + (self.world_size * input_size[dim],)\n            + input_size[dim + 1 :]\n        )\n        return output_tensor\n\n\nclass _CPUSHMDistributed:\n\n    def __init__(self, communicator: CpuCommunicator):\n        instance_identifier = os.environ[\"VLLM_DIST_IDENT\"]\n        unique_name = communicator.unique_name\n        instance_identifier = f\"{instance_identifier}-{unique_name}\"\n        self.communicator = communicator\n\n        group_ranks = [str(rank) for rank in self.communicator.ranks]\n        shm_group_identifier = f\"[{'-'.join(group_ranks)}]\"\n        self.group_name = f\"{instance_identifier}-{shm_group_identifier}-cpushm\"\n\n        self.handle = self._init_cpu_shm()\n\n    def _init_cpu_shm(self) -> int:\n        handle = torch.ops._C.init_shm_manager(\n            self.group_name,\n            self.communicator.world_size,\n            self.communicator.rank,\n        )\n        torch.distributed.barrier(self.communicator.device_group)\n        torch.ops._C.join_shm_manager(\n            handle,\n            self.group_name,\n        )\n        torch.distributed.barrier(self.communicator.device_group)\n\n        return int(handle)\n\n    def all_reduce(\n        self, input: torch.Tensor, group: ProcessGroup | None = None\n    ) -> None:\n        torch.ops._C.shm_allreduce(self.handle, input)\n\n    def gather(\n        self,\n        input: torch.Tensor,\n        gather_list: list[torch.Tensor] | None,\n        dst: int = -1,\n        group: ProcessGroup | None = None,\n    ) -> None:\n        # Note: different from the torch gather, here we use local dst rank.\n        torch.ops._C.shm_gather(\n            self.handle,\n            input,\n            gather_list,\n            torch.distributed.get_group_rank(group, dst),\n        )\n\n    def all_gather_into_tensor(\n        self,\n        output: torch.Tensor,\n        input: torch.Tensor,\n        group: ProcessGroup | None = None,\n    ) -> None:\n        torch.ops._C.shm_all_gather(self.handle, input, output)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/distributed/device_communicators/cuda_communicator.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# Adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/device_communicators/cuda_communicator.py\n\nimport torch\nfrom torch.distributed import ProcessGroup\n\nfrom sglang.multimodal_gen.runtime.distributed.device_communicators.base_device_communicator import (\n    DeviceCommunicatorBase,\n)\n\n\nclass CudaCommunicator(DeviceCommunicatorBase):\n\n    def __init__(\n        self,\n        cpu_group: ProcessGroup,\n        device: torch.device | None = None,\n        device_group: ProcessGroup | None = None,\n        unique_name: str = \"\",\n    ):\n        super().__init__(cpu_group, device, device_group, unique_name)\n\n        from sglang.multimodal_gen.runtime.distributed.device_communicators.pynccl import (\n            PyNcclCommunicator,\n        )\n\n        self.pynccl_comm: PyNcclCommunicator | None = None\n        if self.world_size > 1:\n            self.pynccl_comm = PyNcclCommunicator(\n                group=self.cpu_group,\n                device=self.device,\n            )\n\n    def all_reduce(self, input_, op: torch.distributed.ReduceOp | None = None):\n        pynccl_comm = self.pynccl_comm\n        assert pynccl_comm is not None\n        out = pynccl_comm.all_reduce(input_, op=op)\n        if out is None:\n            # fall back to the default all-reduce using PyTorch.\n            # this usually happens during testing.\n            # when we run the model, allreduce only happens for the TP\n            # group, where we always have either custom allreduce or pynccl.\n            out = input_.clone()\n            torch.distributed.all_reduce(out, group=self.device_group, op=op)\n        return out\n\n    def send(self, tensor: torch.Tensor, dst: int | None = None) -> None:\n        \"\"\"Sends a tensor to the destination rank in a non-blocking way\"\"\"\n        \"\"\"NOTE: `dst` is the local rank of the destination rank.\"\"\"\n        if dst is None:\n            dst = (self.rank_in_group + 1) % self.world_size\n\n        pynccl_comm = self.pynccl_comm\n        if pynccl_comm is not None and not pynccl_comm.disabled:\n            pynccl_comm.send(tensor, dst)\n        else:\n            torch.distributed.send(tensor, self.ranks[dst], self.device_group)\n\n    def recv(\n        self, size: torch.Size, dtype: torch.dtype, src: int | None = None\n    ) -> torch.Tensor:\n        \"\"\"Receives a tensor from the source rank.\"\"\"\n        \"\"\"NOTE: `src` is the local rank of the source rank.\"\"\"\n        if src is None:\n            src = (self.rank_in_group - 1) % self.world_size\n\n        tensor = torch.empty(size, dtype=dtype, device=self.device)\n        pynccl_comm = self.pynccl_comm\n        if pynccl_comm is not None and not pynccl_comm.disabled:\n            pynccl_comm.recv(tensor, src)\n        else:\n            torch.distributed.recv(tensor, self.ranks[src], self.device_group)\n        return tensor\n\n    def destroy(self) -> None:\n        if self.pynccl_comm is not None:\n            self.pynccl_comm = None\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/distributed/device_communicators/pynccl.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# Adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/device_communicators/pynccl.py\n\n# ===================== import region =====================\nimport torch\nimport torch.distributed as dist\nfrom torch.distributed import ProcessGroup, ReduceOp\n\nfrom sglang.multimodal_gen.runtime.distributed.device_communicators.pynccl_wrapper import (\n    NCCLLibrary,\n    buffer_type,\n    cudaStream_t,\n    ncclComm_t,\n    ncclDataTypeEnum,\n    ncclRedOpTypeEnum,\n    ncclUniqueId,\n)\nfrom sglang.multimodal_gen.runtime.distributed.utils import StatelessProcessGroup\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.utils import current_stream\n\nlogger = init_logger(__name__)\n\n\nclass PyNcclCommunicator:\n\n    def __init__(\n        self,\n        group: ProcessGroup | StatelessProcessGroup,\n        device: int | str | torch.device,\n        library_path: str | None = None,\n    ):\n        \"\"\"\n        Args:\n            group: the process group to work on. If None, it will use the\n                default process group.\n            device: the device to bind the PyNcclCommunicator to. If None,\n                it will be bind to f\"cuda:{local_rank}\".\n            library_path: the path to the NCCL library. If None, it will\n                use the default library path.\n        It is the caller's responsibility to make sure each communicator\n        is bind to a unique device.\n        \"\"\"\n        if not isinstance(group, StatelessProcessGroup):\n            assert dist.is_initialized()\n            assert (\n                dist.get_backend(group) != dist.Backend.NCCL\n            ), \"PyNcclCommunicator should be attached to a non-NCCL group.\"\n            # note: this rank is the rank in the group\n            self.rank = dist.get_rank(group)\n            self.world_size = dist.get_world_size(group)\n        else:\n            self.rank = group.rank\n            self.world_size = group.world_size\n\n        self.group = group\n\n        # if world_size == 1, no need to create communicator\n        if self.world_size == 1:\n            self.available = False\n            self.disabled = True\n            return\n        try:\n            self.nccl = NCCLLibrary(library_path)\n        except Exception:\n            # disable because of missing NCCL library\n            # e.g. in a non-GPU environment\n            self.available = False\n            self.disabled = True\n            return\n\n        self.available = True\n        self.disabled = False\n\n        logger.info(\"sglang-diffusion is using nccl==%s\", self.nccl.ncclGetVersion())\n\n        if self.rank == 0:\n            # get the unique id from NCCL\n            self.unique_id = self.nccl.ncclGetUniqueId()\n        else:\n            # construct an empty unique id\n            self.unique_id = ncclUniqueId()\n\n        if not isinstance(group, StatelessProcessGroup):\n            tensor = torch.ByteTensor(list(self.unique_id.internal))\n            ranks = dist.get_process_group_ranks(group)\n            # arg `src` in `broadcast` is the global rank\n            dist.broadcast(tensor, src=ranks[0], group=group)\n            byte_list = tensor.tolist()\n            for i, byte in enumerate(byte_list):\n                self.unique_id.internal[i] = byte\n        else:\n            self.unique_id = group.broadcast_obj(self.unique_id, src=0)\n        if isinstance(device, int):\n            device = torch.device(f\"cuda:{device}\")\n        elif isinstance(device, str):\n            device = torch.device(device)\n        # now `device` is a `torch.device` object\n        assert isinstance(device, torch.device)\n        self.device = device\n        # nccl communicator and stream will use this device\n        # `torch.cuda.device` is a context manager that changes the\n        # current cuda device to the specified one\n        with torch.cuda.device(device):\n            self.comm: ncclComm_t = self.nccl.ncclCommInitRank(\n                self.world_size, self.unique_id, self.rank\n            )\n\n            stream = current_stream()\n            # A small all_reduce for warmup.\n            data = torch.zeros(1, device=device)\n            self.all_reduce(data)\n            if stream is not None:\n                stream.synchronize()\n            del data\n\n    def all_reduce(\n        self, in_tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None\n    ) -> torch.Tensor:\n        if self.disabled:\n            return None\n        # nccl communicator created on a specific device\n        # will only work on tensors on the same device\n        # otherwise it will cause \"illegal memory access\"\n        assert in_tensor.device == self.device, (\n            f\"this nccl communicator is created to work on {self.device}, \"\n            f\"but the input tensor is on {in_tensor.device}\"\n        )\n\n        out_tensor = torch.empty_like(in_tensor)\n\n        if stream is None:\n            stream = current_stream()\n        self.nccl.ncclAllReduce(\n            buffer_type(in_tensor.data_ptr()),\n            buffer_type(out_tensor.data_ptr()),\n            in_tensor.numel(),\n            ncclDataTypeEnum.from_torch(in_tensor.dtype),\n            ncclRedOpTypeEnum.from_torch(op),\n            self.comm,\n            cudaStream_t(stream.cuda_stream),\n        )\n        return out_tensor\n\n    def all_gather(\n        self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, stream=None\n    ):\n        if self.disabled:\n            return\n        # nccl communicator created on a specific device\n        # will only work on tensors on the same device\n        # otherwise it will cause \"illegal memory access\"\n        assert input_tensor.device == self.device, (\n            f\"this nccl communicator is created to work on {self.device}, \"\n            f\"but the input tensor is on {input_tensor.device}\"\n        )\n        if stream is None:\n            stream = current_stream()\n        self.nccl.ncclAllGather(\n            buffer_type(input_tensor.data_ptr()),\n            buffer_type(output_tensor.data_ptr()),\n            input_tensor.numel(),\n            ncclDataTypeEnum.from_torch(input_tensor.dtype),\n            self.comm,\n            cudaStream_t(stream.cuda_stream),\n        )\n\n    def reduce_scatter(\n        self,\n        output_tensor: torch.Tensor,\n        input_tensor: torch.Tensor,\n        op: ReduceOp = ReduceOp.SUM,\n        stream=None,\n    ):\n        if self.disabled:\n            return\n        # nccl communicator created on a specific device\n        # will only work on tensors on the same device\n        # otherwise it will cause \"illegal memory access\"\n        assert input_tensor.device == self.device, (\n            f\"this nccl communicator is created to work on {self.device}, \"\n            f\"but the input tensor is on {input_tensor.device}\"\n        )\n        if stream is None:\n            stream = current_stream()\n        self.nccl.ncclReduceScatter(\n            buffer_type(input_tensor.data_ptr()),\n            buffer_type(output_tensor.data_ptr()),\n            output_tensor.numel(),\n            ncclDataTypeEnum.from_torch(input_tensor.dtype),\n            ncclRedOpTypeEnum.from_torch(op),\n            self.comm,\n            cudaStream_t(stream.cuda_stream),\n        )\n\n    def send(self, tensor: torch.Tensor, dst: int, stream=None):\n        if self.disabled:\n            return\n        assert tensor.device == self.device, (\n            f\"this nccl communicator is created to work on {self.device}, \"\n            f\"but the input tensor is on {tensor.device}\"\n        )\n        if stream is None:\n            stream = current_stream()\n        self.nccl.ncclSend(\n            buffer_type(tensor.data_ptr()),\n            tensor.numel(),\n            ncclDataTypeEnum.from_torch(tensor.dtype),\n            dst,\n            self.comm,\n            cudaStream_t(stream.cuda_stream),\n        )\n\n    def recv(self, tensor: torch.Tensor, src: int, stream=None):\n        if self.disabled:\n            return\n        assert tensor.device == self.device, (\n            f\"this nccl communicator is created to work on {self.device}, \"\n            f\"but the input tensor is on {tensor.device}\"\n        )\n        if stream is None:\n            stream = current_stream()\n        self.nccl.ncclRecv(\n            buffer_type(tensor.data_ptr()),\n            tensor.numel(),\n            ncclDataTypeEnum.from_torch(tensor.dtype),\n            src,\n            self.comm,\n            cudaStream_t(stream.cuda_stream),\n        )\n\n    def broadcast(self, tensor: torch.Tensor, src: int, stream=None):\n        if self.disabled:\n            return\n        assert tensor.device == self.device, (\n            f\"this nccl communicator is created to work on {self.device}, \"\n            f\"but the input tensor is on {tensor.device}\"\n        )\n        if stream is None:\n            stream = current_stream()\n        if src == self.rank:\n            sendbuff = buffer_type(tensor.data_ptr())\n            # NCCL requires the sender also to have a receive buffer\n            recvbuff = buffer_type(tensor.data_ptr())\n        else:\n            sendbuff = buffer_type()\n            recvbuff = buffer_type(tensor.data_ptr())\n        self.nccl.ncclBroadcast(\n            sendbuff,\n            recvbuff,\n            tensor.numel(),\n            ncclDataTypeEnum.from_torch(tensor.dtype),\n            src,\n            self.comm,\n            cudaStream_t(stream.cuda_stream),\n        )\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/distributed/device_communicators/pynccl_wrapper.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# Adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/device_communicators/pynccl_wrapper.py\n\n# This file is a pure Python wrapper for the NCCL library.\n# The main purpose is to use NCCL combined with CUDA graph.\n# Before writing this script, we tried the following approach:\n# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself\n#  often gets stuck when initializing the NCCL communicator.\n# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`\n#  contains many other potential cuda APIs, that are not allowed during\n#  capturing the CUDA graph. For further details, please check\n# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .\n#\n# Another rejected idea is to write a C/C++ binding for NCCL. It is usually\n# doable, but we often encounter issues related with nccl versions, and need\n# to switch between different versions of NCCL. See\n# https://github.com/NVIDIA/nccl/issues/1234 for more details.\n# A C/C++ binding is not flexible enough to handle this. It requires\n# recompilation of the code every time we want to switch between different\n# versions. This current implementation, with a **pure** Python wrapper, is\n# more flexible. We can easily switch between different versions of NCCL by\n# changing the environment variable `SGLANG_DIFFUSION_NCCL_SO_PATH`, or the `so_file`\n# variable in the code.\n\n# TODO(will): support SGLANG_DIFFUSION_NCCL_SO_PATH\n\nimport ctypes\nimport platform\nfrom dataclasses import dataclass\nfrom typing import Any\n\nimport torch\nfrom torch.distributed import ReduceOp\n\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.utils import find_nccl_library\n\nlogger = init_logger(__name__)\n\n# === export types and functions from nccl to Python ===\n# for the original nccl definition, please check\n# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in\n\nncclResult_t = ctypes.c_int\nncclComm_t = ctypes.c_void_p\n\n\nclass ncclUniqueId(ctypes.Structure):\n    _fields_ = [(\"internal\", ctypes.c_byte * 128)]\n\n\ncudaStream_t = ctypes.c_void_p\nbuffer_type = ctypes.c_void_p\n\nncclDataType_t = ctypes.c_int\n\n\nclass ncclDataTypeEnum:\n    ncclInt8 = 0\n    ncclChar = 0\n    ncclUint8 = 1\n    ncclInt32 = 2\n    ncclInt = 2\n    ncclUint32 = 3\n    ncclInt64 = 4\n    ncclUint64 = 5\n    ncclFloat16 = 6\n    ncclHalf = 6\n    ncclFloat32 = 7\n    ncclFloat = 7\n    ncclFloat64 = 8\n    ncclDouble = 8\n    ncclBfloat16 = 9\n    ncclNumTypes = 10\n\n    @classmethod\n    def from_torch(cls, dtype: torch.dtype) -> int:\n        if dtype == torch.int8:\n            return cls.ncclInt8\n        if dtype == torch.uint8:\n            return cls.ncclUint8\n        if dtype == torch.int32:\n            return cls.ncclInt32\n        if dtype == torch.int64:\n            return cls.ncclInt64\n        if dtype == torch.float16:\n            return cls.ncclFloat16\n        if dtype == torch.float32:\n            return cls.ncclFloat32\n        if dtype == torch.float64:\n            return cls.ncclFloat64\n        if dtype == torch.bfloat16:\n            return cls.ncclBfloat16\n        raise ValueError(f\"Unsupported dtype: {dtype}\")\n\n\nncclRedOp_t = ctypes.c_int\n\n\nclass ncclRedOpTypeEnum:\n    ncclSum = 0\n    ncclProd = 1\n    ncclMax = 2\n    ncclMin = 3\n    ncclAvg = 4\n    ncclNumOps = 5\n\n    @classmethod\n    def from_torch(cls, op: ReduceOp) -> int:\n        if op == ReduceOp.SUM:\n            return cls.ncclSum\n        if op == ReduceOp.PRODUCT:\n            return cls.ncclProd\n        if op == ReduceOp.MAX:\n            return cls.ncclMax\n        if op == ReduceOp.MIN:\n            return cls.ncclMin\n        if op == ReduceOp.AVG:\n            return cls.ncclAvg\n        raise ValueError(f\"Unsupported op: {op}\")\n\n\n@dataclass\nclass Function:\n    name: str\n    restype: Any\n    argtypes: list[Any]\n\n\nclass NCCLLibrary:\n    exported_functions = [\n        # const char* ncclGetErrorString(ncclResult_t result)\n        Function(\"ncclGetErrorString\", ctypes.c_char_p, [ncclResult_t]),\n        # ncclResult_t  ncclGetVersion(int *version);\n        Function(\"ncclGetVersion\", ncclResult_t, [ctypes.POINTER(ctypes.c_int)]),\n        # ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);\n        Function(\"ncclGetUniqueId\", ncclResult_t, [ctypes.POINTER(ncclUniqueId)]),\n        # ncclResult_t  ncclCommInitRank(\n        #   ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);\n        # note that ncclComm_t is a pointer type, so the first argument\n        # is a pointer to a pointer\n        Function(\n            \"ncclCommInitRank\",\n            ncclResult_t,\n            [ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, ctypes.c_int],\n        ),\n        # ncclResult_t  ncclAllReduce(\n        #   const void* sendbuff, void* recvbuff, size_t count,\n        #   ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,\n        #   cudaStream_t stream);\n        # note that cudaStream_t is a pointer type, so the last argument\n        # is a pointer\n        Function(\n            \"ncclAllReduce\",\n            ncclResult_t,\n            [\n                buffer_type,\n                buffer_type,\n                ctypes.c_size_t,\n                ncclDataType_t,\n                ncclRedOp_t,\n                ncclComm_t,\n                cudaStream_t,\n            ],\n        ),\n        # ncclResult_t  ncclAllGather(\n        #   const void* sendbuff, void* recvbuff, size_t count,\n        #   ncclDataType_t datatype, ncclComm_t comm,\n        #   cudaStream_t stream);\n        # note that cudaStream_t is a pointer type, so the last argument\n        # is a pointer\n        Function(\n            \"ncclAllGather\",\n            ncclResult_t,\n            [\n                buffer_type,\n                buffer_type,\n                ctypes.c_size_t,\n                ncclDataType_t,\n                ncclComm_t,\n                cudaStream_t,\n            ],\n        ),\n        # ncclResult_t  ncclReduceScatter(\n        #   const void* sendbuff, void* recvbuff, size_t count,\n        #   ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,\n        #   cudaStream_t stream);\n        # note that cudaStream_t is a pointer type, so the last argument\n        # is a pointer\n        Function(\n            \"ncclReduceScatter\",\n            ncclResult_t,\n            [\n                buffer_type,\n                buffer_type,\n                ctypes.c_size_t,\n                ncclDataType_t,\n                ncclRedOp_t,\n                ncclComm_t,\n                cudaStream_t,\n            ],\n        ),\n        # ncclResult_t  ncclSend(\n        #   const void* sendbuff, size_t count, ncclDataType_t datatype,\n        #   int dest, ncclComm_t comm, cudaStream_t stream);\n        Function(\n            \"ncclSend\",\n            ncclResult_t,\n            [\n                buffer_type,\n                ctypes.c_size_t,\n                ncclDataType_t,\n                ctypes.c_int,\n                ncclComm_t,\n                cudaStream_t,\n            ],\n        ),\n        # ncclResult_t  ncclRecv(\n        #   void* recvbuff, size_t count, ncclDataType_t datatype,\n        #   int src, ncclComm_t comm, cudaStream_t stream);\n        Function(\n            \"ncclRecv\",\n            ncclResult_t,\n            [\n                buffer_type,\n                ctypes.c_size_t,\n                ncclDataType_t,\n                ctypes.c_int,\n                ncclComm_t,\n                cudaStream_t,\n            ],\n        ),\n        # ncclResult_t ncclBroadcast(\n        #   const void* sendbuff, void* recvbuff, size_t count,\n        #   ncclDataType_t datatype, int root, ncclComm_t comm,\n        #   cudaStream_t stream);\n        Function(\n            \"ncclBroadcast\",\n            ncclResult_t,\n            [\n                buffer_type,\n                buffer_type,\n                ctypes.c_size_t,\n                ncclDataType_t,\n                ctypes.c_int,\n                ncclComm_t,\n                cudaStream_t,\n            ],\n        ),\n        # be cautious! this is a collective call, it will block until all\n        # processes in the communicator have called this function.\n        # because Python object destruction can happen in random order,\n        # it is better not to call it at all.\n        # ncclResult_t  ncclCommDestroy(ncclComm_t comm);\n        Function(\"ncclCommDestroy\", ncclResult_t, [ncclComm_t]),\n    ]\n\n    # class attribute to store the mapping from the path to the library\n    # to avoid loading the same library multiple times\n    path_to_library_cache: dict[str, Any] = {}\n\n    # class attribute to store the mapping from library path\n    #  to the corresponding dictionary\n    path_to_dict_mapping: dict[str, dict[str, Any]] = {}\n\n    def __init__(self, so_file: str | None = None):\n\n        so_file = so_file or find_nccl_library()\n\n        try:\n            if so_file not in NCCLLibrary.path_to_dict_mapping:\n                lib = ctypes.CDLL(so_file)\n                NCCLLibrary.path_to_library_cache[so_file] = lib\n            self.lib = NCCLLibrary.path_to_library_cache[so_file]\n        except Exception as e:\n            logger.error(\n                \"Failed to load NCCL library from %s .\"\n                \"It is expected if you are not running on NVIDIA/AMD/MTHREADS GPUs.\"\n                \"Otherwise, the nccl library might not exist, be corrupted \"\n                \"or it does not support the current platform %s.\"\n                \"If you already have the library, please set the \"\n                \"environment variable SGLANG_DIFFUSION_NCCL_SO_PATH\"\n                \" to point to the correct nccl library path.\",\n                so_file,\n                platform.platform(),\n            )\n            raise e\n\n        if so_file not in NCCLLibrary.path_to_dict_mapping:\n            _funcs: dict[str, Any] = {}\n            for func in NCCLLibrary.exported_functions:\n                f = getattr(self.lib, func.name)\n                f.restype = func.restype\n                f.argtypes = func.argtypes\n                _funcs[func.name] = f\n            NCCLLibrary.path_to_dict_mapping[so_file] = _funcs\n        self._funcs = NCCLLibrary.path_to_dict_mapping[so_file]\n\n    def ncclGetErrorString(self, result: ncclResult_t) -> str:\n        return str(self._funcs[\"ncclGetErrorString\"](result).decode(\"utf-8\"))\n\n    def NCCL_CHECK(self, result: ncclResult_t) -> None:\n        if result != 0:\n            error_str = self.ncclGetErrorString(result)\n            raise RuntimeError(f\"NCCL error: {error_str}\")\n\n    def ncclGetVersion(self) -> str:\n        version = ctypes.c_int()\n        self.NCCL_CHECK(self._funcs[\"ncclGetVersion\"](ctypes.byref(version)))\n        version_str = str(version.value)\n        # something like 21903 --> \"2.19.3\"\n        major = version_str[0].lstrip(\"0\")\n        minor = version_str[1:3].lstrip(\"0\")\n        patch = version_str[3:].lstrip(\"0\")\n        return f\"{major}.{minor}.{patch}\"\n\n    def ncclGetUniqueId(self) -> ncclUniqueId:\n        unique_id = ncclUniqueId()\n        self.NCCL_CHECK(self._funcs[\"ncclGetUniqueId\"](ctypes.byref(unique_id)))\n        return unique_id\n\n    def ncclCommInitRank(\n        self, world_size: int, unique_id: ncclUniqueId, rank: int\n    ) -> ncclComm_t:\n        comm = ncclComm_t()\n        self.NCCL_CHECK(\n            self._funcs[\"ncclCommInitRank\"](\n                ctypes.byref(comm), world_size, unique_id, rank\n            )\n        )\n        return comm\n\n    def ncclAllReduce(\n        self,\n        sendbuff: buffer_type,\n        recvbuff: buffer_type,\n        count: int,\n        datatype: int,\n        op: int,\n        comm: ncclComm_t,\n        stream: cudaStream_t,\n    ) -> None:\n        # `datatype` actually should be `ncclDataType_t`\n        # and `op` should be `ncclRedOp_t`\n        # both are aliases of `ctypes.c_int`\n        # when we pass int to a function, it will be converted to `ctypes.c_int`\n        # by ctypes automatically\n        self.NCCL_CHECK(\n            self._funcs[\"ncclAllReduce\"](\n                sendbuff, recvbuff, count, datatype, op, comm, stream\n            )\n        )\n\n    def ncclReduceScatter(\n        self,\n        sendbuff: buffer_type,\n        recvbuff: buffer_type,\n        count: int,\n        datatype: int,\n        op: int,\n        comm: ncclComm_t,\n        stream: cudaStream_t,\n    ) -> None:\n        # `datatype` actually should be `ncclDataType_t`\n        # and `op` should be `ncclRedOp_t`\n        # both are aliases of `ctypes.c_int`\n        # when we pass int to a function, it will be converted to `ctypes.c_int`\n        # by ctypes automatically\n        self.NCCL_CHECK(\n            self._funcs[\"ncclReduceScatter\"](\n                sendbuff, recvbuff, count, datatype, op, comm, stream\n            )\n        )\n\n    def ncclAllGather(\n        self,\n        sendbuff: buffer_type,\n        recvbuff: buffer_type,\n        count: int,\n        datatype: int,\n        comm: ncclComm_t,\n        stream: cudaStream_t,\n    ) -> None:\n        # `datatype` actually should be `ncclDataType_t`\n        # which is an aliases of `ctypes.c_int`\n        # when we pass int to a function, it will be converted to `ctypes.c_int`\n        # by ctypes automatically\n        self.NCCL_CHECK(\n            self._funcs[\"ncclAllGather\"](\n                sendbuff, recvbuff, count, datatype, comm, stream\n            )\n        )\n\n    def ncclSend(\n        self,\n        sendbuff: buffer_type,\n        count: int,\n        datatype: int,\n        dest: int,\n        comm: ncclComm_t,\n        stream: cudaStream_t,\n    ) -> None:\n        self.NCCL_CHECK(\n            self._funcs[\"ncclSend\"](sendbuff, count, datatype, dest, comm, stream)\n        )\n\n    def ncclRecv(\n        self,\n        recvbuff: buffer_type,\n        count: int,\n        datatype: int,\n        src: int,\n        comm: ncclComm_t,\n        stream: cudaStream_t,\n    ) -> None:\n        self.NCCL_CHECK(\n            self._funcs[\"ncclRecv\"](recvbuff, count, datatype, src, comm, stream)\n        )\n\n    def ncclBroadcast(\n        self,\n        sendbuff: buffer_type,\n        recvbuff: buffer_type,\n        count: int,\n        datatype: int,\n        root: int,\n        comm: ncclComm_t,\n        stream: cudaStream_t,\n    ) -> None:\n        self.NCCL_CHECK(\n            self._funcs[\"ncclBroadcast\"](\n                sendbuff, recvbuff, count, datatype, root, comm, stream\n            )\n        )\n\n    def ncclCommDestroy(self, comm: ncclComm_t) -> None:\n        self.NCCL_CHECK(self._funcs[\"ncclCommDestroy\"](comm))\n\n\n__all__ = [\n    \"NCCLLibrary\",\n    \"ncclDataTypeEnum\",\n    \"ncclRedOpTypeEnum\",\n    \"ncclUniqueId\",\n    \"ncclComm_t\",\n    \"cudaStream_t\",\n    \"buffer_type\",\n]\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/distributed/group_coordinator.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# Copyright 2024 xDiT team.\n# Adapted from\n# https://github.com/vllm-project/vllm/blob/main/vllm/distributed/parallel_state.py\n# Copyright 2023 The vLLM team.\n# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.\nimport pickle\nfrom collections import namedtuple\nfrom contextlib import contextmanager\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List, Optional, Tuple, Union\n\nimport torch\nimport torch.distributed\nfrom torch.cuda import synchronize\nfrom torch.distributed import Backend, ProcessGroup\n\nfrom sglang.multimodal_gen.runtime.distributed.device_communicators.base_device_communicator import (\n    DeviceCommunicatorBase,\n)\nfrom sglang.multimodal_gen.runtime.distributed.device_communicators.cpu_communicator import (\n    CpuCommunicator,\n)\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import (\n    init_logger,\n    suppress_stdout,\n)\n\ntry:\n    import torch_musa  # noqa: F401\n    from torch_musa.core.device import synchronize\nexcept ModuleNotFoundError:\n    pass\n\nlogger = init_logger(__name__)\n\nTensorMetadata = namedtuple(\"TensorMetadata\", [\"device\", \"dtype\", \"size\"])\n\n\n_group_name_counter: dict[str, int] = {}\n\n\ndef get_local_torch_device() -> torch.device:\n    \"\"\"Return the torch device for the current rank.\"\"\"\n\n    return current_platform.get_local_torch_device()\n\n\ndef _get_unique_name(name: str) -> str:\n    \"\"\"Get a unique name for the group.\n    Example:\n    _get_unique_name(\"tp\") -> \"tp:0\"\n    _get_unique_name(\"tp\") -> \"tp:1\"\n    \"\"\"\n    if name not in _group_name_counter:\n        _group_name_counter[name] = 0\n    newname = f\"{name}:{_group_name_counter[name]}\"\n    _group_name_counter[name] += 1\n    return newname\n\n\ndef _split_tensor_dict(\n    tensor_dict: Dict[str, Union[torch.Tensor, Any]], prefix: str = \"\"\n) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:\n    \"\"\"Split the tensor dictionary into two parts:\n    1. A list of (key, value) pairs. If the value is a tensor, it is replaced\n         by its metadata.\n    2. A list of tensors.\n\n    If the Tensor is nested under `tensor_dict[\"key1\"][\"key2\"]`, the key of its\n    metadata will be \"key1%key2\".\n    \"\"\"\n    metadata_list: List[Tuple[str, Any]] = []\n    tensor_list = []\n    for key, value in tensor_dict.items():\n        assert \"%\" not in key, (\n            \"Avoid having '%' in key \"\n            \"as it is used as a separator for nested entries.\"\n        )\n        if isinstance(value, torch.Tensor):\n            # Note: we cannot use `value.device` here,\n            # because it contains not only the device type but also the device\n            # index (e.g. \"cuda:0\"). We only need the device type.\n            # receiving side will set the device index.\n            device = value.device.type\n            metadata_list.append(\n                (\n                    prefix + key,\n                    TensorMetadata(device, value.dtype, value.size()),\n                )\n            )\n            tensor_list.append(value)\n        elif isinstance(value, dict):\n            if len(value) == 0:\n                metadata_list.append((prefix + key, value))\n            inner_metadata_list, inner_tensor_list = _split_tensor_dict(\n                value, prefix + key + \"%\"\n            )\n            metadata_list.extend(inner_metadata_list)\n            tensor_list.extend(inner_tensor_list)\n        else:\n            metadata_list.append((prefix + key, value))\n    return metadata_list, tensor_list\n\n\ndef _update_nested_dict(nested_dict, flattened_key, value):\n    key_splits = flattened_key.split(\"%\")\n    cur_dict = nested_dict\n    for k in key_splits[:-1]:\n        if k not in cur_dict:\n            cur_dict[k] = {}\n        cur_dict = cur_dict[k]\n    cur_dict[key_splits[-1]] = value\n\n\n@dataclass\nclass GraphCaptureContext:\n    stream: torch.cuda.Stream | None\n\n\nclass GroupCoordinator:\n    \"\"\"\n    PyTorch ProcessGroup wrapper for a group of processes.\n    PyTorch ProcessGroup is bound to one specific communication backend,\n        e.g. NCCL, Gloo, MPI, etc.\n    GroupCoordinator takes charge of all the communication operations among\n        the processes in the group. It can route the communication to\n        a specific implementation (e.g. switch allreduce implementation\n        based on the tensor size and cuda graph mode).\n    \"\"\"\n\n    # available attributes:\n    rank: int  # global rank\n    ranks: List[int]  # global ranks in the group\n    world_size: int  # size of the group\n    # difference between `local_rank` and `rank_in_group`:\n    # if we have a group of size 4 across two nodes:\n    # Process | Node | Rank | Local Rank | Rank in Group\n    #   0     |   0  |  0   |     0      |       0\n    #   1     |   0  |  1   |     1      |       1\n    #   2     |   1  |  2   |     0      |       2\n    #   3     |   1  |  3   |     1      |       3\n    local_rank: int  # local rank in the current node, used to assign devices\n    rank_in_group: int  # rank inside the group\n    cpu_group: ProcessGroup  # group for CPU communication\n    device_group: ProcessGroup  # group for device communication\n    use_device_communicator: bool  # whether to use device communicator\n    device_communicator: DeviceCommunicatorBase  # device communicator\n\n    def __init__(\n        self,\n        group_ranks: List[List[int]],\n        local_rank: int,\n        torch_distributed_backend: Union[str, Backend],\n        use_device_communicator: bool = True,\n        use_message_queue_broadcaster: bool = False,\n        group_name: str | None = None,\n    ):\n        self.unique_name = _get_unique_name(group_name)\n        self.rank = torch.distributed.get_rank()\n        self.local_rank = local_rank\n        self.device_group = None\n        self.cpu_group = None\n\n        for ranks in group_ranks:\n            device_group = torch.distributed.new_group(\n                ranks, backend=torch_distributed_backend\n            )\n            # a group with `gloo` backend, to allow direct coordination between\n            # processes through the CPU.\n            with suppress_stdout():\n                cpu_group = torch.distributed.new_group(ranks, backend=\"gloo\")\n            if self.rank in ranks:\n                self.ranks = ranks\n                self.world_size = len(ranks)\n                self.rank_in_group = ranks.index(self.rank)\n                self.device_group = device_group\n                self.cpu_group = cpu_group\n\n        assert self.cpu_group is not None, f\"{group_ranks=}, {local_rank=}\"\n        assert self.device_group is not None\n\n        # TODO: fix it for other platforms\n        self.device = get_local_torch_device()\n\n        self.use_device_communicator = use_device_communicator\n\n        self.device_communicator: DeviceCommunicatorBase = None  # type: ignore\n        if use_device_communicator and self.world_size > 1:\n            # Platform-aware device communicator selection\n            if current_platform.is_cuda_alike():\n                from sglang.multimodal_gen.runtime.distributed.device_communicators.cuda_communicator import (\n                    CudaCommunicator,\n                )\n\n                self.device_communicator = CudaCommunicator(\n                    cpu_group=self.cpu_group,\n                    device=self.device,\n                    device_group=self.device_group,\n                    unique_name=self.unique_name,\n                )\n            else:\n                # For MPS and CPU, use the CPU communicator\n                self.device_communicator = CpuCommunicator(\n                    cpu_group=self.cpu_group,\n                    device=self.device,\n                    device_group=self.device_group,\n                    unique_name=self.unique_name,\n                )\n\n        self.mq_broadcaster = None\n\n        # TODO(will): check if this is needed\n        # self.use_custom_op_call = current_platform.is_cuda_alike()\n        self.use_custom_op_call = False\n\n    @property\n    def first_rank(self):\n        \"\"\"Return the global rank of the first process in the group\"\"\"\n        return self.ranks[0]\n\n    @property\n    def last_rank(self):\n        \"\"\"Return the global rank of the last process in the group\"\"\"\n        return self.ranks[-1]\n\n    @property\n    def is_first_rank(self):\n        \"\"\"Return whether the caller is the first process in the group\"\"\"\n        return self.rank == self.first_rank\n\n    @property\n    def is_last_rank(self):\n        \"\"\"Return whether the caller is the last process in the group\"\"\"\n        return self.rank == self.last_rank\n\n    @property\n    def next_rank(self):\n        \"\"\"Return the global rank of the process that follows the caller\"\"\"\n        rank_in_group = self.rank_in_group\n        world_size = self.world_size\n        return self.ranks[(rank_in_group + 1) % world_size]\n\n    @property\n    def prev_rank(self):\n        \"\"\"Return the global rank of the process that precedes the caller\"\"\"\n        rank_in_group = self.rank_in_group\n        world_size = self.world_size\n        return self.ranks[(rank_in_group - 1) % world_size]\n\n    @property\n    def group_next_rank(self):\n        \"\"\"Return the group rank of the process that follows the caller\"\"\"\n        rank_in_group = self.rank_in_group\n        world_size = self.world_size\n        return (rank_in_group + 1) % world_size\n\n    @property\n    def group_prev_rank(self):\n        \"\"\"Return the group rank of the process that precedes the caller\"\"\"\n        rank_in_group = self.rank_in_group\n        world_size = self.world_size\n        return (rank_in_group - 1) % world_size\n\n    @property\n    def skip_rank(self):\n        \"\"\"Return the global rank of the process that skip connects with the caller\"\"\"\n        rank_in_group = self.rank_in_group\n        world_size = self.world_size\n        return self.ranks[(world_size - rank_in_group - 1) % world_size]\n\n    @property\n    def group_skip_rank(self):\n        \"\"\"Return the group rank of the process that skip connects with the caller\"\"\"\n        rank_in_group = self.rank_in_group\n        world_size = self.world_size\n        return (world_size - rank_in_group - 1) % world_size\n\n    @contextmanager\n    def graph_capture(self, graph_capture_context: GraphCaptureContext | None = None):\n        if current_platform.is_cuda_alike():\n            if graph_capture_context is None:\n                stream = torch.cuda.Stream()\n                graph_capture_context = GraphCaptureContext(stream)\n            else:\n                stream = graph_capture_context.stream\n\n            # ensure all initialization operations complete before attempting to\n            # capture the graph on another stream\n            curr_stream = torch.cuda.current_stream()\n            if curr_stream != stream:\n                stream.wait_stream(curr_stream)\n\n            with torch.cuda.stream(stream):\n                yield graph_capture_context\n        else:\n            # For non-CUDA platforms (MPS, CPU), just yield the context without stream management\n            if graph_capture_context is None:\n                # Create a dummy context for non-CUDA platforms\n                graph_capture_context = GraphCaptureContext(None)\n            yield graph_capture_context\n\n    def all_to_all_4D(\n        self, input_: torch.Tensor, scatter_dim: int = 2, gather_dim: int = 1\n    ) -> torch.Tensor:\n        if self.world_size == 1:\n            return input_\n        return self.device_communicator.all_to_all_4D(input_, scatter_dim, gather_dim)\n\n    def all_reduce(\n        self,\n        input_: torch.Tensor,\n        op=torch._C._distributed_c10d.ReduceOp.SUM,\n        async_op: bool = False,\n    ) -> torch.Tensor:\n        \"\"\"\n        NOTE: This operation will be applied in-place or out-of-place.\n        Always assume this function modifies its input, but use the return\n        value as the output.\n        \"\"\"\n        # Bypass the function if we are using only 1 GPU.\n        if self.world_size == 1:\n            return input_\n        else:\n            torch.distributed.all_reduce(\n                input_, op=op, group=self.device_group, async_op=async_op\n            )\n        return input_\n\n    def all_gather(\n        self, input_: torch.Tensor, dim: int = 0, separate_tensors: bool = False\n    ) -> Union[torch.Tensor, List[torch.Tensor]]:\n        world_size = self.world_size\n        # Bypass the function if we are using only 1 GPU.\n        if world_size == 1:\n            return input_\n        assert (\n            -input_.dim() <= dim < input_.dim()\n        ), f\"Invalid dim ({dim}) for input tensor with shape {input_.size()}\"\n        if dim < 0:\n            # Convert negative dim to positive.\n            dim += input_.dim()\n        # Allocate output tensor.\n        input_size = list(input_.size())\n        input_size[0] *= world_size\n        output_tensor = torch.empty(\n            input_size, dtype=input_.dtype, device=input_.device\n        )\n        # All-gather.\n        torch.distributed.all_gather_into_tensor(\n            output_tensor, input_, group=self.device_group\n        )\n        if dim != 0:\n            input_size[0] //= world_size\n            output_tensor = output_tensor.reshape(\n                [\n                    world_size,\n                ]\n                + input_size\n            )\n            output_tensor = output_tensor.movedim(0, dim)\n\n        if separate_tensors:\n            tensor_list = [\n                output_tensor.reshape(-1)\n                .narrow(0, input_.numel() * i, input_.numel())\n                .view_as(input_)\n                for i in range(world_size)\n            ]\n            return tensor_list\n        else:\n            input_size = list(input_.size())\n            input_size[dim] = input_size[dim] * world_size\n            # Reshape\n            output_tensor = output_tensor.reshape(input_size)\n            return output_tensor\n\n    def gather(self, input_: torch.Tensor, dst: int = 0, dim: int = -1) -> torch.Tensor:\n        \"\"\"\n        NOTE: We assume that the input tensor is on the same device across\n        all the ranks.\n        NOTE: `dst` is the local rank of the destination rank.\n        \"\"\"\n        world_size = self.world_size\n        # Bypass the function if we are using only 1 GPU.\n        if world_size == 1:\n            return input_\n        assert (\n            -input_.dim() <= dim < input_.dim()\n        ), f\"Invalid dim ({dim}) for input tensor with shape {input_.size()}\"\n        if dim < 0:\n            # Convert negative dim to positive.\n            dim += input_.dim()\n        # Allocate output tensor.\n        if self.rank_in_group == dst:\n            gather_list = [torch.empty_like(input_) for _ in range(world_size)]\n        else:\n            gather_list = None\n        # Gather.\n        torch.distributed.gather(\n            input_, gather_list, dst=self.ranks[dst], group=self.device_group\n        )\n        if self.rank_in_group == dst:\n            output_tensor = torch.cat(gather_list, dim=dim)\n        else:\n            output_tensor = None\n        return output_tensor\n\n    def broadcast(self, input_: torch.Tensor, src: int = 0, async_op: bool = False):\n        \"\"\"Broadcast the input tensor.\n        NOTE: `src` is the local rank of the source rank.\n        \"\"\"\n        assert src < self.world_size, f\"Invalid src rank ({src})\"\n\n        # Bypass the function if we are using only 1 GPU.\n        if self.world_size == 1:\n            return input_\n        # Broadcast.\n        torch.distributed.broadcast(\n            input_,\n            src=self.ranks[src],\n            group=self.device_group,\n            async_op=async_op,\n        )\n        return input_\n\n    def broadcast_object(self, obj: Optional[Any] = None, src: int = 0):\n        \"\"\"Broadcast the input object.\n        NOTE: `src` is the local rank of the source rank.\n        \"\"\"\n        assert src < self.world_size, f\"Invalid src rank ({src})\"\n\n        # Bypass the function if we are using only 1 GPU.\n        if self.world_size == 1:\n            return obj\n        if self.shm_broadcaster is not None:\n            assert src == 0, \"Shared memory broadcaster only supports src=0\"\n            return self.shm_broadcaster.broadcast_object(obj)\n        if self.rank_in_group == src:\n            torch.distributed.broadcast_object_list(\n                [obj], src=self.ranks[src], group=self.cpu_group\n            )\n            return obj\n        else:\n            recv = [None]\n            torch.distributed.broadcast_object_list(\n                recv, src=self.ranks[src], group=self.cpu_group\n            )\n            return recv[0]\n\n    def broadcast_object_list(\n        self,\n        obj_list: List[Any],\n        src: int = 0,\n        group: Optional[ProcessGroup] = None,\n    ):\n        \"\"\"Broadcast the input object list.\n        NOTE: `src` is the local rank of the source rank.\n        \"\"\"\n        assert src < self.world_size, f\"Invalid src rank ({src})\"\n\n        # Bypass the function if we are using only 1 GPU.\n        if self.world_size == 1:\n            return obj_list\n        # Broadcast.\n        torch.distributed.broadcast_object_list(\n            obj_list, src=self.ranks[src], group=self.device_group\n        )\n        return obj_list\n\n    def send_object(self, obj: Any, dst: int) -> None:\n        \"\"\"Send the input object list to the destination rank.\"\"\"\n        \"\"\"NOTE: `dst` is the local rank of the destination rank.\"\"\"\n\n        assert dst < self.world_size, f\"Invalid dst rank ({dst})\"\n\n        assert dst != self.rank, (\n            \"Invalid destination rank. Destination rank is the same \"\n            \"as the current rank.\"\n        )\n\n        # Serialize object to tensor and get the size as well\n        object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)\n\n        size_tensor = torch.tensor(\n            [object_tensor.numel()], dtype=torch.long, device=\"cpu\"\n        )\n\n        # Send object size\n\n        torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group)\n\n        # Send object\n        torch.distributed.send(object_tensor, dst=self.ranks[dst], group=self.cpu_group)\n\n        return None\n\n    def recv_object(self, src: int) -> Any:\n        \"\"\"Receive the input object list from the source rank.\"\"\"\n        \"\"\"NOTE: `src` is the local rank of the source rank.\"\"\"\n\n        assert src < self.world_size, f\"Invalid src rank ({src})\"\n\n        assert (\n            src != self.rank\n        ), \"Invalid source rank. Source rank is the same as the current rank.\"\n\n        size_tensor = torch.empty(1, dtype=torch.long, device=\"cpu\")\n\n        # Receive object size\n        rank_size = torch.distributed.recv(\n            size_tensor, src=self.ranks[src], group=self.cpu_group\n        )\n\n        # Tensor to receive serialized objects into.\n        object_tensor = torch.empty(  # type: ignore[call-overload]\n            size_tensor.item(),  # type: ignore[arg-type]\n            dtype=torch.uint8,\n            device=\"cpu\",\n        )\n\n        rank_object = torch.distributed.recv(\n            object_tensor, src=self.ranks[src], group=self.cpu_group\n        )\n\n        assert (\n            rank_object == rank_size\n        ), \"Received object sender rank does not match the size sender rank.\"\n\n        obj = pickle.loads(object_tensor.numpy().tobytes())\n\n        return obj\n\n    def broadcast_tensor_dict(\n        self,\n        tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None,\n        src: int = 0,\n        group: Optional[ProcessGroup] = None,\n        metadata_group: Optional[ProcessGroup] = None,\n    ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:\n        \"\"\"Broadcast the input tensor dictionary.\n        NOTE: `src` is the local rank of the source rank.\n        \"\"\"\n        # Bypass the function if we are using only 1 GPU.\n        if not torch.distributed.is_initialized() or self.world_size == 1:\n            return tensor_dict\n\n        group = self.device_group\n        metadata_group = self.cpu_group\n        assert src < self.world_size, f\"Invalid src rank ({src})\"\n        src = self.ranks[src]\n\n        rank = self.rank\n        if rank == src:\n            metadata_list: List[Tuple[Any, Any]] = []\n            assert isinstance(\n                tensor_dict, dict\n            ), f\"Expecting a dictionary, got {type(tensor_dict)}\"\n            metadata_list, tensor_list = _split_tensor_dict(tensor_dict)\n            # `metadata_list` lives in CPU memory.\n            # `broadcast_object_list` has serialization & deserialization,\n            # all happening on CPU. Therefore, we can use the CPU group.\n            self.broadcast_object(metadata_list, src=src)\n            async_handles = []\n            for tensor in tensor_list:\n                if tensor.numel() == 0:\n                    # Skip broadcasting empty tensors.\n                    continue\n                if tensor.is_cpu:\n                    # use metadata_group for CPU tensors\n                    handle = torch.distributed.broadcast(\n                        tensor, src=src, group=metadata_group, async_op=True\n                    )\n                else:\n                    # use group for GPU tensors\n                    handle = torch.distributed.broadcast(\n                        tensor, src=src, group=group, async_op=True\n                    )\n                async_handles.append(handle)\n            for async_handle in async_handles:\n                async_handle.wait()\n\n        else:\n            metadata_list = self.broadcast_object(None, src=src)\n            tensor_dict = {}\n            async_handles = []\n            for key, value in metadata_list:\n                if isinstance(value, TensorMetadata):\n                    tensor = torch.empty(\n                        value.size, dtype=value.dtype, device=value.device\n                    )\n                    if tensor.numel() == 0:\n                        # Skip broadcasting empty tensors.\n                        _update_nested_dict(tensor_dict, key, tensor)\n                        continue\n                    if tensor.is_cpu:\n                        # use metadata_group for CPU tensors\n                        handle = torch.distributed.broadcast(\n                            tensor, src=src, group=metadata_group, async_op=True\n                        )\n                    else:\n                        # use group for GPU tensors\n                        handle = torch.distributed.broadcast(\n                            tensor, src=src, group=group, async_op=True\n                        )\n                    async_handles.append(handle)\n                    _update_nested_dict(tensor_dict, key, tensor)\n                else:\n                    _update_nested_dict(tensor_dict, key, value)\n            for async_handle in async_handles:\n                async_handle.wait()\n        return tensor_dict\n\n    def send_tensor_dict(\n        self,\n        tensor_dict: Dict[str, Union[torch.Tensor, Any]],\n        dst: Optional[int] = None,\n    ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:\n        \"\"\"Send the input tensor dictionary.\n        NOTE: `dst` is the local rank of the source rank.\n        \"\"\"\n        # Bypass the function if we are using only 1 GPU.\n        if not torch.distributed.is_initialized() or self.world_size == 1:\n            return tensor_dict\n\n        group = self.device_group\n        metadata_group = self.cpu_group\n\n        if dst is None:\n            dst = self.group_next_rank\n        assert dst < self.world_size, f\"Invalid dst rank ({dst})\"\n\n        metadata_list: List[Tuple[Any, Any]] = []\n        assert isinstance(\n            tensor_dict, dict\n        ), f\"Expecting a dictionary, got {type(tensor_dict)}\"\n        metadata_list, tensor_list = _split_tensor_dict(tensor_dict)\n        # `metadata_list` lives in CPU memory.\n        # `send_object_list` has serialization & deserialization,\n        # all happening on CPU. Therefore, we can use the CPU group.\n        self.send_object(metadata_list, dst=dst)\n        for tensor in tensor_list:\n            if tensor.numel() == 0:\n                # Skip sending empty tensors.\n                continue\n            if tensor.is_cpu:\n                # use metadata_group for CPU tensors\n                torch.distributed.send(\n                    tensor, dst=self.ranks[dst], group=metadata_group\n                )\n            else:\n                # use group for GPU tensors\n                torch.distributed.send(tensor, dst=self.ranks[dst], group=group)\n        return None\n\n    def recv_tensor_dict(\n        self, src: Optional[int] = None\n    ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:\n        \"\"\"Recv the input tensor dictionary.\n        NOTE: `src` is the local rank of the source rank.\n        \"\"\"\n        # Bypass the function if we are using only 1 GPU.\n        if not torch.distributed.is_initialized() or self.world_size == 1:\n            return None\n\n        group = self.device_group\n        metadata_group = self.cpu_group\n\n        if src is None:\n            src = self.group_prev_rank\n        assert src < self.world_size, f\"Invalid src rank ({src})\"\n\n        recv_metadata_list = self.recv_object(src=src)\n        tensor_dict: Dict[str, Any] = {}\n        for key, value in recv_metadata_list:\n            if isinstance(value, TensorMetadata):\n                tensor = torch.empty(value.size, dtype=value.dtype, device=value.device)\n                if tensor.numel() == 0:\n                    # Skip broadcasting empty tensors.\n                    _update_nested_dict(tensor_dict, key, tensor)\n                    continue\n                if tensor.is_cpu:\n                    # use metadata_group for CPU tensors\n                    torch.distributed.recv(\n                        tensor, src=self.ranks[src], group=metadata_group\n                    )\n                else:\n                    # use group for GPU tensors\n                    torch.distributed.recv(tensor, src=self.ranks[src], group=group)\n                _update_nested_dict(tensor_dict, key, tensor)\n            else:\n                _update_nested_dict(tensor_dict, key, value)\n        return tensor_dict\n\n    def barrier(self):\n        \"\"\"Barrier synchronization among the group.\n        NOTE: don't use `device_group` here! `barrier` in NCCL is\n        terrible because it is internally a broadcast operation with\n        secretly created GPU tensors. It is easy to mess up the current\n        device. Use the CPU group instead.\n        \"\"\"\n        torch.distributed.barrier(group=self.cpu_group)\n\n    def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:\n        \"\"\"Sends a tensor to the destination rank in a non-blocking way\"\"\"\n        \"\"\"NOTE: `dst` is the rank_in_group of the destination rank.\"\"\"\n        if dst is None:\n            dst = self.group_next_rank\n\n        torch.distributed.send(\n            tensor,\n            self.ranks[dst],\n            group=(\n                self.device_groups[self.rank_in_group % 2]\n                if self.world_size == 2\n                else self.device_group\n            ),\n        )\n\n    def recv(\n        self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None\n    ) -> torch.Tensor:\n        \"\"\"Receives a tensor from the src rank.\"\"\"\n        \"\"\"NOTE: `src` is the rank_in_group of the source rank.\"\"\"\n        if src is None:\n            src = self.group_prev_rank\n\n        tensor = torch.empty(size, dtype=dtype, device=self.device)\n        torch.distributed.recv(\n            tensor,\n            self.ranks[src],\n            (\n                self.device_groups[(self.rank_in_group + 1) % 2]\n                if self.world_size == 2\n                else self.device_group\n            ),\n        )\n        return tensor\n\n    def destroy(self) -> None:\n        if self.device_group is not None:\n            torch.distributed.destroy_process_group(self.device_group)\n            self.device_group = None\n        if self.cpu_group is not None:\n            torch.distributed.destroy_process_group(self.cpu_group)\n            self.cpu_group = None\n        if self.device_communicator is not None:\n            self.device_communicator.destroy()\n        if self.mq_broadcaster is not None:\n            self.mq_broadcaster = None\n\n\nclass PipelineGroupCoordinator(GroupCoordinator):\n    \"\"\"\n    available attributes:\n    rank: int  # global rank\n    ranks: List[int]  # global ranks in the group\n    world_size: int  # size of the group\n    difference between `local_rank` and `rank_in_group`:\n    if we have a group of size 4 across two nodes:\n    Process | Node | Rank | Local Rank | Rank in Group\n      0     |   0  |  0   |     0      |       0\n      1     |   0  |  1   |     1      |       1\n      2     |   1  |  2   |     0      |       2\n      3     |   1  |  3   |     1      |       3\n    local_rank: int  # local rank used to assign devices\n    rank_in_group: int  # rank inside the group\n    cpu_group: ProcessGroup  # group for CPU communication\n    device_group: ProcessGroup  # group for device communication\n    \"\"\"\n\n    def __init__(\n        self,\n        group_ranks: List[List[int]],\n        local_rank: int,\n        torch_distributed_backend: Union[str, Backend],\n        group_name: str | None = None,\n    ):\n        super().__init__(\n            group_ranks=group_ranks,\n            local_rank=local_rank,\n            torch_distributed_backend=torch_distributed_backend,\n            group_name=group_name,\n        )\n        self.rank = torch.distributed.get_rank()\n        self.local_rank = local_rank\n        self.device_group = None\n        self.cpu_group = None\n        self.cpu_groups = []\n        self.device_groups = []\n        if len(group_ranks[0]) > 2 or len(group_ranks[0]) == 1:\n            for ranks in group_ranks:\n                device_group = torch.distributed.new_group(\n                    ranks, backend=torch_distributed_backend\n                )\n                # a group with `gloo` backend, to allow direct coordination between\n                # processes through the CPU.\n                with suppress_stdout():\n                    cpu_group = torch.distributed.new_group(ranks, backend=\"gloo\")\n                if self.rank in ranks:\n                    self.ranks = ranks\n                    self.world_size = len(ranks)\n                    self.rank_in_group = ranks.index(self.rank)\n                    self.device_group = device_group\n                    self.cpu_group = cpu_group\n        # when pipeline parallelism is 2, we need to create two groups to avoid\n        #   communication stall.\n        # *_group_0_1 represents the group for communication from device 0 to\n        #   device 1.\n        # *_group_1_0 represents the group for communication from device 1 to\n        #   device 0.\n        elif len(group_ranks[0]) == 2:\n            for ranks in group_ranks:\n                device_group_0_1 = torch.distributed.new_group(\n                    ranks, backend=torch_distributed_backend\n                )\n                device_group_1_0 = torch.distributed.new_group(\n                    ranks, backend=torch_distributed_backend\n                )\n                # a group with `gloo` backend, to allow direct coordination between\n                # processes through the CPU.\n                with suppress_stdout():\n                    cpu_group_0_1 = torch.distributed.new_group(ranks, backend=\"gloo\")\n                    cpu_group_1_0 = torch.distributed.new_group(ranks, backend=\"gloo\")\n                if self.rank in ranks:\n                    self.ranks = ranks\n                    self.world_size = len(ranks)\n                    self.rank_in_group = ranks.index(self.rank)\n                    self.device_groups = [device_group_0_1, device_group_1_0]\n                    self.cpu_groups = [cpu_group_0_1, cpu_group_1_0]\n                    self.device_group = device_group_0_1\n                    self.cpu_group = cpu_group_0_1\n\n        assert self.cpu_group is not None\n        assert self.device_group is not None\n\n        self.device = current_platform.get_device(local_rank)\n\n        self.recv_buffer_set: bool = False\n        self.recv_tasks_queue: List[Tuple[str, int]] = []\n        self.receiving_tasks: List[Tuple[torch.distributed.Work, str, int]] = []\n        self.dtype: Optional[torch.dtype] = None\n        self.num_pipefusion_patches: Optional[int] = None\n\n        self.recv_shape: Dict[str, Dict[int, torch.Size]] = {}\n        self.send_shape: Dict[str, Dict[int, torch.Size]] = {}\n        self.recv_buffer: Dict[str, Dict[int, torch.Size]] = {}\n\n        self.skip_tensor_recv_buffer_set: bool = False\n        self.recv_skip_tasks_queue: List[Union[int, Tuple[str, int]]] = []\n        self.receiving_skip_tasks: List[Tuple[torch.distributed.Work, str, int]] = []\n        self.skip_tensor_recv_buffer: Optional[\n            Union[List[torch.Tensor], torch.Tensor]\n        ] = None\n        self.skip_device_group = None\n        for ranks in group_ranks:\n            skip_device_group = torch.distributed.new_group(\n                ranks, backend=torch_distributed_backend\n            )\n            if self.rank in ranks:\n                self.skip_device_group = skip_device_group\n        assert self.skip_device_group is not None\n\n    def reset_buffer(self):\n        self.recv_tasks_queue = []\n        self.receiving_tasks = []\n        self.recv_shape = {}\n        self.send_shape = {}\n        self.recv_buffer = {}\n\n        self.recv_skip_tasks_queue = []\n        self.receiving_skip_tasks = []\n        self.skip_tensor_recv_buffer = {}\n\n    def set_config(self, dtype: torch.dtype):\n        self.dtype = dtype\n\n    def set_recv_buffer(\n        self,\n        num_pipefusion_patches: int,\n        patches_shape_list: List[List[int]],\n        feature_map_shape: List[int],\n        dtype: torch.dtype,\n    ):\n        assert isinstance(dtype, torch.dtype), \"dtype must be a torch.dtype object\"\n        assert (\n            isinstance(num_pipefusion_patches, int) and num_pipefusion_patches >= 1\n        ), \"num_pipefusion_patches must be greater than or equal to 1\"\n        self.dtype = dtype\n        self.num_pipefusion_patches = num_pipefusion_patches\n        self.recv_buffer = [\n            torch.zeros(*shape, dtype=self.dtype, device=self.device)\n            for shape in patches_shape_list\n        ]\n        self.recv_buffer.append(\n            torch.zeros(*feature_map_shape, dtype=self.dtype, device=self.device)\n        )\n        self.recv_buffer_set = True\n\n    def set_extra_tensors_recv_buffer(\n        self,\n        name: str,\n        shape: List[int],\n        num_buffers: int = 1,\n        dtype: torch.dtype = torch.float16,\n    ):\n        self.extra_tensors_recv_buffer[name] = [\n            torch.zeros(*shape, dtype=dtype, device=self.device)\n            for _ in range(num_buffers)\n        ]\n\n    def _check_shape_and_buffer(\n        self,\n        tensor_send_to_next=None,\n        recv_prev=False,\n        name: Optional[str] = None,\n        segment_idx: int = 0,\n    ):\n        send_flag = False\n        name = name or \"latent\"\n        if tensor_send_to_next is not None:\n            shape_list = self.send_shape.get(name, None)\n            if shape_list is None:\n                self.send_shape[name] = {segment_idx: tensor_send_to_next.shape}\n                send_flag = True\n            elif shape_list.get(segment_idx, None) is None:\n                self.send_shape[name][segment_idx] = tensor_send_to_next.shape\n                send_flag = True\n\n        recv_flag = False\n        if recv_prev:\n            shape_list = self.recv_shape.get(name, None)\n            if shape_list is None:\n                recv_flag = True\n            elif shape_list.get(segment_idx, None) is None:\n                recv_flag = True\n\n        recv_prev_shape = self._communicate_shapes(\n            tensor_send_to_next=tensor_send_to_next if send_flag else None,\n            recv_prev=recv_flag,\n        )\n\n        if recv_flag:\n            if self.recv_shape.get(name, None) is None:\n                self.recv_shape[name] = {segment_idx: recv_prev_shape}\n            else:\n                self.recv_shape[name][segment_idx] = recv_prev_shape\n\n            if self.recv_buffer.get(name, None) is None:\n                self.recv_buffer[name] = {\n                    segment_idx: torch.zeros(\n                        recv_prev_shape, device=self.device, dtype=self.dtype\n                    )\n                }\n            else:\n                if self.recv_buffer[name].get(segment_idx, None) is not None:\n                    logger.warning(\n                        f\"Recv buffer [name: {name}, segment_idx: {segment_idx}] already exist. updating...\"\n                    )\n                self.recv_buffer[name][segment_idx] = torch.zeros(\n                    recv_prev_shape, device=self.device, dtype=self.dtype\n                )\n\n    def _communicate_shapes(self, tensor_send_to_next=None, recv_prev=False):\n        \"\"\"Communicate tensor shapes between stages. Used to communicate\n        tensor shapes before the actual tensor communication happens.\n\n        Args:\n            tensor_send_next: tensor to send to next rank (no tensor sent if\n                              set to None).\n            recv_prev: boolean for whether tensor should be received from\n                       previous rank.\n        \"\"\"\n\n        ops = []\n        if recv_prev:\n            recv_prev_dim_tensor = torch.empty(\n                (1), device=self.device, dtype=torch.int64\n            )\n            recv_prev_dim_op = torch.distributed.P2POp(\n                torch.distributed.irecv,\n                recv_prev_dim_tensor,\n                self.prev_rank,\n                self.device_group,\n            )\n            ops.append(recv_prev_dim_op)\n\n        if tensor_send_to_next is not None:\n            send_next_dim_tensor = torch.tensor(\n                tensor_send_to_next.dim(), device=self.device, dtype=torch.int64\n            )\n            send_next_dim_op = torch.distributed.P2POp(\n                torch.distributed.isend,\n                send_next_dim_tensor,\n                self.next_rank,\n                self.device_group,\n            )\n            ops.append(send_next_dim_op)\n\n        if len(ops) > 0:\n            reqs = torch.distributed.batch_isend_irecv(ops)\n            for req in reqs:\n                req.wait()\n\n        # To protect against race condition when using batch_isend_irecv().\n        # should take this out once the bug with batch_isend_irecv is resolved.\n        synchronize()\n\n        ops = []\n        recv_prev_shape_tensor = None\n        if recv_prev:\n            recv_prev_shape_tensor = torch.empty(\n                torch.Size(recv_prev_dim_tensor),\n                device=self.device,\n                dtype=torch.int64,\n            )\n            recv_prev_shape_op = torch.distributed.P2POp(\n                torch.distributed.irecv,\n                recv_prev_shape_tensor,\n                self.prev_rank,\n                self.device_group,\n            )\n            ops.append(recv_prev_shape_op)\n\n        if tensor_send_to_next is not None:\n            send_next_shape_tensor = torch.tensor(\n                tensor_send_to_next.size(),\n                device=self.device,\n                dtype=torch.int64,\n            )\n            send_next_shape_op = torch.distributed.P2POp(\n                torch.distributed.isend,\n                send_next_shape_tensor,\n                self.next_rank,\n                self.device_group,\n            )\n            ops.append(send_next_shape_op)\n\n        if len(ops) > 0:\n            reqs = torch.distributed.batch_isend_irecv(ops)\n            for req in reqs:\n                req.wait()\n\n        synchronize()\n\n        recv_prev_shape = [0, 0, 0]\n        if recv_prev_shape_tensor is not None:\n            recv_prev_shape = recv_prev_shape_tensor\n        return torch.Size(recv_prev_shape)\n\n    def pipeline_send(\n        self, tensor: torch.Tensor, name: str = \"latent\", segment_idx: int = -1\n    ) -> None:\n        tensor = tensor.contiguous()\n        self._check_shape_and_buffer(\n            tensor_send_to_next=tensor, name=name, segment_idx=segment_idx\n        )\n        self._pipeline_isend(tensor).wait()\n\n    def pipeline_isend(\n        self, tensor: torch.Tensor, name: str = \"latent\", segment_idx: int = -1\n    ) -> None:\n        tensor = tensor.contiguous()\n        self._check_shape_and_buffer(\n            tensor_send_to_next=tensor, name=name, segment_idx=segment_idx\n        )\n        self._pipeline_isend(tensor)\n\n    def pipeline_recv(self, idx: int = -1, name: str = \"latent\") -> torch.Tensor:\n        name = name or \"latent\"\n        self._check_shape_and_buffer(recv_prev=True, name=name, segment_idx=idx)\n        self._pipeline_irecv(self.recv_buffer[name][idx]).wait()\n        return self.recv_buffer[name][idx]\n\n    def add_pipeline_recv_task(self, idx: int = -1, name: str = \"latent\"):\n        name = name or \"latent\"\n        self.recv_tasks_queue.append((name, idx))\n\n    def recv_next(self):\n        if len(self.recv_tasks_queue) == 0:\n            raise ValueError(\"No more tasks to receive\")\n        elif len(self.recv_tasks_queue) > 0:\n            name, idx = self.recv_tasks_queue.pop(0)\n            self._check_shape_and_buffer(recv_prev=True, name=name, segment_idx=idx)\n            self.receiving_tasks.append(\n                (self._pipeline_irecv(self.recv_buffer[name][idx]), name, idx)\n            )\n\n    def get_pipeline_recv_data(\n        self, idx: int = -1, name: str = \"latent\"\n    ) -> torch.Tensor:\n        assert (\n            len(self.receiving_tasks) > 0\n        ), \"No tasks to receive, call add_pipeline_recv_task first\"\n        receiving_task = self.receiving_tasks.pop(0)\n        receiving_task[0].wait()\n        assert (\n            receiving_task[1] == name and receiving_task[2] == idx\n        ), \"Received tensor does not match the requested\"\n        return self.recv_buffer[name][idx]\n\n    def _pipeline_irecv(self, tensor: torch.tensor):\n        return torch.distributed.irecv(\n            tensor,\n            src=self.prev_rank,\n            group=(\n                self.device_groups[(self.rank_in_group + 1) % 2]\n                if self.world_size == 2\n                else self.device_group\n            ),\n        )\n\n    def _pipeline_isend(self, tensor: torch.tensor):\n        return torch.distributed.isend(\n            tensor,\n            dst=self.next_rank,\n            group=(\n                self.device_groups[self.rank_in_group % 2]\n                if self.world_size == 2\n                else self.device_group\n            ),\n        )\n\n    def set_skip_tensor_recv_buffer(\n        self,\n        patches_shape_list: List[List[int]],\n        feature_map_shape: List[int],\n    ):\n        self.skip_tensor_recv_buffer = [\n            torch.zeros(*shape, dtype=self.dtype, device=self.device)\n            for shape in patches_shape_list\n        ]\n        self.skip_tensor_recv_buffer.append(\n            torch.zeros(*feature_map_shape, dtype=self.dtype, device=self.device)\n        )\n        self.skip_tensor_recv_buffer_set = True\n\n    def pipeline_send_skip(self, tensor: torch.Tensor) -> None:\n        tensor = tensor.contiguous()\n        self._pipeline_isend_skip(tensor).wait()\n\n    def pipeline_isend_skip(self, tensor: torch.Tensor) -> None:\n        tensor = tensor.contiguous()\n        self._pipeline_isend_skip(tensor)\n\n    def pipeline_recv_skip(self, idx: int = -1) -> torch.Tensor:\n        self._pipeline_irecv_skip(self.skip_tensor_recv_buffer[idx]).wait()\n        return self.skip_tensor_recv_buffer[idx]\n\n    def add_pipeline_recv_skip_task(self, idx: int = -1):\n        self.recv_skip_tasks_queue.append(idx)\n\n    def get_pipeline_recv_skip_data(self, idx: int = -1) -> torch.Tensor:\n        assert (\n            len(self.receiving_skip_tasks) > 0\n        ), \"No tasks to receive, call add_pipeline_recv_skip_task first\"\n        receiving_skip_task = self.receiving_skip_tasks.pop(0)\n        receiving_skip_task[0].wait()\n        assert (\n            receiving_skip_task[2] == idx\n        ), \"Received tensor does not match the requested\"\n        return self.skip_tensor_recv_buffer[idx]\n\n    def recv_skip_next(self):\n        if len(self.recv_skip_tasks_queue) == 0:\n            raise ValueError(\"No more tasks to receive\")\n        elif len(self.recv_skip_tasks_queue) > 0:\n            task = self.recv_skip_tasks_queue.pop(0)\n            idx = task\n            self.receiving_skip_tasks.append(\n                (\n                    self._pipeline_irecv_skip(self.skip_tensor_recv_buffer[idx]),\n                    None,\n                    idx,\n                )\n            )\n\n    def _pipeline_irecv_skip(self, tensor: torch.tensor):\n        return torch.distributed.irecv(\n            tensor, src=self.skip_rank, group=self.skip_device_group\n        )\n\n    def _pipeline_isend_skip(self, tensor: torch.tensor):\n        return torch.distributed.isend(\n            tensor, dst=self.skip_rank, group=self.skip_device_group\n        )\n\n\nclass SequenceParallelGroupCoordinator(GroupCoordinator):\n    def __init__(\n        self,\n        group_ranks: List[List[int]],\n        local_rank: int,\n        torch_distributed_backend: Union[str, Backend],\n        group_name: str | None = None,\n        **kwargs,\n    ):\n        super().__init__(\n            group_ranks=group_ranks,\n            local_rank=local_rank,\n            torch_distributed_backend=torch_distributed_backend,\n            group_name=group_name,\n        )\n        ulysses_group = kwargs.get(\"ulysses_group\", None)\n        ring_group = kwargs.get(\"ring_group\", None)\n        if ulysses_group is None:\n            raise RuntimeError(\n                f\"Please pass argument 'ulysses_group' when calling init func of SequenceParallelGroupCoordinator\"\n            )\n        if ring_group is None:\n            raise RuntimeError(\n                f\"Please pass argument 'ring_group' when calling init func of SequenceParallelGroupCoordinator\"\n            )\n        self.ulysses_group = ulysses_group\n        self.ring_group = ring_group\n\n        self.ulysses_world_size = torch.distributed.get_world_size(self.ulysses_group)\n        self.ulysses_rank = torch.distributed.get_rank(self.ulysses_group)\n        self.ring_world_size = torch.distributed.get_world_size(self.ring_group)\n        self.ring_rank = torch.distributed.get_rank(self.ring_group)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/distributed/parallel_groups.py",
    "content": "# Reference: https://github.com/feifeibear/long-context-attention/blob/main/yunchang/globals.py\n\n\nimport torch\n\n\nclass Singleton:\n    _instance = None\n\n    def __new__(cls, *args, **kwargs):\n        if not cls._instance:\n            cls._instance = super(Singleton, cls).__new__(cls, *args, **kwargs)\n        return cls._instance\n\n\nclass ProcessGroupSingleton(Singleton):\n    def __init__(self):\n        self.ULYSSES_PG = None\n        self.RING_PG = None\n\n\nPROCESS_GROUP = ProcessGroupSingleton()\n\n\ndef set_seq_parallel_pg_by_sp_groups(\n    sp_ulysses_degree,\n    sp_ring_degree,\n    rank: int,\n    sp_groups: list[list[int]],\n    use_ulysses_low: bool = True,\n):\n    \"\"\"Create Ulysses/Ring process groups inside each SP group.\n\n    This is required when TP>1, because SP groups are not necessarily made of\n    consecutive global ranks (e.g., tp-sp order makes SP ranks strided).\n\n    Args:\n        sp_ulysses_degree: ulysses degree inside SP.\n        sp_ring_degree: ring degree inside SP.\n        rank: global rank of current process.\n        sp_groups: list of global-rank lists for each SP group.\n        use_ulysses_low: keep the same semantics as the original function.\n    \"\"\"\n    sp_degree = sp_ring_degree * sp_ulysses_degree\n    assert sp_degree > 0\n    assert all(\n        len(g) == sp_degree for g in sp_groups\n    ), f\"Each SP group must have size {sp_degree}, got sizes {[len(g) for g in sp_groups]}\"\n\n    ulyssess_pg = None\n    ring_pg = None\n\n    num_ulysses_pgs = sp_ring_degree\n    num_ring_pgs = sp_ulysses_degree\n\n    def _map_indices_to_ranks(ranks: list[int], indices: list[int]) -> list[int]:\n        return [ranks[i] for i in indices]\n\n    # Important: call torch.distributed.new_group in the same order on all ranks.\n    for sp_ranks in sp_groups:\n        if use_ulysses_low:\n            for i in range(num_ulysses_pgs):\n                idx = list(range(i * sp_ulysses_degree, (i + 1) * sp_ulysses_degree))\n                ulysses_ranks = _map_indices_to_ranks(sp_ranks, idx)\n                group = torch.distributed.new_group(ulysses_ranks)\n                if rank in ulysses_ranks:\n                    ulyssess_pg = group\n\n            for i in range(num_ring_pgs):\n                idx = list(range(i, sp_degree, num_ring_pgs))\n                ring_ranks = _map_indices_to_ranks(sp_ranks, idx)\n                group = torch.distributed.new_group(ring_ranks)\n                if rank in ring_ranks:\n                    ring_pg = group\n        else:\n            for i in range(num_ring_pgs):\n                idx = list(range(i * sp_ring_degree, (i + 1) * sp_ring_degree))\n                ring_ranks = _map_indices_to_ranks(sp_ranks, idx)\n                group = torch.distributed.new_group(ring_ranks)\n                if rank in ring_ranks:\n                    ring_pg = group\n\n            for i in range(num_ulysses_pgs):\n                idx = list(range(i, sp_degree, num_ulysses_pgs))\n                ulysses_ranks = _map_indices_to_ranks(sp_ranks, idx)\n                group = torch.distributed.new_group(ulysses_ranks)\n                if rank in ulysses_ranks:\n                    ulyssess_pg = group\n\n    PROCESS_GROUP.ULYSSES_PG = ulyssess_pg\n    PROCESS_GROUP.RING_PG = ring_pg\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/distributed/parallel_state.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n# SPDX-License-Identifier: Apache-2.0\n# Adapted from: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/parallel_state.py\n# Copyright 2023 The vLLM team.\n# Adapted from\n# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py\n# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.\n# Adapted from\n# Copyright 2024 xDiT team.\n# Adapted from\n# https://github.com/vllm-project/vllm/blob/main/vllm/distributed/parallel_state.py\n# Copyright 2023 The vLLM team.\n# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.\n\n\"\"\"sglang-diffusion distributed state.\n\nIt takes over the control of the distributed environment from PyTorch.\nThe typical workflow is:\n\n- call `init_distributed_environment` to initialize the distributed environment.\n- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to\n initialize the model parallel groups.\n\n- any code dealing with the distributed stuff\n\n- call `destroy_model_parallel` to destroy the model parallel groups.\n- call `destroy_distributed_environment` to destroy the distributed environment.\n\nIf you only need to use the distributed environment without model parallelism,\n you can skip the model parallel initialization and destruction steps.\n\"\"\"\n\nimport contextlib\nimport datetime\nimport os\nimport weakref\nfrom collections import namedtuple\nfrom collections.abc import Callable\nfrom contextlib import contextmanager\nfrom multiprocessing import shared_memory\nfrom typing import Any, List, Optional\nfrom unittest.mock import patch\n\nimport torch\nimport torch.distributed\nfrom torch.distributed import ProcessGroup\n\nimport sglang.multimodal_gen.envs as envs\nfrom sglang.multimodal_gen.runtime.distributed.utils import StatelessProcessGroup\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nfrom ..utils.distributed import RankGenerator\nfrom .group_coordinator import (\n    GroupCoordinator,\n    PipelineGroupCoordinator,\n    SequenceParallelGroupCoordinator,\n    get_local_torch_device,\n)\n\nlogger = init_logger(__name__)\n\n_WORLD: GroupCoordinator | None = None\n_TP: GroupCoordinator | None = None\n_SP: SequenceParallelGroupCoordinator | None = None\n_PP: PipelineGroupCoordinator | None = None\n_CFG: GroupCoordinator | None = None\n_DP: GroupCoordinator | None = None\n_DIT: ProcessGroup | None = None\n_VAE: ProcessGroup | None = None\n\nTensorMetadata = namedtuple(\"TensorMetadata\", [\"device\", \"dtype\", \"size\"])\n\n\ndef _split_tensor_dict(\n    tensor_dict: dict[str, torch.Tensor | Any],\n) -> tuple[list[tuple[str, Any]], list[torch.Tensor]]:\n    \"\"\"Split the tensor dictionary into two parts:\n    1. A list of (key, value) pairs. If the value is a tensor, it is replaced\n         by its metadata.\n    2. A list of tensors.\n    \"\"\"\n    metadata_list: list[tuple[str, Any]] = []\n    tensor_list: list[torch.Tensor] = []\n    for key, value in tensor_dict.items():\n        if isinstance(value, torch.Tensor):\n            # Note: we cannot use `value.device` here,\n            # because it contains not only the device type but also the device\n            # index (e.g. \"cuda:0\"). We only need the device type.\n            # receiving side will set the device index.\n            device = value.device.type\n            metadata_list.append(\n                (key, TensorMetadata(device, value.dtype, value.size()))\n            )\n            tensor_list.append(value)\n        else:\n            metadata_list.append((key, value))\n    return metadata_list, tensor_list\n\n\n_groups: dict[str, Callable[[], Optional[\"GroupCoordinator\"]]] = {}\n\n\ndef _register_group(group: \"GroupCoordinator\") -> None:\n    _groups[group.unique_name] = weakref.ref(group)\n\n\ndef all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:\n    assert group_name in _groups, f\"Group {group_name} is not found.\"\n    group = _groups[group_name]()\n    if group is None:\n        raise ValueError(f\"Group {group_name} is destroyed.\")\n    return group._all_reduce_out_place(tensor)\n\n\ndef all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:\n    return torch.empty_like(tensor)\n\n\ndef get_world_group() -> GroupCoordinator:\n    assert _WORLD is not None, \"world group is not initialized\"\n    return _WORLD\n\n\ndef init_world_group(\n    ranks: list[int], local_rank: int, backend: str\n) -> GroupCoordinator:\n    return GroupCoordinator(\n        group_ranks=[ranks],\n        local_rank=local_rank,\n        torch_distributed_backend=backend,\n        use_device_communicator=True,\n        group_name=\"world\",\n    )\n\n\ndef init_parallel_group_coordinator(\n    group_ranks: List[List[int]],\n    local_rank: int,\n    backend: str,\n    parallel_mode: str,\n    **kwargs,\n) -> GroupCoordinator:\n    \"\"\"Return a group coordinator for the given parallel mode.\"\"\"\n    assert parallel_mode in [\n        \"data\",\n        \"pipeline\",\n        \"tensor\",\n        \"sequence\",\n        \"classifier_free_guidance\",\n    ], f\"parallel_mode {parallel_mode} is not supported\"\n    if parallel_mode == \"pipeline\":\n        return PipelineGroupCoordinator(\n            group_ranks=group_ranks,\n            local_rank=local_rank,\n            torch_distributed_backend=backend,\n            group_name=\"pp_group\",\n        )\n    elif parallel_mode == \"sequence\":\n        return SequenceParallelGroupCoordinator(\n            group_ranks=group_ranks,\n            local_rank=local_rank,\n            torch_distributed_backend=backend,\n            group_name=\"sp_group\",\n            **kwargs,\n        )\n    else:\n        # fallback to GroupCoordinator\n        return GroupCoordinator(\n            group_ranks=group_ranks,\n            local_rank=local_rank,\n            torch_distributed_backend=backend,\n            group_name=\"cfg_group\",\n        )\n\n\ndef get_tp_group() -> GroupCoordinator:\n    assert _TP is not None, \"tensor model parallel group is not initialized\"\n    return _TP\n\n\ndef init_distributed_environment(\n    world_size: int = 1,\n    rank: int = 0,\n    distributed_init_method: str = \"env://\",\n    local_rank: int = 0,\n    backend: str = \"nccl\",\n    device_id: torch.device | None = None,\n    timeout: int | None = None,\n):\n    # Determine the appropriate backend based on the platform\n    from sglang.multimodal_gen.runtime.platforms import current_platform\n\n    if backend == \"nccl\" and not current_platform.is_cuda_alike():\n        # Use gloo backend for non-CUDA platforms (MPS, CPU)\n        backend = \"gloo\"\n        logger.info(\"Using gloo backend for %s platform\", current_platform.device_name)\n\n    logger.debug(\n        \"world_size=%d rank=%d local_rank=%d \"\n        \"distributed_init_method=%s backend=%s timeout=%s\",\n        world_size,\n        rank,\n        local_rank,\n        distributed_init_method,\n        backend,\n        timeout,\n    )\n    if not torch.distributed.is_initialized():\n        assert distributed_init_method is not None, (\n            \"distributed_init_method must be provided when initializing \"\n            \"distributed environment\"\n        )\n\n        # For MPS and MUSA, don't pass device_id as it doesn't support device indices\n        extra_args = (\n            {}\n            if (\n                current_platform.is_mps()\n                or current_platform.is_musa()\n                or current_platform.is_npu()\n            )\n            else dict(device_id=device_id)\n        )\n\n        if timeout is not None:\n\n            extra_args[\"timeout\"] = datetime.timedelta(seconds=timeout)\n            logger.info(f\"Setting distributed timeout to {timeout} seconds\")\n\n        torch.distributed.init_process_group(\n            backend=backend,\n            init_method=distributed_init_method,\n            world_size=world_size,\n            rank=rank,\n            **extra_args,\n        )\n\n    # set the local rank\n    # local_rank is not available in torch ProcessGroup,\n    # see https://github.com/pytorch/pytorch/issues/122816\n    if local_rank == -1:\n        # local rank not set, this usually happens in single-node\n        # setting, where we can use rank as local rank\n        if distributed_init_method == \"env://\":\n            local_rank = envs.LOCAL_RANK\n        else:\n            local_rank = rank\n    global _WORLD\n    if _WORLD is None:\n        ranks = list(range(torch.distributed.get_world_size()))\n        _WORLD = init_world_group(ranks, local_rank, backend)\n    else:\n        assert (\n            _WORLD.world_size == torch.distributed.get_world_size()\n        ), \"world group already initialized with a different world size\"\n\n\ndef get_sp_group() -> SequenceParallelGroupCoordinator:\n    assert _SP is not None, \"sequence parallel group is not initialized\"\n    return _SP\n\n\ndef get_dp_group() -> GroupCoordinator:\n    assert _DP is not None, \"data parallel group is not initialized\"\n    return _DP\n\n\n# xDiT\ndef initialize_model_parallel(\n    data_parallel_size: int = 1,\n    classifier_free_guidance_degree: int = 1,\n    sequence_parallel_degree: Optional[int] = None,\n    ulysses_degree: int = 1,\n    ring_degree: int = 1,\n    tensor_parallel_degree: int = 1,\n    pipeline_parallel_degree: int = 1,\n    vae_parallel_size: int = 0,\n    backend: Optional[str] = None,\n) -> None:\n    \"\"\"\n    Initialize model parallel groups.\n\n    Arguments:\n        data_parallel_size: number of data parallelism groups.\n        classifier_free_guidance_degree: number of GPUs used for Classifier Free Guidance (CFG)\n        sequence_parallel_degree: number of GPUs used for sequence parallelism. sequence_parallel_degree = ulysses_degree * ring_degree\n        ulysses_degree: number of GPUs used for ulysses sequence parallelism.\n        ring_degree: number of GPUs used for ring sequence parallelism.\n        tensor_parallel_degree: number of GPUs used for tensor parallelism.\n        pipeline_parallel_degree: number of GPUs used for pipeline parallelism.\n        backend: distributed backend of pytorch collective comm.\n\n    Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we\n    use 2 groups to parallelize the batch dim(dp), 2 groups to parallelize\n    split batch caused by CFG, and 2 GPUs to parallelize sequence.\n\n    dp_degree (2) * cfg_degree (2) * sp_degree (2) * pp_degree (2) = 16.\n\n    The present function will create 8 data-parallel groups,\n    8 CFG group, 8 pipeline-parallel group, and\n    8 sequence-parallel groups:\n        8 data-parallel groups:\n            [g0, g8], [g1, g9], [g2, g10], [g3, g11],\n            [g4, g12], [g5, g13], [g6, g14], [g7, g15]\n        8 CFG-parallel groups:\n            [g0, g4], [g1, g5], [g2, g6], [g3, g7],\n            [g8, g12], [g9, g13], [g10, g14], [g11, g15]\n        8 sequence-parallel groups:\n            [g0, g1], [g2, g3], [g4, g5], [g6, g7],\n            [g8, g9], [g10, g11], [g12, g13], [g14, g15]\n        8 pipeline-parallel groups:\n            [g0, g2], [g4, g6], [g8, g10], [g12, g14],\n            [g1, g3], [g5, g7], [g9, g11], [g13, g15]\n    Note that for efficiency, the caller should make sure adjacent ranks\n    are on the same DGX box. For example if we are using 2 DGX-1 boxes\n    with a total of 16 GPUs, rank 0 to 7 belong to the first box and\n    ranks 8 to 15 belong to the second box.\n    \"\"\"\n\n    if backend is None:\n        from sglang.multimodal_gen.runtime.platforms import current_platform\n\n        backend = current_platform.get_torch_distributed_backend_str()\n    # Get world size and rank. Ensure some consistencies.\n    assert torch.distributed.is_initialized()\n    world_size: int = torch.distributed.get_world_size()\n    backend = backend or torch.distributed.get_backend(get_world_group().device_group)\n\n    dit_parallel_size = (\n        data_parallel_size\n        * classifier_free_guidance_degree\n        * sequence_parallel_degree\n        * pipeline_parallel_degree\n        * tensor_parallel_degree\n    )\n\n    if world_size < dit_parallel_size:\n        raise RuntimeError(\n            f\"world_size ({world_size}) is less than \"\n            f\"tensor_parallel_degree ({tensor_parallel_degree}) x \"\n            f\"pipeline_parallel_degree ({pipeline_parallel_degree}) x\"\n            f\"sequence_parallel_degree ({sequence_parallel_degree}) x\"\n            f\"classifier_free_guidance_degree \"\n            f\"({classifier_free_guidance_degree}) x\"\n            f\"data_parallel_degree ({data_parallel_size})\"\n        )\n\n    rank_generator: RankGenerator = RankGenerator(\n        tensor_parallel_degree,\n        sequence_parallel_degree,\n        pipeline_parallel_degree,\n        classifier_free_guidance_degree,\n        data_parallel_size,\n        \"tp-sp-pp-cfg-dp\",\n    )\n    global _DP\n    assert _DP is None, \"data parallel group is already initialized\"\n    _DP = init_parallel_group_coordinator(\n        group_ranks=rank_generator.get_ranks(\"dp\"),\n        local_rank=get_world_group().local_rank,\n        backend=backend,\n        parallel_mode=\"data\",\n    )\n\n    global _CFG\n    assert _CFG is None, \"classifier_free_guidance group is already initialized\"\n    _CFG = init_parallel_group_coordinator(\n        group_ranks=rank_generator.get_ranks(\"cfg\"),\n        local_rank=get_world_group().local_rank,\n        backend=backend,\n        parallel_mode=\"classifier_free_guidance\",\n    )\n    global _PP\n    assert _PP is None, \"pipeline model parallel group is already initialized\"\n    _PP = init_parallel_group_coordinator(\n        group_ranks=rank_generator.get_ranks(\"pp\"),\n        local_rank=get_world_group().local_rank,\n        backend=backend,\n        parallel_mode=\"pipeline\",\n    )\n\n    global _SP\n    assert _SP is None, \"sequence parallel group is already initialized\"\n\n    try:\n        from .parallel_groups import PROCESS_GROUP as _YC_PROCESS_GROUP\n        from .parallel_groups import (\n            set_seq_parallel_pg_by_sp_groups as _set_seq_parallel_pg_by_sp_groups,\n        )\n    except ImportError:\n        _set_seq_parallel_pg_by_sp_groups = None\n\n        class _DummyProcessGroup:\n            ULYSSES_PG = torch.distributed.group.WORLD\n            RING_PG = torch.distributed.group.WORLD\n\n        PROCESS_GROUP = _DummyProcessGroup()\n    else:\n        # Build SGLang Diffusion SP sub-groups based on the true SP groups. This is\n        # critical when TP>1, because SP groups may be strided in global ranks\n        # (e.g., tp-sp order).\n        sp_groups = rank_generator.get_ranks(\"sp\")\n        _set_seq_parallel_pg_by_sp_groups(\n            sp_ulysses_degree=ulysses_degree,\n            sp_ring_degree=ring_degree,\n            rank=get_world_group().rank,\n            sp_groups=sp_groups,\n        )\n        PROCESS_GROUP = _YC_PROCESS_GROUP\n\n    _SP = init_parallel_group_coordinator(\n        group_ranks=rank_generator.get_ranks(\"sp\"),\n        local_rank=get_world_group().local_rank,\n        backend=backend,\n        parallel_mode=\"sequence\",\n        ulysses_group=PROCESS_GROUP.ULYSSES_PG,\n        ring_group=PROCESS_GROUP.RING_PG,\n    )\n\n    global _TP\n    assert _TP is None, \"Tensor parallel group is already initialized\"\n    _TP = init_parallel_group_coordinator(\n        group_ranks=rank_generator.get_ranks(\"tp\"),\n        local_rank=get_world_group().local_rank,\n        backend=backend,\n        parallel_mode=\"tensor\",\n    )\n\n    if vae_parallel_size > 0:\n        init_vae_group(dit_parallel_size, vae_parallel_size, backend)\n    init_dit_group(dit_parallel_size, backend)\n\n\ndef get_sp_world_size() -> int:\n    \"\"\"Return world size for the sequence model parallel group.\"\"\"\n    return get_sp_group().world_size\n\n\ndef get_sp_parallel_rank() -> int:\n    \"\"\"Return my rank for the sequence model parallel group.\"\"\"\n    return get_sp_group().rank_in_group\n\n\ndef get_world_size() -> int:\n    \"\"\"Return world size for the world group.\"\"\"\n    return get_world_group().world_size\n\n\ndef get_world_rank() -> int:\n    \"\"\"Return my rank for the world group.\"\"\"\n    return get_world_group().rank\n\n\ndef get_dp_world_size() -> int:\n    \"\"\"Return world size for the data parallel group.\"\"\"\n    return get_dp_group().world_size\n\n\ndef get_dp_rank() -> int:\n    \"\"\"Return my rank for the data parallel group.\"\"\"\n    return get_dp_group().rank_in_group\n\n\ndef maybe_init_distributed_environment_and_model_parallel(\n    tp_size: int,\n    sp_size: int,\n    enable_cfg_parallel: bool,\n    ulysses_degree: int = 1,\n    ring_degree: int = 1,\n    dp_size: int = 1,\n    distributed_init_method: str = \"env://\",\n    dist_timeout: int | None = None,\n):\n    from sglang.multimodal_gen.runtime.platforms import current_platform\n\n    if _WORLD is not None and model_parallel_is_initialized():\n        # make sure the tp and sp sizes are correct\n        assert (\n            get_tp_world_size() == tp_size\n        ), f\"You are trying to initialize model parallel groups with size {tp_size}, but they are already initialized with size {get_tp_world_size()}\"\n        assert (\n            get_sp_world_size() == sp_size\n        ), f\"You are trying to initialize model parallel groups with size {sp_size}, but they are already initialized with size {get_sp_world_size()}\"\n        return\n    local_rank = int(os.environ.get(\"LOCAL_RANK\", 0))\n    world_size = int(os.environ.get(\"WORLD_SIZE\", 1))\n    rank = int(os.environ.get(\"RANK\", 0))\n    device = get_local_torch_device()\n    logger.info(\n        \"Initializing distributed environment with world_size=%d, device=%s, timeout=%s\",\n        world_size,\n        device,\n        dist_timeout,\n        main_process_only=False,\n    )\n\n    init_distributed_environment(\n        world_size=world_size,\n        rank=rank,\n        local_rank=local_rank,\n        distributed_init_method=distributed_init_method,\n        device_id=device,\n        backend=current_platform.get_torch_distributed_backend_str(),\n        timeout=dist_timeout,\n    )\n    initialize_model_parallel(\n        data_parallel_size=dp_size,\n        classifier_free_guidance_degree=2 if enable_cfg_parallel else 1,\n        tensor_parallel_degree=tp_size,\n        ulysses_degree=ulysses_degree,\n        ring_degree=ring_degree,\n        sequence_parallel_degree=sp_size,\n    )\n\n    # Only set CUDA device if we're on a CUDA platform\n    if current_platform.is_cuda_alike():\n        device = torch.device(f\"cuda:{local_rank}\")\n        torch.cuda.set_device(device)\n    elif current_platform.is_npu():\n        device = torch.device(f\"npu:{local_rank}\")\n        torch.npu.set_device(device)\n\n\ndef model_parallel_is_initialized() -> bool:\n    \"\"\"Check if model parallel groups are initialized.\"\"\"\n    return (\n        _DP is not None\n        and _CFG is not None\n        and _SP is not None\n        and _PP is not None\n        and _TP is not None\n    )\n\n\n_TP_STATE_PATCHED = False\n\n\n@contextmanager\ndef patch_tensor_parallel_group(tp_group: GroupCoordinator):\n    \"\"\"Patch the tp group temporarily until this function ends.\n\n    This method is for draft workers of speculative decoding to run draft model\n    with different tp degree from that of target model workers.\n\n    \"\"\"\n    global _TP_STATE_PATCHED\n    assert not _TP_STATE_PATCHED, \"Should not call when it's already patched\"\n\n    _TP_STATE_PATCHED = True\n    old_tp_group = get_tp_group()\n    global _TP\n    _TP = tp_group\n    try:\n        yield\n    finally:\n        # restore the original state\n        _TP_STATE_PATCHED = False\n        _TP = old_tp_group\n\n\ndef get_tp_world_size() -> int:\n    \"\"\"Return world size for the tensor model parallel group.\"\"\"\n    return get_tp_group().world_size\n\n\ndef get_tp_rank() -> int:\n    \"\"\"Return my rank for the tensor model parallel group.\"\"\"\n    return get_tp_group().rank_in_group\n\n\ndef destroy_distributed_environment() -> None:\n    global _WORLD\n    if _WORLD:\n        _WORLD.destroy()\n    _WORLD = None\n    if torch.distributed.is_initialized():\n        torch.distributed.destroy_process_group()\n\n\ndef cleanup_dist_env_and_memory(shutdown_ray: bool = False):\n    destroy_model_parallel()\n    destroy_distributed_environment()\n    with contextlib.suppress(AssertionError):\n        torch.distributed.destroy_process_group()\n    if shutdown_ray:\n        import ray  # Lazy import Ray\n\n        ray.shutdown()\n\n\ndef is_the_same_node_as(\n    pg: ProcessGroup | StatelessProcessGroup, source_rank: int = 0\n) -> list[int]:\n    \"\"\"\n    This is a collective operation that returns if each rank is in the same node\n    as the source rank. It tests if processes are attached to the same\n    memory system (shared access to shared memory).\n    \"\"\"\n    if isinstance(pg, ProcessGroup):\n        assert (\n            torch.distributed.get_backend(pg) != torch.distributed.Backend.NCCL\n        ), \"in_the_same_node_as should be tested with a non-NCCL group.\"\n        # local rank inside the group\n        rank = torch.distributed.get_rank(group=pg)\n        world_size = torch.distributed.get_world_size(group=pg)\n\n        # global ranks of the processes in the group\n        ranks = torch.distributed.get_process_group_ranks(pg)\n    else:\n        rank = pg.rank\n        world_size = pg.world_size\n        ranks = list(range(world_size))\n\n    # local tensor in each process to store the result\n    is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)\n\n    magic_message = b\"magic_message\"\n    shm = None\n\n    try:\n        with contextlib.suppress(OSError):\n            if rank == source_rank:\n                # create a shared memory segment\n                shm = shared_memory.SharedMemory(create=True, size=128)\n                shm.buf[: len(magic_message)] = magic_message\n                if isinstance(pg, ProcessGroup):\n                    torch.distributed.broadcast_object_list(\n                        [shm.name], src=ranks[source_rank], group=pg\n                    )\n                else:\n                    pg.broadcast_obj(shm.name, src=source_rank)\n                is_in_the_same_node[rank] = 1\n            else:\n                # try to open the shared memory segment\n                if isinstance(pg, ProcessGroup):\n                    recv = [None]\n                    torch.distributed.broadcast_object_list(\n                        recv, src=ranks[source_rank], group=pg\n                    )\n                    name = recv[0]\n                else:\n                    name = pg.broadcast_obj(None, src=source_rank)\n                # fix to https://stackoverflow.com/q/62748654/9191338\n                # Python incorrectly tracks shared memory even if it is not\n                # created by the process. The following patch is a workaround.\n                with patch(\n                    \"multiprocessing.resource_tracker.register\",\n                    lambda *args, **kwargs: None,\n                ):\n                    shm = shared_memory.SharedMemory(name=name)\n                if shm.buf[: len(magic_message)] == magic_message:\n                    is_in_the_same_node[rank] = 1\n    except Exception as e:\n        logger.error(\"Error ignored in is_in_the_same_node: %s\", e)\n    finally:\n        if shm:\n            shm.close()\n\n    if isinstance(pg, ProcessGroup):\n        torch.distributed.barrier(group=pg)\n    else:\n        pg.barrier()\n\n    # clean up the shared memory segment\n    with contextlib.suppress(OSError):\n        if rank == source_rank and shm:\n            shm.unlink()\n\n    if isinstance(pg, ProcessGroup):\n        torch.distributed.all_reduce(is_in_the_same_node, group=pg)\n        aggregated_data = is_in_the_same_node\n    else:\n        aggregated_data = torch.zeros_like(is_in_the_same_node)\n        for i in range(world_size):\n            rank_data = pg.broadcast_obj(is_in_the_same_node, src=i)\n            aggregated_data += rank_data\n\n    return [x == 1 for x in aggregated_data.tolist()]\n\n\ndef get_tensor_model_parallel_world_size() -> int:\n    \"\"\"Return world size for the tensor model parallel group.\"\"\"\n    return get_tp_world_size()\n\n\ndef get_tensor_model_parallel_rank() -> int:\n    \"\"\"Return my rank for the tensor model parallel group.\"\"\"\n    return get_tp_rank()\n\n\ndef get_sequence_parallel_world_size() -> int:\n    \"\"\"Return world size for the sequence parallel group.\"\"\"\n    return get_sp_world_size()\n\n\ndef get_sequence_parallel_rank() -> int:\n    \"\"\"Return my rank for the sequence parallel group.\"\"\"\n    return get_sp_parallel_rank()\n\n\ndef get_ulysses_parallel_world_size() -> int:\n    return get_sp_group().ulysses_world_size\n\n\ndef get_ulysses_parallel_rank() -> int:\n    return get_sp_group().ulysses_rank\n\n\ndef get_ring_parallel_world_size() -> int:\n    return get_sp_group().ring_world_size\n\n\ndef get_ring_parallel_rank() -> int:\n    return get_sp_group().ring_rank\n\n\n# PP\ndef get_pp_group() -> PipelineGroupCoordinator:\n    assert _PP is not None, \"pipeline model parallel group is not initialized\"\n    return _PP\n\n\ndef get_pipeline_parallel_world_size() -> int:\n    \"\"\"Return world size for the pipeline model parallel group.\"\"\"\n    return get_pp_group().world_size\n\n\ndef get_pipeline_parallel_rank() -> int:\n    \"\"\"Return my rank for the pipeline model parallel group.\"\"\"\n    return get_pp_group().rank_in_group\n\n\ndef is_pipeline_first_stage() -> bool:\n    \"\"\"Return True if in the first pipeline model parallel stage, False otherwise.\"\"\"\n    return get_pipeline_parallel_rank() == 0\n\n\ndef is_pipeline_last_stage() -> bool:\n    \"\"\"Return True if in the last pipeline model parallel stage, False otherwise.\"\"\"\n    return get_pipeline_parallel_rank() == (get_pipeline_parallel_world_size() - 1)\n\n\n# CFG\ndef get_cfg_group() -> GroupCoordinator:\n    assert (\n        _CFG is not None\n    ), \"classifier_free_guidance parallel group is not initialized\"\n    return _CFG\n\n\ndef get_classifier_free_guidance_world_size() -> int:\n    \"\"\"Return world size for the classifier_free_guidance parallel group.\"\"\"\n    return get_cfg_group().world_size\n\n\ndef get_classifier_free_guidance_rank() -> int:\n    \"\"\"Return my rank for the classifier_free_guidance parallel group.\"\"\"\n    return get_cfg_group().rank_in_group\n\n\ndef get_data_parallel_world_size() -> int:\n    \"\"\"Return world size for the data parallel group.\"\"\"\n    return get_dp_world_size()\n\n\ndef get_data_parallel_rank() -> int:\n    \"\"\"Return my rank for the data parallel group.\"\"\"\n    return get_dp_rank()\n\n\ndef is_dp_last_group() -> bool:\n    \"\"\"Return True if in the last data parallel group, False otherwise.\"\"\"\n    return (\n        get_sequence_parallel_rank() == (get_sequence_parallel_world_size() - 1)\n        and get_classifier_free_guidance_rank()\n        == (get_classifier_free_guidance_world_size() - 1)\n        and get_pipeline_parallel_rank() == (get_pipeline_parallel_world_size() - 1)\n    )\n\n\ndef get_dit_world_size() -> int:\n    \"\"\"Return world size for the DiT model (excluding VAE).\"\"\"\n    return (\n        get_data_parallel_world_size()\n        * get_classifier_free_guidance_world_size()\n        * get_sequence_parallel_world_size()\n        * get_pipeline_parallel_world_size()\n        * get_tensor_model_parallel_world_size()\n    )\n\n\ndef get_vae_parallel_group() -> ProcessGroup:\n    assert _VAE is not None, \"VAE parallel group is not initialized\"\n    return _VAE\n\n\ndef get_vae_parallel_world_size() -> int:\n    \"\"\"Return world size for the VAE parallel group.\"\"\"\n    return torch.distributed.get_world_size(group=get_vae_parallel_group())\n\n\ndef get_vae_parallel_rank() -> int:\n    \"\"\"Return my rank for the VAE parallel group.\"\"\"\n    return torch.distributed.get_rank(group=get_vae_parallel_group())\n\n\ndef init_dit_group(\n    dit_parallel_size: int,\n    backend: str,\n) -> None:\n    global _DIT\n    assert _DIT is None, \"DIT group is already initialized\"\n    _DIT = torch.distributed.new_group(\n        ranks=list(range(dit_parallel_size)), backend=backend\n    )\n\n\ndef get_dit_group() -> ProcessGroup:\n    assert _DIT is not None, \"DIT group is not initialized\"\n    return _DIT\n\n\ndef init_vae_group(\n    dit_parallel_size: int,\n    vae_parallel_size: int,\n    backend: str,\n):\n    # Initialize VAE group first\n    global _VAE\n    assert _VAE is None, \"VAE parallel group is already initialized\"\n    vae_ranks = list(range(dit_parallel_size, dit_parallel_size + vae_parallel_size))\n    _VAE = torch.distributed.new_group(ranks=vae_ranks, backend=backend)\n\n\ndef destroy_model_parallel() -> None:\n    \"\"\"Set the groups to none and destroy them.\"\"\"\n    global _TP, _SP, _DP, _CFG, _PP, _DIT, _VAE\n\n    for group in (_TP, _SP, _DP, _CFG, _PP):\n        if group is not None:\n            group.destroy()\n\n    for group in (_DIT, _VAE):\n        if group is not None:\n            torch.distributed.destroy_process_group(group)\n\n    _TP, _SP, _DP, _CFG, _PP, _DIT, _VAE = (None,) * 7\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/distributed/utils.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# Adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/utils.py\n\n# Copyright 2023 The vLLM team.\n# Adapted from\n# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py\n# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.\nimport dataclasses\nimport pickle\nimport time\nfrom collections import deque\nfrom collections.abc import Sequence\nfrom typing import Any\n\nimport torch\nfrom torch.distributed import TCPStore\n\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\ndef ensure_divisibility(numerator, denominator) -> None:\n    \"\"\"Ensure that numerator is divisible by the denominator.\"\"\"\n    assert numerator % denominator == 0, \"{} is not divisible by {}\".format(\n        numerator, denominator\n    )\n\n\ndef divide(numerator: int, denominator: int) -> int:\n    \"\"\"Ensure that numerator is divisible by the denominator and return\n    the division value.\"\"\"\n    ensure_divisibility(numerator, denominator)\n    return numerator // denominator\n\n\ndef split_tensor_along_last_dim(\n    tensor: torch.Tensor,\n    num_partitions: int,\n    contiguous_split_chunks: bool = False,\n) -> Sequence[torch.Tensor]:\n    \"\"\"Split a tensor along its last dimension.\n\n    Arguments:\n        tensor: input tensor.\n        num_partitions: number of partitions to split the tensor\n        contiguous_split_chunks: If True, make each chunk contiguous\n                                 in memory.\n\n    Returns:\n        A list of Tensors\n    \"\"\"\n    # Get the size and dimension.\n    last_dim = tensor.dim() - 1\n    last_dim_size = divide(tensor.size()[last_dim], num_partitions)\n    # Split.\n    tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)\n    # NOTE: torch.split does not create contiguous tensors by default.\n    if contiguous_split_chunks:\n        return tuple(chunk.contiguous() for chunk in tensor_list)\n\n    return tuple(tensor_list)\n\n\n@dataclasses.dataclass\nclass StatelessProcessGroup:\n    \"\"\"A dataclass to hold a metadata store, and the rank, world_size of the\n    group. Only use it to communicate metadata between processes.\n    For data-plane communication, create NCCL-related objects.\n    \"\"\"\n\n    rank: int\n    world_size: int\n    store: torch._C._distributed_c10d.Store\n    data_expiration_seconds: int = 3600  # 1 hour\n\n    # dst rank -> counter\n    send_dst_counter: dict[int, int] = dataclasses.field(default_factory=dict)\n    # src rank -> counter\n    recv_src_counter: dict[int, int] = dataclasses.field(default_factory=dict)\n    broadcast_send_counter: int = 0\n    broadcast_recv_src_counter: dict[int, int] = dataclasses.field(default_factory=dict)\n\n    # A deque to store the data entries, with key and timestamp.\n    entries: deque[tuple[str, float]] = dataclasses.field(default_factory=deque)\n\n    def __post_init__(self):\n        assert self.rank < self.world_size\n        self.send_dst_counter = {i: 0 for i in range(self.world_size)}\n        self.recv_src_counter = {i: 0 for i in range(self.world_size)}\n        self.broadcast_recv_src_counter = {i: 0 for i in range(self.world_size)}\n\n    def send_obj(self, obj: Any, dst: int):\n        \"\"\"Send an object to a destination rank.\"\"\"\n        self.expire_data()\n        key = f\"send_to/{dst}/{self.send_dst_counter[dst]}\"\n        self.store.set(key, pickle.dumps(obj))\n        self.send_dst_counter[dst] += 1\n        self.entries.append((key, time.perf_counter()))\n\n    def expire_data(self) -> None:\n        \"\"\"Expire data that is older than `data_expiration_seconds` seconds.\"\"\"\n        while self.entries:\n            # check the oldest entry\n            key, timestamp = self.entries[0]\n            if time.perf_counter() - timestamp > self.data_expiration_seconds:\n                self.store.delete_key(key)\n                self.entries.popleft()\n            else:\n                break\n\n    def recv_obj(self, src: int) -> Any:\n        \"\"\"Receive an object from a source rank.\"\"\"\n        obj = pickle.loads(\n            self.store.get(f\"send_to/{self.rank}/{self.recv_src_counter[src]}\")\n        )\n        self.recv_src_counter[src] += 1\n        return obj\n\n    def broadcast_obj(self, obj: Any | None, src: int) -> Any:\n        \"\"\"Broadcast an object from a source rank to all other ranks.\n        It does not clean up after all ranks have received the object.\n        Use it for limited times, e.g., for initialization.\n        \"\"\"\n        if self.rank == src:\n            self.expire_data()\n            key = f\"broadcast_from/{src}/\" f\"{self.broadcast_send_counter}\"\n            self.store.set(key, pickle.dumps(obj))\n            self.broadcast_send_counter += 1\n            self.entries.append((key, time.perf_counter()))\n            return obj\n        else:\n            key = f\"broadcast_from/{src}/\" f\"{self.broadcast_recv_src_counter[src]}\"\n            recv_obj = pickle.loads(self.store.get(key))\n            self.broadcast_recv_src_counter[src] += 1\n            return recv_obj\n\n    def all_gather_obj(self, obj: Any) -> list[Any]:\n        \"\"\"All gather an object from all ranks.\"\"\"\n        gathered_objs = []\n        for i in range(self.world_size):\n            if i == self.rank:\n                gathered_objs.append(obj)\n                self.broadcast_obj(obj, src=self.rank)\n            else:\n                recv_obj = self.broadcast_obj(None, src=i)\n                gathered_objs.append(recv_obj)\n        return gathered_objs\n\n    def barrier(self):\n        \"\"\"A barrier to synchronize all ranks.\"\"\"\n        for i in range(self.world_size):\n            if i == self.rank:\n                self.broadcast_obj(None, src=self.rank)\n            else:\n                self.broadcast_obj(None, src=i)\n\n    @staticmethod\n    def create(\n        host: str,\n        port: int,\n        rank: int,\n        world_size: int,\n        data_expiration_seconds: int = 3600,\n    ) -> \"StatelessProcessGroup\":\n        \"\"\"A replacement for `torch.distributed.init_process_group` that does not\n        pollute the global state.\n\n        If we have process A and process B called `torch.distributed.init_process_group`\n        to form a group, and then we want to form another group with process A, B, C,\n        D, it is not possible in PyTorch, because process A and process B have already\n        formed a group, and process C and process D cannot join that group. This\n        function is a workaround for this issue.\n\n        `torch.distributed.init_process_group` is a global call, while this function\n        is a stateless call. It will return a `StatelessProcessGroup` object that can be\n        used for exchanging metadata. With this function, process A and process B\n        can call `StatelessProcessGroup.create` to form a group, and then process A, B,\n        C, and D can call `StatelessProcessGroup.create` to form another group.\n        \"\"\"  # noqa\n        store = TCPStore(\n            host_name=host,\n            port=port,\n            world_size=world_size,\n            is_master=(rank == 0),\n        )\n\n        return StatelessProcessGroup(\n            rank=rank,\n            world_size=world_size,\n            store=store,\n            data_expiration_seconds=data_expiration_seconds,\n        )\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/entrypoints/__init__.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import globally_suppress_loggers\n\nglobally_suppress_loggers()\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/entrypoints/cli/__init__.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/entrypoints/cli/cli_types.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/entrypoints/cli/types.py\n\nimport argparse\n\nfrom sglang.multimodal_gen.utils import FlexibleArgumentParser\n\n\nclass CLISubcommand:\n    \"\"\"Base class for CLI subcommands\"\"\"\n\n    name: str\n\n    def cmd(\n        self, args: argparse.Namespace, unknown_args: list[str] | None = None\n    ) -> None:\n        \"\"\"Execute the command with the given arguments\"\"\"\n        raise NotImplementedError\n\n    def validate(self, args: argparse.Namespace) -> None:\n        \"\"\"Validate the arguments for this command\"\"\"\n        pass\n\n    def subparser_init(\n        self, subparsers: argparse._SubParsersAction\n    ) -> FlexibleArgumentParser:\n        \"\"\"Initialize the subparser for this command\"\"\"\n        raise NotImplementedError\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/entrypoints/cli/generate.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/entrypoints/cli/serve.py\n\nimport argparse\nimport dataclasses\nimport json\nimport os\nfrom typing import cast\n\nfrom sglang.multimodal_gen import DiffGenerator\nfrom sglang.multimodal_gen.configs.sample.sampling_params import (\n    SamplingParams,\n    generate_request_id,\n)\nfrom sglang.multimodal_gen.runtime.entrypoints.cli.cli_types import CLISubcommand\nfrom sglang.multimodal_gen.runtime.entrypoints.cli.utils import (\n    RaiseNotImplementedAction,\n)\nfrom sglang.multimodal_gen.runtime.entrypoints.utils import GenerationResult\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.runtime.utils.perf_logger import (\n    MemorySnapshot,\n    PerformanceLogger,\n    RequestMetrics,\n)\nfrom sglang.multimodal_gen.utils import FlexibleArgumentParser\n\nlogger = init_logger(__name__)\n\n\ndef add_multimodal_gen_generate_args(parser: argparse.ArgumentParser):\n    \"\"\"Add the arguments for the generate command.\"\"\"\n    parser.add_argument(\n        \"--config\",\n        type=str,\n        default=\"\",\n        required=False,\n        help=\"Read CLI options from a config JSON or YAML file. If provided, --model-path and --prompt are optional.\",\n    )\n    parser.add_argument(\n        \"--perf-dump-path\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Path to dump the performance metrics (JSON) for the run.\",\n    )\n\n    parser = ServerArgs.add_cli_args(parser)\n    parser = SamplingParams.add_cli_args(parser)\n\n    parser.add_argument(\n        \"--text-encoder-configs\",\n        action=RaiseNotImplementedAction,\n        help=\"JSON array of text encoder configurations (NOT YET IMPLEMENTED)\",\n    )\n\n    return parser\n\n\ndef maybe_dump_performance(\n    args: argparse.Namespace,\n    server_args,\n    prompt: str,\n    results: GenerationResult | list[GenerationResult] | None,\n):\n    \"\"\"dump performance if necessary\"\"\"\n    if not (args.perf_dump_path and results):\n        return\n\n    if isinstance(results, list):\n        result = results[0] if results else None\n    else:\n        result = results\n\n    metrics_dict = result.metrics\n    if not (args.perf_dump_path and metrics_dict):\n        return\n\n    metrics = RequestMetrics(request_id=metrics_dict.get(\"request_id\"))\n    metrics.stages = metrics_dict.get(\"stages\", {})\n    metrics.steps = metrics_dict.get(\"steps\", [])\n    metrics.total_duration_ms = metrics_dict.get(\"total_duration_ms\", 0)\n\n    # restore memory snapshots from serialized dict\n    memory_snapshots_dict = metrics_dict.get(\"memory_snapshots\", {})\n    for checkpoint_name, snapshot_dict in memory_snapshots_dict.items():\n        snapshot = MemorySnapshot(\n            allocated_mb=snapshot_dict.get(\"allocated_mb\", 0.0),\n            reserved_mb=snapshot_dict.get(\"reserved_mb\", 0.0),\n            peak_allocated_mb=snapshot_dict.get(\"peak_allocated_mb\", 0.0),\n            peak_reserved_mb=snapshot_dict.get(\"peak_reserved_mb\", 0.0),\n        )\n        metrics.memory_snapshots[checkpoint_name] = snapshot\n\n    PerformanceLogger.dump_benchmark_report(\n        file_path=args.perf_dump_path,\n        metrics=metrics,\n        meta={\n            \"prompt\": prompt,\n            \"model\": server_args.model_path,\n        },\n        tag=\"cli_generate\",\n    )\n\n\ndef generate_cmd(args: argparse.Namespace, unknown_args: list[str] | None = None):\n    \"\"\"The entry point for the generate command.\"\"\"\n    args.request_id = \"mocked_fake_id_for_offline_generate\"\n\n    server_args = ServerArgs.from_cli_args(args, unknown_args)\n\n    sampling_params_kwargs = SamplingParams.get_cli_args(args)\n    sampling_params_kwargs[\"request_id\"] = generate_request_id()\n\n    # Handle diffusers-specific kwargs passed via CLI\n    if hasattr(args, \"diffusers_kwargs\") and args.diffusers_kwargs:\n        try:\n            sampling_params_kwargs[\"diffusers_kwargs\"] = json.loads(\n                args.diffusers_kwargs\n            )\n            logger.info(\n                \"Parsed diffusers_kwargs: %s\",\n                sampling_params_kwargs[\"diffusers_kwargs\"],\n            )\n        except json.JSONDecodeError as e:\n            logger.error(\"Failed to parse --diffusers-kwargs as JSON: %s\", e)\n            raise ValueError(\n                f\"--diffusers-kwargs must be valid JSON. Got: {args.diffusers_kwargs}\"\n            ) from e\n\n    generator = DiffGenerator.from_pretrained(\n        model_path=server_args.model_path, server_args=server_args, local_mode=True\n    )\n\n    results = generator.generate(sampling_params_kwargs=sampling_params_kwargs)\n\n    prompt = sampling_params_kwargs.get(\"prompt\")\n    maybe_dump_performance(args, server_args, prompt, results)\n\n\nclass GenerateSubcommand(CLISubcommand):\n    \"\"\"The `generate` subcommand for the sglang-diffusion CLI\"\"\"\n\n    def __init__(self) -> None:\n        self.name = \"generate\"\n        super().__init__()\n        self.init_arg_names = self._get_init_arg_names()\n        self.generation_arg_names = self._get_generation_arg_names()\n\n    def _get_init_arg_names(self) -> list[str]:\n        \"\"\"Get names of arguments for DiffGenerator initialization\"\"\"\n        return [\"num_gpus\", \"tp_size\", \"sp_size\", \"model_path\"]\n\n    def _get_generation_arg_names(self) -> list[str]:\n        \"\"\"Get names of arguments for generate_video method\"\"\"\n        return [field.name for field in dataclasses.fields(SamplingParams)]\n\n    def cmd(\n        self, args: argparse.Namespace, unknown_args: list[str] | None = None\n    ) -> None:\n        generate_cmd(args, unknown_args)\n\n    def validate(self, args: argparse.Namespace) -> None:\n        \"\"\"Validate the arguments for this command\"\"\"\n        if args.num_gpus is not None and args.num_gpus <= 0:\n            raise ValueError(\"Number of gpus must be positive\")\n\n        if args.config and not os.path.exists(args.config):\n            raise ValueError(f\"Config file not found: {args.config}\")\n\n    def subparser_init(\n        self, subparsers: argparse._SubParsersAction\n    ) -> FlexibleArgumentParser:\n        generate_parser = subparsers.add_parser(\n            \"generate\",\n            help=\"Run inference on a model\",\n            usage=\"sglang generate (--model-path MODEL_PATH_OR_ID --prompt PROMPT) | --config CONFIG_FILE [OPTIONS]\",\n        )\n\n        generate_parser = add_multimodal_gen_generate_args(generate_parser)\n\n        return cast(FlexibleArgumentParser, generate_parser)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/entrypoints/cli/main.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/entrypoints/cli/main.py\n\nfrom sglang.multimodal_gen.runtime.entrypoints.cli.cli_types import CLISubcommand\nfrom sglang.multimodal_gen.runtime.entrypoints.cli.generate import GenerateSubcommand\nfrom sglang.multimodal_gen.runtime.entrypoints.cli.serve import ServeSubcommand\nfrom sglang.multimodal_gen.utils import FlexibleArgumentParser\n\n\ndef generate_cmd_init() -> list[CLISubcommand]:\n    return [GenerateSubcommand(), ServeSubcommand()]\n\n\ndef cmd_init() -> list[CLISubcommand]:\n    \"\"\"Initialize all commands from separate modules\"\"\"\n    commands = []\n    commands.extend(generate_cmd_init())\n    return commands\n\n\ndef main() -> None:\n    parser = FlexibleArgumentParser(description=\"sglang-diffusion CLI\")\n    parser.add_argument(\"-v\", \"--version\", action=\"version\", version=\"0.1.0\")\n\n    subparsers = parser.add_subparsers(required=False, dest=\"subparser\")\n\n    cmds = {}\n    for cmd in cmd_init():\n        cmd.subparser_init(subparsers).set_defaults(dispatch_function=cmd.cmd)\n        cmds[cmd.name] = cmd\n    args, unknown_args = parser.parse_known_args()\n    if args.subparser in cmds:\n        cmds[args.subparser].validate(args)\n\n    if hasattr(args, \"dispatch_function\"):\n        args.dispatch_function(args, unknown_args=unknown_args)\n    else:\n        parser.print_help()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/entrypoints/cli/serve.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\nimport argparse\nimport os\nfrom typing import cast\n\nfrom sglang.multimodal_gen.apps.webui import run_sgl_diffusion_webui\nfrom sglang.multimodal_gen.runtime.entrypoints.cli.cli_types import CLISubcommand\nfrom sglang.multimodal_gen.runtime.launch_server import launch_server\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.utils import FlexibleArgumentParser\n\nlogger = init_logger(__name__)\n\n\ndef add_multimodal_gen_serve_args(parser: argparse.ArgumentParser):\n    \"\"\"Add the arguments for the serve command.\"\"\"\n    parser.add_argument(\n        \"--config\",\n        type=str,\n        default=\"\",\n        required=False,\n        help=\"Read CLI options from a config JSON or YAML file.\",\n    )\n    return ServerArgs.add_cli_args(parser)\n\n\ndef execute_serve_cmd(args: argparse.Namespace, unknown_args: list[str] | None = None):\n    \"\"\"The entry point for the serve command.\"\"\"\n    server_args = ServerArgs.from_cli_args(args, unknown_args)\n    launch_server(server_args)\n\n    if server_args.webui:\n        run_sgl_diffusion_webui(server_args)\n\n\nclass ServeSubcommand(CLISubcommand):\n    \"\"\"The `serve` subcommand for the sglang-diffusion CLI\"\"\"\n\n    def __init__(self) -> None:\n        self.name = \"serve\"\n        super().__init__()\n\n    def cmd(\n        self, args: argparse.Namespace, unknown_args: list[str] | None = None\n    ) -> None:\n        execute_serve_cmd(args, unknown_args)\n\n    def validate(self, args: argparse.Namespace) -> None:\n        \"\"\"Validate the arguments for this command\"\"\"\n        if args.config and not os.path.exists(args.config):\n            raise ValueError(f\"Config file not found: {args.config}\")\n\n    def subparser_init(\n        self, subparsers: argparse._SubParsersAction\n    ) -> FlexibleArgumentParser:\n        serve_parser = subparsers.add_parser(\n            \"serve\",\n            help=\"Launch the server and start FastAPI listener.\",\n            usage=\"sglang serve --model-path MODEL_PATH_OR_ID [OPTIONS]\",\n        )\n\n        serve_parser = add_multimodal_gen_serve_args(serve_parser)\n\n        return cast(FlexibleArgumentParser, serve_parser)\n\n\ndef cmd_init() -> list[CLISubcommand]:\n    return [ServeSubcommand()]\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/entrypoints/cli/utils.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\nimport argparse\nimport os\nimport shlex\nimport subprocess\nimport sys\n\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\nclass RaiseNotImplementedAction(argparse.Action):\n\n    def __call__(self, parser, namespace, values, option_string=None):\n        raise NotImplementedError(f\"The {option_string} option is not yet implemented\")\n\n\ndef launch_distributed(\n    num_gpus: int, args: list[str], master_port: int | None = None\n) -> int:\n    \"\"\"\n    Launch a distributed job with the given arguments\n\n    Args:\n        num_gpus: Number of GPUs to use\n        args: Arguments to pass to v1_sgl_diffusion_inference.py (defaults to sys.argv[1:])\n        master_port: Port for the master process (default: random)\n    \"\"\"\n\n    current_env = os.environ.copy()\n    python_executable = sys.executable\n    project_root = os.path.abspath(\n        os.path.join(os.path.dirname(__file__), \"../../../..\")\n    )\n    main_script = os.path.join(\n        project_root, \"sgl_diffusion/sample/v1_sgl_diffusion_inference.py\"\n    )\n\n    cmd = [\n        python_executable,\n        \"-m\",\n        \"torch.distributed.run\",\n        f\"--nproc_per_node={num_gpus}\",\n    ]\n\n    if master_port is not None:\n        cmd.append(f\"--master_port={master_port}\")\n\n    cmd.append(main_script)\n    cmd.extend(args)\n\n    logger.info(\"Running inference with %d GPU(s)\", num_gpus)\n    logger.info(\"Launching command: %s\", shlex.join(cmd))\n\n    current_env[\"PYTHONIOENCODING\"] = \"utf-8\"\n    process = subprocess.Popen(\n        cmd,\n        env=current_env,\n        stdout=subprocess.PIPE,\n        stderr=subprocess.STDOUT,\n        universal_newlines=True,\n        bufsize=1,\n        encoding=\"utf-8\",\n        errors=\"replace\",\n    )\n\n    if process.stdout:\n        for line in iter(process.stdout.readline, \"\"):\n            print(line.strip())\n\n    return process.wait()\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/entrypoints/diffusion_generator.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"\nDiffGenerator module for sglang-diffusion.\n\nThis module provides a consolidated interface for generating images/videos using\ndiffusion models.\n\"\"\"\n\nimport dataclasses\nimport multiprocessing as mp\nimport os\nimport time\nfrom typing import Any, List, Union\n\nfrom sglang.multimodal_gen.configs.sample.sampling_params import (\n    DataType,\n    SamplingParams,\n)\nfrom sglang.multimodal_gen.runtime.entrypoints.utils import (\n    GenerationResult,\n    ListLorasReq,\n    MergeLoraWeightsReq,\n    SetLoraReq,\n    ShutdownReq,\n    UnmergeLoraWeightsReq,\n    format_lora_message,\n    prepare_request,\n    save_outputs,\n)\nfrom sglang.multimodal_gen.runtime.launch_server import launch_server\nfrom sglang.multimodal_gen.runtime.pipelines_core import Req\nfrom sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import OutputBatch\nfrom sglang.multimodal_gen.runtime.scheduler_client import sync_scheduler_client\nfrom sglang.multimodal_gen.runtime.server_args import PortArgs, ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import (\n    GREEN,\n    RESET,\n    init_logger,\n    log_batch_completion,\n    log_generation_timer,\n)\n\nlogger = init_logger(__name__)\n\n# TODO: move to somewhere appropriate\ntry:\n    # Set the start method to 'spawn' to avoid CUDA errors in forked processes.\n    # This must be done at the top level of the module, before any CUDA context\n    # or other processes are initialized.\n    mp.set_start_method(\"spawn\", force=True)\nexcept RuntimeError:\n    # The start method can only be set once per program execution.\n    pass\n\n\nclass DiffGenerator:\n    \"\"\"\n    A unified class for generating images/videos using diffusion models.\n\n    This class provides a simple interface for image/video generation with rich\n    customization options, similar to popular frameworks like HF Diffusers.\n    \"\"\"\n\n    def __init__(\n        self,\n        server_args: ServerArgs,\n    ):\n        \"\"\"\n        Initialize the generator.\n\n        Args:\n            server_args: The inference arguments\n        \"\"\"\n        self.server_args = server_args\n        self.port_args = PortArgs.from_server_args(server_args)\n\n        # The executor is now a client to the Scheduler service\n        self.local_scheduler_process: list[mp.Process] | None = None\n        self.owns_scheduler_client: bool = False\n\n    @classmethod\n    def from_pretrained(\n        cls,\n        local_mode: bool = True,\n        **kwargs,\n    ) -> \"DiffGenerator\":\n        \"\"\"\n        Create a DiffGenerator from a pretrained model.\n\n        Priority level: Default pipeline config < User's pipeline config < User's kwargs\n        \"\"\"\n        # If users also provide some kwargs, it will override the ServerArgs and PipelineConfig.\n\n        if (server_args := kwargs.get(\"server_args\", None)) is not None:\n            if isinstance(server_args, ServerArgs):\n                pass\n            elif isinstance(server_args, dict):\n                server_args = ServerArgs.from_kwargs(**server_args)\n        else:\n            server_args = ServerArgs.from_kwargs(**kwargs)\n\n        return cls.from_server_args(server_args, local_mode=local_mode)\n\n    @classmethod\n    def from_server_args(\n        cls, server_args: ServerArgs, local_mode: bool = True\n    ) -> \"DiffGenerator\":\n        \"\"\"\n        Create a DiffGenerator with the specified arguments.\n\n        Args:\n            server_args: The inference arguments\n\n        Returns:\n            The created DiffGenerator\n        \"\"\"\n        instance = cls(\n            server_args=server_args,\n        )\n        logger.info(f\"Local mode: {local_mode}\")\n        if local_mode:\n            instance.local_scheduler_process = instance._start_local_server_if_needed()\n        else:\n            # In remote mode, we just need to connect and check.\n            sync_scheduler_client.initialize(server_args)\n            instance._check_remote_scheduler()\n\n        # In both modes, this DiffGenerator instance is responsible for the client's lifecycle.\n        instance.owns_scheduler_client = True\n        return instance\n\n    def _start_local_server_if_needed(\n        self,\n    ) -> list[mp.Process]:\n        \"\"\"Check if a local server is running; if not, start it and return the process handles.\"\"\"\n        # First, we need a client to test the server. Initialize it temporarily.\n        sync_scheduler_client.initialize(self.server_args)\n\n        processes = launch_server(self.server_args, launch_http_server=False)\n\n        return processes\n\n    def _check_remote_scheduler(self):\n        \"\"\"Check if the remote scheduler is accessible.\"\"\"\n        if not sync_scheduler_client.ping():\n            raise ConnectionError(\n                f\"Could not connect to remote scheduler at \"\n                f\"{self.server_args.scheduler_endpoint} with `local mode` as False. \"\n                \"Please ensure the server is running.\"\n            )\n        logger.info(\n            f\"Successfully connected to remote scheduler at \"\n            f\"{self.server_args.scheduler_endpoint}.\"\n        )\n\n    @staticmethod\n    def _resolve_image_paths_per_prompt(\n        prompts: list[str], image_paths: str | list[str] | None\n    ) -> list[str | list[str] | None]:\n        if len(prompts) <= 1:\n            return [image_paths]\n\n        if not isinstance(image_paths, list) or len(image_paths) <= 1:\n            return [image_paths for _ in prompts]\n\n        if len(image_paths) != len(prompts):\n            raise ValueError(\n                \"When using multiple prompts with multiple input images, \"\n                \"provide either one shared image or exactly one image per prompt.\"\n            )\n\n        return [[image_path] for image_path in image_paths]\n\n    def generate(\n        self,\n        sampling_params_kwargs: dict | None = None,\n    ) -> GenerationResult | list[GenerationResult] | None:\n        \"\"\"Generate image(s)/video(s) based on the given prompt(s).\n\n        Returns a single GenerationResult for a single prompt, a list for\n        multiple prompts, or None when every request failed.\n        \"\"\"\n        # 1. prepare requests\n        prompts = self._resolve_prompts(sampling_params_kwargs.get(\"prompt\"))\n        user_output_file_name = sampling_params_kwargs.get(\"output_file_name\")\n\n        if len(prompts) > 1 and user_output_file_name is not None:\n            raise ValueError(\n                \"Cannot use multiple prompts with a fixed output_file_name. \"\n                \"Either remove --output-file-name or use a single prompt.\"\n            )\n\n        sampling_params_orig = SamplingParams.from_user_sampling_params_args(\n            self.server_args.model_path,\n            server_args=self.server_args,\n            **sampling_params_kwargs,\n        )\n\n        requests: list[Req] = []\n        image_paths_per_prompt = self._resolve_image_paths_per_prompt(\n            prompts, sampling_params_orig.image_path\n        )\n\n        for i, p in enumerate(prompts):\n            sampling_params = dataclasses.replace(\n                sampling_params_orig,\n                prompt=p,\n                output_file_name=user_output_file_name,\n                image_path=image_paths_per_prompt[i],\n            )\n            sampling_params._set_output_file_name()\n            req = prepare_request(\n                server_args=self.server_args,\n                sampling_params=sampling_params,\n            )\n            requests.append(req)\n\n        results: list[GenerationResult] = []\n        total_start_time = time.perf_counter()\n\n        # 2. send requests to scheduler one at a time\n        # TODO: send batch when supported\n        for request_idx, req in enumerate(requests):\n            try:\n                with log_generation_timer(\n                    logger, req.prompt, request_idx + 1, len(requests)\n                ) as timer:\n                    output_batch = self._send_to_scheduler_and_wait_for_response([req])\n                    if output_batch.error:\n                        raise Exception(f\"{output_batch.error}\")\n\n                    if (\n                        output_batch.output is None\n                        and output_batch.output_file_paths is None\n                    ):\n                        logger.error(\n                            \"Received empty output from scheduler for prompt %d\",\n                            request_idx + 1,\n                        )\n                        continue\n\n                    common = dict(\n                        prompt=req.prompt,\n                        size=(req.height, req.width, req.num_frames),\n                        generation_time=timer.duration,\n                        peak_memory_mb=output_batch.peak_memory_mb,\n                        metrics=(\n                            output_batch.metrics.to_dict()\n                            if output_batch.metrics\n                            else {}\n                        ),\n                        trajectory_latents=output_batch.trajectory_latents,\n                        trajectory_timesteps=output_batch.trajectory_timesteps,\n                        trajectory_decoded=output_batch.trajectory_decoded,\n                    )\n\n                    if req.save_output and req.return_file_paths_only:\n                        for idx, path in enumerate(output_batch.output_file_paths):\n                            results.append(\n                                GenerationResult(\n                                    **common,\n                                    prompt_index=idx,\n                                    output_file_path=path,\n                                )\n                            )\n                        continue\n\n                    if req.data_type == DataType.MESH:\n                        for output_idx, sample in enumerate(\n                            output_batch.output_file_paths\n                        ):\n                            results.append(\n                                GenerationResult(\n                                    **common,\n                                    prompt_index=output_idx,\n                                    output_file_path=sample,\n                                )\n                            )\n                        continue\n\n                    samples_out: list[Any] = []\n                    audios_out: list[Any] = []\n                    frames_out: list[Any] = []\n                    num_outputs = len(output_batch.output)\n                    save_outputs(\n                        output_batch.output,\n                        req.data_type,\n                        req.fps,\n                        req.save_output,\n                        lambda idx: req.output_file_path(num_outputs, idx),\n                        audio=output_batch.audio,\n                        audio_sample_rate=output_batch.audio_sample_rate,\n                        samples_out=samples_out,\n                        audios_out=audios_out,\n                        frames_out=frames_out,\n                        output_compression=req.output_compression,\n                        enable_frame_interpolation=req.enable_frame_interpolation,\n                        frame_interpolation_exp=req.frame_interpolation_exp,\n                        frame_interpolation_scale=req.frame_interpolation_scale,\n                        frame_interpolation_model_path=req.frame_interpolation_model_path,\n                        enable_upscaling=req.enable_upscaling,\n                        upscaling_model_path=req.upscaling_model_path,\n                        upscaling_scale=req.upscaling_scale,\n                    )\n\n                    for idx in range(len(samples_out)):\n                        results.append(\n                            GenerationResult(\n                                **common,\n                                samples=samples_out[idx],\n                                frames=frames_out[idx],\n                                audio=audios_out[idx],\n                                prompt_index=idx,\n                                output_file_path=req.output_file_path(num_outputs, idx),\n                            )\n                        )\n            except Exception as e:\n                logger.error(\n                    \"Generation failed for prompt %d/%d: %s\",\n                    request_idx + 1,\n                    len(requests),\n                    e,\n                    exc_info=True,\n                )\n                continue\n\n        total_gen_time = time.perf_counter() - total_start_time\n        log_batch_completion(logger, len(results), total_gen_time)\n        self._log_summary(results)\n\n        if not results:\n            return None\n        return results[0] if len(results) == 1 else results\n\n    def _resolve_prompts(self, prompt: str | list[str] | None) -> list[str]:\n        \"\"\"Collect prompts from the argument or from a prompt file.\"\"\"\n        if self.server_args.prompt_file_path is not None:\n            path = self.server_args.prompt_file_path\n            if not os.path.exists(path):\n                raise FileNotFoundError(f\"Prompt text file not found: {path}\")\n            with open(path, encoding=\"utf-8\") as f:\n                prompts = [line.strip() for line in f if line.strip()]\n            if not prompts:\n                raise ValueError(f\"No prompts found in file: {path}\")\n            logger.info(\"Found %d prompts in %s\", len(prompts), path)\n            return prompts\n\n        if prompt is None:\n            return [\" \"]\n        if isinstance(prompt, str):\n            return [prompt]\n        return list(prompt)\n\n    def _log_summary(self, results: list[GenerationResult]) -> None:\n        if not results:\n            return\n        if self.server_args.warmup:\n            total_duration_ms = results[0].metrics.get(\"total_duration_ms\", 0)\n            logger.info(\n                f\"Warmed-up request processed in {GREEN}%.2f{RESET} seconds (with warmup excluded)\",\n                total_duration_ms / 1000.0,\n            )\n\n        peak_memories = [r.peak_memory_mb for r in results if r.peak_memory_mb]\n        if peak_memories:\n            logger.info(\n                f\"Memory usage - Max peak: {max(peak_memories):.2f} MB, \"\n                f\"Avg peak: {sum(peak_memories) / len(peak_memories):.2f} MB\"\n            )\n\n    def _send_to_scheduler_and_wait_for_response(self, batch: list[Req]) -> OutputBatch:\n        \"\"\"\n        Sends a request to the scheduler and waits for a response.\n        \"\"\"\n        return sync_scheduler_client.forward(batch)\n\n    # LoRA\n    def _send_lora_request(self, req: Any, success_msg: str, failure_msg: str):\n        response = sync_scheduler_client.forward(req)\n        if response.error is None:\n            logger.info(success_msg)\n            return response\n        else:\n            error_msg = response.error\n            raise RuntimeError(f\"{failure_msg}: {error_msg}\")\n\n    def set_lora(\n        self,\n        lora_nickname: Union[str, List[str]],\n        lora_path: Union[str, None, List[Union[str, None]]] = None,\n        target: Union[str, List[str]] = \"all\",\n        strength: Union[float, List[float]] = 1.0,\n    ) -> None:\n        \"\"\"\n        Set LoRA adapter(s) for the specified transformer(s).\n        Supports both single LoRA (backward compatible) and multiple LoRA adapters.\n\n        Args:\n            lora_nickname: The nickname(s) of the adapter(s). Can be a string or a list of strings.\n            lora_path: Path(s) to the LoRA adapter(s). Can be a string, None, or a list of strings/None.\n            target: Which transformer(s) to apply the LoRA to. Can be a string or a list of strings.\n                Valid values:\n                - \"all\": Apply to all transformers (default)\n                - \"transformer\": Apply only to the primary transformer (high noise for Wan2.2)\n                - \"transformer_2\": Apply only to transformer_2 (low noise for Wan2.2)\n                - \"critic\": Apply only to the critic model\n            strength: LoRA strength(s) for merge, default 1.0. Can be a float or a list of floats.\n        \"\"\"\n        req = SetLoraReq(\n            lora_nickname=lora_nickname,\n            lora_path=lora_path,\n            target=target,\n            strength=strength,\n        )\n        nickname_str, target_str, strength_str = format_lora_message(\n            lora_nickname, target, strength\n        )\n\n        self._send_lora_request(\n            req,\n            f\"Successfully set LoRA adapter(s): {nickname_str} (target: {target_str}, strength: {strength_str})\",\n            \"Failed to set LoRA adapter\",\n        )\n\n    def unmerge_lora_weights(self, target: str = \"all\") -> None:\n        \"\"\"\n        Unmerge LoRA weights from the base model.\n\n        Args:\n            target: Which transformer(s) to unmerge.\n        \"\"\"\n        req = UnmergeLoraWeightsReq(target=target)\n        self._send_lora_request(\n            req,\n            f\"Successfully unmerged LoRA weights (target: {target})\",\n            \"Failed to unmerge LoRA weights\",\n        )\n\n    def merge_lora_weights(self, target: str = \"all\", strength: float = 1.0) -> None:\n        \"\"\"\n        Merge LoRA weights into the base model.\n\n        Args:\n            target: Which transformer(s) to merge.\n            strength: LoRA strength for merge, default 1.0.\n        \"\"\"\n        req = MergeLoraWeightsReq(target=target, strength=strength)\n        self._send_lora_request(\n            req,\n            f\"Successfully merged LoRA weights (target: {target}, strength: {strength})\",\n            \"Failed to merge LoRA weights\",\n        )\n\n    def list_loras(self) -> dict:\n        \"\"\"List loaded LoRA adapters and current application status per module.\"\"\"\n        output = self._send_lora_request(\n            req=ListLorasReq(),\n            success_msg=\"Successfully listed LoRA adapters\",\n            failure_msg=\"Failed to list LoRA adapters\",\n        )\n        # _send_lora_request already raises on error, so output.error is always None here\n        return output.output or {}\n\n    def _ensure_lora_state(\n        self,\n        lora_path: str | None,\n        lora_nickname: str | None = None,\n        merge_lora: bool = True,\n    ) -> None:\n        \"\"\"\n        Ensure the LoRA state matches the desired configuration.\n\n        Note: This method does not cache client-side state. The server handles\n        idempotent operations, so redundant calls are safe but may have minor overhead.\n        \"\"\"\n        if lora_path is None:\n            # Unmerge all LoRA weights when no lora_path is provided\n            self.unmerge_lora_weights()\n            return\n\n        lora_nickname = lora_nickname or self.server_args.lora_nickname\n\n        # Set the LoRA adapter (server handles idempotent logic)\n        self.set_lora(lora_nickname, lora_path)\n\n        # Merge or unmerge based on the merge_lora flag\n        if merge_lora:\n            self.merge_lora_weights()\n        else:\n            self.unmerge_lora_weights()\n\n    def generate_with_lora(\n        self,\n        prompt: str | list[str] | None = None,\n        sampling_params: SamplingParams | None = None,\n        *,\n        lora_path: str | None = None,\n        lora_nickname: str | None = None,\n        merge_lora: bool = True,\n        **kwargs,\n    ):\n        self._ensure_lora_state(\n            lora_path=lora_path, lora_nickname=lora_nickname, merge_lora=merge_lora\n        )\n        return self.generate(\n            sampling_params_kwargs=dict(\n                prompt=prompt,\n                sampling_params=sampling_params,\n                **kwargs,\n            )\n        )\n\n    def shutdown(self):\n        \"\"\"\n        Shutdown the generator.\n        If in local mode, it also shuts down the scheduler server.\n        \"\"\"\n        # sends the shutdown command to the server\n        if self.local_scheduler_process and self.owns_scheduler_client:\n            try:\n                sync_scheduler_client.forward(ShutdownReq())\n            except Exception:\n                pass\n\n        if self.local_scheduler_process:\n            for process in self.local_scheduler_process:\n                process.join(timeout=10)\n                if process.is_alive():\n                    logger.warning(\n                        f\"Local worker {process.name} did not terminate gracefully, forcing.\"\n                    )\n                    process.terminate()\n            self.local_scheduler_process = None\n\n        if self.owns_scheduler_client:\n            sync_scheduler_client.close()\n            self.owns_scheduler_client = False\n\n    def __enter__(self):\n        return self\n\n    def __exit__(self, exc_type, exc_val, exc_tb):\n        self.shutdown()\n\n    def __del__(self):\n        if self.owns_scheduler_client:\n            logger.warning(\n                \"Generator was garbage collected without being shut down. \"\n                \"Attempting to shut down the local server and client.\"\n            )\n            self.shutdown()\n        elif self.local_scheduler_process:\n            logger.warning(\n                \"Generator was garbage collected without being shut down. \"\n                \"Attempting to shut down the local server.\"\n            )\n            self.shutdown()\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/entrypoints/http_server.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\nimport asyncio\nimport base64\nimport os\nimport uuid\nfrom contextlib import asynccontextmanager\nfrom typing import TYPE_CHECKING\n\nimport torch\nfrom fastapi import APIRouter, FastAPI, Request\nfrom fastapi.responses import ORJSONResponse\n\nfrom sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams\nfrom sglang.multimodal_gen.runtime.entrypoints.openai import image_api, video_api\nfrom sglang.multimodal_gen.runtime.entrypoints.openai.protocol import (\n    VertexGenerateReqInput,\n)\nfrom sglang.multimodal_gen.runtime.entrypoints.openai.utils import build_sampling_params\nfrom sglang.multimodal_gen.runtime.entrypoints.post_training import weights_api\nfrom sglang.multimodal_gen.runtime.entrypoints.utils import (\n    prepare_request,\n    save_outputs,\n)\nfrom sglang.multimodal_gen.runtime.scheduler_client import async_scheduler_client\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs, get_global_server_args\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.version import __version__\n\nif TYPE_CHECKING:\n    from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req\n\nlogger = init_logger(__name__)\n\nDEFAULT_SEED = 1024\nVERTEX_ROUTE = os.environ.get(\"AIP_PREDICT_ROUTE\", \"/vertex_generate\")\n\n\n@asynccontextmanager\nasync def lifespan(app: FastAPI):\n    from sglang.multimodal_gen.runtime.scheduler_client import (\n        async_scheduler_client,\n        run_zeromq_broker,\n    )\n\n    # 1. Initialize the singleton client that connects to the backend Scheduler\n    server_args = app.state.server_args\n    async_scheduler_client.initialize(server_args)\n\n    # 2. Start the ZMQ Broker in the background to handle offline requests\n    broker_task = asyncio.create_task(run_zeromq_broker(server_args))\n\n    yield\n\n    # On shutdown\n    logger.info(\"FastAPI app is shutting down...\")\n    broker_task.cancel()\n    async_scheduler_client.close()\n\n\n# Health router\nhealth_router = APIRouter()\n\n\n@health_router.get(\"/health\")\nasync def health():\n    return {\"status\": \"ok\"}\n\n\n@health_router.get(\"/models\", deprecated=True)\nasync def get_models(request: Request):\n    \"\"\"\n    Get information about the model served by this server.\n\n    .. deprecated::\n        Use /v1/models instead for OpenAI-compatible model discovery.\n        This endpoint will be removed in a future version.\n    \"\"\"\n    from sglang.multimodal_gen.registry import get_model_info\n\n    server_args: ServerArgs = request.app.state.server_args\n    model_info = get_model_info(server_args.model_path, model_id=server_args.model_id)\n\n    response = {\n        \"model_path\": server_args.model_path,\n        \"num_gpus\": server_args.num_gpus,\n        \"task_type\": server_args.pipeline_config.task_type.name,\n        \"dit_precision\": server_args.pipeline_config.dit_precision,\n        \"vae_precision\": server_args.pipeline_config.vae_precision,\n    }\n\n    if model_info:\n        response[\"pipeline_name\"] = model_info.pipeline_cls.pipeline_name\n        response[\"pipeline_class\"] = model_info.pipeline_cls.__name__\n\n    return response\n\n\n@health_router.get(\"/server_info\")\nasync def server_info_endpoint(request: Request):\n    \"\"\"Get server information.\n\n    Returns fields compatible with the LLM engine's /server_info so that\n    the model gateway can discover diffusion workers.\n    \"\"\"\n    server_args: ServerArgs = request.app.state.server_args\n\n    return {\n        \"model_path\": server_args.model_path,\n        \"served_model_name\": server_args.model_id or server_args.model_path,\n        \"tp_size\": server_args.tp_size,\n        \"dp_size\": server_args.dp_size,\n        \"version\": __version__,\n    }\n\n\n@health_router.get(\"/model_info\")\nasync def model_info_endpoint(request: Request):\n    \"\"\"Get model information.\n\n    Returns fields compatible with the LLM engine's /model_info so that\n    the model gateway can detect capabilities for diffusion workers.\n    \"\"\"\n    from sglang.multimodal_gen.registry import get_model_info\n\n    server_args: ServerArgs = request.app.state.server_args\n    task_type = server_args.pipeline_config.task_type\n\n    try:\n        registry_info = get_model_info(\n            server_args.model_path,\n            backend=server_args.backend,\n            model_id=server_args.model_id,\n        )\n    except Exception:\n        logger.warning(\"Failed to resolve model info from registry\", exc_info=True)\n        registry_info = None\n\n    return {\n        # Fields consumed by the model gateway for worker discovery\n        \"model_path\": server_args.model_path,\n        \"is_generation\": True,\n        \"model_type\": \"diffusion\",\n        \"architectures\": (\n            [registry_info.pipeline_cls.__name__] if registry_info else None\n        ),\n        # Fields matching the LLM engine's /model_info shape\n        \"has_image_understanding\": task_type.accepts_image_input(),\n        \"has_audio_understanding\": False,\n        # Diffusion-specific fields\n        \"task_type\": task_type.name,\n        \"is_image_gen\": task_type.is_image_gen(),\n    }\n\n\n@health_router.get(\"/health_generate\")\nasync def health_generate():\n    # TODO : health generate endpoint\n    return {\"status\": \"ok\"}\n\n\ndef make_serializable(obj):\n    \"\"\"Recursively converts Tensors to None for JSON serialization.\"\"\"\n    if isinstance(obj, torch.Tensor):\n        return None\n    if isinstance(obj, dict):\n        return {k: make_serializable(v) for k, v in obj.items()}\n    if isinstance(obj, list):\n        return [make_serializable(v) for v in obj]\n    return obj\n\n\ndef encode_video_to_base64(file_path: str):\n    if not os.path.exists(file_path):\n        return None\n    with open(file_path, \"rb\") as f:\n        return base64.b64encode(f.read()).decode(\"utf-8\")\n\n\nasync def forward_to_scheduler(\n    req_obj: \"Req\",\n    sp: SamplingParams,\n):\n    \"\"\"Forwards request to scheduler and processes the result.\"\"\"\n    try:\n        response = await async_scheduler_client.forward(req_obj)\n        if response.output is None and response.output_file_paths is None:\n            raise RuntimeError(\"Model generation returned no output.\")\n\n        if response.output_file_paths:\n            output_file_path = response.output_file_paths[0]\n        else:\n            output_file_path = sp.output_file_path()\n            save_outputs(\n                [response.output[0]],\n                sp.data_type,\n                sp.fps,\n                True,\n                lambda _idx: output_file_path,\n                audio=response.audio,\n                audio_sample_rate=response.audio_sample_rate,\n                enable_frame_interpolation=sp.enable_frame_interpolation,\n                frame_interpolation_exp=sp.frame_interpolation_exp,\n                frame_interpolation_scale=sp.frame_interpolation_scale,\n                frame_interpolation_model_path=sp.frame_interpolation_model_path,\n                enable_upscaling=sp.enable_upscaling,\n                upscaling_model_path=sp.upscaling_model_path,\n                upscaling_scale=sp.upscaling_scale,\n            )\n\n        if hasattr(response, \"model_dump\"):\n            data = response.model_dump()\n        else:\n            data = response if isinstance(response, dict) else vars(response)\n\n        if output_file_path:\n            logger.info(\"Processing output file: %s\", output_file_path)\n            b64_video = encode_video_to_base64(output_file_path)\n\n            if b64_video:\n                data[\"output\"] = b64_video\n                data.pop(\"video_data\", None)\n                data.pop(\"video_tensor\", None)\n\n        return make_serializable(data)\n\n    except Exception as e:\n        logger.error(\"Error during generation: %s\", e, exc_info=True)\n        return {\"error\": str(e)}\n\n\nvertex_router = APIRouter()\n\n\n@vertex_router.post(VERTEX_ROUTE)\nasync def vertex_generate(vertex_req: VertexGenerateReqInput):\n    if not vertex_req.instances:\n        return ORJSONResponse({\"predictions\": []})\n\n    server_args = get_global_server_args()\n    params = vertex_req.parameters or {}\n\n    futures = []\n\n    for inst in vertex_req.instances:\n        rid = f\"vertex_{uuid.uuid4()}\"\n\n        sp = build_sampling_params(\n            rid,\n            prompt=inst.get(\"prompt\") or inst.get(\"text\"),\n            image_path=inst.get(\"image\") or inst.get(\"image_url\"),\n            seed=params.get(\"seed\", DEFAULT_SEED),\n            num_frames=params.get(\"num_frames\"),\n            fps=params.get(\"fps\"),\n            width=params.get(\"width\"),\n            height=params.get(\"height\"),\n            guidance_scale=params.get(\"guidance_scale\"),\n            save_output=params.get(\"save_output\"),\n        )\n\n        backend_req = prepare_request(server_args, sampling_params=sp)\n        futures.append(forward_to_scheduler(backend_req, sp))\n\n    results = await asyncio.gather(*futures)\n\n    return ORJSONResponse({\"predictions\": results})\n\n\ndef create_app(server_args: ServerArgs):\n    \"\"\"\n    Create and configure the FastAPI application instance.\n    \"\"\"\n    app = FastAPI(lifespan=lifespan)\n\n    app.include_router(health_router)\n    app.include_router(vertex_router)\n\n    from sglang.multimodal_gen.runtime.entrypoints.openai import common_api, mesh_api\n\n    app.include_router(common_api.router)\n    app.include_router(image_api.router)\n    app.include_router(video_api.router)\n    app.include_router(mesh_api.router)\n    app.include_router(weights_api.router)\n\n    app.state.server_args = server_args\n    return app\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/entrypoints/openai/common_api.py",
    "content": "import time\nfrom typing import Any, List, Optional, Union\n\nfrom fastapi import APIRouter, Body, HTTPException\nfrom fastapi.responses import ORJSONResponse\nfrom pydantic import BaseModel, Field\n\nfrom sglang.multimodal_gen.registry import get_model_info\nfrom sglang.multimodal_gen.runtime.entrypoints.utils import (\n    ListLorasReq,\n    MergeLoraWeightsReq,\n    SetLoraReq,\n    UnmergeLoraWeightsReq,\n    format_lora_message,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import OutputBatch\nfrom sglang.multimodal_gen.runtime.scheduler_client import async_scheduler_client\nfrom sglang.multimodal_gen.runtime.server_args import get_global_server_args\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nrouter = APIRouter(prefix=\"/v1\")\nlogger = init_logger(__name__)\n\n\nclass ModelCard(BaseModel):\n    \"\"\"Model cards.\"\"\"\n\n    id: str\n    object: str = \"model\"\n    created: int = Field(default_factory=lambda: int(time.time()))\n    owned_by: str = \"sglang\"\n    root: Optional[str] = None\n    parent: Optional[str] = None\n    max_model_len: Optional[int] = None\n\n\nclass DiffusionModelCard(ModelCard):\n    \"\"\"Extended ModelCard with diffusion-specific fields.\"\"\"\n\n    num_gpus: Optional[int] = None\n    task_type: Optional[str] = None\n    dit_precision: Optional[str] = None\n    vae_precision: Optional[str] = None\n    pipeline_name: Optional[str] = None\n    pipeline_class: Optional[str] = None\n\n\nasync def _handle_lora_request(req: Any, success_msg: str, failure_msg: str):\n    try:\n        output: OutputBatch = await async_scheduler_client.forward(req)\n        if output.error is None:\n            return {\"status\": \"ok\", \"message\": success_msg}\n        else:\n            error_msg = output.error\n            raise HTTPException(status_code=500, detail=f\"{failure_msg}: {error_msg}\")\n    except Exception as e:\n        if isinstance(e, HTTPException):\n            raise\n        logger.error(f\"Error during '{failure_msg}': {e}\", exc_info=True)\n        raise HTTPException(status_code=500, detail=str(e))\n\n\n@router.post(\"/set_lora\")\nasync def set_lora(\n    lora_nickname: Union[str, List[str]] = Body(..., embed=True),\n    lora_path: Optional[Union[str, List[Optional[str]]]] = Body(None, embed=True),\n    target: Union[str, List[str]] = Body(\"all\", embed=True),\n    strength: Union[float, List[float]] = Body(1.0, embed=True),\n):\n    \"\"\"\n    Set LoRA adapter(s) for the specified transformer(s).\n    Supports both single LoRA (backward compatible) and multiple LoRA adapters.\n\n    Args:\n        lora_nickname: The nickname(s) of the adapter(s). Can be a string or a list of strings.\n        lora_path: Path(s) to the LoRA adapter(s) (local path or HF repo id).\n            Can be a string, None, or a list of strings/None. Must match the length of lora_nickname.\n        target: Which transformer(s) to apply the LoRA to. Can be a string or a list of strings.\n            If a list, must match the length of lora_nickname. Valid values:\n            - \"all\": Apply to all transformers (default)\n            - \"transformer\": Apply only to the primary transformer (high noise for Wan2.2)\n            - \"transformer_2\": Apply only to transformer_2 (low noise for Wan2.2)\n            - \"critic\": Apply only to the critic model\n        strength: LoRA strength(s) for merge, default 1.0. Can be a float or a list of floats.\n            If a list, must match the length of lora_nickname. Values < 1.0 reduce the effect,\n            values > 1.0 amplify the effect.\n    \"\"\"\n    req = SetLoraReq(\n        lora_nickname=lora_nickname,\n        lora_path=lora_path,\n        target=target,\n        strength=strength,\n    )\n    nickname_str, target_str, strength_str = format_lora_message(\n        lora_nickname, target, strength\n    )\n\n    return await _handle_lora_request(\n        req,\n        f\"Successfully set LoRA adapter(s): {nickname_str} (target: {target_str}, strength: {strength_str})\",\n        \"Failed to set LoRA adapter\",\n    )\n\n\n@router.post(\"/merge_lora_weights\")\nasync def merge_lora_weights(\n    target: str = Body(\"all\", embed=True),\n    strength: float = Body(1.0, embed=True),\n):\n    \"\"\"\n    Merge LoRA weights into the base model.\n\n    Args:\n        target: Which transformer(s) to merge. One of \"all\", \"transformer\",\n                \"transformer_2\", \"critic\".\n        strength: LoRA strength for merge, default 1.0. Values < 1.0 reduce the effect,\n            values > 1.0 amplify the effect.\n    \"\"\"\n    req = MergeLoraWeightsReq(target=target, strength=strength)\n    return await _handle_lora_request(\n        req,\n        f\"Successfully merged LoRA weights (target: {target}, strength: {strength})\",\n        \"Failed to merge LoRA weights\",\n    )\n\n\n@router.post(\"/unmerge_lora_weights\")\nasync def unmerge_lora_weights(\n    target: str = Body(\"all\", embed=True),\n):\n    \"\"\"\n    Unmerge LoRA weights from the base model.\n\n    Args:\n        target: Which transformer(s) to unmerge. One of \"all\", \"transformer\",\n                \"transformer_2\", \"critic\".\n    \"\"\"\n    req = UnmergeLoraWeightsReq(target=target)\n    return await _handle_lora_request(\n        req,\n        f\"Successfully unmerged LoRA weights (target: {target})\",\n        \"Failed to unmerge LoRA weights\",\n    )\n\n\n@router.get(\"/model_info\")\nasync def model_info():\n    \"\"\"Get the model information.\"\"\"\n    server_args = get_global_server_args()\n    if not server_args:\n        raise HTTPException(status_code=500, detail=\"Server args not initialized\")\n\n    result = {\n        \"model_path\": server_args.model_path,\n    }\n    return result\n\n\n@router.get(\"/list_loras\")\nasync def list_loras():\n    \"\"\"List loaded LoRA adapters and current application status per module.\"\"\"\n    try:\n        req = ListLorasReq()\n        output: OutputBatch = await async_scheduler_client.forward(req)\n        if output.error is None:\n            return output.output or {}\n        else:\n            raise HTTPException(status_code=500, detail=output.error)\n    except Exception as e:\n        if isinstance(e, HTTPException):\n            raise\n        logger.error(f\"Error during 'list_loras': {e}\", exc_info=True)\n        raise HTTPException(status_code=500, detail=str(e))\n\n\n@router.get(\"/models\", response_class=ORJSONResponse)\nasync def available_models():\n    \"\"\"Show available models. OpenAI-compatible endpoint with extended diffusion info.\"\"\"\n    server_args = get_global_server_args()\n    if not server_args:\n        raise HTTPException(status_code=500, detail=\"Server args not initialized\")\n\n    model_info = get_model_info(\n        server_args.model_path,\n        backend=server_args.backend,\n        model_id=server_args.model_id,\n    )\n\n    card_kwargs = {\n        \"id\": server_args.model_path,\n        \"root\": server_args.model_path,\n        # Extended diffusion-specific fields\n        \"num_gpus\": server_args.num_gpus,\n        \"task_type\": server_args.pipeline_config.task_type.name,\n        \"dit_precision\": server_args.pipeline_config.dit_precision,\n        \"vae_precision\": server_args.pipeline_config.vae_precision,\n    }\n\n    if model_info:\n        card_kwargs[\"pipeline_name\"] = model_info.pipeline_cls.pipeline_name\n        card_kwargs[\"pipeline_class\"] = model_info.pipeline_cls.__name__\n\n    model_card = DiffusionModelCard(**card_kwargs)\n\n    # Return dict directly to preserve extended fields (ModelList strips them)\n    return {\"object\": \"list\", \"data\": [model_card.model_dump()]}\n\n\n@router.get(\"/models/{model:path}\", response_class=ORJSONResponse)\nasync def retrieve_model(model: str):\n    \"\"\"Retrieve a model instance. OpenAI-compatible endpoint with extended diffusion info.\"\"\"\n    server_args = get_global_server_args()\n    if not server_args:\n        raise HTTPException(status_code=500, detail=\"Server args not initialized\")\n\n    if model != server_args.model_path:\n        return ORJSONResponse(\n            status_code=404,\n            content={\n                \"error\": {\n                    \"message\": f\"The model '{model}' does not exist\",\n                    \"type\": \"invalid_request_error\",\n                    \"param\": \"model\",\n                    \"code\": \"model_not_found\",\n                }\n            },\n        )\n\n    model_info = get_model_info(\n        server_args.model_path,\n        backend=server_args.backend,\n        model_id=server_args.model_id,\n    )\n\n    card_kwargs = {\n        \"id\": model,\n        \"root\": model,\n        \"num_gpus\": server_args.num_gpus,\n        \"task_type\": server_args.pipeline_config.task_type.name,\n        \"dit_precision\": server_args.pipeline_config.dit_precision,\n        \"vae_precision\": server_args.pipeline_config.vae_precision,\n    }\n\n    if model_info:\n        card_kwargs[\"pipeline_name\"] = model_info.pipeline_cls.pipeline_name\n        card_kwargs[\"pipeline_class\"] = model_info.pipeline_cls.__name__\n\n    # Return dict to preserve extended fields\n    return DiffusionModelCard(**card_kwargs).model_dump()\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/entrypoints/openai/image_api.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\nimport base64\nimport contextlib\nimport os\nimport time\nfrom typing import List, Optional\n\nfrom fastapi import APIRouter, File, Form, HTTPException, Path, Query, UploadFile\nfrom fastapi.responses import FileResponse\n\nfrom sglang.multimodal_gen.configs.sample.sampling_params import generate_request_id\nfrom sglang.multimodal_gen.runtime.entrypoints.openai.protocol import (\n    ImageGenerationsRequest,\n    ImageResponse,\n    ImageResponseData,\n)\nfrom sglang.multimodal_gen.runtime.entrypoints.openai.storage import cloud_storage\nfrom sglang.multimodal_gen.runtime.entrypoints.openai.stores import IMAGE_STORE\nfrom sglang.multimodal_gen.runtime.entrypoints.openai.utils import (\n    add_common_data_to_response,\n    build_sampling_params,\n    choose_output_image_ext,\n    merge_image_input_list,\n    process_generation_batch,\n    save_image_to_path,\n    temp_dir_if_disabled,\n)\nfrom sglang.multimodal_gen.runtime.entrypoints.utils import prepare_request\nfrom sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import OutputBatch\nfrom sglang.multimodal_gen.runtime.scheduler_client import async_scheduler_client\nfrom sglang.multimodal_gen.runtime.server_args import get_global_server_args\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nrouter = APIRouter(prefix=\"/v1/images\", tags=[\"images\"])\nlogger = init_logger(__name__)\n\n\ndef _read_b64_for_paths(paths: list[str]) -> list[str]:\n    \"\"\"Read and base64-encode each file. Must be called before cloud upload deletes them.\"\"\"\n    result = []\n    for path in paths:\n        with open(path, \"rb\") as f:\n            result.append(base64.b64encode(f.read()).decode(\"utf-8\"))\n    return result\n\n\ndef _build_image_response_kwargs(\n    save_file_path_list: list[str],\n    resp_format: str,\n    prompt: str,\n    request_id: str,\n    result: OutputBatch,\n    *,\n    b64_list: list[str] | None = None,\n    cloud_url: str | None = None,\n    fallback_url: str | None = None,\n    is_persistent: bool = True,\n) -> dict:\n    \"\"\"Build ImageResponse data list.\n\n    For b64_json: uses pre-read b64_list (call _read_b64_for_paths first).\n    For url: uses cloud_url or fallback_url.\n    file_path is omitted when is_persistent=False to avoid exposing stale temp paths.\n    \"\"\"\n    ret = None\n    if resp_format == \"b64_json\":\n        if not b64_list:\n            raise ValueError(\"b64_list required for b64_json response_format\")\n        data = [\n            ImageResponseData(\n                b64_json=b64,\n                revised_prompt=prompt,\n                file_path=os.path.abspath(path) if is_persistent else None,\n            )\n            for b64, path in zip(b64_list, save_file_path_list)\n        ]\n        ret = {\"data\": data}\n    elif resp_format == \"url\":\n        url = cloud_url or fallback_url\n        if not url:\n            raise HTTPException(\n                status_code=400,\n                detail=\"response_format='url' requires cloud storage to be configured.\",\n            )\n        ret = {\n            \"data\": [\n                ImageResponseData(\n                    url=url,\n                    revised_prompt=prompt,\n                    file_path=(\n                        os.path.abspath(save_file_path_list[0])\n                        if is_persistent\n                        else None\n                    ),\n                )\n            ],\n        }\n    else:\n        raise HTTPException(\n            status_code=400, detail=f\"response_format={resp_format} is not supported\"\n        )\n\n    ret = add_common_data_to_response(ret, request_id=request_id, result=result)\n\n    return ret\n\n\n@router.post(\"/generations\", response_model=ImageResponse)\nasync def generations(\n    request: ImageGenerationsRequest,\n):\n    request_id = generate_request_id()\n    server_args = get_global_server_args()\n    ext = choose_output_image_ext(request.output_format, request.background)\n\n    with temp_dir_if_disabled(server_args.output_path) as output_dir:\n        sampling = build_sampling_params(\n            request_id,\n            prompt=request.prompt,\n            size=request.size,\n            width=request.width,\n            height=request.height,\n            num_outputs_per_prompt=max(1, min(int(request.n or 1), 10)),\n            output_file_name=f\"{request_id}.{ext}\",\n            output_path=output_dir,\n            seed=request.seed,\n            generator_device=request.generator_device,\n            num_inference_steps=request.num_inference_steps,\n            guidance_scale=request.guidance_scale,\n            true_cfg_scale=request.true_cfg_scale,\n            negative_prompt=request.negative_prompt,\n            enable_teacache=request.enable_teacache,\n            output_compression=request.output_compression,\n            output_quality=request.output_quality,\n            enable_upscaling=request.enable_upscaling,\n            upscaling_model_path=request.upscaling_model_path,\n            upscaling_scale=request.upscaling_scale,\n        )\n        batch = prepare_request(\n            server_args=server_args,\n            sampling_params=sampling,\n        )\n        # Add diffusers_kwargs if provided\n        if request.diffusers_kwargs:\n            batch.extra[\"diffusers_kwargs\"] = request.diffusers_kwargs\n\n        save_file_path_list, result = await process_generation_batch(\n            async_scheduler_client, batch\n        )\n        save_file_path = save_file_path_list[0]\n        resp_format = (request.response_format or \"b64_json\").lower()\n\n        # read b64 before cloud upload may delete the local file\n        b64_list = (\n            _read_b64_for_paths(save_file_path_list)\n            if resp_format == \"b64_json\"\n            else None\n        )\n\n        cloud_url = await cloud_storage.upload_and_cleanup(save_file_path)\n\n        is_persistent = server_args.output_path is not None\n        await IMAGE_STORE.upsert(\n            request_id,\n            {\n                \"id\": request_id,\n                \"created_at\": int(time.time()),\n                \"file_path\": None if cloud_url or not is_persistent else save_file_path,\n                \"url\": cloud_url,\n            },\n        )\n\n        response_kwargs = _build_image_response_kwargs(\n            save_file_path_list,\n            resp_format,\n            request.prompt,\n            request_id,\n            result,\n            b64_list=b64_list,\n            cloud_url=cloud_url,\n            fallback_url=f\"/v1/images/{request_id}/content\" if is_persistent else None,\n            is_persistent=is_persistent,\n        )\n\n    return ImageResponse(**response_kwargs)\n\n\n@router.post(\"/edits\", response_model=ImageResponse)\nasync def edits(\n    image: Optional[List[UploadFile]] = File(None),\n    image_array: Optional[List[UploadFile]] = File(None, alias=\"image[]\"),\n    url: Optional[List[str]] = Form(None),\n    url_array: Optional[List[str]] = Form(None, alias=\"url[]\"),\n    prompt: str = Form(...),\n    mask: Optional[UploadFile] = File(None),\n    model: Optional[str] = Form(None),\n    n: Optional[int] = Form(1),\n    response_format: Optional[str] = Form(None),\n    size: Optional[str] = Form(None),\n    output_format: Optional[str] = Form(None),\n    background: Optional[str] = Form(\"auto\"),\n    seed: Optional[int] = Form(1024),\n    generator_device: Optional[str] = Form(\"cuda\"),\n    user: Optional[str] = Form(None),\n    negative_prompt: Optional[str] = Form(None),\n    guidance_scale: Optional[float] = Form(None),\n    true_cfg_scale: Optional[float] = Form(None),\n    num_inference_steps: Optional[int] = Form(None),\n    output_quality: Optional[str] = Form(\"default\"),\n    output_compression: Optional[int] = Form(None),\n    enable_teacache: Optional[bool] = Form(False),\n    enable_upscaling: Optional[bool] = Form(False),\n    upscaling_model_path: Optional[str] = Form(None),\n    upscaling_scale: Optional[int] = Form(4),\n    num_frames: int = Form(1),\n):\n    request_id = generate_request_id()\n    server_args = get_global_server_args()\n    # Resolve images from either `image` or `image[]` (OpenAI SDK sends `image[]` when list is provided)\n    images = image or image_array\n    urls = url or url_array\n\n    if (not images or len(images) == 0) and (not urls or len(urls) == 0):\n        raise HTTPException(\n            status_code=422, detail=\"Field 'image' or 'url' is required\"\n        )\n\n    image_list = merge_image_input_list(images, urls)\n\n    with contextlib.ExitStack() as stack:\n        uploads_dir = stack.enter_context(\n            temp_dir_if_disabled(server_args.input_save_path)\n        )\n        output_dir = stack.enter_context(temp_dir_if_disabled(server_args.output_path))\n\n        input_paths = []\n        try:\n            for idx, img in enumerate(image_list):\n                filename = img.filename if hasattr(img, \"filename\") else f\"image_{idx}\"\n                input_path = await save_image_to_path(\n                    img,\n                    os.path.join(uploads_dir, f\"{request_id}_{idx}_{filename}\"),\n                )\n                input_paths.append(input_path)\n        except Exception as e:\n            raise HTTPException(\n                status_code=400,\n                detail=f\"Failed to process image source: {str(e)}\",\n            )\n\n        ext = choose_output_image_ext(output_format, background)\n        sampling = build_sampling_params(\n            request_id,\n            prompt=prompt,\n            size=size,\n            num_outputs_per_prompt=max(1, min(int(n or 1), 10)),\n            output_file_name=f\"{request_id}.{ext}\",\n            output_path=output_dir,\n            image_path=input_paths,\n            seed=seed,\n            generator_device=generator_device,\n            negative_prompt=negative_prompt,\n            guidance_scale=guidance_scale,\n            true_cfg_scale=true_cfg_scale,\n            num_inference_steps=num_inference_steps,\n            enable_teacache=enable_teacache,\n            num_frames=num_frames,\n            output_compression=output_compression,\n            output_quality=output_quality,\n            enable_upscaling=enable_upscaling,\n            upscaling_model_path=upscaling_model_path,\n            upscaling_scale=upscaling_scale,\n        )\n        batch = prepare_request(\n            server_args=server_args,\n            sampling_params=sampling,\n        )\n        save_file_path_list, result = await process_generation_batch(\n            async_scheduler_client, batch\n        )\n        save_file_path = save_file_path_list[0]\n        resp_format = (response_format or \"b64_json\").lower()\n\n        # read b64 before cloud upload may delete the local file\n        b64_list = (\n            _read_b64_for_paths(save_file_path_list)\n            if resp_format == \"b64_json\"\n            else None\n        )\n\n        cloud_url = await cloud_storage.upload_and_cleanup(save_file_path)\n\n        is_persistent = server_args.output_path is not None\n        is_input_persistent = server_args.input_save_path is not None\n        await IMAGE_STORE.upsert(\n            request_id,\n            {\n                \"id\": request_id,\n                \"created_at\": int(time.time()),\n                \"file_path\": None if cloud_url or not is_persistent else save_file_path,\n                \"url\": cloud_url,\n                \"input_image_paths\": input_paths if is_input_persistent else None,\n                \"num_input_images\": len(input_paths),\n            },\n        )\n\n        response_kwargs = _build_image_response_kwargs(\n            save_file_path_list,\n            resp_format,\n            prompt,\n            request_id,\n            result,\n            b64_list=b64_list,\n            cloud_url=cloud_url,\n            fallback_url=f\"/v1/images/{request_id}/content\" if is_persistent else None,\n            is_persistent=is_persistent,\n        )\n\n    return ImageResponse(**response_kwargs)\n\n\n@router.get(\"/{image_id}/content\")\nasync def download_image_content(\n    image_id: str = Path(...), variant: Optional[str] = Query(None)\n):\n    item = await IMAGE_STORE.get(image_id)\n    if not item:\n        raise HTTPException(status_code=404, detail=\"Image not found\")\n\n    if item.get(\"url\"):\n        raise HTTPException(\n            status_code=400,\n            detail=f\"Image has been uploaded to cloud storage. Please use the cloud URL: {item.get('url')}\",\n        )\n\n    file_path = item.get(\"file_path\")\n    if not file_path:\n        raise HTTPException(\n            status_code=404,\n            detail=\"Image was not persisted on disk (output_path is disabled). Use b64_json response_format or configure cloud storage.\",\n        )\n    if not os.path.exists(file_path):\n        raise HTTPException(status_code=404, detail=\"Image is still being generated\")\n\n    ext = os.path.splitext(file_path)[1].lower()\n    media_type = \"image/jpeg\"\n    if ext == \".png\":\n        media_type = \"image/png\"\n    elif ext == \".webp\":\n        media_type = \"image/webp\"\n\n    return FileResponse(\n        path=file_path, media_type=media_type, filename=os.path.basename(file_path)\n    )\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/entrypoints/openai/mesh_api.py",
    "content": "import asyncio\nimport os\nimport time\nfrom typing import Any, Dict, List, Optional\n\nfrom fastapi import (\n    APIRouter,\n    File,\n    Form,\n    HTTPException,\n    Path,\n    Query,\n    Request,\n    UploadFile,\n)\nfrom fastapi.responses import FileResponse\n\nfrom sglang.multimodal_gen.configs.sample.sampling_params import (\n    SamplingParams,\n    generate_request_id,\n)\nfrom sglang.multimodal_gen.runtime.entrypoints.openai.protocol import (\n    MeshGenerationsRequest,\n    MeshListResponse,\n    MeshResponse,\n)\nfrom sglang.multimodal_gen.runtime.entrypoints.openai.storage import cloud_storage\nfrom sglang.multimodal_gen.runtime.entrypoints.openai.stores import MESH_STORE\nfrom sglang.multimodal_gen.runtime.entrypoints.openai.utils import (\n    add_common_data_to_response,\n    merge_image_input_list,\n    process_generation_batch,\n    save_image_to_path,\n)\nfrom sglang.multimodal_gen.runtime.entrypoints.utils import prepare_request\nfrom sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req\nfrom sglang.multimodal_gen.runtime.server_args import get_global_server_args\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\nrouter = APIRouter(prefix=\"/v1/meshes\", tags=[\"meshes\"])\n\n\ndef _normalize_format(fmt: Optional[str]) -> str:\n    fmt = (fmt or \"glb\").lower()\n    return fmt if fmt in (\"glb\", \"obj\") else \"glb\"\n\n\ndef _build_sampling_params_from_request(\n    request_id: str, req: MeshGenerationsRequest, image_path: Optional[str] = None\n) -> SamplingParams:\n    ext = _normalize_format(req.output_format)\n\n    server_args = get_global_server_args()\n    sampling_kwargs: Dict[str, Any] = {\n        \"request_id\": request_id,\n        \"prompt\": req.prompt,\n        \"num_frames\": 1,\n        \"image_path\": [image_path] if image_path else None,\n        \"save_output\": True,\n        \"output_file_name\": f\"{request_id}.{ext}\",\n        \"seed\": req.seed,\n        \"generator_device\": req.generator_device,\n    }\n    if req.num_inference_steps is not None:\n        sampling_kwargs[\"num_inference_steps\"] = req.num_inference_steps\n    if req.guidance_scale is not None:\n        sampling_kwargs[\"guidance_scale\"] = req.guidance_scale\n    if req.negative_prompt is not None:\n        sampling_kwargs[\"negative_prompt\"] = req.negative_prompt\n\n    return SamplingParams.from_user_sampling_params_args(\n        model_path=server_args.model_path,\n        server_args=server_args,\n        **sampling_kwargs,\n    )\n\n\ndef _mesh_job_from_sampling(\n    request_id: str, req: MeshGenerationsRequest, sampling: SamplingParams\n) -> Dict[str, Any]:\n    return {\n        \"id\": request_id,\n        \"object\": \"mesh\",\n        \"model\": req.model or \"\",\n        \"status\": \"queued\",\n        \"progress\": 0,\n        \"created_at\": int(time.time()),\n        \"format\": _normalize_format(req.output_format),\n        \"file_path\": os.path.abspath(sampling.output_file_path()),\n    }\n\n\nasync def _dispatch_job_async(job_id: str, batch: Req) -> None:\n    from sglang.multimodal_gen.runtime.scheduler_client import async_scheduler_client\n\n    try:\n        save_file_path_list, result = await process_generation_batch(\n            async_scheduler_client, batch\n        )\n        save_file_path = save_file_path_list[0]\n\n        file_size = None\n        if os.path.exists(save_file_path):\n            file_size = os.path.getsize(save_file_path)\n\n        cloud_url = await cloud_storage.upload_and_cleanup(save_file_path)\n\n        update_fields: Dict[str, Any] = {\n            \"status\": \"completed\",\n            \"progress\": 100,\n            \"completed_at\": int(time.time()),\n            \"url\": cloud_url,\n            \"file_path\": save_file_path if not cloud_url else None,\n            \"file_size_bytes\": file_size,\n        }\n        update_fields = add_common_data_to_response(\n            update_fields, request_id=job_id, result=result\n        )\n        await MESH_STORE.update_fields(job_id, update_fields)\n    except Exception as e:\n        logger.error(f\"{e}\")\n        await MESH_STORE.update_fields(\n            job_id, {\"status\": \"failed\", \"error\": {\"message\": str(e)}}\n        )\n\n\n@router.post(\"\", response_model=MeshResponse)\nasync def create_mesh(\n    request: Request,\n    image: Optional[List[UploadFile]] = File(None),\n    image_array: Optional[List[UploadFile]] = File(None, alias=\"image[]\"),\n    url: Optional[List[str]] = Form(None),\n    url_array: Optional[List[str]] = Form(None, alias=\"url[]\"),\n    prompt: Optional[str] = Form(\"generate 3d mesh\"),\n    model: Optional[str] = Form(None),\n    seed: Optional[int] = Form(None),\n    generator_device: Optional[str] = Form(\"cuda\"),\n    guidance_scale: Optional[float] = Form(None),\n    num_inference_steps: Optional[int] = Form(None),\n    negative_prompt: Optional[str] = Form(None),\n    output_format: Optional[str] = Form(\"glb\"),\n):\n    content_type = request.headers.get(\"content-type\", \"\").lower()\n    request_id = generate_request_id()\n    server_args = get_global_server_args()\n\n    input_path = None\n\n    if \"multipart/form-data\" in content_type:\n        images = image or image_array\n        urls = url or url_array\n        image_list = merge_image_input_list(images, urls)\n\n        if not image_list:\n            raise HTTPException(\n                status_code=422,\n                detail=\"Field 'image' or 'url' is required for mesh generation\",\n            )\n\n        uploads_dir = os.path.join(\"outputs\", \"uploads\")\n        os.makedirs(uploads_dir, exist_ok=True)\n        img = image_list[0]\n        filename = img.filename if hasattr(img, \"filename\") else \"input_image\"\n        try:\n            input_path = await save_image_to_path(\n                img, os.path.join(uploads_dir, f\"{request_id}_{filename}\")\n            )\n        except Exception as e:\n            raise HTTPException(\n                status_code=400, detail=f\"Failed to process image source: {str(e)}\"\n            )\n\n        req = MeshGenerationsRequest(\n            prompt=prompt or \"generate 3d mesh\",\n            model=model,\n            seed=seed,\n            generator_device=generator_device,\n            num_inference_steps=num_inference_steps,\n            negative_prompt=negative_prompt,\n            output_format=output_format,\n            **(\n                {\"guidance_scale\": guidance_scale} if guidance_scale is not None else {}\n            ),\n        )\n    else:\n        try:\n            body = await request.json()\n        except Exception:\n            body = {}\n        try:\n            payload: Dict[str, Any] = dict(body or {})\n\n            if payload.get(\"input_image\"):\n                img_src = payload.pop(\"input_image\")\n                uploads_dir = os.path.join(\"outputs\", \"uploads\")\n                os.makedirs(uploads_dir, exist_ok=True)\n                input_path = await save_image_to_path(\n                    img_src,\n                    os.path.join(uploads_dir, f\"{request_id}_input_image\"),\n                )\n\n            req = MeshGenerationsRequest(**payload)\n        except Exception as e:\n            raise HTTPException(status_code=400, detail=f\"Invalid request body: {e}\")\n\n    if not input_path:\n        raise HTTPException(\n            status_code=422,\n            detail=\"An input image is required for mesh generation\",\n        )\n\n    sampling_params = _build_sampling_params_from_request(request_id, req, input_path)\n    job = _mesh_job_from_sampling(request_id, req, sampling_params)\n    await MESH_STORE.upsert(request_id, job)\n\n    batch = prepare_request(\n        server_args=server_args,\n        sampling_params=sampling_params,\n    )\n\n    asyncio.create_task(_dispatch_job_async(request_id, batch))\n    return MeshResponse(**job)\n\n\n@router.get(\"\", response_model=MeshListResponse)\nasync def list_meshes(\n    after: Optional[str] = Query(None),\n    limit: Optional[int] = Query(None, ge=1, le=100),\n    order: Optional[str] = Query(\"desc\"),\n):\n    order = (order or \"desc\").lower()\n    if order not in (\"asc\", \"desc\"):\n        order = \"desc\"\n    jobs = await MESH_STORE.list_values()\n\n    reverse = order != \"asc\"\n    jobs.sort(key=lambda j: j.get(\"created_at\", 0), reverse=reverse)\n\n    if after is not None:\n        try:\n            idx = next(i for i, j in enumerate(jobs) if j[\"id\"] == after)\n            jobs = jobs[idx + 1 :]\n        except StopIteration:\n            jobs = []\n\n    if limit is not None:\n        jobs = jobs[:limit]\n    items = [MeshResponse(**j) for j in jobs]\n    return MeshListResponse(data=items)\n\n\n@router.get(\"/{mesh_id}\", response_model=MeshResponse)\nasync def retrieve_mesh(mesh_id: str = Path(...)):\n    job = await MESH_STORE.get(mesh_id)\n    if not job:\n        raise HTTPException(status_code=404, detail=\"Mesh not found\")\n    return MeshResponse(**job)\n\n\n@router.delete(\"/{mesh_id}\", response_model=MeshResponse)\nasync def delete_mesh(mesh_id: str = Path(...)):\n    job = await MESH_STORE.pop(mesh_id)\n    if not job:\n        raise HTTPException(status_code=404, detail=\"Mesh not found\")\n    job[\"status\"] = \"deleted\"\n    return MeshResponse(**job)\n\n\n@router.get(\"/{mesh_id}/content\")\nasync def download_mesh_content(\n    mesh_id: str = Path(...), variant: Optional[str] = Query(None)\n):\n    job = await MESH_STORE.get(mesh_id)\n    if not job:\n        raise HTTPException(status_code=404, detail=\"Mesh not found\")\n\n    if job.get(\"url\"):\n        raise HTTPException(\n            status_code=400,\n            detail=f\"Mesh has been uploaded to cloud storage. Please use the cloud URL: {job.get('url')}\",\n        )\n\n    file_path = job.get(\"file_path\")\n    if not file_path or not os.path.exists(file_path):\n        raise HTTPException(status_code=404, detail=\"Generation is still in-progress\")\n\n    ext = os.path.splitext(file_path)[1].lower()\n    media_type = {\n        \".glb\": \"model/gltf-binary\",\n        \".obj\": \"text/plain\",\n    }.get(ext, \"application/octet-stream\")\n\n    return FileResponse(\n        path=file_path, media_type=media_type, filename=os.path.basename(file_path)\n    )\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/entrypoints/openai/protocol.py",
    "content": "import time\nimport uuid\nfrom abc import ABC\nfrom dataclasses import dataclass, field\nfrom typing import Any, Dict, List, Optional, Union\n\nfrom pydantic import BaseModel, Field\n\n\n# Image API protocol models\nclass ImageResponseData(BaseModel):\n    b64_json: Optional[str] = None\n    url: Optional[str] = None\n    revised_prompt: Optional[str] = None\n    file_path: Optional[str] = None\n\n\nclass ImageResponse(BaseModel):\n    id: str\n    created: int = Field(default_factory=lambda: int(time.time()))\n    data: List[ImageResponseData]\n    peak_memory_mb: Optional[float] = None\n    inference_time_s: Optional[float] = None\n\n\nclass ImageGenerationsRequest(BaseModel):\n    prompt: str\n    model: Optional[str] = None\n    n: Optional[int] = 1\n    quality: Optional[str] = \"auto\"\n    response_format: Optional[str] = \"url\"  # url | b64_json\n    size: Optional[str] = \"1024x1024\"  # e.g., 1024x1024\n    style: Optional[str] = \"vivid\"\n    background: Optional[str] = \"auto\"  # transparent | opaque | auto\n    output_format: Optional[str] = None  # png | jpeg | webp\n    user: Optional[str] = None\n    # SGLang extensions\n    width: Optional[int] = None\n    height: Optional[int] = None\n    num_inference_steps: Optional[int] = None\n    guidance_scale: Optional[float] = None\n    true_cfg_scale: Optional[float] = (\n        None  # for CFG vs guidance distillation (e.g., QwenImage)\n    )\n    seed: Optional[int] = 1024\n    generator_device: Optional[str] = \"cuda\"\n    negative_prompt: Optional[str] = None\n    output_quality: Optional[str] = \"default\"\n    output_compression: Optional[int] = None\n    enable_teacache: Optional[bool] = False\n    # Upscaling\n    enable_upscaling: Optional[bool] = False\n    upscaling_model_path: Optional[str] = None\n    upscaling_scale: Optional[int] = 4\n    diffusers_kwargs: Optional[Dict[str, Any]] = None  # kwargs for diffusers backend\n\n\n# Video API protocol models\nclass VideoResponse(BaseModel):\n    id: str\n    object: str = \"video\"\n    model: str = \"sora-2\"\n    status: str = \"queued\"\n    progress: int = 0\n    created_at: int = Field(default_factory=lambda: int(time.time()))\n    size: str = \"\"\n    seconds: str = \"4\"\n    quality: str = \"standard\"\n    url: Optional[str] = None\n    remixed_from_video_id: Optional[str] = None\n    completed_at: Optional[int] = None\n    expires_at: Optional[int] = None\n    error: Optional[Dict[str, Any]] = None\n    file_path: Optional[str] = None\n    peak_memory_mb: Optional[float] = None\n    inference_time_s: Optional[float] = None\n\n\nclass VideoGenerationsRequest(BaseModel):\n    prompt: str\n    input_reference: Optional[str] = None\n    reference_url: Optional[str] = None\n    model: Optional[str] = None\n    seconds: Optional[int] = 4\n    size: Optional[str] = \"\"\n    fps: Optional[int] = None\n    num_frames: Optional[int] = None\n    seed: Optional[int] = 1024\n    generator_device: Optional[str] = \"cuda\"\n    # SGLang extensions\n    num_inference_steps: Optional[int] = None\n    guidance_scale: Optional[float] = None\n    guidance_scale_2: Optional[float] = None\n    true_cfg_scale: Optional[float] = (\n        None  # for CFG vs guidance distillation (e.g., QwenImage)\n    )\n    negative_prompt: Optional[str] = None\n    enable_teacache: Optional[bool] = False\n    # Frame interpolation\n    enable_frame_interpolation: Optional[bool] = False\n    frame_interpolation_exp: Optional[int] = 1  # 1=2×, 2=4×\n    frame_interpolation_scale: Optional[float] = 1.0\n    frame_interpolation_model_path: Optional[str] = None\n    # Upscaling\n    enable_upscaling: Optional[bool] = False\n    upscaling_model_path: Optional[str] = None\n    upscaling_scale: Optional[int] = 4\n    output_quality: Optional[str] = \"default\"\n    output_compression: Optional[int] = None\n    output_path: Optional[str] = None\n    diffusers_kwargs: Optional[Dict[str, Any]] = None  # kwargs for diffusers backend\n\n\nclass VideoListResponse(BaseModel):\n    data: List[VideoResponse]\n    object: str = \"list\"\n\n\nclass VideoRemixRequest(BaseModel):\n    prompt: str\n\n\n# Mesh API protocol models\nclass MeshResponse(BaseModel):\n    id: str\n    object: str = \"mesh\"\n    model: str = \"\"\n    status: str = \"queued\"\n    progress: int = 0\n    created_at: int = Field(default_factory=lambda: int(time.time()))\n    format: str = \"glb\"\n    url: Optional[str] = None\n    completed_at: Optional[int] = None\n    expires_at: Optional[int] = None\n    error: Optional[Dict[str, Any]] = None\n    file_path: Optional[str] = None\n    file_size_bytes: Optional[int] = None\n    peak_memory_mb: Optional[float] = None\n    inference_time_s: Optional[float] = None\n\n\nclass MeshGenerationsRequest(BaseModel):\n    prompt: str = \"generate 3d mesh\"\n    input_image: Optional[str] = None\n    model: Optional[str] = None\n    seed: Optional[int] = None\n    generator_device: Optional[str] = \"cuda\"\n    num_inference_steps: Optional[int] = None\n    guidance_scale: Optional[float] = None\n    negative_prompt: Optional[str] = None\n    output_format: Optional[str] = \"glb\"\n\n\nclass MeshListResponse(BaseModel):\n    data: List[MeshResponse]\n    object: str = \"list\"\n\n\n@dataclass\nclass BaseReq(ABC):\n    rid: Optional[Union[str, List[str]]] = field(default=None, kw_only=True)\n    http_worker_ipc: Optional[str] = field(default=None, kw_only=True)\n\n    def regenerate_rid(self):\n        \"\"\"Generate a new request ID and return it.\"\"\"\n        if isinstance(self.rid, list):\n            self.rid = [uuid.uuid4().hex for _ in range(len(self.rid))]\n        else:\n            self.rid = uuid.uuid4().hex\n        return self.rid\n\n\n@dataclass\nclass VertexGenerateReqInput(BaseReq):\n    instances: List[dict]\n    parameters: Optional[dict] = None\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/entrypoints/openai/storage.py",
    "content": "import asyncio\nimport os\nfrom typing import Optional\n\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\nclass CloudStorage:\n    def __init__(self):\n        self.enabled = os.getenv(\"SGLANG_CLOUD_STORAGE_TYPE\", \"\").lower() == \"s3\"\n        if not self.enabled:\n            return\n\n        try:\n            import boto3\n        except ImportError:\n            logger.error(\n                \"boto3 is not installed. Please install it with `pip install boto3` to use cloud storage.\"\n            )\n            self.enabled = False\n            return\n\n        self.bucket_name = os.getenv(\"SGLANG_S3_BUCKET_NAME\")\n        if not self.bucket_name:\n            self.enabled = False\n            return\n\n        endpoint_url = os.getenv(\"SGLANG_S3_ENDPOINT_URL\") or None\n        region_name = os.getenv(\"SGLANG_S3_REGION_NAME\") or None\n\n        self.client = boto3.client(\n            \"s3\",\n            aws_access_key_id=os.getenv(\"SGLANG_S3_ACCESS_KEY_ID\"),\n            aws_secret_access_key=os.getenv(\"SGLANG_S3_SECRET_ACCESS_KEY\"),\n            endpoint_url=endpoint_url,\n            region_name=region_name,\n        )\n        self.endpoint_url = endpoint_url\n        self.region_name = region_name\n\n    def is_enabled(self) -> bool:\n        return self.enabled\n\n    async def upload_file(self, local_path: str, destination_key: str) -> Optional[str]:\n        if not self.is_enabled():\n            return None\n\n        def _sync_upload():\n            \"\"\"Synchronous part of the upload to run in a thread.\"\"\"\n            ext = os.path.splitext(local_path)[1].lower()\n            content_type = {\n                \".png\": \"image/png\",\n                \".jpg\": \"image/jpeg\",\n                \".jpeg\": \"image/jpeg\",\n                \".webp\": \"image/webp\",\n                \".mp4\": \"video/mp4\",\n                \".glb\": \"model/gltf-binary\",\n                \".obj\": \"text/plain\",\n            }.get(ext, \"application/octet-stream\")\n\n            # Use the client created once in __init__\n            self.client.upload_file(\n                local_path,\n                self.bucket_name,\n                destination_key,\n                ExtraArgs={\"ContentType\": content_type},\n            )\n\n        try:\n            # Offload the blocking I/O call to a thread executor\n            await asyncio.get_running_loop().run_in_executor(None, _sync_upload)\n        except Exception as e:\n            # If upload fails, log the error and return None for fallback\n            logger.error(f\"Upload failed for {destination_key}: {e}\")\n            return None\n\n        # Simplified URL generation with a default region\n        if self.endpoint_url:\n            url = (\n                f\"{self.endpoint_url.rstrip('/')}/{self.bucket_name}/{destination_key}\"\n            )\n        else:\n            region = self.region_name or \"us-east-1\"\n            url = f\"https://{self.bucket_name}.s3.{region}.amazonaws.com/{destination_key}\"\n\n        logger.info(f\"Uploaded {local_path} to {url}\")\n        return url\n\n    async def upload_and_cleanup(self, file_path: str) -> Optional[str]:\n        \"\"\"Helper to upload a file and delete the local copy if successful.\"\"\"\n        if not self.is_enabled():\n            return None\n\n        key = os.path.basename(file_path)\n        url = await self.upload_file(file_path, key)\n\n        if url:\n            try:\n                # pass if removal fails\n                os.remove(file_path)\n            except OSError as e:\n                logger.warning(f\"Failed to remove temporary file {file_path}: {e}\")\n        return url\n\n\n# Global instance\ncloud_storage = CloudStorage()\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/entrypoints/openai/stores.py",
    "content": "import asyncio\nfrom typing import Any, Dict, List, Optional\n\n\nclass AsyncDictStore:\n    \"\"\"A small async-safe in-memory key-value store for dict items.\n\n    This encapsulates the usual pattern of a module-level dict guarded by\n    an asyncio.Lock and provides simple CRUD methods that are safe to call\n    concurrently from FastAPI request handlers and background tasks.\n    \"\"\"\n\n    def __init__(self) -> None:\n        self._items: Dict[str, Dict[str, Any]] = {}\n        self._lock = asyncio.Lock()\n\n    async def upsert(self, key: str, value: Dict[str, Any]) -> None:\n        async with self._lock:\n            self._items[key] = value\n\n    async def update_fields(\n        self, key: str, updates: Dict[str, Any]\n    ) -> Optional[Dict[str, Any]]:\n        async with self._lock:\n            item = self._items.get(key)\n            if item is None:\n                return None\n            item.update(updates)\n            return item\n\n    async def get(self, key: str) -> Optional[Dict[str, Any]]:\n        async with self._lock:\n            return self._items.get(key)\n\n    async def pop(self, key: str) -> Optional[Dict[str, Any]]:\n        async with self._lock:\n            return self._items.pop(key, None)\n\n    async def list_values(self) -> List[Dict[str, Any]]:\n        async with self._lock:\n            return list(self._items.values())\n\n\n# Global stores shared by OpenAI entrypoints\n# [request_id, dict]\nVIDEO_STORE = AsyncDictStore()\nIMAGE_STORE = AsyncDictStore()\nMESH_STORE = AsyncDictStore()\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/entrypoints/openai/utils.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\nimport base64\nimport os\nimport re\nimport shutil\nimport tempfile\nimport time\nfrom contextlib import contextmanager\nfrom typing import Any, Generator, List, Optional, Union\n\nimport httpx\nfrom fastapi import UploadFile\n\nfrom sglang.multimodal_gen.configs.sample.sampling_params import (\n    DataType,\n    SamplingParams,\n)\nfrom sglang.multimodal_gen.runtime.entrypoints.utils import (\n    ListLorasReq,\n    MergeLoraWeightsReq,\n    SetLoraReq,\n    ShutdownReq,\n    UnmergeLoraWeightsReq,\n    format_lora_message,\n    save_outputs,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import OutputBatch\nfrom sglang.multimodal_gen.runtime.scheduler_client import AsyncSchedulerClient\nfrom sglang.multimodal_gen.runtime.server_args import get_global_server_args\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import (\n    init_logger,\n    log_batch_completion,\n    log_generation_timer,\n)\n\n# re-export LoRA protocol types for backward compatibility\n__all__ = [\n    \"SetLoraReq\",\n    \"MergeLoraWeightsReq\",\n    \"UnmergeLoraWeightsReq\",\n    \"ListLorasReq\",\n    \"ShutdownReq\",\n    \"format_lora_message\",\n]\n\nlogger = init_logger(__name__)\n\nOUTPUT_QUALITY_MAPPER = {\"maximum\": 100, \"high\": 90, \"medium\": 55, \"low\": 35}\nDEFAULT_FPS = 24\nDEFAULT_VIDEO_SECONDS = 4\n\n\n@contextmanager\ndef temp_dir_if_disabled(\n    configured_path: str | None,\n) -> Generator[str, None, None]:\n    \"\"\"Yield *configured_path* when it is set, otherwise create a temporary\n    directory that is automatically removed when the context exits.\"\"\"\n    if configured_path is not None:\n        os.makedirs(configured_path, exist_ok=True)\n        yield configured_path\n    else:\n        tmp = tempfile.mkdtemp(prefix=\"sglang_\")\n        try:\n            yield tmp\n        finally:\n            shutil.rmtree(tmp, ignore_errors=True)\n\n\ndef _parse_size(size: str) -> tuple[int, int] | tuple[None, None]:\n    try:\n        parts = size.lower().replace(\" \", \"\").split(\"x\")\n        if len(parts) != 2:\n            raise ValueError\n        w, h = int(parts[0]), int(parts[1])\n        return w, h\n    except Exception:\n        return None, None\n\n\ndef choose_output_image_ext(\n    output_format: Optional[str], background: Optional[str]\n) -> str:\n    fmt = (output_format or \"\").lower()\n    if fmt in {\"png\", \"webp\", \"jpeg\", \"jpg\"}:\n        return \"jpg\" if fmt == \"jpeg\" else fmt\n    if (background or \"auto\").lower() == \"transparent\":\n        return \"png\"\n    return \"jpg\"\n\n\ndef build_sampling_params(request_id: str, **kwargs) -> SamplingParams:\n    \"\"\"Build SamplingParams from request parameters.\n\n    Handles size parsing, output_quality resolution, and None filtering before\n    delegating to SamplingParams.from_user_sampling_params_args. Callers pass\n    only the parameters they have; None values are stripped automatically so\n    that SamplingParams defaults apply.\n    \"\"\"\n    server_args = get_global_server_args()\n\n    # pop HTTP-layer params that aren't SamplingParams fields\n    output_quality = kwargs.pop(\"output_quality\", None)\n\n    has_explicit_compression = kwargs.get(\"output_compression\") is not None\n\n    # parse \"WxH\" size string if provided\n    size = kwargs.pop(\"size\", None)\n    if size:\n        w, h = _parse_size(size)\n        if w is not None:\n            # treat None dimensions as unset so parsed size can fill them\n            if kwargs.get(\"width\") is None:\n                kwargs[\"width\"] = w\n            if kwargs.get(\"height\") is None:\n                kwargs[\"height\"] = h\n\n    # filter out None values to let SamplingParams defaults apply\n    kwargs = {k: v for k, v in kwargs.items() if v is not None}\n    kwargs.setdefault(\"save_output\", True)\n\n    sampling_params = SamplingParams.from_user_sampling_params_args(\n        model_path=server_args.model_path,\n        server_args=server_args,\n        request_id=request_id,\n        **kwargs,\n    )\n\n    # resolve output_quality → output_compression with the correct data_type.\n    # SamplingParams.__post_init__ may have resolved with the wrong data_type\n    # (default VIDEO) before _adjust() set the correct one.\n    if not has_explicit_compression and output_quality is not None:\n        resolved = adjust_output_quality(output_quality, sampling_params.data_type)\n        if resolved is not None:\n            sampling_params.output_compression = resolved\n\n    return sampling_params\n\n\nasync def save_image_to_path(image: Union[UploadFile, str], target_path: str) -> str:\n    input_path = await _maybe_url_image(image, target_path)\n    if input_path is None:\n        input_path = await _save_upload_to_path(image, target_path)\n    return input_path\n\n\n# Helpers\nasync def _save_upload_to_path(upload: UploadFile, target_path: str) -> str:\n    os.makedirs(os.path.dirname(target_path), exist_ok=True)\n    content = await upload.read()\n    with open(target_path, \"wb\") as f:\n        f.write(content)\n    return target_path\n\n\nasync def _maybe_url_image(img_url: str, target_path: str) -> str | None:\n    if not isinstance(img_url, str):\n        return None\n\n    if img_url.lower().startswith((\"http://\", \"https://\")):\n        # Download image from URL\n        input_path = await _save_url_image_to_path(img_url, target_path)\n        return input_path\n    elif img_url.startswith(\"data:image\"):\n        # encode image base64 url\n        input_path = await _save_base64_image_to_path(img_url, target_path)\n        return input_path\n    else:\n        raise ValueError(\"Unsupported image url format\")\n\n\nasync def _save_url_image_to_path(image_url: str, target_path: str) -> str:\n    \"\"\"Download image from URL and save to target path.\"\"\"\n\n    os.makedirs(os.path.dirname(target_path), exist_ok=True)\n\n    try:\n        async with httpx.AsyncClient(follow_redirects=True) as client:\n            response = await client.get(image_url, timeout=10.0)\n            response.raise_for_status()\n\n            # Determine file extension from content type or URL after downloading\n            if not os.path.splitext(target_path)[1]:\n                content_type = response.headers.get(\"content-type\", \"\").lower()\n\n                url_path = image_url.split(\"?\")[0]\n                _, url_ext = os.path.splitext(url_path)\n                url_ext = url_ext.lower()\n\n                if url_ext in {\".jpg\", \".jpeg\", \".png\", \".webp\", \".gif\", \".bmp\"}:\n                    ext = \".jpg\" if url_ext == \".jpeg\" else url_ext\n                elif content_type.startswith(\"image/\"):\n                    if \"jpeg\" in content_type or \"jpg\" in content_type:\n                        ext = \".jpg\"\n                    elif \"png\" in content_type:\n                        ext = \".png\"\n                    elif \"webp\" in content_type:\n                        ext = \".webp\"\n                    else:\n                        ext = \".jpg\"  # Default to jpg\n                elif content_type == \"application/octet-stream\":\n                    # for octet-stream, if we couldn't get it from URL, default to jpg\n                    ext = \".jpg\"\n                else:\n                    raise ValueError(\n                        f\"URL does not point to an image. Content-Type: {content_type}\"\n                    )\n                target_path = f\"{target_path}{ext}\"\n\n            with open(target_path, \"wb\") as f:\n                f.write(response.content)\n\n            return target_path\n    except Exception as e:\n        raise Exception(f\"Failed to download image from URL: {str(e)}\")\n\n\nasync def _save_base64_image_to_path(base64_data: str, target_path: str) -> str:\n    \"\"\"Decode base64 image data and save to target path.\"\"\"\n\n    _B64_FMT_HINT = (\n        \"Failed to decode base64 image. \"\n        \"Expected format: `data:[<media-type>];base64,<data>`\"\n    )\n\n    # split `data:[<media-type>][;base64],<data>` to media-type base64 data\n    pattern = r\"data:(.*?)(;base64)?,(.*)\"\n    match = re.match(pattern, base64_data)\n    if not match:\n        raise ValueError(_B64_FMT_HINT)\n    media_type = match.group(1)\n    is_base64 = match.group(2)\n    if not is_base64:\n        raise ValueError(f\"{_B64_FMT_HINT} (missing ;base64 marker)\")\n    data = match.group(3)\n    if not data:\n        raise ValueError(f\"{_B64_FMT_HINT} (empty data payload)\")\n    # get ext from url\n    if media_type.startswith(\"image/\"):\n        ext = media_type.split(\"/\")[-1].lower()\n        if ext == \"jpeg\":\n            ext = \"jpg\"\n    else:\n        ext = \"jpg\"\n    target_path = f\"{target_path}.{ext}\"\n    os.makedirs(os.path.dirname(target_path), exist_ok=True)\n\n    try:\n        image_data = base64.b64decode(data)\n        with open(target_path, \"wb\") as f:\n            f.write(image_data)\n\n        return target_path\n    except Exception as e:\n        raise Exception(f\"Failed to decode base64 image: {str(e)}\")\n\n\nasync def process_generation_batch(\n    scheduler_client: AsyncSchedulerClient,\n    batch,\n) -> tuple[list[str], OutputBatch]:\n    total_start_time = time.perf_counter()\n    with log_generation_timer(logger, batch.prompt):\n        result = await scheduler_client.forward([batch])\n\n        if result.output is None and result.output_file_paths is None:\n            error_msg = result.error or \"Unknown error\"\n            raise RuntimeError(\n                f\"Model generation returned no output. Error from scheduler: {error_msg}\"\n            )\n\n        if result.output_file_paths:\n            save_file_path_list = result.output_file_paths\n        else:\n            num_outputs = len(result.output)\n            save_file_path_list = save_outputs(\n                result.output,\n                batch.data_type,\n                batch.fps,\n                batch.save_output,\n                lambda idx: str(batch.output_file_path(num_outputs, idx)),\n                audio=result.audio,\n                audio_sample_rate=result.audio_sample_rate,\n                output_compression=batch.output_compression,\n                enable_frame_interpolation=batch.enable_frame_interpolation,\n                frame_interpolation_exp=batch.frame_interpolation_exp,\n                frame_interpolation_scale=batch.frame_interpolation_scale,\n                frame_interpolation_model_path=batch.frame_interpolation_model_path,\n                enable_upscaling=batch.enable_upscaling,\n                upscaling_model_path=batch.upscaling_model_path,\n                upscaling_scale=batch.upscaling_scale,\n            )\n\n    total_time = time.perf_counter() - total_start_time\n    log_batch_completion(logger, 1, total_time)\n\n    if result.peak_memory_mb and result.peak_memory_mb > 0:\n        logger.info(f\"Peak memory usage: {result.peak_memory_mb:.2f} MB\")\n\n    return save_file_path_list, result\n\n\ndef merge_image_input_list(*inputs: Union[List, Any, None]) -> List:\n    \"\"\"\n    Merge multiple image input sources into a single list.\n\n    This function handles both single items and lists of items, merging them\n    into a single flattened list. Useful for processing images, URLs, or other\n    multimedia inputs that can come as either single items or lists.\n\n    Args:\n        *inputs: Variable number of inputs, each can be None, single item, or list\n\n    Returns:\n        List: Flattened list of all non-None inputs\n\n    Example:\n        >>> merge_image_input_list([\"img1\", \"img2\"], \"img3\", None)\n        [\"img1\", \"img2\", \"img3\"]\n    \"\"\"\n    result = []\n    for input_item in inputs:\n        if input_item is not None:\n            if isinstance(input_item, list):\n                result.extend(input_item)\n            else:\n                result.append(input_item)\n    return result\n\n\ndef add_common_data_to_response(\n    response: dict, request_id: str, result: OutputBatch\n) -> dict:\n    if result.peak_memory_mb and result.peak_memory_mb > 0:\n        response[\"peak_memory_mb\"] = result.peak_memory_mb\n\n    if result.metrics and result.metrics.total_duration_s > 0:\n        response[\"inference_time_s\"] = result.metrics.total_duration_s\n\n    response[\"id\"] = request_id\n\n    return response\n\n\ndef adjust_output_quality(output_quality: str, data_type: DataType = None) -> int:\n    if output_quality == \"default\":\n        return 50 if data_type == DataType.VIDEO else 75\n    return OUTPUT_QUALITY_MAPPER.get(output_quality, None)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/entrypoints/openai/video_api.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\nimport asyncio\nimport json\nimport os\nimport shutil\nimport tempfile\nimport time\nfrom typing import Any, Dict, Optional\n\nfrom fastapi import (\n    APIRouter,\n    File,\n    Form,\n    HTTPException,\n    Path,\n    Query,\n    Request,\n    UploadFile,\n)\nfrom fastapi.responses import FileResponse\n\nfrom sglang.multimodal_gen.configs.sample.sampling_params import (\n    SamplingParams,\n    generate_request_id,\n)\nfrom sglang.multimodal_gen.runtime.entrypoints.openai.protocol import (\n    VideoGenerationsRequest,\n    VideoListResponse,\n    VideoResponse,\n)\nfrom sglang.multimodal_gen.runtime.entrypoints.openai.storage import cloud_storage\nfrom sglang.multimodal_gen.runtime.entrypoints.openai.stores import VIDEO_STORE\nfrom sglang.multimodal_gen.runtime.entrypoints.openai.utils import (\n    DEFAULT_FPS,\n    DEFAULT_VIDEO_SECONDS,\n    add_common_data_to_response,\n    build_sampling_params,\n    merge_image_input_list,\n    process_generation_batch,\n    save_image_to_path,\n)\nfrom sglang.multimodal_gen.runtime.entrypoints.utils import prepare_request\nfrom sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req\nfrom sglang.multimodal_gen.runtime.server_args import get_global_server_args\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\nrouter = APIRouter(prefix=\"/v1/videos\", tags=[\"videos\"])\n\n\ndef _build_video_sampling_params(request_id: str, request: VideoGenerationsRequest):\n    \"\"\"Resolve video-specific defaults (fps, seconds → num_frames) then\n    delegate to the shared build_sampling_params.\"\"\"\n    seconds = request.seconds if request.seconds is not None else DEFAULT_VIDEO_SECONDS\n    fps = request.fps if request.fps is not None else DEFAULT_FPS\n    num_frames = request.num_frames if request.num_frames is not None else fps * seconds\n\n    return build_sampling_params(\n        request_id,\n        prompt=request.prompt,\n        size=request.size,\n        num_frames=num_frames,\n        fps=fps,\n        image_path=request.input_reference,\n        output_file_name=request_id,\n        seed=request.seed,\n        generator_device=request.generator_device,\n        num_inference_steps=request.num_inference_steps,\n        guidance_scale=request.guidance_scale,\n        guidance_scale_2=request.guidance_scale_2,\n        negative_prompt=request.negative_prompt,\n        enable_teacache=request.enable_teacache,\n        enable_frame_interpolation=request.enable_frame_interpolation,\n        frame_interpolation_exp=request.frame_interpolation_exp,\n        frame_interpolation_scale=request.frame_interpolation_scale,\n        frame_interpolation_model_path=request.frame_interpolation_model_path,\n        enable_upscaling=request.enable_upscaling,\n        upscaling_model_path=request.upscaling_model_path,\n        upscaling_scale=request.upscaling_scale,\n        output_path=request.output_path,\n        output_compression=request.output_compression,\n        output_quality=request.output_quality,\n    )\n\n\n# extract metadata which http_server needs to know\ndef _video_job_from_sampling(\n    request_id: str, req: VideoGenerationsRequest, sampling: SamplingParams\n) -> Dict[str, Any]:\n    size_str = f\"{sampling.width}x{sampling.height}\"\n    seconds = int(round((sampling.num_frames or 0) / float(sampling.fps or 24)))\n    return {\n        \"id\": request_id,\n        \"object\": \"video\",\n        \"model\": req.model or \"sora-2\",\n        \"status\": \"queued\",\n        \"progress\": 0,\n        \"created_at\": int(time.time()),\n        \"size\": size_str,\n        \"seconds\": str(seconds),\n        \"quality\": \"standard\",\n        \"file_path\": os.path.abspath(sampling.output_file_path()),\n    }\n\n\nasync def _save_first_input_image(\n    image_sources, request_id: str, uploads_dir: str\n) -> str | None:\n    \"\"\"Save the first input image from a list of sources and return its path.\"\"\"\n    image_list = merge_image_input_list(image_sources)\n    if not image_list:\n        return None\n    image = image_list[0]\n\n    os.makedirs(uploads_dir, exist_ok=True)\n\n    filename = image.filename if hasattr(image, \"filename\") else \"url_image\"\n    target_path = os.path.join(uploads_dir, f\"{request_id}_{filename}\")\n    return await save_image_to_path(image, target_path)\n\n\nasync def _dispatch_job_async(\n    job_id: str,\n    batch: Req,\n    *,\n    temp_dirs: list[str] | None = None,\n    output_persistent: bool = True,\n) -> None:\n    from sglang.multimodal_gen.runtime.scheduler_client import async_scheduler_client\n\n    try:\n        save_file_path_list, result = await process_generation_batch(\n            async_scheduler_client, batch\n        )\n        save_file_path = save_file_path_list[0]\n\n        cloud_url = await cloud_storage.upload_and_cleanup(save_file_path)\n\n        persistent_path = (\n            save_file_path if not cloud_url and output_persistent else None\n        )\n        update_fields = {\n            \"status\": \"completed\",\n            \"progress\": 100,\n            \"completed_at\": int(time.time()),\n            \"url\": cloud_url,\n            \"file_path\": persistent_path,\n        }\n        update_fields = add_common_data_to_response(\n            update_fields, request_id=job_id, result=result\n        )\n        await VIDEO_STORE.update_fields(job_id, update_fields)\n    except Exception as e:\n        logger.error(f\"{e}\")\n        await VIDEO_STORE.update_fields(\n            job_id, {\"status\": \"failed\", \"error\": {\"message\": str(e)}}\n        )\n    finally:\n        for td in temp_dirs or []:\n            shutil.rmtree(td, ignore_errors=True)\n\n\n# TODO: support image to video generation\n@router.post(\"\", response_model=VideoResponse)\nasync def create_video(\n    request: Request,\n    # multipart/form-data fields (optional; used only when content-type is multipart)\n    prompt: Optional[str] = Form(None),\n    input_reference: Optional[UploadFile] = File(None),\n    reference_url: Optional[str] = Form(None),\n    model: Optional[str] = Form(None),\n    seconds: Optional[int] = Form(None),\n    size: Optional[str] = Form(None),\n    fps: Optional[int] = Form(None),\n    num_frames: Optional[int] = Form(None),\n    seed: Optional[int] = Form(1024),\n    generator_device: Optional[str] = Form(\"cuda\"),\n    negative_prompt: Optional[str] = Form(None),\n    guidance_scale: Optional[float] = Form(None),\n    num_inference_steps: Optional[int] = Form(None),\n    enable_teacache: Optional[bool] = Form(False),\n    enable_frame_interpolation: Optional[bool] = Form(False),\n    frame_interpolation_exp: Optional[int] = Form(1),\n    frame_interpolation_scale: Optional[float] = Form(1.0),\n    frame_interpolation_model_path: Optional[str] = Form(None),\n    enable_upscaling: Optional[bool] = Form(False),\n    upscaling_model_path: Optional[str] = Form(None),\n    upscaling_scale: Optional[int] = Form(4),\n    output_quality: Optional[str] = Form(\"default\"),\n    output_compression: Optional[int] = Form(None),\n    extra_body: Optional[str] = Form(None),\n):\n    content_type = request.headers.get(\"content-type\", \"\").lower()\n    request_id = generate_request_id()\n\n    server_args = get_global_server_args()\n    task_type = server_args.pipeline_config.task_type\n\n    # Resolve input upload directory (may be a temp dir when saving is disabled)\n    temp_dirs: list[str] = []\n    if server_args.input_save_path is not None:\n        uploads_dir = server_args.input_save_path\n        os.makedirs(uploads_dir, exist_ok=True)\n    else:\n        uploads_dir = tempfile.mkdtemp(prefix=\"sglang_input_\")\n        temp_dirs.append(uploads_dir)\n\n    # Resolve output directory\n    effective_output_path = server_args.output_path\n    output_persistent = True\n    if \"multipart/form-data\" not in content_type:\n        # JSON body may carry a per-request output_path; checked after parsing below\n        pass\n\n    if \"multipart/form-data\" in content_type:\n        if not prompt:\n            raise HTTPException(status_code=400, detail=\"prompt is required\")\n        # Validate image input based on model task type\n        image_sources = merge_image_input_list(input_reference, reference_url)\n        if task_type.requires_image_input() and not image_sources:\n            raise HTTPException(\n                status_code=400,\n                detail=\"input_reference or reference_url is required for image-to-video generation\",\n            )\n        try:\n            input_path = await _save_first_input_image(\n                image_sources, request_id, uploads_dir\n            )\n        except Exception as e:\n            raise HTTPException(\n                status_code=400, detail=f\"Failed to process image source: {str(e)}\"\n            )\n\n        # Parse extra_body JSON (if provided in multipart form) to get fps/num_frames overrides\n        extra_from_form: Dict[str, Any] = {}\n        if extra_body:\n            try:\n                extra_from_form = json.loads(extra_body)\n            except Exception:\n                extra_from_form = {}\n\n        fps_val = fps if fps is not None else extra_from_form.get(\"fps\")\n        num_frames_val = (\n            num_frames if num_frames is not None else extra_from_form.get(\"num_frames\")\n        )\n\n        req = VideoGenerationsRequest(\n            prompt=prompt,\n            input_reference=input_path,\n            model=model,\n            seconds=seconds if seconds is not None else 4,\n            size=size,\n            fps=fps_val,\n            num_frames=num_frames_val,\n            seed=seed,\n            generator_device=generator_device,\n            negative_prompt=negative_prompt,\n            num_inference_steps=num_inference_steps,\n            enable_teacache=enable_teacache,\n            enable_frame_interpolation=enable_frame_interpolation,\n            frame_interpolation_exp=frame_interpolation_exp,\n            frame_interpolation_scale=frame_interpolation_scale,\n            frame_interpolation_model_path=frame_interpolation_model_path,\n            enable_upscaling=enable_upscaling,\n            upscaling_model_path=upscaling_model_path,\n            upscaling_scale=upscaling_scale,\n            output_compression=output_compression,\n            output_quality=output_quality,\n            **(\n                {\"guidance_scale\": guidance_scale} if guidance_scale is not None else {}\n            ),\n        )\n    else:\n        try:\n            body = await request.json()\n        except Exception:\n            body = {}\n        try:\n            # If client uses extra_body, merge it into the top-level payload\n            payload: Dict[str, Any] = dict(body or {})\n            extra = payload.pop(\"extra_body\", None)\n            if isinstance(extra, dict):\n                # Shallow-merge: only keys like fps/num_frames are expected\n                payload.update(extra)\n            # openai may turn extra_body to extra_json\n            extra_json = payload.pop(\"extra_json\", None)\n            if isinstance(extra_json, dict):\n                payload.update(extra_json)\n            # Validate image input based on model task type\n            has_image_input = payload.get(\"reference_url\") or payload.get(\n                \"input_reference\"\n            )\n            if task_type.requires_image_input() and not has_image_input:\n                raise HTTPException(\n                    status_code=400,\n                    detail=\"input_reference or reference_url is required for image-to-video generation\",\n                )\n            # for non-multipart/form-data type\n            if payload.get(\"reference_url\"):\n                try:\n                    input_path = await _save_first_input_image(\n                        payload.get(\"reference_url\"), request_id, uploads_dir\n                    )\n                except Exception as e:\n                    raise HTTPException(\n                        status_code=400,\n                        detail=f\"Failed to process image source: {str(e)}\",\n                    )\n                payload[\"input_reference\"] = input_path\n            req = VideoGenerationsRequest(**payload)\n        except Exception as e:\n            raise HTTPException(status_code=400, detail=f\"Invalid request body: {e}\")\n\n    # Resolve per-request output_path override\n    effective_output_path = req.output_path or server_args.output_path\n    if effective_output_path is None:\n        output_tmp = tempfile.mkdtemp(prefix=\"sglang_output_\")\n        temp_dirs.append(output_tmp)\n        effective_output_path = output_tmp\n        output_persistent = False\n\n    # Inject resolved output_path so _build_video_sampling_params picks it up\n    req.output_path = effective_output_path\n\n    logger.debug(f\"Server received from create_video endpoint: req={req}\")\n\n    try:\n        sampling_params = _build_video_sampling_params(request_id, req)\n    except (ValueError, TypeError) as e:\n        raise HTTPException(status_code=400, detail=str(e))\n\n    job = _video_job_from_sampling(request_id, req, sampling_params)\n    await VIDEO_STORE.upsert(request_id, job)\n\n    # Build Req for scheduler\n    batch = prepare_request(\n        server_args=server_args,\n        sampling_params=sampling_params,\n    )\n    # Add diffusers_kwargs if provided\n    if req.diffusers_kwargs:\n        batch.extra[\"diffusers_kwargs\"] = req.diffusers_kwargs\n    # Enqueue the job asynchronously and return immediately\n    asyncio.create_task(\n        _dispatch_job_async(\n            request_id,\n            batch,\n            temp_dirs=temp_dirs or None,\n            output_persistent=output_persistent,\n        )\n    )\n    return VideoResponse(**job)\n\n\n@router.get(\"\", response_model=VideoListResponse)\nasync def list_videos(\n    after: Optional[str] = Query(None),\n    limit: Optional[int] = Query(None, ge=1, le=100),\n    order: Optional[str] = Query(\"desc\"),\n):\n    # Normalize order\n    order = (order or \"desc\").lower()\n    if order not in (\"asc\", \"desc\"):\n        order = \"desc\"\n    jobs = await VIDEO_STORE.list_values()\n\n    reverse = order != \"asc\"\n    jobs.sort(key=lambda j: j.get(\"created_at\", 0), reverse=reverse)\n\n    if after is not None:\n        try:\n            idx = next(i for i, j in enumerate(jobs) if j[\"id\"] == after)\n            jobs = jobs[idx + 1 :]\n        except StopIteration:\n            jobs = []\n\n    if limit is not None:\n        jobs = jobs[:limit]\n    items = [VideoResponse(**j) for j in jobs]\n    return VideoListResponse(data=items)\n\n\n@router.get(\"/{video_id}\", response_model=VideoResponse)\nasync def retrieve_video(video_id: str = Path(...)):\n    job = await VIDEO_STORE.get(video_id)\n    if not job:\n        raise HTTPException(status_code=404, detail=\"Video not found\")\n    return VideoResponse(**job)\n\n\n# TODO: support aborting a job.\n@router.delete(\"/{video_id}\", response_model=VideoResponse)\nasync def delete_video(video_id: str = Path(...)):\n    job = await VIDEO_STORE.pop(video_id)\n    if not job:\n        raise HTTPException(status_code=404, detail=\"Video not found\")\n    # Mark as deleted in response semantics\n    job[\"status\"] = \"deleted\"\n    return VideoResponse(**job)\n\n\n@router.get(\"/{video_id}/content\")\nasync def download_video_content(\n    video_id: str = Path(...), variant: Optional[str] = Query(None)\n):\n    job = await VIDEO_STORE.get(video_id)\n    if not job:\n        raise HTTPException(status_code=404, detail=\"Video not found\")\n\n    if job.get(\"url\"):\n        raise HTTPException(\n            status_code=400,\n            detail=f\"Video has been uploaded to cloud storage. Please use the cloud URL: {job.get('url')}\",\n        )\n\n    file_path = job.get(\"file_path\")\n    if not file_path or not os.path.exists(file_path):\n        raise HTTPException(status_code=404, detail=\"Generation is still in-progress\")\n\n    media_type = \"video/mp4\"  # default variant\n    return FileResponse(\n        path=file_path, media_type=media_type, filename=os.path.basename(file_path)\n    )\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/entrypoints/post_training/__init__.py",
    "content": ""
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/entrypoints/post_training/io_struct.py",
    "content": "\"\"\"Request/response data structures for post-training APIs.\"\"\"\n\nfrom dataclasses import dataclass\n\n\n@dataclass\nclass UpdateWeightFromDiskReqInput:\n    \"\"\"Request to update model weights from disk for diffusion models.\"\"\"\n\n    model_path: str\n    flush_cache: bool = True\n    target_modules: list[str] | None = None\n\n\n@dataclass\nclass GetWeightsChecksumReqInput:\n    \"\"\"Compute SHA-256 checksum of loaded module weights for verification.\"\"\"\n\n    module_names: list[str] | None = None\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/entrypoints/post_training/weights_api.py",
    "content": "\"\"\"Weight update API for the diffusion engine.\"\"\"\n\nfrom fastapi import APIRouter, Request\nfrom fastapi.responses import ORJSONResponse\n\nfrom sglang.multimodal_gen.runtime.entrypoints.post_training.io_struct import (\n    GetWeightsChecksumReqInput,\n    UpdateWeightFromDiskReqInput,\n)\nfrom sglang.multimodal_gen.runtime.scheduler_client import async_scheduler_client\n\nrouter = APIRouter()\n\n\n@router.post(\"/update_weights_from_disk\")\nasync def update_weights_from_disk(request: Request):\n    \"\"\"Update model weights from disk inplace without restarting the server.\"\"\"\n    body = await request.json()\n    model_path = body.get(\"model_path\")\n    if not model_path:\n        return ORJSONResponse(\n            {\"success\": False, \"message\": \"model_path is required\"},\n            status_code=400,\n        )\n\n    req = UpdateWeightFromDiskReqInput(\n        model_path=model_path,\n        flush_cache=body.get(\"flush_cache\", True),\n        target_modules=body.get(\"target_modules\"),\n    )\n\n    try:\n        response = await async_scheduler_client.forward(req)\n    except Exception as e:\n        return ORJSONResponse(\n            {\"success\": False, \"message\": str(e)},\n            status_code=500,\n        )\n\n    result = response.output\n    success = result.get(\"success\", False)\n    message = result.get(\"message\", \"Unknown status\")\n    return ORJSONResponse(\n        {\"success\": success, \"message\": message},\n        status_code=200 if success else 400,\n    )\n\n\n@router.post(\"/get_weights_checksum\")\nasync def get_weights_checksum(request: Request):\n    \"\"\"Return SHA-256 checksum of each requested module's weights.\"\"\"\n    body = await request.json()\n    req = GetWeightsChecksumReqInput(\n        module_names=body.get(\"module_names\"),\n    )\n\n    try:\n        response = await async_scheduler_client.forward(req)\n    except Exception as e:\n        return ORJSONResponse({\"error\": str(e)}, status_code=500)\n\n    return ORJSONResponse(response.output, status_code=200)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/entrypoints/utils.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"\nDiffGenerator module for sglang-diffusion.\n\nThis module provides a consolidated interface for generating videos using\ndiffusion models.\n\"\"\"\n\nimport os\nimport shutil\nimport subprocess\nimport tempfile\nfrom dataclasses import dataclass, field\nfrom typing import Any, Callable, List, Optional, Sequence, Union\n\nimport imageio\nimport numpy as np\nimport torch\n\ntry:\n    import scipy.io.wavfile as scipy_wavfile\nexcept ImportError:  # pragma: no cover\n    scipy_wavfile = None\n\ntry:\n    import imageio_ffmpeg as _imageio_ffmpeg\nexcept ImportError:  # pragma: no cover\n    _imageio_ffmpeg = None\n\nfrom sglang.multimodal_gen.configs.sample.sampling_params import (\n    DataType,\n    SamplingParams,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import CYAN, RESET, init_logger\n\nlogger = init_logger(__name__)\n\n\n@dataclass\nclass SetLoraReq:\n    lora_nickname: Union[str, List[str]]\n    lora_path: Optional[Union[str, List[Optional[str]]]] = None\n    target: Union[str, List[str]] = \"all\"\n    strength: Union[float, List[float]] = 1.0\n\n\n@dataclass\nclass MergeLoraWeightsReq:\n    target: str = \"all\"\n    strength: float = 1.0\n\n\n@dataclass\nclass UnmergeLoraWeightsReq:\n    target: str = \"all\"\n\n\n@dataclass\nclass ListLorasReq:\n    pass\n\n\n@dataclass\nclass ShutdownReq:\n    pass\n\n\ndef format_lora_message(\n    lora_nickname: Union[str, List[str]],\n    target: Union[str, List[str]],\n    strength: Union[float, List[float]],\n) -> tuple[str, str, str]:\n    \"\"\"Format success message for single or multiple LoRAs.\"\"\"\n    if isinstance(lora_nickname, list):\n        nickname_str = \", \".join(lora_nickname)\n        target_str = \", \".join(target) if isinstance(target, list) else target\n        strength_str = (\n            \", \".join(f\"{s:.2f}\" for s in strength)\n            if isinstance(strength, list)\n            else f\"{strength:.2f}\"\n        )\n    else:\n        nickname_str = lora_nickname\n        target_str = target if isinstance(target, str) else \", \".join(target)\n        strength_str = (\n            f\"{strength:.2f}\"\n            if isinstance(strength, (int, float))\n            else \", \".join(f\"{s:.2f}\" for s in strength)\n        )\n    return nickname_str, target_str, strength_str\n\n\n@dataclass\nclass GenerationResult:\n    \"\"\"Result of a single generation request from DiffGenerator.\"\"\"\n\n    samples: Any = None\n    frames: Any = None\n    audio: Any = None\n    prompt: str | None = None\n    size: tuple | None = None  # (height, width, num_frames)\n    generation_time: float = 0.0\n    peak_memory_mb: float = 0.0\n    metrics: dict = field(default_factory=dict)\n    trajectory_latents: Any = None\n    trajectory_timesteps: Any = None\n    trajectory_decoded: Any = None\n    prompt_index: int = 0\n    output_file_path: str | None = None\n\n\ndef _normalize_audio_to_numpy(audio: Any) -> np.ndarray | None:\n    \"\"\"Convert audio (torch / numpy) into a float32 numpy array in [-1, 1], best-effort.\"\"\"\n    if audio is None:\n        return None\n    if isinstance(audio, torch.Tensor):\n        audio_np = audio.detach().float().clamp(-1.0, 1.0).cpu().numpy()\n    elif isinstance(audio, np.ndarray):\n        audio_np = audio.astype(np.float32, copy=False)\n        audio_np = np.clip(audio_np, -1.0, 1.0)\n    else:\n        return None\n\n    # 1. Squeeze leading singleton dimensions (Batch, etc.)\n    while audio_np.ndim > 1 and audio_np.shape[0] == 1:\n        audio_np = audio_np.squeeze(0)\n\n    # 2. Handle (C, L) -> (L, C)\n    if audio_np.ndim == 2 and audio_np.shape[0] < audio_np.shape[1]:\n        audio_np = audio_np.transpose(1, 0)\n\n    # 3. Final safety check: if still 2D and channels (dim 1) is huge, something is wrong\n    if audio_np.ndim == 2 and audio_np.shape[1] > 256 and audio_np.shape[0] == 1:\n        audio_np = audio_np.flatten()\n\n    return audio_np\n\n\ndef _pick_audio_sample_rate(\n    *,\n    audio_np: np.ndarray,\n    audio_sample_rate: Optional[int],\n    fps: int,\n    num_frames: int,\n) -> int:\n    \"\"\"Pick a plausible sample rate, falling back to inferring from video duration.\"\"\"\n    selected_sr = int(audio_sample_rate) if audio_sample_rate is not None else None\n    if selected_sr is None or not (8000 <= selected_sr <= 192000):\n        selected_sr = 24000\n        try:\n            duration_s = float(num_frames) / float(fps) if fps else 0.0\n            if duration_s > 0:\n                audio_len = (\n                    int(audio_np.shape[0])\n                    if audio_np.ndim == 2\n                    else int(audio_np.shape[-1])\n                )\n                inferred_sr = int(round(float(audio_len) / duration_s))\n                if 8000 <= inferred_sr <= 192000:\n                    selected_sr = inferred_sr\n        except Exception:\n            pass\n    return selected_sr\n\n\ndef _resolve_ffmpeg_exe() -> str:\n    ffmpeg_exe = \"ffmpeg\"\n    ffmpeg_on_path = shutil.which(\"ffmpeg\")\n    if ffmpeg_on_path:\n        ffmpeg_exe = ffmpeg_on_path\n    try:\n        if _imageio_ffmpeg is not None:\n            ffmpeg_exe = _imageio_ffmpeg.get_ffmpeg_exe()\n    except Exception:\n        pass\n\n    ffmpeg_ok = False\n    if ffmpeg_exe:\n        if os.path.isabs(ffmpeg_exe):\n            ffmpeg_ok = os.path.exists(ffmpeg_exe)\n        else:\n            ffmpeg_ok = shutil.which(ffmpeg_exe) is not None\n    if not ffmpeg_ok:\n        raise RuntimeError(\"ffmpeg not found\")\n    return ffmpeg_exe\n\n\ndef _mux_audio_np_into_mp4(\n    *,\n    save_file_path: str,\n    audio_np: np.ndarray,\n    sample_rate: int,\n    ffmpeg_exe: str,\n) -> None:\n    merged_path = save_file_path.rsplit(\".\", 1)[0] + \".tmp_mux.mp4\"\n    tmp_wav_path = None\n    try:\n        if scipy_wavfile is None:\n            raise RuntimeError(\n                \"scipy is required to mux audio into mp4 (pip install scipy)\"\n            )\n        with tempfile.NamedTemporaryFile(suffix=\".wav\", delete=False) as f:\n            tmp_wav_path = f.name\n        scipy_wavfile.write(tmp_wav_path, sample_rate, audio_np)\n        subprocess.run(\n            [\n                ffmpeg_exe,\n                \"-y\",\n                \"-i\",\n                save_file_path,\n                \"-i\",\n                tmp_wav_path,\n                \"-c:v\",\n                \"copy\",\n                \"-c:a\",\n                \"aac\",\n                \"-strict\",\n                \"experimental\",\n                merged_path,\n            ],\n            check=True,\n            stdout=subprocess.DEVNULL,\n            stderr=subprocess.DEVNULL,\n        )\n        os.replace(merged_path, save_file_path)\n    finally:\n        if tmp_wav_path:\n            try:\n                os.remove(tmp_wav_path)\n            except OSError:\n                pass\n        if os.path.exists(merged_path):\n            try:\n                os.remove(merged_path)\n            except OSError:\n                pass\n\n\ndef _maybe_mux_audio_into_mp4(\n    *,\n    save_file_path: str,\n    audio: Any,\n    frames: list,\n    fps: int,\n    audio_sample_rate: Optional[int],\n) -> None:\n    \"\"\"Best-effort mux audio into an already-written mp4 at save_file_path.\n\n    Any failure should keep the silent video and only log a warning.\n    \"\"\"\n    audio_np = _normalize_audio_to_numpy(audio)\n    if audio_np is None:\n        return\n    selected_sr = _pick_audio_sample_rate(\n        audio_np=audio_np,\n        audio_sample_rate=audio_sample_rate,\n        fps=fps,\n        num_frames=len(frames),\n    )\n\n    try:\n        ffmpeg_exe = _resolve_ffmpeg_exe()\n        _mux_audio_np_into_mp4(\n            save_file_path=save_file_path,\n            audio_np=audio_np,\n            sample_rate=selected_sr,\n            ffmpeg_exe=ffmpeg_exe,\n        )\n        logger.info(f\"Merged video saved to {CYAN}{save_file_path}{RESET}\")\n    except Exception as e:\n        logger.warning(\n            \"Failed to mux audio into mp4 (saved silent video): %s\",\n            str(e),\n        )\n\n\ndef prepare_request(\n    server_args: ServerArgs,\n    sampling_params: SamplingParams,\n) -> Req:\n    \"\"\"\n    Create a Req object with sampling_params as a parameter.\n    \"\"\"\n    req = Req(\n        sampling_params=sampling_params,\n        VSA_sparsity=server_args.attention_backend_config.VSA_sparsity,\n    )\n    try:\n        diffusers_kwargs = sampling_params.diffusers_kwargs\n    except AttributeError:\n        diffusers_kwargs = None\n    if diffusers_kwargs:\n        req.extra[\"diffusers_kwargs\"] = diffusers_kwargs\n\n    req.adjust_size(server_args)\n\n    if not isinstance(req.prompt, str):\n        raise TypeError(f\"`prompt` must be a string, but got {type(req.prompt)}\")\n\n    if (req.width is not None and req.width <= 0) or (\n        req.height is not None and req.height <= 0\n    ):\n        raise ValueError(\n            f\"Height and width must be positive, got height={req.height}, width={req.width}\"\n        )\n\n    return req\n\n\ndef attach_audio_to_video_sample(\n    sample: Any,\n    audio: Any,\n    output_idx: int,\n) -> Any:\n    \"\"\"Attach per-sample audio for video outputs when available.\"\"\"\n    if audio is None:\n        return sample\n    if isinstance(audio, torch.Tensor) and audio.ndim >= 2:\n        audio = audio[output_idx] if audio.shape[0] > output_idx else None\n    elif isinstance(audio, np.ndarray) and audio.ndim >= 2:\n        audio = audio[output_idx] if audio.shape[0] > output_idx else None\n\n    if audio is not None and not (\n        isinstance(sample, (tuple, list)) and len(sample) == 2\n    ):\n        return (sample, audio)\n    return sample\n\n\ndef save_outputs(\n    outputs: Sequence[Any],\n    data_type: DataType,\n    fps: int,\n    save_output: bool,\n    build_output_path: Callable[[int], str],\n    *,\n    audio: Any = None,\n    audio_sample_rate: Optional[int] = None,\n    samples_out: Optional[list[Any]] = None,\n    audios_out: Optional[list[Any]] = None,\n    frames_out: Optional[list[Any]] = None,\n    output_compression: Optional[int] = None,\n    enable_frame_interpolation: bool = False,\n    frame_interpolation_exp: int = 1,\n    frame_interpolation_scale: float = 1.0,\n    frame_interpolation_model_path: Optional[str] = None,\n    enable_upscaling: bool = False,\n    upscaling_model_path: Optional[str] = None,\n    upscaling_scale: int = 4,\n) -> list[str]:\n    \"\"\"Save outputs to files and return the list of file paths.\"\"\"\n    output_paths: list[str] = []\n    for idx, output in enumerate(outputs):\n        save_file_path = build_output_path(idx)\n        sample = output\n        if data_type == DataType.VIDEO:\n            sample = attach_audio_to_video_sample(sample, audio, idx)\n\n        frames = post_process_sample(\n            sample,\n            data_type,\n            fps,\n            save_output,\n            save_file_path,\n            audio_sample_rate=audio_sample_rate,\n            output_compression=output_compression,\n            enable_frame_interpolation=enable_frame_interpolation,\n            frame_interpolation_exp=frame_interpolation_exp,\n            frame_interpolation_scale=frame_interpolation_scale,\n            frame_interpolation_model_path=frame_interpolation_model_path,\n            enable_upscaling=enable_upscaling,\n            upscaling_model_path=upscaling_model_path,\n            upscaling_scale=upscaling_scale,\n        )\n\n        if samples_out is not None:\n            samples_out.append(sample)\n        if audios_out is not None:\n            if data_type == DataType.VIDEO:\n                audio_item = audio\n                if isinstance(audio, torch.Tensor) and audio.ndim >= 2:\n                    audio_item = audio[idx] if audio.shape[0] > idx else None\n                elif isinstance(audio, np.ndarray) and audio.ndim >= 2:\n                    audio_item = audio[idx] if audio.shape[0] > idx else None\n                audios_out.append(audio_item)\n            else:\n                audios_out.append(audio)\n        if frames_out is not None:\n            frames_out.append(frames)\n        output_paths.append(save_file_path)\n    return output_paths\n\n\ndef post_process_sample(\n    sample: Any,\n    data_type: DataType,\n    fps: int,\n    save_output: bool = True,\n    save_file_path: Optional[str] = None,\n    audio_sample_rate: Optional[int] = None,\n    output_compression: Optional[int] = None,\n    enable_frame_interpolation: bool = False,\n    frame_interpolation_exp: int = 1,\n    frame_interpolation_scale: float = 1.0,\n    frame_interpolation_model_path: Optional[str] = None,\n    enable_upscaling: bool = False,\n    upscaling_model_path: Optional[str] = None,\n    upscaling_scale: int = 4,\n):\n    \"\"\"\n    Process sample output, optionally interpolate video frames, and save.\n    \"\"\"\n    audio = None\n    if isinstance(sample, (tuple, list)) and len(sample) == 2:\n        sample, audio = sample\n\n    # 1. Convert tensor / array to list of uint8 HWC frames\n    frames = None\n    if isinstance(sample, torch.Tensor):\n        if sample.dim() == 3:\n            sample = sample.unsqueeze(1)\n        sample = (sample * 255).clamp(0, 255).to(torch.uint8)\n        videos = sample.permute(1, 2, 3, 0).cpu().numpy()\n        frames = list(videos)\n    else:\n        if not isinstance(sample, np.ndarray):\n            raise TypeError(f\"Unsupported sample type: {type(sample)}\")\n\n        arr = sample\n        if arr.ndim == 3:\n            if arr.shape[-1] in (1, 3, 4):\n                arr = arr[None, ...]\n            else:\n                arr = arr[..., None]\n        if arr.ndim != 4:\n            raise ValueError(f\"Unexpected numpy sample shape: {tuple(arr.shape)}\")\n\n        if arr.shape[-1] not in (1, 3, 4) and arr.shape[0] in (1, 3, 4):\n            t = torch.from_numpy(arr)\n            if t.dim() == 3:\n                t = t.unsqueeze(1)\n            t = (t * 255).clamp(0, 255).to(torch.uint8)\n            videos = t.permute(1, 2, 3, 0).cpu().numpy()\n            frames = list(videos)\n        else:\n            if arr.dtype != np.uint8:\n                arr = (np.clip(arr, 0.0, 1.0) * 255.0).astype(np.uint8)\n            frames = list(arr)\n\n    # 2. Frame interpolation (video only)\n    if enable_frame_interpolation and data_type == DataType.VIDEO and len(frames) > 1:\n        from sglang.multimodal_gen.runtime.postprocess import (\n            interpolate_video_frames,\n        )\n\n        frames, multiplier = interpolate_video_frames(\n            frames,\n            exp=frame_interpolation_exp,\n            scale=frame_interpolation_scale,\n            model_path=frame_interpolation_model_path,\n        )\n        fps = fps * multiplier\n\n    # 3. Upscaling (images and videos)\n    if enable_upscaling and frames:\n        from sglang.multimodal_gen.runtime.postprocess import upscale_frames\n\n        frames = upscale_frames(\n            frames,\n            model_path=upscaling_model_path,\n            scale=upscaling_scale,\n        )\n\n    # 4. Save outputs if requested\n    if save_output:\n        if save_file_path:\n            os.makedirs(os.path.dirname(save_file_path), exist_ok=True)\n            if data_type == DataType.VIDEO:\n                quality = (\n                    output_compression / 10 if output_compression is not None else 5\n                )\n                imageio.mimsave(\n                    save_file_path,\n                    frames,\n                    fps=fps,\n                    format=data_type.get_default_extension(),\n                    codec=\"libx264\",\n                    quality=quality,\n                )\n\n                _maybe_mux_audio_into_mp4(\n                    save_file_path=save_file_path,\n                    audio=audio,\n                    frames=frames,\n                    fps=fps,\n                    audio_sample_rate=audio_sample_rate,\n                )\n\n            else:\n                quality = output_compression if output_compression is not None else 75\n                if len(frames) > 1:\n                    for i, image in enumerate(frames):\n                        parts = save_file_path.rsplit(\".\", 1)\n                        if len(parts) == 2:\n                            indexed_path = f\"{parts[0]}_{i}.{parts[1]}\"\n                        else:\n                            indexed_path = f\"{save_file_path}_{i}\"\n                        imageio.imwrite(indexed_path, image, quality=quality)\n                else:\n                    imageio.imwrite(save_file_path, frames[0], quality=quality)\n            logger.info(f\"Output saved to {CYAN}{save_file_path}{RESET}\")\n        else:\n            logger.info(f\"No output path provided, output not saved\")\n\n    return frames\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/launch_server.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\nimport multiprocessing as mp\nimport os\nimport signal\nimport sys\nimport threading\n\nimport psutil\nimport uvicorn\n\nfrom sglang.multimodal_gen.runtime.entrypoints.http_server import create_app\nfrom sglang.multimodal_gen.runtime.managers.gpu_worker import run_scheduler_process\nfrom sglang.multimodal_gen.runtime.server_args import (\n    ServerArgs,\n    prepare_server_args,\n    set_global_server_args,\n)\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import configure_logger, logger\n\n\ndef kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = None):\n    \"\"\"Kill the process and all its child processes.\"\"\"\n    # Remove sigchld handler to avoid spammy logs.\n    if threading.current_thread() is threading.main_thread():\n        signal.signal(signal.SIGCHLD, signal.SIG_DFL)\n\n    if parent_pid is None:\n        parent_pid = os.getpid()\n        include_parent = False\n\n    try:\n        itself = psutil.Process(parent_pid)\n    except psutil.NoSuchProcess:\n        return\n\n    children = itself.children(recursive=True)\n    for child in children:\n        if child.pid == skip_pid:\n            continue\n        try:\n            child.kill()\n        except psutil.NoSuchProcess:\n            pass\n\n    if include_parent:\n        try:\n            if parent_pid == os.getpid():\n                itself.kill()\n                sys.exit(0)\n\n            itself.kill()\n\n            # Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes),\n            # so we send an additional signal to kill them.\n            itself.send_signal(signal.SIGQUIT)\n        except psutil.NoSuchProcess:\n            pass\n\n\ndef launch_server(server_args: ServerArgs, launch_http_server: bool = True):\n    \"\"\"\n    Args:\n        launch_http_server: False for offline local mode\n    \"\"\"\n    configure_logger(server_args)\n\n    # Start a new server with multiple worker processes\n    logger.info(\"Starting server...\")\n\n    num_gpus = server_args.num_gpus\n    processes = []\n\n    # Pipes for master to talk to slaves\n    task_pipes_to_slaves_w = []\n    task_pipes_to_slaves_r = []\n    for _ in range(num_gpus - 1):\n        r, w = mp.Pipe(duplex=False)\n        task_pipes_to_slaves_r.append(r)\n        task_pipes_to_slaves_w.append(w)\n\n    # Pipes for slaves to talk to master\n    result_pipes_from_slaves_w = []\n    result_pipes_from_slaves_r = []\n    for _ in range(num_gpus - 1):\n        r, w = mp.Pipe(duplex=False)\n        result_pipes_from_slaves_r.append(r)\n        result_pipes_from_slaves_w.append(w)\n\n    # Launch all worker processes\n    master_port = server_args.master_port or (server_args.master_port + 100)\n    scheduler_pipe_readers = []\n    scheduler_pipe_writers = []\n\n    for i in range(num_gpus):\n        reader, writer = mp.Pipe(duplex=False)\n        scheduler_pipe_writers.append(writer)\n        if i == 0:  # Master worker\n            process = mp.Process(\n                target=run_scheduler_process,\n                args=(\n                    i,  # local_rank\n                    i,  # rank\n                    master_port,\n                    server_args,\n                    writer,\n                    None,  # No task pipe to read from master\n                    None,  # No result pipe to write to master\n                    task_pipes_to_slaves_w,\n                    result_pipes_from_slaves_r,\n                ),\n                name=f\"sglang-diffusionWorker-{i}\",\n                daemon=True,\n            )\n        else:  # Slave workers\n            process = mp.Process(\n                target=run_scheduler_process,\n                args=(\n                    i,  # local_rank\n                    i,  # rank\n                    master_port,\n                    server_args,\n                    writer,\n                    None,  # No task pipe to read from master\n                    None,  # No result pipe to write to master\n                    task_pipes_to_slaves_r[i - 1],\n                    result_pipes_from_slaves_w[i - 1],\n                ),\n                name=f\"sglang-diffusionWorker-{i}\",\n                daemon=True,\n            )\n        scheduler_pipe_readers.append(reader)\n        process.start()\n        processes.append(process)\n\n    # Wait for all workers to be ready\n    scheduler_infos = []\n    for writer in scheduler_pipe_writers:\n        writer.close()\n\n    # Close unused pipe ends in parent process\n    for p in task_pipes_to_slaves_w:\n        p.close()\n    for p in task_pipes_to_slaves_r:\n        p.close()\n    for p in result_pipes_from_slaves_w:\n        p.close()\n    for p in result_pipes_from_slaves_r:\n        p.close()\n\n    for i, reader in enumerate(scheduler_pipe_readers):\n        try:\n            data = reader.recv()\n        except EOFError:\n            logger.error(\n                f\"Rank {i} scheduler is dead. Please check if there are relevant logs.\"\n            )\n            processes[i].join()\n            logger.error(f\"Exit code: {processes[i].exitcode}\")\n            raise\n\n        if data[\"status\"] != \"ready\":\n            raise RuntimeError(\n                \"Initialization failed. Please see the error messages above.\"\n            )\n        scheduler_infos.append(data)\n        reader.close()\n\n    logger.debug(\"All workers are ready\")\n\n    if launch_http_server:\n        logger.info(\"Starting FastAPI server.\")\n        if server_args.webui:\n            logger.info(\"Launch FastAPI server in another process because of webui.\")\n            http_server_process = mp.Process(\n                target=launch_http_server_only,\n                args=(server_args,),\n                name=f\"sglang-diffusion-webui\",\n                daemon=True,\n            )\n            http_server_process.start()\n        else:\n            launch_http_server_only(server_args)\n\n    return processes\n\n\ndef launch_http_server_only(server_args):\n    # set for endpoints to access global_server_args\n    set_global_server_args(server_args)\n    app = create_app(server_args)\n    uvicorn.run(\n        app,\n        use_colors=True,\n        log_level=server_args.log_level,\n        host=server_args.host,\n        port=server_args.port,\n        reload=False,\n    )\n\n\nif __name__ == \"__main__\":\n    server_args = prepare_server_args(sys.argv[1:])\n\n    try:\n        launch_server(server_args)\n    finally:\n        kill_process_tree(os.getpid(), include_parent=False)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/layers/__init__.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/layers/activation.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/layers/activation.py\n\"\"\"Custom activation functions.\"\"\"\n\nimport math\nfrom typing import Any\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\n\n_is_cuda = current_platform.is_cuda()\n_is_hip = current_platform.is_hip()\n_is_npu = current_platform.is_npu()\nif _is_cuda or _is_hip:\n    from sgl_kernel import silu_and_mul\n\nif _is_npu:\n    import torch_npu\n# TODO (will): remove this dependency\nfrom sglang.multimodal_gen.runtime.layers.custom_op import CustomOp\n\n\n@CustomOp.register(\"silu_and_mul\")\nclass SiluAndMul(CustomOp):\n    \"\"\"An activation function for SwiGLU.\n\n    The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.\n\n    Shapes:\n        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)\n        return: (num_tokens, d) or (batch_size, seq_len, d)\n    \"\"\"\n\n    def __init__(self) -> None:\n        super().__init__()\n\n    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:\n        d = x.shape[-1] // 2\n        output_shape = x.shape[:-1] + (d,)\n        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)\n        silu_and_mul(x, out)\n        return out\n\n    def forward_native(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"PyTorch-native implementation equivalent to forward().\"\"\"\n        d = x.shape[-1] // 2\n        return F.silu(x[..., :d]) * x[..., d:]\n\n    def forward_npu(self, x: torch.Tensor) -> torch.Tensor:\n        out = torch_npu.npu_swiglu(x)\n        return out\n\n    def forward_musa(self, x: torch.Tensor) -> torch.Tensor:\n        return nn.SwishGLU()(x)\n\n\n@CustomOp.register(\"gelu_and_mul\")\nclass GeluAndMul(CustomOp):\n    \"\"\"An activation function for GeGLU.\n\n    The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.\n\n    Shapes:\n        x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)\n        return: (batch_size, seq_len, d) or (num_tokens, d)\n    \"\"\"\n\n    def __init__(self, approximate: str = \"none\"):\n        super().__init__()\n        self.approximate = approximate\n        if approximate not in (\"none\", \"tanh\"):\n            raise ValueError(f\"Unknown approximate mode: {approximate}\")\n\n    def forward_cuda(self, *args, **kwargs) -> Any:\n        return self.forward_native(*args, **kwargs)\n\n    def forward_native(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"PyTorch-native implementation equivalent to forward().\"\"\"\n        d = x.shape[-1] // 2\n        return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]\n\n    def extra_repr(self) -> str:\n        return f\"approximate={repr(self.approximate)}\"\n\n\n@CustomOp.register(\"gelu_new\")\nclass NewGELU(CustomOp):\n\n    def __init__(self):\n        super().__init__()\n\n    def forward_cuda(self, *args, **kwargs) -> Any:\n        return self.forward_native(*args, **kwargs)\n\n    def forward_native(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"PyTorch-native implementation equivalent to forward().\"\"\"\n        c = math.sqrt(2.0 / math.pi)\n        return 0.5 * x * (1.0 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3.0))))\n\n\n@CustomOp.register(\"quick_gelu\")\nclass QuickGELU(CustomOp):\n    # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90\n    def __init__(self):\n        super().__init__()\n\n    def forward_cuda(self, *args, **kwargs) -> Any:\n        return self.forward_native(*args, **kwargs)\n\n    def forward_native(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"PyTorch-native implementation equivalent to forward().\"\"\"\n        return x * torch.sigmoid(1.702 * x)\n\n\n_ACTIVATION_REGISTRY = {\n    \"gelu\": nn.GELU,\n    \"gelu_new\": NewGELU,\n    \"gelu_pytorch_tanh\": lambda: nn.GELU(approximate=\"tanh\"),\n    \"relu\": nn.ReLU,\n    \"silu\": nn.SiLU,\n    \"quick_gelu\": QuickGELU,\n}\n\n\ndef get_act_fn(act_fn_name: str) -> nn.Module:\n    \"\"\"Get an activation function by name.\"\"\"\n    act_fn_name = act_fn_name.lower()\n    if act_fn_name not in _ACTIVATION_REGISTRY:\n        raise ValueError(f\"Activation function {act_fn_name!r} is not supported.\")\n\n    return _ACTIVATION_REGISTRY[act_fn_name]()\n\n\n_ACTIVATION_AND_MUL_REGISTRY = {\n    \"gelu\": GeluAndMul,\n    \"silu\": SiluAndMul,\n}\n\n\ndef get_act_and_mul_fn(act_fn_name: str) -> nn.Module:\n    \"\"\"Get an activation-and-mul (i.e. SiluAndMul) function by name.\"\"\"\n    act_fn_name = act_fn_name.lower()\n    if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:\n        raise ValueError(f\"Activation function {act_fn_name!r} is not supported.\")\n\n    return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]()\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/layers/attention/STA_configuration.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\nimport json\nimport os\nfrom collections import defaultdict\nfrom typing import Any\n\nimport numpy as np\n\nfrom sglang.multimodal_gen.utils import dict_to_3d_list\n\n\ndef configure_sta(\n    mode: str = \"STA_searching\",\n    layer_num: int = 40,\n    time_step_num: int = 50,\n    head_num: int = 40,\n    **kwargs,\n) -> list[list[list[Any]]]:\n    \"\"\"\n    Configure Sliding Tile Attention (STA) parameters based on the specified mode.\n\n    Parameters:\n    ----------\n    mode : str\n        The STA mode to use. Options are:\n        - 'STA_searching': Generate a set of mask candidates for initial search\n        - 'STA_tuning': Select best mask strategy based on previously saved results\n        - 'STA_inference': Load and use a previously tuned mask strategy\n    layer_num: int, number of layers\n    time_step_num: int, number of timesteps\n    head_num: int, number of heads\n\n    **kwargs : dict\n        Mode-specific parameters:\n\n        For 'STA_searching':\n        - mask_candidates: list of str, optional, mask candidates to use\n        - mask_selected: list of int, optional, indices of selected masks\n\n        For 'STA_tuning':\n        - mask_search_files_path: str, required, path to mask search results\n        - mask_candidates: list of str, optional, mask candidates to use\n        - mask_selected: list of int, optional, indices of selected masks\n        - skip_time_steps: int, optional, number of time steps to use full attention (default 12)\n        - save_dir: str, optional, directory to save mask strategy (default \"mask_candidates\")\n\n        For 'STA_inference':\n        - load_path: str, optional, path to load mask strategy (default \"mask_candidates/mask_strategy.json\")\n    \"\"\"\n    valid_modes = [\"STA_searching\", \"STA_tuning\", \"STA_inference\", \"STA_tuning_cfg\"]\n    if mode not in valid_modes:\n        raise ValueError(f\"Mode must be one of {valid_modes}, got {mode}\")\n\n    if mode == \"STA_searching\":\n        # Get parameters with defaults\n        mask_candidates: list[str] | None = kwargs.get(\"mask_candidates\")\n        if mask_candidates is None:\n            raise ValueError(\"mask_candidates is required for STA_searching mode\")\n        mask_selected: list[int] = kwargs.get(\n            \"mask_selected\", list(range(len(mask_candidates)))\n        )\n\n        # Parse selected masks\n        selected_masks: list[list[int]] = []\n        for index in mask_selected:\n            mask = mask_candidates[index]\n            masks_list = [int(x) for x in mask.split(\",\")]\n            selected_masks.append(masks_list)\n\n        # Create 3D mask structure with fixed dimensions (t=50, l=60)\n        masks_3d: list[list[list[list[int]]]] = []\n        for i in range(time_step_num):  # Fixed t dimension = 50\n            row = []\n            for j in range(layer_num):  # Fixed l dimension = 60\n                row.append(selected_masks)  # Add all masks at each position\n            masks_3d.append(row)\n\n        return masks_3d\n\n    elif mode == \"STA_tuning\":\n        # Get required parameters\n        mask_search_files_path: str | None = kwargs.get(\"mask_search_files_path\")\n        if not mask_search_files_path:\n            raise ValueError(\"mask_search_files_path is required for STA_tuning mode\")\n\n        # Get optional parameters with defaults\n        mask_candidates_tuning: list[str] | None = kwargs.get(\"mask_candidates\")\n        if mask_candidates_tuning is None:\n            raise ValueError(\"mask_candidates is required for STA_tuning mode\")\n        mask_selected_tuning: list[int] = kwargs.get(\n            \"mask_selected\", list(range(len(mask_candidates_tuning)))\n        )\n        skip_time_steps_tuning: int | None = kwargs.get(\"skip_time_steps\")\n        save_dir_tuning: str | None = kwargs.get(\"save_dir\", \"mask_candidates\")\n\n        # Parse selected masks\n        selected_masks_tuning: list[list[int]] = []\n        for index in mask_selected_tuning:\n            mask = mask_candidates_tuning[index]\n            masks_list = [int(x) for x in mask.split(\",\")]\n            selected_masks_tuning.append(masks_list)\n\n        # Read JSON results\n        results = read_specific_json_files(mask_search_files_path)\n        averaged_results = average_head_losses(results, selected_masks_tuning)\n\n        # Add full attention mask for specific cases\n        full_attention_mask_tuning: list[int] | None = kwargs.get(\"full_attention_mask\")\n        if full_attention_mask_tuning is not None:\n            selected_masks_tuning.append(full_attention_mask_tuning)\n\n        # Select best mask strategy\n        timesteps_tuning: int = kwargs.get(\"timesteps\", time_step_num)\n        if skip_time_steps_tuning is None:\n            skip_time_steps_tuning = 12\n        mask_strategy, sparsity, strategy_counts = select_best_mask_strategy(\n            averaged_results,\n            selected_masks_tuning,\n            skip_time_steps_tuning,\n            timesteps_tuning,\n            head_num,\n        )\n\n        # Save mask strategy\n        if save_dir_tuning is not None:\n            os.makedirs(save_dir_tuning, exist_ok=True)\n            file_path = os.path.join(\n                save_dir_tuning, f\"mask_strategy_s{skip_time_steps_tuning}.json\"\n            )\n            with open(file_path, \"w\") as f:\n                json.dump(mask_strategy, f, indent=4)\n            print(f\"Successfully saved mask_strategy to {file_path}\")\n\n        # Print sparsity and strategy counts for information\n        print(f\"Overall sparsity: {sparsity:.4f}\")\n        print(\"\\nStrategy usage counts:\")\n        total_heads = time_step_num * layer_num * head_num  # Fixed dimensions\n        for strategy, count in strategy_counts.items():\n            print(f\"Strategy {strategy}: {count} heads ({count/total_heads*100:.2f}%)\")\n\n        # Convert dictionary to 3D list with fixed dimensions\n        mask_strategy_3d = dict_to_3d_list(\n            mask_strategy, t_max=time_step_num, l_max=layer_num, h_max=head_num\n        )\n\n        return mask_strategy_3d\n    elif mode == \"STA_tuning_cfg\":\n        # Get required parameters for both positive and negative paths\n        mask_search_files_path_pos: str | None = kwargs.get(\n            \"mask_search_files_path_pos\"\n        )\n        mask_search_files_path_neg: str | None = kwargs.get(\n            \"mask_search_files_path_neg\"\n        )\n        save_dir_cfg: str | None = kwargs.get(\"save_dir\")\n\n        if (\n            not mask_search_files_path_pos\n            or not mask_search_files_path_neg\n            or not save_dir_cfg\n        ):\n            raise ValueError(\n                \"mask_search_files_path_pos, mask_search_files_path_neg, and save_dir are required for STA_tuning_cfg mode\"\n            )\n\n        # Get optional parameters with defaults\n        mask_candidates_cfg: list[str] | None = kwargs.get(\"mask_candidates\")\n        if mask_candidates_cfg is None:\n            raise ValueError(\"mask_candidates is required for STA_tuning_cfg mode\")\n        mask_selected_cfg: list[int] = kwargs.get(\n            \"mask_selected\", list(range(len(mask_candidates_cfg)))\n        )\n        skip_time_steps_cfg: int | None = kwargs.get(\"skip_time_steps\")\n\n        # Parse selected masks\n        selected_masks_cfg: list[list[int]] = []\n        for index in mask_selected_cfg:\n            mask = mask_candidates_cfg[index]\n            masks_list = [int(x) for x in mask.split(\",\")]\n            selected_masks_cfg.append(masks_list)\n\n        # Read JSON results for both positive and negative paths\n        pos_results = read_specific_json_files(mask_search_files_path_pos)\n        neg_results = read_specific_json_files(mask_search_files_path_neg)\n        # Combine positive and negative results into one list\n        combined_results = pos_results + neg_results\n\n        # Average the combined results\n        averaged_results = average_head_losses(combined_results, selected_masks_cfg)\n\n        # Add full attention mask for specific cases\n        full_attention_mask_cfg: list[int] | None = kwargs.get(\"full_attention_mask\")\n        if full_attention_mask_cfg is not None:\n            selected_masks_cfg.append(full_attention_mask_cfg)\n\n        timesteps_cfg: int = kwargs.get(\"timesteps\", time_step_num)\n        if skip_time_steps_cfg is None:\n            skip_time_steps_cfg = 12\n        # Select best mask strategy using combined results\n        mask_strategy, sparsity, strategy_counts = select_best_mask_strategy(\n            averaged_results,\n            selected_masks_cfg,\n            skip_time_steps_cfg,\n            timesteps_cfg,\n            head_num,\n        )\n\n        # Save mask strategy\n        os.makedirs(save_dir_cfg, exist_ok=True)\n        file_path = os.path.join(\n            save_dir_cfg, f\"mask_strategy_s{skip_time_steps_cfg}.json\"\n        )\n        with open(file_path, \"w\") as f:\n            json.dump(mask_strategy, f, indent=4)\n        print(f\"Successfully saved mask_strategy to {file_path}\")\n\n        # Print sparsity and strategy counts for information\n        print(f\"Overall sparsity: {sparsity:.4f}\")\n        print(\"\\nStrategy usage counts:\")\n        total_heads = time_step_num * layer_num * head_num  # Fixed dimensions\n        for strategy, count in strategy_counts.items():\n            print(f\"Strategy {strategy}: {count} heads ({count/total_heads*100:.2f}%)\")\n\n        # Convert dictionary to 3D list with fixed dimensions\n        mask_strategy_3d = dict_to_3d_list(\n            mask_strategy, t_max=time_step_num, l_max=layer_num, h_max=head_num\n        )\n\n        return mask_strategy_3d\n\n    else:  # STA_inference\n        # Get parameters with defaults\n        load_path: str | None = kwargs.get(\n            \"load_path\", \"mask_candidates/mask_strategy.json\"\n        )\n        if load_path is None:\n            raise ValueError(\"load_path is required for STA_inference mode\")\n\n        # Load previously saved mask strategy\n        with open(load_path) as f:\n            mask_strategy = json.load(f)\n\n        # Convert dictionary to 3D list with fixed dimensions\n        mask_strategy_3d = dict_to_3d_list(\n            mask_strategy, t_max=time_step_num, l_max=layer_num, h_max=head_num\n        )\n\n        return mask_strategy_3d\n\n\n# Helper functions\n\n\ndef read_specific_json_files(folder_path: str) -> list[dict[str, Any]]:\n    \"\"\"Read and parse JSON files containing mask search results.\"\"\"\n    json_contents: list[dict[str, Any]] = []\n\n    # List files only in the current directory (no walk)\n    files = os.listdir(folder_path)\n    # Filter files\n    matching_files = [f for f in files if \"mask\" in f and f.endswith(\".json\")]\n    print(f\"Found {len(matching_files)} matching files: {matching_files}\")\n\n    for file_name in matching_files:\n        file_path = os.path.join(folder_path, file_name)\n        with open(file_path) as file:\n            data = json.load(file)\n            json_contents.append(data)\n\n    return json_contents\n\n\ndef average_head_losses(\n    results: list[dict[str, Any]], selected_masks: list[list[int]]\n) -> dict[str, dict[str, np.ndarray]]:\n    \"\"\"Average losses across all prompts for each mask strategy.\"\"\"\n    # Initialize a dictionary to store the averaged results\n    averaged_losses: dict[str, dict[str, np.ndarray]] = {}\n    loss_type = \"L2_loss\"\n    # Get all loss types (e.g., 'L2_loss')\n    averaged_losses[loss_type] = {}\n\n    for mask in selected_masks:\n        mask_str = str(mask)\n        data_shape = np.array(results[0][loss_type][mask_str]).shape\n        accumulated_data = np.zeros(data_shape)\n\n        # Sum across all prompts\n        for prompt_result in results:\n            accumulated_data += np.array(prompt_result[loss_type][mask_str])\n\n        # Average by dividing by number of prompts\n        averaged_data = accumulated_data / len(results)\n        averaged_losses[loss_type][mask_str] = averaged_data\n\n    return averaged_losses\n\n\ndef select_best_mask_strategy(\n    averaged_results: dict[str, dict[str, np.ndarray]],\n    selected_masks: list[list[int]],\n    skip_time_steps: int = 12,\n    timesteps: int = 50,\n    head_num: int = 40,\n) -> tuple[dict[str, list[int]], float, dict[str, int]]:\n    \"\"\"Select the best mask strategy for each head based on loss minimization.\"\"\"\n    best_mask_strategy: dict[str, list[int]] = {}\n    loss_type = \"L2_loss\"\n    # Get the shape of time steps and layers\n    layers = len(averaged_results[loss_type][str(selected_masks[0])][0])\n\n    # Counter for sparsity calculation\n    total_tokens = 0  # total number of masked tokens\n    total_length = 0  # total sequence length\n\n    strategy_counts: dict[str, int] = {str(strategy): 0 for strategy in selected_masks}\n    full_attn_strategy = selected_masks[-1]  # Last strategy is full attention\n    print(f\"Strategy {full_attn_strategy}, skip first {skip_time_steps} steps \")\n\n    for t in range(timesteps):\n        for layer_idx in range(layers):\n            for h in range(head_num):\n                if t < skip_time_steps:  # First steps use full attention\n                    strategy = full_attn_strategy\n                else:\n                    # Get losses for this head across all strategies\n                    head_losses = []\n                    for strategy in selected_masks[:-1]:  # Exclude full attention\n                        head_losses.append(\n                            averaged_results[loss_type][str(strategy)][t][layer_idx][h]\n                        )\n\n                    # Find which strategy gives minimum loss\n                    best_strategy_idx = np.argmin(head_losses)\n                    strategy = selected_masks[best_strategy_idx]\n\n                best_mask_strategy[f\"{t}_{layer_idx}_{h}\"] = strategy\n\n                # Calculate sparsity\n                nums = strategy  # strategy is already a list of numbers\n                total_tokens += (\n                    nums[0] * nums[1] * nums[2]\n                )  # masked tokens for chosen strategy\n                total_length += (\n                    full_attn_strategy[0]\n                    * full_attn_strategy[1]\n                    * full_attn_strategy[2]\n                )\n\n                # Count strategy usage\n                strategy_counts[str(strategy)] += 1\n\n    overall_sparsity = 1 - total_tokens / total_length\n\n    return best_mask_strategy, overall_sparsity, strategy_counts\n\n\ndef save_mask_search_results(\n    mask_search_final_result: list[dict[str, list[float]]],\n    prompt: str,\n    mask_strategies: list[str],\n    output_dir: str = \"output/mask_search_result/\",\n) -> str | None:\n    if not mask_search_final_result:\n        print(\"No mask search results to save\")\n        return None\n\n    # Create result dictionary with defaultdict for nested lists\n    mask_search_dict: dict[str, dict[str, list[list[float]]]] = {\n        \"L2_loss\": defaultdict(list),\n        \"L1_loss\": defaultdict(list),\n    }\n\n    mask_selected = list(range(len(mask_strategies)))\n    selected_masks: list[list[int]] = []\n    for index in mask_selected:\n        mask = mask_strategies[index]\n        masks_list = [int(x) for x in mask.split(\",\")]\n        selected_masks.append(masks_list)\n\n    # Process each mask strategy\n    for i, mask_strategy in enumerate(selected_masks):\n        mask_strategy_str = str(mask_strategy)\n        # Process L2 loss\n        step_results: list[list[float]] = []\n        for step_data in mask_search_final_result:\n            if isinstance(step_data, dict) and \"L2_loss\" in step_data:\n                layer_losses = [float(loss) for loss in step_data[\"L2_loss\"]]\n                step_results.append(layer_losses)\n        mask_search_dict[\"L2_loss\"][mask_strategy_str] = step_results\n\n        step_results = []\n        for step_data in mask_search_final_result:\n            if isinstance(step_data, dict) and \"L1_loss\" in step_data:\n                layer_losses = [float(loss) for loss in step_data[\"L1_loss\"]]\n                step_results.append(layer_losses)\n        mask_search_dict[\"L1_loss\"][mask_strategy_str] = step_results\n\n    # Create the output directory if it doesn't exist\n    os.makedirs(output_dir, exist_ok=True)\n\n    # Create a filename based on the first 20 characters of the prompt\n    filename = prompt[:50].replace(\" \", \"_\")\n    filepath = os.path.join(output_dir, f\"mask_search_{filename}.json\")\n\n    # Save the results to a JSON file\n    with open(filepath, \"w\") as f:\n        json.dump(mask_search_dict, f, indent=4)\n\n    print(f\"Successfully saved mask research results to {filepath}\")\n\n    return filepath\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/layers/attention/__init__.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\nfrom sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import (\n    AttentionBackend,\n    AttentionMetadata,\n    AttentionMetadataBuilder,\n)\nfrom sglang.multimodal_gen.runtime.layers.attention.layer import (\n    LocalAttention,\n    UlyssesAttention,\n    UlyssesAttention_VSA,\n    USPAttention,\n)\nfrom sglang.multimodal_gen.runtime.layers.attention.selector import get_attn_backend\nfrom sglang.multimodal_gen.runtime.layers.attention.turbo_layer import MinimalA2AAttnOp\n\n__all__ = [\n    \"USPAttention\",\n    \"LocalAttention\",\n    \"UlyssesAttention\",\n    \"UlyssesAttention_VSA\",\n    \"MinimalA2AAttnOp\",\n    \"AttentionBackend\",\n    \"AttentionMetadata\",\n    \"AttentionMetadataBuilder\",\n    # \"AttentionState\",\n    \"get_attn_backend\",\n]\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/layers/attention/backends/__init__.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/layers/attention/backends/aiter.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\nimport aiter\nimport torch\n\nfrom sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import (\n    AttentionBackend,\n    AttentionImpl,\n    AttentionMetadata,\n    AttentionMetadataBuilder,\n)\nfrom sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum\n\n\nclass AITerBackend(AttentionBackend):\n    \"\"\"\n    Backend for AITemplate attention implementation.\n    \"\"\"\n\n    @staticmethod\n    def get_enum() -> AttentionBackendEnum:\n        return AttentionBackendEnum.AITER\n\n    @staticmethod\n    def get_impl_cls() -> type[\"AITerImpl\"]:\n        return AITerImpl\n\n    @staticmethod\n    def get_metadata_cls() -> type[\"AttentionMetadata\"]:\n        # AITer backend does not require special metadata.\n        return AttentionMetadata\n\n    @staticmethod\n    def get_builder_cls() -> type[\"AttentionMetadataBuilder\"]:\n        raise NotImplementedError(\"AITer backend does not have a metadata builder.\")\n\n\nclass AITerImpl(AttentionImpl):\n    \"\"\"\n    Implementation of attention using AITemplate.\n    \"\"\"\n\n    def __init__(\n        self,\n        num_heads: int,\n        head_size: int,\n        softmax_scale: float,\n        causal: bool = False,\n        num_kv_heads: int | None = None,\n        prefix: str = \"\",\n        dropout_p: float = 0.0,\n        **extra_impl_args,\n    ) -> None:\n        if num_kv_heads is not None and num_kv_heads != num_heads:\n            raise NotImplementedError(\n                \"AITer backend does not support Grouped Query Attention yet.\"\n            )\n        self.causal = causal\n        self.dropout_p = dropout_p\n\n    def forward(\n        self,\n        query: torch.Tensor,\n        key: torch.Tensor,\n        value: torch.Tensor,\n        attn_metadata: AttentionMetadata | None = None,\n    ) -> torch.Tensor:\n        \"\"\"\n        Performs attention using aiter.flash_attn_func.\n\n        Args:\n            query: Query tensor of shape [batch_size, num_heads, seq_len, head_dim]\n            key: Key tensor of shape [batch_size, num_heads, seq_len, head_dim]\n            value: Value tensor of shape [batch_size, num_heads, seq_len, head_dim]\n            attn_metadata: Metadata for the attention operation (unused).\n\n        Returns:\n            Output tensor of shape [batch_size, num_heads, seq_len, head_dim]\n        \"\"\"\n        # aiter.flash_attn_func expects tensors in [B, H, S, D] layout,\n        # which is what ring_attn provides.\n        output, _ = aiter.flash_attn_func(\n            query,\n            key,\n            value,\n            dropout_p=self.dropout_p,\n            causal=self.causal,\n            return_attn_probs=False,\n            return_lse=True,\n        )\n        return output\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/layers/attention/backends/aiter_sage.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\n\nimport torch\n\nfrom sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import (\n    AttentionBackend,\n    AttentionImpl,\n    AttentionMetadata,\n    AttentionMetadataBuilder,\n)\nfrom sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum\n\n\nclass AITERSageBackend(AttentionBackend):\n\n    @staticmethod\n    def get_enum() -> AttentionBackendEnum:\n        return AttentionBackendEnum.AITER_SAGE\n\n    @staticmethod\n    def get_impl_cls() -> type[\"AITERSageImpl\"]:\n        return AITERSageImpl\n\n    @staticmethod\n    def get_metadata_cls() -> type[\"AttentionMetadata\"]:\n        # AITER Sage backend does not require special metadata.\n        return AttentionMetadata\n\n    @staticmethod\n    def get_builder_cls() -> type[\"AttentionMetadataBuilder\"]:\n        raise NotImplementedError(\n            \"AITER Sage backend does not have a metadata builder.\"\n        )\n\n\nclass AITERSageImpl(AttentionImpl):\n\n    def __init__(\n        self,\n        num_heads: int,\n        head_size: int,\n        softmax_scale: float,\n        causal: bool = False,\n        num_kv_heads: int | None = None,\n        prefix: str = \"\",\n        dropout_p: float = 0.0,\n        **extra_impl_args,\n    ) -> None:\n\n        try:\n            from aiter.ops.triton.attention.fav3_sage import fav3_sage_wrapper_func\n\n            self.aiter_sage_attn_fn = fav3_sage_wrapper_func\n        except ImportError:\n            raise ImportError(\n                \"AITER Sage attention is not available, please update AITER version.\"\n            )\n\n    def forward(\n        self,\n        query: torch.Tensor,\n        key: torch.Tensor,\n        value: torch.Tensor,\n        attn_metadata: AttentionMetadata | None = None,\n    ) -> torch.Tensor:\n        \"\"\"\n        Performs attention using aiter sage backend.\n\n        Args:\n            query: Query tensor of shape [batch_size, seq_len, head_num, head_dim]\n            key: Key tensor of shape [batch_size, seq_len, head_num, head_dim]\n            value: Value tensor of shape [batch_size, seq_len, head_num, head_dim]\n            attn_metadata: Metadata for the attention operation (unused).\n\n        Returns:\n            Output tensor of shape [batch_size, seq_len, head_num, head_dim]\n        \"\"\"\n\n        output = self.aiter_sage_attn_fn(query, key, value)\n        return output\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/layers/attention/backends/attention_backend.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/attention/backends/abstract.py\n\nfrom abc import ABC, abstractmethod\nfrom dataclasses import dataclass, fields\nfrom typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar\n\nif TYPE_CHECKING:\n    pass\n\nimport torch\n\nfrom sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum\n\n\nclass AttentionBackend(ABC):\n    \"\"\"Abstract class for attention backends.\"\"\"\n\n    # For some attention backends, we allocate an output tensor before\n    # calling the custom op. When piecewise cudagraph is enabled, this\n    # makes sure the output tensor is allocated inside the cudagraph.\n    accept_output_buffer: bool = False\n\n    @staticmethod\n    @abstractmethod\n    def get_enum() -> AttentionBackendEnum:\n        raise NotImplementedError\n\n    @staticmethod\n    @abstractmethod\n    def get_impl_cls() -> type[\"AttentionImpl\"]:\n        raise NotImplementedError\n\n    @staticmethod\n    @abstractmethod\n    def get_metadata_cls() -> type[\"AttentionMetadata\"]:\n        raise NotImplementedError\n\n    # @staticmethod\n    # @abstractmethod\n    # def get_state_cls() -> Type[\"AttentionState\"]:\n    #     raise NotImplementedError\n\n    # @classmethod\n    # def make_metadata(cls, *args, **kwargs) -> \"AttentionMetadata\":\n    #     return cls.get_metadata_cls()(*args, **kwargs)\n\n    @staticmethod\n    @abstractmethod\n    def get_builder_cls() -> type[\"AttentionMetadataBuilder\"]:\n        return None\n\n\n@dataclass\nclass AttentionMetadata:\n    \"\"\"Attention metadata for prefill and decode batched together.\"\"\"\n\n    # Current step of diffusion process\n    current_timestep: int\n\n    def asdict_zerocopy(self, skip_fields: set[str] | None = None) -> dict[str, Any]:\n        \"\"\"Similar to dataclasses.asdict, but avoids deepcopying.\"\"\"\n        if skip_fields is None:\n            skip_fields = set()\n        # Note that if we add dataclasses as fields, they will need\n        # similar handling.\n        return {\n            field.name: getattr(self, field.name)\n            for field in fields(self)\n            if field.name not in skip_fields\n        }\n\n\nT = TypeVar(\"T\", bound=AttentionMetadata)\n\n\nclass AttentionMetadataBuilder(ABC, Generic[T]):\n    \"\"\"Abstract class for attention metadata builders.\"\"\"\n\n    @abstractmethod\n    def __init__(self) -> None:\n        \"\"\"Create the builder, remember some configuration and parameters.\"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def prepare(self) -> None:\n        \"\"\"Prepare for one batch.\"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def build(\n        self,\n        **kwargs: dict[str, Any],\n    ) -> AttentionMetadata:\n        \"\"\"Build attention metadata with on-device tensors.\"\"\"\n        raise NotImplementedError\n\n\nclass AttentionLayer(Protocol):\n\n    _k_scale: torch.Tensor\n    _v_scale: torch.Tensor\n    _k_scale_float: float\n    _v_scale_float: float\n\n    def forward(\n        self,\n        query: torch.Tensor,\n        key: torch.Tensor,\n        value: torch.Tensor,\n        kv_cache: torch.Tensor,\n        attn_metadata: AttentionMetadata,\n    ) -> torch.Tensor: ...\n\n\nclass AttentionImpl(ABC, Generic[T]):\n\n    @abstractmethod\n    def __init__(\n        self,\n        num_heads: int,\n        head_size: int,\n        softmax_scale: float,\n        causal: bool = False,\n        num_kv_heads: int | None = None,\n        prefix: str = \"\",\n        **extra_impl_args,\n    ) -> None:\n        raise NotImplementedError\n\n    def preprocess_qkv(self, qkv: torch.Tensor, attn_metadata: T) -> torch.Tensor:\n        \"\"\"Preprocess QKV tensor before performing attention operation.\n\n        Default implementation returns the tensor unchanged.\n        Subclasses can override this to implement custom preprocessing\n        like reshaping, tiling, scaling, or other transformations.\n\n        Called AFTER all_to_all for distributed attention\n\n        \"\"\"\n        return qkv\n\n    def postprocess_output(\n        self,\n        output: torch.Tensor,\n        attn_metadata: T,\n    ) -> torch.Tensor:\n        \"\"\"Postprocess the output tensor after the attention operation.\n\n        Default implementation returns the tensor unchanged.\n        Subclasses can override this to implement custom postprocessing\n        like untiling, scaling, or other transformations.\n\n        Called BEFORE all_to_all for distributed attention\n\n        \"\"\"\n\n        return output\n\n    @abstractmethod\n    def forward(\n        self,\n        query: torch.Tensor,\n        key: torch.Tensor,\n        value: torch.Tensor,\n        attn_metadata: T,\n    ) -> torch.Tensor:\n        raise NotImplementedError\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n# SPDX-License-Identifier: Apache-2.0\nfrom dataclasses import dataclass\nfrom typing import Any, List, Optional, Tuple\n\nimport torch\n\nfrom sglang.multimodal_gen.runtime.layers.utils import register_custom_op\nfrom sglang.multimodal_gen.runtime.managers.forward_context import get_forward_context\nfrom sglang.multimodal_gen.runtime.platforms import (\n    AttentionBackendEnum,\n)\n\ntry:\n    from sgl_kernel.flash_attn import flash_attn_varlen_func\n\n    from sglang.jit_kernel.flash_attention_v4 import (\n        flash_attn_varlen_func as flash_attn_varlen_func_fa4,\n    )\n\n    def flash_attn_func(*args, ver: int = 3, **kwargs):\n        if ver == 4:\n            return flash_attn_varlen_func_fa4(*args, **kwargs)\n        return flash_attn_varlen_func(*args, **kwargs)\n\nexcept ImportError as e:\n    raise e\n\n\ndef maybe_contiguous(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]:\n    return x.contiguous() if x is not None and x.stride(-1) != 1 else x\n\n\n# -----------------------------\n# Fake implementations for schema / tracing\n# custom op schema requires FIXED return structure.\n# We provide TWO ops:\n# 1) out-only op: always returns Tensor\n# 2) out+lse op: always returns Tuple[Tensor, Tensor]\n# -----------------------------\ndef flash_attn_varlen_func_fake_out(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    cu_seqlens_q: Optional[torch.Tensor] = None,\n    cu_seqlens_k: Optional[torch.Tensor] = None,\n    max_seqlen_q: Optional[int] = None,\n    max_seqlen_k: Optional[int] = None,\n    seqused_q: Optional[torch.Tensor] = None,\n    seqused_k: Optional[torch.Tensor] = None,\n    page_table: Optional[torch.Tensor] = None,\n    softmax_scale: Optional[float] = None,\n    causal: bool = False,\n    qv: Optional[torch.Tensor] = None,\n    q_descale: Optional[torch.Tensor] = None,\n    k_descale: Optional[torch.Tensor] = None,\n    v_descale: Optional[torch.Tensor] = None,\n    window_size: Optional[List[int]] = None,\n    attention_chunk: int = 0,\n    softcap: float = 0.0,\n    num_splits: int = 1,\n    pack_gqa: Optional[bool] = None,\n    sm_margin: int = 0,\n    return_softmax_lse: bool = False,\n    sinks: Optional[torch.Tensor] = None,\n    ver: int = 4,\n) -> torch.Tensor:\n    assert ver == 4, \"only support flash attention v4\"\n    q, k, v = [maybe_contiguous(t) for t in (q, k, v)]\n    num_head, head_dim = q.shape[-2:]\n    if cu_seqlens_q is None:\n        batch_size, seqlen_q = q.shape[:2]\n    else:\n        batch_size = cu_seqlens_q.shape[0] - 1\n        seqlen_q = None\n    head_dim_v = v.shape[-1]\n\n    if cu_seqlens_q is not None:\n        assert cu_seqlens_q.shape == (\n            batch_size + 1,\n        ), \"cu_seqlens_q must have shape (batch_size + 1,)\"\n        assert cu_seqlens_q.dtype == torch.int32, \"cu_seqlens_q must be int32\"\n        assert cu_seqlens_q.stride(0) == 1, \"cu_seqlens_q must be contiguous\"\n\n    assert q.dtype in [\n        torch.float16,\n        torch.bfloat16,\n    ], \"inputs must be float16 or bfloat16\"\n    assert q.dtype == k.dtype == v.dtype, \"inputs must have the same dtype\"\n    assert head_dim <= 256, \"head_dim must be less than or equal to 256\"\n    alignment = 16 // q.element_size()\n    assert head_dim_v % alignment == 0, f\"head_dim_v must be divisible by {alignment}\"\n\n    q_batch_seqlen_shape = (\n        (batch_size, seqlen_q) if cu_seqlens_q is None else (q.shape[0],)\n    )\n    out = q.new_empty(*q_batch_seqlen_shape, num_head, head_dim_v)\n    return out\n\n\ndef flash_attn_varlen_func_fake_out_lse(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    cu_seqlens_q: Optional[torch.Tensor] = None,\n    cu_seqlens_k: Optional[torch.Tensor] = None,\n    max_seqlen_q: Optional[int] = None,\n    max_seqlen_k: Optional[int] = None,\n    seqused_q: Optional[torch.Tensor] = None,\n    seqused_k: Optional[torch.Tensor] = None,\n    page_table: Optional[torch.Tensor] = None,\n    softmax_scale: Optional[float] = None,\n    causal: bool = False,\n    qv: Optional[torch.Tensor] = None,\n    q_descale: Optional[torch.Tensor] = None,\n    k_descale: Optional[torch.Tensor] = None,\n    v_descale: Optional[torch.Tensor] = None,\n    window_size: Optional[List[int]] = None,\n    attention_chunk: int = 0,\n    softcap: float = 0.0,\n    num_splits: int = 1,\n    pack_gqa: Optional[bool] = None,\n    sm_margin: int = 0,\n    return_softmax_lse: bool = True,\n    sinks: Optional[torch.Tensor] = None,\n    ver: int = 4,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    assert ver == 4, \"only support flash attention v4\"\n    q, k, v = [maybe_contiguous(t) for t in (q, k, v)]\n    num_head, head_dim = q.shape[-2:]\n    if cu_seqlens_q is None:\n        batch_size, seqlen_q = q.shape[:2]\n        total_q = batch_size * seqlen_q\n    else:\n        batch_size = cu_seqlens_q.shape[0] - 1\n        seqlen_q = None\n        total_q = q.shape[0]\n    head_dim_v = v.shape[-1]\n\n    if cu_seqlens_q is not None:\n        assert cu_seqlens_q.shape == (\n            batch_size + 1,\n        ), \"cu_seqlens_q must have shape (batch_size + 1,)\"\n        assert cu_seqlens_q.dtype == torch.int32, \"cu_seqlens_q must be int32\"\n        assert cu_seqlens_q.stride(0) == 1, \"cu_seqlens_q must be contiguous\"\n\n    assert q.dtype in [\n        torch.float16,\n        torch.bfloat16,\n    ], \"inputs must be float16 or bfloat16\"\n    assert q.dtype == k.dtype == v.dtype, \"inputs must have the same dtype\"\n    assert head_dim <= 256, \"head_dim must be less than or equal to 256\"\n    alignment = 16 // q.element_size()\n    assert head_dim_v % alignment == 0, f\"head_dim_v must be divisible by {alignment}\"\n\n    q_batch_seqlen_shape = (\n        (batch_size, seqlen_q) if cu_seqlens_q is None else (total_q,)\n    )\n    lse_shape = (\n        (batch_size, num_head, seqlen_q)\n        if cu_seqlens_q is None\n        else (num_head, total_q)\n    )\n\n    out = q.new_empty(*q_batch_seqlen_shape, num_head, head_dim_v)\n    lse = q.new_empty(lse_shape, dtype=torch.float32)\n    return out, lse\n\n\n# -----------------------------\n# Registered custom ops\n# NOTE: fixed return schemas to avoid:\n# \"Object of type 'Tensor' is not an instance of 'sequence'\"\n# -----------------------------\n@register_custom_op(fake_impl=flash_attn_varlen_func_fake_out)\ndef flash_attn_varlen_func_op(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    cu_seqlens_q: Optional[torch.Tensor] = None,\n    cu_seqlens_k: Optional[torch.Tensor] = None,\n    max_seqlen_q: Optional[int] = None,\n    max_seqlen_k: Optional[int] = None,\n    seqused_q: Optional[torch.Tensor] = None,\n    seqused_k: Optional[torch.Tensor] = None,\n    page_table: Optional[torch.Tensor] = None,\n    softmax_scale: Optional[float] = None,\n    causal: bool = False,\n    qv: Optional[torch.Tensor] = None,\n    q_descale: Optional[torch.Tensor] = None,\n    k_descale: Optional[torch.Tensor] = None,\n    v_descale: Optional[torch.Tensor] = None,\n    window_size: Optional[List[int]] = None,\n    attention_chunk: int = 0,\n    softcap: float = 0.0,\n    num_splits: int = 1,\n    pack_gqa: Optional[bool] = None,\n    sm_margin: int = 0,\n    return_softmax_lse: bool = False,\n    sinks: Optional[torch.Tensor] = None,\n    ver: int = 4,\n) -> torch.Tensor:\n    if window_size is None:\n        window_size = [-1, -1]\n    if return_softmax_lse:\n        raise ValueError(\n            \"flash_attn_varlen_func_op is out-only op; return_softmax_lse must be False. \"\n            \"Use flash_attn_varlen_func_op_lse for (out, lse).\"\n        )\n    return flash_attn_func(\n        q,\n        k,\n        v,\n        cu_seqlens_q=cu_seqlens_q,\n        cu_seqlens_k=cu_seqlens_k,\n        max_seqlen_q=max_seqlen_q,\n        max_seqlen_k=max_seqlen_k,\n        seqused_q=seqused_q,\n        seqused_k=seqused_k,\n        page_table=page_table,\n        softmax_scale=softmax_scale,\n        causal=causal,\n        qv=qv,\n        q_descale=q_descale,\n        k_descale=k_descale,\n        v_descale=v_descale,\n        window_size=tuple(window_size),\n        attention_chunk=attention_chunk,\n        softcap=softcap,\n        num_splits=num_splits,\n        pack_gqa=pack_gqa,\n        sm_margin=sm_margin,\n        return_softmax_lse=False,\n        sinks=sinks,\n        ver=ver,\n    )\n\n\n@register_custom_op(fake_impl=flash_attn_varlen_func_fake_out_lse)\ndef flash_attn_varlen_func_op_lse(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    cu_seqlens_q: Optional[torch.Tensor] = None,\n    cu_seqlens_k: Optional[torch.Tensor] = None,\n    max_seqlen_q: Optional[int] = None,\n    max_seqlen_k: Optional[int] = None,\n    seqused_q: Optional[torch.Tensor] = None,\n    seqused_k: Optional[torch.Tensor] = None,\n    page_table: Optional[torch.Tensor] = None,\n    softmax_scale: Optional[float] = None,\n    causal: bool = False,\n    qv: Optional[torch.Tensor] = None,\n    q_descale: Optional[torch.Tensor] = None,\n    k_descale: Optional[torch.Tensor] = None,\n    v_descale: Optional[torch.Tensor] = None,\n    window_size: Optional[List[int]] = None,\n    attention_chunk: int = 0,\n    softcap: float = 0.0,\n    num_splits: int = 1,\n    pack_gqa: Optional[bool] = None,\n    sm_margin: int = 0,\n    return_softmax_lse: bool = True,\n    sinks: Optional[torch.Tensor] = None,\n    ver: int = 4,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    if window_size is None:\n        window_size = [-1, -1]\n    if not return_softmax_lse:\n        raise ValueError(\n            \"flash_attn_varlen_func_op_lse is out+lse op; return_softmax_lse must be True. \"\n            \"Use flash_attn_varlen_func_op for out-only.\"\n        )\n    return flash_attn_func(\n        q,\n        k,\n        v,\n        cu_seqlens_q=cu_seqlens_q,\n        cu_seqlens_k=cu_seqlens_k,\n        max_seqlen_q=max_seqlen_q,\n        max_seqlen_k=max_seqlen_k,\n        seqused_q=seqused_q,\n        seqused_k=seqused_k,\n        page_table=page_table,\n        softmax_scale=softmax_scale,\n        causal=causal,\n        qv=qv,\n        q_descale=q_descale,\n        k_descale=k_descale,\n        v_descale=v_descale,\n        window_size=tuple(window_size),\n        attention_chunk=attention_chunk,\n        softcap=softcap,\n        num_splits=num_splits,\n        pack_gqa=pack_gqa,\n        sm_margin=sm_margin,\n        return_softmax_lse=True,\n        sinks=sinks,\n        ver=ver,\n    )\n\n\nfrom sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import (\n    AttentionBackend,\n    AttentionImpl,\n    AttentionMetadata,\n    AttentionMetadataBuilder,\n)\n\nfa_ver = 3\n\n\ndef set_fa_ver(ver: int) -> None:\n    global fa_ver\n    fa_ver = ver\n\n\n@dataclass\nclass FlashAttentionMetadata:\n    # Sequence lengths for the forward batch\n    # Maximum sequence length for query\n    max_seqlen_q: int = 1\n    # Maximum sequence length for key\n    max_seqlen_k: int = 0\n    # Cumulative sequence lengths for query\n    cu_seqlens_q: torch.Tensor = None\n    # Cumulative sequence lengths for key\n    cu_seqlens_k: torch.Tensor = None\n\n\nclass FlashAttentionMetadataBuilder(AttentionMetadataBuilder):\n    def __init__(self) -> None:\n        pass\n\n    def prepare(self) -> None:\n        pass\n\n    def build(  # type: ignore\n        self,\n        raw_latent_shape=list,\n        **kwargs: dict[str, Any],\n    ) -> FlashAttentionMetadata:\n        # TODO: put empty values here to be set at first-run, since the q_len calculation can be complicated\n        return FlashAttentionMetadata(max_seqlen_q=None, max_seqlen_k=None)\n\n\nclass FlashAttentionBackend(AttentionBackend):\n    accept_output_buffer: bool = True\n\n    @staticmethod\n    def get_supported_head_sizes() -> list[int]:\n        return [32, 64, 96, 128, 160, 192, 224, 256]\n\n    @staticmethod\n    def get_enum() -> AttentionBackendEnum:\n        return AttentionBackendEnum.FA\n\n    @staticmethod\n    def get_impl_cls() -> type[\"FlashAttentionImpl\"]:\n        return FlashAttentionImpl\n\n    @staticmethod\n    def get_metadata_cls() -> type[\"AttentionMetadata\"]:\n        raise NotImplementedError\n\n    @staticmethod\n    def get_builder_cls() -> type[\"AttentionMetadataBuilder\"]:\n        return FlashAttentionMetadataBuilder\n\n\nclass FlashAttentionImpl(AttentionImpl):\n    def __init__(\n        self,\n        num_heads: int,\n        head_size: int,\n        causal: bool,\n        softmax_scale: float,\n        num_kv_heads: int | None = None,\n        prefix: str = \"\",\n        **extra_impl_args,\n    ) -> None:\n        self.num_heads = num_heads\n        self.num_kv_heads = num_kv_heads\n        self.head_size = head_size\n        self.causal = causal\n        self.softmax_scale = softmax_scale\n        self.attention_metadata = FlashAttentionMetadata()\n\n    def forward(\n        self,\n        query: torch.Tensor,\n        key: torch.Tensor,\n        value: torch.Tensor,\n        attn_metadata: AttentionMetadata = None,\n        *,\n        return_softmax_lse: bool = False,\n    ):\n        attn_metadata: FlashAttentionMetadata = get_forward_context().attn_metadata\n        if attn_metadata is not None and attn_metadata.max_seqlen_q is None:\n            attn_metadata.max_seqlen_q = query.shape[1]\n            attn_metadata.max_seqlen_k = key.shape[1]\n            max_seqlen_q = attn_metadata.max_seqlen_q\n            max_seqlen_k = attn_metadata.max_seqlen_k\n        else:\n            max_seqlen_q = query.shape[1]\n            max_seqlen_k = key.shape[1]\n\n        # FA version selection:\n        # - fa_ver == 3: call python function (can return Tensor or (Tensor, Tensor) depending on flag)\n        # - fa_ver == 4: call custom ops with FIXED return schema\n        if fa_ver == 3:\n            flash_attn_op = flash_attn_func\n            output = flash_attn_op(\n                q=query,\n                k=key,\n                v=value,\n                cu_seqlens_q=None,\n                cu_seqlens_k=None,\n                max_seqlen_q=max_seqlen_q,\n                max_seqlen_k=max_seqlen_k,\n                softmax_scale=self.softmax_scale,\n                causal=self.causal,\n                return_softmax_lse=return_softmax_lse,\n                ver=fa_ver,\n            )\n            return output\n\n        if fa_ver == 4:\n            if return_softmax_lse:\n                out_tensor, softmax_lse = flash_attn_varlen_func_op_lse(\n                    q=query,\n                    k=key,\n                    v=value,\n                    cu_seqlens_q=None,\n                    cu_seqlens_k=None,\n                    max_seqlen_q=max_seqlen_q,\n                    max_seqlen_k=max_seqlen_k,\n                    softmax_scale=self.softmax_scale,\n                    causal=self.causal,\n                    return_softmax_lse=True,\n                    ver=fa_ver,\n                )\n                return out_tensor, softmax_lse\n            out_tensor = flash_attn_varlen_func_op(\n                q=query,\n                k=key,\n                v=value,\n                cu_seqlens_q=None,\n                cu_seqlens_k=None,\n                max_seqlen_q=max_seqlen_q,\n                max_seqlen_k=max_seqlen_k,\n                softmax_scale=self.softmax_scale,\n                causal=self.causal,\n                return_softmax_lse=False,\n                ver=fa_ver,\n            )\n            return out_tensor\n\n        raise ValueError(f\"flash attention version {fa_ver} is not supported.\")\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn_2.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\nimport torch\n\nfrom sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import (\n    AttentionBackend,\n    AttentionImpl,\n    AttentionMetadata,\n    AttentionMetadataBuilder,\n)\nfrom sglang.multimodal_gen.runtime.layers.attention.backends.flash_attn import (\n    flash_attn_func,\n)\nfrom sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\nclass FlashAttention2Backend(AttentionBackend):\n    accept_output_buffer: bool = True\n\n    @staticmethod\n    def get_supported_head_sizes() -> list[int]:\n        return [32, 64, 96, 128, 160, 192, 224, 256]\n\n    @staticmethod\n    def get_enum() -> AttentionBackendEnum:\n        return AttentionBackendEnum.FA2\n\n    @staticmethod\n    def get_impl_cls() -> type[\"FlashAttention2Impl\"]:\n        return FlashAttention2Impl\n\n    @staticmethod\n    def get_metadata_cls() -> type[\"AttentionMetadata\"]:\n        raise NotImplementedError\n\n    @staticmethod\n    def get_builder_cls() -> type[\"AttentionMetadataBuilder\"]:\n        raise NotImplementedError\n\n\nclass FlashAttention2Impl(AttentionImpl):\n\n    def __init__(\n        self,\n        num_heads: int,\n        head_size: int,\n        causal: bool,\n        softmax_scale: float,\n        num_kv_heads: int | None = None,\n        prefix: str = \"\",\n        **extra_impl_args,\n    ) -> None:\n        self.causal = causal\n        self.softmax_scale = softmax_scale\n\n    def forward(\n        self,\n        query: torch.Tensor,\n        key: torch.Tensor,\n        value: torch.Tensor,\n        attn_metadata: AttentionMetadata,\n    ):\n        output = flash_attn_func(\n            q=query,  # type: ignore[no-untyped-call]\n            k=key,\n            v=value,\n            cu_seqlens_q=None,\n            cu_seqlens_k=None,\n            max_seqlen_q=None,\n            max_seqlen_k=None,\n            softmax_scale=self.softmax_scale,\n            causal=self.causal,\n        )\n        return output\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/layers/attention/backends/sage_attn.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\n\nimport torch\nfrom sageattention import sageattn\n\nfrom sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import (  # FlashAttentionMetadata,\n    AttentionBackend,\n    AttentionImpl,\n    AttentionMetadata,\n)\nfrom sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\nclass SageAttentionBackend(AttentionBackend):\n    accept_output_buffer: bool = True\n\n    @staticmethod\n    def get_supported_head_sizes() -> list[int]:\n        return [32, 64, 96, 128, 160, 192, 224, 256]\n\n    @staticmethod\n    def get_enum() -> AttentionBackendEnum:\n        return AttentionBackendEnum.SAGE_ATTN\n\n    @staticmethod\n    def get_impl_cls() -> type[\"SageAttentionImpl\"]:\n        return SageAttentionImpl\n\n\nclass SageAttentionImpl(AttentionImpl):\n\n    def __init__(\n        self,\n        num_heads: int,\n        head_size: int,\n        causal: bool,\n        softmax_scale: float,\n        num_kv_heads: int | None = None,\n        prefix: str = \"\",\n        **extra_impl_args,\n    ) -> None:\n        self.causal = causal\n        self.softmax_scale = softmax_scale\n        self.dropout = extra_impl_args.get(\"dropout_p\", 0.0)\n\n    def forward(\n        self,\n        query: torch.Tensor,\n        key: torch.Tensor,\n        value: torch.Tensor,\n        attn_metadata: AttentionMetadata,\n        *,\n        return_softmax_lse: bool = False,\n    ) -> torch.Tensor:\n        output = sageattn(\n            query,\n            key,\n            value,\n            # since input is (batch_size, seq_len, head_num, head_dim)\n            tensor_layout=\"NHD\",\n            is_causal=self.causal,\n            sm_scale=self.softmax_scale,\n            return_lse=return_softmax_lse,\n        )\n        if return_softmax_lse:\n            output, softmax_lse = output\n            return output, softmax_lse\n        return output\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/layers/attention/backends/sage_attn3.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\nimport torch\nimport torch.nn.functional as F\nfrom sageattn3 import sageattn3_blackwell\n\nfrom sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import (\n    AttentionBackend,\n    AttentionImpl,\n    AttentionMetadata,\n)\nfrom sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\nclass SageAttention3Backend(AttentionBackend):\n    accept_output_buffer: bool = True\n\n    @staticmethod\n    def get_supported_head_sizes() -> list[int]:\n        return [64, 128, 256]\n\n    @staticmethod\n    def get_enum() -> AttentionBackendEnum:\n        return AttentionBackendEnum.SAGE_ATTN_3\n\n    @staticmethod\n    def get_impl_cls() -> type[\"SageAttention3Impl\"]:\n        return SageAttention3Impl\n\n    @staticmethod\n    def get_metadata_cls() -> type[\"AttentionMetadata\"]:\n        raise NotImplementedError\n\n\nclass SageAttention3Impl(AttentionImpl):\n    _warned_gqa_fallback_global: bool = False\n\n    def __init__(\n        self,\n        num_heads: int,\n        head_size: int,\n        causal: bool,\n        softmax_scale: float,\n        num_kv_heads: int | None = None,\n        prefix: str = \"\",\n        **extra_impl_args,\n    ) -> None:\n        self.causal = causal\n        self.softmax_scale = softmax_scale\n        self.dropout = extra_impl_args.get(\"dropout_p\", 0.0)\n\n    def forward(\n        self,\n        query: torch.Tensor,\n        key: torch.Tensor,\n        value: torch.Tensor,\n        attn_metadata: AttentionMetadata,\n    ) -> torch.Tensor:\n        query = query.transpose(1, 2)\n        key = key.transpose(1, 2)\n        value = value.transpose(1, 2)\n        # SageAttention3's Blackwell kernel assumes MHA (Hq == Hkv). For GQA/MQA\n        # (Hq != Hkv), fall back to torch SDPA which supports GQA.\n        if key.shape[1] != query.shape[1]:\n            if query.shape[1] % key.shape[1] != 0:\n                raise ValueError(\n                    \"GQA/MQA requires query heads to be a multiple of KV heads, \"\n                    f\"got q_heads={query.shape[1]} and kv_heads={key.shape[1]}\"\n                )\n            if not type(self)._warned_gqa_fallback_global:\n                logger.warning(\n                    \"SageAttention3 does not support GQA/MQA (Hq != Hkv); falling back to torch SDPA.\"\n                )\n                type(self)._warned_gqa_fallback_global = True\n            output = F.scaled_dot_product_attention(\n                query,\n                key,\n                value,\n                is_causal=self.causal,\n                dropout_p=self.dropout,\n                scale=self.softmax_scale,\n                enable_gqa=True,\n            )\n        else:\n            output = sageattn3_blackwell(query, key, value, is_causal=self.causal)\n        output = output.transpose(1, 2)\n        return output\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/layers/attention/backends/sdpa.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\nimport torch\n\nfrom sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import (  # FlashAttentionMetadata,\n    AttentionBackend,\n    AttentionImpl,\n    AttentionMetadata,\n)\nfrom sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\nclass SDPABackend(AttentionBackend):\n\n    accept_output_buffer: bool = True\n\n    @staticmethod\n    def get_supported_head_sizes() -> list[int]:\n        return [32, 64, 96, 128, 160, 192, 224, 256]\n\n    @staticmethod\n    def get_enum() -> AttentionBackendEnum:\n        return AttentionBackendEnum.TORCH_SDPA\n\n    @staticmethod\n    def get_impl_cls() -> type[\"SDPAImpl\"]:\n        return SDPAImpl\n\n    # @staticmethod\n    # def get_metadata_cls() -> Type[\"AttentionMetadata\"]:\n    #     return FlashAttentionMetadata\n\n\nclass SDPAImpl(AttentionImpl):\n\n    def __init__(\n        self,\n        num_heads: int,\n        head_size: int,\n        causal: bool,\n        softmax_scale: float,\n        num_kv_heads: int | None = None,\n        prefix: str = \"\",\n        **extra_impl_args,\n    ) -> None:\n        self.causal = causal\n        self.softmax_scale = softmax_scale\n        self.dropout = extra_impl_args.get(\"dropout_p\", 0.0)\n\n    def forward(\n        self,\n        query: torch.Tensor,\n        key: torch.Tensor,\n        value: torch.Tensor,\n        attn_metadata: AttentionMetadata,\n    ) -> torch.Tensor:\n        # transpose to bs, heads, seq_len, head_dim\n        query = query.transpose(1, 2)\n        key = key.transpose(1, 2)\n        value = value.transpose(1, 2)\n        attn_kwargs = {\n            \"attn_mask\": None,\n            \"dropout_p\": self.dropout,\n            \"is_causal\": self.causal,\n            \"scale\": self.softmax_scale,\n        }\n        if query.shape[1] != key.shape[1]:\n            attn_kwargs[\"enable_gqa\"] = True\n        output = torch.nn.functional.scaled_dot_product_attention(\n            query, key, value, **attn_kwargs\n        )\n        output = output.transpose(1, 2)\n        return output\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/layers/attention/backends/sliding_tile_attn.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\nimport json\nfrom dataclasses import dataclass\nfrom typing import Any\n\nimport torch\nfrom einops import rearrange\n\nfrom sglang.multimodal_gen.runtime.distributed import get_sp_group\nfrom sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import (\n    AttentionBackend,\n    AttentionImpl,\n    AttentionMetadata,\n    AttentionMetadataBuilder,\n)\nfrom sglang.multimodal_gen.runtime.managers.forward_context import (\n    ForwardContext,\n    get_forward_context,\n)\nfrom sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum\nfrom sglang.multimodal_gen.runtime.server_args import get_global_server_args\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.utils import dict_to_3d_list\n\ntry:\n    from st_attn import sliding_tile_attention\n\n    st_attn_backend_available = True\nexcept Exception:\n    st_attn_backend_available = False\n\nlogger = init_logger(__name__)\n\n\nclass RangeDict(dict):\n\n    def __getitem__(self, item: int) -> str:\n        for key in self.keys():\n            if isinstance(key, tuple):\n                low, high = key\n                if low <= item <= high:\n                    return str(super().__getitem__(key))\n            elif key == item:\n                return str(super().__getitem__(key))\n        raise KeyError(f\"seq_len {item} not supported for STA\")\n\n\nclass SlidingTileAttentionBackend(AttentionBackend):\n    accept_output_buffer: bool = True\n\n    @staticmethod\n    def get_supported_head_sizes() -> list[int]:\n        # TODO(will-refactor): check this\n        return [32, 64, 96, 128, 160, 192, 224, 256]\n\n    @staticmethod\n    def get_enum() -> AttentionBackendEnum:\n        return AttentionBackendEnum.SLIDING_TILE_ATTN\n\n    @staticmethod\n    def get_impl_cls() -> type[\"SlidingTileAttentionImpl\"]:\n        return SlidingTileAttentionImpl\n\n    @staticmethod\n    def get_metadata_cls() -> type[\"SlidingTileAttentionMetadata\"]:\n        return SlidingTileAttentionMetadata\n\n    @staticmethod\n    def get_builder_cls() -> type[\"SlidingTileAttentionMetadataBuilder\"]:\n        return SlidingTileAttentionMetadataBuilder\n\n\n@dataclass\nclass SlidingTileAttentionMetadata(AttentionMetadata):\n    current_timestep: int\n    STA_param: list[\n        list[Any]\n    ]  # each timestep with one metadata, shape [num_layers, num_heads]\n\n\nclass SlidingTileAttentionMetadataBuilder(AttentionMetadataBuilder):\n\n    def __init__(self):\n        pass\n\n    def prepare(self):\n        pass\n\n    def build(  # type: ignore\n        self,\n        STA_param: list[list[Any]],\n        current_timestep: int,\n        **kwargs: dict[str, Any],\n    ) -> SlidingTileAttentionMetadata:\n        param = STA_param\n        if param is None:\n            return SlidingTileAttentionMetadata(\n                current_timestep=current_timestep, STA_param=[]\n            )\n        return SlidingTileAttentionMetadata(\n            current_timestep=current_timestep, STA_param=param[current_timestep]\n        )\n\n\nclass SlidingTileAttentionImpl(AttentionImpl):\n\n    def __init__(\n        self,\n        num_heads: int,\n        head_size: int,\n        causal: bool,\n        softmax_scale: float,\n        num_kv_heads: int | None = None,\n        prefix: str = \"\",\n        **extra_impl_args,\n    ) -> None:\n        if not st_attn_backend_available:\n            raise ValueError(\"st attn not supported\")\n        # TODO(will-refactor): for now this is the mask strategy, but maybe we should\n        # have a more general config for STA?\n        mask_strategy_file_path = (\n            get_global_server_args().attention_backend_config.mask_strategy_file_path\n        )\n        if mask_strategy_file_path is None:\n            raise ValueError(\"SGLANG_DIFFUSION_ATTENTION_CONFIG is not set\")\n\n        # TODO(kevin): get mask strategy for different STA modes\n        with open(mask_strategy_file_path) as f:\n            mask_strategy = json.load(f)\n        self.mask_strategy = dict_to_3d_list(mask_strategy)\n\n        self.prefix = prefix\n        sp_group = get_sp_group()\n        self.sp_size = sp_group.world_size\n        # STA config\n        self.STA_base_tile_size = [6, 8, 8]\n        self.dit_seq_shape_mapping = RangeDict(\n            {\n                (115200, 115456): \"30x48x80\",\n                82944: \"36x48x48\",\n                69120: \"18x48x80\",\n            }\n        )\n        self.full_window_mapping = {\n            \"30x48x80\": [5, 6, 10],\n            \"36x48x48\": [6, 6, 6],\n            \"18x48x80\": [3, 6, 10],\n        }\n\n    def tile(self, x: torch.Tensor) -> torch.Tensor:\n        return rearrange(\n            x,\n            \"b (n_t ts_t n_h ts_h n_w ts_w) h d -> b (n_t n_h n_w ts_t ts_h ts_w) h d\",\n            n_t=self.full_window_size[0],\n            n_h=self.full_window_size[1],\n            n_w=self.full_window_size[2],\n            ts_t=self.STA_base_tile_size[0],\n            ts_h=self.STA_base_tile_size[1],\n            ts_w=self.STA_base_tile_size[2],\n        )\n\n    def untile(self, x: torch.Tensor) -> torch.Tensor:\n        x = rearrange(\n            x,\n            \"b (n_t n_h n_w ts_t ts_h ts_w) h d -> b (n_t ts_t n_h ts_h n_w ts_w) h d\",\n            n_t=self.full_window_size[0],\n            n_h=self.full_window_size[1],\n            n_w=self.full_window_size[2],\n            ts_t=self.STA_base_tile_size[0],\n            ts_h=self.STA_base_tile_size[1],\n            ts_w=self.STA_base_tile_size[2],\n        )\n        return x\n\n    def preprocess_qkv(\n        self,\n        qkv: torch.Tensor,\n        attn_metadata: AttentionMetadata,\n    ) -> torch.Tensor:\n        img_sequence_length = qkv.shape[1]\n        self.dit_seq_shape_str = self.dit_seq_shape_mapping[img_sequence_length]\n        self.full_window_size = self.full_window_mapping[self.dit_seq_shape_str]\n        self.dit_seq_shape_int = list(map(int, self.dit_seq_shape_str.split(\"x\")))\n        self.img_seq_length = (\n            self.dit_seq_shape_int[0]\n            * self.dit_seq_shape_int[1]\n            * self.dit_seq_shape_int[2]\n        )\n        return self.tile(qkv)\n\n    def postprocess_output(\n        self,\n        output: torch.Tensor,\n        attn_metadata: SlidingTileAttentionMetadata,\n    ) -> torch.Tensor:\n        return self.untile(output)\n\n    def forward(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        attn_metadata: SlidingTileAttentionMetadata,\n    ) -> torch.Tensor:\n        if self.mask_strategy is None:\n            raise ValueError(\"mask_strategy cannot be None for SlidingTileAttention\")\n        if self.mask_strategy[0] is None:\n            raise ValueError(\"mask_strategy[0] cannot be None for SlidingTileAttention\")\n\n        timestep = attn_metadata.current_timestep\n        forward_context: ForwardContext = get_forward_context()\n        forward_batch = forward_context.forward_batch\n        if forward_batch is None:\n            raise ValueError(\"forward_batch cannot be None\")\n        # pattern:'.double_blocks.0.attn.impl' or '.single_blocks.0.attn.impl'\n        layer_idx = int(self.prefix.split(\".\")[-3])\n        if attn_metadata.STA_param is None or len(attn_metadata.STA_param) <= layer_idx:\n            raise ValueError(\"Invalid STA_param\")\n        STA_param = attn_metadata.STA_param[layer_idx]\n\n        text_length = q.shape[1] - self.img_seq_length\n        has_text = text_length > 0\n\n        query = q.transpose(1, 2).contiguous()\n        key = k.transpose(1, 2).contiguous()\n        value = v.transpose(1, 2).contiguous()\n\n        head_num = query.size(1)\n        sp_group = get_sp_group()\n        current_rank = sp_group.rank_in_group\n        start_head = current_rank * head_num\n\n        # searching or tuning mode\n        if len(STA_param) < head_num * sp_group.world_size:\n            sparse_attn_hidden_states_all = []\n            full_mask_window = STA_param[-1]\n            for window_size in STA_param[:-1]:\n                sparse_hidden_states = sliding_tile_attention(\n                    query,\n                    key,\n                    value,\n                    [window_size] * head_num,\n                    text_length,\n                    has_text,\n                    self.dit_seq_shape_str,\n                ).transpose(1, 2)\n                sparse_attn_hidden_states_all.append(sparse_hidden_states)\n\n            hidden_states = sliding_tile_attention(\n                query,\n                key,\n                value,\n                [full_mask_window] * head_num,\n                text_length,\n                has_text,\n                self.dit_seq_shape_str,\n            ).transpose(1, 2)\n\n            attn_L2_loss = []\n            attn_L1_loss = []\n            # average loss across all heads\n            for sparse_attn_hidden_states in sparse_attn_hidden_states_all:\n                # L2 loss\n                attn_L2_loss_ = (\n                    torch.mean(\n                        (sparse_attn_hidden_states.float() - hidden_states.float())\n                        ** 2,\n                        dim=[0, 1, 3],\n                    )\n                    .cpu()\n                    .numpy()\n                )\n                attn_L2_loss_ = [round(float(x), 6) for x in attn_L2_loss_]\n                attn_L2_loss.append(attn_L2_loss_)\n                # L1 loss\n                attn_L1_loss_ = (\n                    torch.mean(\n                        torch.abs(\n                            sparse_attn_hidden_states.float() - hidden_states.float()\n                        ),\n                        dim=[0, 1, 3],\n                    )\n                    .cpu()\n                    .numpy()\n                )\n                attn_L1_loss_ = [round(float(x), 6) for x in attn_L1_loss_]\n                attn_L1_loss.append(attn_L1_loss_)\n\n            layer_loss_save = {\"L2_loss\": attn_L2_loss, \"L1_loss\": attn_L1_loss}\n\n            if forward_batch.is_cfg_negative:\n                if forward_batch.mask_search_final_result_neg is not None:\n                    forward_batch.mask_search_final_result_neg[timestep].append(\n                        layer_loss_save\n                    )\n            else:\n                if forward_batch.mask_search_final_result_pos is not None:\n                    forward_batch.mask_search_final_result_pos[timestep].append(\n                        layer_loss_save\n                    )\n        else:\n            windows = [STA_param[head_idx + start_head] for head_idx in range(head_num)]\n\n            hidden_states = sliding_tile_attention(\n                query,\n                key,\n                value,\n                windows,\n                text_length,\n                has_text,\n                self.dit_seq_shape_str,\n            ).transpose(1, 2)\n\n        return hidden_states\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/layers/attention/backends/sparse_linear_attn.py",
    "content": "\"\"\"\nCopyright (c) 2025 by SLA team.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\n\nThis implementation is adapted from: from https://github.com/thu-ml/TurboDiffusion/blob/main/turbodiffusion/SLA/core.py and https://github.com/thu-ml/SLA/blob/main/SageSLA/core.py\nCitation (please cite if you use this code):\n\n@article{zhang2025sla,\n  title={SLA: Beyond Sparsity in Diffusion Transformers via Fine-Tunable Sparse-Linear Attention},\n  author={Jintao Zhang and Haoxu Wang and Kai Jiang and Shuo Yang and Kaiwen Zheng and Haocheng Xi and Ziteng Wang and Hongzhou Zhu and Min Zhao and Ion Stoica and Joseph E. Gonzalez and Jun Zhu and Jianfei Chen},\n  journal={arXiv preprint arXiv:2509.24006},\n  year={2025}\n}\n\"\"\"\n\nfrom collections.abc import Callable\nfrom dataclasses import dataclass\nfrom typing import Any\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport triton\nimport triton.language as tl\n\nfrom sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import (\n    AttentionBackend,\n    AttentionImpl,\n    AttentionMetadata,\n    AttentionMetadataBuilder,\n)\nfrom sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\n# ==================================SLA Functions===================================\ndef get_block_map(q, k, topk_ratio, BLKQ=64, BLKK=64):\n    arg_k = k - torch.mean(\n        k, dim=-2, keepdim=True\n    )  # smooth-k technique in SageAttention\n    pooled_qblocks = mean_pool(q, BLKQ)\n    pooled_kblocks = mean_pool(arg_k, BLKK)\n    pooled_score = pooled_qblocks @ pooled_kblocks.transpose(-1, -2)\n\n    K = pooled_score.shape[-1]\n    topk = min(K, int(topk_ratio * K))\n    lut = torch.topk(pooled_score, topk, dim=-1, sorted=False).indices\n\n    sparse_map = torch.zeros_like(pooled_score, dtype=torch.int8)\n    sparse_map.scatter_(-1, lut, 1)\n    return sparse_map, lut, topk\n\n\ndef mean_pool(x, BLK):\n    assert x.is_contiguous()\n\n    B, H, L, D = x.shape\n    L_BLOCKS = (L + BLK - 1) // BLK\n    x_mean = torch.empty((B, H, L_BLOCKS, D), device=x.device, dtype=x.dtype)\n\n    grid = (L_BLOCKS, B * H)\n    compress_kernel[grid](x, x_mean, L, D, BLK)\n    return x_mean\n\n\n@triton.jit\ndef compress_kernel(\n    X,\n    XM,\n    L: tl.constexpr,\n    D: tl.constexpr,\n    BLOCK_L: tl.constexpr,\n):\n    idx_l = tl.program_id(0)\n    idx_bh = tl.program_id(1)\n\n    offs_l = idx_l * BLOCK_L + tl.arange(0, BLOCK_L)\n    offs_d = tl.arange(0, D)\n\n    x_offset = idx_bh * L * D\n    xm_offset = idx_bh * ((L + BLOCK_L - 1) // BLOCK_L) * D\n    x = tl.load(\n        X + x_offset + offs_l[:, None] * D + offs_d[None, :], mask=offs_l[:, None] < L\n    )\n\n    nx = min(BLOCK_L, L - idx_l * BLOCK_L)\n    x_mean = tl.sum(x, axis=0, dtype=tl.float32) / nx\n    tl.store(XM + xm_offset + idx_l * D + offs_d, x_mean.to(XM.dtype.element_ty))\n\n\n@triton.jit\ndef _attn_fwd(\n    Q,\n    K,\n    V,\n    qk_scale: tl.constexpr,\n    topk: tl.constexpr,\n    LUT,\n    LSE,\n    OS,\n    L: tl.constexpr,\n    M_BLOCKS: tl.constexpr,\n    D: tl.constexpr,\n    BLOCK_M: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n):\n    idx_m = tl.program_id(0).to(tl.int64)\n    idx_bh = tl.program_id(1).to(tl.int64)\n\n    qkv_offset = idx_bh * L * D\n    lut_offset = (idx_bh * M_BLOCKS + idx_m) * topk\n    lse_offset = idx_bh * L\n    offs_m = idx_m * BLOCK_M + tl.arange(0, BLOCK_M)\n    offs_n = tl.arange(0, BLOCK_N)\n    offs_d = tl.arange(0, D)\n\n    Q_ptrs = Q + qkv_offset + offs_m[:, None] * D + offs_d[None, :]\n    K_ptrs = K + qkv_offset + offs_n[None, :] * D + offs_d[:, None]\n    V_ptrs = V + qkv_offset + offs_n[:, None] * D + offs_d[None, :]\n    OS_ptrs = OS + qkv_offset + offs_m[:, None] * D + offs_d[None, :]\n    LUT_ptr = LUT + lut_offset\n    LSE_ptrs = LSE + lse_offset + offs_m\n\n    m_i = tl.full([BLOCK_M], -float(\"inf\"), dtype=tl.float32)\n    l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n    o_s = tl.zeros([BLOCK_M, D], dtype=tl.float32)\n\n    q = tl.load(Q_ptrs, mask=offs_m[:, None] < L)\n    for block_idx in tl.range(topk):\n        idx_n = tl.load(LUT_ptr + block_idx)\n        n_mask = offs_n < L - idx_n * BLOCK_N\n\n        k = tl.load(K_ptrs + idx_n * BLOCK_N * D, mask=n_mask[None, :])\n        qk = tl.dot(q, k) * (qk_scale * 1.4426950408889634)  # = 1 / ln(2)\n        if L - idx_n * BLOCK_N < BLOCK_N:\n            qk = tl.where(n_mask[None, :], qk, float(\"-inf\"))\n\n        v = tl.load(V_ptrs + idx_n * BLOCK_N * D, mask=n_mask[:, None])\n        local_m = tl.max(qk, 1)\n        new_m = tl.maximum(m_i, local_m)\n        qk = qk - new_m[:, None]\n\n        p = tl.math.exp2(qk)\n        l_ij = tl.sum(p, 1)\n        alpha = tl.math.exp2(m_i - new_m)\n        o_s = o_s * alpha[:, None]\n        o_s += tl.dot(p.to(v.dtype), v)\n\n        l_i = l_i * alpha + l_ij\n        m_i = new_m\n\n    o_s = o_s / l_i[:, None]\n    tl.store(OS_ptrs, o_s.to(OS.type.element_ty), mask=offs_m[:, None] < L)\n\n    m_i += tl.math.log2(l_i)\n    tl.store(LSE_ptrs, m_i, mask=offs_m < L)\n\n\ndef _get_cuda_arch(device_index: int) -> str:\n    \"\"\"Get CUDA architecture string for the given device.\"\"\"\n    major, minor = torch.cuda.get_device_capability(device_index)\n    return f\"sm{major}{minor}\"\n\n\n# ==================================SLA Class===================================\nclass SparseLinearAttentionBackend(AttentionBackend):\n    \"\"\"Sparse Linear Attention Backend for efficient attention computation.\"\"\"\n\n    accept_output_buffer: bool = True\n\n    @staticmethod\n    def get_supported_head_sizes() -> list[int]:\n        return [64, 128]\n\n    @staticmethod\n    def get_enum() -> AttentionBackendEnum:\n        return AttentionBackendEnum.SLA_ATTN\n\n    @staticmethod\n    def get_impl_cls() -> type[\"SparseLinearAttentionImpl\"]:\n        return SparseLinearAttentionImpl\n\n    @staticmethod\n    def get_metadata_cls() -> type[\"SparseLinearAttentionMetadata\"]:\n        return SparseLinearAttentionMetadata\n\n    @staticmethod\n    def get_builder_cls() -> type[\"SparseLinearAttentionMetadataBuilder\"]:\n        return SparseLinearAttentionMetadataBuilder\n\n\n@dataclass\nclass SparseLinearAttentionMetadata(AttentionMetadata):\n    \"\"\"Metadata for Sparse Linear Attention computation.\"\"\"\n\n    # Basic attention parameters\n    current_timestep: int\n\n    # Sparse attention configuration\n    topk_ratio: float = 0.1\n\n\nclass SparseLinearAttentionMetadataBuilder(AttentionMetadataBuilder):\n    \"\"\"Builder for SparseLinearAttentionMetadata.\"\"\"\n\n    def __init__(self) -> None:\n        pass\n\n    def prepare(self) -> None:\n        pass\n\n    def build(\n        self,\n        current_timestep: int,\n        topk_ratio: float = 0.1,\n        **kwargs: dict[str, Any],\n    ) -> SparseLinearAttentionMetadata:\n        return SparseLinearAttentionMetadata(\n            current_timestep=current_timestep,\n            topk_ratio=topk_ratio,\n        )\n\n\nclass SparseLinearAttentionImpl(AttentionImpl, nn.Module):\n    \"\"\"Implementation of sparse linear attention for the backend.\"\"\"\n\n    def __init__(\n        self,\n        num_heads: int,\n        head_size: int,\n        causal: bool = False,\n        softmax_scale: float | None = None,\n        num_kv_heads: int | None = None,\n        prefix: str = \"\",\n        # SLA-specific parameters - matched to TurboDiffusion defaults\n        topk_ratio: float = 0.1,  # TurboDiffusion uses topk=0.1\n        feature_map: str = \"softmax\",\n        BLKQ: int = 128,  # TurboDiffusion uses BLKQ=128\n        BLKK: int = 64,  # TurboDiffusion uses BLKK=64\n        use_bf16: bool = True,\n        **extra_impl_args,\n    ) -> None:\n        nn.Module.__init__(self)\n\n        # SLA-specific config\n        self.topk_ratio = topk_ratio\n        self.BLKQ = BLKQ\n        self.BLKK = BLKK\n        self.dtype = torch.bfloat16 if use_bf16 else torch.float16\n\n        # Learnable linear projection for combining sparse + linear attention\n        self.proj_l = nn.Linear(head_size, head_size, dtype=torch.float32)\n\n        # Feature map for linear attention\n        # Type annotation for callables\n        self.feature_map_q: Callable[[torch.Tensor], torch.Tensor]\n        self.feature_map_k: Callable[[torch.Tensor], torch.Tensor]\n        if feature_map == \"elu\":\n            self.feature_map_q = lambda x: F.elu(x) + 1\n            self.feature_map_k = lambda x: F.elu(x) + 1\n        elif feature_map == \"relu\":\n            self.feature_map_q = F.relu\n            self.feature_map_k = F.relu\n        elif feature_map == \"softmax\":\n            self.feature_map_q = lambda x: F.softmax(x, dim=-1)\n            self.feature_map_k = lambda x: F.softmax(x, dim=-1)\n        else:\n            raise ValueError(f\"Unknown feature map: {feature_map}\")\n\n        self._init_weights()\n\n    def _init_weights(self) -> None:\n        \"\"\"Initialize projection weights to zero for residual-like behavior.\"\"\"\n        with torch.no_grad():\n            nn.init.zeros_(self.proj_l.weight)\n            nn.init.zeros_(self.proj_l.bias)  # type: ignore[arg-type]\n\n    def _calc_linear_attention_with_torch(self, q, k, v):\n        kv = torch.matmul(k.transpose(-1, -2), v)\n        k_sum = torch.sum(k, dim=-2, keepdim=True)\n        return torch.matmul(q, kv) / (1e-5 + torch.matmul(q, k_sum.transpose(-1, -2)))\n\n    def forward(\n        self,\n        query: torch.Tensor,\n        key: torch.Tensor,\n        value: torch.Tensor,\n        attn_metadata: SparseLinearAttentionMetadata = None,\n    ) -> torch.Tensor:\n        \"\"\"Forward pass for sparse linear attention.\n\n        Args:\n            query: query tensor of shape (B, H, L, D)\n            key: key tensor of shape (B, H, L, D)\n            value: value tensor of shape (B, H, L, D)\n            attn_metadata: attention metadata containing configuration\n        Returns:\n            output tensor of shape (B, H, L, D)\n        \"\"\"\n        dtype = query.dtype\n\n        # Transpose for computation\n        query = query.transpose(1, 2).contiguous()\n        key = key.transpose(1, 2).contiguous()\n        value = value.transpose(1, 2).contiguous()\n\n        # Get sparse attention map\n        sparse_map, lut, real_topk = get_block_map(\n            query, key, topk_ratio=self.topk_ratio, BLKQ=self.BLKQ, BLKK=self.BLKK\n        )\n\n        # Convert to computation dtype\n        query = query.to(self.dtype)\n        key = key.to(self.dtype)\n        value = value.to(self.dtype)\n\n        # Sparse attention computation\n        o_s = _attention.apply(\n            query, key, value, sparse_map, lut, real_topk, self.BLKQ, self.BLKK\n        )\n\n        # Apply feature maps\n        query = self.feature_map_q(query).contiguous().to(self.dtype)  # c_q\n        key = self.feature_map_k(key).contiguous().to(self.dtype)  # c_k\n        # Linear attention computation\n        o_l = self._calc_linear_attention_with_torch(query, key, value)\n\n        # Apply projection and combine results\n        with torch.amp.autocast(\"cuda\", dtype=self.dtype):\n            o_l = self.proj_l(o_l)\n\n        # Combine sparse and linear attention\n        output = (o_s + o_l).to(dtype).transpose(1, 2)\n\n        return output\n\n\nclass _attention(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, q, k, v, k_block_id, lut, topk, BLOCK_M, BLOCK_N, qk_scale=None):\n        assert q.is_contiguous() and k.is_contiguous() and v.is_contiguous()\n        assert k_block_id.is_contiguous() and lut.is_contiguous()\n\n        # We recommend the following two settings\n        assert BLOCK_M == 64 or BLOCK_M == 128\n        assert BLOCK_N == 64\n\n        B, H, L, D = q.shape\n        if qk_scale is None:\n            qk_scale = D**-0.5\n\n        M_BLOCKS = triton.cdiv(L, BLOCK_M)\n\n        o_s = torch.empty_like(v)\n        lse = torch.empty(q.shape[:-1], device=q.device, dtype=torch.float32)\n\n        grid = (M_BLOCKS, B * H)\n        _attn_fwd[grid](\n            q,\n            k,\n            v,\n            qk_scale,\n            topk,\n            lut,\n            lse,\n            o_s,\n            L,\n            M_BLOCKS,\n            D,\n            BLOCK_M,\n            BLOCK_N,\n            num_warps=4 if q.shape[-1] == 64 else 8,\n            num_stages=3,\n        )\n\n        ctx.save_for_backward(q, k, v, k_block_id, lut, lse, o_s)\n        ctx.qk_scale = qk_scale\n        ctx.topk = topk\n        ctx.BLOCK_M = BLOCK_M\n        ctx.BLOCK_N = BLOCK_N\n        return o_s\n\n\n# ==================================SageSLA Class===================================\nSAGESLA_ENABLED = True\ntry:\n    import spas_sage_attn._fused as fused\n    import spas_sage_attn._qattn as qattn\n    from spas_sage_attn.utils import block_map_lut_triton, get_vanilla_qk_quant\nexcept ImportError:\n    SAGESLA_ENABLED = False\n\nSAGE2PP_ENABLED = True\ntry:\n    from spas_sage_attn._qattn import (\n        qk_int8_sv_f8_accum_f16_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold,\n    )\nexcept ImportError:\n    SAGE2PP_ENABLED = False\n\n\nclass SageSparseLinearAttentionBackend(AttentionBackend):\n    \"\"\"Quantized Sparse-Linear Attention backend using SageAttention kernels.\"\"\"\n\n    accept_output_buffer: bool = True\n\n    @staticmethod\n    def get_supported_head_sizes() -> list[int]:\n        return [64, 128]\n\n    @staticmethod\n    def get_enum() -> AttentionBackendEnum:\n        return AttentionBackendEnum.SAGE_SLA_ATTN\n\n    @staticmethod\n    def get_impl_cls() -> type[\"SageSparseLinearAttentionImpl\"]:\n        return SageSparseLinearAttentionImpl\n\n    @staticmethod\n    def get_metadata_cls() -> type[\"SageSparseLinearAttentionMetadata\"]:\n        return SageSparseLinearAttentionMetadata\n\n    @staticmethod\n    def get_builder_cls() -> type[\"SageSparseLinearAttentionMetadataBuilder\"]:\n        return SageSparseLinearAttentionMetadataBuilder\n\n\n@dataclass\nclass SageSparseLinearAttentionMetadata(AttentionMetadata):\n    \"\"\"Metadata for Sage Sparse Linear Attention computation.\"\"\"\n\n    # Basic attention parameters\n    current_timestep: int\n\n    # Sparse attention configuration\n    topk_ratio: float = 0.1\n\n\nclass SageSparseLinearAttentionMetadataBuilder(AttentionMetadataBuilder):\n    \"\"\"Builder for SageSparseLinearAttentionMetadata.\"\"\"\n\n    def __init__(self) -> None:\n        pass\n\n    def prepare(self) -> None:\n        pass\n\n    def build(\n        self,\n        current_timestep: int,\n        topk_ratio: float = 0.1,\n        **kwargs: dict[str, Any],\n    ) -> SageSparseLinearAttentionMetadata:\n        return SageSparseLinearAttentionMetadata(\n            current_timestep=current_timestep,\n            topk_ratio=topk_ratio,\n        )\n\n\nclass SageSparseLinearAttentionImpl(AttentionImpl, nn.Module):\n    def __init__(\n        self,\n        num_heads: int,\n        head_size: int,\n        causal: bool = False,\n        softmax_scale: float | None = None,\n        num_kv_heads: int | None = None,\n        prefix: str = \"\",\n        topk_ratio: float = 0.5,\n        feature_map: str = \"softmax\",\n        use_bf16: bool = True,\n        **extra_impl_args,\n    ) -> None:\n        nn.Module.__init__(self)\n\n        assert (\n            SAGESLA_ENABLED\n        ), \"Install spas_sage_attn(pip install git+https://github.com/thu-ml/SpargeAttn.git --no-build-isolation) first to enable SageSLA.\"\n\n        self.num_heads = num_heads\n        self.head_size = head_size\n        self.softmax_scale = softmax_scale if softmax_scale else head_size**-0.5\n        self.causal = causal\n        self.prefix = prefix\n\n        self.topk_ratio = topk_ratio\n        self.dtype = torch.bfloat16 if use_bf16 else torch.float16\n\n        # Learnable linear projection for combining sparse + linear attention\n        self.proj_l = nn.Linear(head_size, head_size, dtype=torch.float32)\n\n        # Feature map for linear attention\n        # Type annotation for callables\n        self.feature_map_q: Callable[[torch.Tensor], torch.Tensor]\n        self.feature_map_k: Callable[[torch.Tensor], torch.Tensor]\n        if feature_map == \"elu\":\n            self.feature_map_q = lambda x: F.elu(x) + 1\n            self.feature_map_k = lambda x: F.elu(x) + 1\n        elif feature_map == \"relu\":\n            self.feature_map_q = F.relu\n            self.feature_map_k = F.relu\n        elif feature_map == \"softmax\":\n            self.feature_map_q = lambda x: F.softmax(x, dim=-1)\n            self.feature_map_k = lambda x: F.softmax(x, dim=-1)\n        else:\n            raise ValueError(f\"Unknown feature map: {feature_map}\")\n\n        self._init_weights()\n\n    def _init_weights(self) -> None:\n        \"\"\"Initialize projection weights to zero for residual-like behavior.\"\"\"\n        with torch.no_grad():\n            nn.init.zeros_(self.proj_l.weight)\n            nn.init.zeros_(self.proj_l.bias)  # type: ignore[arg-type]\n\n    def _calc_linear_attention_with_torch(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n    ):\n        kv = torch.matmul(k.transpose(-1, -2), v)\n        k_sum = torch.sum(k, dim=-2, keepdim=True)\n        return torch.matmul(q, kv) / (1e-5 + torch.matmul(q, k_sum.transpose(-1, -2)))\n\n    def forward(\n        self,\n        query: torch.Tensor,\n        key: torch.Tensor,\n        value: torch.Tensor,\n        attn_metadata: AttentionMetadata,\n    ) -> torch.Tensor:\n        \"\"\"Forward pass for Sage Sparse Linear attention with quantized kernels.\n        Args:\n            query: query tensor of shape (B, L, H, D)\n            key: key tensor of shape (B, L, H, D)\n            value: value tensor of shape (B, L, H, D)\n            attn_metadata: attention metadata containing configuration\n        Returns:\n            output tensor of shape (B, L, H, D)\n        \"\"\"\n        dtype = query.dtype\n\n        # Transpose from (B, L, H, D) to SLA format (B, H, L, D)\n        q = query.transpose(1, 2).contiguous()\n        k = key.transpose(1, 2).contiguous()\n        v = value.transpose(1, 2).contiguous()\n\n        # Determine block sizes based on GPU architecture\n        arch = _get_cuda_arch(q.device.index)\n\n        if arch == \"sm90\":\n            BLKQ = 64\n            BLKK = 128\n        else:\n            BLKQ = 128\n            BLKK = 64\n        # Compute block-sparse attention pattern\n        sparse_map, lut, real_topk = get_block_map(\n            q, k, topk_ratio=self.topk_ratio, BLKQ=BLKQ, BLKK=BLKK\n        )\n\n        # Convert to compute dtype\n        q = q.to(self.dtype)\n        k = k.to(self.dtype)\n        v = v.to(self.dtype)\n\n        ########## SPARGE BEGIN ##########\n        km = k.mean(dim=-2, keepdim=True)\n        headdim = q.size(-1)\n        assert headdim in [\n            64,\n            128,\n        ], \"headdim should be in [64, 128]. For other headdim, you can use padding and specify the softmax scale.\"\n\n        # Quantize Q, K to INT8\n        q_int8, q_scale, k_int8, k_scale = get_vanilla_qk_quant(q, k, km, BLKQ, BLKK)\n        lut, valid_block_num = block_map_lut_triton(sparse_map)\n        scale = 1.0 / (headdim**0.5)\n\n        o_s = torch.empty_like(q)\n\n        if arch in (\"sm80\", \"sm86\", \"sm87\"):\n            pvthreshold = torch.full(\n                (q.shape[-3],), 1e6, dtype=torch.float32, device=q.device\n            )\n            v_fp16 = v.to(torch.float16)\n            qattn.qk_int8_sv_f16_accum_f16_block_sparse_attn_inst_buf_with_pv_threshold(\n                q_int8,\n                k_int8,\n                v_fp16,\n                o_s,\n                lut,\n                valid_block_num,\n                pvthreshold,\n                q_scale,\n                k_scale,\n                1,\n                False,\n                1,\n                scale,\n                0,\n            )\n        else:\n            b, h_kv, kv_len, head_dim = v.shape\n            padded_len = (kv_len + 127) // 128 * 128\n            v_transposed_permutted = torch.empty(\n                (b, h_kv, head_dim, padded_len), dtype=v.dtype, device=v.device\n            )\n            fused.transpose_pad_permute_cuda(v, v_transposed_permutted, 1)\n            v_fp8 = torch.empty(\n                v_transposed_permutted.shape, dtype=torch.float8_e4m3fn, device=v.device\n            )\n            v_scale = torch.empty(\n                (b, h_kv, head_dim), dtype=torch.float32, device=v.device\n            )\n            fused.scale_fuse_quant_cuda(\n                v_transposed_permutted, v_fp8, v_scale, kv_len, 2.25, 1\n            )\n\n            if arch == \"sm90\":\n                qattn.qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_sm90(\n                    q_int8,\n                    k_int8,\n                    v_fp8,\n                    o_s,\n                    lut,\n                    valid_block_num,\n                    q_scale,\n                    k_scale,\n                    v_scale,\n                    1,\n                    False,\n                    1,\n                    scale,\n                )\n            else:\n                pvthreshold = torch.full(\n                    (q.shape[-3],), 1e6, dtype=torch.float32, device=q.device\n                )\n                if SAGE2PP_ENABLED:\n                    qk_int8_sv_f8_accum_f16_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold(\n                        q_int8,\n                        k_int8,\n                        v_fp8,\n                        o_s,\n                        lut,\n                        valid_block_num,\n                        pvthreshold,\n                        q_scale,\n                        k_scale,\n                        v_scale,\n                        1,\n                        False,\n                        1,\n                        scale,\n                        0,\n                    )\n                else:\n                    qattn.qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold(\n                        q_int8,\n                        k_int8,\n                        v_fp8,\n                        o_s,\n                        lut,\n                        valid_block_num,\n                        pvthreshold,\n                        q_scale,\n                        k_scale,\n                        v_scale,\n                        1,\n                        False,\n                        1,\n                        scale,\n                        0,\n                    )\n\n        ########## SPARGE END ##########\n\n        # Linear attention with feature maps\n        q_linear = self.feature_map_q(q).contiguous().to(self.dtype)\n        k_linear = self.feature_map_k(k).contiguous().to(self.dtype)\n        o_l = self._calc_linear_attention_with_torch(q_linear, k_linear, v)\n\n        # Project linear attention output and combine\n        with torch.amp.autocast(\"cuda\", dtype=self.dtype):\n            o_l = self.proj_l(o_l)\n\n        # Combine sparse and linear outputs\n        output = (o_s + o_l).to(dtype).transpose(1, 2)\n\n        return output\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/layers/attention/backends/sparse_video_gen_2_attn.py",
    "content": "\"\"\"\nSparse Video Gen 2 (SAP) attention backend.\n\nThis is a baseline integration that wires the backend into the\nattention framework.\n\nAdapted from https://github.com/svg-project/Sparse-VideoGen/blob/main/svg/models/wan/attention.py\n\"\"\"\n\nfrom dataclasses import dataclass, field\nfrom typing import Any\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.nn.attention import SDPBackend, sdpa_kernel\n\ntry:\n    from svg.kernels.triton.permute import (\n        apply_inverse_permutation_triton,\n        permute_tensor_by_labels_triton,\n    )\n    from svg.kmeans_utils import (\n        batch_kmeans_Euclid,\n        dynamic_block_sparse_fwd_flashinfer,\n        identify_dynamic_map,\n    )\n\n    svg2_available = True\nexcept ImportError:\n    svg2_available = False\n\nfrom sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import (\n    AttentionBackend,\n    AttentionImpl,\n    AttentionMetadata,\n    AttentionMetadataBuilder,\n)\nfrom sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\nclass SparseVideoGen2AttentionBackend(AttentionBackend):\n\n    accept_output_buffer: bool = True\n\n    @staticmethod\n    def get_supported_head_sizes() -> list[int]:\n        return [64, 128, 256]\n\n    @staticmethod\n    def get_enum() -> AttentionBackendEnum:\n        return AttentionBackendEnum.SPARSE_VIDEO_GEN_2_ATTN\n\n    @staticmethod\n    def get_impl_cls() -> type[\"SparseVideoGen2AttentionImpl\"]:\n        return SparseVideoGen2AttentionImpl\n\n    @staticmethod\n    def get_metadata_cls() -> type[\"SparseVideoGen2AttentionMetadata\"]:\n        return SparseVideoGen2AttentionMetadata\n\n    @staticmethod\n    def get_builder_cls() -> type[\"SparseVideoGen2AttentionMetadataBuilder\"]:\n        return SparseVideoGen2AttentionMetadataBuilder\n\n\n@dataclass\nclass Svg2LayerCache:\n    # centroids for kmeans clustering\n    q_centroids: torch.Tensor | None = None\n    k_centroids: torch.Tensor | None = None\n    centroids_initialized: bool = False\n\n\n@dataclass\nclass Svg2Cache:\n    layers: dict[int, Svg2LayerCache] = field(default_factory=dict)\n\n    def get_layer(self, layer_idx: int) -> Svg2LayerCache:\n        layer_cache = self.layers.get(layer_idx)\n        if layer_cache is None:\n            layer_cache = Svg2LayerCache()\n            self.layers[layer_idx] = layer_cache\n        return layer_cache\n\n\n@dataclass\nclass SparseVideoGen2AttentionMetadata(AttentionMetadata):\n    current_timestep: int\n    num_q_centroids: int\n    num_k_centroids: int\n    top_p_kmeans: float\n    min_kc_ratio: float\n    kmeans_iter_init: int\n    kmeans_iter_step: int\n    zero_step_kmeans_init: bool\n    first_layers_fp: float\n    first_times_fp: float\n    context_length: int\n    num_frame: int\n    frame_size: int\n    cache: Svg2Cache\n    prompt_length: int | None = None\n    max_seqlen_q: int | None = None\n    max_seqlen_k: int | None = None\n\n\ndef _require_kwarg(kwargs: dict[str, Any], name: str) -> Any:\n    if name not in kwargs:\n        raise ValueError(\n            f\"Missing required argument for SparseVideoGen2Attention: {name}\"\n        )\n    return kwargs[name]\n\n\nclass SparseVideoGen2AttentionMetadataBuilder(AttentionMetadataBuilder):\n\n    def __init__(self) -> None:\n        pass\n\n    def prepare(self) -> None:\n        pass\n\n    def build(  # type: ignore[override]\n        self,\n        current_timestep: int,\n        raw_latent_shape: tuple[int, ...],\n        patch_size: tuple[int, int, int],\n        cache: Svg2Cache,\n        num_q_centroids: int,\n        num_k_centroids: int,\n        top_p_kmeans: float,\n        min_kc_ratio: float,\n        kmeans_iter_init: int,\n        kmeans_iter_step: int,\n        zero_step_kmeans_init: bool,\n        first_layers_fp: float,\n        first_times_fp: float,\n        context_length: int = 0,\n        prompt_length: int | None = None,\n        **kwargs: dict[str, Any],\n    ) -> SparseVideoGen2AttentionMetadata:\n        raw_shape = tuple(raw_latent_shape)\n        if len(raw_shape) == 5:\n            t, h, w = raw_shape[2:5]\n        elif len(raw_shape) == 3:\n            t, h, w = raw_shape\n        else:\n            raise ValueError(\n                \"raw_latent_shape must be (T, H, W) or (B, C, T, H, W) for SAP attention\"\n            )\n        pt, ph, pw = patch_size\n        if t % pt != 0 or h % ph != 0 or w % pw != 0:\n            raise ValueError(\n                \"raw_latent_shape must be divisible by patch_size for SAP attention\"\n            )\n\n        num_frame = t // pt\n        frame_size = (h // ph) * (w // pw)\n\n        return SparseVideoGen2AttentionMetadata(\n            current_timestep=current_timestep,\n            num_q_centroids=num_q_centroids,\n            num_k_centroids=num_k_centroids,\n            top_p_kmeans=top_p_kmeans,\n            min_kc_ratio=min_kc_ratio,\n            kmeans_iter_init=kmeans_iter_init,\n            kmeans_iter_step=kmeans_iter_step,\n            zero_step_kmeans_init=zero_step_kmeans_init,\n            first_layers_fp=first_layers_fp,\n            first_times_fp=first_times_fp,\n            context_length=context_length,\n            prompt_length=prompt_length,\n            num_frame=num_frame,\n            frame_size=frame_size,\n            cache=cache,\n        )\n\n\nclass SparseVideoGen2AttentionImpl(AttentionImpl):\n\n    def __init__(\n        self,\n        num_heads: int,\n        head_size: int,\n        causal: bool,\n        softmax_scale: float,\n        num_kv_heads: int | None = None,\n        prefix: str = \"\",\n        **extra_impl_args,\n    ) -> None:\n        if causal:\n            raise ValueError(\n                \"Sparse Video Gen 2 attention does not support causal attention\"\n            )\n        if not svg2_available:\n            raise ImportError(\n                \"Sparse Video Gen 2 attention backend requires svg package to be installed\"\n                \"Please install it by following the instructions at \"\n                \"https://github.com/svg-project/Sparse-VideoGen\"\n            )\n        self.prefix = prefix\n        self.layer_idx = self._get_layer_idx(prefix)\n\n    def _get_layer_idx(self, prefix: str) -> int:\n        parts = prefix.split(\".\")\n        if len(parts) < 3:\n            raise ValueError(\n                f\"Invalid prefix for SparseVideoGen2AttentionImpl: {prefix}\"\n            )\n        return int(parts[-3])\n\n    def kmeans_init(\n        self,\n        query: torch.Tensor,\n        key: torch.Tensor,\n        attn_metadata: SparseVideoGen2AttentionMetadata,\n    ):\n        cfg, num_heads, seq_len, dim = query.size()\n        qlabels, qcentroids, qcluster_sizes, qiter = batch_kmeans_Euclid(\n            query.reshape(cfg * num_heads, seq_len, dim),\n            n_clusters=attn_metadata.num_q_centroids,\n            max_iters=attn_metadata.kmeans_iter_init,\n        )\n        klabels, kcentroids, kcluster_sizes, kiter = batch_kmeans_Euclid(\n            key.reshape(cfg * num_heads, seq_len, dim),\n            n_clusters=attn_metadata.num_k_centroids,\n            max_iters=attn_metadata.kmeans_iter_init,\n        )\n\n        layer_cache = attn_metadata.cache.get_layer(self.layer_idx)\n        layer_cache.q_centroids = qcentroids\n        layer_cache.k_centroids = kcentroids\n\n        return (\n            qlabels,\n            qcentroids,\n            qcluster_sizes,\n            qiter,\n            klabels,\n            kcentroids,\n            kcluster_sizes,\n            kiter,\n        )\n\n    def kmeans_step(\n        self,\n        query: torch.Tensor,\n        key: torch.Tensor,\n        attn_metadata: SparseVideoGen2AttentionMetadata,\n    ):\n        cfg, num_heads, seq_len, dim = query.size()\n        layer_cache = attn_metadata.cache.get_layer(self.layer_idx)\n        qlabels, qcentroids, qcluster_sizes, qiter = batch_kmeans_Euclid(\n            query.reshape(cfg * num_heads, seq_len, dim),\n            n_clusters=attn_metadata.num_q_centroids,\n            max_iters=attn_metadata.kmeans_iter_step,\n            init_centroids=layer_cache.q_centroids,\n        )\n        klabels, kcentroids, kcluster_sizes, kiter = batch_kmeans_Euclid(\n            key.reshape(cfg * num_heads, seq_len, dim),\n            n_clusters=attn_metadata.num_k_centroids,\n            max_iters=attn_metadata.kmeans_iter_step,\n            init_centroids=layer_cache.k_centroids,\n        )\n\n        layer_cache.q_centroids = qcentroids\n        layer_cache.k_centroids = kcentroids\n\n        return (\n            qlabels,\n            qcentroids,\n            qcluster_sizes,\n            qiter,\n            klabels,\n            kcentroids,\n            kcluster_sizes,\n            kiter,\n        )\n\n    def kmeans_clustering(\n        self,\n        query: torch.Tensor,\n        key: torch.Tensor,\n        attn_metadata: SparseVideoGen2AttentionMetadata,\n    ):\n        layer_cache = attn_metadata.cache.get_layer(self.layer_idx)\n        if not layer_cache.centroids_initialized:\n            (\n                qlabels,\n                qcentroids,\n                qcluster_sizes,\n                qiter,\n                klabels,\n                kcentroids,\n                kcluster_sizes,\n                kiter,\n            ) = self.kmeans_init(query, key, attn_metadata)\n            layer_cache.centroids_initialized = True\n            logger.debug(\n                \"Centroids initialized at layer %s (init iters: %s).\",\n                self.layer_idx,\n                attn_metadata.kmeans_iter_init,\n            )\n        else:\n            (\n                qlabels,\n                qcentroids,\n                qcluster_sizes,\n                qiter,\n                klabels,\n                kcentroids,\n                kcluster_sizes,\n                kiter,\n            ) = self.kmeans_step(query, key, attn_metadata)\n\n        return (\n            qlabels,\n            qcentroids,\n            qcluster_sizes,\n            qiter,\n            klabels,\n            kcentroids,\n            kcluster_sizes,\n            kiter,\n        )\n\n    def semantic_aware_permutation(\n        self,\n        query: torch.Tensor,\n        key: torch.Tensor,\n        value: torch.Tensor,\n        attn_metadata: SparseVideoGen2AttentionMetadata,\n    ):\n        cfg, num_heads, seq_len, dim = query.size()\n\n        # 1. Kmeans clustering\n        (\n            qlabels,\n            qcentroids,\n            qcluster_sizes,\n            qiter,\n            klabels,\n            kcentroids,\n            kcluster_sizes,\n            kiter,\n        ) = self.kmeans_clustering(query, key, attn_metadata)\n\n        # 2. Identify dynamic map\n        q_cluster_sizes = qcluster_sizes.view(\n            cfg, num_heads, attn_metadata.num_q_centroids\n        )\n        k_cluster_sizes = kcluster_sizes.view(\n            cfg, num_heads, attn_metadata.num_k_centroids\n        )\n\n        dynamic_map = identify_dynamic_map(\n            qcentroids.view(cfg, num_heads, attn_metadata.num_q_centroids, dim),\n            kcentroids.view(cfg, num_heads, attn_metadata.num_k_centroids, dim),\n            q_cluster_sizes,\n            k_cluster_sizes,\n            attn_metadata.top_p_kmeans,\n            attn_metadata.min_kc_ratio,\n        )\n\n        # 3. Permute the query, key, value\n        q_permuted, q_sorted_indices = permute_tensor_by_labels_triton(\n            query, qlabels, dim=2\n        )\n        k_permuted, k_sorted_indices = permute_tensor_by_labels_triton(\n            key, klabels, dim=2\n        )\n        v_permuted, v_sorted_indices = permute_tensor_by_labels_triton(\n            value, klabels, dim=2, sorted_indices=k_sorted_indices\n        )\n\n        return (\n            q_permuted,\n            k_permuted,\n            v_permuted,\n            dynamic_map,\n            q_cluster_sizes,\n            k_cluster_sizes,\n            q_sorted_indices,\n        )\n\n    def _hunyuan_dynamic_map_post_processing(\n        self,\n        q_perm: torch.Tensor,\n        k_perm: torch.Tensor,\n        v_perm: torch.Tensor,\n        query: torch.Tensor,\n        key: torch.Tensor,\n        value: torch.Tensor,\n        dyn_map: torch.Tensor,\n        qc_sz_s: torch.Tensor,\n        kc_sz_s: torch.Tensor,\n        q_sorted_indices: torch.Tensor,\n        video_length: int,\n        context_length: int,\n        prompt_length: int,\n        unprompt_length: int,\n    ) -> tuple[\n        torch.Tensor,\n        torch.Tensor,\n        torch.Tensor,\n        torch.Tensor,\n        torch.Tensor,\n        torch.Tensor,\n        torch.Tensor,\n    ]:\n        # Place the permuted video tokens back and keep text tokens at the tail.\n        query[:, :, :-context_length, :] = q_perm\n        key[:, :, :-context_length, :] = k_perm\n        value[:, :, :-context_length, :] = v_perm\n\n        # Add prompt/unprompt clusters to the dynamic map.\n        dyn_map = F.pad(dyn_map, (0, 2, 0, 2), value=0)\n        dyn_map[:, :, -2, :-1] = True\n        dyn_map[:, :, :-1, -2] = True\n        dyn_map[:, :, -1, -1] = True\n\n        qc_sz_s = F.pad(qc_sz_s, (0, 2), value=0)\n        qc_sz_s[:, :, -2] = prompt_length\n        qc_sz_s[:, :, -1] = unprompt_length\n        kc_sz_s = F.pad(kc_sz_s, (0, 2), value=0)\n        kc_sz_s[:, :, -2] = prompt_length\n        kc_sz_s[:, :, -1] = unprompt_length\n\n        q_sorted_indices = F.pad(q_sorted_indices, (0, context_length), value=0)\n        q_sorted_indices[:, video_length:] = torch.arange(\n            video_length,\n            video_length + context_length,\n            device=q_sorted_indices.device,\n        )\n        return query, key, value, dyn_map, qc_sz_s, kc_sz_s, q_sorted_indices\n\n    def forward(\n        self,\n        query: torch.Tensor,\n        key: torch.Tensor,\n        value: torch.Tensor,\n        attn_metadata: SparseVideoGen2AttentionMetadata,\n    ) -> torch.Tensor:\n        torch.backends.cuda.preferred_linalg_library(backend=\"magma\")\n        res = None\n        # bshd -> bhsd\n        query = query.transpose(1, 2).contiguous()\n        key = key.transpose(1, 2).contiguous()\n        value = value.transpose(1, 2).contiguous()\n        batch_size, num_heads, seq_len, dim = query.size()\n\n        context_length, num_frame, frame_size = (\n            attn_metadata.context_length,\n            attn_metadata.num_frame,\n            attn_metadata.frame_size,\n        )\n        prompt_length = attn_metadata.prompt_length\n        if prompt_length is None:\n            prompt_length = context_length\n\n        assert (\n            seq_len == context_length + num_frame * frame_size\n        ), f\"Query Shape: {seq_len} is not equivalent to {context_length} + {num_frame} * {frame_size}\"\n\n        # Determine if we use Full Attention to calculate\n        full_attention_flag = False\n\n        if self.layer_idx < attn_metadata.first_layers_fp:\n            full_attention_flag = True\n        if attn_metadata.current_timestep > attn_metadata.first_times_fp:\n            full_attention_flag = True\n\n        if full_attention_flag:\n            if attn_metadata.zero_step_kmeans_init:\n                video_length = attn_metadata.num_frame * attn_metadata.frame_size\n                query_video = query[:, :, :video_length, :].contiguous()\n                key_video = key[:, :, :video_length, :].contiguous()\n                self.kmeans_clustering(query_video, key_video, attn_metadata)\n\n            with sdpa_kernel(\n                SDPBackend.CUDNN_ATTENTION\n            ):  # not sure why we need to force cudnn here, but it's faster than flash attention\n                output_hidden_states = torch.nn.functional.scaled_dot_product_attention(\n                    query, key, value, dropout_p=0.0, is_causal=False\n                )\n\n            res = output_hidden_states.reshape(\n                batch_size, num_heads, seq_len, dim\n            ).transpose(1, 2)\n        else:\n            if context_length > 0:\n                video_length = num_frame * frame_size\n                unprompt_length = max(context_length - prompt_length, 0)\n                query_video = query[:, :, :video_length, :].contiguous()\n                key_video = key[:, :, :video_length, :].contiguous()\n                value_video = value[:, :, :video_length, :].contiguous()\n\n                (\n                    q_perm,\n                    k_perm,\n                    v_perm,\n                    dyn_map,\n                    qc_sz_s,\n                    kc_sz_s,\n                    q_sorted_indices,\n                ) = self.semantic_aware_permutation(\n                    query_video, key_video, value_video, attn_metadata\n                )\n                (\n                    q_perm,\n                    k_perm,\n                    v_perm,\n                    dyn_map,\n                    qc_sz_s,\n                    kc_sz_s,\n                    q_sorted_indices,\n                ) = self._hunyuan_dynamic_map_post_processing(\n                    q_perm,\n                    k_perm,\n                    v_perm,\n                    query,\n                    key,\n                    value,\n                    dyn_map,\n                    qc_sz_s,\n                    kc_sz_s,\n                    q_sorted_indices,\n                    video_length,\n                    context_length,\n                    prompt_length,\n                    unprompt_length,\n                )\n            else:\n                (\n                    q_perm,\n                    k_perm,\n                    v_perm,\n                    dyn_map,\n                    qc_sz_s,\n                    kc_sz_s,\n                    q_sorted_indices,\n                ) = self.semantic_aware_permutation(query, key, value, attn_metadata)\n\n            output_permuted = dynamic_block_sparse_fwd_flashinfer(\n                q_perm, k_perm, v_perm, dyn_map, qc_sz_s, kc_sz_s, is_cpu=False\n            )\n\n            attn_output = apply_inverse_permutation_triton(\n                output_permuted, q_sorted_indices, dim=2\n            )\n\n            res = attn_output.reshape(batch_size, num_heads, seq_len, dim).transpose(\n                1, 2\n            )\n\n        torch.backends.cuda.preferred_linalg_library(\n            backend=\"default\"\n        )  # reset to default\n        return res.contiguous()\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/layers/attention/backends/video_sparse_attn.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\nimport functools\nimport math\nfrom dataclasses import dataclass\n\nimport torch\n\ntry:\n    from vsa import video_sparse_attn\nexcept ImportError:\n    video_sparse_attn = None\n\nfrom typing import Any\n\nfrom sglang.multimodal_gen.runtime.distributed import get_sp_group\nfrom sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import (\n    AttentionBackend,\n    AttentionImpl,\n    AttentionMetadata,\n    AttentionMetadataBuilder,\n)\nfrom sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\nVSA_TILE_SIZE = (4, 4, 4)\n\n\n@functools.lru_cache(maxsize=10)\ndef get_tile_partition_indices(\n    dit_seq_shape: tuple[int, int, int],\n    tile_size: tuple[int, int, int],\n    device: torch.device,\n) -> torch.LongTensor:\n    T, H, W = dit_seq_shape\n    ts, hs, ws = tile_size\n    indices = torch.arange(T * H * W, device=device, dtype=torch.long).reshape(T, H, W)\n    ls = []\n    for t in range(math.ceil(T / ts)):\n        for h in range(math.ceil(H / hs)):\n            for w in range(math.ceil(W / ws)):\n                ls.append(\n                    indices[\n                        t * ts : min(t * ts + ts, T),\n                        h * hs : min(h * hs + hs, H),\n                        w * ws : min(w * ws + ws, W),\n                    ].flatten()\n                )\n    index = torch.cat(ls, dim=0)\n    return index\n\n\n@functools.lru_cache(maxsize=10)\ndef get_reverse_tile_partition_indices(\n    dit_seq_shape: tuple[int, int, int],\n    tile_size: tuple[int, int, int],\n    device: torch.device,\n) -> torch.LongTensor:\n    return torch.argsort(get_tile_partition_indices(dit_seq_shape, tile_size, device))\n\n\n@functools.lru_cache(maxsize=10)\ndef construct_variable_block_sizes(\n    dit_seq_shape: tuple[int, int, int],\n    num_tiles: tuple[int, int, int],\n    device: torch.device,\n) -> torch.LongTensor:\n    \"\"\"\n    Compute the number of valid (non‑padded) tokens inside every\n    (ts_t × ts_h × ts_w) tile after padding ‑‑ flattened in the order\n    (t‑tile, h‑tile, w‑tile) that `rearrange` uses.\n\n    Returns\n    -------\n    torch.LongTensor  # shape: [∏ full_window_size]\n    \"\"\"\n    # unpack\n    t, h, w = dit_seq_shape\n    ts_t, ts_h, ts_w = VSA_TILE_SIZE\n    n_t, n_h, n_w = num_tiles\n\n    def _sizes(dim_len: int, tile: int, n_tiles: int) -> torch.LongTensor:\n        \"\"\"Vector with the size of each tile along one dimension.\"\"\"\n        sizes = torch.full((n_tiles,), tile, dtype=torch.int, device=device)\n        # size of last (possibly partial) tile\n        remainder = dim_len - (n_tiles - 1) * tile\n        sizes[-1] = remainder if remainder > 0 else tile\n        return sizes\n\n    t_sizes = _sizes(t, ts_t, n_t)  # [n_t]\n    h_sizes = _sizes(h, ts_h, n_h)  # [n_h]\n    w_sizes = _sizes(w, ts_w, n_w)  # [n_w]\n\n    # broadcast‑multiply to get voxels per tile, then flatten\n    block_sizes = (\n        t_sizes[:, None, None]  # [n_t, 1,   1]\n        * h_sizes[None, :, None]  # [1,   n_h, 1]\n        * w_sizes[None, None, :]  # [1,   1,   n_w]\n    ).reshape(\n        -1\n    )  # [n_t * n_h * n_w]\n\n    return block_sizes\n\n\n@functools.lru_cache(maxsize=10)\ndef get_non_pad_index(\n    variable_block_sizes: torch.LongTensor,\n    max_block_size: int,\n):\n    n_win = variable_block_sizes.shape[0]\n    device = variable_block_sizes.device\n    starts_pad = torch.arange(n_win, device=device) * max_block_size\n    index_pad = (\n        starts_pad[:, None] + torch.arange(max_block_size, device=device)[None, :]\n    )\n    index_mask = (\n        torch.arange(max_block_size, device=device)[None, :]\n        < variable_block_sizes[:, None]\n    )\n    return index_pad[index_mask]\n\n\nclass VideoSparseAttentionBackend(AttentionBackend):\n\n    accept_output_buffer: bool = True\n\n    @staticmethod\n    def get_supported_head_sizes() -> list[int]:\n        return [64, 128]\n\n    @staticmethod\n    def get_enum() -> AttentionBackendEnum:\n        return AttentionBackendEnum.VIDEO_SPARSE_ATTN\n\n    @staticmethod\n    def get_impl_cls() -> type[\"VideoSparseAttentionImpl\"]:\n        return VideoSparseAttentionImpl\n\n    @staticmethod\n    def get_metadata_cls() -> type[\"VideoSparseAttentionMetadata\"]:\n        return VideoSparseAttentionMetadata\n\n    @staticmethod\n    def get_builder_cls() -> type[\"VideoSparseAttentionMetadataBuilder\"]:\n        return VideoSparseAttentionMetadataBuilder\n\n\n@dataclass\nclass VideoSparseAttentionMetadata(AttentionMetadata):\n    current_timestep: int\n    dit_seq_shape: list[int]\n    VSA_sparsity: float\n    num_tiles: list[int]\n    total_seq_length: int\n    tile_partition_indices: torch.LongTensor\n    reverse_tile_partition_indices: torch.LongTensor\n    variable_block_sizes: torch.LongTensor\n    non_pad_index: torch.LongTensor\n\n    # adaption for FastWan2.1-T2V-1.3B-Diffusers\n    # Sequence lengths for the forward batch\n    # Maximum sequence length for query\n    max_seqlen_q: int = 1\n    # Maximum sequence length for key\n    max_seqlen_k: int = 0\n\n\nclass VideoSparseAttentionMetadataBuilder(AttentionMetadataBuilder):\n\n    def __init__(self):\n        pass\n\n    def prepare(self):\n        pass\n\n    def build(  # type: ignore\n        self,\n        current_timestep: int,\n        raw_latent_shape: tuple[int, int, int],\n        patch_size: tuple[int, int, int],\n        VSA_sparsity: float,\n        device: torch.device,\n        **kwargs: dict[str, Any],\n    ) -> VideoSparseAttentionMetadata:\n        patch_size = patch_size\n        dit_seq_shape = (\n            raw_latent_shape[0] // patch_size[0],\n            raw_latent_shape[1] // patch_size[1],\n            raw_latent_shape[2] // patch_size[2],\n        )\n\n        num_tiles = (\n            math.ceil(dit_seq_shape[0] / VSA_TILE_SIZE[0]),\n            math.ceil(dit_seq_shape[1] / VSA_TILE_SIZE[1]),\n            math.ceil(dit_seq_shape[2] / VSA_TILE_SIZE[2]),\n        )\n        total_seq_length = math.prod(dit_seq_shape)\n\n        tile_partition_indices = get_tile_partition_indices(\n            dit_seq_shape, VSA_TILE_SIZE, device\n        )\n        reverse_tile_partition_indices = get_reverse_tile_partition_indices(\n            dit_seq_shape, VSA_TILE_SIZE, device\n        )\n        variable_block_sizes = construct_variable_block_sizes(\n            dit_seq_shape, num_tiles, device\n        )\n        non_pad_index = get_non_pad_index(\n            variable_block_sizes, math.prod(VSA_TILE_SIZE)\n        )\n\n        return VideoSparseAttentionMetadata(\n            current_timestep=current_timestep,\n            dit_seq_shape=dit_seq_shape,  # type: ignore\n            VSA_sparsity=VSA_sparsity,  # type: ignore\n            num_tiles=num_tiles,  # type: ignore\n            total_seq_length=total_seq_length,  # type: ignore\n            tile_partition_indices=tile_partition_indices,  # type: ignore\n            reverse_tile_partition_indices=reverse_tile_partition_indices,\n            variable_block_sizes=variable_block_sizes,\n            non_pad_index=non_pad_index,\n        )\n\n\nclass VideoSparseAttentionImpl(AttentionImpl):\n\n    def __init__(\n        self,\n        num_heads: int,\n        head_size: int,\n        causal: bool,\n        softmax_scale: float,\n        num_kv_heads: int | None = None,\n        prefix: str = \"\",\n        **extra_impl_args,\n    ) -> None:\n        self.prefix = prefix\n        sp_group = get_sp_group()\n        self.sp_size = sp_group.world_size\n\n    def tile(\n        self,\n        x: torch.Tensor,\n        num_tiles: list[int],\n        tile_partition_indices: torch.LongTensor,\n        non_pad_index: torch.LongTensor,\n    ) -> torch.Tensor:\n        t_padded_size = num_tiles[0] * VSA_TILE_SIZE[0]\n        h_padded_size = num_tiles[1] * VSA_TILE_SIZE[1]\n        w_padded_size = num_tiles[2] * VSA_TILE_SIZE[2]\n\n        x_padded = torch.zeros(\n            (\n                x.shape[0],\n                t_padded_size * h_padded_size * w_padded_size,\n                x.shape[-2],\n                x.shape[-1],\n            ),\n            device=x.device,\n            dtype=x.dtype,\n        )\n        x_padded[:, non_pad_index] = x[:, tile_partition_indices]\n        return x_padded\n\n    def untile(\n        self,\n        x: torch.Tensor,\n        reverse_tile_partition_indices: torch.LongTensor,\n        non_pad_index: torch.LongTensor,\n    ) -> torch.Tensor:\n        x = x[:, non_pad_index][:, reverse_tile_partition_indices]\n        return x\n\n    def preprocess_qkv(\n        self,\n        qkv: torch.Tensor,\n        attn_metadata: VideoSparseAttentionMetadata,\n    ) -> torch.Tensor:\n        return self.tile(\n            qkv,\n            attn_metadata.num_tiles,\n            attn_metadata.tile_partition_indices,\n            attn_metadata.non_pad_index,\n        )\n\n    def postprocess_output(\n        self,\n        output: torch.Tensor,\n        attn_metadata: VideoSparseAttentionMetadata,\n    ) -> torch.Tensor:\n        return self.untile(\n            output,\n            attn_metadata.reverse_tile_partition_indices,\n            attn_metadata.non_pad_index,\n        )\n\n    def forward(  # type: ignore[override]\n        self,\n        query: torch.Tensor,\n        key: torch.Tensor,\n        value: torch.Tensor,\n        gate_compress: torch.Tensor,\n        attn_metadata: VideoSparseAttentionMetadata,\n    ) -> torch.Tensor:\n        query = query.transpose(1, 2).contiguous()\n        key = key.transpose(1, 2).contiguous()\n        value = value.transpose(1, 2).contiguous()\n        gate_compress = gate_compress.transpose(1, 2).contiguous()\n\n        VSA_sparsity = attn_metadata.VSA_sparsity\n\n        cur_topk = math.ceil(\n            (1 - VSA_sparsity)\n            * (attn_metadata.total_seq_length / math.prod(VSA_TILE_SIZE))\n        )\n\n        if video_sparse_attn is None:\n            raise NotImplementedError(\"video_sparse_attn is not installed\")\n        hidden_states = video_sparse_attn(\n            query,\n            key,\n            value,\n            variable_block_sizes=attn_metadata.variable_block_sizes,\n            topk=cur_topk,\n            block_size=VSA_TILE_SIZE,\n            compress_attn_weight=gate_compress,\n        ).transpose(1, 2)\n\n        return hidden_states\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/layers/attention/backends/vmoba.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\nimport re\nfrom dataclasses import dataclass\n\nimport torch\nfrom einops import rearrange\nfrom kernel.attn.vmoba_attn.vmoba import (\n    moba_attn_varlen,\n    process_moba_input,\n    process_moba_output,\n)\n\nfrom sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import (\n    AttentionBackend,\n    AttentionImpl,\n    AttentionMetadata,\n    AttentionMetadataBuilder,\n)\nfrom sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\nclass VMOBAAttentionBackend(AttentionBackend):\n\n    accept_output_buffer: bool = True\n\n    @staticmethod\n    def get_enum() -> AttentionBackendEnum:\n        return AttentionBackendEnum.VMOBA_ATTN\n\n    @staticmethod\n    def get_impl_cls() -> type[\"VMOBAAttentionImpl\"]:\n        return VMOBAAttentionImpl\n\n    @staticmethod\n    def get_metadata_cls() -> type[\"VideoMobaAttentionMetadata\"]:\n        return VideoMobaAttentionMetadata\n\n    @staticmethod\n    def get_builder_cls() -> type[\"VideoMobaAttentionMetadataBuilder\"]:\n        return VideoMobaAttentionMetadataBuilder\n\n\n@dataclass\nclass VideoMobaAttentionMetadata(AttentionMetadata):\n    current_timestep: int\n\n    temporal_chunk_size: int\n    temporal_topk: int\n    spatial_chunk_size: tuple[int, int]\n    spatial_topk: int\n    st_chunk_size: tuple[int, int, int]\n    st_topk: int\n\n    moba_select_mode: str\n    moba_threshold: float\n    moba_threshold_type: str\n    patch_resolution: list[int]\n\n    first_full_step: int = 12\n    first_full_layer: int = 0\n    # temporal_layer -> spatial_layer -> st_layer\n    temporal_layer: int = 1\n    spatial_layer: int = 1\n    st_layer: int = 1\n\n\ndef pad_input(hidden_states, indices, batch, seqlen):\n    \"\"\"\n    Arguments:\n        hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.\n        indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.\n        batch: int, batch size for the padded sequence.\n        seqlen: int, maximum sequence length for the padded sequence.\n    Return:\n        hidden_states: (batch, seqlen, ...)\n    \"\"\"\n    dim = hidden_states.shape[1:]\n    output = torch.zeros(\n        (batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype\n    )\n    output[indices] = hidden_states\n    return rearrange(output, \"(b s) ... -> b s ...\", b=batch)\n\n\nclass VideoMobaAttentionMetadataBuilder(AttentionMetadataBuilder):\n\n    def __init__(self):\n        pass\n\n    def prepare(self):\n        pass\n\n    def build(  # type: ignore\n        self,\n        current_timestep: int,\n        raw_latent_shape: tuple[int, int, int],\n        patch_size: tuple[int, int, int],\n        temporal_chunk_size: int,\n        temporal_topk: int,\n        spatial_chunk_size: tuple[int, int],\n        spatial_topk: int,\n        st_chunk_size: tuple[int, int, int],\n        st_topk: int,\n        moba_select_mode: str = \"threshold\",\n        moba_threshold: float = 0.25,\n        moba_threshold_type: str = \"query_head\",\n        device: torch.device = None,\n        first_full_layer: int = 0,\n        first_full_step: int = 12,\n        temporal_layer: int = 1,\n        spatial_layer: int = 1,\n        st_layer: int = 1,\n        **kwargs,\n    ) -> VideoMobaAttentionMetadata:\n        if device is None:\n            device = torch.device(\"cpu\")\n        assert (\n            raw_latent_shape[0] % patch_size[0] == 0\n            and raw_latent_shape[1] % patch_size[1] == 0\n            and raw_latent_shape[2] % patch_size[2] == 0\n        ), f\"spatial patch_resolution {raw_latent_shape} should be divisible by patch_size {patch_size}\"\n        patch_resolution = [\n            t // pt for t, pt in zip(raw_latent_shape, patch_size, strict=False)\n        ]\n\n        return VideoMobaAttentionMetadata(\n            current_timestep=current_timestep,\n            temporal_chunk_size=temporal_chunk_size,\n            temporal_topk=temporal_topk,\n            spatial_chunk_size=spatial_chunk_size,\n            spatial_topk=spatial_topk,\n            st_chunk_size=st_chunk_size,\n            st_topk=st_topk,\n            moba_select_mode=moba_select_mode,\n            moba_threshold=moba_threshold,\n            moba_threshold_type=moba_threshold_type,\n            patch_resolution=patch_resolution,\n            first_full_layer=first_full_layer,\n            first_full_step=first_full_step,\n            temporal_layer=temporal_layer,\n            spatial_layer=spatial_layer,\n            st_layer=st_layer,\n        )\n\n\nclass VMOBAAttentionImpl(AttentionImpl):\n\n    def __init__(\n        self,\n        num_heads,\n        head_size,\n        softmax_scale,\n        causal=False,\n        num_kv_heads=None,\n        prefix=\"\",\n        **extra_impl_args,\n    ) -> None:\n        self.prefix = prefix\n        self.layer_idx = self._get_layer_idx(prefix)\n\n        self.pad_input = pad_input\n\n    def _get_layer_idx(self, prefix: str) -> int | None:\n        match = re.search(r\"blocks\\.(\\d+)\", prefix)\n        if not match:\n            raise ValueError(f\"Invalid prefix: {prefix}\")\n        return int(match.group(1))\n\n    def forward(\n        self,\n        query: torch.Tensor,\n        key: torch.Tensor,\n        value: torch.Tensor,\n        attn_metadata: AttentionMetadata,\n    ) -> torch.Tensor:\n        \"\"\"\n        query: [B, L, H, D]\n        key:   [B, L, H, D]\n        value: [B, L, H, D]\n        attn_metadata: AttentionMetadata\n        \"\"\"\n        batch_size, sequence_length, num_heads, head_dim = query.shape\n\n        # select chunk type according to layer idx:\n        loop_layer_num = (\n            attn_metadata.temporal_layer\n            + attn_metadata.spatial_layer\n            + attn_metadata.st_layer\n        )\n        moba_layer = self.layer_idx - attn_metadata.first_full_layer\n        if moba_layer % loop_layer_num < attn_metadata.temporal_layer:\n            moba_chunk_size = attn_metadata.temporal_chunk_size\n            moba_topk = attn_metadata.temporal_topk\n        elif (\n            moba_layer % loop_layer_num\n            < attn_metadata.temporal_layer + attn_metadata.spatial_layer\n        ):\n            moba_chunk_size = attn_metadata.spatial_chunk_size\n            moba_topk = attn_metadata.spatial_topk\n        elif (\n            moba_layer % loop_layer_num\n            < attn_metadata.temporal_layer\n            + attn_metadata.spatial_layer\n            + attn_metadata.st_layer\n        ):\n            moba_chunk_size = attn_metadata.st_chunk_size\n            moba_topk = attn_metadata.st_topk\n\n        query, chunk_size = process_moba_input(\n            query, attn_metadata.patch_resolution, moba_chunk_size\n        )\n        key, chunk_size = process_moba_input(\n            key, attn_metadata.patch_resolution, moba_chunk_size\n        )\n        value, chunk_size = process_moba_input(\n            value, attn_metadata.patch_resolution, moba_chunk_size\n        )\n        max_seqlen = query.shape[1]\n        indices_q = torch.arange(\n            0, query.shape[0] * query.shape[1], device=query.device\n        )\n        cu_seqlens = torch.arange(\n            0,\n            query.shape[0] * query.shape[1] + 1,\n            query.shape[1],\n            dtype=torch.int32,\n            device=query.device,\n        )\n        query = rearrange(query, \"b s ... -> (b s) ...\")\n        key = rearrange(key, \"b s ... -> (b s) ...\")\n        value = rearrange(value, \"b s ... -> (b s) ...\")\n\n        # current_timestep=attn_metadata.current_timestep\n        hidden_states = moba_attn_varlen(\n            query,\n            key,\n            value,\n            cu_seqlens=cu_seqlens,\n            max_seqlen=max_seqlen,\n            moba_chunk_size=chunk_size,\n            moba_topk=moba_topk,\n            select_mode=attn_metadata.moba_select_mode,\n            simsum_threshold=attn_metadata.moba_threshold,\n            threshold_type=attn_metadata.moba_threshold_type,\n        )\n        hidden_states = self.pad_input(\n            hidden_states, indices_q, batch_size, sequence_length\n        )\n        hidden_states = process_moba_output(\n            hidden_states, attn_metadata.patch_resolution, moba_chunk_size\n        )\n\n        return hidden_states\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/layers/attention/layer.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\nfrom typing import Type\n\nimport torch\nimport torch.nn as nn\n\nfrom sglang.multimodal_gen.runtime.distributed.communication_op import (\n    sequence_model_parallel_all_gather,\n    sequence_model_parallel_all_to_all_4D,\n)\nfrom sglang.multimodal_gen.runtime.distributed.parallel_state import (\n    get_ring_parallel_world_size,\n    get_sequence_parallel_world_size,\n    get_sp_group,\n    get_sp_parallel_rank,\n    get_sp_world_size,\n    get_ulysses_parallel_world_size,\n)\nfrom sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import (\n    AttentionImpl,\n)\nfrom sglang.multimodal_gen.runtime.layers.attention.selector import get_attn_backend\nfrom sglang.multimodal_gen.runtime.layers.usp import (\n    _usp_input_all_to_all,\n    _usp_output_all_to_all,\n    ring_attn,\n)\nfrom sglang.multimodal_gen.runtime.managers.forward_context import (\n    ForwardContext,\n    get_forward_context,\n)\nfrom sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum\nfrom sglang.multimodal_gen.utils import get_compute_dtype\n\n\nclass UlyssesAttention(nn.Module):\n    \"\"\"Ulysses-style SequenceParallelism attention layer.\"\"\"\n\n    def __init__(\n        self,\n        num_heads: int,\n        head_size: int,\n        num_kv_heads: int | None = None,\n        softmax_scale: float | None = None,\n        causal: bool = False,\n        supported_attention_backends: set[AttentionBackendEnum] | None = None,\n        prefix: str = \"\",\n        **extra_impl_args,\n    ) -> None:\n        super().__init__()\n        if softmax_scale is None:\n            self.softmax_scale = head_size**-0.5\n        else:\n            self.softmax_scale = softmax_scale\n\n        if num_kv_heads is None:\n            num_kv_heads = num_heads\n\n        dtype = get_compute_dtype()\n        attn_backend = get_attn_backend(\n            head_size, dtype, supported_attention_backends=supported_attention_backends\n        )\n        impl_cls = attn_backend.get_impl_cls()\n\n        self.attn_impl = impl_cls(\n            num_heads=num_heads,\n            head_size=head_size,\n            causal=causal,\n            softmax_scale=self.softmax_scale,\n            num_kv_heads=num_kv_heads,\n            prefix=f\"{prefix}.impl\",\n            **extra_impl_args,\n        )\n        self.num_heads = num_heads\n        self.head_size = head_size\n        self.num_kv_heads = num_kv_heads\n        self.backend = attn_backend.get_enum()\n        self.dtype = dtype\n\n    def forward(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        replicated_q: torch.Tensor | None = None,\n        replicated_k: torch.Tensor | None = None,\n        replicated_v: torch.Tensor | None = None,\n    ) -> tuple[torch.Tensor, torch.Tensor | None]:\n        \"\"\"Forward pass for distributed attention.\n\n        Args:\n            q (torch.Tensor): Query tensor [batch_size, seq_len, num_heads, head_dim]\n            k (torch.Tensor): Key tensor [batch_size, seq_len, num_heads, head_dim]\n            v (torch.Tensor): Value tensor [batch_size, seq_len, num_heads, head_dim]\n            replicated_q (Optional[torch.Tensor]): Replicated query tensor, typically for text tokens\n            replicated_k (Optional[torch.Tensor]): Replicated key tensor\n            replicated_v (Optional[torch.Tensor]): Replicated value tensor\n\n        Returns:\n            Tuple[torch.Tensor, Optional[torch.Tensor]]: A tuple containing:\n                - o (torch.Tensor): Output tensor after attention for the main sequence\n                - replicated_o (Optional[torch.Tensor]): Output tensor for replicated tokens, if provided\n        \"\"\"\n        # Check input shapes\n        assert q.dim() == 4 and k.dim() == 4 and v.dim() == 4, \"Expected 4D tensors\"\n        batch_size, seq_len, num_heads, head_dim = q.shape\n        local_rank = get_sp_parallel_rank()\n        world_size = get_sp_world_size()\n\n        forward_context: ForwardContext = get_forward_context()\n        ctx_attn_metadata = forward_context.attn_metadata\n\n        # Stack QKV\n        qkv = torch.cat([q, k, v], dim=0)  # [3, seq_len, num_heads, head_dim]\n\n        # Redistribute heads across sequence dimension\n        qkv = sequence_model_parallel_all_to_all_4D(qkv, scatter_dim=2, gather_dim=1)\n        # Apply backend-specific preprocess_qkv\n        qkv = self.attn_impl.preprocess_qkv(qkv, ctx_attn_metadata)\n\n        # Concatenate with replicated QKV if provided\n        if replicated_q is not None:\n            assert replicated_k is not None and replicated_v is not None\n            replicated_qkv = torch.cat(\n                [replicated_q, replicated_k, replicated_v], dim=0\n            )  # [3, seq_len, num_heads, head_dim]\n            heads_per_rank = num_heads // world_size\n            replicated_qkv = replicated_qkv[\n                :, :, local_rank * heads_per_rank : (local_rank + 1) * heads_per_rank\n            ]\n            qkv = torch.cat([qkv, replicated_qkv], dim=1)\n\n        q, k, v = qkv.chunk(3, dim=0)\n\n        output = self.attn_impl.forward(q, k, v, ctx_attn_metadata)\n\n        # Redistribute back if using sequence parallelism\n        replicated_output = None\n        if replicated_q is not None:\n            replicated_output = output[:, seq_len * world_size :]\n            output = output[:, : seq_len * world_size]\n            # TODO: make this asynchronous\n            replicated_output = sequence_model_parallel_all_gather(\n                replicated_output.contiguous(), dim=2\n            )\n        # Apply backend-specific postprocess_output\n        output = self.attn_impl.postprocess_output(output, ctx_attn_metadata)\n\n        output = sequence_model_parallel_all_to_all_4D(\n            output, scatter_dim=1, gather_dim=2\n        )\n        return output, replicated_output\n\n\nclass UlyssesAttention_VSA(UlyssesAttention):\n    \"\"\"Distributed attention layer with VSA support.\"\"\"\n\n    def forward(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        replicated_q: torch.Tensor | None = None,\n        replicated_k: torch.Tensor | None = None,\n        replicated_v: torch.Tensor | None = None,\n        gate_compress: torch.Tensor | None = None,\n    ) -> torch.Tensor:\n        \"\"\"Forward pass for distributed attention.\n\n        Args:\n            q (torch.Tensor): Query tensor [batch_size, seq_len, num_heads, head_dim]\n            k (torch.Tensor): Key tensor [batch_size, seq_len, num_heads, head_dim]\n            v (torch.Tensor): Value tensor [batch_size, seq_len, num_heads, head_dim]\n            gate_compress (torch.Tensor): Gate compress tensor [batch_size, seq_len, num_heads, head_dim]\n            replicated_q (Optional[torch.Tensor]): Replicated query tensor, typically for text tokens\n            replicated_k (Optional[torch.Tensor]): Replicated key tensor\n            replicated_v (Optional[torch.Tensor]): Replicated value tensor\n\n        Returns:\n            Tuple[torch.Tensor, Optional[torch.Tensor]]: A tuple containing:\n                - o (torch.Tensor): Output tensor after attention for the main sequence\n                - replicated_o (Optional[torch.Tensor]): Output tensor for replicated tokens, if provided\n        \"\"\"\n        # Check text tokens are not supported for VSA now\n        assert (\n            replicated_q is None and replicated_k is None and replicated_v is None\n        ), \"Replicated QKV is not supported for VSA now\"\n        # Check input shapes\n        assert q.dim() == 4 and k.dim() == 4 and v.dim() == 4, \"Expected 4D tensors\"\n\n        forward_context: ForwardContext = get_forward_context()\n        ctx_attn_metadata = forward_context.attn_metadata\n\n        # Stack QKV\n        qkvg = torch.cat(\n            [q, k, v, gate_compress], dim=0\n        )  # [3, seq_len, num_heads, head_dim]\n\n        # Redistribute heads across sequence dimension\n        qkvg = sequence_model_parallel_all_to_all_4D(qkvg, scatter_dim=2, gather_dim=1)\n\n        qkvg = self.attn_impl.preprocess_qkv(qkvg, ctx_attn_metadata)\n\n        q, k, v, gate_compress = qkvg.chunk(4, dim=0)\n        output = self.attn_impl.forward(\n            q, k, v, gate_compress=gate_compress, attn_metadata=ctx_attn_metadata\n        )  # type: ignore[call-arg]\n\n        # Apply backend-specific postprocess_output\n        output = self.attn_impl.postprocess_output(output, ctx_attn_metadata)\n\n        output = sequence_model_parallel_all_to_all_4D(\n            output, scatter_dim=1, gather_dim=2\n        )\n\n        return output\n\n\nclass LocalAttention(nn.Module):\n    \"\"\"Attention layer.\"\"\"\n\n    def __init__(\n        self,\n        num_heads: int,\n        head_size: int,\n        num_kv_heads: int | None = None,\n        softmax_scale: float | None = None,\n        causal: bool = False,\n        supported_attention_backends: set[AttentionBackendEnum] | None = None,\n        **extra_impl_args,\n    ) -> None:\n        super().__init__()\n        if softmax_scale is None:\n            self.softmax_scale = head_size**-0.5\n        else:\n            self.softmax_scale = softmax_scale\n        if num_kv_heads is None:\n            num_kv_heads = num_heads\n\n        dtype = get_compute_dtype()\n        attn_backend = get_attn_backend(\n            head_size, dtype, supported_attention_backends=supported_attention_backends\n        )\n        impl_cls = attn_backend.get_impl_cls()\n        self.attn_impl = impl_cls(\n            num_heads=num_heads,\n            head_size=head_size,\n            softmax_scale=self.softmax_scale,\n            num_kv_heads=num_kv_heads,\n            causal=causal,\n            **extra_impl_args,\n        )\n        self.num_heads = num_heads\n        self.head_size = head_size\n        self.num_kv_heads = num_kv_heads\n        self.backend = attn_backend.get_enum()\n        self.dtype = dtype\n\n    def forward(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n    ) -> torch.Tensor:\n        \"\"\"\n        Apply local attention between query, key and value tensors.\n\n        Args:\n            q (torch.Tensor): Query tensor of shape [batch_size, seq_len, num_heads, head_dim]\n            k (torch.Tensor): Key tensor of shape [batch_size, seq_len, num_heads, head_dim]\n            v (torch.Tensor): Value tensor of shape [batch_size, seq_len, num_heads, head_dim]\n\n        Returns:\n            torch.Tensor: Output tensor after local attention\n        \"\"\"\n        # Check input shapes\n        assert q.dim() == 4 and k.dim() == 4 and v.dim() == 4, \"Expected 4D tensors\"\n\n        forward_context: ForwardContext = get_forward_context()\n        ctx_attn_metadata = forward_context.attn_metadata\n\n        output = self.attn_impl.forward(q, k, v, attn_metadata=ctx_attn_metadata)\n        return output\n\n\nclass USPAttention(nn.Module):\n    \"\"\"\n    Ulysses Sequence Parallelism with Ring Attention.\n\n    This class implements the USP algorithm, which is a combination of\n    Ulysses-style all-to-all communication for sequence-head dimension sharding\n    and Ring Attention for fine-grained sequence parallelism within subgroups.\n    \"\"\"\n\n    def __init__(\n        self,\n        num_heads: int,\n        head_size: int,\n        num_kv_heads: int | None = None,\n        softmax_scale: float | None = None,\n        causal: bool = False,\n        supported_attention_backends: set[AttentionBackendEnum] | None = None,\n        prefix: str = \"\",\n        dropout_rate: float = 0.0,\n        skip_sequence_parallel: bool = False,\n        **extra_impl_args,\n    ) -> None:\n        \"\"\"\n        Args:\n            skip_sequence_parallel:\n              when KV is replicated across all SP ranks (e.g. cross-attention to\n              text/image encoder outputs), the full USP pipeline is redundant:\n              each rank's local Q shard can attend directly to the locally-held\n              full KV without any collective communication.\n        \"\"\"\n        super().__init__()\n        if softmax_scale is None:\n            self.softmax_scale = head_size**-0.5\n        else:\n            self.softmax_scale = softmax_scale\n\n        if num_kv_heads is None:\n            num_kv_heads = num_heads\n\n        dtype = get_compute_dtype()\n        attn_backend = get_attn_backend(\n            head_size, dtype, supported_attention_backends=supported_attention_backends\n        )\n        impl_cls: Type[\"AttentionImpl\"] = attn_backend.get_impl_cls()\n        self.attn_impl = impl_cls(\n            num_heads=num_heads,\n            head_size=head_size,\n            causal=causal,\n            softmax_scale=self.softmax_scale,\n            num_kv_heads=num_kv_heads,\n            prefix=f\"{prefix}.impl\",\n            **extra_impl_args,\n        )\n        self.num_heads = num_heads\n        self.head_size = head_size\n        self.num_kv_heads = num_kv_heads\n        self.backend = attn_backend.get_enum()\n        self.dtype = dtype\n        self.causal = causal\n        self.dropout_p = dropout_rate\n\n        self.skip_sequence_parallel = skip_sequence_parallel\n\n    def forward(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        num_replicated_prefix: int = 0,\n    ) -> torch.Tensor:\n        \"\"\"\n        Forward pass for USPAttention.\n\n            q, k, v: [B, S_local, H, D]\n            num_replicated_prefix: number of leading tokens in q/k/v that are\n                replicated (identical) across all SP ranks, e.g. text tokens\n                in FLUX joint attention.  These tokens are excluded from the\n                Ulysses all-to-all so they appear exactly once in the gathered\n                sequence, preserving correct attention weights.\n\n        Note: Replicated tensors are not supported in this implementation.\n        When skip_sequence_parallel=True (set at construction time), all SP\n        communication is bypassed — use this for cross-attention where KV\n        content is replicated across ranks (distinct from replicated_k/v args).\n        \"\"\"\n        forward_context: ForwardContext = get_forward_context()\n        ctx_attn_metadata = forward_context.attn_metadata\n        if self.skip_sequence_parallel or get_sequence_parallel_world_size() == 1:\n            # No sequence parallelism, just run local attention.\n            out = self.attn_impl.forward(q, k, v, ctx_attn_metadata)\n            return out\n\n        sp_size = get_ulysses_parallel_world_size()\n        if sp_size > 1 and num_replicated_prefix > 0:\n            return self._forward_with_replicated_prefix(\n                q, k, v, ctx_attn_metadata, num_replicated_prefix\n            )\n\n        # Ulysses-style All-to-All for sequence/head sharding\n        if sp_size > 1:\n            # -> [B, S, H_local, D]\n            q = _usp_input_all_to_all(q, head_dim=2)\n            k = _usp_input_all_to_all(k, head_dim=2)\n            v = _usp_input_all_to_all(v, head_dim=2)\n\n        # Ring Attention within subgroups or local attention\n        if get_ring_parallel_world_size() > 1:\n            out = ring_attn(\n                q,\n                k,\n                v,\n                attn_impl=self.attn_impl,\n                is_causal=self.causal,\n                dropout_p=self.dropout_p,\n            )\n        else:\n            # -> [B, S, H_local, D]\n            out = self.attn_impl.forward(q, k, v, ctx_attn_metadata)\n\n        # Ulysses-style All-to-All to restore original sharding\n        if sp_size > 1:\n            # -> [B, S_local, H, D]\n            out = _usp_output_all_to_all(out, head_dim=2)\n\n        return out\n\n    def _forward_with_replicated_prefix(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        ctx_attn_metadata,\n        num_rep: int,\n    ) -> torch.Tensor:\n        \"\"\"Ulysses attention where the first *num_rep* tokens are replicated\n        across SP ranks (e.g. text tokens) and should NOT be duplicated by the\n        all-to-all.\n\n        Strategy:\n        1. Split q/k/v into replicated prefix and SP-sharded suffix.\n        2. All-to-all only the sharded suffix (gathers sequence, shards heads).\n        3. Locally slice the replicated prefix to the same head shard.\n        4. Concatenate [prefix_h_local, gathered_suffix] and run attention.\n        5. Split output, all-to-all back the suffix, all-gather prefix heads.\n        \"\"\"\n        sp_size = get_ulysses_parallel_world_size()\n        sp_rank = get_sp_parallel_rank()\n\n        q_rep, q_shard = q[:, :num_rep], q[:, num_rep:]\n        k_rep, k_shard = k[:, :num_rep], k[:, num_rep:]\n        v_rep, v_shard = v[:, :num_rep], v[:, num_rep:]\n\n        q_shard = _usp_input_all_to_all(q_shard, head_dim=2)\n        k_shard = _usp_input_all_to_all(k_shard, head_dim=2)\n        v_shard = _usp_input_all_to_all(v_shard, head_dim=2)\n\n        h_local = q_shard.shape[2]\n        h_start = sp_rank * h_local\n        h_end = h_start + h_local\n        q_rep = q_rep[:, :, h_start:h_end, :].contiguous()\n        k_rep = k_rep[:, :, h_start:h_end, :].contiguous()\n        v_rep = v_rep[:, :, h_start:h_end, :].contiguous()\n\n        q = torch.cat([q_rep, q_shard], dim=1)\n        k = torch.cat([k_rep, k_shard], dim=1)\n        v = torch.cat([v_rep, v_shard], dim=1)\n\n        out = self.attn_impl.forward(q, k, v, ctx_attn_metadata)\n\n        out_rep = out[:, :num_rep]\n        out_shard = out[:, num_rep:]\n\n        out_shard = _usp_output_all_to_all(out_shard, head_dim=2)\n\n        gathered = [torch.empty_like(out_rep) for _ in range(sp_size)]\n        torch.distributed.all_gather(\n            gathered,\n            out_rep.contiguous(),\n            group=get_sp_group().ulysses_group,\n        )\n        out_rep = torch.cat(gathered, dim=2)\n\n        return torch.cat([out_rep, out_shard], dim=1)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/layers/attention/selector.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/attention/selector.py\n\nimport os\nfrom collections.abc import Generator\nfrom contextlib import contextmanager\nfrom functools import cache\nfrom typing import cast\n\nimport torch\n\nfrom sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import (\n    AttentionBackend,\n)\nfrom sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum\nfrom sglang.multimodal_gen.runtime.server_args import get_global_server_args\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname\n\nlogger = init_logger(__name__)\n\n\ndef backend_name_to_enum(backend_name: str) -> AttentionBackendEnum | None:\n    \"\"\"\n    Convert a string backend name to a _Backend enum value.\n\n    Returns:\n    * _Backend: enum value if backend_name is a valid in-tree type\n    * None: otherwise it's an invalid in-tree type or an out-of-tree platform is\n            loaded.\n    \"\"\"\n    assert backend_name is not None\n    return (\n        AttentionBackendEnum[backend_name]\n        if backend_name in AttentionBackendEnum.__members__\n        else None\n    )\n\n\ndef get_env_variable_attn_backend() -> AttentionBackendEnum | None:\n    \"\"\"\n    Get the backend override specified by the sglang-diffusion attention\n    backend environment variable, if one is specified.\n\n    Returns:\n\n    * _Backend enum value if an override is specified\n    * None otherwise\n    \"\"\"\n    backend_name = os.environ.get(STR_BACKEND_ENV_VAR)\n    return None if backend_name is None else backend_name_to_enum(backend_name)\n\n\n# Global state allows a particular choice of backend\n# to be forced, overriding the logic which auto-selects\n# a backend based on system & workload configuration\n# (default behavior if this variable is None)\n#\n# THIS SELECTION TAKES PRECEDENCE OVER THE\n# FASTVIDEO ATTENTION BACKEND ENVIRONMENT VARIABLE\nforced_attn_backend: AttentionBackendEnum | None = None\n\n\ndef global_force_attn_backend(attn_backend: AttentionBackendEnum | None) -> None:\n    \"\"\"\n    Force all attention operations to use a specified backend.\n\n    Passing `None` for the argument re-enables automatic\n    backend selection.,\n\n    Arguments:\n\n    * attn_backend: backend selection (None to revert to auto)\n    \"\"\"\n    global forced_attn_backend\n    forced_attn_backend = attn_backend\n\n\ndef get_global_forced_attn_backend() -> AttentionBackendEnum | None:\n    \"\"\"\n    Get the currently-forced choice of attention backend,\n    or None if auto-selection is currently enabled.\n    \"\"\"\n    return forced_attn_backend\n\n\ndef get_attn_backend(\n    head_size: int,\n    dtype: torch.dtype,\n    supported_attention_backends: set[AttentionBackendEnum] | None = None,\n) -> type[AttentionBackend]:\n    if supported_attention_backends is None:\n        be_tuple = tuple()\n    else:\n        # Sort the backend names to ensure consistent cache key\n        be_tuple = tuple(\n            sorted(list(supported_attention_backends), key=lambda b: b.name)\n        )\n    return _cached_get_attn_backend(head_size, dtype, be_tuple)\n\n\n@cache\ndef _cached_get_attn_backend(\n    head_size: int,\n    dtype: torch.dtype,\n    supported_attention_backends: tuple[AttentionBackendEnum],\n) -> type[AttentionBackend]:\n    # Check whether a particular choice of backend was\n    # previously forced via global_force_attn_backend() or --attention-backend CLI arg.\n    from sglang.multimodal_gen.runtime.platforms import current_platform\n\n    supported_attention_backends = set(supported_attention_backends)\n    selected_backend = None\n    backend_by_global_setting: AttentionBackendEnum | None = (\n        get_global_forced_attn_backend()\n    )\n    if backend_by_global_setting is not None:\n        selected_backend = backend_by_global_setting\n    else:\n        # Check the server arguments for a backend override\n        server_args = get_global_server_args()\n        if server_args.attention_backend is not None:\n            try:\n                selected_backend = AttentionBackendEnum[\n                    server_args.attention_backend.upper()\n                ]\n\n            except KeyError:\n                raise ValueError(\n                    f\"Invalid attention backend '{server_args.attention_backend}' specified via command line. \"\n                    f\"Available options are: {[e.name.lower() for e in AttentionBackendEnum]}\"\n                )\n\n    # get device-specific attn_backend\n    if len(supported_attention_backends) == 0:\n        # all attention backends are allowed\n        pass\n    elif selected_backend is None:\n        logger.debug(f\"Attention backend not specified\")\n    elif selected_backend not in supported_attention_backends:\n        supported_attention_backends_str = [\n            supported_attention_backend.__str__()\n            for supported_attention_backend in supported_attention_backends\n        ]\n        logger.debug(\n            f\"Selected attention backend: '{selected_backend}' not in supported attention backends: {supported_attention_backends_str}\"\n        )\n        selected_backend = None\n\n    attention_cls = current_platform.get_attn_backend_cls_str(\n        selected_backend, head_size, dtype\n    )\n    if not attention_cls:\n        raise ValueError(\n            f\"Invalid attention backend for {current_platform.device_name}\"\n        )\n    return cast(type[AttentionBackend], resolve_obj_by_qualname(attention_cls))\n\n\n@contextmanager\ndef global_force_attn_backend_context_manager(\n    attn_backend: AttentionBackendEnum,\n) -> Generator[None, None, None]:\n    \"\"\"\n    Globally force a sglang-diffusion attention backend override within a\n    context manager, reverting the global attention backend\n    override to its prior state upon exiting the context\n    manager.\n\n    Arguments:\n    * attn_backend: attention backend to force\n\n    Returns:\n\n    * Generator\n    \"\"\"\n\n    # Save the current state of the global backend override (if any)\n    original_value = get_global_forced_attn_backend()\n\n    # Globally force the new backend override\n    global_force_attn_backend(attn_backend)\n\n    # Yield control back to the enclosed code block\n    try:\n        yield\n    finally:\n        # Revert the original global backend override, if any\n        global_force_attn_backend(original_value)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/layers/attention/turbo_layer.py",
    "content": "# copy and modify from https://github.com/thu-ml/TurboDiffusion/blob/main/turbodiffusion/rcm/utils/a2a_cp.py and https://github.com/thu-ml/TurboDiffusion/blob/main/turbodiffusion/SLA/core.py\n\nfrom typing import Any, Callable, List, Tuple, Type, Union\n\nimport torch\nimport torch.distributed as dist\nfrom einops import rearrange\nfrom torch import Tensor\nfrom torch.distributed import ProcessGroup\nfrom torch.nn import Module\n\nfrom sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import (\n    AttentionImpl,\n)\nfrom sglang.multimodal_gen.runtime.layers.attention.backends.sparse_linear_attn import (\n    SageSparseLinearAttentionBackend,\n    SparseLinearAttentionBackend,\n)\nfrom sglang.multimodal_gen.runtime.layers.attention.selector import get_attn_backend\nfrom sglang.multimodal_gen.runtime.managers.forward_context import (\n    ForwardContext,\n    get_forward_context,\n)\nfrom sglang.multimodal_gen.runtime.platforms.interface import AttentionBackendEnum\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.utils import get_compute_dtype\n\nlogger = init_logger(__name__)\n\n\ndef post_all2all(local_seq_2_local_head, seq_world_size):\n    def post_func(input):\n        # b, s, n, h\n        if local_seq_2_local_head:\n            output = rearrange(input, \"w bs seq h d -> bs (w seq) h d\")\n        else:\n            output = rearrange(input, \"w bs s h d -> bs s (w h) d\", w=seq_world_size)\n\n        return output\n\n    return post_func\n\n\ndef single_all_to_all(input, local_seq_2_local_head, group, async_op=False):\n    seq_world_size = dist.get_world_size(group)\n\n    # b, s, n, h\n    if local_seq_2_local_head:\n        bs, local_seq_len, num_total_head, head_dim = input.shape\n        assert (\n            num_total_head % seq_world_size == 0\n        ), f\"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!\"\n        input_t = rearrange(\n            input,\n            \"bs seq_len (w h) d -> w bs seq_len h d\",\n            w=seq_world_size,\n            h=num_total_head // seq_world_size,\n        ).contiguous()\n        post_all2all_fun = post_all2all(local_seq_2_local_head, seq_world_size)\n    else:\n        bs, global_seq_len, num_local_head, head_dim = input.shape\n        input_t = rearrange(\n            input,\n            \"bs (w s) h d -> w bs s h d\",\n            w=seq_world_size,\n            s=global_seq_len // seq_world_size,\n        ).contiguous()\n        post_all2all_fun = post_all2all(local_seq_2_local_head, seq_world_size)\n\n    output = torch.empty_like(input_t)\n    dist.all_to_all_single(output, input_t, group=group, async_op=async_op)\n\n    res = post_all2all_fun(output)\n    return res\n\n\ndef async_a2a_communicate(\n    a2a_inputs: Union[torch.Tensor, List[torch.Tensor]],\n    cp_size: int,\n    cp_group: ProcessGroup,\n    cp_stream: torch.get_device_module().Stream,\n    local_seq_2_local_head: bool,\n) -> Union[torch.Tensor, List[torch.Tensor]]:\n    \"\"\"\n    A2A communication for context parallelism. best used in communicate qkv\n    Modified from Nvidia Transformer Engine.\n    \"\"\"\n    a2a_inputs = [a2a_inputs] if not isinstance(a2a_inputs, list) else a2a_inputs\n    a2a_outputs, a2a_reqs = [None] * len(a2a_inputs), [None] * len(a2a_inputs)\n    a2a_post_fns = [None] * len(a2a_inputs)\n    if local_seq_2_local_head:\n        for i in range(len(a2a_inputs) + 2):\n            if 0 < i < len(a2a_inputs) + 1:\n                a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1])\n                a2a_reqs[i - 1] = torch.distributed.all_to_all_single(\n                    a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True\n                )\n                a2a_post_fns[i - 1] = post_all2all(local_seq_2_local_head, cp_size)\n            if i > 1:\n                with torch.get_device_module().stream(cp_stream):\n                    a2a_reqs[i - 2].wait()\n                    a2a_outputs[i - 2] = a2a_post_fns[i - 2](a2a_outputs[i - 2])\n            if i < len(a2a_inputs):\n                a2a_inputs[i] = rearrange(\n                    a2a_inputs[i], \"bs seq_len (w h) d -> w bs seq_len h d\", w=cp_size\n                ).contiguous()\n    else:\n        for i in range(len(a2a_inputs) + 2):\n            if 0 < i < len(a2a_inputs) + 1:\n                a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1])\n                a2a_reqs[i - 1] = torch.distributed.all_to_all_single(\n                    a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True\n                )\n                a2a_post_fns[i - 1] = post_all2all(local_seq_2_local_head, cp_size)\n            if i < len(a2a_inputs):\n                a2a_inputs[i] = rearrange(\n                    a2a_inputs[i], \"bs (w s) h d -> w bs s h d\", w=cp_size\n                ).contiguous()\n            if i > 1:\n                with torch.get_device_module().stream(cp_stream):\n                    a2a_reqs[i - 2].wait()\n                    a2a_outputs[i - 2] = a2a_post_fns[i - 2](a2a_outputs[i - 2])\n    torch.get_device_module().current_stream().wait_stream(cp_stream)\n    return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs\n\n\nclass _SeqAllToAll(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx: Any, group: dist.ProcessGroup, input: Tensor, local_seq_2_local_head: bool\n    ) -> Tensor:\n        ctx.group = group\n        res = single_all_to_all(input, local_seq_2_local_head, group, False)\n        ctx.local_seq_2_local_head = local_seq_2_local_head\n        return res\n\n    @staticmethod\n    def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None]:\n        return (\n            None,\n            _SeqAllToAll.apply(ctx.group, *grad_output, not ctx.local_seq_2_local_head),\n            None,\n        )\n\n\nclass _SeqAllToAllQKV(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx: Any,\n        group: dist.ProcessGroup,\n        q: Tensor,\n        k: Tensor,\n        v: Tensor,\n        cp_size: int,\n        cp_stream: torch.get_device_module().Stream,\n        local_seq_2_local_head: bool,\n    ) -> Tuple[Tensor, Tensor, Tensor]:\n        ctx.group = group\n        ctx.cp_size = cp_size\n        ctx.cp_stream = cp_stream\n        ctx.local_seq_2_local_head = local_seq_2_local_head\n        q, k, v = async_a2a_communicate(\n            [q, k, v], cp_size, group, cp_stream, local_seq_2_local_head\n        )\n        return q, k, v\n\n    @staticmethod\n    def backward(\n        ctx: Any, *grad_output: Tensor\n    ) -> Tuple[None, Tensor, Tensor, Tensor, None, None, None]:\n        q_grad, k_grad, v_grad = _SeqAllToAllQKV.apply(\n            ctx.group,\n            *grad_output,\n            ctx.cp_size,\n            ctx.cp_stream,\n            not ctx.local_seq_2_local_head,\n        )\n        return (None, q_grad, k_grad, v_grad, None, None, None)\n\n\nclass DistributedAttention(torch.nn.Module):\n    \"\"\"Initialization.\n\n    Arguments:\n        local_attention (Module): local attention with q,k,v\n        sequence_process_group (ProcessGroup): sequence parallel process group\n    \"\"\"\n\n    def __init__(self, local_attention: Union[Module, Callable]) -> None:\n        super(DistributedAttention, self).__init__()\n        self.local_attn = local_attention\n        self.pg = None\n        self.stream = None\n\n    def forward(\n        self, query: Tensor, key: Tensor, value: Tensor, ctx_attn_metadata\n    ) -> Tensor:\n        \"\"\"forward\n\n        Arguments:\n            query (Tensor): query input to the layer\n            key (Tensor): key input to the layer\n            value (Tensor): value input to the layer\n\n        Returns:\n            * output (Tensor): context output\n        \"\"\"\n        if self.pg is None:\n            return self.local_attn(query, key, value, ctx_attn_metadata)\n        pg_size = dist.get_world_size(self.pg)\n        if pg_size < 2:\n            return self.local_attn(query, key, value, ctx_attn_metadata)\n\n        query_layer, key_layer, value_layer = _SeqAllToAllQKV.apply(\n            self.pg, query, key, value, pg_size, self.stream, True\n        )\n        context_layer = self.local_attn(\n            query_layer, key_layer, value_layer, ctx_attn_metadata\n        )\n\n        output = _SeqAllToAll.apply(self.pg, context_layer, False)\n        return output\n\n    def set_context_parallel_group(self, group, stream):\n        self.pg = group\n        self.stream = stream\n\n\nclass MinimalA2AAttnOp(DistributedAttention):\n    def __init__(\n        self,\n        num_heads: int,\n        head_size: int,\n        attention_type: str,\n        topk: float,\n        supported_attention_backends: set[AttentionBackendEnum] | None = None,\n        prefix: str = \"\",\n    ):\n        dtype = get_compute_dtype()\n        attn_backend = get_attn_backend(\n            head_size, dtype, supported_attention_backends=supported_attention_backends\n        )\n        # Maintained for compatibility purposes; can be removed when CI allows setting Attention_backend or when TurboWan supports FA.\n        if attn_backend not in (\n            SparseLinearAttentionBackend,\n            SageSparseLinearAttentionBackend,\n        ):\n            logger.warning_once(\n                \"TurboWan now only supports `sla_attn` or `sage_sla_attn` and has been automatically set to attention_type. Please set --attention-backend to `sla_attn` or `sage_sla_attn`.\"\n            )\n            if attention_type == \"sagesla\":\n                attn_backend = SageSparseLinearAttentionBackend\n            else:\n                attn_backend = SparseLinearAttentionBackend\n        impl_cls: Type[\"AttentionImpl\"] = attn_backend.get_impl_cls()\n        local_attn = impl_cls(\n            num_heads=num_heads,\n            head_size=head_size,\n            topk_ratio=topk,\n            prefix=f\"{prefix}.impl\",\n        )\n        super(MinimalA2AAttnOp, self).__init__(local_attn)\n\n    def set_context_parallel_group(self, process_group, ranks, stream):\n        del ranks\n        super().set_context_parallel_group(process_group, stream)\n\n    def forward(\n        self, query: Tensor, key: Tensor, value: Tensor, *args: Any, **kwargs\n    ) -> Tensor:\n        forward_context: ForwardContext = get_forward_context()\n        ctx_attn_metadata = forward_context.attn_metadata\n        results = super().forward(query, key, value, ctx_attn_metadata)\n        return rearrange(results, \"b ... h l -> b ... (h l)\")\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/layers/custom_op.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/custom_op.py\n\nfrom collections.abc import Callable\nfrom typing import Any\n\nimport torch.nn as nn\n\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n_is_cuda = current_platform.is_cuda()\n\n\nclass CustomOp(nn.Module):\n    \"\"\"\n    Base class for custom ops.\n    Dispatches the forward method to the appropriate backend.\n    \"\"\"\n\n    def __init__(self) -> None:\n        super().__init__()\n        self._forward_method = self.dispatch_forward()\n\n    def forward(self, *args, **kwargs) -> Any:\n        return self._forward_method(*args, **kwargs)\n\n    def forward_native(self, *args, **kwargs) -> Any:\n        \"\"\"PyTorch-native implementation of the forward method.\n        This method is optional. If implemented, it can be used with compilers\n        such as torch.compile or PyTorch XLA. Also, it can be used for testing\n        purposes.\n        \"\"\"\n        raise NotImplementedError\n\n    def forward_cuda(self, *args, **kwargs) -> Any:\n        raise NotImplementedError\n\n    def forward_hip(self, *args, **kwargs) -> Any:\n        # ROCm kernels follow the CUDA path by default.\n        return self.forward_cuda(*args, **kwargs)\n\n    def forward_cpu(self, *args, **kwargs) -> Any:\n        # By default, we assume that CPU ops are compatible with CUDA ops.\n        return self.forward_cuda(*args, **kwargs)\n\n    def forward_tpu(self, *args, **kwargs) -> Any:\n        # By default, we assume that TPU ops are compatible with the\n        # PyTorch-native implementation.\n        # NOTE(woosuk): This is a placeholder for future extensions.\n        return self.forward_native(*args, **kwargs)\n\n    def forward_musa(self, *args, **kwargs) -> Any:\n        # MUSA kernels follow the CUDA path by default.\n        return self.forward_cuda(*args, **kwargs)\n\n    def forward_oot(self, *args, **kwargs) -> Any:\n        # By default, we assume that OOT ops are compatible with the\n        # PyTorch-native implementation.\n        return self.forward_native(*args, **kwargs)\n\n    def forward_npu(self, *args, **kwargs) -> Any:\n        # By default, we assume that NPU ops are compatible with the\n        # PyTorch-native implementation.\n        return self.forward_native(*args, **kwargs)\n\n    def dispatch_forward(self) -> Callable:\n        if _is_cuda:\n            return self.forward_cuda\n        elif current_platform.is_hip():\n            return self.forward_hip\n        elif current_platform.is_npu():\n            return self.forward_npu\n        elif current_platform.is_xpu():\n            return self.forward_xpu\n        elif current_platform.is_musa():\n            return self.forward_musa\n        else:\n            return self.forward_native\n\n    @classmethod\n    def enabled(cls) -> bool:\n        # since we are not using Inductor, we always return True\n        return True\n\n    @staticmethod\n    def default_on() -> bool:\n        \"\"\"\n        On by default if level < CompilationLevel.PIECEWISE\n        Specifying 'all' or 'none' in custom_op takes precedence.\n        \"\"\"\n        raise NotImplementedError\n\n    # Dictionary of all custom ops (classes, indexed by registered name).\n    # To check if an op with a name is enabled, call .enabled() on the class.\n    # Examples:\n    # - MyOp.enabled()\n    # - op_registry[\"my_op\"].enabled()\n    op_registry: dict[str, type[\"CustomOp\"]] = {}\n\n    # Decorator to register custom ops.\n    @classmethod\n    def register(cls, name: str) -> Callable:\n\n        def decorator(op_cls):\n            assert name not in cls.op_registry, f\"Duplicate op name: {name}\"\n            op_cls.name = name\n            cls.op_registry[name] = op_cls\n            return op_cls\n\n        return decorator\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/layers/elementwise.py",
    "content": "import torch\n\nfrom sglang.jit_kernel.diffusion.triton.scale_shift import fuse_scale_shift_kernel\nfrom sglang.multimodal_gen.runtime.layers.custom_op import CustomOp\n\n\nclass MulAdd(CustomOp):\n    \"\"\"\n    Fuse elementwise mul and add\n    Input: a, b, c, OptionalInt[k]\n    Output: a * (k + b) + c\n    \"\"\"\n\n    def __init__(self, prefix: str = \"\"):\n        super().__init__()\n\n    def forward_native(\n        self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, k: int = 0\n    ) -> torch.Tensor:\n        # a.shape: [batch_size, seq_len, inner_dim]\n        if b.dim() == 4:\n            # b.shape: [batch_size, num_frames, 1, inner_dim]\n            num_frames = b.shape[1]\n            frame_seqlen = a.shape[1] // num_frames\n            return c + (\n                a.unflatten(dim=1, sizes=(num_frames, frame_seqlen)) * (k + b)\n            ).flatten(1, 2)\n        else:\n            # b.shape: [batch_size, 1, inner_dim]\n            return c + a * (k + b)\n\n    def forward_cuda(\n        self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, k: int = 0\n    ):\n        return fuse_scale_shift_kernel(a, b, c, scale_constant=k)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/layers/layernorm.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/layers/layernorm.py\n\"\"\"Custom normalization layers.\"\"\"\n\nfrom typing import Optional, Tuple, Union\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\n\n_is_cuda = current_platform.is_cuda()\n_is_npu = current_platform.is_npu()\n_is_musa = current_platform.is_musa()\nif _is_cuda:\n    from sgl_kernel import fused_add_rmsnorm, rmsnorm\n\nif _is_npu:\n    import torch_npu\n\nif _is_musa:\n    from sgl_kernel import fused_add_rmsnorm\n\nfrom sglang.jit_kernel.diffusion.triton.norm import norm_infer, rms_norm_fn\nfrom sglang.jit_kernel.diffusion.triton.rmsnorm_onepass import triton_one_pass_rms_norm\nfrom sglang.jit_kernel.diffusion.triton.scale_shift import fuse_scale_shift_kernel\nfrom sglang.jit_kernel.norm import can_use_fused_inplace_qknorm, fused_inplace_qknorm\nfrom sglang.multimodal_gen.runtime.distributed.parallel_state import (\n    get_tensor_model_parallel_rank,\n    get_tensor_model_parallel_world_size,\n    get_tp_group,\n)\nfrom sglang.multimodal_gen.runtime.layers.custom_op import CustomOp\nfrom sglang.multimodal_gen.runtime.utils.common import get_bool_env_var\n\n\n# Copied and adapted from sglang\n@CustomOp.register(\"rms_norm\")\nclass RMSNorm(CustomOp):\n    \"\"\"Root mean square normalization.\n\n    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.\n    Refer to https://arxiv.org/abs/1910.07467\n    \"\"\"\n\n    def __init__(\n        self,\n        hidden_size: int,\n        eps: float = 1e-6,\n        dtype: torch.dtype = torch.float32,\n        var_hidden_size: Optional[int] = None,\n    ) -> None:\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(hidden_size))\n        self.variance_epsilon = eps\n        self.hidden_size = hidden_size\n        self.variance_size_override = (\n            None if var_hidden_size == hidden_size else var_hidden_size\n        )\n        if get_bool_env_var(\"SGLANG_ENABLE_DETERMINISTIC_INFERENCE\"):\n            self._forward_method = self.forward_native\n\n    def forward_triton(self, x: torch.Tensor, residual: Optional[torch.Tensor] = None):\n        return rms_norm_fn(\n            x, self.weight, bias=None, residual=residual, eps=self.variance_epsilon\n        )\n\n    def forward_cuda(\n        self,\n        x: torch.Tensor,\n        residual: Optional[torch.Tensor] = None,\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        shape = x.shape\n        device = x.device\n        x = x.reshape(-1, shape[-1])\n        if residual is not None:\n            residual_shape = residual.shape\n            residual = residual.view(-1, shape[-1])\n\n        if x.dtype == torch.float:\n            # fp32\n            out = self.forward_triton(x, residual)\n            if residual is not None:\n                return out[0].view(shape), out[1].view(residual_shape)\n            out = out.view(shape)\n            return out\n        elif self.variance_size_override is not None:\n            return self.forward_native(x, residual)\n        elif residual is not None:\n            fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)\n            return x.view(shape), residual.view(residual_shape)\n        else:\n            if x.shape[-1] <= 128:\n                out = triton_one_pass_rms_norm(\n                    x, self.weight.data, self.variance_epsilon\n                )\n            else:\n                out = rmsnorm(x, self.weight.data, self.variance_epsilon)\n        out = out.view(shape)\n\n        return out\n\n    def forward_native(\n        self,\n        x: torch.Tensor,\n        residual: Optional[torch.Tensor] = None,\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        if not x.is_contiguous():\n            x = x.contiguous()\n        orig_dtype = x.dtype\n        x = x.to(torch.float32)\n        if residual is not None:\n            x = x + residual.to(torch.float32)\n            residual = x.to(orig_dtype)\n\n        hidden_size = x.shape[-1]\n        if hidden_size != self.hidden_size:\n            raise ValueError(\n                \"Expected hidden_size to be \"\n                f\"{self.hidden_size}, but found: {hidden_size}\"\n            )\n\n        if self.variance_size_override is None:\n            x_var = x\n        else:\n            if hidden_size < self.variance_size_override:\n                raise ValueError(\n                    \"Expected hidden_size to be at least \"\n                    f\"{self.variance_size_override}, but found: {hidden_size}\"\n                )\n\n            x_var = x[..., : self.variance_size_override]\n\n        variance = x_var.pow(2).mean(dim=-1, keepdim=True)\n        x = x * torch.rsqrt(variance + self.variance_epsilon)\n        x = (x * self.weight).to(orig_dtype)\n        if residual is None:\n            return x\n        else:\n            return x, residual\n\n    def forward_cpu(\n        self,\n        x: torch.Tensor,\n        residual: Optional[torch.Tensor] = None,\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        return self.forward_native(x, residual)\n\n    def forward_npu(\n        self,\n        x: torch.Tensor,\n        residual: Optional[torch.Tensor] = None,\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        if residual is not None:\n            out, _, residual_out = torch_npu.npu_add_rms_norm(\n                residual, x, self.weight.data, self.variance_epsilon\n            )\n            return out, residual_out\n        return torch_npu.npu_rms_norm(x, self.weight.data, self.variance_epsilon)[0]\n\n    def forward_hip(\n        self,\n        x: torch.Tensor,\n        residual: Optional[torch.Tensor] = None,\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        # ROCm builds of sgl-kernel do not expose rmsnorm custom ops yet.\n        return self.forward_native(x, residual)\n\n    def _get_weight(self, dtype: torch.dtype) -> torch.Tensor:\n        \"\"\"Return weight matched to *dtype*.\n\n        MUSA kernels require input and weight to share the same dtype,\n        unlike CUDA kernels which may handle mixed dtypes internally.\n        \"\"\"\n        weight = self.weight.data\n        if weight.dtype != dtype:\n            weight = weight.to(dtype=dtype)\n        return weight\n\n    def forward_musa(\n        self,\n        x: torch.Tensor,\n        residual: Optional[torch.Tensor] = None,\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        shape = x.shape\n        x = x.reshape(-1, shape[-1])\n        if residual is not None:\n            residual_shape = residual.shape\n            residual = residual.view(-1, shape[-1])\n\n        if self.variance_size_override is not None:\n            return self.forward_native(x, residual)\n        elif residual is not None:\n            # fused_add_rmsnorm requires contiguous inputs.\n            if not x.is_contiguous():\n                x = x.contiguous()\n            if not residual.is_contiguous():\n                residual = residual.contiguous()\n            weight = self._get_weight(x.dtype)\n            fused_add_rmsnorm(x, residual, weight, self.variance_epsilon)\n            return x.view(shape), residual.view(residual_shape)\n        else:\n            weight = self._get_weight(x.dtype)\n            out = F.rms_norm(x, (self.hidden_size,), weight, self.variance_epsilon)\n        out = out.view(shape)\n        return out\n\n    def extra_repr(self) -> str:\n        s = f\"hidden_size={self.weight.data.size(0)}\"\n        s += f\", eps={self.variance_epsilon}\"\n        return s\n\n\n# Copied and adapted from sglang\n@CustomOp.register(\"layer_norm\")\nclass LayerNorm(CustomOp):\n    def __init__(\n        self,\n        hidden_size: int,\n        eps=1e-5,\n        bias: bool = True,\n        elementwise_affine=True,\n        device=None,\n        dtype=None,\n    ) -> None:\n        super().__init__()\n        self.eps = eps\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        self.hidden_size = hidden_size\n        if elementwise_affine:\n            self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))\n            self.bias = (\n                torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))\n                if bias\n                else None\n            )\n        else:\n            self.register_parameter(\"weight\", None)\n            self.register_parameter(\"bias\", None)\n            # Lazy cache for ones vector (not a registered buffer to avoid FSDP/meta issues)\n            self._weight_fallback_cache = None\n\n    def _get_weight_fallback(self, x: torch.Tensor) -> torch.Tensor:\n        wf = getattr(self, \"_weight_fallback_cache\", None)\n        if (\n            wf is None\n            or wf.device != x.device\n            or wf.dtype != x.dtype\n            or wf.numel() != self.hidden_size\n        ):\n            wf = torch.ones(self.hidden_size, device=x.device, dtype=x.dtype)\n            self._weight_fallback_cache = wf\n        return wf\n\n    def forward_triton(self, x: torch.Tensor):\n        # Fast inference kernel without residual/dropout branches\n        return norm_infer(\n            x.view(-1, self.hidden_size),\n            self.weight,\n            self.bias,\n            eps=self.eps,\n            is_rms_norm=False,\n        ).view(x.shape)\n\n    def forward_cuda(\n        self,\n        x: torch.Tensor,\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        shape = x.shape\n        x = x.view(-1, self.hidden_size)\n        return self.forward_triton(x).view(shape)\n\n    @torch.compile(backend=\"inductor\", disable=current_platform.is_npu())\n    def forward_native(\n        self,\n        x: torch.Tensor,\n        residual: Optional[torch.Tensor] = None,\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        input_dtype = x.dtype\n        mean = x.mean(-1, keepdim=True)\n        variance = (x - mean).pow(2).mean(-1, keepdim=True)\n        x = (x - mean) * torch.rsqrt(variance + self.eps)\n        if self.weight is not None:\n            x = self.weight * x\n        # if no affine, this is a no-op\n        if self.bias is not None:\n            x = x + self.bias\n        return x.to(input_dtype)\n\n    def forward_cpu(\n        self,\n        x: torch.Tensor,\n        residual: Optional[torch.Tensor] = None,\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        return self.forward_native(x, residual)\n\n    def forward_musa(self, x: torch.Tensor):\n        return F.layer_norm(x, (self.hidden_size,), self.weight, self.bias, self.eps)\n\n    def extra_repr(self) -> str:\n        s = f\"hidden_size={self.weight.data.size(0)}\"\n        s += f\", eps={self.variance_epsilon}\"\n        return s\n\n\n# adapted from Diffusers: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/normalization.py\n# NOTE(will): Needed to match behavior of diffusers and wan2.1 even while using\n# FSDP's MixedPrecisionPolicy\nclass FP32LayerNorm(nn.LayerNorm):\n    def forward(self, inputs: torch.Tensor) -> torch.Tensor:\n        origin_dtype = inputs.dtype\n        device = inputs.device\n        return F.layer_norm(\n            inputs.float(),\n            self.normalized_shape,\n            self.weight.float().to(device=device) if self.weight is not None else None,\n            self.bias.float().to(device=device) if self.bias is not None else None,\n            self.eps,\n        ).to(origin_dtype)\n\n\n################################################################################\n# Fused norm kernel\n################################################################################\ndef _ensure_contiguous(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:\n    return tensor.contiguous() if tensor is not None else None\n\n\nclass _ScaleResidualNormScaleShift(CustomOp):\n    \"\"\"\n    Fused kernel that combines:\n    1. residual_out = residual + gate * x\n    2. normed = layernorm(residual_out) or rmsnorm(residual_out)\n    3. out = normed * (1 + scale) + shift\n    compute_dtype is always fp32 for higher precision.\n    \"\"\"\n\n    norm_type: str\n\n    def __init__(\n        self,\n        hidden_size: int,\n        eps: float = 1e-6,\n        elementwise_affine: bool = False,\n        dtype: torch.dtype = torch.float32,\n        prefix: str = \"\",\n    ):\n        super().__init__()\n        self.eps = eps\n        self.dtype = dtype\n        if self.norm_type == \"rms\":\n            self.norm = RMSNorm(hidden_size, eps=eps, dtype=dtype)\n        elif self.norm_type == \"layer\":\n            self.norm = FP32LayerNorm(\n                hidden_size, elementwise_affine=elementwise_affine, eps=eps, dtype=dtype\n            )\n        else:\n            raise NotImplementedError(f\"Norm type {self.norm_type} not implemented\")\n\n    def forward_cuda(\n        self,\n        residual: torch.Tensor,\n        x: torch.Tensor,\n        gate: torch.Tensor | int,\n        shift: torch.Tensor,\n        scale: torch.Tensor,\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        if x.shape[-1] % 256 != 0 and x.shape[-1] <= 8192:\n            import warnings\n\n            warnings.warn(\n                \"FusedScaleResidualNormScaleShift cuda not available, using native fallback\",\n                stacklevel=2,\n            )\n            return self.forward_native(residual, x, gate, shift, scale)\n\n        from sglang.jit_kernel.diffusion.cutedsl.scale_residual_norm_scale_shift import (\n            fused_scale_residual_norm_scale_shift,\n        )\n\n        if isinstance(gate, int) and gate != 1:\n            raise ValueError(\n                f\"Only gate value of 1 is supported for int type, but got {gate}\"\n            )\n\n        return fused_scale_residual_norm_scale_shift(\n            residual.contiguous(),\n            x.contiguous(),\n            gate.contiguous() if isinstance(gate, torch.Tensor) else None,\n            _ensure_contiguous(getattr(self.norm, \"weight\", None)),\n            _ensure_contiguous(getattr(self.norm, \"bias\", None)),\n            scale.contiguous(),\n            shift.contiguous(),\n            self.norm_type,\n            self.eps,\n        )\n\n    def forward_hip(self, *args, **kwargs):\n        # ROCm does not support CUDA/CUTLASS-based fused kernels yet,\n        # so we fall back to the native PyTorch implementation.\n        return self.forward_native(*args, **kwargs)\n\n    def forward_musa(self, *args, **kwargs):\n        # MUSA does not support CUDA/CUTLASS-based fused kernels yet,\n        # so we fall back to the native PyTorch implementation.\n        return self.forward_native(*args, **kwargs)\n\n    def forward_native(\n        self,\n        residual: torch.Tensor,\n        x: torch.Tensor,\n        gate: torch.Tensor | int,\n        shift: torch.Tensor,\n        scale: torch.Tensor,\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        # x.shape: [batch_size, seq_len, inner_dim]\n        if isinstance(gate, int):\n            # used by cross-attention, should be 1\n            assert gate == 1\n            residual_output = residual + x\n        elif isinstance(gate, torch.Tensor):\n            if gate.dim() == 4:\n                # gate.shape: [batch_size, num_frames, 1, inner_dim]\n                num_frames = gate.shape[1]\n                frame_seqlen = x.shape[1] // num_frames\n                residual_output = residual + (\n                    x.unflatten(dim=1, sizes=(num_frames, frame_seqlen)) * gate\n                ).flatten(1, 2)\n            else:\n                # gate.shape: [batch_size, 1, inner_dim]\n                residual_output = residual + x * gate\n        else:\n            raise ValueError(f\"Gate type {type(gate)} not supported\")\n        normalized = self.norm(residual_output)\n        modulated = fuse_scale_shift_kernel(normalized, scale, shift)\n        return modulated, residual_output\n\n\nclass ScaleResidualLayerNormScaleShift(_ScaleResidualNormScaleShift):\n    norm_type = \"layer\"\n\n\nclass ScaleResidualRMSNormScaleShift(_ScaleResidualNormScaleShift):\n    norm_type = \"rms\"\n\n\nclass _NormScaleShift(CustomOp):\n    \"\"\"\n    Fused kernel that combines:\n    1. normed = layernorm(x) or rmsnorm(x)\n    2. out = normed * (1 + scale) + shift\n    compute_dtype is always fp32 for higher precision.\n    \"\"\"\n\n    norm_type: str\n\n    def __init__(\n        self,\n        hidden_size: int,\n        eps: float = 1e-6,\n        elementwise_affine: bool = False,\n        dtype: torch.dtype = torch.float32,\n        prefix: str = \"\",\n    ):\n        super().__init__()\n        self.eps = eps\n        if self.norm_type == \"rms\":\n            self.norm = RMSNorm(hidden_size, eps=eps, dtype=dtype)\n        elif self.norm_type == \"layer\":\n            self.norm = FP32LayerNorm(\n                hidden_size, elementwise_affine=elementwise_affine, eps=eps, dtype=dtype\n            )\n        else:\n            raise NotImplementedError(f\"Norm type {self.norm_type} not implemented\")\n\n    def forward_cuda(\n        self, x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor\n    ) -> torch.Tensor:\n        if x.shape[-1] % 256 != 0 and x.shape[-1] <= 8192:\n            import warnings\n\n            warnings.warn(\n                \"FusedNormScaleShift cuda not available, using native fallback\",\n                stacklevel=2,\n            )\n            return self.forward_native(x, shift, scale)\n\n        from sglang.jit_kernel.diffusion.cutedsl.scale_residual_norm_scale_shift import (\n            fused_norm_scale_shift,\n        )\n\n        return fused_norm_scale_shift(\n            x.contiguous(),\n            _ensure_contiguous(getattr(self.norm, \"weight\", None)),\n            _ensure_contiguous(getattr(self.norm, \"bias\", None)),\n            scale.contiguous(),\n            shift.contiguous(),\n            self.norm_type,\n            self.eps,\n        )\n\n    def forward_hip(self, *args, **kwargs):\n        # ROCm does not support CUDA/CUTLASS-based fused kernels yet,\n        # so we fall back to the native PyTorch implementation.\n        return self.forward_native(*args, **kwargs)\n\n    def forward_musa(self, *args, **kwargs):\n        # MUSA does not support CUDA/CUTLASS-based fused kernels yet,\n        # so we fall back to the native PyTorch implementation.\n        return self.forward_native(*args, **kwargs)\n\n    def forward_native(\n        self, x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor\n    ) -> torch.Tensor:\n        normalized = self.norm(x)\n        modulated = fuse_scale_shift_kernel(normalized, scale, shift)\n        return modulated.to(x.dtype)\n\n\nclass LayerNormScaleShift(_NormScaleShift):\n    norm_type = \"layer\"\n\n\nclass RMSNormScaleShift(_NormScaleShift):\n    norm_type = \"rms\"\n\n\ndef apply_qk_norm(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    q_norm: \"RMSNorm\",\n    k_norm: \"RMSNorm\",\n    head_dim: int,\n    allow_inplace: bool = True,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"Apply QK normalization for query and key tensors.\n\n    Uses JIT fused inplace kernel when available, falls back to standard RMSNorm.\n    \"\"\"\n\n    batch_size = q.size(0)\n    q_eps = q_norm.variance_epsilon\n    k_eps = k_norm.variance_epsilon\n    # Only try fused path on CUDA and when it won't introduce implicit copies.\n    if (\n        _is_cuda\n        and allow_inplace\n        and (q_eps == k_eps)\n        and can_use_fused_inplace_qknorm(head_dim, q.dtype)\n    ):\n        fused_inplace_qknorm(\n            q=q.view(batch_size, -1, head_dim),\n            k=k.view(batch_size, -1, head_dim),\n            q_weight=q_norm.weight,\n            k_weight=k_norm.weight,\n            head_dim=head_dim,\n            eps=q_eps,\n        )\n        return q, k\n\n    q_shape = q.shape\n    k_shape = k.shape\n    q_out = q_norm(q.view(-1, head_dim)).view(q_shape)\n    k_out = k_norm(k.view(-1, head_dim)).view(k_shape)\n    return q_out, k_out\n\n\ndef tensor_parallel_rms_norm(x: torch.Tensor, norm: \"RMSNorm\") -> torch.Tensor:\n    tp_rank = get_tensor_model_parallel_rank()\n    tp_size = get_tensor_model_parallel_world_size()\n    src_dtype = x.dtype\n    weight = norm.weight.tensor_split(tp_size)[tp_rank].float()\n    x_fp32 = x.float()\n    variance = x_fp32.pow(2).mean(dim=-1, keepdim=True)\n    variance = get_tp_group().all_reduce(\n        variance, op=torch._C._distributed_c10d.ReduceOp.AVG\n    )\n    output = x_fp32 * torch.rsqrt(variance + norm.variance_epsilon) * weight\n    return output.to(dtype=src_dtype)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/layers/linear.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/layers/linear.py\n\nfrom abc import abstractmethod\n\nimport torch\nimport torch.distributed as dist\nimport torch.nn.functional as F\nfrom torch.nn.parameter import Parameter\n\nfrom sglang.multimodal_gen.runtime.distributed import (\n    divide,\n    get_tp_group,\n    split_tensor_along_last_dim,\n    tensor_model_parallel_all_gather,\n    tensor_model_parallel_all_reduce,\n)\nfrom sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import (\n    QuantizationConfig,\n    QuantizeMethodBase,\n)\nfrom sglang.multimodal_gen.runtime.layers.utils import get_group_rank, get_group_size\n\n# yapf: disable\nfrom sglang.multimodal_gen.runtime.models.parameter import (\n    BasevLLMParameter,\n    BlockQuantScaleParameter,\n    PackedColumnParameter,\n    PackedvLLMParameter,\n    PerTensorScaleParameter,\n    RowvLLMParameter,\n)\n\n# yapf: enable\nfrom sglang.multimodal_gen.runtime.models.utils import set_weight_attrs\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\nWEIGHT_LOADER_V2_SUPPORTED = [\n    \"CompressedTensorsLinearMethod\",\n    \"AWQMarlinLinearMethod\",\n    \"AWQLinearMethod\",\n    \"GPTQMarlinLinearMethod\",\n    \"Fp8LinearMethod\",\n    \"MarlinLinearMethod\",\n    \"QQQLinearMethod\",\n    \"GPTQMarlin24LinearMethod\",\n    \"TPUInt8LinearMethod\",\n    \"GPTQLinearMethod\",\n    \"FBGEMMFp8LinearMethod\",\n    \"ModelOptFp8LinearMethod\",\n    \"IPEXAWQLinearMethod\",\n    \"IPEXGPTQLinearMethod\",\n    \"HQQMarlinMethod\",\n    \"QuarkLinearMethod\",\n]\n\n\ndef adjust_scalar_to_fused_array(\n    param: torch.Tensor, loaded_weight: torch.Tensor, shard_id: str | int\n) -> tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"For fused modules (QKV and MLP) we have an array of length\n    N that holds 1 scale for each \"logical\" matrix. So the param\n    is an array of length N. The loaded_weight corresponds to\n    one of the shards on disk. Here, we slice the param based on\n    the shard_id for loading.\n    \"\"\"\n    qkv_idxs = {\"q\": 0, \"k\": 1, \"v\": 2}\n\n    if isinstance(shard_id, str):\n        shard_id = qkv_idxs[shard_id]\n    elif not isinstance(shard_id, int):\n        raise ValueError(f\"Unknown Shard Id {shard_id}\")\n\n    # AutoFP8 scales do not have a shape\n    # compressed-tensors scales do have a shape\n    if len(loaded_weight.shape) != 0:\n        assert loaded_weight.shape[0] == 1\n        loaded_weight = loaded_weight[0]\n\n    return param[shard_id], loaded_weight\n\n\nclass LinearMethodBase(QuantizeMethodBase):\n    \"\"\"Base class for different (maybe quantized) linear methods.\"\"\"\n\n    @abstractmethod\n    def create_weights(\n        self,\n        layer: torch.nn.Module,\n        input_size_per_partition: int,\n        output_partition_sizes: list[int],\n        input_size: int,\n        output_size: int,\n        params_dtype: torch.dtype,\n        **extra_weight_attrs,\n    ) -> None:\n        \"\"\"Create weights for a linear layer.\n           The weights will be set as attributes of the layer.\n\n        Args:\n            layer: The layer that is using the LinearMethodBase factory.\n            input_size_per_partition: Size of the weight input dim on rank X.\n            output_partition_sizes: Sizes of the output dim of each logical\n                weight on rank X. E.g., output_partition_sizes for QKVLinear\n                is a list contains the width of Wq, Wk, Wv on rank X.\n            input_size: Size of the input dim of the weight across all ranks.\n            output_size: Size of the output dim of the weight across all ranks.\n            params_dtype: Datatype of the parameters.\n        \"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def apply(\n        self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None\n    ) -> torch.Tensor:\n        \"\"\"Apply the weights in layer to the input tensor.\n        Expects create_weights to have been called before on the layer.\"\"\"\n        raise NotImplementedError\n\n\nclass UnquantizedLinearMethod(LinearMethodBase):\n    \"\"\"Linear method without quantization.\"\"\"\n\n    def create_weights(\n        self,\n        layer: torch.nn.Module,\n        input_size_per_partition: int,\n        output_partition_sizes: list[int],\n        input_size: int,\n        output_size: int,\n        params_dtype: torch.dtype,\n        **extra_weight_attrs,\n    ) -> None:\n        weight = Parameter(\n            torch.empty(\n                sum(output_partition_sizes),\n                input_size_per_partition,\n                dtype=params_dtype,\n            ),\n            requires_grad=False,\n        )\n        set_weight_attrs(weight, {\"input_dim\": 1, \"output_dim\": 0})\n        layer.register_parameter(\"weight\", weight)\n        set_weight_attrs(weight, extra_weight_attrs)\n\n    def apply(\n        self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None\n    ) -> torch.Tensor:\n        output = (\n            F.linear(x, layer.weight, bias)\n            if current_platform.is_amp_supported() or bias is None\n            else F.linear(x, layer.weight, bias.to(x.dtype))\n        )  # NOTE: this line assumes that we are using amp when using cuda and is needed to account for the fact that amp isn't supported in mps\n        return output\n\n\nclass LinearBase(torch.nn.Module):\n    \"\"\"Base linear layer.\n\n    Args:\n        input_size: input dimension of the linear layer.\n        output_size: output dimension of the linear layer.\n        skip_bias_add: If true, skip adding bias but instead return it.\n        params_dtype: Data type for the parameters.\n        quant_config: Quantization configure.\n    \"\"\"\n\n    def __init__(\n        self,\n        input_size: int,\n        output_size: int,\n        skip_bias_add: bool = False,\n        params_dtype: torch.dtype | None = None,\n        quant_config: QuantizationConfig | None = None,\n        prefix: str = \"\",\n    ):\n        super().__init__()\n\n        # Keep input parameters\n        self.input_size = input_size\n        self.output_size = output_size\n        self.skip_bias_add = skip_bias_add\n        if params_dtype is None:\n            params_dtype = torch.get_default_dtype()\n        self.params_dtype = params_dtype\n        self.quant_config = quant_config\n        self.prefix = prefix\n        if quant_config is None:\n            self.quant_method: QuantizeMethodBase | None = UnquantizedLinearMethod()\n        else:\n            self.quant_method = quant_config.get_quant_method(self, prefix=prefix)\n\n    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, Parameter | None]:\n        raise NotImplementedError\n\n\nclass ReplicatedLinear(LinearBase):\n    \"\"\"Replicated linear layer.\n\n    Args:\n        input_size: input dimension of the linear layer.\n        output_size: output dimension of the linear layer.\n        bias: If true, add bias.\n        skip_bias_add: If true, skip adding bias but instead return it.\n        params_dtype: Data type for the parameters.\n        quant_config: Quantization configure.\n        prefix: The name of the layer in the state dict, including all parents\n                        (e.g. model.layers.0.qkv_proj)\n    \"\"\"\n\n    def __init__(\n        self,\n        input_size: int,\n        output_size: int,\n        bias: bool = True,\n        skip_bias_add: bool = False,\n        params_dtype: torch.dtype | None = None,\n        quant_config: QuantizationConfig | None = None,\n        prefix: str = \"\",\n    ):\n        super().__init__(\n            input_size,\n            output_size,\n            skip_bias_add,\n            params_dtype,\n            quant_config,\n            prefix=prefix,\n        )\n\n        # All the linear layer supports quant method.\n        assert self.quant_method is not None\n        self.quant_method.create_weights(\n            self,\n            self.input_size,\n            [self.output_size],\n            self.input_size,\n            self.output_size,\n            self.params_dtype,\n            weight_loader=self.weight_loader,\n        )\n\n        if bias:\n            self.bias = Parameter(\n                torch.empty(\n                    self.output_size,\n                    dtype=self.params_dtype,\n                )\n            )\n            set_weight_attrs(\n                self.bias,\n                {\n                    \"output_dim\": 0,\n                    \"weight_loader\": self.weight_loader,\n                },\n            )\n        else:\n            self.register_parameter(\"bias\", None)\n\n    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor) -> None:\n        # If the weight on disk does not have a shape, give it one\n        # (such scales for AutoFp8).\n        if len(loaded_weight.shape) == 0:\n            loaded_weight = loaded_weight.reshape(1)\n\n        assert param.size() == loaded_weight.size(), (\n            f\"Tried to load weights of size {loaded_weight.size()}\"\n            f\"to a parameter of size {param.size()}\"\n        )\n        param.data.copy_(loaded_weight)\n\n    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, Parameter | None]:\n        bias = self.bias if not self.skip_bias_add else None\n        assert self.quant_method is not None\n        output = self.quant_method.apply(self, x, bias)\n        output_bias = self.bias if self.skip_bias_add else None\n        return output, output_bias\n\n    def extra_repr(self) -> str:\n        s = f\"in_features={self.input_size}\"\n        s += f\", output_features={self.output_size}\"\n        s += f\", bias={self.bias is not None}\"\n        return s\n\n\nclass ColumnParallelLinear(LinearBase):\n    \"\"\"Linear layer with column parallelism.\n\n    The linear layer is defined as Y = XA + b. A is parallelized along\n    its second dimension as A = [A_1, ..., A_p].\n\n    Args:\n        input_size: first dimension of matrix A.\n        output_size: second dimension of matrix A.\n        bias: If true, add bias.\n        gather_output: If true, call all-gather on output and make Y available\n                       to all GPUs, otherwise, every GPU will have its output\n                       which is Y_i = XA_i\n        skip_bias_add: This was added to enable performance optimizations where\n                       bias can be fused with other element-wise operations. we\n                       skip adding bias but instead return it.\n        params_dtype: Data type for the parameters.\n        quant_config: Quantization configure.\n        output_sizes: list of output sizes packed into one output, like for QKV\n                       the list would be size 3.\n        prefix: The name of the layer in the state dict, including all parents\n                        (e.g. model.layers.0.qkv_proj)\n    \"\"\"\n\n    def __init__(\n        self,\n        input_size: int,\n        output_size: int,\n        bias: bool = True,\n        gather_output: bool = False,\n        skip_bias_add: bool = False,\n        params_dtype: torch.dtype | None = None,\n        quant_config: QuantizationConfig | None = None,\n        output_sizes: list[int] | None = None,\n        prefix: str = \"\",\n        tp_group: dist.ProcessGroup = None,\n    ):\n        # Divide the weight matrix along the last dimension.\n        self.tp_group = tp_group or get_tp_group()\n        self.tp_size = get_group_size(self.tp_group)\n        self.tp_rank = get_group_rank(self.tp_group)\n        self.input_size_per_partition = input_size\n        self.output_size_per_partition = divide(output_size, self.tp_size)\n        self.output_partition_sizes = [self.output_size_per_partition]\n        # If QKV or MergedColumn, use output size of each partition.\n        if hasattr(self, \"output_sizes\"):\n            self.output_partition_sizes = [\n                divide(output_size, self.tp_size) for output_size in self.output_sizes\n            ]\n\n        super().__init__(\n            input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix\n        )\n\n        self.gather_output = gather_output\n\n        if output_sizes is None:\n            output_sizes = [output_size]\n\n        assert self.quant_method is not None\n        self.quant_method.create_weights(\n            layer=self,\n            input_size_per_partition=self.input_size_per_partition,\n            output_partition_sizes=self.output_partition_sizes,\n            input_size=self.input_size,\n            output_size=self.output_size,\n            params_dtype=self.params_dtype,\n            weight_loader=(\n                self.weight_loader_v2\n                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED\n                else self.weight_loader\n            ),\n        )\n        if bias:\n            self.bias = Parameter(\n                torch.empty(\n                    self.output_size_per_partition,\n                    dtype=params_dtype,\n                )\n            )\n            set_weight_attrs(\n                self.bias,\n                {\n                    \"output_dim\": 0,\n                    \"weight_loader\": self.weight_loader,\n                },\n            )\n        else:\n            self.register_parameter(\"bias\", None)\n\n    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor) -> None:\n        tp_rank = self.tp_rank\n        output_dim = getattr(param, \"output_dim\", None)\n\n        is_sharded_weight = getattr(param, \"is_sharded_weight\", False)\n        is_sharded_weight = is_sharded_weight\n\n        param_data = param.data\n        if output_dim is not None and not is_sharded_weight:\n            shard_size = param_data.shape[output_dim]\n            start_idx = tp_rank * shard_size\n            loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)\n\n        # Special case for loading scales off disk, which often do not\n        # have a shape (such as in the case of AutoFP8).\n        if len(loaded_weight.shape) == 0:\n            loaded_weight = loaded_weight.reshape(1)\n\n        assert param_data.shape == loaded_weight.shape\n        param_data.copy_(loaded_weight)\n\n    def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor) -> None:\n        # Special case for loading scales off disk, which often do not\n        # have a shape (such as in the case of AutoFP8).\n        if len(loaded_weight.shape) == 0:\n            assert loaded_weight.numel() == 1\n            loaded_weight = loaded_weight.reshape(1)\n        param.load_column_parallel_weight(loaded_weight=loaded_weight)\n\n    def forward(self, input_: torch.Tensor) -> tuple[torch.Tensor, Parameter | None]:\n        bias = self.bias if not self.skip_bias_add else None\n\n        # Matrix multiply.\n        assert self.quant_method is not None\n        output_parallel = self.quant_method.apply(self, input_, bias)\n        if self.gather_output:\n            # All-gather across the partitions.\n            output = tensor_model_parallel_all_gather(\n                output_parallel, tp_group=self.tp_group\n            )\n        else:\n            output = output_parallel\n        output_bias = self.bias if self.skip_bias_add else None\n        return output, output_bias\n\n    def extra_repr(self) -> str:\n        s = f\"in_features={self.input_size}\"\n        s += f\", output_features={self.output_size_per_partition}\"\n        s += f\", bias={self.bias is not None}\"\n        s += f\", tp_size={self.tp_size}\"\n        s += f\", gather_output={self.gather_output}\"\n        return s\n\n\nclass MergedColumnParallelLinear(ColumnParallelLinear):\n    \"\"\"Packed linear layers with column parallelism.\n\n    Similar to ColumnParallelLinear, but the weight matrix is concatenated\n    along the output dimension. When the weight matrix is loaded, the\n    different partitions are sharded separately.\n\n    Args:\n        input_size: input dimension of the linear layer.\n        output_sizes: list of output dimensions of the linear layer.\n        bias: If true, add bias.\n        gather_output: If true, call all-gather on output and make the output\n                       available to all GPUs, otherwise, every GPU will have\n                       its own output.\n        skip_bias_add: This was added to enable performance optimizations where\n                       bias can be fused with other element-wise operations. we\n                       skip adding bias but instead return it.\n        params_dtype: Data type for the parameters.\n        quant_config: Quantization configure.\n        prefix: The name of the layer in the state dict, including all parents\n                        (e.g. model.layers.0.qkv_proj)\n    \"\"\"\n\n    def __init__(\n        self,\n        input_size: int,\n        output_sizes: list[int],\n        bias: bool = True,\n        gather_output: bool = False,\n        skip_bias_add: bool = False,\n        params_dtype: torch.dtype | None = None,\n        quant_config: QuantizationConfig | None = None,\n        prefix: str = \"\",\n        tp_group: dist.ProcessGroup = None,\n    ):\n        super().__init__(\n            input_size=input_size,\n            output_size=sum(output_sizes),\n            bias=bias,\n            gather_output=gather_output,\n            skip_bias_add=skip_bias_add,\n            params_dtype=params_dtype,\n            quant_config=quant_config,\n            prefix=prefix,\n            tp_group=tp_group,\n        )\n        self.output_sizes = output_sizes\n        assert all(output_size % self.tp_size == 0 for output_size in output_sizes)\n\n    def weight_loader(\n        self,\n        param: Parameter,\n        loaded_weight: torch.Tensor,\n        loaded_shard_id: int | None = None,\n    ) -> None:\n\n        param_data = param.data\n        output_dim = getattr(param, \"output_dim\", None)\n        # Special case for AQLM codebooks.\n        is_metadata = getattr(param, \"is_metadata\", False)\n        # Special case for per-tensor scale to load scalar into fused array.\n        needs_scalar_to_array = getattr(param, \"needs_scalar_to_array\", False)\n\n        if loaded_shard_id is None:\n            # Loaded weight is already fused on disk (mlp).\n            # (e.g., Phi-3's gate_up_proj).\n            if output_dim is None:\n                if needs_scalar_to_array:\n                    param_data, loaded_weight = adjust_scalar_to_fused_array(\n                        param_data, loaded_weight, 0\n                    )\n\n                assert param_data.shape == loaded_weight.shape\n                param_data.copy_(loaded_weight)\n                return\n            current_shard_offset = 0\n            shard_offsets: list[tuple[int, int, int]] = []\n            for i, output_size in enumerate(self.output_sizes):\n                shard_offsets.append((i, current_shard_offset, output_size))\n                current_shard_offset += output_size\n            for shard_id, shard_offset, shard_size in shard_offsets:\n                loaded_weight_shard = loaded_weight.narrow(\n                    output_dim, shard_offset, shard_size\n                )\n                self.weight_loader(param, loaded_weight_shard, shard_id)\n            return\n\n        assert loaded_shard_id < len(self.output_sizes)\n        tp_rank = self.tp_rank\n        tp_size = self.tp_size\n        if output_dim is not None:\n            shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size\n            shard_size = self.output_sizes[loaded_shard_id] // tp_size\n\n            is_sharded_weight = getattr(param, \"is_sharded_weight\", False)\n            # bitsandbytes loads the weights of the specific portion\n            # no need to narrow\n            is_sharded_weight = is_sharded_weight\n\n            param_data = param_data.narrow(output_dim, shard_offset, shard_size)\n            start_idx = tp_rank * shard_size\n            if not is_sharded_weight:\n                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)\n        # Special case for AQLM codebooks.\n        elif is_metadata:\n            # metadata indicates fixed size concatenated along dim 0\n            shard_size = loaded_weight.shape[0]\n            shard_offset = loaded_shard_id * shard_size\n            param_data = param_data.narrow(0, shard_offset, shard_size)\n\n        # Special case for per-tensor scales in fused case.\n        elif needs_scalar_to_array:\n            param_data, loaded_weight = adjust_scalar_to_fused_array(\n                param_data, loaded_weight, loaded_shard_id\n            )\n\n        else:\n            ignore_warning = getattr(param, \"ignore_warning\", False)\n            if not ignore_warning:\n                logger.warning(\n                    \"Loading a weight without `output_dim` attribute in \"\n                    \"MergedColumnParallelLinear, assume the weight is \"\n                    \"the same for all partitions.\"\n                )\n\n        assert param_data.shape == loaded_weight.shape\n        param_data.copy_(loaded_weight)\n\n    def _load_fused_module_from_checkpoint(\n        self, param: BasevLLMParameter, loaded_weight: torch.Tensor\n    ) -> None:\n        \"\"\"\n        Handle special case for models where MLP layers are already\n        fused on disk. In this case, we have no shard id. This function\n        determmines the shard id by splitting these layers and then calls\n        the weight loader using the shard id.\n\n        An example of a model with these fused layers:\n        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct\n        \"\"\"\n\n        current_shard_offset = 0\n        shard_offsets: list[tuple[int, int, int]] = []\n        for i, output_size in enumerate(self.output_sizes):\n            shard_offsets.append((i, current_shard_offset, output_size))\n            current_shard_offset += output_size\n\n        for shard_id, shard_offset, shard_size in shard_offsets:\n            # Special case for Quantization.\n            # If quantized, we need to adjust the offset and size to account\n            # for the packing.\n            if (\n                isinstance(param, PackedColumnParameter | PackedvLLMParameter)\n                and param.packed_dim == param.output_dim\n            ):\n                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(\n                    shard_size=shard_size, shard_offset=shard_offset\n                )\n\n            loaded_weight_shard = loaded_weight.narrow(\n                param.output_dim, shard_offset, shard_size\n            )\n            self.weight_loader_v2(param, loaded_weight_shard, shard_id)\n\n    def weight_loader_v2(\n        self,\n        param: BasevLLMParameter,\n        loaded_weight: torch.Tensor,\n        loaded_shard_id: int | None = None,\n    ) -> None:\n        if loaded_shard_id is None:\n            if isinstance(param, PerTensorScaleParameter):\n                param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)\n                return\n            elif type(param) in (RowvLLMParameter, BasevLLMParameter):\n                param.load_merged_column_weight(loaded_weight=loaded_weight)\n                return\n            # TODO: @dsikka - move to parameter.py\n            self._load_fused_module_from_checkpoint(param, loaded_weight)\n            return\n\n        assert loaded_shard_id < len(self.output_sizes)\n\n        tp_size = self.tp_size\n\n        if isinstance(param, BlockQuantScaleParameter):\n            raise NotImplementedError(\"FP8 is not implemented yet\")\n            # FIXME(will): add fp8 support\n            # from vllm.model_executor.layers.quantization.fp8 import (\n            #     Fp8LinearMethod, Fp8MoEMethod)\n            # assert self.quant_method is not None\n            # assert isinstance(self.quant_method,\n            #                   (Fp8LinearMethod, Fp8MoEMethod))\n            # weight_block_size = self.quant_method.quant_config.weight_block_size\n            # assert weight_block_size is not None\n            # block_n, _ = weight_block_size[0], weight_block_size[1]\n            # shard_offset = (\n            #     (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //\n            #     block_n) // tp_size\n            # shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //\n            #               block_n // tp_size)\n        else:\n            shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size\n            shard_size = self.output_sizes[loaded_shard_id] // tp_size\n\n        param.load_merged_column_weight(\n            loaded_weight=loaded_weight,\n            shard_id=loaded_shard_id,\n            shard_offset=shard_offset,\n            shard_size=shard_size,\n        )\n\n\nclass QKVParallelLinear(ColumnParallelLinear):\n    \"\"\"Linear layers for the attention's QKV transformation.\n\n    Linear layers for the linear transformation of the query, key, and value\n    vectors in the attention layer. The weight matrix is concatenated along\n    the output dimension. The layer is parallelized along the head dimension.\n    When the number of key/value heads is smaller than the number of query\n    heads (e.g., multi-query/grouped-query attention), the key/value head may\n    be replicated while the query heads are partitioned.\n\n    Args:\n        hidden_size: input hidden state size of the transformer.\n        head_size: size of each attention head.\n        total_num_heads: total number of attention query heads.\n        total_num_kv_heads: total number of attention key/value heads. If\n                            None, assume total_num_kv_heads = total_num_heads.\n        bias: If true, add bias.\n        skip_bias_add: This was added to enable performance optimizations where\n                       bias can be fused with other element-wise operations. we\n                       skip adding bias but instead return it.\n        params_dtype: Data type for the parameters.\n        quant_config: Quantization configure.\n        prefix: The name of the layer in the state dict, including all parents\n                        (e.g. model.layers.0.qkv_proj)\n    \"\"\"\n\n    def __init__(\n        self,\n        hidden_size: int,\n        head_size: int,\n        total_num_heads: int,\n        total_num_kv_heads: int | None = None,\n        bias: bool = True,\n        skip_bias_add: bool = False,\n        params_dtype: torch.dtype | None = None,\n        quant_config: QuantizationConfig | None = None,\n        prefix: str = \"\",\n        tp_group: dist.ProcessGroup = None,\n    ):\n        self.hidden_size = hidden_size\n        self.head_size = head_size\n        self.total_num_heads = total_num_heads\n        if total_num_kv_heads is None:\n            total_num_kv_heads = total_num_heads\n        self.total_num_kv_heads = total_num_kv_heads\n        # Divide the weight matrix along the last dimension.\n        tp_group = tp_group or get_tp_group()\n        tp_size = get_group_size(tp_group)\n        self.num_heads = divide(self.total_num_heads, tp_size)\n        if tp_size >= self.total_num_kv_heads:\n            self.num_kv_heads = 1\n            self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads)\n        else:\n            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)\n            self.num_kv_head_replicas = 1\n        input_size = self.hidden_size\n        output_size = (\n            (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size\n        )\n        self.output_sizes = [\n            self.num_heads * self.head_size * tp_size,  # q_proj\n            self.num_kv_heads * self.head_size * tp_size,  # k_proj\n            self.num_kv_heads * self.head_size * tp_size,  # v_proj\n        ]\n\n        super().__init__(\n            input_size=input_size,\n            output_size=output_size,\n            bias=bias,\n            gather_output=False,\n            skip_bias_add=skip_bias_add,\n            params_dtype=params_dtype,\n            quant_config=quant_config,\n            prefix=prefix,\n            tp_group=tp_group,\n        )\n\n    def _get_shard_offset_mapping(self, loaded_shard_id: str) -> int | None:\n        shard_offset_mapping = {\n            \"q\": 0,\n            \"k\": self.num_heads * self.head_size,\n            \"v\": (self.num_heads + self.num_kv_heads) * self.head_size,\n            \"total\": (self.num_heads + 2 * self.num_kv_heads) * self.head_size,\n        }\n        return shard_offset_mapping.get(loaded_shard_id)\n\n    def _get_shard_size_mapping(self, loaded_shard_id: str) -> int | None:\n        shard_size_mapping = {\n            \"q\": self.num_heads * self.head_size,\n            \"k\": self.num_kv_heads * self.head_size,\n            \"v\": self.num_kv_heads * self.head_size,\n        }\n        return shard_size_mapping.get(loaded_shard_id)\n\n    def _load_fused_module_from_checkpoint(\n        self, param: BasevLLMParameter, loaded_weight: torch.Tensor\n    ):\n        \"\"\"\n        Handle special case for models where QKV layers are already\n        fused on disk. In this case, we have no shard id. This function\n        determmines the shard id by splitting these layers and then calls\n        the weight loader using the shard id.\n\n        An example of a model with these fused layers:\n        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct\n        \"\"\"\n        shard_offsets = [\n            # (shard_id, shard_offset, shard_size)\n            (\"q\", 0, self.total_num_heads * self.head_size),\n            (\n                \"k\",\n                self.total_num_heads * self.head_size,\n                self.total_num_kv_heads * self.head_size,\n            ),\n            (\n                \"v\",\n                (self.total_num_heads + self.total_num_kv_heads) * self.head_size,\n                self.total_num_kv_heads * self.head_size,\n            ),\n        ]\n\n        for shard_id, shard_offset, shard_size in shard_offsets:\n            # Special case for Quantization.\n            # If quantized, we need to adjust the offset and size to account\n            # for the packing.\n            if (\n                isinstance(param, PackedColumnParameter | PackedvLLMParameter)\n                and param.packed_dim == param.output_dim\n            ):\n                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(\n                    shard_size=shard_size, shard_offset=shard_offset\n                )\n\n            loaded_weight_shard = loaded_weight.narrow(\n                param.output_dim, shard_offset, shard_size\n            )\n            self.weight_loader_v2(param, loaded_weight_shard, shard_id)\n\n    def weight_loader_v2(\n        self,\n        param: BasevLLMParameter,\n        loaded_weight: torch.Tensor,\n        loaded_shard_id: str | None = None,\n    ):\n        if loaded_shard_id is None:  # special case for certain models\n            if isinstance(param, PerTensorScaleParameter):\n                param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0)\n                return\n            elif type(param) in (RowvLLMParameter, BasevLLMParameter):\n                param.load_qkv_weight(loaded_weight=loaded_weight)\n                return\n            # TODO: @dsikka - move to parameter.py\n            self._load_fused_module_from_checkpoint(param, loaded_weight)\n            return\n\n        assert loaded_shard_id in [\"q\", \"k\", \"v\"]\n\n        shard_offset = self._get_shard_offset_mapping(loaded_shard_id)\n        shard_size = self._get_shard_size_mapping(loaded_shard_id)\n\n        param.load_qkv_weight(\n            loaded_weight=loaded_weight,\n            num_heads=self.num_kv_head_replicas,\n            shard_id=loaded_shard_id,\n            shard_offset=shard_offset,\n            shard_size=shard_size,\n        )\n\n    def weight_loader(\n        self,\n        param: Parameter,\n        loaded_weight: torch.Tensor,\n        loaded_shard_id: str | None = None,\n    ):\n\n        param_data = param.data\n        output_dim = getattr(param, \"output_dim\", None)\n        # Special case for AQLM codebooks.\n        is_metadata = getattr(param, \"is_metadata\", False)\n\n        # Special case for per-tensor scales in fused case.\n        needs_scalar_to_array = getattr(param, \"needs_scalar_to_array\", False)\n\n        if loaded_shard_id is None:\n            # Loaded weight is already fused on disk (qkv).\n            # (e.g., Phi-3's qkv_proj).\n            if output_dim is None:\n                if needs_scalar_to_array:\n                    param_data, loaded_weight = adjust_scalar_to_fused_array(\n                        param_data, loaded_weight, 0\n                    )\n\n                assert param_data.shape == loaded_weight.shape\n                param_data.copy_(loaded_weight)\n                return\n            shard_offsets = [\n                # (shard_id, shard_offset, shard_size)\n                (\"q\", 0, self.total_num_heads * self.head_size),\n                (\n                    \"k\",\n                    self.total_num_heads * self.head_size,\n                    self.total_num_kv_heads * self.head_size,\n                ),\n                (\n                    \"v\",\n                    (self.total_num_heads + self.total_num_kv_heads) * self.head_size,\n                    self.total_num_kv_heads * self.head_size,\n                ),\n            ]\n\n            for shard_id, shard_offset, shard_size in shard_offsets:\n\n                loaded_weight_shard = loaded_weight.narrow(\n                    output_dim, shard_offset, shard_size\n                )\n                self.weight_loader(param, loaded_weight_shard, shard_id)\n            return\n\n        tp_rank = self.tp_rank\n        assert loaded_shard_id in [\"q\", \"k\", \"v\"]\n\n        # If output dim is defined, use the default loading process.\n        if output_dim is not None:\n            if loaded_shard_id == \"q\":\n                shard_offset = 0\n                shard_size = self.num_heads * self.head_size\n            elif loaded_shard_id == \"k\":\n                shard_offset = self.num_heads * self.head_size\n                shard_size = self.num_kv_heads * self.head_size\n            elif loaded_shard_id == \"v\":\n                shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size\n                shard_size = self.num_kv_heads * self.head_size\n\n            is_sharded_weight = getattr(param, \"is_sharded_weight\", False)\n            # bitsandbytes loads the weights of the specific portion\n            # no need to narrow\n            is_sharded_weight = is_sharded_weight\n\n            shard_idx = 0\n            param_data = param_data.narrow(output_dim, shard_offset, shard_size)\n            if loaded_shard_id == \"q\":\n                shard_idx = tp_rank\n            else:\n                shard_idx = tp_rank // self.num_kv_head_replicas\n            start_idx = shard_idx * shard_size\n\n            if not is_sharded_weight:\n                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)\n\n        # Special case for for AQLM codebooks.\n        elif is_metadata:\n            # metadata indicates fixed size concatenated along dim 0\n            shard_size = loaded_weight.shape[0]\n            shard_index = [\"q\", \"k\", \"v\"].index(loaded_shard_id)\n            param_data = param_data.narrow(0, shard_index * shard_size, shard_size)\n        # Special case for per-tensor scales in fused case.\n        elif needs_scalar_to_array:\n            param_data, loaded_weight = adjust_scalar_to_fused_array(\n                param_data, loaded_weight, loaded_shard_id\n            )\n        else:\n            ignore_warning = getattr(param, \"ignore_warning\", False)\n            if not ignore_warning:\n                logger.warning(\n                    \"Loading a weight without `output_dim` attribute in \"\n                    \"QKVParallelLinear, assume the weight is the same \"\n                    \"for all partitions.\"\n                )\n\n        assert param_data.shape == loaded_weight.shape\n        param_data.copy_(loaded_weight)\n\n\nclass RowParallelLinear(LinearBase):\n    \"\"\"Linear layer with row parallelism.\n\n    The linear layer is defined as Y = XA + b. A is parallelized along\n    its first dimension and X along its second dimension as:\n               -   -\n              | A_1 |\n              | .   |\n          A = | .   |        X = [X_1, ..., X_p]\n              | .   |\n              | A_p |\n               -   -\n    Arguments:\n        input_size: first dimension of matrix A.\n        output_size: second dimension of matrix A.\n        bias: If true, add bias. Note that bias is not parallelized.\n        input_is_parallel: If true, we assume that the input is already\n                           split across the GPUs and we do not split\n                           again.\n        skip_bias_add: This was added to enable performance optimization where\n                       bias can be fused with other element-wise operations.\n                       We skip adding bias but instead return it.\n        params_dtype: Data type for the parameters.\n        quant_config: Quantization configure.\n    \"\"\"\n\n    def __init__(\n        self,\n        input_size: int,\n        output_size: int,\n        bias: bool = True,\n        input_is_parallel: bool = True,\n        skip_bias_add: bool = False,\n        params_dtype: torch.dtype | None = None,\n        reduce_results: bool = True,\n        quant_config: QuantizationConfig | None = None,\n        prefix: str = \"\",\n        tp_group: dist.ProcessGroup = None,\n    ):\n        # Divide the weight matrix along the first dimension.\n        self.tp_group = tp_group or get_tp_group()\n        self.tp_rank = get_group_rank(self.tp_group)\n        self.tp_size = get_group_size(self.tp_group)\n        self.input_size_per_partition = divide(input_size, self.tp_size)\n        self.output_size_per_partition = output_size\n        self.output_partition_sizes = [output_size]\n\n        super().__init__(\n            input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix\n        )\n\n        self.input_is_parallel = input_is_parallel\n        self.reduce_results = reduce_results\n\n        assert self.quant_method is not None\n        self.quant_method.create_weights(\n            layer=self,\n            input_size_per_partition=self.input_size_per_partition,\n            output_partition_sizes=self.output_partition_sizes,\n            input_size=self.input_size,\n            output_size=self.output_size,\n            params_dtype=self.params_dtype,\n            weight_loader=(\n                self.weight_loader_v2\n                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED\n                else self.weight_loader\n            ),\n        )\n        if not reduce_results and (bias and not skip_bias_add):\n            raise ValueError(\n                \"When not reduce the results, adding bias to the \"\n                \"results can lead to incorrect results\"\n            )\n\n        if bias:\n            self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))\n            set_weight_attrs(\n                self.bias,\n                {\n                    \"output_dim\": 0,\n                    \"weight_loader\": self.weight_loader,\n                },\n            )\n        else:\n            self.register_parameter(\"bias\", None)\n\n    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):\n        tp_rank = self.tp_rank\n        input_dim = getattr(param, \"input_dim\", None)\n        is_sharded_weight = getattr(param, \"is_sharded_weight\", False)\n        # bitsandbytes loads the weights of the specific portion\n        # no need to narrow\n        is_sharded_weight = is_sharded_weight\n\n        param_data = param.data\n        if input_dim is not None and not is_sharded_weight:\n            shard_size = param_data.shape[input_dim]\n            start_idx = tp_rank * shard_size\n            loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)\n\n        # Special case for loading scales off disk, which often do not\n        # have a shape (such as in the case of AutoFP8).\n        if len(loaded_weight.shape) == 0:\n            loaded_weight = loaded_weight.reshape(1)\n\n        assert param_data.shape == loaded_weight.shape\n        param_data.copy_(loaded_weight)\n\n    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):\n\n        # Special case for loading scales off disk, which often do not\n        # have a shape (such as in the case of AutoFP8).\n        if len(loaded_weight.shape) == 0:\n            assert loaded_weight.numel() == 1\n            loaded_weight = loaded_weight.reshape(1)\n\n        param.load_row_parallel_weight(loaded_weight=loaded_weight)\n\n    def forward(self, input_) -> tuple[torch.Tensor, Parameter | None]:\n        if self.input_is_parallel:\n            input_parallel = input_\n        else:\n            tp_rank = self.tp_rank\n            splitted_input = split_tensor_along_last_dim(\n                input_, num_partitions=self.tp_size\n            )\n            input_parallel = splitted_input[tp_rank].contiguous()\n\n        # Matrix multiply.\n        assert self.quant_method is not None\n        # Only fuse bias add into GEMM for rank 0 (this ensures that\n        # bias will not get added more than once in TP>1 case)\n        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias\n        output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)\n        if self.reduce_results and self.tp_size > 1:\n            output = tensor_model_parallel_all_reduce(\n                output_parallel, tp_group=self.tp_group\n            )\n        else:\n            output = output_parallel\n\n        output_bias = self.bias if self.skip_bias_add else None\n\n        return output, output_bias\n\n    def extra_repr(self) -> str:\n        s = f\"input_features={self.input_size_per_partition}\"\n        s += f\", output_features={self.output_size}\"\n        s += f\", bias={self.bias is not None}\"\n        s += f\", tp_size={self.tp_size}\"\n        s += f\", reduce_results={self.reduce_results}\"\n        return s\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/layers/lora/linear.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# Code adapted from SGLang https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/lora/layers.py\n\n\nimport torch\nfrom torch import nn\nfrom torch.distributed._composable.fsdp import (\n    CPUOffloadPolicy,\n    OffloadPolicy,\n    fully_shard,\n)\nfrom torch.distributed.tensor import DTensor\n\nfrom sglang.multimodal_gen.runtime.distributed import (\n    get_local_torch_device,\n    get_tp_rank,\n    split_tensor_along_last_dim,\n    tensor_model_parallel_all_gather,\n    tensor_model_parallel_all_reduce,\n)\nfrom sglang.multimodal_gen.runtime.layers.linear import (\n    ColumnParallelLinear,\n    LinearBase,\n    MergedColumnParallelLinear,\n    QKVParallelLinear,\n    ReplicatedLinear,\n    RowParallelLinear,\n)\nfrom sglang.multimodal_gen.runtime.layers.vocab_parallel_embedding import (\n    VocabParallelEmbedding,\n)\nfrom sglang.multimodal_gen.utils import get_mixed_precision_state\n\ntorch._dynamo.config.recompile_limit = 16\n\n\nclass BaseLayerWithLoRA(nn.Module):\n\n    def __init__(\n        self,\n        base_layer: nn.Module,\n        lora_rank: int | None = None,\n        lora_alpha: int | None = None,\n    ):\n        super().__init__()\n        self.base_layer: nn.Module = base_layer\n\n        self.merged: bool = False\n        # Immutable base-weight snapshot; `to(\"cpu\")` may alias CPU storage.\n        # Use `clone()` so merge updates cannot mutate this backup tensor.\n        self.cpu_weight = base_layer.weight.detach().to(\"cpu\").clone()\n        # indicates adapter weights don't contain this layer\n        # (which shouldn't normally happen, but we want to separate it from the case of erroneous merging)\n        # Default to True to prevent using uninitialized weights; set to False when weights are loaded\n        self.disable_lora: bool = True\n        self.lora_rank = lora_rank\n        self.lora_alpha = lora_alpha\n        self.lora_weights_list: list[\n            tuple[torch.nn.Parameter, torch.nn.Parameter, str | None, float]\n        ] = []\n        self.lora_path: str | None = None\n        self.strength: float = 1.0\n\n        self.lora_A = None\n        self.lora_B = None\n\n    @property\n    def weight(self):\n        return self.base_layer.weight\n\n    @property\n    def bias(self):\n        return getattr(self.base_layer, \"bias\", None)\n\n    @torch.compile()\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        lora_A = self.lora_A\n        lora_B = self.lora_B\n        if isinstance(self.lora_B, DTensor):\n            lora_B = self.lora_B.to_local()\n            lora_A = self.lora_A.to_local()\n\n        # TODO: Support multiple LoRA adapters when use not merged mode\n        if not self.merged and not self.disable_lora:\n            lora_A_sliced = self.slice_lora_a_weights(lora_A.to(x, non_blocking=True))\n            lora_B_sliced = self.slice_lora_b_weights(lora_B.to(x, non_blocking=True))\n            delta = x @ lora_A_sliced.T @ lora_B_sliced.T\n            if self.lora_alpha != self.lora_rank:\n                delta = delta * (\n                    self.lora_alpha / self.lora_rank  # type: ignore\n                )  # type: ignore\n            delta = delta * self.strength\n            if delta.dim() > 2:\n                delta = delta.reshape(-1, delta.shape[-1])\n            out, output_bias = self.base_layer(x)\n            return out + delta, output_bias\n        else:\n            out, output_bias = self.base_layer(x)\n            return out, output_bias\n\n    def slice_lora_a_weights(self, A: torch.Tensor) -> torch.Tensor:\n        return A\n\n    def slice_lora_b_weights(self, B: torch.Tensor) -> torch.Tensor:\n        return B\n\n    def set_lora_weights(\n        self,\n        A: torch.Tensor,\n        B: torch.Tensor,\n        lora_path: str | None = None,\n        strength: float = 1.0,\n        clear_existing: bool = False,\n    ) -> None:\n        \"\"\"\n        Set LoRA weights. Supports multiple LoRA adapters.\n\n        Args:\n            A: LoRA A weight tensor\n            B: LoRA B weight tensor\n            lora_path: Path to the LoRA adapter (for logging)\n            strength: LoRA strength\n            clear_existing: If True, clear existing LoRA weights before adding new one.\n                          If False, append to existing list (for multi-LoRA support).\n        \"\"\"\n        lora_A_param = torch.nn.Parameter(\n            A\n        )  # share storage with weights in the pipeline\n        lora_B_param = torch.nn.Parameter(B)\n\n        if clear_existing:\n            self.lora_weights_list.clear()\n            # Also clear backward compatibility attributes\n            self.lora_A = None\n            self.lora_B = None\n            self.lora_path = None\n            self.strength = 1.0\n\n        # Add to list for multi-LoRA support\n        self.lora_weights_list.append((lora_A_param, lora_B_param, lora_path, strength))\n\n        # Set backward compatibility attributes to point to the last LoRA (for single LoRA case)\n        # This ensures backward compatibility while supporting multiple LoRA\n        self.lora_A = lora_A_param\n        self.lora_B = lora_B_param\n        self.lora_path = lora_path\n        self.strength = strength\n\n        self.disable_lora = False\n        self.merge_lora_weights()\n\n    @torch.no_grad()\n    def _merge_lora_into_data(\n        self,\n        data: torch.Tensor,\n        lora_list: list[\n            tuple[torch.nn.Parameter, torch.nn.Parameter, str | None, float]\n        ],\n    ) -> None:\n        \"\"\"\n        Merge all LoRA adapters into the data tensor in-place.\n\n        Args:\n            data: The base weight tensor to merge LoRA into (modified in-place)\n            lora_list: List of (lora_A, lora_B, lora_path, lora_strength) tuples\n        \"\"\"\n        # Merge all LoRA adapters in order\n        for lora_A, lora_B, _, lora_strength in lora_list:\n            lora_delta = self.slice_lora_b_weights(\n                lora_B.to(data)\n            ) @ self.slice_lora_a_weights(lora_A.to(data))\n            # Apply lora_alpha / lora_rank scaling for consistency with forward()\n            if self.lora_alpha is not None and self.lora_rank is not None:\n                if self.lora_alpha != self.lora_rank:\n                    lora_delta = lora_delta * (self.lora_alpha / self.lora_rank)\n            if lora_delta.dim() > 2:\n                lora_delta = lora_delta.reshape(-1, lora_delta.shape[-1])\n            data += lora_strength * lora_delta\n\n    @torch.no_grad()\n    def merge_lora_weights(self, strength: float | None = None) -> None:\n        if strength is not None:\n            self.strength = strength\n\n        if self.disable_lora:\n            return\n\n        if self.merged:\n            self.unmerge_lora_weights()\n\n        # Use lora_weights_list if available, otherwise fall back to single LoRA for backward compatibility\n        lora_list = self.lora_weights_list if self.lora_weights_list else []\n        if not lora_list and self.lora_A is not None and self.lora_B is not None:\n            lora_list = [(self.lora_A, self.lora_B, self.lora_path, self.strength)]\n\n        if not lora_list:\n            raise ValueError(\"LoRA weights not set. Please set them first.\")\n\n        if isinstance(self.base_layer.weight, DTensor):\n            mesh = self.base_layer.weight.data.device_mesh\n            unsharded_base_layer = ReplicatedLinear(\n                input_size=self.base_layer.input_size,\n                output_size=self.base_layer.output_size,\n                bias=getattr(self.base_layer, \"bias\", None) is not None,\n                skip_bias_add=self.base_layer.skip_bias_add,\n                params_dtype=self.base_layer.params_dtype,\n                quant_config=self.base_layer.quant_config,\n                prefix=self.base_layer.prefix,\n            )\n            # Using offload param is on CPU, so current_device is for \"CPU -> GPU -> merge -> CPU\"\n            current_device = self.base_layer.weight.data.device\n            data = self.base_layer.weight.data.to(\n                get_local_torch_device()\n            ).full_tensor()\n\n            self._merge_lora_into_data(data, lora_list)\n\n            unsharded_base_layer.weight = nn.Parameter(data.to(current_device))\n            if isinstance(getattr(self.base_layer, \"bias\", None), DTensor):\n                unsharded_base_layer.bias = nn.Parameter(\n                    self.base_layer.bias.to(get_local_torch_device(), non_blocking=True)\n                    .full_tensor()\n                    .to(current_device)\n                )\n\n            offload_policy = (\n                CPUOffloadPolicy() if \"cpu\" in str(current_device) else OffloadPolicy()\n            )\n            mp_policy = get_mixed_precision_state().mp_policy\n\n            self.base_layer = fully_shard(\n                unsharded_base_layer,\n                mesh=mesh,\n                mp_policy=mp_policy,\n                offload_policy=offload_policy,\n            )\n        else:\n            current_device = self.base_layer.weight.data.device\n            data = self.base_layer.weight.data.to(get_local_torch_device())\n\n            self._merge_lora_into_data(data, lora_list)\n\n            self.base_layer.weight.data = data.to(current_device, non_blocking=True)\n\n        self.merged = True\n\n    @torch.no_grad()\n    # @torch.compile(dynamic=True)\n    def unmerge_lora_weights(self) -> None:\n        if self.disable_lora:\n            return\n\n        if not self.merged:\n            raise ValueError(\n                \"LoRA weights not merged. Please merge them first before unmerging.\"\n            )\n\n        # avoid precision loss\n        if isinstance(self.base_layer.weight, DTensor):\n            device = self.base_layer.weight.data.device\n            old_weight = self.base_layer.weight\n            new_weight_data = self.cpu_weight.to(device, non_blocking=True)\n            self.base_layer.weight = nn.Parameter(new_weight_data)\n            del old_weight\n        else:\n            current_device = self.base_layer.weight.data.device\n            cpu_weight_on_device = self.cpu_weight.to(current_device, non_blocking=True)\n            self.base_layer.weight.data.copy_(cpu_weight_on_device)\n            if (\n                cpu_weight_on_device.data_ptr()\n                != self.base_layer.weight.data.data_ptr()\n            ):\n                del cpu_weight_on_device\n\n        self.merged = False\n\n\nclass VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):\n    \"\"\"\n    Vocab parallel embedding layer with support for LoRA (Low-Rank Adaptation).\n\n    Note: The current version does not yet implement the LoRA functionality.\n    This class behaves exactly the same as the base VocabParallelEmbedding.\n    Future versions will integrate LoRA functionality to support efficient parameter fine-tuning.\n    \"\"\"\n\n    def __init__(\n        self,\n        base_layer: VocabParallelEmbedding,\n    ) -> None:\n        super().__init__(base_layer)\n\n    def forward(self, input_: torch.Tensor) -> torch.Tensor:\n        raise NotImplementedError(\n            \"We don't support VocabParallelEmbeddingWithLoRA yet.\"\n        )\n\n\nclass ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):\n\n    def __init__(\n        self,\n        base_layer: ColumnParallelLinear,\n        lora_rank: int | None = None,\n        lora_alpha: int | None = None,\n    ) -> None:\n        super().__init__(base_layer, lora_rank, lora_alpha)\n\n    def forward(self, input_: torch.Tensor) -> torch.Tensor:\n        # duplicate the logic in ColumnParallelLinear\n        bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None\n        output_parallel = self.base_layer.quant_method.apply(\n            self.base_layer, input_, bias\n        )\n        if self.base_layer.gather_output:\n            output = tensor_model_parallel_all_gather(output_parallel)\n        else:\n            output = output_parallel\n        output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None\n        return output, output_bias\n\n    def slice_lora_a_weights(self, A: torch.Tensor) -> torch.Tensor:\n        return A\n\n    def slice_lora_b_weights(self, B: torch.Tensor) -> torch.Tensor:\n        tp_rank = get_tp_rank()\n        shard_size = self.base_layer.output_partition_sizes[0]\n        start_idx = tp_rank * shard_size\n        end_idx = (tp_rank + 1) * shard_size\n        B = B[start_idx:end_idx, :]\n        return B\n\n\nclass MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):\n\n    def __init__(\n        self,\n        base_layer: MergedColumnParallelLinear,\n        lora_rank: int | None = None,\n        lora_alpha: int | None = None,\n    ) -> None:\n        super().__init__(base_layer, lora_rank, lora_alpha)\n\n    def slice_lora_a_weights(self, A: torch.Tensor) -> torch.Tensor:\n        return A\n\n    def slice_lora_b_weights(self, B: torch.Tensor) -> torch.Tensor:\n        tp_rank = get_tp_rank()\n        # Since the outputs for both gate and up are identical, we use a random one.\n        shard_size = self.base_layer.output_partition_sizes[0]\n        start_idx = tp_rank * shard_size\n        end_idx = (tp_rank + 1) * shard_size\n        return B[:, start_idx:end_idx, :]\n\n\nclass QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):\n\n    def __init__(\n        self,\n        base_layer: QKVParallelLinear,\n        lora_rank: int | None = None,\n        lora_alpha: int | None = None,\n    ) -> None:\n        super().__init__(base_layer, lora_rank, lora_alpha)\n\n    def slice_lora_a_weights(self, A: torch.Tensor) -> torch.Tensor:\n        return A\n\n    def slice_lora_b_weights(\n        self, B: list[torch.Tensor]\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        tp_rank = get_tp_rank()\n        B_q, B_kv = B\n        base_layer = self.base_layer\n        q_proj_shard_size = base_layer.q_proj_shard_size\n        kv_proj_shard_size = base_layer.kv_proj_shard_size\n        num_kv_head_replicas = base_layer.num_kv_head_replicas\n\n        q_start_idx = q_proj_shard_size * tp_rank\n        q_end_idx = q_start_idx + q_proj_shard_size\n\n        kv_shard_id = tp_rank // num_kv_head_replicas\n        kv_start_idx = kv_proj_shard_size * kv_shard_id\n        kv_end_idx = kv_start_idx + kv_proj_shard_size\n\n        return B_q[q_start_idx:q_end_idx, :], B_kv[:, kv_start_idx:kv_end_idx, :]\n\n\nclass RowParallelLinearWithLoRA(BaseLayerWithLoRA):\n\n    def __init__(\n        self,\n        base_layer: RowParallelLinear,\n        lora_rank: int | None = None,\n        lora_alpha: int | None = None,\n    ) -> None:\n        super().__init__(base_layer, lora_rank, lora_alpha)\n\n    def forward(self, input_: torch.Tensor):\n        # duplicate the logic in RowParallelLinear\n        if self.base_layer.input_is_parallel:\n            input_parallel = input_\n        else:\n            tp_rank = get_tp_rank()\n            splitted_input = split_tensor_along_last_dim(\n                input_, num_partitions=self.base_layer.tp_size\n            )\n            input_parallel = splitted_input[tp_rank].contiguous()\n        output_parallel = self.base_layer.quant_method.apply(\n            self.base_layer, input_parallel\n        )\n\n        if self.base_layer.reduce_results and self.base_layer.tp_size > 1:\n            output_ = tensor_model_parallel_all_reduce(output_parallel)\n        else:\n            output_ = output_parallel\n\n        if not self.base_layer.skip_bias_add:\n            output = (\n                output_ + self.base_layer.bias\n                if self.base_layer.bias is not None\n                else output_\n            )\n            output_bias = None\n        else:\n            output = output_\n            output_bias = self.base_layer.bias\n        return output, output_bias\n\n    def slice_lora_a_weights(self, A: torch.Tensor) -> torch.Tensor:\n        tp_rank = get_tp_rank()\n        shard_size = self.base_layer.input_size_per_partition\n        start_idx = tp_rank * shard_size\n        end_idx = (tp_rank + 1) * shard_size\n        A = A[:, start_idx:end_idx].contiguous()\n        return A\n\n    def slice_lora_b_weights(self, B: torch.Tensor) -> torch.Tensor:\n        return B\n\n\nclass LinearWithLoRA(BaseLayerWithLoRA):\n    \"\"\"\n    Wrapper for standard torch.nn.Linear to support LoRA.\n    Unlike custom LinearBase classes, nn.Linear.forward() returns a single tensor,\n    not a tuple of (output, bias).\n    \"\"\"\n\n    def __init__(\n        self,\n        base_layer: nn.Linear,\n        lora_rank: int | None = None,\n        lora_alpha: int | None = None,\n    ) -> None:\n        super().__init__(base_layer, lora_rank, lora_alpha)\n\n    @torch.compile()\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        lora_A = self.lora_A\n        lora_B = self.lora_B\n        if isinstance(self.lora_B, DTensor):\n            lora_B = self.lora_B.to_local()\n            lora_A = self.lora_A.to_local()\n\n        # TODO: Support multiple LoRA adapters when use not merged mode\n        if not self.merged and not self.disable_lora:\n            lora_A_sliced = self.slice_lora_a_weights(lora_A.to(x, non_blocking=True))\n            lora_B_sliced = self.slice_lora_b_weights(lora_B.to(x, non_blocking=True))\n            delta = x @ lora_A_sliced.T @ lora_B_sliced.T\n            if self.lora_alpha != self.lora_rank:\n                delta = delta * (\n                    self.lora_alpha / self.lora_rank  # type: ignore\n                )  # type: ignore\n            delta = delta * self.strength\n            if delta.dim() > 2:\n                delta = delta.reshape(-1, delta.shape[-1])\n            # nn.Linear.forward() returns a single tensor, not a tuple\n            out = self.base_layer(x)\n            return out + delta\n        else:\n            # nn.Linear.forward() returns a single tensor\n            out = self.base_layer(x)\n            return out\n\n\ndef wrap_with_lora_layer(\n    layer: nn.Module,\n    lora_rank: int | None = None,\n    lora_alpha: int | None = None,\n) -> BaseLayerWithLoRA | None:\n    \"\"\"\n    transform the given layer to its corresponding LoRA layer\n    \"\"\"\n    supported_layer_types: dict[\n        type[LinearBase] | type[nn.Linear], type[BaseLayerWithLoRA]\n    ] = {\n        # the order matters\n        # VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA,\n        QKVParallelLinear: QKVParallelLinearWithLoRA,\n        MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA,\n        ColumnParallelLinear: ColumnParallelLinearWithLoRA,\n        RowParallelLinear: RowParallelLinearWithLoRA,\n        ReplicatedLinear: BaseLayerWithLoRA,\n        nn.Linear: LinearWithLoRA,\n    }\n    for src_layer_type, lora_layer_type in supported_layer_types.items():\n        if isinstance(layer, src_layer_type):  # type: ignore[arg-type]\n            ret = lora_layer_type(\n                layer,\n                lora_rank=lora_rank,\n                lora_alpha=lora_alpha,\n            )\n            return ret\n    return None\n\n\n# source: https://github.com/vllm-project/vllm/blob/93b38bea5dd03e1b140ca997dfaadef86f8f1855/vllm/lora/utils.py#L9\ndef replace_submodule(\n    model: nn.Module, module_name: str, new_module: nn.Module\n) -> nn.Module:\n    \"\"\"Replace a submodule in a model with a new module.\"\"\"\n    parent = model.get_submodule(\".\".join(module_name.split(\".\")[:-1]))\n    target_name = module_name.split(\".\")[-1]\n    setattr(parent, target_name, new_module)\n    return new_module\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/layers/mlp.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\nfrom typing import Optional\n\nimport torch\nimport torch.nn as nn\nfrom diffusers.models.activations import (\n    GEGLU,\n    GELU,\n    ApproximateGELU,\n    LinearActivation,\n    SwiGLU,\n)\n\nfrom sglang.multimodal_gen.runtime.layers.activation import get_act_fn\nfrom sglang.multimodal_gen.runtime.layers.linear import (\n    ColumnParallelLinear,\n    RowParallelLinear,\n)\nfrom sglang.multimodal_gen.runtime.layers.quantization import QuantizationConfig\nfrom sglang.srt.utils import add_prefix\n\n\nclass MLP(nn.Module):\n    \"\"\"\n    MLP for DiT blocks, NO gated linear units\n    \"\"\"\n\n    def __init__(\n        self,\n        input_dim: int,\n        mlp_hidden_dim: int,\n        output_dim: int | None = None,\n        bias: bool = True,\n        act_type: str = \"gelu_pytorch_tanh\",\n        dtype: torch.dtype | None = None,\n        prefix: str = \"\",\n        quant_config: QuantizationConfig = None,\n    ):\n        super().__init__()\n        self.fc_in = ColumnParallelLinear(\n            input_dim,\n            mlp_hidden_dim,\n            bias=True,\n            gather_output=False,\n            quant_config=quant_config,\n            prefix=add_prefix(\"0.proj\", prefix),\n        )\n\n        self.act = get_act_fn(act_type)\n        if output_dim is None:\n            output_dim = input_dim\n        self.fc_out = RowParallelLinear(\n            mlp_hidden_dim,\n            output_dim,\n            bias=True,\n            input_is_parallel=True,\n            quant_config=quant_config,\n            prefix=add_prefix(\"2\", prefix),\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x, _ = self.fc_in(x)\n        x = self.act(x)\n        x, _ = self.fc_out(x)\n        return x\n\n\nclass FeedForward(nn.Module):\n    r\"\"\"\n    A feed-forward layer.\n\n    Parameters:\n        dim (`int`): The number of channels in the input.\n        dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.\n        mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.\n        activation_fn (`str`, *optional*, defaults to `\"geglu\"`): Activation function to be used in feed-forward.\n        bias (`bool`, defaults to True): Whether to use a bias in the linear layer.\n    \"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        dim_out: Optional[int] = None,\n        mult: int = 4,\n        activation_fn: str = \"geglu\",\n        inner_dim=None,\n        bias: bool = True,\n    ):\n        super().__init__()\n        if inner_dim is None:\n            inner_dim = int(dim * mult)\n        dim_out = dim_out if dim_out is not None else dim\n\n        if activation_fn == \"gelu\":\n            act_fn = GELU(dim, inner_dim, bias=bias)\n        if activation_fn == \"gelu-approximate\":\n            act_fn = GELU(dim, inner_dim, approximate=\"tanh\", bias=bias)\n        elif activation_fn == \"geglu\":\n            act_fn = GEGLU(dim, inner_dim, bias=bias)\n        elif activation_fn == \"geglu-approximate\":\n            act_fn = ApproximateGELU(dim, inner_dim, bias=bias)\n        elif activation_fn == \"swiglu\":\n            act_fn = SwiGLU(dim, inner_dim, bias=bias)\n        elif activation_fn == \"linear-silu\":\n            act_fn = LinearActivation(dim, inner_dim, bias=bias, activation=\"silu\")\n\n        self.net = nn.ModuleList([])\n        # project in\n        self.net.append(act_fn)\n        # dummy dropout layer to match with checkpoints compatible with diffusers\n        self.net.append(nn.Dropout(0.0))\n        # project out\n        self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        for module in self.net:\n            hidden_states = module(hidden_states)\n        return hidden_states\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/layers/quantization/__init__.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\nfrom typing import Literal, get_args\n\nfrom sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import (\n    QuantizationConfig,\n)\nfrom sglang.multimodal_gen.runtime.layers.quantization.fp8 import Fp8Config\nfrom sglang.multimodal_gen.runtime.layers.quantization.modelslim import ModelSlimConfig\n\nQuantizationMethods = Literal[\"fp8\", \"modelslim\"]\n\nQUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))\n\n# The customized quantization methods which will be added to this dict.\n_CUSTOMIZED_METHOD_TO_QUANT_CONFIG = {\n    \"modelslim\": ModelSlimConfig,\n    \"fp8\": Fp8Config,\n}\n\n\ndef register_quantization_config(quantization: str):\n    \"\"\"Register a customized vllm quantization config.\n\n    When a quantization method is not supported by vllm, you can register a customized\n    quantization config to support it.\n\n    Args:\n        quantization (str): The quantization method name.\n\n\n    \"\"\"  # noqa: E501\n\n    def _wrapper(quant_config_cls):\n        if quantization in QUANTIZATION_METHODS:\n            raise ValueError(\n                f\"The quantization method `{quantization}` is already exists.\"\n            )\n        if not issubclass(quant_config_cls, QuantizationConfig):\n            raise ValueError(\n                \"The quantization config must be a subclass of \" \"`QuantizationConfig`.\"\n            )\n        _CUSTOMIZED_METHOD_TO_QUANT_CONFIG[quantization] = quant_config_cls\n        QUANTIZATION_METHODS.append(quantization)\n        return quant_config_cls\n\n    return _wrapper\n\n\ndef get_quantization_config(quantization: str) -> type[QuantizationConfig]:\n    if quantization not in QUANTIZATION_METHODS:\n        raise ValueError(f\"Invalid quantization method: {quantization}\")\n\n    method_to_config: dict[str, type[QuantizationConfig]] = {}\n    # Update the `method_to_config` with customized quantization methods.\n    method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)\n\n    return method_to_config[quantization]\n\n\n__all__ = [\n    \"QuantizationMethods\",\n    \"QuantizationConfig\",\n    \"get_quantization_config\",\n    \"QUANTIZATION_METHODS\",\n]\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/layers/quantization/configs/base_config.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/layers/quantization/base_config.py\n\nimport inspect\nfrom abc import ABC, abstractmethod\nfrom typing import TYPE_CHECKING, Any\n\nimport torch\nfrom torch import nn\n\nif TYPE_CHECKING:\n    from sglang.multimodal_gen.runtime.layers.quantization import QuantizationMethods\nelse:\n    QuantizationMethods = str\n\n\nclass QuantizeMethodBase(ABC):\n    \"\"\"Base class for different quantized methods.\"\"\"\n\n    @abstractmethod\n    def create_weights(\n        self, layer: torch.nn.Module, *weight_args, **extra_weight_attrs\n    ):\n        \"\"\"Create weights for a layer.\n\n        The weights will be set as attributes of the layer.\"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor:\n        \"\"\"Apply the weights in layer to the input tensor.\n\n        Expects create_weights to have been called before on the layer.\"\"\"\n        raise NotImplementedError\n\n    # Not required functions\n    def embedding(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor:\n        \"\"\"Gather embeddings in the layer based on indices in the input tensor.\n\n        Expects create_weights to have been called before on the layer.\"\"\"\n        raise NotImplementedError\n\n    def process_weights_after_loading(self, layer: nn.Module) -> None:\n        \"\"\"Process the weight after loading.\n\n        This can be used for example, to transpose weights for computation.\n        \"\"\"\n        return\n\n\ndef method_has_implemented_embedding(method_class: type[QuantizeMethodBase]) -> bool:\n    \"\"\"\n    Not all quant methods have embedding implemented, so we need to check that\n    it exists for our given method. We check this by making sure the function\n    has been changed from the base implementation.\n    \"\"\"\n    base_embedding = inspect.getattr_static(QuantizeMethodBase, \"embedding\", None)\n    class_embedding = inspect.getattr_static(method_class, \"embedding\", None)\n\n    return class_embedding is not None and class_embedding is not base_embedding\n\n\nclass QuantizationConfig(ABC):\n    \"\"\"Base class for quantization configs.\"\"\"\n\n    # for quantization frameworks with a separate quantized model provided, e.g. Nunchaku\n    quantized_model_path: str | None = None\n\n    def __init__(self):\n        super().__init__()\n        # mapping is updated by models as they initialize\n        self.packed_modules_mapping: dict[str, list[str]] = dict()\n\n    @abstractmethod\n    def get_name(self) -> QuantizationMethods:\n        \"\"\"Name of the quantization method.\"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def get_supported_act_dtypes(self) -> list[torch.dtype]:\n        \"\"\"List of supported activation dtypes.\"\"\"\n        raise NotImplementedError\n\n    @classmethod\n    @abstractmethod\n    def get_min_capability(cls) -> int:\n        \"\"\"Minimum GPU capability to support the quantization method.\n\n        E.g., 70 for Volta, 75 for Turing, 80 for Ampere.\n        This requirement is due to the custom CUDA kernels used by the\n        quantization method.\n        \"\"\"\n        raise NotImplementedError\n\n    @staticmethod\n    @abstractmethod\n    def get_config_filenames() -> list[str]:\n        \"\"\"List of filenames to search for in the model directory.\"\"\"\n        raise NotImplementedError\n\n    @classmethod\n    @abstractmethod\n    def from_config(cls, config: dict[str, Any]) -> \"QuantizationConfig\":\n        \"\"\"Create a config class from the model's quantization config.\"\"\"\n        raise NotImplementedError\n\n    @classmethod\n    def override_quantization_method(\n        cls, hf_quant_cfg, user_quant\n    ) -> QuantizationMethods | None:\n        \"\"\"\n        Detects if this quantization method can support a given checkpoint\n        format by overriding the user specified quantization method --\n        this method should only be overwritten by subclasses in exceptional\n        circumstances\n        \"\"\"\n        return None\n\n    @staticmethod\n    def get_from_keys(config: dict[str, Any], keys: list[str]) -> Any:\n        \"\"\"Get a value from the model's quantization config.\"\"\"\n        for key in keys:\n            if key in config:\n                return config[key]\n        raise ValueError(\n            f\"Cannot find any of {keys} in the model's \" \"quantization config.\"\n        )\n\n    @staticmethod\n    def get_from_keys_or(config: dict[str, Any], keys: list[str], default: Any) -> Any:\n        \"\"\"Get a optional value from the model's quantization config.\"\"\"\n        try:\n            return QuantizationConfig.get_from_keys(config, keys)\n        except ValueError:\n            return default\n\n    @abstractmethod\n    def get_quant_method(\n        self, layer: torch.nn.Module, prefix: str\n    ) -> QuantizeMethodBase | None:\n        \"\"\"Get the quantize method to use for the quantized layer.\n\n        Args:\n            layer: The layer for the quant method.\n            prefix: The full name of the layer in the state dict\n        Returns:\n            The quantize method. None if the given layer doesn't support quant\n            method.\n        \"\"\"\n        raise NotImplementedError\n\n    def get_cache_scale(self, name: str) -> str | None:\n        return None\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/layers/quantization/configs/nunchaku_config.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\nimport json\nimport os\nfrom dataclasses import dataclass\nfrom functools import lru_cache\nfrom typing import Any, Optional\n\nimport torch\nfrom safetensors.torch import load_file as safetensors_load_file\nfrom torch import nn\n\nfrom sglang.multimodal_gen.runtime.layers.linear import LinearBase\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nfrom .base_config import QuantizationConfig, QuantizeMethodBase\n\nlogger = init_logger(__name__)\n\n\n@lru_cache(maxsize=1)\ndef is_nunchaku_available() -> bool:\n    try:\n        import nunchaku  # noqa\n\n        logger.debug(\"Nunchaku package detected\")\n        return True\n    except Exception:\n        return False\n\n\n@dataclass\nclass NunchakuConfig(QuantizationConfig):\n    \"\"\"\n    Configuration for Nunchaku (SVDQuant) W4A4-style quantization.\n\n    Attributes:\n        precision: Quantization precision type. Options:\n            - \"int4\": Standard INT4 quantization\n            - \"nvfp4\": FP4 quantization\n        rank: SVD low-rank dimension for absorbing outliers\n        group_size: Quantization group size (automatically set based on precision)\n        act_unsigned: Use unsigned activation quantization\n        transformer_weights_path: Path to pre-quantized transformer weights (.safetensors)\n        model_cls: DiT model class that provides quantization rules via get_nunchaku_quant_rules()\n    \"\"\"\n\n    precision: str = \"int4\"\n    rank: int = 32\n    group_size: Optional[int] = None\n    act_unsigned: bool = False\n    transformer_weights_path: Optional[str] = None\n    model_cls: Optional[type] = None\n\n    @classmethod\n    def get_name(cls) -> str:\n        return \"svdquant\"\n\n    @classmethod\n    def get_supported_act_dtypes(cls) -> list[torch.dtype]:\n        return [torch.bfloat16, torch.float16]\n\n    @classmethod\n    def get_min_capability(cls) -> int:\n        return 70\n\n    @staticmethod\n    def get_config_filenames() -> list[str]:\n        return [\"quantization_config.json\", \"quant_config.json\"]\n\n    @classmethod\n    def from_config(cls, config: dict[str, Any]) -> \"NunchakuConfig\":\n\n        return cls(\n            precision=config.get(\"precision\", \"int4\"),\n            rank=int(config.get(\"rank\", 32)),\n            group_size=config.get(\"group_size\"),\n            act_unsigned=bool(config.get(\"act_unsigned\", False)),\n            transformer_weights_path=config.get(\"transformer_weights_path\"),\n        )\n\n    def get_quant_method(\n        self, layer: torch.nn.Module, prefix: str\n    ) -> Optional[QuantizeMethodBase]:\n        if not isinstance(layer, LinearBase):\n            return None\n\n        # get quantization rules from model class\n        quant_rules = self._get_quant_rules()\n\n        # priority: skip > awq_w4a16 > svdq_w4a4 > default\n        skip_patterns = quant_rules.get(\"skip\", [])\n        for pattern in skip_patterns:\n            if pattern in prefix.lower():\n                return None\n\n        awq_patterns = quant_rules.get(\"awq_w4a16\", [])\n        for pattern in awq_patterns:\n            if pattern in prefix:\n                from ..nunchaku_linear import NunchakuAWQLinearMethod\n\n                return NunchakuAWQLinearMethod(group_size=64)\n\n        svdq_patterns = quant_rules.get(\"svdq_w4a4\", [])\n        for pattern in svdq_patterns:\n            if pattern in prefix:\n                from ..nunchaku_linear import NunchakuSVDQLinearMethod\n\n                return NunchakuSVDQLinearMethod(\n                    precision=self.precision,\n                    rank=self.rank,\n                    act_unsigned=self.act_unsigned,\n                )\n\n        # default: apply svdq_w4a4 to all remaining linear layers\n        from ..nunchaku_linear import NunchakuSVDQLinearMethod\n\n        return NunchakuSVDQLinearMethod(\n            precision=self.precision,\n            rank=self.rank,\n            act_unsigned=self.act_unsigned,\n        )\n\n    def _get_quant_rules(self) -> dict[str, list[str]]:\n        if self.model_cls is not None and hasattr(\n            self.model_cls, \"get_nunchaku_quant_rules\"\n        ):\n            return self.model_cls.get_nunchaku_quant_rules()\n        return {}\n\n    def __post_init__(self):\n        if self.group_size is None:\n            if self.precision == \"nvfp4\":\n                self.group_size = 16\n            elif self.precision == \"int4\":\n                self.group_size = 64\n            else:\n                raise ValueError(\n                    f\"Invalid precision: {self.precision}. Must be 'int4' or 'nvfp4'\"\n                )\n\n        if self.precision not in [\"int4\", \"nvfp4\"]:\n            raise ValueError(\n                f\"Invalid precision: {self.precision}. Must be 'int4' or 'nvfp4'\"\n            )\n\n        if self.rank <= 0:\n            raise ValueError(f\"Rank must be positive, got {self.rank}\")\n\n    @classmethod\n    def from_dict(cls, config_dict: dict) -> \"NunchakuConfig\":\n        \"\"\"Create configuration from dictionary.\"\"\"\n        return cls(**config_dict)\n\n    def to_dict(self) -> dict:\n        \"\"\"Convert configuration to dictionary.\"\"\"\n        return {\n            \"precision\": self.precision,\n            \"rank\": self.rank,\n            \"group_size\": self.group_size,\n            \"act_unsigned\": self.act_unsigned,\n            \"transformer_weights_path\": self.transformer_weights_path,\n        }\n\n    @classmethod\n    def from_pretrained(cls, model_path: str) -> Optional[\"NunchakuConfig\"]:\n        for filename in cls.get_config_filenames():\n            config_path = os.path.join(model_path, filename)\n            if os.path.exists(config_path):\n                with open(config_path, \"r\") as f:\n                    config_dict = json.load(f)\n                if config_dict.get(\"quant_method\") == cls.get_name():\n                    return cls.from_config(config_dict)\n        return None\n\n\ndef _patch_native_svdq_linear(\n    module: nn.Module, tensor: Any, svdq_linear_cls: type\n) -> bool:\n    if (\n        isinstance(module, svdq_linear_cls)\n        and getattr(module, \"wtscale\", None) is not None\n    ):\n        module.wtscale = tensor\n        return True\n    return False\n\n\ndef _patch_sglang_svdq_linear(\n    module: nn.Module, tensor: Any, svdq_method_cls: type\n) -> bool:\n    quant_method = getattr(module, \"quant_method\", None)\n    if not isinstance(quant_method, svdq_method_cls):\n        return False\n\n    existing = getattr(module, \"wtscale\", None)\n    if isinstance(existing, nn.Parameter):\n        with torch.no_grad():\n            existing.data.copy_(tensor.to(existing.data.dtype))\n    else:\n        module.wtscale = tensor\n\n    # Keep alpha in sync (kernel reads `layer._nunchaku_alpha`)\n    try:\n        module._nunchaku_alpha = float(tensor.detach().cpu().item())\n    except Exception:\n        module._nunchaku_alpha = None\n    return True\n\n\ndef _patch_sglang_svdq_wcscales(\n    module: nn.Module, tensor: Any, svdq_method_cls: type\n) -> bool:\n    quant_method = getattr(module, \"quant_method\", None)\n    if not isinstance(quant_method, svdq_method_cls):\n        return False\n\n    existing = getattr(module, \"wcscales\", None)\n    if isinstance(existing, nn.Parameter):\n        with torch.no_grad():\n            existing.data.copy_(tensor.to(existing.data.dtype))\n    else:\n        module.wcscales = tensor\n    return True\n\n\ndef _patch_nunchaku_scales(\n    model: nn.Module,\n    safetensors_list: list[str],\n) -> None:\n    \"\"\"Patch transformer module with Nunchaku scale tensors from safetensors weights.\n\n    For NVFP4 checkpoints, correctness depends on `wtscale` and attention\n    `wcscales`. The FSDP loader may skip some of these metadata tensors.\n    \"\"\"\n\n    if not safetensors_list:\n        return\n\n    if len(safetensors_list) != 1:\n        logger.warning(\n            \"Nunchaku scale patch expects a single safetensors file, \"\n            \"but got %d files. Skipping.\",\n            len(safetensors_list),\n        )\n        return\n\n    from nunchaku.models.linear import SVDQW4A4Linear  # type: ignore[import]\n\n    state_dict = safetensors_load_file(safetensors_list[0])\n    if state_dict is None:\n        return\n\n    num_wtscale = 0\n    num_wcscales = 0\n\n    from ..nunchaku_linear import NunchakuSVDQLinearMethod\n\n    for name, module in model.named_modules():\n        wt = state_dict.get(f\"{name}.wtscale\")\n        if wt is not None:\n            if _patch_native_svdq_linear(module, wt, SVDQW4A4Linear):\n                num_wtscale += 1\n            elif _patch_sglang_svdq_linear(module, wt, NunchakuSVDQLinearMethod):\n                num_wtscale += 1\n\n        wc = state_dict.get(f\"{name}.wcscales\")\n        if wc is not None:\n            # Some modules may have wcscales as a direct attribute/Parameter.\n            existing = getattr(module, \"wcscales\", None)\n            if isinstance(existing, nn.Parameter):\n                with torch.no_grad():\n                    existing.data.copy_(wc.to(existing.data.dtype))\n                num_wcscales += 1\n            elif existing is not None:\n                setattr(module, \"wcscales\", wc)\n                num_wcscales += 1\n            elif _patch_sglang_svdq_wcscales(module, wc, NunchakuSVDQLinearMethod):\n                num_wcscales += 1\n\n    if num_wtscale > 0:\n        logger.info(\"Patched wtscale for %d layers\", num_wtscale)\n    if num_wcscales > 0:\n        logger.info(\"Patched wcscales for %d layers\", num_wcscales)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/layers/quantization/fp8.py",
    "content": "from __future__ import annotations\n\nimport logging\nfrom typing import TYPE_CHECKING, Any, Dict, List, Optional, Union\n\nimport torch\nfrom torch.nn import Module\nfrom torch.nn.parameter import Parameter\n\nfrom sglang.multimodal_gen.runtime.distributed.parallel_state import (\n    get_tensor_model_parallel_world_size,\n)\nfrom sglang.multimodal_gen.runtime.layers.linear import (\n    LinearMethodBase,\n    UnquantizedLinearMethod,\n)\nfrom sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import (\n    QuantizationConfig,\n    QuantizeMethodBase,\n)\nfrom sglang.multimodal_gen.runtime.models.parameter import (\n    BlockQuantScaleParameter,\n    ModelWeightParameter,\n    PerTensorScaleParameter,\n)\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\nfrom sglang.multimodal_gen.runtime.utils.common import (\n    cpu_has_amx_support,\n    get_bool_env_var,\n    use_intel_amx_backend,\n)\nfrom sglang.srt.layers.amx_utils import _amx_process_weight_after_loading\nfrom sglang.srt.layers.quantization.fp8_kernel import (\n    is_fp8_fnuz,\n    per_token_group_quant_fp8,\n)\nfrom sglang.srt.layers.quantization.fp8_utils import (\n    apply_fp8_linear,\n    can_auto_enable_marlin_fp8,\n    cutlass_fp8_supported,\n    dispatch_w8a8_block_fp8_linear,\n    input_to_float8,\n    normalize_e4m3fn_to_e4m3fnuz,\n    requant_weight_ue8m0_inplace,\n)\nfrom sglang.srt.layers.quantization.marlin_utils_fp8 import (\n    apply_fp8_marlin_linear,\n    prepare_fp8_layer_for_marlin,\n)\nfrom sglang.srt.layers.quantization.utils import (\n    convert_to_channelwise,\n    is_layer_skipped,\n    requantize_with_max_scale,\n)\n\nif TYPE_CHECKING:\n    from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config\n\n_is_hip = current_platform.is_hip()\n_is_cuda = current_platform.is_cuda()\n_is_npu = current_platform.is_npu()\n_is_cpu_amx_available = cpu_has_amx_support()\n_is_cpu = current_platform.is_cpu()\n_is_fp8_fnuz = is_fp8_fnuz()\n_use_hip_int4 = get_bool_env_var(\"SGLANG_INT4_WEIGHT\") and _is_hip\n_use_aiter = get_bool_env_var(\"SGLANG_USE_AITER\") and _is_hip\n\nif _use_aiter or _use_hip_int4:\n    pass\n\n\nACTIVATION_SCHEMES = [\"static\", \"dynamic\"]\n\nlogger = logging.getLogger(__name__)\n\n\nclass Fp8Config(QuantizationConfig):\n    \"\"\"Config class for FP8.\"\"\"\n\n    def __init__(\n        self,\n        is_checkpoint_fp8_serialized: bool = False,\n        activation_scheme: str = \"dynamic\",\n        ignored_layers: Optional[List[str]] = None,\n        weight_block_size: List[int] = None,\n    ) -> None:\n        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized\n        if is_checkpoint_fp8_serialized:\n            logger.info(\"Detected fp8 checkpoint.\")\n        if activation_scheme not in ACTIVATION_SCHEMES:\n            raise ValueError(f\"Unsupported activation scheme {activation_scheme}\")\n        self.activation_scheme = activation_scheme\n        self.ignored_layers = ignored_layers or []\n        if weight_block_size is not None:\n            if not is_checkpoint_fp8_serialized:\n                raise ValueError(\n                    f\"The block-wise quantization only supports fp8-serialized checkpoint for now.\"\n                )\n            if len(weight_block_size) != 2:\n                raise ValueError(\n                    f\"The quantization block size of weight must have 2 dimensions, but got {len(weight_block_size)} dimensions.\"\n                )\n            if activation_scheme != \"dynamic\":\n                raise ValueError(\n                    f\"The block-wise quantization only supports dynamic activation scheme for now, but got {activation_scheme} activation scheme.\"\n                )\n        self.weight_block_size = weight_block_size\n\n    @classmethod\n    def get_name(cls) -> str:\n        return \"fp8\"\n\n    @classmethod\n    def get_supported_act_dtypes(cls) -> List[torch.dtype]:\n        return [torch.bfloat16, torch.half]\n\n    @classmethod\n    def get_min_capability(cls) -> int:\n        return 80\n\n    @classmethod\n    def get_config_filenames(cls) -> List[str]:\n        return []\n\n    @classmethod\n    def from_config(cls, config: Dict[str, Any]) -> Fp8Config:\n        quant_method = cls.get_from_keys(config, [\"quant_method\"])\n        is_checkpoint_fp8_serialized = \"fp8\" in quant_method\n        activation_scheme = cls.get_from_keys(config, [\"activation_scheme\"])\n        ignored_layers = cls.get_from_keys_or(\n            config, [\"ignored_layers\", \"modules_to_not_convert\"], None\n        )\n        if ignored_layers:\n            # hacking ministral\n            ignored_layers = [layer.replace(\"model.\", \"\") for layer in ignored_layers]\n        weight_block_size = cls.get_from_keys_or(config, [\"weight_block_size\"], None)\n        return cls(\n            is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,\n            activation_scheme=activation_scheme,\n            ignored_layers=ignored_layers,\n            weight_block_size=weight_block_size,\n        )\n\n    def get_quant_method(\n        self, layer: torch.nn.Module, prefix: str\n    ) -> Optional[QuantizeMethodBase]:\n        from sglang.multimodal_gen.runtime.layers.linear import LinearBase\n\n        if isinstance(layer, LinearBase):\n            if is_layer_skipped(prefix, self.ignored_layers):\n                return UnquantizedLinearMethod()\n            return Fp8LinearMethod(self)\n        return None\n\n    def get_scaled_act_names(self) -> List[str]:\n        return []\n\n\nclass Fp8LinearMethod(LinearMethodBase):\n    \"\"\"Linear method for FP8.\n    Supports loading FP8 checkpoints with static weight scale and\n    dynamic/static activation scale.\n\n    Also supports loading quantized FP16/BF16 model checkpoints with dynamic\n    activation scaling. The weight scaling factor will be initialized after\n    the model weights are loaded.\n\n    Limitations:\n    1. Only support per-tensor quantization due to torch._scaled_mm support.\n    2. Only support float8_e4m3fn data type due to the limitation of\n       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)\n\n    Args:\n        quant_config: The quantization config.\n    \"\"\"\n\n    def __init__(self, quant_config: Union[Fp8Config, W4AFp8Config]):\n        self.quant_config = quant_config\n        self.cutlass_fp8_supported = cutlass_fp8_supported()\n\n        # For GPUs that lack FP8 hardware support, we can leverage the Marlin\n        # kernel for fast weight-only FP8 quantization\n        self.use_marlin = False\n        if _is_cuda:\n            force_marlin = get_bool_env_var(\"SGLANG_FORCE_FP8_MARLIN\")\n            auto_enable = can_auto_enable_marlin_fp8()\n            self.use_marlin = force_marlin or auto_enable\n\n        self.block_quant = self.quant_config.weight_block_size is not None\n\n        self.w8a8_block_fp8_linear = dispatch_w8a8_block_fp8_linear()\n\n    def create_weights(\n        self,\n        layer: torch.nn.Module,\n        input_size_per_partition: int,\n        output_partition_sizes: List[int],\n        input_size: int,\n        output_size: int,\n        params_dtype: torch.dtype,\n        **extra_weight_attrs,\n    ):\n        output_size_per_partition = sum(output_partition_sizes)\n        weight_loader = extra_weight_attrs.get(\"weight_loader\")\n\n        tp_size = get_tensor_model_parallel_world_size()\n        if self.block_quant:\n            block_n, block_k = (\n                self.quant_config.weight_block_size[0],\n                self.quant_config.weight_block_size[1],\n            )\n            # Required by row parallel\n            if tp_size > 1 and input_size // input_size_per_partition == tp_size:\n                if input_size_per_partition % block_k != 0:\n                    raise ValueError(\n                        f\"Weight input_size_per_partition = \"\n                        f\"{input_size_per_partition} is not divisible by \"\n                        f\"weight quantization block_k = {block_k}.\"\n                    )\n            # Required by column parallel or enabling merged weights\n            if (\n                tp_size > 1 and output_size // output_size_per_partition == tp_size\n            ) or len(output_partition_sizes) > 1:\n                for output_partition_size in output_partition_sizes:\n                    if output_partition_size % block_n != 0:\n                        raise ValueError(\n                            f\"Weight output_partition_size = \"\n                            f\"{output_partition_size} is not divisible by \"\n                            f\"weight quantization block_n = {block_n}.\"\n                        )\n\n        layer.logical_widths = output_partition_sizes\n        layer.input_size_per_partition = input_size_per_partition\n        layer.output_size_per_partition = output_size_per_partition\n        layer.orig_dtype = params_dtype\n\n        # WEIGHT\n        weight_dtype = (\n            torch.float8_e4m3fn\n            if self.quant_config.is_checkpoint_fp8_serialized\n            else params_dtype\n        )\n\n        weight = ModelWeightParameter(\n            data=torch.empty(\n                output_size_per_partition, input_size_per_partition, dtype=weight_dtype\n            ),\n            input_dim=1,\n            output_dim=0,\n            weight_loader=weight_loader,\n        )\n        layer.register_parameter(\"weight\", weight)\n\n        # If checkpoint is serialized fp8, load them.\n        # Otherwise, wait until process_weights_after_loading.\n        if self.quant_config.is_checkpoint_fp8_serialized:\n            # WEIGHT SCALE\n            if self.block_quant:\n                if hasattr(self.quant_config, \"activation_scheme\"):\n                    assert self.quant_config.activation_scheme == \"dynamic\"\n                elif hasattr(self.quant_config, \"linear_activation_scheme\"):\n                    assert self.quant_config.linear_activation_scheme == \"dynamic\"\n                scale = BlockQuantScaleParameter(\n                    data=torch.empty(\n                        (output_size_per_partition + block_n - 1) // block_n,\n                        (input_size_per_partition + block_k - 1) // block_k,\n                        dtype=torch.float32,\n                    ),\n                    input_dim=1,\n                    output_dim=0,\n                    weight_loader=weight_loader,\n                )\n                scale.format_ue8m0 = False\n                scale[:] = torch.finfo(torch.float32).min\n                layer.register_parameter(\"weight_scale_inv\", scale)\n            else:\n                scale = PerTensorScaleParameter(\n                    data=torch.empty(len(output_partition_sizes), dtype=torch.float32),\n                    weight_loader=weight_loader,\n                )\n                scale[:] = torch.finfo(torch.float32).min\n                layer.register_parameter(\"weight_scale\", scale)\n\n            # INPUT ACTIVATION SCALE\n            if (\n                hasattr(self.quant_config, \"activation_scheme\")\n                and self.quant_config.activation_scheme == \"static\"\n            ) or (\n                hasattr(self.quant_config, \"linear_activation_scheme\")\n                and self.quant_config.linear_activation_scheme == \"static\"\n            ):\n                scale = PerTensorScaleParameter(\n                    data=torch.empty(len(output_partition_sizes), dtype=torch.float32),\n                    weight_loader=weight_loader,\n                )\n\n                scale[:] = torch.finfo(torch.float32).min\n                layer.register_parameter(\"input_scale\", scale)\n            else:\n                layer.register_parameter(\"input_scale\", None)\n\n    def process_weights_after_loading(self, layer: Module) -> None:\n        if self.block_quant:\n            # If ROCm, normalize the weights and scales to e4m3fnuz\n            if _is_fp8_fnuz:\n                # activation_scheme: dynamic\n                weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(\n                    weight=layer.weight,\n                    weight_scale=layer.weight_scale_inv,\n                    input_scale=None,\n                )\n                layer.input_scale = None\n            elif _is_cpu:\n                assert (\n                    _is_cpu_amx_available\n                ), \"Fp8LinearMethod on CPU requires that CPU has AMX support\"\n                _amx_process_weight_after_loading(layer, [\"weight\"])\n                layer.weight_scale_inv = torch.nn.Parameter(\n                    layer.weight_scale_inv.data, requires_grad=False\n                )\n                return\n            else:\n                # For fp8 linear weights run with deepgemm, the weights and scales need be requantized to ue8m0\n                from sglang.srt.layers.quantization.fp8_utils import (\n                    deepgemm_w8a8_block_fp8_linear_with_fallback,\n                )\n                from sglang.srt.model_loader.utils import (\n                    should_deepgemm_weight_requant_ue8m0,\n                )\n\n                if (\n                    should_deepgemm_weight_requant_ue8m0(\n                        weight_block_size=getattr(\n                            self.quant_config, \"weight_block_size\", None\n                        ),\n                    )\n                    and (\n                        self.w8a8_block_fp8_linear\n                        is deepgemm_w8a8_block_fp8_linear_with_fallback\n                    )\n                    and (not layer.weight_scale_inv.format_ue8m0)\n                ):\n                    requant_weight_ue8m0_inplace(\n                        layer.weight,\n                        layer.weight_scale_inv,\n                        self.quant_config.weight_block_size,\n                    )\n                    layer.weight_scale_inv.format_ue8m0 = True\n                weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data\n\n            layer.weight.data = weight.data\n            layer.weight_scale_inv.data = weight_scale.data\n        else:\n            layer.weight = Parameter(layer.weight.data, requires_grad=False)\n\n            # If checkpoint not serialized fp8, quantize the weights.\n            if not self.quant_config.is_checkpoint_fp8_serialized:\n                if self.cutlass_fp8_supported or self.use_marlin:\n                    # apply per-channel quantization default as\n                    # cutlass sgl-kernel and marlin only support per-channel scale\n                    qweight, weight_scale = per_token_group_quant_fp8(\n                        layer.weight, layer.weight.shape[-1]\n                    )\n                    weight_scale = weight_scale.t().contiguous()\n                else:\n                    # per-tensor quantization\n                    qweight, weight_scale = input_to_float8(layer.weight)\n\n                # Update the layer with the new values.\n                layer.weight = Parameter(qweight.t(), requires_grad=False)\n                layer.weight_scale = Parameter(weight_scale, requires_grad=False)\n                layer.input_scale = None\n\n            # If checkpoint is fp8, handle that there are N scales for N\n            # shards in a fused module\n            else:\n                layer.weight_scale = Parameter(\n                    layer.weight_scale.data, requires_grad=False\n                )\n                if (\n                    hasattr(self.quant_config, \"activation_scheme\")\n                    and self.quant_config.activation_scheme == \"static\"\n                ) or (\n                    hasattr(self.quant_config, \"linear_activation_scheme\")\n                    and self.quant_config.linear_activation_scheme == \"static\"\n                ):\n                    layer.input_scale = Parameter(\n                        layer.input_scale.data, requires_grad=False\n                    )\n\n                # cutlass sgl-kernel and marlin only support per-channel scale\n                if self.cutlass_fp8_supported or self.use_marlin:\n                    weight = layer.weight\n                    weight_scale = convert_to_channelwise(\n                        layer.weight_scale, layer.logical_widths\n                    )\n                else:\n                    # Dequant -> Quant with max scale so we can run per tensor.\n                    weight = layer.weight\n                    weight_scale = layer.weight_scale\n                    # If ROCm, normalize the weights and scales to e4m3fnuz\n                    if _is_fp8_fnuz:\n                        weight, weight_scale, input_scale = (\n                            normalize_e4m3fn_to_e4m3fnuz(\n                                weight=weight,\n                                weight_scale=weight_scale,\n                                input_scale=layer.input_scale,\n                            )\n                        )\n                        if input_scale is not None:\n                            layer.input_scale = Parameter(\n                                input_scale, requires_grad=False\n                            )\n\n                    weight_scale, weight = requantize_with_max_scale(\n                        weight=weight,\n                        weight_scale=weight_scale,\n                        logical_widths=layer.logical_widths,\n                    )\n\n                # Update layer with new values.\n                layer.weight = Parameter(weight.t(), requires_grad=False)\n                layer.weight_scale = Parameter(weight_scale, requires_grad=False)\n                if (\n                    hasattr(self.quant_config, \"activation_scheme\")\n                    and self.quant_config.activation_scheme == \"static\"\n                ) or (\n                    hasattr(self.quant_config, \"linear_activation_scheme\")\n                    and self.quant_config.linear_activation_scheme == \"static\"\n                ):\n                    layer.input_scale = Parameter(\n                        layer.input_scale.max(), requires_grad=False\n                    )\n\n        if self.use_marlin:\n            if self.block_quant:\n                layer.weight_block_size = self.quant_config.weight_block_size\n            prepare_fp8_layer_for_marlin(layer, not self.block_quant)\n            # Activations not quantized for marlin.\n            del layer.input_scale\n\n    def apply(\n        self,\n        layer: torch.nn.Module,\n        x: torch.Tensor,\n        bias: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        if self.use_marlin:\n            return apply_fp8_marlin_linear(\n                input=x,\n                weight=layer.weight,\n                weight_scale=layer.weight_scale,\n                workspace=layer.workspace,\n                size_n=layer.output_size_per_partition,\n                size_k=layer.input_size_per_partition,\n                bias=bias,\n            )\n\n        if self.block_quant:\n            if use_intel_amx_backend(layer):\n                return torch.ops.sgl_kernel.fp8_scaled_mm_cpu(\n                    x,\n                    layer.weight,\n                    layer.weight_scale_inv,\n                    self.quant_config.weight_block_size,\n                    bias,\n                    x.dtype,\n                    True,  # is_vnni\n                )\n\n            if isinstance(x, tuple):\n                return self.w8a8_block_fp8_linear(\n                    input=x[0],\n                    weight=layer.weight,\n                    block_size=self.quant_config.weight_block_size,\n                    weight_scale=layer.weight_scale_inv,\n                    input_scale=x[1],\n                    bias=bias,\n                )\n\n            return self.w8a8_block_fp8_linear(\n                input=x,\n                weight=layer.weight,\n                block_size=self.quant_config.weight_block_size,\n                weight_scale=layer.weight_scale_inv,\n                input_scale=None,\n                bias=bias,\n            )\n\n        return apply_fp8_linear(\n            input=x,\n            weight=layer.weight,\n            weight_scale=layer.weight_scale,\n            input_scale=layer.input_scale,\n            bias=bias,\n            cutlass_fp8_supported=self.cutlass_fp8_supported,\n            use_per_token_if_dynamic=False,\n        )\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/layers/quantization/modelslim.py",
    "content": "from __future__ import annotations\n\nimport logging\nfrom types import MappingProxyType\nfrom typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, cast\n\nimport torch\n\nfrom sglang.multimodal_gen.runtime.layers.linear import (\n    LinearMethodBase,\n    UnquantizedLinearMethod,\n)\nfrom sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import (\n    QuantizationConfig,\n    QuantizeMethodBase,\n)\nfrom sglang.srt.layers.quantization.compressed_tensors.utils import should_ignore_layer\nfrom sglang.srt.layers.quantization.modelslim.schemes import (\n    ModelSlimW4A4Int4,\n    ModelSlimW8A8Int8,\n)\n\nif TYPE_CHECKING:\n    from sglang.srt.layers.quantization.modelslim.modelslim import ModelSlimConfig\n    from sglang.srt.layers.quantization.modelslim.schemes import (\n        ModelSlimLinearScheme,\n    )\n\nlogger = logging.getLogger(__name__)\n\n\nclass ModelSlimConfig(QuantizationConfig):\n    \"\"\"\n    Config class for ModelSlim Quantization of Diffusion models https://gitcode.com/Ascend/msmodelslim, a NPU-specific quantization type.\n    The quantization method (W8A8, W4A4, etc.) will be automatically parsed from the `quant_model_description.json` config.\n\n    ModelSlim for Diffusion models includes support for various quantization schemes, such as:\n    - W4A4 dynamic linear\n    - W8A8 static linear\n    - W8A8 dynamic linear\n    \"\"\"\n\n    def __init__(self, quant_config: Dict[str, Any] = {}):\n        super().__init__()\n        self.quant_description = quant_config\n        ignore = cast(List[str], quant_config.get(\"ignore\", []))\n        self.ignore = ignore\n        packed_modules_mapping = quant_config.get(\"packed_modules_mapping\", {})\n        self.packed_modules_mapping = (\n            packed_modules_mapping if packed_modules_mapping is not None else {}\n        )\n\n    def get_linear_method(self) -> ModelSlimLinearMethod:\n        return ModelSlimLinearMethod(self)\n\n    @classmethod\n    def get_supported_act_dtypes(cls) -> List[torch.dtype]:\n        return [torch.int8, torch.float16, torch.bfloat16]\n\n    @classmethod\n    def get_min_capability(cls) -> int:\n        return 0\n\n    @classmethod\n    def get_name(cls) -> str:\n        return \"modelslim\"\n\n    @classmethod\n    def get_config_filenames(cls) -> List[str]:\n        filenames = [\"quant_model_description.json\"]\n        return filenames\n\n    @classmethod\n    def from_config(cls, config: Dict[str, Any]) -> ModelSlimConfig:\n        return cls(config)\n\n    def get_quant_method(\n        self,\n        layer: torch.nn.Module,\n        prefix: str,\n    ) -> Optional[QuantizeMethodBase]:\n        from sglang.multimodal_gen.runtime.layers.linear import LinearBase\n\n        if isinstance(layer, LinearBase):\n            if should_ignore_layer(\n                prefix,\n                ignore=self.ignore,\n                fused_mapping=self.packed_modules_mapping,\n            ):\n                return UnquantizedLinearMethod()\n            key = \"model\"\n            packed_modules_mapping_subset = self.packed_modules_mapping.get(key, {})\n            prefix_in_quant_config = prefix\n            proj_name = prefix.split(\".\")[-1]\n            if proj_name in packed_modules_mapping_subset:\n                prefix_in_quant_config = prefix.replace(\n                    proj_name, packed_modules_mapping_subset[proj_name][0]\n                )\n\n            if self.is_layer_skipped(prefix, packed_modules_mapping_subset):\n                return UnquantizedLinearMethod()\n            scheme = self.get_scheme(layer=layer, layer_name=prefix_in_quant_config)\n            layer.scheme = scheme\n            return ModelSlimLinearMethod(self)\n        else:\n            return None\n\n    def _get_scheme_from_parts(\n        self,\n        layer_name: str,\n    ) -> ModelSlimLinearScheme:\n\n        quant_type = self.quant_description.get(layer_name + \".weight\", \"\")\n        if quant_type == \"W8A8_DYNAMIC\" or quant_type == \"W8A8\":\n            return ModelSlimW8A8Int8(\n                quant_config=self.quant_description, prefix=layer_name\n            )\n        elif quant_type == \"W4A4_DYNAMIC\":\n            return ModelSlimW4A4Int4(\n                quant_config=self.quant_description, prefix=layer_name\n            )\n        raise NotImplementedError(\"No modelslim compatible scheme was found.\")\n\n    def get_scheme(\n        self, layer: torch.nn.Module, layer_name: Optional[str] = None\n    ) -> Optional[ModelSlimLinearScheme]:\n        \"\"\"\n        get_scheme method adjusted for modelslim, taken from\n        python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py\n        \"\"\"\n        scheme = self._get_scheme_from_parts(\n            layer_name=layer_name,\n        )\n\n        # Ascend doesn't support device capability\n        logger.debug(\"Using scheme: %s for %s\", scheme.__class__.__name__, layer_name)\n        return scheme\n\n    def is_layer_skipped(\n        self, prefix: str, fused_mapping: Mapping[str, List[str]] = MappingProxyType({})\n    ):\n        # adapted from vllm.model_executor.layers.quantization.utils.quant_utils.is_layer_skipped\n        proj_name = prefix.split(\".\")[-1]\n        if proj_name in fused_mapping:\n            shard_prefixes = [\n                prefix.replace(proj_name, shard_proj_name)\n                for shard_proj_name in fused_mapping[proj_name]\n            ]\n\n            is_skipped = None\n            for shard_prefix in shard_prefixes:\n                is_shard_skipped = (\n                    self.quant_description.get(shard_prefix + \".weight\", \"\") == \"FLOAT\"\n                )\n\n                if is_skipped is None:\n                    is_skipped = is_shard_skipped\n                elif is_shard_skipped != is_skipped:\n                    raise ValueError(\n                        f\"Detected some but not all shards of {prefix} \"\n                        \"are quantized. All shards of fused layers \"\n                        \"to have the same precision.\"\n                    )\n        else:\n            is_skipped = self.quant_description.get(prefix + \".weight\", \"\") == \"FLOAT\"\n\n        assert is_skipped is not None\n        return is_skipped\n\n    def get_scaled_act_names(self) -> List[str]:\n        return []\n\n\nclass ModelSlimLinearMethod(LinearMethodBase):\n\n    def __init__(self, quantization_config: ModelSlimConfig):\n        self.quantization_config = quantization_config\n\n    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:\n        layer.scheme.process_weights_after_loading(layer)\n\n    def create_weights(\n        self,\n        layer: torch.nn.Module,\n        input_size_per_partition: int,\n        output_partition_sizes: List[int],\n        input_size: int,\n        output_size: int,\n        params_dtype: torch.dtype,\n        **extra_weight_attrs,\n    ):\n        \"\"\"\n        Use the ModelSlimLinearScheme associated with each layer to create\n        the necessary parameters for the layer. See LinearMethodBase for param\n        details\n        \"\"\"\n        weight_loader = extra_weight_attrs.get(\"weight_loader\")\n        layer.scheme.create_weights(\n            layer=layer,\n            input_size=input_size,\n            input_size_per_partition=input_size_per_partition,\n            output_partition_sizes=output_partition_sizes,\n            output_size=output_size,\n            params_dtype=params_dtype,\n            weight_loader=weight_loader,\n        )\n\n    def apply(\n        self,\n        layer: torch.nn.Module,\n        x: torch.Tensor,\n        bias: Optional[torch.Tensor] = None,\n    ):\n        \"\"\"\n        Use the output of create_weights and the CompressedTensorsScheme\n        associated with the layer to apply the forward pass with the\n        layer input.  See LinearMethodBase for param details\n\n        \"\"\"\n\n        scheme = layer.scheme\n        if scheme is None:\n            raise ValueError(\"A scheme must be defined for each layer\")\n        return scheme.apply_weights(layer, x, bias=bias)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/layers/quantization/nunchaku_linear.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\nfrom typing import List, Optional\n\nimport torch\nimport torch.nn as nn\nfrom torch.nn.parameter import Parameter\n\nfrom sglang.multimodal_gen.runtime.layers.linear import LinearMethodBase\nfrom sglang.multimodal_gen.runtime.models.utils import set_weight_attrs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\ntry:\n    from nunchaku.ops.gemm import svdq_gemm_w4a4_cuda\n    from nunchaku.ops.gemv import awq_gemv_w4a16_cuda\n    from nunchaku.ops.quantize import svdq_quantize_w4a4_act_fuse_lora_cuda\nexcept ImportError:\n    svdq_gemm_w4a4_cuda = None\n    awq_gemv_w4a16_cuda = None\n    svdq_quantize_w4a4_act_fuse_lora_cuda = None\n\n\nclass NunchakuSVDQLinearMethod(LinearMethodBase):\n    def __init__(\n        self,\n        precision: str = \"int4\",\n        rank: int = 32,\n        act_unsigned: bool = False,\n    ):\n        self.precision = precision\n        self.rank = rank\n        self.act_unsigned = act_unsigned\n\n        if precision == \"nvfp4\":\n            self.group_size = 16\n        else:\n            self.group_size = 64\n\n    def create_weights(\n        self,\n        layer: torch.nn.Module,\n        input_size_per_partition: int,\n        output_partition_sizes: List[int],\n        input_size: int,\n        output_size: int,\n        params_dtype: torch.dtype,\n        **extra_weight_attrs,\n    ) -> None:\n        output_size_per_partition = sum(output_partition_sizes)\n\n        qweight = Parameter(\n            torch.empty(\n                output_size_per_partition,\n                input_size_per_partition // 2,\n                dtype=torch.int8,\n            ),\n            requires_grad=False,\n        )\n        set_weight_attrs(qweight, {\"input_dim\": 1, \"output_dim\": 0})\n\n        num_groups = input_size_per_partition // self.group_size\n        if self.precision == \"nvfp4\":\n            scale_dtype = torch.float8_e4m3fn\n        else:\n            scale_dtype = params_dtype\n        wscales = Parameter(\n            torch.empty(num_groups, output_size_per_partition, dtype=scale_dtype),\n            requires_grad=False,\n        )\n\n        smooth_factor = Parameter(\n            torch.empty(input_size_per_partition, dtype=params_dtype),\n            requires_grad=False,\n        )\n\n        smooth_factor_orig = Parameter(\n            torch.empty(input_size_per_partition, dtype=params_dtype),\n            requires_grad=False,\n        )\n\n        proj_down = Parameter(\n            torch.empty(input_size_per_partition, self.rank, dtype=params_dtype),\n            requires_grad=False,\n        )\n        proj_up = Parameter(\n            torch.empty(output_size_per_partition, self.rank, dtype=params_dtype),\n            requires_grad=False,\n        )\n\n        if self.precision == \"nvfp4\":\n            wcscales = Parameter(\n                torch.empty(\n                    output_size_per_partition,\n                    dtype=params_dtype,\n                ),\n                requires_grad=False,\n            )\n            wtscale = Parameter(\n                torch.empty(1, dtype=params_dtype),\n                requires_grad=False,\n            )\n        else:\n            wcscales = None\n            wtscale = None\n\n        layer.register_parameter(\"qweight\", qweight)\n        layer.register_parameter(\"wscales\", wscales)\n        layer.register_parameter(\"smooth_factor\", smooth_factor)\n        layer.register_parameter(\"smooth_factor_orig\", smooth_factor_orig)\n        layer.register_parameter(\"proj_down\", proj_down)\n        layer.register_parameter(\"proj_up\", proj_up)\n        if wcscales is not None:\n            layer.register_parameter(\"wcscales\", wcscales)\n        if wtscale is not None:\n            layer.register_parameter(\"wtscale\", wtscale)\n\n        layer.input_size_per_partition = input_size_per_partition\n        layer.output_size_per_partition = output_size_per_partition\n        layer.precision = self.precision\n        layer.rank = self.rank\n        layer.group_size = self.group_size\n        layer.act_unsigned = self.act_unsigned\n\n        weight_loader = extra_weight_attrs.get(\"weight_loader\")\n        if weight_loader is not None:\n            set_weight_attrs(qweight, {\"weight_loader\": weight_loader})\n            set_weight_attrs(wscales, {\"weight_loader\": weight_loader})\n            set_weight_attrs(smooth_factor, {\"weight_loader\": weight_loader})\n            set_weight_attrs(smooth_factor_orig, {\"weight_loader\": weight_loader})\n            set_weight_attrs(proj_down, {\"weight_loader\": weight_loader})\n            set_weight_attrs(proj_up, {\"weight_loader\": weight_loader})\n            if wcscales is not None:\n                set_weight_attrs(wcscales, {\"weight_loader\": weight_loader})\n            if wtscale is not None:\n                set_weight_attrs(wtscale, {\"weight_loader\": weight_loader})\n\n    def process_weights_after_loading(self, layer: nn.Module) -> None:\n        layer.qweight = Parameter(layer.qweight.data, requires_grad=False)\n        layer.wscales = Parameter(layer.wscales.data, requires_grad=False)\n        layer.smooth_factor = Parameter(layer.smooth_factor.data, requires_grad=False)\n        layer.smooth_factor_orig = Parameter(\n            layer.smooth_factor_orig.data, requires_grad=False\n        )\n        layer.proj_down = Parameter(layer.proj_down.data, requires_grad=False)\n        layer.proj_up = Parameter(layer.proj_up.data, requires_grad=False)\n        if hasattr(layer, \"wcscales\") and layer.wcscales is not None:\n            layer.wcscales = Parameter(layer.wcscales.data, requires_grad=False)\n        if hasattr(layer, \"wtscale\") and layer.wtscale is not None:\n            layer.wtscale = Parameter(layer.wtscale.data, requires_grad=False)\n\n        alpha: float | None = None\n        wtscale = getattr(layer, \"wtscale\", None)\n        if wtscale is not None:\n            if isinstance(wtscale, Parameter):\n                wtscale = wtscale.data\n            if isinstance(wtscale, torch.Tensor):\n                alpha = float(wtscale.detach().cpu().item())\n            else:\n                alpha = float(wtscale)\n        layer._nunchaku_alpha = alpha\n\n    def apply(\n        self,\n        layer: torch.nn.Module,\n        x: torch.Tensor,\n        bias: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n\n        orig_shape = x.shape\n        x_2d = x.reshape(-1, orig_shape[-1])\n        quantized_x, ascales, lora_act_out = svdq_quantize_w4a4_act_fuse_lora_cuda(\n            x_2d,\n            lora_down=layer.proj_down,\n            smooth=layer.smooth_factor,\n            fp4=layer.precision == \"nvfp4\",\n            pad_size=256,\n        )\n        out_2d = torch.empty(\n            x_2d.shape[0],\n            layer.output_size_per_partition,\n            dtype=x_2d.dtype,\n            device=x_2d.device,\n        )\n        alpha: float | None = getattr(layer, \"_nunchaku_alpha\", None)\n        wcscales = getattr(layer, \"wcscales\", None)\n\n        svdq_gemm_w4a4_cuda(\n            act=quantized_x,\n            wgt=layer.qweight,\n            out=out_2d,\n            ascales=ascales,\n            wscales=layer.wscales,\n            lora_act_in=lora_act_out,\n            lora_up=layer.proj_up,\n            bias=bias,\n            fp4=layer.precision == \"nvfp4\",\n            alpha=alpha,\n            wcscales=wcscales,\n            act_unsigned=getattr(layer, \"act_unsigned\", False),\n        )\n        out = out_2d.reshape(*orig_shape[:-1], layer.output_size_per_partition)\n        return out\n\n\nclass NunchakuAWQLinearMethod(LinearMethodBase):\n    def __init__(self, group_size: int = 64):\n        self.group_size = group_size\n        self.pack_factor = 8\n\n    def create_weights(\n        self,\n        layer: torch.nn.Module,\n        input_size_per_partition: int,\n        output_partition_sizes: List[int],\n        input_size: int,\n        output_size: int,\n        params_dtype: torch.dtype,\n        **extra_weight_attrs,\n    ) -> None:\n        output_size_per_partition = sum(output_partition_sizes)\n\n        qweight = Parameter(\n            torch.empty(\n                output_size_per_partition // 4,\n                input_size_per_partition // 2,\n                dtype=torch.int32,\n            ),\n            requires_grad=False,\n        )\n        set_weight_attrs(qweight, {\"input_dim\": 1, \"output_dim\": 0})\n\n        num_groups = input_size_per_partition // self.group_size\n        wscales = Parameter(\n            torch.empty(num_groups, output_size_per_partition, dtype=params_dtype),\n            requires_grad=False,\n        )\n\n        wzeros = Parameter(\n            torch.empty(num_groups, output_size_per_partition, dtype=params_dtype),\n            requires_grad=False,\n        )\n\n        layer.register_parameter(\"qweight\", qweight)\n        layer.register_parameter(\"wscales\", wscales)\n        layer.register_parameter(\"wzeros\", wzeros)\n\n        layer.input_size_per_partition = input_size_per_partition\n        layer.output_size_per_partition = output_size_per_partition\n        layer.group_size = self.group_size\n        layer.pack_factor = self.pack_factor\n\n        weight_loader = extra_weight_attrs.get(\"weight_loader\")\n        if weight_loader is not None:\n            set_weight_attrs(qweight, {\"weight_loader\": weight_loader})\n            set_weight_attrs(wscales, {\"weight_loader\": weight_loader})\n            set_weight_attrs(wzeros, {\"weight_loader\": weight_loader})\n\n    def process_weights_after_loading(self, layer: nn.Module) -> None:\n        layer.qweight = Parameter(layer.qweight.data, requires_grad=False)\n        layer.wscales = Parameter(layer.wscales.data, requires_grad=False)\n        layer.wzeros = Parameter(layer.wzeros.data, requires_grad=False)\n\n    def apply(\n        self,\n        layer: torch.nn.Module,\n        x: torch.Tensor,\n        bias: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n\n        orig_shape = x.shape\n        x_2d = x.reshape(-1, orig_shape[-1])\n\n        in_features = layer.input_size_per_partition\n        out_features = layer.output_size_per_partition\n        out_2d = awq_gemv_w4a16_cuda(\n            in_feats=x_2d,\n            kernel=layer.qweight,\n            scaling_factors=layer.wscales,\n            zeros=layer.wzeros,\n            m=x_2d.shape[0],\n            n=out_features,\n            k=in_features,\n            group_size=layer.group_size,\n        )\n        if bias is not None:\n            view_shape = [1] * (out_2d.ndim - 1) + [-1]\n            out_2d.add_(bias.view(view_shape))\n\n        out = out_2d.reshape(*orig_shape[:-1], out_features)\n        return out\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/layers/rotary_embedding/__init__.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/layers/rotary_embedding.py\n\n# Adapted from\n# https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py\n# Copyright 2023 The vLLM team.\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Rotary Positional Embeddings — unified public API (drop-in replacement).\"\"\"\n\nfrom .base import RotaryEmbedding\nfrom .factory import get_rope, get_rotary_pos_embed\nfrom .mrope import NDRotaryEmbedding\nfrom .utils import (\n    _apply_rotary_emb,\n    apply_flashinfer_rope_qk_inplace,\n)\n\n__all__ = [\n    # _utils\n    \"_apply_rotary_emb\",\n    \"apply_flashinfer_rope_qk_inplace\",\n    # _base\n    \"RotaryEmbedding\",\n    # _mrope\n    \"NDRotaryEmbedding\",\n    # _factory\n    \"get_rope\",\n    \"get_rotary_pos_embed\",\n]\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/layers/rotary_embedding/base.py",
    "content": "\"\"\"RotaryEmbedding base class and LinearScalingRotaryEmbedding variant.\"\"\"\n\nimport torch\n\nfrom sglang.multimodal_gen.runtime.layers.custom_op import CustomOp\n\nfrom .utils import _apply_rotary_emb\n\n\n@CustomOp.register(\"rotary_embedding\")\nclass RotaryEmbedding(CustomOp):\n    \"\"\"Original rotary positional embedding.\"\"\"\n\n    def __init__(\n        self,\n        head_size: int,\n        rotary_dim: int,\n        max_position_embeddings: int,\n        base: int | float,\n        is_neox_style: bool,\n        dtype: torch.dtype,\n    ) -> None:\n        super().__init__()\n        self.head_size = head_size\n        self.rotary_dim = rotary_dim\n        self.max_position_embeddings = max_position_embeddings\n        self.base = base\n        self.is_neox_style = is_neox_style\n        self.dtype = dtype\n\n        cache = self._compute_cos_sin_cache()\n        cache = cache.to(dtype)\n        self.cos_sin_cache: torch.Tensor\n        self.register_buffer(\"cos_sin_cache\", cache, persistent=False)\n\n    def _compute_inv_freq(self, base: int | float) -> torch.Tensor:\n        \"\"\"Compute the inverse frequency.\"\"\"\n        # NOTE(woosuk): To exactly match the HF implementation, we need to\n        # use CPU to compute the cache and then move it to GPU. However, we\n        # create the cache on GPU for faster initialization. This may cause\n        # a slight numerical difference between the HF implementation and ours.\n        inv_freq = 1.0 / (\n            base\n            ** (\n                torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim\n            )\n        )\n        return inv_freq\n\n    def _compute_cos_sin_cache(self) -> torch.Tensor:\n        \"\"\"Compute the cos and sin cache.\"\"\"\n        inv_freq = self._compute_inv_freq(self.base)\n        t = torch.arange(self.max_position_embeddings, dtype=torch.float)\n\n        freqs = torch.einsum(\"i,j -> ij\", t, inv_freq)\n        cos = freqs.cos()\n        sin = freqs.sin()\n        cache = torch.cat((cos, sin), dim=-1)\n        return cache\n\n    def forward_cuda(self, *args, **kwargs):\n        return self.forward_native(*args, **kwargs)\n\n    def forward_native(\n        self,\n        positions: torch.Tensor,\n        query: torch.Tensor,\n        key: torch.Tensor,\n        offsets: torch.Tensor | None = None,\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"A PyTorch-native implementation of forward().\"\"\"\n        if offsets is not None:\n            positions = positions + offsets\n        positions = positions.flatten()\n        num_tokens = positions.shape[0]\n        cos_sin = self.cos_sin_cache.index_select(0, positions)\n        cos, sin = cos_sin.chunk(2, dim=-1)\n\n        query_shape = query.shape\n        query = query.view(num_tokens, -1, self.head_size)\n        query_rot = query[..., : self.rotary_dim]\n        query_pass = query[..., self.rotary_dim :]\n        query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)\n        query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)\n\n        key_shape = key.shape\n        key = key.view(num_tokens, -1, self.head_size)\n        key_rot = key[..., : self.rotary_dim]\n        key_pass = key[..., self.rotary_dim :]\n        key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)\n        key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)\n        return query, key\n\n    def extra_repr(self) -> str:\n        s = f\"head_size={self.head_size}, rotary_dim={self.rotary_dim}\"\n        s += f\", max_position_embeddings={self.max_position_embeddings}\"\n        s += f\", base={self.base}, is_neox_style={self.is_neox_style}\"\n        return s\n\n\nclass LinearScalingRotaryEmbedding(RotaryEmbedding):\n    def __init__(\n        self,\n        head_size: int,\n        rotary_dim: int,\n        max_position_embeddings: int,\n        base: int | float,\n        is_neox_style: bool,\n        dtype: torch.dtype,\n        scaling_factor: float,\n    ) -> None:\n        self.scaling_factor = float(scaling_factor)\n        super().__init__(\n            head_size=head_size,\n            rotary_dim=rotary_dim,\n            max_position_embeddings=max_position_embeddings,\n            base=base,\n            is_neox_style=is_neox_style,\n            dtype=dtype,\n        )\n\n    def _compute_cos_sin_cache(self) -> torch.Tensor:\n        inv_freq = self._compute_inv_freq(self.base)\n        t = torch.arange(self.max_position_embeddings, dtype=torch.float)\n        t = t / self.scaling_factor\n        freqs = torch.einsum(\"i,j -> ij\", t, inv_freq)\n        cos = freqs.cos()\n        sin = freqs.sin()\n        cache = torch.cat((cos, sin), dim=-1)\n        return cache\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/layers/rotary_embedding/factory.py",
    "content": "\"\"\"get_rope / get_rotary_pos_embed factory functions and module-level caches.\"\"\"\n\nfrom collections import OrderedDict\nfrom typing import Any\n\nimport torch\n\nfrom .base import LinearScalingRotaryEmbedding, RotaryEmbedding\nfrom .mrope import NDRotaryEmbedding, _to_tuple\n\n_ROPE_DICT: dict[tuple, RotaryEmbedding] = {}\n_ND_ROPE_CACHE: \"OrderedDict[tuple, NDRotaryEmbedding]\" = OrderedDict()\n_ROPE_3D_CACHE: \"OrderedDict[tuple, tuple[torch.Tensor, torch.Tensor]]\" = OrderedDict()\n\n\ndef get_rope(\n    head_size: int,\n    rotary_dim: int,\n    max_position: int,\n    base: int | float,\n    is_neox_style: bool = True,\n    rope_scaling: dict[str, Any] | None = None,\n    dtype: torch.dtype | None = None,\n    partial_rotary_factor: float = 1.0,\n) -> RotaryEmbedding:\n    if dtype is None:\n        dtype = torch.get_default_dtype()\n    if rope_scaling is not None:\n        # Transforms every value that is a list into a tuple for caching calls\n        rope_scaling_tuple = {\n            k: tuple(v) if isinstance(v, list) else v for k, v in rope_scaling.items()\n        }\n        rope_scaling_args = tuple(rope_scaling_tuple.items())\n    else:\n        rope_scaling_args = None\n    if partial_rotary_factor < 1.0:\n        rotary_dim = int(rotary_dim * partial_rotary_factor)\n    max_position_embeddings = max_position\n    rope_type = None\n    if rope_scaling is not None:\n        rope_type = rope_scaling.get(\"rope_type\", rope_scaling.get(\"type\", None))\n        if rope_type in (None, \"default\"):\n            rope_scaling = None\n        elif rope_type == \"linear\":\n            factor = float(rope_scaling.get(\"factor\", 1.0))\n            original_max = rope_scaling.get(\"original_max_position_embeddings\", None)\n            if original_max is not None:\n                max_position_embeddings = max(\n                    max_position_embeddings, int(float(original_max) * factor)\n                )\n    key = (\n        head_size,\n        rotary_dim,\n        max_position_embeddings,\n        base,\n        is_neox_style,\n        rope_scaling_args,\n        dtype,\n    )\n    if key in _ROPE_DICT:\n        return _ROPE_DICT[key]\n\n    if rope_scaling is None:\n        rotary_emb = RotaryEmbedding(\n            head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype\n        )\n    else:\n        if rope_type == \"linear\":\n            factor = float(rope_scaling.get(\"factor\", 1.0))\n            rotary_emb = LinearScalingRotaryEmbedding(\n                head_size=head_size,\n                rotary_dim=rotary_dim,\n                max_position_embeddings=max_position_embeddings,\n                base=base,\n                is_neox_style=is_neox_style,\n                dtype=dtype,\n                scaling_factor=factor,\n            )\n        else:\n            raise ValueError(f\"Unknown RoPE scaling {rope_scaling}\")\n    _ROPE_DICT[key] = rotary_emb\n    return rotary_emb\n\n\ndef get_rotary_pos_embed(\n    rope_sizes,\n    hidden_size,\n    heads_num,\n    rope_dim_list,\n    rope_theta,\n    theta_rescale_factor=1.0,\n    interpolation_factor=1.0,\n    shard_dim: int = 0,\n    dtype: torch.dtype = torch.float32,\n    start_frame: int = 0,\n    device: torch.device | str | None = None,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Generate rotary positional embeddings for the given sizes.\n\n    Args:\n        rope_sizes: Tuple of dimensions (t, h, w)\n        hidden_size: Hidden dimension size\n        heads_num: Number of attention heads\n        rope_dim_list: List of dimensions for each axis, or None\n        rope_theta: Base for frequency calculations\n        theta_rescale_factor: Rescale factor for theta. Defaults to 1.0\n        interpolation_factor: Factor to scale positions. Defaults to 1.0\n        shard_dim: Which dimension to shard for sequence parallelism. Defaults to 0.\n\n    Returns:\n        Tuple of (cos, sin) tensors for rotary embeddings\n    \"\"\"\n\n    target_ndim = 3\n    head_dim = hidden_size // heads_num\n\n    if rope_dim_list is None:\n        rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]\n\n    assert (\n        sum(rope_dim_list) == head_dim\n    ), \"sum(rope_dim_list) should equal to head_dim of attention layer\"\n\n    # Get SP info - now handled within NDRotaryEmbedding\n    # sp_group = get_sp_group()\n    # sp_rank = sp_group.rank_in_group\n    # sp_world_size = sp_group.world_size\n\n    # Simple LRU cache keyed by parameters\n    global _ND_ROPE_CACHE\n    key = (\n        tuple(rope_dim_list),\n        float(rope_theta),\n        (\n            tuple(theta_rescale_factor)\n            if isinstance(theta_rescale_factor, list)\n            else float(theta_rescale_factor)\n        ),\n        (\n            tuple(interpolation_factor)\n            if isinstance(interpolation_factor, list)\n            else float(interpolation_factor)\n        ),\n        dtype,\n    )\n\n    cache_hit = key in _ND_ROPE_CACHE\n    if cache_hit:\n        rope_emb = _ND_ROPE_CACHE.pop(key)\n        _ND_ROPE_CACHE[key] = rope_emb  # move to end (most-recent)\n    else:\n        rope_emb = NDRotaryEmbedding(\n            rope_dim_list=rope_dim_list,\n            rope_theta=rope_theta,\n            theta_rescale_factor=theta_rescale_factor,\n            interpolation_factor=interpolation_factor,\n            dtype=dtype,\n        )\n        _ND_ROPE_CACHE[key] = rope_emb\n        if len(_ND_ROPE_CACHE) > 16:\n            # pop least-recently-used\n            _ND_ROPE_CACHE.pop(next(iter(_ND_ROPE_CACHE)))\n\n    freqs_cos, freqs_sin = rope_emb.forward_from_grid(\n        grid_size=_to_tuple(rope_sizes, dim=3),\n        shard_dim=shard_dim,\n        start_frame=start_frame,\n        device=device,\n    )\n    return freqs_cos, freqs_sin\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/layers/rotary_embedding/mrope.py",
    "content": "\"\"\"MRotaryEmbedding, YaRNScalingMRotaryEmbedding, NDRotaryEmbedding, OneDRotaryEmbedding.\"\"\"\n\nimport functools\n\nimport torch\n\nfrom sglang.multimodal_gen.runtime.distributed.parallel_state import get_sp_group\n\n\ndef _to_tuple(x: int | tuple[int, ...], dim: int = 2) -> tuple[int, ...]:\n    if isinstance(x, int):\n        return (x,) * dim\n    elif len(x) == dim:\n        return x\n    else:\n        raise ValueError(f\"Expected length {dim} or int, but got {x}\")\n\n\ndef get_1d_rotary_pos_embed(\n    dim: int,\n    pos: torch.FloatTensor | int,\n    theta: float = 10000.0,\n    theta_rescale_factor: float = 1.0,\n    interpolation_factor: float = 1.0,\n    dtype: torch.dtype = torch.float32,\n    device: torch.device | str | None = None,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Precompute the frequency tensor for complex exponential (cis) with given dimensions.\n    (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)\n\n    This function calculates a frequency tensor with complex exponential using the given dimension 'dim'\n    and the end index 'end'. The 'theta' parameter scales the frequencies.\n\n    Args:\n        dim (int): Dimension of the frequency tensor.\n        pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar\n        theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.\n        theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.\n        interpolation_factor (float, optional): Factor to scale positions. Defaults to 1.0.\n\n    Returns:\n        freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]\n    \"\"\"\n    if isinstance(pos, int):\n        pos = torch.arange(pos, dtype=dtype, device=device)\n    elif (\n        isinstance(pos, torch.Tensor)\n        and device is not None\n        and pos.device != torch.device(device)\n    ):\n        # Ensure positions are on the requested device to avoid implicit CPU ops.\n        pos = pos.to(device)\n\n    # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning\n    # has some connection to NTK literature\n    if theta_rescale_factor != 1.0:\n        theta *= theta_rescale_factor ** (dim / (dim - 2))\n\n    freqs = 1.0 / (\n        theta\n        ** (torch.arange(0, dim, 2, device=device)[: (dim // 2)].to(dtype) / dim).to(\n            device=device\n        )\n    )  # [D/2]\n    freqs = torch.outer(pos * interpolation_factor, freqs)  # [S, D/2]\n    freqs_cos = freqs.cos()  # [S, D/2]\n    freqs_sin = freqs.sin()  # [S, D/2]\n    return freqs_cos, freqs_sin\n\n\nclass OneDRotaryEmbedding(torch.nn.Module):\n    \"\"\"1D rotary positional embedding with caching.\"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        theta: float = 10000.0,\n        theta_rescale_factor: float = 1.0,\n        interpolation_factor: float = 1.0,\n        dtype: torch.dtype = torch.float32,\n        use_real: bool = False,\n        repeat_interleave_real: bool = False,\n    ):\n        super().__init__()\n        assert dim % 2 == 0\n        self.dim = dim\n        self.theta = theta\n        self.theta_rescale_factor = theta_rescale_factor\n        self.interpolation_factor = interpolation_factor\n        # dtype of freqs\n        self.dtype = dtype\n        self.use_real = use_real\n        self.repeat_interleave_real = repeat_interleave_real\n\n    def build_freqs(self, device):\n        freqs = 1.0 / (\n            self.theta\n            ** (\n                torch.arange(0, self.dim, 2, dtype=self.dtype, device=device)[\n                    : (self.dim // 2)\n                ]\n                / self.dim\n            ).to(device=device)\n        )\n        return freqs\n\n    def build_freqs_outer(self, pos: torch.Tensor, device):\n        theta = self.theta\n        # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning\n        # has some connection to NTK literature\n        if self.theta_rescale_factor != 1.0:\n            theta *= self.theta_rescale_factor ** (self.dim / (self.dim - 2))\n\n        freqs = self.build_freqs(device)\n\n        freqs = torch.outer(pos * self.interpolation_factor, freqs)\n        freqs_cos = freqs.cos()\n        freqs_sin = freqs.sin()\n\n        if self.use_real and self.repeat_interleave_real:\n            freqs_cos = freqs_cos.repeat_interleave(2, dim=1)\n            freqs_sin = freqs_sin.repeat_interleave(2, dim=1)\n\n        return freqs_cos.float(), freqs_sin.float()\n\n    @functools.lru_cache(maxsize=16)\n    def forward_from_grid(\n        self, seq_len: int, start_pos: int, device_str: str\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        device = torch.device(device_str)\n        pos = torch.arange(\n            start_pos, start_pos + seq_len, dtype=self.dtype, device=device\n        )\n\n        freqs_cos, freqs_sin = self.build_freqs_outer(pos, device)\n        return freqs_cos, freqs_sin\n\n    def forward(self, pos: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Calculates 1D rotary embeddings for the given positions.\n\n        This method converts the input tensor to a hashable representation\n        and calls a cached helper method to perform the computation.\n        \"\"\"\n        pos_tuple = tuple(pos.tolist())\n        device_str = str(pos.device)\n        return self._forward_cached(pos_tuple, device_str)\n\n    @functools.lru_cache(maxsize=16)\n    def _forward_cached(\n        self, pos_tuple: tuple, device_str: str\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        The core implementation that computes 1D rotary embeddings.\n        This method is wrapped by an LRU cache.\n        \"\"\"\n        device = torch.device(device_str)\n        pos = torch.as_tensor(pos_tuple, dtype=self.dtype, device=device)\n        freqs_cos, freqs_sin = self.build_freqs_outer(pos, device)\n        return freqs_cos, freqs_sin\n\n\nclass NDRotaryEmbedding(torch.nn.Module):\n    \"\"\"N-dimensional rotary positional embedding.\"\"\"\n\n    def __init__(\n        self,\n        rope_dim_list: list[int],\n        rope_theta: float,\n        theta_rescale_factor: float | list[float] = 1.0,\n        interpolation_factor: float | list[float] = 1.0,\n        use_real: bool = False,\n        repeat_interleave_real: bool = False,\n        dtype: torch.dtype = torch.float32,\n    ):\n        super().__init__()\n        self.rope_dim_list = rope_dim_list\n        self.ndim = len(rope_dim_list)\n        self.rope_theta = rope_theta\n        # dtype of freqs\n        # does not control the output dtype\n        self.dtype = dtype\n\n        if isinstance(theta_rescale_factor, (int, float)):\n            self.theta_rescale_factor = [theta_rescale_factor] * self.ndim\n        elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:\n            self.theta_rescale_factor = [theta_rescale_factor[0]] * self.ndim\n        else:\n            self.theta_rescale_factor = theta_rescale_factor\n        assert (\n            len(self.theta_rescale_factor) == self.ndim\n        ), \"len(theta_rescale_factor) should equal to len(rope_dim_list)\"\n\n        if isinstance(interpolation_factor, (int, float)):\n            self.interpolation_factor = [interpolation_factor] * self.ndim\n        elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:\n            self.interpolation_factor = [interpolation_factor[0]] * self.ndim\n        else:\n            self.interpolation_factor = interpolation_factor\n        assert (\n            len(self.interpolation_factor) == self.ndim\n        ), \"len(interpolation_factor) should equal to len(rope_dim_list)\"\n\n        self.rope_generators: list[OneDRotaryEmbedding] = torch.nn.ModuleList()\n        _config_to_gen_idx: dict[tuple, int] = {}\n        self.dim_idx_to_gen_idx: list[int] = []\n\n        for i in range(self.ndim):\n            dim = self.rope_dim_list[i]\n            rescale = self.theta_rescale_factor[i]\n            interp = self.interpolation_factor[i]\n\n            config_key = (dim, rescale, interp, use_real, repeat_interleave_real)\n            if config_key not in _config_to_gen_idx:\n                generator = OneDRotaryEmbedding(\n                    dim=dim,\n                    theta=self.rope_theta,\n                    theta_rescale_factor=rescale,\n                    interpolation_factor=interp,\n                    dtype=self.dtype,\n                    use_real=use_real,\n                    repeat_interleave_real=repeat_interleave_real,\n                )\n                _config_to_gen_idx[config_key] = len(self.rope_generators)\n                self.rope_generators.append(generator)\n\n            gen_idx = _config_to_gen_idx[config_key]\n            self.dim_idx_to_gen_idx.append(gen_idx)\n\n    def forward(self, positions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Calculates n-d rotary embeddings for given absolute positions.\n\n        Args:\n            positions (torch.Tensor): A tensor of shape `[num_tokens, ndim]`\n                containing the integer coordinates for each token.\n\n        Returns:\n            A tuple of (cos, sin) tensors.\n        \"\"\"\n        # Caching wrapper: convert tensor to a hashable tuple of tuples.\n        pos_tuple = tuple(map(tuple, positions.tolist()))\n        device_str = str(positions.device)\n        return self._forward_cached(pos_tuple, device_str)\n\n    @functools.lru_cache(maxsize=16)\n    def _forward_cached(\n        self, pos_tuple: tuple[tuple[int, ...], ...], device_str: str\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        The core implementation that computes embeddings from a position tensor.\n        This method is wrapped by an LRU cache.\n        \"\"\"\n        device = torch.device(device_str)\n        positions = torch.tensor(pos_tuple, dtype=torch.long, device=device)\n        return self.forward_uncached(pos=positions)\n\n    def forward_uncached(self, pos: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        The core implementation that computes embeddings from a position tensor.\n        This method is wrapped by an LRU cache.\n        \"\"\"\n        device = pos.device\n\n        # Pre-allocate the final tensors for efficiency.\n        num_tokens = pos.shape[0]\n        first_generator = self.rope_generators[0]\n        if first_generator.use_real and first_generator.repeat_interleave_real:\n            head_dim = sum(self.rope_dim_list)\n        else:\n            head_dim = sum(self.rope_dim_list) // 2\n\n        cos = torch.empty((num_tokens, head_dim), device=device, dtype=self.dtype)\n        sin = torch.empty((num_tokens, head_dim), device=device, dtype=self.dtype)\n\n        col_offset = 0\n        for i in range(self.ndim):\n            # Extract position coordinates for the current dimension for all tokens.\n            pos_i = pos[:, i].to(self.dtype)\n\n            # Get the appropriate 1D generator.\n            gen_idx = self.dim_idx_to_gen_idx[i]\n            generator = self.rope_generators[gen_idx]\n\n            # Calculate 1D embeddings.\n            cos_1d, sin_1d = generator(pos_i)\n\n            slice_width = cos_1d.shape[1]\n            cos[:, col_offset : col_offset + slice_width] = cos_1d\n            sin[:, col_offset : col_offset + slice_width] = sin_1d\n            col_offset += slice_width\n\n        return cos.float(), sin.float()\n\n    def forward_from_grid(\n        self,\n        grid_size: tuple[int, ...],\n        shard_dim: int = 0,\n        start_frame: int = 0,\n        device: torch.device | str | None = None,\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Handles sp internally\n        \"\"\"\n        # Caching wrapper: use grid parameters directly as the key.\n        # grid_tuple = _to_tuple(grid_size, dim=self.ndim)\n        device_str = str(device) if device is not None else \"cpu\"\n        return self._forward_cached_from_grid(\n            grid_size, shard_dim, start_frame, device_str\n        )\n\n    @functools.lru_cache(maxsize=16)\n    def _forward_cached_from_grid(\n        self,\n        grid_size: tuple[int, ...],\n        shard_dim: int,\n        start_frame: int,\n        device_str: str,\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Computes embeddings for a structured grid, using a highly efficient\n        implementation that avoids materializing the full position tensor.\n        This method is wrapped by an LRU cache.\n        \"\"\"\n        device = torch.device(device_str)\n        sp_group = get_sp_group()\n        sp_rank = sp_group.rank_in_group\n        sp_world_size = sp_group.world_size\n\n        sizes = _to_tuple(grid_size, dim=self.ndim)\n        starts = (0,) * self.ndim\n\n        # Apply sequence parallel sharding to the sizes and compute shard offset\n        shard_sizes = list(sizes)\n        shard_offsets = [0] * self.ndim\n        if sp_world_size > 1:\n            assert sizes[shard_dim] % sp_world_size == 0, (\n                f\"Dimension {shard_dim} with size {sizes[shard_dim]} is not divisible \"\n                f\"by sequence parallel world size {sp_world_size}\"\n            )\n            shard_size = sizes[shard_dim] // sp_world_size\n            shard_offsets[shard_dim] = sp_rank * shard_size\n            shard_sizes[shard_dim] = shard_size\n\n        # Pre-allocate outputs on the requested device to avoid CPU ops and extra cats\n        num_tokens = 1\n        for s in shard_sizes:\n            num_tokens *= int(s)\n        head_dim_half = sum(self.rope_dim_list) // 2\n        cos = torch.empty((num_tokens, head_dim_half), device=device, dtype=self.dtype)\n        sin = torch.empty((num_tokens, head_dim_half), device=device, dtype=self.dtype)\n\n        # Compute per-axis 1D embeddings once and expand via repeats to [N, d_i/2]\n        col_offset = 0\n        for i in range(self.ndim):\n            dim_i = self.rope_dim_list[i]\n            dim_i_half = dim_i // 2\n            size_i = int(shard_sizes[i])\n\n            # Starting position for this axis, with optional frame offset for time axis (i==0)\n            base_offset = starts[i]\n            if i == 0 and start_frame > 0:\n                base_offset += start_frame\n            if sp_world_size > 1 and i == shard_dim:\n                base_offset += shard_offsets[i]\n\n            gen_idx = self.dim_idx_to_gen_idx[i]\n            generator = self.rope_generators[gen_idx]\n            cos_1d, sin_1d = generator.forward_from_grid(\n                size_i, base_offset, device_str\n            )\n\n            # Expand to [num_tokens, dim_i/2] matching flatten order (last dims vary fastest)\n            repeats_per_entry = 1\n            for j in range(i + 1, self.ndim):\n                repeats_per_entry *= int(shard_sizes[j])\n            tile_count = 1\n            for j in range(0, i):\n                tile_count *= int(shard_sizes[j])\n\n            cos_expanded = cos_1d.repeat_interleave(repeats_per_entry, dim=0)\n            sin_expanded = sin_1d.repeat_interleave(repeats_per_entry, dim=0)\n            if tile_count > 1:\n                cos_expanded = cos_expanded.repeat(tile_count, 1)\n                sin_expanded = sin_expanded.repeat(tile_count, 1)\n\n            cos[:, col_offset : col_offset + dim_i_half] = cos_expanded\n            sin[:, col_offset : col_offset + dim_i_half] = sin_expanded\n            col_offset += dim_i_half\n\n        return cos.float(), sin.float()\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/layers/rotary_embedding/utils.py",
    "content": "\"\"\"Primitive RoPE ops: rotate helpers and apply_rotary_emb utilities.\"\"\"\n\nfrom typing import Optional, Tuple\n\nimport torch\n\nfrom sglang.jit_kernel.diffusion.triton.rotary import apply_rotary_embedding\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\nfrom sglang.srt.utils.custom_op import register_custom_op_from_extern\n\n_is_cuda = current_platform.is_cuda()\nif _is_cuda:\n    try:\n        from flashinfer.rope import (\n            apply_rope_with_cos_sin_cache_inplace as _flashinfer_apply_rope_inplace,\n        )\n    except Exception:\n        _flashinfer_apply_rope_inplace = None\nelse:\n    _flashinfer_apply_rope_inplace = None\n\nif _flashinfer_apply_rope_inplace is not None:\n    flashinfer_apply_rope_inplace = register_custom_op_from_extern(\n        _flashinfer_apply_rope_inplace,\n        op_name=\"flashinfer_apply_rope_with_cos_sin_cache_inplace\",\n        mutates_args=[\"query\", \"key\"],\n    )\nelse:\n    flashinfer_apply_rope_inplace = None\n\n\ndef _apply_rotary_emb(\n    x: torch.Tensor,\n    cos: torch.Tensor,\n    sin: torch.Tensor,\n    is_neox_style: bool,\n    interleaved: bool = False,\n) -> torch.Tensor:\n    \"\"\"\n    Args:\n        x: [num_tokens, num_heads, head_size] or [num_tokens, head_size]\n        cos: [num_tokens, head_size // 2]\n        sin: [num_tokens, head_size // 2]\n        is_neox_style: Whether to use the Neox-style or GPT-J-style rotary\n            positional embeddings.\n    \"\"\"\n    # cos = cos.unsqueeze(-2).to(x.dtype)\n    # sin = sin.unsqueeze(-2).to(x.dtype)\n    if is_neox_style:\n        cos = cos.unsqueeze(-2)\n        sin = sin.unsqueeze(-2)\n        if is_neox_style:\n            x1, x2 = torch.chunk(x, 2, dim=-1)\n        else:\n            x1 = x[..., ::2]\n            x2 = x[..., 1::2]\n        o1 = (x1.float() * cos - x2.float() * sin).type_as(x)\n        o2 = (x2.float() * cos + x1.float() * sin).type_as(x)\n        return torch.cat((o1, o2), dim=-1)\n    else:\n        return apply_rotary_embedding(x, cos, sin, interleaved)\n\n\ndef apply_flashinfer_rope_qk_inplace(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    cos_sin_cache: torch.Tensor,\n    *,\n    head_size: Optional[int] = None,\n    is_neox: bool = False,\n    positions: Optional[torch.Tensor] = None,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    if q.dim() != 4 or k.dim() != 4:\n        raise ValueError(\n            f\"Expected q/k to be 4D [bsz, seqlen, nheads, head_size], \"\n            f\"got q:{tuple(q.shape)} k:{tuple(k.shape)}\"\n        )\n    if q.shape != k.shape:\n        raise ValueError(\n            f\"q and k must have the same shape, got {q.shape} vs {k.shape}\"\n        )\n\n    if not (isinstance(cos_sin_cache, torch.Tensor) and cos_sin_cache.dim() == 2):\n        raise ValueError(\"cos_sin_cache must be a 2D torch.Tensor\")\n\n    bsz, seqlen, nheads, d = q.shape\n    if head_size is None:\n        head_size = d\n    if head_size != d:\n        raise ValueError(f\"head_size mismatch: inferred {d}, but head_size={head_size}\")\n\n    if flashinfer_apply_rope_inplace is None:\n        # Triton fallback for AMD/ROCm where FlashInfer is not available\n        import warnings\n\n        warnings.warn(\n            \"FlashInfer not available, using Triton fallback for RoPE\",\n            stacklevel=2,\n        )\n        half_size = cos_sin_cache.shape[-1] // 2\n        if positions is None:\n            cos = cos_sin_cache[:seqlen, :half_size].to(q.dtype)\n            sin = cos_sin_cache[:seqlen, half_size:].to(q.dtype)\n            cos = cos.unsqueeze(0).expand(bsz, -1, -1).reshape(bsz * seqlen, -1)\n            sin = sin.unsqueeze(0).expand(bsz, -1, -1).reshape(bsz * seqlen, -1)\n        else:\n            positions = positions.to(cos_sin_cache.device).view(-1)\n            cos = cos_sin_cache[positions, :half_size].to(q.dtype)\n            sin = cos_sin_cache[positions, half_size:].to(q.dtype)\n        q_flat = q.reshape(bsz * seqlen, nheads, d)\n        k_flat = k.reshape(bsz * seqlen, nheads, d)\n        q_rot = apply_rotary_embedding(q_flat, cos, sin, interleaved=not is_neox)\n        k_rot = apply_rotary_embedding(k_flat, cos, sin, interleaved=not is_neox)\n        return q_rot.view(bsz, seqlen, nheads, d), k_rot.view(bsz, seqlen, nheads, d)\n\n    if positions is None:\n        pos_1d = torch.arange(seqlen, device=q.device, dtype=torch.long)\n        positions = pos_1d if bsz == 1 else pos_1d.repeat(bsz)\n    else:\n        if not (\n            isinstance(positions, torch.Tensor)\n            and positions.dtype == torch.long\n            and positions.dim() == 1\n        ):\n            raise ValueError(\"positions must be a 1D torch.long Tensor\")\n        if positions.numel() != bsz * seqlen:\n            raise ValueError(\n                f\"positions length must be bsz*seqlen={bsz*seqlen}, got {positions.numel()}\"\n            )\n\n    q_flat = q.reshape(bsz * seqlen, nheads * d).contiguous()\n    k_flat = k.reshape(bsz * seqlen, nheads * d).contiguous()\n    flashinfer_apply_rope_inplace(\n        positions=positions,\n        query=q_flat,\n        key=k_flat,\n        head_size=d,\n        cos_sin_cache=cos_sin_cache,\n        is_neox=is_neox,\n    )\n    return q_flat.view(bsz, seqlen, nheads, d), k_flat.view(bsz, seqlen, nheads, d)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/layers/usp.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\nimport logging\nfrom typing import TYPE_CHECKING\n\nimport torch\nimport torch.distributed._functional_collectives as ft_c\nfrom torch.distributed.tensor.experimental._attention import _cp_options\n\nfrom sglang.multimodal_gen.runtime.distributed.parallel_state import (\n    get_sp_group,\n    get_ulysses_parallel_world_size,\n)\nfrom sglang.srt.utils.common import torch_release\n\n_cp_options.enable_load_balance = False\n\nif TYPE_CHECKING:\n    from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import (\n        AttentionImpl,\n    )\n\nlogger = logging.getLogger(__name__)\n\n\ndef _maybe_wait(tensor: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    When tracing the code, the result tensor is not an AsyncCollectiveTensor,\n    so we cannot call ``wait()``.\n    \"\"\"\n    if isinstance(tensor, ft_c.AsyncCollectiveTensor):\n        return tensor.wait()\n    return tensor\n\n\ndef _usp_all_to_all_single(x: torch.Tensor) -> torch.Tensor:\n    ulysses_pg = get_sp_group().ulysses_group\n    assert ulysses_pg is not None, \"Ulysses process group is not initialized.\"\n    x_shape = x.shape\n    x = x.flatten()\n    x = ft_c.all_to_all_single(\n        x, output_split_sizes=None, input_split_sizes=None, group=ulysses_pg\n    )\n    x = _maybe_wait(x)\n    x = x.reshape(x_shape)\n    return x\n\n\ndef _usp_input_all_to_all(x: torch.Tensor, head_dim: int = 1) -> torch.Tensor:\n    \"\"\"\n    Perform Ulysses-style input all-to-all over the head dimension.\n\n    Default layout expects heads at dim=1 and sequence at dim=2:\n        [b, h, s_local, d] -> [b, h_local, s_global, d]\n\n    If heads are at dim=2 (input is [b, s_local, h, d]), set head_dim=2, and the\n    function returns [b, s_global, h_local, d], preserving the original\n    head/sequence dim ordering.\n\n    Args:\n        x: A 4D tensor with layout [b, *, *, d] where '*' are sequence and heads\n        head_dim: Which dimension index corresponds to heads (1 or 2)\n\n    Returns:\n        Tensor with the same dim order as input, with heads sharded and sequence gathered.\n    \"\"\"\n    world_size = get_ulysses_parallel_world_size()\n    if world_size <= 1:\n        return x\n\n    assert x.ndim == 4, f\"x must have 4 dimensions, got {x.ndim}\"\n    assert head_dim in (1, 2), f\"head_dim must be 1 or 2, got {head_dim}\"\n\n    # Move the dimension to be split (h_global) to dim 0 for all_to_all_single\n    if head_dim == 1:\n        b, h_global, s_local, d = x.shape\n        # Shape transition: [b, h_global, s_local, d] -> [h_global, b, s_local, d]\n        permute_order = (1, 0, 2, 3)\n    else:  # head_dim == 2\n        b, s_local, h_global, d = x.shape\n        # Shape transition: [b, s_local, h_global, d] -> [h_global, b, s_local, d]\n        permute_order = (2, 0, 1, 3)\n\n    assert (\n        h_global % world_size == 0\n    ), f\"h_global ({h_global}) must be divisible by world_size ({world_size})\"\n\n    h_local, s_global = h_global // world_size, s_local * world_size\n\n    x = x.permute(permute_order).contiguous()\n    x = _usp_all_to_all_single(x)\n    x = x.reshape(world_size, h_local, b, s_local, d)\n\n    # Reorder dims to place 'world_size' adjacent to 's_local' to merge them into 's_global'\n    if head_dim == 1:\n        # Shape transition: [world_size, h_local, b, s_local, d] -> [b, h_local, world_size, s_local, d]\n        x = x.permute(2, 1, 0, 3, 4).contiguous().reshape(b, h_local, s_global, d)\n    else:  # head_dim == 2\n        # Shape transition: [world_size, h_local, b, s_local, d] -> [b, world_size, s_local, h_local, d]\n        x = x.permute(2, 0, 3, 1, 4).contiguous().reshape(b, s_global, h_local, d)\n\n    return x\n\n\ndef _usp_output_all_to_all(x: torch.Tensor, head_dim: int = 1) -> torch.Tensor:\n    \"\"\"\n    Perform Ulysses-style output all-to-all over the head dimension (inverse of input).\n\n    Default layout expects heads at dim=1 and sequence at dim=2:\n        [b, h_local, s, d] -> [b, h, s_local, d]\n\n    If heads are at dim=2 (input is [b, s_global, h // world_size, d]), set head_dim=2,\n    and the function returns [b, s_local, h, d], preserving the original head/sequence\n    dim ordering.\n\n    Args:\n        x: A 4D tensor with layout [b, *, *, d] where '*' are sequence and heads\n        head_dim: Which dimension index corresponds to heads (1 or 2)\n\n    Returns:\n        Tensor with the same dim order as input, with heads gathered and sequence sharded.\n    \"\"\"\n    world_size = get_ulysses_parallel_world_size()\n    if world_size <= 1:\n        return x\n\n    assert x.ndim == 4, f\"x must have 4 dimensions, got {x.ndim}\"\n    assert head_dim in (1, 2), f\"head_dim must be 1 or 2, got {head_dim}\"\n\n    # Move the dimension to be split (s_global) to dim 0 for all_to_all_single\n    if head_dim == 1:\n        b, h_local, s_global, d = x.shape\n        # Shape transition: [b, h_local, s_global, d] -> [s_global, b, h_local, d]\n        permute_order = (2, 0, 1, 3)\n    else:  # head_dim == 2\n        b, s_global, h_local, d = x.shape\n        # Shape transition: [b, s_global, h_local, d] -> [s_global, b, h_local, d]\n        permute_order = (1, 0, 2, 3)\n\n    assert (\n        s_global % world_size == 0\n    ), f\"s_global ({s_global}) must be divisible by world_size ({world_size})\"\n\n    s_local, h_global = s_global // world_size, h_local * world_size\n\n    x = x.permute(permute_order).contiguous()\n    x = _usp_all_to_all_single(x)\n    x = x.reshape(world_size, s_local, b, h_local, d)\n\n    # Reorder dims to place 'world_size' adjacent to 'h_local' to merge them into 'h_global'\n    if head_dim == 1:\n        # Shape transition: [world_size, s_local, b, h_local, d] -> [b, world_size, h_local, s_local, d]\n        x = x.permute(2, 0, 3, 1, 4).contiguous().reshape(b, h_global, s_local, d)\n    else:  # head_dim == 2\n        # Shape transition: [world_size, s_local, b, h_local, d] -> [b, s_local, world_size, h_local, d]\n        x = x.permute(2, 1, 0, 3, 4).contiguous().reshape(b, s_local, h_global, d)\n\n    return x\n\n\ndef ring_attn(\n    query: torch.Tensor,\n    key: torch.Tensor,\n    value: torch.Tensor,\n    attn_impl: \"AttentionImpl\",\n    is_causal: bool = False,\n    dropout_p: float = 0.0,\n):\n    \"\"\"\n    Ring Attention implementation.\n\n    This function implements Ring Attention, a strategy for distributed attention\n    computation that reduces peak memory usage. It accepts a generic attention\n    implementation (`attn_impl`) which is called by the underlying PyTorch\n    distributed attention primitive.\n\n    Args:\n        query, key, value: The input tensors for attention.\n        attn_impl: An instance of an attention implementation backend\n                   (e.g., FlashAttentionImpl) whose `forward` method will be\n                   used as the computational kernel.\n        is_causal: Whether to apply causal masking.\n        dropout_p: Dropout probability.\n    \"\"\"\n    # torch.distributed.tensor.experimental._attention is not a public API,\n    from torch.distributed.tensor.experimental._attention import (\n        _templated_ring_attention,\n    )\n\n    ring_pg = get_sp_group().ring_group\n    assert ring_pg is not None, \"Ring process group is not initialized.\"\n\n    # Ring attention primitives expect tensors in [B, H, S, D] layout.\n    # We permute the inputs here.\n    query = torch.permute(query, [0, 2, 1, 3]).contiguous()\n    key = torch.permute(key, [0, 2, 1, 3]).contiguous()\n    value = torch.permute(value, [0, 2, 1, 3]).contiguous()\n\n    # Create an adapter function that matches the signature expected by\n    # _templated_ring_attention. The `attn_impl` already has dropout and\n    # causal settings configured during its initialization.\n\n    # Note: Please be aware that Attention Backend and Ring Attention may require different QKV tensor shapes.\n    # For example, FlashAttention expects the format to be BSHD.\n    def attn_callable_adapter(q, k, v, *args, **kwargs):\n        # We ignore the dropout_p and is_causal passed by _templated_ring_attention\n        # and rely on the pre-configured attn_impl.\n        # The `attn_metadata` is not available here, so we pass None.\n        # This is a limitation we must accept when using this experimental API.\n        q = torch.permute(q, [0, 2, 1, 3])\n        k = torch.permute(k, [0, 2, 1, 3])\n        v = torch.permute(v, [0, 2, 1, 3])\n        # logger.warning(f\"Warning: return_s·oftmax_lse is only supported for FlashAttentionImpl\")\n        output, softmax_lse, *rest = attn_impl.forward(\n            q,\n            k,\n            v,\n            attn_metadata=None,\n            return_softmax_lse=True,\n        )\n        output = torch.permute(output, [0, 2, 1, 3])\n        return output, softmax_lse, *rest\n\n    # Starting from torch 2.6.0, _templated_ring_attention expects an integer\n    # segment_id for the attention function.\n    use_segment_id = torch_release >= (2, 6)\n\n    attn_kwargs = dict(\n        op=attn_callable_adapter,\n        dropout_p=dropout_p,\n        is_causal=is_causal,\n        query=query,\n        key=key,\n        value=value,\n        group=ring_pg,  # https://github.com/pytorch/pytorch/blob/c907c778f42ba2fdaf25b733dd25baf9779c6a12/torch/distributed/tensor/experimental/_context_parallel/_attention.py#L309\n    )\n\n    if use_segment_id:\n        # For torch >= 2.6, segment_id is required. The value '1' is a placeholder\n        # as we are not using complex segmentation features.\n        out, *_ = _templated_ring_attention(\n            seq_dim=1,  # segment_id\n            **attn_kwargs,\n        )\n    else:\n        out, *_ = _templated_ring_attention(\n            **attn_kwargs,\n        )\n\n    # Permute the output back to [B, S, H, D] layout.\n    output = torch.permute(out, [0, 2, 1, 3])\n    return output\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/layers/utils.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/layers/utils.py\n\"\"\"Utility methods for model layers.\"\"\"\n\nimport inspect\nfrom typing import Any, Callable, List, Optional\n\nimport torch\nfrom torch.library import Library\n\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\n\n\ndef get_group_size(group) -> int:\n    if hasattr(group, \"world_size\"):\n        return group.world_size  # GroupCoordinator\n    elif hasattr(group, \"size\") and callable(getattr(group, \"size\", None)):\n        return group.size()  # ProcessGroup\n    else:\n        raise ValueError(f\"Unsupported group type: {type(group)}\")\n\n\ndef get_group_rank(group) -> int:\n    if hasattr(group, \"rank_in_group\"):\n        return group.rank_in_group  # GroupCoordinator\n    elif hasattr(group, \"rank\") and callable(getattr(group, \"rank\", None)):\n        return group.rank()  # ProcessGroup\n    else:\n        raise ValueError(f\"Unsupported group type: {type(group)}\")\n\n\ndef get_token_bin_counts_and_mask(\n    tokens: torch.Tensor,\n    vocab_size: int,\n    num_seqs: int,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    # Compute the bin counts for the tokens.\n    # vocab_size + 1 for padding.\n    bin_counts = torch.zeros(\n        (num_seqs, vocab_size + 1), dtype=torch.long, device=tokens.device\n    )\n    bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))\n    bin_counts = bin_counts[:, :vocab_size]\n    mask = bin_counts > 0\n\n    return bin_counts, mask\n\n\nsglang_lib = Library(\"sglang\", \"FRAGMENT\")  # noqa\n\n\ndef direct_register_custom_op(\n    op_name: str,\n    op_func: Callable,\n    mutates_args: List[str],\n    fake_impl: Optional[Callable] = None,\n    target_lib: Optional[Library] = None,\n):\n    \"\"\"\n    `torch.library.custom_op` can have significant overhead because it\n    needs to consider complicated dispatching logic. This function\n    directly registers a custom op and dispatches it to the CUDA backend.\n    See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5\n    for more details.\n\n    By default, the custom op is registered to the vLLM library. If you\n    want to register it to a different library, you can pass the library\n    object to the `target_lib` argument.\n\n    IMPORTANT: the lifetime of the operator is tied to the lifetime of the\n    library object. If you want to bind the operator to a different library,\n    make sure the library object is alive when the operator is used.\n\n    Note: This function will silently skip registration if the operator\n    with the same name is already registered to avoid RuntimeError in\n    multi-engine scenarios (e.g., VERL framework).\n    \"\"\"\n    import torch.library\n\n    my_lib = target_lib or sglang_lib\n\n    # Check if operator is already registered to avoid duplicate registration\n    # This is important for scenarios where multiple SGLang engines run in the same process\n    try:\n        # Try to access the operator to see if it's already registered\n        lib_name = my_lib.m.name if hasattr(my_lib.m, \"name\") else \"sglang\"\n        if hasattr(torch.ops, lib_name) and hasattr(\n            getattr(torch.ops, lib_name), op_name\n        ):\n            # Operator already exists, skip registration\n            return\n    except (AttributeError, RuntimeError):\n        # Operator doesn't exist, proceed with registration\n        pass\n\n    if hasattr(torch.library, \"infer_schema\"):\n        schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args)\n    else:\n        # for pytorch 2.4\n        import torch._custom_op.impl\n\n        schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)\n\n    try:\n        my_lib.define(op_name + schema_str)\n        my_lib.impl(\n            op_name, op_func, \"CUDA\" if not current_platform.is_npu() else \"PrivateUse1\"\n        )\n        if fake_impl is not None:\n            my_lib._register_fake(op_name, fake_impl)\n    except RuntimeError as error:\n        if \"Tried to register an operator\" in str(error) and \"multiple times\" in str(\n            error\n        ):\n            # Silently ignore duplicate registration errors\n            # This can happen in multi-engine scenarios\n            pass\n        else:\n            # Re-raise other RuntimeErrors\n            raise error\n    except AttributeError as error:\n        # Always re-raise AttributeError as it indicates missing dependencies\n        raise error\n\n\nclass CustomOpWrapper:\n    def __init__(\n        self,\n        op_name: str,\n        op_func: Callable,\n        mutates_args: List[str],\n        **extra_kwargs,\n    ):\n        self.op_name = op_name\n        self.op_func = op_func\n        self.mutates_args = mutates_args\n        self.extra_kwargs = extra_kwargs\n        self._impl: Optional[Callable] = None\n\n    def __call__(self, *args, **kwargs):\n        return self.real_impl(*args, **kwargs)\n\n    @property\n    def real_impl(self) -> Callable:\n        if self._impl is None:\n            if not hasattr(torch.ops.sglang, self.op_name):\n\n                # NOTE(dark): if torch compile fail here, mark the decorator as eager\n                # lazy registration does not work with torch compile\n                direct_register_custom_op(\n                    op_name=self.op_name,\n                    op_func=self.op_func,\n                    mutates_args=self.mutates_args,\n                    fake_impl=self.fake_impl,\n                )\n            self._impl = getattr(torch.ops.sglang, self.op_name)\n            assert self._impl is not None\n        return self._impl\n\n    @property\n    def fake_impl(self) -> Callable:\n        if \"fake_impl\" in self.extra_kwargs:\n            return self.extra_kwargs[\"fake_impl\"]\n        assert \"out_shape\" in self.extra_kwargs\n        signature = inspect.signature(self.op_func)\n        out_shape = self.extra_kwargs[\"out_shape\"]\n\n        # check out_shape in signature\n\n        def fake_impl(*args, **kwargs):\n            if out_shape is None:\n                return None\n            bound = signature.bind(*args, **kwargs)\n            bound.apply_defaults()\n            try:\n                return torch.empty_like(\n                    bound.args[out_shape]\n                    if isinstance(out_shape, int)\n                    else bound.arguments[out_shape]\n                )\n            except (IndexError, KeyError):\n                raise RuntimeError(\n                    f\"Cannot find output argument at position `{out_shape}` for \"\n                    f\"custom operator `{self.op_name}` with signature `{signature}`.\"\n                )\n\n        return fake_impl\n\n\n# Real implementation\ndef register_custom_op(\n    fn: Optional[Callable] = None,\n    *,\n    op_name: Optional[str] = None,\n    mutates_args: Optional[List[str]] = None,\n    eager: bool = True,\n    **extra_kwargs,\n) -> Any:\n    \"\"\"\n    A decorator to register a custom operator.\n\n    Example usage:\n    ```python\n    # inplace operator, out_shape is None by default\n    @register_custom_op(mutates_args=[\"x\"])\n    def add_1_(x: torch.Tensor) -> None:\n        x.add_(1)\n\n    # operator with output, out_shape indicates the position of output\n    @register_custom_op(mutates_args=[\"x\"], out_shape=0)\n    def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n        return x.add_(y)\n    ```\n\n    :param fn: The function to be registered as a custom operator.\n               If None, return a decorator.\n    :type fn: Callable\n    :param op_name: The name of the operator. If None, use the function name\n    :type op_name: Optional[str]\n    :param mutates_args: A list of argument names that are mutated in-place.\n    :type mutates_args: List[str]\n    :param out_shape: The position (int for positional, str for keyword) of the output-shape tensor.\n                      It is used to generate a fake implementation for torch.compile compatibility.\n                      If the operator is inplace and has no output, set to None.\n    :type out_shape: Optional[List[Union[int, str]]]\n    :param fake_impl: A fake implementation for the operator.\n                      Only one of `out_shape` or `fake_impl` should be provided.\n    :type fake_impl: Optional[Callable]\n    :param eager: Whether to register the operator eagerly.\n                  If False, the registration will be deferred until the first call.\n                  If you met any issue with torch.compile, try to set eager=True.\n                  Currently, to avoid misuse, we set eager=True by default.\n    :type eager: bool\n    :return: The registered JIT custom operator, or a decorator.\n             NOTE: the real register will occur at the first call of the function.\n    :rtype: Callable\n    \"\"\"\n    extra_kwarg_keys = set(extra_kwargs.keys())\n    expected_kwarg_keys = set({\"out_shape\", \"fake_impl\"})\n    assert (\n        expected_kwarg_keys >= extra_kwarg_keys\n    ), f\"Unexpected extra kwargs: {extra_kwarg_keys - expected_kwarg_keys}\"\n\n    has_out_shape = \"out_shape\" in extra_kwargs\n    has_fake_impl = \"fake_impl\" in extra_kwargs\n    assert not (\n        has_out_shape and has_fake_impl\n    ), \"Only one of `out_shape` or `fake_impl` should be provided.\"\n    # Assume inplace if neither out_shape nor fake_impl is provided\n    if not (has_out_shape or has_fake_impl):\n        extra_kwargs[\"out_shape\"] = None\n\n    def decorator(op_func: Callable) -> Callable:\n        wrapper = CustomOpWrapper(\n            op_name=op_name or op_func.__name__,\n            op_func=op_func,\n            mutates_args=mutates_args or [],\n            **extra_kwargs,\n        )\n        return wrapper.real_impl if eager else wrapper\n\n    if fn is not None:\n        return decorator(fn)\n    return decorator\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/layers/visual_embedding.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\nimport math\n\nimport torch\nimport torch.nn as nn\nfrom diffusers.models.embeddings import (\n    CombinedTimestepGuidanceTextProjEmbeddings as _CombinedTimestepGuidanceTextProjEmbeddings,\n)\nfrom diffusers.models.embeddings import (\n    CombinedTimestepTextProjEmbeddings as _CombinedTimestepTextProjEmbeddings,\n)\nfrom diffusers.models.embeddings import (\n    PixArtAlphaTextProjection,\n    TimestepEmbedding,\n)\nfrom diffusers.models.embeddings import Timesteps as _Timesteps\nfrom diffusers.models.embeddings import (\n    get_timestep_embedding as timestep_embedding_diffusers,\n)\n\nfrom sglang.jit_kernel.timestep_embedding import (\n    timestep_embedding as timestep_embedding_cuda,\n)\nfrom sglang.multimodal_gen.runtime.layers.activation import get_act_fn\nfrom sglang.multimodal_gen.runtime.layers.linear import ColumnParallelLinear\nfrom sglang.multimodal_gen.runtime.layers.mlp import MLP\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\n\n_is_cuda = current_platform.is_cuda()\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\"2D Image to Patch Embedding\n\n    Image to Patch Embedding using Conv2d\n\n    A convolution based approach to patchifying a 2D image w/ embedding projection.\n\n    Based on the impl in https://github.com/google-research/vision_transformer\n\n    Hacked together by / Copyright 2020 Ross Wightman\n\n    Remove the _assert function in forward function to be compatible with multi-resolution images.\n    \"\"\"\n\n    def __init__(\n        self,\n        patch_size=16,\n        in_chans=3,\n        embed_dim=768,\n        norm_layer=None,\n        flatten=True,\n        bias=True,\n        dtype=None,\n        prefix: str = \"\",\n    ):\n        super().__init__()\n        # Convert patch_size to 2-tuple\n        if isinstance(patch_size, list | tuple):\n            if len(patch_size) == 1:\n                patch_size = (patch_size[0], patch_size[0])\n        else:\n            patch_size = (patch_size, patch_size)\n\n        self.patch_size = patch_size\n        self.flatten = flatten\n\n        self.proj = nn.Conv3d(\n            in_chans,\n            embed_dim,\n            kernel_size=patch_size,\n            stride=patch_size,\n            bias=bias,\n            dtype=dtype,\n        )\n        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()\n\n    def forward(self, x):\n        x = self.proj(x)\n        if self.flatten:\n            x = x.flatten(2).transpose(1, 2)  # BCHW -> BNC\n        x = self.norm(x)\n        return x\n\n\nclass Timesteps(_Timesteps):\n    def forward(self, timesteps: torch.Tensor) -> torch.Tensor:\n        if _is_cuda:\n            return timestep_embedding_cuda(\n                timesteps,\n                self.num_channels,\n                flip_sin_to_cos=self.flip_sin_to_cos,\n                downscale_freq_shift=self.downscale_freq_shift,\n                scale=self.scale,\n            )\n        else:\n            return timestep_embedding_diffusers(\n                timesteps,\n                self.num_channels,\n                flip_sin_to_cos=self.flip_sin_to_cos,\n                downscale_freq_shift=self.downscale_freq_shift,\n                scale=self.scale,\n            )\n\n\nclass CombinedTimestepGuidanceTextProjEmbeddings(\n    _CombinedTimestepGuidanceTextProjEmbeddings\n):\n    def __init__(self, embedding_dim, pooled_projection_dim):\n        nn.Module.__init__(self)\n\n        # use sgld op\n        self.time_proj = Timesteps(\n            num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0\n        )\n        # use diffusers op\n        self.timestep_embedder = TimestepEmbedding(\n            in_channels=256, time_embed_dim=embedding_dim\n        )\n        self.guidance_embedder = TimestepEmbedding(\n            in_channels=256, time_embed_dim=embedding_dim\n        )\n        self.text_embedder = PixArtAlphaTextProjection(\n            pooled_projection_dim, embedding_dim, act_fn=\"silu\"\n        )\n\n\nclass CombinedTimestepTextProjEmbeddings(_CombinedTimestepTextProjEmbeddings):\n    def __init__(self, embedding_dim, pooled_projection_dim):\n        nn.Module.__init__(self)\n\n        # use sgld op\n        self.time_proj = Timesteps(\n            num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0\n        )\n        # use diffusers op\n        self.timestep_embedder = TimestepEmbedding(\n            in_channels=256, time_embed_dim=embedding_dim\n        )\n        self.text_embedder = PixArtAlphaTextProjection(\n            pooled_projection_dim, embedding_dim, act_fn=\"silu\"\n        )\n\n\nclass TimestepEmbedder(nn.Module):\n    \"\"\"\n    Embeds scalar timesteps into vector representations.\n    \"\"\"\n\n    def __init__(\n        self,\n        hidden_size,\n        act_layer=\"silu\",\n        frequency_embedding_size=256,\n        max_period=10000,\n        dtype=None,\n        freq_dtype=torch.float32,\n        prefix: str = \"\",\n    ):\n        super().__init__()\n        self.frequency_embedding_size = frequency_embedding_size\n        self.max_period = max_period\n\n        self.mlp = MLP(\n            frequency_embedding_size,\n            hidden_size,\n            hidden_size,\n            act_type=act_layer,\n            dtype=dtype,\n        )\n        self.freq_dtype = freq_dtype\n\n    def forward(\n        self, t: torch.Tensor, timestep_seq_len: int | None = None\n    ) -> torch.Tensor:\n        t_freq = timestep_embedding(\n            t, self.frequency_embedding_size, self.max_period, dtype=self.freq_dtype\n        ).to(self.mlp.fc_in.weight.dtype)\n        if timestep_seq_len is not None:\n            assert (\n                t_freq.shape[0] % timestep_seq_len == 0\n            ), \"timestep length is not divisible by timestep_seq_len\"\n            batch_size = t_freq.shape[0] // timestep_seq_len\n            t_freq = t_freq.unflatten(0, (batch_size, timestep_seq_len))\n        # t_freq = t_freq.to(self.mlp.fc_in.weight.dtype)\n        t_emb = self.mlp(t_freq)\n        return t_emb\n\n\ndef timestep_embedding(\n    t: torch.Tensor,\n    dim: int,\n    max_period: int = 10000,\n    dtype: torch.dtype = torch.float32,\n) -> torch.Tensor:\n    \"\"\"\n    Create sinusoidal timestep embeddings.\n\n    Args:\n        t: Tensor of shape [B] with timesteps\n        dim: Embedding dimension\n        max_period: Controls the minimum frequency of the embeddings\n\n    Returns:\n        Tensor of shape [B, dim] with embeddings\n    \"\"\"\n    half = dim // 2\n    freqs = torch.exp(\n        -math.log(max_period)\n        * torch.arange(start=0, end=half, dtype=dtype, device=t.device)\n        / half\n    )\n    args = t[:, None].float() * freqs[None]\n    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)\n    if dim % 2:\n        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)\n    return embedding\n\n\nclass ModulateProjection(nn.Module):\n    \"\"\"Modulation layer for DiT blocks.\"\"\"\n\n    def __init__(\n        self,\n        hidden_size: int,\n        factor: int = 2,\n        act_layer: str = \"silu\",\n        dtype: torch.dtype | None = None,\n        prefix: str = \"\",\n    ):\n        super().__init__()\n        self.factor = factor\n        self.hidden_size = hidden_size\n        self.linear = ColumnParallelLinear(\n            hidden_size,\n            hidden_size * factor,\n            bias=True,\n            gather_output=True,\n            params_dtype=dtype,\n        )\n        self.act = get_act_fn(act_layer)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.act(x)\n        x, _ = self.linear(x)\n        return x\n\n\ndef unpatchify(x, t, h, w, patch_size, channels) -> torch.Tensor:\n    \"\"\"\n    Convert patched representation back to image space.\n\n    Args:\n        x: Tensor of shape [B, T*H*W, C*P_t*P_h*P_w]\n        t, h, w: Temporal and spatial dimensions\n\n    Returns:\n        Unpatchified tensor of shape [B, C, T*P_t, H*P_h, W*P_w]\n    \"\"\"\n    assert x.ndim == 3, f\"x.ndim: {x.ndim}\"\n    assert len(patch_size) == 3, f\"patch_size: {patch_size}\"\n    assert t * h * w == x.shape[1], f\"t * h * w: {t * h * w}, x.shape[1]: {x.shape[1]}\"\n    c = channels\n    pt, ph, pw = patch_size\n\n    x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw))\n    x = torch.einsum(\"nthwcopq->nctohpwq\", x)\n    imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))\n\n    return imgs\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/layers/vocab_parallel_embedding.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\nfrom collections.abc import Sequence\nfrom dataclasses import dataclass\n\nimport torch\nimport torch.distributed as dist\nimport torch.nn.functional as F\nfrom torch.nn.parameter import Parameter, UninitializedParameter\n\nfrom sglang.multimodal_gen.runtime.distributed import (\n    divide,\n    get_tp_group,\n    tensor_model_parallel_all_reduce,\n)\nfrom sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import (\n    QuantizationConfig,\n    QuantizeMethodBase,\n    method_has_implemented_embedding,\n)\nfrom sglang.multimodal_gen.runtime.layers.utils import get_group_rank, get_group_size\nfrom sglang.multimodal_gen.runtime.models.parameter import BasevLLMParameter\nfrom sglang.multimodal_gen.runtime.models.utils import set_weight_attrs\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\n\nDEFAULT_VOCAB_PADDING_SIZE = 64\n\n\nclass UnquantizedEmbeddingMethod(QuantizeMethodBase):\n    \"\"\"Unquantized method for embeddings.\"\"\"\n\n    def create_weights(\n        self,\n        layer: torch.nn.Module,\n        input_size_per_partition: int,\n        output_partition_sizes: list[int],\n        input_size: int,\n        output_size: int,\n        params_dtype: torch.dtype,\n        **extra_weight_attrs,\n    ):\n        \"\"\"Create weights for embedding layer.\"\"\"\n\n        weight = Parameter(\n            torch.empty(\n                sum(output_partition_sizes),\n                input_size_per_partition,\n                dtype=params_dtype,\n            ),\n            requires_grad=False,\n        )\n        set_weight_attrs(weight, {\"input_dim\": 1, \"output_dim\": 0})\n        layer.register_parameter(\"weight\", weight)\n        set_weight_attrs(weight, extra_weight_attrs)\n\n    def apply(\n        self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None\n    ) -> torch.Tensor:\n        return F.linear(x, layer.weight, bias)\n\n    def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor:\n        return F.embedding(input_, layer.weight)\n\n\ndef pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:\n    \"\"\"Pad the vocab size to the given value.\"\"\"\n    return ((vocab_size + pad_to - 1) // pad_to) * pad_to\n\n\ndef vocab_range_from_per_partition_vocab_size(\n    per_partition_vocab_size: int, rank: int, offset: int = 0\n) -> Sequence[int]:\n    index_f = rank * per_partition_vocab_size\n    index_l = index_f + per_partition_vocab_size\n    return index_f + offset, index_l + offset\n\n\ndef vocab_range_from_global_vocab_size(\n    global_vocab_size: int, rank: int, world_size: int, offset: int = 0\n) -> Sequence[int]:\n    per_partition_vocab_size = divide(global_vocab_size, world_size)\n    return vocab_range_from_per_partition_vocab_size(\n        per_partition_vocab_size, rank, offset=offset\n    )\n\n\n@dataclass\nclass VocabParallelEmbeddingShardIndices:\n    \"\"\"Indices for a shard of a vocab parallel embedding.\"\"\"\n\n    padded_org_vocab_start_index: int\n    padded_org_vocab_end_index: int\n    padded_added_vocab_start_index: int\n    padded_added_vocab_end_index: int\n\n    org_vocab_start_index: int\n    org_vocab_end_index: int\n    added_vocab_start_index: int\n    added_vocab_end_index: int\n\n    @property\n    def num_org_elements(self) -> int:\n        return self.org_vocab_end_index - self.org_vocab_start_index\n\n    @property\n    def num_added_elements(self) -> int:\n        return self.added_vocab_end_index - self.added_vocab_start_index\n\n    @property\n    def num_org_elements_padded(self) -> int:\n        return self.padded_org_vocab_end_index - self.padded_org_vocab_start_index\n\n    @property\n    def num_added_elements_padded(self) -> int:\n        return self.padded_added_vocab_end_index - self.padded_added_vocab_start_index\n\n    @property\n    def num_org_vocab_padding(self) -> int:\n        return self.num_org_elements_padded - self.num_org_elements\n\n    @property\n    def num_added_vocab_padding(self) -> int:\n        return self.num_added_elements_padded - self.num_added_elements\n\n    @property\n    def num_elements_padded(self) -> int:\n        return self.num_org_elements_padded + self.num_added_elements_padded\n\n    def __post_init__(self):\n        # sanity checks\n        assert self.padded_org_vocab_start_index <= self.padded_org_vocab_end_index\n        assert self.padded_added_vocab_start_index <= self.padded_added_vocab_end_index\n\n        assert self.org_vocab_start_index <= self.org_vocab_end_index\n        assert self.added_vocab_start_index <= self.added_vocab_end_index\n\n        assert self.org_vocab_start_index <= self.padded_org_vocab_start_index\n        assert self.added_vocab_start_index <= self.padded_added_vocab_start_index\n        assert self.org_vocab_end_index <= self.padded_org_vocab_end_index\n        assert self.added_vocab_end_index <= self.padded_added_vocab_end_index\n\n        assert self.num_org_elements <= self.num_org_elements_padded\n        assert self.num_added_elements <= self.num_added_elements_padded\n\n\n@torch.compile(\n    dynamic=True,\n    backend=current_platform.simple_compile_backend,\n    disable=current_platform.is_npu(),\n)\ndef get_masked_input_and_mask(\n    input_: torch.Tensor,\n    org_vocab_start_index: int,\n    org_vocab_end_index: int,\n    num_org_vocab_padding: int,\n    added_vocab_start_index: int,\n    added_vocab_end_index: int,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    # torch.compile will fuse all of the pointwise ops below\n    # into a single kernel, making it very fast\n    org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < org_vocab_end_index)\n    added_vocab_mask = (input_ >= added_vocab_start_index) & (\n        input_ < added_vocab_end_index\n    )\n    added_offset = (\n        added_vocab_start_index\n        - (org_vocab_end_index - org_vocab_start_index)\n        - num_org_vocab_padding\n    )\n    valid_offset = (org_vocab_start_index * org_vocab_mask) + (\n        added_offset * added_vocab_mask\n    )\n    vocab_mask = org_vocab_mask | added_vocab_mask\n    input_ = vocab_mask * (input_ - valid_offset)\n    return input_, ~vocab_mask\n\n\nclass VocabParallelEmbedding(torch.nn.Module):\n    \"\"\"Embedding parallelized in the vocabulary dimension.\n\n    Adapted from torch.nn.Embedding, note that we pad the vocabulary size to\n    make sure it is divisible by the number of model parallel GPUs.\n\n    In order to support various loading methods, we ensure that LoRA-added\n    embeddings are always at the end of TP-sharded tensors. In other words,\n    we shard base embeddings and LoRA embeddings separately (both padded),\n    and place them in the same tensor.\n    In this example, we will have the original vocab size = 1010,\n    added vocab size = 16 and padding to 64. Therefore, the total\n    vocab size with padding will be 1088 (because we first pad 1010 to\n    1024, add 16, and then pad to 1088).\n    Therefore, the tensor format looks like the following:\n    TP1, rank 0 (no sharding):\n                            |< --------BASE-------- >|< -BASE PADDING-- >|< -----LORA------ >|< -LORA PADDING-- >|\n    corresponding token_id: |  0  |  1  | ... | 1009 |  -1  | ... |  -1  | 1010 | ... | 1015 |  -1  | ... |  -1  |\n                     index: |  0  |  1  | ... | 1009 | 1010 | ... | 1023 | 1024 | ... | 1039 | 1040 | ... | 1087 |\n\n    TP2, rank 0:\n                            |< --------------------BASE--------------------- >|< -----LORA------ >|< -LORA PADDING- >|\n    corresponding token_id: |  0  |  1  |  2  | ... | 497  | 498 | ...  | 511 | 1000 | ... | 1015 |  -1  | ... |  -1 |\n                     index: |  0  |  1  |  2  | ... | 497  | 498 | ...  | 511 | 512  | ... | 527  |  520 | ... | 543 |\n    TP2, rank 1:\n                            |< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >|\n    corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1  | ...  | -1  |  -1  | ... |  -1  | -1  | ... |   -1 |\n                     index: |  0  |  1  |  2  | ... | 497  | 498 | ...  | 511 | 512  | ... | 519  | 520 | ... |  543 |\n\n    Args:\n        num_embeddings: vocabulary size.\n        embedding_dim: size of hidden state.\n        params_dtype: type of the parameters.\n        org_num_embeddings: original vocabulary size (without LoRA).\n        padding_size: padding size for the vocabulary.\n        quant_config: quant config for the layer\n        prefix: full name of the layer in the state dict\n    \"\"\"  # noqa: E501\n\n    def __init__(\n        self,\n        num_embeddings: int,\n        embedding_dim: int,\n        params_dtype: torch.dtype | None = None,\n        org_num_embeddings: int | None = None,\n        padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,\n        quant_config: QuantizationConfig | None = None,\n        prefix: str = \"\",\n        tp_group: dist.ProcessGroup = None,\n    ):\n        super().__init__()\n\n        # Keep the input dimensions.\n        tp_group = tp_group or get_tp_group()\n        tp_rank = get_group_rank(tp_group)\n        self.tp_size = get_group_size(tp_group)\n        self.tp_group = tp_group\n        self.num_embeddings = num_embeddings\n        self.padding_size = padding_size\n        self.org_vocab_size = org_num_embeddings or num_embeddings\n        num_added_embeddings = num_embeddings - self.org_vocab_size\n        self.org_vocab_size_padded = pad_vocab_size(\n            self.org_vocab_size, self.padding_size\n        )\n        self.num_embeddings_padded = pad_vocab_size(\n            self.org_vocab_size_padded + num_added_embeddings, self.padding_size\n        )\n        assert self.org_vocab_size_padded <= self.num_embeddings_padded\n\n        self.shard_indices = self._get_indices(\n            self.num_embeddings_padded,\n            self.org_vocab_size_padded,\n            self.num_embeddings,\n            self.org_vocab_size,\n            tp_rank,\n            self.tp_size,\n        )\n        self.embedding_dim = embedding_dim\n\n        quant_method = None\n        if quant_config is not None:\n            quant_method = quant_config.get_quant_method(self, prefix=prefix)\n        if quant_method is None:\n            quant_method = UnquantizedEmbeddingMethod()\n\n        # If we are making an embedding layer, then our quantization linear\n        # method must implement the embedding operation. If we are another\n        # layer type like ParallelLMHead, this is not important.\n        is_embedding_layer = type(self.__class__) is VocabParallelEmbedding\n        quant_method_implements_embedding = method_has_implemented_embedding(\n            type(quant_method)\n        )\n        if is_embedding_layer and not quant_method_implements_embedding:\n            raise NotImplementedError(\n                f\"The class {type(quant_method).__name__} must implement \"\n                \"the 'embedding' method, see UnquantizedEmbeddingMethod.\"\n            )\n\n        self.quant_method: QuantizeMethodBase = quant_method\n\n        if params_dtype is None:\n            params_dtype = torch.get_default_dtype()\n        # Divide the weight matrix along the vocaburaly dimension.\n        self.num_added_embeddings = self.num_embeddings - self.org_vocab_size\n        self.num_embeddings_per_partition = divide(\n            self.num_embeddings_padded, self.tp_size\n        )\n        assert (\n            self.shard_indices.num_elements_padded == self.num_embeddings_per_partition\n        )\n        self.num_org_embeddings_per_partition = (\n            self.shard_indices.org_vocab_end_index\n            - self.shard_indices.org_vocab_start_index\n        )\n        self.num_added_embeddings_per_partition = (\n            self.shard_indices.added_vocab_end_index\n            - self.shard_indices.added_vocab_start_index\n        )\n\n        self.quant_method.create_weights(\n            self,\n            self.embedding_dim,\n            [self.num_embeddings_per_partition],\n            self.embedding_dim,\n            self.num_embeddings_padded,\n            params_dtype=params_dtype,\n            weight_loader=self.weight_loader,\n        )\n\n    @classmethod\n    def _get_indices(\n        cls,\n        vocab_size_padded: int,\n        org_vocab_size_padded: int,\n        vocab_size: int,\n        org_vocab_size: int,\n        tp_rank: int,\n        tp_size: int,\n    ) -> VocabParallelEmbeddingShardIndices:\n        \"\"\"Get start and end indices for vocab parallel embedding, following the\n        layout outlined in the class docstring, based on the given tp_rank and\n        tp_size.\"\"\"\n        num_added_embeddings_padded = vocab_size_padded - org_vocab_size_padded\n        padded_org_vocab_start_index, padded_org_vocab_end_index = (\n            vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank, tp_size)\n        )\n        padded_added_vocab_start_index, padded_added_vocab_end_index = (\n            vocab_range_from_global_vocab_size(\n                num_added_embeddings_padded, tp_rank, tp_size, offset=org_vocab_size\n            )\n        )\n        # remove padding\n        org_vocab_start_index = min(padded_org_vocab_start_index, org_vocab_size)\n        org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size)\n        added_vocab_start_index = min(padded_added_vocab_start_index, vocab_size)\n        added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size)\n        return VocabParallelEmbeddingShardIndices(\n            padded_org_vocab_start_index,\n            padded_org_vocab_end_index,\n            padded_added_vocab_start_index,\n            padded_added_vocab_end_index,\n            org_vocab_start_index,\n            org_vocab_end_index,\n            added_vocab_start_index,\n            added_vocab_end_index,\n        )\n\n    def get_sharded_to_full_mapping(self) -> list[int] | None:\n        \"\"\"Get a mapping that can be used to reindex the gathered\n        logits for sampling.\n\n        During sampling, we gather logits from all ranks. The relationship\n        of index->token_id will follow the same format as outlined in the class\n        docstring. However, after the gather, we want to reindex the final\n        logits tensor to map index->token_id one-to-one (the index is always\n        equal the token_id it corresponds to). The indices returned by this\n        method allow us to do that.\n        \"\"\"\n        if self.tp_size < 2:\n            return None\n\n        base_embeddings: list[int] = []\n        added_embeddings: list[int] = []\n        padding: list[int] = []\n        for tp_rank in range(self.tp_size):\n            shard_indices = self._get_indices(\n                self.num_embeddings_padded,\n                self.org_vocab_size_padded,\n                self.num_embeddings,\n                self.org_vocab_size,\n                tp_rank,\n                self.tp_size,\n            )\n            range_start = self.num_embeddings_per_partition * tp_rank\n            range_end = self.num_embeddings_per_partition * (tp_rank + 1)\n            base_embeddings.extend(\n                range(range_start, range_start + shard_indices.num_org_elements)\n            )\n            padding.extend(\n                range(\n                    range_start + shard_indices.num_org_elements,\n                    range_start + shard_indices.num_org_elements_padded,\n                )\n            )\n            added_embeddings.extend(\n                range(\n                    range_start + shard_indices.num_org_elements_padded,\n                    range_start\n                    + shard_indices.num_org_elements_padded\n                    + shard_indices.num_added_elements,\n                )\n            )\n            padding.extend(\n                range(\n                    range_start\n                    + shard_indices.num_org_elements_padded\n                    + shard_indices.num_added_elements,\n                    range_start\n                    + shard_indices.num_org_elements_padded\n                    + shard_indices.num_added_elements_padded,\n                )\n            )\n            assert (\n                range_start\n                + shard_indices.num_org_elements_padded\n                + shard_indices.num_added_elements_padded\n                == range_end\n            )\n        ret = base_embeddings + added_embeddings + padding\n        assert len(ret) == self.num_embeddings_padded\n        return ret\n\n    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):\n        output_dim = getattr(param, \"output_dim\", None)\n        packed_dim = getattr(param, \"packed_dim\", None)\n\n        # If the parameter is a gguf weight, then load it directly.\n        if getattr(param, \"is_gguf_weight_type\", None):\n            param.data.copy_(loaded_weight)\n            param.weight_type = loaded_weight.item()\n            return\n        elif isinstance(param, UninitializedParameter):\n            shape = list(loaded_weight.shape)\n            if output_dim is not None:\n                shape[output_dim] = self.num_embeddings_per_partition\n            param.materialize(tuple(shape), dtype=loaded_weight.dtype)\n\n        # If parameter does not have output dim, then it should\n        # be copied onto all gpus (e.g. g_idx for act_order gptq).\n        if output_dim is None:\n            assert param.data.shape == loaded_weight.shape\n            param.data.copy_(loaded_weight)\n            return\n\n        # Shard indexes for loading the weight\n        start_idx = self.shard_indices.org_vocab_start_index\n        shard_size = self.shard_indices.org_vocab_end_index - start_idx\n\n        # If param packed on the same dim we are sharding on, then\n        # need to adjust offsets of loaded weight by pack_factor.\n        if packed_dim is not None and packed_dim == output_dim:\n            packed_factor = (\n                param.packed_factor\n                if isinstance(param, BasevLLMParameter)\n                else param.pack_factor\n            )\n            assert loaded_weight.shape[output_dim] == (\n                self.org_vocab_size // param.packed_factor\n            )\n            start_idx = start_idx // packed_factor\n            shard_size = shard_size // packed_factor\n        else:\n            assert loaded_weight.shape[output_dim] == self.org_vocab_size\n\n        # Copy the data. Select chunk corresponding to current shard.\n        loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)\n\n        param[: loaded_weight.shape[0]].data.copy_(loaded_weight)\n        param[loaded_weight.shape[0] :].data.fill_(0)\n\n    def forward(self, input_):\n        if self.tp_size > 1:\n            # Build the mask.\n            masked_input, input_mask = get_masked_input_and_mask(\n                input_,\n                self.shard_indices.org_vocab_start_index,\n                self.shard_indices.org_vocab_end_index,\n                self.shard_indices.num_org_vocab_padding,\n                self.shard_indices.added_vocab_start_index,\n                self.shard_indices.added_vocab_end_index,\n            )\n        else:\n            masked_input = input_\n        # Get the embeddings.\n        output_parallel = self.quant_method.embedding(self, masked_input.long())\n        # Mask the output embedding.\n        if self.tp_size > 1:\n            output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)\n        # Reduce across all the model parallel GPUs.\n        output = tensor_model_parallel_all_reduce(\n            output_parallel, tp_group=self.tp_group\n        )\n        return output\n\n    def extra_repr(self) -> str:\n        s = f\"num_embeddings={self.num_embeddings_per_partition}\"\n        s += f\", embedding_dim={self.embedding_dim}\"\n        s += f\", org_vocab_size={self.org_vocab_size}\"\n        s += f\", num_embeddings_padded={self.num_embeddings_padded}\"\n        s += f\", tp_size={self.tp_size}\"\n        return s\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/loader/component_loaders/adapter_loader.py",
    "content": "from safetensors.torch import load_file as safetensors_load_file\n\nfrom sglang.multimodal_gen.runtime.distributed import get_local_torch_device\nfrom sglang.multimodal_gen.runtime.loader.component_loaders.component_loader import (\n    ComponentLoader,\n)\nfrom sglang.multimodal_gen.runtime.loader.utils import (\n    _list_safetensors_files,\n    set_default_torch_dtype,\n    skip_init_modules,\n)\nfrom sglang.multimodal_gen.runtime.models.registry import ModelRegistry\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import (\n    get_diffusers_component_config,\n)\nfrom sglang.multimodal_gen.utils import PRECISION_TO_TYPE\n\n\nclass AdapterLoader(ComponentLoader):\n    \"\"\"Loader for small adapter-style modules (e.g., LTX-2 connectors).\n\n    This loader intentionally avoids FSDP sharding and just:\n    1) Instantiates the module from `config.json`.\n    2) Loads a single safetensors state_dict.\n    \"\"\"\n\n    component_names = [\"connectors\"]\n    expected_library = \"diffusers\"\n\n    def load_customized(\n        self, component_model_path: str, server_args: ServerArgs, *args\n    ):\n        config = get_diffusers_component_config(component_path=component_model_path)\n\n        cls_name = config.pop(\"_class_name\", None)\n        if cls_name is None:\n            raise ValueError(\n                \"Model config does not contain a _class_name attribute. \"\n                \"Only diffusers format is supported.\"\n            )\n\n        config.pop(\"_diffusers_version\", None)\n        config.pop(\"_name_or_path\", None)\n\n        server_args.model_paths[\"connectors\"] = component_model_path\n\n        model_cls, _ = ModelRegistry.resolve_model_cls(cls_name)\n\n        target_device = get_local_torch_device()\n        default_dtype = PRECISION_TO_TYPE[server_args.pipeline_config.dit_precision]\n\n        from types import SimpleNamespace\n\n        with set_default_torch_dtype(default_dtype), skip_init_modules():\n            connector_cfg = SimpleNamespace(**config)\n            model = model_cls(connector_cfg).to(\n                device=target_device, dtype=default_dtype\n            )\n\n        safetensors_list = _list_safetensors_files(component_model_path)\n        if not safetensors_list:\n            raise ValueError(f\"No safetensors files found in {component_model_path}\")\n        if len(safetensors_list) != 1:\n            raise ValueError(\n                f\"Found {len(safetensors_list)} safetensors files in {component_model_path}, expected 1\"\n            )\n\n        loaded = safetensors_load_file(safetensors_list[0])\n        model.load_state_dict(loaded, strict=False)\n\n        return model\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/loader/component_loaders/bridge_loader.py",
    "content": "from copy import deepcopy\n\nimport torch\n\nfrom sglang.multimodal_gen.runtime.distributed import get_local_torch_device\nfrom sglang.multimodal_gen.runtime.loader.component_loaders.component_loader import (\n    ComponentLoader,\n)\nfrom sglang.multimodal_gen.runtime.loader.fsdp_load import maybe_load_fsdp_model\nfrom sglang.multimodal_gen.runtime.loader.utils import _list_safetensors_files\nfrom sglang.multimodal_gen.runtime.models.registry import ModelRegistry\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import (\n    get_diffusers_component_config,\n)\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.utils import PRECISION_TO_TYPE\n\nlogger = init_logger(__name__)\n\n\nclass BridgeLoader(ComponentLoader):\n    \"\"\"Loader for MOVA dual tower bridge with FSDP support.\"\"\"\n\n    pipeline_bridge_config_attr: str = \"bridge_config\"\n\n    component_names = [\"dual_tower_bridge\"]\n    expected_library = \"diffusers\"\n\n    def load_customized(\n        self, component_model_path: str, server_args: ServerArgs, component_name: str\n    ):\n        config = get_diffusers_component_config(component_path=component_model_path)\n        hf_config = deepcopy(config)\n        class_name = config.pop(\"_class_name\", None)\n        if class_name is None:\n            raise ValueError(\n                \"Model config does not contain a _class_name attribute. \"\n                \"Only diffusers format is supported.\"\n            )\n        server_args.model_paths[component_name] = component_model_path\n\n        # Try to get bridge config from pipeline config, fallback to creating one\n        bridge_config = getattr(\n            server_args.pipeline_config, self.pipeline_bridge_config_attr, None\n        )\n        if bridge_config is not None:\n            bridge_config.update_model_arch(config)\n        else:\n            # Create a minimal config from hf_config\n            from sglang.multimodal_gen.configs.models.bridges.mova_dual_tower import (\n                MOVADualTowerConfig,\n            )\n\n            bridge_config = MOVADualTowerConfig()\n            bridge_config.update_model_arch(config)\n\n        model_cls, _ = ModelRegistry.resolve_model_cls(class_name)\n\n        # Find all safetensors files\n        safetensors_list = _list_safetensors_files(component_model_path)\n        if not safetensors_list:\n            raise ValueError(f\"No safetensors files found in {component_model_path}\")\n\n        default_dtype = PRECISION_TO_TYPE[server_args.pipeline_config.dit_precision]\n\n        logger.info(\n            \"Loading %s from %s safetensors files, default_dtype: %s\",\n            class_name,\n            len(safetensors_list),\n            default_dtype,\n        )\n\n        # Check if FSDP loading is available\n        if (\n            server_args.hsdp_shard_dim is not None\n            and hasattr(model_cls, \"_fsdp_shard_conditions\")\n            and model_cls._fsdp_shard_conditions\n        ):\n            # Load with FSDP support\n            model = maybe_load_fsdp_model(\n                model_cls=model_cls,\n                init_params={\"config\": bridge_config, \"hf_config\": hf_config},\n                weight_dir_list=safetensors_list,\n                device=get_local_torch_device(),\n                hsdp_replicate_dim=server_args.hsdp_replicate_dim,\n                hsdp_shard_dim=server_args.hsdp_shard_dim,\n                cpu_offload=server_args.dit_cpu_offload,\n                pin_cpu_memory=server_args.pin_cpu_memory,\n                fsdp_inference=server_args.use_fsdp_inference,\n                param_dtype=default_dtype,\n                reduce_dtype=torch.float32,\n                output_dtype=None,\n                strict=False,\n            )\n        else:\n            # Fallback to simple loading (for non-FSDP or legacy models)\n            model = model_cls.from_pretrained(\n                component_model_path, torch_dtype=default_dtype\n            )\n            model = model.to(device=get_local_torch_device(), dtype=default_dtype)\n\n        total_params = sum(p.numel() for p in model.parameters())\n        logger.info(\"Loaded bridge model with %.2fM parameters\", total_params / 1e6)\n\n        return model\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/loader/component_loaders/component_loader.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\nimport importlib\nimport os\nimport pkgutil\nimport traceback\nfrom abc import ABC\nfrom typing import Any, Type\n\nimport torch\nfrom diffusers import AutoModel\nfrom torch import nn\nfrom transformers import AutoImageProcessor, AutoProcessor, AutoTokenizer\n\nfrom sglang.multimodal_gen.configs.models import ModelConfig\nfrom sglang.multimodal_gen.runtime.distributed import get_local_torch_device\nfrom sglang.multimodal_gen.runtime.loader.utils import (\n    _normalize_component_type,\n    component_name_to_loader_cls,\n    get_memory_usage_of_component,\n)\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import get_hf_config\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\nclass ComponentLoader(ABC):\n    \"\"\"Base class for loading a specific type of model component.\"\"\"\n\n    # the list of possible name of the component in model_index.json, e.g., scheduler\n    component_names: list[str] = []\n\n    # diffusers or transformers\n    expected_library: str = \"\"\n\n    _loaders_registered = False\n\n    def __init_subclass__(cls, **kwargs):\n        \"\"\"\n        register loaders, called when subclass is imported\n        \"\"\"\n        super().__init_subclass__(**kwargs)\n        for component_name in cls.component_names:\n            component_name_to_loader_cls[component_name] = cls\n\n    def __init__(self, device=None) -> None:\n        self.device = device\n\n    def should_offload(\n        self, server_args: ServerArgs, model_config: ModelConfig | None = None\n    ):\n        # not offload by default\n        return False\n\n    def target_device(self, should_offload):\n        if should_offload:\n            return (\n                torch.device(\"mps\")\n                if current_platform.is_mps()\n                else torch.device(\"cpu\")\n            )\n        else:\n            return get_local_torch_device()\n\n    def load(\n        self,\n        component_model_path: str,\n        server_args: ServerArgs,\n        component_name: str,\n        transformers_or_diffusers: str,\n    ) -> tuple[AutoModel, float]:\n        \"\"\"\n        Template method that standardizes logging around the core load implementation.\n        The priority of loading method is:\n            1. load customized component\n            2. load native diffusers/transformers component\n        If all of the above methods failed, an error will be thrown\n\n        \"\"\"\n        gpu_mem_before_loading = current_platform.get_available_gpu_memory()\n        logger.info(\n            \"Loading %s from %s. avail mem: %.2f GB\",\n            component_name,\n            component_model_path,\n            gpu_mem_before_loading,\n        )\n        try:\n            component = self.load_customized(\n                component_model_path, server_args, component_name\n            )\n            source = \"sgl-diffusion\"\n        except Exception as e:\n            if \"Unsupported model architecture\" in str(e):\n                logger.info(\n                    f\"Component: {component_name} doesn't have a customized version yet, using native version\"\n                )\n            else:\n                traceback.print_exc()\n                logger.error(\n                    f\"Error while loading customized {component_name}, falling back to native version\"\n                )\n            # fallback to native version\n            component = self.load_native(\n                component_model_path, server_args, transformers_or_diffusers\n            )\n            should_offload = self.should_offload(server_args)\n            target_device = self.target_device(should_offload)\n            component = component.to(device=target_device)\n            source = \"native\"\n            logger.warning(\n                \"Native component %s: %s is loaded, performance may be sub-optimal\",\n                component_name,\n                component.__class__.__name__,\n            )\n\n        if component is None:\n            logger.error(\"Load %s failed\", component_name)\n            consumed = 0.0\n        else:\n            if isinstance(component, nn.Module):\n                component = component.eval()\n            current_gpu_mem = current_platform.get_available_gpu_memory()\n            model_size = get_memory_usage_of_component(component) or \"NA\"\n            consumed = gpu_mem_before_loading - current_gpu_mem\n            logger.info(\n                f\"Loaded %s: %s ({source} version). model size: %s GB, consumed GPU mem: %.2f GB, avail GPU mem: %.2f GB\",\n                component_name,\n                component.__class__.__name__,\n                model_size,\n                consumed,\n                current_gpu_mem,\n            )\n        return component, consumed\n\n    def load_native(\n        self,\n        component_model_path: str,\n        server_args: ServerArgs,\n        transformers_or_diffusers: str,\n    ) -> AutoModel:\n        \"\"\"\n        Load the component using the native library (transformers/diffusers).\n        \"\"\"\n        if transformers_or_diffusers == \"transformers\":\n            from transformers import AutoModel\n\n            config = get_hf_config(\n                component_model_path,\n                trust_remote_code=server_args.trust_remote_code,\n                revision=server_args.revision,\n            )\n            return AutoModel.from_pretrained(\n                component_model_path,\n                config=config,\n                trust_remote_code=server_args.trust_remote_code,\n                revision=server_args.revision,\n            )\n        elif transformers_or_diffusers == \"diffusers\":\n            from diffusers import AutoModel\n\n            return AutoModel.from_pretrained(\n                component_model_path,\n                revision=server_args.revision,\n                trust_remote_code=server_args.trust_remote_code,\n            )\n        else:\n            raise ValueError(f\"Unsupported library: {transformers_or_diffusers}\")\n\n    def load_customized(\n        self, component_model_path: str, server_args: ServerArgs, component_name: str\n    ):\n        \"\"\"\n        Load the customized version component, implemented and optimized in SGL-diffusion\n        \"\"\"\n        raise NotImplementedError(\n            f\"load_customized not implemented for {self.__class__.__name__}\"\n        )\n\n    @classmethod\n    def _ensure_loaders_registered(cls):\n        \"\"\"\n        avoid multiple registration\n        \"\"\"\n        if cls._loaders_registered:\n            return\n\n        package_dir = os.path.dirname(__file__)\n        package_name = (\n            __package__\n            or \"sglang.multimodal_gen.runtime.loader.component_loaders.component_loaders\"\n        )\n\n        for _, name, _ in pkgutil.iter_modules([package_dir]):\n            # skip importing self to avoid circular dependency issues\n            if name == \"component_loader\":\n                continue\n            try:\n                importlib.import_module(f\".{name}\", package=package_name)\n            except ImportError as e:\n                logger.warning(f\"Failed to import loader component {name}: {e}\")\n\n        cls._loaders_registered = True\n\n    @classmethod\n    def for_component_type(\n        cls, component_name: str, transformers_or_diffusers: str\n    ) -> \"ComponentLoader\":\n        \"\"\"\n        Factory method to create a component loader for a specific component type.\n\n        Args:\n            component_name: Type of component (e.g., \"vae\", \"text_encoder\", \"transformer\", \"scheduler\")\n            transformers_or_diffusers: Whether the component is from transformers or diffusers\n        \"\"\"\n        cls._ensure_loaders_registered()\n\n        # Map of component types to their loader classes and expected library\n        component_name = _normalize_component_type(component_name)\n\n        # NOTE(FlamingoPg): special for LTX-2 models\n        if component_name == \"vocoder\" or component_name == \"connectors\":\n            transformers_or_diffusers = \"diffusers\"\n\n        # NOTE(CloudRipple): special for MOVA models\n        # TODO(CloudRipple): remove most of these special cases after unifying the loading logic\n        if component_name in [\n            \"audio_vae\",\n            \"audio_dit\",\n            \"dual_tower_bridge\",\n            \"video_dit\",\n        ]:\n            transformers_or_diffusers = \"diffusers\"\n\n        if (\n            component_name == \"scheduler\"\n            and transformers_or_diffusers == \"mova.diffusion.schedulers.flow_match_pair\"\n        ):\n            transformers_or_diffusers = \"diffusers\"\n\n        if component_name in component_name_to_loader_cls:\n            loader_cls: Type[ComponentLoader] = component_name_to_loader_cls[\n                component_name\n            ]\n            expected_library = loader_cls.expected_library\n            # Assert that the library matches what's expected for this component type\n            assert (\n                transformers_or_diffusers == expected_library\n            ), f\"{component_name} must be loaded from {expected_library}, got {transformers_or_diffusers}\"\n            return loader_cls()\n\n        # For unknown component types, use a generic loader\n        logger.warning(\n            \"No specific loader found for component type: %s. Using generic loader.\",\n            component_name,\n        )\n        return GenericComponentLoader(transformers_or_diffusers)\n\n\nclass ImageProcessorLoader(ComponentLoader):\n    \"\"\"Loader for image processor.\"\"\"\n\n    component_names = [\"image_processor\"]\n    expected_library = \"transformers\"\n\n    def load_customized(\n        self, component_model_path: str, server_args: ServerArgs, component_name: str\n    ) -> Any:\n        return AutoImageProcessor.from_pretrained(component_model_path, use_fast=True)\n\n\nclass AutoProcessorLoader(ComponentLoader):\n    \"\"\"Loader for auto processor.\"\"\"\n\n    component_names = [\"processor\"]\n    expected_library = \"transformers\"\n\n    def load_customized(\n        self, component_model_path: str, server_args: ServerArgs, component_name: str\n    ) -> Any:\n        return AutoProcessor.from_pretrained(component_model_path)\n\n\nclass TokenizerLoader(ComponentLoader):\n    \"\"\"Loader for tokenizers.\"\"\"\n\n    component_names = [\"tokenizer\"]\n    expected_library = \"transformers\"\n\n    def load_customized(\n        self, component_model_path: str, server_args: ServerArgs, component_name: str\n    ) -> Any:\n        return AutoTokenizer.from_pretrained(\n            component_model_path,\n            padding_size=\"right\",\n        )\n\n\nclass GenericComponentLoader(ComponentLoader):\n    \"\"\"Generic loader for components that don't have a specific loader.\"\"\"\n\n    def __init__(self, library=\"transformers\") -> None:\n        super().__init__()\n        self.library = library\n\n\nclass PipelineComponentLoader:\n    \"\"\"\n    Utility class for loading the components in a pipeline.\n    \"\"\"\n\n    @staticmethod\n    def load_component(\n        component_name: str,\n        component_model_path: str,\n        transformers_or_diffusers: str,\n        server_args: ServerArgs,\n    ):\n        \"\"\"\n        Load a pipeline component.\n\n        Args:\n            component_name: Name of the component (e.g., \"vae\", \"text_encoder\", \"transformer\", \"scheduler\")\n            component_model_path: Path to the component model\n            transformers_or_diffusers: Whether the component is from transformers or diffusers\n\n        \"\"\"\n\n        # Get the appropriate loader for this component type\n        loader = ComponentLoader.for_component_type(\n            component_name, transformers_or_diffusers\n        )\n\n        try:\n            # Load the component\n            return loader.load(\n                component_model_path,\n                server_args,\n                component_name,\n                transformers_or_diffusers,\n            )\n        except Exception as e:\n            logger.error(\n                f\"Error while loading component: {component_name}, {component_model_path=}\"\n            )\n            raise e\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/loader/component_loaders/image_encoder_loader.py",
    "content": "from sglang.multimodal_gen.configs.models import ModelConfig\nfrom sglang.multimodal_gen.runtime.loader.component_loaders.text_encoder_loader import (\n    TextEncoderLoader,\n)\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import (\n    get_diffusers_component_config,\n)\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\nclass ImageEncoderLoader(TextEncoderLoader):\n    component_names = [\"image_encoder\"]\n    expected_library = \"transformers\"\n\n    def should_offload(self, server_args, model_config: ModelConfig | None = None):\n        should_offload = server_args.image_encoder_cpu_offload\n        if not should_offload:\n            return False\n        # _fsdp_shard_conditions is in arch_config, not directly on model_config\n        arch_config = (\n            getattr(model_config, \"arch_config\", model_config) if model_config else None\n        )\n        fsdp_shard_conditions = (\n            getattr(arch_config, \"_fsdp_shard_conditions\", []) if arch_config else []\n        )\n        use_cpu_offload = should_offload and len(fsdp_shard_conditions) > 0\n        return use_cpu_offload\n\n    def load_customized(\n        self, component_model_path: str, server_args: ServerArgs, *args\n    ):\n        \"\"\"Load the text encoders based on the model path, and inference args.\"\"\"\n        # model_config: PretrainedConfig = get_hf_config(\n        #     model=model_path,\n        #     trust_remote_code=server_args.trust_remote_code,\n        #     revision=server_args.revision,\n        #     model_override_args=None,\n        # )\n        model_config = get_diffusers_component_config(\n            component_path=component_model_path\n        )\n\n        encoder_config = server_args.pipeline_config.image_encoder_config\n        encoder_config.update_model_arch(model_config)\n\n        # Always start with local device; load_model will adjust for offload if needed\n        # TODO(will): add support for other dtypes\n        return self.load_model(\n            component_model_path,\n            encoder_config,\n            server_args,\n            server_args.pipeline_config.image_encoder_precision,\n            cpu_offload_flag=server_args.image_encoder_cpu_offload,\n        )\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/loader/component_loaders/scheduler_loader.py",
    "content": "from sglang.multimodal_gen.runtime.loader.component_loaders.component_loader import (\n    ComponentLoader,\n)\nfrom sglang.multimodal_gen.runtime.models.registry import ModelRegistry\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import (\n    get_diffusers_component_config,\n)\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\nclass SchedulerLoader(ComponentLoader):\n    \"\"\"Loader for scheduler.\"\"\"\n\n    component_names = [\"scheduler\"]\n    expected_library = \"diffusers\"\n\n    def load_customized(\n        self, component_model_path: str, server_args: ServerArgs, *args\n    ):\n        \"\"\"Load the scheduler based on the model path, and inference args.\"\"\"\n        config = get_diffusers_component_config(component_path=component_model_path)\n\n        class_name = config.pop(\"_class_name\")\n        assert (\n            class_name is not None\n        ), \"Model config does not contain a _class_name attribute. Only diffusers format is supported.\"\n\n        scheduler_cls, _ = ModelRegistry.resolve_model_cls(class_name)\n\n        scheduler = scheduler_cls(**config)\n        if server_args.pipeline_config.flow_shift is not None:\n            scheduler.set_shift(server_args.pipeline_config.flow_shift)\n\n        return scheduler\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/loader/component_loaders/text_encoder_loader.py",
    "content": "import dataclasses\nimport glob\nimport os\nfrom collections.abc import Generator, Iterable\nfrom typing import Generator, Iterable, cast\n\nimport torch\nimport torch.distributed as dist\nimport torch.nn as nn\nfrom torch import nn\nfrom torch.distributed import init_device_mesh\nfrom transformers.utils import SAFE_WEIGHTS_INDEX_NAME\n\nfrom sglang.multimodal_gen.configs.models import EncoderConfig, ModelConfig\nfrom sglang.multimodal_gen.configs.pipeline_configs.qwen_image import (\n    QwenImageEditPipelineConfig,\n)\nfrom sglang.multimodal_gen.runtime.distributed import get_local_torch_device\nfrom sglang.multimodal_gen.runtime.loader.component_loaders.component_loader import (\n    ComponentLoader,\n)\nfrom sglang.multimodal_gen.runtime.loader.fsdp_load import shard_model\nfrom sglang.multimodal_gen.runtime.loader.utils import (\n    set_default_torch_dtype,\n    skip_init_modules,\n)\nfrom sglang.multimodal_gen.runtime.loader.weight_utils import (\n    filter_duplicate_safetensors_files,\n    filter_files_not_needed_for_inference,\n    pt_weights_iterator,\n    safetensors_weights_iterator,\n)\nfrom sglang.multimodal_gen.runtime.models.registry import ModelRegistry\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import (\n    get_config,\n    get_diffusers_component_config,\n)\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.utils import PRECISION_TO_TYPE\nfrom sglang.srt.environ import envs\n\nlogger = init_logger(__name__)\n\n\nclass TextEncoderLoader(ComponentLoader):\n    \"\"\"Loader for text encoders.\"\"\"\n\n    component_names = [\"text_encoder\"]\n    expected_library = \"transformers\"\n\n    @dataclasses.dataclass\n    class Source:\n        \"\"\"A source for weights.\"\"\"\n\n        model_or_path: str\n        \"\"\"The model ID or path.\"\"\"\n\n        prefix: str = \"\"\n        \"\"\"A prefix to prepend to all weights.\"\"\"\n\n        fall_back_to_pt: bool = True\n        \"\"\"Whether .pt weights can be used.\"\"\"\n\n        allow_patterns_overrides: list[str] | None = None\n        \"\"\"If defined, weights will load exclusively using these patterns.\"\"\"\n\n    def should_offload(self, server_args, model_config: ModelConfig | None = None):\n        should_offload = server_args.text_encoder_cpu_offload\n        if not should_offload:\n            return False\n        # _fsdp_shard_conditions is in arch_config, not directly on model_config\n        arch_config = (\n            getattr(model_config, \"arch_config\", model_config) if model_config else None\n        )\n        fsdp_shard_conditions = (\n            getattr(arch_config, \"_fsdp_shard_conditions\", []) if arch_config else []\n        )\n        use_cpu_offload = should_offload and len(fsdp_shard_conditions) > 0\n        return use_cpu_offload\n\n    def _prepare_weights(\n        self,\n        model_name_or_path: str,\n        fall_back_to_pt: bool,\n        allow_patterns_overrides: list[str] | None,\n    ) -> tuple[str, list[str], bool]:\n        \"\"\"Prepare weights for the model.\n\n        If the model is not local, it will be downloaded.\"\"\"\n        # model_name_or_path = (self._maybe_download_from_modelscope(\n        #     model_name_or_path, revision) or model_name_or_path)\n\n        is_local = os.path.isdir(model_name_or_path)\n        assert is_local, \"Model path must be a local directory\"\n\n        use_safetensors = False\n        index_file = SAFE_WEIGHTS_INDEX_NAME\n        allow_patterns = [\"*.safetensors\", \"*.bin\"]\n\n        if fall_back_to_pt:\n            allow_patterns += [\"*.pt\"]\n\n        if allow_patterns_overrides is not None:\n            allow_patterns = allow_patterns_overrides\n\n        hf_folder = model_name_or_path\n\n        hf_weights_files: list[str] = []\n        for pattern in allow_patterns:\n            hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))\n            if len(hf_weights_files) > 0:\n                if pattern == \"*.safetensors\":\n                    use_safetensors = True\n                break\n\n        if use_safetensors:\n            hf_weights_files = filter_duplicate_safetensors_files(\n                hf_weights_files, hf_folder, index_file\n            )\n        else:\n            hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files)\n\n        if len(hf_weights_files) == 0:\n            raise RuntimeError(\n                f\"Cannot find any model weights with `{model_name_or_path}`\"\n            )\n\n        if envs.SGLANG_SORT_WEIGHT_FILES.get():\n            hf_weights_files.sort()\n\n        return hf_folder, hf_weights_files, use_safetensors\n\n    def _get_weights_iterator(\n        self, source: \"Source\", to_cpu: bool\n    ) -> Generator[tuple[str, torch.Tensor], None, None]:\n        \"\"\"get an iterator for the model weights based on the load format.\"\"\"\n        hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(\n            source.model_or_path,\n            source.fall_back_to_pt,\n            source.allow_patterns_overrides,\n        )\n        if use_safetensors:\n            weights_iterator = safetensors_weights_iterator(\n                hf_weights_files, to_cpu=to_cpu\n            )\n        else:\n            weights_iterator = pt_weights_iterator(hf_weights_files, to_cpu=to_cpu)\n\n        # apply the prefix.\n        return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator)\n\n    def _get_all_weights(\n        self,\n        model: nn.Module,\n        model_path: str,\n        to_cpu: bool,\n    ) -> Generator[tuple[str, torch.Tensor], None, None]:\n        primary_weights = TextEncoderLoader.Source(\n            model_path,\n            prefix=\"\",\n            fall_back_to_pt=getattr(model, \"fall_back_to_pt_during_load\", True),\n            allow_patterns_overrides=getattr(model, \"allow_patterns_overrides\", None),\n        )\n        yield from self._get_weights_iterator(primary_weights, to_cpu)\n\n        secondary_weights = cast(\n            Iterable[TextEncoderLoader.Source],\n            getattr(model, \"secondary_weights\", ()),\n        )\n        for source in secondary_weights:\n            yield from self._get_weights_iterator(source, to_cpu)\n\n    def load_customized(\n        self, component_model_path: str, server_args: ServerArgs, component_name: str\n    ):\n        \"\"\"Load the text encoders based on the model path, and inference args.\"\"\"\n        diffusers_pretrained_config = get_config(\n            component_model_path, trust_remote_code=True\n        )\n        model_config = get_diffusers_component_config(\n            component_path=component_model_path\n        )\n\n        def is_not_first_encoder(module_name):\n            return \"2\" in module_name\n\n        # TODO(mick): had to throw an exception for different text-encoder arch\n        if not is_not_first_encoder(component_name):\n            encoder_config = server_args.pipeline_config.text_encoder_configs[0]\n            encoder_config.update_model_arch(model_config)\n            for key, value in diffusers_pretrained_config.__dict__.items():\n                setattr(encoder_config.arch_config, key, value)\n            encoder_dtype = server_args.pipeline_config.text_encoder_precisions[0]\n        else:\n            assert len(server_args.pipeline_config.text_encoder_configs) == 2\n            encoder_config = server_args.pipeline_config.text_encoder_configs[1]\n            encoder_config.update_model_arch(model_config)\n            encoder_dtype = server_args.pipeline_config.text_encoder_precisions[1]\n        # TODO(will): add support for other dtypes\n        return self.load_model(\n            component_model_path,\n            encoder_config,\n            server_args,\n            encoder_dtype,\n        )\n\n    def load_model(\n        self,\n        model_path: str,\n        model_config: EncoderConfig,\n        server_args: ServerArgs,\n        dtype: str = \"fp16\",\n        cpu_offload_flag: bool | None = None,\n    ):\n        # Determine CPU offload behavior and target device\n\n        local_torch_device = get_local_torch_device()\n        should_offload = self.should_offload(server_args, model_config)\n\n        if should_offload and not current_platform.is_mps():\n            model_device = torch.device(\"cpu\")\n        else:\n            model_device = local_torch_device\n\n        with set_default_torch_dtype(PRECISION_TO_TYPE[dtype]):\n            with model_device, skip_init_modules():\n                architectures = getattr(model_config, \"architectures\", [])\n                model_cls, _ = ModelRegistry.resolve_model_cls(architectures)\n                enable_image_understanding = (\n                    True\n                    if isinstance(\n                        server_args.pipeline_config, QwenImageEditPipelineConfig\n                    )\n                    else False\n                )\n                model_config.enable_image_understanding = enable_image_understanding\n                model = model_cls(model_config)\n\n            weights_to_load = {name for name, _ in model.named_parameters()}\n            loaded_weights = model.load_weights(\n                self._get_all_weights(model, model_path, to_cpu=should_offload)\n            )\n\n            # Explicitly move model to target device after loading weights\n            if not should_offload:\n                model = model.to(local_torch_device)\n\n            if should_offload:\n                # Disable FSDP for MPS as it's not compatible\n                if current_platform.is_mps():\n                    logger.info(\n                        \"Disabling FSDP sharding for MPS platform as it's not compatible\"\n                    )\n                    model = model.to(local_torch_device)\n                else:\n                    mesh = init_device_mesh(\n                        current_platform.device_type,\n                        mesh_shape=(1, dist.get_world_size()),\n                        mesh_dim_names=(\"offload\", \"replicate\"),\n                    )\n                    shard_model(\n                        model,\n                        cpu_offload=True,\n                        reshard_after_forward=True,\n                        mesh=mesh[\"offload\"],\n                        fsdp_shard_conditions=model_config.arch_config._fsdp_shard_conditions\n                        or getattr(model, \"_fsdp_shard_conditions\", None),\n                        pin_cpu_memory=server_args.pin_cpu_memory,\n                    )\n            else:\n                model = model.to(local_torch_device)\n            # We only enable strict check for non-quantized models\n            # that have loaded weights tracking currently.\n            # if loaded_weights is not None:\n            weights_not_loaded = weights_to_load - loaded_weights\n            if weights_not_loaded:\n                # NOTE:\n                # If we silently continue with uninitialized weights, the text encoder can\n                # produce NaNs/garbage embeddings that later fail stage verification in a\n                # hard-to-debug way (e.g., `prompt_embeds` fails the NaN check).\n                #\n                # We allow a small set of known-optional parameters to be missing, but\n                # default to strict behavior for the rest.\n                allowed_missing_patterns = (\n                    getattr(model, \"_allowed_missing_weights_patterns\", []) or []\n                )\n                unexpected_missing = {\n                    n\n                    for n in weights_not_loaded\n                    if not any(pat in n for pat in allowed_missing_patterns)\n                }\n                if unexpected_missing:\n                    raise ValueError(\n                        \"Following text encoder weights were not initialized from checkpoint: \"\n                        f\"{sorted(unexpected_missing)}. \"\n                        \"This usually indicates a checkpoint/model-arch mismatch or a broken \"\n                        \"weight-name mapping. If these are truly optional, set \"\n                        \"`model._allowed_missing_weights_patterns` to whitelist patterns.\"\n                    )\n                logger.warning(\n                    \"Following (allowed) text encoder weights were not initialized from \"\n                    \"checkpoint: %s (allowed patterns: %s)\",\n                    sorted(weights_not_loaded),\n                    allowed_missing_patterns,\n                )\n\n        return model\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/loader/component_loaders/transformer_loader.py",
    "content": "import json\nimport logging\nimport os\nfrom typing import Any, Dict, List, Optional\n\nimport torch\n\nfrom sglang.multimodal_gen.runtime.distributed import get_local_torch_device\nfrom sglang.multimodal_gen.runtime.layers.quantization.configs.nunchaku_config import (\n    NunchakuConfig,\n    _patch_nunchaku_scales,\n)\nfrom sglang.multimodal_gen.runtime.loader.component_loaders.component_loader import (\n    ComponentLoader,\n)\nfrom sglang.multimodal_gen.runtime.loader.fsdp_load import maybe_load_fsdp_model\nfrom sglang.multimodal_gen.runtime.loader.utils import (\n    _list_safetensors_files,\n    _normalize_component_type,\n)\nfrom sglang.multimodal_gen.runtime.models.registry import ModelRegistry\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import (\n    get_diffusers_component_config,\n    maybe_download_model,\n)\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import get_log_level, init_logger\nfrom sglang.multimodal_gen.runtime.utils.quantization_utils import (\n    get_metadata_from_safetensors_file,\n    get_quant_config,\n    get_quant_config_from_safetensors_metadata,\n)\nfrom sglang.multimodal_gen.utils import PRECISION_TO_TYPE\nfrom sglang.srt.utils import is_npu\n\n_is_npu = is_npu()\n\nlogger = init_logger(__name__)\n\n\nclass TransformerLoader(ComponentLoader):\n    \"\"\"Shared loader for (video/audio) DiT transformers.\"\"\"\n\n    component_names = [\"transformer\", \"audio_dit\", \"video_dit\"]\n    expected_library = \"diffusers\"\n\n    def get_list_of_safetensors_to_load(\n        self, server_args: ServerArgs, component_model_path: str\n    ) -> list[str]:\n        \"\"\"\n        get list of safetensors to load.\n\n        If --transformer-weights-path is provided, load weights from that path\n        instead of the base model's component directory.\n        \"\"\"\n        quantized_path = server_args.transformer_weights_path\n\n        if quantized_path:\n            quantized_path = maybe_download_model(quantized_path)\n            logger.info(\"using quantized transformer weights from: %s\", quantized_path)\n            if os.path.isfile(quantized_path) and quantized_path.endswith(\n                \".safetensors\"\n            ):\n                safetensors_list = [quantized_path]\n            else:\n                safetensors_list = _list_safetensors_files(quantized_path)\n        else:\n            safetensors_list = _list_safetensors_files(component_model_path)\n\n        if not safetensors_list:\n            raise ValueError(\n                f\"no safetensors files found in \"\n                f\"{quantized_path or component_model_path}\"\n            )\n\n        return safetensors_list\n\n    def _resolve_quant_config(\n        self,\n        hf_config: Dict[str, List[str]],\n        server_args: ServerArgs,\n        safetensors_list: list[str],\n        component_model_path: str,\n    ) -> Optional[dict]:\n        # priority: model config.json → safetensors metadata → nunchaku config\n        quant_config = get_quant_config(hf_config, component_model_path)\n        if quant_config is None and server_args.transformer_weights_path:\n            # try to read quantization_config from the safetensors metadata header\n            for safetensors_file in safetensors_list:\n                quant_config = get_quant_config_from_safetensors_metadata(\n                    safetensors_file\n                )\n                if quant_config:\n                    break\n        return quant_config\n\n    def _resolve_target_param_dtype(\n        self,\n        quant_config: Optional[dict],\n        nunchaku_config: Optional[NunchakuConfig],\n        model_cls,\n        server_args: ServerArgs,\n    ) -> Optional[torch.dtype]:\n        if quant_config is not None or nunchaku_config is not None:\n            # TODO: improve the condition\n            # respect dtype from checkpoint\n            param_dtype = None\n        else:\n            param_dtype = PRECISION_TO_TYPE[server_args.pipeline_config.dit_precision]\n\n        if nunchaku_config is not None:\n            nunchaku_config.model_cls = model_cls\n            # verify that the nunchaku checkpoint matches the selected model class\n            original_dit_cls_name = json.loads(\n                get_metadata_from_safetensors_file(\n                    nunchaku_config.transformer_weights_path\n                ).get(\"config\")\n            )[\"_class_name\"]\n            specified_dit_cls_name = str(model_cls.__name__)\n            if original_dit_cls_name != specified_dit_cls_name:\n                raise Exception(\n                    f\"Class name of DiT specified in nunchaku transformer_weights_path: {original_dit_cls_name} does not match that of specified DiT name: {specified_dit_cls_name}\"\n                )\n\n        return param_dtype\n\n    def load_customized(\n        self, component_model_path: str, server_args: ServerArgs, component_name: str\n    ):\n        \"\"\"Load the transformer based on the model path, and inference args.\"\"\"\n        # 1. hf config\n        config = get_diffusers_component_config(component_path=component_model_path)\n\n        # 2. quant config\n        safetensors_list = self.get_list_of_safetensors_to_load(\n            server_args, component_model_path\n        )\n\n        quant_config = self._resolve_quant_config(\n            config, server_args, safetensors_list, component_model_path\n        )\n\n        # 3. dit config\n        # Config from Diffusers supersedes sgl_diffusion's model config\n        component_name = _normalize_component_type(component_name)\n        server_args.model_paths[component_name] = component_model_path\n        if component_name in (\"transformer\", \"video_dit\"):\n            pipeline_dit_config_attr = \"dit_config\"\n        elif component_name in (\"audio_dit\",):\n            pipeline_dit_config_attr = \"audio_dit_config\"\n        else:\n            raise ValueError(f\"Invalid module name: {component_name}\")\n        dit_config = getattr(server_args.pipeline_config, pipeline_dit_config_attr)\n        dit_config.update_model_arch(config)\n\n        cls_name = config.pop(\"_class_name\")\n        model_cls, _ = ModelRegistry.resolve_model_cls(cls_name)\n\n        nunchaku_config = server_args.nunchaku_config\n        param_dtype = self._resolve_target_param_dtype(\n            quant_config, nunchaku_config, model_cls, server_args\n        )\n\n        logger.info(\n            \"Loading %s from %s safetensors file(s) %s, param_dtype: %s\",\n            cls_name,\n            len(safetensors_list),\n            f\": {safetensors_list}\" if get_log_level() == logging.DEBUG else \"\",\n            param_dtype,\n        )\n\n        # prepare init_param\n        init_params: dict[str, Any] = {\n            \"config\": dit_config,\n            \"hf_config\": config,\n            \"quant_config\": (quant_config if quant_config else nunchaku_config),\n        }\n        if (\n            init_params[\"quant_config\"] is None\n            and server_args.transformer_weights_path is not None\n        ):\n            logger.warning(\n                f\"transformer_weights_path provided, but quantization config not resolved, which is unexpected and likely to cause errors\"\n            )\n        else:\n            logger.debug(\"quantization config: %s\", init_params[\"quant_config\"])\n\n        # Load the model using FSDP loader\n        model = maybe_load_fsdp_model(\n            model_cls=model_cls,\n            init_params=init_params,\n            weight_dir_list=safetensors_list,\n            device=get_local_torch_device(),\n            hsdp_replicate_dim=server_args.hsdp_replicate_dim,\n            hsdp_shard_dim=server_args.hsdp_shard_dim,\n            cpu_offload=server_args.dit_cpu_offload,\n            pin_cpu_memory=server_args.pin_cpu_memory,\n            fsdp_inference=server_args.use_fsdp_inference,\n            # TODO(will): make these configurable\n            param_dtype=param_dtype,\n            reduce_dtype=torch.float32,\n            output_dtype=None,\n            strict=False,\n        )\n\n        if nunchaku_config is not None:\n            _patch_nunchaku_scales(model, safetensors_list)\n\n        total_params = sum(p.numel() for p in model.parameters())\n        logger.info(\"Loaded model with %.2fB parameters\", total_params / 1e9)\n\n        # considering the existent of mixed-precision models (e.g., nunchaku)\n        if next(model.parameters()).dtype != param_dtype and param_dtype:\n            logger.warning(\n                f\"Model dtype does not match expected param dtype, {next(model.parameters()).dtype} vs {param_dtype}\"\n            )\n\n        return model\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/loader/component_loaders/vae_loader.py",
    "content": "import importlib.util\nimport os\n\nimport torch\nimport torch.nn as nn\nfrom safetensors.torch import load_file as safetensors_load_file\n\nfrom sglang.multimodal_gen import envs\nfrom sglang.multimodal_gen.configs.models import ModelConfig\nfrom sglang.multimodal_gen.runtime.loader.component_loaders.component_loader import (\n    ComponentLoader,\n)\nfrom sglang.multimodal_gen.runtime.loader.utils import (\n    _list_safetensors_files,\n    set_default_torch_dtype,\n    skip_init_modules,\n)\nfrom sglang.multimodal_gen.runtime.models.registry import ModelRegistry\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import (\n    get_diffusers_component_config,\n)\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.utils import PRECISION_TO_TYPE\n\nlogger = init_logger(__name__)\n\n\ndef _convert_conv3d_weights_to_channels_last_3d(module: nn.Module) -> int:\n    \"\"\"\n    Convert Conv3d weights to channels_last_3d (NDHWC) memory format.\n    Returns the number of Conv3d modules converted.\n    \"\"\"\n    if not hasattr(torch, \"channels_last_3d\"):\n        return 0\n    num_converted = 0\n    for m in module.modules():\n        if isinstance(m, nn.Conv3d):\n            try:\n                m.weight.data = m.weight.data.to(memory_format=torch.channels_last_3d)\n                num_converted += 1\n            except Exception:\n                # Best-effort; skip unsupported cases.\n                continue\n    return num_converted\n\n\nclass VAELoader(ComponentLoader):\n    \"\"\"Shared loader for (video/audio) VAE modules.\"\"\"\n\n    component_names = [\"vae\", \"audio_vae\", \"video_vae\"]\n    expected_library = \"diffusers\"\n\n    def should_offload(\n        self, server_args: ServerArgs, model_config: ModelConfig | None = None\n    ):\n        return server_args.vae_cpu_offload\n\n    def load_customized(\n        self, component_model_path: str, server_args: ServerArgs, component_name: str\n    ):\n        \"\"\"Load the VAE based on the model path, and inference args.\"\"\"\n        config = get_diffusers_component_config(component_path=component_model_path)\n        class_name = config.pop(\"_class_name\", None)\n        assert (\n            class_name is not None\n        ), \"Model config does not contain a _class_name attribute. Only diffusers format is supported.\"\n\n        server_args.model_paths[component_name] = component_model_path\n\n        if component_name in (\"vae\", \"video_vae\"):\n            pipeline_vae_config_attr = \"vae_config\"\n            pipeline_vae_precision = \"vae_precision\"\n        elif component_name in (\"audio_vae\",):\n            pipeline_vae_config_attr = \"audio_vae_config\"\n            pipeline_vae_precision = \"audio_vae_precision\"\n        else:\n            raise ValueError(\n                f\"Unsupported module name for VAE loader: {component_name}\"\n            )\n        vae_config = getattr(server_args.pipeline_config, pipeline_vae_config_attr)\n        vae_precision = getattr(server_args.pipeline_config, pipeline_vae_precision)\n        vae_config.update_model_arch(config)\n        if hasattr(vae_config, \"post_init\"):\n            # NOTE: some post init logics are only available after updated with config\n            vae_config.post_init()\n\n        should_offload = self.should_offload(server_args)\n        target_device = self.target_device(should_offload)\n\n        # Check for auto_map first (custom VAE classes)\n        auto_map = config.get(\"auto_map\", {})\n        auto_model_map = auto_map.get(\"AutoModel\")\n        if auto_model_map:\n            module_path, cls_name = auto_model_map.rsplit(\".\", 1)\n            custom_module_file = os.path.join(component_model_path, f\"{module_path}.py\")\n            spec = importlib.util.spec_from_file_location(\"_custom\", custom_module_file)\n            custom_module = importlib.util.module_from_spec(spec)\n            spec.loader.exec_module(custom_module)\n            vae_cls = getattr(custom_module, cls_name)\n            vae_dtype = PRECISION_TO_TYPE[vae_precision]\n            with set_default_torch_dtype(vae_dtype):\n                vae = vae_cls.from_pretrained(\n                    component_model_path,\n                    revision=server_args.revision,\n                    trust_remote_code=server_args.trust_remote_code,\n                )\n            vae = vae.to(device=target_device, dtype=vae_dtype)\n            if (\n                component_name in (\"vae\", \"video_vae\")\n                and torch.cuda.is_available()\n                and getattr(envs, \"SGLANG_DIFFUSION_VAE_CHANNELS_LAST_3D\", False)\n            ):\n                n = _convert_conv3d_weights_to_channels_last_3d(vae)\n                if n > 0:\n                    logger.info(\n                        \"VAE: converted %d Conv3d weights to channels_last_3d\", n\n                    )\n            vae = current_platform.optimize_vae(vae)\n            return vae\n\n        # Load from ModelRegistry (standard VAE classes)\n        with (\n            set_default_torch_dtype(PRECISION_TO_TYPE[vae_precision]),\n            skip_init_modules(),\n        ):\n            vae_cls, _ = ModelRegistry.resolve_model_cls(class_name)\n            vae = vae_cls(vae_config).to(target_device)\n\n        safetensors_list = _list_safetensors_files(component_model_path)\n        assert (\n            len(safetensors_list) >= 1\n        ), f\"Found no safetensors files in {component_model_path}\"\n        loaded = {}\n        for sf_path in safetensors_list:\n            loaded.update(safetensors_load_file(sf_path))\n        vae.load_state_dict(loaded, strict=False)\n\n        state_keys = set(vae.state_dict().keys())\n        loaded_keys = set(loaded.keys())\n        missing_keys = sorted(state_keys - loaded_keys)\n        unexpected_keys = sorted(loaded_keys - state_keys)\n        if missing_keys:\n            logger.warning(\"VAE missing keys: %s\", missing_keys)\n        if unexpected_keys:\n            logger.warning(\"VAE unexpected keys: %s\", unexpected_keys)\n\n        if (\n            component_name in (\"vae\", \"video_vae\")\n            and torch.cuda.is_available()\n            and getattr(envs, \"SGLANG_DIFFUSION_VAE_CHANNELS_LAST_3D\", False)\n        ):\n            n = _convert_conv3d_weights_to_channels_last_3d(vae)\n            if n > 0:\n                logger.info(\"VAE: converted %d Conv3d weights to channels_last_3d\", n)\n\n        vae = current_platform.optimize_vae(vae)\n        return vae\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/loader/component_loaders/vl_encoder_loader.py",
    "content": "from typing import Any\n\nfrom sglang.multimodal_gen.runtime.distributed import get_local_torch_device\nfrom sglang.multimodal_gen.runtime.loader.component_loaders.component_loader import (\n    ComponentLoader,\n)\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import get_hf_config\n\n\nclass VisionLanguageEncoderLoader(ComponentLoader):\n    \"\"\"Loader for vision language encoder (typically Causal LM or Vision2Seq).\"\"\"\n\n    component_names = [\"vision_language_encoder\"]\n    expected_library = \"transformers\"\n\n    def load_customized(\n        self,\n        component_model_path: str,\n        server_args: ServerArgs,\n        transformers_or_diffusers: str = \"vision_language_encoder\",\n    ) -> Any:\n        if transformers_or_diffusers == \"vision_language_encoder\":\n            from transformers import GlmImageForConditionalGeneration\n\n            config = get_hf_config(\n                component_model_path,\n                trust_remote_code=server_args.trust_remote_code,\n                revision=server_args.revision,\n            )\n            model = GlmImageForConditionalGeneration.from_pretrained(\n                component_model_path,\n                config=config,\n                trust_remote_code=server_args.trust_remote_code,\n                revision=server_args.revision,\n            ).to(get_local_torch_device())\n            return model\n        else:\n            raise ValueError(\n                f\"Unsupported library for VisionLanguageEncoder: {transformers_or_diffusers}\"\n            )\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/loader/component_loaders/vocoder_loader.py",
    "content": "from safetensors.torch import load_file as safetensors_load_file\n\nfrom sglang.multimodal_gen.configs.models import ModelConfig\nfrom sglang.multimodal_gen.runtime.loader.component_loaders.component_loader import (\n    ComponentLoader,\n)\nfrom sglang.multimodal_gen.runtime.loader.utils import (\n    _list_safetensors_files,\n    set_default_torch_dtype,\n    skip_init_modules,\n)\nfrom sglang.multimodal_gen.runtime.models.registry import ModelRegistry\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import (\n    get_diffusers_component_config,\n)\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.utils import PRECISION_TO_TYPE\n\nlogger = init_logger(__name__)\n\n\nclass VocoderLoader(ComponentLoader):\n    component_names = [\"vocoder\"]\n    expected_library = \"diffusers\"\n\n    def should_offload(\n        self, server_args: ServerArgs, model_config: ModelConfig | None = None\n    ):\n        return server_args.vae_cpu_offload\n\n    def load_customized(\n        self, component_model_path: str, server_args: ServerArgs, component_name: str\n    ):\n        config = get_diffusers_component_config(component_path=component_model_path)\n        class_name = config.pop(\"_class_name\", None)\n        assert (\n            class_name is not None\n        ), \"Model config does not contain a _class_name attribute. Only diffusers format is supported.\"\n\n        server_args.model_paths[component_name] = component_model_path\n\n        from sglang.multimodal_gen.configs.models.vocoder.ltx_vocoder import (\n            LTXVocoderConfig,\n        )\n\n        vocoder_config = LTXVocoderConfig()\n        vocoder_config.update_model_arch(config)\n\n        try:\n            vocoder_precision = server_args.pipeline_config.audio_vae_precision\n        except AttributeError:\n            vocoder_precision = \"fp32\"\n        vocoder_dtype = PRECISION_TO_TYPE[vocoder_precision]\n\n        should_offload = self.should_offload(server_args)\n        target_device = self.target_device(should_offload)\n\n        with set_default_torch_dtype(vocoder_dtype), skip_init_modules():\n            vocoder_cls, _ = ModelRegistry.resolve_model_cls(class_name)\n            vocoder = vocoder_cls(vocoder_config).to(target_device)\n\n        safetensors_list = _list_safetensors_files(component_model_path)\n        assert (\n            len(safetensors_list) == 1\n        ), f\"Found {len(safetensors_list)} safetensors files in {component_model_path}\"\n        loaded = safetensors_load_file(safetensors_list[0])\n        incompatible = vocoder.load_state_dict(loaded, strict=False)\n        missing_keys = []\n        unexpected_keys = []\n        try:\n            missing_keys = incompatible.missing_keys\n            unexpected_keys = incompatible.unexpected_keys\n        except AttributeError:\n            # Best-effort fallback in case older torch returns a tuple-like.\n            try:\n                missing_keys = incompatible[0]\n                unexpected_keys = incompatible[1]\n            except Exception:\n                pass\n\n        if missing_keys or unexpected_keys:\n            logger.warning(\n                \"Loaded vocoder with missing_keys=%d unexpected_keys=%d\",\n                len(missing_keys),\n                len(unexpected_keys),\n            )\n        return vocoder\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/loader/fsdp_load.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\n# Adapted from torchtune\n# Copyright 2024 The TorchTune Authors.\n# Copyright 2025 The sglang-diffusion Authors.\n\nfrom collections.abc import Callable, Generator\nfrom itertools import chain\nfrom typing import Any\n\nimport torch\nfrom torch import nn\nfrom torch.distributed import DeviceMesh, init_device_mesh\nfrom torch.distributed._tensor import distribute_tensor\nfrom torch.distributed.fsdp import (\n    CPUOffloadPolicy,\n    FSDPModule,\n    MixedPrecisionPolicy,\n    fully_shard,\n)\nfrom torch.nn.modules.module import _IncompatibleKeys\n\nfrom sglang.multimodal_gen.runtime.layers.linear import UnquantizedLinearMethod\nfrom sglang.multimodal_gen.runtime.loader.utils import (\n    get_param_names_mapping,\n    hf_to_custom_state_dict,\n    set_default_torch_dtype,\n)\nfrom sglang.multimodal_gen.runtime.loader.weight_utils import (\n    safetensors_weights_iterator,\n)\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.utils import set_mixed_precision_policy\nfrom sglang.srt.utils import is_npu\n\n_is_npu = is_npu()\n\nlogger = init_logger(__name__)\n\n\ndef _make_param_like(\n    actual_param: torch.nn.Parameter, tensor: torch.Tensor\n) -> torch.nn.Parameter:\n    cls = actual_param.__class__\n    # nn.Parameter defaults to requires_grad=True, which is illegal for non-floating/complex dtypes (e.g., int8/FP8\n    # quantized weights).\n    try:\n        new_param = cls.__new__(cls, tensor, requires_grad=False)\n    except TypeError:\n        new_param = cls.__new__(cls, tensor)\n    new_param.__dict__.update(actual_param.__dict__)\n    new_param.requires_grad = False\n    return new_param\n\n\n# TODO(PY): add compile option\ndef maybe_load_fsdp_model(\n    model_cls: type[nn.Module],\n    init_params: dict[str, Any],\n    weight_dir_list: list[str],\n    device: torch.device,\n    hsdp_replicate_dim: int,\n    hsdp_shard_dim: int,\n    param_dtype: torch.dtype,\n    reduce_dtype: torch.dtype,\n    cpu_offload: bool = False,\n    fsdp_inference: bool = False,\n    output_dtype: torch.dtype | None = None,\n    pin_cpu_memory: bool = True,\n    strict: bool = True,\n) -> torch.nn.Module:\n    \"\"\"Load a model with optional FSDP (Fully Sharded Data Parallel) support.\n\n    Args:\n        param_dtype: Data type for model parameters, also used for:\n            - Model initialization context (set_default_torch_dtype)\n            - FSDP mixed precision policy\n            - Weight loading and casting\n        reduce_dtype: Data type for gradient reduction in FSDP mixed precision.\n        strict: If True, enforce strict state dict loading (all keys must match).\n    \"\"\"\n    # NOTE(will): cast_forward_inputs=True shouldn't be needed as we are\n    # manually casting the inputs to the model\n    default_torch_dtype = param_dtype if param_dtype else torch.bfloat16\n    mp_policy = MixedPrecisionPolicy(\n        default_torch_dtype, reduce_dtype, output_dtype, cast_forward_inputs=False\n    )\n\n    set_mixed_precision_policy(\n        param_dtype=default_torch_dtype,\n        reduce_dtype=reduce_dtype,\n        output_dtype=output_dtype,\n        mp_policy=mp_policy,\n    )\n\n    with set_default_torch_dtype(default_torch_dtype), torch.device(\"meta\"):\n        model = model_cls(**init_params)\n\n    # Check if we should use FSDP\n    use_fsdp = fsdp_inference\n\n    # Disable FSDP for MPS as it's not compatible\n    if current_platform.is_mps():\n        use_fsdp = False\n        logger.info(\"Disabling FSDP for MPS platform as it's not compatible\")\n\n    if use_fsdp:\n        world_size = hsdp_replicate_dim * hsdp_shard_dim\n        if not fsdp_inference:\n            hsdp_replicate_dim = world_size\n            hsdp_shard_dim = 1\n\n        device_mesh = init_device_mesh(\n            current_platform.device_type,\n            # (Replicate(), Shard(dim=0))\n            mesh_shape=(hsdp_replicate_dim, hsdp_shard_dim),\n            mesh_dim_names=(\"replicate\", \"shard\"),\n        )\n        shard_model(\n            model,\n            cpu_offload=cpu_offload,\n            reshard_after_forward=True,\n            mp_policy=mp_policy,\n            mesh=device_mesh,\n            fsdp_shard_conditions=model._fsdp_shard_conditions,\n            pin_cpu_memory=pin_cpu_memory,\n        )\n\n    weight_iterator = safetensors_weights_iterator(weight_dir_list)\n    param_names_mapping_fn = get_param_names_mapping(model.param_names_mapping)\n    load_model_from_full_model_state_dict(\n        model,\n        weight_iterator,\n        device,\n        param_dtype,\n        strict=strict,\n        cpu_offload=cpu_offload,\n        param_names_mapping=param_names_mapping_fn,\n    )\n\n    for _, module in model.named_modules():\n        quant_method = getattr(module, \"quant_method\", None)\n        if quant_method is not None and hasattr(\n            quant_method, \"process_weights_after_loading\"\n        ):\n            if _is_npu and not isinstance(quant_method, UnquantizedLinearMethod):\n                # Activate the NZ format for storing weights,\n                # which is a specific optimization for Ascend NPU\n                torch.npu.config.allow_internal_format = True\n            quant_method.process_weights_after_loading(module)\n            if _is_npu:\n                torch.npu.empty_cache()\n\n    for n, p in chain(model.named_parameters(), model.named_buffers()):\n        if p.is_meta:\n            raise RuntimeError(f\"Unexpected param or buffer {n} on meta device.\")\n        # Avoid unintended computation graph accumulation during inference\n        if isinstance(p, torch.nn.Parameter):\n            p.requires_grad = False\n    return model\n\n\ndef shard_model(\n    model,\n    *,\n    cpu_offload: bool,\n    reshard_after_forward: bool = True,\n    mp_policy: MixedPrecisionPolicy | None = MixedPrecisionPolicy(),  # noqa\n    mesh: DeviceMesh | None = None,\n    fsdp_shard_conditions: list[Callable[[str, nn.Module], bool]] = [],  # noqa\n    pin_cpu_memory: bool = True,\n) -> None:\n    \"\"\"\n    Utility to shard a model with FSDP using the PyTorch Distributed fully_shard API.\n\n    This method will over the model's named modules from the bottom-up and apply shard modules\n    based on whether they meet any of the criteria from shard_conditions.\n\n    Args:\n        model (TransformerDecoder): Model to shard with FSDP.\n        cpu_offload (bool): If set to True, FSDP will offload parameters, gradients, and optimizer\n            states to CPU.\n        reshard_after_forward (bool): Whether to reshard parameters and buffers after\n            the forward pass. Setting this to True corresponds to the FULL_SHARD sharding strategy\n            from FSDP1, while setting it to False corresponds to the SHARD_GRAD_OP sharding strategy.\n        mesh (Optional[DeviceMesh]): Device mesh to use for FSDP sharding under multiple parallelism.\n            Default to None.\n        fsdp_shard_conditions (List[Callable[[str, nn.Module], bool]]): A list of functions to determine\n            which modules to shard with FSDP.\n        pin_cpu_memory (bool): If set to True, FSDP will pin the CPU memory of the offloaded parameters.\n\n    \"\"\"\n    if fsdp_shard_conditions is None or len(fsdp_shard_conditions) == 0:\n        logger.warning(\n            \"The FSDP shard condition list is empty or None. No modules will be sharded in %s\",\n            type(model).__name__,\n        )\n        return\n\n    fsdp_kwargs = {\n        \"reshard_after_forward\": reshard_after_forward,\n        \"mesh\": mesh,\n        \"mp_policy\": mp_policy,\n    }\n    if cpu_offload:\n        fsdp_kwargs[\"offload_policy\"] = CPUOffloadPolicy(pin_memory=pin_cpu_memory)\n\n    # iterating in reverse to start with\n    # lowest-level modules first\n    num_layers_sharded = 0\n    # TODO(will): don't reshard after forward for the last layer to save on the\n    # all-gather that will immediately happen Shard the model with FSDP,\n    for n, m in reversed(list(model.named_modules())):\n        if any([shard_condition(n, m) for shard_condition in fsdp_shard_conditions]):  # type: ignore\n            fully_shard(m, **fsdp_kwargs)\n            num_layers_sharded += 1\n\n    if num_layers_sharded == 0:\n        raise ValueError(\n            \"No layer modules were sharded. Please check if shard conditions are working as expected.\"\n        )\n\n    # Finally shard the entire model to account for any stragglers\n    fully_shard(model, **fsdp_kwargs)\n\n\n# TODO(PY): device mesh for cfg parallel\ndef load_model_from_full_model_state_dict(\n    model: FSDPModule | torch.nn.Module,\n    full_sd_iterator: Generator[tuple[str, torch.Tensor], None, None],\n    device: torch.device,\n    param_dtype: torch.dtype | None,\n    strict: bool = False,\n    cpu_offload: bool = False,\n    param_names_mapping: Callable[[str], tuple[str, Any, Any]] | None = None,\n) -> _IncompatibleKeys:\n    \"\"\"\n    Converting full state dict into a sharded state dict\n    and loading it into FSDP model (if training) or normal huggingface model\n    Args:\n        model (Union[FSDPModule, torch.nn.Module]): Model to generate fully qualified names for cpu_state_dict\n        full_sd_iterator (Generator): an iterator yielding (param_name, tensor) pairs\n        device (torch.device): device used to move full state dict tensors\n        param_dtype (torch.dtype): dtype used to move full state dict tensors. If none, respect original dtype from checkpoint\n        strict (bool): flag to check if to load the model in strict mode\n        cpu_offload (bool): flag to check if FSDP offload is enabled\n        param_names_mapping (Optional[Callable[[str], str]]): a function that maps full param name to sharded param name\n    Returns:\n        ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:\n            * **missing_keys** is a list of str containing the missing keys\n            * **unexpected_keys** is a list of str containing the unexpected keys\n\n    \"\"\"\n    meta_sd = model.state_dict()\n    param_dict = dict(model.named_parameters())\n\n    # map names from checkpoint to customized names\n    custom_param_sd, reverse_param_names_mapping = hf_to_custom_state_dict(\n        full_sd_iterator, param_names_mapping\n    )  # type: ignore\n\n    is_fsdp_model = isinstance(model, FSDPModule) or any(\n        hasattr(p, \"device_mesh\") for p in meta_sd.values()\n    )\n\n    # sort parameter names to ensure all ranks process parameters in the same order\n    sorted_param_names = sorted(custom_param_sd.keys())\n\n    sharded_sd = {}\n    skipped_checkpoint_keys: list[str] = []\n\n    # shard from loaded state_dict, custom_param_sd -> sharded_sd\n    for target_param_name in sorted_param_names:\n        full_tensor = custom_param_sd[target_param_name]\n        meta_sharded_param = meta_sd.get(target_param_name)\n\n        if meta_sharded_param is None:\n            # For FSDP models, ensure all ranks process parameters consistently\n            if strict or is_fsdp_model:\n                raise ValueError(\n                    f\"Parameter {target_param_name} not found in custom model state dict. The hf to custom mapping may be incorrect.\"\n                )\n            else:\n                skipped_checkpoint_keys.append(target_param_name)\n                continue\n\n        # use meta param dtype so quantized params (e.g. FP8) keep their dtype;\n        # for non-quantized models meta dtype equals param_dtype anyway\n        if meta_sharded_param is None:\n            # for nunchaku, some scales are patched later\n            target_dtype = full_tensor.dtype\n        else:\n            target_dtype = meta_sharded_param.dtype\n\n        if not hasattr(meta_sharded_param, \"device_mesh\"):\n            full_tensor = full_tensor.to(device=device, dtype=target_dtype)\n            actual_param = param_dict.get(target_param_name)\n            weight_loader = (\n                getattr(actual_param, \"weight_loader\", None)\n                if actual_param is not None\n                else None\n            )\n            if weight_loader is not None:\n                assert actual_param is not None\n                sharded_tensor = torch.empty_like(\n                    meta_sharded_param, device=device, dtype=target_dtype\n                )\n                # Preserve requires_grad flag to avoid errors with non-floating dtypes\n                requires_grad = getattr(meta_sharded_param, \"requires_grad\", False)\n                temp_param = _make_param_like(actual_param, sharded_tensor)\n                if not (\n                    sharded_tensor.is_floating_point() or sharded_tensor.is_complex()\n                ):\n                    requires_grad = False\n                temp_param.requires_grad = requires_grad\n                weight_loader(temp_param, full_tensor)\n                sharded_tensor = temp_param.data\n            else:\n                # In cases where parts of the model aren't sharded, some parameters will be plain tensors\n                sharded_tensor = full_tensor\n\n            # Important: `cpu_offload` is intended for FSDP-managed parameter movement.\n            # If a parameter is not sharded into a DTensor (i.e., no `device_mesh`), FSDP\n            # will NOT manage it. Offloading it here would leave CPU parameters that\n            # later participate in GPU kernels (e.g., conv/embedding), causing device/dtype\n            # mismatches like \"Input type (CUDABFloat16Type) and weight type (CPUBFloat16Type)\".\n            #\n            # Therefore:\n            # - For non-FSDP models, keep the historical behavior (allow CPU offload).\n            # - For FSDP models, do NOT offload non-sharded parameters here.\n            if cpu_offload and not is_fsdp_model:\n                sharded_tensor = sharded_tensor.cpu()\n        else:\n            full_tensor = full_tensor.to(device=device, dtype=target_dtype)\n            sharded_tensor = distribute_tensor(\n                full_tensor,\n                meta_sharded_param.device_mesh,\n                meta_sharded_param.placements,\n            )\n            if cpu_offload:\n                sharded_tensor = sharded_tensor.to(\"cpu\")\n\n        requires_grad = False\n        sharded_sd[target_param_name] = nn.Parameter(\n            sharded_tensor, requires_grad=requires_grad\n        )\n\n    model.reverse_param_names_mapping = reverse_param_names_mapping\n\n    if skipped_checkpoint_keys:\n        logger.warning(\n            \"Checkpoint keys not loaded (no matching model parameter) %s\",\n            (\n                skipped_checkpoint_keys[:20]\n                if len(skipped_checkpoint_keys) > 20\n                else skipped_checkpoint_keys\n            ),\n        )\n        if len(skipped_checkpoint_keys) > 20:\n            logger.warning(\n                \"... and %d more skipped keys.\",\n                len(skipped_checkpoint_keys) - 20,\n            )\n\n    # parameters in nn.Module that doesn't exist in safetensor files\n    unused_keys = set(meta_sd.keys()) - set(sharded_sd.keys())\n    if unused_keys:\n        logger.warning(\"Found unloaded parameters in meta state dict: %s\", unused_keys)\n\n    # for nunchaku; norm_q/norm_k for SANA QK normalization layers\n    ALLOWED_NEW_PARAM_PATTERNS = [\n        \"gate_compress\",\n        \"wcscales\",\n        \"wtscale\",\n        \"bias\",\n        \"norm_q\",\n        \"norm_k\",\n    ]\n    for new_param_name in unused_keys:\n        if not any(pattern in new_param_name for pattern in ALLOWED_NEW_PARAM_PATTERNS):\n            logger.error(\n                \"Unsupported new parameter: %s. Allowed patterns: %s\",\n                new_param_name,\n                ALLOWED_NEW_PARAM_PATTERNS,\n            )\n            raise ValueError(\n                f\"New parameter '{new_param_name}' is not supported. \"\n                f\"Currently only parameters containing {ALLOWED_NEW_PARAM_PATTERNS} are allowed.\"\n            )\n\n        meta_sharded_param = meta_sd.get(new_param_name)\n        meta_sharded_param_dtype = meta_sharded_param.dtype\n\n        if any(\n            p in new_param_name for p in (\"wcscales\", \"wtscale\", \"norm_q\", \"norm_k\")\n        ):\n            init_like = torch.ones_like\n        else:\n            init_like = torch.zeros_like\n\n        if not hasattr(meta_sharded_param, \"device_mesh\"):\n            sharded_tensor = init_like(\n                meta_sharded_param, device=device, dtype=meta_sharded_param_dtype\n            )\n            if cpu_offload and not is_fsdp_model:\n                sharded_tensor = sharded_tensor.cpu()\n        else:\n            full_tensor = init_like(\n                meta_sharded_param, device=device, dtype=meta_sharded_param_dtype\n            )\n            sharded_tensor = distribute_tensor(\n                full_tensor,\n                meta_sharded_param.device_mesh,\n                meta_sharded_param.placements,\n            )\n            if cpu_offload:\n                sharded_tensor = sharded_tensor.cpu()\n        sharded_sd[new_param_name] = nn.Parameter(sharded_tensor)\n\n    # choose `assign=True` since we cannot call `copy_` on meta tensor\n    return model.load_state_dict(sharded_sd, strict=strict, assign=True)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/loader/utils.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Utilities for selecting and loading models.\"\"\"\n\nimport contextlib\nimport glob\nimport os\nimport re\nfrom collections import defaultdict\nfrom collections.abc import Callable, Iterator\nfrom typing import Any, Dict, Type\n\nimport torch\nfrom torch import nn\n\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\n@contextlib.contextmanager\ndef set_default_torch_dtype(dtype: torch.dtype):\n    \"\"\"Sets the default torch dtype to the given dtype.\"\"\"\n    old_dtype = torch.get_default_dtype()\n    torch.set_default_dtype(dtype)\n    try:\n        yield\n    finally:\n        torch.set_default_dtype(old_dtype)\n\n\ndef get_param_names_mapping(\n    mapping_dict: dict[str, str | tuple[str, int, int]],\n) -> Callable[[str], tuple[str, Any, Any]]:\n    \"\"\"\n    Creates a mapping function that transforms parameter names using regex patterns.\n\n    Args:\n        mapping_dict (Dict[str, str]): Dictionary mapping regex patterns to replacement patterns\n\n    Returns:\n        Callable[[str], str]: A function that maps parameter names from source to target format\n    \"\"\"\n\n    def mapping_fn(name: str) -> tuple[str, Any, Any]:\n        # support chained conversions, e.g.:\n        # transformer.xxx.lora_down -> xxx.lora_down -> xxx.proj_down\n        merge_index = None\n        total_split_params = None\n        max_steps = max(8, len(mapping_dict) * 2)\n        applied_patterns: set[str] = set()\n        visited_names: set[str] = {name}\n\n        for _ in range(max_steps):\n            transformed = False\n            for pattern, replacement in mapping_dict.items():\n                # avoid re-applying the same rule on its own output\n                if pattern in applied_patterns:\n                    continue\n                if re.match(pattern, name) is None:\n                    continue\n\n                curr_merge_index = None\n                curr_total_split_params = None\n                if isinstance(replacement, tuple):\n                    curr_merge_index = replacement[1]\n                    curr_total_split_params = replacement[2]\n                    replacement = replacement[0]\n\n                new_name = re.sub(pattern, replacement, name)\n\n                if new_name != name:\n                    if curr_merge_index is not None:\n                        merge_index = curr_merge_index\n                        total_split_params = curr_total_split_params\n\n                    name = new_name\n                    applied_patterns.add(pattern)\n                    if name in visited_names:\n                        transformed = False\n                        break\n                    visited_names.add(name)\n                    transformed = True\n                    break\n\n            if not transformed:\n                break\n\n        return name, merge_index, total_split_params\n\n    return mapping_fn\n\n\ndef hf_to_custom_state_dict(\n    hf_param_sd: dict[str, torch.Tensor] | Iterator[tuple[str, torch.Tensor]],\n    param_names_mapping: Callable[[str], tuple[str, Any, Any]],\n) -> tuple[dict[str, torch.Tensor], dict[str, tuple[str, Any, Any]]]:\n    \"\"\"\n    Converts a Hugging Face parameter state dictionary to a custom parameter state dictionary.\n\n    Args:\n        hf_param_sd (Dict[str, torch.Tensor]): The Hugging Face parameter state dictionary\n        param_names_mapping (Callable[[str], tuple[str, Any, Any]]): A function that maps parameter names from source to target format\n\n    Returns:\n        custom_param_sd (Dict[str, torch.Tensor]): The custom formatted parameter state dict\n        reverse_param_names_mapping (Dict[str, Tuple[str, Any, Any]]): Maps back from custom to hf\n    \"\"\"\n    custom_param_sd = {}\n    to_merge_params = defaultdict(dict)  # type: ignore\n    reverse_param_names_mapping = {}\n    if isinstance(hf_param_sd, dict):\n        hf_param_sd = hf_param_sd.items()  # type: ignore\n    for source_param_name, full_tensor in hf_param_sd:  # type: ignore\n        target_param_name, merge_index, num_params_to_merge = param_names_mapping(\n            source_param_name\n        )\n        if target_param_name == \"\" or target_param_name is None:  # type: ignore[comparison-overlap]\n            continue\n        reverse_param_names_mapping[target_param_name] = (\n            source_param_name,\n            merge_index,\n            num_params_to_merge,\n        )\n        if merge_index is not None:\n            to_merge_params[target_param_name][merge_index] = full_tensor\n            if len(to_merge_params[target_param_name]) == num_params_to_merge:\n                # cat at output dim according to the merge_index order\n                sorted_tensors = [\n                    to_merge_params[target_param_name][i]\n                    for i in range(num_params_to_merge)\n                ]\n                full_tensor = torch.cat(sorted_tensors, dim=0)\n                del to_merge_params[target_param_name]\n            else:\n                continue\n        custom_param_sd[target_param_name] = full_tensor\n    return custom_param_sd, reverse_param_names_mapping\n\n\nclass skip_init_modules:\n    def __enter__(self):\n        # Save originals\n        self._orig_reset = {}\n        for cls in (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d):\n            self._orig_reset[cls] = cls.reset_parameters\n            cls.reset_parameters = lambda self: None  # skip init\n\n    def __exit__(self, exc_type, exc_value, traceback):\n        # restore originals\n        for cls, orig in self._orig_reset.items():\n            cls.reset_parameters = orig\n\n\ndef _normalize_component_type(module_type: str) -> str:\n    \"\"\"Normalize module types like 'text_encoder_2' -> 'text_encoder'.\"\"\"\n    if module_type.endswith(\"_2\"):\n        return module_type[:-2]\n    return module_type\n\n\ndef _clean_hf_config_inplace(model_config: dict) -> None:\n    \"\"\"Remove common extraneous HF fields if present.\"\"\"\n    for key in (\n        \"_name_or_path\",\n        \"transformers_version\",\n        \"model_type\",\n        \"tokenizer_class\",\n        \"torch_dtype\",\n    ):\n        model_config.pop(key, None)\n\n\ndef _list_safetensors_files(model_path: str) -> list[str]:\n    \"\"\"List all .safetensors files under a directory.\"\"\"\n    return sorted(glob.glob(os.path.join(str(model_path), \"*.safetensors\")))\n\n\nBYTES_PER_GB = 1024**3\n\n\ndef get_memory_usage_of_component(module) -> float | None:\n    \"\"\"\n    returned value is in GB, rounded to 2 decimal digits\n    \"\"\"\n    if not isinstance(module, nn.Module):\n        return None\n    if hasattr(module, \"get_memory_footprint\"):\n        usage = module.get_memory_footprint() / BYTES_PER_GB\n    else:\n        # manually\n        param_size = sum(p.numel() * p.element_size() for p in module.parameters())\n        buffer_size = sum(b.numel() * b.element_size() for b in module.buffers())\n\n        total_size_bytes = param_size + buffer_size\n        usage = total_size_bytes / (1024**3)\n\n    return round(usage, 2)\n\n\n# component name ->  ComponentLoader class\ncomponent_name_to_loader_cls: Dict[str, Type[Any]] = {}\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/loader/weight_utils.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/model_loader/weight_utils.py\n\"\"\"Utilities for downloading, loading, initializing and verifying model weights.\"\"\"\n\nimport hashlib\nimport json\nimport os\nimport tempfile\nfrom collections.abc import Generator, Iterable\nfrom pathlib import Path\n\nimport filelock\nimport torch\nfrom safetensors.torch import safe_open\nfrom torch.distributed.tensor import DTensor\nfrom tqdm.auto import tqdm\n\ntry:\n    from runai_model_streamer import SafetensorsStreamer\n\n    HAS_RUNAI_MODEL_STREAMER = True\nexcept ImportError:\n    HAS_RUNAI_MODEL_STREAMER = False\n\nfrom sglang.multimodal_gen.runtime.distributed import get_local_torch_device\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n# use system-level temp directory for file locks, so that multiple users\n# can share the same lock without error.\n# lock files in the temp directory will be automatically deleted when the\n# system reboots, so users will not complain about annoying lock files\ntemp_dir = tempfile.gettempdir()\n\n\nclass DisabledTqdm(tqdm):\n\n    def __init__(self, *args, **kwargs):\n        kwargs[\"disable\"] = True\n        super().__init__(*args, **kwargs)\n\n\ndef get_lock(model_name_or_path: str | Path, cache_dir: str | None = None):\n    lock_dir = cache_dir or temp_dir\n    model_name_or_path = str(model_name_or_path)\n    os.makedirs(os.path.dirname(lock_dir), exist_ok=True)\n    model_name = model_name_or_path.replace(\"/\", \"-\")\n    hash_name = hashlib.sha256(model_name.encode()).hexdigest()\n    # add hash to avoid conflict with old users' lock files\n    lock_file_name = hash_name + model_name + \".lock\"\n    # mode 0o666 is required for the filelock to be shared across users\n    lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), mode=0o666)\n    return lock\n\n\n# For models like Mistral-7B-v0.3, there are both sharded\n# safetensors files and a consolidated safetensors file.\n# Passing both of these to the weight loader functionality breaks.\n# So, we use the index_file to\n# look up which safetensors files should be used.\ndef filter_duplicate_safetensors_files(\n    hf_weights_files: list[str], hf_folder: str, index_file: str\n) -> list[str]:\n    # model.safetensors.index.json is a mapping from keys in the\n    # torch state_dict to safetensors file holding that weight.\n    index_file_name = os.path.join(hf_folder, index_file)\n    if not os.path.isfile(index_file_name):\n        return hf_weights_files\n\n    # Iterate through the weight_map (weight_name: safetensors files)\n    # to identify weights that we should use.\n    with open(index_file_name) as f:\n        weight_map = json.load(f)[\"weight_map\"]\n    weight_files_in_index = set()\n    for weight_name in weight_map:\n        weight_files_in_index.add(os.path.join(hf_folder, weight_map[weight_name]))\n    # Filter out any fields that are not found in the index file.\n    hf_weights_files = [f for f in hf_weights_files if f in weight_files_in_index]\n    return hf_weights_files\n\n\ndef filter_files_not_needed_for_inference(hf_weights_files: list[str]) -> list[str]:\n    \"\"\"\n    Exclude files that are not needed for inference.\n\n    See https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233\n    \"\"\"\n    blacklist = [\n        \"training_args.bin\",\n        \"optimizer.bin\",\n        \"optimizer.pt\",\n        \"scheduler.pt\",\n        \"scaler.pt\",\n    ]\n    hf_weights_files = [\n        f for f in hf_weights_files if not any(f.endswith(x) for x in blacklist)\n    ]\n    return hf_weights_files\n\n\n# explicitly use pure text format, with a newline at the end\n# this makes it impossible to see the animation in the progress bar\n# but will avoid messing up with ray or multiprocessing, which wraps\n# each line of output with some prefix.\n_BAR_FORMAT = \"{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\\n\"  # noqa: E501\n\n\ndef _validate_safetensors_file(file_path: str) -> bool:\n    \"\"\"\n    Validate that a safetensors file is readable and not corrupted.\n\n    Args:\n        file_path: Path to the safetensors file\n\n    Returns:\n        True if file is valid, False if corrupted\n    \"\"\"\n    try:\n        with safe_open(file_path, framework=\"pt\", device=\"cpu\") as f:\n            _ = list(f.keys())\n        return True\n    except Exception as e:\n        logger.error(\n            \"Corrupted safetensors file detected: %s - %s: %s\",\n            file_path,\n            type(e).__name__,\n            str(e),\n        )\n        return False\n\n\ndef safetensors_weights_iterator(\n    hf_weights_files: list[str],\n    to_cpu: bool = True,\n    use_runai_model_streamer: bool = HAS_RUNAI_MODEL_STREAMER,\n) -> Generator[tuple[str, torch.Tensor], None, None]:\n    \"\"\"Iterate over the weights in the model safetensor files.\"\"\"\n    enable_tqdm = (\n        not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0\n    )\n    device = \"cpu\" if to_cpu else str(get_local_torch_device())\n\n    # Validate files before loading\n    corrupted_files = [\n        st_file\n        for st_file in hf_weights_files\n        if not _validate_safetensors_file(st_file)\n    ]\n\n    if corrupted_files:\n        # Delete corrupted files (both symlink and blob if applicable)\n        for file_path in corrupted_files:\n            try:\n                if os.path.islink(file_path):\n                    blob_path = os.path.realpath(file_path)\n                    os.remove(file_path)\n                    logger.info(\n                        \"Removed corrupted symlink: %s\", os.path.basename(file_path)\n                    )\n                    if os.path.exists(blob_path):\n                        os.remove(blob_path)\n                        logger.info(\n                            \"Removed corrupted blob: %s\", os.path.basename(blob_path)\n                        )\n                elif os.path.isfile(file_path):\n                    os.remove(file_path)\n                    logger.info(\n                        \"Removed corrupted file: %s\", os.path.basename(file_path)\n                    )\n            except Exception as e:\n                logger.warning(\"Failed to remove corrupted file %s: %s\", file_path, e)\n\n        raise RuntimeError(\n            f\"Found {len(corrupted_files)} corrupted safetensors file(s). \"\n            f\"Files have been removed: {[os.path.basename(f) for f in corrupted_files]}. \"\n            \"Please retry - the files will be re-downloaded automatically.\"\n        )\n\n    if use_runai_model_streamer:\n        with SafetensorsStreamer() as streamer:\n            streamer.stream_files(hf_weights_files)\n            for name, tensor in streamer.get_tensors():\n                if to_cpu:\n                    yield name, tensor.clone().detach()\n                else:\n                    yield name, tensor.to(device)\n    else:\n        for st_file in tqdm(\n            hf_weights_files,\n            desc=\"Loading safetensors checkpoint shards\",\n            disable=not enable_tqdm,\n            bar_format=_BAR_FORMAT,\n        ):\n            with safe_open(st_file, framework=\"pt\", device=device) as f:\n                for name in f.keys():  # noqa: SIM118\n                    param = f.get_tensor(name)\n                    yield name, param\n\n\ndef _load_pt_file(bin_file: str, device: str) -> dict:\n    \"\"\"Load a PyTorch checkpoint file, handling legacy tar format.\n\n    PyTorch 2.6 changed the default of weights_only from False to True.\n    Legacy tar format files cannot be loaded with weights_only=True.\n    This function tries weights_only=True first, then falls back to False\n    for legacy tar format files from trusted sources (HuggingFace Hub).\n    \"\"\"\n    try:\n        return torch.load(bin_file, map_location=device, weights_only=True)\n    except RuntimeError as e:\n        if \"legacy .tar format\" in str(e):\n            logger.warning(\n                \"Loading %s with weights_only=False (legacy tar format)\",\n                os.path.basename(bin_file),\n            )\n            return torch.load(bin_file, map_location=device, weights_only=False)\n        raise\n\n\ndef pt_weights_iterator(\n    hf_weights_files: list[str],\n    to_cpu: bool = True,\n) -> Generator[tuple[str, torch.Tensor], None, None]:\n    \"\"\"Iterate over the weights in the model bin/pt files.\"\"\"\n    device = \"cpu\" if to_cpu else str(get_local_torch_device())\n    enable_tqdm = (\n        not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0\n    )\n    for bin_file in tqdm(\n        hf_weights_files,\n        desc=\"Loading pt checkpoint shards\",\n        disable=not enable_tqdm,\n        bar_format=_BAR_FORMAT,\n    ):\n        state = _load_pt_file(bin_file, device)\n        yield from state.items()\n        del state\n\n\ndef default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:\n    \"\"\"Default weight loader.\"\"\"\n    try:\n        if param.numel() == 1 and loaded_weight.numel() == 1:\n            # Sometimes scalar values aren't considered tensors with shapes\n            # so if both param and loaded_weight are a scalar,\n            # \"broadcast\" instead of copy\n            param.data.fill_(loaded_weight.item())\n        else:\n            assert param.size() == loaded_weight.size(), (\n                f\"Attempted to load weight ({loaded_weight.size()}) \"\n                f\"into parameter ({param.size()})\"\n            )\n\n            param.data.copy_(loaded_weight)\n    except Exception:\n        # NOTE: This exception is added for the purpose of setting breakpoint to\n        # debug weight loading issues.\n        raise\n\n\ndef maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None:\n    \"\"\"Remap the name of FP8 k/v_scale parameters.\n\n    This function handles the remapping of FP8 k/v_scale parameter names.\n    It detects if the given name ends with a suffix and attempts to remap\n    it to the expected name format in the model. If the remapped name is not\n    found in the params_dict, a warning is printed and None is returned.\n\n    Args:\n        name (str): The original loaded checkpoint parameter name.\n        params_dict (dict): Dictionary containing the model's named parameters.\n\n    Returns:\n        str: The remapped parameter name if successful, or the original name\n             if no remapping is needed.\n        None: If the remapped name is not found in params_dict.\n    \"\"\"\n    if name.endswith(\".kv_scale\"):\n        logger.warning_once(\n            \"DEPRECATED. Found kv_scale in the checkpoint. \"\n            \"This format is deprecated in favor of separate k_scale and \"\n            \"v_scale tensors and will be removed in a future release. \"\n            \"Functionally, we will remap kv_scale to k_scale and duplicate \"\n            \"k_scale to v_scale\"\n        )\n        # NOTE: we remap the deprecated kv_scale to k_scale\n        remapped_name = name.replace(\".kv_scale\", \".attn.k_scale\")\n        if remapped_name not in params_dict:\n            logger.warning_once(\n                f\"Found kv_scale in the checkpoint (e.g. {name}), \"\n                \"but not found the expected name in the model \"\n                f\"(e.g. {remapped_name}). kv_scale is \"\n                \"not loaded.\"\n            )\n            return None\n        return remapped_name\n\n    possible_scale_names = [\".k_scale\", \".v_scale\"]\n    modelopt_scale_names = [\".self_attn.k_proj.k_scale\", \".self_attn.v_proj.v_scale\"]\n    for scale_name in possible_scale_names:\n        if name.endswith(scale_name):\n            if any(mo_scale_name in name for mo_scale_name in modelopt_scale_names):\n                remapped_name = name.replace(\n                    f\".self_attn.{scale_name[1]}_proj{scale_name}\",\n                    f\".self_attn.attn{scale_name}\",\n                )\n            else:\n                remapped_name = name.replace(scale_name, f\".attn{scale_name}\")\n            if remapped_name not in params_dict:\n                logger.warning_once(\n                    f\"Found {scale_name} in the checkpoint (e.g. {name}), \"\n                    \"but not found the expected name in the model \"\n                    f\"(e.g. {remapped_name}). {scale_name} is \"\n                    \"not loaded.\"\n                )\n                return None\n            return remapped_name\n\n    # If there were no matches, return the untouched param name\n    return name\n\n\ndef compute_weights_checksum(\n    named_params: Iterable[tuple[str, torch.Tensor]],\n) -> str:\n    \"\"\"Compute a SHA-256 checksum for a set of (name, tensor) pairs.\n\n    Used to verify the correctness of weight refitting. After a refit,\n    compare the checksum of the in-GPU model weights against the checksum\n    of the on-disk tensors or the tensors in the training engine.\n    \"\"\"\n    hasher = hashlib.sha256()\n    for name, tensor in sorted(named_params, key=lambda x: x[0]):\n        hasher.update(name.encode())\n        t = tensor.detach()\n        # DTensor doesn't support .numpy(); extract the local tensor.\n        if isinstance(t, DTensor):\n            t = t._local_tensor\n        hasher.update(t.cpu().contiguous().reshape(-1).view(torch.uint8).numpy().data)\n    return hasher.hexdigest()\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/loader/weights_updater.py",
    "content": "\"\"\"\nIn-place weight updates for diffusion pipeline modules.\n\nThis module provides WeightsUpdater, which swaps model weights at runtime\nwithout restarting the server.  It is the diffusion-engine counterpart of the\nLLM engine's ModelRunner.update_weights_from_disk.\n\nDetailed usage of higher level API can be found in\n\n/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py\n\nKey design decisions:\n\n- All-or-nothing with rollback: modules are updated sequentially.  If\n  any module fails (shape mismatch, corrupted file, etc.), every module\n  that was already updated is rolled back by reloading its weights from\n  pipeline.model_path (the last successfully-loaded checkpoint).  On\n  success, pipeline.model_path is updated to the new model_path so\n  that future rollbacks target the latest good checkpoint, not the\n  originally-launched model.\n\n- Rollback failures propagate: if rollback itself fails, the exception is\n  not caught so the caller knows the model is in an inconsistent state.\n  This matches the LLM engine behaviour.\n\n- Offload-aware: the diffusion LayerwiseOffloadManager replaces GPU\n  parameters with torch.empty((1,)) placeholders while real weights live\n  in consolidated pinned CPU buffers.  A naive param.data.copy_() would\n  fail with a shape mismatch.  Instead, the updater dynamically detects\n  active offload managers and writes new weights directly into their CPU\n  buffers via update_cpu_weights(), bypassing the placeholders entirely.\n  For any layer that happens to be prefetched on GPU at update time, the\n  live GPU tensor is also updated so the change takes effect immediately.\n  This requires no extra GPU memory and does not disturb the offload state.\n\n- DTensor-aware: parameters that have been distributed via\n  torch.distributed.tensor are updated through distribute_tensor\n  so that each shard is correctly placed on the right device mesh.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport gc\nfrom pathlib import Path\n\nimport torch\nfrom torch.distributed.tensor import DTensor, distribute_tensor\n\nfrom sglang.multimodal_gen.runtime.cache.teacache import TeaCacheMixin\nfrom sglang.multimodal_gen.runtime.loader.utils import (\n    _list_safetensors_files,\n)\nfrom sglang.multimodal_gen.runtime.loader.weight_utils import (\n    safetensors_weights_iterator,\n)\nfrom sglang.multimodal_gen.runtime.pipelines.diffusers_pipeline import DiffusersPipeline\nfrom sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import maybe_download_model\nfrom sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\ndef get_updatable_modules(pipeline) -> dict[str, torch.nn.Module]:\n    \"\"\"Return updatable nn.Module components for the given pipeline.\n\n    Works with both the native ComposedPipelineBase backend and the\n    DiffusersPipeline wrapper.\n    \"\"\"\n    if isinstance(pipeline, DiffusersPipeline):\n        diffusers_pipe = pipeline.get_module(\"diffusers_pipeline\")\n        if diffusers_pipe is not None and diffusers_pipe.components is not None:\n            raw = diffusers_pipe.components\n        else:\n            raw = {}\n    else:\n        raw = pipeline.modules\n    return {n: m for n, m in raw.items() if isinstance(m, torch.nn.Module)}\n\n\ndef _get_weights_iter(weights_dir: str):\n    \"\"\"Return a (name, tensor) iterator over safetensors in weights_dir.\"\"\"\n    safetensors_files = _list_safetensors_files(weights_dir)\n    if not safetensors_files:\n        raise FileNotFoundError(f\"No safetensors files found in {weights_dir}\")\n    return safetensors_weights_iterator(safetensors_files)\n\n\ndef _validate_weight_files(\n    local_model_path: str,\n    modules_to_update: list[tuple[str, torch.nn.Module]],\n) -> tuple[dict[str, str], list[str]]:\n    \"\"\"Check that every module has a weights directory with safetensors files.\n\n    Returns:\n        (weights_map, missing) where weights_map maps module name to its\n        weights directory and missing lists modules without weight files.\n    \"\"\"\n    weights_map: dict[str, str] = {}\n    missing: list[str] = []\n    for module_name, _ in modules_to_update:\n        weights_dir = Path(local_model_path) / module_name\n        if weights_dir.exists() and _list_safetensors_files(str(weights_dir)):\n            weights_map[module_name] = str(weights_dir)\n        else:\n            missing.append(module_name)\n    return weights_map, missing\n\n\ndef _load_weights_into_module(module: torch.nn.Module, weights_iter) -> None:\n    \"\"\"Load weights into a module, handling offload-managed parameters.\n\n    For offloaded modules, updates CPU buffers directly via\n    update_cpu_weights(); non-offloaded parameters use in-place copy.\n    \"\"\"\n    offload_managers: list = []\n    if isinstance(module, OffloadableDiTMixin) and module.layerwise_offload_managers:\n        offload_managers = [m for m in module.layerwise_offload_managers if m.enabled]\n\n    if offload_managers:\n        weight_dict = dict(weights_iter)\n        offloaded_names: set[str] = set()\n        for manager in offload_managers:\n            offloaded_names.update(manager.update_cpu_weights(weight_dict))\n        remaining = ((n, w) for n, w in weight_dict.items() if n not in offloaded_names)\n        load_weights_into_model(remaining, dict(module.named_parameters()))\n    else:\n        load_weights_into_model(weights_iter, dict(module.named_parameters()))\n\n\ndef load_weights_into_model(weights_iter, model_params: dict) -> None:\n    \"\"\"Copy weights from weights_iter into model_params in-place.\"\"\"\n    for name, loaded_weight in weights_iter:\n        if name not in model_params:\n            continue\n        param = model_params[name]\n        if param.shape != loaded_weight.shape:\n            raise ValueError(\n                f\"Shape mismatch for {name}: model={param.shape}, loaded={loaded_weight.shape}\"\n            )\n        if isinstance(param, DTensor):\n            distributed_weight = distribute_tensor(\n                loaded_weight.to(param.dtype),\n                param.device_mesh,\n                param.placements,\n            )\n            param._local_tensor.copy_(distributed_weight._local_tensor)\n        else:\n            param.data.copy_(loaded_weight.to(param.dtype))\n\n\nclass WeightsUpdater:\n    \"\"\"In-place weight updates for diffusion pipeline modules.\n\n    Args:\n        pipeline: A ComposedPipelineBase (or DiffusersPipeline) instance\n            whose modules will be updated.  The pipeline's model_path\n            attribute is used for rollback on failure.\n    \"\"\"\n\n    def __init__(self, pipeline):\n        self.pipeline = pipeline\n\n    def update_weights_from_disk(\n        self,\n        model_path: str,\n        flush_cache: bool = True,\n        target_modules: list[str] | None = None,\n    ) -> tuple[bool, str]:\n        \"\"\"Update model weights from disk without restarting the server.\"\"\"\n        logger.info(f\"Updating weights from disk: {model_path}\")\n\n        try:\n            modules_to_update = self._collect_modules(target_modules)\n        except ValueError as e:\n            logger.error(str(e))\n            return False, str(e)\n\n        if not modules_to_update:\n            error_msg = (\n                f\"No matching modules found for update. \"\n                f\"Requested: {target_modules}. \"\n                f\"Available nn.Module(s): {list(get_updatable_modules(self.pipeline).keys())}\"\n            )\n            logger.error(error_msg)\n            return False, error_msg\n\n        try:\n            local_model_path = maybe_download_model(model_path)\n        except Exception as e:\n            return False, f\"Failed to download model: {e}\"\n\n        weights_map, missing = _validate_weight_files(\n            local_model_path, modules_to_update\n        )\n        if missing:\n            error_msg = (\n                f\"Cannot update weights: missing weight files for modules: {missing}. \"\n                f\"No partial updates allowed.\"\n            )\n            logger.error(error_msg)\n            return False, error_msg\n\n        logger.info(\n            f\"Updating {len(weights_map)} modules: \"\n            + \", \".join(f\"{n} <- {p}\" for n, p in weights_map.items())\n        )\n\n        success, message = self._apply_weights(modules_to_update, weights_map)\n\n        gc.collect()\n        torch.cuda.empty_cache()\n\n        if success and flush_cache:\n            for _, module in modules_to_update:\n                if isinstance(module, TeaCacheMixin):\n                    module.reset_teacache_state()\n\n        logger.info(message)\n        return success, message\n\n    def _collect_modules(\n        self, target_modules: list[str] | None\n    ) -> list[tuple[str, torch.nn.Module]]:\n        \"\"\"Resolve target_modules to (name, module) pairs.\n\n        Raises:\n            ValueError: If target_modules contains names not found in the pipeline.\n        \"\"\"\n        components = get_updatable_modules(self.pipeline)\n\n        if target_modules is None:\n            names = list(components.keys())\n        else:\n            unknown = [n for n in target_modules if n not in components]\n            if unknown:\n                raise ValueError(\n                    f\"Module(s) requested for update not found in pipeline: {unknown}. \"\n                    f\"Available Module(s): {list(components.keys())}\"\n                )\n            names = target_modules\n\n        return [(name, components[name]) for name in names]\n\n    def _apply_weights(\n        self,\n        modules_to_update: list[tuple[str, torch.nn.Module]],\n        weights_map: dict[str, str],\n    ) -> tuple[bool, str]:\n        \"\"\"Load weights into each module; rollback on first failure.\"\"\"\n        updated_modules: list[str] = []\n\n        for module_name, module in modules_to_update:\n            try:\n                weights_iter = _get_weights_iter(weights_map[module_name])\n                _load_weights_into_module(module, weights_iter)\n                updated_modules.append(module_name)\n            except Exception as e:\n                rollback_list = updated_modules + [module_name]\n                logger.error(\n                    f\"Weight update failed for module '{module_name}': {e}. \"\n                    f\"Rolling back {len(rollback_list)} module(s) \"\n                    f\"(including partially-loaded '{module_name}'): \"\n                    f\"{rollback_list}.\",\n                    exc_info=True,\n                )\n                self._rollback(rollback_list)\n                return False, (\n                    f\"Failed to update module '{module_name}': {e}. \"\n                    f\"All modules rolled back to original weights.\"\n                )\n\n        names = \", \".join(updated_modules)\n        return True, f\"Updated {len(updated_modules)} modules ({names}).\"\n\n    def _rollback(self, updated_modules: list[str]) -> None:\n        \"\"\"Restore updated_modules to original weights.\n\n        If rollback itself fails the exception propagates so the caller\n        knows the model is in an inconsistent state.\n        \"\"\"\n        if not updated_modules:\n            return\n        original_path = maybe_download_model(self.pipeline.model_path)\n        for name in updated_modules:\n            module = self.pipeline.get_module(name)\n            if module is None:\n                continue\n            weights_dir = Path(original_path) / name\n            if not weights_dir.exists():\n                continue\n            weights_iter = _get_weights_iter(str(weights_dir))\n            _load_weights_into_module(module, weights_iter)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/managers/forward_context.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/forward_context.py\nimport time\nfrom collections import defaultdict\nfrom contextlib import contextmanager\nfrom dataclasses import dataclass\nfrom typing import TYPE_CHECKING, Optional, Type\n\nimport torch\n\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nif TYPE_CHECKING:\n    from sglang.multimodal_gen.runtime.layers.attention import AttentionMetadata\n    from sglang.multimodal_gen.runtime.pipelines_core import Req\n\nlogger = init_logger(__name__)\n\n# TODO(will): check if this is needed\n# track_batchsize: bool = envs.SGLANG_DIFFUSION_LOG_BATCHSIZE_INTERVAL >= 0\ntrack_batchsize: bool = False\nlast_logging_time: float = 0\nforward_start_time: float = 0\n# batchsize_logging_interval: float = envs.SGLANG_DIFFUSION_LOG_BATCHSIZE_INTERVAL\nbatchsize_logging_interval: float = 1000\nbatchsize_forward_time: defaultdict = defaultdict(list)\n\n\n@dataclass\nclass ForwardContext:\n    current_timestep: int\n    # TODO(will): check this arg\n    # copy from vllm_config.compilation_config.static_forward_context\n    # attn_layers: Dict[str, Any]\n    # TODO: extend to support per-layer dynamic forward context\n    attn_metadata: \"AttentionMetadata\"  # set dynamically for each forward pass\n    forward_batch: Optional[\"Req\"] = None\n    attention_backend_cls: Optional[Type] = None\n\n    def set_attn_backend_cls(self, attention_backend_cls: Type):\n        if self.attention_backend_cls:\n            if self.attention_backend_cls != attention_backend_cls:\n                raise RuntimeError(\n                    f\"Different types of attention backend in a same context detected, previous: {self.attention_backend_cls}, new: {attention_backend_cls}\"\n                )\n        else:\n            self.attention_backend_cls = attention_backend_cls\n\n\n_forward_context: Optional[\"ForwardContext\"] = None\n\n\ndef get_forward_context() -> \"ForwardContext\":\n    \"\"\"Get the current forward context.\"\"\"\n    assert _forward_context is not None, (\n        \"Forward context is not set. \"\n        \"Please use `set_forward_context` to set the forward context.\"\n    )\n    return _forward_context\n\n\n# TODO(will): finalize the interface\n@contextmanager\ndef set_forward_context(\n    current_timestep, attn_metadata, forward_batch: Optional[\"Req\"] = None\n):\n    \"\"\"A context manager that stores the current forward context,\n    can be attention metadata, etc.\n    Here we can inject common logic for every model forward pass.\n    \"\"\"\n    global forward_start_time\n    need_to_track_batchsize = track_batchsize and attn_metadata is not None\n    if need_to_track_batchsize:\n        forward_start_time = time.perf_counter()\n    global _forward_context\n    prev_context = _forward_context\n    _forward_context = ForwardContext(\n        current_timestep=current_timestep,\n        attn_metadata=attn_metadata,\n        forward_batch=forward_batch,\n    )\n\n    try:\n        yield\n    finally:\n        global last_logging_time, batchsize_logging_interval\n        if need_to_track_batchsize:\n            if hasattr(attn_metadata, \"num_prefill_tokens\"):\n                # for v0 attention backends\n                batchsize = (\n                    attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens\n                )\n            else:\n                # for v1 attention backends\n                batchsize = attn_metadata.num_input_tokens\n            now = time.perf_counter()\n            # time measurement is in milliseconds\n            batchsize_forward_time[batchsize].append((now - forward_start_time) * 1000)\n            if now - last_logging_time > batchsize_logging_interval:\n                last_logging_time = now\n                forward_stats = []\n                for bs, times in batchsize_forward_time.items():\n                    if len(times) <= 1:\n                        # can be cudagraph / profiling run\n                        continue\n                    medium = torch.quantile(torch.tensor(times), q=0.5).item()\n                    medium = round(medium, 2)\n                    forward_stats.append((bs, len(times), medium))\n                forward_stats.sort(key=lambda x: x[1], reverse=True)\n                if forward_stats:\n                    logger.info(\n                        (\n                            \"Batchsize forward time stats \"\n                            \"(batchsize, count, median_time(ms)): %s\"\n                        ),\n                        forward_stats,\n                    )\n        _forward_context = prev_context\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/managers/gpu_worker.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\nimport gc\nimport multiprocessing as mp\nimport os\nimport time\nfrom typing import List, Union\n\nimport torch\nfrom setproctitle import setproctitle\n\nfrom sglang.multimodal_gen import envs\nfrom sglang.multimodal_gen.runtime.distributed import (\n    get_sp_group,\n    get_tp_rank,\n    get_tp_world_size,\n    maybe_init_distributed_environment_and_model_parallel,\n    model_parallel_is_initialized,\n)\nfrom sglang.multimodal_gen.runtime.distributed.parallel_state import (\n    get_cfg_group,\n    get_classifier_free_guidance_rank,\n    get_classifier_free_guidance_world_size,\n    get_ring_parallel_rank,\n    get_ring_parallel_world_size,\n    get_tp_group,\n    get_ulysses_parallel_rank,\n    get_ulysses_parallel_world_size,\n)\nfrom sglang.multimodal_gen.runtime.entrypoints.utils import save_outputs\nfrom sglang.multimodal_gen.runtime.loader.weight_utils import compute_weights_checksum\nfrom sglang.multimodal_gen.runtime.loader.weights_updater import (\n    WeightsUpdater,\n    get_updatable_modules,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core import (\n    ComposedPipelineBase,\n    LoRAPipeline,\n    Req,\n    build_pipeline,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import OutputBatch\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\nfrom sglang.multimodal_gen.runtime.server_args import PortArgs, ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.common import set_cuda_arch, set_musa_arch\nfrom sglang.multimodal_gen.runtime.utils.layerwise_offload import (\n    OffloadableDiTMixin,\n    iter_materialized_weights,\n)\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import (\n    configure_logger,\n    globally_suppress_loggers,\n    init_logger,\n)\nfrom sglang.multimodal_gen.runtime.utils.perf_logger import (\n    PerformanceLogger,\n    capture_memory_snapshot,\n)\nfrom sglang.srt.utils.network import NetworkAddress\n\nlogger = init_logger(__name__)\n\n\nclass GPUWorker:\n    \"\"\"\n    A worker that executes the model on a single GPU.\n    \"\"\"\n\n    def __init__(\n        self,\n        local_rank: int,\n        rank: int,\n        master_port: int,\n        server_args: ServerArgs,\n    ):\n        self.local_rank = local_rank\n        self.rank = rank\n        self.master_port = master_port\n        # FIXME: should we use tcp as distribute init method?\n        self.server_args = server_args\n        self.pipeline: ComposedPipelineBase = None\n\n        self.init_device_and_model()\n        self.sp_group = get_sp_group()\n        self.sp_cpu_group = self.sp_group.cpu_group\n        self.tp_group = get_tp_group()\n        self.tp_cpu_group = self.tp_group.cpu_group\n\n        self.cfg_group = get_cfg_group()\n        self.cfg_cpu_group = self.cfg_group.cpu_group\n\n    def init_device_and_model(self) -> None:\n        \"\"\"Initialize the device and load the model.\"\"\"\n        torch.get_device_module().set_device(self.local_rank)\n        # Set environment variables for distributed initialization\n        os.environ[\"MASTER_ADDR\"] = \"localhost\"\n        os.environ[\"MASTER_PORT\"] = str(self.master_port)\n        os.environ[\"LOCAL_RANK\"] = str(self.local_rank)\n        os.environ[\"RANK\"] = str(self.rank)\n        os.environ[\"WORLD_SIZE\"] = str(self.server_args.num_gpus)\n        # initialize the distributed environment\n        maybe_init_distributed_environment_and_model_parallel(\n            tp_size=self.server_args.tp_size,\n            enable_cfg_parallel=self.server_args.enable_cfg_parallel,\n            ulysses_degree=self.server_args.ulysses_degree,\n            ring_degree=self.server_args.ring_degree,\n            sp_size=self.server_args.sp_degree,\n            dp_size=self.server_args.dp_size,\n            distributed_init_method=NetworkAddress(\n                \"127.0.0.1\", self.master_port\n            ).to_tcp(),\n            dist_timeout=self.server_args.dist_timeout,\n        )\n\n        # set proc title\n        if model_parallel_is_initialized():\n            suffix = \"\"\n            if get_tp_world_size() != 1:\n                tp_rank = get_tp_rank()\n                suffix += f\"_TP{tp_rank}\"\n            if get_ulysses_parallel_world_size() != 1:\n                u_rank = get_ulysses_parallel_rank()\n                suffix += f\"_U{u_rank}\"\n            if get_ring_parallel_world_size() != 1:\n                r_rank = get_ring_parallel_rank()\n                suffix += f\"_R{r_rank}\"\n            if get_classifier_free_guidance_world_size() != 1:\n                c_rank = get_classifier_free_guidance_rank()\n                suffix += f\"_C{c_rank}\"\n            setproctitle(f\"sgl_diffusion::scheduler{suffix}\")\n        else:\n            setproctitle(f\"sgl_diffusion::scheduler_{self.local_rank}\")\n\n        self.pipeline = build_pipeline(self.server_args)\n\n        # apply layerwise offload after lora is applied while building LoRAPipeline\n        # otherwise empty offloaded weights could fail lora converting\n        if self.server_args.dit_layerwise_offload:\n            # enable layerwise offload if possible\n            for module_name in [\n                \"transformer\",\n                \"transformer_2\",\n                \"video_dit\",\n                \"video_dit_2\",\n                \"audio_dit\",\n            ]:\n                dit = self.pipeline.get_module(module_name)\n                if dit:\n                    if isinstance(dit, OffloadableDiTMixin):\n                        dit.configure_layerwise_offload(self.server_args)\n                    else:\n                        logger.info(\n                            f\"Module {type(dit).__name__} does not support layerwise offload. Skipping.\"\n                        )\n\n        logger.info(\n            f\"Worker {self.rank}: Initialized device, model, and distributed environment.\"\n        )\n\n    def do_mem_analysis(self, output_batch: OutputBatch):\n        final_snapshot = capture_memory_snapshot()\n        if output_batch.metrics:\n            output_batch.metrics.record_memory_snapshot(\"mem_analysis\", final_snapshot)\n\n        # for details on max_memory_reserved: https://docs.pytorch.org/docs/stable/generated/torch.cuda.memory.max_memory_reserved.html\n        peak_reserved_bytes = torch.get_device_module().max_memory_reserved()\n        peak_allocated_bytes = torch.get_device_module().max_memory_allocated()\n\n        output_batch.peak_memory_mb = peak_reserved_bytes / (1024**2)\n        peak_reserved_gb = peak_reserved_bytes / (1024**3)\n        peak_allocated_gb = peak_allocated_bytes / (1024**3)\n\n        remaining_gpu_mem_gb = (\n            current_platform.get_device_total_memory() / (1024**3) - peak_reserved_gb\n        )\n        can_stay_resident = self.get_can_stay_resident_components(remaining_gpu_mem_gb)\n        suggested_args = set()\n        component_to_arg = {\n            \"vae\": \"--vae-cpu-offload\",\n            \"text_encoder\": \"--text-encoder-cpu-offload\",\n            \"text_encoder_2\": \"--text-encoder-cpu-offload\",\n            \"image_encoder\": \"--image-encoder-cpu-offload\",\n        }\n\n        for component in can_stay_resident:\n            if component == \"transformer\":\n                if self.server_args.dit_layerwise_offload:\n                    suggested_args.add(\"--dit-layerwise-offload\")\n                elif self.server_args.dit_cpu_offload:\n                    suggested_args.add(\"--dit-cpu-offload\")\n            elif component in component_to_arg:\n                suggested_args.add(component_to_arg[component])\n\n        suggested_args_str = (\n            \", \".join(sorted(suggested_args)) if suggested_args else \"None\"\n        )\n\n        pool_overhead_gb = peak_reserved_gb - peak_allocated_gb\n\n        logger.info(\n            f\"Peak GPU memory: {peak_reserved_gb:.2f} GB, \"\n            f\"Peak allocated: {peak_allocated_gb:.2f} GB, \"\n            f\"Memory pool overhead: {pool_overhead_gb:.2f} GB ({pool_overhead_gb / peak_reserved_gb * 100:.1f}%), \"\n            f\"Remaining GPU memory at peak: {remaining_gpu_mem_gb:.2f} GB. \"\n            f\"Components that could stay resident (based on the last request workload): {can_stay_resident}. \"\n            f\"Related offload server args to disable: {suggested_args_str}\"\n        )\n\n    def execute_forward(self, batch: List[Req]) -> OutputBatch:\n        \"\"\"\n        Execute a forward pass.\n        \"\"\"\n        assert self.pipeline is not None\n        req = batch[0]\n        output_batch = None\n        try:\n            if self.rank == 0:\n                torch.get_device_module().reset_peak_memory_stats()\n\n            start_time = time.monotonic()\n\n            # capture memory baseline before forward\n            if self.rank == 0 and req.metrics:\n                baseline_snapshot = capture_memory_snapshot()\n                req.metrics.record_memory_snapshot(\"before_forward\", baseline_snapshot)\n\n            req.log(server_args=self.server_args)\n            result = self.pipeline.forward(req, self.server_args)\n\n            if isinstance(result, Req):\n                output_batch = OutputBatch(\n                    output=result.output,\n                    audio=getattr(result, \"audio\", None),\n                    audio_sample_rate=getattr(result, \"audio_sample_rate\", None),\n                    metrics=result.metrics,\n                    trajectory_timesteps=getattr(result, \"trajectory_timesteps\", None),\n                    trajectory_latents=getattr(result, \"trajectory_latents\", None),\n                    noise_pred=getattr(result, \"noise_pred\", None),\n                    trajectory_decoded=getattr(result, \"trajectory_decoded\", None),\n                )\n            else:\n                output_batch = result\n\n            # capture memory after forward (peak)\n            if self.rank == 0 and output_batch.metrics:\n                peak_snapshot = capture_memory_snapshot()\n                output_batch.metrics.record_memory_snapshot(\n                    \"after_forward\", peak_snapshot\n                )\n\n            if self.rank == 0 and not req.suppress_logs:\n                self.do_mem_analysis(output_batch)\n\n            duration_ms = (time.monotonic() - start_time) * 1000\n            output_batch.metrics.total_duration_ms = duration_ms\n\n            # Save output to file and return file path only if requested. Avoid the serialization\n            # and deserialization overhead between scheduler_client and gpu_worker.\n            if req.save_output and req.return_file_paths_only:\n                if self.rank == 0 and output_batch.output is not None:\n                    output_paths = save_outputs(\n                        output_batch.output,\n                        req.data_type,\n                        req.fps,\n                        True,\n                        lambda idx: req.output_file_path(len(output_batch.output), idx),\n                        audio=output_batch.audio,\n                        audio_sample_rate=output_batch.audio_sample_rate,\n                        output_compression=req.output_compression,\n                        enable_frame_interpolation=req.enable_frame_interpolation,\n                        frame_interpolation_exp=req.frame_interpolation_exp,\n                        frame_interpolation_scale=req.frame_interpolation_scale,\n                        frame_interpolation_model_path=req.frame_interpolation_model_path,\n                        enable_upscaling=req.enable_upscaling,\n                        upscaling_model_path=req.upscaling_model_path,\n                        upscaling_scale=req.upscaling_scale,\n                    )\n                    output_batch.output_file_paths = output_paths\n\n                # No rank needs to hold on to generated tensors once the file-path\n                # response has been materialized on rank 0\n                output_batch.output = None\n                output_batch.audio = None\n                output_batch.audio_sample_rate = None\n\n                if torch.cuda.is_initialized():\n                    torch.cuda.empty_cache()\n\n            # TODO: extract to avoid duplication\n            if req.perf_dump_path is not None or envs.SGLANG_DIFFUSION_STAGE_LOGGING:\n                # Avoid logging warmup perf records that share the same request_id.\n                if not req.is_warmup:\n                    PerformanceLogger.log_request_summary(metrics=output_batch.metrics)\n        except Exception as e:\n            logger.error(\n                f\"Error executing request {req.request_id}: {e}\", exc_info=True\n            )\n            if isinstance(e, _oom_exceptions()):\n                logger.warning(OOM_MSG)\n            if output_batch is None:\n                output_batch = OutputBatch()\n            output_batch.error = f\"Error executing request {req.request_id}: {e}\"\n        return output_batch\n\n    def get_can_stay_resident_components(\n        self, remaining_gpu_mem_gb: float\n    ) -> List[str]:\n        \"\"\"\n        Calculate which components can stay resident on GPU without being offloaded.\n        \"\"\"\n        can_stay_resident = []\n        if not self.pipeline:\n            return can_stay_resident\n\n        # Map memory_usage keys to server_args offload flags\n        # If the flag is False, the component is ALREADY resident, so we don't suggest it.\n        # If the flag is True, it is currently offloaded, so it's a candidate to \"stay resident\".\n        offload_flags = {\n            \"transformer\": self.server_args.dit_cpu_offload\n            or self.server_args.dit_layerwise_offload,\n            \"vae\": self.server_args.vae_cpu_offload,\n            \"text_encoder\": self.server_args.text_encoder_cpu_offload,\n            \"text_encoder_2\": self.server_args.text_encoder_cpu_offload,\n            \"image_encoder\": self.server_args.image_encoder_cpu_offload,\n        }\n\n        for name, usage in self.pipeline.memory_usages.items():\n            # Only consider components that are currently configured to be offloaded\n            is_offload_configured = offload_flags.get(name, False)\n            if not is_offload_configured:\n                continue\n\n            if usage <= remaining_gpu_mem_gb:\n                can_stay_resident.append(name)\n                remaining_gpu_mem_gb -= usage\n\n        return can_stay_resident\n\n    def set_lora(\n        self,\n        lora_nickname: Union[str, List[str]],\n        lora_path: Union[str, None, List[Union[str, None]]] = None,\n        target: Union[str, List[str]] = \"all\",\n        strength: Union[float, List[float]] = 1.0,\n    ) -> OutputBatch:\n        \"\"\"\n        Set the LoRA adapter(s) for the pipeline.\n        Supports both single LoRA (backward compatible) and multiple LoRA adapters.\n\n        Args:\n            lora_nickname: The nickname(s) of the adapter(s). Can be a string or a list of strings.\n            lora_path: Path(s) to the LoRA adapter(s). Can be a string, None, or a list of strings/None.\n            target: Which transformer(s) to apply the LoRA to. Can be a string or a list of strings.\n            strength: LoRA strength(s) for merge, default 1.0. Can be a float or a list of floats.\n        \"\"\"\n        if not isinstance(self.pipeline, LoRAPipeline):\n            return OutputBatch(error=\"Lora is not enabled\")\n        self.pipeline.set_lora(lora_nickname, lora_path, target, strength)\n        return OutputBatch()\n\n    def merge_lora_weights(\n        self, target: str = \"all\", strength: float = 1.0\n    ) -> OutputBatch:\n        \"\"\"\n        Merge LoRA weights.\n\n        Args:\n            target: Which transformer(s) to merge.\n            strength: LoRA strength for merge, default 1.0.\n        \"\"\"\n        if not isinstance(self.pipeline, LoRAPipeline):\n            return OutputBatch(error=\"Lora is not enabled\")\n        self.pipeline.merge_lora_weights(target, strength)\n        return OutputBatch()\n\n    def unmerge_lora_weights(self, target: str = \"all\") -> OutputBatch:\n        \"\"\"\n        Unmerge LoRA weights.\n\n        Args:\n            target: Which transformer(s) to unmerge.\n        \"\"\"\n        if not isinstance(self.pipeline, LoRAPipeline):\n            return OutputBatch(error=\"Lora is not enabled\")\n        self.pipeline.unmerge_lora_weights(target)\n        return OutputBatch()\n\n    def list_loras(self) -> OutputBatch:\n        \"\"\"\n        List loaded LoRA adapters and current application status per module.\n        \"\"\"\n        from sglang.multimodal_gen.runtime.pipelines_core.lora_pipeline import (\n            LoRAPipeline,\n        )\n\n        if not isinstance(self.pipeline, LoRAPipeline):\n            return OutputBatch(error=\"Lora is not enabled\")\n        status = self.pipeline.get_lora_status()\n        return OutputBatch(output=status)\n\n    def update_weights_from_disk(\n        self,\n        model_path: str,\n        flush_cache: bool = True,\n        target_modules: list[str] | None = None,\n    ) -> tuple[bool, str]:\n        \"\"\"Update model weights from disk inplace without restarting the server.\"\"\"\n        if not self.pipeline:\n            return False, \"Pipeline is not initialized\"\n\n        updater = WeightsUpdater(self.pipeline)\n        success, message = updater.update_weights_from_disk(\n            model_path,\n            flush_cache=flush_cache,\n            target_modules=target_modules,\n        )\n        if success:\n            self.server_args.model_path = model_path\n            self.pipeline.model_path = model_path\n        return success, message\n\n    def get_weights_checksum(\n        self, module_names: list[str] | None = None\n    ) -> dict[str, str]:\n        \"\"\"Compute SHA-256 checksum of each module's weights.\"\"\"\n        if not self.pipeline:\n            return {\"error\": \"Pipeline is not initialized\"}\n\n        all_modules = get_updatable_modules(self.pipeline)\n        names = module_names if module_names is not None else list(all_modules.keys())\n\n        checksums: dict[str, str] = {}\n        for name in names:\n            module = all_modules.get(name)\n            if module is None:\n                checksums[name] = \"not_found\"\n                continue\n            checksums[name] = compute_weights_checksum(\n                iter_materialized_weights(module)\n            )\n        return checksums\n\n\nOOM_MSG = f\"\"\"\nOOM detected. Possible solutions:\n  - If the OOM occurs during loading:\n    1. Enable CPU offload for memory-intensive components, or use `--dit-layerwise-offload` for DiT\n  - If the OOM occurs during runtime:\n    1. Enable SP and/or TP (in a multi-GPU setup)\n    2. Reduce the number of output tokens by lowering resolution or decreasing `--num-frames`\n    3. Opt for a sparse-attention backend\n    4. Enable FSDP by `--use-fsdp-inference` (in a multi-GPU setup)\n    5. Enable quantization (e.g. nunchaku)\n  Or, open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose\n\"\"\"\n\n\ndef _oom_exceptions():\n    # torch.OutOfMemoryError exists only in some PyTorch builds\n    types = [torch.cuda.OutOfMemoryError]\n    if hasattr(torch, \"OutOfMemoryError\"):\n        types.append(torch.OutOfMemoryError)\n    return tuple(types)\n\n\ndef run_scheduler_process(\n    local_rank: int,\n    rank: int,\n    master_port: int,\n    server_args: ServerArgs,\n    pipe_writer: mp.connection.Connection,\n    # For all workers: pipe to receive tasks from rank 0\n    task_pipe_r: mp.connection.Connection,\n    # For slave workers: pipe to send results back to rank 0\n    result_pipe_w: mp.connection.Connection | None,\n    # For rank 0 worker only: pipes to send tasks to slaves\n    task_pipes_to_slaves: list[mp.connection.Connection] | None = None,\n    # For rank 0 worker only: pipes to receive results from slaves\n    result_pipes_from_slaves: list[mp.connection.Connection] | None = None,\n) -> None:\n    \"\"\"\n    The entry point for the worker process.\n    Rank 0 acts as the master, handling ZMQ requests and coordinating slaves.\n    Ranks > 0 act as slaves, waiting for tasks from the master.\n    \"\"\"\n    configure_logger(server_args)\n    globally_suppress_loggers()\n    if current_platform.is_cuda():\n        set_cuda_arch()\n    elif current_platform.is_musa():\n        set_musa_arch()\n\n    port_args = PortArgs.from_server_args(server_args)\n\n    # start the scheduler event loop\n    assert task_pipes_to_slaves is not None\n    assert result_pipes_from_slaves is not None\n    from sglang.multimodal_gen.runtime.managers.scheduler import Scheduler\n\n    try:\n        scheduler = Scheduler(\n            server_args,\n            gpu_id=rank,\n            port_args=port_args,\n            task_pipes_to_slaves=task_pipes_to_slaves,\n            result_pipes_from_slaves=result_pipes_from_slaves,\n        )\n        logger.info(f\"Worker {rank}: Scheduler loop started.\")\n        pipe_writer.send(\n            {\n                \"status\": \"ready\",\n            }\n        )\n        scheduler.event_loop()\n    except _oom_exceptions() as _e:\n        logger.warning(OOM_MSG)\n        raise\n    finally:\n        # Clean up resources to speed up shutdown\n        if \"scheduler\" in locals():\n            del scheduler\n        gc.collect()\n        if torch.cuda.is_initialized():\n            torch.cuda.empty_cache()\n        if torch.distributed.is_available() and torch.distributed.is_initialized():\n            torch.distributed.destroy_process_group()\n        logger.info(f\"Worker {rank}: Shutdown complete.\")\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/managers/scheduler.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\nimport asyncio\nimport os\nimport pickle\nfrom collections import deque\nfrom typing import Any, List\n\nimport zmq\n\nfrom sglang.multimodal_gen.configs.pipeline_configs.base import ModelTaskType\nfrom sglang.multimodal_gen.runtime.entrypoints.openai.utils import (\n    _parse_size,\n    save_image_to_path,\n)\nfrom sglang.multimodal_gen.runtime.entrypoints.post_training.io_struct import (\n    GetWeightsChecksumReqInput,\n    UpdateWeightFromDiskReqInput,\n)\nfrom sglang.multimodal_gen.runtime.entrypoints.utils import (\n    ListLorasReq,\n    MergeLoraWeightsReq,\n    SetLoraReq,\n    ShutdownReq,\n    UnmergeLoraWeightsReq,\n)\nfrom sglang.multimodal_gen.runtime.managers.gpu_worker import GPUWorker\nfrom sglang.multimodal_gen.runtime.pipelines_core import Req\nfrom sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import OutputBatch\nfrom sglang.multimodal_gen.runtime.server_args import (\n    PortArgs,\n    ServerArgs,\n    set_global_server_args,\n)\nfrom sglang.multimodal_gen.runtime.utils.common import get_zmq_socket\nfrom sglang.multimodal_gen.runtime.utils.distributed import broadcast_pyobj\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import GREEN, RESET, init_logger\n\nlogger = init_logger(__name__)\n\nMINIMUM_PICTURE_BASE64_FOR_WARMUP = \"data:image/jpg;base64,iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAYAAABzenr0AAAACXBIWXMAAA7EAAAOxAGVKw4bAAAAbUlEQVRYhe3VsQ2AMAxE0Y/lIgNQULD/OqyCMgCihCKSG4yRuKuiNH6JLsoEbMACOGBcua9HOR7Y6w6swBwMy0qLTpkeI77qdEBpBFAHBBDAGH8WrwJKI4AAegUCfAKgEgpQDvh3CR3oQCuav58qlAw73kKCSgAAAABJRU5ErkJggg==\"\n\n\nclass Scheduler:\n    \"\"\"\n    Runs the main event loop for the rank 0 worker.\n    It listens for external requests via ZMQ and coordinates with other workers.\n    This class does NOT manage worker processes.\n    \"\"\"\n\n    def __init__(\n        self,\n        server_args: ServerArgs,\n        gpu_id: int,\n        port_args: PortArgs,\n        task_pipes_to_slaves: list = None,\n        result_pipes_from_slaves: list = None,\n    ):\n        self.server_args = server_args\n        self.port_args = port_args\n\n        set_global_server_args(server_args=server_args)\n\n        # Inter-process Communication\n        self.context = zmq.Context(io_threads=2)\n        endpoint = server_args.scheduler_endpoint\n        if gpu_id == 0:\n            # router allocates identify (envelope) for each connection\n            self.receiver, actual_endpoint = get_zmq_socket(\n                self.context, zmq.ROUTER, endpoint, True\n            )\n            logger.info(f\"Scheduler bind at endpoint: {actual_endpoint}\")\n        else:\n            self.receiver = None\n\n        worker = GPUWorker(\n            local_rank=gpu_id,\n            master_port=port_args.master_port,\n            rank=gpu_id,\n            server_args=server_args,\n        )\n        self.worker = worker\n        self.task_pipes_to_slaves = task_pipes_to_slaves\n        self.result_pipes_from_slaves = result_pipes_from_slaves\n        self.gpu_id = gpu_id\n        self._running = True\n\n        self.request_handlers = {\n            SetLoraReq: self._handle_set_lora,\n            MergeLoraWeightsReq: self._handle_merge_lora,\n            UnmergeLoraWeightsReq: self._handle_unmerge_lora,\n            Req: self._handle_generation,\n            List[Req]: self._handle_generation,\n            ListLorasReq: self._handle_list_loras,\n            ShutdownReq: self._handle_shutdown,\n            UpdateWeightFromDiskReqInput: self._handle_update_weights_from_disk,\n            GetWeightsChecksumReqInput: self._handle_get_weights_checksum,\n        }\n\n        # FIFO, new reqs are appended\n        self.waiting_queue: deque[tuple[bytes, Req]] = deque()\n\n        # whether we've send the necessary warmup reqs\n        self.warmed_up = False\n        # warmup progress tracking\n        self._warmup_total = 0\n        self._warmup_processed = 0\n\n        self.prepare_server_warmup_reqs()\n\n        # Maximum consecutive errors before terminating the event loop\n        self._max_consecutive_errors = 3\n        self._consecutive_error_count = 0\n\n    def _handle_set_lora(self, reqs: List[Any]) -> OutputBatch:\n        # TODO: return set status\n        # TODO: return with SetLoRAResponse or something more appropriate\n        req = reqs[0]\n        return self.worker.set_lora(\n            req.lora_nickname, req.lora_path, req.target, req.strength\n        )\n\n    def _handle_merge_lora(self, reqs: List[Any]):\n        req = reqs[0]\n        return self.worker.merge_lora_weights(req.target, req.strength)\n\n    def _handle_unmerge_lora(self, reqs: List[Any]) -> OutputBatch:\n        req = reqs[0]\n        return self.worker.unmerge_lora_weights(req.target)\n\n    def _handle_list_loras(self, _reqs: List[Any]) -> OutputBatch:\n        return self.worker.list_loras()\n\n    def _handle_shutdown(self, _reqs: List[Any]) -> OutputBatch:\n        self._running = False\n        return OutputBatch()\n\n    def _handle_update_weights_from_disk(self, reqs: List[Any]) -> OutputBatch:\n        \"\"\"Handle update_weights_from_disk request for RL workflows.\"\"\"\n        req = reqs[0]\n        success, message = self.worker.update_weights_from_disk(\n            model_path=req.model_path,\n            flush_cache=req.flush_cache,\n            target_modules=req.target_modules,\n        )\n        return OutputBatch(\n            output={\"success\": success, \"message\": message},\n            error=None if success else message,\n        )\n\n    def _handle_get_weights_checksum(self, reqs: List[Any]) -> OutputBatch:\n        \"\"\"Handle get_weights_checksum request.\"\"\"\n        req = reqs[0]\n        checksums = self.worker.get_weights_checksum(module_names=req.module_names)\n        return OutputBatch(output=checksums)\n\n    def _handle_generation(self, reqs: List[Req]):\n        warmup_reqs = [req for req in reqs if req.is_warmup]\n        if warmup_reqs:\n            self._warmup_processed += len(warmup_reqs)\n            if self._warmup_total > 0:\n                logger.info(\n                    f\"Processing warmup req... ({self._warmup_processed}/{self._warmup_total})\"\n                )\n            else:\n                logger.info(\"Processing warmup req...\")\n        return self.worker.execute_forward(reqs)\n\n    def return_result(\n        self,\n        output_batch: OutputBatch,\n        identity: bytes | None = None,\n        is_warmup: bool = False,\n    ):\n        \"\"\"\n        replies to client, only on rank 0\n        \"\"\"\n        if not is_warmup and self.receiver is not None and identity is not None:\n            self.receiver.send_multipart([identity, b\"\", pickle.dumps(output_batch)])\n\n    def get_next_batch_to_run(self) -> list[tuple[bytes, Req]] | None:\n        \"\"\"pull a req from waiting_queue\"\"\"\n        if not self.waiting_queue:\n            return None\n\n        # pop the first (earliest)\n        item = self.waiting_queue.popleft()\n\n        return [item]\n\n    def prepare_server_warmup_reqs(self):\n        if (\n            self.server_args.warmup\n            and not self.warmed_up\n            and self.server_args.warmup_resolutions is not None\n        ):\n            # insert warmup reqs constructed with each warmup-resolution\n            self._warmup_total = len(self.server_args.warmup_resolutions)\n            self._warmup_processed = 0\n\n            for resolution in self.server_args.warmup_resolutions:\n                width, height = _parse_size(resolution)\n                task_type = self.server_args.pipeline_config.task_type\n\n                if task_type in (\n                    ModelTaskType.I2I,\n                    ModelTaskType.TI2I,\n                    ModelTaskType.I2V,\n                    ModelTaskType.TI2V,\n                ):\n                    uploads_dir = os.path.join(\"outputs\", \"uploads\")\n                    os.makedirs(uploads_dir, exist_ok=True)\n                    input_path = asyncio.run(\n                        save_image_to_path(\n                            MINIMUM_PICTURE_BASE64_FOR_WARMUP,\n                            os.path.join(uploads_dir, \"warmup_image.jpg\"),\n                        )\n                    )\n                    req = Req(\n                        data_type=task_type.data_type(),\n                        width=width,\n                        height=height,\n                        prompt=\"\",\n                        negative_prompt=\"\",\n                        image_path=[input_path],\n                    )\n                else:\n                    req = Req(\n                        data_type=task_type.data_type(),\n                        width=width,\n                        height=height,\n                        prompt=\"\",\n                    )\n                req.set_as_warmup(self.server_args.warmup_steps)\n                self.waiting_queue.append((None, req))\n            # if server is warmed-up, set this flag to avoid req-based warmup\n            self.warmed_up = True\n\n    def process_received_reqs_with_req_based_warmup(\n        self, recv_reqs: List[tuple[bytes, Any]]\n    ) -> List[tuple[bytes, Any]]:\n        if (\n            self.warmed_up\n            or not self.server_args.warmup\n            or not recv_reqs\n            or self.server_args.warmup_resolutions is not None\n        ):\n            return recv_reqs\n\n        # handle server req-based warmup by inserting an identical req to the beginning of the waiting queue\n        # only the very first req through server's lifetime will be warmed up\n        identity, req = recv_reqs[0]\n        if isinstance(req, Req):\n            warmup_req = req.copy_as_warmup(self.server_args.warmup_steps)\n            recv_reqs.insert(0, (identity, warmup_req))\n            self._warmup_total = 1\n            self._warmup_processed = 0\n            self.warmed_up = True\n        return recv_reqs\n\n    def recv_reqs(self) -> List[tuple[bytes, Any]]:\n        \"\"\"\n        For non-main schedulers, reqs are broadcasted from main using broadcast_pyobj\n        \"\"\"\n        if self.receiver is not None:\n            try:\n                try:\n                    # Accept valid REQ envelopes only, ignore malformed/probe frames.\n                    parts = self.receiver.recv_multipart(zmq.NOBLOCK)\n                    identity, payload = parts[0], parts[-1]\n\n                    # Ignore malformed probes or non-pickle data\n                    recv_reqs = pickle.loads(payload) if len(parts) > 2 else []\n                except (zmq.Again, pickle.UnpicklingError, IndexError, EOFError):\n                    recv_reqs = []\n            except zmq.ZMQError:\n                # re-raise or handle appropriately to let the outer loop continue\n                raise\n\n            if recv_reqs:\n                # Ensure recv_reqs is a list\n                if not isinstance(recv_reqs, list):\n                    recv_reqs = [recv_reqs]\n\n                # Pack with identity for rank 0\n                recv_reqs = [(identity, req) for req in recv_reqs]\n        else:\n            recv_reqs = None\n\n        # TODO: fix this condition\n        if self.server_args.sp_degree != 1:\n            recv_reqs = broadcast_pyobj(\n                recv_reqs,\n                self.worker.sp_group.rank,\n                self.worker.sp_cpu_group,\n                src=self.worker.sp_group.ranks[0],\n            )\n\n        if self.server_args.enable_cfg_parallel:\n            recv_reqs = broadcast_pyobj(\n                recv_reqs,\n                self.worker.cfg_group.rank,\n                self.worker.cfg_cpu_group,\n                src=self.worker.cfg_group.ranks[0],\n            )\n\n        if self.server_args.tp_size > 1:\n            recv_reqs = broadcast_pyobj(\n                recv_reqs,\n                self.worker.tp_group.rank,\n                self.worker.tp_cpu_group,\n                src=self.worker.tp_group.ranks[0],\n            )\n\n        assert recv_reqs is not None\n\n        return recv_reqs\n\n    def event_loop(self) -> None:\n        \"\"\"\n        The main event loop that listens for ZMQ requests.\n        Handles abortion\n        \"\"\"\n\n        logger.debug(\n            f\"Rank 0 scheduler listening on tcp://*:{self.server_args.scheduler_port}\"\n        )\n\n        while self._running:\n            # 1: receive requests\n            try:\n                new_reqs = self.recv_reqs()\n                new_reqs = self.process_received_reqs_with_req_based_warmup(new_reqs)\n                self.waiting_queue.extend(new_reqs)\n                # Reset error count on success\n                self._consecutive_error_count = 0\n            except Exception as e:\n                self._consecutive_error_count += 1\n                logger.error(\n                    f\"Error receiving requests in scheduler event loop \"\n                    f\"(attempt {self._consecutive_error_count}/{self._max_consecutive_errors}): {e}\",\n                    exc_info=True,\n                )\n                if self._consecutive_error_count >= self._max_consecutive_errors:\n                    logger.error(\n                        f\"Maximum consecutive errors ({self._max_consecutive_errors}) reached. \"\n                        \"Terminating scheduler event loop.\"\n                    )\n                    raise RuntimeError(\n                        f\"Scheduler terminated after {self._max_consecutive_errors} \"\n                        f\"consecutive errors. Last error: {e}\"\n                    ) from e\n                continue\n\n            # 2: execute, make sure a reply is always sent\n            items = self.get_next_batch_to_run()\n            if not items:\n                continue\n\n            identities = [item[0] for item in items]\n            reqs = [item[1] for item in items]\n\n            try:\n                processed_req = reqs[0]\n                handler = self.request_handlers.get(type(processed_req))\n                if handler:\n                    output_batch = handler(reqs)\n                else:\n                    output_batch = OutputBatch(\n                        error=f\"Unknown request type: {type(processed_req)}\"\n                    )\n            except Exception as e:\n                logger.error(\n                    f\"Error executing request in scheduler event loop: {e}\",\n                    exc_info=True,\n                )\n                # Determine appropriate error response format\n                output_batch = (\n                    OutputBatch(error=str(e))\n                    if reqs and isinstance(reqs[0], Req)\n                    else OutputBatch(error=str(e))\n                )\n\n            # 3. return results\n            try:\n                # log warmup info\n                is_warmup = (\n                    processed_req.is_warmup if isinstance(processed_req, Req) else False\n                )\n                if is_warmup:\n                    if output_batch.error is None:\n                        if self._warmup_total > 0:\n                            logger.info(\n                                f\"Warmup req ({self._warmup_processed}/{self._warmup_total}) processed in {GREEN}%.2f{RESET} seconds\",\n                                output_batch.metrics.total_duration_s,\n                            )\n                        else:\n                            logger.info(\n                                f\"Warmup req processed in {GREEN}%.2f{RESET} seconds\",\n                                output_batch.metrics.total_duration_s,\n                            )\n                    else:\n                        if self._warmup_total > 0:\n                            logger.info(\n                                f\"Warmup req ({self._warmup_processed}/{self._warmup_total}) processing failed\"\n                            )\n                        else:\n                            logger.info(f\"Warmup req processing failed\")\n\n                # TODO: Support sending back to multiple identities if batched\n                self.return_result(output_batch, identities[0], is_warmup=is_warmup)\n            except zmq.ZMQError as e:\n                # Reply failed; log and keep loop alive to accept future requests\n                logger.error(f\"ZMQ error sending reply: {e}\")\n                continue\n\n        if self.receiver is not None:\n            self.receiver.close()\n        self.context.destroy(linger=0)\n\n    def _broadcast_task(self, payload: dict[str, Any]) -> None:\n        \"\"\"Broadcast a task to all slave worker processes.\"\"\"\n        method = payload[\"method\"]\n        kwargs = {k: v for k, v in payload.items() if k != \"method\"}\n        task = {\"method\": method, \"kwargs\": kwargs}\n        for pipe in self.task_pipes_to_slaves:\n            pipe.send(task)\n\n    def _collect_slave_results(self) -> List[dict[str, Any]]:\n        \"\"\"Collect results from all slave worker processes.\"\"\"\n        results = []\n        for pipe in self.result_pipes_from_slaves:\n            results.append(pipe.recv())\n        return results\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/__init__.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/adapter/ltx_2_connector.py",
    "content": "from typing import Optional, Tuple, Union\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom diffusers.models.attention import FeedForward\n\nfrom sglang.multimodal_gen.configs.models.adapter.ltx_2_connector import (\n    LTX2ConnectorConfig,\n)\nfrom sglang.multimodal_gen.runtime.layers.attention import USPAttention\nfrom sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum\n\n\ndef apply_interleaved_rotary_emb(\n    x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Tensor]\n) -> torch.Tensor:\n    cos, sin = freqs\n    x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1)  # [B, S, C // 2]\n    x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2)\n    out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)\n    return out\n\n\ndef apply_split_rotary_emb(\n    x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Tensor]\n) -> torch.Tensor:\n    cos, sin = freqs\n\n    x_dtype = x.dtype\n    needs_reshape = False\n    if x.ndim != 4 and cos.ndim == 4:\n        # cos is (#b, h, t, r) -> reshape x to (b, h, t, dim_per_head)\n        # The cos/sin batch dim may only be broadcastable, so take batch size from x\n        b = x.shape[0]\n        _, h, t, _ = cos.shape\n        x = x.reshape(b, t, h, -1).transpose(1, 2)\n        needs_reshape = True\n\n    # Split last dim (2*r) into (d=2, r)\n    last = x.shape[-1]\n    if last % 2 != 0:\n        raise ValueError(\n            f\"Expected x.shape[-1] to be even for split rotary, got {last}.\"\n        )\n    r = last // 2\n\n    # (..., 2, r)\n    split_x = x.reshape(*x.shape[:-1], 2, r)\n    first_x = split_x[..., :1, :]  # (..., 1, r)\n    second_x = split_x[..., 1:, :]  # (..., 1, r)\n\n    cos_u = cos.unsqueeze(-2)  # broadcast to (..., 1, r) against (..., 2, r)\n    sin_u = sin.unsqueeze(-2)\n\n    out = split_x * cos_u\n    first_out = out[..., :1, :]\n    second_out = out[..., 1:, :]\n\n    first_out.addcmul_(-sin_u, second_x)\n    second_out.addcmul_(sin_u, first_x)\n\n    out = out.reshape(*out.shape[:-2], last)\n\n    if needs_reshape:\n        out = out.transpose(1, 2).reshape(b, t, -1)\n\n    out = out.to(dtype=x_dtype)\n    return out\n\n\nclass LTX2Attention(torch.nn.Module):\n    r\"\"\"\n    Attention class for all LTX-2.0 attention layers. Compared to LTX-1.0, this supports specifying the query and key\n    RoPE embeddings separately for audio-to-video (a2v) and video-to-audio (v2a) cross-attention.\n    \"\"\"\n\n    def __init__(\n        self,\n        query_dim: int,\n        heads: int = 8,\n        kv_heads: int = 8,\n        dim_head: int = 64,\n        dropout: float = 0.0,\n        bias: bool = True,\n        cross_attention_dim: Optional[int] = None,\n        out_bias: bool = True,\n        qk_norm: str = \"rms_norm_across_heads\",\n        norm_eps: float = 1e-6,\n        norm_elementwise_affine: bool = True,\n        rope_type: str = \"interleaved\",\n        processor=None,\n    ):\n        super().__init__()\n        if qk_norm != \"rms_norm_across_heads\":\n            raise NotImplementedError(\n                \"Only 'rms_norm_across_heads' is supported as a valid value for `qk_norm`.\"\n            )\n\n        self.head_dim = dim_head\n        self.inner_dim = dim_head * heads\n        self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads\n        self.query_dim = query_dim\n        self.cross_attention_dim = (\n            cross_attention_dim if cross_attention_dim is not None else query_dim\n        )\n        self.use_bias = bias\n        self.dropout = dropout\n        self.out_dim = query_dim\n        self.heads = heads\n        self.rope_type = rope_type\n\n        self.norm_q = torch.nn.RMSNorm(\n            dim_head * heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine\n        )\n        self.norm_k = torch.nn.RMSNorm(\n            dim_head * kv_heads,\n            eps=norm_eps,\n            elementwise_affine=norm_elementwise_affine,\n        )\n        self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)\n        self.to_k = torch.nn.Linear(\n            self.cross_attention_dim, self.inner_kv_dim, bias=bias\n        )\n        self.to_v = torch.nn.Linear(\n            self.cross_attention_dim, self.inner_kv_dim, bias=bias\n        )\n        self.to_out = torch.nn.ModuleList([])\n        self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))\n        self.to_out.append(torch.nn.Dropout(dropout))\n\n        # Scaled dot product attention\n        self.attn = USPAttention(\n            num_heads=heads,\n            head_size=self.head_dim,\n            dropout_rate=0,\n            softmax_scale=None,\n            causal=False,\n            supported_attention_backends={\n                AttentionBackendEnum.FA,\n                AttentionBackendEnum.AITER,\n                AttentionBackendEnum.TORCH_SDPA,\n                AttentionBackendEnum.SAGE_ATTN,\n                AttentionBackendEnum.SAGE_ATTN_3,\n            },\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        query_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n        key_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n    ) -> torch.Tensor:\n        batch_size, sequence_length, _ = (\n            hidden_states.shape\n            if encoder_hidden_states is None\n            else encoder_hidden_states.shape\n        )\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n\n        query = self.to_q(hidden_states)\n        key = self.to_k(encoder_hidden_states)\n        value = self.to_v(encoder_hidden_states)\n\n        query = self.norm_q(query)\n        key = self.norm_k(key)\n\n        if query_rotary_emb is not None:\n            if self.rope_type == \"interleaved\":\n                query = apply_interleaved_rotary_emb(query, query_rotary_emb)\n                key = apply_interleaved_rotary_emb(\n                    key,\n                    key_rotary_emb if key_rotary_emb is not None else query_rotary_emb,\n                )\n            elif self.rope_type == \"split\":\n                query = apply_split_rotary_emb(query, query_rotary_emb)\n                key = apply_split_rotary_emb(\n                    key,\n                    key_rotary_emb if key_rotary_emb is not None else query_rotary_emb,\n                )\n\n        query = query.unflatten(2, (self.heads, -1))\n        key = key.unflatten(2, (self.heads, -1))\n        value = value.unflatten(2, (self.heads, -1))\n\n        hidden_states = self.attn(\n            query,\n            key,\n            value,\n        )\n        hidden_states = hidden_states.flatten(2, 3)\n        hidden_states = hidden_states.to(query.dtype)\n\n        hidden_states = self.to_out[0](hidden_states)\n        hidden_states = self.to_out[1](hidden_states)\n        return hidden_states\n\n\nclass LTX2RotaryPosEmbed1d(nn.Module):\n    \"\"\"\n    1D rotary positional embeddings (RoPE) for the LTX 2.0 text encoder connectors.\n    \"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        base_seq_len: int = 4096,\n        theta: float = 10000.0,\n        double_precision: bool = True,\n        rope_type: str = \"interleaved\",\n        num_attention_heads: int = 32,\n    ):\n        super().__init__()\n        if rope_type not in [\"interleaved\", \"split\"]:\n            raise ValueError(\n                f\"{rope_type=} not supported. Choose between 'interleaved' and 'split'.\"\n            )\n\n        self.dim = dim\n        self.base_seq_len = base_seq_len\n        self.theta = theta\n        self.double_precision = double_precision\n        self.rope_type = rope_type\n        self.num_attention_heads = num_attention_heads\n\n    def forward(\n        self,\n        batch_size: int,\n        pos: int,\n        device: Union[str, torch.device],\n        dtype: Optional[torch.dtype] = None,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        # 1. Get 1D position ids\n        grid_1d = torch.arange(pos, dtype=torch.float32, device=device)\n        # Get fractional indices relative to self.base_seq_len\n        grid_1d = grid_1d / self.base_seq_len\n        grid = grid_1d.unsqueeze(0).repeat(batch_size, 1)  # [batch_size, seq_len]\n\n        # 2. Calculate 1D RoPE frequencies\n        num_rope_elems = 2  # 1 (because 1D) * 2 (for cos, sin) = 2\n        freqs_dtype = torch.float64 if self.double_precision else torch.float32\n        pow_indices = torch.pow(\n            self.theta,\n            torch.linspace(\n                start=0.0,\n                end=1.0,\n                steps=self.dim // num_rope_elems,\n                dtype=freqs_dtype,\n                device=device,\n            ),\n        )\n        freqs = (pow_indices * torch.pi / 2.0).to(dtype=torch.float32)\n\n        # 3. Matrix-vector outer product between pos ids of shape (batch_size, seq_len) and freqs vector of shape\n        # (self.dim // 2,).\n        freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs  # [B, seq_len, self.dim // 2]\n\n        # 4. Get real, interleaved (cos, sin) frequencies, padded to self.dim\n        if self.rope_type == \"interleaved\":\n            cos_freqs = freqs.cos().repeat_interleave(2, dim=-1)\n            sin_freqs = freqs.sin().repeat_interleave(2, dim=-1)\n\n            if self.dim % num_rope_elems != 0:\n                cos_padding = torch.ones_like(\n                    cos_freqs[:, :, : self.dim % num_rope_elems]\n                )\n                sin_padding = torch.zeros_like(\n                    sin_freqs[:, :, : self.dim % num_rope_elems]\n                )\n                cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1)\n                sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1)\n\n        elif self.rope_type == \"split\":\n            expected_freqs = self.dim // 2\n            current_freqs = freqs.shape[-1]\n            pad_size = expected_freqs - current_freqs\n            cos_freq = freqs.cos()\n            sin_freq = freqs.sin()\n\n            if pad_size != 0:\n                cos_padding = torch.ones_like(cos_freq[:, :, :pad_size])\n                sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size])\n\n                cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1)\n                sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1)\n\n            # Reshape freqs to be compatible with multi-head attention\n            b = cos_freq.shape[0]\n            t = cos_freq.shape[1]\n\n            cos_freq = cos_freq.reshape(b, t, self.num_attention_heads, -1)\n            sin_freq = sin_freq.reshape(b, t, self.num_attention_heads, -1)\n\n            cos_freqs = torch.swapaxes(cos_freq, 1, 2)  # (B,H,T,D//2)\n            sin_freqs = torch.swapaxes(sin_freq, 1, 2)  # (B,H,T,D//2)\n\n        if dtype is not None:\n            cos_freqs = cos_freqs.to(dtype)\n            sin_freqs = sin_freqs.to(dtype)\n        return cos_freqs, sin_freqs\n\n\nclass LTX2TransformerBlock1d(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        num_attention_heads: int,\n        attention_head_dim: int,\n        activation_fn: str = \"gelu-approximate\",\n        eps: float = 1e-6,\n        rope_type: str = \"interleaved\",\n    ):\n        super().__init__()\n\n        self.norm1 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False)\n        self.attn1 = LTX2Attention(\n            query_dim=dim,\n            heads=num_attention_heads,\n            kv_heads=num_attention_heads,\n            dim_head=attention_head_dim,\n            rope_type=rope_type,\n        )\n\n        self.norm2 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False)\n        self.ff = FeedForward(dim, activation_fn=activation_fn)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        rotary_emb: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        norm_hidden_states = self.norm1(hidden_states)\n        attn_hidden_states = self.attn1(\n            norm_hidden_states,\n            attention_mask=attention_mask,\n            query_rotary_emb=rotary_emb,\n        )\n        hidden_states = hidden_states + attn_hidden_states\n\n        norm_hidden_states = self.norm2(hidden_states)\n        ff_hidden_states = self.ff(norm_hidden_states)\n        hidden_states = hidden_states + ff_hidden_states\n\n        return hidden_states\n\n\nclass LTX2ConnectorTransformer1d(nn.Module):\n    \"\"\"\n    A 1D sequence transformer for modalities such as text.\n    In LTX 2.0, this is used to process the text encoder hidden states for each of the video and audio streams.\n    \"\"\"\n\n    _supports_gradient_checkpointing = True\n\n    def __init__(\n        self,\n        num_attention_heads: int = 30,\n        attention_head_dim: int = 128,\n        num_layers: int = 2,\n        num_learnable_registers: int | None = 128,\n        rope_base_seq_len: int = 4096,\n        rope_theta: float = 10000.0,\n        rope_double_precision: bool = True,\n        eps: float = 1e-6,\n        causal_temporal_positioning: bool = False,\n        rope_type: str = \"interleaved\",\n    ):\n        super().__init__()\n        self.num_attention_heads = num_attention_heads\n        self.inner_dim = num_attention_heads * attention_head_dim\n        self.causal_temporal_positioning = causal_temporal_positioning\n\n        self.num_learnable_registers = num_learnable_registers\n        self.learnable_registers = None\n        if num_learnable_registers is not None:\n            init_registers = (\n                torch.rand(num_learnable_registers, self.inner_dim) * 2.0 - 1.0\n            )\n            self.learnable_registers = torch.nn.Parameter(init_registers)\n\n        self.rope = LTX2RotaryPosEmbed1d(\n            self.inner_dim,\n            base_seq_len=rope_base_seq_len,\n            theta=rope_theta,\n            double_precision=rope_double_precision,\n            rope_type=rope_type,\n            num_attention_heads=num_attention_heads,\n        )\n\n        self.transformer_blocks = torch.nn.ModuleList(\n            [\n                LTX2TransformerBlock1d(\n                    dim=self.inner_dim,\n                    num_attention_heads=num_attention_heads,\n                    attention_head_dim=attention_head_dim,\n                    rope_type=rope_type,\n                )\n                for _ in range(num_layers)\n            ]\n        )\n\n        self.norm_out = torch.nn.RMSNorm(\n            self.inner_dim, eps=eps, elementwise_affine=False\n        )\n\n        self.gradient_checkpointing = False\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        attn_mask_binarize_threshold: float = -9000.0,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        # hidden_states shape: [batch_size, seq_len, hidden_dim]\n        # attention_mask shape: [batch_size, seq_len] or [batch_size, 1, 1, seq_len]\n        batch_size, seq_len, _ = hidden_states.shape\n\n        # 1. Replace padding with learned registers, if using\n        if self.learnable_registers is not None:\n            if seq_len % self.num_learnable_registers != 0:\n                raise ValueError(\n                    f\"The `hidden_states` sequence length {hidden_states.shape[1]} should be divisible by the number\"\n                    f\" of learnable registers {self.num_learnable_registers}\"\n                )\n\n            num_register_repeats = seq_len // self.num_learnable_registers\n            registers = torch.tile(\n                self.learnable_registers, (num_register_repeats, 1)\n            )  # [seq_len, inner_dim]\n\n            binary_attn_mask = (attention_mask >= attn_mask_binarize_threshold).int()\n            if binary_attn_mask.ndim == 4:\n                binary_attn_mask = binary_attn_mask.squeeze(1).squeeze(\n                    1\n                )  # [B, 1, 1, L] --> [B, L]\n\n            hidden_states_non_padded = [\n                hidden_states[i, binary_attn_mask[i].bool(), :]\n                for i in range(batch_size)\n            ]\n            valid_seq_lens = [x.shape[0] for x in hidden_states_non_padded]\n            pad_lengths = [seq_len - valid_seq_len for valid_seq_len in valid_seq_lens]\n            padded_hidden_states = [\n                F.pad(x, pad=(0, 0, 0, p), value=0)\n                for x, p in zip(hidden_states_non_padded, pad_lengths)\n            ]\n            padded_hidden_states = torch.cat(\n                [x.unsqueeze(0) for x in padded_hidden_states], dim=0\n            )  # [B, L, D]\n\n            flipped_mask = torch.flip(binary_attn_mask, dims=[1]).unsqueeze(\n                -1\n            )  # [B, L, 1]\n            hidden_states = (\n                flipped_mask * padded_hidden_states + (1 - flipped_mask) * registers\n            )\n\n            # Overwrite attention_mask with an all-zeros mask if using registers.\n            attention_mask = torch.zeros_like(attention_mask)\n\n        # 2. Calculate 1D RoPE positional embeddings\n        rotary_emb = self.rope(\n            batch_size, seq_len, device=hidden_states.device, dtype=hidden_states.dtype\n        )\n\n        # 3. Run 1D transformer blocks\n        for block in self.transformer_blocks:\n            if torch.is_grad_enabled() and self.gradient_checkpointing:\n                hidden_states = self._gradient_checkpointing_func(\n                    block, hidden_states, attention_mask, rotary_emb\n                )\n            else:\n                hidden_states = block(\n                    hidden_states, attention_mask=attention_mask, rotary_emb=rotary_emb\n                )\n\n        hidden_states = self.norm_out(hidden_states)\n\n        return hidden_states, attention_mask\n\n\nclass LTX2TextConnectors(nn.Module):\n    \"\"\"\n    Text connector stack used by LTX 2.0 to process the packed text encoder hidden states for both the video and audio\n    streams.\n    \"\"\"\n\n    def __init__(\n        self,\n        config: LTX2ConnectorConfig,\n    ):\n        super().__init__()\n        caption_channels = config.caption_channels\n        text_proj_in_factor = config.text_proj_in_factor\n        video_connector_num_attention_heads = config.video_connector_num_attention_heads\n        video_connector_attention_head_dim = config.video_connector_attention_head_dim\n        video_connector_num_layers = config.video_connector_num_layers\n        video_connector_num_learnable_registers = (\n            config.video_connector_num_learnable_registers\n        )\n        audio_connector_num_attention_heads = config.audio_connector_num_attention_heads\n        audio_connector_attention_head_dim = config.audio_connector_attention_head_dim\n        audio_connector_num_layers = config.audio_connector_num_layers\n        audio_connector_num_learnable_registers = (\n            config.audio_connector_num_learnable_registers\n        )\n        connector_rope_base_seq_len = config.connector_rope_base_seq_len\n        rope_theta = config.rope_theta\n        rope_double_precision = config.rope_double_precision\n        causal_temporal_positioning = config.causal_temporal_positioning\n        rope_type = config.rope_type\n\n        self.text_proj_in = nn.Linear(\n            caption_channels * text_proj_in_factor, caption_channels, bias=False\n        )\n        self.video_connector = LTX2ConnectorTransformer1d(\n            num_attention_heads=video_connector_num_attention_heads,\n            attention_head_dim=video_connector_attention_head_dim,\n            num_layers=video_connector_num_layers,\n            num_learnable_registers=video_connector_num_learnable_registers,\n            rope_base_seq_len=connector_rope_base_seq_len,\n            rope_theta=rope_theta,\n            rope_double_precision=rope_double_precision,\n            causal_temporal_positioning=causal_temporal_positioning,\n            rope_type=rope_type,\n        )\n        self.audio_connector = LTX2ConnectorTransformer1d(\n            num_attention_heads=audio_connector_num_attention_heads,\n            attention_head_dim=audio_connector_attention_head_dim,\n            num_layers=audio_connector_num_layers,\n            num_learnable_registers=audio_connector_num_learnable_registers,\n            rope_base_seq_len=connector_rope_base_seq_len,\n            rope_theta=rope_theta,\n            rope_double_precision=rope_double_precision,\n            causal_temporal_positioning=causal_temporal_positioning,\n            rope_type=rope_type,\n        )\n\n    def forward(\n        self,\n        text_encoder_hidden_states: torch.Tensor,\n        attention_mask: torch.Tensor,\n        additive_mask: bool = False,\n    ):\n        # Convert to additive attention mask, if necessary\n        if not additive_mask:\n            text_dtype = text_encoder_hidden_states.dtype\n            attention_mask = (attention_mask - 1).reshape(\n                attention_mask.shape[0], 1, -1, attention_mask.shape[-1]\n            )\n            attention_mask = attention_mask.to(text_dtype) * torch.finfo(text_dtype).max\n\n        # Ensure input dtype matches the layer's weight dtype\n        if text_encoder_hidden_states.dtype != self.text_proj_in.weight.dtype:\n            text_encoder_hidden_states = text_encoder_hidden_states.to(\n                self.text_proj_in.weight.dtype\n            )\n\n        # Ensure sequence length is divisible by num_learnable_registers (128)\n        seq_len = text_encoder_hidden_states.shape[1]\n        num_learnable_registers = self.video_connector.num_learnable_registers\n        if (\n            num_learnable_registers is not None\n            and seq_len % num_learnable_registers != 0\n        ):\n            pad_len = num_learnable_registers - (seq_len % num_learnable_registers)\n            text_encoder_hidden_states = F.pad(\n                text_encoder_hidden_states, (0, 0, 0, pad_len), value=0.0\n            )\n\n            if attention_mask.shape[-1] == seq_len:\n                # Pad with a large negative value to mask out the new tokens\n                attention_mask = F.pad(attention_mask, (0, pad_len), value=-1000000.0)\n\n        text_encoder_hidden_states = self.text_proj_in(text_encoder_hidden_states)\n\n        video_text_embedding, new_attn_mask = self.video_connector(\n            text_encoder_hidden_states, attention_mask\n        )\n\n        attn_mask = (new_attn_mask < 1e-6).to(torch.int64)\n        attn_mask = attn_mask.reshape(\n            video_text_embedding.shape[0], video_text_embedding.shape[1], 1\n        )\n        video_text_embedding = video_text_embedding * attn_mask\n        new_attn_mask = attn_mask.squeeze(-1)\n\n        audio_text_embedding, _ = self.audio_connector(\n            text_encoder_hidden_states, attention_mask\n        )\n\n        return video_text_embedding, audio_text_embedding, new_attn_mask\n\n\nEntryClass = LTX2TextConnectors\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/bridges/__init__.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\nfrom sglang.multimodal_gen.runtime.models.bridges.mova_dual_tower import (\n    DualTowerConditionalBridge,\n)\n\n__all__ = [\"DualTowerConditionalBridge\"]\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/bridges/mova_dual_tower.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n# Copied and adapted from: mossVG/mova/diffusion/models/interactionv2.py\n\n\nfrom typing import Any, Dict, List, Optional, Tuple\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom einops import rearrange\n\nfrom sglang.multimodal_gen.configs.models.bridges.mova_dual_tower import (\n    MOVADualTowerConfig,\n)\nfrom sglang.multimodal_gen.runtime.distributed import get_tp_world_size\nfrom sglang.multimodal_gen.runtime.layers.attention import USPAttention\nfrom sglang.multimodal_gen.runtime.layers.layernorm import (\n    RMSNorm,\n    tensor_parallel_rms_norm,\n)\nfrom sglang.multimodal_gen.runtime.layers.linear import (\n    ColumnParallelLinear,\n    ReplicatedLinear,\n    RowParallelLinear,\n)\nfrom sglang.multimodal_gen.runtime.layers.rotary_embedding import (\n    apply_flashinfer_rope_qk_inplace,\n)\nfrom sglang.multimodal_gen.runtime.models.dits.base import CachableDiT\nfrom sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\n@torch.no_grad()\ndef compute_rope_cos_sin(\n    position_ids: torch.Tensor,\n    head_dim: int,\n    base: float = 10000.0,\n    device: Optional[torch.device] = None,\n    dtype: Optional[torch.dtype] = None,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Compute RoPE cos/sin embeddings for given position IDs.\n\n    This is a functional implementation that doesn't require storing buffers,\n    making it compatible with FSDP meta device initialization.\n\n    Args:\n        position_ids: Position IDs tensor [B, L] or [1, L]\n        head_dim: Dimension of each attention head\n        base: RoPE base frequency (default: 10000.0)\n        device: Target device\n        dtype: Output dtype\n\n    Returns:\n        (cos, sin): Each with shape [B, L, head_dim]\n    \"\"\"\n    device = device or position_ids.device\n    dtype = dtype or torch.float32\n\n    # Compute inverse frequencies\n    inv_freq = 1.0 / (\n        base\n        ** (torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) / head_dim)\n    )\n\n    # Expand for batch computation: [B, L] -> [B, 1, L] @ [1, head_dim/2, 1] -> [B, head_dim/2, L]\n    inv_freq_expanded = inv_freq[None, :, None].expand(position_ids.shape[0], -1, 1)\n    position_ids_expanded = position_ids[:, None, :].float()\n\n    # Compute frequencies: [B, head_dim/2, L] -> [B, L, head_dim/2]\n    freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)\n\n    # Double the frequencies for full head_dim: [B, L, head_dim]\n    emb = torch.cat((freqs, freqs), dim=-1)\n\n    cos = emb.cos().to(dtype=dtype)\n    sin = emb.sin().to(dtype=dtype)\n\n    return cos, sin\n\n\nclass PerFrameAttentionPooling(nn.Module):\n    \"\"\"Per-frame multi-head attention pooling.\n\n    Flattens the input sequence [B, L, D] and grid size (T, H, W).\n    Performs single-query attention pooling on the H*W tokens for each time frame.\n    Output shape: [B, T, D].\n    \"\"\"\n\n    def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):\n        super().__init__()\n        assert dim % num_heads == 0, \"dim must be divisible by num_heads\"\n        self.dim = dim\n        self.num_heads = num_heads\n\n        self.probe = nn.Parameter(torch.randn(1, 1, dim))\n        nn.init.normal_(self.probe, std=0.02)\n\n        self.attention = nn.MultiheadAttention(\n            embed_dim=dim, num_heads=num_heads, batch_first=True\n        )\n        self.layernorm = nn.LayerNorm(dim, eps=eps)\n\n    def forward(self, x: torch.Tensor, grid_size: Tuple[int, int, int]) -> torch.Tensor:\n        \"\"\"Forward pass.\n\n        Args:\n            x: Input tensor of shape [B, L, D], where L = T * H * W.\n            grid_size: Tuple of (T, H, W).\n\n        Returns:\n            Pooled tensor of shape [B, T, D].\n        \"\"\"\n        B, L, D = x.shape\n        T, H, W = grid_size\n        assert (\n            D == self.dim\n        ), f\"Input dimension D={D} does not match module dim={self.dim}\"\n        assert L == T * H * W, f\"Flattened length L={L} does not match T*H*W={T*H*W}\"\n\n        S = H * W\n        x_bt_s_d = x.view(B, T, S, D).contiguous().view(B * T, S, D)  # [B*T, S, D]\n        probe = self.probe.expand(B * T, -1, -1)  # [B*T, 1, D]\n\n        pooled_bt_1_d = self.attention(probe, x_bt_s_d, x_bt_s_d, need_weights=False)[0]\n        pooled_bt_d = pooled_bt_1_d.squeeze(1)  # [B*T, D]\n\n        pooled = pooled_bt_d.view(B, T, D)\n        pooled = self.layernorm(pooled)\n        return pooled\n\n\nclass CrossModalInteractionController:\n    \"\"\"Strategy class to control dual-tower interaction.\n\n    Manages the interaction mapping between Visual DiT (e.g., 30 layers)\n    and Audio DiT (e.g., 30 layers).\n    \"\"\"\n\n    def __init__(self, visual_layers: int = 30, audio_layers: int = 30):\n        self.visual_layers = visual_layers\n        self.audio_layers = audio_layers\n        self.min_layers = min(visual_layers, audio_layers)\n\n    def get_interaction_layers(\n        self, strategy: str = \"shallow_focus\"\n    ) -> Dict[str, List[Tuple[int, int]]]:\n        \"\"\"Gets the mapping relationship of interaction layers.\"\"\"\n        if strategy == \"shallow_focus\":\n            num_interact = min(10, self.min_layers // 3)\n            interact_layers = list(range(0, num_interact))\n        elif strategy == \"distributed\":\n            step = 3\n            interact_layers = list(range(0, self.min_layers, step))\n        elif strategy == \"progressive\":\n            shallow = list(range(0, min(8, self.min_layers)))\n            if self.min_layers > 8:\n                deep = list(range(8, self.min_layers, 3))\n                interact_layers = shallow + deep\n            else:\n                interact_layers = shallow\n        elif strategy == \"custom\":\n            interact_layers = [0, 2, 4, 6, 8, 12, 16, 20]\n            interact_layers = [i for i in interact_layers if i < self.min_layers]\n        elif strategy == \"full\":\n            interact_layers = list(range(0, self.min_layers))\n        else:\n            raise ValueError(f\"Unknown interaction strategy: {strategy}\")\n\n        mapping = {\n            \"v2a\": [(i, i) for i in interact_layers],\n            \"a2v\": [(i, i) for i in interact_layers],\n        }\n        return mapping\n\n    def should_interact(\n        self, layer_idx: int, direction: str, interaction_mapping: Dict\n    ) -> bool:\n        \"\"\"Determines if the specified layer needs to interact.\"\"\"\n        if direction not in interaction_mapping:\n            return False\n        return any(src == layer_idx for src, _ in interaction_mapping[direction])\n\n\nclass ConditionalCrossAttention(nn.Module):\n    \"\"\"\n    Cross-modal attention for dual-tower bridge with Tensor Parallel support.\n\n    This module handles attention between video and audio hidden states,\n    which have different sequence lengths.\n    \"\"\"\n\n    def __init__(self, dim: int, kv_dim: int, num_heads: int, eps: float = 1e-6):\n        super().__init__()\n        self.q_dim = dim\n        self.kv_dim = kv_dim\n        self.num_heads = num_heads\n        self.head_dim = self.q_dim // num_heads\n\n        self.tp_size = get_tp_world_size()\n        if self.num_heads % self.tp_size != 0:\n            raise ValueError(\n                f\"num_heads ({self.num_heads}) must be divisible by tp_size ({self.tp_size}).\"\n            )\n        self.num_heads_per_rank = self.num_heads // self.tp_size\n\n        # TP strategy: shard Q/K/V over heads (column-parallel), then row-parallel output.\n        self.q = ColumnParallelLinear(dim, dim, bias=True, gather_output=False)\n        self.k = ColumnParallelLinear(kv_dim, dim, bias=True, gather_output=False)\n        self.v = ColumnParallelLinear(kv_dim, dim, bias=True, gather_output=False)\n        self.o = RowParallelLinear(dim, dim, bias=True, input_is_parallel=True)\n        self.norm_q = RMSNorm(dim, eps=eps)\n        self.norm_k = RMSNorm(dim, eps=eps)\n\n        self.attn = USPAttention(\n            num_heads=self.num_heads_per_rank,\n            head_size=self.head_dim,\n            causal=False,\n            softmax_scale=None,\n            # is_cross_attention=True,\n        )\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        y: torch.Tensor,\n        x_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n        y_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n    ):\n        ctx = y\n        q, _ = self.q(x)\n        k, _ = self.k(ctx)\n        v, _ = self.v(ctx)\n\n        # RMSNorm over sharded hidden dimension\n        if self.tp_size > 1:\n            q = tensor_parallel_rms_norm(q, self.norm_q)\n            k = tensor_parallel_rms_norm(k, self.norm_k)\n        else:\n            q = self.norm_q(q)\n            k = self.norm_k(k)\n\n        if x_freqs is not None:\n            x_cos, x_sin = x_freqs\n            q_view = rearrange(q, \"b l (h d) -> b l h d\", d=self.head_dim)\n            x_cos = x_cos.to(q_view.dtype).to(q_view.device).squeeze(0)\n            x_sin = x_sin.to(q_view.dtype).to(q_view.device).squeeze(0)\n            # FlashInfer expects cos_sin_cache with shape [seqlen, head_dim],\n            # where the first half is cos and the second half is sin, each with\n            # head_dim//2 elements. Since compute_rope_cos_sin duplicates the\n            # frequencies (cat((freqs, freqs))), we only take the first half.\n            half_dim = self.head_dim // 2\n            cos_sin_cache = torch.cat(\n                [\n                    x_cos[:, :half_dim].to(dtype=torch.float32).contiguous(),\n                    x_sin[:, :half_dim].to(dtype=torch.float32).contiguous(),\n                ],\n                dim=-1,\n            )\n            q_view, _ = apply_flashinfer_rope_qk_inplace(\n                q_view, q_view.clone(), cos_sin_cache, is_neox=True\n            )\n            q = rearrange(q_view, \"b l h d -> b l (h d)\")\n\n        if y_freqs is not None:\n            y_cos, y_sin = y_freqs\n            k_view = rearrange(k, \"b l (h d) -> b l h d\", d=self.head_dim)\n            y_cos = y_cos.to(k_view.dtype).to(k_view.device).squeeze(0)\n            y_sin = y_sin.to(k_view.dtype).to(k_view.device).squeeze(0)\n            # FlashInfer expects cos_sin_cache with shape [seqlen, head_dim],\n            # where the first half is cos and the second half is sin, each with\n            # head_dim//2 elements. Since compute_rope_cos_sin duplicates the\n            # frequencies (cat((freqs, freqs))), we only take the first half.\n            half_dim = self.head_dim // 2\n            cos_sin_cache = torch.cat(\n                [\n                    y_cos[:, :half_dim].to(dtype=torch.float32).contiguous(),\n                    y_sin[:, :half_dim].to(dtype=torch.float32).contiguous(),\n                ],\n                dim=-1,\n            )\n            k_view, _ = apply_flashinfer_rope_qk_inplace(\n                k_view, k_view.clone(), cos_sin_cache, is_neox=True\n            )\n            k = rearrange(k_view, \"b l h d -> b l (h d)\")\n\n        q = rearrange(q, \"b l (h d) -> b l h d\", h=self.num_heads_per_rank)\n        k = rearrange(k, \"b l (h d) -> b l h d\", h=self.num_heads_per_rank)\n        v = rearrange(v, \"b l (h d) -> b l h d\", h=self.num_heads_per_rank)\n\n        x = self.attn(q, k, v)\n        x = rearrange(x, \"b l h d -> b l (h d)\")\n        x, _ = self.o(x)\n        return x\n\n\nclass AdaLayerNorm(nn.Module):\n    \"\"\"\n    Norm layer modified to incorporate timestep embeddings.\n    \"\"\"\n\n    def __init__(\n        self,\n        embedding_dim: int,\n        num_embeddings: Optional[int] = None,\n        output_dim: Optional[int] = None,\n        norm_elementwise_affine: bool = False,\n        norm_eps: float = 1e-5,\n        chunk_dim: int = 0,\n    ):\n        super().__init__()\n\n        self.chunk_dim = chunk_dim\n        output_dim = output_dim or embedding_dim * 2\n\n        if num_embeddings is not None:\n            self.emb = nn.Embedding(num_embeddings, embedding_dim)\n        else:\n            self.emb = None\n\n        self.silu = nn.SiLU()\n        self.linear = ReplicatedLinear(embedding_dim, output_dim)\n        self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        timestep: Optional[torch.Tensor] = None,\n        temb: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        if self.emb is not None:\n            temb = self.emb(timestep)\n\n        temb, _ = self.linear(self.silu(temb))\n\n        if self.chunk_dim == 2:\n            scale, shift = temb.chunk(2, dim=2)\n        elif self.chunk_dim == 1:\n            shift, scale = temb.chunk(2, dim=1)\n            shift = shift[:, None, :]\n            scale = scale[:, None, :]\n        else:\n            scale, shift = temb.chunk(2, dim=0)\n\n        x = self.norm(x) * (1 + scale) + shift\n        return x\n\n\nclass ConditionalCrossAttentionBlock(nn.Module):\n    \"\"\"A wrapper block for ConditionalCrossAttention that applies LayerNorm to the condition input y.\"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        kv_dim: int,\n        num_heads: int,\n        eps: float = 1e-6,\n        pooled_adaln: bool = False,\n    ):\n        super().__init__()\n        self.y_norm = nn.LayerNorm(kv_dim, eps=eps)\n        self.inner = ConditionalCrossAttention(\n            dim=dim, kv_dim=kv_dim, num_heads=num_heads, eps=eps\n        )\n        self.pooled_adaln = pooled_adaln\n        if pooled_adaln:\n            self.per_frame_pooling = PerFrameAttentionPooling(\n                kv_dim, num_heads=num_heads, eps=eps\n            )\n            self.adaln = AdaLayerNorm(kv_dim, output_dim=dim * 2, chunk_dim=2)\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        y: torch.Tensor,\n        x_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n        y_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n        video_grid_size: Optional[Tuple[int, int, int]] = None,\n    ) -> torch.Tensor:\n        if self.pooled_adaln:\n            assert video_grid_size is not None, \"video_grid_size cannot be None\"\n            pooled_y = self.per_frame_pooling(y, video_grid_size)\n            if pooled_y.shape[1] != x.shape[1]:\n                pooled_y = F.interpolate(\n                    pooled_y.permute(0, 2, 1),\n                    size=x.shape[1],\n                    mode=\"linear\",\n                    align_corners=False,\n                ).permute(0, 2, 1)\n            x = self.adaln(x, temb=pooled_y)\n        y = self.y_norm(y)\n        return self.inner(x=x, y=y, x_freqs=x_freqs, y_freqs=y_freqs)\n\n\nclass DualTowerConditionalBridge(\n    CachableDiT,\n    OffloadableDiTMixin,\n):\n    \"\"\"Dual-tower conditional bridge module v2 (SGLang optimized version).\n\n    Implements the correct architecture:\n    1. Audio latents -> Audio DiT -> Audio hidden states [B, L, 1536].\n    2. Visual latents -> Visual DiT -> Visual hidden states [B, L, 5120].\n    3. Cross-attention interaction between the hidden states of the two DiTs.\n    \"\"\"\n\n    _fsdp_shard_conditions = MOVADualTowerConfig()._fsdp_shard_conditions\n    _compile_conditions = MOVADualTowerConfig()._compile_conditions\n    _supported_attention_backends = MOVADualTowerConfig()._supported_attention_backends\n    param_names_mapping = MOVADualTowerConfig().param_names_mapping\n    reverse_param_names_mapping = MOVADualTowerConfig().reverse_param_names_mapping\n    lora_param_names_mapping = MOVADualTowerConfig().lora_param_names_mapping\n\n    def __init__(\n        self,\n        config: MOVADualTowerConfig | None = None,\n        hf_config: dict[str, Any] | None = None,\n        # Fallback parameters for from_pretrained compatibility\n        visual_layers: int = 40,\n        audio_layers: int = 30,\n        visual_hidden_dim: int = 5120,\n        audio_hidden_dim: int = 1536,\n        audio_fps: float = 50.0,\n        head_dim: int = 128,\n        interaction_strategy: str = \"full\",\n        apply_cross_rope: bool = True,\n        apply_first_frame_bias_in_rope: bool = False,\n        trainable_condition_scale: bool = False,\n        pooled_adaln: bool = False,\n    ):\n        super().__init__(config=config, hf_config=hf_config)\n\n        # Use config if provided, otherwise use individual parameters\n        if config is not None:\n            visual_layers = config.visual_layers\n            audio_layers = config.audio_layers\n            visual_hidden_dim = config.visual_hidden_dim\n            audio_hidden_dim = config.audio_hidden_dim\n            audio_fps = config.audio_fps\n            head_dim = config.head_dim\n            interaction_strategy = config.interaction_strategy\n            apply_cross_rope = config.apply_cross_rope\n            apply_first_frame_bias_in_rope = config.apply_first_frame_bias_in_rope\n            trainable_condition_scale = config.trainable_condition_scale\n            pooled_adaln = config.pooled_adaln\n\n        self.visual_hidden_dim = visual_hidden_dim\n        self.audio_hidden_dim = audio_hidden_dim\n        self.audio_fps = audio_fps\n        self.head_dim = head_dim\n        self.apply_cross_rope = apply_cross_rope\n        self.apply_first_frame_bias_in_rope = apply_first_frame_bias_in_rope\n        self.trainable_condition_scale = trainable_condition_scale\n        self.pooled_adaln = pooled_adaln\n\n        if self.trainable_condition_scale:\n            self.condition_scale = nn.Parameter(\n                torch.tensor([1.0], dtype=torch.float32)\n            )\n        else:\n            self.condition_scale = 1.0\n\n        self.controller = CrossModalInteractionController(visual_layers, audio_layers)\n        self.interaction_mapping = self.controller.get_interaction_layers(\n            interaction_strategy\n        )\n\n        # Cross-modal attention modules - interaction at DiT hidden states level\n        self.audio_to_video_conditioners = nn.ModuleDict()\n        self.video_to_audio_conditioners = nn.ModuleDict()\n\n        self.rope_base = 10000.0  # RoPE base frequency hardcode. adapted from original mova implementation.\n\n        # Audio DiT hidden states conditioning Video DiT\n        for v_layer, _ in self.interaction_mapping[\"a2v\"]:\n            self.audio_to_video_conditioners[str(v_layer)] = (\n                ConditionalCrossAttentionBlock(\n                    dim=visual_hidden_dim,\n                    kv_dim=audio_hidden_dim,\n                    num_heads=visual_hidden_dim // head_dim,\n                    pooled_adaln=False,\n                )\n            )\n\n        # Visual DiT hidden states conditioning Audio DiT\n        for a_layer, _ in self.interaction_mapping[\"v2a\"]:\n            self.video_to_audio_conditioners[str(a_layer)] = (\n                ConditionalCrossAttentionBlock(\n                    dim=audio_hidden_dim,\n                    kv_dim=visual_hidden_dim,\n                    num_heads=audio_hidden_dim // head_dim,\n                    pooled_adaln=self.pooled_adaln,\n                )\n            )\n\n        # Required attributes for CachableDiT/BaseDiT\n        self.hidden_size = visual_hidden_dim\n        self.num_attention_heads = visual_hidden_dim // head_dim\n        self.num_channels_latents = (\n            visual_hidden_dim  # Bridge doesn't output latents, but required by BaseDiT\n        )\n        self.layer_names = [\n            \"audio_to_video_conditioners\",\n            \"video_to_audio_conditioners\",\n        ]\n        self.__post_init__()\n\n    @torch.no_grad()\n    def build_aligned_freqs(\n        self,\n        video_fps: float,\n        grid_size: Tuple[int, int, int],\n        audio_steps: int,\n        device: Optional[torch.device] = None,\n        dtype: Optional[torch.dtype] = None,\n    ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:\n        \"\"\"Generates aligned RoPE (cos, sin) based on video FPS, grid size, and audio length.\n\n        Uses functional RoPE computation to avoid FSDP meta device issues.\n\n        Args:\n            video_fps: FPS of the video.\n            grid_size: Tuple of (f_v, h, w).\n            audio_steps: Length of the audio sequence.\n            device: Target device.\n            dtype: Output dtype.\n\n        Returns:\n            A tuple of ((cos_v, sin_v), (cos_a, sin_a)).\n        \"\"\"\n        f_v, h, w = grid_size\n        L_v = f_v * h * w\n        L_a = int(audio_steps)\n\n        device = device or next(self.parameters()).device\n        dtype = dtype or torch.float32\n\n        # Audio positions: 0, 1, 2, ..., L_a-1\n        audio_pos = torch.arange(L_a, device=device, dtype=torch.float32).unsqueeze(0)\n\n        # Video positions: Align video frames to audio step units\n        if self.apply_first_frame_bias_in_rope:\n            video_effective_fps = float(video_fps) / 4.0\n            if f_v > 0:\n                t_starts = torch.zeros((f_v,), device=device, dtype=torch.float32)\n                if f_v > 1:\n                    t_starts[1:] = (1.0 / float(video_fps)) + torch.arange(\n                        f_v - 1, device=device, dtype=torch.float32\n                    ) * (1.0 / video_effective_fps)\n            else:\n                t_starts = torch.zeros((0,), device=device, dtype=torch.float32)\n            video_pos_per_frame = t_starts * float(self.audio_fps)\n        else:\n            scale = float(self.audio_fps) / float(video_fps / 4.0)\n            video_pos_per_frame = (\n                torch.arange(f_v, device=device, dtype=torch.float32) * scale\n            )\n\n        video_pos = video_pos_per_frame.repeat_interleave(h * w).unsqueeze(0)\n\n        # Use functional RoPE to compute cos/sin\n        cos_v, sin_v = compute_rope_cos_sin(\n            video_pos, self.head_dim, base=self.rope_base, device=device, dtype=dtype\n        )\n        cos_a, sin_a = compute_rope_cos_sin(\n            audio_pos, self.head_dim, base=self.rope_base, device=device, dtype=dtype\n        )\n\n        return (cos_v, sin_v), (cos_a, sin_a)\n\n    def should_interact(self, layer_idx: int, direction: str) -> bool:\n        return self.controller.should_interact(\n            layer_idx, direction, self.interaction_mapping\n        )\n\n    def apply_conditional_control(\n        self,\n        layer_idx: int,\n        direction: str,\n        primary_hidden_states: torch.Tensor,\n        condition_hidden_states: torch.Tensor,\n        x_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n        y_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n        condition_scale: Optional[float] = None,\n        video_grid_size: Optional[Tuple[int, int, int]] = None,\n    ) -> torch.Tensor:\n        \"\"\"Applies conditional control at the DiT hidden states level.\"\"\"\n        if not self.controller.should_interact(\n            layer_idx, direction, self.interaction_mapping\n        ):\n            return primary_hidden_states\n\n        if direction == \"a2v\":\n            conditioner = self.audio_to_video_conditioners[str(layer_idx)]\n        elif direction == \"v2a\":\n            conditioner = self.video_to_audio_conditioners[str(layer_idx)]\n        else:\n            raise ValueError(f\"Invalid direction: {direction}\")\n\n        conditioned_features = conditioner(\n            x=primary_hidden_states,\n            y=condition_hidden_states,\n            x_freqs=x_freqs,\n            y_freqs=y_freqs,\n            video_grid_size=video_grid_size,\n        )\n\n        if self.trainable_condition_scale and condition_scale is not None:\n            logger.warning(\n                \"The current model has a trainable condition_scale, but condition_scale \"\n                \"was passed externally. Ignoring the trainable condition_scale and \"\n                \"using the external condition_scale=%s.\",\n                condition_scale,\n            )\n\n        scale = condition_scale if condition_scale is not None else self.condition_scale\n\n        primary_hidden_states = primary_hidden_states + conditioned_features * scale\n\n        return primary_hidden_states\n\n    def forward(\n        self,\n        layer_idx: int,\n        visual_hidden_states: torch.Tensor,\n        audio_hidden_states: torch.Tensor,\n        *,\n        x_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n        y_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n        a2v_condition_scale: Optional[float] = None,\n        v2a_condition_scale: Optional[float] = None,\n        condition_scale: Optional[float] = None,\n        video_grid_size: Optional[Tuple[int, int, int]] = None,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"Performs bidirectional conditional control for both visual and audio towers.\"\"\"\n        visual_conditioned = self.apply_conditional_control(\n            layer_idx=layer_idx,\n            direction=\"a2v\",\n            primary_hidden_states=visual_hidden_states,\n            condition_hidden_states=audio_hidden_states,\n            x_freqs=x_freqs,\n            y_freqs=y_freqs,\n            condition_scale=(\n                a2v_condition_scale\n                if a2v_condition_scale is not None\n                else condition_scale\n            ),\n            video_grid_size=video_grid_size,\n        )\n\n        audio_conditioned = self.apply_conditional_control(\n            layer_idx=layer_idx,\n            direction=\"v2a\",\n            primary_hidden_states=audio_hidden_states,\n            condition_hidden_states=visual_hidden_states,\n            x_freqs=y_freqs,\n            y_freqs=x_freqs,\n            condition_scale=(\n                v2a_condition_scale\n                if v2a_condition_scale is not None\n                else condition_scale\n            ),\n            video_grid_size=video_grid_size,\n        )\n\n        return visual_conditioned, audio_conditioned\n\n\nEntryClass = DualTowerConditionalBridge\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/dits/base.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\nfrom abc import ABC, abstractmethod\nfrom typing import Any\n\nimport torch\nfrom torch import nn\n\nfrom sglang.multimodal_gen.configs.models import DiTConfig\n\n# NOTE: TeaCacheContext and TeaCacheMixin have been moved to\n# sglang.multimodal_gen.runtime.cache.teacache\n# For backwards compatibility, re-export from the new location\nfrom sglang.multimodal_gen.runtime.cache.teacache import TeaCacheContext  # noqa: F401\nfrom sglang.multimodal_gen.runtime.cache.teacache import TeaCacheMixin\nfrom sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum\n\n\n# TODO\nclass BaseDiT(nn.Module, ABC):\n    _fsdp_shard_conditions: list = []\n    _compile_conditions: list = []\n    param_names_mapping: dict\n    reverse_param_names_mapping: dict\n    hidden_size: int\n    num_attention_heads: int\n    num_channels_latents: int\n    # always supports torch_sdpa\n    _supported_attention_backends: set[AttentionBackendEnum] = (\n        DiTConfig()._supported_attention_backends\n    )\n\n    def __init_subclass__(cls) -> None:\n        required_class_attrs = [\n            \"_fsdp_shard_conditions\",\n            \"param_names_mapping\",\n            \"_compile_conditions\",\n        ]\n        super().__init_subclass__()\n        for attr in required_class_attrs:\n            if not hasattr(cls, attr):\n                raise AttributeError(\n                    f\"Subclasses of BaseDiT must define '{attr}' class variable\"\n                )\n\n    def __init__(self, config: DiTConfig, hf_config: dict[str, Any], **kwargs) -> None:\n        super().__init__()\n        self.config = config\n        self.hf_config = hf_config\n        if not self.supported_attention_backends:\n            raise ValueError(\n                f\"Subclass {self.__class__.__name__} must define _supported_attention_backends\"\n            )\n\n    @abstractmethod\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: torch.Tensor | list[torch.Tensor],\n        timestep: torch.LongTensor,\n        encoder_hidden_states_image: torch.Tensor | list[torch.Tensor] | None = None,\n        guidance=None,\n        **kwargs,\n    ) -> torch.Tensor:\n        pass\n\n    def __post_init__(self) -> None:\n        required_attrs = [\"hidden_size\", \"num_attention_heads\", \"num_channels_latents\"]\n        for attr in required_attrs:\n            if not hasattr(self, attr):\n                raise AttributeError(\n                    f\"Subclasses of BaseDiT must define '{attr}' instance variable\"\n                )\n\n    @property\n    def supported_attention_backends(self) -> set[AttentionBackendEnum]:\n        return self._supported_attention_backends\n\n    @property\n    def device(self) -> torch.device:\n        \"\"\"Get the device of the model.\"\"\"\n        return next(self.parameters()).device\n\n\nclass CachableDiT(TeaCacheMixin, BaseDiT):\n    \"\"\"\n    An intermediate base class that adds TeaCache optimization functionality to DiT models.\n\n    Inherits TeaCacheMixin for cache logic and BaseDiT for core DiT functionality.\n    \"\"\"\n\n    # These are required class attributes that should be overridden by concrete implementations\n    _fsdp_shard_conditions = []\n    param_names_mapping = {}\n    reverse_param_names_mapping = {}\n    lora_param_names_mapping: dict = {}\n    # Ensure these instance attributes are properly defined in subclasses\n    hidden_size: int\n    num_attention_heads: int\n    num_channels_latents: int\n    # always supports torch_sdpa\n    _supported_attention_backends: set[AttentionBackendEnum] = (\n        DiTConfig()._supported_attention_backends\n    )\n\n    def __init__(self, config: DiTConfig, **kwargs) -> None:\n        super().__init__(config, **kwargs)\n        self._init_teacache_state()\n\n    @classmethod\n    def get_nunchaku_quant_rules(cls) -> dict[str, dict[str, Any]]:\n        \"\"\"\n        Get quantization rules for Nunchaku quantization.\n\n        Returns a dict mapping layer name patterns to quantization configs:\n        {\n            \"skip\": [list of patterns to skip quantization],\n            \"svdq_w4a4\": [list of patterns for SVDQ W4A4],\n            \"awq_w4a16\": [list of patterns for AWQ W4A16],\n        }\n        \"\"\"\n        return {}\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/dits/causal_wanvideo.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\nimport math\nfrom typing import Any\n\nimport torch\nimport torch.nn as nn\nfrom torch.nn.attention.flex_attention import (\n    BlockMask,\n    create_block_mask,\n    flex_attention,\n)\n\nfrom sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin\n\n# wan 1.3B model has a weird channel / head configurations and require max-autotune to work with flexattention\n# see https://github.com/pytorch/pytorch/issues/133254\n# change to default for other models\nflex_attention = torch.compile(\n    flex_attention, dynamic=False, mode=\"max-autotune-no-cudagraphs\"\n)\nimport torch.distributed as dist\n\nfrom sglang.multimodal_gen.configs.models.dits import WanVideoConfig\nfrom sglang.multimodal_gen.runtime.distributed.parallel_state import get_sp_world_size\nfrom sglang.multimodal_gen.runtime.layers.attention import LocalAttention\nfrom sglang.multimodal_gen.runtime.layers.elementwise import MulAdd\nfrom sglang.multimodal_gen.runtime.layers.layernorm import (\n    FP32LayerNorm,\n    LayerNormScaleShift,\n    RMSNorm,\n    ScaleResidualLayerNormScaleShift,\n)\nfrom sglang.multimodal_gen.runtime.layers.linear import ReplicatedLinear\nfrom sglang.multimodal_gen.runtime.layers.mlp import MLP\nfrom sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import (\n    QuantizationConfig,\n)\nfrom sglang.multimodal_gen.runtime.layers.rotary_embedding import (\n    _apply_rotary_emb,\n    get_rotary_pos_embed,\n)\nfrom sglang.multimodal_gen.runtime.layers.visual_embedding import PatchEmbed\nfrom sglang.multimodal_gen.runtime.models.dits.base import BaseDiT\nfrom sglang.multimodal_gen.runtime.models.dits.wanvideo import (\n    WanT2VCrossAttention,\n    WanTimeTextImageEmbedding,\n)\nfrom sglang.multimodal_gen.runtime.platforms import (\n    AttentionBackendEnum,\n    current_platform,\n)\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\nclass CausalWanSelfAttention(nn.Module):\n\n    def __init__(\n        self,\n        dim: int,\n        num_heads: int,\n        local_attn_size: int = -1,\n        sink_size: int = 0,\n        qk_norm=True,\n        eps=1e-6,\n        parallel_attention=False,\n    ) -> None:\n        assert dim % num_heads == 0\n        super().__init__()\n        self.dim = dim\n        self.num_heads = num_heads\n        self.head_dim = dim // num_heads\n        self.local_attn_size = local_attn_size\n        self.sink_size = sink_size\n        self.qk_norm = qk_norm\n        self.eps = eps\n        self.parallel_attention = parallel_attention\n        self.max_attention_size = (\n            32760 if local_attn_size == -1 else local_attn_size * 1560\n        )\n\n        # Scaled dot product attention\n        self.attn = LocalAttention(\n            num_heads=num_heads,\n            head_size=self.head_dim,\n            dropout_rate=0,\n            softmax_scale=None,\n            causal=False,\n            supported_attention_backends=(\n                AttentionBackendEnum.FA,\n                AttentionBackendEnum.AITER,\n                AttentionBackendEnum.TORCH_SDPA,\n            ),\n        )\n\n    def forward(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        freqs_cis: tuple[torch.Tensor, torch.Tensor],\n        block_mask: BlockMask,\n        kv_cache: dict | None = None,\n        current_start: int = 0,\n        cache_start: int | None = None,\n    ):\n        r\"\"\"\n        Args:\n            x(Tensor): Shape [B, L, num_heads, C / num_heads]\n            seq_lens(Tensor): Shape [B]\n            grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)\n            freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]\n        \"\"\"\n        if cache_start is None:\n            cache_start = current_start\n\n        cos, sin = freqs_cis\n        roped_query = _apply_rotary_emb(q, cos, sin, is_neox_style=False).type_as(v)\n        roped_key = _apply_rotary_emb(k, cos, sin, is_neox_style=False).type_as(v)\n\n        if kv_cache is None:\n            # Padding for flex attention\n            padded_length = math.ceil(q.shape[1] / 128) * 128 - q.shape[1]\n            padded_roped_query = torch.cat(\n                [\n                    roped_query,\n                    torch.zeros(\n                        [q.shape[0], padded_length, q.shape[2], q.shape[3]],\n                        device=q.device,\n                        dtype=v.dtype,\n                    ),\n                ],\n                dim=1,\n            )\n\n            padded_roped_key = torch.cat(\n                [\n                    roped_key,\n                    torch.zeros(\n                        [k.shape[0], padded_length, k.shape[2], k.shape[3]],\n                        device=k.device,\n                        dtype=v.dtype,\n                    ),\n                ],\n                dim=1,\n            )\n\n            padded_v = torch.cat(\n                [\n                    v,\n                    torch.zeros(\n                        [v.shape[0], padded_length, v.shape[2], v.shape[3]],\n                        device=v.device,\n                        dtype=v.dtype,\n                    ),\n                ],\n                dim=1,\n            )\n\n            x = flex_attention(\n                query=padded_roped_query.transpose(2, 1),\n                key=padded_roped_key.transpose(2, 1),\n                value=padded_v.transpose(2, 1),\n                block_mask=block_mask,\n            )[:, :, :-padded_length].transpose(2, 1)\n        else:\n            frame_seqlen = q.shape[1]\n            current_end = current_start + roped_query.shape[1]\n            sink_tokens = self.sink_size * frame_seqlen\n            # If we are using local attention and the current KV cache size is larger than the local attention size, we need to truncate the KV cache\n            kv_cache_size = kv_cache[\"k\"].shape[1]\n            num_new_tokens = roped_query.shape[1]\n            if (\n                self.local_attn_size != -1\n                and (current_end > kv_cache[\"global_end_index\"].item())\n                and (\n                    num_new_tokens + kv_cache[\"local_end_index\"].item() > kv_cache_size\n                )\n            ):\n                # Calculate the number of new tokens added in this step\n                # Shift existing cache content left to discard oldest tokens\n                # Clone the source slice to avoid overlapping memory error\n                num_evicted_tokens = (\n                    num_new_tokens + kv_cache[\"local_end_index\"].item() - kv_cache_size\n                )\n                num_rolled_tokens = (\n                    kv_cache[\"local_end_index\"].item()\n                    - num_evicted_tokens\n                    - sink_tokens\n                )\n                kv_cache[\"k\"][\n                    :, sink_tokens : sink_tokens + num_rolled_tokens\n                ] = kv_cache[\"k\"][\n                    :,\n                    sink_tokens\n                    + num_evicted_tokens : sink_tokens\n                    + num_evicted_tokens\n                    + num_rolled_tokens,\n                ].clone()\n                kv_cache[\"v\"][\n                    :, sink_tokens : sink_tokens + num_rolled_tokens\n                ] = kv_cache[\"v\"][\n                    :,\n                    sink_tokens\n                    + num_evicted_tokens : sink_tokens\n                    + num_evicted_tokens\n                    + num_rolled_tokens,\n                ].clone()\n                # Insert the new keys/values at the end\n                local_end_index = (\n                    kv_cache[\"local_end_index\"].item()\n                    + current_end\n                    - kv_cache[\"global_end_index\"].item()\n                    - num_evicted_tokens\n                )\n                local_start_index = local_end_index - num_new_tokens\n                kv_cache[\"k\"][:, local_start_index:local_end_index] = roped_key\n                kv_cache[\"v\"][:, local_start_index:local_end_index] = v\n            else:\n                # Assign new keys/values directly up to current_end\n                local_end_index = (\n                    kv_cache[\"local_end_index\"].item()\n                    + current_end\n                    - kv_cache[\"global_end_index\"].item()\n                )\n                local_start_index = local_end_index - num_new_tokens\n                kv_cache[\"k\"] = kv_cache[\"k\"].detach()\n                kv_cache[\"v\"] = kv_cache[\"v\"].detach()\n                # logger.info(\"kv_cache['k'] is in comp graph: %s\", kv_cache[\"k\"].requires_grad or kv_cache[\"k\"].grad_fn is not None)\n                kv_cache[\"k\"][:, local_start_index:local_end_index] = roped_key\n                kv_cache[\"v\"][:, local_start_index:local_end_index] = v\n            x = self.attn(\n                roped_query,\n                kv_cache[\"k\"][\n                    :,\n                    max(0, local_end_index - self.max_attention_size) : local_end_index,\n                ],\n                kv_cache[\"v\"][\n                    :,\n                    max(0, local_end_index - self.max_attention_size) : local_end_index,\n                ],\n            )\n            kv_cache[\"global_end_index\"].fill_(current_end)\n            kv_cache[\"local_end_index\"].fill_(local_end_index)\n\n        return x\n\n\nclass CausalWanTransformerBlock(nn.Module):\n\n    def __init__(\n        self,\n        dim: int,\n        ffn_dim: int,\n        num_heads: int,\n        local_attn_size: int = -1,\n        sink_size: int = 0,\n        qk_norm: str = \"rms_norm_across_heads\",\n        cross_attn_norm: bool = False,\n        eps: float = 1e-6,\n        added_kv_proj_dim: int | None = None,\n        supported_attention_backends: set[AttentionBackendEnum] | None = None,\n        prefix: str = \"\",\n        quant_config: QuantizationConfig | None = None,\n    ):\n        super().__init__()\n\n        # 1. Self-attention\n        self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)\n        self.to_q = ReplicatedLinear(dim, dim, bias=True, quant_config=quant_config)\n        self.to_k = ReplicatedLinear(dim, dim, bias=True, quant_config=quant_config)\n        self.to_v = ReplicatedLinear(dim, dim, bias=True, quant_config=quant_config)\n\n        self.to_out = ReplicatedLinear(dim, dim, bias=True, quant_config=quant_config)\n        self.attn1 = CausalWanSelfAttention(\n            dim,\n            num_heads,\n            local_attn_size=local_attn_size,\n            sink_size=sink_size,\n            qk_norm=qk_norm,\n            eps=eps,\n        )\n        self.hidden_dim = dim\n        self.num_attention_heads = num_heads\n        self.local_attn_size = local_attn_size\n        dim_head = dim // num_heads\n        if qk_norm == \"rms_norm\":\n            self.norm_q = RMSNorm(dim_head, eps=eps)\n            self.norm_k = RMSNorm(dim_head, eps=eps)\n        elif qk_norm == \"rms_norm_across_heads\":\n            # LTX applies qk norm across all heads\n            self.norm_q = RMSNorm(dim, eps=eps)\n            self.norm_k = RMSNorm(dim, eps=eps)\n        else:\n            print(\"QK Norm type not supported\")\n            raise Exception\n        assert cross_attn_norm is True\n        self.self_attn_residual_norm = ScaleResidualLayerNormScaleShift(\n            dim, eps=eps, elementwise_affine=True, dtype=torch.float32\n        )\n\n        # 2. Cross-attention\n        # Only T2V for now\n        cross_attn_backends = {\n            b for b in supported_attention_backends if not b.is_sparse\n        }\n        self.attn2 = WanT2VCrossAttention(\n            dim,\n            num_heads,\n            qk_norm=qk_norm,\n            eps=eps,\n            supported_attention_backends=cross_attn_backends,\n            quant_config=quant_config,\n        )\n        self.cross_attn_residual_norm = ScaleResidualLayerNormScaleShift(\n            dim, eps=eps, elementwise_affine=False, dtype=torch.float32\n        )\n\n        # 3. Feed-forward\n        self.ffn = MLP(\n            dim, ffn_dim, act_type=\"gelu_pytorch_tanh\", quant_config=quant_config\n        )\n        self.mlp_residual = MulAdd()\n\n        self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: torch.Tensor,\n        temb: torch.Tensor,\n        freqs_cis: tuple[torch.Tensor, torch.Tensor],\n        block_mask: BlockMask,\n        kv_cache: dict | None = None,\n        crossattn_cache: dict | None = None,\n        current_start: int = 0,\n        cache_start: int | None = None,\n    ) -> torch.Tensor:\n        # hidden_states.shape: [batch_size, seq_length, inner_dim]\n        # temb.shape: [batch_size, num_frames, 6, inner_dim]\n        if hidden_states.dim() == 4:\n            hidden_states = hidden_states.squeeze(1)\n        num_frames = temb.shape[1]\n        frame_seqlen = hidden_states.shape[1] // num_frames\n        bs, seq_length, _ = hidden_states.shape\n        orig_dtype = hidden_states.dtype\n        # assert orig_dtype != torch.float32\n        e = self.scale_shift_table + temb.float()\n        # e.shape: [batch_size, num_frames, 6, inner_dim]\n        assert e.shape == (bs, num_frames, 6, self.hidden_dim)\n        shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = e.chunk(\n            6, dim=2\n        )\n        # *_msa.shape: [batch_size, num_frames, 1, inner_dim]\n        assert shift_msa.dtype == torch.float32\n\n        # 1. Self-attention\n        norm_hidden_states = (\n            (\n                self.norm1(hidden_states.float()).unflatten(\n                    dim=1, sizes=(num_frames, frame_seqlen)\n                )\n                * (1 + scale_msa)\n                + shift_msa\n            )\n            .flatten(1, 2)\n            .to(orig_dtype)\n        )\n        query, _ = self.to_q(norm_hidden_states)\n        key, _ = self.to_k(norm_hidden_states)\n        value, _ = self.to_v(norm_hidden_states)\n\n        if self.norm_q is not None:\n            query = self.norm_q(query)\n        if self.norm_k is not None:\n            key = self.norm_k(key)\n\n        query = query.squeeze(1).unflatten(2, (self.num_attention_heads, -1))\n        key = key.squeeze(1).unflatten(2, (self.num_attention_heads, -1))\n        value = value.squeeze(1).unflatten(2, (self.num_attention_heads, -1))\n\n        attn_output = self.attn1(\n            query,\n            key,\n            value,\n            freqs_cis,\n            block_mask,\n            kv_cache,\n            current_start,\n            cache_start,\n        )\n        attn_output = attn_output.flatten(2)\n        attn_output, _ = self.to_out(attn_output)\n        attn_output = attn_output.squeeze(1)\n\n        null_shift = null_scale = torch.zeroes(\n            (1,), device=hidden_states.device, dtype=hidden_states.dtype\n        )\n        norm_hidden_states, hidden_states = self.self_attn_residual_norm(\n            hidden_states, attn_output, gate_msa, null_shift, null_scale\n        )\n        norm_hidden_states, hidden_states = norm_hidden_states.to(\n            orig_dtype\n        ), hidden_states.to(orig_dtype)\n\n        # 2. Cross-attention\n        attn_output = self.attn2(\n            norm_hidden_states,\n            context=encoder_hidden_states,\n            context_lens=None,\n            crossattn_cache=crossattn_cache,\n        )\n        norm_hidden_states, hidden_states = self.cross_attn_residual_norm(\n            hidden_states, attn_output, 1, c_shift_msa, c_scale_msa\n        )\n        norm_hidden_states, hidden_states = norm_hidden_states.to(\n            orig_dtype\n        ), hidden_states.to(orig_dtype)\n\n        # 3. Feed-forward\n        ff_output = self.ffn(norm_hidden_states)\n        hidden_states = self.mlp_residual(ff_output, c_gate_msa, hidden_states)\n        hidden_states = hidden_states.to(orig_dtype)\n\n        return hidden_states\n\n\nclass CausalWanTransformer3DModel(BaseDiT, OffloadableDiTMixin):\n    _fsdp_shard_conditions = WanVideoConfig()._fsdp_shard_conditions\n    _compile_conditions = WanVideoConfig()._compile_conditions\n    _supported_attention_backends = WanVideoConfig()._supported_attention_backends\n    param_names_mapping = WanVideoConfig().param_names_mapping\n    reverse_param_names_mapping = WanVideoConfig().reverse_param_names_mapping\n    lora_param_names_mapping = WanVideoConfig().lora_param_names_mapping\n\n    def __init__(\n        self,\n        config: WanVideoConfig,\n        hf_config: dict[str, Any],\n        quant_config: QuantizationConfig | None = None,\n    ) -> None:\n        super().__init__(config=config, hf_config=hf_config)\n\n        inner_dim = config.num_attention_heads * config.attention_head_dim\n        self.hidden_size = config.hidden_size\n        self.num_attention_heads = config.num_attention_heads\n        self.attention_head_dim = config.attention_head_dim\n        self.in_channels = config.in_channels\n        self.out_channels = config.out_channels\n        self.num_channels_latents = config.num_channels_latents\n        self.patch_size = config.patch_size\n        self.text_len = config.text_len\n        self.local_attn_size = config.local_attn_size\n\n        # 1. Patch & position embedding\n        self.patch_embedding = PatchEmbed(\n            in_chans=config.in_channels,\n            embed_dim=inner_dim,\n            patch_size=config.patch_size,\n            flatten=False,\n        )\n\n        # 2. Condition embeddings\n        self.condition_embedder = WanTimeTextImageEmbedding(\n            dim=inner_dim,\n            time_freq_dim=config.freq_dim,\n            text_embed_dim=config.text_dim,\n            image_embed_dim=config.image_dim,\n        )\n\n        # 3. Transformer blocks\n        self.blocks = nn.ModuleList(\n            [\n                CausalWanTransformerBlock(\n                    inner_dim,\n                    config.ffn_dim,\n                    config.num_attention_heads,\n                    config.local_attn_size,\n                    config.sink_size,\n                    config.qk_norm,\n                    config.cross_attn_norm,\n                    config.eps,\n                    config.added_kv_proj_dim,\n                    self._supported_attention_backends,\n                    prefix=f\"{config.prefix}.blocks.{i}\",\n                    quant_config=quant_config,\n                )\n                for i in range(config.num_layers)\n            ]\n        )\n\n        # 4. Output norm & projection\n        self.norm_out = LayerNormScaleShift(\n            inner_dim,\n            eps=config.eps,\n            elementwise_affine=False,\n            dtype=torch.float32,\n        )\n        self.proj_out = nn.Linear(\n            inner_dim, config.out_channels * math.prod(config.patch_size)\n        )\n        self.scale_shift_table = nn.Parameter(\n            torch.randn(1, 2, inner_dim) / inner_dim**0.5\n        )\n\n        self.gradient_checkpointing = False\n\n        # Causal-specific\n        self.block_mask = None\n        self.num_frame_per_block = config.arch_config.num_frames_per_block\n        assert self.num_frame_per_block <= 3\n        self.independent_first_frame = False\n\n        self.__post_init__()\n\n        self.layer_names = [\n            \"blocks\",\n        ]\n\n    @staticmethod\n    def _prepare_blockwise_causal_attn_mask(\n        device: torch.device | str,\n        num_frames: int = 21,\n        frame_seqlen: int = 1560,\n        num_frame_per_block=1,\n        local_attn_size=-1,\n    ) -> BlockMask:\n        \"\"\"\n        we will divide the token sequence into the following format\n        [1 latent frame] [1 latent frame] ... [1 latent frame]\n        We use flexattention to construct the attention mask\n        \"\"\"\n        total_length = num_frames * frame_seqlen\n\n        # we do right padding to get to a multiple of 128\n        padded_length = math.ceil(total_length / 128) * 128 - total_length\n\n        ends = torch.zeros(\n            total_length + padded_length, device=device, dtype=torch.long\n        )\n\n        # Block-wise causal mask will attend to all elements that are before the end of the current chunk\n        frame_indices = torch.arange(\n            start=0,\n            end=total_length,\n            step=frame_seqlen * num_frame_per_block,\n            device=device,\n        )\n\n        for tmp in frame_indices:\n            ends[tmp : tmp + frame_seqlen * num_frame_per_block] = (\n                tmp + frame_seqlen * num_frame_per_block\n            )\n\n        def attention_mask(b, h, q_idx, kv_idx):\n            if local_attn_size == -1:\n                return (kv_idx < ends[q_idx]) | (q_idx == kv_idx)\n            else:\n                return (\n                    (kv_idx < ends[q_idx])\n                    & (kv_idx >= (ends[q_idx] - local_attn_size * frame_seqlen))\n                ) | (q_idx == kv_idx)\n            # return ((kv_idx < total_length) & (q_idx < total_length))  | (q_idx == kv_idx) # bidirectional mask\n\n        block_mask = create_block_mask(\n            attention_mask,\n            B=None,\n            H=None,\n            Q_LEN=total_length + padded_length,\n            KV_LEN=total_length + padded_length,\n            _compile=False,\n            device=device,\n        )\n\n        if not dist.is_initialized() or dist.get_rank() == 0:\n            print(\n                f\" cache a block wise causal mask with block size of {num_frame_per_block} frames\"\n            )\n            print(block_mask)\n\n        # import imageio\n        # import numpy as np\n        # from torch.nn.attention.flex_attention import create_mask\n\n        # mask = create_mask(attention_mask, B=None, H=None, Q_LEN=total_length +\n        #                    padded_length, KV_LEN=total_length + padded_length, device=device)\n        # import cv2\n        # mask = cv2.resize(mask[0, 0].cpu().float().numpy(), (1024, 1024))\n        # imageio.imwrite(\"mask_%d.jpg\" % (0), np.uint8(255. * mask))\n\n        return block_mask\n\n    def _forward_inference(\n        self,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: torch.Tensor | list[torch.Tensor],\n        timestep: torch.LongTensor,\n        encoder_hidden_states_image: torch.Tensor | list[torch.Tensor] | None = None,\n        kv_cache: dict = None,\n        crossattn_cache: dict = None,\n        current_start: int = 0,\n        cache_start: int = 0,\n        start_frame: int = 0,\n        **kwargs,\n    ) -> torch.Tensor:\n        r\"\"\"\n        Run the diffusion model with kv caching.\n        See Algorithm 2 of CausVid paper https://arxiv.org/abs/2412.07772 for details.\n        This function will be run for num_frame times.\n        Process the latent frames one by one (1560 tokens each)\n        \"\"\"\n\n        orig_dtype = hidden_states.dtype\n        if not isinstance(encoder_hidden_states, torch.Tensor):\n            encoder_hidden_states = encoder_hidden_states[0]\n        if (\n            isinstance(encoder_hidden_states_image, list)\n            and len(encoder_hidden_states_image) > 0\n        ):\n            encoder_hidden_states_image = encoder_hidden_states_image[0]\n        else:\n            encoder_hidden_states_image = None\n\n        batch_size, num_channels, num_frames, height, width = hidden_states.shape\n        p_t, p_h, p_w = self.patch_size\n        post_patch_num_frames = num_frames // p_t\n        post_patch_height = height // p_h\n        post_patch_width = width // p_w\n\n        # Get rotary embeddings\n        d = self.hidden_size // self.num_attention_heads\n        rope_dim_list = [d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)]\n        freqs_cos, freqs_sin = get_rotary_pos_embed(\n            (\n                post_patch_num_frames * get_sp_world_size(),\n                post_patch_height,\n                post_patch_width,\n            ),\n            self.hidden_size,\n            self.num_attention_heads,\n            rope_dim_list,\n            dtype=(\n                torch.float32\n                if current_platform.is_mps() or current_platform.is_musa()\n                else torch.float64\n            ),\n            rope_theta=10000,\n            start_frame=start_frame,  # Assume that start_frame is 0 when kv_cache is None\n        )\n        freqs_cos = freqs_cos.to(hidden_states.device)\n        freqs_sin = freqs_sin.to(hidden_states.device)\n        freqs_cis = (\n            (freqs_cos.float(), freqs_sin.float()) if freqs_cos is not None else None\n        )\n\n        hidden_states = self.patch_embedding(hidden_states)\n        hidden_states = hidden_states.flatten(2).transpose(1, 2)\n\n        temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = (\n            self.condition_embedder(\n                timestep.flatten(), encoder_hidden_states, encoder_hidden_states_image\n            )\n        )\n        timestep_proj = timestep_proj.unflatten(1, (6, self.hidden_size)).unflatten(\n            dim=0, sizes=timestep.shape\n        )\n\n        if encoder_hidden_states_image is not None:\n            encoder_hidden_states = torch.concat(\n                [encoder_hidden_states_image, encoder_hidden_states], dim=1\n            )\n\n        encoder_hidden_states = (\n            encoder_hidden_states.to(orig_dtype)\n            if current_platform.is_mps()\n            else encoder_hidden_states\n        )  # cast to orig_dtype for MPS\n\n        assert encoder_hidden_states.dtype == orig_dtype\n\n        # 4. Transformer blocks\n        for block_index, block in enumerate(self.blocks):\n            if torch.is_grad_enabled() and self.gradient_checkpointing:\n                causal_kwargs = {\n                    \"kv_cache\": kv_cache[block_index],\n                    \"current_start\": current_start,\n                    \"cache_start\": cache_start,\n                    \"block_mask\": self.block_mask,\n                }\n                hidden_states = self._gradient_checkpointing_func(\n                    block,\n                    hidden_states,\n                    encoder_hidden_states,\n                    timestep_proj,\n                    freqs_cis,\n                    **causal_kwargs,\n                )\n            else:\n                causal_kwargs = {\n                    \"kv_cache\": kv_cache[block_index],\n                    \"crossattn_cache\": crossattn_cache[block_index],\n                    \"current_start\": current_start,\n                    \"cache_start\": cache_start,\n                    \"block_mask\": self.block_mask,\n                }\n                hidden_states = block(\n                    hidden_states,\n                    encoder_hidden_states,\n                    timestep_proj,\n                    freqs_cis,\n                    **causal_kwargs,\n                )\n\n        # 5. Output norm, projection & unpatchify\n        temb = temb.unflatten(dim=0, sizes=timestep.shape).unsqueeze(2)\n        shift, scale = (self.scale_shift_table.unsqueeze(1) + temb).chunk(2, dim=2)\n        hidden_states = self.norm_out(hidden_states, shift, scale)\n        hidden_states = self.proj_out(hidden_states)\n\n        hidden_states = hidden_states.reshape(\n            batch_size,\n            post_patch_num_frames,\n            post_patch_height,\n            post_patch_width,\n            p_t,\n            p_h,\n            p_w,\n            -1,\n        )\n        hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)\n        output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)\n\n        return output\n\n    def _forward_train(\n        self,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: torch.Tensor | list[torch.Tensor],\n        timestep: torch.LongTensor,\n        encoder_hidden_states_image: torch.Tensor | list[torch.Tensor] | None = None,\n        start_frame: int = 0,\n        **kwargs,\n    ) -> torch.Tensor:\n\n        orig_dtype = hidden_states.dtype\n        if not isinstance(encoder_hidden_states, torch.Tensor):\n            encoder_hidden_states = encoder_hidden_states[0]\n        if (\n            isinstance(encoder_hidden_states_image, list)\n            and len(encoder_hidden_states_image) > 0\n        ):\n            encoder_hidden_states_image = encoder_hidden_states_image[0]\n        else:\n            encoder_hidden_states_image = None\n\n        batch_size, num_channels, num_frames, height, width = hidden_states.shape\n        p_t, p_h, p_w = self.patch_size\n        post_patch_num_frames = num_frames // p_t\n        post_patch_height = height // p_h\n        post_patch_width = width // p_w\n\n        # Get rotary embeddings\n        d = self.hidden_size // self.num_attention_heads\n        rope_dim_list = [d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)]\n        freqs_cos, freqs_sin = get_rotary_pos_embed(\n            (\n                post_patch_num_frames * get_sp_world_size(),\n                post_patch_height,\n                post_patch_width,\n            ),\n            self.hidden_size,\n            self.num_attention_heads,\n            rope_dim_list,\n            dtype=(\n                torch.float32\n                if current_platform.is_mps() or current_platform.is_musa()\n                else torch.float64\n            ),\n            rope_theta=10000,\n            start_frame=start_frame,\n        )\n        freqs_cos = freqs_cos.to(hidden_states.device)\n        freqs_sin = freqs_sin.to(hidden_states.device)\n        freqs_cis = (\n            (freqs_cos.float(), freqs_sin.float()) if freqs_cos is not None else None\n        )\n\n        # Construct blockwise causal attn mask\n        if self.block_mask is None:\n            self.block_mask = self._prepare_blockwise_causal_attn_mask(\n                device=hidden_states.device,\n                num_frames=num_frames,\n                frame_seqlen=post_patch_height * post_patch_width,\n                num_frame_per_block=self.num_frame_per_block,\n                local_attn_size=self.local_attn_size,\n            )\n\n        hidden_states = self.patch_embedding(hidden_states)\n        hidden_states = hidden_states.flatten(2).transpose(1, 2)\n\n        temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = (\n            self.condition_embedder(\n                timestep.flatten(), encoder_hidden_states, encoder_hidden_states_image\n            )\n        )\n        timestep_proj = timestep_proj.unflatten(1, (6, self.hidden_size)).unflatten(\n            dim=0, sizes=timestep.shape\n        )\n\n        if encoder_hidden_states_image is not None:\n            encoder_hidden_states = torch.concat(\n                [encoder_hidden_states_image, encoder_hidden_states], dim=1\n            )\n\n        encoder_hidden_states = (\n            encoder_hidden_states.to(orig_dtype)\n            if current_platform.is_mps()\n            else encoder_hidden_states\n        )  # cast to orig_dtype for MPS\n\n        assert encoder_hidden_states.dtype == orig_dtype\n\n        # 4. Transformer blocks\n        if torch.is_grad_enabled() and self.gradient_checkpointing:\n            for block in self.blocks:\n                hidden_states = self._gradient_checkpointing_func(\n                    block,\n                    hidden_states,\n                    encoder_hidden_states,\n                    timestep_proj,\n                    freqs_cis,\n                    block_mask=self.block_mask,\n                )\n        else:\n            for block in self.blocks:\n                hidden_states = block(\n                    hidden_states,\n                    encoder_hidden_states,\n                    timestep_proj,\n                    freqs_cis,\n                    block_mask=self.block_mask,\n                )\n\n        # 5. Output norm, projection & unpatchify\n        temb = temb.unflatten(dim=0, sizes=timestep.shape).unsqueeze(2)\n        shift, scale = (self.scale_shift_table.unsqueeze(1) + temb).chunk(2, dim=2)\n        hidden_states = self.norm_out(hidden_states, shift, scale)\n        hidden_states = self.proj_out(hidden_states)\n\n        hidden_states = hidden_states.reshape(\n            batch_size,\n            post_patch_num_frames,\n            post_patch_height,\n            post_patch_width,\n            p_t,\n            p_h,\n            p_w,\n            -1,\n        )\n        hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)\n        output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)\n\n        return output\n\n    def forward(self, *args, **kwargs):\n        if kwargs.get(\"kv_cache\") is not None:\n            return self._forward_inference(*args, **kwargs)\n        else:\n            return self._forward_train(*args, **kwargs)\n\n\nEntryClass = CausalWanTransformer3DModel\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/dits/flux.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Any, Dict, List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn as nn\nfrom diffusers.models.attention import AttentionModuleMixin\nfrom diffusers.models.modeling_outputs import Transformer2DModelOutput\nfrom diffusers.models.normalization import (\n    AdaLayerNormContinuous,\n    AdaLayerNormZero,\n    AdaLayerNormZeroSingle,\n)\nfrom torch.nn import LayerNorm as LayerNorm\n\nfrom sglang.multimodal_gen.configs.models.dits.flux import FluxConfig\nfrom sglang.multimodal_gen.runtime.layers.attention import USPAttention\nfrom sglang.multimodal_gen.runtime.layers.layernorm import RMSNorm, apply_qk_norm\nfrom sglang.multimodal_gen.runtime.layers.linear import (\n    ColumnParallelLinear,\n    MergedColumnParallelLinear,\n)\nfrom sglang.multimodal_gen.runtime.layers.mlp import FeedForward\nfrom sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import (\n    QuantizationConfig,\n)\nfrom sglang.multimodal_gen.runtime.layers.quantization.configs.nunchaku_config import (\n    NunchakuConfig,\n    is_nunchaku_available,\n)\nfrom sglang.multimodal_gen.runtime.layers.rotary_embedding import (\n    NDRotaryEmbedding,\n    apply_flashinfer_rope_qk_inplace,\n)\nfrom sglang.multimodal_gen.runtime.layers.visual_embedding import (\n    CombinedTimestepGuidanceTextProjEmbeddings,\n    CombinedTimestepTextProjEmbeddings,\n)\nfrom sglang.multimodal_gen.runtime.models.dits.base import CachableDiT\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\nfrom sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)  # pylint: disable=invalid-name\n\ntry:\n    from nunchaku.models.attention import NunchakuFeedForward  # type: ignore[import]\n    from nunchaku.models.normalization import (  # type: ignore[import]\n        NunchakuAdaLayerNormZero,\n        NunchakuAdaLayerNormZeroSingle,\n    )\n    from nunchaku.ops.gemm import (\n        svdq_gemm_w4a4_cuda as _svdq_gemm_w4a4,  # type: ignore[import]\n    )\n    from nunchaku.ops.quantize import (\n        svdq_quantize_w4a4_act_fuse_lora_cuda as _svdq_quantize_w4a4,  # type: ignore[import]\n    )\n\n    _nunchaku_fused_ops_available = True\nexcept Exception:\n    NunchakuFeedForward = None\n    NunchakuAdaLayerNormZero = None\n    NunchakuAdaLayerNormZeroSingle = None\n    _svdq_gemm_w4a4 = None\n    _svdq_quantize_w4a4 = None\n    _nunchaku_fused_ops_available = False\n\n\ndef _fused_gelu_mlp(\n    x: torch.Tensor,\n    fc1,\n    fc2,\n    pad_size: int = 256,\n) -> torch.Tensor:\n    \"\"\"\n    Fused GELU MLP matching nunchaku's fused_gelu_mlp kernel path.\n\n    nunchaku's single-block MLP checkpoint is calibrated for the fused path where:\n      1. fc1 GEMM + GELU + 0.171875 shift + unsigned re-quantization + fc2.lora_down\n         are all done in a single fused kernel call\n      2. fc2 GEMM then receives unsigned INT4 activations (act_unsigned=True)\n\n    Using the sequential path (fc1 → GELU → fc2 with symmetric quantization) is\n    fundamentally incompatible with these wscales, causing visually wrong outputs.\n    \"\"\"\n    batch_size, seq_len, channels = x.shape\n    x_2d = x.view(batch_size * seq_len, channels)\n\n    quantized_x, ascales, lora_act = _svdq_quantize_w4a4(\n        x_2d,\n        lora_down=fc1.proj_down,\n        smooth=fc1.smooth_factor,\n        fp4=fc1.precision == \"nvfp4\",\n        pad_size=pad_size,\n    )\n\n    batch_size_pad = (batch_size * seq_len + pad_size - 1) // pad_size * pad_size\n    is_fp4 = fc2.precision == \"nvfp4\"\n\n    qout_act = torch.empty(\n        batch_size_pad,\n        fc1.output_size_per_partition // 2,\n        dtype=torch.uint8,\n        device=x_2d.device,\n    )\n    if is_fp4:\n        qout_ascales = torch.empty(\n            fc1.output_size_per_partition // 16,\n            batch_size_pad,\n            dtype=torch.float8_e4m3fn,\n            device=x_2d.device,\n        )\n    else:\n        qout_ascales = torch.empty(\n            fc1.output_size_per_partition // 64,\n            batch_size_pad,\n            dtype=x_2d.dtype,\n            device=x_2d.device,\n        )\n    qout_lora_act = torch.empty(\n        batch_size_pad, fc2.proj_down.shape[1], dtype=torch.float32, device=x_2d.device\n    )\n\n    # fused: fc1 GEMM + GELU + shift + unsigned quantize + fc2.lora_down\n    _svdq_gemm_w4a4(\n        act=quantized_x,\n        wgt=fc1.qweight,\n        qout=qout_act,\n        ascales=ascales,\n        wscales=fc1.wscales,\n        oscales=qout_ascales,\n        lora_act_in=lora_act,\n        lora_up=fc1.proj_up,\n        lora_down=fc2.proj_down,\n        lora_act_out=qout_lora_act,\n        bias=fc1.bias,\n        smooth_factor=fc2.smooth_factor,\n        fp4=is_fp4,\n        alpha=getattr(fc1, \"_nunchaku_alpha\", None),\n        wcscales=getattr(fc1, \"wcscales\", None),\n    )\n\n    output = torch.empty(\n        batch_size * seq_len,\n        fc2.output_size_per_partition,\n        dtype=x_2d.dtype,\n        device=x_2d.device,\n    )\n    # fc2 GEMM with unsigned INT4 activations (fused kernel shifted by 0.171875)\n    _svdq_gemm_w4a4(\n        act=qout_act,\n        wgt=fc2.qweight,\n        out=output,\n        ascales=qout_ascales,\n        wscales=fc2.wscales,\n        lora_act_in=qout_lora_act,\n        lora_up=fc2.proj_up,\n        bias=fc2.bias,\n        fp4=is_fp4,\n        alpha=getattr(fc2, \"_nunchaku_alpha\", None),\n        wcscales=getattr(fc2, \"wcscales\", None),\n        act_unsigned=True,\n    )\n\n    return output.view(batch_size, seq_len, -1)\n\n\ndef _get_qkv_projections(\n    attn: \"FluxAttention\", hidden_states, encoder_hidden_states=None\n):\n    if getattr(attn, \"use_fused_qkv\", False):\n        qkv, _ = attn.to_qkv(hidden_states)\n        query, key, value = [x.contiguous() for x in qkv.chunk(3, dim=-1)]\n    else:\n        query, _ = attn.to_q(hidden_states)\n        key, _ = attn.to_k(hidden_states)\n        value, _ = attn.to_v(hidden_states)\n\n    encoder_query = encoder_key = encoder_value = None\n    if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:\n        if getattr(attn, \"use_fused_added_qkv\", False):\n            added_qkv, _ = attn.to_added_qkv(encoder_hidden_states)\n            encoder_query, encoder_key, encoder_value = [\n                x.contiguous() for x in added_qkv.chunk(3, dim=-1)\n            ]\n        else:\n            encoder_query, _ = attn.add_q_proj(encoder_hidden_states)\n            encoder_key, _ = attn.add_k_proj(encoder_hidden_states)\n            encoder_value, _ = attn.add_v_proj(encoder_hidden_states)\n\n    return query, key, value, encoder_query, encoder_key, encoder_value\n\n\nclass FluxAttention(torch.nn.Module, AttentionModuleMixin):\n    def __init__(\n        self,\n        query_dim: int,\n        num_heads: int = 8,\n        dim_head: int = 64,\n        dropout: float = 0.0,\n        bias: bool = False,\n        added_kv_proj_dim: Optional[int] = None,\n        added_proj_bias: Optional[bool] = True,\n        out_bias: bool = True,\n        eps: float = 1e-5,\n        out_dim: int = None,\n        context_pre_only: Optional[bool] = None,\n        pre_only: bool = False,\n        quant_config: Optional[QuantizationConfig] = None,\n        prefix: str = \"\",\n    ):\n        super().__init__()\n\n        self.head_dim = dim_head\n        self.inner_dim = out_dim if out_dim is not None else dim_head * num_heads\n        self.query_dim = query_dim\n        self.use_bias = bias\n        self.dropout = dropout\n        self.out_dim = out_dim if out_dim is not None else query_dim\n        self.context_pre_only = context_pre_only\n        self.pre_only = pre_only\n        self.heads = out_dim // dim_head if out_dim is not None else num_heads\n        self.added_kv_proj_dim = added_kv_proj_dim\n        self.added_proj_bias = added_proj_bias\n\n        self.use_fused_qkv = isinstance(quant_config, NunchakuConfig)\n        self.use_fused_added_qkv = isinstance(quant_config, NunchakuConfig)\n\n        self.norm_q = RMSNorm(dim_head, eps=eps)\n        self.norm_k = RMSNorm(dim_head, eps=eps)\n\n        if self.use_fused_qkv:\n            self.to_qkv = MergedColumnParallelLinear(\n                query_dim,\n                [self.inner_dim] * 3,\n                bias=bias,\n                gather_output=True,\n                quant_config=quant_config,\n                prefix=f\"{prefix}.to_qkv\" if prefix else \"to_qkv\",\n            )\n        else:\n            self.to_q = ColumnParallelLinear(\n                query_dim,\n                self.inner_dim,\n                bias=bias,\n                gather_output=True,\n                quant_config=quant_config,\n            )\n            self.to_k = ColumnParallelLinear(\n                query_dim,\n                self.inner_dim,\n                bias=bias,\n                gather_output=True,\n                quant_config=quant_config,\n            )\n            self.to_v = ColumnParallelLinear(\n                query_dim,\n                self.inner_dim,\n                bias=bias,\n                gather_output=True,\n                quant_config=quant_config,\n            )\n        if not self.pre_only:\n            self.to_out = torch.nn.ModuleList([])\n            self.to_out.append(\n                ColumnParallelLinear(\n                    self.inner_dim,\n                    self.out_dim,\n                    bias=out_bias,\n                    gather_output=True,\n                    quant_config=quant_config,\n                    prefix=f\"{prefix}.to_out.0\" if prefix else \"\",\n                )\n            )\n            if dropout != 0.0:\n                self.to_out.append(torch.nn.Dropout(dropout))\n\n        if added_kv_proj_dim is not None:\n            self.norm_added_q = RMSNorm(dim_head, eps=eps)\n            self.norm_added_k = RMSNorm(dim_head, eps=eps)\n            if self.use_fused_added_qkv:\n                self.to_added_qkv = MergedColumnParallelLinear(\n                    added_kv_proj_dim,\n                    [self.inner_dim] * 3,\n                    bias=added_proj_bias,\n                    gather_output=True,\n                    quant_config=quant_config,\n                    prefix=f\"{prefix}.to_added_qkv\" if prefix else \"to_added_qkv\",\n                )\n            else:\n                self.add_q_proj = ColumnParallelLinear(\n                    added_kv_proj_dim,\n                    self.inner_dim,\n                    bias=added_proj_bias,\n                    gather_output=True,\n                    quant_config=quant_config,\n                )\n                self.add_k_proj = ColumnParallelLinear(\n                    added_kv_proj_dim,\n                    self.inner_dim,\n                    bias=added_proj_bias,\n                    gather_output=True,\n                    quant_config=quant_config,\n                )\n                self.add_v_proj = ColumnParallelLinear(\n                    added_kv_proj_dim,\n                    self.inner_dim,\n                    bias=added_proj_bias,\n                    gather_output=True,\n                    quant_config=quant_config,\n                )\n            self.to_add_out = ColumnParallelLinear(\n                self.inner_dim,\n                query_dim,\n                bias=out_bias,\n                gather_output=True,\n                quant_config=quant_config,\n                prefix=f\"{prefix}.to_add_out\" if prefix else \"\",\n            )\n\n        self.attn = USPAttention(\n            num_heads=num_heads,\n            head_size=self.head_dim,\n            dropout_rate=0,\n            softmax_scale=None,\n            causal=False,\n        )\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        encoder_hidden_states: Optional[torch.Tensor] = None,\n        freqs_cis=None,\n    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:\n        query, key, value, encoder_query, encoder_key, encoder_value = (\n            _get_qkv_projections(self, x, encoder_hidden_states)\n        )\n\n        query = query.unflatten(-1, (self.heads, -1))\n        key = key.unflatten(-1, (self.heads, -1))\n        value = value.unflatten(-1, (self.heads, -1))\n        query, key = apply_qk_norm(\n            q=query,\n            k=key,\n            q_norm=self.norm_q,\n            k_norm=self.norm_k,\n            head_dim=self.head_dim,\n            allow_inplace=True,\n        )\n\n        if self.added_kv_proj_dim is not None:\n            encoder_query = encoder_query.unflatten(-1, (self.heads, -1))\n            encoder_key = encoder_key.unflatten(-1, (self.heads, -1))\n            encoder_value = encoder_value.unflatten(-1, (self.heads, -1))\n\n            encoder_query, encoder_key = apply_qk_norm(\n                q=encoder_query,\n                k=encoder_key,\n                q_norm=self.norm_added_q,\n                k_norm=self.norm_added_k,\n                head_dim=self.head_dim,\n                allow_inplace=True,\n            )\n\n            bsz, seq_len, _, _ = query.shape\n            query = torch.cat([encoder_query, query], dim=1)\n            key = torch.cat([encoder_key, key], dim=1)\n            value = torch.cat([encoder_value, value], dim=1)\n\n        if freqs_cis is not None:\n            cos, sin = freqs_cis\n            cos_sin_cache = torch.cat(\n                [\n                    cos.to(dtype=torch.float32).contiguous(),\n                    sin.to(dtype=torch.float32).contiguous(),\n                ],\n                dim=-1,\n            )\n            query, key = apply_flashinfer_rope_qk_inplace(\n                query, key, cos_sin_cache, is_neox=False\n            )\n\n        x = self.attn(query, key, value)\n        x = x.flatten(2, 3)\n        x = x.to(query.dtype)\n\n        if encoder_hidden_states is not None:\n            encoder_hidden_states, x = x.split_with_sizes(\n                [\n                    encoder_hidden_states.shape[1],\n                    x.shape[1] - encoder_hidden_states.shape[1],\n                ],\n                dim=1,\n            )\n            if not self.pre_only:\n                x, _ = self.to_out[0](x)\n                if len(self.to_out) == 2:\n                    x = self.to_out[1](x)\n            encoder_hidden_states, _ = self.to_add_out(encoder_hidden_states)\n\n            return x, encoder_hidden_states\n        else:\n            if not self.pre_only:\n                x, _ = self.to_out[0](x)\n                if len(self.to_out) == 2:\n                    x = self.to_out[1](x)\n            return x\n\n\nclass FluxSingleTransformerBlock(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        num_attention_heads: int,\n        attention_head_dim: int,\n        mlp_ratio: float = 4.0,\n        quant_config: Optional[QuantizationConfig] = None,\n        prefix: str = \"\",\n    ):\n        super().__init__()\n        self.mlp_hidden_dim = int(dim * mlp_ratio)\n        self.use_nunchaku_structure = isinstance(quant_config, NunchakuConfig)\n\n        self.norm = AdaLayerNormZeroSingle(dim)\n\n        if self.use_nunchaku_structure:\n            self.mlp_fc1 = ColumnParallelLinear(\n                dim,\n                self.mlp_hidden_dim,\n                bias=True,\n                gather_output=True,\n                quant_config=quant_config,\n                prefix=f\"{prefix}.mlp_fc1\" if prefix else \"mlp_fc1\",\n            )\n            self.act_mlp = nn.GELU(approximate=\"tanh\")\n            self.mlp_fc2 = ColumnParallelLinear(\n                self.mlp_hidden_dim,\n                dim,\n                bias=True,\n                gather_output=True,\n                quant_config=quant_config,\n                prefix=f\"{prefix}.mlp_fc2\" if prefix else \"mlp_fc2\",\n            )\n\n            self.attn = FluxAttention(\n                query_dim=dim,\n                dim_head=attention_head_dim,\n                num_heads=num_attention_heads,\n                out_dim=dim,\n                bias=True,\n                eps=1e-6,\n                pre_only=False,\n                quant_config=quant_config,\n                prefix=f\"{prefix}.attn\" if prefix else \"attn\",\n            )\n            if is_nunchaku_available():\n                self.norm = NunchakuAdaLayerNormZeroSingle(self.norm, scale_shift=0)\n        else:\n            self.proj_mlp = ColumnParallelLinear(\n                dim,\n                self.mlp_hidden_dim,\n                bias=True,\n                gather_output=True,\n                quant_config=quant_config,\n            )\n            self.act_mlp = nn.GELU(approximate=\"tanh\")\n            self.proj_out = ColumnParallelLinear(\n                dim + self.mlp_hidden_dim,\n                dim,\n                bias=True,\n                gather_output=True,\n                quant_config=quant_config,\n            )\n            self.attn = FluxAttention(\n                query_dim=dim,\n                dim_head=attention_head_dim,\n                num_heads=num_attention_heads,\n                out_dim=dim,\n                bias=True,\n                eps=1e-6,\n                pre_only=True,\n                quant_config=quant_config,\n            )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: torch.Tensor,\n        temb: torch.Tensor,\n        freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n        joint_attention_kwargs: Optional[Dict[str, Any]] = None,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        text_seq_len = encoder_hidden_states.shape[1]\n        hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)\n\n        residual = hidden_states\n        norm_hidden_states, gate = self.norm(hidden_states, emb=temb)\n        joint_attention_kwargs = joint_attention_kwargs or {}\n\n        if self.use_nunchaku_structure:\n            if _nunchaku_fused_ops_available:\n                mlp_hidden_states = _fused_gelu_mlp(\n                    norm_hidden_states, self.mlp_fc1, self.mlp_fc2\n                )\n            else:\n                mlp_out, _ = self.mlp_fc1(norm_hidden_states)\n                mlp_hidden_states = self.act_mlp(mlp_out)\n                mlp_hidden_states, _ = self.mlp_fc2(mlp_hidden_states)\n\n            attn_output = self.attn(\n                x=norm_hidden_states,\n                freqs_cis=freqs_cis,\n                **joint_attention_kwargs,\n            )\n            if isinstance(attn_output, tuple):\n                attn_output = attn_output[0]\n\n            hidden_states = attn_output + mlp_hidden_states\n            gate = gate.unsqueeze(1)\n            hidden_states = gate * hidden_states\n            hidden_states = residual + hidden_states\n        else:\n            proj_hidden_states, _ = self.proj_mlp(norm_hidden_states)\n            mlp_hidden_states = self.act_mlp(proj_hidden_states)\n\n            attn_output = self.attn(\n                x=norm_hidden_states,\n                freqs_cis=freqs_cis,\n                **joint_attention_kwargs,\n            )\n\n            hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)\n            gate = gate.unsqueeze(1)\n            proj_out, _ = self.proj_out(hidden_states)\n            hidden_states = gate * proj_out\n            hidden_states = residual + hidden_states\n\n        if hidden_states.dtype == torch.float16:\n            hidden_states = hidden_states.clip(-65504, 65504)\n\n        encoder_hidden_states, hidden_states = (\n            hidden_states[:, :text_seq_len],\n            hidden_states[:, text_seq_len:],\n        )\n        return encoder_hidden_states, hidden_states\n\n\nclass FluxTransformerBlock(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        num_attention_heads: int,\n        attention_head_dim: int,\n        qk_norm: str = \"rms_norm\",\n        eps: float = 1e-6,\n        quant_config: Optional[QuantizationConfig] = None,\n        prefix: str = \"\",\n    ):\n        super().__init__()\n\n        self.norm1 = AdaLayerNormZero(dim)\n        self.norm1_context = AdaLayerNormZero(dim)\n\n        self.attn = FluxAttention(\n            query_dim=dim,\n            added_kv_proj_dim=dim,\n            dim_head=attention_head_dim,\n            num_heads=num_attention_heads,\n            out_dim=dim,\n            context_pre_only=False,\n            bias=True,\n            eps=eps,\n            quant_config=quant_config,\n            prefix=f\"{prefix}.attn\" if prefix else \"attn\",\n        )\n\n        self.norm2 = LayerNorm(dim, eps=1e-6, elementwise_affine=False)\n        self.norm2_context = LayerNorm(dim, eps=1e-6, elementwise_affine=False)\n\n        nunchaku_enabled = (\n            quant_config is not None\n            and hasattr(quant_config, \"get_name\")\n            and quant_config.get_name() == \"svdquant\"\n            and is_nunchaku_available()\n        )\n        self.use_nunchaku_structure = nunchaku_enabled\n        self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn=\"gelu-approximate\")\n        self.ff_context = FeedForward(\n            dim=dim,\n            dim_out=dim,\n            activation_fn=\"gelu-approximate\",\n        )\n        if nunchaku_enabled:\n            nunchaku_kwargs = {\n                \"precision\": quant_config.precision,\n                \"rank\": quant_config.rank,\n                \"act_unsigned\": quant_config.act_unsigned,\n            }\n            self.ff = NunchakuFeedForward(self.ff, **nunchaku_kwargs)\n            self.ff_context = NunchakuFeedForward(self.ff_context, **nunchaku_kwargs)\n            self.norm1 = NunchakuAdaLayerNormZero(self.norm1, scale_shift=0)\n            self.norm1_context = NunchakuAdaLayerNormZero(\n                self.norm1_context, scale_shift=0\n            )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: torch.Tensor,\n        temb: torch.Tensor,\n        freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n        joint_attention_kwargs: Optional[Dict[str, Any]] = None,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(\n            hidden_states, emb=temb\n        )\n\n        norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (\n            self.norm1_context(encoder_hidden_states, emb=temb)\n        )\n\n        joint_attention_kwargs = joint_attention_kwargs or {}\n        # Attention.\n        attention_outputs = self.attn(\n            x=norm_hidden_states,\n            encoder_hidden_states=norm_encoder_hidden_states,\n            freqs_cis=freqs_cis,\n            **joint_attention_kwargs,\n        )\n\n        if len(attention_outputs) == 2:\n            attn_output, context_attn_output = attention_outputs\n        elif len(attention_outputs) == 3:\n            attn_output, context_attn_output, ip_attn_output = attention_outputs\n\n        # Process attention outputs for the `hidden_states`.\n        attn_output = gate_msa.unsqueeze(1) * attn_output\n        hidden_states = hidden_states + attn_output\n        norm_hidden_states = self.norm2(hidden_states)\n        if self.use_nunchaku_structure:\n            norm_hidden_states = (\n                norm_hidden_states * scale_mlp[:, None] + shift_mlp[:, None]\n            )\n        else:\n            norm_hidden_states = (\n                norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]\n            )\n\n        ff_output = self.ff(norm_hidden_states)\n        ff_output = gate_mlp.unsqueeze(1) * ff_output\n\n        hidden_states = hidden_states + ff_output\n\n        if len(attention_outputs) == 3:\n            hidden_states = hidden_states + ip_attn_output\n        # Process attention outputs for the `encoder_hidden_states`.\n        context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output\n        encoder_hidden_states = encoder_hidden_states + context_attn_output\n\n        norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)\n        if self.use_nunchaku_structure:\n            norm_encoder_hidden_states = (\n                norm_encoder_hidden_states * c_scale_mlp[:, None] + c_shift_mlp[:, None]\n            )\n        else:\n            norm_encoder_hidden_states = (\n                norm_encoder_hidden_states * (1 + c_scale_mlp[:, None])\n                + c_shift_mlp[:, None]\n            )\n\n        context_ff_output = self.ff_context(norm_encoder_hidden_states)\n        encoder_hidden_states = (\n            encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output\n        )\n        if encoder_hidden_states.dtype == torch.float16:\n            encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)\n\n        return encoder_hidden_states, hidden_states\n\n\nclass FluxPosEmbed(nn.Module):\n    # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11\n    def __init__(self, theta: int, axes_dim: List[int]):\n        super().__init__()\n        self.rope = NDRotaryEmbedding(\n            rope_dim_list=axes_dim,\n            rope_theta=theta,\n            use_real=False,\n            repeat_interleave_real=False,\n            dtype=(\n                torch.float32\n                if current_platform.is_mps() or current_platform.is_musa()\n                else torch.float64\n            ),\n        )\n\n    def forward(self, ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:\n        pos = ids.float()\n        # TODO: potential error: flux use n_axes = ids.shape[-1]\n        # see: https://github.com/huggingface/diffusers/blob/17c0e79dbdf53fb6705e9c09cc1a854b84c39249/src/diffusers/models/transformers/transformer_flux.py#L509\n        freqs_cos, freqs_sin = self.rope.forward_uncached(pos=pos)\n        return freqs_cos.contiguous().float(), freqs_sin.contiguous().float()\n\n\nclass FluxTransformer2DModel(CachableDiT, OffloadableDiTMixin):\n    \"\"\"\n    The Transformer model introduced in Flux.\n\n    Reference: https://blackforestlabs.ai/announcing-black-forest-labs/\n    \"\"\"\n\n    param_names_mapping = FluxConfig().arch_config.param_names_mapping\n\n    @classmethod\n    def get_nunchaku_quant_rules(cls) -> dict[str, list[str]]:\n        return {\n            \"skip\": [\n                \"norm\",\n                \"embed\",\n                \"rotary\",\n                \"pos_embed\",\n            ],\n            \"svdq_w4a4\": [\n                \"attn.to_qkv\",\n                \"attn.to_out\",\n                \"attn.add_qkv_proj\",\n                \"attn.to_added_qkv\",\n                \"attn.to_add_out\",\n                \"img_mlp\",\n                \"txt_mlp\",\n                \"attention.to_qkv\",\n                \"attention.to_out\",\n                \"proj_mlp\",\n                \"proj_out\",\n                \"mlp_fc1\",\n                \"mlp_fc2\",\n                \"ff.net\",\n                \"ff_context.net\",\n            ],\n            \"awq_w4a16\": [\n                \"img_mod\",\n                \"txt_mod\",\n            ],\n        }\n\n    def __init__(\n        self,\n        config: FluxConfig,\n        hf_config: dict[str, Any],\n        quant_config: Optional[QuantizationConfig] = None,\n    ) -> None:\n        super().__init__(config=config, hf_config=hf_config)\n        self.config = config.arch_config\n\n        self.out_channels = (\n            getattr(self.config, \"out_channels\", None) or self.config.in_channels\n        )\n        self.inner_dim = (\n            self.config.num_attention_heads * self.config.attention_head_dim\n        )\n\n        self.rotary_emb = FluxPosEmbed(theta=10000, axes_dim=self.config.axes_dims_rope)\n\n        text_time_guidance_cls = (\n            CombinedTimestepGuidanceTextProjEmbeddings\n            if self.config.guidance_embeds\n            else CombinedTimestepTextProjEmbeddings\n        )\n        self.time_text_embed = text_time_guidance_cls(\n            embedding_dim=self.inner_dim,\n            pooled_projection_dim=self.config.pooled_projection_dim,\n        )\n\n        self.context_embedder = ColumnParallelLinear(\n            self.config.joint_attention_dim,\n            self.inner_dim,\n            bias=True,\n            gather_output=True,\n        )\n        self.x_embedder = ColumnParallelLinear(\n            self.config.in_channels, self.inner_dim, bias=True, gather_output=True\n        )\n        self.transformer_blocks = nn.ModuleList(\n            [\n                FluxTransformerBlock(\n                    dim=self.inner_dim,\n                    num_attention_heads=self.config.num_attention_heads,\n                    attention_head_dim=self.config.attention_head_dim,\n                    quant_config=quant_config,\n                    prefix=f\"transformer_blocks.{i}\",\n                )\n                for i in range(self.config.num_layers)\n            ]\n        )\n\n        self.single_transformer_blocks = nn.ModuleList(\n            [\n                FluxSingleTransformerBlock(\n                    dim=self.inner_dim,\n                    num_attention_heads=self.config.num_attention_heads,\n                    attention_head_dim=self.config.attention_head_dim,\n                    quant_config=quant_config,\n                    prefix=f\"single_transformer_blocks.{i}\",\n                )\n                for i in range(self.config.num_single_layers)\n            ]\n        )\n\n        self.norm_out = AdaLayerNormContinuous(\n            self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6\n        )\n        self.proj_out = ColumnParallelLinear(\n            self.inner_dim,\n            self.config.patch_size * self.config.patch_size * self.out_channels,\n            bias=True,\n            gather_output=True,\n        )\n\n        self.layer_names = [\n            \"transformer_blocks\",\n            \"single_transformer_blocks\",\n        ]\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: torch.Tensor = None,\n        pooled_projections: torch.Tensor = None,\n        timestep: torch.LongTensor = None,\n        guidance: torch.Tensor = None,\n        freqs_cis: torch.Tensor = None,\n        joint_attention_kwargs: Optional[Dict[str, Any]] = None,\n    ) -> Union[torch.Tensor, Transformer2DModelOutput]:\n        \"\"\"\n        The [`FluxTransformer2DModel`] forward method.\n\n        Args:\n            hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):\n                Input `hidden_states`.\n            encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):\n                Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.\n            pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected\n                from the embeddings of input conditions.\n            timestep ( `torch.LongTensor`):\n                Used to indicate denoising step.\n            guidance (`torch.Tensor`):\n                Guidance embeddings.\n            joint_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n\n        \"\"\"\n        if (\n            joint_attention_kwargs is not None\n            and joint_attention_kwargs.get(\"scale\", None) is not None\n        ):\n            logger.warning(\n                \"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective.\"\n            )\n        hidden_states, _ = self.x_embedder(hidden_states)\n\n        # Only pass guidance to time_text_embed if the model supports it\n        if self.config.guidance_embeds and guidance is not None:\n            temb = self.time_text_embed(timestep, guidance, pooled_projections)\n        else:\n            temb = self.time_text_embed(timestep, pooled_projections)\n\n        encoder_hidden_states, _ = self.context_embedder(encoder_hidden_states)\n\n        if (\n            joint_attention_kwargs is not None\n            and \"ip_adapter_image_embeds\" in joint_attention_kwargs\n        ):\n            ip_adapter_image_embeds = joint_attention_kwargs.pop(\n                \"ip_adapter_image_embeds\"\n            )\n            ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)\n            joint_attention_kwargs.update({\"ip_hidden_states\": ip_hidden_states})\n\n        for block in self.transformer_blocks:\n            encoder_hidden_states, hidden_states = block(\n                hidden_states=hidden_states,\n                encoder_hidden_states=encoder_hidden_states,\n                temb=temb,\n                freqs_cis=freqs_cis,\n                joint_attention_kwargs=joint_attention_kwargs,\n            )\n        for block in self.single_transformer_blocks:\n            encoder_hidden_states, hidden_states = block(\n                hidden_states=hidden_states,\n                encoder_hidden_states=encoder_hidden_states,\n                temb=temb,\n                freqs_cis=freqs_cis,\n                joint_attention_kwargs=joint_attention_kwargs,\n            )\n\n        hidden_states = self.norm_out(hidden_states, temb)\n\n        output, _ = self.proj_out(hidden_states)\n\n        return output\n\n\nEntryClass = FluxTransformer2DModel\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/dits/flux_2.py",
    "content": "# Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Any, Dict, List, Optional, Tuple\n\nimport torch\nimport torch.nn as nn\nfrom diffusers.models.attention import AttentionModuleMixin\nfrom diffusers.models.embeddings import TimestepEmbedding, Timesteps\nfrom diffusers.models.normalization import AdaLayerNormContinuous\n\nfrom sglang.multimodal_gen.configs.models.dits.flux import FluxConfig\nfrom sglang.multimodal_gen.runtime.layers.attention import USPAttention\nfrom sglang.multimodal_gen.runtime.layers.layernorm import RMSNorm, apply_qk_norm\nfrom sglang.multimodal_gen.runtime.layers.linear import ColumnParallelLinear\nfrom sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import (\n    QuantizationConfig,\n)\nfrom sglang.multimodal_gen.runtime.layers.rotary_embedding import (\n    NDRotaryEmbedding,\n    apply_flashinfer_rope_qk_inplace,\n)\nfrom sglang.multimodal_gen.runtime.models.dits.base import CachableDiT\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\nfrom sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)  # pylint: disable=invalid-name\n\n\ndef _get_qkv_projections(\n    attn: \"Flux2Attention\", hidden_states, encoder_hidden_states=None\n):\n    query, _ = attn.to_q(hidden_states)\n    key, _ = attn.to_k(hidden_states)\n    value, _ = attn.to_v(hidden_states)\n\n    encoder_query = encoder_key = encoder_value = None\n    if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:\n        encoder_query, _ = attn.add_q_proj(encoder_hidden_states)\n        encoder_key, _ = attn.add_k_proj(encoder_hidden_states)\n        encoder_value, _ = attn.add_v_proj(encoder_hidden_states)\n\n    return query, key, value, encoder_query, encoder_key, encoder_value\n\n\nclass Flux2SwiGLU(nn.Module):\n    \"\"\"\n    Flux 2 uses a SwiGLU-style activation in the transformer feedforward sub-blocks, but with the linear projection\n    layer fused into the first linear layer of the FF sub-block. Thus, this module has no trainable parameters.\n    \"\"\"\n\n    def __init__(self):\n        super().__init__()\n        self.gate_fn = nn.SiLU()\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x1, x2 = x.chunk(2, dim=-1)\n        x = self.gate_fn(x1) * x2\n        return x\n\n\nclass Flux2FeedForward(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        dim_out: Optional[int] = None,\n        mult: float = 3.0,\n        inner_dim: Optional[int] = None,\n        bias: bool = False,\n        quant_config: Optional[QuantizationConfig] = None,\n    ):\n        super().__init__()\n        if inner_dim is None:\n            inner_dim = int(dim * mult)\n        dim_out = dim_out or dim\n\n        # Flux2SwiGLU will reduce the dimension by half\n        self.linear_in = ColumnParallelLinear(\n            dim, inner_dim * 2, bias=bias, gather_output=True, quant_config=quant_config\n        )\n        self.act_fn = Flux2SwiGLU()\n        self.linear_out = ColumnParallelLinear(\n            inner_dim, dim_out, bias=bias, gather_output=True, quant_config=quant_config\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x, _ = self.linear_in(x)\n        x = self.act_fn(x)\n        x, _ = self.linear_out(x)\n        return x\n\n\nclass Flux2Attention(torch.nn.Module, AttentionModuleMixin):\n    def __init__(\n        self,\n        query_dim: int,\n        num_heads: int = 8,\n        dim_head: int = 64,\n        dropout: float = 0.0,\n        bias: bool = False,\n        added_kv_proj_dim: Optional[int] = None,\n        added_proj_bias: Optional[bool] = True,\n        out_bias: bool = True,\n        eps: float = 1e-5,\n        out_dim: int = None,\n        elementwise_affine: bool = True,\n        quant_config: Optional[QuantizationConfig] = None,\n    ):\n        super().__init__()\n\n        self.head_dim = dim_head\n        self.inner_dim = out_dim if out_dim is not None else dim_head * num_heads\n        self.query_dim = query_dim\n        self.out_dim = out_dim if out_dim is not None else query_dim\n        self.heads = out_dim // dim_head if out_dim is not None else num_heads\n\n        self.use_bias = bias\n        self.dropout = dropout\n\n        self.added_kv_proj_dim = added_kv_proj_dim\n        self.added_proj_bias = added_proj_bias\n\n        self.to_q = ColumnParallelLinear(\n            query_dim,\n            self.inner_dim,\n            bias=bias,\n            gather_output=True,\n            quant_config=quant_config,\n        )\n        self.to_k = ColumnParallelLinear(\n            query_dim,\n            self.inner_dim,\n            bias=bias,\n            gather_output=True,\n            quant_config=quant_config,\n        )\n        self.to_v = ColumnParallelLinear(\n            query_dim,\n            self.inner_dim,\n            bias=bias,\n            gather_output=True,\n            quant_config=quant_config,\n        )\n\n        # QK Norm\n        self.norm_q = RMSNorm(dim_head, eps=eps)\n        self.norm_k = RMSNorm(dim_head, eps=eps)\n\n        self.to_out = torch.nn.ModuleList([])\n        self.to_out.append(\n            ColumnParallelLinear(\n                self.inner_dim,\n                self.out_dim,\n                bias=out_bias,\n                gather_output=True,\n                quant_config=quant_config,\n            )\n        )\n        self.to_out.append(torch.nn.Dropout(dropout))\n\n        if added_kv_proj_dim is not None:\n            self.norm_added_q = RMSNorm(dim_head, eps=eps)\n            self.norm_added_k = RMSNorm(dim_head, eps=eps)\n            self.add_q_proj = ColumnParallelLinear(\n                added_kv_proj_dim,\n                self.inner_dim,\n                bias=added_proj_bias,\n                gather_output=True,\n                quant_config=quant_config,\n            )\n            self.add_k_proj = ColumnParallelLinear(\n                added_kv_proj_dim,\n                self.inner_dim,\n                bias=added_proj_bias,\n                gather_output=True,\n                quant_config=quant_config,\n            )\n            self.add_v_proj = ColumnParallelLinear(\n                added_kv_proj_dim,\n                self.inner_dim,\n                bias=added_proj_bias,\n                gather_output=True,\n                quant_config=quant_config,\n            )\n            self.to_add_out = ColumnParallelLinear(\n                self.inner_dim,\n                query_dim,\n                bias=out_bias,\n                gather_output=True,\n                quant_config=quant_config,\n            )\n\n        self.attn = USPAttention(\n            num_heads=num_heads,\n            head_size=self.head_dim,\n            dropout_rate=0,\n            softmax_scale=None,\n            causal=False,\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: Optional[torch.Tensor] = None,\n        freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n    ) -> torch.Tensor:\n        query, key, value, encoder_query, encoder_key, encoder_value = (\n            _get_qkv_projections(self, hidden_states, encoder_hidden_states)\n        )\n\n        query = query.unflatten(-1, (self.heads, -1))\n        key = key.unflatten(-1, (self.heads, -1))\n        value = value.unflatten(-1, (self.heads, -1))\n\n        query, key = apply_qk_norm(\n            q=query,\n            k=key,\n            q_norm=self.norm_q,\n            k_norm=self.norm_k,\n            head_dim=self.head_dim,\n            allow_inplace=True,\n        )\n\n        if self.added_kv_proj_dim is not None:\n            encoder_query = encoder_query.unflatten(-1, (self.heads, -1))\n            encoder_key = encoder_key.unflatten(-1, (self.heads, -1))\n            encoder_value = encoder_value.unflatten(-1, (self.heads, -1))\n\n            encoder_query, encoder_key = apply_qk_norm(\n                q=encoder_query,\n                k=encoder_key,\n                q_norm=self.norm_added_q,\n                k_norm=self.norm_added_k,\n                head_dim=self.head_dim,\n                allow_inplace=True,\n            )\n\n            query = torch.cat([encoder_query, query], dim=1)\n            key = torch.cat([encoder_key, key], dim=1)\n            value = torch.cat([encoder_value, value], dim=1)\n\n        if freqs_cis is not None:\n            cos, sin = freqs_cis\n            cos_sin_cache = torch.cat(\n                [\n                    cos.to(dtype=torch.float32).contiguous(),\n                    sin.to(dtype=torch.float32).contiguous(),\n                ],\n                dim=-1,\n            )\n            query, key = apply_flashinfer_rope_qk_inplace(\n                query, key, cos_sin_cache, is_neox=False\n            )\n\n        num_rep = (\n            encoder_hidden_states.shape[1] if encoder_hidden_states is not None else 0\n        )\n        hidden_states = self.attn(query, key, value, num_replicated_prefix=num_rep)\n\n        hidden_states = hidden_states.flatten(2, 3)\n        hidden_states = hidden_states.to(query.dtype)\n\n        if encoder_hidden_states is not None:\n            encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(\n                [\n                    encoder_hidden_states.shape[1],\n                    hidden_states.shape[1] - encoder_hidden_states.shape[1],\n                ],\n                dim=1,\n            )\n            encoder_hidden_states, _ = self.to_add_out(encoder_hidden_states)\n\n        hidden_states, _ = self.to_out[0](hidden_states)\n        hidden_states = self.to_out[1](hidden_states)\n\n        if encoder_hidden_states is not None:\n            return hidden_states, encoder_hidden_states\n        else:\n            return hidden_states\n\n\nclass Flux2ParallelSelfAttention(torch.nn.Module, AttentionModuleMixin):\n    \"\"\"\n    Flux 2 parallel self-attention for the Flux 2 single-stream transformer blocks.\n\n    This implements a parallel transformer block, where the attention QKV projections are fused to the feedforward (FF)\n    input projections, and the attention output projections are fused to the FF output projections. See the [ViT-22B\n    paper](https://arxiv.org/abs/2302.05442) for a visual depiction of this type of transformer block.\n    \"\"\"\n\n    # Does not support QKV fusion as the QKV projections are always fused\n    _supports_qkv_fusion = False\n\n    def __init__(\n        self,\n        query_dim: int,\n        num_heads: int = 8,\n        dim_head: int = 64,\n        dropout: float = 0.0,\n        bias: bool = False,\n        out_bias: bool = True,\n        eps: float = 1e-5,\n        out_dim: int = None,\n        elementwise_affine: bool = True,\n        mlp_ratio: float = 4.0,\n        mlp_mult_factor: int = 2,\n        quant_config: Optional[QuantizationConfig] = None,\n    ):\n        super().__init__()\n\n        self.head_dim = dim_head\n        self.inner_dim = out_dim if out_dim is not None else dim_head * num_heads\n        self.query_dim = query_dim\n        self.out_dim = out_dim if out_dim is not None else query_dim\n        self.heads = out_dim // dim_head if out_dim is not None else num_heads\n\n        self.use_bias = bias\n        self.dropout = dropout\n\n        self.mlp_ratio = mlp_ratio\n        self.mlp_hidden_dim = int(query_dim * self.mlp_ratio)\n        self.mlp_mult_factor = mlp_mult_factor\n\n        # Fused QKV projections + MLP input projection\n        self.to_qkv_mlp_proj = ColumnParallelLinear(\n            self.query_dim,\n            self.inner_dim * 3 + self.mlp_hidden_dim * self.mlp_mult_factor,\n            bias=bias,\n            gather_output=True,\n            quant_config=quant_config,\n        )\n        self.mlp_act_fn = Flux2SwiGLU()\n\n        # QK Norm\n        self.norm_q = RMSNorm(dim_head, eps=eps)\n        self.norm_k = RMSNorm(dim_head, eps=eps)\n\n        # Fused attention output projection + MLP output projection\n        self.to_out = ColumnParallelLinear(\n            self.inner_dim + self.mlp_hidden_dim,\n            self.out_dim,\n            bias=out_bias,\n            gather_output=True,\n            quant_config=quant_config,\n        )\n\n        self.attn = USPAttention(\n            num_heads=num_heads,\n            head_size=self.head_dim,\n            dropout_rate=0,\n            softmax_scale=None,\n            causal=False,\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n        num_replicated_prefix: int = 0,\n        **kwargs,\n    ) -> torch.Tensor:\n        # Parallel in (QKV + MLP in) projection\n        hidden_states, _ = self.to_qkv_mlp_proj(hidden_states)\n        qkv, mlp_hidden_states = torch.split(\n            hidden_states,\n            [3 * self.inner_dim, self.mlp_hidden_dim * self.mlp_mult_factor],\n            dim=-1,\n        )\n\n        # Handle the attention logic\n        query, key, value = qkv.chunk(3, dim=-1)\n\n        query = query.unflatten(-1, (self.heads, -1))\n        key = key.unflatten(-1, (self.heads, -1))\n        value = value.unflatten(-1, (self.heads, -1))\n\n        query = self.norm_q(query)\n        key = self.norm_k(key)\n\n        if freqs_cis is not None:\n            cos, sin = freqs_cis\n            cos_sin_cache = torch.cat(\n                [\n                    cos.to(dtype=torch.float32).contiguous(),\n                    sin.to(dtype=torch.float32).contiguous(),\n                ],\n                dim=-1,\n            )\n            query, key = apply_flashinfer_rope_qk_inplace(\n                query, key, cos_sin_cache, is_neox=False\n            )\n        hidden_states = self.attn(\n            query, key, value, num_replicated_prefix=num_replicated_prefix\n        )\n        hidden_states = hidden_states.flatten(2, 3)\n        hidden_states = hidden_states.to(query.dtype)\n\n        # Handle the feedforward (FF) logic\n        mlp_hidden_states = self.mlp_act_fn(mlp_hidden_states)\n\n        # Concatenate and parallel output projection\n        hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1)\n        hidden_states, _ = self.to_out(hidden_states)\n\n        return hidden_states\n\n\nclass Flux2SingleTransformerBlock(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        num_attention_heads: int,\n        attention_head_dim: int,\n        mlp_ratio: float = 3.0,\n        eps: float = 1e-6,\n        bias: bool = False,\n        quant_config: Optional[QuantizationConfig] = None,\n    ):\n        super().__init__()\n\n        self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)\n\n        # Note that the MLP in/out linear layers are fused with the attention QKV/out projections, respectively; this\n        # is often called a \"parallel\" transformer block. See the [ViT-22B paper](https://arxiv.org/abs/2302.05442)\n        # for a visual depiction of this type of transformer block.\n        self.attn = Flux2ParallelSelfAttention(\n            query_dim=dim,\n            dim_head=attention_head_dim,\n            num_heads=num_attention_heads,\n            out_dim=dim,\n            bias=bias,\n            out_bias=bias,\n            eps=eps,\n            mlp_ratio=mlp_ratio,\n            mlp_mult_factor=2,\n            quant_config=quant_config,\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: Optional[torch.Tensor],\n        temb_mod_params: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],\n        freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n        joint_attention_kwargs: Optional[Dict[str, Any]] = None,\n        split_hidden_states: bool = False,\n        text_seq_len: Optional[int] = None,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        # If encoder_hidden_states is None, hidden_states is assumed to have encoder_hidden_states already\n        # concatenated\n        if encoder_hidden_states is not None:\n            text_seq_len = encoder_hidden_states.shape[1]\n            hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)\n\n        mod_shift, mod_scale, mod_gate = temb_mod_params\n\n        norm_hidden_states = self.norm(hidden_states)\n        norm_hidden_states = (1 + mod_scale) * norm_hidden_states + mod_shift\n\n        joint_attention_kwargs = joint_attention_kwargs or {}\n        attn_output = self.attn(\n            hidden_states=norm_hidden_states,\n            freqs_cis=freqs_cis,\n            num_replicated_prefix=text_seq_len or 0,\n            **joint_attention_kwargs,\n        )\n\n        hidden_states = hidden_states + mod_gate * attn_output\n        if hidden_states.dtype == torch.float16:\n            hidden_states = hidden_states.clip(-65504, 65504)\n\n        if split_hidden_states:\n            encoder_hidden_states, hidden_states = (\n                hidden_states[:, :text_seq_len],\n                hidden_states[:, text_seq_len:],\n            )\n            return encoder_hidden_states, hidden_states\n        else:\n            return hidden_states\n\n\nclass Flux2TransformerBlock(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        num_attention_heads: int,\n        attention_head_dim: int,\n        mlp_ratio: float = 3.0,\n        eps: float = 1e-6,\n        bias: bool = False,\n        quant_config: Optional[QuantizationConfig] = None,\n    ):\n        super().__init__()\n        self.mlp_hidden_dim = int(dim * mlp_ratio)\n\n        self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)\n        self.norm1_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)\n\n        self.attn = Flux2Attention(\n            query_dim=dim,\n            added_kv_proj_dim=dim,\n            dim_head=attention_head_dim,\n            num_heads=num_attention_heads,\n            out_dim=dim,\n            bias=bias,\n            added_proj_bias=bias,\n            out_bias=bias,\n            eps=eps,\n            quant_config=quant_config,\n        )\n\n        self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)\n        self.ff = Flux2FeedForward(\n            dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias, quant_config=quant_config\n        )\n\n        self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)\n        self.ff_context = Flux2FeedForward(\n            dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias, quant_config=quant_config\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: torch.Tensor,\n        temb_mod_params_img: Tuple[\n            Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...\n        ],\n        temb_mod_params_txt: Tuple[\n            Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...\n        ],\n        freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n        joint_attention_kwargs: Optional[Dict[str, Any]] = None,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        joint_attention_kwargs = joint_attention_kwargs or {}\n\n        # Modulation parameters shape: [1, 1, self.dim]\n        (shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = (\n            temb_mod_params_img\n        )\n        (c_shift_msa, c_scale_msa, c_gate_msa), (\n            c_shift_mlp,\n            c_scale_mlp,\n            c_gate_mlp,\n        ) = temb_mod_params_txt\n\n        # Img stream\n        norm_hidden_states = self.norm1(hidden_states)\n        norm_hidden_states = (1 + scale_msa) * norm_hidden_states + shift_msa\n\n        # Conditioning txt stream\n        norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states)\n        norm_encoder_hidden_states = (\n            1 + c_scale_msa\n        ) * norm_encoder_hidden_states + c_shift_msa\n\n        # Attention on concatenated img + txt stream\n        attention_outputs = self.attn(\n            hidden_states=norm_hidden_states,\n            encoder_hidden_states=norm_encoder_hidden_states,\n            freqs_cis=freqs_cis,\n            **joint_attention_kwargs,\n        )\n\n        attn_output, context_attn_output = attention_outputs\n\n        # Process attention outputs for the image stream (`hidden_states`).\n        attn_output = gate_msa * attn_output\n        hidden_states = hidden_states + attn_output\n\n        norm_hidden_states = self.norm2(hidden_states)\n        norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp\n\n        ff_output = self.ff(norm_hidden_states)\n        hidden_states = hidden_states + gate_mlp * ff_output\n\n        # Process attention outputs for the text stream (`encoder_hidden_states`).\n        context_attn_output = c_gate_msa * context_attn_output\n        encoder_hidden_states = encoder_hidden_states + context_attn_output\n\n        norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)\n        norm_encoder_hidden_states = (\n            norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp\n        )\n\n        context_ff_output = self.ff_context(norm_encoder_hidden_states)\n        encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output\n        if encoder_hidden_states.dtype == torch.float16:\n            encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)\n\n        return encoder_hidden_states, hidden_states\n\n\nclass Flux2TimestepGuidanceEmbeddings(nn.Module):\n    def __init__(\n        self,\n        in_channels: int = 256,\n        embedding_dim: int = 6144,\n        bias: bool = False,\n        guidance_embeds: bool = True,\n    ):\n        super().__init__()\n\n        self.time_proj = Timesteps(\n            num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0\n        )\n        self.timestep_embedder = TimestepEmbedding(\n            in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias\n        )\n\n        if guidance_embeds:\n            self.guidance_embedder = TimestepEmbedding(\n                in_channels=in_channels,\n                time_embed_dim=embedding_dim,\n                sample_proj_bias=bias,\n            )\n        else:\n            self.guidance_embedder = None\n\n    def forward(\n        self, timestep: torch.Tensor, guidance: Optional[torch.Tensor] = None\n    ) -> torch.Tensor:\n        timesteps_proj = self.time_proj(timestep)\n        timesteps_emb = self.timestep_embedder(\n            timesteps_proj.to(timestep.dtype)\n        )  # (N, D)\n\n        if guidance is not None and self.guidance_embedder is not None:\n            guidance_proj = self.time_proj(guidance)\n            guidance_emb = self.guidance_embedder(\n                guidance_proj.to(guidance.dtype)\n            )  # (N, D)\n            time_guidance_emb = timesteps_emb + guidance_emb\n            return time_guidance_emb\n        else:\n            return timesteps_emb\n\n\nclass Flux2Modulation(nn.Module):\n    def __init__(self, dim: int, mod_param_sets: int = 2, bias: bool = False):\n        super().__init__()\n        self.mod_param_sets = mod_param_sets\n\n        self.linear = ColumnParallelLinear(\n            dim, dim * 3 * self.mod_param_sets, bias=bias, gather_output=True\n        )\n        self.act_fn = nn.SiLU()\n\n    def forward(\n        self, temb: torch.Tensor\n    ) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]:\n        mod = self.act_fn(temb)\n        mod, _ = self.linear(mod)\n\n        if mod.ndim == 2:\n            mod = mod.unsqueeze(1)\n        mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1)\n        # Return tuple of 3-tuples of modulation params shift/scale/gate\n        return tuple(\n            mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets)\n        )\n\n\nclass Flux2PosEmbed(nn.Module):\n    def __init__(self, theta: int, axes_dim: List[int]):\n        super().__init__()\n        self.rope = NDRotaryEmbedding(\n            rope_dim_list=axes_dim,\n            rope_theta=theta,\n            use_real=False,\n            repeat_interleave_real=False,\n            dtype=(\n                torch.float32\n                if current_platform.is_mps() or current_platform.is_musa()\n                else torch.float64\n            ),\n        )\n\n    def forward(self, ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:\n        pos = ids.float()\n        # TODO: potential error: flux use n_axes = ids.shape[-1]\n        # see: https://github.com/huggingface/diffusers/blob/17c0e79dbdf53fb6705e9c09cc1a854b84c39249/src/diffusers/models/transformers/transformer_flux.py#L509\n        freqs_cos, freqs_sin = self.rope.forward_uncached(pos=pos)\n        return freqs_cos.contiguous().float(), freqs_sin.contiguous().float()\n\n\nclass Flux2Transformer2DModel(CachableDiT, OffloadableDiTMixin):\n    \"\"\"\n    The Transformer model introduced in Flux 2.\n\n    Reference: https://blackforestlabs.ai/announcing-black-forest-labs/\n\n    \"\"\"\n\n    param_names_mapping = FluxConfig().arch_config.param_names_mapping\n\n    def __init__(\n        self,\n        config: FluxConfig,\n        hf_config: dict[str, Any],\n        quant_config: Optional[QuantizationConfig] = None,\n    ):\n        super().__init__(config=config, hf_config=hf_config)\n        patch_size: int = config.patch_size\n        in_channels: int = config.in_channels\n        out_channels: Optional[int] = config.out_channels\n        num_layers: int = config.num_layers\n        num_single_layers: int = config.num_single_layers\n        attention_head_dim: int = config.attention_head_dim\n        num_attention_heads: int = config.num_attention_heads\n        joint_attention_dim: int = config.joint_attention_dim\n        timestep_guidance_channels: int = config.timestep_guidance_channels\n        mlp_ratio: float = config.mlp_ratio\n        axes_dims_rope: Tuple[int, ...] = config.axes_dims_rope\n        rope_theta: int = config.rope_theta\n        eps: float = config.eps\n        guidance_embeds: bool = getattr(config, \"guidance_embeds\", True)\n        self.out_channels = out_channels or in_channels\n        self.inner_dim = num_attention_heads * attention_head_dim\n        self.guidance_embeds = guidance_embeds\n\n        # 1. Sinusoidal positional embedding for RoPE on image and text tokens\n        self.rotary_emb = Flux2PosEmbed(theta=rope_theta, axes_dim=axes_dims_rope)\n\n        # 2. Combined timestep + guidance embedding\n        self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings(\n            in_channels=timestep_guidance_channels,\n            embedding_dim=self.inner_dim,\n            bias=False,\n            guidance_embeds=guidance_embeds,\n        )\n\n        # 3. Modulation (double stream and single stream blocks share modulation parameters, resp.)\n        # Two sets of shift/scale/gate modulation parameters for the double stream attn and FF sub-blocks\n        self.double_stream_modulation_img = Flux2Modulation(\n            self.inner_dim, mod_param_sets=2, bias=False\n        )\n        self.double_stream_modulation_txt = Flux2Modulation(\n            self.inner_dim, mod_param_sets=2, bias=False\n        )\n        # Only one set of modulation parameters as the attn and FF sub-blocks are run in parallel for single stream\n        self.single_stream_modulation = Flux2Modulation(\n            self.inner_dim, mod_param_sets=1, bias=False\n        )\n\n        # 4. Input projections\n        self.x_embedder = ColumnParallelLinear(\n            in_channels, self.inner_dim, bias=False, gather_output=True\n        )\n        self.context_embedder = ColumnParallelLinear(\n            joint_attention_dim, self.inner_dim, bias=False, gather_output=True\n        )\n\n        # 5. Double Stream Transformer Blocks\n        self.transformer_blocks = nn.ModuleList(\n            [\n                Flux2TransformerBlock(\n                    dim=self.inner_dim,\n                    num_attention_heads=num_attention_heads,\n                    attention_head_dim=attention_head_dim,\n                    mlp_ratio=mlp_ratio,\n                    eps=eps,\n                    bias=False,\n                    quant_config=quant_config,\n                )\n                for _ in range(num_layers)\n            ]\n        )\n\n        # 6. Single Stream Transformer Blocks\n        self.single_transformer_blocks = nn.ModuleList(\n            [\n                Flux2SingleTransformerBlock(\n                    dim=self.inner_dim,\n                    num_attention_heads=num_attention_heads,\n                    attention_head_dim=attention_head_dim,\n                    mlp_ratio=mlp_ratio,\n                    eps=eps,\n                    bias=False,\n                    quant_config=quant_config,\n                )\n                for _ in range(num_single_layers)\n            ]\n        )\n\n        # 7. Output layers\n        self.norm_out = AdaLayerNormContinuous(\n            self.inner_dim,\n            self.inner_dim,\n            elementwise_affine=False,\n            eps=eps,\n            bias=False,\n        )\n        self.proj_out = ColumnParallelLinear(\n            self.inner_dim,\n            patch_size * patch_size * self.out_channels,\n            bias=False,\n            gather_output=True,\n        )\n\n        self.layer_names = [\"transformer_blocks\", \"single_transformer_blocks\"]\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: torch.Tensor = None,\n        timestep: torch.LongTensor = None,\n        guidance: torch.Tensor = None,\n        freqs_cis: torch.Tensor = None,\n        joint_attention_kwargs: Optional[Dict[str, Any]] = None,\n    ) -> torch.Tensor:\n        \"\"\"\n        The [`FluxTransformer2DModel`] forward method.\n\n        Args:\n            hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):\n                Input `hidden_states`.\n            encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):\n                Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.\n            timestep ( `torch.LongTensor`):\n                Used to indicate denoising step.\n            joint_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n\n        \"\"\"\n        # 0. Handle input arguments\n        if joint_attention_kwargs is not None:\n            joint_attention_kwargs = joint_attention_kwargs.copy()\n            lora_scale = joint_attention_kwargs.pop(\"scale\", 1.0)\n        else:\n            lora_scale = 1.0\n\n        num_txt_tokens = encoder_hidden_states.shape[1]\n\n        # 1. Calculate timestep embedding and modulation parameters\n        timestep = timestep.to(hidden_states.dtype)\n        if guidance is not None:\n            guidance = guidance.to(hidden_states.dtype)\n\n        temb = self.time_guidance_embed(timestep, guidance)\n\n        double_stream_mod_img = self.double_stream_modulation_img(temb)\n        double_stream_mod_txt = self.double_stream_modulation_txt(temb)\n        single_stream_mod = self.single_stream_modulation(temb)[0]\n\n        # 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states)\n        hidden_states, _ = self.x_embedder(hidden_states)\n        encoder_hidden_states, _ = self.context_embedder(encoder_hidden_states)\n\n        # 3. Calculate RoPE embeddings from image and text tokens\n        # NOTE: the below logic means that we can't support batched inference with images of different resolutions or\n        # text prompts of different lengths. Is this a use case we want to support?\n        # 4. Double Stream Transformer Blocks\n        for index_block, block in enumerate(self.transformer_blocks):\n            encoder_hidden_states, hidden_states = block(\n                hidden_states=hidden_states,\n                encoder_hidden_states=encoder_hidden_states,\n                temb_mod_params_img=double_stream_mod_img,\n                temb_mod_params_txt=double_stream_mod_txt,\n                freqs_cis=freqs_cis,\n                joint_attention_kwargs=joint_attention_kwargs,\n            )\n        # Concatenate text and image streams for single-block inference\n        hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)\n\n        # 5. Single Stream Transformer Blocks\n        for index_block, block in enumerate(self.single_transformer_blocks):\n            hidden_states = block(\n                hidden_states=hidden_states,\n                encoder_hidden_states=None,\n                temb_mod_params=single_stream_mod,\n                freqs_cis=freqs_cis,\n                joint_attention_kwargs=joint_attention_kwargs,\n                text_seq_len=num_txt_tokens,\n            )\n        # Remove text tokens from concatenated stream\n        hidden_states = hidden_states[:, num_txt_tokens:, ...]\n\n        # 6. Output layers\n        hidden_states = self.norm_out(hidden_states, temb)\n        output, _ = self.proj_out(hidden_states)\n\n        return output\n\n\nEntryClass = Flux2Transformer2DModel\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/dits/glm_image.py",
    "content": "# Copyright 2025 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Any, Dict, List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom sglang.multimodal_gen.configs.models.dits.glmimage import GlmImageDitConfig\nfrom sglang.multimodal_gen.runtime.distributed.parallel_state import (\n    get_sp_parallel_rank,\n    get_sp_world_size,\n)\nfrom sglang.multimodal_gen.runtime.layers.attention import USPAttention\nfrom sglang.multimodal_gen.runtime.layers.layernorm import (\n    ScaleResidualLayerNormScaleShift,\n)\nfrom sglang.multimodal_gen.runtime.layers.linear import ReplicatedLinear\nfrom sglang.multimodal_gen.runtime.layers.mlp import FeedForward\nfrom sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import (\n    QuantizationConfig,\n)\nfrom sglang.multimodal_gen.runtime.layers.rotary_embedding import (\n    _apply_rotary_emb,\n    apply_flashinfer_rope_qk_inplace,\n)\nfrom sglang.multimodal_gen.runtime.layers.visual_embedding import Timesteps\nfrom sglang.multimodal_gen.runtime.models.dits.base import CachableDiT\nfrom sglang.multimodal_gen.runtime.platforms import (\n    AttentionBackendEnum,\n    current_platform,\n)\nfrom sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n_is_cuda = current_platform.is_cuda()\n\n\nclass GlmImageLayerKVCache:\n    \"\"\"KV cache for GlmImage model.\"\"\"\n\n    def __init__(self):\n        self.k_cache = None\n        self.v_cache = None\n        self.mode: Optional[str] = None  # \"write\", \"read\", \"skip\"\n\n    def store(self, k: torch.Tensor, v: torch.Tensor):\n        if self.k_cache is None:\n            self.k_cache = k\n            self.v_cache = v\n        else:\n            self.k_cache = torch.cat([self.k_cache, k], dim=2)\n            self.v_cache = torch.cat([self.v_cache, v], dim=2)\n\n    def get(self):\n        return self.k_cache, self.v_cache\n\n    def clear(self):\n        self.k_cache = None\n        self.v_cache = None\n        self.mode = None\n\n\nclass GlmImageKVCache:\n    \"\"\"Container for all layers' KV caches.\"\"\"\n\n    def __init__(self, num_layers: int):\n        self.num_layers = num_layers\n        self.caches = [GlmImageLayerKVCache() for _ in range(num_layers)]\n\n    def __getitem__(self, layer_idx: int) -> GlmImageLayerKVCache:\n        return self.caches[layer_idx]\n\n    def set_mode(self, mode: Optional[str]):\n        if mode is not None and mode not in [\"write\", \"read\", \"skip\"]:\n            raise ValueError(\n                f\"Invalid mode: {mode}, must be one of 'write', 'read', 'skip'\"\n            )\n        for cache in self.caches:\n            cache.mode = mode\n\n    def clear(self):\n        for cache in self.caches:\n            cache.clear()\n\n\nclass GlmImageTimestepEmbedding(nn.Module):\n    \"\"\"\n    Replacement for diffusers TimestepEmbedding using ReplicatedLinear.\n    Structure: linear_1 -> act(silu) -> linear_2\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        time_embed_dim: int,\n        act_fn: str = \"silu\",\n        out_dim: int = None,\n    ):\n        super().__init__()\n        if out_dim is None:\n            out_dim = time_embed_dim\n        self.linear_1 = ReplicatedLinear(in_channels, time_embed_dim, bias=True)\n        if act_fn == \"silu\":\n            self.act = nn.SiLU()\n        elif act_fn == \"gelu\":\n            self.act = nn.GELU(approximate=\"tanh\")\n        else:\n            self.act = nn.SiLU()\n        self.linear_2 = ReplicatedLinear(time_embed_dim, out_dim, bias=True)\n\n    def forward(self, sample: torch.Tensor) -> torch.Tensor:\n        sample, _ = self.linear_1(sample)\n        sample = self.act(sample)\n        sample, _ = self.linear_2(sample)\n        return sample\n\n\nclass GlmImageTextProjection(nn.Module):\n    \"\"\"\n    Replacement for diffusers PixArtAlphaTextProjection using ReplicatedLinear.\n    Structure: linear_1 -> act_1 -> linear_2\n    \"\"\"\n\n    def __init__(\n        self,\n        in_features: int,\n        hidden_size: int,\n        out_features: int = None,\n        act_fn: str = \"silu\",\n    ):\n        super().__init__()\n        if out_features is None:\n            out_features = hidden_size\n        self.linear_1 = ReplicatedLinear(in_features, hidden_size, bias=True)\n        if act_fn == \"silu\":\n            self.act_1 = nn.SiLU()\n        elif act_fn == \"gelu_tanh\":\n            self.act_1 = nn.GELU(approximate=\"tanh\")\n        else:\n            self.act_1 = nn.SiLU()\n        self.linear_2 = ReplicatedLinear(hidden_size, out_features, bias=True)\n\n    def forward(self, caption: torch.Tensor) -> torch.Tensor:\n        hidden_states, _ = self.linear_1(caption)\n        hidden_states = self.act_1(hidden_states)\n        hidden_states, _ = self.linear_2(hidden_states)\n        return hidden_states\n\n\nclass GlmImageCombinedTimestepSizeEmbeddings(nn.Module):\n    def __init__(\n        self,\n        embedding_dim: int,\n        condition_dim: int,\n        pooled_projection_dim: int,\n        timesteps_dim: int = 256,\n    ):\n        super().__init__()\n\n        self.time_proj = Timesteps(\n            num_channels=timesteps_dim, flip_sin_to_cos=True, downscale_freq_shift=0\n        )\n        self.condition_proj = Timesteps(\n            num_channels=condition_dim, flip_sin_to_cos=True, downscale_freq_shift=0\n        )\n        self.timestep_embedder = GlmImageTimestepEmbedding(\n            in_channels=timesteps_dim, time_embed_dim=embedding_dim\n        )\n        self.condition_embedder = GlmImageTextProjection(\n            pooled_projection_dim, embedding_dim, act_fn=\"silu\"\n        )\n\n    def forward(\n        self,\n        timestep: torch.Tensor,\n        target_size: torch.Tensor,\n        crop_coords: torch.Tensor,\n        hidden_dtype: torch.dtype,\n    ) -> torch.Tensor:\n        timesteps_proj = self.time_proj(timestep)\n\n        crop_coords_proj = self.condition_proj(crop_coords.flatten()).view(\n            crop_coords.size(0), -1\n        )\n        target_size_proj = self.condition_proj(target_size.flatten()).view(\n            target_size.size(0), -1\n        )\n\n        # (B, 2 * condition_dim)\n        condition_proj = torch.cat([crop_coords_proj, target_size_proj], dim=1)\n\n        timesteps_emb = self.timestep_embedder(\n            timesteps_proj.to(dtype=hidden_dtype)\n        )  # (B, embedding_dim)\n        condition_emb = self.condition_embedder(\n            condition_proj.to(dtype=hidden_dtype)\n        )  # (B, embedding_dim)\n\n        conditioning = timesteps_emb + condition_emb\n        return conditioning\n\n\nclass GlmImageImageProjector(nn.Module):\n    def __init__(\n        self,\n        in_channels: int = 16,\n        hidden_size: int = 2560,\n        patch_size: int = 2,\n    ):\n        super().__init__()\n        self.patch_size = patch_size\n\n        self.proj = nn.Linear(in_channels * patch_size**2, hidden_size)\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        batch_size, channel, height, width = hidden_states.shape\n        post_patch_height = height // self.patch_size\n        post_patch_width = width // self.patch_size\n\n        hidden_states = hidden_states.reshape(\n            batch_size,\n            channel,\n            post_patch_height,\n            self.patch_size,\n            post_patch_width,\n            self.patch_size,\n        )\n        hidden_states = (\n            hidden_states.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2)\n        )\n        hidden_states = self.proj(hidden_states)\n\n        return hidden_states\n\n\nclass GlmImageAdaLayerNormZero(nn.Module):\n    def __init__(self, embedding_dim: int, dim: int) -> None:\n        super().__init__()\n\n        self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)\n        self.norm_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)\n        self.linear = ReplicatedLinear(embedding_dim, 12 * dim, bias=True)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: torch.Tensor,\n        temb: torch.Tensor,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        dtype = hidden_states.dtype\n        norm_hidden_states = self.norm(hidden_states).to(dtype=dtype)\n        norm_encoder_hidden_states = self.norm_context(encoder_hidden_states).to(\n            dtype=dtype\n        )\n\n        emb, _ = self.linear(temb)\n        (\n            shift_msa,\n            c_shift_msa,\n            scale_msa,\n            c_scale_msa,\n            gate_msa,\n            c_gate_msa,\n            shift_mlp,\n            c_shift_mlp,\n            scale_mlp,\n            c_scale_mlp,\n            gate_mlp,\n            c_gate_mlp,\n        ) = emb.chunk(12, dim=1)\n\n        hidden_states = norm_hidden_states * (\n            1 + scale_msa.unsqueeze(1)\n        ) + shift_msa.unsqueeze(1)\n        encoder_hidden_states = norm_encoder_hidden_states * (\n            1 + c_scale_msa.unsqueeze(1)\n        ) + c_shift_msa.unsqueeze(1)\n\n        return (\n            hidden_states,\n            gate_msa,\n            shift_mlp,\n            scale_mlp,\n            gate_mlp,\n            encoder_hidden_states,\n            c_gate_msa,\n            c_shift_mlp,\n            c_scale_mlp,\n            c_gate_mlp,\n        )\n\n\nclass GlmImageAttention(torch.nn.Module):\n    def __init__(\n        self,\n        query_dim,\n        heads,\n        dim_head,\n        out_dim,\n        bias,\n        qk_norm,\n        elementwise_affine,\n        eps,\n        supported_attention_backends: set[AttentionBackendEnum] | None = None,\n        prefix: str = \"\",\n        quant_config: QuantizationConfig | None = None,\n    ):\n        super().__init__()\n\n        self.k_cache = None\n        self.v_cache = None\n\n        self.heads = out_dim // dim_head if out_dim is not None else heads\n        self.dim_head = dim_head\n        self.inner_dim = out_dim if out_dim is not None else dim_head * heads\n        self.inner_kv_dim = self.inner_dim\n        self.out_dim = out_dim if out_dim is not None else query_dim\n\n        self.num_kv_heads = self.dim_head // self.inner_kv_dim\n\n        self.to_q = ReplicatedLinear(\n            query_dim, self.inner_dim, bias=bias, quant_config=quant_config\n        )\n        self.to_k = ReplicatedLinear(\n            query_dim, self.inner_kv_dim, bias=bias, quant_config=quant_config\n        )\n        self.to_v = ReplicatedLinear(\n            query_dim, self.inner_kv_dim, bias=bias, quant_config=quant_config\n        )\n\n        # (dropout omitted)\n        self.to_out = nn.ModuleList(\n            [\n                ReplicatedLinear(\n                    self.inner_dim, self.out_dim, bias=True, quant_config=quant_config\n                )\n            ]\n        )\n\n        if qk_norm is None:\n            self.norm_q = None\n            self.norm_k = None\n        elif qk_norm == \"layer_norm\":\n            self.norm_q = nn.LayerNorm(\n                dim_head, eps=eps, elementwise_affine=elementwise_affine\n            )\n            self.norm_k = nn.LayerNorm(\n                dim_head, eps=eps, elementwise_affine=elementwise_affine\n            )\n        else:\n            raise ValueError(\n                f\"unknown qk_norm: {qk_norm}. Should be one of None, 'layer_norm', 'fp32_layer_norm', 'layer_norm_across_heads', 'rms_norm', 'rms_norm_across_heads', 'l2'.\"\n            )\n\n        self.attn = USPAttention(\n            num_heads=self.heads,\n            head_size=dim_head,\n            num_kv_heads=self.num_kv_heads,\n            dropout_rate=0,\n            softmax_scale=None,\n            causal=False,\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n        kv_cache: Optional[GlmImageLayerKVCache] = None,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        dtype = encoder_hidden_states.dtype\n\n        batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape\n        batch_size, image_seq_length, embed_dim = hidden_states.shape\n        hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)\n\n        # 1. QKV projections\n        query, _ = self.to_q(hidden_states)\n        key, _ = self.to_k(hidden_states)\n        value, _ = self.to_v(hidden_states)\n\n        query = query.unflatten(2, (self.heads, -1))\n        key = key.unflatten(2, (self.heads, -1))\n        value = value.unflatten(2, (self.heads, -1))\n\n        # 2. QK normalization\n        if self.norm_q is not None:\n            query = self.norm_q(query).to(dtype=dtype)\n        if self.norm_k is not None:\n            key = self.norm_k(key).to(dtype=dtype)\n\n        # 3. Rotational positional embeddings applied to latent stream\n        if image_rotary_emb is not None:\n            cos, sin = image_rotary_emb\n\n            if _is_cuda and cos.dim() == 2:\n                q_img = query[:, text_seq_length:, :, :]\n                k_img = key[:, text_seq_length:, :, :]\n                cos_sin_cache = torch.cat(\n                    [\n                        cos.to(dtype=torch.float32).contiguous(),\n                        sin.to(dtype=torch.float32).contiguous(),\n                    ],\n                    dim=-1,\n                )\n                # apply_flashinfer_rope_qk_inplace is inplace kernel and q_img/k_img are views of query/key, so we need not copy back\n                q_out, k_out = apply_flashinfer_rope_qk_inplace(\n                    q_img, k_img, cos_sin_cache, is_neox=True\n                )\n            else:\n                query[:, text_seq_length:, :, :] = _apply_rotary_emb(\n                    query[:, text_seq_length:, :, :], cos, sin, is_neox_style=True\n                )\n                key[:, text_seq_length:, :, :] = _apply_rotary_emb(\n                    key[:, text_seq_length:, :, :], cos, sin, is_neox_style=True\n                )\n\n        if kv_cache is not None:\n            if kv_cache.mode == \"write\":\n                kv_cache.store(key, value)\n            elif kv_cache.mode == \"read\":\n                k_cache, v_cache = kv_cache.get()\n                key = torch.cat([k_cache, key], dim=1) if k_cache is not None else key\n                value = (\n                    torch.cat([v_cache, value], dim=1) if v_cache is not None else value\n                )\n            elif kv_cache.mode == \"skip\":\n                pass\n\n        # 4. Attention\n        if attention_mask is not None:\n            text_attn_mask = attention_mask\n            assert (\n                text_attn_mask.dim() == 2\n            ), \"the shape of text_attn_mask should be (batch_size, text_seq_length)\"\n            text_attn_mask = text_attn_mask.float().to(query.device)\n            mix_attn_mask = torch.ones(\n                (batch_size, text_seq_length + image_seq_length), device=query.device\n            )\n            mix_attn_mask[:, :text_seq_length] = text_attn_mask\n            mix_attn_mask = mix_attn_mask.unsqueeze(2)\n            attn_mask_matrix = mix_attn_mask @ mix_attn_mask.transpose(1, 2)\n            attention_mask = (attn_mask_matrix > 0).unsqueeze(1).to(query.dtype)\n        hidden_states = self.attn(query, key, value)\n        hidden_states = hidden_states.flatten(2, 3)\n        hidden_states = hidden_states.to(query.dtype)\n\n        # 5. Output projection\n        hidden_states, _ = self.to_out[0](hidden_states)\n        # hidden_states = self.to_out[1](hidden_states)         # (dropout omitted)\n\n        encoder_hidden_states, hidden_states = hidden_states.split(\n            [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1\n        )\n        return hidden_states, encoder_hidden_states\n\n\nclass GlmImageTransformerBlock(nn.Module):\n    def __init__(\n        self,\n        dim: int = 2560,\n        num_attention_heads: int = 64,\n        attention_head_dim: int = 40,\n        time_embed_dim: int = 512,\n        supported_attention_backends: set[AttentionBackendEnum] | None = None,\n        prefix: str = \"\",\n        quant_config: QuantizationConfig | None = None,\n    ) -> None:\n        super().__init__()\n\n        # 1. Attention\n        self.norm1 = GlmImageAdaLayerNormZero(time_embed_dim, dim)\n\n        self.attn1 = GlmImageAttention(\n            query_dim=dim,\n            heads=num_attention_heads,\n            dim_head=attention_head_dim,\n            out_dim=dim,\n            bias=True,\n            qk_norm=\"layer_norm\",\n            elementwise_affine=False,\n            eps=1e-5,\n            supported_attention_backends=supported_attention_backends,\n            prefix=f\"{prefix}.attn1\",\n            quant_config=quant_config,\n        )\n\n        # 2. Feedforward\n        self.norm2 = ScaleResidualLayerNormScaleShift(\n            dim, eps=1e-5, elementwise_affine=False\n        )\n        self.norm2_context = ScaleResidualLayerNormScaleShift(\n            dim, eps=1e-5, elementwise_affine=False\n        )\n        self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn=\"gelu-approximate\")\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: torch.Tensor,\n        temb: Optional[torch.Tensor] = None,\n        image_rotary_emb: Optional[\n            Union[\n                Tuple[torch.Tensor, torch.Tensor],\n                List[Tuple[torch.Tensor, torch.Tensor]],\n            ]\n        ] = None,\n        attention_mask: Optional[Dict[str, torch.Tensor]] = None,\n        attention_kwargs: Optional[Dict[str, Any]] = None,\n        kv_cache: Optional[GlmImageLayerKVCache] = None,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        # 1. Timestep conditioning\n        (\n            norm_hidden_states,\n            gate_msa,\n            shift_mlp,\n            scale_mlp,\n            gate_mlp,\n            norm_encoder_hidden_states,\n            c_gate_msa,\n            c_shift_mlp,\n            c_scale_mlp,\n            c_gate_mlp,\n        ) = self.norm1(hidden_states, encoder_hidden_states, temb)\n\n        # 2. Attention\n        if attention_kwargs is None:\n            attention_kwargs = {}\n\n        attn_hidden_states, attn_encoder_hidden_states = self.attn1(\n            hidden_states=norm_hidden_states,\n            encoder_hidden_states=norm_encoder_hidden_states,\n            image_rotary_emb=image_rotary_emb,\n            attention_mask=attention_mask,\n            kv_cache=kv_cache,\n            **attention_kwargs,\n        )\n\n        # 3. Feedforward (fused residual + norm + scale/shift)\n        norm_hidden_states, hidden_states = self.norm2(\n            hidden_states,\n            attn_hidden_states,\n            gate_msa.unsqueeze(1),\n            shift_mlp.unsqueeze(1),\n            scale_mlp.unsqueeze(1),\n        )\n        norm_encoder_hidden_states, encoder_hidden_states = self.norm2_context(\n            encoder_hidden_states,\n            attn_encoder_hidden_states,\n            c_gate_msa.unsqueeze(1),\n            c_shift_mlp.unsqueeze(1),\n            c_scale_mlp.unsqueeze(1),\n        )\n\n        ff_output = self.ff(norm_hidden_states)\n        ff_output_context = self.ff(norm_encoder_hidden_states)\n        hidden_states = hidden_states + ff_output * gate_mlp.unsqueeze(1)\n        encoder_hidden_states = (\n            encoder_hidden_states + ff_output_context * c_gate_mlp.unsqueeze(1)\n        )\n\n        return hidden_states, encoder_hidden_states\n\n\nclass GlmImageRotaryPosEmbed(nn.Module):\n    def __init__(self, dim: int, patch_size: int, theta: float = 10000.0) -> None:\n        super().__init__()\n\n        self.dim = dim\n        self.patch_size = patch_size\n        self.theta = theta\n\n    def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n        batch_size, num_channels, height, width = hidden_states.shape\n        height, width = height // self.patch_size, width // self.patch_size\n        device = hidden_states.device\n\n        dim_h, dim_w = self.dim // 2, self.dim // 2\n        h_inv_freq = 1.0 / (\n            self.theta\n            ** (\n                torch.arange(0, dim_h, 2, dtype=torch.float32, device=device)[\n                    : (dim_h // 2)\n                ].float()\n                / dim_h\n            )\n        )\n        w_inv_freq = 1.0 / (\n            self.theta\n            ** (\n                torch.arange(0, dim_w, 2, dtype=torch.float32, device=device)[\n                    : (dim_w // 2)\n                ].float()\n                / dim_w\n            )\n        )\n        h_seq = torch.arange(height, device=device)\n        w_seq = torch.arange(width, device=device)\n        freqs_h = torch.outer(h_seq, h_inv_freq)\n        freqs_w = torch.outer(w_seq, w_inv_freq)\n\n        # Create position matrices for height and width\n        # [height, 1, dim//4] and [1, width, dim//4]\n        freqs_h = freqs_h.unsqueeze(1)\n        freqs_w = freqs_w.unsqueeze(0)\n        # Broadcast freqs_h and freqs_w to [height, width, dim//4]\n        freqs_h = freqs_h.expand(height, width, -1)\n        freqs_w = freqs_w.expand(height, width, -1)\n\n        # Concatenate along last dimension to get [height, width, dim//2]\n        freqs = torch.cat([freqs_h, freqs_w], dim=-1)\n        freqs = freqs.reshape(height * width, -1)  # [height * width, dim//2]\n        return (freqs.cos(), freqs.sin())\n\n\nclass GlmImageAdaLayerNormContinuous(nn.Module):\n    \"\"\"\n    GlmImage-only final AdaLN: LN(x) -> Linear(cond) -> chunk -> affine. Matches Megatron: **no activation** before the\n    Linear on conditioning embedding.\n    \"\"\"\n\n    def __init__(\n        self,\n        embedding_dim: int,\n        conditioning_embedding_dim: int,\n        elementwise_affine: bool = True,\n        eps: float = 1e-5,\n        bias: bool = True,\n        norm_type: str = \"layer_norm\",\n    ):\n        super().__init__()\n        self.linear = nn.Linear(\n            conditioning_embedding_dim, embedding_dim * 2, bias=bias\n        )\n        if norm_type == \"layer_norm\":\n            self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias)\n            # For now, don’t replace this with sglang’s LayerNorm\n            # because the model doesn’t have this parameter and it will break model loading\n        elif norm_type == \"rms_norm\":\n            self.norm = nn.RMSNorm(embedding_dim, eps, elementwise_affine)\n        else:\n            raise ValueError(f\"unknown norm_type {norm_type}\")\n\n    def forward(\n        self, x: torch.Tensor, conditioning_embedding: torch.Tensor\n    ) -> torch.Tensor:\n        # *** NO SiLU here ***\n        emb = self.linear(conditioning_embedding.to(x.dtype))\n        scale, shift = torch.chunk(emb, 2, dim=1)\n        x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]\n        return x\n\n\nclass GlmImageTransformer2DModel(CachableDiT, OffloadableDiTMixin):\n    r\"\"\"\n    Args:\n        patch_size (`int`, defaults to `2`):\n            The size of the patches to use in the patch embedding layer.\n        in_channels (`int`, defaults to `16`):\n            The number of channels in the input.\n        num_layers (`int`, defaults to `30`):\n            The number of layers of Transformer blocks to use.\n        attention_head_dim (`int`, defaults to `40`):\n            The number of channels in each head.\n        num_attention_heads (`int`, defaults to `64`):\n            The number of heads to use for multi-head attention.\n        out_channels (`int`, defaults to `16`):\n            The number of channels in the output.\n        text_embed_dim (`int`, defaults to `1472`):\n            Input dimension of text embeddings from the text encoder.\n        time_embed_dim (`int`, defaults to `512`):\n            Output dimension of timestep embeddings.\n        condition_dim (`int`, defaults to `256`):\n            The embedding dimension of the input SDXL-style resolution conditions (original_size, target_size,\n            crop_coords).\n        pos_embed_max_size (`int`, defaults to `128`):\n            The maximum resolution of the positional embeddings, from which slices of shape `H x W` are taken and added\n            to input patched latents, where `H` and `W` are the latent height and width respectively. A value of 128\n            means that the maximum supported height and width for image generation is `128 * vae_scale_factor *\n            patch_size => 128 * 8 * 2 => 2048`.\n        sample_size (`int`, defaults to `128`):\n            The base resolution of input latents. If height/width is not provided during generation, this value is used\n            to determine the resolution as `sample_size * vae_scale_factor => 128 * 8 => 1024`\n    \"\"\"\n\n    def __init__(\n        self,\n        config: GlmImageDitConfig,\n        hf_config: dict[str, Any],\n        quant_config: QuantizationConfig | None = None,\n    ):\n        super().__init__(config=config, hf_config=hf_config)\n\n        self.config_data = config  # Store config\n        arch_config = config.arch_config\n\n        self.in_channels = arch_config.in_channels\n        self.out_channels = arch_config.out_channels\n        self.patch_size = arch_config.patch_size\n        self.num_layers = arch_config.num_layers\n        self.attention_head_dim = arch_config.attention_head_dim\n        self.num_attention_heads = arch_config.num_attention_heads\n        self.text_embed_dim = arch_config.text_embed_dim\n        self.time_embed_dim = arch_config.time_embed_dim\n\n        # GlmImage uses 2 additional SDXL-like conditions - target_size, crop_coords\n        # Each of these are sincos embeddings of shape 2 * condition_dim\n        pooled_projection_dim = 2 * 2 * arch_config.condition_dim\n        inner_dim = arch_config.num_attention_heads * arch_config.attention_head_dim\n\n        # 1. RoPE\n        self.rotary_emb = GlmImageRotaryPosEmbed(\n            arch_config.attention_head_dim, arch_config.patch_size, theta=10000.0\n        )\n\n        # 2. Patch & Text-timestep embedding\n        self.image_projector = GlmImageImageProjector(\n            arch_config.in_channels, inner_dim, arch_config.patch_size\n        )\n        self.glyph_projector = FeedForward(\n            arch_config.text_embed_dim,\n            inner_dim,\n            inner_dim=inner_dim,\n            activation_fn=\"gelu\",\n        )\n        self.prior_token_embedding = nn.Embedding(\n            arch_config.prior_vq_quantizer_codebook_size, inner_dim\n        )\n        self.prior_projector = FeedForward(\n            inner_dim, inner_dim, inner_dim=inner_dim, activation_fn=\"linear-silu\"\n        )\n\n        self.time_condition_embed = GlmImageCombinedTimestepSizeEmbeddings(\n            embedding_dim=arch_config.time_embed_dim,\n            condition_dim=arch_config.condition_dim,\n            pooled_projection_dim=pooled_projection_dim,\n            timesteps_dim=arch_config.time_embed_dim,\n        )\n\n        # 3. Transformer blocks\n        self._supported_attention_backends = arch_config._supported_attention_backends\n        self.transformer_blocks = nn.ModuleList(\n            [\n                GlmImageTransformerBlock(\n                    inner_dim,\n                    arch_config.num_attention_heads,\n                    arch_config.attention_head_dim,\n                    arch_config.time_embed_dim,\n                    supported_attention_backends=self._supported_attention_backends,\n                    prefix=f\"transformer_blocks.{i}\",\n                    quant_config=quant_config,\n                )\n                for i in range(arch_config.num_layers)\n            ]\n        )\n\n        # 4. Output projection\n        self.norm_out = GlmImageAdaLayerNormContinuous(\n            inner_dim, arch_config.time_embed_dim, elementwise_affine=False\n        )\n        self.proj_out = nn.Linear(\n            inner_dim,\n            arch_config.patch_size * arch_config.patch_size * arch_config.out_channels,\n            bias=True,\n        )\n\n        self.gradient_checkpointing = False\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: torch.Tensor,\n        prior_token_id: torch.Tensor,\n        prior_token_drop: torch.Tensor,\n        timestep: torch.LongTensor,\n        target_size: torch.Tensor,\n        crop_coords: torch.Tensor,\n        attention_kwargs: Optional[Dict[str, Any]] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        kv_caches: Optional[GlmImageKVCache] = None,\n        kv_caches_mode: Optional[str] = None,\n        freqs_cis: Optional[\n            Union[\n                Tuple[torch.Tensor, torch.Tensor],\n                List[Tuple[torch.Tensor, torch.Tensor]],\n            ]\n        ] = None,\n        ###\n        guidance: torch.Tensor = None,  # TODO: this should probably be removed\n    ) -> Tuple[torch.Tensor]:\n        if kv_caches is not None:\n            kv_caches.set_mode(kv_caches_mode)\n\n        batch_size, num_channels, height, width = hidden_states.shape\n\n        timestep -= 1.0\n\n        if isinstance(encoder_hidden_states, list):\n            encoder_hidden_states = encoder_hidden_states[0]\n\n        # 1. RoPE\n        image_rotary_emb = freqs_cis\n        if image_rotary_emb is None:\n            image_rotary_emb = self.rotary_emb(hidden_states)\n        # 2. Patch & Timestep embeddings\n        p = self.config.patch_size\n        post_patch_height = height // p\n        post_patch_width = width // p\n\n        hidden_states = self.image_projector(hidden_states)\n        encoder_hidden_states = self.glyph_projector(encoder_hidden_states)\n        prior_embedding = self.prior_token_embedding(prior_token_id)\n        prior_embedding[prior_token_drop] *= 0.0\n        prior_hidden_states = self.prior_projector(prior_embedding)\n        # SP: when latents are H-sharded, hidden_states has fewer patches than prior_hidden_states.\n        # Shard prior_hidden_states along seq dim to match (prior is row-major, same as latent patches).\n        if (\n            get_sp_world_size() > 1\n            and prior_hidden_states.shape[1] != hidden_states.shape[1]\n        ):\n            rank = get_sp_parallel_rank()\n            sp_world_size = get_sp_world_size()\n            chunk = prior_hidden_states.shape[1] // sp_world_size\n            prior_hidden_states = prior_hidden_states[\n                :, rank * chunk : (rank + 1) * chunk, :\n            ]\n        hidden_states = hidden_states + prior_hidden_states\n\n        temb = self.time_condition_embed(\n            timestep, target_size, crop_coords, hidden_states.dtype\n        )\n        temb = F.silu(temb)\n\n        # 3. Transformer blocks\n        for idx, block in enumerate(self.transformer_blocks):\n            hidden_states, encoder_hidden_states = block(\n                hidden_states,\n                encoder_hidden_states,\n                temb,\n                image_rotary_emb,\n                attention_mask,\n                attention_kwargs,\n                kv_cache=kv_caches[idx] if kv_caches is not None else None,\n            )\n\n        # 4. Output norm & projection\n        hidden_states = self.norm_out(hidden_states, temb)\n        hidden_states = self.proj_out(hidden_states)\n\n        # 5. Unpatchify\n        hidden_states = hidden_states.reshape(\n            batch_size, post_patch_height, post_patch_width, -1, p, p\n        )\n        output = hidden_states.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3)\n\n        return output.float()\n        # float()\n        # reference: https://github.com/zRzRzRzRzRzRzR/diffusers/blob/6cfc83b4abc5b083fef56a18ec4700f48ba3aaba/src/diffusers/pipelines/glm_image/pipeline_glm_image.py#L737\n\n\nEntryClass = GlmImageTransformer2DModel\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/dits/helios.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n# Adapted from Helios diffusers transformer:\n# https://github.com/BestWishYsh/Helios\n\"\"\"\nHelios Transformer 3D model for video generation.\n\nImplements the HeliosTransformer3DModel with multi-term memory patches,\n3D rotary position embeddings, and per-block scale-shift modulation.\n\"\"\"\n\nimport math\nfrom functools import lru_cache\nfrom typing import Any\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom sglang.multimodal_gen.configs.models.dits.helios import HeliosConfig\nfrom sglang.multimodal_gen.runtime.distributed import (\n    divide,\n    get_sp_world_size,\n    get_tp_world_size,\n)\nfrom sglang.multimodal_gen.runtime.distributed.communication_op import (\n    sequence_model_parallel_all_gather,\n)\nfrom sglang.multimodal_gen.runtime.distributed.parallel_state import get_sp_group\nfrom sglang.multimodal_gen.runtime.layers.attention import USPAttention\nfrom sglang.multimodal_gen.runtime.layers.layernorm import (\n    FP32LayerNorm,\n    RMSNorm,\n    tensor_parallel_rms_norm,\n)\nfrom sglang.multimodal_gen.runtime.layers.linear import (\n    ColumnParallelLinear,\n    RowParallelLinear,\n)\nfrom sglang.multimodal_gen.runtime.layers.mlp import MLP\nfrom sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import (\n    QuantizationConfig,\n)\nfrom sglang.multimodal_gen.runtime.layers.visual_embedding import (\n    ModulateProjection,\n    PatchEmbed,\n    TimestepEmbedder,\n)\nfrom sglang.multimodal_gen.runtime.managers.forward_context import get_forward_context\nfrom sglang.multimodal_gen.runtime.models.dits.base import CachableDiT\nfrom sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\n# ---------------------------------------------------------------------------\n# Utility functions\n# ---------------------------------------------------------------------------\n\n\ndef pad_for_3d_conv(x, kernel_size):\n    \"\"\"Pad input to make it divisible by kernel_size using replicate mode.\"\"\"\n    b, c, t, h, w = x.shape\n    pt, ph, pw = kernel_size\n    pad_t = (pt - (t % pt)) % pt\n    pad_h = (ph - (h % ph)) % ph\n    pad_w = (pw - (w % pw)) % pw\n    return F.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode=\"replicate\")\n\n\ndef center_down_sample_3d(x, kernel_size):\n    \"\"\"Average pooling for 3D downsampling.\"\"\"\n    return F.avg_pool3d(x, kernel_size, stride=kernel_size)\n\n\ndef apply_rotary_emb_transposed(hidden_states, freqs_cis):\n    \"\"\"Apply rotary positional embeddings with transposed cos/sin format.\"\"\"\n    x_1, x_2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)\n    cos, sin = freqs_cis.unsqueeze(-2).chunk(2, dim=-1)\n    out = torch.empty_like(hidden_states)\n    out[..., 0::2] = x_1 * cos[..., 0::2] - x_2 * sin[..., 1::2]\n    out[..., 1::2] = x_1 * sin[..., 1::2] + x_2 * cos[..., 0::2]\n    return out.type_as(hidden_states)\n\n\n# ---------------------------------------------------------------------------\n# Output norm\n# ---------------------------------------------------------------------------\n\n\nclass HeliosOutputNorm(nn.Module):\n    def __init__(self, dim: int, eps: float = 1e-6):\n        super().__init__()\n        self.scale_shift_table = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)\n        self.norm = FP32LayerNorm(dim, eps, elementwise_affine=False)\n\n    def forward(self, hidden_states, temb, original_context_length):\n        temb = temb[:, -original_context_length:, :]\n        shift, scale = (\n            self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)\n        ).chunk(2, dim=2)\n        shift = shift.squeeze(2).to(hidden_states.device)\n        scale = scale.squeeze(2).to(hidden_states.device)\n        hidden_states = hidden_states[:, -original_context_length:, :]\n        hidden_states = (\n            self.norm(hidden_states.float()) * (1 + scale) + shift\n        ).type_as(hidden_states)\n        return hidden_states\n\n\n# ---------------------------------------------------------------------------\n# Rotary Positional Embedding (3D)\n# ---------------------------------------------------------------------------\n\n\nclass HeliosRotaryPosEmbed(nn.Module):\n    \"\"\"3D rotary position embeddings for (time, height, width).\"\"\"\n\n    def __init__(self, rope_dim, theta):\n        super().__init__()\n        self.DT, self.DY, self.DX = rope_dim\n        self.theta = theta\n        # Store as plain attributes (not buffers) to avoid meta-device issues\n        # during FSDP loading. They'll be re-created on the correct device in forward.\n        self._freqs_base_t = None\n        self._freqs_base_y = None\n        self._freqs_base_x = None\n\n    def _get_freqs_base(self, dim):\n        return 1.0 / (\n            self.theta\n            ** (torch.arange(0, dim, 2, dtype=torch.float32)[: (dim // 2)] / dim)\n        )\n\n    def _ensure_freqs_base(self, device):\n        \"\"\"Lazily create frequency bases on the correct device.\"\"\"\n        if self._freqs_base_t is None or self._freqs_base_t.device != device:\n            self._freqs_base_t = self._get_freqs_base(self.DT).to(device)\n            self._freqs_base_y = self._get_freqs_base(self.DY).to(device)\n            self._freqs_base_x = self._get_freqs_base(self.DX).to(device)\n\n    @torch.no_grad()\n    def get_frequency_batched(self, freqs_base, pos):\n        freqs = torch.einsum(\"d,bthw->dbthw\", freqs_base, pos)\n        freqs = freqs.repeat_interleave(2, dim=0)\n        return freqs.cos(), freqs.sin()\n\n    @torch.no_grad()\n    @lru_cache(maxsize=32)\n    def _get_spatial_meshgrid(self, height, width, device_str):\n        device = torch.device(device_str)\n        grid_y_coords = torch.arange(height, device=device, dtype=torch.float32)\n        grid_x_coords = torch.arange(width, device=device, dtype=torch.float32)\n        grid_y, grid_x = torch.meshgrid(grid_y_coords, grid_x_coords, indexing=\"ij\")\n        return grid_y, grid_x\n\n    @torch.no_grad()\n    def forward(self, frame_indices, height, width, device):\n        self._ensure_freqs_base(device)\n        batch_size = frame_indices.shape[0]\n        num_frames = frame_indices.shape[1]\n\n        frame_indices = frame_indices.to(device=device, dtype=torch.float32)\n        grid_y, grid_x = self._get_spatial_meshgrid(height, width, str(device))\n\n        grid_t = frame_indices[:, :, None, None].expand(\n            batch_size, num_frames, height, width\n        )\n        grid_y_batch = grid_y[None, None, :, :].expand(batch_size, num_frames, -1, -1)\n        grid_x_batch = grid_x[None, None, :, :].expand(batch_size, num_frames, -1, -1)\n\n        freqs_cos_t, freqs_sin_t = self.get_frequency_batched(\n            self._freqs_base_t, grid_t\n        )\n        freqs_cos_y, freqs_sin_y = self.get_frequency_batched(\n            self._freqs_base_y, grid_y_batch\n        )\n        freqs_cos_x, freqs_sin_x = self.get_frequency_batched(\n            self._freqs_base_x, grid_x_batch\n        )\n\n        result = torch.cat(\n            [\n                freqs_cos_t,\n                freqs_cos_y,\n                freqs_cos_x,\n                freqs_sin_t,\n                freqs_sin_y,\n                freqs_sin_x,\n            ],\n            dim=0,\n        )\n        return result.permute(1, 0, 2, 3, 4)\n\n\n# ---------------------------------------------------------------------------\n# Condition Embedder\n# ---------------------------------------------------------------------------\n\n\nclass HeliosTimeTextEmbedding(nn.Module):\n    \"\"\"Condition embedder combining timestep and text embeddings.\"\"\"\n\n    def __init__(self, dim, time_freq_dim, time_proj_dim, text_embed_dim):\n        super().__init__()\n        self.time_embedder = TimestepEmbedder(\n            dim, frequency_embedding_size=time_freq_dim, act_layer=\"silu\"\n        )\n        self.time_modulation = ModulateProjection(dim, factor=6, act_layer=\"silu\")\n        self.text_embedder = MLP(\n            text_embed_dim, dim, dim, bias=True, act_type=\"gelu_pytorch_tanh\"\n        )\n\n    def forward(\n        self, timestep, encoder_hidden_states, is_return_encoder_hidden_states=True\n    ):\n        temb = self.time_embedder(timestep)\n        timestep_proj = self.time_modulation(temb)\n\n        if encoder_hidden_states is not None and is_return_encoder_hidden_states:\n            encoder_hidden_states = self.text_embedder(encoder_hidden_states)\n\n        return temb, timestep_proj, encoder_hidden_states\n\n\n# ---------------------------------------------------------------------------\n# Self-Attention for Helios\n# ---------------------------------------------------------------------------\n\n\nclass HeliosSelfAttention(nn.Module):\n    \"\"\"Self-attention with RMSNorm Q/K, optional history key amplification.\"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        num_heads: int,\n        eps: float = 1e-6,\n        is_amplify_history: bool = False,\n        history_scale_mode: str = \"per_head\",\n        quant_config: QuantizationConfig | None = None,\n    ):\n        super().__init__()\n        self.dim = dim\n        self.num_heads = num_heads\n        self.head_dim = dim // num_heads\n        tp_size = get_tp_world_size()\n        self.local_num_heads = divide(num_heads, tp_size)\n\n        self.to_q = ColumnParallelLinear(\n            dim, dim, bias=True, gather_output=False, quant_config=quant_config\n        )\n        self.to_k = ColumnParallelLinear(\n            dim, dim, bias=True, gather_output=False, quant_config=quant_config\n        )\n        self.to_v = ColumnParallelLinear(\n            dim, dim, bias=True, gather_output=False, quant_config=quant_config\n        )\n        self.to_out = RowParallelLinear(\n            dim, dim, bias=True, reduce_results=True, quant_config=quant_config\n        )\n        self.norm_q = RMSNorm(dim, eps=eps)\n        self.norm_k = RMSNorm(dim, eps=eps)\n        self.tp_rmsnorm = tp_size > 1\n\n        self.attn = USPAttention(\n            num_heads=self.local_num_heads,\n            head_size=self.head_dim,\n            causal=False,\n            is_cross_attention=False,\n        )\n\n        self.is_amplify_history = is_amplify_history\n        if is_amplify_history:\n            if history_scale_mode == \"scalar\":\n                self.history_key_scale = nn.Parameter(torch.ones(1))\n            elif history_scale_mode == \"per_head\":\n                self.history_key_scale = nn.Parameter(torch.ones(num_heads))\n            else:\n                raise ValueError(f\"Unknown history_scale_mode: {history_scale_mode}\")\n            self.history_scale_mode = history_scale_mode\n            self.max_scale = 10.0\n\n    def forward(self, hidden_states, rotary_emb=None, original_context_length=None):\n        q, _ = self.to_q(hidden_states)\n        k, _ = self.to_k(hidden_states)\n        v, _ = self.to_v(hidden_states)\n\n        if self.tp_rmsnorm:\n            q = tensor_parallel_rms_norm(q, self.norm_q)\n            k = tensor_parallel_rms_norm(k, self.norm_k)\n        else:\n            q = self.norm_q(q)\n            k = self.norm_k(k)\n\n        q = q.unflatten(2, (self.local_num_heads, self.head_dim))\n        k = k.unflatten(2, (self.local_num_heads, self.head_dim))\n        v = v.unflatten(2, (self.local_num_heads, self.head_dim))\n\n        if rotary_emb is not None:\n            q = apply_rotary_emb_transposed(q, rotary_emb)\n            k = apply_rotary_emb_transposed(k, rotary_emb)\n\n        history_seq_len = (\n            hidden_states.shape[1] - original_context_length\n            if original_context_length is not None\n            else 0\n        )\n\n        if self.is_amplify_history and original_context_length is not None:\n            if history_seq_len > 0:\n                scale_key = 1.0 + torch.sigmoid(self.history_key_scale) * (\n                    self.max_scale - 1.0\n                )\n                if self.history_scale_mode == \"per_head\":\n                    scale_key = scale_key.view(1, 1, -1, 1)\n                k = torch.cat(\n                    [k[:, :history_seq_len] * scale_key, k[:, history_seq_len:]],\n                    dim=1,\n                )\n\n        x = self.attn(q, k, v, num_replicated_prefix=history_seq_len)\n        x = x.flatten(2)\n        x, _ = self.to_out(x)\n        return x\n\n\n# ---------------------------------------------------------------------------\n# Cross-Attention for Helios\n# ---------------------------------------------------------------------------\n\n\nclass HeliosCrossAttention(nn.Module):\n    \"\"\"Cross-attention with RMSNorm Q/K normalization.\"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        num_heads: int,\n        eps: float = 1e-6,\n        quant_config: QuantizationConfig | None = None,\n    ):\n        super().__init__()\n        self.dim = dim\n        self.num_heads = num_heads\n        self.head_dim = dim // num_heads\n        tp_size = get_tp_world_size()\n        self.local_num_heads = divide(num_heads, tp_size)\n\n        self.to_q = ColumnParallelLinear(\n            dim, dim, bias=True, gather_output=False, quant_config=quant_config\n        )\n        self.to_k = ColumnParallelLinear(\n            dim, dim, bias=True, gather_output=False, quant_config=quant_config\n        )\n        self.to_v = ColumnParallelLinear(\n            dim, dim, bias=True, gather_output=False, quant_config=quant_config\n        )\n        self.to_out = RowParallelLinear(\n            dim, dim, bias=True, reduce_results=True, quant_config=quant_config\n        )\n        self.norm_q = RMSNorm(dim, eps=eps)\n        self.norm_k = RMSNorm(dim, eps=eps)\n        self.tp_rmsnorm = tp_size > 1\n\n        self.attn = USPAttention(\n            num_heads=self.local_num_heads,\n            head_size=self.head_dim,\n            causal=False,\n            skip_sequence_parallel=True,\n        )\n\n    def forward(self, hidden_states, encoder_hidden_states):\n        q, _ = self.to_q(hidden_states)\n        k, _ = self.to_k(encoder_hidden_states)\n        v, _ = self.to_v(encoder_hidden_states)\n\n        if self.tp_rmsnorm:\n            q = tensor_parallel_rms_norm(q, self.norm_q)\n            k = tensor_parallel_rms_norm(k, self.norm_k)\n        else:\n            q = self.norm_q(q)\n            k = self.norm_k(k)\n\n        q = q.unflatten(2, (self.local_num_heads, self.head_dim))\n        k = k.unflatten(2, (self.local_num_heads, self.head_dim))\n        v = v.unflatten(2, (self.local_num_heads, self.head_dim))\n\n        x = self.attn(q, k, v)\n        x = x.flatten(2)\n        x, _ = self.to_out(x)\n        return x\n\n\n# ---------------------------------------------------------------------------\n# Transformer Block\n# ---------------------------------------------------------------------------\n\n\nclass HeliosTransformerBlock(nn.Module):\n    \"\"\"\n    Single transformer block with self-attention, cross-attention, FFN,\n    and scale-shift modulation from timestep embeddings.\n    \"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        ffn_dim: int,\n        num_heads: int,\n        cross_attn_norm: bool = True,\n        eps: float = 1e-6,\n        guidance_cross_attn: bool = True,\n        is_amplify_history: bool = False,\n        history_scale_mode: str = \"per_head\",\n        quant_config: QuantizationConfig | None = None,\n    ):\n        super().__init__()\n\n        # 1. Self-attention\n        self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)\n        self.attn1 = HeliosSelfAttention(\n            dim=dim,\n            num_heads=num_heads,\n            eps=eps,\n            is_amplify_history=is_amplify_history,\n            history_scale_mode=history_scale_mode,\n            quant_config=quant_config,\n        )\n\n        # 2. Cross-attention\n        self.attn2 = HeliosCrossAttention(\n            dim=dim,\n            num_heads=num_heads,\n            eps=eps,\n            quant_config=quant_config,\n        )\n        self.self_attn_residual_norm = (\n            FP32LayerNorm(dim, eps, elementwise_affine=True)\n            if cross_attn_norm\n            else nn.Identity()\n        )\n\n        # 3. Feed-forward\n        self.ffn = MLP(\n            dim, ffn_dim, act_type=\"gelu_pytorch_tanh\", quant_config=quant_config\n        )\n        self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)\n\n        self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)\n\n        # 4. Guidance cross-attention flag\n        self.guidance_cross_attn = guidance_cross_attn\n\n    def forward(\n        self,\n        hidden_states,\n        encoder_hidden_states,\n        temb,\n        rotary_emb,\n        original_context_length=None,\n    ):\n        if temb.ndim == 4:\n            shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (\n                self.scale_shift_table.unsqueeze(0) + temb.float()\n            ).chunk(6, dim=2)\n            shift_msa = shift_msa.squeeze(2)\n            scale_msa = scale_msa.squeeze(2)\n            gate_msa = gate_msa.squeeze(2)\n            c_shift_msa = c_shift_msa.squeeze(2)\n            c_scale_msa = c_scale_msa.squeeze(2)\n            c_gate_msa = c_gate_msa.squeeze(2)\n        else:\n            shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (\n                self.scale_shift_table + temb.float()\n            ).chunk(6, dim=1)\n\n        # 1. Self-attention\n        norm_hidden_states = (\n            self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa\n        ).type_as(hidden_states)\n        attn_output = self.attn1(\n            norm_hidden_states, rotary_emb, original_context_length\n        )\n        hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(\n            hidden_states\n        )\n\n        # 2. Cross-attention\n        if self.guidance_cross_attn:\n            history_seq_len = hidden_states.shape[1] - original_context_length\n            history_hidden_states, current_hidden_states = torch.split(\n                hidden_states, [history_seq_len, original_context_length], dim=1\n            )\n            norm_hidden_states = self.self_attn_residual_norm(\n                current_hidden_states.float()\n            ).type_as(current_hidden_states)\n            attn_output = self.attn2(norm_hidden_states, encoder_hidden_states)\n            current_hidden_states = current_hidden_states + attn_output\n            hidden_states = torch.cat(\n                [history_hidden_states, current_hidden_states], dim=1\n            )\n        else:\n            norm_hidden_states = self.self_attn_residual_norm(\n                hidden_states.float()\n            ).type_as(hidden_states)\n            attn_output = self.attn2(norm_hidden_states, encoder_hidden_states)\n            hidden_states = hidden_states + attn_output\n\n        # 3. Feed-forward\n        norm_hidden_states = (\n            self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa\n        ).type_as(hidden_states)\n        ff_output = self.ffn(norm_hidden_states)\n        hidden_states = (\n            hidden_states.float() + ff_output.float() * c_gate_msa\n        ).type_as(hidden_states)\n\n        return hidden_states\n\n\n# ---------------------------------------------------------------------------\n# Main model\n# ---------------------------------------------------------------------------\n\n\nclass HeliosTransformer3DModel(CachableDiT, OffloadableDiTMixin):\n    \"\"\"\n    Helios Transformer 3D model for video generation.\n\n    Implements multi-scale history patches, 3D RoPE, and chunked denoising\n    with zero_history_timestep and guidance_cross_attn.\n    \"\"\"\n\n    _fsdp_shard_conditions = HeliosConfig()._fsdp_shard_conditions\n    _compile_conditions = HeliosConfig()._compile_conditions\n    _supported_attention_backends = HeliosConfig()._supported_attention_backends\n    param_names_mapping = HeliosConfig().param_names_mapping\n    reverse_param_names_mapping = HeliosConfig().reverse_param_names_mapping\n    lora_param_names_mapping = HeliosConfig().lora_param_names_mapping\n\n    def __init__(\n        self,\n        config: HeliosConfig,\n        hf_config: dict[str, Any],\n        quant_config: QuantizationConfig | None = None,\n    ) -> None:\n        super().__init__(config=config, hf_config=hf_config)\n\n        inner_dim = config.num_attention_heads * config.attention_head_dim\n        self.hidden_size = config.hidden_size\n        self.num_attention_heads = config.num_attention_heads\n        self.in_channels = config.in_channels\n        self.out_channels = config.out_channels\n        self.num_channels_latents = config.num_channels_latents\n        self.patch_size = config.patch_size\n        self.text_len = config.text_len\n        self.inner_dim = inner_dim\n\n        # Helios-specific config\n        self.zero_history_timestep = config.zero_history_timestep\n        self.has_multi_term_memory_patch = config.has_multi_term_memory_patch\n        self.guidance_cross_attn = config.guidance_cross_attn\n\n        # 1. Patch & position embedding\n        self.patch_embedding = PatchEmbed(\n            in_chans=config.in_channels,\n            embed_dim=inner_dim,\n            patch_size=config.patch_size,\n            flatten=False,\n        )\n\n        # 2. Rotary position embeddings\n        self.rope = HeliosRotaryPosEmbed(\n            rope_dim=config.rope_dim, theta=config.rope_theta\n        )\n\n        # 3. Multi-term memory patches\n        if self.has_multi_term_memory_patch:\n            self.patch_short = nn.Conv3d(\n                config.in_channels,\n                inner_dim,\n                kernel_size=config.patch_size,\n                stride=config.patch_size,\n            )\n            self.patch_mid = nn.Conv3d(\n                config.in_channels,\n                inner_dim,\n                kernel_size=tuple(2 * p for p in config.patch_size),\n                stride=tuple(2 * p for p in config.patch_size),\n            )\n            self.patch_long = nn.Conv3d(\n                config.in_channels,\n                inner_dim,\n                kernel_size=tuple(4 * p for p in config.patch_size),\n                stride=tuple(4 * p for p in config.patch_size),\n            )\n\n        # 4. Condition embeddings\n        self.condition_embedder = HeliosTimeTextEmbedding(\n            dim=inner_dim,\n            time_freq_dim=config.freq_dim,\n            time_proj_dim=inner_dim * 6,\n            text_embed_dim=config.text_dim,\n        )\n\n        # 5. Transformer blocks\n        self.blocks = nn.ModuleList(\n            [\n                HeliosTransformerBlock(\n                    dim=inner_dim,\n                    ffn_dim=config.ffn_dim,\n                    num_heads=config.num_attention_heads,\n                    cross_attn_norm=config.cross_attn_norm,\n                    eps=config.eps,\n                    guidance_cross_attn=config.guidance_cross_attn,\n                    is_amplify_history=config.is_amplify_history,\n                    history_scale_mode=config.history_scale_mode,\n                    quant_config=quant_config,\n                )\n                for _ in range(config.num_layers)\n            ]\n        )\n\n        # 6. Output norm & projection\n        self.norm_out = HeliosOutputNorm(inner_dim, config.eps)\n        self.proj_out = ColumnParallelLinear(\n            inner_dim,\n            config.out_channels * math.prod(config.patch_size),\n            bias=True,\n            gather_output=True,\n            quant_config=quant_config,\n        )\n\n        self.cnt = 0\n        self.__post_init__()\n        self.layer_names = [\"blocks\"]\n        self.sp_size = get_sp_world_size()\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: torch.Tensor | list[torch.Tensor],\n        timestep: torch.LongTensor,\n        # Stage 1 history inputs\n        indices_hidden_states=None,\n        indices_latents_history_short=None,\n        indices_latents_history_mid=None,\n        indices_latents_history_long=None,\n        latents_history_short=None,\n        latents_history_mid=None,\n        latents_history_long=None,\n        **kwargs,\n    ) -> torch.Tensor:\n        orig_dtype = hidden_states.dtype\n        if not isinstance(encoder_hidden_states, torch.Tensor):\n            encoder_hidden_states = encoder_hidden_states[0]\n\n        # Check if sequence parallelism is enabled\n        forward_batch = get_forward_context().forward_batch\n        if forward_batch is not None:\n            sequence_shard_enabled = (\n                forward_batch.enable_sequence_shard and self.sp_size > 1\n            )\n        else:\n            sequence_shard_enabled = False\n\n        batch_size = hidden_states.shape[0]\n        p_t, p_h, p_w = self.patch_size\n\n        # 1. Patch embed the noisy latents\n        hidden_states = self.patch_embedding(hidden_states)\n        _, _, post_patch_num_frames, post_patch_height, post_patch_width = (\n            hidden_states.shape\n        )\n\n        if indices_hidden_states is None:\n            indices_hidden_states = (\n                torch.arange(0, post_patch_num_frames)\n                .unsqueeze(0)\n                .expand(batch_size, -1)\n            )\n\n        hidden_states = hidden_states.flatten(2).transpose(1, 2)\n\n        # 2. Compute rotary embeddings\n        rotary_emb = self.rope(\n            frame_indices=indices_hidden_states,\n            height=post_patch_height,\n            width=post_patch_width,\n            device=hidden_states.device,\n        )\n        rotary_emb = rotary_emb.flatten(2).transpose(1, 2)\n        original_context_length = hidden_states.shape[1]\n\n        # Sequence parallelism: shard current tokens and RoPE across SP ranks\n        seq_shard_pad = 0\n        if sequence_shard_enabled:\n            sp_rank = get_sp_group().rank_in_group\n            seq_len = hidden_states.shape[1]\n            if seq_len % self.sp_size != 0:\n                seq_shard_pad = self.sp_size - (seq_len % self.sp_size)\n                hs_pad = torch.zeros(\n                    batch_size,\n                    seq_shard_pad,\n                    hidden_states.shape[2],\n                    dtype=hidden_states.dtype,\n                    device=hidden_states.device,\n                )\n                re_pad = torch.zeros(\n                    batch_size,\n                    seq_shard_pad,\n                    rotary_emb.shape[2],\n                    dtype=rotary_emb.dtype,\n                    device=rotary_emb.device,\n                )\n                hidden_states = torch.cat([hidden_states, hs_pad], dim=1)\n                rotary_emb = torch.cat([rotary_emb, re_pad], dim=1)\n            local_seq_len = hidden_states.shape[1] // self.sp_size\n            hidden_states = hidden_states.view(\n                batch_size, self.sp_size, local_seq_len, -1\n            )[:, sp_rank, :, :].contiguous()\n            rotary_emb = rotary_emb.view(batch_size, self.sp_size, local_seq_len, -1)[\n                :, sp_rank, :, :\n            ].contiguous()\n            effective_context_length = local_seq_len\n        else:\n            effective_context_length = original_context_length\n\n        # 3. Process short history\n        if (\n            latents_history_short is not None\n            and indices_latents_history_short is not None\n        ):\n            latents_history_short = latents_history_short.to(hidden_states)\n            latents_history_short = self.patch_short(latents_history_short)\n            _, _, _, H1, W1 = latents_history_short.shape\n            latents_history_short = latents_history_short.flatten(2).transpose(1, 2)\n\n            rotary_emb_history_short = self.rope(\n                frame_indices=indices_latents_history_short,\n                height=H1,\n                width=W1,\n                device=latents_history_short.device,\n            )\n            rotary_emb_history_short = rotary_emb_history_short.flatten(2).transpose(\n                1, 2\n            )\n            hidden_states = torch.cat([latents_history_short, hidden_states], dim=1)\n            rotary_emb = torch.cat([rotary_emb_history_short, rotary_emb], dim=1)\n\n        # 4. Process mid history\n        if latents_history_mid is not None and indices_latents_history_mid is not None:\n            latents_history_mid = latents_history_mid.to(hidden_states)\n            latents_history_mid = pad_for_3d_conv(latents_history_mid, (2, 4, 4))\n            latents_history_mid = self.patch_mid(latents_history_mid)\n            latents_history_mid = latents_history_mid.flatten(2).transpose(1, 2)\n\n            rotary_emb_history_mid = self.rope(\n                frame_indices=indices_latents_history_mid,\n                height=H1,\n                width=W1,\n                device=latents_history_mid.device,\n            )\n            rotary_emb_history_mid = pad_for_3d_conv(rotary_emb_history_mid, (2, 2, 2))\n            rotary_emb_history_mid = center_down_sample_3d(\n                rotary_emb_history_mid, (2, 2, 2)\n            )\n            rotary_emb_history_mid = rotary_emb_history_mid.flatten(2).transpose(1, 2)\n\n            hidden_states = torch.cat([latents_history_mid, hidden_states], dim=1)\n            rotary_emb = torch.cat([rotary_emb_history_mid, rotary_emb], dim=1)\n\n        # 5. Process long history\n        if (\n            latents_history_long is not None\n            and indices_latents_history_long is not None\n        ):\n            latents_history_long = latents_history_long.to(hidden_states)\n            latents_history_long = pad_for_3d_conv(latents_history_long, (4, 8, 8))\n            latents_history_long = self.patch_long(latents_history_long)\n            latents_history_long = latents_history_long.flatten(2).transpose(1, 2)\n\n            rotary_emb_history_long = self.rope(\n                frame_indices=indices_latents_history_long,\n                height=H1,\n                width=W1,\n                device=latents_history_long.device,\n            )\n            rotary_emb_history_long = pad_for_3d_conv(\n                rotary_emb_history_long, (4, 4, 4)\n            )\n            rotary_emb_history_long = center_down_sample_3d(\n                rotary_emb_history_long, (4, 4, 4)\n            )\n            rotary_emb_history_long = rotary_emb_history_long.flatten(2).transpose(1, 2)\n\n            hidden_states = torch.cat([latents_history_long, hidden_states], dim=1)\n            rotary_emb = torch.cat([rotary_emb_history_long, rotary_emb], dim=1)\n\n        history_context_length = hidden_states.shape[1] - effective_context_length\n\n        # 6. Compute condition embeddings\n        if indices_hidden_states is not None and self.zero_history_timestep:\n            timestep_t0 = torch.zeros(\n                (1,), dtype=timestep.dtype, device=timestep.device\n            )\n            temb_t0, timestep_proj_t0, _ = self.condition_embedder(\n                timestep_t0,\n                encoder_hidden_states,\n                is_return_encoder_hidden_states=False,\n            )\n            temb_t0 = temb_t0.unsqueeze(1).expand(\n                batch_size, history_context_length, -1\n            )\n            timestep_proj_t0 = (\n                timestep_proj_t0.unflatten(-1, (6, -1))\n                .view(1, 6, 1, -1)\n                .expand(batch_size, -1, history_context_length, -1)\n            )\n\n        temb, timestep_proj, encoder_hidden_states = self.condition_embedder(\n            timestep, encoder_hidden_states\n        )\n        timestep_proj = timestep_proj.unflatten(-1, (6, -1))\n\n        if indices_hidden_states is not None and not self.zero_history_timestep:\n            main_repeat_size = hidden_states.shape[1]\n        else:\n            main_repeat_size = effective_context_length\n        temb = temb.view(batch_size, 1, -1).expand(batch_size, main_repeat_size, -1)\n        timestep_proj = timestep_proj.view(batch_size, 6, 1, -1).expand(\n            batch_size, 6, main_repeat_size, -1\n        )\n\n        if indices_hidden_states is not None and self.zero_history_timestep:\n            temb = torch.cat([temb_t0, temb], dim=1)\n            timestep_proj = torch.cat([timestep_proj_t0, timestep_proj], dim=2)\n\n        if timestep_proj.ndim == 4:\n            timestep_proj = timestep_proj.permute(0, 2, 1, 3)\n\n        # 7. Transformer blocks\n        hidden_states = hidden_states.contiguous()\n        encoder_hidden_states = encoder_hidden_states.contiguous()\n        rotary_emb = rotary_emb.contiguous()\n\n        for block in self.blocks:\n            hidden_states = block(\n                hidden_states,\n                encoder_hidden_states,\n                timestep_proj,\n                rotary_emb,\n                effective_context_length,\n            )\n\n        self.cnt += 1\n\n        # SP: all-gather current tokens before output\n        if sequence_shard_enabled:\n            current_tokens = hidden_states[:, -local_seq_len:, :].contiguous()\n            current_tokens = sequence_model_parallel_all_gather(current_tokens, dim=1)\n            if seq_shard_pad > 0:\n                current_tokens = current_tokens[:, :original_context_length, :]\n            hidden_states = current_tokens\n            # Re-create temb for norm_out (all current tokens share same timestep)\n            temb = temb[:, :1, :].expand(batch_size, original_context_length, -1)\n\n        # 8. Output norm & projection\n        hidden_states = self.norm_out(hidden_states, temb, original_context_length)\n        hidden_states, _ = self.proj_out(hidden_states)\n\n        # 9. Unpatchify\n        hidden_states = hidden_states.reshape(\n            batch_size,\n            post_patch_num_frames,\n            post_patch_height,\n            post_patch_width,\n            p_t,\n            p_h,\n            p_w,\n            -1,\n        )\n        hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)\n        output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)\n\n        return output\n\n\nEntryClass = HeliosTransformer3DModel\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/dits/hunyuan3d.py",
    "content": "# Copied and adapted from: https://github.com/Tencent-Hunyuan/Hunyuan3D-2\nfrom __future__ import annotations\n\nimport math\nfrom dataclasses import dataclass\nfrom typing import List, Optional, Tuple\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom einops import rearrange\n\nfrom sglang.multimodal_gen.configs.models.dits.hunyuan3d import (\n    Hunyuan3DDiTArchConfig,\n    Hunyuan3DDiTConfig,\n)\nfrom sglang.multimodal_gen.runtime.distributed import divide\nfrom sglang.multimodal_gen.runtime.distributed.parallel_state import get_tp_world_size\nfrom sglang.multimodal_gen.runtime.layers.attention import LocalAttention\nfrom sglang.multimodal_gen.runtime.layers.linear import (\n    MergedColumnParallelLinear,\n    RowParallelLinear,\n)\nfrom sglang.multimodal_gen.runtime.layers.mlp import MLP\nfrom sglang.multimodal_gen.runtime.models.dits.base import CachableDiT\nfrom sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum\nfrom sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\nclass MixedRowParallelLinear(RowParallelLinear):\n    \"\"\"RowParallel for inputs concatenated from multiple separately-sharded sources.\"\"\"\n\n    def __init__(self, input_sizes: list[int], output_size: int, **kwargs):\n        self.input_sizes = input_sizes\n        super().__init__(sum(input_sizes), output_size, **kwargs)\n\n    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):\n        input_dim = getattr(param, \"input_dim\", None)\n        if input_dim is not None:\n            shards = []\n            offset = 0\n            for sz in self.input_sizes:\n                part = loaded_weight.narrow(input_dim, offset, sz)\n                per_rank = sz // self.tp_size\n                shard = part.narrow(input_dim, self.tp_rank * per_rank, per_rank)\n                shards.append(shard)\n                offset += sz\n            param.data.copy_(torch.cat(shards, dim=input_dim))\n        else:\n            param.data.copy_(loaded_weight)\n\n\ndef _flux_timestep_embedding(\n    t: torch.Tensor, dim, max_period=10000, time_factor: float = 1000.0\n):\n    \"\"\"Create sinusoidal timestep embeddings for Flux-style model.\"\"\"\n    t = time_factor * t\n    half = dim // 2\n    freqs = torch.exp(\n        -math.log(max_period)\n        * torch.arange(start=0, end=half, dtype=torch.float32)\n        / half\n    ).to(t.device)\n\n    args = t[:, None].float() * freqs[None]\n    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)\n    if dim % 2:\n        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)\n    if torch.is_floating_point(t):\n        embedding = embedding.to(t)\n    return embedding\n\n\nclass _FluxGELU(nn.Module):\n    def __init__(self, approximate=\"tanh\"):\n        super().__init__()\n        self.approximate = approximate\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        return F.gelu(x, approximate=self.approximate)\n\n\nclass _FluxMLPEmbedder(nn.Module):\n    def __init__(self, in_dim: int, hidden_dim: int):\n        super().__init__()\n        self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)\n        self.silu = nn.SiLU()\n        self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        return self.out_layer(self.silu(self.in_layer(x)))\n\n\nclass _FluxRMSNorm(nn.Module):\n    def __init__(self, dim: int):\n        super().__init__()\n        self.scale = nn.Parameter(torch.ones(dim))\n\n    def forward(self, x: torch.Tensor):\n        x_dtype = x.dtype\n        x = x.float()\n        rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)\n        return (x * rrms).to(dtype=x_dtype) * self.scale\n\n\nclass _FluxQKNorm(nn.Module):\n    def __init__(self, dim: int):\n        super().__init__()\n        self.query_norm = _FluxRMSNorm(dim)\n        self.key_norm = _FluxRMSNorm(dim)\n\n    def forward(\n        self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        q = self.query_norm(q)\n        k = self.key_norm(k)\n        return q.to(v), k.to(v)\n\n\nclass _FluxSelfAttention(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        num_heads: int = 8,\n        qkv_bias: bool = False,\n        supported_attention_backends: set[AttentionBackendEnum] | None = None,\n    ):\n        super().__init__()\n        tp_size = get_tp_world_size()\n        self.num_heads = num_heads\n        self.local_num_heads = divide(num_heads, tp_size)\n        self.head_dim = dim // num_heads\n\n        self.qkv = MergedColumnParallelLinear(\n            dim, [dim, dim, dim], bias=qkv_bias, gather_output=False\n        )\n        self.norm = _FluxQKNorm(self.head_dim)\n        self.proj = RowParallelLinear(dim, dim, bias=True, input_is_parallel=True)\n\n        if supported_attention_backends is None:\n            supported_attention_backends = {\n                AttentionBackendEnum.FA,\n                AttentionBackendEnum.TORCH_SDPA,\n            }\n        self.local_attn = LocalAttention(\n            num_heads=self.local_num_heads,\n            head_size=self.head_dim,\n            causal=False,\n            supported_attention_backends=supported_attention_backends,\n        )\n\n    def forward(self, x: torch.Tensor, pe: torch.Tensor) -> torch.Tensor:\n        qkv, _ = self.qkv(x)\n        B, L, _ = qkv.shape\n        qkv = qkv.view(B, L, 3, self.local_num_heads, self.head_dim)\n        q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]\n        q = q.transpose(1, 2)\n        k = k.transpose(1, 2)\n        v_for_norm = v.transpose(1, 2)\n        q, k = self.norm(q, k, v_for_norm)\n        q = q.transpose(1, 2)\n        k = k.transpose(1, 2)\n        x = self.local_attn(q, k, v)\n        x = x.flatten(2)\n        x, _ = self.proj(x)\n        return x\n\n\n@dataclass\nclass _FluxModulationOut:\n    shift: torch.Tensor\n    scale: torch.Tensor\n    gate: torch.Tensor\n\n\nclass _FluxModulation(nn.Module):\n    def __init__(self, dim: int, double: bool):\n        super().__init__()\n        self.is_double = double\n        self.multiplier = 6 if double else 3\n        self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)\n\n    def forward(\n        self, vec: torch.Tensor\n    ) -> Tuple[_FluxModulationOut, Optional[_FluxModulationOut]]:\n        out = self.lin(F.silu(vec))[:, None, :]\n        out = out.chunk(self.multiplier, dim=-1)\n\n        return (\n            _FluxModulationOut(*out[:3]),\n            _FluxModulationOut(*out[3:]) if self.is_double else None,\n        )\n\n\nclass _FluxDoubleStreamBlock(nn.Module):\n    def __init__(\n        self,\n        hidden_size: int,\n        num_heads: int,\n        mlp_ratio: float,\n        qkv_bias: bool = False,\n        supported_attention_backends: set[AttentionBackendEnum] | None = None,\n    ):\n        super().__init__()\n        mlp_hidden_dim = int(hidden_size * mlp_ratio)\n        tp_size = get_tp_world_size()\n        self.num_heads = num_heads\n        self.local_num_heads = divide(num_heads, tp_size)\n        self.hidden_size = hidden_size\n        self.head_dim = hidden_size // num_heads\n        self.img_mod = _FluxModulation(hidden_size, double=True)\n        self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)\n        self.img_attn = _FluxSelfAttention(\n            dim=hidden_size,\n            num_heads=num_heads,\n            qkv_bias=qkv_bias,\n            supported_attention_backends=supported_attention_backends,\n        )\n\n        self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)\n        self.img_mlp = MLP(hidden_size, mlp_hidden_dim, act_type=\"gelu_pytorch_tanh\")\n\n        self.txt_mod = _FluxModulation(hidden_size, double=True)\n        self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)\n        self.txt_attn = _FluxSelfAttention(\n            dim=hidden_size,\n            num_heads=num_heads,\n            qkv_bias=qkv_bias,\n            supported_attention_backends=supported_attention_backends,\n        )\n\n        self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)\n        self.txt_mlp = MLP(hidden_size, mlp_hidden_dim, act_type=\"gelu_pytorch_tanh\")\n\n        if supported_attention_backends is None:\n            supported_attention_backends = {\n                AttentionBackendEnum.FA,\n                AttentionBackendEnum.TORCH_SDPA,\n            }\n        self.local_attn_joint = LocalAttention(\n            num_heads=self.local_num_heads,\n            head_size=self.head_dim,\n            causal=False,\n            supported_attention_backends=supported_attention_backends,\n        )\n\n    def forward(\n        self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, pe: torch.Tensor\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n\n        img_mod1, img_mod2 = self.img_mod(vec)\n        txt_mod1, txt_mod2 = self.txt_mod(vec)\n\n        img_modulated = self.img_norm1(img)\n        img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift\n\n        B, img_L, _ = img_modulated.shape\n        img_qkv, _ = self.img_attn.qkv(img_modulated)\n        img_qkv = img_qkv.view(B, img_L, 3, self.local_num_heads, self.head_dim)\n        img_q, img_k, img_v = img_qkv[:, :, 0], img_qkv[:, :, 1], img_qkv[:, :, 2]\n        img_q_t = img_q.transpose(1, 2)\n        img_k_t = img_k.transpose(1, 2)\n        img_v_t = img_v.transpose(1, 2)\n        img_q_t, img_k_t = self.img_attn.norm(img_q_t, img_k_t, img_v_t)\n        img_q = img_q_t.transpose(1, 2)\n        img_k = img_k_t.transpose(1, 2)\n\n        txt_modulated = self.txt_norm1(txt)\n        txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift\n        txt_L = txt_modulated.shape[1]\n        txt_qkv, _ = self.txt_attn.qkv(txt_modulated)\n        txt_qkv = txt_qkv.view(B, txt_L, 3, self.local_num_heads, self.head_dim)\n        txt_q, txt_k, txt_v = txt_qkv[:, :, 0], txt_qkv[:, :, 1], txt_qkv[:, :, 2]\n        txt_q_t = txt_q.transpose(1, 2)\n        txt_k_t = txt_k.transpose(1, 2)\n        txt_v_t = txt_v.transpose(1, 2)\n        txt_q_t, txt_k_t = self.txt_attn.norm(txt_q_t, txt_k_t, txt_v_t)\n        txt_q = txt_q_t.transpose(1, 2)\n        txt_k = txt_k_t.transpose(1, 2)\n\n        q = torch.cat((txt_q, img_q), dim=1)\n        k = torch.cat((txt_k, img_k), dim=1)\n        v = torch.cat((txt_v, img_v), dim=1)\n\n        attn = self.local_attn_joint(q, k, v)\n        attn = attn.flatten(2)\n\n        txt_attn, img_attn = attn[:, :txt_L], attn[:, txt_L:]\n\n        img_proj, _ = self.img_attn.proj(img_attn)\n        img = img + img_mod1.gate * img_proj\n        img = img + img_mod2.gate * self.img_mlp(\n            (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift\n        )\n\n        txt_proj, _ = self.txt_attn.proj(txt_attn)\n        txt = txt + txt_mod1.gate * txt_proj\n        txt = txt + txt_mod2.gate * self.txt_mlp(\n            (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift\n        )\n        return img, txt\n\n\nclass _FluxSingleStreamBlock(nn.Module):\n    \"\"\"\n    A DiT block with parallel linear layers as described in\n    https://arxiv.org/abs/2302.05442 and adapted modulation interface.\n    \"\"\"\n\n    def __init__(\n        self,\n        hidden_size: int,\n        num_heads: int,\n        mlp_ratio: float = 4.0,\n        qk_scale: Optional[float] = None,\n        supported_attention_backends: set[AttentionBackendEnum] | None = None,\n    ):\n        super().__init__()\n\n        tp_size = get_tp_world_size()\n        self.hidden_dim = hidden_size\n        self.num_heads = num_heads\n        self.local_num_heads = divide(num_heads, tp_size)\n        self.head_dim = hidden_size // num_heads\n        self.tp_size = tp_size\n\n        self.mlp_hidden_dim = int(hidden_size * mlp_ratio)\n        self.linear1 = MergedColumnParallelLinear(\n            hidden_size,\n            [hidden_size, hidden_size, hidden_size, self.mlp_hidden_dim],\n            bias=True,\n            gather_output=False,\n        )\n        self.linear2 = MixedRowParallelLinear(\n            [hidden_size, self.mlp_hidden_dim],\n            hidden_size,\n            bias=True,\n            input_is_parallel=True,\n        )\n\n        self.norm = _FluxQKNorm(self.head_dim)\n\n        self.hidden_size = hidden_size\n        self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)\n\n        self.mlp_act = _FluxGELU(approximate=\"tanh\")\n        self.modulation = _FluxModulation(hidden_size, double=False)\n\n        if supported_attention_backends is None:\n            supported_attention_backends = {\n                AttentionBackendEnum.FA,\n                AttentionBackendEnum.TORCH_SDPA,\n            }\n        self.local_attn = LocalAttention(\n            num_heads=self.local_num_heads,\n            head_size=self.head_dim,\n            causal=False,\n            supported_attention_backends=supported_attention_backends,\n        )\n\n    def forward(\n        self, x: torch.Tensor, vec: torch.Tensor, pe: torch.Tensor\n    ) -> torch.Tensor:\n        mod, _ = self.modulation(vec)\n\n        x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift\n        linear1_out, _ = self.linear1(x_mod)\n        local_qkv_dim = 3 * self.head_dim * self.local_num_heads\n        local_mlp_dim = self.mlp_hidden_dim // self.tp_size\n        qkv, mlp = torch.split(linear1_out, [local_qkv_dim, local_mlp_dim], dim=-1)\n\n        B, L, _ = qkv.shape\n        qkv = qkv.view(B, L, 3, self.local_num_heads, self.head_dim)\n        q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]\n        q_t = q.transpose(1, 2)\n        k_t = k.transpose(1, 2)\n        v_t = v.transpose(1, 2)\n        q_t, k_t = self.norm(q_t, k_t, v_t)\n        q = q_t.transpose(1, 2)\n        k = k_t.transpose(1, 2)\n\n        attn = self.local_attn(q, k, v)\n        attn = attn.flatten(2)\n\n        output, _ = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))\n        return x + mod.gate * output\n\n\nclass _FluxLastLayer(nn.Module):\n    def __init__(self, hidden_size: int, patch_size: int, out_channels: int):\n        super().__init__()\n        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)\n        self.linear = nn.Linear(\n            hidden_size, patch_size * patch_size * out_channels, bias=True\n        )\n        self.adaLN_modulation = nn.Sequential(\n            nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)\n        )\n\n    def forward(self, x: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:\n        shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)\n        x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]\n        x = self.linear(x)\n        return x\n\n\nclass Hunyuan3D2DiT(CachableDiT, OffloadableDiTMixin):\n    \"\"\"Hunyuan3D DiT model (Flux-style architecture for Hunyuan3D-2.0).\"\"\"\n\n    _aliases = [\"hy3dgen.shapegen.models.Hunyuan3DDiT\"]\n\n    param_names_mapping = Hunyuan3DDiTConfig().param_names_mapping\n\n    @classmethod\n    def build_config_from_params(cls, params: dict) -> Hunyuan3DDiTConfig:\n        \"\"\"Build a DiTConfig from YAML-style parameter dict.\"\"\"\n        field_mapping = {\n            \"num_heads\": \"num_attention_heads\",\n            \"depth\": \"num_layers\",\n            \"depth_single_blocks\": \"num_single_layers\",\n        }\n        arch_kwargs = {}\n        for k, v in params.items():\n            if k in (\"ckpt_path\", \"supported_attention_backends\"):\n                continue\n            mapped = field_mapping.get(k, k)\n            if k == \"axes_dim\" and isinstance(v, list):\n                v = tuple(v)\n            arch_kwargs[mapped] = v\n        return Hunyuan3DDiTConfig(arch_config=Hunyuan3DDiTArchConfig(**arch_kwargs))\n\n    def __init__(\n        self,\n        config: Hunyuan3DDiTConfig,\n        hf_config: dict | None = None,\n        **kwargs,\n    ):\n        super().__init__(config=config, hf_config=hf_config or {}, **kwargs)\n        arch = config.arch_config\n\n        in_channels = arch.in_channels\n        context_in_dim = arch.context_in_dim\n        hidden_size = arch.hidden_size\n        mlp_ratio = arch.mlp_ratio\n        num_heads = arch.num_attention_heads\n        depth = arch.num_layers\n        depth_single_blocks = arch.num_single_layers\n        axes_dim = list(arch.axes_dim)\n        theta = arch.theta\n        qkv_bias = arch.qkv_bias\n        time_factor = arch.time_factor\n        guidance_embed = arch.guidance_embed\n        supported_attention_backends = arch._supported_attention_backends\n\n        self.in_channels = in_channels\n        self.context_in_dim = context_in_dim\n        self.hidden_size = hidden_size\n        self.mlp_ratio = mlp_ratio\n        self.num_heads = num_heads\n        self.num_attention_heads = num_heads\n        self.depth = depth\n        self.depth_single_blocks = depth_single_blocks\n        self.axes_dim = axes_dim\n        self.theta = theta\n        self.qkv_bias = qkv_bias\n        self.time_factor = time_factor\n        self.out_channels = self.in_channels\n        self.num_channels_latents = self.in_channels\n        self.guidance_embed = guidance_embed\n\n        if hidden_size % num_heads != 0:\n            raise ValueError(\n                f\"Hidden size {hidden_size} must be divisible by num_heads {num_heads}\"\n            )\n        pe_dim = hidden_size // num_heads\n        if sum(axes_dim) != pe_dim:\n            raise ValueError(f\"Got {axes_dim} but expected positional dim {pe_dim}\")\n        self.latent_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)\n        self.time_in = _FluxMLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)\n        self.cond_in = nn.Linear(context_in_dim, self.hidden_size)\n        self.guidance_in = (\n            _FluxMLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)\n            if guidance_embed\n            else nn.Identity()\n        )\n\n        self.double_blocks = nn.ModuleList(\n            [\n                _FluxDoubleStreamBlock(\n                    self.hidden_size,\n                    self.num_heads,\n                    mlp_ratio=mlp_ratio,\n                    qkv_bias=qkv_bias,\n                    supported_attention_backends=supported_attention_backends,\n                )\n                for _ in range(depth)\n            ]\n        )\n\n        self.single_blocks = nn.ModuleList(\n            [\n                _FluxSingleStreamBlock(\n                    self.hidden_size,\n                    self.num_heads,\n                    mlp_ratio=mlp_ratio,\n                    supported_attention_backends=supported_attention_backends,\n                )\n                for _ in range(depth_single_blocks)\n            ]\n        )\n\n        self.final_layer = _FluxLastLayer(self.hidden_size, 1, self.out_channels)\n\n        # OffloadableDiTMixin\n        self.layer_names = [\"double_blocks\", \"single_blocks\"]\n\n    def forward(\n        self,\n        x,\n        t,\n        contexts,\n        **kwargs,\n    ) -> torch.Tensor:\n        \"\"\"Forward pass for denoising.\"\"\"\n\n        cond = contexts[\"main\"]\n\n        latent = self.latent_in(x)\n\n        t_emb = _flux_timestep_embedding(t, 256, self.time_factor).to(\n            dtype=latent.dtype\n        )\n\n        vec = self.time_in(t_emb)\n\n        if self.guidance_embed:\n            guidance = kwargs.get(\"guidance\", None)\n            if guidance is None:\n                raise ValueError(\n                    \"Didn't get guidance strength for guidance distilled model.\"\n                )\n            vec = vec + self.guidance_in(\n                _flux_timestep_embedding(guidance, 256, self.time_factor)\n            )\n\n        cond = self.cond_in(cond)\n\n        pe = None\n\n        # Double blocks\n        for i, block in enumerate(self.double_blocks):\n            latent, cond = block(img=latent, txt=cond, vec=vec, pe=pe)\n        latent = torch.cat((cond, latent), 1)\n\n        # Single blocks\n        for i, block in enumerate(self.single_blocks):\n            latent = block(latent, vec=vec, pe=pe)\n\n        latent = latent[:, cond.shape[1] :, ...]\n        latent = self.final_layer(latent, vec)\n        return latent\n\n\nimport copy\nimport json\nimport os as _os\n\nfrom diffusers.models import UNet2DConditionModel\nfrom diffusers.models.attention_processor import Attention as DiffusersAttention\nfrom diffusers.models.transformers.transformer_2d import BasicTransformerBlock\n\n\ndef _chunked_feed_forward(\n    ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int\n):\n    \"\"\"Feed forward with chunking to save memory.\"\"\"\n    if hidden_states.shape[chunk_dim] % chunk_size != 0:\n        raise ValueError(\n            f\"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]}\"\n            f\"has to be divisible by chunk size: {chunk_size}.\"\n            f\" Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`.\"\n        )\n\n    num_chunks = hidden_states.shape[chunk_dim] // chunk_size\n    ff_output = torch.cat(\n        [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],\n        dim=chunk_dim,\n    )\n    return ff_output\n\n\nclass SGLangAttentionWrapper(torch.nn.Module):\n    \"\"\"Drop-in replacement for DiffusersAttention that uses sglang's attention backend.\"\"\"\n\n    _SUPPORTED_BACKENDS = {AttentionBackendEnum.FA, AttentionBackendEnum.TORCH_SDPA}\n\n    def __init__(\n        self,\n        query_dim: int,\n        heads: int = 8,\n        dim_head: int = 64,\n        dropout: float = 0.0,\n        bias: bool = False,\n        cross_attention_dim: int | None = None,\n        out_bias: bool = True,\n    ) -> None:\n        super().__init__()\n        self.inner_dim = dim_head * heads\n        self.heads = heads\n        self.dim_head = dim_head\n        self.query_dim = query_dim\n        cross_attention_dim = cross_attention_dim or query_dim\n\n        self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)\n        self.to_k = nn.Linear(cross_attention_dim, self.inner_dim, bias=bias)\n        self.to_v = nn.Linear(cross_attention_dim, self.inner_dim, bias=bias)\n        self.to_out = nn.ModuleList(\n            [nn.Linear(self.inner_dim, query_dim, bias=out_bias), nn.Dropout(dropout)]\n        )\n\n        from sglang.multimodal_gen.runtime.layers.attention.selector import (\n            get_attn_backend,\n        )\n\n        attn_backend = get_attn_backend(\n            dim_head, torch.float16, self._SUPPORTED_BACKENDS\n        )\n        impl_cls = attn_backend.get_impl_cls()\n        self.attn_impl = impl_cls(\n            num_heads=heads,\n            head_size=dim_head,\n            softmax_scale=dim_head**-0.5,\n            num_kv_heads=heads,\n            causal=False,\n        )\n        self._attn_backend_name = attn_backend.get_enum().name\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: torch.Tensor | None = None,\n        attention_mask: torch.Tensor | None = None,\n        **kwargs,\n    ) -> torch.Tensor:\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n\n        B, N_q, _ = hidden_states.shape\n        _, N_kv, _ = encoder_hidden_states.shape\n\n        q = self.to_q(hidden_states).view(B, N_q, self.heads, self.dim_head)\n        k = self.to_k(encoder_hidden_states).view(B, N_kv, self.heads, self.dim_head)\n        v = self.to_v(encoder_hidden_states).view(B, N_kv, self.heads, self.dim_head)\n\n        from sglang.multimodal_gen.runtime.managers.forward_context import (\n            get_forward_context,\n        )\n\n        ctx = get_forward_context()\n        out = self.attn_impl.forward(q, k, v, attn_metadata=ctx.attn_metadata)\n        out = out.reshape(B, N_q, self.inner_dim)\n\n        out = self.to_out[0](out)\n        out = self.to_out[1](out)\n        return out\n\n\nclass Basic2p5DTransformerBlock(torch.nn.Module):\n    \"\"\"2.5D Transformer block with Multiview Attention (MVA) and Reference View Attention (RVA).\"\"\"\n\n    def __init__(\n        self,\n        transformer: BasicTransformerBlock,\n        layer_name: str,\n        use_ma: bool = True,\n        use_ra: bool = True,\n        is_turbo: bool = False,\n        use_sglang_attn: bool = True,\n    ) -> None:\n        super().__init__()\n        self.transformer = transformer\n        self.layer_name = layer_name\n        self.use_ma = use_ma\n        self.use_ra = use_ra\n        self.is_turbo = is_turbo\n        self.use_sglang_attn = use_sglang_attn and not is_turbo\n\n        attn_cls = (\n            SGLangAttentionWrapper if self.use_sglang_attn else DiffusersAttention\n        )\n        attn_kwargs = dict(\n            query_dim=self.dim,\n            heads=self.num_attention_heads,\n            dim_head=self.attention_head_dim,\n            dropout=self.dropout,\n            bias=self.attention_bias,\n            cross_attention_dim=None,\n            upcast_attention=self.attn1.upcast_attention,\n            out_bias=True,\n        )\n        if self.use_sglang_attn:\n            attn_kwargs.pop(\"upcast_attention\")\n\n        if self.use_ma:\n            self.attn_multiview = attn_cls(**attn_kwargs)\n\n        if self.use_ra:\n            self.attn_refview = attn_cls(**attn_kwargs)\n\n        if self.is_turbo:\n            self._initialize_attn_weights()\n\n    def _initialize_attn_weights(self):\n        \"\"\"Initialize attention weights for turbo mode.\"\"\"\n        if self.use_ma:\n            self.attn_multiview.load_state_dict(self.attn1.state_dict())\n            with torch.no_grad():\n                for layer in self.attn_multiview.to_out:\n                    for param in layer.parameters():\n                        param.zero_()\n        if self.use_ra:\n            self.attn_refview.load_state_dict(self.attn1.state_dict())\n            with torch.no_grad():\n                for layer in self.attn_refview.to_out:\n                    for param in layer.parameters():\n                        param.zero_()\n\n    def __getattr__(self, name: str):\n        try:\n            return super().__getattr__(name)\n        except AttributeError:\n            return getattr(self.transformer, name)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        encoder_hidden_states: Optional[torch.Tensor] = None,\n        encoder_attention_mask: Optional[torch.Tensor] = None,\n        timestep: Optional[torch.LongTensor] = None,\n        cross_attention_kwargs: dict = None,\n        class_labels: Optional[torch.LongTensor] = None,\n        added_cond_kwargs: Optional[dict] = None,\n    ) -> torch.Tensor:\n        \"\"\"Forward pass with MVA and RVA support.\"\"\"\n        batch_size = hidden_states.shape[0]\n\n        cross_attention_kwargs = (\n            cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}\n        )\n        num_in_batch = cross_attention_kwargs.pop(\"num_in_batch\", 1)\n        mode = cross_attention_kwargs.pop(\"mode\", None)\n\n        if not self.is_turbo:\n            mva_scale = cross_attention_kwargs.pop(\"mva_scale\", 1.0)\n            ref_scale = cross_attention_kwargs.pop(\"ref_scale\", 1.0)\n        else:\n            position_attn_mask = cross_attention_kwargs.pop(\"position_attn_mask\", None)\n            position_voxel_indices = cross_attention_kwargs.pop(\n                \"position_voxel_indices\", None\n            )\n            mva_scale = 1.0\n            ref_scale = 1.0\n\n        condition_embed_dict = cross_attention_kwargs.pop(\"condition_embed_dict\", None)\n\n        # Normalization\n        if self.norm_type == \"ada_norm\":\n            norm_hidden_states = self.norm1(hidden_states, timestep)\n        elif self.norm_type == \"ada_norm_zero\":\n            norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(\n                hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype\n            )\n        elif self.norm_type in [\"layer_norm\", \"layer_norm_i2vgen\"]:\n            norm_hidden_states = self.norm1(hidden_states)\n        elif self.norm_type == \"ada_norm_continuous\":\n            norm_hidden_states = self.norm1(\n                hidden_states, added_cond_kwargs[\"pooled_text_emb\"]\n            )\n        elif self.norm_type == \"ada_norm_single\":\n            shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (\n                self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)\n            ).chunk(6, dim=1)\n            norm_hidden_states = self.norm1(hidden_states)\n            norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa\n        else:\n            raise ValueError(\"Incorrect norm used\")\n\n        if self.pos_embed is not None:\n            norm_hidden_states = self.pos_embed(norm_hidden_states)\n\n        # Prepare GLIGEN inputs\n        cross_attention_kwargs = (\n            cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}\n        )\n        gligen_kwargs = cross_attention_kwargs.pop(\"gligen\", None)\n\n        # Self-attention\n        attn_output = self.attn1(\n            norm_hidden_states,\n            encoder_hidden_states=(\n                encoder_hidden_states if self.only_cross_attention else None\n            ),\n            attention_mask=attention_mask,\n            **cross_attention_kwargs,\n        )\n\n        if self.norm_type == \"ada_norm_zero\":\n            attn_output = gate_msa.unsqueeze(1) * attn_output\n        elif self.norm_type == \"ada_norm_single\":\n            attn_output = gate_msa * attn_output\n\n        hidden_states = attn_output + hidden_states\n        if hidden_states.ndim == 4:\n            hidden_states = hidden_states.squeeze(1)\n\n        # Reference Attention - Write mode\n        if mode is not None and \"w\" in mode:\n            condition_embed_dict[self.layer_name] = rearrange(\n                norm_hidden_states, \"(b n) l c -> b (n l) c\", n=num_in_batch\n            )\n\n        # Reference Attention - Read mode\n        if mode is not None and \"r\" in mode and self.use_ra:\n            condition_embed = (\n                condition_embed_dict[self.layer_name]\n                .unsqueeze(1)\n                .repeat(1, num_in_batch, 1, 1)\n            )\n            condition_embed = rearrange(condition_embed, \"b n l c -> (b n) l c\")\n\n            attn_output = self.attn_refview(\n                norm_hidden_states,\n                encoder_hidden_states=condition_embed,\n                attention_mask=None,\n                **cross_attention_kwargs,\n            )\n\n            if not self.is_turbo:\n                ref_scale_timing = ref_scale\n                if isinstance(ref_scale, torch.Tensor):\n                    ref_scale_timing = (\n                        ref_scale.unsqueeze(1).repeat(1, num_in_batch).view(-1)\n                    )\n                    for _ in range(attn_output.ndim - 1):\n                        ref_scale_timing = ref_scale_timing.unsqueeze(-1)\n\n            hidden_states = ref_scale_timing * attn_output + hidden_states\n\n            if hidden_states.ndim == 4:\n                hidden_states = hidden_states.squeeze(1)\n\n        # Multiview Attention\n        if num_in_batch > 1 and self.use_ma:\n            multivew_hidden_states = rearrange(\n                norm_hidden_states, \"(b n) l c -> b (n l) c\", n=num_in_batch\n            )\n\n            if self.is_turbo:\n                position_mask = None\n                if position_attn_mask is not None:\n                    if multivew_hidden_states.shape[1] in position_attn_mask:\n                        position_mask = position_attn_mask[\n                            multivew_hidden_states.shape[1]\n                        ]\n                position_indices = None\n                if position_voxel_indices is not None:\n                    if multivew_hidden_states.shape[1] in position_voxel_indices:\n                        position_indices = position_voxel_indices[\n                            multivew_hidden_states.shape[1]\n                        ]\n                attn_output = self.attn_multiview(\n                    multivew_hidden_states,\n                    encoder_hidden_states=multivew_hidden_states,\n                    attention_mask=position_mask,\n                    position_indices=position_indices,\n                    **cross_attention_kwargs,\n                )\n            else:\n                attn_output = self.attn_multiview(\n                    multivew_hidden_states,\n                    encoder_hidden_states=multivew_hidden_states,\n                    **cross_attention_kwargs,\n                )\n\n            attn_output = rearrange(\n                attn_output, \"b (n l) c -> (b n) l c\", n=num_in_batch\n            )\n\n            hidden_states = mva_scale * attn_output + hidden_states\n            if hidden_states.ndim == 4:\n                hidden_states = hidden_states.squeeze(1)\n\n        # GLIGEN Control\n        if gligen_kwargs is not None:\n            hidden_states = self.fuser(hidden_states, gligen_kwargs[\"objs\"])\n\n        # Cross-Attention\n        if self.attn2 is not None:\n            if self.norm_type == \"ada_norm\":\n                norm_hidden_states = self.norm2(hidden_states, timestep)\n            elif self.norm_type in [\"ada_norm_zero\", \"layer_norm\", \"layer_norm_i2vgen\"]:\n                norm_hidden_states = self.norm2(hidden_states)\n            elif self.norm_type == \"ada_norm_single\":\n                norm_hidden_states = hidden_states\n            elif self.norm_type == \"ada_norm_continuous\":\n                norm_hidden_states = self.norm2(\n                    hidden_states, added_cond_kwargs[\"pooled_text_emb\"]\n                )\n            else:\n                raise ValueError(\"Incorrect norm\")\n\n            if self.pos_embed is not None and self.norm_type != \"ada_norm_single\":\n                norm_hidden_states = self.pos_embed(norm_hidden_states)\n\n            attn_output = self.attn2(\n                norm_hidden_states,\n                encoder_hidden_states=encoder_hidden_states,\n                attention_mask=encoder_attention_mask,\n                **cross_attention_kwargs,\n            )\n\n            hidden_states = attn_output + hidden_states\n\n        # Feed-forward\n        if self.norm_type == \"ada_norm_continuous\":\n            norm_hidden_states = self.norm3(\n                hidden_states, added_cond_kwargs[\"pooled_text_emb\"]\n            )\n        elif not self.norm_type == \"ada_norm_single\":\n            norm_hidden_states = self.norm3(hidden_states)\n\n        if self.norm_type == \"ada_norm_zero\":\n            norm_hidden_states = (\n                norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]\n            )\n\n        if self.norm_type == \"ada_norm_single\":\n            norm_hidden_states = self.norm2(hidden_states)\n            norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp\n\n        if self._chunk_size is not None:\n            ff_output = _chunked_feed_forward(\n                self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size\n            )\n        else:\n            ff_output = self.ff(norm_hidden_states)\n\n        if self.norm_type == \"ada_norm_zero\":\n            ff_output = gate_mlp.unsqueeze(1) * ff_output\n        elif self.norm_type == \"ada_norm_single\":\n            ff_output = gate_mlp * ff_output\n\n        hidden_states = ff_output + hidden_states\n        if hidden_states.ndim == 4:\n            hidden_states = hidden_states.squeeze(1)\n\n        return hidden_states\n\n\n@torch.no_grad()\ndef compute_voxel_grid_mask(position: torch.Tensor, grid_resolution: int = 8):\n    \"\"\"Compute voxel grid mask for position-aware attention.\"\"\"\n    position = position.half()\n    B, N, _, H, W = position.shape\n    assert H % grid_resolution == 0 and W % grid_resolution == 0\n\n    valid_mask = (position != 1).all(dim=2, keepdim=True)\n    valid_mask = valid_mask.expand_as(position)\n    position[valid_mask == False] = 0\n\n    position = rearrange(\n        position,\n        \"b n c (num_h grid_h) (num_w grid_w) -> b n num_h num_w c grid_h grid_w\",\n        num_h=grid_resolution,\n        num_w=grid_resolution,\n    )\n    valid_mask = rearrange(\n        valid_mask,\n        \"b n c (num_h grid_h) (num_w grid_w) -> b n num_h num_w c grid_h grid_w\",\n        num_h=grid_resolution,\n        num_w=grid_resolution,\n    )\n\n    grid_position = position.sum(dim=(-2, -1))\n    count_masked = valid_mask.sum(dim=(-2, -1))\n\n    grid_position = grid_position / count_masked.clamp(min=1)\n    grid_position[count_masked < 5] = 0\n\n    grid_position = grid_position.permute(0, 1, 4, 2, 3)\n    grid_position = rearrange(grid_position, \"b n c h w -> b n (h w) c\")\n\n    grid_position_expanded_1 = grid_position.unsqueeze(2).unsqueeze(4)\n    grid_position_expanded_2 = grid_position.unsqueeze(1).unsqueeze(3)\n\n    distances = torch.norm(grid_position_expanded_1 - grid_position_expanded_2, dim=-1)\n\n    weights = distances\n    grid_distance = 1.73 / grid_resolution\n\n    weights = weights < grid_distance\n\n    return weights\n\n\ndef compute_multi_resolution_mask(\n    position_maps: torch.Tensor, grid_resolutions: List[int] = [32, 16, 8]\n) -> dict:\n    \"\"\"Compute multi-resolution position attention masks.\"\"\"\n    position_attn_mask = {}\n    with torch.no_grad():\n        for grid_resolution in grid_resolutions:\n            position_mask = compute_voxel_grid_mask(position_maps, grid_resolution)\n            position_mask = rearrange(\n                position_mask, \"b ni nj li lj -> b (ni li) (nj lj)\"\n            )\n            position_attn_mask[position_mask.shape[1]] = position_mask\n    return position_attn_mask\n\n\n@torch.no_grad()\ndef compute_discrete_voxel_indice(\n    position: torch.Tensor, grid_resolution: int = 8, voxel_resolution: int = 128\n):\n    \"\"\"Compute discrete voxel indices for position encoding.\"\"\"\n    position = position.half()\n    B, N, _, H, W = position.shape\n    assert H % grid_resolution == 0 and W % grid_resolution == 0\n\n    valid_mask = (position != 1).all(dim=2, keepdim=True)\n    valid_mask = valid_mask.expand_as(position)\n    position[valid_mask == False] = 0\n\n    position = rearrange(\n        position,\n        \"b n c (num_h grid_h) (num_w grid_w) -> b n num_h num_w c grid_h grid_w\",\n        num_h=grid_resolution,\n        num_w=grid_resolution,\n    )\n    valid_mask = rearrange(\n        valid_mask,\n        \"b n c (num_h grid_h) (num_w grid_w) -> b n num_h num_w c grid_h grid_w\",\n        num_h=grid_resolution,\n        num_w=grid_resolution,\n    )\n\n    grid_position = position.sum(dim=(-2, -1))\n    count_masked = valid_mask.sum(dim=(-2, -1))\n\n    grid_position = grid_position / count_masked.clamp(min=1)\n    grid_position[count_masked < 5] = 0\n\n    grid_position = grid_position.permute(0, 1, 4, 2, 3).clamp(0, 1)\n    voxel_indices = grid_position * (voxel_resolution - 1)\n    voxel_indices = torch.round(voxel_indices).long()\n    return voxel_indices\n\n\ndef compute_multi_resolution_discrete_voxel_indice(\n    position_maps: torch.Tensor,\n    grid_resolutions: List[int] = [64, 32, 16, 8],\n    voxel_resolutions: List[int] = [512, 256, 128, 64],\n) -> dict:\n    \"\"\"Compute multi-resolution discrete voxel indices.\"\"\"\n    voxel_indices = {}\n    with torch.no_grad():\n        for grid_resolution, voxel_resolution in zip(\n            grid_resolutions, voxel_resolutions\n        ):\n            voxel_indice = compute_discrete_voxel_indice(\n                position_maps, grid_resolution, voxel_resolution\n            )\n            voxel_indice = rearrange(voxel_indice, \"b n c h w -> b (n h w) c\")\n            voxel_indices[voxel_indice.shape[1]] = {\n                \"voxel_indices\": voxel_indice,\n                \"voxel_resolution\": voxel_resolution,\n            }\n    return voxel_indices\n\n\nclass UNet2p5DConditionModel(torch.nn.Module):\n    \"\"\"2.5D UNet for multi-view texture generation.\"\"\"\n\n    def __init__(self, unet: UNet2DConditionModel) -> None:\n        super().__init__()\n        self.unet = unet\n\n        self.use_ma = True\n        self.use_ra = True\n        self.use_camera_embedding = True\n        self.use_dual_stream = True\n        self.is_turbo = False\n\n        if self.use_dual_stream:\n            self.unet_dual = copy.deepcopy(unet)\n            self.init_attention(self.unet_dual)\n        self.init_attention(\n            self.unet, use_ma=self.use_ma, use_ra=self.use_ra, is_turbo=self.is_turbo\n        )\n        self.init_condition()\n        self.init_camera_embedding()\n\n    @staticmethod\n    def from_pretrained(pretrained_model_name_or_path: str, **kwargs):\n        \"\"\"Load a pretrained UNet2p5DConditionModel.\"\"\"\n        torch_dtype = kwargs.pop(\"dtype\", kwargs.pop(\"torch_dtype\", torch.float32))\n        config_path = _os.path.join(pretrained_model_name_or_path, \"config.json\")\n        unet_ckpt_path = _os.path.join(\n            pretrained_model_name_or_path, \"diffusion_pytorch_model.bin\"\n        )\n\n        with open(config_path, \"r\", encoding=\"utf-8\") as file:\n            config = json.load(file)\n\n        unet = UNet2DConditionModel(**config)\n        unet = UNet2p5DConditionModel(unet)\n        unet_ckpt = torch.load(unet_ckpt_path, map_location=\"cpu\", weights_only=True)\n        unet.load_state_dict(unet_ckpt, strict=True)\n        unet = unet.to(torch_dtype)\n        return unet\n\n    def init_condition(self):\n        \"\"\"Initialize condition-related modules.\"\"\"\n        self.unet.conv_in = torch.nn.Conv2d(\n            12,  # 4 (latent) + 4 (normal) + 4 (position)\n            self.unet.conv_in.out_channels,\n            kernel_size=self.unet.conv_in.kernel_size,\n            stride=self.unet.conv_in.stride,\n            padding=self.unet.conv_in.padding,\n            dilation=self.unet.conv_in.dilation,\n            groups=self.unet.conv_in.groups,\n            bias=self.unet.conv_in.bias is not None,\n        )\n\n        self.unet.learned_text_clip_gen = nn.Parameter(torch.randn(1, 77, 1024))\n        self.unet.learned_text_clip_ref = nn.Parameter(torch.randn(1, 77, 1024))\n\n    def init_camera_embedding(self):\n        \"\"\"Initialize camera embedding module.\"\"\"\n        if self.use_camera_embedding:\n            time_embed_dim = 1280\n            self.max_num_ref_image = 5\n            self.max_num_gen_image = 12 * 3 + 4 * 2\n            self.unet.class_embedding = nn.Embedding(\n                self.max_num_ref_image + self.max_num_gen_image, time_embed_dim\n            )\n\n    def init_attention(\n        self,\n        unet: UNet2DConditionModel,\n        use_ma: bool = False,\n        use_ra: bool = False,\n        is_turbo: bool = False,\n        use_sglang_attn: bool = True,\n    ):\n        \"\"\"Initialize attention blocks with MVA and RVA support.\"\"\"\n        block_kwargs = dict(\n            use_ma=use_ma,\n            use_ra=use_ra,\n            is_turbo=is_turbo,\n            use_sglang_attn=use_sglang_attn,\n        )\n\n        # Down blocks\n        for down_block_i, down_block in enumerate(unet.down_blocks):\n            if (\n                hasattr(down_block, \"has_cross_attention\")\n                and down_block.has_cross_attention\n            ):\n                for attn_i, attn in enumerate(down_block.attentions):\n                    for transformer_i, transformer in enumerate(\n                        attn.transformer_blocks\n                    ):\n                        if isinstance(transformer, BasicTransformerBlock):\n                            attn.transformer_blocks[transformer_i] = (\n                                Basic2p5DTransformerBlock(\n                                    transformer,\n                                    f\"down_{down_block_i}_{attn_i}_{transformer_i}\",\n                                    **block_kwargs,\n                                )\n                            )\n\n        # Mid block\n        if (\n            hasattr(unet.mid_block, \"has_cross_attention\")\n            and unet.mid_block.has_cross_attention\n        ):\n            for attn_i, attn in enumerate(unet.mid_block.attentions):\n                for transformer_i, transformer in enumerate(attn.transformer_blocks):\n                    if isinstance(transformer, BasicTransformerBlock):\n                        attn.transformer_blocks[transformer_i] = (\n                            Basic2p5DTransformerBlock(\n                                transformer,\n                                f\"mid_{attn_i}_{transformer_i}\",\n                                **block_kwargs,\n                            )\n                        )\n\n        # Up blocks\n        for up_block_i, up_block in enumerate(unet.up_blocks):\n            if (\n                hasattr(up_block, \"has_cross_attention\")\n                and up_block.has_cross_attention\n            ):\n                for attn_i, attn in enumerate(up_block.attentions):\n                    for transformer_i, transformer in enumerate(\n                        attn.transformer_blocks\n                    ):\n                        if isinstance(transformer, BasicTransformerBlock):\n                            attn.transformer_blocks[transformer_i] = (\n                                Basic2p5DTransformerBlock(\n                                    transformer,\n                                    f\"up_{up_block_i}_{attn_i}_{transformer_i}\",\n                                    **block_kwargs,\n                                )\n                            )\n\n        if use_sglang_attn and (use_ma or use_ra):\n            backend = \"unknown\"\n            for block in self._iter_2p5d_blocks(unet):\n                for attr in (\"attn_multiview\", \"attn_refview\"):\n                    wrapper = getattr(block, attr, None)\n                    if isinstance(wrapper, SGLangAttentionWrapper):\n                        backend = wrapper._attn_backend_name\n                        break\n                if backend != \"unknown\":\n                    break\n            count = sum(1 for _ in self._iter_2p5d_blocks(unet))\n            logger.info(\n                \"Initialized %d Basic2p5DTransformerBlocks with sglang %s attention\",\n                count,\n                backend,\n            )\n\n    @staticmethod\n    def _iter_2p5d_blocks(unet):\n        \"\"\"Yield all Basic2p5DTransformerBlock instances in a UNet.\"\"\"\n        for block_group in (unet.down_blocks, [unet.mid_block], unet.up_blocks):\n            for block in block_group:\n                if not hasattr(block, \"attentions\"):\n                    continue\n                for attn in block.attentions:\n                    for tb in attn.transformer_blocks:\n                        if isinstance(tb, Basic2p5DTransformerBlock):\n                            yield tb\n\n    def __getattr__(self, name: str):\n        try:\n            return super().__getattr__(name)\n        except AttributeError:\n            return getattr(self.unet, name)\n\n    def forward(\n        self,\n        sample: torch.Tensor,\n        timestep: torch.Tensor,\n        encoder_hidden_states: torch.Tensor,\n        *args,\n        down_intrablock_additional_residuals=None,\n        down_block_res_samples=None,\n        mid_block_res_sample=None,\n        **cached_condition,\n    ):\n        \"\"\"Forward pass for multi-view texture generation.\"\"\"\n        B, N_gen, _, H, W = sample.shape\n        assert H == W\n\n        if self.use_camera_embedding:\n            camera_info_gen = (\n                cached_condition[\"camera_info_gen\"] + self.max_num_ref_image\n            )\n            camera_info_gen = rearrange(camera_info_gen, \"b n -> (b n)\")\n        else:\n            camera_info_gen = None\n\n        # Concatenate latents with normal and position maps\n        sample = [sample]\n        if \"normal_imgs\" in cached_condition:\n            sample.append(cached_condition[\"normal_imgs\"])\n        if \"position_imgs\" in cached_condition:\n            sample.append(cached_condition[\"position_imgs\"])\n        sample = torch.cat(sample, dim=2)\n\n        sample = rearrange(sample, \"b n c h w -> (b n) c h w\")\n\n        encoder_hidden_states_gen = encoder_hidden_states.unsqueeze(1).repeat(\n            1, N_gen, 1, 1\n        )\n        encoder_hidden_states_gen = rearrange(\n            encoder_hidden_states_gen, \"b n l c -> (b n) l c\"\n        )\n\n        # Process reference images for RVA\n        if self.use_ra:\n            if \"condition_embed_dict\" in cached_condition:\n                condition_embed_dict = cached_condition[\"condition_embed_dict\"]\n            else:\n                condition_embed_dict = {}\n                ref_latents = cached_condition[\"ref_latents\"]\n                N_ref = ref_latents.shape[1]\n\n                if self.use_camera_embedding:\n                    camera_info_ref = cached_condition[\"camera_info_ref\"]\n                    camera_info_ref = rearrange(camera_info_ref, \"b n -> (b n)\")\n                else:\n                    camera_info_ref = None\n\n                ref_latents = rearrange(ref_latents, \"b n c h w -> (b n) c h w\")\n\n                encoder_hidden_states_ref = self.unet.learned_text_clip_ref.unsqueeze(\n                    1\n                ).repeat(B, N_ref, 1, 1)\n                encoder_hidden_states_ref = rearrange(\n                    encoder_hidden_states_ref, \"b n l c -> (b n) l c\"\n                )\n\n                noisy_ref_latents = ref_latents\n                timestep_ref = 0\n\n                if self.use_dual_stream:\n                    unet_ref = self.unet_dual\n                else:\n                    unet_ref = self.unet\n\n                unet_ref(\n                    noisy_ref_latents,\n                    timestep_ref,\n                    encoder_hidden_states=encoder_hidden_states_ref,\n                    class_labels=camera_info_ref,\n                    return_dict=False,\n                    cross_attention_kwargs={\n                        \"mode\": \"w\",\n                        \"num_in_batch\": N_ref,\n                        \"condition_embed_dict\": condition_embed_dict,\n                    },\n                )\n                cached_condition[\"condition_embed_dict\"] = condition_embed_dict\n        else:\n            condition_embed_dict = None\n\n        mva_scale = cached_condition.get(\"mva_scale\", 1.0)\n        ref_scale = cached_condition.get(\"ref_scale\", 1.0)\n\n        if self.is_turbo:\n            position_attn_mask = cached_condition.get(\"position_attn_mask\", None)\n            position_voxel_indices = cached_condition.get(\n                \"position_voxel_indices\", None\n            )\n            cross_attention_kwargs_ = {\n                \"mode\": \"r\",\n                \"num_in_batch\": N_gen,\n                \"condition_embed_dict\": condition_embed_dict,\n                \"position_attn_mask\": position_attn_mask,\n                \"position_voxel_indices\": position_voxel_indices,\n                \"mva_scale\": mva_scale,\n                \"ref_scale\": ref_scale,\n            }\n        else:\n            cross_attention_kwargs_ = {\n                \"mode\": \"r\",\n                \"num_in_batch\": N_gen,\n                \"condition_embed_dict\": condition_embed_dict,\n                \"mva_scale\": mva_scale,\n                \"ref_scale\": ref_scale,\n            }\n\n        return self.unet(\n            sample,\n            timestep,\n            encoder_hidden_states_gen,\n            *args,\n            class_labels=camera_info_gen,\n            down_intrablock_additional_residuals=(\n                [\n                    s.to(dtype=self.unet.dtype)\n                    for s in down_intrablock_additional_residuals\n                ]\n                if down_intrablock_additional_residuals is not None\n                else None\n            ),\n            down_block_additional_residuals=(\n                [s.to(dtype=self.unet.dtype) for s in down_block_res_samples]\n                if down_block_res_samples is not None\n                else None\n            ),\n            mid_block_additional_residual=(\n                mid_block_res_sample.to(dtype=self.unet.dtype)\n                if mid_block_res_sample is not None\n                else None\n            ),\n            return_dict=False,\n            cross_attention_kwargs=cross_attention_kwargs_,\n        )\n\n\n# Entry class for model registry\nEntryClass = [Hunyuan3D2DiT, UNet2p5DConditionModel]\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/dits/hunyuanvideo.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\nfrom typing import Any\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\n\nfrom sglang.multimodal_gen.configs.models.dits import HunyuanVideoConfig\nfrom sglang.multimodal_gen.configs.sample.teacache import TeaCacheParams\nfrom sglang.multimodal_gen.runtime.distributed.parallel_state import get_sp_world_size\nfrom sglang.multimodal_gen.runtime.layers.attention import (\n    LocalAttention,\n    UlyssesAttention,\n)\nfrom sglang.multimodal_gen.runtime.layers.elementwise import MulAdd\nfrom sglang.multimodal_gen.runtime.layers.layernorm import (\n    LayerNormScaleShift,\n    RMSNorm,\n    ScaleResidualLayerNormScaleShift,\n)\nfrom sglang.multimodal_gen.runtime.layers.linear import ReplicatedLinear\nfrom sglang.multimodal_gen.runtime.layers.mlp import MLP\nfrom sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import (\n    QuantizationConfig,\n)\nfrom sglang.multimodal_gen.runtime.layers.rotary_embedding import (\n    _apply_rotary_emb,\n    get_rotary_pos_embed,\n)\nfrom sglang.multimodal_gen.runtime.layers.visual_embedding import (\n    ModulateProjection,\n    PatchEmbed,\n    TimestepEmbedder,\n    unpatchify,\n)\nfrom sglang.multimodal_gen.runtime.managers.forward_context import get_forward_context\nfrom sglang.multimodal_gen.runtime.models.dits.base import CachableDiT\nfrom sglang.multimodal_gen.runtime.models.utils import modulate\nfrom sglang.multimodal_gen.runtime.platforms import (\n    AttentionBackendEnum,\n    current_platform,\n)\nfrom sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin\n\n\nclass MMDoubleStreamBlock(nn.Module):\n    \"\"\"\n    A multimodal DiT block with separate modulation for text and image/video,\n    using distributed attention and linear layers.\n    \"\"\"\n\n    def __init__(\n        self,\n        hidden_size: int,\n        num_attention_heads: int,\n        mlp_ratio: float,\n        dtype: torch.dtype | None = None,\n        supported_attention_backends: set[AttentionBackendEnum] | None = None,\n        prefix: str = \"\",\n        quant_config: QuantizationConfig | None = None,\n    ):\n        super().__init__()\n\n        self.deterministic = False\n        self.num_attention_heads = num_attention_heads\n        head_dim = hidden_size // num_attention_heads\n        mlp_hidden_dim = int(hidden_size * mlp_ratio)\n\n        # Image modulation components\n        self.img_mod = ModulateProjection(\n            hidden_size,\n            factor=6,\n            act_layer=\"silu\",\n            dtype=dtype,\n            prefix=f\"{prefix}.img_mod\",\n        )\n\n        # Fused operations for image stream\n        self.img_attn_norm = LayerNormScaleShift(\n            hidden_size, elementwise_affine=False, dtype=dtype\n        )\n        self.img_attn_residual_mlp_norm = ScaleResidualLayerNormScaleShift(\n            hidden_size, elementwise_affine=False, dtype=dtype\n        )\n        self.img_mlp_residual = MulAdd()\n\n        # Image attention components\n        self.img_attn_qkv = ReplicatedLinear(\n            hidden_size,\n            hidden_size * 3,\n            bias=True,\n            params_dtype=dtype,\n            prefix=f\"{prefix}.img_attn_qkv\",\n            quant_config=quant_config,\n        )\n\n        self.img_attn_q_norm = RMSNorm(head_dim, eps=1e-6, dtype=dtype)\n        self.img_attn_k_norm = RMSNorm(head_dim, eps=1e-6, dtype=dtype)\n\n        self.img_attn_proj = ReplicatedLinear(\n            hidden_size,\n            hidden_size,\n            bias=True,\n            params_dtype=dtype,\n            prefix=f\"{prefix}.img_attn_proj\",\n            quant_config=quant_config,\n        )\n\n        self.img_mlp = MLP(\n            hidden_size,\n            mlp_hidden_dim,\n            bias=True,\n            dtype=dtype,\n            prefix=f\"{prefix}.img_mlp\",\n            quant_config=quant_config,\n        )\n\n        # Text modulation components\n        self.txt_mod = ModulateProjection(\n            hidden_size,\n            factor=6,\n            act_layer=\"silu\",\n            dtype=dtype,\n            prefix=f\"{prefix}.txt_mod\",\n        )\n\n        # Fused operations for text stream\n        self.txt_attn_norm = LayerNormScaleShift(\n            hidden_size, elementwise_affine=False, dtype=dtype\n        )\n        self.txt_attn_residual_mlp_norm = ScaleResidualLayerNormScaleShift(\n            hidden_size, elementwise_affine=False, dtype=dtype\n        )\n        self.txt_mlp_residual = MulAdd()\n\n        # Text attention components\n        self.txt_attn_qkv = ReplicatedLinear(\n            hidden_size,\n            hidden_size * 3,\n            bias=True,\n            params_dtype=dtype,\n            quant_config=quant_config,\n        )\n\n        # QK norm layers for text\n        self.txt_attn_q_norm = RMSNorm(head_dim, eps=1e-6, dtype=dtype)\n        self.txt_attn_k_norm = RMSNorm(head_dim, eps=1e-6, dtype=dtype)\n\n        self.txt_attn_proj = ReplicatedLinear(\n            hidden_size,\n            hidden_size,\n            bias=True,\n            params_dtype=dtype,\n            quant_config=quant_config,\n        )\n\n        self.txt_mlp = MLP(\n            hidden_size,\n            mlp_hidden_dim,\n            bias=True,\n            dtype=dtype,\n            quant_config=quant_config,\n        )\n\n        # Use UlyssesAttention to replace Distributed attention\n        self.attn = UlyssesAttention(\n            num_heads=num_attention_heads,\n            head_size=head_dim,\n            causal=False,\n            supported_attention_backends=supported_attention_backends,\n            prefix=f\"{prefix}.attn\",\n        )\n\n    def forward(\n        self,\n        img: torch.Tensor,\n        txt: torch.Tensor,\n        vec: torch.Tensor,\n        freqs_cis: tuple,\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        # Process modulation vectors\n        img_mod_outputs = self.img_mod(vec)\n        (\n            img_attn_shift,\n            img_attn_scale,\n            img_attn_gate,\n            img_mlp_shift,\n            img_mlp_scale,\n            img_mlp_gate,\n        ) = torch.chunk(img_mod_outputs, 6, dim=-1)\n\n        txt_mod_outputs = self.txt_mod(vec)\n        (\n            txt_attn_shift,\n            txt_attn_scale,\n            txt_attn_gate,\n            txt_mlp_shift,\n            txt_mlp_scale,\n            txt_mlp_gate,\n        ) = torch.chunk(txt_mod_outputs, 6, dim=-1)\n\n        # Prepare image for attention using fused operation\n        img_attn_input = self.img_attn_norm(img, img_attn_shift, img_attn_scale)\n        # Get QKV for image\n        img_qkv, _ = self.img_attn_qkv(img_attn_input)\n        batch_size, image_seq_len = img_qkv.shape[0], img_qkv.shape[1]\n\n        # Split QKV\n        img_qkv = img_qkv.view(\n            batch_size, image_seq_len, 3, self.num_attention_heads, -1\n        )\n        img_q, img_k, img_v = img_qkv[:, :, 0], img_qkv[:, :, 1], img_qkv[:, :, 2]\n\n        # Apply QK-Norm if needed\n\n        img_q = self.img_attn_q_norm(img_q.contiguous()).to(img_v)\n        img_k = self.img_attn_k_norm(img_k.contiguous()).to(img_v)\n        # Apply rotary embeddings\n        cos, sin = freqs_cis\n        img_q, img_k = _apply_rotary_emb(\n            img_q, cos, sin, is_neox_style=False\n        ), _apply_rotary_emb(img_k, cos, sin, is_neox_style=False)\n        # Prepare text for attention using fused operation\n        txt_attn_input = self.txt_attn_norm(txt, txt_attn_shift, txt_attn_scale)\n\n        # Get QKV for text\n        txt_qkv, _ = self.txt_attn_qkv(txt_attn_input)\n        batch_size, text_seq_len = txt_qkv.shape[0], txt_qkv.shape[1]\n\n        # Split QKV\n        txt_qkv = txt_qkv.view(\n            batch_size, text_seq_len, 3, self.num_attention_heads, -1\n        )\n        txt_q, txt_k, txt_v = txt_qkv[:, :, 0], txt_qkv[:, :, 1], txt_qkv[:, :, 2]\n\n        # Apply QK-Norm if needed\n        txt_q = self.txt_attn_q_norm(txt_q.contiguous()).to(txt_q.dtype)\n        txt_k = self.txt_attn_k_norm(txt_k.contiguous()).to(txt_k.dtype)\n\n        # Run distributed attention\n        img_attn, txt_attn = self.attn(img_q, img_k, img_v, txt_q, txt_k, txt_v)\n        img_attn_out, _ = self.img_attn_proj(\n            img_attn.view(batch_size, image_seq_len, -1)\n        )\n        # Use fused operation for residual connection, normalization, and modulation\n        img_mlp_input, img_residual = self.img_attn_residual_mlp_norm(\n            img, img_attn_out, img_attn_gate, img_mlp_shift, img_mlp_scale\n        )\n\n        # Process image MLP\n        img_mlp_out = self.img_mlp(img_mlp_input)\n        img = self.img_mlp_residual(img_mlp_out, img_mlp_gate, img_residual)\n\n        # Process text attention output\n        txt_attn_out, _ = self.txt_attn_proj(\n            txt_attn.reshape(batch_size, text_seq_len, -1)\n        )\n\n        # Use fused operation for residual connection, normalization, and modulation\n        txt_mlp_input, txt_residual = self.txt_attn_residual_mlp_norm(\n            txt, txt_attn_out, txt_attn_gate, txt_mlp_shift, txt_mlp_scale\n        )\n\n        # Process text MLP\n        txt_mlp_out = self.txt_mlp(txt_mlp_input)\n        txt = self.txt_mlp_residual(txt_mlp_out, txt_mlp_gate, txt_residual)\n\n        return img, txt\n\n\nclass MMSingleStreamBlock(nn.Module):\n    \"\"\"\n    A DiT block with parallel linear layers using distributed attention\n    and tensor parallelism.\n    \"\"\"\n\n    def __init__(\n        self,\n        hidden_size: int,\n        num_attention_heads: int,\n        mlp_ratio: float = 4.0,\n        dtype: torch.dtype | None = None,\n        supported_attention_backends: set[AttentionBackendEnum] | None = None,\n        prefix: str = \"\",\n        quant_config: QuantizationConfig | None = None,\n    ):\n        super().__init__()\n\n        self.deterministic = False\n        self.hidden_size = hidden_size\n        self.num_attention_heads = num_attention_heads\n        head_dim = hidden_size // num_attention_heads\n        mlp_hidden_dim = int(hidden_size * mlp_ratio)\n        self.mlp_hidden_dim = mlp_hidden_dim\n\n        # Combined QKV and MLP input projection\n        self.linear1 = ReplicatedLinear(\n            hidden_size,\n            hidden_size * 3 + mlp_hidden_dim,\n            bias=True,\n            params_dtype=dtype,\n            prefix=f\"{prefix}.linear1\",\n            quant_config=quant_config,\n        )\n\n        # Combined projection and MLP output\n        self.linear2 = ReplicatedLinear(\n            hidden_size + mlp_hidden_dim,\n            hidden_size,\n            bias=True,\n            params_dtype=dtype,\n            prefix=f\"{prefix}.linear2\",\n            quant_config=quant_config,\n        )\n\n        # QK norm layers\n        self.q_norm = RMSNorm(head_dim, eps=1e-6, dtype=dtype)\n        self.k_norm = RMSNorm(head_dim, eps=1e-6, dtype=dtype)\n\n        # Fused operations with better naming\n        self.input_norm_scale_shift = LayerNormScaleShift(\n            hidden_size,\n            eps=1e-6,\n            elementwise_affine=False,\n            dtype=dtype,\n        )\n        self.output_residual = MulAdd()\n\n        # Activation function\n        self.mlp_act = nn.GELU(approximate=\"tanh\")\n\n        # Modulation\n        self.modulation = ModulateProjection(\n            hidden_size,\n            factor=3,\n            act_layer=\"silu\",\n            dtype=dtype,\n            prefix=f\"{prefix}.modulation\",\n        )\n\n        # Use UlyssesAttention to replace Distributed attention\n        self.attn = UlyssesAttention(\n            num_heads=num_attention_heads,\n            head_size=head_dim,\n            causal=False,\n            supported_attention_backends=supported_attention_backends,\n            prefix=f\"{prefix}.attn\",\n        )\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        vec: torch.Tensor,\n        txt_len: int,\n        freqs_cis: tuple[torch.Tensor, torch.Tensor],\n    ) -> torch.Tensor:\n        # Process modulation\n        mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)\n\n        # Apply pre-norm and modulation using fused operation\n        x_mod = self.input_norm_scale_shift(x, mod_shift, mod_scale)\n\n        # Get combined projections\n        linear1_out, _ = self.linear1(x_mod)\n\n        # Split into QKV and MLP parts\n        qkv, mlp = torch.split(\n            linear1_out, [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1\n        )\n\n        # Process QKV\n        batch_size, seq_len = qkv.shape[0], qkv.shape[1]\n        qkv = qkv.view(batch_size, seq_len, 3, self.num_attention_heads, -1)\n        q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]\n\n        # Apply QK-Norm\n        q = self.q_norm(q.contiguous()).to(v.dtype)\n        k = self.k_norm(k.contiguous()).to(v.dtype)\n\n        # Split into image and text parts\n        img_q, txt_q = q[:, :-txt_len], q[:, -txt_len:]\n        img_k, txt_k = k[:, :-txt_len], k[:, -txt_len:]\n        img_v, txt_v = v[:, :-txt_len], v[:, -txt_len:]\n        # Apply rotary embeddings to image parts\n        cos, sin = freqs_cis\n        img_q, img_k = _apply_rotary_emb(\n            img_q, cos, sin, is_neox_style=False\n        ), _apply_rotary_emb(img_k, cos, sin, is_neox_style=False)\n\n        # Run distributed attention\n        img_attn_output, txt_attn_output = self.attn(\n            img_q, img_k, img_v, txt_q, txt_k, txt_v\n        )\n        attn_output = torch.cat((img_attn_output, txt_attn_output), dim=1).view(\n            batch_size, seq_len, -1\n        )\n        # Process MLP activation\n        mlp_output = self.mlp_act(mlp)\n\n        # Combine attention and MLP outputs\n        combined = torch.cat((attn_output, mlp_output), dim=-1)\n\n        # Final projection\n        output, _ = self.linear2(combined)\n\n        # Apply residual connection with gating using fused operation\n        return self.output_residual(output, mod_gate, x)\n\n\nclass HunyuanVideoTransformer3DModel(CachableDiT, OffloadableDiTMixin):\n    \"\"\"\n    HunyuanVideo Transformer backbone adapted for distributed training.\n\n    This implementation uses distributed attention and linear layers for efficient\n    parallel processing across multiple GPUs.\n\n    Based on the architecture from:\n    - Flux.1: https://github.com/black-forest-labs/flux\n    - MMDiT: http://arxiv.org/abs/2403.03206\n    \"\"\"\n\n    # PY: we make the input args the same as HF config\n\n    # shard single stream, double stream blocks, and refiner_blocks\n    _fsdp_shard_conditions = HunyuanVideoConfig()._fsdp_shard_conditions\n    _compile_conditions = HunyuanVideoConfig()._compile_conditions\n    _supported_attention_backends = HunyuanVideoConfig()._supported_attention_backends\n    param_names_mapping = HunyuanVideoConfig().param_names_mapping\n    reverse_param_names_mapping = HunyuanVideoConfig().reverse_param_names_mapping\n    lora_param_names_mapping = HunyuanVideoConfig().lora_param_names_mapping\n\n    def __init__(\n        self,\n        config: HunyuanVideoConfig,\n        hf_config: dict[str, Any],\n        quant_config: QuantizationConfig | None = None,\n    ):\n        super().__init__(config=config, hf_config=hf_config)\n\n        self.patch_size = [config.patch_size_t, config.patch_size, config.patch_size]\n        self.in_channels = config.in_channels\n        self.num_channels_latents = config.num_channels_latents\n        self.out_channels = (\n            config.in_channels if config.out_channels is None else config.out_channels\n        )\n        self.unpatchify_channels = self.out_channels\n        self.guidance_embeds = config.guidance_embeds\n        self.rope_dim_list = list(config.rope_axes_dim)\n        self.rope_theta = config.rope_theta\n        self.text_states_dim = config.text_embed_dim\n        self.text_states_dim_2 = config.pooled_projection_dim\n        # TODO(will): hack?\n        self.dtype = config.dtype\n\n        pe_dim = config.hidden_size // config.num_attention_heads\n        if sum(config.rope_axes_dim) != pe_dim:\n            raise ValueError(\n                f\"Got {config.rope_axes_dim} but expected positional dim {pe_dim}\"\n            )\n\n        self.hidden_size = config.hidden_size\n        self.num_attention_heads = config.num_attention_heads\n        self.num_channels_latents = config.num_channels_latents\n\n        # Image projection\n        self.img_in = PatchEmbed(\n            self.patch_size,\n            self.in_channels,\n            self.hidden_size,\n            dtype=config.dtype,\n            prefix=f\"{config.prefix}.img_in\",\n        )\n\n        self.txt_in = SingleTokenRefiner(\n            self.text_states_dim,\n            config.hidden_size,\n            config.num_attention_heads,\n            depth=config.num_refiner_layers,\n            dtype=config.dtype,\n            prefix=f\"{config.prefix}.txt_in\",\n        )\n\n        # Time modulation\n        self.time_in = TimestepEmbedder(\n            self.hidden_size,\n            act_layer=\"silu\",\n            dtype=config.dtype,\n            prefix=f\"{config.prefix}.time_in\",\n        )\n\n        # Text modulation\n        self.vector_in = MLP(\n            self.text_states_dim_2,\n            self.hidden_size,\n            self.hidden_size,\n            act_type=\"silu\",\n            dtype=config.dtype,\n            prefix=f\"{config.prefix}.vector_in\",\n        )\n\n        # Guidance modulation\n        self.guidance_in = (\n            TimestepEmbedder(\n                self.hidden_size,\n                act_layer=\"silu\",\n                dtype=config.dtype,\n                prefix=f\"{config.prefix}.guidance_in\",\n            )\n            if self.guidance_embeds\n            else None\n        )\n\n        # Double blocks\n        self.double_blocks = nn.ModuleList(\n            [\n                MMDoubleStreamBlock(\n                    config.hidden_size,\n                    config.num_attention_heads,\n                    mlp_ratio=config.mlp_ratio,\n                    dtype=config.dtype,\n                    supported_attention_backends=self._supported_attention_backends,\n                    prefix=f\"{config.prefix}.double_blocks.{i}\",\n                    quant_config=quant_config,\n                )\n                for i in range(config.num_layers)\n            ]\n        )\n\n        # Single blocks\n        self.single_blocks = nn.ModuleList(\n            [\n                MMSingleStreamBlock(\n                    config.hidden_size,\n                    config.num_attention_heads,\n                    mlp_ratio=config.mlp_ratio,\n                    dtype=config.dtype,\n                    supported_attention_backends=self._supported_attention_backends,\n                    prefix=f\"{config.prefix}.single_blocks.{i + config.num_layers}\",\n                    quant_config=quant_config,\n                )\n                for i in range(config.num_single_layers)\n            ]\n        )\n\n        self.final_layer = FinalLayer(\n            config.hidden_size,\n            self.patch_size,\n            self.out_channels,\n            dtype=config.dtype,\n            prefix=f\"{config.prefix}.final_layer\",\n        )\n\n        self.__post_init__()\n\n        self.layer_names = [\"double_blocks\", \"single_blocks\"]\n\n    # TODO: change the input the FORWARD_BATCH Dict\n    # TODO: change output to a dict\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: torch.Tensor | list[torch.Tensor],\n        timestep: torch.LongTensor,\n        encoder_hidden_states_image: torch.Tensor | list[torch.Tensor] | None = None,\n        guidance=None,\n        **kwargs,\n    ):\n        \"\"\"\n        Forward pass of the HunyuanDiT model.\n\n        Args:\n            hidden_states: Input image/video latents [B, C, T, H, W]\n            encoder_hidden_states: Text embeddings [B, L, D]\n            timestep: Diffusion timestep\n            guidance: Guidance scale for CFG\n\n        Returns:\n            Tuple of (output)\n        \"\"\"\n        forward_context = get_forward_context()\n        forward_batch = forward_context.forward_batch\n        enable_teacache = forward_batch is not None and forward_batch.enable_teacache\n\n        if guidance is None:\n            guidance = torch.tensor(\n                [6016.0], device=hidden_states.device, dtype=hidden_states.dtype\n            )\n\n        img = x = hidden_states\n        t = timestep\n\n        # Split text embeddings - first token is global, rest are per-token\n        if isinstance(encoder_hidden_states, torch.Tensor):\n            txt = encoder_hidden_states[:, 1:]\n            text_states_2 = encoder_hidden_states[:, 0, : self.text_states_dim_2]\n        else:\n            txt = encoder_hidden_states[0]\n            text_states_2 = encoder_hidden_states[1]\n\n        # Get spatial dimensions\n        _, _, ot, oh, ow = x.shape  # codespell:ignore\n        tt, th, tw = (\n            ot // self.patch_size[0],  # codespell:ignore\n            oh // self.patch_size[1],\n            ow // self.patch_size[2],\n        )\n\n        # Get rotary embeddings\n        freqs_cos, freqs_sin = get_rotary_pos_embed(\n            (tt * get_sp_world_size(), th, tw),\n            self.hidden_size,\n            self.num_attention_heads,\n            self.rope_dim_list,\n            self.rope_theta,\n        )\n        freqs_cos = freqs_cos.to(x.device)\n        freqs_sin = freqs_sin.to(x.device)\n        # Prepare modulation vectors\n        vec = self.time_in(t)\n\n        # Add text modulation\n        vec = vec + self.vector_in(text_states_2)\n\n        # Add guidance modulation if needed\n        if self.guidance_in and guidance is not None:\n            vec = vec + self.guidance_in(guidance)\n        # Embed image and text\n        img = self.img_in(img)\n        txt = self.txt_in(txt, t)\n        txt_seq_len = txt.shape[1]\n        img_seq_len = img.shape[1]\n\n        freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None\n\n        should_skip_forward = self.should_skip_forward_for_cached_states(\n            img=img, vec=vec\n        )\n\n        if should_skip_forward:\n            img = self.retrieve_cached_states(img)\n        else:\n            if enable_teacache:\n                original_img = img.clone()\n\n            # Process through double stream blocks\n            for index, block in enumerate(self.double_blocks):\n                double_block_args = [img, txt, vec, freqs_cis]\n                img, txt = block(*double_block_args)\n            # Merge txt and img to pass through single stream blocks\n            x = torch.cat((img, txt), 1)\n\n            # Process through single stream blocks\n            if len(self.single_blocks) > 0:\n                for index, block in enumerate(self.single_blocks):\n                    single_block_args = [\n                        x,\n                        vec,\n                        txt_seq_len,\n                        freqs_cis,\n                    ]\n                    x = block(*single_block_args)\n\n            # Extract image features\n            img = x[:, :img_seq_len, ...]\n\n            if enable_teacache:\n                self.maybe_cache_states(img, original_img)\n\n        # Final layer processing\n        img = self.final_layer(img, vec)\n        # Unpatchify to get original shape\n        img = unpatchify(img, tt, th, tw, self.patch_size, self.out_channels)\n\n        return img\n\n    def maybe_cache_states(\n        self, hidden_states: torch.Tensor, original_hidden_states: torch.Tensor\n    ) -> None:\n        self.previous_residual = hidden_states - original_hidden_states\n\n    def should_skip_forward_for_cached_states(self, **kwargs) -> bool:\n\n        forward_context = get_forward_context()\n        forward_batch = forward_context.forward_batch\n        if forward_batch is None:\n            return False\n        current_timestep = forward_context.current_timestep\n        enable_teacache = forward_batch.enable_teacache\n\n        if not enable_teacache:\n            return False\n        raise NotImplementedError(\"teacache is not supported yet for HunyuanVideo\")\n\n        teacache_params = forward_batch.teacache_params\n        assert teacache_params is not None, \"teacache_params is not initialized\"\n        assert isinstance(\n            teacache_params, TeaCacheParams\n        ), \"teacache_params is not a TeaCacheParams\"\n        num_inference_steps = forward_batch.num_inference_steps\n        teache_thresh = teacache_params.teacache_thresh\n\n        coefficients = teacache_params.coefficients\n\n        if current_timestep == 0:\n            self.cnt = 0\n\n        inp = kwargs[\"img\"].clone()\n        vec_ = kwargs[\"vec\"].clone()\n        # convert to DTensor\n        vec_ = torch.distributed.tensor.DTensor.from_local(\n            vec_,\n            torch.distributed.DeviceMesh(\n                current_platform.device_type,\n                list(range(get_sp_world_size())),\n                mesh_dim_names=(\"dp\",),\n            ),\n            [torch.distributed.tensor.Replicate()],\n        )\n\n        inp = torch.distributed.tensor.DTensor.from_local(\n            inp,\n            torch.distributed.DeviceMesh(\n                current_platform.device_type,\n                list(range(get_sp_world_size())),\n                mesh_dim_names=(\"dp\",),\n            ),\n            [torch.distributed.tensor.Replicate()],\n        )\n\n        # txt_ = kwargs[\"txt\"].clone()\n\n        # inp = img.clone()\n        # vec_ = vec.clone()\n        # txt_ = txt.clone()\n        (\n            img_mod1_shift,\n            img_mod1_scale,\n            img_mod1_gate,\n            img_mod2_shift,\n            img_mod2_scale,\n            img_mod2_gate,\n        ) = (\n            self.double_blocks[0].img_mod(vec_).chunk(6, dim=-1)\n        )\n        normed_inp = self.double_blocks[0].img_attn_norm.norm(inp)\n        modulated_inp = modulate(normed_inp, shift=img_mod1_shift, scale=img_mod1_scale)\n        if self.cnt == 0 or self.cnt == num_inference_steps - 1:\n            should_calc = True\n            self.accumulated_rel_l1_distance = 0\n        else:\n            coefficients = [\n                7.33226126e02,\n                -4.01131952e02,\n                6.75869174e01,\n                -3.14987800e00,\n                9.61237896e-02,\n            ]\n            rescale_func = np.poly1d(coefficients)\n            assert (\n                self.previous_modulated_input is not None\n            ), \"previous_modulated_input is not initialized\"\n            self.accumulated_rel_l1_distance += rescale_func(\n                (\n                    (modulated_inp - self.previous_modulated_input).abs().mean()\n                    / self.previous_modulated_input.abs().mean()\n                )\n                .cpu()\n                .item()\n            )\n            if self.accumulated_rel_l1_distance < teache_thresh:\n                should_calc = False\n            else:\n                should_calc = True\n                self.accumulated_rel_l1_distance = 0\n        self.previous_modulated_input = modulated_inp\n        self.cnt += 1\n\n        return not should_calc\n\n    def retrieve_cached_states(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        return hidden_states + self.previous_residual\n\n\nclass SingleTokenRefiner(nn.Module):\n    \"\"\"\n    A token refiner that processes text embeddings with attention to improve\n    their representation for cross-attention with image features.\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels,\n        hidden_size,\n        num_attention_heads,\n        depth=2,\n        qkv_bias=True,\n        dtype=None,\n        prefix: str = \"\",\n    ) -> None:\n        super().__init__()\n\n        # Input projection\n        self.input_embedder = ReplicatedLinear(\n            in_channels,\n            hidden_size,\n            bias=True,\n            params_dtype=dtype,\n            prefix=f\"{prefix}.input_embedder\",\n        )\n\n        # Timestep embedding\n        self.t_embedder = TimestepEmbedder(\n            hidden_size, act_layer=\"silu\", dtype=dtype, prefix=f\"{prefix}.t_embedder\"\n        )\n\n        # Context embedding\n        self.c_embedder = MLP(\n            in_channels,\n            hidden_size,\n            hidden_size,\n            act_type=\"silu\",\n            dtype=dtype,\n            prefix=f\"{prefix}.c_embedder\",\n        )\n\n        # Refiner blocks\n        self.refiner_blocks = nn.ModuleList(\n            [\n                IndividualTokenRefinerBlock(\n                    hidden_size,\n                    num_attention_heads,\n                    qkv_bias=qkv_bias,\n                    dtype=dtype,\n                    prefix=f\"{prefix}.refiner_blocks.{i}\",\n                )\n                for i in range(depth)\n            ]\n        )\n\n    def forward(self, x, t):\n        # Get timestep embeddings\n        timestep_aware_representations = self.t_embedder(t)\n\n        # Get context-aware representations\n\n        context_aware_representations = torch.mean(x, dim=1)\n\n        context_aware_representations = self.c_embedder(context_aware_representations)\n        c = timestep_aware_representations + context_aware_representations\n        # Project input\n        x, _ = self.input_embedder(x)\n        # Process through refiner blocks\n        for block in self.refiner_blocks:\n            x = block(x, c)\n        return x\n\n\nclass IndividualTokenRefinerBlock(nn.Module):\n    \"\"\"\n    A transformer block for refining individual tokens with self-attention.\n    \"\"\"\n\n    def __init__(\n        self,\n        hidden_size,\n        num_attention_heads,\n        mlp_ratio=4.0,\n        qkv_bias=True,\n        dtype=None,\n        prefix: str = \"\",\n    ) -> None:\n        super().__init__()\n        self.num_attention_heads = num_attention_heads\n        mlp_hidden_dim = int(hidden_size * mlp_ratio)\n\n        # Normalization and attention\n        self.norm1 = nn.LayerNorm(\n            hidden_size, eps=1e-6, elementwise_affine=True, dtype=dtype\n        )\n\n        self.self_attn_qkv = ReplicatedLinear(\n            hidden_size,\n            hidden_size * 3,\n            bias=qkv_bias,\n            params_dtype=dtype,\n            prefix=f\"{prefix}.self_attn_qkv\",\n        )\n\n        self.self_attn_proj = ReplicatedLinear(\n            hidden_size,\n            hidden_size,\n            bias=qkv_bias,\n            params_dtype=dtype,\n            prefix=f\"{prefix}.self_attn_proj\",\n        )\n\n        # MLP\n        self.norm2 = nn.LayerNorm(\n            hidden_size, eps=1e-6, elementwise_affine=True, dtype=dtype\n        )\n        self.mlp = MLP(\n            hidden_size,\n            mlp_hidden_dim,\n            bias=True,\n            act_type=\"silu\",\n            dtype=dtype,\n            prefix=f\"{prefix}.mlp\",\n        )\n\n        # Modulation\n        self.adaLN_modulation = ModulateProjection(\n            hidden_size,\n            factor=2,\n            act_layer=\"silu\",\n            dtype=dtype,\n            prefix=f\"{prefix}.adaLN_modulation\",\n        )\n\n        # Scaled dot product attention\n        self.attn = LocalAttention(\n            num_heads=num_attention_heads,\n            head_size=hidden_size // num_attention_heads,\n            # TODO: remove hardcode; remove STA\n            supported_attention_backends=(\n                AttentionBackendEnum.FA,\n                AttentionBackendEnum.AITER,\n                AttentionBackendEnum.TORCH_SDPA,\n            ),\n        )\n\n    def forward(self, x, c):\n        # Get modulation parameters\n        gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=-1)\n        # Self-attention\n        norm_x = self.norm1(x)\n        qkv, _ = self.self_attn_qkv(norm_x)\n\n        batch_size, seq_len = qkv.shape[0], qkv.shape[1]\n        qkv = qkv.view(batch_size, seq_len, 3, self.num_attention_heads, -1)\n        q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]\n\n        # Run scaled dot product attention\n        attn_output = self.attn(q, k, v)  # [B, L, H, D]\n        attn_output = attn_output.reshape(batch_size, seq_len, -1)  # [B, L, H*D]\n\n        # Project and apply residual connection with gating\n        attn_out, _ = self.self_attn_proj(attn_output)\n        x = x + attn_out * gate_msa.unsqueeze(1)\n\n        # MLP\n        mlp_out = self.mlp(self.norm2(x))\n        x = x + mlp_out * gate_mlp.unsqueeze(1)\n\n        return x\n\n\nclass FinalLayer(nn.Module):\n    \"\"\"\n    The final layer of DiT that projects features to pixel space.\n    \"\"\"\n\n    def __init__(\n        self, hidden_size, patch_size, out_channels, dtype=None, prefix: str = \"\"\n    ) -> None:\n        super().__init__()\n\n        # Normalization\n        self.norm_final = nn.LayerNorm(\n            hidden_size, eps=1e-6, elementwise_affine=False, dtype=dtype\n        )\n\n        output_dim = patch_size[0] * patch_size[1] * patch_size[2] * out_channels\n\n        self.linear = ReplicatedLinear(\n            hidden_size,\n            output_dim,\n            bias=True,\n            params_dtype=dtype,\n            prefix=f\"{prefix}.linear\",\n        )\n\n        # Modulation\n        self.adaLN_modulation = ModulateProjection(\n            hidden_size,\n            factor=2,\n            act_layer=\"silu\",\n            dtype=dtype,\n            prefix=f\"{prefix}.adaLN_modulation\",\n        )\n\n    def forward(self, x, c):\n        # What the heck HF? Why you change the scale and shift order here???\n        scale, shift = self.adaLN_modulation(c).chunk(2, dim=-1)\n        x = self.norm_final(x) * (1.0 + scale.unsqueeze(1)) + shift.unsqueeze(1)\n        x, _ = self.linear(x)\n        return x\n\n\nEntryClass = HunyuanVideoTransformer3DModel\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/dits/ltx_2.py",
    "content": "# Copied and adapted from LTX-2 and WanVideo implementations.\n#\n# SPDX-License-Identifier: Apache-2.0\n\nfrom __future__ import annotations\n\nfrom typing import Any, Optional, Tuple, Union\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom sglang.multimodal_gen.configs.models.dits.ltx_2 import LTX2ArchConfig, LTX2Config\nfrom sglang.multimodal_gen.runtime.distributed import (\n    get_sp_parallel_rank,\n    get_sp_world_size,\n    get_tp_rank,\n    get_tp_world_size,\n    model_parallel_is_initialized,\n)\nfrom sglang.multimodal_gen.runtime.distributed.communication_op import (\n    tensor_model_parallel_all_reduce,\n)\nfrom sglang.multimodal_gen.runtime.layers.attention import USPAttention\nfrom sglang.multimodal_gen.runtime.layers.linear import (\n    ColumnParallelLinear,\n    RowParallelLinear,\n)\nfrom sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import (\n    QuantizationConfig,\n)\nfrom sglang.multimodal_gen.runtime.layers.visual_embedding import timestep_embedding\nfrom sglang.multimodal_gen.runtime.models.dits.base import CachableDiT\nfrom sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum\nfrom sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\ndef apply_interleaved_rotary_emb(\n    x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Tensor]\n) -> torch.Tensor:\n    cos, sin = freqs\n    x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1)\n    x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2)\n    return (x.float() * cos + x_rotated.float() * sin).to(x.dtype)\n\n\ndef apply_split_rotary_emb(\n    x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Tensor]\n) -> torch.Tensor:\n    cos, sin = freqs\n    x_dtype = x.dtype\n    needs_reshape = False\n    if x.ndim != 4 and cos.ndim == 4:\n        b = x.shape[0]\n        _, h, t, _ = cos.shape\n        x = x.reshape(b, t, h, -1).swapaxes(1, 2)\n        needs_reshape = True\n\n    last = x.shape[-1]\n    if last % 2 != 0:\n        raise ValueError(\n            f\"Expected x.shape[-1] to be even for split rotary, got {last}.\"\n        )\n    r = last // 2\n\n    split_x = x.reshape(*x.shape[:-1], 2, r)\n    first_x = split_x[..., :1, :]\n    second_x = split_x[..., 1:, :]\n\n    cos_u = cos.unsqueeze(-2)\n    sin_u = sin.unsqueeze(-2)\n\n    out = split_x * cos_u\n    first_out = out[..., :1, :]\n    second_out = out[..., 1:, :]\n    first_out.addcmul_(-sin_u, second_x)\n    second_out.addcmul_(sin_u, first_x)\n\n    out = out.reshape(*out.shape[:-2], last)\n    if needs_reshape:\n        out = out.swapaxes(1, 2).reshape(b, t, -1)\n    return out.to(dtype=x_dtype)\n\n\n# ==============================================================================\n# Layers and Embeddings\n# ==============================================================================\n\n\nclass LTX2AudioVideoRotaryPosEmbed(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        patch_size: int = 1,\n        patch_size_t: int = 1,\n        base_num_frames: int = 20,\n        base_height: int = 2048,\n        base_width: int = 2048,\n        sampling_rate: int = 16000,\n        hop_length: int = 160,\n        scale_factors: Tuple[int, ...] = (8, 32, 32),\n        theta: float = 10000.0,\n        causal_offset: int = 1,\n        modality: str = \"video\",\n        double_precision: bool = True,\n        rope_type: str = \"interleaved\",\n        num_attention_heads: int = 32,\n    ) -> None:\n        super().__init__()\n        self.dim = int(dim)\n        self.patch_size = int(patch_size)\n        self.patch_size_t = int(patch_size_t)\n\n        if rope_type not in [\"interleaved\", \"split\"]:\n            raise ValueError(\n                f\"{rope_type=} not supported. Choose between 'interleaved' and 'split'.\"\n            )\n        self.rope_type = rope_type\n\n        self.base_num_frames = int(base_num_frames)\n        self.num_attention_heads = int(num_attention_heads)\n\n        self.base_height = int(base_height)\n        self.base_width = int(base_width)\n\n        self.sampling_rate = int(sampling_rate)\n        self.hop_length = int(hop_length)\n        self.audio_latents_per_second = (\n            float(self.sampling_rate) / float(self.hop_length) / float(scale_factors[0])\n        )\n\n        self.scale_factors = tuple(int(x) for x in scale_factors)\n        self.theta = float(theta)\n        self.causal_offset = int(causal_offset)\n\n        self.modality = modality\n        self.coords_dtype = torch.bfloat16 if modality == \"video\" else torch.float32\n        if self.modality not in [\"video\", \"audio\"]:\n            raise ValueError(\n                f\"Modality {modality} is not supported. Supported modalities are `video` and `audio`.\"\n            )\n        self.double_precision = bool(double_precision)\n\n    def prepare_video_coords(\n        self,\n        batch_size: int,\n        num_frames: int,\n        height: int,\n        width: int,\n        device: torch.device,\n        fps: float = 24.0,\n        *,\n        start_frame: int = 0,\n    ) -> torch.Tensor:\n        grid_f = torch.arange(\n            start=int(start_frame),\n            end=int(num_frames) + int(start_frame),\n            step=self.patch_size_t,\n            dtype=torch.float32,\n            device=device,\n        )\n        grid_h = torch.arange(\n            start=0,\n            end=height,\n            step=self.patch_size,\n            dtype=torch.float32,\n            device=device,\n        )\n        grid_w = torch.arange(\n            start=0,\n            end=width,\n            step=self.patch_size,\n            dtype=torch.float32,\n            device=device,\n        )\n        grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing=\"ij\")\n        grid = torch.stack(grid, dim=0)\n\n        patch_size = (self.patch_size_t, self.patch_size, self.patch_size)\n        patch_size_delta = torch.tensor(\n            patch_size, dtype=grid.dtype, device=grid.device\n        )\n        patch_ends = grid + patch_size_delta.view(3, 1, 1, 1)\n\n        latent_coords = torch.stack([grid, patch_ends], dim=-1)\n        latent_coords = latent_coords.flatten(1, 3)\n        latent_coords = latent_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1)\n\n        scale_tensor = torch.tensor(self.scale_factors, device=latent_coords.device)\n        broadcast_shape = [1] * latent_coords.ndim\n        broadcast_shape[1] = -1\n        pixel_coords = latent_coords * scale_tensor.view(*broadcast_shape)\n        pixel_coords[:, 0, ...] = (\n            pixel_coords[:, 0, ...] + self.causal_offset - self.scale_factors[0]\n        ).clamp(min=0)\n        pixel_coords[:, 0, ...] = pixel_coords[:, 0, ...] / fps\n        return pixel_coords\n\n    def prepare_audio_coords(\n        self,\n        batch_size: int,\n        num_frames: int,\n        device: torch.device,\n        *,\n        start_frame: int = 0,\n    ) -> torch.Tensor:\n        grid_f = torch.arange(\n            start=int(start_frame),\n            end=int(num_frames) + int(start_frame),\n            step=self.patch_size_t,\n            dtype=torch.float32,\n            device=device,\n        )\n\n        audio_scale_factor = self.scale_factors[0]\n        grid_start_mel = grid_f * audio_scale_factor\n        grid_start_mel = (\n            grid_start_mel + self.causal_offset - audio_scale_factor\n        ).clip(min=0)\n        grid_start_s = grid_start_mel * self.hop_length / self.sampling_rate\n\n        grid_end_mel = (grid_f + self.patch_size_t) * audio_scale_factor\n        grid_end_mel = (grid_end_mel + self.causal_offset - audio_scale_factor).clip(\n            min=0\n        )\n        grid_end_s = grid_end_mel * self.hop_length / self.sampling_rate\n\n        audio_coords = torch.stack([grid_start_s, grid_end_s], dim=-1)\n        audio_coords = audio_coords.unsqueeze(0).expand(batch_size, -1, -1)\n        audio_coords = audio_coords.unsqueeze(1)\n        return audio_coords\n\n    def prepare_coords(self, *args, **kwargs):\n        if self.modality == \"video\":\n            return self.prepare_video_coords(*args, **kwargs)\n        return self.prepare_audio_coords(*args, **kwargs)\n\n    def forward(\n        self, coords: torch.Tensor, device: Optional[Union[str, torch.device]] = None\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        device = device or coords.device\n        num_pos_dims = coords.shape[1]\n\n        coords = coords.to(self.coords_dtype)\n        if coords.ndim == 4:\n            coords_start, coords_end = coords.chunk(2, dim=-1)\n            coords = (coords_start + coords_end) / 2.0\n            coords = coords.squeeze(-1)\n\n        if self.modality == \"video\":\n            max_positions = (self.base_num_frames, self.base_height, self.base_width)\n        else:\n            max_positions = (self.base_num_frames,)\n\n        grid = torch.stack(\n            [coords[:, i] / max_positions[i] for i in range(num_pos_dims)], dim=-1\n        ).to(device)\n\n        num_rope_elems = num_pos_dims * 2\n        freqs_dtype = torch.float64 if self.double_precision else torch.float32\n        pow_indices = torch.pow(\n            self.theta,\n            torch.linspace(\n                start=0.0,\n                end=1.0,\n                steps=self.dim // num_rope_elems,\n                dtype=freqs_dtype,\n                device=device,\n            ),\n        )\n        freqs = (pow_indices * torch.pi / 2.0).to(dtype=torch.float32)\n\n        freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs\n        freqs = freqs.transpose(-1, -2).flatten(2)\n\n        if self.rope_type == \"interleaved\":\n            cos_freqs = freqs.cos().repeat_interleave(2, dim=-1)\n            sin_freqs = freqs.sin().repeat_interleave(2, dim=-1)\n\n            if self.dim % num_rope_elems != 0:\n                cos_padding = torch.ones_like(\n                    cos_freqs[:, :, : self.dim % num_rope_elems]\n                )\n                sin_padding = torch.zeros_like(\n                    cos_freqs[:, :, : self.dim % num_rope_elems]\n                )\n                cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1)\n                sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1)\n        else:\n            expected_freqs = self.dim // 2\n            current_freqs = freqs.shape[-1]\n            pad_size = expected_freqs - current_freqs\n            cos_freq = freqs.cos()\n            sin_freq = freqs.sin()\n\n            if pad_size != 0:\n                cos_padding = torch.ones_like(cos_freq[:, :, :pad_size])\n                sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size])\n                cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)\n                sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)\n\n            b = cos_freq.shape[0]\n            t = cos_freq.shape[1]\n            cos_freq = cos_freq.reshape(b, t, self.num_attention_heads, -1)\n            sin_freq = sin_freq.reshape(b, t, self.num_attention_heads, -1)\n            cos_freqs = torch.swapaxes(cos_freq, 1, 2)\n            sin_freqs = torch.swapaxes(sin_freq, 1, 2)\n\n        # Cast to bf16 to match model weights dtype. coords_dtype controls\n        # intermediate coordinate precision (fp32 for audio) and differs.\n        return cos_freqs.to(torch.bfloat16), sin_freqs.to(torch.bfloat16)\n\n\ndef rms_norm(x: torch.Tensor, eps: float) -> torch.Tensor:\n    return F.rms_norm(x, normalized_shape=(x.shape[-1],), eps=eps)\n\n\nclass LTX2TextProjection(nn.Module):\n    def __init__(\n        self,\n        in_features: int,\n        hidden_size: int,\n        out_features: int | None = None,\n        act_fn: str = \"gelu_tanh\",\n    ) -> None:\n        super().__init__()\n        if out_features is None:\n            out_features = hidden_size\n\n        self.linear_1 = ColumnParallelLinear(\n            in_features, hidden_size, bias=True, gather_output=True\n        )\n        if act_fn == \"gelu_tanh\":\n            self.act_1 = nn.GELU(approximate=\"tanh\")\n        elif act_fn == \"silu\":\n            self.act_1 = nn.SiLU()\n        else:\n            raise ValueError(f\"Unknown activation function: {act_fn}\")\n\n        self.linear_2 = ColumnParallelLinear(\n            hidden_size, out_features, bias=True, gather_output=True\n        )\n\n    def forward(self, caption: torch.Tensor) -> torch.Tensor:\n        hidden_states, _ = self.linear_1(caption)\n        hidden_states = self.act_1(hidden_states)\n        hidden_states, _ = self.linear_2(hidden_states)\n        return hidden_states\n\n\nclass LTX2TimestepEmbedder(nn.Module):\n    def __init__(self, embedding_dim: int, in_channels: int = 256) -> None:\n        super().__init__()\n        self.linear_1 = ColumnParallelLinear(\n            in_channels, embedding_dim, bias=True, gather_output=True\n        )\n        self.linear_2 = ColumnParallelLinear(\n            embedding_dim, embedding_dim, bias=True, gather_output=True\n        )\n\n    def forward(self, t_emb: torch.Tensor) -> torch.Tensor:\n        x, _ = self.linear_1(t_emb)\n        x = F.silu(x)\n        x, _ = self.linear_2(x)\n        return x\n\n\nclass LTX2PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):\n    def __init__(self, embedding_dim: int) -> None:\n        super().__init__()\n        self.timestep_embedder = LTX2TimestepEmbedder(embedding_dim, in_channels=256)\n\n    def forward(\n        self, timestep: torch.Tensor, hidden_dtype: torch.dtype | None = None\n    ) -> torch.Tensor:\n        t = timestep.reshape(-1).to(dtype=torch.float32)\n        t_emb = timestep_embedding(t, dim=256, max_period=10000, dtype=torch.float32)\n        if hidden_dtype is not None:\n            t_emb = t_emb.to(dtype=hidden_dtype)\n        return self.timestep_embedder(t_emb)\n\n\nclass LTX2AdaLayerNormSingle(nn.Module):\n    def __init__(self, embedding_dim: int, embedding_coefficient: int = 6) -> None:\n        super().__init__()\n        self.emb = LTX2PixArtAlphaCombinedTimestepSizeEmbeddings(embedding_dim)\n        self.silu = nn.SiLU()\n        self.linear = ColumnParallelLinear(\n            embedding_dim,\n            embedding_coefficient * embedding_dim,\n            bias=True,\n            gather_output=True,\n        )\n\n    def forward(\n        self, timestep: torch.Tensor, hidden_dtype: torch.dtype | None = None\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        embedded_timestep = self.emb(timestep, hidden_dtype=hidden_dtype).to(\n            dtype=self.linear.weight.dtype\n        )\n        out, _ = self.linear(self.silu(embedded_timestep))\n        return out, embedded_timestep\n\n\nclass LTX2TPRMSNormAcrossHeads(nn.Module):\n    def __init__(\n        self, full_hidden_size: int, local_hidden_size: int, eps: float\n    ) -> None:\n        super().__init__()\n        self.full_hidden_size = full_hidden_size\n        self.local_hidden_size = local_hidden_size\n        self.eps = eps\n        self.weight = nn.Parameter(torch.ones(local_hidden_size))\n\n        tp_rank = get_tp_rank()\n\n        def _weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:\n            shard = loaded_weight.narrow(\n                0, tp_rank * local_hidden_size, local_hidden_size\n            )\n            param.data.copy_(shard.to(dtype=param.dtype, device=param.device))\n\n        setattr(self.weight, \"weight_loader\", _weight_loader)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        # Keep track of the original dtype. We do the statistics in fp32 for\n        # numerical stability, but cast the output back to the input dtype to\n        orig_dtype = x.dtype\n        if get_tp_world_size() == 1:\n            var = x.float().pow(2).mean(dim=-1, keepdim=True)\n        else:\n            local_sumsq = x.float().pow(2).sum(dim=-1, keepdim=True)\n            global_sumsq = tensor_model_parallel_all_reduce(local_sumsq)\n            var = global_sumsq / float(self.full_hidden_size)\n\n        inv_rms_fp32 = torch.rsqrt(var + self.eps)\n        y = (x.float() * inv_rms_fp32).to(dtype=orig_dtype)\n        return y * self.weight.to(dtype=orig_dtype)\n\n\nclass LTX2Attention(nn.Module):\n    def __init__(\n        self,\n        query_dim: int,\n        context_dim: int | None = None,\n        heads: int = 8,\n        dim_head: int = 64,\n        norm_eps: float = 1e-6,\n        qk_norm: bool = True,\n        supported_attention_backends: set[AttentionBackendEnum] | None = None,\n        prefix: str = \"\",\n        quant_config: QuantizationConfig | None = None,\n    ) -> None:\n        super().__init__()\n\n        self.query_dim = int(query_dim)\n        self.context_dim = int(query_dim if context_dim is None else context_dim)\n        self.heads = int(heads)\n        self.dim_head = int(dim_head)\n        self.inner_dim = self.heads * self.dim_head\n        self.norm_eps = float(norm_eps)\n        self.qk_norm = bool(qk_norm)\n\n        tp_size = get_tp_world_size()\n        if tp_size <= 0:\n            raise ValueError(f\"Invalid {tp_size=}. Expected tp_size >= 1.\")\n        if self.heads % tp_size != 0:\n            raise ValueError(\n                f\"LTX2Attention requires heads divisible by tp_size, got \"\n                f\"{self.heads=} {tp_size=}.\"\n            )\n        if self.inner_dim % tp_size != 0:\n            # This should follow from heads % tp_size, but keep explicit for clarity.\n            raise ValueError(\n                f\"LTX2Attention requires inner_dim divisible by tp_size, got \"\n                f\"{self.inner_dim=} {tp_size=}.\"\n            )\n        self.local_heads = self.heads // tp_size\n\n        self.to_q = ColumnParallelLinear(\n            self.query_dim,\n            self.inner_dim,\n            bias=True,\n            gather_output=False,\n            quant_config=quant_config,\n        )\n        self.to_k = ColumnParallelLinear(\n            self.context_dim,\n            self.inner_dim,\n            bias=True,\n            gather_output=False,\n            quant_config=quant_config,\n        )\n        self.to_v = ColumnParallelLinear(\n            self.context_dim,\n            self.inner_dim,\n            bias=True,\n            gather_output=False,\n            quant_config=quant_config,\n        )\n\n        self.q_norm: nn.Module | None = None\n        self.k_norm: nn.Module | None = None\n        if self.qk_norm:\n            if tp_size == 1:\n                self.q_norm = torch.nn.RMSNorm(self.inner_dim, eps=self.norm_eps)\n                self.k_norm = torch.nn.RMSNorm(self.inner_dim, eps=self.norm_eps)\n            else:\n                self.q_norm = LTX2TPRMSNormAcrossHeads(\n                    full_hidden_size=self.inner_dim,\n                    local_hidden_size=self.inner_dim // tp_size,\n                    eps=self.norm_eps,\n                )\n                self.k_norm = LTX2TPRMSNormAcrossHeads(\n                    full_hidden_size=self.inner_dim,\n                    local_hidden_size=self.inner_dim // tp_size,\n                    eps=self.norm_eps,\n                )\n\n        self.to_out = nn.Sequential(\n            RowParallelLinear(\n                self.inner_dim,\n                self.query_dim,\n                bias=True,\n                input_is_parallel=True,\n                quant_config=quant_config,\n            ),\n            nn.Identity(),\n        )\n\n        self.attn = USPAttention(\n            num_heads=self.local_heads,\n            head_size=self.dim_head,\n            num_kv_heads=self.local_heads,\n            dropout_rate=0,\n            softmax_scale=None,\n            causal=False,\n            supported_attention_backends=supported_attention_backends,\n            prefix=f\"{prefix}.attn\",\n        )\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        context: torch.Tensor | None = None,\n        mask: torch.Tensor | None = None,\n        pe: tuple[torch.Tensor, torch.Tensor] | None = None,\n        k_pe: tuple[torch.Tensor, torch.Tensor] | None = None,\n    ) -> torch.Tensor:\n        q, _ = self.to_q(x)\n        context_ = x if context is None else context\n        k, _ = self.to_k(context_)\n        v, _ = self.to_v(context_)\n\n        if self.qk_norm:\n            assert self.q_norm is not None and self.k_norm is not None\n            q = self.q_norm(q)\n            k = self.k_norm(k)\n\n        if pe is not None:\n            cos, sin = pe\n            k_cos, k_sin = pe if k_pe is None else k_pe\n            tp_size = get_tp_world_size()\n            if tp_size > 1:\n                tp_rank = get_tp_rank()\n                cos, sin = self._slice_rope_for_tp(\n                    cos, sin, tp_rank=tp_rank, tp_size=tp_size\n                )\n                k_cos, k_sin = self._slice_rope_for_tp(\n                    k_cos, k_sin, tp_rank=tp_rank, tp_size=tp_size\n                )\n            if cos.dim() == 3:\n                q = apply_interleaved_rotary_emb(q, (cos, sin))\n                k = apply_interleaved_rotary_emb(k, (k_cos, k_sin))\n            else:\n                q = apply_split_rotary_emb(q, (cos, sin))\n                k = apply_split_rotary_emb(k, (k_cos, k_sin))\n\n        q = q.view(*q.shape[:-1], self.local_heads, self.dim_head)\n        k = k.view(*k.shape[:-1], self.local_heads, self.dim_head)\n        v = v.view(*v.shape[:-1], self.local_heads, self.dim_head)\n\n        if mask is not None:\n            # Fallback to SDPA for masked attention\n            q_ = q.transpose(1, 2)\n            k_ = k.transpose(1, 2)\n            v_ = v.transpose(1, 2)\n\n            if torch.is_floating_point(mask):\n                m = mask\n                if m.dim() == 2:\n                    m = m[:, None, None, :]\n                elif m.dim() == 3:\n                    m = m[:, None, :, :]\n                sdpa_mask = m.to(dtype=q_.dtype, device=q_.device)\n            else:\n                m = mask.to(dtype=q_.dtype, device=q_.device)\n                if m.dim() == 2:\n                    m = m[:, None, None, :]\n                elif m.dim() == 3:\n                    m = m[:, None, :, :]\n                sdpa_mask = (m - 1.0) * torch.finfo(q_.dtype).max\n\n            out = torch.nn.functional.scaled_dot_product_attention(\n                q_, k_, v_, attn_mask=sdpa_mask, dropout_p=0.0, is_causal=False\n            ).transpose(1, 2)\n        else:\n            out = self.attn(q, k, v)\n\n        out = out.flatten(2)\n        out, _ = self.to_out[0](out)\n        return out\n\n    def _slice_rope_for_tp(\n        self,\n        cos: torch.Tensor,\n        sin: torch.Tensor,\n        *,\n        tp_rank: int,\n        tp_size: int,\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"Slice RoPE tensors to the local TP shard.\n\n        - split-rope: cos/sin are shaped [B, H, T, R] (head-major), slice by heads.\n        - interleaved-rope: cos/sin are shaped [B, T, D], where D matches the projected\n          feature dimension and is sharded by TP.\n        \"\"\"\n        if cos.ndim == 4:\n            # [B, H, T, R]\n            start = tp_rank * self.local_heads\n            end = start + self.local_heads\n            return cos[:, start:end, :, :], sin[:, start:end, :, :]\n        elif cos.ndim == 3:\n            # [B, T, D]\n            d = cos.shape[-1]\n            if d % tp_size != 0:\n                raise ValueError(\n                    f\"RoPE dim must be divisible by tp_size, got {d=} {tp_size=}.\"\n                )\n            local_d = d // tp_size\n            start = tp_rank * local_d\n            end = start + local_d\n            return cos[:, :, start:end], sin[:, :, start:end]\n        raise ValueError(f\"Unexpected RoPE tensor rank: {cos.ndim}. Expected 3 or 4.\")\n\n\nclass LTX2FeedForward(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        dim_out: int | None = None,\n        mult: int = 4,\n        quant_config: QuantizationConfig | None = None,\n    ) -> None:\n        super().__init__()\n        if dim_out is None:\n            dim_out = dim\n        inner_dim = int(dim * mult)\n\n        self.proj_in = ColumnParallelLinear(\n            dim, inner_dim, bias=True, gather_output=True, quant_config=quant_config\n        )\n        self.act = nn.GELU(approximate=\"tanh\")\n        self.proj_out = ColumnParallelLinear(\n            inner_dim, dim_out, bias=True, gather_output=True, quant_config=quant_config\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x, _ = self.proj_in(x)\n        x = self.act(x)\n        x, _ = self.proj_out(x)\n        return x\n\n\nclass LTX2TransformerBlock(nn.Module):\n    def __init__(\n        self,\n        idx: int,\n        dim: int,\n        num_attention_heads: int,\n        attention_head_dim: int,\n        cross_attention_dim: int,\n        audio_dim: int,\n        audio_num_attention_heads: int,\n        audio_attention_head_dim: int,\n        audio_cross_attention_dim: int,\n        qk_norm: bool = True,\n        norm_eps: float = 1e-6,\n        supported_attention_backends: set[AttentionBackendEnum] | None = None,\n        prefix: str = \"\",\n        quant_config: QuantizationConfig | None = None,\n    ):\n        super().__init__()\n        self.idx = idx\n        self.norm_eps = norm_eps\n\n        # 1. Self-Attention (video and audio)\n        self.attn1 = LTX2Attention(\n            query_dim=dim,\n            heads=num_attention_heads,\n            dim_head=attention_head_dim,\n            norm_eps=norm_eps,\n            qk_norm=qk_norm,\n            supported_attention_backends=supported_attention_backends,\n            prefix=f\"{prefix}.attn1\",\n            quant_config=quant_config,\n        )\n        self.audio_attn1 = LTX2Attention(\n            query_dim=audio_dim,\n            heads=audio_num_attention_heads,\n            dim_head=audio_attention_head_dim,\n            norm_eps=norm_eps,\n            qk_norm=qk_norm,\n            supported_attention_backends=supported_attention_backends,\n            prefix=f\"{prefix}.audio_attn1\",\n            quant_config=quant_config,\n        )\n\n        # 2. Prompt Cross-Attention\n        self.attn2 = LTX2Attention(\n            query_dim=dim,\n            context_dim=cross_attention_dim,\n            heads=num_attention_heads,\n            dim_head=attention_head_dim,\n            norm_eps=norm_eps,\n            qk_norm=qk_norm,\n            supported_attention_backends=supported_attention_backends,\n            prefix=f\"{prefix}.attn2\",\n            quant_config=quant_config,\n        )\n        self.audio_attn2 = LTX2Attention(\n            query_dim=audio_dim,\n            context_dim=audio_cross_attention_dim,\n            heads=audio_num_attention_heads,\n            dim_head=audio_attention_head_dim,\n            norm_eps=norm_eps,\n            qk_norm=qk_norm,\n            supported_attention_backends=supported_attention_backends,\n            prefix=f\"{prefix}.audio_attn2\",\n            quant_config=quant_config,\n        )\n\n        # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention\n        self.audio_to_video_attn = LTX2Attention(\n            query_dim=dim,\n            context_dim=audio_dim,\n            heads=audio_num_attention_heads,\n            dim_head=audio_attention_head_dim,\n            norm_eps=norm_eps,\n            qk_norm=qk_norm,\n            supported_attention_backends=supported_attention_backends,\n            prefix=f\"{prefix}.audio_to_video_attn\",\n            quant_config=quant_config,\n        )\n        self.video_to_audio_attn = LTX2Attention(\n            query_dim=audio_dim,\n            context_dim=dim,\n            heads=audio_num_attention_heads,\n            dim_head=audio_attention_head_dim,\n            norm_eps=norm_eps,\n            qk_norm=qk_norm,\n            supported_attention_backends=supported_attention_backends,\n            prefix=f\"{prefix}.video_to_audio_attn\",\n            quant_config=quant_config,\n        )\n\n        # 4. Feedforward layers\n        self.ff = LTX2FeedForward(dim, dim_out=dim, quant_config=quant_config)\n        self.audio_ff = LTX2FeedForward(\n            audio_dim, dim_out=audio_dim, quant_config=quant_config\n        )\n\n        # 5. Modulation Parameters\n        self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)\n        self.audio_scale_shift_table = nn.Parameter(\n            torch.randn(6, audio_dim) / audio_dim**0.5\n        )\n        self.video_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, dim))\n        self.audio_a2v_cross_attn_scale_shift_table = nn.Parameter(\n            torch.randn(5, audio_dim)\n        )\n\n    def get_ada_values(\n        self,\n        scale_shift_table: torch.Tensor,\n        batch_size: int,\n        timestep: torch.Tensor,\n        indices: slice,\n    ) -> tuple[torch.Tensor, ...]:\n        num_ada_params = int(scale_shift_table.shape[0])\n        ada_values = (\n            scale_shift_table[indices]\n            .unsqueeze(0)\n            .unsqueeze(0)\n            .to(device=timestep.device, dtype=timestep.dtype)\n            + timestep.reshape(batch_size, timestep.shape[1], num_ada_params, -1)[\n                :, :, indices, :\n            ]\n        ).unbind(dim=2)\n        return [t.squeeze(2) for t in ada_values]\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        audio_hidden_states: torch.Tensor,\n        encoder_hidden_states: torch.Tensor,\n        audio_encoder_hidden_states: torch.Tensor,\n        temb: torch.Tensor,\n        temb_audio: torch.Tensor,\n        temb_ca_scale_shift: torch.Tensor,\n        temb_ca_audio_scale_shift: torch.Tensor,\n        temb_ca_gate: torch.Tensor,\n        temb_ca_audio_gate: torch.Tensor,\n        video_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n        audio_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n        ca_video_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n        ca_audio_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n        encoder_attention_mask: Optional[torch.Tensor] = None,\n        audio_encoder_attention_mask: Optional[torch.Tensor] = None,\n        a2v_cross_attention_mask: Optional[torch.Tensor] = None,\n        v2a_cross_attention_mask: Optional[torch.Tensor] = None,\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n\n        batch_size = hidden_states.size(0)\n\n        # 1. Video and Audio Self-Attention\n        vshift_msa, vscale_msa, vgate_msa = self.get_ada_values(\n            self.scale_shift_table, batch_size, temb, slice(0, 3)\n        )\n        norm_hidden_states = (\n            rms_norm(hidden_states, self.norm_eps) * (1 + vscale_msa) + vshift_msa\n        )\n        attn_hidden_states = self.attn1(norm_hidden_states, pe=video_rotary_emb)\n        hidden_states = hidden_states + attn_hidden_states * vgate_msa\n\n        ashift_msa, ascale_msa, agate_msa = self.get_ada_values(\n            self.audio_scale_shift_table, batch_size, temb_audio, slice(0, 3)\n        )\n        norm_audio_hidden_states = (\n            rms_norm(audio_hidden_states, self.norm_eps) * (1 + ascale_msa) + ashift_msa\n        )\n        attn_audio_hidden_states = self.audio_attn1(\n            norm_audio_hidden_states, pe=audio_rotary_emb\n        )\n        audio_hidden_states = audio_hidden_states + attn_audio_hidden_states * agate_msa\n\n        # 2. Prompt Cross-Attention\n        norm_hidden_states = rms_norm(hidden_states, self.norm_eps)\n        attn_hidden_states = self.attn2(\n            norm_hidden_states,\n            context=encoder_hidden_states,\n            mask=encoder_attention_mask,\n        )\n        hidden_states = hidden_states + attn_hidden_states\n\n        norm_audio_hidden_states = rms_norm(audio_hidden_states, self.norm_eps)\n        attn_audio_hidden_states = self.audio_attn2(\n            norm_audio_hidden_states,\n            context=audio_encoder_hidden_states,\n            mask=audio_encoder_attention_mask,\n        )\n        audio_hidden_states = audio_hidden_states + attn_audio_hidden_states\n\n        # 3. Audio-to-Video and Video-to-Audio Cross-Attention\n        norm_hidden_states = rms_norm(hidden_states, self.norm_eps)\n        norm_audio_hidden_states = rms_norm(audio_hidden_states, self.norm_eps)\n\n        # Compute combined ada params\n        video_per_layer_ca_scale_shift = self.video_a2v_cross_attn_scale_shift_table[\n            :4, :\n        ]\n        video_per_layer_ca_gate = self.video_a2v_cross_attn_scale_shift_table[4:, :]\n\n        video_ca_scale_shift_table = (\n            video_per_layer_ca_scale_shift[None, None, :, :].to(\n                dtype=temb_ca_scale_shift.dtype, device=temb_ca_scale_shift.device\n            )\n            + temb_ca_scale_shift.reshape(\n                batch_size, temb_ca_scale_shift.shape[1], 4, -1\n            )\n        ).unbind(dim=2)\n        video_ca_gate = (\n            video_per_layer_ca_gate[None, None, :, :].to(\n                dtype=temb_ca_gate.dtype, device=temb_ca_gate.device\n            )\n            + temb_ca_gate.reshape(batch_size, temb_ca_gate.shape[1], 1, -1)\n        ).unbind(dim=2)\n\n        (\n            video_a2v_ca_scale,\n            video_a2v_ca_shift,\n            video_v2a_ca_scale,\n            video_v2a_ca_shift,\n        ) = [t.squeeze(2) for t in video_ca_scale_shift_table]\n        a2v_gate = video_ca_gate[0].squeeze(2)\n\n        audio_per_layer_ca_scale_shift = self.audio_a2v_cross_attn_scale_shift_table[\n            :4, :\n        ]\n        audio_per_layer_ca_gate = self.audio_a2v_cross_attn_scale_shift_table[4:, :]\n\n        audio_ca_scale_shift_table = (\n            audio_per_layer_ca_scale_shift[None, None, :, :].to(\n                dtype=temb_ca_audio_scale_shift.dtype,\n                device=temb_ca_audio_scale_shift.device,\n            )\n            + temb_ca_audio_scale_shift.reshape(\n                batch_size, temb_ca_audio_scale_shift.shape[1], 4, -1\n            )\n        ).unbind(dim=2)\n        audio_ca_gate = (\n            audio_per_layer_ca_gate[None, None, :, :].to(\n                dtype=temb_ca_audio_gate.dtype, device=temb_ca_audio_gate.device\n            )\n            + temb_ca_audio_gate.reshape(batch_size, temb_ca_audio_gate.shape[1], 1, -1)\n        ).unbind(dim=2)\n\n        (\n            audio_a2v_ca_scale,\n            audio_a2v_ca_shift,\n            audio_v2a_ca_scale,\n            audio_v2a_ca_shift,\n        ) = [t.squeeze(2) for t in audio_ca_scale_shift_table]\n        v2a_gate = audio_ca_gate[0].squeeze(2)\n\n        # A2V\n        mod_norm_hidden_states = (\n            norm_hidden_states * (1 + video_a2v_ca_scale) + video_a2v_ca_shift\n        )\n        mod_norm_audio_hidden_states = (\n            norm_audio_hidden_states * (1 + audio_a2v_ca_scale) + audio_a2v_ca_shift\n        )\n\n        a2v_attn_hidden_states = self.audio_to_video_attn(\n            mod_norm_hidden_states,\n            context=mod_norm_audio_hidden_states,\n            pe=ca_video_rotary_emb,\n            k_pe=ca_audio_rotary_emb,\n            mask=a2v_cross_attention_mask,\n        )\n        hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states\n\n        # V2A\n        mod_norm_hidden_states = (\n            norm_hidden_states * (1 + video_v2a_ca_scale) + video_v2a_ca_shift\n        )\n        mod_norm_audio_hidden_states = (\n            norm_audio_hidden_states * (1 + audio_v2a_ca_scale) + audio_v2a_ca_shift\n        )\n\n        v2a_attn_hidden_states = self.video_to_audio_attn(\n            mod_norm_audio_hidden_states,\n            context=mod_norm_hidden_states,\n            pe=ca_audio_rotary_emb,\n            k_pe=ca_video_rotary_emb,\n            mask=v2a_cross_attention_mask,\n        )\n        audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states\n\n        # 4. Feedforward\n        vshift_mlp, vscale_mlp, vgate_mlp = self.get_ada_values(\n            self.scale_shift_table, batch_size, temb, slice(3, None)\n        )\n        norm_hidden_states = (\n            rms_norm(hidden_states, self.norm_eps) * (1 + vscale_mlp) + vshift_mlp\n        )\n        ff_output = self.ff(norm_hidden_states)\n        hidden_states = hidden_states + ff_output * vgate_mlp\n\n        ashift_mlp, ascale_mlp, agate_mlp = self.get_ada_values(\n            self.audio_scale_shift_table, batch_size, temb_audio, slice(3, None)\n        )\n        norm_audio_hidden_states = (\n            rms_norm(audio_hidden_states, self.norm_eps) * (1 + ascale_mlp) + ashift_mlp\n        )\n        audio_ff_output = self.audio_ff(norm_audio_hidden_states)\n        audio_hidden_states = audio_hidden_states + audio_ff_output * agate_mlp\n\n        return hidden_states, audio_hidden_states\n\n\nclass LTX2VideoTransformer3DModel(CachableDiT, OffloadableDiTMixin):\n    _fsdp_shard_conditions = LTX2ArchConfig()._fsdp_shard_conditions\n    _compile_conditions = LTX2ArchConfig()._compile_conditions\n    _supported_attention_backends = LTX2ArchConfig()._supported_attention_backends\n    param_names_mapping = LTX2ArchConfig().param_names_mapping\n    reverse_param_names_mapping = LTX2ArchConfig().reverse_param_names_mapping\n    lora_param_names_mapping = LTX2ArchConfig().lora_param_names_mapping\n\n    def _validate_tp_config(self, *, arch: LTX2ArchConfig, tp_size: int) -> None:\n        \"\"\"Validate TP-related dimension constraints (fail-fast).\"\"\"\n        if tp_size < 1:\n            raise ValueError(f\"Invalid tp_size={tp_size}. Expected tp_size >= 1.\")\n\n        if self.hidden_size % self.num_attention_heads != 0:\n            raise ValueError(\n                \"video hidden_size must be divisible by num_attention_heads, got \"\n                f\"{self.hidden_size=} {self.num_attention_heads=}.\"\n            )\n        if self.audio_hidden_size % self.audio_num_attention_heads != 0:\n            raise ValueError(\n                \"audio_hidden_size must be divisible by audio_num_attention_heads, got \"\n                f\"{self.audio_hidden_size=} {self.audio_num_attention_heads=}.\"\n            )\n\n        if tp_size == 1:\n            return\n\n        if self.num_attention_heads % tp_size != 0:\n            raise ValueError(\n                \"num_attention_heads must be divisible by tp_size, got \"\n                f\"{self.num_attention_heads=} {tp_size=}.\"\n            )\n        if self.audio_num_attention_heads % tp_size != 0:\n            raise ValueError(\n                \"audio_num_attention_heads must be divisible by tp_size, got \"\n                f\"{self.audio_num_attention_heads=} {tp_size=}.\"\n            )\n        if self.hidden_size % tp_size != 0:\n            raise ValueError(\n                \"hidden_size must be divisible by tp_size for TP-sharded projections, got \"\n                f\"{self.hidden_size=} {tp_size=}.\"\n            )\n        if self.audio_hidden_size % tp_size != 0:\n            raise ValueError(\n                \"audio_hidden_size must be divisible by tp_size for TP-sharded projections, got \"\n                f\"{self.audio_hidden_size=} {tp_size=}.\"\n            )\n        if int(arch.out_channels) % tp_size != 0:\n            raise ValueError(\n                \"out_channels must be divisible by tp_size for TP-sharded output projection, got \"\n                f\"{arch.out_channels=} {tp_size=}.\"\n            )\n        if int(arch.audio_out_channels) % tp_size != 0:\n            raise ValueError(\n                \"audio_out_channels must be divisible by tp_size for TP-sharded output projection, got \"\n                f\"{arch.audio_out_channels=} {tp_size=}.\"\n            )\n\n    def __init__(\n        self,\n        config: LTX2Config,\n        hf_config: dict[str, Any],\n        quant_config: QuantizationConfig | None = None,\n    ) -> None:\n        super().__init__(config=config, hf_config=hf_config)\n\n        arch = config.arch_config\n        self.hidden_size = arch.hidden_size\n        self.num_attention_heads = arch.num_attention_heads\n        self.audio_hidden_size = arch.audio_hidden_size\n        self.audio_num_attention_heads = arch.audio_num_attention_heads\n        self.norm_eps = arch.norm_eps\n\n        tp_size = get_tp_world_size()\n        self._validate_tp_config(arch=arch, tp_size=tp_size)\n\n        # 1. Patchification input projections\n        # Matches LTX2Config().param_names_mapping\n        self.patchify_proj = ColumnParallelLinear(\n            arch.in_channels,\n            self.hidden_size,\n            bias=True,\n            gather_output=True,\n            quant_config=quant_config,\n        )\n        self.audio_patchify_proj = ColumnParallelLinear(\n            arch.audio_in_channels,\n            self.audio_hidden_size,\n            bias=True,\n            gather_output=True,\n            quant_config=quant_config,\n        )\n\n        # 2. Prompt embeddings\n        self.caption_projection = LTX2TextProjection(\n            in_features=arch.caption_channels, hidden_size=self.hidden_size\n        )\n        self.audio_caption_projection = LTX2TextProjection(\n            in_features=arch.caption_channels, hidden_size=self.audio_hidden_size\n        )\n\n        # 3. Timestep Modulation Params and Embedding\n        self.adaln_single = LTX2AdaLayerNormSingle(\n            self.hidden_size, embedding_coefficient=6\n        )\n        self.audio_adaln_single = LTX2AdaLayerNormSingle(\n            self.audio_hidden_size, embedding_coefficient=6\n        )\n\n        # Global Cross Attention Modulation Parameters\n        self.av_ca_video_scale_shift_adaln_single = LTX2AdaLayerNormSingle(\n            self.hidden_size, embedding_coefficient=4\n        )\n        self.av_ca_a2v_gate_adaln_single = LTX2AdaLayerNormSingle(\n            self.hidden_size, embedding_coefficient=1\n        )\n        self.av_ca_audio_scale_shift_adaln_single = LTX2AdaLayerNormSingle(\n            self.audio_hidden_size, embedding_coefficient=4\n        )\n        self.av_ca_v2a_gate_adaln_single = LTX2AdaLayerNormSingle(\n            self.audio_hidden_size, embedding_coefficient=1\n        )\n\n        # Output Layer Scale/Shift Modulation parameters\n        self.scale_shift_table = nn.Parameter(\n            torch.randn(2, self.hidden_size) / self.hidden_size**0.5\n        )\n        self.audio_scale_shift_table = nn.Parameter(\n            torch.randn(2, self.audio_hidden_size) / self.audio_hidden_size**0.5\n        )\n\n        hf_patch_size = int(hf_config.get(\"patch_size\", 1))\n        hf_patch_size_t = int(hf_config.get(\"patch_size_t\", 1))\n        self.patch_size = (hf_patch_size_t, hf_patch_size, hf_patch_size)\n\n        hf_audio_patch_size = int(hf_config.get(\"audio_patch_size\", 1))\n        hf_audio_patch_size_t = int(hf_config.get(\"audio_patch_size_t\", 1))\n\n        rope_type = (\n            arch.rope_type.value\n            if hasattr(arch.rope_type, \"value\")\n            else str(arch.rope_type)\n        )\n        rope_double_precision = bool(\n            hf_config.get(\"rope_double_precision\", arch.double_precision_rope)\n        )\n        causal_offset = int(hf_config.get(\"causal_offset\", 1))\n\n        pos_embed_max_pos = int(arch.positional_embedding_max_pos[0])\n        base_height = int(arch.positional_embedding_max_pos[1])\n        base_width = int(arch.positional_embedding_max_pos[2])\n\n        audio_pos_embed_max_pos = int(arch.audio_positional_embedding_max_pos[0])\n\n        self.video_scale_factors = (8, 32, 32)\n        self.audio_scale_factors = (4,)\n\n        self.rope = LTX2AudioVideoRotaryPosEmbed(\n            dim=self.hidden_size,\n            patch_size=hf_patch_size,\n            patch_size_t=hf_patch_size_t,\n            base_num_frames=pos_embed_max_pos,\n            base_height=base_height,\n            base_width=base_width,\n            scale_factors=self.video_scale_factors,\n            theta=float(arch.positional_embedding_theta),\n            causal_offset=causal_offset,\n            modality=\"video\",\n            double_precision=rope_double_precision,\n            rope_type=rope_type,\n            num_attention_heads=self.num_attention_heads,\n        )\n        self.audio_rope = LTX2AudioVideoRotaryPosEmbed(\n            dim=self.audio_hidden_size,\n            patch_size=hf_audio_patch_size,\n            patch_size_t=hf_audio_patch_size_t,\n            base_num_frames=audio_pos_embed_max_pos,\n            sampling_rate=16000,\n            hop_length=160,\n            scale_factors=self.audio_scale_factors,\n            theta=float(arch.positional_embedding_theta),\n            causal_offset=causal_offset,\n            modality=\"audio\",\n            double_precision=rope_double_precision,\n            rope_type=rope_type,\n            num_attention_heads=self.audio_num_attention_heads,\n        )\n\n        cross_attn_pos_embed_max_pos = max(pos_embed_max_pos, audio_pos_embed_max_pos)\n        self.cross_attn_rope = LTX2AudioVideoRotaryPosEmbed(\n            dim=int(arch.audio_cross_attention_dim),\n            patch_size=hf_patch_size,\n            patch_size_t=hf_patch_size_t,\n            base_num_frames=cross_attn_pos_embed_max_pos,\n            base_height=base_height,\n            base_width=base_width,\n            theta=float(arch.positional_embedding_theta),\n            causal_offset=causal_offset,\n            modality=\"video\",\n            double_precision=rope_double_precision,\n            rope_type=rope_type,\n            num_attention_heads=self.num_attention_heads,\n        )\n        self.cross_attn_audio_rope = LTX2AudioVideoRotaryPosEmbed(\n            dim=int(arch.audio_cross_attention_dim),\n            patch_size=hf_audio_patch_size,\n            patch_size_t=hf_audio_patch_size_t,\n            base_num_frames=cross_attn_pos_embed_max_pos,\n            sampling_rate=16000,\n            hop_length=160,\n            theta=float(arch.positional_embedding_theta),\n            causal_offset=causal_offset,\n            modality=\"audio\",\n            double_precision=rope_double_precision,\n            rope_type=rope_type,\n            num_attention_heads=self.audio_num_attention_heads,\n        )\n\n        self.cross_pe_max_pos = cross_attn_pos_embed_max_pos\n\n        # 5. Transformer Blocks\n        self.transformer_blocks = nn.ModuleList(\n            [\n                LTX2TransformerBlock(\n                    idx=idx,\n                    dim=self.hidden_size,\n                    num_attention_heads=self.num_attention_heads,\n                    attention_head_dim=self.hidden_size // self.num_attention_heads,\n                    cross_attention_dim=arch.cross_attention_dim,\n                    audio_dim=self.audio_hidden_size,\n                    audio_num_attention_heads=self.audio_num_attention_heads,\n                    audio_attention_head_dim=self.audio_hidden_size\n                    // self.audio_num_attention_heads,\n                    audio_cross_attention_dim=arch.audio_cross_attention_dim,\n                    norm_eps=self.norm_eps,\n                    qk_norm=True,  # Always True in LTX2\n                    supported_attention_backends=self._supported_attention_backends,\n                    prefix=config.prefix,\n                    quant_config=quant_config,\n                )\n                for idx in range(arch.num_layers)\n            ]\n        )\n\n        # 6. Output layers\n        self.norm_out = nn.LayerNorm(\n            self.hidden_size, eps=self.norm_eps, elementwise_affine=False\n        )\n        self.proj_out = ColumnParallelLinear(\n            self.hidden_size,\n            arch.out_channels,\n            bias=True,\n            gather_output=True,\n            quant_config=quant_config,\n        )\n\n        self.audio_norm_out = nn.LayerNorm(\n            self.audio_hidden_size, eps=self.norm_eps, elementwise_affine=False\n        )\n        self.audio_proj_out = ColumnParallelLinear(\n            self.audio_hidden_size,\n            arch.audio_out_channels,\n            bias=True,\n            gather_output=True,\n            quant_config=quant_config,\n        )\n\n        self.out_channels_raw = arch.out_channels // (\n            self.patch_size[0] * self.patch_size[1] * self.patch_size[2]\n        )\n        self.audio_out_channels = arch.audio_out_channels\n        self.timestep_scale_multiplier = arch.timestep_scale_multiplier\n        self.av_ca_timestep_scale_multiplier = arch.av_ca_timestep_scale_multiplier\n\n        self.layer_names = [\"transformer_blocks\"]\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        audio_hidden_states: torch.Tensor,\n        encoder_hidden_states: torch.Tensor,\n        audio_encoder_hidden_states: torch.Tensor,\n        timestep: torch.LongTensor,\n        audio_timestep: Optional[torch.LongTensor] = None,\n        encoder_attention_mask: Optional[torch.Tensor] = None,\n        audio_encoder_attention_mask: Optional[torch.Tensor] = None,\n        num_frames: Optional[int] = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        fps: float = 24.0,\n        audio_num_frames: Optional[int] = None,\n        video_coords: Optional[torch.Tensor] = None,\n        audio_coords: Optional[torch.Tensor] = None,\n        **kwargs,\n    ) -> tuple[torch.Tensor | None, torch.Tensor | None]:\n\n        batch_size = hidden_states.size(0)\n        audio_timestep = audio_timestep if audio_timestep is not None else timestep\n\n        if num_frames is None or height is None or width is None:\n            raise ValueError(\n                \"num_frames/height/width must be provided for RoPE coordinate generation.\"\n            )\n        if audio_num_frames is None:\n            raise ValueError(\n                \"audio_num_frames must be provided for RoPE coordinate generation.\"\n            )\n\n        if video_coords is None:\n            # Wan-style SP-RoPE: when SP is enabled, each rank runs on its local\n            # time shard but RoPE positions must be offset to global time.\n            #\n            # We assume equal time sharding across SP ranks.\n            if model_parallel_is_initialized():\n                sp_world_size = get_sp_world_size()\n                sp_rank = get_sp_parallel_rank()\n            else:\n                sp_world_size = 1\n                sp_rank = 0\n\n            video_shift = int(sp_rank) * int(num_frames) if sp_world_size > 1 else 0\n            video_coords = self.rope.prepare_video_coords(\n                batch_size=batch_size,\n                num_frames=num_frames,\n                height=height,\n                width=width,\n                device=hidden_states.device,\n                fps=fps,\n                start_frame=video_shift,\n            )\n        if audio_coords is None:\n            audio_coords = self.audio_rope.prepare_audio_coords(\n                batch_size=batch_size,\n                num_frames=audio_num_frames,\n                device=audio_hidden_states.device,\n            )\n\n        video_rotary_emb = self.rope(video_coords, device=hidden_states.device)\n        audio_rotary_emb = self.audio_rope(\n            audio_coords, device=audio_hidden_states.device\n        )\n        ca_video_rotary_emb = self.cross_attn_rope(\n            video_coords[:, 0:1, :], device=hidden_states.device\n        )\n        ca_audio_rotary_emb = self.cross_attn_audio_rope(\n            audio_coords[:, 0:1, :], device=audio_hidden_states.device\n        )\n\n        # 2. Patchify input projections\n        hidden_states, _ = self.patchify_proj(hidden_states)\n        audio_hidden_states, _ = self.audio_patchify_proj(audio_hidden_states)\n\n        # 3. Prepare timestep embeddings\n        # 3.1. Prepare global modality (video and audio) timestep embedding and modulation parameters\n        temb, embedded_timestep = self.adaln_single(\n            timestep.flatten(),\n        )\n        temb = temb.view(batch_size, -1, temb.size(-1))\n        embedded_timestep = embedded_timestep.view(\n            batch_size, -1, embedded_timestep.size(-1)\n        )\n\n        temb_audio, audio_embedded_timestep = self.audio_adaln_single(\n            audio_timestep.flatten()\n        )\n        temb_audio = temb_audio.view(batch_size, -1, temb_audio.size(-1))\n        audio_embedded_timestep = audio_embedded_timestep.view(\n            batch_size, -1, audio_embedded_timestep.size(-1)\n        )\n\n        # 3.2. Prepare global modality cross attention modulation parameters\n        ts_ca_mult = (\n            self.av_ca_timestep_scale_multiplier / self.timestep_scale_multiplier\n        )\n\n        hidden_dtype = hidden_states.dtype\n        temb_ca_scale_shift, _ = self.av_ca_video_scale_shift_adaln_single(\n            timestep.flatten(), hidden_dtype=hidden_dtype\n        )\n        temb_ca_scale_shift = temb_ca_scale_shift.view(\n            batch_size, -1, temb_ca_scale_shift.shape[-1]\n        )\n\n        temb_ca_gate, _ = self.av_ca_a2v_gate_adaln_single(\n            timestep.flatten() * self.av_ca_timestep_scale_multiplier,\n            hidden_dtype=hidden_dtype,\n        )\n        temb_ca_gate = temb_ca_gate.view(batch_size, -1, temb_ca_gate.shape[-1])\n\n        temb_ca_audio_scale_shift, _ = self.av_ca_audio_scale_shift_adaln_single(\n            audio_timestep.flatten(), hidden_dtype=audio_hidden_states.dtype\n        )\n        temb_ca_audio_scale_shift = temb_ca_audio_scale_shift.view(\n            batch_size, -1, temb_ca_audio_scale_shift.shape[-1]\n        )\n\n        temb_ca_audio_gate, _ = self.av_ca_v2a_gate_adaln_single(\n            audio_timestep.flatten() * self.av_ca_timestep_scale_multiplier,\n            hidden_dtype=audio_hidden_states.dtype,\n        )\n        temb_ca_audio_gate = temb_ca_audio_gate.view(\n            batch_size, -1, temb_ca_audio_gate.shape[-1]\n        )\n\n        # 4. Prepare prompt embeddings\n        encoder_hidden_states = self.caption_projection(encoder_hidden_states)\n        audio_encoder_hidden_states = self.audio_caption_projection(\n            audio_encoder_hidden_states\n        )\n\n        # 5. Run blocks\n        for block in self.transformer_blocks:\n            hidden_states, audio_hidden_states = block(\n                hidden_states,\n                audio_hidden_states,\n                encoder_hidden_states,\n                audio_encoder_hidden_states,\n                # Keep the first 4 args positional to stay compatible with cache-dit's\n                # LTX2 adapter, which treats `audio_hidden_states` as `encoder_hidden_states`\n                # under ForwardPattern.Pattern_0.\n                temb=temb,\n                temb_audio=temb_audio,\n                temb_ca_scale_shift=temb_ca_scale_shift,\n                temb_ca_audio_scale_shift=temb_ca_audio_scale_shift,\n                temb_ca_gate=temb_ca_gate,\n                temb_ca_audio_gate=temb_ca_audio_gate,\n                video_rotary_emb=video_rotary_emb,\n                audio_rotary_emb=audio_rotary_emb,\n                ca_video_rotary_emb=ca_video_rotary_emb,\n                ca_audio_rotary_emb=ca_audio_rotary_emb,\n                encoder_attention_mask=encoder_attention_mask,\n                audio_encoder_attention_mask=audio_encoder_attention_mask,\n            )\n\n        # 6. Output layers\n        # Video\n        scale_shift_values = self.scale_shift_table[None, None].to(\n            device=hidden_states.device, dtype=hidden_states.dtype\n        ) + embedded_timestep[:, :, None].to(dtype=hidden_states.dtype)\n        shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]\n        with torch.autocast(device_type=hidden_states.device.type, enabled=False):\n            hidden_states = self.norm_out(hidden_states)\n        hidden_states = hidden_states * (1 + scale) + shift\n        hidden_states, _ = self.proj_out(hidden_states)\n\n        # Audio\n        audio_scale_shift_values = self.audio_scale_shift_table[None, None].to(\n            device=audio_hidden_states.device, dtype=audio_hidden_states.dtype\n        ) + audio_embedded_timestep[:, :, None].to(dtype=audio_hidden_states.dtype)\n        audio_shift, audio_scale = (\n            audio_scale_shift_values[:, :, 0],\n            audio_scale_shift_values[:, :, 1],\n        )\n        with torch.autocast(device_type=audio_hidden_states.device.type, enabled=False):\n            audio_hidden_states = self.audio_norm_out(audio_hidden_states)\n        audio_hidden_states = audio_hidden_states * (1 + audio_scale) + audio_shift\n        audio_hidden_states, _ = self.audio_proj_out(audio_hidden_states)\n\n        # Unpatchify if requested (default True for pipeline compatibility)\n        return_latents = kwargs.get(\"return_latents\", True)\n\n        if return_latents:\n            # Unpatchify Video\n            # [B, N, C_out_raw*patch_vol] -> [B, C_out_raw, T, H, W]\n            # Requires num_frames, height, width to be known\n            if num_frames is not None and height is not None and width is not None:\n                p_t, p_h, p_w = self.patch_size\n                post_t, post_h, post_w = num_frames // p_t, height // p_h, width // p_w\n                b = batch_size\n                hidden_states = hidden_states.reshape(\n                    b, post_t, post_h, post_w, self.out_channels_raw, p_t, p_h, p_w\n                )\n                hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7).reshape(\n                    b, self.out_channels_raw, num_frames, height, width\n                )\n\n            # Unpatchify Audio\n            # [B, N, C_out] -> [B, C_out, T] (or 4D/5D)\n            if audio_num_frames is not None:\n                b = batch_size\n                # simple reshape for 1D patch\n                audio_hidden_states = audio_hidden_states.permute(0, 2, 1)  # [B, C, T]\n\n        return hidden_states, audio_hidden_states\n\n\n# Backward-compatible alias (older internal name).\nLTXModel = LTX2VideoTransformer3DModel\nEntryClass = LTX2VideoTransformer3DModel\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/dits/mova_audio_dit.py",
    "content": "# Copied and adapted from: mossVG/mova/diffusion/models/wan_audio_dit.py\n# SPDX-License-Identifier: Apache-2.0\n#\n# NOTE: This module reuses common functions from mova_video_dit.py to reduce code duplication.\n# Audio-specific functions (precompute_freqs_cis_1d, legacy_precompute_freqs_cis_1d) are kept here.\n\nimport math\nfrom typing import Any, Optional, Tuple\n\nimport torch\nimport torch.nn as nn\nfrom einops import rearrange\nfrom torch.distributed.tensor import DTensor\n\nfrom sglang.multimodal_gen.configs.models.dits.mova_audio import MOVAAudioConfig\nfrom sglang.multimodal_gen.runtime.layers.linear import ReplicatedLinear\nfrom sglang.multimodal_gen.runtime.layers.mlp import MLP\nfrom sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import (\n    QuantizationConfig,\n)\nfrom sglang.multimodal_gen.runtime.models.dits.base import CachableDiT\nfrom sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin\n\n# Reuse common functions and classes from mova_video_dit\nfrom .mova_video_dit import DiTBlock, precompute_freqs_cis, sinusoidal_embedding_1d\n\n\n# Audio-specific positional encoding functions\ndef legacy_precompute_freqs_cis_1d(\n    dim: int,\n    end: int = 16384,\n    theta: float = 10000.0,\n    base_tps=4.0,\n    target_tps=44100 / 2048,\n):\n    s = float(base_tps) / float(target_tps)\n    # 1d rope precompute\n    f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta, s)\n    # No positional encoding is applied to the remaining dimensions\n    no_freqs_cis = precompute_freqs_cis(dim // 3, end, theta, s)\n    no_freqs_cis = torch.ones_like(no_freqs_cis)\n    return f_freqs_cis, no_freqs_cis, no_freqs_cis\n\n\ndef precompute_freqs_cis_1d(dim: int, end: int = 16384, theta: float = 10000.0):\n    f_freqs_cis = precompute_freqs_cis(dim, end, theta)\n    return f_freqs_cis.chunk(3, dim=-1)\n\n\nclass Head(nn.Module):\n    def __init__(\n        self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float\n    ):\n        super().__init__()\n        self.dim = dim\n        self.patch_size = patch_size\n        self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)\n        self.head = ReplicatedLinear(dim, out_dim * math.prod(patch_size))\n        self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)\n\n    def forward(self, x, t_mod):\n        if len(t_mod.shape) == 3:\n            shift, scale = (\n                self.modulation.unsqueeze(0).to(dtype=t_mod.dtype, device=t_mod.device)\n                + t_mod.unsqueeze(2)\n            ).chunk(2, dim=2)\n            x, _ = self.head(self.norm(x) * (1 + scale.squeeze(2)) + shift.squeeze(2))\n        else:\n            # NOTE: t_mod was originally [B, C]. This works correctly with broadcasting when B=1, but it won't match [1, 2, C] when B > 1.\n            shift, scale = (\n                self.modulation.to(dtype=t_mod.dtype, device=t_mod.device)\n                + t_mod.unsqueeze(1)\n            ).chunk(2, dim=1)\n            x, _ = self.head(self.norm(x) * (1 + scale) + shift)\n        return x\n\n\nclass Conv1dLocalIsland(nn.Conv1d):\n    \"\"\"Inherits from Conv1d and overrides forward.\n\n    - Parameters remain as DTensors (optimizer consistency is maintained).\n    - In the forward pass, x, weight, and bias are aggregated as Replicate,\n      and then local convolution is performed via to_local.\n    - The output is then redistributed as a DTensor (default is Replicate,\n      placements can be customized).\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n    def forward(self, input):\n        if isinstance(input, DTensor):\n            x_local = input.to_local()  # type: ignore[attr-defined]\n            w_local = self.weight.to_local()  # type: ignore[attr-defined]\n            b_local = (\n                self.bias.to_local() if self.bias is not None else None  # type: ignore[attr-defined]\n            )\n\n            return self._conv_forward(x_local, w_local, b_local)\n        else:\n            return super().forward(input)\n\n\nclass WanAudioModel(CachableDiT, OffloadableDiTMixin):\n    _fsdp_shard_conditions = MOVAAudioConfig()._fsdp_shard_conditions\n    _compile_conditions = MOVAAudioConfig()._compile_conditions\n    _supported_attention_backends = MOVAAudioConfig()._supported_attention_backends\n    param_names_mapping = MOVAAudioConfig().param_names_mapping\n    reverse_param_names_mapping = MOVAAudioConfig().reverse_param_names_mapping\n    lora_param_names_mapping = MOVAAudioConfig().lora_param_names_mapping\n\n    def __init__(\n        self,\n        config: MOVAAudioConfig,\n        hf_config: dict[str, Any],\n        quant_config: QuantizationConfig | None = None,\n    ) -> None:\n        super().__init__(config=config, hf_config=hf_config)\n\n        # Extract parameters from config\n        dim = config.dim\n        in_dim = config.in_dim\n        ffn_dim = config.ffn_dim\n        out_dim = config.out_dim\n        text_dim = config.text_dim\n        freq_dim = config.freq_dim\n        eps = config.eps\n        patch_size = config.patch_size\n        num_heads = config.num_heads\n        num_layers = config.num_layers\n        has_image_pos_emb = config.has_image_pos_emb\n        has_ref_conv = config.has_ref_conv\n        separated_timestep = config.separated_timestep\n        require_vae_embedding = config.require_vae_embedding\n        require_clip_embedding = config.require_clip_embedding\n        fuse_vae_embedding_in_latents = config.fuse_vae_embedding_in_latents\n        vae_type = config.vae_type\n\n        self.dim = dim\n        self.freq_dim = freq_dim\n        self.patch_size = patch_size\n        self.separated_timestep = separated_timestep\n        self.require_vae_embedding = require_vae_embedding\n        self.require_clip_embedding = require_clip_embedding\n        self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents\n        self.vae_type = vae_type\n        # self.patch_embedding = nn.Conv3d(\n        #     in_dim, dim, kernel_size=patch_size, stride=patch_size)\n        self.patch_embedding = Conv1dLocalIsland(\n            in_dim, dim, kernel_size=patch_size, stride=patch_size\n        )\n        self.text_embedding = MLP(\n            text_dim,\n            dim,\n            output_dim=dim,\n            act_type=\"gelu_pytorch_tanh\",\n            quant_config=quant_config,\n        )\n        self.time_embedding = MLP(\n            freq_dim, dim, output_dim=dim, act_type=\"silu\", quant_config=quant_config\n        )\n        # Preserve state_dict keys (time_projection.1.weight/bias).\n        self.time_projection = nn.Sequential(\n            nn.SiLU(), ReplicatedLinear(dim, dim * 6, quant_config=quant_config)\n        )\n        self.blocks = nn.ModuleList(\n            [\n                DiTBlock(dim, num_heads, ffn_dim, eps, quant_config=quant_config)\n                for _ in range(num_layers)\n            ]\n        )\n        self.head = Head(dim, out_dim, patch_size, eps)\n        self.num_heads = num_heads\n        self.freqs = None\n        self.img_pos_emb = None\n        if has_ref_conv:\n            self.ref_conv = nn.Conv2d(16, dim, kernel_size=(2, 2), stride=(2, 2))\n        self.has_image_pos_emb = has_image_pos_emb\n        self.has_ref_conv = has_ref_conv\n        self.hidden_size = dim\n        self.num_attention_heads = num_heads\n        self.num_channels_latents = out_dim\n        self.layer_names = [\"blocks\"]\n        self.cnt = 0\n        self.teacache_thresh = 0\n        self.coefficients = []\n        self.accumulated_rel_l1_distance = 0\n        self.previous_modulated_input = None\n        self.previous_resiual = None\n        self.previous_e0_even = None\n        self.previous_e0_odd = None\n        self.previous_residual_even = None\n        self.previous_residual_odd = None\n        self.is_even = False\n        self.should_calc_even = True\n        self.should_calc_odd = True\n        self.accumulated_rel_l1_distance_even = 0\n        self.accumulated_rel_l1_distance_odd = 0\n        self.__post_init__()\n\n    def _init_freqs(self):\n        if self.freqs is not None:\n            return\n        head_dim = self.dim // self.num_heads\n        if self.vae_type == \"dac\":\n            self.freqs = precompute_freqs_cis_1d(head_dim)\n        else:\n            raise ValueError(f\"Invalid VAE type: {self.vae_type}\")\n\n    def patchify(\n        self,\n        x: torch.Tensor,\n        control_camera_latents_input: Optional[torch.Tensor] = None,\n    ):\n        x = self.patch_embedding(x)\n        grid_size = x.shape[2:]\n        x = rearrange(x, \"b c f -> b f c\").contiguous()\n        return x, grid_size  # x, grid_size: (f)\n\n    def unpatchify(self, x: torch.Tensor, grid_size: tuple[int]):\n        return rearrange(\n            x, \"b f (p c) -> b c (f p)\", f=grid_size[0], p=self.patch_size[0]\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: torch.Tensor | list[torch.Tensor],\n        timestep: torch.LongTensor,\n    ) -> torch.Tensor:\n        # MOVA audio uses x/context naming historically.\n        x = hidden_states\n        context = (\n            encoder_hidden_states[0]\n            if isinstance(encoder_hidden_states, list)\n            else encoder_hidden_states\n        )\n\n        t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep))\n        t_proj, _ = self.time_projection(t)\n        t_mod = t_proj.unflatten(1, (6, self.dim))\n        context = self.text_embedding(context)\n\n        x, (f,) = self.patchify(x)\n\n        freqs = (\n            torch.cat(\n                [\n                    self.freqs[0][:f].view(f, -1).expand(f, -1),\n                    self.freqs[1][:f].view(f, -1).expand(f, -1),\n                    self.freqs[2][:f].view(f, -1).expand(f, -1),\n                ],\n                dim=-1,\n            )\n            .reshape(f, 1, -1)\n            .to(x.device)\n        )\n\n        for block in self.blocks:\n            x = block(x, context, t_mod, freqs)\n\n        x = self.head(x, t)\n        x = self.unpatchify(x, (f,))\n        return x\n\n\nEntryClass = WanAudioModel\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/dits/mova_video_dit.py",
    "content": "# Copied and adapted from: mossVG/mova/diffusion/models/wan_video_dit.py\n# SPDX-License-Identifier: Apache-2.0\n#\n# NOTE: This module shares common functions (sinusoidal_embedding_1d, precompute_freqs_cis, etc.)\n# with wanvideo.py. These functions are kept here for MOVA-specific model architecture,\n# but could be refactored to a common module in the future.\n\nimport math\nfrom typing import Any, Tuple\n\nimport torch\nimport torch.nn as nn\nfrom einops import rearrange\nfrom torch.distributed.tensor import DTensor\n\nfrom sglang.multimodal_gen.configs.models.dits.mova_video import MOVAVideoConfig\nfrom sglang.multimodal_gen.runtime.distributed import get_tp_world_size\nfrom sglang.multimodal_gen.runtime.layers.attention import LocalAttention, USPAttention\n\n# Reuse SGLang's optimized RMSNorm instead of torch.nn.RMSNorm or custom SlowRMSNorm\nfrom sglang.multimodal_gen.runtime.layers.layernorm import (\n    LayerNormScaleShift,\n    RMSNorm,\n    ScaleResidualLayerNormScaleShift,\n    tensor_parallel_rms_norm,\n)\nfrom sglang.multimodal_gen.runtime.layers.linear import (\n    ColumnParallelLinear,\n    ReplicatedLinear,\n    RowParallelLinear,\n)\nfrom sglang.multimodal_gen.runtime.layers.mlp import MLP\nfrom sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import (\n    QuantizationConfig,\n)\nfrom sglang.multimodal_gen.runtime.models.dits.base import CachableDiT\nfrom sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\n# @torch.compile(fullgraph=True)\ndef modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):\n    return x * (1 + scale) + shift\n\n\ndef sinusoidal_embedding_1d(dim, position):\n    sinusoid = torch.outer(\n        position.type(torch.float64),\n        torch.pow(\n            10000,\n            -torch.arange(dim // 2, dtype=torch.float64, device=position.device).div(\n                dim // 2\n            ),\n        ),\n    )\n    x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)\n    return x.to(position.dtype)\n\n\ndef precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0):\n    # 3d rope precompute\n    f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta)\n    h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)\n    w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)\n    return f_freqs_cis, h_freqs_cis, w_freqs_cis\n\n\ndef precompute_freqs_cis(\n    dim: int, end: int = 1024, theta: float = 10000.0, s: float = 1.0\n):\n    # 1d rope precompute\n    # Note: s parameter is used for audio-specific scaling (e.g., tps adjustment)\n    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].double() / dim))\n    pos = torch.arange(end, dtype=torch.float64, device=freqs.device) * s\n    freqs = torch.outer(pos, freqs)\n    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64\n    return freqs_cis\n\n\ndef rope_apply(x, freqs, num_heads):\n    x = rearrange(x, \"b s (n d) -> b s n d\", n=num_heads)\n    x_out = torch.view_as_complex(\n        x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2)\n    )\n    x_out = torch.view_as_real(x_out * freqs).flatten(2)\n    return x_out.to(x.dtype)\n\n\ndef rope_apply_head_dim(x, freqs, head_dim):\n    x = rearrange(x, \"b s (n d) -> b s n d\", d=head_dim)\n    x_out = torch.view_as_complex(\n        x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2)\n    )\n    # print(f\"{x_out.shape = }, {freqs.shape = }\")\n    x_out = torch.view_as_real(x_out * freqs).flatten(2)\n    return x_out.to(x.dtype)\n\n\nclass SelfAttention(nn.Module):\n    \"\"\"\n    Self-Attention module for MOVA DiT with Sequence Parallelism support.\n\n    SP is handled at the pipeline level (latents are pre-sharded before DiT forward).\n    USPAttention internally handles the all-to-all communication for distributed attention.\n    Input x should already be the local shard [B, S_local, D] when SP is enabled.\n    \"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        num_heads: int,\n        eps: float = 1e-6,\n        quant_config: QuantizationConfig | None = None,\n    ):\n        super().__init__()\n        self.dim = dim\n        self.num_heads = num_heads\n        self.head_dim = dim // num_heads\n\n        self.tp_size = get_tp_world_size()\n        if self.num_heads % self.tp_size != 0:\n            raise ValueError(\n                f\"num_heads ({self.num_heads}) must be divisible by tp_size ({self.tp_size}).\"\n            )\n        self.num_heads_per_rank = self.num_heads // self.tp_size\n\n        # TP strategy: shard Q/K/V over heads (column-parallel), then row-parallel output.\n        self.q = ColumnParallelLinear(\n            dim, dim, bias=True, gather_output=False, quant_config=quant_config\n        )\n        self.k = ColumnParallelLinear(\n            dim, dim, bias=True, gather_output=False, quant_config=quant_config\n        )\n        self.v = ColumnParallelLinear(\n            dim, dim, bias=True, gather_output=False, quant_config=quant_config\n        )\n        self.o = RowParallelLinear(\n            dim, dim, bias=True, input_is_parallel=True, quant_config=quant_config\n        )\n        self.norm_q = RMSNorm(dim, eps=eps)\n        self.norm_k = RMSNorm(dim, eps=eps)\n\n        self.attn = USPAttention(\n            # Local heads per TP rank.\n            num_heads=self.num_heads_per_rank,\n            head_size=self.head_dim,\n            causal=False,\n            softmax_scale=None,\n        )\n\n    def forward(self, x, freqs):\n        \"\"\"\n        Forward pass for self-attention.\n\n        Args:\n            x: Input tensor [B, S_local, D] - already sharded by SP when SP > 1\n            freqs: RoPE frequencies [S_local, 1, head_dim] - should match x's sequence length\n\n        Returns:\n            Output tensor [B, S_local, D]\n        \"\"\"\n        if isinstance(freqs, DTensor):\n            freqs = freqs.to_local()\n\n        # Compute Q, K, V on local sequence\n        q, _ = self.q(x)\n        k, _ = self.k(x)\n        v, _ = self.v(x)\n\n        # RMSNorm over sharded hidden dimension.\n        if self.tp_size > 1:\n            q = tensor_parallel_rms_norm(q, self.norm_q)\n            k = tensor_parallel_rms_norm(k, self.norm_k)\n        else:\n            q = self.norm_q(q)\n            k = self.norm_k(k)\n\n        # Apply RoPE\n        q = rope_apply_head_dim(q, freqs, self.head_dim)\n        k = rope_apply_head_dim(k, freqs, self.head_dim)\n\n        # USPAttention expects [B, S_local, H, D] format\n        q = rearrange(q, \"b s (n d) -> b s n d\", n=self.num_heads_per_rank)\n        k = rearrange(k, \"b s (n d) -> b s n d\", n=self.num_heads_per_rank)\n        v = rearrange(v, \"b s (n d) -> b s n d\", n=self.num_heads_per_rank)\n\n        # USPAttention handles SP communication internally\n        out = self.attn(q, k, v)\n        out = rearrange(out, \"b s n d -> b s (n d)\")\n\n        out, _ = self.o(out)\n        return out\n\n\nclass CrossAttention(nn.Module):\n    \"\"\"\n    Cross-Attention module for MOVA DiT.\n\n    Cross-attention does NOT require SP communication because:\n    - Query comes from the main sequence (already sharded by SP)\n    - Key/Value come from context (text embeddings, which are replicated across all ranks)\n\n    Uses LocalAttention instead of USPAttention for efficiency.\n    \"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        num_heads: int,\n        eps: float = 1e-6,\n        quant_config: QuantizationConfig | None = None,\n    ):\n        super().__init__()\n        self.dim = dim\n        self.num_heads = num_heads\n        self.head_dim = dim // num_heads\n\n        self.tp_size = get_tp_world_size()\n        if self.num_heads % self.tp_size != 0:\n            raise ValueError(\n                f\"num_heads ({self.num_heads}) must be divisible by tp_size ({self.tp_size}).\"\n            )\n        self.num_heads_per_rank = self.num_heads // self.tp_size\n\n        self.q = ColumnParallelLinear(\n            dim, dim, bias=True, gather_output=False, quant_config=quant_config\n        )\n        self.k = ColumnParallelLinear(\n            dim, dim, bias=True, gather_output=False, quant_config=quant_config\n        )\n        self.v = ColumnParallelLinear(\n            dim, dim, bias=True, gather_output=False, quant_config=quant_config\n        )\n        self.o = RowParallelLinear(\n            dim, dim, bias=True, input_is_parallel=True, quant_config=quant_config\n        )\n        self.norm_q = RMSNorm(dim, eps=eps)\n        self.norm_k = RMSNorm(dim, eps=eps)\n\n        # Use LocalAttention for cross-attention (no SP communication needed)\n        self.attn = LocalAttention(\n            num_heads=self.num_heads_per_rank,\n            head_size=self.head_dim,\n            causal=False,\n            softmax_scale=None,\n        )\n\n    def forward(self, x: torch.Tensor, y: torch.Tensor):\n        \"\"\"\n        Forward pass for cross-attention.\n\n        Args:\n            x: Query tensor [B, S_local, D] - the main sequence (sharded by SP)\n            y: Context tensor [B, S_ctx, D] - text/image embeddings (replicated)\n\n        Returns:\n            Output tensor [B, S_local, D]\n        \"\"\"\n        ctx = y\n\n        q, _ = self.q(x)\n        k, _ = self.k(ctx)\n        v, _ = self.v(ctx)\n\n        if self.tp_size > 1:\n            q = tensor_parallel_rms_norm(q, self.norm_q)\n            k = tensor_parallel_rms_norm(k, self.norm_k)\n        else:\n            q = self.norm_q(q)\n            k = self.norm_k(k)\n\n        q = rearrange(q, \"b s (n d) -> b s n d\", n=self.num_heads_per_rank)\n        k = rearrange(k, \"b s (n d) -> b s n d\", n=self.num_heads_per_rank)\n        v = rearrange(v, \"b s (n d) -> b s n d\", n=self.num_heads_per_rank)\n        x = self.attn(q, k, v)\n        x = rearrange(x, \"b s n d -> b s (n d)\")\n        x, _ = self.o(x)\n        return x\n\n\nclass MulAdd(nn.Module):\n    def __init__(self):\n        super().__init__()\n\n    def forward(self, x, gate, residual):\n        return residual + gate * x\n\n\nclass DiTBlock(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        num_heads: int,\n        ffn_dim: int,\n        eps: float = 1e-6,\n        quant_config: QuantizationConfig | None = None,\n    ):\n        super().__init__()\n        self.dim = dim\n        self.num_heads = num_heads\n        self.ffn_dim = ffn_dim\n\n        self.self_attn = SelfAttention(dim, num_heads, eps, quant_config=quant_config)\n        self.cross_attn = CrossAttention(dim, num_heads, eps, quant_config=quant_config)\n        self.norm1 = LayerNormScaleShift(\n            dim, eps=eps, elementwise_affine=False, dtype=torch.float32\n        )\n        self.self_attn_norm = nn.LayerNorm(dim, eps=eps)\n        # Fused: residual + 1 * cross_attn_out → layernorm + scale/shift\n        # Replaces the old norm2 (LayerNormScaleShift) + residual add for cross-attention\n        self.cross_attn_residual_norm = ScaleResidualLayerNormScaleShift(\n            dim, eps=eps, elementwise_affine=False, dtype=torch.float32\n        )\n        self.ffn = MLP(\n            dim,\n            ffn_dim,\n            output_dim=dim,\n            act_type=\"gelu_pytorch_tanh\",\n            quant_config=quant_config,\n        )\n        self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)\n        self.mlp_residual = MulAdd()\n\n    def forward(self, x, context, t_mod, freqs):\n        has_seq = len(t_mod.shape) == 4\n        chunk_dim = 2 if has_seq else 1\n        # msa: multi-head self-attention  mlp: multi-layer perceptron\n        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (\n            self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod\n        ).chunk(6, dim=chunk_dim)\n        if has_seq:\n            shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (\n                shift_msa.squeeze(2),\n                scale_msa.squeeze(2),\n                gate_msa.squeeze(2),\n                shift_mlp.squeeze(2),\n                scale_mlp.squeeze(2),\n                gate_mlp.squeeze(2),\n            )\n        orig_dtype = x.dtype\n        # 1. Self-attention, fuse:\n        # - layernorm(x) * (1 + scale_msa) + shift_msa\n        input_x = self.norm1(x, shift_msa, scale_msa)\n        # 2. torch.compile may fuse mlp_residual and self_attn_norm\n        x = self.mlp_residual(self.self_attn(input_x, freqs), gate_msa, x)\n        norm_x = self.self_attn_norm(x)\n        # 3. Cross-attention, fuse:\n        # - x = x + 1 * cross_output\n        # - input_x = layernorm(x) * (1 + scale_mlp) + shift_mlp\n        cross_output = self.cross_attn(norm_x, context)\n        input_x, x = self.cross_attn_residual_norm(\n            x, cross_output, 1, shift_mlp, scale_mlp\n        )\n        # 4. Feed-forward\n        x = self.mlp_residual(self.ffn(input_x), gate_mlp, x)\n        x = x.to(orig_dtype)\n        return x\n\n\nclass Head(nn.Module):\n    def __init__(\n        self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float\n    ):\n        super().__init__()\n        self.dim = dim\n        self.patch_size = patch_size\n        self.norm = LayerNormScaleShift(\n            dim, eps=eps, elementwise_affine=False, dtype=torch.float32\n        )\n        # Output dim is small for MOVA; replicate to avoid TP shape coupling.\n        self.head = ReplicatedLinear(dim, out_dim * math.prod(patch_size))\n        self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)\n\n    def forward(self, x, t_mod):\n        if len(t_mod.shape) == 3:\n            shift, scale = (\n                self.modulation.unsqueeze(0).to(dtype=t_mod.dtype, device=t_mod.device)\n                + t_mod.unsqueeze(2)\n            ).chunk(2, dim=2)\n            x, _ = self.head(self.norm(x, shift.squeeze(2), scale.squeeze(2)))\n        else:\n            shift, scale = (\n                self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod\n            ).chunk(2, dim=1)\n            x, _ = self.head(self.norm(x, shift, scale))\n        return x\n\n\nclass Conv3dLocalIsland(nn.Conv3d):\n    \"\"\"\n    Inherits from Conv3d and overrides the forward method.\n\n    Key behaviors:\n    - Parameters are kept as DTensor to maintain optimizer consistency.\n    - The forward pass aggregates input, weight, and bias into a Replicate state,\n      then performs the convolution locally using to_local().\n    - The output is then redistributed as a DTensor (defaults to Replicate,\n      but placements can be customized).\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n    def forward(self, input):\n        if isinstance(input, DTensor):\n            # NOTE: DTensor typing stubs are incomplete; at runtime DTensor has\n            # to_local() and parameters may also be DTensor.\n            x_local = input.to_local()  # type: ignore[attr-defined]\n            w_local = self.weight.to_local()  # type: ignore[attr-defined]\n            b_local = (\n                self.bias.to_local() if self.bias is not None else None  # type: ignore[attr-defined]\n            )\n\n            return self._conv_forward(x_local, w_local, b_local)\n        else:\n            return super().forward(input)\n\n\nclass WanModel(CachableDiT, OffloadableDiTMixin):\n    _fsdp_shard_conditions = MOVAVideoConfig()._fsdp_shard_conditions\n    _compile_conditions = MOVAVideoConfig()._compile_conditions\n    _supported_attention_backends = MOVAVideoConfig()._supported_attention_backends\n    param_names_mapping = MOVAVideoConfig().param_names_mapping\n    reverse_param_names_mapping = MOVAVideoConfig().reverse_param_names_mapping\n    lora_param_names_mapping = MOVAVideoConfig().lora_param_names_mapping\n\n    def __init__(\n        self,\n        config: MOVAVideoConfig,\n        hf_config: dict[str, Any],\n        quant_config: QuantizationConfig | None = None,\n    ) -> None:\n        super().__init__(config=config, hf_config=hf_config)\n\n        # Extract parameters from config\n        dim = config.dim\n        in_dim = config.in_dim\n        ffn_dim = config.ffn_dim\n        out_dim = config.out_dim\n        text_dim = config.text_dim\n        freq_dim = config.freq_dim\n        eps = config.eps\n        patch_size = config.patch_size\n        num_heads = config.num_heads\n        num_layers = config.num_layers\n        has_image_pos_emb = config.has_image_pos_emb\n        has_ref_conv = config.has_ref_conv\n        separated_timestep = config.separated_timestep\n        require_vae_embedding = config.require_vae_embedding\n        require_clip_embedding = config.require_clip_embedding\n        fuse_vae_embedding_in_latents = config.fuse_vae_embedding_in_latents\n\n        self.dim = dim\n        self.freq_dim = freq_dim\n        self.patch_size = patch_size\n        self.separated_timestep = separated_timestep\n        self.require_vae_embedding = require_vae_embedding\n        self.require_clip_embedding = require_clip_embedding\n        self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents\n\n        self.patch_embedding = Conv3dLocalIsland(\n            in_dim, dim, kernel_size=patch_size, stride=patch_size\n        )\n        self.text_embedding = MLP(\n            text_dim,\n            dim,\n            output_dim=dim,\n            act_type=\"gelu_pytorch_tanh\",\n            quant_config=quant_config,\n        )\n        self.time_embedding = MLP(\n            freq_dim, dim, output_dim=dim, act_type=\"silu\", quant_config=quant_config\n        )\n        # Preserve state_dict keys (time_projection.1.weight/bias).\n        self.time_projection = nn.Sequential(\n            nn.SiLU(), ReplicatedLinear(dim, dim * 6, quant_config=quant_config)\n        )\n        self.blocks = nn.ModuleList(\n            [\n                DiTBlock(dim, num_heads, ffn_dim, eps, quant_config=quant_config)\n                for _ in range(num_layers)\n            ]\n        )\n        self.head = Head(dim, out_dim, patch_size, eps)\n        self.num_heads = num_heads\n        self.freqs = None\n\n        if has_ref_conv:\n            self.ref_conv = nn.Conv2d(16, dim, kernel_size=(2, 2), stride=(2, 2))\n        self.has_image_pos_emb = has_image_pos_emb\n        self.has_ref_conv = has_ref_conv\n        self.hidden_size = dim\n        self.num_attention_heads = num_heads\n        self.num_channels_latents = out_dim\n        self.layer_names = [\"blocks\"]\n        self.cnt = 0\n        self.teacache_thresh = 0\n        self.coefficients = []\n        self.accumulated_rel_l1_distance = 0\n        self.previous_modulated_input = None\n        self.previous_resiual = None\n        self.previous_e0_even = None\n        self.previous_e0_odd = None\n        self.previous_residual_even = None\n        self.previous_residual_odd = None\n        self.is_even = False\n        self.should_calc_even = True\n        self.should_calc_odd = True\n        self.accumulated_rel_l1_distance_even = 0\n        self.accumulated_rel_l1_distance_odd = 0\n        self.__post_init__()\n\n    def _init_freqs(self):\n        if self.freqs is not None:\n            return\n        head_dim = self.dim // self.num_heads\n        self.freqs = precompute_freqs_cis_3d(head_dim)\n\n    def patchify(\n        self, x: torch.Tensor, control_camera_latents_input: torch.Tensor | None = None\n    ):\n        # NOTE(dhyu): avoid slow_conv\n        x = x.contiguous(memory_format=torch.channels_last_3d)\n        x = self.patch_embedding(x)\n        grid_size = x.shape[2:]\n        x = rearrange(x, \"b c f h w -> b (f h w) c\").contiguous()\n        return x, grid_size  # x, grid_size: (f, h, w)\n\n    def unpatchify(self, x: torch.Tensor, grid_size: tuple[int, int, int]):\n        return rearrange(\n            x,\n            \"b (f h w) (x y z c) -> b c (f x) (h y) (w z)\",\n            f=grid_size[0],\n            h=grid_size[1],\n            w=grid_size[2],\n            x=self.patch_size[0],\n            y=self.patch_size[1],\n            z=self.patch_size[2],\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: torch.Tensor | list[torch.Tensor],\n        timestep: torch.LongTensor,\n    ) -> torch.Tensor:\n        # MOVA code historically uses x/context/y/clip_feature naming.\n        x = hidden_states\n        context = (\n            encoder_hidden_states[0]\n            if isinstance(encoder_hidden_states, list)\n            else encoder_hidden_states\n        )\n        t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep))\n        t_proj, _ = self.time_projection(t)\n        t_mod = t_proj.unflatten(1, (6, self.dim))\n        context = self.text_embedding(context)\n\n        x, (f, h, w) = self.patchify(x)\n\n        freqs = (\n            torch.cat(\n                [\n                    self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),\n                    self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),\n                    self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),\n                ],\n                dim=-1,\n            )\n            .reshape(f * h * w, 1, -1)\n            .to(x.device)\n        )\n\n        for block in self.blocks:\n            x = block(x, context, t_mod, freqs)\n\n        x = self.head(x, t)\n        x = self.unpatchify(x, (f, h, w))\n        return x\n\n\nEntryClass = WanModel\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\nimport functools\nfrom typing import Any, Dict, List, Optional, Tuple, Union\n\nimport diffusers\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom diffusers.models.attention import FeedForward\nfrom diffusers.models.embeddings import TimestepEmbedding, Timesteps\nfrom diffusers.models.modeling_outputs import Transformer2DModelOutput\nfrom diffusers.models.normalization import AdaLayerNormContinuous\n\nfrom sglang.jit_kernel.diffusion.triton.scale_shift import (\n    fuse_layernorm_scale_shift_gate_select01_kernel,\n    fuse_residual_layernorm_scale_shift_gate_select01_kernel,\n)\nfrom sglang.multimodal_gen.configs.models.dits.qwenimage import QwenImageDitConfig\nfrom sglang.multimodal_gen.runtime.distributed import get_local_torch_device\nfrom sglang.multimodal_gen.runtime.layers.attention import USPAttention\nfrom sglang.multimodal_gen.runtime.layers.elementwise import MulAdd\nfrom sglang.multimodal_gen.runtime.layers.layernorm import (\n    LayerNormScaleShift,\n    RMSNorm,\n    ScaleResidualLayerNormScaleShift,\n    apply_qk_norm,\n)\nfrom sglang.multimodal_gen.runtime.layers.linear import (\n    MergedColumnParallelLinear,\n    ReplicatedLinear,\n)\nfrom sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import (\n    QuantizationConfig,\n)\nfrom sglang.multimodal_gen.runtime.layers.quantization.configs.nunchaku_config import (\n    NunchakuConfig,\n    is_nunchaku_available,\n)\nfrom sglang.multimodal_gen.runtime.layers.rotary_embedding import (\n    apply_flashinfer_rope_qk_inplace,\n)\nfrom sglang.multimodal_gen.runtime.models.dits.base import CachableDiT\nfrom sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum\nfrom sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)  # pylint: disable=invalid-name\n\ntry:\n    from nunchaku.models.attention import NunchakuFeedForward  # type: ignore[import]\nexcept Exception:\n    NunchakuFeedForward = None\n\n\ndef _get_qkv_projections(\n    attn: \"QwenImageCrossAttention\", hidden_states, encoder_hidden_states=None\n):\n    if attn.use_fused_qkv:\n        img_qkv, _ = attn.to_qkv(hidden_states)\n        img_query, img_key, img_value = [\n            x.contiguous() for x in img_qkv.chunk(3, dim=-1)\n        ]\n    else:\n        img_query, _ = attn.to_q(hidden_states)\n        img_key, _ = attn.to_k(hidden_states)\n        img_value, _ = attn.to_v(hidden_states)\n\n    txt_query = txt_key = txt_value = None\n    if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:\n        if attn.use_fused_added_qkv:\n            txt_qkv, _ = attn.to_added_qkv(encoder_hidden_states)\n            txt_query, txt_key, txt_value = [\n                x.contiguous() for x in txt_qkv.chunk(3, dim=-1)\n            ]\n        else:\n            txt_query, _ = attn.add_q_proj(encoder_hidden_states)\n            txt_key, _ = attn.add_k_proj(encoder_hidden_states)\n            txt_value, _ = attn.add_v_proj(encoder_hidden_states)\n\n    return img_query, img_key, img_value, txt_query, txt_key, txt_value\n\n\nclass QwenTimestepProjEmbeddings(nn.Module):\n    def __init__(self, embedding_dim, use_additional_t_cond=False):\n        super().__init__()\n\n        self.time_proj = Timesteps(\n            num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000\n        )\n        self.timestep_embedder = TimestepEmbedding(\n            in_channels=256, time_embed_dim=embedding_dim\n        )\n        self.use_additional_t_cond = use_additional_t_cond\n        if use_additional_t_cond:\n            self.addition_t_embedding = nn.Embedding(2, embedding_dim)\n\n    def forward(self, timestep, hidden_states, addition_t_cond=None):\n        timesteps_proj = self.time_proj(timestep)\n        timesteps_emb = self.timestep_embedder(\n            timesteps_proj.to(dtype=hidden_states.dtype)\n        )  # (N, D)\n\n        conditioning = timesteps_emb\n        if self.use_additional_t_cond:\n            if addition_t_cond is None:\n                raise ValueError(\n                    \"When additional_t_cond is True, addition_t_cond must be provided.\"\n                )\n            addition_t_emb = self.addition_t_embedding(addition_t_cond)\n            addition_t_emb = addition_t_emb.to(dtype=hidden_states.dtype)\n            conditioning = conditioning + addition_t_emb\n\n        return conditioning\n\n\nclass QwenEmbedRope(nn.Module):\n    def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):\n        super().__init__()\n        self.theta = theta\n        self.axes_dim = axes_dim\n        pos_index = torch.arange(4096)\n        neg_index = torch.arange(4096).flip(0) * -1 - 1\n        self.pos_freqs = torch.cat(\n            [\n                self.rope_params(pos_index, self.axes_dim[0], self.theta),\n                self.rope_params(pos_index, self.axes_dim[1], self.theta),\n                self.rope_params(pos_index, self.axes_dim[2], self.theta),\n            ],\n            dim=1,\n        )\n        self.neg_freqs = torch.cat(\n            [\n                self.rope_params(neg_index, self.axes_dim[0], self.theta),\n                self.rope_params(neg_index, self.axes_dim[1], self.theta),\n                self.rope_params(neg_index, self.axes_dim[2], self.theta),\n            ],\n            dim=1,\n        )\n\n        # DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART\n        self.scale_rope = scale_rope\n\n    def rope_params(self, index, dim, theta=10000):\n        \"\"\"\n        Args:\n            index: [0, 1, 2, 3] 1D Tensor representing the position index of the token\n        \"\"\"\n        device = index.device\n        assert dim % 2 == 0\n        freqs = torch.outer(\n            index,\n            (\n                1.0\n                / torch.pow(\n                    theta,\n                    torch.arange(0, dim, 2, device=device).to(torch.float32).div(dim),\n                )\n            ).to(device=device),\n        )\n        freqs = torch.polar(torch.ones_like(freqs), freqs)\n        return freqs\n\n    def forward(\n        self,\n        video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]],\n        txt_seq_lens: List[int],\n        device: torch.device,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Args:\n            video_fhw (`Tuple[int, int, int]` or `List[Tuple[int, int, int]]`):\n                A list of 3 integers [frame, height, width] representing the shape of the video.\n            txt_seq_lens (`List[int]`):\n                A list of integers of length batch_size representing the length of each text prompt.\n            device: (`torch.device`):\n                The device on which to perform the RoPE computation.\n        \"\"\"\n        # When models are initialized under a \"meta\" device context (e.g. init_empty_weights),\n        # tensors created during __init__ become meta tensors. Calling .to(...) on a meta tensor\n        # raises \"Cannot copy out of meta tensor\". Rebuild the frequencies on the target device\n        # in that case; otherwise move them if just on a different device.\n        if getattr(self.pos_freqs, \"device\", torch.device(\"meta\")).type == \"meta\":\n            pos_index = torch.arange(4096, device=device)\n            neg_index = torch.arange(4096, device=device).flip(0) * -1 - 1\n            self.pos_freqs = torch.cat(\n                [\n                    self.rope_params(pos_index, self.axes_dim[0], self.theta),\n                    self.rope_params(pos_index, self.axes_dim[1], self.theta),\n                    self.rope_params(pos_index, self.axes_dim[2], self.theta),\n                ],\n                dim=1,\n            ).to(device=device)\n            self.neg_freqs = torch.cat(\n                [\n                    self.rope_params(neg_index, self.axes_dim[0], self.theta),\n                    self.rope_params(neg_index, self.axes_dim[1], self.theta),\n                    self.rope_params(neg_index, self.axes_dim[2], self.theta),\n                ],\n                dim=1,\n            ).to(device=device)\n        elif self.pos_freqs.device != device:\n            self.pos_freqs = self.pos_freqs.to(device)\n            self.neg_freqs = self.neg_freqs.to(device)\n\n        if isinstance(video_fhw, list):\n            video_fhw = video_fhw[0]\n        if not isinstance(video_fhw, list):\n            video_fhw = [video_fhw]\n\n        vid_freqs = []\n        max_vid_index = 0\n        for idx, fhw in enumerate(video_fhw):\n            frame, height, width = fhw\n            # RoPE frequencies are cached via a lru_cache decorator on _compute_video_freqs\n            video_freq = self._compute_video_freqs(frame, height, width, idx)\n            video_freq = video_freq.to(device)\n            vid_freqs.append(video_freq)\n\n            if self.scale_rope:\n                max_vid_index = max(height // 2, width // 2, max_vid_index)\n            else:\n                max_vid_index = max(height, width, max_vid_index)\n\n        max_len = max(txt_seq_lens)\n        txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]\n        vid_freqs = torch.cat(vid_freqs, dim=0).to(device=device)\n        return vid_freqs, txt_freqs\n\n    @functools.lru_cache(maxsize=128)\n    def _compute_video_freqs(\n        self, frame: int, height: int, width: int, idx: int = 0\n    ) -> torch.Tensor:\n        seq_lens = frame * height * width\n        freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)\n        freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)\n\n        freqs_frame = (\n            freqs_pos[0][idx : idx + frame]\n            .view(frame, 1, 1, -1)\n            .expand(frame, height, width, -1)\n        )\n        if self.scale_rope:\n            freqs_height = torch.cat(\n                [freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]],\n                dim=0,\n            )\n            freqs_height = freqs_height.view(1, height, 1, -1).expand(\n                frame, height, width, -1\n            )\n            freqs_width = torch.cat(\n                [freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]],\n                dim=0,\n            )\n            freqs_width = freqs_width.view(1, 1, width, -1).expand(\n                frame, height, width, -1\n            )\n        else:\n            freqs_height = (\n                freqs_pos[1][:height]\n                .view(1, height, 1, -1)\n                .expand(frame, height, width, -1)\n            )\n            freqs_width = (\n                freqs_pos[2][:width]\n                .view(1, 1, width, -1)\n                .expand(frame, height, width, -1)\n            )\n\n        freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(\n            seq_lens, -1\n        )\n        return freqs.clone().contiguous()\n\n\nclass QwenEmbedLayer3DRope(nn.Module):\n    def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):\n        super().__init__()\n        self.theta = theta\n        self.axes_dim = axes_dim\n        pos_index = torch.arange(4096)\n        neg_index = torch.arange(4096).flip(0) * -1 - 1\n        self.pos_freqs = torch.cat(\n            [\n                self.rope_params(pos_index, self.axes_dim[0], self.theta),\n                self.rope_params(pos_index, self.axes_dim[1], self.theta),\n                self.rope_params(pos_index, self.axes_dim[2], self.theta),\n            ],\n            dim=1,\n        )\n        self.neg_freqs = torch.cat(\n            [\n                self.rope_params(neg_index, self.axes_dim[0], self.theta),\n                self.rope_params(neg_index, self.axes_dim[1], self.theta),\n                self.rope_params(neg_index, self.axes_dim[2], self.theta),\n            ],\n            dim=1,\n        )\n\n        self.scale_rope = scale_rope\n\n    def rope_params(self, index, dim, theta=10000):\n        \"\"\"\n        Args:\n            index: [0, 1, 2, 3] 1D Tensor representing the position index of the token\n        \"\"\"\n        device = index.device\n        assert dim % 2 == 0\n        freqs = torch.outer(\n            index,\n            (\n                1.0\n                / torch.pow(\n                    theta,\n                    torch.arange(0, dim, 2, device=device).to(torch.float32).div(dim),\n                )\n            ).to(device=device),\n        )\n        freqs = torch.polar(torch.ones_like(freqs), freqs)\n        return freqs\n\n    def forward(self, video_fhw, txt_seq_lens, device):\n        \"\"\"\n        Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:\n        txt_length: [bs] a list of 1 integers representing the length of the text\n        \"\"\"\n\n        # When models are initialized under a \"meta\" device context (e.g. init_empty_weights),\n        # tensors created during __init__ become meta tensors. Calling .to(...) on a meta tensor\n        # raises \"Cannot copy out of meta tensor\". Rebuild the frequencies on the target device\n        # in that case; otherwise move them if just on a different device.\n        if getattr(self.pos_freqs, \"device\", torch.device(\"meta\")).type == \"meta\":\n            pos_index = torch.arange(4096, device=device)\n            neg_index = torch.arange(4096, device=device).flip(0) * -1 - 1\n            self.pos_freqs = torch.cat(\n                [\n                    self.rope_params(pos_index, self.axes_dim[0], self.theta),\n                    self.rope_params(pos_index, self.axes_dim[1], self.theta),\n                    self.rope_params(pos_index, self.axes_dim[2], self.theta),\n                ],\n                dim=1,\n            ).to(device=device)\n            self.neg_freqs = torch.cat(\n                [\n                    self.rope_params(neg_index, self.axes_dim[0], self.theta),\n                    self.rope_params(neg_index, self.axes_dim[1], self.theta),\n                    self.rope_params(neg_index, self.axes_dim[2], self.theta),\n                ],\n                dim=1,\n            ).to(device=device)\n        elif self.pos_freqs.device != device:\n            self.pos_freqs = self.pos_freqs.to(device)\n            self.neg_freqs = self.neg_freqs.to(device)\n\n        if isinstance(video_fhw, list):\n            video_fhw = video_fhw[0]\n        if not isinstance(video_fhw, list):\n            video_fhw = [video_fhw]\n\n        vid_freqs = []\n        max_vid_index = 0\n        layer_num = len(video_fhw) - 1\n        for idx, fhw in enumerate(video_fhw):\n            frame, height, width = fhw\n            if idx != layer_num:\n                video_freq = self._compute_video_freqs(frame, height, width, idx)\n            else:\n                # For the condition image, we set the layer index to -1\n                video_freq = self._compute_condition_freqs(frame, height, width)\n            video_freq = video_freq.to(device)\n            vid_freqs.append(video_freq)\n\n            if self.scale_rope:\n                max_vid_index = max(height // 2, width // 2, max_vid_index)\n            else:\n                max_vid_index = max(height, width, max_vid_index)\n\n        max_vid_index = max(max_vid_index, layer_num)\n        max_len = max(txt_seq_lens)\n        txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]\n        vid_freqs = torch.cat(vid_freqs, dim=0)\n\n        return vid_freqs, txt_freqs\n\n    @functools.lru_cache(maxsize=None)\n    def _compute_video_freqs(self, frame, height, width, idx=0):\n        seq_lens = frame * height * width\n        freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)\n        freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)\n\n        freqs_frame = (\n            freqs_pos[0][idx : idx + frame]\n            .view(frame, 1, 1, -1)\n            .expand(frame, height, width, -1)\n        )\n        if self.scale_rope:\n            freqs_height = torch.cat(\n                [freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]],\n                dim=0,\n            )\n            freqs_height = freqs_height.view(1, height, 1, -1).expand(\n                frame, height, width, -1\n            )\n            freqs_width = torch.cat(\n                [freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]],\n                dim=0,\n            )\n            freqs_width = freqs_width.view(1, 1, width, -1).expand(\n                frame, height, width, -1\n            )\n        else:\n            freqs_height = (\n                freqs_pos[1][:height]\n                .view(1, height, 1, -1)\n                .expand(frame, height, width, -1)\n            )\n            freqs_width = (\n                freqs_pos[2][:width]\n                .view(1, 1, width, -1)\n                .expand(frame, height, width, -1)\n            )\n\n        freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(\n            seq_lens, -1\n        )\n        return freqs.clone().contiguous()\n\n    @functools.lru_cache(maxsize=None)\n    def _compute_condition_freqs(self, frame, height, width):\n        seq_lens = frame * height * width\n        freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)\n        freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)\n\n        freqs_frame = (\n            freqs_neg[0][-1:].view(frame, 1, 1, -1).expand(frame, height, width, -1)\n        )\n        if self.scale_rope:\n            freqs_height = torch.cat(\n                [freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]],\n                dim=0,\n            )\n            freqs_height = freqs_height.view(1, height, 1, -1).expand(\n                frame, height, width, -1\n            )\n            freqs_width = torch.cat(\n                [freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]],\n                dim=0,\n            )\n            freqs_width = freqs_width.view(1, 1, width, -1).expand(\n                frame, height, width, -1\n            )\n        else:\n            freqs_height = (\n                freqs_pos[1][:height]\n                .view(1, height, 1, -1)\n                .expand(frame, height, width, -1)\n            )\n            freqs_width = (\n                freqs_pos[2][:width]\n                .view(1, 1, width, -1)\n                .expand(frame, height, width, -1)\n            )\n\n        freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(\n            seq_lens, -1\n        )\n        return freqs.clone().contiguous()\n\n\nclass QwenImageCrossAttention(nn.Module):\n    def __init__(\n        self,\n        dim: int,  # query_dim\n        num_heads: int,\n        head_dim: int,\n        window_size=(-1, -1),\n        added_kv_proj_dim: int = None,\n        out_bias: bool = True,\n        qk_norm=True,  # rmsnorm\n        eps=1e-6,\n        pre_only=False,\n        context_pre_only: bool = False,\n        parallel_attention=False,\n        out_dim: int = None,\n        quant_config: Optional[QuantizationConfig] = None,\n        prefix: str = \"\",\n    ) -> None:\n        assert dim % num_heads == 0\n        super().__init__()\n        self.dim = dim\n        self.num_heads = num_heads\n        self.head_dim = dim // num_heads\n        self.window_size = window_size\n        self.qk_norm = qk_norm\n        self.eps = eps\n        self.parallel_attention = parallel_attention\n        self.added_kv_proj_dim = added_kv_proj_dim\n        self.prefix = prefix\n\n        self.use_fused_qkv = isinstance(quant_config, NunchakuConfig)\n\n        self.inner_dim = out_dim if out_dim is not None else head_dim * num_heads\n        self.inner_kv_dim = self.inner_dim\n\n        if self.use_fused_qkv:\n            # Use fused QKV projection for nunchaku quantization\n            self.to_qkv = MergedColumnParallelLinear(\n                dim,\n                [self.inner_dim] * 3,\n                bias=True,\n                quant_config=quant_config,\n                prefix=f\"{prefix}.to_qkv\",\n            )\n        else:\n            # Use separate Q/K/V projections for non-quantized models\n            self.to_q = ReplicatedLinear(dim, self.inner_dim, bias=True)\n            self.to_k = ReplicatedLinear(dim, self.inner_dim, bias=True)\n            self.to_v = ReplicatedLinear(dim, self.inner_dim, bias=True)\n\n        if self.qk_norm:\n            self.norm_q = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity()\n            self.norm_k = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity()\n\n        if added_kv_proj_dim is not None:\n            self.use_fused_added_qkv = isinstance(quant_config, NunchakuConfig)\n            if self.use_fused_added_qkv:\n                self.to_added_qkv = MergedColumnParallelLinear(\n                    added_kv_proj_dim,\n                    [self.inner_dim] * 3,\n                    bias=True,\n                    quant_config=quant_config,\n                    prefix=f\"{prefix}.to_added_qkv\",\n                )\n            else:\n                # Use separate Q/K/V projections for non-quantized models\n                self.add_q_proj = ReplicatedLinear(\n                    added_kv_proj_dim, self.inner_dim, bias=True\n                )\n                self.add_k_proj = ReplicatedLinear(\n                    added_kv_proj_dim, self.inner_dim, bias=True\n                )\n                self.add_v_proj = ReplicatedLinear(\n                    added_kv_proj_dim, self.inner_dim, bias=True\n                )\n\n        if context_pre_only is not None and not context_pre_only:\n            self.to_add_out = ReplicatedLinear(self.inner_dim, self.dim, bias=out_bias)\n        else:\n            self.to_add_out = None\n\n        if not pre_only:\n            self.to_out = nn.ModuleList([])\n            self.to_out.append(\n                ReplicatedLinear(self.inner_dim, self.dim, bias=out_bias)\n            )\n        else:\n            self.to_out = None\n\n        self.norm_added_q = RMSNorm(head_dim, eps=eps)\n        self.norm_added_k = RMSNorm(head_dim, eps=eps)\n\n        # Scaled dot product attention\n        self.attn = USPAttention(\n            num_heads=num_heads,\n            head_size=self.head_dim,\n            dropout_rate=0,\n            softmax_scale=None,\n            causal=False,\n            supported_attention_backends={\n                AttentionBackendEnum.FA,\n                AttentionBackendEnum.AITER,\n                AttentionBackendEnum.AITER_SAGE,\n                AttentionBackendEnum.TORCH_SDPA,\n                AttentionBackendEnum.SAGE_ATTN,\n                AttentionBackendEnum.SAGE_ATTN_3,\n            },\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: torch.Tensor,\n        image_rotary_emb: tuple[torch.Tensor, torch.Tensor],\n        **cross_attention_kwargs,\n    ):\n        seq_len_txt = encoder_hidden_states.shape[1]\n\n        img_query, img_key, img_value, txt_query, txt_key, txt_value = (\n            _get_qkv_projections(self, hidden_states, encoder_hidden_states)\n        )\n\n        # Reshape for multi-head attention\n        img_query = img_query.unflatten(-1, (self.num_heads, -1))\n        img_key = img_key.unflatten(-1, (self.num_heads, -1))\n        img_value = img_value.unflatten(-1, (self.num_heads, -1))\n\n        txt_query = txt_query.unflatten(-1, (self.num_heads, -1))\n        txt_key = txt_key.unflatten(-1, (self.num_heads, -1))\n        txt_value = txt_value.unflatten(-1, (self.num_heads, -1))\n\n        # Apply QK normalization\n        if self.qk_norm:\n            img_query, img_key = apply_qk_norm(\n                q=img_query,\n                k=img_key,\n                q_norm=self.norm_q,\n                k_norm=self.norm_k,\n                head_dim=img_query.shape[-1],\n                allow_inplace=True,\n            )\n            txt_query, txt_key = apply_qk_norm(\n                q=txt_query,\n                k=txt_key,\n                q_norm=self.norm_added_q,\n                k_norm=self.norm_added_k,\n                head_dim=txt_query.shape[-1],\n                allow_inplace=True,\n            )\n\n        # Apply RoPE\n        if image_rotary_emb is not None:\n            if not (\n                isinstance(image_rotary_emb[0], torch.Tensor)\n                and image_rotary_emb[0].dim() == 2\n            ):\n                raise RuntimeError(\"image_rotary_emb must be cos_sin_cache tensors\")\n\n            img_cache, txt_cache = image_rotary_emb\n\n            img_query, img_key = apply_flashinfer_rope_qk_inplace(\n                img_query, img_key, img_cache, is_neox=False\n            )\n            txt_query, txt_key = apply_flashinfer_rope_qk_inplace(\n                txt_query, txt_key, txt_cache, is_neox=False\n            )\n\n        # Concatenate for joint attention\n        # Order: [text, image]\n        joint_query = torch.cat([txt_query, img_query], dim=1)\n        joint_key = torch.cat([txt_key, img_key], dim=1)\n        joint_value = torch.cat([txt_value, img_value], dim=1)\n\n        # Compute joint attention\n        joint_hidden_states = self.attn(\n            joint_query,\n            joint_key,\n            joint_value,\n        )\n\n        # Reshape back\n        joint_hidden_states = joint_hidden_states.flatten(2, 3)\n        joint_hidden_states = joint_hidden_states.to(joint_query.dtype)\n\n        # Split attention outputs back\n        txt_attn_output = joint_hidden_states[:, :seq_len_txt, :]  # Text part\n        img_attn_output = joint_hidden_states[:, seq_len_txt:, :]  # Image part\n\n        # Apply output projections\n        img_attn_output, _ = self.to_out[0](img_attn_output)\n        if len(self.to_out) > 1:\n            (img_attn_output,) = self.to_out[1](img_attn_output)  # dropout\n\n        txt_attn_output, _ = self.to_add_out(txt_attn_output)\n\n        return img_attn_output, txt_attn_output\n\n\nclass QwenImageTransformerBlock(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        num_attention_heads: int,\n        attention_head_dim: int,\n        qk_norm: str = \"rms_norm\",\n        eps: float = 1e-6,\n        quant_config: Optional[QuantizationConfig] | NunchakuConfig = None,\n        prefix: str = \"\",\n        zero_cond_t: bool = False,\n    ):\n        super().__init__()\n        self.prefix = prefix\n\n        self.dim = dim\n        self.num_attention_heads = num_attention_heads\n        self.attention_head_dim = attention_head_dim\n        self.quant_config = quant_config\n        self.zero_cond_t = zero_cond_t\n\n        mod_quant_config = (\n            quant_config\n            if (quant_config is not None and quant_config.get_name() == \"svdquant\")\n            else None\n        )\n        # Image processing modules\n        self.img_mod = nn.Sequential(\n            nn.SiLU(),\n            nn.Linear(\n                dim, 6 * dim, bias=True\n            ),  # For scale, shift, gate for norm1 and norm2\n        )\n        self.img_norm1 = LayerNormScaleShift(\n            hidden_size=dim, eps=eps, elementwise_affine=False\n        )\n\n        self.attn = QwenImageCrossAttention(\n            dim=dim,\n            num_heads=num_attention_heads,\n            added_kv_proj_dim=dim,\n            context_pre_only=False,\n            head_dim=attention_head_dim,\n            quant_config=quant_config,\n            prefix=f\"{prefix}.attn\",\n        )\n        self.img_norm2 = ScaleResidualLayerNormScaleShift(\n            dim, eps=eps, elementwise_affine=False\n        )\n\n        # Text processing modules\n        self.txt_mod = nn.Sequential(\n            nn.SiLU(),\n            nn.Linear(\n                dim, 6 * dim, bias=True\n            ),  # For scale, shift, gate for norm1 and norm2\n        )\n        self.txt_norm1 = LayerNormScaleShift(\n            hidden_size=dim, eps=eps, elementwise_affine=False\n        )\n        # Text doesn't need separate attention - it's handled by img_attn joint computation\n        self.txt_norm2 = ScaleResidualLayerNormScaleShift(\n            hidden_size=dim, eps=eps, elementwise_affine=False\n        )\n        # Utils\n        self.fuse_mul_add = MulAdd()\n\n        nunchaku_enabled = (\n            quant_config is not None\n            and hasattr(quant_config, \"get_name\")\n            and quant_config.get_name() == \"svdquant\"\n            and is_nunchaku_available()\n        )\n        if nunchaku_enabled:\n            ff_class = diffusers.models.attention.FeedForward\n            self.img_mlp = ff_class(\n                dim=dim,\n                dim_out=dim,\n                activation_fn=\"gelu-approximate\",\n            )\n            self.txt_mlp = ff_class(\n                dim=dim,\n                dim_out=dim,\n                activation_fn=\"gelu-approximate\",\n            )\n        else:\n            self.img_mlp = FeedForward(\n                dim=dim,\n                dim_out=dim,\n                activation_fn=\"gelu-approximate\",\n            )\n            self.txt_mlp = FeedForward(\n                dim=dim,\n                dim_out=dim,\n                activation_fn=\"gelu-approximate\",\n            )\n\n        if nunchaku_enabled:\n            nunchaku_kwargs = {\n                \"precision\": quant_config.precision,\n                \"rank\": quant_config.rank,\n                \"act_unsigned\": quant_config.act_unsigned,\n            }\n            self.img_mlp = NunchakuFeedForward(self.img_mlp, **nunchaku_kwargs)\n            self.txt_mlp = NunchakuFeedForward(self.txt_mlp, **nunchaku_kwargs)\n\n    def _modulate(\n        self,\n        x: torch.Tensor,\n        mod_params: torch.Tensor,\n        norm_module: Union[LayerNormScaleShift, ScaleResidualLayerNormScaleShift],\n        index: Optional[torch.Tensor] = None,\n        gate_x: Optional[torch.Tensor] = None,\n        residual_x: Optional[torch.Tensor] = None,\n    ) -> Union[\n        Tuple[torch.Tensor, torch.Tensor],\n        Tuple[torch.Tensor, torch.Tensor, torch.Tensor],\n    ]:\n        # Apply attention gates and add residual (like in Megatron)\n        #   - residual_out = gate_x * x + residual_x\n        # - x = norm(residual_out) * (1 + scale) + shift\n        # TODO: clean code here\n        is_scale_residual = isinstance(norm_module, ScaleResidualLayerNormScaleShift)\n\n        shift, scale, gate = mod_params.chunk(3, dim=-1)\n        if index is not None:\n            actual_batch = x.shape[0]\n            shift0, shift1 = (\n                shift[:actual_batch],\n                shift[actual_batch : 2 * actual_batch],\n            )\n            scale0, scale1 = (\n                scale[:actual_batch],\n                scale[actual_batch : 2 * actual_batch],\n            )\n            gate0, gate1 = gate[:actual_batch], gate[actual_batch : 2 * actual_batch]\n            if not x.is_contiguous():\n                x = x.contiguous()\n            if not index.is_contiguous():\n                index = index.contiguous()\n            if is_scale_residual:\n                if not residual_x.is_contiguous():\n                    residual_x = residual_x.contiguous()\n                if not gate_x.is_contiguous():\n                    gate_x = gate_x.contiguous()\n                x, residual_out, gate_result = (\n                    fuse_residual_layernorm_scale_shift_gate_select01_kernel(\n                        x,\n                        residual=residual_x,\n                        residual_gate=gate_x,\n                        weight=getattr(norm_module.norm, \"weight\", None),\n                        bias=getattr(norm_module.norm, \"bias\", None),\n                        scale0=scale0.contiguous(),\n                        shift0=shift0.contiguous(),\n                        gate0=gate0.contiguous(),\n                        scale1=scale1.contiguous(),\n                        shift1=shift1.contiguous(),\n                        gate1=gate1.contiguous(),\n                        index=index,\n                        eps=norm_module.eps,\n                    )\n                )\n                return x, residual_out, gate_result\n            else:\n                x, gate_result = fuse_layernorm_scale_shift_gate_select01_kernel(\n                    x,\n                    weight=getattr(norm_module.norm, \"weight\", None),\n                    bias=getattr(norm_module.norm, \"bias\", None),\n                    scale0=scale0.contiguous(),\n                    shift0=shift0.contiguous(),\n                    gate0=gate0.contiguous(),\n                    scale1=scale1.contiguous(),\n                    shift1=shift1.contiguous(),\n                    gate1=gate1.contiguous(),\n                    index=index,\n                    eps=norm_module.eps,\n                )\n                return x, gate_result\n        else:\n            shift_result = shift.unsqueeze(1)\n            scale_result = scale.unsqueeze(1)\n            gate_result = gate.unsqueeze(1)\n            if is_scale_residual:\n                modulated, residual_out = norm_module(\n                    residual=residual_x,\n                    x=x,\n                    gate=gate_x,\n                    shift=shift_result,\n                    scale=scale_result,\n                )\n                return modulated, residual_out, gate_result\n            else:\n                modulated = norm_module(x=x, shift=shift_result, scale=scale_result)\n                return modulated, gate_result\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: torch.Tensor,\n        encoder_hidden_states_mask: torch.Tensor,\n        temb_img_silu: torch.Tensor,\n        temb_txt_silu: torch.Tensor,\n        image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n        joint_attention_kwargs: Optional[Dict[str, Any]] = None,\n        modulate_index: Optional[torch.Tensor] = None,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        # Get modulation parameters for both streams\n        img_mod_params = self.img_mod[1](temb_img_silu)  # [B, 6*dim]\n        txt_mod_params = self.txt_mod[1](temb_txt_silu)  # [B, 6*dim]\n\n        if (\n            self.quant_config is not None\n            and hasattr(self.quant_config, \"get_name\")\n            and self.quant_config.get_name() == \"svdquant\"\n        ):\n            # When NOT using nunchaku, reshape mod_params from [B, 6*dim] to [B, dim*6]\n            # When using nunchaku (svdquant), keep original format\n            img_mod_params = (\n                img_mod_params.view(img_mod_params.shape[0], -1, 6)\n                .transpose(1, 2)\n                .reshape(img_mod_params.shape[0], -1)\n            )\n            txt_mod_params = (\n                txt_mod_params.view(txt_mod_params.shape[0], -1, 6)\n                .transpose(1, 2)\n                .reshape(txt_mod_params.shape[0], -1)\n            )\n\n        # Split modulation parameters for norm1 and norm2\n        img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)  # Each [B, 3*dim]\n        txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)  # Each [B, 3*dim]\n\n        # Process image stream - norm1 + modulation\n        img_modulated, img_gate1 = self._modulate(\n            hidden_states, img_mod1, self.img_norm1, modulate_index\n        )\n        # Process text stream - norm1 + modulation\n        txt_shift1, txt_scale1, txt_gate1_raw = txt_mod1.chunk(3, dim=-1)\n        txt_modulated = self.txt_norm1(\n            encoder_hidden_states, shift=txt_shift1, scale=txt_scale1\n        )\n        txt_gate1 = txt_gate1_raw.unsqueeze(1)\n\n        # Use QwenAttnProcessor2_0 for joint attention computation\n        # This directly implements the DoubleStreamLayerMegatron logic:\n        # 1. Computes QKV for both streams\n        # 2. Applies QK normalization and RoPE\n        # 3. Concatenates and runs joint attention\n        # 4. Splits results back to separate streams\n        joint_attention_kwargs = joint_attention_kwargs or {}\n        attn_output = self.attn(\n            # Image stream (will be processed as \"sample\")\n            hidden_states=img_modulated,\n            # Text stream (will be processed as \"context\")\n            encoder_hidden_states=txt_modulated,\n            encoder_hidden_states_mask=encoder_hidden_states_mask,\n            image_rotary_emb=image_rotary_emb,\n            **joint_attention_kwargs,\n        )\n\n        # QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided\n        img_attn_output, txt_attn_output = attn_output\n        # Process image stream - norm2 + MLP\n        img_modulated2, hidden_states, img_gate2 = self._modulate(\n            img_attn_output,\n            img_mod2,\n            self.img_norm2,\n            modulate_index,\n            gate_x=img_gate1,\n            residual_x=hidden_states,\n        )\n        img_mlp_output = self.img_mlp(img_modulated2)\n\n        if img_mlp_output.dim() == 2:\n            img_mlp_output = img_mlp_output.unsqueeze(0)\n        hidden_states = self.fuse_mul_add(img_mlp_output, img_gate2, hidden_states)\n\n        # Process text stream - norm2 + MLP\n        txt_shift2, txt_scale2, txt_gate2_raw = txt_mod2.chunk(3, dim=-1)\n        txt_modulated2, encoder_hidden_states = self.txt_norm2(\n            residual=encoder_hidden_states,\n            x=txt_attn_output,\n            gate=txt_gate1,\n            shift=txt_shift2,\n            scale=txt_scale2,\n        )\n        txt_gate2 = txt_gate2_raw.unsqueeze(1)\n        txt_mlp_output = self.txt_mlp(txt_modulated2)\n\n        if txt_mlp_output.dim() == 2:\n            txt_mlp_output = txt_mlp_output.unsqueeze(0)\n        encoder_hidden_states = self.fuse_mul_add(\n            txt_mlp_output, txt_gate2, encoder_hidden_states\n        )\n\n        # Clip to prevent overflow for fp16\n        if encoder_hidden_states.dtype == torch.float16:\n            encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)\n        if hidden_states.dtype == torch.float16:\n            hidden_states = hidden_states.clip(-65504, 65504)\n\n        return encoder_hidden_states, hidden_states\n\n\ndef to_hashable(obj):\n    if isinstance(obj, list):\n        return tuple(to_hashable(x) for x in obj)\n    return obj\n\n\nclass QwenImageTransformer2DModel(CachableDiT, OffloadableDiTMixin):\n    \"\"\"\n    The Transformer model introduced in Qwen.\n\n    \"\"\"\n\n    _supports_gradient_checkpointing = True\n    _no_split_modules = [\"QwenImageTransformerBlock\"]\n    _skip_layerwise_casting_patterns = [\"pos_embed\", \"norm\"]\n    _repeated_blocks = [\"QwenImageTransformerBlock\"]\n\n    param_names_mapping = QwenImageDitConfig().arch_config.param_names_mapping\n\n    @classmethod\n    def get_nunchaku_quant_rules(cls) -> dict[str, list[str]]:\n        return {\n            \"skip\": [\n                \"norm\",\n                \"embed\",\n                \"rotary\",\n                \"pos_embed\",\n            ],\n            \"svdq_w4a4\": [\n                \"attn.to_qkv\",\n                \"attn.to_out\",\n                \"attn.add_qkv_proj\",\n                \"attn.to_add_out\",\n                \"img_mlp\",\n                \"txt_mlp\",\n            ],\n            \"awq_w4a16\": [\n                \"img_mod\",\n                \"txt_mod\",\n            ],\n        }\n\n    def __init__(\n        self,\n        config: QwenImageDitConfig,\n        hf_config: dict[str, Any],\n        quant_config: Optional[QuantizationConfig] = None,\n    ):\n        super().__init__(config=config, hf_config=hf_config)\n        patch_size = config.arch_config.patch_size\n        in_channels = config.arch_config.in_channels\n        out_channels = config.arch_config.out_channels\n        num_layers = config.arch_config.num_layers\n        attention_head_dim = config.arch_config.attention_head_dim\n        num_attention_heads = config.arch_config.num_attention_heads\n        joint_attention_dim = config.arch_config.joint_attention_dim\n        axes_dims_rope = config.arch_config.axes_dims_rope\n        self.zero_cond_t = getattr(config.arch_config, \"zero_cond_t\", False)\n        self.out_channels = out_channels or in_channels\n        self.inner_dim = num_attention_heads * attention_head_dim\n\n        self.use_additional_t_cond: bool = getattr(\n            config.arch_config, \"use_additional_t_cond\", False\n        )  # For qwen-image-layered now\n        self.use_layer3d_rope: bool = getattr(\n            config.arch_config, \"use_layer3d_rope\", False\n        )  # For qwen-image-layered now\n\n        if not self.use_layer3d_rope:\n            self.rotary_emb = QwenEmbedRope(\n                theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True\n            )\n        else:\n            self.rotary_emb = QwenEmbedLayer3DRope(\n                theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True\n            )\n\n        self.time_text_embed = QwenTimestepProjEmbeddings(\n            embedding_dim=self.inner_dim,\n            use_additional_t_cond=self.use_additional_t_cond,\n        )\n\n        self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6)\n\n        self.img_in = nn.Linear(in_channels, self.inner_dim)\n        self.txt_in = nn.Linear(joint_attention_dim, self.inner_dim)\n\n        self.transformer_blocks = nn.ModuleList(\n            [\n                QwenImageTransformerBlock(\n                    dim=self.inner_dim,\n                    num_attention_heads=num_attention_heads,\n                    attention_head_dim=attention_head_dim,\n                    quant_config=quant_config,\n                    prefix=f\"transformer_blocks.{layer_idx}\",\n                    zero_cond_t=self.zero_cond_t,\n                )\n                for layer_idx in range(num_layers)\n            ]\n        )\n\n        self.norm_out = AdaLayerNormContinuous(\n            self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6\n        )\n        self.proj_out = nn.Linear(\n            self.inner_dim, patch_size * patch_size * self.out_channels, bias=True\n        )\n\n        self.timestep_zero = torch.zeros(\n            (1,), dtype=torch.int, device=get_local_torch_device()\n        )\n\n        self.layer_names = [\"transformer_blocks\"]\n\n    @functools.lru_cache(maxsize=50)\n    def build_modulate_index(self, img_shapes: tuple[int, int, int], device):\n        modulate_index_list = []\n        for sample in img_shapes:\n            first_size = sample[0][0] * sample[0][1] * sample[0][2]\n            total_size = sum(s[0] * s[1] * s[2] for s in sample)\n            idx = (torch.arange(total_size, device=device) >= first_size).int()\n            modulate_index_list.append(idx)\n\n        modulate_index = torch.stack(modulate_index_list)\n        return modulate_index\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: torch.Tensor = None,\n        encoder_hidden_states_mask: torch.Tensor = None,\n        timestep: torch.LongTensor = None,\n        img_shapes: Optional[List[Tuple[int, int, int]]] = None,\n        txt_seq_lens: Optional[List[int]] = None,\n        freqs_cis: tuple[torch.Tensor, torch.Tensor] = None,\n        additional_t_cond: Optional[torch.Tensor] = None,\n        guidance: torch.Tensor = None,  # TODO: this should probably be removed\n        attention_kwargs: Optional[Dict[str, Any]] = None,\n        controlnet_block_samples=None,\n        return_dict: bool = True,\n    ) -> Union[torch.Tensor, Transformer2DModelOutput]:\n        \"\"\"\n        The [`QwenTransformer2DModel`] forward method.\n\n        Args:\n            hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):\n                Input `hidden_states`.\n            encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):\n                Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.\n            encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`):\n                Mask of the input conditions.\n            timestep ( `torch.LongTensor`):\n                Used to indicate denoising step.\n            attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain\n                tuple.\n\n        Returns:\n            If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a\n            `tuple` where the first element is the sample tensor.\n        \"\"\"\n        if (\n            attention_kwargs is not None\n            and attention_kwargs.get(\"scale\", None) is not None\n        ):\n            logger.warning(\n                \"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective.\"\n            )\n\n        if isinstance(encoder_hidden_states, list):\n            encoder_hidden_states = encoder_hidden_states[0]\n\n        hidden_states = self.img_in(hidden_states)\n\n        timestep = (timestep / 1000).to(hidden_states.dtype)\n\n        if self.zero_cond_t:\n            timestep = torch.cat([timestep, self.timestep_zero], dim=0)\n            device = timestep.device\n            modulate_index = self.build_modulate_index(to_hashable(img_shapes), device)\n        else:\n            modulate_index = None\n\n        encoder_hidden_states = self.txt_norm(encoder_hidden_states)\n        encoder_hidden_states = self.txt_in(encoder_hidden_states)\n\n        temb = self.time_text_embed(timestep, hidden_states, additional_t_cond)\n\n        temb_img_silu = F.silu(temb)\n        if self.zero_cond_t:\n            temb_txt = temb.chunk(2, dim=0)[0]\n            temb_txt_silu = temb_img_silu.chunk(2, dim=0)[0]\n        else:\n            temb_txt = temb\n            temb_txt_silu = temb_img_silu\n\n        image_rotary_emb = freqs_cis\n        for index_block, block in enumerate(self.transformer_blocks):\n            encoder_hidden_states, hidden_states = block(\n                hidden_states=hidden_states,\n                encoder_hidden_states=encoder_hidden_states,\n                encoder_hidden_states_mask=encoder_hidden_states_mask,\n                temb_img_silu=temb_img_silu,\n                temb_txt_silu=temb_txt_silu,\n                image_rotary_emb=image_rotary_emb,\n                joint_attention_kwargs=attention_kwargs,\n                modulate_index=modulate_index,\n            )\n\n            # controlnet residual\n            if controlnet_block_samples is not None:\n                interval_control = len(self.transformer_blocks) / len(\n                    controlnet_block_samples\n                )\n                interval_control = int(np.ceil(interval_control))\n                hidden_states = (\n                    hidden_states\n                    + controlnet_block_samples[index_block // interval_control]\n                )\n        # Use only the image part (hidden_states) from the dual-stream blocks\n        hidden_states = self.norm_out(hidden_states, temb_txt)\n\n        output = self.proj_out(hidden_states)\n        return output\n\n\nEntryClass = QwenImageTransformer2DModel\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/dits/sana.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom diffusers.models.embeddings import PixArtAlphaTextProjection, TimestepEmbedding\n\nfrom sglang.multimodal_gen.configs.models.dits.sana import SanaConfig\nfrom sglang.multimodal_gen.runtime.layers.layernorm import RMSNorm\nfrom sglang.multimodal_gen.runtime.layers.visual_embedding import Timesteps\nfrom sglang.multimodal_gen.runtime.models.dits.base import CachableDiT\nfrom sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\nclass SanaCombinedTimestepSizeEmbeddings(nn.Module):\n    def __init__(self, embedding_dim):\n        super().__init__()\n        self.time_proj = Timesteps(\n            num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0\n        )\n        self.timestep_embedder = TimestepEmbedding(\n            in_channels=256, time_embed_dim=embedding_dim\n        )\n\n    def forward(self, timestep, hidden_dtype=None):\n        timesteps_proj = self.time_proj(timestep)\n        if hidden_dtype is not None:\n            timesteps_proj = timesteps_proj.to(dtype=hidden_dtype)\n        timesteps_emb = self.timestep_embedder(timesteps_proj)\n        return timesteps_emb\n\n\nclass SanaAdaLayerNormSingle(nn.Module):\n    def __init__(self, embedding_dim):\n        super().__init__()\n        self.emb = SanaCombinedTimestepSizeEmbeddings(embedding_dim)\n        self.silu = nn.SiLU()\n        self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)\n\n    def forward(self, timestep, hidden_dtype=None):\n        embedded_timestep = self.emb(timestep, hidden_dtype=hidden_dtype)\n        out = self.linear(self.silu(embedded_timestep))\n        return out, embedded_timestep\n\n\nclass SanaModulatedNorm(nn.Module):\n    def __init__(self, dim, eps=1e-6):\n        super().__init__()\n        self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)\n\n    def forward(self, x, temb, scale_shift_table):\n        x = self.norm(x)\n        shift, scale = (scale_shift_table[None] + temb[:, None]).chunk(2, dim=1)\n        x = x * (1 + scale) + shift\n        return x\n\n\nclass GLUMBConv(nn.Module):\n    \"\"\"Gated Linear Unit with Multi-Branch Convolution.\"\"\"\n\n    def __init__(self, in_channels, out_channels, expand_ratio=2.5):\n        super().__init__()\n        hidden_channels = int(expand_ratio * in_channels)\n        self.nonlinearity = nn.SiLU()\n        self.conv_inverted = nn.Conv2d(in_channels, hidden_channels * 2, 1, 1, 0)\n        self.conv_depth = nn.Conv2d(\n            hidden_channels * 2,\n            hidden_channels * 2,\n            3,\n            1,\n            1,\n            groups=hidden_channels * 2,\n        )\n        self.conv_point = nn.Conv2d(hidden_channels, out_channels, 1, 1, 0, bias=False)\n\n    def forward(self, hidden_states):\n        hidden_states = self.conv_inverted(hidden_states)\n        hidden_states = self.nonlinearity(hidden_states)\n        hidden_states = self.conv_depth(hidden_states)\n        hidden_states, gate = torch.chunk(hidden_states, 2, dim=1)\n        hidden_states = hidden_states * self.nonlinearity(gate)\n        hidden_states = self.conv_point(hidden_states)\n        return hidden_states\n\n\nclass SanaLinearAttention(nn.Module):\n    \"\"\"Linear attention with O(N*D^2) complexity instead of O(N^2*D).\"\"\"\n\n    def __init__(self, query_dim, num_heads, head_dim, qk_norm_dim, bias=False):\n        super().__init__()\n        inner_dim = num_heads * head_dim\n        self.num_heads = num_heads\n        self.head_dim = head_dim\n\n        self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)\n        self.to_k = nn.Linear(query_dim, inner_dim, bias=bias)\n        self.to_v = nn.Linear(query_dim, inner_dim, bias=bias)\n        self.to_out = nn.ModuleList(\n            [nn.Linear(inner_dim, query_dim, bias=True), nn.Identity()]\n        )\n        self.norm_q = RMSNorm(qk_norm_dim)\n        self.norm_k = RMSNorm(qk_norm_dim)\n\n    def forward(self, hidden_states):\n        B, S, _ = hidden_states.shape\n\n        query = self.to_q(hidden_states)\n        key = self.to_k(hidden_states)\n        value = self.to_v(hidden_states)\n\n        query = self.norm_q(query)\n        key = self.norm_k(key)\n\n        query = query.view(B, S, self.num_heads, self.head_dim).transpose(1, 2)\n        key = key.view(B, S, self.num_heads, self.head_dim).transpose(1, 2)\n        value = value.view(B, S, self.num_heads, self.head_dim).transpose(1, 2)\n        query = F.relu(query)\n        key = F.relu(key)\n\n        kv = torch.matmul(key.transpose(-2, -1), value)  # (B, H, D, D)\n        qkv = torch.matmul(query, kv)  # (B, H, S, D)\n\n        key_sum = key.sum(dim=-2, keepdim=True)  # (B, H, 1, D)\n        normalizer = torch.matmul(query, key_sum.transpose(-2, -1)).clamp(min=1e-6)\n        hidden_states = qkv / normalizer\n\n        hidden_states = hidden_states.transpose(1, 2).reshape(B, S, -1)\n        hidden_states = self.to_out[0](hidden_states)\n        return hidden_states\n\n\nclass SanaCrossAttention(nn.Module):\n    def __init__(self, query_dim, cross_attention_dim, num_heads, head_dim, bias=False):\n        super().__init__()\n        inner_dim = num_heads * head_dim\n        self.num_heads = num_heads\n        self.head_dim = head_dim\n\n        self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)\n        self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)\n        self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)\n        self.to_out = nn.ModuleList(\n            [nn.Linear(inner_dim, query_dim, bias=True), nn.Identity()]\n        )\n\n        self.norm_q = RMSNorm(inner_dim)\n        self.norm_k = RMSNorm(inner_dim)\n\n    def forward(\n        self, hidden_states, encoder_hidden_states, encoder_attention_mask=None\n    ):\n        B, S, _ = hidden_states.shape\n        T = encoder_hidden_states.shape[1]\n\n        query = self.to_q(hidden_states)\n        key = self.to_k(encoder_hidden_states)\n        value = self.to_v(encoder_hidden_states)\n\n        query = self.norm_q(query)\n        key = self.norm_k(key)\n\n        query = query.view(B, S, self.num_heads, self.head_dim).transpose(1, 2)\n        key = key.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)\n        value = value.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)\n\n        attn_mask = None\n        if encoder_attention_mask is not None:\n            attn_mask = encoder_attention_mask.bool()\n            attn_mask = attn_mask[:, None, None, :].expand(B, self.num_heads, S, T)\n\n        hidden_states = F.scaled_dot_product_attention(\n            query, key, value, attn_mask=attn_mask\n        )\n        hidden_states = hidden_states.transpose(1, 2).reshape(B, S, -1)\n        hidden_states = self.to_out[0](hidden_states)\n        return hidden_states\n\n\nclass SanaTransformerBlock(nn.Module):\n    def __init__(\n        self,\n        dim,\n        num_attention_heads,\n        attention_head_dim,\n        num_cross_attention_heads,\n        cross_attention_head_dim,\n        cross_attention_dim,\n        mlp_ratio,\n        norm_eps,\n        attention_bias=False,\n    ):\n        super().__init__()\n\n        self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)\n\n        self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=norm_eps)\n        self.attn1 = SanaLinearAttention(\n            query_dim=dim,\n            num_heads=num_attention_heads,\n            head_dim=attention_head_dim,\n            qk_norm_dim=num_attention_heads * attention_head_dim,\n            bias=attention_bias,\n        )\n\n        self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=norm_eps)\n        self.attn2 = SanaCrossAttention(\n            query_dim=dim,\n            cross_attention_dim=cross_attention_dim,\n            num_heads=num_cross_attention_heads,\n            head_dim=cross_attention_head_dim,\n            bias=True,\n        )\n\n        self.ff = GLUMBConv(in_channels=dim, out_channels=dim, expand_ratio=mlp_ratio)\n\n    def forward(\n        self,\n        hidden_states,\n        encoder_hidden_states,\n        timestep,\n        height,\n        width,\n        encoder_attention_mask=None,\n    ):\n        batch_size = hidden_states.shape[0]\n\n        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (\n            self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)\n        ).chunk(6, dim=1)\n\n        norm_hidden = self.norm1(hidden_states)\n        norm_hidden = norm_hidden * (1 + scale_msa) + shift_msa\n        attn_output = self.attn1(norm_hidden)\n        hidden_states = hidden_states + gate_msa * attn_output\n\n        attn_output = self.attn2(\n            hidden_states, encoder_hidden_states, encoder_attention_mask\n        )\n        hidden_states = hidden_states + attn_output\n\n        norm_hidden = self.norm2(hidden_states)\n        norm_hidden = norm_hidden * (1 + scale_mlp) + shift_mlp\n        norm_hidden = norm_hidden.unflatten(1, (height, width)).permute(0, 3, 1, 2)\n        ff_output = self.ff(norm_hidden)\n        ff_output = ff_output.flatten(2, 3).permute(0, 2, 1)\n        hidden_states = hidden_states + gate_mlp * ff_output\n\n        return hidden_states\n\n\nclass SanaTransformer2DModel(CachableDiT, OffloadableDiTMixin):\n\n    _fsdp_shard_conditions = [\n        lambda n, m: isinstance(m, SanaTransformerBlock),\n    ]\n    _compile_conditions = [\n        lambda n, m: isinstance(m, SanaTransformerBlock),\n    ]\n    param_names_mapping = SanaConfig().arch_config.param_names_mapping\n    reverse_param_names_mapping = {}\n\n    def __init__(self, config: SanaConfig, hf_config=None, **kwargs):\n        super().__init__(config, hf_config=hf_config or {}, **kwargs)\n\n        arch = config.arch_config\n        self.out_channels = arch.out_channels\n        self.patch_size = arch.patch_size\n        self.inner_dim = arch.num_attention_heads * arch.attention_head_dim\n\n        self.hidden_size = self.inner_dim\n        self.num_attention_heads = arch.num_attention_heads\n        self.num_channels_latents = arch.num_channels_latents\n\n        self.patch_embed = nn.ModuleDict(\n            {\n                \"proj\": nn.Conv2d(\n                    arch.in_channels,\n                    self.inner_dim,\n                    kernel_size=arch.patch_size,\n                    stride=arch.patch_size,\n                    bias=True,\n                ),\n            }\n        )\n        self.time_embed = SanaAdaLayerNormSingle(self.inner_dim)\n        self.caption_projection = PixArtAlphaTextProjection(\n            in_features=arch.caption_channels,\n            hidden_size=self.inner_dim,\n        )\n\n        self.caption_norm = RMSNorm(self.inner_dim)\n\n        self.transformer_blocks = nn.ModuleList(\n            [\n                SanaTransformerBlock(\n                    dim=self.inner_dim,\n                    num_attention_heads=arch.num_attention_heads,\n                    attention_head_dim=arch.attention_head_dim,\n                    num_cross_attention_heads=arch.num_cross_attention_heads,\n                    cross_attention_head_dim=arch.cross_attention_head_dim,\n                    cross_attention_dim=arch.cross_attention_dim,\n                    mlp_ratio=arch.mlp_ratio,\n                    norm_eps=arch.norm_eps,\n                    attention_bias=False,\n                )\n                for _ in range(arch.num_layers)\n            ]\n        )\n        self.scale_shift_table = nn.Parameter(\n            torch.randn(2, self.inner_dim) / self.inner_dim**0.5\n        )\n\n        self.norm_out = SanaModulatedNorm(self.inner_dim, eps=arch.norm_eps)\n\n        self.proj_out = nn.Linear(\n            self.inner_dim,\n            arch.patch_size * arch.patch_size * self.out_channels,\n            bias=True,\n        )\n\n        self.layer_names = [\"transformer_blocks\"]\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: torch.Tensor = None,\n        timestep: torch.LongTensor = None,\n        guidance: torch.Tensor = None,\n        encoder_attention_mask: torch.Tensor = None,\n        **kwargs,\n    ) -> torch.Tensor:\n\n        # Input validation - fail fast\n        if encoder_hidden_states is None:\n            raise ValueError(\"SANA forward pass requires encoder_hidden_states\")\n\n        batch_size, channels, height, width = hidden_states.shape\n        p = self.patch_size\n        post_patch_height = height // p\n        post_patch_width = width // p\n\n        hidden_states = self.patch_embed[\"proj\"](hidden_states)\n        hidden_states = hidden_states.flatten(2).transpose(1, 2)\n\n        timestep_emb, embedded_timestep = self.time_embed(\n            timestep, hidden_dtype=hidden_states.dtype\n        )\n\n        if isinstance(encoder_attention_mask, (list, tuple)):\n            encoder_attention_mask = encoder_attention_mask[0]\n\n        encoder_hidden_states = self.caption_projection(encoder_hidden_states)\n        if encoder_hidden_states.shape[0] != batch_size:\n            encoder_hidden_states = encoder_hidden_states.expand(\n                batch_size, -1, -1\n            ).contiguous()\n        encoder_hidden_states = encoder_hidden_states.view(\n            batch_size, -1, hidden_states.shape[-1]\n        )\n        encoder_hidden_states = self.caption_norm(encoder_hidden_states)\n\n        if (\n            encoder_attention_mask is not None\n            and encoder_attention_mask.shape[0] != batch_size\n        ):\n            encoder_attention_mask = encoder_attention_mask.expand(\n                batch_size, -1\n            ).contiguous()\n\n        for block in self.transformer_blocks:\n            hidden_states = block(\n                hidden_states,\n                encoder_hidden_states,\n                timestep_emb,\n                post_patch_height,\n                post_patch_width,\n                encoder_attention_mask=encoder_attention_mask,\n            )\n        hidden_states = self.norm_out(\n            hidden_states, embedded_timestep, self.scale_shift_table\n        )\n        hidden_states = self.proj_out(hidden_states)\n        hidden_states = hidden_states.reshape(\n            batch_size, post_patch_height, post_patch_width, p, p, self.out_channels\n        )\n        hidden_states = hidden_states.permute(0, 5, 1, 3, 2, 4)\n        hidden_states = hidden_states.reshape(\n            batch_size, self.out_channels, height, width\n        )\n\n        return hidden_states\n\n\nEntryClass = SanaTransformer2DModel\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\nimport math\nfrom functools import lru_cache\nfrom typing import Any\n\nimport torch\nimport torch.nn as nn\n\nfrom sglang.multimodal_gen.configs.models.dits import WanVideoConfig\nfrom sglang.multimodal_gen.configs.sample.wan import WanTeaCacheParams\nfrom sglang.multimodal_gen.runtime.distributed import (\n    divide,\n    get_sp_group,\n    get_sp_world_size,\n    get_tp_world_size,\n    sequence_model_parallel_all_gather,\n)\nfrom sglang.multimodal_gen.runtime.layers.attention import (\n    MinimalA2AAttnOp,\n    UlyssesAttention_VSA,\n    USPAttention,\n)\nfrom sglang.multimodal_gen.runtime.layers.elementwise import MulAdd\nfrom sglang.multimodal_gen.runtime.layers.layernorm import (\n    FP32LayerNorm,\n    LayerNormScaleShift,\n    RMSNorm,\n    ScaleResidualLayerNormScaleShift,\n    tensor_parallel_rms_norm,\n)\nfrom sglang.multimodal_gen.runtime.layers.linear import (\n    ColumnParallelLinear,\n    RowParallelLinear,\n)\nfrom sglang.multimodal_gen.runtime.layers.mlp import MLP\nfrom sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import (\n    QuantizationConfig,\n)\nfrom sglang.multimodal_gen.runtime.layers.rotary_embedding import (\n    NDRotaryEmbedding,\n    _apply_rotary_emb,\n    apply_flashinfer_rope_qk_inplace,\n)\nfrom sglang.multimodal_gen.runtime.layers.visual_embedding import (\n    ModulateProjection,\n    PatchEmbed,\n    TimestepEmbedder,\n)\nfrom sglang.multimodal_gen.runtime.managers.forward_context import get_forward_context\nfrom sglang.multimodal_gen.runtime.models.dits.base import CachableDiT\nfrom sglang.multimodal_gen.runtime.platforms import (\n    AttentionBackendEnum,\n    current_platform,\n)\nfrom sglang.multimodal_gen.runtime.server_args import get_global_server_args\nfrom sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.srt.utils import add_prefix\n\nlogger = init_logger(__name__)\n_is_cuda = current_platform.is_cuda()\n\n\nclass WanImageEmbedding(torch.nn.Module):\n\n    def __init__(self, in_features: int, out_features: int):\n        super().__init__()\n\n        self.norm1 = FP32LayerNorm(in_features)\n        self.ff = MLP(in_features, in_features, out_features, act_type=\"gelu\")\n        self.norm2 = FP32LayerNorm(out_features)\n\n    def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:\n        dtype = encoder_hidden_states_image.dtype\n        hidden_states = self.norm1(encoder_hidden_states_image)\n        hidden_states = self.ff(hidden_states)\n        hidden_states = self.norm2(hidden_states).to(dtype)\n        return hidden_states\n\n\nclass WanTimeTextImageEmbedding(nn.Module):\n\n    def __init__(\n        self,\n        dim: int,\n        time_freq_dim: int,\n        text_embed_dim: int,\n        image_embed_dim: int | None = None,\n    ):\n        super().__init__()\n\n        self.time_embedder = TimestepEmbedder(\n            dim, frequency_embedding_size=time_freq_dim, act_layer=\"silu\"\n        )\n        self.time_modulation = ModulateProjection(dim, factor=6, act_layer=\"silu\")\n        self.text_embedder = MLP(\n            text_embed_dim, dim, dim, bias=True, act_type=\"gelu_pytorch_tanh\"\n        )\n\n        self.image_embedder = None\n        if image_embed_dim is not None:\n            self.image_embedder = WanImageEmbedding(image_embed_dim, dim)\n\n    def forward(\n        self,\n        timestep: torch.Tensor,\n        encoder_hidden_states: torch.Tensor,\n        encoder_hidden_states_image: torch.Tensor | None = None,\n        timestep_seq_len: int | None = None,\n    ):\n        temb = self.time_embedder(timestep, timestep_seq_len)\n        timestep_proj = self.time_modulation(temb)\n\n        encoder_hidden_states = self.text_embedder(encoder_hidden_states)\n        if encoder_hidden_states_image is not None:\n            assert self.image_embedder is not None\n            encoder_hidden_states_image = self.image_embedder(\n                encoder_hidden_states_image\n            )\n\n        return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image\n\n\nclass WanSelfAttention(nn.Module):\n\n    def __init__(\n        self,\n        dim: int,\n        num_heads: int,\n        window_size=(-1, -1),\n        qk_norm=True,\n        eps=1e-6,\n        parallel_attention=False,\n        prefix: str = \"\",\n        supported_attention_backends: set[AttentionBackendEnum] | None = None,\n        is_cross_attention: bool = False,\n        quant_config: QuantizationConfig | None = None,\n    ) -> None:\n        assert dim % num_heads == 0\n        super().__init__()\n        self.dim = dim\n        self.num_heads = num_heads\n        self.head_dim = dim // num_heads\n        self.window_size = window_size\n        self.qk_norm = qk_norm\n        self.eps = eps\n        self.parallel_attention = parallel_attention\n        tp_size = get_tp_world_size()\n\n        # layers\n        self.to_q = ColumnParallelLinear(\n            dim,\n            dim,\n            gather_output=False,\n            quant_config=quant_config,\n            prefix=add_prefix(\"to_q\", prefix),\n        )\n        self.to_k = ColumnParallelLinear(\n            dim,\n            dim,\n            gather_output=False,\n            quant_config=quant_config,\n            prefix=add_prefix(\"to_k\", prefix),\n        )\n        self.to_v = ColumnParallelLinear(\n            dim,\n            dim,\n            gather_output=False,\n            quant_config=quant_config,\n            prefix=add_prefix(\"to_v\", prefix),\n        )\n        self.to_out = RowParallelLinear(\n            dim,\n            dim,\n            input_is_parallel=True,\n            quant_config=quant_config,\n            prefix=add_prefix(\"to_out.0\", prefix),\n        )\n        self.norm_q = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity()\n        self.norm_k = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity()\n        self.tp_rmsnorm = tp_size > 1 and qk_norm\n        self.local_num_heads = divide(num_heads, tp_size)\n\n        # Scaled dot product attention\n        self.attn = USPAttention(\n            num_heads=self.local_num_heads,\n            head_size=self.head_dim,\n            dropout_rate=0,\n            softmax_scale=None,\n            causal=False,\n            supported_attention_backends=supported_attention_backends,\n            skip_sequence_parallel=is_cross_attention,\n            quant_config=quant_config,\n        )\n\n    def forward(self, x: torch.Tensor, context: torch.Tensor, context_lens: int):\n        r\"\"\"\n        Args:\n            x(Tensor): Shape [B, L, num_heads, C / num_heads]\n        \"\"\"\n        pass\n\n\nclass WanT2VCrossAttention(WanSelfAttention):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs, is_cross_attention=True)\n\n    def forward(self, x, context, context_lens):\n        r\"\"\"\n        Args:\n            x(Tensor): Shape [B, L1, C]\n            context(Tensor): Shape [B, L2, C]\n            context_lens(Tensor): Shape [B]\n        \"\"\"\n        q, _ = self.to_q(x)\n        if self.tp_rmsnorm:\n            q = tensor_parallel_rms_norm(q, self.norm_q)\n        else:\n            q = self.norm_q(q)\n        q = q.unflatten(2, (self.local_num_heads, self.head_dim))\n\n        k, _ = self.to_k(context)\n        if self.tp_rmsnorm:\n            k = tensor_parallel_rms_norm(k, self.norm_k)\n        else:\n            k = self.norm_k(k)\n        k = k.unflatten(2, (self.local_num_heads, self.head_dim))\n\n        v, _ = self.to_v(context)\n        v = v.unflatten(2, (self.local_num_heads, self.head_dim))\n\n        # compute attention\n        x = self.attn(q, k, v)\n\n        # output\n        x = x.flatten(2)\n        x, _ = self.to_out(x)\n        return x\n\n\nclass WanI2VCrossAttention(WanSelfAttention):\n\n    def __init__(\n        self,\n        dim: int,\n        num_heads: int,\n        window_size=(-1, -1),\n        qk_norm=True,\n        eps=1e-6,\n        prefix: str = \"\",\n        supported_attention_backends: set[AttentionBackendEnum] | None = None,\n        quant_config: QuantizationConfig | None = None,\n    ) -> None:\n        super().__init__(\n            dim,\n            num_heads,\n            window_size,\n            qk_norm,\n            eps,\n            supported_attention_backends=supported_attention_backends,\n            is_cross_attention=True,\n            quant_config=quant_config,\n        )\n\n        self.add_k_proj = ColumnParallelLinear(\n            dim,\n            dim,\n            gather_output=False,\n            quant_config=quant_config,\n            prefix=add_prefix(\"add_k_proj\", prefix),\n        )\n        self.add_v_proj = ColumnParallelLinear(\n            dim,\n            dim,\n            gather_output=False,\n            quant_config=quant_config,\n            prefix=add_prefix(\"add_v_proj\", prefix),\n        )\n        self.norm_added_k = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity()\n\n    def forward(self, x, context, context_lens):\n        r\"\"\"\n        Args:\n            x(Tensor): Shape [B, L1, C]\n            context(Tensor): Shape [B, L2, C]\n            context_lens(Tensor): Shape [B]\n        \"\"\"\n        context_img = context[:, :257]\n        context = context[:, 257:]\n\n        q, _ = self.to_q(x)\n        if self.tp_rmsnorm:\n            q = tensor_parallel_rms_norm(q, self.norm_q)\n        else:\n            q = self.norm_q(q)\n        q = q.unflatten(2, (self.local_num_heads, self.head_dim))\n\n        k, _ = self.to_k(context)\n        if self.tp_rmsnorm:\n            k = tensor_parallel_rms_norm(k, self.norm_k)\n        else:\n            k = self.norm_k(k)\n        k = k.unflatten(2, (self.local_num_heads, self.head_dim))\n\n        v, _ = self.to_v(context)\n        v = v.unflatten(2, (self.local_num_heads, self.head_dim))\n\n        k_img, _ = self.add_k_proj(context_img)\n        if self.tp_rmsnorm:\n            k_img = tensor_parallel_rms_norm(k_img, self.norm_added_k)\n        else:\n            k_img = self.norm_added_k(k_img)\n        k_img = k_img.unflatten(2, (self.local_num_heads, self.head_dim))\n\n        v_img, _ = self.add_v_proj(context_img)\n        v_img = v_img.unflatten(2, (self.local_num_heads, self.head_dim))\n\n        img_x = self.attn(q, k_img, v_img)\n        x = self.attn(q, k, v)\n\n        # output\n        x = x.flatten(2)\n        img_x = img_x.flatten(2)\n        x = x + img_x\n        x, _ = self.to_out(x)\n        return x\n\n\nclass WanTransformerBlock(nn.Module):\n\n    def __init__(\n        self,\n        dim: int,\n        ffn_dim: int,\n        num_heads: int,\n        qk_norm: str = \"rms_norm_across_heads\",\n        cross_attn_norm: bool = False,\n        eps: float = 1e-6,\n        added_kv_proj_dim: int | None = None,\n        supported_attention_backends: set[AttentionBackendEnum] | None = None,\n        prefix: str = \"\",\n        attention_type: str = \"original\",\n        sla_topk: float = 0.1,\n        quant_config: QuantizationConfig | None = None,\n    ):\n        super().__init__()\n\n        # 1. Self-attention\n        self.norm1 = LayerNormScaleShift(\n            dim,\n            eps=eps,\n            elementwise_affine=False,\n            dtype=torch.float32,\n        )\n        self.to_q = ColumnParallelLinear(\n            dim,\n            dim,\n            bias=True,\n            gather_output=False,\n            quant_config=quant_config,\n            prefix=add_prefix(\"attn1.to_q\", prefix),\n        )\n        self.to_k = ColumnParallelLinear(\n            dim,\n            dim,\n            bias=True,\n            gather_output=False,\n            quant_config=quant_config,\n            prefix=add_prefix(\"attn1.to_k\", prefix),\n        )\n        self.to_v = ColumnParallelLinear(\n            dim,\n            dim,\n            bias=True,\n            gather_output=False,\n            quant_config=quant_config,\n            prefix=add_prefix(\"attn1.to_v\", prefix),\n        )\n\n        self.to_out = RowParallelLinear(\n            dim,\n            dim,\n            bias=True,\n            reduce_results=True,\n            quant_config=quant_config,\n            prefix=add_prefix(\"attn1.to_out.0\", prefix),\n        )\n        tp_size = get_tp_world_size()\n        self.local_num_heads = divide(num_heads, tp_size)\n        self_attn_backends = supported_attention_backends\n\n        if attention_type in (\"sla\", \"sagesla\"):\n            self.attn1 = MinimalA2AAttnOp(\n                num_heads=self.local_num_heads,\n                head_size=dim // num_heads,\n                attention_type=attention_type,\n                topk=sla_topk,\n                supported_attention_backends={\n                    AttentionBackendEnum.SLA_ATTN,\n                    AttentionBackendEnum.SAGE_SLA_ATTN,\n                },\n                prefix=add_prefix(\"attn1\", prefix),\n            )\n        else:\n            self.attn1 = USPAttention(\n                num_heads=self.local_num_heads,\n                head_size=dim // num_heads,\n                causal=False,\n                supported_attention_backends=self_attn_backends,\n                prefix=add_prefix(\"attn1\", prefix),\n                quant_config=quant_config,\n                is_cross_attention=False,\n            )\n\n        self.hidden_dim = dim\n        self.num_attention_heads = num_heads\n        self.dim_head = dim // num_heads\n        if qk_norm == \"rms_norm\":\n            self.norm_q = RMSNorm(self.dim_head, eps=eps)\n            self.norm_k = RMSNorm(self.dim_head, eps=eps)\n        elif qk_norm == \"rms_norm_across_heads\":\n            # LTX applies qk norm across all heads\n            self.norm_q = RMSNorm(dim, eps=eps)\n            self.norm_k = RMSNorm(dim, eps=eps)\n        else:\n            logger.error(\"QK Norm type not supported\")\n            raise Exception\n        assert cross_attn_norm is True\n        self.qk_norm = qk_norm\n        self.tp_rmsnorm = qk_norm == \"rms_norm_across_heads\" and tp_size > 1\n        self.self_attn_residual_norm = ScaleResidualLayerNormScaleShift(\n            dim,\n            eps=eps,\n            elementwise_affine=True,\n            dtype=torch.float32,\n        )\n\n        # 2. Cross-attention\n        cross_attn_backends = {\n            b for b in supported_attention_backends if not b.is_sparse\n        }\n        if added_kv_proj_dim is not None:\n            # I2V\n            self.attn2 = WanI2VCrossAttention(\n                dim,\n                num_heads,\n                qk_norm=qk_norm,\n                eps=eps,\n                prefix=add_prefix(\"attn2\", prefix),\n                supported_attention_backends=cross_attn_backends,\n                quant_config=quant_config,\n            )\n        else:\n            # T2V\n            self.attn2 = WanT2VCrossAttention(\n                dim,\n                num_heads,\n                qk_norm=qk_norm,\n                eps=eps,\n                prefix=add_prefix(\"attn2\", prefix),\n                supported_attention_backends=cross_attn_backends,\n                quant_config=quant_config,\n            )\n        self.cross_attn_residual_norm = ScaleResidualLayerNormScaleShift(\n            dim,\n            eps=eps,\n            elementwise_affine=False,\n            dtype=torch.float32,\n        )\n\n        # 3. Feed-forward\n        self.ffn = MLP(\n            dim,\n            ffn_dim,\n            act_type=\"gelu_pytorch_tanh\",\n            prefix=add_prefix(\"ffn.net\", prefix),\n            quant_config=quant_config,\n        )\n        self.mlp_residual = MulAdd()\n\n        self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: torch.Tensor,\n        temb: torch.Tensor,\n        freqs_cis: tuple[torch.Tensor, torch.Tensor],\n    ) -> torch.Tensor:\n        if hidden_states.dim() == 4:\n            hidden_states = hidden_states.squeeze(1)\n        bs, seq_length, _ = hidden_states.shape\n        orig_dtype = hidden_states.dtype\n        if temb.dim() == 4:\n            # temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v)\n            shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (\n                self.scale_shift_table.unsqueeze(0) + temb.float()\n            ).chunk(6, dim=2)\n            # batch_size, seq_len, 1, inner_dim\n            shift_msa = shift_msa.squeeze(2)\n            scale_msa = scale_msa.squeeze(2)\n            gate_msa = gate_msa.squeeze(2)\n            c_shift_msa = c_shift_msa.squeeze(2)\n            c_scale_msa = c_scale_msa.squeeze(2)\n            c_gate_msa = c_gate_msa.squeeze(2)\n        else:\n            # temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B)\n            e = self.scale_shift_table + temb.float()\n            shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (\n                e.chunk(6, dim=1)\n            )\n\n        assert shift_msa.dtype == torch.float32\n\n        # 1. Self-attention\n        norm_hidden_states = self.norm1(hidden_states, shift_msa, scale_msa)\n        query, _ = self.to_q(norm_hidden_states)\n        key, _ = self.to_k(norm_hidden_states)\n        value, _ = self.to_v(norm_hidden_states)\n\n        if self.norm_q is not None:\n            if self.tp_rmsnorm:\n                query = tensor_parallel_rms_norm(query, self.norm_q)\n            else:\n                query = self.norm_q(query)\n        if self.norm_k is not None:\n            if self.tp_rmsnorm:\n                key = tensor_parallel_rms_norm(key, self.norm_k)\n            else:\n                key = self.norm_k(key)\n        query = query.squeeze(1).unflatten(2, (self.local_num_heads, self.dim_head))\n        key = key.squeeze(1).unflatten(2, (self.local_num_heads, self.dim_head))\n        value = value.squeeze(1).unflatten(2, (self.local_num_heads, self.dim_head))\n\n        # Apply rotary embeddings\n        cos, sin = freqs_cis\n        if _is_cuda and query.shape == key.shape:\n            cos_sin_cache = torch.cat(\n                [\n                    cos.to(dtype=torch.float32).contiguous(),\n                    sin.to(dtype=torch.float32).contiguous(),\n                ],\n                dim=-1,\n            )\n            query, key = apply_flashinfer_rope_qk_inplace(\n                query, key, cos_sin_cache, is_neox=False\n            )\n        else:\n            query, key = _apply_rotary_emb(\n                query, cos, sin, is_neox_style=False\n            ), _apply_rotary_emb(key, cos, sin, is_neox_style=False)\n        attn_output = self.attn1(query, key, value)\n        attn_output = attn_output.flatten(2)\n        attn_output, _ = self.to_out(attn_output)\n        attn_output = attn_output.squeeze(1)\n\n        null_shift = null_scale = torch.zeros(\n            (1,), device=hidden_states.device, dtype=hidden_states.dtype\n        )\n        norm_hidden_states, hidden_states = self.self_attn_residual_norm(\n            hidden_states, attn_output, gate_msa, null_shift, null_scale\n        )\n        norm_hidden_states, hidden_states = norm_hidden_states.to(\n            orig_dtype\n        ), hidden_states.to(orig_dtype)\n\n        # 2. Cross-attention\n        attn_output = self.attn2(\n            norm_hidden_states, context=encoder_hidden_states, context_lens=None\n        )\n        norm_hidden_states, hidden_states = self.cross_attn_residual_norm(\n            hidden_states, attn_output, 1, c_shift_msa, c_scale_msa\n        )\n        norm_hidden_states, hidden_states = norm_hidden_states.to(\n            orig_dtype\n        ), hidden_states.to(orig_dtype)\n\n        # 3. Feed-forward\n        ff_output = self.ffn(norm_hidden_states)\n        hidden_states = self.mlp_residual(ff_output, c_gate_msa, hidden_states)\n        hidden_states = hidden_states.to(orig_dtype)\n\n        return hidden_states\n\n\nclass WanTransformerBlock_VSA(nn.Module):\n\n    def __init__(\n        self,\n        dim: int,\n        ffn_dim: int,\n        num_heads: int,\n        qk_norm: str = \"rms_norm_across_heads\",\n        cross_attn_norm: bool = False,\n        eps: float = 1e-6,\n        added_kv_proj_dim: int | None = None,\n        supported_attention_backends: set[AttentionBackendEnum] | None = None,\n        prefix: str = \"\",\n        quant_config: QuantizationConfig | None = None,\n    ):\n        super().__init__()\n\n        # 1. Self-attention\n        self.norm1 = LayerNormScaleShift(\n            dim,\n            eps=eps,\n            elementwise_affine=False,\n            dtype=torch.float32,\n        )\n        self.to_q = ColumnParallelLinear(\n            dim,\n            dim,\n            bias=True,\n            gather_output=True,\n            quant_config=quant_config,\n            prefix=add_prefix(\"attn1.to_q\", prefix),\n        )\n        self.to_k = ColumnParallelLinear(\n            dim,\n            dim,\n            bias=True,\n            gather_output=True,\n            quant_config=quant_config,\n            prefix=add_prefix(\"attn1.to_k\", prefix),\n        )\n        self.to_v = ColumnParallelLinear(\n            dim,\n            dim,\n            bias=True,\n            gather_output=True,\n            quant_config=quant_config,\n            prefix=add_prefix(\"attn1.to_v\", prefix),\n        )\n        self.to_gate_compress = ColumnParallelLinear(\n            dim,\n            dim,\n            bias=True,\n            gather_output=True,\n            quant_config=quant_config,\n            prefix=add_prefix(\"attn1.to_gate_compress\", prefix),\n        )\n\n        self.to_out = ColumnParallelLinear(\n            dim,\n            dim,\n            bias=True,\n            gather_output=True,\n            quant_config=quant_config,\n            prefix=add_prefix(\"attn1.to_out.0\", prefix),\n        )\n        self.attn1 = UlyssesAttention_VSA(\n            num_heads=num_heads,\n            head_size=dim // num_heads,\n            causal=False,\n            supported_attention_backends=supported_attention_backends,\n            prefix=add_prefix(\"attn1\", prefix),\n            quant_config=quant_config,\n        )\n        self.hidden_dim = dim\n        self.num_attention_heads = num_heads\n        dim_head = dim // num_heads\n        if qk_norm == \"rms_norm\":\n            self.norm_q = RMSNorm(dim_head, eps=eps)\n            self.norm_k = RMSNorm(dim_head, eps=eps)\n        elif qk_norm == \"rms_norm_across_heads\":\n            # LTX applies qk norm across all heads\n            self.norm_q = RMSNorm(dim, eps=eps)\n            self.norm_k = RMSNorm(dim, eps=eps)\n        else:\n            logger.error(\"QK Norm type not supported\")\n            raise Exception\n        assert cross_attn_norm is True\n        self.self_attn_residual_norm = ScaleResidualLayerNormScaleShift(\n            dim,\n            eps=eps,\n            elementwise_affine=True,\n            dtype=torch.float32,\n        )\n\n        # 2. Cross-attention\n        cross_attn_backends = {\n            b for b in supported_attention_backends if not b.is_sparse\n        }\n        if added_kv_proj_dim is not None:\n            # I2V\n            self.attn2 = WanI2VCrossAttention(\n                dim,\n                num_heads,\n                qk_norm=qk_norm,\n                eps=eps,\n                prefix=add_prefix(\"attn2\", prefix),\n                supported_attention_backends=cross_attn_backends,\n                quant_config=quant_config,\n            )\n        else:\n            # T2V\n            self.attn2 = WanT2VCrossAttention(\n                dim,\n                num_heads,\n                qk_norm=qk_norm,\n                eps=eps,\n                prefix=add_prefix(\"attn2\", prefix),\n                supported_attention_backends=cross_attn_backends,\n                quant_config=quant_config,\n            )\n        self.cross_attn_residual_norm = ScaleResidualLayerNormScaleShift(\n            dim,\n            eps=eps,\n            elementwise_affine=False,\n            dtype=torch.float32,\n        )\n\n        # 3. Feed-forward\n        self.ffn = MLP(\n            dim,\n            ffn_dim,\n            act_type=\"gelu_pytorch_tanh\",\n            prefix=add_prefix(\"ffn.net\", prefix),\n            quant_config=quant_config,\n        )\n        self.mlp_residual = MulAdd()\n\n        self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: torch.Tensor,\n        temb: torch.Tensor,\n        freqs_cis: tuple[torch.Tensor, torch.Tensor],\n    ) -> torch.Tensor:\n        if hidden_states.dim() == 4:\n            hidden_states = hidden_states.squeeze(1)\n        bs, seq_length, _ = hidden_states.shape\n        orig_dtype = hidden_states.dtype\n        # assert orig_dtype != torch.float32\n        e = self.scale_shift_table + temb.float()\n        shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = e.chunk(\n            6, dim=1\n        )\n        assert shift_msa.dtype == torch.float32\n\n        # 1. Self-attention\n        norm_hidden_states = self.norm1(hidden_states, shift_msa, scale_msa)\n        query, _ = self.to_q(norm_hidden_states)\n        key, _ = self.to_k(norm_hidden_states)\n        value, _ = self.to_v(norm_hidden_states)\n        gate_compress, _ = self.to_gate_compress(norm_hidden_states)\n\n        if self.norm_q is not None:\n            query = self.norm_q(query)\n        if self.norm_k is not None:\n            key = self.norm_k(key)\n\n        query = query.squeeze(1).unflatten(2, (self.num_attention_heads, -1))\n        key = key.squeeze(1).unflatten(2, (self.num_attention_heads, -1))\n        value = value.squeeze(1).unflatten(2, (self.num_attention_heads, -1))\n        gate_compress = gate_compress.squeeze(1).unflatten(\n            2, (self.num_attention_heads, -1)\n        )\n\n        # Apply rotary embeddings\n        cos, sin = freqs_cis\n        if _is_cuda and query.shape == key.shape:\n            cos_sin_cache = torch.cat(\n                [\n                    cos.to(dtype=torch.float32).contiguous(),\n                    sin.to(dtype=torch.float32).contiguous(),\n                ],\n                dim=-1,\n            )\n            query, key = apply_flashinfer_rope_qk_inplace(\n                query, key, cos_sin_cache, is_neox=False\n            )\n        else:\n            query, key = _apply_rotary_emb(\n                query, cos, sin, is_neox_style=False\n            ), _apply_rotary_emb(key, cos, sin, is_neox_style=False)\n\n        attn_output = self.attn1(query, key, value, gate_compress=gate_compress)\n        attn_output = attn_output.flatten(2)\n        attn_output, _ = self.to_out(attn_output)\n        attn_output = attn_output.squeeze(1)\n\n        null_shift = null_scale = torch.zeros((1,), device=hidden_states.device)\n        norm_hidden_states, hidden_states = self.self_attn_residual_norm(\n            hidden_states, attn_output, gate_msa, null_shift, null_scale\n        )\n        norm_hidden_states, hidden_states = norm_hidden_states.to(\n            orig_dtype\n        ), hidden_states.to(orig_dtype)\n\n        # 2. Cross-attention\n        attn_output = self.attn2(\n            norm_hidden_states, context=encoder_hidden_states, context_lens=None\n        )\n        norm_hidden_states, hidden_states = self.cross_attn_residual_norm(\n            hidden_states, attn_output, 1, c_shift_msa, c_scale_msa\n        )\n        norm_hidden_states, hidden_states = norm_hidden_states.to(\n            orig_dtype\n        ), hidden_states.to(orig_dtype)\n\n        # 3. Feed-forward\n        ff_output = self.ffn(norm_hidden_states)\n        hidden_states = self.mlp_residual(ff_output, c_gate_msa, hidden_states)\n        hidden_states = hidden_states.to(orig_dtype)\n\n        return hidden_states\n\n\nclass WanTransformer3DModel(CachableDiT, OffloadableDiTMixin):\n    _fsdp_shard_conditions = WanVideoConfig()._fsdp_shard_conditions\n    _compile_conditions = WanVideoConfig()._compile_conditions\n    _supported_attention_backends = WanVideoConfig()._supported_attention_backends\n    param_names_mapping = WanVideoConfig().param_names_mapping\n    reverse_param_names_mapping = WanVideoConfig().reverse_param_names_mapping\n    lora_param_names_mapping = WanVideoConfig().lora_param_names_mapping\n\n    def __init__(\n        self,\n        config: WanVideoConfig,\n        hf_config: dict[str, Any],\n        quant_config: QuantizationConfig | None = None,\n    ) -> None:\n        super().__init__(config=config, hf_config=hf_config)\n\n        inner_dim = config.num_attention_heads * config.attention_head_dim\n        self.hidden_size = config.hidden_size\n        self.num_attention_heads = config.num_attention_heads\n        self.in_channels = config.in_channels\n        self.out_channels = config.out_channels\n        self.num_channels_latents = config.num_channels_latents\n        self.patch_size = config.patch_size\n        self.text_len = config.text_len\n\n        # 1. Patch & position embedding\n        self.patch_embedding = PatchEmbed(\n            in_chans=config.in_channels,\n            embed_dim=inner_dim,\n            patch_size=config.patch_size,\n            flatten=False,\n        )\n\n        # 2. Condition embeddings\n        self.condition_embedder = WanTimeTextImageEmbedding(\n            dim=inner_dim,\n            time_freq_dim=config.freq_dim,\n            text_embed_dim=config.text_dim,\n            image_embed_dim=config.image_dim,\n        )\n\n        # 3. Transformer blocks\n        attn_backend = get_global_server_args().attention_backend\n        transformer_block = (\n            WanTransformerBlock_VSA\n            if (attn_backend and attn_backend.lower() == \"video_sparse_attn\")\n            else WanTransformerBlock\n        )\n        self.blocks = nn.ModuleList(\n            [\n                transformer_block(\n                    inner_dim,\n                    config.ffn_dim,\n                    config.num_attention_heads,\n                    config.qk_norm,\n                    config.cross_attn_norm,\n                    config.eps,\n                    config.added_kv_proj_dim,\n                    self._supported_attention_backends\n                    | {AttentionBackendEnum.VIDEO_SPARSE_ATTN},\n                    prefix=f\"blocks.{i}\",\n                    attention_type=config.attention_type,\n                    sla_topk=config.sla_topk,\n                    quant_config=quant_config,\n                )\n                for i in range(config.num_layers)\n            ]\n        )\n\n        # 4. Output norm & projection\n        self.norm_out = LayerNormScaleShift(\n            inner_dim,\n            eps=config.eps,\n            elementwise_affine=False,\n            dtype=torch.float32,\n        )\n        self.proj_out = ColumnParallelLinear(\n            inner_dim,\n            config.out_channels * math.prod(config.patch_size),\n            bias=True,\n            gather_output=True,\n            prefix=f\"proj_out\",\n            quant_config=quant_config,\n        )\n        self.scale_shift_table = nn.Parameter(\n            torch.randn(1, 2, inner_dim) / inner_dim**0.5\n        )\n\n        # For type checking\n\n        self.cnt = 0\n        self.__post_init__()\n\n        # misc\n        self.sp_size = get_sp_world_size()\n\n        # Get rotary embeddings\n        d = self.hidden_size // self.num_attention_heads\n        self.rope_dim_list = [d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)]\n\n        self.rotary_emb = NDRotaryEmbedding(\n            rope_dim_list=self.rope_dim_list,\n            rope_theta=10000,\n            dtype=(\n                torch.float32\n                if current_platform.is_mps() or current_platform.is_musa()\n                else torch.float64\n            ),\n        )\n\n        self.layer_names = [\"blocks\"]\n\n    @lru_cache(maxsize=1)\n    def _compute_rope_for_sequence_shard(\n        self,\n        local_len: int,\n        rank: int,\n        frame_stride_local: int,\n        width_local: int,\n        device: torch.device,\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        token_start = rank * local_len\n        token_indices = torch.arange(\n            token_start,\n            token_start + local_len,\n            device=device,\n            dtype=torch.long,\n        )\n        t_idx = token_indices // frame_stride_local\n        rem = token_indices % frame_stride_local\n        h_idx = rem // width_local\n        w_idx = rem % width_local\n        positions = torch.stack((t_idx, h_idx, w_idx), dim=1)\n        return self.rotary_emb.forward_uncached(positions)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: torch.Tensor | list[torch.Tensor],\n        timestep: torch.LongTensor,\n        encoder_hidden_states_image: torch.Tensor | list[torch.Tensor] | None = None,\n        guidance=None,\n        **kwargs,\n    ) -> torch.Tensor:\n        forward_batch = get_forward_context().forward_batch\n        if forward_batch is not None:\n            sequence_shard_enabled = (\n                forward_batch.enable_sequence_shard and self.sp_size > 1\n            )\n        else:\n            sequence_shard_enabled = False\n        self.enable_teacache = (\n            forward_batch is not None and forward_batch.enable_teacache\n        )\n\n        orig_dtype = hidden_states.dtype\n        if not isinstance(encoder_hidden_states, torch.Tensor):\n            encoder_hidden_states = encoder_hidden_states[0]\n        if (\n            isinstance(encoder_hidden_states_image, list)\n            and len(encoder_hidden_states_image) > 0\n        ):\n            encoder_hidden_states_image = encoder_hidden_states_image[0]\n        else:\n            encoder_hidden_states_image = None\n\n        batch_size, num_channels, num_frames, height, width = hidden_states.shape\n\n        p_t, p_h, p_w = self.patch_size\n        post_patch_num_frames = num_frames // p_t\n        post_patch_height = height // p_h\n        post_patch_width = width // p_w\n\n        if not sequence_shard_enabled:\n            # The rotary embedding layer correctly handles SP offsets internally.\n            freqs_cos, freqs_sin = self.rotary_emb.forward_from_grid(\n                (\n                    post_patch_num_frames * self.sp_size,\n                    post_patch_height,\n                    post_patch_width,\n                ),\n                shard_dim=0,\n                start_frame=0,\n                device=hidden_states.device,\n            )\n            assert freqs_cos.dtype == torch.float32\n            assert freqs_cos.device == hidden_states.device\n            freqs_cis = (\n                (freqs_cos.float(), freqs_sin.float())\n                if freqs_cos is not None\n                else None\n            )\n\n        hidden_states = self.patch_embedding(hidden_states)\n        hidden_states = hidden_states.flatten(2).transpose(1, 2)\n\n        # shape is [B, T' * H' * W', C]\n        seq_len_orig = hidden_states.shape[1]\n        seq_shard_pad = 0\n        if sequence_shard_enabled:\n            if seq_len_orig % self.sp_size != 0:\n                seq_shard_pad = self.sp_size - (seq_len_orig % self.sp_size)\n                pad = torch.zeros(\n                    (batch_size, seq_shard_pad, hidden_states.shape[2]),\n                    dtype=hidden_states.dtype,\n                    device=hidden_states.device,\n                )\n                hidden_states = torch.cat([hidden_states, pad], dim=1)\n            sp_rank = get_sp_group().rank_in_group\n            local_seq_len = hidden_states.shape[1] // self.sp_size\n            hidden_states = hidden_states.view(\n                batch_size, self.sp_size, local_seq_len, hidden_states.shape[2]\n            )\n            hidden_states = hidden_states[:, sp_rank, :, :]\n\n            frame_stride = post_patch_height * post_patch_width\n            freqs_cos, freqs_sin = self._compute_rope_for_sequence_shard(\n                local_seq_len,\n                sp_rank,\n                frame_stride,\n                post_patch_width,\n                hidden_states.device,\n            )\n            freqs_cis = (freqs_cos.float(), freqs_sin.float())\n\n        # timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v)\n        if timestep.dim() == 2:\n            # ti2v\n            ts_seq_len = timestep.shape[1]\n            timestep = timestep.flatten()  # batch_size * seq_len\n        else:\n            ts_seq_len = None\n\n        temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = (\n            self.condition_embedder(\n                timestep,\n                encoder_hidden_states,\n                encoder_hidden_states_image,\n                timestep_seq_len=ts_seq_len,\n            )\n        )\n        if ts_seq_len is not None:\n            # batch_size, seq_len, 6, inner_dim\n            timestep_proj = timestep_proj.unflatten(2, (6, -1))\n        else:\n            # batch_size, 6, inner_dim\n            timestep_proj = timestep_proj.unflatten(1, (6, -1))\n\n        if sequence_shard_enabled and ts_seq_len is not None:\n            if seq_shard_pad > 0:\n                pad = torch.zeros(\n                    (\n                        batch_size,\n                        seq_shard_pad,\n                        timestep_proj.shape[2],\n                        timestep_proj.shape[3],\n                    ),\n                    dtype=timestep_proj.dtype,\n                    device=timestep_proj.device,\n                )\n                timestep_proj = torch.cat([timestep_proj, pad], dim=1)\n            timestep_proj = timestep_proj.view(\n                batch_size,\n                self.sp_size,\n                local_seq_len,\n                timestep_proj.shape[2],\n                timestep_proj.shape[3],\n            )\n            timestep_proj = timestep_proj[:, sp_rank, :, :, :]\n\n        if encoder_hidden_states_image is not None:\n            encoder_hidden_states = torch.concat(\n                [encoder_hidden_states_image, encoder_hidden_states], dim=1\n            )\n\n        encoder_hidden_states = (\n            encoder_hidden_states.to(orig_dtype)\n            if not current_platform.is_amp_supported()\n            else encoder_hidden_states\n        )  # cast to orig_dtype for MPS\n\n        assert encoder_hidden_states.dtype == orig_dtype\n\n        # 4. Transformer blocks\n        # if caching is enabled, we might be able to skip the forward pass\n        should_skip_forward = self.should_skip_forward_for_cached_states(\n            timestep_proj=timestep_proj, temb=temb\n        )\n\n        if should_skip_forward:\n            hidden_states = self.retrieve_cached_states(hidden_states)\n        else:\n            # if teacache is enabled, we need to cache the original hidden states\n            if self.enable_teacache:\n                original_hidden_states = hidden_states.clone()\n\n            for block in self.blocks:\n                hidden_states = block(\n                    hidden_states, encoder_hidden_states, timestep_proj, freqs_cis\n                )\n            # if teacache is enabled, we need to cache the original hidden states\n            if self.enable_teacache:\n                self.maybe_cache_states(hidden_states, original_hidden_states)\n        self.cnt += 1\n\n        if sequence_shard_enabled:\n            hidden_states = hidden_states.contiguous()\n            hidden_states = sequence_model_parallel_all_gather(hidden_states, dim=1)\n            if seq_shard_pad > 0:\n                hidden_states = hidden_states[:, :seq_len_orig, :]\n\n        # 5. Output norm, projection & unpatchify\n        if temb.dim() == 3:\n            # batch_size, seq_len, inner_dim (wan 2.2 ti2v)\n            shift, scale = (\n                self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)\n            ).chunk(2, dim=2)\n            shift = shift.squeeze(2)\n            scale = scale.squeeze(2)\n        else:\n            # batch_size, inner_dim\n            shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)\n\n        hidden_states = self.norm_out(hidden_states, shift, scale)\n        hidden_states, _ = self.proj_out(hidden_states)\n\n        hidden_states = hidden_states.reshape(\n            batch_size,\n            post_patch_num_frames,\n            post_patch_height,\n            post_patch_width,\n            p_t,\n            p_h,\n            p_w,\n            -1,\n        )\n        hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)\n        output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)\n\n        return output\n\n    def maybe_cache_states(\n        self, hidden_states: torch.Tensor, original_hidden_states: torch.Tensor\n    ) -> None:\n        \"\"\"Cache residual with CFG positive/negative separation.\"\"\"\n        residual = hidden_states.squeeze(0) - original_hidden_states\n        if not self.is_cfg_negative:\n            self.previous_residual = residual\n        else:\n            self.previous_residual_negative = residual\n\n    def should_skip_forward_for_cached_states(self, **kwargs) -> bool:\n        if not self.enable_teacache:\n            return False\n        ctx = self._get_teacache_context()\n        if ctx is None:\n            return False\n\n        # Wan uses WanTeaCacheParams with additional fields\n        teacache_params = ctx.teacache_params\n        assert isinstance(\n            teacache_params, WanTeaCacheParams\n        ), \"teacache_params is not a WanTeaCacheParams\"\n\n        # Initialize Wan-specific parameters\n        use_ret_steps = teacache_params.use_ret_steps\n        cutoff_steps = teacache_params.get_cutoff_steps(ctx.num_inference_steps)\n        ret_steps = teacache_params.ret_steps\n\n        # Adjust ret_steps and cutoff_steps for non-CFG mode\n        # (WanTeaCacheParams uses *2 factor assuming CFG)\n        if not ctx.do_cfg:\n            ret_steps = ret_steps // 2\n            cutoff_steps = cutoff_steps // 2\n\n        timestep_proj = kwargs[\"timestep_proj\"]\n        temb = kwargs[\"temb\"]\n        modulated_inp = timestep_proj if use_ret_steps else temb\n\n        self.is_cfg_negative = ctx.is_cfg_negative\n\n        # Wan uses ret_steps/cutoff_steps for boundary detection\n        is_boundary_step = self.cnt < ret_steps or self.cnt >= cutoff_steps\n\n        # Use shared helper to compute cache decision\n        should_calc = self._compute_teacache_decision(\n            modulated_inp=modulated_inp,\n            is_boundary_step=is_boundary_step,\n            coefficients=ctx.coefficients,\n            teacache_thresh=ctx.teacache_thresh,\n        )\n\n        return not should_calc\n\n    def retrieve_cached_states(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        \"\"\"Retrieve cached residual with CFG positive/negative separation.\"\"\"\n        if not self.is_cfg_negative:\n            return hidden_states + self.previous_residual\n        else:\n            return hidden_states + self.previous_residual_negative\n\n\nEntryClass = WanTransformer3DModel\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/dits/zimage.py",
    "content": "import math\nfrom typing import Any, List, Optional, Tuple\n\nimport torch\nimport torch.nn as nn\n\nfrom sglang.multimodal_gen.configs.models.dits.zimage import ZImageDitConfig\nfrom sglang.multimodal_gen.runtime.distributed import get_tp_world_size\nfrom sglang.multimodal_gen.runtime.layers.activation import SiluAndMul\nfrom sglang.multimodal_gen.runtime.layers.attention import USPAttention\nfrom sglang.multimodal_gen.runtime.layers.layernorm import RMSNorm, apply_qk_norm\nfrom sglang.multimodal_gen.runtime.layers.linear import (\n    ColumnParallelLinear,\n    MergedColumnParallelLinear,\n    ReplicatedLinear,\n    RowParallelLinear,\n)\nfrom sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import (\n    QuantizationConfig,\n)\nfrom sglang.multimodal_gen.runtime.layers.quantization.configs.nunchaku_config import (\n    NunchakuConfig,\n    is_nunchaku_available,\n)\nfrom sglang.multimodal_gen.runtime.layers.rotary_embedding import (\n    _apply_rotary_emb,\n    apply_flashinfer_rope_qk_inplace,\n)\nfrom sglang.multimodal_gen.runtime.models.dits.base import CachableDiT\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\nfrom sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\ntry:\n    from nunchaku.models.attention import NunchakuFeedForward  # type: ignore[import]\nexcept Exception:\n    NunchakuFeedForward = None\n\nlogger = init_logger(__name__)\n_is_cuda = current_platform.is_cuda()\n\nADALN_EMBED_DIM = 256\nSEQ_MULTI_OF = 32\n\n\nclass SelectFirstElement(nn.Module):\n    def __init__(self):\n        super().__init__()\n\n    def forward(self, x):\n        return x[0]\n\n\nclass TimestepEmbedder(nn.Module):\n    def __init__(self, out_size, mid_size=None, frequency_embedding_size=256):\n        super().__init__()\n        if mid_size is None:\n            mid_size = out_size\n\n        self.mlp = nn.ModuleList(\n            [\n                ColumnParallelLinear(\n                    frequency_embedding_size, mid_size, bias=True, gather_output=False\n                ),\n                nn.SiLU(),\n                RowParallelLinear(\n                    mid_size, out_size, bias=True, input_is_parallel=True\n                ),\n            ]\n        )\n\n        self.frequency_embedding_size = frequency_embedding_size\n\n    @staticmethod\n    def timestep_embedding(t, dim, max_period=10000):\n        with torch.amp.autocast(current_platform.device_type, enabled=False):\n            half = dim // 2\n            freqs = torch.exp(\n                -math.log(max_period)\n                * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)\n                / half\n            )\n            args = t[:, None].float() * freqs[None]\n            embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)\n            if dim % 2:\n                embedding = torch.cat(\n                    [embedding, torch.zeros_like(embedding[:, :1])], dim=-1\n                )\n            return embedding\n\n    def forward(self, t):\n        t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(\n            self.mlp[0].weight.dtype\n        )\n        t_emb, _ = self.mlp[0](t_freq)\n        t_emb = self.mlp[1](t_emb)\n        t_emb, _ = self.mlp[2](t_emb)\n        return t_emb\n\n\nclass FeedForward(nn.Module):\n    def __init__(self, dim: int, hidden_dim: int):\n        super().__init__()\n        # Use MergedColumnParallelLinear for gate and up projection (fused)\n        self.w13 = MergedColumnParallelLinear(\n            dim, [hidden_dim, hidden_dim], bias=False, gather_output=False\n        )\n        self.w2 = RowParallelLinear(hidden_dim, dim, bias=False, input_is_parallel=True)\n        self.act = SiluAndMul()\n\n    def forward(self, x):\n        x13, _ = self.w13(x)\n        x = self.act(x13)\n        out, _ = self.w2(x)\n        return out\n\n\nclass ZImageAttention(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        num_heads: int,\n        num_kv_heads: int,\n        qk_norm: bool = True,\n        eps: float = 1e-6,\n        quant_config: Optional[QuantizationConfig] = None,\n        prefix: str = \"\",\n    ) -> None:\n        super().__init__()\n        self.dim = dim\n        self.head_dim = dim // num_heads\n        self.num_heads = num_heads\n        self.num_kv_heads = num_kv_heads\n        self.qk_norm = qk_norm\n\n        tp_size = get_tp_world_size()\n        assert (\n            num_heads % tp_size == 0\n        ), f\"num_heads {num_heads} must be divisible by tp world size {tp_size}\"\n        assert (\n            num_kv_heads % tp_size == 0\n        ), f\"num_kv_heads {num_kv_heads} must be divisible by tp world size {tp_size}\"\n        self.local_num_heads = num_heads // tp_size\n        self.local_num_kv_heads = num_kv_heads // tp_size\n\n        kv_dim = self.head_dim * num_kv_heads\n        self.use_fused_qkv = isinstance(quant_config, NunchakuConfig)\n\n        if self.use_fused_qkv:\n            self.to_qkv = MergedColumnParallelLinear(\n                dim,\n                [dim, kv_dim, kv_dim],\n                bias=False,\n                gather_output=False,\n                quant_config=quant_config,\n                prefix=f\"{prefix}.to_qkv\",\n            )\n        else:\n            self.to_q = ColumnParallelLinear(\n                dim,\n                dim,\n                bias=False,\n                gather_output=False,\n                quant_config=quant_config,\n                prefix=f\"{prefix}.to_q\",\n            )\n            self.to_k = ColumnParallelLinear(\n                dim,\n                kv_dim,\n                bias=False,\n                gather_output=False,\n                quant_config=quant_config,\n                prefix=f\"{prefix}.to_k\",\n            )\n            self.to_v = ColumnParallelLinear(\n                dim,\n                kv_dim,\n                bias=False,\n                gather_output=False,\n                quant_config=quant_config,\n                prefix=f\"{prefix}.to_v\",\n            )\n\n        if self.qk_norm:\n            self.norm_q = RMSNorm(self.head_dim, eps=eps)\n            self.norm_k = RMSNorm(self.head_dim, eps=eps)\n        else:\n            self.norm_q = None\n            self.norm_k = None\n\n        self.to_out = nn.ModuleList(\n            [\n                RowParallelLinear(\n                    dim,\n                    dim,\n                    bias=False,\n                    input_is_parallel=True,\n                    quant_config=quant_config,\n                    prefix=f\"{prefix}.to_out.0\",\n                )\n            ]\n        )\n\n        self.attn = USPAttention(\n            num_heads=self.local_num_heads,\n            head_size=self.head_dim,\n            num_kv_heads=self.local_num_kv_heads,\n            dropout_rate=0,\n            softmax_scale=None,\n            causal=False,\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n    ):\n        if self.use_fused_qkv:\n            qkv, _ = self.to_qkv(hidden_states)\n            q, k, v = qkv.split(\n                [\n                    self.local_num_heads * self.head_dim,\n                    self.local_num_kv_heads * self.head_dim,\n                    self.local_num_kv_heads * self.head_dim,\n                ],\n                dim=-1,\n            )\n            q = q.contiguous()\n            k = k.contiguous()\n            v = v.contiguous()\n        else:\n            q, _ = self.to_q(hidden_states)\n            k, _ = self.to_k(hidden_states)\n            v, _ = self.to_v(hidden_states)\n        q = q.view(*q.shape[:-1], self.local_num_heads, self.head_dim)\n        k = k.view(*k.shape[:-1], self.local_num_kv_heads, self.head_dim)\n        v = v.view(*v.shape[:-1], self.local_num_kv_heads, self.head_dim)\n\n        if self.qk_norm:\n            q, k = apply_qk_norm(\n                q=q,\n                k=k,\n                q_norm=self.norm_q,\n                k_norm=self.norm_k,\n                head_dim=self.head_dim,\n                allow_inplace=True,\n            )\n\n        if freqs_cis is not None:\n            cos, sin = freqs_cis\n            if _is_cuda and q.shape == k.shape:\n                cos_sin_cache = torch.cat(\n                    [\n                        cos.to(dtype=torch.float32).contiguous(),\n                        sin.to(dtype=torch.float32).contiguous(),\n                    ],\n                    dim=-1,\n                )\n                q, k = apply_flashinfer_rope_qk_inplace(\n                    q, k, cos_sin_cache, is_neox=False\n                )\n            else:\n                q = _apply_rotary_emb(q, cos, sin, is_neox_style=False)\n                k = _apply_rotary_emb(k, cos, sin, is_neox_style=False)\n\n        hidden_states = self.attn(q, k, v)\n        hidden_states = hidden_states.flatten(2)\n\n        hidden_states, _ = self.to_out[0](hidden_states)\n\n        return hidden_states\n\n\nclass ZImageTransformerBlock(nn.Module):\n    def __init__(\n        self,\n        layer_id: int,\n        dim: int,\n        n_heads: int,\n        n_kv_heads: int,\n        norm_eps: float,\n        qk_norm: bool,\n        modulation=True,\n        quant_config: Optional[QuantizationConfig] = None,\n        prefix: str = \"\",\n    ):\n        super().__init__()\n        self.dim = dim\n        self.head_dim = dim // n_heads\n        self.layer_id = layer_id\n        self.modulation = modulation\n\n        self.attention = ZImageAttention(\n            dim=dim,\n            num_heads=n_heads,\n            num_kv_heads=n_kv_heads,\n            qk_norm=qk_norm,\n            eps=1e-5,\n            quant_config=quant_config,\n            prefix=f\"{prefix}.attention\",\n        )\n\n        hidden_dim = int(dim / 3 * 8)\n        nunchaku_enabled = (\n            isinstance(quant_config, NunchakuConfig) and is_nunchaku_available()\n        )\n        if nunchaku_enabled:\n            import diffusers\n\n            ff = diffusers.models.attention.FeedForward(\n                dim=dim,\n                dim_out=dim,\n                activation_fn=\"swiglu\",\n                inner_dim=hidden_dim,\n                bias=False,\n            )\n            nunchaku_kwargs = {\n                \"precision\": quant_config.precision,\n                \"rank\": quant_config.rank,\n                \"act_unsigned\": quant_config.act_unsigned,\n            }\n            self.feed_forward = NunchakuFeedForward(ff, **nunchaku_kwargs)\n            # NunchakuFeedForward overrides net[2].act_unsigned=True for int4 (GELU-specific\n            # optimization for non-negative activations). Z-Image uses SwiGLU whose output\n            # can be negative, so we must restore the original act_unsigned value.\n            if hasattr(self.feed_forward, \"net\") and len(self.feed_forward.net) > 2:\n                self.feed_forward.net[2].act_unsigned = quant_config.act_unsigned\n        else:\n            self.feed_forward = FeedForward(dim=dim, hidden_dim=hidden_dim)\n\n        self.attention_norm1 = RMSNorm(dim, eps=norm_eps)\n        self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)\n\n        self.attention_norm2 = RMSNorm(dim, eps=norm_eps)\n        self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)\n\n        if modulation:\n            self.adaLN_modulation = nn.Sequential(\n                ReplicatedLinear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True)\n            )\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        freqs_cis: Tuple[torch.Tensor, torch.Tensor],\n        adaln_input: Optional[torch.Tensor] = None,\n    ):\n        if self.modulation:\n            assert adaln_input is not None\n            scale_msa_gate, _ = self.adaLN_modulation(adaln_input)\n            scale_msa, gate_msa, scale_mlp, gate_mlp = scale_msa_gate.unsqueeze(\n                1\n            ).chunk(4, dim=2)\n            gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()\n            scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp\n\n            # Attention block\n            attn_out = self.attention(\n                self.attention_norm1(x) * scale_msa,\n                freqs_cis=freqs_cis,\n            )\n            x = x + gate_msa * self.attention_norm2(attn_out)\n\n            # FFN block\n            x = x + gate_mlp * self.ffn_norm2(\n                self.feed_forward(\n                    self.ffn_norm1(x) * scale_mlp,\n                )\n            )\n        else:\n            # Attention block\n            attn_out = self.attention(\n                self.attention_norm1(x),\n                freqs_cis=freqs_cis,\n            )\n            x = x + self.attention_norm2(attn_out)\n\n            # FFN block\n            x = x + self.ffn_norm2(\n                self.feed_forward(\n                    self.ffn_norm1(x),\n                )\n            )\n\n        return x\n\n\nclass FinalLayer(nn.Module):\n    def __init__(self, hidden_size, out_channels):\n        super().__init__()\n        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)\n        self.linear = ColumnParallelLinear(\n            hidden_size, out_channels, bias=True, gather_output=True\n        )\n\n        self.act = nn.SiLU()\n        self.adaLN_modulation = nn.Sequential(\n            nn.SiLU(),\n            ReplicatedLinear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True),\n        )\n\n    def forward(self, x, c):\n        scale, _ = self.adaLN_modulation(c)\n        scale = 1.0 + scale\n        x = self.norm_final(x) * scale.unsqueeze(1)\n        x, _ = self.linear(x)\n        return x\n\n\nclass RopeEmbedder:\n    def __init__(\n        self,\n        theta: float = 256.0,\n        axes_dims: List[int] = (16, 56, 56),\n        axes_lens: List[int] = (64, 128, 128),\n    ):\n        self.theta = theta\n        self.axes_dims = axes_dims\n        self.axes_lens = axes_lens\n        assert len(axes_dims) == len(\n            axes_lens\n        ), \"axes_dims and axes_lens must have the same length\"\n\n        self.cos_cached = None\n        self.sin_cached = None\n\n    @staticmethod\n    def precompute_freqs(dim: List[int], end: List[int], theta: float = 256.0):\n        with torch.device(\"cpu\"):\n            cos_list = []\n            sin_list = []\n            for i, (d, e) in enumerate(zip(dim, end)):\n                freqs = 1.0 / (\n                    theta\n                    ** (torch.arange(0, d, 2, dtype=torch.float64, device=\"cpu\") / d)\n                )\n                timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)\n                freqs = torch.outer(timestep, freqs).float()\n\n                cos_list.append(torch.cos(freqs))\n                sin_list.append(torch.sin(freqs))\n\n            return cos_list, sin_list\n\n    def __call__(self, ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Args:\n            ids: [batch, len(axes_dims)] or [seq_len, len(axes_dims)]\n        Returns:\n            cos: [batch/seq, head_dim // 2]\n            sin: [batch/seq, head_dim // 2]\n        \"\"\"\n        assert ids.ndim == 2\n        assert ids.shape[-1] == len(self.axes_dims)\n        device = ids.device\n\n        if self.cos_cached is None:\n            self.cos_cached, self.sin_cached = self.precompute_freqs(\n                self.axes_dims, self.axes_lens, theta=self.theta\n            )\n            self.cos_cached = [c.to(device) for c in self.cos_cached]\n            self.sin_cached = [s.to(device) for s in self.sin_cached]\n        else:\n            if self.cos_cached[0].device != device:\n                self.cos_cached = [c.to(device) for c in self.cos_cached]\n                self.sin_cached = [s.to(device) for s in self.sin_cached]\n\n        cos_out = []\n        sin_out = []\n        for i in range(len(self.axes_dims)):\n            index = ids[:, i]\n            cos_out.append(self.cos_cached[i][index])\n            sin_out.append(self.sin_cached[i][index])\n\n        return torch.cat(cos_out, dim=-1), torch.cat(sin_out, dim=-1)\n\n\nclass ZImageTransformer2DModel(CachableDiT, OffloadableDiTMixin):\n    _supports_gradient_checkpointing = True\n    _no_split_modules = [\"ZImageTransformerBlock\"]\n    _fsdp_shard_conditions = ZImageDitConfig().arch_config._fsdp_shard_conditions\n    param_names_mapping = ZImageDitConfig().arch_config.param_names_mapping\n\n    param_names_mapping = ZImageDitConfig().arch_config.param_names_mapping\n    reverse_param_names_mapping = (\n        ZImageDitConfig().arch_config.reverse_param_names_mapping\n    )\n\n    @classmethod\n    def get_nunchaku_quant_rules(cls) -> dict[str, list[str]]:\n        return {\n            \"skip\": [\n                \"norm\",\n                \"embed\",\n                \"rotary\",\n                \"pos_embed\",\n            ],\n            \"svdq_w4a4\": [\n                \"attention.to_qkv\",\n                \"attention.to_out\",\n                \"img_mlp\",\n                \"txt_mlp\",\n            ],\n            \"awq_w4a16\": [\n                \"img_mod\",\n                \"txt_mod\",\n            ],\n        }\n\n    def __init__(\n        self,\n        config: ZImageDitConfig,\n        hf_config: dict[str, Any],\n        quant_config: Optional[QuantizationConfig] = None,\n    ) -> None:\n        super().__init__(config=config, hf_config=hf_config)\n\n        self.config_data = config  # Store config\n        arch_config = config.arch_config\n\n        self.in_channels = arch_config.in_channels\n        self.out_channels = arch_config.out_channels\n        self.all_patch_size = arch_config.all_patch_size\n        self.all_f_patch_size = arch_config.all_f_patch_size\n        self.dim = arch_config.dim\n        self.n_heads = arch_config.num_attention_heads\n\n        self.rope_theta = arch_config.rope_theta\n        self.t_scale = arch_config.t_scale\n        self.gradient_checkpointing = False\n\n        assert len(self.all_patch_size) == len(self.all_f_patch_size)\n\n        all_x_embedder = {}\n        all_final_layer = {}\n        for patch_idx, (patch_size, f_patch_size) in enumerate(\n            zip(self.all_patch_size, self.all_f_patch_size)\n        ):\n            x_embedder = ColumnParallelLinear(\n                f_patch_size * patch_size * patch_size * self.in_channels,\n                self.dim,\n                bias=True,\n                gather_output=True,\n            )\n            all_x_embedder[f\"{patch_size}-{f_patch_size}\"] = x_embedder\n\n            final_layer = FinalLayer(\n                self.dim, patch_size * patch_size * f_patch_size * self.out_channels\n            )\n            all_final_layer[f\"{patch_size}-{f_patch_size}\"] = final_layer\n\n        self.all_x_embedder = nn.ModuleDict(all_x_embedder)\n        self.all_final_layer = nn.ModuleDict(all_final_layer)\n\n        self.noise_refiner = nn.ModuleList(\n            [\n                ZImageTransformerBlock(\n                    1000 + layer_id,\n                    self.dim,\n                    self.n_heads,\n                    arch_config.n_kv_heads,\n                    arch_config.norm_eps,\n                    arch_config.qk_norm,\n                    modulation=True,\n                    quant_config=quant_config,\n                    prefix=f\"noise_refiner.{layer_id}\",\n                )\n                for layer_id in range(arch_config.n_refiner_layers)\n            ]\n        )\n        self.context_refiner = nn.ModuleList(\n            [\n                ZImageTransformerBlock(\n                    layer_id,\n                    self.dim,\n                    self.n_heads,\n                    arch_config.n_kv_heads,\n                    arch_config.norm_eps,\n                    arch_config.qk_norm,\n                    modulation=False,\n                    quant_config=quant_config,\n                    prefix=f\"context_refiner.{layer_id}\",\n                )\n                for layer_id in range(arch_config.n_refiner_layers)\n            ]\n        )\n        self.t_embedder = TimestepEmbedder(\n            min(self.dim, ADALN_EMBED_DIM), mid_size=1024\n        )\n\n        self.cap_embedder = nn.Sequential(\n            RMSNorm(arch_config.cap_feat_dim, eps=arch_config.norm_eps),\n            ReplicatedLinear(arch_config.cap_feat_dim, self.dim, bias=True),\n        )\n\n        self.x_pad_token = nn.Parameter(torch.empty((1, self.dim)))\n        self.cap_pad_token = nn.Parameter(torch.empty((1, self.dim)))\n\n        self.layers = nn.ModuleList(\n            [\n                ZImageTransformerBlock(\n                    layer_id,\n                    self.dim,\n                    self.n_heads,\n                    arch_config.n_kv_heads,\n                    arch_config.norm_eps,\n                    arch_config.qk_norm,\n                    quant_config=quant_config,\n                    prefix=f\"layers.{layer_id}\",\n                )\n                for layer_id in range(arch_config.num_layers)\n            ]\n        )\n        head_dim = self.dim // self.n_heads\n        assert head_dim == sum(arch_config.axes_dims)\n        self.axes_dims = arch_config.axes_dims\n        self.axes_lens = arch_config.axes_lens\n\n        self.rotary_emb = RopeEmbedder(\n            theta=self.rope_theta, axes_dims=self.axes_dims, axes_lens=self.axes_lens\n        )\n        self.layer_names = [\"layers\"]\n\n    def unpatchify(\n        self, x: List[torch.Tensor], size: List[Tuple], patch_size, f_patch_size\n    ) -> List[torch.Tensor]:\n        pH = pW = patch_size\n        pF = f_patch_size\n        bsz = len(x)\n        assert len(size) == bsz\n        for i in range(bsz):\n            F, H, W = size[i]\n            ori_len = (F // pF) * (H // pH) * (W // pW)\n            # \"f h w pf ph pw c -> c (f pf) (h ph) (w pw)\"\n            x[i] = (\n                x[i][:ori_len]\n                .view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)\n                .permute(6, 0, 3, 1, 4, 2, 5)\n                .reshape(self.out_channels, F, H, W)\n            )\n        return x\n\n    @staticmethod\n    def create_coordinate_grid(size, start=None, device=None):\n        if start is None:\n            start = (0 for _ in size)\n\n        axes = [\n            torch.arange(x0, x0 + span, dtype=torch.int32, device=device)\n            for x0, span in zip(start, size)\n        ]\n        grids = torch.meshgrid(axes, indexing=\"ij\")\n        return torch.stack(grids, dim=-1)\n\n    def patchify_and_embed(\n        self,\n        all_image: List[torch.Tensor],\n        all_cap_feats: List[torch.Tensor],\n        patch_size: int,\n        f_patch_size: int,\n    ):\n        assert len(all_image) == len(all_cap_feats) == 1\n\n        image = all_image[0]  # C, F, H, W\n        cap_feat = all_cap_feats[0]  # L, D\n        pH = pW = patch_size\n        pF = f_patch_size\n        device = image.device\n\n        all_image_out = []\n        all_image_size = []\n        all_cap_feats_out = []\n\n        # ------------ Process Caption ------------\n        cap_ori_len = cap_feat.size(0)\n        cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF\n\n        # padded feature\n        cap_padded_feat = torch.cat(\n            [cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)],\n            dim=0,\n        )\n        all_cap_feats_out.append(cap_padded_feat)\n\n        # ------------ Process Image ------------\n        C, F, H, W = image.size()\n        all_image_size.append((F, H, W))\n\n        F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW\n        image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)\n        # \"c f pf h ph w pw -> (f h w) (pf ph pw c)\"\n        image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(\n            F_tokens * H_tokens * W_tokens, pF * pH * pW * C\n        )\n        image_ori_len = image.size(0)\n        image_padding_len = (-image_ori_len) % SEQ_MULTI_OF\n\n        # padded feature\n        image_padded_feat = torch.cat(\n            [image, image[-1:].repeat(image_padding_len, 1)],\n            dim=0,\n        )\n        all_image_out.append(image_padded_feat)\n\n        return (\n            all_image_out,\n            all_cap_feats_out,\n            all_image_size,\n        )\n\n    def forward(\n        self,\n        hidden_states: List[torch.Tensor],\n        encoder_hidden_states: List[torch.Tensor],\n        timestep,\n        guidance=0,\n        patch_size=2,\n        f_patch_size=1,\n        freqs_cis=None,\n        **kwargs,\n    ):\n        assert patch_size in self.all_patch_size\n        assert f_patch_size in self.all_f_patch_size\n\n        x = hidden_states\n        cap_feats = encoder_hidden_states\n        timestep = 1000.0 - timestep\n        t = timestep\n        bsz = 1\n        device = x[0].device\n        t = self.t_embedder(t)\n        adaln_input = t.type_as(x)\n        (\n            x,\n            cap_feats,\n            x_size,\n        ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)\n\n        x = torch.cat(x, dim=0)\n        x, _ = self.all_x_embedder[f\"{patch_size}-{f_patch_size}\"](x)\n        x_freqs_cis = freqs_cis[1]\n\n        x = x.unsqueeze(0)\n        x_freqs_cis = x_freqs_cis\n        for layer in self.noise_refiner:\n            x = layer(x, x_freqs_cis, adaln_input)\n\n        cap_feats = torch.cat(cap_feats, dim=0)\n\n        cap_feats, _ = self.cap_embedder(cap_feats)\n\n        cap_freqs_cis = freqs_cis[0]\n\n        cap_feats = cap_feats.unsqueeze(0)\n        for layer in self.context_refiner:\n            cap_feats = layer(cap_feats, cap_freqs_cis)\n\n        unified = torch.cat([x, cap_feats], dim=1)\n        unified_freqs_cis = (\n            torch.cat([x_freqs_cis[0], cap_freqs_cis[0]], dim=0),\n            torch.cat([x_freqs_cis[1], cap_freqs_cis[1]], dim=0),\n        )\n\n        for layer in self.layers:\n            unified = layer(unified, unified_freqs_cis, adaln_input)\n\n        unified = self.all_final_layer[f\"{patch_size}-{f_patch_size}\"](\n            unified, adaln_input\n        )\n        unified = list(unified.unbind(dim=0))\n        x = self.unpatchify(unified, x_size, patch_size, f_patch_size)\n\n        return -x[0]\n\n\nEntryClass = ZImageTransformer2DModel\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/encoders/base.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\nfrom abc import ABC, abstractmethod\nfrom dataclasses import field\n\nimport torch\nfrom torch import nn\n\nfrom sglang.multimodal_gen.configs.models.encoders import (\n    BaseEncoderOutput,\n    ImageEncoderConfig,\n    TextEncoderConfig,\n)\nfrom sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum\n\n\nclass TextEncoder(nn.Module, ABC):\n    _fsdp_shard_conditions: list = field(default_factory=lambda: [])\n    _stacked_params_mapping: list[tuple[str, str, str]] = field(default_factory=list)\n    _supported_attention_backends: set[AttentionBackendEnum] = (\n        TextEncoderConfig()._supported_attention_backends\n    )\n\n    def __init__(self, config: TextEncoderConfig) -> None:\n        super().__init__()\n        self.config = config\n        self._fsdp_shard_conditions = config.arch_config._fsdp_shard_conditions\n        self._stacked_params_mapping = config.arch_config.stacked_params_mapping\n        if not self.supported_attention_backends:\n            raise ValueError(\n                f\"Subclass {self.__class__.__name__} must define _supported_attention_backends\"\n            )\n\n    @abstractmethod\n    def forward(\n        self,\n        input_ids: torch.Tensor | None,\n        position_ids: torch.Tensor | None = None,\n        attention_mask: torch.Tensor | None = None,\n        inputs_embeds: torch.Tensor | None = None,\n        output_hidden_states: bool | None = None,\n        **kwargs,\n    ) -> BaseEncoderOutput:\n        pass\n\n    @property\n    def supported_attention_backends(self) -> set[AttentionBackendEnum]:\n        return self._supported_attention_backends\n\n\nclass ImageEncoder(nn.Module, ABC):\n    _supported_attention_backends: set[AttentionBackendEnum] = (\n        ImageEncoderConfig()._supported_attention_backends\n    )\n\n    def __init__(self, config: ImageEncoderConfig) -> None:\n        super().__init__()\n        self.config = config\n        if not self.supported_attention_backends:\n            raise ValueError(\n                f\"Subclass {self.__class__.__name__} must define _supported_attention_backends\"\n            )\n\n    @abstractmethod\n    def forward(self, pixel_values: torch.Tensor, **kwargs) -> BaseEncoderOutput:\n        pass\n\n    @property\n    def supported_attention_backends(self) -> set[AttentionBackendEnum]:\n        return self._supported_attention_backends\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/encoders/bert.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# type: ignore\nimport os\n\nimport torch\nimport torch.nn as nn\nfrom transformers import BertModel, BertTokenizer\n\n\nclass HunyuanClip(nn.Module):\n    \"\"\"\n    Hunyuan clip code copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py\n    hunyuan's clip used BertModel and BertTokenizer, so we copy it.\n    \"\"\"\n\n    def __init__(self, model_dir, max_length=77):\n        super().__init__()\n\n        self.max_length = max_length\n        self.tokenizer = BertTokenizer.from_pretrained(\n            os.path.join(model_dir, \"tokenizer\")\n        )\n        self.text_encoder = BertModel.from_pretrained(\n            os.path.join(model_dir, \"clip_text_encoder\")\n        )\n\n    @torch.no_grad\n    def forward(self, prompts, with_mask=True):\n        self.device = next(self.text_encoder.parameters()).device\n        text_inputs = self.tokenizer(\n            prompts,\n            padding=\"max_length\",\n            max_length=self.max_length,\n            truncation=True,\n            return_attention_mask=True,\n            return_tensors=\"pt\",\n        )\n        prompt_embeds = self.text_encoder(\n            text_inputs.input_ids.to(self.device),\n            attention_mask=(\n                text_inputs.attention_mask.to(self.device) if with_mask else None\n            ),\n        )\n        return prompt_embeds.last_hidden_state, prompt_embeds.pooler_output\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/encoders/clip.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/models/clip.py\n# Adapted from transformers: https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py\n\"\"\"Minimal implementation of CLIPVisionModel intended to be only used\nwithin a vision language model.\"\"\"\n\nfrom collections.abc import Iterable\nfrom typing import Optional\n\nimport torch\nimport torch.nn as nn\n\nfrom sglang.multimodal_gen.configs.models.encoders import (\n    BaseEncoderOutput,\n    CLIPTextConfig,\n    CLIPVisionConfig,\n)\nfrom sglang.multimodal_gen.runtime.distributed import divide, get_tp_world_size\nfrom sglang.multimodal_gen.runtime.layers.activation import get_act_fn\nfrom sglang.multimodal_gen.runtime.layers.attention import LocalAttention\nfrom sglang.multimodal_gen.runtime.layers.linear import (\n    ColumnParallelLinear,\n    QKVParallelLinear,\n    RowParallelLinear,\n)\nfrom sglang.multimodal_gen.runtime.layers.quantization import QuantizationConfig\n\n# TODO: support quantization\n# from vllm.model_executor.layers.quantization import QuantizationConfig\nfrom sglang.multimodal_gen.runtime.loader.weight_utils import default_weight_loader\nfrom sglang.multimodal_gen.runtime.models.encoders.base import ImageEncoder, TextEncoder\nfrom sglang.multimodal_gen.runtime.models.encoders.vision import (\n    resolve_visual_encoder_outputs,\n)\nfrom sglang.multimodal_gen.runtime.platforms import (\n    AttentionBackendEnum,\n    current_platform,\n)\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\n# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa\nclass CLIPVisionEmbeddings(nn.Module):\n\n    def __init__(self, config: CLIPVisionConfig):\n        super().__init__()\n        self.config = config\n        self.embed_dim = config.hidden_size\n        self.image_size = config.image_size\n        self.patch_size = config.patch_size\n        assert self.image_size % self.patch_size == 0\n\n        self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))\n\n        self.patch_embedding = nn.Conv2d(\n            in_channels=config.num_channels,\n            out_channels=self.embed_dim,\n            kernel_size=self.patch_size,\n            stride=self.patch_size,\n            bias=False,\n        )\n\n        self.num_patches = (self.image_size // self.patch_size) ** 2\n        self.num_positions = self.num_patches + 1\n        self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)\n        self.register_buffer(\n            \"position_ids\",\n            torch.arange(self.num_positions).expand((1, -1)),\n            persistent=False,\n        )\n\n    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:\n        batch_size = pixel_values.shape[0]\n        target_dtype = self.patch_embedding.weight.dtype\n        patch_embeds = self.patch_embedding(\n            pixel_values.to(dtype=target_dtype)\n        )  # shape = [*, width, grid, grid]\n        patch_embeds = patch_embeds.flatten(2).transpose(1, 2)\n\n        class_embeds = self.class_embedding.expand(batch_size, 1, -1)\n        embeddings = torch.cat([class_embeds, patch_embeds], dim=1)\n        embeddings = embeddings + self.position_embedding(self.position_ids)\n\n        return embeddings\n\n\nclass CLIPTextEmbeddings(nn.Module):\n\n    def __init__(self, config: CLIPTextConfig):\n        super().__init__()\n        self.config = config\n        embed_dim = config.hidden_size\n\n        self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)\n        self.position_embedding = nn.Embedding(\n            config.max_position_embeddings, embed_dim\n        )\n\n        # position_ids (1, len position emb) is contiguous in memory and exported when serialized\n        self.register_buffer(\n            \"position_ids\",\n            torch.arange(config.max_position_embeddings).expand((1, -1)),\n            persistent=False,\n        )\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor | None = None,\n        position_ids: torch.LongTensor | None = None,\n        inputs_embeds: torch.FloatTensor | None = None,\n    ) -> torch.Tensor:\n        if input_ids is not None:\n            seq_length = input_ids.shape[-1]\n        elif inputs_embeds is not None:\n            seq_length = inputs_embeds.shape[-2]\n        else:\n            raise ValueError(\"Either input_ids or inputs_embeds must be provided.\")\n\n        max_position_embedding = self.position_embedding.weight.shape[0]\n\n        if seq_length > max_position_embedding:\n            raise ValueError(\n                f\"Sequence length must be less than max_position_embeddings (got `sequence length`: \"\n                f\"{seq_length} and max_position_embeddings: {max_position_embedding}\"\n            )\n\n        if position_ids is None:\n            position_ids = self.position_ids[:, :seq_length]\n\n        if inputs_embeds is None:\n            inputs_embeds = self.token_embedding(input_ids)\n\n        position_embeddings = self.position_embedding(position_ids)\n        embeddings = inputs_embeds + position_embeddings\n\n        return embeddings\n\n\nclass CLIPAttention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(\n        self,\n        config: CLIPVisionConfig | CLIPTextConfig,\n        quant_config: QuantizationConfig | None = None,\n        prefix: str = \"\",\n    ):\n        super().__init__()\n        self.config = config\n        self.embed_dim = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.head_dim = self.embed_dim // self.num_heads\n        if self.head_dim * self.num_heads != self.embed_dim:\n            raise ValueError(\n                \"embed_dim must be divisible by num_heads \"\n                f\"(got `embed_dim`: {self.embed_dim} and `num_heads`:\"\n                f\" {self.num_heads}).\"\n            )\n        self.scale = self.head_dim**-0.5\n        self.dropout = config.attention_dropout\n\n        self.qkv_proj = QKVParallelLinear(\n            hidden_size=self.embed_dim,\n            head_size=self.head_dim,\n            total_num_heads=self.num_heads,\n            quant_config=quant_config,\n            prefix=f\"{prefix}.qkv_proj\",\n        )\n\n        self.out_proj = RowParallelLinear(\n            input_size=self.embed_dim,\n            output_size=self.embed_dim,\n            quant_config=quant_config,\n            prefix=f\"{prefix}.out_proj\",\n        )\n\n        self.tp_size = get_tp_world_size()\n        self.num_heads_per_partition = divide(self.num_heads, self.tp_size)\n\n        self.attn = LocalAttention(\n            self.num_heads_per_partition,\n            self.head_dim,\n            self.num_heads_per_partition,\n            softmax_scale=self.scale,\n            causal=True,\n            supported_attention_backends=config._supported_attention_backends,\n        )\n\n    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):\n        return (\n            tensor.view(bsz, seq_len, self.num_heads, self.head_dim)\n            .transpose(1, 2)\n            .contiguous()\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: torch.Tensor | None = None,\n    ):\n        \"\"\"Input shape: Batch x Time x Channel\"\"\"\n\n        qkv_states, _ = self.qkv_proj(hidden_states)\n        query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)\n        # use flash_attn_func\n        query_states = query_states.reshape(\n            query_states.shape[0],\n            query_states.shape[1],\n            self.num_heads_per_partition,\n            self.head_dim,\n        )\n        key_states = key_states.reshape(\n            key_states.shape[0],\n            key_states.shape[1],\n            self.num_heads_per_partition,\n            self.head_dim,\n        )\n        value_states = value_states.reshape(\n            value_states.shape[0],\n            value_states.shape[1],\n            self.num_heads_per_partition,\n            self.head_dim,\n        )\n\n        if self.attn.backend == AttentionBackendEnum.TORCH_SDPA:\n            query_states = query_states.transpose(1, 2)  # [B, H, S, D]\n            key_states = key_states.transpose(1, 2)\n            value_states = value_states.transpose(1, 2)\n\n            if current_platform.is_rocm() or current_platform.is_musa():\n                # ROCm: Using both is_causal=True and attn_mask causes NaN.\n                # Use is_causal=True alone (padding mask not needed for CLIP\n                # since pooler_output comes from EOS token before padding).\n                # XXX (MUSA): Torch SDPA on MUSA currently does not support\n                # using both `attn_mask` and `is_causal=True` simultaneously.\n                attn_output = torch.nn.functional.scaled_dot_product_attention(\n                    query_states,\n                    key_states,\n                    value_states,\n                    attn_mask=None,\n                    is_causal=True,\n                    scale=self.scale,\n                )\n            else:\n                if attention_mask is not None:\n                    # SDPA requires [B, 1, 1, S] or [B, S, S] format mask\n                    if attention_mask.dim() == 2:\n                        attn_mask = attention_mask[:, None, None, :].to(\n                            dtype=query_states.dtype\n                        )\n                        attn_mask = (1.0 - attn_mask) * torch.finfo(\n                            query_states.dtype\n                        ).min\n                    else:\n                        attn_mask = attention_mask\n                else:\n                    attn_mask = None\n\n                attn_output = torch.nn.functional.scaled_dot_product_attention(\n                    query_states,\n                    key_states,\n                    value_states,\n                    attn_mask=attn_mask,\n                    is_causal=attention_mask is None,\n                    scale=self.scale,\n                )\n            attn_output = attn_output.transpose(1, 2)\n        else:\n            # Use LocalAttention (doesn't support attention_mask, but maintains compatibility)\n            attn_output = self.attn(query_states, key_states, value_states)\n\n        attn_output = attn_output.reshape(\n            attn_output.shape[0],\n            attn_output.shape[1],\n            self.num_heads_per_partition * self.head_dim,\n        )\n        attn_output, _ = self.out_proj(attn_output)\n\n        return attn_output, None\n\n\nclass CLIPMLP(nn.Module):\n\n    def __init__(\n        self,\n        config: CLIPVisionConfig | CLIPTextConfig,\n        quant_config: QuantizationConfig | None = None,\n        prefix: str = \"\",\n    ) -> None:\n        super().__init__()\n        self.config = config\n        self.activation_fn = get_act_fn(config.hidden_act)\n        self.fc1 = ColumnParallelLinear(\n            config.hidden_size,\n            config.intermediate_size,\n            bias=True,\n            quant_config=quant_config,\n            prefix=f\"{prefix}.fc1\",\n        )\n        self.fc2 = RowParallelLinear(\n            config.intermediate_size,\n            config.hidden_size,\n            bias=True,\n            quant_config=quant_config,\n            prefix=f\"{prefix}.fc2\",\n        )\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        hidden_states, _ = self.fc1(hidden_states)\n        hidden_states = self.activation_fn(hidden_states)\n        hidden_states, _ = self.fc2(hidden_states)\n\n        return hidden_states\n\n\nclass CLIPEncoderLayer(nn.Module):\n\n    def __init__(\n        self,\n        config: CLIPTextConfig | CLIPVisionConfig,\n        quant_config: QuantizationConfig | None = None,\n        prefix: str = \"\",\n    ) -> None:\n        super().__init__()\n        self.self_attn = CLIPAttention(\n            config,\n            quant_config=quant_config,\n            prefix=f\"{prefix}.self_attn\",\n        )\n        self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n        self.mlp = CLIPMLP(config, quant_config=quant_config, prefix=f\"{prefix}.mlp\")\n        self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: torch.Tensor | None = None,\n    ) -> torch.Tensor:\n        residual = hidden_states\n\n        hidden_states = self.layer_norm1(hidden_states)\n        hidden_states, _ = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n        )\n        hidden_states = residual + hidden_states\n\n        residual = hidden_states\n        hidden_states = self.layer_norm2(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n\n        return hidden_states\n\n\nclass CLIPEncoder(nn.Module):\n    \"\"\"\n    Transformer encoder consisting of `config.num_hidden_layers` self\n    attention layers. Each layer is a [`CLIPEncoderLayer`].\n\n    Args:\n        config: CLIPConfig\n    \"\"\"\n\n    def __init__(\n        self,\n        config: CLIPVisionConfig | CLIPTextConfig,\n        quant_config: QuantizationConfig | None = None,\n        num_hidden_layers_override: int | None = None,\n        prefix: str = \"\",\n    ) -> None:\n        super().__init__()\n\n        self.config = config\n\n        if num_hidden_layers_override is None:\n            num_hidden_layers = config.num_hidden_layers\n        else:\n            num_hidden_layers = num_hidden_layers_override\n        self.layers = nn.ModuleList(\n            [\n                CLIPEncoderLayer(\n                    config=config,\n                    quant_config=quant_config,\n                    prefix=f\"{prefix}.layers.{layer_idx}\",\n                )\n                for layer_idx in range(num_hidden_layers)\n            ]\n        )\n\n    def forward(\n        self,\n        inputs_embeds: torch.Tensor,\n        return_all_hidden_states: bool,\n        attention_mask: torch.Tensor | None = None,\n    ) -> torch.Tensor | list[torch.Tensor]:\n        hidden_states_pool = [inputs_embeds]\n        hidden_states = inputs_embeds\n\n        for idx, encoder_layer in enumerate(self.layers):\n            hidden_states = encoder_layer(\n                hidden_states,\n                attention_mask=attention_mask,\n            )\n            if return_all_hidden_states:\n                hidden_states_pool.append(hidden_states)\n        # If we have multiple feature sample layers, we return all hidden\n        # states in order and grab the ones we need by index.\n        if return_all_hidden_states:\n            return hidden_states_pool\n        return [hidden_states]\n\n\nclass CLIPTextTransformer(nn.Module):\n\n    def __init__(\n        self,\n        config: CLIPTextConfig,\n        quant_config: QuantizationConfig | None = None,\n        num_hidden_layers_override: int | None = None,\n        prefix: str = \"\",\n    ):\n        super().__init__()\n        self.config = config\n        embed_dim = config.hidden_size\n\n        self.embeddings = CLIPTextEmbeddings(config)\n\n        self.encoder = CLIPEncoder(\n            config,\n            quant_config=quant_config,\n            num_hidden_layers_override=num_hidden_layers_override,\n            prefix=prefix,\n        )\n\n        self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)\n\n        # For `pooled_output` computation\n        self.eos_token_id = config.eos_token_id\n\n    def forward(\n        self,\n        input_ids: torch.Tensor | None,\n        position_ids: torch.Tensor | None = None,\n        attention_mask: torch.Tensor | None = None,\n        inputs_embeds: torch.Tensor | None = None,\n        output_hidden_states: bool | None = None,\n    ) -> BaseEncoderOutput:\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n\n        if input_ids is None:\n            raise ValueError(\"You have to specify input_ids\")\n\n        input_shape = input_ids.size()\n        input_ids = input_ids.view(-1, input_shape[-1])\n\n        hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)\n\n        # CLIP's text model uses causal mask, prepare it here.\n        # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324\n        # causal_attention_mask = _create_4d_causal_attention_mask(\n        #     input_shape, hidden_states.dtype, device=hidden_states.device\n        # )\n\n        # # expand attention_mask\n        # if attention_mask is not None and not self._use_flash_attention_2:\n        #     raise NotImplementedError(\"attention_mask is not supported for CLIPTextTransformer\")\n        #     # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n        #     attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)\n\n        encoder_outputs = self.encoder(\n            inputs_embeds=hidden_states,\n            return_all_hidden_states=output_hidden_states,\n            attention_mask=attention_mask,\n        )\n\n        last_hidden_state = encoder_outputs[-1]\n        last_hidden_state = self.final_layer_norm(last_hidden_state)\n\n        if self.eos_token_id == 2:\n            # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.\n            # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added\n            # ------------------------------------------------------------\n            # text_embeds.shape = [batch_size, sequence_length, transformer.width]\n            # take features from the eot embedding (eot_token is the highest number in each sequence)\n            # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14\n            pooled_output = last_hidden_state[\n                torch.arange(\n                    last_hidden_state.shape[0], device=last_hidden_state.device\n                ),\n                input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(\n                    dim=-1\n                ),\n            ]\n        else:\n            # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)\n            pooled_output = last_hidden_state[\n                torch.arange(\n                    last_hidden_state.shape[0], device=last_hidden_state.device\n                ),\n                # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)\n                # Note: we assume each sequence (along batch dim.) contains an  `eos_token_id` (e.g. prepared by the tokenizer)\n                (\n                    input_ids.to(dtype=torch.int, device=last_hidden_state.device)\n                    == self.eos_token_id\n                )\n                .int()\n                .argmax(dim=-1),\n            ]\n\n        return BaseEncoderOutput(\n            last_hidden_state=last_hidden_state,\n            pooler_output=pooled_output,\n            hidden_states=encoder_outputs,\n            # attentions=encoder_outputs.attentions,\n        )\n\n\nclass CLIPTextModel(TextEncoder):\n\n    def __init__(\n        self,\n        config: CLIPTextConfig,\n    ) -> None:\n        super().__init__(config)\n        self.text_model = CLIPTextTransformer(\n            config=config, quant_config=config.quant_config, prefix=config.prefix\n        )\n\n    def forward(\n        self,\n        input_ids: torch.Tensor | None,\n        position_ids: torch.Tensor | None = None,\n        attention_mask: torch.Tensor | None = None,\n        inputs_embeds: torch.Tensor | None = None,\n        output_hidden_states: bool | None = None,\n        **kwargs,\n    ) -> BaseEncoderOutput:\n\n        outputs: BaseEncoderOutput = self.text_model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            output_hidden_states=output_hidden_states,\n        )\n        return outputs\n\n    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:\n\n        # Define mapping for stacked parameters\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            (\"qkv_proj\", \"q_proj\", \"q\"),\n            (\"qkv_proj\", \"k_proj\", \"k\"),\n            (\"qkv_proj\", \"v_proj\", \"v\"),\n        ]\n        params_dict = dict(self.named_parameters())\n        loaded_params: set[str] = set()\n        for name, loaded_weight in weights:\n            # Handle q_proj, k_proj, v_proj -> qkv_proj mapping\n            for param_name, weight_name, shard_id in stacked_params_mapping:\n                if weight_name in name:\n                    # Replace the weight name with the parameter name\n                    model_param_name = name.replace(weight_name, param_name)\n\n                    if model_param_name in params_dict:\n                        param = params_dict[model_param_name]\n                        weight_loader = param.weight_loader\n                        weight_loader(param, loaded_weight, shard_id)\n                        loaded_params.add(model_param_name)\n                    break\n            else:\n                # Use default weight loader for all other parameters\n                if name in params_dict:\n                    param = params_dict[name]\n                    weight_loader = getattr(\n                        param, \"weight_loader\", default_weight_loader\n                    )\n                    weight_loader(param, loaded_weight)\n                    loaded_params.add(name)\n\n        return loaded_params\n\n\nclass CLIPVisionTransformer(nn.Module):\n\n    def __init__(\n        self,\n        config: CLIPVisionConfig,\n        quant_config: QuantizationConfig | None = None,\n        num_hidden_layers_override: int | None = None,\n        require_post_norm: bool | None = None,\n        prefix: str = \"\",\n    ) -> None:\n        super().__init__()\n\n        self.config = config\n        embed_dim = config.hidden_size\n\n        self.embeddings = CLIPVisionEmbeddings(config)\n\n        # NOTE: This typo of \"layrnorm\" is not fixed on purpose to match\n        # the original transformers code and name of the model weights.\n        self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)\n\n        self.encoder = CLIPEncoder(\n            config=config,\n            quant_config=quant_config,\n            num_hidden_layers_override=num_hidden_layers_override,\n            prefix=f\"{prefix}.encoder\",\n        )\n\n        num_hidden_layers = config.num_hidden_layers\n        if len(self.encoder.layers) > config.num_hidden_layers:\n            raise ValueError(\n                f\"The original encoder only has {num_hidden_layers} \"\n                f\"layers, but you requested {len(self.encoder.layers)} layers.\"\n            )\n\n        # If possible, skip post_layernorm to conserve memory\n        if require_post_norm is None:\n            require_post_norm = len(self.encoder.layers) == num_hidden_layers\n\n        if require_post_norm:\n            self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)\n        else:\n            self.post_layernorm = None\n\n    def forward(\n        self,\n        pixel_values: torch.Tensor,\n        output_hidden_states: Optional[bool] = None,\n        feature_sample_layers: list[int] | None = None,\n    ) -> BaseEncoderOutput:\n\n        hidden_states = self.embeddings(pixel_values)\n        hidden_states = self.pre_layrnorm(hidden_states)\n\n        return_all_hidden_states = output_hidden_states or (\n            feature_sample_layers is not None\n        )\n\n        # Produces either the last layer output or all of the hidden states,\n        # depending on if we have feature_sample_layers or not\n        encoder_outputs = self.encoder(\n            inputs_embeds=hidden_states,\n            return_all_hidden_states=return_all_hidden_states,\n        )\n\n        if not return_all_hidden_states:\n            encoder_outputs = encoder_outputs[0]\n\n            # Handle post-norm (if applicable) and stacks feature layers if needed\n            encoder_outputs = resolve_visual_encoder_outputs(\n                encoder_outputs,\n                feature_sample_layers,\n                self.post_layernorm,\n                self.config.num_hidden_layers,\n            )\n\n        if return_all_hidden_states:\n            return BaseEncoderOutput(hidden_states=encoder_outputs)\n\n        return BaseEncoderOutput(last_hidden_state=encoder_outputs)\n\n\nclass CLIPVisionModel(ImageEncoder):\n    config_class = CLIPVisionConfig\n    main_input_name = \"pixel_values\"\n    packed_modules_mapping = {\"qkv_proj\": [\"q_proj\", \"k_proj\", \"v_proj\"]}\n\n    def __init__(self, config: CLIPVisionConfig) -> None:\n        super().__init__(config)\n        self.vision_model = CLIPVisionTransformer(\n            config=config,\n            quant_config=config.quant_config,\n            num_hidden_layers_override=config.num_hidden_layers_override,\n            require_post_norm=config.require_post_norm,\n            prefix=f\"{config.prefix}.vision_model\",\n        )\n\n    def forward(\n        self,\n        pixel_values: torch.Tensor,\n        feature_sample_layers: list[int] | None = None,\n        output_hidden_states: Optional[bool] = None,\n        **kwargs,\n    ) -> BaseEncoderOutput:\n        base_encoder_output = self.vision_model(\n            pixel_values,\n            output_hidden_states=output_hidden_states,\n            feature_sample_layers=feature_sample_layers,\n        )\n\n        return base_encoder_output\n\n    @property\n    def device(self):\n        return next(self.parameters()).device\n\n    # (TODO) Add prefix argument for filtering out weights to be loaded\n    #        ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986\n    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:\n\n        params_dict = dict(self.named_parameters())\n        loaded_params: set[str] = set()\n        layer_count = len(self.vision_model.encoder.layers)\n\n        for name, loaded_weight in weights:\n            if name.startswith(\"visual_projection\"):\n                continue\n            # post_layernorm is not needed in CLIPVisionModel\n            if (\n                name.startswith(\"vision_model.post_layernorm\")\n                and self.vision_model.post_layernorm is None\n            ):\n                continue\n\n            # omit layers when num_hidden_layers_override is set\n            if name.startswith(\"vision_model.encoder.layers\"):\n                layer_idx = int(name.split(\".\")[3])\n                if layer_idx >= layer_count:\n                    continue\n\n            for (\n                param_name,\n                weight_name,\n                shard_id,\n            ) in self.config.arch_config.stacked_params_mapping:\n                if weight_name not in name:\n                    continue\n                name = name.replace(weight_name, param_name)\n\n                param = params_dict[name]\n                weight_loader = param.weight_loader\n                weight_loader(param, loaded_weight, shard_id)\n                break\n            else:\n                param = params_dict[name]\n                weight_loader = getattr(param, \"weight_loader\", default_weight_loader)\n                weight_loader(param, loaded_weight)\n            loaded_params.add(name)\n        return loaded_params\n\n\nclass BertModel(CLIPTextModel):\n    pass\n\n\nEntryClass = [CLIPTextModel, CLIPVisionModel]\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/encoders/gemma2.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n#\n# Gemma2 2B text encoder for SANA.\n#\n# This is a decoder-only language model used as a text encoder: we feed\n# in tokenized text and extract the final hidden states (not logits) as\n# the conditioning signal for SANA's cross-attention layers.\n#\n# Architecture follows google/gemma-2-2b-it:\n#   - 26 layers, alternating global / sliding-window attention\n#   - GQA with 8 query heads, 4 KV heads, head_dim=256\n#   - Pre/post attention + pre/post feedforward LayerNorm (Gemma2-style)\n#   - GeGLU activation (gelu_pytorch_tanh)\n#\n# Adapted from the Gemma3 text model implementation in this codebase.\n\nimport logging\nfrom typing import Any, Iterable\n\nimport torch\nfrom torch import nn\n\nfrom sglang.multimodal_gen.configs.models.encoders.base import BaseEncoderOutput\nfrom sglang.multimodal_gen.configs.models.encoders.gemma2 import Gemma2Config\nfrom sglang.multimodal_gen.runtime.distributed import get_tp_world_size\nfrom sglang.multimodal_gen.runtime.layers.activation import GeluAndMul\nfrom sglang.multimodal_gen.runtime.layers.linear import (\n    MergedColumnParallelLinear,\n    QKVParallelLinear,\n    RowParallelLinear,\n)\nfrom sglang.multimodal_gen.runtime.layers.quantization import QuantizationConfig\nfrom sglang.multimodal_gen.runtime.layers.rotary_embedding import get_rope\nfrom sglang.multimodal_gen.runtime.layers.vocab_parallel_embedding import (\n    VocabParallelEmbedding,\n)\nfrom sglang.multimodal_gen.runtime.loader.weight_utils import default_weight_loader\n\nlogger = logging.getLogger(__name__)\n\n\nclass Gemma2RMSNorm(nn.Module):\n    def __init__(self, dim: int, eps: float = 1e-6):\n        super().__init__()\n        self.eps = eps\n        self.weight = nn.Parameter(torch.zeros(dim))\n\n    def _norm(self, x):\n        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)\n\n    def forward(self, x):\n        output = self._norm(x.float())\n        output = output * (1.0 + self.weight.float())\n        return output.type_as(x)\n\n\nclass Gemma2MLP(nn.Module):\n    def __init__(\n        self,\n        hidden_size: int,\n        intermediate_size: int,\n        hidden_act: str,\n        quant_config: QuantizationConfig | None = None,\n        prefix: str = \"\",\n    ) -> None:\n        super().__init__()\n        self.gate_up_proj = MergedColumnParallelLinear(\n            input_size=hidden_size,\n            output_sizes=[intermediate_size] * 2,\n            bias=False,\n            quant_config=quant_config,\n            prefix=f\"{prefix}.gate_up_proj\",\n        )\n        self.down_proj = RowParallelLinear(\n            input_size=intermediate_size,\n            output_size=hidden_size,\n            bias=False,\n            quant_config=quant_config,\n            prefix=f\"{prefix}.down_proj\",\n        )\n        if hidden_act != \"gelu_pytorch_tanh\":\n            raise ValueError(\n                \"Gemma2 uses `gelu_pytorch_tanh` as the hidden activation. \"\n                f\"Got: {hidden_act}\"\n            )\n        self.act_fn = GeluAndMul(approximate=\"tanh\")\n\n    def forward(self, x):\n        x, _ = self.gate_up_proj(x)\n        x = self.act_fn(x)\n        x, _ = self.down_proj(x)\n        return x\n\n\nclass Gemma2Attention(nn.Module):\n    def __init__(\n        self,\n        layer_id: int,\n        config: Gemma2Config,\n        hidden_size: int,\n        num_heads: int,\n        num_kv_heads: int,\n        quant_config: QuantizationConfig | None = None,\n        prefix: str = \"\",\n    ) -> None:\n        super().__init__()\n        self.layer_id = layer_id\n        self.hidden_size = hidden_size\n        tp_size = get_tp_world_size()\n        self.total_num_heads = num_heads\n        assert self.total_num_heads % tp_size == 0\n        self.num_heads = self.total_num_heads // tp_size\n        self.total_num_kv_heads = num_kv_heads\n        if self.total_num_kv_heads >= tp_size:\n            assert self.total_num_kv_heads % tp_size == 0\n        else:\n            assert tp_size % self.total_num_kv_heads == 0\n        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)\n\n        arch = config.arch_config\n        self.head_dim = arch.head_dim\n        self.q_size = self.num_heads * self.head_dim\n        self.kv_size = self.num_kv_heads * self.head_dim\n        self.scaling = arch.query_pre_attn_scalar**-0.5\n\n        self.qkv_proj = QKVParallelLinear(\n            hidden_size=hidden_size,\n            head_size=self.head_dim,\n            total_num_heads=self.total_num_heads,\n            total_num_kv_heads=self.total_num_kv_heads,\n            bias=arch.attention_bias,\n            quant_config=quant_config,\n            prefix=f\"{prefix}.qkv_proj\",\n        )\n\n        self.o_proj = RowParallelLinear(\n            input_size=self.total_num_heads * self.head_dim,\n            output_size=hidden_size,\n            bias=arch.attention_bias,\n            quant_config=quant_config,\n            prefix=f\"{prefix}.o_proj\",\n        )\n\n        # Gemma2 interleaves global (even layers) and sliding-window (odd layers)\n        # attention. This pattern reduces memory for long sequences while\n        # maintaining global context every other layer.\n        self.is_sliding = (layer_id % 2) == 1\n        if self.is_sliding:\n            self.sliding_window = arch.sliding_window\n        else:\n            self.sliding_window = None\n\n        self.rotary_emb = get_rope(\n            self.head_dim,\n            rotary_dim=self.head_dim,\n            max_position=arch.max_position_embeddings,\n            base=arch.rope_theta,\n            is_neox_style=True,\n        )\n\n    def forward(\n        self,\n        positions: torch.Tensor,\n        hidden_states: torch.Tensor,\n        attention_mask: torch.Tensor | None = None,\n    ) -> torch.Tensor:\n        qkv, _ = self.qkv_proj(hidden_states)\n        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)\n\n        batch_size, seq_len, _ = q.shape\n        q = q.view(batch_size, seq_len, self.num_heads, self.head_dim)\n        k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)\n        v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)\n\n        q, k = self.rotary_emb(positions, q, k)\n\n        query = q.transpose(1, 2)\n        key = k.transpose(1, 2)\n        value = v.transpose(1, 2)\n\n        attn_mask = torch.zeros(\n            (seq_len, seq_len), device=hidden_states.device, dtype=torch.float32\n        )\n        causal = torch.triu(\n            torch.ones(\n                (seq_len, seq_len), device=hidden_states.device, dtype=torch.bool\n            ),\n            diagonal=1,\n        )\n        attn_mask = attn_mask.masked_fill(causal, float(\"-inf\"))\n        if self.is_sliding and self.sliding_window is not None:\n            idx = torch.arange(seq_len, device=hidden_states.device)\n            dist = idx[None, :] - idx[:, None]\n            too_far = dist > self.sliding_window\n            attn_mask = attn_mask.masked_fill(too_far, float(\"-inf\"))\n\n        if attention_mask is not None:\n            key_pad = ~attention_mask.to(torch.bool)\n            attn_mask = attn_mask[None, None, :, :].expand(\n                batch_size, 1, seq_len, seq_len\n            )\n            attn_mask = attn_mask.masked_fill(\n                key_pad[:, None, None, :].expand(batch_size, 1, seq_len, seq_len),\n                float(\"-inf\"),\n            )\n\n        attn_kwargs = {\n            \"attn_mask\": attn_mask,\n            \"dropout_p\": 0.0,\n            \"is_causal\": False,\n            \"scale\": self.scaling,\n        }\n        if query.shape[1] != key.shape[1]:\n            attn_kwargs[\"enable_gqa\"] = True\n        attn_output = torch.nn.functional.scaled_dot_product_attention(\n            query, key, value, **attn_kwargs\n        )\n\n        # NOTE: Gemma2 specifies attn_logit_softcapping (tanh(logits/cap)*cap) but\n        # PyTorch's scaled_dot_product_attention does not support it natively.\n        # For short text-encoder sequences (~300 tokens), the quality impact is\n        # negligible. A custom attention kernel would be needed for full fidelity.\n\n        attn_output = attn_output.transpose(1, 2)\n        attn_output = attn_output.reshape(\n            batch_size, seq_len, self.num_heads * self.head_dim\n        )\n\n        output, _ = self.o_proj(attn_output)\n        return output\n\n\nclass Gemma2DecoderLayer(nn.Module):\n    def __init__(\n        self,\n        layer_id: int,\n        config: Gemma2Config,\n        quant_config: QuantizationConfig | None = None,\n        prefix: str = \"\",\n    ) -> None:\n        super().__init__()\n        arch = config.arch_config\n        self.hidden_size = arch.hidden_size\n        self.self_attn = Gemma2Attention(\n            layer_id=layer_id,\n            config=config,\n            hidden_size=self.hidden_size,\n            num_heads=arch.num_attention_heads,\n            num_kv_heads=arch.num_key_value_heads,\n            quant_config=quant_config,\n            prefix=f\"{prefix}.self_attn\",\n        )\n        self.mlp = Gemma2MLP(\n            hidden_size=self.hidden_size,\n            intermediate_size=arch.intermediate_size,\n            hidden_act=arch.hidden_activation,\n            quant_config=quant_config,\n            prefix=f\"{prefix}.mlp\",\n        )\n        self.input_layernorm = Gemma2RMSNorm(self.hidden_size, eps=arch.rms_norm_eps)\n        self.post_attention_layernorm = Gemma2RMSNorm(\n            self.hidden_size, eps=arch.rms_norm_eps\n        )\n        self.pre_feedforward_layernorm = Gemma2RMSNorm(\n            self.hidden_size, eps=arch.rms_norm_eps\n        )\n        self.post_feedforward_layernorm = Gemma2RMSNorm(\n            self.hidden_size, eps=arch.rms_norm_eps\n        )\n\n    def forward(\n        self,\n        positions: torch.Tensor,\n        hidden_states: torch.Tensor,\n        attention_mask: torch.Tensor | None = None,\n    ) -> torch.Tensor:\n        residual = hidden_states\n        hidden_states = self.input_layernorm(hidden_states)\n        hidden_states = self.self_attn(positions, hidden_states, attention_mask)\n        hidden_states = self.post_attention_layernorm(hidden_states)\n        hidden_states = residual + hidden_states\n\n        residual = hidden_states\n        hidden_states = self.pre_feedforward_layernorm(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = self.post_feedforward_layernorm(hidden_states)\n        hidden_states = residual + hidden_states\n\n        return hidden_states\n\n\nclass Gemma2Model(nn.Module):\n    \"\"\"Gemma2 text encoder model for SANA pipeline.\"\"\"\n\n    _fsdp_shard_conditions = []\n\n    def __init__(self, config: Gemma2Config, **kwargs):\n        super().__init__()\n        self.config = config\n        arch = config.arch_config\n        self.quant_config = None\n\n        self.vocab_size = arch.vocab_size\n        self.embed_tokens = VocabParallelEmbedding(\n            self.vocab_size,\n            arch.hidden_size,\n            org_num_embeddings=arch.vocab_size,\n            quant_config=self.quant_config,\n        )\n        self.embed_scale = arch.hidden_size**0.5\n\n        self.layers = nn.ModuleList(\n            [\n                Gemma2DecoderLayer(\n                    layer_id=i,\n                    config=config,\n                    quant_config=self.quant_config,\n                    prefix=f\"model.layers.{i}\",\n                )\n                for i in range(arch.num_hidden_layers)\n            ]\n        )\n\n        self.norm = Gemma2RMSNorm(arch.hidden_size, eps=arch.rms_norm_eps)\n\n    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:\n        return self.embed_tokens(input_ids) * self.embed_scale\n\n    def forward(\n        self,\n        input_ids: torch.Tensor | None = None,\n        position_ids: torch.Tensor | None = None,\n        attention_mask: torch.Tensor | None = None,\n        inputs_embeds: torch.Tensor | None = None,\n        output_hidden_states: bool | None = None,\n        **kwargs,\n    ) -> BaseEncoderOutput:\n        if (input_ids is None) ^ (inputs_embeds is not None):\n            raise ValueError(\n                \"You must specify exactly one of input_ids or inputs_embeds\"\n            )\n\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else getattr(self.config.arch_config, \"output_hidden_states\", False)\n        )\n\n        if inputs_embeds is not None:\n            hidden_states = inputs_embeds\n        else:\n            hidden_states = self.get_input_embeddings(input_ids)\n\n        if position_ids is None:\n            position_ids = torch.arange(\n                0, hidden_states.shape[1], device=hidden_states.device\n            ).unsqueeze(0)\n\n        all_hidden_states: tuple[Any, ...] | None = () if output_hidden_states else None\n\n        for layer in self.layers:\n            if all_hidden_states is not None:\n                all_hidden_states += (hidden_states,)\n\n            hidden_states = layer(position_ids, hidden_states, attention_mask)\n\n        hidden_states = self.norm(hidden_states)\n\n        if all_hidden_states is not None:\n            all_hidden_states += (hidden_states,)\n\n        return BaseEncoderOutput(\n            last_hidden_state=hidden_states,\n            hidden_states=all_hidden_states,\n        )\n\n    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:\n        params_dict = dict(self.named_parameters())\n        loaded_params: set[str] = set()\n\n        stacked_params_mapping = getattr(\n            self.config.arch_config, \"stacked_params_mapping\", None\n        )\n        if stacked_params_mapping is None:\n            stacked_params_mapping = [\n                (\".qkv_proj\", \".q_proj\", \"q\"),\n                (\".qkv_proj\", \".k_proj\", \"k\"),\n                (\".qkv_proj\", \".v_proj\", \"v\"),\n                (\".gate_up_proj\", \".gate_proj\", \"0\"),\n                (\".gate_up_proj\", \".up_proj\", \"1\"),\n            ]\n\n        for name, loaded_weight in weights:\n            if \"rotary_emb.inv_freq\" in name:\n                continue\n\n            # HF Gemma2Model stores weights as model.layers.X... / model.embed_tokens...\n            # Strip \"model.\" prefix if present to match our naming\n            if name.startswith(\"model.\"):\n                name = name[len(\"model.\") :]\n\n            for param_name, weight_name, shard_id in stacked_params_mapping:\n                if weight_name not in name:\n                    continue\n                name = name.replace(weight_name, param_name)\n\n                if name not in params_dict:\n                    continue\n\n                param = params_dict[name]\n                weight_loader = param.weight_loader\n                self._load_with_shard_id(weight_loader, param, loaded_weight, shard_id)\n                break\n            else:\n                if name not in params_dict:\n                    continue\n                param = params_dict[name]\n                weight_loader = getattr(param, \"weight_loader\", default_weight_loader)\n                weight_loader(param, loaded_weight)\n\n            loaded_params.add(name)\n        return loaded_params\n\n    @staticmethod\n    def _load_with_shard_id(weight_loader, param, loaded_weight, shard_id):\n        try:\n            weight_loader(param, loaded_weight, shard_id)\n            return\n        except (AssertionError, TypeError):\n            pass\n\n        if isinstance(shard_id, str):\n            mapping = {\"q\": 0, \"k\": 1, \"v\": 2}\n            if shard_id in mapping:\n                weight_loader(param, loaded_weight, mapping[shard_id])\n                return\n            if shard_id.isdigit():\n                weight_loader(param, loaded_weight, int(shard_id))\n                return\n        elif isinstance(shard_id, int):\n            mapping = {0: \"q\", 1: \"k\", 2: \"v\"}\n            if shard_id in mapping:\n                weight_loader(param, loaded_weight, mapping[shard_id])\n                return\n\n        raise TypeError(\n            f\"Unsupported shard_id={shard_id!r} for weight_loader={weight_loader}\"\n        )\n\n\nEntryClass = Gemma2Model\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/encoders/gemma_3.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# Adapted from sglang: python/sglang/srt/models/gemma3_causal.py\n\nimport logging\nfrom functools import partial\nfrom typing import Any, Iterable, Optional, Set, Tuple\n\nimport torch\nfrom torch import nn\n\nfrom sglang.multimodal_gen.configs.models.encoders.base import BaseEncoderOutput\nfrom sglang.multimodal_gen.configs.models.encoders.gemma_3 import Gemma3Config\nfrom sglang.multimodal_gen.runtime.distributed import get_tp_world_size\nfrom sglang.multimodal_gen.runtime.layers.activation import GeluAndMul\nfrom sglang.multimodal_gen.runtime.layers.attention import LocalAttention\nfrom sglang.multimodal_gen.runtime.layers.linear import (\n    ColumnParallelLinear,\n    MergedColumnParallelLinear,\n    QKVParallelLinear,\n    RowParallelLinear,\n)\nfrom sglang.multimodal_gen.runtime.layers.quantization import QuantizationConfig\nfrom sglang.multimodal_gen.runtime.layers.rotary_embedding import get_rope\nfrom sglang.multimodal_gen.runtime.loader.weight_utils import default_weight_loader\nfrom sglang.multimodal_gen.runtime.utils.common import add_prefix\n\nlogger = logging.getLogger(__name__)\n\n\ndef get_attention_sliding_window_size(config):\n    return config.sliding_window - 1\n\n\nclass Gemma3RMSNorm(nn.Module):\n    def __init__(self, dim: int, eps: float = 1e-6):\n        super().__init__()\n        self.eps = eps\n        self.weight = nn.Parameter(torch.zeros(dim))\n\n    def _norm(self, x):\n        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)\n\n    def forward(self, x):\n        output = self._norm(x.float())\n        output = output * (1.0 + self.weight.float())\n        return output.type_as(x)\n\n    def extra_repr(self):\n        return f\"{tuple(self.weight.shape)}, eps={self.eps}\"\n\n\nclass Gemma3MLP(nn.Module):\n    def __init__(\n        self,\n        hidden_size: int,\n        intermediate_size: int,\n        hidden_act: str,\n        quant_config: QuantizationConfig | None = None,\n        prefix: str = \"\",\n    ) -> None:\n        super().__init__()\n        self.gate_up_proj = MergedColumnParallelLinear(\n            input_size=hidden_size,\n            output_sizes=[intermediate_size] * 2,\n            bias=False,\n            quant_config=quant_config,\n            prefix=f\"{prefix}.gate_up_proj\",\n        )\n        self.down_proj = RowParallelLinear(\n            input_size=intermediate_size,\n            output_size=hidden_size,\n            bias=False,\n            quant_config=quant_config,\n            prefix=f\"{prefix}.down_proj\",\n        )\n        if hidden_act != \"gelu_pytorch_tanh\":\n            raise ValueError(\n                \"Gemma3 uses `gelu_pytorch_tanh` as the hidden activation \"\n                \"function. Please set `hidden_activation` to \"\n                \"`gelu_pytorch_tanh`.\"\n            )\n        self.act_fn = GeluAndMul(approximate=\"tanh\")\n\n    def forward(self, x):\n        x, _ = self.gate_up_proj(x)\n        x = self.act_fn(x)\n        x, _ = self.down_proj(x)\n        return x\n\n\ndef _rotate_half(x: torch.Tensor) -> torch.Tensor:\n    x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\nclass Gemma3Attention(nn.Module):\n    def __init__(\n        self,\n        layer_id: int,\n        config: Gemma3Config,\n        hidden_size: int,\n        num_heads: int,\n        num_kv_heads: int,\n        quant_config: QuantizationConfig | None = None,\n        prefix: str = \"\",\n    ) -> None:\n        super().__init__()\n        self.layer_id = layer_id\n        self.hidden_size = hidden_size\n        tp_size = get_tp_world_size()\n        self.total_num_heads = num_heads\n        assert self.total_num_heads % tp_size == 0\n        self.num_heads = self.total_num_heads // tp_size\n        self.total_num_kv_heads = num_kv_heads\n        if self.total_num_kv_heads >= tp_size:\n            assert self.total_num_kv_heads % tp_size == 0\n        else:\n            assert tp_size % self.total_num_kv_heads == 0\n        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)\n\n        self.head_dim = getattr(\n            config.text_config, \"head_dim\", self.hidden_size // self.total_num_heads\n        )\n\n        self.q_size = self.num_heads * self.head_dim\n        self.kv_size = self.num_kv_heads * self.head_dim\n        self.scaling = config.text_config.query_pre_attn_scalar**-0.5\n\n        self.qkv_proj = QKVParallelLinear(\n            hidden_size=hidden_size,\n            head_size=self.head_dim,\n            total_num_heads=self.total_num_heads,\n            total_num_kv_heads=self.total_num_kv_heads,\n            bias=config.text_config.attention_bias,\n            quant_config=quant_config,\n            prefix=f\"{prefix}.qkv_proj\",\n        )\n\n        self.o_proj = RowParallelLinear(\n            input_size=self.total_num_heads * self.head_dim,\n            output_size=hidden_size,\n            bias=config.text_config.attention_bias,\n            quant_config=quant_config,\n            prefix=f\"{prefix}.o_proj\",\n        )\n\n        self.is_sliding = (\n            config.text_config.layer_types[layer_id] == \"sliding_attention\"\n        )\n\n        # Initialize the rotary embedding.\n        if self.is_sliding:\n            # Local attention.\n            self.rope_theta = config.text_config.rope_local_base_freq\n            rope_scaling = None  # Default\n            # sliding window\n            self.sliding_window = get_attention_sliding_window_size(config.text_config)\n            # (left, right) = (window, 0) effectively for causal\n            self.window_size = (self.sliding_window, 0)\n        else:\n            # Global attention.\n            self.rope_theta = config.text_config.rope_theta\n            rope_scaling = config.text_config.rope_scaling\n            self.sliding_window = None\n            self.window_size = (-1, -1)\n\n        self.rotary_emb = get_rope(\n            self.head_dim,\n            rotary_dim=self.head_dim,\n            max_position=config.text_config.max_position_embeddings,\n            base=self.rope_theta,\n            rope_scaling=rope_scaling,\n            is_neox_style=True,\n        )\n\n        # NOTE(gmixiaojin): The shared RotaryEmbedding above computes inv_freq on\n        # GPU and uses the x1*cos - x2*sin formula, which causes slight\n        # numerical differences vs HuggingFace (see the NOTE in\n        # rotary_embedding.py:_compute_inv_freq).  For HF-exact alignment we\n        # precompute inv_freq on CPU and use rotate_half in self.rotary_emb().\n        freq_indices = (\n            torch.arange(0, self.head_dim, 2, dtype=torch.int64).float() / self.head_dim\n        )\n        inv_freq = 1.0 / (self.rope_theta**freq_indices)\n        if rope_scaling and rope_scaling.get(\"factor\"):\n            inv_freq = inv_freq / float(rope_scaling[\"factor\"])\n        self.register_buffer(\"_hf_inv_freq\", inv_freq, persistent=False)\n\n        # Local Attention not support attention mask, we use global attention instead.\n        # self.attn = LocalAttention(\n        #     self.num_heads,\n        #     self.head_dim,\n        #     self.num_kv_heads,\n        #     softmax_scale=self.scaling,\n        #     causal=True,\n        #     supported_attention_backends=config._supported_attention_backends,\n        #     window_size=self.window_size,\n        # )\n\n        # Gemma3 adds normalization for q and k\n        self.q_norm = Gemma3RMSNorm(\n            dim=self.head_dim, eps=config.text_config.rms_norm_eps\n        )\n        self.k_norm = Gemma3RMSNorm(\n            dim=self.head_dim, eps=config.text_config.rms_norm_eps\n        )\n\n    def rotary_emb(self, positions, q, k):\n        \"\"\"Apply RoPE using HF-exact formula with precomputed inv_freq.\"\"\"\n        positions_flat = positions.flatten().float()\n        num_tokens = positions_flat.shape[0]\n\n        with torch.autocast(device_type=q.device.type, enabled=False):\n            freqs = torch.outer(positions_flat, self._hf_inv_freq.float())\n            emb = freqs.repeat(1, 2)\n            cos = emb.cos().to(q.dtype).unsqueeze(1)\n            sin = emb.sin().to(q.dtype).unsqueeze(1)\n\n        q = q.reshape(num_tokens, -1, self.head_dim)\n        k = k.reshape(num_tokens, -1, self.head_dim)\n        q = q * cos + _rotate_half(q) * sin\n        k = k * cos + _rotate_half(k) * sin\n        return q, k\n\n    def forward(\n        self,\n        positions: torch.Tensor,\n        hidden_states: torch.Tensor,\n        attention_mask: torch.Tensor | None = None,\n    ) -> torch.Tensor:\n        qkv, _ = self.qkv_proj(hidden_states)\n        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)\n\n        batch_size, seq_len, _ = q.shape\n        q = q.view(batch_size, seq_len, self.num_heads, self.head_dim)\n        k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)\n        v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)\n\n        # Apply QK Norm\n        q = self.q_norm(q)\n        k = self.k_norm(k)\n\n        # Apply RoPE\n        q, k = self.rotary_emb(positions, q, k)\n        q = q.reshape(batch_size, seq_len, self.num_heads, self.head_dim)\n        k = k.reshape(batch_size, seq_len, self.num_kv_heads, self.head_dim)\n\n        # TODO(FlamingoPg): Support LocalAttention\n        query = q.transpose(1, 2)\n        key = k.transpose(1, 2)\n        value = v.transpose(1, 2)\n\n        min_val = torch.finfo(query.dtype).min\n        attn_mask = torch.zeros(\n            (seq_len, seq_len),\n            device=hidden_states.device,\n            dtype=query.dtype,\n        )\n        causal = torch.triu(\n            torch.ones(\n                (seq_len, seq_len), device=hidden_states.device, dtype=torch.bool\n            ),\n            diagonal=1,\n        )\n        attn_mask = attn_mask.masked_fill(causal, min_val)\n        if self.is_sliding and self.sliding_window is not None:\n            idx = torch.arange(seq_len, device=hidden_states.device)\n            dist = idx[None, :] - idx[:, None]\n            too_far = dist > self.sliding_window\n            attn_mask = attn_mask.masked_fill(too_far, min_val)\n\n        key_pad = ~attention_mask.to(torch.bool)\n        attn_mask = attn_mask[None, None, :, :].expand(batch_size, 1, seq_len, seq_len)\n        attn_mask = attn_mask.masked_fill(\n            key_pad[:, None, None, :].expand(batch_size, 1, seq_len, seq_len),\n            min_val,\n        )\n\n        attn_kwargs = {\n            \"attn_mask\": attn_mask,\n            \"dropout_p\": 0.0,\n            \"is_causal\": False,\n            \"scale\": self.scaling,\n        }\n        if query.shape[1] != key.shape[1]:\n            attn_kwargs[\"enable_gqa\"] = True\n        attn_output = torch.nn.functional.scaled_dot_product_attention(\n            query, key, value, **attn_kwargs\n        )\n        attn_output = attn_output.transpose(1, 2)\n\n        attn_output = attn_output.reshape(\n            batch_size, seq_len, self.num_heads * self.head_dim\n        )\n\n        output, _ = self.o_proj(attn_output)\n        return output\n\n\nclass Gemma3DecoderLayer(nn.Module):\n    def __init__(\n        self,\n        layer_id: int,\n        config: Gemma3Config,\n        quant_config: QuantizationConfig | None = None,\n        prefix: str = \"\",\n    ) -> None:\n        super().__init__()\n        self.hidden_size = config.text_config.hidden_size\n        self.self_attn = Gemma3Attention(\n            layer_id=layer_id,\n            config=config,\n            hidden_size=self.hidden_size,\n            num_heads=config.text_config.num_attention_heads,\n            num_kv_heads=getattr(\n                config.text_config,\n                \"num_key_value_heads\",\n                config.text_config.num_attention_heads,\n            ),\n            quant_config=quant_config,\n            prefix=f\"{prefix}.self_attn\",\n        )\n        self.mlp = Gemma3MLP(\n            hidden_size=self.hidden_size,\n            intermediate_size=config.text_config.intermediate_size,\n            hidden_act=config.text_config.hidden_activation,\n            quant_config=quant_config,\n            prefix=f\"{prefix}.mlp\",\n        )\n        self.input_layernorm = Gemma3RMSNorm(\n            config.text_config.hidden_size, eps=config.text_config.rms_norm_eps\n        )\n        self.post_attention_layernorm = Gemma3RMSNorm(\n            config.text_config.hidden_size, eps=config.text_config.rms_norm_eps\n        )\n        self.pre_feedforward_layernorm = Gemma3RMSNorm(\n            config.text_config.hidden_size, eps=config.text_config.rms_norm_eps\n        )\n        self.post_feedforward_layernorm = Gemma3RMSNorm(\n            config.text_config.hidden_size, eps=config.text_config.rms_norm_eps\n        )\n\n    def forward(\n        self,\n        positions: torch.Tensor,\n        hidden_states: torch.Tensor,\n        residual: torch.Tensor | None,\n        attention_mask: torch.Tensor | None = None,\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        # Self Attention\n        # Gemma3 uses \"sandwich norm\":\n        # x = x + norm(attn(norm(x)))\n        # So we treat input hidden_states as the residual base.\n\n        if residual is not None:\n            hidden_states = hidden_states + residual\n            residual = None\n\n        residual_input = hidden_states\n        hidden_states = self.input_layernorm(hidden_states)\n\n        hidden_states = self.self_attn(\n            positions=positions,\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n        )\n\n        hidden_states = self.post_attention_layernorm(hidden_states)\n        hidden_states = residual_input + hidden_states\n\n        # MLP\n        residual_mlp = hidden_states\n        hidden_states = self.pre_feedforward_layernorm(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = self.post_feedforward_layernorm(hidden_states)\n        hidden_states = residual_mlp + hidden_states\n\n        return hidden_states, None\n\n\nclass Gemma3TextScaledWordEmbedding(nn.Embedding):\n    def __init__(\n        self,\n        num_embeddings: int,\n        embedding_dim: int,\n        padding_idx: int,\n        embed_scale: Optional[float] = 1.0,\n    ):\n        super().__init__(num_embeddings, embedding_dim, padding_idx)\n        self.embed_scale = embed_scale\n\n    def forward(self, input_ids: torch.Tensor):\n        return super().forward(input_ids) * self.embed_scale\n\n\n# --- Siglip Vision Model Implementation ---\n\n\nclass QuickGELU(nn.Module):\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        return x * torch.sigmoid(1.702 * x)\n\n\nclass SiglipVisionEmbeddings(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.config = config\n        self.embed_dim = config.hidden_size\n        self.image_size = config.image_size\n        self.patch_size = config.patch_size\n\n        self.patch_embedding = nn.Conv2d(\n            in_channels=config.num_channels,\n            out_channels=self.embed_dim,\n            kernel_size=self.patch_size,\n            stride=self.patch_size,\n            padding=\"valid\",\n        )\n\n        self.num_patches = (self.image_size // self.patch_size) ** 2\n        self.num_positions = self.num_patches\n        # Use simple Embedding for position embeddings (usually small enough)\n        self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)\n        self.register_buffer(\n            \"position_ids\",\n            torch.arange(self.num_positions).expand((1, -1)),\n            persistent=False,\n        )\n\n    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:\n        target_dtype = self.patch_embedding.weight.dtype\n        patch_embeds = self.patch_embedding(\n            pixel_values.to(dtype=target_dtype)\n        )  # shape = [*, width, grid, grid]\n        embeddings = patch_embeds.flatten(2).transpose(1, 2)\n        embeddings = embeddings + self.position_embedding(self.position_ids)\n\n        return embeddings\n\n\nclass SiglipMLP(nn.Module):\n    def __init__(\n        self,\n        config,\n        act_layer: type[nn.Module] = QuickGELU,\n        quant_config: Optional[QuantizationConfig] = None,\n        prefix: str = \"\",\n    ):\n        super().__init__()\n        self.fc1 = ColumnParallelLinear(\n            config.hidden_size,\n            config.intermediate_size,\n            quant_config=quant_config,\n            prefix=add_prefix(\"fc1\", prefix),\n        )\n        self.act = act_layer()\n        self.fc2 = RowParallelLinear(\n            config.intermediate_size,\n            config.hidden_size,\n            quant_config=quant_config,\n            prefix=add_prefix(\"fc2\", prefix),\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x_parallel, _ = self.fc1(x)\n        x_parallel = self.act(x_parallel)\n        x, _ = self.fc2(x_parallel)\n        return x\n\n\nclass SiglipAttention(nn.Module):\n    def __init__(\n        self,\n        hidden_size: int,\n        num_heads: int,\n        quant_config: Optional[QuantizationConfig] = None,\n        prefix: str = \"\",\n    ):\n        super().__init__()\n        self.hidden_size = hidden_size\n        self.num_heads = num_heads\n        tp_size = get_tp_world_size()\n        self.head_dim = hidden_size // num_heads\n        self.num_heads_per_partition = num_heads // tp_size\n        self.scaling = self.head_dim**-0.5\n\n        self.qkv_proj = QKVParallelLinear(\n            hidden_size=hidden_size,\n            head_size=self.head_dim,\n            total_num_heads=num_heads,\n            total_num_kv_heads=num_heads,\n            bias=True,\n            quant_config=quant_config,\n            prefix=add_prefix(\"qkv_proj\", prefix),\n        )\n\n        self.out_proj = RowParallelLinear(\n            input_size=hidden_size,\n            output_size=hidden_size,\n            bias=True,\n            quant_config=quant_config,\n            prefix=add_prefix(\"out_proj\", prefix),\n        )\n\n        self.attn = LocalAttention(\n            num_heads=self.num_heads_per_partition,\n            head_size=self.head_dim,\n            num_kv_heads=self.num_heads_per_partition,\n            softmax_scale=self.scaling,\n            causal=False,  # Bidirectional for Vision\n        )\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        qkv, _ = self.qkv_proj(hidden_states)\n        q, k, v = qkv.split([self.hidden_size // get_tp_world_size()] * 3, dim=-1)\n\n        batch_size, seq_len, _ = q.shape\n        q = q.view(batch_size, seq_len, self.num_heads_per_partition, self.head_dim)\n        k = k.view(batch_size, seq_len, self.num_heads_per_partition, self.head_dim)\n        v = v.view(batch_size, seq_len, self.num_heads_per_partition, self.head_dim)\n\n        attn_output = self.attn(q, k, v)\n\n        attn_output = attn_output.reshape(\n            batch_size, seq_len, self.hidden_size // get_tp_world_size()\n        )\n\n        output, _ = self.out_proj(attn_output)\n        return output\n\n\nclass SiglipEncoderLayer(nn.Module):\n    def __init__(\n        self,\n        config,\n        act_layer: type[nn.Module] = QuickGELU,\n        norm_layer: type[nn.Module] = None,\n        quant_config: Optional[QuantizationConfig] = None,\n        prefix: str = \"\",\n    ) -> None:\n        super().__init__()\n        if norm_layer is None:\n            norm_layer = partial(nn.LayerNorm, eps=config.layer_norm_eps)\n        self.layer_norm1 = norm_layer(config.hidden_size)\n        self.layer_norm2 = norm_layer(config.hidden_size)\n        self.self_attn = SiglipAttention(\n            hidden_size=config.hidden_size,\n            num_heads=config.num_attention_heads,\n            quant_config=quant_config,\n            prefix=add_prefix(\"self_attn\", prefix),\n        )\n        self.mlp = SiglipMLP(\n            config,\n            act_layer=act_layer,\n            quant_config=quant_config,\n            prefix=add_prefix(\"mlp\", prefix),\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n    ) -> torch.Tensor:\n        residual = hidden_states\n        hidden_states = self.layer_norm1(hidden_states)\n        hidden_states = self.self_attn(hidden_states)\n        hidden_states = residual + hidden_states\n\n        residual = hidden_states\n        hidden_states = self.layer_norm2(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n        return hidden_states\n\n\nclass SiglipEncoder(nn.Module):\n    def __init__(\n        self,\n        config,\n        quant_config: Optional[QuantizationConfig] = None,\n        prefix: str = \"\",\n    ) -> None:\n        super().__init__()\n        self.config = config\n        num_hidden_layers = config.num_hidden_layers\n        norm_layer = partial(nn.LayerNorm, eps=config.layer_norm_eps)\n        self.layers = nn.ModuleList(\n            [\n                SiglipEncoderLayer(\n                    config=config,\n                    norm_layer=norm_layer,\n                    quant_config=quant_config,\n                    prefix=add_prefix(f\"layers.{layer_idx}\", prefix),\n                )\n                for layer_idx in range(num_hidden_layers)\n            ]\n        )\n\n    def forward(\n        self,\n        inputs_embeds: torch.Tensor,\n    ) -> torch.Tensor:\n        hidden_states = inputs_embeds\n        for encoder_layer in self.layers:\n            hidden_states = encoder_layer(hidden_states)\n        return hidden_states\n\n\nclass SiglipVisionTransformer(nn.Module):\n    def __init__(\n        self,\n        config,\n        quant_config: Optional[QuantizationConfig] = None,\n        prefix: str = \"\",\n    ) -> None:\n        super().__init__()\n        self.config = config\n        embed_dim = config.hidden_size\n        self.embeddings = SiglipVisionEmbeddings(config)\n        self.encoder = SiglipEncoder(\n            config=config,\n            quant_config=quant_config,\n            prefix=add_prefix(\"encoder\", prefix),\n        )\n        self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)\n\n    @property\n    def device(self) -> torch.device:\n        return self.encoder.layers[0].layer_norm1.weight.device\n\n    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:\n        hidden_states = self.embeddings(pixel_values.to(self.device))\n        last_hidden_state = self.encoder(inputs_embeds=hidden_states)\n        last_hidden_state = self.post_layernorm(last_hidden_state)\n        return last_hidden_state\n\n\nclass SiglipVisionModel(nn.Module):\n    def __init__(\n        self,\n        config,\n        quant_config: Optional[QuantizationConfig] = None,\n        prefix: str = \"\",\n    ):\n        super().__init__()\n        self.vision_model = SiglipVisionTransformer(\n            config, quant_config, prefix=add_prefix(\"vision_model\", prefix)\n        )\n\n    @property\n    def device(self) -> torch.device:\n        return self.vision_model.device\n\n    def forward(self, pixel_values: torch.Tensor):\n        return self.vision_model(pixel_values)\n\n\nclass Gemma3MultiModalProjector(nn.Module):\n    \"\"\"Projector for Gemma3 multimodal.\"\"\"\n\n    def __init__(self, config: Gemma3Config):\n        super().__init__()\n\n        self.mm_input_projection_weight = nn.Parameter(\n            torch.zeros(\n                config.vision_config.hidden_size, config.text_config.hidden_size\n            )\n        )\n\n        self.mm_soft_emb_norm = Gemma3RMSNorm(\n            config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps\n        )\n\n        self.patches_per_image = int(\n            config.vision_config.image_size // config.vision_config.patch_size\n        )\n        self.tokens_per_side = int(config.mm_tokens_per_image**0.5)\n        self.kernel_size = self.patches_per_image // self.tokens_per_side\n        self.avg_pool = nn.AvgPool2d(\n            kernel_size=self.kernel_size, stride=self.kernel_size\n        )\n\n    def forward(self, vision_outputs: torch.Tensor) -> torch.Tensor:\n        batch_size, seq_length, hidden_size = vision_outputs.shape\n\n        # Reshape for pooling\n        reshaped_vision_outputs = vision_outputs.transpose(1, 2)\n        reshaped_vision_outputs = reshaped_vision_outputs.reshape(\n            batch_size, hidden_size, self.patches_per_image, self.patches_per_image\n        )\n        reshaped_vision_outputs = reshaped_vision_outputs.contiguous()\n\n        # Apply pooling\n        pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)\n        pooled_vision_outputs = pooled_vision_outputs.flatten(2)\n        pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)\n\n        # Apply normalization\n        normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs)\n\n        # Project to text embedding space\n        projected_vision_outputs = torch.matmul(\n            normed_vision_outputs, self.mm_input_projection_weight\n        )\n\n        return projected_vision_outputs.type_as(vision_outputs)\n\n\nclass Gemma3TextModel(nn.Module):\n    def __init__(self, config: Gemma3Config):\n        super().__init__()\n        self.config = config\n        # TODO(yinfan.1024) support text encoding model quant later\n        self.quant_config = None\n\n        # Use VocabParallelEmbedding\n        from sglang.multimodal_gen.runtime.layers.vocab_parallel_embedding import (\n            VocabParallelEmbedding,\n        )\n\n        self.vocab_size = config.text_config.vocab_size\n        self.embed_tokens = VocabParallelEmbedding(\n            self.vocab_size,\n            config.text_config.hidden_size,\n            org_num_embeddings=config.text_config.vocab_size,\n            quant_config=self.quant_config,\n        )\n        self.embed_scale = config.text_config.hidden_size**0.5\n\n        self.layers = nn.ModuleList(\n            [\n                Gemma3DecoderLayer(\n                    layer_id=i,\n                    config=config,\n                    quant_config=self.quant_config,\n                    prefix=f\"{config.text_config.prefix}.layers.{i}\",\n                )\n                for i in range(config.text_config.num_hidden_layers)\n            ]\n        )\n\n        self.norm = Gemma3RMSNorm(\n            config.text_config.hidden_size, eps=config.text_config.rms_norm_eps\n        )\n\n    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:\n        out = self.embed_tokens(input_ids)\n        return out * torch.tensor(self.embed_scale, device=out.device, dtype=out.dtype)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor | None,\n        position_ids: torch.Tensor | None = None,\n        attention_mask: torch.Tensor | None = None,\n        inputs_embeds: torch.Tensor | None = None,\n        output_hidden_states: bool | None = None,\n        **kwargs,\n    ) -> BaseEncoderOutput:\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n\n        if inputs_embeds is not None:\n            hidden_states = inputs_embeds\n        else:\n            hidden_states = self.get_input_embeddings(input_ids)\n\n        residual = None\n\n        if position_ids is None:\n            position_ids = torch.arange(\n                0, hidden_states.shape[1], device=hidden_states.device\n            ).unsqueeze(0)\n\n        all_hidden_states: tuple[Any, ...] | None = () if output_hidden_states else None\n\n        for layer in self.layers:\n            if all_hidden_states is not None:\n                all_hidden_states += (hidden_states,)\n\n            hidden_states, residual = layer(\n                position_ids,\n                hidden_states,\n                residual,\n                attention_mask=attention_mask,\n            )\n\n        hidden_states = self.norm(hidden_states)\n\n        if all_hidden_states is not None:\n            all_hidden_states += (hidden_states,)\n\n        output = BaseEncoderOutput(\n            last_hidden_state=hidden_states,\n            hidden_states=all_hidden_states,\n        )\n\n        return output\n\n    def load_weights(self, weights: Any) -> set[str]:\n        # Copied from LlamaModel.load_weights but adapted\n        params_dict = dict(self.named_parameters())\n        loaded_params: set[str] = set()\n\n        def _load_with_shard_id(\n            weight_loader, param, loaded_weight: torch.Tensor, shard_id\n        ) -> None:\n            \"\"\"Call param.weight_loader with best-effort shard_id normalization.\n\n            Different fused-QKV implementations expect different shard_id types:\n            - Some expect strings: \"q\"/\"k\"/\"v\"\n            - Some expect integer indices: 0/1/2\n            We try the provided shard_id first, then fall back between str/int forms.\n            \"\"\"\n            try:\n                weight_loader(param, loaded_weight, shard_id)\n                return\n            except (AssertionError, TypeError):\n                pass\n\n            # Fall back between common representations.\n            if isinstance(shard_id, str):\n                mapping = {\"q\": 0, \"k\": 1, \"v\": 2}\n                if shard_id in mapping:\n                    weight_loader(param, loaded_weight, mapping[shard_id])\n                    return\n                if shard_id.isdigit():\n                    weight_loader(param, loaded_weight, int(shard_id))\n                    return\n            elif isinstance(shard_id, int):\n                mapping = {0: \"q\", 1: \"k\", 2: \"v\"}\n                if shard_id in mapping:\n                    weight_loader(param, loaded_weight, mapping[shard_id])\n                    return\n\n            # Re-raise with a clearer message.\n            raise TypeError(\n                f\"Unsupported shard_id={shard_id!r} for weight_loader={weight_loader} \"\n                f\"(param={getattr(param, 'name', '<param>')}).\"\n            )\n\n        stacked_params_mapping = getattr(\n            getattr(self.config, \"arch_config\", object()),\n            \"stacked_params_mapping\",\n            None,\n        )\n        if stacked_params_mapping is None:\n            stacked_params_mapping = [\n                # Fused QKV shards; downstream loaders may want \"q/k/v\" or 0/1/2.\n                (\".qkv_proj\", \".q_proj\", \"q\"),\n                (\".qkv_proj\", \".k_proj\", \"k\"),\n                (\".qkv_proj\", \".v_proj\", \"v\"),\n                (\".gate_up_proj\", \".gate_proj\", 0),\n                (\".gate_up_proj\", \".up_proj\", 1),\n            ]\n\n        for name, loaded_weight in weights:\n            if \"rotary_emb.inv_freq\" in name:\n                continue\n\n            # The config has stacked_params_mapping\n            for (\n                param_name,\n                weight_name,\n                shard_id,\n            ) in stacked_params_mapping:\n                if weight_name not in name:\n                    continue\n                name = name.replace(weight_name, param_name)\n\n                if name not in params_dict:\n                    continue\n\n                param = params_dict[name]\n                weight_loader = param.weight_loader\n                _load_with_shard_id(weight_loader, param, loaded_weight, shard_id)\n                break\n            else:\n                if name not in params_dict:\n                    continue\n\n                param = params_dict[name]\n                weight_loader = getattr(param, \"weight_loader\", default_weight_loader)\n                weight_loader(param, loaded_weight)\n\n            loaded_params.add(name)\n        return loaded_params\n\n\nclass Gemma3ForConditionalGeneration(nn.Module):\n    def __init__(\n        self,\n        config: Gemma3Config,\n        quant_config: Optional[QuantizationConfig] = None,\n        prefix: str = \"\",\n    ) -> None:\n        super().__init__()\n        self.config = config\n        self.quant_config = quant_config\n        self.text_config = config.text_config\n\n        # Vision Tower\n        self.vision_tower = SiglipVisionModel(\n            config=config.vision_config,\n            quant_config=quant_config,\n            prefix=add_prefix(\"vision_tower\", prefix),\n        )\n\n        # Projector\n        self.multi_modal_projector = Gemma3MultiModalProjector(config)\n\n        # Text Model\n        self.language_model = Gemma3TextModel(config)\n\n    def get_placeholder_mask(\n        self,\n        input_ids: torch.LongTensor,\n        inputs_embeds: torch.FloatTensor,\n        image_features: torch.FloatTensor,\n    ) -> torch.Tensor:\n        image_token_index = int(getattr(self.config, \"image_token_index\", -1))\n        if image_token_index < 0:\n            image_token_index = int(getattr(self.text_config, \"image_token_index\", -1))\n        special_image_mask = input_ids == image_token_index\n        n_image_tokens = int(special_image_mask.sum().item())\n        special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds)\n        n_image_features = int(image_features.shape[0] * image_features.shape[1])\n        if inputs_embeds[special_image_mask].numel() != image_features.numel():\n            raise ValueError(\n                f\"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}\"\n            )\n        return special_image_mask\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor,\n        pixel_values: torch.FloatTensor | None = None,\n        **kwargs,\n    ):\n        vocab_size = int(self.language_model.vocab_size)\n        image_token_index = int(getattr(self.config, \"image_token_index\", -1))\n        if image_token_index < 0:\n            image_token_index = int(getattr(self.text_config, \"image_token_index\", -1))\n\n        if input_ids is not None and image_token_index >= vocab_size:\n            special_image_mask = input_ids == image_token_index\n            llm_input_ids = input_ids.clone()\n            llm_input_ids[special_image_mask] = 0\n        else:\n            llm_input_ids = input_ids\n\n        inputs_embeds = self.language_model.get_input_embeddings(llm_input_ids)\n\n        if pixel_values is not None:\n            if pixel_values.dim() == 5:\n                pixel_values = pixel_values.reshape(\n                    -1,\n                    pixel_values.shape[2],\n                    pixel_values.shape[3],\n                    pixel_values.shape[4],\n                )\n            elif pixel_values.dim() == 3:\n                pixel_values = pixel_values.unsqueeze(0)\n            elif pixel_values.dim() != 4:\n                raise ValueError(f\"Unexpected pixel_values shape: {pixel_values.shape}\")\n\n            vision_outputs = self.vision_tower(pixel_values)\n            image_features = self.multi_modal_projector(vision_outputs)\n            image_features = image_features.to(\n                device=inputs_embeds.device, dtype=inputs_embeds.dtype\n            )\n            special_image_mask = self.get_placeholder_mask(\n                input_ids, inputs_embeds=inputs_embeds, image_features=image_features\n            )\n            inputs_embeds = inputs_embeds.masked_scatter(\n                special_image_mask, image_features\n            )\n\n        return self.language_model.forward(\n            llm_input_ids, inputs_embeds=inputs_embeds, **kwargs\n        )\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:\n        loaded_params: Set[str] = set()\n        params_dict = dict(self.named_parameters())\n\n        def _load_with_shard_id(\n            weight_loader, param, loaded_weight: torch.Tensor, shard_id\n        ) -> None:\n            \"\"\"Call param.weight_loader with best-effort shard_id normalization.\n\n            Different fused-QKV implementations expect different shard_id types:\n            - Some expect strings: \"q\"/\"k\"/\"v\"\n            - Some expect integer indices: 0/1/2\n            We try the provided shard_id first, then fall back between str/int forms.\n            \"\"\"\n            try:\n                weight_loader(param, loaded_weight, shard_id)\n                return\n            except (AssertionError, TypeError):\n                pass\n\n            # Fall back between common representations.\n            if isinstance(shard_id, str):\n                mapping = {\"q\": 0, \"k\": 1, \"v\": 2}\n                if shard_id in mapping:\n                    weight_loader(param, loaded_weight, mapping[shard_id])\n                    return\n                if shard_id.isdigit():\n                    weight_loader(param, loaded_weight, int(shard_id))\n                    return\n            elif isinstance(shard_id, int):\n                mapping = {0: \"q\", 1: \"k\", 2: \"v\"}\n                if shard_id in mapping:\n                    weight_loader(param, loaded_weight, mapping[shard_id])\n                    return\n\n            raise TypeError(\n                f\"Unsupported shard_id={shard_id!r} for weight_loader={weight_loader} \"\n                f\"(param={getattr(param, 'name', '<param>')}).\"\n            )\n\n        # Separate weights\n        language_model_weights: list[tuple[str, torch.Tensor]] = []\n        other_weights: list[tuple[str, torch.Tensor]] = []\n\n        for name, loaded_weight in weights:\n            # Handle prefix mapping if needed\n            # HF weights might be \"model.vision_tower...\", \"model.language_model...\"\n\n            if \"vision_tower\" in name or \"vision_model\" in name:\n                # Load vision tower weights\n                # Map name to local name\n                local_name = name\n                if \"model.vision_tower\" in name:\n                    local_name = name.replace(\"model.vision_tower\", \"vision_tower\")\n                elif \"vision_tower\" in name:\n                    pass  # already correct prefix if matching self.vision_tower\n                elif local_name.startswith(\"vision_model.\"):\n                    local_name = (\n                        \"vision_tower.vision_model.\"\n                        + local_name[len(\"vision_model.\") :]\n                    )\n\n                # We need to map HF Siglip names to our Siglip implementation\n                # Our Siglip: vision_tower.vision_model.encoder.layers...\n                # HF Siglip: vision_model.encoder.layers...\n\n                # If loading from Gemma3 checkpoint, it usually has \"model.vision_tower.vision_model...\"\n\n                if local_name in params_dict:\n                    param = params_dict[local_name]\n                    weight_loader = getattr(\n                        param, \"weight_loader\", default_weight_loader\n                    )\n                    weight_loader(param, loaded_weight)\n                    loaded_params.add(local_name)\n                else:\n                    qkv_shard_id = None\n                    fused_name = None\n                    if \".self_attn.q_proj.\" in local_name:\n                        fused_name = local_name.replace(\n                            \".self_attn.q_proj.\", \".self_attn.qkv_proj.\"\n                        )\n                        qkv_shard_id = \"q\"\n                    elif \".self_attn.k_proj.\" in local_name:\n                        fused_name = local_name.replace(\n                            \".self_attn.k_proj.\", \".self_attn.qkv_proj.\"\n                        )\n                        qkv_shard_id = \"k\"\n                    elif \".self_attn.v_proj.\" in local_name:\n                        fused_name = local_name.replace(\n                            \".self_attn.v_proj.\", \".self_attn.qkv_proj.\"\n                        )\n                        qkv_shard_id = \"v\"\n\n                    if fused_name is not None and fused_name in params_dict:\n                        param = params_dict[fused_name]\n                        weight_loader = getattr(\n                            param, \"weight_loader\", default_weight_loader\n                        )\n                        _load_with_shard_id(\n                            weight_loader, param, loaded_weight, qkv_shard_id\n                        )\n                        loaded_params.add(fused_name)\n                        continue\n\n                    if \".self_attn.proj.\" in local_name:\n                        candidate = local_name.replace(\n                            \".self_attn.proj.\", \".self_attn.out_proj.\"\n                        )\n                        if candidate in params_dict:\n                            param = params_dict[candidate]\n                            weight_loader = getattr(\n                                param, \"weight_loader\", default_weight_loader\n                            )\n                            weight_loader(param, loaded_weight)\n                            loaded_params.add(candidate)\n                            continue\n                    if \".self_attn.out_proj.\" in local_name:\n                        candidate = local_name.replace(\n                            \".self_attn.out_proj.\", \".self_attn.proj.\"\n                        )\n                        if candidate in params_dict:\n                            param = params_dict[candidate]\n                            weight_loader = getattr(\n                                param, \"weight_loader\", default_weight_loader\n                            )\n                            weight_loader(param, loaded_weight)\n                            loaded_params.add(candidate)\n                            continue\n\n                    # Try to find match\n                    suffix = local_name.split(\"vision_tower.\")[-1]\n                    # Try adding vision_model\n                    candidate = f\"vision_tower.vision_model.{suffix}\"\n                    if candidate in params_dict:\n                        param = params_dict[candidate]\n                        weight_loader = getattr(\n                            param, \"weight_loader\", default_weight_loader\n                        )\n                        weight_loader(param, loaded_weight)\n                        loaded_params.add(candidate)\n\n            elif \"multi_modal_projector\" in name:\n                local_name = name\n                if \"model.multi_modal_projector\" in name:\n                    local_name = name.replace(\n                        \"model.multi_modal_projector\", \"multi_modal_projector\"\n                    )\n\n                if local_name in params_dict:\n                    param = params_dict[local_name]\n                    weight_loader = getattr(\n                        param, \"weight_loader\", default_weight_loader\n                    )\n                    weight_loader(param, loaded_weight)\n                    loaded_params.add(local_name)\n\n            elif \"language_model\" in name or \"model.language_model\" in name:\n                # Strip prefix for language model\n                # If name is \"model.language_model.model.layers.0...\", we want \"model.layers.0...\" for Gemma3ForCausalLM\n                # Gemma3ForCausalLM has .model (Gemma3TextModel) and .lm_head\n\n                # HF: model.language_model.model.layers...\n                # Ours: language_model.model.layers...\n\n                # We pass (name, weight) to language_model.load_weights\n                # We should strip \"model.language_model.\" or \"language_model.\"\n\n                suffix = name\n                if \"model.language_model.\" in name:\n                    suffix = name.replace(\"model.language_model.\", \"\")\n                elif \"language_model.\" in name:\n                    suffix = name.replace(\"language_model.\", \"\")\n                if suffix.startswith(\"model.\"):\n                    suffix = suffix[len(\"model.\") :]\n\n                language_model_weights.append((suffix, loaded_weight))\n\n            else:\n                # Fallback for other weights (maybe direct lm_head if not nested?)\n                other_weights.append((name, loaded_weight))\n\n        if language_model_weights:\n            lm_loaded = self.language_model.load_weights(language_model_weights)\n            loaded_params.update({f\"language_model.{n}\" for n in lm_loaded})\n\n        return loaded_params\n\n    def get_attention_sliding_window_size(self):\n        if self.text_config is not None and hasattr(\n            self.text_config, \"get_attention_sliding_window_size\"\n        ):\n            return self.text_config.get_attention_sliding_window_size()\n        sliding_window = getattr(self.text_config, \"sliding_window\", None)\n        if sliding_window is None:\n            sliding_window = getattr(self.config, \"sliding_window\", None)\n        if sliding_window is None:\n            return None\n        return int(sliding_window) - 1\n\n\nEntryClass = Gemma3ForConditionalGeneration\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/encoders/hunyuan3d.py",
    "content": "# Copied and adapted from: https://github.com/Tencent-Hunyuan/Hunyuan3D-2\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom torchvision import transforms\nfrom transformers import (\n    CLIPVisionConfig,\n    CLIPVisionModelWithProjection,\n    Dinov2Config,\n    Dinov2Model,\n)\n\n\ndef get_1d_sincos_pos_embed_from_grid(embed_dim, pos):\n\n    assert embed_dim % 2 == 0\n    omega = np.arange(embed_dim // 2, dtype=np.float64)\n    omega /= embed_dim / 2.0\n    omega = 1.0 / 10000**omega  # (D/2,)\n\n    pos = pos.reshape(-1)  # (M,)\n    out = np.einsum(\"m,d->md\", pos, omega)  # (M, D/2), outer product\n\n    emb_sin = np.sin(out)  # (M, D/2)\n    emb_cos = np.cos(out)  # (M, D/2)\n\n    return np.concatenate([emb_sin, emb_cos], axis=1)\n\n\nclass ImageEncoder(nn.Module):\n    MODEL_CLASS = None\n    MODEL_CONFIG_CLASS = None\n    mean = []\n    std = []\n\n    def __init__(\n        self,\n        version=None,\n        config=None,\n        use_cls_token=True,\n        image_size=224,\n        **kwargs,\n    ):\n        super().__init__()\n\n        if config is None:\n            self.model = self.MODEL_CLASS.from_pretrained(version)\n        else:\n            self.model = self.MODEL_CLASS(self.MODEL_CONFIG_CLASS.from_dict(config))\n        self.model.eval()\n        self.model.requires_grad_(False)\n        self.use_cls_token = use_cls_token\n        self.size = image_size // 14\n        self.num_patches = (image_size // 14) ** 2\n        if self.use_cls_token:\n            self.num_patches += 1\n\n        self.transform = transforms.Compose(\n            [\n                transforms.Resize(\n                    image_size, transforms.InterpolationMode.BILINEAR, antialias=True\n                ),\n                transforms.CenterCrop(image_size),\n                transforms.Normalize(\n                    mean=self.mean,\n                    std=self.std,\n                ),\n            ]\n        )\n\n    def forward(self, image, mask=None, value_range=(-1, 1), **kwargs):\n        if value_range is not None:\n            low, high = value_range\n            image = (image - low) / (high - low)\n\n        image = image.to(self.model.device, dtype=self.model.dtype)\n        inputs = self.transform(image)\n        outputs = self.model(inputs)\n\n        last_hidden_state = outputs.last_hidden_state\n        if not self.use_cls_token:\n            last_hidden_state = last_hidden_state[:, 1:, :]\n\n        return last_hidden_state\n\n    def unconditional_embedding(self, batch_size, **kwargs):\n        device = next(self.model.parameters()).device\n        dtype = next(self.model.parameters()).dtype\n        zero = torch.zeros(\n            batch_size,\n            self.num_patches,\n            self.model.config.hidden_size,\n            device=device,\n            dtype=dtype,\n        )\n\n        return zero\n\n\nclass CLIPImageEncoder(ImageEncoder):\n    MODEL_CLASS = CLIPVisionModelWithProjection\n    MODEL_CONFIG_CLASS = CLIPVisionConfig\n    mean = [0.48145466, 0.4578275, 0.40821073]\n    std = [0.26862954, 0.26130258, 0.27577711]\n\n\nclass DinoImageEncoder(ImageEncoder):\n    MODEL_CLASS = Dinov2Model\n    MODEL_CONFIG_CLASS = Dinov2Config\n    mean = [0.485, 0.456, 0.406]\n    std = [0.229, 0.224, 0.225]\n\n\nclass DinoImageEncoderMV(DinoImageEncoder):\n    _aliases = [\n        \"hy3dshape.models.conditioner.DinoImageEncoderMV\",\n    ]\n\n    def __init__(\n        self,\n        version=None,\n        config=None,\n        use_cls_token=True,\n        image_size=224,\n        view_num=4,\n        **kwargs,\n    ):\n        super().__init__(version, config, use_cls_token, image_size, **kwargs)\n        self.view_num = view_num\n        self.num_patches = self.num_patches\n        pos = np.arange(self.view_num, dtype=np.float32)\n        view_embedding = torch.from_numpy(\n            get_1d_sincos_pos_embed_from_grid(self.model.config.hidden_size, pos)\n        ).float()\n\n        view_embedding = view_embedding.unsqueeze(1).repeat(1, self.num_patches, 1)\n        self.view_embed = view_embedding.unsqueeze(0)\n\n    def forward(self, image, mask=None, value_range=(-1, 1), view_idxs=None, **kwargs):\n        if value_range is not None:\n            low, high = value_range\n            image = (image - low) / (high - low)\n\n        image = image.to(self.model.device, dtype=self.model.dtype)\n\n        bs, num_views, c, h, w = image.shape\n        image = image.view(bs * num_views, c, h, w)\n\n        inputs = self.transform(image)\n        outputs = self.model(inputs)\n\n        last_hidden_state = outputs.last_hidden_state\n        last_hidden_state = last_hidden_state.view(\n            bs, num_views, last_hidden_state.shape[-2], last_hidden_state.shape[-1]\n        )\n\n        view_embedding = self.view_embed.to(last_hidden_state.dtype).to(\n            last_hidden_state.device\n        )\n        if view_idxs is not None:\n            assert len(view_idxs) == bs\n            view_embeddings = []\n            for i in range(bs):\n                view_idx = view_idxs[i]\n                assert num_views == len(view_idx)\n                view_embeddings.append(self.view_embed[:, view_idx, ...])\n            view_embedding = (\n                torch.cat(view_embeddings, 0)\n                .to(last_hidden_state.dtype)\n                .to(last_hidden_state.device)\n            )\n\n        if num_views != self.view_num:\n            view_embedding = view_embedding[:, :num_views, ...]\n        last_hidden_state = last_hidden_state + view_embedding\n        last_hidden_state = last_hidden_state.view(\n            bs, num_views * last_hidden_state.shape[-2], last_hidden_state.shape[-1]\n        )\n        return last_hidden_state\n\n    def unconditional_embedding(self, batch_size, view_idxs, **kwargs):\n        device = next(self.model.parameters()).device\n        dtype = next(self.model.parameters()).dtype\n        zero = torch.zeros(\n            batch_size,\n            self.num_patches * len(view_idxs[0]),\n            self.model.config.hidden_size,\n            device=device,\n            dtype=dtype,\n        )\n        return zero\n\n\ndef build_image_encoder(config):\n    if config[\"type\"] == \"CLIPImageEncoder\":\n        return CLIPImageEncoder(**config[\"kwargs\"])\n    elif config[\"type\"] == \"DinoImageEncoder\":\n        return DinoImageEncoder(**config[\"kwargs\"])\n    elif config[\"type\"] == \"DinoImageEncoderMV\":\n        return DinoImageEncoderMV(**config[\"kwargs\"])\n    else:\n        raise ValueError(f'Unknown image encoder type: {config[\"type\"]}')\n\n\nclass DualImageEncoder(nn.Module):\n    def __init__(\n        self,\n        main_image_encoder,\n        additional_image_encoder,\n    ):\n        super().__init__()\n        self.main_image_encoder = build_image_encoder(main_image_encoder)\n        self.additional_image_encoder = build_image_encoder(additional_image_encoder)\n\n    def forward(self, image, mask=None, **kwargs):\n        outputs = {\n            \"main\": self.main_image_encoder(image, mask=mask, **kwargs),\n            \"additional\": self.additional_image_encoder(image, mask=mask, **kwargs),\n        }\n        return outputs\n\n    def unconditional_embedding(self, batch_size, **kwargs):\n        outputs = {\n            \"main\": self.main_image_encoder.unconditional_embedding(\n                batch_size, **kwargs\n            ),\n            \"additional\": self.additional_image_encoder.unconditional_embedding(\n                batch_size, **kwargs\n            ),\n        }\n        return outputs\n\n\nclass SingleImageEncoder(nn.Module):\n    def __init__(\n        self,\n        main_image_encoder,\n    ):\n        super().__init__()\n        self.main_image_encoder = build_image_encoder(main_image_encoder)\n\n    def forward(self, image, mask=None, **kwargs):\n        outputs = {\n            \"main\": self.main_image_encoder(image, mask=mask, **kwargs),\n        }\n        return outputs\n\n    def unconditional_embedding(self, batch_size, **kwargs):\n        outputs = {\n            \"main\": self.main_image_encoder.unconditional_embedding(\n                batch_size, **kwargs\n            ),\n        }\n        return outputs\n\n\n# Entry class for model registry\nEntryClass = [\n    SingleImageEncoder,\n    DualImageEncoder,\n    DinoImageEncoder,\n    DinoImageEncoderMV,\n]\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/encoders/llama.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/models/llama.py\n\n# Adapted from\n# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py\n# Copyright 2023 The vLLM team.\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Inference-only LLaMA model compatible with HuggingFace weights.\"\"\"\n\nfrom collections.abc import Iterable\nfrom typing import Any\n\nimport torch\nfrom torch import nn\n\n# from ..utils import (extract_layer_index)\nfrom sglang.multimodal_gen.configs.models.encoders import BaseEncoderOutput, LlamaConfig\nfrom sglang.multimodal_gen.runtime.distributed import get_tp_world_size\nfrom sglang.multimodal_gen.runtime.layers.activation import SiluAndMul\n\n# from vllm.model_executor.layers.quantization import QuantizationConfig\nfrom sglang.multimodal_gen.runtime.layers.attention import LocalAttention\nfrom sglang.multimodal_gen.runtime.layers.layernorm import RMSNorm\nfrom sglang.multimodal_gen.runtime.layers.linear import (\n    MergedColumnParallelLinear,\n    QKVParallelLinear,\n    RowParallelLinear,\n)\nfrom sglang.multimodal_gen.runtime.layers.quantization import QuantizationConfig\nfrom sglang.multimodal_gen.runtime.layers.rotary_embedding import get_rope\nfrom sglang.multimodal_gen.runtime.layers.vocab_parallel_embedding import (\n    VocabParallelEmbedding,\n)\nfrom sglang.multimodal_gen.runtime.loader.weight_utils import (\n    default_weight_loader,\n    maybe_remap_kv_scale_name,\n)\nfrom sglang.multimodal_gen.runtime.models.encoders.base import TextEncoder\n\n\nclass LlamaMLP(nn.Module):\n\n    def __init__(\n        self,\n        hidden_size: int,\n        intermediate_size: int,\n        hidden_act: str,\n        quant_config: QuantizationConfig | None = None,\n        bias: bool = False,\n        prefix: str = \"\",\n    ) -> None:\n        super().__init__()\n        self.gate_up_proj = MergedColumnParallelLinear(\n            input_size=hidden_size,\n            output_sizes=[intermediate_size] * 2,\n            # output_size=intermediate_size,\n            bias=bias,\n            quant_config=quant_config,\n            prefix=f\"{prefix}.gate_up_proj\",\n        )\n        self.down_proj = RowParallelLinear(\n            input_size=intermediate_size,\n            output_size=hidden_size,\n            bias=bias,\n            quant_config=quant_config,\n            prefix=f\"{prefix}.down_proj\",\n        )\n        if hidden_act != \"silu\":\n            raise ValueError(\n                f\"Unsupported activation: {hidden_act}. \"\n                \"Only silu is supported for now.\"\n            )\n        self.act_fn = SiluAndMul()\n\n    def forward(self, x):\n        x, _ = self.gate_up_proj(x)\n        x = self.act_fn(x)\n        x, _ = self.down_proj(x)\n        return x\n\n\nclass LlamaAttention(nn.Module):\n\n    def __init__(\n        self,\n        config: LlamaConfig,\n        hidden_size: int,\n        num_heads: int,\n        num_kv_heads: int,\n        rope_theta: float = 10000,\n        rope_scaling: dict[str, Any] | None = None,\n        max_position_embeddings: int = 8192,\n        quant_config: QuantizationConfig | None = None,\n        bias: bool = False,\n        bias_o_proj: bool = False,\n        prefix: str = \"\",\n    ) -> None:\n        super().__init__()\n        # layer_idx = extract_layer_index(prefix)\n        self.hidden_size = hidden_size\n        tp_size = get_tp_world_size()\n        self.total_num_heads = num_heads\n        assert self.total_num_heads % tp_size == 0\n        self.num_heads = self.total_num_heads // tp_size\n        self.total_num_kv_heads = num_kv_heads\n        if self.total_num_kv_heads >= tp_size:\n            # Number of KV heads is greater than TP size, so we partition\n            # the KV heads across multiple tensor parallel GPUs.\n            assert self.total_num_kv_heads % tp_size == 0\n        else:\n            # Number of KV heads is less than TP size, so we replicate\n            # the KV heads across multiple tensor parallel GPUs.\n            assert tp_size % self.total_num_kv_heads == 0\n        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)\n        # MistralConfig has an optional head_dim introduced by Mistral-Nemo\n        self.head_dim = getattr(\n            config, \"head_dim\", self.hidden_size // self.total_num_heads\n        )\n        # Phi models introduced a partial_rotary_factor parameter in the config\n        partial_rotary_factor = getattr(config, \"partial_rotary_factor\", 1)\n        self.rotary_dim = int(partial_rotary_factor * self.head_dim)\n        self.q_size = self.num_heads * self.head_dim\n        self.kv_size = self.num_kv_heads * self.head_dim\n        self.scaling = self.head_dim**-0.5\n        self.rope_theta = rope_theta\n        self.max_position_embeddings = max_position_embeddings\n\n        self.qkv_proj = QKVParallelLinear(\n            hidden_size=hidden_size,\n            head_size=self.head_dim,\n            total_num_heads=self.total_num_heads,\n            total_num_kv_heads=self.total_num_kv_heads,\n            bias=bias,\n            quant_config=quant_config,\n            prefix=f\"{prefix}.qkv_proj\",\n        )\n\n        self.o_proj = RowParallelLinear(\n            input_size=self.total_num_heads * self.head_dim,\n            output_size=hidden_size,\n            bias=bias_o_proj,\n            quant_config=quant_config,\n            prefix=f\"{prefix}.o_proj\",\n        )\n\n        is_neox_style = True\n        is_gguf = (\n            quant_config\n            and hasattr(quant_config, \"get_name\")\n            and quant_config.get_name() == \"gguf\"\n        )\n        if is_gguf and config.model_type == \"llama\":\n            is_neox_style = False\n\n        self.rotary_emb = get_rope(\n            self.head_dim,\n            rotary_dim=self.rotary_dim,\n            max_position=max_position_embeddings,\n            base=int(rope_theta),\n            rope_scaling=rope_scaling,\n            is_neox_style=is_neox_style,\n        )\n\n        self.attn = LocalAttention(\n            self.num_heads,\n            self.head_dim,\n            self.num_kv_heads,\n            softmax_scale=self.scaling,\n            causal=True,\n            supported_attention_backends=config._supported_attention_backends,\n        )\n\n    def forward(\n        self,\n        positions: torch.Tensor,\n        hidden_states: torch.Tensor,\n    ) -> torch.Tensor:\n        qkv, _ = self.qkv_proj(hidden_states)\n        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)\n        q, k = self.rotary_emb(positions, q, k)\n        # attn_output = self.attn(q, k, v)\n        # use flash_attn_func\n        # TODO (Attn abstraction and backend)\n        # reshape q, k, v to (batch_size, seq_len, num_heads, head_dim)\n        batch_size = q.shape[0]\n        seq_len = q.shape[1]\n        q = q.reshape(batch_size, seq_len, self.num_heads, self.head_dim)\n        k = k.reshape(batch_size, seq_len, self.num_kv_heads, self.head_dim)\n        v = v.reshape(batch_size, seq_len, self.num_kv_heads, self.head_dim)\n        # import pdb; pdb.set_trace()\n        # attn_output = flash_attn_varlen_func(q, k, v, softmax_scale=self.scaling, causal=True)\n        attn_output = self.attn(q, k, v)\n        attn_output = attn_output.reshape(\n            batch_size, seq_len, self.num_heads * self.head_dim\n        )\n\n        output, _ = self.o_proj(attn_output)\n        return output\n\n\nclass LlamaDecoderLayer(nn.Module):\n\n    def __init__(\n        self,\n        config: LlamaConfig,\n        quant_config: QuantizationConfig | None = None,\n        prefix: str = \"\",\n    ) -> None:\n        super().__init__()\n        self.hidden_size = config.hidden_size\n        rope_theta = config.rope_parameters[\"rope_theta\"]\n        rope_scaling = config.rope_parameters\n        if rope_scaling is not None and getattr(\n            config, \"original_max_position_embeddings\", None\n        ):\n            rope_scaling[\"original_max_position_embeddings\"] = (\n                config.original_max_position_embeddings\n            )\n        max_position_embeddings = getattr(config, \"max_position_embeddings\", 8192)\n        # Support abacusai/Smaug-72B-v0.1 with attention_bias\n        # Support internlm/internlm-7b with bias\n        attention_bias = getattr(config, \"attention_bias\", False) or getattr(\n            config, \"bias\", False\n        )\n        bias_o_proj = attention_bias\n        # support internlm/internlm3-8b with qkv_bias\n        if hasattr(config, \"qkv_bias\"):\n            attention_bias = config.qkv_bias\n\n        self.self_attn = LlamaAttention(\n            config=config,\n            hidden_size=self.hidden_size,\n            num_heads=config.num_attention_heads,\n            num_kv_heads=getattr(\n                config, \"num_key_value_heads\", config.num_attention_heads\n            ),\n            rope_theta=rope_theta,\n            rope_scaling=rope_scaling,\n            max_position_embeddings=max_position_embeddings,\n            quant_config=quant_config,\n            bias=attention_bias,\n            bias_o_proj=bias_o_proj,\n            prefix=f\"{prefix}.self_attn\",\n        )\n        self.mlp = LlamaMLP(\n            hidden_size=self.hidden_size,\n            intermediate_size=config.intermediate_size,\n            hidden_act=config.hidden_act,\n            quant_config=quant_config,\n            bias=getattr(config, \"mlp_bias\", False),\n            prefix=f\"{prefix}.mlp\",\n        )\n        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.post_attention_layernorm = RMSNorm(\n            config.hidden_size, eps=config.rms_norm_eps\n        )\n\n    def forward(\n        self,\n        positions: torch.Tensor,\n        hidden_states: torch.Tensor,\n        residual: torch.Tensor | None,\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        # Self Attention\n        if residual is None:\n            residual = hidden_states\n            hidden_states = self.input_layernorm(hidden_states)\n        else:\n            hidden_states, residual = self.input_layernorm(hidden_states, residual)\n\n        hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states)\n\n        # Fully Connected\n        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)\n        hidden_states = self.mlp(hidden_states)\n        return hidden_states, residual\n\n\nclass LlamaModel(TextEncoder):\n\n    def __init__(\n        self,\n        config: LlamaConfig,\n    ):\n        super().__init__(config)\n\n        self.config = config\n        self.quant_config = self.config.quant_config\n        if config.lora_config is not None:\n            max_loras = 1\n            lora_vocab_size = 1\n            if hasattr(config.lora_config, \"max_loras\"):\n                max_loras = config.lora_config.max_loras\n            if hasattr(config.lora_config, \"lora_extra_vocab_size\"):\n                lora_vocab_size = config.lora_config.lora_extra_vocab_size\n            lora_vocab = lora_vocab_size * max_loras\n        else:\n            lora_vocab = 0\n        self.vocab_size = config.vocab_size + lora_vocab\n        self.org_vocab_size = config.vocab_size\n\n        self.embed_tokens = VocabParallelEmbedding(\n            self.vocab_size,\n            config.hidden_size,\n            org_num_embeddings=config.vocab_size,\n            quant_config=config.quant_config,\n        )\n\n        self.layers = nn.ModuleList(\n            [\n                LlamaDecoderLayer(\n                    config=config,\n                    quant_config=config.quant_config,\n                    prefix=f\"{config.prefix}.layers.{i}\",\n                )\n                for i in range(config.num_hidden_layers)\n            ]\n        )\n\n        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:\n        return self.embed_tokens(input_ids)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor | None,\n        position_ids: torch.Tensor | None = None,\n        attention_mask: torch.Tensor | None = None,\n        inputs_embeds: torch.Tensor | None = None,\n        output_hidden_states: bool | None = None,\n        **kwargs,\n    ) -> BaseEncoderOutput:\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        if inputs_embeds is not None:\n            hidden_states = inputs_embeds\n        else:\n            hidden_states = self.get_input_embeddings(input_ids)\n        residual = None\n\n        if position_ids is None:\n            position_ids = torch.arange(\n                0, hidden_states.shape[1], device=hidden_states.device\n            ).unsqueeze(0)\n\n        all_hidden_states: tuple[Any, ...] | None = () if output_hidden_states else None\n        for layer in self.layers:\n            if all_hidden_states is not None:\n                # TODO\n                all_hidden_states += (\n                    (hidden_states,)\n                    if residual is None\n                    else (hidden_states + residual,)\n                )\n            hidden_states, residual = layer(position_ids, hidden_states, residual)\n\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        # add hidden states from the last decoder layer\n        if all_hidden_states is not None:\n            all_hidden_states += (hidden_states,)\n\n        # TODO(will): maybe unify the output format with other models and use\n        # our own class\n        output = BaseEncoderOutput(\n            last_hidden_state=hidden_states,\n            # past_key_values=past_key_values if use_cache else None,\n            hidden_states=all_hidden_states,\n            # attentions=all_self_attns,\n        )\n\n        return output\n\n    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:\n\n        params_dict = dict(self.named_parameters())\n        loaded_params: set[str] = set()\n        for name, loaded_weight in weights:\n            if \"rotary_emb.inv_freq\" in name:\n                continue\n            if \"rotary_emb.cos_cached\" in name or \"rotary_emb.sin_cached\" in name:\n                # Models trained using ColossalAI may include these tensors in\n                # the checkpoint. Skip them.\n                continue\n            # if (self.quant_config is not None and\n            #     (scale_name := self.quant_config.get_cache_scale(name))):\n            #     # Loading kv cache quantization scales\n            #     param = params_dict[scale_name]\n            #     weight_loader = getattr(param, \"weight_loader\",\n            #                             default_weight_loader)\n            #     loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else\n            #                      loaded_weight[0])\n            #     weight_loader(param, loaded_weight)\n            #     loaded_params.add(scale_name)\n            #     continue\n            if \"scale\" in name:\n                # Remapping the name of FP8 kv-scale.\n                kv_scale_name: str | None = maybe_remap_kv_scale_name(name, params_dict)\n                if kv_scale_name is None:\n                    continue\n                else:\n                    name = kv_scale_name\n            for (\n                param_name,\n                weight_name,\n                shard_id,\n            ) in self.config.arch_config.stacked_params_mapping:\n                if weight_name not in name:\n                    continue\n                name = name.replace(weight_name, param_name)\n                # Skip loading extra bias for GPTQ models.\n                if name.endswith(\".bias\") and name not in params_dict:\n                    continue\n\n                if name not in params_dict:\n                    continue\n\n                param = params_dict[name]\n                weight_loader = param.weight_loader\n                weight_loader(param, loaded_weight, shard_id)\n                break\n            else:\n                # Skip loading extra bias for GPTQ models.\n                if name.endswith(\".bias\") and name not in params_dict:\n                    continue\n\n                if name not in params_dict:\n                    continue\n\n                param = params_dict[name]\n                weight_loader = getattr(param, \"weight_loader\", default_weight_loader)\n                weight_loader(param, loaded_weight)\n            loaded_params.add(name)\n        return loaded_params\n\n\nEntryClass = LlamaModel\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/encoders/mistral_3.py",
    "content": "# coding=utf-8\n# Copyright 2025 HuggingFace Inc. team. All rights reserved.\n#\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom typing import Iterable, Optional, Union\n\nimport torch\nfrom torch import nn\nfrom transformers import Cache, DynamicCache, LlavaConfig, Mistral3Config, MistralConfig\nfrom transformers.integrations.sdpa_attention import sdpa_attention_forward\nfrom transformers.masking_utils import create_causal_mask\nfrom transformers.modeling_outputs import BaseModelOutputWithPast\nfrom transformers.models.mistral3.modeling_mistral3 import (\n    Mistral3CausalLMOutputWithPast,\n    Mistral3ModelOutputWithPast,\n)\nfrom transformers.models.mistral.modeling_mistral import (\n    MistralMLP,\n    MistralRMSNorm,\n    MistralRotaryEmbedding,\n    apply_rotary_pos_emb,\n)\n\nfrom sglang.multimodal_gen.runtime.layers.attention import USPAttention\nfrom sglang.multimodal_gen.runtime.loader.weight_utils import default_weight_loader\nfrom sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).\n    The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to\n    (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, None, :, :].expand(\n        batch, num_key_value_heads, n_rep, slen, head_dim\n    )\n    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n\n\nclass MistralAttention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self, config: MistralConfig, layer_idx: int):\n        super().__init__()\n        self.config = config\n        self.layer_idx = layer_idx\n        self.num_key_value_groups = (\n            config.num_attention_heads // config.num_key_value_heads\n        )\n\n        self.head_dim = (\n            getattr(config, \"head_dim\", None)\n            or config.hidden_size // config.num_attention_heads\n        )\n        self.num_key_value_groups = (\n            config.num_attention_heads // config.num_key_value_heads\n        )\n        self.scaling = self.head_dim**-0.5\n        self.attention_dropout = config.attention_dropout\n        self.is_causal = True\n        self.q_proj = nn.Linear(\n            config.hidden_size, config.num_attention_heads * self.head_dim, bias=False\n        )\n        self.k_proj = nn.Linear(\n            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False\n        )\n        self.v_proj = nn.Linear(\n            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False\n        )\n        self.o_proj = nn.Linear(\n            config.num_attention_heads * self.head_dim, config.hidden_size, bias=False\n        )\n        self.is_causal = True\n        self.num_heads = config.num_attention_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        self.attn = USPAttention(\n            num_heads=self.num_heads,\n            head_size=self.head_dim,\n            dropout_rate=0,\n            softmax_scale=None,\n            causal=False,\n            supported_attention_backends={\n                AttentionBackendEnum.FA,\n                AttentionBackendEnum.TORCH_SDPA,\n            },\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        position_embeddings: tuple[torch.Tensor, torch.Tensor],\n        attention_mask: Optional[torch.Tensor],\n        past_key_values: Optional[Cache] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        **kwargs,\n    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:\n        input_shape = hidden_states.shape[:-1]\n        hidden_shape = (*input_shape, -1, self.head_dim)\n\n        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n\n        cos, sin = position_embeddings\n        query_states, key_states = apply_rotary_pos_emb(\n            query_states, key_states, cos, sin\n        )\n\n        if past_key_values is not None:\n            # sin and cos are specific to RoPE models; cache_position needed for the static cache\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}\n            key_states, value_states = past_key_values.update(\n                key_states, value_states, self.layer_idx, cache_kwargs\n            )\n\n        attention_interface = sdpa_attention_forward\n        attn_output, attn_weights = attention_interface(\n            self,\n            query_states,\n            key_states,\n            value_states,\n            attention_mask,\n            dropout=0.0,\n            scaling=self.scaling,\n            sliding_window=getattr(\n                self.config, \"sliding_window\", None\n            ),  # main diff with Llama\n            **kwargs,\n        )\n\n        attn_output = attn_output.reshape(*input_shape, -1).contiguous()\n        attn_output = self.o_proj(attn_output)\n        return attn_output\n\n\nclass MistralDecoderLayer(nn.Module):\n    def __init__(self, config: MistralConfig, layer_idx: int):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n        self.self_attn = MistralAttention(config=config, layer_idx=layer_idx)\n        self.mlp = MistralMLP(config)\n        self.input_layernorm = MistralRMSNorm(\n            config.hidden_size, eps=config.rms_norm_eps\n        )\n        self.post_attention_layernorm = MistralRMSNorm(\n            config.hidden_size, eps=config.rms_norm_eps\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Cache] = None,\n        use_cache: Optional[bool] = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        position_embeddings: Optional[\n            tuple[torch.Tensor, torch.Tensor]\n        ] = None,  # necessary, but kept here for BC\n        **kwargs,\n    ) -> torch.Tensor:\n        residual = hidden_states\n        hidden_states = self.input_layernorm(hidden_states)\n        # Self Attention\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            use_cache=use_cache,\n            cache_position=cache_position,\n            position_embeddings=position_embeddings,\n            **kwargs,\n        )\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n        return hidden_states\n\n\nclass MistralModel(nn.Module):\n    def __init__(self, config: MistralConfig):\n        super().__init__()\n        self.config = config\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(\n            config.vocab_size, config.hidden_size, self.padding_idx\n        )\n        self.layers = nn.ModuleList(\n            [\n                MistralDecoderLayer(config, layer_idx)\n                for layer_idx in range(config.num_hidden_layers)\n            ]\n        )\n        self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.rotary_emb = MistralRotaryEmbedding(config=config)\n        self.gradient_checkpointing = False\n        self.config._attn_implementation = \"sdpa\"\n\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Cache] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        output_hidden_states: Optional[bool] = None,\n        **kwargs,\n    ) -> BaseModelOutputWithPast:\n        if (input_ids is None) ^ (inputs_embeds is not None):\n            raise ValueError(\n                \"You must specify exactly one of input_ids or inputs_embeds\"\n            )\n\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        if use_cache and past_key_values is None:\n            past_key_values = DynamicCache(config=self.config)\n\n        if cache_position is None:\n            past_seen_tokens = (\n                past_key_values.get_seq_length() if past_key_values is not None else 0\n            )\n            cache_position = torch.arange(\n                past_seen_tokens,\n                past_seen_tokens + inputs_embeds.shape[1],\n                device=inputs_embeds.device,\n            )\n\n        if position_ids is None:\n            position_ids = cache_position.unsqueeze(0)\n        mask_function = create_causal_mask\n        causal_mask = mask_function(\n            config=self.config,\n            input_embeds=inputs_embeds,\n            attention_mask=attention_mask,\n            cache_position=cache_position,\n            past_key_values=past_key_values,\n            position_ids=position_ids,\n        )\n\n        hidden_states = inputs_embeds\n        position_embeddings = self.rotary_emb(hidden_states, position_ids)\n\n        hidden_states_pool = []\n        for decoder_layer in self.layers[: self.config.num_hidden_layers]:\n            hidden_states = decoder_layer(\n                hidden_states,\n                attention_mask=causal_mask,\n                position_ids=position_ids,\n                past_key_values=past_key_values,\n                use_cache=use_cache,\n                cache_position=cache_position,\n                position_embeddings=position_embeddings,\n                **kwargs,\n            )\n            if output_hidden_states:\n                hidden_states_pool.append(hidden_states)\n\n        hidden_states = self.norm(hidden_states)\n        if output_hidden_states:\n            hidden_states_pool.append(hidden_states)\n\n        return BaseModelOutputWithPast(\n            hidden_states=hidden_states_pool,\n            last_hidden_state=hidden_states,\n            past_key_values=past_key_values if use_cache else None,\n        )\n\n\nclass Mistral3Model(nn.Module):\n    _checkpoint_conversion_mapping = {\"language_model.model\": \"language_model\"}\n\n    def __init__(self, config: Mistral3Config):\n        super().__init__()\n        self.language_model = MistralModel(config.text_config)\n        self.config = config\n\n    def get_input_embeddings(self):\n        return self.language_model.embed_tokens\n\n    def set_decoder(self, decoder):\n        self.language_model = decoder\n\n    def get_decoder(self):\n        return self.language_model\n\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        output_hidden_states: Optional[bool] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Cache] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidoutput_hidden_statesden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        image_sizes: Optional[torch.Tensor] = None,\n        **kwargs,\n    ) -> Union[tuple, Mistral3ModelOutputWithPast]:\n        output_attentions = False\n        output_hidden_states = True\n\n        if (input_ids is None) ^ (inputs_embeds is not None):\n            raise ValueError(\n                \"You must specify exactly one of input_ids or inputs_embeds\"\n            )\n\n        if inputs_embeds is None:\n            inputs_embeds = self.get_input_embeddings()(input_ids)\n\n        outputs: BaseModelOutputWithPast = self.language_model(\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=True,\n            cache_position=cache_position,\n            **kwargs,\n        )\n\n        return Mistral3ModelOutputWithPast(\n            last_hidden_state=outputs.last_hidden_state,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n\nclass Mistral3ForConditionalGeneration(nn.Module):\n    _checkpoint_conversion_mapping = {\n        \"^language_model.model\": \"model.language_model\",\n        \"^multi_modal_projector\": \"model.multi_modal_projector\",\n        \"^language_model.lm_head\": \"lm_head\",\n    }\n    _tied_weights_keys = [\"lm_head.weight\"]\n\n    def __init__(self, config: LlavaConfig):\n        super().__init__()\n        self.model = Mistral3Model(config.arch_config)\n\n    def get_input_embeddings(self):\n        return self.model.get_input_embeddings()\n\n    def set_decoder(self, decoder):\n        self.model.set_decoder(decoder)\n\n    def get_decoder(self):\n        return self.model.get_decoder()\n\n    # Make modules available through conditional class for BC\n    @property\n    def language_model(self):\n        return self.model.language_model\n\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Cache] = None,\n        output_hidden_states: Optional[bool] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        logits_to_keep: Union[int, torch.Tensor] = 0,\n        image_sizes: Optional[torch.Tensor] = None,\n        **kwargs,\n    ) -> Union[tuple, Mistral3CausalLMOutputWithPast]:\n        r\"\"\"\n        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n        Example:\n\n        \"\"\"\n        output_hidden_states = True\n\n        outputs = self.model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_hidden_states=output_hidden_states,\n            return_dict=True,\n            cache_position=cache_position,\n            image_sizes=image_sizes,\n            **kwargs,\n        )\n\n        return Mistral3CausalLMOutputWithPast(\n            hidden_states=outputs.hidden_states,\n        )\n\n    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:\n        # Define mapping for stacked parameters\n        params_dict = dict(self.named_parameters())\n        loaded_params: set[str] = set()\n        for name, loaded_weight in weights:\n            name_lower = name.lower()\n            if (\n                \"vision\" in name_lower\n                or \"multi\" in name_lower\n                or \"lm_head\" in name_lower\n            ):\n                continue\n            final_name = name.replace(\"language_model.model.\", \"model.language_model.\")\n\n            if final_name in params_dict:\n                param = params_dict[final_name]\n                weight_loader = getattr(param, \"weight_loader\", default_weight_loader)\n                weight_loader(param, loaded_weight)\n                loaded_params.add(final_name)\n            else:\n                logger.warning(f\"Param {name=} {final_name=} from weight is not loaded\")\n\n        return loaded_params\n\n\nEntryClass = Mistral3ForConditionalGeneration\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/encoders/qwen2_5vl.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\nfrom transformers import (\n    Cache,\n    DynamicCache,\n    PretrainedConfig,\n    Qwen2_5_VLTextConfig,\n    Qwen2RMSNorm,\n)\nfrom transformers.masking_utils import (\n    create_causal_mask,\n    create_sliding_window_causal_mask,\n)\nfrom transformers.modeling_flash_attention_utils import FlashAttentionKwargs\nfrom transformers.modeling_outputs import BaseModelOutputWithPast\nfrom transformers.utils import TransformersKwargs, is_torchdynamo_compiling\n\nfrom sglang.multimodal_gen.configs.models.encoders.qwen_image import Qwen2_5VLConfig\nfrom sglang.multimodal_gen.runtime.layers.attention import LocalAttention\nfrom sglang.multimodal_gen.runtime.layers.linear import (\n    MergedColumnParallelLinear,\n    RowParallelLinear,\n)\nfrom sglang.multimodal_gen.runtime.layers.quantization import QuantizationConfig\nfrom sglang.multimodal_gen.runtime.loader.weight_utils import default_weight_loader\nfrom sglang.multimodal_gen.runtime.models.encoders.base import TextEncoder\nfrom sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum\nfrom sglang.multimodal_gen.runtime.utils.common import add_prefix\n\n# coding=utf-8\n# Adapted from\n# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py\n# Copyright 2024 The Qwen team.\n# Copyright 2023 The vLLM team.\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Inference-only Qwen2-VL model compatible with HuggingFace weights.\"\"\"\nimport logging\nfrom typing import Callable, Iterable, Optional, Tuple, Union\n\ntry:\n    from typing import Unpack  # type: ignore[attr-defined]\nexcept ImportError:\n    # Python 3.10 and below\n    from typing_extensions import Unpack\n\nimport torch\nimport torch.nn as nn\nfrom transformers.activations import ACT2FN\nfrom transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (\n    Qwen2_5_VisionTransformerPretrainedModel,\n    Qwen2_5_VLAttention,\n    Qwen2_5_VLCausalLMOutputWithPast,\n    Qwen2_5_VLModelOutputWithPast,\n    Qwen2_5_VLRotaryEmbedding,\n    Qwen2MLP,\n    apply_multimodal_rotary_pos_emb,\n    eager_attention_forward,\n)\n\nlogger = logging.getLogger(__name__)\n\n\nclass Qwen2_5_VLAttention(nn.Module):\n    \"\"\"\n    Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer\n    and \"Generating Long Sequences with Sparse Transformers\".\n    \"\"\"\n\n    def __init__(self, config: Qwen2_5_VLTextConfig, layer_idx: Optional[int] = None):\n        super().__init__()\n        self.config = config\n        self.layer_idx = layer_idx\n        if layer_idx is None:\n            logger.warn(\n                f\"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will \"\n                \"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` \"\n                \"when creating this class.\"\n            )\n\n        self.hidden_size = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.head_dim = self.hidden_size // self.num_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        self.num_key_value_groups = self.num_heads // self.num_key_value_heads\n        self.is_causal = True\n        self.attention_dropout = config.attention_dropout\n        self.rope_scaling = config.rope_scaling\n        self.scaling = self.head_dim**-0.5\n\n        if (self.head_dim * self.num_heads) != self.hidden_size:\n            raise ValueError(\n                f\"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}\"\n                f\" and `num_heads`: {self.num_heads}).\"\n            )\n        self.q_proj = nn.Linear(\n            self.hidden_size, self.num_heads * self.head_dim, bias=True\n        )\n        self.k_proj = nn.Linear(\n            self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True\n        )\n        self.v_proj = nn.Linear(\n            self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True\n        )\n        self.o_proj = nn.Linear(\n            self.num_heads * self.head_dim, self.hidden_size, bias=False\n        )\n        self.sliding_window = (\n            config.sliding_window\n            if config.layer_types[layer_idx] == \"sliding_attention\"\n            else None\n        )\n\n        self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config)\n        self.attn = LocalAttention(\n            num_heads=self.num_heads,\n            head_size=self.head_dim,\n            num_kv_heads=self.num_key_value_heads,\n            softmax_scale=self.scaling,\n            causal=True,\n            supported_attention_backends=(\n                AttentionBackendEnum.FA,\n                AttentionBackendEnum.TORCH_SDPA,\n            ),\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Cache] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        position_embeddings: Optional[\n            tuple[torch.Tensor, torch.Tensor]\n        ] = None,  # necessary, but kept here for BC\n        **kwargs: Unpack[FlashAttentionKwargs],\n    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:\n        bsz, q_len, _ = hidden_states.size()\n\n        query_states = self.q_proj(hidden_states)\n        key_states = self.k_proj(hidden_states)\n        value_states = self.v_proj(hidden_states)\n\n        query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)\n        key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)\n        value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)\n\n        cos, sin = position_embeddings\n        query_states, key_states = apply_multimodal_rotary_pos_emb(\n            query_states, key_states, cos, sin, self.rope_scaling[\"mrope_section\"]\n        )\n\n        if past_key_values is not None:\n            cache_kwargs = {\n                \"sin\": sin,\n                \"cos\": cos,\n                \"cache_position\": cache_position,\n            }  # Specific to RoPE models\n            key_states, value_states = past_key_values.update(\n                key_states, value_states, self.layer_idx, cache_kwargs\n            )\n\n        attention_interface: Callable = eager_attention_forward\n        # if self.config._attn_implementation != \"eager\":\n        # attention_interface = ALL_ATTENTION_FUNCTIONS[\"sdpa\"]\n        query_states = query_states.transpose(1, 2)\n        key_states = key_states.transpose(1, 2)\n        value_states = value_states.transpose(1, 2)\n        attn_output = self.attn(query_states, key_states, value_states)\n        #\n        # attn_output, attn_weights = attention_interface(\n        #     self,\n        #     query_states,\n        #     key_states,\n        #     value_states,\n        #     attention_mask,\n        #     dropout=0.0 if not self.training else self.attention_dropout,\n        #     scaling=self.scaling,\n        #     sliding_window=self.sliding_window,\n        #     position_ids=position_ids,  # pass positions for FA2\n        #     **kwargs,\n        # )\n\n        attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()\n        attn_output = self.o_proj(attn_output)\n        return attn_output\n\n\nclass Qwen2_5_VLDecoderLayer(nn.Module):\n    def __init__(self, config: Qwen2_5_VLTextConfig, layer_idx: int):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n\n        if (\n            config.use_sliding_window\n            and config._attn_implementation != \"flash_attention_2\"\n        ):\n            logger.warning(\n                f\"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; \"\n                \"unexpected results may be encountered.\"\n            )\n        self.self_attn = Qwen2_5_VLAttention(config, layer_idx)\n\n        self.mlp = Qwen2MLP(config)\n        self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.post_attention_layernorm = Qwen2RMSNorm(\n            config.hidden_size, eps=config.rms_norm_eps\n        )\n        self.attention_type = config.layer_types[layer_idx]\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[tuple[torch.Tensor]] = None,\n        output_attentions: Optional[bool] = False,\n        use_cache: Optional[bool] = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        position_embeddings: Optional[\n            tuple[torch.Tensor, torch.Tensor]\n        ] = None,  # necessary, but kept here for BC\n        **kwargs: Unpack[FlashAttentionKwargs],\n    ) -> tuple[\n        torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]\n    ]:\n        \"\"\"\n        Args:\n            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size\n                `(batch, sequence_length)` where padding elements are indicated by 0.\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n            use_cache (`bool`, *optional*):\n                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n                (see `past_key_values`).\n            past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states\n            cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):\n                Indices depicting the position of the input sequence tokens in the sequence.\n            position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):\n                Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,\n                with `head_dim` being the embedding dimension of each attention head.\n            kwargs (`dict`, *optional*):\n                Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code\n                into the model\n        \"\"\"\n\n        residual = hidden_states\n\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Self Attention\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            output_attentions=output_attentions,\n            use_cache=use_cache,\n            cache_position=cache_position,\n            position_embeddings=position_embeddings,\n            **kwargs,\n        )\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n\n        return hidden_states\n\n\nclass Qwen2_5_VLMLP(nn.Module):\n    def __init__(\n        self,\n        in_features: int,\n        hidden_features: int = None,\n        bias: bool = True,\n        hidden_act=\"silu\",\n        quant_config: Optional[QuantizationConfig] = None,\n        prefix: str = \"\",\n    ):\n        super().__init__()\n        self.gate_up_proj = MergedColumnParallelLinear(\n            input_size=in_features,\n            output_sizes=[hidden_features] * 2,  # [gate_proj, up_proj]\n            bias=bias,\n            quant_config=quant_config,\n            prefix=add_prefix(\"gate_up_proj\", prefix),\n        )\n        self.down_proj = RowParallelLinear(\n            hidden_features,\n            in_features,\n            bias=bias,\n            quant_config=quant_config,\n            prefix=add_prefix(\"down_proj\", prefix),\n        )\n        self.act = ACT2FN[hidden_act]\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        gate_up, _ = self.gate_up_proj(x)\n        gate, up = gate_up.chunk(2, dim=-1)\n        x = self.act(gate) * up\n        x_down, _ = self.down_proj(x)\n        return x_down\n\n\nclass Qwen2_5_VLTextModel(nn.Module):\n    def __init__(self, config: PretrainedConfig):\n        super().__init__()\n        self.config = config\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(\n            config.vocab_size, config.hidden_size, self.padding_idx\n        )\n        self.layers = nn.ModuleList(\n            [\n                Qwen2_5_VLDecoderLayer(config, layer_idx)\n                for layer_idx in range(config.num_hidden_layers)\n            ]\n        )\n        self._attn_implementation = config._attn_implementation\n        self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config)\n        self.has_sliding_layers = \"sliding_attention\" in self.config.layer_types\n\n        self.gradient_checkpointing = False\n        # Initialize weights and apply final processing\n        # self.post_init()\n\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Cache] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        **kwargs: Unpack[FlashAttentionKwargs],\n    ) -> Union[tuple, BaseModelOutputWithPast]:\n        output_attentions = (\n            output_attentions\n            if output_attentions is not None\n            else self.config.output_attentions\n        )\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        if (input_ids is None) ^ (inputs_embeds is not None):\n            raise ValueError(\n                \"You must specify exactly one of input_ids or inputs_embeds\"\n            )\n\n        # torch.jit.trace() doesn't support cache objects in the output\n        if use_cache and past_key_values is None and not torch.jit.is_tracing():\n            past_key_values = DynamicCache(config=self.config)\n\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        if cache_position is None:\n            past_seen_tokens = (\n                past_key_values.get_seq_length() if past_key_values is not None else 0\n            )\n            cache_position = torch.arange(\n                past_seen_tokens,\n                past_seen_tokens + inputs_embeds.shape[1],\n                device=inputs_embeds.device,\n            )\n\n        # the hard coded `3` is for temporal, height and width.\n        if position_ids is None:\n            position_ids = cache_position.view(1, 1, -1).expand(\n                3, inputs_embeds.shape[0], -1\n            )\n        elif position_ids.ndim == 2:\n            position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)\n\n        # NOTE: we need to pass text position ids for packing. Qwen2-VL uses 3D positions\n        # where each dim indicates visual spatial positions for temporal/height/width grids.\n        # There are two scenarios when FA2-like packed masking might be activated.\n        # 1. User specifically passed packed `position_ids` and no attention mask.\n        #    In this case we expect the user to create correct position ids for all 3 grids\n        #    and prepend text-only position ids to it. The final tensor will be [4, bs, seq-len]\n        # 2. User runs forward with no attention mask and no position ids. In this case, position ids\n        #    are prepared by the model (`get_rope_index`) as `[4, bs, seq-len]` tensor. Text-only positions are\n        #    prepended by us when creating positions so that the mask is constructed correctly. NOTE: failing to pass\n        #    text-only positions will cause incorrect mask construction, do not change `prepare_input_for_generation`\n        if position_ids.ndim == 3 and position_ids.shape[0] == 4:\n            text_position_ids = position_ids[0]\n            position_ids = position_ids[1:]\n        else:\n            text_position_ids = position_ids[0]\n\n        # It may already have been prepared by e.g. `generate`\n        if not isinstance(causal_mask_mapping := attention_mask, dict):\n            # Prepare mask arguments\n            mask_kwargs = {\n                \"config\": self.config,\n                \"input_embeds\": inputs_embeds,\n                \"attention_mask\": attention_mask,\n                \"cache_position\": cache_position,\n                \"past_key_values\": past_key_values,\n                \"position_ids\": text_position_ids,\n            }\n            # Create the masks\n            causal_mask_mapping = {\n                \"full_attention\": create_causal_mask(**mask_kwargs),\n            }\n            # The sliding window alternating layers are not always activated depending on the config\n            if self.has_sliding_layers:\n                causal_mask_mapping[\"sliding_attention\"] = (\n                    create_sliding_window_causal_mask(**mask_kwargs)\n                )\n\n        hidden_states = inputs_embeds\n\n        # create position embeddings to be shared across the decoder layers\n        position_embeddings = self.rotary_emb(hidden_states, position_ids)\n\n        # decoder layers\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attns = () if output_attentions else None\n\n        for decoder_layer in self.layers:\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n\n            hidden_states = decoder_layer(\n                hidden_states,\n                attention_mask=causal_mask_mapping[decoder_layer.attention_type],\n                position_ids=text_position_ids,\n                past_key_values=past_key_values,\n                output_attentions=output_attentions,\n                use_cache=use_cache,\n                cache_position=cache_position,\n                position_embeddings=position_embeddings,\n                **kwargs,\n            )\n\n        hidden_states = self.norm(hidden_states)\n\n        # add hidden states from the last decoder layer\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        if not return_dict:\n            return tuple(\n                v\n                for v in [\n                    hidden_states,\n                    past_key_values,\n                    all_hidden_states,\n                    all_self_attns,\n                ]\n                if v is not None\n            )\n        return BaseModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=past_key_values,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attns,\n        )\n\n\nclass Qwen2_5_VLModel(nn.Module):\n    base_model_prefix = \"\"\n    _checkpoint_conversion_mapping = {\"^model\": \"language_model\"}\n    # Reference: fix gemma3 grad acc #37208\n    accepts_loss_kwargs = False\n    _no_split_modules = [\"Qwen2_5_VLDecoderLayer\", \"Qwen2_5_VLVisionBlock\"]\n\n    def __init__(self, config, enable_image_understanding: bool = False):\n        super().__init__()\n        self.language_model = Qwen2_5_VLTextModel(config.text_config)\n\n        if enable_image_understanding:\n            self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(\n                config.vision_config\n            )\n            self.visual.to(torch.get_default_dtype())\n        self.rope_deltas = None  # cache rope_deltas here\n        self.config = config\n        # Initialize weights and apply final processing\n        # self.post_init()\n\n    def get_input_embeddings(self):\n        return self.language_model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.language_model.embed_tokens = value\n\n    def set_decoder(self, decoder):\n        self.language_model = decoder\n\n    def get_decoder(self):\n        return self.language_model\n\n    def get_rope_index(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        image_grid_thw: Optional[torch.LongTensor] = None,\n        video_grid_thw: Optional[torch.LongTensor] = None,\n        second_per_grid_ts: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Calculate the 3D rope index based on image and video's temporal, height and width in LLM.\n\n        Explanation:\n            Each embedding sequence contains vision embedding and text embedding or just contains text embedding.\n\n            For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs.\n            Examples:\n                input_ids: [T T T T T], here T is for text.\n                temporal position_ids: [0, 1, 2, 3, 4]\n                height position_ids: [0, 1, 2, 3, 4]\n                width position_ids: [0, 1, 2, 3, 4]\n\n            For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part\n            and 1D rotary position embedding for text part.\n            Examples:\n                Temporal (Time): 3 patches, representing different segments of the video in time.\n                Height: 2 patches, dividing each frame vertically.\n                Width: 2 patches, dividing each frame horizontally.\n                We also have some important parameters:\n                fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second.\n                tokens_per_second: This is a crucial parameter. It dictates how many \"time-steps\" or \"temporal tokens\" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity.\n                temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames.\n                interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs.\n                input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.\n                vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100]\n                vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]\n                vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]\n                text temporal position_ids: [101, 102, 103, 104, 105]\n                text height position_ids: [101, 102, 103, 104, 105]\n                text width position_ids: [101, 102, 103, 104, 105]\n                Here we calculate the text start position_ids as the max vision position_ids plus 1.\n\n        Args:\n            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide\n                it.\n            image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):\n                The temporal, height and width of feature shape of each image in LLM.\n            video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):\n                The temporal, height and width of feature shape of each video in LLM.\n            second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):\n                The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.\n            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n                - 1 for tokens that are **not masked**,\n                - 0 for tokens that are **masked**.\n\n        Returns:\n            position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)\n            mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)\n        \"\"\"\n        spatial_merge_size = self.config.vision_config.spatial_merge_size\n        image_token_id = self.config.image_token_id\n        video_token_id = self.config.video_token_id\n        vision_start_token_id = self.config.vision_start_token_id\n        mrope_position_deltas = []\n        if input_ids is not None and (\n            image_grid_thw is not None or video_grid_thw is not None\n        ):\n            total_input_ids = input_ids\n            if attention_mask is None:\n                attention_mask = torch.ones_like(total_input_ids)\n            position_ids = torch.ones(\n                3,\n                input_ids.shape[0],\n                input_ids.shape[1],\n                dtype=input_ids.dtype,\n                device=input_ids.device,\n            )\n            image_index, video_index = 0, 0\n            attention_mask = attention_mask.to(total_input_ids.device)\n            for i, input_ids in enumerate(total_input_ids):\n                input_ids = input_ids[attention_mask[i] == 1]\n                image_nums, video_nums = 0, 0\n                vision_start_indices = torch.argwhere(\n                    input_ids == vision_start_token_id\n                ).squeeze(1)\n                vision_tokens = input_ids[vision_start_indices + 1]\n                image_nums = (vision_tokens == image_token_id).sum()\n                video_nums = (vision_tokens == video_token_id).sum()\n                input_tokens = input_ids.tolist()\n                llm_pos_ids_list: list = []\n                st = 0\n                remain_images, remain_videos = image_nums, video_nums\n                for _ in range(image_nums + video_nums):\n                    if image_token_id in input_tokens and remain_images > 0:\n                        ed_image = input_tokens.index(image_token_id, st)\n                    else:\n                        ed_image = len(input_tokens) + 1\n                    if video_token_id in input_tokens and remain_videos > 0:\n                        ed_video = input_tokens.index(video_token_id, st)\n                    else:\n                        ed_video = len(input_tokens) + 1\n                    if ed_image < ed_video:\n                        t, h, w = (\n                            image_grid_thw[image_index][0],\n                            image_grid_thw[image_index][1],\n                            image_grid_thw[image_index][2],\n                        )\n                        second_per_grid_t = 0\n                        image_index += 1\n                        remain_images -= 1\n                        ed = ed_image\n\n                    else:\n                        t, h, w = (\n                            video_grid_thw[video_index][0],\n                            video_grid_thw[video_index][1],\n                            video_grid_thw[video_index][2],\n                        )\n                        if second_per_grid_ts is not None:\n                            second_per_grid_t = second_per_grid_ts[video_index]\n                        else:\n                            second_per_grid_t = 1.0\n                        video_index += 1\n                        remain_videos -= 1\n                        ed = ed_video\n                    llm_grid_t, llm_grid_h, llm_grid_w = (\n                        t.item(),\n                        h.item() // spatial_merge_size,\n                        w.item() // spatial_merge_size,\n                    )\n                    text_len = ed - st\n\n                    st_idx = (\n                        llm_pos_ids_list[-1].max() + 1\n                        if len(llm_pos_ids_list) > 0\n                        else 0\n                    )\n                    llm_pos_ids_list.append(\n                        torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx\n                    )\n\n                    range_tensor = torch.arange(llm_grid_t).view(-1, 1)\n                    expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w)\n\n                    ## normalize type, send to device.\n                    second_per_grid_t = torch.as_tensor(\n                        second_per_grid_t,\n                        dtype=range_tensor.dtype,\n                        device=range_tensor.device,\n                    )\n\n                    time_tensor = (\n                        expanded_range\n                        * second_per_grid_t\n                        * self.config.vision_config.tokens_per_second\n                    )\n\n                    time_tensor_long = time_tensor.long()\n                    t_index = time_tensor_long.flatten()\n\n                    h_index = (\n                        torch.arange(llm_grid_h)\n                        .view(1, -1, 1)\n                        .expand(llm_grid_t, -1, llm_grid_w)\n                        .flatten()\n                    )\n                    w_index = (\n                        torch.arange(llm_grid_w)\n                        .view(1, 1, -1)\n                        .expand(llm_grid_t, llm_grid_h, -1)\n                        .flatten()\n                    )\n                    llm_pos_ids_list.append(\n                        torch.stack([t_index, h_index, w_index]) + text_len + st_idx\n                    )\n                    st = ed + llm_grid_t * llm_grid_h * llm_grid_w\n\n                if st < len(input_tokens):\n                    st_idx = (\n                        llm_pos_ids_list[-1].max() + 1\n                        if len(llm_pos_ids_list) > 0\n                        else 0\n                    )\n                    text_len = len(input_tokens) - st\n                    llm_pos_ids_list.append(\n                        torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx\n                    )\n\n                llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)\n                position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(\n                    position_ids.device\n                )\n                mrope_position_deltas.append(\n                    llm_positions.max() + 1 - len(total_input_ids[i])\n                )\n            mrope_position_deltas = torch.tensor(\n                mrope_position_deltas, device=input_ids.device\n            ).unsqueeze(1)\n            return position_ids, mrope_position_deltas\n        else:\n            if attention_mask is not None:\n                position_ids = attention_mask.long().cumsum(-1) - 1\n                position_ids.masked_fill_(attention_mask == 0, 1)\n                position_ids = (\n                    position_ids.unsqueeze(0)\n                    .expand(3, -1, -1)\n                    .to(attention_mask.device)\n                )\n                max_position_ids = position_ids.max(0, keepdim=False)[0].max(\n                    -1, keepdim=True\n                )[0]\n                mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]\n            else:\n                position_ids = (\n                    torch.arange(input_ids.shape[1], device=input_ids.device)\n                    .view(1, 1, -1)\n                    .expand(3, input_ids.shape[0], -1)\n                )\n                mrope_position_deltas = torch.zeros(\n                    [input_ids.shape[0], 1],\n                    device=input_ids.device,\n                    dtype=input_ids.dtype,\n                )\n\n            return position_ids, mrope_position_deltas\n\n    def get_video_features(\n        self,\n        pixel_values_videos: torch.FloatTensor,\n        video_grid_thw: Optional[torch.LongTensor] = None,\n    ):\n        \"\"\"\n        Encodes videos into continuous embeddings that can be forwarded to the language model.\n\n        Args:\n            pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):\n                The tensors corresponding to the input videos.\n            video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):\n                The temporal, height and width of feature shape of each video in LLM.\n        \"\"\"\n        pixel_values_videos = pixel_values_videos.type(self.visual.dtype)\n        video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)\n        split_sizes = (\n            video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2\n        ).tolist()\n        video_embeds = torch.split(video_embeds, split_sizes)\n        return video_embeds\n\n    def get_image_features(\n        self,\n        pixel_values: torch.FloatTensor,\n        image_grid_thw: Optional[torch.LongTensor] = None,\n    ):\n        \"\"\"\n        Encodes images into continuous embeddings that can be forwarded to the language model.\n\n        Args:\n            pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):\n                The tensors corresponding to the input images.\n            image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):\n                The temporal, height and width of feature shape of each image in LLM.\n        \"\"\"\n        pixel_values = pixel_values.type(self.visual.dtype)\n        image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)\n        if not isinstance(image_embeds, torch.Tensor):\n            # In transformers v5, the visual encoder returns BaseModelOutputWithPooling.\n            # pooler_output contains the spatially merged embeddings (what we need),\n            # while last_hidden_state contains the raw unmerged output.\n            image_embeds = image_embeds.pooler_output\n        split_sizes = (\n            image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2\n        ).tolist()\n        image_embeds = torch.split(image_embeds, split_sizes)\n        return image_embeds\n\n    def get_placeholder_mask(\n        self,\n        input_ids: torch.LongTensor,\n        inputs_embeds: torch.FloatTensor,\n        image_features: torch.FloatTensor = None,\n        video_features: torch.FloatTensor = None,\n    ):\n        \"\"\"\n        Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is\n        equal to the length of multimodal features. If the lengths are different, an error is raised.\n        \"\"\"\n        if input_ids is None:\n            special_image_mask = inputs_embeds == self.get_input_embeddings()(\n                torch.tensor(\n                    self.config.image_token_id,\n                    dtype=torch.long,\n                    device=inputs_embeds.device,\n                )\n            )\n            special_image_mask = special_image_mask.all(-1)\n            special_video_mask = inputs_embeds == self.get_input_embeddings()(\n                torch.tensor(\n                    self.config.video_token_id,\n                    dtype=torch.long,\n                    device=inputs_embeds.device,\n                )\n            )\n            special_video_mask = special_video_mask.all(-1)\n        else:\n            special_image_mask = input_ids == self.config.image_token_id\n            special_video_mask = input_ids == self.config.video_token_id\n\n        n_image_tokens = special_image_mask.sum()\n        special_image_mask = (\n            special_image_mask.unsqueeze(-1)\n            .expand_as(inputs_embeds)\n            .to(inputs_embeds.device)\n        )\n        if (\n            image_features is not None\n            and inputs_embeds[special_image_mask].numel() != image_features.numel()\n        ):\n            raise ValueError(\n                f\"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}\"\n            )\n\n        n_video_tokens = special_video_mask.sum()\n        special_video_mask = (\n            special_video_mask.unsqueeze(-1)\n            .expand_as(inputs_embeds)\n            .to(inputs_embeds.device)\n        )\n        if (\n            video_features is not None\n            and inputs_embeds[special_video_mask].numel() != video_features.numel()\n        ):\n            raise ValueError(\n                f\"Videos features and video tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}\"\n            )\n\n        return special_image_mask, special_video_mask\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Cache] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        pixel_values: Optional[torch.Tensor] = None,\n        pixel_values_videos: Optional[torch.FloatTensor] = None,\n        image_grid_thw: Optional[torch.LongTensor] = None,\n        video_grid_thw: Optional[torch.LongTensor] = None,\n        rope_deltas: Optional[torch.LongTensor] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        second_per_grid_ts: Optional[torch.Tensor] = None,\n        **kwargs: Unpack[TransformersKwargs],\n    ) -> Union[tuple, Qwen2_5_VLModelOutputWithPast]:\n        r\"\"\"\n        image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):\n            The temporal, height and width of feature shape of each image in LLM.\n        video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):\n            The temporal, height and width of feature shape of each video in LLM.\n        rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):\n            The rope index difference between sequence length and multimodal rope.\n        second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):\n            The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.\n        \"\"\"\n\n        output_attentions = (\n            output_attentions\n            if output_attentions is not None\n            else self.config.output_attentions\n        )\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        if inputs_embeds is None:\n            inputs_embeds = self.get_input_embeddings()(input_ids)\n\n        if pixel_values is not None:\n            image_embeds = self.get_image_features(pixel_values, image_grid_thw)\n            image_embeds = torch.cat(image_embeds, dim=0).to(\n                inputs_embeds.device, inputs_embeds.dtype\n            )\n            image_mask, _ = self.get_placeholder_mask(\n                input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds\n            )\n            inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)\n\n        if pixel_values_videos is not None:\n            video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)\n            video_embeds = torch.cat(video_embeds, dim=0).to(\n                inputs_embeds.device, inputs_embeds.dtype\n            )\n            _, video_mask = self.get_placeholder_mask(\n                input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds\n            )\n            inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)\n\n        if position_ids is None:\n            # Calculate RoPE index once per generation in the pre-fill stage only.\n            # When compiling, we can't check tensor values thus we check only input length\n            # It is safe to assume that `length!=1` means we're in pre-fill because compiled\n            # models currently cannot do asssisted decoding\n            prefill_compiled_stage = is_torchdynamo_compiling() and (\n                (input_ids is not None and input_ids.shape[1] != 1)\n                or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)\n            )\n            prefill_noncompiled_stage = not is_torchdynamo_compiling() and (\n                (cache_position is not None and cache_position[0] == 0)\n                or (past_key_values is None or past_key_values.get_seq_length() == 0)\n            )\n            if (\n                prefill_compiled_stage or prefill_noncompiled_stage\n            ) or self.rope_deltas is None:\n                position_ids, rope_deltas = self.get_rope_index(\n                    input_ids,\n                    image_grid_thw,\n                    video_grid_thw,\n                    second_per_grid_ts=second_per_grid_ts,\n                    attention_mask=attention_mask,\n                )\n                self.rope_deltas = rope_deltas\n            else:\n                batch_size, seq_length, _ = inputs_embeds.shape\n                position_ids = torch.arange(seq_length, device=inputs_embeds.device)\n                position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)\n                if cache_position is not None:\n                    delta = (cache_position[0] + self.rope_deltas).to(\n                        inputs_embeds.device\n                    )\n                else:\n                    delta = torch.zeros(\n                        (batch_size, seq_length), device=inputs_embeds.device\n                    )\n                delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=1)\n                position_ids += delta.to(position_ids.device)\n\n        outputs = self.language_model(\n            input_ids=None,\n            position_ids=position_ids,\n            attention_mask=attention_mask,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=True,\n            cache_position=cache_position,\n            **kwargs,\n        )\n\n        output = Qwen2_5_VLModelOutputWithPast(\n            last_hidden_state=outputs.last_hidden_state,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n            rope_deltas=self.rope_deltas,\n        )\n        return output if return_dict else output.to_tuple()\n\n\nclass Qwen2_5_VLForConditionalGeneration(TextEncoder):\n    # BitandBytes specific attributes\n    default_bitsandbytes_target_modules = [\n        \".gate_up_proj.\",\n        \".down_proj.\",\n        \".q_proj.\",\n        \".k_proj.\",\n        \".v_proj.\",\n        \".o_proj.\",\n    ]\n    bitsandbytes_stacked_params_mapping = {\n        # shard_name, weight_name, index\n        \"q_proj\": (\"qkv_proj\", 0),\n        \"k_proj\": (\"qkv_proj\", 1),\n        \"v_proj\": (\"qkv_proj\", 2),\n        \"gate_proj\": (\"gate_up_proj\", 0),\n        \"up_proj\": (\"gate_up_proj\", 1),\n    }\n\n    def __init__(\n        self,\n        config: Qwen2_5VLConfig,\n        quant_config: Optional[QuantizationConfig] = None,\n        prefix: str = \"\",\n    ) -> None:\n        super().__init__(config)\n        enable_image_understanding = config.enable_image_understanding\n        config = config.arch_config\n        self.model = Qwen2_5_VLModel(\n            config, enable_image_understanding=enable_image_understanding\n        )\n        self.lm_head = nn.Linear(\n            config.text_config.hidden_size, config.text_config.vocab_size, bias=False\n        )\n\n        self.enable_image_understanding = enable_image_understanding\n\n        self.config = config\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    @torch.no_grad()\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Cache] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        pixel_values: Optional[torch.Tensor] = None,\n        pixel_values_videos: Optional[torch.FloatTensor] = None,\n        image_grid_thw: Optional[torch.LongTensor] = None,\n        video_grid_thw: Optional[torch.LongTensor] = None,\n        rope_deltas: Optional[torch.LongTensor] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        second_per_grid_ts: Optional[torch.Tensor] = None,\n        logits_to_keep: Union[int, torch.Tensor] = 0,\n        **kwargs: Unpack[TransformersKwargs],\n    ):\n        \"\"\"Run forward pass for Qwen2_5-VL.\n\n        Args:\n            input_ids: Flattened (concatenated) input_ids corresponding to a\n                batch.\n            positions: Flattened (concatenated) position ids corresponding to a\n                batch.\n                **NOTE**: If mrope is enabled (default setting for Qwen2-VL\n                opensource models), the shape will be `(3, seq_len)`,\n                otherwise it will be `(seq_len,).\n                (Use input_metadata.mrope_positions to replace it)\n        \"\"\"\n        output_attentions = False\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n\n        outputs = self.model(\n            input_ids=input_ids,\n            pixel_values=pixel_values,\n            pixel_values_videos=pixel_values_videos,\n            image_grid_thw=image_grid_thw,\n            video_grid_thw=video_grid_thw,\n            second_per_grid_ts=second_per_grid_ts,\n            position_ids=position_ids,\n            attention_mask=attention_mask,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=True,\n            cache_position=cache_position,\n            **kwargs,\n        )\n\n        hidden_states = outputs[0]\n\n        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss\n        slice_indices = (\n            slice(-logits_to_keep, None)\n            if isinstance(logits_to_keep, int)\n            else logits_to_keep\n        )\n        logits = self.lm_head(hidden_states[:, slice_indices, :])\n        return Qwen2_5_VLCausalLMOutputWithPast(\n            loss=None,\n            logits=logits,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n            rope_deltas=outputs.rope_deltas,\n        )\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        loaded_params: set[str] = set()\n\n        params_dict = dict(self.named_parameters(remove_duplicate=False))\n        for name, loaded_weight in weights:\n            if \"rotary_emb.inv_freq\" in name:\n                continue\n\n            name = name.replace(\"model.\", \"model.language_model.\")\n            if \"visual.\" in name:\n                if not self.enable_image_understanding:\n                    continue\n                name = name.replace(\"visual.\", \"model.visual.\")\n            try:\n                # Skip loading extra bias for GPTQ models.\n                if name.endswith(\".bias\") and name not in params_dict:\n                    continue\n                param = params_dict[name]\n            except KeyError:\n                raise\n\n            weight_loader = getattr(param, \"weight_loader\", default_weight_loader)\n            loaded_weight = loaded_weight.to(param.dtype)\n            weight_loader(param, loaded_weight)\n            loaded_params.add(name)\n        return loaded_params\n\n    def get_embed_and_head(self):\n        return self.model.embed_tokens.weight, self.lm_head.weight\n\n\nEntryClass = Qwen2_5_VLForConditionalGeneration\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/encoders/qwen3.py",
    "content": "from collections.abc import Iterable\nfrom typing import Any\n\nimport torch\nfrom torch import nn\n\nfrom sglang.multimodal_gen.configs.models.encoders import BaseEncoderOutput\nfrom sglang.multimodal_gen.configs.models.encoders.qwen3 import Qwen3TextConfig\nfrom sglang.multimodal_gen.runtime.distributed import get_tp_world_size\nfrom sglang.multimodal_gen.runtime.layers.activation import SiluAndMul\nfrom sglang.multimodal_gen.runtime.layers.attention import LocalAttention\nfrom sglang.multimodal_gen.runtime.layers.layernorm import RMSNorm\nfrom sglang.multimodal_gen.runtime.layers.linear import (\n    MergedColumnParallelLinear,\n    QKVParallelLinear,\n    RowParallelLinear,\n)\nfrom sglang.multimodal_gen.runtime.layers.quantization import QuantizationConfig\nfrom sglang.multimodal_gen.runtime.layers.rotary_embedding import get_rope\nfrom sglang.multimodal_gen.runtime.layers.vocab_parallel_embedding import (\n    VocabParallelEmbedding,\n)\nfrom sglang.multimodal_gen.runtime.loader.weight_utils import (\n    default_weight_loader,\n    maybe_remap_kv_scale_name,\n)\nfrom sglang.multimodal_gen.runtime.models.encoders.base import TextEncoder\n\n\nclass Qwen3MLP(nn.Module):\n    \"\"\"Qwen3 MLP with SwiGLU activation and tensor parallelism.\"\"\"\n\n    def __init__(\n        self,\n        hidden_size: int,\n        intermediate_size: int,\n        hidden_act: str,\n        quant_config: QuantizationConfig | None = None,\n        bias: bool = False,\n        prefix: str = \"\",\n    ) -> None:\n        super().__init__()\n        self.gate_up_proj = MergedColumnParallelLinear(\n            input_size=hidden_size,\n            output_sizes=[intermediate_size] * 2,\n            bias=bias,\n            quant_config=quant_config,\n            prefix=f\"{prefix}.gate_up_proj\",\n        )\n        self.down_proj = RowParallelLinear(\n            input_size=intermediate_size,\n            output_size=hidden_size,\n            bias=bias,\n            quant_config=quant_config,\n            prefix=f\"{prefix}.down_proj\",\n        )\n        if hidden_act != \"silu\":\n            raise ValueError(\n                f\"Unsupported activation: {hidden_act}. Only silu is supported.\"\n            )\n        self.act_fn = SiluAndMul()\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x, _ = self.gate_up_proj(x)\n        x = self.act_fn(x)\n        x, _ = self.down_proj(x)\n        return x\n\n\nclass Qwen3Attention(nn.Module):\n    \"\"\"Qwen3 attention with QK-Norm and tensor parallelism.\n\n    Key difference from LLaMA: RMSNorm is applied to Q and K before attention.\n    \"\"\"\n\n    def __init__(\n        self,\n        config: Qwen3TextConfig,\n        hidden_size: int,\n        num_heads: int,\n        num_kv_heads: int,\n        rope_theta: float = 1000000.0,\n        rope_scaling: dict[str, Any] | None = None,\n        max_position_embeddings: int = 40960,\n        quant_config: QuantizationConfig | None = None,\n        bias: bool = False,\n        prefix: str = \"\",\n    ) -> None:\n        super().__init__()\n        self.hidden_size = hidden_size\n        tp_size = get_tp_world_size()\n        self.total_num_heads = num_heads\n        assert self.total_num_heads % tp_size == 0\n        self.num_heads = self.total_num_heads // tp_size\n        self.total_num_kv_heads = num_kv_heads\n        if self.total_num_kv_heads >= tp_size:\n            assert self.total_num_kv_heads % tp_size == 0\n        else:\n            assert tp_size % self.total_num_kv_heads == 0\n        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)\n\n        self.head_dim = getattr(\n            config, \"head_dim\", self.hidden_size // self.total_num_heads\n        )\n        self.rotary_dim = self.head_dim\n        self.q_size = self.num_heads * self.head_dim\n        self.kv_size = self.num_kv_heads * self.head_dim\n        self.scaling = self.head_dim**-0.5\n        self.rope_theta = rope_theta\n        self.max_position_embeddings = max_position_embeddings\n\n        # QKV projection with tensor parallelism\n        self.qkv_proj = QKVParallelLinear(\n            hidden_size=hidden_size,\n            head_size=self.head_dim,\n            total_num_heads=self.total_num_heads,\n            total_num_kv_heads=self.total_num_kv_heads,\n            bias=bias,\n            quant_config=quant_config,\n            prefix=f\"{prefix}.qkv_proj\",\n        )\n\n        # Output projection\n        self.o_proj = RowParallelLinear(\n            input_size=self.total_num_heads * self.head_dim,\n            output_size=hidden_size,\n            bias=bias,\n            quant_config=quant_config,\n            prefix=f\"{prefix}.o_proj\",\n        )\n\n        # QK-Norm: Key difference from LLaMA\n        rms_norm_eps = getattr(config, \"rms_norm_eps\", 1e-6)\n        self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)\n        self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)\n\n        # Rotary embeddings\n        self.rotary_emb = get_rope(\n            self.head_dim,\n            rotary_dim=self.rotary_dim,\n            max_position=max_position_embeddings,\n            base=int(rope_theta),\n            rope_scaling=rope_scaling,\n            is_neox_style=True,\n        )\n\n        # Attention with FlashAttention/SageAttn support\n        self.attn = LocalAttention(\n            self.num_heads,\n            self.head_dim,\n            self.num_kv_heads,\n            softmax_scale=self.scaling,\n            causal=True,\n            supported_attention_backends=config._supported_attention_backends,\n        )\n\n    def forward(\n        self,\n        positions: torch.Tensor,\n        hidden_states: torch.Tensor,\n    ) -> torch.Tensor:\n        # QKV projection\n        qkv, _ = self.qkv_proj(hidden_states)\n        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)\n\n        # Reshape for QK-norm\n        batch_size, seq_len = q.shape[0], q.shape[1]\n        q = q.reshape(batch_size, seq_len, self.num_heads, self.head_dim)\n        k = k.reshape(batch_size, seq_len, self.num_kv_heads, self.head_dim)\n        v = v.reshape(batch_size, seq_len, self.num_kv_heads, self.head_dim)\n\n        # Apply QK-Norm (key difference from LLaMA)\n        q = self.q_norm(q)\n        k = self.k_norm(k)\n\n        # Reshape back for rotary embeddings\n        q = q.reshape(batch_size, seq_len, -1)\n        k = k.reshape(batch_size, seq_len, -1)\n\n        # Apply rotary embeddings\n        q, k = self.rotary_emb(positions, q, k)\n\n        # Reshape for attention\n        q = q.reshape(batch_size, seq_len, self.num_heads, self.head_dim)\n        k = k.reshape(batch_size, seq_len, self.num_kv_heads, self.head_dim)\n\n        # Attention\n        attn_output = self.attn(q, k, v)\n        attn_output = attn_output.reshape(batch_size, seq_len, -1)\n\n        # Output projection\n        output, _ = self.o_proj(attn_output)\n        return output\n\n\nclass Qwen3DecoderLayer(nn.Module):\n    \"\"\"Qwen3 transformer decoder layer.\"\"\"\n\n    def __init__(\n        self,\n        config: Qwen3TextConfig,\n        quant_config: QuantizationConfig | None = None,\n        prefix: str = \"\",\n    ) -> None:\n        super().__init__()\n        self.hidden_size = config.hidden_size\n        rope_theta = config.rope_parameters[\"rope_theta\"]\n        rope_scaling = config.rope_parameters\n        max_position_embeddings = getattr(config, \"max_position_embeddings\", 40960)\n        attention_bias = getattr(config, \"attention_bias\", False)\n\n        self.self_attn = Qwen3Attention(\n            config=config,\n            hidden_size=self.hidden_size,\n            num_heads=config.num_attention_heads,\n            num_kv_heads=getattr(\n                config, \"num_key_value_heads\", config.num_attention_heads\n            ),\n            rope_theta=rope_theta,\n            rope_scaling=rope_scaling,\n            max_position_embeddings=max_position_embeddings,\n            quant_config=quant_config,\n            bias=attention_bias,\n            prefix=f\"{prefix}.self_attn\",\n        )\n        self.mlp = Qwen3MLP(\n            hidden_size=self.hidden_size,\n            intermediate_size=config.intermediate_size,\n            hidden_act=config.hidden_act,\n            quant_config=quant_config,\n            bias=getattr(config, \"mlp_bias\", False),\n            prefix=f\"{prefix}.mlp\",\n        )\n        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.post_attention_layernorm = RMSNorm(\n            config.hidden_size, eps=config.rms_norm_eps\n        )\n\n    def forward(\n        self,\n        positions: torch.Tensor,\n        hidden_states: torch.Tensor,\n        residual: torch.Tensor | None,\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        # Self Attention\n        if residual is None:\n            residual = hidden_states\n            hidden_states = self.input_layernorm(hidden_states)\n        else:\n            hidden_states, residual = self.input_layernorm(hidden_states, residual)\n\n        hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states)\n\n        # MLP\n        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)\n        hidden_states = self.mlp(hidden_states)\n        return hidden_states, residual\n\n\nclass Qwen3ForCausalLM(TextEncoder):\n    \"\"\"Qwen3 causal language model for text encoding in diffusion models.\n\n    Features:\n    - Tensor parallelism support\n    - FlashAttention/SageAttn/SDPA support via LocalAttention\n    - QK-Norm for better training stability\n    - FSDP sharding for CPU offload\n    \"\"\"\n\n    def __init__(self, config: Qwen3TextConfig) -> None:\n        super().__init__(config)\n\n        self.config = config\n        self.quant_config = config.quant_config\n\n        # Embedding layer with tensor parallelism\n        if config.lora_config is not None:\n            max_loras = getattr(config.lora_config, \"max_loras\", 1)\n            lora_vocab_size = getattr(config.lora_config, \"lora_extra_vocab_size\", 1)\n            lora_vocab = lora_vocab_size * max_loras\n        else:\n            lora_vocab = 0\n        self.vocab_size = config.vocab_size + lora_vocab\n        self.org_vocab_size = config.vocab_size\n\n        self.embed_tokens = VocabParallelEmbedding(\n            self.vocab_size,\n            config.hidden_size,\n            org_num_embeddings=config.vocab_size,\n            quant_config=config.quant_config,\n        )\n\n        # Transformer layers\n        self.layers = nn.ModuleList(\n            [\n                Qwen3DecoderLayer(\n                    config=config,\n                    quant_config=config.quant_config,\n                    prefix=f\"{config.prefix}.layers.{i}\",\n                )\n                for i in range(config.num_hidden_layers)\n            ]\n        )\n\n        # Final layer norm\n        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:\n        return self.embed_tokens(input_ids)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor | None = None,\n        position_ids: torch.Tensor | None = None,\n        attention_mask: torch.Tensor | None = None,\n        inputs_embeds: torch.Tensor | None = None,\n        output_hidden_states: bool | None = None,\n        **kwargs,\n    ) -> BaseEncoderOutput:\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n\n        if inputs_embeds is not None:\n            hidden_states = inputs_embeds\n        else:\n            hidden_states = self.get_input_embeddings(input_ids)\n\n        residual = None\n\n        if position_ids is None:\n            position_ids = torch.arange(\n                0, hidden_states.shape[1], device=hidden_states.device\n            ).unsqueeze(0)\n\n        all_hidden_states: tuple[Any, ...] | None = () if output_hidden_states else None\n\n        for layer in self.layers:\n            if all_hidden_states is not None:\n                all_hidden_states += (\n                    (hidden_states,)\n                    if residual is None\n                    else (hidden_states + residual,)\n                )\n            hidden_states, residual = layer(position_ids, hidden_states, residual)\n\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        # Add hidden states from the last decoder layer\n        if all_hidden_states is not None:\n            all_hidden_states += (hidden_states,)\n\n        return BaseEncoderOutput(\n            last_hidden_state=hidden_states,\n            hidden_states=all_hidden_states,\n        )\n\n    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:\n        \"\"\"Load weights with support for tensor parallelism and weight remapping.\"\"\"\n        params_dict = dict(self.named_parameters())\n        loaded_params: set[str] = set()\n\n        for name, loaded_weight in weights:\n            # Strip 'model.' prefix from HuggingFace Qwen3 weights\n            if name.startswith(\"model.\"):\n                name = name[6:]  # len(\"model.\") == 6\n\n            # Skip rotary embedding weights\n            if \"rotary_emb.inv_freq\" in name:\n                continue\n            if \"rotary_emb.cos_cached\" in name or \"rotary_emb.sin_cached\" in name:\n                continue\n\n            # Handle KV scale remapping\n            if \"scale\" in name:\n                kv_scale_name: str | None = maybe_remap_kv_scale_name(name, params_dict)\n                if kv_scale_name is None:\n                    continue\n                else:\n                    name = kv_scale_name\n\n            # Handle stacked params mapping (qkv_proj, gate_up_proj)\n            for (\n                param_name,\n                weight_name,\n                shard_id,\n            ) in self.config.arch_config.stacked_params_mapping:\n                if weight_name not in name:\n                    continue\n                name = name.replace(weight_name, param_name)\n\n                # Skip loading extra bias for GPTQ models\n                if name.endswith(\".bias\") and name not in params_dict:\n                    continue\n\n                if name not in params_dict:\n                    continue\n\n                param = params_dict[name]\n                weight_loader = param.weight_loader\n                weight_loader(param, loaded_weight, shard_id)\n                break\n            else:\n                # Skip loading extra bias for GPTQ models\n                if name.endswith(\".bias\") and name not in params_dict:\n                    continue\n\n                if name not in params_dict:\n                    continue\n\n                param = params_dict[name]\n                weight_loader = getattr(param, \"weight_loader\", default_weight_loader)\n                weight_loader(param, loaded_weight)\n\n            loaded_params.add(name)\n\n        return loaded_params\n\n\nEntryClass = Qwen3ForCausalLM\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/encoders/t5.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# Adapted from transformers: https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/t5/modeling_t5.py\n\n# Derived from T5 implementation posted on HuggingFace; license below:\n#\n# coding=utf-8\n# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"PyTorch T5 & UMT5 model.\"\"\"\n\nimport math\nfrom collections.abc import Iterable\nfrom dataclasses import dataclass\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\nfrom sglang.multimodal_gen.configs.models.encoders import BaseEncoderOutput, T5Config\nfrom sglang.multimodal_gen.runtime.distributed import _get_folding_tp_group\nfrom sglang.multimodal_gen.runtime.layers.activation import get_act_fn\nfrom sglang.multimodal_gen.runtime.layers.layernorm import RMSNorm\nfrom sglang.multimodal_gen.runtime.layers.linear import (\n    MergedColumnParallelLinear,\n    QKVParallelLinear,\n    RowParallelLinear,\n)\nfrom sglang.multimodal_gen.runtime.layers.quantization import QuantizationConfig\nfrom sglang.multimodal_gen.runtime.layers.utils import get_group_rank, get_group_size\nfrom sglang.multimodal_gen.runtime.layers.vocab_parallel_embedding import (\n    VocabParallelEmbedding,\n)\nfrom sglang.multimodal_gen.runtime.loader.weight_utils import default_weight_loader\nfrom sglang.multimodal_gen.runtime.models.encoders.base import TextEncoder\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\n\n\nclass AttentionType:\n    \"\"\"\n    Attention type.\n    Use string to be compatible with `torch.compile`.\n    \"\"\"\n\n    # Decoder attention between previous layer Q/K/V\n    DECODER = \"decoder\"\n    # Encoder attention between previous layer Q/K/V for encoder-decoder\n    ENCODER = \"encoder\"\n    # Encoder attention between previous layer Q/K/V\n    ENCODER_ONLY = \"encoder_only\"\n    # Attention between dec. Q and enc. K/V for encoder-decoder\n    ENCODER_DECODER = \"encoder_decoder\"\n\n\n@dataclass\nclass AttentionMetadata:\n    attn_bias: torch.Tensor\n\n\nclass T5DenseActDense(nn.Module):\n\n    def __init__(\n        self, config: T5Config, quant_config: QuantizationConfig | None = None\n    ):\n        super().__init__()\n        tp_group = _get_folding_tp_group(config)\n        self.wi = MergedColumnParallelLinear(\n            config.d_model, [config.d_ff], bias=False, tp_group=tp_group\n        )\n        self.wo = RowParallelLinear(\n            config.d_ff,\n            config.d_model,\n            bias=False,\n            quant_config=quant_config,\n            tp_group=tp_group,\n        )\n        self.act = get_act_fn(config.dense_act_fn)\n\n    def forward(self, hidden_states) -> torch.Tensor:\n        hidden_states, _ = self.wi(hidden_states)\n        hidden_states = self.act(hidden_states)\n        hidden_states, _ = self.wo(hidden_states)\n        return hidden_states\n\n\nclass T5DenseGatedActDense(nn.Module):\n\n    def __init__(\n        self, config: T5Config, quant_config: QuantizationConfig | None = None\n    ):\n        super().__init__()\n        tp_group = _get_folding_tp_group(config)\n        self.wi_0 = MergedColumnParallelLinear(\n            config.d_model,\n            [config.d_ff],\n            bias=False,\n            quant_config=quant_config,\n            tp_group=tp_group,\n        )\n        self.wi_1 = MergedColumnParallelLinear(\n            config.d_model,\n            [config.d_ff],\n            bias=False,\n            quant_config=quant_config,\n            tp_group=tp_group,\n        )\n        # Should not run in fp16 unless mixed-precision is used,\n        # see https://github.com/huggingface/transformers/issues/20287.\n        self.wo = RowParallelLinear(\n            config.d_ff,\n            config.d_model,\n            bias=False,\n            quant_config=quant_config,\n            tp_group=tp_group,\n        )\n        self.act = get_act_fn(config.dense_act_fn)\n\n    def forward(self, hidden_states) -> torch.Tensor:\n        hidden_gelu = self.act(self.wi_0(hidden_states)[0])\n        hidden_linear, _ = self.wi_1(hidden_states)\n        hidden_states = hidden_gelu * hidden_linear\n        hidden_states, _ = self.wo(hidden_states)\n        return hidden_states\n\n\nclass T5LayerFF(nn.Module):\n\n    def __init__(\n        self, config: T5Config, quant_config: QuantizationConfig | None = None\n    ):\n        super().__init__()\n        if config.is_gated_act:\n            self.DenseReluDense = T5DenseGatedActDense(\n                config, quant_config=quant_config\n            )\n        else:\n            self.DenseReluDense = T5DenseActDense(config, quant_config=quant_config)\n\n        self.layer_norm = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)\n\n    def forward(self, hidden_states) -> torch.Tensor:\n        forwarded_states = self.layer_norm(hidden_states)\n        forwarded_states = self.DenseReluDense(forwarded_states)\n        hidden_states = hidden_states + forwarded_states\n        return hidden_states\n\n\n# T5 has attn_bias and does not use softmax scaling\nclass T5MultiHeadAttention(nn.Module):\n\n    def __init__(self) -> None:\n        super().__init__()\n\n    def forward(self, q, k, v, attn_bias=None):\n        b, _, n, c = q.shape\n        attn = torch.einsum(\"binc,bjnc->bnij\", q, k)\n        if attn_bias is not None:\n            attn += attn_bias\n\n        attn = F.softmax(attn.float(), dim=-1).type_as(attn)\n        x = torch.einsum(\"bnij,bjnc->binc\", attn, v)\n        x = x.reshape(b, -1, n * c)\n        return x\n\n\nclass T5Attention(nn.Module):\n\n    def __init__(\n        self,\n        config: T5Config,\n        attn_type: str,\n        has_relative_attention_bias=False,\n        quant_config: QuantizationConfig | None = None,\n        prefix: str = \"\",\n    ):\n        super().__init__()\n        self.attn_type = attn_type\n        # Cross-attention has no relative pos encoding anyway\n        self.is_decoder = attn_type == AttentionType.DECODER\n        self.has_relative_attention_bias = has_relative_attention_bias\n        self.relative_attention_num_buckets = config.relative_attention_num_buckets\n        self.relative_attention_max_distance = config.relative_attention_max_distance\n        self.d_model = config.d_model\n        self.key_value_proj_dim = config.d_kv\n        self.total_num_heads = self.total_num_kv_heads = config.num_heads\n\n        # Partition heads across multiple tensor parallel GPUs.\n        self.tp_group = _get_folding_tp_group(config)\n        self.tp_world_size = get_group_size(self.tp_group)\n        assert config.num_heads % self.tp_world_size == 0\n        self.n_heads = config.num_heads // self.tp_world_size\n\n        self.inner_dim = self.n_heads * self.key_value_proj_dim\n        # No GQA in t5.\n        # self.n_kv_heads = self.n_heads\n\n        self.qkv_proj = QKVParallelLinear(\n            self.d_model,\n            self.key_value_proj_dim,\n            self.total_num_heads,\n            self.total_num_kv_heads,\n            bias=False,\n            quant_config=quant_config,\n            prefix=f\"{prefix}.qkv_proj\",\n            tp_group=self.tp_group,\n        )\n\n        self.attn = T5MultiHeadAttention()\n\n        if self.has_relative_attention_bias:\n            self.relative_attention_bias = VocabParallelEmbedding(\n                self.relative_attention_num_buckets,\n                self.total_num_heads,\n                org_num_embeddings=self.relative_attention_num_buckets,\n                padding_size=self.relative_attention_num_buckets,\n                quant_config=quant_config,\n                tp_group=self.tp_group,\n            )\n        self.o = RowParallelLinear(\n            self.total_num_heads * self.key_value_proj_dim,\n            self.d_model,\n            bias=False,\n            quant_config=quant_config,\n            prefix=f\"{prefix}.o_proj\",\n            tp_group=self.tp_group,\n        )\n\n    @staticmethod\n    def _relative_position_bucket(\n        relative_position, bidirectional=True, num_buckets=32, max_distance=128\n    ) -> torch.Tensor:\n        \"\"\"\n        Adapted from Mesh Tensorflow:\n        https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593\n        Translate relative position to a bucket number for relative attention.\n        The relative position is defined as memory_position - query_position,\n        i.e. the distance in tokens from the attending position to the\n        attended-to position. If bidirectional=False, then positive relative\n        positions are invalid. We use smaller buckets for small absolute\n        relative_position and larger buckets for larger absolute\n        relative_positions. All relative positions >=max_distance map to the\n        same bucket. All relative positions <=-max_distance map to the same\n        bucket. This should allow for more graceful generalization to longer\n        sequences than the model has been trained on\n        Args:\n            relative_position: an int32 Tensor\n            bidirectional: a boolean - whether the attention is bidirectional\n            num_buckets: an integer\n            max_distance: an integer\n        Returns:\n            a Tensor with the same shape as relative_position, containing int32\n            values in the range [0, num_buckets)\n        \"\"\"  # noqa: E501\n        relative_buckets = 0\n        if bidirectional:\n            num_buckets //= 2\n            relative_buckets += (relative_position > 0).to(torch.long) * num_buckets\n            relative_position = torch.abs(relative_position)\n        else:\n            relative_position = -torch.min(\n                relative_position, torch.zeros_like(relative_position)\n            )\n        # now relative_position is in the range [0, inf)\n\n        # half of the buckets are for exact increments in positions\n        max_exact = num_buckets // 2\n        is_small = relative_position < max_exact\n\n        # The other half of the buckets are for logarithmically bigger bins\n        # in positions up to max_distance\n        relative_position_if_large = max_exact + (\n            torch.log(relative_position.float() / max_exact)\n            / math.log(max_distance / max_exact)\n            * (num_buckets - max_exact)\n        ).to(torch.long)\n        relative_position_if_large = torch.min(\n            relative_position_if_large,\n            torch.full_like(relative_position_if_large, num_buckets - 1),\n        )\n\n        relative_buckets += torch.where(\n            is_small, relative_position, relative_position_if_large\n        )\n        return relative_buckets\n\n    def compute_bias(self, query_length, key_length, device=None) -> torch.Tensor:\n        \"\"\"Compute binned relative position bias\"\"\"\n        if device is None:\n            device = self.relative_attention_bias.weight.device\n        context_position = torch.arange(query_length, dtype=torch.long, device=device)[\n            :, None\n        ]\n        memory_position = torch.arange(key_length, dtype=torch.long, device=device)[\n            None, :\n        ]\n        # max_seq_len, nh\n        relative_position = memory_position - context_position\n        relative_position_bucket = self._relative_position_bucket(\n            relative_position,  # shape (query_length, key_length)\n            bidirectional=(not self.is_decoder),\n            num_buckets=self.relative_attention_num_buckets,\n            max_distance=self.relative_attention_max_distance,\n        )\n        values = self.relative_attention_bias(\n            relative_position_bucket\n        )  # shape (query_length, key_length, num_heads)\n        x = values.permute([2, 0, 1]).unsqueeze(\n            0\n        )  # shape (1, num_heads, query_length, key_length)\n        return x\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,  # (num_tokens, d_model)\n        attention_mask: torch.Tensor,\n        attn_metadata: AttentionMetadata | None = None,\n    ) -> torch.Tensor:\n        bs, seq_len, _ = hidden_states.shape\n        num_seqs = bs\n        n, c = (\n            self.n_heads,\n            self.key_value_proj_dim,\n        )\n        qkv, _ = self.qkv_proj(hidden_states)\n        # Projection of 'own' hidden state (self-attention). No GQA here.\n        q, k, v = qkv.split(self.inner_dim, dim=-1)\n        q = q.reshape(bs, seq_len, n, c)\n        k = k.reshape(bs, seq_len, n, c)\n        v = v.reshape(bs, seq_len, n, c)\n\n        assert attn_metadata is not None\n        attn_bias = attn_metadata.attn_bias\n        # Not compatible with CP here (as all encoder-decoder models),\n        # as it assumes homogeneous batch (prefills or decodes).\n        if self.has_relative_attention_bias:\n            # Self-attention. Compute T5 relative positional encoding.\n            # The bias term is computed on longest sequence in batch. Biases\n            # for shorter sequences are slices of the longest.\n            assert self.attn_type == AttentionType.ENCODER\n            attn_bias = self.compute_bias(seq_len, seq_len).repeat(num_seqs, 1, 1, 1)\n            attn_metadata.attn_bias = attn_bias\n        else:\n            # Encoder/Decoder Self-Attention Layer, attn bias already cached.\n            assert attn_bias is not None\n\n        if attention_mask is not None:\n            attention_mask = (\n                attention_mask.view(bs, 1, 1, -1)\n                if attention_mask.ndim == 2\n                else attention_mask.unsqueeze(1)\n            )\n            mask_val = -1e4 if current_platform.is_mps() else torch.finfo(q.dtype).min\n            attn_bias.masked_fill_(attention_mask == 0, mask_val)\n\n        if self.tp_world_size > 1:\n            rank = get_group_rank(self.tp_group)\n            attn_bias = attn_bias[\n                :, rank * self.n_heads : (rank + 1) * self.n_heads, :, :\n            ]\n\n        attn_output = self.attn(q, k, v, attn_bias)\n        output, _ = self.o(attn_output)\n        return output\n\n\nclass T5LayerSelfAttention(nn.Module):\n\n    def __init__(\n        self,\n        config,\n        has_relative_attention_bias=False,\n        quant_config: QuantizationConfig | None = None,\n        prefix: str = \"\",\n    ):\n        super().__init__()\n        self.SelfAttention = T5Attention(\n            config,\n            AttentionType.DECODER if \"decoder\" in prefix else AttentionType.ENCODER,\n            has_relative_attention_bias=has_relative_attention_bias,\n            quant_config=quant_config,\n            prefix=f\"{prefix}.SelfAttention\",\n        )\n        self.layer_norm = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: torch.Tensor,\n        attn_metadata: AttentionMetadata | None = None,\n    ) -> torch.Tensor:\n        normed_hidden_states = self.layer_norm(hidden_states)\n\n        attention_output = self.SelfAttention(\n            hidden_states=normed_hidden_states,\n            attention_mask=attention_mask,\n            attn_metadata=attn_metadata,\n        )\n\n        hidden_states = hidden_states + attention_output\n\n        return hidden_states\n\n\nclass T5LayerCrossAttention(nn.Module):\n\n    def __init__(\n        self, config, quant_config: QuantizationConfig | None = None, prefix: str = \"\"\n    ):\n        super().__init__()\n        self.EncDecAttention = T5Attention(\n            config,\n            AttentionType.ENCODER_DECODER,\n            has_relative_attention_bias=False,\n            quant_config=quant_config,\n            prefix=f\"{prefix}.EncDecAttention\",\n        )\n        self.layer_norm = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attn_metadata: AttentionMetadata | None = None,\n    ) -> torch.Tensor:\n        normed_hidden_states = self.layer_norm(hidden_states)\n        attention_output = self.EncDecAttention(\n            hidden_states=normed_hidden_states,\n            attn_metadata=attn_metadata,\n        )\n        hidden_states = hidden_states + attention_output\n        return hidden_states\n\n\nclass T5Block(nn.Module):\n\n    def __init__(\n        self,\n        config: T5Config,\n        is_decoder: bool,\n        has_relative_attention_bias=False,\n        quant_config: QuantizationConfig | None = None,\n        prefix: str = \"\",\n    ):\n        super().__init__()\n        self.is_decoder = is_decoder\n        self.layer = nn.ModuleList()\n        self.layer.append(\n            T5LayerSelfAttention(\n                config,\n                has_relative_attention_bias=has_relative_attention_bias,\n                quant_config=quant_config,\n                prefix=f\"{prefix}.self_attn\",\n            )\n        )\n\n        if self.is_decoder:\n            self.layer.append(\n                T5LayerCrossAttention(\n                    config, quant_config=quant_config, prefix=f\"{prefix}.cross_attn\"\n                )\n            )\n\n        self.layer.append(T5LayerFF(config, quant_config=quant_config))\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: torch.Tensor,\n        attn_metadata: AttentionMetadata | None = None,\n    ) -> torch.Tensor:\n\n        if attention_mask is None:\n            attention_mask = torch.ones(\n                hidden_states.shape[:2], device=hidden_states.device\n            )\n\n        hidden_states = self.layer[0](\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            attn_metadata=attn_metadata,\n        )\n\n        if self.is_decoder:\n            hidden_states = self.layer[1](\n                hidden_states=hidden_states, attn_metadata=attn_metadata\n            )\n\n        # Apply Feed Forward layer\n        hidden_states = self.layer[-1](hidden_states)\n\n        return hidden_states\n\n\nclass T5Stack(nn.Module):\n\n    def __init__(\n        self,\n        config: T5Config,\n        is_decoder: bool,\n        n_layers: int,\n        embed_tokens=None,\n        quant_config: QuantizationConfig | None = None,\n        prefix: str = \"\",\n        is_umt5: bool = False,\n    ):\n        super().__init__()\n        self.embed_tokens = embed_tokens\n        self.is_umt5 = is_umt5\n        if is_umt5:\n            self.block = nn.ModuleList(\n                [\n                    T5Block(\n                        config,\n                        is_decoder=is_decoder,\n                        has_relative_attention_bias=True,\n                        quant_config=quant_config,\n                        prefix=f\"{prefix}.blocks.{i}\",\n                    )\n                    for i in range(n_layers)\n                ]\n            )\n        else:\n            # Only the first block has relative positional encoding.\n            self.block = nn.ModuleList(\n                [\n                    T5Block(\n                        config,\n                        is_decoder=is_decoder,\n                        has_relative_attention_bias=i == 0,\n                        quant_config=quant_config,\n                        prefix=f\"{prefix}.blocks.{i}\",\n                    )\n                    for i in range(n_layers)\n                ]\n            )\n        self.final_layer_norm = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        attention_mask: torch.Tensor,\n        attn_metadata: AttentionMetadata,\n    ) -> torch.Tensor:\n        hidden_states = self.embed_tokens(input_ids)\n\n        for idx, block in enumerate(self.block):\n            hidden_states = block(\n                hidden_states=hidden_states,\n                attention_mask=attention_mask,\n                attn_metadata=attn_metadata,\n            )\n\n        hidden_states = self.final_layer_norm(hidden_states)\n        return hidden_states\n\n\nclass T5EncoderModel(TextEncoder):\n\n    def __init__(self, config: T5Config, prefix: str = \"\"):\n        super().__init__(config)\n\n        quant_config = None\n        tp_group = _get_folding_tp_group(config)\n        self.shared = VocabParallelEmbedding(\n            config.vocab_size,\n            config.d_model,\n            org_num_embeddings=config.vocab_size,\n            tp_group=tp_group,\n        )\n\n        self.encoder = T5Stack(\n            config,\n            False,\n            config.num_layers,\n            self.shared,\n            quant_config=quant_config,\n            prefix=f\"{prefix}.encoder\",\n            is_umt5=False,\n        )\n\n    def get_input_embeddings(self):\n        return self.shared\n\n    def forward(\n        self,\n        input_ids: torch.Tensor | None,\n        position_ids: torch.Tensor | None = None,\n        attention_mask: torch.Tensor | None = None,\n        inputs_embeds: torch.Tensor | None = None,\n        output_hidden_states: bool | None = None,\n        **kwargs,\n    ) -> BaseEncoderOutput:\n        attn_metadata = AttentionMetadata(None)\n        hidden_states = self.encoder(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            attn_metadata=attn_metadata,\n        )\n\n        return BaseEncoderOutput(last_hidden_state=hidden_states)\n\n    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            (\".qkv_proj\", \".q\", \"q\"),\n            (\".qkv_proj\", \".k\", \"k\"),\n            (\".qkv_proj\", \".v\", \"v\"),\n        ]\n        params_dict = dict(self.named_parameters())\n        loaded_params: set[str] = set()\n        for name, loaded_weight in weights:\n            loaded = False\n            if \"decoder\" in name or \"lm_head\" in name:\n                continue\n            for param_name, weight_name, shard_id in stacked_params_mapping:\n                if weight_name not in name:\n                    continue\n                name = name.replace(weight_name, param_name)\n                # Skip loading extra bias for GPTQ models.\n                if name.endswith(\".bias\") and name not in params_dict:\n                    continue\n\n                if name not in params_dict:\n                    continue\n\n                param = params_dict[name]\n                weight_loader = param.weight_loader\n                weight_loader(param, loaded_weight, shard_id)\n                loaded = True\n                break\n            if not loaded:\n                # Skip loading extra bias for GPTQ models.\n                if name.endswith(\".bias\") and name not in params_dict:\n                    continue\n\n                if name not in params_dict:\n                    continue\n\n                param = params_dict[name]\n                weight_loader = getattr(param, \"weight_loader\", default_weight_loader)\n                weight_loader(param, loaded_weight)\n            loaded_params.add(name)\n        return loaded_params\n\n\nclass UMT5EncoderModel(TextEncoder):\n\n    def __init__(self, config: T5Config, prefix: str = \"\"):\n        super().__init__(config)\n\n        quant_config = None\n        tp_group = _get_folding_tp_group(config)\n        self.shared = VocabParallelEmbedding(\n            config.vocab_size,\n            config.d_model,\n            org_num_embeddings=config.vocab_size,\n            tp_group=tp_group,\n        )\n\n        self.encoder = T5Stack(\n            config,\n            False,\n            config.num_layers,\n            self.shared,\n            quant_config=quant_config,\n            prefix=f\"{prefix}.encoder\",\n            is_umt5=True,\n        )\n\n    def get_input_embeddings(self):\n        return self.shared\n\n    def forward(\n        self,\n        input_ids: torch.Tensor | None,\n        position_ids: torch.Tensor | None = None,\n        attention_mask: torch.Tensor | None = None,\n        inputs_embeds: torch.Tensor | None = None,\n        output_hidden_states: bool | None = None,\n        **kwargs,\n    ) -> BaseEncoderOutput:\n        attn_metadata = AttentionMetadata(None)\n        hidden_states = self.encoder(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            attn_metadata=attn_metadata,\n        )\n\n        return BaseEncoderOutput(\n            last_hidden_state=hidden_states,\n            attention_mask=attention_mask,\n        )\n\n    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:\n        params_dict = dict(self.named_parameters())\n        loaded_params: set[str] = set()\n        for name, loaded_weight in weights:\n            loaded = False\n            if \"decoder\" in name or \"lm_head\" in name:\n                continue\n            for (\n                param_name,\n                weight_name,\n                shard_id,\n            ) in self.config.arch_config.stacked_params_mapping:\n                if weight_name not in name:\n                    continue\n                name = name.replace(weight_name, param_name)\n                # Skip loading extra bias for GPTQ models.\n                if name.endswith(\".bias\") and name not in params_dict:\n                    continue\n\n                if name not in params_dict:\n                    continue\n\n                param = params_dict[name]\n                weight_loader = param.weight_loader\n                weight_loader(param, loaded_weight, shard_id)\n                loaded = True\n                break\n            if not loaded:\n                # Skip loading extra bias for GPTQ models.\n                if name.endswith(\".bias\") and name not in params_dict:\n                    continue\n\n                if name not in params_dict:\n                    continue\n\n                param = params_dict[name]\n                weight_loader = getattr(param, \"weight_loader\", default_weight_loader)\n                weight_loader(param, loaded_weight)\n            loaded_params.add(name)\n        return loaded_params\n\n\nEntryClass = [T5EncoderModel, UMT5EncoderModel]\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/encoders/vision.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/models/vision.py\n\nfrom abc import ABC, abstractmethod\nfrom typing import Generic, TypeVar\n\nimport torch\nfrom transformers import PretrainedConfig\n\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n_C = TypeVar(\"_C\", bound=PretrainedConfig)\n\n\nclass VisionEncoderInfo(ABC, Generic[_C]):\n\n    def __init__(self, vision_config: _C) -> None:\n        super().__init__()\n\n        self.vision_config = vision_config\n\n    @abstractmethod\n    def get_num_image_tokens(\n        self,\n        *,\n        image_width: int,\n        image_height: int,\n    ) -> int:\n        raise NotImplementedError\n\n    @abstractmethod\n    def get_max_image_tokens(self) -> int:\n        raise NotImplementedError\n\n    @abstractmethod\n    def get_image_size(self) -> int:\n        raise NotImplementedError\n\n    @abstractmethod\n    def get_patch_size(self) -> int:\n        raise NotImplementedError\n\n    @abstractmethod\n    def get_patch_grid_length(self) -> int:\n        raise NotImplementedError\n\n\ndef resolve_visual_encoder_outputs(\n    encoder_outputs: torch.Tensor | list[torch.Tensor],\n    feature_sample_layers: list[int] | None,\n    post_layer_norm: torch.nn.LayerNorm | None,\n    max_possible_layers: int,\n) -> torch.Tensor:\n    \"\"\"Given the outputs a visual encoder module that may correspond to the\n    output of the last layer, or a list of hidden states to be stacked,\n    handle post normalization and resolve it into a single output tensor.\n\n    Args:\n        encoder_outputs: Output of encoder's last layer or all hidden states.\n        feature_sample_layers: Optional layer indices to grab from the encoder\n            outputs; if provided, encoder outputs must be a list.\n        post_layer_norm: Post norm to apply to the output of the encoder.\n        max_possible_layers: Total layers in the fully loaded visual encoder.\n\n    \"\"\"\n    if feature_sample_layers is None:\n        if post_layer_norm is not None:\n            return post_layer_norm(encoder_outputs)\n        return encoder_outputs\n\n    # Get the hidden states corresponding to the layer indices.\n    # Negative values are relative to the full visual encoder,\n    # so offset them depending on how many layers were loaded.\n    # NOTE: this assumes that encoder_outputs is a list containing\n    # the inputs to the visual encoder, followed by the hidden states\n    # of each layer.\n    num_loaded_layers = len(encoder_outputs) - 1\n    offset = max_possible_layers - num_loaded_layers\n    hs_pool = [\n        (\n            encoder_outputs[layer_idx]\n            if layer_idx >= 0\n            else encoder_outputs[layer_idx + offset]\n        )\n        for layer_idx in feature_sample_layers\n    ]\n\n    # Apply post-norm on the final hidden state if we are using it\n    uses_last_layer = feature_sample_layers[-1] in (len(hs_pool) - 1, -1)\n    if post_layer_norm is not None and uses_last_layer:\n        hs_pool[-1] = post_layer_norm(encoder_outputs)\n    return torch.cat(hs_pool, dim=-1)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/parameter.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# Adapted from: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/parameter.py\n\nfrom collections.abc import Callable\nfrom fractions import Fraction\nfrom typing import Any\n\nimport torch\nfrom torch.nn import Parameter\n\nfrom sglang.multimodal_gen.runtime.distributed import get_tp_rank\nfrom sglang.multimodal_gen.runtime.models.utils import _make_synced_weight_loader\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\nclass BasevLLMParameter(Parameter):\n    \"\"\"\n    Base parameter for vLLM linear layers. Extends the torch.nn.parameter\n    by taking in a linear weight loader. Will copy the loaded weight\n    into the parameter when the provided weight loader is called.\n    \"\"\"\n\n    def __new__(cls, data: torch.Tensor, **kwargs):\n\n        return super().__new__(cls, data=data, requires_grad=False)\n\n    def __init__(self, data: torch.Tensor, weight_loader: Callable):\n        \"\"\"\n        Initialize the BasevLLMParameter\n\n        :param data: torch tensor with the parameter data\n        :param weight_loader: weight loader callable\n\n        :returns: a torch.nn.parameter\n        \"\"\"\n\n        # During weight loading, we often do something like:\n        # narrowed_tensor = param.data.narrow(0, offset, len)\n        # narrowed_tensor.copy_(real_weight)\n        # expecting narrowed_tensor and param.data to share the same storage.\n        # However, on TPUs, narrowed_tensor will lazily propagate to the base\n        # tensor, which is param.data, leading to the redundant memory usage.\n        # This sometimes causes OOM errors during model loading. To avoid this,\n        # we sync the param tensor after its weight loader is called.\n        from sglang.multimodal_gen.runtime.platforms import current_platform\n\n        if current_platform.is_tpu():\n            weight_loader = _make_synced_weight_loader(weight_loader)\n\n        self._weight_loader = weight_loader\n\n    @property\n    def weight_loader(self):\n        return self._weight_loader\n\n    def _is_1d_and_scalar(self, loaded_weight: torch.Tensor):\n        cond1 = self.data.ndim == 1 and self.data.numel() == 1\n        cond2 = loaded_weight.ndim == 0 and loaded_weight.numel() == 1\n        return cond1 and cond2\n\n    def _assert_and_load(self, loaded_weight: torch.Tensor) -> None:\n        assert self.data.shape == loaded_weight.shape or self._is_1d_and_scalar(\n            loaded_weight\n        )\n        self.data.copy_(loaded_weight)\n\n    def load_column_parallel_weight(self, loaded_weight: torch.Tensor) -> None:\n        self._assert_and_load(loaded_weight)\n\n    def load_row_parallel_weight(self, loaded_weight: torch.Tensor) -> None:\n        self._assert_and_load(loaded_weight)\n\n    def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs) -> None:\n        self._assert_and_load(loaded_weight)\n\n    def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs) -> None:\n        self._assert_and_load(loaded_weight)\n\n\nclass _ColumnvLLMParameter(BasevLLMParameter):\n    \"\"\"\n    Private class defining weight loading functionality\n    (load_merged_column_weight, load_qkv_weight)\n    for parameters being loaded into linear layers with column\n    parallelism. This includes QKV and MLP layers which are\n    not already fused on disk. Requires an output dimension\n    to be defined. Called within the weight loader of\n    each of the column parallel linear layers.\n    \"\"\"\n\n    def __init__(self, output_dim: int, **kwargs):\n        self._output_dim = output_dim\n        super().__init__(**kwargs)\n\n    @property\n    def output_dim(self):\n        return self._output_dim\n\n    def load_column_parallel_weight(self, loaded_weight: torch.Tensor) -> None:\n        tp_rank = get_tp_rank()\n        shard_size = self.data.shape[self.output_dim]\n        loaded_weight = loaded_weight.narrow(\n            self.output_dim, tp_rank * shard_size, shard_size\n        )\n        assert self.data.shape == loaded_weight.shape\n        self.data.copy_(loaded_weight)\n\n    def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs) -> None:\n\n        shard_offset = kwargs.get(\"shard_offset\")\n        shard_size = kwargs.get(\"shard_size\")\n        if shard_offset is None or shard_size is None:\n            raise ValueError(\"shard_offset and shard_size must be provided\")\n        if (\n            isinstance(self, PackedColumnParameter | PackedvLLMParameter)\n            and self.packed_dim == self.output_dim\n        ):\n            shard_size, shard_offset = self.adjust_shard_indexes_for_packing(\n                shard_offset=shard_offset, shard_size=shard_size\n            )\n\n        param_data = self.data\n\n        tp_rank = get_tp_rank()\n        param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)\n        loaded_weight = loaded_weight.narrow(\n            self.output_dim, tp_rank * shard_size, shard_size\n        )\n        assert param_data.shape == loaded_weight.shape\n        param_data.copy_(loaded_weight)\n\n    def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs) -> None:\n\n        shard_offset = kwargs.get(\"shard_offset\")\n        shard_size = kwargs.get(\"shard_size\")\n        shard_id = kwargs.get(\"shard_id\")\n        num_heads = kwargs.get(\"num_heads\")\n\n        assert shard_offset is not None\n        assert shard_size is not None\n        assert shard_id is not None\n        assert num_heads is not None\n\n        if (\n            isinstance(self, PackedColumnParameter | PackedvLLMParameter)\n            and self.output_dim == self.packed_dim\n        ):\n            shard_size, shard_offset = self.adjust_shard_indexes_for_packing(\n                shard_offset=shard_offset, shard_size=shard_size\n            )\n\n        param_data = self.data\n        tp_rank = get_tp_rank()\n        shard_id = tp_rank if shard_id == \"q\" else tp_rank // num_heads\n        param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)\n        loaded_weight = loaded_weight.narrow(\n            self.output_dim, shard_id * shard_size, shard_size\n        )\n\n        assert param_data.shape == loaded_weight.shape\n        param_data.copy_(loaded_weight)\n\n\nclass RowvLLMParameter(BasevLLMParameter):\n    \"\"\"\n    Parameter class defining weight_loading functionality\n    (load_row_parallel_weight) for parameters being loaded\n    into linear layers with row parallel functionality.\n    Requires an input_dim to be defined.\n    \"\"\"\n\n    def __init__(self, input_dim: int, **kwargs):\n        self._input_dim = input_dim\n        super().__init__(**kwargs)\n\n    @property\n    def input_dim(self):\n        return self._input_dim\n\n    def load_row_parallel_weight(self, loaded_weight: torch.Tensor) -> None:\n        tp_rank = get_tp_rank()\n        shard_size = self.data.shape[self.input_dim]\n        loaded_weight = loaded_weight.narrow(\n            self.input_dim, tp_rank * shard_size, shard_size\n        )\n\n        if len(loaded_weight.shape) == 0:\n            loaded_weight = loaded_weight.reshape(1)\n\n        assert self.data.shape == loaded_weight.shape\n        self.data.copy_(loaded_weight)\n\n\nclass ModelWeightParameter(_ColumnvLLMParameter, RowvLLMParameter):\n    \"\"\"\n    Parameter class for linear layer weights. Uses both column and\n    row parallelism.\n    \"\"\"\n\n    pass\n\n\nclass GroupQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):\n    \"\"\"\n    Parameter class for weight scales loaded for weights with\n    grouped quantization. Uses both column and row parallelism.\n    \"\"\"\n\n    pass\n\n\nclass ChannelQuantScaleParameter(_ColumnvLLMParameter):\n    \"\"\"\n    Parameter class for weight scales loaded for weights with\n    channel-wise quantization. Equivalent to _ColumnvLLMParameter.\n    \"\"\"\n\n    pass\n\n\nclass PerTensorScaleParameter(BasevLLMParameter):\n    \"\"\"\n    Parameter class for scales where the number of scales is\n    equivalent to the number of logical matrices in fused linear\n    layers (e.g. for QKV, there are 3 scales loaded from disk).\n    This is relevant to weights with per-tensor quantization.\n    Adds functionality to map the scalers to a shard during\n    weight loading.\n\n    Note: additional parameter manipulation may be handled\n    for each quantization config specifically, within\n    process_weights_after_loading\n    \"\"\"\n\n    def __init__(self, **kwargs):\n        self.qkv_idxs = {\"q\": 0, \"k\": 1, \"v\": 2}\n        super().__init__(**kwargs)\n\n    def _shard_id_as_int(self, shard_id: str | int) -> int:\n        if isinstance(shard_id, int):\n            return shard_id\n\n        # if not int, assume shard_id for qkv\n        # map to int and return\n        assert isinstance(shard_id, str)\n        assert shard_id in self.qkv_idxs\n        return self.qkv_idxs[shard_id]\n\n    # For row parallel layers, no sharding needed\n    # load weight into parameter as is\n    def load_row_parallel_weight(self, *args, **kwargs) -> None:\n        super().load_row_parallel_weight(*args, **kwargs)\n\n    def load_merged_column_weight(self, *args, **kwargs) -> None:\n        self._load_into_shard_id(*args, **kwargs)\n\n    def load_qkv_weight(self, *args, **kwargs) -> None:\n        self._load_into_shard_id(*args, **kwargs)\n\n    def load_column_parallel_weight(self, *args, **kwargs) -> None:\n        super().load_row_parallel_weight(*args, **kwargs)\n\n    def _load_into_shard_id(\n        self, loaded_weight: torch.Tensor, shard_id: str | int, **kwargs\n    ):\n        \"\"\"\n        Slice the parameter data based on the shard id for\n        loading.\n        \"\"\"\n\n        param_data = self.data\n        shard_id = self._shard_id_as_int(shard_id)\n\n        # AutoFP8 scales do not have a shape\n        # compressed-tensors scales do have a shape\n        if len(loaded_weight.shape) != 0:\n            assert loaded_weight.shape[0] == 1\n            loaded_weight = loaded_weight[0]\n\n        param_data = param_data[shard_id]\n        assert param_data.shape == loaded_weight.shape\n        param_data.copy_(loaded_weight)\n\n\nclass PackedColumnParameter(_ColumnvLLMParameter):\n    \"\"\"\n    Parameter for model parameters which are packed on disk\n    and support column parallelism only. See PackedvLLMParameter\n    for more details on the packed properties.\n    \"\"\"\n\n    def __init__(self, packed_factor: int | Fraction, packed_dim: int, **kwargs):\n        self._packed_factor = packed_factor\n        self._packed_dim = packed_dim\n        super().__init__(**kwargs)\n\n    @property\n    def packed_dim(self):\n        return self._packed_dim\n\n    @property\n    def packed_factor(self):\n        return self._packed_factor\n\n    def adjust_shard_indexes_for_packing(\n        self, shard_size, shard_offset\n    ) -> tuple[Any, Any]:\n        return _adjust_shard_indexes_for_packing(\n            shard_size=shard_size,\n            shard_offset=shard_offset,\n            packed_factor=self.packed_factor,\n        )\n\n\nclass PackedvLLMParameter(ModelWeightParameter):\n    \"\"\"\n    Parameter for model weights which are packed on disk.\n    Example: GPTQ Marlin weights are int4 or int8, packed into int32.\n    Extends the ModelWeightParameter to take in the\n    packed factor, the packed dimension, and optionally, marlin\n    tile size for marlin kernels. Adjusts the shard_size and\n    shard_offset for fused linear layers model weight loading\n    by accounting for packing and optionally, marlin tile size.\n    \"\"\"\n\n    def __init__(self, packed_factor: int | Fraction, packed_dim: int, **kwargs):\n        self._packed_factor = packed_factor\n        self._packed_dim = packed_dim\n        super().__init__(**kwargs)\n\n    @property\n    def packed_dim(self):\n        return self._packed_dim\n\n    @property\n    def packed_factor(self):\n        return self._packed_factor\n\n    def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):\n        return _adjust_shard_indexes_for_packing(\n            shard_size=shard_size,\n            shard_offset=shard_offset,\n            packed_factor=self.packed_factor,\n        )\n\n\nclass BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):\n    \"\"\"\n    Parameter class for weight scales loaded for weights with\n    block-wise quantization. Uses both column and row parallelism.\n    \"\"\"\n\n    pass\n\n\ndef permute_param_layout_(\n    param: BasevLLMParameter, input_dim: int, output_dim: int, **kwargs\n) -> BasevLLMParameter:\n    \"\"\"\n    Permute a parameter's layout to the specified input and output dimensions,\n    useful for forcing the parameter into a known layout, for example, if I need\n    a packed (quantized) weight matrix to be in the layout\n        {input_dim = 0, output_dim = 1, packed_dim = 0}\n    then I can call:\n        permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)\n    to ensure x is in the correct layout (permuting it to the correct layout if\n    required, asserting if it cannot get it to the correct layout)\n    \"\"\"\n\n    curr_input_dim = getattr(param, \"input_dim\", None)\n    curr_output_dim = getattr(param, \"output_dim\", None)\n\n    if curr_input_dim is None or curr_output_dim is None:\n        assert param.data.dim() == 2, (\n            \"permute_param_layout_ only supports 2D parameters when either \"\n            \"input_dim or output_dim is not set\"\n        )\n\n    # if one of the dimensions is not set, set it to the opposite of the other\n    #  we can only do this since we asserted the parameter is 2D above\n    if curr_input_dim is None:\n        assert curr_output_dim is not None, \"either input or output dim must be set\"\n        curr_input_dim = (curr_output_dim + 1) % 2\n    if curr_output_dim is None:\n        assert curr_input_dim is not None, \"either input or output dim must be set\"\n        curr_output_dim = (curr_input_dim + 1) % 2\n\n    # create permutation from the current layout to the layout with\n    # self.input_dim at input_dim and self.output_dim at output_dim preserving\n    # other dimensions\n    perm = [\n        i for i in range(param.data.dim()) if i not in [curr_input_dim, curr_output_dim]\n    ]\n    perm.insert(input_dim, curr_input_dim)\n    perm.insert(output_dim, curr_output_dim)\n\n    if \"packed_dim\" in kwargs:\n        assert (\n            hasattr(param, \"packed_dim\")\n            and param.packed_dim == perm[kwargs[\"packed_dim\"]]\n        ), \"permute_param_layout_ currently doesn't support repacking\"\n\n    param.data = param.data.permute(*perm)\n    if hasattr(param, \"_input_dim\"):\n        param._input_dim = input_dim\n    if hasattr(param, \"_output_dim\"):\n        param._output_dim = output_dim\n    if \"packed_dim\" in kwargs and hasattr(param, \"_packed_dim\"):\n        param._packed_dim = kwargs[\"packed_dim\"]\n\n    return param\n\n\ndef _adjust_shard_indexes_for_packing(\n    shard_size, shard_offset, packed_factor\n) -> tuple[Any, Any]:\n    shard_size = shard_size // packed_factor\n    shard_offset = shard_offset // packed_factor\n    return shard_size, shard_offset\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/registry.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/models/registry.py\n\nimport ast\nimport importlib\nimport os\nimport pickle\nimport subprocess\nimport sys\nimport tempfile\nfrom abc import ABC, abstractmethod\nfrom collections.abc import Callable, Set\nfrom dataclasses import dataclass, field\nfrom functools import lru_cache\nfrom typing import NoReturn, TypeVar, cast\n\nimport cloudpickle\nfrom torch import nn\n\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\nMODELS_PATH = os.path.dirname(__file__)\nCOMPONENT_DIRS = [\n    d\n    for d in os.listdir(MODELS_PATH)\n    if os.path.isdir(os.path.join(MODELS_PATH, d))\n    and not d.startswith(\"__\")\n    and not d.startswith(\".\")\n]\n\n_IMAGE_ENCODER_MODELS: dict[str, tuple] = {\n    # \"HunyuanVideoTransformer3DModel\": (\"image_encoder\", \"hunyuanvideo\", \"HunyuanVideoImageEncoder\"),\n    \"CLIPVisionModelWithProjection\": (\"encoders\", \"clip\", \"CLIPVisionModel\"),\n}\n\n# Global alias mapping: external_path -> canonical_class_name\n_ALIAS_TO_MODEL: dict[str, str] = {}\n\n\ndef _parse_aliases_from_ast(value_node: ast.expr) -> list[str]:\n    \"\"\"Parse _aliases list from AST node.\"\"\"\n    aliases = []\n    if isinstance(value_node, (ast.List, ast.Tuple)):\n        for elt in value_node.elts:\n            if isinstance(elt, ast.Constant) and isinstance(elt.value, str):\n                aliases.append(elt.value)\n    return aliases\n\n\n@lru_cache(maxsize=None)\ndef _discover_and_register_models() -> dict[str, tuple[str, str, str]]:\n    discovered_models = dict(_IMAGE_ENCODER_MODELS)\n\n    # Collect class definitions with their _aliases\n    class_aliases: dict[str, list[str]] = {}\n\n    for component in COMPONENT_DIRS:\n        component_path = os.path.join(MODELS_PATH, component)\n        for filename in os.listdir(component_path):\n            if not filename.endswith(\".py\"):\n                continue\n\n            mod_relname = filename[:-3]\n            filepath = os.path.join(component_path, filename)\n            try:\n                with open(filepath, \"r\", encoding=\"utf-8\") as f:\n                    source = f.read()\n                tree = ast.parse(source, filename=filename)\n\n                entry_class_node = None\n                first_class_def = None\n\n                # Collect all class definitions and their _aliases\n                file_class_aliases: dict[str, list[str]] = {}\n                for node in ast.walk(tree):\n                    if isinstance(node, ast.ClassDef):\n                        if first_class_def is None:\n                            first_class_def = node\n                        # Look for _aliases in the class body\n                        for class_body_node in node.body:\n                            if isinstance(class_body_node, ast.Assign):\n                                for target in class_body_node.targets:\n                                    if (\n                                        isinstance(target, ast.Name)\n                                        and target.id == \"_aliases\"\n                                    ):\n                                        aliases = _parse_aliases_from_ast(\n                                            class_body_node.value\n                                        )\n                                        if aliases:\n                                            file_class_aliases[node.name] = aliases\n                    if isinstance(node, ast.Assign):\n                        for target in node.targets:\n                            if (\n                                isinstance(target, ast.Name)\n                                and target.id == \"EntryClass\"\n                            ):\n                                entry_class_node = node\n                                break\n\n                if entry_class_node and first_class_def:\n                    model_cls_name_list = []\n                    value_node = entry_class_node.value\n\n                    # EntryClass = ClassName\n                    if isinstance(value_node, ast.Name):\n                        model_cls_name_list.append(value_node.id)\n                    # EntryClass = [\"...\", ClassName, ...]\n                    elif isinstance(value_node, (ast.List, ast.Tuple)):\n                        for elt in value_node.elts:\n                            if isinstance(elt, ast.Constant):\n                                model_cls_name_list.append(elt.value)\n                            elif isinstance(elt, ast.Name):\n                                model_cls_name_list.append(elt.id)\n\n                    if model_cls_name_list:\n                        for model_cls_str in model_cls_name_list:\n                            if model_cls_str in discovered_models:\n                                logger.warning(\n                                    f\"Duplicate architecture found: {model_cls_str}. It will be overwritten.\"\n                                )\n                            model_arch = model_cls_str\n                            discovered_models[model_arch] = (\n                                component,\n                                mod_relname,\n                                model_cls_str,\n                            )\n                            # Collect aliases for this class\n                            if model_cls_str in file_class_aliases:\n                                class_aliases[model_cls_str] = file_class_aliases[\n                                    model_cls_str\n                                ]\n\n            except Exception as e:\n                logger.warning(f\"Could not parse {filepath} to find models: {e}\")\n\n    # Build alias -> canonical class name mapping\n    for class_name, aliases in class_aliases.items():\n        for alias in aliases:\n            if alias in _ALIAS_TO_MODEL:\n                logger.warning(\n                    f\"Alias '{alias}' already registered for '{_ALIAS_TO_MODEL[alias]}', \"\n                    f\"will be overwritten by '{class_name}'\"\n                )\n            _ALIAS_TO_MODEL[alias] = class_name\n\n    return discovered_models\n\n\n_SGLANG_DIFFUSION_MODELS = _discover_and_register_models()\n\n_SUBPROCESS_COMMAND = [\n    sys.executable,\n    \"-m\",\n    \"sglang.multimodal_gen.runtime.models.dits.registry\",\n]\n\n_T = TypeVar(\"_T\")\n\n\n@dataclass(frozen=True)\nclass _ModelInfo:\n    architecture: str\n\n    @staticmethod\n    def from_model_cls(model: type[nn.Module]) -> \"_ModelInfo\":\n        return _ModelInfo(\n            architecture=model.__name__,\n        )\n\n\nclass _BaseRegisteredModel(ABC):\n\n    @abstractmethod\n    def inspect_model_cls(self) -> _ModelInfo:\n        raise NotImplementedError\n\n    @abstractmethod\n    def load_model_cls(self) -> type[nn.Module]:\n        raise NotImplementedError\n\n\n@dataclass(frozen=True)\nclass _RegisteredModel(_BaseRegisteredModel):\n    \"\"\"\n    Represents a model that has already been imported in the main process.\n    \"\"\"\n\n    interfaces: _ModelInfo\n    model_cls: type[nn.Module]\n\n    @staticmethod\n    def from_model_cls(model_cls: type[nn.Module]):\n        return _RegisteredModel(\n            interfaces=_ModelInfo.from_model_cls(model_cls),\n            model_cls=model_cls,\n        )\n\n    def inspect_model_cls(self) -> _ModelInfo:\n        return self.interfaces\n\n    def load_model_cls(self) -> type[nn.Module]:\n        return self.model_cls\n\n\ndef _run_in_subprocess(fn: Callable[[], _T]) -> _T:\n    # NOTE: We use a temporary directory instead of a temporary file to avoid\n    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file\n    with tempfile.TemporaryDirectory() as tempdir:\n        output_filepath = os.path.join(tempdir, \"registry_output.tmp\")\n\n        # `cloudpickle` allows pickling lambda functions directly\n        input_bytes = cloudpickle.dumps((fn, output_filepath))\n\n        # cannot use `sys.executable __file__` here because the script\n        # contains relative imports\n        returned = subprocess.run(\n            _SUBPROCESS_COMMAND, input=input_bytes, capture_output=True\n        )\n\n        # check if the subprocess is successful\n        try:\n            returned.check_returncode()\n        except Exception as e:\n            # wrap raised exception to provide more information\n            raise RuntimeError(\n                f\"Error raised in subprocess:\\n\" f\"{returned.stderr.decode()}\"\n            ) from e\n\n        with open(output_filepath, \"rb\") as f:\n            return cast(_T, pickle.load(f))\n\n\n@dataclass(frozen=True)\nclass _LazyRegisteredModel(_BaseRegisteredModel):\n    \"\"\"\n    Represents a model that has not been imported in the main process.\n    \"\"\"\n\n    module_name: str\n    component_name: str\n    class_name: str\n\n    # Performed in another process to avoid initializing CUDA\n    def inspect_model_cls(self) -> _ModelInfo:\n        return _run_in_subprocess(\n            lambda: _ModelInfo.from_model_cls(self.load_model_cls())\n        )\n\n    def load_model_cls(self) -> type[nn.Module]:\n        mod = importlib.import_module(self.module_name)\n        return cast(type[nn.Module], getattr(mod, self.class_name))\n\n\n@lru_cache(maxsize=128)\ndef _try_load_model_cls(\n    model_arch: str,\n    model: _BaseRegisteredModel,\n) -> type[nn.Module] | None:\n    from sglang.multimodal_gen.runtime.platforms import current_platform\n\n    current_platform.verify_model_arch(model_arch)\n    try:\n        return model.load_model_cls()\n    except Exception:\n        logger.exception(\"Ignore import error when loading '%s'\", model_arch)\n        return None\n\n\n@lru_cache(maxsize=128)\ndef _try_inspect_model_cls(\n    model_arch: str,\n    model: _BaseRegisteredModel,\n) -> _ModelInfo | None:\n    try:\n        return model.inspect_model_cls()\n    except Exception:\n        logger.exception(\"Error in inspecting model architecture '%s'\", model_arch)\n        return None\n\n\n@dataclass\nclass _ModelRegistry:\n    # Keyed by model_arch\n    registered_models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)\n\n    def get_supported_archs(self) -> Set[str]:\n        return self.registered_models.keys()\n\n    def resolve_by_alias(self, alias: str) -> type[nn.Module] | None:\n        \"\"\"Resolve a model class by its alias (external module path).\"\"\"\n        if alias in _ALIAS_TO_MODEL:\n            canonical_name = _ALIAS_TO_MODEL[alias]\n            return self._try_load_model_cls(canonical_name)\n        return None\n\n    def register_model(\n        self,\n        model_arch: str,\n        model_cls: type[nn.Module] | str,\n    ) -> None:\n        \"\"\"\n        Register an external model to be used in vLLM.\n\n        :code:`model_cls` can be either:\n\n        - A :class:`torch.nn.Module` class directly referencing the model.\n        - A string in the format :code:`<module>:<class>` which can be used to\n          lazily import the model. This is useful to avoid initializing CUDA\n          when importing the model and thus the related error\n          :code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`.\n        \"\"\"\n        if model_arch in self.registered_models:\n            logger.warning(\n                \"Model architecture %s is already registered, and will be \"\n                \"overwritten by the new model class %s.\",\n                model_arch,\n                model_cls,\n            )\n\n        if isinstance(model_cls, str):\n            split_str = model_cls.split(\":\")\n            if len(split_str) != 2:\n                msg = \"Expected a string in the format `<module>:<class>`\"\n                raise ValueError(msg)\n\n            model = _LazyRegisteredModel(*split_str)\n        else:\n            model = _RegisteredModel.from_model_cls(model_cls)\n\n        self.registered_models[model_arch] = model\n\n    def _raise_for_unsupported(self, architectures: list[str]) -> NoReturn:\n        all_supported_archs = self.get_supported_archs()\n\n        if any(arch in all_supported_archs for arch in architectures):\n            raise ValueError(\n                f\"Model architectures {architectures} failed \"\n                \"to be inspected. Please check the logs for more details.\"\n            )\n\n        raise ValueError(\n            f\"Model architectures {architectures} are not supported for now. \"\n            f\"Supported architectures: {all_supported_archs}\"\n        )\n\n    def _try_load_model_cls(self, model_arch: str) -> type[nn.Module] | None:\n        if model_arch not in self.registered_models:\n            return None\n\n        return _try_load_model_cls(model_arch, self.registered_models[model_arch])\n\n    def _try_inspect_model_cls(self, model_arch: str) -> _ModelInfo | None:\n        if model_arch not in self.registered_models:\n            return None\n\n        return _try_inspect_model_cls(model_arch, self.registered_models[model_arch])\n\n    def _normalize_archs(\n        self,\n        architectures: str | list[str],\n    ) -> list[str]:\n        if isinstance(architectures, str):\n            architectures = [architectures]\n        if not architectures:\n            logger.warning(\"No model architectures are specified\")\n\n        normalized_arch = []\n        for arch in architectures:\n            if arch not in self.registered_models:\n                registered_models = list(self.registered_models.keys())\n                raise Exception(\n                    f\"Unsupported model architecture: {arch}. Registered architectures: {registered_models}\"\n                )\n            normalized_arch.append(arch)\n        return normalized_arch\n\n    def inspect_model_cls(\n        self,\n        architectures: str | list[str],\n    ) -> tuple[_ModelInfo, str]:\n        architectures = self._normalize_archs(architectures)\n\n        for arch in architectures:\n            model_info = self._try_inspect_model_cls(arch)\n            if model_info is not None:\n                return (model_info, arch)\n\n        return self._raise_for_unsupported(architectures)\n\n    def resolve_model_cls(\n        self,\n        architectures: str | list[str],\n    ) -> tuple[type[nn.Module], str]:\n        architectures = self._normalize_archs(architectures)\n\n        for arch in architectures:\n            model_cls = self._try_load_model_cls(arch)\n            if model_cls is not None:\n                return (model_cls, arch)\n\n        return self._raise_for_unsupported(architectures)\n\n\nModelRegistry = _ModelRegistry(\n    {\n        model_arch: _LazyRegisteredModel(\n            module_name=f\"sglang.multimodal_gen.runtime.models.{component_name}.{mod_relname}\",\n            component_name=component_name,\n            class_name=cls_name,\n        )\n        for model_arch, (\n            component_name,\n            mod_relname,\n            cls_name,\n        ) in _SGLANG_DIFFUSION_MODELS.items()\n    }\n)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/schedulers/base.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\nfrom abc import ABC, abstractmethod\n\nimport torch\n\n\nclass BaseScheduler(ABC):\n    timesteps: torch.Tensor\n    order: int\n    num_train_timesteps: int\n\n    def __init__(self, *args, **kwargs) -> None:\n        # Check if subclass has defined all required properties\n        required_attributes = [\"timesteps\", \"order\", \"num_train_timesteps\"]\n\n        for attr in required_attributes:\n            if not hasattr(self, attr):\n                raise AttributeError(\n                    f\"Subclasses of BaseScheduler must define '{attr}' property\"\n                )\n\n    @abstractmethod\n    def set_shift(self, shift: float) -> None:\n        pass\n\n    @abstractmethod\n    def set_timesteps(self, *args, **kwargs) -> None:\n        pass\n\n    @abstractmethod\n    def scale_model_input(\n        self, sample: torch.Tensor, timestep: int | None = None\n    ) -> torch.Tensor:\n        pass\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/schedulers/flow_match_pair.py",
    "content": "# Copied and adapted from: https://github.com/OpenMOSS/MOVA/tree/main/mova/diffusion/schedulers/flow_match.py and flow_match_pair.py\n# SPDX-License-Identifier: Apache-2.0\n\nfrom __future__ import annotations\n\nimport math\n\nimport torch\n\nfrom sglang.multimodal_gen.runtime.models.schedulers.base import BaseScheduler\n\n\nclass FlowMatchScheduler(BaseScheduler):\n    def __init__(\n        self,\n        num_inference_steps=100,\n        num_train_timesteps=1000,\n        shift=3.0,\n        sigma_max=1.0,\n        sigma_min=0.003 / 1.002,\n        inverse_timesteps=False,\n        extra_one_step=False,\n        reverse_sigmas=False,\n        exponential_shift=False,\n        exponential_shift_mu=None,\n        shift_terminal=None,\n    ):\n        self.order = 1\n        self.num_train_timesteps = num_train_timesteps\n        self.shift = shift\n        self.sigma_max = sigma_max\n        self.sigma_min = sigma_min\n        self.inverse_timesteps = inverse_timesteps\n        self.extra_one_step = extra_one_step\n        self.reverse_sigmas = reverse_sigmas\n        self.exponential_shift = exponential_shift\n        self.exponential_shift_mu = exponential_shift_mu\n        self.shift_terminal = shift_terminal\n        self.train_timesteps = None\n        self.train_sigmas = None\n        self.set_timesteps(num_train_timesteps)\n        self.set_timesteps(num_inference_steps)\n        BaseScheduler.__init__(self)\n\n    def set_shift(self, shift: float) -> None:\n        self.shift = shift\n\n    def set_timesteps(\n        self,\n        num_inference_steps=100,\n        denoising_strength=1.0,\n        training=False,\n        shift=None,\n        dynamic_shift_len=None,\n    ):\n        if shift is not None:\n            self.shift = shift\n        sigma_start = (\n            self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength\n        )\n        if self.extra_one_step:\n            self.sigmas = torch.linspace(\n                sigma_start, self.sigma_min, num_inference_steps + 1\n            )[:-1]\n        else:\n            self.sigmas = torch.linspace(\n                sigma_start, self.sigma_min, num_inference_steps\n            )\n        if self.inverse_timesteps:\n            self.sigmas = torch.flip(self.sigmas, dims=[0])\n        if self.exponential_shift:\n            mu = (\n                self.calculate_shift(dynamic_shift_len)\n                if dynamic_shift_len is not None\n                else self.exponential_shift_mu\n            )\n            self.sigmas = math.exp(mu) / (math.exp(mu) + (1 / self.sigmas - 1))\n        else:\n            self.sigmas = (\n                self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)\n            )\n        if self.shift_terminal is not None:\n            one_minus_z = 1 - self.sigmas\n            scale_factor = one_minus_z[-1] / (1 - self.shift_terminal)\n            self.sigmas = 1 - (one_minus_z / scale_factor)\n        if self.reverse_sigmas:\n            self.sigmas = 1 - self.sigmas\n        self.timesteps = self.sigmas * self.num_train_timesteps\n        # Initialize train_timesteps on first set.\n        if self.train_timesteps is None:\n            self.train_timesteps = self.timesteps\n            self.train_sigmas = self.sigmas\n        if training:\n            x = self.timesteps\n            y = torch.exp(\n                -2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2\n            )\n            y_shifted = y - y.min()\n            bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())\n            self.linear_timesteps_weights = bsmntw_weighing\n            self.training = True\n        else:\n            self.training = False\n\n    def scale_model_input(self, sample: torch.Tensor, timestep: int | None = None):\n        return sample\n\n    def step(self, model_output, timestep, sample, to_final=False, **kwargs):\n        if isinstance(timestep, torch.Tensor):\n            timestep = timestep.cpu()\n        timestep_id = torch.argmin((self.timesteps - timestep).abs())\n        sigma = self.sigmas[timestep_id]\n        if to_final or timestep_id + 1 >= len(self.timesteps):\n            sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0\n        else:\n            sigma_ = self.sigmas[timestep_id + 1]\n        prev_sample = sample + model_output * (sigma_ - sigma)\n        return prev_sample\n\n    def return_to_timestep(self, timestep, sample, sample_stablized):\n        if isinstance(timestep, torch.Tensor):\n            timestep = timestep.cpu()\n        timestep_id = torch.argmin((self.timesteps - timestep).abs())\n        sigma = self.sigmas[timestep_id]\n        model_output = (sample - sample_stablized) / sigma\n        return model_output\n\n    def add_noise(self, original_samples, noise, timestep):\n        if isinstance(timestep, torch.Tensor):\n            timestep = timestep.cpu()\n        timestep_id = torch.argmin((self.timesteps - timestep).abs())\n        sigma = self.sigmas[timestep_id]\n        sample = (1 - sigma) * original_samples + sigma * noise\n        return sample\n\n    def training_target(self, sample, noise, timestep):\n        target = noise - sample\n        return target\n\n    def training_weight(self, timestep):\n        timestep_id = torch.argmin(\n            (self.timesteps - timestep.to(self.timesteps.device)).abs()\n        )\n        weights = self.linear_timesteps_weights[timestep_id]\n        return weights\n\n    def calculate_shift(\n        self,\n        image_seq_len,\n        base_seq_len: int = 256,\n        max_seq_len: int = 8192,\n        base_shift: float = 0.5,\n        max_shift: float = 0.9,\n    ):\n        m = (max_shift - base_shift) / (max_seq_len - base_seq_len)\n        b = base_shift - m * base_seq_len\n        mu = image_seq_len * m + b\n        return mu\n\n\nclass FlowMatchPairScheduler(FlowMatchScheduler):\n    \"\"\"Pairing scheduler built on FlowMatchScheduler.\n\n    Provides a convenient pairing interface for timesteps or sigmas.\n\n    Attributes:\n        pair_timesteps: Cached timestep pairs of shape [num_timesteps, 2].\n        pair_sigmas: Cached sigma pairs of shape [num_timesteps, 2].\n    \"\"\"\n\n    def __init__(\n        self,\n        num_inference_steps=100,\n        num_train_timesteps=1000,\n        shift=3.0,\n        sigma_max=1.0,\n        sigma_min=0.003 / 1.002,\n        inverse_timesteps=False,\n        extra_one_step=False,\n        reverse_sigmas=False,\n        exponential_shift=False,\n        exponential_shift_mu=None,\n        shift_terminal=None,\n    ):\n        self._pair_postprocess_fn = None\n        self._pair_postprocess_requires_source = False\n        self.pair_timesteps: torch.Tensor | None = None\n        self.pair_sigmas: torch.Tensor | None = None\n        self.timesteps: torch.Tensor | None = None\n        self.sigmas: torch.Tensor | None = None\n        super().__init__(\n            num_inference_steps=num_inference_steps,\n            num_train_timesteps=num_train_timesteps,\n            shift=shift,\n            sigma_max=sigma_max,\n            sigma_min=sigma_min,\n            inverse_timesteps=inverse_timesteps,\n            extra_one_step=extra_one_step,\n            reverse_sigmas=reverse_sigmas,\n            exponential_shift=exponential_shift,\n            exponential_shift_mu=exponential_shift_mu,\n            shift_terminal=shift_terminal,\n        )\n\n    def set_pair_postprocess(self, fn):\n        \"\"\"Set a postprocess function to customize pairs after construction.\n\n        Args:\n            fn: Callable with signature fn(pairs: torch.Tensor) -> torch.Tensor.\n                The returned tensor must have the same shape as input pairs.\n\n        Raises:\n            TypeError: If fn is not callable or None.\n            RuntimeError: If scheduler is not initialized.\n        \"\"\"\n        if fn is not None and not callable(fn):\n            raise TypeError(\"pair_postprocess must be callable or None\")\n        self._pair_postprocess_fn = fn\n        self._pair_postprocess_requires_source = (\n            False if fn is None else bool(getattr(fn, \"_requires_source\", False))\n        )\n        if self.timesteps is None or self.sigmas is None:\n            raise RuntimeError(\"Scheduler not initialized; call set_timesteps() first\")\n        self._refresh_pair_cache()\n\n    def set_pair_postprocess_by_name(self, name: str | None, **kwargs):\n        \"\"\"Configure a postprocess function by name.\n\n        Supported names:\n            - None/\"none\"/\"off\"/\"false\"/\"no\": disable\n            - \"quadratic_perp_bulge_swap\": x2=x+d, y2=x-d, where d=4*amp*s*(1-s), s=t/T\n            - \"v2a_sequential\": assume pairs are (t,t); sample half sequence from column 0\n              with stride 2, then let column 0 follow this sequence first, followed by column 1\n            - \"a2v_sequential\": same as above, but column 1 first then column 0\n            - \"dual_sigma_shift\": use only timestep count; rebuild two columns independently using\n              FlowMatchScheduler sigma transform logic; configurable visual_shift/audio_shift\n\n        Args:\n            name: Postprocess name or None to disable.\n            **kwargs: Extra parameters for the named postprocess. For example:\n                - amp: Float amplitude, default 150.0.\n\n        Raises:\n            ValueError: If name is unknown.\n        \"\"\"\n\n        if name is None or str(name).lower() in (\"none\", \"off\", \"false\", \"no\"):\n            self.set_pair_postprocess(None)\n            return\n        if name == \"quadratic_perp_bulge_swap\":\n            amp = float(kwargs.get(\"amp\", 150.0))\n\n            def _quadratic_perp_bulge_swap(pairs: torch.Tensor):\n                if (\n                    not isinstance(pairs, torch.Tensor)\n                    or pairs.ndim != 2\n                    or pairs.shape[1] != 2\n                ):\n                    raise ValueError(\"pairs must be a torch.Tensor of shape [N, 2]\")\n                x = pairs[:, 0]\n                T = float(self.num_train_timesteps)\n                s = x / T\n                d = 4.0 * amp * s * (1.0 - s)\n                x2 = x + d\n                y2 = x - d\n                return torch.stack([x2, y2], dim=1)\n\n            self.set_pair_postprocess(_quadratic_perp_bulge_swap)\n            return\n        if name == \"v2a_sequential\":\n\n            def _v2a(pairs: torch.Tensor):\n                if (\n                    not isinstance(pairs, torch.Tensor)\n                    or pairs.ndim != 2\n                    or pairs.shape[1] != 2\n                ):\n                    raise ValueError(\"pairs must be a torch.Tensor of shape [N, 2]\")\n                N = pairs.shape[0]\n                base = pairs[:, 0]\n                seq_half = base[::2]\n                m = int(seq_half.shape[0])\n                col0 = torch.cat([seq_half, seq_half[-1:].repeat(m)], dim=0)[:N]\n                col1 = torch.cat([seq_half[0:1].repeat(m), seq_half], dim=0)[:N]\n                return torch.stack(\n                    [\n                        col0.to(dtype=pairs.dtype, device=pairs.device),\n                        col1.to(dtype=pairs.dtype, device=pairs.device),\n                    ],\n                    dim=1,\n                )\n\n            self.set_pair_postprocess(_v2a)\n            return\n        if name == \"a2v_sequential\":\n\n            def _a2v(pairs: torch.Tensor):\n                if (\n                    not isinstance(pairs, torch.Tensor)\n                    or pairs.ndim != 2\n                    or pairs.shape[1] != 2\n                ):\n                    raise ValueError(\"pairs must be a torch.Tensor of shape [N, 2]\")\n                N = pairs.shape[0]\n                base = pairs[:, 0]\n                seq_half = base[::2]\n                m = int(seq_half.shape[0])\n                col0 = torch.cat([seq_half[0:1].repeat(m), seq_half], dim=0)[:N]\n                col1 = torch.cat([seq_half, seq_half[-1:].repeat(m)], dim=0)[:N]\n                return torch.stack(\n                    [\n                        col0.to(dtype=pairs.dtype, device=pairs.device),\n                        col1.to(dtype=pairs.dtype, device=pairs.device),\n                    ],\n                    dim=1,\n                )\n\n            self.set_pair_postprocess(_a2v)\n            return\n        if name == \"v2a\":\n\n            def _v2a_classic(pairs: torch.Tensor):\n                if (\n                    not isinstance(pairs, torch.Tensor)\n                    or pairs.ndim != 2\n                    or pairs.shape[1] != 2\n                ):\n                    raise ValueError(\"pairs must be a torch.Tensor of shape [N, 2]\")\n                zeros = torch.zeros_like(pairs[:, 0])\n                return torch.stack([zeros, pairs[:, 1]], dim=1)\n\n            self.set_pair_postprocess(_v2a_classic)\n            return\n        if name == \"a2v\":\n\n            def _a2v_classic(pairs: torch.Tensor):\n                if (\n                    not isinstance(pairs, torch.Tensor)\n                    or pairs.ndim != 2\n                    or pairs.shape[1] != 2\n                ):\n                    raise ValueError(\"pairs must be a torch.Tensor of shape [N, 2]\")\n                zeros = torch.zeros_like(pairs[:, 1])\n                return torch.stack([pairs[:, 0], zeros], dim=1)\n\n            self.set_pair_postprocess(_a2v_classic)\n            return\n        if name == \"dual_sigma_shift\":\n            visual_shift = float(kwargs.get(\"visual_shift\", self.shift))\n            audio_shift = float(kwargs.get(\"audio_shift\", self.shift))\n            visual_denoising_strength = float(\n                kwargs.get(\"visual_denoising_strength\", 1.0)\n            )\n            audio_denoising_strength = float(\n                kwargs.get(\"audio_denoising_strength\", 1.0)\n            )\n            visual_mu = kwargs.get(\n                \"visual_exponential_shift_mu\", self.exponential_shift_mu\n            )\n            audio_mu = kwargs.get(\n                \"audio_exponential_shift_mu\", self.exponential_shift_mu\n            )\n\n            def _dual_sigma_shift(pairs: torch.Tensor, *, source: str):\n                if not isinstance(pairs, torch.Tensor):\n                    raise TypeError(\"pairs must be a torch.Tensor\")\n                if pairs.ndim != 2 or pairs.shape[1] != 2:\n                    raise ValueError(\"pairs must be a torch.Tensor of shape [N, 2]\")\n                if pairs.shape[0] == 0:\n                    raise ValueError(\"pairs length must be greater than 0\")\n                if source not in (\"timesteps\", \"sigmas\"):\n                    raise ValueError(\"source must be 'timesteps' or 'sigmas'\")\n\n                num_steps = pairs.shape[0]\n                device = pairs.device\n                dtype = pairs.dtype\n\n                def _build_column(\n                    shift_value: float, denoising_strength: float, mu_override\n                ):\n                    if shift_value <= 0:\n                        raise ValueError(\"shift must be positive\")\n                    if denoising_strength <= 0:\n                        raise ValueError(\"denoising_strength must be positive\")\n\n                    sigma_start = (\n                        self.sigma_min\n                        + (self.sigma_max - self.sigma_min) * denoising_strength\n                    )\n                    if self.extra_one_step:\n                        base = torch.linspace(\n                            sigma_start,\n                            self.sigma_min,\n                            num_steps + 1,\n                            device=device,\n                            dtype=dtype,\n                        )[:-1]\n                    else:\n                        base = torch.linspace(\n                            sigma_start,\n                            self.sigma_min,\n                            num_steps,\n                            device=device,\n                            dtype=dtype,\n                        )\n\n                    if self.inverse_timesteps:\n                        base = torch.flip(base, dims=[0])\n\n                    if self.exponential_shift:\n                        mu_value = mu_override\n                        if mu_value is None:\n                            raise RuntimeError(\n                                \"exponential_shift enabled but exponential_shift_mu is missing\"\n                            )\n                        exp_mu = math.exp(float(mu_value))\n                        base = exp_mu / (exp_mu + (1 / base - 1))\n                    else:\n                        base = shift_value * base / (1 + (shift_value - 1) * base)\n\n                    if self.shift_terminal is not None:\n                        one_minus_z = 1 - base\n                        scale_factor = one_minus_z[-1] / (1 - self.shift_terminal)\n                        base = 1 - (one_minus_z / scale_factor)\n\n                    if self.reverse_sigmas:\n                        base = 1 - base\n\n                    if source == \"timesteps\":\n                        return base * self.num_train_timesteps\n                    return base\n\n                col0 = _build_column(visual_shift, visual_denoising_strength, visual_mu)\n                col1 = _build_column(audio_shift, audio_denoising_strength, audio_mu)\n                return torch.stack([col0, col1], dim=1)\n\n            _dual_sigma_shift._requires_source = True\n            self.set_pair_postprocess(_dual_sigma_shift)\n            return\n        raise ValueError(f\"Unknown pair_postprocess name: {name}\")\n\n    def _make_pairs_from_vector(self, vec: torch.Tensor) -> torch.Tensor:\n        if vec.ndim != 1:\n            raise ValueError(\"vec must be 1D\")\n        return torch.stack([vec, vec], dim=1)\n\n    def get_pairs(self, source: str = \"timesteps\") -> torch.Tensor:\n        if source == \"timesteps\":\n            if self.pair_timesteps is None:\n                self._refresh_pair_cache()\n            return self.pair_timesteps\n        if source == \"sigmas\":\n            if self.pair_sigmas is None:\n                self._refresh_pair_cache()\n            return self.pair_sigmas\n        raise ValueError(\"source must be 'timesteps' or 'sigmas'\")\n\n    def timestep_to_sigma(self, timestep: torch.Tensor | float) -> torch.Tensor:\n        \"\"\"Return sigma for a scalar timestep via nearest neighbor lookup.\n\n        Args:\n            timestep: Scalar timestep value.\n\n        Returns:\n            Sigma corresponding to the nearest timestep.\n        \"\"\"\n        t_value = float(timestep)\n        t_cpu = torch.tensor(t_value)\n        idx = torch.argmin((self.train_timesteps - t_cpu).abs())\n        return self.train_sigmas[idx]\n\n    def step_from_to(\n        self,\n        model_output: torch.Tensor,\n        timestep_from: torch.Tensor,\n        timestep_to: torch.Tensor | None,\n        sample: torch.Tensor,\n    ) -> torch.Tensor:\n        \"\"\"Advance one step using an explicit (from, to) timestep pair.\n\n        The update rule is:\n            x_{to} = x_{from} + model_output * (sigma(to) - sigma(from))\n\n        Args:\n            model_output: Predicted model output.\n            timestep_from: Source timestep.\n            timestep_to: Target timestep or None for terminal.\n            sample: Current sample at timestep_from.\n\n        Returns:\n            Updated sample at timestep_to.\n        \"\"\"\n        sigma_from = self.timestep_to_sigma(timestep_from)\n        if timestep_to is None:\n            sigma_to = torch.tensor(\n                1.0 if (self.inverse_timesteps or self.reverse_sigmas) else 0.0,\n                device=sigma_from.device,\n                dtype=sigma_from.dtype,\n            )\n        else:\n            sigma_to = self.timestep_to_sigma(timestep_to)\n        prev_sample = sample + model_output * (sigma_to - sigma_from)\n        return prev_sample\n\n    def _refresh_pair_cache(self) -> None:\n        if self.timesteps is None or self.sigmas is None:\n            raise RuntimeError(\"Scheduler not initialized; call set_timesteps() first\")\n\n        def _apply_postprocess(pairs: torch.Tensor, source: str) -> torch.Tensor:\n            if self._pair_postprocess_fn is None:\n                return pairs\n            if self._pair_postprocess_requires_source:\n                modified = self._pair_postprocess_fn(pairs, source=source)\n            else:\n                modified = self._pair_postprocess_fn(pairs)\n            if not isinstance(modified, torch.Tensor):\n                raise TypeError(\"pair_postprocess must return a torch.Tensor\")\n            if modified.shape != pairs.shape:\n                raise ValueError(\"pair_postprocess must return the same shape as input\")\n            return modified\n\n        base_pairs_timesteps = self._make_pairs_from_vector(self.timesteps)\n        base_pairs_sigmas = self._make_pairs_from_vector(self.sigmas)\n\n        self.pair_timesteps = _apply_postprocess(base_pairs_timesteps, \"timesteps\")\n        self.pair_sigmas = _apply_postprocess(base_pairs_sigmas, \"sigmas\")\n\n\nEntryClass = FlowMatchPairScheduler\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/schedulers/hunyuan3d_scheduler.py",
    "content": "# Copied and adapted from: https://github.com/Tencent-Hunyuan/Hunyuan3D-2\nfrom __future__ import annotations\n\nimport math\nfrom dataclasses import dataclass\nfrom typing import List, Optional, Tuple, Union\n\nimport numpy as np\nimport torch\nfrom diffusers.configuration_utils import ConfigMixin, register_to_config\nfrom diffusers.schedulers.scheduling_utils import SchedulerMixin\nfrom diffusers.utils import BaseOutput\n\n\n@dataclass\nclass Hunyuan3DFlowMatchSchedulerOutput(BaseOutput):\n    \"\"\"Output class for the scheduler's step function.\"\"\"\n\n    prev_sample: torch.FloatTensor\n\n\nclass Hunyuan3DFlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):\n    \"\"\"Euler discrete scheduler for flow matching.\"\"\"\n\n    # External module path aliases for compatibility with Hunyuan3D configs\n    _aliases = [\n        \"hy3dgen.shapegen.schedulers.FlowMatchEulerDiscreteScheduler\",\n        \"hy3dshape.schedulers.FlowMatchEulerDiscreteScheduler\",\n    ]\n\n    _compatibles = []\n    order = 1\n\n    @register_to_config\n    def __init__(\n        self,\n        num_train_timesteps: int = 1000,\n        shift: float = 1.0,\n        use_dynamic_shifting: bool = False,\n    ):\n        timesteps = np.linspace(\n            1, num_train_timesteps, num_train_timesteps, dtype=np.float32\n        ).copy()\n        timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)\n\n        sigmas = timesteps / num_train_timesteps\n        if not use_dynamic_shifting:\n            sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)\n\n        self.timesteps = sigmas * num_train_timesteps\n        self._step_index = None\n        self._begin_index = None\n\n        self.sigmas = sigmas.to(\"cpu\")\n        self.sigma_min = self.sigmas[-1].item()\n        self.sigma_max = self.sigmas[0].item()\n\n    @property\n    def step_index(self) -> Optional[int]:\n        \"\"\"The index counter for current timestep.\"\"\"\n        return self._step_index\n\n    @property\n    def begin_index(self) -> Optional[int]:\n        \"\"\"The index for the first timestep.\"\"\"\n        return self._begin_index\n\n    def set_begin_index(self, begin_index: int = 0):\n        \"\"\"Set the begin index for the scheduler.\n\n        Args:\n            begin_index: The begin index for the scheduler.\n        \"\"\"\n        self._begin_index = begin_index\n\n    def scale_model_input(\n        self,\n        sample: torch.FloatTensor,\n        timestep: Optional[Union[float, torch.FloatTensor]] = None,\n    ) -> torch.FloatTensor:\n        \"\"\"Identity operation for flow matching (no input scaling needed).\"\"\"\n        return sample\n\n    def scale_noise(\n        self,\n        sample: torch.FloatTensor,\n        timestep: Union[float, torch.FloatTensor],\n        noise: Optional[torch.FloatTensor] = None,\n    ) -> torch.FloatTensor:\n        \"\"\"Forward process in flow-matching (add noise to sample).\"\"\"\n        sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype)\n\n        if sample.device.type == \"mps\" and torch.is_floating_point(timestep):\n            schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)\n            timestep = timestep.to(sample.device, dtype=torch.float32)\n        else:\n            schedule_timesteps = self.timesteps.to(sample.device)\n            timestep = timestep.to(sample.device)\n\n        if self.begin_index is None:\n            step_indices = [\n                self.index_for_timestep(t, schedule_timesteps) for t in timestep\n            ]\n        elif self.step_index is not None:\n            step_indices = [self.step_index] * timestep.shape[0]\n        else:\n            step_indices = [self.begin_index] * timestep.shape[0]\n\n        sigma = sigmas[step_indices].flatten()\n        while len(sigma.shape) < len(sample.shape):\n            sigma = sigma.unsqueeze(-1)\n\n        sample = sigma * noise + (1.0 - sigma) * sample\n        return sample\n\n    def _sigma_to_t(self, sigma: float) -> float:\n        \"\"\"Convert sigma to timestep.\"\"\"\n        return sigma * self.config.num_train_timesteps\n\n    def time_shift(self, mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor:\n        \"\"\"Apply time shift transformation.\"\"\"\n        return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)\n\n    def set_timesteps(\n        self,\n        num_inference_steps: int = None,\n        device: Union[str, torch.device] = None,\n        sigmas: Optional[List[float]] = None,\n        mu: Optional[float] = None,\n    ):\n        \"\"\"Set the discrete timesteps for the diffusion chain.\"\"\"\n        if self.config.use_dynamic_shifting and mu is None:\n            raise ValueError(\n                \"Must pass a value for `mu` when `use_dynamic_shifting` is True\"\n            )\n\n        if sigmas is None:\n            self.num_inference_steps = num_inference_steps\n            timesteps = np.linspace(\n                self._sigma_to_t(self.sigma_max),\n                self._sigma_to_t(self.sigma_min),\n                num_inference_steps,\n            )\n            sigmas = timesteps / self.config.num_train_timesteps\n\n        if self.config.use_dynamic_shifting:\n            sigmas = self.time_shift(mu, 1.0, sigmas)\n        else:\n            sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)\n\n        sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)\n        timesteps = sigmas * self.config.num_train_timesteps\n\n        self.timesteps = timesteps.to(device=device)\n        self.sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)])\n\n        self._step_index = None\n        self._begin_index = None\n\n    def index_for_timestep(\n        self, timestep: float, schedule_timesteps: Optional[torch.Tensor] = None\n    ) -> int:\n        \"\"\"Find the index for a given timestep.\"\"\"\n        if schedule_timesteps is None:\n            schedule_timesteps = self.timesteps\n\n        indices = (schedule_timesteps == timestep).nonzero()\n        pos = 1 if len(indices) > 1 else 0\n        return indices[pos].item()\n\n    def _init_step_index(self, timestep: Union[float, torch.Tensor]):\n        \"\"\"Initialize step index from timestep.\"\"\"\n        if self.begin_index is None:\n            if isinstance(timestep, torch.Tensor):\n                timestep = timestep.to(self.timesteps.device)\n            self._step_index = self.index_for_timestep(timestep)\n        else:\n            self._step_index = self._begin_index\n\n    def step(\n        self,\n        model_output: torch.FloatTensor,\n        timestep: Union[float, torch.FloatTensor],\n        sample: torch.FloatTensor,\n        s_churn: float = 0.0,\n        s_tmin: float = 0.0,\n        s_tmax: float = float(\"inf\"),\n        s_noise: float = 1.0,\n        generator: Optional[torch.Generator] = None,\n        return_dict: bool = True,\n    ) -> Union[Hunyuan3DFlowMatchSchedulerOutput, Tuple]:\n        \"\"\"Predict the sample from the previous timestep.\"\"\"\n        if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):\n            raise ValueError(\n                \"Passing integer indices as timesteps is not supported. \"\n                \"Pass one of `scheduler.timesteps` as a timestep.\"\n            )\n\n        if self.step_index is None:\n            self._init_step_index(timestep)\n\n        # Upcast to avoid precision issues\n        sample = sample.to(torch.float32)\n\n        sigma = self.sigmas[self.step_index]\n        sigma_next = self.sigmas[self.step_index + 1]\n\n        prev_sample = sample + (sigma_next - sigma) * model_output\n        prev_sample = prev_sample.to(model_output.dtype)\n\n        self._step_index += 1\n\n        if not return_dict:\n            return (prev_sample,)\n\n        return Hunyuan3DFlowMatchSchedulerOutput(prev_sample=prev_sample)\n\n    def __len__(self) -> int:\n        return self.config.num_train_timesteps\n\n\n@dataclass\nclass Hunyuan3DConsistencyFlowMatchSchedulerOutput(BaseOutput):\n    \"\"\"Output for consistency flow matching scheduler.\"\"\"\n\n    prev_sample: torch.FloatTensor\n    pred_original_sample: torch.FloatTensor\n\n\nclass Hunyuan3DConsistencyFlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):\n    \"\"\"Consistency Flow Matching Euler Discrete Scheduler.\"\"\"\n\n    # External module path aliases for compatibility with Hunyuan3D configs\n    _aliases = [\n        \"hy3dshape.schedulers.Hunyuan3DConsistencyFlowMatchEulerDiscreteScheduler\",\n    ]\n\n    _compatibles = []\n    order = 1\n\n    @register_to_config\n    def __init__(\n        self,\n        num_train_timesteps: int = 1000,\n        pcm_timesteps: int = 50,\n    ):\n        sigmas = np.linspace(0, 1, num_train_timesteps)\n        step_ratio = num_train_timesteps // pcm_timesteps\n\n        euler_timesteps = (np.arange(1, pcm_timesteps) * step_ratio).round().astype(\n            np.int64\n        ) - 1\n        euler_timesteps = np.asarray([0] + euler_timesteps.tolist())\n\n        self.euler_timesteps = euler_timesteps\n        self.sigmas = sigmas[self.euler_timesteps]\n        self.sigmas = torch.from_numpy(self.sigmas.copy()).to(dtype=torch.float32)\n        self.timesteps = self.sigmas * num_train_timesteps\n        self._step_index = None\n        self._begin_index = None\n        self.sigmas = self.sigmas.to(\"cpu\")\n\n    @property\n    def step_index(self) -> Optional[int]:\n        return self._step_index\n\n    @property\n    def begin_index(self) -> Optional[int]:\n        return self._begin_index\n\n    def set_begin_index(self, begin_index: int = 0):\n        self._begin_index = begin_index\n\n    def scale_model_input(\n        self,\n        sample: torch.FloatTensor,\n        timestep: Optional[Union[float, torch.FloatTensor]] = None,\n    ) -> torch.FloatTensor:\n        \"\"\"Identity operation for flow matching (no input scaling needed).\"\"\"\n        return sample\n\n    def _sigma_to_t(self, sigma: float) -> float:\n        return sigma * self.config.num_train_timesteps\n\n    def set_timesteps(\n        self,\n        num_inference_steps: int = None,\n        device: Union[str, torch.device] = None,\n        sigmas: Optional[List[float]] = None,\n    ):\n        \"\"\"Set timesteps for inference.\"\"\"\n        self.num_inference_steps = (\n            num_inference_steps if num_inference_steps is not None else len(sigmas)\n        )\n        inference_indices = np.linspace(\n            0, self.config.pcm_timesteps, num=self.num_inference_steps, endpoint=False\n        )\n        inference_indices = np.floor(inference_indices).astype(np.int64)\n        inference_indices = torch.from_numpy(inference_indices).long()\n\n        self.sigmas_ = self.sigmas[inference_indices]\n        timesteps = self.sigmas_ * self.config.num_train_timesteps\n        self.timesteps = timesteps.to(device=device)\n        self.sigmas_ = torch.cat(\n            [self.sigmas_, torch.ones(1, device=self.sigmas_.device)]\n        )\n\n        self._step_index = None\n        self._begin_index = None\n\n    def index_for_timestep(\n        self, timestep: float, schedule_timesteps: Optional[torch.Tensor] = None\n    ) -> int:\n        if schedule_timesteps is None:\n            schedule_timesteps = self.timesteps\n        indices = (schedule_timesteps == timestep).nonzero()\n        pos = 1 if len(indices) > 1 else 0\n        return indices[pos].item()\n\n    def _init_step_index(self, timestep: Union[float, torch.Tensor]):\n        if self.begin_index is None:\n            if isinstance(timestep, torch.Tensor):\n                timestep = timestep.to(self.timesteps.device)\n            self._step_index = self.index_for_timestep(timestep)\n        else:\n            self._step_index = self._begin_index\n\n    def step(\n        self,\n        model_output: torch.FloatTensor,\n        timestep: Union[float, torch.FloatTensor],\n        sample: torch.FloatTensor,\n        generator: Optional[torch.Generator] = None,\n        return_dict: bool = True,\n    ) -> Union[Hunyuan3DConsistencyFlowMatchSchedulerOutput, Tuple]:\n        \"\"\"Perform one step of the consistency flow matching scheduler.\"\"\"\n        if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):\n            raise ValueError(\"Passing integer indices as timesteps is not supported.\")\n\n        if self.step_index is None:\n            self._init_step_index(timestep)\n\n        sample = sample.to(torch.float32)\n\n        sigma = self.sigmas_[self.step_index]\n        sigma_next = self.sigmas_[self.step_index + 1]\n\n        prev_sample = sample + (sigma_next - sigma) * model_output\n        prev_sample = prev_sample.to(model_output.dtype)\n\n        pred_original_sample = sample + (1.0 - sigma) * model_output\n        pred_original_sample = pred_original_sample.to(model_output.dtype)\n\n        self._step_index += 1\n\n        if not return_dict:\n            return (prev_sample,)\n\n        return Hunyuan3DConsistencyFlowMatchSchedulerOutput(\n            prev_sample=prev_sample, pred_original_sample=pred_original_sample\n        )\n\n    def __len__(self) -> int:\n        return self.config.num_train_timesteps\n\n\n# Entry class for model registry\nEntryClass = [\n    Hunyuan3DFlowMatchEulerDiscreteScheduler,\n    Hunyuan3DConsistencyFlowMatchEulerDiscreteScheduler,\n]\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_comfyui_passthrough.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"\nPass-through scheduler for ComfyUI integration.\n\nThis scheduler does not modify latents - it simply returns the input sample unchanged.\nThe actual denoising logic is handled by ComfyUI.\n\"\"\"\n\nimport torch\nfrom diffusers.configuration_utils import ConfigMixin, register_to_config\nfrom diffusers.schedulers.scheduling_utils import SchedulerMixin\nfrom diffusers.utils import BaseOutput\n\nfrom sglang.multimodal_gen.runtime.models.schedulers.base import BaseScheduler\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\nclass ComfyUIPassThroughSchedulerOutput(BaseOutput):\n    \"\"\"\n    Output class for the scheduler's `step` function output.\n\n    Args:\n        prev_sample (`torch.FloatTensor`): The input sample unchanged (pass-through).\n    \"\"\"\n\n    prev_sample: torch.FloatTensor\n\n\nclass ComfyUIPassThroughScheduler(BaseScheduler, ConfigMixin, SchedulerMixin):\n    \"\"\"\n    Pass-through scheduler for ComfyUI integration.\n\n    This scheduler does not modify latents. It is used when the denoising logic\n    is handled externally by ComfyUI. The scheduler simply returns the input\n    sample unchanged, allowing ComfyUI to manage the denoising process.\n\n    Usage:\n        - num_inference_steps is always 1 (each step is handled separately)\n        - timesteps are provided externally by ComfyUI\n        - step() returns the input sample unchanged\n    \"\"\"\n\n    config_name = \"scheduler_config.json\"\n    order = 1\n\n    @register_to_config\n    def __init__(\n        self,\n        num_train_timesteps=1000,\n        *args,\n        **kwargs,\n    ):\n        self.num_train_timesteps = num_train_timesteps\n        # Initialize timesteps as empty - will be set externally\n        self.timesteps = torch.tensor([], dtype=torch.long)\n        self.shift = 0.0\n        self._step_index = 0  # Track current step index\n        self._begin_index: int | None = None  # For compatibility with DenoisingStage\n\n    def set_timesteps(\n        self,\n        num_inference_steps=1,  # Always 1 for ComfyUI\n        timesteps=None,  # Can be provided externally\n        device=None,\n        **kwargs,\n    ):\n        \"\"\"\n        Set timesteps. For ComfyUI, timesteps are provided externally.\n\n        Args:\n            num_inference_steps: Ignored (always 1 for ComfyUI)\n            timesteps: External timesteps provided by ComfyUI\n            device: Device to place timesteps on\n        \"\"\"\n        if timesteps is not None:\n            # Use externally provided timesteps\n            if isinstance(timesteps, torch.Tensor):\n                self.timesteps = timesteps\n            else:\n                self.timesteps = torch.tensor(timesteps, dtype=torch.long)\n            if device is not None:\n                self.timesteps = self.timesteps.to(device)\n        else:\n            # Create a single timestep if none provided\n            if device is None:\n                device = torch.device(\"cpu\")\n            self.timesteps = torch.tensor([0], dtype=torch.long, device=device)\n\n    def step(\n        self,\n        model_output: torch.FloatTensor,\n        timestep: torch.FloatTensor | int,\n        sample: torch.FloatTensor,\n        return_dict: bool = False,\n        **kwargs,\n    ) -> tuple | ComfyUIPassThroughSchedulerOutput:\n        \"\"\"\n        Pass-through step: returns the input sample unchanged.\n\n        This scheduler does not modify latents. The actual denoising is handled\n        by ComfyUI, so we simply return the input sample as-is.\n\n        Args:\n            model_output: Predicted noise (ignored, but kept for API compatibility)\n            timestep: Current timestep (ignored, but kept for API compatibility)\n            sample: Input latents (returned unchanged)\n            return_dict: Whether to return a dict or tuple\n\n        Returns:\n            The input sample unchanged (prev_sample = sample)\n        \"\"\"\n        # Increment step index for tracking\n        self._step_index += 1\n\n        # Simply return the input sample unchanged\n        prev_sample = sample\n\n        if not return_dict:\n            return (prev_sample,)\n\n        return ComfyUIPassThroughSchedulerOutput(prev_sample=prev_sample)\n\n    def scale_model_input(\n        self, sample: torch.Tensor, timestep: int | None = None\n    ) -> torch.Tensor:\n        \"\"\"\n        Scale model input. For pass-through scheduler, returns input unchanged.\n\n        Args:\n            sample: Input sample\n            timestep: Timestep (ignored)\n\n        Returns:\n            Input sample unchanged\n        \"\"\"\n        return sample\n\n    def set_shift(self, shift: float) -> None:\n        \"\"\"\n        Set shift parameter (no-op for pass-through scheduler).\n\n        Args:\n            shift: Shift value (ignored)\n        \"\"\"\n        self.shift = shift\n\n    def set_begin_index(self, begin_index: int = 0) -> None:\n        \"\"\"\n        Sets the begin index for the scheduler. This function should be run from pipeline before the inference.\n\n        Args:\n            begin_index: The begin index for the scheduler.\n        \"\"\"\n        self._begin_index = begin_index\n\n    @property\n    def begin_index(self) -> int | None:\n        \"\"\"\n        The index for the first timestep.\n        \"\"\"\n        return self._begin_index\n\n    @property\n    def step_index(self) -> int:\n        \"\"\"\n        The index counter for current timestep.\n        \"\"\"\n        return self._step_index\n\n    def add_noise(\n        self,\n        original_samples: torch.Tensor,\n        noise: torch.Tensor,\n        timestep: torch.Tensor,\n    ) -> torch.Tensor:\n        \"\"\"\n        Add noise to samples. For pass-through scheduler, returns original samples.\n\n        Args:\n            original_samples: Original clean samples\n            noise: Noise to add (ignored)\n            timestep: Timestep (ignored)\n\n        Returns:\n            Original samples unchanged\n        \"\"\"\n        return original_samples\n\n\nEntryClass = ComfyUIPassThroughScheduler\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_dpm_solver_multistep.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n#\n# DPM-Solver++ multistep scheduler wrapper for SANA.\n#\n# SANA uses DPM-Solver++ (Lu et al., 2022) as its noise scheduler, which\n# is a high-order ODE solver that converges in fewer steps than DDIM.\n# With solver_order=2 and 20 steps, SANA achieves high-quality results.\n#\n# This wrapper delegates all numerical work to diffusers' implementation\n# and only adapts the interface for sglang's denoising stage.\n\nimport torch\nfrom diffusers import (\n    DPMSolverMultistepScheduler as DiffusersDPMSolverMultistepScheduler,\n)\nfrom diffusers.configuration_utils import ConfigMixin, register_to_config\nfrom diffusers.schedulers.scheduling_utils import SchedulerMixin\n\nfrom sglang.multimodal_gen.runtime.models.schedulers.base import BaseScheduler\n\n\nclass DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin, BaseScheduler):\n    \"\"\"DPM-Solver++ multistep scheduler wrapper for sglang's BaseScheduler interface.\"\"\"\n\n    order = 1\n    num_train_timesteps = 1000\n\n    @register_to_config\n    def __init__(\n        self,\n        num_train_timesteps: int = 1000,\n        beta_start: float = 0.0001,\n        beta_end: float = 0.02,\n        beta_schedule: str = \"scaled_linear\",\n        trained_betas=None,\n        solver_order: int = 2,\n        prediction_type: str = \"epsilon\",\n        thresholding: bool = False,\n        dynamic_thresholding_ratio: float = 0.995,\n        sample_max_value: float = 1.0,\n        algorithm_type: str = \"dpmsolver++\",\n        solver_type: str = \"midpoint\",\n        lower_order_final: bool = True,\n        euler_at_final: bool = False,\n        use_karras_sigmas: bool = False,\n        use_lu_lambdas: bool = False,\n        use_exponential_sigmas: bool = False,\n        use_beta_sigmas: bool = False,\n        use_flow_sigmas: bool = False,\n        final_sigmas_type: str = \"zero\",\n        lambda_min_clipped: float = -float(\"inf\"),\n        variance_type: str | None = None,\n        timestep_spacing: str = \"linspace\",\n        steps_offset: int = 0,\n        rescale_betas_zero_snr: bool = False,\n        flow_shift: float | None = None,\n        **kwargs,\n    ):\n        self.num_train_timesteps = num_train_timesteps\n        self._inner = DiffusersDPMSolverMultistepScheduler(\n            num_train_timesteps=num_train_timesteps,\n            beta_start=beta_start,\n            beta_end=beta_end,\n            beta_schedule=beta_schedule,\n            trained_betas=trained_betas,\n            solver_order=solver_order,\n            prediction_type=prediction_type,\n            thresholding=thresholding,\n            dynamic_thresholding_ratio=dynamic_thresholding_ratio,\n            sample_max_value=sample_max_value,\n            algorithm_type=algorithm_type,\n            solver_type=solver_type,\n            lower_order_final=lower_order_final,\n            euler_at_final=euler_at_final,\n            use_karras_sigmas=use_karras_sigmas,\n            use_lu_lambdas=use_lu_lambdas,\n            use_exponential_sigmas=use_exponential_sigmas,\n            use_beta_sigmas=use_beta_sigmas,\n            use_flow_sigmas=use_flow_sigmas,\n            flow_shift=flow_shift,\n            final_sigmas_type=final_sigmas_type,\n            lambda_min_clipped=lambda_min_clipped,\n            variance_type=variance_type,\n            timestep_spacing=timestep_spacing,\n            steps_offset=steps_offset,\n            rescale_betas_zero_snr=rescale_betas_zero_snr,\n        )\n        self.timesteps = self._inner.timesteps\n        self.order = solver_order\n        self._flow_shift = flow_shift\n        self._begin_index: int | None = None\n        BaseScheduler.__init__(self)\n\n    def set_shift(self, shift: float) -> None:\n        self._flow_shift = shift\n\n    def set_begin_index(self, begin_index: int = 0) -> None:\n        self._begin_index = begin_index\n\n    @property\n    def begin_index(self) -> int | None:\n        return self._begin_index\n\n    def set_timesteps(self, num_inference_steps: int, device=None, **kwargs):\n        self._inner.set_timesteps(num_inference_steps, device=device, **kwargs)\n        self.timesteps = self._inner.timesteps\n\n    def scale_model_input(\n        self, sample: torch.Tensor, timestep: int | None = None\n    ) -> torch.Tensor:\n        return self._inner.scale_model_input(sample, timestep)\n\n    def step(\n        self,\n        model_output: torch.Tensor,\n        timestep: int,\n        sample: torch.Tensor,\n        **kwargs,\n    ):\n        return self._inner.step(model_output, timestep, sample, **kwargs)\n\n    @property\n    def sigmas(self):\n        return getattr(self._inner, \"sigmas\", None)\n\n    @property\n    def init_noise_sigma(self):\n        return self._inner.init_noise_sigma\n\n    def add_noise(\n        self,\n        original_samples: torch.Tensor,\n        noise: torch.Tensor,\n        timesteps: torch.Tensor,\n    ) -> torch.Tensor:\n        return self._inner.add_noise(original_samples, noise, timesteps)\n\n\nEntryClass = DPMSolverMultistepScheduler\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_flow_match_euler_discrete.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\n# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n#\n# Modified from diffusers==0.29.2\n#\n# ==============================================================================\nimport math\nfrom dataclasses import dataclass\nfrom typing import Any\n\nimport numpy as np\nimport scipy.stats\nimport torch\nfrom diffusers.configuration_utils import ConfigMixin, register_to_config\nfrom diffusers.schedulers.scheduling_utils import SchedulerMixin\nfrom diffusers.utils import BaseOutput\n\nfrom sglang.multimodal_gen.runtime.models.schedulers.base import BaseScheduler\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\n@dataclass\nclass FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):\n    \"\"\"\n    Output class for the scheduler's `step` function output.\n\n    Args:\n        prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):\n            Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the\n            denoising loop.\n    \"\"\"\n\n    prev_sample: torch.FloatTensor\n\n\nclass FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin, BaseScheduler):\n    \"\"\"\n    Euler scheduler.\n\n    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic\n    methods the library implements for all schedulers such as loading and saving.\n\n    Args:\n        num_train_timesteps (`int`, defaults to 1000):\n            The number of diffusion steps to train the model.\n        shift (`float`, defaults to 1.0):\n            The shift value for the timestep schedule.\n        use_dynamic_shifting (`bool`, defaults to False):\n            Whether to apply timestep shifting on-the-fly based on the image resolution.\n        base_shift (`float`, defaults to 0.5):\n            Value to stabilize image generation. Increasing `base_shift` reduces variation and image is more consistent\n            with desired output.\n        max_shift (`float`, defaults to 1.15):\n            Value change allowed to latent vectors. Increasing `max_shift` encourages more variation and image may be\n            more exaggerated or stylized.\n        base_image_seq_len (`int`, defaults to 256):\n            The base image sequence length.\n        max_image_seq_len (`int`, defaults to 4096):\n            The maximum image sequence length.\n        invert_sigmas (`bool`, defaults to False):\n            Whether to invert the sigmas.\n        shift_terminal (`float`, defaults to None):\n            The end value of the shifted timestep schedule.\n        use_karras_sigmas (`bool`, defaults to False):\n            Whether to use Karras sigmas for step sizes in the noise schedule during sampling.\n        use_exponential_sigmas (`bool`, defaults to False):\n            Whether to use exponential sigmas for step sizes in the noise schedule during sampling.\n        use_beta_sigmas (`bool`, defaults to False):\n            Whether to use beta sigmas for step sizes in the noise schedule during sampling.\n        time_shift_type (`str`, defaults to \"exponential\"):\n            The type of dynamic resolution-dependent timestep shifting to apply. Either \"exponential\" or \"linear\".\n        stochastic_sampling (`bool`, defaults to False):\n            Whether to use stochastic sampling.\n    \"\"\"\n\n    _compatibles: list[Any] = []\n    order = 1\n\n    @register_to_config\n    def __init__(\n        self,\n        num_train_timesteps: int = 1000,\n        shift: float = 1.0,\n        use_dynamic_shifting: bool = False,\n        base_shift: float | None = 0.5,\n        max_shift: float | None = 1.15,\n        base_image_seq_len: int | None = 256,\n        max_image_seq_len: int | None = 4096,\n        invert_sigmas: bool = False,\n        shift_terminal: float | None = None,\n        use_karras_sigmas: bool | None = False,\n        use_exponential_sigmas: bool | None = False,\n        use_beta_sigmas: bool | None = False,\n        time_shift_type: str = \"exponential\",\n        stochastic_sampling: bool = False,\n    ):\n        if (\n            sum(\n                [\n                    self.config.use_beta_sigmas,\n                    self.config.use_exponential_sigmas,\n                    self.config.use_karras_sigmas,\n                ]\n            )\n            > 1\n        ):\n            raise ValueError(\n                \"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.\"\n            )\n        if time_shift_type not in {\"exponential\", \"linear\"}:\n            raise ValueError(\n                \"`time_shift_type` must either be 'exponential' or 'linear'.\"\n            )\n\n        timesteps = np.linspace(\n            1, num_train_timesteps, num_train_timesteps, dtype=np.float32\n        )[::-1].copy()\n        timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)\n\n        sigmas = timesteps / num_train_timesteps\n        if not use_dynamic_shifting:\n            # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution\n            sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)\n\n        self.timesteps = sigmas * num_train_timesteps\n        self.num_train_timesteps = num_train_timesteps\n\n        self._step_index: int | None = None\n        self._begin_index: int | None = None\n\n        self._shift = shift\n\n        self.sigmas = sigmas.to(\"cpu\")  # to avoid too much CPU/GPU communication\n        self.sigma_min = self.sigmas[-1].item()\n        self.sigma_max = self.sigmas[0].item()\n        BaseScheduler.__init__(self)\n\n    @property\n    def shift(self) -> float:\n        \"\"\"\n        The value used for shifting.\n        \"\"\"\n        return self._shift\n\n    @property\n    def step_index(self) -> int | None:\n        \"\"\"\n        The index counter for current timestep. It will increase 1 after each scheduler step.\n        \"\"\"\n        return self._step_index\n\n    @property\n    def begin_index(self) -> int | None:\n        \"\"\"\n        The index for the first timestep. It should be set from pipeline with `set_begin_index` method.\n        \"\"\"\n        return self._begin_index\n\n    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index\n    def set_begin_index(self, begin_index: int = 0) -> None:\n        \"\"\"\n        Sets the begin index for the scheduler. This function should be run from pipeline before the inference.\n\n        Args:\n            begin_index (`int`):\n                The begin index for the scheduler.\n        \"\"\"\n        self._begin_index = begin_index\n\n    def set_shift(self, shift: float) -> None:\n        self._shift = shift\n\n    def scale_noise(\n        self,\n        sample: torch.FloatTensor,\n        timestep: float | torch.FloatTensor,\n        noise: torch.FloatTensor | None = None,\n    ) -> torch.FloatTensor:\n        \"\"\"\n        Forward process in flow-matching\n        \"\"\"\n        # Make sure sigmas and timesteps have the same device and dtype as original_samples\n        sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype)\n\n        if sample.device.type == \"mps\" and torch.is_floating_point(timestep):\n            # mps does not support float64\n            schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)\n            assert isinstance(timestep, torch.Tensor)\n            timestep = timestep.to(sample.device, dtype=torch.float32)\n        else:\n            schedule_timesteps = self.timesteps.to(sample.device)\n            assert isinstance(timestep, torch.Tensor)\n            timestep = timestep.to(sample.device)\n\n        # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index\n        if self.begin_index is None:\n            step_indices = [\n                self.index_for_timestep(t, schedule_timesteps) for t in timestep\n            ]\n        elif self.step_index is not None:\n            # add_noise is called after first denoising step (for inpainting)\n            step_indices = [self.step_index] * timestep.shape[0]\n        else:\n            # add noise is called before first denoising step to create initial latent(img2img)\n            step_indices = [self.begin_index] * timestep.shape[0]\n\n        sigma = sigmas[step_indices].flatten()\n        while len(sigma.shape) < len(sample.shape):\n            sigma = sigma.unsqueeze(-1)\n\n        sample = sigma * noise + (1.0 - sigma) * sample\n\n        return sample\n\n    def _sigma_to_t(self, sigma: float) -> float:\n        return sigma * self.config.num_train_timesteps\n\n    def time_shift(\n        self, mu: float, sigma: float, t: torch.Tensor | np.ndarray\n    ) -> torch.Tensor | np.ndarray:\n        if self.config.time_shift_type == \"exponential\":\n            return self._time_shift_exponential(mu, sigma, t)\n        elif self.config.time_shift_type == \"linear\":\n            return self._time_shift_linear(mu, sigma, t)\n        else:\n            raise ValueError(f\"Unknown time_shift_type: {self.config.time_shift_type}\")\n\n    def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor:\n        r\"\"\"\n        Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config\n        value.\n\n        Reference:\n        https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51\n\n        Args:\n            t (`torch.Tensor`):\n                A tensor of timesteps to be stretched and shifted.\n\n        Returns:\n            `torch.Tensor`:\n                A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`.\n        \"\"\"\n        one_minus_z = 1 - t\n        scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal)\n        stretched_t = 1 - (one_minus_z / scale_factor)\n        return stretched_t\n\n    def set_timesteps(\n        self,\n        num_inference_steps: int | None = None,\n        device: str | torch.device = None,\n        sigmas: list[float] | None = None,\n        mu: float | None = None,\n        timesteps: list[float] | None = None,\n    ) -> None:\n        \"\"\"\n        Sets the discrete timesteps used for the diffusion chain (to be run before inference).\n\n        Args:\n            num_inference_steps (`int`, *optional*):\n                The number of diffusion steps used when generating samples with a pre-trained model.\n            device (`str` or `torch.device`, *optional*):\n                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n            sigmas (`List[float]`, *optional*):\n                Custom values for sigmas to be used for each diffusion step. If `None`, the sigmas are computed\n                automatically.\n            mu (`float`, *optional*):\n                Determines the amount of shifting applied to sigmas when performing resolution-dependent timestep\n                shifting.\n            timesteps (`List[float]`, *optional*):\n                Custom values for timesteps to be used for each diffusion step. If `None`, the timesteps are computed\n                automatically.\n        \"\"\"\n\n        if self.config.use_dynamic_shifting and mu is None:\n            raise ValueError(\n                \"`mu` must be passed when `use_dynamic_shifting` is set to be `True`\"\n            )\n\n        if (\n            sigmas is not None\n            and timesteps is not None\n            and len(sigmas) != len(timesteps)\n        ):\n            raise ValueError(\"`sigmas` and `timesteps` should have the same length\")\n\n        if num_inference_steps is not None:\n            if (sigmas is not None and len(sigmas) != num_inference_steps) or (\n                timesteps is not None and len(timesteps) != num_inference_steps\n            ):\n                raise ValueError(\n                    \"`sigmas` and `timesteps` should have the same length as num_inference_steps, if `num_inference_steps` is provided\"\n                )\n        else:\n            if sigmas is not None:\n                num_inference_steps = len(sigmas)\n            elif timesteps is not None:\n                num_inference_steps = len(timesteps)\n            else:\n                raise ValueError(\n                    \"Either num_inference_steps, sigmas, or timesteps must be provided\"\n                )\n\n        self.num_inference_steps = num_inference_steps\n\n        # 1. Prepare default sigmas\n        is_timesteps_provided = timesteps is not None\n\n        timesteps_array: np.ndarray | None = None\n        if is_timesteps_provided:\n            assert timesteps is not None\n            timesteps_array = np.array(timesteps).astype(np.float32)\n\n        sigmas_array: np.ndarray\n        if sigmas is None:\n            if timesteps_array is None:\n                timesteps_array = np.linspace(\n                    self._sigma_to_t(self.sigma_max),\n                    self._sigma_to_t(self.sigma_min),\n                    num_inference_steps,\n                )\n            sigmas_array = timesteps_array / self.config.num_train_timesteps\n        else:\n            sigmas_array = np.array(sigmas).astype(np.float32)\n            num_inference_steps = len(sigmas_array)\n\n        # 2. Perform timestep shifting. Either no shifting is applied, or resolution-dependent shifting of\n        #    \"exponential\" or \"linear\" type is applied\n        if self.config.use_dynamic_shifting:\n            assert mu is not None, \"mu cannot be None when use_dynamic_shifting is True\"\n            sigmas_array = self.time_shift(mu, 1.0, sigmas_array)\n        else:\n            sigmas_array = (\n                self.shift * sigmas_array / (1 + (self.shift - 1) * sigmas_array)\n            )\n\n        # 3. If required, stretch the sigmas schedule to terminate at the configured `shift_terminal` value\n        if self.config.shift_terminal:\n            sigmas_tensor = torch.from_numpy(sigmas_array).to(dtype=torch.float32)\n            sigmas_tensor = self.stretch_shift_to_terminal(sigmas_tensor)\n            sigmas_array = sigmas_tensor.numpy()\n\n        # 4. If required, convert sigmas to one of karras, exponential, or beta sigma schedules\n        if self.config.use_karras_sigmas:\n            sigmas_tensor = torch.from_numpy(sigmas_array).to(dtype=torch.float32)\n            sigmas_tensor = self._convert_to_karras(\n                in_sigmas=sigmas_tensor, num_inference_steps=num_inference_steps\n            )\n            sigmas_array = sigmas_tensor.numpy()\n        elif self.config.use_exponential_sigmas:\n            sigmas_tensor = torch.from_numpy(sigmas_array).to(dtype=torch.float32)\n            sigmas_tensor = self._convert_to_exponential(\n                in_sigmas=sigmas_tensor, num_inference_steps=num_inference_steps\n            )\n            sigmas_array = sigmas_tensor.numpy()\n        elif self.config.use_beta_sigmas:\n            sigmas_tensor = torch.from_numpy(sigmas_array).to(dtype=torch.float32)\n            sigmas_tensor = self._convert_to_beta(\n                in_sigmas=sigmas_tensor, num_inference_steps=num_inference_steps\n            )\n            sigmas_array = sigmas_tensor.numpy()\n\n        # 5. Convert sigmas and timesteps to tensors and move to specified device\n        sigmas_tensor = torch.from_numpy(sigmas_array).to(\n            dtype=torch.float32, device=device\n        )\n        if not is_timesteps_provided:\n            timesteps_tensor = sigmas_tensor * self.config.num_train_timesteps\n        else:\n            assert timesteps_array is not None\n            timesteps_tensor = torch.from_numpy(timesteps_array).to(\n                dtype=torch.float32, device=device\n            )\n\n        # 6. Append the terminal sigma value.\n        #    If a model requires inverted sigma schedule for denoising but timesteps without inversion, the\n        #    `invert_sigmas` flag can be set to `True`. This case is only required in Mochi\n        if self.config.invert_sigmas:\n            sigmas_tensor = 1.0 - sigmas_tensor\n            timesteps_tensor = sigmas_tensor * self.config.num_train_timesteps\n            sigmas_tensor = torch.cat(\n                [sigmas_tensor, torch.ones(1, device=sigmas_tensor.device)]\n            )\n        else:\n            sigmas_tensor = torch.cat(\n                [sigmas_tensor, torch.zeros(1, device=sigmas_tensor.device)]\n            )\n\n        self.timesteps = timesteps_tensor\n        self.sigmas = sigmas_tensor\n        self._step_index = None\n        self._begin_index = None\n\n    def index_for_timestep(\n        self,\n        timestep: float | torch.FloatTensor,\n        schedule_timesteps: torch.Tensor | None = None,\n    ) -> int:\n        if schedule_timesteps is None:\n            schedule_timesteps = self.timesteps\n\n        indices = (schedule_timesteps == timestep).nonzero()\n\n        # The sigma index that is taken for the **very** first `step`\n        # is always the second index (or the last index if there is only 1)\n        # This way we can ensure we don't accidentally skip a sigma in\n        # case we start in the middle of the denoising schedule (e.g. for image-to-image)\n        pos = 1 if len(indices) > 1 else 0\n\n        return indices[pos].item()\n\n    def _init_step_index(self, timestep: float | torch.FloatTensor) -> None:\n        if self.begin_index is None:\n            if isinstance(timestep, torch.Tensor):\n                timestep = timestep.to(self.timesteps.device)\n            self._step_index = self.index_for_timestep(timestep)\n        else:\n            self._step_index = self._begin_index\n\n    def step(\n        self,\n        model_output: torch.FloatTensor,\n        timestep: int | torch.Tensor,\n        sample: torch.FloatTensor,\n        s_churn: float = 0.0,\n        s_tmin: float = 0.0,\n        s_tmax: float = float(\"inf\"),\n        s_noise: float = 1.0,\n        generator: torch.Generator | None = None,\n        per_token_timesteps: torch.Tensor | None = None,\n        return_dict: bool = True,\n    ) -> FlowMatchEulerDiscreteSchedulerOutput | tuple[torch.FloatTensor, ...]:\n        \"\"\"\n        Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion\n        process from the learned model outputs (most often the predicted noise).\n\n        Args:\n            model_output (`torch.FloatTensor`):\n                The direct output from learned diffusion model.\n            timestep (`int` or `torch.Tensor`):\n                The current discrete timestep in the diffusion chain.\n            sample (`torch.FloatTensor`):\n                A current instance of a sample created by the diffusion process.\n            s_churn (`float`):\n            s_tmin  (`float`):\n            s_tmax  (`float`):\n            s_noise (`float`, defaults to 1.0):\n                Scaling factor for noise added to the sample.\n            generator (`torch.Generator`, *optional*):\n                A random number generator.\n            per_token_timesteps (`torch.Tensor`, *optional*):\n                The timesteps for each token in the sample.\n            return_dict (`bool`):\n                Whether or not to return a\n                [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple.\n\n        Returns:\n            [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or `tuple`:\n                If return_dict is `True`,\n                [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] is returned,\n                otherwise a tuple is returned where the first element is the sample tensor.\n        \"\"\"\n\n        if isinstance(timestep, int | torch.IntTensor | torch.LongTensor):\n            raise ValueError(\n                (\n                    \"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to\"\n                    \" `FlowMatchEulerDiscreteScheduler.step()` is not supported. Make sure to pass\"\n                    \" one of the `scheduler.timesteps` as a timestep.\"\n                ),\n            )\n\n        if self.step_index is None:\n            self._init_step_index(timestep)\n\n        # Upcast to avoid precision issues when computing prev_sample\n        sample = sample.to(torch.float32)\n\n        if per_token_timesteps is not None:\n            per_token_sigmas = per_token_timesteps / self.config.num_train_timesteps\n\n            sigmas = self.sigmas[:, None, None]\n            lower_mask = sigmas < per_token_sigmas[None] - 1e-6\n            lower_sigmas = lower_mask * sigmas\n            lower_sigmas, _ = lower_sigmas.max(dim=0)\n\n            current_sigma = per_token_sigmas[..., None]\n            next_sigma = lower_sigmas[..., None]\n            dt = current_sigma - next_sigma\n        else:\n            assert self.step_index is not None, \"step_index should not be None\"\n            sigma_idx = self.step_index\n            sigma = self.sigmas[sigma_idx]\n            sigma_next = self.sigmas[sigma_idx + 1]\n\n            current_sigma = sigma\n            next_sigma = sigma_next\n            dt = sigma_next - sigma\n\n        if self.config.stochastic_sampling:\n            x0 = sample - current_sigma * model_output\n            noise = torch.randn_like(sample)\n            prev_sample = (1.0 - next_sigma) * x0 + next_sigma * noise\n        else:\n            prev_sample = sample + dt * model_output\n\n        # upon completion increase step index by one\n        assert self._step_index is not None, \"_step_index should not be None\"\n        self._step_index += 1\n        if per_token_timesteps is None:\n            # Cast sample back to model compatible dtype\n            prev_sample = prev_sample.to(model_output.dtype)\n\n        if isinstance(prev_sample, torch.Tensor | float) and not return_dict:\n            return (prev_sample,)\n\n        return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)\n\n    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras\n    def _convert_to_karras(\n        self, in_sigmas: torch.Tensor, num_inference_steps: int\n    ) -> torch.Tensor:\n        \"\"\"Constructs the noise schedule of Karras et al. (2022).\"\"\"\n\n        # Hack to make sure that other schedulers which copy this function don't break\n        # TODO: Add this logic to the other schedulers\n        if hasattr(self.config, \"sigma_min\"):\n            sigma_min = self.config.sigma_min\n        else:\n            sigma_min = None\n\n        if hasattr(self.config, \"sigma_max\"):\n            sigma_max = self.config.sigma_max\n        else:\n            sigma_max = None\n\n        sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()\n        sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()\n\n        rho = 7.0  # 7.0 is the value used in the paper\n        ramp = np.linspace(0, 1, num_inference_steps)\n        min_inv_rho = sigma_min ** (1 / rho)\n        max_inv_rho = sigma_max ** (1 / rho)\n        sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho\n        return sigmas\n\n    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential\n    def _convert_to_exponential(\n        self, in_sigmas: torch.Tensor, num_inference_steps: int\n    ) -> torch.Tensor:\n        \"\"\"Constructs an exponential noise schedule.\"\"\"\n\n        # Hack to make sure that other schedulers which copy this function don't break\n        # TODO: Add this logic to the other schedulers\n        if hasattr(self.config, \"sigma_min\"):\n            sigma_min = self.config.sigma_min\n        else:\n            sigma_min = None\n\n        if hasattr(self.config, \"sigma_max\"):\n            sigma_max = self.config.sigma_max\n        else:\n            sigma_max = None\n\n        sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()\n        sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()\n\n        sigmas = np.exp(\n            np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps)\n        )\n        return sigmas\n\n    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta\n    def _convert_to_beta(\n        self,\n        in_sigmas: torch.Tensor,\n        num_inference_steps: int,\n        alpha: float = 0.6,\n        beta: float = 0.6,\n    ) -> torch.Tensor:\n        \"\"\"From \"Beta Sampling is All You Need\" [arXiv:2407.12173] (Lee et. al, 2024)\"\"\"\n\n        # Hack to make sure that other schedulers which copy this function don't break\n        # TODO: Add this logic to the other schedulers\n        if hasattr(self.config, \"sigma_min\"):\n            sigma_min = self.config.sigma_min\n        else:\n            sigma_min = None\n\n        if hasattr(self.config, \"sigma_max\"):\n            sigma_max = self.config.sigma_max\n        else:\n            sigma_max = None\n\n        sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()\n        sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()\n\n        sigmas = np.array(\n            [\n                sigma_min + (ppf * (sigma_max - sigma_min))\n                for ppf in [\n                    scipy.stats.beta.ppf(timestep, alpha, beta)\n                    for timestep in 1 - np.linspace(0, 1, num_inference_steps)\n                ]\n            ]\n        )\n        return sigmas\n\n    def _time_shift_exponential(\n        self, mu: float, sigma: float, t: torch.Tensor | np.ndarray\n    ) -> torch.Tensor | np.ndarray:\n        if isinstance(t, np.ndarray):\n            return np.exp(mu) / (np.exp(mu) + (1 / t - 1) ** sigma)\n        else:\n            return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)\n\n    def _time_shift_linear(\n        self, mu: float, sigma: float, t: torch.Tensor | np.ndarray\n    ) -> torch.Tensor | np.ndarray:\n        return mu / (mu + (1 / t - 1) ** sigma)\n\n    def add_noise(\n        self,\n        clean_latent: torch.Tensor,\n        noise: torch.Tensor,\n        timestep: torch.IntTensor,\n    ) -> torch.Tensor:\n        \"\"\"\n        Args:\n            clean_latent: the clean latent with shape [B, C, H, W],\n                where B is batch_size or batch_size * num_frames\n            noise: the noise with shape [B, C, H, W]\n            timestep: the timestep with shape [1] or [bs * num_frames] or [bs, num_frames]\n\n        Returns:\n            the corrupted latent with shape [B, C, H, W]\n        \"\"\"\n        # If timestep is [bs, num_frames]\n        if timestep.ndim == 2:\n            timestep = timestep.flatten(0, 1)\n            assert timestep.numel() == clean_latent.shape[0]\n        elif timestep.ndim == 1:\n            # If timestep is [1]\n            if timestep.shape[0] == 1:\n                timestep = timestep.expand(clean_latent.shape[0])\n            else:\n                assert timestep.numel() == clean_latent.shape[0]\n        else:\n            raise ValueError(f\"[add_noise] Invalid timestep shape: {timestep.shape}\")\n        # timestep shape should be [B]\n        self.sigmas = self.sigmas.to(noise.device)\n        self.timesteps = self.timesteps.to(noise.device)\n        timestep_id = torch.argmin(\n            (self.timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1\n        )\n        sigma = self.sigmas[timestep_id].reshape(-1, 1, 1, 1)\n        sample = (1 - sigma) * clean_latent + sigma * noise\n        return sample.type_as(noise)\n\n    def scale_model_input(\n        self, sample: torch.Tensor, timestep: int | None = None\n    ) -> torch.Tensor:\n        return sample\n\n    def __len__(self) -> int:\n        return 0\n\n\nEntryClass = FlowMatchEulerDiscreteScheduler\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_flow_unipc_multistep.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py\n# Convert unipc for flow matching\n# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\n\nimport math\nfrom typing import Any\n\nimport numpy as np\nimport torch\nfrom diffusers.configuration_utils import ConfigMixin, register_to_config\nfrom diffusers.schedulers.scheduling_utils import (\n    KarrasDiffusionSchedulers,\n    SchedulerMixin,\n    SchedulerOutput,\n)\nfrom diffusers.utils import deprecate\n\nfrom sglang.multimodal_gen.runtime.models.schedulers.base import BaseScheduler\n\n\nclass FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin, BaseScheduler):\n    \"\"\"\n    `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models.\n\n    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic\n    methods the library implements for all schedulers such as loading and saving.\n\n    Args:\n        num_train_timesteps (`int`, defaults to 1000):\n            The number of diffusion steps to train the model.\n        solver_order (`int`, default `2`):\n            The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1`\n            due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for\n            unconditional sampling.\n        prediction_type (`str`, defaults to \"flow_prediction\"):\n            Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts\n            the flow of the diffusion process.\n        thresholding (`bool`, defaults to `False`):\n            Whether to use the \"dynamic thresholding\" method. This is unsuitable for latent-space diffusion models such\n            as Stable Diffusion.\n        dynamic_thresholding_ratio (`float`, defaults to 0.995):\n            The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.\n        sample_max_value (`float`, defaults to 1.0):\n            The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`.\n        predict_x0 (`bool`, defaults to `True`):\n            Whether to use the updating algorithm on the predicted x0.\n        solver_type (`str`, default `bh2`):\n            Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2`\n            otherwise.\n        lower_order_final (`bool`, default `True`):\n            Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can\n            stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.\n        disable_corrector (`list`, default `[]`):\n            Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)`\n            and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is\n            usually disabled during the first few steps.\n        solver_p (`SchedulerMixin`, default `None`):\n            Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`.\n        use_karras_sigmas (`bool`, *optional*, defaults to `False`):\n            Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,\n            the sigmas are determined according to a sequence of noise levels {σi}.\n        use_exponential_sigmas (`bool`, *optional*, defaults to `False`):\n            Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.\n        timestep_spacing (`str`, defaults to `\"linspace\"`):\n            The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and\n            Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.\n        steps_offset (`int`, defaults to 0):\n            An offset added to the inference steps, as required by some model families.\n        final_sigmas_type (`str`, defaults to `\"zero\"`):\n            The final `sigma` value for the noise schedule during the sampling process. If `\"sigma_min\"`, the final\n            sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.\n    \"\"\"\n\n    _compatibles = [e.name for e in KarrasDiffusionSchedulers]\n    order = 1\n\n    @register_to_config\n    def __init__(\n        self,\n        num_train_timesteps: int = 1000,\n        solver_order: int = 2,\n        prediction_type: str = \"flow_prediction\",\n        shift: float | None = 1.0,\n        use_dynamic_shifting=False,\n        thresholding: bool = False,\n        dynamic_thresholding_ratio: float = 0.995,\n        sample_max_value: float = 1.0,\n        predict_x0: bool = True,\n        solver_type: str = \"bh2\",\n        lower_order_final: bool = True,\n        disable_corrector: tuple = (),\n        solver_p: SchedulerMixin = None,\n        timestep_spacing: str = \"linspace\",\n        steps_offset: int = 0,\n        final_sigmas_type: str | None = \"zero\",  # \"zero\", \"sigma_min\"\n        **kwargs,\n    ):\n\n        if solver_type not in [\"bh1\", \"bh2\"]:\n            if solver_type in [\"midpoint\", \"heun\", \"logrho\"]:\n                self.register_to_config(solver_type=\"bh2\")\n            else:\n                raise NotImplementedError(\n                    f\"{solver_type} is not implemented for {self.__class__}\"\n                )\n\n        self.predict_x0 = predict_x0\n        # setable values\n        self.num_inference_steps: int | None = None\n        alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[\n            ::-1\n        ].copy()\n        sigmas = 1.0 - alphas\n        sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)\n\n        if not use_dynamic_shifting:\n            # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution\n            assert shift is not None\n            sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)  # pyright: ignore\n\n        self.sigmas = sigmas\n        self.sigma_min = self.sigmas[-1].item()\n        self.sigma_max = self.sigmas[0].item()\n\n        self.timesteps = sigmas * num_train_timesteps\n        self.num_train_timesteps = num_train_timesteps\n\n        self.model_outputs = [None] * solver_order\n        self.timestep_list: list[Any | None] = [None] * solver_order\n        self.lower_order_nums = 0\n        self.disable_corrector = list(disable_corrector)\n        self.solver_p = solver_p\n        self.last_sample = None\n        self._step_index: int | None = None\n        self._begin_index: int | None = None\n\n        BaseScheduler.__init__(self)\n\n    @property\n    def step_index(self):\n        \"\"\"\n        The index counter for current timestep. It will increase 1 after each scheduler step.\n        \"\"\"\n        return self._step_index\n\n    @property\n    def begin_index(self):\n        \"\"\"\n        The index for the first timestep. It should be set from pipeline with `set_begin_index` method.\n        \"\"\"\n        return self._begin_index\n\n    def set_shift(self, shift: float) -> None:\n        self.config.shift = shift\n\n    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index\n    def set_begin_index(self, begin_index: int = 0):\n        \"\"\"\n        Sets the begin index for the scheduler. This function should be run from pipeline before the inference.\n\n        Args:\n            begin_index (`int`):\n                The begin index for the scheduler.\n        \"\"\"\n        self._begin_index = begin_index\n\n    # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps\n    def set_timesteps(\n        self,\n        num_inference_steps: int | None = None,\n        device: str | torch.device = None,\n        sigmas: list[float] | None = None,\n        mu: float | None | None = None,\n        shift: float | None | None = None,\n    ):\n        \"\"\"\n        Sets the discrete timesteps used for the diffusion chain (to be run before inference).\n        Args:\n            num_inference_steps (`int`):\n                Total number of the spacing of the time steps.\n            device (`str` or `torch.device`, *optional*):\n                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        \"\"\"\n\n        if self.config.use_dynamic_shifting and mu is None:\n            raise ValueError(\n                \" you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`\"\n            )\n\n        if sigmas is None:\n            assert num_inference_steps is not None\n            sigmas = np.linspace(\n                self.sigma_max, self.sigma_min, num_inference_steps + 1\n            ).copy()[\n                :-1\n            ]  # pyright: ignore\n\n        if self.config.use_dynamic_shifting:\n            assert mu is not None\n            sigmas = self.time_shift(mu, 1.0, sigmas)  # pyright: ignore\n        else:\n            if shift is None:\n                shift = self.config.shift\n            assert isinstance(sigmas, np.ndarray)\n            sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)  # pyright: ignore\n\n        if self.config.final_sigmas_type == \"sigma_min\":\n            sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5\n        elif self.config.final_sigmas_type == \"zero\":\n            sigma_last = 0\n        else:\n            raise ValueError(\n                f\"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}\"\n            )\n\n        timesteps = sigmas * self.config.num_train_timesteps\n        sigmas = np.concatenate([sigmas, [sigma_last]]).astype(\n            np.float32\n        )  # pyright: ignore\n\n        self.sigmas = torch.from_numpy(sigmas).to(device=device)\n        self.timesteps = torch.from_numpy(timesteps).to(\n            device=device, dtype=torch.int64\n        )\n\n        self.num_inference_steps = len(timesteps)\n\n        self.model_outputs = [\n            None,\n        ] * self.config.solver_order\n        self.lower_order_nums = 0\n        self.last_sample = None\n        if self.solver_p:\n            self.solver_p.set_timesteps(self.num_inference_steps, device=device)\n\n        # add an index counter for schedulers that allow duplicated timesteps\n        self._step_index = None\n        self._begin_index = None\n\n    # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample\n    def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        \"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the\n        prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by\n        s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing\n        pixels from saturation at each step. We find that dynamic thresholding results in significantly better\n        photorealism as well as better image-text alignment, especially when using very large guidance weights.\"\n\n        https://arxiv.org/abs/2205.11487\n        \"\"\"\n        dtype = sample.dtype\n        batch_size, channels, *remaining_dims = sample.shape\n\n        if dtype not in (torch.float32, torch.float64):\n            sample = (\n                sample.float()\n            )  # upcast for quantile calculation, and clamp not implemented for cpu half\n\n        # Flatten sample for doing quantile calculation along each image\n        sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))\n\n        abs_sample = sample.abs()  # \"a certain percentile absolute pixel value\"\n\n        s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)\n        s = torch.clamp(\n            s, min=1, max=self.config.sample_max_value\n        )  # When clamped to min=1, equivalent to standard clipping to [-1, 1]\n        s = s.unsqueeze(1)  # (batch_size, 1) because clamp will broadcast along dim=0\n        sample = (\n            torch.clamp(sample, -s, s) / s\n        )  # \"we threshold xt0 to the range [-s, s] and then divide by s\"\n\n        sample = sample.reshape(batch_size, channels, *remaining_dims)\n        sample = sample.to(dtype)\n\n        return sample\n\n    def _sigma_to_alpha_sigma_t(self, sigma) -> tuple[Any, Any]:\n        return 1 - sigma, sigma\n\n    # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps\n    def time_shift(self, mu: float, sigma: float, t: torch.Tensor):\n        return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)\n\n    def convert_model_output(\n        self,\n        model_output: torch.Tensor,\n        *args,\n        sample: torch.Tensor = None,\n        **kwargs,\n    ) -> torch.Tensor:\n        r\"\"\"\n        Convert the model output to the corresponding type the UniPC algorithm needs.\n\n        Args:\n            model_output (`torch.Tensor`):\n                The direct output from the learned diffusion model.\n            timestep (`int`):\n                The current discrete timestep in the diffusion chain.\n            sample (`torch.Tensor`):\n                A current instance of a sample created by the diffusion process.\n\n        Returns:\n            `torch.Tensor`:\n                The converted model output.\n        \"\"\"\n        timestep = args[0] if len(args) > 0 else kwargs.pop(\"timestep\", None)\n        if sample is None:\n            if len(args) > 1:\n                sample = args[1]\n            else:\n                raise ValueError(\"missing `sample` as a required keyword argument\")\n        if timestep is not None:\n            deprecate(\n                \"timesteps\",\n                \"1.0.0\",\n                \"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`\",\n            )\n\n        if self.predict_x0:\n            if self.config.prediction_type == \"flow_prediction\":\n                sigma_t = self.sigmas[self.step_index]\n                x0_pred = sample - sigma_t * model_output\n            else:\n                raise ValueError(\n                    f\"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,\"\n                    \" `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler.\"\n                )\n\n            if self.config.thresholding:\n                x0_pred = self._threshold_sample(x0_pred)\n\n            return x0_pred\n        else:\n            if self.config.prediction_type == \"flow_prediction\":\n                sigma_t = self.sigmas[self.step_index]\n                epsilon = sample - (1 - sigma_t) * model_output\n            else:\n                raise ValueError(\n                    f\"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,\"\n                    \" `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler.\"\n                )\n\n            if self.config.thresholding:\n                sigma_t = self.sigmas[self.step_index]\n                x0_pred = sample - sigma_t * model_output\n                x0_pred = self._threshold_sample(x0_pred)\n                epsilon = model_output + x0_pred\n\n            return epsilon\n\n    def multistep_uni_p_bh_update(\n        self,\n        model_output: torch.Tensor,\n        *args,\n        sample: torch.Tensor = None,\n        order: int | None = None,  # pyright: ignore\n        **kwargs,\n    ) -> torch.Tensor:\n        \"\"\"\n        One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.\n\n        Args:\n            model_output (`torch.Tensor`):\n                The direct output from the learned diffusion model at the current timestep.\n            prev_timestep (`int`):\n                The previous discrete timestep in the diffusion chain.\n            sample (`torch.Tensor`):\n                A current instance of a sample created by the diffusion process.\n            order (`int`):\n                The order of UniP at this timestep (corresponds to the *p* in UniPC-p).\n\n        Returns:\n            `torch.Tensor`:\n                The sample tensor at the previous timestep.\n        \"\"\"\n        prev_timestep = args[0] if len(args) > 0 else kwargs.pop(\"prev_timestep\", None)\n        if sample is None:\n            if len(args) > 1:\n                sample = args[1]\n            else:\n                raise ValueError(\" missing `sample` as a required keyword argument\")\n        if order is None:\n            if len(args) > 2:\n                order = args[2]\n            else:\n                raise ValueError(\" missing `order` as a required keyword argument\")\n        if prev_timestep is not None:\n            deprecate(\n                \"prev_timestep\",\n                \"1.0.0\",\n                \"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`\",\n            )\n        model_output_list = self.model_outputs\n\n        s0 = self.timestep_list[-1]\n        m0 = model_output_list[-1]\n        x = sample\n\n        if self.solver_p:\n            x_t = self.solver_p.step(model_output, s0, x).prev_sample\n            return x_t\n\n        sigma_t, sigma_s0 = (\n            self.sigmas[self.step_index + 1],\n            self.sigmas[self.step_index],\n        )  # pyright: ignore\n        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)\n        alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)\n\n        lambda_t = torch.log(alpha_t) - torch.log(sigma_t)\n        lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)\n\n        h = lambda_t - lambda_s0\n        device = sample.device\n\n        rks = []\n        D1s: list[Any] | None = []\n        sigmas = self.sigmas.to(device=device)\n        for i in range(1, order):\n            si = self.step_index - i  # pyright: ignore\n            mi = model_output_list[-(i + 1)]\n            alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(sigmas[si])\n            lambda_si = torch.log(alpha_si) - torch.log(sigma_si)\n            rk = (lambda_si - lambda_s0) / h\n            rks.append(rk)\n            assert mi is not None\n            D1s.append((mi - m0) / rk)  # pyright: ignore\n\n        if len(rks) > 0:\n            rks = torch.stack(rks)\n            one = torch.ones(1, device=device, dtype=rks.dtype)\n            rks = torch.cat([rks, one])\n        else:\n            rks = torch.ones(1, device=device, dtype=h.dtype)\n\n        R = []\n        b = []\n\n        hh = -h if self.predict_x0 else h\n        h_phi_1 = torch.expm1(hh)  # h\\phi_1(h) = e^h - 1\n        h_phi_k = h_phi_1 / hh - 1\n\n        factorial_i = 1\n\n        if self.config.solver_type == \"bh1\":\n            B_h = hh\n        elif self.config.solver_type == \"bh2\":\n            B_h = torch.expm1(hh)\n        else:\n            raise NotImplementedError()\n\n        for i in range(1, order + 1):\n            R.append(torch.pow(rks, i - 1))\n            b.append(h_phi_k * factorial_i / B_h)\n            factorial_i *= i + 1\n            h_phi_k = h_phi_k / hh - 1 / factorial_i\n\n        R = torch.stack(R)\n        b = torch.stack(b)\n\n        if D1s is not None and len(D1s) > 0:\n            D1s = torch.stack(D1s, dim=1)  # (B, K)\n            # for order 2, we use a simplified version\n            if order == 2:\n                rhos_p = 0.5 * torch.ones(1, dtype=x.dtype, device=device)\n            else:\n                assert isinstance(R, torch.Tensor)\n                rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype)\n        else:\n            D1s = None\n\n        if self.predict_x0:\n            x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0\n            if D1s is not None:\n                pred_res = torch.einsum(\n                    \"k,bkc...->bc...\", rhos_p, D1s\n                )  # pyright: ignore\n            else:\n                pred_res = 0\n            x_t = x_t_ - alpha_t * B_h * pred_res\n        else:\n            x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0\n            if D1s is not None:\n                pred_res = torch.einsum(\n                    \"k,bkc...->bc...\", rhos_p, D1s\n                )  # pyright: ignore\n            else:\n                pred_res = 0\n            x_t = x_t_ - sigma_t * B_h * pred_res\n\n        x_t = x_t.to(x.dtype)\n        return x_t\n\n    def multistep_uni_c_bh_update(\n        self,\n        this_model_output: torch.Tensor,\n        *args,\n        last_sample: torch.Tensor = None,\n        this_sample: torch.Tensor = None,\n        order: int | None = None,  # pyright: ignore\n        **kwargs,\n    ) -> torch.Tensor:\n        \"\"\"\n        One step for the UniC (B(h) version).\n\n        Args:\n            this_model_output (`torch.Tensor`):\n                The model outputs at `x_t`.\n            this_timestep (`int`):\n                The current timestep `t`.\n            last_sample (`torch.Tensor`):\n                The generated sample before the last predictor `x_{t-1}`.\n            this_sample (`torch.Tensor`):\n                The generated sample after the last predictor `x_{t}`.\n            order (`int`):\n                The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`.\n\n        Returns:\n            `torch.Tensor`:\n                The corrected sample tensor at the current timestep.\n        \"\"\"\n        this_timestep = args[0] if len(args) > 0 else kwargs.pop(\"this_timestep\", None)\n        if last_sample is None:\n            if len(args) > 1:\n                last_sample = args[1]\n            else:\n                raise ValueError(\" missing`last_sample` as a required keyword argument\")\n        if this_sample is None:\n            if len(args) > 2:\n                this_sample = args[2]\n            else:\n                raise ValueError(\" missing`this_sample` as a required keyword argument\")\n        if order is None:\n            if len(args) > 3:\n                order = args[3]\n            else:\n                raise ValueError(\" missing`order` as a required keyword argument\")\n        if this_timestep is not None:\n            deprecate(\n                \"this_timestep\",\n                \"1.0.0\",\n                \"Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`\",\n            )\n\n        model_output_list = self.model_outputs\n\n        m0 = model_output_list[-1]\n        x = last_sample\n        x_t = this_sample\n        model_t = this_model_output\n\n        sigma_t, sigma_s0 = (\n            self.sigmas[self.step_index],\n            self.sigmas[self.step_index - 1],\n        )  # pyright: ignore\n        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)\n        alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)\n\n        lambda_t = torch.log(alpha_t) - torch.log(sigma_t)\n        lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)\n\n        h = lambda_t - lambda_s0\n        device = this_sample.device\n\n        # Build rks and D1s fully on device to avoid any host-device sync\n        # Fast paths for small orders (common cases: 1 or 2)\n        if order == 1:\n            rks = torch.ones(1, device=device, dtype=h.dtype)\n            D1s = None\n        elif order == 2:\n            # order == 2 -> only one historical point is used\n            si = self.step_index - 2  # i = 1\n            mi = model_output_list[-2]\n            alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])\n            lambda_si = torch.log(alpha_si) - torch.log(sigma_si)\n            rk = (lambda_si - lambda_s0) / h  # 0-dim tensor on device\n            # rks = [rk, 1.0] but keep it on device without list->tensor sync\n            rks = torch.stack((rk, torch.ones_like(rk)))\n            assert mi is not None\n            # D1s shape: (B, K=1, C, ...) to match later einsum over K\n            D1s = ((mi - m0) / rk).unsqueeze(1)  # pyright: ignore\n        else:\n            rks_list = []\n            D1s_list = []\n            for i in range(1, order):\n                si = self.step_index - (i + 1)\n                mi = model_output_list[-(i + 1)]\n                alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])\n                lambda_si = torch.log(alpha_si) - torch.log(sigma_si)\n                rk = (lambda_si - lambda_s0) / h\n                rks_list.append(rk)\n                assert mi is not None\n                D1s_list.append((mi - m0) / rk)  # pyright: ignore\n\n            # Append 1.0 as a device tensor to rks\n            rks = torch.stack(rks_list + [torch.ones_like(rks_list[0])])\n            D1s = torch.stack(D1s_list, dim=1) if len(D1s_list) > 0 else None\n\n        R = []\n        b = []\n\n        hh = -h if self.predict_x0 else h\n        h_phi_1 = torch.expm1(hh)  # h\\phi_1(h) = e^h - 1\n        h_phi_k = h_phi_1 / hh - 1\n\n        factorial_i = 1\n\n        if self.config.solver_type == \"bh1\":\n            B_h = hh\n        elif self.config.solver_type == \"bh2\":\n            B_h = torch.expm1(hh)\n        else:\n            raise NotImplementedError()\n\n        for i in range(1, order + 1):\n            R.append(torch.pow(rks, i - 1))\n            b.append(h_phi_k * factorial_i / B_h)\n            factorial_i *= i + 1\n            h_phi_k = h_phi_k / hh - 1 / factorial_i\n\n        R = torch.stack(R)\n        # Avoid torch.tensor(list_of_gpu_scalars) which syncs to host\n        b = torch.stack(b)\n\n        # D1s is already prepared above for order==2; remains None for order==1\n\n        # for order 1, we use a simplified version\n        if order == 1:\n            rhos_c = 0.5 * torch.ones(1, dtype=x.dtype, device=device)\n        elif order == 2:\n            # Manually solve the 2x2 linear system to avoid device synchronization from torch.linalg.solve\n            # R = [[1, 1], [rk, 1]], where rk = rks[0]\n            rk = rks[0]\n            det = 1 - rk\n            # Using Cramer's rule to solve for rhos_c = [x0, x1]\n            # x0 = (b0 - b1) / det\n            # x1 = (b1 - rk * b0) / det\n            rhos_c_0 = (b[0] - b[1]) / det\n            rhos_c_1 = (b[1] - rk * b[0]) / det\n            rhos_c = torch.stack([rhos_c_0, rhos_c_1])\n        else:\n            rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)\n\n        if self.predict_x0:\n            x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0\n            if D1s is not None:\n                corr_res = torch.einsum(\"k,bkc...->bc...\", rhos_c[:-1], D1s)\n            else:\n                corr_res = 0\n            D1_t = model_t - m0\n            x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)\n        else:\n            x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0\n            if D1s is not None:\n                corr_res = torch.einsum(\"k,bkc...->bc...\", rhos_c[:-1], D1s)\n            else:\n                corr_res = 0\n            D1_t = model_t - m0\n            x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)\n        x_t = x_t.to(x.dtype)\n        return x_t\n\n    def index_for_timestep(self, timestep, schedule_timesteps=None) -> int:\n        if schedule_timesteps is None:\n            schedule_timesteps = self.timesteps\n\n        indices = (schedule_timesteps == timestep).nonzero()\n\n        # The sigma index that is taken for the **very** first `step`\n        # is always the second index (or the last index if there is only 1)\n        # This way we can ensure we don't accidentally skip a sigma in\n        # case we start in the middle of the denoising schedule (e.g. for image-to-image)\n        pos = 1 if len(indices) > 1 else 0\n        step_index: int = indices[pos].item()\n\n        return step_index\n\n    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index\n    def _init_step_index(self, timestep) -> None:\n        \"\"\"\n        Initialize the step_index counter for the scheduler.\n        \"\"\"\n\n        if self.begin_index is None:\n            if isinstance(timestep, torch.Tensor):\n                timestep = timestep.to(self.timesteps.device)\n            self._step_index = self.index_for_timestep(timestep)\n        else:\n            self._step_index = self._begin_index\n\n    def step(\n        self,\n        model_output: torch.Tensor,\n        timestep: int | torch.Tensor,\n        sample: torch.Tensor,\n        return_dict: bool = True,\n        generator=None,\n    ) -> SchedulerOutput | tuple:\n        \"\"\"\n        Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with\n        the multistep UniPC.\n\n        Args:\n            model_output (`torch.Tensor`):\n                The direct output from learned diffusion model.\n            timestep (`int`):\n                The current discrete timestep in the diffusion chain.\n            sample (`torch.Tensor`):\n                A current instance of a sample created by the diffusion process.\n            return_dict (`bool`):\n                Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.\n\n        Returns:\n            [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:\n                If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a\n                tuple is returned where the first element is the sample tensor.\n\n        \"\"\"\n        if self.num_inference_steps is None:\n            raise ValueError(\n                \"Number of inference steps is 'None', you need to call 'set_timesteps' after creating the scheduler\"\n            )\n\n        if self.step_index is None:\n            self._init_step_index(timestep)\n\n        use_corrector = (\n            self.step_index > 0\n            and self.step_index - 1 not in self.disable_corrector\n            and self.last_sample is not None  # pyright: ignore\n        )\n\n        sample = sample.to(model_output.device)\n        model_output_convert = self.convert_model_output(model_output, sample=sample)\n\n        if use_corrector:\n            sample = self.multistep_uni_c_bh_update(\n                this_model_output=model_output_convert,\n                last_sample=self.last_sample,\n                this_sample=sample,\n                order=self.this_order,\n            )\n\n        for i in range(self.config.solver_order - 1):\n            self.model_outputs[i] = self.model_outputs[i + 1]\n            self.timestep_list[i] = self.timestep_list[i + 1]\n\n        self.model_outputs[-1] = model_output_convert\n        self.timestep_list[-1] = timestep  # pyright: ignore\n\n        if self.config.lower_order_final:\n            this_order = min(\n                self.config.solver_order, len(self.timesteps) - self.step_index\n            )  # pyright: ignore\n        else:\n            this_order = self.config.solver_order\n\n        self.this_order: int = min(\n            this_order, self.lower_order_nums + 1\n        )  # warmup for multistep\n        assert self.this_order > 0\n\n        self.last_sample = sample\n        prev_sample = self.multistep_uni_p_bh_update(\n            model_output=model_output,  # pass the original non-converted model output, in case solver-p is used\n            sample=sample,\n            order=self.this_order,\n        )\n\n        if self.lower_order_nums < self.config.solver_order:\n            self.lower_order_nums += 1\n\n        # upon completion increase step index by one\n        assert self._step_index is not None\n        self._step_index += 1  # pyright: ignore\n\n        if not return_dict:\n            return (prev_sample,)\n\n        return SchedulerOutput(prev_sample=prev_sample)\n\n    def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:\n        \"\"\"\n        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the\n        current timestep.\n\n        Args:\n            sample (`torch.Tensor`):\n                The input sample.\n\n        Returns:\n            `torch.Tensor`:\n                A scaled input sample.\n        \"\"\"\n        return sample\n\n    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise\n    def add_noise(\n        self,\n        original_samples: torch.Tensor,\n        noise: torch.Tensor,\n        timesteps: torch.IntTensor,\n    ) -> torch.Tensor:\n        # Make sure sigmas and timesteps have the same device and dtype as original_samples\n        sigmas = self.sigmas.to(\n            device=original_samples.device, dtype=original_samples.dtype\n        )\n        if original_samples.device.type == \"mps\" and torch.is_floating_point(timesteps):\n            # mps does not support float64\n            schedule_timesteps = self.timesteps.to(\n                original_samples.device, dtype=torch.float32\n            )\n            timesteps = timesteps.to(original_samples.device, dtype=torch.float32)\n        else:\n            schedule_timesteps = self.timesteps.to(original_samples.device)\n            timesteps = timesteps.to(original_samples.device)\n\n        # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index\n        if self.begin_index is None:\n            step_indices = [\n                self.index_for_timestep(t, schedule_timesteps) for t in timesteps\n            ]\n        elif self.step_index is not None:\n            # add_noise is called after first denoising step (for inpainting)\n            step_indices = [self.step_index] * timesteps.shape[0]\n        else:\n            # add noise is called before first denoising step to create initial latent(img2img)\n            step_indices = [self.begin_index] * timesteps.shape[0]\n\n        sigma = sigmas[step_indices].flatten()\n        while len(sigma.shape) < len(original_samples.shape):\n            sigma = sigma.unsqueeze(-1)\n\n        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)\n        noisy_samples = alpha_t * original_samples + sigma_t * noise\n        return noisy_samples\n\n\nEntryClass = FlowUniPCMultistepScheduler\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_helios.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n# Adapted from Helios diffusers scheduler:\n# https://github.com/BestWishYsh/Helios\n\"\"\"\nHelios scheduler implementing flow-matching with UniPC/Euler solvers.\n\nFor Phase 1 T2V (stages=1), this simplifies to standard flow-matching\nwith dynamic shifting and UniPC multistep solver.\n\"\"\"\n\nimport math\nfrom dataclasses import dataclass\n\nimport numpy as np\nimport torch\n\n\n@dataclass\nclass HeliosSchedulerOutput:\n    prev_sample: torch.FloatTensor\n    model_outputs: torch.FloatTensor | None = None\n    last_sample: torch.FloatTensor | None = None\n    this_order: int | None = None\n\n\nclass HeliosSchedulerConfig:\n    \"\"\"Mimics diffusers config interface for scheduler parameters.\"\"\"\n\n    def __init__(self, **kwargs):\n        for k, v in kwargs.items():\n            setattr(self, k, v)\n\n    def get(self, key, default=None):\n        return getattr(self, key, default)\n\n\nclass HeliosScheduler:\n    \"\"\"\n    Helios multi-stage scheduler supporting Euler, UniPC, and DMD solvers.\n\n    For Phase 1 T2V with stages=1, this is a standard flow-matching scheduler\n    with optional time shifting and UniPC multistep updates.\n    \"\"\"\n\n    order = 1\n\n    def __init__(\n        self,\n        num_train_timesteps: int = 1000,\n        shift: float = 1.0,\n        stages: int = 1,\n        stage_range: list | None = None,\n        gamma: float = 1 / 3,\n        thresholding: bool = False,\n        prediction_type: str = \"flow_prediction\",\n        solver_order: int = 2,\n        predict_x0: bool = True,\n        solver_type: str = \"bh2\",\n        lower_order_final: bool = True,\n        disable_corrector: list[int] | None = None,\n        use_flow_sigmas: bool = True,\n        scheduler_type: str = \"unipc\",\n        use_dynamic_shifting: bool = False,\n        time_shift_type: str = \"linear\",\n        **kwargs,\n    ):\n        if stage_range is None:\n            # Evenly divide [0, 1] into 3 stages for pyramid SR\n            stage_range = [0, 1 / 3, 2 / 3, 1]\n        if disable_corrector is None:\n            disable_corrector = []\n\n        self.config = HeliosSchedulerConfig(\n            num_train_timesteps=num_train_timesteps,\n            shift=shift,\n            stages=stages,\n            stage_range=stage_range,\n            gamma=gamma,\n            thresholding=thresholding,\n            prediction_type=prediction_type,\n            solver_order=solver_order,\n            predict_x0=predict_x0,\n            solver_type=solver_type,\n            lower_order_final=lower_order_final,\n            disable_corrector=disable_corrector,\n            use_flow_sigmas=use_flow_sigmas,\n            scheduler_type=scheduler_type,\n            use_dynamic_shifting=use_dynamic_shifting,\n            time_shift_type=time_shift_type,\n        )\n\n        self.timestep_ratios = {}\n        self.timesteps_per_stage = {}\n        self.sigmas_per_stage = {}\n        self.start_sigmas = {}\n        self.end_sigmas = {}\n        self.ori_start_sigmas = {}\n\n        self.init_sigmas_for_each_stage()\n        self.sigma_min = self.sigmas[-1].item()\n        self.sigma_max = self.sigmas[0].item()\n        self.gamma = gamma\n\n        if solver_type not in [\"bh1\", \"bh2\"]:\n            raise NotImplementedError(f\"{solver_type} is not implemented\")\n\n        self.predict_x0 = predict_x0\n        self.model_outputs = [None] * solver_order\n        self.timestep_list = [None] * solver_order\n        self.lower_order_nums = 0\n        self.disable_corrector = disable_corrector\n        self.solver_p = None\n        self.last_sample = None\n        self._step_index = None\n        self._begin_index = None\n        self.num_inference_steps = None\n\n    def init_sigmas(self):\n        num_train_timesteps = self.config.num_train_timesteps\n        shift = self.config.shift\n\n        alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps + 1)\n        sigmas = 1.0 - alphas\n        sigmas = np.flip(shift * sigmas / (1 + (shift - 1) * sigmas))[:-1].copy()\n        sigmas = torch.from_numpy(sigmas)\n        timesteps = (sigmas * num_train_timesteps).clone()\n\n        self._step_index = None\n        self._begin_index = None\n        self.timesteps = timesteps\n        self.sigmas = sigmas.to(\"cpu\")\n\n    def init_sigmas_for_each_stage(self):\n        self.init_sigmas()\n\n        stage_distance = []\n        stages = self.config.stages\n        training_steps = self.config.num_train_timesteps\n        stage_range = self.config.stage_range\n\n        for i_s in range(stages):\n            start_indice = int(stage_range[i_s] * training_steps)\n            start_indice = max(start_indice, 0)\n            end_indice = int(stage_range[i_s + 1] * training_steps)\n            end_indice = min(end_indice, training_steps)\n            start_sigma = self.sigmas[start_indice].item()\n            end_sigma = (\n                self.sigmas[end_indice].item() if end_indice < training_steps else 0.0\n            )\n            self.ori_start_sigmas[i_s] = start_sigma\n\n            if i_s != 0:\n                ori_sigma = 1 - start_sigma\n                gamma = self.config.gamma\n                corrected_sigma = (\n                    1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma)\n                ) * ori_sigma\n                start_sigma = 1 - corrected_sigma\n\n            stage_distance.append(start_sigma - end_sigma)\n            self.start_sigmas[i_s] = start_sigma\n            self.end_sigmas[i_s] = end_sigma\n\n        tot_distance = sum(stage_distance)\n        for i_s in range(stages):\n            if i_s == 0:\n                start_ratio = 0.0\n            else:\n                start_ratio = sum(stage_distance[:i_s]) / tot_distance\n            if i_s == stages - 1:\n                # Use value just below 1.0 to avoid out-of-bounds indexing\n                end_ratio = 1.0 - 1e-16\n            else:\n                end_ratio = sum(stage_distance[: i_s + 1]) / tot_distance\n            self.timestep_ratios[i_s] = (start_ratio, end_ratio)\n\n        for i_s in range(stages):\n            timestep_ratio = self.timestep_ratios[i_s]\n            # Clamp to max valid timestep (num_train_timesteps - 1)\n            timestep_max = min(\n                self.timesteps[int(timestep_ratio[0] * training_steps)], 999\n            )\n            timestep_min = self.timesteps[\n                min(int(timestep_ratio[1] * training_steps), training_steps - 1)\n            ]\n            timesteps = np.linspace(timestep_max, timestep_min, training_steps + 1)\n            self.timesteps_per_stage[i_s] = (\n                timesteps[:-1]\n                if isinstance(timesteps, torch.Tensor)\n                else torch.from_numpy(timesteps[:-1])\n            )\n            # Sigma range [0.999, 0]: start just below 1.0 to avoid singularity\n            stage_sigmas = np.linspace(0.999, 0, training_steps + 1)\n            self.sigmas_per_stage[i_s] = torch.from_numpy(stage_sigmas[:-1])\n\n    @property\n    def step_index(self):\n        return self._step_index\n\n    @property\n    def begin_index(self):\n        return self._begin_index\n\n    def set_begin_index(self, begin_index: int = 0):\n        self._begin_index = begin_index\n\n    def time_shift(self, mu, sigma, t):\n        if self.config.time_shift_type == \"exponential\":\n            return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)\n        elif self.config.time_shift_type == \"linear\":\n            return mu / (mu + (1 / t - 1) ** sigma)\n\n    def set_timesteps(\n        self,\n        num_inference_steps: int,\n        stage_index: int | None = None,\n        device: str | torch.device = None,\n        sigmas=None,\n        mu=None,\n        is_amplify_first_chunk: bool = False,\n    ):\n        if self.config.scheduler_type == \"dmd\":\n            if is_amplify_first_chunk:\n                num_inference_steps = num_inference_steps * 2 + 1\n            else:\n                num_inference_steps = num_inference_steps + 1\n\n        self.num_inference_steps = num_inference_steps\n        self.init_sigmas()\n\n        if self.config.stages == 1:\n            if sigmas is None:\n                sigmas = np.linspace(\n                    1,\n                    1 / self.config.num_train_timesteps,\n                    num_inference_steps + 1,\n                )[:-1].astype(np.float32)\n                if self.config.shift != 1.0:\n                    assert not self.config.use_dynamic_shifting\n                    sigmas = self.time_shift(self.config.shift, 1.0, sigmas)\n            timesteps = (sigmas * self.config.num_train_timesteps).copy()\n            sigmas = torch.from_numpy(sigmas)\n        else:\n            stage_timesteps = self.timesteps_per_stage[stage_index]\n            timesteps = np.linspace(\n                stage_timesteps[0].item(),\n                stage_timesteps[-1].item(),\n                num_inference_steps,\n            )\n            stage_sigmas = self.sigmas_per_stage[stage_index]\n            ratios = np.linspace(\n                stage_sigmas[0].item(), stage_sigmas[-1].item(), num_inference_steps\n            )\n            sigmas = torch.from_numpy(ratios)\n\n        self.timesteps = torch.from_numpy(timesteps).to(device=device)\n        self.sigmas = torch.cat([sigmas, torch.zeros(1)]).to(device=device)\n\n        self._step_index = None\n        self.reset_scheduler_history()\n\n        if self.config.scheduler_type == \"dmd\":\n            self.timesteps = self.timesteps[:-1]\n            self.sigmas = torch.cat([self.sigmas[:-2], self.sigmas[-1:]])\n\n        if self.config.use_dynamic_shifting:\n            assert self.config.shift == 1.0\n            self.sigmas = self.time_shift(mu, 1.0, self.sigmas)\n            if self.config.stages == 1:\n                self.timesteps = self.sigmas[:-1] * self.config.num_train_timesteps\n            else:\n                self.timesteps = self.timesteps_per_stage[\n                    stage_index\n                ].min() + self.sigmas[:-1] * (\n                    self.timesteps_per_stage[stage_index].max()\n                    - self.timesteps_per_stage[stage_index].min()\n                )\n\n    # ---------------------------------- Euler ----------------------------------\n    def index_for_timestep(self, timestep, schedule_timesteps=None):\n        if schedule_timesteps is None:\n            schedule_timesteps = self.timesteps\n        indices = (schedule_timesteps == timestep).nonzero()\n        pos = 1 if len(indices) > 1 else 0\n        return indices[pos].item()\n\n    def _init_step_index(self, timestep):\n        if self.begin_index is None:\n            if isinstance(timestep, torch.Tensor):\n                timestep = timestep.to(self.timesteps.device)\n            self._step_index = self.index_for_timestep(timestep)\n        else:\n            self._step_index = self._begin_index\n\n    def step_euler(\n        self,\n        model_output: torch.FloatTensor,\n        timestep=None,\n        sample: torch.FloatTensor = None,\n        return_dict: bool = True,\n        **kwargs,\n    ) -> HeliosSchedulerOutput | tuple:\n        if self.step_index is None:\n            self._step_index = 0\n\n        sample = sample.to(torch.float32)\n        sigma = self.sigmas[self.step_index]\n        sigma_next = self.sigmas[self.step_index + 1]\n\n        prev_sample = sample + (sigma_next - sigma) * model_output\n        prev_sample = prev_sample.to(model_output.dtype)\n\n        self._step_index += 1\n\n        if not return_dict:\n            return (prev_sample,)\n        return HeliosSchedulerOutput(prev_sample=prev_sample)\n\n    # ---------------------------------- UniPC ----------------------------------\n    def _sigma_to_alpha_sigma_t(self, sigma):\n        if self.config.use_flow_sigmas:\n            alpha_t = 1 - sigma\n            sigma_t = torch.clamp(sigma, min=1e-8)\n        else:\n            alpha_t = 1 / ((sigma**2 + 1) ** 0.5)\n            sigma_t = sigma * alpha_t\n        return alpha_t, sigma_t\n\n    def convert_model_output(self, model_output, sample=None, sigma=None, **kwargs):\n        flag = False\n        if sigma is None:\n            flag = True\n            sigma = self.sigmas[self.step_index]\n        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)\n\n        if self.predict_x0:\n            if self.config.prediction_type == \"flow_prediction\":\n                if flag:\n                    sigma_t = self.sigmas[self.step_index]\n                else:\n                    sigma_t = sigma\n                x0_pred = sample - sigma_t * model_output\n            elif self.config.prediction_type == \"epsilon\":\n                x0_pred = (sample - sigma_t * model_output) / alpha_t\n            elif self.config.prediction_type == \"sample\":\n                x0_pred = model_output\n            elif self.config.prediction_type == \"v_prediction\":\n                x0_pred = alpha_t * sample - sigma_t * model_output\n            else:\n                raise ValueError(\n                    f\"prediction_type {self.config.prediction_type} not supported\"\n                )\n            return x0_pred\n        else:\n            if self.config.prediction_type == \"epsilon\":\n                return model_output\n            elif self.config.prediction_type == \"sample\":\n                return (sample - alpha_t * model_output) / sigma_t\n            elif self.config.prediction_type == \"v_prediction\":\n                return alpha_t * model_output + sigma_t * sample\n            else:\n                raise ValueError(\n                    f\"prediction_type {self.config.prediction_type} not supported\"\n                )\n\n    def multistep_uni_p_bh_update(\n        self, model_output, sample=None, order=None, sigma=None, sigma_next=None\n    ):\n        model_output_list = self.model_outputs\n        m0 = model_output_list[-1]\n        x = sample\n\n        if sigma_next is None and sigma is None:\n            sigma_t, sigma_s0 = (\n                self.sigmas[self.step_index + 1],\n                self.sigmas[self.step_index],\n            )\n        else:\n            sigma_t, sigma_s0 = sigma_next, sigma\n        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)\n        alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)\n\n        lambda_t = torch.log(alpha_t) - torch.log(sigma_t)\n        lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)\n        h = lambda_t - lambda_s0\n        device = sample.device\n\n        rks = []\n        D1s = []\n        for i in range(1, order):\n            si = self.step_index - i\n            mi = model_output_list[-(i + 1)]\n            alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])\n            lambda_si = torch.log(alpha_si) - torch.log(sigma_si)\n            rk = (lambda_si - lambda_s0) / h\n            rks.append(rk)\n            D1s.append((mi - m0) / rk)\n\n        rks.append(1.0)\n        rks = torch.tensor(rks, device=device)\n\n        R = []\n        b = []\n\n        hh = -h if self.predict_x0 else h\n        h_phi_1 = torch.expm1(hh)\n        h_phi_k = h_phi_1 / hh - 1\n        factorial_i = 1\n\n        if self.config.solver_type == \"bh1\":\n            B_h = hh\n        elif self.config.solver_type == \"bh2\":\n            B_h = torch.expm1(hh)\n        else:\n            raise NotImplementedError()\n\n        for i in range(1, order + 1):\n            R.append(torch.pow(rks, i - 1))\n            b.append(h_phi_k * factorial_i / B_h)\n            factorial_i *= i + 1\n            h_phi_k = h_phi_k / hh - 1 / factorial_i\n\n        R = torch.stack(R)\n        b = torch.tensor(b, device=device)\n\n        if len(D1s) > 0:\n            D1s = torch.stack(D1s, dim=1)\n            if order == 2:\n                rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)\n            else:\n                rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype)\n        else:\n            D1s = None\n\n        if self.predict_x0:\n            x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0\n            pred_res = (\n                torch.einsum(\"k,bkc...->bc...\", rhos_p, D1s) if D1s is not None else 0\n            )\n            x_t = x_t_ - alpha_t * B_h * pred_res\n        else:\n            x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0\n            pred_res = (\n                torch.einsum(\"k,bkc...->bc...\", rhos_p, D1s) if D1s is not None else 0\n            )\n            x_t = x_t_ - sigma_t * B_h * pred_res\n\n        return x_t.to(x.dtype)\n\n    def multistep_uni_c_bh_update(\n        self,\n        this_model_output,\n        last_sample=None,\n        this_sample=None,\n        order=None,\n        sigma_before=None,\n        sigma=None,\n    ):\n        model_output_list = self.model_outputs\n        m0 = model_output_list[-1]\n        x = last_sample\n        model_t = this_model_output\n\n        if sigma_before is None and sigma is None:\n            sigma_t, sigma_s0 = (\n                self.sigmas[self.step_index],\n                self.sigmas[self.step_index - 1],\n            )\n        else:\n            sigma_t, sigma_s0 = sigma, sigma_before\n        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)\n        alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)\n\n        lambda_t = torch.log(alpha_t) - torch.log(sigma_t)\n        lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)\n        h = lambda_t - lambda_s0\n        device = this_sample.device\n\n        rks = []\n        D1s = []\n        for i in range(1, order):\n            si = self.step_index - (i + 1)\n            mi = model_output_list[-(i + 1)]\n            alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])\n            lambda_si = torch.log(alpha_si) - torch.log(sigma_si)\n            rk = (lambda_si - lambda_s0) / h\n            rks.append(rk)\n            D1s.append((mi - m0) / rk)\n\n        rks.append(1.0)\n        rks = torch.tensor(rks, device=device)\n\n        R = []\n        b = []\n        hh = -h if self.predict_x0 else h\n        h_phi_1 = torch.expm1(hh)\n        h_phi_k = h_phi_1 / hh - 1\n        factorial_i = 1\n\n        if self.config.solver_type == \"bh1\":\n            B_h = hh\n        elif self.config.solver_type == \"bh2\":\n            B_h = torch.expm1(hh)\n        else:\n            raise NotImplementedError()\n\n        for i in range(1, order + 1):\n            R.append(torch.pow(rks, i - 1))\n            b.append(h_phi_k * factorial_i / B_h)\n            factorial_i *= i + 1\n            h_phi_k = h_phi_k / hh - 1 / factorial_i\n\n        R = torch.stack(R)\n        b = torch.tensor(b, device=device)\n\n        if len(D1s) > 0:\n            D1s = torch.stack(D1s, dim=1)\n        else:\n            D1s = None\n\n        if order == 1:\n            rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)\n        else:\n            rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)\n\n        if self.predict_x0:\n            x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0\n            corr_res = (\n                torch.einsum(\"k,bkc...->bc...\", rhos_c[:-1], D1s)\n                if D1s is not None\n                else 0\n            )\n            D1_t = model_t - m0\n            x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)\n        else:\n            x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0\n            corr_res = (\n                torch.einsum(\"k,bkc...->bc...\", rhos_c[:-1], D1s)\n                if D1s is not None\n                else 0\n            )\n            D1_t = model_t - m0\n            x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)\n\n        return x_t.to(x.dtype)\n\n    def step_unipc(\n        self,\n        model_output,\n        timestep=None,\n        sample=None,\n        return_dict: bool = True,\n        **kwargs,\n    ) -> HeliosSchedulerOutput | tuple:\n        if self.num_inference_steps is None:\n            raise ValueError(\n                \"Number of inference steps is 'None', run 'set_timesteps' first\"\n            )\n\n        if self.step_index is None:\n            self._step_index = 0\n\n        use_corrector = (\n            self.step_index > 0\n            and self.step_index - 1 not in self.disable_corrector\n            and self.last_sample is not None\n        )\n\n        model_output_convert = self.convert_model_output(model_output, sample=sample)\n\n        if use_corrector:\n            sample = self.multistep_uni_c_bh_update(\n                this_model_output=model_output_convert,\n                last_sample=self.last_sample,\n                this_sample=sample,\n                order=self.this_order,\n            )\n\n        for i in range(self.config.solver_order - 1):\n            self.model_outputs[i] = self.model_outputs[i + 1]\n            self.timestep_list[i] = self.timestep_list[i + 1]\n        self.model_outputs[-1] = model_output_convert\n        self.timestep_list[-1] = timestep\n\n        if self.config.lower_order_final:\n            this_order = min(\n                self.config.solver_order, len(self.timesteps) - self.step_index\n            )\n        else:\n            this_order = self.config.solver_order\n        self.this_order = min(this_order, self.lower_order_nums + 1)\n        assert self.this_order > 0\n\n        self.last_sample = sample\n        prev_sample = self.multistep_uni_p_bh_update(\n            model_output=model_output,\n            sample=sample,\n            order=self.this_order,\n        )\n\n        if self.lower_order_nums < self.config.solver_order:\n            self.lower_order_nums += 1\n\n        self._step_index += 1\n\n        if not return_dict:\n            return (prev_sample,)\n        return HeliosSchedulerOutput(prev_sample=prev_sample)\n\n    # ---------------------------------- DMD ----------------------------------\n    def add_noise(self, original_samples, noise, timestep, sigmas, timesteps):\n        sigmas = sigmas.to(noise.device)\n        timesteps = timesteps.to(noise.device)\n        timestep_id = torch.argmin(\n            (timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1\n        )\n        sigma = sigmas[timestep_id].reshape(-1, 1, 1, 1, 1)\n        sample = (1 - sigma) * original_samples + sigma * noise\n        return sample.type_as(noise)\n\n    def convert_flow_pred_to_x0(self, flow_pred, xt, timestep, sigmas, timesteps):\n        original_dtype = flow_pred.dtype\n        device = flow_pred.device\n        flow_pred, xt, sigmas, timesteps = (\n            x.double().to(device) for x in (flow_pred, xt, sigmas, timesteps)\n        )\n        timestep_id = torch.argmin(\n            (timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1\n        )\n        sigma_t = sigmas[timestep_id].reshape(-1, 1, 1, 1, 1)\n        x0_pred = xt - sigma_t * flow_pred\n        return x0_pred.to(original_dtype)\n\n    def step_dmd(\n        self,\n        model_output: torch.FloatTensor,\n        timestep=None,\n        sample: torch.FloatTensor = None,\n        return_dict: bool = True,\n        cur_sampling_step: int = 0,\n        dmd_noisy_tensor: torch.FloatTensor | None = None,\n        dmd_sigmas: torch.FloatTensor | None = None,\n        dmd_timesteps: torch.FloatTensor | None = None,\n        all_timesteps: torch.FloatTensor | None = None,\n        **kwargs,\n    ) -> HeliosSchedulerOutput | tuple:\n        pred_image_or_video = self.convert_flow_pred_to_x0(\n            flow_pred=model_output,\n            xt=sample,\n            timestep=torch.full(\n                (model_output.shape[0],),\n                timestep,\n                dtype=torch.long,\n                device=model_output.device,\n            ),\n            sigmas=dmd_sigmas,\n            timesteps=dmd_timesteps,\n        )\n        if cur_sampling_step < len(all_timesteps) - 1:\n            prev_sample = self.add_noise(\n                pred_image_or_video,\n                dmd_noisy_tensor,\n                torch.full(\n                    (model_output.shape[0],),\n                    all_timesteps[cur_sampling_step + 1],\n                    dtype=torch.long,\n                    device=model_output.device,\n                ),\n                sigmas=dmd_sigmas,\n                timesteps=dmd_timesteps,\n            )\n        else:\n            prev_sample = pred_image_or_video\n\n        if not return_dict:\n            return (prev_sample,)\n        return HeliosSchedulerOutput(prev_sample=prev_sample)\n\n    # ---------------------------------- Main step ----------------------------------\n    def step(\n        self,\n        model_output,\n        timestep=None,\n        sample=None,\n        return_dict: bool = True,\n        **kwargs,\n    ) -> HeliosSchedulerOutput | tuple:\n        if self.config.scheduler_type == \"euler\":\n            return self.step_euler(\n                model_output=model_output,\n                timestep=timestep,\n                sample=sample,\n                return_dict=return_dict,\n            )\n        elif self.config.scheduler_type == \"unipc\":\n            return self.step_unipc(\n                model_output=model_output,\n                timestep=timestep,\n                sample=sample,\n                return_dict=return_dict,\n            )\n        elif self.config.scheduler_type == \"dmd\":\n            return self.step_dmd(\n                model_output=model_output,\n                timestep=timestep,\n                sample=sample,\n                return_dict=return_dict,\n                **kwargs,\n            )\n        else:\n            raise NotImplementedError(\n                f\"Scheduler type '{self.config.scheduler_type}' not implemented\"\n            )\n\n    def reset_scheduler_history(self):\n        self.model_outputs = [None] * self.config.solver_order\n        self.timestep_list = [None] * self.config.solver_order\n        self.lower_order_nums = 0\n        self.disable_corrector = self.config.disable_corrector\n        self.solver_p = None\n        self.last_sample = None\n        self._step_index = None\n        self._begin_index = None\n\n    def set_shift(self, shift: float):\n        \"\"\"Update the shift parameter (called by SchedulerLoader after loading).\"\"\"\n        self.config.shift = shift\n        self.shift = shift\n\n    def __len__(self):\n        return self.config.num_train_timesteps\n\n\n# Alias for Helios-Distilled which uses \"HeliosDMDScheduler\" in scheduler_config.json\nHeliosDMDScheduler = HeliosScheduler\n\nEntryClass = [HeliosScheduler, \"HeliosDMDScheduler\"]\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_self_forcing_flow_match.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\nimport torch\nfrom diffusers.configuration_utils import ConfigMixin, register_to_config\nfrom diffusers.schedulers.scheduling_utils import SchedulerMixin\nfrom diffusers.utils import BaseOutput\n\nfrom sglang.multimodal_gen.runtime.models.schedulers.base import BaseScheduler\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\nclass SelfForcingFlowMatchSchedulerOutput(BaseOutput):\n    \"\"\"\n    Output class for the scheduler's `step` function output.\n\n    Args:\n        prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):\n            Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the\n            denoising loop.\n    \"\"\"\n\n    prev_sample: torch.FloatTensor\n\n\nclass SelfForcingFlowMatchScheduler(BaseScheduler, ConfigMixin, SchedulerMixin):\n    config_name = \"scheduler_config.json\"\n    order = 1\n\n    @register_to_config\n    def __init__(\n        self,\n        num_inference_steps=100,\n        num_train_timesteps=1000,\n        shift=3.0,\n        sigma_max=1.0,\n        sigma_min=0.003 / 1.002,\n        inverse_timesteps=False,\n        extra_one_step=False,\n        reverse_sigmas=False,\n        *args,\n        **kwargs,\n    ):\n        self.num_train_timesteps = num_train_timesteps\n        self.shift = shift\n        self.sigma_max = sigma_max\n        self.sigma_min = sigma_min\n        self.inverse_timesteps = inverse_timesteps\n        self.extra_one_step = extra_one_step\n        self.reverse_sigmas = reverse_sigmas\n        self.set_timesteps(num_inference_steps)\n\n    def set_timesteps(\n        self,\n        num_inference_steps=100,\n        denoising_strength=1.0,\n        return_dict=False,\n        **kwargs,\n    ):\n        sigma_start = (\n            self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength\n        )\n        if self.extra_one_step:\n            self.sigmas = torch.linspace(\n                sigma_start, self.sigma_min, num_inference_steps + 1\n            )[:-1]\n        else:\n            self.sigmas = torch.linspace(\n                sigma_start, self.sigma_min, num_inference_steps\n            )\n        if self.inverse_timesteps:\n            self.sigmas = torch.flip(self.sigmas, dims=[0])\n        self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)\n        if self.reverse_sigmas:\n            self.sigmas = 1 - self.sigmas\n        self.timesteps = self.sigmas * self.num_train_timesteps\n\n    def step(\n        self,\n        model_output: torch.FloatTensor,\n        timestep: torch.FloatTensor,\n        sample: torch.FloatTensor,\n        to_final=False,\n        return_dict=False,\n        **kwargs,\n    ):\n        if timestep.ndim == 2:\n            timestep = timestep.flatten(0, 1)\n        elif timestep.ndim == 0:\n            # handles the case where timestep is a scalar, this occurs when we\n            # use this scheduler for ODE trajectory\n            timestep = timestep.unsqueeze(0)\n\n        self.sigmas = self.sigmas.to(model_output.device)\n        self.timesteps = self.timesteps.to(model_output.device)\n        timestep = timestep.to(model_output.device)\n\n        timestep_id = torch.argmin(\n            (self.timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1\n        )\n        sigma = self.sigmas[timestep_id].reshape(-1, 1, 1, 1)\n        if to_final or (timestep_id + 1 >= len(self.timesteps)).any():\n            sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0\n        else:\n            sigma_ = self.sigmas[timestep_id + 1].reshape(-1, 1, 1, 1)\n        prev_sample = sample + model_output * (sigma_ - sigma)\n        if isinstance(prev_sample, torch.Tensor | float) and not return_dict:\n            return (prev_sample,)\n        return SelfForcingFlowMatchSchedulerOutput(prev_sample=prev_sample)\n\n    def add_noise(self, original_samples, noise, timestep):\n        \"\"\"\n        Diffusion forward corruption process.\n        Input:\n            - clean_latent: the clean latent with shape [B*T, C, H, W]\n            - noise: the noise with shape [B*T, C, H, W]\n            - timestep: the timestep with shape [B*T]\n        Output: the corrupted latent with shape [B*T, C, H, W]\n        \"\"\"\n        if timestep.ndim == 2:\n            timestep = timestep.flatten(0, 1)\n        self.sigmas = self.sigmas.to(noise.device)\n        self.timesteps = self.timesteps.to(noise.device)\n        timestep_id = torch.argmin(\n            (self.timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1\n        )\n        sigma = self.sigmas[timestep_id].reshape(-1, 1, 1, 1)\n        sample = (1 - sigma) * original_samples + sigma * noise\n        return sample.type_as(noise)\n\n    def scale_model_input(\n        self, sample: torch.Tensor, timestep: int | None = None\n    ) -> torch.Tensor:\n        return sample\n\n    def set_shift(self, shift: float) -> None:\n        self.shift = shift\n\n\nEntryClass = SelfForcingFlowMatchScheduler\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_unipc_multistep.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# Copyright 2025 TSAIL Team and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# DISCLAIMER: check https://huggingface.co/papers/2302.04867 and https://github.com/wl-zhao/UniPC for more info\n# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py\n# ==============================================================================\n#\n# Modified from diffusers==0.35.0.dev0\n#\n# ==============================================================================\n\nimport math\n\nimport numpy as np\nimport torch\nfrom diffusers.configuration_utils import ConfigMixin, register_to_config\nfrom diffusers.schedulers.scheduling_utils import (\n    KarrasDiffusionSchedulers,\n    SchedulerMixin,\n    SchedulerOutput,\n)\nfrom diffusers.utils import deprecate, is_scipy_available\n\nfrom sglang.multimodal_gen.runtime.models.schedulers.base import BaseScheduler\n\nif is_scipy_available():\n    import scipy.stats\n\n\n# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar\ndef betas_for_alpha_bar(\n    num_diffusion_timesteps,\n    max_beta=0.999,\n    alpha_transform_type=\"cosine\",\n):\n    \"\"\"\n    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of\n    (1-beta) over time from t = [0,1].\n\n    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up\n    to that part of the diffusion process.\n\n\n    Args:\n        num_diffusion_timesteps (`int`): the number of betas to produce.\n        max_beta (`float`): the maximum beta to use; use values lower than 1 to\n                     prevent singularities.\n        alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.\n                     Choose from `cosine` or `exp`\n\n    Returns:\n        betas (`np.ndarray`): the betas used by the scheduler to step the model outputs\n    \"\"\"\n    if alpha_transform_type == \"cosine\":\n\n        def alpha_bar_fn(t):\n            return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2\n\n    elif alpha_transform_type == \"exp\":\n\n        def alpha_bar_fn(t):\n            return math.exp(t * -12.0)\n\n    else:\n        raise ValueError(f\"Unsupported alpha_transform_type: {alpha_transform_type}\")\n\n    betas = []\n    for i in range(num_diffusion_timesteps):\n        t1 = i / num_diffusion_timesteps\n        t2 = (i + 1) / num_diffusion_timesteps\n        betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))\n    return torch.tensor(betas, dtype=torch.float32)\n\n\n# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr\ndef rescale_zero_terminal_snr(betas):\n    \"\"\"\n    Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)\n\n\n    Args:\n        betas (`torch.Tensor`):\n            the betas that the scheduler is being initialized with.\n\n    Returns:\n        `torch.Tensor`: rescaled betas with zero terminal SNR\n    \"\"\"\n    # Convert betas to alphas_bar_sqrt\n    alphas = 1.0 - betas\n    alphas_cumprod = torch.cumprod(alphas, dim=0)\n    alphas_bar_sqrt = alphas_cumprod.sqrt()\n\n    # Store old values.\n    alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()\n    alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()\n\n    # Shift so the last timestep is zero.\n    alphas_bar_sqrt -= alphas_bar_sqrt_T\n\n    # Scale so the first timestep is back to the old value.\n    alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)\n\n    # Convert alphas_bar_sqrt to betas\n    alphas_bar = alphas_bar_sqrt**2  # Revert sqrt\n    alphas = alphas_bar[1:] / alphas_bar[:-1]  # Revert cumprod\n    alphas = torch.cat([alphas_bar[0:1], alphas])\n    betas = 1 - alphas\n\n    return betas\n\n\nclass UniPCMultistepScheduler(SchedulerMixin, ConfigMixin, BaseScheduler):\n    \"\"\"\n    `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models.\n\n    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic\n    methods the library implements for all schedulers such as loading and saving.\n\n    Args:\n        num_train_timesteps (`int`, defaults to 1000):\n            The number of diffusion steps to train the model.\n        beta_start (`float`, defaults to 0.0001):\n            The starting `beta` value of inference.\n        beta_end (`float`, defaults to 0.02):\n            The final `beta` value.\n        beta_schedule (`str`, defaults to `\"linear\"`):\n            The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from\n            `linear`, `scaled_linear`, or `squaredcos_cap_v2`.\n        trained_betas (`np.ndarray`, *optional*):\n            Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.\n        solver_order (`int`, default `2`):\n            The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1`\n            due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for\n            unconditional sampling.\n        prediction_type (`str`, defaults to `epsilon`, *optional*):\n            Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),\n            `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen\n            Video](https://imagen.research.google/video/paper.pdf) paper).\n        thresholding (`bool`, defaults to `False`):\n            Whether to use the \"dynamic thresholding\" method. This is unsuitable for latent-space diffusion models such\n            as Stable Diffusion.\n        dynamic_thresholding_ratio (`float`, defaults to 0.995):\n            The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.\n        sample_max_value (`float`, defaults to 1.0):\n            The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`.\n        predict_x0 (`bool`, defaults to `True`):\n            Whether to use the updating algorithm on the predicted x0.\n        solver_type (`str`, default `bh2`):\n            Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2`\n            otherwise.\n        lower_order_final (`bool`, default `True`):\n            Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can\n            stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.\n        disable_corrector (`list`, default `[]`):\n            Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)`\n            and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is\n            usually disabled during the first few steps.\n        solver_p (`SchedulerMixin`, default `None`):\n            Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`.\n        use_karras_sigmas (`bool`, *optional*, defaults to `False`):\n            Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,\n            the sigmas are determined according to a sequence of noise levels {σi}.\n        use_exponential_sigmas (`bool`, *optional*, defaults to `False`):\n            Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.\n        use_beta_sigmas (`bool`, *optional*, defaults to `False`):\n            Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta\n            Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.\n        use_flow_sigmas (`bool`, *optional*, defaults to `False`):\n            Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.\n        timestep_spacing (`str`, defaults to `\"linspace\"`):\n            The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and\n            Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.\n        steps_offset (`int`, defaults to 0):\n            An offset added to the inference steps, as required by some model families.\n        final_sigmas_type (`str`, defaults to `\"zero\"`):\n            The final `sigma` value for the noise schedule during the sampling process. If `\"sigma_min\"`, the final\n            sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.\n        rescale_betas_zero_snr (`bool`, defaults to `False`):\n            Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and\n            dark samples instead of limiting it to samples with medium brightness. Loosely related to\n            [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).\n    \"\"\"\n\n    _compatibles = [e.name for e in KarrasDiffusionSchedulers]\n    order = 1\n\n    @register_to_config\n    def __init__(\n        self,\n        num_train_timesteps: int = 1000,\n        beta_start: float = 0.0001,\n        beta_end: float = 0.02,\n        beta_schedule: str = \"linear\",\n        trained_betas: np.ndarray | list[float] | None = None,\n        solver_order: int = 2,\n        prediction_type: str = \"epsilon\",\n        thresholding: bool = False,\n        dynamic_thresholding_ratio: float = 0.995,\n        sample_max_value: float = 1.0,\n        predict_x0: bool = True,\n        solver_type: str = \"bh2\",\n        lower_order_final: bool = True,\n        disable_corrector: list[int] = [],\n        solver_p: SchedulerMixin = None,\n        use_karras_sigmas: bool | None = False,\n        use_exponential_sigmas: bool | None = False,\n        use_beta_sigmas: bool | None = False,\n        use_flow_sigmas: bool | None = False,\n        flow_shift: float | None = 1.0,\n        timestep_spacing: str = \"linspace\",\n        steps_offset: int = 0,\n        final_sigmas_type: str | None = \"zero\",  # \"zero\", \"sigma_min\"\n        rescale_betas_zero_snr: bool = False,\n        use_dynamic_shifting: bool = False,\n        time_shift_type: str = \"exponential\",\n    ):\n        if self.config.use_beta_sigmas and not is_scipy_available():\n            raise ImportError(\n                \"Make sure to install scipy if you want to use beta sigmas.\"\n            )\n        if (\n            sum(\n                [\n                    self.config.use_beta_sigmas,\n                    self.config.use_exponential_sigmas,\n                    self.config.use_karras_sigmas,\n                ]\n            )\n            > 1\n        ):\n            raise ValueError(\n                \"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.\"\n            )\n        if trained_betas is not None:\n            self.betas = torch.tensor(trained_betas, dtype=torch.float32)\n        elif beta_schedule == \"linear\":\n            self.betas = torch.linspace(\n                beta_start, beta_end, num_train_timesteps, dtype=torch.float32\n            )\n        elif beta_schedule == \"scaled_linear\":\n            # this schedule is very specific to the latent diffusion model.\n            self.betas = (\n                torch.linspace(\n                    beta_start**0.5,\n                    beta_end**0.5,\n                    num_train_timesteps,\n                    dtype=torch.float32,\n                )\n                ** 2\n            )\n        elif beta_schedule == \"squaredcos_cap_v2\":\n            # Glide cosine schedule\n            self.betas = betas_for_alpha_bar(num_train_timesteps)\n        else:\n            raise NotImplementedError(\n                f\"{beta_schedule} is not implemented for {self.__class__}\"\n            )\n\n        if rescale_betas_zero_snr:\n            self.betas = rescale_zero_terminal_snr(self.betas)\n\n        self.alphas = 1.0 - self.betas\n        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)\n\n        if rescale_betas_zero_snr:\n            # Close to 0 without being 0 so first sigma is not inf\n            # FP16 smallest positive subnormal works well here\n            self.alphas_cumprod[-1] = 2**-24\n\n        # Currently we only support VP-type noise schedule\n        self.alpha_t = torch.sqrt(self.alphas_cumprod)\n        self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)\n        self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)\n        self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5\n\n        # standard deviation of the initial noise distribution\n        self.init_noise_sigma = 1.0\n\n        if solver_type not in [\"bh1\", \"bh2\"]:\n            if solver_type in [\"midpoint\", \"heun\", \"logrho\"]:\n                self.register_to_config(solver_type=\"bh2\")\n            else:\n                raise NotImplementedError(\n                    f\"{solver_type} is not implemented for {self.__class__}\"\n                )\n\n        self.predict_x0 = predict_x0\n        # setable values\n        self.num_inference_steps = None\n        timesteps = np.linspace(\n            0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32\n        )[::-1].copy()\n        self.timesteps = torch.from_numpy(timesteps)\n        self.num_train_timesteps = num_train_timesteps\n        self.model_outputs = [None] * solver_order\n        self.timestep_list = [None] * solver_order\n        self.lower_order_nums = 0\n        self.disable_corrector = disable_corrector\n        self.solver_p = solver_p\n        self.last_sample = None\n        self._step_index = None\n        self._begin_index = None\n        self.sigmas = self.sigmas.to(\"cpu\")  # to avoid too much CPU/GPU communication\n\n        BaseScheduler.__init__(self)\n\n    @property\n    def step_index(self):\n        \"\"\"\n        The index counter for current timestep. It will increase 1 after each scheduler step.\n        \"\"\"\n        return self._step_index\n\n    @property\n    def begin_index(self):\n        \"\"\"\n        The index for the first timestep. It should be set from pipeline with `set_begin_index` method.\n        \"\"\"\n        return self._begin_index\n\n    def set_shift(self, shift: float) -> None:\n        self.config.flow_shift = shift\n\n    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index\n    def set_begin_index(self, begin_index: int = 0):\n        \"\"\"\n        Sets the begin index for the scheduler. This function should be run from pipeline before the inference.\n\n        Args:\n            begin_index (`int`):\n                The begin index for the scheduler.\n        \"\"\"\n        self._begin_index = begin_index\n\n    def set_timesteps(\n        self,\n        num_inference_steps: int,\n        device: str | torch.device = None,\n        mu: float | None = None,\n    ):\n        \"\"\"\n        Sets the discrete timesteps used for the diffusion chain (to be run before inference).\n\n        Args:\n            num_inference_steps (`int`):\n                The number of diffusion steps used when generating samples with a pre-trained model.\n            device (`str` or `torch.device`, *optional*):\n                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        \"\"\"\n        # \"linspace\", \"leading\", \"trailing\" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891\n        if mu is not None:\n            assert (\n                self.config.use_dynamic_shifting\n                and self.config.time_shift_type == \"exponential\"\n            )\n            self.config.flow_shift = np.exp(mu)\n        if self.config.timestep_spacing == \"linspace\":\n            timesteps = (\n                np.linspace(\n                    0, self.config.num_train_timesteps - 1, num_inference_steps + 1\n                )\n                .round()[::-1][:-1]\n                .copy()\n                .astype(np.int64)\n            )\n        elif self.config.timestep_spacing == \"leading\":\n            step_ratio = self.config.num_train_timesteps // (num_inference_steps + 1)\n            # creates integer timesteps by multiplying by ratio\n            # casting to int to avoid issues when num_inference_step is power of 3\n            timesteps = (\n                (np.arange(0, num_inference_steps + 1) * step_ratio)\n                .round()[::-1][:-1]\n                .copy()\n                .astype(np.int64)\n            )\n            timesteps += self.config.steps_offset\n        elif self.config.timestep_spacing == \"trailing\":\n            step_ratio = self.config.num_train_timesteps / num_inference_steps\n            # creates integer timesteps by multiplying by ratio\n            # casting to int to avoid issues when num_inference_step is power of 3\n            timesteps = (\n                np.arange(self.config.num_train_timesteps, 0, -step_ratio)\n                .round()\n                .copy()\n                .astype(np.int64)\n            )\n            timesteps -= 1\n        else:\n            raise ValueError(\n                f\"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'.\"\n            )\n\n        sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)\n        if self.config.use_karras_sigmas:\n            log_sigmas = np.log(sigmas)\n            sigmas = np.flip(sigmas).copy()\n            sigmas = self._convert_to_karras(\n                in_sigmas=sigmas, num_inference_steps=num_inference_steps\n            )\n            timesteps = np.array(\n                [self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]\n            ).round()\n            if self.config.final_sigmas_type == \"sigma_min\":\n                sigma_last = sigmas[-1]\n            elif self.config.final_sigmas_type == \"zero\":\n                sigma_last = 0\n            else:\n                raise ValueError(\n                    f\"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}\"\n                )\n            sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)\n        elif self.config.use_exponential_sigmas:\n            log_sigmas = np.log(sigmas)\n            sigmas = np.flip(sigmas).copy()\n            sigmas = self._convert_to_exponential(\n                in_sigmas=sigmas, num_inference_steps=num_inference_steps\n            )\n            timesteps = np.array(\n                [self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]\n            )\n            if self.config.final_sigmas_type == \"sigma_min\":\n                sigma_last = sigmas[-1]\n            elif self.config.final_sigmas_type == \"zero\":\n                sigma_last = 0\n            else:\n                raise ValueError(\n                    f\"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}\"\n                )\n            sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)\n        elif self.config.use_beta_sigmas:\n            log_sigmas = np.log(sigmas)\n            sigmas = np.flip(sigmas).copy()\n            sigmas = self._convert_to_beta(\n                in_sigmas=sigmas, num_inference_steps=num_inference_steps\n            )\n            timesteps = np.array(\n                [self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]\n            )\n            if self.config.final_sigmas_type == \"sigma_min\":\n                sigma_last = sigmas[-1]\n            elif self.config.final_sigmas_type == \"zero\":\n                sigma_last = 0\n            else:\n                raise ValueError(\n                    f\"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}\"\n                )\n            sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)\n        elif self.config.use_flow_sigmas:\n            alphas = np.linspace(\n                1, 1 / self.config.num_train_timesteps, num_inference_steps + 1\n            )\n            sigmas = 1.0 - alphas\n            sigmas = np.flip(\n                self.config.flow_shift\n                * sigmas\n                / (1 + (self.config.flow_shift - 1) * sigmas)\n            )[:-1].copy()\n            timesteps = (sigmas * self.config.num_train_timesteps).copy()\n            if self.config.final_sigmas_type == \"sigma_min\":\n                sigma_last = sigmas[-1]\n            elif self.config.final_sigmas_type == \"zero\":\n                sigma_last = 0\n            else:\n                raise ValueError(\n                    f\"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}\"\n                )\n            sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)\n        else:\n            sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)\n            if self.config.final_sigmas_type == \"sigma_min\":\n                sigma_last = (\n                    (1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]\n                ) ** 0.5\n            elif self.config.final_sigmas_type == \"zero\":\n                sigma_last = 0\n            else:\n                raise ValueError(\n                    f\"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}\"\n                )\n            sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)\n\n        self.sigmas = torch.from_numpy(sigmas)\n        self.timesteps = torch.from_numpy(timesteps).to(\n            device=device, dtype=torch.int64\n        )\n\n        self.num_inference_steps = len(timesteps)\n\n        self.model_outputs = [\n            None,\n        ] * self.config.solver_order\n        self.lower_order_nums = 0\n        self.last_sample = None\n        if self.solver_p:\n            self.solver_p.set_timesteps(self.num_inference_steps, device=device)\n\n        # add an index counter for schedulers that allow duplicated timesteps\n        self._step_index = None\n        self._begin_index = None\n        self.sigmas = self.sigmas.to(\"cpu\")  # to avoid too much CPU/GPU communication\n\n    # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample\n    def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        \"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the\n        prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by\n        s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing\n        pixels from saturation at each step. We find that dynamic thresholding results in significantly better\n        photorealism as well as better image-text alignment, especially when using very large guidance weights.\"\n\n        https://huggingface.co/papers/2205.11487\n        \"\"\"\n        dtype = sample.dtype\n        batch_size, channels, *remaining_dims = sample.shape\n\n        if dtype not in (torch.float32, torch.float64):\n            sample = (\n                sample.float()\n            )  # upcast for quantile calculation, and clamp not implemented for cpu half\n\n        # Flatten sample for doing quantile calculation along each image\n        sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))\n\n        abs_sample = sample.abs()  # \"a certain percentile absolute pixel value\"\n\n        s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)\n        s = torch.clamp(\n            s, min=1, max=self.config.sample_max_value\n        )  # When clamped to min=1, equivalent to standard clipping to [-1, 1]\n        s = s.unsqueeze(1)  # (batch_size, 1) because clamp will broadcast along dim=0\n        sample = (\n            torch.clamp(sample, -s, s) / s\n        )  # \"we threshold xt0 to the range [-s, s] and then divide by s\"\n\n        sample = sample.reshape(batch_size, channels, *remaining_dims)\n        sample = sample.to(dtype)\n\n        return sample\n\n    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t\n    def _sigma_to_t(self, sigma, log_sigmas):\n        # get log sigma\n        log_sigma = np.log(np.maximum(sigma, 1e-10))\n\n        # get distribution\n        dists = log_sigma - log_sigmas[:, np.newaxis]\n\n        # get sigmas range\n        low_idx = (\n            np.cumsum((dists >= 0), axis=0)\n            .argmax(axis=0)\n            .clip(max=log_sigmas.shape[0] - 2)\n        )\n        high_idx = low_idx + 1\n\n        low = log_sigmas[low_idx]\n        high = log_sigmas[high_idx]\n\n        # interpolate sigmas\n        w = (low - log_sigma) / (low - high)\n        w = np.clip(w, 0, 1)\n\n        # transform interpolation to time range\n        t = (1 - w) * low_idx + w * high_idx\n        t = t.reshape(sigma.shape)\n        return t\n\n    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t\n    def _sigma_to_alpha_sigma_t(self, sigma):\n        if self.config.use_flow_sigmas:\n            alpha_t = 1 - sigma\n            sigma_t = sigma\n        else:\n            alpha_t = 1 / ((sigma**2 + 1) ** 0.5)\n            sigma_t = sigma * alpha_t\n\n        return alpha_t, sigma_t\n\n    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras\n    def _convert_to_karras(\n        self, in_sigmas: torch.Tensor, num_inference_steps\n    ) -> torch.Tensor:\n        \"\"\"Constructs the noise schedule of Karras et al. (2022).\"\"\"\n\n        # Hack to make sure that other schedulers which copy this function don't break\n        # TODO: Add this logic to the other schedulers\n        if hasattr(self.config, \"sigma_min\"):\n            sigma_min = self.config.sigma_min\n        else:\n            sigma_min = None\n\n        if hasattr(self.config, \"sigma_max\"):\n            sigma_max = self.config.sigma_max\n        else:\n            sigma_max = None\n\n        sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()\n        sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()\n\n        rho = 7.0  # 7.0 is the value used in the paper\n        ramp = np.linspace(0, 1, num_inference_steps)\n        min_inv_rho = sigma_min ** (1 / rho)\n        max_inv_rho = sigma_max ** (1 / rho)\n        sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho\n        return sigmas\n\n    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential\n    def _convert_to_exponential(\n        self, in_sigmas: torch.Tensor, num_inference_steps: int\n    ) -> torch.Tensor:\n        \"\"\"Constructs an exponential noise schedule.\"\"\"\n\n        # Hack to make sure that other schedulers which copy this function don't break\n        # TODO: Add this logic to the other schedulers\n        if hasattr(self.config, \"sigma_min\"):\n            sigma_min = self.config.sigma_min\n        else:\n            sigma_min = None\n\n        if hasattr(self.config, \"sigma_max\"):\n            sigma_max = self.config.sigma_max\n        else:\n            sigma_max = None\n\n        sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()\n        sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()\n\n        sigmas = np.exp(\n            np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps)\n        )\n        return sigmas\n\n    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta\n    def _convert_to_beta(\n        self,\n        in_sigmas: torch.Tensor,\n        num_inference_steps: int,\n        alpha: float = 0.6,\n        beta: float = 0.6,\n    ) -> torch.Tensor:\n        \"\"\"From \"Beta Sampling is All You Need\" [arXiv:2407.12173] (Lee et. al, 2024)\"\"\"\n\n        # Hack to make sure that other schedulers which copy this function don't break\n        # TODO: Add this logic to the other schedulers\n        if hasattr(self.config, \"sigma_min\"):\n            sigma_min = self.config.sigma_min\n        else:\n            sigma_min = None\n\n        if hasattr(self.config, \"sigma_max\"):\n            sigma_max = self.config.sigma_max\n        else:\n            sigma_max = None\n\n        sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()\n        sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()\n\n        sigmas = np.array(\n            [\n                sigma_min + (ppf * (sigma_max - sigma_min))\n                for ppf in [\n                    scipy.stats.beta.ppf(timestep, alpha, beta)\n                    for timestep in 1 - np.linspace(0, 1, num_inference_steps)\n                ]\n            ]\n        )\n        return sigmas\n\n    def convert_model_output(\n        self,\n        model_output: torch.Tensor,\n        *args,\n        sample: torch.Tensor = None,\n        **kwargs,\n    ) -> torch.Tensor:\n        r\"\"\"\n        Convert the model output to the corresponding type the UniPC algorithm needs.\n\n        Args:\n            model_output (`torch.Tensor`):\n                The direct output from the learned diffusion model.\n            timestep (`int`):\n                The current discrete timestep in the diffusion chain.\n            sample (`torch.Tensor`):\n                A current instance of a sample created by the diffusion process.\n\n        Returns:\n            `torch.Tensor`:\n                The converted model output.\n        \"\"\"\n        timestep = args[0] if len(args) > 0 else kwargs.pop(\"timestep\", None)\n        if sample is None:\n            if len(args) > 1:\n                sample = args[1]\n            else:\n                raise ValueError(\"missing `sample` as a required keyword argument\")\n        if timestep is not None:\n            deprecate(\n                \"timesteps\",\n                \"1.0.0\",\n                \"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`\",\n            )\n\n        sigma = self.sigmas[self.step_index]\n        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)\n\n        if self.predict_x0:\n            if self.config.prediction_type == \"epsilon\":\n                x0_pred = (sample - sigma_t * model_output) / alpha_t\n            elif self.config.prediction_type == \"sample\":\n                x0_pred = model_output\n            elif self.config.prediction_type == \"v_prediction\":\n                x0_pred = alpha_t * sample - sigma_t * model_output\n            elif self.config.prediction_type == \"flow_prediction\":\n                sigma_t = self.sigmas[self.step_index]\n                x0_pred = sample - sigma_t * model_output\n            else:\n                raise ValueError(\n                    f\"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, \"\n                    \"`v_prediction`, or `flow_prediction` for the UniPCMultistepScheduler.\"\n                )\n\n            if self.config.thresholding:\n                x0_pred = self._threshold_sample(x0_pred)\n\n            return x0_pred\n        else:\n            if self.config.prediction_type == \"epsilon\":\n                return model_output\n            elif self.config.prediction_type == \"sample\":\n                epsilon = (sample - alpha_t * model_output) / sigma_t\n                return epsilon\n            elif self.config.prediction_type == \"v_prediction\":\n                epsilon = alpha_t * model_output + sigma_t * sample\n                return epsilon\n            else:\n                raise ValueError(\n                    f\"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or\"\n                    \" `v_prediction` for the UniPCMultistepScheduler.\"\n                )\n\n    def multistep_uni_p_bh_update(\n        self,\n        model_output: torch.Tensor,\n        *args,\n        sample: torch.Tensor = None,\n        order: int = None,\n        **kwargs,\n    ) -> torch.Tensor:\n        \"\"\"\n        One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.\n\n        Args:\n            model_output (`torch.Tensor`):\n                The direct output from the learned diffusion model at the current timestep.\n            prev_timestep (`int`):\n                The previous discrete timestep in the diffusion chain.\n            sample (`torch.Tensor`):\n                A current instance of a sample created by the diffusion process.\n            order (`int`):\n                The order of UniP at this timestep (corresponds to the *p* in UniPC-p).\n\n        Returns:\n            `torch.Tensor`:\n                The sample tensor at the previous timestep.\n        \"\"\"\n        prev_timestep = args[0] if len(args) > 0 else kwargs.pop(\"prev_timestep\", None)\n        if sample is None:\n            if len(args) > 1:\n                sample = args[1]\n            else:\n                raise ValueError(\"missing `sample` as a required keyword argument\")\n        if order is None:\n            if len(args) > 2:\n                order = args[2]\n            else:\n                raise ValueError(\"missing `order` as a required keyword argument\")\n        if prev_timestep is not None:\n            deprecate(\n                \"prev_timestep\",\n                \"1.0.0\",\n                \"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`\",\n            )\n        model_output_list = self.model_outputs\n\n        s0 = self.timestep_list[-1]\n        m0 = model_output_list[-1]\n        x = sample\n\n        if self.solver_p:\n            x_t = self.solver_p.step(model_output, s0, x).prev_sample\n            return x_t\n\n        sigma_t, sigma_s0 = (\n            self.sigmas[self.step_index + 1],\n            self.sigmas[self.step_index],\n        )\n        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)\n        alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)\n\n        lambda_t = torch.log(alpha_t) - torch.log(sigma_t)\n        lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)\n\n        h = lambda_t - lambda_s0\n        device = sample.device\n\n        rks = []\n        D1s = []\n        for i in range(1, order):\n            si = self.step_index - i\n            mi = model_output_list[-(i + 1)]\n            alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])\n            lambda_si = torch.log(alpha_si) - torch.log(sigma_si)\n            rk = (lambda_si - lambda_s0) / h\n            rks.append(rk)\n            D1s.append((mi - m0) / rk)\n\n        rks.append(1.0)\n        rks = torch.tensor(rks, device=device)\n\n        R = []\n        b = []\n\n        hh = -h if self.predict_x0 else h\n        h_phi_1 = torch.expm1(hh)  # h\\phi_1(h) = e^h - 1\n        h_phi_k = h_phi_1 / hh - 1\n\n        factorial_i = 1\n\n        if self.config.solver_type == \"bh1\":\n            B_h = hh\n        elif self.config.solver_type == \"bh2\":\n            B_h = torch.expm1(hh)\n        else:\n            raise NotImplementedError()\n\n        for i in range(1, order + 1):\n            R.append(torch.pow(rks, i - 1))\n            b.append(h_phi_k * factorial_i / B_h)\n            factorial_i *= i + 1\n            h_phi_k = h_phi_k / hh - 1 / factorial_i\n\n        R = torch.stack(R)\n        b = torch.tensor(b, device=device)\n\n        if len(D1s) > 0:\n            D1s = torch.stack(D1s, dim=1)  # (B, K)\n            # for order 2, we use a simplified version\n            if order == 2:\n                rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)\n            else:\n                rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype)\n        else:\n            D1s = None\n\n        if self.predict_x0:\n            x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0\n            if D1s is not None:\n                pred_res = torch.einsum(\"k,bkc...->bc...\", rhos_p, D1s)\n            else:\n                pred_res = 0\n            x_t = x_t_ - alpha_t * B_h * pred_res\n        else:\n            x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0\n            if D1s is not None:\n                pred_res = torch.einsum(\"k,bkc...->bc...\", rhos_p, D1s)\n            else:\n                pred_res = 0\n            x_t = x_t_ - sigma_t * B_h * pred_res\n\n        x_t = x_t.to(x.dtype)\n        return x_t\n\n    def multistep_uni_c_bh_update(\n        self,\n        this_model_output: torch.Tensor,\n        *args,\n        last_sample: torch.Tensor = None,\n        this_sample: torch.Tensor = None,\n        order: int = None,\n        **kwargs,\n    ) -> torch.Tensor:\n        \"\"\"\n        One step for the UniC (B(h) version).\n\n        Args:\n            this_model_output (`torch.Tensor`):\n                The model outputs at `x_t`.\n            this_timestep (`int`):\n                The current timestep `t`.\n            last_sample (`torch.Tensor`):\n                The generated sample before the last predictor `x_{t-1}`.\n            this_sample (`torch.Tensor`):\n                The generated sample after the last predictor `x_{t}`.\n            order (`int`):\n                The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`.\n\n        Returns:\n            `torch.Tensor`:\n                The corrected sample tensor at the current timestep.\n        \"\"\"\n        this_timestep = args[0] if len(args) > 0 else kwargs.pop(\"this_timestep\", None)\n        if last_sample is None:\n            if len(args) > 1:\n                last_sample = args[1]\n            else:\n                raise ValueError(\"missing `last_sample` as a required keyword argument\")\n        if this_sample is None:\n            if len(args) > 2:\n                this_sample = args[2]\n            else:\n                raise ValueError(\"missing `this_sample` as a required keyword argument\")\n        if order is None:\n            if len(args) > 3:\n                order = args[3]\n            else:\n                raise ValueError(\"missing `order` as a required keyword argument\")\n        if this_timestep is not None:\n            deprecate(\n                \"this_timestep\",\n                \"1.0.0\",\n                \"Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`\",\n            )\n\n        model_output_list = self.model_outputs\n\n        m0 = model_output_list[-1]\n        x = last_sample\n        x_t = this_sample\n        model_t = this_model_output\n\n        sigma_t, sigma_s0 = (\n            self.sigmas[self.step_index],\n            self.sigmas[self.step_index - 1],\n        )\n        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)\n        alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)\n\n        lambda_t = torch.log(alpha_t) - torch.log(sigma_t)\n        lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)\n\n        h = lambda_t - lambda_s0\n        device = this_sample.device\n\n        rks = []\n        D1s = []\n        for i in range(1, order):\n            si = self.step_index - (i + 1)\n            mi = model_output_list[-(i + 1)]\n            alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])\n            lambda_si = torch.log(alpha_si) - torch.log(sigma_si)\n            rk = (lambda_si - lambda_s0) / h\n            rks.append(rk)\n            D1s.append((mi - m0) / rk)\n\n        rks.append(1.0)\n        rks = torch.tensor(rks, device=device)\n\n        R = []\n        b = []\n\n        hh = -h if self.predict_x0 else h\n        h_phi_1 = torch.expm1(hh)  # h\\phi_1(h) = e^h - 1\n        h_phi_k = h_phi_1 / hh - 1\n\n        factorial_i = 1\n\n        if self.config.solver_type == \"bh1\":\n            B_h = hh\n        elif self.config.solver_type == \"bh2\":\n            B_h = torch.expm1(hh)\n        else:\n            raise NotImplementedError()\n\n        for i in range(1, order + 1):\n            R.append(torch.pow(rks, i - 1))\n            b.append(h_phi_k * factorial_i / B_h)\n            factorial_i *= i + 1\n            h_phi_k = h_phi_k / hh - 1 / factorial_i\n\n        R = torch.stack(R)\n        b = torch.tensor(b, device=device)\n\n        if len(D1s) > 0:\n            D1s = torch.stack(D1s, dim=1)\n        else:\n            D1s = None\n\n        # for order 1, we use a simplified version\n        if order == 1:\n            rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)\n        else:\n            rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)\n\n        if self.predict_x0:\n            x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0\n            if D1s is not None:\n                corr_res = torch.einsum(\"k,bkc...->bc...\", rhos_c[:-1], D1s)\n            else:\n                corr_res = 0\n            D1_t = model_t - m0\n            x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)\n        else:\n            x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0\n            if D1s is not None:\n                corr_res = torch.einsum(\"k,bkc...->bc...\", rhos_c[:-1], D1s)\n            else:\n                corr_res = 0\n            D1_t = model_t - m0\n            x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)\n        x_t = x_t.to(x.dtype)\n        return x_t\n\n    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep\n    def index_for_timestep(self, timestep, schedule_timesteps=None):\n        if schedule_timesteps is None:\n            schedule_timesteps = self.timesteps\n\n        index_candidates = (schedule_timesteps == timestep).nonzero()\n\n        if len(index_candidates) == 0:\n            step_index = len(self.timesteps) - 1\n        # The sigma index that is taken for the **very** first `step`\n        # is always the second index (or the last index if there is only 1)\n        # This way we can ensure we don't accidentally skip a sigma in\n        # case we start in the middle of the denoising schedule (e.g. for image-to-image)\n        elif len(index_candidates) > 1:\n            step_index = index_candidates[1].item()\n        else:\n            step_index = index_candidates[0].item()\n\n        return step_index\n\n    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index\n    def _init_step_index(self, timestep):\n        \"\"\"\n        Initialize the step_index counter for the scheduler.\n        \"\"\"\n\n        if self.begin_index is None:\n            if isinstance(timestep, torch.Tensor):\n                timestep = timestep.to(self.timesteps.device)\n            self._step_index = self.index_for_timestep(timestep)\n        else:\n            self._step_index = self._begin_index\n\n    def step(\n        self,\n        model_output: torch.Tensor,\n        timestep: int | torch.Tensor,\n        sample: torch.Tensor,\n        return_dict: bool = True,\n    ) -> SchedulerOutput | tuple:\n        \"\"\"\n        Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with\n        the multistep UniPC.\n\n        Args:\n            model_output (`torch.Tensor`):\n                The direct output from learned diffusion model.\n            timestep (`int`):\n                The current discrete timestep in the diffusion chain.\n            sample (`torch.Tensor`):\n                A current instance of a sample created by the diffusion process.\n            return_dict (`bool`):\n                Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.\n\n        Returns:\n            [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:\n                If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a\n                tuple is returned where the first element is the sample tensor.\n\n        \"\"\"\n        if self.num_inference_steps is None:\n            raise ValueError(\n                \"Number of inference steps is 'None', you need to call 'set_timesteps' after creating the scheduler\"\n            )\n\n        if self.step_index is None:\n            self._init_step_index(timestep)\n\n        use_corrector = (\n            self.step_index > 0\n            and self.step_index - 1 not in self.disable_corrector\n            and self.last_sample is not None\n        )\n\n        model_output_convert = self.convert_model_output(model_output, sample=sample)\n        if use_corrector:\n            sample = self.multistep_uni_c_bh_update(\n                this_model_output=model_output_convert,\n                last_sample=self.last_sample,\n                this_sample=sample,\n                order=self.this_order,\n            )\n\n        for i in range(self.config.solver_order - 1):\n            self.model_outputs[i] = self.model_outputs[i + 1]\n            self.timestep_list[i] = self.timestep_list[i + 1]\n\n        self.model_outputs[-1] = model_output_convert\n        self.timestep_list[-1] = timestep\n\n        if self.config.lower_order_final:\n            this_order = min(\n                self.config.solver_order, len(self.timesteps) - self.step_index\n            )\n        else:\n            this_order = self.config.solver_order\n\n        self.this_order = min(\n            this_order, self.lower_order_nums + 1\n        )  # warmup for multistep\n        assert self.this_order > 0\n\n        self.last_sample = sample\n        prev_sample = self.multistep_uni_p_bh_update(\n            model_output=model_output,  # pass the original non-converted model output, in case solver-p is used\n            sample=sample,\n            order=self.this_order,\n        )\n\n        if self.lower_order_nums < self.config.solver_order:\n            self.lower_order_nums += 1\n\n        # upon completion increase step index by one\n        self._step_index += 1\n\n        if not return_dict:\n            return (prev_sample,)\n\n        return SchedulerOutput(prev_sample=prev_sample)\n\n    def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:\n        \"\"\"\n        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the\n        current timestep.\n\n        Args:\n            sample (`torch.Tensor`):\n                The input sample.\n\n        Returns:\n            `torch.Tensor`:\n                A scaled input sample.\n        \"\"\"\n        return sample\n\n    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise\n    def add_noise(\n        self,\n        original_samples: torch.Tensor,\n        noise: torch.Tensor,\n        timesteps: torch.IntTensor,\n    ) -> torch.Tensor:\n        # Make sure sigmas and timesteps have the same device and dtype as original_samples\n        sigmas = self.sigmas.to(\n            device=original_samples.device, dtype=original_samples.dtype\n        )\n        if original_samples.device.type == \"mps\" and torch.is_floating_point(timesteps):\n            # mps does not support float64\n            schedule_timesteps = self.timesteps.to(\n                original_samples.device, dtype=torch.float32\n            )\n            timesteps = timesteps.to(original_samples.device, dtype=torch.float32)\n        else:\n            schedule_timesteps = self.timesteps.to(original_samples.device)\n            timesteps = timesteps.to(original_samples.device)\n\n        # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index\n        if self.begin_index is None:\n            step_indices = [\n                self.index_for_timestep(t, schedule_timesteps) for t in timesteps\n            ]\n        elif self.step_index is not None:\n            # add_noise is called after first denoising step (for inpainting)\n            step_indices = [self.step_index] * timesteps.shape[0]\n        else:\n            # add noise is called before first denoising step to create initial latent(img2img)\n            step_indices = [self.begin_index] * timesteps.shape[0]\n\n        sigma = sigmas[step_indices].flatten()\n        while len(sigma.shape) < len(original_samples.shape):\n            sigma = sigma.unsqueeze(-1)\n\n        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)\n        noisy_samples = alpha_t * original_samples + sigma_t * noise\n        return noisy_samples\n\n    def __len__(self):\n        return self.config.num_train_timesteps\n\n\nEntryClass = UniPCMultistepScheduler\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/utils.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# Adapted from: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/utils.py\n\"\"\"Utils for model executor.\"\"\"\n\nfrom typing import Any\n\nimport torch\n\n\ndef set_weight_attrs(\n    weight: torch.Tensor,\n    weight_attrs: dict[str, Any] | None,\n):\n    \"\"\"Set attributes on a weight tensor.\n\n    This method is used to set attributes on a weight tensor. This method\n    will not overwrite existing attributes.\n\n    Args:\n        weight: The weight tensor.\n        weight_attrs: A dictionary of attributes to set on the weight tensor.\n    \"\"\"\n    if weight_attrs is None:\n        return\n    for key, value in weight_attrs.items():\n        assert not hasattr(weight, key), f\"Overwriting existing tensor attribute: {key}\"\n\n        # NOTE(woosuk): During weight loading, we often do something like:\n        # narrowed_tensor = param.data.narrow(0, offset, len)\n        # narrowed_tensor.copy_(real_weight)\n        # expecting narrowed_tensor and param.data to share the same storage.\n        # However, on TPUs, narrowed_tensor will lazily propagate to the base\n        # tensor, which is param.data, leading to the redundant memory usage.\n        # This sometimes causes OOM errors during model loading. To avoid this,\n        # we sync the param tensor after its weight loader is called.\n        # TODO(woosuk): Remove this hack once we have a better solution.\n        from sglang.multimodal_gen.runtime.platforms import current_platform\n\n        if current_platform.is_tpu() and key == \"weight_loader\":\n            value = _make_synced_weight_loader(value)\n        setattr(weight, key, value)\n\n\ndef _make_synced_weight_loader(original_weight_loader) -> Any:\n\n    def _synced_weight_loader(param, *args, **kwargs):\n        original_weight_loader(param, *args, **kwargs)\n        torch._sync(param)\n\n    return _synced_weight_loader\n\n\ndef extract_layer_index(layer_name: str) -> int:\n    \"\"\"\n    Extract the layer index from the module name.\n    Examples:\n    - \"encoder.layers.0\" -> 0\n    - \"encoder.layers.1.self_attn\" -> 1\n    - \"2.self_attn\" -> 2\n    - \"model.encoder.layers.0.sub.1\" -> ValueError\n    \"\"\"\n    subnames = layer_name.split(\".\")\n    int_vals: list[int] = []\n    for subname in subnames:\n        try:\n            int_vals.append(int(subname))\n        except ValueError:\n            continue\n    assert len(int_vals) == 1, (\n        f\"layer name {layer_name} should\" \" only contain one integer\"\n    )\n    return int_vals[0]\n\n\ndef modulate(\n    x: torch.Tensor,\n    shift: torch.Tensor | None = None,\n    scale: torch.Tensor | None = None,\n) -> torch.Tensor:\n    \"\"\"modulate by shift and scale\"\"\"\n    if scale is None and shift is None:\n        return x\n    elif shift is None:\n        return x * (1 + scale.unsqueeze(1))  # type: ignore[union-attr]\n    elif scale is None:\n        return x + shift.unsqueeze(1)  # type: ignore[union-attr]\n    else:\n        return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(\n            1\n        )  # type: ignore[union-attr]\n\n\ndef pred_noise_to_pred_video(\n    pred_noise: torch.Tensor,\n    noise_input_latent: torch.Tensor,\n    timestep: torch.Tensor,\n    scheduler: Any,\n) -> torch.Tensor:\n    \"\"\"\n    Convert predicted noise to clean latent.\n\n    Args:\n    pred_noise: the predicted noise with shape [B, C, H, W]\n        where B is batch_size or batch_size * num_frames\n    noise_input_latent: the noisy latent with shape [B, C, H, W],\n    timestep: the timestep with shape [1] or [bs * num_frames] or [bs, num_frames]\n    scheduler: the scheduler\n\n    Returns:\n        the predicted video with shape [B, C, H, W]\n    \"\"\"\n    # If timestep is [bs, num_frames]\n    if timestep.ndim == 2:\n        timestep = timestep.flatten(0, 1)\n        assert timestep.numel() == noise_input_latent.shape[0]\n    elif timestep.ndim == 1:\n        # If timestep is [1]\n        if timestep.shape[0] == 1:\n            timestep = timestep.expand(noise_input_latent.shape[0])\n        else:\n            assert timestep.numel() == noise_input_latent.shape[0]\n    else:\n        raise ValueError(\n            f\"[pred_noise_to_pred_video] Invalid timestep shape: {timestep.shape}\"\n        )\n    # timestep shape should be [B]\n    dtype = pred_noise.dtype\n    device = pred_noise.device\n    pred_noise = pred_noise.double().to(device)\n    noise_input_latent = noise_input_latent.double().to(device)\n    sigmas = scheduler.sigmas.double().to(device)\n    timesteps = scheduler.timesteps.double().to(device)\n    timestep_id = torch.argmin(\n        (timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1\n    )\n    sigma_t = sigmas[timestep_id].reshape(-1, 1, 1, 1)\n    pred_video = noise_input_latent - sigma_t * pred_noise\n    return pred_video.to(dtype)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/vaes/autoencoder.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\nfrom typing import Dict, Optional, Tuple, Union\n\nimport torch\nfrom diffusers.models.attention_processor import (\n    ADDED_KV_ATTENTION_PROCESSORS,\n    CROSS_ATTENTION_PROCESSORS,\n    Attention,\n    AttentionProcessor,\n    AttnAddedKVProcessor,\n    AttnProcessor,\n    FusedAttnProcessor2_0,\n)\nfrom diffusers.models.autoencoders.vae import (\n    Decoder,\n    DecoderOutput,\n    DiagonalGaussianDistribution,\n    Encoder,\n)\nfrom diffusers.models.modeling_outputs import AutoencoderKLOutput\nfrom torch import nn\n\nfrom sglang.multimodal_gen.configs.models.vaes.flux import FluxVAEConfig\n\n\nclass AutoencoderKL(nn.Module):\n    r\"\"\"\n    A VAE model with KL loss for encoding images into latents and decoding latent representations into images.\n\n    This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented\n    for all models (such as downloading or saving).\n\n    Parameters:\n        in_channels (int, *optional*, defaults to 3): Number of channels in the input image.\n        out_channels (int,  *optional*, defaults to 3): Number of channels in the output.\n        down_block_types (`Tuple[str]`, *optional*, defaults to `(\"DownEncoderBlock2D\",)`):\n            Tuple of downsample block types.\n        up_block_types (`Tuple[str]`, *optional*, defaults to `(\"UpDecoderBlock2D\",)`):\n            Tuple of upsample block types.\n        block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):\n            Tuple of block output channels.\n        act_fn (`str`, *optional*, defaults to `\"silu\"`): The activation function to use.\n        latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.\n        sample_size (`int`, *optional*, defaults to `32`): Sample input size.\n        scaling_factor (`float`, *optional*, defaults to 0.18215):\n            The component-wise standard deviation of the trained latent space computed using the first batch of the\n            training set. This is used to scale the latent space to have unit variance when training the diffusion\n            model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the\n            diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1\n            / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image\n            Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) paper.\n        force_upcast (`bool`, *optional*, default to `True`):\n            If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE\n            can be fine-tuned / trained to a lower range without losing too much precision in which case `force_upcast`\n            can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix\n        mid_block_add_attention (`bool`, *optional*, default to `True`):\n            If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the\n            mid_block will only have resnet blocks\n    \"\"\"\n\n    _supports_gradient_checkpointing = True\n    _no_split_modules = [\"BasicTransformerBlock\", \"ResnetBlock2D\"]\n\n    def __init__(\n        self,\n        config: FluxVAEConfig,\n    ):\n        super().__init__()\n        self.config = config\n        arch_config = config.arch_config\n\n        in_channels = arch_config.in_channels\n        out_channels = arch_config.out_channels\n        down_block_types = arch_config.down_block_types\n        up_block_types = arch_config.up_block_types\n        block_out_channels = arch_config.block_out_channels\n        layers_per_block = arch_config.layers_per_block\n        act_fn = arch_config.act_fn\n        latent_channels = arch_config.latent_channels\n        norm_num_groups = arch_config.norm_num_groups\n        sample_size = arch_config.sample_size\n        use_quant_conv = arch_config.use_quant_conv\n        use_post_quant_conv = arch_config.use_post_quant_conv\n        mid_block_add_attention = arch_config.mid_block_add_attention\n\n        # pass init params to Encoder\n        self.encoder = Encoder(\n            in_channels=in_channels,\n            out_channels=latent_channels,\n            down_block_types=down_block_types,\n            block_out_channels=block_out_channels,\n            layers_per_block=layers_per_block,\n            act_fn=act_fn,\n            norm_num_groups=norm_num_groups,\n            double_z=True,\n            mid_block_add_attention=mid_block_add_attention,\n        )\n\n        # pass init params to Decoder\n        self.decoder = Decoder(\n            in_channels=latent_channels,\n            out_channels=out_channels,\n            up_block_types=up_block_types,\n            block_out_channels=block_out_channels,\n            layers_per_block=layers_per_block,\n            norm_num_groups=norm_num_groups,\n            act_fn=act_fn,\n            mid_block_add_attention=mid_block_add_attention,\n        )\n\n        self.quant_conv = (\n            nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)\n            if use_quant_conv\n            else None\n        )\n        self.post_quant_conv = (\n            nn.Conv2d(latent_channels, latent_channels, 1)\n            if use_post_quant_conv\n            else None\n        )\n\n        self.use_slicing = False\n        self.use_tiling = False\n\n        # only relevant if vae tiling is enabled\n        self.tile_sample_min_size = sample_size\n        sample_size = (\n            self.config.sample_size[0]\n            if isinstance(self.config.sample_size, (list, tuple))\n            else self.config.sample_size\n        )\n        self.tile_latent_min_size = int(\n            sample_size / (2 ** (len(self.config.block_out_channels) - 1))\n        )\n        self.tile_overlap_factor = 0.25\n\n    def enable_tiling(self, use_tiling: bool = True):\n        r\"\"\"\n        Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to\n        compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow\n        processing larger images.\n        \"\"\"\n        self.use_tiling = use_tiling\n\n    def disable_tiling(self):\n        r\"\"\"\n        Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing\n        decoding in one step.\n        \"\"\"\n        self.enable_tiling(False)\n\n    def enable_slicing(self):\n        r\"\"\"\n        Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to\n        compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.\n        \"\"\"\n        self.use_slicing = True\n\n    def disable_slicing(self):\n        r\"\"\"\n        Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing\n        decoding in one step.\n        \"\"\"\n        self.use_slicing = False\n\n    @property\n    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors\n    def attn_processors(self) -> Dict[str, AttentionProcessor]:\n        r\"\"\"\n        Returns:\n            `dict` of attention processors: A dictionary containing all attention processors used in the model with\n            indexed by its weight name.\n        \"\"\"\n        # set recursively\n        processors = {}\n\n        def fn_recursive_add_processors(\n            name: str,\n            module: torch.nn.Module,\n            processors: Dict[str, AttentionProcessor],\n        ):\n            if hasattr(module, \"get_processor\"):\n                processors[f\"{name}.processor\"] = module.get_processor()\n\n            for sub_name, child in module.named_children():\n                fn_recursive_add_processors(f\"{name}.{sub_name}\", child, processors)\n\n            return processors\n\n        for name, module in self.named_children():\n            fn_recursive_add_processors(name, module, processors)\n\n        return processors\n\n    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor\n    def set_attn_processor(\n        self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]\n    ):\n        r\"\"\"\n        Sets the attention processor to use to compute attention.\n\n        Parameters:\n            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):\n                The instantiated processor class or a dictionary of processor classes that will be set as the processor\n                for **all** `Attention` layers.\n\n                If `processor` is a dict, the key needs to define the path to the corresponding cross attention\n                processor. This is strongly recommended when setting trainable attention processors.\n\n        \"\"\"\n        count = len(self.attn_processors.keys())\n\n        if isinstance(processor, dict) and len(processor) != count:\n            raise ValueError(\n                f\"A dict of processors was passed, but the number of processors {len(processor)} does not match the\"\n                f\" number of attention layers: {count}. Please make sure to pass {count} processor classes.\"\n            )\n\n        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):\n            if hasattr(module, \"set_processor\"):\n                if not isinstance(processor, dict):\n                    module.set_processor(processor)\n                else:\n                    module.set_processor(processor.pop(f\"{name}.processor\"))\n\n            for sub_name, child in module.named_children():\n                fn_recursive_attn_processor(f\"{name}.{sub_name}\", child, processor)\n\n        for name, module in self.named_children():\n            fn_recursive_attn_processor(name, module, processor)\n\n    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor\n    def set_default_attn_processor(self):\n        \"\"\"\n        Disables custom attention processors and sets the default attention implementation.\n        \"\"\"\n        if all(\n            proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS\n            for proc in self.attn_processors.values()\n        ):\n            processor = AttnAddedKVProcessor()\n        elif all(\n            proc.__class__ in CROSS_ATTENTION_PROCESSORS\n            for proc in self.attn_processors.values()\n        ):\n            processor = AttnProcessor()\n        else:\n            raise ValueError(\n                f\"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}\"\n            )\n\n        self.set_attn_processor(processor)\n\n    def _encode(self, x: torch.Tensor) -> torch.Tensor:\n        batch_size, num_channels, height, width = x.shape\n\n        if self.use_tiling and (\n            width > self.tile_sample_min_size or height > self.tile_sample_min_size\n        ):\n            return self._tiled_encode(x)\n\n        enc = self.encoder(x)\n        if self.quant_conv is not None:\n            enc = self.quant_conv(enc)\n\n        return enc\n\n    def encode(\n        self, x: torch.Tensor, return_dict: bool = True\n    ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:\n        \"\"\"\n        Encode a batch of images into latents.\n\n        Args:\n            x (`torch.Tensor`): Input batch of images.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.\n\n        Returns:\n                The latent representations of the encoded images. If `return_dict` is True, a\n                [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.\n        \"\"\"\n        if self.use_slicing and x.shape[0] > 1:\n            encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]\n            h = torch.cat(encoded_slices)\n        else:\n            h = self._encode(x)\n\n        posterior = DiagonalGaussianDistribution(h)\n\n        if not return_dict:\n            return (posterior,)\n\n        return AutoencoderKLOutput(latent_dist=posterior)\n\n    def _decode(\n        self, z: torch.Tensor, return_dict: bool = True\n    ) -> Union[DecoderOutput, torch.Tensor]:\n        if self.use_tiling and (\n            z.shape[-1] > self.tile_latent_min_size\n            or z.shape[-2] > self.tile_latent_min_size\n        ):\n            return self.tiled_decode(z, return_dict=return_dict)\n\n        if self.post_quant_conv is not None:\n            z = self.post_quant_conv(z)\n\n        dec = self.decoder(z)\n\n        if not return_dict:\n            return (dec,)\n\n        return DecoderOutput(sample=dec)\n\n    def decode(self, z: torch.FloatTensor) -> Union[DecoderOutput, torch.FloatTensor]:\n        \"\"\"\n        Decode a batch of images.\n\n        Args:\n            z (`torch.Tensor`): Input batch of latent vectors.\n\n        Returns:\n            [`~models.vae.DecoderOutput`] or `tuple`:\n                If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is\n                returned.\n\n        \"\"\"\n\n        if self.use_slicing and z.shape[0] > 1:\n            decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]\n            decoded = torch.cat(decoded_slices)\n        else:\n            decoded = self._decode(z).sample\n\n        return decoded\n\n    def blend_v(\n        self, a: torch.Tensor, b: torch.Tensor, blend_extent: int\n    ) -> torch.Tensor:\n        blend_extent = min(a.shape[2], b.shape[2], blend_extent)\n        for y in range(blend_extent):\n            b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[\n                :, :, y, :\n            ] * (y / blend_extent)\n        return b\n\n    def blend_h(\n        self, a: torch.Tensor, b: torch.Tensor, blend_extent: int\n    ) -> torch.Tensor:\n        blend_extent = min(a.shape[3], b.shape[3], blend_extent)\n        for x in range(blend_extent):\n            b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[\n                :, :, :, x\n            ] * (x / blend_extent)\n        return b\n\n    def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:\n        r\"\"\"Encode a batch of images using a tiled encoder.\n\n        When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several\n        steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is\n        different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the\n        tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the\n        output, but they should be much less noticeable.\n\n        Args:\n            x (`torch.Tensor`): Input batch of images.\n\n        Returns:\n            `torch.Tensor`:\n                The latent representation of the encoded videos.\n        \"\"\"\n\n        overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))\n        blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)\n        row_limit = self.tile_latent_min_size - blend_extent\n\n        # Split the image into 512x512 tiles and encode them separately.\n        rows = []\n        for i in range(0, x.shape[2], overlap_size):\n            row = []\n            for j in range(0, x.shape[3], overlap_size):\n                tile = x[\n                    :,\n                    :,\n                    i : i + self.tile_sample_min_size,\n                    j : j + self.tile_sample_min_size,\n                ]\n                tile = self.encoder(tile)\n                if self.config.use_quant_conv:\n                    tile = self.quant_conv(tile)\n                row.append(tile)\n            rows.append(row)\n        result_rows = []\n        for i, row in enumerate(rows):\n            result_row = []\n            for j, tile in enumerate(row):\n                # blend the above tile and the left tile\n                # to the current tile and add the current tile to the result row\n                if i > 0:\n                    tile = self.blend_v(rows[i - 1][j], tile, blend_extent)\n                if j > 0:\n                    tile = self.blend_h(row[j - 1], tile, blend_extent)\n                result_row.append(tile[:, :, :row_limit, :row_limit])\n            result_rows.append(torch.cat(result_row, dim=3))\n\n        enc = torch.cat(result_rows, dim=2)\n        return enc\n\n    def tiled_encode(\n        self, x: torch.Tensor, return_dict: bool = True\n    ) -> AutoencoderKLOutput:\n        r\"\"\"Encode a batch of images using a tiled encoder.\n\n        When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several\n        steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is\n        different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the\n        tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the\n        output, but they should be much less noticeable.\n\n        Args:\n            x (`torch.Tensor`): Input batch of images.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.\n\n        Returns:\n            [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:\n                If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain\n                `tuple` is returned.\n        \"\"\"\n        deprecation_message = (\n            \"The tiled_encode implementation supporting the `return_dict` parameter is deprecated. In the future, the \"\n            \"implementation of this method will be replaced with that of `_tiled_encode` and you will no longer be able \"\n            \"to pass `return_dict`. You will also have to create a `DiagonalGaussianDistribution()` from the returned value.\"\n        )\n        # deprecate(\"tiled_encode\", \"1.0.0\", deprecation_message, standard_warn=False)\n\n        overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))\n        blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)\n        row_limit = self.tile_latent_min_size - blend_extent\n\n        # Split the image into 512x512 tiles and encode them separately.\n        rows = []\n        for i in range(0, x.shape[2], overlap_size):\n            row = []\n            for j in range(0, x.shape[3], overlap_size):\n                tile = x[\n                    :,\n                    :,\n                    i : i + self.tile_sample_min_size,\n                    j : j + self.tile_sample_min_size,\n                ]\n                tile = self.encoder(tile)\n                if self.config.use_quant_conv:\n                    tile = self.quant_conv(tile)\n                row.append(tile)\n            rows.append(row)\n        result_rows = []\n        for i, row in enumerate(rows):\n            result_row = []\n            for j, tile in enumerate(row):\n                # blend the above tile and the left tile\n                # to the current tile and add the current tile to the result row\n                if i > 0:\n                    tile = self.blend_v(rows[i - 1][j], tile, blend_extent)\n                if j > 0:\n                    tile = self.blend_h(row[j - 1], tile, blend_extent)\n                result_row.append(tile[:, :, :row_limit, :row_limit])\n            result_rows.append(torch.cat(result_row, dim=3))\n\n        moments = torch.cat(result_rows, dim=2)\n        posterior = DiagonalGaussianDistribution(moments)\n\n        if not return_dict:\n            return (posterior,)\n\n        return AutoencoderKLOutput(latent_dist=posterior)\n\n    def tiled_decode(\n        self, z: torch.Tensor, return_dict: bool = True\n    ) -> Union[DecoderOutput, torch.Tensor]:\n        r\"\"\"\n        Decode a batch of images using a tiled decoder.\n\n        Args:\n            z (`torch.Tensor`): Input batch of latent vectors.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.\n\n        Returns:\n            [`~models.vae.DecoderOutput`] or `tuple`:\n                If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is\n                returned.\n        \"\"\"\n        overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))\n        blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)\n        row_limit = self.tile_sample_min_size - blend_extent\n\n        # Split z into overlapping 64x64 tiles and decode them separately.\n        # The tiles have an overlap to avoid seams between tiles.\n        rows = []\n        for i in range(0, z.shape[2], overlap_size):\n            row = []\n            for j in range(0, z.shape[3], overlap_size):\n                tile = z[\n                    :,\n                    :,\n                    i : i + self.tile_latent_min_size,\n                    j : j + self.tile_latent_min_size,\n                ]\n                if self.config.use_post_quant_conv:\n                    tile = self.post_quant_conv(tile)\n                decoded = self.decoder(tile)\n                row.append(decoded)\n            rows.append(row)\n        result_rows = []\n        for i, row in enumerate(rows):\n            result_row = []\n            for j, tile in enumerate(row):\n                # blend the above tile and the left tile\n                # to the current tile and add the current tile to the result row\n                if i > 0:\n                    tile = self.blend_v(rows[i - 1][j], tile, blend_extent)\n                if j > 0:\n                    tile = self.blend_h(row[j - 1], tile, blend_extent)\n                result_row.append(tile[:, :, :row_limit, :row_limit])\n            result_rows.append(torch.cat(result_row, dim=3))\n\n        dec = torch.cat(result_rows, dim=2)\n        if not return_dict:\n            return (dec,)\n\n        return DecoderOutput(sample=dec)\n\n    def forward(\n        self,\n        sample: torch.Tensor,\n        sample_posterior: bool = False,\n        generator: Optional[torch.Generator] = None,\n    ) -> Union[DecoderOutput, torch.Tensor]:\n        r\"\"\"\n        Args:\n            sample (`torch.Tensor`): Input sample.\n            sample_posterior (`bool`, *optional*, defaults to `False`):\n                Whether to sample from the posterior.\n        \"\"\"\n        x = sample\n        posterior = self.encode(x).latent_dist\n        if sample_posterior:\n            z = posterior.sample(generator=generator)\n        else:\n            z = posterior.mode()\n        dec = self.decode(z).sample\n\n        return dec\n\n    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections\n    def fuse_qkv_projections(self):\n        \"\"\"\n        Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)\n        are fused. For cross-attention modules, key and value projection matrices are fused.\n\n        > [!WARNING] > This API is 🧪 experimental.\n        \"\"\"\n        self.original_attn_processors = None\n\n        for _, attn_processor in self.attn_processors.items():\n            if \"Added\" in str(attn_processor.__class__.__name__):\n                raise ValueError(\n                    \"`fuse_qkv_projections()` is not supported for models having added KV projections.\"\n                )\n\n        self.original_attn_processors = self.attn_processors\n\n        for module in self.modules():\n            if isinstance(module, Attention):\n                module.fuse_projections(fuse=True)\n\n        self.set_attn_processor(FusedAttnProcessor2_0())\n\n\nEntryClass = AutoencoderKL\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/vaes/autoencoder_dc.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\nfrom collections.abc import Iterable\n\nimport torch\nfrom torch import nn\n\nfrom sglang.multimodal_gen.configs.models.vaes.sana import SanaVAEConfig\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\nclass AutoencoderDC(nn.Module):\n    \"\"\"Deep Compression Autoencoder wrapper with 32x spatial compression.\"\"\"\n\n    def __init__(self, config: SanaVAEConfig = None, **kwargs):\n        super().__init__()\n        self._config = config\n        self._inner_model = None\n        self._loaded_state_dict: dict[str, torch.Tensor] = {}\n\n    def _ensure_inner_model(self, state_dict: dict[str, torch.Tensor] | None = None):\n        if self._inner_model is not None:\n            return\n\n        from diffusers import AutoencoderDC as DiffusersAutoencoderDC\n\n        device = \"cpu\"\n        state_to_load = (\n            state_dict if state_dict is not None else self._loaded_state_dict\n        )\n        if state_to_load:\n            first_tensor = next(iter(state_to_load.values()))\n            device = first_tensor.device\n        hf_config = {}\n        if self._config is not None:\n            arch = self._config.arch_config\n            for key, value in vars(arch).items():\n                if key == \"extra_attrs\" and isinstance(value, dict):\n                    for ek, ev in value.items():\n                        hf_config[ek] = ev\n                elif not key.startswith(\"_\") and not callable(value):\n                    hf_config[key] = value\n\n        self._inner_model = DiffusersAutoencoderDC.from_config(hf_config)\n\n        if state_to_load:\n            missing, unexpected = self._inner_model.load_state_dict(\n                state_to_load, strict=False\n            )\n            if missing:\n                logger.warning(\n                    \"AutoencoderDC missing keys when loading: %d keys\", len(missing)\n                )\n                if len(missing) > 10:\n                    logger.debug(\"First 10 missing keys: %s\", list(missing)[:10])\n                else:\n                    logger.debug(\"Missing keys: %s\", list(missing))\n            if unexpected:\n                logger.debug(\n                    \"AutoencoderDC unexpected keys when loading: %d keys\",\n                    len(unexpected),\n                )\n            if state_dict is None:\n                self._loaded_state_dict.clear()\n\n        self._inner_model = self._inner_model.to(device)\n\n    @property\n    def config(self):\n        if self._inner_model is not None:\n            return self._inner_model.config\n        return self._config\n\n    @property\n    def dtype(self):\n        if self._inner_model is not None:\n            return next(self._inner_model.parameters()).dtype\n        return torch.float32\n\n    @property\n    def device(self):\n        if self._inner_model is not None:\n            return next(self._inner_model.parameters()).device\n        return torch.device(\"cpu\")\n\n    def encode(self, x: torch.Tensor, **kwargs):\n        self._ensure_inner_model()\n        return self._inner_model.encode(x, **kwargs)\n\n    def decode(self, z: torch.Tensor, **kwargs):\n        self._ensure_inner_model()\n        z = z.to(dtype=self.dtype)\n        return self._inner_model.decode(z, **kwargs)\n\n    def forward(self, x: torch.Tensor, **kwargs):\n        self._ensure_inner_model()\n        return self._inner_model(x, **kwargs)\n\n    def load_state_dict(\n        self,\n        state_dict: dict[str, torch.Tensor],\n        strict: bool = True,\n        assign: bool = False,\n    ):\n        \"\"\"Intercept load_state_dict to route weights into the inner diffusers model.\"\"\"\n        self._ensure_inner_model(state_dict=state_dict)\n\n    def state_dict(self, *args, **kwargs) -> dict[str, torch.Tensor]:\n        self._ensure_inner_model()\n        return self._inner_model.state_dict(*args, **kwargs)\n\n    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:\n        \"\"\"Buffer weights for deferred loading. The inner model is built lazily.\"\"\"\n        loaded_params: set[str] = set()\n        for name, weight in weights:\n            self._loaded_state_dict[name] = weight\n            loaded_params.add(name)\n        return loaded_params\n\n    def to(self, *args, **kwargs):\n        if self._inner_model is not None:\n            self._inner_model = self._inner_model.to(*args, **kwargs)\n        return super().to(*args, **kwargs)\n\n\nEntryClass = AutoencoderDC\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/vaes/autoencoder_kl_flux2.py",
    "content": "import math\nfrom typing import Dict, Optional, Tuple, Union\n\nimport torch\nimport torch.nn as nn\nfrom diffusers.models.attention_processor import (\n    ADDED_KV_ATTENTION_PROCESSORS,\n    CROSS_ATTENTION_PROCESSORS,\n    AttentionProcessor,\n    AttnAddedKVProcessor,\n    AttnProcessor,\n)\nfrom diffusers.models.autoencoders.vae import (\n    Decoder,\n    DecoderOutput,\n    DiagonalGaussianDistribution,\n    Encoder,\n)\nfrom diffusers.models.modeling_outputs import AutoencoderKLOutput\n\nfrom sglang.multimodal_gen.configs.models.vaes.flux import Flux2VAEConfig\nfrom sglang.multimodal_gen.runtime.models.vaes.common import ParallelTiledVAE\n\n\nclass AutoencoderKLFlux2(ParallelTiledVAE):\n    r\"\"\"\n    A VAE model with KL loss for encoding images into latents and decoding latent representations into images.\n\n    This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented\n    for all models (such as downloading or saving).\n\n    Parameters:\n    \"\"\"\n\n    _supports_gradient_checkpointing = True\n    _no_split_modules = [\"BasicTransformerBlock\", \"ResnetBlock2D\"]\n\n    def __init__(\n        self,\n        config: Flux2VAEConfig,\n    ):\n        super().__init__(config=config)\n\n        self.config = config\n        arch_config = config.arch_config\n\n        in_channels: int = arch_config.in_channels\n        out_channels: int = arch_config.out_channels\n        down_block_types: Tuple[str, ...] = arch_config.down_block_types\n        up_block_types: Tuple[str, ...] = arch_config.up_block_types\n        block_out_channels: Tuple[int, ...] = arch_config.block_out_channels\n        layers_per_block: int = arch_config.layers_per_block\n        act_fn: str = arch_config.act_fn\n        latent_channels: int = arch_config.latent_channels\n        norm_num_groups: int = arch_config.norm_num_groups\n        sample_size: int = arch_config.sample_size\n        force_upcast: bool = arch_config.force_upcast\n        use_quant_conv: bool = arch_config.use_quant_conv\n        use_post_quant_conv: bool = arch_config.use_post_quant_conv\n        mid_block_add_attention: bool = arch_config.mid_block_add_attention\n        batch_norm_eps: float = arch_config.batch_norm_eps\n        batch_norm_momentum: float = arch_config.batch_norm_momentum\n        patch_size: Tuple[int, int] = arch_config.patch_size\n        # pass init params to Encoder\n        self.encoder = Encoder(\n            in_channels=in_channels,\n            out_channels=latent_channels,\n            down_block_types=down_block_types,\n            block_out_channels=block_out_channels,\n            layers_per_block=layers_per_block,\n            act_fn=act_fn,\n            norm_num_groups=norm_num_groups,\n            double_z=True,\n            mid_block_add_attention=mid_block_add_attention,\n        )\n\n        # pass init params to Decoder\n        self.decoder = Decoder(\n            in_channels=latent_channels,\n            out_channels=out_channels,\n            up_block_types=up_block_types,\n            block_out_channels=block_out_channels,\n            layers_per_block=layers_per_block,\n            norm_num_groups=norm_num_groups,\n            act_fn=act_fn,\n            mid_block_add_attention=mid_block_add_attention,\n        )\n\n        self.quant_conv = (\n            nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)\n            if use_quant_conv\n            else None\n        )\n        self.post_quant_conv = (\n            nn.Conv2d(latent_channels, latent_channels, 1)\n            if use_post_quant_conv\n            else None\n        )\n\n        self.bn = nn.BatchNorm2d(\n            math.prod(patch_size) * latent_channels,\n            eps=batch_norm_eps,\n            momentum=batch_norm_momentum,\n            affine=False,\n            track_running_stats=True,\n        )\n\n        self.use_slicing = False\n        self.use_tiling = False\n\n        # only relevant if vae tiling is enabled\n        self.tile_sample_min_size = self.config.sample_size\n        sample_size = (\n            self.config.sample_size[0]\n            if isinstance(self.config.sample_size, (list, tuple))\n            else self.config.sample_size\n        )\n        self.tile_latent_min_size = int(\n            sample_size / (2 ** (len(self.config.block_out_channels) - 1))\n        )\n        self.tile_overlap_factor = 0.25\n\n    @property\n    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors\n    def attn_processors(self) -> Dict[str, AttentionProcessor]:\n        r\"\"\"\n        Returns:\n            `dict` of attention processors: A dictionary containing all attention processors used in the model with\n            indexed by its weight name.\n        \"\"\"\n        # set recursively\n        processors = {}\n\n        def fn_recursive_add_processors(\n            name: str,\n            module: torch.nn.Module,\n            processors: Dict[str, AttentionProcessor],\n        ):\n            if hasattr(module, \"get_processor\"):\n                processors[f\"{name}.processor\"] = module.get_processor()\n\n            for sub_name, child in module.named_children():\n                fn_recursive_add_processors(f\"{name}.{sub_name}\", child, processors)\n\n            return processors\n\n        for name, module in self.named_children():\n            fn_recursive_add_processors(name, module, processors)\n\n        return processors\n\n    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor\n    def set_attn_processor(\n        self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]\n    ):\n        r\"\"\"\n        Sets the attention processor to use to compute attention.\n\n        Parameters:\n            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):\n                The instantiated processor class or a dictionary of processor classes that will be set as the processor\n                for **all** `Attention` layers.\n\n                If `processor` is a dict, the key needs to define the path to the corresponding cross attention\n                processor. This is strongly recommended when setting trainable attention processors.\n\n        \"\"\"\n        count = len(self.attn_processors.keys())\n\n        if isinstance(processor, dict) and len(processor) != count:\n            raise ValueError(\n                f\"A dict of processors was passed, but the number of processors {len(processor)} does not match the\"\n                f\" number of attention layers: {count}. Please make sure to pass {count} processor classes.\"\n            )\n\n        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):\n            if hasattr(module, \"set_processor\"):\n                if not isinstance(processor, dict):\n                    module.set_processor(processor)\n                else:\n                    module.set_processor(processor.pop(f\"{name}.processor\"))\n\n            for sub_name, child in module.named_children():\n                fn_recursive_attn_processor(f\"{name}.{sub_name}\", child, processor)\n\n        for name, module in self.named_children():\n            fn_recursive_attn_processor(name, module, processor)\n\n    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor\n    def set_default_attn_processor(self):\n        \"\"\"\n        Disables custom attention processors and sets the default attention implementation.\n        \"\"\"\n        if all(\n            proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS\n            for proc in self.attn_processors.values()\n        ):\n            processor = AttnAddedKVProcessor()\n        elif all(\n            proc.__class__ in CROSS_ATTENTION_PROCESSORS\n            for proc in self.attn_processors.values()\n        ):\n            processor = AttnProcessor()\n        else:\n            raise ValueError(\n                f\"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}\"\n            )\n\n        self.set_attn_processor(processor)\n\n    def _encode(self, x: torch.Tensor) -> torch.Tensor:\n        batch_size, num_channels, height, width = x.shape\n\n        if self.use_tiling and (\n            width > self.tile_sample_min_size or height > self.tile_sample_min_size\n        ):\n            return self._tiled_encode(x)\n\n        enc = self.encoder(x)\n        if self.quant_conv is not None:\n            enc = self.quant_conv(enc)\n\n        return enc\n\n    def encode(\n        self, x: torch.Tensor, return_dict: bool = True\n    ) -> Union[DiagonalGaussianDistribution]:\n        \"\"\"\n        Encode a batch of images into latents.\n\n        Args:\n            x (`torch.Tensor`): Input batch of images.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.\n\n        Returns:\n                The latent representations of the encoded images. If `return_dict` is True, a\n                [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.\n        \"\"\"\n\n        if x.ndim == 5:\n            assert x.shape[2] == 1\n            x = x.squeeze(2)\n\n        if self.use_slicing and x.shape[0] > 1:\n            encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]\n            h = torch.cat(encoded_slices)\n        else:\n            h = self._encode(x)\n\n        posterior = DiagonalGaussianDistribution(h)\n        return posterior\n\n    def _decode(\n        self, z: torch.Tensor, return_dict: bool = True\n    ) -> Union[DecoderOutput, torch.Tensor]:\n        if self.use_tiling and (\n            z.shape[-1] > self.tile_latent_min_size\n            or z.shape[-2] > self.tile_latent_min_size\n        ):\n            return self.tiled_decode(z, return_dict=return_dict)\n\n        if self.post_quant_conv is not None:\n            z = self.post_quant_conv(z)\n\n        dec = self.decoder(z)\n\n        if not return_dict:\n            return (dec,)\n\n        return DecoderOutput(sample=dec)\n\n    def decode(\n        self, z: torch.FloatTensor, return_dict: bool = True, generator=None\n    ) -> Union[DecoderOutput, torch.FloatTensor]:\n        \"\"\"\n        Decode a batch of images.\n\n        Args:\n            z (`torch.Tensor`): Input batch of latent vectors.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.\n\n        Returns:\n            [`~models.vae.DecoderOutput`] or `tuple`:\n                If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is\n                returned.\n\n        \"\"\"\n        if self.use_slicing and z.shape[0] > 1:\n            decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]\n            decoded = torch.cat(decoded_slices)\n        else:\n            decoded = self._decode(z).sample\n\n        return decoded\n\n    def blend_v(\n        self, a: torch.Tensor, b: torch.Tensor, blend_extent: int\n    ) -> torch.Tensor:\n        blend_extent = min(a.shape[2], b.shape[2], blend_extent)\n        for y in range(blend_extent):\n            b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[\n                :, :, y, :\n            ] * (y / blend_extent)\n        return b\n\n    def blend_h(\n        self, a: torch.Tensor, b: torch.Tensor, blend_extent: int\n    ) -> torch.Tensor:\n        blend_extent = min(a.shape[3], b.shape[3], blend_extent)\n        for x in range(blend_extent):\n            b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[\n                :, :, :, x\n            ] * (x / blend_extent)\n        return b\n\n    def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:\n        r\"\"\"Encode a batch of images using a tiled encoder.\n\n        When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several\n        steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is\n        different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the\n        tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the\n        output, but they should be much less noticeable.\n\n        Args:\n            x (`torch.Tensor`): Input batch of images.\n\n        Returns:\n            `torch.Tensor`:\n                The latent representation of the encoded videos.\n        \"\"\"\n\n        overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))\n        blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)\n        row_limit = self.tile_latent_min_size - blend_extent\n\n        # Split the image into 512x512 tiles and encode them separately.\n        rows = []\n        for i in range(0, x.shape[2], overlap_size):\n            row = []\n            for j in range(0, x.shape[3], overlap_size):\n                tile = x[\n                    :,\n                    :,\n                    i : i + self.tile_sample_min_size,\n                    j : j + self.tile_sample_min_size,\n                ]\n                tile = self.encoder(tile)\n                if self.config.use_quant_conv:\n                    tile = self.quant_conv(tile)\n                row.append(tile)\n            rows.append(row)\n        result_rows = []\n        for i, row in enumerate(rows):\n            result_row = []\n            for j, tile in enumerate(row):\n                # blend the above tile and the left tile\n                # to the current tile and add the current tile to the result row\n                if i > 0:\n                    tile = self.blend_v(rows[i - 1][j], tile, blend_extent)\n                if j > 0:\n                    tile = self.blend_h(row[j - 1], tile, blend_extent)\n                result_row.append(tile[:, :, :row_limit, :row_limit])\n            result_rows.append(torch.cat(result_row, dim=3))\n\n        enc = torch.cat(result_rows, dim=2)\n        return enc\n\n    def tiled_encode(\n        self, x: torch.Tensor, return_dict: bool = True\n    ) -> AutoencoderKLOutput:\n        r\"\"\"Encode a batch of images using a tiled encoder.\n\n        When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several\n        steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is\n        different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the\n        tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the\n        output, but they should be much less noticeable.\n\n        Args:\n            x (`torch.Tensor`): Input batch of images.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.\n\n        Returns:\n            [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:\n                If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain\n                `tuple` is returned.\n        \"\"\"\n        deprecation_message = (\n            \"The tiled_encode implementation supporting the `return_dict` parameter is deprecated. In the future, the \"\n            \"implementation of this method will be replaced with that of `_tiled_encode` and you will no longer be able \"\n            \"to pass `return_dict`. You will also have to create a `DiagonalGaussianDistribution()` from the returned value.\"\n        )\n\n        overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))\n        blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)\n        row_limit = self.tile_latent_min_size - blend_extent\n\n        # Split the image into 512x512 tiles and encode them separately.\n        rows = []\n        for i in range(0, x.shape[2], overlap_size):\n            row = []\n            for j in range(0, x.shape[3], overlap_size):\n                tile = x[\n                    :,\n                    :,\n                    i : i + self.tile_sample_min_size,\n                    j : j + self.tile_sample_min_size,\n                ]\n                tile = self.encoder(tile)\n                if self.config.use_quant_conv:\n                    tile = self.quant_conv(tile)\n                row.append(tile)\n            rows.append(row)\n        result_rows = []\n        for i, row in enumerate(rows):\n            result_row = []\n            for j, tile in enumerate(row):\n                # blend the above tile and the left tile\n                # to the current tile and add the current tile to the result row\n                if i > 0:\n                    tile = self.blend_v(rows[i - 1][j], tile, blend_extent)\n                if j > 0:\n                    tile = self.blend_h(row[j - 1], tile, blend_extent)\n                result_row.append(tile[:, :, :row_limit, :row_limit])\n            result_rows.append(torch.cat(result_row, dim=3))\n\n        moments = torch.cat(result_rows, dim=2)\n        posterior = DiagonalGaussianDistribution(moments)\n\n        if not return_dict:\n            return (posterior,)\n\n        return AutoencoderKLOutput(latent_dist=posterior)\n\n    def tiled_decode(\n        self, z: torch.Tensor, return_dict: bool = True\n    ) -> Union[DecoderOutput, torch.Tensor]:\n        r\"\"\"\n        Decode a batch of images using a tiled decoder.\n\n        Args:\n            z (`torch.Tensor`): Input batch of latent vectors.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.\n\n        Returns:\n            [`~models.vae.DecoderOutput`] or `tuple`:\n                If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is\n                returned.\n        \"\"\"\n        overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))\n        blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)\n        row_limit = self.tile_sample_min_size - blend_extent\n\n        # Split z into overlapping 64x64 tiles and decode them separately.\n        # The tiles have an overlap to avoid seams between tiles.\n        rows = []\n        for i in range(0, z.shape[2], overlap_size):\n            row = []\n            for j in range(0, z.shape[3], overlap_size):\n                tile = z[\n                    :,\n                    :,\n                    i : i + self.tile_latent_min_size,\n                    j : j + self.tile_latent_min_size,\n                ]\n                if self.config.use_post_quant_conv:\n                    tile = self.post_quant_conv(tile)\n                decoded = self.decoder(tile)\n                row.append(decoded)\n            rows.append(row)\n        result_rows = []\n        for i, row in enumerate(rows):\n            result_row = []\n            for j, tile in enumerate(row):\n                # blend the above tile and the left tile\n                # to the current tile and add the current tile to the result row\n                if i > 0:\n                    tile = self.blend_v(rows[i - 1][j], tile, blend_extent)\n                if j > 0:\n                    tile = self.blend_h(row[j - 1], tile, blend_extent)\n                result_row.append(tile[:, :, :row_limit, :row_limit])\n            result_rows.append(torch.cat(result_row, dim=3))\n\n        dec = torch.cat(result_rows, dim=2)\n        if not return_dict:\n            return (dec,)\n\n        return DecoderOutput(sample=dec)\n\n    def forward(\n        self,\n        sample: torch.Tensor,\n        sample_posterior: bool = False,\n        return_dict: bool = True,\n        generator: Optional[torch.Generator] = None,\n    ) -> Union[DecoderOutput, torch.Tensor]:\n        r\"\"\"\n        Args:\n            sample (`torch.Tensor`): Input sample.\n            sample_posterior (`bool`, *optional*, defaults to `False`):\n                Whether to sample from the posterior.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`DecoderOutput`] instead of a plain tuple.\n        \"\"\"\n        x = sample\n        posterior = self.encode(x).latent_dist\n        if sample_posterior:\n            z = posterior.sample(generator=generator)\n        else:\n            z = posterior.mode()\n        dec = self.decode(z).sample\n\n        if not return_dict:\n            return (dec,)\n\n        return DecoderOutput(sample=dec)\n\n\nEntryClass = AutoencoderKLFlux2\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/vaes/autoencoder_kl_qwenimage.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\nfrom typing import Optional, Tuple, Union\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom diffusers.models.activations import get_activation\nfrom diffusers.models.autoencoders.vae import (\n    DecoderOutput,\n    DiagonalGaussianDistribution,\n)\nfrom diffusers.models.modeling_outputs import AutoencoderKLOutput\n\nfrom sglang.multimodal_gen.configs.models.vaes.qwenimage import QwenImageVAEConfig\nfrom sglang.multimodal_gen.runtime.distributed import get_local_torch_device\nfrom sglang.multimodal_gen.runtime.models.vaes.common import ParallelTiledVAE\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)  # pylint: disable=invalid-name\n\nCACHE_T = 2\n\n\nclass QwenImageCausalConv3d(nn.Conv3d):\n    r\"\"\"\n    A custom 3D causal convolution layer with feature caching support.\n\n    This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature\n    caching for efficient inference.\n\n    Args:\n        in_channels (int): Number of channels in the input image\n        out_channels (int): Number of channels produced by the convolution\n        kernel_size (int or tuple): Size of the convolving kernel\n        stride (int or tuple, optional): Stride of the convolution. Default: 1\n        padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: Union[int, Tuple[int, int, int]],\n        stride: Union[int, Tuple[int, int, int]] = 1,\n        padding: Union[int, Tuple[int, int, int]] = 0,\n    ) -> None:\n        super().__init__(\n            in_channels=in_channels,\n            out_channels=out_channels,\n            kernel_size=kernel_size,\n            stride=stride,\n            padding=padding,\n        )\n\n        # Set up causal padding\n        self._padding = (\n            self.padding[2],\n            self.padding[2],\n            self.padding[1],\n            self.padding[1],\n            2 * self.padding[0],\n            0,\n        )\n        self.padding = (0, 0, 0)\n\n    def forward(self, x, cache_x=None):\n        padding = list(self._padding)\n        if cache_x is not None and self._padding[4] > 0:\n            cache_x = cache_x.to(x.device)\n            x = torch.cat([cache_x, x], dim=2)\n            padding[4] -= cache_x.shape[2]\n        x = F.pad(x, padding)\n        return super().forward(x)\n\n\nclass QwenImageRMS_norm(nn.Module):\n    r\"\"\"\n    A custom RMS normalization layer.\n\n    Args:\n        dim (int): The number of dimensions to normalize over.\n        channel_first (bool, optional): Whether the input tensor has channels as the first dimension.\n            Default is True.\n        images (bool, optional): Whether the input represents image data. Default is True.\n        bias (bool, optional): Whether to include a learnable bias term. Default is False.\n    \"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        channel_first: bool = True,\n        images: bool = True,\n        bias: bool = False,\n    ) -> None:\n        super().__init__()\n        broadcastable_dims = (1, 1, 1) if not images else (1, 1)\n        shape = (dim, *broadcastable_dims) if channel_first else (dim,)\n\n        self.channel_first = channel_first\n        self.scale = dim**0.5\n        self.gamma = nn.Parameter(torch.ones(shape))\n        self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0\n\n    def forward(self, x):\n        return (\n            F.normalize(x, dim=(1 if self.channel_first else -1))\n            * self.scale\n            * self.gamma\n            + self.bias\n        )\n\n\nclass QwenImageUpsample(nn.Upsample):\n    r\"\"\"\n    Perform upsampling while ensuring the output tensor has the same data type as the input.\n\n    Returns:\n        torch.Tensor: Upsampled tensor with the same data type as the input.\n    \"\"\"\n\n    def forward(self, x):\n        return super().forward(x.float()).type_as(x)\n\n\nclass QwenImageResample(nn.Module):\n    r\"\"\"\n    A custom resampling module for 2D and 3D data.\n\n    Args:\n        dim (int): The number of input/output channels.\n        mode (str): The resampling mode. Must be one of:\n            - 'none': No resampling (identity operation).\n            - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution.\n            - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution.\n            - 'downsample2d': 2D downsampling with zero-padding and convolution.\n            - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.\n    \"\"\"\n\n    def __init__(self, dim: int, mode: str) -> None:\n        super().__init__()\n        self.dim = dim\n        self.mode = mode\n\n        # layers\n        if mode == \"upsample2d\":\n            self.resample = nn.Sequential(\n                QwenImageUpsample(scale_factor=(2.0, 2.0), mode=\"nearest-exact\"),\n                nn.Conv2d(dim, dim // 2, 3, padding=1),\n            )\n        elif mode == \"upsample3d\":\n            self.resample = nn.Sequential(\n                QwenImageUpsample(scale_factor=(2.0, 2.0), mode=\"nearest-exact\"),\n                nn.Conv2d(dim, dim // 2, 3, padding=1),\n            )\n            self.time_conv = QwenImageCausalConv3d(\n                dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)\n            )\n\n        elif mode == \"downsample2d\":\n            self.resample = nn.Sequential(\n                nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))\n            )\n        elif mode == \"downsample3d\":\n            self.resample = nn.Sequential(\n                nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))\n            )\n            self.time_conv = QwenImageCausalConv3d(\n                dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)\n            )\n\n        else:\n            self.resample = nn.Identity()\n\n    def forward(self, x, feat_cache=None, feat_idx=[0]):\n        b, c, t, h, w = x.size()\n        if self.mode == \"upsample3d\":\n            if feat_cache is not None:\n                idx = feat_idx[0]\n                if feat_cache[idx] is None:\n                    feat_cache[idx] = \"Rep\"\n                    feat_idx[0] += 1\n                else:\n                    cache_x = x[:, :, -CACHE_T:, :, :].clone()\n                    if (\n                        cache_x.shape[2] < 2\n                        and feat_cache[idx] is not None\n                        and feat_cache[idx] != \"Rep\"\n                    ):\n                        # cache last frame of last two chunk\n                        cache_x = torch.cat(\n                            [\n                                feat_cache[idx][:, :, -1, :, :]\n                                .unsqueeze(2)\n                                .to(cache_x.device),\n                                cache_x,\n                            ],\n                            dim=2,\n                        )\n                    if (\n                        cache_x.shape[2] < 2\n                        and feat_cache[idx] is not None\n                        and feat_cache[idx] == \"Rep\"\n                    ):\n                        cache_x = torch.cat(\n                            [torch.zeros_like(cache_x).to(cache_x.device), cache_x],\n                            dim=2,\n                        )\n                    if feat_cache[idx] == \"Rep\":\n                        x = self.time_conv(x)\n                    else:\n                        x = self.time_conv(x, feat_cache[idx])\n                    feat_cache[idx] = cache_x\n                    feat_idx[0] += 1\n\n                    x = x.reshape(b, 2, c, t, h, w)\n                    x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)\n                    x = x.reshape(b, c, t * 2, h, w)\n        t = x.shape[2]\n        x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)\n        x = self.resample(x)\n        x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4)\n\n        if self.mode == \"downsample3d\":\n            if feat_cache is not None:\n                idx = feat_idx[0]\n                if feat_cache[idx] is None:\n                    feat_cache[idx] = x.clone()\n                    feat_idx[0] += 1\n                else:\n                    cache_x = x[:, :, -1:, :, :].clone()\n                    x = self.time_conv(\n                        torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)\n                    )\n                    feat_cache[idx] = cache_x\n                    feat_idx[0] += 1\n        return x\n\n\nclass QwenImageResidualBlock(nn.Module):\n    r\"\"\"\n    A custom residual block module.\n\n    Args:\n        in_dim (int): Number of input channels.\n        out_dim (int): Number of output channels.\n        dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0.\n        non_linearity (str, optional): Type of non-linearity to use. Default is \"silu\".\n    \"\"\"\n\n    def __init__(\n        self,\n        in_dim: int,\n        out_dim: int,\n        dropout: float = 0.0,\n        non_linearity: str = \"silu\",\n    ) -> None:\n        super().__init__()\n        self.in_dim = in_dim\n        self.out_dim = out_dim\n        self.nonlinearity = get_activation(non_linearity)\n\n        # layers\n        self.norm1 = QwenImageRMS_norm(in_dim, images=False)\n        self.conv1 = QwenImageCausalConv3d(in_dim, out_dim, 3, padding=1)\n        self.norm2 = QwenImageRMS_norm(out_dim, images=False)\n        self.dropout = nn.Dropout(dropout)\n        self.conv2 = QwenImageCausalConv3d(out_dim, out_dim, 3, padding=1)\n        self.conv_shortcut = (\n            QwenImageCausalConv3d(in_dim, out_dim, 1)\n            if in_dim != out_dim\n            else nn.Identity()\n        )\n\n    def forward(self, x, feat_cache=None, feat_idx=[0]):\n        # Apply shortcut connection\n        h = self.conv_shortcut(x)\n\n        # First normalization and activation\n        x = self.norm1(x)\n        x = self.nonlinearity(x)\n\n        if feat_cache is not None:\n            idx = feat_idx[0]\n            cache_x = x[:, :, -CACHE_T:, :, :].clone()\n            if cache_x.shape[2] < 2 and feat_cache[idx] is not None:\n                cache_x = torch.cat(\n                    [\n                        feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),\n                        cache_x,\n                    ],\n                    dim=2,\n                )\n\n            x = self.conv1(x, feat_cache[idx])\n            feat_cache[idx] = cache_x\n            feat_idx[0] += 1\n        else:\n            x = self.conv1(x)\n\n        # Second normalization and activation\n        x = self.norm2(x)\n        x = self.nonlinearity(x)\n\n        # Dropout\n        x = self.dropout(x)\n\n        if feat_cache is not None:\n            idx = feat_idx[0]\n            cache_x = x[:, :, -CACHE_T:, :, :].clone()\n            if cache_x.shape[2] < 2 and feat_cache[idx] is not None:\n                cache_x = torch.cat(\n                    [\n                        feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),\n                        cache_x,\n                    ],\n                    dim=2,\n                )\n\n            x = self.conv2(x, feat_cache[idx])\n            feat_cache[idx] = cache_x\n            feat_idx[0] += 1\n        else:\n            x = self.conv2(x)\n\n        # Add residual connection\n        return x + h\n\n\nclass QwenImageAttentionBlock(nn.Module):\n    r\"\"\"\n    Causal self-attention with a single head.\n\n    Args:\n        dim (int): The number of channels in the input tensor.\n    \"\"\"\n\n    def __init__(self, dim):\n        super().__init__()\n        self.dim = dim\n\n        # layers\n        self.norm = QwenImageRMS_norm(dim)\n        self.to_qkv = nn.Conv2d(dim, dim * 3, 1)\n        self.proj = nn.Conv2d(dim, dim, 1)\n\n    def forward(self, x):\n        identity = x\n        batch_size, channels, time, height, width = x.size()\n\n        x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width)\n        x = self.norm(x)\n\n        # compute query, key, value\n        qkv = self.to_qkv(x)\n        qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)\n        qkv = qkv.permute(0, 1, 3, 2).contiguous()\n        q, k, v = qkv.chunk(3, dim=-1)\n\n        # apply attention\n        x = F.scaled_dot_product_attention(q, k, v)\n\n        x = (\n            x.squeeze(1)\n            .permute(0, 2, 1)\n            .reshape(batch_size * time, channels, height, width)\n        )\n\n        # output projection\n        x = self.proj(x)\n\n        # Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w]\n        x = x.view(batch_size, time, channels, height, width)\n        x = x.permute(0, 2, 1, 3, 4)\n\n        return x + identity\n\n\nclass QwenImageMidBlock(nn.Module):\n    \"\"\"\n    Middle block for QwenImageVAE encoder and decoder.\n\n    Args:\n        dim (int): Number of input/output channels.\n        dropout (float): Dropout rate.\n        non_linearity (str): Type of non-linearity to use.\n    \"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        dropout: float = 0.0,\n        non_linearity: str = \"silu\",\n        num_layers: int = 1,\n    ):\n        super().__init__()\n        self.dim = dim\n\n        # Create the components\n        resnets = [QwenImageResidualBlock(dim, dim, dropout, non_linearity)]\n        attentions = []\n        for _ in range(num_layers):\n            attentions.append(QwenImageAttentionBlock(dim))\n            resnets.append(QwenImageResidualBlock(dim, dim, dropout, non_linearity))\n        self.attentions = nn.ModuleList(attentions)\n        self.resnets = nn.ModuleList(resnets)\n\n        self.gradient_checkpointing = False\n\n    def forward(self, x, feat_cache=None, feat_idx=[0]):\n        # First residual block\n        x = self.resnets[0](x, feat_cache, feat_idx)\n\n        # Process through attention and residual blocks\n        for attn, resnet in zip(self.attentions, self.resnets[1:]):\n            if attn is not None:\n                x = attn(x)\n\n            x = resnet(x, feat_cache, feat_idx)\n\n        return x\n\n\nclass QwenImageEncoder3d(nn.Module):\n    r\"\"\"\n    A 3D encoder module.\n\n    Args:\n        dim (int): The base number of channels in the first layer.\n        z_dim (int): The dimensionality of the latent space.\n        dim_mult (list of int): Multipliers for the number of channels in each block.\n        num_res_blocks (int): Number of residual blocks in each block.\n        attn_scales (list of float): Scales at which to apply attention mechanisms.\n        temperal_downsample (list of bool): Whether to downsample temporally in each block.\n        dropout (float): Dropout rate for the dropout layers.\n        non_linearity (str): Type of non-linearity to use.\n    \"\"\"\n\n    def __init__(\n        self,\n        dim=128,\n        z_dim=4,\n        dim_mult=[1, 2, 4, 4],\n        num_res_blocks=2,\n        attn_scales=[],\n        temperal_downsample=[True, True, False],\n        dropout=0.0,\n        non_linearity: str = \"silu\",\n        input_channels: int = 3,\n    ):\n        super().__init__()\n        # dim = config.arch_config.dim\n        # z_dim = config.arch_config.z_dim\n        # dim_mult = config.arch_config.dim_mult\n        # num_res_blocks = config.arch_config.num_res_blocks\n        # attn_scales = config.arch_config.attn_scales\n        # temperal_downsample = config.arch_config.temperal_downsample\n        # dropout = config.arch_config.dropout\n        # non_linearity = config.arch_config.non_linearity\n        self.dim = dim\n        self.z_dim = z_dim\n        self.dim_mult = dim_mult\n        self.num_res_blocks = num_res_blocks\n        self.attn_scales = attn_scales\n        self.temperal_downsample = temperal_downsample\n        self.nonlinearity = get_activation(non_linearity)\n\n        # dimensions\n        dims = [dim * u for u in [1] + dim_mult]\n        scale = 1.0\n\n        # init block\n        self.conv_in = QwenImageCausalConv3d(input_channels, dims[0], 3, padding=1)\n\n        # downsample blocks\n        self.down_blocks = nn.ModuleList([])\n        for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):\n            # residual (+attention) blocks\n            for _ in range(num_res_blocks):\n                self.down_blocks.append(\n                    QwenImageResidualBlock(in_dim, out_dim, dropout)\n                )\n                if scale in attn_scales:\n                    self.down_blocks.append(QwenImageAttentionBlock(out_dim))\n                in_dim = out_dim\n\n            # downsample block\n            if i != len(dim_mult) - 1:\n                mode = \"downsample3d\" if temperal_downsample[i] else \"downsample2d\"\n                self.down_blocks.append(QwenImageResample(out_dim, mode=mode))\n                scale /= 2.0\n\n        # middle blocks\n        self.mid_block = QwenImageMidBlock(\n            out_dim, dropout, non_linearity, num_layers=1\n        )\n\n        # output blocks\n        self.norm_out = QwenImageRMS_norm(out_dim, images=False)\n        self.conv_out = QwenImageCausalConv3d(out_dim, z_dim, 3, padding=1)\n\n        self.gradient_checkpointing = False\n\n    def forward(self, x, feat_cache=None, feat_idx=[0]):\n        if feat_cache is not None:\n            idx = feat_idx[0]\n            cache_x = x[:, :, -CACHE_T:, :, :].clone()\n            if cache_x.shape[2] < 2 and feat_cache[idx] is not None:\n                # cache last frame of last two chunk\n                cache_x = torch.cat(\n                    [\n                        feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),\n                        cache_x,\n                    ],\n                    dim=2,\n                )\n            x = self.conv_in(x, feat_cache[idx])\n            feat_cache[idx] = cache_x\n            feat_idx[0] += 1\n        else:\n            x = self.conv_in(x)\n\n        ## downsamples\n        for layer in self.down_blocks:\n            if feat_cache is not None:\n                x = layer(x, feat_cache, feat_idx)\n            else:\n                x = layer(x)\n\n        ## middle\n        x = self.mid_block(x, feat_cache, feat_idx)\n\n        ## head\n        x = self.norm_out(x)\n        x = self.nonlinearity(x)\n        if feat_cache is not None:\n            idx = feat_idx[0]\n            cache_x = x[:, :, -CACHE_T:, :, :].clone()\n            if cache_x.shape[2] < 2 and feat_cache[idx] is not None:\n                # cache last frame of last two chunk\n                cache_x = torch.cat(\n                    [\n                        feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),\n                        cache_x,\n                    ],\n                    dim=2,\n                )\n            x = self.conv_out(x, feat_cache[idx])\n            feat_cache[idx] = cache_x\n            feat_idx[0] += 1\n        else:\n            x = self.conv_out(x)\n        return x\n\n\nclass QwenImageUpBlock(nn.Module):\n    \"\"\"\n    A block that handles upsampling for the QwenImageVAE decoder.\n\n    Args:\n        in_dim (int): Input dimension\n        out_dim (int): Output dimension\n        num_res_blocks (int): Number of residual blocks\n        dropout (float): Dropout rate\n        upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d')\n        non_linearity (str): Type of non-linearity to use\n    \"\"\"\n\n    def __init__(\n        self,\n        in_dim: int,\n        out_dim: int,\n        num_res_blocks: int,\n        dropout: float = 0.0,\n        upsample_mode: Optional[str] = None,\n        non_linearity: str = \"silu\",\n    ):\n        super().__init__()\n        self.in_dim = in_dim\n        self.out_dim = out_dim\n\n        # Create layers list\n        resnets = []\n        # Add residual blocks and attention if needed\n        current_dim = in_dim\n        for _ in range(num_res_blocks + 1):\n            resnets.append(\n                QwenImageResidualBlock(current_dim, out_dim, dropout, non_linearity)\n            )\n            current_dim = out_dim\n\n        self.resnets = nn.ModuleList(resnets)\n\n        # Add upsampling layer if needed\n        self.upsamplers = None\n        if upsample_mode is not None:\n            self.upsamplers = nn.ModuleList(\n                [QwenImageResample(out_dim, mode=upsample_mode)]\n            )\n\n        self.gradient_checkpointing = False\n\n    def forward(self, x, feat_cache=None, feat_idx=[0]):\n        \"\"\"\n        Forward pass through the upsampling block.\n\n        Args:\n            x (torch.Tensor): Input tensor\n            feat_cache (list, optional): Feature cache for causal convolutions\n            feat_idx (list, optional): Feature index for cache management\n\n        Returns:\n            torch.Tensor: Output tensor\n        \"\"\"\n        for resnet in self.resnets:\n            if feat_cache is not None:\n                x = resnet(x, feat_cache, feat_idx)\n            else:\n                x = resnet(x)\n\n        if self.upsamplers is not None:\n            if feat_cache is not None:\n                x = self.upsamplers[0](x, feat_cache, feat_idx)\n            else:\n                x = self.upsamplers[0](x)\n        return x\n\n\nclass QwenImageDecoder3d(nn.Module):\n    r\"\"\"\n    A 3D decoder module.\n\n    Args:\n        dim (int): The base number of channels in the first layer.\n        z_dim (int): The dimensionality of the latent space.\n        dim_mult (list of int): Multipliers for the number of channels in each block.\n        num_res_blocks (int): Number of residual blocks in each block.\n        attn_scales (list of float): Scales at which to apply attention mechanisms.\n        temperal_upsample (list of bool): Whether to upsample temporally in each block.\n        dropout (float): Dropout rate for the dropout layers.\n        non_linearity (str): Type of non-linearity to use.\n    \"\"\"\n\n    def __init__(\n        self,\n        dim=128,\n        z_dim=4,\n        dim_mult=[1, 2, 4, 4],\n        num_res_blocks=2,\n        attn_scales=[],\n        temperal_upsample=[False, True, True],\n        dropout=0.0,\n        non_linearity: str = \"silu\",\n        input_channels=3,\n    ):\n        super().__init__()\n        self.dim = dim\n        self.z_dim = z_dim\n        self.dim_mult = dim_mult\n        self.num_res_blocks = num_res_blocks\n        self.attn_scales = attn_scales\n        self.temperal_upsample = temperal_upsample\n\n        self.nonlinearity = get_activation(non_linearity)\n\n        # dimensions\n        dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]\n        scale = 1.0 / 2 ** (len(dim_mult) - 2)\n\n        # init block\n        self.conv_in = QwenImageCausalConv3d(z_dim, dims[0], 3, padding=1)\n\n        # middle blocks\n        self.mid_block = QwenImageMidBlock(\n            dims[0], dropout, non_linearity, num_layers=1\n        )\n\n        # upsample blocks\n        self.up_blocks = nn.ModuleList([])\n        for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):\n            # residual (+attention) blocks\n            if i > 0:\n                in_dim = in_dim // 2\n\n            # Determine if we need upsampling\n            upsample_mode = None\n            if i != len(dim_mult) - 1:\n                upsample_mode = \"upsample3d\" if temperal_upsample[i] else \"upsample2d\"\n\n            # Create and add the upsampling block\n            up_block = QwenImageUpBlock(\n                in_dim=in_dim,\n                out_dim=out_dim,\n                num_res_blocks=num_res_blocks,\n                dropout=dropout,\n                upsample_mode=upsample_mode,\n                non_linearity=non_linearity,\n            )\n            self.up_blocks.append(up_block)\n\n            # Update scale for next iteration\n            if upsample_mode is not None:\n                scale *= 2.0\n\n        # output blocks\n        self.norm_out = QwenImageRMS_norm(out_dim, images=False)\n        self.conv_out = QwenImageCausalConv3d(out_dim, input_channels, 3, padding=1)\n\n        self.gradient_checkpointing = False\n\n    def forward(self, x, feat_cache=None, feat_idx=[0]):\n        ## conv1\n        if feat_cache is not None:\n            idx = feat_idx[0]\n            cache_x = x[:, :, -CACHE_T:, :, :].clone()\n            if cache_x.shape[2] < 2 and feat_cache[idx] is not None:\n                # cache last frame of last two chunk\n                cache_x = torch.cat(\n                    [\n                        feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),\n                        cache_x,\n                    ],\n                    dim=2,\n                )\n            x = self.conv_in(x, feat_cache[idx])\n            feat_cache[idx] = cache_x\n            feat_idx[0] += 1\n        else:\n            x = self.conv_in(x)\n\n        ## middle\n        x = self.mid_block(x, feat_cache, feat_idx)\n\n        ## upsamples\n        for up_block in self.up_blocks:\n            x = up_block(x, feat_cache, feat_idx)\n\n        ## head\n        x = self.norm_out(x)\n        x = self.nonlinearity(x)\n        if feat_cache is not None:\n            idx = feat_idx[0]\n            cache_x = x[:, :, -CACHE_T:, :, :].clone()\n            if cache_x.shape[2] < 2 and feat_cache[idx] is not None:\n                # cache last frame of last two chunk\n                cache_x = torch.cat(\n                    [\n                        feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),\n                        cache_x,\n                    ],\n                    dim=2,\n                )\n            x = self.conv_out(x, feat_cache[idx])\n            feat_cache[idx] = cache_x\n            feat_idx[0] += 1\n        else:\n            x = self.conv_out(x)\n        return x\n\n\nclass AutoencoderKLQwenImage(ParallelTiledVAE):\n    r\"\"\"\n    A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.\n\n    This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented\n    for all models (such as downloading or saving).\n    \"\"\"\n\n    _supports_gradient_checkpointing = False\n\n    # fmt: off\n    def __init__(\n        self,\n        config: QwenImageVAEConfig,\n    ) -> None:\n        # fmt: on\n        super().__init__(config=config)\n        base_dim = config.arch_config.base_dim\n        z_dim = config.arch_config.z_dim\n        dim_mult = config.arch_config.dim_mult\n        num_res_blocks = config.arch_config.num_res_blocks\n        attn_scales = config.arch_config.attn_scales\n        temperal_downsample = config.arch_config.temperal_downsample\n        dropout = config.arch_config.dropout\n        # non_linearity = config.arch_config.non_linearity\n        self.z_dim = z_dim\n        self.temperal_downsample = temperal_downsample\n        self.temperal_upsample = temperal_downsample[::-1]\n        self.input_channels = config.arch_config.input_channels\n        self.latents_mean = config.arch_config.latents_mean\n        self.config = config.arch_config\n\n        self.encoder = QwenImageEncoder3d(\n            base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout,\n            input_channels=self.input_channels\n        )\n        self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1)\n        self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1)\n\n        self.decoder = QwenImageDecoder3d(\n            base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout,\n            input_channels=self.input_channels\n        )\n\n        # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension\n        # to perform decoding of a single video latent at a time.\n        self.use_slicing = False\n\n        # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent\n        # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the\n        # intermediate tiles together, the memory requirement can be lowered.\n        self.use_tiling = False\n\n        # The minimal tile height and width for spatial tiling to be used\n        self.tile_sample_min_height = 256\n        self.tile_sample_min_width = 256\n\n        # The minimal distance between two spatial tiles\n        self.tile_sample_stride_height = 192\n        self.tile_sample_stride_width = 192\n\n        # Precompute and cache conv counts for encoder and decoder for clear_cache speedup\n        self._cached_conv_counts = {\n            \"decoder\": sum(isinstance(m, QwenImageCausalConv3d) for m in self.decoder.modules())\n            if self.decoder is not None\n            else 0,\n            \"encoder\": sum(isinstance(m, QwenImageCausalConv3d) for m in self.encoder.modules())\n            if self.encoder is not None\n            else 0,\n        }\n        cuda_device = get_local_torch_device()\n        # FIXME: hardcode\n        dtype = torch.bfloat16\n        latent_channels = config.arch_config.z_dim\n\n        self.shift_factor = (\n            torch.tensor(\n                config.arch_config.latents_mean\n            )\n            .view(1, latent_channels, 1, 1, 1)\n            .to(cuda_device, dtype)\n        )\n\n    def enable_tiling(\n        self,\n        tile_sample_min_height: Optional[int] = None,\n        tile_sample_min_width: Optional[int] = None,\n        tile_sample_stride_height: Optional[float] = None,\n        tile_sample_stride_width: Optional[float] = None,\n    ) -> None:\n        r\"\"\"\n        Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to\n        compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow\n        processing larger images.\n\n        Args:\n            tile_sample_min_height (`int`, *optional*):\n                The minimum height required for a sample to be separated into tiles across the height dimension.\n            tile_sample_min_width (`int`, *optional*):\n                The minimum width required for a sample to be separated into tiles across the width dimension.\n            tile_sample_stride_height (`int`, *optional*):\n                The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are\n                no tiling artifacts produced across the height dimension.\n            tile_sample_stride_width (`int`, *optional*):\n                The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling\n                artifacts produced across the width dimension.\n        \"\"\"\n        self.use_tiling = True\n        self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height\n        self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width\n        self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height\n        self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width\n\n    def disable_tiling(self) -> None:\n        r\"\"\"\n        Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing\n        decoding in one step.\n        \"\"\"\n        self.use_tiling = False\n\n    def enable_slicing(self) -> None:\n        r\"\"\"\n        Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to\n        compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.\n        \"\"\"\n        self.use_slicing = True\n\n    def disable_slicing(self) -> None:\n        r\"\"\"\n        Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing\n        decoding in one step.\n        \"\"\"\n        self.use_slicing = False\n\n    def clear_cache(self):\n        def _count_conv3d(model):\n            count = 0\n            for m in model.modules():\n                if isinstance(m, QwenImageCausalConv3d):\n                    count += 1\n            return count\n\n        self._conv_num = _count_conv3d(self.decoder)\n        self._conv_idx = [0]\n        self._feat_map = [None] * self._conv_num\n        # cache encode\n        self._enc_conv_num = _count_conv3d(self.encoder)\n        self._enc_conv_idx = [0]\n        self._enc_feat_map = [None] * self._enc_conv_num\n\n    def _encode(self, x: torch.Tensor):\n        _, _, num_frame, height, width = x.shape\n\n        if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):\n            return self.tiled_encode(x)\n\n        self.clear_cache()\n        iter_ = 1 + (num_frame - 1) // 4\n        for i in range(iter_):\n            self._enc_conv_idx = [0]\n            if i == 0:\n                out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)\n            else:\n                out_ = self.encoder(\n                    x[:, :, 1 + 4 * (i - 1): 1 + 4 * i, :, :],\n                    feat_cache=self._enc_feat_map,\n                    feat_idx=self._enc_conv_idx,\n                )\n                out = torch.cat([out, out_], 2)\n\n        enc = self.quant_conv(out)\n        self.clear_cache()\n        return enc\n\n    def encode(\n        self, x: torch.Tensor, return_dict: bool = True\n    ) -> DiagonalGaussianDistribution:\n        r\"\"\"\n        Encode a batch of images into latents.\n\n        Args:\n            x (`torch.Tensor`): Input batch of images.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.\n\n        Returns:\n                The latent representations of the encoded videos. If `return_dict` is True, a\n                [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.\n        \"\"\"\n        if self.use_slicing and x.shape[0] > 1:\n            encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]\n            h = torch.cat(encoded_slices)\n        else:\n            h = self._encode(x)\n        posterior = DiagonalGaussianDistribution(h)\n\n        return posterior\n\n    def _decode(self, z: torch.Tensor, return_dict: bool = True):\n        _, _, num_frame, height, width = z.shape\n        tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio\n        tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio\n\n        if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):\n            return self.tiled_decode(z, return_dict=return_dict)\n\n        self.clear_cache()\n        x = self.post_quant_conv(z)\n        for i in range(num_frame):\n            self._conv_idx = [0]\n            if i == 0:\n                out = self.decoder(x[:, :, i: i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)\n            else:\n                out_ = self.decoder(x[:, :, i: i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)\n                out = torch.cat([out, out_], 2)\n\n        out = torch.clamp(out, min=-1.0, max=1.0)\n        self.clear_cache()\n        if not return_dict:\n            return (out,)\n\n        return DecoderOutput(sample=out)\n\n    def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:\n        r\"\"\"\n        Decode a batch of images.\n\n        Args:\n            z (`torch.Tensor`): Input batch of latent vectors.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.\n\n        Returns:\n            [`~models.vae.DecoderOutput`] or `tuple`:\n                If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is\n                returned.\n        \"\"\"\n        if self.use_slicing and z.shape[0] > 1:\n            decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]\n            decoded = torch.cat(decoded_slices)\n        else:\n            decoded = self._decode(z).sample\n\n        return decoded\n\n    def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:\n        blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)\n        for y in range(blend_extent):\n            b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (\n                y / blend_extent\n            )\n        return b\n\n    def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:\n        blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)\n        for x in range(blend_extent):\n            b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (\n                x / blend_extent\n            )\n        return b\n\n    def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:\n        r\"\"\"Encode a batch of images using a tiled encoder.\n\n        Args:\n            x (`torch.Tensor`): Input batch of videos.\n\n        Returns:\n            `torch.Tensor`:\n                The latent representation of the encoded videos.\n        \"\"\"\n        _, _, num_frames, height, width = x.shape\n        latent_height = height // self.spatial_compression_ratio\n        latent_width = width // self.spatial_compression_ratio\n\n        tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio\n        tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio\n        tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio\n        tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio\n\n        blend_height = tile_latent_min_height - tile_latent_stride_height\n        blend_width = tile_latent_min_width - tile_latent_stride_width\n\n        # Split x into overlapping tiles and encode them separately.\n        # The tiles have an overlap to avoid seams between tiles.\n        rows = []\n        for i in range(0, height, self.tile_sample_stride_height):\n            row = []\n            for j in range(0, width, self.tile_sample_stride_width):\n                self.clear_cache()\n                time = []\n                frame_range = 1 + (num_frames - 1) // 4\n                for k in range(frame_range):\n                    self._enc_conv_idx = [0]\n                    if k == 0:\n                        tile = x[:, :, :1, i: i + self.tile_sample_min_height, j: j + self.tile_sample_min_width]\n                    else:\n                        tile = x[\n                            :,\n                            :,\n                            1 + 4 * (k - 1): 1 + 4 * k,\n                            i: i + self.tile_sample_min_height,\n                            j: j + self.tile_sample_min_width,\n                        ]\n                    tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)\n                    tile = self.quant_conv(tile)\n                    time.append(tile)\n                row.append(torch.cat(time, dim=2))\n            rows.append(row)\n        self.clear_cache()\n\n        result_rows = []\n        for i, row in enumerate(rows):\n            result_row = []\n            for j, tile in enumerate(row):\n                # blend the above tile and the left tile\n                # to the current tile and add the current tile to the result row\n                if i > 0:\n                    tile = self.blend_v(rows[i - 1][j], tile, blend_height)\n                if j > 0:\n                    tile = self.blend_h(row[j - 1], tile, blend_width)\n                result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])\n            result_rows.append(torch.cat(result_row, dim=-1))\n\n        enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]\n        return enc\n\n    def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:\n        r\"\"\"\n        Decode a batch of images using a tiled decoder.\n\n        Args:\n            z (`torch.Tensor`): Input batch of latent vectors.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.\n\n        Returns:\n            [`~models.vae.DecoderOutput`] or `tuple`:\n                If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is\n                returned.\n        \"\"\"\n        _, _, num_frames, height, width = z.shape\n        sample_height = height * self.spatial_compression_ratio\n        sample_width = width * self.spatial_compression_ratio\n\n        tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio\n        tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio\n        tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio\n        tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio\n\n        blend_height = self.tile_sample_min_height - self.tile_sample_stride_height\n        blend_width = self.tile_sample_min_width - self.tile_sample_stride_width\n\n        # Split z into overlapping tiles and decode them separately.\n        # The tiles have an overlap to avoid seams between tiles.\n        rows = []\n        for i in range(0, height, tile_latent_stride_height):\n            row = []\n            for j in range(0, width, tile_latent_stride_width):\n                self.clear_cache()\n                time = []\n                for k in range(num_frames):\n                    self._conv_idx = [0]\n                    tile = z[:, :, k: k + 1, i: i + tile_latent_min_height, j: j + tile_latent_min_width]\n                    tile = self.post_quant_conv(tile)\n                    decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)\n                    time.append(decoded)\n                row.append(torch.cat(time, dim=2))\n            rows.append(row)\n        self.clear_cache()\n\n        result_rows = []\n        for i, row in enumerate(rows):\n            result_row = []\n            for j, tile in enumerate(row):\n                # blend the above tile and the left tile\n                # to the current tile and add the current tile to the result row\n                if i > 0:\n                    tile = self.blend_v(rows[i - 1][j], tile, blend_height)\n                if j > 0:\n                    tile = self.blend_h(row[j - 1], tile, blend_width)\n                result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])\n            result_rows.append(torch.cat(result_row, dim=-1))\n\n        dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]\n\n        if not return_dict:\n            return (dec,)\n        return DecoderOutput(sample=dec)\n\n    def forward(\n        self,\n        sample: torch.Tensor,\n        sample_posterior: bool = False,\n        return_dict: bool = True,\n        generator: Optional[torch.Generator] = None,\n    ) -> Union[DecoderOutput, torch.Tensor]:\n        \"\"\"\n        Args:\n            sample (`torch.Tensor`): Input sample.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`DecoderOutput`] instead of a plain tuple.\n        \"\"\"\n        x = sample\n        posterior = self.encode(x).latent_dist\n        if sample_posterior:\n            z = posterior.sample(generator=generator)\n        else:\n            z = posterior.mode()\n        dec = self.decode(z, return_dict=return_dict)\n        return dec\n\n\nEntryClass = AutoencoderKLQwenImage\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/vaes/common.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\nfrom abc import ABC, abstractmethod\nfrom collections.abc import Iterator\nfrom math import prod\nfrom typing import Optional, cast\n\nimport numpy as np\nimport torch\nimport torch.distributed as dist\nfrom diffusers.models.autoencoders.vae import DiagonalGaussianDistribution\nfrom diffusers.utils.torch_utils import randn_tensor\nfrom torch import nn\n\nfrom sglang.multimodal_gen.configs.models import VAEConfig\nfrom sglang.multimodal_gen.runtime.distributed import (\n    get_sp_parallel_rank,\n    get_sp_world_size,\n)\n\n\nclass ParallelTiledVAE(ABC, nn.Module):\n    tile_sample_min_height: int\n    tile_sample_min_width: int\n    tile_sample_min_num_frames: int\n    tile_sample_stride_height: int\n    tile_sample_stride_width: int\n    tile_sample_stride_num_frames: int\n    blend_num_frames: int\n    use_tiling: bool\n    use_temporal_tiling: bool\n    use_parallel_tiling: bool\n\n    def __init__(self, config: VAEConfig, **kwargs) -> None:\n        super().__init__()\n        self.config = config\n        self.tile_sample_min_height = config.tile_sample_min_height\n        self.tile_sample_min_width = config.tile_sample_min_width\n        self.tile_sample_min_num_frames = config.tile_sample_min_num_frames\n        self.tile_sample_stride_height = config.tile_sample_stride_height\n        self.tile_sample_stride_width = config.tile_sample_stride_width\n        self.tile_sample_stride_num_frames = config.tile_sample_stride_num_frames\n        self.blend_num_frames = config.blend_num_frames\n        self.use_tiling = config.use_tiling\n        self.use_temporal_tiling = config.use_temporal_tiling\n        self.use_parallel_tiling = config.use_parallel_tiling\n\n    @property\n    def device(self):\n        return next(self.parameters()).device\n\n    @property\n    def temporal_compression_ratio(self) -> int:\n        return cast(int, self.config.temporal_compression_ratio)\n\n    @property\n    def spatial_compression_ratio(self) -> int:\n        return cast(int, self.config.spatial_compression_ratio)\n\n    @property\n    def scaling_factor(self) -> float | torch.Tensor:\n        return cast(float | torch.Tensor, self.config.scaling_factor)\n\n    @abstractmethod\n    def _encode(self, *args, **kwargs) -> torch.Tensor:\n        pass\n\n    @abstractmethod\n    def _decode(self, *args, **kwargs) -> torch.Tensor:\n        pass\n\n    def encode(self, x: torch.Tensor) -> DiagonalGaussianDistribution:\n        batch_size, num_channels, num_frames, height, width = x.shape\n        latent_num_frames = (num_frames - 1) // self.temporal_compression_ratio + 1\n\n        if (\n            self.use_tiling\n            and self.use_temporal_tiling\n            and num_frames > self.tile_sample_min_num_frames\n        ):\n            latents = self.tiled_encode(x)[:, :, :latent_num_frames]\n        elif self.use_tiling and (\n            width > self.tile_sample_min_width or height > self.tile_sample_min_height\n        ):\n            latents = self.spatial_tiled_encode(x)[:, :, :latent_num_frames]\n        else:\n            latents = self._encode(x)[:, :, :latent_num_frames]\n        return DiagonalGaussianDistribution(latents)\n\n    def decode(self, z: torch.Tensor) -> torch.Tensor:\n        batch_size, num_channels, num_frames, height, width = z.shape\n        tile_latent_min_height = (\n            self.tile_sample_min_height // self.spatial_compression_ratio\n        )\n        tile_latent_min_width = (\n            self.tile_sample_stride_width // self.spatial_compression_ratio\n        )\n        tile_latent_min_num_frames = (\n            self.tile_sample_min_num_frames // self.temporal_compression_ratio\n        )\n        num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1\n\n        if self.use_tiling and self.use_parallel_tiling and get_sp_world_size() > 1:\n            return self.parallel_tiled_decode(z)[:, :, :num_sample_frames]\n        if (\n            self.use_tiling\n            and self.use_temporal_tiling\n            and num_frames > tile_latent_min_num_frames\n        ):\n            return self.tiled_decode(z)[:, :, :num_sample_frames]\n\n        if self.use_tiling and (\n            width > tile_latent_min_width or height > tile_latent_min_height\n        ):\n            return self.spatial_tiled_decode(z)[:, :, :num_sample_frames]\n\n        return self._decode(z)[:, :, :num_sample_frames]\n\n    def blend_v(\n        self, a: torch.Tensor, b: torch.Tensor, blend_extent: int\n    ) -> torch.Tensor:\n        blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)\n        for y in range(blend_extent):\n            b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (\n                1 - y / blend_extent\n            ) + b[:, :, :, y, :] * (y / blend_extent)\n        return b\n\n    def blend_h(\n        self, a: torch.Tensor, b: torch.Tensor, blend_extent: int\n    ) -> torch.Tensor:\n        blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)\n        for x in range(blend_extent):\n            b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (\n                1 - x / blend_extent\n            ) + b[:, :, :, :, x] * (x / blend_extent)\n        return b\n\n    def blend_t(\n        self, a: torch.Tensor, b: torch.Tensor, blend_extent: int\n    ) -> torch.Tensor:\n        blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)\n        for x in range(blend_extent):\n            b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (\n                1 - x / blend_extent\n            ) + b[:, :, x, :, :] * (x / blend_extent)\n        return b\n\n    def spatial_tiled_encode(self, x: torch.Tensor) -> torch.Tensor:\n        r\"\"\"Encode a batch of images using a tiled encoder.\n\n        Args:\n            x (`torch.Tensor`): Input batch of videos.\n\n        Returns:\n            `torch.Tensor`:\n                The latent representation of the encoded videos.\n        \"\"\"\n        _, _, _, height, width = x.shape\n        # latent_height = height // self.spatial_compression_ratio\n        # latent_width = width // self.spatial_compression_ratio\n\n        tile_latent_min_height = (\n            self.tile_sample_min_height // self.spatial_compression_ratio\n        )\n        tile_latent_min_width = (\n            self.tile_sample_min_width // self.spatial_compression_ratio\n        )\n        tile_latent_stride_height = (\n            self.tile_sample_stride_height // self.spatial_compression_ratio\n        )\n        tile_latent_stride_width = (\n            self.tile_sample_stride_width // self.spatial_compression_ratio\n        )\n\n        blend_height = tile_latent_min_height - tile_latent_stride_height\n        blend_width = tile_latent_min_width - tile_latent_stride_width\n\n        # Split x into overlapping tiles and encode them separately.\n        # The tiles have an overlap to avoid seams between tiles.\n        rows = []\n        for i in range(0, height, self.tile_sample_stride_height):\n            row = []\n            for j in range(0, width, self.tile_sample_stride_width):\n                tile = x[\n                    :,\n                    :,\n                    :,\n                    i : i + self.tile_sample_min_height,\n                    j : j + self.tile_sample_min_width,\n                ]\n                tile = self._encode(tile)\n                row.append(tile)\n            rows.append(row)\n\n        return self._merge_spatial_tiles(\n            rows,\n            blend_height,\n            blend_width,\n            tile_latent_stride_height,\n            tile_latent_stride_width,\n        )\n\n    def _parallel_data_generator(\n        self, gathered_results, gathered_dim_metadata\n    ) -> Iterator[tuple[torch.Tensor, int]]:\n        global_idx = 0\n        for i, per_rank_metadata in enumerate(gathered_dim_metadata):\n            _start_shape = 0\n            for shape in per_rank_metadata:\n                mul_shape = prod(shape)\n                yield (\n                    gathered_results[\n                        i, _start_shape : _start_shape + mul_shape\n                    ].reshape(shape),\n                    global_idx,\n                )\n                _start_shape += mul_shape\n                global_idx += 1\n\n    def parallel_tiled_decode(self, z: torch.FloatTensor) -> torch.FloatTensor:\n        \"\"\"\n        Parallel version of tiled_decode that distributes both temporal and spatial computation across GPUs\n        \"\"\"\n        world_size, rank = get_sp_world_size(), get_sp_parallel_rank()\n        B, C, T, H, W = z.shape\n\n        # Calculate parameters\n        tile_latent_min_height = (\n            self.tile_sample_min_height // self.spatial_compression_ratio\n        )\n        tile_latent_min_width = (\n            self.tile_sample_min_width // self.spatial_compression_ratio\n        )\n        tile_latent_min_num_frames = (\n            self.tile_sample_min_num_frames // self.temporal_compression_ratio\n        )\n        tile_latent_stride_height = (\n            self.tile_sample_stride_height // self.spatial_compression_ratio\n        )\n        tile_latent_stride_width = (\n            self.tile_sample_stride_width // self.spatial_compression_ratio\n        )\n        tile_latent_stride_num_frames = (\n            self.tile_sample_stride_num_frames // self.temporal_compression_ratio\n        )\n\n        blend_height = self.tile_sample_min_height - self.tile_sample_stride_height\n        blend_width = self.tile_sample_min_width - self.tile_sample_stride_width\n\n        # Calculate tile dimensions\n        num_t_tiles = (\n            T + tile_latent_stride_num_frames - 1\n        ) // tile_latent_stride_num_frames\n        num_h_tiles = (H + tile_latent_stride_height - 1) // tile_latent_stride_height\n        num_w_tiles = (W + tile_latent_stride_width - 1) // tile_latent_stride_width\n        total_spatial_tiles = num_h_tiles * num_w_tiles\n        total_tiles = num_t_tiles * total_spatial_tiles\n\n        # Calculate tiles per rank and padding\n        tiles_per_rank = (total_tiles + world_size - 1) // world_size\n        start_tile_idx = rank * tiles_per_rank\n        end_tile_idx = min((rank + 1) * tiles_per_rank, total_tiles)\n\n        local_results = []\n        local_dim_metadata = []\n        # Process assigned tiles\n        for local_idx, global_idx in enumerate(range(start_tile_idx, end_tile_idx)):\n            t_idx = global_idx // total_spatial_tiles\n            spatial_idx = global_idx % total_spatial_tiles\n            h_idx = spatial_idx // num_w_tiles\n            w_idx = spatial_idx % num_w_tiles\n\n            # Calculate positions\n            t_start = t_idx * tile_latent_stride_num_frames\n            h_start = h_idx * tile_latent_stride_height\n            w_start = w_idx * tile_latent_stride_width\n\n            # Extract and process tile\n            tile = z[\n                :,\n                :,\n                t_start : t_start + tile_latent_min_num_frames + 1,\n                h_start : h_start + tile_latent_min_height,\n                w_start : w_start + tile_latent_min_width,\n            ]\n\n            # Process tile\n            tile = self._decode(tile)\n\n            if t_start > 0:\n                tile = tile[:, :, 1:, :, :]\n\n            # Store metadata\n            shape = tile.shape\n            # Store decoded data (flattened)\n            decoded_flat = tile.reshape(-1)\n            local_results.append(decoded_flat)\n            local_dim_metadata.append(shape)\n\n        results = torch.cat(local_results, dim=0).contiguous()\n        del local_results\n        # first gather size to pad the results\n        local_size = torch.tensor(\n            [results.size(0)], device=results.device, dtype=torch.int64\n        )\n        all_sizes = [\n            torch.zeros(1, device=results.device, dtype=torch.int64)\n            for _ in range(world_size)\n        ]\n        dist.all_gather(all_sizes, local_size)\n        max_size = max(size.item() for size in all_sizes)\n        padded_results = torch.zeros(max_size, device=results.device)\n        padded_results[: results.size(0)] = results\n        del results\n\n        # Gather all results\n        gathered_dim_metadata = [None] * world_size\n        gathered_results = (\n            torch.zeros_like(padded_results)\n            .repeat(world_size, *[1] * len(padded_results.shape))\n            .contiguous()\n        )  # use contiguous to make sure it won't copy data in the following operations\n        # TODO (PY): use sgl_diffusion distributed methods\n        dist.all_gather_into_tensor(gathered_results, padded_results)\n        dist.all_gather_object(gathered_dim_metadata, local_dim_metadata)\n        # Process gathered results\n        data: list = [\n            [[[] for _ in range(num_w_tiles)] for _ in range(num_h_tiles)]\n            for _ in range(num_t_tiles)\n        ]\n        for current_data, global_idx in self._parallel_data_generator(\n            gathered_results, gathered_dim_metadata\n        ):\n            t_idx = global_idx // total_spatial_tiles\n            spatial_idx = global_idx % total_spatial_tiles\n            h_idx = spatial_idx // num_w_tiles\n            w_idx = spatial_idx % num_w_tiles\n            data[t_idx][h_idx][w_idx] = current_data\n        # Merge results\n        result_slices = []\n        last_slice_data = None\n        for i, tem_data in enumerate(data):\n            slice_data = self._merge_spatial_tiles(\n                tem_data,\n                blend_height,\n                blend_width,\n                self.tile_sample_stride_height,\n                self.tile_sample_stride_width,\n            )\n            if i > 0:\n                slice_data = self.blend_t(\n                    last_slice_data, slice_data, self.blend_num_frames\n                )\n                result_slices.append(\n                    slice_data[:, :, : self.tile_sample_stride_num_frames, :, :]\n                )\n            else:\n                result_slices.append(\n                    slice_data[:, :, : self.tile_sample_stride_num_frames + 1, :, :]\n                )\n            last_slice_data = slice_data\n        dec = torch.cat(result_slices, dim=2)\n\n        return dec\n\n    def _merge_spatial_tiles(\n        self, tiles, blend_height, blend_width, stride_height, stride_width\n    ) -> torch.Tensor:\n        \"\"\"Helper function to merge spatial tiles with blending\"\"\"\n        result_rows = []\n        for i, row in enumerate(tiles):\n            result_row = []\n            for j, tile in enumerate(row):\n                if i > 0:\n                    tile = self.blend_v(tiles[i - 1][j], tile, blend_height)\n                if j > 0:\n                    tile = self.blend_h(row[j - 1], tile, blend_width)\n                result_row.append(tile[:, :, :, :stride_height, :stride_width])\n            result_rows.append(torch.cat(result_row, dim=-1))\n        return torch.cat(result_rows, dim=-2)\n\n    def spatial_tiled_decode(self, z: torch.Tensor) -> torch.Tensor:\n        r\"\"\"\n        Decode a batch of images using a tiled decoder.\n\n        Args:\n            z (`torch.Tensor`): Input batch of latent vectors.\n\n        Returns:\n            `torch.Tensor`:\n                The decoded images.\n        \"\"\"\n\n        _, _, _, height, width = z.shape\n        # sample_height = height * self.spatial_compression_ratio\n        # sample_width = width * self.spatial_compression_ratio\n\n        tile_latent_min_height = (\n            self.tile_sample_min_height // self.spatial_compression_ratio\n        )\n        tile_latent_min_width = (\n            self.tile_sample_min_width // self.spatial_compression_ratio\n        )\n        tile_latent_stride_height = (\n            self.tile_sample_stride_height // self.spatial_compression_ratio\n        )\n        tile_latent_stride_width = (\n            self.tile_sample_stride_width // self.spatial_compression_ratio\n        )\n\n        blend_height = self.tile_sample_min_height - self.tile_sample_stride_height\n        blend_width = self.tile_sample_min_width - self.tile_sample_stride_width\n\n        # Split z into overlapping tiles and decode them separately.\n        # The tiles have an overlap to avoid seams between tiles.\n        rows = []\n        for i in range(0, height, tile_latent_stride_height):\n            row = []\n            for j in range(0, width, tile_latent_stride_width):\n                tile = z[\n                    :,\n                    :,\n                    :,\n                    i : i + tile_latent_min_height,\n                    j : j + tile_latent_min_width,\n                ]\n                decoded = self._decode(tile)\n                row.append(decoded)\n            rows.append(row)\n        return self._merge_spatial_tiles(\n            rows,\n            blend_height,\n            blend_width,\n            self.tile_sample_stride_height,\n            self.tile_sample_stride_width,\n        )\n\n    def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:\n        _, _, num_frames, height, width = x.shape\n\n        # tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio\n        tile_latent_stride_num_frames = (\n            self.tile_sample_stride_num_frames // self.temporal_compression_ratio\n        )\n\n        row = []\n        for i in range(0, num_frames, self.tile_sample_stride_num_frames):\n            tile = x[:, :, i : i + self.tile_sample_min_num_frames + 1, :, :]\n            if self.use_tiling and (\n                height > self.tile_sample_min_height\n                or width > self.tile_sample_min_width\n            ):\n                tile = self.spatial_tiled_encode(tile)\n            else:\n                tile = self._encode(tile)\n            if i > 0:\n                tile = tile[:, :, 1:, :, :]\n            row.append(tile)\n        result_row = []\n        for i, tile in enumerate(row):\n            if i > 0:\n                tile = self.blend_t(row[i - 1], tile, self.blend_num_frames)\n                result_row.append(tile[:, :, :tile_latent_stride_num_frames, :, :])\n            else:\n                result_row.append(tile[:, :, : tile_latent_stride_num_frames + 1, :, :])\n        enc = torch.cat(result_row, dim=2)\n        return enc\n\n    def tiled_decode(self, z: torch.Tensor) -> torch.Tensor:\n        batch_size, num_channels, num_frames, height, width = z.shape\n\n        tile_latent_min_height = (\n            self.tile_sample_min_height // self.spatial_compression_ratio\n        )\n        tile_latent_min_width = (\n            self.tile_sample_min_width // self.spatial_compression_ratio\n        )\n        tile_latent_min_num_frames = (\n            self.tile_sample_min_num_frames // self.temporal_compression_ratio\n        )\n        tile_latent_stride_num_frames = (\n            self.tile_sample_stride_num_frames // self.temporal_compression_ratio\n        )\n\n        row = []\n        for i in range(0, num_frames, tile_latent_stride_num_frames):\n            tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :]\n            if self.use_tiling and (\n                tile.shape[-1] > tile_latent_min_width\n                or tile.shape[-2] > tile_latent_min_height\n            ):\n                decoded = self.spatial_tiled_decode(tile)\n            else:\n                decoded = self._decode(tile)\n            if i > 0:\n                decoded = decoded[:, :, 1:, :, :]\n            row.append(decoded)\n        result_row = []\n        for i, tile in enumerate(row):\n            if i > 0:\n                tile = self.blend_t(row[i - 1], tile, self.blend_num_frames)\n                result_row.append(\n                    tile[:, :, : self.tile_sample_stride_num_frames, :, :]\n                )\n            else:\n                result_row.append(\n                    tile[:, :, : self.tile_sample_stride_num_frames + 1, :, :]\n                )\n\n        dec = torch.cat(result_row, dim=2)\n        return dec\n\n    def enable_tiling(\n        self,\n        tile_sample_min_height: int | None = None,\n        tile_sample_min_width: int | None = None,\n        tile_sample_min_num_frames: int | None = None,\n        tile_sample_stride_height: int | None = None,\n        tile_sample_stride_width: int | None = None,\n        tile_sample_stride_num_frames: int | None = None,\n        blend_num_frames: int | None = None,\n        use_tiling: bool | None = None,\n        use_temporal_tiling: bool | None = None,\n        use_parallel_tiling: bool | None = None,\n    ) -> None:\n        r\"\"\"\n        Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to\n        compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow\n        processing larger images.\n\n        Args:\n            tile_sample_min_height (`int`, *optional*):\n                The minimum height required for a sample to be separated into tiles across the height dimension.\n            tile_sample_min_width (`int`, *optional*):\n                The minimum width required for a sample to be separated into tiles across the width dimension.\n            tile_sample_min_num_frames (`int`, *optional*):\n                The minimum number of frames required for a sample to be separated into tiles across the frame\n                dimension.\n            tile_sample_stride_height (`int`, *optional*):\n                The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are\n                no tiling artifacts produced across the height dimension.\n            tile_sample_stride_width (`int`, *optional*):\n                The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling\n                artifacts produced across the width dimension.\n            tile_sample_stride_num_frames (`int`, *optional*):\n                The stride between two consecutive frame tiles. This is to ensure that there are no tiling artifacts\n                produced across the frame dimension.\n        \"\"\"\n        self.use_tiling = True\n        self.tile_sample_min_height = (\n            tile_sample_min_height or self.tile_sample_min_height\n        )\n        self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width\n        self.tile_sample_min_num_frames = (\n            tile_sample_min_num_frames or self.tile_sample_min_num_frames\n        )\n        self.tile_sample_stride_height = (\n            tile_sample_stride_height or self.tile_sample_stride_height\n        )\n        self.tile_sample_stride_width = (\n            tile_sample_stride_width or self.tile_sample_stride_width\n        )\n        self.tile_sample_stride_num_frames = (\n            tile_sample_stride_num_frames or self.tile_sample_stride_num_frames\n        )\n        if blend_num_frames is not None:\n            self.blend_num_frames = blend_num_frames\n        else:\n            self.blend_num_frames = (\n                self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames\n            )\n        self.use_tiling = use_tiling or self.use_tiling\n        self.use_temporal_tiling = use_temporal_tiling or self.use_temporal_tiling\n        self.use_parallel_tiling = use_parallel_tiling or self.use_parallel_tiling\n\n    def disable_tiling(self) -> None:\n        r\"\"\"\n        Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing\n        decoding in one step.\n        \"\"\"\n        self.use_tiling = False\n\n\n# adapted from https://github.com/huggingface/diffusers/blob/e7ffeae0a191f710881d1fbde00cd6ff025e81f2/src/diffusers/models/autoencoders/vae.py#L691\nclass DiagonalGaussianDistribution:\n\n    def __init__(self, parameters: torch.Tensor, deterministic: bool = False):\n        self.parameters = parameters\n        self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)\n        self.logvar = torch.clamp(self.logvar, -30.0, 20.0)\n        self.deterministic = deterministic\n        self.std = torch.exp(0.5 * self.logvar)\n        self.var = torch.exp(self.logvar)\n        if self.deterministic:\n            self.var = self.std = torch.zeros_like(\n                self.mean, device=self.parameters.device, dtype=self.parameters.dtype\n            )\n\n    def sample(self, generator: torch.Generator | None = None) -> torch.Tensor:\n        # make sure sample is on the same device as the parameters and has same dtype\n        sample = randn_tensor(\n            self.mean.shape,\n            generator=generator,\n            device=self.parameters.device,\n            dtype=self.parameters.dtype,\n        )\n        x = self.mean + self.std * sample\n        return x\n\n    def kl(\n        self,\n        other: Optional[\"DiagonalGaussianDistribution\"] = None,\n        dims: tuple[int, ...] = (1, 2, 3),\n    ) -> torch.Tensor:\n        if self.deterministic:\n            return torch.Tensor([0.0])\n        else:\n            if other is None:\n                return 0.5 * torch.sum(\n                    torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,\n                    dim=dims,\n                )\n            else:\n                return 0.5 * torch.sum(\n                    torch.pow(self.mean - other.mean, 2) / other.var\n                    + self.var / other.var\n                    - 1.0\n                    - self.logvar\n                    + other.logvar,\n                    dim=dims,\n                )\n\n    def nll(\n        self, sample: torch.Tensor, dims: tuple[int, ...] = (1, 2, 3)\n    ) -> torch.Tensor:\n        if self.deterministic:\n            return torch.Tensor([0.0])\n        logtwopi = np.log(2.0 * np.pi)\n        return 0.5 * torch.sum(\n            logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,\n            dim=dims,\n        )\n\n    def mode(self) -> torch.Tensor:\n        return self.mean\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/vaes/dac.py",
    "content": "# Copied and adapted from: https://github.com/descriptinc/descript-audio-codec\n\n# SPDX-License-Identifier: MIT\n\nimport math\nfrom bisect import bisect_right\nfrom typing import Union\n\nimport torch\nimport torch.nn.functional as F\nfrom einops import rearrange\nfrom torch import nn\n\nfrom sglang.multimodal_gen.configs.models.vaes.dac import DacVAEConfig\nfrom sglang.multimodal_gen.runtime.models.vaes.common import (\n    DiagonalGaussianDistribution,\n)\n\n\n# Scripting this brings model speed up 1.4x\n@torch.jit.script\ndef snake(x, alpha):\n    shape = x.shape\n    x = x.reshape(shape[0], shape[1], -1)\n    x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)\n    x = x.reshape(shape)\n    return x\n\n\nclass Snake1d(nn.Module):\n    def __init__(self, channels):\n        super().__init__()\n        self.alpha = nn.Parameter(torch.ones(1, channels, 1))\n\n    def forward(self, x):\n        return snake(x, self.alpha)\n\n\nclass VectorQuantize(nn.Module):\n    \"\"\"\n    Implementation of VQ similar to Karpathy's repo:\n    https://github.com/karpathy/deep-vector-quantization\n    Additionally uses following tricks from Improved VQGAN\n    (https://arxiv.org/pdf/2110.04627.pdf):\n        1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space\n            for improved codebook usage\n        2. l2-normalized codes: Converts euclidean distance to cosine similarity which\n            improves training stability\n    \"\"\"\n\n    def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):\n        super().__init__()\n        self.codebook_size = codebook_size\n        self.codebook_dim = codebook_dim\n\n        self.in_proj = nn.Conv1d(input_dim, codebook_dim, kernel_size=1)\n        self.out_proj = nn.Conv1d(codebook_dim, input_dim, kernel_size=1)\n        self.codebook = nn.Embedding(codebook_size, codebook_dim)\n\n    def forward(self, z):\n        \"\"\"Quantize the input tensor using a fixed codebook and return the corresponding codebook vectors.\n\n        Args:\n            z (torch.Tensor): Input tensor with shape ``[B, D, T]``.\n\n        Returns:\n            tuple: A tuple containing:\n                - z_q (torch.Tensor): Quantized continuous representation with shape ``[B, D, T]``.\n                - commitment_loss (torch.Tensor): Commitment loss scalar to train encoder to predict\n                  vectors closer to codebook entries.\n                - codebook_loss (torch.Tensor): Codebook loss scalar to update the codebook.\n                - indices (torch.Tensor): Codebook indices (quantized discrete representation) with shape ``[B, T]``.\n                - z_e (torch.Tensor): Projected latents (continuous representation before quantization) with shape ``[B, D, T]``.\n        \"\"\"\n\n        # Factorized codes (ViT-VQGAN) Project input into low-dimensional space\n        z_e = self.in_proj(z)  # z_e : (B x D x T)\n        z_q, indices = self.decode_latents(z_e)\n\n        commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction=\"none\").mean([1, 2])\n        codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction=\"none\").mean([1, 2])\n\n        z_q = (\n            z_e + (z_q - z_e).detach()\n        )  # noop in forward pass, straight-through gradient estimator in backward pass\n\n        z_q = self.out_proj(z_q)\n\n        return z_q, commitment_loss, codebook_loss, indices, z_e\n\n    def embed_code(self, embed_id):\n        return F.embedding(embed_id, self.codebook.weight)\n\n    def decode_code(self, embed_id):\n        return self.embed_code(embed_id).transpose(1, 2)\n\n    def decode_latents(self, latents):\n        encodings = rearrange(latents, \"b d t -> (b t) d\")\n        codebook = self.codebook.weight  # codebook: (N x D)\n\n        # L2 normalize encodings and codebook (ViT-VQGAN)\n        encodings = F.normalize(encodings)\n        codebook = F.normalize(codebook)\n\n        # Compute euclidean distance with codebook\n        dist = (\n            encodings.pow(2).sum(1, keepdim=True)\n            - 2 * encodings @ codebook.t()\n            + codebook.pow(2).sum(1, keepdim=True).t()\n        )\n        indices = rearrange((-dist).max(1)[1], \"(b t) -> b t\", b=latents.size(0))\n        z_q = self.decode_code(indices)\n        return z_q, indices\n\n\nclass ResidualVectorQuantize(nn.Module):\n    \"\"\"\n    Introduced in SoundStream: An end2end neural audio codec\n    https://arxiv.org/abs/2107.03312\n    \"\"\"\n\n    def __init__(\n        self,\n        input_dim: int = 512,\n        n_codebooks: int = 9,\n        codebook_size: int = 1024,\n        codebook_dim: Union[int, list] = 8,\n        quantizer_dropout: float = 0.0,\n    ):\n        super().__init__()\n        if isinstance(codebook_dim, int):\n            codebook_dim = [codebook_dim for _ in range(n_codebooks)]\n\n        self.n_codebooks = n_codebooks\n        self.codebook_dim = codebook_dim\n        self.codebook_size = codebook_size\n        dim_offsets = [0]\n        for dim in self.codebook_dim:\n            dim_offsets.append(dim_offsets[-1] + dim)\n        self._codebook_dim_offsets = tuple(dim_offsets)\n\n        self.quantizers = nn.ModuleList(\n            [\n                VectorQuantize(input_dim, codebook_size, codebook_dim[i])\n                for i in range(n_codebooks)\n            ]\n        )\n        self.quantizer_dropout = quantizer_dropout\n\n    def forward(self, z, n_quantizers: int = None):\n        \"\"\"Quantize the input tensor using a fixed set of codebooks and return the corresponding codebook vectors.\n\n        Args:\n            z (torch.Tensor): Input tensor with shape ``[B, D, T]``.\n            n_quantizers (int, optional): Number of quantizers to use. If ``None``,\n                all quantizers are used. When ``n_quantizers`` < ``self.n_codebooks``,\n                quantizer dropout is applied. Note: if ``self.quantizer_dropout`` > 0\n                and in training mode, this argument is ignored and a random number of\n                quantizers is used.\n\n        Returns:\n            tuple: A tuple containing:\n                - z_q (torch.Tensor): Quantized continuous representation with shape ``[B, D, T]``.\n                - codes (torch.Tensor): Codebook indices for each codebook with shape ``[B, N, T]``\n                  (quantized discrete representation of input).\n                - latents (torch.Tensor): Projected latents with shape ``[B, N*D, T]``\n                  (continuous representation before quantization).\n                - commitment_loss (torch.Tensor): Commitment loss scalar to train encoder to predict\n                  vectors closer to codebook entries.\n                - codebook_loss (torch.Tensor): Codebook loss scalar to update the codebook.\n        \"\"\"\n        z_q = 0\n        residual = z\n        commitment_loss = 0\n        codebook_loss = 0\n\n        codebook_indices = []\n        latents = []\n\n        if n_quantizers is None:\n            n_quantizers = self.n_codebooks\n        quantizers = self.quantizers\n        if self.training:\n            batch_size = z.shape[0]\n            device = z.device\n            n_quantizers = torch.full(\n                (batch_size,),\n                self.n_codebooks + 1,\n                device=device,\n                dtype=torch.long,\n            )\n            if self.quantizer_dropout > 0:\n                dropout = torch.randint(\n                    1,\n                    self.n_codebooks + 1,\n                    (batch_size,),\n                    device=device,\n                )\n                n_dropout = int(batch_size * self.quantizer_dropout)\n                if n_dropout > 0:\n                    n_quantizers[:n_dropout] = dropout[:n_dropout]\n\n            for i, quantizer in enumerate(quantizers):\n                z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(\n                    residual\n                )\n\n                # Create mask to apply quantizer dropout\n                mask = i < n_quantizers\n                z_q = z_q + z_q_i * mask[:, None, None]\n                residual = residual - z_q_i\n\n                # Sum losses\n                commitment_loss += (commitment_loss_i * mask).mean()\n                codebook_loss += (codebook_loss_i * mask).mean()\n\n                codebook_indices.append(indices_i)\n                latents.append(z_e_i)\n        else:\n            for i, quantizer in enumerate(quantizers):\n                if i >= n_quantizers:\n                    break\n                z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(\n                    residual\n                )\n                z_q = z_q + z_q_i\n                residual = residual - z_q_i\n\n                commitment_loss += commitment_loss_i.mean()\n                codebook_loss += codebook_loss_i.mean()\n\n                codebook_indices.append(indices_i)\n                latents.append(z_e_i)\n\n        codes = torch.stack(codebook_indices, dim=1)\n        latents = torch.cat(latents, dim=1)\n\n        return z_q, codes, latents, commitment_loss, codebook_loss\n\n    def from_codes(self, codes: torch.Tensor):\n        \"\"\"Reconstruct the continuous representation from quantized codes.\n\n        Args:\n            codes (torch.Tensor): Quantized discrete representation with shape ``[B, N, T]``.\n\n        Returns:\n            tuple: A tuple containing:\n                - z_q (torch.Tensor): Quantized continuous representation with shape ``[B, D, T]``.\n                - z_p (torch.Tensor): Concatenated latent space representation with shape ``[B, N*D, T]``.\n                - codes (torch.Tensor): Original input codebook indices with shape ``[B, N, T]``.\n        \"\"\"\n        z_q = 0.0\n        z_p = []\n        n_codebooks = codes.shape[1]\n        for i in range(n_codebooks):\n            z_p_i = self.quantizers[i].decode_code(codes[:, i, :])\n            z_p.append(z_p_i)\n\n            z_q_i = self.quantizers[i].out_proj(z_p_i)\n            z_q = z_q + z_q_i\n        return z_q, torch.cat(z_p, dim=1), codes\n\n    def from_latents(self, latents: torch.Tensor):\n        \"\"\"Reconstruct the continuous representation from unquantized latents.\n\n        Args:\n            latents (torch.Tensor): Continuous representation after projection with shape ``[B, N*D, T]``.\n\n        Returns:\n            tuple: A tuple containing:\n                - z_q (torch.Tensor): Quantized representation of full-projected space with shape ``[B, D, T]``.\n                - z_p (torch.Tensor): Quantized representation of latent space with shape ``[B, N*D, T]``.\n                - codes (torch.Tensor): Codebook indices with shape ``[B, N, T]``.\n        \"\"\"\n        z_q = 0\n        z_p = []\n        codes = []\n        dims = self._codebook_dim_offsets\n        n_codebooks = bisect_right(dims, latents.shape[1]) - 1\n        for i in range(n_codebooks):\n            j, k = dims[i], dims[i + 1]\n            z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])\n            z_p.append(z_p_i)\n            codes.append(codes_i)\n\n            z_q_i = self.quantizers[i].out_proj(z_p_i)\n            z_q = z_q + z_q_i\n\n        return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)\n\n\nclass ResidualUnit(nn.Module):\n    def __init__(self, dim: int = 16, dilation: int = 1):\n        super().__init__()\n        pad = ((7 - 1) * dilation) // 2\n        self.block = nn.Sequential(\n            Snake1d(dim),\n            nn.Conv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),\n            Snake1d(dim),\n            nn.Conv1d(dim, dim, kernel_size=1),\n        )\n\n    def forward(self, x):\n        y = self.block(x)\n        pad = (x.shape[-1] - y.shape[-1]) // 2\n        if pad > 0:\n            x = x[..., pad:-pad]\n        return x + y\n\n\nclass EncoderBlock(nn.Module):\n    def __init__(self, dim: int = 16, stride: int = 1):\n        super().__init__()\n        self.block = nn.Sequential(\n            ResidualUnit(dim // 2, dilation=1),\n            ResidualUnit(dim // 2, dilation=3),\n            ResidualUnit(dim // 2, dilation=9),\n            Snake1d(dim // 2),\n            nn.Conv1d(\n                dim // 2,\n                dim,\n                kernel_size=2 * stride,\n                stride=stride,\n                padding=math.ceil(stride / 2),\n            ),\n        )\n\n    def forward(self, x):\n        return self.block(x)\n\n\nclass Encoder(nn.Module):\n    def __init__(\n        self,\n        d_model: int = 64,\n        strides: list = [2, 4, 8, 8],\n        d_latent: int = 64,\n    ):\n        super().__init__()\n        # Create first convolution\n        self.block = [nn.Conv1d(1, d_model, kernel_size=7, padding=3)]\n\n        # Create EncoderBlocks that double channels as they downsample by `stride`\n        for stride in strides:\n            d_model *= 2\n            self.block += [EncoderBlock(d_model, stride=stride)]\n\n        # Create last convolution\n        self.block += [\n            Snake1d(d_model),\n            nn.Conv1d(d_model, d_latent, kernel_size=3, padding=1),\n        ]\n\n        # Wrap black into nn.Sequential\n        self.block = nn.Sequential(*self.block)\n        self.enc_dim = d_model\n\n    def forward(self, x):\n        return self.block(x)\n\n\nclass DecoderBlock(nn.Module):\n    def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):\n        super().__init__()\n        self.block = nn.Sequential(\n            Snake1d(input_dim),\n            nn.ConvTranspose1d(\n                input_dim,\n                output_dim,\n                kernel_size=2 * stride,\n                stride=stride,\n                padding=math.ceil(stride / 2),\n                output_padding=stride % 2,\n            ),\n            ResidualUnit(output_dim, dilation=1),\n            ResidualUnit(output_dim, dilation=3),\n            ResidualUnit(output_dim, dilation=9),\n        )\n\n    def forward(self, x):\n        return self.block(x)\n\n\nclass Decoder(nn.Module):\n    def __init__(\n        self,\n        input_channel,\n        channels,\n        rates,\n        d_out: int = 1,\n    ):\n        super().__init__()\n\n        # Add first conv layer\n        layers = [nn.Conv1d(input_channel, channels, kernel_size=7, padding=3)]\n\n        # Add upsampling + MRF blocks\n        for i, stride in enumerate(rates):\n            input_dim = channels // 2**i\n            output_dim = channels // 2 ** (i + 1)\n            layers += [DecoderBlock(input_dim, output_dim, stride)]\n\n        # Add final conv layer\n        layers += [\n            Snake1d(output_dim),\n            nn.Conv1d(output_dim, d_out, kernel_size=7, padding=3),\n            nn.Tanh(),\n        ]\n\n        self.model = nn.Sequential(*layers)\n\n    def forward(self, x):\n        return self.model(x)\n\n\nclass DAC(nn.Module):\n    def __init__(\n        self,\n        config: DacVAEConfig,\n    ):\n        super().__init__()\n\n        self.continuous = config.continuous\n        self.decoder_dim = config.decoder_dim\n        self.decoder_rates = config.decoder_rates\n        self.encoder_dim = config.encoder_dim\n        self.encoder_rates = config.encoder_rates\n        self.hop_length = math.prod(config.encoder_rates)\n        self.sample_rate = config.sample_rate\n\n        if config.latent_dim is None:\n            latent_dim = config.encoder_dim * (2 ** len(config.encoder_rates))\n        else:\n            latent_dim = config.latent_dim\n\n        self.latent_dim = latent_dim\n\n        if config.load_encoder:\n            self.encoder = Encoder(config.encoder_dim, config.encoder_rates, latent_dim)\n\n        if not config.continuous:\n            self.n_codebooks = config.n_codebooks\n            self.codebook_size = config.codebook_size\n            self.codebook_dim = config.codebook_dim\n            self.quantizer = ResidualVectorQuantize(\n                input_dim=latent_dim,\n                n_codebooks=config.n_codebooks,\n                codebook_size=config.codebook_size,\n                codebook_dim=config.codebook_dim,\n                quantizer_dropout=config.quantizer_dropout,\n            )\n        else:\n            self.quant_conv = torch.nn.Conv1d(latent_dim, 2 * latent_dim, 1)\n            self.post_quant_conv = torch.nn.Conv1d(latent_dim, latent_dim, 1)\n\n        if config.load_decoder:\n            self.decoder = Decoder(\n                latent_dim,\n                config.decoder_dim,\n                config.decoder_rates,\n            )\n\n        self.apply(self.init_weights)\n\n    @staticmethod\n    def init_weights(m):\n        if isinstance(m, nn.Conv1d):\n            nn.init.trunc_normal_(m.weight, std=0.02)\n            nn.init.constant_(m.bias, 0)\n\n    @property\n    def dtype(self):\n        return next(self.parameters()).dtype\n\n    @property\n    def device(self):\n        return next(self.parameters()).device\n\n    def preprocess(self, audio_data, sample_rate):\n        if sample_rate is None:\n            sample_rate = self.sample_rate\n        assert sample_rate == self.sample_rate\n\n        length = audio_data.shape[-1]\n        right_pad = math.ceil(length / self.hop_length) * self.hop_length - length\n        audio_data = nn.functional.pad(audio_data, (0, right_pad))\n\n        return audio_data\n\n    def encode(\n        self,\n        audio_data: torch.Tensor,\n        n_quantizers: int = None,\n    ):\n        \"\"\"Encode audio data into latent representations.\n\n        This method processes audio through the encoder network and optionally applies\n        vector quantization (in VQ mode) or projects to a Gaussian distribution (in\n        continuous mode) to produce latent representations.\n\n        Args:\n            audio_data (torch.Tensor): Audio data to encode, with shape ``[B, 1, T]``.\n            n_quantizers (int, optional): Number of quantizers to use. If ``None``,\n                all quantizers are used. Only applicable in VQ mode (``continuous=False``).\n\n        Returns:\n            tuple: A tuple containing:\n                - z (torch.Tensor): Encoded representation. In VQ mode, this is the\n                  quantized continuous representation with shape ``[B, D, T]``. In\n                  continuous mode, this is a ``DiagonalGaussianDistribution`` object.\n                - codes (torch.Tensor or None): Codebook indices with shape ``[B, N, T]``\n                  in VQ mode, ``None`` in continuous mode.\n                - latents (torch.Tensor or None): Projected latents with shape ``[B, N*D, T]``\n                  in VQ mode, ``None`` in continuous mode.\n                - commitment_loss (torch.Tensor): Commitment loss scalar.\n                - codebook_loss (torch.Tensor): Codebook loss scalar.\n\n        Note:\n            In continuous mode, the encoded representation is projected through a\n            quantization convolution layer and wrapped in a ``DiagonalGaussianDistribution``\n            for VAE training.\n        \"\"\"\n        z = self.encoder(audio_data)  # [B x D x T]\n        if not self.continuous:\n            z, codes, latents, commitment_loss, codebook_loss = self.quantizer(\n                z, n_quantizers\n            )\n        else:\n            z = self.quant_conv(z)  # [B x 2D x T]\n            z = DiagonalGaussianDistribution(z)\n            codes, latents, commitment_loss, codebook_loss = None, None, 0, 0\n\n        return z, codes, latents, commitment_loss, codebook_loss\n\n    def decode(self, z: torch.Tensor):\n        \"\"\"Decode latent representations back to audio waveforms.\n\n        This method takes latent representations (either quantized from VQ mode or sampled\n        from the posterior in continuous mode) and reconstructs the corresponding audio\n        through the decoder network.\n\n        Args:\n            z (torch.Tensor): Latent representation to decode, with shape ``[B, D, T]``.\n                In VQ mode (``continuous=False``), this is the quantized continuous\n                representation. In continuous mode (``continuous=True``), this is sampled\n                from the posterior distribution.\n\n        Returns:\n            torch.Tensor: Decoded audio data with shape ``[B, 1, T']``. The output length\n            T' is determined by the decoder's upsampling rates and may differ from the\n            input temporal dimension T.\n\n        Note:\n            In continuous mode (``continuous=True``), the input is first passed through\n            a post-quantization convolution layer before being fed to the decoder.\n        \"\"\"\n        if not self.continuous:\n            audio = self.decoder(z)\n        else:\n            z = self.post_quant_conv(z)\n            audio = self.decoder(z)\n\n        return audio\n\n    def forward(\n        self,\n        audio_data: torch.Tensor,\n        sample_rate: int = None,\n        n_quantizers: int = None,\n    ):\n        \"\"\"Model forward pass.\n\n        Args:\n            audio_data (torch.Tensor): Audio to encode, shape [B, 1, T].\n            sample_rate (int, optional): Sample rate in Hz. Defaults to\n                ``self.sample_rate`` when ``None``.\n            n_quantizers (int, optional): Number of quantizers to use. When ``None``,\n                all quantizers are used. Only used in VQ mode (``continuous=False``).\n\n        Returns:\n            dict: A dictionary containing different keys depending on the mode:\n\n            **VQ Mode (``continuous=False``):**\n                - \"audio\" (torch.Tensor): Decoded audio, shape [B, 1, length].\n                - \"z\" (torch.Tensor): Quantized continuous representation, shape [B, D, T].\n                - \"codes\" (torch.Tensor): Codebook indices, shape [B, N, T].\n                - \"latents\" (torch.Tensor): Projected latents, shape [B, N*D, T].\n                - \"vq/commitment_loss\" (torch.Tensor): Commitment loss.\n                - \"vq/codebook_loss\" (torch.Tensor): Codebook loss.\n\n            **Continuous Mode (``continuous=True``):**\n                - \"audio\" (torch.Tensor): Decoded audio, shape [B, 1, length].\n                - \"z\" (torch.Tensor): Latent representation, shape [B, D, T].\n                - \"kl_loss\" (torch.Tensor): KL divergence loss (for VAE training).\n        \"\"\"\n        length = audio_data.shape[-1]\n        audio_data = self.preprocess(audio_data, sample_rate)\n        if not self.continuous:\n            z, codes, latents, commitment_loss, codebook_loss = self.encode(\n                audio_data, n_quantizers\n            )\n\n            x = self.decode(z)\n            return {\n                \"audio\": x[..., :length],\n                \"z\": z,\n                \"codes\": codes,\n                \"latents\": latents,\n                \"vq/commitment_loss\": commitment_loss,\n                \"vq/codebook_loss\": codebook_loss,\n            }\n        else:\n            posterior, _, _, _, _ = self.encode(audio_data, n_quantizers)\n            z = posterior.sample()\n            x = self.decode(z)\n\n            kl_loss = posterior.kl(dims=(1, 2))\n            kl_loss = kl_loss.mean()\n\n            return {\n                \"audio\": x[..., :length],\n                \"z\": z,\n                \"kl_loss\": kl_loss,\n            }\n\n\nEntryClass = DAC\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/vaes/hunyuan3d_vae.py",
    "content": "# Copied and adapted from: https://github.com/Tencent-Hunyuan/Hunyuan3D-2\n\n\nfrom __future__ import annotations\n\nfrom typing import Callable, List, Optional, Tuple, Union\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom einops import rearrange, repeat\nfrom tqdm import tqdm\n\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n# Attention backend selection\nscaled_dot_product_attention = F.scaled_dot_product_attention\n\n\nclass CrossAttentionProcessor:\n    def __call__(self, attn, q, k, v):\n        out = scaled_dot_product_attention(q, k, v)\n        return out\n\n\nclass FlashVDMCrossAttentionProcessor:\n    def __init__(self, topk=None):\n        self.topk = topk\n\n    def __call__(self, attn, q, k, v):\n        if k.shape[-2] == 3072:\n            topk = 1024\n        elif k.shape[-2] == 512:\n            topk = 256\n        else:\n            topk = k.shape[-2] // 3\n\n        if self.topk is True:\n            q1 = q[:, :, ::100, :]\n            sim = q1 @ k.transpose(-1, -2)\n            sim = torch.mean(sim, -2)\n            topk_ind = torch.topk(sim, dim=-1, k=topk).indices.squeeze(-2).unsqueeze(-1)\n            topk_ind = topk_ind.expand(-1, -1, -1, v.shape[-1])\n            v0 = torch.gather(v, dim=-2, index=topk_ind)\n            k0 = torch.gather(k, dim=-2, index=topk_ind)\n            out = scaled_dot_product_attention(q, k0, v0)\n        elif self.topk is False:\n            out = scaled_dot_product_attention(q, k, v)\n        else:\n            idx, counts = self.topk\n            start = 0\n            outs = []\n            for grid_coord, count in zip(idx, counts):\n                end = start + count\n                q_chunk = q[:, :, start:end, :]\n                k0, v0 = self.select_topkv(q_chunk, k, v, topk)\n                out = scaled_dot_product_attention(q_chunk, k0, v0)\n                outs.append(out)\n                start += count\n            out = torch.cat(outs, dim=-2)\n        self.topk = False\n        return out\n\n    def select_topkv(self, q_chunk, k, v, topk):\n        q1 = q_chunk[:, :, ::50, :]\n        sim = q1 @ k.transpose(-1, -2)\n        sim = torch.mean(sim, -2)\n        topk_ind = torch.topk(sim, dim=-1, k=topk).indices.squeeze(-2).unsqueeze(-1)\n        topk_ind = topk_ind.expand(-1, -1, -1, v.shape[-1])\n        v0 = torch.gather(v, dim=-2, index=topk_ind)\n        k0 = torch.gather(k, dim=-2, index=topk_ind)\n        return k0, v0\n\n\nclass FlashVDMTopMCrossAttentionProcessor(FlashVDMCrossAttentionProcessor):\n    def select_topkv(self, q_chunk, k, v, topk):\n        q1 = q_chunk[:, :, ::30, :]\n        sim = q1 @ k.transpose(-1, -2)\n        # sim = sim.to(torch.float32)\n        sim = sim.softmax(-1)\n        sim = torch.mean(sim, 1)\n        activated_token = torch.where(sim > 1e-6)[2]\n        index = (\n            torch.unique(activated_token, return_counts=True)[0]\n            .unsqueeze(0)\n            .unsqueeze(0)\n            .unsqueeze(-1)\n        )\n        index = index.expand(-1, v.shape[1], -1, v.shape[-1])\n        v0 = torch.gather(v, dim=-2, index=index)\n        k0 = torch.gather(k, dim=-2, index=index)\n        return k0, v0\n\n\nclass FourierEmbedder(nn.Module):\n    def __init__(\n        self,\n        num_freqs: int = 6,\n        logspace: bool = True,\n        input_dim: int = 3,\n        include_input: bool = True,\n        include_pi: bool = True,\n    ) -> None:\n        \"\"\"The initialization\"\"\"\n\n        super().__init__()\n\n        if logspace:\n            frequencies = 2.0 ** torch.arange(num_freqs, dtype=torch.float32)\n        else:\n            frequencies = torch.linspace(\n                1.0, 2.0 ** (num_freqs - 1), num_freqs, dtype=torch.float32\n            )\n\n        if include_pi:\n            frequencies *= torch.pi\n\n        self.register_buffer(\"frequencies\", frequencies, persistent=False)\n        self.include_input = include_input\n        self.num_freqs = num_freqs\n\n        self.out_dim = self.get_dims(input_dim)\n\n    def get_dims(self, input_dim):\n        temp = 1 if self.include_input or self.num_freqs == 0 else 0\n        out_dim = input_dim * (self.num_freqs * 2 + temp)\n\n        return out_dim\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"Forward process.\"\"\"\n\n        if self.num_freqs > 0:\n            embed = (x[..., None].contiguous() * self.frequencies).view(\n                *x.shape[:-1], -1\n            )\n            if self.include_input:\n                return torch.cat((x, embed.sin(), embed.cos()), dim=-1)\n            else:\n                return torch.cat((embed.sin(), embed.cos()), dim=-1)\n        else:\n            return x\n\n\nclass DropPath(nn.Module):\n    \"\"\"Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).\"\"\"\n\n    def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):\n        super(DropPath, self).__init__()\n        self.drop_prob = drop_prob\n        self.scale_by_keep = scale_by_keep\n\n    def forward(self, x):\n        \"\"\"Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).\"\"\"\n        if self.drop_prob == 0.0 or not self.training:\n            return x\n        keep_prob = 1 - self.drop_prob\n        shape = (x.shape[0],) + (1,) * (\n            x.ndim - 1\n        )  # work with diff dim tensors, not just 2D ConvNets\n        random_tensor = x.new_empty(shape).bernoulli_(keep_prob)\n        if keep_prob > 0.0 and self.scale_by_keep:\n            random_tensor.div_(keep_prob)\n        return x * random_tensor\n\n    def extra_repr(self):\n        return f\"drop_prob={round(self.drop_prob, 3):0.3f}\"\n\n\nclass MLP(nn.Module):\n    def __init__(\n        self,\n        *,\n        width: int,\n        expand_ratio: int = 4,\n        output_width: int = None,\n        drop_path_rate: float = 0.0,\n    ):\n        super().__init__()\n        self.width = width\n        self.c_fc = nn.Linear(width, width * expand_ratio)\n        self.c_proj = nn.Linear(\n            width * expand_ratio, output_width if output_width is not None else width\n        )\n        self.gelu = nn.GELU()\n        self.drop_path = (\n            DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()\n        )\n\n    def forward(self, x):\n        return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))\n\n\nclass QKVMultiheadCrossAttention(nn.Module):\n    def __init__(\n        self,\n        *,\n        heads: int,\n        n_data: Optional[int] = None,\n        width=None,\n        qk_norm=False,\n        norm_layer=nn.LayerNorm,\n    ):\n        super().__init__()\n        self.heads = heads\n        self.n_data = n_data\n        self.q_norm = (\n            norm_layer(width // heads, elementwise_affine=True, eps=1e-6)\n            if qk_norm\n            else nn.Identity()\n        )\n        self.k_norm = (\n            norm_layer(width // heads, elementwise_affine=True, eps=1e-6)\n            if qk_norm\n            else nn.Identity()\n        )\n\n        self.attn_processor = CrossAttentionProcessor()\n\n    def forward(self, q, kv):\n        _, n_ctx, _ = q.shape\n        bs, n_data, width = kv.shape\n        attn_ch = width // self.heads // 2\n        q = q.view(bs, n_ctx, self.heads, -1)\n        kv = kv.view(bs, n_data, self.heads, -1)\n        k, v = torch.split(kv, attn_ch, dim=-1)\n\n        q = self.q_norm(q)\n        k = self.k_norm(k)\n        q, k, v = map(\n            lambda t: rearrange(t, \"b n h d -> b h n d\", h=self.heads), (q, k, v)\n        )\n        out = self.attn_processor(self, q, k, v)\n        out = out.transpose(1, 2).reshape(bs, n_ctx, -1)\n        return out\n\n\nclass MultiheadCrossAttention(nn.Module):\n    def __init__(\n        self,\n        *,\n        width: int,\n        heads: int,\n        qkv_bias: bool = True,\n        n_data: Optional[int] = None,\n        data_width: Optional[int] = None,\n        norm_layer=nn.LayerNorm,\n        qk_norm: bool = False,\n        kv_cache: bool = False,\n    ):\n        super().__init__()\n        self.n_data = n_data\n        self.width = width\n        self.heads = heads\n        self.data_width = width if data_width is None else data_width\n        self.c_q = nn.Linear(width, width, bias=qkv_bias)\n        self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias)\n        self.c_proj = nn.Linear(width, width)\n        self.attention = QKVMultiheadCrossAttention(\n            heads=heads,\n            n_data=n_data,\n            width=width,\n            norm_layer=norm_layer,\n            qk_norm=qk_norm,\n        )\n        self.kv_cache = kv_cache\n        self.data = None\n\n    def forward(self, x, data):\n        x = self.c_q(x)\n        if self.kv_cache:\n            if self.data is None:\n                self.data = self.c_kv(data)\n                logger.info(\n                    \"Save kv cache,this should be called only once for one mesh\"\n                )\n            data = self.data\n        else:\n            data = self.c_kv(data)\n        x = self.attention(x, data)\n        x = self.c_proj(x)\n        return x\n\n\nclass ResidualCrossAttentionBlock(nn.Module):\n    def __init__(\n        self,\n        *,\n        n_data: Optional[int] = None,\n        width: int,\n        heads: int,\n        mlp_expand_ratio: int = 4,\n        data_width: Optional[int] = None,\n        qkv_bias: bool = True,\n        norm_layer=nn.LayerNorm,\n        qk_norm: bool = False,\n    ):\n        super().__init__()\n\n        if data_width is None:\n            data_width = width\n\n        self.attn = MultiheadCrossAttention(\n            n_data=n_data,\n            width=width,\n            heads=heads,\n            data_width=data_width,\n            qkv_bias=qkv_bias,\n            norm_layer=norm_layer,\n            qk_norm=qk_norm,\n        )\n        self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)\n        self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6)\n        self.ln_3 = norm_layer(width, elementwise_affine=True, eps=1e-6)\n        self.mlp = MLP(width=width, expand_ratio=mlp_expand_ratio)\n\n    def forward(self, x: torch.Tensor, data: torch.Tensor):\n        x = x + self.attn(self.ln_1(x), self.ln_2(data))\n        x = x + self.mlp(self.ln_3(x))\n        return x\n\n\nclass QKVMultiheadAttention(nn.Module):\n    def __init__(\n        self,\n        *,\n        heads: int,\n        n_ctx: int,\n        width=None,\n        qk_norm=False,\n        norm_layer=nn.LayerNorm,\n    ):\n        super().__init__()\n        self.heads = heads\n        self.n_ctx = n_ctx\n        self.q_norm = (\n            norm_layer(width // heads, elementwise_affine=True, eps=1e-6)\n            if qk_norm\n            else nn.Identity()\n        )\n        self.k_norm = (\n            norm_layer(width // heads, elementwise_affine=True, eps=1e-6)\n            if qk_norm\n            else nn.Identity()\n        )\n\n    def forward(self, qkv):\n        bs, n_ctx, width = qkv.shape\n        attn_ch = width // self.heads // 3\n        qkv = qkv.view(bs, n_ctx, self.heads, -1)\n        q, k, v = torch.split(qkv, attn_ch, dim=-1)\n\n        q = self.q_norm(q)\n        k = self.k_norm(k)\n\n        q, k, v = map(\n            lambda t: rearrange(t, \"b n h d -> b h n d\", h=self.heads), (q, k, v)\n        )\n        out = (\n            scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)\n        )\n        return out\n\n\nclass MultiheadAttention(nn.Module):\n    def __init__(\n        self,\n        *,\n        n_ctx: int,\n        width: int,\n        heads: int,\n        qkv_bias: bool,\n        norm_layer=nn.LayerNorm,\n        qk_norm: bool = False,\n        drop_path_rate: float = 0.0,\n    ):\n        super().__init__()\n        self.n_ctx = n_ctx\n        self.width = width\n        self.heads = heads\n        self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias)\n        self.c_proj = nn.Linear(width, width)\n        self.attention = QKVMultiheadAttention(\n            heads=heads,\n            n_ctx=n_ctx,\n            width=width,\n            norm_layer=norm_layer,\n            qk_norm=qk_norm,\n        )\n        self.drop_path = (\n            DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()\n        )\n\n    def forward(self, x):\n        x = self.c_qkv(x)\n        x = self.attention(x)\n        x = self.drop_path(self.c_proj(x))\n        return x\n\n\nclass ResidualAttentionBlock(nn.Module):\n    def __init__(\n        self,\n        *,\n        n_ctx: int,\n        width: int,\n        heads: int,\n        qkv_bias: bool = True,\n        norm_layer=nn.LayerNorm,\n        qk_norm: bool = False,\n        drop_path_rate: float = 0.0,\n    ):\n        super().__init__()\n        self.attn = MultiheadAttention(\n            n_ctx=n_ctx,\n            width=width,\n            heads=heads,\n            qkv_bias=qkv_bias,\n            norm_layer=norm_layer,\n            qk_norm=qk_norm,\n            drop_path_rate=drop_path_rate,\n        )\n        self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)\n        self.mlp = MLP(width=width, drop_path_rate=drop_path_rate)\n        self.ln_2 = norm_layer(width, elementwise_affine=True, eps=1e-6)\n\n    def forward(self, x: torch.Tensor):\n        x = x + self.attn(self.ln_1(x))\n        x = x + self.mlp(self.ln_2(x))\n        return x\n\n\nclass Transformer(nn.Module):\n    def __init__(\n        self,\n        *,\n        n_ctx: int,\n        width: int,\n        layers: int,\n        heads: int,\n        qkv_bias: bool = True,\n        norm_layer=nn.LayerNorm,\n        qk_norm: bool = False,\n        drop_path_rate: float = 0.0,\n    ):\n        super().__init__()\n        self.n_ctx = n_ctx\n        self.width = width\n        self.layers = layers\n        self.resblocks = nn.ModuleList(\n            [\n                ResidualAttentionBlock(\n                    n_ctx=n_ctx,\n                    width=width,\n                    heads=heads,\n                    qkv_bias=qkv_bias,\n                    norm_layer=norm_layer,\n                    qk_norm=qk_norm,\n                    drop_path_rate=drop_path_rate,\n                )\n                for _ in range(layers)\n            ]\n        )\n\n    def forward(self, x: torch.Tensor):\n        for block in self.resblocks:\n            x = block(x)\n        return x\n\n\nclass CrossAttentionDecoder(nn.Module):\n\n    def __init__(\n        self,\n        *,\n        num_latents: int,\n        out_channels: int,\n        fourier_embedder: FourierEmbedder,\n        width: int,\n        heads: int,\n        mlp_expand_ratio: int = 4,\n        downsample_ratio: int = 1,\n        enable_ln_post: bool = True,\n        qkv_bias: bool = True,\n        qk_norm: bool = False,\n        label_type: str = \"binary\",\n    ):\n        super().__init__()\n\n        self.enable_ln_post = enable_ln_post\n        self.fourier_embedder = fourier_embedder\n        self.downsample_ratio = downsample_ratio\n        self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width)\n        if self.downsample_ratio != 1:\n            self.latents_proj = nn.Linear(width * downsample_ratio, width)\n        if self.enable_ln_post == False:\n            qk_norm = False\n        self.cross_attn_decoder = ResidualCrossAttentionBlock(\n            n_data=num_latents,\n            width=width,\n            mlp_expand_ratio=mlp_expand_ratio,\n            heads=heads,\n            qkv_bias=qkv_bias,\n            qk_norm=qk_norm,\n        )\n\n        if self.enable_ln_post:\n            self.ln_post = nn.LayerNorm(width)\n        self.output_proj = nn.Linear(width, out_channels)\n        self.label_type = label_type\n\n    def set_cross_attention_processor(self, processor):\n        self.cross_attn_decoder.attn.attention.attn_processor = processor\n\n    def forward(self, queries=None, query_embeddings=None, latents=None):\n        if query_embeddings is None:\n            fourier_out = self.fourier_embedder(queries)\n            query_embeddings = self.query_proj(fourier_out.to(latents.dtype))\n\n        if self.downsample_ratio != 1:\n            latents = self.latents_proj(latents)\n\n        x = self.cross_attn_decoder(query_embeddings, latents)\n\n        if self.enable_ln_post:\n            x = self.ln_post(x)\n\n        occ = self.output_proj(x)\n        return occ\n\n\ndef generate_dense_grid_points(\n    bbox_min: np.ndarray,\n    bbox_max: np.ndarray,\n    octree_resolution: int,\n    indexing: str = \"ij\",\n):\n    length = bbox_max - bbox_min\n    num_cells = octree_resolution\n\n    x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)\n    y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)\n    z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)\n    [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)\n    xyz = np.stack((xs, ys, zs), axis=-1)\n    grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]\n\n    return xyz, grid_size, length\n\n\ndef extract_near_surface_volume_fn(input_tensor: torch.Tensor, alpha: float):\n    \"\"\"Extract near-surface voxels for hierarchical decoding.\"\"\"\n    device = input_tensor.device\n\n    val = input_tensor + alpha\n    valid_mask = val > -9000\n\n    def get_neighbor(t, shift, axis):\n        if shift == 0:\n            return t.clone()\n        pad_dims = [0, 0, 0, 0, 0, 0]\n        if axis == 0:\n            pad_idx = 0 if shift > 0 else 1\n            pad_dims[pad_idx] = abs(shift)\n        elif axis == 1:\n            pad_idx = 2 if shift > 0 else 3\n            pad_dims[pad_idx] = abs(shift)\n        elif axis == 2:\n            pad_idx = 4 if shift > 0 else 5\n            pad_dims[pad_idx] = abs(shift)\n\n        padded = F.pad(t.unsqueeze(0).unsqueeze(0), pad_dims[::-1], mode=\"replicate\")\n\n        slice_dims = [slice(None)] * 3\n        if axis == 0:\n            slice_dims[0] = slice(shift, None) if shift > 0 else slice(None, shift)\n        elif axis == 1:\n            slice_dims[1] = slice(shift, None) if shift > 0 else slice(None, shift)\n        elif axis == 2:\n            slice_dims[2] = slice(shift, None) if shift > 0 else slice(None, shift)\n\n        padded = padded.squeeze(0).squeeze(0)\n        return padded[slice_dims]\n\n    left = get_neighbor(val, 1, axis=0)\n    right = get_neighbor(val, -1, axis=0)\n    back = get_neighbor(val, 1, axis=1)\n    front = get_neighbor(val, -1, axis=1)\n    down = get_neighbor(val, 1, axis=2)\n    up = get_neighbor(val, -1, axis=2)\n\n    def safe_where(neighbor):\n        return torch.where(neighbor > -9000, neighbor, val)\n\n    left, right = safe_where(left), safe_where(right)\n    back, front = safe_where(back), safe_where(front)\n    down, up = safe_where(down), safe_where(up)\n\n    sign = torch.sign(val.to(torch.float32))\n    neighbors_sign = torch.stack(\n        [\n            torch.sign(left.to(torch.float32)),\n            torch.sign(right.to(torch.float32)),\n            torch.sign(back.to(torch.float32)),\n            torch.sign(front.to(torch.float32)),\n            torch.sign(down.to(torch.float32)),\n            torch.sign(up.to(torch.float32)),\n        ],\n        dim=0,\n    )\n\n    same_sign = torch.all(neighbors_sign == sign, dim=0)\n    mask = (~same_sign).to(torch.int32)\n    return mask * valid_mask.to(torch.int32)\n\n\nclass VanillaVolumeDecoder:\n    \"\"\"Standard volume decoder using dense grid evaluation.\"\"\"\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        latents: torch.FloatTensor,\n        geo_decoder: Callable,\n        bounds: Union[Tuple[float], List[float], float] = 1.01,\n        num_chunks: int = 10000,\n        octree_resolution: int = None,\n        enable_pbar: bool = True,\n        **kwargs,\n    ):\n        device = latents.device\n        dtype = latents.dtype\n        batch_size = latents.shape[0]\n\n        if isinstance(bounds, float):\n            bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]\n\n        bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])\n        xyz_samples, grid_size, length = generate_dense_grid_points(\n            bbox_min=bbox_min,\n            bbox_max=bbox_max,\n            octree_resolution=octree_resolution,\n            indexing=\"ij\",\n        )\n        xyz_samples = (\n            torch.from_numpy(xyz_samples)\n            .to(device, dtype=dtype)\n            .contiguous()\n            .reshape(-1, 3)\n        )\n\n        batch_logits = []\n        for start in tqdm(\n            range(0, xyz_samples.shape[0], num_chunks),\n            desc=\"Volume Decoding\",\n            disable=not enable_pbar,\n        ):\n            chunk_queries = xyz_samples[start : start + num_chunks, :]\n            chunk_queries = repeat(chunk_queries, \"p c -> b p c\", b=batch_size)\n            logits = geo_decoder(queries=chunk_queries, latents=latents)\n            batch_logits.append(logits)\n\n        grid_logits = torch.cat(batch_logits, dim=1)\n        grid_logits = grid_logits.view((batch_size, *grid_size)).float()\n\n        return grid_logits\n\n\nclass HierarchicalVolumeDecoding:\n    \"\"\"Hierarchical volume decoder with multi-resolution refinement.\"\"\"\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        latents: torch.FloatTensor,\n        geo_decoder: Callable,\n        bounds: Union[Tuple[float], List[float], float] = 1.01,\n        num_chunks: int = 10000,\n        mc_level: float = 0.0,\n        octree_resolution: int = None,\n        min_resolution: int = 63,\n        enable_pbar: bool = True,\n        **kwargs,\n    ):\n        device = latents.device\n        dtype = latents.dtype\n\n        resolutions = []\n        if octree_resolution < min_resolution:\n            resolutions.append(octree_resolution)\n        while octree_resolution >= min_resolution:\n            resolutions.append(octree_resolution)\n            octree_resolution = octree_resolution // 2\n        resolutions.reverse()\n\n        # 1. generate query points\n        if isinstance(bounds, float):\n            bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]\n        bbox_min = np.array(bounds[0:3])\n        bbox_max = np.array(bounds[3:6])\n        bbox_size = bbox_max - bbox_min\n\n        xyz_samples, grid_size, length = generate_dense_grid_points(\n            bbox_min=bbox_min,\n            bbox_max=bbox_max,\n            octree_resolution=resolutions[0],\n            indexing=\"ij\",\n        )\n\n        dilate = nn.Conv3d(1, 1, 3, padding=1, bias=False, device=device, dtype=dtype)\n        dilate.weight = torch.nn.Parameter(\n            torch.ones(dilate.weight.shape, dtype=dtype, device=device)\n        )\n\n        grid_size = np.array(grid_size)\n        xyz_samples = (\n            torch.from_numpy(xyz_samples)\n            .to(device, dtype=dtype)\n            .contiguous()\n            .reshape(-1, 3)\n        )\n\n        # 2. latents to 3d volume\n        batch_logits = []\n        batch_size = latents.shape[0]\n        for start in tqdm(\n            range(0, xyz_samples.shape[0], num_chunks),\n            desc=f\"Hierarchical Volume Decoding [r{resolutions[0] + 1}]\",\n            disable=not enable_pbar,\n        ):\n            queries = xyz_samples[start : start + num_chunks, :]\n            batch_queries = repeat(queries, \"p c -> b p c\", b=batch_size)\n            logits = geo_decoder(queries=batch_queries, latents=latents)\n            batch_logits.append(logits)\n\n        grid_logits = torch.cat(batch_logits, dim=1).view(\n            (batch_size, grid_size[0], grid_size[1], grid_size[2])\n        )\n\n        for octree_depth_now in resolutions[1:]:\n            grid_size = np.array([octree_depth_now + 1] * 3)\n            resolution = bbox_size / octree_depth_now\n            next_index = torch.zeros(tuple(grid_size), dtype=dtype, device=device)\n            next_logits = torch.full(\n                next_index.shape, -10000.0, dtype=dtype, device=device\n            )\n            curr_points = extract_near_surface_volume_fn(\n                grid_logits.squeeze(0), mc_level\n            )\n            curr_points += grid_logits.squeeze(0).abs() < 0.95\n\n            if octree_depth_now == resolutions[-1]:\n                expand_num = 0\n            else:\n                expand_num = 1\n            for i in range(expand_num):\n                curr_points = dilate(curr_points.unsqueeze(0).to(dtype)).squeeze(0)\n            cidx_x, cidx_y, cidx_z = torch.where(curr_points > 0)\n            next_index[cidx_x * 2, cidx_y * 2, cidx_z * 2] = 1\n            for i in range(2 - expand_num):\n                next_index = dilate(next_index.unsqueeze(0)).squeeze(0)\n            nidx = torch.where(next_index > 0)\n\n            next_points = torch.stack(nidx, dim=1)\n            next_points = next_points * torch.tensor(\n                resolution, dtype=next_points.dtype, device=device\n            ) + torch.tensor(bbox_min, dtype=next_points.dtype, device=device)\n\n            # Check if next_points is empty\n            if next_points.shape[0] == 0:\n                logger.warning(\n                    f\"No valid surface points found at resolution {octree_depth_now}, \"\n                    f\"skipping this level\"\n                )\n                continue\n\n            batch_logits = []\n            for start in tqdm(\n                range(0, next_points.shape[0], num_chunks),\n                desc=f\"Hierarchical Volume Decoding [r{octree_depth_now + 1}]\",\n                disable=not enable_pbar,\n            ):\n                queries = next_points[start : start + num_chunks, :]\n                batch_queries = repeat(queries, \"p c -> b p c\", b=batch_size)\n                logits = geo_decoder(\n                    queries=batch_queries.to(latents.dtype), latents=latents\n                )\n                batch_logits.append(logits)\n            grid_logits = torch.cat(batch_logits, dim=1)\n            next_logits[nidx] = grid_logits[0, ..., 0]\n            grid_logits = next_logits.unsqueeze(0)\n        grid_logits[grid_logits == -10000.0] = float(\"nan\")\n\n        return grid_logits\n\n\nclass FlashVDMVolumeDecoding:\n    \"\"\"Flash VDM volume decoder with adaptive KV selection.\"\"\"\n\n    def __init__(self, topk_mode=\"mean\"):\n        if topk_mode not in [\"mean\", \"merge\"]:\n            raise ValueError(f\"Unsupported topk_mode {topk_mode}\")\n\n        if topk_mode == \"mean\":\n            self.processor = FlashVDMCrossAttentionProcessor()\n        else:\n            self.processor = FlashVDMTopMCrossAttentionProcessor()\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        latents: torch.FloatTensor,\n        geo_decoder: CrossAttentionDecoder,\n        bounds: Union[Tuple[float], List[float], float] = 1.01,\n        num_chunks: int = 10000,\n        mc_level: float = 0.0,\n        octree_resolution: int = None,\n        min_resolution: int = 63,\n        mini_grid_num: int = 4,\n        enable_pbar: bool = True,\n        **kwargs,\n    ):\n        processor = self.processor\n        geo_decoder.set_cross_attention_processor(processor)\n\n        device = latents.device\n        dtype = latents.dtype\n\n        resolutions = []\n        orig_resolution = octree_resolution\n        if octree_resolution < min_resolution:\n            resolutions.append(octree_resolution)\n        while octree_resolution >= min_resolution:\n            resolutions.append(octree_resolution)\n            octree_resolution = octree_resolution // 2\n        resolutions.reverse()\n        resolutions[0] = round(resolutions[0] / mini_grid_num) * mini_grid_num - 1\n        for i, resolution in enumerate(resolutions[1:]):\n            resolutions[i + 1] = resolutions[0] * 2 ** (i + 1)\n\n        if isinstance(bounds, float):\n            bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]\n        bbox_min = np.array(bounds[0:3])\n        bbox_max = np.array(bounds[3:6])\n        bbox_size = bbox_max - bbox_min\n\n        xyz_samples, grid_size, length = generate_dense_grid_points(\n            bbox_min=bbox_min,\n            bbox_max=bbox_max,\n            octree_resolution=resolutions[0],\n            indexing=\"ij\",\n        )\n\n        logger.info(f\"FlashVDMVolumeDecoding Resolution: {resolutions}\")\n\n        dilate = nn.Conv3d(1, 1, 3, padding=1, bias=False, device=device, dtype=dtype)\n        dilate.weight = torch.nn.Parameter(\n            torch.ones(dilate.weight.shape, dtype=dtype, device=device)\n        )\n\n        grid_size = np.array(grid_size)\n\n        xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype)\n        batch_size = latents.shape[0]\n        mini_grid_size = xyz_samples.shape[0] // mini_grid_num\n        xyz_samples = (\n            xyz_samples.view(\n                mini_grid_num,\n                mini_grid_size,\n                mini_grid_num,\n                mini_grid_size,\n                mini_grid_num,\n                mini_grid_size,\n                3,\n            )\n            .permute(0, 2, 4, 1, 3, 5, 6)\n            .reshape(-1, mini_grid_size * mini_grid_size * mini_grid_size, 3)\n        )\n\n        batch_logits = []\n        num_batchs = max(num_chunks // xyz_samples.shape[1], 1)\n        for start in tqdm(\n            range(0, xyz_samples.shape[0], num_batchs),\n            desc=\"FlashVDM Volume Decoding\",\n            disable=not enable_pbar,\n        ):\n            queries = xyz_samples[start : start + num_batchs, :]\n            batch = queries.shape[0]\n            batch_latents = repeat(latents.squeeze(0), \"p c -> b p c\", b=batch)\n            processor.topk = True\n            logits = geo_decoder(queries=queries, latents=batch_latents)\n            batch_logits.append(logits)\n\n        grid_logits = (\n            torch.cat(batch_logits, dim=0)\n            .reshape(\n                mini_grid_num,\n                mini_grid_num,\n                mini_grid_num,\n                mini_grid_size,\n                mini_grid_size,\n                mini_grid_size,\n            )\n            .permute(0, 3, 1, 4, 2, 5)\n            .contiguous()\n            .view((batch_size, grid_size[0], grid_size[1], grid_size[2]))\n        )\n\n        for octree_depth_now in resolutions[1:]:\n            grid_size = np.array([octree_depth_now + 1] * 3)\n            resolution = bbox_size / octree_depth_now\n            next_index = torch.zeros(tuple(grid_size), dtype=dtype, device=device)\n            next_logits = torch.full(\n                next_index.shape, -10000.0, dtype=dtype, device=device\n            )\n            curr_points = extract_near_surface_volume_fn(\n                grid_logits.squeeze(0), mc_level\n            )\n            curr_points += grid_logits.squeeze(0).abs() < 0.95\n\n            expand_num = 0 if octree_depth_now == resolutions[-1] else 1\n            for _ in range(expand_num):\n                curr_points = dilate(curr_points.unsqueeze(0).to(dtype)).squeeze(0)\n\n            cidx_x, cidx_y, cidx_z = torch.where(curr_points > 0)\n            next_index[cidx_x * 2, cidx_y * 2, cidx_z * 2] = 1\n            for _ in range(2 - expand_num):\n                next_index = dilate(next_index.unsqueeze(0)).squeeze(0)\n            nidx = torch.where(next_index > 0)\n\n            next_points = torch.stack(nidx, dim=1)\n            next_points = next_points * torch.tensor(\n                resolution, dtype=torch.float32, device=device\n            ) + torch.tensor(bbox_min, dtype=torch.float32, device=device)\n\n            # Check if next_points is empty (no valid surface points found)\n            if next_points.shape[0] == 0:\n                # Skip this resolution level if no points found\n                # Use the previous grid_logits as fallback\n                logger.warning(\n                    f\"No valid surface points found at resolution {octree_depth_now}, \"\n                    f\"skipping this level and using previous resolution grid_logits\"\n                )\n                continue\n\n            query_grid_num = 6\n            min_val = next_points.min(axis=0).values\n            max_val = next_points.max(axis=0).values\n            vol_queries_index = (\n                (next_points - min_val) / (max_val - min_val) * (query_grid_num - 0.001)\n            )\n            index = torch.floor(vol_queries_index).long()\n            index = (\n                index[..., 0] * (query_grid_num**2)\n                + index[..., 1] * query_grid_num\n                + index[..., 2]\n            )\n            index = index.sort()\n            next_points = next_points[index.indices].unsqueeze(0).contiguous()\n            unique_values = torch.unique(index.values, return_counts=True)\n            grid_logits_flat = torch.zeros(\n                (next_points.shape[1]), dtype=latents.dtype, device=latents.device\n            )\n            input_grid = [[], []]\n            logits_grid_list = []\n            start_num = 0\n            sum_num = 0\n            for grid_index, count in zip(\n                unique_values[0].cpu().tolist(), unique_values[1].cpu().tolist()\n            ):\n                if sum_num + count < num_chunks or sum_num == 0:\n                    sum_num += count\n                    input_grid[0].append(grid_index)\n                    input_grid[1].append(count)\n                else:\n                    processor.topk = input_grid\n                    logits_grid = geo_decoder(\n                        queries=next_points[:, start_num : start_num + sum_num],\n                        latents=latents,\n                    )\n                    start_num = start_num + sum_num\n                    logits_grid_list.append(logits_grid)\n                    input_grid = [[grid_index], [count]]\n                    sum_num = count\n            if sum_num > 0:\n                processor.topk = input_grid\n                logits_grid = geo_decoder(\n                    queries=next_points[:, start_num : start_num + sum_num],\n                    latents=latents,\n                )\n                logits_grid_list.append(logits_grid)\n            logits_grid = torch.cat(logits_grid_list, dim=1)\n            grid_logits_flat[index.indices] = logits_grid.squeeze(0).squeeze(-1)\n            next_logits[nidx] = grid_logits_flat\n            grid_logits = next_logits.unsqueeze(0)\n\n        grid_logits[grid_logits == -10000.0] = float(\"nan\")\n        return grid_logits\n\n\nclass Latent2MeshOutput:\n    \"\"\"Container for mesh output from VAE decoder.\"\"\"\n\n    def __init__(self, mesh_v=None, mesh_f=None):\n        self.mesh_v = mesh_v\n        self.mesh_f = mesh_f\n\n\ndef center_vertices(vertices):\n    \"\"\"Translate vertices so bounding box is centered at zero.\"\"\"\n    vert_min = vertices.min(dim=0)[0]\n    vert_max = vertices.max(dim=0)[0]\n    vert_center = 0.5 * (vert_min + vert_max)\n    return vertices - vert_center\n\n\nclass SurfaceExtractor:\n    \"\"\"Base class for surface extraction algorithms.\"\"\"\n\n    def _compute_box_stat(\n        self, bounds: Union[Tuple[float], List[float], float], octree_resolution: int\n    ):\n        if isinstance(bounds, float):\n            bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]\n\n        bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])\n        bbox_size = bbox_max - bbox_min\n        grid_size = [\n            int(octree_resolution) + 1,\n            int(octree_resolution) + 1,\n            int(octree_resolution) + 1,\n        ]\n        return grid_size, bbox_min, bbox_size\n\n    def run(self, *args, **kwargs):\n        raise NotImplementedError\n\n    def __call__(self, grid_logits, **kwargs):\n        outputs = []\n        for i in range(grid_logits.shape[0]):\n            try:\n                vertices, faces = self.run(grid_logits[i], **kwargs)\n                vertices = vertices.astype(np.float32)\n                faces = np.ascontiguousarray(faces)\n                outputs.append(Latent2MeshOutput(mesh_v=vertices, mesh_f=faces))\n            except Exception:\n                import traceback\n\n                traceback.print_exc()\n                outputs.append(None)\n        return outputs\n\n\nclass MCSurfaceExtractor(SurfaceExtractor):\n    \"\"\"Marching Cubes surface extractor.\"\"\"\n\n    def run(self, grid_logit, *, mc_level, bounds, octree_resolution, **kwargs):\n        from skimage import measure\n\n        vertices, faces, normals, _ = measure.marching_cubes(\n            grid_logit.cpu().numpy(), mc_level, method=\"lewiner\"\n        )\n        grid_size, bbox_min, bbox_size = self._compute_box_stat(\n            bounds, octree_resolution\n        )\n        vertices = vertices / grid_size * bbox_size + bbox_min\n        return vertices, faces\n\n\nclass DMCSurfaceExtractor(SurfaceExtractor):\n    \"\"\"Differentiable Marching Cubes surface extractor.\"\"\"\n\n    def run(self, grid_logit, *, octree_resolution, **kwargs):\n        device = grid_logit.device\n        if not hasattr(self, \"dmc\"):\n            try:\n                from diso import DiffDMC\n\n                self.dmc = DiffDMC(dtype=torch.float32).to(device)\n            except ImportError:\n                raise ImportError(\n                    \"Please install diso via `pip install diso`, or set mc_algo to 'mc'\"\n                )\n        sdf = -grid_logit / octree_resolution\n        sdf = sdf.to(torch.float32).contiguous()\n        verts, faces = self.dmc(sdf, deform=None, return_quads=False, normalize=True)\n        verts = center_vertices(verts)\n        vertices = verts.detach().cpu().numpy()\n        faces = faces.detach().cpu().numpy()[:, ::-1]\n        return vertices, faces\n\n\nSurfaceExtractors = {\n    \"mc\": MCSurfaceExtractor,\n    \"dmc\": DMCSurfaceExtractor,\n}\n\n\nclass VectsetVAE(nn.Module):\n    \"\"\"Base VAE class for vector set encoding.\"\"\"\n\n    def __init__(self, volume_decoder=None, surface_extractor=None):\n        super().__init__()\n        if volume_decoder is None:\n            volume_decoder = VanillaVolumeDecoder()\n        if surface_extractor is None:\n            surface_extractor = MCSurfaceExtractor()\n        self.volume_decoder = volume_decoder\n        self.surface_extractor = surface_extractor\n\n    def latents2mesh(self, latents: torch.FloatTensor, **kwargs):\n        \"\"\"Convert latents to mesh.\"\"\"\n        grid_logits = self.volume_decoder(latents, self.geo_decoder, **kwargs)\n        outputs = self.surface_extractor(grid_logits, **kwargs)\n        return outputs\n\n    def enable_flashvdm_decoder(\n        self,\n        enabled: bool = True,\n        adaptive_kv_selection=True,\n        topk_mode=\"mean\",\n        mc_algo=\"dmc\",\n    ):\n        \"\"\"Enable or disable FlashVDM decoder for faster inference.\"\"\"\n        if enabled:\n            if adaptive_kv_selection:\n                self.volume_decoder = FlashVDMVolumeDecoding(topk_mode)\n            else:\n                self.volume_decoder = HierarchicalVolumeDecoding()\n            if mc_algo not in SurfaceExtractors:\n                raise ValueError(\n                    f\"Unsupported mc_algo {mc_algo}, available: {list(SurfaceExtractors.keys())}\"\n                )\n            self.surface_extractor = SurfaceExtractors[mc_algo]()\n        else:\n            self.volume_decoder = VanillaVolumeDecoder()\n            self.surface_extractor = MCSurfaceExtractor()\n\n\nclass ShapeVAE(VectsetVAE):\n    \"\"\"Shape VAE for 3D mesh generation from latent codes.\"\"\"\n\n    _aliases = [\"hy3dgen.shapegen.models.ShapeVAE\"]\n\n    def __init__(\n        self,\n        *,\n        num_latents: int,\n        embed_dim: int,\n        width: int,\n        heads: int,\n        num_decoder_layers: int,\n        num_encoder_layers: int = 8,\n        pc_size: int = 5120,\n        pc_sharpedge_size: int = 5120,\n        point_feats: int = 3,\n        downsample_ratio: int = 20,\n        geo_decoder_downsample_ratio: int = 1,\n        geo_decoder_mlp_expand_ratio: int = 4,\n        geo_decoder_ln_post: bool = True,\n        num_freqs: int = 8,\n        include_pi: bool = True,\n        qkv_bias: bool = True,\n        qk_norm: bool = False,\n        label_type: str = \"binary\",\n        drop_path_rate: float = 0.0,\n        scale_factor: float = 1.0,\n        use_ln_post: bool = True,\n        ckpt_path=None,\n    ):\n        super().__init__()\n        self.geo_decoder_ln_post = geo_decoder_ln_post\n        self.downsample_ratio = downsample_ratio\n\n        self.fourier_embedder = FourierEmbedder(\n            num_freqs=num_freqs, include_pi=include_pi\n        )\n\n        self.post_kl = nn.Linear(embed_dim, width)\n\n        self.transformer = Transformer(\n            n_ctx=num_latents,\n            width=width,\n            layers=num_decoder_layers,\n            heads=heads,\n            qkv_bias=qkv_bias,\n            qk_norm=qk_norm,\n            drop_path_rate=drop_path_rate,\n        )\n\n        self.geo_decoder = CrossAttentionDecoder(\n            fourier_embedder=self.fourier_embedder,\n            out_channels=1,\n            num_latents=num_latents,\n            mlp_expand_ratio=geo_decoder_mlp_expand_ratio,\n            downsample_ratio=geo_decoder_downsample_ratio,\n            enable_ln_post=self.geo_decoder_ln_post,\n            width=width // geo_decoder_downsample_ratio,\n            heads=heads // geo_decoder_downsample_ratio,\n            qkv_bias=qkv_bias,\n            qk_norm=qk_norm,\n            label_type=label_type,\n        )\n\n        self.scale_factor = scale_factor\n        self.latent_shape = (num_latents, embed_dim)\n\n    def forward(self, latents):\n        latents = self.post_kl(latents)\n        latents = self.transformer(latents)\n        return latents\n\n    def decode(self, latents):\n        \"\"\"Decode latents to features.\"\"\"\n        latents = self.post_kl(latents)\n        latents = self.transformer(latents)\n        return latents\n\n\n# Entry class for model registry\nEntryClass = ShapeVAE\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/vaes/hunyuanvae.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# Adapted from diffusers\n\n# Copyright 2024 The Hunyuan Team, The HuggingFace Team and The sglang-diffusion Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom sglang.multimodal_gen.configs.models.vaes import HunyuanVAEConfig\nfrom sglang.multimodal_gen.runtime.layers.activation import get_act_fn\nfrom sglang.multimodal_gen.runtime.models.vaes.common import ParallelTiledVAE\n\n\ndef prepare_causal_attention_mask(\n    num_frames: int,\n    height_width: int,\n    dtype: torch.dtype,\n    device: torch.device,\n    batch_size: int | None = None,\n) -> torch.Tensor:\n    indices = torch.arange(1, num_frames + 1, dtype=torch.int32, device=device)\n    indices_blocks = indices.repeat_interleave(height_width)\n    x, y = torch.meshgrid(indices_blocks, indices_blocks, indexing=\"xy\")\n    mask = torch.where(x <= y, 0, -float(\"inf\")).to(dtype=dtype)\n\n    if batch_size is not None:\n        mask = mask.unsqueeze(0).expand(batch_size, -1, -1)\n    return mask\n\n\nclass HunyuanVAEAttention(nn.Module):\n\n    def __init__(\n        self, in_channels, heads, dim_head, eps, norm_num_groups, bias\n    ) -> None:\n        super().__init__()\n        self.in_channels = in_channels\n        self.heads = heads\n        self.dim_head = dim_head\n        self.eps = eps\n        self.norm_num_groups = norm_num_groups\n        self.bias = bias\n\n        inner_dim = heads * dim_head\n\n        # Define the projection layers\n        self.to_q = nn.Linear(in_channels, inner_dim, bias=bias)\n        self.to_k = nn.Linear(in_channels, inner_dim, bias=bias)\n        self.to_v = nn.Linear(in_channels, inner_dim, bias=bias)\n        self.to_out = nn.Sequential(nn.Linear(inner_dim, in_channels, bias=bias))\n\n        # Optional normalization layers\n        self.group_norm = nn.GroupNorm(\n            norm_num_groups, in_channels, eps=eps, affine=True\n        )\n\n    def forward(\n        self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None\n    ) -> torch.Tensor:\n        residual = hidden_states\n\n        batch_size, sequence_length, _ = hidden_states.shape\n\n        hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        # Project to query, key, value\n        query = self.to_q(hidden_states)\n        key = self.to_k(hidden_states)\n        value = self.to_v(hidden_states)\n\n        # Reshape for multi-head attention\n        head_dim = self.dim_head\n\n        query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)\n        key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)\n        value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)\n\n        # Perform scaled dot-product attention\n        hidden_states = F.scaled_dot_product_attention(\n            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False\n        )\n\n        # Reshape back\n        hidden_states = hidden_states.transpose(1, 2).reshape(\n            batch_size, -1, self.heads * head_dim\n        )\n        hidden_states = hidden_states.to(query.dtype)\n\n        # Linear projection\n        hidden_states = self.to_out(hidden_states)\n\n        # Residual connection and rescale\n        hidden_states = hidden_states + residual\n\n        return hidden_states\n\n\nclass HunyuanVideoCausalConv3d(nn.Module):\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: int | tuple[int, int, int] = 3,\n        stride: int | tuple[int, int, int] = 1,\n        padding: int | tuple[int, int, int] = 0,\n        dilation: int | tuple[int, int, int] = 1,\n        bias: bool = True,\n        pad_mode: str = \"replicate\",\n    ) -> None:\n        super().__init__()\n\n        kernel_size = (\n            (kernel_size, kernel_size, kernel_size)\n            if isinstance(kernel_size, int)\n            else kernel_size\n        )\n\n        self.pad_mode = pad_mode\n        self.time_causal_padding = (\n            kernel_size[0] // 2,\n            kernel_size[0] // 2,\n            kernel_size[1] // 2,\n            kernel_size[1] // 2,\n            kernel_size[2] - 1,\n            0,\n        )\n\n        self.conv = nn.Conv3d(\n            in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias\n        )\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        hidden_states = F.pad(\n            hidden_states, self.time_causal_padding, mode=self.pad_mode\n        )\n        return self.conv(hidden_states)\n\n\nclass HunyuanVideoUpsampleCausal3D(nn.Module):\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int | None = None,\n        kernel_size: int = 3,\n        stride: int = 1,\n        bias: bool = True,\n        upsample_factor: tuple[int, ...] = (2, 2, 2),\n    ) -> None:\n        super().__init__()\n\n        out_channels = out_channels or in_channels\n        self.upsample_factor = upsample_factor\n\n        self.conv = HunyuanVideoCausalConv3d(\n            in_channels, out_channels, kernel_size, stride, bias=bias\n        )\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        num_frames = hidden_states.size(2)\n\n        first_frame, other_frames = hidden_states.split((1, num_frames - 1), dim=2)\n        first_frame = F.interpolate(\n            first_frame.squeeze(2),\n            scale_factor=self.upsample_factor[1:],\n            mode=\"nearest\",\n        ).unsqueeze(2)\n\n        if num_frames > 1:\n            # See: https://github.com/pytorch/pytorch/issues/81665\n            # Unless you have a version of pytorch where non-contiguous implementation of F.interpolate\n            # is fixed, this will raise either a runtime error, or fail silently with bad outputs.\n            # If you are encountering an error here, make sure to try running encoding/decoding with\n            # `vae.enable_tiling()` first. If that doesn't work, open an issue at:\n            # https://github.com/huggingface/diffusers/issues\n            other_frames = other_frames.contiguous()\n            other_frames = F.interpolate(\n                other_frames, scale_factor=self.upsample_factor, mode=\"nearest\"\n            )\n            hidden_states = torch.cat((first_frame, other_frames), dim=2)\n        else:\n            hidden_states = first_frame\n\n        hidden_states = self.conv(hidden_states)\n        return hidden_states\n\n\nclass HunyuanVideoDownsampleCausal3D(nn.Module):\n\n    def __init__(\n        self,\n        channels: int,\n        out_channels: int | None = None,\n        padding: int = 1,\n        kernel_size: int = 3,\n        bias: bool = True,\n        stride=2,\n    ) -> None:\n        super().__init__()\n        out_channels = out_channels or channels\n\n        self.conv = HunyuanVideoCausalConv3d(\n            channels, out_channels, kernel_size, stride, padding, bias=bias\n        )\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        hidden_states = self.conv(hidden_states)\n        return hidden_states\n\n\nclass HunyuanVideoResnetBlockCausal3D(nn.Module):\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int | None = None,\n        dropout: float = 0.0,\n        groups: int = 32,\n        eps: float = 1e-6,\n        non_linearity: str = \"silu\",\n    ) -> None:\n        super().__init__()\n        out_channels = out_channels or in_channels\n\n        self.nonlinearity = get_act_fn(non_linearity)\n\n        self.norm1 = nn.GroupNorm(groups, in_channels, eps=eps, affine=True)\n        self.conv1 = HunyuanVideoCausalConv3d(in_channels, out_channels, 3, 1, 0)\n\n        self.norm2 = nn.GroupNorm(groups, out_channels, eps=eps, affine=True)\n        self.dropout = nn.Dropout(dropout)\n        self.conv2 = HunyuanVideoCausalConv3d(out_channels, out_channels, 3, 1, 0)\n\n        self.conv_shortcut = None\n        if in_channels != out_channels:\n            self.conv_shortcut = HunyuanVideoCausalConv3d(\n                in_channels, out_channels, 1, 1, 0\n            )\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        hidden_states = hidden_states.contiguous()\n        residual = hidden_states\n\n        hidden_states = self.norm1(hidden_states)\n        hidden_states = self.nonlinearity(hidden_states)\n        hidden_states = self.conv1(hidden_states)\n\n        hidden_states = self.norm2(hidden_states)\n        hidden_states = self.nonlinearity(hidden_states)\n        hidden_states = self.dropout(hidden_states)\n        hidden_states = self.conv2(hidden_states)\n\n        if self.conv_shortcut is not None:\n            residual = self.conv_shortcut(residual)\n\n        hidden_states = hidden_states + residual\n        return hidden_states\n\n\nclass HunyuanVideoMidBlock3D(nn.Module):\n\n    def __init__(\n        self,\n        in_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_act_fn: str = \"silu\",\n        resnet_groups: int = 32,\n        add_attention: bool = True,\n        attention_head_dim: int = 1,\n    ) -> None:\n        super().__init__()\n        resnet_groups = (\n            resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)\n        )\n        self.add_attention = add_attention\n\n        # There is always at least one resnet\n        resnets = [\n            HunyuanVideoResnetBlockCausal3D(\n                in_channels=in_channels,\n                out_channels=in_channels,\n                eps=resnet_eps,\n                groups=resnet_groups,\n                dropout=dropout,\n                non_linearity=resnet_act_fn,\n            )\n        ]\n        attentions: list[HunyuanVAEAttention | None] = []\n\n        for _ in range(num_layers):\n            if self.add_attention:\n                attentions.append(\n                    HunyuanVAEAttention(\n                        in_channels,\n                        heads=in_channels // attention_head_dim,\n                        dim_head=attention_head_dim,\n                        eps=resnet_eps,\n                        norm_num_groups=resnet_groups,\n                        bias=True,\n                    )\n                )\n            else:\n                attentions.append(None)\n\n            resnets.append(\n                HunyuanVideoResnetBlockCausal3D(\n                    in_channels=in_channels,\n                    out_channels=in_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    non_linearity=resnet_act_fn,\n                )\n            )\n\n        self.attentions = nn.ModuleList(attentions)\n        self.resnets = nn.ModuleList(resnets)\n\n        self.gradient_checkpointing = False\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        if torch.is_grad_enabled() and self.gradient_checkpointing:\n            hidden_states = self._gradient_checkpointing_func(\n                self.resnets[0], hidden_states\n            )\n\n            for attn, resnet in zip(self.attentions, self.resnets[1:], strict=True):\n                if attn is not None:\n                    batch_size, num_channels, num_frames, height, width = (\n                        hidden_states.shape\n                    )\n                    hidden_states = hidden_states.permute(0, 2, 3, 4, 1).flatten(1, 3)\n                    attention_mask = prepare_causal_attention_mask(\n                        num_frames,\n                        height * width,\n                        hidden_states.dtype,\n                        hidden_states.device,\n                        batch_size=batch_size,\n                    )\n                    hidden_states = attn(hidden_states, attention_mask=attention_mask)\n                    hidden_states = hidden_states.unflatten(\n                        1, (num_frames, height, width)\n                    ).permute(0, 4, 1, 2, 3)\n\n                hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)\n\n        else:\n            hidden_states = self.resnets[0](hidden_states)\n\n            for attn, resnet in zip(self.attentions, self.resnets[1:], strict=True):\n                if attn is not None:\n                    batch_size, num_channels, num_frames, height, width = (\n                        hidden_states.shape\n                    )\n                    hidden_states = hidden_states.permute(0, 2, 3, 4, 1).flatten(1, 3)\n                    attention_mask = prepare_causal_attention_mask(\n                        num_frames,\n                        height * width,\n                        hidden_states.dtype,\n                        hidden_states.device,\n                        batch_size=batch_size,\n                    )\n                    hidden_states = attn(hidden_states, attention_mask=attention_mask)\n                    hidden_states = hidden_states.unflatten(\n                        1, (num_frames, height, width)\n                    ).permute(0, 4, 1, 2, 3)\n\n                hidden_states = resnet(hidden_states)\n\n        return hidden_states\n\n\nclass HunyuanVideoDownBlock3D(nn.Module):\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_act_fn: str = \"silu\",\n        resnet_groups: int = 32,\n        add_downsample: bool = True,\n        downsample_stride: tuple[int, ...] | int = 2,\n        downsample_padding: int = 1,\n    ) -> None:\n        super().__init__()\n        resnets = []\n\n        for i in range(num_layers):\n            in_channels = in_channels if i == 0 else out_channels\n            resnets.append(\n                HunyuanVideoResnetBlockCausal3D(\n                    in_channels=in_channels,\n                    out_channels=out_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    non_linearity=resnet_act_fn,\n                )\n            )\n\n        self.resnets = nn.ModuleList(resnets)\n\n        if add_downsample:\n            self.downsamplers = nn.ModuleList(\n                [\n                    HunyuanVideoDownsampleCausal3D(\n                        out_channels,\n                        out_channels=out_channels,\n                        padding=downsample_padding,\n                        stride=downsample_stride,\n                    )\n                ]\n            )\n        else:\n            self.downsamplers = None\n\n        self.gradient_checkpointing = False\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        if torch.is_grad_enabled() and self.gradient_checkpointing:\n            for resnet in self.resnets:\n                hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)\n        else:\n            for resnet in self.resnets:\n                hidden_states = resnet(hidden_states)\n\n        if self.downsamplers is not None:\n            for downsampler in self.downsamplers:\n                hidden_states = downsampler(hidden_states)\n\n        return hidden_states\n\n\nclass HunyuanVideoUpBlock3D(nn.Module):\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_act_fn: str = \"silu\",\n        resnet_groups: int = 32,\n        add_upsample: bool = True,\n        upsample_scale_factor: tuple[int, ...] = (2, 2, 2),\n    ) -> None:\n        super().__init__()\n        resnets = []\n\n        for i in range(num_layers):\n            input_channels = in_channels if i == 0 else out_channels\n\n            resnets.append(\n                HunyuanVideoResnetBlockCausal3D(\n                    in_channels=input_channels,\n                    out_channels=out_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    non_linearity=resnet_act_fn,\n                )\n            )\n\n        self.resnets = nn.ModuleList(resnets)\n\n        if add_upsample:\n            self.upsamplers = nn.ModuleList(\n                [\n                    HunyuanVideoUpsampleCausal3D(\n                        out_channels,\n                        out_channels=out_channels,\n                        upsample_factor=upsample_scale_factor,\n                    )\n                ]\n            )\n        else:\n            self.upsamplers = None\n\n        self.gradient_checkpointing = False\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        if torch.is_grad_enabled() and self.gradient_checkpointing:\n            for resnet in self.resnets:\n                hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)\n\n        else:\n            for resnet in self.resnets:\n                hidden_states = resnet(hidden_states)\n\n        if self.upsamplers is not None:\n            for upsampler in self.upsamplers:\n                hidden_states = upsampler(hidden_states)\n\n        return hidden_states\n\n\nclass HunyuanVideoEncoder3D(nn.Module):\n    r\"\"\"\n    Causal encoder for 3D video-like data introduced in [Hunyuan Video](https://huggingface.co/papers/2412.03603).\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int = 3,\n        out_channels: int = 3,\n        down_block_types: tuple[str, ...] = (\n            \"HunyuanVideoDownBlock3D\",\n            \"HunyuanVideoDownBlock3D\",\n            \"HunyuanVideoDownBlock3D\",\n            \"HunyuanVideoDownBlock3D\",\n        ),\n        block_out_channels: tuple[int, ...] = (128, 256, 512, 512),\n        layers_per_block: int = 2,\n        norm_num_groups: int = 32,\n        act_fn: str = \"silu\",\n        double_z: bool = True,\n        mid_block_add_attention=True,\n        temporal_compression_ratio: int = 4,\n        spatial_compression_ratio: int = 8,\n    ) -> None:\n        super().__init__()\n\n        self.conv_in = HunyuanVideoCausalConv3d(\n            in_channels, block_out_channels[0], kernel_size=3, stride=1\n        )\n        self.mid_block: HunyuanVideoMidBlock3D | None = None\n        self.down_blocks = nn.ModuleList([])\n\n        output_channel = block_out_channels[0]\n        for i, down_block_type in enumerate(down_block_types):\n            if down_block_type != \"HunyuanVideoDownBlock3D\":\n                raise ValueError(f\"Unsupported down_block_type: {down_block_type}\")\n\n            input_channel = output_channel\n            output_channel = block_out_channels[i]\n            is_final_block = i == len(block_out_channels) - 1\n            num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))\n            num_time_downsample_layers = int(np.log2(temporal_compression_ratio))\n\n            if temporal_compression_ratio == 4:\n                add_spatial_downsample = bool(i < num_spatial_downsample_layers)\n                add_time_downsample = bool(\n                    i >= (len(block_out_channels) - 1 - num_time_downsample_layers)\n                    and not is_final_block\n                )\n            elif temporal_compression_ratio == 8:\n                add_spatial_downsample = bool(i < num_spatial_downsample_layers)\n                add_time_downsample = bool(i < num_time_downsample_layers)\n            else:\n                raise ValueError(\n                    f\"Unsupported time_compression_ratio: {temporal_compression_ratio}\"\n                )\n\n            downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)\n            downsample_stride_T = (2,) if add_time_downsample else (1,)\n            downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)\n\n            down_block = HunyuanVideoDownBlock3D(\n                num_layers=layers_per_block,\n                in_channels=input_channel,\n                out_channels=output_channel,\n                add_downsample=bool(add_spatial_downsample or add_time_downsample),\n                resnet_eps=1e-6,\n                resnet_act_fn=act_fn,\n                resnet_groups=norm_num_groups,\n                downsample_stride=downsample_stride,\n                downsample_padding=0,\n            )\n\n            self.down_blocks.append(down_block)\n\n        self.mid_block = HunyuanVideoMidBlock3D(\n            in_channels=block_out_channels[-1],\n            resnet_eps=1e-6,\n            resnet_act_fn=act_fn,\n            attention_head_dim=block_out_channels[-1],\n            resnet_groups=norm_num_groups,\n            add_attention=mid_block_add_attention,\n        )\n\n        self.conv_norm_out = nn.GroupNorm(\n            num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6\n        )\n        self.conv_act = nn.SiLU()\n\n        conv_out_channels = 2 * out_channels if double_z else out_channels\n        self.conv_out = HunyuanVideoCausalConv3d(\n            block_out_channels[-1], conv_out_channels, kernel_size=3\n        )\n\n        self.gradient_checkpointing = False\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        hidden_states = self.conv_in(hidden_states)\n\n        if torch.is_grad_enabled() and self.gradient_checkpointing:\n            for down_block in self.down_blocks:\n                hidden_states = self._gradient_checkpointing_func(\n                    down_block, hidden_states\n                )\n\n            hidden_states = self._gradient_checkpointing_func(\n                self.mid_block, hidden_states\n            )\n        else:\n            for down_block in self.down_blocks:\n                hidden_states = down_block(hidden_states)\n            assert self.mid_block is not None\n            hidden_states = self.mid_block(hidden_states)\n\n        hidden_states = self.conv_norm_out(hidden_states)\n        hidden_states = self.conv_act(hidden_states)\n        hidden_states = self.conv_out(hidden_states)\n\n        return hidden_states\n\n\nclass HunyuanVideoDecoder3D(nn.Module):\n    r\"\"\"\n    Causal decoder for 3D video-like data introduced in [Hunyuan Video](https://huggingface.co/papers/2412.03603).\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int = 3,\n        out_channels: int = 3,\n        up_block_types: tuple[str, ...] = (\n            \"HunyuanVideoUpBlock3D\",\n            \"HunyuanVideoUpBlock3D\",\n            \"HunyuanVideoUpBlock3D\",\n            \"HunyuanVideoUpBlock3D\",\n        ),\n        block_out_channels: tuple[int, ...] = (128, 256, 512, 512),\n        layers_per_block: int = 2,\n        norm_num_groups: int = 32,\n        act_fn: str = \"silu\",\n        mid_block_add_attention=True,\n        time_compression_ratio: int = 4,\n        spatial_compression_ratio: int = 8,\n    ):\n        super().__init__()\n        self.layers_per_block = layers_per_block\n\n        self.conv_in = HunyuanVideoCausalConv3d(\n            in_channels, block_out_channels[-1], kernel_size=3, stride=1\n        )\n        self.up_blocks = nn.ModuleList([])\n\n        # mid\n        self.mid_block = HunyuanVideoMidBlock3D(\n            in_channels=block_out_channels[-1],\n            resnet_eps=1e-6,\n            resnet_act_fn=act_fn,\n            attention_head_dim=block_out_channels[-1],\n            resnet_groups=norm_num_groups,\n            add_attention=mid_block_add_attention,\n        )\n\n        # up\n        reversed_block_out_channels = list(reversed(block_out_channels))\n        output_channel = reversed_block_out_channels[0]\n        for i, up_block_type in enumerate(up_block_types):\n            if up_block_type != \"HunyuanVideoUpBlock3D\":\n                raise ValueError(f\"Unsupported up_block_type: {up_block_type}\")\n\n            prev_output_channel = output_channel\n            output_channel = reversed_block_out_channels[i]\n            is_final_block = i == len(block_out_channels) - 1\n            num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))\n            num_time_upsample_layers = int(np.log2(time_compression_ratio))\n\n            if time_compression_ratio == 4:\n                add_spatial_upsample = bool(i < num_spatial_upsample_layers)\n                add_time_upsample = bool(\n                    i >= len(block_out_channels) - 1 - num_time_upsample_layers\n                    and not is_final_block\n                )\n            else:\n                raise ValueError(\n                    f\"Unsupported time_compression_ratio: {time_compression_ratio}\"\n                )\n\n            upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)\n            upsample_scale_factor_T = (2,) if add_time_upsample else (1,)\n            upsample_scale_factor = tuple(\n                upsample_scale_factor_T + upsample_scale_factor_HW\n            )\n\n            up_block = HunyuanVideoUpBlock3D(\n                num_layers=self.layers_per_block + 1,\n                in_channels=prev_output_channel,\n                out_channels=output_channel,\n                add_upsample=bool(add_spatial_upsample or add_time_upsample),\n                upsample_scale_factor=upsample_scale_factor,\n                resnet_eps=1e-6,\n                resnet_act_fn=act_fn,\n                resnet_groups=norm_num_groups,\n            )\n\n            self.up_blocks.append(up_block)\n            prev_output_channel = output_channel\n\n        # out\n        self.conv_norm_out = nn.GroupNorm(\n            num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6\n        )\n        self.conv_act = nn.SiLU()\n        self.conv_out = HunyuanVideoCausalConv3d(\n            block_out_channels[0], out_channels, kernel_size=3\n        )\n\n        self.gradient_checkpointing = False\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        hidden_states = self.conv_in(hidden_states)\n\n        if torch.is_grad_enabled() and self.gradient_checkpointing:\n            hidden_states = self._gradient_checkpointing_func(\n                self.mid_block, hidden_states\n            )\n\n            for up_block in self.up_blocks:\n                hidden_states = self._gradient_checkpointing_func(\n                    up_block, hidden_states\n                )\n        else:\n            hidden_states = self.mid_block(hidden_states)\n\n            for up_block in self.up_blocks:\n                hidden_states = up_block(hidden_states)\n\n        # post-process\n        hidden_states = self.conv_norm_out(hidden_states)\n        hidden_states = self.conv_act(hidden_states)\n        hidden_states = self.conv_out(hidden_states)\n\n        return hidden_states\n\n\nclass AutoencoderKLHunyuanVideo(ParallelTiledVAE):\n    r\"\"\"\n    A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.\n    Introduced in [HunyuanVideo](https://huggingface.co/papers/2412.03603).\n\n    This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented\n    for all models (such as downloading or saving).\n    \"\"\"\n\n    _supports_gradient_checkpointing = True\n\n    def __init__(\n        self,\n        config: HunyuanVAEConfig,\n    ) -> None:\n        nn.Module.__init__(self)\n        ParallelTiledVAE.__init__(self, config)\n\n        # TODO(will): only pass in config. We do this by manually defining a\n        # config for hunyuan vae\n        self.block_out_channels = config.block_out_channels\n\n        if config.load_encoder:\n            self.encoder = HunyuanVideoEncoder3D(\n                in_channels=config.in_channels,\n                out_channels=config.latent_channels,\n                down_block_types=config.down_block_types,\n                block_out_channels=config.block_out_channels,\n                layers_per_block=config.layers_per_block,\n                norm_num_groups=config.norm_num_groups,\n                act_fn=config.act_fn,\n                double_z=True,\n                mid_block_add_attention=config.mid_block_add_attention,\n                temporal_compression_ratio=config.temporal_compression_ratio,\n                spatial_compression_ratio=config.spatial_compression_ratio,\n            )\n            self.quant_conv = nn.Conv3d(\n                2 * config.latent_channels, 2 * config.latent_channels, kernel_size=1\n            )\n\n        if config.load_decoder:\n            self.decoder = HunyuanVideoDecoder3D(\n                in_channels=config.latent_channels,\n                out_channels=config.out_channels,\n                up_block_types=config.up_block_types,\n                block_out_channels=config.block_out_channels,\n                layers_per_block=config.layers_per_block,\n                norm_num_groups=config.norm_num_groups,\n                act_fn=config.act_fn,\n                time_compression_ratio=config.temporal_compression_ratio,\n                spatial_compression_ratio=config.spatial_compression_ratio,\n                mid_block_add_attention=config.mid_block_add_attention,\n            )\n            self.post_quant_conv = nn.Conv3d(\n                config.latent_channels, config.latent_channels, kernel_size=1\n            )\n\n    def _encode(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.encoder(x)\n        enc = self.quant_conv(x)\n        return enc\n\n    def _decode(self, z: torch.Tensor) -> torch.Tensor:\n        z = self.post_quant_conv(z)\n        dec = self.decoder(z)\n        return dec\n\n    def forward(\n        self,\n        sample: torch.Tensor,\n        sample_posterior: bool = False,\n        generator: torch.Generator | None = None,\n    ) -> torch.Tensor:\n        r\"\"\"\n        Args:\n            sample (`torch.Tensor`): Input sample.\n            sample_posterior (`bool`, *optional*, defaults to `False`):\n                Whether to sample from the posterior.\n        \"\"\"\n        x = sample\n        posterior = self.encode(x).latent_dist\n        if sample_posterior:\n            z = posterior.sample(generator=generator)\n        else:\n            z = posterior.mode()\n        dec = self.decode(z)\n        return dec\n\n\nEntryClass = AutoencoderKLHunyuanVideo\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/vaes/ltx_2_audio.py",
    "content": "from typing import Optional, Tuple, Union\n\nimport torch\nimport torch.nn.functional as F\nfrom diffusers.models.autoencoders.vae import (\n    DecoderOutput,\n    DiagonalGaussianDistribution,\n)\nfrom diffusers.models.modeling_outputs import AutoencoderKLOutput\nfrom torch import nn\n\nfrom sglang.multimodal_gen.configs.models.vaes.ltx_audio import LTXAudioVAEConfig\nfrom sglang.multimodal_gen.runtime.models.vaes.common import ParallelTiledVAE\n\nLATENT_DOWNSAMPLE_FACTOR = 4\n\n\nclass LTX2AudioCausalConv2d(nn.Module):\n    \"\"\"\n    A causal 2D convolution that pads asymmetrically along the causal axis.\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: Union[int, Tuple[int, int]],\n        stride: int = 1,\n        dilation: Union[int, Tuple[int, int]] = 1,\n        groups: int = 1,\n        bias: bool = True,\n        causality_axis: str = \"height\",\n    ) -> None:\n        super().__init__()\n\n        self.causality_axis = causality_axis\n        kernel_size = (\n            (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size\n        )\n        dilation = (dilation, dilation) if isinstance(dilation, int) else dilation\n\n        pad_h = (kernel_size[0] - 1) * dilation[0]\n        pad_w = (kernel_size[1] - 1) * dilation[1]\n\n        if self.causality_axis == \"none\":\n            padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)\n        elif self.causality_axis in {\"width\", \"width-compatibility\"}:\n            padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2)\n        elif self.causality_axis == \"height\":\n            padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0)\n        else:\n            raise ValueError(f\"Invalid causality_axis: {causality_axis}\")\n\n        self.padding = padding\n        self.conv = nn.Conv2d(\n            in_channels,\n            out_channels,\n            kernel_size,\n            stride=stride,\n            padding=0,\n            dilation=dilation,\n            groups=groups,\n            bias=bias,\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = F.pad(x, self.padding)\n        return self.conv(x)\n\n\nclass LTX2AudioPixelNorm(nn.Module):\n    \"\"\"\n    Per-pixel (per-location) RMS normalization layer.\n    \"\"\"\n\n    def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:\n        super().__init__()\n        self.dim = dim\n        self.eps = eps\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True)\n        rms = torch.sqrt(mean_sq + self.eps)\n        return x / rms\n\n\nclass LTX2AudioAttnBlock(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        norm_type: str = \"group\",\n    ) -> None:\n        super().__init__()\n        self.in_channels = in_channels\n\n        if norm_type == \"group\":\n            self.norm = nn.GroupNorm(\n                num_groups=32, num_channels=in_channels, eps=1e-6, affine=True\n            )\n        elif norm_type == \"pixel\":\n            self.norm = LTX2AudioPixelNorm(dim=1, eps=1e-6)\n        else:\n            raise ValueError(f\"Invalid normalization type: {norm_type}\")\n        self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)\n        self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)\n        self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)\n        self.proj_out = nn.Conv2d(\n            in_channels, in_channels, kernel_size=1, stride=1, padding=0\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        h_ = self.norm(x)\n        q = self.q(h_)\n        k = self.k(h_)\n        v = self.v(h_)\n\n        batch, channels, height, width = q.shape\n        q = q.reshape(batch, channels, height * width).permute(0, 2, 1).contiguous()\n        k = k.reshape(batch, channels, height * width).contiguous()\n        attn = torch.bmm(q, k) * (int(channels) ** (-0.5))\n        attn = torch.nn.functional.softmax(attn, dim=2)\n\n        v = v.reshape(batch, channels, height * width)\n        attn = attn.permute(0, 2, 1).contiguous()\n        h_ = torch.bmm(v, attn).reshape(batch, channels, height, width)\n\n        h_ = self.proj_out(h_)\n        return x + h_\n\n\nclass LTX2AudioResnetBlock(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: Optional[int] = None,\n        conv_shortcut: bool = False,\n        dropout: float = 0.0,\n        temb_channels: int = 512,\n        norm_type: str = \"group\",\n        causality_axis: str = \"height\",\n    ) -> None:\n        super().__init__()\n        self.causality_axis = causality_axis\n\n        if (\n            self.causality_axis is not None\n            and self.causality_axis != \"none\"\n            and norm_type == \"group\"\n        ):\n            raise ValueError(\"Causal ResnetBlock with GroupNorm is not supported.\")\n        self.in_channels = in_channels\n        out_channels = in_channels if out_channels is None else out_channels\n        self.out_channels = out_channels\n        self.use_conv_shortcut = conv_shortcut\n\n        if norm_type == \"group\":\n            self.norm1 = nn.GroupNorm(\n                num_groups=32, num_channels=in_channels, eps=1e-6, affine=True\n            )\n        elif norm_type == \"pixel\":\n            self.norm1 = LTX2AudioPixelNorm(dim=1, eps=1e-6)\n        else:\n            raise ValueError(f\"Invalid normalization type: {norm_type}\")\n        self.non_linearity = nn.SiLU()\n        if causality_axis is not None:\n            self.conv1 = LTX2AudioCausalConv2d(\n                in_channels,\n                out_channels,\n                kernel_size=3,\n                stride=1,\n                causality_axis=causality_axis,\n            )\n        else:\n            self.conv1 = nn.Conv2d(\n                in_channels, out_channels, kernel_size=3, stride=1, padding=1\n            )\n        if temb_channels > 0:\n            self.temb_proj = nn.Linear(temb_channels, out_channels)\n        if norm_type == \"group\":\n            self.norm2 = nn.GroupNorm(\n                num_groups=32, num_channels=out_channels, eps=1e-6, affine=True\n            )\n        elif norm_type == \"pixel\":\n            self.norm2 = LTX2AudioPixelNorm(dim=1, eps=1e-6)\n        else:\n            raise ValueError(f\"Invalid normalization type: {norm_type}\")\n        self.dropout = nn.Dropout(dropout)\n        if causality_axis is not None:\n            self.conv2 = LTX2AudioCausalConv2d(\n                out_channels,\n                out_channels,\n                kernel_size=3,\n                stride=1,\n                causality_axis=causality_axis,\n            )\n        else:\n            self.conv2 = nn.Conv2d(\n                out_channels, out_channels, kernel_size=3, stride=1, padding=1\n            )\n        if self.in_channels != self.out_channels:\n            if self.use_conv_shortcut:\n                if causality_axis is not None:\n                    self.conv_shortcut = LTX2AudioCausalConv2d(\n                        in_channels,\n                        out_channels,\n                        kernel_size=3,\n                        stride=1,\n                        causality_axis=causality_axis,\n                    )\n                else:\n                    self.conv_shortcut = nn.Conv2d(\n                        in_channels, out_channels, kernel_size=3, stride=1, padding=1\n                    )\n            else:\n                if causality_axis is not None:\n                    self.nin_shortcut = LTX2AudioCausalConv2d(\n                        in_channels,\n                        out_channels,\n                        kernel_size=1,\n                        stride=1,\n                        causality_axis=causality_axis,\n                    )\n                else:\n                    self.nin_shortcut = nn.Conv2d(\n                        in_channels, out_channels, kernel_size=1, stride=1, padding=0\n                    )\n\n    def forward(\n        self, x: torch.Tensor, temb: Optional[torch.Tensor] = None\n    ) -> torch.Tensor:\n        h = self.norm1(x)\n        h = self.non_linearity(h)\n        h = self.conv1(h)\n\n        if temb is not None:\n            h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None]\n\n        h = self.norm2(h)\n        h = self.non_linearity(h)\n        h = self.dropout(h)\n        h = self.conv2(h)\n\n        if self.in_channels != self.out_channels:\n            x = (\n                self.conv_shortcut(x)\n                if self.use_conv_shortcut\n                else self.nin_shortcut(x)\n            )\n\n        return x + h\n\n\nclass LTX2AudioDownsample(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        with_conv: bool,\n        causality_axis: Optional[str] = \"height\",\n    ) -> None:\n        super().__init__()\n        self.with_conv = with_conv\n        self.causality_axis = causality_axis\n\n        if self.with_conv:\n            self.conv = torch.nn.Conv2d(\n                in_channels, in_channels, kernel_size=3, stride=2, padding=0\n            )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        if self.with_conv:\n            # Padding tuple is in the order: (left, right, top, bottom).\n            if self.causality_axis == \"none\":\n                pad = (0, 1, 0, 1)\n            elif self.causality_axis == \"width\":\n                pad = (2, 0, 0, 1)\n            elif self.causality_axis == \"height\":\n                pad = (0, 1, 2, 0)\n            elif self.causality_axis == \"width-compatibility\":\n                pad = (1, 0, 0, 1)\n            else:\n                raise ValueError(\n                    f\"Invalid `causality_axis` {self.causality_axis}; supported values are `none`, `width`, `height`,\"\n                    f\" and `width-compatibility`.\"\n                )\n\n            x = F.pad(x, pad, mode=\"constant\", value=0)\n            x = self.conv(x)\n        else:\n            # with_conv=False implies that causality_axis is \"none\"\n            x = F.avg_pool2d(x, kernel_size=2, stride=2)\n        return x\n\n\nclass LTX2AudioUpsample(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        with_conv: bool,\n        causality_axis: Optional[str] = \"height\",\n    ) -> None:\n        super().__init__()\n        self.with_conv = with_conv\n        self.causality_axis = causality_axis\n        if self.with_conv:\n            if causality_axis is not None:\n                self.conv = LTX2AudioCausalConv2d(\n                    in_channels,\n                    in_channels,\n                    kernel_size=3,\n                    stride=1,\n                    causality_axis=causality_axis,\n                )\n            else:\n                self.conv = nn.Conv2d(\n                    in_channels, in_channels, kernel_size=3, stride=1, padding=1\n                )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode=\"nearest\")\n        if self.with_conv:\n            x = self.conv(x)\n            if self.causality_axis is None or self.causality_axis == \"none\":\n                pass\n            elif self.causality_axis == \"height\":\n                x = x[:, :, 1:, :]\n            elif self.causality_axis == \"width\":\n                x = x[:, :, :, 1:]\n            elif self.causality_axis == \"width-compatibility\":\n                pass\n            else:\n                raise ValueError(f\"Invalid causality_axis: {self.causality_axis}\")\n\n        return x\n\n\nclass LTX2AudioAudioPatchifier:\n    \"\"\"\n    Patchifier for spectrogram/audio latents.\n    \"\"\"\n\n    def __init__(\n        self,\n        patch_size: int,\n        sample_rate: int = 16000,\n        hop_length: int = 160,\n        audio_latent_downsample_factor: int = 4,\n        is_causal: bool = True,\n    ):\n        self.hop_length = hop_length\n        self.sample_rate = sample_rate\n        self.audio_latent_downsample_factor = audio_latent_downsample_factor\n        self.is_causal = is_causal\n        self._patch_size = (1, patch_size, patch_size)\n\n    def patchify(self, audio_latents: torch.Tensor) -> torch.Tensor:\n        batch, channels, time, freq = audio_latents.shape\n        return audio_latents.permute(0, 2, 1, 3).reshape(batch, time, channels * freq)\n\n    def unpatchify(\n        self, audio_latents: torch.Tensor, channels: int, mel_bins: int\n    ) -> torch.Tensor:\n        batch, time, _ = audio_latents.shape\n        return audio_latents.view(batch, time, channels, mel_bins).permute(0, 2, 1, 3)\n\n    @property\n    def patch_size(self) -> Tuple[int, int, int]:\n        return self._patch_size\n\n\nclass LTX2AudioEncoder(nn.Module):\n    def __init__(\n        self,\n        base_channels: int = 128,\n        output_channels: int = 1,\n        num_res_blocks: int = 2,\n        attn_resolutions: Optional[Tuple[int, ...]] = None,\n        in_channels: int = 2,\n        resolution: int = 256,\n        latent_channels: int = 8,\n        ch_mult: Tuple[int, ...] = (1, 2, 4),\n        norm_type: str = \"group\",\n        causality_axis: Optional[str] = \"width\",\n        dropout: float = 0.0,\n        mid_block_add_attention: bool = False,\n        sample_rate: int = 16000,\n        mel_hop_length: int = 160,\n        is_causal: bool = True,\n        mel_bins: Optional[int] = 64,\n        double_z: bool = True,\n    ):\n        super().__init__()\n\n        self.sample_rate = sample_rate\n        self.mel_hop_length = mel_hop_length\n        self.is_causal = is_causal\n        self.mel_bins = mel_bins\n\n        self.base_channels = base_channels\n        self.temb_ch = 0\n        self.num_resolutions = len(ch_mult)\n        self.num_res_blocks = num_res_blocks\n        self.resolution = resolution\n        self.in_channels = in_channels\n        self.out_ch = output_channels\n        self.give_pre_end = False\n        self.tanh_out = False\n        self.norm_type = norm_type\n        self.latent_channels = latent_channels\n        self.channel_multipliers = ch_mult\n        self.attn_resolutions = attn_resolutions\n        self.causality_axis = causality_axis\n\n        base_block_channels = base_channels\n        base_resolution = resolution\n        self.z_shape = (1, latent_channels, base_resolution, base_resolution)\n\n        if self.causality_axis is not None:\n            self.conv_in = LTX2AudioCausalConv2d(\n                in_channels,\n                base_block_channels,\n                kernel_size=3,\n                stride=1,\n                causality_axis=self.causality_axis,\n            )\n        else:\n            self.conv_in = nn.Conv2d(\n                in_channels, base_block_channels, kernel_size=3, stride=1, padding=1\n            )\n\n        self.down = nn.ModuleList()\n        block_in = base_block_channels\n        curr_res = self.resolution\n\n        for level in range(self.num_resolutions):\n            stage = nn.Module()\n            stage.block = nn.ModuleList()\n            stage.attn = nn.ModuleList()\n            block_out = self.base_channels * self.channel_multipliers[level]\n\n            for _ in range(self.num_res_blocks):\n                stage.block.append(\n                    LTX2AudioResnetBlock(\n                        in_channels=block_in,\n                        out_channels=block_out,\n                        temb_channels=self.temb_ch,\n                        dropout=dropout,\n                        norm_type=self.norm_type,\n                        causality_axis=self.causality_axis,\n                    )\n                )\n                block_in = block_out\n                if self.attn_resolutions:\n                    if curr_res in self.attn_resolutions:\n                        stage.attn.append(\n                            LTX2AudioAttnBlock(block_in, norm_type=self.norm_type)\n                        )\n\n            if level != self.num_resolutions - 1:\n                stage.downsample = LTX2AudioDownsample(\n                    block_in, True, causality_axis=self.causality_axis\n                )\n                curr_res = curr_res // 2\n\n            self.down.append(stage)\n\n        self.mid = nn.Module()\n        self.mid.block_1 = LTX2AudioResnetBlock(\n            in_channels=block_in,\n            out_channels=block_in,\n            temb_channels=self.temb_ch,\n            dropout=dropout,\n            norm_type=self.norm_type,\n            causality_axis=self.causality_axis,\n        )\n        if mid_block_add_attention:\n            self.mid.attn_1 = LTX2AudioAttnBlock(block_in, norm_type=self.norm_type)\n        else:\n            self.mid.attn_1 = nn.Identity()\n        self.mid.block_2 = LTX2AudioResnetBlock(\n            in_channels=block_in,\n            out_channels=block_in,\n            temb_channels=self.temb_ch,\n            dropout=dropout,\n            norm_type=self.norm_type,\n            causality_axis=self.causality_axis,\n        )\n\n        final_block_channels = block_in\n        z_channels = 2 * latent_channels if double_z else latent_channels\n        if self.norm_type == \"group\":\n            self.norm_out = nn.GroupNorm(\n                num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True\n            )\n        elif self.norm_type == \"pixel\":\n            self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6)\n        else:\n            raise ValueError(f\"Invalid normalization type: {self.norm_type}\")\n        self.non_linearity = nn.SiLU()\n\n        if self.causality_axis is not None:\n            self.conv_out = LTX2AudioCausalConv2d(\n                final_block_channels,\n                z_channels,\n                kernel_size=3,\n                stride=1,\n                causality_axis=self.causality_axis,\n            )\n        else:\n            self.conv_out = nn.Conv2d(\n                final_block_channels, z_channels, kernel_size=3, stride=1, padding=1\n            )\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        # hidden_states expected shape: (batch_size, channels, time, num_mel_bins)\n        hidden_states = self.conv_in(hidden_states)\n\n        for level in range(self.num_resolutions):\n            stage = self.down[level]\n            for block_idx, block in enumerate(stage.block):\n                hidden_states = block(hidden_states, temb=None)\n                if stage.attn:\n                    hidden_states = stage.attn[block_idx](hidden_states)\n\n            if level != self.num_resolutions - 1 and hasattr(stage, \"downsample\"):\n                hidden_states = stage.downsample(hidden_states)\n\n        hidden_states = self.mid.block_1(hidden_states, temb=None)\n        hidden_states = self.mid.attn_1(hidden_states)\n        hidden_states = self.mid.block_2(hidden_states, temb=None)\n\n        hidden_states = self.norm_out(hidden_states)\n        hidden_states = self.non_linearity(hidden_states)\n        hidden_states = self.conv_out(hidden_states)\n\n        return hidden_states\n\n\nclass LTX2AudioDecoder(nn.Module):\n    \"\"\"\n    Symmetric decoder that reconstructs audio spectrograms from latent features.\n\n    The decoder mirrors the encoder structure with configurable channel multipliers, attention resolutions, and causal\n    convolutions.\n    \"\"\"\n\n    def __init__(\n        self,\n        base_channels: int = 128,\n        output_channels: int = 1,\n        num_res_blocks: int = 2,\n        attn_resolutions: Optional[Tuple[int, ...]] = None,\n        in_channels: int = 2,\n        resolution: int = 256,\n        latent_channels: int = 8,\n        ch_mult: Tuple[int, ...] = (1, 2, 4),\n        norm_type: str = \"group\",\n        causality_axis: Optional[str] = \"width\",\n        dropout: float = 0.0,\n        mid_block_add_attention: bool = False,\n        sample_rate: int = 16000,\n        mel_hop_length: int = 160,\n        is_causal: bool = True,\n        mel_bins: Optional[int] = 64,\n    ) -> None:\n        super().__init__()\n\n        self.sample_rate = sample_rate\n        self.mel_hop_length = mel_hop_length\n        self.is_causal = is_causal\n        self.mel_bins = mel_bins\n        self.patchifier = LTX2AudioAudioPatchifier(\n            patch_size=1,\n            audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,\n            sample_rate=sample_rate,\n            hop_length=mel_hop_length,\n            is_causal=is_causal,\n        )\n\n        self.base_channels = base_channels\n        self.temb_ch = 0\n        self.num_resolutions = len(ch_mult)\n        self.num_res_blocks = num_res_blocks\n        self.resolution = resolution\n        self.in_channels = in_channels\n        self.out_ch = output_channels\n        self.give_pre_end = False\n        self.tanh_out = False\n        self.norm_type = norm_type\n        self.latent_channels = latent_channels\n        self.channel_multipliers = ch_mult\n        self.attn_resolutions = attn_resolutions\n        self.causality_axis = causality_axis\n\n        base_block_channels = base_channels * self.channel_multipliers[-1]\n        base_resolution = resolution // (2 ** (self.num_resolutions - 1))\n        self.z_shape = (1, latent_channels, base_resolution, base_resolution)\n\n        if self.causality_axis is not None:\n            self.conv_in = LTX2AudioCausalConv2d(\n                latent_channels,\n                base_block_channels,\n                kernel_size=3,\n                stride=1,\n                causality_axis=self.causality_axis,\n            )\n        else:\n            self.conv_in = nn.Conv2d(\n                latent_channels, base_block_channels, kernel_size=3, stride=1, padding=1\n            )\n        self.non_linearity = nn.SiLU()\n        self.mid = nn.Module()\n        self.mid.block_1 = LTX2AudioResnetBlock(\n            in_channels=base_block_channels,\n            out_channels=base_block_channels,\n            temb_channels=self.temb_ch,\n            dropout=dropout,\n            norm_type=self.norm_type,\n            causality_axis=self.causality_axis,\n        )\n        if mid_block_add_attention:\n            self.mid.attn_1 = LTX2AudioAttnBlock(\n                base_block_channels, norm_type=self.norm_type\n            )\n        else:\n            self.mid.attn_1 = nn.Identity()\n        self.mid.block_2 = LTX2AudioResnetBlock(\n            in_channels=base_block_channels,\n            out_channels=base_block_channels,\n            temb_channels=self.temb_ch,\n            dropout=dropout,\n            norm_type=self.norm_type,\n            causality_axis=self.causality_axis,\n        )\n\n        self.up = nn.ModuleList()\n        block_in = base_block_channels\n        curr_res = self.resolution // (2 ** (self.num_resolutions - 1))\n\n        for level in reversed(range(self.num_resolutions)):\n            stage = nn.Module()\n            stage.block = nn.ModuleList()\n            stage.attn = nn.ModuleList()\n            block_out = self.base_channels * self.channel_multipliers[level]\n\n            for _ in range(self.num_res_blocks + 1):\n                stage.block.append(\n                    LTX2AudioResnetBlock(\n                        in_channels=block_in,\n                        out_channels=block_out,\n                        temb_channels=self.temb_ch,\n                        dropout=dropout,\n                        norm_type=self.norm_type,\n                        causality_axis=self.causality_axis,\n                    )\n                )\n                block_in = block_out\n                if self.attn_resolutions:\n                    if curr_res in self.attn_resolutions:\n                        stage.attn.append(\n                            LTX2AudioAttnBlock(block_in, norm_type=self.norm_type)\n                        )\n\n            if level != 0:\n                stage.upsample = LTX2AudioUpsample(\n                    block_in, True, causality_axis=self.causality_axis\n                )\n                curr_res *= 2\n\n            self.up.insert(0, stage)\n\n        final_block_channels = block_in\n\n        if self.norm_type == \"group\":\n            self.norm_out = nn.GroupNorm(\n                num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True\n            )\n        elif self.norm_type == \"pixel\":\n            self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6)\n        else:\n            raise ValueError(f\"Invalid normalization type: {self.norm_type}\")\n\n        if self.causality_axis is not None:\n            self.conv_out = LTX2AudioCausalConv2d(\n                final_block_channels,\n                output_channels,\n                kernel_size=3,\n                stride=1,\n                causality_axis=self.causality_axis,\n            )\n        else:\n            self.conv_out = nn.Conv2d(\n                final_block_channels,\n                output_channels,\n                kernel_size=3,\n                stride=1,\n                padding=1,\n            )\n\n    def forward(\n        self,\n        sample: torch.Tensor,\n    ) -> torch.Tensor:\n        _, _, frames, mel_bins = sample.shape\n\n        target_frames = frames * LATENT_DOWNSAMPLE_FACTOR\n\n        if self.causality_axis is not None:\n            target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1)\n\n        target_channels = self.out_ch\n        target_mel_bins = self.mel_bins if self.mel_bins is not None else mel_bins\n\n        hidden_features = self.conv_in(sample)\n        hidden_features = self.mid.block_1(hidden_features, temb=None)\n        hidden_features = self.mid.attn_1(hidden_features)\n        hidden_features = self.mid.block_2(hidden_features, temb=None)\n\n        for level in reversed(range(self.num_resolutions)):\n            stage = self.up[level]\n            for block_idx, block in enumerate(stage.block):\n                hidden_features = block(hidden_features, temb=None)\n                if stage.attn:\n                    hidden_features = stage.attn[block_idx](hidden_features)\n\n            if level != 0 and hasattr(stage, \"upsample\"):\n                hidden_features = stage.upsample(hidden_features)\n\n        if self.give_pre_end:\n            return hidden_features\n\n        hidden = self.norm_out(hidden_features)\n        hidden = self.non_linearity(hidden)\n        decoded_output = self.conv_out(hidden)\n        decoded_output = torch.tanh(decoded_output) if self.tanh_out else decoded_output\n\n        _, _, current_time, current_freq = decoded_output.shape\n        target_time = target_frames\n        target_freq = target_mel_bins\n\n        decoded_output = decoded_output[\n            :,\n            :target_channels,\n            : min(current_time, target_time),\n            : min(current_freq, target_freq),\n        ]\n\n        time_padding_needed = target_time - decoded_output.shape[2]\n        freq_padding_needed = target_freq - decoded_output.shape[3]\n\n        if time_padding_needed > 0 or freq_padding_needed > 0:\n            padding = (\n                0,\n                max(freq_padding_needed, 0),\n                0,\n                max(time_padding_needed, 0),\n            )\n            decoded_output = F.pad(decoded_output, padding)\n\n        decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq]\n\n        return decoded_output\n\n\nclass AutoencoderKLLTX2Audio(ParallelTiledVAE):\n    r\"\"\"\n    LTX2 audio VAE for encoding and decoding audio latent representations.\n    \"\"\"\n\n    _supports_gradient_checkpointing = False\n\n    def __init__(\n        self,\n        config: LTXAudioVAEConfig,\n    ) -> None:\n        super().__init__(config=config)\n\n        causality_axis = config.arch_config.causality_axis\n        attn_resolutions = config.arch_config.attn_resolutions\n        base_channels = config.arch_config.base_channels\n        output_channels = config.arch_config.output_channels\n        ch_mult = config.arch_config.ch_mult\n        num_res_blocks = config.arch_config.num_res_blocks\n        in_channels = config.arch_config.in_channels\n        resolution = config.arch_config.resolution\n        latent_channels = config.arch_config.latent_channels\n        norm_type = config.arch_config.norm_type\n        dropout = config.arch_config.dropout\n        mid_block_add_attention = config.arch_config.mid_block_add_attention\n        sample_rate = config.arch_config.sample_rate\n        mel_hop_length = config.arch_config.mel_hop_length\n        is_causal = config.arch_config.is_causal\n        mel_bins = config.arch_config.mel_bins\n        double_z = config.arch_config.double_z\n\n        supported_causality_axes = {\"none\", \"width\", \"height\", \"width-compatibility\"}\n        if causality_axis not in supported_causality_axes:\n            raise ValueError(\n                f\"{causality_axis=} is not valid. Supported values: {supported_causality_axes}\"\n            )\n\n        attn_resolution_set = (\n            set(attn_resolutions) if attn_resolutions else attn_resolutions\n        )\n\n        self.encoder = LTX2AudioEncoder(\n            base_channels=base_channels,\n            output_channels=output_channels,\n            ch_mult=ch_mult,\n            num_res_blocks=num_res_blocks,\n            attn_resolutions=attn_resolution_set,\n            in_channels=in_channels,\n            resolution=resolution,\n            latent_channels=latent_channels,\n            norm_type=norm_type,\n            causality_axis=causality_axis,\n            dropout=dropout,\n            mid_block_add_attention=mid_block_add_attention,\n            sample_rate=sample_rate,\n            mel_hop_length=mel_hop_length,\n            is_causal=is_causal,\n            mel_bins=mel_bins,\n            double_z=double_z,\n        )\n\n        self.decoder = LTX2AudioDecoder(\n            base_channels=base_channels,\n            output_channels=output_channels,\n            ch_mult=ch_mult,\n            num_res_blocks=num_res_blocks,\n            attn_resolutions=attn_resolution_set,\n            in_channels=in_channels,\n            resolution=resolution,\n            latent_channels=latent_channels,\n            norm_type=norm_type,\n            causality_axis=causality_axis,\n            dropout=dropout,\n            mid_block_add_attention=mid_block_add_attention,\n            sample_rate=sample_rate,\n            mel_hop_length=mel_hop_length,\n            is_causal=is_causal,\n            mel_bins=mel_bins,\n        )\n\n        # Per-channel statistics for normalizing and denormalizing the latent representation. This statistics is computed over\n        # the entire dataset and stored in model's checkpoint under AudioVAE state_dict\n        latents_std = torch.zeros((base_channels,))\n        latents_mean = torch.ones((base_channels,))\n        self.register_buffer(\"latents_mean\", latents_mean, persistent=True)\n        self.register_buffer(\"latents_std\", latents_std, persistent=True)\n\n        # TODO: confirm whether the mel compression ratio below is correct\n        self.mel_compression_ratio = LATENT_DOWNSAMPLE_FACTOR\n        self.use_slicing = False\n\n    def _encode(self, x: torch.Tensor) -> torch.Tensor:\n        return self.encoder(x)\n\n    def encode(self, x: torch.Tensor, return_dict: bool = True):\n        if self.use_slicing and x.shape[0] > 1:\n            encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]\n            h = torch.cat(encoded_slices)\n        else:\n            h = self._encode(x)\n        posterior = DiagonalGaussianDistribution(h)\n\n        if not return_dict:\n            return (posterior,)\n        return AutoencoderKLOutput(latent_dist=posterior)\n\n    def _decode(self, z: torch.Tensor) -> torch.Tensor:\n        return self.decoder(z)\n\n    def decode(\n        self, z: torch.Tensor, return_dict: bool = True\n    ) -> Union[DecoderOutput, torch.Tensor]:\n        if self.use_slicing and z.shape[0] > 1:\n            decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]\n            decoded = torch.cat(decoded_slices)\n        else:\n            decoded = self._decode(z)\n\n        if not return_dict:\n            return (decoded,)\n\n        return DecoderOutput(sample=decoded)\n\n    def forward(\n        self,\n        sample: torch.Tensor,\n        sample_posterior: bool = False,\n        return_dict: bool = True,\n        generator: Optional[torch.Generator] = None,\n    ) -> Union[DecoderOutput, torch.Tensor]:\n        posterior = self.encode(sample).latent_dist\n        if sample_posterior:\n            z = posterior.sample(generator=generator)\n        else:\n            z = posterior.mode()\n        dec = self.decode(z)\n        if not return_dict:\n            return (dec.sample,)\n        return dec\n\n\nEntryClass = AutoencoderKLLTX2Audio\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/vaes/ltx_2_vae.py",
    "content": "from typing import Optional, Tuple, Union\n\nimport torch\nimport torch.nn as nn\nfrom diffusers.models.activations import get_activation\nfrom diffusers.models.autoencoders.vae import (\n    DecoderOutput,\n    DiagonalGaussianDistribution,\n)\nfrom diffusers.models.embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings\nfrom diffusers.models.modeling_outputs import AutoencoderKLOutput\n\nfrom sglang.multimodal_gen.configs.models.vaes.ltx_video import LTXVideoVAEConfig\nfrom sglang.multimodal_gen.runtime.models.vaes.common import ParallelTiledVAE\n\n\nclass PerChannelRMSNorm(nn.Module):\n    \"\"\"\n    Per-pixel (per-location) RMS normalization layer.\n\n    For each element along the chosen dimension, this layer normalizes the tensor by the root-mean-square of its values\n    across that dimension:\n\n        y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps)\n    \"\"\"\n\n    def __init__(self, channel_dim: int = 1, eps: float = 1e-8) -> None:\n        \"\"\"\n        Args:\n            dim: Dimension along which to compute the RMS (typically channels).\n            eps: Small constant added for numerical stability.\n        \"\"\"\n        super().__init__()\n        self.channel_dim = channel_dim\n        self.eps = eps\n\n    def forward(\n        self, x: torch.Tensor, channel_dim: Optional[int] = None\n    ) -> torch.Tensor:\n        \"\"\"\n        Apply RMS normalization along the configured dimension.\n        \"\"\"\n        channel_dim = channel_dim or self.channel_dim\n        # Compute mean of squared values along `dim`, keep dimensions for broadcasting.\n        mean_sq = torch.mean(x**2, dim=self.channel_dim, keepdim=True)\n        # Normalize by the root-mean-square (RMS).\n        rms = torch.sqrt(mean_sq + self.eps)\n        return x / rms\n\n\n# Like LTXCausalConv3d, but whether causal inference is performed can be specified at runtime\nclass LTX2VideoCausalConv3d(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: Union[int, Tuple[int, int, int]] = 3,\n        stride: Union[int, Tuple[int, int, int]] = 1,\n        dilation: Union[int, Tuple[int, int, int]] = 1,\n        groups: int = 1,\n        spatial_padding_mode: str = \"zeros\",\n    ):\n        super().__init__()\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.kernel_size = (\n            kernel_size\n            if isinstance(kernel_size, tuple)\n            else (kernel_size, kernel_size, kernel_size)\n        )\n\n        dilation = dilation if isinstance(dilation, tuple) else (dilation, 1, 1)\n        stride = stride if isinstance(stride, tuple) else (stride, stride, stride)\n        height_pad = self.kernel_size[1] // 2\n        width_pad = self.kernel_size[2] // 2\n        padding = (0, height_pad, width_pad)\n\n        self.conv = nn.Conv3d(\n            in_channels,\n            out_channels,\n            self.kernel_size,\n            stride=stride,\n            dilation=dilation,\n            groups=groups,\n            padding=padding,\n            padding_mode=spatial_padding_mode,\n        )\n\n    def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Tensor:\n        time_kernel_size = self.kernel_size[0]\n\n        if causal:\n            pad_left = hidden_states[:, :, :1, :, :].repeat(\n                (1, 1, time_kernel_size - 1, 1, 1)\n            )\n            hidden_states = torch.concatenate([pad_left, hidden_states], dim=2)\n        else:\n            pad_left = hidden_states[:, :, :1, :, :].repeat(\n                (1, 1, (time_kernel_size - 1) // 2, 1, 1)\n            )\n            pad_right = hidden_states[:, :, -1:, :, :].repeat(\n                (1, 1, (time_kernel_size - 1) // 2, 1, 1)\n            )\n            hidden_states = torch.concatenate(\n                [pad_left, hidden_states, pad_right], dim=2\n            )\n\n        hidden_states = self.conv(hidden_states)\n        return hidden_states\n\n\n# Like LTXVideoResnetBlock3d, but uses new causal Conv3d, normal Conv3d for the conv_shortcut, and the spatial padding\n# mode is configurable\nclass LTX2VideoResnetBlock3d(nn.Module):\n    r\"\"\"\n    A 3D ResNet block used in the LTX 2.0 audiovisual model.\n\n    Args:\n        in_channels (`int`):\n            Number of input channels.\n        out_channels (`int`, *optional*):\n            Number of output channels. If None, defaults to `in_channels`.\n        dropout (`float`, defaults to `0.0`):\n            Dropout rate.\n        eps (`float`, defaults to `1e-6`):\n            Epsilon value for normalization layers.\n        elementwise_affine (`bool`, defaults to `False`):\n            Whether to enable elementwise affinity in the normalization layers.\n        non_linearity (`str`, defaults to `\"swish\"`):\n            Activation function to use.\n        conv_shortcut (bool, defaults to `False`):\n            Whether or not to use a convolution shortcut.\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: Optional[int] = None,\n        dropout: float = 0.0,\n        eps: float = 1e-6,\n        elementwise_affine: bool = False,\n        non_linearity: str = \"swish\",\n        inject_noise: bool = False,\n        timestep_conditioning: bool = False,\n        spatial_padding_mode: str = \"zeros\",\n    ) -> None:\n        super().__init__()\n\n        out_channels = out_channels or in_channels\n\n        self.nonlinearity = get_activation(non_linearity)\n\n        self.norm1 = PerChannelRMSNorm()\n        self.conv1 = LTX2VideoCausalConv3d(\n            in_channels=in_channels,\n            out_channels=out_channels,\n            kernel_size=3,\n            spatial_padding_mode=spatial_padding_mode,\n        )\n\n        self.norm2 = PerChannelRMSNorm()\n        self.dropout = nn.Dropout(dropout)\n        self.conv2 = LTX2VideoCausalConv3d(\n            in_channels=out_channels,\n            out_channels=out_channels,\n            kernel_size=3,\n            spatial_padding_mode=spatial_padding_mode,\n        )\n\n        self.norm3 = None\n        self.conv_shortcut = None\n        if in_channels != out_channels:\n            self.norm3 = nn.LayerNorm(\n                in_channels, eps=eps, elementwise_affine=True, bias=True\n            )\n            # LTX 2.0 uses a normal nn.Conv3d here rather than LTXVideoCausalConv3d\n            self.conv_shortcut = nn.Conv3d(\n                in_channels=in_channels,\n                out_channels=out_channels,\n                kernel_size=1,\n                stride=1,\n            )\n\n        self.per_channel_scale1 = None\n        self.per_channel_scale2 = None\n        if inject_noise:\n            self.per_channel_scale1 = nn.Parameter(torch.zeros(in_channels, 1, 1))\n            self.per_channel_scale2 = nn.Parameter(torch.zeros(in_channels, 1, 1))\n\n        self.scale_shift_table = None\n        if timestep_conditioning:\n            self.scale_shift_table = nn.Parameter(\n                torch.randn(4, in_channels) / in_channels**0.5\n            )\n\n    def forward(\n        self,\n        inputs: torch.Tensor,\n        temb: Optional[torch.Tensor] = None,\n        generator: Optional[torch.Generator] = None,\n        causal: bool = True,\n    ) -> torch.Tensor:\n        hidden_states = inputs\n\n        hidden_states = self.norm1(hidden_states)\n\n        if self.scale_shift_table is not None:\n            temb = (\n                temb.unflatten(1, (4, -1))\n                + self.scale_shift_table[None, ..., None, None, None]\n            )\n            shift_1, scale_1, shift_2, scale_2 = temb.unbind(dim=1)\n            hidden_states = hidden_states * (1 + scale_1) + shift_1\n\n        hidden_states = self.nonlinearity(hidden_states)\n        hidden_states = self.conv1(hidden_states, causal=causal)\n\n        if self.per_channel_scale1 is not None:\n            spatial_shape = hidden_states.shape[-2:]\n            spatial_noise = torch.randn(\n                spatial_shape,\n                generator=generator,\n                device=hidden_states.device,\n                dtype=hidden_states.dtype,\n            )[None]\n            hidden_states = (\n                hidden_states\n                + (spatial_noise * self.per_channel_scale1)[None, :, None, ...]\n            )\n\n        hidden_states = self.norm2(hidden_states)\n\n        if self.scale_shift_table is not None:\n            hidden_states = hidden_states * (1 + scale_2) + shift_2\n\n        hidden_states = self.nonlinearity(hidden_states)\n        hidden_states = self.dropout(hidden_states)\n        hidden_states = self.conv2(hidden_states, causal=causal)\n\n        if self.per_channel_scale2 is not None:\n            spatial_shape = hidden_states.shape[-2:]\n            spatial_noise = torch.randn(\n                spatial_shape,\n                generator=generator,\n                device=hidden_states.device,\n                dtype=hidden_states.dtype,\n            )[None]\n            hidden_states = (\n                hidden_states\n                + (spatial_noise * self.per_channel_scale2)[None, :, None, ...]\n            )\n\n        if self.norm3 is not None:\n            inputs = self.norm3(inputs.movedim(1, -1)).movedim(-1, 1)\n\n        if self.conv_shortcut is not None:\n            inputs = self.conv_shortcut(inputs)\n\n        hidden_states = hidden_states + inputs\n        return hidden_states\n\n\n# Like LTX 1.0 LTXVideoDownsampler3d, but uses new causal Conv3d\nclass LTXVideoDownsampler3d(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        stride: Union[int, Tuple[int, int, int]] = 1,\n        spatial_padding_mode: str = \"zeros\",\n    ) -> None:\n        super().__init__()\n\n        self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride)\n        self.group_size = (\n            in_channels * stride[0] * stride[1] * stride[2]\n        ) // out_channels\n\n        out_channels = out_channels // (\n            self.stride[0] * self.stride[1] * self.stride[2]\n        )\n\n        self.conv = LTX2VideoCausalConv3d(\n            in_channels=in_channels,\n            out_channels=out_channels,\n            kernel_size=3,\n            stride=1,\n            spatial_padding_mode=spatial_padding_mode,\n        )\n\n    def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Tensor:\n        hidden_states = torch.cat(\n            [hidden_states[:, :, : self.stride[0] - 1], hidden_states], dim=2\n        )\n\n        residual = (\n            hidden_states.unflatten(4, (-1, self.stride[2]))\n            .unflatten(3, (-1, self.stride[1]))\n            .unflatten(2, (-1, self.stride[0]))\n        )\n        residual = residual.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4)\n        residual = residual.unflatten(1, (-1, self.group_size))\n        residual = residual.mean(dim=2)\n\n        hidden_states = self.conv(hidden_states, causal=causal)\n        hidden_states = (\n            hidden_states.unflatten(4, (-1, self.stride[2]))\n            .unflatten(3, (-1, self.stride[1]))\n            .unflatten(2, (-1, self.stride[0]))\n        )\n        hidden_states = hidden_states.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4)\n        hidden_states = hidden_states + residual\n\n        return hidden_states\n\n\n# Like LTX 1.0 LTXVideoUpsampler3d, but uses new causal Conv3d\nclass LTXVideoUpsampler3d(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        stride: Union[int, Tuple[int, int, int]] = 1,\n        residual: bool = False,\n        upscale_factor: int = 1,\n        spatial_padding_mode: str = \"zeros\",\n    ) -> None:\n        super().__init__()\n\n        self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride)\n        self.residual = residual\n        self.upscale_factor = upscale_factor\n\n        out_channels = (\n            in_channels * stride[0] * stride[1] * stride[2]\n        ) // upscale_factor\n\n        self.conv = LTX2VideoCausalConv3d(\n            in_channels=in_channels,\n            out_channels=out_channels,\n            kernel_size=3,\n            stride=1,\n            spatial_padding_mode=spatial_padding_mode,\n        )\n\n    def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Tensor:\n        batch_size, num_channels, num_frames, height, width = hidden_states.shape\n\n        if self.residual:\n            residual = hidden_states.reshape(\n                batch_size,\n                -1,\n                self.stride[0],\n                self.stride[1],\n                self.stride[2],\n                num_frames,\n                height,\n                width,\n            )\n            residual = (\n                residual.permute(0, 1, 5, 2, 6, 3, 7, 4)\n                .flatten(6, 7)\n                .flatten(4, 5)\n                .flatten(2, 3)\n            )\n            repeats = (\n                self.stride[0] * self.stride[1] * self.stride[2]\n            ) // self.upscale_factor\n            residual = residual.repeat(1, repeats, 1, 1, 1)\n            residual = residual[:, :, self.stride[0] - 1 :]\n\n        hidden_states = self.conv(hidden_states, causal=causal)\n        hidden_states = hidden_states.reshape(\n            batch_size,\n            -1,\n            self.stride[0],\n            self.stride[1],\n            self.stride[2],\n            num_frames,\n            height,\n            width,\n        )\n        hidden_states = (\n            hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4)\n            .flatten(6, 7)\n            .flatten(4, 5)\n            .flatten(2, 3)\n        )\n        hidden_states = hidden_states[:, :, self.stride[0] - 1 :]\n\n        if self.residual:\n            hidden_states = hidden_states + residual\n\n        return hidden_states\n\n\n# Like LTX 1.0 LTXVideo095DownBlock3D, but with the updated LTX2VideoResnetBlock3d\nclass LTX2VideoDownBlock3D(nn.Module):\n    r\"\"\"\n    Down block used in the LTXVideo model.\n\n    Args:\n        in_channels (`int`):\n            Number of input channels.\n        out_channels (`int`, *optional*):\n            Number of output channels. If None, defaults to `in_channels`.\n        num_layers (`int`, defaults to `1`):\n            Number of resnet layers.\n        dropout (`float`, defaults to `0.0`):\n            Dropout rate.\n        resnet_eps (`float`, defaults to `1e-6`):\n            Epsilon value for normalization layers.\n        resnet_act_fn (`str`, defaults to `\"swish\"`):\n            Activation function to use.\n        spatio_temporal_scale (`bool`, defaults to `True`):\n            Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.\n            Whether or not to downsample across temporal dimension.\n        is_causal (`bool`, defaults to `True`):\n            Whether this layer behaves causally (future frames depend only on past frames) or not.\n    \"\"\"\n\n    _supports_gradient_checkpointing = True\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: Optional[int] = None,\n        num_layers: int = 1,\n        dropout: float = 0.0,\n        resnet_eps: float = 1e-6,\n        resnet_act_fn: str = \"swish\",\n        spatio_temporal_scale: bool = True,\n        downsample_type: str = \"conv\",\n        spatial_padding_mode: str = \"zeros\",\n    ):\n        super().__init__()\n\n        out_channels = out_channels or in_channels\n\n        resnets = []\n        for _ in range(num_layers):\n            resnets.append(\n                LTX2VideoResnetBlock3d(\n                    in_channels=in_channels,\n                    out_channels=in_channels,\n                    dropout=dropout,\n                    eps=resnet_eps,\n                    non_linearity=resnet_act_fn,\n                    spatial_padding_mode=spatial_padding_mode,\n                )\n            )\n        self.resnets = nn.ModuleList(resnets)\n\n        self.downsamplers = None\n        if spatio_temporal_scale:\n            self.downsamplers = nn.ModuleList()\n\n            if downsample_type == \"conv\":\n                self.downsamplers.append(\n                    LTX2VideoCausalConv3d(\n                        in_channels=in_channels,\n                        out_channels=in_channels,\n                        kernel_size=3,\n                        stride=(2, 2, 2),\n                        spatial_padding_mode=spatial_padding_mode,\n                    )\n                )\n            elif downsample_type == \"spatial\":\n                self.downsamplers.append(\n                    LTXVideoDownsampler3d(\n                        in_channels=in_channels,\n                        out_channels=out_channels,\n                        stride=(1, 2, 2),\n                        spatial_padding_mode=spatial_padding_mode,\n                    )\n                )\n            elif downsample_type == \"temporal\":\n                self.downsamplers.append(\n                    LTXVideoDownsampler3d(\n                        in_channels=in_channels,\n                        out_channels=out_channels,\n                        stride=(2, 1, 1),\n                        spatial_padding_mode=spatial_padding_mode,\n                    )\n                )\n            elif downsample_type == \"spatiotemporal\":\n                self.downsamplers.append(\n                    LTXVideoDownsampler3d(\n                        in_channels=in_channels,\n                        out_channels=out_channels,\n                        stride=(2, 2, 2),\n                        spatial_padding_mode=spatial_padding_mode,\n                    )\n                )\n\n        self.gradient_checkpointing = False\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        temb: Optional[torch.Tensor] = None,\n        generator: Optional[torch.Generator] = None,\n        causal: bool = True,\n    ) -> torch.Tensor:\n        r\"\"\"Forward method of the `LTXDownBlock3D` class.\"\"\"\n\n        for i, resnet in enumerate(self.resnets):\n            if torch.is_grad_enabled() and self.gradient_checkpointing:\n                hidden_states = self._gradient_checkpointing_func(\n                    resnet, hidden_states, temb, generator, causal\n                )\n            else:\n                hidden_states = resnet(hidden_states, temb, generator, causal=causal)\n\n        if self.downsamplers is not None:\n            for downsampler in self.downsamplers:\n                hidden_states = downsampler(hidden_states, causal=causal)\n\n        return hidden_states\n\n\n# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d\n# Like LTX 1.0 LTXVideoMidBlock3d, but with the updated LTX2VideoResnetBlock3d\nclass LTX2VideoMidBlock3d(nn.Module):\n    r\"\"\"\n    A middle block used in the LTXVideo model.\n\n    Args:\n        in_channels (`int`):\n            Number of input channels.\n        num_layers (`int`, defaults to `1`):\n            Number of resnet layers.\n        dropout (`float`, defaults to `0.0`):\n            Dropout rate.\n        resnet_eps (`float`, defaults to `1e-6`):\n            Epsilon value for normalization layers.\n        resnet_act_fn (`str`, defaults to `\"swish\"`):\n            Activation function to use.\n        is_causal (`bool`, defaults to `True`):\n            Whether this layer behaves causally (future frames depend only on past frames) or not.\n    \"\"\"\n\n    _supports_gradient_checkpointing = True\n\n    def __init__(\n        self,\n        in_channels: int,\n        num_layers: int = 1,\n        dropout: float = 0.0,\n        resnet_eps: float = 1e-6,\n        resnet_act_fn: str = \"swish\",\n        inject_noise: bool = False,\n        timestep_conditioning: bool = False,\n        spatial_padding_mode: str = \"zeros\",\n    ) -> None:\n        super().__init__()\n\n        self.time_embedder = None\n        if timestep_conditioning:\n            self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(\n                in_channels * 4, 0\n            )\n\n        resnets = []\n        for _ in range(num_layers):\n            resnets.append(\n                LTX2VideoResnetBlock3d(\n                    in_channels=in_channels,\n                    out_channels=in_channels,\n                    dropout=dropout,\n                    eps=resnet_eps,\n                    non_linearity=resnet_act_fn,\n                    inject_noise=inject_noise,\n                    timestep_conditioning=timestep_conditioning,\n                    spatial_padding_mode=spatial_padding_mode,\n                )\n            )\n        self.resnets = nn.ModuleList(resnets)\n\n        self.gradient_checkpointing = False\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        temb: Optional[torch.Tensor] = None,\n        generator: Optional[torch.Generator] = None,\n        causal: bool = True,\n    ) -> torch.Tensor:\n        r\"\"\"Forward method of the `LTXMidBlock3D` class.\"\"\"\n\n        if self.time_embedder is not None:\n            temb = self.time_embedder(\n                timestep=temb.flatten(),\n                resolution=None,\n                aspect_ratio=None,\n                batch_size=hidden_states.size(0),\n                hidden_dtype=hidden_states.dtype,\n            )\n            temb = temb.view(hidden_states.size(0), -1, 1, 1, 1)\n\n        for i, resnet in enumerate(self.resnets):\n            if torch.is_grad_enabled() and self.gradient_checkpointing:\n                hidden_states = self._gradient_checkpointing_func(\n                    resnet, hidden_states, temb, generator, causal\n                )\n            else:\n                hidden_states = resnet(hidden_states, temb, generator, causal=causal)\n\n        return hidden_states\n\n\n# Like LTXVideoUpBlock3d but with no conv_in and the updated LTX2VideoResnetBlock3d\nclass LTX2VideoUpBlock3d(nn.Module):\n    r\"\"\"\n    Up block used in the LTXVideo model.\n\n    Args:\n        in_channels (`int`):\n            Number of input channels.\n        out_channels (`int`, *optional*):\n            Number of output channels. If None, defaults to `in_channels`.\n        num_layers (`int`, defaults to `1`):\n            Number of resnet layers.\n        dropout (`float`, defaults to `0.0`):\n            Dropout rate.\n        resnet_eps (`float`, defaults to `1e-6`):\n            Epsilon value for normalization layers.\n        resnet_act_fn (`str`, defaults to `\"swish\"`):\n            Activation function to use.\n        spatio_temporal_scale (`bool`, defaults to `True`):\n            Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.\n            Whether or not to downsample across temporal dimension.\n        is_causal (`bool`, defaults to `True`):\n            Whether this layer behaves causally (future frames depend only on past frames) or not.\n    \"\"\"\n\n    _supports_gradient_checkpointing = True\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: Optional[int] = None,\n        num_layers: int = 1,\n        dropout: float = 0.0,\n        resnet_eps: float = 1e-6,\n        resnet_act_fn: str = \"swish\",\n        spatio_temporal_scale: bool = True,\n        inject_noise: bool = False,\n        timestep_conditioning: bool = False,\n        upsample_residual: bool = False,\n        upscale_factor: int = 1,\n        spatial_padding_mode: str = \"zeros\",\n    ):\n        super().__init__()\n\n        out_channels = out_channels or in_channels\n\n        self.time_embedder = None\n        if timestep_conditioning:\n            self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(\n                in_channels * 4, 0\n            )\n\n        self.conv_in = None\n        if in_channels != out_channels:\n            self.conv_in = LTX2VideoResnetBlock3d(\n                in_channels=in_channels,\n                out_channels=out_channels,\n                dropout=dropout,\n                eps=resnet_eps,\n                non_linearity=resnet_act_fn,\n                inject_noise=inject_noise,\n                timestep_conditioning=timestep_conditioning,\n                spatial_padding_mode=spatial_padding_mode,\n            )\n\n        self.upsamplers = None\n        if spatio_temporal_scale:\n            self.upsamplers = nn.ModuleList(\n                [\n                    LTXVideoUpsampler3d(\n                        out_channels * upscale_factor,\n                        stride=(2, 2, 2),\n                        residual=upsample_residual,\n                        upscale_factor=upscale_factor,\n                        spatial_padding_mode=spatial_padding_mode,\n                    )\n                ]\n            )\n\n        resnets = []\n        for _ in range(num_layers):\n            resnets.append(\n                LTX2VideoResnetBlock3d(\n                    in_channels=out_channels,\n                    out_channels=out_channels,\n                    dropout=dropout,\n                    eps=resnet_eps,\n                    non_linearity=resnet_act_fn,\n                    inject_noise=inject_noise,\n                    timestep_conditioning=timestep_conditioning,\n                    spatial_padding_mode=spatial_padding_mode,\n                )\n            )\n        self.resnets = nn.ModuleList(resnets)\n\n        self.gradient_checkpointing = False\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        temb: Optional[torch.Tensor] = None,\n        generator: Optional[torch.Generator] = None,\n        causal: bool = True,\n    ) -> torch.Tensor:\n        if self.conv_in is not None:\n            hidden_states = self.conv_in(hidden_states, temb, generator, causal=causal)\n\n        if self.time_embedder is not None:\n            temb = self.time_embedder(\n                timestep=temb.flatten(),\n                resolution=None,\n                aspect_ratio=None,\n                batch_size=hidden_states.size(0),\n                hidden_dtype=hidden_states.dtype,\n            )\n            temb = temb.view(hidden_states.size(0), -1, 1, 1, 1)\n\n        if self.upsamplers is not None:\n            for upsampler in self.upsamplers:\n                hidden_states = upsampler(hidden_states, causal=causal)\n\n        for i, resnet in enumerate(self.resnets):\n            if torch.is_grad_enabled() and self.gradient_checkpointing:\n                hidden_states = self._gradient_checkpointing_func(\n                    resnet, hidden_states, temb, generator, causal\n                )\n            else:\n                hidden_states = resnet(hidden_states, temb, generator, causal=causal)\n\n        return hidden_states\n\n\n# Like LTX 1.0 LTXVideoEncoder3d but with different default args - the spatiotemporal downsampling pattern is\n# different, as is the layers_per_block (the 2.0 VAE is bigger)\nclass LTX2VideoEncoder3d(nn.Module):\n    r\"\"\"\n    The `LTXVideoEncoder3d` layer of a variational autoencoder that encodes input video samples to its latent\n    representation.\n\n    Args:\n        in_channels (`int`, defaults to 3):\n            Number of input channels.\n        out_channels (`int`, defaults to 128):\n            Number of latent channels.\n        block_out_channels (`Tuple[int, ...]`, defaults to `(256, 512, 1024, 2048)`):\n            The number of output channels for each block.\n        spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, True)`:\n            Whether a block should contain spatio-temporal downscaling layers or not.\n        layers_per_block (`Tuple[int, ...]`, defaults to `(4, 6, 6, 2, 2)`):\n            The number of layers per block.\n        downsample_type (`Tuple[str, ...]`, defaults to `(\"spatial\", \"temporal\", \"spatiotemporal\", \"spatiotemporal\")`):\n            The spatiotemporal downsampling pattern per block. Per-layer values can be\n                - `\"spatial\"` (downsample spatial dims by 2x)\n                - `\"temporal\"` (downsample temporal dim by 2x)\n                - `\"spatiotemporal\"` (downsample both spatial and temporal dims by 2x)\n        patch_size (`int`, defaults to `4`):\n            The size of spatial patches.\n        patch_size_t (`int`, defaults to `1`):\n            The size of temporal patches.\n        resnet_norm_eps (`float`, defaults to `1e-6`):\n            Epsilon value for ResNet normalization layers.\n        is_causal (`bool`, defaults to `True`):\n            Whether this layer behaves causally (future frames depend only on past frames) or not.\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int = 3,\n        out_channels: int = 128,\n        block_out_channels: Tuple[int, ...] = (256, 512, 1024, 2048),\n        down_block_types: Tuple[str, ...] = (\n            \"LTX2VideoDownBlock3D\",\n            \"LTX2VideoDownBlock3D\",\n            \"LTX2VideoDownBlock3D\",\n            \"LTX2VideoDownBlock3D\",\n        ),\n        spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, True),\n        layers_per_block: Tuple[int, ...] = (4, 6, 6, 2, 2),\n        downsample_type: Tuple[str, ...] = (\n            \"spatial\",\n            \"temporal\",\n            \"spatiotemporal\",\n            \"spatiotemporal\",\n        ),\n        patch_size: int = 4,\n        patch_size_t: int = 1,\n        resnet_norm_eps: float = 1e-6,\n        is_causal: bool = True,\n        spatial_padding_mode: str = \"zeros\",\n    ):\n        super().__init__()\n\n        self.patch_size = patch_size\n        self.patch_size_t = patch_size_t\n        self.in_channels = in_channels * patch_size**2\n        self.is_causal = is_causal\n\n        output_channel = out_channels\n\n        self.conv_in = LTX2VideoCausalConv3d(\n            in_channels=self.in_channels,\n            out_channels=output_channel,\n            kernel_size=3,\n            stride=1,\n            spatial_padding_mode=spatial_padding_mode,\n        )\n\n        # down blocks\n        num_block_out_channels = len(block_out_channels)\n        self.down_blocks = nn.ModuleList([])\n        for i in range(num_block_out_channels):\n            input_channel = output_channel\n            output_channel = block_out_channels[i]\n\n            if down_block_types[i] == \"LTX2VideoDownBlock3D\":\n                down_block = LTX2VideoDownBlock3D(\n                    in_channels=input_channel,\n                    out_channels=output_channel,\n                    num_layers=layers_per_block[i],\n                    resnet_eps=resnet_norm_eps,\n                    spatio_temporal_scale=spatio_temporal_scaling[i],\n                    downsample_type=downsample_type[i],\n                    spatial_padding_mode=spatial_padding_mode,\n                )\n            else:\n                raise ValueError(f\"Unknown down block type: {down_block_types[i]}\")\n\n            self.down_blocks.append(down_block)\n\n        # mid block\n        self.mid_block = LTX2VideoMidBlock3d(\n            in_channels=output_channel,\n            num_layers=layers_per_block[-1],\n            resnet_eps=resnet_norm_eps,\n            spatial_padding_mode=spatial_padding_mode,\n        )\n\n        # out\n        self.norm_out = PerChannelRMSNorm()\n        self.conv_act = nn.SiLU()\n        self.conv_out = LTX2VideoCausalConv3d(\n            in_channels=output_channel,\n            out_channels=out_channels + 1,\n            kernel_size=3,\n            stride=1,\n            spatial_padding_mode=spatial_padding_mode,\n        )\n\n        self.gradient_checkpointing = False\n\n    def forward(\n        self, hidden_states: torch.Tensor, causal: Optional[bool] = None\n    ) -> torch.Tensor:\n        r\"\"\"The forward method of the `LTXVideoEncoder3d` class.\"\"\"\n\n        p = self.patch_size\n        p_t = self.patch_size_t\n\n        batch_size, num_channels, num_frames, height, width = hidden_states.shape\n        post_patch_num_frames = num_frames // p_t\n        post_patch_height = height // p\n        post_patch_width = width // p\n        causal = causal or self.is_causal\n\n        hidden_states = hidden_states.reshape(\n            batch_size,\n            num_channels,\n            post_patch_num_frames,\n            p_t,\n            post_patch_height,\n            p,\n            post_patch_width,\n            p,\n        )\n        # Thanks for driving me insane with the weird patching order :(\n        hidden_states = hidden_states.permute(0, 1, 3, 7, 5, 2, 4, 6).flatten(1, 4)\n        hidden_states = self.conv_in(hidden_states, causal=causal)\n\n        if torch.is_grad_enabled() and self.gradient_checkpointing:\n            for down_block in self.down_blocks:\n                hidden_states = self._gradient_checkpointing_func(\n                    down_block, hidden_states, None, None, causal\n                )\n\n            hidden_states = self._gradient_checkpointing_func(\n                self.mid_block, hidden_states, None, None, causal\n            )\n        else:\n            for down_block in self.down_blocks:\n                hidden_states = down_block(hidden_states, causal=causal)\n\n            hidden_states = self.mid_block(hidden_states, causal=causal)\n\n        hidden_states = self.norm_out(hidden_states)\n        hidden_states = self.conv_act(hidden_states)\n        hidden_states = self.conv_out(hidden_states, causal=causal)\n\n        last_channel = hidden_states[:, -1:]\n        last_channel = last_channel.repeat(1, hidden_states.size(1) - 2, 1, 1, 1)\n        hidden_states = torch.cat([hidden_states, last_channel], dim=1)\n\n        return hidden_states\n\n\n# Like LTX 1.0 LTXVideoDecoder3d, but has only 3 symmetric up blocks which are causal and residual with upsample_factor 2\nclass LTX2VideoDecoder3d(nn.Module):\n    r\"\"\"\n    The `LTXVideoDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output\n    sample.\n\n    Args:\n        in_channels (`int`, defaults to 128):\n            Number of latent channels.\n        out_channels (`int`, defaults to 3):\n            Number of output channels.\n        block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):\n            The number of output channels for each block.\n        spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, False)`:\n            Whether a block should contain spatio-temporal upscaling layers or not.\n        layers_per_block (`Tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`):\n            The number of layers per block.\n        patch_size (`int`, defaults to `4`):\n            The size of spatial patches.\n        patch_size_t (`int`, defaults to `1`):\n            The size of temporal patches.\n        resnet_norm_eps (`float`, defaults to `1e-6`):\n            Epsilon value for ResNet normalization layers.\n        is_causal (`bool`, defaults to `False`):\n            Whether this layer behaves causally (future frames depend only on past frames) or not.\n        timestep_conditioning (`bool`, defaults to `False`):\n            Whether to condition the model on timesteps.\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int = 128,\n        out_channels: int = 3,\n        block_out_channels: Tuple[int, ...] = (256, 512, 1024),\n        spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True),\n        layers_per_block: Tuple[int, ...] = (5, 5, 5, 5),\n        patch_size: int = 4,\n        patch_size_t: int = 1,\n        resnet_norm_eps: float = 1e-6,\n        is_causal: bool = False,\n        inject_noise: Tuple[bool, ...] = (False, False, False),\n        timestep_conditioning: bool = False,\n        upsample_residual: Tuple[bool, ...] = (True, True, True),\n        upsample_factor: Tuple[bool, ...] = (2, 2, 2),\n        spatial_padding_mode: str = \"reflect\",\n    ) -> None:\n        super().__init__()\n\n        self.patch_size = patch_size\n        self.patch_size_t = patch_size_t\n        self.out_channels = out_channels * patch_size**2\n        self.is_causal = is_causal\n\n        block_out_channels = tuple(reversed(block_out_channels))\n        spatio_temporal_scaling = tuple(reversed(spatio_temporal_scaling))\n        layers_per_block = tuple(reversed(layers_per_block))\n        inject_noise = tuple(reversed(inject_noise))\n        upsample_residual = tuple(reversed(upsample_residual))\n        upsample_factor = tuple(reversed(upsample_factor))\n        output_channel = block_out_channels[0]\n\n        self.conv_in = LTX2VideoCausalConv3d(\n            in_channels=in_channels,\n            out_channels=output_channel,\n            kernel_size=3,\n            stride=1,\n            spatial_padding_mode=spatial_padding_mode,\n        )\n\n        self.mid_block = LTX2VideoMidBlock3d(\n            in_channels=output_channel,\n            num_layers=layers_per_block[0],\n            resnet_eps=resnet_norm_eps,\n            inject_noise=inject_noise[0],\n            timestep_conditioning=timestep_conditioning,\n            spatial_padding_mode=spatial_padding_mode,\n        )\n\n        # up blocks\n        num_block_out_channels = len(block_out_channels)\n        self.up_blocks = nn.ModuleList([])\n        for i in range(num_block_out_channels):\n            input_channel = output_channel // upsample_factor[i]\n            output_channel = block_out_channels[i] // upsample_factor[i]\n\n            up_block = LTX2VideoUpBlock3d(\n                in_channels=input_channel,\n                out_channels=output_channel,\n                num_layers=layers_per_block[i + 1],\n                resnet_eps=resnet_norm_eps,\n                spatio_temporal_scale=spatio_temporal_scaling[i],\n                inject_noise=inject_noise[i + 1],\n                timestep_conditioning=timestep_conditioning,\n                upsample_residual=upsample_residual[i],\n                upscale_factor=upsample_factor[i],\n                spatial_padding_mode=spatial_padding_mode,\n            )\n\n            self.up_blocks.append(up_block)\n\n        # out\n        self.norm_out = PerChannelRMSNorm()\n        self.conv_act = nn.SiLU()\n        self.conv_out = LTX2VideoCausalConv3d(\n            in_channels=output_channel,\n            out_channels=self.out_channels,\n            kernel_size=3,\n            stride=1,\n            spatial_padding_mode=spatial_padding_mode,\n        )\n\n        # timestep embedding\n        self.time_embedder = None\n        self.scale_shift_table = None\n        self.timestep_scale_multiplier = None\n        if timestep_conditioning:\n            self.timestep_scale_multiplier = nn.Parameter(\n                torch.tensor(1000.0, dtype=torch.float32)\n            )\n            self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(\n                output_channel * 2, 0\n            )\n            self.scale_shift_table = nn.Parameter(\n                torch.randn(2, output_channel) / output_channel**0.5\n            )\n\n        self.gradient_checkpointing = False\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        temb: Optional[torch.Tensor] = None,\n        causal: Optional[bool] = None,\n    ) -> torch.Tensor:\n        causal = causal or self.is_causal\n\n        hidden_states = self.conv_in(hidden_states, causal=causal)\n\n        if self.timestep_scale_multiplier is not None:\n            temb = temb * self.timestep_scale_multiplier\n\n        if torch.is_grad_enabled() and self.gradient_checkpointing:\n            hidden_states = self._gradient_checkpointing_func(\n                self.mid_block, hidden_states, temb, None, causal\n            )\n\n            for up_block in self.up_blocks:\n                hidden_states = self._gradient_checkpointing_func(\n                    up_block, hidden_states, temb, None, causal\n                )\n        else:\n            hidden_states = self.mid_block(hidden_states, temb, causal=causal)\n\n            for up_block in self.up_blocks:\n                hidden_states = up_block(hidden_states, temb, causal=causal)\n\n        hidden_states = self.norm_out(hidden_states)\n\n        if self.time_embedder is not None:\n            temb = self.time_embedder(\n                timestep=temb.flatten(),\n                resolution=None,\n                aspect_ratio=None,\n                batch_size=hidden_states.size(0),\n                hidden_dtype=hidden_states.dtype,\n            )\n            temb = temb.view(hidden_states.size(0), -1, 1, 1, 1).unflatten(1, (2, -1))\n            temb = temb + self.scale_shift_table[None, ..., None, None, None]\n            shift, scale = temb.unbind(dim=1)\n            hidden_states = hidden_states * (1 + scale) + shift\n\n        hidden_states = self.conv_act(hidden_states)\n        hidden_states = self.conv_out(hidden_states, causal=causal)\n\n        p = self.patch_size\n        p_t = self.patch_size_t\n\n        batch_size, num_channels, num_frames, height, width = hidden_states.shape\n        hidden_states = hidden_states.reshape(\n            batch_size, -1, p_t, p, p, num_frames, height, width\n        )\n        hidden_states = (\n            hidden_states.permute(0, 1, 5, 2, 6, 4, 7, 3)\n            .flatten(6, 7)\n            .flatten(4, 5)\n            .flatten(2, 3)\n        )\n\n        return hidden_states\n\n\nclass AutoencoderKLLTX2Video(ParallelTiledVAE):\n    r\"\"\"\n    A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.\n\n    This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented\n    for all models (such as downloading or saving).\n    \"\"\"\n\n    _supports_gradient_checkpointing = False\n\n    def __init__(self, config: LTXVideoVAEConfig):\n        super().__init__(config=config)\n        in_channels = config.arch_config.in_channels\n        latent_channels = config.arch_config.latent_channels\n        out_channels = config.arch_config.out_channels\n        block_out_channels = config.arch_config.block_out_channels\n        down_block_types = config.arch_config.down_block_types\n        spatio_temporal_scaling = config.arch_config.spatio_temporal_scaling\n        layers_per_block = config.arch_config.layers_per_block\n        downsample_type = config.arch_config.downsample_type\n        patch_size = config.arch_config.patch_size\n        patch_size_t = config.arch_config.patch_size_t\n        resnet_norm_eps = config.arch_config.resnet_norm_eps\n        encoder_causal = config.arch_config.encoder_causal\n        encoder_spatial_padding_mode = config.arch_config.encoder_spatial_padding_mode\n\n        decoder_block_out_channels = config.arch_config.decoder_block_out_channels\n        decoder_spatio_temporal_scaling = (\n            config.arch_config.decoder_spatio_temporal_scaling\n        )\n        decoder_layers_per_block = config.arch_config.decoder_layers_per_block\n        decoder_causal = config.arch_config.decoder_causal\n        decoder_spatial_padding_mode = config.arch_config.decoder_spatial_padding_mode\n\n        self.encoder = LTX2VideoEncoder3d(\n            in_channels,\n            latent_channels,\n            block_out_channels,\n            down_block_types,\n            spatio_temporal_scaling,\n            layers_per_block,\n            downsample_type,\n            patch_size,\n            patch_size_t,\n            resnet_norm_eps,\n            encoder_causal,\n            encoder_spatial_padding_mode,\n        )\n\n        self.decoder = LTX2VideoDecoder3d(\n            latent_channels,\n            out_channels,\n            decoder_block_out_channels,\n            decoder_spatio_temporal_scaling,\n            decoder_layers_per_block,\n            patch_size,\n            patch_size_t,\n            resnet_norm_eps,\n            decoder_causal,\n            decoder_spatial_padding_mode,\n        )\n\n        latents_mean = torch.zeros((latent_channels,), requires_grad=False)\n        latents_std = torch.ones((latent_channels,), requires_grad=False)\n        self.register_buffer(\"latents_mean\", latents_mean, persistent=True)\n        self.register_buffer(\"latents_std\", latents_std, persistent=True)\n\n        # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension\n        # to perform decoding of a single video latent at a time.\n        self.use_slicing = False\n\n        # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent\n        # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the\n        # intermediate tiles together, the memory requirement can be lowered.\n        self.use_tiling = False\n\n        # When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames\n        # at a fixed frame batch size (based on `self.num_latent_frames_batch_sizes`), the memory requirement can be lowered.\n        self.use_framewise_encoding = False\n        self.use_framewise_decoding = False\n\n        # This can be configured based on the amount of GPU memory available.\n        # `16` for sample frames and `2` for latent frames are sensible defaults for consumer GPUs.\n        # Setting it to higher values results in higher memory usage.\n        self.num_sample_frames_batch_size = 16\n        self.num_latent_frames_batch_size = 2\n\n        # The minimal tile height and width for spatial tiling to be used\n        self.tile_sample_min_height = 512\n        self.tile_sample_min_width = 512\n        self.tile_sample_min_num_frames = 16\n\n        # The minimal distance between two spatial tiles\n        self.tile_sample_stride_height = 448\n        self.tile_sample_stride_width = 448\n        self.tile_sample_stride_num_frames = 8\n\n    def enable_tiling(\n        self,\n        tile_sample_min_height: Optional[int] = None,\n        tile_sample_min_width: Optional[int] = None,\n        tile_sample_min_num_frames: Optional[int] = None,\n        tile_sample_stride_height: Optional[float] = None,\n        tile_sample_stride_width: Optional[float] = None,\n        tile_sample_stride_num_frames: Optional[float] = None,\n    ) -> None:\n        r\"\"\"\n        Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to\n        compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow\n        processing larger images.\n\n        Args:\n            tile_sample_min_height (`int`, *optional*):\n                The minimum height required for a sample to be separated into tiles across the height dimension.\n            tile_sample_min_width (`int`, *optional*):\n                The minimum width required for a sample to be separated into tiles across the width dimension.\n            tile_sample_stride_height (`int`, *optional*):\n                The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are\n                no tiling artifacts produced across the height dimension.\n            tile_sample_stride_width (`int`, *optional*):\n                The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling\n                artifacts produced across the width dimension.\n        \"\"\"\n        self.use_tiling = True\n        self.tile_sample_min_height = (\n            tile_sample_min_height or self.tile_sample_min_height\n        )\n        self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width\n        self.tile_sample_min_num_frames = (\n            tile_sample_min_num_frames or self.tile_sample_min_num_frames\n        )\n        self.tile_sample_stride_height = (\n            tile_sample_stride_height or self.tile_sample_stride_height\n        )\n        self.tile_sample_stride_width = (\n            tile_sample_stride_width or self.tile_sample_stride_width\n        )\n        self.tile_sample_stride_num_frames = (\n            tile_sample_stride_num_frames or self.tile_sample_stride_num_frames\n        )\n\n    def _encode(self, x: torch.Tensor, causal: Optional[bool] = None) -> torch.Tensor:\n        batch_size, num_channels, num_frames, height, width = x.shape\n\n        if self.use_framewise_decoding and num_frames > self.tile_sample_min_num_frames:\n            return self._temporal_tiled_encode(x, causal=causal)\n\n        if self.use_tiling and (\n            width > self.tile_sample_min_width or height > self.tile_sample_min_height\n        ):\n            return self.tiled_encode(x, causal=causal)\n\n        enc = self.encoder(x, causal=causal)\n\n        return enc\n\n    def encode(\n        self, x: torch.Tensor, causal: Optional[bool] = None, return_dict: bool = True\n    ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:\n        \"\"\"\n        Encode a batch of images into latents.\n\n        Args:\n            x (`torch.Tensor`): Input batch of images.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.\n\n        Returns:\n                The latent representations of the encoded videos. If `return_dict` is True, a\n                [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.\n        \"\"\"\n        if self.use_slicing and x.shape[0] > 1:\n            encoded_slices = [\n                self._encode(x_slice, causal=causal) for x_slice in x.split(1)\n            ]\n            h = torch.cat(encoded_slices)\n        else:\n            h = self._encode(x, causal=causal)\n        posterior = DiagonalGaussianDistribution(h)\n\n        if not return_dict:\n            return (posterior,)\n        return AutoencoderKLOutput(latent_dist=posterior)\n\n    def _decode(\n        self,\n        z: torch.Tensor,\n        temb: Optional[torch.Tensor] = None,\n        causal: Optional[bool] = None,\n        return_dict: bool = True,\n    ) -> Union[DecoderOutput, torch.Tensor]:\n        batch_size, num_channels, num_frames, height, width = z.shape\n        tile_latent_min_height = (\n            self.tile_sample_min_height // self.spatial_compression_ratio\n        )\n        tile_latent_min_width = (\n            self.tile_sample_min_width // self.spatial_compression_ratio\n        )\n        tile_latent_min_num_frames = (\n            self.tile_sample_min_num_frames // self.temporal_compression_ratio\n        )\n\n        if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames:\n            return self._temporal_tiled_decode(\n                z, temb, causal=causal, return_dict=return_dict\n            )\n\n        if self.use_tiling and (\n            width > tile_latent_min_width or height > tile_latent_min_height\n        ):\n            return self.tiled_decode(z, temb, causal=causal, return_dict=return_dict)\n\n        dec = self.decoder(z, temb, causal=causal)\n\n        if not return_dict:\n            return (dec,)\n\n        return DecoderOutput(sample=dec)\n\n    def decode(\n        self,\n        z: torch.Tensor,\n        temb: Optional[torch.Tensor] = None,\n        causal: Optional[bool] = None,\n        return_dict: bool = True,\n    ) -> Union[DecoderOutput, torch.Tensor]:\n        \"\"\"\n        Decode a batch of images.\n\n        Args:\n            z (`torch.Tensor`): Input batch of latent vectors.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.\n\n        Returns:\n            [`~models.vae.DecoderOutput`] or `tuple`:\n                If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is\n                returned.\n        \"\"\"\n        if self.use_slicing and z.shape[0] > 1:\n            if temb is not None:\n                decoded_slices = [\n                    self._decode(z_slice, t_slice, causal=causal).sample\n                    for z_slice, t_slice in (z.split(1), temb.split(1))\n                ]\n            else:\n                decoded_slices = [\n                    self._decode(z_slice, causal=causal).sample\n                    for z_slice in z.split(1)\n                ]\n            decoded = torch.cat(decoded_slices)\n        else:\n            decoded = self._decode(z, temb, causal=causal).sample\n\n        if not return_dict:\n            return (decoded,)\n\n        return DecoderOutput(sample=decoded)\n\n    def blend_v(\n        self, a: torch.Tensor, b: torch.Tensor, blend_extent: int\n    ) -> torch.Tensor:\n        blend_extent = min(a.shape[3], b.shape[3], blend_extent)\n        for y in range(blend_extent):\n            b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (\n                1 - y / blend_extent\n            ) + b[:, :, :, y, :] * (y / blend_extent)\n        return b\n\n    def blend_h(\n        self, a: torch.Tensor, b: torch.Tensor, blend_extent: int\n    ) -> torch.Tensor:\n        blend_extent = min(a.shape[4], b.shape[4], blend_extent)\n        for x in range(blend_extent):\n            b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (\n                1 - x / blend_extent\n            ) + b[:, :, :, :, x] * (x / blend_extent)\n        return b\n\n    def blend_t(\n        self, a: torch.Tensor, b: torch.Tensor, blend_extent: int\n    ) -> torch.Tensor:\n        blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)\n        for x in range(blend_extent):\n            b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (\n                1 - x / blend_extent\n            ) + b[:, :, x, :, :] * (x / blend_extent)\n        return b\n\n    def tiled_encode(\n        self, x: torch.Tensor, causal: Optional[bool] = None\n    ) -> torch.Tensor:\n        r\"\"\"Encode a batch of images using a tiled encoder.\n\n        Args:\n            x (`torch.Tensor`): Input batch of videos.\n\n        Returns:\n            `torch.Tensor`:\n                The latent representation of the encoded videos.\n        \"\"\"\n        batch_size, num_channels, num_frames, height, width = x.shape\n        latent_height = height // self.spatial_compression_ratio\n        latent_width = width // self.spatial_compression_ratio\n\n        tile_latent_min_height = (\n            self.tile_sample_min_height // self.spatial_compression_ratio\n        )\n        tile_latent_min_width = (\n            self.tile_sample_min_width // self.spatial_compression_ratio\n        )\n        tile_latent_stride_height = (\n            self.tile_sample_stride_height // self.spatial_compression_ratio\n        )\n        tile_latent_stride_width = (\n            self.tile_sample_stride_width // self.spatial_compression_ratio\n        )\n\n        blend_height = tile_latent_min_height - tile_latent_stride_height\n        blend_width = tile_latent_min_width - tile_latent_stride_width\n\n        # Split x into overlapping tiles and encode them separately.\n        # The tiles have an overlap to avoid seams between tiles.\n        rows = []\n        for i in range(0, height, self.tile_sample_stride_height):\n            row = []\n            for j in range(0, width, self.tile_sample_stride_width):\n                time = self.encoder(\n                    x[\n                        :,\n                        :,\n                        :,\n                        i : i + self.tile_sample_min_height,\n                        j : j + self.tile_sample_min_width,\n                    ],\n                    causal=causal,\n                )\n\n                row.append(time)\n            rows.append(row)\n\n        result_rows = []\n        for i, row in enumerate(rows):\n            result_row = []\n            for j, tile in enumerate(row):\n                # blend the above tile and the left tile\n                # to the current tile and add the current tile to the result row\n                if i > 0:\n                    tile = self.blend_v(rows[i - 1][j], tile, blend_height)\n                if j > 0:\n                    tile = self.blend_h(row[j - 1], tile, blend_width)\n                result_row.append(\n                    tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]\n                )\n            result_rows.append(torch.cat(result_row, dim=4))\n\n        enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]\n        return enc\n\n    def tiled_decode(\n        self,\n        z: torch.Tensor,\n        temb: Optional[torch.Tensor],\n        causal: Optional[bool] = None,\n        return_dict: bool = True,\n    ) -> Union[DecoderOutput, torch.Tensor]:\n        r\"\"\"\n        Decode a batch of images using a tiled decoder.\n\n        Args:\n            z (`torch.Tensor`): Input batch of latent vectors.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.\n\n        Returns:\n            [`~models.vae.DecoderOutput`] or `tuple`:\n                If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is\n                returned.\n        \"\"\"\n\n        batch_size, num_channels, num_frames, height, width = z.shape\n        sample_height = height * self.spatial_compression_ratio\n        sample_width = width * self.spatial_compression_ratio\n\n        tile_latent_min_height = (\n            self.tile_sample_min_height // self.spatial_compression_ratio\n        )\n        tile_latent_min_width = (\n            self.tile_sample_min_width // self.spatial_compression_ratio\n        )\n        tile_latent_stride_height = (\n            self.tile_sample_stride_height // self.spatial_compression_ratio\n        )\n        tile_latent_stride_width = (\n            self.tile_sample_stride_width // self.spatial_compression_ratio\n        )\n\n        blend_height = self.tile_sample_min_height - self.tile_sample_stride_height\n        blend_width = self.tile_sample_min_width - self.tile_sample_stride_width\n\n        # Split z into overlapping tiles and decode them separately.\n        # The tiles have an overlap to avoid seams between tiles.\n        rows = []\n        for i in range(0, height, tile_latent_stride_height):\n            row = []\n            for j in range(0, width, tile_latent_stride_width):\n                time = self.decoder(\n                    z[\n                        :,\n                        :,\n                        :,\n                        i : i + tile_latent_min_height,\n                        j : j + tile_latent_min_width,\n                    ],\n                    temb,\n                    causal=causal,\n                )\n\n                row.append(time)\n            rows.append(row)\n\n        result_rows = []\n        for i, row in enumerate(rows):\n            result_row = []\n            for j, tile in enumerate(row):\n                # blend the above tile and the left tile\n                # to the current tile and add the current tile to the result row\n                if i > 0:\n                    tile = self.blend_v(rows[i - 1][j], tile, blend_height)\n                if j > 0:\n                    tile = self.blend_h(row[j - 1], tile, blend_width)\n                result_row.append(\n                    tile[\n                        :,\n                        :,\n                        :,\n                        : self.tile_sample_stride_height,\n                        : self.tile_sample_stride_width,\n                    ]\n                )\n            result_rows.append(torch.cat(result_row, dim=4))\n\n        dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]\n\n        if not return_dict:\n            return (dec,)\n\n        return DecoderOutput(sample=dec)\n\n    def _temporal_tiled_encode(\n        self, x: torch.Tensor, causal: Optional[bool] = None\n    ) -> AutoencoderKLOutput:\n        batch_size, num_channels, num_frames, height, width = x.shape\n        latent_num_frames = (num_frames - 1) // self.temporal_compression_ratio + 1\n\n        tile_latent_min_num_frames = (\n            self.tile_sample_min_num_frames // self.temporal_compression_ratio\n        )\n        tile_latent_stride_num_frames = (\n            self.tile_sample_stride_num_frames // self.temporal_compression_ratio\n        )\n        blend_num_frames = tile_latent_min_num_frames - tile_latent_stride_num_frames\n\n        row = []\n        for i in range(0, num_frames, self.tile_sample_stride_num_frames):\n            tile = x[:, :, i : i + self.tile_sample_min_num_frames + 1, :, :]\n            if self.use_tiling and (\n                height > self.tile_sample_min_height\n                or width > self.tile_sample_min_width\n            ):\n                tile = self.tiled_encode(tile, causal=causal)\n            else:\n                tile = self.encoder(tile, causal=causal)\n            if i > 0:\n                tile = tile[:, :, 1:, :, :]\n            row.append(tile)\n\n        result_row = []\n        for i, tile in enumerate(row):\n            if i > 0:\n                tile = self.blend_t(row[i - 1], tile, blend_num_frames)\n                result_row.append(tile[:, :, :tile_latent_stride_num_frames, :, :])\n            else:\n                result_row.append(tile[:, :, : tile_latent_stride_num_frames + 1, :, :])\n\n        enc = torch.cat(result_row, dim=2)[:, :, :latent_num_frames]\n        return enc\n\n    def _temporal_tiled_decode(\n        self,\n        z: torch.Tensor,\n        temb: Optional[torch.Tensor],\n        causal: Optional[bool] = None,\n        return_dict: bool = True,\n    ) -> Union[DecoderOutput, torch.Tensor]:\n        batch_size, num_channels, num_frames, height, width = z.shape\n        num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1\n\n        tile_latent_min_height = (\n            self.tile_sample_min_height // self.spatial_compression_ratio\n        )\n        tile_latent_min_width = (\n            self.tile_sample_min_width // self.spatial_compression_ratio\n        )\n        tile_latent_min_num_frames = (\n            self.tile_sample_min_num_frames // self.temporal_compression_ratio\n        )\n        tile_latent_stride_num_frames = (\n            self.tile_sample_stride_num_frames // self.temporal_compression_ratio\n        )\n        blend_num_frames = (\n            self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames\n        )\n\n        row = []\n        for i in range(0, num_frames, tile_latent_stride_num_frames):\n            tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :]\n            if self.use_tiling and (\n                tile.shape[-1] > tile_latent_min_width\n                or tile.shape[-2] > tile_latent_min_height\n            ):\n                decoded = self.tiled_decode(\n                    tile, temb, causal=causal, return_dict=True\n                ).sample\n            else:\n                decoded = self.decoder(tile, temb, causal=causal)\n            if i > 0:\n                decoded = decoded[:, :, :-1, :, :]\n            row.append(decoded)\n\n        result_row = []\n        for i, tile in enumerate(row):\n            if i > 0:\n                tile = self.blend_t(row[i - 1], tile, blend_num_frames)\n                tile = tile[:, :, : self.tile_sample_stride_num_frames, :, :]\n                result_row.append(tile)\n            else:\n                result_row.append(\n                    tile[:, :, : self.tile_sample_stride_num_frames + 1, :, :]\n                )\n\n        dec = torch.cat(result_row, dim=2)[:, :, :num_sample_frames]\n\n        if not return_dict:\n            return (dec,)\n        return DecoderOutput(sample=dec)\n\n    def forward(\n        self,\n        sample: torch.Tensor,\n        temb: Optional[torch.Tensor] = None,\n        sample_posterior: bool = False,\n        encoder_causal: Optional[bool] = None,\n        decoder_causal: Optional[bool] = None,\n        return_dict: bool = True,\n        generator: Optional[torch.Generator] = None,\n    ) -> Union[torch.Tensor, torch.Tensor]:\n        x = sample\n        posterior = self.encode(x, causal=encoder_causal).latent_dist\n        if sample_posterior:\n            z = posterior.sample(generator=generator)\n        else:\n            z = posterior.mode()\n        dec = self.decode(z, temb, causal=decoder_causal)\n        if not return_dict:\n            return (dec.sample,)\n        return dec\n\n\nEntryClass = AutoencoderKLLTX2Video\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/vaes/parallel/wan_common_utils.py",
    "content": "from __future__ import annotations\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\n\n\nclass AvgDown3D(nn.Module):\n    def __init__(\n        self,\n        in_channels,\n        out_channels,\n        factor_t,\n        factor_s=1,\n    ):\n        super().__init__()\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.factor_t = factor_t\n        self.factor_s = factor_s\n        self.factor = self.factor_t * self.factor_s * self.factor_s\n\n        assert in_channels * self.factor % out_channels == 0\n        self.group_size = in_channels * self.factor // out_channels\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t\n        pad = (0, 0, 0, 0, pad_t, 0)\n        x = F.pad(x, pad)\n        B, C, T, H, W = x.shape\n        x = x.view(\n            B,\n            C,\n            T // self.factor_t,\n            self.factor_t,\n            H // self.factor_s,\n            self.factor_s,\n            W // self.factor_s,\n            self.factor_s,\n        )\n        x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()\n        x = x.view(\n            B,\n            C * self.factor,\n            T // self.factor_t,\n            H // self.factor_s,\n            W // self.factor_s,\n        )\n        x = x.view(\n            B,\n            self.out_channels,\n            self.group_size,\n            T // self.factor_t,\n            H // self.factor_s,\n            W // self.factor_s,\n        )\n        x = x.mean(dim=2)\n        return x\n\n\nclass DupUp3D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        factor_t,\n        factor_s=1,\n    ):\n        super().__init__()\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n\n        self.factor_t = factor_t\n        self.factor_s = factor_s\n        self.factor = self.factor_t * self.factor_s * self.factor_s\n\n        assert out_channels * self.factor % in_channels == 0\n        self.repeats = out_channels * self.factor // in_channels\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = x.repeat_interleave(self.repeats, dim=1)\n        x = x.view(\n            x.size(0),\n            self.out_channels,\n            self.factor_t,\n            self.factor_s,\n            self.factor_s,\n            x.size(2),\n            x.size(3),\n            x.size(4),\n        )\n        x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()\n        x = x.view(\n            x.size(0),\n            self.out_channels,\n            x.size(2) * self.factor_t,\n            x.size(4) * self.factor_s,\n            x.size(6) * self.factor_s,\n        )\n\n        _first_chunk = first_chunk.get() if first_chunk is not None else None\n        if _first_chunk:\n            x = x[:, :, self.factor_t - 1 :, :, :]\n        return x\n\n\nclass WanCausalConv3d(nn.Conv3d):\n    r\"\"\"\n    A custom 3D causal convolution layer with feature caching support.\n\n    This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature\n    caching for efficient inference.\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: int | tuple[int, int, int],\n        stride: int | tuple[int, int, int] = 1,\n        padding: int | tuple[int, int, int] = 0,\n    ) -> None:\n        super().__init__(\n            in_channels=in_channels,\n            out_channels=out_channels,\n            kernel_size=kernel_size,\n            stride=stride,\n            padding=padding,\n        )\n        self.padding: tuple[int, int, int]\n        # Set up causal padding\n        self._padding: tuple[int, ...] = (\n            self.padding[2],\n            self.padding[2],\n            self.padding[1],\n            self.padding[1],\n            2 * self.padding[0],\n            0,\n        )\n        self.padding = (0, 0, 0)\n\n    def forward(self, x, cache_x=None):\n        padding = list(self._padding)\n        if cache_x is not None and self._padding[4] > 0:\n            cache_x = cache_x.to(x.device)\n            x = torch.cat([cache_x, x], dim=2)\n            padding[4] -= cache_x.shape[2]\n        x = F.pad(x, padding)\n        x = (\n            x.to(self.weight.dtype) if current_platform.is_mps() else x\n        )  # casting needed for mps since amp isn't supported\n        return super().forward(x)\n\n\nclass WanRMS_norm(nn.Module):\n    r\"\"\"\n    A custom RMS normalization layer.\n    \"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        channel_first: bool = True,\n        images: bool = True,\n        bias: bool = False,\n    ) -> None:\n        super().__init__()\n        broadcastable_dims = (1, 1, 1) if not images else (1, 1)\n        shape = (dim, *broadcastable_dims) if channel_first else (dim,)\n\n        self.channel_first = channel_first\n        self.scale = dim**0.5\n        self.gamma = nn.Parameter(torch.ones(shape))\n        self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0\n\n    def forward(self, x):\n        return (\n            F.normalize(x, dim=(1 if self.channel_first else -1))\n            * self.scale\n            * self.gamma\n            + self.bias\n        )\n\n\nclass WanUpsample(nn.Upsample):\n    r\"\"\"\n    Perform upsampling while ensuring the output tensor has the same data type as the input.\n    \"\"\"\n\n    def forward(self, x):\n        return super().forward(x.float()).type_as(x)\n\n\nis_first_frame = None\nfeat_cache = None\nfeat_idx = None\ncache_t = None\nfirst_chunk = None\n\n\ndef bind_context(\n    is_first_frame_var,\n    feat_cache_var,\n    feat_idx_var,\n    cache_t_value,\n    first_chunk_var,\n):\n    global is_first_frame\n    global feat_cache\n    global feat_idx\n    global cache_t\n    global first_chunk\n    is_first_frame = is_first_frame_var\n    feat_cache = feat_cache_var\n    feat_idx = feat_idx_var\n    cache_t = cache_t_value\n    first_chunk = first_chunk_var\n\n\ndef _ensure_bound():\n    if (\n        is_first_frame is None\n        or feat_cache is None\n        or feat_idx is None\n        or cache_t is None\n        or first_chunk is None\n    ):\n        raise RuntimeError(\"common_utils.bind_context() must be called before use.\")\n\n\ndef resample_forward(self, x):\n    _ensure_bound()\n    b, c, t, h, w = x.size()\n    first_frame = is_first_frame.get()\n    if first_frame:\n        assert t == 1\n    _feat_cache = feat_cache.get()\n    _feat_idx = feat_idx.get()\n    if self.mode == \"upsample3d\":\n        if _feat_cache is not None:\n            idx = _feat_idx\n            if _feat_cache[idx] is None:\n                _feat_cache[idx] = \"Rep\"\n                _feat_idx += 1\n            else:\n                cache_x = x[:, :, -cache_t:, :, :].clone()\n                if (\n                    cache_x.shape[2] < 2\n                    and _feat_cache[idx] is not None\n                    and _feat_cache[idx] != \"Rep\"\n                ):\n                    # cache last frame of last two chunk\n                    cache_x = torch.cat(\n                        [\n                            _feat_cache[idx][:, :, -1, :, :]\n                            .unsqueeze(2)\n                            .to(cache_x.device),\n                            cache_x,\n                        ],\n                        dim=2,\n                    )\n                if (\n                    cache_x.shape[2] < 2\n                    and _feat_cache[idx] is not None\n                    and _feat_cache[idx] == \"Rep\"\n                ):\n                    cache_x = torch.cat(\n                        [torch.zeros_like(cache_x).to(cache_x.device), cache_x],\n                        dim=2,\n                    )\n                if _feat_cache[idx] == \"Rep\":\n                    x = self.time_conv(x)\n                else:\n                    x = self.time_conv(x, _feat_cache[idx])\n                _feat_cache[idx] = cache_x\n                _feat_idx += 1\n\n                x = x.reshape(b, 2, c, t, h, w)\n                x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)\n                x = x.reshape(b, c, t * 2, h, w)\n            feat_cache.set(_feat_cache)\n            feat_idx.set(_feat_idx)\n        elif not first_frame and hasattr(self, \"time_conv\"):\n            x = self.time_conv(x)\n            x = x.reshape(b, 2, c, t, h, w)\n            x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)\n            x = x.reshape(b, c, t * 2, h, w)\n    t = x.shape[2]\n    x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)\n    x = self.resample(x)\n    x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4)\n\n    _feat_cache = feat_cache.get()\n    _feat_idx = feat_idx.get()\n    if self.mode == \"downsample3d\":\n        if _feat_cache is not None:\n            idx = _feat_idx\n            if _feat_cache[idx] is None:\n                _feat_cache[idx] = x.clone()\n                _feat_idx += 1\n            else:\n                cache_x = x[:, :, -1:, :, :].clone()\n                x = self.time_conv(torch.cat([_feat_cache[idx][:, :, -1:, :, :], x], 2))\n                _feat_cache[idx] = cache_x\n                _feat_idx += 1\n            feat_cache.set(_feat_cache)\n            feat_idx.set(_feat_idx)\n        elif not first_frame and hasattr(self, \"time_conv\"):\n            x = self.time_conv(x)\n    return x\n\n\ndef residual_block_forward(self, x):\n    _ensure_bound()\n    # Apply shortcut connection\n    h = self.conv_shortcut(x)\n\n    # First normalization and activation\n    x = self.norm1(x)\n    x = self.nonlinearity(x)\n\n    _feat_cache = feat_cache.get()\n    _feat_idx = feat_idx.get()\n    if _feat_cache is not None:\n        idx = _feat_idx\n        cache_x = x[:, :, -cache_t:, :, :].clone()\n        if cache_x.shape[2] < 2 and _feat_cache[idx] is not None:\n            cache_x = torch.cat(\n                [\n                    _feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),\n                    cache_x,\n                ],\n                dim=2,\n            )\n\n        x = self.conv1(x, _feat_cache[idx])\n        _feat_cache[idx] = cache_x\n        _feat_idx += 1\n        feat_cache.set(_feat_cache)\n        feat_idx.set(_feat_idx)\n    else:\n        x = self.conv1(x)\n\n    # Second normalization and activation\n    x = self.norm2(x)\n    x = self.nonlinearity(x)\n\n    # Dropout\n    x = self.dropout(x)\n\n    _feat_cache = feat_cache.get()\n    _feat_idx = feat_idx.get()\n    if _feat_cache is not None:\n        idx = _feat_idx\n        cache_x = x[:, :, -cache_t:, :, :].clone()\n        if cache_x.shape[2] < 2 and _feat_cache[idx] is not None:\n            cache_x = torch.cat(\n                [\n                    _feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),\n                    cache_x,\n                ],\n                dim=2,\n            )\n\n        x = self.conv2(x, _feat_cache[idx])\n        _feat_cache[idx] = cache_x\n        _feat_idx += 1\n        feat_cache.set(_feat_cache)\n        feat_idx.set(_feat_idx)\n    else:\n        x = self.conv2(x)\n\n    # Add residual connection\n    return x + h\n\n\ndef attention_block_forward(self, x):\n    identity = x\n    batch_size, channels, num_frames, height, width = x.size()\n    x = x.permute(0, 2, 1, 3, 4).reshape(\n        batch_size * num_frames, channels, height, width\n    )\n    x = self.norm(x)\n\n    # compute query, key, value\n    qkv = self.to_qkv(x)\n    qkv = qkv.reshape(batch_size * num_frames, 1, channels * 3, -1)\n    qkv = qkv.permute(0, 1, 3, 2).contiguous()\n    q, k, v = qkv.chunk(3, dim=-1)\n\n    x = torch.nn.functional.scaled_dot_product_attention(q, k, v)\n\n    x = (\n        x.squeeze(1)\n        .permute(0, 2, 1)\n        .reshape(batch_size * num_frames, channels, height, width)\n    )\n\n    # output projection\n    x = self.proj(x)\n\n    # Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w]\n    x = x.view(batch_size, num_frames, channels, height, width)\n    x = x.permute(0, 2, 1, 3, 4)\n\n    return x + identity\n\n\ndef mid_block_forward(self, x):\n    # First residual block\n    x = self.resnets[0](x)\n\n    # Process through attention and residual blocks\n    for attn, resnet in zip(self.attentions, self.resnets[1:], strict=True):\n        if attn is not None:\n            x = attn(x)\n\n        x = resnet(x)\n\n    return x\n\n\ndef residual_down_block_forward(self, x):\n    x_copy = x\n    for resnet in self.resnets:\n        x = resnet(x)\n    if self.downsampler is not None:\n        x = self.downsampler(x)\n\n    return x + self.avg_shortcut(x_copy)\n\n\ndef residual_up_block_forward(self, x):\n    if self.avg_shortcut is not None:\n        x_copy = x\n\n    for resnet in self.resnets:\n        x = resnet(x)\n\n    if self.upsampler is not None:\n        x = self.upsampler(x)\n\n    if self.avg_shortcut is not None:\n        x = x + self.avg_shortcut(x_copy)\n\n    return x\n\n\ndef up_block_forward(self, x):\n    for resnet in self.resnets:\n        x = resnet(x)\n\n    if self.upsamplers is not None:\n        x = self.upsamplers[0](x)\n    return x\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/vaes/parallel/wan_dist_utils.py",
    "content": "import math\n\nimport torch\nimport torch.distributed as dist\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom sglang.multimodal_gen.runtime.distributed.parallel_state import (\n    get_sp_group,\n    get_sp_parallel_rank,\n    get_sp_world_size,\n)\nfrom sglang.multimodal_gen.runtime.layers.activation import get_act_fn\nfrom sglang.multimodal_gen.runtime.models.vaes.parallel.wan_common_utils import (\n    AvgDown3D,\n    DupUp3D,\n    WanCausalConv3d,\n    WanRMS_norm,\n    WanUpsample,\n    attention_block_forward,\n    mid_block_forward,\n    resample_forward,\n    residual_block_forward,\n    residual_down_block_forward,\n    residual_up_block_forward,\n    up_block_forward,\n)\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\n\n\ndef tensor_pad(x: torch.Tensor, len_to_pad: int, dim: int = -2):\n    x = torch.cat(\n        [\n            x,\n            torch.zeros(\n                *x.shape[:dim],\n                len_to_pad,\n                *x.shape[dim + 1 :],\n                dtype=x.dtype,\n                device=x.device,\n            ),\n        ],\n        dim=dim,\n    )\n    return x\n\n\ndef tensor_chunk(x: torch.Tensor, dim: int = -2, world_size: int = 1, rank: int = 0):\n    if x is None:\n        return None\n    if world_size <= 1:\n        return x\n    len_to_padding = (int(math.ceil(x.shape[dim] / world_size)) * world_size) - x.shape[\n        dim\n    ]\n    if len_to_padding != 0:\n        x = tensor_pad(x, len_to_padding, dim=dim)\n    return torch.chunk(x, world_size, dim=dim)[rank]\n\n\ndef split_for_parallel_encode(\n    x: torch.Tensor, downsample_count: int, world_size: int, rank: int\n):\n    orig_height = x.shape[-2]\n    expected_height = orig_height // (2**downsample_count)\n    factor = world_size * (2**downsample_count)\n    pad_h = (factor - orig_height % factor) % factor\n    if pad_h:\n        x = F.pad(x, (0, 0, 0, pad_h, 0, 0))\n    expected_local_height = (orig_height + pad_h) // (2**downsample_count) // world_size\n    x = tensor_chunk(x, dim=-2, world_size=world_size, rank=rank)\n    return x, expected_height, expected_local_height\n\n\ndef ensure_local_height(x: torch.Tensor, expected_local_height: int | None):\n    if expected_local_height is None:\n        return x\n    if x.shape[-2] < expected_local_height:\n        pad = expected_local_height - x.shape[-2]\n        return F.pad(x, (0, 0, 0, pad, 0, 0))\n    if x.shape[-2] > expected_local_height:\n        return x[..., :expected_local_height, :].contiguous()\n    return x\n\n\ndef split_for_parallel_decode(\n    x: torch.Tensor, upsample_count: int, world_size: int, rank: int\n):\n    expected_height = x.shape[-2] * (2**upsample_count)\n    x = tensor_chunk(x, dim=-2, world_size=world_size, rank=rank)\n    return x, expected_height\n\n\ndef gather_and_trim_height(x: torch.Tensor, expected_height: int | None):\n    if expected_height is None:\n        return x\n    x = get_sp_group().all_gather(x, dim=-2)\n    if x.shape[-2] != expected_height:\n        x = x[..., :expected_height, :].contiguous()\n    return x\n\n\ndef _ensure_recv_buf(\n    recv_buf: torch.Tensor | None, reference: torch.Tensor\n) -> torch.Tensor:\n    if (\n        recv_buf is None\n        or recv_buf.shape != reference.shape\n        or recv_buf.dtype != reference.dtype\n        or recv_buf.device != reference.device\n    ):\n        return torch.empty_like(reference)\n    return recv_buf\n\n\ndef halo_exchange(\n    x: torch.Tensor,\n    height_halo_size: int = 1,\n    recv_top_buf: torch.Tensor | None = None,\n    recv_bottom_buf: torch.Tensor | None = None,\n) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n    if height_halo_size == 0:\n        return x, recv_top_buf, recv_bottom_buf\n\n    sp_group = get_sp_group()\n    rank = get_sp_parallel_rank()\n    world_size = get_sp_world_size()\n    group = sp_group.device_group\n    group_ranks = sp_group.ranks\n\n    top_row = x[..., :height_halo_size, :].contiguous()\n    bottom_row = x[..., -height_halo_size:, :].contiguous()\n\n    recv_top_buf = _ensure_recv_buf(recv_top_buf, top_row)\n    recv_bottom_buf = _ensure_recv_buf(recv_bottom_buf, bottom_row)\n\n    # use batched P2P operations\n    p2p_ops = []\n\n    if rank > 0:\n        # has previous neighbor, recv previous rank's data to recv_top_buf and send top_row to it.\n        prev_rank = group_ranks[rank - 1]\n        p2p_ops.append(dist.P2POp(dist.irecv, recv_top_buf, prev_rank, group))\n        p2p_ops.append(dist.P2POp(dist.isend, top_row, prev_rank, group))\n    if rank < world_size - 1:\n        # has next neighbor, send bottom_row to next rank and recv next rank's data to recv_bottom_buf.\n        next_rank = group_ranks[rank + 1]\n        p2p_ops.append(dist.P2POp(dist.isend, bottom_row, next_rank, group))\n        p2p_ops.append(dist.P2POp(dist.irecv, recv_bottom_buf, next_rank, group))\n\n    if rank == 0:\n        recv_top_buf.zero_()\n    if rank == world_size - 1:\n        recv_bottom_buf.zero_()\n\n    if p2p_ops:\n        reqs = dist.batch_isend_irecv(p2p_ops)\n        for req in reqs:\n            req.wait()\n\n    return (\n        torch.concat([recv_top_buf, x, recv_bottom_buf], dim=-2),\n        recv_top_buf,\n        recv_bottom_buf,\n    )\n\n\nclass WanDistConv2d(nn.Conv2d):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: int | tuple[int, int, int],\n        stride: int | tuple[int, int, int] = 1,\n        padding: int | tuple[int, int, int] = 0,\n        height_padding: tuple[int, int] | None = None,\n    ):\n        super().__init__(\n            in_channels=in_channels,\n            out_channels=out_channels,\n            kernel_size=kernel_size,\n            stride=stride,\n            padding=padding,\n        )\n\n        self.height_halo_size = (self.kernel_size[-2] - 1) // 2\n        if height_padding is None:\n            height_padding = (self.padding[-2], self.padding[-2])\n        self.height_pad_top, self.height_pad_bottom = height_padding\n\n        self.padding: tuple[int, int]\n        if self.height_halo_size > 0:\n            self._padding = (self.padding[1], self.padding[1], 0, 0)\n        else:\n            self._padding = (\n                self.padding[1],\n                self.padding[1],\n                self.padding[0],\n                self.padding[0],\n            )\n\n        self.padding = (0, 0)\n        self._halo_recv_top_buf: torch.Tensor | None = None\n        self._halo_recv_bottom_buf: torch.Tensor | None = None\n        self.rank = get_sp_parallel_rank()\n        self.world_size = get_sp_world_size()\n\n    def forward(self, x):\n        x = F.pad(x, self._padding)\n\n        x_padded, self._halo_recv_top_buf, self._halo_recv_bottom_buf = halo_exchange(\n            x,\n            height_halo_size=self.height_halo_size,\n            recv_top_buf=self._halo_recv_top_buf,\n            recv_bottom_buf=self._halo_recv_bottom_buf,\n        )\n\n        pad_top = self.height_pad_top\n        stride = self.stride[-2]\n        global_start = self.rank * x.shape[-2]\n        if self.height_halo_size > 0 and stride > 1:\n            shift = (global_start - self.height_halo_size + pad_top) % stride\n            if shift:\n                x_padded = x_padded[..., shift:, :]\n                global_start += shift\n\n        out = super().forward(x_padded)\n\n        if self.height_halo_size == 0:\n            return out\n\n        local_height = x.shape[-2]\n        global_height = local_height * self.world_size\n        halo = self.height_halo_size\n        pad_bottom = self.height_pad_bottom\n        kernel = self.kernel_size[-2]\n        min_i = math.ceil(((-pad_top) - (global_start - halo)) / stride)\n        max_i = math.floor(\n            ((global_height - 1 + pad_bottom) - (kernel - 1) - (global_start - halo))\n            / stride\n        )\n        start = max(min_i, 0)\n        end = min(max_i + 1, out.shape[-2])\n        if start != 0 or end != out.shape[-2]:\n            out = out[..., start:end, :]\n\n        return out\n\n\nclass WanDistCausalConv3d(nn.Conv3d):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: int | tuple[int, int, int],\n        stride: int | tuple[int, int, int] = 1,\n        padding: int | tuple[int, int, int] = 0,\n    ):\n        super().__init__(\n            in_channels=in_channels,\n            out_channels=out_channels,\n            kernel_size=kernel_size,\n            stride=stride,\n            padding=padding,\n        )\n\n        self.height_pad_top = self.padding[1]\n        self.height_pad_bottom = self.padding[1]\n        self.height_halo_size = (self.kernel_size[-2] - 1) // 2\n\n        self.padding: tuple[int, int, int]\n        # Set up causal padding, let the halo to control height padding\n        if self.height_halo_size > 0:\n            self._padding: tuple[int, ...] = (\n                self.padding[2],\n                self.padding[2],\n                0,\n                0,\n                2 * self.padding[0],\n                0,\n            )\n        else:\n            self._padding: tuple[int, ...] = (\n                self.padding[2],\n                self.padding[2],\n                self.padding[1],\n                self.padding[1],\n                2 * self.padding[0],\n                0,\n            )\n        self.padding = (0, 0, 0)\n        self._halo_recv_top_buf: torch.Tensor | None = None\n        self._halo_recv_bottom_buf: torch.Tensor | None = None\n        self.rank = get_sp_parallel_rank()\n        self.world_size = get_sp_world_size()\n\n    def forward(self, x, cache_x=None):\n        padding = list(self._padding)\n        if cache_x is not None and self._padding[4] > 0:\n            cache_x = cache_x.to(x.device)\n            x = torch.cat([cache_x, x], dim=2)\n            padding[4] -= cache_x.shape[2]\n\n        x = F.pad(x, padding)\n\n        x = (\n            x.to(self.weight.dtype) if current_platform.is_mps() else x\n        )  # casting needed for mps since amp isn't supported\n\n        x_padded, self._halo_recv_top_buf, self._halo_recv_bottom_buf = halo_exchange(\n            x,\n            height_halo_size=self.height_halo_size,\n            recv_top_buf=self._halo_recv_top_buf,\n            recv_bottom_buf=self._halo_recv_bottom_buf,\n        )\n\n        pad_top = self.height_pad_top\n        stride = self.stride[-2]\n        global_start = self.rank * x.shape[-2]\n        if self.height_halo_size > 0 and stride > 1:\n            shift = (global_start - self.height_halo_size + pad_top) % stride\n            if shift:\n                x_padded = x_padded[..., shift:, :]\n                global_start += shift\n\n        out = super().forward(x_padded)\n\n        if self.height_halo_size == 0:\n            return out\n\n        local_height = x.shape[-2]\n        global_height = local_height * self.world_size\n        halo = self.height_halo_size\n        pad_bottom = self.height_pad_bottom\n        kernel = self.kernel_size[-2]\n        min_i = math.ceil(((-pad_top) - (global_start - halo)) / stride)\n        max_i = math.floor(\n            ((global_height - 1 + pad_bottom) - (kernel - 1) - (global_start - halo))\n            / stride\n        )\n        start = max(min_i, 0)\n        end = min(max_i + 1, out.shape[-2])\n        if start != 0 or end != out.shape[-2]:\n            out = out[..., start:end, :]\n\n        return out\n\n\nclass WanDistZeroPad2d(nn.Module):\n    \"\"\"Apply 2D padding once globally across sequence-parallel height splits.\"\"\"\n\n    def __init__(self, padding: tuple[int, int, int, int]) -> None:\n        super().__init__()\n        self.padding = padding  # (left, right, top, bottom)\n        self.rank = get_sp_parallel_rank()\n        self.world_size = get_sp_world_size()\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        left, right, top, bottom = self.padding\n        if self.world_size <= 1:\n            return F.pad(x, (left, right, top, bottom))\n        # Only the first/last rank should contribute global top/bottom padding.\n        top = top if self.rank == 0 else 0\n        bottom = bottom if self.rank == self.world_size - 1 else 0\n        return F.pad(x, (left, right, top, bottom))\n\n\nclass WanDistResample(nn.Module):\n    r\"\"\"\n    A custom resampling module for 2D and 3D data used for parallel decoding.\n\n    Args:\n        dim (int): The number of input/output channels.\n        mode (str): The resampling mode. Must be one of:\n            - 'none': No resampling (identity operation).\n            - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution.\n            - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution.\n            - 'downsample2d': 2D downsampling with zero-padding and convolution.\n            - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.\n    \"\"\"\n\n    def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None:\n        super().__init__()\n        self.dim = dim\n        self.mode = mode\n\n        # default to dim //2\n        if upsample_out_dim is None:\n            upsample_out_dim = dim // 2\n\n        # layers\n        # We support parallel encode/decode; downsample uses halo exchange as well.\n        if mode == \"upsample2d\":\n            self.resample = nn.Sequential(\n                WanUpsample(scale_factor=(2.0, 2.0), mode=\"nearest-exact\"),\n                WanDistConv2d(dim, upsample_out_dim, 3, padding=1),\n            )\n        elif mode == \"upsample3d\":\n            self.resample = nn.Sequential(\n                WanUpsample(scale_factor=(2.0, 2.0), mode=\"nearest-exact\"),\n                WanDistConv2d(dim, upsample_out_dim, 3, padding=1),\n            )\n            self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))\n\n        elif mode == \"downsample2d\":\n            self.resample = nn.Sequential(\n                WanDistZeroPad2d((0, 1, 0, 0)),\n                WanDistConv2d(dim, dim, 3, stride=(2, 2), height_padding=(0, 1)),\n            )\n        elif mode == \"downsample3d\":\n            self.resample = nn.Sequential(\n                WanDistZeroPad2d((0, 1, 0, 0)),\n                WanDistConv2d(dim, dim, 3, stride=(2, 2), height_padding=(0, 1)),\n            )\n            self.time_conv = WanCausalConv3d(\n                dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)\n            )\n\n        else:\n            self.resample = nn.Identity()\n\n    def forward(self, x):\n        return resample_forward(self, x)\n\n\nclass WanDistResidualBlock(nn.Module):\n    r\"\"\"\n    A custom residual block module.\n\n    Args:\n        in_dim (int): Number of input channels.\n        out_dim (int): Number of output channels.\n        dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0.\n        non_linearity (str, optional): Type of non-linearity to use. Default is \"silu\".\n    \"\"\"\n\n    def __init__(\n        self,\n        in_dim: int,\n        out_dim: int,\n        dropout: float = 0.0,\n        non_linearity: str = \"silu\",\n    ) -> None:\n        super().__init__()\n        self.in_dim = in_dim\n        self.out_dim = out_dim\n        self.nonlinearity = get_act_fn(non_linearity)\n\n        # layers\n        self.norm1 = WanRMS_norm(in_dim, images=False)\n        self.conv1 = WanDistCausalConv3d(in_dim, out_dim, 3, padding=1)\n        self.norm2 = WanRMS_norm(out_dim, images=False)\n        self.dropout = nn.Dropout(dropout)\n        self.conv2 = WanDistCausalConv3d(out_dim, out_dim, 3, padding=1)\n        self.conv_shortcut = (\n            WanDistCausalConv3d(in_dim, out_dim, 1)\n            if in_dim != out_dim\n            else nn.Identity()\n        )\n\n    def forward(self, x):\n        return residual_block_forward(self, x)\n\n\nclass WanDistAttentionBlock(nn.Module):\n    r\"\"\"\n    Causal self-attention with a single head.\n\n    Args:\n        dim (int): The number of channels in the input tensor.\n    \"\"\"\n\n    def __init__(self, dim) -> None:\n        super().__init__()\n        self.dim = dim\n\n        # layers\n        self.norm = WanRMS_norm(dim)\n        self.to_qkv = nn.Conv2d(dim, dim * 3, 1)\n        self.proj = nn.Conv2d(dim, dim, 1)\n        self.rank = get_sp_parallel_rank()\n        self.world_size = get_sp_world_size()\n        self.sp_group = get_sp_group()\n\n    def forward(self, x):\n        if self.world_size > 1:\n            x = self.sp_group.all_gather(x, dim=-2)\n            x = x.contiguous()\n        x = attention_block_forward(self, x)\n        if self.world_size > 1:\n            x = torch.chunk(x, self.world_size, dim=-2)[self.rank]\n\n        return x\n\n\nclass WanDistMidBlock(nn.Module):\n    \"\"\"\n    Middle block for WanVAE encoder and decoder.\n\n    Args:\n        dim (int): Number of input/output channels.\n        dropout (float): Dropout rate.\n        non_linearity (str): Type of non-linearity to use.\n    \"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        dropout: float = 0.0,\n        non_linearity: str = \"silu\",\n        num_layers: int = 1,\n    ):\n        super().__init__()\n        self.dim = dim\n\n        # Create the components\n        resnets = [WanDistResidualBlock(dim, dim, dropout, non_linearity)]\n        attentions = []\n        for _ in range(num_layers):\n            attentions.append(WanDistAttentionBlock(dim))\n            resnets.append(WanDistResidualBlock(dim, dim, dropout, non_linearity))\n        self.attentions = nn.ModuleList(attentions)\n        self.resnets = nn.ModuleList(resnets)\n\n        self.gradient_checkpointing = False\n\n    def forward(self, x):\n        return mid_block_forward(self, x)\n\n\nclass WanDistResidualDownBlock(nn.Module):\n    def __init__(\n        self,\n        in_dim,\n        out_dim,\n        dropout,\n        num_res_blocks,\n        temperal_downsample=False,\n        down_flag=False,\n    ):\n        super().__init__()\n\n        # Shortcut path with downsample\n        self.avg_shortcut = AvgDown3D(\n            in_dim,\n            out_dim,\n            factor_t=2 if temperal_downsample else 1,\n            factor_s=2 if down_flag else 1,\n        )\n\n        # Main path with residual blocks and downsample\n        resnets = []\n        for _ in range(num_res_blocks):\n            resnets.append(WanDistResidualBlock(in_dim, out_dim, dropout))\n            in_dim = out_dim\n        self.resnets = nn.ModuleList(resnets)\n\n        # Add the final downsample block\n        if down_flag:\n            mode = \"downsample3d\" if temperal_downsample else \"downsample2d\"\n            self.downsampler = WanDistResample(out_dim, mode=mode)\n        else:\n            self.downsampler = None\n\n    def forward(self, x):\n        return residual_down_block_forward(self, x)\n\n\nclass WanDistResidualUpBlock(nn.Module):\n    \"\"\"\n    A block that handles upsampling for the WanVAE decoder.\n    Args:\n        in_dim (int): Input dimension\n        out_dim (int): Output dimension\n        num_res_blocks (int): Number of residual blocks\n        dropout (float): Dropout rate\n        temperal_upsample (bool): Whether to upsample on temporal dimension\n        up_flag (bool): Whether to upsample or not\n        non_linearity (str): Type of non-linearity to use\n    \"\"\"\n\n    def __init__(\n        self,\n        in_dim: int,\n        out_dim: int,\n        num_res_blocks: int,\n        dropout: float = 0.0,\n        temperal_upsample: bool = False,\n        up_flag: bool = False,\n        non_linearity: str = \"silu\",\n    ):\n        super().__init__()\n        self.in_dim = in_dim\n        self.out_dim = out_dim\n\n        if up_flag:\n            self.avg_shortcut = DupUp3D(\n                in_dim,\n                out_dim,\n                factor_t=2 if temperal_upsample else 1,\n                factor_s=2,\n            )\n        else:\n            self.avg_shortcut = None\n\n        # create residual blocks\n        resnets = []\n        current_dim = in_dim\n        for _ in range(num_res_blocks + 1):\n            resnets.append(\n                WanDistResidualBlock(current_dim, out_dim, dropout, non_linearity)\n            )\n            current_dim = out_dim\n\n        self.resnets = nn.ModuleList(resnets)\n\n        # Add upsampling layer if needed\n        if up_flag:\n            upsample_mode = \"upsample3d\" if temperal_upsample else \"upsample2d\"\n            self.upsampler = WanDistResample(\n                out_dim, mode=upsample_mode, upsample_out_dim=out_dim\n            )\n        else:\n            self.upsampler = None\n\n        self.gradient_checkpointing = False\n\n    def forward(self, x):\n        return residual_up_block_forward(self, x)\n\n\nclass WanDistUpBlock(nn.Module):\n    \"\"\"\n    A block that handles upsampling for the WanVAE decoder.\n\n    Args:\n        in_dim (int): Input dimension\n        out_dim (int): Output dimension\n        num_res_blocks (int): Number of residual blocks\n        dropout (float): Dropout rate\n        upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d')\n        non_linearity (str): Type of non-linearity to use\n    \"\"\"\n\n    def __init__(\n        self,\n        in_dim: int,\n        out_dim: int,\n        num_res_blocks: int,\n        dropout: float = 0.0,\n        upsample_mode: str | None = None,\n        non_linearity: str = \"silu\",\n    ):\n        super().__init__()\n        self.in_dim = in_dim\n        self.out_dim = out_dim\n\n        # Create layers list\n        resnets = []\n        # Add residual blocks and attention if needed\n        current_dim = in_dim\n        for _ in range(num_res_blocks + 1):\n            resnets.append(\n                WanDistResidualBlock(current_dim, out_dim, dropout, non_linearity)\n            )\n            current_dim = out_dim\n\n        self.resnets = nn.ModuleList(resnets)\n\n        # Add upsampling layer if needed\n        self.upsamplers = None\n        if upsample_mode is not None:\n            self.upsamplers = nn.ModuleList(\n                [WanDistResample(out_dim, mode=upsample_mode)]\n            )\n\n        self.gradient_checkpointing = False\n\n    def forward(self, x):\n        return up_block_forward(self, x)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/vaes/wanvae.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\n# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport contextvars\nfrom contextlib import contextmanager\n\nimport torch\nimport torch.distributed as dist\nimport torch.nn as nn\nfrom einops import rearrange\n\nfrom sglang.multimodal_gen.configs.models.vaes import WanVAEConfig\nfrom sglang.multimodal_gen.runtime.distributed.parallel_state import (\n    get_sp_parallel_rank,\n    get_sp_world_size,\n)\nfrom sglang.multimodal_gen.runtime.layers.activation import get_act_fn\nfrom sglang.multimodal_gen.runtime.models.vaes.common import (\n    DiagonalGaussianDistribution,\n    ParallelTiledVAE,\n)\nfrom sglang.multimodal_gen.runtime.models.vaes.parallel.wan_common_utils import (\n    AvgDown3D,\n    DupUp3D,\n    WanCausalConv3d,\n    WanRMS_norm,\n    WanUpsample,\n    attention_block_forward,\n    bind_context,\n    mid_block_forward,\n    resample_forward,\n    residual_block_forward,\n    residual_down_block_forward,\n    residual_up_block_forward,\n    up_block_forward,\n)\nfrom sglang.multimodal_gen.runtime.models.vaes.parallel.wan_dist_utils import (\n    WanDistAttentionBlock,\n    WanDistCausalConv3d,\n    WanDistMidBlock,\n    WanDistResample,\n    WanDistResidualBlock,\n    WanDistResidualDownBlock,\n    WanDistResidualUpBlock,\n    WanDistUpBlock,\n    ensure_local_height,\n    gather_and_trim_height,\n    split_for_parallel_decode,\n    split_for_parallel_encode,\n)\n\nCACHE_T = 2\n\nis_first_frame = contextvars.ContextVar(\"is_first_frame\", default=False)\nfeat_cache = contextvars.ContextVar(\"feat_cache\", default=None)\nfeat_idx = contextvars.ContextVar(\"feat_idx\", default=0)\nfirst_chunk = contextvars.ContextVar(\"first_chunk\", default=None)\n\nbind_context(is_first_frame, feat_cache, feat_idx, CACHE_T, first_chunk)\n\n\n@contextmanager\ndef forward_context(\n    first_frame_arg=False, feat_cache_arg=None, feat_idx_arg=None, first_chunk_arg=None\n):\n    is_first_frame_token = is_first_frame.set(first_frame_arg)\n    feat_cache_token = feat_cache.set(feat_cache_arg)\n    feat_idx_token = feat_idx.set(feat_idx_arg)\n    first_chunk_token = first_chunk.set(first_chunk_arg)\n    try:\n        yield\n    finally:\n        is_first_frame.reset(is_first_frame_token)\n        feat_cache.reset(feat_cache_token)\n        feat_idx.reset(feat_idx_token)\n        first_chunk.reset(first_chunk_token)\n\n\nclass WanResample(nn.Module):\n    r\"\"\"\n    A custom resampling module for 2D and 3D data.\n\n    Args:\n        dim (int): The number of input/output channels.\n        mode (str): The resampling mode. Must be one of:\n            - 'none': No resampling (identity operation).\n            - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution.\n            - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution.\n            - 'downsample2d': 2D downsampling with zero-padding and convolution.\n            - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.\n    \"\"\"\n\n    def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None:\n        super().__init__()\n        self.dim = dim\n        self.mode = mode\n\n        # default to dim //2\n        if upsample_out_dim is None:\n            upsample_out_dim = dim // 2\n\n        # layers\n        if mode == \"upsample2d\":\n            self.resample = nn.Sequential(\n                WanUpsample(scale_factor=(2.0, 2.0), mode=\"nearest-exact\"),\n                nn.Conv2d(dim, upsample_out_dim, 3, padding=1),\n            )\n        elif mode == \"upsample3d\":\n            self.resample = nn.Sequential(\n                WanUpsample(scale_factor=(2.0, 2.0), mode=\"nearest-exact\"),\n                nn.Conv2d(dim, upsample_out_dim, 3, padding=1),\n            )\n            self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))\n\n        elif mode == \"downsample2d\":\n            self.resample = nn.Sequential(\n                nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))\n            )\n        elif mode == \"downsample3d\":\n            self.resample = nn.Sequential(\n                nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))\n            )\n            self.time_conv = WanCausalConv3d(\n                dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)\n            )\n\n        else:\n            self.resample = nn.Identity()\n\n    def forward(self, x):\n        return resample_forward(self, x)\n\n\nclass WanResidualBlock(nn.Module):\n    r\"\"\"\n    A custom residual block module.\n\n    Args:\n        in_dim (int): Number of input channels.\n        out_dim (int): Number of output channels.\n        dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0.\n        non_linearity (str, optional): Type of non-linearity to use. Default is \"silu\".\n    \"\"\"\n\n    def __init__(\n        self,\n        in_dim: int,\n        out_dim: int,\n        dropout: float = 0.0,\n        non_linearity: str = \"silu\",\n    ) -> None:\n        super().__init__()\n        self.in_dim = in_dim\n        self.out_dim = out_dim\n        self.nonlinearity = get_act_fn(non_linearity)\n\n        # layers\n        self.norm1 = WanRMS_norm(in_dim, images=False)\n        self.conv1 = WanCausalConv3d(in_dim, out_dim, 3, padding=1)\n        self.norm2 = WanRMS_norm(out_dim, images=False)\n        self.dropout = nn.Dropout(dropout)\n        self.conv2 = WanCausalConv3d(out_dim, out_dim, 3, padding=1)\n        self.conv_shortcut = (\n            WanCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()\n        )\n\n    def forward(self, x):\n        return residual_block_forward(self, x)\n\n\nclass WanAttentionBlock(nn.Module):\n    r\"\"\"\n    Causal self-attention with a single head.\n\n    Args:\n        dim (int): The number of channels in the input tensor.\n    \"\"\"\n\n    def __init__(self, dim) -> None:\n        super().__init__()\n        self.dim = dim\n\n        # layers\n        self.norm = WanRMS_norm(dim)\n        self.to_qkv = nn.Conv2d(dim, dim * 3, 1)\n        self.proj = nn.Conv2d(dim, dim, 1)\n\n    def forward(self, x):\n        return attention_block_forward(self, x)\n\n\nclass WanMidBlock(nn.Module):\n    \"\"\"\n    Middle block for WanVAE encoder and decoder.\n\n    Args:\n        dim (int): Number of input/output channels.\n        dropout (float): Dropout rate.\n        non_linearity (str): Type of non-linearity to use.\n    \"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        dropout: float = 0.0,\n        non_linearity: str = \"silu\",\n        num_layers: int = 1,\n    ):\n        super().__init__()\n        self.dim = dim\n\n        # Create the components\n        resnets = [WanResidualBlock(dim, dim, dropout, non_linearity)]\n        attentions = []\n        for _ in range(num_layers):\n            attentions.append(WanAttentionBlock(dim))\n            resnets.append(WanResidualBlock(dim, dim, dropout, non_linearity))\n        self.attentions = nn.ModuleList(attentions)\n        self.resnets = nn.ModuleList(resnets)\n\n        self.gradient_checkpointing = False\n\n    def forward(self, x):\n        return mid_block_forward(self, x)\n\n\nclass WanResidualDownBlock(nn.Module):\n\n    def __init__(\n        self,\n        in_dim,\n        out_dim,\n        dropout,\n        num_res_blocks,\n        temperal_downsample=False,\n        down_flag=False,\n    ):\n        super().__init__()\n\n        # Shortcut path with downsample\n        self.avg_shortcut = AvgDown3D(\n            in_dim,\n            out_dim,\n            factor_t=2 if temperal_downsample else 1,\n            factor_s=2 if down_flag else 1,\n        )\n\n        # Main path with residual blocks and downsample\n        resnets = []\n        for _ in range(num_res_blocks):\n            resnets.append(WanResidualBlock(in_dim, out_dim, dropout))\n            in_dim = out_dim\n        self.resnets = nn.ModuleList(resnets)\n\n        # Add the final downsample block\n        if down_flag:\n            mode = \"downsample3d\" if temperal_downsample else \"downsample2d\"\n            self.downsampler = WanResample(out_dim, mode=mode)\n        else:\n            self.downsampler = None\n\n    def forward(self, x):\n        return residual_down_block_forward(self, x)\n\n\nclass WanEncoder3d(nn.Module):\n    r\"\"\"\n    A 3D encoder module.\n\n    Args:\n        dim (int): The base number of channels in the first layer.\n        z_dim (int): The dimensionality of the latent space.\n        dim_mult (list of int): Multipliers for the number of channels in each block.\n        num_res_blocks (int): Number of residual blocks in each block.\n        attn_scales (list of float): Scales at which to apply attention mechanisms.\n        temperal_downsample (list of bool): Whether to downsample temporally in each block.\n        dropout (float): Dropout rate for the dropout layers.\n        non_linearity (str): Type of non-linearity to use.\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int = 3,\n        dim=128,\n        z_dim=4,\n        dim_mult=(1, 2, 4, 4),\n        num_res_blocks=2,\n        attn_scales=(),\n        temperal_downsample=(True, True, False),\n        dropout=0.0,\n        non_linearity: str = \"silu\",\n        is_residual: bool = False,  # wan 2.2 vae use a residual downblock\n        use_parallel_encode: bool = False,\n    ):\n        super().__init__()\n        self.dim = dim\n        self.z_dim = z_dim\n        dim_mult = list(dim_mult)\n        self.dim_mult = dim_mult\n        self.num_res_blocks = num_res_blocks\n        self.attn_scales = list(attn_scales)\n        self.temperal_downsample = list(temperal_downsample)\n        self.nonlinearity = get_act_fn(non_linearity)\n        self.use_parallel_encode = use_parallel_encode\n        self.downsample_count = max(len(dim_mult) - 1, 0)\n\n        # dimensions\n        dims = [dim * u for u in [1] + dim_mult]\n        scale = 1.0\n\n        world_size = 1\n        if dist.is_initialized():\n            world_size = get_sp_world_size()\n\n        if use_parallel_encode and world_size > 1:\n            CausalConv3d = WanDistCausalConv3d\n            ResidualDownBlock = WanDistResidualDownBlock\n            ResidualBlock = WanDistResidualBlock\n            AttentionBlock = WanDistAttentionBlock\n            Resample = WanDistResample\n            MidBlock = WanDistMidBlock\n        else:\n            CausalConv3d = WanCausalConv3d\n            ResidualDownBlock = WanResidualDownBlock\n            ResidualBlock = WanResidualBlock\n            AttentionBlock = WanAttentionBlock\n            Resample = WanResample\n            MidBlock = WanMidBlock\n\n        # init block\n        self.conv_in = CausalConv3d(in_channels, dims[0], 3, padding=1)\n\n        # downsample blocks\n        self.down_blocks = nn.ModuleList([])\n        for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:], strict=True)):\n            # residual (+attention) blocks\n            if is_residual:\n                self.down_blocks.append(\n                    ResidualDownBlock(\n                        in_dim,\n                        out_dim,\n                        dropout,\n                        num_res_blocks,\n                        temperal_downsample=(\n                            temperal_downsample[i] if i != len(dim_mult) - 1 else False\n                        ),\n                        down_flag=i != len(dim_mult) - 1,\n                    )\n                )\n            else:\n                for _ in range(num_res_blocks):\n                    self.down_blocks.append(ResidualBlock(in_dim, out_dim, dropout))\n                    if scale in attn_scales:\n                        self.down_blocks.append(AttentionBlock(out_dim))\n                    in_dim = out_dim\n\n                # downsample block\n                if i != len(dim_mult) - 1:\n                    mode = \"downsample3d\" if temperal_downsample[i] else \"downsample2d\"\n                    self.down_blocks.append(Resample(out_dim, mode=mode))\n                    scale /= 2.0\n\n        # middle blocks\n        self.mid_block = MidBlock(out_dim, dropout, non_linearity, num_layers=1)\n\n        # output blocks\n        self.norm_out = WanRMS_norm(out_dim, images=False)\n        self.conv_out = CausalConv3d(out_dim, z_dim, 3, padding=1)\n\n        self.gradient_checkpointing = False\n        self.world_size = 1\n        self.rank = 0\n        if dist.is_initialized():\n            self.world_size = get_sp_world_size()\n            self.rank = get_sp_parallel_rank()\n\n    def forward(self, x):\n        expected_local_height = None\n        expected_height = None\n        if self.use_parallel_encode and self.world_size > 1:\n            x, expected_height, expected_local_height = split_for_parallel_encode(\n                x, self.downsample_count, self.world_size, self.rank\n            )\n\n        _feat_cache = feat_cache.get()\n        _feat_idx = feat_idx.get()\n        if _feat_cache is not None:\n            idx = _feat_idx\n            cache_x = x[:, :, -CACHE_T:, :, :].clone()\n            if cache_x.shape[2] < 2 and _feat_cache[idx] is not None:\n                # cache last frame of last two chunk\n                cache_x = torch.cat(\n                    [\n                        _feat_cache[idx][:, :, -1, :, :]\n                        .unsqueeze(2)\n                        .to(cache_x.device),\n                        cache_x,\n                    ],\n                    dim=2,\n                )\n            x = self.conv_in(x, _feat_cache[idx])\n            _feat_cache[idx] = cache_x\n            _feat_idx += 1\n            feat_cache.set(_feat_cache)\n            feat_idx.set(_feat_idx)\n        else:\n            x = self.conv_in(x)\n\n        ## downsamples\n        for layer in self.down_blocks:\n            x = layer(x)\n\n        ## middle\n        if self.use_parallel_encode and self.world_size > 1:\n            x = ensure_local_height(x, expected_local_height)\n        x = self.mid_block(x)\n\n        ## head\n        x = self.norm_out(x)\n        x = self.nonlinearity(x)\n\n        _feat_cache = feat_cache.get()\n        _feat_idx = feat_idx.get()\n        if _feat_cache is not None:\n            idx = _feat_idx\n            cache_x = x[:, :, -CACHE_T:, :, :].clone()\n            if cache_x.shape[2] < 2 and _feat_cache[idx] is not None:\n                # cache last frame of last two chunk\n                cache_x = torch.cat(\n                    [\n                        _feat_cache[idx][:, :, -1, :, :]\n                        .unsqueeze(2)\n                        .to(cache_x.device),\n                        cache_x,\n                    ],\n                    dim=2,\n                )\n            x = self.conv_out(x, _feat_cache[idx])\n            _feat_cache[idx] = cache_x\n            _feat_idx += 1\n            feat_cache.set(_feat_cache)\n            feat_idx.set(_feat_idx)\n        else:\n            x = self.conv_out(x)\n\n        if self.use_parallel_encode and self.world_size > 1:\n            x = gather_and_trim_height(x, expected_height)\n        return x\n\n\n# adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoders/autoencoder_kl_wan.py\nclass WanResidualUpBlock(nn.Module):\n    \"\"\"\n    A block that handles upsampling for the WanVAE decoder.\n    Args:\n        in_dim (int): Input dimension\n        out_dim (int): Output dimension\n        num_res_blocks (int): Number of residual blocks\n        dropout (float): Dropout rate\n        temperal_upsample (bool): Whether to upsample on temporal dimension\n        up_flag (bool): Whether to upsample or not\n        non_linearity (str): Type of non-linearity to use\n    \"\"\"\n\n    def __init__(\n        self,\n        in_dim: int,\n        out_dim: int,\n        num_res_blocks: int,\n        dropout: float = 0.0,\n        temperal_upsample: bool = False,\n        up_flag: bool = False,\n        non_linearity: str = \"silu\",\n    ):\n        super().__init__()\n        self.in_dim = in_dim\n        self.out_dim = out_dim\n\n        if up_flag:\n            self.avg_shortcut = DupUp3D(\n                in_dim,\n                out_dim,\n                factor_t=2 if temperal_upsample else 1,\n                factor_s=2,\n            )\n        else:\n            self.avg_shortcut = None\n\n        # create residual blocks\n        resnets = []\n        current_dim = in_dim\n        for _ in range(num_res_blocks + 1):\n            resnets.append(\n                WanResidualBlock(current_dim, out_dim, dropout, non_linearity)\n            )\n            current_dim = out_dim\n\n        self.resnets = nn.ModuleList(resnets)\n\n        # Add upsampling layer if needed\n        if up_flag:\n            upsample_mode = \"upsample3d\" if temperal_upsample else \"upsample2d\"\n            self.upsampler = WanResample(\n                out_dim, mode=upsample_mode, upsample_out_dim=out_dim\n            )\n        else:\n            self.upsampler = None\n\n        self.gradient_checkpointing = False\n\n    def forward(self, x):\n        return residual_up_block_forward(self, x)\n\n\nclass WanUpBlock(nn.Module):\n    \"\"\"\n    A block that handles upsampling for the WanVAE decoder.\n\n    Args:\n        in_dim (int): Input dimension\n        out_dim (int): Output dimension\n        num_res_blocks (int): Number of residual blocks\n        dropout (float): Dropout rate\n        upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d')\n        non_linearity (str): Type of non-linearity to use\n    \"\"\"\n\n    def __init__(\n        self,\n        in_dim: int,\n        out_dim: int,\n        num_res_blocks: int,\n        dropout: float = 0.0,\n        upsample_mode: str | None = None,\n        non_linearity: str = \"silu\",\n    ):\n        super().__init__()\n        self.in_dim = in_dim\n        self.out_dim = out_dim\n\n        # Create layers list\n        resnets = []\n        # Add residual blocks and attention if needed\n        current_dim = in_dim\n        for _ in range(num_res_blocks + 1):\n            resnets.append(\n                WanResidualBlock(current_dim, out_dim, dropout, non_linearity)\n            )\n            current_dim = out_dim\n\n        self.resnets = nn.ModuleList(resnets)\n\n        # Add upsampling layer if needed\n        self.upsamplers = None\n        if upsample_mode is not None:\n            self.upsamplers = nn.ModuleList([WanResample(out_dim, mode=upsample_mode)])\n\n        self.gradient_checkpointing = False\n\n    def forward(self, x):\n        return up_block_forward(self, x)\n\n\nclass WanDecoder3d(nn.Module):\n    r\"\"\"\n    A 3D decoder module.\n\n    Args:\n        dim (int): The base number of channels in the first layer.\n        z_dim (int): The dimensionality of the latent space.\n        dim_mult (list of int): Multipliers for the number of channels in each block.\n        num_res_blocks (int): Number of residual blocks in each block.\n        attn_scales (list of float): Scales at which to apply attention mechanisms.\n        temperal_upsample (list of bool): Whether to upsample temporally in each block.\n        dropout (float): Dropout rate for the dropout layers.\n        non_linearity (str): Type of non-linearity to use.\n    \"\"\"\n\n    def __init__(\n        self,\n        dim=128,\n        z_dim=4,\n        dim_mult=(1, 2, 4, 4),\n        num_res_blocks=2,\n        attn_scales=(),\n        temperal_upsample=(False, True, True),\n        dropout=0.0,\n        non_linearity: str = \"silu\",\n        out_channels: int = 3,\n        is_residual: bool = False,\n        use_parallel_decode: bool = False,\n    ):\n        super().__init__()\n        self.dim = dim\n        self.z_dim = z_dim\n        dim_mult = list(dim_mult)\n        self.dim_mult = dim_mult\n        self.num_res_blocks = num_res_blocks\n        self.attn_scales = list(attn_scales)\n        self.temperal_upsample = list(temperal_upsample)\n\n        self.nonlinearity = get_act_fn(non_linearity)\n        self.use_parallel_decode = use_parallel_decode\n        self.upsample_count = 0\n\n        # dimensions\n        dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]\n\n        world_size = 1\n        if dist.is_initialized():\n            world_size = get_sp_world_size()\n\n        if use_parallel_decode and world_size > 1:\n            CausalConv3d = WanDistCausalConv3d\n            MidBlock = WanDistMidBlock\n            ResidualUpBlock = WanDistResidualUpBlock\n            UpBlock = WanDistUpBlock\n        else:\n            CausalConv3d = WanCausalConv3d\n            MidBlock = WanMidBlock\n            ResidualUpBlock = WanResidualUpBlock\n            UpBlock = WanUpBlock\n\n        # init block\n        self.conv_in = CausalConv3d(z_dim, dims[0], 3, padding=1)\n\n        # middle blocks\n        self.mid_block = MidBlock(dims[0], dropout, non_linearity, num_layers=1)\n\n        # upsample blocks\n        self.upsample_count = 0\n        self.up_blocks = nn.ModuleList([])\n        for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:], strict=True)):\n            # residual (+attention) blocks\n            if i > 0 and not is_residual:\n                # wan vae 2.1\n                in_dim = in_dim // 2\n\n            # determine if we need upsampling\n            up_flag = i != len(dim_mult) - 1\n            # determine upsampling mode, if not upsampling, set to None\n            upsample_mode = None\n            if up_flag and temperal_upsample[i]:\n                upsample_mode = \"upsample3d\"\n            elif up_flag:\n                upsample_mode = \"upsample2d\"\n\n            # Create and add the upsampling block\n            if is_residual:\n                up_block = ResidualUpBlock(\n                    in_dim=in_dim,\n                    out_dim=out_dim,\n                    num_res_blocks=num_res_blocks,\n                    dropout=dropout,\n                    temperal_upsample=temperal_upsample[i] if up_flag else False,\n                    up_flag=up_flag,\n                    non_linearity=non_linearity,\n                )\n            else:\n                up_block = UpBlock(\n                    in_dim=in_dim,\n                    out_dim=out_dim,\n                    num_res_blocks=num_res_blocks,\n                    dropout=dropout,\n                    upsample_mode=upsample_mode,\n                    non_linearity=non_linearity,\n                )\n            self.up_blocks.append(up_block)\n            if up_flag:\n                self.upsample_count += 1\n\n        # output blocks\n        self.norm_out = WanRMS_norm(out_dim, images=False)\n        self.conv_out = CausalConv3d(out_dim, out_channels, 3, padding=1)\n\n        self.gradient_checkpointing = False\n        self.world_size = 1\n        self.rank = 0\n        if dist.is_initialized():\n            self.world_size = get_sp_world_size()\n            self.rank = get_sp_parallel_rank()\n\n    def forward(self, x):\n        expected_height = None\n        if self.use_parallel_decode and self.world_size > 1:\n            x, expected_height = split_for_parallel_decode(\n                x, self.upsample_count, self.world_size, self.rank\n            )\n\n        ## conv1\n        _feat_cache = feat_cache.get()\n        _feat_idx = feat_idx.get()\n        if _feat_cache is not None:\n            idx = _feat_idx\n            cache_x = x[:, :, -CACHE_T:, :, :].clone()\n            if cache_x.shape[2] < 2 and _feat_cache[idx] is not None:\n                # cache last frame of last two chunk\n                cache_x = torch.cat(\n                    [\n                        _feat_cache[idx][:, :, -1, :, :]\n                        .unsqueeze(2)\n                        .to(cache_x.device),\n                        cache_x,\n                    ],\n                    dim=2,\n                )\n            x = self.conv_in(x, _feat_cache[idx])\n            _feat_cache[idx] = cache_x\n            _feat_idx += 1\n            feat_cache.set(_feat_cache)\n            feat_idx.set(_feat_idx)\n        else:\n            x = self.conv_in(x)\n\n        ## middle\n        x = self.mid_block(x)\n\n        ## upsamples\n        for up_block in self.up_blocks:\n            x = up_block(x)\n\n        ## head\n        x = self.norm_out(x)\n        x = self.nonlinearity(x)\n        _feat_cache = feat_cache.get()\n        _feat_idx = feat_idx.get()\n        if _feat_cache is not None:\n            idx = _feat_idx\n            cache_x = x[:, :, -CACHE_T:, :, :].clone()\n            if cache_x.shape[2] < 2 and _feat_cache[idx] is not None:\n                # cache last frame of last two chunk\n                cache_x = torch.cat(\n                    [\n                        _feat_cache[idx][:, :, -1, :, :]\n                        .unsqueeze(2)\n                        .to(cache_x.device),\n                        cache_x,\n                    ],\n                    dim=2,\n                )\n            x = self.conv_out(x, _feat_cache[idx])\n            _feat_cache[idx] = cache_x\n            _feat_idx += 1\n            feat_cache.set(_feat_cache)\n            feat_idx.set(_feat_idx)\n        else:\n            x = self.conv_out(x)\n\n        if self.use_parallel_decode and self.world_size > 1:\n            x = gather_and_trim_height(x, expected_height)\n        return x\n\n\ndef patchify(x, patch_size):\n    if patch_size == 1:\n        return x\n\n    if x.dim() == 4:\n        x = rearrange(x, \"b c (h q) (w r) -> b (c r q) h w\", q=patch_size, r=patch_size)\n    elif x.dim() == 5:\n        x = rearrange(\n            x,\n            \"b c f (h q) (w r) -> b (c r q) f h w\",\n            q=patch_size,\n            r=patch_size,\n        )\n    else:\n        raise ValueError(f\"Invalid input shape: {x.shape}\")\n\n    return x\n\n\ndef unpatchify(x, patch_size):\n    if patch_size == 1:\n        return x\n\n    if x.dim() == 4:\n        x = rearrange(x, \"b (c r q) h w -> b c (h q) (w r)\", q=patch_size, r=patch_size)\n    elif x.dim() == 5:\n        x = rearrange(\n            x,\n            \"b (c r q) f h w -> b c f (h q) (w r)\",\n            q=patch_size,\n            r=patch_size,\n        )\n\n    return x\n\n\nclass AutoencoderKLWan(ParallelTiledVAE):\n    r\"\"\"\n    A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.\n    Introduced in [Wan 2.1].\n    \"\"\"\n\n    _supports_gradient_checkpointing = False\n\n    def __init__(\n        self,\n        config: WanVAEConfig,\n    ) -> None:\n        nn.Module.__init__(self)\n        ParallelTiledVAE.__init__(self, config)\n\n        self.z_dim = config.z_dim\n        self.temperal_downsample = list(config.temperal_downsample)\n        self.temperal_upsample = list(config.temperal_downsample)[::-1]\n\n        if config.decoder_base_dim is None:\n            decoder_base_dim = config.base_dim\n        else:\n            decoder_base_dim = config.decoder_base_dim\n\n        self.latents_mean = list(config.latents_mean)\n        self.latents_std = list(config.latents_std)\n        self.shift_factor = config.shift_factor\n        self.use_parallel_encode = getattr(config, \"use_parallel_encode\", False)\n        self.use_parallel_decode = getattr(config, \"use_parallel_decode\", False)\n\n        if config.load_encoder:\n            self.encoder = WanEncoder3d(\n                in_channels=config.in_channels,\n                dim=config.base_dim,\n                z_dim=self.z_dim * 2,\n                dim_mult=config.dim_mult,\n                num_res_blocks=config.num_res_blocks,\n                attn_scales=config.attn_scales,\n                temperal_downsample=self.temperal_downsample,\n                dropout=config.dropout,\n                is_residual=config.is_residual,\n                use_parallel_encode=self.use_parallel_encode,\n            )\n        self.quant_conv = WanCausalConv3d(self.z_dim * 2, self.z_dim * 2, 1)\n        self.post_quant_conv = WanCausalConv3d(self.z_dim, self.z_dim, 1)\n\n        if config.load_decoder:\n            self.decoder = WanDecoder3d(\n                dim=decoder_base_dim,\n                z_dim=self.z_dim,\n                dim_mult=config.dim_mult,\n                num_res_blocks=config.num_res_blocks,\n                attn_scales=config.attn_scales,\n                temperal_upsample=self.temperal_upsample,\n                dropout=config.dropout,\n                out_channels=config.out_channels,\n                is_residual=config.is_residual,\n                use_parallel_decode=self.use_parallel_decode,\n            )\n\n        self.use_feature_cache = config.use_feature_cache\n\n    def clear_cache(self) -> None:\n\n        def _count_conv3d(model) -> int:\n            count = 0\n            for m in model.modules():\n                if isinstance(m, WanCausalConv3d) or isinstance(m, WanDistCausalConv3d):\n                    count += 1\n            return count\n\n        if self.config.load_decoder:\n            self._conv_num = _count_conv3d(self.decoder)\n            self._conv_idx = 0\n            self._feat_map = [None] * self._conv_num\n        # cache encode\n        if self.config.load_encoder:\n            self._enc_conv_num = _count_conv3d(self.encoder)\n            self._enc_conv_idx = 0\n            self._enc_feat_map = [None] * self._enc_conv_num\n\n    def encode(self, x: torch.Tensor) -> torch.Tensor:\n        if self.use_feature_cache:\n            self.clear_cache()\n            if self.config.patch_size is not None:\n                x = patchify(x, patch_size=self.config.patch_size)\n            with forward_context(\n                feat_cache_arg=self._enc_feat_map, feat_idx_arg=self._enc_conv_idx\n            ):\n                t = x.shape[2]\n                iter_ = 1 + (t - 1) // 4\n                for i in range(iter_):\n                    feat_idx.set(0)\n                    if i == 0:\n                        out = self.encoder(x[:, :, :1, :, :])\n                    else:\n                        out_ = self.encoder(x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :])\n                        out = torch.cat([out, out_], 2)\n            enc = self.quant_conv(out)\n            mu, logvar = enc[:, : self.z_dim, :, :, :], enc[:, self.z_dim :, :, :, :]\n            enc = torch.cat([mu, logvar], dim=1)\n            enc = DiagonalGaussianDistribution(enc)\n            self.clear_cache()\n        else:\n            for block in self.encoder.down_blocks:\n                if isinstance(block, WanResample) and block.mode == \"downsample3d\":\n                    _padding = list(block.time_conv._padding)\n                    _padding[4] = 2\n                    block.time_conv._padding = tuple(_padding)\n            enc = ParallelTiledVAE.encode(self, x)\n\n        return enc\n\n    def _encode(self, x: torch.Tensor, first_frame=False) -> torch.Tensor:\n        with forward_context(first_frame_arg=first_frame):\n            out = self.encoder(x)\n        enc = self.quant_conv(out)\n        mu, logvar = enc[:, : self.z_dim, :, :, :], enc[:, self.z_dim :, :, :, :]\n        enc = torch.cat([mu, logvar], dim=1)\n        return enc\n\n    def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:\n        first_frame = x[:, :, 0, :, :].unsqueeze(2)\n        first_frame = self._encode(first_frame, first_frame=True)\n\n        enc = ParallelTiledVAE.tiled_encode(self, x)\n        enc = enc[:, :, 1:]\n        enc = torch.cat([first_frame, enc], dim=2)\n        return enc\n\n    def spatial_tiled_encode(self, x: torch.Tensor) -> torch.Tensor:\n        first_frame = x[:, :, 0, :, :].unsqueeze(2)\n        first_frame = self._encode(first_frame, first_frame=True)\n\n        enc = ParallelTiledVAE.spatial_tiled_encode(self, x)\n        enc = enc[:, :, 1:]\n        enc = torch.cat([first_frame, enc], dim=2)\n        return enc\n\n    def decode(self, z: torch.Tensor) -> torch.Tensor:\n        if self.use_feature_cache:\n            self.clear_cache()\n            iter_ = z.shape[2]\n            x = self.post_quant_conv(z)\n            with forward_context(\n                feat_cache_arg=self._feat_map, feat_idx_arg=self._conv_idx\n            ):\n                for i in range(iter_):\n                    feat_idx.set(0)\n                    if i == 0:\n                        first_chunk.set(True)\n                        out = self.decoder(x[:, :, i : i + 1, :, :])\n                    else:\n                        first_chunk.set(False)\n                        out_ = self.decoder(x[:, :, i : i + 1, :, :])\n                        out = torch.cat([out, out_], 2)\n\n            if self.config.patch_size is not None:\n                out = unpatchify(out, patch_size=self.config.patch_size)\n\n            out = out.float()\n            out = torch.clamp(out, min=-1.0, max=1.0)\n            self.clear_cache()\n        else:\n            out = ParallelTiledVAE.decode(self, z)\n\n        return out\n\n    def _decode(self, z: torch.Tensor, first_frame=False) -> torch.Tensor:\n        x = self.post_quant_conv(z)\n        with forward_context(first_frame_arg=first_frame):\n            out = self.decoder(x)\n\n        out = torch.clamp(out, min=-1.0, max=1.0)\n\n        return out\n\n    def tiled_decode(self, z: torch.Tensor) -> torch.Tensor:\n        self.blend_num_frames *= 2\n        dec = ParallelTiledVAE.tiled_decode(self, z)\n        start_frame_idx = self.temporal_compression_ratio - 1\n        dec = dec[:, :, start_frame_idx:]\n        return dec\n\n    def spatial_tiled_decode(self, z: torch.Tensor) -> torch.Tensor:\n        dec = ParallelTiledVAE.spatial_tiled_decode(self, z)\n        start_frame_idx = self.temporal_compression_ratio - 1\n        dec = dec[:, :, start_frame_idx:]\n        return dec\n\n    def parallel_tiled_decode(self, z: torch.FloatTensor) -> torch.FloatTensor:\n        self.blend_num_frames *= 2\n        dec = ParallelTiledVAE.parallel_tiled_decode(self, z)\n        start_frame_idx = self.temporal_compression_ratio - 1\n        dec = dec[:, :, start_frame_idx:]\n        return dec\n\n    def forward(\n        self,\n        sample: torch.Tensor,\n        sample_posterior: bool = False,\n        generator: torch.Generator | None = None,\n    ) -> torch.Tensor:\n        \"\"\"\n        Args:\n            sample (`torch.Tensor`): Input sample.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`DecoderOutput`] instead of a plain tuple.\n        \"\"\"\n        x = sample\n        posterior = self.encode(x).latent_dist\n        if sample_posterior:\n            z = posterior.sample(generator=generator)\n        else:\n            z = posterior.mode()\n        dec = self.decode(z)\n        return dec\n\n\nEntryClass = AutoencoderKLWan\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/vision_utils.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\nimport os\nimport tempfile\nfrom collections.abc import Callable\nfrom urllib.parse import unquote, urlparse\n\nimport imageio\nimport numpy as np\nimport PIL.Image\nimport PIL.ImageOps\nimport requests\nimport torch\nfrom packaging import version\n\nif version.parse(version.parse(PIL.__version__).base_version) >= version.parse(\"9.1.0\"):\n    PIL_INTERPOLATION = {\n        \"linear\": PIL.Image.Resampling.BILINEAR,\n        \"bilinear\": PIL.Image.Resampling.BILINEAR,\n        \"bicubic\": PIL.Image.Resampling.BICUBIC,\n        \"lanczos\": PIL.Image.Resampling.LANCZOS,\n        \"nearest\": PIL.Image.Resampling.NEAREST,\n    }\nelse:\n    PIL_INTERPOLATION = {\n        \"linear\": PIL.Image.LINEAR,\n        \"bilinear\": PIL.Image.BILINEAR,\n        \"bicubic\": PIL.Image.BICUBIC,\n        \"lanczos\": PIL.Image.LANCZOS,\n        \"nearest\": PIL.Image.NEAREST,\n    }\n\n\ndef pil_to_numpy(images: list[PIL.Image.Image] | PIL.Image.Image) -> np.ndarray:\n    r\"\"\"\n    Convert a PIL image or a list of PIL images to NumPy arrays.\n\n    Args:\n        images (`PIL.Image.Image` or `List[PIL.Image.Image]`):\n            The PIL image or list of images to convert to NumPy format.\n\n    Returns:\n        `np.ndarray`:\n            A NumPy array representation of the images.\n    \"\"\"\n    if not isinstance(images, list):\n        images = [images]\n    images = [np.array(image).astype(np.float32) / 255.0 for image in images]\n    images_arr: np.ndarray = np.stack(images, axis=0)\n\n    return images_arr\n\n\ndef numpy_to_pt(images: np.ndarray) -> torch.Tensor:\n    r\"\"\"\n    Convert a NumPy image to a PyTorch tensor.\n\n    Args:\n        images (`np.ndarray`):\n            The NumPy image array to convert to PyTorch format.\n\n    Returns:\n        `torch.Tensor`:\n            A PyTorch tensor representation of the images.\n    \"\"\"\n    if images.ndim == 3:\n        images = images[..., None]\n\n    images = torch.from_numpy(images.transpose(0, 3, 1, 2))\n    return images\n\n\ndef normalize(images: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor:\n    r\"\"\"\n    Normalize an image array to [-1,1].\n\n    Args:\n        images (`np.ndarray` or `torch.Tensor`):\n            The image array to normalize.\n\n    Returns:\n        `np.ndarray` or `torch.Tensor`:\n            The normalized image array.\n    \"\"\"\n    return 2.0 * images - 1.0\n\n\n# adapted from diffusers.utils import load_image\ndef load_image(\n    image: str | PIL.Image.Image,\n    convert_method: Callable[[PIL.Image.Image], PIL.Image.Image] | None = None,\n) -> PIL.Image.Image:\n    \"\"\"\n    Loads `image` to a PIL Image.\n\n    Args:\n        image (`str` or `PIL.Image.Image`):\n            The image to convert to the PIL Image format.\n        convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], *optional*):\n            A conversion method to apply to the image after loading it. When set to `None` the image will be converted\n            \"RGB\".\n    \"\"\"\n    if isinstance(image, str):\n        if image.startswith(\"http://\") or image.startswith(\"https://\"):\n            image = PIL.Image.open(requests.get(image, stream=True).raw)\n        elif os.path.isfile(image):\n            image = PIL.Image.open(image)\n        else:\n            raise ValueError(\n                f\"Incorrect path or URL. URLs must start with `http://` or `https://`, and {image} is not a valid path.\"\n            )\n    elif isinstance(image, PIL.Image.Image):\n        image = image\n    else:\n        raise ValueError(\n            \"Incorrect format used for the image. Should be a URL linking to an image, a local path, or a PIL image.\"\n        )\n\n    image = PIL.ImageOps.exif_transpose(image)\n\n    if convert_method is not None:\n        image = convert_method(image)\n    else:\n        image = image.convert(\"RGB\")\n\n    return image\n\n\n# adapted from diffusers.utils import load_video\ndef load_video(\n    video: str,\n    convert_method: (\n        Callable[[list[PIL.Image.Image]], list[PIL.Image.Image]] | None\n    ) = None,\n) -> list[PIL.Image.Image]:\n    \"\"\"\n    Loads `video` to a list of PIL Image.\n    Args:\n        video (`str`):\n            A URL or Path to a video to convert to a list of PIL Image format.\n        convert_method (Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]], *optional*):\n            A conversion method to apply to the video after loading it. When set to `None` the images will be converted\n            to \"RGB\".\n    Returns:\n        `List[PIL.Image.Image]`:\n            The video as a list of PIL images.\n    \"\"\"\n    is_url = video.startswith(\"http://\") or video.startswith(\"https://\")\n    is_file = os.path.isfile(video)\n    was_tempfile_created = False\n\n    if not (is_url or is_file):\n        raise ValueError(\n            f\"Incorrect path or URL. URLs must start with `http://` or `https://`, and {video} is not a valid path.\"\n        )\n\n    if is_url:\n        response = requests.get(video, stream=True)\n        if response.status_code != 200:\n            raise ValueError(\n                f\"Failed to download video. Status code: {response.status_code}\"\n            )\n\n        parsed_url = urlparse(video)\n        file_name = os.path.basename(unquote(parsed_url.path))\n\n        suffix = os.path.splitext(file_name)[1] or \".mp4\"\n        with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as temp_file:\n            video_path = temp_file.name\n            video_data = response.iter_content(chunk_size=8192)\n            for chunk in video_data:\n                temp_file.write(chunk)\n\n        video = video_path\n\n    pil_images = []\n    if video.endswith(\".gif\"):\n        gif = PIL.Image.open(video)\n        try:\n            while True:\n                pil_images.append(gif.copy())\n                gif.seek(gif.tell() + 1)\n        except EOFError:\n            pass\n\n    else:\n        try:\n            imageio.plugins.ffmpeg.get_exe()\n        except AttributeError:\n            raise AttributeError(\n                \"`Unable to find an ffmpeg installation on your machine. Please install via `pip install imageio-ffmpeg\"\n            ) from None\n\n        with imageio.get_reader(video) as reader:\n            # Read all frames\n            for frame in reader:\n                pil_images.append(PIL.Image.fromarray(frame))\n\n    if was_tempfile_created:\n        os.remove(video_path)\n\n    if convert_method is not None:\n        pil_images = convert_method(pil_images)\n\n    return pil_images\n\n\ndef get_default_height_width(\n    image: PIL.Image.Image | np.ndarray | torch.Tensor,\n    vae_scale_factor: int,\n    height: int | None = None,\n    width: int | None = None,\n) -> tuple[int, int]:\n    r\"\"\"\n    Returns the height and width of the image, downscaled to the next integer multiple of `vae_scale_factor`.\n\n    Args:\n        image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):\n            The image input, which can be a PIL image, NumPy array, or PyTorch tensor. If it is a NumPy array, it\n            should have shape `[batch, height, width]` or `[batch, height, width, channels]`. If it is a PyTorch\n            tensor, it should have shape `[batch, channels, height, width]`.\n        height (`Optional[int]`, *optional*, defaults to `None`):\n            The height of the preprocessed image. If `None`, the height of the `image` input will be used.\n        width (`Optional[int]`, *optional*, defaults to `None`):\n            The width of the preprocessed image. If `None`, the width of the `image` input will be used.\n\n    Returns:\n        `Tuple[int, int]`:\n            A tuple containing the height and width, both resized to the nearest integer multiple of\n            `vae_scale_factor`.\n    \"\"\"\n\n    if height is None:\n        if isinstance(image, PIL.Image.Image):\n            height = image.height\n        elif isinstance(image, torch.Tensor):\n            height = image.shape[2]\n        else:\n            height = image.shape[1]\n\n    if width is None:\n        if isinstance(image, PIL.Image.Image):\n            width = image.width\n        elif isinstance(image, torch.Tensor):\n            width = image.shape[3]\n        else:\n            width = image.shape[2]\n\n    width, height = (\n        x - x % vae_scale_factor for x in (width, height)\n    )  # resize to integer multiple of vae_scale_factor\n\n    return height, width\n\n\ndef resize(\n    image: PIL.Image.Image | np.ndarray | torch.Tensor,\n    height: int,\n    width: int,\n    resize_mode: str = \"default\",  # \"default\", \"fill\", \"crop\"\n    resample: str = \"lanczos\",\n) -> PIL.Image.Image | np.ndarray | torch.Tensor:\n    \"\"\"\n    Resize image.\n\n    Args:\n        image (`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):\n            The image input, can be a PIL image, numpy array or pytorch tensor.\n        height (`int`):\n            The height to resize to.\n        width (`int`):\n            The width to resize to.\n        resize_mode (`str`, *optional*, defaults to `default`):\n            The resize mode to use, can be one of `default` or `fill`. If `default`, will resize the image to fit\n            within the specified width and height, and it may not maintaining the original aspect ratio. If `fill`,\n            will resize the image to fit within the specified width and height, maintaining the aspect ratio, and\n            then center the image within the dimensions, filling empty with data from image. If `crop`, will resize\n            the image to fit within the specified width and height, maintaining the aspect ratio, and then center\n            the image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only\n            supported for PIL image input.\n\n    Returns:\n        `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:\n            The resized image.\n    \"\"\"\n    if resize_mode != \"default\" and not isinstance(image, PIL.Image.Image):\n        raise ValueError(\n            f\"Only PIL image input is supported for resize_mode {resize_mode}\"\n        )\n    assert isinstance(image, PIL.Image.Image)\n    if resize_mode == \"default\":\n        image = image.resize((width, height), resample=PIL_INTERPOLATION[resample])\n    else:\n        raise ValueError(f\"resize_mode {resize_mode} is not supported\")\n    return image\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/models/vocoder/ltx_2_vocoder.py",
    "content": "import math\nfrom abc import ABC\nfrom typing import Tuple\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom sglang.multimodal_gen.configs.models.vocoder.ltx_vocoder import LTXVocoderConfig\n\n\nclass ResBlock(nn.Module):\n    def __init__(\n        self,\n        channels: int,\n        kernel_size: int = 3,\n        stride: int = 1,\n        dilations: Tuple[int, ...] = (1, 3, 5),\n        leaky_relu_negative_slope: float = 0.1,\n        padding_mode: str = \"same\",\n    ):\n        super().__init__()\n        self.dilations = dilations\n        self.negative_slope = leaky_relu_negative_slope\n\n        self.convs1 = nn.ModuleList(\n            [\n                nn.Conv1d(\n                    channels,\n                    channels,\n                    kernel_size,\n                    stride=stride,\n                    dilation=dilation,\n                    padding=padding_mode,\n                )\n                for dilation in dilations\n            ]\n        )\n\n        self.convs2 = nn.ModuleList(\n            [\n                nn.Conv1d(\n                    channels,\n                    channels,\n                    kernel_size,\n                    stride=stride,\n                    dilation=1,\n                    padding=padding_mode,\n                )\n                for _ in range(len(dilations))\n            ]\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        for conv1, conv2 in zip(self.convs1, self.convs2):\n            xt = F.leaky_relu(x, negative_slope=self.negative_slope)\n            xt = conv1(xt)\n            xt = F.leaky_relu(xt, negative_slope=self.negative_slope)\n            xt = conv2(xt)\n            x = x + xt\n        return x\n\n\nclass LTX2Vocoder(ABC, nn.Module):\n    r\"\"\"\n    LTX 2.0 vocoder for converting generated mel spectrograms back to audio waveforms.\n    \"\"\"\n\n    def __init__(\n        self,\n        config: LTXVocoderConfig,\n    ):\n        super().__init__()\n        self.config = config\n        self.sample_rate = (\n            getattr(config.arch_config, \"sample_rate\", None)\n            or getattr(config.arch_config, \"sampling_rate\", None)\n            or getattr(config.arch_config, \"audio_sample_rate\", None)\n        )\n\n        in_channels = config.arch_config.in_channels\n        hidden_channels = config.arch_config.hidden_channels\n        out_channels = config.arch_config.out_channels\n        upsample_kernel_sizes = config.arch_config.upsample_kernel_sizes\n        upsample_factors = config.arch_config.upsample_factors\n        resnet_kernel_sizes = config.arch_config.resnet_kernel_sizes\n        resnet_dilations = config.arch_config.resnet_dilations\n        leaky_relu_negative_slope = config.arch_config.leaky_relu_negative_slope\n\n        self.num_upsample_layers = len(upsample_kernel_sizes)\n        self.resnets_per_upsample = len(resnet_kernel_sizes)\n        self.out_channels = out_channels\n        self.total_upsample_factor = math.prod(upsample_factors)\n        self.negative_slope = leaky_relu_negative_slope\n\n        if self.num_upsample_layers != len(upsample_factors):\n            raise ValueError(\n                f\"`upsample_kernel_sizes` and `upsample_factors` should be lists of the same length but are length\"\n                f\" {self.num_upsample_layers} and {len(upsample_factors)}, respectively.\"\n            )\n\n        if self.resnets_per_upsample != len(resnet_dilations):\n            raise ValueError(\n                f\"`resnet_kernel_sizes` and `resnet_dilations` should be lists of the same length but are length\"\n                f\" {len(self.resnets_per_upsample)} and {len(resnet_dilations)}, respectively.\"\n            )\n\n        self.conv_in = nn.Conv1d(\n            in_channels, hidden_channels, kernel_size=7, stride=1, padding=3\n        )\n\n        self.upsamplers = nn.ModuleList()\n        self.resnets = nn.ModuleList()\n        input_channels = hidden_channels\n        for i, (stride, kernel_size) in enumerate(\n            zip(upsample_factors, upsample_kernel_sizes)\n        ):\n            output_channels = input_channels // 2\n            self.upsamplers.append(\n                nn.ConvTranspose1d(\n                    input_channels,  # hidden_channels // (2 ** i)\n                    output_channels,  # hidden_channels // (2 ** (i + 1))\n                    kernel_size,\n                    stride=stride,\n                    padding=(kernel_size - stride) // 2,\n                )\n            )\n\n            for kernel_size, dilations in zip(resnet_kernel_sizes, resnet_dilations):\n                self.resnets.append(\n                    ResBlock(\n                        output_channels,\n                        kernel_size,\n                        dilations=dilations,\n                        leaky_relu_negative_slope=leaky_relu_negative_slope,\n                    )\n                )\n            input_channels = output_channels\n\n        self.conv_out = nn.Conv1d(output_channels, out_channels, 7, stride=1, padding=3)\n\n    def forward(\n        self, hidden_states: torch.Tensor, time_last: bool = False\n    ) -> torch.Tensor:\n        r\"\"\"\n        Forward pass of the vocoder.\n\n        Args:\n            hidden_states (`torch.Tensor`):\n                Input Mel spectrogram tensor of shape `(batch_size, num_channels, time, num_mel_bins)` if `time_last`\n                is `False` (the default) or shape `(batch_size, num_channels, num_mel_bins, time)` if `time_last` is\n                `True`.\n            time_last (`bool`, *optional*, defaults to `False`):\n                Whether the last dimension of the input is the time/frame dimension or the Mel bins dimension.\n\n        Returns:\n            `torch.Tensor`:\n                Audio waveform tensor of shape (batch_size, out_channels, audio_length)\n        \"\"\"\n\n        # Ensure that the time/frame dimension is last\n        if not time_last:\n            hidden_states = hidden_states.transpose(2, 3)\n        # Combine channels and frequency (mel bins) dimensions\n        hidden_states = hidden_states.flatten(1, 2)\n\n        hidden_states = self.conv_in(hidden_states)\n\n        for i in range(self.num_upsample_layers):\n            hidden_states = F.leaky_relu(\n                hidden_states, negative_slope=self.negative_slope\n            )\n            hidden_states = self.upsamplers[i](hidden_states)\n\n            # Run all resnets in parallel on hidden_states\n            start = i * self.resnets_per_upsample\n            end = (i + 1) * self.resnets_per_upsample\n            resnet_outputs = torch.stack(\n                [self.resnets[j](hidden_states) for j in range(start, end)], dim=0\n            )\n\n            hidden_states = torch.mean(resnet_outputs, dim=0)\n\n        # NOTE: unlike the first leaky ReLU, this leaky ReLU is set to use the default F.leaky_relu negative slope of\n        # 0.01 (whereas the others usually use a slope of 0.1). Not sure if this is intended\n        hidden_states = F.leaky_relu(hidden_states, negative_slope=0.01)\n        hidden_states = self.conv_out(hidden_states)\n        hidden_states = torch.tanh(hidden_states)\n\n        return hidden_states\n\n\nEntryClass = LTX2Vocoder\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines/__init__.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines/comfyui_flux_pipeline.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n# SPDX-License-Identifier: Apache-2.0\n\nimport os\nimport re\nfrom typing import Any, Generator\n\nimport torch\n\nfrom sglang.multimodal_gen.configs.models.dits.flux import FluxConfig\nfrom sglang.multimodal_gen.runtime.distributed import get_local_torch_device\nfrom sglang.multimodal_gen.runtime.models.registry import ModelRegistry\nfrom sglang.multimodal_gen.runtime.models.schedulers.scheduling_comfyui_passthrough import (\n    ComfyUIPassThroughScheduler,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core import LoRAPipeline\nfrom sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import (\n    ComposedPipelineBase,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages import (\n    ComfyUILatentPreparationStage,\n    DenoisingStage,\n)\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.utils import PRECISION_TO_TYPE\n\nlogger = init_logger(__name__)\n\n\nclass ComfyUIFluxPipeline(LoRAPipeline, ComposedPipelineBase):\n    \"\"\"\n    Simplified pipeline for ComfyUI integration with only denoising stage.\n\n    This pipeline requires pre-processed inputs:\n    - prompt_embeds: Pre-encoded text embeddings (list of tensors)\n    - negative_prompt_embeds: Pre-encoded negative prompt embeddings (if using CFG)\n    - latents: Optional initial noise latents (will be generated if not provided)\n\n    Usage:\n        generator = DiffGenerator.from_pretrained(\n            model_path=\"path/to/model\",\n            pipeline_class_name=\"ComfyUIFluxPipeline\",\n            device=\"cuda\",\n        )\n    \"\"\"\n\n    pipeline_name = \"ComfyUIFluxPipeline\"\n\n    # Configuration classes for safetensors files without model_index.json\n    from sglang.multimodal_gen.configs.pipeline_configs.flux import FluxPipelineConfig\n    from sglang.multimodal_gen.configs.sample.flux import FluxSamplingParams\n\n    pipeline_config_cls = FluxPipelineConfig\n    sampling_params_cls = FluxSamplingParams\n\n    _required_config_modules = [\n        \"transformer\",\n        \"scheduler\",\n    ]\n\n    def initialize_pipeline(self, server_args: ServerArgs):\n        \"\"\"\n        Initialize the pipeline with ComfyUI pass-through scheduler.\n        This scheduler does not modify latents, allowing ComfyUI to handle denoising.\n        \"\"\"\n        self.modules[\"scheduler\"] = ComfyUIPassThroughScheduler(\n            num_train_timesteps=1000\n        )\n\n        if hasattr(server_args.pipeline_config, \"vae_config\"):\n            vae_config = server_args.pipeline_config.vae_config\n            if hasattr(vae_config, \"post_init\") and not hasattr(\n                vae_config, \"_post_init_called\"\n            ):\n                vae_config.post_init()\n                logger.info(\n                    \"Called vae_config.post_init() to set spatial_compression_ratio. \"\n                    f\"spatial_compression_ratio={vae_config.arch_config.spatial_compression_ratio}\"\n                )\n\n    def load_modules(\n        self,\n        server_args: ServerArgs,\n        loaded_modules: dict[str, torch.nn.Module] | None = None,\n    ) -> dict[str, Any]:\n        \"\"\"\n        Load modules for ComfyUIFluxPipeline.\n\n        If model_path is a safetensors file, load transformer directly from it\n        without requiring model_index.json. Otherwise, fall back to default loading.\n        \"\"\"\n        if os.path.isfile(self.model_path) and self.model_path.endswith(\".safetensors\"):\n            logger.info(\n                \"Detected safetensors file, loading transformer directly from: %s\",\n                self.model_path,\n            )\n            return self._load_transformer_from_safetensors(server_args, loaded_modules)\n        else:\n            logger.info(\n                \"Model path is a directory, using default loading method: %s\",\n                self.model_path,\n            )\n            return super().load_modules(server_args, loaded_modules)\n\n    def _load_and_convert_weights_from_safetensors(\n        self,\n        model_cls: type,\n        dit_config: FluxConfig,\n        hf_config: dict,\n        safetensors_list: list[str],\n        updated_mapping: dict,\n        qkv_size: int,\n        mlp_hidden_dim: int,\n        has_guidance_embeds: bool,\n        default_dtype: torch.dtype,\n    ) -> tuple[torch.nn.Module, dict]:\n        \"\"\"\n        Load and convert weights from safetensors file, then load them into the model.\n        \"\"\"\n        from sglang.multimodal_gen.runtime.loader.utils import (\n            get_param_names_mapping,\n            set_default_torch_dtype,\n        )\n        from sglang.multimodal_gen.runtime.loader.weight_utils import (\n            safetensors_weights_iterator,\n        )\n\n        logger.info(\n            \"Converting ComfyUI Flux weights to SGLang format and loading model...\"\n        )\n\n        # Create model on target device\n        device = get_local_torch_device()\n        with set_default_torch_dtype(default_dtype):\n            model = model_cls(**{\"config\": dit_config, \"hf_config\": hf_config})\n            model = model.to(device)\n\n        # Verify model has guidance_embedder if config says it should\n        has_guidance_embedder = hasattr(model.time_text_embed, \"guidance_embedder\")\n        if has_guidance_embeds and not has_guidance_embedder:\n            logger.warning(\n                \"Config has guidance_embeds=True but model doesn't have guidance_embedder. \"\n                \"This may indicate a configuration mismatch.\"\n            )\n        elif not has_guidance_embeds and has_guidance_embedder:\n            logger.warning(\n                \"Config has guidance_embeds=False but model has guidance_embedder. \"\n                \"This may indicate a configuration mismatch.\"\n            )\n\n        # Note: guidance_in mappings are already included in comfyui_flux_mappings above.\n        # If model doesn't support guidance embeddings, the weights will be filtered out\n        # in _convert_comfyui_weights() based on has_guidance_embeds flag.\n\n        param_names_mapping_fn = get_param_names_mapping(updated_mapping)\n\n        weight_iterator = safetensors_weights_iterator(safetensors_list)\n        converted_weights = self._convert_comfyui_weights(\n            weight_iterator=weight_iterator,\n            qkv_size=qkv_size,\n            mlp_hidden_dim=mlp_hidden_dim,\n            has_guidance_embeds=has_guidance_embeds,\n        )\n\n        model_state_dict = model.state_dict()\n        missing_keys = set(model_state_dict.keys())\n        unexpected_keys = []\n        loaded_count = 0\n        reverse_param_names_mapping = {}\n\n        # Handle merged parameters (collect all parts before merging)\n        from collections import defaultdict\n\n        to_merge_params = defaultdict(dict)\n\n        # Process weights incrementally: load immediately after conversion\n        for source_name, tensor in converted_weights:\n            target_name, merge_index, num_params_to_merge = param_names_mapping_fn(\n                source_name\n            )\n            reverse_param_names_mapping[target_name] = (\n                source_name,\n                merge_index,\n                num_params_to_merge,\n            )\n\n            if merge_index is not None:\n                # Collect parts for merging\n                to_merge_params[target_name][merge_index] = tensor\n                if len(to_merge_params[target_name]) == num_params_to_merge:\n                    # All parts collected, merge them\n                    sorted_tensors = [\n                        to_merge_params[target_name][i]\n                        for i in range(num_params_to_merge)\n                    ]\n                    merged_tensor = torch.cat(sorted_tensors, dim=0)\n                    # Load immediately after merging\n                    if target_name in model_state_dict:\n                        param = model_state_dict[target_name]\n                        loaded_tensor = merged_tensor.to(\n                            device=param.device, dtype=param.dtype\n                        )\n                        param.data.copy_(loaded_tensor)\n                        missing_keys.discard(target_name)\n                        loaded_count += 1\n                        del merged_tensor, loaded_tensor\n                    else:\n                        unexpected_keys.append(target_name)\n                    # Clear merged parts\n                    del to_merge_params[target_name]\n                    for t in sorted_tensors:\n                        del t\n            else:\n                # Direct mapping, load immediately\n                if target_name in model_state_dict:\n                    param = model_state_dict[target_name]\n                    # Check shape compatibility\n                    if tensor.shape != param.shape:\n                        logger.warning(\n                            f\"Shape mismatch for {target_name}: \"\n                            f\"loaded {tensor.shape} vs model {param.shape}, skipping. \"\n                            f\"Source: {source_name}\"\n                        )\n                        unexpected_keys.append(target_name)\n                        del tensor\n                        continue\n\n                    # Debug logging for norm_out.linear to verify mapping\n                    if (\n                        \"norm_out.linear\" in target_name\n                        or \"final_layer.adaLN_modulation\" in source_name\n                    ):\n                        logger.info(\n                            f\"Loading norm_out.linear: {source_name} -> {target_name}, \"\n                            f\"shape: {tensor.shape}\"\n                        )\n\n                    loaded_tensor = tensor.to(device=param.device, dtype=param.dtype)\n                    param.data.copy_(loaded_tensor)\n                    missing_keys.discard(target_name)\n                    loaded_count += 1\n                    del tensor, loaded_tensor\n                else:\n                    # Debug logging for unmapped parameters\n                    if \"norm_out.linear\" in target_name:\n                        logger.warning(\n                            f\"norm_out.linear parameter {target_name} not found in model state_dict. \"\n                            f\"Source: {source_name}\"\n                        )\n                    unexpected_keys.append(target_name)\n\n        optional_missing_keys = []\n        required_missing_keys = []\n        for key in missing_keys:\n            if key.endswith(\".bias\"):\n                # Check if corresponding weight exists (if weight exists but bias doesn't, it's optional)\n                weight_key = key.replace(\".bias\", \".weight\")\n                if weight_key not in missing_keys:\n                    optional_missing_keys.append(key)\n                else:\n                    required_missing_keys.append(key)\n            else:\n                required_missing_keys.append(key)\n\n        if required_missing_keys:\n            logger.warning(\n                f\"Required missing keys (first 10): {required_missing_keys[:10]}...\"\n            )\n        if optional_missing_keys:\n            logger.info(\n                f\"Optional missing keys (bias parameters, {len(optional_missing_keys)} total): \"\n                f\"These will use default values (zeros)\"\n            )\n        if unexpected_keys:\n            logger.warning(f\"Unexpected keys (first 10): {unexpected_keys[:10]}...\")\n\n        logger.info(f\"Successfully loaded {loaded_count} weight tensors\")\n\n        return model, reverse_param_names_mapping\n\n    def _convert_comfyui_weights(\n        self,\n        weight_iterator: Generator[tuple[str, torch.Tensor], None, None],\n        qkv_size: int,\n        mlp_hidden_dim: int,\n        has_guidance_embeds: bool,\n    ) -> Generator[tuple[str, torch.Tensor], None, None]:\n        \"\"\"\n        Convert ComfyUI Flux weights to SGLang format.\n        Splits fused qkv weights into to_q/to_k/to_v plus proj_mlp.\n        Filters out guidance_in weights if model doesn't support guidance embeddings.\n        Handles scale/shift order difference between ComfyUI and AdaLayerNormContinuous.\n        \"\"\"\n        for name, tensor in weight_iterator:\n            if not has_guidance_embeds and name.startswith(\"guidance_in.\"):\n                logger.debug(\n                    f\"Skipping {name} (model doesn't support guidance embeddings)\"\n                )\n                continue\n\n            # Split fused qkv in double blocks into separate q/k/v projections\n            match = re.match(\n                r\"double_blocks\\.(\\d+)\\.(img_attn|txt_attn)\\.qkv\\.(weight|bias)$\", name\n            )\n            if match:\n                block_idx, attn_type, param_type = match.groups()\n                hidden_size = qkv_size // 3\n\n                if tensor.shape[0] < 3 * hidden_size:\n                    logger.warning(\n                        f\"{name} shape {tensor.shape} smaller than expected qkv size {3 * hidden_size}, skipping\"\n                    )\n                    continue\n\n                if param_type == \"bias\":\n                    q_tensor = tensor[:hidden_size]\n                    k_tensor = tensor[hidden_size : 2 * hidden_size]\n                    v_tensor = tensor[2 * hidden_size : 3 * hidden_size]\n                else:\n                    q_tensor = tensor[:hidden_size, :]\n                    k_tensor = tensor[hidden_size : 2 * hidden_size, :]\n                    v_tensor = tensor[2 * hidden_size : 3 * hidden_size, :]\n\n                target_prefix = f\"transformer_blocks.{block_idx}.attn\"\n                if attn_type == \"img_attn\":\n                    yield f\"{target_prefix}.to_q.{param_type}\", q_tensor\n                    yield f\"{target_prefix}.to_k.{param_type}\", k_tensor\n                    yield f\"{target_prefix}.to_v.{param_type}\", v_tensor\n                else:\n                    # txt_attn corresponds to encoder projections\n                    yield f\"{target_prefix}.add_q_proj.{param_type}\", q_tensor\n                    yield f\"{target_prefix}.add_k_proj.{param_type}\", k_tensor\n                    yield f\"{target_prefix}.add_v_proj.{param_type}\", v_tensor\n                continue\n\n            match = re.match(r\"single_blocks\\.(\\d+)\\.linear1\\.(weight|bias)$\", name)\n            if match:\n                block_idx, param_type = match.groups()\n                expected_size = qkv_size + mlp_hidden_dim\n\n                if tensor.shape[0] < expected_size:\n                    logger.warning(\n                        f\"linear1.{param_type} shape {tensor.shape} doesn't match \"\n                        f\"expected size {expected_size}, skipping\"\n                    )\n                    continue\n\n                # Split tensor\n                qkv_tensor = (\n                    tensor[:qkv_size] if param_type == \"bias\" else tensor[:qkv_size, :]\n                )\n                mlp_tensor = (\n                    tensor[qkv_size:] if param_type == \"bias\" else tensor[qkv_size:, :]\n                )\n\n                # Split qkv into q/k/v for single blocks\n                hidden_size = qkv_size // 3\n                if param_type == \"bias\":\n                    q_tensor = qkv_tensor[:hidden_size]\n                    k_tensor = qkv_tensor[hidden_size : 2 * hidden_size]\n                    v_tensor = qkv_tensor[2 * hidden_size : 3 * hidden_size]\n                else:\n                    q_tensor = qkv_tensor[:hidden_size, :]\n                    k_tensor = qkv_tensor[hidden_size : 2 * hidden_size, :]\n                    v_tensor = qkv_tensor[2 * hidden_size : 3 * hidden_size, :]\n\n                yield f\"single_transformer_blocks.{block_idx}.attn.to_q.{param_type}\", q_tensor\n                yield f\"single_transformer_blocks.{block_idx}.attn.to_k.{param_type}\", k_tensor\n                yield f\"single_transformer_blocks.{block_idx}.attn.to_v.{param_type}\", v_tensor\n                yield f\"single_transformer_blocks.{block_idx}.proj_mlp.{param_type}\", mlp_tensor\n            elif name == \"final_layer.adaLN_modulation.1.weight\":\n                # ComfyUI: output order is [shift, scale]\n                # AdaLayerNormContinuous: expects [scale, shift]\n                # Need to swap the first half and second half of the weight matrix\n                # Weight shape: (2 * hidden_size, hidden_size)\n                # Split into two halves and swap them\n                half_size = tensor.shape[0] // 2\n                shift_weights = tensor[:half_size, :]\n                scale_weights = tensor[half_size:, :]\n                # Swap: put scale first, then shift\n                swapped_tensor = torch.cat([scale_weights, shift_weights], dim=0)\n                logger.info(\n                    f\"Swapped scale/shift order for {name}: \"\n                    f\"shape {tensor.shape} -> {swapped_tensor.shape}\"\n                )\n                yield name, swapped_tensor\n            elif name == \"final_layer.adaLN_modulation.1.bias\":\n                # Same swap for bias: (2 * hidden_size,)\n                half_size = tensor.shape[0] // 2\n                shift_bias = tensor[:half_size]\n                scale_bias = tensor[half_size:]\n                swapped_tensor = torch.cat([scale_bias, shift_bias], dim=0)\n                logger.info(\n                    f\"Swapped scale/shift order for {name}: \"\n                    f\"shape {tensor.shape} -> {swapped_tensor.shape}\"\n                )\n                yield name, swapped_tensor\n            else:\n                # Other weights pass through (handled by param_names_mapping)\n                yield name, tensor\n\n    def _load_transformer_from_safetensors(\n        self,\n        server_args: ServerArgs,\n        loaded_modules: dict[str, torch.nn.Module] | None = None,\n    ) -> dict[str, Any]:\n        \"\"\"\n        Load transformer directly from safetensors file without model_index.json.\n        \"\"\"\n        if loaded_modules is not None and \"transformer\" in loaded_modules:\n            logger.info(\"Using provided transformer module\")\n            components = {\n                \"transformer\": loaded_modules[\"transformer\"],\n                \"scheduler\": self.modules.get(\"scheduler\"),\n            }\n            return components\n\n        if hasattr(server_args.pipeline_config, \"dit_config\"):\n            dit_config = server_args.pipeline_config.dit_config\n            if not isinstance(dit_config, FluxConfig):\n                logger.warning(\"dit_config is not FluxConfig, creating new FluxConfig\")\n                dit_config = FluxConfig()\n                server_args.pipeline_config.dit_config = dit_config\n        else:\n            logger.info(\"Creating default FluxConfig\")\n            dit_config = FluxConfig()\n            server_args.pipeline_config.dit_config = dit_config\n\n        # Set guidance_embeds to True for ComfyUI Flux models\n        dit_config.arch_config.guidance_embeds = True\n        logger.info(\"Set guidance_embeds=True for ComfyUI Flux model\")\n\n        if dit_config.arch_config.param_names_mapping is None:\n            dit_config.arch_config.param_names_mapping = {}\n\n        # ComfyUI Flux uses different parameter names than SGLang Flux\n        # Key differences:\n        # - ComfyUI: single_blocks.{i}.linear1 (fused QKV + MLP input)\n        # - SGLang: single_transformer_blocks.{i}.attn.to_qkv + proj_mlp (separate)\n        # - ComfyUI: single_blocks.{i}.linear2\n        # - SGLang: single_transformer_blocks.{i}.proj_out\n        # - ComfyUI: double_blocks.{i}.img_attn.qkv / txt_attn.qkv\n        # - SGLang: transformer_blocks.{i}.attn.to_qkv / attn.to_added_qkv\n\n        # Note: For fused layers like linear1, we need custom weight splitting logic\n        # which will be handled in the weight conversion function below\n        comfyui_flux_mappings = {\n            # Double stream blocks - attention layers\n            r\"double_blocks\\.(\\d+)\\.img_attn\\.qkv\\.(weight|bias)$\": (\n                r\"transformer_blocks.\\1.attn.to_qkv.\\2\",\n                None,\n                None,\n            ),\n            r\"double_blocks\\.(\\d+)\\.txt_attn\\.qkv\\.(weight|bias)$\": (\n                r\"transformer_blocks.\\1.attn.to_added_qkv.\\2\",\n                None,\n                None,\n            ),\n            r\"double_blocks\\.(\\d+)\\.img_attn\\.proj\\.(weight|bias)$\": (\n                r\"transformer_blocks.\\1.attn.to_out.0.\\2\",\n                None,\n                None,\n            ),\n            r\"double_blocks\\.(\\d+)\\.txt_attn\\.proj\\.(weight|bias)$\": (\n                r\"transformer_blocks.\\1.attn.to_add_out.\\2\",\n                None,\n                None,\n            ),\n            r\"double_blocks\\.(\\d+)\\.img_attn\\.norm\\.query_norm\\.scale$\": (\n                r\"transformer_blocks.\\1.attn.norm_q.weight\",\n                None,\n                None,\n            ),\n            r\"double_blocks\\.(\\d+)\\.img_attn\\.norm\\.key_norm\\.scale$\": (\n                r\"transformer_blocks.\\1.attn.norm_k.weight\",\n                None,\n                None,\n            ),\n            r\"double_blocks\\.(\\d+)\\.txt_attn\\.norm\\.query_norm\\.scale$\": (\n                r\"transformer_blocks.\\1.attn.norm_added_q.weight\",\n                None,\n                None,\n            ),\n            r\"double_blocks\\.(\\d+)\\.txt_attn\\.norm\\.key_norm\\.scale$\": (\n                r\"transformer_blocks.\\1.attn.norm_added_k.weight\",\n                None,\n                None,\n            ),\n            # Double stream blocks - MLP layers (map to net structure)\n            r\"double_blocks\\.(\\d+)\\.img_mlp\\.0\\.(weight|bias)$\": (\n                r\"transformer_blocks.\\1.ff.net.0.proj.\\2\",\n                None,\n                None,\n            ),\n            r\"double_blocks\\.(\\d+)\\.img_mlp\\.2\\.(weight|bias)$\": (\n                r\"transformer_blocks.\\1.ff.net.2.\\2\",\n                None,\n                None,\n            ),\n            r\"double_blocks\\.(\\d+)\\.txt_mlp\\.0\\.(weight|bias)$\": (\n                r\"transformer_blocks.\\1.ff_context.net.0.proj.\\2\",\n                None,\n                None,\n            ),\n            r\"double_blocks\\.(\\d+)\\.txt_mlp\\.2\\.(weight|bias)$\": (\n                r\"transformer_blocks.\\1.ff_context.net.2.\\2\",\n                None,\n                None,\n            ),\n            # Double stream blocks - modulation layers\n            r\"double_blocks\\.(\\d+)\\.img_mod\\.lin\\.(weight|bias)$\": (\n                r\"transformer_blocks.\\1.norm1.linear.\\2\",\n                None,\n                None,\n            ),\n            r\"double_blocks\\.(\\d+)\\.txt_mod\\.lin\\.(weight|bias)$\": (\n                r\"transformer_blocks.\\1.norm1_context.linear.\\2\",\n                None,\n                None,\n            ),\n            # Single stream blocks - linear2 maps to proj_out\n            r\"single_blocks\\.(\\d+)\\.linear2\\.(weight|bias)$\": (\n                r\"single_transformer_blocks.\\1.proj_out.\\2\",\n                None,\n                None,\n            ),\n            # Single stream blocks - norm layers (scale -> weight)\n            r\"single_blocks\\.(\\d+)\\.norm\\.query_norm\\.scale$\": (\n                r\"single_transformer_blocks.\\1.attn.norm_q.weight\",\n                None,\n                None,\n            ),\n            r\"single_blocks\\.(\\d+)\\.norm\\.key_norm\\.scale$\": (\n                r\"single_transformer_blocks.\\1.attn.norm_k.weight\",\n                None,\n                None,\n            ),\n            # Single stream blocks - modulation (maps to norm.linear)\n            r\"single_blocks\\.(\\d+)\\.modulation\\.lin\\.(weight|bias)$\": (\n                r\"single_transformer_blocks.\\1.norm.linear.\\2\",\n                None,\n                None,\n            ),\n            # Time and guidance embeddings\n            r\"^time_in\\.in_layer\\.(weight|bias)$\": (\n                r\"time_text_embed.timestep_embedder.linear_1.\\1\",\n                None,\n                None,\n            ),\n            r\"^time_in\\.out_layer\\.(weight|bias)$\": (\n                r\"time_text_embed.timestep_embedder.linear_2.\\1\",\n                None,\n                None,\n            ),\n            r\"^txt_in\\.(weight|bias)$\": (r\"context_embedder.\\1\", None, None),\n            r\"^vector_in\\.in_layer\\.(weight|bias)$\": (\n                r\"time_text_embed.text_embedder.linear_1.\\1\",\n                None,\n                None,\n            ),\n            r\"^vector_in\\.out_layer\\.(weight|bias)$\": (\n                r\"time_text_embed.text_embedder.linear_2.\\1\",\n                None,\n                None,\n            ),\n            # Final layer mappings\n            r\"^final_layer\\.linear\\.(weight|bias)$\": (r\"proj_out.\\1\", None, None),\n            r\"^final_layer\\.norm_final\\.(weight|bias)$\": (r\"norm_out.\\1\", None, None),\n            r\"^final_layer\\.adaLN_modulation\\.1\\.(weight|bias)$\": (\n                r\"norm_out.linear.\\1\",\n                None,\n                None,\n            ),\n            # Image input embedding\n            r\"^img_in\\.(weight|bias)$\": (r\"x_embedder.\\1\", None, None),\n            # Guidance embeddings (if model supports guidance)\n            r\"^guidance_in\\.in_layer\\.(weight|bias)$\": (\n                r\"time_text_embed.guidance_embedder.linear_1.\\1\",\n                None,\n                None,\n            ),\n            r\"^guidance_in\\.out_layer\\.(weight|bias)$\": (\n                r\"time_text_embed.guidance_embedder.linear_2.\\1\",\n                None,\n                None,\n            ),\n        }\n\n        # Merge ComfyUI mappings with existing mappings (ComfyUI mappings take precedence)\n        updated_mapping = {\n            **dit_config.arch_config.param_names_mapping,\n            **comfyui_flux_mappings,\n        }\n        dit_config.arch_config.param_names_mapping = updated_mapping\n        logger.info(\n            \"Added ComfyUI weight name mappings for Flux model. \"\n            f\"Total mappings: {len(updated_mapping)}\"\n        )\n\n        cls_name = \"FluxTransformer2DModel\"\n        model_cls, _ = ModelRegistry.resolve_model_cls(cls_name)\n        logger.info(\"Resolved transformer class: %s\", cls_name)\n\n        original_mapping = None\n        if comfyui_flux_mappings:\n            original_mapping = model_cls.param_names_mapping\n            model_cls.param_names_mapping = updated_mapping\n            logger.info(\n                \"Temporarily updated model class param_names_mapping with ComfyUI mappings. \"\n                f\"Total mappings: {len(updated_mapping)}\"\n            )\n\n        safetensors_list = [self.model_path]\n        logger.info(\"Loading weights from: %s\", safetensors_list)\n        default_dtype = PRECISION_TO_TYPE[server_args.pipeline_config.dit_precision]\n        server_args.model_paths[\"transformer\"] = os.path.dirname(self.model_path) or \".\"\n        hf_config = {}\n\n        hidden_size = (\n            dit_config.arch_config.num_attention_heads\n            * dit_config.arch_config.attention_head_dim\n        )\n        mlp_ratio = getattr(dit_config.arch_config, \"mlp_ratio\", 4.0)\n        mlp_hidden_dim = int(hidden_size * mlp_ratio)\n        qkv_size = 3 * hidden_size\n        has_guidance_embeds = True\n\n        # Load and convert weights from safetensors file\n        model, reverse_param_names_mapping = (\n            self._load_and_convert_weights_from_safetensors(\n                model_cls=model_cls,\n                dit_config=dit_config,\n                hf_config=hf_config,\n                safetensors_list=safetensors_list,\n                updated_mapping=updated_mapping,\n                qkv_size=qkv_size,\n                mlp_hidden_dim=mlp_hidden_dim,\n                has_guidance_embeds=has_guidance_embeds,\n                default_dtype=default_dtype,\n            )\n        )\n\n        model = model.eval()\n        for param in model.parameters():\n            param.requires_grad = False\n\n        model.reverse_param_names_mapping = reverse_param_names_mapping\n\n        if original_mapping is not None:\n            model_cls.param_names_mapping = original_mapping\n\n        total_params = sum(p.numel() for p in model.parameters())\n        logger.info(\"Loaded transformer with %.2fB parameters\", total_params / 1e9)\n\n        components = {\n            \"transformer\": model,\n            \"scheduler\": self.modules.get(\"scheduler\"),\n        }\n\n        logger.info(\"Successfully loaded modules: %s\", list(components.keys()))\n        return components\n\n    def create_pipeline_stages(self, server_args: ServerArgs):\n        logger.info(\n            \"ComfyUIFluxPipeline.create_pipeline_stages() called - creating latent_preparation_stage and denoising_stage\"\n        )\n\n        self.add_stages(\n            [\n                ComfyUILatentPreparationStage(\n                    scheduler=self.get_module(\"scheduler\"),\n                    transformer=self.get_module(\"transformer\"),\n                ),\n                DenoisingStage(\n                    transformer=self.get_module(\"transformer\"),\n                    scheduler=self.get_module(\"scheduler\"),\n                ),\n            ]\n        )\n\n        logger.info(\n            f\"ComfyUIFluxPipeline stages created: {list(self._stage_name_mapping.keys())}\"\n        )\n\n\nEntryClass = ComfyUIFluxPipeline\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines/comfyui_qwen_image_pipeline.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\nimport os\nfrom itertools import chain\nfrom typing import Any\n\nimport torch\nfrom torch.distributed import init_device_mesh\nfrom torch.distributed.fsdp import MixedPrecisionPolicy\n\nfrom sglang.multimodal_gen.configs.models.dits.qwenimage import QwenImageDitConfig\nfrom sglang.multimodal_gen.runtime.distributed import get_local_torch_device\nfrom sglang.multimodal_gen.runtime.loader.fsdp_load import (\n    load_model_from_full_model_state_dict,\n    shard_model,\n)\nfrom sglang.multimodal_gen.runtime.loader.utils import (\n    get_param_names_mapping,\n    set_default_torch_dtype,\n)\nfrom sglang.multimodal_gen.runtime.loader.weight_utils import (\n    safetensors_weights_iterator,\n)\nfrom sglang.multimodal_gen.runtime.models.registry import ModelRegistry\nfrom sglang.multimodal_gen.runtime.models.schedulers.scheduling_comfyui_passthrough import (\n    ComfyUIPassThroughScheduler,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core import LoRAPipeline\nfrom sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import (\n    ComposedPipelineBase,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages import (\n    ComfyUILatentPreparationStage,\n    DenoisingStage,\n)\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.utils import PRECISION_TO_TYPE, set_mixed_precision_policy\n\nlogger = init_logger(__name__)\n\n\nclass ComfyUIQwenImagePipelineBase(LoRAPipeline, ComposedPipelineBase):\n    \"\"\"\n    Base pipeline for ComfyUI QwenImage integration with only denoising stage.\n\n    This pipeline requires pre-processed inputs:\n    - prompt_embeds: Pre-encoded text embeddings (list of tensors)\n    - latents: Pre-processed image latents in sequence format [B, S, D]\n\n    Usage:\n        generator = DiffGenerator.from_pretrained(\n            model_path=\"path/to/model\",\n            pipeline_class_name=\"ComfyUIQwenImagePipeline\",\n            device=\"cuda\",\n        )\n    \"\"\"\n\n    # Subclasses should override this\n    zero_cond_t: bool = False\n\n    pipeline_name = \"ComfyUIQwenImagePipeline\"\n\n    _required_config_modules = [\n        \"transformer\",\n        \"scheduler\",\n    ]\n\n    def initialize_pipeline(self, server_args: ServerArgs):\n        \"\"\"\n        Initialize the pipeline with ComfyUI pass-through scheduler.\n        This scheduler does not modify latents, allowing ComfyUI to handle denoising.\n        \"\"\"\n        self.modules[\"scheduler\"] = ComfyUIPassThroughScheduler(\n            num_train_timesteps=1000\n        )\n\n        # Ensure VAE config is properly initialized even though we don't load the VAE model\n        vae_config = server_args.pipeline_config.vae_config\n        vae_config.post_init()\n        logger.info(\n            \"Called vae_config.post_init() to set vae_scale_factor. \"\n            f\"vae_scale_factor={vae_config.arch_config.vae_scale_factor}\"\n        )\n\n    def load_modules(\n        self,\n        server_args: ServerArgs,\n        loaded_modules: dict[str, torch.nn.Module] | None = None,\n    ) -> dict[str, Any]:\n        \"\"\"\n        Load modules for ComfyUIQwenImagePipeline.\n\n        If model_path is a safetensors file, load transformer directly from it\n        without requiring model_index.json. Otherwise, fall back to default loading.\n        \"\"\"\n        if os.path.isfile(self.model_path) and self.model_path.endswith(\".safetensors\"):\n            logger.info(\n                \"Detected safetensors file, loading transformer directly from: %s\",\n                self.model_path,\n            )\n            return self._load_transformer_from_safetensors(server_args, loaded_modules)\n        else:\n            logger.info(\n                \"Model path is a directory, using default loading method: %s\",\n                self.model_path,\n            )\n            return super().load_modules(server_args, loaded_modules)\n\n    def _load_transformer_from_safetensors(\n        self,\n        server_args: ServerArgs,\n        loaded_modules: dict[str, torch.nn.Module] | None = None,\n    ) -> dict[str, Any]:\n        \"\"\"Load transformer directly from safetensors without model_index.json.\"\"\"\n\n        # 1) Fast path: use provided module\n        if loaded_modules is not None and \"transformer\" in loaded_modules:\n            logger.info(\"Using provided transformer module\")\n            return {\n                \"transformer\": loaded_modules[\"transformer\"],\n                \"scheduler\": self.modules.get(\"scheduler\"),\n            }\n\n        # 2) Build config and mappings\n        dit_config, updated_mapping, model_cls, default_dtype = (\n            self._prepare_dit_config_and_mapping(server_args)\n        )\n        safetensors_list = [self.model_path]\n        logger.info(\"Loading weights from: %s\", safetensors_list)\n\n        # 3) Instantiate model (meta) and optionally shard\n        model = self._instantiate_model(\n            model_cls, dit_config, default_dtype, updated_mapping, server_args\n        )\n\n        # 4) Load weights\n        self._load_weights_into_model(\n            model, safetensors_list, default_dtype, updated_mapping, server_args\n        )\n\n        components = {\n            \"transformer\": model,\n            \"scheduler\": self.modules.get(\"scheduler\"),\n        }\n        logger.info(\"Successfully loaded modules: %s\", list(components.keys()))\n        return components\n\n    def _prepare_dit_config_and_mapping(self, server_args: ServerArgs):\n        from sglang.multimodal_gen.configs.models.dits.qwenimage import (\n            QwenImageArchConfig,\n        )\n\n        comfyui_arch_config = QwenImageArchConfig(\n            patch_size=2,\n            in_channels=64,\n            out_channels=16,\n            num_layers=60,\n            attention_head_dim=128,\n            num_attention_heads=24,\n            joint_attention_dim=3584,\n            pooled_projection_dim=768,\n            guidance_embeds=False,\n            axes_dims_rope=(16, 56, 56),\n            zero_cond_t=self.zero_cond_t,\n        )\n        dit_config = QwenImageDitConfig(arch_config=comfyui_arch_config)\n        server_args.pipeline_config.dit_config = dit_config\n\n        if dit_config.arch_config.param_names_mapping is None:\n            dit_config.arch_config.param_names_mapping = {}\n\n        comfyui_qwen_mappings = {r\"^model\\.diffusion_model\\.(.*)$\": r\"\\1\"}\n        updated_mapping = {\n            **dit_config.arch_config.param_names_mapping,\n            **comfyui_qwen_mappings,\n        }\n        dit_config.arch_config.param_names_mapping = updated_mapping\n        logger.info(\n            \"Added ComfyUI weight name mappings to param_names_mapping. \"\n            f\"Total mappings: {len(updated_mapping)}\"\n        )\n\n        cls_name = \"QwenImageTransformer2DModel\"\n        model_cls, _ = ModelRegistry.resolve_model_cls(cls_name)\n        logger.info(\"Resolved transformer class: %s\", cls_name)\n\n        default_dtype = PRECISION_TO_TYPE[server_args.pipeline_config.dit_precision]\n        server_args.model_paths[\"transformer\"] = os.path.dirname(self.model_path) or \".\"\n        assert server_args.hsdp_shard_dim is not None, \"hsdp_shard_dim must be set\"\n        logger.info(\n            \"Loading %s from safetensors file, default_dtype: %s\",\n            cls_name,\n            default_dtype,\n        )\n        return dit_config, updated_mapping, model_cls, default_dtype\n\n    def _instantiate_model(\n        self,\n        model_cls,\n        dit_config,\n        default_dtype,\n        updated_mapping,\n        server_args: ServerArgs,\n    ):\n        from sglang.multimodal_gen.runtime.platforms import current_platform\n\n        hf_config = {}\n        original_mapping = model_cls.param_names_mapping\n        model_cls.param_names_mapping = updated_mapping\n        logger.info(\n            \"Temporarily updated model class param_names_mapping with ComfyUI mappings. \"\n            f\"Total mappings: {len(updated_mapping)}\"\n        )\n\n        try:\n            mp_policy = MixedPrecisionPolicy(\n                torch.bfloat16, torch.float32, None, cast_forward_inputs=False\n            )\n            set_mixed_precision_policy(\n                param_dtype=torch.bfloat16,\n                reduce_dtype=torch.float32,\n                output_dtype=None,\n                mp_policy=mp_policy,\n            )\n\n            with set_default_torch_dtype(default_dtype), torch.device(\"meta\"):\n                model = model_cls(**{\"config\": dit_config, \"hf_config\": hf_config})\n\n            use_fsdp = server_args.use_fsdp_inference\n            if current_platform.is_mps():\n                use_fsdp = False\n                logger.info(\"Disabling FSDP for MPS platform as it's not compatible\")\n\n            if use_fsdp:\n                device_mesh = init_device_mesh(\n                    current_platform.device_type,\n                    mesh_shape=(\n                        server_args.hsdp_replicate_dim,\n                        server_args.hsdp_shard_dim,\n                    ),\n                    mesh_dim_names=(\"replicate\", \"shard\"),\n                )\n                shard_model(\n                    model,\n                    cpu_offload=server_args.dit_cpu_offload,\n                    reshard_after_forward=True,\n                    mp_policy=mp_policy,\n                    mesh=device_mesh,\n                    fsdp_shard_conditions=model._fsdp_shard_conditions,\n                    pin_cpu_memory=server_args.pin_cpu_memory,\n                )\n        finally:\n            model_cls.param_names_mapping = original_mapping\n\n        return model\n\n    def _load_weights_into_model(\n        self,\n        model,\n        safetensors_list,\n        default_dtype,\n        updated_mapping,\n        server_args: ServerArgs,\n    ):\n        # Create weight iterator for loading\n        weight_iterator = safetensors_weights_iterator(safetensors_list)\n\n        # Load weights\n        param_names_mapping_fn = get_param_names_mapping(updated_mapping)\n        load_model_from_full_model_state_dict(\n            model,\n            weight_iterator,\n            get_local_torch_device(),\n            default_dtype,\n            strict=True,\n            cpu_offload=server_args.dit_cpu_offload,\n            param_names_mapping=param_names_mapping_fn,\n        )\n\n        # Check for meta parameters\n        for n, p in chain(model.named_parameters(), model.named_buffers()):\n            if p.is_meta:\n                raise RuntimeError(f\"Unexpected param or buffer {n} on meta device.\")\n            if isinstance(p, torch.nn.Parameter):\n                p.requires_grad = False\n\n        total_params = sum(p.numel() for p in model.parameters())\n        logger.info(\"Loaded transformer with %.2fB parameters\", total_params / 1e9)\n\n    def create_pipeline_stages(self, server_args: ServerArgs):\n        logger.info(\n            f\"{self.__class__.__name__}.create_pipeline_stages() called - creating latent_preparation_stage and denoising_stage\"\n        )\n\n        self.add_stages(\n            [\n                ComfyUILatentPreparationStage(\n                    scheduler=self.get_module(\"scheduler\"),\n                    transformer=self.get_module(\"transformer\"),\n                ),\n                DenoisingStage(\n                    transformer=self.get_module(\"transformer\"),\n                    scheduler=self.get_module(\"scheduler\"),\n                ),\n            ]\n        )\n\n        logger.info(\n            f\"{self.__class__.__name__} stages created: {list(self._stage_name_mapping.keys())}\"\n        )\n\n\nclass ComfyUIQwenImagePipeline(ComfyUIQwenImagePipelineBase):\n    \"\"\"ComfyUI QwenImage pipeline for text-to-image generation.\"\"\"\n\n    pipeline_name = \"ComfyUIQwenImagePipeline\"\n    zero_cond_t = False\n\n    from sglang.multimodal_gen.configs.pipeline_configs.qwen_image import (\n        QwenImagePipelineConfig,\n    )\n    from sglang.multimodal_gen.configs.sample.qwenimage import QwenImageSamplingParams\n\n    pipeline_config_cls = QwenImagePipelineConfig\n    sampling_params_cls = QwenImageSamplingParams\n\n\nclass ComfyUIQwenImageEditPipeline(ComfyUIQwenImagePipelineBase):\n    \"\"\"ComfyUI QwenImage pipeline for image-to-image editing.\"\"\"\n\n    pipeline_name = \"ComfyUIQwenImageEditPipeline\"\n    zero_cond_t = True\n\n    from sglang.multimodal_gen.configs.pipeline_configs.qwen_image import (\n        QwenImageEditPlusPipelineConfig,\n    )\n    from sglang.multimodal_gen.configs.sample.qwenimage import (\n        QwenImageEditPlusSamplingParams,\n    )\n\n    pipeline_config_cls = QwenImageEditPlusPipelineConfig\n    sampling_params_cls = QwenImageEditPlusSamplingParams\n\n\nEntryClass = [ComfyUIQwenImagePipeline, ComfyUIQwenImageEditPipeline]\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines/comfyui_zimage_pipeline.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n# SPDX-License-Identifier: Apache-2.0\n\nimport os\nimport re\nfrom collections.abc import Generator\nfrom itertools import chain\nfrom typing import Any\n\nimport torch\nfrom torch.distributed import init_device_mesh\nfrom torch.distributed.fsdp import MixedPrecisionPolicy\n\nfrom sglang.multimodal_gen.configs.models.dits.zimage import ZImageDitConfig\nfrom sglang.multimodal_gen.runtime.distributed import get_local_torch_device\nfrom sglang.multimodal_gen.runtime.loader.fsdp_load import (\n    load_model_from_full_model_state_dict,\n    shard_model,\n)\nfrom sglang.multimodal_gen.runtime.loader.utils import (\n    get_param_names_mapping,\n    set_default_torch_dtype,\n)\nfrom sglang.multimodal_gen.runtime.loader.weight_utils import (\n    safetensors_weights_iterator,\n)\nfrom sglang.multimodal_gen.runtime.models.registry import ModelRegistry\nfrom sglang.multimodal_gen.runtime.models.schedulers.scheduling_comfyui_passthrough import (\n    ComfyUIPassThroughScheduler,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core import LoRAPipeline\nfrom sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import (\n    ComposedPipelineBase,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages import (\n    ComfyUILatentPreparationStage,\n    DenoisingStage,\n)\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.utils import PRECISION_TO_TYPE, set_mixed_precision_policy\n\nlogger = init_logger(__name__)\n\n\nclass ComfyUIZImagePipeline(LoRAPipeline, ComposedPipelineBase):\n    \"\"\"\n    Simplified pipeline for ComfyUI integration with only denoising stage.\n\n    This pipeline requires pre-processed inputs:\n    - prompt_embeds: Pre-encoded text embeddings (list of tensors)\n    - negative_prompt_embeds: Pre-encoded negative prompt embeddings (if using CFG)\n    - latents: Optional initial noise latents (will be generated if not provided)\n\n    Usage:\n        generator = DiffGenerator.from_pretrained(\n            model_path=\"path/to/model\",\n            pipeline_class_name=\"ComfyUIZImagePipeline\",\n            device=\"cuda\",\n        )\n    \"\"\"\n\n    pipeline_name = \"ComfyUIZImagePipeline\"\n    from sglang.multimodal_gen.configs.pipeline_configs.zimage import (\n        ZImagePipelineConfig,\n    )\n    from sglang.multimodal_gen.configs.sample.zimage import ZImageSamplingParams\n\n    pipeline_config_cls = ZImagePipelineConfig\n    sampling_params_cls = ZImageSamplingParams\n\n    _required_config_modules = [\n        \"transformer\",\n        \"scheduler\",\n    ]\n\n    def initialize_pipeline(self, server_args: ServerArgs):\n        \"\"\"\n        Initialize the pipeline with ComfyUI pass-through scheduler.\n        This scheduler does not modify latents, allowing ComfyUI to handle denoising.\n        \"\"\"\n        self.modules[\"scheduler\"] = ComfyUIPassThroughScheduler(\n            num_train_timesteps=1000\n        )\n\n        # Ensure VAE config is properly initialized even though we don't load the VAE model\n        # This is necessary because get_freqs_cis uses spatial_compression_ratio\n        if hasattr(server_args.pipeline_config, \"vae_config\"):\n            vae_config = server_args.pipeline_config.vae_config\n            if hasattr(vae_config, \"post_init\") and not hasattr(\n                vae_config, \"_post_init_called\"\n            ):\n                vae_config.post_init()\n                logger.info(\n                    \"Called vae_config.post_init() to set spatial_compression_ratio. \"\n                    f\"spatial_compression_ratio={vae_config.arch_config.spatial_compression_ratio}\"\n                )\n\n    def load_modules(\n        self,\n        server_args: ServerArgs,\n        loaded_modules: dict[str, torch.nn.Module] | None = None,\n    ) -> dict[str, Any]:\n        \"\"\"\n        Load modules for ComfyUIZImagePipeline.\n\n        If model_path is a safetensors file, load transformer directly from it\n        without requiring model_index.json. Otherwise, fall back to default loading.\n        \"\"\"\n        if os.path.isfile(self.model_path) and self.model_path.endswith(\".safetensors\"):\n            logger.info(\n                \"Detected safetensors file, loading transformer directly from: %s\",\n                self.model_path,\n            )\n            return self._load_transformer_from_safetensors(server_args, loaded_modules)\n        else:\n            logger.info(\n                \"Model path is a directory, using default loading method: %s\",\n                self.model_path,\n            )\n            return super().load_modules(server_args, loaded_modules)\n\n    def _convert_comfyui_qkv_weights(\n        self,\n        weight_iterator: Generator[tuple[str, torch.Tensor], None, None],\n        dim: int,\n        num_heads: int,\n        num_kv_heads: int,\n    ) -> Generator[tuple[str, torch.Tensor], None, None]:\n        \"\"\"\n        Convert ComfyUI zimage qkv weights to SGLang format.\n        Splits merged qkv.weight into separate to_q, to_k, to_v weights.\n\n        Args:\n            weight_iterator: Iterator yielding (name, tensor) pairs from safetensors\n            dim: Model dimension\n            num_heads: Number of attention heads\n            num_kv_heads: Number of key-value heads\n\n        Yields:\n            (name, tensor) pairs with qkv weights split into to_q, to_k, to_v\n        \"\"\"\n        head_dim = dim // num_heads\n        q_size = dim\n        k_size = head_dim * num_kv_heads\n        v_size = head_dim * num_kv_heads\n\n        for name, tensor in weight_iterator:\n            # Match qkv weights in layers, noise_refiner, or context_refiner\n            # Pattern: (layers|noise_refiner|context_refiner).{i}.attention.qkv.(weight|bias)\n            match = re.match(\n                r\"(layers|noise_refiner|context_refiner)\\.(\\d+)\\.attention\\.qkv\\.(weight|bias)$\",\n                name,\n            )\n            if match:\n                module_name, layer_idx, param_type = match.groups()\n                base_name = f\"{module_name}.{layer_idx}.attention\"\n\n                if param_type == \"weight\":\n                    # Weight shape: (q_size + k_size + v_size, dim)\n                    # Split into q, k, v\n                    q_weight = tensor[:q_size, :]\n                    k_weight = tensor[q_size : q_size + k_size, :]\n                    v_weight = tensor[q_size + k_size :, :]\n\n                    logger.debug(\n                        f\"Splitting {name} (shape {tensor.shape}) into \"\n                        f\"to_q ({q_weight.shape}), to_k ({k_weight.shape}), to_v ({v_weight.shape})\"\n                    )\n\n                    yield f\"{base_name}.to_q.weight\", q_weight\n                    yield f\"{base_name}.to_k.weight\", k_weight\n                    yield f\"{base_name}.to_v.weight\", v_weight\n                else:  # bias\n                    # Bias shape: (q_size + k_size + v_size,)\n                    # Split into q, k, v\n                    q_bias = tensor[:q_size]\n                    k_bias = tensor[q_size : q_size + k_size]\n                    v_bias = tensor[q_size + k_size :]\n\n                    logger.debug(\n                        f\"Splitting {name} (shape {tensor.shape}) into \"\n                        f\"to_q ({q_bias.shape}), to_k ({k_bias.shape}), to_v ({v_bias.shape})\"\n                    )\n\n                    yield f\"{base_name}.to_q.bias\", q_bias\n                    yield f\"{base_name}.to_k.bias\", k_bias\n                    yield f\"{base_name}.to_v.bias\", v_bias\n            else:\n                # Pass through other weights unchanged\n                yield name, tensor\n\n    def _load_transformer_from_safetensors(\n        self,\n        server_args: ServerArgs,\n        loaded_modules: dict[str, torch.nn.Module] | None = None,\n    ) -> dict[str, Any]:\n        \"\"\"\n        Load transformer directly from safetensors file without model_index.json.\n\n        This method:\n        1. Uses hardcoded ZImageDitConfig for zimage model\n        2. Loads transformer from the safetensors file\n        3. Uses ComfyUIPassThroughScheduler (already created in initialize_pipeline)\n        \"\"\"\n        # Check if transformer is already provided\n        if loaded_modules is not None and \"transformer\" in loaded_modules:\n            logger.info(\"Using provided transformer module\")\n            components = {\n                \"transformer\": loaded_modules[\"transformer\"],\n                \"scheduler\": self.modules.get(\"scheduler\"),\n            }\n            return components\n\n        if hasattr(server_args.pipeline_config, \"dit_config\"):\n            dit_config = server_args.pipeline_config.dit_config\n            if not isinstance(dit_config, ZImageDitConfig):\n                logger.warning(\n                    \"dit_config is not ZImageDitConfig, creating new ZImageDitConfig\"\n                )\n                dit_config = ZImageDitConfig()\n                server_args.pipeline_config.dit_config = dit_config\n        else:\n            logger.info(\"Creating default ZImageDitConfig\")\n            dit_config = ZImageDitConfig()\n            server_args.pipeline_config.dit_config = dit_config\n\n        if dit_config.arch_config.param_names_mapping is None:\n            dit_config.arch_config.param_names_mapping = {}\n\n        # Add mappings for norm layers: map from ComfyUI format (k_norm/q_norm) to SGLang format (norm_k/norm_q)\n        # The regex matches the source name from safetensors, and the tuple specifies the target name in the model\n        # Note: qkv weights are handled separately by _convert_comfyui_qkv_weights function\n        comfyui_norm_mappings = {\n            r\"(.*)\\.attention\\.k_norm\\.weight$\": (\n                r\"\\1.attention.norm_k.weight\",\n                None,\n                None,\n            ),\n            r\"(.*)\\.attention\\.q_norm\\.weight$\": (\n                r\"\\1.attention.norm_q.weight\",\n                None,\n                None,\n            ),\n            r\"(.*)\\.attention\\.out\\.weight$\": (\n                r\"\\1.attention.to_out.0.weight\",\n                None,\n                None,\n            ),\n            r\"^final_layer\\.(.*)$\": (r\"all_final_layer.2-1.\\1\", None, None),\n            r\"^x_embedder\\.(.*)$\": (r\"all_x_embedder.2-1.\\1\", None, None),\n        }\n\n        # Merge ComfyUI mappings with existing mappings (ComfyUI mappings take precedence)\n        updated_mapping = {\n            **dit_config.arch_config.param_names_mapping,\n            **comfyui_norm_mappings,\n        }\n        dit_config.arch_config.param_names_mapping = updated_mapping\n        logger.info(\n            \"Added ComfyUI weight name mappings (k_norm/q_norm -> norm_k/norm_q) to param_names_mapping. \"\n            f\"Total mappings: {len(updated_mapping)}\"\n        )\n\n        cls_name = \"ZImageTransformer2DModel\"\n        model_cls, _ = ModelRegistry.resolve_model_cls(cls_name)\n        logger.info(\"Resolved transformer class: %s\", cls_name)\n        safetensors_list = [self.model_path]\n        logger.info(\"Loading weights from: %s\", safetensors_list)\n\n        default_dtype = PRECISION_TO_TYPE[server_args.pipeline_config.dit_precision]\n        server_args.model_paths[\"transformer\"] = os.path.dirname(self.model_path) or \".\"\n        hf_config = {}\n\n        assert server_args.hsdp_shard_dim is not None, \"hsdp_shard_dim must be set\"\n        logger.info(\n            \"Loading %s from safetensors file, default_dtype: %s\",\n            cls_name,\n            default_dtype,\n        )\n\n        original_mapping = model_cls.param_names_mapping\n        model_cls.param_names_mapping = updated_mapping\n        logger.info(\n            \"Temporarily updated model class param_names_mapping with ComfyUI mappings. \"\n            f\"Total mappings: {len(updated_mapping)}\"\n        )\n\n        try:\n            # Create model first (same as maybe_load_fsdp_model)\n            from sglang.multimodal_gen.runtime.platforms import current_platform\n\n            mp_policy = MixedPrecisionPolicy(\n                torch.bfloat16, torch.float32, None, cast_forward_inputs=False\n            )\n\n            set_mixed_precision_policy(\n                param_dtype=torch.bfloat16,\n                reduce_dtype=torch.float32,\n                output_dtype=None,\n                mp_policy=mp_policy,\n            )\n\n            with set_default_torch_dtype(default_dtype), torch.device(\"meta\"):\n                model = model_cls(**{\"config\": dit_config, \"hf_config\": hf_config})\n\n            # Check if we should use FSDP\n            use_fsdp = server_args.use_fsdp_inference\n            if current_platform.is_mps():\n                use_fsdp = False\n                logger.info(\"Disabling FSDP for MPS platform as it's not compatible\")\n\n            if use_fsdp:\n                world_size = server_args.hsdp_replicate_dim * server_args.hsdp_shard_dim\n                device_mesh = init_device_mesh(\n                    current_platform.device_type,\n                    mesh_shape=(\n                        server_args.hsdp_replicate_dim,\n                        server_args.hsdp_shard_dim,\n                    ),\n                    mesh_dim_names=(\"replicate\", \"shard\"),\n                )\n                shard_model(\n                    model,\n                    cpu_offload=server_args.dit_cpu_offload,\n                    reshard_after_forward=True,\n                    mp_policy=mp_policy,\n                    mesh=device_mesh,\n                    fsdp_shard_conditions=model._fsdp_shard_conditions,\n                    pin_cpu_memory=server_args.pin_cpu_memory,\n                )\n\n            # Get model dimensions for qkv splitting\n            arch_config = dit_config.arch_config\n            dim = arch_config.dim\n            num_heads = arch_config.num_attention_heads\n            num_kv_heads = arch_config.n_kv_heads\n\n            # Create weight iterator with qkv conversion\n            base_weight_iterator = safetensors_weights_iterator(safetensors_list)\n            converted_weight_iterator = self._convert_comfyui_qkv_weights(\n                base_weight_iterator, dim, num_heads, num_kv_heads\n            )\n\n            # Load weights\n            param_names_mapping_fn = get_param_names_mapping(updated_mapping)\n            load_model_from_full_model_state_dict(\n                model,\n                converted_weight_iterator,\n                get_local_torch_device(),\n                default_dtype,\n                strict=True,\n                cpu_offload=server_args.dit_cpu_offload,\n                param_names_mapping=param_names_mapping_fn,\n            )\n\n            # Check for meta parameters\n            for n, p in chain(model.named_parameters(), model.named_buffers()):\n                if p.is_meta:\n                    raise RuntimeError(\n                        f\"Unexpected param or buffer {n} on meta device.\"\n                    )\n                if isinstance(p, torch.nn.Parameter):\n                    p.requires_grad = False\n        finally:\n            model_cls.param_names_mapping = original_mapping\n\n        total_params = sum(p.numel() for p in model.parameters())\n        logger.info(\"Loaded transformer with %.2fB parameters\", total_params / 1e9)\n\n        components = {\n            \"transformer\": model,\n            \"scheduler\": self.modules.get(\"scheduler\"),\n        }\n\n        logger.info(\"Successfully loaded modules: %s\", list(components.keys()))\n        return components\n\n    def create_pipeline_stages(self, server_args: ServerArgs):\n        logger.info(\n            \"ComfyUIZImagePipeline.create_pipeline_stages() called - creating latent_preparation_stage and denoising_stage\"\n        )\n\n        self.add_stages(\n            [\n                ComfyUILatentPreparationStage(\n                    scheduler=self.get_module(\"scheduler\"),\n                    transformer=self.get_module(\"transformer\"),\n                ),\n                DenoisingStage(\n                    transformer=self.get_module(\"transformer\"),\n                    scheduler=self.get_module(\"scheduler\"),\n                ),\n            ]\n        )\n\n        logger.info(\n            f\"ComfyUIZImagePipeline stages created: {list(self._stage_name_mapping.keys())}\"\n        )\n\n\nEntryClass = ComfyUIZImagePipeline\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines/diffusers_pipeline.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\"\"\"\nDiffusers backend pipeline wrapper.\n\nThis module provides a wrapper that allows running any diffusers-supported model\nthrough sglang's infrastructure using vanilla diffusers pipelines.\n\"\"\"\n\nimport argparse\nimport inspect\nimport re\nimport warnings\nfrom io import BytesIO\nfrom typing import Any\n\nimport numpy as np\nimport requests\nimport torch\nimport torchvision.transforms as T\nfrom diffusers import DiffusionPipeline\nfrom PIL import Image\n\nfrom sglang.multimodal_gen.configs.pipeline_configs.base import PipelineConfig\nfrom sglang.multimodal_gen.runtime.distributed import get_local_torch_device\nfrom sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import (\n    ComposedPipelineBase,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.executors.pipeline_executor import (\n    PipelineExecutor,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.executors.sync_executor import (\n    SyncExecutor,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages import PipelineStage\nfrom sglang.multimodal_gen.runtime.platforms import (\n    AttentionBackendEnum,\n    current_platform,\n)\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import maybe_download_model\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\nclass DiffusersExecutionStage(PipelineStage):\n    \"\"\"Pipeline stage that wraps diffusers pipeline execution.\"\"\"\n\n    def __init__(self, diffusers_pipe: DiffusionPipeline):\n        super().__init__()\n        self.diffusers_pipe = diffusers_pipe\n\n    def forward(self, batch: Req, server_args: ServerArgs) -> Req:\n        \"\"\"Execute the diffusers pipeline.\"\"\"\n\n        kwargs = self._build_pipeline_kwargs(batch)\n\n        # Filter kwargs to only those supported by the pipeline, warn about ignored args\n        kwargs, _ = self._filter_pipeline_kwargs(kwargs)\n\n        # Request tensor output for cleaner handling\n        if \"output_type\" not in kwargs:\n            kwargs[\"output_type\"] = \"pt\"\n\n        with torch.no_grad(), warnings.catch_warnings(record=True):\n            warnings.simplefilter(\"always\")\n            try:\n                output = self.diffusers_pipe(**kwargs)\n            except TypeError as e:\n                # Some pipelines don't support output_type=\"pt\"\n                if \"output_type\" in str(e):\n                    kwargs.pop(\"output_type\", None)\n                    output = self.diffusers_pipe(**kwargs)\n                else:\n                    raise\n\n        batch.output = self._extract_output(output)\n        if batch.output is not None:\n            batch.output = self._postprocess_output(batch.output)\n\n        return batch\n\n    def _filter_pipeline_kwargs(\n        self, kwargs: dict[str, Any], *, strict: bool = False\n    ) -> tuple[dict[str, Any], list[str]]:\n        \"\"\"Filter kwargs to those accepted by the pipeline's __call__.\n\n        Args:\n            kwargs: Arguments to filter\n            strict: If True, raise ValueError on unsupported args; otherwise warn\n\n        Returns:\n            Tuple of (filtered_kwargs, ignored_keys)\n        \"\"\"\n        try:\n            sig = inspect.signature(self.diffusers_pipe.__call__)\n        except (ValueError, TypeError):\n            return kwargs, []\n\n        params = sig.parameters\n        accepts_var_kwargs = any(\n            p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()\n        )\n        if accepts_var_kwargs:\n            return kwargs, []\n\n        valid = set(params.keys()) - {\"self\"}\n\n        filtered = {}\n        ignored = []\n        for k, v in kwargs.items():\n            if k in valid:\n                filtered[k] = v\n            else:\n                ignored.append(k)\n\n        if ignored:\n            pipe_name = type(self.diffusers_pipe).__name__\n            msg = (\n                f\"Pipeline '{pipe_name}' does not support: {', '.join(sorted(ignored))}. \"\n                \"These arguments will be ignored.\"\n            )\n            if strict:\n                raise ValueError(msg)\n            logger.warning(msg)\n\n        return filtered, ignored\n\n    def _extract_output(self, output: Any) -> torch.Tensor | None:\n        \"\"\"Extract tensor output from pipeline result.\"\"\"\n        for attr in [\"images\", \"frames\", \"video\", \"sample\", \"pred_original_sample\"]:\n            data = getattr(output, attr, None)\n            if data is None:\n                continue\n\n            result = self._convert_to_tensor(data)\n            if result is not None:\n                logger.debug(\n                    \"Extracted output from '%s': shape=%s, dtype=%s\",\n                    attr,\n                    result.shape,\n                    result.dtype,\n                )\n                return result\n\n        logger.warning(\"Could not extract output from pipeline result\")\n        return None\n\n    def _convert_to_tensor(self, data: Any) -> torch.Tensor | None:\n        \"\"\"Convert various data formats to a tensor.\"\"\"\n        if isinstance(data, torch.Tensor):\n            return data\n\n        if isinstance(data, np.ndarray):\n            tensor = torch.from_numpy(data).float()\n            if tensor.max() > 1.0:\n                tensor = tensor / 255.0\n            # (B, H, W, C) -> (B, C, H, W) or (B, T, H, W, C) -> (B, C, T, H, W)\n            if tensor.ndim == 4:\n                tensor = tensor.permute(0, 3, 1, 2)\n            elif tensor.ndim == 5:\n                tensor = tensor.permute(0, 4, 1, 2, 3)\n            return tensor\n\n        if isinstance(data, Image.Image):\n            return T.ToTensor()(data)\n\n        if isinstance(data, list) and len(data) > 0:\n            return self._convert_list_to_tensor(data)\n\n        return None\n\n    def _convert_list_to_tensor(self, data: list) -> torch.Tensor | None:\n        \"\"\"Convert a list of items to a tensor.\"\"\"\n        first = data[0]\n\n        # Nested list (e.g., [[frame1, frame2, ...]] for video batches)\n        if isinstance(first, list) and len(first) > 0:\n            data = first\n            first = data[0]\n\n        if isinstance(first, Image.Image):\n            tensors = [T.ToTensor()(img) for img in data]\n            stacked = torch.stack(tensors)\n            if len(tensors) > 1:\n                return stacked.permute(1, 0, 2, 3)  # (T, C, H, W) -> (C, T, H, W)\n            return stacked[0]\n\n        if isinstance(first, torch.Tensor):\n            stacked = torch.stack(data)\n            if len(data) > 1:\n                return stacked.permute(1, 0, 2, 3)\n            return stacked[0]\n\n        if isinstance(first, np.ndarray):\n            tensors = [torch.from_numpy(arr).float() for arr in data]\n            if tensors[0].max() > 1.0:\n                tensors = [t / 255.0 for t in tensors]\n            if tensors[0].ndim == 3:\n                tensors = [t.permute(2, 0, 1) for t in tensors]\n            stacked = torch.stack(tensors)\n            if len(data) > 1:\n                return stacked.permute(1, 0, 2, 3)\n            return stacked[0]\n\n        return None\n\n    def _postprocess_output(self, output: torch.Tensor) -> torch.Tensor:\n        \"\"\"Post-process output tensor to ensure valid values and correct shape.\"\"\"\n        output = output.cpu().float()\n\n        # Handle NaN or Inf values\n        if torch.isnan(output).any() or torch.isinf(output).any():\n            logger.warning(\"Output contains invalid values, fixing...\")\n            output = torch.nan_to_num(output, nan=0.5, posinf=1.0, neginf=0.0)\n\n        # Normalize to [0, 1] range if needed\n        min_val, max_val = output.min().item(), output.max().item()\n        if min_val < -0.5 or max_val > 1.5:\n            output = (output + 1) / 2\n\n        output = output.clamp(0, 1)\n\n        # Ensure correct shape for downstream processing\n        output = self._fix_output_shape(output)\n\n        logger.debug(\"Final output tensor shape: %s\", output.shape)\n        return output\n\n    def _fix_output_shape(self, output: torch.Tensor) -> torch.Tensor:\n        \"\"\"Fix tensor shape for downstream processing.\n\n        Expected: (B, C, H, W) for images or (B, C, T, H, W) for videos.\n        \"\"\"\n        if output.dim() == 5:\n            # Video: (B, T, C, H, W) -> (B, C, T, H, W)\n            return output.permute(0, 2, 1, 3, 4)\n\n        if output.dim() == 4:\n            if output.shape[0] == 1 or output.shape[1] in [1, 3, 4]:\n                return output  # Already (B, C, H, W)\n            # (T, C, H, W) -> (1, C, T, H, W)\n            return output.unsqueeze(0).permute(0, 2, 1, 3, 4)\n\n        if output.dim() == 3:\n            c, h, w = output.shape\n            if c > 4 and w <= 4:\n                output = output.permute(2, 0, 1)\n            if output.shape[0] == 1:\n                output = output.repeat(3, 1, 1)\n            return output.unsqueeze(0)\n\n        if output.dim() == 2:\n            return output.unsqueeze(0).repeat(3, 1, 1).unsqueeze(0)\n\n        return output\n\n    def _build_pipeline_kwargs(self, batch: Req) -> dict[str, Any]:\n        \"\"\"Build kwargs dict for diffusers pipeline call.\"\"\"\n        kwargs = {}\n\n        if batch.prompt is not None:\n            kwargs[\"prompt\"] = batch.prompt\n\n        if batch.negative_prompt:\n            kwargs[\"negative_prompt\"] = batch.negative_prompt\n\n        if batch.num_inference_steps is not None:\n            kwargs[\"num_inference_steps\"] = batch.num_inference_steps\n\n        if batch.guidance_scale is not None:\n            kwargs[\"guidance_scale\"] = batch.guidance_scale\n\n        if batch.true_cfg_scale is not None:\n            kwargs[\"true_cfg_scale\"] = batch.true_cfg_scale\n\n        if batch.height is not None:\n            kwargs[\"height\"] = batch.height\n\n        if batch.width is not None:\n            kwargs[\"width\"] = batch.width\n\n        if batch.num_frames is not None and batch.num_frames > 1:\n            kwargs[\"num_frames\"] = batch.num_frames\n\n        # Generator for reproducibility\n        if batch.generator is not None:\n            kwargs[\"generator\"] = batch.generator\n        elif batch.seed is not None:\n            device = self._get_generator_device(batch)\n            kwargs[\"generator\"] = torch.Generator(device=device).manual_seed(batch.seed)\n\n        # Image input for img2img or inpainting\n        image = self._load_input_image(batch)\n        if image is not None:\n            kwargs[\"image\"] = image\n\n        if batch.num_outputs_per_prompt > 1:\n            kwargs[\"num_images_per_prompt\"] = batch.num_outputs_per_prompt\n\n        # Extra diffusers-specific kwargs\n        if batch.extra:\n            diffusers_kwargs = batch.extra.get(\"diffusers_kwargs\", {})\n            if diffusers_kwargs:\n                kwargs.update(diffusers_kwargs)\n\n        return kwargs\n\n    def _get_generator_device(self, batch: Req) -> str:\n        \"\"\"Resolve RNG device consistently with the non-diffusers path.\n\n        Diffusers CPU offload can temporarily park modules on CPU, but that\n        should not silently switch a CUDA request to CPU RNG, otherwise the\n        same seed produces different outputs depending on runtime placement.\n        \"\"\"\n        if batch.generator_device == \"cpu\":\n            return \"cpu\"\n        return current_platform.device_type\n\n    def _load_input_image(self, batch: Req) -> Image.Image | None:\n        \"\"\"Load input image from batch.\"\"\"\n        # Check for PIL image in condition_image or pixel_values\n        if batch.condition_image is not None and isinstance(\n            batch.condition_image, Image.Image\n        ):\n            return batch.condition_image\n        if batch.pixel_values is not None and isinstance(\n            batch.pixel_values, Image.Image\n        ):\n            return batch.pixel_values\n\n        if not batch.image_path:\n            return None\n\n        if isinstance(batch.image_path, list):\n            batch.image_path = batch.image_path[0]\n\n        try:\n            if batch.image_path.startswith((\"http://\", \"https://\")):\n                response = requests.get(batch.image_path, timeout=30)\n                response.raise_for_status()\n                return Image.open(BytesIO(response.content)).convert(\"RGB\")\n            return Image.open(batch.image_path).convert(\"RGB\")\n        except Exception as e:\n            logger.error(\"Failed to load image from %s: %s\", batch.image_path, e)\n            return None\n\n\nclass DiffusersPipeline(ComposedPipelineBase):\n    \"\"\"\n    Pipeline wrapper that uses vanilla diffusers pipelines.\n\n    This allows running any diffusers-supported model through sglang's infrastructure\n    without requiring native sglang implementation.\n    \"\"\"\n\n    pipeline_name = \"DiffusersPipeline\"\n    is_video_pipeline = False\n    _required_config_modules: list[str] = []\n\n    def __init__(\n        self,\n        model_path: str,\n        server_args: ServerArgs,\n        required_config_modules: list[str] | None = None,\n        loaded_modules: dict[str, torch.nn.Module] | None = None,\n        executor: PipelineExecutor | None = None,\n    ):\n        self.server_args = server_args\n        self.model_path = model_path\n        self._stages: list[PipelineStage] = []\n        self._stage_name_mapping: dict[str, PipelineStage] = {}\n        self.modules: dict[str, Any] = {}\n        self.memory_usages: dict[str, float] = {}\n        self.post_init_called = False\n        self.executor = executor or SyncExecutor(server_args=server_args)\n        self._cache_dit_enabled = False\n\n        logger.info(\"Loading diffusers pipeline from %s\", model_path)\n        self.diffusers_pipe = self._load_diffusers_pipeline(model_path, server_args)\n        self._detect_pipeline_type()\n\n    def _load_diffusers_pipeline(\n        self, model_path: str, server_args: ServerArgs\n    ) -> DiffusionPipeline:\n        \"\"\"Load the diffusers pipeline.\n\n        Optimizations applied:\n        - device_map: Loads models directly to GPU, warming up CUDA caching allocator\n          to avoid small tensor allocations during inference.\n        - Parallel shard loading: When using device_map with accelerate, model shards\n          are loaded in parallel for faster initialization.\n        \"\"\"\n\n        original_model_path = model_path  # Keep original for custom_pipeline\n        model_path = maybe_download_model(model_path, force_diffusers_model=True)\n        self.model_path = model_path\n\n        dtype = self._get_dtype(server_args)\n        logger.info(\"Loading diffusers pipeline with dtype=%s\", dtype)\n\n        # Build common kwargs for from_pretrained\n        load_kwargs = {\n            \"torch_dtype\": dtype,\n            \"trust_remote_code\": server_args.trust_remote_code,\n            \"revision\": server_args.revision,\n        }\n\n        # Add quantization config if provided (e.g., BitsAndBytesConfig for 4/8-bit)\n        quant_config = getattr(server_args.pipeline_config, \"quantization_config\", None)\n        if quant_config is not None:\n            load_kwargs[\"quantization_config\"] = quant_config\n            logger.info(\"Using quantization config: %s\", type(quant_config).__name__)\n\n        try:\n            pipe = DiffusionPipeline.from_pretrained(model_path, **load_kwargs)\n        except AttributeError as e:\n            if \"has no attribute\" in str(e):\n                # Custom pipeline class not in diffusers - try loading with custom_pipeline\n                logger.info(\n                    \"Pipeline class not found in diffusers, trying custom_pipeline from repo...\"\n                )\n                try:\n                    custom_kwargs = {\n                        **load_kwargs,\n                        \"custom_pipeline\": original_model_path,\n                    }\n                    custom_kwargs[\"trust_remote_code\"] = True\n                    pipe = DiffusionPipeline.from_pretrained(\n                        model_path, **custom_kwargs\n                    )\n                except Exception as e2:\n                    match = re.search(r\"has no attribute (\\w+)\", str(e))\n                    class_name = match.group(1) if match else \"unknown\"\n                    raise RuntimeError(\n                        f\"Pipeline class '{class_name}' not found in diffusers and no custom pipeline.py in repo. \"\n                        f\"Try: pip install --upgrade diffusers (some pipelines require latest version). \"\n                        f\"Original error: {e}\"\n                    ) from e2\n            else:\n                raise\n        except Exception as e:\n            # Only retry with float32 for dtype-related errors\n            if \"dtype\" in str(e).lower() or \"float\" in str(e).lower():\n                logger.warning(\n                    \"Failed with dtype=%s, falling back to float32: %s\", dtype, e\n                )\n                load_kwargs[\"torch_dtype\"] = torch.float32\n                pipe = DiffusionPipeline.from_pretrained(model_path, **load_kwargs)\n            else:\n                raise\n\n        # Use CPU offload (all-or-nothing in diffusers) if any component offload is requested.\n        any_offload = (\n            server_args.dit_cpu_offload\n            or server_args.text_encoder_cpu_offload\n            or server_args.image_encoder_cpu_offload\n            or server_args.vae_cpu_offload\n        )\n        if any_offload:\n            device = get_local_torch_device()\n            gpu_id = device.index if device.index is not None else 0\n            pipe.enable_model_cpu_offload(gpu_id=gpu_id)\n            logger.info(\n                \"Enabled model CPU offload for diffusers pipeline (gpu_id=%d)\", gpu_id\n            )\n        else:\n            pipe = pipe.to(get_local_torch_device())\n        # Apply VAE memory optimizations from pipeline config\n        self._apply_vae_optimizations(pipe, server_args)\n        # Apply attention backend if specified\n        self._apply_attention_backend(pipe, server_args)\n        # Apply cache-dit acceleration if configured\n        pipe = self._apply_cache_dit(pipe, server_args)\n        # Apply torch.compile if enabled and supported\n        pipe = self._apply_torch_compile(pipe, server_args)\n        logger.info(\"Loaded diffusers pipeline: %s\", pipe.__class__.__name__)\n        return pipe\n\n    def _apply_vae_optimizations(\n        self, pipe: DiffusionPipeline, server_args: ServerArgs\n    ) -> None:\n        \"\"\"Apply VAE memory optimizations (tiling, slicing) from pipeline config.\"\"\"\n        config = server_args.pipeline_config\n\n        # VAE slicing: decode latents slice-by-slice for lower peak memory\n        # https://huggingface.co/docs/diffusers/optimization/memory#vae-slicing\n        if config.vae_slicing:\n            if hasattr(pipe, \"vae\") and hasattr(pipe.vae, \"enable_slicing\"):\n                pipe.vae.enable_slicing()\n                logger.info(\"Enabled VAE slicing for lower memory usage\")\n            elif hasattr(pipe, \"enable_vae_slicing\"):\n                pipe.enable_vae_slicing()\n                logger.info(\"Enabled VAE slicing for lower memory usage\")\n            else:\n                logger.warning(\n                    \"VAE slicing is not available: neither \"\n                    \"`pipe.vae.enable_slicing()` nor `pipe.enable_vae_slicing()` was found.\"\n                )\n\n        # VAE tiling: decode latents tile-by-tile for large images\n        # https://huggingface.co/docs/diffusers/optimization/memory#vae-tiling\n        if config.vae_tiling:\n            if hasattr(pipe, \"vae\") and hasattr(pipe.vae, \"enable_tiling\"):\n                pipe.vae.enable_tiling()\n                logger.info(\"Enabled VAE tiling for large image support\")\n            elif hasattr(pipe, \"enable_vae_tiling\"):\n                pipe.enable_vae_tiling()\n                logger.info(\"Enabled VAE tiling for large image support\")\n            else:\n                logger.warning(\n                    \"VAE tiling is not available: neither \"\n                    \"`pipe.vae.enable_tiling()` nor `pipe.enable_vae_tiling()` was found.\"\n                )\n\n    def _apply_attention_backend(\n        self, pipe: DiffusionPipeline, server_args: ServerArgs\n    ) -> None:\n        \"\"\"Apply attention backend setting from pipeline config or server_args.\n\n        See: https://huggingface.co/docs/diffusers/main/en/optimization/attention_backends\n        Available backends: flash, _flash_3_hub, sage, xformers, native, etc.\n        \"\"\"\n        backend = server_args.attention_backend\n\n        if backend is None:\n            backend = getattr(\n                server_args.pipeline_config, \"diffusers_attention_backend\", None\n            )\n\n        if backend is None:\n            return\n\n        backend = backend.lower()\n        sglang_backends = {e.name.lower() for e in AttentionBackendEnum} | {\n            \"fa3\",\n            \"fa4\",\n        }\n        if backend in sglang_backends:\n            logger.debug(\n                \"Skipping diffusers attention backend '%s' because it matches a \"\n                \"SGLang backend name. Use diffusers backend names when running \"\n                \"the diffusers backend.\",\n                backend,\n            )\n            return\n\n        for component_name in [\"transformer\", \"unet\"]:\n            component = getattr(pipe, component_name, None)\n            if component is not None and hasattr(component, \"set_attention_backend\"):\n                try:\n                    component.set_attention_backend(backend)\n                    logger.info(\n                        \"Set attention backend '%s' on %s\", backend, component_name\n                    )\n                except Exception as e:\n                    logger.warning(\n                        \"Failed to set attention backend '%s' on %s: %s\",\n                        backend,\n                        component_name,\n                        e,\n                    )\n\n    def _apply_cache_dit(\n        self, pipe: DiffusionPipeline, server_args: ServerArgs\n    ) -> DiffusionPipeline:\n        \"\"\"Enable cache-dit for diffusers pipeline if configured.\"\"\"\n        cache_dit_config = server_args.cache_dit_config\n        if not cache_dit_config:\n            return pipe\n\n        try:\n            import cache_dit\n        except ImportError as e:\n            raise RuntimeError(\n                \"cache-dit is required for --cache-dit-config. \"\n                \"Install it with `pip install cache-dit`.\"\n            ) from e\n\n        if not hasattr(cache_dit, \"load_configs\"):\n            raise RuntimeError(\n                \"cache-dit>=1.2.0 is required for --cache-dit-config. \"\n                \"Please upgrade cache-dit.\"\n            )\n\n        try:\n            cache_options = cache_dit.load_configs(cache_dit_config)\n        except Exception as e:\n            raise ValueError(\n                \"Failed to load cache-dit config. Provide a YAML/JSON path (or a dict \"\n                \"supported by cache-dit>=1.2.0).\"\n            ) from e\n\n        try:\n            pipe = cache_dit.enable_cache(pipe, **cache_options)\n        except Exception:\n            # cache-dit is an external integration and can raise a variety of errors.\n            logger.exception(\"Failed to enable cache-dit for diffusers pipeline\")\n            raise\n\n        logger.info(\"Enabled cache-dit for diffusers pipeline\")\n        self._cache_dit_enabled = True\n        return pipe\n\n    def _apply_torch_compile(self, pipe: Any, server_args: ServerArgs) -> Any:\n        \"\"\"Apply torch.compile to the pipeline if configured and supported.\"\"\"\n        if not server_args.enable_torch_compile:\n            return pipe\n\n        # check if the pipeline has 'transformer' or 'unet' components which are\n        # typically the most expensive parts to compile. 'transformer_2' for some\n        # video pipelines, e.g, Wan 2.2 series, also check for that.\n        compilable_components = [\"transformer\", \"transformer_2\", \"unet\"]\n        if not any(hasattr(pipe, comp) for comp in compilable_components):\n            logger.warning(\n                \"Pipeline does not have 'transformer' or 'unet' components. \"\n                \"torch.compile may not provide significant benefits and could increase latency.\"\n            )\n            return pipe\n\n        if self._cache_dit_enabled:\n            try:\n                import cache_dit\n\n                if hasattr(cache_dit, \"set_compile_configs\"):\n                    cache_dit.set_compile_configs()\n            except Exception as e:\n                logger.warning(\n                    f\"Failed to set torch_compile configs for cache-dit: {e}\"\n                )\n\n        for comp in compilable_components:\n            if hasattr(pipe, comp):\n                try:\n                    component = getattr(pipe, comp)\n                    # TODO(DefTruth): Add support for 'compile_repeated_blocks' for 'transformer'\n                    # modules which can significantly reduce compilation time for large models\n                    # with repeated blocks.\n                    if isinstance(component, torch.nn.Module) and hasattr(\n                        component, \"compile\"\n                    ):\n                        # Prefer in-place compilation if supported. According to PyTorch documentation:\n                        # https://docs.pytorch.org/docs/stable/generated/torch.compile.html\n                        component.compile()\n                    else:\n                        compiled_component = torch.compile(component)\n                        setattr(pipe, comp, compiled_component)\n                    logger.info(\n                        f\"Applied torch.compile to {comp} component of the pipeline\"\n                    )\n                except Exception as e:\n                    logger.warning(f\"Failed to apply torch.compile to {comp}: {e}\")\n\n        return pipe\n\n    def _get_dtype(self, server_args: ServerArgs) -> torch.dtype:\n        dtype = (\n            torch.bfloat16\n            if torch.get_device_module().is_bf16_supported()\n            else torch.float16\n        )\n\n        dit_precision = server_args.pipeline_config.dit_precision\n        if dit_precision == \"fp16\":\n            dtype = torch.float16\n        elif dit_precision == \"bf16\":\n            dtype = torch.bfloat16\n        elif dit_precision == \"fp32\":\n            dtype = torch.float32\n\n        return dtype\n\n    def _detect_pipeline_type(self) -> None:\n        \"\"\"Detect if this is an image or video pipeline.\"\"\"\n        pipe_class_name = self.diffusers_pipe.__class__.__name__.lower()\n        video_indicators = [\"video\", \"animat\", \"cogvideo\", \"wan\", \"hunyuan\"]\n        self.is_video_pipeline = any(ind in pipe_class_name for ind in video_indicators)\n        logger.debug(\n            \"Detected pipeline type: %s\",\n            \"video\" if self.is_video_pipeline else \"image\",\n        )\n\n    def load_modules(\n        self,\n        server_args: ServerArgs,\n        loaded_modules: dict[str, torch.nn.Module] | None = None,\n    ) -> dict[str, Any]:\n        \"\"\"Skip sglang's module loading - diffusers handles it.\"\"\"\n        return {\"diffusers_pipeline\": self.diffusers_pipe}\n\n    def create_pipeline_stages(self, server_args: ServerArgs) -> None:\n        \"\"\"Create the execution stage wrapping the diffusers pipeline.\"\"\"\n        self.add_stage(\n            stage_name=\"diffusers_execution\",\n            stage=DiffusersExecutionStage(self.diffusers_pipe),\n        )\n\n    def initialize_pipeline(self, server_args: ServerArgs) -> None:\n        pass\n\n    def post_init(self) -> None:\n        \"\"\"Post initialization hook.\"\"\"\n        if self.post_init_called:\n            return\n        self.post_init_called = True\n        self.initialize_pipeline(self.server_args)\n        self.create_pipeline_stages(self.server_args)\n\n    def add_stage(self, stage_name: str, stage: PipelineStage) -> None:\n        \"\"\"Add a stage to the pipeline.\"\"\"\n        if stage_name is None:\n            stage_name = self._infer_stage_name(stage)\n        if stage_name in self._stage_name_mapping:\n            raise ValueError(f\"Duplicate stage name detected: {stage_name}\")\n\n        self._stages.append(stage)\n        self._stage_name_mapping[stage_name] = stage\n        return self\n\n    @property\n    def stages(self) -> list[PipelineStage]:\n        \"\"\"List of stages in the pipeline.\"\"\"\n        return self._stages\n\n    @torch.no_grad()\n    def forward(self, batch: Req, server_args: ServerArgs) -> Req:\n        \"\"\"Execute the pipeline on the given batch.\"\"\"\n        if not self.post_init_called:\n            self.post_init()\n        return self.executor.execute(self.stages, batch, server_args)\n\n    @classmethod\n    def from_pretrained(\n        cls,\n        model_path: str,\n        device: str | None = None,\n        torch_dtype: torch.dtype | None = None,\n        pipeline_config: str | PipelineConfig | None = None,\n        args: argparse.Namespace | None = None,\n        required_config_modules: list[str] | None = None,\n        loaded_modules: dict[str, torch.nn.Module] | None = None,\n        **kwargs,\n    ) -> \"DiffusersPipeline\":\n        \"\"\"Load a pipeline from a pretrained model using diffusers backend.\"\"\"\n        kwargs[\"model_path\"] = model_path\n        server_args = ServerArgs.from_kwargs(**kwargs)\n\n        pipe = cls(\n            model_path,\n            server_args,\n            required_config_modules=required_config_modules,\n            loaded_modules=loaded_modules,\n        )\n        pipe.post_init()\n        return pipe\n\n    def get_module(self, module_name: str, default_value: Any = None) -> Any:\n        \"\"\"Get a module by name.\"\"\"\n        if module_name == \"diffusers_pipeline\":\n            return self.diffusers_pipe\n        return self.modules.get(module_name, default_value)\n\n\nEntryClass = DiffusersPipeline\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines/flux.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\nfrom sglang.multimodal_gen.runtime.pipelines_core import LoRAPipeline\nfrom sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import (\n    ComposedPipelineBase,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages import (\n    InputValidationStage,\n    TextEncodingStage,\n)\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\n# TODO(will): move PRECISION_TO_TYPE to better place\n\nlogger = init_logger(__name__)\n\n\ndef calculate_shift(\n    image_seq_len,\n    base_seq_len: int = 256,\n    max_seq_len: int = 4096,\n    base_shift: float = 0.5,\n    max_shift: float = 1.15,\n):\n    m = (max_shift - base_shift) / (max_seq_len - base_seq_len)\n    b = base_shift - m * base_seq_len\n    mu = image_seq_len * m + b\n    return mu\n\n\ndef prepare_mu(batch: Req, server_args: ServerArgs):\n    height = batch.height\n    width = batch.width\n    vae_scale_factor = (\n        server_args.pipeline_config.vae_config.arch_config.vae_scale_factor\n    )\n    image_seq_len = (int(height) // vae_scale_factor) * (int(width) // vae_scale_factor)\n\n    mu = calculate_shift(\n        image_seq_len,\n        # hard code, since scheduler_config is not in PipelineConfig now\n        256,\n        4096,\n        0.5,\n        1.15,\n    )\n    return \"mu\", mu\n\n\nclass FluxPipeline(LoRAPipeline, ComposedPipelineBase):\n    pipeline_name = \"FluxPipeline\"\n\n    _required_config_modules = [\n        \"text_encoder\",\n        \"text_encoder_2\",\n        \"tokenizer\",\n        \"tokenizer_2\",\n        \"vae\",\n        \"transformer\",\n        \"scheduler\",\n    ]\n\n    def create_pipeline_stages(self, server_args: ServerArgs):\n        self.add_stage(InputValidationStage())\n\n        self.add_stage(\n            TextEncodingStage(\n                text_encoders=[\n                    self.get_module(\"text_encoder\"),\n                    self.get_module(\"text_encoder_2\"),\n                ],\n                tokenizers=[\n                    self.get_module(\"tokenizer\"),\n                    self.get_module(\"tokenizer_2\"),\n                ],\n            ),\n            \"prompt_encoding_stage_primary\",\n        )\n\n        self.add_standard_timestep_preparation_stage(prepare_extra_kwargs=[prepare_mu])\n        self.add_standard_latent_preparation_stage()\n        self.add_standard_denoising_stage()\n        self.add_standard_decoding_stage()\n\n\nEntryClass = FluxPipeline\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines/flux_2.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n# SPDX-License-Identifier: Apache-2.0\n\nfrom diffusers.image_processor import VaeImageProcessor\n\nfrom sglang.multimodal_gen.runtime.pipelines_core import LoRAPipeline, Req\nfrom sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import (\n    ComposedPipelineBase,\n)\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\ndef compute_empirical_mu(batch: Req, server_args: ServerArgs):\n    num_steps = batch.num_inference_steps\n    image_seq_len = batch.raw_latent_shape[1]\n    a1, b1 = 8.73809524e-05, 1.89833333\n    a2, b2 = 0.00016927, 0.45666666\n\n    if image_seq_len > 4300:\n        mu = a2 * image_seq_len + b2\n        return \"mu\", float(mu)\n\n    m_200 = a2 * image_seq_len + b2\n    m_10 = a1 * image_seq_len + b1\n\n    a = (m_200 - m_10) / 190.0\n    b = m_200 - 200.0 * a\n    mu = a * num_steps + b\n\n    return \"mu\", float(mu)\n\n\nclass Flux2Pipeline(LoRAPipeline, ComposedPipelineBase):\n    pipeline_name = \"Flux2Pipeline\"\n\n    _required_config_modules = [\n        \"text_encoder\",\n        \"tokenizer\",\n        \"vae\",\n        \"transformer\",\n        \"scheduler\",\n    ]\n\n    def create_pipeline_stages(self, server_args: ServerArgs):\n        vae_image_processor = VaeImageProcessor(\n            vae_scale_factor=server_args.pipeline_config.vae_config.arch_config.vae_scale_factor\n            * 2\n        )\n\n        self.add_standard_ti2i_stages(\n            include_input_validation=True,\n            vae_image_processor=vae_image_processor,\n            prompt_encoding=\"text\",\n            image_vae_stage_kwargs={\"vae_image_processor\": vae_image_processor},\n            prepare_extra_timestep_kwargs=[compute_empirical_mu],\n        )\n\n\nEntryClass = Flux2Pipeline\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines/flux_2_klein.py",
    "content": "from sglang.multimodal_gen.runtime.pipelines.flux_2 import Flux2Pipeline\n\n\nclass Flux2KleinPipeline(Flux2Pipeline):\n    pipeline_name = \"Flux2KleinPipeline\"\n\n\nEntryClass = Flux2KleinPipeline\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines/glm_image.py",
    "content": "from sglang.multimodal_gen.runtime.pipelines_core import LoRAPipeline\nfrom sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import (\n    ComposedPipelineBase,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages import DenoisingStage\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.model_specific_stages.glm_image import (\n    GlmImageBeforeDenoisingStage,\n)\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\nclass GlmImagePipeline(LoRAPipeline, ComposedPipelineBase):\n    pipeline_name = \"GlmImagePipeline\"\n\n    _required_config_modules = [\n        \"text_encoder\",\n        \"tokenizer\",\n        \"vae\",\n        \"vision_language_encoder\",\n        \"processor\",\n        \"transformer\",\n        \"scheduler\",\n    ]\n\n    def create_pipeline_stages(self, server_args: ServerArgs):\n        self.add_stage(\n            GlmImageBeforeDenoisingStage(\n                vae=self.get_module(\"vae\"),\n                text_encoder=self.get_module(\"text_encoder\"),\n                tokenizer=self.get_module(\"tokenizer\"),\n                processor=self.get_module(\"processor\"),\n                transformer=self.get_module(\"transformer\"),\n                scheduler=self.get_module(\"scheduler\"),\n                vision_language_encoder=self.get_module(\"vision_language_encoder\"),\n            ),\n            \"glm_image_before_denoising_stage\",\n        )\n\n        self.add_stage(\n            DenoisingStage(\n                transformer=self.get_module(\"transformer\"),\n                scheduler=self.get_module(\"scheduler\"),\n            ),\n        )\n\n        self.add_standard_decoding_stage()\n\n\nEntryClass = [GlmImagePipeline]\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines/helios_pipeline.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\"\"\"\nHelios video diffusion pipeline implementation.\n\nThis module contains an implementation of the Helios video diffusion pipeline\nusing the modular pipeline architecture. Phase 1: T2V only.\n\"\"\"\n\nfrom sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import (\n    ComposedPipelineBase,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.lora_pipeline import LoRAPipeline\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages import (\n    InputValidationStage,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.model_specific_stages.helios_decoding import (\n    HeliosDecodingStage,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.model_specific_stages.helios_denoising import (\n    HeliosChunkedDenoisingStage,\n)\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\nclass HeliosPipeline(LoRAPipeline, ComposedPipelineBase):\n    \"\"\"\n    Helios video diffusion pipeline with LoRA support.\n\n    Implements the Helios T2V pipeline with chunked denoising,\n    multi-term memory history, and CFG Zero Star guidance.\n    \"\"\"\n\n    pipeline_name = \"HeliosPipeline\"\n\n    _required_config_modules = [\n        \"text_encoder\",\n        \"tokenizer\",\n        \"vae\",\n        \"transformer\",\n        \"scheduler\",\n    ]\n\n    def initialize_pipeline(self, server_args: ServerArgs):\n        # Use the scheduler loaded from model's scheduler_config.json as-is.\n        # It contains critical config: use_dynamic_shifting=true,\n        # time_shift_type=\"exponential\", etc.\n        scheduler = self.modules.get(\"scheduler\")\n        if scheduler is not None and server_args.pipeline_config.flow_shift is not None:\n            scheduler.set_shift(server_args.pipeline_config.flow_shift)\n\n        # Configure scheduler for Stage 2/3 if enabled\n        pipeline_config = server_args.pipeline_config\n        if scheduler is not None and pipeline_config.is_enable_stage2:\n            scheduler.config.stages = pipeline_config.pyramid_num_stages\n            scheduler.config.scheduler_type = pipeline_config.scheduler_type\n            scheduler.config.gamma = pipeline_config.gamma\n            scheduler.init_sigmas_for_each_stage()\n\n    def create_pipeline_stages(self, server_args: ServerArgs) -> None:\n        self.add_stage(InputValidationStage())\n        self.add_standard_text_encoding_stage()\n        self.add_standard_latent_preparation_stage()\n        # Skip standard timestep preparation — the Helios denoising stage\n        # handles scheduler.set_timesteps internally per-chunk with mu.\n        self.add_stage(\n            HeliosChunkedDenoisingStage(\n                transformer=self.get_module(\"transformer\"),\n                scheduler=self.modules[\"scheduler\"],\n            ),\n            \"helios_chunked_denoising_stage\",\n        )\n        # Helios-specific decoding: decode each chunk's latents separately\n        # to avoid temporal artifacts from Wan VAE causal convolutions\n        self.add_stage(\n            HeliosDecodingStage(vae=self.get_module(\"vae\"), pipeline=self),\n            \"helios_decoding_stage\",\n        )\n\n\nclass HeliosPyramidPipeline(HeliosPipeline):\n    \"\"\"Helios pyramid SR pipeline (used by Helios-Mid and Helios-Distilled).\"\"\"\n\n    pipeline_name = \"HeliosPyramidPipeline\"\n\n\nEntryClass = [HeliosPipeline, HeliosPyramidPipeline]\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines/hunyuan3d_pipeline.py",
    "content": "\"\"\"\nHunyuan3D image-to-mesh pipeline implementation.\n\nShape pipeline: BeforeDenoising -> Denoising -> Export -> Save\nPaint pipeline (optional): Preprocess -> TexGen -> Postprocess\n\"\"\"\n\nfrom __future__ import annotations\n\nimport glob\nimport importlib\nimport os\nfrom itertools import chain\nfrom typing import Any\n\nimport torch\nimport torch.nn as nn\n\nfrom sglang.multimodal_gen.configs.pipeline_configs.hunyuan3d import (\n    Hunyuan3D2PipelineConfig,\n)\nfrom sglang.multimodal_gen.runtime.loader.fsdp_load import (\n    load_model_from_full_model_state_dict,\n    set_default_torch_dtype,\n)\nfrom sglang.multimodal_gen.runtime.loader.utils import get_param_names_mapping\nfrom sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import (\n    ComposedPipelineBase,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages import (\n    Hunyuan3DPaintPostprocessStage,\n    Hunyuan3DPaintPreprocessStage,\n    Hunyuan3DPaintTexGenStage,\n    Hunyuan3DShapeBeforeDenoisingStage,\n    Hunyuan3DShapeDenoisingStage,\n    Hunyuan3DShapeExportStage,\n    Hunyuan3DShapeSaveStage,\n)\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\nclass Hunyuan3D2Pipeline(ComposedPipelineBase):\n    \"\"\"Hunyuan3D 2.0 image-to-mesh pipeline.\n\n    Shape pipeline: BeforeDenoising -> Denoising -> Export -> Save\n    Paint pipeline (optional): Preprocess -> TexGen -> Postprocess\n    \"\"\"\n\n    pipeline_name = \"Hunyuan3D2Pipeline\"\n    _required_config_modules = [\n        \"hy3dshape_model\",\n        \"hy3dshape_vae\",\n        \"hy3dshape_scheduler\",\n        \"hy3dshape_conditioner\",\n        \"hy3dshape_image_processor\",\n    ]\n\n    def _load_config(self) -> dict[str, Any]:\n        return {\n            \"_class_name\": self.pipeline_name,\n            \"_diffusers_version\": \"0.0.0\",\n            \"hy3dshape_model\": [\"diffusers\", \"Hunyuan3DShapeModel\"],\n            \"hy3dshape_vae\": [\"diffusers\", \"Hunyuan3DShapeVAE\"],\n            \"hy3dshape_scheduler\": [\"diffusers\", \"Hunyuan3DShapeScheduler\"],\n            \"hy3dshape_conditioner\": [\"diffusers\", \"Hunyuan3DShapeConditioner\"],\n            \"hy3dshape_image_processor\": [\"diffusers\", \"Hunyuan3DShapeImageProcessor\"],\n        }\n\n    # Class resolution\n    @staticmethod\n    def _resolve_class(target: str) -> Any:\n        \"\"\"Resolve a YAML target string to a Python class.\"\"\"\n        from sglang.multimodal_gen.runtime.models.registry import ModelRegistry\n\n        cls = ModelRegistry.resolve_by_alias(target)\n        if cls is not None:\n            return cls\n\n        class_name = target.rsplit(\".\", 1)[-1]\n        try:\n            cls, _ = ModelRegistry.resolve_model_cls(class_name)\n            return cls\n        except Exception:\n            pass\n\n        from sglang.multimodal_gen.runtime.utils.mesh3d_utils import (\n            resolve_hunyuan3d_tool,\n        )\n\n        for name in (target, class_name):\n            tool_cls = resolve_hunyuan3d_tool(name)\n            if tool_cls is not None:\n                return tool_cls\n\n        module, cls_name = target.rsplit(\".\", 1)\n        return getattr(importlib.import_module(module, package=None), cls_name)\n\n    # Path / checkpoint resolution\n    @staticmethod\n    def _resolve_shape_dir(\n        model_path: str,\n        subfolder: str,\n        use_safetensors: bool,\n        variant: str | None,\n    ) -> tuple[str, str]:\n        \"\"\"Locate (or download) the shape subfolder and return (config_path, ckpt_path).\"\"\"\n        local_path = os.path.join(model_path, subfolder)\n        if not os.path.exists(local_path):\n            local_path = os.path.expanduser(local_path)\n\n        if not os.path.exists(local_path):\n            logger.info(\n                \"Local path %s not found, downloading from HuggingFace Hub\",\n                local_path,\n            )\n            from huggingface_hub import snapshot_download\n\n            downloaded = snapshot_download(\n                repo_id=model_path,\n                allow_patterns=[f\"{subfolder}/*\"],\n            )\n            local_path = os.path.join(downloaded, subfolder)\n\n        config_path = os.path.join(local_path, \"config.yaml\")\n        if not os.path.exists(config_path):\n            for alt in (\"config.yml\", \"model_config.yaml\"):\n                alt_path = os.path.join(local_path, alt)\n                if os.path.exists(alt_path):\n                    config_path = alt_path\n                    break\n\n        if use_safetensors:\n            ckpt_name = (\n                f\"model.{variant}.safetensors\" if variant else \"model.safetensors\"\n            )\n        else:\n            ckpt_name = f\"model-{variant}.ckpt\" if variant else \"model.ckpt\"\n\n        ckpt_path = os.path.join(local_path, ckpt_name)\n        if not os.path.exists(ckpt_path):\n            pattern = \"*.safetensors\" if use_safetensors else \"*.ckpt\"\n            files = glob.glob(os.path.join(local_path, pattern))\n            if files:\n                ckpt_path = files[0]\n\n        logger.info(\"Config path: %s\", config_path)\n        logger.info(\"Checkpoint path: %s\", ckpt_path)\n        return config_path, ckpt_path\n\n    @staticmethod\n    def _resolve_paint_dir(model_path: str, subfolder: str) -> str:\n        \"\"\"Locate (or download) the paint subfolder and return its local path.\"\"\"\n        local_path = os.path.join(model_path, subfolder)\n        if not os.path.exists(local_path):\n            local_path = os.path.expanduser(local_path)\n\n        if not os.path.exists(local_path):\n            logger.info(\n                \"Local path %s not found, downloading from HuggingFace Hub\",\n                local_path,\n            )\n            from huggingface_hub import snapshot_download\n\n            downloaded = snapshot_download(\n                repo_id=model_path,\n                allow_patterns=[f\"{subfolder}/*\"],\n            )\n            local_path = os.path.join(downloaded, subfolder)\n\n        for subdir in (\"vae\", \"unet\"):\n            config_file = os.path.join(local_path, subdir, \"config.json\")\n            if not os.path.exists(config_file):\n                raise FileNotFoundError(\n                    f\"Paint model incomplete: {config_file} not found. \"\n                    \"Download the model or check network connectivity.\"\n                )\n\n        logger.info(\"Resolved paint model directory: %s\", local_path)\n        return local_path\n\n    @staticmethod\n    def _load_and_split_checkpoint(\n        ckpt_path: str, use_safetensors: bool\n    ) -> dict[str, dict[str, torch.Tensor]]:\n        \"\"\"Load a bundled checkpoint and split by the first '.' in each key.\"\"\"\n        if use_safetensors:\n            import safetensors.torch\n\n            flat = safetensors.torch.load_file(ckpt_path, device=\"cpu\")\n            ckpt: dict[str, dict[str, torch.Tensor]] = {}\n            for key, value in flat.items():\n                component = key.split(\".\")[0]\n                sub_key = key[len(component) + 1 :]\n                ckpt.setdefault(component, {})[sub_key] = value\n            return ckpt\n        else:\n            return torch.load(ckpt_path, map_location=\"cpu\", weights_only=True)\n\n    # Component loading helpers\n    @classmethod\n    def _load_dit_model(\n        cls,\n        cfg: dict[str, Any],\n        weights: dict[str, torch.Tensor],\n        device: torch.device,\n        dtype: torch.dtype,\n    ) -> nn.Module:\n        \"\"\"Load the DiT model using meta-device instantiation + standard weight loading.\"\"\"\n        if \"target\" not in cfg:\n            raise KeyError(\"Expected key 'target' in model config.\")\n        target_cls = cls._resolve_class(cfg[\"target\"])\n        params = cfg.get(\"params\", {})\n\n        if hasattr(target_cls, \"build_config_from_params\"):\n            dit_config = target_cls.build_config_from_params(params)\n            init_kwargs: dict[str, Any] = {\"config\": dit_config, \"hf_config\": {}}\n        else:\n            init_kwargs = params\n\n        with set_default_torch_dtype(dtype), torch.device(\"meta\"):\n            model = target_cls(**init_kwargs)\n\n        weight_iterator = ((k, v) for k, v in weights.items())\n        param_names_mapping_fn = get_param_names_mapping(model.param_names_mapping)\n\n        load_model_from_full_model_state_dict(\n            model,\n            weight_iterator,\n            device,\n            dtype,\n            strict=False,\n            param_names_mapping=param_names_mapping_fn,\n        )\n\n        for name, p in chain(model.named_parameters(), model.named_buffers()):\n            if p.is_meta:\n                raise RuntimeError(f\"Unexpected param or buffer {name} on meta device.\")\n            if isinstance(p, nn.Parameter):\n                p.requires_grad = False\n\n        return model.eval()\n\n    @classmethod\n    def _load_simple_component(\n        cls,\n        cfg: dict[str, Any],\n        weights: dict[str, torch.Tensor] | None,\n        device: torch.device,\n        dtype: torch.dtype,\n    ) -> nn.Module:\n        \"\"\"Load a component (VAE / conditioner) with direct instantiation + state_dict.\"\"\"\n        if \"target\" not in cfg:\n            raise KeyError(\"Expected key 'target' in component config.\")\n        target_cls = cls._resolve_class(cfg[\"target\"])\n        params = cfg.get(\"params\", {})\n\n        with set_default_torch_dtype(dtype):\n            component = target_cls(**params)\n\n        if weights is not None:\n            component.load_state_dict(weights, strict=False)\n\n        component.to(device=device, dtype=dtype)\n        return component.eval()\n\n    @classmethod\n    def _instantiate_component(cls, cfg: dict[str, Any]) -> Any:\n        \"\"\"Instantiate a lightweight component (scheduler / image_processor) without weights.\"\"\"\n        if \"target\" not in cfg:\n            raise KeyError(\"Expected key 'target' in component config.\")\n        target_cls = cls._resolve_class(cfg[\"target\"])\n        params = cfg.get(\"params\", {})\n        return target_cls(**params)\n\n    # Module loading override\n    def load_modules(\n        self,\n        server_args: ServerArgs,\n        loaded_modules: dict[str, torch.nn.Module] | None = None,\n    ) -> dict[str, Any]:\n        \"\"\"Load all Hunyuan3D shape components from a bundled checkpoint.\"\"\"\n        import yaml\n\n        from sglang.multimodal_gen.runtime.distributed import get_local_torch_device\n\n        config = server_args.pipeline_config\n        if not isinstance(config, Hunyuan3D2PipelineConfig):\n            raise TypeError(f\"Expected Hunyuan3D2PipelineConfig, got {type(config)}\")\n\n        model_path = config.shape_model_path or server_args.model_path\n\n        logger.info(\"Loading Hunyuan3D shape models from %s\", model_path)\n\n        config_path, ckpt_path = self._resolve_shape_dir(\n            model_path,\n            config.shape_subfolder,\n            config.shape_use_safetensors,\n            config.shape_variant,\n        )\n\n        with open(config_path, \"r\") as f:\n            model_config = yaml.safe_load(f)\n\n        ckpt = self._load_and_split_checkpoint(ckpt_path, config.shape_use_safetensors)\n\n        dtype = torch.float16\n        if config.shape_variant and \"bf16\" in config.shape_variant:\n            dtype = torch.bfloat16\n        device = get_local_torch_device()\n\n        components: dict[str, Any] = {}\n\n        components[\"hy3dshape_model\"] = self._load_dit_model(\n            model_config[\"model\"], ckpt[\"model\"], device, dtype\n        )\n\n        components[\"hy3dshape_vae\"] = self._load_simple_component(\n            model_config[\"vae\"], ckpt.get(\"vae\"), device, dtype\n        )\n\n        components[\"hy3dshape_conditioner\"] = self._load_simple_component(\n            model_config[\"conditioner\"], ckpt.get(\"conditioner\"), device, dtype\n        )\n\n        components[\"hy3dshape_scheduler\"] = self._instantiate_component(\n            model_config[\"scheduler\"]\n        )\n        components[\"hy3dshape_image_processor\"] = self._instantiate_component(\n            model_config[\"image_processor\"]\n        )\n\n        logger.info(\"All Hunyuan3D shape components loaded successfully\")\n\n        if config.paint_enable:\n            try:\n                paint_dir = self._resolve_paint_dir(\n                    server_args.model_path, config.paint_subfolder\n                )\n                components[\"hy3dpaint_dir\"] = paint_dir\n            except Exception as e:\n                logger.warning(\"Failed to resolve paint model path: %s\", e)\n\n        return components\n\n    # Pipeline lifecycle\n    def initialize_pipeline(self, server_args: ServerArgs):\n        config = server_args.pipeline_config\n        if not isinstance(config, Hunyuan3D2PipelineConfig):\n            raise TypeError(\n                \"Hunyuan3D2Pipeline requires Hunyuan3D2PipelineConfig, \"\n                f\"got {type(config)}\"\n            )\n\n    def create_pipeline_stages(self, server_args: ServerArgs):\n        config = server_args.pipeline_config\n        assert isinstance(config, Hunyuan3D2PipelineConfig)\n\n        # Shape: 4 stages\n        self.add_stage(\n            stage_name=\"shape_before_denoising\",\n            stage=Hunyuan3DShapeBeforeDenoisingStage(\n                image_processor=self.get_module(\"hy3dshape_image_processor\"),\n                conditioner=self.get_module(\"hy3dshape_conditioner\"),\n                vae=self.get_module(\"hy3dshape_vae\"),\n                model=self.get_module(\"hy3dshape_model\"),\n                scheduler=self.get_module(\"hy3dshape_scheduler\"),\n                config=config,\n            ),\n        )\n        self.add_stage(\n            stage_name=\"shape_denoising\",\n            stage=Hunyuan3DShapeDenoisingStage(\n                transformer=self.get_module(\"hy3dshape_model\"),\n                scheduler=self.get_module(\"hy3dshape_scheduler\"),\n            ),\n        )\n        self.add_stage(\n            stage_name=\"shape_export\",\n            stage=Hunyuan3DShapeExportStage(\n                vae=self.get_module(\"hy3dshape_vae\"),\n                config=config,\n            ),\n        )\n        self.add_stage(\n            stage_name=\"shape_save\",\n            stage=Hunyuan3DShapeSaveStage(config=config),\n        )\n\n        # Paint: 3 stages (optional)\n        if config.paint_enable:\n            self.add_stage(\n                stage_name=\"paint_preprocess\",\n                stage=Hunyuan3DPaintPreprocessStage(config=config),\n            )\n            self.add_stage(\n                stage_name=\"paint_texgen\",\n                stage=Hunyuan3DPaintTexGenStage(\n                    config=config,\n                    paint_dir=self.get_module(\"hy3dpaint_dir\"),\n                ),\n            )\n            self.add_stage(\n                stage_name=\"paint_postprocess\",\n                stage=Hunyuan3DPaintPostprocessStage(config=config),\n            )\n\n\nEntryClass = Hunyuan3D2Pipeline\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines/hunyuan_pipeline.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"\nHunyuan video diffusion pipeline implementation.\n\nThis module contains an implementation of the Hunyuan video diffusion pipeline\nusing the modular pipeline architecture.\n\"\"\"\n\nfrom sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import (\n    ComposedPipelineBase,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages import (\n    InputValidationStage,\n    TextEncodingStage,\n)\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\n# TODO(will): move PRECISION_TO_TYPE to better place\n\nlogger = init_logger(__name__)\n\n\nclass HunyuanVideoPipeline(ComposedPipelineBase):\n\n    pipeline_name = \"HunyuanVideoPipeline\"\n\n    _required_config_modules = [\n        \"text_encoder\",\n        \"text_encoder_2\",\n        \"tokenizer\",\n        \"tokenizer_2\",\n        \"vae\",\n        \"transformer\",\n        \"scheduler\",\n    ]\n\n    def create_pipeline_stages(self, server_args: ServerArgs):\n        self.add_stage(InputValidationStage())\n        self.add_stage(\n            TextEncodingStage(\n                text_encoders=[\n                    self.get_module(\"text_encoder\"),\n                    self.get_module(\"text_encoder_2\"),\n                ],\n                tokenizers=[\n                    self.get_module(\"tokenizer\"),\n                    self.get_module(\"tokenizer_2\"),\n                ],\n            ),\n            \"prompt_encoding_stage_primary\",\n        )\n        self.add_standard_timestep_preparation_stage()\n        self.add_standard_latent_preparation_stage()\n        self.add_standard_denoising_stage()\n        self.add_standard_decoding_stage()\n\n\nEntryClass = HunyuanVideoPipeline\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines/ltx_2_pipeline.py",
    "content": "import inspect\nimport json\nimport math\nimport os\n\nimport numpy as np\nimport torch\nfrom diffusers import FlowMatchEulerDiscreteScheduler\n\nfrom sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import (\n    ComposedPipelineBase,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages import (\n    InputValidationStage,\n    LTX2AVDecodingStage,\n    LTX2AVDenoisingStage,\n    LTX2AVLatentPreparationStage,\n    LTX2TextConnectorStage,\n    TextEncodingStage,\n)\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\ndef calculate_shift(\n    image_seq_len,\n    base_seq_len: int = 256,\n    max_seq_len: int = 4096,\n    base_shift: float = 0.5,\n    max_shift: float = 1.15,\n):\n    m = (max_shift - base_shift) / (max_seq_len - base_seq_len)\n    b = base_shift - m * base_seq_len\n    mu = image_seq_len * m + b\n    return mu\n\n\ndef prepare_mu(batch: Req, server_args: ServerArgs):\n    height = batch.height\n    width = batch.width\n    num_frames = batch.num_frames\n\n    vae_arch = getattr(\n        getattr(server_args.pipeline_config, \"vae_config\", None), \"arch_config\", None\n    )\n    vae_scale_factor = (\n        getattr(vae_arch, \"spatial_compression_ratio\", None)\n        or getattr(vae_arch, \"vae_scale_factor\", None)\n        or getattr(server_args.pipeline_config, \"vae_scale_factor\", None)\n    )\n    vae_temporal_compression = getattr(\n        vae_arch, \"temporal_compression_ratio\", None\n    ) or getattr(server_args.pipeline_config, \"vae_temporal_compression\", None)\n\n    # Values from LTX2Pipeline in diffusers\n    mu = calculate_shift(\n        4096,\n        base_seq_len=1024,\n        max_seq_len=4096,\n        base_shift=0.95,\n        max_shift=2.05,\n    )\n    return \"mu\", mu\n\n\ndef _load_component_config(model_path: str, component_name: str):\n    \"\"\"Helper to load component config from model_index.json or config.json\"\"\"\n    try:\n        # Try loading model_index.json first\n        index_path = os.path.join(model_path, \"model_index.json\")\n        if os.path.exists(index_path):\n            with open(index_path, \"r\") as f:\n                index = json.load(f)\n\n            if component_name in index:\n                # It's a subfolder\n                subfolder = index[component_name][1]\n                config_path = os.path.join(model_path, subfolder, \"config.json\")\n                if os.path.exists(config_path):\n                    with open(config_path, \"r\") as f:\n                        return json.load(f)\n\n        # Fallback to direct config.json in subfolder if standard structure\n        config_path = os.path.join(model_path, component_name, \"config.json\")\n        if os.path.exists(config_path):\n            with open(config_path, \"r\") as f:\n                return json.load(f)\n\n    except Exception as e:\n        logger.warning(f\"Failed to load config for {component_name}: {e}\")\n\n    return {}\n\n\ndef _filter_kwargs_for_cls(cls, kwargs):\n    \"\"\"Filter kwargs to only include those accepted by cls.__init__\"\"\"\n    sig = inspect.signature(cls.__init__)\n    return {k: v for k, v in kwargs.items() if k in sig.parameters}\n\n\nclass LTX2FlowMatchScheduler(FlowMatchEulerDiscreteScheduler):\n    \"\"\"Override ``_time_shift_exponential`` to use torch f32 instead of numpy f64.\"\"\"\n\n    def _time_shift_exponential(self, mu, sigma, t):\n        if isinstance(t, np.ndarray):\n            t_torch = torch.from_numpy(t).to(torch.float32)\n            result = math.exp(mu) / (math.exp(mu) + (1 / t_torch - 1) ** sigma)\n            return result.numpy()\n        return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)\n\n\nclass LTX2Pipeline(ComposedPipelineBase):\n    # NOTE: must match `model_index.json`'s `_class_name` for native dispatch.\n    pipeline_name = \"LTX2Pipeline\"\n\n    _required_config_modules = [\n        \"transformer\",\n        \"text_encoder\",\n        \"tokenizer\",\n        \"scheduler\",\n        \"vae\",\n        \"audio_vae\",\n        \"vocoder\",\n        \"connectors\",\n    ]\n\n    def initialize_pipeline(self, server_args: ServerArgs):\n        orig = self.get_module(\"scheduler\")\n        self.modules[\"scheduler\"] = LTX2FlowMatchScheduler.from_config(orig.config)\n\n    def create_pipeline_stages(self, server_args: ServerArgs):\n        self.add_stages(\n            [\n                InputValidationStage(),\n                TextEncodingStage(\n                    text_encoders=[self.get_module(\"text_encoder\")],\n                    tokenizers=[self.get_module(\"tokenizer\")],\n                ),\n                LTX2TextConnectorStage(connectors=self.get_module(\"connectors\")),\n            ]\n        )\n\n        self.add_standard_timestep_preparation_stage(prepare_extra_kwargs=[prepare_mu])\n\n        self.add_stages(\n            [\n                LTX2AVLatentPreparationStage(\n                    scheduler=self.get_module(\"scheduler\"),\n                    transformer=self.get_module(\"transformer\"),\n                    audio_vae=self.get_module(\"audio_vae\"),\n                ),\n                LTX2AVDenoisingStage(\n                    transformer=self.get_module(\"transformer\"),\n                    scheduler=self.get_module(\"scheduler\"),\n                    vae=self.get_module(\"vae\"),\n                    audio_vae=self.get_module(\"audio_vae\"),\n                ),\n                LTX2AVDecodingStage(\n                    vae=self.get_module(\"vae\"),\n                    audio_vae=self.get_module(\"audio_vae\"),\n                    vocoder=self.get_module(\"vocoder\"),\n                    pipeline=self,\n                ),\n            ]\n        )\n\n\nEntryClass = LTX2Pipeline\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines/mova_pipeline.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\"\"\"\nMOVA pipeline integration (native SGLang pipeline).\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom sglang.multimodal_gen.configs.pipeline_configs.mova import MOVAPipelineConfig\nfrom sglang.multimodal_gen.configs.sample.mova import MOVASamplingParams\nfrom sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import (\n    ComposedPipelineBase,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages import (\n    ImageVAEEncodingStage,\n    InputValidationStage,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.model_specific_stages.mova import (\n    MOVADecodingStage,\n    MOVADenoisingStage,\n    MOVALatentPreparationStage,\n    MOVATimestepPreparationStage,\n)\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\nclass MOVAPipeline(ComposedPipelineBase):\n    \"\"\"MOVA pipeline with SGLang stage orchestration.\"\"\"\n\n    pipeline_name = \"MOVA\"\n    is_video_pipeline = True\n    _required_config_modules = [\n        \"video_vae\",\n        \"audio_vae\",\n        \"text_encoder\",\n        \"tokenizer\",\n        \"scheduler\",\n        \"video_dit\",\n        \"video_dit_2\",\n        \"audio_dit\",\n        \"dual_tower_bridge\",\n    ]\n    pipeline_config_cls = MOVAPipelineConfig\n    sampling_params_cls = MOVASamplingParams\n\n    def initialize_pipeline(self, server_args: ServerArgs) -> None:\n        \"\"\"\n        Initialize the pipeline.\n\n        MOVA supports Context Parallel (sequence parallel) through USPAttention,\n        which uses Ulysses-style all-to-all communication for distributed attention.\n        \"\"\"\n        if server_args.sp_degree > 1:\n            logger.info(\n                \"MOVA Context Parallel enabled with sp_degree=%d. \"\n                \"Using USPAttention for distributed self-attention.\",\n                server_args.sp_degree,\n            )\n\n    def create_pipeline_stages(self, server_args: ServerArgs) -> None:\n        self.add_stage(InputValidationStage())\n        self.add_standard_text_encoding_stage()\n        if getattr(self.get_module(\"video_dit\"), \"require_vae_embedding\", True):\n            self.add_stage(ImageVAEEncodingStage(vae=self.get_module(\"video_vae\")))\n        self.add_stage(\n            MOVALatentPreparationStage(\n                audio_vae=self.get_module(\"audio_vae\"),\n                require_vae_embedding=getattr(\n                    self.get_module(\"video_dit\"), \"require_vae_embedding\", True\n                ),\n            ),\n            \"mova_latent_preparation_stage\",\n        )\n        self.add_stage(\n            MOVATimestepPreparationStage(\n                scheduler=self.get_module(\"scheduler\"),\n            ),\n            \"mova_timestep_preparation_stage\",\n        )\n        self.add_stage(\n            MOVADenoisingStage(\n                video_dit=self.get_module(\"video_dit\"),\n                video_dit_2=self.get_module(\"video_dit_2\"),\n                audio_dit=self.get_module(\"audio_dit\"),\n                dual_tower_bridge=self.get_module(\"dual_tower_bridge\"),\n                scheduler=self.get_module(\"scheduler\"),\n            ),\n            \"mova_denoising_stage\",\n        )\n        self.add_stage(\n            MOVADecodingStage(\n                video_vae=self.get_module(\"video_vae\"),\n                audio_vae=self.get_module(\"audio_vae\"),\n            ),\n            \"mova_decoding_stage\",\n        )\n\n\nclass MOVAPipelineAlias(MOVAPipeline):\n    pipeline_name = \"MOVAPipeline\"\n\n\nEntryClass = [MOVAPipeline, MOVAPipelineAlias]\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines/qwen_image.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\nfrom diffusers.image_processor import VaeImageProcessor\n\nfrom sglang.multimodal_gen.runtime.pipelines_core import LoRAPipeline\nfrom sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import (\n    ComposedPipelineBase,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.model_specific_stages.qwen_image_layered import (\n    QwenImageLayeredBeforeDenoisingStage,\n)\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\n# TODO(will): move PRECISION_TO_TYPE to better place\n\nlogger = init_logger(__name__)\n\n\ndef calculate_shift(\n    image_seq_len,\n    base_seq_len: int = 256,\n    max_seq_len: int = 4096,\n    base_shift: float = 0.5,\n    max_shift: float = 1.15,\n):\n    m = (max_shift - base_shift) / (max_seq_len - base_seq_len)\n    b = base_shift - m * base_seq_len\n    mu = image_seq_len * m + b\n    return mu\n\n\ndef prepare_mu(batch: Req, server_args: ServerArgs):\n    height = batch.height\n    width = batch.width\n    vae_scale_factor = server_args.pipeline_config.vae_config.vae_scale_factor\n    image_seq_len = (int(height) // vae_scale_factor // 2) * (\n        int(width) // vae_scale_factor // 2\n    )\n    mu = calculate_shift(\n        image_seq_len,\n        # hard code, since scheduler_config is not in PipelineConfig now\n        256,\n        8192,\n        0.5,\n        0.9,\n    )\n    return \"mu\", mu\n\n\nclass QwenImagePipeline(LoRAPipeline, ComposedPipelineBase):\n    pipeline_name = \"QwenImagePipeline\"\n\n    _required_config_modules = [\n        \"text_encoder\",\n        \"tokenizer\",\n        \"vae\",\n        \"transformer\",\n        \"scheduler\",\n    ]\n\n    def create_pipeline_stages(self, server_args: ServerArgs):\n        self.add_standard_t2i_stages(prepare_extra_timestep_kwargs=[prepare_mu])\n\n\nclass QwenImageEditPipeline(LoRAPipeline, ComposedPipelineBase):\n    pipeline_name = \"QwenImageEditPipeline\"\n\n    _required_config_modules = [\n        \"processor\",\n        \"scheduler\",\n        \"text_encoder\",\n        \"tokenizer\",\n        \"transformer\",\n        \"vae\",\n    ]\n\n    def create_pipeline_stages(self, server_args: ServerArgs):\n        vae_image_processor = VaeImageProcessor(\n            vae_scale_factor=server_args.pipeline_config.vae_config.arch_config.vae_scale_factor\n            * 2\n        )\n\n        self.add_standard_ti2i_stages(\n            vae_image_processor=vae_image_processor,\n            prompt_encoding=\"image_encoding\",\n            image_processor_key=\"processor\",\n            prompt_text_encoder_key=\"text_encoder\",\n            prepare_extra_timestep_kwargs=[prepare_mu],\n        )\n\n\nclass QwenImageEditPlusPipeline(QwenImageEditPipeline):\n    pipeline_name = \"QwenImageEditPlusPipeline\"\n\n\ndef prepare_mu_layered(batch: Req, server_args: ServerArgs):\n    base_seqlen = 256 * 256 / 16 / 16\n    mu = (batch.image_latent.shape[1] / base_seqlen) ** 0.5\n    return \"mu\", mu\n\n\nclass QwenImageLayeredPipeline(QwenImageEditPipeline):\n    pipeline_name = \"QwenImageLayeredPipeline\"\n\n    _required_config_modules = [\n        \"vae\",\n        \"tokenizer\",\n        \"processor\",\n        \"transformer\",\n        \"scheduler\",\n    ]\n\n    def create_pipeline_stages(self, server_args: ServerArgs):\n        self.add_stage(\n            QwenImageLayeredBeforeDenoisingStage(\n                vae=self.get_module(\"vae\"),\n                tokenizer=self.get_module(\"tokenizer\"),\n                processor=self.get_module(\"processor\"),\n                transformer=self.get_module(\"transformer\"),\n                scheduler=self.get_module(\"scheduler\"),\n                model_path=self.model_path,\n            )\n        )\n\n        self.add_standard_timestep_preparation_stage(\n            prepare_extra_kwargs=[prepare_mu_layered]\n        )\n        self.add_standard_denoising_stage()\n        self.add_standard_decoding_stage()\n\n\nEntryClass = [\n    QwenImagePipeline,\n    QwenImageEditPipeline,\n    QwenImageEditPlusPipeline,\n    QwenImageLayeredPipeline,\n]\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines/sana.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n#\n# SANA text-to-image pipeline.\n#\n# Stage order matches Flux (InputValidation -> TextEncoding -> TimestepPrep ->\n# LatentPrep -> Denoising -> Decoding) rather than the add_standard_t2i_stages\n# helper (which puts LatentPrep before TimestepPrep). Both orderings are\n# functionally equivalent since these stages are independent.\n#\n# SANA uses a single text encoder (Gemma2), so only one text_encoder + tokenizer\n# pair is registered — unlike Flux which has text_encoder + text_encoder_2.\n# The pipeline_name must match the _class_name in HF model_index.json.\n\nfrom sglang.multimodal_gen.runtime.pipelines_core import LoRAPipeline\nfrom sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import (\n    ComposedPipelineBase,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages import (\n    InputValidationStage,\n    TextEncodingStage,\n)\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\nclass SanaPipeline(LoRAPipeline, ComposedPipelineBase):\n    pipeline_name = \"SanaPipeline\"\n\n    _required_config_modules = [\n        \"text_encoder\",\n        \"tokenizer\",\n        \"vae\",\n        \"transformer\",\n        \"scheduler\",\n    ]\n\n    def create_pipeline_stages(self, server_args: ServerArgs):\n        self.add_stage(InputValidationStage())\n\n        self.add_stage(\n            TextEncodingStage(\n                text_encoders=[self.get_module(\"text_encoder\")],\n                tokenizers=[self.get_module(\"tokenizer\")],\n            ),\n            \"prompt_encoding_stage_primary\",\n        )\n\n        self.add_standard_timestep_preparation_stage()\n        self.add_standard_latent_preparation_stage()\n        self.add_standard_denoising_stage()\n        self.add_standard_decoding_stage()\n\n\nEntryClass = SanaPipeline\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines/wan_causal_dmd_pipeline.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"\nWan causal DMD pipeline implementation.\n\nThis module wires the causal DMD denoising stage into the modular pipeline.\n\"\"\"\n\nfrom sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import (\n    ComposedPipelineBase,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.lora_pipeline import LoRAPipeline\n\n# isort: off\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages import (\n    CausalDMDDenoisingStage,\n    InputValidationStage,\n)\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\n# isort: on\n\nlogger = init_logger(__name__)\n\n\nclass WanCausalDMDPipeline(LoRAPipeline, ComposedPipelineBase):\n    pipeline_name = \"WanCausalDMDPipeline\"\n\n    _required_config_modules = [\n        \"text_encoder\",\n        \"tokenizer\",\n        \"vae\",\n        \"transformer\",\n        \"scheduler\",\n    ]\n\n    def create_pipeline_stages(self, server_args: ServerArgs) -> None:\n        self.add_stage(InputValidationStage())\n        self.add_standard_text_encoding_stage()\n        self.add_standard_latent_preparation_stage()\n\n        self.add_stage(\n            CausalDMDDenoisingStage(\n                transformer=self.get_module(\"transformer\"),\n                scheduler=self.get_module(\"scheduler\"),\n            ),\n        )\n\n        self.add_standard_decoding_stage()\n\n\nEntryClass = WanCausalDMDPipeline\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines/wan_dmd_pipeline.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"\nWan video diffusion pipeline implementation.\n\nThis module contains an implementation of the Wan video diffusion pipeline\nusing the modular pipeline architecture.\n\"\"\"\n\nfrom sglang.multimodal_gen.runtime.models.schedulers.scheduling_flow_match_euler_discrete import (\n    FlowMatchEulerDiscreteScheduler,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import (\n    ComposedPipelineBase,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.lora_pipeline import LoRAPipeline\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\n# isort: off\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages import (\n    DmdDenoisingStage,\n    InputValidationStage,\n)\n\n# isort: on\n\nlogger = init_logger(__name__)\n\n\nclass WanDMDPipeline(LoRAPipeline, ComposedPipelineBase):\n    \"\"\"\n    Wan video diffusion pipeline with LoRA support.\n    \"\"\"\n\n    pipeline_name = \"WanDMDPipeline\"\n\n    _required_config_modules = [\n        \"text_encoder\",\n        \"tokenizer\",\n        \"vae\",\n        \"transformer\",\n        \"scheduler\",\n    ]\n\n    def initialize_pipeline(self, server_args: ServerArgs):\n\n        self.modules[\"scheduler\"] = FlowMatchEulerDiscreteScheduler(\n            shift=server_args.pipeline_config.flow_shift\n        )\n\n    def create_pipeline_stages(self, server_args: ServerArgs) -> None:\n        self.add_stages(\n            [\n                InputValidationStage(),\n            ]\n        )\n\n        self.add_standard_text_encoding_stage()\n\n        self.add_standard_timestep_preparation_stage()\n        self.add_standard_latent_preparation_stage()\n\n        self.add_stages(\n            [\n                DmdDenoisingStage(\n                    transformer=self.get_module(\"transformer\"),\n                    scheduler=self.get_module(\"scheduler\"),\n                ),\n            ]\n        )\n\n        self.add_standard_decoding_stage()\n\n\nEntryClass = WanDMDPipeline\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines/wan_i2v_dmd_pipeline.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"\nWan video diffusion pipeline implementation.\n\nThis module contains an implementation of the Wan video diffusion pipeline\nusing the modular pipeline architecture.\n\"\"\"\n\nfrom sglang.multimodal_gen.runtime.models.schedulers.scheduling_flow_match_euler_discrete import (\n    FlowMatchEulerDiscreteScheduler,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import (\n    ComposedPipelineBase,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.lora_pipeline import LoRAPipeline\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages import DmdDenoisingStage\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\nclass WanImageToVideoDmdPipeline(LoRAPipeline, ComposedPipelineBase):\n    pipeline_name = \"WanImageToVideoDmdPipeline\"\n\n    _required_config_modules = [\n        \"text_encoder\",\n        \"tokenizer\",\n        \"vae\",\n        \"transformer\",\n        \"scheduler\",\n        \"image_encoder\",\n        \"image_processor\",\n    ]\n\n    def initialize_pipeline(self, server_args: ServerArgs):\n        self.modules[\"scheduler\"] = FlowMatchEulerDiscreteScheduler(\n            shift=server_args.pipeline_config.flow_shift\n        )\n\n    def create_pipeline_stages(self, server_args: ServerArgs):\n        self.add_standard_ti2v_stages(\n            image_vae_encoding_position=\"after_latent\",\n            denoising_stage_factory=lambda: DmdDenoisingStage(\n                transformer=self.get_module(\"transformer\"),\n                scheduler=self.get_module(\"scheduler\"),\n                transformer_2=self.get_module(\"transformer_2\"),\n            ),\n        )\n\n\nEntryClass = WanImageToVideoDmdPipeline\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines/wan_i2v_pipeline.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"\nWan video diffusion pipeline implementation.\n\nThis module contains an implementation of the Wan video diffusion pipeline\nusing the modular pipeline architecture.\n\"\"\"\n\nfrom sglang.multimodal_gen.runtime.models.schedulers.scheduling_flow_unipc_multistep import (\n    FlowUniPCMultistepScheduler,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import (\n    ComposedPipelineBase,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.lora_pipeline import LoRAPipeline\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\nclass WanImageToVideoPipeline(LoRAPipeline, ComposedPipelineBase):\n    pipeline_name = \"WanImageToVideoPipeline\"\n\n    _required_config_modules = [\n        \"text_encoder\",\n        \"tokenizer\",\n        \"vae\",\n        \"transformer\",\n        \"scheduler\",\n        \"image_encoder\",\n        \"image_processor\",\n    ]\n\n    def initialize_pipeline(self, server_args: ServerArgs):\n        self.modules[\"scheduler\"] = FlowUniPCMultistepScheduler(\n            shift=server_args.pipeline_config.flow_shift\n        )\n\n    def create_pipeline_stages(self, server_args: ServerArgs):\n        self.add_standard_ti2v_stages()\n\n\nEntryClass = WanImageToVideoPipeline\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines/wan_pipeline.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"\nWan video diffusion pipeline implementation.\n\nThis module contains an implementation of the Wan video diffusion pipeline\nusing the modular pipeline architecture.\n\"\"\"\n\nfrom sglang.multimodal_gen.runtime.models.schedulers.scheduling_flow_unipc_multistep import (\n    FlowUniPCMultistepScheduler,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import (\n    ComposedPipelineBase,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.lora_pipeline import LoRAPipeline\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\nclass WanPipeline(LoRAPipeline, ComposedPipelineBase):\n    \"\"\"\n    Wan video diffusion pipeline with LoRA support.\n    \"\"\"\n\n    pipeline_name = \"WanPipeline\"\n\n    _required_config_modules = [\n        \"text_encoder\",\n        \"tokenizer\",\n        \"vae\",\n        \"transformer\",\n        \"scheduler\",\n    ]\n\n    def initialize_pipeline(self, server_args: ServerArgs):\n        # We use UniPCMScheduler from Wan2.1 official repo, not the one in diffusers.\n        self.modules[\"scheduler\"] = FlowUniPCMultistepScheduler(\n            shift=server_args.pipeline_config.flow_shift\n        )\n\n    def create_pipeline_stages(self, server_args: ServerArgs) -> None:\n        self.add_standard_t2i_stages()\n\n\nEntryClass = WanPipeline\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines/zimage_pipeline.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n# SPDX-License-Identifier: Apache-2.0\n\n\nfrom sglang.multimodal_gen.runtime.pipelines_core import LoRAPipeline, Req\nfrom sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import (\n    ComposedPipelineBase,\n)\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\ndef calculate_shift(\n    image_seq_len,\n    base_seq_len: int = 256,\n    max_seq_len: int = 4096,\n    base_shift: float = 0.5,\n    max_shift: float = 1.15,\n):\n    m = (max_shift - base_shift) / (max_seq_len - base_seq_len)\n    b = base_shift - m * base_seq_len\n    mu = image_seq_len * m + b\n    return mu\n\n\ndef prepare_mu(batch: Req, server_args: ServerArgs):\n    height = batch.height\n    width = batch.width\n    vae_scale_factor = server_args.pipeline_config.vae_config.vae_scale_factor\n    image_seq_len = ((int(height) // vae_scale_factor) // 2) * (\n        (int(width) // vae_scale_factor) // 2\n    )\n    mu = calculate_shift(\n        image_seq_len,\n        # hard code, since scheduler_config is not in PipelineConfig now\n        256,\n        4096,\n        0.5,\n        1.15,\n    )\n    return \"mu\", mu\n\n\nclass ZImagePipeline(LoRAPipeline, ComposedPipelineBase):\n    pipeline_name = \"ZImagePipeline\"\n\n    _required_config_modules = [\n        \"text_encoder\",\n        \"tokenizer\",\n        \"vae\",\n        \"transformer\",\n        \"scheduler\",\n    ]\n\n    def create_pipeline_stages(self, server_args: ServerArgs):\n        self.add_standard_t2i_stages(prepare_extra_timestep_kwargs=[prepare_mu])\n\n\nEntryClass = ZImagePipeline\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines_core/__init__.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"\nDiffusion pipelines for sglang.multimodal_gen.\n\nThis package contains diffusion pipelines for generating videos and images.\n\"\"\"\n\nfrom typing import cast\n\nfrom sglang.multimodal_gen.registry import get_model_info\nfrom sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import (\n    ComposedPipelineBase,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.lora_pipeline import LoRAPipeline\nfrom sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import (\n    maybe_download_model,\n    verify_model_config_and_directory,\n)\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\nclass PipelineWithLoRA(LoRAPipeline, ComposedPipelineBase):\n    \"\"\"Type for a pipeline that has both ComposedPipelineBase and LoRAPipeline functionality.\"\"\"\n\n    pass\n\n\ndef build_pipeline(\n    server_args: ServerArgs,\n) -> PipelineWithLoRA:\n    \"\"\"\n    Only works with valid hf diffusers configs. (model_index.json)\n    We want to build a pipeline based on the inference args mode_path:\n    1. download the model from the hub if it's not already downloaded\n    2. verify the model config and directory\n    3. based on the config, determine the pipeline class\n    \"\"\"\n    model_path = server_args.model_path\n\n    # Check if pipeline class is explicitly specified\n    if server_args.pipeline_class_name:\n        from sglang.multimodal_gen.registry import (\n            _PIPELINE_REGISTRY,\n            _discover_and_register_pipelines,\n        )\n\n        _discover_and_register_pipelines()\n        logger.info(f\"Requested pipeline_class_name: {server_args.pipeline_class_name}\")\n        logger.info(\n            f\"Available pipelines in registry: {list(_PIPELINE_REGISTRY.keys())}\"\n        )\n        pipeline_cls = _PIPELINE_REGISTRY.get(server_args.pipeline_class_name)\n        if pipeline_cls is None:\n            raise ValueError(\n                f\"Pipeline class '{server_args.pipeline_class_name}' not found in registry. \"\n                f\"Available pipelines: {list(_PIPELINE_REGISTRY.keys())}\"\n            )\n        logger.info(\n            f\"✓ Using explicitly specified pipeline: {server_args.pipeline_class_name} (class: {pipeline_cls.__name__})\"\n        )\n    else:\n        logger.info(\"No pipeline_class_name specified, using model_index.json\")\n        model_info = get_model_info(\n            model_path,\n            backend=server_args.backend,\n            model_id=server_args.model_id,\n        )\n        pipeline_cls = model_info.pipeline_cls\n        logger.info(f\"Using pipeline from model_index.json: {pipeline_cls.__name__}\")\n\n    # instantiate the pipelines\n    pipeline = pipeline_cls(model_path, server_args)\n\n    logger.info(\"Pipeline instantiated\")\n\n    return cast(PipelineWithLoRA, pipeline)\n\n\n__all__ = [\n    \"build_pipeline\",\n    \"ComposedPipelineBase\",\n    \"Req\",\n    \"LoRAPipeline\",\n]\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines_core/composed_pipeline_base.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"\nBase class for composed pipelines.\n\nThis module defines the base class for pipelines that are composed of multiple stages.\n\"\"\"\n\nimport os\nimport re\nfrom abc import ABC, abstractmethod\nfrom typing import Any, Callable, Literal, cast\n\nimport torch\nfrom tqdm import tqdm\n\nfrom sglang.multimodal_gen.runtime.loader.component_loaders.component_loader import (\n    PipelineComponentLoader,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.executors.pipeline_executor import (\n    PipelineExecutor,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import OutputBatch, Req\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages import (\n    DecodingStage,\n    DenoisingStage,\n    ImageEncodingStage,\n    ImageVAEEncodingStage,\n    InputValidationStage,\n    LatentPreparationStage,\n    PipelineStage,\n    TextEncodingStage,\n    TimestepPreparationStage,\n)\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import (\n    maybe_download_model,\n    verify_model_config_and_directory,\n)\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\nclass ComposedPipelineBase(ABC):\n    \"\"\"\n    Base class for pipelines composed of multiple stages.\n\n    This class provides the framework for creating pipelines by composing multiple\n    stages together. Each stage is responsible for a specific part of the diffusion\n    process, and the pipeline orchestrates the execution of these stages.\n    \"\"\"\n\n    is_video_pipeline: bool = False  # To be overridden by video pipelines\n    # should contains only the modules to be loaded\n    _required_config_modules: list[str] = []\n    _extra_config_module_map: dict[str, str] = {}\n    server_args: ServerArgs | None = None\n    modules: dict[str, Any] = {}\n    executor: PipelineExecutor | None = None\n\n    # the name of the pipeline it associated with, in diffusers\n    pipeline_name: str\n\n    def is_lora_effective(self):\n        return False\n\n    def is_lora_set(self):\n        return False\n\n    def __init__(\n        self,\n        model_path: str,\n        server_args: ServerArgs,\n        required_config_modules: list[str] | None = None,\n        loaded_modules: dict[str, torch.nn.Module] | None = None,\n        executor: PipelineExecutor | None = None,\n    ):\n        \"\"\"\n        Initialize the pipeline. After __init__, the pipeline should be ready to\n        use. The pipeline should be stateless and not hold any batch state.\n        \"\"\"\n        self.server_args = server_args\n\n        self.model_path: str = model_path\n        self._stages: list[PipelineStage] = []\n        self._stage_name_mapping: dict[str, PipelineStage] = {}\n        self.executor = executor or self.build_executor(server_args=server_args)\n\n        if required_config_modules is not None:\n            self._required_config_modules = required_config_modules\n\n        if self._required_config_modules is None:\n            raise NotImplementedError(\"Subclass must set _required_config_modules\")\n\n        # [module_name, gpu memory usage]\n        self.memory_usages: dict[str, float] = {}\n        # Load modules directly in initialization\n        logger.info(\"Loading pipeline modules...\")\n        self.modules = self.load_modules(server_args, loaded_modules)\n\n        self.__post_init__()\n\n    def build_executor(self, server_args: ServerArgs):\n        # TODO\n        from sglang.multimodal_gen.runtime.pipelines_core.executors.parallel_executor import (\n            ParallelExecutor,\n        )\n\n        # return SyncExecutor(server_args=server_args)\n        return ParallelExecutor(server_args=server_args)\n\n    def __post_init__(self) -> None:\n        assert self.server_args is not None, \"server_args must be set\"\n        self.initialize_pipeline(self.server_args)\n\n        logger.info(\"Creating pipeline stages...\")\n        self.create_pipeline_stages(self.server_args)\n\n    def get_module(self, module_name: str, default_value: Any = None) -> Any:\n        return self.modules.get(module_name, default_value)\n\n    def add_module(self, module_name: str, module: Any):\n        self.modules[module_name] = module\n\n    def _load_config(self) -> dict[str, Any]:\n        model_path = maybe_download_model(self.model_path, force_diffusers_model=True)\n        self.model_path = model_path\n        logger.info(\"Model path: %s\", model_path)\n        config = verify_model_config_and_directory(model_path)\n        return cast(dict[str, Any], config)\n\n    @property\n    def required_config_modules(self) -> list[str]:\n        \"\"\"\n        List of modules that are required by the pipeline. The names should match\n        the diffusers directory and model_index.json file. These modules will be\n        loaded using the PipelineComponentLoader and made available in the\n        modules dictionary. Access these modules using the get_module method.\n\n        class ConcretePipeline(ComposedPipelineBase):\n            _required_config_modules = [\"vae\", \"text_encoder\", \"transformer\", \"scheduler\", \"tokenizer\"]\n\n\n            @property\n            def required_config_modules(self):\n                return self._required_config_modules\n        \"\"\"\n        return self._required_config_modules\n\n    @property\n    def stages(self) -> list[PipelineStage]:\n        \"\"\"\n        List of stages in the pipeline.\n        \"\"\"\n        return self._stages\n\n    @abstractmethod\n    def create_pipeline_stages(self, server_args: ServerArgs):\n        \"\"\"\n        Create the inference pipeline stages.\n        \"\"\"\n        raise NotImplementedError\n\n    def initialize_pipeline(self, server_args: ServerArgs):\n        \"\"\"\n        Initialize the pipeline.\n        \"\"\"\n        return\n\n    def _resolve_component_path(\n        self, server_args: ServerArgs, module_name: str, load_module_name: str\n    ) -> str:\n        override_path = server_args.component_paths.get(module_name)\n        if override_path is not None:\n            # overridden with args like --vae-path\n            component_model_path = maybe_download_model(override_path)\n        else:\n            component_model_path = os.path.join(self.model_path, load_module_name)\n\n        logger.debug(\"Resolved component path: %s\", component_model_path)\n        return component_model_path\n\n    def load_modules(\n        self,\n        server_args: ServerArgs,\n        loaded_modules: dict[str, torch.nn.Module] | None = None,\n    ) -> dict[str, Any]:\n        \"\"\"\n        Load the modules from the config.\n        loaded_modules: Optional[Dict[str, torch.nn.Module]] = None,\n        If provided, loaded_modules will be used instead of loading from config/pretrained weights.\n        \"\"\"\n\n        model_index = self._load_config()\n        logger.info(\"Loading pipeline modules from config: %s\", model_index)\n\n        # remove keys that are not pipeline modules\n        model_index.pop(\"_class_name\")\n        model_index.pop(\"_diffusers_version\")\n        if (\n            \"boundary_ratio\" in model_index\n            and model_index[\"boundary_ratio\"] is not None\n        ):\n            has_transformer = (\n                \"transformer\" in model_index\n                or \"transformer_2\" in model_index\n                or \"transformer\" in self.required_config_modules\n                or \"transformer_2\" in self.required_config_modules\n            )\n            if has_transformer:\n                logger.info(\n                    \"MoE pipeline detected. Adding transformer_2 to self.required_config_modules...\"\n                )\n                if \"transformer_2\" not in self.required_config_modules:\n                    self.required_config_modules.append(\"transformer_2\")\n            else:\n                logger.info(\n                    \"Boundary ratio found in model_index.json without transformers; \"\n                    \"using it for pipeline config only.\"\n                )\n            logger.info(\n                \"Setting boundary ratio to %s\",\n                model_index[\"boundary_ratio\"],\n            )\n            server_args.pipeline_config.dit_config.boundary_ratio = model_index[\n                \"boundary_ratio\"\n            ]\n\n        model_index.pop(\"boundary_ratio\", None)\n        # used by Wan2.2 ti2v\n        model_index.pop(\"expand_timesteps\", None)\n\n        # some sanity checks\n        assert (\n            len(model_index) > 1\n        ), \"model_index.json must contain at least one pipeline module\"\n\n        model_index = {\n            required_module: model_index[required_module]\n            for required_module in self.required_config_modules\n        }\n\n        for module_name in self.required_config_modules:\n            if (\n                module_name not in model_index\n                and module_name in self._extra_config_module_map\n            ):\n                extra_module_value = self._extra_config_module_map[module_name]\n                logger.warning(\n                    \"model_index.json does not contain a %s module, but found {%s: %s} in _extra_config_module_map, adding to model_index.\",\n                    module_name,\n                    module_name,\n                    extra_module_value,\n                )\n                if extra_module_value in model_index:\n                    logger.info(\n                        \"Using module %s for %s\", extra_module_value, module_name\n                    )\n                    model_index[module_name] = model_index[extra_module_value]\n                    continue\n                else:\n                    raise ValueError(\n                        f\"Required module key: {module_name} value: {model_index.get(module_name)} was not found in loaded modules {model_index.keys()}\"\n                    )\n\n        # all the component models used by the pipeline\n        required_modules = self.required_config_modules\n        logger.info(\"Loading required components: %s\", required_modules)\n\n        loaded_components = {}\n        for module_name, (\n            transformers_or_diffusers,\n            architecture,\n        ) in tqdm(iterable=model_index.items(), desc=\"Loading required modules\"):\n            if transformers_or_diffusers is None:\n                logger.warning(\n                    \"Module %s in model_index.json has null value, removing from required_config_modules\",\n                    module_name,\n                )\n                if module_name in self.required_config_modules:\n                    self.required_config_modules.remove(module_name)\n                continue\n            if module_name not in required_modules:\n                logger.info(\"Skipping module %s\", module_name)\n                continue\n            if loaded_modules is not None and module_name in loaded_modules:\n                logger.info(\"Using module %s already provided\", module_name)\n                loaded_components[module_name] = loaded_modules[module_name]\n                continue\n\n            # we load the module from the extra config module map if it exists\n            if module_name in self._extra_config_module_map:\n                load_module_name = self._extra_config_module_map[module_name]\n            else:\n                load_module_name = module_name\n\n            component_model_path = self._resolve_component_path(\n                server_args, module_name, load_module_name\n            )\n            module, memory_usage = PipelineComponentLoader.load_component(\n                component_name=load_module_name,\n                component_model_path=component_model_path,\n                transformers_or_diffusers=transformers_or_diffusers,\n                server_args=server_args,\n            )\n\n            self.memory_usages[load_module_name] = memory_usage\n\n            if module_name in loaded_components:\n                logger.warning(\"Overwriting module %s\", module_name)\n            loaded_components[module_name] = module\n\n        # Check if all required modules were loaded\n        for module_name in required_modules:\n            if (\n                module_name not in loaded_components\n                or loaded_components[module_name] is None\n            ):\n                raise ValueError(\n                    f\"Required module: {module_name} was not found in loaded modules: {list(loaded_components.keys())}\"\n                )\n\n        logger.debug(\n            \"Memory usage of loaded modules (GiB): %s. avail mem: %s GB\",\n            self.memory_usages,\n            round(current_platform.get_available_gpu_memory(), 2),\n        )\n\n        return loaded_components\n\n    @staticmethod\n    def _infer_stage_name(stage: PipelineStage) -> str:\n        class_name = stage.__class__.__name__\n        # snake_case\n        name = re.sub(r\"(?<!^)(?=[A-Z])\", \"_\", class_name).lower()\n        if not name.endswith(\"_stage\"):\n            name += \"_stage\"\n        return name\n\n    def add_stage(\n        self, stage: PipelineStage, stage_name: str | None = None\n    ) -> \"ComposedPipelineBase\":\n\n        assert self.modules is not None, \"No modules are registered\"\n\n        if stage_name is None:\n            stage_name = self._infer_stage_name(stage)\n        if stage_name in self._stage_name_mapping:\n            raise ValueError(f\"Duplicate stage name detected: {stage_name}\")\n\n        self._stages.append(stage)\n        self._stage_name_mapping[stage_name] = stage\n        return self\n\n    def add_stages(\n        self, stages: list[PipelineStage | tuple[PipelineStage, str]]\n    ) -> \"ComposedPipelineBase\":\n\n        for item in stages:\n            if isinstance(item, tuple):\n                stage, name = item\n                self.add_stage(stage, name)\n            else:\n                self.add_stage(item)\n        return self\n\n    def add_stage_if(\n        self,\n        condition: bool | Callable[[], bool],\n        stage: PipelineStage,\n    ) -> \"ComposedPipelineBase\":\n        should_add = condition() if callable(condition) else condition\n        if should_add:\n            self.add_stage(stage)\n        return self\n\n    def get_stage(self, stage_name: str) -> PipelineStage | None:\n        \"\"\"Get a stage by name.\"\"\"\n        return self._stage_name_mapping.get(stage_name)\n\n    def add_standard_text_encoding_stage(\n        self,\n        text_encoder_key: str = \"text_encoder\",\n        tokenizer_key: str = \"tokenizer\",\n    ) -> \"ComposedPipelineBase\":\n        return self.add_stage(\n            TextEncodingStage(\n                text_encoders=[self.get_module(text_encoder_key)],\n                tokenizers=[self.get_module(tokenizer_key)],\n            ),\n        )\n\n    def add_standard_timestep_preparation_stage(\n        self,\n        scheduler_key: str = \"scheduler\",\n        prepare_extra_kwargs: list[Callable] | None = [],\n    ) -> \"ComposedPipelineBase\":\n        return self.add_stage(\n            TimestepPreparationStage(\n                scheduler=self.get_module(scheduler_key),\n                prepare_extra_set_timesteps_kwargs=prepare_extra_kwargs,\n            ),\n        )\n\n    def add_standard_latent_preparation_stage(\n        self,\n        scheduler_key: str = \"scheduler\",\n        transformer_key: str = \"transformer\",\n    ) -> \"ComposedPipelineBase\":\n        return self.add_stage(\n            LatentPreparationStage(\n                scheduler=self.get_module(scheduler_key),\n                transformer=self.get_module(transformer_key),\n            ),\n        )\n\n    def add_standard_denoising_stage(\n        self,\n        transformer_key: str = \"transformer\",\n        transformer_2_key: str | None = \"transformer_2\",\n        scheduler_key: str = \"scheduler\",\n        vae_key: str | None = \"vae\",\n    ) -> \"ComposedPipelineBase\":\n\n        kwargs = {\n            \"transformer\": self.get_module(transformer_key),\n            \"scheduler\": self.get_module(scheduler_key),\n        }\n\n        if transformer_2_key:\n            transformer_2 = self.get_module(transformer_2_key, None)\n            if transformer_2 is not None:\n                kwargs[\"transformer_2\"] = transformer_2\n\n        if vae_key:\n            vae = self.get_module(vae_key, None)\n            if vae is not None:\n                kwargs[\"vae\"] = vae\n                kwargs[\"pipeline\"] = self\n\n        return self.add_stage(DenoisingStage(**kwargs))\n\n    def add_standard_decoding_stage(\n        self,\n        vae_key: str = \"vae\",\n    ) -> \"ComposedPipelineBase\":\n\n        return self.add_stage(\n            DecodingStage(\n                vae=self.get_module(vae_key),\n                pipeline=self,\n                component_name=vae_key,\n            ),\n        )\n\n    def add_standard_t2i_stages(\n        self,\n        include_input_validation: bool = True,\n        prepare_extra_timestep_kwargs: list[Callable] | None = [],\n    ) -> \"ComposedPipelineBase\":\n\n        if include_input_validation:\n            self.add_stage(InputValidationStage())\n\n        self.add_standard_text_encoding_stage()\n\n        self.add_standard_latent_preparation_stage()\n        self.add_standard_timestep_preparation_stage(\n            prepare_extra_kwargs=prepare_extra_timestep_kwargs\n        )\n        self.add_standard_denoising_stage()\n        self.add_standard_decoding_stage()\n\n        return self\n\n    def add_standard_ti2i_stages(\n        self,\n        *,\n        include_input_validation: bool = True,\n        vae_image_processor: Any | None = None,\n        prompt_encoding: Literal[\"text\", \"image_encoding\"] = \"text\",\n        text_encoder_key: str = \"text_encoder\",\n        tokenizer_key: str = \"tokenizer\",\n        image_processor_key: str = \"processor\",\n        prompt_text_encoder_key: str = \"text_encoder\",\n        image_vae_key: str = \"vae\",\n        image_vae_stage_kwargs: dict[str, Any] | None = None,\n        prepare_extra_timestep_kwargs: list[Callable] | None = [],\n    ) -> \"ComposedPipelineBase\":\n        if include_input_validation:\n            self.add_stage(\n                InputValidationStage(vae_image_processor=vae_image_processor)\n            )\n\n        if prompt_encoding == \"text\":\n            self.add_standard_text_encoding_stage(\n                text_encoder_key=text_encoder_key,\n                tokenizer_key=tokenizer_key,\n            )\n        elif prompt_encoding == \"image_encoding\":\n            self.add_stage(\n                ImageEncodingStage(\n                    image_processor=self.get_module(image_processor_key),\n                    text_encoder=self.get_module(prompt_text_encoder_key),\n                ),\n            )\n        else:\n            raise ValueError(f\"Unknown prompt_encoding: {prompt_encoding}\")\n\n        self.add_stage(\n            ImageVAEEncodingStage(\n                vae=self.get_module(image_vae_key),\n                **(image_vae_stage_kwargs or {}),\n            ),\n        )\n\n        self.add_standard_latent_preparation_stage()\n\n        self.add_standard_timestep_preparation_stage(\n            prepare_extra_kwargs=prepare_extra_timestep_kwargs\n        )\n        self.add_standard_denoising_stage()\n        self.add_standard_decoding_stage()\n        return self\n\n    def add_standard_ti2v_stages(\n        self,\n        *,\n        include_input_validation: bool = True,\n        vae_image_processor: Any | None = None,\n        text_encoder_key: str = \"text_encoder\",\n        tokenizer_key: str = \"tokenizer\",\n        image_encoder_key: str = \"image_encoder\",\n        image_processor_key: str = \"image_processor\",\n        image_vae_key: str = \"vae\",\n        image_vae_stage_kwargs: dict[str, Any] | None = None,\n        image_vae_encoding_position: Literal[\n            \"before_timestep\", \"after_latent\"\n        ] = \"before_timestep\",\n        prepare_extra_timestep_kwargs: list[Callable] | None = [],\n        denoising_stage_factory: Callable[[], PipelineStage] | None = None,\n    ) -> \"ComposedPipelineBase\":\n        if include_input_validation:\n            self.add_stage(\n                InputValidationStage(vae_image_processor=vae_image_processor)\n            )\n\n        self.add_standard_text_encoding_stage(\n            text_encoder_key=text_encoder_key,\n            tokenizer_key=tokenizer_key,\n        )\n\n        image_encoder = self.get_module(image_encoder_key, None)\n        image_processor = self.get_module(image_processor_key, None)\n        self.add_stage_if(\n            image_encoder is not None and image_processor is not None,\n            ImageEncodingStage(\n                image_encoder=image_encoder,\n                image_processor=image_processor,\n            ),\n        )\n\n        if image_vae_encoding_position == \"before_timestep\":\n            self.add_stage(\n                ImageVAEEncodingStage(\n                    vae=self.get_module(image_vae_key),\n                    **(image_vae_stage_kwargs or {}),\n                )\n            )\n\n        self.add_standard_latent_preparation_stage()\n        self.add_standard_timestep_preparation_stage(\n            prepare_extra_kwargs=prepare_extra_timestep_kwargs\n        )\n        if image_vae_encoding_position == \"after_latent\":\n            self.add_stage(\n                ImageVAEEncodingStage(\n                    vae=self.get_module(image_vae_key),\n                    **(image_vae_stage_kwargs or {}),\n                )\n            )\n        elif image_vae_encoding_position != \"before_timestep\":\n            raise ValueError(\n                f\"Unknown image_vae_encoding_position: {image_vae_encoding_position}\"\n            )\n\n        if denoising_stage_factory is None:\n            self.add_standard_denoising_stage()\n        else:\n            self.add_stage(denoising_stage_factory())\n\n        self.add_standard_decoding_stage()\n        return self\n\n    # TODO(will): don't hardcode no_grad\n    @torch.no_grad()\n    def forward(\n        self,\n        batch: Req,\n        server_args: ServerArgs,\n    ) -> OutputBatch:\n        \"\"\"\n        Generate a video or image using the pipeline.\n\n        Args:\n            batch: The batch to generate from.\n            server_args: The inference arguments.\n        Returns:\n            Req: The batch with the generated video or image.\n        \"\"\"\n\n        if self.is_lora_set() and not self.is_lora_effective():\n            logger.warning(\n                \"LoRA adapter is set, but not effective. Please make sure the LoRA weights are merged\"\n            )\n\n        # Execute each stage\n        if not batch.is_warmup and not batch.suppress_logs:\n            logger.info(\n                \"Running pipeline stages: %s\",\n                list(self._stage_name_mapping.keys()),\n                main_process_only=True,\n            )\n\n        return self.executor.execute_with_profiling(self.stages, batch, server_args)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines_core/executors/parallel_executor.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\nfrom typing import List\n\nimport torch\n\nfrom sglang.multimodal_gen.runtime.distributed import get_sp_group\nfrom sglang.multimodal_gen.runtime.distributed.parallel_state import (\n    get_cfg_group,\n    get_classifier_free_guidance_rank,\n    get_world_rank,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core import Req\nfrom sglang.multimodal_gen.runtime.pipelines_core.executors.pipeline_executor import (\n    PipelineExecutor,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import OutputBatch\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.base import (\n    PipelineStage,\n    StageParallelismType,\n)\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.distributed import broadcast_pyobj\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\nclass ParallelExecutor(PipelineExecutor):\n    \"\"\"\n    The correctness of the execution relies on the parallelism_type declared by stages\n\n    \"\"\"\n\n    def collect_from_main(self, batches: list[Req]):\n\n        # TODO: fix this condition\n        if self.server_args.sp_degree != 1:\n            sp_group = get_sp_group()\n            batches = broadcast_pyobj(\n                batches,\n                sp_group.rank,\n                sp_group.cpu_group,\n                src=sp_group.ranks[0],\n            )\n\n        if self.server_args.enable_cfg_parallel:\n            batches = broadcast_pyobj(\n                batches,\n                self.worker.cfg_group.rank,\n                self.worker.cfg_cpu_group,\n                src=self.worker.cfg_group.ranks[0],\n            )\n\n    def _execute(\n        self,\n        stages: List[PipelineStage],\n        batch: Req,\n        server_args: ServerArgs,\n    ) -> OutputBatch:\n        \"\"\"\n        Execute all pipeline stages respecting their declared parallelism type.\n        \"\"\"\n        if server_args.enable_cfg_parallel:\n            rank = get_classifier_free_guidance_rank()\n        else:\n            rank = get_world_rank()\n        cfg_group = get_cfg_group()\n\n        # TODO: decide when to gather on main when CFG_PARALLEL -> MAIN_RANK_ONLY\n        for stage in stages:\n            paradigm = stage.parallelism_type\n\n            if paradigm == StageParallelismType.MAIN_RANK_ONLY:\n                if rank == 0:\n                    # Only main rank executes, others just wait\n                    batch = stage(batch, server_args)\n                torch.distributed.barrier()\n\n            elif paradigm == StageParallelismType.CFG_PARALLEL:\n                obj_list = [batch] if rank == 0 else []\n                broadcasted_list = broadcast_pyobj(\n                    obj_list, rank=rank, dist_group=cfg_group.cpu_group, src=0\n                )\n                if rank != 0:\n                    batch = broadcasted_list[0]\n                batch = stage(batch, server_args)\n\n                torch.distributed.barrier()\n\n            elif paradigm == StageParallelismType.REPLICATED:\n                batch = stage(batch, server_args)\n        return batch\n\n    def execute(\n        self,\n        stages: List[PipelineStage],\n        batch: Req,\n        server_args: ServerArgs,\n    ) -> OutputBatch:\n        batch = self._execute(stages, batch, server_args)\n        return batch\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines_core/executors/pipeline_executor.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"\nBase class for all pipeline executors.\n\"\"\"\n\nimport contextlib\nfrom abc import ABC, abstractmethod\nfrom typing import TYPE_CHECKING, List\n\nfrom sglang.multimodal_gen.runtime.distributed import get_world_rank\nfrom sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import OutputBatch, Req\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.runtime.utils.perf_logger import StageProfiler\nfrom sglang.multimodal_gen.runtime.utils.profiler import SGLDiffusionProfiler\n\nif TYPE_CHECKING:\n    # Only for type checkers; avoids runtime circular import\n    from sglang.multimodal_gen.runtime.pipelines_core.stages.base import PipelineStage\n\nlogger = init_logger(__name__)\n\n\nclass Timer(StageProfiler):\n    \"\"\"\n    A wrapper around StageProfiler to maintain backward compatibility.\n    It forces simple logging behavior (log start/end) regardless of env vars.\n    \"\"\"\n\n    def __init__(self, name=\"Stage\"):\n        super().__init__(\n            stage_name=name, logger=logger, metrics=None, log_stage_start_end=True\n        )\n\n\nclass PipelineExecutor(ABC):\n    \"\"\"\n    Abstract base class for all pipeline executors.\n\n    Executors orchestrate the execution of pipeline, with managing the parallel and communications required by stages\n\n    \"\"\"\n\n    def __init__(self, server_args):\n        self.server_args = server_args\n\n    def execute_with_profiling(\n        self,\n        stages: List[\"PipelineStage\"],\n        batch: Req,\n        server_args: ServerArgs,\n    ) -> OutputBatch:\n\n        with self.profile_execution(batch, dump_rank=0):\n            batch = self.execute(stages, batch, server_args)\n\n        return batch\n\n    @abstractmethod\n    def execute(\n        self,\n        stages: List[\"PipelineStage\"],\n        batch: Req,\n        server_args: ServerArgs,\n    ) -> OutputBatch:\n        \"\"\"\n        Execute the pipeline stages.\n\n        Args:\n            stages: A list of pipeline stages to execute.\n            batch: The batch to process.\n            server_args: The server arguments.\n\n        Returns:\n            The processed batch.\n        \"\"\"\n        raise NotImplementedError\n\n    @contextlib.contextmanager\n    def profile_execution(self, batch: Req, dump_rank: int = 0):\n        \"\"\"\n        Context manager for profiling execution.\n        \"\"\"\n        do_profile = batch.profile and not batch.is_warmup\n        if not do_profile:\n            # fast forward\n            yield\n            return\n\n        request_id = batch.request_id\n        rank = get_world_rank()\n\n        profiler = SGLDiffusionProfiler(\n            request_id=request_id,\n            rank=rank,\n            full_profile=batch.profile_all_stages,\n            num_steps=batch.num_profiled_timesteps,\n            num_inference_steps=batch.num_inference_steps,\n        )\n        try:\n            yield\n        finally:\n            profiler.stop(dump_rank=dump_rank)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines_core/executors/sync_executor.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"\nSynchronous pipeline executor implementation.\n\"\"\"\n\nfrom typing import List\n\nfrom sglang.multimodal_gen.runtime.pipelines_core.executors.pipeline_executor import (\n    PipelineExecutor,\n    SGLDiffusionProfiler,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import OutputBatch, Req\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages import PipelineStage\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\n\n\nclass SyncExecutor(PipelineExecutor):\n    \"\"\"\n    A simple synchronous executor that runs stages sequentially.\n    \"\"\"\n\n    def run_profile_all_stages(\n        self,\n        stages: List[PipelineStage],\n        batch: Req,\n        server_args: ServerArgs,\n    ) -> OutputBatch:\n        \"\"\"\n        Execute all pipeline stages sequentially.\n        \"\"\"\n        for stage in stages:\n            batch = stage(batch, server_args)\n            profiler = SGLDiffusionProfiler.get_instance()\n            if profiler:\n                profiler.step_stage()\n        return batch\n\n    def execute(\n        self,\n        stages: List[PipelineStage],\n        batch: Req,\n        server_args: ServerArgs,\n    ) -> OutputBatch:\n        \"\"\"\n        Execute the pipeline stages sequentially.\n        \"\"\"\n\n        batch = self.run_profile_all_stages(stages, batch, server_args)\n\n        return batch\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines_core/lora_format_adapter.py",
    "content": "from __future__ import annotations\n\nimport logging\nfrom enum import Enum\nfrom typing import Dict, Iterable, Mapping, Optional\n\nimport torch\nfrom diffusers.loaders import lora_conversion_utils as lcu\n\nlogger = logging.getLogger(\"LoRAFormatAdapter\")\n\n\nclass LoRAFormat(str, Enum):\n    \"\"\"Supported external LoRA formats before normalization.\"\"\"\n\n    STANDARD = \"standard\"\n    NON_DIFFUSERS_SD = \"non-diffusers-sd\"\n    QWEN_IMAGE_STANDARD = \"qwen-image-standard\"\n    XLABS_FLUX = \"xlabs-ai\"\n    KOHYA_FLUX = \"kohya-flux\"\n    WAN = \"wan\"\n    AI_TOOLKIT_FLUX = \"ai-toolkit-flux\"\n\n\ndef _sample_keys(keys: Iterable[str], k: int = 20) -> list[str]:\n    out = []\n    for i, key in enumerate(keys):\n        if i >= k:\n            break\n        out.append(key)\n    return out\n\n\ndef _has_substring_key(keys: Iterable[str], substr: str) -> bool:\n    return any(substr in k for k in keys)\n\n\ndef _has_prefix_key(keys: Iterable[str], prefix: str) -> bool:\n    return any(k.startswith(prefix) for k in keys)\n\n\ndef _looks_like_xlabs_flux_key(k: str) -> bool:\n    \"\"\"XLabs FLUX-style keys under double_blocks/single_blocks with lora down/up.\"\"\"\n    if not (k.endswith(\".down.weight\") or k.endswith(\".up.weight\")):\n        return False\n\n    if not k.startswith(\n        (\n            \"double_blocks.\",\n            \"single_blocks.\",\n            \"diffusion_model.double_blocks\",\n            \"diffusion_model.single_blocks\",\n        )\n    ):\n        return False\n\n    return \".processor.\" in k or \".proj_lora\" in k or \".qkv_lora\" in k\n\n\ndef _looks_like_kohya_flux(state_dict: Mapping[str, torch.Tensor]) -> bool:\n    \"\"\"Kohya FLUX LoRA (flux_lora.py) under lora_unet_double/single_blocks_ prefixes.\"\"\"\n    if not state_dict:\n        return False\n    keys = state_dict.keys()\n    return any(\n        k.startswith(\"lora_unet_double_blocks_\")\n        or k.startswith(\"lora_unet_single_blocks_\")\n        for k in keys\n    )\n\n\ndef _looks_like_non_diffusers_sd(state_dict: Mapping[str, torch.Tensor]) -> bool:\n    \"\"\"Classic non-diffusers SD LoRA (Kohya/A1111/sd-scripts).\"\"\"\n    if not state_dict:\n        return False\n    keys = state_dict.keys()\n    return all(\n        k.startswith((\"lora_unet_\", \"lora_te_\", \"lora_te1_\", \"lora_te2_\")) for k in keys\n    )\n\n\ndef _looks_like_wan_lora(state_dict: Mapping[str, torch.Tensor]) -> bool:\n    \"\"\"Wan2.2 distill LoRAs (Wan-AI / Wan2.2-Distill-Loras style).\"\"\"\n    if not state_dict:\n        return False\n\n    for k in state_dict.keys():\n        if not k.startswith(\"diffusion_model.blocks.\"):\n            continue\n        if \".lora_down\" not in k and \".lora_up\" not in k:\n            continue\n        if \".cross_attn.\" in k or \".self_attn.\" in k or \".ffn.\" in k or \".norm3.\" in k:\n            return True\n\n    return False\n\n\ndef _looks_like_qwen_image(state_dict: Mapping[str, torch.Tensor]) -> bool:\n    keys = list(state_dict.keys())\n    if not keys:\n        return False\n    return _has_prefix_key(keys, \"transformer.transformer_blocks.\") and (\n        _has_substring_key(keys, \".lora.down.weight\")\n        or _has_substring_key(keys, \".lora.up.weight\")\n    )\n\n\ndef _looks_like_ai_toolkit_flux_lora(state_dict: Mapping[str, torch.Tensor]) -> bool:\n    \"\"\"Detect ai-toolkit/ComfyUI trained Flux LoRA with double_blocks/single_blocks naming.\n\n    Key patterns: double_blocks.{N}.img_attn.proj.lora_A.weight\n    \"\"\"\n    keys = list(state_dict.keys())\n    if not keys:\n        return False\n\n    has_double_blocks = any(\n        k.startswith(\"double_blocks.\")\n        or k.startswith(\"base_model.model.double_blocks.\")\n        for k in keys\n    )\n    has_single_blocks = any(\n        k.startswith(\"single_blocks.\")\n        or k.startswith(\"base_model.model.single_blocks.\")\n        for k in keys\n    )\n    has_lora_ab = _has_substring_key(keys, \".lora_A\") or _has_substring_key(\n        keys, \".lora_B\"\n    )\n\n    return (has_double_blocks or has_single_blocks) and has_lora_ab\n\n\ndef detect_lora_format_from_state_dict(\n    state_dict: Mapping[str, torch.Tensor],\n) -> LoRAFormat:\n    \"\"\"Classify LoRA format by key patterns only.\"\"\"\n    keys = list(state_dict.keys())\n    if not keys:\n        return LoRAFormat.STANDARD\n\n    if _looks_like_ai_toolkit_flux_lora(state_dict):\n        return LoRAFormat.AI_TOOLKIT_FLUX\n\n    if _has_substring_key(keys, \".lora_A\") or _has_substring_key(keys, \".lora_B\"):\n        return LoRAFormat.STANDARD\n\n    if any(_looks_like_xlabs_flux_key(k) for k in keys):\n        return LoRAFormat.XLABS_FLUX\n    if _looks_like_kohya_flux(state_dict):\n        return LoRAFormat.KOHYA_FLUX\n\n    if _looks_like_wan_lora(state_dict):\n        return LoRAFormat.WAN\n\n    if _looks_like_qwen_image(state_dict):\n        return LoRAFormat.STANDARD\n\n    if _looks_like_non_diffusers_sd(state_dict):\n        return LoRAFormat.NON_DIFFUSERS_SD\n\n    if _has_substring_key(keys, \".lora.down\") or _has_substring_key(keys, \".lora_up\"):\n        return LoRAFormat.NON_DIFFUSERS_SD\n\n    return LoRAFormat.STANDARD\n\n\ndef _convert_qwen_image_standard(\n    state_dict: Mapping[str, torch.Tensor],\n    log: logging.Logger,\n) -> Dict[str, torch.Tensor]:\n    \"\"\"Qwen-Image: transformer.*.lora.down/up -> transformer_blocks.*.lora_A/B.\"\"\"\n    out: Dict[str, torch.Tensor] = {}\n\n    for name, tensor in state_dict.items():\n        new_name = name\n\n        if new_name.startswith(\"transformer.\"):\n            new_name = new_name[len(\"transformer.\") :]\n\n        if new_name.endswith(\".lora.down.weight\"):\n            new_name = new_name.replace(\".lora.down.weight\", \".lora_A.weight\")\n        elif new_name.endswith(\".lora.up.weight\"):\n            new_name = new_name.replace(\".lora.up.weight\", \".lora_B.weight\")\n\n        out[new_name] = tensor\n\n    return out\n\n\ndef _convert_non_diffusers_sd_simple(\n    state_dict: Mapping[str, torch.Tensor],\n    log: logging.Logger,\n) -> Dict[str, torch.Tensor]:\n    \"\"\"Generic down/up -> A/B conversion for non-diffusers SD-like formats.\"\"\"\n    out: Dict[str, torch.Tensor] = {}\n\n    for name, tensor in state_dict.items():\n        new_name = name\n\n        if \"lora_down.weight\" in new_name:\n            new_name = new_name.replace(\"lora_down.weight\", \"lora_A.weight\")\n        elif \"lora_up.weight\" in new_name:\n            new_name = new_name.replace(\"lora_up.weight\", \"lora_B.weight\")\n        elif new_name.endswith(\".lora_down\"):\n            new_name = new_name.replace(\".lora_down\", \".lora_A\")\n        elif new_name.endswith(\".lora_up\"):\n            new_name = new_name.replace(\".lora_up\", \".lora_B\")\n\n        out[new_name] = tensor\n\n    sample = _sample_keys(out.keys(), 20)\n    log.info(\n        \"[LoRAFormatAdapter] after NON_DIFFUSERS_SD simple conversion, \"\n        \"sample keys (<=20): %s\",\n        \", \".join(sample),\n    )\n    return out\n\n\ndef _convert_with_diffusers_utils_if_available(\n    state_dict: Mapping[str, torch.Tensor],\n    log: logging.Logger,\n) -> Optional[Dict[str, torch.Tensor]]:\n    \"\"\"Use diffusers.lora_conversion_utils if available.\"\"\"\n    try:\n        if hasattr(lcu, \"maybe_convert_state_dict\"):\n            converted = lcu.maybe_convert_state_dict(  # type: ignore[attr-defined]\n                state_dict\n            )\n        else:\n            converted = dict(state_dict)\n\n        if not isinstance(converted, dict):\n            converted = dict(converted)\n\n        sample = _sample_keys(converted.keys(), 20)\n        log.info(\n            \"[LoRAFormatAdapter] diffusers.lora_conversion_utils converted keys, \"\n            \"sample keys (<=20): %s\",\n            \", \".join(sample),\n        )\n        return converted\n    except Exception as exc:  # pragma: no cover\n        log.warning(\n            \"[LoRAFormatAdapter] diffusers lora_conversion_utils failed, \"\n            \"falling back to internal converters. Error: %s\",\n            exc,\n        )\n        return None\n\n\ndef _convert_via_diffusers_candidates(\n    state_dict: Mapping[str, torch.Tensor],\n    candidate_names: tuple[str, ...],\n    log: logging.Logger,\n    unavailable_warning: str,\n    no_converter_warning: str,\n    success_info: str,\n    all_failed_warning: str,\n) -> Dict[str, torch.Tensor]:\n    \"\"\"Try multiple named converters in lora_conversion_utils, use the first that works.\"\"\"\n    converters = [\n        (n, getattr(lcu, n)) for n in candidate_names if callable(getattr(lcu, n, None))\n    ]\n    if not converters:\n        log.warning(no_converter_warning)\n        return dict(state_dict)\n\n    last_err: Optional[Exception] = None\n\n    for name, fn in converters:\n        try:\n            sd_copy = dict(state_dict)\n            out = fn(sd_copy)\n            if isinstance(out, tuple) and isinstance(out[0], dict):\n                out = out[0]\n            if not isinstance(out, dict):\n                raise TypeError(f\"Converter {name} returned {type(out)}\")\n            log.info(success_info.format(name=name))\n            return out\n        except Exception as exc:\n            last_err = exc\n\n    log.warning(all_failed_warning.format(last_err=last_err))\n    return dict(state_dict)\n\n\ndef _convert_xlabs_ai_via_diffusers(\n    state_dict: Mapping[str, torch.Tensor],\n    log: logging.Logger,\n) -> Dict[str, torch.Tensor]:\n    \"\"\"Convert XLabs FLUX LoRA via diffusers helpers.\"\"\"\n    return _convert_via_diffusers_candidates(\n        state_dict,\n        (\n            \"_convert_xlabs_flux_lora_to_diffusers\",\n            \"convert_xlabs_lora_state_dict_to_diffusers\",\n            \"convert_xlabs_lora_to_diffusers\",\n            \"convert_xlabs_flux_lora_to_diffusers\",\n        ),\n        log=log,\n        unavailable_warning=(\n            \"[LoRAFormatAdapter] XLabs FLUX detected but diffusers is unavailable.\"\n        ),\n        no_converter_warning=(\n            \"[LoRAFormatAdapter] No XLabs FLUX converter found in diffusers.\"\n        ),\n        success_info=\"[LoRAFormatAdapter] Converted XLabs FLUX LoRA using {name}\",\n        all_failed_warning=(\n            \"[LoRAFormatAdapter] All XLabs FLUX converters failed; \"\n            \"last error: {last_err}\"\n        ),\n    )\n\n\ndef _convert_kohya_flux_via_diffusers(\n    state_dict: Mapping[str, torch.Tensor],\n    log: logging.Logger,\n) -> Dict[str, torch.Tensor]:\n    \"\"\"Convert Kohya FLUX LoRA via diffusers helpers.\"\"\"\n    return _convert_via_diffusers_candidates(\n        state_dict,\n        (\n            \"_convert_kohya_flux_lora_to_diffusers\",\n            \"convert_kohya_flux_lora_to_diffusers\",\n        ),\n        log=log,\n        unavailable_warning=(\n            \"[LoRAFormatAdapter] Kohya FLUX detected but diffusers is unavailable.\"\n        ),\n        no_converter_warning=\"[LoRAFormatAdapter] No Kohya FLUX converter found.\",\n        success_info=\"[LoRAFormatAdapter] Converted Kohya FLUX LoRA using {name}\",\n        all_failed_warning=(\n            \"[LoRAFormatAdapter] Kohya FLUX conversion failed; \"\n            \"last error: {last_err}\"\n        ),\n    )\n\n\ndef _convert_ai_toolkit_flux_lora(\n    state_dict: Mapping[str, torch.Tensor],\n    log: logging.Logger,\n) -> Dict[str, torch.Tensor]:\n    \"\"\"Convert ai-toolkit/ComfyUI trained Flux LoRA to SGLang format.\n\n    Handles the naming convention conversion:\n    - double_blocks.{N}.img_attn.qkv -> transformer_blocks.{N}.attn.to_q/k/v\n    - double_blocks.{N}.txt_attn.qkv -> transformer_blocks.{N}.attn.add_q/k/v_proj\n    - double_blocks.{N}.img_attn.proj -> transformer_blocks.{N}.attn.to_out.0\n    - double_blocks.{N}.txt_attn.proj -> transformer_blocks.{N}.attn.to_add_out\n    - double_blocks -> transformer_blocks\n    - single_blocks -> single_transformer_blocks\n    \"\"\"\n    out: Dict[str, torch.Tensor] = {}\n    original_state_dict: Dict[str, torch.Tensor] = {}\n\n    for name, tensor in state_dict.items():\n        new_name = name\n        if new_name.startswith(\"diffusion_model.\"):\n            new_name = new_name[len(\"diffusion_model.\") :]\n        if new_name.startswith(\"base_model.model.\"):\n            new_name = new_name[len(\"base_model.model.\") :]\n        original_state_dict[new_name] = tensor\n\n    num_double_layers = 0\n    num_single_layers = 0\n    for key in original_state_dict.keys():\n        if key.startswith(\"single_blocks.\"):\n            parts = key.split(\".\")\n            if len(parts) > 1 and parts[1].isdigit():\n                num_single_layers = max(num_single_layers, int(parts[1]) + 1)\n        elif key.startswith(\"double_blocks.\"):\n            parts = key.split(\".\")\n            if len(parts) > 1 and parts[1].isdigit():\n                num_double_layers = max(num_double_layers, int(parts[1]) + 1)\n\n    lora_keys = (\"lora_A\", \"lora_B\")\n    attn_types = (\"img_attn\", \"txt_attn\")\n\n    for sl in range(num_single_layers):\n        single_block_prefix = f\"single_blocks.{sl}\"\n        attn_prefix = f\"single_transformer_blocks.{sl}.attn\"\n\n        for lora_key in lora_keys:\n            linear1_key = f\"{single_block_prefix}.linear1.{lora_key}.weight\"\n            if linear1_key in original_state_dict:\n                out[f\"{attn_prefix}.to_qkv_mlp_proj.{lora_key}.weight\"] = (\n                    original_state_dict.pop(linear1_key)\n                )\n\n            linear2_key = f\"{single_block_prefix}.linear2.{lora_key}.weight\"\n            if linear2_key in original_state_dict:\n                out[f\"{attn_prefix}.to_out.{lora_key}.weight\"] = (\n                    original_state_dict.pop(linear2_key)\n                )\n\n    for dl in range(num_double_layers):\n        transformer_block_prefix = f\"transformer_blocks.{dl}\"\n\n        for lora_key in lora_keys:\n            for attn_type in attn_types:\n                attn_prefix = f\"{transformer_block_prefix}.attn\"\n                qkv_key = f\"double_blocks.{dl}.{attn_type}.qkv.{lora_key}.weight\"\n\n                if qkv_key not in original_state_dict:\n                    continue\n\n                fused_qkv_weight = original_state_dict.pop(qkv_key)\n\n                if lora_key == \"lora_A\":\n                    diff_attn_proj_keys = (\n                        [\"to_q\", \"to_k\", \"to_v\"]\n                        if attn_type == \"img_attn\"\n                        else [\"add_q_proj\", \"add_k_proj\", \"add_v_proj\"]\n                    )\n                    for proj_key in diff_attn_proj_keys:\n                        out[f\"{attn_prefix}.{proj_key}.{lora_key}.weight\"] = (\n                            fused_qkv_weight\n                        )\n                else:\n                    if fused_qkv_weight.shape[0] % 3 != 0:\n                        log.warning(\n                            \"[LoRAFormatAdapter] QKV weight shape %s not divisible by 3, \"\n                            \"may cause shape mismatch for %s\",\n                            fused_qkv_weight.shape,\n                            qkv_key,\n                        )\n                    sample_q, sample_k, sample_v = torch.chunk(\n                        fused_qkv_weight, 3, dim=0\n                    )\n\n                    if attn_type == \"img_attn\":\n                        out[f\"{attn_prefix}.to_q.{lora_key}.weight\"] = sample_q\n                        out[f\"{attn_prefix}.to_k.{lora_key}.weight\"] = sample_k\n                        out[f\"{attn_prefix}.to_v.{lora_key}.weight\"] = sample_v\n                    else:\n                        out[f\"{attn_prefix}.add_q_proj.{lora_key}.weight\"] = sample_q\n                        out[f\"{attn_prefix}.add_k_proj.{lora_key}.weight\"] = sample_k\n                        out[f\"{attn_prefix}.add_v_proj.{lora_key}.weight\"] = sample_v\n\n        proj_mappings = [\n            (\"img_attn.proj\", \"attn.to_out.0\"),\n            (\"txt_attn.proj\", \"attn.to_add_out\"),\n        ]\n        for org_proj, diff_proj in proj_mappings:\n            for lora_key in lora_keys:\n                original_key = f\"double_blocks.{dl}.{org_proj}.{lora_key}.weight\"\n                if original_key in original_state_dict:\n                    diffusers_key = (\n                        f\"{transformer_block_prefix}.{diff_proj}.{lora_key}.weight\"\n                    )\n                    out[diffusers_key] = original_state_dict.pop(original_key)\n\n    for key, tensor in original_state_dict.items():\n        new_key = key.replace(\"double_blocks.\", \"transformer_blocks.\")\n        new_key = new_key.replace(\"single_blocks.\", \"single_transformer_blocks.\")\n        out[new_key] = tensor\n\n    extra_mappings = {\n        \"img_in\": \"x_embedder\",\n        \"txt_in\": \"context_embedder\",\n        \"time_in.in_layer\": \"time_guidance_embed.timestep_embedder.linear_1\",\n        \"time_in.out_layer\": \"time_guidance_embed.timestep_embedder.linear_2\",\n        \"final_layer.linear\": \"proj_out\",\n        \"final_layer.adaLN_modulation.1\": \"norm_out.linear\",\n        \"single_stream_modulation.lin\": \"single_stream_modulation.linear\",\n        \"double_stream_modulation_img.lin\": \"double_stream_modulation_img.linear\",\n        \"double_stream_modulation_txt.lin\": \"double_stream_modulation_txt.linear\",\n    }\n\n    final_out: Dict[str, torch.Tensor] = {}\n    for key, tensor in out.items():\n        new_key = key\n        for org_key, diff_key in extra_mappings.items():\n            if key.startswith(org_key):\n                new_key = key.replace(org_key, diff_key, 1)\n                break\n        final_out[new_key] = tensor\n\n    sample = _sample_keys(final_out.keys(), 20)\n    log.info(\n        \"[LoRAFormatAdapter] after AI_TOOLKIT_FLUX conversion, \"\n        \"sample keys (<=20): %s\",\n        \", \".join(sample),\n    )\n    return final_out\n\n\ndef convert_lora_state_dict_by_format(\n    state_dict: Mapping[str, torch.Tensor],\n    fmt: LoRAFormat,\n    log: logging.Logger,\n) -> Dict[str, torch.Tensor]:\n    \"\"\"Normalize a raw LoRA state_dict into A/B + .weight naming.\"\"\"\n    if fmt == LoRAFormat.QWEN_IMAGE_STANDARD:\n        return _convert_qwen_image_standard(state_dict, log)\n\n    if fmt == LoRAFormat.AI_TOOLKIT_FLUX:\n        return _convert_ai_toolkit_flux_lora(state_dict, log)\n\n    if fmt == LoRAFormat.XLABS_FLUX:\n        converted = _convert_xlabs_ai_via_diffusers(state_dict, log)\n        return _convert_non_diffusers_sd_simple(converted, log)\n\n    if fmt == LoRAFormat.KOHYA_FLUX:\n        converted = _convert_kohya_flux_via_diffusers(state_dict, log)\n        return _convert_non_diffusers_sd_simple(converted, log)\n\n    if fmt == LoRAFormat.WAN:\n        maybe = _convert_with_diffusers_utils_if_available(state_dict, log)\n        if maybe is None:\n            maybe = dict(state_dict)\n        return _convert_non_diffusers_sd_simple(maybe, log)\n\n    if fmt == LoRAFormat.STANDARD:\n        maybe = _convert_with_diffusers_utils_if_available(state_dict, log)\n        if maybe is None:\n            maybe = dict(state_dict)\n\n        if _looks_like_qwen_image(maybe):\n            return _convert_qwen_image_standard(maybe, log)\n\n        return maybe\n\n    if fmt == LoRAFormat.NON_DIFFUSERS_SD:\n        maybe = _convert_with_diffusers_utils_if_available(state_dict, log)\n        if maybe is None:\n            maybe = dict(state_dict)\n        return _convert_non_diffusers_sd_simple(maybe, log)\n\n    log.info(\n        \"[LoRAFormatAdapter] format %s not handled specially, returning as-is\",\n        fmt,\n    )\n    return dict(state_dict)\n\n\ndef normalize_lora_state_dict(\n    state_dict: Mapping[str, torch.Tensor],\n    logger: Optional[logging.Logger] = None,\n) -> Dict[str, torch.Tensor]:\n    \"\"\"Normalize any supported LoRA format into a single canonical layout.\"\"\"\n    log = logger or globals()[\"logger\"]\n\n    keys = list(state_dict.keys())\n    log.info(\n        \"[LoRAFormatAdapter] normalize_lora_state_dict called, #keys=%d\",\n        len(keys),\n    )\n    if keys:\n        log.info(\n            \"[LoRAFormatAdapter] before convert, sample keys (<=20): %s\",\n            \", \".join(_sample_keys(keys, 20)),\n        )\n\n    fmt = detect_lora_format_from_state_dict(state_dict)\n    log.info(\"[LoRAFormatAdapter] detected format: %s\", fmt)\n\n    normalized = convert_lora_state_dict_by_format(state_dict, fmt, log)\n\n    norm_keys = list(normalized.keys())\n    if norm_keys:\n        log.info(\n            \"[LoRAFormatAdapter] after convert, sample keys (<=20): %s\",\n            \", \".join(_sample_keys(norm_keys, 20)),\n        )\n\n    return normalized\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines_core/lora_pipeline.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\nimport os\nfrom collections import defaultdict\nfrom collections.abc import Hashable\nfrom contextlib import contextmanager\nfrom typing import Any\n\nimport torch\nimport torch.distributed as dist\nfrom safetensors.torch import load_file\n\nfrom sglang.multimodal_gen.runtime.distributed import get_local_torch_device\nfrom sglang.multimodal_gen.runtime.layers.lora.linear import (\n    BaseLayerWithLoRA,\n    replace_submodule,\n    wrap_with_lora_layer,\n)\nfrom sglang.multimodal_gen.runtime.loader.utils import get_param_names_mapping\nfrom sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import (\n    ComposedPipelineBase,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.lora_format_adapter import (\n    normalize_lora_state_dict,\n)\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import maybe_download_lora\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\n# to avoid deadlocks when forking\nos.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n\nlogger = init_logger(__name__)\n\n\nclass LoRAPipeline(ComposedPipelineBase):\n    \"\"\"\n    Pipeline that supports injecting LoRA adapters into the diffusion transformer.\n    \"\"\"\n\n    # Type annotations for instance attributes (initialized in __init__)\n    # [lora_nickname][target_LoRA_weight_name_in_SGLang_dit] = weight\n    # e.g., [jinx][transformer_blocks.0.attn.to_v.lora_A]\n    lora_adapters: dict[str, dict[str, torch.Tensor]]\n    loaded_adapter_paths: dict[str, str]  # nickname -> lora_path\n    # Track current adapter per module: {\"transformer\": \"high_lora\", \"transformer_2\": \"low_lora\"}\n    cur_adapter_name: dict[str, str]\n    cur_adapter_path: dict[str, str]\n    cur_adapter_strength: dict[str, float]  # Track current strength per module\n    cur_adapter_config: dict[str, tuple[list[str], list[float]]]\n    # [dit_layer_name] = wrapped_lora_layer\n    lora_layers: dict[str, BaseLayerWithLoRA]\n    lora_layers_critic: dict[str, BaseLayerWithLoRA]\n    lora_layers_transformer_2: dict[str, BaseLayerWithLoRA]\n    server_args: ServerArgs\n    exclude_lora_layers: list[str]\n    device: torch.device\n    lora_target_modules: list[str] | None\n    lora_path: str | None\n    lora_nickname: str\n    lora_rank: int | None\n    lora_alpha: int | None\n    lora_initialized: bool\n    # Track merge status per module: {\"transformer\": True, \"transformer_2\": False}\n    is_lora_merged: dict[str, bool]\n    # Valid target values for set_lora (class constant, immutable)\n    VALID_TARGETS: list[str] = [\"all\", \"transformer\", \"transformer_2\", \"critic\"]\n\n    def __init__(self, *args, **kwargs) -> None:\n        super().__init__(*args, **kwargs)\n        # Initialize all mutable instance attributes to avoid sharing across instances\n        self.lora_adapters = defaultdict(dict)\n        self.loaded_adapter_paths = {}\n        self.cur_adapter_name = {}\n        self.cur_adapter_path = {}\n        self.cur_adapter_strength = {}\n        # Track full LoRA config: {module_name: (nickname_list, strength_list)}\n        self.cur_adapter_config = {}\n        self.lora_layers = {}\n        self.lora_layers_critic = {}\n        self.lora_layers_transformer_2 = {}\n        self.is_lora_merged = {}\n        self.lora_initialized = False\n        self.lora_rank = None\n        self.lora_alpha = None\n        self.lora_path = None\n        self.lora_nickname = \"default\"\n\n        # Initialize from server_args\n        self.device = get_local_torch_device()\n        self.exclude_lora_layers = (\n            self.server_args.pipeline_config.dit_config.arch_config.exclude_lora_layers\n        )\n        self.lora_target_modules = self.server_args.lora_target_modules\n        self.lora_path = self.server_args.lora_path\n        self.lora_nickname = self.server_args.lora_nickname\n        if self.lora_path is not None:\n            self.convert_to_lora_layers()\n            self.set_lora(\n                self.lora_nickname, self.lora_path, strength=self.server_args.lora_scale  # type: ignore\n            )  # type: ignore\n\n    def is_target_layer(self, module_name: str) -> bool:\n        if self.lora_target_modules is None:\n            return True\n        return any(\n            target_name in module_name for target_name in self.lora_target_modules\n        )\n\n    def _get_target_lora_layers(\n        self, target: str\n    ) -> tuple[list[tuple[str, dict[str, BaseLayerWithLoRA]]], str | None]:\n        \"\"\"\n        Return a list of (module_name, lora_layers_dict) based on the target.\n\n        Args:\n            target: One of \"all\", \"transformer\", \"transformer_2\", \"critic\".\n\n        Returns:\n            A tuple of (result, error_message):\n            - result: List of tuples (module_name, lora_layers_dict) to operate on.\n            - error_message: Error description if target is invalid or module doesn't exist, None otherwise.\n        \"\"\"\n        if target == \"all\":\n            result: list[tuple[str, dict[str, BaseLayerWithLoRA]]] = [\n                (\"transformer\", self.lora_layers)\n            ]\n            if self.lora_layers_transformer_2:\n                result.append((\"transformer_2\", self.lora_layers_transformer_2))\n            if self.lora_layers_critic:\n                result.append((\"critic\", self.lora_layers_critic))\n            return result, None\n        elif target == \"transformer\":\n            return [(\"transformer\", self.lora_layers)], None\n        elif target == \"transformer_2\":\n            if not self.lora_layers_transformer_2:\n                return [], \"transformer_2 does not exist in this pipeline\"\n            return [(\"transformer_2\", self.lora_layers_transformer_2)], None\n        elif target == \"critic\":\n            if not self.lora_layers_critic:\n                return (\n                    [],\n                    \"critic (fake_score_transformer) does not exist in this pipeline\",\n                )\n            return [(\"critic\", self.lora_layers_critic)], None\n        else:\n            return [], f\"Invalid target: {target}. Valid targets: {self.VALID_TARGETS}\"\n\n    @contextmanager\n    def _temporarily_disable_offload(\n        self,\n        target_modules: list[tuple[str, dict[str, BaseLayerWithLoRA]]] | None = None,\n        target: str | None = None,\n        use_module_names_only: bool = False,\n    ):\n        \"\"\"\n        Context manager to temporarily disable layerwise offload for the given modules.\n\n        Args:\n            target_modules: List of (module_name, lora_layers_dict) tuples. If None, will be determined from target.\n            target: Target string (\"all\", \"transformer\", etc.). Used if target_modules is None.\n            use_module_names_only: If True, determine module names directly from target without requiring\n                                   LoRA initialization. Used for convert_to_lora_layers scenario.\n\n        Yields:\n            List of modules that had offload disabled.\n        \"\"\"\n        from sglang.multimodal_gen.runtime.utils.layerwise_offload import (\n            OffloadableDiTMixin,\n        )\n\n        module_names = []\n        if target_modules is not None:\n            # Extract module names from target_modules\n            module_names = [module_name for module_name, _ in target_modules]\n        elif target is not None:\n            if use_module_names_only:\n                if target == \"all\":\n                    module_names = [\"transformer\", \"transformer_2\"]\n                elif target in [\"transformer\", \"transformer_2\", \"critic\"]:\n                    module_names = [target]\n            else:\n                target_modules, _ = self._get_target_lora_layers(target)\n                if target_modules:\n                    module_names = [module_name for module_name, _ in target_modules]\n        else:\n            yield []\n            return\n\n        if not module_names:\n            yield []\n            return\n\n        # clear device cache to free up unused memory\n        if torch.get_device_module().is_available():\n            torch.get_device_module().synchronize()\n            torch.get_device_module().empty_cache()\n\n        offload_disabled_modules = []\n        for module_name in module_names:\n            module = self.modules.get(module_name)\n            if module is not None and isinstance(module, OffloadableDiTMixin):\n                if module.layerwise_offload_managers is not None:\n                    module.disable_offload()\n                    offload_disabled_modules.append(module)\n\n        try:\n            yield offload_disabled_modules\n        finally:\n            # Re-enable layerwise offload: sync weights to CPU and restore hooks\n            for module in offload_disabled_modules:\n                module.enable_offload()\n\n    def convert_module_lora_layers(\n        self,\n        module: torch.nn.Module,\n        module_name: str,\n        target_lora_layers: dict[str, BaseLayerWithLoRA],\n        check_exclude: bool = True,\n    ) -> int:\n        \"\"\"\n        Convert layers in a module to LoRA layers.\n\n        Args:\n            module: The module to convert.\n            module_name: The name of the module (for replace_submodule).\n            target_lora_layers: The dictionary to store the converted LoRA layers.\n            check_exclude: Whether to check the exclude_lora_layers list.\n\n        Returns:\n            The number of layers converted.\n        \"\"\"\n        converted_count = 0\n        for name, layer in module.named_modules():\n            if not self.is_target_layer(name):\n                continue\n\n            if check_exclude:\n                excluded = any(\n                    exclude_layer in name for exclude_layer in self.exclude_lora_layers\n                )\n                if excluded:\n                    continue\n\n            lora_layer = wrap_with_lora_layer(\n                layer,\n                lora_rank=self.lora_rank,\n                lora_alpha=self.lora_alpha,\n            )\n            if lora_layer is not None:\n                target_lora_layers[name] = lora_layer\n                replace_submodule(self.modules[module_name], name, lora_layer)\n                converted_count += 1\n\n        return converted_count\n\n    def convert_to_lora_layers(self) -> None:\n        \"\"\"\n        Unified method to convert the transformer to a LoRA transformer.\n        \"\"\"\n        if self.lora_initialized:\n            return\n        self.lora_initialized = True\n\n        # Convert transformer\n        converted_count = self.convert_module_lora_layers(\n            self.modules[\"transformer\"],\n            \"transformer\",\n            self.lora_layers,\n            check_exclude=True,\n        )\n        logger.info(\"Converted %d layers to LoRA layers\", converted_count)\n\n        # Convert transformer_2 if exists (e.g., Wan2.2 A14B dual-transformer)\n        if (\n            \"transformer_2\" in self.modules\n            and self.modules[\"transformer_2\"] is not None\n        ):\n            converted_count_2 = self.convert_module_lora_layers(\n                self.modules[\"transformer_2\"],\n                \"transformer_2\",\n                self.lora_layers_transformer_2,\n                check_exclude=True,\n            )\n            logger.info(\n                \"Converted %d layers to LoRA layers in transformer_2\", converted_count_2\n            )\n\n        # Convert fake_score_transformer if exists\n        if \"fake_score_transformer\" in self.modules:\n            converted_count_critic = self.convert_module_lora_layers(\n                self.modules[\"fake_score_transformer\"],\n                \"fake_score_transformer\",\n                self.lora_layers_critic,\n                check_exclude=False,\n            )\n            logger.info(\n                \"Converted %d layers to LoRA layers in the critic model\",\n                converted_count_critic,\n            )\n\n    def _normalize_lora_params(\n        self,\n        lora_nickname: str | list[str],\n        lora_path: str | None | list[str | None],\n        strength: float | list[float],\n        target: str | list[str],\n    ) -> tuple[list[str], list[str | None], list[float], list[str]]:\n        \"\"\"\n        Normalize LoRA parameters to lists for multi-LoRA support.\n\n        Requirements:\n        - each nickname must have a corresponding lora_path (no implicit repeat)\n        - strength / target if scalar broadcast, else length must match nickname\n        \"\"\"\n        # nickname\n        if isinstance(lora_nickname, str):\n            lora_nicknames = [lora_nickname]\n        else:\n            lora_nicknames = lora_nickname\n\n        # lora_path: require 1:1 mapping with nickname (no implicit repeat)\n        if isinstance(lora_path, list):\n            lora_paths = lora_path\n        else:\n            lora_paths = [lora_path]\n        if len(lora_paths) != len(lora_nicknames):\n            raise ValueError(\n                f\"Length mismatch: lora_nickname has {len(lora_nicknames)} items, \"\n                f\"but lora_path has {len(lora_paths)} items. \"\n                \"Provide one path per nickname.\"\n            )\n\n        # strength and target: allow scalar broadcast, else length must match\n        if isinstance(strength, (int, float)):\n            strengths = [float(strength)] * len(lora_nicknames)\n        else:\n            strengths = [float(s) for s in strength]\n        if len(strengths) != len(lora_nicknames):\n            raise ValueError(\n                f\"Length mismatch: lora_nickname has {len(lora_nicknames)} items, \"\n                f\"but strength has {len(strengths)} items\"\n            )\n\n        if isinstance(target, str):\n            targets = [target] * len(lora_nicknames)\n        else:\n            targets = target\n        if len(targets) != len(lora_nicknames):\n            raise ValueError(\n                f\"Length mismatch: lora_nickname has {len(lora_nicknames)} items, \"\n                f\"but target has {len(targets)} items\"\n            )\n        return lora_nicknames, lora_paths, strengths, targets\n\n    def _check_lora_config_matches(\n        self,\n        module_name: str,\n        target_nicknames: list[str],\n        target_strengths: list[float],\n        adapter_updated: bool,\n    ) -> bool:\n        \"\"\"\n        Check if current LoRA configuration matches the target configuration.\n\n        Args:\n            module_name: The name of the module to check.\n            target_nicknames: List of LoRA nicknames to apply.\n            target_strengths: List of LoRA strengths to apply.\n            adapter_updated: Whether any adapter was updated/loaded.\n\n        Returns:\n            True if the configuration matches exactly (including order and strength), False otherwise.\n        \"\"\"\n        if not self.is_lora_merged.get(module_name, False):\n            return False\n        if adapter_updated:\n            return False  # Adapter was updated, need to reapply\n\n        stored_config = self.cur_adapter_config.get(module_name)\n        if stored_config is None:\n            return False\n\n        stored_nicknames, stored_strengths = stored_config\n        # Compare: nickname list and strength list must match exactly (including order)\n        return (\n            stored_nicknames == target_nicknames\n            and stored_strengths == target_strengths\n        )\n\n    def _apply_lora_to_layers(\n        self,\n        lora_layers: dict[str, BaseLayerWithLoRA],\n        lora_nicknames: list[str],\n        lora_paths: list[str | None],\n        rank: int,\n        strengths: list[float],\n        clear_existing: bool = False,\n    ) -> int:\n        \"\"\"\n        Apply LoRA weights to the given lora_layers. Supports multiple LoRA adapters.\n\n        Args:\n            lora_layers: The dictionary of LoRA layers to apply weights to.\n            lora_nicknames: The list of nicknames of the LoRA adapters.\n            lora_paths: The list of paths to the LoRA adapters. Must match length of lora_nicknames.\n            rank: The distributed rank (for logging).\n            strengths: The list of LoRA strengths for merge. Must match length of lora_nicknames.\n            clear_existing: If True, clear existing LoRA weights before adding new ones.\n\n        Returns:\n            The number of layers that had LoRA weights applied.\n        \"\"\"\n        if len(lora_paths) != len(lora_nicknames):\n            raise ValueError(\n                f\"Length mismatch: lora_nicknames has {len(lora_nicknames)} items, \"\n                f\"but lora_paths has {len(lora_paths)} items\"\n            )\n        if len(strengths) != len(lora_nicknames):\n            raise ValueError(\n                f\"Length mismatch: lora_nicknames has {len(lora_nicknames)} items, \"\n                f\"but strengths has {len(strengths)} items\"\n            )\n\n        adapted_count = 0\n        for name, layer in lora_layers.items():\n            # Apply all LoRA adapters in order\n            for idx, (nickname, path, lora_strength) in enumerate(\n                zip(lora_nicknames, lora_paths, strengths)\n            ):\n                lora_A_name = name + \".lora_A\"\n                lora_B_name = name + \".lora_B\"\n                if (\n                    lora_A_name in self.lora_adapters[nickname]\n                    and lora_B_name in self.lora_adapters[nickname]\n                ):\n                    # Some LoRA checkpoints (e.g. Lightning distill) store per-layer alpha as \"<layer>.alpha\".\n                    # If present, we must apply the standard LoRA scaling: scale = alpha / rank.\n                    try:\n                        inferred_rank = int(\n                            self.lora_adapters[nickname][lora_A_name].shape[0]\n                        )\n                    except Exception:\n                        inferred_rank = None\n                    # Default to None for some checkpoints without \"<layer>.alpha\"\n                    inferred_alpha: int | None = None\n                    alpha_key = name + \".alpha\"\n                    if alpha_key in self.lora_adapters[nickname]:\n                        try:\n                            inferred_alpha = int(\n                                self.lora_adapters[nickname][alpha_key].item()\n                            )\n                        except Exception:\n                            inferred_alpha = None\n\n                    if inferred_rank is not None:\n                        layer.lora_rank = inferred_rank\n                        layer.lora_alpha = (\n                            inferred_alpha\n                            if inferred_alpha is not None\n                            else inferred_rank\n                        )\n\n                    layer.set_lora_weights(\n                        self.lora_adapters[nickname][lora_A_name],\n                        self.lora_adapters[nickname][lora_B_name],\n                        lora_path=path,\n                        strength=lora_strength,\n                        clear_existing=(\n                            clear_existing and idx == 0\n                        ),  # Only clear on first LoRA\n                    )\n                    adapted_count += 1\n                else:\n                    if rank == 0 and idx == 0:  # Only warn for first missing LoRA\n                        logger.warning(\n                            \"LoRA adapter %s does not contain the weights for layer '%s'. LoRA will not be applied to it.\",\n                            path,\n                            name,\n                        )\n                    # Only disable if no LoRA was applied at all\n                    if idx == len(lora_nicknames) - 1:\n                        has_any_lora = any(\n                            name + \".lora_A\" in self.lora_adapters[n]\n                            and name + \".lora_B\" in self.lora_adapters[n]\n                            for n in lora_nicknames\n                        )\n                        if not has_any_lora:\n                            layer.disable_lora = True\n        return adapted_count\n\n    def is_lora_effective(self, target: str = \"all\") -> bool:\n        \"\"\"\n        Check if LoRA is currently effective (merged) for the specified target.\n\n        Args:\n            target: Which transformer to check. \"all\" returns True if any is merged.\n        \"\"\"\n        if target == \"all\":\n            return any(self.is_lora_merged.values())\n        return self.is_lora_merged.get(target, False)\n\n    def is_lora_set(self, target: str = \"all\") -> bool:\n        \"\"\"\n        Check if LoRA has been set for the specified target.\n\n        Args:\n            target: Which transformer to check. \"all\" returns True if any is set.\n        \"\"\"\n        if not self.lora_initialized:\n            return False\n        if target == \"all\":\n            return bool(self.cur_adapter_name)\n        return target in self.cur_adapter_name\n\n    def load_lora_adapter(self, lora_path: str, lora_nickname: str, rank: int):\n        \"\"\"\n        Load the LoRA, and setup the lora_adapters for later weight replacement\n        \"\"\"\n        assert lora_path is not None\n\n        # Only rank 0 downloads to avoid race conditions where other ranks\n        # try to load incomplete downloads\n        if rank == 0:\n            lora_local_path = maybe_download_lora(lora_path)\n        else:\n            lora_local_path = None\n\n        # Synchronize all ranks after download completes\n        if dist.is_initialized():\n            dist.barrier()\n\n        # Non-rank-0 workers now download (will hit cache since rank 0 completed)\n        if rank != 0:\n            lora_local_path = maybe_download_lora(lora_path)\n\n        raw_state_dict = load_file(lora_local_path)\n        lora_state_dict = normalize_lora_state_dict(raw_state_dict, logger=logger)\n\n        if lora_nickname in self.lora_adapters:\n            self.lora_adapters[lora_nickname].clear()\n\n        config = self.server_args.pipeline_config.dit_config.arch_config\n\n        param_names_mapping_fn = get_param_names_mapping(\n            config.param_names_mapping\n            or self.modules[\"transformer\"].param_names_mapping\n        )\n        lora_param_names_mapping_fn = get_param_names_mapping(\n            config.lora_param_names_mapping\n            or self.modules[\"transformer\"].lora_param_names_mapping\n        )\n\n        to_merge_params: defaultdict[Hashable, dict[Any, Any]] = defaultdict(dict)\n        for name, weight in lora_state_dict.items():\n            name = name.replace(\"diffusion_model.\", \"\")\n            name = name.replace(\".weight\", \"\")\n            # misc-format -> HF-format\n            name, _, _ = lora_param_names_mapping_fn(name)\n            # HF-format (LoRA) -> SGLang-dit-format\n            target_name, merge_index, num_params_to_merge = param_names_mapping_fn(name)\n            # for fuse B(out_dim, r) @ A(r, in_dim) -> (N, out_dim, r) @ (N, r, in_dim)\n            # see param mapping in HunyuanVideoArchConfig\n            if merge_index is not None:\n                to_merge_params[target_name][merge_index] = weight\n                if len(to_merge_params[target_name]) == num_params_to_merge:\n                    sorted_tensors = [\n                        to_merge_params[target_name][i]\n                        for i in range(num_params_to_merge)\n                    ]\n                    # Use stack instead of cat because it needs to be compatible with TP.\n                    weight = torch.stack(sorted_tensors, dim=0)\n                    del to_merge_params[target_name]\n                else:\n                    continue\n\n            if target_name in self.lora_adapters[lora_nickname]:\n                raise ValueError(\n                    f\"Dit target weight name {target_name} already exists in lora_adapters[{lora_nickname}]\"\n                )\n            self.lora_adapters[lora_nickname][target_name] = weight.to(self.device)\n        self.loaded_adapter_paths[lora_nickname] = lora_path\n        logger.info(\"Rank %d: loaded LoRA adapter %s\", rank, lora_path)\n\n    def set_lora(\n        self,\n        lora_nickname: str | list[str],\n        lora_path: str | None | list[str | None] = None,\n        target: str | list[str] = \"all\",\n        strength: float | list[float] = 1.0,\n    ):  # type: ignore\n        \"\"\"\n        Load LoRA adapter(s) into the pipeline and apply them to the specified transformer(s).\n        Supports both single LoRA (backward compatible) and multiple LoRA adapters.\n        \"\"\"\n        # Normalize inputs to lists for multi-LoRA support\n        lora_nicknames, lora_paths, strengths, targets = self._normalize_lora_params(\n            lora_nickname, lora_path, strength, target\n        )\n\n        # Validate targets\n        invalid_targets = [t for t in targets if t not in self.VALID_TARGETS]\n        if invalid_targets:\n            raise ValueError(\n                f\"Invalid target(s): {invalid_targets}. Valid targets: {self.VALID_TARGETS}\"\n            )\n\n        # Disable layerwise offload before convert_to_lora_layers to ensure weights are accessible\n        # This is critical because convert_to_lora_layers needs to save cpu_weight from actual weights,\n        # not from offloaded placeholder tensors\n        if not self.lora_initialized:\n            with self._temporarily_disable_offload(\n                target=\"all\", use_module_names_only=True\n            ):\n                self.convert_to_lora_layers()\n\n        # Check adapter presence and load missing adapters\n        adapter_updated = False\n        rank = dist.get_rank()\n\n        # load required adapters\n        for nickname, path in zip(lora_nicknames, lora_paths):\n            if nickname not in self.lora_adapters and path is None:\n                raise ValueError(\n                    f\"Adapter {nickname} not found in the pipeline. Please provide lora_path to load it.\"\n                )\n            # Check if adapter needs to be loaded\n            should_load = False\n            if path is not None:\n                if nickname not in self.loaded_adapter_paths:\n                    should_load = True\n                elif self.loaded_adapter_paths[nickname] != path:\n                    should_load = True\n            if should_load:\n                adapter_updated = True\n                self.load_lora_adapter(path, nickname, rank)\n\n        # Group by target to apply separately\n        target_to_indices = {}\n        for idx, tgt in enumerate(targets):\n            if tgt not in target_to_indices:\n                target_to_indices[tgt] = []\n            target_to_indices[tgt].append(idx)\n\n        adapted_count = 0\n        for tgt, idx_list in target_to_indices.items():\n            target_modules, error = self._get_target_lora_layers(tgt)\n            if error:\n                logger.warning(\"set_lora: %s\", error)\n            if not target_modules:\n                continue\n\n            # Disable layerwise offload if enabled: load all layers to GPU\n            # the LoRA weights merging process requires weights being on device\n            with self._temporarily_disable_offload(target_modules=target_modules):\n                tgt_nicknames = [lora_nicknames[i] for i in idx_list]\n                tgt_paths = [lora_paths[i] for i in idx_list]\n                tgt_strengths = [strengths[i] for i in idx_list]\n\n                merged_name = (\n                    \",\".join(tgt_nicknames)\n                    if len(tgt_nicknames) > 1\n                    else tgt_nicknames[0]\n                )\n\n                # Skip if LoRA configuration matches exactly (including order and strength)\n                # Since all modules for the same target apply the same config, checking one is sufficient\n                first_module_name, _ = target_modules[0]\n                if self._check_lora_config_matches(\n                    first_module_name, tgt_nicknames, tgt_strengths, adapter_updated\n                ):\n                    logger.info(\"LoRA configuration matches exactly, skipping\")\n                    continue\n\n                # Apply LoRA to modules for this target\n                for module_name, lora_layers_dict in target_modules:\n                    count = self._apply_lora_to_layers(\n                        lora_layers_dict,\n                        tgt_nicknames,\n                        tgt_paths,\n                        rank,\n                        tgt_strengths,\n                        clear_existing=True,\n                    )\n                    adapted_count += count\n                    self.cur_adapter_name[module_name] = merged_name\n                    self.cur_adapter_path[module_name] = \",\".join(\n                        str(p or self.loaded_adapter_paths.get(n, \"\"))\n                        for n, p in zip(tgt_nicknames, tgt_paths)\n                    )\n                    self.is_lora_merged[module_name] = True\n                    self.cur_adapter_strength[module_name] = tgt_strengths[0]\n                    # Store full configuration for multi-LoRA support (preserves order and all strengths)\n                    self.cur_adapter_config[module_name] = (\n                        tgt_nicknames.copy(),\n                        tgt_strengths.copy(),\n                    )\n\n        logger.info(\n            \"Rank %d: LoRA adapter(s) %s applied to %d layers (targets: %s, strengths: %s)\",\n            rank,\n            \", \".join(map(str, lora_paths)) if lora_paths else None,\n            adapted_count,\n            \", \".join(targets) if len(set(targets)) > 1 else targets[0],\n            (\n                \", \".join(f\"{s:.2f}\" for s in strengths)\n                if len(strengths) > 1\n                else f\"{strengths[0]:.2f}\"\n            ),\n        )\n\n    def merge_lora_weights(self, target: str = \"all\", strength: float = 1.0) -> None:\n        \"\"\"\n        Merge LoRA weights into the base model for the specified target.\n\n        This operation is idempotent - calling it when LoRA is already merged is safe.\n\n        Args:\n            target: Which transformer(s) to merge. One of \"all\", \"transformer\",\n                    \"transformer_2\", \"critic\".\n            strength: LoRA strength for merge, default 1.0.\n        \"\"\"\n        target_modules, error = self._get_target_lora_layers(target)\n        if error:\n            logger.warning(\"merge_lora_weights: %s\", error)\n        if not target_modules:\n            return\n\n        # Disable layerwise offload if enabled: load all layers to GPU\n        with self._temporarily_disable_offload(target_modules=target_modules):\n            for module_name, lora_layers_dict in target_modules:\n                if self.is_lora_merged.get(module_name, False):\n                    # Check if strength is the same - if so, skip (idempotent)\n                    if self.cur_adapter_strength.get(module_name) == strength:\n                        logger.warning(\n                            \"LoRA weights are already merged for %s with same strength\",\n                            module_name,\n                        )\n                        continue\n                    # Different strength requested - allow re-merge (layer handles unmerge internally)\n                    logger.info(\n                        \"Re-merging LoRA weights for %s with new strength %s\",\n                        module_name,\n                        strength,\n                    )\n                for name, layer in lora_layers_dict.items():\n                    # Only re-enable LoRA for layers that actually have LoRA weights\n                    has_lora_weights = (\n                        hasattr(layer, \"lora_A\") and layer.lora_A is not None\n                    )\n                    if not has_lora_weights:\n                        continue\n                    if hasattr(layer, \"disable_lora\"):\n                        layer.disable_lora = False\n                    try:\n                        layer.merge_lora_weights(strength=strength)\n                    except Exception as e:\n                        logger.warning(\"Could not merge layer %s: %s\", name, e)\n                        continue\n                self.is_lora_merged[module_name] = True\n                self.cur_adapter_strength[module_name] = strength\n                logger.info(\n                    \"LoRA weights merged for %s (strength: %s)\", module_name, strength\n                )\n\n    def unmerge_lora_weights(self, target: str = \"all\") -> None:\n        \"\"\"\n        Unmerge LoRA weights from the base model for the specified target.\n        This also disables LoRA so it won't be computed on-the-fly.\n\n        This operation is idempotent - calling it when LoRA is not merged is safe.\n\n        Args:\n            target: Which transformer(s) to unmerge. One of \"all\", \"transformer\",\n                    \"transformer_2\", \"critic\".\n        \"\"\"\n        target_modules, error = self._get_target_lora_layers(target)\n        if error:\n            logger.warning(\"unmerge_lora_weights: %s\", error)\n        if not target_modules:\n            return\n\n        # Disable layerwise offload if enabled: load all layers to GPU\n\n        for module_name, lora_layers_dict in target_modules:\n            if not self.is_lora_merged.get(module_name, False):\n                logger.warning(\n                    \"LoRA weights are not merged for %s, skipping\", module_name\n                )\n                continue\n            with self._temporarily_disable_offload(target_modules=target_modules):\n                for name, layer in lora_layers_dict.items():\n                    # Check layer-level state to avoid raising exception\n                    if hasattr(layer, \"merged\") and not layer.merged:\n                        logger.warning(\"Layer %s is not merged, skipping\", name)\n                        # Still disable LoRA to prevent on-the-fly computation\n                        if hasattr(layer, \"disable_lora\"):\n                            layer.disable_lora = True\n                        continue\n                    try:\n                        layer.unmerge_lora_weights()\n                        # Disable LoRA after unmerge to prevent on-the-fly computation\n                        if hasattr(layer, \"disable_lora\"):\n                            layer.disable_lora = True\n                    except ValueError as e:\n                        logger.warning(\"Could not unmerge layer %s: %s\", name, e)\n                        # Still disable LoRA even if unmerge failed\n                        if hasattr(layer, \"disable_lora\"):\n                            layer.disable_lora = True\n                        continue\n                self.is_lora_merged[module_name] = False\n                self.cur_adapter_strength.pop(module_name, None)\n                self.cur_adapter_config.pop(module_name, None)\n            logger.info(\"LoRA weights unmerged for %s\", module_name)\n\n    def get_lora_status(self) -> dict[str, Any]:\n        \"\"\"\n        Summarize loaded LoRA adapters and current application status per module.\n\n        Returns a plain Python dict with no tensor values to allow safe JSON serialization.\n        \"\"\"\n        # Loaded adapters: list of {nickname, path}\n        loaded_adapters = [\n            {\"nickname\": nickname, \"path\": path}\n            for nickname, path in self.loaded_adapter_paths.items()\n        ]\n\n        def _module_status(module_name: str) -> list[dict] | None:\n            # return list of dict to support multi-lora in the future\n            if not self.is_lora_merged.get(module_name, False):\n                return None\n            else:\n                return [\n                    {\n                        \"nickname\": self.cur_adapter_name.get(module_name, None),\n                        \"path\": self.cur_adapter_path.get(module_name, None),\n                        \"merged\": self.is_lora_merged.get(module_name, False),\n                        \"strength\": self.cur_adapter_strength.get(module_name, None),\n                    }\n                ]\n\n        # Build active usage per module only for modules that exist in this pipeline\n        active: dict[str, Any] = {}\n        if (\n            \"transformer\" in self.modules\n            and self.modules[\"transformer\"] is not None\n            and (status := _module_status(\"transformer\")) is not None\n        ):\n            active[\"transformer\"] = status\n        if (\n            \"transformer_2\" in self.modules\n            and self.modules[\"transformer_2\"] is not None\n            and (status := _module_status(\"transformer_2\")) is not None\n        ):\n            active[\"transformer_2\"] = status\n        if (\n            \"fake_score_transformer\" in self.modules\n            and self.modules[\"fake_score_transformer\"] is not None\n            and (status := _module_status(\"critic\")) is not None\n        ):\n            active[\"critic\"] = status\n\n        return {\n            \"loaded_adapters\": loaded_adapters,\n            \"active\": active,\n        }\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines_core/schedule_batch.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# Inspired by SGLang: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/model_executor/forward_batch_info.py\n\"\"\"\nData structures for functional pipeline processing.\n\nThis module defines the dataclasses used to pass state between pipeline components\nin a functional manner, reducing the need for explicit parameter passing.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport logging\nimport os\nimport pprint\nfrom copy import deepcopy\nfrom dataclasses import MISSING, asdict, dataclass, field, fields\nfrom typing import Any, Optional\n\nimport PIL.Image\nimport torch\n\nfrom sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import (\n    _sanitize_for_logging,\n    init_logger,\n)\nfrom sglang.multimodal_gen.runtime.utils.perf_logger import RequestMetrics\nfrom sglang.multimodal_gen.utils import align_to\n\nlogger = init_logger(__name__)\n\nSAMPLING_PARAMS_FIELDS = {f.name for f in fields(SamplingParams)}\n\n\n@dataclass(init=False)\nclass Req:\n    \"\"\"\n    Complete state passed through the pipeline execution.\n\n    This dataclass contains all information needed during the diffusion pipeline\n    execution, allowing methods to update specific components without needing\n    to manage numerous individual parameters.\n\n    [IMPORTANT] Fields that overlap with SamplingParams are automatically delegated to the\n    sampling_params member via __getattr__ and __setattr__.\n    \"\"\"\n\n    sampling_params: SamplingParams | None = None\n\n    generator: torch.Generator | list[torch.Generator] | None = None\n\n    # Image encoder hidden states\n    image_embeds: list[torch.Tensor] = field(default_factory=list)\n\n    original_condition_image_size: tuple[int, int] = None\n    condition_image: torch.Tensor | PIL.Image.Image | None = None\n    vae_image: torch.Tensor | PIL.Image.Image | None = None\n    pixel_values: torch.Tensor | PIL.Image.Image | None = None\n    preprocessed_image: torch.Tensor | None = None\n\n    output_file_ext: str | None = None\n    # Primary encoder embeddings\n    prompt_embeds: list[torch.Tensor] | torch.Tensor = field(default_factory=list)\n    negative_prompt_embeds: list[torch.Tensor] | None = None\n    prompt_attention_mask: list[torch.Tensor] | None = None\n    negative_attention_mask: list[torch.Tensor] | None = None\n    clip_embedding_pos: list[torch.Tensor] | None = None\n    clip_embedding_neg: list[torch.Tensor] | None = None\n\n    pooled_embeds: list[torch.Tensor] = field(default_factory=list)\n    neg_pooled_embeds: list[torch.Tensor] = field(default_factory=list)\n\n    # Additional text-related parameters\n    max_sequence_length: int | None = None\n    prompt_template: dict[str, Any] | None = None\n    do_classifier_free_guidance: bool = False\n\n    seeds: list[int] | None = None\n\n    # Tracking if embeddings are already processed\n    is_prompt_processed: bool = False\n\n    # Audio Embeddings (LTX-2)\n    audio_prompt_embeds: list[torch.Tensor] | torch.Tensor = field(default_factory=list)\n    negative_audio_prompt_embeds: list[torch.Tensor] | torch.Tensor = field(\n        default_factory=list\n    )\n\n    # Latent tensors\n    latents: torch.Tensor | None = None\n    y: torch.Tensor | None = None\n    # Flux-2\n    latent_ids: torch.Tensor | None = None\n\n    # Audio Latents\n    audio_latents: torch.Tensor | None = None\n    audio_noise: torch.Tensor | None = None\n    raw_audio_latent_shape: tuple[int, ...] | None = None\n\n    # Audio Parameters\n    generate_audio: bool = True\n\n    raw_latent_shape: torch.Tensor | None = None\n    noise_pred: torch.Tensor | None = None\n    # vae-encoded condition image\n    image_latent: torch.Tensor | list[torch.Tensor] | None = None\n    condition_image_latent_ids: torch.Tensor | list[torch.Tensor] | None = None\n    vae_image_sizes: list[tuple[int, int]] | None = None\n\n    # Latent dimensions\n    height_latents: list[int] | int | None = None\n    width_latents: list[int] | int | None = None\n\n    # Timesteps\n    timesteps: torch.Tensor | None = None\n    paired_timesteps: torch.Tensor | None = None\n    timestep: torch.Tensor | float | int | None = None\n    step_index: int | None = None\n\n    eta: float = 0.0\n    sigmas: list[float] | None = None\n\n    n_tokens: int | None = None\n\n    # Other parameters that may be needed by specific schedulers\n    extra_step_kwargs: dict[str, Any] = field(default_factory=dict)\n\n    # Component modules (populated by the pipeline)\n    modules: dict[str, Any] = field(default_factory=dict)\n\n    trajectory_timesteps: list[torch.Tensor] | None = None\n    trajectory_latents: torch.Tensor | None = None\n    trajectory_audio_latents: torch.Tensor | None = None\n\n    # Extra parameters that might be needed by specific pipeline implementations\n    extra: dict[str, Any] = field(default_factory=dict)\n\n    is_warmup: bool = False\n\n    # STA parameters\n    STA_param: list | None = None\n    is_cfg_negative: bool = False\n    mask_search_final_result_pos: list[list] | None = None\n    mask_search_final_result_neg: list[list] | None = None\n\n    # VSA parameters\n    VSA_sparsity: float = 0.0\n\n    # stage logging\n    metrics: Optional[\"RequestMetrics\"] = None\n\n    # results\n    output: torch.Tensor | None = None\n    audio: torch.Tensor | None = None\n    audio_sample_rate: int | None = None\n\n    def __init__(self, **kwargs):\n        # Initialize dataclass fields\n        for name, field in self.__class__.__dataclass_fields__.items():\n            if name in kwargs:\n                object.__setattr__(self, name, kwargs.pop(name))\n            elif field.default is not MISSING:\n                object.__setattr__(self, name, field.default)\n            elif field.default_factory is not MISSING:\n                object.__setattr__(self, name, field.default_factory())\n\n        for name, value in kwargs.items():\n            setattr(self, name, value)\n\n        self.validate()\n\n    def __getattr__(self, name: str) -> Any:\n        \"\"\"\n        Delegate attribute access to sampling_params if not found in Req.\n        This is only called when the attribute is not found in the instance.\n        \"\"\"\n        if name == \"sampling_params\":\n            raise AttributeError(\n                f\"'{type(self).__name__}' object has no attribute '{name}'\"\n            )\n\n        sampling_params = object.__getattribute__(self, \"sampling_params\")\n        if sampling_params is not None and hasattr(sampling_params, name):\n            return getattr(sampling_params, name)\n\n        raise AttributeError(\n            f\"'{type(self).__name__}' object has no attribute '{name}'\"\n        )\n\n    def __setattr__(self, name: str, value: Any) -> None:\n        \"\"\"\n        Smart attribute setting:\n        1. If field exists in Req, set it in Req\n        2. Else if field exists in sampling_params, set it in sampling_params\n        3. Else set it in Req (for dynamic attributes)\n        \"\"\"\n        if name == \"sampling_params\":\n            object.__setattr__(self, name, value)\n            return\n\n        if name in self.__class__.__dataclass_fields__:\n            object.__setattr__(self, name, value)\n            return\n\n        try:\n            sampling_params = object.__getattribute__(self, \"sampling_params\")\n        except AttributeError:\n            sampling_params = None\n\n        if sampling_params is not None and hasattr(sampling_params, name):\n            setattr(sampling_params, name, value)\n            return\n\n        if sampling_params is None and name in SAMPLING_PARAMS_FIELDS:\n            new_sp = SamplingParams()\n            object.__setattr__(self, \"sampling_params\", new_sp)\n            setattr(new_sp, name, value)\n            return\n\n        object.__setattr__(self, name, value)\n\n    @property\n    def batch_size(self):\n        # Determine batch size\n        if isinstance(self.prompt, list):\n            batch_size = len(self.prompt)\n        elif self.prompt is not None:\n            batch_size = 1\n        else:\n            batch_size = self.prompt_embeds[0].shape[0]\n\n        # Adjust batch size for number of videos per prompt\n        batch_size *= self.num_outputs_per_prompt\n        return batch_size\n\n    def output_file_path(self, num_outputs=1, output_idx=None):\n        output_file_name = self.output_file_name\n        if num_outputs > 1 and output_file_name:\n            base, ext = os.path.splitext(output_file_name)\n            output_file_name = f\"{base}_{output_idx}{ext}\"\n\n        if self.output_path is None or not output_file_name:\n            return None\n        return os.path.join(self.output_path, output_file_name)\n\n    def set_as_warmup(self, warmup_steps: int = 1):\n        self.is_warmup = True\n        self.save_output = False\n        self.suppress_logs = True\n        self.extra[\"cache_dit_num_inference_steps\"] = self.num_inference_steps\n        self.num_inference_steps = warmup_steps\n\n    def copy_as_warmup(self, warmup_steps: int = 1) -> \"Req\":\n        req = deepcopy(self)\n        req.set_as_warmup(warmup_steps)\n        return req\n\n    def validate(self):\n        \"\"\"Initialize dependent fields after dataclass initialization.\"\"\"\n        # Set do_classifier_free_guidance based on guidance scale and negative prompt\n        if self.guidance_scale > 1.0 and self.negative_prompt is not None:\n            self.do_classifier_free_guidance = True\n        if self.negative_prompt_embeds is None:\n            self.negative_prompt_embeds = []\n        if self.guidance_scale_2 is None:\n            self.guidance_scale_2 = self.guidance_scale\n\n        self.metrics = RequestMetrics(request_id=self.request_id)\n\n    def adjust_size(self, server_args: ServerArgs):\n        pass\n\n    def __str__(self):\n        return pprint.pformat(asdict(self), indent=2, width=120)\n\n    def log(self, server_args: ServerArgs):\n        if self.is_warmup or self.suppress_logs:\n            return\n        # TODO: in some cases (e.g., TI2I), height and weight might be undecided at this moment\n        if self.height:\n            target_height = align_to(self.height, 16)\n        else:\n            target_height = -1\n        if self.width:\n            target_width = align_to(self.width, 16)\n        else:\n            target_width = -1\n\n        if logger.isEnabledFor(logging.DEBUG):\n            display_prompt = self.prompt\n            display_neg_prompt = self.negative_prompt\n        else:\n            display_prompt = _sanitize_for_logging(self.prompt, key_hint=\"prompt\")\n            display_neg_prompt = _sanitize_for_logging(\n                self.negative_prompt, key_hint=\"negative_prompt\"\n            )\n\n        debug_str = f\"\"\"Sampling params:\n                       width: {target_width}\n                      height: {target_height}\n                  num_frames: {self.num_frames}\n                         fps: {self.fps}\n                      prompt: {display_prompt}\n                  neg_prompt: {display_neg_prompt}\n                        seed: {self.seed}\n                 infer_steps: {self.num_inference_steps}\n      num_outputs_per_prompt: {self.num_outputs_per_prompt}\n              guidance_scale: {self.guidance_scale}\n     embedded_guidance_scale: {server_args.pipeline_config.embedded_cfg_scale}\n                    n_tokens: {self.n_tokens}\n                  flow_shift: {server_args.pipeline_config.flow_shift}\n                  image_path: {self.image_path}\n                 save_output: {self.save_output}\n            output_file_path: {self.output_file_path()}\n        \"\"\"  # type: ignore[attr-defined]\n        logger.info(debug_str)\n\n\n@dataclass\nclass OutputBatch:\n    \"\"\"\n    Final output (after pipeline completion)\n    \"\"\"\n\n    output: torch.Tensor | None = None\n    audio: torch.Tensor | None = None\n    audio_sample_rate: int | None = None\n    trajectory_timesteps: list[torch.Tensor] | None = None\n    trajectory_latents: torch.Tensor | None = None\n    trajectory_decoded: list[torch.Tensor] | None = None\n    error: str | None = None\n    output_file_paths: list[str] | None = None\n\n    # logged metrics info, directly from Req.timings\n    metrics: Optional[\"RequestMetrics\"] = None\n\n    # For ComfyUI integration: noise prediction from denoising stage\n    noise_pred: torch.Tensor | None = None\n    peak_memory_mb: float = 0.0\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines_core/stages/__init__.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"\nPipeline stages for diffusion models.\n\nThis package contains the various stages that can be composed to create\ncomplete diffusion pipelines.\n\"\"\"\n\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.base import PipelineStage\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.causal_denoising import (\n    CausalDMDDenoisingStage,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.comfyui_latent_preparation import (\n    ComfyUILatentPreparationStage,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.decoding import DecodingStage\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.decoding_av import (\n    LTX2AVDecodingStage,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.denoising import DenoisingStage\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.denoising_av import (\n    LTX2AVDenoisingStage,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.denoising_dmd import (\n    DmdDenoisingStage,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.encoding import EncodingStage\n\n# Hunyuan3D paint stages\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.hunyuan3d_paint import (\n    Hunyuan3DPaintPostprocessStage,\n    Hunyuan3DPaintPreprocessStage,\n    Hunyuan3DPaintTexGenStage,\n)\n\n# Hunyuan3D shape stages\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.hunyuan3d_shape import (\n    Hunyuan3DShapeBeforeDenoisingStage,\n    Hunyuan3DShapeDenoisingStage,\n    Hunyuan3DShapeExportStage,\n    Hunyuan3DShapeSaveStage,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.image_encoding import (\n    ImageEncodingStage,\n    ImageVAEEncodingStage,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.input_validation import (\n    InputValidationStage,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.latent_preparation import (\n    LatentPreparationStage,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.latent_preparation_av import (\n    LTX2AVLatentPreparationStage,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.text_connector import (\n    LTX2TextConnectorStage,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.text_encoding import (\n    TextEncodingStage,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.timestep_preparation import (\n    TimestepPreparationStage,\n)\n\n__all__ = [\n    \"PipelineStage\",\n    \"InputValidationStage\",\n    \"TimestepPreparationStage\",\n    \"LatentPreparationStage\",\n    \"ComfyUILatentPreparationStage\",\n    \"LTX2AVLatentPreparationStage\",\n    \"DenoisingStage\",\n    \"DmdDenoisingStage\",\n    \"LTX2AVDenoisingStage\",\n    \"CausalDMDDenoisingStage\",\n    \"EncodingStage\",\n    \"DecodingStage\",\n    \"LTX2AVDecodingStage\",\n    \"ImageEncodingStage\",\n    \"ImageVAEEncodingStage\",\n    \"TextEncodingStage\",\n    \"LTX2TextConnectorStage\",\n    # Hunyuan3D shape stages\n    \"Hunyuan3DShapeBeforeDenoisingStage\",\n    \"Hunyuan3DShapeDenoisingStage\",\n    \"Hunyuan3DShapeExportStage\",\n    \"Hunyuan3DShapeSaveStage\",\n    # Hunyuan3D paint stages\n    \"Hunyuan3DPaintPreprocessStage\",\n    \"Hunyuan3DPaintTexGenStage\",\n    \"Hunyuan3DPaintPostprocessStage\",\n]\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines_core/stages/base.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"\nBase classes for pipeline stages.\n\nThis module defines the abstract base classes for pipeline stages that can be\ncomposed to create complete diffusion pipelines.\n\"\"\"\n\nfrom abc import ABC, abstractmethod\nfrom enum import Enum, auto\n\nimport torch\n\nfrom sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.validators import (\n    VerificationResult,\n)\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs, get_global_server_args\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.runtime.utils.perf_logger import StageProfiler\n\nlogger = init_logger(__name__)\n\n\nclass StageParallelismType(Enum):\n    # execute on all gpus\n    REPLICATED = auto()\n    # executed on main rank only\n    MAIN_RANK_ONLY = auto()\n    # this stage requires a cfg-parallel\n    CFG_PARALLEL = auto()\n\n\nclass StageVerificationError(Exception):\n    \"\"\"Exception raised when stage verification fails.\"\"\"\n\n    pass\n\n\nclass PipelineStage(ABC):\n    \"\"\"\n    Abstract base class for all pipeline stages.\n\n    A pipeline stage represents a discrete step in the diffusion process that can be\n    composed with other stages to create a complete pipeline. Each stage is responsible\n    for a specific part of the process, such as prompt encoding, latent preparation, etc.\n    \"\"\"\n\n    def __init__(self):\n        self.server_args = get_global_server_args()\n\n    def log_info(self, msg, *args):\n        \"\"\"Logs an informational message with the stage name as a prefix.\"\"\"\n        if self.server_args.comfyui_mode:\n            return\n        logger.info(f\"[{self.__class__.__name__}] {msg}\", *args)\n\n    def log_warning(self, msg, *args):\n        \"\"\"Logs a warning message with the stage name as a prefix.\"\"\"\n        logger.warning(f\"[{self.__class__.__name__}] {msg}\", *args)\n\n    def log_error(self, msg, *args):\n        \"\"\"Logs an error message with the stage name as a prefix.\"\"\"\n        logger.error(f\"[{self.__class__.__name__}] {msg}\", *args)\n\n    def log_debug(self, msg, *args):\n        \"\"\"Logs a debug message with the stage name as a prefix.\"\"\"\n        logger.debug(f\"[{self.__class__.__name__}] {msg}\", *args)\n\n    def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult:\n        \"\"\"\n        Verify the input for the stage.\n\n        Example:\n            from sglang.multimodal_gen.runtime.pipelines.stages.validators import V, VerificationResult\n\n            def verify_input(self, batch, server_args):\n                result = VerificationResult()\n                result.add_check(\"height\", batch.height, V.positive_int_divisible(8))\n                result.add_check(\"width\", batch.width, V.positive_int_divisible(8))\n                result.add_check(\"image_latent\", batch.image_latent, V.is_tensor)\n                return result\n\n        \"\"\"\n        # Default implementation - no verification\n        return VerificationResult()\n\n    def maybe_free_model_hooks(self):\n        pass\n\n    def load_model(self):\n        \"\"\"\n        Load the model for the stage.\n        \"\"\"\n        pass\n\n    def offload_model(self):\n        \"\"\"\n        Offload the model for the stage.\n        \"\"\"\n        pass\n\n    # execute on all ranks by default\n    @property\n    def parallelism_type(self) -> StageParallelismType:\n        # if get_global_server_args().enable_cfg_parallel:\n        #     return StageParallelismType.MAIN_RANK_ONLY\n        return StageParallelismType.REPLICATED\n\n    def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult:\n        \"\"\"\n        Verify the output for the stage.\n\n\n\n        Returns:\n            A VerificationResult containing the verification status.\n        \"\"\"\n        # Default implementation - no verification\n        return VerificationResult()\n\n    def _run_verification(\n        self,\n        verification_result: VerificationResult,\n        stage_name: str,\n        verification_type: str,\n    ) -> None:\n        \"\"\"\n        Run verification and raise errors if any checks fail.\n\n        Args:\n            verification_result: Results from verify_input or verify_output\n            stage_name: Name of the current stage\n            verification_type: \"input\" or \"output\"\n        \"\"\"\n        if not verification_result.is_valid():\n            failed_fields = verification_result.get_failed_fields()\n            if failed_fields:\n                # Get detailed failure information\n                detailed_summary = verification_result.get_failure_summary()\n\n                failed_fields_str = \", \".join(failed_fields)\n                error_msg = (\n                    f\"{verification_type.capitalize()} verification failed for {stage_name}: \"\n                    f\"Failed fields: {failed_fields_str}\\n\"\n                    f\"Details: {detailed_summary}\"\n                )\n                raise StageVerificationError(error_msg)\n\n    @property\n    def device(self) -> torch.device:\n        \"\"\"Get the device for this stage.\"\"\"\n        return torch.device(\n            current_platform.device_type,\n        )\n\n    def set_logging(self, enable: bool):\n        \"\"\"\n        Enable or disable logging for this stage.\n\n        Args:\n            enable: Whether to enable logging.\n        \"\"\"\n        self._enable_logging = enable\n\n    def __call__(\n        self,\n        batch: Req,\n        server_args: ServerArgs,\n    ) -> Req:\n        \"\"\"\n        Execute the stage's processing on the batch with optional verification and logging.\n        Should not be overridden by subclasses.\n\n\n\n        Returns:\n            The updated batch information after this stage's processing.\n        \"\"\"\n        stage_name = self.__class__.__name__\n        # Check if verification is enabled (simple approach for prototype)\n\n        # Pre-execution input verification\n        try:\n            input_result = self.verify_input(batch, server_args)\n            self._run_verification(input_result, stage_name, \"input\")\n        except Exception as e:\n            logger.error(\"Input verification failed for %s: %s\", stage_name, str(e))\n            raise\n\n        # Execute the actual stage logic with unified profiling\n        with StageProfiler(\n            stage_name,\n            logger=logger,\n            metrics=batch.metrics,\n            log_stage_start_end=not batch.is_warmup\n            and not (self.server_args and self.server_args.comfyui_mode),\n            perf_dump_path_provided=batch.perf_dump_path is not None,\n        ):\n            result = self.forward(batch, server_args)\n\n        # Post-execution output verification\n        try:\n            output_result = self.verify_output(result, server_args)\n            self._run_verification(output_result, stage_name, \"output\")\n        except Exception as e:\n            logger.error(\"Output verification failed for %s: %s\", stage_name, str(e))\n            raise\n\n        return result\n\n    @abstractmethod\n    def forward(\n        self,\n        batch: Req,\n        server_args: ServerArgs,\n    ) -> Req:\n        \"\"\"\n        Forward pass of the stage's processing.\n\n        This method should be implemented by subclasses to provide the forward\n        processing logic for the stage.\n\n\n\n        Returns:\n            The updated batch information after this stage's processing.\n        \"\"\"\n        raise NotImplementedError\n\n    def backward(\n        self,\n        batch: Req,\n        server_args: ServerArgs,\n    ) -> Req:\n        raise NotImplementedError\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines_core/stages/causal_denoising.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\nimport torch  # type: ignore\n\nfrom sglang.multimodal_gen.runtime.distributed import get_local_torch_device\nfrom sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context\nfrom sglang.multimodal_gen.runtime.models.utils import pred_noise_to_pred_video\nfrom sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.denoising import DenoisingStage\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.validators import (\n    StageValidators as V,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.validators import (\n    VerificationResult,\n)\nfrom sglang.multimodal_gen.runtime.platforms import (\n    AttentionBackendEnum,\n    current_platform,\n)\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\nclass CausalDMDDenoisingStage(DenoisingStage):\n    \"\"\"\n    Denoising stage for causal diffusion.\n    \"\"\"\n\n    def __init__(self, transformer, scheduler) -> None:\n        super().__init__(transformer, scheduler)\n        # KV and cross-attention cache state (initialized on first forward)\n        self.kv_cache1: list | None = None\n        self.crossattn_cache: list | None = None\n        # Model-dependent constants (aligned with causal_inference.py assumptions)\n        self.num_transformer_blocks = self.transformer.config.arch_config.num_layers\n        self.num_frames_per_block = (\n            self.transformer.config.arch_config.num_frames_per_block\n        )\n        self.sliding_window_num_frames = (\n            self.transformer.config.arch_config.sliding_window_num_frames\n        )\n\n        try:\n            self.local_attn_size = getattr(\n                self.transformer.model, \"local_attn_size\", -1\n            )  # type: ignore\n        except Exception:\n            self.local_attn_size = -1\n\n    def forward(\n        self,\n        batch: Req,\n        server_args: ServerArgs,\n    ) -> Req:\n        target_dtype = torch.bfloat16\n        autocast_enabled = (\n            target_dtype != torch.float32\n        ) and not server_args.disable_autocast\n\n        latent_seq_length = batch.latents.shape[-1] * batch.latents.shape[-2]\n        patch_ratio = (\n            self.transformer.config.arch_config.patch_size[-1]\n            * self.transformer.config.arch_config.patch_size[-2]\n        )\n        self.frame_seq_length = latent_seq_length // patch_ratio\n        # TODO(will): make this a parameter once we add i2v support\n        independent_first_frame = self.transformer.independent_first_frame\n\n        # Timesteps for DMD\n        timesteps = torch.tensor(\n            server_args.pipeline_config.dmd_denoising_steps, dtype=torch.long\n        ).cpu()\n\n        if server_args.pipeline_config.warp_denoising_step:\n            logger.info(\"Warping timesteps...\")\n            scheduler_timesteps = torch.cat(\n                (self.scheduler.timesteps.cpu(), torch.tensor([0], dtype=torch.float32))\n            )\n            timesteps = scheduler_timesteps[1000 - timesteps]\n        timesteps = timesteps.to(get_local_torch_device())\n        logger.info(\"Using timesteps: %s\", timesteps)\n\n        # Image kwargs (kept empty unless caller provides compatible args)\n        image_kwargs: dict = {}\n\n        pos_cond_kwargs = self.prepare_extra_func_kwargs(\n            self.transformer.forward,\n            {\n                # \"encoder_hidden_states_2\": batch.clip_embedding_pos,\n                \"encoder_attention_mask\": batch.prompt_attention_mask,\n            },\n        )\n\n        # STA\n        if self.attn_backend.get_enum() == AttentionBackendEnum.SLIDING_TILE_ATTN:\n            self.prepare_sta_param(batch, server_args)\n\n        # Latents and prompts\n        assert batch.latents is not None, \"latents must be provided\"\n        latents = batch.latents  # [B, C, T, H, W]\n        b, c, t, h, w = latents.shape\n        prompt_embeds = batch.prompt_embeds\n        assert torch.isnan(prompt_embeds[0]).sum() == 0\n\n        # Initialize or reset caches\n        if self.kv_cache1 is None:\n            self._initialize_kv_cache(\n                batch_size=latents.shape[0], dtype=target_dtype, device=latents.device\n            )\n            self._initialize_crossattn_cache(\n                batch_size=latents.shape[0],\n                max_text_len=server_args.pipeline_config.text_encoder_configs[\n                    0\n                ].arch_config.text_len,\n                dtype=target_dtype,\n                device=latents.device,\n            )\n        else:\n            assert self.crossattn_cache is not None\n            # reset cross-attention cache\n            for block_index in range(self.num_transformer_blocks):\n                self.crossattn_cache[block_index][\"is_init\"] = False  # type: ignore\n            # reset kv cache pointers\n            for block_index in range(len(self.kv_cache1)):\n                self.kv_cache1[block_index][\"global_end_index\"] = (\n                    torch.tensor(  # type: ignore\n                        [0], dtype=torch.long, device=latents.device\n                    )\n                )\n                self.kv_cache1[block_index][\"local_end_index\"] = (\n                    torch.tensor(  # type: ignore\n                        [0], dtype=torch.long, device=latents.device\n                    )\n                )\n\n        # Optional: cache context features from provided image latents prior to generation\n        current_start_frame = 0\n        if getattr(batch, \"image_latent\", None) is not None:\n            image_latent = batch.image_latent\n            assert image_latent is not None\n            input_frames = image_latent.shape[2]\n            # timestep zero (or configured context noise) for cache warm-up\n            t_zero = torch.zeros(\n                [latents.shape[0]], device=latents.device, dtype=torch.long\n            )\n            if independent_first_frame and input_frames >= 1:\n                # warm-up with the very first frame independently\n                image_first_btchw = (\n                    image_latent[:, :, :1, :, :].to(target_dtype).permute(0, 2, 1, 3, 4)\n                )\n                with torch.autocast(\n                    device_type=current_platform.device_type,\n                    dtype=target_dtype,\n                    enabled=autocast_enabled,\n                ):\n                    _ = self.transformer(\n                        image_first_btchw,\n                        prompt_embeds,\n                        t_zero,\n                        kv_cache=self.kv_cache1,\n                        crossattn_cache=self.crossattn_cache,\n                        current_start=current_start_frame * self.frame_seq_length,\n                        **image_kwargs,\n                        **pos_cond_kwargs,\n                    )\n                current_start_frame += 1\n                remaining_frames = input_frames - 1\n            else:\n                remaining_frames = input_frames\n\n            # process remaining input frames in blocks of num_frame_per_block\n            while remaining_frames > 0:\n                block = min(self.num_frames_per_block, remaining_frames)\n                ref_btchw = (\n                    image_latent[\n                        :, :, current_start_frame : current_start_frame + block, :, :\n                    ]\n                    .to(target_dtype)\n                    .permute(0, 2, 1, 3, 4)\n                )\n                with torch.autocast(\n                    device_type=current_platform.device_type,\n                    dtype=target_dtype,\n                    enabled=autocast_enabled,\n                ):\n                    _ = self.transformer(\n                        ref_btchw,\n                        prompt_embeds,\n                        t_zero,\n                        kv_cache=self.kv_cache1,\n                        crossattn_cache=self.crossattn_cache,\n                        current_start=current_start_frame * self.frame_seq_length,\n                        **image_kwargs,\n                        **pos_cond_kwargs,\n                    )\n                current_start_frame += block\n                remaining_frames -= block\n\n        # Base position offset from any cache warm-up\n        pos_start_base = current_start_frame\n\n        # Determine block sizes\n        if not independent_first_frame or (\n            independent_first_frame and batch.image_latent is not None\n        ):\n            if t % self.num_frames_per_block != 0:\n                raise ValueError(\n                    \"num_frames must be divisible by num_frames_per_block for causal DMD denoising\"\n                )\n            num_blocks = t // self.num_frames_per_block\n            block_sizes = [self.num_frames_per_block] * num_blocks\n            start_index = 0\n        else:\n            if (t - 1) % self.num_frames_per_block != 0:\n                raise ValueError(\n                    \"(num_frames - 1) must be divisible by num_frame_per_block when independent_first_frame=True\"\n                )\n            num_blocks = (t - 1) // self.num_frames_per_block\n            block_sizes = [1] + [self.num_frames_per_block] * num_blocks\n            start_index = 0\n\n        # DMD loop in causal blocks\n        with self.progress_bar(total=len(block_sizes) * len(timesteps)) as progress_bar:\n            for current_num_frames in block_sizes:\n                current_latents = latents[\n                    :, :, start_index : start_index + current_num_frames, :, :\n                ]\n                # use BTCHW for DMD conversion routines\n                noise_latents_btchw = current_latents.permute(0, 2, 1, 3, 4)\n                video_raw_latent_shape = noise_latents_btchw.shape\n\n                for i, t_cur in enumerate(timesteps):\n                    # Copy for pred conversion\n                    noise_latents = noise_latents_btchw.clone()\n                    latent_model_input = current_latents.to(target_dtype)\n\n                    if (\n                        batch.image_latent is not None\n                        and independent_first_frame\n                        and start_index == 0\n                    ):\n                        latent_model_input = torch.cat(\n                            [latent_model_input, batch.image_latent.to(target_dtype)],\n                            dim=2,\n                        )\n\n                    # Prepare inputs\n                    t_expand = t_cur.repeat(latent_model_input.shape[0])\n\n                    # Attention metadata if needed\n                    if (\n                        self.attn_backend.get_enum()\n                        == AttentionBackendEnum.VIDEO_SPARSE_ATTN\n                    ):\n                        self.attn_metadata_builder_cls = (\n                            self.attn_backend.get_builder_cls()\n                        )\n                        if self.attn_metadata_builder_cls is not None:\n                            self.attn_metadata_builder = (\n                                self.attn_metadata_builder_cls()\n                            )\n                            attn_metadata = self.attn_metadata_builder.build(  # type: ignore\n                                current_timestep=i,  # type: ignore\n                                raw_latent_shape=(\n                                    current_num_frames,\n                                    h,\n                                    w,\n                                ),  # type: ignore\n                                patch_size=server_args.pipeline_config.dit_config.patch_size,  # type: ignore\n                                STA_param=batch.STA_param,  # type: ignore\n                                VSA_sparsity=server_args.attention_backend_config.VSA_sparsity,  # type: ignore\n                                device=get_local_torch_device(),  # type: ignore\n                            )  # type: ignore\n                            assert (\n                                attn_metadata is not None\n                            ), \"attn_metadata cannot be None\"\n                        else:\n                            attn_metadata = None\n                    else:\n                        attn_metadata = None\n\n                    with (\n                        torch.autocast(\n                            device_type=current_platform.device_type,\n                            dtype=target_dtype,\n                            enabled=autocast_enabled,\n                        ),\n                        set_forward_context(\n                            current_timestep=i,\n                            attn_metadata=attn_metadata,\n                            forward_batch=batch,\n                        ),\n                    ):\n                        # Run transformer; follow DMD stage pattern\n                        t_expanded_noise = t_cur * torch.ones(\n                            (latent_model_input.shape[0], 1),\n                            device=latent_model_input.device,\n                            dtype=torch.long,\n                        )\n                        pred_noise_btchw = self.transformer(\n                            latent_model_input,\n                            prompt_embeds,\n                            t_expanded_noise,\n                            kv_cache=self.kv_cache1,\n                            crossattn_cache=self.crossattn_cache,\n                            current_start=(pos_start_base + start_index)\n                            * self.frame_seq_length,\n                            start_frame=start_index,\n                            **image_kwargs,\n                            **pos_cond_kwargs,\n                        ).permute(0, 2, 1, 3, 4)\n\n                    # Convert pred noise to pred video with FM Euler scheduler utilities\n                    pred_video_btchw = pred_noise_to_pred_video(\n                        pred_noise=pred_noise_btchw.flatten(0, 1),\n                        noise_input_latent=noise_latents.flatten(0, 1),\n                        timestep=t_expand,\n                        scheduler=self.scheduler,\n                    ).unflatten(0, pred_noise_btchw.shape[:2])\n\n                    if i < len(timesteps) - 1:\n                        next_timestep = timesteps[i + 1] * torch.ones(\n                            [1], dtype=torch.long, device=pred_video_btchw.device\n                        )\n                        noise = torch.randn(\n                            video_raw_latent_shape,\n                            dtype=pred_video_btchw.dtype,\n                            generator=(\n                                batch.generator[0]\n                                if isinstance(batch.generator, list)\n                                else batch.generator\n                            ),\n                            device=self.device,\n                        )\n                        noise_btchw = noise\n                        noise_latents_btchw = self.scheduler.add_noise(\n                            pred_video_btchw.flatten(0, 1),\n                            noise_btchw.flatten(0, 1),\n                            next_timestep,\n                        ).unflatten(0, pred_video_btchw.shape[:2])\n                        current_latents = noise_latents_btchw.permute(0, 2, 1, 3, 4)\n                    else:\n                        current_latents = pred_video_btchw.permute(0, 2, 1, 3, 4)\n\n                    if progress_bar is not None:\n                        progress_bar.update()\n\n                # Write back and advance\n                latents[:, :, start_index : start_index + current_num_frames, :, :] = (\n                    current_latents\n                )\n\n                # Re-run with context timestep to update KV cache using clean context\n                context_noise = getattr(server_args.pipeline_config, \"context_noise\", 0)\n                t_context = torch.ones(\n                    [latents.shape[0]], device=latents.device, dtype=torch.long\n                ) * int(context_noise)\n                context_bcthw = current_latents.to(target_dtype)\n                with (\n                    torch.autocast(\n                        device_type=current_platform.device_type,\n                        dtype=target_dtype,\n                        enabled=autocast_enabled,\n                    ),\n                    set_forward_context(\n                        current_timestep=0,\n                        attn_metadata=attn_metadata,\n                        forward_batch=batch,\n                    ),\n                ):\n                    t_expanded_context = t_context.unsqueeze(1)\n                    _ = self.transformer(\n                        context_bcthw,\n                        prompt_embeds,\n                        t_expanded_context,\n                        kv_cache=self.kv_cache1,\n                        crossattn_cache=self.crossattn_cache,\n                        current_start=(pos_start_base + start_index)\n                        * self.frame_seq_length,\n                        start_frame=start_index,\n                        **image_kwargs,\n                        **pos_cond_kwargs,\n                    )\n                start_index += current_num_frames\n\n        batch.latents = latents\n        return batch\n\n    def _initialize_kv_cache(self, batch_size, dtype, device) -> None:\n        \"\"\"\n        Initialize a Per-GPU KV cache aligned with the Wan model assumptions.\n        \"\"\"\n        kv_cache1 = []\n        num_attention_heads = self.transformer.num_attention_heads\n        attention_head_dim = self.transformer.attention_head_dim\n        if self.local_attn_size != -1:\n            kv_cache_size = self.local_attn_size * self.frame_seq_length\n        else:\n            kv_cache_size = self.frame_seq_length * self.sliding_window_num_frames\n\n        for _ in range(self.num_transformer_blocks):\n            kv_cache1.append(\n                {\n                    \"k\": torch.zeros(\n                        [\n                            batch_size,\n                            kv_cache_size,\n                            num_attention_heads,\n                            attention_head_dim,\n                        ],\n                        dtype=dtype,\n                        device=device,\n                    ),\n                    \"v\": torch.zeros(\n                        [\n                            batch_size,\n                            kv_cache_size,\n                            num_attention_heads,\n                            attention_head_dim,\n                        ],\n                        dtype=dtype,\n                        device=device,\n                    ),\n                    \"global_end_index\": torch.tensor(\n                        [0], dtype=torch.long, device=device\n                    ),\n                    \"local_end_index\": torch.tensor(\n                        [0], dtype=torch.long, device=device\n                    ),\n                }\n            )\n\n        self.kv_cache1 = kv_cache1\n\n    def _initialize_crossattn_cache(\n        self, batch_size, max_text_len, dtype, device\n    ) -> None:\n        \"\"\"\n        Initialize a Per-GPU cross-attention cache aligned with the Wan model assumptions.\n        \"\"\"\n        crossattn_cache = []\n        num_attention_heads = self.transformer.num_attention_heads\n        attention_head_dim = self.transformer.attention_head_dim\n        for _ in range(self.num_transformer_blocks):\n            crossattn_cache.append(\n                {\n                    \"k\": torch.zeros(\n                        [\n                            batch_size,\n                            max_text_len,\n                            num_attention_heads,\n                            attention_head_dim,\n                        ],\n                        dtype=dtype,\n                        device=device,\n                    ),\n                    \"v\": torch.zeros(\n                        [\n                            batch_size,\n                            max_text_len,\n                            num_attention_heads,\n                            attention_head_dim,\n                        ],\n                        dtype=dtype,\n                        device=device,\n                    ),\n                    \"is_init\": False,\n                }\n            )\n        self.crossattn_cache = crossattn_cache\n\n    def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult:\n        \"\"\"Verify denoising stage inputs.\"\"\"\n        result = VerificationResult()\n        result.add_check(\"latents\", batch.latents, [V.is_tensor, V.with_dims(5)])\n        result.add_check(\"prompt_embeds\", batch.prompt_embeds, V.list_not_empty)\n        result.add_check(\"image_embeds\", batch.image_embeds, V.is_list)\n        result.add_check(\n            \"image_latent\", batch.image_latent, V.none_or_tensor_with_dims(5)\n        )\n        result.add_check(\n            \"num_inference_steps\", batch.num_inference_steps, V.positive_int\n        )\n        result.add_check(\"guidance_scale\", batch.guidance_scale, V.non_negative_float)\n        result.add_check(\"eta\", batch.eta, V.non_negative_float)\n        result.add_check(\"generator\", batch.generator, V.generator_or_list_generators)\n        result.add_check(\n            \"do_classifier_free_guidance\",\n            batch.do_classifier_free_guidance,\n            V.bool_value,\n        )\n        result.add_check(\n            \"negative_prompt_embeds\",\n            batch.negative_prompt_embeds,\n            lambda x: not batch.do_classifier_free_guidance or V.list_not_empty(x),\n        )\n        return result\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines_core/stages/comfyui_latent_preparation.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\"\"\"\nComfyUI latent preparation stage with device mismatch fix.\nThis stage extends LatentPreparationStage to handle device mismatch issues\nthat occur when tensors are pickled and unpickled via broadcast_pyobj in\nmulti-GPU scenarios.\n\"\"\"\n\nimport dataclasses\n\nimport torch\n\nfrom sglang.multimodal_gen.runtime.distributed import (\n    get_local_torch_device,\n    get_sp_group,\n)\nfrom sglang.multimodal_gen.runtime.distributed.parallel_state import get_sp_world_size\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.latent_preparation import (\n    LatentPreparationStage,\n)\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\nclass ComfyUILatentPreparationStage(LatentPreparationStage):\n    \"\"\"\n    ComfyUI-specific latent preparation stage with device mismatch fix.\n\n    This stage extends LatentPreparationStage to automatically fix device\n    mismatches for tensor fields on non-source ranks in multi-GPU scenarios.\n    \"\"\"\n\n    @staticmethod\n    def _fix_tensor_device(value, target_device):\n        \"\"\"Recursively fix tensor device, handling single tensors, lists, and tuples.\"\"\"\n        if isinstance(value, torch.Tensor):\n            if value.device != target_device:\n                return value.detach().clone().to(target_device)\n            return value\n        elif isinstance(value, list):\n            return [\n                ComfyUILatentPreparationStage._fix_tensor_device(v, target_device)\n                for v in value\n            ]\n        elif isinstance(value, tuple):\n            return tuple(\n                ComfyUILatentPreparationStage._fix_tensor_device(v, target_device)\n                for v in value\n            )\n        return value\n\n    @staticmethod\n    def _has_tensor(value):\n        \"\"\"Check if value contains any tensor.\"\"\"\n        if isinstance(value, torch.Tensor):\n            return True\n        elif isinstance(value, (list, tuple)):\n            return any(ComfyUILatentPreparationStage._has_tensor(v) for v in value)\n        return False\n\n    def forward(self, batch, server_args):\n        \"\"\"\n        Prepare latents with device mismatch fix for ComfyUI pipelines.\n\n        This method first fixes device mismatches for all tensor fields,\n        then calls the parent class's forward method, and ensures raw_latent_shape\n        is set correctly (before packing, for proper unpadding later).\n        \"\"\"\n        # Fix device mismatch for tensor fields on non-source ranks\n        if get_sp_world_size() > 1:\n            sp_group = get_sp_group()\n            target_device = get_local_torch_device()\n\n            if sp_group.rank != 0:\n                logger.debug(\n                    f\"[ComfyUILatentPreparationStage] Fixing tensor device on rank={sp_group.rank} \"\n                    f\"target_device={target_device}\"\n                )\n\n                if dataclasses.is_dataclass(batch):\n                    for field in dataclasses.fields(batch):\n                        value = getattr(batch, field.name, None)\n                        if value is not None and self._has_tensor(value):\n                            fixed_value = self._fix_tensor_device(value, target_device)\n                            setattr(batch, field.name, fixed_value)\n                else:\n                    for attr_name in dir(batch):\n                        if not attr_name.startswith(\"_\") and not callable(\n                            getattr(batch, attr_name, None)\n                        ):\n                            try:\n                                value = getattr(batch, attr_name, None)\n                                if value is not None and self._has_tensor(value):\n                                    fixed_value = self._fix_tensor_device(\n                                        value, target_device\n                                    )\n                                    setattr(batch, attr_name, fixed_value)\n                            except (AttributeError, TypeError):\n                                continue\n\n        original_latents_shape = None\n        if batch.latents is not None:\n            original_latents_shape = batch.latents.shape\n\n        # Call parent class's forward method\n        result = super().forward(batch, server_args)\n\n        if original_latents_shape is not None:\n            # Preserve the original shape before any potential packing/conversion\n            # (e.g., 4D spatial -> 3D sequence) to ensure proper unpadding later.\n            result.raw_latent_shape = original_latents_shape\n\n        return result\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines_core/stages/decoding.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"\nDecoding stage for diffusion pipelines.\n\"\"\"\n\nimport weakref\n\nimport torch\n\nfrom sglang.multimodal_gen.runtime.distributed import get_local_torch_device\nfrom sglang.multimodal_gen.runtime.loader.component_loaders.vae_loader import VAELoader\nfrom sglang.multimodal_gen.runtime.models.vaes.common import ParallelTiledVAE\nfrom sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import OutputBatch, Req\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.base import (\n    PipelineStage,\n    StageParallelismType,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.validators import (\n    VerificationResult,\n)\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs, get_global_server_args\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.utils import PRECISION_TO_TYPE\n\nlogger = init_logger(__name__)\n\n\ndef _ensure_tensor_decode_output(decode_output):\n    \"\"\"\n    Ensure VAE decode output is a tensor.\n\n    Some VAE implementations return DecoderOutput objects with a .sample attribute,\n    tuples, or tensors directly. This function normalizes the output to always be a tensor.\n\n    Args:\n        decode_output: Output from VAE.decode(), can be DecoderOutput, tuple, or torch.Tensor\n\n    Returns:\n        torch.Tensor: The decoded image tensor\n    \"\"\"\n    if isinstance(decode_output, tuple):\n        return decode_output[0]\n    if hasattr(decode_output, \"sample\"):\n        return decode_output.sample\n    return decode_output\n\n\nclass DecodingStage(PipelineStage):\n    \"\"\"\n    Stage for decoding latent representations into pixel space.\n\n    This stage handles the decoding of latent representations into the final\n    output format (e.g., pixel values).\n    \"\"\"\n\n    def __init__(self, vae, pipeline=None, component_name: str = \"vae\") -> None:\n        super().__init__()\n        self.vae: ParallelTiledVAE = vae\n        self.pipeline = weakref.ref(pipeline) if pipeline else None\n        self.component_name = component_name\n\n    @property\n    def parallelism_type(self) -> StageParallelismType:\n        if get_global_server_args().enable_cfg_parallel:\n            return StageParallelismType.MAIN_RANK_ONLY\n        return StageParallelismType.REPLICATED\n\n    def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult:\n        \"\"\"Verify decoding stage inputs.\"\"\"\n        result = VerificationResult()\n        # Denoised latents for VAE decoding: [batch_size, channels, frames, height_latents, width_latents]\n        # result.add_check(\"latents\", batch.latents, [V.is_tensor, V.with_dims(5)])\n        return result\n\n    def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult:\n        \"\"\"Verify decoding stage outputs.\"\"\"\n        result = VerificationResult()\n        # Decoded video/images: [batch_size, channels, frames, height, width]\n        # result.add_check(\"output\", batch.output, [V.is_tensor, V.with_dims(5)])\n        return result\n\n    def scale_and_shift(self, latents: torch.Tensor, server_args):\n        scaling_factor, shift_factor = (\n            server_args.pipeline_config.get_decode_scale_and_shift(\n                latents.device, latents.dtype, self.vae\n            )\n        )\n\n        # 1. scale\n        if isinstance(scaling_factor, torch.Tensor):\n            latents = latents / scaling_factor.to(latents.device, latents.dtype)\n        else:\n            latents = latents / scaling_factor\n\n        # 2. apply shifting if needed\n        if shift_factor is not None:\n            if isinstance(shift_factor, torch.Tensor):\n                latents += shift_factor.to(latents.device, latents.dtype)\n            else:\n                latents += shift_factor\n        return latents\n\n    @torch.no_grad()\n    def decode(self, latents: torch.Tensor, server_args: ServerArgs) -> torch.Tensor:\n        \"\"\"\n        Decode latent representations into pixel space using VAE.\n\n        Args:\n            latents: Input latent tensor with shape (batch, channels, frames, height_latents, width_latents)\n            server_args: Configuration containing:\n                - disable_autocast: Whether to disable automatic mixed precision (default: False)\n                - pipeline_config.vae_precision: VAE computation precision (\"fp32\", \"fp16\", \"bf16\")\n                - pipeline_config.vae_tiling: Whether to enable VAE tiling for memory efficiency\n\n        Returns:\n            Decoded video tensor with shape (batch, channels, frames, height, width),\n            normalized to [0, 1] range and moved to CPU as float32\n        \"\"\"\n        vae_dtype = PRECISION_TO_TYPE[server_args.pipeline_config.vae_precision]\n        self.vae = self.vae.to(device=get_local_torch_device(), dtype=vae_dtype)\n        latents = latents.to(get_local_torch_device())\n        vae_autocast_enabled = (\n            vae_dtype != torch.float32\n        ) and not server_args.disable_autocast\n\n        # scale and shift\n        latents = self.scale_and_shift(latents, server_args)\n        # Preprocess latents before decoding (e.g., unpatchify for standard Flux2 VAE)\n        latents = server_args.pipeline_config.preprocess_decoding(\n            latents, server_args, vae=self.vae\n        )\n\n        # Decode latents\n        with torch.autocast(\n            device_type=current_platform.device_type,\n            dtype=vae_dtype,\n            enabled=vae_autocast_enabled,\n        ):\n            try:\n                # TODO: make it more specific\n                if server_args.pipeline_config.vae_tiling:\n                    self.vae.enable_tiling()\n            except Exception:\n                pass\n            if not vae_autocast_enabled:\n                latents = latents.to(vae_dtype)\n            decode_output = self.vae.decode(latents)\n            image = _ensure_tensor_decode_output(decode_output)\n\n        # De-normalize image to [0, 1] range\n        image = (image / 2 + 0.5).clamp(0, 1)\n        return image\n\n    def load_model(self):\n        # load vae if not already loaded (used for memory constrained devices)\n        pipeline = self.pipeline() if self.pipeline else None\n        if not self.server_args.model_loaded[self.component_name]:\n            loader = VAELoader()\n            self.vae, _ = loader.load(\n                self.server_args.model_paths[self.component_name],\n                self.server_args,\n                component_name=self.component_name,\n                transformers_or_diffusers=loader.expected_library,\n            )\n            if pipeline:\n                pipeline.add_module(self.component_name, self.vae)\n            self.server_args.model_loaded[self.component_name] = True\n\n    def offload_model(self):\n        # Offload models if needed\n        self.maybe_free_model_hooks()\n\n        if self.server_args.vae_cpu_offload:\n            self.vae.to(\"cpu\", non_blocking=True)\n\n        if torch.backends.mps.is_available():\n            # Flush lazy MPS kernels before freeing weights to avoid hangs.\n            torch.mps.synchronize()\n            del self.vae\n            pipeline = self.pipeline() if self.pipeline else None\n            if pipeline is not None and self.component_name in pipeline.modules:\n                del pipeline.modules[self.component_name]\n            self.server_args.model_loaded[self.component_name] = False\n\n    @torch.no_grad()\n    def forward(\n        self,\n        batch: Req,\n        server_args: ServerArgs,\n    ) -> OutputBatch:\n        \"\"\"\n        Decode latent representations into pixel space.\n\n        This method processes the batch through the VAE decoder, converting latent\n        representations to pixel-space video/images. It also optionally decodes\n        trajectory latents for visualization purposes.\n\n        \"\"\"\n        # load vae if not already loaded (used for memory constrained devices)\n        self.load_model()\n\n        frames = self.decode(batch.latents, server_args)\n\n        # decode trajectory latents if needed\n        if batch.return_trajectory_decoded:\n            assert (\n                batch.trajectory_latents is not None\n            ), \"batch should have trajectory latents\"\n\n            # 1. Batch trajectory decoding to improve GPU utilization\n            # batch.trajectory_latents is [batch_size, timesteps, channels, frames, height, width]\n            B, T, C, F, H, W = batch.trajectory_latents.shape\n            flat_latents = batch.trajectory_latents.view(B * T, C, F, H, W)\n\n            logger.info(\"decoding %s trajectory latents in batch\", B * T)\n            # Use the optimized batch decode\n            all_decoded = self.decode(flat_latents, server_args)\n\n            # 2. Reshape back\n            # Keep on GPU to allow faster vectorized post-processing\n            decoded_tensor = all_decoded.view(B, T, *all_decoded.shape[1:])\n\n            # Convert to list of tensors (per timestep) as expected by OutputBatch\n            # Each element in list is [B, channels, frames, H_out, W_out]\n            trajectory_decoded = [decoded_tensor[:, i] for i in range(T)]\n        else:\n            trajectory_decoded = None\n\n        frames = server_args.pipeline_config.post_decoding(frames, server_args)\n\n        # Update batch with decoded image\n        output_batch = OutputBatch(\n            output=frames,\n            trajectory_timesteps=batch.trajectory_timesteps,\n            trajectory_latents=batch.trajectory_latents,\n            trajectory_decoded=trajectory_decoded,\n            metrics=batch.metrics,\n        )\n\n        # Keep VAE resident during warmup; the real request needs it next.\n        if not getattr(batch, \"is_warmup\", False):\n            self.offload_model()\n\n        return output_batch\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines_core/stages/decoding_av.py",
    "content": "import torch\n\nfrom sglang.multimodal_gen.runtime.distributed import get_local_torch_device\nfrom sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import OutputBatch, Req\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.decoding import DecodingStage\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.utils import PRECISION_TO_TYPE\n\nlogger = init_logger(__name__)\n\n\nclass LTX2AVDecodingStage(DecodingStage):\n    \"\"\"\n    LTX-2 specific decoding stage that handles both video and audio decoding.\n    \"\"\"\n\n    def __init__(self, vae, audio_vae, vocoder, pipeline=None):\n        super().__init__(vae, pipeline)\n        self.audio_vae = audio_vae\n        self.vocoder = vocoder\n        # Add video processor for postprocessing\n        from diffusers.video_processor import VideoProcessor\n\n        self.video_processor = VideoProcessor(vae_scale_factor=32)\n\n    def forward(self, batch: Req, server_args: ServerArgs) -> OutputBatch:\n        self.load_model()\n\n        self.vae = self.vae.to(get_local_torch_device())\n        self.vae.eval()\n        latents = batch.latents.to(get_local_torch_device())\n\n        vae_dtype = PRECISION_TO_TYPE[server_args.pipeline_config.vae_precision]\n        vae_autocast_enabled = (\n            vae_dtype != torch.float32\n        ) and not server_args.disable_autocast\n\n        original_dtype = vae_dtype\n        self.vae.to(torch.bfloat16)\n        latents = latents.to(torch.bfloat16)\n        std = self.vae.latents_std.view(1, -1, 1, 1, 1).to(latents)\n        mean = self.vae.latents_mean.view(1, -1, 1, 1, 1).to(latents)\n        latents = latents * std + mean\n        latents = server_args.pipeline_config.preprocess_decoding(\n            latents, server_args, vae=self.vae\n        )\n\n        with torch.autocast(\n            device_type=current_platform.device_type,\n            dtype=vae_dtype,\n            enabled=vae_autocast_enabled,\n        ):\n            try:\n                if server_args.pipeline_config.vae_tiling:\n                    self.vae.enable_tiling()\n            except Exception:\n                pass\n            decode_output = self.vae.decode(latents)\n            if isinstance(decode_output, tuple):\n                video = decode_output[0]\n            elif hasattr(decode_output, \"sample\"):\n                video = decode_output.sample\n            else:\n                video = decode_output\n\n        self.vae.to(original_dtype)\n        video = self.video_processor.postprocess_video(video, output_type=\"np\")\n\n        output_batch = OutputBatch(\n            output=video,\n            trajectory_timesteps=batch.trajectory_timesteps,\n            trajectory_latents=batch.trajectory_latents,\n            trajectory_decoded=None,\n            metrics=batch.metrics,\n        )\n\n        # 2. Decode Audio\n        try:\n            audio_latents = batch.audio_latents\n        except AttributeError:\n            audio_latents = None\n        if audio_latents is not None:\n            # Ensure device/dtype\n            device = get_local_torch_device()\n            self.audio_vae = self.audio_vae.to(device)\n            self.vocoder = self.vocoder.to(device)\n            self.audio_vae.eval()\n            self.vocoder.eval()\n            try:\n                dtype = self.audio_vae.dtype\n            except AttributeError:\n                dtype = None\n            if dtype is None:\n                try:\n                    dtype = next(self.audio_vae.parameters()).dtype\n                except StopIteration:\n                    dtype = torch.float32\n            audio_latents = audio_latents.to(device, dtype=dtype)\n            try:\n                latents_std = self.audio_vae.latents_std\n            except AttributeError:\n                latents_std = None\n            if isinstance(latents_std, torch.Tensor) and torch.all(latents_std == 0):\n                logger.warning(\n                    \"audio_vae.latents_std is all zeros; audio denorm may be incorrect.\"\n                )\n\n            with torch.no_grad():\n                # Decode latents to spectrogram\n                spectrogram = self.audio_vae.decode(audio_latents, return_dict=False)[0]\n                if hasattr(self.vocoder, \"conv_in\") and hasattr(\n                    self.vocoder.conv_in, \"in_channels\"\n                ):\n                    expected_in = int(self.vocoder.conv_in.in_channels)\n                    actual_in = int(spectrogram.shape[1]) * int(spectrogram.shape[3])\n                    if actual_in != expected_in:\n                        raise ValueError(\n                            f\"Vocoder expects channels*mel_bins={expected_in}, got {actual_in} from spectrogram shape {tuple(spectrogram.shape)}\"\n                        )\n                # Decode spectrogram to waveform\n                waveform = self.vocoder(spectrogram)\n            output_batch.audio = waveform.cpu().float()\n            try:\n                pipeline_audio_cfg = server_args.pipeline_config.audio_vae_config\n            except AttributeError:\n                pipeline_audio_cfg = None\n            try:\n                pipeline_audio_arch = pipeline_audio_cfg.arch_config  # type: ignore[union-attr]\n            except AttributeError:\n                pipeline_audio_arch = None\n            try:\n                pipeline_audio_sr = pipeline_audio_arch.sample_rate  # type: ignore[union-attr]\n            except AttributeError:\n                pipeline_audio_sr = None\n\n            try:\n                vocoder_sr = self.vocoder.sample_rate\n            except AttributeError:\n                vocoder_sr = None\n            try:\n                audio_vae_sr = self.audio_vae.sample_rate\n            except AttributeError:\n                audio_vae_sr = None\n            output_batch.audio_sample_rate = (\n                vocoder_sr or audio_vae_sr or pipeline_audio_sr\n            )\n\n        self.offload_model()\n        return output_batch\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"\nDenoising stage for diffusion pipelines.\n\"\"\"\n\nimport inspect\nimport math\nimport os\nimport time\nimport weakref\nfrom collections.abc import Iterable\nfrom functools import lru_cache\nfrom typing import Any\n\nimport torch\nimport torch.nn as nn\nfrom einops import rearrange\nfrom tqdm.auto import tqdm\n\nfrom sglang.multimodal_gen import envs\nfrom sglang.multimodal_gen.configs.pipeline_configs.base import ModelTaskType, STA_Mode\nfrom sglang.multimodal_gen.configs.pipeline_configs.wan import (\n    Wan2_2_TI2V_5B_Config,\n)\nfrom sglang.multimodal_gen.runtime.cache.cache_dit_integration import (\n    CacheDitConfig,\n    enable_cache_on_dual_transformer,\n    enable_cache_on_transformer,\n    get_scm_mask,\n    refresh_context_on_dual_transformer,\n    refresh_context_on_transformer,\n)\nfrom sglang.multimodal_gen.runtime.distributed import (\n    cfg_model_parallel_all_reduce,\n    get_local_torch_device,\n    get_sp_group,\n    get_sp_parallel_rank,\n    get_sp_world_size,\n    get_tp_group,\n    get_world_group,\n    get_world_size,\n)\nfrom sglang.multimodal_gen.runtime.distributed.communication_op import (\n    sequence_model_parallel_all_gather,\n)\nfrom sglang.multimodal_gen.runtime.distributed.parallel_state import (\n    get_cfg_group,\n    get_classifier_free_guidance_rank,\n)\nfrom sglang.multimodal_gen.runtime.layers.attention.selector import get_attn_backend\nfrom sglang.multimodal_gen.runtime.layers.attention.STA_configuration import (\n    configure_sta,\n    save_mask_search_results,\n)\nfrom sglang.multimodal_gen.runtime.loader.component_loaders.transformer_loader import (\n    TransformerLoader,\n)\nfrom sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context\nfrom sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.base import (\n    PipelineStage,\n    StageParallelismType,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.validators import (\n    StageValidators as V,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.validators import (\n    VerificationResult,\n)\nfrom sglang.multimodal_gen.runtime.platforms import (\n    AttentionBackendEnum,\n    current_platform,\n)\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.runtime.utils.perf_logger import StageProfiler\nfrom sglang.multimodal_gen.runtime.utils.profiler import SGLDiffusionProfiler\nfrom sglang.multimodal_gen.utils import dict_to_3d_list, masks_like\nfrom sglang.srt.utils.common import get_compiler_backend\n\nlogger = init_logger(__name__)\n\n\nclass DenoisingStage(PipelineStage):\n    \"\"\"\n    Stage for running the denoising loop in diffusion pipelines.\n\n    This stage handles the iterative denoising process that transforms\n    the initial noise into the final output.\n    \"\"\"\n\n    def __init__(\n        self, transformer, scheduler, pipeline=None, transformer_2=None, vae=None\n    ) -> None:\n        super().__init__()\n        self.transformer = transformer\n        self.transformer_2 = transformer_2\n\n        hidden_size = self.server_args.pipeline_config.dit_config.hidden_size\n        num_attention_heads = (\n            self.server_args.pipeline_config.dit_config.num_attention_heads\n        )\n        attn_head_size = hidden_size // num_attention_heads\n\n        # torch compile\n        for transformer in filter(None, [self.transformer, self.transformer_2]):\n            self._maybe_enable_torch_compile(transformer)\n\n        self.scheduler = scheduler\n        self.vae = vae\n        self.pipeline = weakref.ref(pipeline) if pipeline else None\n\n        # TODO(will): hack, should use the actual one in dit\n        self.attn_backend = get_attn_backend(\n            head_size=attn_head_size,\n            dtype=torch.float16,\n        )\n\n        # cfg\n        self.guidance = None\n\n        # misc\n        self.profiler = None\n        # cache-dit state (for delayed mounting and idempotent control)\n        self._cache_dit_enabled = False\n        self._cached_num_steps = None\n        self._is_warmed_up = False\n\n    def _maybe_enable_torch_compile(self, module: object) -> None:\n        \"\"\"\n        Compile a module with torch.compile, and enable inductor overlap tweak if available.\n        No-op if torch compile is disabled or the object is not a nn.Module.\n        \"\"\"\n        if not self.server_args.enable_torch_compile or not isinstance(\n            module, nn.Module\n        ):\n            return\n        compile_kwargs: dict[str, Any] = {\"fullgraph\": False, \"dynamic\": None}\n\n        if current_platform.is_npu():\n            backend = get_compiler_backend()\n            compile_kwargs[\"backend\"] = backend\n            compile_kwargs[\"dynamic\"] = False\n            logger.info(\"Compiling transformer with torchair backend on NPU\")\n        else:\n            try:\n                import torch._inductor.config as _inductor_cfg\n\n                _inductor_cfg.reorder_for_compute_comm_overlap = True\n            except ImportError:\n                pass\n            mode = os.environ.get(\n                \"SGLANG_TORCH_COMPILE_MODE\", \"max-autotune-no-cudagraphs\"\n            )\n            compile_kwargs[\"mode\"] = mode\n            logger.info(f\"Compiling transformer with mode: {mode}\")\n\n        # TODO(triple-mu): support customized fullgraph and dynamic in the future\n        module.compile(**compile_kwargs)\n\n    def _maybe_enable_cache_dit(\n        self, num_inference_steps: int | tuple[int, int], batch: Req\n    ) -> None:\n        \"\"\"Enable cache-dit on the transformers if configured (idempotent).\n\n        This method should be called after the transformer is fully loaded\n        and before torch.compile is applied.\n\n        For dual-transformer models (e.g., Wan2.2), this enables cache-dit on both\n        transformers with (potentially) different configurations.\n\n        \"\"\"\n        if isinstance(num_inference_steps, tuple):\n            num_high_noise_steps, num_low_noise_steps = num_inference_steps\n\n        # NOTE: When a new request arrives, we need to refresh the cache-dit context.\n        if self._cache_dit_enabled:\n            scm_preset = envs.SGLANG_CACHE_DIT_SCM_PRESET\n            scm_preset = None if scm_preset == \"none\" else scm_preset\n            if isinstance(num_inference_steps, tuple):\n                refresh_context_on_dual_transformer(\n                    self.transformer,\n                    self.transformer_2,\n                    num_high_noise_steps,\n                    num_low_noise_steps,\n                    scm_preset=scm_preset,\n                )\n            else:\n                refresh_context_on_transformer(\n                    self.transformer,\n                    num_inference_steps,\n                    scm_preset=scm_preset,\n                )\n            return\n\n        # check if cache-dit is enabled in config\n        if not envs.SGLANG_CACHE_DIT_ENABLED or batch.is_warmup:\n            return\n\n        world_size = get_world_size()\n        parallelized = world_size > 1\n\n        sp_group = None\n        tp_group = None\n        if parallelized:\n            sp_group_candidate = get_sp_group()\n            tp_group_candidate = get_tp_group()\n\n            sp_world_size = sp_group_candidate.world_size if sp_group_candidate else 1\n            tp_world_size = tp_group_candidate.world_size if tp_group_candidate else 1\n\n            has_sp = sp_world_size > 1\n            has_tp = tp_world_size > 1\n\n            sp_group = sp_group_candidate.device_group if has_sp else None\n            tp_group = tp_group_candidate.device_group if has_tp else None\n\n            logger.info(\n                \"cache-dit enabled in distributed environment (world_size=%d, has_sp=%s, has_tp=%s)\",\n                world_size,\n                has_sp,\n                has_tp,\n            )\n        # === Parse SCM configuration from envs ===\n        # SCM is shared between primary and secondary transformers\n        scm_preset = envs.SGLANG_CACHE_DIT_SCM_PRESET\n        scm_compute_bins_str = envs.SGLANG_CACHE_DIT_SCM_COMPUTE_BINS\n        scm_cache_bins_str = envs.SGLANG_CACHE_DIT_SCM_CACHE_BINS\n        scm_policy = envs.SGLANG_CACHE_DIT_SCM_POLICY\n\n        # parse custom bins if provided (both must be set together)\n        scm_compute_bins = None\n        scm_cache_bins = None\n        if scm_compute_bins_str and scm_cache_bins_str:\n            try:\n                scm_compute_bins = [\n                    int(x.strip()) for x in scm_compute_bins_str.split(\",\")\n                ]\n                scm_cache_bins = [int(x.strip()) for x in scm_cache_bins_str.split(\",\")]\n            except ValueError as e:\n                logger.warning(\"Failed to parse SCM bins: %s. SCM disabled.\", e)\n                scm_preset = \"none\"\n        elif scm_compute_bins_str or scm_cache_bins_str:\n            # Only one of the bins was provided - warn user\n            logger.warning(\n                \"SCM custom bins require both compute_bins and cache_bins. \"\n                \"Only one was provided (compute=%s, cache=%s). Falling back to preset '%s'.\",\n                scm_compute_bins_str,\n                scm_cache_bins_str,\n                scm_preset,\n            )\n\n        # generate SCM mask using cache-dit's steps_mask()\n        # cache-dit handles step count validation and scaling internally\n        steps_computation_mask = get_scm_mask(\n            preset=scm_preset,\n            num_inference_steps=(\n                num_inference_steps\n                if isinstance(num_inference_steps, int)\n                else num_high_noise_steps\n            ),\n            compute_bins=scm_compute_bins,\n            cache_bins=scm_cache_bins,\n        )\n\n        if isinstance(num_inference_steps, tuple):\n            steps_computation_mask_2 = get_scm_mask(\n                preset=scm_preset,\n                num_inference_steps=num_low_noise_steps,\n                compute_bins=scm_compute_bins,\n                cache_bins=scm_cache_bins,\n            )\n\n        # build config for primary transformer (high-noise expert)\n        primary_config = CacheDitConfig(\n            enabled=True,\n            Fn_compute_blocks=envs.SGLANG_CACHE_DIT_FN,\n            Bn_compute_blocks=envs.SGLANG_CACHE_DIT_BN,\n            max_warmup_steps=envs.SGLANG_CACHE_DIT_WARMUP,\n            residual_diff_threshold=envs.SGLANG_CACHE_DIT_RDT,\n            max_continuous_cached_steps=envs.SGLANG_CACHE_DIT_MC,\n            enable_taylorseer=envs.SGLANG_CACHE_DIT_TAYLORSEER,\n            taylorseer_order=envs.SGLANG_CACHE_DIT_TS_ORDER,\n            num_inference_steps=(\n                num_inference_steps\n                if isinstance(num_inference_steps, int)\n                else num_high_noise_steps\n            ),\n            # SCM fields\n            steps_computation_mask=steps_computation_mask,\n            steps_computation_policy=scm_policy,\n        )\n\n        if self.transformer_2 is not None:\n            # dual transformer\n            # build config for secondary transformer (low-noise expert)\n            # uses secondary parameters which inherit from primary if not explicitly set\n            secondary_config = CacheDitConfig(\n                enabled=True,\n                Fn_compute_blocks=envs.SGLANG_CACHE_DIT_SECONDARY_FN,\n                Bn_compute_blocks=envs.SGLANG_CACHE_DIT_SECONDARY_BN,\n                max_warmup_steps=envs.SGLANG_CACHE_DIT_SECONDARY_WARMUP,\n                residual_diff_threshold=envs.SGLANG_CACHE_DIT_SECONDARY_RDT,\n                max_continuous_cached_steps=envs.SGLANG_CACHE_DIT_SECONDARY_MC,\n                enable_taylorseer=envs.SGLANG_CACHE_DIT_SECONDARY_TAYLORSEER,\n                taylorseer_order=envs.SGLANG_CACHE_DIT_SECONDARY_TS_ORDER,\n                num_inference_steps=num_low_noise_steps,\n                # SCM fields - shared with primary\n                steps_computation_mask=steps_computation_mask_2,\n                steps_computation_policy=scm_policy,\n            )\n\n            # for dual transformers, must use BlockAdapter to enable cache on both simultaneously.\n            # Don't call enable_cache separately on each transformer.\n            self.transformer, self.transformer_2 = enable_cache_on_dual_transformer(\n                self.transformer,\n                self.transformer_2,\n                primary_config,\n                secondary_config,\n                model_name=\"wan2.2\",\n                sp_group=sp_group,\n                tp_group=tp_group,\n            )\n            logger.info(\n                \"cache-dit enabled on dual transformers (steps=%d, %d)\",\n                num_high_noise_steps,\n                num_low_noise_steps,\n            )\n        else:\n            # single transformer\n            self.transformer = enable_cache_on_transformer(\n                self.transformer,\n                primary_config,\n                model_name=\"transformer\",\n                sp_group=sp_group,\n                tp_group=tp_group,\n            )\n            logger.info(\n                \"cache-dit enabled on transformer (steps=%d, Fn=%d, Bn=%d, rdt=%.3f)\",\n                num_inference_steps,\n                envs.SGLANG_CACHE_DIT_FN,\n                envs.SGLANG_CACHE_DIT_BN,\n                envs.SGLANG_CACHE_DIT_RDT,\n            )\n\n        self._cache_dit_enabled = True\n        self._cached_num_steps = num_inference_steps\n\n    @lru_cache(maxsize=8)\n    def _build_guidance(self, batch_size, target_dtype, device, guidance_val):\n        \"\"\"Builds a guidance tensor. This method is cached.\"\"\"\n        return (\n            torch.full(\n                (batch_size,),\n                guidance_val,\n                dtype=target_dtype,\n                device=device,\n            )\n            * 1000.0\n        )\n\n    def get_or_build_guidance(self, bsz: int, dtype, device):\n        \"\"\"\n        Get the guidance tensor, using a cached version if available.\n\n        This method retrieves a cached guidance tensor using `_build_guidance`.\n        The caching is based on batch size, dtype, device, and the guidance value,\n        preventing repeated tensor creation within the denoising loop.\n        \"\"\"\n        if self.server_args.pipeline_config.should_use_guidance:\n            # TODO: should the guidance_scale be picked-up from sampling_params?\n            guidance_val = self.server_args.pipeline_config.embedded_cfg_scale\n            return self._build_guidance(bsz, dtype, device, guidance_val)\n        else:\n            return None\n\n    @property\n    def parallelism_type(self) -> StageParallelismType:\n        # return StageParallelismType.CFG_PARALLEL if get_global_server_args().enable_cfg_parallel else StageParallelismType.REPLICATED\n        return StageParallelismType.REPLICATED\n\n    def _preprocess_latents_for_ti2v(\n        self, latents, target_dtype, batch, server_args: ServerArgs\n    ):\n        # FIXME: should probably move to latent preparation stage, to handle with offload\n        # Wan2.2 TI2V directly replaces the first frame of the latent with\n        # the image latent instead of appending along the channel dim\n        assert batch.image_latent is None, \"TI2V task should not have image latents\"\n        assert self.vae is not None, \"VAE is not provided for TI2V task\"\n        self.vae = self.vae.to(batch.condition_image.device)\n        z = self.vae.encode(batch.condition_image).mean.float()\n        if self.vae.device != \"cpu\" and server_args.vae_cpu_offload:\n            self.vae = self.vae.to(\"cpu\")\n        if hasattr(self.vae, \"shift_factor\") and self.vae.shift_factor is not None:\n            if isinstance(self.vae.shift_factor, torch.Tensor):\n                z -= self.vae.shift_factor.to(z.device, z.dtype)\n            else:\n                z -= self.vae.shift_factor\n\n        if isinstance(self.vae.scaling_factor, torch.Tensor):\n            z = z * self.vae.scaling_factor.to(z.device, z.dtype)\n        else:\n            z = z * self.vae.scaling_factor\n        # z: [B, C, 1, H, W]\n        latent_model_input = latents.to(target_dtype)\n        # Keep as [B, C, T, H, W] for proper broadcasting\n        assert latent_model_input.ndim == 5\n\n        # Create mask with proper shape [B, C, T, H, W]\n        latent_for_mask = latent_model_input.squeeze(0)  # [C, T, H, W]\n        _, reserved_frames_masks = masks_like([latent_for_mask], zero=True)\n        reserved_frames_mask = reserved_frames_masks[0].unsqueeze(0)  # [1, C, T, H, W]\n\n        # replace GLOBAL first frame with image - proper broadcasting\n        # z: [B, C, 1, H, W], reserved_frames_mask: [1, C, T, H, W]\n        # Both will broadcast correctly\n        latents = (\n            1.0 - reserved_frames_mask\n        ) * z + reserved_frames_mask * latent_model_input\n        assert latents.ndim == 5\n        latents = latents.to(get_local_torch_device())\n        batch.latents = latents\n\n        F = batch.num_frames\n        temporal_scale = (\n            server_args.pipeline_config.vae_config.arch_config.scale_factor_temporal\n        )\n        spatial_scale = (\n            server_args.pipeline_config.vae_config.arch_config.scale_factor_spatial\n        )\n        patch_size = server_args.pipeline_config.dit_config.arch_config.patch_size\n        seq_len = (\n            ((F - 1) // temporal_scale + 1)\n            * (batch.height // spatial_scale)\n            * (batch.width // spatial_scale)\n            // (patch_size[1] * patch_size[2])\n        )\n        seq_len = int(math.ceil(seq_len / get_sp_world_size())) * get_sp_world_size()\n        return seq_len, z, reserved_frames_masks\n\n    def _postprocess_latents_for_ti2v(self, z, reserved_frames_masks, batch):\n        rank_in_sp_group = get_sp_parallel_rank()\n        sp_world_size = get_sp_world_size()\n\n        if getattr(batch, \"did_sp_shard_latents\", False):\n            # Shard z (image latent) along time dimension\n            # z shape: [1, C, 1, H, W] - only first frame\n            # Only rank 0 has the first frame after sharding\n            if z.shape[2] == 1:\n                # z is single frame, only rank 0 needs it\n                if rank_in_sp_group == 0:\n                    z_sp = z\n                else:\n                    # Other ranks don't have the first frame\n                    z_sp = None\n            else:\n                # Should not happen for TI2V\n                z_sp = z\n\n            # Shard reserved_frames_mask along time dimension to match sharded latents\n            # reserved_frames_mask is a list from masks_like, extract reserved_frames_mask[0] first\n            # reserved_frames_mask[0] shape: [C, T, H, W]\n            # All ranks need their portion of reserved_frames_mask for timestep calculation\n            if reserved_frames_masks is not None:\n                reserved_frames_mask = reserved_frames_masks[\n                    0\n                ]  # Extract tensor from list\n                time_dim = reserved_frames_mask.shape[1]  # [C, T, H, W]\n                if time_dim > 0 and time_dim % sp_world_size == 0:\n                    reserved_frames_mask_sp_tensor = rearrange(\n                        reserved_frames_mask,\n                        \"c (n t) h w -> c n t h w\",\n                        n=sp_world_size,\n                    ).contiguous()\n                    reserved_frames_mask_sp_tensor = reserved_frames_mask_sp_tensor[\n                        :, rank_in_sp_group, :, :, :\n                    ]\n                    reserved_frames_mask_sp = (\n                        reserved_frames_mask_sp_tensor  # Store as tensor, not list\n                    )\n                else:\n                    reserved_frames_mask_sp = reserved_frames_mask\n            else:\n                reserved_frames_mask_sp = None\n        else:\n            # SP not enabled or latents not sharded\n            z_sp = z\n            reserved_frames_mask_sp = (\n                reserved_frames_masks[0] if reserved_frames_masks is not None else None\n            )  # Extract tensor\n\n        return reserved_frames_mask_sp, z_sp\n\n    def _handle_boundary_ratio(\n        self,\n        server_args,\n        batch,\n    ):\n        \"\"\"\n        (Wan2.2) Calculate timestep to switch from high noise expert to low noise expert\n        \"\"\"\n        boundary_ratio = server_args.pipeline_config.dit_config.boundary_ratio\n        if batch.boundary_ratio is not None:\n            logger.info(\n                \"Overriding boundary ratio from %s to %s\",\n                boundary_ratio,\n                batch.boundary_ratio,\n            )\n            boundary_ratio = batch.boundary_ratio\n\n        if boundary_ratio is not None:\n            boundary_timestep = boundary_ratio * self.scheduler.num_train_timesteps\n        else:\n            boundary_timestep = None\n\n        return boundary_timestep\n\n    def _prepare_denoising_loop(self, batch: Req, server_args: ServerArgs):\n        \"\"\"\n        Prepare all necessary invariant variables for the denoising loop.\n\n        Returns:\n            A dictionary containing all the prepared variables for the denoising loop.\n        \"\"\"\n        assert self.transformer is not None\n        pipeline = self.pipeline() if self.pipeline else None\n\n        boundary_timestep = self._handle_boundary_ratio(server_args, batch)\n        # Get timesteps and calculate warmup steps\n        timesteps = batch.timesteps\n        num_inference_steps = batch.num_inference_steps\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n\n        if self.transformer_2 is not None:\n            assert boundary_timestep is not None, \"boundary_timestep must be provided\"\n            num_high_noise_steps = (timesteps >= boundary_timestep).sum().item()\n            num_low_noise_steps = num_inference_steps - num_high_noise_steps\n            cache_dit_num_inference_steps = (num_high_noise_steps, num_low_noise_steps)\n        else:\n            cache_dit_num_inference_steps = num_inference_steps\n\n        if not server_args.model_loaded[\"transformer\"]:\n            # FIXME: reuse more code\n            loader = TransformerLoader()\n            self.transformer = loader.load(\n                server_args.model_paths[\"transformer\"], server_args, \"transformer\"\n            )\n            # enable cache-dit before torch.compile (delayed mounting)\n            self._maybe_enable_cache_dit(cache_dit_num_inference_steps, batch)\n            self._maybe_enable_torch_compile(self.transformer)\n            if pipeline:\n                pipeline.add_module(\"transformer\", self.transformer)\n            server_args.model_loaded[\"transformer\"] = True\n        else:\n            self._maybe_enable_cache_dit(cache_dit_num_inference_steps, batch)\n\n        # Prepare extra step kwargs for scheduler\n        extra_step_kwargs = self.prepare_extra_func_kwargs(\n            self.scheduler.step,\n            {\"generator\": batch.generator, \"eta\": batch.eta},\n        )\n\n        # Setup precision and autocast settings\n        target_dtype = torch.bfloat16\n        autocast_enabled = (\n            target_dtype != torch.float32\n        ) and not server_args.disable_autocast\n\n        # Prepare image latents and embeddings for I2V generation\n        image_embeds = batch.image_embeds\n        if len(image_embeds) > 0:\n            image_embeds = [\n                image_embed.to(target_dtype) for image_embed in image_embeds\n            ]\n\n        # Prepare STA parameters\n        if self.attn_backend.get_enum() == AttentionBackendEnum.SLIDING_TILE_ATTN:\n            self.prepare_sta_param(batch, server_args)\n\n        # Get latents and embeddings\n        latents = batch.latents\n        prompt_embeds = batch.prompt_embeds\n        # Removed Tensor truthiness assert to avoid GPU sync\n        neg_prompt_embeds = None\n        if batch.do_classifier_free_guidance:\n            neg_prompt_embeds = batch.negative_prompt_embeds\n            assert neg_prompt_embeds is not None\n            # Removed Tensor truthiness assert to avoid GPU sync\n\n        # specifically for Wan2_2_TI2V_5B_Config, not applicable for FastWan2_2_TI2V_5B_Config\n        should_preprocess_for_wan_ti2v = (\n            server_args.pipeline_config.task_type == ModelTaskType.TI2V\n            and batch.condition_image is not None\n            and type(server_args.pipeline_config) is Wan2_2_TI2V_5B_Config\n        )\n\n        # TI2V specific preparations - before SP sharding\n        if should_preprocess_for_wan_ti2v:\n            seq_len, z, reserved_frames_masks = self._preprocess_latents_for_ti2v(\n                latents, target_dtype, batch, server_args\n            )\n        else:\n            seq_len, z, reserved_frames_masks = (\n                None,\n                None,\n                None,\n            )\n\n        # Handle sequence parallelism after TI2V processing\n        self._preprocess_sp_latents(batch, server_args)\n        latents = batch.latents\n\n        # Shard z and reserved_frames_mask for TI2V if SP is enabled\n        if should_preprocess_for_wan_ti2v:\n            reserved_frames_mask_sp, z_sp = self._postprocess_latents_for_ti2v(\n                z, reserved_frames_masks, batch\n            )\n        else:\n            reserved_frames_mask_sp, z_sp = (\n                reserved_frames_masks[0] if reserved_frames_masks is not None else None\n            ), z\n\n        guidance = self.get_or_build_guidance(\n            # TODO: replace with raw_latent_shape?\n            latents.shape[0],\n            latents.dtype,\n            latents.device,\n        )\n\n        image_kwargs = self.prepare_extra_func_kwargs(\n            getattr(self.transformer, \"forward\", self.transformer),\n            {\n                # TODO: make sure on-device\n                \"encoder_hidden_states_image\": image_embeds,\n                \"mask_strategy\": dict_to_3d_list(None, t_max=50, l_max=60, h_max=24),\n            },\n        )\n\n        pos_cond_kwargs = self.prepare_extra_func_kwargs(\n            getattr(self.transformer, \"forward\", self.transformer),\n            {\n                \"encoder_hidden_states_2\": batch.clip_embedding_pos,\n                \"encoder_attention_mask\": batch.prompt_attention_mask,\n            }\n            | server_args.pipeline_config.prepare_pos_cond_kwargs(\n                batch,\n                self.device,\n                getattr(self.transformer, \"rotary_emb\", None),\n                dtype=target_dtype,\n            )\n            | dict(\n                encoder_hidden_states=server_args.pipeline_config.get_pos_prompt_embeds(\n                    batch\n                )\n            ),\n        )\n\n        if batch.do_classifier_free_guidance:\n            neg_cond_kwargs = self.prepare_extra_func_kwargs(\n                getattr(self.transformer, \"forward\", self.transformer),\n                {\n                    \"encoder_hidden_states_2\": batch.clip_embedding_neg,\n                    \"encoder_attention_mask\": batch.negative_attention_mask,\n                }\n                | server_args.pipeline_config.prepare_neg_cond_kwargs(\n                    batch,\n                    self.device,\n                    getattr(self.transformer, \"rotary_emb\", None),\n                    dtype=target_dtype,\n                )\n                | dict(\n                    encoder_hidden_states=server_args.pipeline_config.get_neg_prompt_embeds(\n                        batch\n                    )\n                ),\n            )\n        else:\n            neg_cond_kwargs = {}\n\n        return {\n            \"extra_step_kwargs\": extra_step_kwargs,\n            \"target_dtype\": target_dtype,\n            \"autocast_enabled\": autocast_enabled,\n            \"timesteps\": timesteps,\n            \"num_inference_steps\": num_inference_steps,\n            \"num_warmup_steps\": num_warmup_steps,\n            \"image_kwargs\": image_kwargs,\n            \"pos_cond_kwargs\": pos_cond_kwargs,\n            \"neg_cond_kwargs\": neg_cond_kwargs,\n            \"latents\": latents,\n            \"prompt_embeds\": prompt_embeds,\n            \"neg_prompt_embeds\": neg_prompt_embeds,\n            \"boundary_timestep\": boundary_timestep,\n            \"z\": z_sp,  # Use SP-sharded version\n            # ndim == 5\n            \"reserved_frames_mask\": reserved_frames_mask_sp,  # Use SP-sharded version\n            \"seq_len\": seq_len,\n            \"guidance\": guidance,\n        }\n\n    def _post_denoising_loop(\n        self,\n        batch: Req,\n        latents: torch.Tensor,\n        trajectory_latents: list,\n        trajectory_timesteps: list,\n        server_args: ServerArgs,\n        is_warmup: bool = False,\n    ):\n        # Gather results if using sequence parallelism\n        if trajectory_latents:\n            trajectory_tensor = torch.stack(trajectory_latents, dim=1)\n            trajectory_timesteps_tensor = torch.stack(trajectory_timesteps, dim=0)\n        else:\n            trajectory_tensor = None\n            trajectory_timesteps_tensor = None\n\n        # Gather results if using sequence parallelism\n        latents, trajectory_tensor = self._postprocess_sp_latents(\n            batch, latents, trajectory_tensor\n        )\n\n        # Gather noise_pred if using sequence parallelism\n        # noise_pred has the same shape as latents (sharded along sequence dimension)\n        if (\n            get_sp_world_size() > 1\n            and getattr(batch, \"did_sp_shard_latents\", False)\n            and server_args.comfyui_mode\n            and hasattr(batch, \"noise_pred\")\n            and batch.noise_pred is not None\n        ):\n            batch.noise_pred = server_args.pipeline_config.gather_latents_for_sp(\n                batch.noise_pred\n            )\n            if hasattr(batch, \"raw_latent_shape\"):\n                orig_s = batch.raw_latent_shape[1]\n                if batch.noise_pred.shape[1] > orig_s:\n                    batch.noise_pred = batch.noise_pred[:, :orig_s, :]\n\n        if trajectory_tensor is not None and trajectory_timesteps_tensor is not None:\n            batch.trajectory_timesteps = trajectory_timesteps_tensor.cpu()\n            batch.trajectory_latents = trajectory_tensor.cpu()\n\n        # Update batch with final latents\n        batch.latents = self.server_args.pipeline_config.post_denoising_loop(\n            latents, batch\n        )\n\n        # Save STA mask search results if needed\n        if (\n            not is_warmup\n            and self.attn_backend.get_enum() == AttentionBackendEnum.SLIDING_TILE_ATTN\n            and server_args.attention_backend_config.STA_mode == \"STA_SEARCHING\"\n        ):\n            self.save_sta_search_results(batch)\n\n        # Capture references before potential deletion on MPS\n        dits = list(filter(None, [self.transformer, self.transformer_2]))\n\n        # deallocate transformer if on mps\n        pipeline = self.pipeline() if self.pipeline else None\n        if torch.backends.mps.is_available() and not is_warmup:\n            logger.info(\n                \"Memory before deallocating transformer: %s\",\n                torch.mps.current_allocated_memory(),\n            )\n            del self.transformer\n            if pipeline is not None and \"transformer\" in pipeline.modules:\n                del pipeline.modules[\"transformer\"]\n            server_args.model_loaded[\"transformer\"] = False\n            logger.info(\n                \"Memory after deallocating transformer: %s\",\n                torch.mps.current_allocated_memory(),\n            )\n\n        # reset offload managers with prefetching first layer for next forward\n        for dit in dits:\n            if isinstance(dit, OffloadableDiTMixin):\n                # release all DiT weights to avoid peak VRAM usage, which may increasing the latency for next req\n                # TODO: should be make this an option?\n                for manager in dit.layerwise_offload_managers:\n                    manager.release_all()\n\n    def _preprocess_sp_latents(self, batch: Req, server_args: ServerArgs):\n        \"\"\"Shard latents for Sequence Parallelism if applicable.\"\"\"\n        if get_sp_world_size() <= 1:\n            return\n\n        if batch.latents is not None:\n            (\n                batch.latents,\n                did_shard,\n            ) = server_args.pipeline_config.shard_latents_for_sp(batch, batch.latents)\n            batch.did_sp_shard_latents = did_shard\n        else:\n            batch.did_sp_shard_latents = False\n\n        # image_latent must be sharded consistently with latents when it is\n        # concatenated along the sequence dimension in the denoising loop.\n        if batch.image_latent is not None:\n            batch.image_latent, _ = server_args.pipeline_config.shard_latents_for_sp(\n                batch, batch.image_latent\n            )\n\n    def _postprocess_sp_latents(\n        self,\n        batch: Req,\n        latents: torch.Tensor,\n        trajectory_tensor: torch.Tensor | None,\n    ) -> tuple[torch.Tensor, torch.Tensor | None]:\n        \"\"\"Gather latents after Sequence Parallelism if they were sharded.\"\"\"\n        if get_sp_world_size() > 1 and getattr(batch, \"did_sp_shard_latents\", False):\n            latents = self.server_args.pipeline_config.gather_latents_for_sp(latents)\n            if trajectory_tensor is not None:\n                # trajectory_tensor shapes:\n                # - video: [b, num_steps, c, t_local, h, w] -> gather on dim=3\n                # - image: [b, num_steps, s_local, d] -> gather on dim=2\n                trajectory_tensor = trajectory_tensor.to(get_local_torch_device())\n                gather_dim = 3 if trajectory_tensor.dim() >= 5 else 2\n                trajectory_tensor = sequence_model_parallel_all_gather(\n                    trajectory_tensor, dim=gather_dim\n                )\n                if gather_dim == 2 and hasattr(batch, \"raw_latent_shape\"):\n                    orig_s = batch.raw_latent_shape[1]\n                    if trajectory_tensor.shape[2] > orig_s:\n                        trajectory_tensor = trajectory_tensor[:, :, :orig_s, :]\n        return latents, trajectory_tensor\n\n    def step_profile(self):\n        profiler = SGLDiffusionProfiler.get_instance()\n        if profiler:\n            profiler.step_denoising_step()\n\n    def _manage_device_placement(\n        self,\n        model_to_use: nn.Module,\n        model_to_offload: nn.Module | None,\n        server_args: ServerArgs,\n    ):\n        \"\"\"\n        Manages the offload / load behavior of dit\n        \"\"\"\n        if not server_args.dit_cpu_offload:\n            return\n\n        # FSDP manages offloading internally\n        if server_args.use_fsdp_inference:\n            return\n\n        # Offload the unused model if it's on CUDA\n        if (\n            model_to_offload is not None\n            and next(model_to_offload.parameters()).device.type == \"cuda\"\n        ):\n            model_to_offload.to(\"cpu\")\n\n        # Load the model to use if it's on CPU\n        if (\n            model_to_use is not None\n            and next(model_to_use.parameters()).device.type == \"cpu\"\n        ):\n            model_to_use.to(get_local_torch_device())\n\n    def _select_and_manage_model(\n        self,\n        t_int: int,\n        boundary_timestep: float | None,\n        server_args: ServerArgs,\n        batch: Req,\n    ):\n        if boundary_timestep is None or t_int >= boundary_timestep:\n            # High-noise stage\n            current_model = self.transformer\n            model_to_offload = self.transformer_2\n            current_guidance_scale = batch.guidance_scale\n        else:\n            # Low-noise stage\n            current_model = self.transformer_2\n            model_to_offload = self.transformer\n            current_guidance_scale = batch.guidance_scale_2\n\n        self._manage_device_placement(current_model, model_to_offload, server_args)\n\n        assert current_model is not None, \"The model for the current step is not set.\"\n        return current_model, current_guidance_scale\n\n    def expand_timestep_before_forward(\n        self,\n        batch: Req,\n        server_args: ServerArgs,\n        t_device,\n        target_dtype,\n        seq_len: int | None,\n        reserved_frames_mask,\n    ):\n        bsz = batch.raw_latent_shape[0]\n        should_preprocess_for_wan_ti2v = (\n            server_args.pipeline_config.task_type == ModelTaskType.TI2V\n            and batch.condition_image is not None\n            and type(server_args.pipeline_config) is Wan2_2_TI2V_5B_Config\n        )\n\n        # expand timestep\n        if should_preprocess_for_wan_ti2v:\n            # Explicitly cast t_device to the target float type at the beginning.\n            # This ensures any precision-based rounding (e.g., float32(999.0) -> bfloat16(1000.0))\n            # is applied consistently *before* it's used by any rank.\n            t_device_rounded = t_device.to(target_dtype)\n\n            local_seq_len = seq_len\n            if get_sp_world_size() > 1 and getattr(\n                batch, \"did_sp_shard_latents\", False\n            ):\n                local_seq_len = seq_len // get_sp_world_size()\n\n            if get_sp_parallel_rank() == 0 and reserved_frames_mask is not None:\n                # Rank 0 has the first frame, create a special timestep tensor\n                # NOTE: The spatial downsampling in the next line is suspicious but kept\n                # to match original model's potential training configuration.\n                temp_ts = (\n                    reserved_frames_mask[0][:, ::2, ::2] * t_device_rounded\n                ).flatten()\n\n                # Pad to full local sequence length\n                temp_ts = torch.cat(\n                    [\n                        temp_ts,\n                        temp_ts.new_ones(local_seq_len - temp_ts.size(0))\n                        * t_device_rounded,\n                    ]\n                )\n                timestep = temp_ts.unsqueeze(0).repeat(bsz, 1)\n            else:\n                # Other ranks get a uniform timestep tensor of the correct shape [B, local_seq_len]\n                timestep = t_device.repeat(bsz, local_seq_len)\n        else:\n            timestep = t_device.repeat(bsz)\n        return timestep\n\n    def post_forward_for_ti2v_task(\n        self, batch: Req, server_args: ServerArgs, reserved_frames_mask, latents, z\n    ):\n        \"\"\"\n        For Wan2.2 ti2v task, global first frame should be replaced with encoded image after each timestep\n        \"\"\"\n        should_preprocess_for_wan_ti2v = (\n            server_args.pipeline_config.task_type == ModelTaskType.TI2V\n            and batch.condition_image is not None\n            and type(server_args.pipeline_config) is Wan2_2_TI2V_5B_Config\n        )\n        if should_preprocess_for_wan_ti2v:\n            # Apply TI2V mask blending with SP-aware z and reserved_frames_mask.\n            # This ensures the first frame is always the condition image after each step.\n            # This is only applied on rank 0, where z is not None.\n            if z is not None and reserved_frames_mask is not None:\n                # z: [1, C, 1, H, W]\n                # latents: [1, C, T_local, H, W]\n                # reserved_frames_mask: [C, T_local, H, W]\n                # Unsqueeze mask to [1, C, T_local, H, W] for broadcasting.\n                # z will broadcast along the time dimension.\n                latents = (\n                    1.0 - reserved_frames_mask.unsqueeze(0)\n                ) * z + reserved_frames_mask.unsqueeze(0) * latents\n\n        return latents\n\n    @torch.no_grad()\n    def forward(\n        self,\n        batch: Req,\n        server_args: ServerArgs,\n    ) -> Req:\n        \"\"\"\n        Run the denoising loop.\n        \"\"\"\n        # Prepare variables for the denoising loop\n\n        prepared_vars = self._prepare_denoising_loop(batch, server_args)\n        extra_step_kwargs = prepared_vars[\"extra_step_kwargs\"]\n        target_dtype = prepared_vars[\"target_dtype\"]\n        autocast_enabled = prepared_vars[\"autocast_enabled\"]\n        timesteps = prepared_vars[\"timesteps\"]\n        num_inference_steps = prepared_vars[\"num_inference_steps\"]\n        num_warmup_steps = prepared_vars[\"num_warmup_steps\"]\n        image_kwargs = prepared_vars[\"image_kwargs\"]\n        pos_cond_kwargs = prepared_vars[\"pos_cond_kwargs\"]\n        neg_cond_kwargs = prepared_vars[\"neg_cond_kwargs\"]\n        latents = prepared_vars[\"latents\"]\n        boundary_timestep = prepared_vars[\"boundary_timestep\"]\n        z = prepared_vars[\"z\"]\n        reserved_frames_mask = prepared_vars[\"reserved_frames_mask\"]\n        seq_len = prepared_vars[\"seq_len\"]\n        guidance = prepared_vars[\"guidance\"]\n\n        # Initialize lists for ODE trajectory\n        trajectory_timesteps: list[torch.Tensor] = []\n        trajectory_latents: list[torch.Tensor] = []\n\n        # Run denoising loop\n        denoising_start_time = time.time()\n\n        # to avoid device-sync caused by timestep comparison\n        is_warmup = batch.is_warmup\n        self.scheduler.set_begin_index(0)\n        timesteps_cpu = timesteps.cpu()\n        num_timesteps = timesteps_cpu.shape[0]\n        with torch.autocast(\n            device_type=current_platform.device_type,\n            dtype=target_dtype,\n            enabled=autocast_enabled,\n        ):\n            with self.progress_bar(total=num_inference_steps) as progress_bar:\n                for i, t_host in enumerate(timesteps_cpu):\n                    with StageProfiler(\n                        f\"denoising_step_{i}\",\n                        logger=logger,\n                        metrics=batch.metrics,\n                        perf_dump_path_provided=batch.perf_dump_path is not None,\n                    ):\n                        t_int = int(t_host.item())\n                        t_device = timesteps[i]\n                        current_model, current_guidance_scale = (\n                            self._select_and_manage_model(\n                                t_int=t_int,\n                                boundary_timestep=boundary_timestep,\n                                server_args=server_args,\n                                batch=batch,\n                            )\n                        )\n\n                        # Expand latents for I2V\n                        latent_model_input = latents.to(target_dtype)\n                        if batch.image_latent is not None:\n                            assert (\n                                not server_args.pipeline_config.task_type\n                                == ModelTaskType.TI2V\n                            ), \"image latents should not be provided for TI2V task\"\n                            latent_model_input = torch.cat(\n                                [latent_model_input, batch.image_latent], dim=1\n                            ).to(target_dtype)\n\n                        timestep = self.expand_timestep_before_forward(\n                            batch,\n                            server_args,\n                            t_device,\n                            target_dtype,\n                            seq_len,\n                            reserved_frames_mask,\n                        )\n\n                        latent_model_input = self.scheduler.scale_model_input(\n                            latent_model_input, t_device\n                        )\n\n                        # Predict noise residual\n                        attn_metadata = self._build_attn_metadata(\n                            i,\n                            batch,\n                            server_args,\n                            timestep_value=t_int,\n                            timesteps=timesteps_cpu,\n                        )\n                        noise_pred = self._predict_noise_with_cfg(\n                            current_model=current_model,\n                            latent_model_input=latent_model_input,\n                            timestep=timestep,\n                            batch=batch,\n                            timestep_index=i,\n                            attn_metadata=attn_metadata,\n                            target_dtype=target_dtype,\n                            current_guidance_scale=current_guidance_scale,\n                            image_kwargs=image_kwargs,\n                            pos_cond_kwargs=pos_cond_kwargs,\n                            neg_cond_kwargs=neg_cond_kwargs,\n                            server_args=server_args,\n                            guidance=guidance,\n                            latents=latents,\n                        )\n\n                        # Save noise_pred to batch for external access (e.g., ComfyUI)\n                        if server_args.comfyui_mode:\n                            batch.noise_pred = noise_pred\n\n                        # Compute the previous noisy sample\n                        latents = self.scheduler.step(\n                            model_output=noise_pred,\n                            timestep=t_device,\n                            sample=latents,\n                            **extra_step_kwargs,\n                            return_dict=False,\n                        )[0]\n\n                        latents = self.post_forward_for_ti2v_task(\n                            batch, server_args, reserved_frames_mask, latents, z\n                        )\n\n                        # save trajectory latents if needed\n                        if batch.return_trajectory_latents:\n                            trajectory_timesteps.append(t_host)\n                            trajectory_latents.append(latents)\n\n                        # Update progress bar\n                        if i == num_timesteps - 1 or (\n                            (i + 1) > num_warmup_steps\n                            and (i + 1) % self.scheduler.order == 0\n                            and progress_bar is not None\n                        ):\n                            progress_bar.update()\n\n                        if not is_warmup:\n                            self.step_profile()\n\n        denoising_end_time = time.time()\n\n        if num_timesteps > 0 and not is_warmup:\n            self.log_info(\n                \"average time per step: %.4f seconds\",\n                (denoising_end_time - denoising_start_time) / len(timesteps),\n            )\n\n        self._post_denoising_loop(\n            batch=batch,\n            latents=latents,\n            trajectory_latents=trajectory_latents,\n            trajectory_timesteps=trajectory_timesteps,\n            server_args=server_args,\n            is_warmup=is_warmup,\n        )\n        return batch\n\n    # TODO: this will extends the preparation stage, should let subclass/passed-in variables decide which to prepare\n    def prepare_extra_func_kwargs(self, func, kwargs) -> dict[str, Any]:\n        \"\"\"\n        Prepare extra kwargs for the scheduler step / denoise step.\n\n        Args:\n            func: The function to prepare kwargs for.\n            kwargs: The kwargs to prepare.\n        \"\"\"\n        import functools\n\n        # Handle cache-dit's partial wrapping logic.\n        # Cache-dit wraps the forward method with functools.partial where args[0] is the instance.\n        # We access `_original_forward` if available to inspect the underlying signature.\n        # See: https://github.com/vipshop/cache-dit\n        if isinstance(func, functools.partial) and func.args:\n            func = getattr(func.args[0], \"_original_forward\", func)\n\n        # Unwrap any decorators (e.g. functools.wraps)\n        target_func = inspect.unwrap(func)\n\n        # Filter kwargs based on the signature\n        params = inspect.signature(target_func).parameters\n        return {k: v for k, v in kwargs.items() if k in params}\n\n    def progress_bar(\n        self, iterable: Iterable | None = None, total: int | None = None\n    ) -> tqdm:\n        \"\"\"\n        Create a progress bar for the denoising process.\n        \"\"\"\n        local_rank = get_world_group().local_rank\n        disable = local_rank != 0\n        return tqdm(iterable=iterable, total=total, disable=disable)\n\n    def rescale_noise_cfg(\n        self, noise_cfg, noise_pred_text, guidance_rescale=0.0\n    ) -> torch.Tensor:\n        \"\"\"\n        Rescale noise prediction according to guidance_rescale.\n\n        Based on findings of \"Common Diffusion Noise Schedules and Sample Steps are Flawed\"\n        (https://arxiv.org/pdf/2305.08891.pdf), Section 3.4.\n\n        Args:\n            noise_cfg: The noise prediction with guidance.\n            noise_pred_text: The text-conditioned noise prediction.\n            guidance_rescale: The guidance rescale factor.\n\n        Returns:\n            The rescaled noise prediction.\n        \"\"\"\n        std_text = noise_pred_text.std(\n            dim=list(range(1, noise_pred_text.ndim)), keepdim=True\n        )\n        std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)\n        # Rescale the results from guidance (fixes overexposure)\n        noise_pred_rescaled = noise_cfg * (std_text / std_cfg)\n        # Mix with the original results from guidance by factor guidance_rescale\n        noise_cfg = (\n            guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg\n        )\n        return noise_cfg\n\n    def _build_attn_metadata(\n        self,\n        i: int,\n        batch: Req,\n        server_args: ServerArgs,\n        *,\n        timestep_value: int | None = None,\n        timesteps: torch.Tensor | None = None,\n    ) -> Any | None:\n        \"\"\"\n        Build attention metadata for custom attention backends.\n\n        Args:\n            i: The current timestep index.\n        \"\"\"\n        attn_metadata = None\n        self.attn_metadata_builder = None\n        try:\n            self.attn_metadata_builder_cls = self.attn_backend.get_builder_cls()\n        except NotImplementedError:\n            self.attn_metadata_builder_cls = None\n        if self.attn_metadata_builder_cls:\n            self.attn_metadata_builder = self.attn_metadata_builder_cls()\n        if (\n            self.attn_backend.get_enum() == AttentionBackendEnum.SLIDING_TILE_ATTN\n            or self.attn_backend.get_enum() == AttentionBackendEnum.VIDEO_SPARSE_ATTN\n        ):\n            attn_metadata = self.attn_metadata_builder.build(\n                current_timestep=i,\n                raw_latent_shape=batch.raw_latent_shape[2:5],\n                patch_size=server_args.pipeline_config.dit_config.patch_size,\n                STA_param=batch.STA_param,\n                VSA_sparsity=server_args.attention_backend_config.VSA_sparsity,\n                device=get_local_torch_device(),\n            )\n        elif (\n            self.attn_backend.get_enum() == AttentionBackendEnum.SPARSE_VIDEO_GEN_2_ATTN\n        ):\n            if timestep_value is None or timesteps is None:\n                raise ValueError(\n                    \"timestep_value and timesteps must be provided for SVG2 attention metadata\"\n                )\n\n            svg2_cfg = server_args.attention_backend_config or {}\n            num_layers = server_args.pipeline_config.dit_config.num_layers\n            if (\n                server_args.pipeline_config.dit_config.prefix.lower() == \"hunyuan\"\n                and hasattr(server_args.pipeline_config.dit_config, \"num_single_layers\")\n            ):\n                num_layers += server_args.pipeline_config.dit_config.num_single_layers\n            first_layers_fp = svg2_cfg.get(\"svg2_first_layers_fp\", 0.03)\n            if first_layers_fp <= 1.0:\n                first_layers_fp = math.floor(first_layers_fp * num_layers)\n            first_layers_fp = max(0, min(int(first_layers_fp), num_layers))\n\n            first_times_fp = svg2_cfg.get(\"svg2_first_times_fp\", 0.2)\n            if first_times_fp <= 1.0:\n                num_fp_steps = math.floor(first_times_fp * len(timesteps))\n                if num_fp_steps > 0:\n                    first_times_fp = float(timesteps[num_fp_steps - 1].item() - 1)\n                else:\n                    first_times_fp = float(timesteps.max().item() + 1)\n\n            current_timestep = int(timestep_value)\n\n            cache = batch.extra.get(\"svg2_cache\")\n            if cache is None:\n                from sglang.multimodal_gen.runtime.layers.attention.backends.sparse_video_gen_2_attn import (\n                    Svg2Cache,\n                )\n\n                cache = Svg2Cache()\n                batch.extra[\"svg2_cache\"] = cache\n\n            patch_size = server_args.pipeline_config.dit_config.patch_size\n            if isinstance(patch_size, list):\n                patch_size = tuple(patch_size)\n            if isinstance(patch_size, int):\n                patch_size_t = getattr(\n                    server_args.pipeline_config.dit_config, \"patch_size_t\", None\n                )\n                if patch_size_t is not None:\n                    patch_size = (patch_size_t, patch_size, patch_size)\n\n            context_length = 0\n            prompt_length = None\n            if server_args.pipeline_config.dit_config.prefix.lower() == \"hunyuan\":\n                prompt_embeds = server_args.pipeline_config.get_pos_prompt_embeds(batch)\n                if isinstance(prompt_embeds, list):\n                    text_embeds = prompt_embeds[0] if prompt_embeds else None\n                else:\n                    text_embeds = prompt_embeds\n                if isinstance(text_embeds, torch.Tensor) and text_embeds.ndim >= 2:\n                    context_length = int(text_embeds.shape[1])\n                if context_length > 0 and batch.prompt_attention_mask:\n                    mask = batch.prompt_attention_mask[0]\n                    if isinstance(mask, torch.Tensor):\n                        if mask.shape[-1] > context_length:\n                            mask = mask[:, -context_length:]\n                        prompt_length = int(mask[0].sum().item())\n                if prompt_length is None:\n                    prompt_length = context_length\n\n            attn_metadata = self.attn_metadata_builder.build(\n                current_timestep=current_timestep,\n                raw_latent_shape=batch.raw_latent_shape,\n                patch_size=patch_size,\n                num_q_centroids=svg2_cfg.get(\"svg2_num_q_centroids\", 300),\n                num_k_centroids=svg2_cfg.get(\"svg2_num_k_centroids\", 1000),\n                top_p_kmeans=svg2_cfg.get(\"svg2_top_p_kmeans\", 0.9),\n                min_kc_ratio=svg2_cfg.get(\"svg2_min_kc_ratio\", 0.1),\n                kmeans_iter_init=svg2_cfg.get(\"svg2_kmeans_iter_init\", 50),\n                kmeans_iter_step=svg2_cfg.get(\"svg2_kmeans_iter_step\", 2),\n                zero_step_kmeans_init=svg2_cfg.get(\"svg2_zero_step_kmeans_init\", False),\n                first_layers_fp=first_layers_fp,\n                first_times_fp=first_times_fp,\n                context_length=context_length,\n                prompt_length=prompt_length,\n                cache=cache,\n                calculate_density=False,  # only need density when doing head load balancing\n            )\n        elif self.attn_backend.get_enum() == AttentionBackendEnum.VMOBA_ATTN:\n            moba_params = server_args.attention_backend_config.moba_config.copy()\n            moba_params.update(\n                {\n                    \"current_timestep\": i,\n                    \"raw_latent_shape\": batch.raw_latent_shape[2:5],\n                    \"patch_size\": server_args.pipeline_config.dit_config.patch_size,\n                    \"device\": get_local_torch_device(),\n                }\n            )\n        elif self.attn_backend.get_enum() == AttentionBackendEnum.FA:\n            attn_metadata = self.attn_metadata_builder.build(\n                raw_latent_shape=batch.raw_latent_shape\n            )\n        else:\n            # attn_metadata can be None for SDPA attention backend\n            return None\n\n        return attn_metadata\n\n    def _predict_noise(\n        self,\n        current_model,\n        latent_model_input,\n        timestep,\n        target_dtype,\n        guidance: torch.Tensor,\n        **kwargs,\n    ):\n        return current_model(\n            hidden_states=latent_model_input,\n            timestep=timestep,\n            guidance=guidance,\n            **kwargs,\n        )\n\n    def _predict_noise_with_cfg(\n        self,\n        current_model: nn.Module,\n        latent_model_input: torch.Tensor,\n        timestep,\n        batch: Req,\n        timestep_index: int,\n        attn_metadata,\n        target_dtype,\n        current_guidance_scale,\n        image_kwargs: dict[str, Any],\n        pos_cond_kwargs: dict[str, Any],\n        neg_cond_kwargs: dict[str, Any],\n        server_args,\n        guidance,\n        latents,\n    ):\n        \"\"\"\n        Predict the noise residual with classifier-free guidance.\n\n        Args:\n            current_model: The transformer model to use for the current step.\n            latent_model_input: The input latents for the model.\n            timestep: The expanded timestep tensor.\n            batch: The current batch information.\n            timestep_index: The current timestep index.\n            attn_metadata: Attention metadata for custom backends.\n            target_dtype: The target data type for autocasting.\n            current_guidance_scale: The guidance scale for the current step.\n            image_kwargs: Keyword arguments for image conditioning.\n            pos_cond_kwargs: Keyword arguments for positive prompt conditioning.\n            neg_cond_kwargs: Keyword arguments for negative prompt conditioning.\n\n        Returns:\n            The predicted noise.\n        \"\"\"\n        noise_pred_cond: torch.Tensor | None = None\n        noise_pred_uncond: torch.Tensor | None = None\n        cfg_rank = get_classifier_free_guidance_rank()\n        # positive pass\n        if not (server_args.enable_cfg_parallel and cfg_rank != 0):\n            batch.is_cfg_negative = False\n            with set_forward_context(\n                current_timestep=timestep_index,\n                attn_metadata=attn_metadata,\n                forward_batch=batch,\n            ):\n                noise_pred_cond = self._predict_noise(\n                    current_model=current_model,\n                    latent_model_input=latent_model_input,\n                    timestep=timestep,\n                    target_dtype=target_dtype,\n                    guidance=guidance,\n                    **image_kwargs,\n                    **pos_cond_kwargs,\n                )\n                # TODO: can it be moved to after _predict_noise_with_cfg?\n                noise_pred_cond = server_args.pipeline_config.slice_noise_pred(\n                    noise_pred_cond, latents\n                )\n        if not batch.do_classifier_free_guidance:\n            # If CFG is disabled, we are done. Return the conditional prediction.\n            return noise_pred_cond\n\n        # negative pass\n        if not server_args.enable_cfg_parallel or cfg_rank != 0:\n            batch.is_cfg_negative = True\n            with set_forward_context(\n                current_timestep=timestep_index,\n                attn_metadata=attn_metadata,\n                forward_batch=batch,\n            ):\n                noise_pred_uncond = self._predict_noise(\n                    current_model=current_model,\n                    latent_model_input=latent_model_input,\n                    timestep=timestep,\n                    target_dtype=target_dtype,\n                    guidance=guidance,\n                    **image_kwargs,\n                    **neg_cond_kwargs,\n                )\n                noise_pred_uncond = server_args.pipeline_config.slice_noise_pred(\n                    noise_pred_uncond, latents\n                )\n\n        # Combine predictions\n        if server_args.enable_cfg_parallel:\n            # Each rank computes its partial contribution and we sum via all-reduce:\n            #   final = s*cond + (1-s)*uncond\n            if cfg_rank == 0:\n                assert noise_pred_cond is not None\n                partial = current_guidance_scale * noise_pred_cond\n            else:\n                assert noise_pred_uncond is not None\n                partial = (1 - current_guidance_scale) * noise_pred_uncond\n\n            noise_pred = cfg_model_parallel_all_reduce(partial)\n\n            if batch.cfg_normalization and float(batch.cfg_normalization) > 0:\n                factor = float(batch.cfg_normalization)\n                pred_f = noise_pred.float()\n                new_norm = torch.linalg.vector_norm(pred_f)\n                if cfg_rank == 0:\n                    cond_f = noise_pred_cond.float()\n                    ori_norm = torch.linalg.vector_norm(cond_f)\n                else:\n                    ori_norm = torch.empty_like(new_norm)\n                ori_norm = get_cfg_group().broadcast(ori_norm, src=0)\n                max_norm = ori_norm * factor\n\n                if new_norm > max_norm:\n                    noise_pred = noise_pred * (max_norm / new_norm)\n\n            # Guidance rescale: broadcast std(cond) from rank 0, compute std(cfg) locally\n            if batch.guidance_rescale > 0.0:\n                std_cfg = noise_pred.std(\n                    dim=list(range(1, noise_pred.ndim)), keepdim=True\n                )\n                if cfg_rank == 0:\n                    assert noise_pred_cond is not None\n                    std_text = noise_pred_cond.std(\n                        dim=list(range(1, noise_pred_cond.ndim)), keepdim=True\n                    )\n                else:\n                    std_text = torch.empty_like(std_cfg)\n                # Broadcast std_text from local src=0 to all ranks in CFG group\n                std_text = get_cfg_group().broadcast(std_text, src=0)\n                noise_pred_rescaled = noise_pred * (std_text / std_cfg)\n                noise_pred = (\n                    batch.guidance_rescale * noise_pred_rescaled\n                    + (1 - batch.guidance_rescale) * noise_pred\n                )\n            return noise_pred\n        else:\n            # Serial CFG: both cond and uncond are available locally\n            assert noise_pred_cond is not None and noise_pred_uncond is not None\n            noise_pred = noise_pred_uncond + current_guidance_scale * (\n                noise_pred_cond - noise_pred_uncond\n            )\n\n            if batch.cfg_normalization and float(batch.cfg_normalization) > 0:\n                factor = float(batch.cfg_normalization)\n                cond_f = noise_pred_cond.float()\n                pred_f = noise_pred.float()\n                ori_norm = torch.linalg.vector_norm(cond_f)\n                new_norm = torch.linalg.vector_norm(pred_f)\n                max_norm = ori_norm * factor\n\n                if new_norm > max_norm:\n                    noise_pred = noise_pred * (max_norm / new_norm)\n\n            if batch.guidance_rescale > 0.0:\n                noise_pred = self.rescale_noise_cfg(\n                    noise_pred,\n                    noise_pred_cond,\n                    guidance_rescale=batch.guidance_rescale,\n                )\n            return noise_pred\n\n    def prepare_sta_param(self, batch: Req, server_args: ServerArgs):\n        \"\"\"\n        Prepare Sliding Tile Attention (STA) parameters and settings.\n        \"\"\"\n        # TODO(kevin): STA mask search, currently only support Wan2.1 with 69x768x1280\n        try:\n            STA_mode = STA_Mode[server_args.attention_backend_config.STA_mode]\n        except Exception as e:\n            logger.error(f\"Passed STA_mode: {STA_mode} doesn't exist\")\n            raise e\n        skip_time_steps = server_args.attention_backend_config.skip_time_steps\n        if batch.timesteps is None:\n            raise ValueError(\"Timesteps must be provided\")\n        timesteps_num = batch.timesteps.shape[0]\n\n        logger.info(\"STA_mode: %s\", STA_mode)\n        if (batch.num_frames, batch.height, batch.width) != (\n            69,\n            768,\n            1280,\n        ) and STA_mode != \"STA_inference\":\n            raise NotImplementedError(\n                \"STA mask search/tuning is not supported for this resolution\"\n            )\n\n        if (\n            STA_mode == STA_Mode.STA_SEARCHING\n            or STA_mode == STA_Mode.STA_TUNING\n            or STA_mode == STA_Mode.STA_TUNING_CFG\n        ):\n            size = (batch.width, batch.height)\n            if size == (1280, 768):\n                # TODO: make it configurable\n                sparse_mask_candidates_searching = [\n                    \"3, 1, 10\",\n                    \"1, 5, 7\",\n                    \"3, 3, 3\",\n                    \"1, 6, 5\",\n                    \"1, 3, 10\",\n                    \"3, 6, 1\",\n                ]\n                sparse_mask_candidates_tuning = [\n                    \"3, 1, 10\",\n                    \"1, 5, 7\",\n                    \"3, 3, 3\",\n                    \"1, 6, 5\",\n                    \"1, 3, 10\",\n                    \"3, 6, 1\",\n                ]\n                full_mask = [\"3,6,10\"]\n            else:\n                raise NotImplementedError(\n                    \"STA mask search is not supported for this resolution\"\n                )\n        layer_num = self.transformer.config.num_layers\n        # specific for HunyuanVideo\n        if hasattr(self.transformer.config, \"num_single_layers\"):\n            layer_num += self.transformer.config.num_single_layers\n        head_num = self.transformer.config.num_attention_heads\n\n        if STA_mode == STA_Mode.STA_SEARCHING:\n            STA_param = configure_sta(\n                mode=STA_Mode.STA_SEARCHING,\n                layer_num=layer_num,\n                head_num=head_num,\n                time_step_num=timesteps_num,\n                mask_candidates=sparse_mask_candidates_searching + full_mask,\n                # last is full mask; Can add more sparse masks while keep last one as full mask\n            )\n        elif STA_mode == STA_Mode.STA_TUNING:\n            STA_param = configure_sta(\n                mode=STA_Mode.STA_TUNING,\n                layer_num=layer_num,\n                head_num=head_num,\n                time_step_num=timesteps_num,\n                mask_search_files_path=f\"output/mask_search_result_pos_{size[0]}x{size[1]}/\",\n                mask_candidates=sparse_mask_candidates_tuning,\n                full_attention_mask=[int(x) for x in full_mask[0].split(\",\")],\n                skip_time_steps=skip_time_steps,  # Use full attention for first 12 steps\n                save_dir=f\"output/mask_search_strategy_{size[0]}x{size[1]}/\",  # Custom save directory\n                timesteps=timesteps_num,\n            )\n        elif STA_mode == STA_Mode.STA_TUNING_CFG:\n            STA_param = configure_sta(\n                mode=STA_Mode.STA_TUNING_CFG,\n                layer_num=layer_num,\n                head_num=head_num,\n                time_step_num=timesteps_num,\n                mask_search_files_path_pos=f\"output/mask_search_result_pos_{size[0]}x{size[1]}/\",\n                mask_search_files_path_neg=f\"output/mask_search_result_neg_{size[0]}x{size[1]}/\",\n                mask_candidates=sparse_mask_candidates_tuning,\n                full_attention_mask=[int(x) for x in full_mask[0].split(\",\")],\n                skip_time_steps=skip_time_steps,\n                save_dir=f\"output/mask_search_strategy_{size[0]}x{size[1]}/\",\n                timesteps=timesteps_num,\n            )\n        elif STA_mode == STA_Mode.STA_INFERENCE:\n            import sglang.multimodal_gen.envs as envs\n\n            config_file = envs.SGLANG_DIFFUSION_ATTENTION_CONFIG\n            if config_file is None:\n                raise ValueError(\"SGLANG_DIFFUSION_ATTENTION_CONFIG is not set\")\n            STA_param = configure_sta(\n                mode=STA_Mode.STA_INFERENCE,\n                layer_num=layer_num,\n                head_num=head_num,\n                time_step_num=timesteps_num,\n                load_path=config_file,\n            )\n\n        batch.STA_param = STA_param\n        batch.mask_search_final_result_pos = [[] for _ in range(timesteps_num)]\n        batch.mask_search_final_result_neg = [[] for _ in range(timesteps_num)]\n\n    def save_sta_search_results(self, batch: Req):\n        \"\"\"\n        Save the STA mask search results.\n\n        Args:\n            batch: The current batch information.\n        \"\"\"\n        size = (batch.width, batch.height)\n        if size == (1280, 768):\n            # TODO: make it configurable\n            sparse_mask_candidates_searching = [\n                \"3, 1, 10\",\n                \"1, 5, 7\",\n                \"3, 3, 3\",\n                \"1, 6, 5\",\n                \"1, 3, 10\",\n                \"3, 6, 1\",\n            ]\n        else:\n            raise NotImplementedError(\n                \"STA mask search is not supported for this resolution\"\n            )\n\n        if batch.mask_search_final_result_pos is not None and batch.prompt is not None:\n            save_mask_search_results(\n                [dict(layer_data) for layer_data in batch.mask_search_final_result_pos],\n                prompt=str(batch.prompt),\n                mask_strategies=sparse_mask_candidates_searching,\n                output_dir=f\"output/mask_search_result_pos_{size[0]}x{size[1]}/\",\n            )\n        if batch.mask_search_final_result_neg is not None and batch.prompt is not None:\n            save_mask_search_results(\n                [dict(layer_data) for layer_data in batch.mask_search_final_result_neg],\n                prompt=str(batch.prompt),\n                mask_strategies=sparse_mask_candidates_searching,\n                output_dir=f\"output/mask_search_result_neg_{size[0]}x{size[1]}/\",\n            )\n\n    def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult:\n        \"\"\"Verify denoising stage inputs.\"\"\"\n        result = VerificationResult()\n        result.add_check(\"timesteps\", batch.timesteps, [V.is_tensor, V.min_dims(1)])\n        # disable temporarily for image-generation models\n        # result.add_check(\"latents\", batch.latents, [V.is_tensor, V.with_dims(5)])\n        result.add_check(\"prompt_embeds\", batch.prompt_embeds, V.list_not_empty)\n        result.add_check(\"image_embeds\", batch.image_embeds, V.is_list)\n        # result.add_check(\n        #     \"image_latent\", batch.image_latent, V.none_or_tensor_with_dims(5)\n        # )\n        result.add_check(\n            \"num_inference_steps\", batch.num_inference_steps, V.positive_int\n        )\n        result.add_check(\"guidance_scale\", batch.guidance_scale, V.non_negative_float)\n        result.add_check(\"eta\", batch.eta, V.non_negative_float)\n        result.add_check(\"generator\", batch.generator, V.generator_or_list_generators)\n        result.add_check(\n            \"do_classifier_free_guidance\",\n            batch.do_classifier_free_guidance,\n            V.bool_value,\n        )\n        result.add_check(\n            \"negative_prompt_embeds\",\n            batch.negative_prompt_embeds,\n            lambda x: not batch.do_classifier_free_guidance or V.list_not_empty(x),\n        )\n        return result\n\n    def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult:\n        \"\"\"Verify denoising stage outputs.\"\"\"\n        result = VerificationResult()\n        # result.add_check(\"latents\", batch.latents, [V.is_tensor, V.with_dims(5)])\n        return result\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising_av.py",
    "content": "import copy\nimport math\nimport time\nfrom io import BytesIO\n\nimport av\nimport numpy as np\nimport PIL.Image\nimport torch\nfrom diffusers.models.autoencoders.vae import DiagonalGaussianDistribution\nfrom diffusers.models.modeling_outputs import AutoencoderKLOutput\n\nfrom sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context\nfrom sglang.multimodal_gen.runtime.models.vision_utils import (\n    load_image,\n    normalize,\n    numpy_to_pt,\n    pil_to_numpy,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.denoising import DenoisingStage\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.validators import (\n    StageValidators as V,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.validators import (\n    VerificationResult,\n)\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.runtime.utils.perf_logger import StageProfiler\nfrom sglang.multimodal_gen.utils import PRECISION_TO_TYPE\n\nlogger = init_logger(__name__)\n\n\nclass LTX2AVDenoisingStage(DenoisingStage):\n    \"\"\"\n    LTX-2 specific denoising stage that handles joint video and audio generation.\n    \"\"\"\n\n    def __init__(self, transformer, scheduler, vae=None, audio_vae=None, **kwargs):\n        super().__init__(\n            transformer=transformer, scheduler=scheduler, vae=vae, **kwargs\n        )\n        self.audio_vae = audio_vae\n\n    @staticmethod\n    def _get_video_latent_num_frames_for_model(\n        batch: Req, server_args: ServerArgs, latents: torch.Tensor\n    ) -> int:\n        \"\"\"Return the latent-frame length the DiT model should see.\n\n        - If video latents were time-sharded for SP and are packed as token latents\n          ([B, S, D]), the model only sees the local shard and must use the local\n          latent-frame count (stored on the batch during SP sharding).\n        - Otherwise, fall back to the global latent-frame count inferred from the\n          requested output frames and the VAE temporal compression ratio.\n        \"\"\"\n        did_sp_shard = bool(getattr(batch, \"did_sp_shard_latents\", False))\n        is_token_latents = isinstance(latents, torch.Tensor) and latents.ndim == 3\n\n        if did_sp_shard and is_token_latents:\n            if not hasattr(batch, \"sp_video_latent_num_frames\"):\n                raise ValueError(\n                    \"SP-sharded LTX2 token latents require `batch.sp_video_latent_num_frames` \"\n                    \"to be set by `LTX2PipelineConfig.shard_latents_for_sp()`.\"\n                )\n            return int(batch.sp_video_latent_num_frames)\n\n        pc = server_args.pipeline_config\n        return int((batch.num_frames - 1) // int(pc.vae_temporal_compression) + 1)\n\n    @staticmethod\n    def _truncate_sp_padded_token_latents(\n        batch: Req, latents: torch.Tensor\n    ) -> torch.Tensor:\n        \"\"\"Remove token padding introduced by SP time-sharding (if applicable).\"\"\"\n        did_sp_shard = bool(getattr(batch, \"did_sp_shard_latents\", False))\n        if not did_sp_shard or not (\n            isinstance(latents, torch.Tensor) and latents.ndim == 3\n        ):\n            return latents\n\n        raw_shape = getattr(batch, \"raw_latent_shape\", None)\n        if not (isinstance(raw_shape, tuple) and len(raw_shape) == 3):\n            return latents\n\n        orig_s = int(raw_shape[1])\n        cur_s = int(latents.shape[1])\n        if cur_s == orig_s:\n            return latents\n        if cur_s < orig_s:\n            raise ValueError(\n                f\"Unexpected gathered token-latents seq_len {cur_s} < original seq_len {orig_s}.\"\n            )\n        return latents[:, :orig_s, :].contiguous()\n\n    def _maybe_enable_cache_dit(self, num_inference_steps: int, batch: Req) -> None:\n        \"\"\"Disable cache-dit for TI2V-style requests (image-conditioned), to avoid stale activations.\n\n        NOTE: base denoising stage calls this hook with (num_inference_steps, batch).\n        \"\"\"\n        if getattr(self, \"_disable_cache_dit_for_request\", False):\n            return\n        return super()._maybe_enable_cache_dit(num_inference_steps, batch)\n\n    @staticmethod\n    def _resize_center_crop(\n        img: PIL.Image.Image, *, width: int, height: int\n    ) -> PIL.Image.Image:\n        return img.resize((width, height), resample=PIL.Image.Resampling.BILINEAR)\n\n    @staticmethod\n    def _apply_video_codec_compression(\n        img_array: np.ndarray, crf: int = 33\n    ) -> np.ndarray:\n        \"\"\"Encode as a single H.264 frame and decode back to simulate compression artifacts.\"\"\"\n        if crf == 0:\n            return img_array\n        height, width = img_array.shape[0] // 2 * 2, img_array.shape[1] // 2 * 2\n        img_array = img_array[:height, :width]\n        buffer = BytesIO()\n        container = av.open(buffer, mode=\"w\", format=\"mp4\")\n        stream = container.add_stream(\n            \"libx264\", rate=1, options={\"crf\": str(crf), \"preset\": \"veryfast\"}\n        )\n        stream.height, stream.width = height, width\n        frame = av.VideoFrame.from_ndarray(img_array, format=\"rgb24\").reformat(\n            format=\"yuv420p\"\n        )\n        container.mux(stream.encode(frame))\n        container.mux(stream.encode())\n        container.close()\n        buffer.seek(0)\n        container = av.open(buffer)\n        decoded = next(container.decode(container.streams.video[0]))\n        container.close()\n        return decoded.to_ndarray(format=\"rgb24\")\n\n    @staticmethod\n    def _resize_center_crop_tensor(\n        img: PIL.Image.Image,\n        *,\n        width: int,\n        height: int,\n        device: torch.device,\n        dtype: torch.dtype,\n        apply_codec_compression: bool = True,\n        codec_crf: int = 33,\n    ) -> torch.Tensor:\n        \"\"\"Resize, center-crop, and normalize to [1, C, 1, H, W] tensor in [-1, 1].\"\"\"\n        img_array = np.array(img).astype(np.uint8)[..., :3]\n        if apply_codec_compression:\n            img_array = LTX2AVDenoisingStage._apply_video_codec_compression(\n                img_array, crf=codec_crf\n            )\n        tensor = (\n            torch.from_numpy(img_array.astype(np.float32))\n            .permute(2, 0, 1)\n            .unsqueeze(0)\n            .to(device=device)\n        )\n        src_h, src_w = tensor.shape[2], tensor.shape[3]\n        scale = max(height / src_h, width / src_w)\n        new_h, new_w = math.ceil(src_h * scale), math.ceil(src_w * scale)\n        tensor = torch.nn.functional.interpolate(\n            tensor, size=(new_h, new_w), mode=\"bilinear\", align_corners=False\n        )\n        top, left = (new_h - height) // 2, (new_w - width) // 2\n        tensor = tensor[:, :, top : top + height, left : left + width]\n        return ((tensor / 127.5 - 1.0).to(dtype=dtype)).unsqueeze(2)\n\n    @staticmethod\n    def _pil_to_normed_tensor(img: PIL.Image.Image) -> torch.Tensor:\n        # PIL -> numpy [0,1] -> torch [B,C,H,W], then [-1,1]\n        arr = pil_to_numpy(img)\n        t = numpy_to_pt(arr)\n        return normalize(t)\n\n    @staticmethod\n    def _should_apply_ltx2_ti2v(batch: Req) -> bool:\n        \"\"\"True if we have an image-latent token prefix to condition with.\n\n        SP note: when token latents are time-sharded, only the rank that owns the\n        *global* first latent frame should apply TI2V conditioning (rank with start_frame==0).\n        \"\"\"\n        if (\n            batch.image_latent is None\n            or int(getattr(batch, \"ltx2_num_image_tokens\", 0)) <= 0\n        ):\n            return False\n        did_sp_shard = bool(getattr(batch, \"did_sp_shard_latents\", False))\n        if not did_sp_shard:\n            return True\n        return int(getattr(batch, \"sp_video_start_frame\", 0)) == 0\n\n    def _prepare_ltx2_image_latent(self, batch: Req, server_args: ServerArgs) -> None:\n        \"\"\"Encode `batch.image_path` into packed token latents for LTX-2 TI2V.\"\"\"\n        if (\n            batch.image_latent is not None\n            and int(getattr(batch, \"ltx2_num_image_tokens\", 0)) > 0\n        ):\n            return\n        batch.ltx2_num_image_tokens = 0\n        batch.image_latent = None\n\n        if batch.image_path is None:\n            return\n        if batch.width is None or batch.height is None:\n            raise ValueError(\"width/height must be provided for LTX-2 TI2V.\")\n        if self.vae is None:\n            raise ValueError(\"VAE must be provided for LTX-2 TI2V.\")\n\n        image_path = (\n            batch.image_path[0]\n            if isinstance(batch.image_path, list)\n            else batch.image_path\n        )\n\n        img = load_image(image_path)\n        batch.condition_image = self._resize_center_crop(\n            img, width=int(batch.width), height=int(batch.height)\n        )\n\n        latents_device = (\n            batch.latents.device\n            if isinstance(batch.latents, torch.Tensor)\n            else torch.device(\"cpu\")\n        )\n        encode_dtype = batch.latents.dtype\n        original_dtype = PRECISION_TO_TYPE[server_args.pipeline_config.vae_precision]\n        self.vae = self.vae.to(device=latents_device, dtype=encode_dtype)\n        vae_autocast_enabled = (\n            original_dtype != torch.float32\n        ) and not server_args.disable_autocast\n\n        video_condition = self._resize_center_crop_tensor(\n            img,\n            width=int(batch.width),\n            height=int(batch.height),\n            device=latents_device,\n            dtype=encode_dtype,\n        )\n\n        with torch.autocast(\n            device_type=current_platform.device_type,\n            dtype=original_dtype,\n            enabled=vae_autocast_enabled,\n        ):\n            try:\n                if server_args.pipeline_config.vae_tiling:\n                    self.vae.enable_tiling()\n            except Exception:\n                pass\n            if not vae_autocast_enabled:\n                video_condition = video_condition.to(encode_dtype)\n\n            latent_dist: DiagonalGaussianDistribution = self.vae.encode(video_condition)\n            if isinstance(latent_dist, AutoencoderKLOutput):\n                latent_dist = latent_dist.latent_dist\n\n        mode = server_args.pipeline_config.vae_config.encode_sample_mode()\n        if mode == \"argmax\":\n            latent = latent_dist.mode()\n        elif mode == \"sample\":\n            if batch.generator is None:\n                raise ValueError(\"Generator must be provided for VAE sampling.\")\n            latent = latent_dist.sample(batch.generator)\n        else:\n            raise ValueError(f\"Unsupported encode_sample_mode: {mode}\")\n\n        # Per-channel normalization: normalized = (x - mean) / std\n        mean = self.vae.latents_mean.view(1, -1, 1, 1, 1).to(latent)\n        std = self.vae.latents_std.view(1, -1, 1, 1, 1).to(latent)\n        latent = (latent - mean) / std\n\n        packed = server_args.pipeline_config.maybe_pack_latents(\n            latent, latent.shape[0], batch\n        )\n        if not (isinstance(packed, torch.Tensor) and packed.ndim == 3):\n            raise ValueError(\"Expected packed image latents [B, S0, D].\")\n\n        # Fail-fast token count: must match one latent frame's tokens.\n        vae_sf = int(server_args.pipeline_config.vae_scale_factor)\n        patch = int(server_args.pipeline_config.patch_size)\n        latent_h = int(batch.height) // vae_sf\n        latent_w = int(batch.width) // vae_sf\n        expected_tokens = (latent_h // patch) * (latent_w // patch)\n        if int(packed.shape[1]) != int(expected_tokens):\n            raise ValueError(\n                \"LTX-2 conditioning token count mismatch: \"\n                f\"{int(packed.shape[1])=} {int(expected_tokens)=}.\"\n            )\n\n        batch.image_latent = packed\n        batch.ltx2_num_image_tokens = int(packed.shape[1])\n\n        if batch.debug:\n            logger.info(\n                \"LTX2 TI2V conditioning prepared: %d tokens (shape=%s) for %sx%s\",\n                batch.ltx2_num_image_tokens,\n                tuple(batch.image_latent.shape),\n                batch.width,\n                batch.height,\n            )\n\n        self.vae.to(original_dtype)\n        if server_args.vae_cpu_offload:\n            self.vae = self.vae.to(\"cpu\")\n\n    @torch.no_grad()\n    def forward(self, batch: Req, server_args: ServerArgs) -> Req:\n        \"\"\"\n         Run the denoising loop.\n\n        Args:\n            batch: The current batch information.\n            server_args: The inference arguments.\n\n        Returns:\n            The batch with denoised latents.\n        \"\"\"\n        # Disable cache-dit for image-conditioned requests (TI2V-style) for correctness/debuggability.\n        self._disable_cache_dit_for_request = batch.image_path is not None\n\n        # Prepare variables for the denoising loop\n\n        prepared_vars = self._prepare_denoising_loop(batch, server_args)\n        extra_step_kwargs = prepared_vars[\"extra_step_kwargs\"]\n        target_dtype = prepared_vars[\"target_dtype\"]\n        autocast_enabled = prepared_vars[\"autocast_enabled\"]\n        timesteps = prepared_vars[\"timesteps\"]\n        num_inference_steps = prepared_vars[\"num_inference_steps\"]\n        num_warmup_steps = prepared_vars[\"num_warmup_steps\"]\n        image_kwargs = prepared_vars[\"image_kwargs\"]\n        pos_cond_kwargs = prepared_vars[\"pos_cond_kwargs\"]\n        neg_cond_kwargs = prepared_vars[\"neg_cond_kwargs\"]\n        latents = prepared_vars[\"latents\"]\n        boundary_timestep = prepared_vars[\"boundary_timestep\"]\n        z = prepared_vars[\"z\"]\n        reserved_frames_mask = prepared_vars[\"reserved_frames_mask\"]\n        seq_len = prepared_vars[\"seq_len\"]\n        guidance = prepared_vars[\"guidance\"]\n\n        audio_latents = batch.audio_latents\n        audio_scheduler = copy.deepcopy(self.scheduler)\n\n        # Prepare TI2V conditioning once (encode image -> patchify tokens).\n        self._prepare_ltx2_image_latent(batch, server_args)\n\n        # For LTX-2 packed token latents, SP sharding happens on the time dimension\n        # (frames). The model must see local latent frames (RoPE offset is applied\n        # inside the model using SP rank).\n        latent_num_frames_for_model = self._get_video_latent_num_frames_for_model(\n            batch=batch, server_args=server_args, latents=latents\n        )\n        latent_height = batch.height // server_args.pipeline_config.vae_scale_factor\n        latent_width = batch.width // server_args.pipeline_config.vae_scale_factor\n\n        # Initialize lists for ODE trajectory\n        trajectory_timesteps: list[torch.Tensor] = []\n        trajectory_latents: list[torch.Tensor] = []\n        trajectory_audio_latents: list[torch.Tensor] = []\n\n        # Run denoising loop\n        denoising_start_time = time.time()\n\n        # to avoid device-sync caused by timestep comparison\n        is_warmup = batch.is_warmup\n        self.scheduler.set_begin_index(0)\n        audio_scheduler.set_begin_index(0)\n        timesteps_cpu = timesteps.cpu()\n        num_timesteps = timesteps_cpu.shape[0]\n\n        do_ti2v = self._should_apply_ltx2_ti2v(batch)\n        num_img_tokens = int(getattr(batch, \"ltx2_num_image_tokens\", 0))\n        denoise_mask = None\n        clean_latent = None\n        if do_ti2v:\n            if not (isinstance(latents, torch.Tensor) and latents.ndim == 3):\n                raise ValueError(\"LTX-2 TI2V expects packed token latents [B, S, D].\")\n            latents[:, :num_img_tokens, :] = batch.image_latent[\n                :, :num_img_tokens, :\n            ].to(device=latents.device, dtype=latents.dtype)\n            denoise_mask = torch.ones(\n                (latents.shape[0], latents.shape[1], 1),\n                device=latents.device,\n                dtype=torch.float32,\n            )\n            denoise_mask[:, :num_img_tokens, :] = 0.0\n            clean_latent = latents.detach().clone()\n            clean_latent[:, :num_img_tokens, :] = batch.image_latent[\n                :, :num_img_tokens, :\n            ].to(device=latents.device, dtype=latents.dtype)\n\n        with torch.autocast(\n            device_type=current_platform.device_type,\n            dtype=target_dtype,\n            enabled=autocast_enabled,\n        ):\n            with self.progress_bar(total=num_inference_steps) as progress_bar:\n                for i, t_host in enumerate(timesteps_cpu):\n                    with StageProfiler(\n                        f\"denoising_step_{i}\",\n                        logger=logger,\n                        metrics=batch.metrics,\n                        perf_dump_path_provided=batch.perf_dump_path is not None,\n                    ):\n                        t_int = int(t_host.item())\n                        t_device = timesteps[i]\n                        current_model, current_guidance_scale = (\n                            self._select_and_manage_model(\n                                t_int=t_int,\n                                boundary_timestep=boundary_timestep,\n                                server_args=server_args,\n                                batch=batch,\n                            )\n                        )\n\n                        # Predict noise residual\n                        attn_metadata = self._build_attn_metadata(i, batch, server_args)\n\n                        # === LTX-2 sigma-space Euler step (flow matching) ===\n                        # Use scheduler-generated sigmas (includes terminal sigma=0).\n                        sigmas = getattr(self.scheduler, \"sigmas\", None)\n                        if sigmas is None or not isinstance(sigmas, torch.Tensor):\n                            raise ValueError(\n                                \"Expected scheduler.sigmas to be a tensor for LTX-2.\"\n                            )\n                        sigma = sigmas[i].to(device=latents.device, dtype=torch.float32)\n                        sigma_next = sigmas[i + 1].to(\n                            device=latents.device, dtype=torch.float32\n                        )\n                        dt = sigma_next - sigma\n\n                        latent_model_input = latents.to(target_dtype)\n                        audio_latent_model_input = audio_latents.to(target_dtype)\n\n                        latent_num_frames = latent_num_frames_for_model\n\n                        # Audio latent dims\n                        if audio_latent_model_input.ndim == 3:\n                            audio_num_frames_latent = int(\n                                audio_latent_model_input.shape[1]\n                            )\n                        elif audio_latent_model_input.ndim == 4:\n                            audio_num_frames_latent = int(\n                                audio_latent_model_input.shape[2]\n                            )\n                        else:\n                            raise ValueError(\n                                f\"Unexpected audio latents rank: {audio_latent_model_input.ndim}, shape={tuple(audio_latent_model_input.shape)}\"\n                            )\n\n                        # LTX-2 model can generate coords internally.\n                        video_coords = None\n                        audio_coords = None\n\n                        timestep = t_device.expand(int(latent_model_input.shape[0]))\n                        if do_ti2v and denoise_mask is not None:\n                            timestep_video = timestep.unsqueeze(\n                                -1\n                            ) * denoise_mask.squeeze(-1)\n                        else:\n                            timestep_video = timestep\n                        timestep_audio = timestep\n\n                        # Conditions\n                        encoder_hidden_states = batch.prompt_embeds[0]\n                        audio_encoder_hidden_states = batch.audio_prompt_embeds[0]\n                        encoder_attention_mask = batch.prompt_attention_mask\n\n                        # Follow ltx-pipelines structure: separate pos/neg forward passes,\n                        # then apply CFG on denoised (x0) predictions.\n                        with set_forward_context(\n                            current_timestep=i, attn_metadata=attn_metadata\n                        ):\n                            v_pos, a_v_pos = current_model(\n                                hidden_states=latent_model_input,\n                                audio_hidden_states=audio_latent_model_input,\n                                encoder_hidden_states=encoder_hidden_states,\n                                audio_encoder_hidden_states=audio_encoder_hidden_states,\n                                timestep=timestep_video,\n                                audio_timestep=timestep_audio,\n                                encoder_attention_mask=encoder_attention_mask,\n                                audio_encoder_attention_mask=encoder_attention_mask,\n                                num_frames=latent_num_frames,\n                                height=latent_height,\n                                width=latent_width,\n                                fps=batch.fps,\n                                audio_num_frames=audio_num_frames_latent,\n                                video_coords=video_coords,\n                                audio_coords=audio_coords,\n                                return_latents=False,\n                                return_dict=False,\n                            )\n\n                            if batch.do_classifier_free_guidance:\n                                neg_encoder_hidden_states = (\n                                    batch.negative_prompt_embeds[0]\n                                )\n                                neg_audio_encoder_hidden_states = (\n                                    batch.negative_audio_prompt_embeds[0]\n                                )\n                                neg_encoder_attention_mask = (\n                                    batch.negative_attention_mask\n                                )\n\n                                v_neg, a_v_neg = current_model(\n                                    hidden_states=latent_model_input,\n                                    audio_hidden_states=audio_latent_model_input,\n                                    encoder_hidden_states=neg_encoder_hidden_states,\n                                    audio_encoder_hidden_states=neg_audio_encoder_hidden_states,\n                                    timestep=timestep_video,\n                                    audio_timestep=timestep_audio,\n                                    encoder_attention_mask=neg_encoder_attention_mask,\n                                    audio_encoder_attention_mask=neg_encoder_attention_mask,\n                                    num_frames=latent_num_frames,\n                                    height=latent_height,\n                                    width=latent_width,\n                                    fps=batch.fps,\n                                    audio_num_frames=audio_num_frames_latent,\n                                    video_coords=video_coords,\n                                    audio_coords=audio_coords,\n                                    return_latents=False,\n                                    return_dict=False,\n                                )\n                            else:\n                                v_neg = None\n                                a_v_neg = None\n\n                        v_pos = v_pos.float()\n                        a_v_pos = a_v_pos.float()\n                        if v_neg is not None:\n                            v_neg = v_neg.float()\n                        if a_v_neg is not None:\n                            a_v_neg = a_v_neg.float()\n\n                        # Velocity -> denoised (x0): x0 = x - sigma * v\n                        sigma_val = float(sigma.item())\n                        denoised_video = (latents.float() - sigma_val * v_pos).to(\n                            latents.dtype\n                        )\n                        denoised_audio = (\n                            audio_latents.float() - sigma_val * a_v_pos\n                        ).to(audio_latents.dtype)\n\n                        if (\n                            batch.do_classifier_free_guidance\n                            and v_neg is not None\n                            and a_v_neg is not None\n                        ):\n                            denoised_video_neg = (\n                                latents.float() - sigma_val * v_neg\n                            ).to(latents.dtype)\n                            denoised_audio_neg = (\n                                audio_latents.float() - sigma_val * a_v_neg\n                            ).to(audio_latents.dtype)\n                            denoised_video = denoised_video + (\n                                batch.guidance_scale - 1.0\n                            ) * (denoised_video - denoised_video_neg)\n                            denoised_audio = denoised_audio + (\n                                batch.guidance_scale - 1.0\n                            ) * (denoised_audio - denoised_audio_neg)\n\n                        # Apply conditioning mask (keep conditioned tokens clean).\n                        if (\n                            do_ti2v\n                            and denoise_mask is not None\n                            and clean_latent is not None\n                        ):\n                            denoised_video = (\n                                denoised_video * denoise_mask\n                                + clean_latent.float() * (1.0 - denoise_mask)\n                            )\n\n                        # Euler step in sigma space: x_next = x + (sigma_next - sigma) * v,\n                        # where v = (x - x0) / sigma.\n                        if sigma_val == 0.0:\n                            v_video = torch.zeros_like(denoised_video)\n                            v_audio = torch.zeros_like(denoised_audio)\n                        else:\n                            v_video = (\n                                (latents.float() - denoised_video.float()) / sigma_val\n                            ).to(latents.dtype)\n                            v_audio = (\n                                (audio_latents.float() - denoised_audio.float())\n                                / sigma_val\n                            ).to(audio_latents.dtype)\n\n                        latents = (latents.float() + v_video.float() * dt).to(\n                            dtype=latents.dtype\n                        )\n                        audio_latents = (\n                            audio_latents.float() + v_audio.float() * dt\n                        ).to(dtype=audio_latents.dtype)\n\n                        if do_ti2v:\n                            latents[:, :num_img_tokens, :] = batch.image_latent[\n                                :, :num_img_tokens, :\n                            ].to(device=latents.device, dtype=latents.dtype)\n\n                        latents = self.post_forward_for_ti2v_task(\n                            batch, server_args, reserved_frames_mask, latents, z\n                        )\n\n                        # save trajectory latents if needed\n                        if batch.return_trajectory_latents:\n                            trajectory_timesteps.append(t_host)\n                            trajectory_latents.append(latents)\n                            if audio_latents is not None:\n                                trajectory_audio_latents.append(audio_latents)\n\n                        # Update progress bar\n                        if i == num_timesteps - 1 or (\n                            (i + 1) > num_warmup_steps\n                            and (i + 1) % self.scheduler.order == 0\n                            and progress_bar is not None\n                        ):\n                            progress_bar.update()\n\n                        if not is_warmup:\n                            self.step_profile()\n\n        denoising_end_time = time.time()\n\n        if num_timesteps > 0 and not is_warmup:\n            self.log_info(\n                \"average time per step: %.4f seconds\",\n                (denoising_end_time - denoising_start_time) / len(timesteps),\n            )\n\n        batch.audio_latents = audio_latents\n        self._post_denoising_loop(\n            batch=batch,\n            latents=latents,\n            trajectory_latents=trajectory_latents,\n            trajectory_timesteps=trajectory_timesteps,\n            trajectory_audio_latents=trajectory_audio_latents,\n            server_args=server_args,\n            is_warmup=is_warmup,\n        )\n\n        return batch\n\n    def _post_denoising_loop(\n        self,\n        batch: Req,\n        latents: torch.Tensor,\n        trajectory_latents: list,\n        trajectory_timesteps: list,\n        trajectory_audio_latents: list,\n        server_args: ServerArgs,\n        is_warmup: bool = False,\n    ):\n        # 1. Handle Trajectory (Video) - Copy from base\n        if trajectory_latents:\n            trajectory_tensor = torch.stack(trajectory_latents, dim=1)\n            trajectory_timesteps_tensor = torch.stack(trajectory_timesteps, dim=0)\n        else:\n            trajectory_tensor = None\n            trajectory_timesteps_tensor = None\n\n        latents, trajectory_tensor = self._postprocess_sp_latents(\n            batch, latents, trajectory_tensor\n        )\n\n        # If SP time-sharding padded whole frames worth of tokens, remove padding\n        # after gather and before unpacking.\n        latents = self._truncate_sp_padded_token_latents(batch, latents)\n\n        if trajectory_tensor is not None and trajectory_timesteps_tensor is not None:\n            batch.trajectory_timesteps = trajectory_timesteps_tensor.cpu()\n            batch.trajectory_latents = trajectory_tensor.cpu()\n\n        # 2. Handle Trajectory (Audio) - LTX-2 specific\n        if trajectory_audio_latents:\n            trajectory_audio_tensor = torch.stack(trajectory_audio_latents, dim=1)\n            # We don't have SP support for audio latents yet (or needed?)\n            batch.trajectory_audio_latents = trajectory_audio_tensor.cpu()\n\n        # 3. Unpack and Denormalize\n        # Call pipeline_config._unpad_and_unpack_latents\n        # latents is video latents.\n        # batch.audio_latents is audio latents.\n\n        audio_latents = batch.audio_latents\n\n        # NOTE: self.vae and self.audio_vae should be populated via __init__ or manual setting\n        if self.vae is None or self.audio_vae is None:\n            logger.warning(\n                \"VAE or Audio VAE not found in DenoisingStage. Skipping unpack and denormalize.\"\n            )\n            batch.latents = latents\n            batch.audio_latents = audio_latents\n        else:\n            latents, audio_latents = (\n                server_args.pipeline_config._unpad_and_unpack_latents(\n                    latents, audio_latents, batch, self.vae, self.audio_vae\n                )\n            )\n\n            batch.latents = latents\n            batch.audio_latents = audio_latents\n\n        if isinstance(self.transformer, OffloadableDiTMixin):\n            for manager in self.transformer.layerwise_offload_managers:\n                manager.release_all()\n\n    def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult:\n        \"\"\"Verify denoising stage inputs.\n\n        Note: LTX-2 connector stage converts `prompt_embeds`/`negative_prompt_embeds`\n        from list-of-tensors to a single tensor (video context) and stores audio\n        context separately.\n        \"\"\"\n\n        result = VerificationResult()\n        result.add_check(\"timesteps\", batch.timesteps, [V.is_tensor, V.min_dims(1)])\n\n        # LTX-2 may carry prompt embeddings as either a tensor (preferred) or legacy list.\n        result.add_check(\n            \"prompt_embeds\",\n            batch.prompt_embeds,\n            lambda x: V.is_tensor(x) or V.list_not_empty(x),\n        )\n\n        # Keep base expectation: image_embeds is always a list (may be empty).\n        result.add_check(\"image_embeds\", batch.image_embeds, V.is_list)\n\n        result.add_check(\n            \"num_inference_steps\", batch.num_inference_steps, V.positive_int\n        )\n        result.add_check(\"guidance_scale\", batch.guidance_scale, V.non_negative_float)\n        result.add_check(\"eta\", batch.eta, V.non_negative_float)\n        result.add_check(\"generator\", batch.generator, V.generator_or_list_generators)\n        result.add_check(\n            \"do_classifier_free_guidance\",\n            batch.do_classifier_free_guidance,\n            V.bool_value,\n        )\n\n        # When CFG is enabled, negative prompt embeddings must exist (tensor or legacy list).\n        result.add_check(\n            \"negative_prompt_embeds\",\n            batch.negative_prompt_embeds,\n            lambda x: (not batch.do_classifier_free_guidance)\n            or V.is_tensor(x)\n            or V.list_not_empty(x),\n        )\n        return result\n\n    def do_classifier_free_guidance(self, batch: Req) -> bool:\n        return batch.guidance_scale > 1.0\n\n\nclass LTX2RefinementStage(LTX2AVDenoisingStage):\n    def __init__(\n        self, transformer, scheduler, distilled_sigmas, vae=None, audio_vae=None\n    ):\n        super().__init__(transformer, scheduler, vae, audio_vae)\n        self.distilled_sigmas = torch.tensor(distilled_sigmas)\n\n    def forward(self, batch: Req, server_args: ServerArgs) -> Req:\n        # 1. Add noise to latents\n        noise_scale = self.distilled_sigmas[0].to(batch.latents.device)\n        noise = torch.randn_like(batch.latents)\n        batch.latents = batch.latents + noise * noise_scale\n\n        # 2. Run denoising loop with distilled_sigmas\n        # Save original sigmas\n        original_sigmas = self.scheduler.sigmas\n        original_timesteps = self.scheduler.timesteps\n        original_num_inference_steps = self.scheduler.num_inference_steps\n\n        # Set distilled sigmas\n        self.scheduler.sigmas = self.distilled_sigmas.to(self.scheduler.sigmas.device)\n        # Approximation for timesteps\n        self.scheduler.timesteps = self.scheduler.sigmas * 1000\n        self.scheduler.num_inference_steps = len(self.distilled_sigmas) - 1\n\n        # Call parent forward\n        try:\n            batch = super().forward(batch, server_args)\n        finally:\n            # Restore original sigmas\n            self.scheduler.sigmas = original_sigmas\n            self.scheduler.timesteps = original_timesteps\n            self.scheduler.num_inference_steps = original_num_inference_steps\n\n        return batch\n\n    def do_classifier_free_guidance(self, batch: Req) -> bool:\n        return False  # Stage 2 uses simple denoising (no CFG)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising_dmd.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\nimport time\n\nimport torch\n\nfrom sglang.multimodal_gen.runtime.distributed import get_local_torch_device\nfrom sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context\nfrom sglang.multimodal_gen.runtime.models.schedulers.scheduling_flow_match_euler_discrete import (\n    FlowMatchEulerDiscreteScheduler,\n)\nfrom sglang.multimodal_gen.runtime.models.utils import pred_noise_to_pred_video\nfrom sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages import DenoisingStage\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.runtime.utils.perf_logger import StageProfiler\nfrom sglang.multimodal_gen.utils import dict_to_3d_list\n\nlogger = init_logger(__name__)\n\n\nclass DmdDenoisingStage(DenoisingStage):\n    \"\"\"\n    Denoising stage for DMD.\n    \"\"\"\n\n    def __init__(self, transformer, scheduler, transformer_2=None) -> None:\n        super().__init__(\n            transformer=transformer, scheduler=scheduler, transformer_2=transformer_2\n        )\n        self.scheduler = FlowMatchEulerDiscreteScheduler(shift=8.0)\n\n    def _preprocess_sp_latents(self, batch: Req, server_args: ServerArgs):\n        # 1. to shard latents (B, C, T, H, W) along dim 2\n        super()._preprocess_sp_latents(batch, server_args)\n\n        # 2. DMD expects (B, T, C, H, W) for the main latents in the loop\n        if batch.latents is not None:\n            batch.latents = batch.latents.permute(0, 2, 1, 3, 4)\n\n        # Note: batch.image_latent is kept as (B, C, T, H, W) here\n\n    def _postprocess_sp_latents(\n        self,\n        batch: Req,\n        latents: torch.Tensor,\n        trajectory_tensor: torch.Tensor | None,\n    ) -> tuple[torch.Tensor, torch.Tensor | None]:\n        # 1. convert back from DMD's (B, T, C, H, W) to standard (B, C, T, H, W)\n        # this is because base gather_latents_for_sp expects dim=2 for T\n        latents = latents.permute(0, 2, 1, 3, 4)\n\n        # 2. use base method to gather\n        return super()._postprocess_sp_latents(batch, latents, trajectory_tensor)\n\n    def forward(\n        self,\n        batch: Req,\n        server_args: ServerArgs,\n    ) -> Req:\n        \"\"\"\n        Run the denoising loop.\n        \"\"\"\n        prepared_vars = self._prepare_denoising_loop(batch, server_args)\n\n        target_dtype = prepared_vars[\"target_dtype\"]\n        autocast_enabled = prepared_vars[\"autocast_enabled\"]\n        num_warmup_steps = prepared_vars[\"num_warmup_steps\"]\n        latents = prepared_vars[\"latents\"]\n        video_raw_latent_shape = latents.shape\n\n        timesteps = torch.tensor(\n            server_args.pipeline_config.dmd_denoising_steps,\n            dtype=torch.long,\n            device=get_local_torch_device(),\n        )\n\n        # prepare image_kwargs\n        image_embeds = batch.image_embeds\n        if len(image_embeds) > 0:\n            image_embeds = [img.to(target_dtype) for img in image_embeds]\n\n        image_kwargs = self.prepare_extra_func_kwargs(\n            self.transformer.forward,\n            {\n                \"encoder_hidden_states_image\": image_embeds,\n                \"mask_strategy\": dict_to_3d_list(None, t_max=50, l_max=60, h_max=24),\n            },\n        )\n\n        pos_cond_kwargs = prepared_vars[\"pos_cond_kwargs\"]\n\n        denoising_loop_start_time = time.time()\n        with self.progress_bar(total=len(timesteps)) as progress_bar:\n            for i, t in enumerate(timesteps):\n                # Skip if interrupted\n                if hasattr(self, \"interrupt\") and self.interrupt:\n                    continue\n\n                with StageProfiler(\n                    f\"denoising_step_{i}\",\n                    logger=logger,\n                    metrics=batch.metrics,\n                    perf_dump_path_provided=batch.perf_dump_path is not None,\n                ):\n                    t_int = int(t.item())\n                    if self.transformer_2 is not None:\n                        current_model, current_guidance_scale = (\n                            self._select_and_manage_model(\n                                t_int=t_int,\n                                boundary_timestep=self._handle_boundary_ratio(\n                                    server_args, batch\n                                ),\n                                server_args=server_args,\n                                batch=batch,\n                            )\n                        )\n                    else:\n                        current_model = self.transformer\n                        self._manage_device_placement(current_model, None, server_args)\n                    # Expand latents for I2V\n                    noise_latents = latents.clone()\n                    latent_model_input = latents.to(target_dtype)\n\n                    if batch.image_latent is not None:\n                        latent_model_input = torch.cat(\n                            [\n                                latent_model_input,\n                                batch.image_latent.permute(0, 2, 1, 3, 4),\n                            ],\n                            dim=2,\n                        ).to(target_dtype)\n                    assert not torch.isnan(\n                        latent_model_input\n                    ).any(), \"latent_model_input contains nan\"\n\n                    # Prepare inputs for transformer\n                    t_expand = t.repeat(latent_model_input.shape[0])\n\n                    guidance_expand = self.get_or_build_guidance(\n                        latent_model_input.shape[0],\n                        target_dtype,\n                        get_local_torch_device(),\n                    )\n\n                    # Predict noise residual\n                    with torch.autocast(\n                        device_type=current_platform.device_type,\n                        dtype=target_dtype,\n                        enabled=autocast_enabled,\n                    ):\n                        attn_metadata = self._build_attn_metadata(i, batch, server_args)\n\n                        batch.is_cfg_negative = False\n                        with set_forward_context(\n                            current_timestep=i,\n                            attn_metadata=attn_metadata,\n                            forward_batch=batch,\n                        ):\n                            # Run transformer\n                            pred_noise = current_model(\n                                hidden_states=latent_model_input.permute(0, 2, 1, 3, 4),\n                                timestep=t_expand,\n                                guidance=guidance_expand,\n                                **image_kwargs,\n                                **pos_cond_kwargs,\n                            ).permute(0, 2, 1, 3, 4)\n\n                        pred_video = pred_noise_to_pred_video(\n                            pred_noise=pred_noise.flatten(0, 1),\n                            noise_input_latent=noise_latents.flatten(0, 1),\n                            timestep=t_expand,\n                            scheduler=self.scheduler,\n                        ).unflatten(0, pred_noise.shape[:2])\n\n                        if i < len(timesteps) - 1:\n                            next_timestep = timesteps[i + 1] * torch.ones(\n                                [1], dtype=torch.long, device=pred_video.device\n                            )\n                            noise = torch.randn(\n                                video_raw_latent_shape,\n                                dtype=pred_video.dtype,\n                                generator=batch.generator[0],\n                                device=self.device,\n                            )\n                            latents = self.scheduler.add_noise(\n                                pred_video.flatten(0, 1),\n                                noise.flatten(0, 1),\n                                next_timestep,\n                            ).unflatten(0, pred_video.shape[:2])\n                        else:\n                            latents = pred_video\n\n                        # Update progress bar\n                        if i == len(timesteps) - 1 or (\n                            (i + 1) > num_warmup_steps\n                            and (i + 1) % self.scheduler.order == 0\n                            and progress_bar is not None\n                        ):\n                            progress_bar.update()\n\n                    self.step_profile()\n\n        denoising_loop_end_time = time.time()\n        if len(timesteps) > 0:\n            self.log_info(\n                \"average time per step: %.4f seconds\",\n                (denoising_loop_end_time - denoising_loop_start_time) / len(timesteps),\n            )\n\n        self._post_denoising_loop(\n            batch=batch,\n            latents=latents,\n            trajectory_latents=[],\n            trajectory_timesteps=[],\n            server_args=server_args,\n        )\n\n        return batch\n\n    def _select_and_manage_model(\n        self,\n        t_int: int,\n        boundary_timestep: float | None,\n        server_args: ServerArgs,\n        batch: Req,\n    ):\n        if boundary_timestep is None or t_int >= boundary_timestep:\n            # High-noise stage\n            current_model = self.transformer\n            model_to_offload = self.transformer_2\n            current_guidance_scale = batch.guidance_scale\n        else:\n            # Low-noise stage\n            current_model = self.transformer_2\n            model_to_offload = self.transformer\n            current_guidance_scale = batch.guidance_scale_2\n\n        self._manage_device_placement(current_model, model_to_offload, server_args)\n\n        assert current_model is not None, \"The model for the current step is not set.\"\n        return current_model, current_guidance_scale\n\n    def _manage_device_placement(\n        self,\n        model_to_use: torch.nn.Module,\n        model_to_offload: torch.nn.Module | None,\n        server_args: ServerArgs,\n    ):\n        \"\"\"\n        Manages the offload / load behavior of dit\n        \"\"\"\n        if not server_args.dit_cpu_offload:\n            return\n\n        # Offload the unused model if it's on CUDA\n        if (\n            model_to_offload is not None\n            and next(model_to_offload.parameters()).device.type == \"cuda\"\n        ):\n            model_to_offload.to(\"cpu\")\n\n        # Load the model to use if it's on CPU\n        if (\n            model_to_use is not None\n            and next(model_to_use.parameters()).device.type == \"cpu\"\n        ):\n            model_to_use.to(get_local_torch_device())\n\n    def _handle_boundary_ratio(\n        self,\n        server_args,\n        batch,\n    ):\n        \"\"\"\n        (Wan2.2) Calculate timestep to switch from high noise expert to low noise expert\n        \"\"\"\n        boundary_ratio = server_args.pipeline_config.dit_config.boundary_ratio\n        if batch.boundary_ratio is not None:\n            logger.info(\n                \"Overriding boundary ratio from %s to %s\",\n                boundary_ratio,\n                batch.boundary_ratio,\n            )\n            boundary_ratio = batch.boundary_ratio\n\n        if boundary_ratio is not None:\n            boundary_timestep = boundary_ratio * self.scheduler.num_train_timesteps\n        else:\n            boundary_timestep = None\n\n        return boundary_timestep\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines_core/stages/encoding.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"\nEncoding stage for diffusion pipelines.\n\"\"\"\n\nimport torch\n\nfrom sglang.multimodal_gen.runtime.distributed import get_local_torch_device\nfrom sglang.multimodal_gen.runtime.models.vaes.common import ParallelTiledVAE\nfrom sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.base import PipelineStage\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.validators import (\n    V,  # Import validators\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.validators import (\n    VerificationResult,\n)\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.utils import PRECISION_TO_TYPE\n\nlogger = init_logger(__name__)\n\n\nclass EncodingStage(PipelineStage):\n    \"\"\"\n    Stage for encoding pixel space representations into latent space.\n\n    This stage handles the encoding of pixel-space video/images into latent\n    representations for further processing in the diffusion pipeline.\n    \"\"\"\n\n    def __init__(self, vae: ParallelTiledVAE) -> None:\n        super().__init__()\n        self.vae: ParallelTiledVAE = vae\n\n    @torch.no_grad()\n    def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult:\n        \"\"\"Verify encoding stage inputs.\"\"\"\n        result = VerificationResult()\n        # Input video/images for VAE encoding: [batch_size, channels, frames, height, width]\n        result.add_check(\"latents\", batch.latents, [V.is_tensor, V.with_dims(5)])\n        return result\n\n    def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult:\n        \"\"\"Verify encoding stage outputs.\"\"\"\n        result = VerificationResult()\n        # Encoded latents: [batch_size, channels, frames, height_latents, width_latents]\n        result.add_check(\"latents\", batch.latents, [V.is_tensor, V.with_dims(5)])\n        return result\n\n    def forward(\n        self,\n        batch: Req,\n        server_args: ServerArgs,\n    ) -> Req:\n        \"\"\"\n        Encode pixel space representations into latent space.\n\n\n\n        Returns:\n            The batch with encoded latents.\n        \"\"\"\n        assert batch.latents is not None and isinstance(batch.latents, torch.Tensor)\n\n        self.vae = self.vae.to(get_local_torch_device())\n\n        # Setup VAE precision\n        vae_dtype = PRECISION_TO_TYPE[server_args.pipeline_config.vae_precision]\n        vae_autocast_enabled = (\n            vae_dtype != torch.float32\n        ) and not server_args.disable_autocast\n\n        # Normalize input to [-1, 1] range (reverse of decoding normalization)\n        latents = (batch.latents * 2.0 - 1.0).clamp(-1, 1)\n\n        # Move to appropriate device and dtype\n        latents = latents.to(get_local_torch_device())\n\n        # Encode image to latents\n        with torch.autocast(\n            device_type=current_platform.device_type,\n            dtype=vae_dtype,\n            enabled=vae_autocast_enabled,\n        ):\n            if server_args.pipeline_config.vae_tiling:\n                self.vae.enable_tiling()\n            # if server_args.vae_sp:\n            #     self.vae.enable_parallel()\n            if not vae_autocast_enabled:\n                latents = latents.to(vae_dtype)\n            latents = self.vae.encode(latents).mean\n\n        # Update batch with encoded latents\n        batch.latents = latents\n\n        # Offload models if needed\n        self.maybe_free_model_hooks()\n\n        if server_args.vae_cpu_offload:\n            self.vae.to(\"cpu\")\n\n        return batch\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines_core/stages/hunyuan3d_paint.py",
    "content": "\"\"\"\nHunyuan3D paint/texture generation stages.\n\nThree-stage pipeline: Preprocess -> TexGen -> Postprocess.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport os\nfrom typing import Any\n\nimport numpy as np\nimport torch\nfrom diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (\n    retrieve_timesteps,\n)\nfrom einops import rearrange\n\nfrom sglang.multimodal_gen.configs.pipeline_configs.hunyuan3d import (\n    Hunyuan3D2PipelineConfig,\n)\nfrom sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context\nfrom sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import OutputBatch, Req\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.base import (\n    PipelineStage,\n    StageParallelismType,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.validators import (\n    StageValidators as V,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.validators import (\n    VerificationResult,\n)\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\n# Utility functions\ndef guidance_scale_embedding(\n    w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32\n) -> torch.Tensor:\n    \"\"\"Generate guidance scale embeddings.\"\"\"\n    assert len(w.shape) == 1\n    w = w * 1000.0\n\n    half_dim = embedding_dim // 2\n    emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)\n    emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)\n    emb = w.to(dtype)[:, None] * emb[None, :]\n    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)\n    if embedding_dim % 2 == 1:\n        emb = torch.nn.functional.pad(emb, (0, 1))\n    assert emb.shape == (w.shape[0], embedding_dim)\n    return emb\n\n\ndef extract_into_tensor(\n    a: torch.Tensor, t: torch.Tensor, x_shape: tuple, n_gen: int\n) -> torch.Tensor:\n    \"\"\"Extract values from tensor and reshape for multi-view generation.\"\"\"\n    out = a.gather(-1, t)\n    out = out.repeat(n_gen)\n    out = rearrange(out, \"(b n) -> b n\", n=n_gen)\n    b, c, *_ = out.shape\n    return out.reshape(b, c, *((1,) * (len(x_shape) - 2)))\n\n\ndef get_predicted_original_sample(\n    model_output: torch.Tensor,\n    timesteps: torch.Tensor,\n    sample: torch.Tensor,\n    prediction_type: str,\n    alphas: torch.Tensor,\n    sigmas: torch.Tensor,\n    n_gen: int,\n) -> torch.Tensor:\n    \"\"\"Get predicted original sample from model output.\"\"\"\n    alphas = extract_into_tensor(alphas, timesteps, sample.shape, n_gen)\n    sigmas = extract_into_tensor(sigmas, timesteps, sample.shape, n_gen)\n    model_output = rearrange(model_output, \"(b n) c h w -> b n c h w\", n=n_gen)\n\n    if prediction_type == \"epsilon\":\n        pred_x_0 = (sample - sigmas * model_output) / alphas\n    elif prediction_type == \"sample\":\n        pred_x_0 = model_output\n    elif prediction_type == \"v_prediction\":\n        pred_x_0 = alphas * sample - sigmas * model_output\n    else:\n        raise ValueError(\n            f\"Prediction type {prediction_type} is not supported; \"\n            \"currently, `epsilon`, `sample`, and `v_prediction` are supported.\"\n        )\n\n    return pred_x_0\n\n\ndef get_predicted_noise(\n    model_output: torch.Tensor,\n    timesteps: torch.Tensor,\n    sample: torch.Tensor,\n    prediction_type: str,\n    alphas: torch.Tensor,\n    sigmas: torch.Tensor,\n    n_gen: int,\n) -> torch.Tensor:\n    \"\"\"Get predicted noise from model output.\"\"\"\n    alphas = extract_into_tensor(alphas, timesteps, sample.shape, n_gen)\n    sigmas = extract_into_tensor(sigmas, timesteps, sample.shape, n_gen)\n    model_output = rearrange(model_output, \"(b n) c h w -> b n c h w\", n=n_gen)\n\n    if prediction_type == \"epsilon\":\n        pred_epsilon = model_output\n    elif prediction_type == \"sample\":\n        pred_epsilon = (sample - alphas * model_output) / sigmas\n    elif prediction_type == \"v_prediction\":\n        pred_epsilon = alphas * model_output + sigmas * sample\n    else:\n        raise ValueError(\n            f\"Prediction type {prediction_type} is not supported; \"\n            \"currently, `epsilon`, `sample`, and `v_prediction` are supported.\"\n        )\n\n    return pred_epsilon\n\n\ndef to_rgb_image(maybe_rgba):\n    \"\"\"Convert RGBA image to RGB.\"\"\"\n    from PIL import Image\n\n    if maybe_rgba.mode == \"RGB\":\n        return maybe_rgba\n    if maybe_rgba.mode == \"RGBA\":\n        rgba = maybe_rgba\n        img = np.random.randint(\n            127, 128, size=[rgba.size[1], rgba.size[0], 3], dtype=np.uint8\n        )\n        img = Image.fromarray(img, \"RGB\")\n        img.paste(rgba, mask=rgba.getchannel(\"A\"))\n        return img\n    raise ValueError(f\"Unsupported image type: {maybe_rgba.mode}\")\n\n\nclass DDIMSolver:\n    \"\"\"DDIM solver for fast sampling.\"\"\"\n\n    def __init__(\n        self,\n        alpha_cumprods: np.ndarray,\n        timesteps: int = 1000,\n        ddim_timesteps: int = 50,\n    ):\n        step_ratio = timesteps // ddim_timesteps\n        self.ddim_timesteps = (\n            np.arange(1, ddim_timesteps + 1) * step_ratio\n        ).round().astype(np.int64) - 1\n        self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]\n        self.ddim_alpha_cumprods_prev = np.asarray(\n            [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()\n        )\n        self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()\n        self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)\n        self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev)\n\n    def to(self, device: torch.device) -> \"DDIMSolver\":\n        self.ddim_timesteps = self.ddim_timesteps.to(device)\n        self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)\n        self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device)\n        return self\n\n    def ddim_step(\n        self,\n        pred_x0: torch.Tensor,\n        pred_noise: torch.Tensor,\n        timestep_index: torch.Tensor,\n        n_gen: int,\n    ) -> torch.Tensor:\n        alpha_cumprod_prev = extract_into_tensor(\n            self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape, n_gen\n        )\n        dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise\n        x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt\n        return x_prev\n\n\ndef _recorrect_rgb(\n    src_image: torch.Tensor,\n    target_image: torch.Tensor,\n    alpha_channel: torch.Tensor,\n    scale: float = 0.95,\n) -> torch.Tensor:\n    \"\"\"Correct RGB values to match target color distribution.\"\"\"\n\n    def flat_and_mask(bgr, a):\n        mask = torch.where(a > 0.5, True, False)\n        bgr_flat = bgr.reshape(-1, bgr.shape[-1])\n        mask_flat = mask.reshape(-1)\n        bgr_flat_masked = bgr_flat[mask_flat, :]\n        return bgr_flat_masked\n\n    src_flat = flat_and_mask(src_image, alpha_channel)\n    target_flat = flat_and_mask(target_image, alpha_channel)\n    corrected_bgr = torch.zeros_like(src_image)\n\n    for i in range(3):\n        src_mean, src_stddev = torch.mean(src_flat[:, i]), torch.std(src_flat[:, i])\n        target_mean, target_stddev = torch.mean(target_flat[:, i]), torch.std(\n            target_flat[:, i]\n        )\n        corrected_bgr[:, :, i] = torch.clamp(\n            (src_image[:, :, i] - scale * src_mean) * (target_stddev / src_stddev)\n            + scale * target_mean,\n            0,\n            1,\n        )\n\n    src_mse = torch.mean((src_image - target_image) ** 2)\n    modify_mse = torch.mean((corrected_bgr - target_image) ** 2)\n    if src_mse < modify_mse:\n        corrected_bgr = torch.cat([src_image, alpha_channel], dim=-1)\n    else:\n        corrected_bgr = torch.cat([corrected_bgr, alpha_channel], dim=-1)\n\n    return corrected_bgr\n\n\n# Stage 1: Preprocess (UV unwrap + delight + multi-view rendering)\nclass Hunyuan3DPaintPreprocessStage(PipelineStage):\n    \"\"\"Preprocessing: UV unwrap + delight in parallel, then multi-view rendering.\"\"\"\n\n    CAMERA_AZIMS = [0, 90, 180, 270, 0, 180]\n    CAMERA_ELEVS = [0, 0, 0, 0, 90, -90]\n    VIEW_WEIGHTS = [1, 0.1, 0.5, 0.1, 0.05, 0.05]\n\n    @property\n    def parallelism_type(self) -> StageParallelismType:\n        return StageParallelismType.MAIN_RANK_ONLY\n\n    def __init__(self, config: Hunyuan3D2PipelineConfig) -> None:\n        super().__init__()\n        self.config = config\n        self._delight_pipeline = None\n        self._delight_loaded = False\n        self._renderer = None\n        self._renderer_loaded = False\n\n    # --- UV unwrap ---\n\n    def _do_uv_unwrap(self, batch: Req, server_args: ServerArgs) -> Req:\n        import time\n\n        from sglang.multimodal_gen.runtime.utils.mesh3d_utils import mesh_uv_wrap\n\n        mesh = batch.extra[\"shape_meshes\"]\n        if isinstance(mesh, list):\n            mesh = mesh[0]\n\n        try:\n            start_time = time.time()\n            mesh = mesh_uv_wrap(mesh)\n            elapsed = time.time() - start_time\n            logger.info(f\"UV unwrapping completed in {elapsed:.2f}s\")\n        except Exception as e:\n            logger.warning(f\"UV unwrapping failed: {e}\")\n\n        batch.extra[\"paint_mesh\"] = mesh\n        return batch\n\n    # --- Delight ---\n\n    def _load_delight_model(self, server_args: ServerArgs):\n        if self._delight_loaded:\n            return\n\n        from diffusers import (\n            EulerAncestralDiscreteScheduler,\n            StableDiffusionInstructPix2PixPipeline,\n        )\n        from huggingface_hub import snapshot_download\n\n        model_path = server_args.model_path\n        delight_subfolder = getattr(\n            self.config, \"delight_subfolder\", \"hunyuan3d-delight-v2-0\"\n        )\n\n        local_path = os.path.join(model_path, delight_subfolder)\n        if not os.path.exists(local_path):\n            local_path = os.path.expanduser(local_path)\n\n        if not os.path.exists(local_path):\n            try:\n                downloaded = snapshot_download(\n                    repo_id=model_path,\n                    allow_patterns=[f\"{delight_subfolder}/*\"],\n                )\n                local_path = os.path.join(downloaded, delight_subfolder)\n            except Exception as e:\n                logger.warning(\"Could not download delight model: %s\", e)\n                local_path = None\n\n        if local_path and os.path.exists(local_path):\n            pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(\n                local_path,\n                torch_dtype=torch.float16,\n                safety_checker=None,\n            )\n            pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(\n                pipeline.scheduler.config\n            )\n            pipeline.set_progress_bar_config(disable=True)\n            self._delight_pipeline = pipeline.to(self.device, torch.float16)\n            logger.info(\"Delight model loaded successfully\")\n        else:\n            logger.warning(\n                \"Delight model not available, skipping delight preprocessing\"\n            )\n\n        self._delight_loaded = True\n\n    @torch.no_grad()\n    def _run_delight(self, image):\n        import cv2\n        from PIL import Image as PILImage\n\n        image = image.resize((512, 512))\n\n        if image.mode == \"RGBA\":\n            image_array = np.array(image)\n            alpha_channel = image_array[:, :, 3]\n            erosion_size = 3\n            kernel = np.ones((erosion_size, erosion_size), np.uint8)\n            alpha_channel = cv2.erode(alpha_channel, kernel, iterations=1)\n            image_array[alpha_channel == 0, :3] = 255\n            image_array[:, :, 3] = alpha_channel\n            image = PILImage.fromarray(image_array)\n\n            image_tensor = torch.tensor(np.array(image) / 255.0).to(self.device)\n            alpha = image_tensor[:, :, 3:]\n            rgb_target = image_tensor[:, :, :3]\n        else:\n            image_tensor = torch.tensor(np.array(image) / 255.0).to(self.device)\n            alpha = torch.ones_like(image_tensor)[:, :, :1]\n            rgb_target = image_tensor[:, :, :3]\n\n        image = image.convert(\"RGB\")\n\n        image = self._delight_pipeline(\n            prompt=self.config.delight_prompt,\n            image=image,\n            generator=torch.manual_seed(42),\n            height=512,\n            width=512,\n            num_inference_steps=self.config.delight_num_inference_steps,\n            image_guidance_scale=self.config.delight_cfg_image,\n            guidance_scale=self.config.delight_guidance_scale,\n        ).images[0]\n\n        image_tensor = torch.tensor(np.array(image) / 255.0).to(self.device)\n        rgb_src = image_tensor[:, :, :3]\n        image = _recorrect_rgb(rgb_src, rgb_target, alpha)\n        image = image[:, :, :3] * image[:, :, 3:] + torch.ones_like(image[:, :, :3]) * (\n            1.0 - image[:, :, 3:]\n        )\n        image = PILImage.fromarray((image.cpu().numpy() * 255).astype(np.uint8))\n\n        return image\n\n    def _do_delight(self, batch: Req, server_args: ServerArgs) -> Req:\n        from PIL import Image\n\n        from sglang.multimodal_gen.runtime.utils.mesh3d_utils import recenter_image\n\n        image = Image.open(batch.image_path)\n        image = recenter_image(image)\n\n        if not self.config.delight_enable:\n            logger.info(\"Delight preprocessing disabled, using original image\")\n            batch.extra[\"delighted_image\"] = image\n            return batch\n\n        self._load_delight_model(server_args)\n        if self._delight_pipeline is not None:\n            try:\n                image = self._run_delight(image)\n                logger.info(\"Image delight completed\")\n            except Exception as e:\n                logger.warning(f\"Image delight failed: {e}\")\n\n        batch.extra[\"delighted_image\"] = image\n        return batch\n\n    # --- Multi-view rendering ---\n\n    def _init_renderer(self):\n        if self._renderer_loaded:\n            return\n\n        from sglang.multimodal_gen.runtime.utils.mesh3d_utils import MeshRender\n\n        self._renderer = MeshRender(\n            default_resolution=self.config.paint_render_size,\n            texture_size=self.config.paint_texture_size,\n        )\n        self._renderer_loaded = True\n        logger.info(\"Mesh renderer initialized\")\n\n    def _render_multiview(self, mesh) -> tuple:\n        self._init_renderer()\n        self._renderer.load_mesh(mesh)\n\n        normal_maps = self._renderer.render_normal_multiview(\n            self.CAMERA_ELEVS, self.CAMERA_AZIMS, use_abs_coor=True\n        )\n        position_maps = self._renderer.render_position_multiview(\n            self.CAMERA_ELEVS, self.CAMERA_AZIMS\n        )\n\n        logger.info(f\"Rendered {len(normal_maps)} views for texture generation\")\n        return normal_maps, position_maps\n\n    # --- Forward ---\n\n    def forward(self, batch: Req, server_args: ServerArgs) -> Req:\n        if batch.extra.get(\"_mesh_failed\"):\n            logger.warning(\"Mesh generation failed, skipping paint preprocessing\")\n            batch.extra[\"paint_mesh\"] = None\n            batch.extra[\"delighted_image\"] = None\n            batch.extra[\"normal_maps\"] = []\n            batch.extra[\"position_maps\"] = []\n            batch.extra[\"camera_azims\"] = self.CAMERA_AZIMS\n            batch.extra[\"camera_elevs\"] = self.CAMERA_ELEVS\n            batch.extra[\"view_weights\"] = self.VIEW_WEIGHTS\n            batch.extra[\"renderer\"] = None\n            return batch\n\n        import concurrent.futures\n        import copy\n\n        # 1. UV unwrap + delight in parallel\n        batch_for_uv = batch\n        batch_for_delight = copy.copy(batch)\n        batch_for_delight.extra = batch.extra.copy()\n\n        with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:\n            uv_future = executor.submit(self._do_uv_unwrap, batch_for_uv, server_args)\n            delight_future = executor.submit(\n                self._do_delight, batch_for_delight, server_args\n            )\n            uv_future.result()\n            delight_future.result()\n\n        batch.extra[\"paint_mesh\"] = batch_for_uv.extra.get(\"paint_mesh\")\n        batch.extra[\"delighted_image\"] = batch_for_delight.extra.get(\"delighted_image\")\n\n        # 2. Multi-view rendering\n        normal_maps, position_maps = self._render_multiview(batch.extra[\"paint_mesh\"])\n        batch.extra[\"normal_maps\"] = normal_maps\n        batch.extra[\"position_maps\"] = position_maps\n        batch.extra[\"camera_azims\"] = self.CAMERA_AZIMS\n        batch.extra[\"camera_elevs\"] = self.CAMERA_ELEVS\n        batch.extra[\"view_weights\"] = self.VIEW_WEIGHTS\n        batch.extra[\"renderer\"] = self._renderer\n\n        return batch\n\n    def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult:\n        result = VerificationResult()\n        result.add_check(\"shape_meshes\", batch.extra.get(\"shape_meshes\"), V.not_none)\n        result.add_check(\"image_path\", batch.image_path, V.not_none)\n        return result\n\n    def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult:\n        result = VerificationResult()\n        result.add_check(\"paint_mesh\", batch.extra.get(\"paint_mesh\"), V.not_none)\n        result.add_check(\n            \"delighted_image\", batch.extra.get(\"delighted_image\"), V.not_none\n        )\n        result.add_check(\"normal_maps\", batch.extra.get(\"normal_maps\"), V.is_list)\n        result.add_check(\"position_maps\", batch.extra.get(\"position_maps\"), V.is_list)\n        result.add_check(\"renderer\", batch.extra.get(\"renderer\"), V.not_none)\n        return result\n\n\n# Stage 2: TexGen (model loading + input prep + denoising + decode)\nclass Hunyuan3DPaintTexGenStage(PipelineStage):\n    def __init__(\n        self,\n        config: Hunyuan3D2PipelineConfig,\n        paint_dir: str | None = None,\n        transformer: Any = None,\n        scheduler: Any = None,\n        vae: Any = None,\n        vae_scale_factor: int = 8,\n        image_processor: Any = None,\n        solver: Any = None,\n        is_turbo: bool = False,\n    ) -> None:\n        super().__init__()\n        self.config = config\n        self.paint_dir = paint_dir\n        self.transformer = transformer\n        self.scheduler = scheduler\n        self.vae = vae\n        self.vae_scale_factor = vae_scale_factor\n        self.image_processor = image_processor\n        self.solver = solver\n        self.is_turbo = is_turbo\n        self._loaded = transformer is not None\n\n    @property\n    def parallelism_type(self) -> StageParallelismType:\n        return StageParallelismType.MAIN_RANK_ONLY\n\n    def _load_paint_models(self, server_args: ServerArgs) -> None:\n        \"\"\"Load paint models from pre-resolved local path (no network).\"\"\"\n        if self._loaded:\n            return\n        if self.paint_dir is None:\n            logger.warning(\"No paint model directory resolved, skipping\")\n            self._loaded = True\n            return\n        try:\n            self._do_load_paint(server_args)\n            logger.info(\"Paint pipeline loaded successfully\")\n        except Exception as e:\n            logger.warning(\"Failed to load paint pipeline: %s\", e)\n            self.vae = None\n            self.transformer = None\n            self.scheduler = None\n        self._loaded = True\n\n    def _do_load_paint(self, server_args: ServerArgs) -> None:\n        import json\n\n        from diffusers import AutoencoderKL\n        from diffusers.image_processor import VaeImageProcessor\n\n        from sglang.multimodal_gen.runtime.models.dits.hunyuan3d import (\n            UNet2p5DConditionModel,\n        )\n\n        local_path = self.paint_dir\n        logger.info(\"Loading paint model from %s\", local_path)\n        vae_dir = os.path.join(local_path, \"vae\")\n        with open(os.path.join(vae_dir, \"config.json\"), \"r\") as f:\n            vae_config = json.load(f)\n        vae_config = {k: v for k, v in vae_config.items() if not k.startswith(\"_\")}\n        self.vae = AutoencoderKL(**vae_config)\n        st_path = os.path.join(vae_dir, \"diffusion_pytorch_model.safetensors\")\n        bin_path = os.path.join(vae_dir, \"diffusion_pytorch_model.bin\")\n        if os.path.exists(st_path):\n            from safetensors.torch import load_file\n\n            state_dict = load_file(st_path)\n        elif os.path.exists(bin_path):\n            state_dict = torch.load(bin_path, map_location=\"cpu\", weights_only=True)\n        else:\n            raise FileNotFoundError(f\"No VAE weights in {vae_dir}\")\n        self.vae.load_state_dict(state_dict)\n        self.vae = self.vae.to(device=self.device, dtype=torch.float16).eval()\n        self.transformer = UNet2p5DConditionModel.from_pretrained(\n            os.path.join(local_path, \"unet\"),\n            torch_dtype=torch.float16,\n        ).to(self.device)\n        self.is_turbo = bool(getattr(self.config, \"paint_turbo_mode\", False))\n        sched_path = os.path.join(local_path, \"scheduler\", \"scheduler_config.json\")\n        with open(sched_path, \"r\") as f:\n            sched_cfg = json.load(f)\n        if self.is_turbo:\n            from diffusers import LCMScheduler\n\n            self.scheduler = LCMScheduler.from_config(sched_cfg)\n        else:\n            from diffusers import EulerAncestralDiscreteScheduler\n\n            self.scheduler = EulerAncestralDiscreteScheduler.from_config(\n                sched_cfg, timestep_spacing=\"trailing\"\n            )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n        self.solver = DDIMSolver(\n            self.scheduler.alphas_cumprod.cpu().numpy(),\n            timesteps=self.scheduler.config.num_train_timesteps,\n            ddim_timesteps=30,\n        ).to(self.device)\n        if server_args.enable_torch_compile:\n            compile_mode = os.environ.get(\n                \"SGLANG_TORCH_COMPILE_MODE\", \"max-autotune-no-cudagraphs\"\n            )\n            logger.info(\"Compiling paint transformer with mode: %s\", compile_mode)\n            self.transformer.compile(mode=compile_mode, fullgraph=False, dynamic=None)\n\n    def _convert_pil_list_to_tensor(\n        self, images: list, device: torch.device\n    ) -> torch.Tensor:\n        bg_c = [1.0, 1.0, 1.0]\n        images_tensor = []\n        for batch_imgs in images:\n            view_imgs = []\n            for pil_img in batch_imgs:\n                if pil_img.mode == \"L\":\n                    pil_img = pil_img.point(\n                        lambda x: 255 if x > 1 else 0, mode=\"1\"\n                    ).convert(\"RGB\")\n                img = np.asarray(pil_img, dtype=np.float32) / 255.0\n                if img.shape[2] > 3:\n                    alpha = img[:, :, 3:]\n                    img = img[:, :, :3] * alpha + bg_c * (1 - alpha)\n                img = (\n                    torch.from_numpy(img)\n                    .permute(2, 0, 1)\n                    .unsqueeze(0)\n                    .contiguous()\n                    .to(device=device, dtype=self.vae.dtype)\n                )\n                view_imgs.append(img)\n            view_imgs = torch.cat(view_imgs, dim=0)\n            images_tensor.append(view_imgs.unsqueeze(0))\n        return torch.cat(images_tensor, dim=0)\n\n    @torch.no_grad()\n    def _encode_images(self, images: torch.Tensor) -> torch.Tensor:\n        batch_size = images.shape[0]\n        images = rearrange(images, \"b n c h w -> (b n) c h w\")\n        dtype = next(self.vae.parameters()).dtype\n        images = (images - 0.5) * 2.0\n        posterior = self.vae.encode(images.to(dtype)).latent_dist\n        latents = posterior.sample() * self.vae.config.scaling_factor\n        return rearrange(latents, \"(b n) c h w -> b n c h w\", b=batch_size)\n\n    @staticmethod\n    def _compute_camera_index(azim: float, elev: float) -> int:\n        base_idx = int(((azim // 30) + 9) % 12)\n        if elev == 0:\n            base, divisor = 12, 1\n        elif elev == 20:\n            base, divisor = 24, 1\n        elif elev == -20:\n            base, divisor = 0, 1\n        elif elev == 90:\n            base, divisor = 40, 3\n        elif elev == -90:\n            base, divisor = 36, 3\n        else:\n            base, divisor = 12, 1\n        return base + (base_idx // divisor)\n\n    def _prepare_denoising_inputs(\n        self,\n        batch: Req,\n        server_args: ServerArgs,\n    ) -> dict[str, Any]:\n        import random\n\n        from diffusers.utils.torch_utils import randn_tensor\n\n        device = self.device\n        normal_maps = batch.extra[\"normal_maps\"]\n        position_maps = batch.extra[\"position_maps\"]\n        camera_azims = batch.extra[\"camera_azims\"]\n        camera_elevs = batch.extra[\"camera_elevs\"]\n\n        num_steps = self.config.paint_num_inference_steps\n        guidance_scale = self.config.paint_guidance_scale\n        render_size = self.config.paint_resolution\n        num_in_batch = len(normal_maps)\n\n        seed = 0\n        random.seed(seed)\n        np.random.seed(seed)\n        torch.manual_seed(seed)\n        generator = torch.Generator(device=device).manual_seed(seed)\n\n        image = batch.extra[\"delighted_image\"]\n        if not isinstance(image, list):\n            image = [image]\n        image = [to_rgb_image(img) for img in image]\n\n        image_vae = [\n            torch.tensor(np.array(img, dtype=np.float32) / 255.0) for img in image\n        ]\n        image_vae = [\n            iv.unsqueeze(0).permute(0, 3, 1, 2).unsqueeze(0) for iv in image_vae\n        ]\n        image_vae = torch.cat(image_vae, dim=1).to(device=device, dtype=self.vae.dtype)\n        ref_latents = self._encode_images(image_vae)\n\n        target_size = render_size\n        if isinstance(normal_maps, list):\n            normal_maps = [\n                (\n                    img.resize((target_size, target_size))\n                    if hasattr(img, \"resize\")\n                    else img\n                )\n                for img in normal_maps\n            ]\n            normal_maps = self._convert_pil_list_to_tensor([normal_maps], device)\n        if isinstance(position_maps, list):\n            position_maps = [\n                (\n                    img.resize((target_size, target_size))\n                    if hasattr(img, \"resize\")\n                    else img\n                )\n                for img in position_maps\n            ]\n            position_maps = self._convert_pil_list_to_tensor([position_maps], device)\n\n        normal_imgs = (\n            self._encode_images(normal_maps) if normal_maps is not None else None\n        )\n        position_imgs = (\n            self._encode_images(position_maps) if position_maps is not None else None\n        )\n\n        camera_info = [\n            self._compute_camera_index(azim, elev)\n            for azim, elev in zip(camera_azims, camera_elevs)\n        ]\n        camera_info_gen = torch.tensor([camera_info], device=device, dtype=torch.int64)\n        camera_info_ref = torch.tensor([[0]], device=device, dtype=torch.int64)\n\n        do_cfg = guidance_scale > 1 and not self.is_turbo\n\n        if self.is_turbo and position_maps is not None:\n            from sglang.multimodal_gen.runtime.models.dits.hunyuan3d import (\n                compute_multi_resolution_discrete_voxel_indice,\n                compute_multi_resolution_mask,\n            )\n\n            position_attn_mask = compute_multi_resolution_mask(position_maps)\n            position_voxel_indices = compute_multi_resolution_discrete_voxel_indice(\n                position_maps\n            )\n        else:\n            position_attn_mask = None\n            position_voxel_indices = None\n\n        if do_cfg:\n            negative_ref_latents = torch.zeros_like(ref_latents)\n            ref_latents = torch.cat([negative_ref_latents, ref_latents])\n            ref_scale = torch.as_tensor([0.0, 1.0]).to(ref_latents)\n            if normal_imgs is not None:\n                normal_imgs = torch.cat((normal_imgs, normal_imgs))\n            if position_imgs is not None:\n                position_imgs = torch.cat((position_imgs, position_imgs))\n            if position_maps is not None:\n                position_maps = torch.cat((position_maps, position_maps))\n            camera_info_gen = torch.cat((camera_info_gen, camera_info_gen))\n            camera_info_ref = torch.cat((camera_info_ref, camera_info_ref))\n        else:\n            ref_scale = None\n\n        model_kwargs = {\n            \"ref_latents\": ref_latents,\n            \"num_in_batch\": num_in_batch,\n        }\n        if ref_scale is not None:\n            model_kwargs[\"ref_scale\"] = ref_scale\n        if normal_imgs is not None:\n            model_kwargs[\"normal_imgs\"] = normal_imgs\n        if position_imgs is not None:\n            model_kwargs[\"position_imgs\"] = position_imgs\n        if position_maps is not None:\n            model_kwargs[\"position_maps\"] = position_maps\n        model_kwargs[\"camera_info_gen\"] = camera_info_gen\n        model_kwargs[\"camera_info_ref\"] = camera_info_ref\n        if position_attn_mask is not None:\n            model_kwargs[\"position_attn_mask\"] = position_attn_mask\n        if position_voxel_indices is not None:\n            model_kwargs[\"position_voxel_indices\"] = position_voxel_indices\n\n        prompt_embeds = self.transformer.learned_text_clip_gen.repeat(1, 1, 1)\n        negative_prompt_embeds = torch.zeros_like(prompt_embeds)\n\n        if self.is_turbo:\n            bsz = 3\n            index = torch.arange(29, -1, -bsz, device=device).long()\n            timesteps = self.solver.ddim_timesteps[index]\n            self.scheduler.set_timesteps(timesteps=timesteps.cpu(), device=device)\n            timesteps = self.scheduler.timesteps\n        else:\n            timesteps, num_steps = retrieve_timesteps(\n                self.scheduler, num_steps, device, None, None\n            )\n\n        num_channels_latents = self.transformer.config.in_channels\n        latent_shape = (\n            num_in_batch,\n            num_channels_latents,\n            render_size // self.vae_scale_factor,\n            render_size // self.vae_scale_factor,\n        )\n        latents = randn_tensor(\n            latent_shape, generator=generator, device=device, dtype=prompt_embeds.dtype\n        )\n        latents = latents * self.scheduler.init_noise_sigma\n\n        return {\n            \"timesteps\": timesteps,\n            \"latents\": latents,\n            \"prompt_embeds\": prompt_embeds,\n            \"negative_prompt_embeds\": negative_prompt_embeds,\n            \"model_kwargs\": model_kwargs,\n            \"num_in_batch\": num_in_batch,\n            \"num_inference_steps\": num_steps,\n            \"guidance_scale\": guidance_scale,\n            \"do_cfg\": do_cfg,\n            \"generator\": generator,\n            \"num_channels_latents\": num_channels_latents,\n        }\n\n    @torch.no_grad()\n    def _denoise_loop(\n        self,\n        timesteps: torch.Tensor,\n        latents: torch.Tensor,\n        prompt_embeds: torch.Tensor,\n        negative_prompt_embeds: torch.Tensor,\n        model_kwargs: dict[str, Any],\n        num_in_batch: int,\n        guidance_scale: float,\n        do_cfg: bool,\n        generator: torch.Generator,\n        num_channels_latents: int,\n    ) -> torch.Tensor:\n        import inspect\n\n        if do_cfg:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])\n\n        extra_step_kwargs = {}\n        if \"eta\" in inspect.signature(self.scheduler.step).parameters:\n            extra_step_kwargs[\"eta\"] = 0.0\n        if \"generator\" in inspect.signature(self.scheduler.step).parameters:\n            extra_step_kwargs[\"generator\"] = generator\n\n        for step_idx, t in enumerate(timesteps):\n            latents = rearrange(latents, \"(b n) c h w -> b n c h w\", n=num_in_batch)\n            latent_model_input = torch.cat([latents] * 2) if do_cfg else latents\n            latent_model_input = rearrange(\n                latent_model_input, \"b n c h w -> (b n) c h w\"\n            )\n            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n            latent_model_input = rearrange(\n                latent_model_input, \"(b n) c h w -> b n c h w\", n=num_in_batch\n            )\n\n            with set_forward_context(\n                current_timestep=step_idx,\n                attn_metadata=None,\n            ):\n                noise_pred = self.transformer(\n                    latent_model_input,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    timestep_cond=None,\n                    cross_attention_kwargs=None,\n                    added_cond_kwargs=None,\n                    return_dict=False,\n                    **model_kwargs,\n                )[0]\n\n            latents = rearrange(latents, \"b n c h w -> (b n) c h w\")\n\n            if do_cfg:\n                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                noise_pred = noise_pred_uncond + guidance_scale * (\n                    noise_pred_text - noise_pred_uncond\n                )\n\n            latents = self.scheduler.step(\n                noise_pred,\n                t,\n                latents[:, :num_channels_latents, :, :],\n                **extra_step_kwargs,\n                return_dict=False,\n            )[0]\n\n        return latents\n\n    @torch.no_grad()\n    def _decode_latents(self, latents: torch.Tensor) -> list:\n        image = self.vae.decode(\n            latents / self.vae.config.scaling_factor, return_dict=False\n        )[0]\n        return self.image_processor.postprocess(image, output_type=\"pil\")\n\n    def forward(self, batch: Req, server_args: ServerArgs) -> Req:\n        if batch.extra.get(\"_mesh_failed\"):\n            logger.warning(\"Mesh generation failed, skipping paint texgen\")\n            batch.extra[\"multiview_textures\"] = []\n            return batch\n\n        self._load_paint_models(server_args)\n\n        delighted_image = batch.extra[\"delighted_image\"]\n        normal_maps = batch.extra[\"normal_maps\"]\n\n        if self.transformer is not None:\n            try:\n                prepared = self._prepare_denoising_inputs(batch, server_args)\n\n                latents = self._denoise_loop(\n                    timesteps=prepared[\"timesteps\"],\n                    latents=prepared[\"latents\"],\n                    prompt_embeds=prepared[\"prompt_embeds\"],\n                    negative_prompt_embeds=prepared[\"negative_prompt_embeds\"],\n                    model_kwargs=prepared[\"model_kwargs\"],\n                    num_in_batch=prepared[\"num_in_batch\"],\n                    guidance_scale=prepared[\"guidance_scale\"],\n                    do_cfg=prepared[\"do_cfg\"],\n                    generator=prepared[\"generator\"],\n                    num_channels_latents=prepared[\"num_channels_latents\"],\n                )\n\n                multiview_textures = self._decode_latents(latents)\n                logger.info(\n                    \"Paint pipeline generated %d textures\", len(multiview_textures)\n                )\n\n            except Exception as e:\n                logger.error(f\"Paint pipeline execution failed: {e}\")\n                import traceback\n\n                traceback.print_exc()\n                render_size = self.config.paint_resolution\n                multiview_textures = [\n                    delighted_image.resize((render_size, render_size))\n                    for _ in range(len(normal_maps))\n                ]\n        else:\n            logger.warning(\n                \"Paint pipeline not available, using reference image for all views\"\n            )\n            render_size = self.config.paint_resolution\n            multiview_textures = [\n                delighted_image.resize((render_size, render_size))\n                for _ in range(len(normal_maps))\n            ]\n\n        batch.extra[\"multiview_textures\"] = multiview_textures\n        logger.info(f\"Generated {len(multiview_textures)} texture views\")\n        return batch\n\n    def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult:\n        if batch.extra.get(\"_mesh_failed\"):\n            return VerificationResult()\n        result = VerificationResult()\n        result.add_check(\n            \"delighted_image\", batch.extra.get(\"delighted_image\"), V.not_none\n        )\n        result.add_check(\"normal_maps\", batch.extra.get(\"normal_maps\"), V.is_list)\n        result.add_check(\"position_maps\", batch.extra.get(\"position_maps\"), V.is_list)\n        result.add_check(\"camera_azims\", batch.extra.get(\"camera_azims\"), V.is_list)\n        result.add_check(\"camera_elevs\", batch.extra.get(\"camera_elevs\"), V.is_list)\n        return result\n\n    def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult:\n        result = VerificationResult()\n        result.add_check(\n            \"multiview_textures\", batch.extra.get(\"multiview_textures\"), V.is_list\n        )\n        return result\n\n\n# Stage 3: Postprocess (texture baking + mesh export)\nclass Hunyuan3DPaintPostprocessStage(PipelineStage):\n    \"\"\"Texture baking from multi-view images and final mesh export.\"\"\"\n\n    @property\n    def parallelism_type(self) -> StageParallelismType:\n        return StageParallelismType.MAIN_RANK_ONLY\n\n    def __init__(self, config: Hunyuan3D2PipelineConfig) -> None:\n        super().__init__()\n        self.config = config\n\n    def forward(self, batch: Req, server_args: ServerArgs) -> OutputBatch:\n        if batch.extra.get(\"_mesh_failed\"):\n            logger.warning(\"Mesh generation failed, skipping paint postprocess\")\n            return OutputBatch(output_file_paths=[], metrics=batch.metrics)\n\n        renderer = batch.extra[\"renderer\"]\n        multiview_textures = batch.extra[\"multiview_textures\"]\n        camera_elevs = batch.extra[\"camera_elevs\"]\n        camera_azims = batch.extra[\"camera_azims\"]\n        view_weights = batch.extra[\"view_weights\"]\n\n        render_size = getattr(self.config, \"paint_render_size\", 2048)\n        resized_textures = []\n        for tex in multiview_textures:\n            if hasattr(tex, \"resize\"):\n                resized_textures.append(tex.resize((render_size, render_size)))\n            else:\n                resized_textures.append(tex)\n\n        try:\n            texture, mask = renderer.bake_from_multiview(\n                resized_textures,\n                camera_elevs,\n                camera_azims,\n                view_weights,\n                method=\"fast\",\n            )\n\n            mask_np = (mask.squeeze(-1).cpu().numpy() * 255).astype(\"uint8\")\n            texture = renderer.texture_inpaint(texture, mask_np)\n\n            renderer.set_texture(texture)\n            textured_mesh = renderer.save_mesh()\n            logger.info(\"Texture baking completed\")\n        except Exception as e:\n            logger.error(f\"Texture baking failed: {e}\")\n            textured_mesh = batch.extra[\"paint_mesh\"]\n\n        obj_path = batch.extra[\"shape_obj_path\"]\n        return_path = batch.extra[\"shape_return_path\"]\n\n        try:\n            textured_mesh.export(obj_path)\n            if self.config.paint_save_glb:\n                glb_path = obj_path[:-4] + \".glb\"\n                textured_mesh.export(glb_path)\n                return_path = glb_path\n                self._cleanup_obj_artifacts(obj_path)\n        except Exception as e:\n            logger.error(f\"Mesh export failed: {e}\")\n\n        return OutputBatch(output_file_paths=[return_path], metrics=batch.metrics)\n\n    @staticmethod\n    def _cleanup_obj_artifacts(obj_path: str) -> None:\n        \"\"\"Remove OBJ file and trimesh-generated material artifacts.\"\"\"\n        obj_dir = os.path.dirname(obj_path) or \".\"\n        targets = [obj_path]\n        for f in os.listdir(obj_dir):\n            if f.endswith(\".mtl\") or (f.startswith(\"material\") and f.endswith(\".png\")):\n                targets.append(os.path.join(obj_dir, f))\n        for path in targets:\n            try:\n                os.remove(path)\n            except OSError:\n                pass\n\n    def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult:\n        if batch.extra.get(\"_mesh_failed\"):\n            return VerificationResult()\n        result = VerificationResult()\n        result.add_check(\"renderer\", batch.extra.get(\"renderer\"), V.not_none)\n        result.add_check(\n            \"multiview_textures\", batch.extra.get(\"multiview_textures\"), V.is_list\n        )\n        result.add_check(\"camera_elevs\", batch.extra.get(\"camera_elevs\"), V.is_list)\n        result.add_check(\"camera_azims\", batch.extra.get(\"camera_azims\"), V.is_list)\n        result.add_check(\"view_weights\", batch.extra.get(\"view_weights\"), V.is_list)\n        return result\n\n\n__all__ = [\n    \"Hunyuan3DPaintPreprocessStage\",\n    \"Hunyuan3DPaintTexGenStage\",\n    \"Hunyuan3DPaintPostprocessStage\",\n]\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines_core/stages/hunyuan3d_shape.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\"\"\"\nHunyuan3D shape generation stages.\n\nFour-stage pipeline: BeforeDenoising -> Denoising -> Export -> Save.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport os\nfrom typing import Any\n\nimport numpy as np\nimport torch\n\nfrom sglang.multimodal_gen.configs.pipeline_configs.hunyuan3d import (\n    Hunyuan3D2PipelineConfig,\n)\nfrom sglang.multimodal_gen.runtime.loader.component_loaders.transformer_loader import (\n    TransformerLoader,\n)\nfrom sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context\nfrom sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import OutputBatch, Req\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.base import PipelineStage\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.denoising import DenoisingStage\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.validators import (\n    StageValidators as V,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.validators import (\n    VerificationResult,\n)\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.runtime.utils.mesh3d_utils import export_to_trimesh\n\nlogger = init_logger(__name__)\n\n\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps=None,\n    device=None,\n    timesteps=None,\n    sigmas=None,\n    **kwargs,\n):\n    \"\"\"Retrieve timesteps from scheduler.\"\"\"\n    import inspect\n\n    if timesteps is not None and sigmas is not None:\n        raise ValueError(\"Only one of timesteps or sigmas can be passed.\")\n\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(\n            inspect.signature(scheduler.set_timesteps).parameters.keys()\n        )\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"Scheduler {scheduler.__class__} doesn't support custom timesteps.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n\n    elif sigmas is not None:\n        accepts_sigmas = \"sigmas\" in set(\n            inspect.signature(scheduler.set_timesteps).parameters.keys()\n        )\n        if not accepts_sigmas:\n            raise ValueError(\n                f\"Scheduler {scheduler.__class__} doesn't support custom sigmas.\"\n            )\n        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n\n    return timesteps, num_inference_steps\n\n\ndef _prepare_shape_image(image_processor, image, mask=None) -> dict:\n    \"\"\"Prepare shape image for conditioning.\"\"\"\n    if isinstance(image, torch.Tensor) and isinstance(mask, torch.Tensor):\n        return {\"image\": image, \"mask\": mask}\n\n    if isinstance(image, str) and not os.path.exists(image):\n        raise FileNotFoundError(f\"Couldn't find image at path {image}\")\n\n    if not isinstance(image, list):\n        image = [image]\n\n    outputs = [image_processor(img) for img in image]\n    cond_input = {k: [] for k in outputs[0].keys()}\n    for output in outputs:\n        for key, value in output.items():\n            cond_input[key].append(value)\n    for key, value in cond_input.items():\n        if isinstance(value[0], torch.Tensor):\n            cond_input[key] = torch.cat(value, dim=0)\n    return cond_input\n\n\ndef _move_to_device(payload, device, dtype):\n    \"\"\"Recursively move tensors in payload to specified device and dtype.\"\"\"\n    if isinstance(payload, torch.Tensor):\n        return payload.to(device=device, dtype=dtype)\n    if isinstance(payload, dict):\n        return {k: _move_to_device(v, device, dtype) for k, v in payload.items()}\n    if isinstance(payload, list):\n        return [_move_to_device(v, device, dtype) for v in payload]\n    return payload\n\n\nclass Hunyuan3DShapeBeforeDenoisingStage(PipelineStage):\n    \"\"\"Monolithic pre-processing stage for Hunyuan3D shape generation.\n\n    Consolidates input validation, image preprocessing, conditioning, and\n    latent/timestep preparation into a single stage.\n    \"\"\"\n\n    def __init__(\n        self,\n        image_processor: Any,\n        conditioner: Any,\n        vae: Any,\n        model: Any,\n        scheduler: Any,\n        config: Hunyuan3D2PipelineConfig,\n    ) -> None:\n        super().__init__()\n        self.image_processor = image_processor\n        self.conditioner = conditioner\n        self.vae = vae\n        self.model = model\n        self.scheduler = scheduler\n        self.config = config\n\n    def _validate_input(self, batch: Req, server_args: ServerArgs) -> None:\n        if batch.image_path is None:\n            raise ValueError(\"Hunyuan3D requires 'image_path' input.\")\n        if isinstance(batch.image_path, list):\n            if len(batch.image_path) != 1:\n                raise ValueError(\"Hunyuan3D only supports a single image input.\")\n            batch.image_path = batch.image_path[0]\n        if not isinstance(batch.image_path, str):\n            raise ValueError(\n                f\"Hunyuan3D expects image_path as str, got {type(batch.image_path)}\"\n            )\n        if not os.path.exists(batch.image_path):\n            raise FileNotFoundError(f\"Image path not found: {batch.image_path}\")\n        if batch.num_outputs_per_prompt != 1:\n            raise ValueError(\"Hunyuan3D only supports num_outputs_per_prompt=1.\")\n\n    def _prepare_latents(self, batch_size, dtype, device, generator):\n        from diffusers.utils.torch_utils import randn_tensor\n\n        shape = (batch_size, *self.vae.latent_shape)\n        latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        return latents * getattr(self.scheduler, \"init_noise_sigma\", 1.0)\n\n    def forward(self, batch: Req, server_args: ServerArgs) -> Req:\n        # 1. Input validation\n        self._validate_input(batch, server_args)\n\n        # 2. Image preprocessing\n        cond_inputs = _prepare_shape_image(self.image_processor, batch.image_path)\n        image = cond_inputs.pop(\"image\")\n\n        device = self.device\n        dtype = next(self.model.parameters()).dtype\n        image = _move_to_device(image, device, dtype)\n        cond_inputs = _move_to_device(cond_inputs, device, dtype)\n\n        # 3. Conditioning with CFG\n        do_cfg = batch.guidance_scale >= 0 and not (\n            hasattr(self.model, \"guidance_embed\") and self.model.guidance_embed is True\n        )\n\n        cond = self.conditioner(image=image, **cond_inputs)\n        if do_cfg:\n            un_cond = self.conditioner.unconditional_embedding(\n                image.shape[0], **cond_inputs\n            )\n\n            def cat_recursive(a, b):\n                if isinstance(a, torch.Tensor):\n                    return torch.cat([a, b], dim=0).to(dtype)\n                out = {}\n                for key in a.keys():\n                    out[key] = cat_recursive(a[key], b[key])\n                return out\n\n            cond = cat_recursive(cond, un_cond)\n\n        # 4. Latent and timestep preparation\n        batch_size = image.shape[0]\n        sigmas = np.linspace(0, 1, batch.num_inference_steps)\n        timesteps, _ = retrieve_timesteps(\n            self.scheduler,\n            batch.num_inference_steps,\n            device,\n            sigmas=sigmas,\n        )\n\n        generator = batch.generator\n        if generator is None and batch.seed is not None:\n            generator = torch.Generator(device=device).manual_seed(batch.seed)\n\n        latents = self._prepare_latents(batch_size, dtype, device, generator)\n\n        guidance = None\n        if hasattr(self.model, \"guidance_embed\") and self.model.guidance_embed is True:\n            guidance = torch.tensor(\n                [batch.guidance_scale] * batch_size, device=device, dtype=dtype\n            )\n\n        # 5. Populate batch\n        batch.prompt_embeds = [cond]\n        batch.do_classifier_free_guidance = do_cfg\n        batch.timesteps = timesteps\n        batch.latents = latents\n        batch.extra[\"shape_guidance\"] = guidance\n        batch.extra[\"shape_image\"] = image\n        return batch\n\n    def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult:\n        result = VerificationResult()\n        result.add_check(\"image_path\", batch.image_path, V.not_none)\n        result.add_check(\n            \"num_inference_steps\", batch.num_inference_steps, V.positive_int\n        )\n        return result\n\n    def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult:\n        result = VerificationResult()\n        result.add_check(\"timesteps\", batch.timesteps, [V.is_tensor, V.min_dims(1)])\n        result.add_check(\"latents\", batch.latents, V.is_tensor)\n        result.add_check(\"prompt_embeds\", batch.prompt_embeds, V.list_not_empty)\n        return result\n\n\nclass Hunyuan3DShapeDenoisingStage(DenoisingStage):\n    \"\"\"Denoising stage for Hunyuan3D shape generation.\"\"\"\n\n    def __init__(self, transformer: Any, scheduler: Any, **kwargs) -> None:\n        super().__init__(transformer=transformer, scheduler=scheduler, **kwargs)\n\n    def _prepare_denoising_loop(self, batch: Req, server_args: ServerArgs):\n        \"\"\"Prepare Hunyuan3D-specific variables for the base denoising loop.\"\"\"\n        assert self.transformer is not None\n        pipeline = self.pipeline() if self.pipeline else None\n        cache_dit_num_inference_steps = batch.extra.get(\n            \"cache_dit_num_inference_steps\", batch.num_inference_steps\n        )\n        if not server_args.model_loaded[\"transformer\"]:\n            loader = TransformerLoader()\n            self.transformer = loader.load(\n                server_args.model_paths[\"transformer\"], server_args, \"transformer\"\n            )\n            self._maybe_enable_cache_dit(cache_dit_num_inference_steps, batch)\n            self._maybe_enable_torch_compile(self.transformer)\n            if pipeline:\n                pipeline.add_module(\"transformer\", self.transformer)\n            server_args.model_loaded[\"transformer\"] = True\n        else:\n            self._maybe_enable_cache_dit(cache_dit_num_inference_steps, batch)\n\n        timesteps = batch.timesteps\n        if timesteps is None:\n            raise ValueError(\"Timesteps must be provided\")\n\n        latents = batch.latents\n        if latents is None:\n            raise ValueError(\"Latents must be provided\")\n\n        cond = batch.prompt_embeds[0] if batch.prompt_embeds else None\n        if cond is None:\n            raise ValueError(\"Conditioning (prompt_embeds) must be provided\")\n\n        if batch.raw_latent_shape is None:\n            batch.raw_latent_shape = latents.shape\n\n        guidance = batch.extra.get(\"shape_guidance\")\n        num_inference_steps = batch.num_inference_steps\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n\n        extra_step_kwargs = self.prepare_extra_func_kwargs(\n            self.scheduler.step,\n            {\"generator\": batch.generator, \"eta\": batch.eta},\n        )\n\n        target_dtype = next(self.transformer.parameters()).dtype\n        autocast_enabled = False\n\n        pos_cond_kwargs = {\"encoder_hidden_states\": cond}\n        neg_cond_kwargs = {}\n\n        return {\n            \"extra_step_kwargs\": extra_step_kwargs,\n            \"target_dtype\": target_dtype,\n            \"autocast_enabled\": autocast_enabled,\n            \"timesteps\": timesteps,\n            \"num_inference_steps\": num_inference_steps,\n            \"num_warmup_steps\": num_warmup_steps,\n            \"image_kwargs\": {},\n            \"pos_cond_kwargs\": pos_cond_kwargs,\n            \"neg_cond_kwargs\": neg_cond_kwargs,\n            \"latents\": latents,\n            \"prompt_embeds\": batch.prompt_embeds,\n            \"neg_prompt_embeds\": None,\n            \"boundary_timestep\": None,\n            \"z\": None,\n            \"reserved_frames_mask\": None,\n            \"seq_len\": None,\n            \"guidance\": guidance,\n        }\n\n    def _predict_noise(\n        self,\n        current_model,\n        latent_model_input,\n        timestep,\n        target_dtype,\n        guidance: torch.Tensor,\n        **kwargs,\n    ):\n        \"\"\"Hunyuan3D-specific noise prediction with normalized timestep.\"\"\"\n        cond = kwargs.get(\"encoder_hidden_states\")\n        timestep_norm = timestep / self.scheduler.config.num_train_timesteps\n        return current_model(latent_model_input, timestep_norm, cond, guidance=guidance)\n\n    def _predict_noise_with_cfg(\n        self,\n        current_model,\n        latent_model_input: torch.Tensor,\n        timestep,\n        batch: Req,\n        timestep_index: int,\n        attn_metadata,\n        target_dtype,\n        current_guidance_scale,\n        image_kwargs: dict[str, Any],\n        pos_cond_kwargs: dict[str, Any],\n        neg_cond_kwargs: dict[str, Any],\n        server_args,\n        guidance,\n        latents,\n    ):\n        \"\"\"Hunyuan3D-specific CFG: concat latents, single forward, then split.\"\"\"\n        cond = pos_cond_kwargs.get(\"encoder_hidden_states\")\n        do_cfg = batch.do_classifier_free_guidance\n\n        if do_cfg:\n            latent_input = torch.cat([latent_model_input] * 2)\n        else:\n            latent_input = latent_model_input\n\n        timestep_expanded = timestep.expand(latent_input.shape[0]).to(latents.dtype)\n\n        with set_forward_context(\n            current_timestep=timestep_index,\n            attn_metadata=attn_metadata,\n            forward_batch=batch,\n        ):\n            noise_pred = self._predict_noise(\n                current_model=current_model,\n                latent_model_input=latent_input,\n                timestep=timestep_expanded,\n                target_dtype=target_dtype,\n                guidance=guidance,\n                encoder_hidden_states=cond,\n            )\n\n        if do_cfg:\n            noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)\n            noise_pred = noise_pred_uncond + current_guidance_scale * (\n                noise_pred_cond - noise_pred_uncond\n            )\n\n        return noise_pred\n\n    def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult:\n        result = VerificationResult()\n        result.add_check(\"timesteps\", batch.timesteps, [V.is_tensor, V.min_dims(1)])\n        result.add_check(\"latents\", batch.latents, V.is_tensor)\n        result.add_check(\"prompt_embeds\", batch.prompt_embeds, V.list_not_empty)\n        result.add_check(\n            \"num_inference_steps\", batch.num_inference_steps, V.positive_int\n        )\n        result.add_check(\"guidance_scale\", batch.guidance_scale, V.non_negative_float)\n        return result\n\n    def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult:\n        result = VerificationResult()\n        result.add_check(\"latents\", batch.latents, V.is_tensor)\n        return result\n\n\nclass Hunyuan3DShapeExportStage(PipelineStage):\n    \"\"\"VAE decoding and mesh extraction stage.\"\"\"\n\n    def __init__(self, vae: Any, config: Hunyuan3D2PipelineConfig) -> None:\n        super().__init__()\n        self.vae = vae\n        self.config = config\n\n    def forward(self, batch: Req, server_args: ServerArgs) -> Req:\n        if self.config.shape_mc_algo is not None:\n            try:\n                from sglang.multimodal_gen.runtime.models.vaes.hunyuan3d_vae import (\n                    SurfaceExtractors,\n                )\n\n                self.vae.surface_extractor = SurfaceExtractors[\n                    self.config.shape_mc_algo\n                ]()\n            except ImportError:\n                logger.warning(\n                    f\"Could not load SurfaceExtractors for mc_algo={self.config.shape_mc_algo}\"\n                )\n\n        latents = batch.latents\n\n        if self.config.shape_output_type != \"latent\":\n            latents = 1.0 / self.vae.scale_factor * latents\n            latents = self.vae(latents)\n\n            outputs = self.vae.latents2mesh(\n                latents,\n                bounds=self.config.shape_box_v,\n                mc_level=self.config.shape_mc_level,\n                num_chunks=self.config.shape_num_chunks,\n                octree_resolution=self.config.shape_octree_resolution,\n                mc_algo=self.config.shape_mc_algo,\n                enable_pbar=False,\n            )\n        else:\n            outputs = latents\n\n        if self.config.shape_output_type == \"trimesh\":\n            outputs = export_to_trimesh(outputs)\n\n        batch.extra[\"shape_meshes\"] = outputs\n        return batch\n\n    def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult:\n        result = VerificationResult()\n        result.add_check(\"latents\", batch.latents, V.is_tensor)\n        return result\n\n    def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult:\n        result = VerificationResult()\n        result.add_check(\"shape_meshes\", batch.extra.get(\"shape_meshes\"), V.not_none)\n        return result\n\n\nclass Hunyuan3DShapeSaveStage(PipelineStage):\n    \"\"\"Mesh file export and output decision stage.\"\"\"\n\n    def __init__(self, config: Hunyuan3D2PipelineConfig) -> None:\n        super().__init__()\n        self.config = config\n\n    def _get_output_paths(self, batch: Req) -> tuple[str, str]:\n        output_path = batch.output_file_path() or os.path.join(\n            batch.output_path, \"output.obj\"\n        )\n        if output_path.endswith(\".glb\"):\n            obj_path = output_path[:-4] + \".obj\"\n            return obj_path, output_path\n        if output_path.endswith(\".obj\"):\n            return output_path, output_path\n        return output_path + \".obj\", output_path + \".obj\"\n\n    def forward(self, batch: Req, server_args: ServerArgs) -> Req | OutputBatch:\n        mesh_outputs = batch.extra[\"shape_meshes\"]\n        mesh = mesh_outputs[0] if isinstance(mesh_outputs, list) else mesh_outputs\n        if isinstance(mesh, list):\n            mesh = mesh[0]\n\n        if mesh is None:\n            if batch.is_warmup:\n                logger.info(\n                    \"Skipping mesh export during warmup \"\n                    \"(surface extraction returned None)\"\n                )\n                batch.extra[\"_mesh_failed\"] = True\n                if self.config.paint_enable:\n                    return batch\n                return OutputBatch(output_file_paths=[], metrics=batch.metrics)\n            raise RuntimeError(\n                \"Mesh generation failed: surface extraction returned None. \"\n                \"The surface level may be outside the volume data range.\"\n            )\n\n        obj_path, return_path = self._get_output_paths(batch)\n        output_dir = os.path.dirname(obj_path)\n        if output_dir:\n            os.makedirs(output_dir, exist_ok=True)\n        mesh.export(obj_path)\n\n        batch.extra[\"shape_obj_path\"] = obj_path\n        batch.extra[\"shape_return_path\"] = return_path\n\n        if self.config.paint_enable:\n            return batch\n\n        if return_path.endswith(\".glb\"):\n            return_path = obj_path\n        return OutputBatch(output_file_paths=[return_path], timings=batch.timings)\n\n    def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult:\n        result = VerificationResult()\n        result.add_check(\"shape_meshes\", batch.extra.get(\"shape_meshes\"), V.not_none)\n        return result\n\n\n__all__ = [\n    \"retrieve_timesteps\",\n    \"Hunyuan3DShapeBeforeDenoisingStage\",\n    \"Hunyuan3DShapeDenoisingStage\",\n    \"Hunyuan3DShapeExportStage\",\n    \"Hunyuan3DShapeSaveStage\",\n]\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines_core/stages/image_encoding.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"\nImage encoding stages for I2V diffusion pipelines.\n\nThis module contains implementations of image encoding stages for diffusion pipelines.\n\"\"\"\n\nimport inspect\n\nimport PIL\nimport torch\nfrom diffusers.models.autoencoders.vae import DiagonalGaussianDistribution\nfrom diffusers.models.modeling_outputs import AutoencoderKLOutput\n\nfrom sglang.multimodal_gen.configs.pipeline_configs.qwen_image import (\n    qwen_image_postprocess_text,\n)\nfrom sglang.multimodal_gen.runtime.distributed import get_local_torch_device\nfrom sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context\nfrom sglang.multimodal_gen.runtime.models.vaes.common import ParallelTiledVAE\nfrom sglang.multimodal_gen.runtime.models.vision_utils import (\n    normalize,\n    numpy_to_pt,\n    pil_to_numpy,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.base import PipelineStage\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.validators import (\n    StageValidators as V,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.validators import (\n    VerificationResult,\n)\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.utils import PRECISION_TO_TYPE\n\nlogger = init_logger(__name__)\n\n\nclass ImageEncodingStage(PipelineStage):\n    \"\"\"\n    Stage for encoding image prompts into embeddings for diffusion models.\n\n    This stage handles the encoding of image prompts into the embedding space\n    expected by the diffusion model.\n    \"\"\"\n\n    def __init__(\n        self,\n        image_processor,\n        image_encoder=None,\n        text_encoder=None,\n    ) -> None:\n        \"\"\"\n        Initialize the prompt encoding stage.\n\n        Args:\n            text_encoder: An encoder to encode input_ids and pixel values\n        \"\"\"\n        super().__init__()\n        self.image_processor = image_processor\n        self.image_encoder = image_encoder\n        self.text_encoder = text_encoder\n\n    def load_model(self):\n        if self.server_args.image_encoder_cpu_offload:\n            device = get_local_torch_device()\n            self.move_to_device(device)\n\n    def offload_model(self):\n        if self.server_args.image_encoder_cpu_offload:\n            self.move_to_device(\"cpu\")\n\n    def move_to_device(self, device):\n        if self.server_args.use_fsdp_inference:\n            return\n        fields = [\n            \"image_processor\",\n            \"image_encoder\",\n        ]\n        for field in fields:\n            processor = getattr(self, field, None)\n            if processor and hasattr(processor, \"to\"):\n                setattr(self, field, processor.to(device))\n\n    def encoding_qwen_image_edit(self, outputs, image_inputs):\n        # encoder hidden state\n        prompt_embeds = qwen_image_postprocess_text(outputs, image_inputs, 64)\n        return prompt_embeds\n\n    @torch.no_grad()\n    def forward(\n        self,\n        batch: Req,\n        server_args: ServerArgs,\n    ) -> Req:\n        \"\"\"\n        Encode the prompt into image encoder hidden states.\n        \"\"\"\n\n        if batch.condition_image is None:\n            return batch\n        cuda_device = get_local_torch_device()\n\n        self.load_model()\n\n        image_processor_kwargs = (\n            server_args.pipeline_config.prepare_image_processor_kwargs(batch)\n        )\n        per_prompt_images = image_processor_kwargs.pop(\"per_prompt_images\", None)\n        texts = image_processor_kwargs.pop(\"text\", None)\n\n        if per_prompt_images is None:\n            per_prompt_images = [batch.condition_image]\n            texts = [None] if texts is None else texts\n\n        all_prompt_embeds = []\n        all_neg_prompt_embeds = []\n\n        image_processor_call_params = inspect.signature(\n            self.image_processor.__call__\n        ).parameters\n        image_processor_kwargs = {\n            k: v\n            for k, v in image_processor_kwargs.items()\n            if k in image_processor_call_params\n        }\n\n        for idx, prompt_images in enumerate(per_prompt_images):\n            if not prompt_images:\n                continue\n\n            cur_kwargs = image_processor_kwargs.copy()\n            if texts and idx < len(texts) and \"text\" in image_processor_call_params:\n                cur_kwargs[\"text\"] = [texts[idx]]\n\n            image_inputs = self.image_processor(\n                images=prompt_images, return_tensors=\"pt\", **cur_kwargs\n            ).to(cuda_device)\n\n            if self.image_encoder:\n                # if an image encoder is provided\n                with set_forward_context(current_timestep=0, attn_metadata=None):\n                    outputs = self.image_encoder(\n                        **image_inputs,\n                        **server_args.pipeline_config.image_encoder_extra_args,\n                    )\n                    image_embeds = server_args.pipeline_config.postprocess_image(\n                        outputs\n                    )\n                batch.image_embeds.append(image_embeds)\n            elif self.text_encoder:\n                # if a text encoder is provided, e.g. Qwen-Image-Edit\n                # 1. neg prompt embeds\n                if batch.do_classifier_free_guidance:\n                    neg_image_processor_kwargs = (\n                        server_args.pipeline_config.prepare_image_processor_kwargs(\n                            batch, neg=True\n                        )\n                    )\n                    neg_image_processor_kwargs.pop(\"per_prompt_images\", None)\n                    neg_texts = neg_image_processor_kwargs.pop(\"text\", None)\n                    if neg_texts and idx < len(neg_texts):\n                        neg_image_processor_kwargs[\"text\"] = [neg_texts[idx]]\n                    neg_image_inputs = self.image_processor(\n                        images=prompt_images,\n                        return_tensors=\"pt\",\n                        **neg_image_processor_kwargs,\n                    ).to(cuda_device)\n\n                with set_forward_context(current_timestep=0, attn_metadata=None):\n                    outputs = self.text_encoder(\n                        input_ids=image_inputs.input_ids,\n                        attention_mask=image_inputs.attention_mask,\n                        pixel_values=image_inputs.pixel_values,\n                        image_grid_thw=image_inputs.image_grid_thw,\n                        output_hidden_states=True,\n                    )\n                    if batch.do_classifier_free_guidance:\n                        neg_outputs = self.text_encoder(\n                            input_ids=neg_image_inputs.input_ids,\n                            attention_mask=neg_image_inputs.attention_mask,\n                            pixel_values=neg_image_inputs.pixel_values,\n                            image_grid_thw=neg_image_inputs.image_grid_thw,\n                            output_hidden_states=True,\n                        )\n\n                all_prompt_embeds.append(\n                    self.encoding_qwen_image_edit(outputs, image_inputs)\n                )\n                if batch.do_classifier_free_guidance:\n                    all_neg_prompt_embeds.append(\n                        self.encoding_qwen_image_edit(neg_outputs, neg_image_inputs)\n                    )\n\n        if all_prompt_embeds:\n            batch.prompt_embeds.append(torch.cat(all_prompt_embeds, dim=0))\n        if all_neg_prompt_embeds:\n            batch.negative_prompt_embeds.append(torch.cat(all_neg_prompt_embeds, dim=0))\n\n        self.offload_model()\n\n        return batch\n\n    def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult:\n        \"\"\"Verify image encoding stage inputs.\"\"\"\n        result = VerificationResult()\n        if batch.debug:\n            logger.debug(f\"{batch.condition_image=}\")\n            logger.debug(f\"{batch.image_embeds=}\")\n        result.add_check(\"pil_image\", batch.condition_image, V.not_none)\n        result.add_check(\"image_embeds\", batch.image_embeds, V.is_list)\n        return result\n\n    def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult:\n        \"\"\"Verify image encoding stage outputs.\"\"\"\n        result = VerificationResult()\n        # result.add_check(\"image_embeds\", batch.image_embeds, V.list_of_tensors_dims(3))\n        return result\n\n\nclass ImageVAEEncodingStage(PipelineStage):\n    \"\"\"\n    Stage for encoding pixel representations into latent space.\n\n    This stage handles the encoding of pixel representations into the final\n    input format (e.g., image_latents).\n    \"\"\"\n\n    def __init__(self, vae: ParallelTiledVAE, **kwargs) -> None:\n        super().__init__()\n        self.vae: ParallelTiledVAE = vae\n\n    def load_model(self):\n        self.vae = self.vae.to(get_local_torch_device())\n\n    def offload_model(self):\n        if self.server_args.vae_cpu_offload:\n            self.vae = self.vae.to(\"cpu\")\n\n    def forward(\n        self,\n        batch: Req,\n        server_args: ServerArgs,\n    ) -> Req:\n        \"\"\"\n        Encode pixel representations into latent space.\n        \"\"\"\n\n        if batch.condition_image is None:\n            return batch\n\n        self.load_model()\n        num_frames = batch.num_frames\n\n        images = (\n            batch.vae_image if batch.vae_image is not None else batch.condition_image\n        )\n        if not isinstance(images, list):\n            images = [images]\n\n        all_image_latents = []\n        prepare_condition_image_latent_ids = getattr(\n            server_args.pipeline_config, \"prepare_condition_image_latent_ids\", None\n        )\n        condition_latents = [] if callable(prepare_condition_image_latent_ids) else None\n        for image in images:\n            image = self.preprocess(\n                image,\n            ).to(get_local_torch_device(), dtype=torch.float32)\n\n            # (B, C, H, W) -> (B, C, 1, H, W)\n            image = image.unsqueeze(2)\n\n            if num_frames == 1:\n                video_condition = image\n            else:\n                video_condition = torch.cat(\n                    [\n                        image,\n                        image.new_zeros(\n                            image.shape[0],\n                            image.shape[1],\n                            num_frames - 1,\n                            image.shape[3],\n                            image.shape[4],\n                        ),\n                    ],\n                    dim=2,\n                )\n            video_condition = video_condition.to(\n                device=get_local_torch_device(), dtype=torch.float32\n            )\n\n            # Setup VAE precision\n            vae_dtype = PRECISION_TO_TYPE[server_args.pipeline_config.vae_precision]\n            vae_autocast_enabled = (\n                vae_dtype != torch.float32\n            ) and not server_args.disable_autocast\n\n            # Encode Image\n            with torch.autocast(\n                device_type=current_platform.device_type,\n                dtype=vae_dtype,\n                enabled=vae_autocast_enabled,\n            ):\n                if server_args.pipeline_config.vae_tiling:\n                    self.vae.enable_tiling()\n                # if server_args.vae_sp:\n                #     self.vae.enable_parallel()\n                if not vae_autocast_enabled:\n                    video_condition = video_condition.to(vae_dtype)\n                latent_dist: DiagonalGaussianDistribution = self.vae.encode(\n                    video_condition\n                )\n                # for auto_encoder from diffusers\n                if isinstance(latent_dist, AutoencoderKLOutput):\n                    latent_dist = latent_dist.latent_dist\n\n            generator = batch.generator\n            if generator is None:\n                raise ValueError(\"Generator must be provided\")\n\n            sample_mode = server_args.pipeline_config.vae_config.encode_sample_mode()\n\n            latent_condition = self.retrieve_latents(\n                latent_dist, generator, sample_mode=sample_mode\n            )\n            latent_condition = server_args.pipeline_config.postprocess_vae_encode(\n                latent_condition, self.vae\n            )\n\n            scaling_factor, shift_factor = (\n                server_args.pipeline_config.get_decode_scale_and_shift(\n                    device=latent_condition.device,\n                    dtype=latent_condition.dtype,\n                    vae=self.vae,\n                )\n            )\n\n            # apply shift & scale if needed\n            if isinstance(shift_factor, torch.Tensor):\n                shift_factor = shift_factor.to(latent_condition.device)\n\n            if isinstance(scaling_factor, torch.Tensor):\n                scaling_factor = scaling_factor.to(latent_condition.device)\n\n            latent_condition -= shift_factor\n            latent_condition = latent_condition * scaling_factor\n\n            if condition_latents is not None:\n                condition_latents.append(latent_condition)\n\n            image_latent = server_args.pipeline_config.postprocess_image_latent(\n                latent_condition, batch\n            )\n            all_image_latents.append(image_latent)\n\n        batch.image_latent = torch.cat(all_image_latents, dim=1)\n        if condition_latents is not None:\n            prepare_condition_image_latent_ids(condition_latents, batch)\n\n        self.offload_model()\n        return batch\n\n    def retrieve_latents(\n        self,\n        encoder_output: DiagonalGaussianDistribution,\n        generator: torch.Generator | None = None,\n        sample_mode: str = \"sample\",\n    ):\n        if sample_mode == \"sample\":\n            return encoder_output.sample(generator)\n        elif sample_mode == \"argmax\":\n            return encoder_output.mode()\n        else:\n            raise AttributeError(\"Could not access latents of provided encoder_output\")\n\n    def preprocess(\n        self,\n        image: torch.Tensor | PIL.Image.Image,\n    ) -> torch.Tensor:\n\n        if isinstance(image, PIL.Image.Image):\n            image = pil_to_numpy(image)  # to np\n            image = numpy_to_pt(image)  # to pt\n\n        do_normalize = True\n        if image.min() < 0:\n            do_normalize = False\n        if do_normalize:\n            image = normalize(image)\n\n        return image\n\n    def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult:\n        \"\"\"Verify encoding stage inputs.\"\"\"\n        result = VerificationResult()\n\n        assert batch.condition_image is None or (\n            isinstance(batch.condition_image, PIL.Image.Image)\n            or isinstance(batch.condition_image, torch.Tensor)\n            or isinstance(batch.condition_image, list)\n        )\n        assert batch.height is not None and isinstance(batch.height, int)\n        assert batch.width is not None and isinstance(batch.width, int)\n        assert batch.num_frames is not None and isinstance(batch.num_frames, int)\n\n        result.add_check(\"generator\", batch.generator, V.generator_or_list_generators)\n        result.add_check(\"height\", batch.height, V.positive_int)\n        result.add_check(\"width\", batch.width, V.positive_int)\n        result.add_check(\"num_frames\", batch.num_frames, V.positive_int)\n        return result\n\n    def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult:\n        \"\"\"Verify encoding stage outputs.\"\"\"\n        result = VerificationResult()\n        # result.add_check(\n        #     \"image_latent\", batch.image_latent, [V.is_tensor, V.with_dims(5)]\n        # )\n        return result\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines_core/stages/input_validation.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"\nInput validation stage for diffusion pipelines.\n\"\"\"\n\nimport numpy as np\nimport torch\nimport torchvision.transforms.functional as TF\nfrom PIL import Image\n\nfrom sglang.multimodal_gen.configs.pipeline_configs import WanI2V480PConfig\nfrom sglang.multimodal_gen.configs.pipeline_configs.base import ModelTaskType\nfrom sglang.multimodal_gen.configs.pipeline_configs.mova import MOVAPipelineConfig\nfrom sglang.multimodal_gen.runtime.models.vision_utils import load_image, load_video\nfrom sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.base import PipelineStage\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.validators import (\n    StageValidators,\n    VerificationResult,\n)\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.utils import best_output_size\n\nlogger = init_logger(__name__)\n\n# Alias for convenience\nV = StageValidators\n\n\n# TODO: since this might change sampling params after logging, should be do this beforehand?\n\n\nclass InputValidationStage(PipelineStage):\n    \"\"\"\n    Stage for validating and preparing inputs for diffusion pipelines.\n\n    This stage validates that all required inputs are present and properly formatted\n    before proceeding with the diffusion process.\n\n    In this stage, input image and output image may be resized\n    \"\"\"\n\n    def __init__(self, vae_image_processor=None):\n        super().__init__()\n        self.vae_image_processor = vae_image_processor\n\n    @staticmethod\n    def _calculate_dimensions_from_area(\n        max_area: float, aspect_ratio: float, mod_value: int\n    ) -> tuple[int, int]:\n        \"\"\"\n        Calculate output dimensions based on maximum area and aspect ratio.\n\n        Args:\n            max_area: Maximum area constraint for the output\n            aspect_ratio: Target aspect ratio (height/width)\n            mod_value: Value to round dimensions to (typically vae_scale * patch_size)\n\n        Returns:\n            Tuple of (width, height) rounded to multiples of mod_value\n        \"\"\"\n        height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value\n        width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value\n        return width, height\n\n    def _generate_seeds(self, batch: Req, server_args: ServerArgs):\n        \"\"\"Generate seeds for the inference\"\"\"\n        seed = batch.seed\n        num_videos_per_prompt = batch.num_outputs_per_prompt\n\n        assert seed is not None\n        seeds = [seed + i for i in range(num_videos_per_prompt)]\n        batch.seeds = seeds\n\n        # Create generators based on generator_device parameter\n        # Note: This will overwrite any existing batch.generator\n        generator_device = batch.generator_device\n\n        if generator_device == \"cpu\":\n            device_str = \"cpu\"\n        else:\n            device_str = current_platform.device_type\n\n        batch.generator = [\n            torch.Generator(device_str).manual_seed(seed) for seed in seeds\n        ]\n\n    def preprocess_condition_image(\n        self,\n        batch: Req,\n        server_args: ServerArgs,\n        condition_image_width,\n        condition_image_height,\n    ):\n        \"\"\"\n        preprocess condition image\n        NOTE: condition image resizing is only allowed in InputValidationStage\n        \"\"\"\n        if batch.condition_image is not None and (\n            server_args.pipeline_config.task_type == ModelTaskType.I2I\n            or server_args.pipeline_config.task_type == ModelTaskType.TI2I\n        ):\n            # calculate new condition image size\n            if not isinstance(batch.condition_image, list):\n                batch.condition_image = [batch.condition_image]\n\n            processed_images = []\n            final_image = batch.condition_image[-1]\n            config = server_args.pipeline_config\n            config.preprocess_vae_image(batch, self.vae_image_processor)\n\n            for img in batch.condition_image:\n                size = config.calculate_condition_image_size(img, img.width, img.height)\n                if size is not None:\n                    width, height = size\n                    img, _ = config.preprocess_condition_image(\n                        img, width, height, self.vae_image_processor\n                    )\n\n                processed_images.append(img)\n\n            batch.condition_image = processed_images\n            calculated_size = config.prepare_calculated_size(final_image)\n\n            # adjust output image size\n            if calculated_size is not None:\n                calculated_width, calculated_height = calculated_size\n                width = batch.width or calculated_width\n                height = batch.height or calculated_height\n                multiple_of = (\n                    server_args.pipeline_config.vae_config.get_vae_scale_factor() * 2\n                )\n                width = width // multiple_of * multiple_of\n                height = height // multiple_of * multiple_of\n                batch.width = width\n                batch.height = height\n\n        elif server_args.pipeline_config.task_type == ModelTaskType.TI2V:\n            if server_args.pipeline_config.skip_input_image_preprocess:\n                return\n            # duplicate with vae_image_processor\n            # further processing for ti2v task\n            if isinstance(\n                batch.condition_image, list\n            ):  # not support multi image input yet.\n                batch.condition_image = batch.condition_image[0]\n\n            img = batch.condition_image\n            ih, iw = img.height, img.width\n            patch_size = server_args.pipeline_config.dit_config.arch_config.patch_size\n            vae_stride = (\n                server_args.pipeline_config.vae_config.arch_config.scale_factor_spatial\n            )\n            dh, dw = patch_size[1] * vae_stride, patch_size[2] * vae_stride\n            max_area = 704 * 1280\n            ow, oh = best_output_size(iw, ih, dw, dh, max_area)\n\n            scale = max(ow / iw, oh / ih)\n            img = img.resize((round(iw * scale), round(ih * scale)), Image.LANCZOS)\n            logger.debug(\"resized condition image to: %sx%s\", img.height, img.width)\n\n            # center-crop\n            x1 = (img.width - ow) // 2\n            y1 = (img.height - oh) // 2\n            img = img.crop((x1, y1, x1 + ow, y1 + oh))\n            assert img.width == ow and img.height == oh\n\n            # to tensor\n            img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device).unsqueeze(1)\n            img = img.unsqueeze(0)\n            batch.height = oh\n            batch.width = ow\n            # TODO: should we store in a new field: pixel values?\n            batch.condition_image = img\n\n        elif isinstance(server_args.pipeline_config, WanI2V480PConfig):\n            # TODO: could we merge with above?\n            # resize image only, Wan2.1 I2V\n            if isinstance(batch.condition_image, list):\n                batch.condition_image = batch.condition_image[\n                    0\n                ]  # not support multi image input yet.\n\n            max_area = server_args.pipeline_config.max_area\n            aspect_ratio = condition_image_height / condition_image_width\n            mod_value = (\n                server_args.pipeline_config.vae_config.arch_config.scale_factor_spatial\n                * server_args.pipeline_config.dit_config.arch_config.patch_size[1]\n            )\n            width, height = self._calculate_dimensions_from_area(\n                max_area, aspect_ratio, mod_value\n            )\n\n            batch.condition_image = batch.condition_image.resize((width, height))\n            batch.height = height\n            batch.width = width\n\n        elif issubclass(type(server_args.pipeline_config), MOVAPipelineConfig):\n            # resize image only, MOVA\n            image = batch.condition_image\n            if isinstance(image, list):\n                image = image[0]  # not support multi image input yet.\n\n            max_area = server_args.pipeline_config.max_area\n            if hasattr(batch, \"height\") and hasattr(batch, \"width\"):\n                aspect_ratio = batch.height / batch.width\n            else:\n                aspect_ratio = (\n                    batch.sampling_params.height / batch.sampling_params.width\n                )\n            mod_value = (\n                server_args.pipeline_config.vae_config.arch_config.scale_factor_spatial\n                * server_args.pipeline_config.dit_config.arch_config.patch_size[1]\n            )\n            width, height = self._calculate_dimensions_from_area(\n                max_area, aspect_ratio, mod_value\n            )\n\n            config = server_args.pipeline_config\n            image, (final_w, final_h) = (\n                server_args.pipeline_config.preprocess_condition_image(\n                    image, width, height, self.vae_image_processor\n                )\n            )\n            batch.condition_image = image\n            batch.width = final_w\n            batch.height = final_h\n\n    def forward(\n        self,\n        batch: Req,\n        server_args: ServerArgs,\n    ) -> Req:\n        \"\"\"\n        Validate and prepare inputs.\n        \"\"\"\n\n        self._generate_seeds(batch, server_args)\n\n        if (\n            server_args.pipeline_config.task_type == ModelTaskType.I2M\n            and batch.num_inference_steps is None\n            and hasattr(server_args.pipeline_config, \"shape_num_inference_steps\")\n        ):\n            batch.num_inference_steps = (\n                server_args.pipeline_config.shape_num_inference_steps\n            )\n\n        # Ensure prompt is properly formatted (I2M can be image-only)\n        if (\n            server_args.pipeline_config.task_type != ModelTaskType.I2M\n            and batch.prompt is None\n            and batch.prompt_embeds is None\n        ):\n            raise ValueError(\"Either `prompt` or `prompt_embeds` must be provided\")\n\n        # Ensure negative prompt is properly formatted if using classifier-free guidance\n        if (\n            batch.do_classifier_free_guidance\n            and batch.negative_prompt is None\n            and batch.negative_prompt_embeds is None\n        ):\n            raise ValueError(\n                \"For classifier-free guidance, either `negative_prompt` or \"\n                \"`negative_prompt_embeds` must be provided\"\n            )\n\n        # Validate number of inference steps\n        if batch.num_inference_steps <= 0:\n            raise ValueError(\n                f\"Number of inference steps must be positive, but got {batch.num_inference_steps}\"\n            )\n\n        # Validate guidance scale if using classifier-free guidance\n        if batch.do_classifier_free_guidance and batch.guidance_scale < 0:\n            raise ValueError(\n                f\"Guidance scale must be positive, but got {batch.guidance_scale}\"\n            )\n\n        # for i2v, get image from image_path\n        # @TODO(Wei) hard-coded for wan2.2 5b ti2v for now. Should put this in image_encoding stage\n        if batch.image_path is not None:\n            if isinstance(batch.image_path, list):\n                batch.condition_image = []\n                for path in batch.image_path:\n                    if path.endswith(\".mp4\"):\n                        image = load_video(path)[0]\n                    else:\n                        image = load_image(path)\n                    batch.condition_image.append(image)\n\n                # Use the first image for size reference\n                condition_image_width = batch.condition_image[0].width\n                condition_image_height = batch.condition_image[0].height\n                batch.original_condition_image_size = (\n                    condition_image_width,\n                    condition_image_height,\n                )\n            else:\n                if batch.image_path.endswith(\".mp4\"):\n                    image = load_video(batch.image_path)[0]\n                else:\n                    image = load_image(batch.image_path)\n                batch.condition_image = image\n                condition_image_width, condition_image_height = (\n                    image.width,\n                    image.height,\n                )\n                batch.original_condition_image_size = image.size\n\n            if server_args.pipeline_config.task_type != ModelTaskType.I2M:\n                self.preprocess_condition_image(\n                    batch, server_args, condition_image_width, condition_image_height\n                )\n\n        # if height or width is not specified at this point, set default to 720p\n        default_height = 720\n        default_width = 1280\n        if batch.height is None and batch.width is None:\n            batch.height = default_height\n            batch.width = default_width\n        elif batch.height is None:\n            batch.height = batch.width * default_height // default_width\n        elif batch.width is None:\n            batch.width = batch.height * default_width // default_height\n\n        return batch\n\n    def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult:\n        \"\"\"Verify input validation stage inputs.\"\"\"\n        result = VerificationResult()\n        result.add_check(\"seed\", batch.seed, [V.not_none, V.non_negative_int])\n        result.add_check(\n            \"num_videos_per_prompt\", batch.num_outputs_per_prompt, V.positive_int\n        )\n        if server_args.pipeline_config.task_type != ModelTaskType.I2M:\n            result.add_check(\n                \"prompt_or_embeds\",\n                None,\n                lambda _: V.string_or_list_strings(batch.prompt)\n                or V.list_not_empty(batch.prompt_embeds),\n            )\n\n        if server_args.pipeline_config.task_type != ModelTaskType.I2M:\n            result.add_check(\n                \"num_inference_steps\", batch.num_inference_steps, V.positive_int\n            )\n        else:\n            result.add_check(\n                \"num_inference_steps\",\n                batch.num_inference_steps,\n                lambda x: x is None or V.positive_int(x),\n            )\n        result.add_check(\n            \"guidance_scale\",\n            batch.guidance_scale,\n            lambda x: not batch.do_classifier_free_guidance or V.non_negative_float(x),\n        )\n        return result\n\n    def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult:\n        \"\"\"Verify input validation stage outputs.\"\"\"\n        result = VerificationResult()\n        result.add_check(\"height\", batch.height, V.positive_int)\n        result.add_check(\"width\", batch.width, V.positive_int)\n        result.add_check(\"seeds\", batch.seeds, V.list_not_empty)\n        result.add_check(\"generator\", batch.generator, V.generator_or_list_generators)\n        return result\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines_core/stages/latent_preparation.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"\nLatent preparation stage for diffusion pipelines.\n\"\"\"\n\nfrom diffusers.utils.torch_utils import randn_tensor\n\nfrom sglang.multimodal_gen.runtime.distributed import get_local_torch_device\nfrom sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.base import PipelineStage\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.validators import (\n    StageValidators as V,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.validators import (\n    VerificationResult,\n)\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\nclass LatentPreparationStage(PipelineStage):\n    \"\"\"\n    Stage for preparing initial latent variables for the diffusion process.\n\n    This stage handles the preparation of the initial latent variables that will be\n    denoised during the diffusion process.\n    \"\"\"\n\n    def __init__(self, scheduler, transformer) -> None:\n        super().__init__()\n        self.scheduler = scheduler\n        self.transformer = transformer\n\n    def forward(\n        self,\n        batch: Req,\n        server_args: ServerArgs,\n    ) -> Req:\n        \"\"\"\n        Prepare initial latent variables for the diffusion process.\n\n\n\n        Returns:\n            The batch with prepared latent variables.\n        \"\"\"\n\n        # Adjust video length based on VAE version if needed\n        latent_num_frames = self.adjust_video_length(batch, server_args)\n\n        batch_size = batch.batch_size\n\n        # Get required parameters\n        dtype = batch.prompt_embeds[0].dtype\n        device = get_local_torch_device()\n        generator = batch.generator\n        latents = batch.latents\n        num_frames = (\n            latent_num_frames if latent_num_frames is not None else batch.num_frames\n        )\n        height = batch.height\n        width = batch.width\n\n        # TODO(will): remove this once we add input/output validation for stages\n        if height is None or width is None:\n            raise ValueError(\"Height and width must be provided\")\n\n        # Validate generator if it's a list\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        # Generate or use provided latents\n        if latents is None:\n            shape = server_args.pipeline_config.prepare_latent_shape(\n                batch, batch_size, num_frames\n            )\n            latents = randn_tensor(\n                shape, generator=generator, device=device, dtype=dtype\n            )\n\n            latent_ids = server_args.pipeline_config.maybe_prepare_latent_ids(latents)\n\n            if latent_ids is not None:\n                batch.latent_ids = latent_ids.to(device=device)\n\n            latents = server_args.pipeline_config.maybe_pack_latents(\n                latents, batch_size, batch\n            )\n        else:\n            latents = latents.to(device)\n\n        # Scale the initial noise if needed\n        if hasattr(self.scheduler, \"init_noise_sigma\"):\n            latents = latents * self.scheduler.init_noise_sigma\n        # Update batch with prepared latents\n        batch.latents = latents\n        batch.raw_latent_shape = latents.shape\n        return batch\n\n    def adjust_video_length(self, batch: Req, server_args: ServerArgs) -> int:\n        \"\"\"\n        Adjust video length based on VAE version.\n        \"\"\"\n\n        video_length = batch.num_frames\n        latent_num_frames = video_length\n        use_temporal_scaling_frames = (\n            server_args.pipeline_config.vae_config.use_temporal_scaling_frames\n        )\n        if use_temporal_scaling_frames:\n            temporal_scale_factor = (\n                server_args.pipeline_config.vae_config.arch_config.temporal_compression_ratio\n            )\n            latent_num_frames = (video_length - 1) // temporal_scale_factor + 1\n        return int(latent_num_frames)\n\n    def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult:\n        \"\"\"Verify latent preparation stage inputs.\"\"\"\n        result = VerificationResult()\n        result.add_check(\n            \"prompt_or_embeds\",\n            None,\n            lambda _: V.string_or_list_strings(batch.prompt)\n            or V.list_not_empty(batch.prompt_embeds),\n        )\n        result.add_check(\"prompt_embeds\", batch.prompt_embeds, V.list_of_tensors)\n        result.add_check(\n            \"num_videos_per_prompt\", batch.num_outputs_per_prompt, V.positive_int\n        )\n        result.add_check(\"generator\", batch.generator, V.generator_or_list_generators)\n        result.add_check(\"num_frames\", batch.num_frames, V.positive_int)\n        result.add_check(\"height\", batch.height, V.positive_int)\n        result.add_check(\"width\", batch.width, V.positive_int)\n        result.add_check(\"latents\", batch.latents, V.none_or_tensor)\n        return result\n\n    def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult:\n        \"\"\"Verify latent preparation stage outputs.\"\"\"\n        result = VerificationResult()\n        if batch.debug:\n            logger.debug(f\"{batch.raw_latent_shape=}\")\n        # disable temporarily for image-generation models\n        # result.add_check(\"latents\", batch.latents, [V.is_tensor, V.with_dims(5)])\n        result.add_check(\"raw_latent_shape\", batch.raw_latent_shape, V.is_tuple)\n        return result\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines_core/stages/latent_preparation_av.py",
    "content": "import torch\nfrom diffusers.utils.torch_utils import randn_tensor\n\nfrom sglang.multimodal_gen.runtime.distributed import get_local_torch_device\nfrom sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.latent_preparation import (\n    LatentPreparationStage,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.validators import (\n    StageValidators as V,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.validators import (\n    VerificationResult,\n)\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\nclass LTX2AVLatentPreparationStage(LatentPreparationStage):\n    \"\"\"\n    LTX-2 specific latent preparation stage that handles both video and audio latents.\n    \"\"\"\n\n    def __init__(self, scheduler, transformer=None, audio_vae=None):\n        super().__init__(scheduler, transformer)\n        self.audio_vae = audio_vae\n\n    def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult:\n        \"\"\"Verify latent preparation stage inputs.\"\"\"\n        result = VerificationResult()\n        result.add_check(\n            \"prompt_or_embeds\",\n            None,\n            lambda _: V.string_or_list_strings(batch.prompt)\n            or V.list_not_empty(batch.prompt_embeds)\n            or V.is_tensor(batch.prompt_embeds),\n        )\n\n        if isinstance(batch.prompt_embeds, list):\n            result.add_check(\"prompt_embeds\", batch.prompt_embeds, V.list_of_tensors)\n        else:\n            result.add_check(\"prompt_embeds\", batch.prompt_embeds, V.is_tensor)\n\n        result.add_check(\n            \"num_videos_per_prompt\", batch.num_outputs_per_prompt, V.positive_int\n        )\n        result.add_check(\"generator\", batch.generator, V.generator_or_list_generators)\n        result.add_check(\"num_frames\", batch.num_frames, V.positive_int)\n        result.add_check(\"height\", batch.height, V.positive_int)\n        result.add_check(\"width\", batch.width, V.positive_int)\n        result.add_check(\"latents\", batch.latents, V.none_or_tensor)\n        return result\n\n    def forward(self, batch: Req, server_args: ServerArgs) -> Req:\n        # 1. Prepare Video Latents using base class logic\n        # This sets batch.latents and batch.raw_latent_shape\n        batch = super().forward(batch, server_args)\n\n        # 2. Prepare Audio Latents (optional)\n        # Default to True if not specified\n        try:\n            generate_audio = batch.generate_audio\n        except AttributeError:\n            generate_audio = True\n        if not generate_audio:\n            batch.audio_latents = None\n            batch.raw_audio_latent_shape = None\n            return batch\n\n        device = get_local_torch_device()\n        if isinstance(batch.prompt_embeds, list) and batch.prompt_embeds:\n            dtype = batch.prompt_embeds[0].dtype\n        elif isinstance(batch.prompt_embeds, torch.Tensor):\n            dtype = batch.prompt_embeds.dtype\n        else:\n            dtype = torch.float16\n        generator = batch.generator\n\n        audio_latents = batch.audio_latents\n        batch_size = batch.batch_size\n        num_frames = batch.num_frames\n\n        if audio_latents is None:\n            shape = server_args.pipeline_config.prepare_audio_latent_shape(\n                batch, batch_size, num_frames\n            )\n\n            audio_latents = randn_tensor(\n                shape, generator=generator, device=device, dtype=dtype\n            )\n        else:\n            audio_latents = audio_latents.to(device)\n\n        audio_latents = server_args.pipeline_config.maybe_pack_audio_latents(\n            audio_latents, batch_size, batch\n        )\n\n        # Store in batch\n        batch.audio_latents = audio_latents\n        batch.raw_audio_latent_shape = audio_latents.shape\n\n        return batch\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/glm_image.py",
    "content": "import inspect\nimport re\nimport time\nfrom math import sqrt\nfrom typing import List, Optional, Tuple, Union\n\nimport numpy as np\nimport PIL\nimport torch\nfrom diffusers.image_processor import VaeImageProcessor\nfrom diffusers.utils.torch_utils import randn_tensor\n\nfrom sglang.multimodal_gen.runtime.distributed import get_local_torch_device\nfrom sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context\nfrom sglang.multimodal_gen.runtime.models.dits.glm_image import GlmImageKVCache\nfrom sglang.multimodal_gen.runtime.models.vision_utils import load_image\nfrom sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.base import PipelineStage\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\ndef calculate_shift(\n    image_seq_len,\n    base_seq_len: int = 256,\n    base_shift: float = 0.25,\n    max_shift: float = 0.75,\n) -> float:\n    m = (image_seq_len / base_seq_len) ** 0.5\n    mu = m * max_shift + base_shift\n    return mu\n\n\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps: Optional[int] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    timesteps: Optional[List[int]] = None,\n    sigmas: Optional[List[float]] = None,\n    **kwargs,\n):\n    r\"\"\"\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n    \"\"\"\n    accepts_timesteps = \"timesteps\" in set(\n        inspect.signature(scheduler.set_timesteps).parameters.keys()\n    )\n    accepts_sigmas = \"sigmas\" in set(\n        inspect.signature(scheduler.set_timesteps).parameters.keys()\n    )\n\n    if timesteps is not None and sigmas is not None:\n        if not accepts_timesteps and not accepts_sigmas:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep or sigma schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(\n            timesteps=timesteps, sigmas=sigmas, device=device, **kwargs\n        )\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    elif timesteps is not None and sigmas is None:\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    elif timesteps is None and sigmas is not None:\n        if not accepts_sigmas:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" sigmas schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents\ndef retrieve_latents(\n    encoder_output: torch.Tensor,\n    generator: Optional[torch.Generator] = None,\n    sample_mode: str = \"sample\",\n):\n    if hasattr(encoder_output, \"latent_dist\") and sample_mode == \"sample\":\n        return encoder_output.latent_dist.sample(generator)\n    elif hasattr(encoder_output, \"latent_dist\") and sample_mode == \"argmax\":\n        return encoder_output.latent_dist.mode()\n    elif hasattr(encoder_output, \"latents\"):\n        return encoder_output.latents\n    else:\n        raise AttributeError(\"Could not access latents of provided encoder_output\")\n\n\nclass GlmImageBeforeDenoisingStage(PipelineStage):\n    r\"\"\"\n    Pipeline for text-to-image generation using GLM-Image.\n\n    This pipeline integrates both the AR (autoregressive) model for token generation and the DiT (diffusion\n    transformer) model for image decoding.\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`T5EncoderModel`]):\n            Frozen text-encoder for glyph embeddings.\n        tokenizer (`PreTrainedTokenizer`):\n            Tokenizer for the text encoder.\n        processor (`AutoProcessor`):\n            Processor for the AR model to handle chat templates and tokenization.\n        vision_language_encoder ([`GlmImageForConditionalGeneration`]):\n            The AR model that generates image tokens from text prompts.\n        transformer ([`GlmImageTransformer2DModel`]):\n            A text conditioned transformer to denoise the encoded image latents (DiT).\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `transformer` to denoise the encoded image latents.\n    \"\"\"\n\n    def __init__(\n        self,\n        tokenizer,\n        processor,\n        text_encoder,\n        vision_language_encoder,\n        vae,\n        transformer,\n        scheduler,\n    ):\n        super().__init__()\n\n        self.tokenizer = tokenizer\n        self.processor = processor\n        self.text_encoder = text_encoder\n        self.vision_language_encoder = vision_language_encoder\n        self.vae = vae\n        self.transformer = transformer\n        self.scheduler = scheduler\n\n        self.vae_scale_factor = (\n            2 ** (len(self.vae.config.block_out_channels) - 1)\n            if getattr(self, \"vae\", None)\n            else 8\n        )\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n\n        self.default_sample_size = (\n            self.transformer.config.sample_size\n            if hasattr(self, \"transformer\")\n            and self.transformer is not None\n            and hasattr(self.transformer.config, \"sample_size\")\n            else 128\n        )\n\n    def _parse_and_expand_shape_info(\n        self, prompt: str\n    ) -> Tuple[str, int, int, int, int]:\n        \"\"\"\n        Parse the shape info from prompt and expand it for AR model.\n\n        Args:\n            prompt: The prompt containing <sop>H W<eop> shape specification\n\n        Returns:\n            Tuple of (expanded_prompt, token_h, token_w, prev_token_h, prev_token_w)\n        \"\"\"\n        match = re.search(r\"<sop>(\\d+)\\s+(\\d+)<eop>\", prompt)\n        if match is None:\n            raise ValueError(\n                f\"Prompt must contain shape info in format '<sop>H W<eop>', got: {prompt}\"\n            )\n\n        token_h, token_w = int(match.group(1)), int(match.group(2))\n        ratio = token_h / token_w\n        prev_token_h = int(sqrt(ratio) * 16)\n        prev_token_w = int(sqrt(1 / ratio) * 16)\n\n        old_shape = f\"<sop>{token_h} {token_w}<eop>\"\n        new_shape = (\n            f\"<sop>{token_h} {token_w}<eop><sop>{prev_token_h} {prev_token_w}<eop>\"\n        )\n        expanded_prompt = prompt.replace(old_shape, new_shape)\n\n        return expanded_prompt, token_h, token_w, prev_token_h, prev_token_w\n\n    def _build_image_grid_thw(\n        self,\n        token_h: int,\n        token_w: int,\n        prev_token_h: int,\n        prev_token_w: int,\n        existing_grid: Optional[torch.Tensor] = None,\n        device: Optional[torch.device] = None,\n    ) -> torch.Tensor:\n        \"\"\"\n        Build image grid tensor for AR model.\n\n        For text-to-image: creates grid for large image + small image For image-to-image: appends new image to existing\n        grid\n        \"\"\"\n        if existing_grid is None or existing_grid.numel() == 0:\n            # Text-to-image: large image + small image\n            return torch.tensor(\n                [\n                    [1, token_h, token_w],\n                    [1, prev_token_h, prev_token_w],\n                ],\n                device=device,\n            )\n        else:\n            # Image-to-image: append to existing\n            return torch.cat(\n                [existing_grid, torch.tensor([[1, token_h, token_w]], device=device)],\n                dim=0,\n            )\n\n    def _calculate_ar_generation_params(\n        self,\n        token_h: int,\n        token_w: int,\n        prev_token_h: int,\n        prev_token_w: int,\n        is_text_to_image: bool,\n    ) -> Tuple[int, int]:\n        \"\"\"\n        Calculate max_new_tokens and large_image_start_offset for AR generation.\n        \"\"\"\n        large_image_tokens = token_h * token_w\n        small_image_tokens = prev_token_h * prev_token_w\n\n        if is_text_to_image:\n            max_new_tokens = small_image_tokens + large_image_tokens + 1\n            large_image_start_offset = small_image_tokens\n        else:\n            max_new_tokens = large_image_tokens + 1\n            large_image_start_offset = 0\n\n        return max_new_tokens, large_image_start_offset\n\n    def _extract_large_image_tokens(\n        self,\n        outputs: torch.Tensor,\n        input_length: int,\n        large_image_start_offset: int,\n        large_image_tokens: int,\n    ) -> torch.Tensor:\n        \"\"\"\n        Extract the large image tokens from AR model output.\n        \"\"\"\n        generated_tokens = outputs[0][input_length:]\n        large_image_start = large_image_start_offset\n        large_image_end = large_image_start + large_image_tokens\n        return generated_tokens[large_image_start:large_image_end]\n\n    def _upsample_d32_to_d16(\n        self, token_ids: torch.Tensor, token_h: int, token_w: int\n    ) -> torch.Tensor:\n        \"\"\"\n        Upsample token IDs from d32 format to d16 format.\n\n        AR model generates tokens at d32 resolution (each token = 32x32 pixels). DiT expects tokens at d16 resolution\n        (each token = 16x16 pixels). This function performs 2x nearest-neighbor upsampling.\n\n        Args:\n            token_ids: Token IDs of shape [N] where N = token_h * token_w\n            token_h: Height in d32 token units\n            token_w: Width in d32 token units\n\n        Returns:\n            Upsampled token IDs of shape [1, N*4] where N*4 = (token_h*2) * (token_w*2)\n        \"\"\"\n        # Reshape to spatial format: [1, 1, H, W]\n        token_ids = token_ids.view(1, 1, token_h, token_w)\n\n        # 2x nearest-neighbor upsampling\n        token_ids = torch.nn.functional.interpolate(\n            token_ids.float(), scale_factor=2, mode=\"nearest\"\n        ).to(dtype=torch.long)\n\n        # Flatten back to [1, H*W*4]\n        token_ids = token_ids.view(1, -1)\n\n        return token_ids\n\n    @staticmethod\n    def _compute_generation_params(\n        image_grid_thw,\n        is_text_to_image: bool,\n    ):\n        grid_sizes = []\n        grid_hw = []\n\n        for i in range(image_grid_thw.shape[0]):\n            t, h, w = image_grid_thw[i].tolist()\n            grid_sizes.append(int(h * w))\n            grid_hw.append((int(h), int(w)))\n\n        if not is_text_to_image:\n            max_new_tokens = grid_sizes[-1] + 1\n            large_image_start_offset = 0\n            target_grid_h, target_grid_w = grid_hw[-1]\n        else:\n            total_tokens = sum(grid_sizes)\n            max_new_tokens = total_tokens + 1\n            large_image_start_offset = sum(grid_sizes[1:])\n            target_grid_h, target_grid_w = grid_hw[0]\n        return max_new_tokens, large_image_start_offset, target_grid_h, target_grid_w\n\n    @staticmethod\n    def _upsample_token_ids(\n        token_ids: torch.Tensor, token_h: int, token_w: int\n    ) -> torch.Tensor:\n        token_ids = token_ids.view(1, 1, token_h, token_w)\n        token_ids = torch.nn.functional.interpolate(\n            token_ids.float(), scale_factor=2, mode=\"nearest\"\n        ).to(dtype=torch.long)\n        token_ids = token_ids.view(1, -1)\n        return token_ids\n\n    def generate_prior_tokens(\n        self,\n        prompt: str,\n        height: int,\n        width: int,\n        image: Optional[List[PIL.Image.Image]] = None,\n        factor: int = 32,\n    ) -> Tuple[torch.Tensor, int, int]:\n        \"\"\"\n        Generate prior tokens using the AR (vision_language_encoder) model.\n\n        Args:\n            prompt: The text prompt with shape info (e.g., \"description<sop>36 24<eop>\")\n            condition_images: Optional list of condition images for i2i\n\n        Returns:\n            Tuple of (prior_token_ids, pixel_height, pixel_width)\n            - prior_token_ids: Upsampled to d16 format, shape [1, token_h*token_w*4]\n            - pixel_height: Image height in pixels\n            - pixel_width: Image width in pixels\n        \"\"\"\n        device = self.vision_language_encoder.device\n        height = (height // factor) * factor\n        width = (width // factor) * factor\n\n        is_text_to_image = image is None or len(image) == 0\n        # Build messages for processor\n        content = []\n        if image is not None:\n            for img in image:\n                content.append({\"type\": \"image\", \"image\": img})\n        content.append({\"type\": \"text\", \"text\": prompt})\n        messages = [{\"role\": \"user\", \"content\": content}]\n\n        inputs = self.processor.apply_chat_template(\n            messages,\n            tokenize=True,\n            target_h=height,\n            target_w=width,\n            return_dict=True,\n            return_tensors=\"pt\",\n        ).to(device)\n\n        image_grid_thw = inputs.get(\"image_grid_thw\")\n        max_new_tokens, large_image_offset, token_h, token_w = (\n            self._compute_generation_params(\n                image_grid_thw=image_grid_thw, is_text_to_image=is_text_to_image\n            )\n        )\n\n        prior_token_image_ids = None\n        if image is not None:\n            prior_token_image_embed = self.vision_language_encoder.get_image_features(\n                inputs[\"pixel_values\"], image_grid_thw[:-1]\n            )\n            prior_token_image_embed = torch.cat(prior_token_image_embed, dim=0)\n            prior_token_image_ids = self.vision_language_encoder.get_image_tokens(\n                prior_token_image_embed, image_grid_thw[:-1]\n            )\n\n        # For GLM-Image, greedy decoding is not allowed; it may cause repetitive outputs.\n        # max_new_tokens must be exactly grid_h * grid_w + 1 (the +1 is for EOS).\n        outputs = self.vision_language_encoder.generate(\n            **inputs,\n            max_new_tokens=max_new_tokens,\n            do_sample=True,\n        )\n\n        prior_token_ids_d32 = self._extract_large_image_tokens(\n            outputs,\n            inputs[\"input_ids\"].shape[-1],\n            large_image_offset,\n            token_h * token_w,\n        )\n        prior_token_ids = self._upsample_token_ids(\n            prior_token_ids_d32, token_h, token_w\n        )\n\n        return prior_token_ids, prior_token_image_ids\n\n    def get_glyph_texts(self, prompt):\n        prompt = prompt[0] if isinstance(prompt, list) else prompt\n        ocr_texts = (\n            re.findall(r\"'([^']*)'\", prompt)\n            + re.findall(r\"“([^“”]*)”\", prompt)\n            + re.findall(r'\"([^\"]*)\"', prompt)\n            + re.findall(r\"「([^「」]*)」\", prompt)\n        )\n        return ocr_texts\n\n    def _get_glyph_embeds(\n        self,\n        prompt: Union[str, List[str]] = None,\n        max_sequence_length: int = 2048,\n        device: Optional[torch.device] = None,\n        dtype: Optional[torch.dtype] = None,\n    ):\n        device = device or self._execution_device\n        dtype = dtype or self.text_encoder.dtype\n\n        glyph_texts = self.get_glyph_texts(prompt)\n        input_ids = self.tokenizer(\n            glyph_texts if len(glyph_texts) > 0 else [\"\"],\n            max_length=max_sequence_length,\n            truncation=True,\n        ).input_ids\n        input_ids = [\n            [self.tokenizer.pad_token_id] * ((len(input_ids) + 1) % 2) + input_ids_\n            for input_ids_ in input_ids\n        ]\n        max_length = max(len(input_ids_) for input_ids_ in input_ids)\n        attention_mask = torch.tensor(\n            [\n                [1] * len(input_ids_) + [0] * (max_length - len(input_ids_))\n                for input_ids_ in input_ids\n            ],\n            device=device,\n        )\n        input_ids = torch.tensor(\n            [\n                input_ids_\n                + [self.tokenizer.pad_token_id] * (max_length - len(input_ids_))\n                for input_ids_ in input_ids\n            ],\n            device=device,\n        )\n        outputs = self.text_encoder(input_ids, attention_mask=attention_mask)\n        glyph_embeds = outputs.last_hidden_state[attention_mask.bool()].unsqueeze(0)\n\n        return glyph_embeds.to(device=device, dtype=dtype)\n\n    def encode_prompt(\n        self,\n        prompt: Union[str, List[str]],\n        do_classifier_free_guidance: bool = True,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        device: Optional[torch.device] = None,\n        dtype: Optional[torch.dtype] = None,\n        max_sequence_length: int = 2048,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):\n                Whether to use classifier free guidance or not.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            device: (`torch.device`, *optional*):\n                torch device\n            dtype: (`torch.dtype`, *optional*):\n                torch dtype\n            max_sequence_length (`int`, defaults to `2048`):\n                Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.\n        \"\"\"\n        device = device or self._execution_device\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n        if prompt is not None:\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        if prompt_embeds is None:\n            prompt_embeds = self._get_glyph_embeds(\n                prompt, max_sequence_length, device, dtype\n            )\n\n        seq_len = prompt_embeds.size(1)\n        prompt_embeds = prompt_embeds.repeat(1, 1, 1)\n        prompt_embeds = prompt_embeds.view(1, seq_len, -1)\n\n        negative_prompt_embeds = None\n        if do_classifier_free_guidance:\n            negative_prompt = \"\"\n            negative_prompt = (\n                batch_size * [negative_prompt]\n                if isinstance(negative_prompt, str)\n                else negative_prompt\n            )\n\n            if prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n\n            negative_prompt_embeds = self._get_glyph_embeds(\n                negative_prompt, max_sequence_length, device, dtype\n            )\n\n            seq_len = negative_prompt_embeds.size(1)\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, 1, 1)\n            negative_prompt_embeds = negative_prompt_embeds.view(1, seq_len, -1)\n\n        return prompt_embeds, negative_prompt_embeds\n\n    def prepare_latents(\n        self,\n        batch_size,\n        num_channels_latents,\n        height,\n        width,\n        dtype,\n        device,\n        generator,\n    ):\n\n        shape = (\n            batch_size,\n            num_channels_latents,\n            int(height) // self.vae_scale_factor,\n            int(width) // self.vae_scale_factor,\n        )\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n        latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        return latents\n\n    def check_inputs(\n        self,\n        prompt,\n        height,\n        width,\n        callback_on_step_end_tensor_inputs,\n        prompt_embeds=None,\n    ):\n        if (\n            height is not None\n            and height % (self.vae_scale_factor * self.transformer.config.patch_size)\n            != 0\n            or width is not None\n            and width % (self.transformer.config.patch_size) != 0\n        ):\n            logger.warning(\n                f\"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly\"\n            )\n\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs\n            for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (\n            not isinstance(prompt, str) and not isinstance(prompt, list)\n        ):\n            raise ValueError(\n                f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\"\n            )\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    @property\n    def do_classifier_free_guidance(self):\n        return self._guidance_scale > 1\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    @property\n    def attention_kwargs(self):\n        return self._attention_kwargs\n\n    @property\n    def current_timestep(self):\n        return self._current_timestep\n\n    @property\n    def interrupt(self):\n        return self._interrupt\n\n    @torch.no_grad()\n    def forward(\n        self,\n        batch: Req,\n        server_args: ServerArgs,\n    ) -> Req:\n\n        guidance_scale = batch.guidance_scale\n        prompt = batch.prompt\n        num_inference_steps = batch.num_inference_steps\n        if batch.image_path is not None:\n            ar_condition_images = [\n                load_image(img_path) for img_path in batch.image_path\n            ]\n        else:\n            ar_condition_images = None\n\n        height = batch.height\n        width = batch.width\n\n        device = get_local_torch_device()\n        max_sequence_length = 1024\n        generator = torch.Generator(device=device).manual_seed(batch.seed)\n        attention_kwargs = {}\n        prompt_embeds = None\n        do_classifier_free_guidance = True\n        dtype = torch.bfloat16\n\n        self._guidance_scale = guidance_scale\n        self._current_timestep = None\n        self._interrupt = False\n\n        batch_size = 1\n\n        device = get_local_torch_device()\n\n        if ar_condition_images is not None:\n            height = height or ar_condition_images[0].height\n            width = width or ar_condition_images[0].width\n        time_start = time.time()\n        prior_token_id, prior_token_image_ids = self.generate_prior_tokens(\n            prompt=prompt,\n            image=ar_condition_images,\n            height=height,\n            width=width,\n        )\n        prior_token_id = prior_token_id.to(device=device)\n        time_end = time.time()\n        logger.info(f\"generate_prior_tokens time: {time_end - time_start}\")\n\n        # 3. Encode input prompt\n        prompt_embeds, negative_prompt_embeds = self.encode_prompt(\n            prompt,\n            do_classifier_free_guidance,\n            prompt_embeds=prompt_embeds,\n            max_sequence_length=max_sequence_length,\n            device=device,\n            dtype=dtype,\n        )\n\n        # 4. process images\n        if ar_condition_images is not None:\n            preprocessed_condition_images = []\n            for img in ar_condition_images:\n                image_height, image_width = (\n                    img.size[::-1]\n                    if isinstance(img, PIL.Image.Image)\n                    else img.shape[:2]\n                )\n                multiple_of = self.vae_scale_factor * self.transformer.config.patch_size\n                image_height = (image_height // multiple_of) * multiple_of\n                image_width = (image_width // multiple_of) * multiple_of\n                img = self.image_processor.preprocess(\n                    img, height=image_height, width=image_width\n                )\n                preprocessed_condition_images.append(img)\n            ar_condition_images = preprocessed_condition_images\n\n        # 5. Prepare latents and (optional) condition_images kv cache\n        latent_channels = self.transformer.config.in_channels\n        latents = self.prepare_latents(\n            batch_size=1,\n            num_channels_latents=latent_channels,\n            height=height,\n            width=width,\n            dtype=torch.float32,\n            device=device,\n            generator=generator,\n        )\n\n        kv_caches = GlmImageKVCache(num_layers=self.transformer.config.num_layers)\n\n        if ar_condition_images is not None:\n            latents_mean = torch.tensor(self.vae.config.latents_mean).view(\n                1, self.vae.config.latent_channels, 1, 1\n            )\n            latents_std = torch.tensor(self.vae.config.latents_std).view(\n                1, self.vae.config.latent_channels, 1, 1\n            )\n\n            latents_mean = latents_mean.to(device=device, dtype=prompt_embeds.dtype)\n            latents_std = latents_std.to(device=device, dtype=prompt_embeds.dtype)\n\n            for condition_image, condition_image_prior_token_id in zip(\n                ar_condition_images, prior_token_image_ids\n            ):\n                condition_image = condition_image.to(\n                    device=device, dtype=prompt_embeds.dtype\n                )\n\n                condition_latent = retrieve_latents(\n                    self.vae.encode(condition_image),\n                    generator=generator,\n                    sample_mode=\"argmax\",\n                )\n                condition_latent = (condition_latent - latents_mean) / latents_std\n\n                # Do not remove.\n                # It would be use to run the reference image through a\n                # forward pass at timestep 0 and keep the KV cache.\n                with set_forward_context(current_timestep=1, attn_metadata=None):\n                    _ = self.transformer(\n                        hidden_states=condition_latent,\n                        encoder_hidden_states=torch.zeros_like(prompt_embeds)[\n                            :1, :0, ...\n                        ],\n                        prior_token_id=condition_image_prior_token_id,\n                        prior_token_drop=torch.full_like(\n                            condition_image_prior_token_id, False, dtype=torch.bool\n                        ),\n                        timestep=torch.zeros((1,), device=device),\n                        target_size=torch.tensor(\n                            [condition_image.shape[-2:]], device=device\n                        ),\n                        crop_coords=torch.zeros((1, 2), device=device),\n                        attention_kwargs=attention_kwargs,\n                        kv_caches=kv_caches,\n                        kv_caches_mode=\"write\",\n                    )\n\n        # 6. Prepare additional timestep conditions\n        target_size = (height, width)\n        target_size = torch.tensor(\n            [target_size], dtype=prompt_embeds.dtype, device=device\n        )\n        crops_coords_top_left = torch.tensor(\n            [(0, 0)], dtype=prompt_embeds.dtype, device=device\n        )\n\n        # Prepare timesteps\n        image_seq_len = (\n            (height // self.vae_scale_factor) * (width // self.vae_scale_factor)\n        ) // (self.transformer.config.patch_size**2)\n        timesteps = np.linspace(\n            self.scheduler.config.num_train_timesteps, 1.0, num_inference_steps + 1\n        )[:-1]\n        timesteps = timesteps.astype(np.int64).astype(np.float32)\n        sigmas = timesteps / self.scheduler.config.num_train_timesteps\n        mu = calculate_shift(\n            image_seq_len,\n            self.scheduler.config.get(\"base_image_seq_len\", 256),\n            self.scheduler.config.get(\"base_shift\", 0.25),\n            self.scheduler.config.get(\"max_shift\", 0.75),\n        )\n        timesteps, num_inference_steps = retrieve_timesteps(\n            self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu\n        )\n        self._num_timesteps = len(timesteps)\n\n        # 7. Prepare for denoising loop\n\n        batch.prompt_embeds = [prompt_embeds]\n        batch.negative_prompt_embeds = [negative_prompt_embeds]\n        batch.latents = latents\n        batch.timesteps = timesteps\n        batch.num_inference_steps = num_inference_steps\n        batch.sigmas = sigmas.tolist()  # Convert numpy array to list for validation\n        batch.generator = generator\n        batch.raw_latent_shape = latents.shape\n\n        batch.prior_token_id = prior_token_id\n        batch.prior_token_drop_cond = torch.full_like(\n            prior_token_id, False, dtype=torch.bool\n        )\n        batch.prior_token_drop_uncond = torch.full_like(\n            prior_token_id, True, dtype=torch.bool\n        )\n        batch.target_size = target_size\n        batch.crop_coords = crops_coords_top_left\n\n        batch.kv_caches = kv_caches\n\n        batch.height = height\n        batch.width = width\n\n        return batch\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/helios_decoding.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\"\"\"\nHelios-specific decoding stage.\n\nDecodes latent chunks one at a time (matching diffusers HeliosPipeline behavior)\nto avoid temporal artifacts at chunk boundaries caused by Wan VAE's causal convolutions.\n\"\"\"\n\nimport torch\n\nfrom sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import OutputBatch, Req\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.decoding import (\n    DecodingStage,\n)\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\nclass HeliosDecodingStage(DecodingStage):\n    \"\"\"\n    Helios-specific decoding stage that decodes latent chunks independently.\n\n    The Wan VAE uses causal 3D convolutions with feature caching. When decoding\n    the full latent sequence at once, the causal conv processes all frames with\n    continuous context, producing a different number of output frames per latent\n    frame compared to chunk-by-chunk decoding. This causes temporal misalignment\n    and visible seams at chunk boundaries.\n\n    This stage decodes each chunk's latents separately (matching diffusers'\n    HeliosPipeline behavior) and concatenates the results in pixel space.\n    \"\"\"\n\n    @torch.no_grad()\n    def forward(\n        self,\n        batch: Req,\n        server_args: ServerArgs,\n    ) -> OutputBatch:\n        latent_chunks = getattr(batch, \"latent_chunks\", None)\n\n        if latent_chunks is None or len(latent_chunks) <= 1:\n            # No chunked latents or single chunk — use standard decode\n            return super().forward(batch, server_args)\n\n        # Load VAE if needed\n        self.load_model()\n\n        # Decode each chunk separately and concatenate in pixel space\n        video_chunks = []\n        for chunk_latents in latent_chunks:\n            chunk_video = self.decode(chunk_latents, server_args)\n            video_chunks.append(chunk_video)\n\n        frames = torch.cat(video_chunks, dim=2)\n        frames = server_args.pipeline_config.post_decoding(frames, server_args)\n\n        output_batch = OutputBatch(\n            output=frames,\n            trajectory_timesteps=batch.trajectory_timesteps,\n            trajectory_latents=batch.trajectory_latents,\n            trajectory_decoded=None,\n            metrics=batch.metrics,\n        )\n\n        self.offload_model()\n        return output_batch\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/helios_denoising.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\"\"\"\nHelios-specific chunked denoising stage.\n\nImplements Stage 1 chunked denoising with multi-term memory history\nand CFG Zero Star guidance. VAE decoding is handled by the standard\nDecodingStage downstream.\n\"\"\"\n\nimport math\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\n\nfrom sglang.multimodal_gen.runtime.distributed import get_local_torch_device\nfrom sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context\nfrom sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.base import (\n    PipelineStage,\n    StageParallelismType,\n)\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.runtime.utils.perf_logger import StageProfiler\nfrom sglang.multimodal_gen.utils import PRECISION_TO_TYPE\n\nlogger = init_logger(__name__)\n\n\ndef optimized_scale(positive_flat, negative_flat):\n    \"\"\"CFG Zero Star: compute optimal guidance scale.\"\"\"\n    positive_flat = positive_flat.float()\n    negative_flat = negative_flat.float()\n    dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)\n    squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8\n    return dot_product / squared_norm\n\n\ndef calculate_shift(\n    image_seq_len,\n    base_seq_len: int = 256,\n    max_seq_len: int = 4096,\n    base_shift: float = 0.5,\n    max_shift: float = 1.15,\n):\n    m = (max_shift - base_shift) / (max_seq_len - base_seq_len)\n    b = base_shift - m * base_seq_len\n    mu = image_seq_len * m + b\n    return mu\n\n\ndef sample_block_noise(\n    batch_size, channel, num_frames, height, width, gamma, patch_size=(1, 2, 2)\n):\n    \"\"\"Generate spatially-correlated block noise for pyramid SR.\"\"\"\n    _, ph, pw = patch_size\n    block_size = ph * pw\n\n    # Explicitly use CPU to avoid requiring MAGMA for cholesky on ROCm/CUDA\n    cov = (\n        torch.eye(block_size, device=\"cpu\") * (1 + gamma)\n        - torch.ones(block_size, block_size, device=\"cpu\") * gamma\n    )\n    dist = torch.distributions.MultivariateNormal(\n        torch.zeros(block_size, device=\"cpu\"), covariance_matrix=cov\n    )\n    block_number = batch_size * channel * num_frames * (height // ph) * (width // pw)\n\n    noise = dist.sample((block_number,))\n    noise = noise.view(\n        batch_size, channel, num_frames, height // ph, width // pw, ph, pw\n    )\n    noise = noise.permute(0, 1, 2, 3, 5, 4, 6).reshape(\n        batch_size, channel, num_frames, height, width\n    )\n    return noise\n\n\nclass HeliosChunkedDenoisingStage(PipelineStage):\n    \"\"\"\n    Helios chunked denoising stage implementing Stage 1 loop.\n\n    Iterates over video chunks, manages history buffers (short/mid/long),\n    runs transformer per chunk with CFG guidance, scheduler step,\n    and accumulates denoised latents. VAE decoding is left to DecodingStage.\n    \"\"\"\n\n    def __init__(self, transformer, scheduler):\n        super().__init__()\n        self.transformer = transformer\n        self.scheduler = scheduler\n\n    @property\n    def parallelism_type(self):\n        return StageParallelismType.REPLICATED\n\n    def _denoise_one_chunk(\n        self,\n        latents,\n        prompt_embeds,\n        negative_prompt_embeds,\n        timesteps,\n        guidance_scale,\n        indices_hidden_states,\n        indices_latents_history_short,\n        indices_latents_history_mid,\n        indices_latents_history_long,\n        latents_history_short,\n        latents_history_mid,\n        latents_history_long,\n        target_dtype,\n        device,\n        is_cfg_zero_star=True,\n        use_zero_init=True,\n        zero_steps=1,\n        batch=None,\n        server_args=None,\n        global_step_offset=0,\n    ):\n        \"\"\"Denoise a single chunk with full timestep loop.\"\"\"\n        batch_size = latents.shape[0]\n        do_cfg = guidance_scale > 1.0\n\n        for i, t in enumerate(timesteps):\n            with StageProfiler(\n                f\"denoising_step_{global_step_offset + i}\",\n                logger=logger,\n                metrics=batch.metrics if batch is not None else None,\n                perf_dump_path_provided=(\n                    batch.perf_dump_path is not None if batch is not None else False\n                ),\n            ):\n                timestep = t.expand(batch_size)\n                latent_model_input = latents.to(target_dtype)\n\n                with set_forward_context(\n                    current_timestep=t,\n                    forward_batch=batch,\n                    attn_metadata=None,\n                ):\n                    noise_pred = self.transformer(\n                        hidden_states=latent_model_input,\n                        timestep=timestep,\n                        encoder_hidden_states=prompt_embeds,\n                        indices_hidden_states=indices_hidden_states,\n                        indices_latents_history_short=indices_latents_history_short,\n                        indices_latents_history_mid=indices_latents_history_mid,\n                        indices_latents_history_long=indices_latents_history_long,\n                        latents_history_short=(\n                            latents_history_short.to(target_dtype)\n                            if latents_history_short is not None\n                            else None\n                        ),\n                        latents_history_mid=(\n                            latents_history_mid.to(target_dtype)\n                            if latents_history_mid is not None\n                            else None\n                        ),\n                        latents_history_long=(\n                            latents_history_long.to(target_dtype)\n                            if latents_history_long is not None\n                            else None\n                        ),\n                    )\n\n                if do_cfg:\n                    with set_forward_context(\n                        current_timestep=t,\n                        forward_batch=batch,\n                        attn_metadata=None,\n                    ):\n                        noise_uncond = self.transformer(\n                            hidden_states=latent_model_input,\n                            timestep=timestep,\n                            encoder_hidden_states=negative_prompt_embeds,\n                            indices_hidden_states=indices_hidden_states,\n                            indices_latents_history_short=indices_latents_history_short,\n                            indices_latents_history_mid=indices_latents_history_mid,\n                            indices_latents_history_long=indices_latents_history_long,\n                            latents_history_short=(\n                                latents_history_short.to(target_dtype)\n                                if latents_history_short is not None\n                                else None\n                            ),\n                            latents_history_mid=(\n                                latents_history_mid.to(target_dtype)\n                                if latents_history_mid is not None\n                                else None\n                            ),\n                            latents_history_long=(\n                                latents_history_long.to(target_dtype)\n                                if latents_history_long is not None\n                                else None\n                            ),\n                        )\n\n                    if is_cfg_zero_star:\n                        noise_pred_text = noise_pred\n                        positive_flat = noise_pred_text.reshape(batch_size, -1)\n                        negative_flat = noise_uncond.reshape(batch_size, -1)\n\n                        alpha = optimized_scale(positive_flat, negative_flat)\n                        alpha = alpha.view(\n                            batch_size, *([1] * (len(noise_pred_text.shape) - 1))\n                        )\n                        alpha = alpha.to(noise_pred_text.dtype)\n\n                        if (i <= zero_steps) and use_zero_init:\n                            noise_pred = noise_pred_text * 0.0\n                        else:\n                            noise_pred = noise_uncond * alpha + guidance_scale * (\n                                noise_pred_text - noise_uncond * alpha\n                            )\n                    else:\n                        noise_pred = noise_uncond + guidance_scale * (\n                            noise_pred - noise_uncond\n                        )\n\n                latents = self.scheduler.step(\n                    noise_pred, t, latents, return_dict=False\n                )[0]\n\n        return latents\n\n    def _denoise_one_chunk_stage2(\n        self,\n        latents,\n        prompt_embeds,\n        negative_prompt_embeds,\n        guidance_scale,\n        indices_hidden_states,\n        indices_latents_history_short,\n        indices_latents_history_mid,\n        indices_latents_history_long,\n        latents_history_short,\n        latents_history_mid,\n        latents_history_long,\n        target_dtype,\n        device,\n        pyramid_num_stages,\n        pyramid_num_inference_steps_list,\n        is_distilled,\n        is_amplify_first_chunk,\n        gamma,\n        is_cfg_zero_star=True,\n        use_zero_init=True,\n        zero_steps=1,\n        batch=None,\n        server_args=None,\n        global_step_offset=0,\n    ):\n        \"\"\"Denoise a single chunk using pyramid super-resolution (Stage 2).\"\"\"\n        batch_size, num_channel, num_frames, height, width = latents.shape\n        patch_size = self.transformer.patch_size\n\n        # Downsample to lowest pyramid level\n        latents = latents.permute(0, 2, 1, 3, 4).reshape(\n            batch_size * num_frames, num_channel, height, width\n        )\n        for _ in range(pyramid_num_stages - 1):\n            height //= 2\n            width //= 2\n            latents = F.interpolate(latents, size=(height, width), mode=\"bilinear\") * 2\n        latents = latents.reshape(\n            batch_size, num_frames, num_channel, height, width\n        ).permute(0, 2, 1, 3, 4)\n\n        start_point_list = None\n        if is_distilled:\n            start_point_list = [latents]\n\n        do_cfg = guidance_scale > 1.0\n        step_counter = global_step_offset\n\n        for i_s in range(pyramid_num_stages):\n            # Compute mu for current resolution\n            image_seq_len = (\n                latents.shape[-1]\n                * latents.shape[-2]\n                * latents.shape[-3]\n                // (patch_size[0] * patch_size[1] * patch_size[2])\n            )\n            mu = calculate_shift(image_seq_len)\n\n            self.scheduler.set_timesteps(\n                pyramid_num_inference_steps_list[i_s],\n                i_s,\n                device=device,\n                mu=mu,\n                is_amplify_first_chunk=is_amplify_first_chunk,\n            )\n            timesteps = self.scheduler.timesteps\n\n            if i_s > 0:\n                # Upsample 2x nearest-neighbor\n                height *= 2\n                width *= 2\n                latents = latents.permute(0, 2, 1, 3, 4).reshape(\n                    batch_size * num_frames,\n                    num_channel,\n                    height // 2,\n                    width // 2,\n                )\n                latents = F.interpolate(latents, size=(height, width), mode=\"nearest\")\n                latents = latents.reshape(\n                    batch_size, num_frames, num_channel, height, width\n                ).permute(0, 2, 1, 3, 4)\n\n                # Renoise with correlated block noise\n                ori_sigma = 1 - self.scheduler.ori_start_sigmas[i_s]\n                alpha = 1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma)\n                beta = alpha * (1 - ori_sigma) / math.sqrt(gamma)\n\n                bs, ch, nf, h, w = latents.shape\n                noise = sample_block_noise(bs, ch, nf, h, w, gamma, patch_size)\n                noise = noise.to(device=device, dtype=target_dtype)\n                latents = alpha * latents + beta * noise\n\n                if is_distilled:\n                    start_point_list.append(latents)\n\n            # Denoising loop for this pyramid stage\n            for idx, t in enumerate(timesteps):\n                with StageProfiler(\n                    f\"denoising_step_{step_counter}\",\n                    logger=logger,\n                    metrics=batch.metrics if batch is not None else None,\n                    perf_dump_path_provided=(\n                        batch.perf_dump_path is not None if batch is not None else False\n                    ),\n                ):\n                    timestep = t.expand(batch_size)\n                    latent_model_input = latents.to(target_dtype)\n\n                    with set_forward_context(\n                        current_timestep=t,\n                        forward_batch=batch,\n                        attn_metadata=None,\n                    ):\n                        noise_pred = self.transformer(\n                            hidden_states=latent_model_input,\n                            timestep=timestep,\n                            encoder_hidden_states=prompt_embeds,\n                            indices_hidden_states=indices_hidden_states,\n                            indices_latents_history_short=indices_latents_history_short,\n                            indices_latents_history_mid=indices_latents_history_mid,\n                            indices_latents_history_long=indices_latents_history_long,\n                            latents_history_short=(\n                                latents_history_short.to(target_dtype)\n                                if latents_history_short is not None\n                                else None\n                            ),\n                            latents_history_mid=(\n                                latents_history_mid.to(target_dtype)\n                                if latents_history_mid is not None\n                                else None\n                            ),\n                            latents_history_long=(\n                                latents_history_long.to(target_dtype)\n                                if latents_history_long is not None\n                                else None\n                            ),\n                        )\n\n                    if do_cfg:\n                        with set_forward_context(\n                            current_timestep=t,\n                            forward_batch=batch,\n                            attn_metadata=None,\n                        ):\n                            noise_uncond = self.transformer(\n                                hidden_states=latent_model_input,\n                                timestep=timestep,\n                                encoder_hidden_states=negative_prompt_embeds,\n                                indices_hidden_states=indices_hidden_states,\n                                indices_latents_history_short=indices_latents_history_short,\n                                indices_latents_history_mid=indices_latents_history_mid,\n                                indices_latents_history_long=indices_latents_history_long,\n                                latents_history_short=(\n                                    latents_history_short.to(target_dtype)\n                                    if latents_history_short is not None\n                                    else None\n                                ),\n                                latents_history_mid=(\n                                    latents_history_mid.to(target_dtype)\n                                    if latents_history_mid is not None\n                                    else None\n                                ),\n                                latents_history_long=(\n                                    latents_history_long.to(target_dtype)\n                                    if latents_history_long is not None\n                                    else None\n                                ),\n                            )\n\n                        if is_cfg_zero_star:\n                            noise_pred_text = noise_pred\n                            positive_flat = noise_pred_text.reshape(batch_size, -1)\n                            negative_flat = noise_uncond.reshape(batch_size, -1)\n\n                            alpha_cfg = optimized_scale(positive_flat, negative_flat)\n                            alpha_cfg = alpha_cfg.view(\n                                batch_size,\n                                *([1] * (len(noise_pred_text.shape) - 1)),\n                            )\n                            alpha_cfg = alpha_cfg.to(noise_pred_text.dtype)\n\n                            if (i_s == 0 and idx <= zero_steps) and use_zero_init:\n                                noise_pred = noise_pred_text * 0.0\n                            else:\n                                noise_pred = (\n                                    noise_uncond * alpha_cfg\n                                    + guidance_scale\n                                    * (noise_pred_text - noise_uncond * alpha_cfg)\n                                )\n                        else:\n                            noise_pred = noise_uncond + guidance_scale * (\n                                noise_pred - noise_uncond\n                            )\n\n                    latents = self.scheduler.step(\n                        noise_pred,\n                        t,\n                        latents,\n                        return_dict=False,\n                        cur_sampling_step=idx,\n                        dmd_noisy_tensor=(\n                            start_point_list[i_s]\n                            if start_point_list is not None\n                            else None\n                        ),\n                        dmd_sigmas=self.scheduler.sigmas,\n                        dmd_timesteps=self.scheduler.timesteps,\n                        all_timesteps=timesteps,\n                    )[0]\n\n                step_counter += 1\n\n        return latents, step_counter\n\n    def forward(self, batch: Req, server_args: ServerArgs) -> Req:\n        \"\"\"Run the Helios chunked denoising loop.\"\"\"\n        pipeline_config = server_args.pipeline_config\n        device = (\n            batch.latents.device\n            if hasattr(batch, \"latents\") and batch.latents is not None\n            else torch.device(\"cuda\")\n        )\n        target_dtype = PRECISION_TO_TYPE.get(\n            server_args.pipeline_config.precision, torch.bfloat16\n        )\n\n        # Get config params\n        num_latent_frames_per_chunk = pipeline_config.num_latent_frames_per_chunk\n        history_sizes = sorted(list(pipeline_config.history_sizes), reverse=True)\n        is_cfg_zero_star = pipeline_config.is_cfg_zero_star\n        zero_steps = pipeline_config.zero_steps\n        keep_first_frame = pipeline_config.keep_first_frame\n        guidance_scale = batch.guidance_scale\n        num_inference_steps = batch.num_inference_steps\n\n        # Stage 2 params\n        is_enable_stage2 = pipeline_config.is_enable_stage2\n        pyramid_num_stages = pipeline_config.pyramid_num_stages\n        pyramid_num_inference_steps_list = (\n            pipeline_config.pyramid_num_inference_steps_list\n        )\n        is_distilled = pipeline_config.is_distilled\n        is_amplify_first_chunk = pipeline_config.is_amplify_first_chunk\n        gamma = pipeline_config.gamma\n\n        # Move transformer to GPU if CPU-offloaded\n        if server_args.dit_cpu_offload and not server_args.use_fsdp_inference:\n            if next(self.transformer.parameters()).device.type == \"cpu\":\n                self.transformer.to(get_local_torch_device())\n\n        # Get encoder outputs (prompt_embeds is a list of tensors, one per encoder)\n        prompt_embeds = batch.prompt_embeds\n        if isinstance(prompt_embeds, list):\n            prompt_embeds = prompt_embeds[0]\n        prompt_embeds = prompt_embeds.to(target_dtype)\n        negative_prompt_embeds = batch.negative_prompt_embeds\n        if isinstance(negative_prompt_embeds, list):\n            negative_prompt_embeds = (\n                negative_prompt_embeds[0] if negative_prompt_embeds else None\n            )\n        if negative_prompt_embeds is not None:\n            negative_prompt_embeds = negative_prompt_embeds.to(target_dtype)\n\n        # Scale factors inherited from the Wan VAE used by Helios\n        # (AutoencoderKLWan: temporal_compression_ratio=4, spatial_compression_ratio=8)\n        vae_scale_factor_temporal = 4\n        vae_scale_factor_spatial = 8\n\n        # Compute chunking\n        height = batch.height\n        width = batch.width\n        num_frames = batch.num_frames\n        num_channels_latents = self.transformer.in_channels\n\n        window_num_frames = (\n            num_latent_frames_per_chunk - 1\n        ) * vae_scale_factor_temporal + 1\n        num_latent_chunk = max(\n            1, (num_frames + window_num_frames - 1) // window_num_frames\n        )\n        num_history_latent_frames = sum(history_sizes)\n        batch_size = 1  # Helios processes one video at a time\n\n        # Prepare history latents\n        if not keep_first_frame:\n            history_sizes[-1] = history_sizes[-1] + 1\n        history_latents = torch.zeros(\n            batch_size,\n            num_channels_latents,\n            num_history_latent_frames,\n            height // vae_scale_factor_spatial,\n            width // vae_scale_factor_spatial,\n            device=device,\n            dtype=torch.float32,\n        )\n\n        # Build frame indices\n        if keep_first_frame:\n            indices = torch.arange(\n                0, sum([1, *history_sizes, num_latent_frames_per_chunk])\n            )\n            (\n                indices_prefix,\n                indices_latents_history_long,\n                indices_latents_history_mid,\n                indices_latents_history_1x,\n                indices_hidden_states,\n            ) = indices.split([1, *history_sizes, num_latent_frames_per_chunk], dim=0)\n            indices_latents_history_short = torch.cat(\n                [indices_prefix, indices_latents_history_1x], dim=0\n            )\n        else:\n            indices = torch.arange(\n                0, sum([*history_sizes, num_latent_frames_per_chunk])\n            )\n            (\n                indices_latents_history_long,\n                indices_latents_history_mid,\n                indices_latents_history_short,\n                indices_hidden_states,\n            ) = indices.split([*history_sizes, num_latent_frames_per_chunk], dim=0)\n\n        indices_hidden_states = indices_hidden_states.unsqueeze(0)\n        indices_latents_history_short = indices_latents_history_short.unsqueeze(0)\n        indices_latents_history_mid = indices_latents_history_mid.unsqueeze(0)\n        indices_latents_history_long = indices_latents_history_long.unsqueeze(0)\n\n        # Set up scheduler\n        patch_size = self.transformer.patch_size\n        image_seq_len = (\n            num_latent_frames_per_chunk\n            * (height // vae_scale_factor_spatial)\n            * (width // vae_scale_factor_spatial)\n            // (patch_size[0] * patch_size[1] * patch_size[2])\n        )\n        # Sigma schedule from near-1.0 (pure noise) to 0.0 (clean); 0.999 avoids singularity\n        sigmas = np.linspace(0.999, 0.0, num_inference_steps + 1)[:-1]\n        mu = calculate_shift(image_seq_len)\n\n        # Chunk loop\n        image_latents = None\n        total_generated_latent_frames = 0\n        chunk_latents_list = []  # Store per-chunk latents for chunk-by-chunk decode\n        global_step_offset = 0  # Track step index across chunks for perf logging\n\n        self.log_info(\n            f\"Starting chunked denoising: {num_latent_chunk} chunks, \"\n            f\"{num_inference_steps} steps each\"\n        )\n\n        for k in range(num_latent_chunk):\n            is_first_chunk = k == 0\n\n            # Extract history\n            if keep_first_frame:\n                (\n                    latents_history_long,\n                    latents_history_mid,\n                    latents_history_1x,\n                ) = history_latents[:, :, -num_history_latent_frames:].split(\n                    history_sizes, dim=2\n                )\n                if image_latents is None and is_first_chunk:\n                    latents_prefix = torch.zeros(\n                        (\n                            batch_size,\n                            num_channels_latents,\n                            1,\n                            latents_history_1x.shape[-2],\n                            latents_history_1x.shape[-1],\n                        ),\n                        device=device,\n                        dtype=latents_history_1x.dtype,\n                    )\n                else:\n                    latents_prefix = image_latents\n                latents_history_short = torch.cat(\n                    [latents_prefix, latents_history_1x], dim=2\n                )\n            else:\n                (\n                    latents_history_long,\n                    latents_history_mid,\n                    latents_history_short,\n                ) = history_latents[:, :, -num_history_latent_frames:].split(\n                    history_sizes, dim=2\n                )\n\n            # Generate noise latents for this chunk\n            # Use batch.generator to ensure identical noise across SP ranks\n            latent_shape = (\n                batch_size,\n                num_channels_latents,\n                (window_num_frames - 1) // vae_scale_factor_temporal + 1,\n                height // vae_scale_factor_spatial,\n                width // vae_scale_factor_spatial,\n            )\n            generator = batch.generator\n            if isinstance(generator, list):\n                generator = generator[0] if len(generator) > 0 else None\n            gen_device = generator.device if generator is not None else device\n            latents = torch.randn(\n                latent_shape,\n                generator=generator,\n                device=gen_device,\n                dtype=torch.float32,\n            )\n            if latents.device != device:\n                latents = latents.to(device)\n\n            if is_enable_stage2:\n                # Stage 2: Pyramid SR denoising (handles scheduler internally)\n                latents, global_step_offset = self._denoise_one_chunk_stage2(\n                    latents=latents,\n                    prompt_embeds=prompt_embeds,\n                    negative_prompt_embeds=negative_prompt_embeds,\n                    guidance_scale=guidance_scale,\n                    indices_hidden_states=indices_hidden_states,\n                    indices_latents_history_short=indices_latents_history_short,\n                    indices_latents_history_mid=indices_latents_history_mid,\n                    indices_latents_history_long=indices_latents_history_long,\n                    latents_history_short=latents_history_short,\n                    latents_history_mid=latents_history_mid,\n                    latents_history_long=latents_history_long,\n                    target_dtype=target_dtype,\n                    device=device,\n                    pyramid_num_stages=pyramid_num_stages,\n                    pyramid_num_inference_steps_list=pyramid_num_inference_steps_list,\n                    is_distilled=is_distilled,\n                    is_amplify_first_chunk=(is_amplify_first_chunk and is_first_chunk),\n                    gamma=gamma,\n                    is_cfg_zero_star=is_cfg_zero_star,\n                    use_zero_init=True,\n                    zero_steps=zero_steps,\n                    batch=batch,\n                    server_args=server_args,\n                    global_step_offset=global_step_offset,\n                )\n            else:\n                # Stage 1: Standard flat denoising\n                self.scheduler.set_timesteps(\n                    num_inference_steps, device=device, sigmas=sigmas, mu=mu\n                )\n                timesteps = self.scheduler.timesteps\n\n                latents = self._denoise_one_chunk(\n                    latents=latents,\n                    prompt_embeds=prompt_embeds,\n                    negative_prompt_embeds=negative_prompt_embeds,\n                    timesteps=timesteps,\n                    guidance_scale=guidance_scale,\n                    indices_hidden_states=indices_hidden_states,\n                    indices_latents_history_short=indices_latents_history_short,\n                    indices_latents_history_mid=indices_latents_history_mid,\n                    indices_latents_history_long=indices_latents_history_long,\n                    latents_history_short=latents_history_short,\n                    latents_history_mid=latents_history_mid,\n                    latents_history_long=latents_history_long,\n                    target_dtype=target_dtype,\n                    device=device,\n                    is_cfg_zero_star=is_cfg_zero_star,\n                    use_zero_init=True,\n                    zero_steps=zero_steps,\n                    batch=batch,\n                    server_args=server_args,\n                    global_step_offset=global_step_offset,\n                )\n                global_step_offset += num_inference_steps\n\n            # Extract first frame as image_latents for subsequent chunks\n            if keep_first_frame and is_first_chunk and image_latents is None:\n                image_latents = latents[:, :, 0:1, :, :]\n\n            # Update history\n            total_generated_latent_frames += latents.shape[2]\n            history_latents = torch.cat([history_latents, latents], dim=2)\n            chunk_latents_list.append(latents)\n\n        # Move transformer back to CPU after denoising\n        if server_args.dit_cpu_offload and not server_args.use_fsdp_inference:\n            if next(self.transformer.parameters()).device.type != \"cpu\":\n                self.transformer.to(\"cpu\")\n        torch.cuda.empty_cache()\n\n        # Store per-chunk latents for chunk-by-chunk VAE decode (matches diffusers behavior).\n        # The standard DecodingStage will check for this attribute and decode each chunk\n        # separately to avoid temporal artifacts at chunk boundaries.\n        batch.latent_chunks = chunk_latents_list\n        batch.latents = history_latents[:, :, -total_generated_latent_frames:]\n\n        return batch\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/mova.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\"\"\"\nMOVA-specific pipeline stages.\n\nSequence Parallelism (SP) Support:\n- Video latents are sharded along the sequence dimension (T*H*W) after patchify\n- Audio latents are sharded along the sequence dimension (L) after patchify\n- USPAttention handles all-to-all communication internally\n- Latents are gathered before unpatchify to restore full sequence\n\"\"\"\n\nfrom __future__ import annotations\n\nimport functools\nimport inspect\nimport os\nfrom collections.abc import Iterable\n\nimport torch\nimport torch.nn as nn\nfrom diffusers.utils.torch_utils import randn_tensor\nfrom tqdm.auto import tqdm\n\nfrom sglang.multimodal_gen.runtime.distributed import (\n    get_local_torch_device,\n    get_world_group,\n)\nfrom sglang.multimodal_gen.runtime.distributed.communication_op import (\n    cfg_model_parallel_all_reduce,\n    sequence_model_parallel_all_gather,\n)\nfrom sglang.multimodal_gen.runtime.distributed.parallel_state import (\n    get_cfg_group,\n    get_classifier_free_guidance_rank,\n    get_sp_parallel_rank,\n    get_sp_world_size,\n)\nfrom sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context\n\n# Both audio and video DiT use the same sinusoidal_embedding_1d function\n# Import from mova_video_dit where it's defined (mova_audio_dit re-exports it)\nfrom sglang.multimodal_gen.runtime.models.dits.mova_video_dit import (\n    sinusoidal_embedding_1d,\n)\n\n# Create aliases for backward compatibility\nvideo_sinusoidal_embedding_1d = sinusoidal_embedding_1d\naudio_sinusoidal_embedding_1d = sinusoidal_embedding_1d\nfrom sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import OutputBatch, Req\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.base import (\n    PipelineStage,\n    StageParallelismType,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.decoding import (\n    _ensure_tensor_decode_output,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.validators import (\n    StageValidators as V,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.validators import (\n    VerificationResult,\n)\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs, get_global_server_args\nfrom sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.runtime.utils.perf_logger import StageProfiler\nfrom sglang.multimodal_gen.runtime.utils.profiler import SGLDiffusionProfiler\nfrom sglang.multimodal_gen.utils import PRECISION_TO_TYPE\nfrom sglang.srt.utils.common import get_compiler_backend\n\nlogger = init_logger(__name__)\n\n\nclass MOVALatentPreparationStage(PipelineStage):\n    \"\"\"Prepare video/audio noise latents for MOVA.\"\"\"\n\n    def __init__(self, audio_vae, require_vae_embedding: bool = True) -> None:\n        super().__init__()\n        self.audio_vae = audio_vae\n        self.require_vae_embedding = require_vae_embedding\n\n    def forward(self, batch: Req, server_args: ServerArgs) -> Req:\n        batch_size = batch.batch_size\n        num_frames = batch.num_frames\n        if num_frames is None:\n            raise ValueError(\"num_frames is required for MOVA\")\n\n        audio_num_samples = int(self.audio_vae.sample_rate * num_frames / batch.fps)\n\n        video_shape = server_args.pipeline_config.prepare_latent_shape(\n            batch, batch_size, num_frames\n        )\n        audio_shape = server_args.pipeline_config.prepare_audio_latent_shape(\n            batch_size, audio_num_samples, self.audio_vae\n        )\n\n        device = get_local_torch_device()\n        generator = batch.generator\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        dit_dtype = PRECISION_TO_TYPE[server_args.pipeline_config.dit_precision]\n        batch.latents = randn_tensor(\n            video_shape, generator=generator, device=device, dtype=dit_dtype\n        )\n        batch.audio_latents = randn_tensor(\n            audio_shape, generator=generator, device=device, dtype=dit_dtype\n        )\n\n        if batch.image_latent is not None:\n            batch.y = batch.image_latent.to(device=device, dtype=dit_dtype)\n        elif self.require_vae_embedding:\n            raise ValueError(\"MOVA requires reference image latents for denoising\")\n        return batch\n\n\nclass MOVATimestepPreparationStage(PipelineStage):\n    \"\"\"Prepare paired timesteps for MOVA.\"\"\"\n\n    def __init__(self, scheduler) -> None:\n        super().__init__()\n        self.scheduler = scheduler\n\n    def forward(self, batch: Req, server_args: ServerArgs) -> Req:\n        self.scheduler.set_timesteps(\n            batch.num_inference_steps,\n            denoising_strength=1.0,\n            shift=getattr(batch, \"sigma_shift\", self.scheduler.shift),\n        )\n        self.scheduler.set_pair_postprocess_by_name(\n            \"dual_sigma_shift\",\n            visual_shift=getattr(batch, \"visual_shift\", 5.0),\n            audio_shift=getattr(batch, \"audio_shift\", 5.0),\n        )\n        paired = self.scheduler.get_pairs()\n        batch.paired_timesteps = paired\n        batch.timesteps = paired\n        return batch\n\n\nclass MOVADenoisingStage(PipelineStage):\n    \"\"\"Run MOVA dual-tower denoising loop.\"\"\"\n\n    def __init__(self, video_dit, video_dit_2, audio_dit, dual_tower_bridge, scheduler):\n        super().__init__()\n        self.video_dit = video_dit\n        self.video_dit_2 = video_dit_2\n        self.audio_dit = audio_dit\n        self.dual_tower_bridge = dual_tower_bridge\n        self.scheduler = scheduler\n        self._cache_dit_enabled = False\n        self._cached_num_steps = None\n        self._torch_compiled = False\n\n    @property\n    def parallelism_type(self) -> StageParallelismType:\n        if get_global_server_args().enable_cfg_parallel:\n            return StageParallelismType.CFG_PARALLEL\n        return StageParallelismType.REPLICATED\n\n    def _predict(\n        self,\n        visual_dit,\n        visual_latents,\n        audio_latents,\n        y,\n        context,\n        timestep,\n        audio_timestep,\n        video_fps,\n        timestep_index: int,\n        attn_metadata,\n        forward_batch: Req | None = None,\n    ):\n        # Set forward context for distributed attention (USPAttention)\n        with set_forward_context(\n            current_timestep=timestep_index,\n            attn_metadata=attn_metadata,\n            forward_batch=forward_batch,\n        ):\n            return self.inference_single_step(\n                visual_dit=visual_dit,\n                visual_latents=visual_latents,\n                audio_latents=audio_latents,\n                y=y,\n                context=context,\n                timestep=timestep,\n                audio_timestep=audio_timestep,\n                video_fps=video_fps,\n            )\n\n    def _cfg_combine(self, pos, neg, guidance_scale, cfg_rank, enable_cfg_parallel):\n        if not enable_cfg_parallel:\n            return neg + guidance_scale * (pos - neg)\n        if cfg_rank == 0:\n            partial = guidance_scale * pos\n        else:\n            partial = (1 - guidance_scale) * neg\n        return cfg_model_parallel_all_reduce(partial)\n\n    def _maybe_enable_torch_compile(self, module: nn.Module, server_args: ServerArgs):\n        \"\"\"\n        Compile a module with torch.compile, and enable inductor overlap tweak if available.\n        No-op if torch compile is disabled or the object is not a nn.Module.\n        \"\"\"\n        if not server_args.enable_torch_compile or not isinstance(module, nn.Module):\n            return\n        compile_kwargs: dict[str, object] = {\"fullgraph\": False, \"dynamic\": None}\n\n        if current_platform.is_npu():\n            backend = get_compiler_backend()\n            compile_kwargs[\"backend\"] = backend\n            compile_kwargs[\"dynamic\"] = False\n            logger.info(\n                \"Compiling %s with torchair backend on NPU\",\n                module.__class__.__name__,\n            )\n        else:\n            try:\n                import torch._inductor.config as _inductor_cfg\n\n                _inductor_cfg.reorder_for_compute_comm_overlap = True\n            except ImportError:\n                pass\n            mode = os.environ.get(\n                \"SGLANG_TORCH_COMPILE_MODE\", \"max-autotune-no-cudagraphs\"\n            )\n            compile_kwargs[\"mode\"] = mode\n            logger.info(\"Compiling %s with mode: %s\", module.__class__.__name__, mode)\n\n        # TODO(triple-mu): support customized fullgraph and dynamic in the future\n        module.compile(**compile_kwargs)\n\n    def _maybe_compile_dits(self, server_args: ServerArgs):\n        if self._torch_compiled or not server_args.enable_torch_compile:\n            return\n        for module in filter(None, [self.video_dit, self.video_dit_2, self.audio_dit]):\n            self._maybe_enable_torch_compile(module, server_args)\n        self._torch_compiled = True\n\n    def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult:\n        \"\"\"Verify denoising stage inputs.\"\"\"\n        result = VerificationResult()\n        result.add_check(\"y\", batch.y, V.is_tensor)\n        result.add_check(\"paired_timesteps\", batch.paired_timesteps, V.is_tensor)\n        result.add_check(\"latents\", batch.latents, V.is_tensor)\n        result.add_check(\"audio_latents\", batch.audio_latents, V.is_tensor)\n        result.add_check(\"prompt_embeds\", batch.prompt_embeds, V.list_not_empty)\n        result.add_check(\n            \"negative_prompt_embeds\",\n            batch.negative_prompt_embeds,\n            lambda x: not batch.do_classifier_free_guidance or V.list_not_empty(x),\n        )\n        result.add_check(\n            \"num_inference_steps\", batch.num_inference_steps, V.positive_int\n        )\n        result.add_check(\"guidance_scale\", batch.guidance_scale, V.non_negative_float)\n        result.add_check(\n            \"guidance_rescale\", batch.guidance_rescale, V.non_negative_float\n        )\n        result.add_check(\n            \"do_classifier_free_guidance\",\n            batch.do_classifier_free_guidance,\n            V.bool_value,\n        )\n        return result\n\n    def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult:\n        \"\"\"Verify denoising stage outputs.\"\"\"\n        result = VerificationResult()\n        result.add_check(\"latents\", batch.latents, V.is_tensor)\n        result.add_check(\"audio_latents\", batch.audio_latents, V.is_tensor)\n        return result\n\n    def progress_bar(\n        self, iterable: Iterable | None = None, total: int | None = None\n    ) -> tqdm:\n        \"\"\"\n        Create a progress bar for the denoising process.\n        \"\"\"\n        local_rank = get_world_group().local_rank\n        disable = local_rank != 0\n        return tqdm(iterable=iterable, total=total, disable=disable)\n\n    def step_profile(self):\n        profiler = SGLDiffusionProfiler.get_instance()\n        if profiler:\n            profiler.step_denoising_step()\n\n    def rescale_noise_cfg(\n        self, noise_cfg, noise_pred_text, guidance_rescale=0.0\n    ) -> torch.Tensor:\n        \"\"\"\n        Rescale noise prediction according to guidance_rescale.\n\n        Based on findings of \"Common Diffusion Noise Schedules and Sample Steps are Flawed\"\n        (https://arxiv.org/pdf/2305.08891.pdf), Section 3.4.\n        \"\"\"\n        std_text = noise_pred_text.std(\n            dim=list(range(1, noise_pred_text.ndim)), keepdim=True\n        )\n        std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)\n        noise_pred_rescaled = noise_cfg * (std_text / std_cfg)\n        noise_cfg = (\n            guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg\n        )\n        return noise_cfg\n\n    def prepare_extra_func_kwargs(self, func, kwargs) -> dict[str, object]:\n        if not kwargs:\n            return {}\n\n        if isinstance(func, functools.partial) and func.args:\n            func = getattr(func.args[0], \"_original_forward\", func)\n\n        target_func = inspect.unwrap(func)\n        params = inspect.signature(target_func).parameters\n        return {k: v for k, v in kwargs.items() if k in params}\n\n    def _build_attn_metadata(\n        self, i: int, batch: Req, server_args: ServerArgs\n    ) -> object | None:\n        return None\n\n    def _manage_device_placement(\n        self,\n        model_to_use: nn.Module | None,\n        model_to_offload: nn.Module | None,\n        server_args: ServerArgs,\n    ):\n        if not server_args.dit_cpu_offload:\n            return\n\n        if (\n            model_to_offload is not None\n            and next(model_to_offload.parameters()).device.type == \"cuda\"\n        ):\n            model_to_offload.to(\"cpu\")\n\n        if (\n            model_to_use is not None\n            and next(model_to_use.parameters()).device.type == \"cpu\"\n        ):\n            model_to_use.to(get_local_torch_device())\n\n    def _select_visual_dit(\n        self, timestep: float, boundary_ratio: float | None, server_args: ServerArgs\n    ):\n        if boundary_ratio is None or self.video_dit_2 is None:\n            self._manage_device_placement(self.video_dit, None, server_args)\n            return self.video_dit\n\n        boundary_timestep = boundary_ratio * self.scheduler.num_train_timesteps\n        if timestep >= boundary_timestep:\n            current_model = self.video_dit\n            model_to_offload = self.video_dit_2\n        else:\n            current_model = self.video_dit_2\n            model_to_offload = self.video_dit\n\n        self._manage_device_placement(current_model, model_to_offload, server_args)\n        return current_model\n\n    def _ensure_shared_models_on_device(self, server_args: ServerArgs):\n        \"\"\"Ensure shared denoising modules are on the active device when cpu offload is enabled.\"\"\"\n        self._manage_device_placement(self.audio_dit, None, server_args)\n        self._manage_device_placement(self.dual_tower_bridge, None, server_args)\n\n    def _apply_guidance_rescale(\n        self,\n        noise_pred,\n        noise_pred_text,\n        guidance_rescale,\n        cfg_rank,\n        enable_cfg_parallel,\n    ):\n        if guidance_rescale <= 0.0:\n            return noise_pred\n        if enable_cfg_parallel:\n            std_cfg = noise_pred.std(dim=list(range(1, noise_pred.ndim)), keepdim=True)\n            if cfg_rank == 0:\n                assert noise_pred_text is not None\n                std_text = noise_pred_text.std(\n                    dim=list(range(1, noise_pred_text.ndim)), keepdim=True\n                )\n            else:\n                std_text = torch.empty_like(std_cfg)\n            std_text = get_cfg_group().broadcast(std_text, src=0)\n            noise_pred_rescaled = noise_pred * (std_text / std_cfg)\n            return guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * (\n                noise_pred\n            )\n        return self.rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale)\n\n    @torch.no_grad()\n    def forward(self, batch: Req, server_args: ServerArgs) -> Req:\n        self._maybe_compile_dits(server_args)\n        self._ensure_shared_models_on_device(server_args)\n\n        paired_timesteps = batch.paired_timesteps\n        if paired_timesteps is None:\n            raise ValueError(\"paired_timesteps must be set for MOVA\")\n\n        y = batch.y if batch.y is not None else batch.image_latent\n        if getattr(self.video_dit, \"require_vae_embedding\", False) and y is None:\n            raise ValueError(\"MOVA requires reference image latents for denoising\")\n\n        boundary_ratio = server_args.pipeline_config.boundary_ratio\n        total_steps = paired_timesteps.shape[0]\n        cfg_rank = get_classifier_free_guidance_rank()\n        enable_cfg_parallel = server_args.enable_cfg_parallel\n\n        is_warmup = batch.is_warmup\n        extra_step_kwargs = self.prepare_extra_func_kwargs(\n            self.scheduler.step_from_to,\n            getattr(batch, \"extra_step_kwargs\", None) or {},\n        )\n\n        metrics = getattr(batch, \"metrics\", None)\n        perf_dump_path_provided = getattr(batch, \"perf_dump_path\", None) is not None\n\n        with self.progress_bar(total=total_steps) as progress_bar:\n            for idx_step in range(total_steps):\n                with StageProfiler(\n                    f\"denoising_step_{idx_step}\",\n                    logger=logger,\n                    metrics=metrics,\n                    perf_dump_path_provided=perf_dump_path_provided,\n                ):\n                    pair_t = paired_timesteps[idx_step]\n                    if getattr(pair_t, \"shape\", None) == (2,):\n                        timestep, audio_timestep = pair_t\n                    else:\n                        timestep = pair_t\n                        audio_timestep = pair_t\n\n                    cur_visual_dit = self._select_visual_dit(\n                        timestep.item(), boundary_ratio, server_args\n                    )\n\n                    timestep = timestep.unsqueeze(0).to(device=get_local_torch_device())\n                    audio_timestep = audio_timestep.unsqueeze(0).to(\n                        device=get_local_torch_device()\n                    )\n\n                    attn_metadata = self._build_attn_metadata(\n                        idx_step, batch, server_args\n                    )\n\n                    if not batch.do_classifier_free_guidance:\n                        visual_noise_pred, audio_noise_pred = self._predict(\n                            cur_visual_dit,\n                            batch.latents,\n                            batch.audio_latents,\n                            y,\n                            batch.prompt_embeds[0],\n                            timestep,\n                            audio_timestep,\n                            batch.fps,\n                            idx_step,\n                            attn_metadata,\n                            batch,\n                        )\n                    else:\n                        if enable_cfg_parallel:\n                            if cfg_rank == 0:\n                                pos = self._predict(\n                                    cur_visual_dit,\n                                    batch.latents,\n                                    batch.audio_latents,\n                                    y,\n                                    batch.prompt_embeds[0],\n                                    timestep,\n                                    audio_timestep,\n                                    batch.fps,\n                                    idx_step,\n                                    attn_metadata,\n                                    batch,\n                                )\n                                neg = (None, None)\n                            else:\n                                pos = (None, None)\n                                neg = self._predict(\n                                    cur_visual_dit,\n                                    batch.latents,\n                                    batch.audio_latents,\n                                    y,\n                                    batch.negative_prompt_embeds[0],\n                                    timestep,\n                                    audio_timestep,\n                                    batch.fps,\n                                    idx_step,\n                                    attn_metadata,\n                                    batch,\n                                )\n                        else:\n                            pos = self._predict(\n                                cur_visual_dit,\n                                batch.latents,\n                                batch.audio_latents,\n                                y,\n                                batch.prompt_embeds[0],\n                                timestep,\n                                audio_timestep,\n                                batch.fps,\n                                idx_step,\n                                attn_metadata,\n                                batch,\n                            )\n                            neg = self._predict(\n                                cur_visual_dit,\n                                batch.latents,\n                                batch.audio_latents,\n                                y,\n                                batch.negative_prompt_embeds[0],\n                                timestep,\n                                audio_timestep,\n                                batch.fps,\n                                idx_step,\n                                attn_metadata,\n                                batch,\n                            )\n\n                            visual_noise_pred = self._cfg_combine(\n                                pos[0] if pos[0] is not None else neg[0],\n                                neg[0] if neg[0] is not None else pos[0],\n                                batch.guidance_scale,\n                                cfg_rank,\n                                enable_cfg_parallel,\n                            )\n                            audio_noise_pred = self._cfg_combine(\n                                pos[1] if pos[1] is not None else neg[1],\n                                neg[1] if neg[1] is not None else pos[1],\n                                batch.guidance_scale,\n                                cfg_rank,\n                                enable_cfg_parallel,\n                            )\n\n                            if batch.guidance_rescale > 0.0:\n                                visual_noise_pred = self._apply_guidance_rescale(\n                                    visual_noise_pred,\n                                    pos[0] if pos[0] is not None else None,\n                                    batch.guidance_rescale,\n                                    cfg_rank,\n                                    enable_cfg_parallel,\n                                )\n                                audio_noise_pred = self._apply_guidance_rescale(\n                                    audio_noise_pred,\n                                    pos[1] if pos[1] is not None else None,\n                                    batch.guidance_rescale,\n                                    cfg_rank,\n                                    enable_cfg_parallel,\n                                )\n\n                        if idx_step + 1 < total_steps:\n                            next_pair_t = paired_timesteps[idx_step + 1]\n                            if getattr(next_pair_t, \"shape\", None) == (2,):\n                                next_timestep, next_audio_timestep = next_pair_t\n                            else:\n                                next_timestep = next_pair_t\n                                next_audio_timestep = next_pair_t\n                        else:\n                            next_timestep = None\n                            next_audio_timestep = None\n\n                        batch.latents = self.scheduler.step_from_to(\n                            visual_noise_pred,\n                            timestep,\n                            next_timestep,\n                            batch.latents,\n                            **extra_step_kwargs,\n                        )\n                        batch.audio_latents = self.scheduler.step_from_to(\n                            audio_noise_pred,\n                            audio_timestep,\n                            next_audio_timestep,\n                            batch.audio_latents,\n                            **extra_step_kwargs,\n                        )\n\n                    if progress_bar is not None:\n                        progress_bar.update()\n                    if not is_warmup and hasattr(self, \"step_profile\"):\n                        self.step_profile()\n\n        for dit in filter(None, [self.video_dit, self.video_dit_2, self.audio_dit]):\n            if isinstance(dit, OffloadableDiTMixin):\n                dit.prepare_for_next_req()\n\n        return batch\n\n    def _shard_sequence_for_sp(\n        self, x: torch.Tensor, dim: int = 1\n    ) -> tuple[torch.Tensor, int]:\n        \"\"\"\n        Shard tensor along sequence dimension for Sequence Parallelism.\n\n        Args:\n            x: Input tensor\n            dim: Dimension to shard along\n\n        Returns:\n            (sharded_tensor, pad_len)\n        \"\"\"\n        sp_size = get_sp_world_size()\n        if sp_size <= 1:\n            return x, 0\n\n        sp_rank = get_sp_parallel_rank()\n        seq_len = x.shape[dim]\n\n        # Pad if needed\n        pad_len = (sp_size - (seq_len % sp_size)) % sp_size\n        if pad_len > 0:\n            pad_shape = list(x.shape)\n            pad_shape[dim] = pad_len\n            pad = torch.zeros(pad_shape, dtype=x.dtype, device=x.device)\n            x = torch.cat([x, pad], dim=dim)\n\n        # Shard\n        chunk_size = x.shape[dim] // sp_size\n        start = sp_rank * chunk_size\n        end = start + chunk_size\n        idx = [slice(None)] * x.dim()\n        idx[dim] = slice(start, end)\n        return x[tuple(idx)], pad_len\n\n    def _gather_sequence_from_sp(\n        self, x: torch.Tensor, pad_len: int, dim: int = 1\n    ) -> torch.Tensor:\n        \"\"\"\n        Gather tensor along sequence dimension after Sequence Parallelism.\n\n        Args:\n            x: Sharded tensor\n            pad_len: Padding length that was added during sharding\n            dim: Dimension to gather along\n\n        Returns:\n            Gathered tensor with padding removed\n        \"\"\"\n        sp_size = get_sp_world_size()\n        if sp_size <= 1:\n            return x\n\n        gathered = sequence_model_parallel_all_gather(x, dim=dim)\n        if pad_len > 0:\n            idx = [slice(None)] * gathered.dim()\n            idx[dim] = slice(0, gathered.shape[dim] - pad_len)\n            gathered = gathered[tuple(idx)]\n        return gathered\n\n    def inference_single_step(\n        self,\n        visual_dit,\n        visual_latents: torch.Tensor,\n        audio_latents: torch.Tensor,\n        y,\n        context: torch.Tensor,\n        timestep: torch.Tensor,\n        audio_timestep: torch.Tensor,\n        video_fps: float,\n    ):\n        \"\"\"\n        Single inference step for MOVA dual-tower denoising.\n\n        Supports Sequence Parallelism (SP):\n        - After patchify, sequences are sharded across SP ranks\n        - USPAttention handles distributed attention communication\n        - Before unpatchify, sequences are gathered back\n        \"\"\"\n        model_dtype = visual_dit.time_embedding.fc_in.weight.dtype\n        device = visual_latents.device\n\n        visual_context = context.to(device=device, dtype=model_dtype)\n        audio_context = context.to(device=device, dtype=model_dtype)\n        with torch.autocast(\n            device_type=current_platform.device_type, dtype=torch.float32\n        ):\n            visual_t = visual_dit.time_embedding(\n                video_sinusoidal_embedding_1d(visual_dit.freq_dim, timestep)\n            )\n            visual_t_mod, _ = visual_dit.time_projection(visual_t)\n            visual_t_mod = visual_t_mod.unflatten(1, (6, visual_dit.dim))\n\n            audio_t = self.audio_dit.time_embedding(\n                audio_sinusoidal_embedding_1d(self.audio_dit.freq_dim, audio_timestep)\n            )\n            audio_t_mod, _ = self.audio_dit.time_projection(audio_t)\n            audio_t_mod = audio_t_mod.unflatten(1, (6, self.audio_dit.dim))\n\n        visual_t = visual_t.to(model_dtype)\n        visual_t_mod = visual_t_mod.to(model_dtype)\n        audio_t = audio_t.to(model_dtype)\n        audio_t_mod = audio_t_mod.to(model_dtype)\n\n        visual_context_emb = visual_dit.text_embedding(visual_context)\n        audio_context_emb = self.audio_dit.text_embedding(audio_context)\n\n        visual_x = visual_latents.to(model_dtype)\n        audio_x = audio_latents.to(model_dtype)\n\n        if getattr(visual_dit, \"require_vae_embedding\", False):\n            visual_x = torch.cat([visual_x, y], dim=1)\n\n        # Patchify visual latents\n        visual_x, (t, h, w) = visual_dit.patchify(visual_x)\n        grid_size = (t, h, w)\n        full_visual_seq_len = t * h * w\n\n        # Build visual freqs for full sequence\n        visual_dit._init_freqs()\n        visual_freqs = tuple(freq.to(visual_x.device) for freq in visual_dit.freqs)\n        visual_freqs = (\n            torch.cat(\n                [\n                    visual_freqs[0][:t].view(t, 1, 1, -1).expand(t, h, w, -1),\n                    visual_freqs[1][:h].view(1, h, 1, -1).expand(t, h, w, -1),\n                    visual_freqs[2][:w].view(1, 1, w, -1).expand(t, h, w, -1),\n                ],\n                dim=-1,\n            )\n            .reshape(full_visual_seq_len, 1, -1)\n            .to(visual_x.device)\n        )\n\n        # Patchify audio latents\n        audio_x, (f,) = self.audio_dit.patchify(audio_x, None)\n        full_audio_seq_len = f\n\n        # Build audio freqs for full sequence\n        self.audio_dit._init_freqs()\n        audio_freqs = (\n            torch.cat(\n                [\n                    self.audio_dit.freqs[0][:f].view(f, -1).expand(f, -1),\n                    self.audio_dit.freqs[1][:f].view(f, -1).expand(f, -1),\n                    self.audio_dit.freqs[2][:f].view(f, -1).expand(f, -1),\n                ],\n                dim=-1,\n            )\n            .reshape(full_audio_seq_len, 1, -1)\n            .to(audio_x.device)\n        )\n\n        # Shard sequences for SP\n        visual_x, visual_pad_len = self._shard_sequence_for_sp(visual_x, dim=1)\n        audio_x, audio_pad_len = self._shard_sequence_for_sp(audio_x, dim=1)\n\n        # Shard freqs to match local sequence length\n        visual_freqs, _ = self._shard_sequence_for_sp(visual_freqs, dim=0)\n        audio_freqs, _ = self._shard_sequence_for_sp(audio_freqs, dim=0)\n\n        # Forward through dual-tower DiT\n        visual_x, audio_x = self.forward_dual_tower_dit(\n            visual_dit=visual_dit,\n            visual_x=visual_x,\n            audio_x=audio_x,\n            visual_context=visual_context_emb,\n            audio_context=audio_context_emb,\n            visual_t_mod=visual_t_mod,\n            audio_t_mod=audio_t_mod,\n            visual_freqs=visual_freqs,\n            audio_freqs=audio_freqs,\n            grid_size=grid_size,\n            video_fps=video_fps,\n            full_visual_seq_len=full_visual_seq_len,\n            full_audio_seq_len=full_audio_seq_len,\n        )\n\n        # Gather sequences back from SP before head/unpatchify\n        visual_x = self._gather_sequence_from_sp(visual_x, visual_pad_len, dim=1)\n        audio_x = self._gather_sequence_from_sp(audio_x, audio_pad_len, dim=1)\n\n        visual_output = visual_dit.head(visual_x, visual_t)\n        visual_output = visual_dit.unpatchify(visual_output, grid_size)\n\n        audio_output = self.audio_dit.head(audio_x, audio_t)\n        audio_output = self.audio_dit.unpatchify(audio_output, (f,))\n\n        return visual_output.float(), audio_output.float()\n\n    def forward_dual_tower_dit(\n        self,\n        visual_dit,\n        visual_x: torch.Tensor,\n        audio_x: torch.Tensor,\n        visual_context: torch.Tensor,\n        audio_context: torch.Tensor,\n        visual_t_mod: torch.Tensor,\n        audio_t_mod: torch.Tensor,\n        visual_freqs: torch.Tensor,\n        audio_freqs: torch.Tensor,\n        grid_size: tuple[int, int, int],\n        video_fps: float,\n        full_visual_seq_len: int,\n        full_audio_seq_len: int,\n        condition_scale: float | None = 1.0,\n        a2v_condition_scale: float | None = None,\n        v2a_condition_scale: float | None = None,\n    ):\n        \"\"\"\n        Forward pass through dual-tower DiT with cross-modal interaction.\n\n        Sequence Parallelism (SP) Support:\n        - visual_x and audio_x are already sharded along sequence dimension\n        - visual_freqs and audio_freqs match the local sequence length\n        - USPAttention in self-attention handles distributed communication\n        - LocalAttention in cross-attention operates on local sequence vs replicated context\n        - Cross-modal attention (dual_tower_bridge) uses LocalAttention (no SP communication)\n\n        Args:\n            full_visual_seq_len: Full visual sequence length before SP sharding\n            full_audio_seq_len: Full audio sequence length before SP sharding\n        \"\"\"\n        min_layers = min(len(visual_dit.blocks), len(self.audio_dit.blocks))\n        visual_layers = len(visual_dit.blocks)\n        sp_size = get_sp_world_size()\n\n        # Build RoPE frequencies for cross-attention if needed (only used when SP == 1)\n        # When SP > 1, we rebuild freqs inside the loop after gathering full sequences\n        visual_rope_cos_sin, audio_rope_cos_sin = (\n            self.dual_tower_bridge.build_aligned_freqs(\n                video_fps=video_fps,\n                grid_size=grid_size,\n                audio_steps=full_audio_seq_len,\n                device=visual_x.device,\n                dtype=visual_x.dtype,\n            )\n        )\n        if visual_rope_cos_sin is not None:\n            visual_rope_cos_sin = [\n                self._shard_sequence_for_sp(rope_cos_sin, dim=1)[0]\n                for rope_cos_sin in visual_rope_cos_sin\n            ]\n        if audio_rope_cos_sin is not None:\n            audio_rope_cos_sin = [\n                self._shard_sequence_for_sp(rope_cos_sin, dim=1)[0]\n                for rope_cos_sin in audio_rope_cos_sin\n            ]\n\n        for layer_idx in range(min_layers):\n            visual_block = visual_dit.blocks[layer_idx]\n            audio_block = self.audio_dit.blocks[layer_idx]\n\n            # Cross-modal interaction via dual tower bridge\n            # Bridge operations (PerFrameAttentionPooling, RoPE) expect full sequences\n            # When SP is enabled, we need to gather before bridge and shard after\n            if self.dual_tower_bridge.should_interact(layer_idx, \"a2v\"):\n                visual_x, audio_x = self.dual_tower_bridge(\n                    layer_idx,\n                    visual_x,\n                    audio_x,\n                    x_freqs=visual_rope_cos_sin,\n                    y_freqs=audio_rope_cos_sin,\n                    a2v_condition_scale=a2v_condition_scale,\n                    v2a_condition_scale=v2a_condition_scale,\n                    condition_scale=condition_scale,\n                    video_grid_size=grid_size,\n                )\n\n            # Self-attention and FFN in DiT blocks\n            visual_x = visual_block(\n                visual_x, visual_context, visual_t_mod, visual_freqs\n            )\n            audio_x = audio_block(audio_x, audio_context, audio_t_mod, audio_freqs)\n\n        # Process remaining visual layers (if visual has more layers than audio)\n        for layer_idx in range(min_layers, visual_layers):\n            visual_block = visual_dit.blocks[layer_idx]\n            visual_x = visual_block(\n                visual_x, visual_context, visual_t_mod, visual_freqs\n            )\n\n        return visual_x, audio_x\n\n\nclass MOVADecodingStage(PipelineStage):\n    \"\"\"Decode video and audio outputs for MOVA.\"\"\"\n\n    def __init__(self, video_vae, audio_vae) -> None:\n        super().__init__()\n        self.video_vae = video_vae\n        self.audio_vae = audio_vae\n\n    @property\n    def parallelism_type(self) -> StageParallelismType:\n        if get_global_server_args().enable_cfg_parallel:\n            return StageParallelismType.MAIN_RANK_ONLY\n        return StageParallelismType.REPLICATED\n\n    @torch.no_grad()\n    def forward(self, batch: Req, server_args: ServerArgs) -> OutputBatch:\n        self.video_vae = self.video_vae.to(get_local_torch_device())\n        self.audio_vae = self.audio_vae.to(get_local_torch_device())\n\n        video_latents = server_args.pipeline_config.denormalize_video_latents(\n            batch.latents, self.video_vae\n        )\n\n        vae_dtype = PRECISION_TO_TYPE[server_args.pipeline_config.vae_precision]\n        vae_autocast_enabled = (\n            vae_dtype != torch.float32\n        ) and not server_args.disable_autocast\n\n        with torch.autocast(\n            device_type=current_platform.device_type,\n            dtype=vae_dtype,\n            enabled=vae_autocast_enabled,\n        ):\n            if server_args.pipeline_config.vae_tiling:\n                self.video_vae.enable_tiling()\n            if not vae_autocast_enabled:\n                video_latents = video_latents.to(vae_dtype)\n            decode_output = self.video_vae.decode(video_latents)\n            video = _ensure_tensor_decode_output(decode_output)\n\n        video = (video / 2 + 0.5).clamp(0, 1)\n\n        with torch.autocast(\n            device_type=current_platform.device_type, dtype=torch.float32\n        ):\n            audio = self.audio_vae.decode(batch.audio_latents)\n        output_batch = OutputBatch(\n            output=video,\n            audio=audio,\n            audio_sample_rate=getattr(self.audio_vae, \"sample_rate\", None),\n            metrics=batch.metrics,\n        )\n        return output_batch\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/qwen_image_layered.py",
    "content": "import inspect\nimport math\nfrom typing import List, Optional, Union\n\nimport numpy as np\nimport torch\nfrom diffusers.image_processor import VaeImageProcessor\nfrom diffusers.utils.torch_utils import randn_tensor\n\nfrom sglang.multimodal_gen.runtime.distributed import get_local_torch_device\nfrom sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context\nfrom sglang.multimodal_gen.runtime.models.vision_utils import load_image\nfrom sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.base import PipelineStage\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\n# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit_plus.calculate_dimensions\ndef calculate_dimensions(target_area, ratio):\n    width = math.sqrt(target_area * ratio)\n    height = width / ratio\n\n    width = round(width / 32) * 32\n    height = round(height / 32) * 32\n\n    return width, height\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents\ndef retrieve_latents(\n    encoder_output: torch.Tensor,\n    generator: Optional[torch.Generator] = None,\n    sample_mode: str = \"sample\",\n):\n    if sample_mode == \"sample\":\n        return encoder_output.sample(generator)\n    elif sample_mode == \"argmax\":\n        return encoder_output.mode()\n    else:\n        return encoder_output\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps: Optional[int] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    timesteps: Optional[List[int]] = None,\n    sigmas: Optional[List[float]] = None,\n    **kwargs,\n):\n    r\"\"\"\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n\n    Args:\n        scheduler (`SchedulerMixin`):\n            The scheduler to get timesteps from.\n        num_inference_steps (`int`):\n            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`\n            must be `None`.\n        device (`str` or `torch.device`, *optional*):\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        timesteps (`List[int]`, *optional*):\n            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,\n            `num_inference_steps` and `sigmas` must be `None`.\n        sigmas (`List[float]`, *optional*):\n            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,\n            `num_inference_steps` and `timesteps` must be `None`.\n\n    Returns:\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n        second element is the number of inference steps.\n    \"\"\"\n    if timesteps is not None and sigmas is not None:\n        raise ValueError(\n            \"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values\"\n        )\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(\n            inspect.signature(scheduler.set_timesteps).parameters.keys()\n        )\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    elif sigmas is not None:\n        accept_sigmas = \"sigmas\" in set(\n            inspect.signature(scheduler.set_timesteps).parameters.keys()\n        )\n        if not accept_sigmas:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" sigmas schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\n\nclass QwenImageLayeredBeforeDenoisingStage(PipelineStage):\n    def __init__(\n        self, vae, tokenizer, processor, transformer, scheduler, model_path\n    ) -> None:\n        super().__init__()\n        self.vae = vae.to(torch.bfloat16)\n        from transformers import Qwen2_5_VLForConditionalGeneration\n\n        self.text_encoder = (\n            Qwen2_5_VLForConditionalGeneration.from_pretrained(\n                model_path, subfolder=\"text_encoder\"\n            )\n            .to(get_local_torch_device())\n            .to(torch.bfloat16)\n        )\n        self.tokenizer = tokenizer\n        self.processor = processor\n        self.transformer = transformer\n        self.scheduler = scheduler\n\n        self.vae_scale_factor = (\n            2 ** len(self.vae.temperal_downsample) if getattr(self, \"vae\", None) else 8\n        )\n        self.image_processor = VaeImageProcessor(\n            vae_scale_factor=self.vae_scale_factor * 2\n        )\n        self.vl_processor = processor\n        self.tokenizer_max_length = 1024\n        self.latent_channels = self.vae.z_dim if getattr(self, \"vae\", None) else 16\n\n        self.prompt_template_encode = \"<|im_start|>system\\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\\n<|im_start|>user\\n{}<|im_end|>\\n<|im_start|>assistant\\n\"\n        self.prompt_template_encode_start_idx = 34\n        self.image_caption_prompt_cn = \"\"\"<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n<|im_start|>user\\n# 图像标注器\\n你是一个专业的图像标注器。请基于输入图像，撰写图注:\\n1.\n使用自然、描述性的语言撰写图注，不要使用结构化形式或富文本形式。\\n2. 通过加入以下内容，丰富图注细节：\\n - 对象的属性：如数量、颜色、形状、大小、位置、材质、状态、动作等\\n -\n对象间的视觉关系：如空间关系、功能关系、动作关系、从属关系、比较关系、因果关系等\\n - 环境细节：例如天气、光照、颜色、纹理、气氛等\\n - 文字内容：识别图像中清晰可见的文字，不做翻译和解释，用引号在图注中强调\\n3.\n保持真实性与准确性：\\n - 不要使用笼统的描述\\n -\n描述图像中所有可见的信息，但不要加入没有在图像中出现的内容\\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\\n<|im_start|>assistant\\n\"\"\"\n        self.image_caption_prompt_en = \"\"\"<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n<|im_start|>user\\n# Image Annotator\\nYou are a professional\nimage annotator. Please write an image caption based on the input image:\\n1. Write the caption using natural,\ndescriptive language without structured formats or rich text.\\n2. Enrich caption details by including: \\n - Object\nattributes, such as quantity, color, shape, size, material, state, position, actions, and so on\\n - Vision Relations\nbetween objects, such as spatial relations, functional relations, possessive relations, attachment relations, action\nrelations, comparative relations, causal relations, and so on\\n - Environmental details, such as weather, lighting,\ncolors, textures, atmosphere, and so on\\n - Identify the text clearly visible in the image, without translation or\nexplanation, and highlight it in the caption with quotation marks\\n3. Maintain authenticity and accuracy:\\n - Avoid\ngeneralizations\\n - Describe all visible information in the image, while do not add information not explicitly shown in\nthe image\\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\\n<|im_start|>assistant\\n\"\"\"\n        self.default_sample_size = 128\n\n    # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden\n    def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):\n        bool_mask = mask.bool()\n        valid_lengths = bool_mask.sum(dim=1)\n        selected = hidden_states[bool_mask]\n        split_result = torch.split(selected, valid_lengths.tolist(), dim=0)\n\n        return split_result\n\n    def get_image_caption(self, prompt_image, use_en_prompt=True, device=None):\n        if use_en_prompt:\n            prompt = self.image_caption_prompt_en\n        else:\n            prompt = self.image_caption_prompt_cn\n        model_inputs = self.vl_processor(\n            text=prompt,\n            images=prompt_image,\n            padding=True,\n            return_tensors=\"pt\",\n        ).to(device)\n        with set_forward_context(current_timestep=0, attn_metadata=None):\n            generated_ids = self.text_encoder.generate(\n                **model_inputs, max_new_tokens=512\n            )\n            generated_ids_trimmed = [\n                out_ids[len(in_ids) :]\n                for in_ids, out_ids in zip(model_inputs.input_ids, generated_ids)\n            ]\n            output_text = self.vl_processor.batch_decode(\n                generated_ids_trimmed,\n                skip_special_tokens=True,\n                clean_up_tokenization_spaces=False,\n            )[0]\n            return output_text.strip()\n\n    def _get_qwen_prompt_embeds(\n        self,\n        prompt: Union[str, List[str]] = None,\n        device: Optional[torch.device] = None,\n        dtype: Optional[torch.dtype] = None,\n    ):\n        dtype = dtype or self.text_encoder.dtype\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n\n        template = self.prompt_template_encode\n        drop_idx = self.prompt_template_encode_start_idx\n        txt = [template.format(e) for e in prompt]\n        txt_tokens = self.tokenizer(\n            txt,\n            padding=True,\n            return_tensors=\"pt\",\n        ).to(device)\n        encoder_hidden_states = self.text_encoder(\n            input_ids=txt_tokens.input_ids,\n            attention_mask=txt_tokens.attention_mask,\n            output_hidden_states=True,\n        )\n        hidden_states = encoder_hidden_states.hidden_states[-1]\n        split_hidden_states = self._extract_masked_hidden(\n            hidden_states, txt_tokens.attention_mask\n        )\n        split_hidden_states = [e[drop_idx:] for e in split_hidden_states]\n        attn_mask_list = [\n            torch.ones(e.size(0), dtype=torch.long, device=e.device)\n            for e in split_hidden_states\n        ]\n        max_seq_len = max([e.size(0) for e in split_hidden_states])\n        prompt_embeds = torch.stack(\n            [\n                torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))])\n                for u in split_hidden_states\n            ]\n        )\n        encoder_attention_mask = torch.stack(\n            [\n                torch.cat([u, u.new_zeros(max_seq_len - u.size(0))])\n                for u in attn_mask_list\n            ]\n        )\n\n        prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)\n\n        return prompt_embeds, encoder_attention_mask\n\n    @staticmethod\n    def _pack_latents(latents, batch_size, num_channels_latents, height, width, layers):\n        latents = latents.view(\n            batch_size, layers, num_channels_latents, height // 2, 2, width // 2, 2\n        )\n        latents = latents.permute(0, 1, 3, 5, 2, 4, 6)\n        latents = latents.reshape(\n            batch_size, layers * (height // 2) * (width // 2), num_channels_latents * 4\n        )\n\n        return latents\n\n    # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.encode_prompt\n    def encode_prompt(\n        self,\n        prompt: Union[str, List[str]],\n        device: Optional[torch.device] = None,\n        num_images_per_prompt: int = 1,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        prompt_embeds_mask: Optional[torch.Tensor] = None,\n        max_sequence_length: int = 1024,\n    ):\n        r\"\"\"\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n        \"\"\"\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n\n        if prompt_embeds is None:\n            prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(\n                prompt, device\n            )\n\n        prompt_embeds = prompt_embeds[:, :max_sequence_length]\n        prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]\n\n        _, seq_len, _ = prompt_embeds.shape\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(num_images_per_prompt, seq_len, -1)\n        prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds_mask = prompt_embeds_mask.view(num_images_per_prompt, seq_len)\n\n        return prompt_embeds, prompt_embeds_mask\n\n    # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline._encode_vae_image\n    def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):\n        self.vae = self.vae.to(get_local_torch_device())\n        if isinstance(generator, list):\n            image_latents = [\n                retrieve_latents(\n                    self.vae.encode(image[i : i + 1]),\n                    generator=generator[i],\n                    sample_mode=\"argmax\",\n                )\n                for i in range(image.shape[0])\n            ]\n            image_latents = torch.cat(image_latents, dim=0)\n        else:\n            image_latents = retrieve_latents(\n                self.vae.encode(image), generator=generator, sample_mode=\"argmax\"\n            )\n        latents_mean = (\n            torch.tensor(self.vae.config.latents_mean)\n            .view(1, self.latent_channels, 1, 1, 1)\n            .to(image_latents.device, image_latents.dtype)\n        )\n        latents_std = (\n            torch.tensor(self.vae.config.latents_std)\n            .view(1, self.latent_channels, 1, 1, 1)\n            .to(image_latents.device, image_latents.dtype)\n        )\n        image_latents = (image_latents - latents_mean) / latents_std\n        self.vae.to(\"cpu\")\n        return image_latents\n\n    def prepare_latents(\n        self,\n        image,\n        batch_size,\n        num_channels_latents,\n        height,\n        width,\n        layers,\n        dtype,\n        device,\n        generator,\n        latents=None,\n    ):\n        # VAE applies 8x compression on images but we must also account for packing which requires\n        # latent height and width to be divisible by 2.\n        height = 2 * (int(height) // (self.vae_scale_factor * 2))\n        width = 2 * (int(width) // (self.vae_scale_factor * 2))\n        shape = (\n            batch_size,\n            layers + 1,\n            num_channels_latents,\n            height,\n            width,\n        )  ### the generated first image is combined image\n\n        image_latents = None\n        if image is not None:\n            image = image.to(device=device, dtype=dtype)\n            if image.shape[1] != self.latent_channels:\n                image_latents = self._encode_vae_image(image=image, generator=generator)\n            else:\n                image_latents = image\n            if (\n                batch_size > image_latents.shape[0]\n                and batch_size % image_latents.shape[0] == 0\n            ):\n                # expand init_latents for batch_size\n                additional_image_per_prompt = batch_size // image_latents.shape[0]\n                image_latents = torch.cat(\n                    [image_latents] * additional_image_per_prompt, dim=0\n                )\n            elif (\n                batch_size > image_latents.shape[0]\n                and batch_size % image_latents.shape[0] != 0\n            ):\n                raise ValueError(\n                    f\"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts.\"\n                )\n            else:\n                image_latents = torch.cat([image_latents], dim=0)\n\n            image_latent_height, image_latent_width = image_latents.shape[3:]\n            image_latents = image_latents.permute(\n                0, 2, 1, 3, 4\n            )  # (b, c, f, h, w) -> (b, f, c, h, w)\n            image_latents = self._pack_latents(\n                image_latents,\n                batch_size,\n                num_channels_latents,\n                image_latent_height,\n                image_latent_width,\n                1,\n            )\n\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n        if latents is None:\n            latents = randn_tensor(\n                shape, generator=generator, device=device, dtype=dtype\n            )\n            latents = self._pack_latents(\n                latents, batch_size, num_channels_latents, height, width, layers + 1\n            )\n        else:\n            latents = latents.to(device=device, dtype=dtype)\n\n        return latents, image_latents\n\n    def forward(\n        self,\n        batch: Req,\n        server_args: ServerArgs,\n    ) -> Req:\n        use_en_prompt = True\n        device = get_local_torch_device()\n        layers = batch.num_frames\n        num_inference_steps = batch.num_inference_steps\n        generator = batch.generator\n\n        assert batch.image_path is not None\n        image = load_image(batch.image_path[0])\n        image = image.convert(\"RGBA\")\n        image_size = image.size\n        resolution = server_args.pipeline_config.resolution\n        calculated_width, calculated_height = calculate_dimensions(\n            resolution * resolution, image_size[0] / image_size[1]\n        )\n\n        height = calculated_height\n        width = calculated_width\n\n        multiple_of = self.vae_scale_factor * 2\n        width = width // multiple_of * multiple_of\n        height = height // multiple_of * multiple_of\n\n        # if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):\n        image = self.image_processor.resize(image, calculated_height, calculated_width)\n        prompt_image = image\n        image = self.image_processor.preprocess(\n            image, calculated_height, calculated_width\n        )\n        image = image.unsqueeze(2)\n        image = image.to(dtype=torch.bfloat16)\n\n        prompt = self.get_image_caption(\n            prompt_image, use_en_prompt=use_en_prompt, device=device\n        )\n\n        prompt_embeds, prompt_embeds_mask = self.encode_prompt(\n            prompt=prompt,\n            device=device,\n        )\n\n        negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(\n            prompt=batch.negative_prompt,\n            device=device,\n        )\n\n        num_channels_latents = self.transformer.config.in_channels // 4\n        latents, image_latents = self.prepare_latents(\n            image,\n            1,\n            num_channels_latents,\n            height,\n            width,\n            layers,\n            prompt_embeds.dtype,\n            device,\n            generator,\n        )\n        img_shapes = [\n            [\n                *[\n                    (\n                        1,\n                        height // self.vae_scale_factor // 2,\n                        width // self.vae_scale_factor // 2,\n                    )\n                    for _ in range(layers + 1)\n                ],\n                (\n                    1,\n                    calculated_height // self.vae_scale_factor // 2,\n                    calculated_width // self.vae_scale_factor // 2,\n                ),\n            ]\n        ]\n\n        # 5. Prepare timesteps\n        sigmas = np.linspace(1.0, 0, num_inference_steps + 1)[:-1]\n        image_seq_len = latents.shape[1]\n        base_seqlen = 256 * 256 / 16 / 16\n        mu = (image_latents.shape[1] / base_seqlen) ** 0.5\n        timesteps, num_inference_steps = retrieve_timesteps(\n            self.scheduler,\n            num_inference_steps,\n            device,\n            sigmas=sigmas,\n            mu=mu,\n        )\n\n        txt_seq_lens = (\n            prompt_embeds_mask.sum(dim=1).tolist()\n            if prompt_embeds_mask is not None\n            else None\n        )\n        negative_txt_seq_lens = (\n            negative_prompt_embeds_mask.sum(dim=1).tolist()\n            if negative_prompt_embeds_mask is not None\n            else None\n        )\n        is_rgb = torch.tensor([0]).to(device=device, dtype=torch.long)\n\n        batch.prompt_embeds = [prompt_embeds]\n        batch.prompt_embeds_mask = [prompt_embeds_mask]\n        batch.negative_prompt_embeds = [negative_prompt_embeds]\n        batch.negative_prompt_embeds_mask = [negative_prompt_embeds_mask]\n        batch.latents = latents\n        batch.image_latent = image_latents\n        batch.num_inference_steps = num_inference_steps\n        batch.sigmas = sigmas.tolist()  # Convert numpy array to list for validation\n        batch.generator = torch.manual_seed(0)\n        batch.original_condition_image_size = image_size\n        batch.raw_latent_shape = latents.shape\n        batch.txt_seq_lens = txt_seq_lens\n        batch.img_shapes = img_shapes\n\n        return batch\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines_core/stages/text_connector.py",
    "content": "import torch\n\nfrom sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context\nfrom sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.base import PipelineStage\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\n\n\nclass LTX2TextConnectorStage(PipelineStage):\n    \"\"\"\n    Stage for applying LTX-2 Text Connectors to split/transform text embeddings\n    into video and audio contexts.\n    \"\"\"\n\n    def __init__(self, connectors):\n        super().__init__()\n        self.connectors = connectors\n\n    def forward(self, batch: Req, server_args: ServerArgs) -> Req:\n        # Input: batch.prompt_embeds (from Gemma, [B, S, D])\n        # Output: batch.prompt_embeds (Video Context), batch.audio_prompt_embeds (Audio Context)\n\n        prompt_embeds = batch.prompt_embeds\n        prompt_attention_mask = batch.prompt_attention_mask\n        neg_prompt_embeds = batch.negative_prompt_embeds\n        neg_prompt_attention_mask = batch.negative_attention_mask\n\n        if isinstance(prompt_embeds, list):\n            prompt_embeds = prompt_embeds[0] if len(prompt_embeds) > 0 else None\n\n        if isinstance(prompt_attention_mask, list):\n            prompt_attention_mask = (\n                prompt_attention_mask[0] if len(prompt_attention_mask) > 0 else None\n            )\n\n        if isinstance(neg_prompt_embeds, list):\n            neg_prompt_embeds = (\n                neg_prompt_embeds[0] if len(neg_prompt_embeds) > 0 else None\n            )\n\n        if isinstance(neg_prompt_attention_mask, list):\n            neg_prompt_attention_mask = (\n                neg_prompt_attention_mask[0]\n                if len(neg_prompt_attention_mask) > 0\n                else None\n            )\n\n        # Handle CFG: Concatenate negative and positive inputs\n        if batch.do_classifier_free_guidance:\n\n            # Concatenate: [Negative, Positive]\n            prompt_embeds = torch.cat([neg_prompt_embeds, prompt_embeds], dim=0)\n            prompt_attention_mask = torch.cat(\n                [neg_prompt_attention_mask, prompt_attention_mask], dim=0\n            )\n\n        # Prepare additive mask for connectors (as per Diffusers implementation)\n        dtype = prompt_embeds.dtype\n\n        additive_attention_mask = (1 - prompt_attention_mask.to(dtype)) * -1000000.0\n\n        # Call connectors\n        # Expects: prompt_embeds, attention_mask, additive_mask=True\n        with set_forward_context(current_timestep=None, attn_metadata=None):\n            connector_prompt_embeds, connector_audio_prompt_embeds, connector_mask = (\n                self.connectors(\n                    prompt_embeds, additive_attention_mask, additive_mask=True\n                )\n            )\n\n        # Split results if CFG was enabled\n        if batch.do_classifier_free_guidance:\n            neg_embeds, pos_embeds = connector_prompt_embeds.chunk(2, dim=0)\n            neg_audio_embeds, pos_audio_embeds = connector_audio_prompt_embeds.chunk(\n                2, dim=0\n            )\n            neg_mask, pos_mask = connector_mask.chunk(2, dim=0)\n\n            batch.prompt_embeds = [pos_embeds]\n            batch.audio_prompt_embeds = [pos_audio_embeds]\n            batch.prompt_attention_mask = pos_mask\n\n            batch.negative_prompt_embeds = [neg_embeds]\n            batch.negative_audio_prompt_embeds = [neg_audio_embeds]\n            batch.negative_attention_mask = neg_mask\n        else:\n            # Update positive fields\n            batch.prompt_embeds = [connector_prompt_embeds]\n            batch.audio_prompt_embeds = [connector_audio_prompt_embeds]\n            batch.prompt_attention_mask = connector_mask\n\n        return batch\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines_core/stages/text_encoding.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"\nPrompt encoding stages for diffusion pipelines.\n\nThis module contains implementations of prompt encoding stages for diffusion pipelines.\n\"\"\"\n\nimport torch\n\nfrom sglang.multimodal_gen.configs.models.encoders import BaseEncoderOutput\nfrom sglang.multimodal_gen.configs.pipeline_configs import FluxPipelineConfig\nfrom sglang.multimodal_gen.configs.pipeline_configs.flux import Flux2PipelineConfig\nfrom sglang.multimodal_gen.runtime.distributed import get_local_torch_device\nfrom sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context\nfrom sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.base import PipelineStage\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.validators import (\n    StageValidators as V,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.validators import (\n    VerificationResult,\n)\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\nclass TextEncodingStage(PipelineStage):\n    \"\"\"\n    Stage for encoding text prompts into embeddings for diffusion models.\n\n    This stage handles the encoding of text prompts into the embedding space\n    expected by the diffusion model.\n    \"\"\"\n\n    def __init__(self, text_encoders, tokenizers) -> None:\n        \"\"\"\n        Initialize the prompt encoding stage.\n\n        \"\"\"\n        super().__init__()\n        self.tokenizers = tokenizers\n        self.text_encoders = text_encoders\n\n    @torch.no_grad()\n    def forward(\n        self,\n        batch: Req,\n        server_args: ServerArgs,\n    ) -> Req:\n        \"\"\"\n        Encode the prompt into text encoder hidden states.\n        \"\"\"\n        assert len(self.tokenizers) == len(self.text_encoders)\n        assert len(self.text_encoders) == len(\n            server_args.pipeline_config.text_encoder_configs\n        )\n\n        # Encode positive prompt with all available encoders\n        assert batch.prompt is not None\n        prompt_text: str | list[str] = batch.prompt\n\n        all_indices: list[int] = list(range(len(self.text_encoders)))\n\n        prompt_embeds_list, prompt_masks_list, pooler_embeds_list = self.encode_text(\n            prompt_text,\n            server_args,\n            encoder_index=all_indices,\n            return_attention_mask=True,\n        )\n\n        for pe in prompt_embeds_list:\n            batch.prompt_embeds.append(pe)\n\n        for pe in pooler_embeds_list:\n            batch.pooled_embeds.append(pe)\n\n        if batch.prompt_attention_mask is None:\n            batch.prompt_attention_mask = []\n            for am in prompt_masks_list:\n                batch.prompt_attention_mask.append(am)\n\n        # Encode negative prompt if CFG is enabled\n        if batch.do_classifier_free_guidance:\n            assert isinstance(batch.negative_prompt, str)\n            neg_embeds_list, neg_masks_list, neg_pooler_embeds_list = self.encode_text(\n                batch.negative_prompt,\n                server_args,\n                encoder_index=all_indices,\n                return_attention_mask=True,\n            )\n\n            assert batch.negative_prompt_embeds is not None\n\n            for ne in neg_embeds_list:\n                batch.negative_prompt_embeds.append(ne)\n\n            for pe in neg_pooler_embeds_list:\n                batch.neg_pooled_embeds.append(pe)\n            if batch.negative_attention_mask is None:\n                batch.negative_attention_mask = []\n                for nm in neg_masks_list:\n                    batch.negative_attention_mask.append(nm)\n\n        return batch\n\n    def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult:\n        \"\"\"Verify text encoding stage inputs.\"\"\"\n        result = VerificationResult()\n        result.add_check(\"prompt\", batch.prompt, V.string_or_list_strings)\n        result.add_check(\n            \"negative_prompt\",\n            batch.negative_prompt,\n            lambda x: not batch.do_classifier_free_guidance or V.string_not_none(x),\n        )\n        result.add_check(\n            \"do_classifier_free_guidance\",\n            batch.do_classifier_free_guidance,\n            V.bool_value,\n        )\n        result.add_check(\"prompt_embeds\", batch.prompt_embeds, V.is_list)\n        result.add_check(\n            \"negative_prompt_embeds\", batch.negative_prompt_embeds, V.none_or_list\n        )\n        return result\n\n    def prepare_tokenizer_kwargs(self, tokenizer_kwargs, **kwargs):\n        tok_kwargs = tokenizer_kwargs | kwargs\n\n        return tok_kwargs\n\n    @torch.no_grad()\n    def encode_text(\n        self,\n        text: str | list[str],\n        server_args: ServerArgs,\n        encoder_index: int | list[int] | None = None,\n        return_attention_mask: bool = False,\n        return_type: str = \"list\",  # one of: \"list\", \"dict\", \"stack\"\n        device: torch.device | str | None = None,\n        dtype: torch.dtype | None = None,\n        max_length: int | None = None,\n        truncation: bool | None = None,\n        padding: bool | str | None = None,\n        return_overflowing_tokens=None,\n        return_length=None,\n    ):\n        \"\"\"\n        Encode plain text using selected text encoder(s) and return embeddings.\n\n        Args:\n            text: A single string or a list of strings to encode.\n            server_args: The inference arguments providing pipeline config,\n                including tokenizer and encoder settings, preprocess and postprocess\n                functions.\n            encoder_index: Encoder selector by index. Accepts an int or list of ints.\n            return_attention_mask: If True, also return attention masks for each\n                selected encoder.\n            return_type: \"list\" (default) returns a list aligned with selection;\n                \"dict\" returns a dict keyed by encoder index as a string; \"stack\" stacks along a\n                new first dimension (requires matching shapes).\n            device: Optional device override for inputs; defaults to local torch device.\n            dtype: Optional dtype to cast returned embeddings to.\n            max_length: Optional per-call tokenizer override.\n            truncation: Optional per-call tokenizer override.\n            padding: Optional per-call tokenizer override.\n\n        Returns:\n            Depending on return_type and return_attention_mask:\n            - list: List[Tensor] or (List[Tensor], List[Tensor])\n            - dict: Dict[str, Tensor] or (Dict[str, Tensor], Dict[str, Tensor])\n            - stack: Tensor of shape [num_encoders, ...] or a tuple with stacked\n              attention masks\n        \"\"\"\n\n        assert len(self.tokenizers) == len(self.text_encoders)\n        assert len(self.text_encoders) == len(\n            server_args.pipeline_config.text_encoder_configs\n        )\n\n        # Resolve selection into indices\n        encoder_cfgs = server_args.pipeline_config.text_encoder_configs\n        if encoder_index is None:\n            indices: list[int] = [0]\n        elif isinstance(encoder_index, int):\n            indices = [encoder_index]\n        else:\n            indices = list(encoder_index)\n        # validate range\n        num_encoders = len(self.text_encoders)\n        for idx in indices:\n            if idx < 0 or idx >= num_encoders:\n                raise IndexError(\n                    f\"encoder index {idx} out of range [0, {num_encoders - 1}]\"\n                )\n\n        # Validate indices are within range\n        num_encoders = len(self.text_encoders)\n\n        # Normalize input to list[str]\n        assert isinstance(text, str | list)\n        if isinstance(text, str):\n            texts: list[str] = [text]\n        else:\n            texts = text\n\n        embeds_list: list[torch.Tensor] = []\n        pooled_embeds_list: list[torch.Tensor] = []\n\n        attn_masks_list: list[torch.Tensor] = []\n\n        preprocess_funcs = server_args.pipeline_config.preprocess_text_funcs\n        postprocess_funcs = server_args.pipeline_config.postprocess_text_funcs\n        text_encoder_extra_args = server_args.pipeline_config.text_encoder_extra_args\n        encoder_cfgs = server_args.pipeline_config.text_encoder_configs\n\n        if return_type not in (\"list\", \"dict\", \"stack\"):\n            raise ValueError(\n                f\"Invalid return_type '{return_type}'. Expected one of: 'list', 'dict', 'stack'\"\n            )\n\n        target_device = device if device is not None else get_local_torch_device()\n\n        for i in indices:\n            tokenizer = self.tokenizers[i]\n            text_encoder = self.text_encoders[i]\n            encoder_config = encoder_cfgs[i]\n            preprocess_func = preprocess_funcs[i]\n            postprocess_func = postprocess_funcs[i]\n            text_encoder_extra_arg = (\n                text_encoder_extra_args[i]\n                if i < len(text_encoder_extra_args) and text_encoder_extra_args[i]\n                else {}\n            )\n\n            processed_text_list: list[str] = []\n            for prompt_str in texts:\n                preprocessed = preprocess_func(prompt_str)\n                processed_text_list.append(preprocessed)\n\n            # Prepare tokenizer args\n            tok_kwargs = self.prepare_tokenizer_kwargs(\n                encoder_config.tokenizer_kwargs,\n                **text_encoder_extra_arg,\n            )\n\n            text_inputs: dict = server_args.pipeline_config.tokenize_prompt(\n                processed_text_list, tokenizer, tok_kwargs\n            ).to(target_device)\n\n            input_ids = text_inputs[\"input_ids\"]\n            is_flux_v1 = isinstance(\n                server_args.pipeline_config, FluxPipelineConfig\n            ) and not isinstance(server_args.pipeline_config, Flux2PipelineConfig)\n            is_flux_t5 = is_flux_v1 and i == 1\n\n            if is_flux_t5:\n                attention_mask = torch.ones(input_ids.shape[:2], device=target_device)\n            else:\n                attention_mask = text_inputs[\"attention_mask\"]\n            with set_forward_context(current_timestep=0, attn_metadata=None):\n                outputs: BaseEncoderOutput = text_encoder(\n                    input_ids=input_ids,\n                    attention_mask=attention_mask,\n                    output_hidden_states=True,\n                    use_cache=False,\n                )\n            prompt_embeds = postprocess_func(outputs, text_inputs)\n            if dtype is not None:\n                prompt_embeds = prompt_embeds.to(dtype=dtype)\n\n            embeds_list.append(prompt_embeds)\n            if is_flux_v1:\n                pooled_embeds_list.append(outputs.pooler_output)\n            if return_attention_mask:\n                attn_masks_list.append(attention_mask)\n\n        # Shape results according to return_type\n        if return_type == \"list\":\n            if return_attention_mask:\n                return embeds_list, attn_masks_list, pooled_embeds_list\n            return embeds_list, pooled_embeds_list\n\n        if return_type == \"dict\":\n            key_strs = [str(i) for i in indices]\n            embeds_dict = {k: v for k, v in zip(key_strs, embeds_list, strict=False)}\n            if return_attention_mask:\n                attn_dict = {\n                    k: v for k, v in zip(key_strs, attn_masks_list, strict=False)\n                }\n                return embeds_dict, attn_dict\n            return embeds_dict\n\n        # return_type == \"stack\"\n        # Validate shapes are compatible\n        base_shape = list(embeds_list[0].shape)\n        for t in embeds_list[1:]:\n            if list(t.shape) != base_shape:\n                raise ValueError(\n                    f\"Cannot stack embeddings with differing shapes: {[list(t.shape) for t in embeds_list]}\"\n                )\n        stacked_embeds = torch.stack(embeds_list, dim=0)\n        if return_attention_mask:\n            base_mask_shape = list(attn_masks_list[0].shape)\n            for m in attn_masks_list[1:]:\n                if list(m.shape) != base_mask_shape:\n                    raise ValueError(\n                        f\"Cannot stack attention masks with differing shapes: {[list(m.shape) for m in attn_masks_list]}\"\n                    )\n            stacked_masks = torch.stack(attn_masks_list, dim=0)\n            return stacked_embeds, stacked_masks\n        return stacked_embeds\n\n    def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult:\n        \"\"\"Verify text encoding stage outputs.\"\"\"\n        result = VerificationResult()\n        result.add_check(\n            \"prompt_embeds\", batch.prompt_embeds, V.list_of_tensors_min_dims(2)\n        )\n        result.add_check(\n            \"negative_prompt_embeds\",\n            batch.negative_prompt_embeds,\n            lambda x: not batch.do_classifier_free_guidance\n            or V.list_of_tensors_with_min_dims(x, 2),\n        )\n        if batch.debug:\n            logger.debug(f\"{batch.prompt_embeds=}\")\n            logger.debug(f\"{batch.negative_prompt_embeds=}\")\n        return result\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines_core/stages/timestep_preparation.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"\nTimestep preparation stages for diffusion pipelines.\n\nThis module contains implementations of timestep preparation stages for diffusion pipelines.\n\"\"\"\n\nimport inspect\nfrom typing import Any, Callable, Tuple\n\nimport torch\n\nfrom sglang.multimodal_gen.runtime.distributed import get_local_torch_device\nfrom sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.base import (\n    PipelineStage,\n    StageParallelismType,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.validators import (\n    StageValidators as V,\n)\nfrom sglang.multimodal_gen.runtime.pipelines_core.stages.validators import (\n    VerificationResult,\n)\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\nclass TimestepPreparationStage(PipelineStage):\n    \"\"\"\n    Stage for preparing timesteps for the diffusion process.\n\n    This stage handles the preparation of the timestep sequence that will be used\n    during the diffusion process.\n    \"\"\"\n\n    def __init__(\n        self,\n        scheduler,\n        prepare_extra_set_timesteps_kwargs: list[\n            Callable[[Req, ServerArgs], Tuple[str, Any]]\n        ] = [],\n    ) -> None:\n        super().__init__()\n        self.scheduler = scheduler\n        self.prepare_extra_set_timesteps_kwargs = (\n            prepare_extra_set_timesteps_kwargs or []\n        )\n\n    @property\n    def parallelism_type(self) -> StageParallelismType:\n        return StageParallelismType.REPLICATED\n\n    def forward(\n        self,\n        batch: Req,\n        server_args: ServerArgs,\n    ) -> Req:\n        \"\"\"\n        Prepare timesteps for the diffusion process.\n\n\n\n        Returns:\n            The batch with prepared timesteps.\n        \"\"\"\n        scheduler = self.scheduler\n        device = get_local_torch_device()\n        num_inference_steps = batch.num_inference_steps\n        timesteps = batch.timesteps\n        sigmas = batch.sigmas\n        n_tokens = batch.n_tokens\n\n        sigmas = server_args.pipeline_config.prepare_sigmas(sigmas, num_inference_steps)\n        batch.sigmas = sigmas\n\n        # Prepare extra kwargs for set_timesteps\n        extra_set_timesteps_kwargs = {}\n        if (\n            n_tokens is not None\n            and \"n_tokens\" in inspect.signature(scheduler.set_timesteps).parameters\n        ):\n            extra_set_timesteps_kwargs[\"n_tokens\"] = n_tokens\n\n        for callee in self.prepare_extra_set_timesteps_kwargs:\n            key, value = callee(batch, server_args)\n            assert isinstance(key, str)\n            extra_set_timesteps_kwargs[key] = value\n            if key == \"mu\":\n                batch.extra[\"mu\"] = value\n\n        # Handle custom timesteps or sigmas\n        if timesteps is not None and sigmas is not None:\n            raise ValueError(\n                \"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values\"\n            )\n\n        if timesteps is not None:\n            accepts_timesteps = (\n                \"timesteps\" in inspect.signature(scheduler.set_timesteps).parameters\n            )\n            if not accepts_timesteps:\n                raise ValueError(\n                    f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                    f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n                )\n            scheduler.set_timesteps(\n                timesteps=timesteps, device=device, **extra_set_timesteps_kwargs\n            )\n            timesteps = scheduler.timesteps\n        elif sigmas is not None:\n            accept_sigmas = (\n                \"sigmas\" in inspect.signature(scheduler.set_timesteps).parameters\n            )\n            if not accept_sigmas:\n                raise ValueError(\n                    f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                    f\" sigmas schedules. Please check whether you are using the correct scheduler.\"\n                )\n            scheduler.set_timesteps(\n                sigmas=sigmas, device=device, **extra_set_timesteps_kwargs\n            )\n            timesteps = scheduler.timesteps\n        else:\n            scheduler.set_timesteps(\n                num_inference_steps, device=device, **extra_set_timesteps_kwargs\n            )\n            timesteps = scheduler.timesteps\n\n        # Update batch with prepared timesteps\n        batch.timesteps = timesteps\n        if not batch.is_warmup:\n            self.log_debug(\"timesteps: %s\", timesteps)\n        return batch\n\n    def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult:\n        \"\"\"Verify timestep preparation stage inputs.\"\"\"\n        result = VerificationResult()\n        result.add_check(\n            \"num_inference_steps\", batch.num_inference_steps, V.positive_int\n        )\n        result.add_check(\"timesteps\", batch.timesteps, V.none_or_tensor)\n        result.add_check(\"sigmas\", batch.sigmas, V.none_or_list)\n        result.add_check(\"n_tokens\", batch.n_tokens, V.none_or_positive_int)\n        return result\n\n    def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult:\n        \"\"\"Verify timestep preparation stage outputs.\"\"\"\n        if (\n            batch.is_warmup\n            and isinstance(batch.timesteps, torch.Tensor)\n            and torch.isnan(batch.timesteps).any()\n        ):\n            # when num-inference-steps == 1, the last sigma being 1, the 1 / last_sigma could be nan\n            # this a workaround for warmup req only\n            batch.timesteps = torch.ones(\n                (1,), dtype=torch.float32, device=get_local_torch_device()\n            )\n\n        result = VerificationResult()\n        result.add_check(\"timesteps\", batch.timesteps, [V.is_tensor, V.with_dims(1)])\n        return result\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/pipelines_core/stages/validators.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"\nCommon validators for pipeline stage verification.\n\nThis module provides reusable validation functions that can be used across\nall pipeline stages for input/output verification.\n\"\"\"\n\nfrom collections.abc import Callable\nfrom typing import Any\n\nimport torch\n\n\nclass StageValidators:\n    \"\"\"Common validators for pipeline stages.\"\"\"\n\n    @staticmethod\n    def not_none(value: Any) -> bool:\n        \"\"\"Check if value is not None.\"\"\"\n        return value is not None\n\n    @staticmethod\n    def positive_int(value: Any) -> bool:\n        \"\"\"Check if value is a positive integer.\"\"\"\n        return isinstance(value, int) and value > 0\n\n    @staticmethod\n    def non_negative_int(value: Any) -> bool:\n        \"\"\"Check if value is a non-negative float.\"\"\"\n        return isinstance(value, int | float) and value >= 0\n\n    @staticmethod\n    def positive_float(value: Any) -> bool:\n        \"\"\"Check if value is a positive float.\"\"\"\n        return isinstance(value, int | float) and value > 0\n\n    @staticmethod\n    def non_negative_float(value: Any) -> bool:\n        \"\"\"Check if value is a non-negative float.\"\"\"\n        return isinstance(value, int | float) and value >= 0\n\n    @staticmethod\n    def divisible_by(value: Any, divisor: int) -> bool:\n        \"\"\"Check if value is divisible by divisor.\"\"\"\n        return value is not None and isinstance(value, int) and value % divisor == 0\n\n    @staticmethod\n    def is_tensor(value: Any) -> bool:\n        \"\"\"Check if value is a torch tensor and doesn't contain NaN values.\"\"\"\n        if not isinstance(value, torch.Tensor):\n            return False\n        return not torch.isnan(value).any().item()\n\n    @staticmethod\n    def tensor_with_dims(value: Any, dims: int) -> bool:\n        \"\"\"Check if value is a tensor with specific dimensions and no NaN values.\"\"\"\n        if not isinstance(value, torch.Tensor):\n            return False\n        if value.dim() != dims:\n            return False\n        return not torch.isnan(value).any().item()\n\n    @staticmethod\n    def tensor_min_dims(value: Any, min_dims: int) -> bool:\n        \"\"\"Check if value is a tensor with at least min_dims dimensions and no NaN values.\"\"\"\n        if not isinstance(value, torch.Tensor):\n            return False\n        if value.dim() < min_dims:\n            return False\n        return not torch.isnan(value).any().item()\n\n    @staticmethod\n    def tensor_shape_matches(value: Any, expected_shape: tuple) -> bool:\n        \"\"\"Check if tensor shape matches expected shape (None for any size) and no NaN values.\"\"\"\n        if not isinstance(value, torch.Tensor):\n            return False\n        if len(value.shape) != len(expected_shape):\n            return False\n        for actual, expected in zip(value.shape, expected_shape, strict=True):\n            if expected is not None and actual != expected:\n                return False\n        return not torch.isnan(value).any().item()\n\n    @staticmethod\n    def list_not_empty(value: Any) -> bool:\n        \"\"\"Check if value is a non-empty list.\"\"\"\n        return isinstance(value, list) and len(value) > 0\n\n    @staticmethod\n    def list_length(value: Any, length: int) -> bool:\n        \"\"\"Check if list has specific length.\"\"\"\n        return isinstance(value, list) and len(value) == length\n\n    @staticmethod\n    def list_min_length(value: Any, min_length: int) -> bool:\n        \"\"\"Check if list has at least min_length items.\"\"\"\n        return isinstance(value, list) and len(value) >= min_length\n\n    @staticmethod\n    def string_not_empty(value: Any) -> bool:\n        \"\"\"Check if value is a non-empty string.\"\"\"\n        return isinstance(value, str) and len(value.strip()) > 0\n\n    @staticmethod\n    def string_not_none(value: Any) -> bool:\n        \"\"\"Check if value is a non-empty string.\"\"\"\n        return isinstance(value, str) and len(value) > 0\n\n    @staticmethod\n    def string_or_list_strings(value: Any) -> bool:\n        \"\"\"Check if value is a string or list of strings.\"\"\"\n        if isinstance(value, str):\n            return True\n        if isinstance(value, list):\n            return all(isinstance(item, str) for item in value)\n        return False\n\n    @staticmethod\n    def bool_value(value: Any) -> bool:\n        \"\"\"Check if value is a boolean.\"\"\"\n        return isinstance(value, bool)\n\n    @staticmethod\n    def generator_or_list_generators(value: Any) -> bool:\n        \"\"\"Check if value is a Generator or list of Generators.\"\"\"\n        if isinstance(value, torch.Generator):\n            return True\n        if isinstance(value, list):\n            return all(isinstance(item, torch.Generator) for item in value)\n        return False\n\n    @staticmethod\n    def is_list(value: Any) -> bool:\n        \"\"\"Check if value is a list (can be empty).\"\"\"\n        return isinstance(value, list)\n\n    @staticmethod\n    def is_tuple(value: Any) -> bool:\n        \"\"\"Check if value is a tuple.\"\"\"\n        return isinstance(value, tuple)\n\n    @staticmethod\n    def none_or_tensor(value: Any) -> bool:\n        \"\"\"Check if value is None or a tensor without NaN values.\"\"\"\n        if value is None:\n            return True\n        if not isinstance(value, torch.Tensor):\n            return False\n        return not torch.isnan(value).any().item()\n\n    @staticmethod\n    def list_of_tensors_with_dims(value: Any, dims: int) -> bool:\n        \"\"\"Check if value is a non-empty list where all items are tensors with specific dimensions and no NaN values.\"\"\"\n        if not isinstance(value, list) or len(value) == 0:\n            return False\n        for item in value:\n            if not isinstance(item, torch.Tensor):\n                return False\n            if item.dim() != dims:\n                return False\n            if torch.isnan(item).any().item():\n                return False\n        return True\n\n    @staticmethod\n    def list_of_tensors(value: Any) -> bool:\n        \"\"\"Check if value is a non-empty list where all items are tensors without NaN values.\"\"\"\n        if not isinstance(value, list) or len(value) == 0:\n            return False\n        for item in value:\n            if not isinstance(item, torch.Tensor):\n                return False\n            if torch.isnan(item).any().item():\n                return False\n        return True\n\n    @staticmethod\n    def list_of_tensors_with_min_dims(value: Any, min_dims: int) -> bool:\n        \"\"\"Check if value is a non-empty list where all items are tensors with at least min_dims dimensions and no NaN values.\"\"\"\n        if not isinstance(value, list) or len(value) == 0:\n            return False\n        for item in value:\n            if not isinstance(item, torch.Tensor):\n                return False\n            if item.dim() < min_dims:\n                return False\n            if torch.isnan(item).any().item():\n                return False\n        return True\n\n    @staticmethod\n    def none_or_tensor_with_dims(dims: int) -> Callable[[Any], bool]:\n        \"\"\"Return a validator that checks if value is None or a tensor with specific dimensions and no NaN values.\"\"\"\n\n        def validator(value: Any) -> bool:\n            if value is None:\n                return True\n            if not isinstance(value, torch.Tensor):\n                return False\n            if value.dim() != dims:\n                return False\n            return not torch.isnan(value).any().item()\n\n        return validator\n\n    @staticmethod\n    def none_or_list(value: Any) -> bool:\n        \"\"\"Check if value is None or a list.\"\"\"\n        return value is None or isinstance(value, list)\n\n    @staticmethod\n    def none_or_positive_int(value: Any) -> bool:\n        \"\"\"Check if value is None or a positive integer.\"\"\"\n        return value is None or (isinstance(value, int) and value > 0)\n\n    # Helper methods that return functions for common patterns\n    @staticmethod\n    def with_dims(dims: int) -> Callable[[Any], bool]:\n        \"\"\"Return a validator that checks if tensor has specific dimensions and no NaN values.\"\"\"\n\n        def validator(value: Any) -> bool:\n            return StageValidators.tensor_with_dims(value, dims)\n\n        return validator\n\n    @staticmethod\n    def min_dims(min_dims: int) -> Callable[[Any], bool]:\n        \"\"\"Return a validator that checks if tensor has at least min_dims dimensions and no NaN values.\"\"\"\n\n        def validator(value: Any) -> bool:\n            return StageValidators.tensor_min_dims(value, min_dims)\n\n        return validator\n\n    @staticmethod\n    def divisible(divisor: int) -> Callable[[Any], bool]:\n        \"\"\"Return a validator that checks if value is divisible by divisor.\"\"\"\n\n        def validator(value: Any) -> bool:\n            return StageValidators.divisible_by(value, divisor)\n\n        return validator\n\n    @staticmethod\n    def positive_int_divisible(divisor: int) -> Callable[[Any], bool]:\n        \"\"\"Return a validator that checks if value is a positive integer divisible by divisor.\"\"\"\n\n        def validator(value: Any) -> bool:\n            return (\n                isinstance(value, int)\n                and value > 0\n                and StageValidators.divisible_by(value, divisor)\n            )\n\n        return validator\n\n    @staticmethod\n    def list_of_tensors_dims(dims: int) -> Callable[[Any], bool]:\n        \"\"\"Return a validator that checks if value is a list of tensors with specific dimensions and no NaN values.\"\"\"\n\n        def validator(value: Any) -> bool:\n            return StageValidators.list_of_tensors_with_dims(value, dims)\n\n        return validator\n\n    @staticmethod\n    def list_of_tensors_min_dims(min_dims: int) -> Callable[[Any], bool]:\n        \"\"\"Return a validator that checks if value is a list of tensors with at least min_dims dimensions and no NaN values.\"\"\"\n\n        def validator(value: Any) -> bool:\n            return StageValidators.list_of_tensors_with_min_dims(value, min_dims)\n\n        return validator\n\n\nclass ValidationFailure:\n    \"\"\"Details about a specific validation failure.\"\"\"\n\n    def __init__(\n        self,\n        validator_name: str,\n        actual_value: Any,\n        expected: str | None = None,\n        error_msg: str | None = None,\n    ):\n        self.validator_name = validator_name\n        self.actual_value = actual_value\n        self.expected = expected\n        self.error_msg = error_msg\n\n    def __str__(self) -> str:\n        parts = [f\"Validator '{self.validator_name}' failed\"]\n\n        if self.error_msg:\n            parts.append(f\"Error: {self.error_msg}\")\n\n        # Add actual value info (but limit very long representations)\n        actual_str = self._format_value(self.actual_value)\n        parts.append(f\"Actual: {actual_str}\")\n\n        if self.expected:\n            parts.append(f\"Expected: {self.expected}\")\n\n        return \". \".join(parts)\n\n    def _format_value(self, value: Any) -> str:\n        \"\"\"Format a value for display in error messages.\"\"\"\n        if value is None:\n            return \"None\"\n        elif isinstance(value, torch.Tensor):\n            return f\"tensor(shape={list(value.shape)}, dtype={value.dtype})\"\n        elif isinstance(value, list):\n            if len(value) == 0:\n                return \"[]\"\n            elif len(value) <= 3:\n                item_strs = [self._format_value(item) for item in value]\n                return f\"[{', '.join(item_strs)}]\"\n            else:\n                return f\"list(length={len(value)}, first_item={self._format_value(value[0])})\"\n        elif isinstance(value, str):\n            if len(value) > 50:\n                return f\"'{value[:47]}...'\"\n            else:\n                return f\"'{value}'\"\n        else:\n            return f\"{type(value).__name__}({value})\"\n\n\nclass VerificationResult:\n    \"\"\"Wrapper class for stage verification results.\"\"\"\n\n    def __init__(self) -> None:\n        self._checks: dict[str, bool] = {}\n        self._failures: dict[str, list[ValidationFailure]] = {}\n\n    def add_check(\n        self,\n        field_name: str,\n        value: Any,\n        validators: Callable[[Any], bool] | list[Callable[[Any], bool]],\n    ) -> \"VerificationResult\":\n        \"\"\"\n        Add a validation check for a field.\n\n        Args:\n            field_name: Name of the field being checked\n            value: The actual value to validate\n            validators: Single validation function or list of validation functions.\n                       Each function will be called with the value as its first argument.\n\n        Returns:\n            Self for method chaining\n\n        Examples:\n            # Single validator\n            result.add_check(\"tensor\", my_tensor, V.is_tensor)\n\n            # Multiple validators (all must pass)\n            result.add_check(\"latents\", batch.latents, [V.is_tensor, V.with_dims(5)])\n\n            # Using partial functions for parameters\n            result.add_check(\"height\", batch.height, [V.not_none, V.divisible(8)])\n        \"\"\"\n        if not isinstance(validators, list):\n            validators = [validators]\n\n        failures = []\n        all_passed = True\n\n        # Apply all validators and collect detailed failure info\n        for validator in validators:\n            try:\n                passed = validator(value)\n                if not passed:\n                    all_passed = False\n                    failure = self._create_validation_failure(validator, value)\n                    failures.append(failure)\n            except Exception as e:\n                # If any validator raises an exception, consider the check failed\n                all_passed = False\n                validator_name = getattr(validator, \"__name__\", str(validator))\n                failure = ValidationFailure(\n                    validator_name=validator_name,\n                    actual_value=value,\n                    error_msg=f\"Exception during validation: {str(e)}\",\n                )\n                failures.append(failure)\n\n        self._checks[field_name] = all_passed\n        if not all_passed:\n            self._failures[field_name] = failures\n\n        return self\n\n    def _create_validation_failure(\n        self, validator: Callable, value: Any\n    ) -> ValidationFailure:\n        \"\"\"Create a ValidationFailure with detailed information.\"\"\"\n        validator_name = getattr(validator, \"__name__\", str(validator))\n\n        # Try to extract meaningful expected value info based on validator type\n        expected = None\n        error_msg = None\n\n        # Handle common validator patterns\n        if hasattr(validator, \"__closure__\") and validator.__closure__:\n            # This is likely a closure (like our helper functions)\n            if \"dims\" in validator_name or \"with_dims\" in str(validator):\n                if isinstance(value, torch.Tensor):\n                    expected = f\"tensor with {validator.__closure__[0].cell_contents} dimensions\"\n                else:\n                    expected = \"tensor with specific dimensions\"\n            elif \"divisible\" in str(validator):\n                expected = (\n                    f\"integer divisible by {validator.__closure__[0].cell_contents}\"\n                )\n\n        # Handle specific validator types and check for NaN values\n        if validator_name == \"is_tensor\":\n            expected = \"torch.Tensor without NaN values\"\n            if isinstance(value, torch.Tensor) and torch.isnan(value).any().item():\n                error_msg = (\n                    f\"tensor contains {torch.isnan(value).sum().item()} NaN values\"\n                )\n        elif validator_name == \"positive_int\":\n            expected = \"positive integer\"\n        elif validator_name == \"not_none\":\n            expected = \"non-None value\"\n        elif validator_name == \"list_not_empty\":\n            expected = \"non-empty list\"\n        elif validator_name == \"bool_value\":\n            expected = \"boolean value\"\n        elif (\n            \"tensor_with_dims\" in validator_name or \"tensor_min_dims\" in validator_name\n        ):\n            if isinstance(value, torch.Tensor):\n                if torch.isnan(value).any().item():\n                    error_msg = f\"tensor has {value.dim()} dimensions but contains {torch.isnan(value).sum().item()} NaN values\"\n                else:\n                    error_msg = f\"tensor has {value.dim()} dimensions\"\n        elif validator_name == \"is_list\":\n            expected = \"list\"\n        elif validator_name == \"none_or_tensor\":\n            expected = \"None or tensor without NaN values\"\n            if isinstance(value, torch.Tensor) and torch.isnan(value).any().item():\n                error_msg = (\n                    f\"tensor contains {torch.isnan(value).sum().item()} NaN values\"\n                )\n        elif validator_name == \"list_of_tensors\":\n            expected = \"non-empty list of tensors without NaN values\"\n            if isinstance(value, list) and len(value) > 0:\n                nan_count = 0\n                for item in value:\n                    if (\n                        isinstance(item, torch.Tensor)\n                        and torch.isnan(item).any().item()\n                    ):\n                        nan_count += torch.isnan(item).sum().item()\n                if nan_count > 0:\n                    error_msg = (\n                        f\"list contains tensors with total {nan_count} NaN values\"\n                    )\n        elif \"list_of_tensors_with_dims\" in validator_name:\n            expected = (\n                \"non-empty list of tensors with specific dimensions and no NaN values\"\n            )\n            if isinstance(value, list) and len(value) > 0:\n                nan_count = 0\n                for item in value:\n                    if (\n                        isinstance(item, torch.Tensor)\n                        and torch.isnan(item).any().item()\n                    ):\n                        nan_count += torch.isnan(item).sum().item()\n                if nan_count > 0:\n                    error_msg = (\n                        f\"list contains tensors with total {nan_count} NaN values\"\n                    )\n\n        return ValidationFailure(\n            validator_name=validator_name,\n            actual_value=value,\n            expected=expected,\n            error_msg=error_msg,\n        )\n\n    def is_valid(self) -> bool:\n        \"\"\"Check if all validations passed.\"\"\"\n        return all(self._checks.values())\n\n    def get_failed_fields(self) -> list[str]:\n        \"\"\"Get list of fields that failed validation.\"\"\"\n        return [field for field, passed in self._checks.items() if not passed]\n\n    def get_detailed_failures(self) -> dict[str, list[ValidationFailure]]:\n        \"\"\"Get detailed failure information for each failed field.\"\"\"\n        return self._failures.copy()\n\n    def get_failure_summary(self) -> str:\n        \"\"\"Get a comprehensive summary of all validation failures.\"\"\"\n        if self.is_valid():\n            return \"All validations passed\"\n\n        summary_parts = []\n        for field_name, failures in self._failures.items():\n            field_summary = f\"\\n  Field '{field_name}':\"\n            for i, failure in enumerate(failures, 1):\n                field_summary += f\"\\n    {i}. {failure}\"\n            summary_parts.append(field_summary)\n\n        return \"Validation failures:\" + \"\".join(summary_parts)\n\n    def to_dict(self) -> dict:\n        \"\"\"Convert to dictionary for backward compatibility.\"\"\"\n        return self._checks.copy()\n\n\n# Alias for convenience\nV = StageValidators\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/platforms/__init__.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/platforms/__init__.py\n\nimport traceback\nfrom typing import TYPE_CHECKING\n\n# imported by other files, do not remove\nfrom sglang.multimodal_gen.runtime.platforms.interface import (  # noqa: F401\n    AttentionBackendEnum,\n    Platform,\n    PlatformEnum,\n)\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.utils import resolve_obj_by_qualname\n\nlogger = init_logger(__name__)\n\n\ndef cuda_platform_plugin() -> str | None:\n    is_cuda = False\n\n    try:\n        from sglang.multimodal_gen.utils import import_pynvml\n\n        pynvml = import_pynvml()  # type: ignore[no-untyped-call]\n        pynvml.nvmlInit()\n        try:\n            # NOTE: Edge case: sgl_diffusion cpu build on a GPU machine.\n            # Third-party pynvml can be imported in cpu build,\n            # we need to check if sgl_diffusion is built with cpu too.\n            # Otherwise, sgl_diffusion will always activate cuda plugin\n            # on a GPU machine, even if in a cpu build.\n            is_cuda = pynvml.nvmlDeviceGetCount() > 0\n        finally:\n            pynvml.nvmlShutdown()\n    except Exception as e:\n        if \"nvml\" not in e.__class__.__name__.lower():\n            # If the error is not related to NVML, re-raise it.\n            raise e\n\n        # CUDA is supported on Jetson, but NVML may not be.\n        import os\n\n        def cuda_is_jetson() -> bool:\n            return os.path.isfile(\"/etc/nv_tegra_release\") or os.path.exists(\n                \"/sys/class/tegra-firmware\"\n            )\n\n        if cuda_is_jetson():\n            is_cuda = True\n    if is_cuda:\n        logger.debug(\"CUDA is available\")\n\n    return (\n        \"sglang.multimodal_gen.runtime.platforms.cuda.CudaPlatform\" if is_cuda else None\n    )\n\n\ndef mps_platform_plugin() -> str | None:\n    \"\"\"Detect if MPS (Metal Performance Shaders) is available on macOS.\"\"\"\n    is_mps = False\n\n    try:\n        import torch\n\n        if torch.backends.mps.is_available():\n            is_mps = True\n            logger.debug(\"MPS (Metal Performance Shaders) is available\")\n    except Exception as e:\n        logger.debug(\"MPS detection failed: %s\", e)\n\n    return \"sglang.multimodal_gen.runtime.platforms.mps.MpsPlatform\" if is_mps else None\n\n\ndef cpu_platform_plugin() -> str | None:\n    \"\"\"Detect if CPU platform should be used.\"\"\"\n    # CPU is always available as a fallback\n    return \"sglang.multimodal_gen.runtime.platforms.cpu.CpuPlatform\"\n\n\ndef rocm_platform_plugin() -> str | None:\n    is_rocm = False\n\n    try:\n        import amdsmi\n\n        amdsmi.amdsmi_init()\n        try:\n            if len(amdsmi.amdsmi_get_processor_handles()) > 0:\n                is_rocm = True\n                logger.debug(\"ROCm platform is available\")\n        finally:\n            amdsmi.amdsmi_shut_down()\n    except Exception as e:\n        logger.debug(\"ROCm platform is unavailable: %s\", e)\n\n    return (\n        \"sglang.multimodal_gen.runtime.platforms.rocm.RocmPlatform\" if is_rocm else None\n    )\n\n\ndef npu_platform_plugin() -> str | None:\n    is_npu = False\n\n    try:\n        import torch\n\n        if torch.npu.is_available():\n            is_npu = True\n            logger.debug(\"NPU is available\")\n    except Exception as e:\n        logger.debug(\"NPU detection failed: %s\", e)\n    return (\n        \"sglang.multimodal_gen.runtime.platforms.npu.NPUPlatformBase\"\n        if is_npu\n        else None\n    )\n\n\ndef musa_platform_plugin() -> str | None:\n    is_musa = False\n\n    try:\n        import pymtml\n\n        pymtml.mtmlLibraryInit()\n        try:\n            is_musa = pymtml.mtmlLibraryCountDevice() > 0\n        finally:\n            pymtml.mtmlLibraryShutDown()\n    except Exception as e:\n        logger.debug(\"MUSA platform is unavailable: %s\", e)\n\n    return (\n        \"sglang.multimodal_gen.runtime.platforms.musa.MusaPlatform\" if is_musa else None\n    )\n\n\nbuiltin_platform_plugins = {\n    \"cuda\": cuda_platform_plugin,\n    \"rocm\": rocm_platform_plugin,\n    \"mps\": mps_platform_plugin,\n    \"cpu\": cpu_platform_plugin,\n    \"npu\": npu_platform_plugin,\n    \"musa\": musa_platform_plugin,\n}\n\n\ndef resolve_current_platform_cls_qualname() -> str:\n    # TODO(will): if we need to support other platforms, we should consider if\n    # vLLM's plugin architecture is suitable for our needs.\n\n    # Try MPS first on macOS\n    platform_cls_qualname = mps_platform_plugin()\n    if platform_cls_qualname is not None:\n        return platform_cls_qualname\n\n    # Fall back to ROCm\n    platform_cls_qualname = rocm_platform_plugin()\n    if platform_cls_qualname is not None:\n        return platform_cls_qualname\n\n    # Fall back to CUDA\n    platform_cls_qualname = cuda_platform_plugin()\n    if platform_cls_qualname is not None:\n        return platform_cls_qualname\n\n    # Fall back to NPU\n    platform_cls_qualname = npu_platform_plugin()\n    if platform_cls_qualname is not None:\n        return platform_cls_qualname\n\n    # Fall back to MUSA\n    platform_cls_qualname = musa_platform_plugin()\n    if platform_cls_qualname is not None:\n        return platform_cls_qualname\n\n    # Fall back to CPU as last resort\n    platform_cls_qualname = cpu_platform_plugin()\n    if platform_cls_qualname is not None:\n        return platform_cls_qualname\n\n    raise RuntimeError(\"No platform plugin found. Please check your \" \"installation.\")\n\n\n_current_platform: Platform | None = None\n_init_trace: str = \"\"\n\ncurrent_platform: Platform\n\n\ndef __getattr__(name: str):\n    if name == \"current_platform\":\n        # lazy init current_platform.\n        # 1. out-of-tree platform plugins need `from sglang.multimodal_gen.runtime.platforms import\n        #    Platform` so that they can inherit `Platform` class. Therefore,\n        #    we cannot resolve `current_platform` during the import of\n        #    `sglang.multimodal_gen.runtime.platforms`.\n        global _current_platform\n        if _current_platform is None:\n            platform_cls_qualname = resolve_current_platform_cls_qualname()\n            _current_platform = resolve_obj_by_qualname(platform_cls_qualname)()\n            global _init_trace\n            _init_trace = \"\".join(traceback.format_stack())\n        return _current_platform\n    elif name in globals():\n        return globals()[name]\n    else:\n        raise AttributeError(f\"No attribute named '{name}' exists in {__name__}.\")\n\n\n__all__ = [\"Platform\", \"PlatformEnum\", \"current_platform\", \"_init_trace\"]\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/platforms/cpu.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/platforms/cpu.py\n\nimport platform\nfrom functools import lru_cache\nfrom typing import Any\n\nimport psutil\nimport torch\n\nfrom sglang.multimodal_gen.runtime.platforms.interface import (\n    CpuArchEnum,\n    Platform,\n    PlatformEnum,\n)\n\n\nclass CpuPlatform(Platform):\n    _enum = PlatformEnum.CPU\n    device_name = \"CPU\"\n    device_type = \"cpu\"\n    dispatch_key = \"CPU\"\n\n    @classmethod\n    def get_cpu_architecture(cls) -> CpuArchEnum:\n        \"\"\"Get the CPU architecture.\"\"\"\n        machine = platform.machine().lower()\n        if machine in (\"x86_64\", \"amd64\", \"i386\", \"i686\"):\n            return CpuArchEnum.X86\n        elif machine in (\"arm64\", \"aarch64\"):\n            return CpuArchEnum.ARM\n        else:\n            return CpuArchEnum.UNSPECIFIED\n\n    @classmethod\n    def get_device_name(cls, device_id: int = 0) -> str:\n        return platform.processor()\n\n    @classmethod\n    def get_device_uuid(cls, device_id: int = 0) -> str:\n        return platform.machine()\n\n    @classmethod\n    @lru_cache(maxsize=1)\n    def get_device_total_memory(cls, device_id: int = 0) -> int:\n\n        return psutil.virtual_memory().total\n\n    @classmethod\n    def is_async_output_supported(cls, enforce_eager: bool | None) -> bool:\n        return True\n\n    @classmethod\n    def get_current_memory_usage(\n        cls, device: torch.types.Device | None = None\n    ) -> float:\n        # For CPU, we can't easily get memory usage without additional libraries\n        return 0.0\n\n    @classmethod\n    def get_available_gpu_memory(\n        cls,\n        device_id: int = 0,\n        distributed: bool = False,\n        empty_cache: bool = True,\n        cpu_group: Any = None,\n    ) -> float:\n\n        total_free_memory = psutil.virtual_memory().available\n        # For simplicity, we assume 1 NUMA node for now in this platform abstraction\n        # as get_cpu_ids_by_node is not available in multimodal_gen.runtime.utils\n        n_numa_node = 1\n        free_memory = total_free_memory / n_numa_node\n\n        if distributed:\n            import torch.distributed as dist\n\n            tensor = torch.tensor(free_memory, dtype=torch.float32)\n            dist.all_reduce(tensor, op=dist.ReduceOp.MIN, group=cpu_group)\n            free_memory = float(tensor.item())\n\n        return free_memory / (1 << 30)\n\n    @classmethod\n    def get_device_communicator_cls(cls) -> str:\n        return \"sglang.multimodal_gen.runtime.distributed.device_communicators.cpu_communicator.CpuCommunicator\"\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/platforms/cuda.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/platforms/cuda.py\n\"\"\"Code inside this file can safely assume cuda platform, e.g. importing\npynvml. However, it should not initialize cuda context.\n\"\"\"\n\nimport os\nfrom collections.abc import Callable\nfrom functools import lru_cache, wraps\nfrom typing import Any, TypeVar\n\nimport psutil\nimport torch\nfrom typing_extensions import ParamSpec\n\nfrom sglang.multimodal_gen import envs\nfrom sglang.multimodal_gen.runtime.platforms.interface import (\n    AttentionBackendEnum,\n    DeviceCapability,\n    Platform,\n    PlatformEnum,\n)\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.utils import import_pynvml\n\nlogger = init_logger(__name__)\n\n_P = ParamSpec(\"_P\")\n_R = TypeVar(\"_R\")\n\npynvml = import_pynvml()  # type: ignore[no-untyped-call]\n\n# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models\n# see https://github.com/huggingface/diffusers/issues/9704 for details\ntorch.backends.cuda.enable_cudnn_sdp(False)\n\n\ndef device_id_to_physical_device_id(device_id: int) -> int:\n    if \"CUDA_VISIBLE_DEVICES\" in os.environ:\n        device_ids = os.environ[\"CUDA_VISIBLE_DEVICES\"].split(\",\")\n        if device_ids == [\"\"]:\n            msg = (\n                \"CUDA_VISIBLE_DEVICES is set to empty string, which means\"\n                \" GPU support is disabled. If you are using ray, please unset\"\n                \" the environment variable `CUDA_VISIBLE_DEVICES` inside the\"\n                \" worker/actor. \"\n                \"Check https://github.com/vllm-project/vllm/issues/8402 for\"\n                \" more information.\"\n            )\n            raise RuntimeError(msg)\n        physical_device_id = device_ids[device_id]\n        return int(physical_device_id)\n    else:\n        return device_id\n\n\ndef with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:\n    @wraps(fn)\n    def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:\n        pynvml.nvmlInit()\n        try:\n            return fn(*args, **kwargs)\n        finally:\n            pynvml.nvmlShutdown()\n\n    return wrapper\n\n\nclass CudaPlatformBase(Platform):\n    _enum = PlatformEnum.CUDA\n    device_name: str = \"cuda\"\n    device_type: str = \"cuda\"\n    dispatch_key: str = \"CUDA\"\n    device_control_env_var: str = \"CUDA_VISIBLE_DEVICES\"\n\n    @classmethod\n    def get_local_torch_device(cls) -> torch.device:\n        return torch.device(f\"cuda:{envs.LOCAL_RANK}\")\n\n    @classmethod\n    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:\n        raise NotImplementedError\n\n    @classmethod\n    def get_device_name(cls, device_id: int = 0) -> str:\n        raise NotImplementedError\n\n    @classmethod\n    @lru_cache(maxsize=1)\n    def get_device_total_memory(cls, device_id: int = 0) -> int:\n        raise NotImplementedError\n\n    @classmethod\n    def is_async_output_supported(cls, enforce_eager: bool | None) -> bool:\n        if enforce_eager:\n            logger.warning(\n                \"To see benefits of async output processing, enable CUDA \"\n                \"graph. Since, enforce-eager is enabled, async output \"\n                \"processor cannot be used\"\n            )\n            return False\n        return True\n\n    @classmethod\n    def is_full_nvlink(cls, device_ids: list[int]) -> bool:\n        raise NotImplementedError\n\n    @classmethod\n    def log_warnings(cls) -> None:\n        pass\n\n    @classmethod\n    def get_current_memory_usage(\n        cls, device: torch.types.Device | None = None\n    ) -> float:\n        torch.cuda.reset_peak_memory_stats(device)\n        return float(torch.cuda.max_memory_allocated(device))\n\n    @classmethod\n    def get_available_gpu_memory(\n        cls,\n        device_id: int = 0,\n        distributed: bool = False,\n        empty_cache: bool = True,\n        cpu_group: Any = None,\n    ) -> float:\n        if empty_cache:\n            torch.cuda.empty_cache()\n\n        if torch.distributed.is_initialized():\n            device_id = torch.distributed.get_rank()\n\n        device_props = torch.cuda.get_device_properties(device_id)\n        if device_props.is_integrated:\n            free_gpu_memory = psutil.virtual_memory().available\n        else:\n            free_gpu_memory, _ = torch.cuda.mem_get_info(device_id)\n\n        if distributed:\n            import torch.distributed as dist\n\n            tensor = torch.tensor(free_gpu_memory, dtype=torch.float32, device=\"cuda\")\n            dist.all_reduce(tensor, op=dist.ReduceOp.MIN, group=cpu_group)\n            free_gpu_memory = float(tensor.item())\n\n        return free_gpu_memory / (1 << 30)\n\n    @classmethod\n    def get_attn_backend_cls_str(\n        cls,\n        selected_backend: AttentionBackendEnum | None,\n        head_size: int,\n        dtype: torch.dtype,\n    ) -> str:\n        target_backend: AttentionBackendEnum | None = None\n        # TODO(will): maybe come up with a more general interface for local attention\n        # if distributed is False, we always try to use Flash attn\n        if selected_backend == AttentionBackendEnum.SLIDING_TILE_ATTN:\n            try:\n                from st_attn import sliding_tile_attention  # noqa: F401\n\n                from sglang.multimodal_gen.runtime.layers.attention.backends.sliding_tile_attn import (  # noqa: F401\n                    SlidingTileAttentionBackend,\n                )\n\n                logger.info(\"Using Sliding Tile Attention backend\")\n\n                return \"sglang.multimodal_gen.runtime.layers.attention.backends.sliding_tile_attn.SlidingTileAttentionBackend\"\n            except ImportError as e:\n                logger.error(\n                    \"Failed to import Sliding Tile Attention backend: %s\", str(e)\n                )\n                raise ImportError(\n                    \"Sliding Tile Attention backend is not installed. \"\n                ) from e\n        elif selected_backend == AttentionBackendEnum.SAGE_ATTN:\n            try:\n                from sageattention import sageattn  # noqa: F401\n\n                from sglang.multimodal_gen.runtime.layers.attention.backends.sage_attn import (  # noqa: F401\n                    SageAttentionBackend,\n                )\n\n                logger.info(\"Using Sage Attention backend\")\n\n                return \"sglang.multimodal_gen.runtime.layers.attention.backends.sage_attn.SageAttentionBackend\"\n            except ImportError as e:\n                logger.info(e)\n                logger.info(\n                    \"Sage Attention backend is not installed (To install it, run `pip install sageattention==2.2.0 --no-build-isolation`). Falling back to Flash Attention.\"\n                )\n                target_backend = AttentionBackendEnum.FA\n        elif selected_backend == AttentionBackendEnum.SAGE_ATTN_3:\n            try:\n                from sglang.multimodal_gen.runtime.layers.attention.backends.sage_attn3 import (  # noqa: F401\n                    SageAttention3Backend,\n                )\n\n                logger.info(\"Using Sage Attention 3 backend\")\n                return \"sglang.multimodal_gen.runtime.layers.attention.backends.sage_attn3.SageAttention3Backend\"\n            except ImportError as e:\n                logger.info(e)\n                logger.info(\n                    \"Sage Attention 3 backend is not installed (To install it, see https://github.com/thu-ml/SageAttention/tree/main/sageattention3_blackwell#installation). Falling back to Torch SDPA.\"\n                )\n                target_backend = AttentionBackendEnum.TORCH_SDPA\n        elif selected_backend == AttentionBackendEnum.VIDEO_SPARSE_ATTN:\n            try:\n                from vsa import block_sparse_attn  # noqa: F401\n\n                from sglang.multimodal_gen.runtime.layers.attention.backends.video_sparse_attn import (  # noqa: F401\n                    VideoSparseAttentionBackend,\n                )\n\n                logger.info(\"Using Video Sparse Attention backend\")\n\n                return \"sglang.multimodal_gen.runtime.layers.attention.backends.video_sparse_attn.VideoSparseAttentionBackend\"\n            except ImportError as e:\n                logger.error(\n                    \"Failed to import Video Sparse Attention backend: %s\", str(e)\n                )\n                raise ImportError(\n                    \"Video Sparse Attention backend is not installed.\"\n                ) from e\n        elif selected_backend == AttentionBackendEnum.SPARSE_VIDEO_GEN_2_ATTN:\n            try:\n                from svg.kernels.triton.permute import (  # noqa: F401\n                    apply_inverse_permutation_triton,\n                    permute_tensor_by_labels_triton,\n                )\n                from svg.kmeans_utils import (  # noqa: F401\n                    batch_kmeans_Euclid,\n                    density_calculation,\n                    dynamic_block_sparse_fwd_flashinfer,\n                    identify_dynamic_map,\n                )\n\n                from sglang.multimodal_gen.runtime.layers.attention.backends.sparse_video_gen_2_attn import (  # noqa: F401\n                    SparseVideoGen2AttentionBackend,\n                )\n\n                logger.info(\"Using Sparse Video Gen 2 (SAP) Attention backend\")\n                return \"sglang.multimodal_gen.runtime.layers.attention.backends.sparse_video_gen_2_attn.SparseVideoGen2AttentionBackend\"\n            except ImportError as e:\n                logger.error(\n                    \"Failed to import Sparse Video Gen 2 (SAP) Attention backend: %s\",\n                    str(e),\n                )\n                raise ImportError(\n                    \"Sparse Video Gen 2 (SAP) Attention backend is not installed. \"\n                    \"Please install it by following the instructions at \"\n                    \"https://github.com/svg-project/Sparse-VideoGen\"\n                ) from e\n        elif selected_backend == AttentionBackendEnum.VMOBA_ATTN:\n            try:\n                from kernel.attn.vmoba_attn.vmoba import moba_attn_varlen  # noqa: F401\n\n                from sglang.multimodal_gen.runtime.layers.attention.backends.vmoba import (  # noqa: F401\n                    VMOBAAttentionBackend,\n                )\n\n                logger.info(\"Using Video MOBA Attention backend\")\n\n                return \"sglang.multimodal_gen.runtime.layers.attention.backends.vmoba.VMOBAAttentionBackend\"\n            except ImportError as e:\n                logger.error(\n                    \"Failed to import Video MoBA Attention backend: %s\", str(e)\n                )\n                raise ImportError(\n                    \"Video MoBA Attention backend is not installed. \"\n                ) from e\n        elif selected_backend == AttentionBackendEnum.AITER:\n            logger.info(\"Using AITer backend\")\n            return \"sglang.multimodal_gen.runtime.layers.attention.backends.aiter.AITerBackend\"\n        elif selected_backend == AttentionBackendEnum.TORCH_SDPA:\n            logger.info(\"Using Torch SDPA backend\")\n            return \"sglang.multimodal_gen.runtime.layers.attention.backends.sdpa.SDPABackend\"\n        elif selected_backend == AttentionBackendEnum.SLA_ATTN:\n            logger.info(\"Using Sparse Linear Attention backend\")\n            return \"sglang.multimodal_gen.runtime.layers.attention.backends.sparse_linear_attn.SparseLinearAttentionBackend\"\n        elif selected_backend == AttentionBackendEnum.SAGE_SLA_ATTN:\n            logger.info(\"Using Sage Sparse Linear Attention backend\")\n            return \"sglang.multimodal_gen.runtime.layers.attention.backends.sparse_linear_attn.SageSparseLinearAttentionBackend\"\n        elif selected_backend == AttentionBackendEnum.FA2:\n            from sglang.multimodal_gen.runtime.layers.attention.backends.flash_attn_2 import (  # noqa: F401\n                FlashAttention2Backend,\n            )\n\n            logger.info(\"Using FlashAttention2 backend\")\n            return \"sglang.multimodal_gen.runtime.layers.attention.backends.flash_attn_2.FlashAttention2Backend\"\n        elif selected_backend in [\n            AttentionBackendEnum.FA,\n        ]:\n            if cls.is_sm120():\n                logger.info(\n                    \"FlashAttention is not supported on SM12.x in this build; falling back to Torch SDPA.\"\n                )\n                target_backend = AttentionBackendEnum.TORCH_SDPA\n            elif cls.is_blackwell():\n                from sglang.multimodal_gen.runtime.layers.attention.backends.flash_attn import (\n                    set_fa_ver,\n                )\n\n                set_fa_ver(4)\n                target_backend = AttentionBackendEnum.FA\n            else:\n                target_backend = AttentionBackendEnum.FA\n        elif selected_backend:\n            raise ValueError(f\"Invalid attention backend for {cls.device_name}\")\n        else:\n            if cls.is_sm120():\n                # On SM12.x, the sgl-kernel FlashAttention wheels may not include\n                # support yet. Default to Torch SDPA for correctness.\n                logger.info(\"Defaulting to Torch SDPA backend on SM12.x\")\n                target_backend = AttentionBackendEnum.TORCH_SDPA\n            elif cls.is_blackwell():\n                from sglang.multimodal_gen.runtime.layers.attention.backends.flash_attn import (\n                    set_fa_ver,\n                )\n\n                set_fa_ver(4)\n                target_backend = AttentionBackendEnum.FA\n            else:\n                target_backend = AttentionBackendEnum.FA\n\n        # Ensure we have a target backend selected before validation/fallback.\n        if target_backend is None:\n            target_backend = AttentionBackendEnum.FA\n\n        if target_backend == AttentionBackendEnum.FA and cls.is_blackwell():\n            from sglang.multimodal_gen.runtime.layers.attention.backends.flash_attn import (\n                set_fa_ver,\n            )\n\n            set_fa_ver(4)\n\n        if not cls.has_device_capability(80):\n            logger.info(\"Cannot use FlashAttention backend for Volta and Turing GPUs.\")\n            target_backend = AttentionBackendEnum.TORCH_SDPA\n        elif dtype not in (torch.float16, torch.bfloat16):\n            logger.info(\n                \"Cannot use FlashAttention backend for dtype other than \"\n                \"torch.float16 or torch.bfloat16.\"\n            )\n            target_backend = AttentionBackendEnum.TORCH_SDPA\n        # FlashAttn is valid for the model, checking if the package is\n        # installed.\n        if target_backend == AttentionBackendEnum.FA:\n            try:\n                from sglang.multimodal_gen.runtime.layers.attention.backends.flash_attn import (  # noqa: F401\n                    FlashAttentionBackend,\n                )\n\n                supported_sizes = FlashAttentionBackend.get_supported_head_sizes()\n                if head_size not in supported_sizes:\n                    logger.info(\n                        \"Cannot use FlashAttention backend for head size %d.\",\n                        head_size,\n                    )\n                    target_backend = AttentionBackendEnum.TORCH_SDPA\n            except ImportError:\n                logger.info(\n                    \"Cannot use FlashAttention backend because the \"\n                    \"flash_attn package is not found. \"\n                    \"Make sure that flash_attn was built and installed \"\n                    \"(on by default).\"\n                )\n                target_backend = AttentionBackendEnum.TORCH_SDPA\n\n        if target_backend == AttentionBackendEnum.TORCH_SDPA:\n            logger.info(\"Using Torch SDPA backend\")\n\n            return \"sglang.multimodal_gen.runtime.layers.attention.backends.sdpa.SDPABackend\"\n\n        logger.info(\"Using FlashAttention (FA3 for hopper, FA4 for blackwell) backend\")\n\n        return \"sglang.multimodal_gen.runtime.layers.attention.backends.flash_attn.FlashAttentionBackend\"\n\n    @classmethod\n    def get_device_communicator_cls(cls) -> str:\n        return \"sglang.multimodal_gen.runtime.distributed.device_communicators.cuda_communicator.CudaCommunicator\"  # noqa\n\n\n# NVML utils\n# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,\n# all the related functions work on real physical device ids.\n# the major benefit of using NVML is that it will not initialize CUDA\nclass NvmlCudaPlatform(CudaPlatformBase):\n    @classmethod\n    @lru_cache(maxsize=8)\n    @with_nvml_context\n    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:\n        try:\n            physical_device_id = device_id_to_physical_device_id(device_id)\n            handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)\n            major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)\n            return DeviceCapability(major=major, minor=minor)\n        except RuntimeError:\n            return None\n\n    @classmethod\n    @lru_cache(maxsize=8)\n    @with_nvml_context\n    def has_device_capability(\n        cls,\n        capability: tuple[int, int] | int,\n        device_id: int = 0,\n    ) -> bool:\n        try:\n            return bool(super().has_device_capability(capability, device_id))\n        except RuntimeError:\n            return False\n\n    @classmethod\n    @lru_cache(maxsize=8)\n    @with_nvml_context\n    def get_device_name(cls, device_id: int = 0) -> str:\n        physical_device_id = device_id_to_physical_device_id(device_id)\n        return cls._get_physical_device_name(physical_device_id)\n\n    @classmethod\n    @lru_cache(maxsize=8)\n    @with_nvml_context\n    def get_device_uuid(cls, device_id: int = 0) -> str:\n        physical_device_id = device_id_to_physical_device_id(device_id)\n        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)\n        return str(pynvml.nvmlDeviceGetUUID(handle))\n\n    @classmethod\n    @lru_cache(maxsize=8)\n    @with_nvml_context\n    def get_device_total_memory(cls, device_id: int = 0) -> int:\n        physical_device_id = device_id_to_physical_device_id(device_id)\n        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)\n        return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)\n\n    @classmethod\n    @with_nvml_context\n    def is_full_nvlink(cls, physical_device_ids: list[int]) -> bool:\n        \"\"\"\n        query if the set of gpus are fully connected by nvlink (1 hop)\n        \"\"\"\n        handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]\n        for i, handle in enumerate(handles):\n            for j, peer_handle in enumerate(handles):\n                if i < j:\n                    try:\n                        p2p_status = pynvml.nvmlDeviceGetP2PStatus(\n                            handle,\n                            peer_handle,\n                            pynvml.NVML_P2P_CAPS_INDEX_NVLINK,\n                        )\n                        if p2p_status != pynvml.NVML_P2P_STATUS_OK:\n                            return False\n                    except pynvml.NVMLError:\n                        logger.exception(\n                            \"NVLink detection failed. This is normal if\"\n                            \" your machine has no NVLink equipped.\"\n                        )\n                        return False\n        return True\n\n    @classmethod\n    def _get_physical_device_name(cls, device_id: int = 0) -> str:\n        handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)\n        return str(pynvml.nvmlDeviceGetName(handle))\n\n    @classmethod\n    @with_nvml_context\n    def log_warnings(cls) -> None:\n        device_ids: int = pynvml.nvmlDeviceGetCount()\n        if device_ids > 1:\n            device_names = [cls._get_physical_device_name(i) for i in range(device_ids)]\n            if (\n                len(set(device_names)) > 1\n                and os.environ.get(\"CUDA_DEVICE_ORDER\") != \"PCI_BUS_ID\"\n            ):\n                logger.warning(\n                    \"Detected different devices in the system: %s. Please\"\n                    \" make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to \"\n                    \"avoid unexpected behavior.\",\n                    \", \".join(device_names),\n                )\n\n\nclass NonNvmlCudaPlatform(CudaPlatformBase):\n    @classmethod\n    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:\n        major, minor = torch.cuda.get_device_capability(device_id)\n        return DeviceCapability(major=major, minor=minor)\n\n    @classmethod\n    def get_device_name(cls, device_id: int = 0) -> str:\n        return str(torch.cuda.get_device_name(device_id))\n\n    @classmethod\n    @lru_cache(maxsize=1)\n    def get_device_total_memory(cls, device_id: int = 0) -> int:\n        device_props = torch.cuda.get_device_properties(device_id)\n        return int(device_props.total_memory)\n\n    @classmethod\n    def is_full_nvlink(cls, physical_device_ids: list[int]) -> bool:\n        logger.exception(\n            \"NVLink detection not possible, as context support was\"\n            \" not found. Assuming no NVLink available.\"\n        )\n        return False\n\n\n# Autodetect either NVML-enabled or non-NVML platform\n# based on whether NVML is available.\nnvml_available = False\ntry:\n    try:\n        pynvml.nvmlInit()\n        nvml_available = True\n    except Exception:\n        # On Jetson, NVML is not supported.\n        nvml_available = False\nfinally:\n    if nvml_available:\n        pynvml.nvmlShutdown()\n\nCudaPlatform = NvmlCudaPlatform if nvml_available else NonNvmlCudaPlatform\n\ntry:\n    from sphinx.ext.autodoc.mock import _MockModule\n\n    if not isinstance(pynvml, _MockModule):\n        CudaPlatform.log_warnings()\nexcept ModuleNotFoundError:\n    CudaPlatform.log_warnings()\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/platforms/interface.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/platforms/interface.py\nfrom __future__ import annotations\n\nimport enum\nimport random\nfrom functools import lru_cache\nfrom typing import TYPE_CHECKING, Any, NamedTuple\n\nimport numpy as np\nimport torch\n\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.utils import resolve_obj_by_qualname\n\nif TYPE_CHECKING:\n    from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import (\n        AttentionImpl,\n    )\n\nlogger = init_logger(__name__)\n\n\nclass AttentionBackendEnum(enum.Enum):\n    FA2 = enum.auto()\n    FA = enum.auto()\n    SLIDING_TILE_ATTN = enum.auto()\n    TORCH_SDPA = enum.auto()\n    SAGE_ATTN = enum.auto()\n    SAGE_ATTN_3 = enum.auto()\n    VIDEO_SPARSE_ATTN = enum.auto()\n    SPARSE_VIDEO_GEN_2_ATTN = enum.auto()\n    VMOBA_ATTN = enum.auto()\n    AITER = enum.auto()\n    AITER_SAGE = enum.auto()\n    SLA_ATTN = enum.auto()\n    SAGE_SLA_ATTN = enum.auto()\n    NO_ATTENTION = enum.auto()\n\n    def __str__(self):\n        return self.name.lower()\n\n    @property\n    def is_sparse(self) -> bool:\n        return self in {\n            AttentionBackendEnum.SLIDING_TILE_ATTN,\n            AttentionBackendEnum.VIDEO_SPARSE_ATTN,\n            AttentionBackendEnum.SPARSE_VIDEO_GEN_2_ATTN,\n            AttentionBackendEnum.VMOBA_ATTN,\n            AttentionBackendEnum.SLA_ATTN,\n            AttentionBackendEnum.SAGE_SLA_ATTN,\n        }\n\n\nclass PlatformEnum(enum.Enum):\n    CUDA = enum.auto()\n    ROCM = enum.auto()\n    TPU = enum.auto()\n    CPU = enum.auto()\n    MPS = enum.auto()\n    NPU = enum.auto()\n    MUSA = enum.auto()\n    OOT = enum.auto()\n    UNSPECIFIED = enum.auto()\n\n\nclass CpuArchEnum(enum.Enum):\n    X86 = enum.auto()\n    ARM = enum.auto()\n    UNSPECIFIED = enum.auto()\n\n\nclass DeviceCapability(NamedTuple):\n    major: int\n    minor: int\n\n    def as_version_str(self) -> str:\n        return f\"{self.major}.{self.minor}\"\n\n    def to_int(self) -> int:\n        \"\"\"\n        Express device capability as an integer ``<major><minor>``.\n\n        It is assumed that the minor version is always a single digit.\n        \"\"\"\n        assert 0 <= self.minor < 10\n        return self.major * 10 + self.minor\n\n\nclass Platform:\n    _enum: PlatformEnum\n    device_name: str\n    device_type: str\n    device: torch.device | None = None  # Dummy attribute for compatibility\n\n    # available dispatch keys:\n    # check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa\n    # use \"CPU\" as a fallback for platforms not registered in PyTorch\n    dispatch_key: str = \"CPU\"\n\n    # The torch.compile backend for compiling simple and\n    # standalone functions. The default value is \"inductor\" to keep\n    # the same behavior as PyTorch.\n    # NOTE: for the forward part of the model, vLLM has another separate\n    # compilation strategy.\n    simple_compile_backend: str = \"inductor\"\n\n    supported_quantization: list[str] = []\n\n    @lru_cache(maxsize=1)\n    def is_cuda(self) -> bool:\n        return self.is_cuda_static()\n\n    @lru_cache(maxsize=1)\n    def is_npu(self) -> bool:\n        return self._enum == PlatformEnum.NPU\n\n    @lru_cache(maxsize=1)\n    def is_rocm(self) -> bool:\n        return self.is_rocm_static()\n\n    @lru_cache(maxsize=1)\n    def is_tpu(self) -> bool:\n        return self._enum == PlatformEnum.TPU\n\n    @lru_cache(maxsize=1)\n    def is_cpu(self) -> bool:\n        return self._enum == PlatformEnum.CPU\n\n    @classmethod\n    @lru_cache(maxsize=1)\n    def is_blackwell(cls):\n        if not cls.is_cuda_static():\n            return False\n        return torch.cuda.get_device_capability()[0] == 10\n\n    @classmethod\n    @lru_cache(maxsize=1)\n    def is_hopper(cls):\n        if not cls.is_cuda_static():\n            return False\n        return torch.cuda.get_device_capability() == (9, 0)\n\n    @classmethod\n    @lru_cache(maxsize=1)\n    def is_sm120(cls):\n        if not cls.is_cuda_static():\n            return False\n        return torch.cuda.get_device_capability()[0] == 12\n\n    @classmethod\n    def is_cuda_static(cls) -> bool:\n        return getattr(cls, \"_enum\", None) == PlatformEnum.CUDA\n\n    @classmethod\n    def is_rocm_static(cls) -> bool:\n        return getattr(cls, \"_enum\", None) == PlatformEnum.ROCM\n\n    @lru_cache(maxsize=1)\n    def is_hpu(self) -> bool:\n        return hasattr(torch, \"hpu\") and torch.hpu.is_available()\n\n    @lru_cache(maxsize=1)\n    def is_xpu(self) -> bool:\n        return hasattr(torch, \"xpu\") and torch.xpu.is_available()\n\n    @lru_cache(maxsize=1)\n    def is_npu(self) -> bool:\n        return hasattr(torch, \"npu\") and torch.npu.is_available()\n\n    def is_out_of_tree(self) -> bool:\n        return self._enum == PlatformEnum.OOT\n\n    @lru_cache(maxsize=1)\n    def is_cuda_alike(self) -> bool:\n        \"\"\"Stateless version of :func:`torch.cuda.is_available`.\"\"\"\n        return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM, PlatformEnum.MUSA)\n\n    @lru_cache(maxsize=1)\n    def is_mps(self) -> bool:\n        return self._enum == PlatformEnum.MPS\n\n    @lru_cache(maxsize=1)\n    def is_musa(self):\n        try:\n            return hasattr(torch, \"musa\") and torch.musa.is_available()\n        except ModuleNotFoundError:\n            return False\n\n    @lru_cache(maxsize=1)\n    def is_hip(self) -> bool:\n        return self.is_rocm()\n\n    @classmethod\n    @lru_cache(maxsize=1)\n    def is_amp_supported(cls) -> bool:\n        return True\n\n    @classmethod\n    def get_local_torch_device(cls) -> torch.device:\n        raise NotImplementedError\n\n    @classmethod\n    def get_attn_backend_cls_str(\n        cls,\n        selected_backend: AttentionBackendEnum | None,\n        head_size: int,\n        dtype: torch.dtype,\n    ) -> str:\n        \"\"\"Get the attention backend class of a device.\"\"\"\n        return \"\"\n\n    @classmethod\n    def get_device_capability(\n        cls,\n        device_id: int = 0,\n    ) -> DeviceCapability | None:\n        \"\"\"Stateless version of :func:`torch.cuda.get_device_capability`.\"\"\"\n        return None\n\n    @classmethod\n    def has_device_capability(\n        cls,\n        capability: tuple[int, int] | int,\n        device_id: int = 0,\n    ) -> bool:\n        \"\"\"\n        Test whether this platform is compatible with a device capability.\n\n        The ``capability`` argument can either be:\n\n        - A tuple ``(major, minor)``.\n        - An integer ``<major><minor>``. (See :meth:`DeviceCapability.to_int`)\n        \"\"\"\n        current_capability = cls.get_device_capability(device_id=device_id)\n        if current_capability is None:\n            return False\n\n        if isinstance(capability, tuple):\n            return current_capability >= capability\n\n        return current_capability.to_int() >= capability\n\n    @classmethod\n    def get_device_name(cls, device_id: int = 0) -> str:\n        \"\"\"Get the name of a device.\"\"\"\n        raise NotImplementedError\n\n    @classmethod\n    def get_device_uuid(cls, device_id: int = 0) -> str:\n        \"\"\"Get the uuid of a device, e.g. the PCI bus ID.\"\"\"\n        raise NotImplementedError\n\n    @classmethod\n    @lru_cache(maxsize=1)\n    def get_device_total_memory(cls, device_id: int = 0) -> int:\n        \"\"\"Get the total memory of a device in bytes.\"\"\"\n        raise NotImplementedError\n\n    @lru_cache(maxsize=1)\n    def get_device(self, local_rank: int) -> torch.device:\n        if self.is_cuda() or self.is_rocm():\n            return torch.device(\"cuda\", local_rank)\n        elif self.is_npu():\n            return torch.device(\"npu\", local_rank)\n        elif self.is_musa():\n            return torch.device(\"musa\", local_rank)\n        elif self.is_mps():\n            return torch.device(\"mps\")\n        else:\n            return torch.device(\"cpu\")\n\n    @lru_cache(maxsize=1)\n    def get_torch_distributed_backend_str(self) -> str:\n        if self.is_cuda_alike():\n            return \"nccl\"\n        elif self.is_npu():\n            return \"hccl\"\n        elif self.is_musa():\n            return \"mccl\"\n        elif self.is_mps():\n            return \"gloo\"\n        else:\n            raise NotImplementedError(\n                \"No Accelerators(AMD/NV/MTT GPU, AMD MI instinct accelerators) available\"\n            )\n\n    @classmethod\n    def is_async_output_supported(cls, enforce_eager: bool | None) -> bool:\n        \"\"\"\n        Check if the current platform supports async output.\n        \"\"\"\n        raise NotImplementedError\n\n    @classmethod\n    def inference_mode(cls):\n        \"\"\"A device-specific wrapper of `torch.inference_mode`.\n\n        This wrapper is recommended because some hardware backends such as TPU\n        do not support `torch.inference_mode`. In such a case, they will fall\n        back to `torch.no_grad` by overriding this method.\n        \"\"\"\n        return torch.inference_mode(mode=True)\n\n    @classmethod\n    def seed_everything(cls, seed: int | None = None) -> None:\n        \"\"\"\n        Set the seed of each random module.\n        `torch.manual_seed` will set seed on all devices.\n\n        Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20\n        \"\"\"\n        if seed is not None:\n            random.seed(seed)\n            np.random.seed(seed)\n            torch.manual_seed(seed)\n            torch.get_device_module().manual_seed_all(seed)\n\n    @classmethod\n    def verify_model_arch(cls, model_arch: str) -> None:\n        \"\"\"\n        Verify whether the current platform supports the specified model\n        architecture.\n\n        - This will raise an Error or Warning based on the model support on\n        the current platform.\n        - By default all models are considered supported.\n        \"\"\"\n        pass\n\n    @classmethod\n    def verify_quantization(cls, quant: str) -> None:\n        \"\"\"\n        Verify whether the quantization is supported by the current platform.\n        \"\"\"\n        if cls.supported_quantization and quant not in cls.supported_quantization:\n            raise ValueError(\n                f\"{quant} quantization is currently not supported in \"\n                f\"{cls.device_name}.\"\n            )\n\n    @classmethod\n    def get_current_memory_usage(\n        cls, device: torch.types.Device | None = None\n    ) -> float:\n        \"\"\"\n        Return the memory usage in bytes.\n        \"\"\"\n        raise NotImplementedError\n\n    @classmethod\n    def get_available_gpu_memory(\n        cls,\n        device_id: int = 0,\n        distributed: bool = False,\n        empty_cache: bool = True,\n        cpu_group: Any = None,\n    ) -> float:\n        \"\"\"\n        Return the available memory in GiB.\n        \"\"\"\n        raise NotImplementedError\n\n    @classmethod\n    def get_device_communicator_cls(cls) -> str:\n        \"\"\"\n        Get device specific communicator class for distributed communication.\n        \"\"\"\n        return \"sglang.multimodal_gen.runtime.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase\"  # noqa\n\n    @classmethod\n    def get_cpu_architecture(cls) -> CpuArchEnum:\n        \"\"\"Get the CPU architecture of the current platform.\"\"\"\n        return CpuArchEnum.UNSPECIFIED\n\n    @classmethod\n    def enable_dit_layerwise_offload_for_wan_by_default(cls) -> bool:\n        \"\"\"Whether to enable DIT layerwise offload by default on the current platform.\"\"\"\n        return True\n\n    @classmethod\n    def optimize_vae(cls, vae: torch.nn.Module) -> torch.nn.Module:\n        \"\"\"Apply platform-specific optimizations to VAE after loading.\"\"\"\n        return vae\n\n    def get_attn_backend(self, *args, **kwargs) -> AttentionImpl:\n        attention_cls_str = self.get_attn_backend_cls_str(*args, **kwargs)\n        return resolve_obj_by_qualname(attention_cls_str)\n\n\nclass UnspecifiedPlatform(Platform):\n    _enum = PlatformEnum.UNSPECIFIED\n    device_type = \"\"\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/platforms/mps.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\nfrom functools import lru_cache\nfrom typing import Any\n\nimport psutil\nimport torch\n\nfrom sglang.multimodal_gen.runtime.platforms import (\n    AttentionBackendEnum,\n    Platform,\n    PlatformEnum,\n)\nfrom sglang.multimodal_gen.runtime.platforms.interface import DeviceCapability\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\n# SPDX-License-Identifier: Apache-2.0\n\n\nlogger = init_logger(__name__)\n\n\nclass MpsPlatform(Platform):\n    _enum = PlatformEnum.MPS\n    device_name: str = \"mps\"\n    device_type: str = \"mps\"\n    dispatch_key: str = \"MPS\"\n    device_control_env_var: str = \"MPS_VISIBLE_DEVICES\"\n\n    @classmethod\n    @lru_cache(maxsize=1)\n    def is_amp_supported(cls) -> bool:\n        return False\n\n    @classmethod\n    def get_local_torch_device(cls) -> torch.device:\n        return torch.device(\"mps\")\n\n    @classmethod\n    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:\n        raise NotImplementedError\n\n    @classmethod\n    def get_device_name(cls, device_id: int = 0) -> str:\n        raise NotImplementedError\n\n    @classmethod\n    def get_device_uuid(cls, device_id: int = 0) -> str:\n        raise NotImplementedError\n\n    @classmethod\n    @lru_cache(maxsize=1)\n    def get_device_total_memory(cls, device_id: int = 0) -> int:\n\n        return psutil.virtual_memory().total\n\n    @classmethod\n    def is_async_output_supported(cls, enforce_eager: bool | None) -> bool:\n        if enforce_eager:\n            logger.warning(\n                \"To see benefits of async output processing, enable MPS \"\n                \"graph. Since, enforce-eager is enabled, async output \"\n                \"processor cannot be used\"\n            )\n            return False\n        return True\n\n    @classmethod\n    def get_current_memory_usage(\n        cls, device: torch.types.Device | None = None\n    ) -> float:\n        return 0.0\n\n    @classmethod\n    def get_available_gpu_memory(\n        cls,\n        device_id: int = 0,\n        distributed: bool = False,\n        empty_cache: bool = True,\n        cpu_group: Any = None,\n    ) -> float:\n\n        if empty_cache:\n            torch.mps.empty_cache()\n\n        # For MPS, available memory is essentially the system available memory\n        free_memory = psutil.virtual_memory().available\n\n        if distributed:\n            import torch.distributed as dist\n\n            tensor = torch.tensor(free_memory, dtype=torch.float32)\n            dist.all_reduce(tensor, op=dist.ReduceOp.MIN, group=cpu_group)\n            free_memory = float(tensor.item())\n\n        return free_memory / (1 << 30)\n\n    @classmethod\n    def get_attn_backend_cls_str(\n        cls,\n        selected_backend: AttentionBackendEnum | None,\n        head_size: int,\n        dtype: torch.dtype,\n    ) -> str:\n        # MPS supports SDPA (Scaled Dot-Product Attention) which is the most compatible\n        logger.info(\"Using Torch SDPA backend for MPS.\")\n        return (\n            \"sglang.multimodal_gen.runtime.layers.attention.backends.sdpa.SDPABackend\"\n        )\n\n    @classmethod\n    def get_device_communicator_cls(cls) -> str:\n        # Use base communicator for MPS\n        return \"sglang.multimodal_gen.runtime.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase\"\n\n    @classmethod\n    def seed_everything(cls, seed: int | None = None) -> None:\n        \"\"\"Set the seed for MPS device.\"\"\"\n        if seed is not None:\n            import random\n\n            import numpy as np\n\n            random.seed(seed)\n            np.random.seed(seed)\n            torch.manual_seed(seed)\n            # MPS doesn't have manual_seed_all like CUDA\n            # The manual_seed above should be sufficient\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/platforms/musa.py",
    "content": "\"\"\"\nThis file is a platform abstraction for MThreads (MUSA) GPUs,\nadjusted to match the structure and interface of `cuda.py`.\n\"\"\"\n\nimport os\nfrom collections.abc import Callable\nfrom functools import lru_cache, wraps\nfrom typing import Any, TypeVar\n\nimport psutil\nimport pymtml\n\n# isort: off\nimport torch\nimport torchada  # noqa: F401\n\n# isort: on\nfrom typing_extensions import ParamSpec\n\nfrom sglang.multimodal_gen import envs\nfrom sglang.multimodal_gen.runtime.platforms.interface import (\n    AttentionBackendEnum,\n    DeviceCapability,\n    Platform,\n    PlatformEnum,\n)\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n_P = ParamSpec(\"_P\")\n_R = TypeVar(\"_R\")\n\n\ndef device_id_to_physical_device_id(device_id: int) -> int:\n    if \"MUSA_VISIBLE_DEVICES\" in os.environ:\n        device_ids = os.environ[\"MUSA_VISIBLE_DEVICES\"].split(\",\")\n        if device_ids == [\"\"]:\n            msg = (\n                \"MUSA_VISIBLE_DEVICES is set to empty string, which means\"\n                \" GPU support is disabled. If you are using ray, please unset\"\n                \" the environment variable `MUSA_VISIBLE_DEVICES` inside the\"\n                \" worker/actor. \"\n                \"Check https://github.com/vllm-project/vllm/issues/8402 for\"\n                \" more information.\"\n            )\n            raise RuntimeError(msg)\n        physical_device_id = device_ids[device_id]\n        return int(physical_device_id)\n    else:\n        return device_id\n\n\ndef with_mtml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:\n    @wraps(fn)\n    def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:\n        pymtml.nvmlInit()\n        try:\n            return fn(*args, **kwargs)\n        finally:\n            pymtml.nvmlShutdown()\n\n    return wrapper\n\n\nclass MusaPlatformBase(Platform):\n    _enum = PlatformEnum.MUSA\n    device_name: str = \"musa\"\n    device_type: str = \"musa\"\n    dispatch_key: str = \"MUSA\"\n    device_control_env_var: str = \"MUSA_VISIBLE_DEVICES\"\n\n    @classmethod\n    def get_local_torch_device(cls) -> torch.device:\n        return torch.device(f\"musa:{envs.LOCAL_RANK}\")\n\n    @classmethod\n    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:\n        raise NotImplementedError\n\n    @classmethod\n    def get_device_name(cls, device_id: int = 0) -> str:\n        raise NotImplementedError\n\n    @classmethod\n    @lru_cache(maxsize=1)\n    def get_device_total_memory(cls, device_id: int = 0) -> int:\n        raise NotImplementedError\n\n    @classmethod\n    def is_async_output_supported(cls, enforce_eager: bool | None) -> bool:\n        if enforce_eager:\n            logger.warning(\n                \"To see benefits of async output processing, enable MUSA \"\n                \"graph. Since, enforce-eager is enabled, async output \"\n                \"processor cannot be used\"\n            )\n            return False\n        return True\n\n    @classmethod\n    def is_full_mtlink(cls, device_ids: list[int]) -> bool:\n        raise NotImplementedError\n\n    @classmethod\n    def log_warnings(cls) -> None:\n        pass\n\n    @classmethod\n    def get_current_memory_usage(\n        cls, device: torch.types.Device | None = None\n    ) -> float:\n        torch.cuda.reset_peak_memory_stats(device)\n        return float(torch.cuda.max_memory_allocated(device))\n\n    @classmethod\n    def get_available_gpu_memory(\n        cls,\n        device_id: int = 0,\n        distributed: bool = False,\n        empty_cache: bool = True,\n        cpu_group: Any = None,\n    ) -> float:\n        if empty_cache:\n            torch.cuda.empty_cache()\n\n        if torch.distributed.is_initialized():\n            device_id = torch.distributed.get_rank()\n\n        device_props = torch.cuda.get_device_properties(device_id)\n        if device_props.is_integrated:\n            free_gpu_memory = psutil.virtual_memory().available\n        else:\n            free_gpu_memory, _ = torch.cuda.mem_get_info(device_id)\n\n        if distributed:\n            import torch.distributed as dist\n\n            tensor = torch.tensor(free_gpu_memory, dtype=torch.float32, device=\"musa\")\n            dist.all_reduce(tensor, op=dist.ReduceOp.MIN, group=cpu_group)\n            free_gpu_memory = float(tensor.item())\n\n        return free_gpu_memory / (1 << 30)\n\n    @classmethod\n    def get_attn_backend_cls_str(\n        cls,\n        selected_backend: AttentionBackendEnum | None,\n        head_size: int,\n        dtype: torch.dtype,\n    ) -> str:\n        logger.info(\"Using Torch SDPA backend.\")\n        return (\n            \"sglang.multimodal_gen.runtime.layers.attention.backends.sdpa.SDPABackend\"\n        )\n\n    @classmethod\n    def get_device_communicator_cls(cls) -> str:\n        return \"sglang.multimodal_gen.runtime.distributed.device_communicators.cuda_communicator.CudaCommunicator\"  # noqa\n\n\n# MTML utils\n# Note that MTML is not affected by `MUSA_VISIBLE_DEVICES`,\n# all the related functions work on real physical device ids.\n# the major benefit of using MTML is that it will not initialize MUSA\nclass MtmlMusaPlatform(MusaPlatformBase):\n    @classmethod\n    @lru_cache(maxsize=8)\n    @with_mtml_context\n    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:\n        try:\n            physical_device_id = device_id_to_physical_device_id(device_id)\n            handle = pymtml.nvmlDeviceGetHandleByIndex(physical_device_id)\n            major, minor = pymtml.nvmlDeviceGetCudaComputeCapability(handle)\n            return DeviceCapability(major=major, minor=minor)\n        except RuntimeError:\n            return None\n\n    @classmethod\n    @lru_cache(maxsize=8)\n    @with_mtml_context\n    def has_device_capability(\n        cls,\n        capability: tuple[int, int] | int,\n        device_id: int = 0,\n    ) -> bool:\n        try:\n            return bool(super().has_device_capability(capability, device_id))\n        except RuntimeError:\n            return False\n\n    @classmethod\n    @lru_cache(maxsize=8)\n    @with_mtml_context\n    def get_device_name(cls, device_id: int = 0) -> str:\n        physical_device_id = device_id_to_physical_device_id(device_id)\n        return cls._get_physical_device_name(physical_device_id)\n\n    @classmethod\n    @lru_cache(maxsize=8)\n    @with_mtml_context\n    def get_device_uuid(cls, device_id: int = 0) -> str:\n        physical_device_id = device_id_to_physical_device_id(device_id)\n        handle = pymtml.nvmlDeviceGetHandleByIndex(physical_device_id)\n        return str(pymtml.nvmlDeviceGetUUID(handle))\n\n    @classmethod\n    @lru_cache(maxsize=8)\n    @with_mtml_context\n    def get_device_total_memory(cls, device_id: int = 0) -> int:\n        physical_device_id = device_id_to_physical_device_id(device_id)\n        handle = pymtml.nvmlDeviceGetHandleByIndex(physical_device_id)\n        return int(pymtml.nvmlDeviceGetMemoryInfo(handle).total)\n\n    @classmethod\n    @with_mtml_context\n    def is_full_mtlink(cls, physical_device_ids: list[int]) -> bool:\n        \"\"\"\n        query if the set of gpus are fully connected by mtlink (1 hop)\n        \"\"\"\n        handles = [pymtml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]\n        for i, handle in enumerate(handles):\n            for j, peer_handle in enumerate(handles):\n                if i < j:\n                    try:\n                        p2p_status = pymtml.nvmlDeviceGetP2PStatus(\n                            handle,\n                            peer_handle,\n                            pymtml.NVML_P2P_CAPS_INDEX_NVLINK,\n                        )\n                        if p2p_status != pymtml.NVML_P2P_STATUS_OK:\n                            return False\n                    except pymtml.NVMLError:\n                        logger.exception(\n                            \"MTLink detection failed. This is normal if\"\n                            \" your machine has no MTLink equipped.\"\n                        )\n                        return False\n        return True\n\n    @classmethod\n    def _get_physical_device_name(cls, device_id: int = 0) -> str:\n        handle = pymtml.nvmlDeviceGetHandleByIndex(device_id)\n        return str(pymtml.nvmlDeviceGetName(handle))\n\n    @classmethod\n    @with_mtml_context\n    def log_warnings(cls) -> None:\n        device_ids: int = pymtml.nvmlDeviceGetCount()\n        if device_ids > 1:\n            device_names = [cls._get_physical_device_name(i) for i in range(device_ids)]\n            if (\n                len(set(device_names)) > 1\n                and os.environ.get(\"MUSA_DEVICE_ORDER\") != \"PCI_BUS_ID\"\n            ):\n                logger.warning(\n                    \"Detected different devices in the system: %s. Please\"\n                    \" make sure to set `MUSA_DEVICE_ORDER=PCI_BUS_ID` to \"\n                    \"avoid unexpected behavior.\",\n                    \", \".join(device_names),\n                )\n\n\nclass NonMtmlMusaPlatform(MusaPlatformBase):\n    @classmethod\n    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:\n        major, minor = torch.cuda.get_device_capability(device_id)\n        return DeviceCapability(major=major, minor=minor)\n\n    @classmethod\n    def get_device_name(cls, device_id: int = 0) -> str:\n        return str(torch.cuda.get_device_name(device_id))\n\n    @classmethod\n    @lru_cache(maxsize=1)\n    def get_device_total_memory(cls, device_id: int = 0) -> int:\n        device_props = torch.cuda.get_device_properties(device_id)\n        return int(device_props.total_memory)\n\n    @classmethod\n    def is_full_mtlink(cls, physical_device_ids: list[int]) -> bool:\n        logger.error(\n            \"MTLink detection not possible, as context support was\"\n            \" not found. Assuming no MTLink available.\"\n        )\n        return False\n\n\n# Autodetect either MTML-enabled or non-MTML platform\n# based on whether MTML is available.\nmtml_available = False\n\nif \"MUSA_DISABLE_MTML\" not in os.environ:\n    try:\n        try:\n            pymtml.nvmlInit()\n            mtml_available = True\n        except Exception:\n            mtml_available = False\n    finally:\n        if mtml_available:\n            pymtml.nvmlShutdown()\n\nMusaPlatform = MtmlMusaPlatform if mtml_available else NonMtmlMusaPlatform\n\ntry:\n    from sphinx.ext.autodoc.mock import _MockModule\n\n    if not isinstance(pymtml, _MockModule):\n        MusaPlatform.log_warnings()\nexcept ModuleNotFoundError:\n    MusaPlatform.log_warnings()\n\nif __name__ == \"__main__\":\n    print(MusaPlatform.__name__)\n    print(MusaPlatform.get_device_name())\n    print(MusaPlatform.get_device_capability())\n    print(MusaPlatform.get_device_total_memory())\n    print(MusaPlatform.is_full_mtlink([0, 1, 2, 3, 4, 5, 6, 7]))\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/platforms/npu.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n# Adapted from vllm-ascend: https://github.com/vllm-project/vllm-ascend/blob/main/vllm_ascend/platform.py\n\nimport os\nfrom typing import Any\n\nimport torch\n\nfrom sglang.multimodal_gen import envs\nfrom sglang.multimodal_gen.runtime.platforms.interface import (\n    AttentionBackendEnum,\n    DeviceCapability,\n    Platform,\n    PlatformEnum,\n)\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\ndef device_id_to_physical_device_id(device_id: int) -> int:\n    if \"ASCEND_RT_VISIBLE_DEVICES\" in os.environ:\n        device_ids = os.environ[\"ASCEND_RT_VISIBLE_DEVICES\"].split(\",\")\n        if device_ids == [\"\"]:\n            msg = (\n                \"ASCEND_RT_VISIBLE_DEVICES is set to empty string, which means\"\n                \" NPU support is disabled\"\n            )\n            raise RuntimeError(msg)\n        physical_device_id = device_ids[device_id]\n        return int(physical_device_id)\n    else:\n        return device_id\n\n\nclass NPUPlatformBase(Platform):\n    _enum = PlatformEnum.NPU\n    device_name: str = \"npu\"\n    device_type: str = \"npu\"\n    dispatch_key: str = \"NPU\"\n    device_control_env_var: str = \"ASCEND_RT_VISIBLE_DEVICES\"\n\n    @classmethod\n    def get_local_torch_device(cls) -> torch.device:\n        return torch.device(f\"npu:{envs.LOCAL_RANK}\")\n\n    @classmethod\n    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:\n        return None\n\n    @classmethod\n    def get_device_name(cls, device_id: int = 0) -> str:\n        return str(torch.npu.get_device_name(device_id))\n\n    @classmethod\n    def get_device_total_memory(cls, device_id: int = 0) -> int:\n        device_props = torch.npu.get_device_properties(device_id)\n        return int(device_props.total_memory)\n\n    @classmethod\n    def is_async_output_supported(cls, enforce_eager: bool | None) -> bool:\n        if enforce_eager:\n            logger.warning(\n                \"To see benefits of async output processing, enable NPU \"\n                \"graph. Since, enforce-eager is enabled, async output \"\n                \"processor cannot be used\"\n            )\n            return False\n        return True\n\n    @classmethod\n    def is_full_nvlink(cls, physical_device_ids: list[int]) -> bool:\n        logger.exception(\n            \"NVLink detection not possible, as context support was\"\n            \" not found. Assuming no NVLink available.\"\n        )\n        return False\n\n    @classmethod\n    def get_available_gpu_memory(\n        cls,\n        device_id: int = 0,\n        distributed: bool = False,\n        empty_cache: bool = True,\n        cpu_group: Any = None,\n    ) -> float:\n        if empty_cache:\n            torch.npu.empty_cache()\n\n        free_gpu_memory, _ = torch.npu.mem_get_info(device_id)\n\n        if distributed:\n            import torch.distributed as dist\n\n            tensor = torch.tensor(free_gpu_memory, dtype=torch.float32, device=\"npu\")\n            dist.all_reduce(tensor, op=dist.ReduceOp.MIN, group=cpu_group)\n            free_gpu_memory = float(tensor.item())\n\n        return free_gpu_memory / (1 << 30)\n\n    @classmethod\n    def log_warnings(cls) -> None:\n        pass\n\n    @classmethod\n    def get_current_memory_usage(\n        cls, device: torch.types.Device | None = None\n    ) -> float:\n        torch.npu.reset_peak_memory_stats(device)\n        return float(torch.npu.max_memory_allocated(device))\n\n    @classmethod\n    def get_attn_backend_cls_str(\n        cls,\n        selected_backend: AttentionBackendEnum | None,\n        head_size: int,\n        dtype: torch.dtype,\n    ) -> str:\n        logger.info(\"Using Torch SDPA backend.\")\n        return (\n            \"sglang.multimodal_gen.runtime.layers.attention.backends.sdpa.SDPABackend\"\n        )\n\n    @classmethod\n    def get_device_communicator_cls(cls) -> str:\n        return \"sglang.multimodal_gen.runtime.distributed.device_communicators.cuda_communicator.CudaCommunicator\"  # noqa\n\n    @classmethod\n    def enable_dit_layerwise_offload_for_wan_by_default(cls) -> bool:\n        \"\"\"The performance of the layerwise_offload feature depends on the device's memory size and the memory size occupied by the model. Use --dit-layerwise-offload True if it suitable for your case.\"\"\"\n        return False\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/platforms/rocm.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# Adapted from rocm/vllm: https://github.com/ROCm/vllm/blob/v0.7.3%2Brocm/vllm/platforms/rocm.py\n\"\"\"\nThis file is a platform abstraction for ROCm GPUs,\nadjusted to match the structure and interface of `cuda.py`.\n\"\"\"\n\nfrom functools import lru_cache\nfrom typing import Any\n\nimport torch\n\nimport sglang.multimodal_gen.envs as envs\nfrom sglang.multimodal_gen.runtime.platforms.interface import (\n    AttentionBackendEnum,\n    DeviceCapability,\n    Platform,\n    PlatformEnum,\n)\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\n# ROCm uses the same torch.cuda interface\nclass RocmPlatform(Platform):\n    _enum = PlatformEnum.ROCM\n    device_name: str = \"rocm\"\n    device_type: str = \"cuda\"  # torch uses 'cuda' backend string\n    dispatch_key: str = \"CUDA\"\n    device_control_env_var: str = \"CUDA_VISIBLE_DEVICES\"\n\n    @classmethod\n    def get_local_torch_device(cls) -> torch.device:\n        return torch.device(f\"cuda:{envs.LOCAL_RANK}\")\n\n    @classmethod\n    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:\n        major, minor = torch.cuda.get_device_capability(device_id)\n        return DeviceCapability(major=major, minor=minor)\n\n    @classmethod\n    def get_device_name(cls, device_id: int = 0) -> str:\n        return str(torch.cuda.get_device_name(device_id))\n\n    @classmethod\n    @lru_cache(maxsize=1)\n    def get_device_total_memory(cls, device_id: int = 0) -> int:\n        return torch.cuda.get_device_properties(device_id).total_memory\n\n    @classmethod\n    def is_async_output_supported(cls, enforce_eager: bool | None) -> bool:\n        if enforce_eager:\n            logger.warning(\n                \"To see benefits of async output processing, enable CUDA graph. \"\n                \"Since enforce-eager is enabled, async output processor cannot be used\"\n            )\n            return False\n        return True\n\n    @classmethod\n    def log_warnings(cls) -> None:\n        pass  # ROCm-specific warnings can be added here\n\n    @classmethod\n    def get_current_memory_usage(cls, device: torch.device | None = None) -> float:\n        torch.cuda.reset_peak_memory_stats(device)\n        return float(torch.cuda.max_memory_allocated(device))\n\n    @classmethod\n    def get_available_gpu_memory(\n        cls,\n        device_id: int = 0,\n        distributed: bool = False,\n        empty_cache: bool = True,\n        cpu_group: Any = None,\n    ) -> float:\n        if empty_cache:\n            torch.cuda.empty_cache()\n\n        free_gpu_memory, _ = torch.cuda.mem_get_info(device_id)\n\n        if distributed:\n            import torch.distributed as dist\n\n            tensor = torch.tensor(free_gpu_memory, dtype=torch.float32, device=\"cuda\")\n            dist.all_reduce(tensor, op=dist.ReduceOp.MIN, group=cpu_group)\n            free_gpu_memory = float(tensor.item())\n\n        return free_gpu_memory / (1 << 30)\n\n    @classmethod\n    def get_attn_backend_cls_str(\n        cls,\n        selected_backend: AttentionBackendEnum | None,\n        head_size: int,\n        dtype: torch.dtype,\n    ) -> str:\n        if selected_backend == AttentionBackendEnum.TORCH_SDPA:\n            logger.info(\"Using Torch SDPA backend.\")\n            return \"sglang.multimodal_gen.runtime.layers.attention.backends.sdpa.SDPABackend\"\n\n        elif selected_backend in (AttentionBackendEnum.FA, None):\n            pass\n\n        elif selected_backend == AttentionBackendEnum.AITER:\n            if dtype not in (torch.float16, torch.bfloat16):\n                logger.warning(\n                    \"AITer backend works best with fp16/bf16 inputs but got dtype=%s. \"\n                    \"Proceeding with AITer anyway.\",\n                    dtype,\n                )\n            logger.info(\"Using AITer backend on ROCm.\")\n            return \"sglang.multimodal_gen.runtime.layers.attention.backends.aiter.AITerBackend\"\n\n        elif selected_backend == AttentionBackendEnum.AITER_SAGE:\n            if dtype in (torch.float16, torch.bfloat16):\n                logger.info(\"Using AITER Sage backend on ROCm.\")\n                return \"sglang.multimodal_gen.runtime.layers.attention.backends.aiter_sage.AITERSageBackend\"\n            else:\n                logger.warning(\n                    \"AITER Sage backend only supports bf16/fp16 inputs but got dtype=%s.\",\n                    dtype,\n                )\n\n        elif selected_backend in (\n            AttentionBackendEnum.SLIDING_TILE_ATTN,\n            AttentionBackendEnum.SAGE_ATTN,\n        ):\n            raise ValueError(\n                f\"{selected_backend.name} is not supported on {cls.device_name}.\"\n            )\n        elif selected_backend:\n            raise ValueError(\n                f\"Invalid attention backend for {cls.device_name}: {selected_backend}\"\n            )\n\n        target_backend = AttentionBackendEnum.FA\n        if dtype not in (torch.float16, torch.bfloat16):\n            logger.info(\n                \"Cannot use FlashAttention backend for dtype other than \"\n                \"torch.float16 or torch.bfloat16.\"\n            )\n            target_backend = AttentionBackendEnum.TORCH_SDPA\n\n        if target_backend == AttentionBackendEnum.FA:\n            try:\n                import flash_attn  # noqa: F401\n\n                from sglang.multimodal_gen.runtime.layers.attention.backends.flash_attn import (  # noqa: F401\n                    FlashAttentionBackend,\n                )\n\n                supported_sizes = FlashAttentionBackend.get_supported_head_sizes()\n                if head_size not in supported_sizes:\n                    logger.info(\n                        \"Cannot use FlashAttention-2 backend for head size %d.\",\n                        head_size,\n                    )\n                    target_backend = AttentionBackendEnum.TORCH_SDPA\n            except ImportError:\n                logger.info(\n                    \"Cannot use FlashAttention backend because the \"\n                    \"flash_attn package is not found. \"\n                    \"Make sure that flash_attn was built and installed \"\n                    \"(on by default).\"\n                )\n                target_backend = AttentionBackendEnum.TORCH_SDPA\n\n        if target_backend == AttentionBackendEnum.TORCH_SDPA:\n            logger.info(\"Using Torch SDPA backend.\")\n\n            return \"sglang.multimodal_gen.runtime.layers.attention.backends.sdpa.SDPABackend\"\n\n        logger.info(\"Using Flash Attention backend.\")\n\n        return \"sglang.multimodal_gen.runtime.layers.attention.backends.flash_attn.FlashAttentionBackend\"\n\n    @classmethod\n    def get_device_communicator_cls(cls) -> str:\n        return \"sglang.multimodal_gen.runtime.distributed.device_communicators.cuda_communicator.CudaCommunicator\"  # works for ROCm too\n\n    @classmethod\n    def optimize_vae(cls, vae: torch.nn.Module) -> torch.nn.Module:\n        \"\"\"Replace nn.GroupNorm with AITer GroupNorm for improved ROCm VAE performance.\"\"\"\n        if not envs.SGLANG_USE_ROCM_VAE:\n            return vae\n        try:\n            from aiter.ops.groupnorm import GroupNorm as AiterGroupNorm\n\n            count = cls._replace_groupnorm(vae, AiterGroupNorm)\n            if count > 0:\n                logger.info(\n                    \"Replaced %d nn.GroupNorm modules with AITer GroupNorm in VAE\",\n                    count,\n                )\n        except Exception:\n            logger.warning(\n                \"Failed to apply AITer GroupNorm to VAE.\",\n                exc_info=True,\n            )\n        return vae\n\n    @staticmethod\n    def _replace_groupnorm(module: torch.nn.Module, aiter_gn_cls: type) -> int:\n        count = 0\n        for name, child in module.named_children():\n            if isinstance(child, torch.nn.GroupNorm) and child.affine:\n                replacement = aiter_gn_cls(\n                    num_groups=child.num_groups,\n                    num_channels=child.num_channels,\n                    eps=child.eps,\n                    affine=True,\n                    device=child.weight.device,\n                    dtype=child.weight.dtype,\n                )\n                replacement.weight = child.weight\n                replacement.bias = child.bias\n                setattr(module, name, replacement)\n                count += 1\n            else:\n                count += RocmPlatform._replace_groupnorm(child, aiter_gn_cls)\n        return count\n\n    @classmethod\n    def enable_dit_layerwise_offload_for_wan_by_default(cls) -> bool:\n        \"\"\"ROCm performs better without DIT layerwise offload on Wan.\"\"\"\n        return False\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/postprocess/__init__.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\"\"\"Frame interpolation and upscaling support for SGLang diffusion pipelines.\"\"\"\n\nfrom sglang.multimodal_gen.runtime.postprocess.realesrgan_upscaler import (\n    ImageUpscaler,\n    upscale_frames,\n)\nfrom sglang.multimodal_gen.runtime.postprocess.rife_interpolator import (\n    FrameInterpolator,\n    interpolate_video_frames,\n)\n\n__all__ = [\n    \"FrameInterpolator\",\n    \"interpolate_video_frames\",\n    \"ImageUpscaler\",\n    \"upscale_frames\",\n]\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/postprocess/realesrgan_upscaler.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\"\"\"\nReal-ESRGAN upscaling for SGLang diffusion pipelines.\n\nReal-ESRGAN model code is vendored and adapted from:\n  - https://github.com/xinntao/Real-ESRGAN  (BSD-3-Clause License)\n  Copyright (c) 2021 xinntao\n\nThe ImageUpscaler wrapper and integration code are original work.\n\"\"\"\n\nimport math\nimport os\nfrom typing import Optional\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n# Default HuggingFace repo and filename for Real-ESRGAN weights\n_DEFAULT_REALESRGAN_HF_REPO = \"ai-forever/Real-ESRGAN\"\n_DEFAULT_REALESRGAN_FILENAME = \"RealESRGAN_x4.pth\"\n\n# Module-level cache: model_path -> UpscalerModel instance\n_MODEL_CACHE: dict[str, \"UpscalerModel\"] = {}\n\n\n# ---------------------------------------------------------------------------\n# Vendored Real-ESRGAN architecture code\n# (SRVGGNetCompact, ResidualDenseBlock, RRDB, RRDBNet)\n# ---------------------------------------------------------------------------\n\n\nclass SRVGGNetCompact(nn.Module):\n    \"\"\"Compact VGG-style network for super resolution.\n\n    Corresponds to ``realesr-animevideov3`` and ``realesr-general-x4v3``.\n    Reference: xinntao/Real-ESRGAN (BSD-3-Clause).\n    \"\"\"\n\n    def __init__(\n        self,\n        num_in_ch: int = 3,\n        num_out_ch: int = 3,\n        num_feat: int = 64,\n        num_conv: int = 16,\n        upscale: int = 4,\n        act_type: str = \"prelu\",\n    ):\n        super().__init__()\n        self.num_in_ch = num_in_ch\n        self.num_out_ch = num_out_ch\n        self.num_feat = num_feat\n        self.num_conv = num_conv\n        self.upscale = upscale\n        self.act_type = act_type\n\n        self.body = nn.ModuleList()\n        # first conv\n        self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))\n        # first activation\n        self.body.append(self._make_act(act_type, num_feat))\n        # body convs + activations\n        for _ in range(num_conv):\n            self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))\n            self.body.append(self._make_act(act_type, num_feat))\n        # last conv: maps to out_ch * upscale^2 for pixel shuffle\n        self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))\n        self.upsampler = nn.PixelShuffle(upscale)\n\n    @staticmethod\n    def _make_act(act_type: str, num_feat: int) -> nn.Module:\n        if act_type == \"relu\":\n            return nn.ReLU(inplace=True)\n        elif act_type == \"prelu\":\n            return nn.PReLU(num_parameters=num_feat)\n        elif act_type == \"leakyrelu\":\n            return nn.LeakyReLU(negative_slope=0.1, inplace=True)\n        else:\n            raise ValueError(f\"Unsupported activation type: {act_type}\")\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        out = x\n        for layer in self.body:\n            out = layer(out)\n        out = self.upsampler(out)\n        # residual addition with nearest upsampled input\n        base = F.interpolate(x, scale_factor=self.upscale, mode=\"nearest\")\n        return out + base\n\n\nclass ResidualDenseBlock(nn.Module):\n    \"\"\"Residual Dense Block used in RRDB (RealESRGAN_x4plus).\"\"\"\n\n    def __init__(self, num_feat: int = 64, num_grow_ch: int = 32):\n        super().__init__()\n        self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)\n        self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)\n        self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)\n        self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)\n        self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)\n        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x1 = self.lrelu(self.conv1(x))\n        x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))\n        x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))\n        x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))\n        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))\n        return x5 * 0.2 + x\n\n\nclass RRDB(nn.Module):\n    \"\"\"Residual in Residual Dense Block.\"\"\"\n\n    def __init__(self, num_feat: int, num_grow_ch: int = 32):\n        super().__init__()\n        self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)\n        self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)\n        self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        out = self.rdb1(x)\n        out = self.rdb2(out)\n        out = self.rdb3(out)\n        return out * 0.2 + x\n\n\nclass RRDBNet(nn.Module):\n    \"\"\"RRDB network for RealESRGAN_x4plus (heavier, higher quality for photos).\"\"\"\n\n    def __init__(\n        self,\n        num_in_ch: int = 3,\n        num_out_ch: int = 3,\n        scale: int = 4,\n        num_feat: int = 64,\n        num_block: int = 23,\n        num_grow_ch: int = 32,\n    ):\n        super().__init__()\n        self.scale = scale\n        in_ch = num_in_ch\n        if scale == 2:\n            in_ch = num_in_ch * 4\n        elif scale == 1:\n            in_ch = num_in_ch * 16\n        self.conv_first = nn.Conv2d(in_ch, num_feat, 3, 1, 1)\n        self.body = nn.Sequential(\n            *[RRDB(num_feat, num_grow_ch) for _ in range(num_block)]\n        )\n        self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)\n        # upsample\n        self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)\n        self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)\n        self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)\n        self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)\n        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        if self.scale == 2:\n            feat = F.pixel_unshuffle(x, 2)\n        elif self.scale == 1:\n            feat = F.pixel_unshuffle(x, 4)\n        else:\n            feat = x\n        feat = self.conv_first(feat)\n        body_feat = self.conv_body(self.body(feat))\n        feat = feat + body_feat\n        feat = self.lrelu(\n            self.conv_up1(F.interpolate(feat, scale_factor=2, mode=\"nearest\"))\n        )\n        feat = self.lrelu(\n            self.conv_up2(F.interpolate(feat, scale_factor=2, mode=\"nearest\"))\n        )\n        return self.conv_last(self.lrelu(self.conv_hr(feat)))\n\n\n# ---------------------------------------------------------------------------\n# Architecture auto-detection\n# ---------------------------------------------------------------------------\n\n\ndef _build_net_from_state_dict(state_dict: dict) -> nn.Module:\n    \"\"\"Detect architecture from checkpoint keys and return an unloaded network.\"\"\"\n    if \"conv_first.weight\" in state_dict:\n        # RRDBNet (e.g., RealESRGAN_x4plus)\n        num_feat = state_dict[\"conv_first.weight\"].shape[0]\n        num_block = sum(\n            1\n            for k in state_dict\n            if k.startswith(\"body.\") and k.endswith(\".rdb1.conv1.weight\")\n        )\n        num_grow_ch = state_dict[\"body.0.rdb1.conv1.weight\"].shape[0]\n        logger.info(\n            \"Detected RRDBNet: num_feat=%d, num_block=%d, num_grow_ch=%d\",\n            num_feat,\n            num_block,\n            num_grow_ch,\n        )\n        return RRDBNet(\n            num_in_ch=3,\n            num_out_ch=3,\n            scale=4,\n            num_feat=num_feat,\n            num_block=num_block,\n            num_grow_ch=num_grow_ch,\n        )\n    else:\n        # SRVGGNetCompact (e.g., realesr-animevideov3)\n        num_feat = state_dict[\"body.0.weight\"].shape[0]\n        # body layout: [first_conv, first_act, (conv, act)*num_conv, last_conv]\n        # count 4-D weight tensors = first_conv + loop_convs + last_conv = num_conv + 2\n        conv_keys = sorted(\n            [\n                k\n                for k in state_dict\n                if k.startswith(\"body.\")\n                and k.endswith(\".weight\")\n                and state_dict[k].dim() == 4\n            ],\n            key=lambda k: int(k.split(\".\")[1]),\n        )\n        num_conv = len(conv_keys) - 2  # subtract first and last\n        # upscale from last conv output channels: out_ch = num_out_ch * upscale^2\n        last_out_ch = state_dict[conv_keys[-1]].shape[0]\n        upscale = int(math.sqrt(last_out_ch / 3))\n        logger.info(\n            \"Detected SRVGGNetCompact: num_feat=%d, num_conv=%d, upscale=%d\",\n            num_feat,\n            num_conv,\n            upscale,\n        )\n        return SRVGGNetCompact(\n            num_in_ch=3,\n            num_out_ch=3,\n            num_feat=num_feat,\n            num_conv=num_conv,\n            upscale=upscale,\n            act_type=\"prelu\",\n        )\n\n\n# ---------------------------------------------------------------------------\n# UpscalerModel\n# ---------------------------------------------------------------------------\n\n\nclass UpscalerModel:\n    \"\"\"Wraps a Real-ESRGAN network, provides load() and upscale() API.\"\"\"\n\n    def __init__(self, net: nn.Module, scale: int):\n        self.net = net\n        self.scale = scale  # the model's native upscaling factor (e.g. 4)\n\n    @property\n    def device(self) -> torch.device:\n        return next(self.net.parameters()).device\n\n    def upscale(self, frame: np.ndarray, outscale: float | None = None) -> np.ndarray:\n        \"\"\"Upscale a single HWC uint8 frame → HWC uint8 frame.\n\n        Args:\n            frame:    Input HWC uint8 numpy array.\n            outscale: Desired final upscaling factor. If different from the\n                      model's native scale, a cheap resize is applied after\n                      the network output (same approach as the official\n                      Real-ESRGAN ``inference_realesrgan.py --outscale``).\n                      ``None`` means use the model's native scale as-is.\n        \"\"\"\n        h, w = frame.shape[:2]\n        img = frame.astype(np.float32) / 255.0\n        img_t = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).to(self.device)\n        with torch.no_grad():\n            out = self.net(img_t)\n\n        # If the desired outscale differs from the model's native scale,\n        # resize to (h * outscale, w * outscale).\n        if outscale is not None and outscale != self.scale:\n            target_h = int(h * outscale)\n            target_w = int(w * outscale)\n            out = F.interpolate(\n                out, size=(target_h, target_w), mode=\"bicubic\", align_corners=False\n            )\n\n        out_np = out.squeeze(0).permute(1, 2, 0).clamp(0.0, 1.0).cpu().numpy()\n        return (out_np * 255.0).astype(np.uint8)\n\n\n# ---------------------------------------------------------------------------\n# ImageUpscaler public class\n# ---------------------------------------------------------------------------\n\n\nclass ImageUpscaler:\n    \"\"\"\n    Lazy-loaded Real-ESRGAN upscaler.\n\n    Weights are downloaded and cached on first call to `.upscale()`.\n    Supports both SRVGGNetCompact (lightweight, default) and RRDBNet (heavier).\n    \"\"\"\n\n    def __init__(\n        self,\n        model_path: Optional[str] = None,\n        scale: int = 4,\n        half_precision: bool = False,\n    ):\n        self._model_path = model_path\n        self._scale = scale\n        self._half_precision = half_precision\n\n    def _ensure_model_loaded(self) -> UpscalerModel:\n        \"\"\"Download/load Real-ESRGAN weights, detect arch, and cache globally.\"\"\"\n        model_path = self._model_path or _DEFAULT_REALESRGAN_HF_REPO\n\n        # Resolve: local .pth pass-through, or HF repo → download single file\n        resolved_path = _resolve_model_path(model_path)\n\n        if resolved_path in _MODEL_CACHE:\n            return _MODEL_CACHE[resolved_path]\n\n        logger.info(\"Loading Real-ESRGAN weights from %s\", resolved_path)\n        try:\n            state_dict = torch.load(\n                resolved_path, map_location=\"cpu\", weights_only=True\n            )\n        except Exception as e:\n            raise RuntimeError(\n                f\"Failed to load Real-ESRGAN checkpoint from '{resolved_path}'. \"\n                f\"The file may be corrupted or not a valid PyTorch checkpoint. \"\n                f\"Original error: {e}\"\n            ) from e\n\n        # Some checkpoints wrap weights under a 'params' or 'params_ema' key\n        if \"params_ema\" in state_dict:\n            state_dict = state_dict[\"params_ema\"]\n        elif \"params\" in state_dict:\n            state_dict = state_dict[\"params\"]\n\n        try:\n            net = _build_net_from_state_dict(state_dict)\n            net.load_state_dict(state_dict, strict=True)\n        except (RuntimeError, KeyError) as e:\n            raise RuntimeError(\n                f\"Real-ESRGAN weight file '{resolved_path}' is not compatible \"\n                f\"with the supported architectures (SRVGGNetCompact / RRDBNet). \"\n                f\"Please ensure you are using a valid Real-ESRGAN checkpoint. \"\n                f\"Original error: {e}\"\n            ) from e\n        net.eval()\n\n        device = current_platform.get_local_torch_device()\n        if self._half_precision:\n            net = net.half()\n        net = net.to(device)\n\n        # Detect the model's native scale from network architecture\n        native_scale = 4  # sensible default\n        if hasattr(net, \"upscale\"):\n            native_scale = net.upscale\n        elif hasattr(net, \"scale\"):\n            native_scale = net.scale\n\n        model = UpscalerModel(net=net, scale=native_scale)\n        _MODEL_CACHE[resolved_path] = model\n        logger.info(\n            \"Real-ESRGAN model loaded on device: %s (native_scale=%dx, outscale=%s)\",\n            device,\n            native_scale,\n            f\"{self._scale}x\" if self._scale != native_scale else \"native\",\n        )\n        return model\n\n    def upscale(self, frames: list[np.ndarray]) -> list[np.ndarray]:\n        \"\"\"Upscale a list of HWC uint8 frames.\n\n        Uses the model's native scale for super-resolution, then resizes to\n        the desired ``outscale`` if it differs (cheap bicubic resize).\n        \"\"\"\n        if not frames:\n            return frames\n        model = self._ensure_model_loaded()\n        outscale = self._scale if self._scale != model.scale else None\n        return [model.upscale(frame, outscale=outscale) for frame in frames]\n\n\n# ---------------------------------------------------------------------------\n# HF download helper\n# ---------------------------------------------------------------------------\n\n\ndef _resolve_model_path(model_path: str) -> str:\n    \"\"\"Return a local .pth file path.\n\n    Accepts:\n    - An existing local file path (pass-through).\n    - A HuggingFace ``repo_id`` → downloads the default weight file\n      (``RealESRGAN_x4.pth``).\n    - A HuggingFace ``repo_id:filename`` → downloads *filename* from *repo_id*,\n      allowing users to specify custom weight files hosted on HF.\n    \"\"\"\n    if os.path.isfile(model_path):\n        return model_path\n\n    # Parse optional \"repo_id:filename\" syntax; fall back to default filename.\n    if \":\" in model_path and not model_path.startswith(\"/\"):\n        repo_id, filename = model_path.split(\":\", 1)\n    else:\n        repo_id = model_path\n        filename = _DEFAULT_REALESRGAN_FILENAME\n\n    try:\n        from huggingface_hub import hf_hub_download\n    except ImportError as e:\n        raise ImportError(\n            \"huggingface_hub is required to download Real-ESRGAN weights. \"\n            \"Install it with: pip install huggingface_hub\"\n        ) from e\n\n    logger.info(\n        \"Downloading Real-ESRGAN weights from HF repo %s (file: %s)\",\n        repo_id,\n        filename,\n    )\n    try:\n        local_path = hf_hub_download(\n            repo_id=repo_id,\n            filename=filename,\n        )\n    except Exception as e:\n        raise FileNotFoundError(\n            f\"Failed to download Real-ESRGAN weights from HuggingFace repo \"\n            f\"'{repo_id}' (file: '{filename}'). If you are using a custom \"\n            f\"model, provide either a local .pth file path or use the \"\n            f\"'repo_id:filename' format (e.g. 'my-org/my-esrgan:weights.pth'). \"\n            f\"Original error: {e}\"\n        ) from e\n    return local_path\n\n\n# ---------------------------------------------------------------------------\n# Module-level convenience function\n# ---------------------------------------------------------------------------\n\n\ndef upscale_frames(\n    frames: list[np.ndarray],\n    model_path: Optional[str] = None,\n    scale: int = 4,\n    half_precision: bool = False,\n) -> list[np.ndarray]:\n    \"\"\"\n    Convenience wrapper around ImageUpscaler.\n\n    The model always runs at its native resolution (e.g. 4× for\n    ``RealESRGAN_x4.pth``).  If *scale* differs from the native factor,\n    a cheap bicubic resize is applied after the network output – the same\n    approach used by the official Real-ESRGAN ``--outscale`` flag.\n\n    Args:\n        frames:         List of uint8 HWC numpy frames.\n        model_path:     Local .pth file, HuggingFace repo ID, or\n                        ``repo_id:filename`` for a custom weight file.\n                        None → default ``ai-forever/Real-ESRGAN`` with\n                        ``RealESRGAN_x4.pth``.\n        scale:          Desired final upscaling factor (e.g. 2, 3, 4).\n                        The 4× model is used internally; the output is\n                        resized to match *scale* when it differs.\n        half_precision: Use fp16 inference (faster on supported GPUs).\n\n    Returns:\n        List of upscaled uint8 HWC numpy frames.\n    \"\"\"\n    upscaler = ImageUpscaler(\n        model_path=model_path, scale=scale, half_precision=half_precision\n    )\n    return upscaler.upscale(frames)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/postprocess/rife_interpolator.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\"\"\"\nRIFE 4.22.lite frame interpolation for SGLang diffusion pipelines.\n\nRIFE model code is vendored and adapted from:\n  - https://github.com/hzwer/ECCV2022-RIFE  (MIT License)\n  - https://github.com/hzwer/Practical-RIFE  (MIT License)\n  Copyright (c) 2021 Zhewei Huang\n\nThe FrameInterpolator wrapper and integration code are original work.\n\"\"\"\n\nimport os\nfrom typing import Optional\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n# Default HuggingFace repo for RIFE 4.22.lite weights\n_DEFAULT_RIFE_HF_REPO = \"elfgum/RIFE-4.22.lite\"\n\n# Module-level cache: model_path -> Model instance\n_MODEL_CACHE: dict[str, \"Model\"] = {}\n\n\n# ---------------------------------------------------------------------------\n# Vendored RIFE 4.22.lite model code\n# (IFBlock, IFNet_HDv3 backbone, Model wrapper)\n# ---------------------------------------------------------------------------\n\n\ndef warp(tenInput: torch.Tensor, tenFlow: torch.Tensor) -> torch.Tensor:\n    \"\"\"Warp tenInput by tenFlow using grid_sample.\"\"\"\n    # Build base grid for the current size\n    tenHorizontal = (\n        torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=tenFlow.device)\n        .view(1, 1, 1, tenFlow.shape[3])\n        .expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)\n    )\n    tenVertical = (\n        torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=tenFlow.device)\n        .view(1, 1, tenFlow.shape[2], 1)\n        .expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])\n    )\n    tenGrid = torch.cat([tenHorizontal, tenVertical], dim=1)\n\n    tenFlow = torch.cat(\n        [\n            tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),\n            tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0),\n        ],\n        dim=1,\n    )\n\n    grid = (tenGrid + tenFlow).permute(0, 2, 3, 1)\n    return F.grid_sample(\n        input=tenInput,\n        grid=grid,\n        mode=\"bilinear\",\n        padding_mode=\"border\",\n        align_corners=True,\n    )\n\n\ndef _conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):\n    \"\"\"Conv2d + LeakyReLU helper (matches RIFE 4.22 conv()).\"\"\"\n    return nn.Sequential(\n        nn.Conv2d(\n            in_planes,\n            out_planes,\n            kernel_size=kernel_size,\n            stride=stride,\n            padding=padding,\n            dilation=dilation,\n            bias=True,\n        ),\n        nn.LeakyReLU(0.2, True),\n    )\n\n\nclass ResConv(nn.Module):\n    \"\"\"Residual convolution block with learnable beta scaling (RIFE 4.22).\"\"\"\n\n    def __init__(self, c: int, dilation: int = 1):\n        super().__init__()\n        self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1)\n        self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True)\n        self.relu = nn.LeakyReLU(0.2, True)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        return self.relu(self.conv(x) * self.beta + x)\n\n\nclass IFBlock(nn.Module):\n    \"\"\"Single-scale optical flow + mask + feature block (RIFE 4.22).\"\"\"\n\n    def __init__(self, in_planes: int, c: int = 64):\n        super().__init__()\n        self.conv0 = nn.Sequential(\n            _conv(in_planes, c // 2, 3, 2, 1),\n            _conv(c // 2, c, 3, 2, 1),\n        )\n        self.convblock = nn.Sequential(\n            ResConv(c),\n            ResConv(c),\n            ResConv(c),\n            ResConv(c),\n            ResConv(c),\n            ResConv(c),\n            ResConv(c),\n            ResConv(c),\n        )\n        self.lastconv = nn.Sequential(\n            nn.ConvTranspose2d(c, 4 * 13, 4, 2, 1),\n            nn.PixelShuffle(2),\n        )\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        flow: Optional[torch.Tensor] = None,\n        scale: float = 1.0,\n    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        x = F.interpolate(\n            x, scale_factor=1.0 / scale, mode=\"bilinear\", align_corners=False\n        )\n        if flow is not None:\n            flow = (\n                F.interpolate(\n                    flow,\n                    scale_factor=1.0 / scale,\n                    mode=\"bilinear\",\n                    align_corners=False,\n                )\n                * 1.0\n                / scale\n            )\n            x = torch.cat((x, flow), 1)\n        feat = self.conv0(x)\n        feat = self.convblock(feat)\n        tmp = self.lastconv(feat)\n        tmp = F.interpolate(\n            tmp, scale_factor=scale, mode=\"bilinear\", align_corners=False\n        )\n        flow = tmp[:, :4] * scale\n        mask = tmp[:, 4:5]\n        feat = tmp[:, 5:]\n        return flow, mask, feat\n\n\nclass Head(nn.Module):\n    \"\"\"Feature encoder producing 4-channel features at full resolution (RIFE 4.22).\"\"\"\n\n    def __init__(self):\n        super().__init__()\n        self.cnn0 = nn.Conv2d(3, 16, 3, 2, 1)\n        self.cnn1 = nn.Conv2d(16, 16, 3, 1, 1)\n        self.cnn2 = nn.Conv2d(16, 16, 3, 1, 1)\n        self.cnn3 = nn.ConvTranspose2d(16, 4, 4, 2, 1)\n        self.relu = nn.LeakyReLU(0.2, True)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x0 = self.cnn0(x)\n        x = self.relu(x0)\n        x1 = self.cnn1(x)\n        x = self.relu(x1)\n        x2 = self.cnn2(x)\n        x = self.relu(x2)\n        x3 = self.cnn3(x)\n        return x3\n\n\nclass IFNet(nn.Module):\n    \"\"\"4-scale IFNet optical flow network (RIFE 4.22 backbone).\"\"\"\n\n    def __init__(self):\n        super().__init__()\n        self.block0 = IFBlock(7 + 8, c=192)\n        self.block1 = IFBlock(8 + 4 + 8 + 8, c=128)\n        self.block2 = IFBlock(8 + 4 + 8 + 8, c=64)\n        self.block3 = IFBlock(8 + 4 + 8 + 8, c=32)\n        self.encode = Head()\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        timestep: float = 0.5,\n        scale_list: Optional[list] = None,\n    ) -> tuple[list, torch.Tensor, list]:\n        if scale_list is None:\n            scale_list = [8, 4, 2, 1]\n\n        channel = x.shape[1] // 2\n        img0 = x[:, :channel]\n        img1 = x[:, channel:]\n\n        if not torch.is_tensor(timestep):\n            timestep = (x[:, :1].clone() * 0 + 1) * timestep\n        else:\n            timestep = timestep.repeat(1, 1, img0.shape[2], img0.shape[3])\n\n        f0 = self.encode(img0[:, :3])\n        f1 = self.encode(img1[:, :3])\n\n        flow_list = []\n        merged = []\n        mask_list = []\n        warped_img0 = img0\n        warped_img1 = img1\n        flow = None\n        mask = None\n\n        block = [self.block0, self.block1, self.block2, self.block3]\n        for i in range(4):\n            if flow is None:\n                flow, mask, feat = block[i](\n                    torch.cat((img0[:, :3], img1[:, :3], f0, f1, timestep), 1),\n                    None,\n                    scale=scale_list[i],\n                )\n            else:\n                wf0 = warp(f0, flow[:, :2])\n                wf1 = warp(f1, flow[:, 2:4])\n                fd, m0, feat = block[i](\n                    torch.cat(\n                        (\n                            warped_img0[:, :3],\n                            warped_img1[:, :3],\n                            wf0,\n                            wf1,\n                            timestep,\n                            mask,\n                            feat,\n                        ),\n                        1,\n                    ),\n                    flow,\n                    scale=scale_list[i],\n                )\n                mask = m0\n                flow = flow + fd\n\n            mask_list.append(mask)\n            flow_list.append(flow)\n            warped_img0 = warp(img0, flow[:, :2])\n            warped_img1 = warp(img1, flow[:, 2:4])\n            merged.append((warped_img0, warped_img1))\n\n        mask = torch.sigmoid(mask)\n        merged[3] = warped_img0 * mask + warped_img1 * (1 - mask)\n\n        return flow_list, mask_list[3], merged\n\n\nclass Model:\n    \"\"\"Wraps IFNet, provides load_model() and inference() API.\"\"\"\n\n    def __init__(self):\n        self.flownet = IFNet()\n        self.device_type: str = \"cpu\"\n\n    def eval(self) -> \"Model\":\n        self.flownet.eval()\n        return self\n\n    def device(self) -> torch.device:\n        return next(self.flownet.parameters()).device\n\n    def load_model(self, path: str, strip_module_prefix: bool = True) -> None:\n        \"\"\"Load weights from {path}/flownet.pkl.\n\n        Args:\n            path: Directory containing ``flownet.pkl``.\n            strip_module_prefix: If True, strip the ``module.`` prefix that\n                ``DataParallel`` / ``DistributedDataParallel`` adds to keys.\n        \"\"\"\n        flownet_path = os.path.join(path, \"flownet.pkl\")\n        if not os.path.isfile(flownet_path):\n            raise FileNotFoundError(\n                f\"RIFE weight file not found: {flownet_path}\\n\"\n                \"Expected layout: <model_path>/flownet.pkl\"\n            )\n\n        def convert(param):\n            if strip_module_prefix:\n                return {\n                    k.replace(\"module.\", \"\"): v\n                    for k, v in param.items()\n                    if \"module.\" in k\n                }\n            else:\n                return {k: v for k, v in param.items() if \"module.\" not in k}\n\n        state = torch.load(flownet_path, map_location=\"cpu\", weights_only=False)\n        self.flownet.load_state_dict(convert(state), strict=False)\n        logger.info(\"Loaded RIFE weights from %s\", flownet_path)\n\n    def inference(\n        self,\n        img0: torch.Tensor,\n        img1: torch.Tensor,\n        scale: float = 1.0,\n        timestep: float = 0.5,\n    ) -> torch.Tensor:\n        \"\"\"Interpolate a single intermediate frame between img0 and img1.\"\"\"\n        n, c, h, w = img0.shape\n\n        # Pad to multiples of 32 so that RIFE's downsample/upsample round-trips\n        # preserve spatial dimensions exactly.\n        ph = ((h - 1) // 32 + 1) * 32\n        pw = ((w - 1) // 32 + 1) * 32\n        pad = (0, pw - w, 0, ph - h)\n        img0 = F.pad(img0, pad)\n        img1 = F.pad(img1, pad)\n\n        imgs = torch.cat((img0, img1), 1)\n        scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale]\n        with torch.no_grad():\n            flow_list, mask, merged = self.flownet(\n                imgs,\n                timestep=timestep,\n                scale_list=scale_list,\n            )\n\n        # Crop back to original resolution\n        return merged[3][:, :, :h, :w]\n\n\n# ---------------------------------------------------------------------------\n# FrameInterpolator public class\n# ---------------------------------------------------------------------------\n\n\nclass FrameInterpolator:\n    \"\"\"\n    Lazy-loaded RIFE 4.22.lite frame interpolator.\n\n    Weights are loaded on first call to `.interpolate()` and cached globally\n    per model_path to avoid reloading across requests.\n    \"\"\"\n\n    def __init__(self, model_path: Optional[str] = None):\n        self._model_path = model_path\n        self._resolved_path: Optional[str] = None\n\n    def _ensure_model_loaded(self) -> Model:\n        \"\"\"Load RIFE model weights.\n\n        Accepts a local directory **or** a HuggingFace repo ID.  When *None*\n        (the default) the weights are downloaded (and cached) automatically\n        from ``elfgum/RIFE-4.22.lite`` via ``maybe_download_model()``.\n        \"\"\"\n        from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import (\n            maybe_download_model,\n        )\n\n        model_path = self._model_path or _DEFAULT_RIFE_HF_REPO\n\n        # Resolve: local path pass-through, HF repo ID → download & cache\n        model_path = maybe_download_model(model_path)\n\n        self._resolved_path = model_path\n\n        if model_path in _MODEL_CACHE:\n            return _MODEL_CACHE[model_path]\n\n        device = current_platform.get_local_torch_device()\n        model = Model()\n        model.load_model(model_path, strip_module_prefix=True)\n        model.eval()\n        model.flownet = model.flownet.to(device)\n        _MODEL_CACHE[model_path] = model\n        logger.info(\"RIFE model loaded on device: %s\", device)\n        return model\n\n    @staticmethod\n    def _frame_to_tensor(frame: np.ndarray, device: torch.device) -> torch.Tensor:\n        \"\"\"Convert uint8 HWC numpy frame to float32 CHW tensor on device.\"\"\"\n        t = torch.from_numpy(frame).permute(2, 0, 1).unsqueeze(0).float() / 255.0\n        return t.to(device)\n\n    @staticmethod\n    def _tensor_to_frame(t: torch.Tensor) -> np.ndarray:\n        \"\"\"Convert float32 CHW tensor (batch=1) to uint8 HWC numpy frame.\"\"\"\n        arr = t.squeeze(0).permute(1, 2, 0).clamp(0.0, 1.0).cpu().numpy()\n        return (arr * 255.0).astype(np.uint8)\n\n    def _make_inference(\n        self, model: Model, I0: torch.Tensor, I1: torch.Tensor, n: int, scale: float\n    ) -> list[torch.Tensor]:\n        \"\"\"\n        Recursively generate n-1 intermediate frames between I0 and I1.\n\n        Returns a list of intermediate frame tensors (not including I0 or I1).\n        \"\"\"\n        if n == 1:\n            return [model.inference(I0, I1, scale=scale)]\n        mid = model.inference(I0, I1, scale=scale)\n        return (\n            self._make_inference(model, I0, mid, n // 2, scale)\n            + [mid]\n            + self._make_inference(model, mid, I1, n // 2, scale)\n        )\n\n    def interpolate(\n        self,\n        frames: list[np.ndarray],\n        exp: int = 1,\n        scale: float = 1.0,\n    ) -> tuple[list[np.ndarray], int]:\n        \"\"\"\n        Interpolate frames using RIFE.\n\n        Args:\n            frames: List of uint8 numpy arrays with shape [H, W, 3].\n            exp:    Exponent for interpolation factor. 1 → 2×, 2 → 4×.\n            scale:  RIFE inference scale. Use 0.5 for high-resolution inputs.\n\n        Returns:\n            (interpolated_frames, multiplier) where multiplier = 2**exp.\n        \"\"\"\n        if len(frames) < 2:\n            logger.warning(\n                \"Frame interpolation requires at least 2 frames; returning input unchanged.\"\n            )\n            return frames, 1\n\n        model = self._ensure_model_loaded()\n        device = model.device()\n\n        n_intermediate = 2**exp // 2  # intermediates per adjacent pair\n\n        result: list[np.ndarray] = []\n        for i in range(len(frames) - 1):\n            I0 = self._frame_to_tensor(frames[i], device)\n            I1 = self._frame_to_tensor(frames[i + 1], device)\n\n            intermediate_tensors = self._make_inference(\n                model, I0, I1, n_intermediate, scale\n            )\n\n            result.append(frames[i])\n            for t in intermediate_tensors:\n                result.append(self._tensor_to_frame(t))\n\n        result.append(frames[-1])\n        multiplier = 2**exp\n        return result, multiplier\n\n\n# ---------------------------------------------------------------------------\n# Module-level convenience function\n# ---------------------------------------------------------------------------\n\n\ndef interpolate_video_frames(\n    frames: list[np.ndarray],\n    exp: int = 1,\n    scale: float = 1.0,\n    model_path: Optional[str] = None,\n) -> tuple[list[np.ndarray], int]:\n    \"\"\"\n    Convenience wrapper around FrameInterpolator.\n\n    Args:\n        frames:     List of uint8 HWC numpy frames.\n        exp:        Interpolation exponent (1=2×, 2=4×).\n        scale:      RIFE inference scale (default 1.0; use 0.5 for high-res).\n        model_path: Local directory or HuggingFace repo ID containing\n                    ``flownet.pkl``.  *None* → default ``elfgum/RIFE-4.22.lite``.\n\n    Returns:\n        (interpolated_frames, multiplier)\n    \"\"\"\n    interpolator = FrameInterpolator(model_path=model_path)\n    return interpolator.interpolate(frames, exp=exp, scale=scale)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/scheduler_client.py",
    "content": "import pickle\nfrom typing import Any\n\nimport zmq\nimport zmq.asyncio\n\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\nasync def run_zeromq_broker(server_args: ServerArgs):\n    \"\"\"\n    This function runs as a background task in the FastAPI process.\n    It listens for TCP requests from offline clients (e.g., DiffGenerator).\n    \"\"\"\n    ctx = zmq.asyncio.Context()\n    # This is the REP socket that listens for requests from DiffGenerator\n    socket = ctx.socket(zmq.REP)\n    broker_endpoint = f\"tcp://*:{server_args.broker_port}\"\n    socket.bind(broker_endpoint)\n    logger.info(f\"ZMQ Broker is listening for offline jobs on {broker_endpoint}\")\n\n    while True:\n        try:\n            # 1. Receive a request from an offline client\n            payload = await socket.recv()\n            request_batch = pickle.loads(payload)\n            logger.info(\"Broker received an offline job from a client.\")\n\n            # 2. Forward the request to the main Scheduler via the shared client\n            response_batch = await async_scheduler_client.forward(request_batch)\n\n            # 3. Send the Scheduler's reply back to the offline client\n            await socket.send(pickle.dumps(response_batch))\n\n        except Exception as e:\n            logger.error(f\"Error in ZMQ Broker: {e}\", exc_info=True)\n            # A reply must be sent to prevent the client from hanging\n            try:\n                await socket.send(pickle.dumps({\"status\": \"error\", \"message\": str(e)}))\n            except Exception:\n                pass\n\n\nclass SchedulerClient:\n    \"\"\"\n    A synchronous, singleton client for communicating with the Scheduler service.\n    Designed for use in DiffGenerator, where synchronous usage is preferred\n    \"\"\"\n\n    def __init__(self):\n        self.context = None\n        self.scheduler_socket = None\n        self.server_args = None\n\n    def initialize(self, server_args: ServerArgs):\n        if self.context is not None and not self.context.closed:\n            logger.warning(\"SchedulerClient is already initialized. Re-initializing.\")\n            self.close()\n\n        self.server_args = server_args\n        self.context = zmq.Context()\n        self.scheduler_socket = self.context.socket(zmq.REQ)\n\n        # Set socket options for the main communication socket\n        self.scheduler_socket.setsockopt(zmq.LINGER, 0)\n\n        # 100 minute timeout for generation\n        self.scheduler_socket.setsockopt(zmq.RCVTIMEO, 6000000)\n\n        scheduler_endpoint = self.server_args.scheduler_endpoint\n        self.scheduler_socket.connect(scheduler_endpoint)\n        logger.debug(\n            f\"SchedulerClient connected to backend scheduler at {scheduler_endpoint}\"\n        )\n\n    def forward(self, batch: Any) -> Any:\n        \"\"\"Sends a batch or request to the scheduler and waits for the response.\"\"\"\n        try:\n            self.scheduler_socket.send_pyobj(batch)\n            output_batch = self.scheduler_socket.recv_pyobj()\n            return output_batch\n        except zmq.error.Again:\n            logger.error(\"Timeout waiting for response from scheduler.\")\n            raise TimeoutError(\"Scheduler did not respond in time.\")\n\n    def ping(self) -> bool:\n        \"\"\"\n        Checks if the scheduler server is alive using a temporary socket.\n        \"\"\"\n        if self.context is None or self.context.closed:\n            logger.error(\"Cannot ping: client is not initialized.\")\n            return False\n\n        ping_socket = self.context.socket(zmq.REQ)\n        ping_socket.setsockopt(zmq.LINGER, 0)\n        ping_socket.setsockopt(zmq.RCVTIMEO, 2000)  # 2-second timeout for pings\n\n        endpoint = self.server_args.scheduler_endpoint\n\n        try:\n            ping_socket.connect(endpoint)\n            ping_socket.send_pyobj({\"method\": \"ping\"})\n            ping_socket.recv_pyobj()\n            return True\n        except zmq.error.Again:\n            return False\n        finally:\n            ping_socket.close()\n\n    def close(self):\n        \"\"\"Closes the socket and terminates the context.\"\"\"\n        if self.scheduler_socket:\n            self.scheduler_socket.close()\n            self.scheduler_socket = None\n        if self.context:\n            self.context.term()\n            self.context = None\n\n\nclass AsyncSchedulerClient:\n    \"\"\"\n    An asynchronous, singleton client for communicating with the Scheduler service.\n    Designed for use in asynchronous environments like FastAPI entrypoints.\n\n    To support high concurrency, it creates a new REQ socket for each request\n    rather than sharing a single one (which would cause ZMQ state errors).\n    \"\"\"\n\n    def __init__(self):\n        self.context = None\n        self.server_args = None\n\n    def initialize(self, server_args: ServerArgs):\n        if self.context is not None and not self.context.closed:\n            logger.warning(\n                \"AsyncSchedulerClient is already initialized. Re-initializing.\"\n            )\n            self.close()\n\n        self.server_args = server_args\n        self.context = zmq.asyncio.Context()\n        logger.debug(\"AsyncSchedulerClient initialized with zmq.asyncio.Context\")\n\n    async def forward(self, batch: Any) -> Any:\n        \"\"\"Sends a batch or request to the scheduler and waits for the response.\"\"\"\n        if self.context is None:\n            raise RuntimeError(\n                \"AsyncSchedulerClient is not initialized. Call initialize() first.\"\n            )\n\n        # Create a temporary REQ socket for this request to allow concurrency\n        socket = self.context.socket(zmq.REQ)\n        socket.setsockopt(zmq.LINGER, 0)\n        # 100 minute timeout\n        socket.setsockopt(zmq.RCVTIMEO, 6000000)\n\n        endpoint = self.server_args.scheduler_endpoint\n        socket.connect(endpoint)\n\n        try:\n            await socket.send(pickle.dumps(batch))\n            payload = await socket.recv()\n            return pickle.loads(payload)\n        except zmq.error.Again:\n            logger.error(\"Timeout waiting for response from scheduler.\")\n            raise TimeoutError(\"Scheduler did not respond in time.\")\n        finally:\n            socket.close()\n\n    async def ping(self) -> bool:\n        \"\"\"\n        Checks if the scheduler server is alive using a temporary socket.\n        \"\"\"\n        if self.context is None or self.context.closed:\n            logger.error(\"Cannot ping: client is not initialized.\")\n            return False\n\n        ping_socket = self.context.socket(zmq.REQ)\n        ping_socket.setsockopt(zmq.LINGER, 0)\n        ping_socket.setsockopt(zmq.RCVTIMEO, 2000)\n\n        endpoint = self.server_args.scheduler_endpoint\n\n        try:\n            ping_socket.connect(endpoint)\n            await ping_socket.send(pickle.dumps({\"method\": \"ping\"}))\n            await ping_socket.recv()\n            return True\n        except zmq.error.Again:\n            return False\n        finally:\n            ping_socket.close()\n\n    def close(self):\n        \"\"\"Closes the socket and terminates the context.\"\"\"\n        if self.context:\n            self.context.term()\n            self.context = None\n\n\n# Singleton instances for easy access\nasync_scheduler_client = AsyncSchedulerClient()\nsync_scheduler_client = SchedulerClient()\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/server_args.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# Inspired by SGLang: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/server_args.py\n\"\"\"The arguments of sglang-diffusion Inference.\"\"\"\n\nimport argparse\nimport dataclasses\nimport json\nimport math\nimport os\nimport random\nimport sys\nimport tempfile\nfrom dataclasses import field\nfrom enum import Enum\nfrom typing import Any, Optional\n\nimport addict\nimport yaml\n\nfrom sglang.multimodal_gen import envs\nfrom sglang.multimodal_gen.configs.models.encoders import T5Config\nfrom sglang.multimodal_gen.configs.pipeline_configs.base import PipelineConfig\nfrom sglang.multimodal_gen.configs.quantization import NunchakuSVDQuantArgs\nfrom sglang.multimodal_gen.runtime.layers.quantization.configs.nunchaku_config import (\n    NunchakuConfig,\n)\nfrom sglang.multimodal_gen.runtime.loader.utils import BYTES_PER_GB\nfrom sglang.multimodal_gen.runtime.platforms import (\n    AttentionBackendEnum,\n    current_platform,\n)\nfrom sglang.multimodal_gen.runtime.utils.common import (\n    is_port_available,\n    is_valid_ipv6_address,\n)\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import (\n    _sanitize_for_logging,\n    configure_logger,\n    init_logger,\n)\nfrom sglang.multimodal_gen.utils import (\n    FlexibleArgumentParser,\n    StoreBoolean,\n    expand_path_fields,\n    expand_path_kwargs,\n)\n\nlogger = init_logger(__name__)\n\n\nclass Backend(str, Enum):\n    \"\"\"\n    Enumeration for different model backends.\n    - AUTO: Automatically select backend (prefer sglang native, fallback to diffusers)\n    - SGLANG: Use sglang's native optimized implementation\n    - DIFFUSERS: Use vanilla diffusers pipeline (supports all diffusers models)\n    \"\"\"\n\n    AUTO = \"auto\"\n    SGLANG = \"sglang\"\n    DIFFUSERS = \"diffusers\"\n\n    @classmethod\n    def from_string(cls, value: str) -> \"Backend\":\n        \"\"\"Convert string to Backend enum.\"\"\"\n        try:\n            return cls(value.lower())\n        except ValueError:\n            raise ValueError(\n                f\"Invalid backend: {value}. Must be one of: {', '.join([m.value for m in cls])}\"\n            ) from None\n\n    @classmethod\n    def choices(cls) -> list[str]:\n        \"\"\"Get all available choices as strings for argparse.\"\"\"\n        return [backend.value for backend in cls]\n\n\n@dataclasses.dataclass\nclass ServerArgs:\n    # Model and path configuration (for convenience)\n    model_path: str\n\n    # explicit model ID override (e.g. \"Qwen-Image\")\n    model_id: str | None = None\n\n    # Model backend (sglang native or diffusers)\n    backend: Backend = Backend.AUTO\n\n    # Attention\n    attention_backend: str = None\n    attention_backend_config: addict.Dict | None = None\n    cache_dit_config: str | dict[str, Any] | None = (\n        None  # cache-dit config for diffusers\n    )\n\n    # Distributed executor backend\n    nccl_port: Optional[int] = None\n\n    # HuggingFace specific parameters\n    trust_remote_code: bool = False\n    revision: str | None = None\n\n    # Parallelism\n    num_gpus: int = 1\n    tp_size: Optional[int] = None\n    sp_degree: Optional[int] = None\n    # sequence parallelism\n    ulysses_degree: Optional[int] = None\n    ring_degree: Optional[int] = None\n    # data parallelism\n    # number of data parallelism groups\n    dp_size: int = 1\n    # number of gpu in a dp group\n    dp_degree: int = 1\n    # cfg parallel\n    enable_cfg_parallel: bool = False\n\n    hsdp_replicate_dim: int = 1\n    hsdp_shard_dim: Optional[int] = None\n    dist_timeout: int | None = 3600  # 1 hour\n\n    pipeline_config: PipelineConfig = field(default_factory=PipelineConfig, repr=False)\n\n    # Pipeline override\n    pipeline_class_name: str | None = (\n        None  # Override pipeline class from model_index.json\n    )\n\n    # LoRA parameters\n    # (Wenxuan) prefer to keep it here instead of in pipeline config to not make it complicated.\n    lora_path: str | None = None\n    lora_nickname: str = \"default\"  # for swapping adapters in the pipeline\n    lora_scale: float = 1.0  # LoRA scale for merging (e.g., 0.125 for Hyper-SD)\n\n    # Component path overrides (key = model_index.json component name, value = path)\n    component_paths: dict[str, str] = field(default_factory=dict)\n\n    # path to pre-quantized transformer weights (single .safetensors or directory).\n    transformer_weights_path: str | None = None\n    # can restrict layers to adapt, e.g. [\"q_proj\"]\n    # Will adapt only q, k, v, o by default.\n    lora_target_modules: list[str] | None = None\n\n    # CPU offload parameters\n    dit_cpu_offload: bool | None = None\n    dit_layerwise_offload: bool | None = None\n    dit_offload_prefetch_size: float = 0.0\n    text_encoder_cpu_offload: bool | None = None\n    image_encoder_cpu_offload: bool | None = None\n    vae_cpu_offload: bool | None = None\n    use_fsdp_inference: bool = False\n    pin_cpu_memory: bool = True\n\n    # ComfyUI integration\n    comfyui_mode: bool = False\n\n    # Compilation\n    enable_torch_compile: bool = False\n\n    # warmup\n    warmup: bool = False\n    warmup_resolutions: list[str] = None\n    warmup_steps: int = 1\n\n    disable_autocast: bool | None = None\n\n    # Quantization / Nunchaku SVDQuant configuration\n    nunchaku_config: NunchakuSVDQuantArgs | NunchakuConfig | None = field(\n        default_factory=NunchakuSVDQuantArgs, repr=False\n    )\n\n    # Master port for distributed inference\n    # TODO: do not hard code\n    master_port: int | None = None\n\n    # http server endpoint config\n    host: str | None = \"127.0.0.1\"\n    port: int | None = 30000\n\n    # TODO: webui and their endpoint, check if webui_port is available.\n    webui: bool = False\n    webui_port: int | None = 12312\n\n    scheduler_port: int = 5555\n\n    output_path: str | None = \"outputs/\"\n    input_save_path: str | None = \"inputs/uploads\"\n\n    # Prompt text file for batch processing\n    prompt_file_path: str | None = None\n\n    # model paths for correct deallocation\n    model_paths: dict[str, str] = field(default_factory=dict)\n    model_loaded: dict[str, bool] = field(\n        default_factory=lambda: {\n            \"transformer\": True,\n            \"vae\": True,\n            \"video_vae\": True,\n            \"audio_vae\": True,\n            \"video_dit\": True,\n            \"audio_dit\": True,\n            \"dual_tower_bridge\": True,\n        }\n    )\n\n    # # DMD parameters\n    # dmd_denoising_steps: List[int] | None = field(default=None)\n\n    # MoE parameters used by Wan2.2\n    boundary_ratio: float | None = None\n\n    # Logging\n    log_level: str = \"info\"\n\n    @property\n    def broker_port(self) -> int:\n        return self.port + 1\n\n    @property\n    def is_local_mode(self) -> bool:\n        \"\"\"\n        If no server is running when a generation task begins, 'local_mode' will be enabled: a dedicated server will be launched\n        \"\"\"\n        return self.host is None or self.port is None\n\n    def _adjust_path(self):\n        expand_path_fields(self)\n        self._adjust_save_paths()\n\n    def _adjust_parameters(self):\n        \"\"\"set defaults and normalize values.\"\"\"\n        self._adjust_offload()\n        self._adjust_path()\n        self._adjust_quant_config()\n        self._adjust_warmup()\n        self._adjust_network_ports()\n        # adjust parallelism before attention backend\n        self._adjust_parallelism()\n        self._adjust_attention_backend()\n        self._adjust_platform_specific()\n        self._adjust_autocast()\n        self.adjust_pipeline_config()\n\n    def _validate_parameters(self):\n        \"\"\"check consistency and raise errors for invalid configs\"\"\"\n        self._validate_pipeline()\n        self._validate_offload()\n        self._validate_parallelism()\n        self._validate_cfg_parallel()\n\n    def _adjust_save_paths(self):\n        \"\"\"Normalize empty-string save paths to None (disabled).\"\"\"\n        if self.output_path is not None and self.output_path.strip() == \"\":\n            self.output_path = None\n        if self.input_save_path is not None and self.input_save_path.strip() == \"\":\n            self.input_save_path = None\n\n    def _adjust_quant_config(self):\n        \"\"\"validate and adjust\"\"\"\n\n        # nunchaku\n        ncfg = self.nunchaku_config\n        if ncfg is None or isinstance(ncfg, NunchakuConfig):\n            return\n        ncfg.validate()\n\n        # propagate the path to server_args\n        if ncfg.transformer_weights_path:\n            self.transformer_weights_path = ncfg.transformer_weights_path\n\n        if not ncfg.enable_svdquant or not ncfg.transformer_weights_path:\n            self.nunchaku_config = None\n        else:\n            self.nunchaku_config = NunchakuConfig(\n                precision=self.nunchaku_config.quantization_precision,\n                rank=self.nunchaku_config.quantization_rank,\n                act_unsigned=self.nunchaku_config.quantization_act_unsigned,\n                transformer_weights_path=self.nunchaku_config.transformer_weights_path,\n            )\n\n    def adjust_pipeline_config(self):\n        # enable parallel folding when SP is enabled\n        if self.tp_size != 1 or self.sp_degree <= 1:\n            return\n\n        enabled = False\n        for text_encoder_config in self.pipeline_config.text_encoder_configs:\n            if isinstance(text_encoder_config, T5Config):\n                text_encoder_config.parallel_folding = True\n                enabled = True\n                text_encoder_config.parallel_folding_mode = \"sp\"\n\n        if enabled:\n            logger.info(\n                \"Enabled T5 text encoder parallel folding (mode=sp) for %s (tp_size=%s, sp_degree=%s).\",\n                self.__class__.__name__,\n                self.tp_size,\n                self.sp_degree,\n            )\n\n    def _adjust_offload(self):\n        # TODO: to be handled by each platform\n        if current_platform.get_device_total_memory() / BYTES_PER_GB < 30:\n            logger.info(\"Enabling all offloading for GPU with low device memory\")\n            if self.dit_cpu_offload is None:\n                self.dit_cpu_offload = True\n            if self.text_encoder_cpu_offload is None:\n                self.text_encoder_cpu_offload = True\n            if self.image_encoder_cpu_offload is None:\n                self.image_encoder_cpu_offload = True\n            if self.vae_cpu_offload is None:\n                self.vae_cpu_offload = True\n        elif self.pipeline_config.task_type.is_image_gen():\n            logger.info(\n                \"Disabling some offloading (except dit, text_encoder) for image generation model\"\n            )\n            if self.dit_cpu_offload is None:\n                self.dit_cpu_offload = True\n            if self.text_encoder_cpu_offload is None:\n                self.text_encoder_cpu_offload = True\n            if self.image_encoder_cpu_offload is None:\n                self.image_encoder_cpu_offload = False\n            if self.vae_cpu_offload is None:\n                self.vae_cpu_offload = False\n        else:\n            if self.dit_cpu_offload is None:\n                self.dit_cpu_offload = True\n            if self.text_encoder_cpu_offload is None:\n                self.text_encoder_cpu_offload = True\n            if self.image_encoder_cpu_offload is None:\n                self.image_encoder_cpu_offload = True\n            if self.vae_cpu_offload is None:\n                self.vae_cpu_offload = True\n\n    def _adjust_attention_backend(self):\n        if self.attention_backend in [\"fa3\", \"fa4\"]:\n            self.attention_backend = \"fa\"\n\n        # attention_backend_config\n        if self.attention_backend_config is None:\n            self.attention_backend_config = addict.Dict()\n        elif isinstance(self.attention_backend_config, str):\n            self.attention_backend_config = addict.Dict(\n                self._parse_attention_backend_config(self.attention_backend_config)\n            )\n\n        if self.ring_degree > 1:\n            if self.attention_backend is not None and self.attention_backend not in (\n                \"fa\",\n                \"sage_attn\",\n            ):\n                raise ValueError(\n                    \"Ring Attention is only supported for flash attention or sage attention backend for now\"\n                )\n            if self.attention_backend is None:\n                self.attention_backend = \"fa\"\n                logger.info(\n                    \"Ring Attention is currently only supported for flash attention or sage attention; \"\n                    \"attention_backend has been automatically set to flash attention\"\n                )\n\n        if self.attention_backend is None and self.backend != Backend.DIFFUSERS:\n            self._set_default_attention_backend()\n\n    def _adjust_warmup(self):\n        if self.warmup_resolutions is not None:\n            self.warmup = True\n\n        if self.warmup:\n            logger.info(\n                \"Warmup enabled, the launch time is expected to be longer than usual\"\n            )\n\n    def _adjust_network_ports(self):\n        self.port = self.settle_port(self.port)\n        initial_scheduler_port = self.scheduler_port + (\n            random.randint(0, 100) if self.scheduler_port == 5555 else 0\n        )\n        self.scheduler_port = self.settle_port(initial_scheduler_port)\n        initial_master_port = (\n            self.master_port\n            if self.master_port is not None\n            else (30005 + random.randint(0, 100))\n        )\n        self.master_port = self.settle_port(initial_master_port, 37)\n\n    def _adjust_parallelism(self):\n        if self.tp_size is None:\n            self.tp_size = 1\n\n        if self.hsdp_shard_dim is None:\n            self.hsdp_shard_dim = self.num_gpus\n\n        # adjust sp_degree: allocate all remaining GPUs after TP and DP\n        if self.sp_degree is None:\n            num_gpus_per_group = self.dp_size * self.tp_size\n            if self.enable_cfg_parallel:\n                num_gpus_per_group *= 2\n            if self.num_gpus % num_gpus_per_group == 0:\n                self.sp_degree = self.num_gpus // num_gpus_per_group\n            else:\n                # Will be validated later\n                self.sp_degree = 1\n\n        if (\n            self.ulysses_degree is None\n            and self.ring_degree is None\n            and self.sp_degree != 1\n        ):\n            self.ulysses_degree = self.sp_degree\n            logger.info(\n                f\"Automatically set ulysses_degree=sp_degree={self.ulysses_degree} for best performance\"\n            )\n\n        if self.ulysses_degree is None:\n            self.ulysses_degree = 1\n            logger.debug(\n                f\"Ulysses degree not set, using default value {self.ulysses_degree}\"\n            )\n\n        if self.ring_degree is None:\n            self.ring_degree = 1\n            logger.debug(f\"Ring degree not set, using default value {self.ring_degree}\")\n\n    def _adjust_platform_specific(self):\n        if current_platform.is_mps():\n            self.use_fsdp_inference = False\n            self.dit_layerwise_offload = False\n\n        # automatically enable dit_layerwise_offload for Wan/MOVA models if appropriate\n        if not envs.SGLANG_CACHE_DIT_ENABLED:\n            pipeline_name_lower = self.pipeline_config.__class__.__name__.lower()\n            if (\n                (\"wan\" in pipeline_name_lower or \"mova\" in pipeline_name_lower)\n                and self.dit_layerwise_offload is None\n                and current_platform.enable_dit_layerwise_offload_for_wan_by_default()\n            ):\n                logger.info(\n                    f\"Automatically enable dit_layerwise_offload for {self.pipeline_config.__class__.__name__} \"\n                    \"for low memory and performance balance\"\n                )\n                self.dit_layerwise_offload = True\n\n    def _adjust_autocast(self):\n        if self.disable_autocast is None:\n            self.disable_autocast = not self.pipeline_config.enable_autocast\n\n    def _parse_attention_backend_config(self, config_str: str) -> dict[str, Any]:\n        \"\"\"parse attention backend config from string.\"\"\"\n        if not config_str:\n            return {}\n\n        # 1. treat as file path\n        if os.path.exists(config_str):\n            if config_str.endswith((\".yaml\", \".yml\")):\n                with open(config_str, \"r\") as f:\n                    return yaml.safe_load(f)\n            elif config_str.endswith(\".json\"):\n                with open(config_str, \"r\") as f:\n                    return json.load(f)\n\n        # 2. treat as JSON string\n        try:\n            return json.loads(config_str)\n        except json.JSONDecodeError:\n            pass\n\n        # 3. treat as k=v pairs (simple implementation). e.g., \"sparsity=0.5,enable_x=true\"\n        try:\n            config = {}\n            pairs = config_str.split(\",\")\n            for pair in pairs:\n                k, v = pair.split(\"=\", 1)\n                k = k.strip()\n                v = v.strip()\n                if v.lower() == \"true\":\n                    v = True\n                elif v.lower() == \"false\":\n                    v = False\n                elif v.replace(\".\", \"\", 1).isdigit():\n                    v = float(v) if \".\" in v else int(v)\n                config[k] = v\n            return config\n        except Exception:\n            raise ValueError(f\"Could not parse attention backend config: {config_str}\")\n\n    def __post_init__(self):\n        # configure logger before use\n        configure_logger(server_args=self)\n\n        # 1. adjust parameters\n        self._adjust_parameters()\n\n        # 2. Validate parameters\n        self._validate_parameters()\n\n        # log clean server_args\n        try:\n            safe_args = _sanitize_for_logging(self, key_hint=\"server_args\")\n            logger.info(\"server_args: %s\", json.dumps(safe_args, ensure_ascii=False))\n        except Exception:\n            # Fallback to default repr if sanitization fails\n            logger.info(f\"server_args: {self}\")\n\n    @staticmethod\n    def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:\n        # Model and path configuration\n        parser.add_argument(\n            \"--model-path\",\n            type=str,\n            help=\"The path of the model weights. This can be a local folder or a Hugging Face repo ID.\",\n        )\n        parser.add_argument(\n            \"--model-id\",\n            type=str,\n            default=ServerArgs.model_id,\n            help=(\n                \"Override the model ID used for config resolution. \"\n                \"Useful when --model-path is a local directory whose name does not match \"\n                \"any registered HF repo name. Should be the repo name portion of the HF ID \"\n                \"(e.g. 'Qwen-Image' for 'Qwen/Qwen-Image').\"\n            ),\n        )\n        # attention\n        parser.add_argument(\n            \"--attention-backend\",\n            type=str,\n            default=None,\n            help=(\n                \"The attention backend to use. For SGLang-native pipelines, use \"\n                \"values like fa, torch_sdpa, sage_attn, etc. For diffusers pipelines, \"\n                \"use diffusers attention backend names such as flash, _flash_3_hub, \"\n                \"sage, or xformers.\"\n            ),\n        )\n        parser.add_argument(\n            \"--attention-backend-config\",\n            type=str,\n            default=None,\n            help=\"Configuration for the attention backend. Can be a JSON string, a path to a JSON/YAML file, or key=value pairs.\",\n        )\n        parser.add_argument(\n            \"--cache-dit-config\",\n            type=str,\n            default=ServerArgs.cache_dit_config,\n            help=\"Path to a Cache-DiT YAML/JSON config. Enables cache-dit for diffusers backend.\",\n        )\n\n        # HuggingFace specific parameters\n        parser.add_argument(\n            \"--trust-remote-code\",\n            action=StoreBoolean,\n            default=ServerArgs.trust_remote_code,\n            help=\"Trust remote code when loading HuggingFace models\",\n        )\n        parser.add_argument(\n            \"--revision\",\n            type=str,\n            default=ServerArgs.revision,\n            help=\"The specific model version to use (can be a branch name, tag name, or commit id)\",\n        )\n\n        # Parallelism\n        parser.add_argument(\n            \"--num-gpus\",\n            type=int,\n            default=ServerArgs.num_gpus,\n            help=\"The number of GPUs to use.\",\n        )\n        parser.add_argument(\n            \"--tp-size\",\n            type=int,\n            default=None,\n            help=\"The tensor parallelism size. Defaults to 1 if not specified.\",\n        )\n        parser.add_argument(\n            \"--sp-degree\",\n            type=int,\n            default=None,\n            help=\"The sequence parallelism size. If not specified, will use all remaining GPUs after accounting for TP and DP.\",\n        )\n        parser.add_argument(\n            \"--ulysses-degree\",\n            type=int,\n            default=ServerArgs.ulysses_degree,\n            help=\"Ulysses sequence parallel degree. Used in attention layer.\",\n        )\n        parser.add_argument(\n            \"--ring-degree\",\n            type=int,\n            default=ServerArgs.ring_degree,\n            help=\"Ring sequence parallel degree. Used in attention layer.\",\n        )\n        parser.add_argument(\n            \"--enable-cfg-parallel\",\n            action=\"store_true\",\n            default=ServerArgs.enable_cfg_parallel,\n            help=\"Enable cfg parallel.\",\n        )\n        parser.add_argument(\n            \"--data-parallel-size\",\n            \"--dp-size\",\n            \"--dp\",\n            type=int,\n            default=ServerArgs.dp_size,\n            help=\"The data parallelism size.\",\n        )\n\n        parser.add_argument(\n            \"--hsdp-replicate-dim\",\n            type=int,\n            default=ServerArgs.hsdp_replicate_dim,\n            help=\"The data parallelism size.\",\n        )\n        parser.add_argument(\n            \"--hsdp-shard-dim\",\n            type=int,\n            default=None,\n            help=\"The data parallelism shards. Defaults to num_gpus if not specified.\",\n        )\n        parser.add_argument(\n            \"--dist-timeout\",\n            type=int,\n            default=ServerArgs.dist_timeout,\n            help=\"Timeout for torch.distributed operations in seconds. \"\n            \"Increase this value if you encounter 'Connection closed by peer' errors after the service is idle. \",\n        )\n\n        # Prompt text file for batch processing\n        parser.add_argument(\n            \"--prompt-file-path\",\n            type=str,\n            default=ServerArgs.prompt_file_path,\n            help=\"Path to a text file containing prompts (one per line) for batch processing\",\n        )\n\n        parser.add_argument(\n            \"--mask-strategy-file-path\",\n            type=str,\n            help=\"Path to mask strategy JSON file for STA\",\n        )\n        parser.add_argument(\n            \"--enable-torch-compile\",\n            action=StoreBoolean,\n            default=ServerArgs.enable_torch_compile,\n            help=\"Use torch.compile to speed up DiT inference.\"\n            + \"However, will likely cause precision drifts. See (https://github.com/pytorch/pytorch/issues/145213)\",\n        )\n\n        # warmup\n        parser.add_argument(\n            \"--warmup\",\n            action=StoreBoolean,\n            default=ServerArgs.warmup,\n            help=\"Perform some warmup after server starts (if `--warmup-resolutions` is specified) or before processing the first request (if `--warmup-resolutions` is not specified).\"\n            \"Recommended to enable when benchmarking to ensure fair comparison and best performance.\"\n            \"When enabled with `--warmup-resolutions` unspecified, look for the line ending with `(with warmup excluded)` for actual processing time.\",\n        )\n        parser.add_argument(\n            \"--warmup-resolutions\",\n            type=str,\n            nargs=\"+\",\n            default=ServerArgs.warmup_resolutions,\n            help=\"Specify resolutions for server to warmup. e.g., `--warmup-resolutions 256x256, 720x720`\",\n        )\n        parser.add_argument(\n            \"--warmup-steps\",\n            type=int,\n            default=ServerArgs.warmup_steps,\n            help=\"The number of warmup steps to perform for each resolution.\",\n        )\n\n        parser.add_argument(\n            \"--dit-cpu-offload\",\n            action=StoreBoolean,\n            help=\"Use CPU offload for DiT inference. Enable if run out of memory with FSDP.\",\n        )\n        parser.add_argument(\n            \"--dit-layerwise-offload\",\n            action=StoreBoolean,\n            default=ServerArgs.dit_layerwise_offload,\n            help=\"Enable layerwise CPU offload with async H2D prefetch overlap for supported DiT models (e.g., Wan, MOVA). \"\n            \"Cannot be used together with cache-dit (SGLANG_CACHE_DIT_ENABLED), dit_cpu_offload, or use_fsdp_inference.\",\n        )\n        parser.add_argument(\n            \"--dit-offload-prefetch-size\",\n            type=float,\n            default=ServerArgs.dit_offload_prefetch_size,\n            help=\"The size of prefetch for dit-layerwise-offload. If the value is between 0.0 and 1.0, it is treated as a ratio of the total number of layers. If the value is >= 1, it is treated as the absolute number of layers. 0.0 means prefetch 1 layer (lowest memory). Values above 0.5 might have peak memory close to no offload but worse performance.\",\n        )\n        parser.add_argument(\n            \"--use-fsdp-inference\",\n            action=StoreBoolean,\n            help=\"Use FSDP for inference by sharding the model weights. Latency is very low due to prefetch--enable if run out of memory.\",\n        )\n        parser.add_argument(\n            \"--text-encoder-cpu-offload\",\n            action=StoreBoolean,\n            help=\"Use CPU offload for text encoder. Enable if run out of memory.\",\n        )\n        parser.add_argument(\n            \"--image-encoder-cpu-offload\",\n            action=StoreBoolean,\n            help=\"Use CPU offload for image encoder. Enable if run out of memory.\",\n        )\n        parser.add_argument(\n            \"--vae-cpu-offload\",\n            action=StoreBoolean,\n            help=\"Use CPU offload for VAE. Enable if run out of memory.\",\n        )\n        parser.add_argument(\n            \"--pin-cpu-memory\",\n            action=StoreBoolean,\n            help='Pin memory for CPU offload. Only added as a temp workaround if it throws \"CUDA error: invalid argument\". '\n            \"Should be enabled in almost all cases\",\n        )\n        parser.add_argument(\n            \"--disable-autocast\",\n            action=StoreBoolean,\n            help=\"Disable autocast for denoising loop and vae decoding in pipeline sampling\",\n        )\n\n        # Nunchaku SVDQuant quantization parameters\n        NunchakuSVDQuantArgs.add_cli_args(parser)\n\n        # Master port for distributed inference\n        parser.add_argument(\n            \"--master-port\",\n            type=int,\n            default=ServerArgs.master_port,\n            help=\"Master port for distributed inference. If not set, a random free port will be used.\",\n        )\n        parser.add_argument(\n            \"--scheduler-port\",\n            type=int,\n            default=ServerArgs.scheduler_port,\n            help=\"Port for the scheduler server.\",\n        )\n        parser.add_argument(\n            \"--host\",\n            type=str,\n            default=ServerArgs.host,\n            help=\"Host for the HTTP API server.\",\n        )\n        parser.add_argument(\n            \"--port\",\n            type=int,\n            default=ServerArgs.port,\n            help=\"Port for the HTTP API server.\",\n        )\n        parser.add_argument(\n            \"--webui\",\n            action=StoreBoolean,\n            default=ServerArgs.webui,\n            help=\"Whether to use webui for better display\",\n        )\n\n        parser.add_argument(\n            \"--webui-port\",\n            type=int,\n            default=ServerArgs.webui_port,\n            help=\"Whether to use webui for better display\",\n        )\n        parser.add_argument(\n            \"--output-path\",\n            type=str,\n            default=ServerArgs.output_path,\n            help='Directory path to save generated images/videos. Set to \"\" to disable persistent saving.',\n        )\n        parser.add_argument(\n            \"--input-save-path\",\n            type=str,\n            default=ServerArgs.input_save_path,\n            help='Directory path to save uploaded input images/videos. Set to \"\" to disable persistent saving.',\n        )\n\n        # LoRA\n        parser.add_argument(\n            \"--lora-path\",\n            type=str,\n            default=ServerArgs.lora_path,\n            help=\"The path to the LoRA adapter weights (can be local file path or HF hub id) to launch with\",\n        )\n        parser.add_argument(\n            \"--lora-nickname\",\n            type=str,\n            default=ServerArgs.lora_nickname,\n            help=\"The nickname for the LoRA adapter to launch with\",\n        )\n        parser.add_argument(\n            \"--lora-scale\",\n            type=float,\n            default=ServerArgs.lora_scale,\n            help=\"LoRA scale for merging (e.g., 0.125 for Hyper-SD). Same as lora_scale in Diffusers\",\n        )\n        # Add pipeline configuration arguments\n        PipelineConfig.add_cli_args(parser)\n\n        # Logging\n        parser.add_argument(\n            \"--log-level\",\n            type=str,\n            default=ServerArgs.log_level,\n            help=\"The logging level of all loggers.\",\n        )\n        parser.add_argument(\n            \"--backend\",\n            type=str,\n            choices=Backend.choices(),\n            default=ServerArgs.backend.value,\n            help=\"The model backend to use. 'auto' prefers sglang native and falls back to diffusers. \"\n            \"'sglang' uses native optimized implementation. 'diffusers' uses vanilla diffusers pipeline.\",\n        )\n        return parser\n\n    def url(self):\n        host = self.host\n        if not host or host == \"0.0.0.0\":\n            host = \"127.0.0.1\"\n        elif host == \"::\":\n            host = \"::1\"\n        if is_valid_ipv6_address(host):\n            return f\"http://[{host}]:{self.port}\"\n        else:\n            return f\"http://{host}:{self.port}\"\n\n    @property\n    def scheduler_endpoint(self):\n        \"\"\"\n        Internal endpoint for scheduler.\n        Prefers the configured host but normalizes localhost -> 127.0.0.1 to avoid ZMQ issues.\n        \"\"\"\n        scheduler_host = self.host\n        if scheduler_host is None or scheduler_host == \"localhost\":\n            scheduler_host = \"127.0.0.1\"\n        return f\"tcp://{scheduler_host}:{self.scheduler_port}\"\n\n    def settle_port(\n        self, port: int, port_inc: int = 42, max_attempts: int = 100\n    ) -> int:\n        \"\"\"\n        Find an available port with retry logic.\n        \"\"\"\n        attempts = 0\n        original_port = port\n\n        while attempts < max_attempts:\n            if is_port_available(port):\n                if attempts > 0:\n                    logger.info(\n                        f\"Port {original_port} was unavailable, using port {port} instead\"\n                    )\n                return port\n\n            attempts += 1\n            if port < 60000:\n                port += port_inc\n            else:\n                # Wrap around with randomization to avoid collision\n                port = 5000 + random.randint(0, 1000)\n\n        raise RuntimeError(\n            f\"Failed to find available port after {max_attempts} attempts \"\n            f\"(started from port {original_port})\"\n        )\n\n    @staticmethod\n    def _extract_component_paths(\n        unknown_args: list[str],\n    ) -> tuple[dict[str, str], list[str]]:\n        \"\"\"\n        Extract dynamic ``--<component>-path`` args from unrecognised CLI args.\n        \"\"\"\n        component_paths: dict[str, str] = {}\n        remaining: list[str] = []\n        i = 0\n        while i < len(unknown_args):\n            arg = unknown_args[i]\n            key_part = arg.split(\"=\", 1)[0] if \"=\" in arg else arg\n            if key_part.startswith(\"--\") and key_part.endswith(\"-path\"):\n                component = key_part[2:-5].replace(\"-\", \"_\")\n                if \"=\" in arg:\n                    component_paths[component] = arg.split(\"=\", 1)[1]\n                elif i + 1 < len(unknown_args) and not unknown_args[i + 1].startswith(\n                    \"-\"\n                ):\n                    i += 1\n                    component_paths[component] = unknown_args[i]\n                else:\n                    remaining.append(arg)\n                    i += 1\n                    continue\n            else:\n                remaining.append(arg)\n            i += 1\n\n        # canonicalize and validate\n        for component, path in component_paths.items():\n            path = os.path.expanduser(path)\n            component_paths[component] = path\n        return component_paths, remaining\n\n    @classmethod\n    def from_cli_args(\n        cls, args: argparse.Namespace, unknown_args: list[str] | None = None\n    ) -> \"ServerArgs\":\n        if unknown_args is None:\n            unknown_args = []\n\n        # extract dynamic --<component>-path from unknown args\n        dynamic_paths, remaining = cls._extract_component_paths(unknown_args)\n        if remaining:\n            raise SystemExit(f\"error: unrecognized arguments: {' '.join(remaining)}\")\n\n        provided_args = cls.get_provided_args(args, unknown_args)\n\n        # Handle config file\n        config_file = provided_args.get(\"config\")\n        if config_file:\n            config_args = cls.load_config_file(config_file)\n            provided_args = {**config_args, **provided_args}\n\n        if dynamic_paths:\n            existing = dict(provided_args.get(\"component_paths\") or {})\n            existing.update(dynamic_paths)\n            provided_args[\"component_paths\"] = existing\n\n        return cls.from_dict(provided_args)\n\n    @classmethod\n    def from_dict(cls, kwargs: dict[str, Any]) -> \"ServerArgs\":\n        \"\"\"Create a ServerArgs object from a dictionary.\"\"\"\n        kwargs = expand_path_kwargs(dict(kwargs))\n        attrs = [attr.name for attr in dataclasses.fields(cls)]\n        server_args_kwargs: dict[str, Any] = {}\n\n        component_paths = dict(kwargs.get(\"component_paths\") or {})\n        if component_paths:\n            server_args_kwargs[\"component_paths\"] = component_paths\n\n        for attr in attrs:\n            if attr == \"pipeline_config\":\n                pipeline_config = PipelineConfig.from_kwargs(kwargs)\n                logger.debug(f\"Using PipelineConfig: {type(pipeline_config)}\")\n                server_args_kwargs[\"pipeline_config\"] = pipeline_config\n            elif attr == \"nunchaku_config\":\n                nunchaku_config = NunchakuSVDQuantArgs.from_dict(kwargs)\n                server_args_kwargs[\"nunchaku_config\"] = nunchaku_config\n            elif attr in kwargs:\n                server_args_kwargs[attr] = kwargs[attr]\n\n        return cls(**server_args_kwargs)\n\n    @staticmethod\n    def load_config_file(config_file: str) -> dict[str, Any]:\n        \"\"\"Load a config file.\"\"\"\n        if config_file.endswith(\".json\"):\n            with open(config_file, \"r\") as f:\n                return json.load(f)\n        elif config_file.endswith((\".yaml\", \".yml\")):\n            try:\n                import yaml\n            except ImportError:\n                raise ImportError(\n                    \"Please install PyYAML to use YAML config files. \"\n                    \"`pip install pyyaml`\"\n                )\n            with open(config_file, \"r\") as f:\n                return yaml.safe_load(f)\n        else:\n            raise ValueError(f\"Unsupported config file format: {config_file}\")\n\n    @classmethod\n    def from_kwargs(cls, **kwargs: Any) -> \"ServerArgs\":\n        # Convert backend string to enum if necessary\n        if \"backend\" in kwargs and isinstance(kwargs[\"backend\"], str):\n            kwargs[\"backend\"] = Backend.from_string(kwargs[\"backend\"])\n\n        kwargs[\"pipeline_config\"] = PipelineConfig.from_kwargs(kwargs)\n        return cls(**kwargs)\n\n    @staticmethod\n    def get_provided_args(\n        args: argparse.Namespace, unknown_args: list[str]\n    ) -> dict[str, Any]:\n        \"\"\"Get the arguments provided by the user.\"\"\"\n        provided_args = {}\n        # We need to check against the raw command-line arguments to see what was\n        # explicitly provided by the user, vs. what's a default value from argparse.\n        raw_argv = sys.argv + unknown_args\n\n        # Create a set of argument names that were present on the command line.\n        # This handles both styles: '--arg=value' and '--arg value'.\n        provided_arg_names = set()\n        for arg in raw_argv:\n            if arg.startswith(\"--\"):\n                # For '--arg=value', this gets 'arg'; for '--arg', this also gets 'arg'.\n                arg_name = arg.split(\"=\", 1)[0].replace(\"-\", \"_\").lstrip(\"_\")\n                provided_arg_names.add(arg_name)\n\n        # Populate provided_args if the argument from the namespace was on the command line.\n        for k, v in vars(args).items():\n            if k in provided_arg_names:\n                provided_args[k] = v\n\n        return provided_args\n\n    def _validate_pipeline(self):\n        if self.pipeline_config is None:\n            raise ValueError(\"pipeline_config is not set in ServerArgs\")\n\n        self.pipeline_config.check_pipeline_config()\n\n    def _validate_offload(self):\n        # validate dit_offload_prefetch_size\n        if self.dit_offload_prefetch_size > 1 and (\n            isinstance(self.dit_offload_prefetch_size, float)\n            and not self.dit_offload_prefetch_size.is_integer()\n        ):\n            self.dit_offload_prefetch_size = int(\n                math.floor(self.dit_offload_prefetch_size)\n            )\n            logger.info(\n                f\"Invalid --dit-offload-prefetch-size value passed, truncated to: {self.dit_offload_prefetch_size}\"\n            )\n\n        if 0.5 <= self.dit_offload_prefetch_size < 1.0:\n            logger.info(\n                \"We do not recommend --dit-offload-prefetch-size to be between 0.5 and 1.0\"\n            )\n\n        # validate dit_layerwise_offload conflicts\n        if self.dit_layerwise_offload:\n            if self.dit_offload_prefetch_size < 0.0:\n                raise ValueError(\"dit_offload_prefetch_size must be non-negative\")\n\n            if self.use_fsdp_inference:\n                logger.warning(\n                    \"dit_layerwise_offload is enabled, automatically disabling use_fsdp_inference.\"\n                )\n                self.use_fsdp_inference = False\n\n            if self.dit_cpu_offload is None:\n                logger.warning(\n                    \"dit_layerwise_offload is enabled, automatically disabling dit_cpu_offload.\"\n                )\n                self.dit_cpu_offload = False\n\n            if envs.SGLANG_CACHE_DIT_ENABLED:\n                raise ValueError(\n                    \"dit_layerwise_offload cannot be enabled together with cache-dit. \"\n                    \"cache-dit may reuse skipped blocks whose weights have been released by layerwise offload, \"\n                    \"causing shape mismatch errors. \"\n                    \"Please disable either --dit-layerwise-offload or SGLANG_CACHE_DIT_ENABLED.\"\n                )\n\n    def _validate_parallelism(self):\n        if self.sp_degree > self.num_gpus or self.num_gpus % self.sp_degree != 0:\n            raise ValueError(\n                f\"num_gpus ({self.num_gpus}) must be >= and divisible by sp_degree ({self.sp_degree})\"\n            )\n\n        if (\n            self.hsdp_replicate_dim > self.num_gpus\n            or self.num_gpus % self.hsdp_replicate_dim != 0\n        ):\n            raise ValueError(\n                f\"num_gpus ({self.num_gpus}) must be >= and divisible by hsdp_replicate_dim ({self.hsdp_replicate_dim})\"\n            )\n\n        if (\n            self.hsdp_shard_dim > self.num_gpus\n            or self.num_gpus % self.hsdp_shard_dim != 0\n        ):\n            raise ValueError(\n                f\"num_gpus ({self.num_gpus}) must be >= and divisible by hsdp_shard_dim ({self.hsdp_shard_dim})\"\n            )\n\n        if self.num_gpus % self.dp_size != 0:\n            raise ValueError(\n                f\"num_gpus ({self.num_gpus}) must be divisible by dp_size ({self.dp_size})\"\n            )\n\n        if self.dp_size < 1:\n            raise ValueError(\"--dp-size must be a natural number\")\n\n        if self.dp_size > 1:\n            raise ValueError(\"DP is not yet supported\")\n\n        num_gpus_per_group = self.dp_size * self.tp_size\n        if self.enable_cfg_parallel:\n            num_gpus_per_group *= 2\n\n        if self.num_gpus % num_gpus_per_group != 0:\n            raise ValueError(\n                f\"num_gpus ({self.num_gpus}) must be divisible by (dp_size * tp_size{' * 2' if self.enable_cfg_parallel else ''}) = {num_gpus_per_group}\"\n            )\n\n        if self.sp_degree != self.ring_degree * self.ulysses_degree:\n            raise ValueError(\n                f\"sp_degree ({self.sp_degree}) must equal ring_degree * ulysses_degree \"\n                f\"({self.ring_degree} * {self.ulysses_degree} = {self.ring_degree * self.ulysses_degree})\"\n            )\n\n        if os.getenv(\"SGLANG_CACHE_DIT_ENABLED\", \"\").lower() == \"true\":\n            has_sp = self.sp_degree > 1\n            has_tp = self.tp_size > 1\n            if has_sp and has_tp:\n                logger.warning(\n                    \"cache-dit is enabled with hybrid parallelism (SP + TP). \"\n                    \"Proceeding anyway (SGLang integration may support this mode).\"\n                )\n\n    def _validate_cfg_parallel(self):\n        if self.enable_cfg_parallel and self.num_gpus == 1:\n            raise ValueError(\n                \"CFG Parallelism is enabled via `--enable-cfg-parallel`, but num_gpus == 1\"\n            )\n\n    def _set_default_attention_backend(self) -> None:\n        \"\"\"Configure ROCm defaults when users do not specify an attention backend.\"\"\"\n        if current_platform.is_rocm():\n            default_backend = AttentionBackendEnum.AITER.name.lower()\n            self.attention_backend = default_backend\n            logger.info(\n                \"Attention backend not specified. Using '%s' by default on ROCm \"\n                \"to match SGLang SRT defaults.\",\n                default_backend,\n            )\n\n\n@dataclasses.dataclass\nclass PortArgs:\n    # The ipc filename for scheduler (rank 0) to receive inputs from tokenizer (zmq)\n    scheduler_input_ipc_name: str\n\n    # The port for nccl initialization (torch.dist)\n    nccl_port: int\n\n    # The ipc filename for rpc call between Engine and Scheduler\n    rpc_ipc_name: str\n\n    # The ipc filename for Scheduler to send metrics\n    metrics_ipc_name: str\n\n    # Master port for distributed inference\n    master_port: int | None = None\n\n    @staticmethod\n    def from_server_args(\n        server_args: ServerArgs, dp_rank: Optional[int] = None\n    ) -> \"PortArgs\":\n        if server_args.nccl_port is None:\n            nccl_port = server_args.scheduler_port + random.randint(100, 1000)\n            while True:\n                if is_port_available(nccl_port):\n                    break\n                if nccl_port < 60000:\n                    nccl_port += 42\n                else:\n                    nccl_port -= 43\n        else:\n            nccl_port = server_args.nccl_port\n\n        # Normal case, use IPC within a single node\n        return PortArgs(\n            scheduler_input_ipc_name=f\"ipc://{tempfile.NamedTemporaryFile(delete=False).name}\",\n            nccl_port=nccl_port,\n            rpc_ipc_name=f\"ipc://{tempfile.NamedTemporaryFile(delete=False).name}\",\n            metrics_ipc_name=f\"ipc://{tempfile.NamedTemporaryFile(delete=False).name}\",\n            master_port=server_args.master_port,\n        )\n\n\n_global_server_args = None\n\n\ndef prepare_server_args(argv: list[str]) -> ServerArgs:\n    \"\"\"\n    Prepare the inference arguments from the command line arguments.\n    \"\"\"\n    parser = FlexibleArgumentParser()\n    ServerArgs.add_cli_args(parser)\n    raw_args, unknown_args = parser.parse_known_args(argv)\n    server_args = ServerArgs.from_cli_args(raw_args, unknown_args)\n    return server_args\n\n\ndef set_global_server_args(server_args: ServerArgs):\n    \"\"\"\n    Set the global sgl_diffusion config for each process\n    \"\"\"\n    global _global_server_args\n    _global_server_args = server_args\n\n\ndef get_global_server_args() -> ServerArgs:\n    if _global_server_args is None:\n        # in ci, usually when we test custom ops/modules directly,\n        # we don't set the sgl_diffusion config. In that case, we set a default\n        # config.\n        # TODO(will): may need to handle this for CI.\n        raise ValueError(\"Global sgl_diffusion args is not set.\")\n    return _global_server_args\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/utils/common.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\nimport ipaddress\nimport logging\nimport os\nimport platform\nimport signal\nimport socket\nimport sys\nimport threading\nfrom functools import lru_cache\n\nimport psutil\nimport torch\nimport zmq\n\n# use the native logger to avoid circular import\nlogger = logging.getLogger(__name__)\n\n\ndef kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = None):\n    \"\"\"Kill the process and all its child processes.\"\"\"\n    # Remove sigchld handler to avoid spammy logs.\n    if threading.current_thread() is threading.main_thread():\n        signal.signal(signal.SIGCHLD, signal.SIG_DFL)\n\n    if parent_pid is None:\n        parent_pid = os.getpid()\n        include_parent = False\n\n    try:\n        itself = psutil.Process(parent_pid)\n    except psutil.NoSuchProcess:\n        return\n\n    children = itself.children(recursive=True)\n    for child in children:\n        if child.pid == skip_pid:\n            continue\n        try:\n            child.kill()\n        except psutil.NoSuchProcess:\n            pass\n\n    if include_parent:\n        try:\n            if parent_pid == os.getpid():\n                itself.kill()\n                sys.exit(0)\n\n            itself.kill()\n\n            # Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes),\n            # so we send an additional signal to kill them.\n            itself.send_signal(signal.SIGQUIT)\n        except psutil.NoSuchProcess:\n            pass\n\n\ndef add_prefix(name: str, prefix: str) -> str:\n    \"\"\"Add a weight path prefix to a module name.\n\n    Args:\n        name: base module name.\n        prefix: weight prefix str to added to the front of `name` concatenated with `.`.\n\n    Returns:\n        The string `prefix.name` if prefix is non-empty, otherwise just `name`.\n    \"\"\"\n    return name if not prefix else f\"{prefix}.{name}\"\n\n\ndef is_valid_ipv6_address(address: str) -> bool:\n    try:\n        ipaddress.IPv6Address(address)\n        return True\n    except ValueError:\n        return False\n\n\ndef configure_ipv6(dist_init_addr):\n    addr = dist_init_addr\n    end = addr.find(\"]\")\n    if end == -1:\n        raise ValueError(\"invalid IPv6 address format: missing ']'\")\n\n    host = addr[: end + 1]\n\n    # this only validates the address without brackets: we still need the below checks.\n    # if it's invalid, immediately raise an error so we know it's not formatting issues.\n    if not is_valid_ipv6_address(host[1:end]):\n        raise ValueError(f\"invalid IPv6 address: {host}\")\n\n    port_str = None\n    if len(addr) > end + 1:\n        if addr[end + 1] == \":\":\n            port_str = addr[end + 2 :]\n        else:\n            raise ValueError(\"received IPv6 address format: expected ':' after ']'\")\n\n    if not port_str:\n        raise ValueError(\n            \"a port must be specified in IPv6 address (format: [ipv6]:port)\"\n        )\n\n    try:\n        port = int(port_str)\n    except ValueError:\n        raise ValueError(f\"invalid port in IPv6 address: '{port_str}'\")\n    return port, host\n\n\ndef is_port_available(port):\n    \"\"\"Return whether a port is available.\"\"\"\n    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:\n        try:\n            s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)\n            s.bind((\"\", port))\n            s.listen(1)\n            return True\n        except socket.error:\n            return False\n        except OverflowError:\n            return False\n\n\ndef get_zmq_socket(\n    context: zmq.Context,\n    socket_type: zmq.SocketType,\n    endpoint: str,\n    bind: bool,\n    max_bind_retries: int = 10,\n) -> tuple[zmq.Socket, str]:\n    \"\"\"\n    Create and configure a ZMQ socket.\n\n    Args:\n        context: ZMQ context\n        socket_type: Type of ZMQ socket\n        endpoint: Endpoint string (e.g., \"tcp://localhost:5555\")\n        bind: Whether to bind (True) or connect (False)\n        max_bind_retries: Maximum number of retries if bind fails due to address already in use\n\n    Returns:\n        A tuple of (socket, actual_endpoint). The actual_endpoint may differ from the\n        requested endpoint if bind retry was needed.\n    \"\"\"\n    mem = psutil.virtual_memory()\n    total_mem = mem.total / 1024**3\n    available_mem = mem.available / 1024**3\n    if total_mem > 32 and available_mem > 16:\n        buf_size = int(0.5 * 1024**3)\n    else:\n        buf_size = -1\n\n    socket = context.socket(socket_type)\n    if endpoint.find(\"[\") != -1:\n        socket.setsockopt(zmq.IPV6, 1)\n\n    def set_send_opt():\n        socket.setsockopt(zmq.SNDHWM, 0)\n        socket.setsockopt(zmq.SNDBUF, buf_size)\n\n    def set_recv_opt():\n        socket.setsockopt(zmq.RCVHWM, 0)\n        socket.setsockopt(zmq.RCVBUF, buf_size)\n\n    if socket_type == zmq.PUSH:\n        set_send_opt()\n    elif socket_type == zmq.PULL:\n        set_recv_opt()\n    elif socket_type in [zmq.DEALER, zmq.REQ, zmq.REP, zmq.ROUTER]:\n        set_send_opt()\n        set_recv_opt()\n    else:\n        raise ValueError(f\"Unsupported socket type: {socket_type}\")\n\n    if bind:\n        # Parse port from endpoint for retry logic\n        import re\n\n        port_match = re.search(r\":(\\d+)$\", endpoint)\n\n        if port_match and max_bind_retries > 1:\n            original_port = int(port_match.group(1))\n            last_exception = None\n\n            for attempt in range(max_bind_retries):\n                try:\n                    current_endpoint = endpoint\n                    if attempt > 0:\n                        # Try next port (increment by 42 to match settle_port logic)\n                        current_port = original_port + attempt * 42\n                        current_endpoint = re.sub(\n                            r\":(\\d+)$\", f\":{current_port}\", endpoint\n                        )\n                        logger.info(\n                            f\"ZMQ bind failed for port {original_port + (attempt - 1) * 42}, \"\n                            f\"retrying with port {current_port} (attempt {attempt + 1}/{max_bind_retries})\"\n                        )\n\n                    socket.bind(current_endpoint)\n\n                    if attempt > 0:\n                        logger.warning(\n                            f\"Successfully bound ZMQ socket to {current_endpoint} after {attempt + 1} attempts. \"\n                            f\"Original port {original_port} was unavailable.\"\n                        )\n\n                    return socket, current_endpoint\n\n                except zmq.ZMQError as e:\n                    last_exception = e\n                    if e.errno == zmq.EADDRINUSE and attempt < max_bind_retries - 1:\n                        # Address already in use, try next port\n                        continue\n                    elif attempt == max_bind_retries - 1:\n                        # Last attempt failed\n                        logger.error(\n                            f\"Failed to bind ZMQ socket after {max_bind_retries} attempts. \"\n                            f\"Original endpoint: {endpoint}, Last tried port: {original_port + attempt * 42}\"\n                        )\n                        raise\n                    else:\n                        # Different error, raise immediately\n                        raise\n\n            # Should not reach here, but just in case\n            if last_exception:\n                raise last_exception\n        else:\n            # No retry logic needed (either no port in endpoint or max_bind_retries == 1)\n            socket.bind(endpoint)\n            return socket, endpoint\n    else:\n        socket.connect(endpoint)\n        return socket, endpoint\n\n    return socket, endpoint\n\n\n# https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip\n\n\n@lru_cache(maxsize=1)\ndef is_host_cpu_x86() -> bool:\n    machine = platform.machine().lower()\n    return (\n        machine in (\"x86_64\", \"amd64\", \"i386\", \"i686\")\n        and hasattr(torch, \"cpu\")\n        and torch.cpu.is_available()\n    )\n\n\n# cuda\n\n\ndef set_cuda_arch():\n    capability = torch.cuda.get_device_capability()\n    arch = f\"{capability[0]}.{capability[1]}\"\n    os.environ[\"TORCH_CUDA_ARCH_LIST\"] = f\"{arch}{'+PTX' if arch == '9.0' else ''}\"\n\n\n# musa\n\n\ndef set_musa_arch():\n    capability = torch.cuda.get_device_capability()\n    arch = f\"{capability[0]}{capability[1]}\"\n    os.environ[\"TORCH_MUSA_ARCH_LIST\"] = f\"{arch}\"\n\n\n# env var managements\n\n_warned_bool_env_var_keys = set()\n\n\ndef get_bool_env_var(name: str, default: str = \"false\") -> bool:\n    value = os.getenv(name, default)\n    value = str(value).strip().lower()\n\n    truthy_values = {\"1\", \"true\", \"yes\", \"y\", \"t\", \"on\"}\n    falsy_values = {\"0\", \"false\", \"no\", \"n\", \"f\", \"off\", \"\"}\n\n    if (value not in truthy_values) and (value not in falsy_values):\n        if value not in _warned_bool_env_var_keys:\n            logger.warning(\n                f\"get_bool_env_var({name}) see non-understandable value={value} and treat as false\"\n            )\n        _warned_bool_env_var_keys.add(value)\n\n    return value in truthy_values\n\n\ntry:\n    import sgl_kernel  # noqa: F401\n\n    is_intel_amx_backend_available = hasattr(\n        torch.ops.sgl_kernel, \"convert_weight_packed\"\n    )\nexcept:\n    is_intel_amx_backend_available = False\n\ntry:\n    # move torch._C._cpu._is_amx_tile_supported() from cpu_has_amx_support\n    # to support torch compile\n    is_amx_tile_supported = torch._C._cpu._is_amx_tile_supported()\nexcept:\n    is_amx_tile_supported = False\n\n\ndef cpu_has_amx_support():\n    return is_amx_tile_supported and is_intel_amx_backend_available\n\n\ndef use_intel_amx_backend(layer):\n    return getattr(layer, \"use_intel_amx_backend\", False)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/utils/distributed.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\nimport pickle\nfrom typing import Any, List, Optional\n\nimport numpy as np\nimport torch\nimport torch.distributed as dist\n\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\n\n\ndef broadcast_pyobj(\n    data: List[Any],\n    rank: int,\n    dist_group: Optional[torch.distributed.ProcessGroup] = None,\n    src: int = 0,\n    force_cpu_device: bool = True,\n):\n    \"\"\"Broadcast inputs from src rank to all other ranks with torch.dist backend.\n    The `rank` here refer to the source rank on global process group (regardless\n    of dist_group argument).\n    \"\"\"\n\n    device = torch.device(\n        current_platform.device_type if not force_cpu_device else \"cpu\"\n    )\n\n    if rank == src:\n        if data is None or len(data) == 0:\n            tensor_size = torch.tensor([0], dtype=torch.long, device=device)\n            dist.broadcast(tensor_size, src=src, group=dist_group)\n        else:\n            serialized_data = pickle.dumps(data)\n            size = len(serialized_data)\n\n            tensor_data = torch.ByteTensor(\n                np.frombuffer(serialized_data, dtype=np.uint8).copy()\n            ).to(device)\n            tensor_size = torch.tensor([size], dtype=torch.long, device=device)\n\n            dist.broadcast(tensor_size, src=src, group=dist_group)\n            dist.broadcast(tensor_data, src=src, group=dist_group)\n        return data\n    else:\n        tensor_size = torch.tensor([0], dtype=torch.long, device=device)\n        dist.broadcast(tensor_size, src=src, group=dist_group)\n        size = tensor_size.item()\n\n        if size == 0:\n            return []\n\n        tensor_data = torch.empty(size, dtype=torch.uint8, device=device)\n        dist.broadcast(tensor_data, src=src, group=dist_group)\n\n        serialized_data = bytes(tensor_data.cpu().numpy())\n        data = pickle.loads(serialized_data)\n        return data\n\n\ndef generate_masked_orthogonal_rank_groups(\n    world_size: int, parallel_size: list[int], mask: list[bool]\n) -> list[list[int]]:\n    \"\"\"Generate orthogonal parallel groups based on the parallel size and mask.\n\n    Arguments:\n        world_size (int): world size\n\n        parallel_size (List[int]):\n            The parallel size of each orthogonal parallel type. For example, if\n            tensor_parallel_size = 2, pipeline_model_parallel_group = 3, data_parallel_size = 4,\n            and the parallel mapping order is tp-pp-dp, then the parallel_size = [2, 3, 4].\n\n        mask (List[bool]):\n            The mask controls which parallel methods the generated groups represent. If mask[i] is\n            True, it means the generated group contains the i-th parallelism method. For example,\n            if parallel_size = [tp_size, pp_size, dp_size], and mask = [True, False , True], then\n            the generated group is the `tp-dp` group, if the mask = [False, True, False], then the\n            generated group is the `pp` group.\n\n    Algorithm:\n        For orthogonal parallelism, such as tp/dp/pp/cp, the global_rank and\n\n        If we want to get the `dp_group` (tp_size * pp_size groups of dp_size ranks each.\n        For example,  if the gpu size is 8 and order is 'tp-pp-dp', size is '2-2-2', and the\n        dp_group here is [[0, 4], [1, 5], [2, 6], [3, 7]].)\n        The tp_rank and pp_rank will be combined to form the `dp_group_index`.\n            dp_group_index = tp_rank + pp_rank * tp_size (2)\n\n        So, Given that tp_rank and pp_rank satisfy equation (2), and dp_rank in\n        range(0, dp_size), the ranks in dp_group[dp_group_index] satisfies the\n        equation (1).\n\n        This function solve this math problem.\n\n    For example, if the parallel_size = [tp_size, dp_size, pp_size] = [2, 3, 4],\n    and the mask = [False, True, False]. Then,\n        dp_group_index(0) = tp_rank(0) + pp_rank(0) * 2\n        dp_group_index(1) = tp_rank(1) + pp_rank(0) * 2\n        ...\n        dp_group_index(7) = tp_rank(1) + pp_rank(3) * 2\n\n        dp_group[0] = 0 + range(0, 3) * 2 + 0 = [0, 2, 4]\n        dp_group[1] = 1 + range(0, 3) * 2 + 0 = [1, 3, 5]\n        ...\n        dp_group[7] = 1 + range(0, 3) * 2 + 3 * 2 * 3 = [19, 21, 23]\n    \"\"\"\n\n    def prefix_product(a: List[int], init=1) -> List[int]:\n        r = [init]\n        for v in a:\n            init = init * v\n            r.append(init)\n        return r\n\n    def inner_product(a: List[int], b: List[int]) -> int:\n        return sum([x * y for x, y in zip(a, b)])\n\n    def decompose(index, shape, stride=None):\n        \"\"\"\n        This function solve the math problem below:\n            There is an equation:\n                index = sum(idx[i] * stride[i])\n            And given the value of index, stride.\n            Return the idx.\n        This function will used to get the pp/dp/pp_rank\n        from group_index and rank_in_group.\n        \"\"\"\n        if stride is None:\n            stride = prefix_product(shape)\n        idx = [(index // d) % s for s, d in zip(shape, stride)]\n        # stride is a prefix_product result. And the value of stride[-1]\n        # is not used.\n        assert (\n            sum([x * y for x, y in zip(idx, stride[:-1])]) == index\n        ), \"idx {} with shape {} mismatch the return idx {}\".format(index, shape, idx)\n        return idx\n\n    masked_shape = [s for s, m in zip(parallel_size, mask) if m]\n    unmasked_shape = [s for s, m in zip(parallel_size, mask) if not m]\n\n    global_stride = prefix_product(parallel_size)\n    masked_stride = [d for d, m in zip(global_stride, mask) if m]\n    unmasked_stride = [d for d, m in zip(global_stride, mask) if not m]\n\n    group_size = prefix_product(masked_shape)[-1]\n    num_of_group = world_size // group_size\n\n    ranks = []\n    for group_index in range(num_of_group):\n        # get indices from unmaksed for group_index.\n        decomposed_group_idx = decompose(group_index, unmasked_shape)\n        rank = []\n        for rank_in_group in range(group_size):\n            # get indices from masked for rank_in_group.\n            decomposed_rank_idx = decompose(rank_in_group, masked_shape)\n            rank.append(\n                inner_product(decomposed_rank_idx, masked_stride)\n                + inner_product(decomposed_group_idx, unmasked_stride)\n            )\n        ranks.append(rank)\n    return ranks\n\n\nclass RankGenerator(object):\n    def __init__(\n        self,\n        tp: int,\n        sp: int,\n        pp: int,\n        cfg: int,\n        dp: int,\n        order: str,\n        rank_offset: int = 0,\n    ) -> None:\n        self.tp = tp\n        self.sp = sp\n        self.pp = pp\n        self.cfg = cfg\n        self.dp = dp\n        self.rank_offset = rank_offset\n        self.world_size = tp * sp * pp * cfg * dp\n\n        self.name_to_size = {\n            \"tp\": self.tp,\n            \"sp\": self.sp,\n            \"pp\": self.pp,\n            \"cfg\": self.cfg,\n            \"dp\": self.dp,\n        }\n        order = order.lower()\n\n        for name in self.name_to_size.keys():\n            if name not in order and self.name_to_size[name] != 1:\n                raise RuntimeError(\n                    f\"The size of ({name}) is ({self.name_to_size[name]}), but you haven't specified the order ({self.order}).\"\n                )\n            elif name not in order:\n                order = order + \"-\" + name\n\n        self.order = order\n        self.ordered_size = []\n\n        for token in order.split(\"-\"):\n            self.ordered_size.append(self.name_to_size[token])\n\n    def get_mask(self, order: str, token: str):\n        ordered_token = order.split(\"-\")\n        token = token.split(\"-\")\n        mask = [False] * len(ordered_token)\n        for t in token:\n            mask[ordered_token.index(t)] = True\n        return mask\n\n    def get_ranks(self, token):\n        \"\"\"Get rank group by input token.\n\n        Arguments:\n            token (str):\n                Specify the ranks type that want to get. If we want\n                to obtain multiple parallel types, we can use a hyphen\n                '-' to separate them. For example, if we want to obtain\n                the TP_DP group, the token should be 'tp-dp'.\n\n        \"\"\"\n        mask = self.get_mask(self.order, token)\n        ranks = generate_masked_orthogonal_rank_groups(\n            self.world_size, self.ordered_size, mask\n        )\n        if self.rank_offset > 0:\n            for rank_group in ranks:\n                for i in range(len(rank_group)):\n                    rank_group[i] += self.rank_offset\n        return ranks\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/utils/hf_diffusers_utils.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# Adapted from SGLang: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/hf_transformers_utils.py\n\n# Copyright 2023-2024 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Utilities for Huggingface Transformers.\"\"\"\n\nimport contextlib\nimport glob\nimport json\nimport os\nimport shutil\nimport time\nfrom functools import reduce\nfrom pathlib import Path\nfrom typing import Any, Optional, Union, cast\n\nfrom diffusers.loaders.lora_base import (\n    _best_guess_weight_name,  # watch out for potetential removal from diffusers\n)\nfrom huggingface_hub.errors import (\n    LocalEntryNotFoundError,\n    RepositoryNotFoundError,\n    RevisionNotFoundError,\n)\nfrom requests.exceptions import ConnectionError as RequestsConnectionError\nfrom requests.exceptions import RequestException\nfrom transformers import AutoConfig, PretrainedConfig\n\nfrom sglang.multimodal_gen.runtime.loader.utils import _clean_hf_config_inplace\nfrom sglang.multimodal_gen.runtime.loader.weight_utils import get_lock\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.srt.environ import envs\nfrom sglang.utils import is_in_ci\n\nlogger = init_logger(__name__)\n\n\ndef _check_index_files_for_missing_shards(\n    model_path: str,\n) -> tuple[bool, list[str], list[str]]:\n    \"\"\"\n    Check all subdirectories for missing shards based on index files.\n\n    This catches cases where a model download was interrupted, leaving\n    some safetensors shards missing while the index file exists.\n\n    Args:\n        model_path: Path to the model directory\n\n    Returns:\n        Tuple of (all_valid, missing_files, checked_subdirs)\n    \"\"\"\n    missing_files = []\n    checked_subdirs = []\n\n    # Add common subdirectories for diffusers models\n    try:\n        subdirs = os.listdir(model_path)\n    except OSError as e:\n        logger.warning(\"Failed to list model directory %s: %s\", model_path, e)\n        return True, [], []  # Assume valid if we can't check\n\n    # Check the root directory and all subdirectories that might contain model weights\n    dirs_to_check = [model_path]\n\n    for subdir in subdirs:\n        subdir_path = os.path.join(model_path, subdir)\n        if os.path.isdir(subdir_path):\n            dirs_to_check.append(subdir_path)\n\n    for dir_path in dirs_to_check:\n        # Find all safetensors index files\n        index_files = glob.glob(os.path.join(dir_path, \"*.safetensors.index.json\"))\n\n        for index_file in index_files:\n            checked_subdirs.append(os.path.basename(dir_path))\n            try:\n                with open(index_file) as f:\n                    index_data = json.load(f)\n\n                weight_map = index_data.get(\"weight_map\", {})\n                if not weight_map:\n                    continue\n\n                # Get unique files referenced in weight_map\n                required_files = set(weight_map.values())\n\n                for file_name in required_files:\n                    file_path = os.path.join(dir_path, file_name)\n                    if not os.path.exists(file_path):\n                        relative_path = os.path.relpath(file_path, model_path)\n                        missing_files.append(relative_path)\n\n            except Exception as e:\n                logger.warning(\"Failed to read index file %s: %s\", index_file, e)\n                continue\n\n    return len(missing_files) == 0, missing_files, checked_subdirs\n\n\ndef _cleanup_model_cache(model_path: str, reason: str) -> bool:\n    \"\"\"\n    Remove the model cache directory to force a clean re-download.\n\n    Args:\n        model_path: Path to the model directory (snapshot path)\n        reason: Reason for cleanup (for logging)\n\n    Returns:\n        True if cleanup was performed, False otherwise\n    \"\"\"\n    # Navigate up to the model root directory: snapshots/hash -> snapshots -> model_root\n    # HF cache structure: models--org--name/snapshots/hash/\n    try:\n        snapshot_dir = os.path.abspath(model_path)\n        snapshots_dir = os.path.dirname(snapshot_dir)\n        repo_folder = os.path.dirname(snapshots_dir)\n\n        # Verify this looks like an HF cache structure\n        if os.path.basename(snapshots_dir) != \"snapshots\":\n            logger.warning(\n                \"Model path %s doesn't appear to be in HF cache structure, skipping cleanup\",\n                model_path,\n            )\n            return False\n\n        logger.warning(\n            \"Removing model cache at %s. Reason: %s\",\n            repo_folder,\n            reason,\n        )\n        shutil.rmtree(repo_folder)\n        logger.info(\"Successfully removed corrupted cache directory\")\n        return True\n    except Exception as e:\n        logger.error(\n            \"Failed to remove corrupted cache directory %s: %s. \"\n            \"Manual cleanup may be required.\",\n            model_path,\n            e,\n        )\n        return False\n\n\ndef _ci_validate_diffusers_model(model_path: str) -> tuple[bool, bool]:\n    \"\"\"\n    CI-specific validation for diffusers models.\n\n    Checks all subdirectories (transformer, transformer_2, vae, etc.) for\n    missing shards based on their index files. If issues are found in CI,\n    cleans up the cache to force re-download.\n\n    Args:\n        model_path: Path to the model directory\n\n    Returns:\n        Tuple of (is_valid, cleanup_performed)\n        - is_valid: True if the model is valid\n        - cleanup_performed: True if cleanup was performed (only relevant when is_valid=False)\n    \"\"\"\n    if not is_in_ci():\n        return True, False\n    is_valid, missing_files, checked_subdirs = _check_index_files_for_missing_shards(\n        model_path\n    )\n\n    if not is_valid:\n        logger.error(\n            \"CI validation failed for %s. Missing %d file(s): %s. \"\n            \"Checked subdirectories: %s\",\n            model_path,\n            len(missing_files),\n            missing_files[:5] if len(missing_files) > 5 else missing_files,\n            checked_subdirs,\n        )\n        cleanup_performed = _cleanup_model_cache(\n            model_path,\n            f\"Missing {len(missing_files)} shard file(s): {missing_files[:3]}\",\n        )\n        return False, cleanup_performed\n\n    if checked_subdirs:\n        logger.info(\n            \"CI validation passed for %s. Checked subdirectories: %s\",\n            model_path,\n            checked_subdirs,\n        )\n\n    return True, False\n\n\ndef _verify_diffusers_model_complete(path: str) -> bool:\n    \"\"\"Check if a diffusers model directory has all required component subdirectories.\"\"\"\n    config_path = os.path.join(path, \"model_index.json\")\n    if not os.path.exists(config_path):\n        return False\n\n    try:\n        with open(config_path) as config_file:\n            model_index = json.load(config_file)\n    except Exception as exc:\n        logger.warning(\"Failed to read model_index.json at %s: %s\", config_path, exc)\n        return False\n\n    component_keys = [\n        key\n        for key, value in model_index.items()\n        if isinstance(value, (list, tuple))\n        and len(value) == 2\n        and all(isinstance(item, str) for item in value)\n    ]\n    if component_keys:\n        return all(os.path.exists(os.path.join(path, key)) for key in component_keys)\n\n    return os.path.exists(os.path.join(path, \"transformer\")) and os.path.exists(\n        os.path.join(path, \"vae\")\n    )\n\n\n_CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = {\n    # ChatGLMConfig.model_type: ChatGLMConfig,\n    # DbrxConfig.model_type: DbrxConfig,\n    # ExaoneConfig.model_type: ExaoneConfig,\n    # Qwen2_5_VLConfig.model_type: Qwen2_5_VLConfig,\n}\n\nfor name, cls in _CONFIG_REGISTRY.items():\n    with contextlib.suppress(ValueError):\n        AutoConfig.register(name, cls)\n\n\ndef download_from_hf(model_path: str):\n    if os.path.exists(model_path):\n        return model_path\n\n    return snapshot_download(model_path, allow_patterns=[\"*.json\", \"*.bin\", \"*.model\"])\n\n\ndef get_hf_config(\n    component_model_path: str,\n    trust_remote_code: bool,\n    revision: str | None = None,\n    model_override_args: dict | None = None,\n    **kwargs,\n) -> PretrainedConfig:\n    if check_gguf_file(component_model_path):\n        raise NotImplementedError(\"GGUF models are not supported.\")\n\n    config = AutoConfig.from_pretrained(\n        component_model_path,\n        trust_remote_code=trust_remote_code,\n        revision=revision,\n        **kwargs,\n    )\n    if config.model_type in _CONFIG_REGISTRY:\n        config_class = _CONFIG_REGISTRY[config.model_type]\n        config = config_class.from_pretrained(component_model_path, revision=revision)\n        # NOTE(HandH1998): Qwen2VL requires `_name_or_path` attribute in `config`.\n        config._name_or_path = component_model_path\n    if model_override_args:\n        config.update(model_override_args)\n\n    return config\n\n\ndef get_config(\n    model: str,\n    trust_remote_code: bool,\n    revision: Optional[str] = None,\n    model_override_args: Optional[dict] = None,\n    **kwargs,\n):\n    return AutoConfig.from_pretrained(\n        model, trust_remote_code=trust_remote_code, revision=revision, **kwargs\n    )\n\n\ndef load_dict(file_path):\n    if not os.path.exists(file_path):\n        return {}\n    try:\n        # Load the config directly from the file\n        with open(file_path) as f:\n            config_dict: dict[str, Any] = json.load(f)\n        if \"_diffusers_version\" in config_dict:\n            config_dict.pop(\"_diffusers_version\")\n        # TODO(will): apply any overrides from inference args\n        return config_dict\n    except Exception as e:\n        raise RuntimeError(\n            f\"Failed to load diffusers config from {file_path}: {e}\"\n        ) from e\n\n\ndef get_diffusers_component_config(\n    component_path: str,\n) -> dict[str, Any]:\n    \"\"\"Gets a configuration of a submodule for the given diffusers model.\"\"\"\n    # Download from HuggingFace Hub if path doesn't exist locally\n    if not os.path.exists(component_path):\n        component_path = maybe_download_model(component_path)\n\n    config_names = [\"generation_config.json\"]\n    # By default, we load config.json, but scheduler_config.json for scheduler\n    if \"scheduler\" in component_path:\n        config_names.append(\"scheduler_config.json\")\n    else:\n        config_names.append(\"config.json\")\n\n    config_file_paths = [\n        os.path.join(component_path, config_name) for config_name in config_names\n    ]\n\n    combined_config = reduce(\n        lambda acc, path: acc | load_dict(path), config_file_paths, {}\n    )\n\n    _clean_hf_config_inplace(combined_config)\n\n    logger.debug(\"HF model config: %s\", combined_config)\n\n    return combined_config\n\n\n# Models don't use the same configuration key for determining the maximum\n# context length.  Store them here so we can sanely check them.\n# NOTE: The ordering here is important. Some models have two of these and we\n# have a preference for which value gets used.\nCONTEXT_LENGTH_KEYS = [\n    \"max_sequence_length\",\n    \"seq_length\",\n    \"max_seq_len\",\n    \"model_max_length\",\n    \"max_position_embeddings\",\n]\n\n\ndef attach_additional_stop_token_ids(tokenizer):\n    # Special handling for stop token <|eom_id|> generated by llama 3 tool use.\n    if \"<|eom_id|>\" in tokenizer.get_added_vocab():\n        tokenizer.additional_stop_token_ids = {\n            tokenizer.get_added_vocab()[\"<|eom_id|>\"]\n        }\n    else:\n        tokenizer.additional_stop_token_ids = None\n\n\ndef check_gguf_file(model: str | os.PathLike) -> bool:\n    \"\"\"Check if the file is a GGUF model.\"\"\"\n    model = Path(model)\n    if not model.is_file():\n        return False\n    elif model.suffix == \".gguf\":\n        return True\n\n    with open(model, \"rb\") as f:\n        header = f.read(4)\n    return header == b\"GGUF\"\n\n\ndef maybe_download_lora(\n    model_name_or_path: str, local_dir: str | None = None, download: bool = True\n) -> str:\n    \"\"\"\n    Check if the model path is a Hugging Face Hub model ID and download it if needed.\n    Args:\n        model_name_or_path: Local path or Hugging Face Hub model ID\n        local_dir: Local directory to save the model\n        download: Whether to download the model from Hugging Face Hub\n\n    Returns:\n        Local path to the model\n    \"\"\"\n    allow_patterns = [\"*.json\", \"*.safetensors\", \"*.bin\"]\n\n    local_path = maybe_download_model(\n        model_name_or_path,\n        local_dir,\n        download,\n        is_lora=True,\n        allow_patterns=allow_patterns,\n    )\n    # return directly if local_path is a file\n    if os.path.isfile(local_path):\n        return local_path\n\n    weight_name = _best_guess_weight_name(local_path, file_extension=\".safetensors\")\n    # AMD workaround: PR 15813 changed from model_name_or_path to local_path,\n    # which can return None. Fall back to original behavior on ROCm.\n    if weight_name is None and current_platform.is_rocm():\n        weight_name = _best_guess_weight_name(\n            model_name_or_path, file_extension=\".safetensors\"\n        )\n    return os.path.join(local_path, weight_name)\n\n\ndef verify_model_config_and_directory(model_path: str) -> dict[str, Any]:\n    \"\"\"\n    Verify that the model directory contains a valid diffusers configuration.\n\n    Args:\n        model_path: Path to the model directory\n\n    Returns:\n        The loaded model configuration as a dictionary\n    \"\"\"\n\n    # Check for model_index.json which is required for diffusers models\n    config_path = os.path.join(model_path, \"model_index.json\")\n    if not os.path.exists(config_path):\n        raise ValueError(\n            f\"Model directory {model_path} does not contain model_index.json. \"\n            \"Only HuggingFace diffusers format is supported.\"\n        )\n\n    # Load the config\n    with open(config_path) as f:\n        config = json.load(f)\n\n    # Verify diffusers version exists\n    if \"_diffusers_version\" not in config:\n        raise ValueError(\"model_index.json does not contain _diffusers_version\")\n\n    logger.info(\"Diffusers version: %s\", config[\"_diffusers_version\"])\n\n    component_keys = [\n        key\n        for key, value in config.items()\n        if isinstance(value, (list, tuple))\n        and len(value) == 2\n        and all(isinstance(item, str) for item in value)\n    ]\n    if component_keys:\n        missing_components = [\n            component_key\n            for component_key in component_keys\n            if not os.path.exists(os.path.join(model_path, component_key))\n        ]\n        if missing_components:\n            missing_str = \", \".join(missing_components)\n            raise ValueError(\n                f\"Model directory {model_path} is missing required component \"\n                f\"directories: {missing_str}.\"\n            )\n    else:\n        transformer_dir = os.path.join(model_path, \"transformer\")\n        vae_dir = os.path.join(model_path, \"vae\")\n        if not os.path.exists(transformer_dir):\n            raise ValueError(\n                f\"Model directory {model_path} does not contain a transformer/ directory.\"\n            )\n        if not os.path.exists(vae_dir):\n            raise ValueError(\n                f\"Model directory {model_path} does not contain a vae/ directory.\"\n            )\n    return cast(dict[str, Any], config)\n\n\ndef maybe_download_model_index(model_name_or_path: str) -> dict[str, Any]:\n    \"\"\"\n    Download and extract just the model_index.json for a Hugging Face model.\n\n    Args:\n        model_name_or_path: Path or HF Hub model ID\n\n    Returns:\n        The parsed model_index.json as a dictionary\n    \"\"\"\n    import tempfile\n\n    from huggingface_hub.errors import EntryNotFoundError\n\n    # If it's a local path, verify it directly\n    if os.path.exists(model_name_or_path):\n        try:\n            return verify_model_config_and_directory(model_name_or_path)\n        except ValueError:\n            # Not a pipeline, maybe a single model.\n            config_path = os.path.join(model_name_or_path, \"config.json\")\n            if os.path.exists(config_path):\n                with open(config_path) as f:\n                    config = json.load(f)\n                return config\n            raise\n\n    # For remote models, download just the model_index.json\n    try:\n        with tempfile.TemporaryDirectory() as tmp_dir:\n            # Download just the model_index.json file\n            model_index_path = hf_hub_download(\n                repo_id=model_name_or_path,\n                filename=\"model_index.json\",\n                local_dir=tmp_dir,\n            )\n\n            # Load the model_index.json\n            with open(model_index_path) as f:\n                config: dict[str, Any] = json.load(f)\n\n            # Verify it has the required fields\n            if \"_class_name\" not in config:\n                raise ValueError(\n                    f\"model_index.json for {model_name_or_path} does not contain _class_name field\"\n                )\n\n            if \"_diffusers_version\" not in config:\n                raise ValueError(\n                    f\"model_index.json for {model_name_or_path} does not contain _diffusers_version field\"\n                )\n\n            # Add the pipeline name for downstream use\n            config[\"pipeline_name\"] = config[\"_class_name\"]\n\n            logger.debug(\n                \"Downloaded model_index.json for %s, pipeline: %s\",\n                model_name_or_path,\n                config[\"_class_name\"],\n            )\n            return config\n    except EntryNotFoundError:\n        logger.warning(\n            \"model_index.json not found for %s. Assuming it is a single model and downloading it.\",\n            model_name_or_path,\n        )\n        local_path = maybe_download_model(model_name_or_path)\n        config_path = os.path.join(local_path, \"config.json\")\n        if not os.path.exists(config_path):\n            raise ValueError(\n                f\"Failed to find config.json for {model_name_or_path} after failing to find model_index.json\"\n                f\"You might be looking for models ending with '-Diffusers'\"\n            )\n        with open(config_path) as f:\n            config = json.load(f)\n        return config\n    except Exception as e:\n        raise ValueError(\n            f\"Failed to download or parse model_index.json for {model_name_or_path}: {e}\"\n        ) from e\n\n\ndef maybe_download_model(\n    model_name_or_path: str,\n    local_dir: str | None = None,\n    download: bool = True,\n    is_lora: bool = False,\n    allow_patterns: list[str] | None = None,\n    force_diffusers_model: bool = False,\n) -> str:\n    \"\"\"\n    Check if the model path is a Hugging Face Hub model ID and download it if needed.\n\n    Args:\n        model_name_or_path: Local path or Hugging Face Hub model ID\n        local_dir: Local directory to save the model\n        download: Whether to download the model from Hugging Face Hub\n        is_lora: If True, skip model completeness verification (LoRA models don't have transformer/vae directories)\n        force_diffusers_model: If True, apply diffusers model check. Otherwise it should be a component model\n    Returns:\n        Local path to the model\n    \"\"\"\n\n    # 1. Local path check: if path exists locally, verify it's complete (skip for LoRA)\n    if os.path.exists(model_name_or_path):\n        if not force_diffusers_model:\n            return model_name_or_path\n        if is_lora or _verify_diffusers_model_complete(model_name_or_path):\n            if not is_lora:\n                is_valid, cleanup_performed = _ci_validate_diffusers_model(\n                    model_name_or_path\n                )\n                if not is_valid:\n                    if cleanup_performed:\n                        logger.warning(\n                            \"CI validation failed for local model at %s, \"\n                            \"cache has been cleaned up, will re-download\",\n                            model_name_or_path,\n                        )\n                        # Fall through to download\n                    else:\n                        raise ValueError(\n                            f\"CI validation failed for local model at {model_name_or_path}. \"\n                            \"Some safetensors shards are missing. \"\n                            \"Please manually delete the model directory and retry.\"\n                        )\n                else:\n                    logger.info(\"Model already exists locally and is complete\")\n                    return model_name_or_path\n            else:\n                logger.info(\"Model already exists locally and is complete\")\n                return model_name_or_path\n        else:\n            logger.warning(\n                \"Local model at %s appears incomplete (missing required components), \"\n                \"will attempt re-download\",\n                model_name_or_path,\n            )\n\n    # 2. Cache-first strategy (Fast Path)\n    # Try to read from HF cache without network access\n    try:\n        logger.info(\n            \"Checking for cached model in HF Hub cache for %s...\", model_name_or_path\n        )\n        local_path = snapshot_download(\n            repo_id=model_name_or_path,\n            ignore_patterns=[\"*.onnx\", \"*.msgpack\"],\n            local_dir=local_dir,\n            local_files_only=True,\n            max_workers=8,\n        )\n        if not force_diffusers_model:\n            return str(local_path)\n        if is_lora or _verify_diffusers_model_complete(local_path):\n            if not is_lora:\n                is_valid, cleanup_performed = _ci_validate_diffusers_model(local_path)\n                if not is_valid:\n                    logger.warning(\n                        \"CI validation failed for cached model at %s, \"\n                        \"%s, will re-download\",\n                        local_path,\n                        (\n                            \"cache has been cleaned up\"\n                            if cleanup_performed\n                            else \"cleanup was not performed\"\n                        ),\n                    )\n                    # Fall through to download\n                else:\n                    logger.info(\"Found complete model in cache at %s\", local_path)\n                    return str(local_path)\n            else:\n                logger.info(\"Found complete model in cache at %s\", local_path)\n                return str(local_path)\n        else:\n            if not download:\n                raise ValueError(\n                    f\"Model {model_name_or_path} found in cache but is incomplete and download=False.\"\n                )\n            logger.info(\n                \"Model found in cache but incomplete, will download from HF Hub\"\n            )\n    except LocalEntryNotFoundError:\n        if not download:\n            raise ValueError(\n                f\"Model {model_name_or_path} not found in local cache and download=False.\"\n            )\n        logger.info(\"Model not found in cache, will download from HF Hub\")\n    except Exception as e:\n        logger.warning(\n            \"Unexpected error while checking cache for %s: %s, will attempt download\",\n            model_name_or_path,\n            e,\n        )\n        if not download:\n            raise ValueError(\n                f\"Error checking cache for {model_name_or_path} and download=False: {e}\"\n            ) from e\n\n    # 3. Download strategy (with retry mechanism)\n    MAX_RETRIES = 5\n    for attempt in range(MAX_RETRIES):\n        try:\n            logger.info(\n                \"Downloading model snapshot from HF Hub for %s (attempt %d/%d)...\",\n                model_name_or_path,\n                attempt + 1,\n                MAX_RETRIES,\n            )\n            with get_lock(model_name_or_path).acquire(poll_interval=2):\n                local_path = snapshot_download(\n                    repo_id=model_name_or_path,\n                    ignore_patterns=[\"*.onnx\", \"*.msgpack\"],\n                    allow_patterns=allow_patterns,\n                    local_dir=local_dir,\n                    max_workers=8,\n                )\n\n            if not force_diffusers_model:\n                return str(local_path)\n            # Verify downloaded model is complete (skip for LoRA)\n            elif not is_lora and not _verify_diffusers_model_complete(local_path):\n                logger.warning(\n                    \"Downloaded model at %s is incomplete, retrying with force_download=True\",\n                    local_path,\n                )\n                with get_lock(model_name_or_path).acquire(poll_interval=2):\n                    local_path = snapshot_download(\n                        repo_id=model_name_or_path,\n                        ignore_patterns=[\"*.onnx\", \"*.msgpack\"],\n                        local_dir=local_dir,\n                        max_workers=8,\n                        force_download=True,\n                    )\n                if not _verify_diffusers_model_complete(local_path):\n                    raise ValueError(\n                        f\"Downloaded model at {local_path} is still incomplete after forced re-download. \"\n                        \"The model repository may be missing required components (model_index.json, transformer/, or vae/).\"\n                    )\n\n            # CI validation: check all subdirectories for missing shards after download\n            if not is_lora:\n                is_valid, cleanup_performed = _ci_validate_diffusers_model(local_path)\n                if not is_valid:\n                    # In CI, if validation fails after download, we have a serious issue\n                    # If cleanup was performed, the next retry should get a fresh download\n                    raise ValueError(\n                        f\"CI validation failed for downloaded model at {local_path}. \"\n                        f\"Some safetensors shards are missing. Cleanup performed: {cleanup_performed}.\"\n                    )\n\n            logger.info(\"Downloaded model to %s\", local_path)\n            return str(local_path)\n\n        except (RepositoryNotFoundError, RevisionNotFoundError) as e:\n            raise ValueError(\n                f\"Model or revision not found at {model_name_or_path}. \"\n                f\"Please check the model ID or ensure you have access to the repository. Error: {e}\"\n            ) from e\n        except (RequestException, RequestsConnectionError) as e:\n            if attempt == MAX_RETRIES - 1:\n                raise ValueError(\n                    f\"Could not find model at {model_name_or_path} and failed to download from HF Hub \"\n                    f\"after {MAX_RETRIES} attempts due to network error: {e}\"\n                ) from e\n            wait_time = 2**attempt\n            logger.warning(\n                \"Download failed (attempt %d/%d) due to network error: %s. \"\n                \"Retrying in %d seconds...\",\n                attempt + 1,\n                MAX_RETRIES,\n                e,\n                wait_time,\n            )\n            time.sleep(wait_time)\n        except Exception as e:\n            raise ValueError(\n                f\"Could not find model at {model_name_or_path} and failed to download from HF Hub: {e}\"\n            ) from e\n\n\n# Unified download functions with Hugging Face-compatible names\ndef hf_hub_download(\n    repo_id: str,\n    filename: str,\n    local_dir: Optional[Union[str, Path]] = None,\n    **kwargs,\n) -> str:\n    \"\"\"Unified hf_hub_download that supports both Hugging Face Hub and ModelScope.\"\"\"\n    if envs.SGLANG_USE_MODELSCOPE.get():\n        from modelscope import model_file_download\n\n        return model_file_download(\n            model_id=repo_id,\n            file_path=filename,\n            cache_dir=local_dir,\n            **kwargs,\n        )\n    else:\n        from huggingface_hub import hf_hub_download as _hf_hub_download\n\n        return _hf_hub_download(\n            repo_id=repo_id,\n            filename=filename,\n            local_dir=local_dir,\n            **kwargs,\n        )\n\n\ndef snapshot_download(\n    repo_id: str,\n    local_dir: Optional[Union[str, Path]] = None,\n    ignore_patterns: Optional[Union[list[str], str]] = None,\n    allow_patterns: Optional[Union[list[str], str]] = None,\n    local_files_only: bool = False,\n    max_workers: int = 8,\n    **kwargs,\n) -> str:\n    \"\"\"Unified snapshot_download that supports both Hugging Face Hub and ModelScope.\"\"\"\n    if envs.SGLANG_USE_MODELSCOPE.get():\n        from modelscope import snapshot_download as _ms_snapshot_download\n\n        ms_kwargs = {\n            \"model_id\": repo_id,\n            \"local_dir\": local_dir,\n            \"ignore_patterns\": ignore_patterns,\n            \"allow_patterns\": allow_patterns,\n            \"local_files_only\": local_files_only,\n            \"max_workers\": max_workers,\n        }\n        ms_kwargs.update(kwargs)\n        return _ms_snapshot_download(**ms_kwargs)\n    else:\n        from huggingface_hub import snapshot_download as _hf_snapshot_download\n\n        hf_kwargs = {\n            \"repo_id\": repo_id,\n            \"local_dir\": local_dir,\n            \"ignore_patterns\": ignore_patterns,\n            \"allow_patterns\": allow_patterns,\n            \"local_files_only\": local_files_only,\n            \"max_workers\": max_workers,\n            \"etag_timeout\": 60,\n        }\n        hf_kwargs.update(kwargs)\n        return _hf_snapshot_download(**hf_kwargs)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/utils/layerwise_offload.py",
    "content": "import re\nfrom itertools import chain\nfrom typing import Any, Dict, List, Set, Tuple\n\nimport torch\n\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\n# Adapted from skywork AI Infra diffusion optimize\nclass LayerwiseOffloadManager:\n    \"\"\"A lightweight layerwise CPU offload manager.\n\n    This utility offloads per-layer parameters/buffers from GPU to CPU, and\n    supports async H2D prefetch using a dedicated CUDA stream.\n\n    Typical usage:\n    - Construct the manager with the target model and the list-like module\n      attribute that represents transformer blocks (e.g. ``blocks``).\n    - Call :meth:`initialize` once to offload weights and prefetch layer 0.\n    - During forward, call :meth:`prefetch_layer` for the next layer and\n      :meth:`release_layer` for the finished layer.\n    \"\"\"\n\n    def __init__(\n        self,\n        model: torch.nn.Module,\n        *,\n        layers_attr_str: str,\n        num_layers: int,\n        enabled: bool,\n        pin_cpu_memory: bool = True,\n        prefetch_size: int = 1,\n    ) -> None:\n        self.model = model\n        self.layers_attr_str = layers_attr_str\n        self.num_layers = num_layers\n        self.pin_cpu_memory = pin_cpu_memory\n        self.prefetch_size = min(max(1, prefetch_size), self.num_layers)\n        self.enabled = bool(enabled and torch.get_device_module().is_available())\n        if not self.enabled:\n            return\n        self.device = torch.device(\n            current_platform.device_type, torch.get_device_module().current_device()\n        )\n        self.copy_stream = torch.get_device_module().Stream()\n\n        self._layer_name_re = re.compile(\n            rf\"(^|\\.){re.escape(layers_attr_str)}\\.(\\d+)(\\.|$)\"\n        )\n\n        # layer_idx -> {dtype: consolidated_pinned_cpu_tensor}\n        # stores the consolidated weight from a same layer, of same dtype\n        self._consolidated_cpu_weights: Dict[int, Dict[torch.dtype, torch.Tensor]] = {}\n        # layer_idx -> {name: {dtype, offset, numel, shape}}\n        # stores the offset and numel of each weight from a same layer, of same dtype\n        self._weight_metadata: Dict[int, Dict[str, Dict[str, Any]]] = {}\n        # layer indices that are already in gpu\n        self._gpu_layers: Set[int] = set()\n        # layer_idx -> torch.get_device_module().Event for fine-grained sync, to make sure the weight is resident in pre-hook\n        self._prefetch_events: Dict[int, torch.get_device_module().Event] = {}\n\n        self._named_parameters: Dict[str, torch.nn.Parameter] = {}\n        self._named_buffers: Dict[str, torch.Tensor] = {}\n        self._offload_placeholders: Dict[torch.dtype, torch.Tensor] = {}\n        # Store forward hooks for removal\n        self._forward_hooks: List[Any] = []\n\n        self._initialize()\n\n    def _match_layer_idx(self, name: str) -> int | None:\n        m = self._layer_name_re.search(name)\n        if not m:\n            return None\n        try:\n            return int(m.group(2))\n        except Exception:\n            return None\n\n    def _get_shared_empty_tensor(self, dtype: torch.dtype) -> torch.Tensor:\n        placeholder = self._offload_placeholders.get(dtype)\n        if placeholder is None:\n            placeholder = torch.empty((1,), device=self.device, dtype=dtype)\n            self._offload_placeholders[dtype] = placeholder\n        return placeholder\n\n    @torch.compiler.disable\n    def _initialize(self) -> None:\n        if not self.enabled:\n            return\n\n        self._named_parameters = dict(self.model.named_parameters())\n        self._named_buffers = dict(self.model.named_buffers())\n\n        # 1. collect and group tensors by layer and dtype\n        layer_groups: Dict[int, Dict[torch.dtype, List[Tuple[str, torch.Tensor]]]] = {}\n        all_tensors = chain(self._named_parameters.items(), self._named_buffers.items())\n        for name, tensor in all_tensors:\n            layer_idx = self._match_layer_idx(name)\n            if layer_idx is None or layer_idx >= self.num_layers:\n                continue\n            layer_groups.setdefault(layer_idx, {}).setdefault(tensor.dtype, []).append(\n                (name, tensor)\n            )\n\n        # 2. concat and offload (in pinned memory)\n        for layer_idx, dtype_to_params in layer_groups.items():\n            self._consolidated_cpu_weights[layer_idx] = {}\n            self._weight_metadata[layer_idx] = {}\n\n            for dtype, weights in dtype_to_params.items():\n                total_numel = sum(t.numel() for _, t in weights)\n\n                # create concatenated CPU buffer (in pinned memory)\n                cpu_buffer = torch.empty(\n                    total_numel, dtype=dtype, pin_memory=self.pin_cpu_memory\n                )\n\n                # offload weights to the buffer\n                current_offset = 0\n                for name, weight in weights:\n                    numel = weight.numel()\n                    cpu_buffer[current_offset : current_offset + numel].copy_(\n                        weight.flatten()\n                    )\n                    self._weight_metadata[layer_idx][name] = {\n                        \"dtype\": dtype,\n                        \"offset\": current_offset,\n                        \"numel\": numel,\n                        \"shape\": weight.shape,\n                    }\n\n                    weight.data = self._get_shared_empty_tensor(dtype)\n\n                    current_offset += numel\n\n                self._consolidated_cpu_weights[layer_idx][dtype] = cpu_buffer\n\n        # prefetch the first layer for warm-up\n        self.prepare_for_next_req(non_blocking=False)\n\n        self.register_forward_hooks()\n        logger.info(\n            f\"LayerwiseOffloadManager initialized with num prefetched layer: {self.prefetch_size}, total num layers: {self.num_layers}\"\n        )\n\n    def prepare_for_next_req(self, non_blocking=True):\n        \"\"\"\n        Prepare for the next round of denoising loop with prefetching the necessary layers\n        \"\"\"\n        for i in range(self.prefetch_size):\n            self.prefetch_layer(i, non_blocking=non_blocking)\n        if not non_blocking and self.copy_stream is not None:\n            torch.get_device_module().current_stream().wait_stream(self.copy_stream)\n\n    def get_target_with_name(self, name: str) -> torch.Tensor:\n        \"\"\"get the target model weight/buffer to be replaced\"\"\"\n        if name in self._named_parameters:\n            target = self._named_parameters[name]\n        else:\n            target = self._named_buffers[name]\n        return target\n\n    @torch.compiler.disable\n    def prefetch_layer(self, layer_idx: int, non_blocking: bool = True) -> None:\n        \"\"\"\n        idempotent\n        \"\"\"\n        if not self.enabled or self.device is None or self.copy_stream is None:\n            return\n        if layer_idx < 0 or layer_idx >= self.num_layers:\n            return\n        if layer_idx in self._gpu_layers:\n            return\n        if layer_idx not in self._consolidated_cpu_weights:\n            return\n        self.copy_stream.wait_stream(torch.get_device_module().current_stream())\n\n        # create gpu buffer and load from CPU buffer\n        gpu_buffers: Dict[torch.dtype, torch.Tensor] = {}\n        with torch.get_device_module().stream(self.copy_stream):\n            for dtype, cpu_buffer in self._consolidated_cpu_weights[layer_idx].items():\n                gpu_buffer = torch.empty(\n                    cpu_buffer.shape, dtype=dtype, device=self.device\n                )\n                gpu_buffer.copy_(cpu_buffer, non_blocking=non_blocking)\n                gpu_buffers[dtype] = gpu_buffer\n\n        # record the prefetch event of this layer\n        event = torch.get_device_module().Event()\n        event.record(self.copy_stream)\n        self._prefetch_events[layer_idx] = event\n\n        # restore model's weights by their metadata using gpu buffer\n        for name, meta in self._weight_metadata[layer_idx].items():\n            dtype = meta[\"dtype\"]\n            gpu_buffer = gpu_buffers[dtype]\n\n            # map the parameter's data to the correct slice of the GPU buffer\n            target = self.get_target_with_name(name)\n            target.data = gpu_buffer[\n                meta[\"offset\"] : meta[\"offset\"] + meta[\"numel\"]\n            ].view(meta[\"shape\"])\n\n        self._gpu_layers.add(layer_idx)\n\n    @torch.compiler.disable\n    def release_layer(self, layer_idx: int) -> None:\n        \"\"\"\n        lightweight release layer weights\n        Basically set the reference count to the gpu weight tensor to zero. The weights on cpu is untouched\n        \"\"\"\n        if not self.enabled or self.device is None:\n            return\n\n        # clear prefetch event, since it's useless and needs to be reset\n        self._prefetch_events.pop(layer_idx, None)\n\n        if layer_idx not in self._gpu_layers:\n            return\n\n        for name, meta in self._weight_metadata.get(layer_idx, {}).items():\n            target = self.get_target_with_name(name)\n            # Wraparound prefetch will reload the layer when it is needed again\n            target.data = self._get_shared_empty_tensor(meta[\"dtype\"])\n\n        self._gpu_layers.discard(layer_idx)\n\n    @torch.compiler.disable\n    def release_all(self) -> None:\n        if not self.enabled or self.device is None:\n            return\n        if self.copy_stream is not None:\n            torch.get_device_module().current_stream().wait_stream(self.copy_stream)\n\n        for layer_idx in list(self._gpu_layers):\n            self.release_layer(layer_idx)\n\n    @torch.compiler.disable\n    def load_all_layers(self) -> None:\n        \"\"\"Load all layers from CPU to GPU.\"\"\"\n        if not self.enabled or self.device is None:\n            return\n        if self.copy_stream is not None:\n            torch.get_device_module().current_stream().wait_stream(self.copy_stream)\n\n        for layer_idx in range(self.num_layers):\n            if layer_idx not in self._gpu_layers:\n                self.prefetch_layer(layer_idx, non_blocking=False)\n\n    @torch.compiler.disable\n    def sync_layer_to_cpu(self, layer_idx: int) -> None:\n        \"\"\"Sync a layer's weights from GPU back to CPU.\"\"\"\n        if not self.enabled or layer_idx not in self._gpu_layers:\n            return\n        if layer_idx not in self._consolidated_cpu_weights:\n            return\n\n        if self.copy_stream is not None:\n            torch.get_device_module().current_stream().wait_stream(self.copy_stream)\n\n        # Collect current GPU weights and write back to CPU buffer\n        for name, meta in self._weight_metadata.get(layer_idx, {}).items():\n            target = self.get_target_with_name(name)\n            gpu_weight = target.data.flatten().cpu()\n\n            dtype = meta[\"dtype\"]\n            cpu_buffer = self._consolidated_cpu_weights[layer_idx][dtype]\n            offset = meta[\"offset\"]\n            numel = meta[\"numel\"]\n            cpu_buffer[offset : offset + numel].copy_(gpu_weight)\n\n    @torch.compiler.disable\n    def sync_all_layers_to_cpu(self) -> None:\n        \"\"\"Sync all loaded layers' weights from GPU back to CPU.\"\"\"\n        if not self.enabled or self.device is None:\n            return\n        if self.copy_stream is not None:\n            torch.get_device_module().current_stream().wait_stream(self.copy_stream)\n\n        for layer_idx in list(self._gpu_layers):\n            self.sync_layer_to_cpu(layer_idx)\n\n    @torch.compiler.disable\n    def update_cpu_weights(\n        self, weight_dict: Dict[str, torch.Tensor]\n    ) -> Set[str] | None:\n        \"\"\"Update consolidated CPU buffers with new weights.\n\n        When layerwise offload (--dit-layerwise-offload) is enabled, the\n        offload manager replaces GPU parameters with small torch.empty((1,))\n        placeholders while real weights live in consolidated pinned CPU\n        buffers.\n\n        The refit process writes new weights directly into the CPU buffers,\n        bypassing the placeholders.  For any layer that happens to be resident\n        on the GPU at update time, the live GPU tensor is also updated.\n\n        Args:\n            weight_dict: Mapping of parameter name to new weight tensor.\n\n        Returns:\n            Set of parameter names that were successfully updated.\n\n        Raises:\n            ValueError: If a weight's shape does not match the recorded\n                metadata (i.e., the real shape, not the placeholder shape).\n        \"\"\"\n        if not self.enabled:\n            return None\n\n        updated_names: Set[str] = set()\n        for name, loaded_weight in weight_dict.items():\n            layer_idx = self._match_layer_idx(name)\n            if layer_idx is None:\n                continue\n            meta_layer = self._weight_metadata.get(layer_idx)\n            if meta_layer is None or name not in meta_layer:\n                continue\n\n            meta = meta_layer[name]\n            if tuple(meta[\"shape\"]) != tuple(loaded_weight.shape):\n                raise ValueError(\n                    f\"Shape mismatch for {name}: \"\n                    f\"expected={tuple(meta['shape'])}, \"\n                    f\"loaded={tuple(loaded_weight.shape)}\"\n                )\n\n            dtype = meta[\"dtype\"]\n            offset = meta[\"offset\"]\n            numel = meta[\"numel\"]\n            cpu_buffer = self._consolidated_cpu_weights[layer_idx][dtype]\n            cpu_buffer[offset : offset + numel].copy_(\n                loaded_weight.to(dtype=dtype).flatten()\n            )\n\n            # If this layer is currently on GPU, update the live parameter.\n            if layer_idx in self._gpu_layers:\n                target = self.get_target_with_name(name)\n                target.data.copy_(loaded_weight.to(dtype=target.dtype))\n\n            updated_names.add(name)\n\n        return updated_names\n\n    def iter_cpu_weights(self):\n        \"\"\"Yield (name, tensor) pairs from consolidated CPU buffers.\n\n        This reconstructs the original weight tensors (with correct shapes)\n        from the flat CPU buffers using stored metadata.  Unlike\n        model.named_parameters(), which returns (1,) placeholders\n        when offload is enabled, this method returns the real weights and\n        can be used for checksum computation.\n        \"\"\"\n        for layer_idx in sorted(self._weight_metadata):\n            for name, meta in self._weight_metadata[layer_idx].items():\n                dtype = meta[\"dtype\"]\n                offset = meta[\"offset\"]\n                numel = meta[\"numel\"]\n                shape = meta[\"shape\"]\n                cpu_buffer = self._consolidated_cpu_weights[layer_idx][dtype]\n                yield name, cpu_buffer[offset : offset + numel].reshape(shape)\n\n    def register_forward_hooks(self) -> None:\n        if not self.enabled:\n            return\n\n        layers = getattr(self.model, self.layers_attr_str)\n\n        def make_pre_hook(i):\n            def hook(module, input):\n                # wait only for the current layer if it's being prefetched\n                if i == 0:\n                    self.prepare_for_next_req(non_blocking=False)\n                if i in self._prefetch_events:\n                    torch.get_device_module().current_stream().wait_event(\n                        self._prefetch_events[i]\n                    )\n\n                # trigger batch prefetch (i + prefetch_size ~ i + 2 * prefetch_size) if needed\n                if i % self.prefetch_size == 0:\n                    for j in range(i + self.prefetch_size, i + 2 * self.prefetch_size):\n                        layer_to_prefetch = j % self.num_layers\n                        self.prefetch_layer(layer_to_prefetch, non_blocking=True)\n\n            return hook\n\n        def make_post_hook(i):\n            def hook(module, input, output):\n                # previous, we wait here, until the copy stream for next layer is finished,\n                # now with any prefetch_size, only wait for the copy stream, when the copy stream is for the next layer\n                self.release_layer(i)\n\n            return hook\n\n        # register prefetch & release hooks for each layer\n        self._forward_hooks.clear()\n        for i, layer in enumerate(layers):\n            pre_hook_handle = layer.register_forward_pre_hook(make_pre_hook(i))\n            post_hook_handle = layer.register_forward_hook(make_post_hook(i))\n            self._forward_hooks.extend([pre_hook_handle, post_hook_handle])\n\n    def remove_forward_hooks(self) -> None:\n        \"\"\"Remove all registered forward hooks.\"\"\"\n        for hook_handle in self._forward_hooks:\n            hook_handle.remove()\n        self._forward_hooks.clear()\n\n\nclass OffloadableDiTMixin:\n    \"\"\"\n    A mixin that registers forward hooks for a DiT to enable layerwise offload\n    \"\"\"\n\n    # the list of names of a DiT's layers/blocks\n    layer_names: List[str]\n    layerwise_offload_managers: list[LayerwiseOffloadManager] = []\n\n    def configure_layerwise_offload(self, server_args: ServerArgs):\n        self.layerwise_offload_managers = []\n        for layer_name in self.layer_names:\n            # a manager per layer-list\n            module_list = getattr(self, layer_name, None)\n            if module_list is None or not isinstance(module_list, torch.nn.ModuleList):\n                continue\n\n            num_layers = len(module_list)\n            if server_args.dit_offload_prefetch_size < 1.0:\n                prefetch_size = 1 + int(\n                    round(server_args.dit_offload_prefetch_size * (num_layers - 1))\n                )\n            else:\n                prefetch_size = int(server_args.dit_offload_prefetch_size)\n\n            manager = LayerwiseOffloadManager(\n                model=self,\n                layers_attr_str=layer_name,\n                num_layers=num_layers,\n                enabled=True,\n                pin_cpu_memory=server_args.pin_cpu_memory,\n                prefetch_size=prefetch_size,\n            )\n            self.layerwise_offload_managers.append(manager)\n\n        logger.info(\n            f\"Enabled layerwise offload for {self.__class__.__name__} on modules: {self.layer_names}\"\n        )\n\n    def prepare_for_next_req(self):\n        if self.layerwise_offload_managers is None:\n            return\n        for manager in self.layerwise_offload_managers:\n            manager.prepare_for_next_req(non_blocking=True)\n\n    def disable_offload(self) -> None:\n        \"\"\"Disable layerwise offload: load all layers to GPU and remove hooks.\"\"\"\n        if self.layerwise_offload_managers is None:\n            return\n        for manager in self.layerwise_offload_managers:\n            if manager.enabled:\n                manager.remove_forward_hooks()\n                manager.load_all_layers()\n\n    def enable_offload(self) -> None:\n        \"\"\"Re-enable layerwise offload: sync weights to CPU, release layers, and restore hooks.\"\"\"\n        if self.layerwise_offload_managers is None:\n            return\n        for manager in self.layerwise_offload_managers:\n            if manager.enabled:\n                manager.sync_all_layers_to_cpu()\n                manager.release_all()\n                manager.register_forward_hooks()\n\n\ndef iter_materialized_weights(module: torch.nn.Module):\n    \"\"\"Yield (name, tensor) pairs with materialized weights, even under offload.\n\n    When layerwise offload is active, module.named_parameters() returns\n    (1,) placeholders for offloaded layers.  This function reads the\n    actual data from the offload manager's CPU buffers and chains it with\n    the non-offloaded parameters.\n    \"\"\"\n    offload_managers: list = []\n    if isinstance(module, OffloadableDiTMixin) and module.layerwise_offload_managers:\n        offload_managers = [m for m in module.layerwise_offload_managers if m.enabled]\n\n    if not offload_managers:\n        yield from module.named_parameters()\n        return\n\n    # Collect offloaded names and their real tensors from CPU buffers.\n    offloaded_names: set[str] = set()\n    for manager in offload_managers:\n        for name, tensor in manager.iter_cpu_weights():\n            offloaded_names.add(name)\n            yield name, tensor\n\n    # Yield non-offloaded parameters (e.g. final norms, embeddings).\n    for name, param in module.named_parameters():\n        if name not in offloaded_names:\n            yield name, param\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/utils/logging_utils.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/logger.py\n\"\"\"Logging configuration for sglang.multimodal_gen.\"\"\"\n\nimport argparse\nimport contextlib\nimport dataclasses\nimport datetime\nimport inspect\nimport logging\nimport os\nimport sys\nimport time\nfrom contextlib import contextmanager\nfrom enum import Enum\nfrom functools import lru_cache, partial\nfrom logging import Logger\nfrom types import MethodType\nfrom typing import Any, cast\n\nimport sglang.multimodal_gen.envs as envs\n\nSGLANG_DIFFUSION_LOGGING_LEVEL = envs.SGLANG_DIFFUSION_LOGGING_LEVEL\nSGLANG_DIFFUSION_LOGGING_PREFIX = envs.SGLANG_DIFFUSION_LOGGING_PREFIX\n\n# color\nCYAN = \"\\033[1;36m\"\nRED = \"\\033[91m\"\nGREEN = \"\\033[92m\"\nYELLOW = \"\\033[93m\"\nRESET = \"\\033[0;0m\"\n\n_FORMAT = (\n    f\"{SGLANG_DIFFUSION_LOGGING_PREFIX}%(levelname)s %(asctime)s \"\n    \"[%(filename)s: %(lineno)d] %(message)s\"\n)\n\n# _FORMAT = \"[%(asctime)s] %(message)s\"\n_DATE_FORMAT = \"%m-%d %H:%M:%S\"\n\nDEFAULT_LOGGING_CONFIG = {\n    \"formatters\": {\n        \"sgl_diffusion\": {\n            \"class\": \"sglang.multimodal_gen.runtime.utils.logging_utils.ColoredFormatter\",\n            \"datefmt\": _DATE_FORMAT,\n            \"format\": _FORMAT,\n        },\n    },\n    \"handlers\": {\n        \"sgl_diffusion\": {\n            \"class\": \"logging.StreamHandler\",\n            \"formatter\": \"sgl_diffusion\",\n            \"level\": SGLANG_DIFFUSION_LOGGING_LEVEL,\n            \"stream\": \"ext://sys.stdout\",\n        },\n    },\n    \"loggers\": {\n        \"sgl_diffusion\": {\n            \"handlers\": [\"sgl_diffusion\"],\n            \"level\": \"WARNING\",\n            \"propagate\": False,\n        },\n    },\n    \"root\": {\n        \"handlers\": [\"sgl_diffusion\"],\n        \"level\": \"DEBUG\",\n    },\n    \"version\": 1,\n    \"disable_existing_loggers\": False,\n}\n\n\nclass ColoredFormatter(logging.Formatter):\n    \"\"\"A logging formatter that adds color to log levels.\"\"\"\n\n    LEVEL_COLORS = {\n        logging.ERROR: RED,\n        logging.WARNING: YELLOW,\n    }\n\n    def format(self, record: logging.LogRecord) -> str:\n        \"\"\"Adds color to the log\"\"\"\n\n        formatted_message = super().format(record)\n\n        color = self.LEVEL_COLORS.get(record.levelno)\n        if color:\n            formatted_message = f\"{color}{formatted_message}{RESET}\"\n\n        return formatted_message\n\n\nclass SortedHelpFormatter(argparse.HelpFormatter):\n    \"\"\"SortedHelpFormatter that sorts arguments by their option strings.\"\"\"\n\n    def add_arguments(self, actions):\n        actions = sorted(actions, key=lambda x: x.option_strings)\n        super().add_arguments(actions)\n\n\n@lru_cache\ndef _print_info_once(logger: Logger, msg: str) -> None:\n    # Set the stacklevel to 2 to print the original caller's line info\n    logger.info(msg, stacklevel=2)\n\n\n@lru_cache\ndef _print_warning_once(logger: Logger, msg: str) -> None:\n    # Set the stacklevel to 2 to print the original caller's line info\n    logger.warning(msg, stacklevel=2)\n\n\ndef get_is_main_process():\n    try:\n        rank = int(os.environ[\"RANK\"])\n    except (KeyError, ValueError):\n        rank = 0\n    return rank == 0\n\n\ndef get_is_local_main_process():\n    try:\n        rank = int(os.environ[\"LOCAL_RANK\"])\n    except (KeyError, ValueError):\n        rank = 0\n    return rank == 0\n\n\ndef _log_process_aware(\n    server_log_level: int,\n    level: int,\n    logger_self: Logger,\n    msg: object,\n    *args: Any,\n    main_process_only: bool,\n    local_main_process_only: bool,\n    **kwargs: Any,\n) -> None:\n    \"\"\"Helper function to log a message if the process rank matches the criteria.\"\"\"\n    is_main_process = get_is_main_process()\n    is_local_main_process = get_is_local_main_process()\n    should_log = (\n        not main_process_only\n        and not local_main_process_only\n        or (main_process_only and is_main_process)\n        or (local_main_process_only and is_local_main_process)\n        or server_log_level <= logging.DEBUG\n    )\n\n    if should_log:\n        # stacklevel=3 to show the original caller's location,\n        # as this function is called by the patched methods.\n        if \"stacklevel\" in kwargs:\n            logger_self.log(level, msg, *args, **kwargs)\n        else:\n            logger_self.log(level, msg, *args, stacklevel=3, **kwargs)\n\n\nclass _SGLDiffusionLogger(Logger):\n    \"\"\"\n    Note:\n        This class is just to provide type information.\n        We actually patch the methods directly on the :class:`logging.Logger`\n        instance to avoid conflicting with other libraries such as\n        `intel_extension_for_pytorch.utils._logger`.\n    \"\"\"\n\n    def info_once(self, msg: str) -> None:\n        \"\"\"\n        As :meth:`info`, but subsequent calls with the same message\n        are silently dropped.\n        \"\"\"\n        _print_info_once(self, msg)\n\n    def warning_once(self, msg: str) -> None:\n        \"\"\"\n        As :meth:`warning`, but subsequent calls with the same message\n        are silently dropped.\n        \"\"\"\n        _print_warning_once(self, msg)\n\n    def info(  # type: ignore[override]\n        self,\n        msg: object,\n        *args: Any,\n        main_process_only: bool = True,\n        local_main_process_only: bool = True,\n        **kwargs: Any,\n    ) -> None: ...\n\n    def debug(  # type: ignore[override]\n        self,\n        msg: object,\n        *args: Any,\n        main_process_only: bool = True,\n        local_main_process_only: bool = True,\n        **kwargs: Any,\n    ) -> None: ...\n\n    def warning(  # type: ignore[override]\n        self,\n        msg: object,\n        *args: Any,\n        main_process_only: bool = False,\n        local_main_process_only: bool = True,\n        **kwargs: Any,\n    ) -> None: ...\n\n    def error(  # type: ignore[override]\n        self,\n        msg: object,\n        *args: Any,\n        main_process_only: bool = False,\n        local_main_process_only: bool = True,\n        **kwargs: Any,\n    ) -> None: ...\n\n\ndef init_logger(name: str) -> _SGLDiffusionLogger:\n    \"\"\"The main purpose of this function is to ensure that loggers are\n    retrieved in such a way that we can be sure the root sgl_diffusion logger has\n    already been configured.\"\"\"\n\n    logger = logging.getLogger(name)\n\n    server_log_level = logger.getEffectiveLevel()\n\n    # Patch instance methods\n    setattr(logger, \"info_once\", MethodType(_print_info_once, logger))\n    setattr(logger, \"warning_once\", MethodType(_print_warning_once, logger))\n\n    def _create_patched_method(\n        level: int,\n        main_process_only_default: bool,\n        local_main_process_only_default: bool,\n    ):\n        def _method(\n            self: Logger,\n            msg: object,\n            *args: Any,\n            main_process_only: bool = main_process_only_default,\n            local_main_process_only: bool = local_main_process_only_default,\n            **kwargs: Any,\n        ) -> None:\n            _log_process_aware(\n                server_log_level,\n                level,\n                self,\n                msg,\n                *args,\n                main_process_only=main_process_only,\n                local_main_process_only=local_main_process_only,\n                **kwargs,\n            )\n\n        return _method\n\n    setattr(\n        logger,\n        \"info\",\n        MethodType(_create_patched_method(logging.INFO, True, True), logger),\n    )\n    setattr(\n        logger,\n        \"debug\",\n        MethodType(_create_patched_method(logging.DEBUG, True, True), logger),\n    )\n    setattr(\n        logger,\n        \"warning\",\n        MethodType(_create_patched_method(logging.WARNING, False, True), logger),\n    )\n    setattr(\n        logger,\n        \"error\",\n        MethodType(_create_patched_method(logging.ERROR, False, False), logger),\n    )\n\n    return cast(_SGLDiffusionLogger, logger)\n\n\nlogger = init_logger(__name__)\n\n\ndef _is_torch_tensor(obj: Any) -> tuple[bool, Any]:\n    \"\"\"Return (is_tensor, torch_module_or_None) without importing torch at module import time.\"\"\"\n    try:\n        import torch  # type: ignore\n\n        return isinstance(obj, torch.Tensor), torch\n    except Exception:\n        return False, None\n\n\ndef _sanitize_for_logging(obj: Any, key_hint: str | None = None) -> Any:\n    \"\"\"Recursively convert objects to JSON-serializable forms for concise logging.\n\n    Rules:\n    - Drop any field/dict key named 'param_names_mapping'.\n    - Render Enums using their value.\n    - Render torch.Tensor as a compact summary; if key name is 'scaling_factor', include stats.\n    - Dataclasses are expanded to dicts and sanitized recursively.\n    - Callables/functions are rendered as their qualified name.\n    - Redact sensitive fields like 'prompt' and 'negative_prompt' (only show length).\n    - Fallback to str(...) for unknown types.\n    \"\"\"\n    if obj is None or isinstance(obj, (str, int, float, bool)):\n        if key_hint in (\"prompt\", \"negative_prompt\"):\n            if isinstance(obj, str):\n                return f\"<redacted, len={len(obj)}>\"\n        return obj\n\n    if isinstance(obj, Enum):\n        return obj.value\n\n    is_tensor, torch_mod = _is_torch_tensor(obj)\n    if is_tensor:\n        try:\n            ten = obj.detach().cpu()\n            if key_hint == \"scaling_factor\":\n                stats = {\n                    \"shape\": list(ten.shape),\n                    \"dtype\": str(ten.dtype),\n                }\n                try:\n                    stats[\"min\"] = float(ten.min().item())\n                except Exception:\n                    pass\n                try:\n                    stats[\"max\"] = float(ten.max().item())\n                except Exception:\n                    pass\n                try:\n                    stats[\"mean\"] = float(ten.float().mean().item())\n                except Exception:\n                    pass\n                return {\"tensor\": \"scaling_factor\", **stats}\n            return {\"tensor\": True, \"shape\": list(ten.shape), \"dtype\": str(ten.dtype)}\n        except Exception:\n            return \"<tensor>\"\n\n    if dataclasses.is_dataclass(obj):\n        result: dict[str, Any] = {}\n        for f in dataclasses.fields(obj):\n            if not f.repr:\n                continue\n            name = f.name\n            if \"names_mapping\" in name:\n                continue\n            try:\n                value = getattr(obj, name)\n            except Exception:\n                continue\n            result[name] = _sanitize_for_logging(value, key_hint=name)\n        return result\n\n    if isinstance(obj, dict):\n        result_dict: dict[str, Any] = {}\n        for k, v in obj.items():\n            try:\n                key_str = str(k)\n            except Exception:\n                key_str = \"<key>\"\n            if key_str == \"param_names_mapping\":\n                continue\n            result_dict[key_str] = _sanitize_for_logging(v, key_hint=key_str)\n        return result_dict\n\n    if isinstance(obj, (list, tuple, set)):\n        return [_sanitize_for_logging(x, key_hint=key_hint) for x in obj]\n\n    try:\n        if inspect.isroutine(obj) or inspect.isclass(obj):\n            module = getattr(obj, \"__module__\", \"\")\n            qn = getattr(obj, \"__qualname__\", getattr(obj, \"__name__\", \"<callable>\"))\n            return f\"{module}.{qn}\" if module else qn\n    except Exception:\n        pass\n\n    try:\n        return str(obj)\n    except Exception:\n        return \"<unserializable>\"\n\n\ndef _trace_calls(log_path, root_dir, frame, event, arg=None):\n    if event in [\"call\", \"return\"]:\n        # Extract the filename, line number, function name, and the code object\n        filename = frame.f_code.co_filename\n        lineno = frame.f_lineno\n        func_name = frame.f_code.co_name\n        if not filename.startswith(root_dir):\n            # only log the functions in the sgl_diffusion root_dir\n            return\n        # Log every function call or return\n        try:\n            last_frame = frame.f_back\n            if last_frame is not None:\n                last_filename = last_frame.f_code.co_filename\n                last_lineno = last_frame.f_lineno\n                last_func_name = last_frame.f_code.co_name\n            else:\n                # initial frame\n                last_filename = \"\"\n                last_lineno = 0\n                last_func_name = \"\"\n            with open(log_path, \"a\") as f:\n                ts = datetime.datetime.now().strftime(\"%Y-%m-%d %H:%M:%S.%f\")\n                if event == \"call\":\n                    f.write(\n                        f\"{ts} Call to\"\n                        f\" {func_name} in {filename}:{lineno}\"\n                        f\" from {last_func_name} in {last_filename}:\"\n                        f\"{last_lineno}\\n\"\n                    )\n                else:\n                    f.write(\n                        f\"{ts} Return from\"\n                        f\" {func_name} in {filename}:{lineno}\"\n                        f\" to {last_func_name} in {last_filename}:\"\n                        f\"{last_lineno}\\n\"\n                    )\n        except NameError:\n            # modules are deleted during shutdown\n            pass\n    return partial(_trace_calls, log_path, root_dir)\n\n\ndef enable_trace_function_call(log_file_path: str, root_dir: str | None = None):\n    \"\"\"\n    Enable tracing of every function call in code under `root_dir`.\n    This is useful for debugging hangs or crashes.\n    `log_file_path` is the path to the log file.\n    `root_dir` is the root directory of the code to trace. If None, it is the\n    sgl_diffusion root directory.\n\n    Note that this call is thread-level, any threads calling this function\n    will have the trace enabled. Other threads will not be affected.\n    \"\"\"\n    logger.warning(\n        \"SGLANG_DIFFUSION_TRACE_FUNCTION is enabled. It will record every\"\n        \" function executed by Python. This will slow down the code. It \"\n        \"is suggested to be used for debugging hang or crashes only.\"\n    )\n    logger.info(\"Trace frame log is saved to %s\", log_file_path)\n    if root_dir is None:\n        # by default, this is the sgl_diffusion root directory\n        root_dir = os.path.dirname(os.path.dirname(__file__))\n    sys.settrace(partial(_trace_calls, log_file_path, root_dir))\n\n\ndef set_uvicorn_logging_configs():\n    from uvicorn.config import LOGGING_CONFIG\n\n    LOGGING_CONFIG[\"formatters\"][\"default\"][\n        \"fmt\"\n    ] = \"[%(asctime)s] %(levelprefix)s %(message)s\"\n    LOGGING_CONFIG[\"formatters\"][\"default\"][\"datefmt\"] = \"%Y-%m-%d %H:%M:%S\"\n    LOGGING_CONFIG[\"formatters\"][\"access\"][\n        \"fmt\"\n    ] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - \"%(request_line)s\" %(status_code)s'\n    LOGGING_CONFIG[\"formatters\"][\"access\"][\"datefmt\"] = \"%Y-%m-%d %H:%M:%S\"\n\n\ndef configure_logger(server_args, prefix: str = \"\"):\n    log_format = f\"[%(asctime)s{prefix}] %(message)s\"\n    datefmt = \"%m-%d %H:%M:%S\"\n\n    formatter = ColoredFormatter(log_format, datefmt=datefmt)\n    handler = logging.StreamHandler(sys.stdout)\n    handler.setFormatter(formatter)\n\n    root = logging.getLogger()\n    root.handlers.clear()\n    root.addHandler(handler)\n    root.setLevel(getattr(logging, server_args.log_level.upper()))\n\n    set_uvicorn_logging_configs()\n\n\n@lru_cache(maxsize=1)\ndef get_log_level() -> int:\n    root = logging.getLogger()\n    return root.level\n\n\ndef suppress_loggers(loggers_to_suppress: list[str], level: int = logging.WARNING):\n    original_levels = {}\n\n    for logger_name in loggers_to_suppress:\n        logger = logging.getLogger(logger_name)\n        original_levels[logger_name] = logger.level\n        logger.setLevel(level)\n\n    return original_levels\n\n\ndef globally_suppress_loggers():\n    # globally suppress some obsessive loggers\n    target_names = [\n        \"imageio\",\n        \"imageio_ffmpeg\",\n        \"PIL\",\n        \"PIL_Image\",\n        \"python_multipart.multipart\",\n        \"filelock\",\n        \"urllib3\",\n        \"httpx\",\n        \"httpcore\",\n    ]\n\n    for name in target_names:\n        logging.getLogger(name).setLevel(logging.ERROR)\n\n\n# source: https://github.com/vllm-project/vllm/blob/a11f4a81e027efd9ef783b943489c222950ac989/vllm/utils/system_utils.py#L60\n@contextlib.contextmanager\ndef suppress_stdout():\n    \"\"\"\n    Suppress stdout from C libraries at the file descriptor level.\n\n    Only suppresses stdout, not stderr, to preserve error messages.\n    Example:\n        with suppress_stdout():\n            # C library calls that would normally print to stdout\n            torch.distributed.new_group(ranks, backend=\"gloo\")\n    \"\"\"\n    # Don't suppress if logging level is DEBUG\n\n    stdout_fd = sys.stdout.fileno()\n    stdout_dup = os.dup(stdout_fd)\n    devnull_fd = os.open(os.devnull, os.O_WRONLY)\n\n    try:\n        sys.stdout.flush()\n        os.dup2(devnull_fd, stdout_fd)\n        yield\n    finally:\n        sys.stdout.flush()\n        os.dup2(stdout_dup, stdout_fd)\n        os.close(stdout_dup)\n        os.close(devnull_fd)\n\n\nclass GenerationTimer:\n    def __init__(self):\n        self.start_time = 0.0\n        self.end_time = 0.0\n        self.duration = 0.0\n\n\n@contextmanager\ndef log_generation_timer(\n    logger: logging.Logger,\n    prompt: str,\n    request_idx: int | None = None,\n    total_requests: int | None = None,\n):\n    if request_idx is not None and total_requests is not None:\n        logger.info(\n            \"Processing prompt %d/%d: %s\",\n            request_idx,\n            total_requests,\n            _sanitize_for_logging(prompt, key_hint=\"prompt\"),\n        )\n\n    timer = GenerationTimer()\n    timer.start_time = time.perf_counter()\n    try:\n        yield timer\n        timer.end_time = time.perf_counter()\n        timer.duration = timer.end_time - timer.start_time\n        logger.info(\n            f\"Pixel data generated successfully in {GREEN}%.2f{RESET} seconds\",\n            timer.duration,\n        )\n    except Exception as e:\n        if request_idx is not None:\n            logger.error(\n                \"Failed to generate output for prompt %d: %s\",\n                request_idx,\n                e,\n                exc_info=True,\n            )\n        else:\n            logger.error(\n                f\"Failed to generate output for prompt: {e}\",\n                exc_info=True,\n            )\n        raise\n\n\ndef log_batch_completion(\n    logger: logging.Logger, num_outputs: int, total_time: float\n) -> None:\n    logger.info(\n        f\"Completed batch processing. Generated %d outputs in {GREEN}%.2f{RESET} seconds\",\n        num_outputs,\n        total_time,\n    )\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/utils/mesh3d_utils.py",
    "content": "\"\"\"Adapted from Hunyuan3D-2: https://github.com/Tencent/Hunyuan3D-2\"\"\"\n\nfrom __future__ import annotations\n\nimport math\nfrom typing import Any, List, Optional, Tuple, Union\n\nimport cv2\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport trimesh\nfrom einops import rearrange, repeat\nfrom PIL import Image\n\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n# Import C++ mesh processor extension\nfrom sglang.multimodal_gen.csrc.render.mesh_processor import meshVerticeInpaint\n\n\ndef transform_pos(\n    mtx: Union[np.ndarray, torch.Tensor],\n    pos: torch.Tensor,\n    keepdim: bool = False,\n) -> torch.Tensor:\n    \"\"\"Transform positions by a matrix.\"\"\"\n    t_mtx = torch.from_numpy(mtx).to(pos.device) if isinstance(mtx, np.ndarray) else mtx\n\n    if pos.shape[-1] == 3:\n        posw = torch.cat([pos, torch.ones([pos.shape[0], 1]).to(pos.device)], axis=1)\n    else:\n        posw = pos\n\n    if keepdim:\n        return torch.matmul(posw, t_mtx.t())[...]\n    else:\n        return torch.matmul(posw, t_mtx.t())[None, ...]\n\n\ndef get_mv_matrix(\n    elev: float,\n    azim: float,\n    camera_distance: float,\n    center: Optional[np.ndarray] = None,\n) -> np.ndarray:\n    \"\"\"Compute model-view matrix from camera parameters.\"\"\"\n    elev = -elev\n    azim += 90\n\n    elev_rad = math.radians(elev)\n    azim_rad = math.radians(azim)\n\n    camera_position = np.array(\n        [\n            camera_distance * math.cos(elev_rad) * math.cos(azim_rad),\n            camera_distance * math.cos(elev_rad) * math.sin(azim_rad),\n            camera_distance * math.sin(elev_rad),\n        ]\n    )\n\n    if center is None:\n        center = np.array([0, 0, 0])\n    else:\n        center = np.array(center)\n\n    lookat = center - camera_position\n    lookat = lookat / np.linalg.norm(lookat)\n\n    up = np.array([0, 0, 1.0])\n    right = np.cross(lookat, up)\n    right = right / np.linalg.norm(right)\n    up = np.cross(right, lookat)\n    up = up / np.linalg.norm(up)\n\n    c2w = np.concatenate(\n        [np.stack([right, up, -lookat], axis=-1), camera_position[:, None]], axis=-1\n    )\n\n    w2c = np.zeros((4, 4))\n    w2c[:3, :3] = np.transpose(c2w[:3, :3], (1, 0))\n    w2c[:3, 3:] = -np.matmul(np.transpose(c2w[:3, :3], (1, 0)), c2w[:3, 3:])\n    w2c[3, 3] = 1.0\n\n    return w2c.astype(np.float32)\n\n\ndef get_orthographic_projection_matrix(\n    left: float = -1,\n    right: float = 1,\n    bottom: float = -1,\n    top: float = 1,\n    near: float = 0,\n    far: float = 2,\n) -> np.ndarray:\n    \"\"\"Compute orthographic projection matrix.\"\"\"\n    ortho_matrix = np.eye(4, dtype=np.float32)\n    ortho_matrix[0, 0] = 2 / (right - left)\n    ortho_matrix[1, 1] = 2 / (top - bottom)\n    ortho_matrix[2, 2] = -2 / (far - near)\n    ortho_matrix[0, 3] = -(right + left) / (right - left)\n    ortho_matrix[1, 3] = -(top + bottom) / (top - bottom)\n    ortho_matrix[2, 3] = -(far + near) / (far - near)\n    return ortho_matrix\n\n\ndef get_perspective_projection_matrix(\n    fovy: float,\n    aspect_wh: float,\n    near: float,\n    far: float,\n) -> np.ndarray:\n    \"\"\"Compute perspective projection matrix.\"\"\"\n    fovy_rad = math.radians(fovy)\n    return np.array(\n        [\n            [1.0 / (math.tan(fovy_rad / 2.0) * aspect_wh), 0, 0, 0],\n            [0, 1.0 / math.tan(fovy_rad / 2.0), 0, 0],\n            [0, 0, -(far + near) / (far - near), -2.0 * far * near / (far - near)],\n            [0, 0, -1, 0],\n        ]\n    ).astype(np.float32)\n\n\ndef export_to_trimesh(mesh_output: Any) -> Any:\n    \"\"\"Convert mesh output to trimesh format.\"\"\"\n    if isinstance(mesh_output, list):\n        outputs = []\n        for mesh in mesh_output:\n            if mesh is None:\n                outputs.append(None)\n            else:\n                # Reverse face winding\n                mesh.mesh_f = mesh.mesh_f[:, ::-1]\n                mesh_obj = trimesh.Trimesh(mesh.mesh_v, mesh.mesh_f)\n                outputs.append(mesh_obj)\n        return outputs\n    else:\n        mesh_output.mesh_f = mesh_output.mesh_f[:, ::-1]\n        return trimesh.Trimesh(mesh_output.mesh_v, mesh_output.mesh_f)\n\n\ndef mesh_uv_wrap(mesh: Any) -> Any:\n    \"\"\"Apply UV unwrapping to mesh. In-place like native Hunyuan3D-2 for same layout.\"\"\"\n    try:\n        import xatlas\n    except ImportError:\n        logger.warning(\"xatlas not available, skipping UV unwrap\")\n        return mesh\n\n    if isinstance(mesh, trimesh.Scene):\n        mesh = mesh.dump(concatenate=True)\n\n    if len(mesh.faces) > 500000000:\n        raise ValueError(\n            \"The mesh has more than 500,000,000 faces, which is not supported.\"\n        )\n\n    vmapping, indices, uvs = xatlas.parametrize(mesh.vertices, mesh.faces)\n\n    mesh.vertices = mesh.vertices[vmapping]\n    mesh.faces = indices\n    if not hasattr(mesh.visual, \"uv\"):\n        mesh.visual = trimesh.visual.TextureVisuals(\n            uv=uvs, material=trimesh.visual.material.SimpleMaterial()\n        )\n    else:\n        mesh.visual.uv = uvs\n\n    return mesh\n\n\ndef stride_from_shape(shape: Tuple[int, ...]) -> List[int]:\n    \"\"\"Compute stride from shape for scatter operations.\"\"\"\n    stride = [1]\n    for x in reversed(shape[1:]):\n        stride.append(stride[-1] * x)\n    return list(reversed(stride))\n\n\ndef scatter_add_nd_with_count(\n    input: torch.Tensor,\n    count: torch.Tensor,\n    indices: torch.Tensor,\n    values: torch.Tensor,\n    weights: Optional[torch.Tensor] = None,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"Scatter add with counting for texture baking.\"\"\"\n    D = indices.shape[-1]\n    C = input.shape[-1]\n    size = input.shape[:-1]\n    stride = stride_from_shape(size)\n\n    assert len(size) == D\n\n    input = input.view(-1, C)\n    count = count.view(-1, 1)\n\n    flatten_indices = (\n        indices * torch.tensor(stride, dtype=torch.long, device=indices.device)\n    ).sum(-1)\n\n    if weights is None:\n        weights = torch.ones_like(values[..., :1])\n\n    input.scatter_add_(0, flatten_indices.unsqueeze(1).repeat(1, C), values)\n    count.scatter_add_(0, flatten_indices.unsqueeze(1), weights)\n\n    return input.view(*size, C), count.view(*size, 1)\n\n\ndef linear_grid_put_2d(\n    H: int,\n    W: int,\n    coords: torch.Tensor,\n    values: torch.Tensor,\n    return_count: bool = False,\n) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n    \"\"\"Put values into a 2D grid using linear interpolation.\"\"\"\n    C = values.shape[-1]\n\n    indices = coords * torch.tensor(\n        [H - 1, W - 1], dtype=torch.float32, device=coords.device\n    )\n    indices_00 = indices.floor().long()\n    indices_00[:, 0].clamp_(0, H - 2)\n    indices_00[:, 1].clamp_(0, W - 2)\n    indices_01 = indices_00 + torch.tensor(\n        [0, 1], dtype=torch.long, device=indices.device\n    )\n    indices_10 = indices_00 + torch.tensor(\n        [1, 0], dtype=torch.long, device=indices.device\n    )\n    indices_11 = indices_00 + torch.tensor(\n        [1, 1], dtype=torch.long, device=indices.device\n    )\n\n    h = indices[..., 0] - indices_00[..., 0].float()\n    w = indices[..., 1] - indices_00[..., 1].float()\n    w_00 = (1 - h) * (1 - w)\n    w_01 = (1 - h) * w\n    w_10 = h * (1 - w)\n    w_11 = h * w\n\n    result = torch.zeros(H, W, C, device=values.device, dtype=values.dtype)\n    count = torch.zeros(H, W, 1, device=values.device, dtype=values.dtype)\n    weights = torch.ones_like(values[..., :1])\n\n    result, count = scatter_add_nd_with_count(\n        result,\n        count,\n        indices_00,\n        values * w_00.unsqueeze(1),\n        weights * w_00.unsqueeze(1),\n    )\n    result, count = scatter_add_nd_with_count(\n        result,\n        count,\n        indices_01,\n        values * w_01.unsqueeze(1),\n        weights * w_01.unsqueeze(1),\n    )\n    result, count = scatter_add_nd_with_count(\n        result,\n        count,\n        indices_10,\n        values * w_10.unsqueeze(1),\n        weights * w_10.unsqueeze(1),\n    )\n    result, count = scatter_add_nd_with_count(\n        result,\n        count,\n        indices_11,\n        values * w_11.unsqueeze(1),\n        weights * w_11.unsqueeze(1),\n    )\n\n    if return_count:\n        return result, count\n\n    mask = count.squeeze(-1) > 0\n    result[mask] = result[mask] / count[mask].repeat(1, C)\n\n    return result\n\n\nclass MeshRender:\n    \"\"\"Mesh renderer using CUDA rasterization for texture generation.\"\"\"\n\n    def __init__(\n        self,\n        camera_distance: float = 1.45,\n        camera_type: str = \"orth\",\n        default_resolution: int = 1024,\n        texture_size: int = 1024,\n        bake_mode: str = \"linear\",\n        device: str = \"cuda\",\n    ):\n        \"\"\"Initialize the mesh renderer.\"\"\"\n        self.device = device\n\n        self.set_default_render_resolution(default_resolution)\n        self.set_default_texture_resolution(texture_size)\n\n        self.camera_distance = camera_distance\n        self.camera_type = camera_type\n        self.bake_angle_thres = 75\n        self.bake_unreliable_kernel_size = int(\n            (2 / 512) * max(self.default_resolution[0], self.default_resolution[1])\n        )\n        self.bake_mode = bake_mode\n\n        # Set up camera projection matrix\n        if camera_type == \"orth\":\n            self.ortho_scale = 1.2\n            self.camera_proj_mat = get_orthographic_projection_matrix(\n                left=-self.ortho_scale * 0.5,\n                right=self.ortho_scale * 0.5,\n                bottom=-self.ortho_scale * 0.5,\n                top=self.ortho_scale * 0.5,\n                near=0.1,\n                far=100,\n            )\n        elif camera_type == \"perspective\":\n            self.camera_proj_mat = get_perspective_projection_matrix(\n                49.13,\n                self.default_resolution[1] / self.default_resolution[0],\n                0.01,\n                100.0,\n            )\n        else:\n            raise ValueError(f\"Unknown camera type: {camera_type}\")\n\n        # Mesh data\n        self.vtx_pos = None\n        self.pos_idx = None\n        self.vtx_uv = None\n        self.uv_idx = None\n        self.tex = None\n        self.mesh_copy = None\n        self.scale_factor = 1.0\n\n    def set_default_render_resolution(\n        self, default_resolution: Union[int, Tuple[int, int]]\n    ):\n        \"\"\"Set default rendering resolution.\"\"\"\n        if isinstance(default_resolution, int):\n            default_resolution = (default_resolution, default_resolution)\n        self.default_resolution = default_resolution\n\n    def set_default_texture_resolution(self, texture_size: Union[int, Tuple[int, int]]):\n        \"\"\"Set default texture resolution.\"\"\"\n        if isinstance(texture_size, int):\n            texture_size = (texture_size, texture_size)\n        self.texture_size = texture_size\n\n    def _rasterize(\n        self,\n        pos_clip: torch.Tensor,\n        tri: torch.Tensor,\n        resolution: Tuple[int, int],\n    ) -> torch.Tensor:\n        \"\"\"Rasterize using CUDA rasterizer.\"\"\"\n        from sglang.multimodal_gen.csrc.render.hunyuan3d_rasterizer import rasterize\n\n        if pos_clip.dim() == 2:\n            pos_clip = pos_clip.unsqueeze(0)\n\n        findices, barycentric = rasterize(pos_clip, tri, resolution)\n        rast_out = torch.cat((barycentric, findices.unsqueeze(-1).float()), dim=-1)\n        rast_out = rast_out.unsqueeze(0)\n        return rast_out\n\n    def _interpolate(\n        self,\n        attr: torch.Tensor,\n        rast_out: torch.Tensor,\n        tri: torch.Tensor,\n    ) -> torch.Tensor:\n        \"\"\"Interpolate vertex attributes.\"\"\"\n        from sglang.multimodal_gen.csrc.render.hunyuan3d_rasterizer import interpolate\n\n        barycentric = rast_out[0, ..., :-1]\n        findices = rast_out[0, ..., -1].int()\n\n        if attr.dim() == 2:\n            attr = attr.unsqueeze(0)\n\n        result = interpolate(attr, findices, barycentric, tri)\n        return result\n\n    def load_mesh(\n        self,\n        mesh: Union[trimesh.Trimesh, trimesh.Scene],\n        scale_factor: float = 1.15,\n        auto_center: bool = True,\n    ):\n        \"\"\"Load a mesh for rendering.\"\"\"\n        if isinstance(mesh, trimesh.Scene):\n            mesh = mesh.dump(concatenate=True)\n\n        self.mesh_copy = mesh.copy()\n\n        vtx_pos = mesh.vertices.astype(np.float32)\n        pos_idx = mesh.faces.astype(np.int32)\n\n        # Get UV coordinates if available\n        if hasattr(mesh.visual, \"uv\") and mesh.visual.uv is not None:\n            vtx_uv = mesh.visual.uv.astype(np.float32)\n            uv_idx = pos_idx.copy()\n        else:\n            vtx_uv = None\n            uv_idx = None\n\n        self.vtx_pos = torch.from_numpy(vtx_pos).to(self.device).float()\n        self.pos_idx = torch.from_numpy(pos_idx).to(self.device).to(torch.int32)\n\n        if vtx_uv is not None and uv_idx is not None:\n            self.vtx_uv = torch.from_numpy(vtx_uv).to(self.device).float()\n            self.uv_idx = torch.from_numpy(uv_idx).to(self.device).to(torch.int32)\n        else:\n            self.vtx_uv = None\n            self.uv_idx = None\n\n        # Coordinate transformation (Y-up to Z-up)\n        self.vtx_pos[:, [0, 1]] = -self.vtx_pos[:, [0, 1]]\n        self.vtx_pos[:, [1, 2]] = self.vtx_pos[:, [2, 1]]\n        if self.vtx_uv is not None:\n            self.vtx_uv[:, 1] = 1.0 - self.vtx_uv[:, 1]\n\n        if auto_center:\n            max_bb = (self.vtx_pos - 0).max(0)[0]\n            min_bb = (self.vtx_pos - 0).min(0)[0]\n            center = (max_bb + min_bb) / 2\n            scale = torch.norm(self.vtx_pos - center, dim=1).max() * 2.0\n            self.vtx_pos = (self.vtx_pos - center) * (scale_factor / float(scale))\n            self.scale_factor = scale_factor\n\n    def save_mesh(self) -> trimesh.Trimesh:\n        \"\"\"Save mesh with current texture, reusing the original mesh object.\"\"\"\n        texture_data = self.get_texture()\n        texture_img = Image.fromarray((texture_data * 255).astype(np.uint8))\n\n        material = trimesh.visual.material.SimpleMaterial(\n            image=texture_img, diffuse=(255, 255, 255)\n        )\n        self.mesh_copy.visual = trimesh.visual.TextureVisuals(\n            uv=self.mesh_copy.visual.uv, image=texture_img, material=material\n        )\n        return self.mesh_copy\n\n    def get_mesh(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:\n        \"\"\"Get mesh data with inverse coordinate transformation.\"\"\"\n        vtx_pos = self.vtx_pos.cpu().numpy().copy()\n        pos_idx = self.pos_idx.cpu().numpy()\n\n        # Inverse coordinate transformation\n        vtx_pos[:, [1, 2]] = vtx_pos[:, [2, 1]]\n        vtx_pos[:, [0, 1]] = -vtx_pos[:, [0, 1]]\n\n        if self.vtx_uv is not None:\n            vtx_uv = self.vtx_uv.cpu().numpy().copy()\n            vtx_uv[:, 1] = 1.0 - vtx_uv[:, 1]\n            uv_idx = self.uv_idx.cpu().numpy()\n        else:\n            vtx_uv = None\n            uv_idx = None\n\n        return vtx_pos, pos_idx, vtx_uv, uv_idx\n\n    def set_texture(self, tex: Union[np.ndarray, torch.Tensor, Image.Image]):\n        \"\"\"Set texture for the mesh.\"\"\"\n        if isinstance(tex, np.ndarray):\n            if tex.max() <= 1.0:\n                tex = (tex * 255).astype(np.uint8)\n            tex = Image.fromarray(tex.astype(np.uint8))\n        elif isinstance(tex, torch.Tensor):\n            tex_np = tex.cpu().numpy()\n            if tex_np.max() <= 1.0:\n                tex_np = (tex_np * 255).astype(np.uint8)\n            tex = Image.fromarray(tex_np.astype(np.uint8))\n\n        tex = tex.resize(self.texture_size).convert(\"RGB\")\n        tex = np.array(tex) / 255.0\n        self.tex = torch.from_numpy(tex).to(self.device).float()\n\n    def get_texture(self) -> np.ndarray:\n        \"\"\"Get current texture as numpy array.\"\"\"\n        if self.tex is None:\n            return np.ones((*self.texture_size, 3), dtype=np.float32)\n        return self.tex.cpu().numpy()\n\n    def _get_pos_from_mvp(\n        self,\n        elev: float,\n        azim: float,\n        camera_distance: Optional[float] = None,\n        center: Optional[np.ndarray] = None,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"Get camera-space and clip-space positions.\"\"\"\n        proj = self.camera_proj_mat\n        r_mv = get_mv_matrix(\n            elev=elev,\n            azim=azim,\n            camera_distance=(\n                self.camera_distance if camera_distance is None else camera_distance\n            ),\n            center=center,\n        )\n\n        pos_camera = transform_pos(r_mv, self.vtx_pos, keepdim=True)\n        pos_clip = transform_pos(proj, pos_camera)\n\n        return pos_camera, pos_clip\n\n    def render_normal(\n        self,\n        elev: float,\n        azim: float,\n        camera_distance: Optional[float] = None,\n        center: Optional[np.ndarray] = None,\n        resolution: Optional[Tuple[int, int]] = None,\n        bg_color: List[float] = [1, 1, 1],\n        use_abs_coor: bool = False,\n        normalize_rgb: bool = True,\n        return_type: str = \"th\",\n    ) -> Union[torch.Tensor, np.ndarray, Image.Image]:\n        \"\"\"Render normal map from a viewpoint.\"\"\"\n        pos_camera, pos_clip = self._get_pos_from_mvp(\n            elev, azim, camera_distance, center\n        )\n\n        if resolution is None:\n            resolution = self.default_resolution\n        if isinstance(resolution, (int, float)):\n            resolution = (int(resolution), int(resolution))\n\n        rast_out = self._rasterize(pos_clip, self.pos_idx, resolution)\n\n        # Compute face normals\n        if use_abs_coor:\n            mesh_triangles = self.vtx_pos[self.pos_idx[:, :3].long(), :]\n        else:\n            pos_camera_3d = pos_camera[:, :3] / pos_camera[:, 3:4]\n            mesh_triangles = pos_camera_3d[self.pos_idx[:, :3].long(), :]\n\n        face_normals = F.normalize(\n            torch.cross(\n                mesh_triangles[:, 1, :] - mesh_triangles[:, 0, :],\n                mesh_triangles[:, 2, :] - mesh_triangles[:, 0, :],\n                dim=-1,\n            ),\n            dim=-1,\n        )\n\n        # Compute vertex normals\n        vertex_normals = trimesh.geometry.mean_vertex_normals(\n            vertex_count=self.vtx_pos.shape[0],\n            faces=self.pos_idx.cpu().numpy(),\n            face_normals=face_normals.cpu().numpy(),\n        )\n        vertex_normals = (\n            torch.from_numpy(vertex_normals).float().to(self.device).contiguous()\n        )\n\n        # Interpolate normals\n        normal = self._interpolate(vertex_normals[None, ...], rast_out, self.pos_idx)\n\n        # Apply visibility mask\n        visible_mask = torch.clamp(rast_out[..., -1:], 0, 1)\n        bg_tensor = torch.tensor(bg_color, dtype=torch.float32, device=self.device)\n        normal = normal * visible_mask + bg_tensor * (1 - visible_mask)\n\n        if normalize_rgb:\n            normal = (normal + 1) * 0.5\n\n        image = normal[0, ...]\n\n        if return_type == \"np\":\n            image = image.cpu().numpy()\n        elif return_type == \"pl\":\n            image = image.cpu().numpy() * 255\n            image = Image.fromarray(image.astype(np.uint8))\n\n        return image\n\n    def render_position(\n        self,\n        elev: float,\n        azim: float,\n        camera_distance: Optional[float] = None,\n        center: Optional[np.ndarray] = None,\n        resolution: Optional[Tuple[int, int]] = None,\n        bg_color: List[float] = [1, 1, 1],\n        return_type: str = \"th\",\n    ) -> Union[torch.Tensor, np.ndarray, Image.Image]:\n        \"\"\"Render position map from a viewpoint.\"\"\"\n        pos_camera, pos_clip = self._get_pos_from_mvp(\n            elev, azim, camera_distance, center\n        )\n\n        if resolution is None:\n            resolution = self.default_resolution\n        if isinstance(resolution, (int, float)):\n            resolution = (int(resolution), int(resolution))\n\n        rast_out = self._rasterize(pos_clip, self.pos_idx, resolution)\n\n        # Position colors (normalized vertex positions)\n        tex_position = 0.5 - self.vtx_pos[:, :3] / self.scale_factor\n        tex_position = tex_position.contiguous()\n\n        # Interpolate positions\n        position = self._interpolate(tex_position[None, ...], rast_out, self.pos_idx)\n\n        # Apply visibility mask\n        visible_mask = torch.clamp(rast_out[..., -1:], 0, 1)\n        bg_tensor = torch.tensor(bg_color, dtype=torch.float32, device=self.device)\n        position = position * visible_mask + bg_tensor * (1 - visible_mask)\n\n        image = position[0, ...]\n\n        if return_type == \"np\":\n            image = image.cpu().numpy()\n        elif return_type == \"pl\":\n            image = image.cpu().numpy() * 255\n            image = Image.fromarray(image.astype(np.uint8))\n\n        return image\n\n    def render_normal_multiview(\n        self,\n        camera_elevs: List[float],\n        camera_azims: List[float],\n        use_abs_coor: bool = True,\n    ) -> List[Image.Image]:\n        \"\"\"Render normal maps from multiple viewpoints.\"\"\"\n        normal_maps = []\n        for elev, azim in zip(camera_elevs, camera_azims):\n            normal_map = self.render_normal(\n                elev, azim, use_abs_coor=use_abs_coor, return_type=\"pl\"\n            )\n            normal_maps.append(normal_map)\n        return normal_maps\n\n    def render_position_multiview(\n        self,\n        camera_elevs: List[float],\n        camera_azims: List[float],\n    ) -> List[Image.Image]:\n        \"\"\"Render position maps from multiple viewpoints.\"\"\"\n        position_maps = []\n        for elev, azim in zip(camera_elevs, camera_azims):\n            position_map = self.render_position(elev, azim, return_type=\"pl\")\n            position_maps.append(position_map)\n        return position_maps\n\n    def _render_sketch_from_depth(self, depth_image: torch.Tensor) -> torch.Tensor:\n        \"\"\"Render sketch from depth using edge detection.\"\"\"\n        depth_image_np = depth_image.cpu().numpy()\n        depth_image_np = (depth_image_np * 255).astype(np.uint8)\n        depth_edges = cv2.Canny(depth_image_np, 30, 80)\n        sketch_image = (\n            torch.from_numpy(depth_edges).to(depth_image.device).float() / 255.0\n        )\n        sketch_image = sketch_image.unsqueeze(-1)\n        return sketch_image\n\n    def back_project(\n        self,\n        image: Union[Image.Image, np.ndarray, torch.Tensor],\n        elev: float,\n        azim: float,\n        camera_distance: Optional[float] = None,\n        center: Optional[np.ndarray] = None,\n    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        \"\"\"Back-project an image onto mesh UV space.\"\"\"\n        if isinstance(image, Image.Image):\n            image = torch.tensor(np.array(image) / 255.0)\n        elif isinstance(image, np.ndarray):\n            image = torch.tensor(image)\n        if image.dim() == 2:\n            image = image.unsqueeze(-1)\n        image = image.float().to(self.device)\n        resolution = image.shape[:2]\n        channel = image.shape[-1]\n\n        pos_camera, pos_clip = self._get_pos_from_mvp(\n            elev, azim, camera_distance, center\n        )\n\n        rast_out = self._rasterize(pos_clip, self.pos_idx, resolution)\n        visible_mask = torch.clamp(rast_out[..., -1:], 0, 1)[0, ...]\n\n        # Compute vertex normals for angle-based weighting\n        pos_camera_3d = pos_camera[:, :3] / pos_camera[:, 3:4]\n        v0 = pos_camera_3d[self.pos_idx[:, 0].long(), :]\n        v1 = pos_camera_3d[self.pos_idx[:, 1].long(), :]\n        v2 = pos_camera_3d[self.pos_idx[:, 2].long(), :]\n        face_normals = F.normalize(torch.cross(v1 - v0, v2 - v0, dim=-1), dim=-1)\n\n        vertex_normals = trimesh.geometry.mean_vertex_normals(\n            vertex_count=self.vtx_pos.shape[0],\n            faces=self.pos_idx.cpu().numpy(),\n            face_normals=face_normals.cpu().numpy(),\n        )\n        vertex_normals = (\n            torch.from_numpy(vertex_normals).float().to(self.device).contiguous()\n        )\n\n        # Interpolate normals and UVs\n        normal = self._interpolate(vertex_normals[None, ...], rast_out, self.pos_idx)\n        normal = normal[0, ...]\n\n        if self.vtx_uv is not None:\n            uv = self._interpolate(self.vtx_uv[None, ...], rast_out, self.uv_idx)\n        else:\n            # No UV coordinates\n            texture = torch.zeros(\n                self.texture_size[1], self.texture_size[0], channel, device=self.device\n            )\n            cos_map = torch.zeros(\n                self.texture_size[1], self.texture_size[0], 1, device=self.device\n            )\n            boundary_map = torch.zeros_like(cos_map)\n            return texture, cos_map, boundary_map\n\n        # Compute depth for sketch\n        tex_depth = pos_camera_3d[:, 2].reshape(1, -1, 1).contiguous()\n        depth = self._interpolate(tex_depth, rast_out, self.pos_idx)[0, ...]\n        depth_masked = depth[visible_mask > 0]\n        if depth_masked.numel() > 0:\n            depth_max, depth_min = depth_masked.max(), depth_masked.min()\n            depth_normalized = (depth - depth_min) / (depth_max - depth_min + 1e-8)\n        else:\n            depth_normalized = depth\n        depth_image = depth_normalized * visible_mask\n\n        sketch_image = self._render_sketch_from_depth(depth_image)\n\n        # Cosine weighting\n        lookat = torch.tensor([[0, 0, -1]], device=self.device)\n        cos_image = torch.nn.functional.cosine_similarity(lookat, normal.view(-1, 3))\n        cos_image = cos_image.view(normal.shape[0], normal.shape[1], 1)\n\n        cos_thres = np.cos(self.bake_angle_thres / 180 * np.pi)\n        cos_image[cos_image < cos_thres] = 0\n\n        # Shrink visible mask\n        kernel_size = self.bake_unreliable_kernel_size * 2 + 1\n        kernel = torch.ones((1, 1, kernel_size, kernel_size), dtype=torch.float32).to(\n            sketch_image.device\n        )\n\n        visible_mask_proc = visible_mask.permute(2, 0, 1).unsqueeze(0).float()\n        visible_mask_proc = F.conv2d(\n            1.0 - visible_mask_proc, kernel, padding=kernel_size // 2\n        )\n        visible_mask_proc = 1.0 - (visible_mask_proc > 0).float()\n        visible_mask_proc = visible_mask_proc.squeeze(0).permute(1, 2, 0)\n\n        sketch_proc = sketch_image.permute(2, 0, 1).unsqueeze(0)\n        sketch_proc = F.conv2d(sketch_proc, kernel, padding=kernel_size // 2)\n        sketch_proc = (sketch_proc > 0).float()\n        sketch_proc = sketch_proc.squeeze(0).permute(1, 2, 0)\n        visible_mask_proc = visible_mask_proc * (sketch_proc < 0.5)\n\n        cos_image[visible_mask_proc == 0] = 0\n\n        # Linear baking\n        proj_mask = (visible_mask_proc != 0).view(-1)\n        uv_flat = uv.squeeze(0).contiguous().view(-1, 2)[proj_mask]\n        image_flat = image.squeeze(0).contiguous().view(-1, channel)[proj_mask]\n        cos_flat = cos_image.contiguous().view(-1, 1)[proj_mask]\n        sketch_flat = sketch_image.contiguous().view(-1, 1)[proj_mask]\n\n        texture = linear_grid_put_2d(\n            self.texture_size[1], self.texture_size[0], uv_flat[..., [1, 0]], image_flat\n        )\n        cos_map = linear_grid_put_2d(\n            self.texture_size[1], self.texture_size[0], uv_flat[..., [1, 0]], cos_flat\n        )\n        boundary_map = linear_grid_put_2d(\n            self.texture_size[1],\n            self.texture_size[0],\n            uv_flat[..., [1, 0]],\n            sketch_flat,\n        )\n\n        return texture, cos_map, boundary_map\n\n    def bake_from_multiview(\n        self,\n        views: List[Image.Image],\n        camera_elevs: List[float],\n        camera_azims: List[float],\n        view_weights: List[float],\n        method: str = \"fast\",\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"Bake texture from multiple views.\"\"\"\n        project_textures, project_weighted_cos_maps = [], []\n        bake_exp = 4\n\n        for view, camera_elev, camera_azim, weight in zip(\n            views, camera_elevs, camera_azims, view_weights\n        ):\n            project_texture, project_cos_map, _ = self.back_project(\n                view, camera_elev, camera_azim\n            )\n            project_cos_map = weight * (project_cos_map**bake_exp)\n            project_textures.append(project_texture)\n            project_weighted_cos_maps.append(project_cos_map)\n\n        if method == \"fast\":\n            texture, ori_trust_map = self.fast_bake_texture(\n                project_textures, project_weighted_cos_maps\n            )\n        else:\n            raise ValueError(f\"Unknown bake method: {method}\")\n\n        return texture, ori_trust_map > 1e-8\n\n    @torch.no_grad()\n    def fast_bake_texture(\n        self,\n        textures: List[torch.Tensor],\n        cos_maps: List[torch.Tensor],\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"Fast texture baking by weighted averaging.\"\"\"\n        channel = textures[0].shape[-1]\n        texture_merge = torch.zeros(self.texture_size + (channel,)).to(self.device)\n        trust_map_merge = torch.zeros(self.texture_size + (1,)).to(self.device)\n\n        for texture, cos_map in zip(textures, cos_maps):\n            view_sum = (cos_map > 0).sum()\n            painted_sum = ((cos_map > 0) * (trust_map_merge > 0)).sum()\n            if view_sum > 0 and painted_sum / view_sum > 0.99:\n                continue\n            texture_merge += texture * cos_map\n            trust_map_merge += cos_map\n\n        texture_merge = texture_merge / torch.clamp(trust_map_merge, min=1e-8)\n        texture_merge = texture_merge.clamp(0.0, 1.0)\n\n        return texture_merge, trust_map_merge > 1e-8\n\n    def texture_inpaint(\n        self,\n        texture: torch.Tensor,\n        mask: Union[torch.Tensor, np.ndarray],\n    ) -> torch.Tensor:\n        \"\"\"Inpaint missing regions in UV texture using mesh-aware method.\"\"\"\n        if isinstance(texture, torch.Tensor):\n            texture_np = texture.cpu().numpy()\n        else:\n            texture_np = texture\n\n        if isinstance(mask, torch.Tensor):\n            mask_np = mask.cpu().numpy()\n        else:\n            mask_np = mask\n\n        # Ensure proper format\n        if texture_np.max() <= 1.0:\n            texture_np = texture_np.astype(np.float32)\n        else:\n            texture_np = (texture_np / 255.0).astype(np.float32)\n\n        if mask_np.ndim == 3:\n            mask_np = mask_np.squeeze(-1)\n        if mask_np.dtype == np.uint8:\n            mask_uint8 = mask_np\n        else:\n            mask_uint8 = ((mask_np > 0) * 255).astype(np.uint8)\n\n        # Get mesh data for mesh-aware inpainting\n        vtx_pos, pos_idx, vtx_uv, uv_idx = self.get_mesh()\n\n        if vtx_uv is not None and uv_idx is not None:\n            texture_np, mask_uint8 = meshVerticeInpaint(\n                texture_np, mask_uint8, vtx_pos, vtx_uv, pos_idx, uv_idx\n            )\n\n        # Final OpenCV inpainting for remaining holes\n        texture_uint8 = (texture_np * 255).astype(np.uint8)\n        inpaint_mask = 255 - mask_uint8\n        texture_inpainted = cv2.inpaint(texture_uint8, inpaint_mask, 3, cv2.INPAINT_NS)\n\n        return torch.from_numpy(texture_inpainted / 255.0).float().to(self.device)\n\n    # Alias for compatibility\n    uv_inpaint = texture_inpaint\n\n\ndef array_to_tensor(np_array):\n    \"\"\"Convert numpy array to normalized tensor.\"\"\"\n    image_pt = torch.tensor(np_array).float()\n    image_pt = image_pt / 255 * 2 - 1\n    image_pt = rearrange(image_pt, \"h w c -> c h w\")\n    image_pts = repeat(image_pt, \"c h w -> b c h w\", b=1)\n    return image_pts\n\n\ndef recenter_image(image, border_ratio=0.2):\n    \"\"\"Recenter a PIL image, cropping to non-transparent content with a border.\"\"\"\n    from PIL import Image as PILImage\n\n    if image.mode == \"RGB\":\n        return image\n    elif image.mode == \"L\":\n        return image.convert(\"RGB\")\n    if image.mode != \"RGBA\":\n        image = image.convert(\"RGBA\")\n\n    alpha_channel = np.array(image)[:, :, 3]\n    non_zero_indices = np.argwhere(alpha_channel > 0)\n    if non_zero_indices.size == 0:\n        raise ValueError(\"Image is fully transparent\")\n\n    min_row, min_col = non_zero_indices.min(axis=0)\n    max_row, max_col = non_zero_indices.max(axis=0)\n\n    cropped_image = image.crop((min_col, min_row, max_col + 1, max_row + 1))\n\n    width, height = cropped_image.size\n    border_width = int(width * border_ratio)\n    border_height = int(height * border_ratio)\n\n    new_width = width + 2 * border_width\n    new_height = height + 2 * border_height\n    square_size = max(new_width, new_height)\n\n    new_image = PILImage.new(\"RGBA\", (square_size, square_size), (255, 255, 255, 0))\n\n    paste_x = (square_size - new_width) // 2 + border_width\n    paste_y = (square_size - new_height) // 2 + border_height\n    new_image.paste(cropped_image, (paste_x, paste_y))\n    return new_image\n\n\nclass ImageProcessorV2:\n    \"\"\"Image processor for Hunyuan3D single-view input.\"\"\"\n\n    # External module path aliases for compatibility with Hunyuan3D configs\n    _aliases = [\n        \"hy3dshape.preprocessors.ImageProcessorV2\",\n        \"hy3dgen.shapegen.preprocessors.ImageProcessorV2\",\n    ]\n\n    def __init__(self, size=512, border_ratio=None):\n        self.size = size\n        self.border_ratio = border_ratio\n\n    @staticmethod\n    def recenter(image, border_ratio: float = 0.2):\n        \"\"\"recenter an image to leave some empty space at the image border.\"\"\"\n\n        if image.shape[-1] == 4:\n            mask = image[..., 3]\n        else:\n            mask = np.ones_like(image[..., 0:1]) * 255\n            image = np.concatenate([image, mask], axis=-1)\n            mask = mask[..., 0]\n\n        height, width, channels = image.shape\n\n        size = max(height, width)\n        result = np.zeros((size, size, channels), dtype=np.uint8)\n\n        coords = np.nonzero(mask)\n        x_min, x_max = coords[0].min(), coords[0].max()\n        y_min, y_max = coords[1].min(), coords[1].max()\n        crop_h = x_max - x_min\n        crop_w = y_max - y_min\n        if crop_h == 0 or crop_w == 0:\n            raise ValueError(\"input image is empty\")\n        desired_size = int(size * (1 - border_ratio))\n        scale = desired_size / max(crop_h, crop_w)\n        scaled_h = int(crop_h * scale)\n        scaled_w = int(crop_w * scale)\n        x2_min = (size - scaled_h) // 2\n        x2_max = x2_min + scaled_h\n\n        y2_min = (size - scaled_w) // 2\n        y2_max = y2_min + scaled_w\n\n        result[x2_min:x2_max, y2_min:y2_max] = cv2.resize(\n            image[x_min:x_max, y_min:y_max],\n            (scaled_w, scaled_h),\n            interpolation=cv2.INTER_AREA,\n        )\n\n        bg = np.ones((result.shape[0], result.shape[1], 3), dtype=np.uint8) * 255\n\n        mask = result[..., 3:].astype(np.float32) / 255\n        result = result[..., :3] * mask + bg * (1 - mask)\n\n        mask = mask * 255\n        result = result.clip(0, 255).astype(np.uint8)\n        mask = mask.clip(0, 255).astype(np.uint8)\n        return result, mask\n\n    def load_image(self, image, border_ratio=0.15, to_tensor=True):\n        if isinstance(image, str):\n            image = cv2.imread(image, cv2.IMREAD_UNCHANGED)\n            image, mask = self.recenter(image, border_ratio=border_ratio)\n            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n        elif isinstance(image, Image.Image):\n            image = image.convert(\"RGBA\")\n            image = np.asarray(image)\n            image, mask = self.recenter(image, border_ratio=border_ratio)\n\n        image = cv2.resize(image, (self.size, self.size), interpolation=cv2.INTER_CUBIC)\n        mask = cv2.resize(mask, (self.size, self.size), interpolation=cv2.INTER_NEAREST)\n        mask = mask[..., np.newaxis]\n\n        if to_tensor:\n            image = array_to_tensor(image)\n            mask = array_to_tensor(mask)\n        return image, mask\n\n    def __call__(self, image, border_ratio=0.15, to_tensor=True, **kwargs):\n        if self.border_ratio is not None:\n            border_ratio = self.border_ratio\n        image, mask = self.load_image(\n            image, border_ratio=border_ratio, to_tensor=to_tensor\n        )\n        outputs = {\"image\": image, \"mask\": mask}\n        return outputs\n\n\nclass MVImageProcessorV2(ImageProcessorV2):\n    \"\"\"Multi-view image processor for Hunyuan3D.\"\"\"\n\n    # External module path aliases for compatibility with Hunyuan3D configs\n    _aliases = [\n        \"hy3dshape.preprocessors.MVImageProcessorV2\",\n    ]\n\n    return_view_idx = True\n\n    def __init__(self, size=512, border_ratio=None):\n        super().__init__(size, border_ratio)\n        self.view2idx = {\"front\": 0, \"left\": 1, \"back\": 2, \"right\": 3}\n\n    def __call__(self, image_dict, border_ratio=0.15, to_tensor=True, **kwargs):\n        if self.border_ratio is not None:\n            border_ratio = self.border_ratio\n\n        images = []\n        masks = []\n        view_idxs = []\n        for view_tag, image in image_dict.items():\n            view_idxs.append(self.view2idx[view_tag])\n            image, mask = self.load_image(\n                image, border_ratio=border_ratio, to_tensor=to_tensor\n            )\n            images.append(image)\n            masks.append(mask)\n\n        zipped_lists = zip(view_idxs, images, masks)\n        sorted_zipped_lists = sorted(zipped_lists)\n        view_idxs, images, masks = zip(*sorted_zipped_lists)\n\n        image = torch.cat(images, 0).unsqueeze(0)\n        mask = torch.cat(masks, 0).unsqueeze(0)\n        outputs = {\"image\": image, \"mask\": mask, \"view_idxs\": view_idxs}\n        return outputs\n\n\n# All tool classes available in this module for resolution\nTOOL_CLASSES = (\n    ImageProcessorV2,\n    MVImageProcessorV2,\n)\n\n\ndef resolve_hunyuan3d_tool(target: str):\n    \"\"\"Resolve a Hunyuan3D tool class by target string.\"\"\"\n    # First, try to match against _aliases\n    for cls in TOOL_CLASSES:\n        aliases = getattr(cls, \"_aliases\", [])\n        if target in aliases:\n            return cls\n\n    # Then, try to match against class names\n    for cls in TOOL_CLASSES:\n        if cls.__name__ == target:\n            return cls\n\n    return None\n\n\n__all__ = [\n    \"transform_pos\",\n    \"get_mv_matrix\",\n    \"get_orthographic_projection_matrix\",\n    \"get_perspective_projection_matrix\",\n    \"export_to_trimesh\",\n    \"mesh_uv_wrap\",\n    \"meshVerticeInpaint\",\n    \"stride_from_shape\",\n    \"scatter_add_nd_with_count\",\n    \"linear_grid_put_2d\",\n    \"MeshRender\",\n    \"recenter_image\",\n    \"array_to_tensor\",\n    \"ImageProcessorV2\",\n    \"MVImageProcessorV2\",\n    \"TOOL_CLASSES\",\n    \"resolve_hunyuan3d_tool\",\n]\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/utils/perf_logger.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\nimport dataclasses\nimport json\nimport logging\nimport os\nimport subprocess\nimport sys\nimport time\nfrom datetime import datetime\nfrom functools import lru_cache\nfrom pathlib import Path\nfrom typing import Any, Dict, Optional\n\nimport torch\nfrom dateutil.tz import UTC\n\nimport sglang\nimport sglang.multimodal_gen.envs as envs\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import (\n    CYAN,\n    RESET,\n    _SGLDiffusionLogger,\n    get_is_main_process,\n    init_logger,\n)\n\nlogger = init_logger(__name__)\n\n\n@dataclasses.dataclass\nclass MemorySnapshot:\n    allocated_mb: float  # current allocated memory\n    reserved_mb: float  # current reserved memory (actual VRAM)\n    peak_allocated_mb: float  # peak allocated since last reset\n    peak_reserved_mb: float  # peak reserved since last reset\n\n    def to_dict(self) -> Dict[str, Any]:\n        return {\n            \"allocated_mb\": round(self.allocated_mb, 2),\n            \"reserved_mb\": round(self.reserved_mb, 2),\n            \"peak_allocated_mb\": round(self.peak_allocated_mb, 2),\n            \"peak_reserved_mb\": round(self.peak_reserved_mb, 2),\n        }\n\n\n@dataclasses.dataclass\nclass RequestMetrics:\n    \"\"\"Performance metrics for a single request, including timings and memory snapshots.\"\"\"\n\n    def __init__(self, request_id: str):\n        self.request_id = request_id\n        self.stages: Dict[str, float] = {}\n        self.steps: list[float] = []\n        self.total_duration_ms: float = 0.0\n        # memory tracking: {checkpoint_name: MemorySnapshot}\n        self.memory_snapshots: Dict[str, MemorySnapshot] = {}\n\n    @property\n    def total_duration_s(self) -> float:\n        return self.total_duration_ms / 1000.0\n\n    def record_stage(self, stage_name: str, duration_s: float):\n        \"\"\"Records the duration of a pipeline stage\"\"\"\n        self.stages[stage_name] = duration_s * 1000  # Store as milliseconds\n\n    def record_steps(self, index: int, duration_s: float):\n        \"\"\"Records the duration of a denoising step\"\"\"\n        assert index == len(self.steps)\n        self.steps.append(duration_s * 1000)\n\n    def record_memory_snapshot(self, checkpoint_name: str, snapshot: MemorySnapshot):\n        self.memory_snapshots[checkpoint_name] = snapshot\n\n    def to_dict(self) -> Dict[str, Any]:\n        \"\"\"Serializes the metrics data to a dictionary.\"\"\"\n        return {\n            \"request_id\": self.request_id,\n            \"stages\": self.stages,\n            \"steps\": self.steps,\n            \"total_duration_ms\": self.total_duration_ms,\n            \"memory_snapshots\": {\n                name: snapshot.to_dict()\n                for name, snapshot in self.memory_snapshots.items()\n            },\n        }\n\n\ndef get_diffusion_perf_log_dir() -> str:\n    \"\"\"\n    Determines the directory for performance logs.\n    \"\"\"\n    log_dir = os.environ.get(\"SGLANG_PERF_LOG_DIR\")\n    if log_dir:\n        return os.path.abspath(log_dir)\n    if log_dir is None:\n        sglang_path = Path(sglang.__file__).resolve()\n        target_path = (sglang_path.parent / \"../../.cache/logs\").resolve()\n        return str(target_path)\n    return \"\"\n\n\n@lru_cache(maxsize=1)\ndef get_git_commit_hash() -> str:\n    try:\n        commit_hash = os.environ.get(\"SGLANG_GIT_COMMIT\")\n        if not commit_hash:\n            commit_hash = (\n                subprocess.check_output(\n                    [\"git\", \"rev-parse\", \"HEAD\"], stderr=subprocess.DEVNULL\n                )\n                .strip()\n                .decode(\"utf-8\")\n            )\n        _CACHED_COMMIT_HASH = commit_hash\n        return commit_hash\n    except (subprocess.CalledProcessError, FileNotFoundError):\n        _CACHED_COMMIT_HASH = \"N/A\"\n        return \"N/A\"\n\n\ndef capture_memory_snapshot() -> MemorySnapshot:\n    if not torch.get_device_module().is_available():\n        return MemorySnapshot(\n            allocated_mb=0.0,\n            reserved_mb=0.0,\n            peak_allocated_mb=0.0,\n            peak_reserved_mb=0.0,\n        )\n\n    allocated = torch.get_device_module().memory_allocated()\n    reserved = torch.get_device_module().memory_reserved()\n    peak_allocated = torch.get_device_module().max_memory_allocated()\n    peak_reserved = torch.get_device_module().max_memory_reserved()\n\n    return MemorySnapshot(\n        allocated_mb=allocated / (1024**2),\n        reserved_mb=reserved / (1024**2),\n        peak_allocated_mb=peak_allocated / (1024**2),\n        peak_reserved_mb=peak_reserved / (1024**2),\n    )\n\n\n@dataclasses.dataclass\nclass RequestPerfRecord:\n    request_id: str\n\n    timestamp: str\n    commit_hash: str\n    tag: str\n\n    stages: list[dict]\n    steps: list[float]\n    total_duration_ms: float\n    memory_snapshots: dict[str, dict] = dataclasses.field(default_factory=dict)\n\n    def __init__(\n        self,\n        request_id,\n        commit_hash,\n        tag,\n        stages,\n        steps,\n        total_duration_ms,\n        memory_snapshots=None,\n        timestamp=None,\n    ):\n        self.request_id = request_id\n        if timestamp is not None:\n            self.timestamp = timestamp\n        else:\n            self.timestamp = datetime.now(UTC).isoformat()\n\n        self.commit_hash = commit_hash\n        self.tag = tag\n        self.stages = stages\n        self.steps = steps\n        self.total_duration_ms = total_duration_ms\n        self.memory_snapshots = memory_snapshots or {}\n\n\nclass StageProfiler:\n    \"\"\"\n    A unified context manager, records performance metrics (usually of a single Stage or a step) into a provided RequestMetrics object (usually from a Req).\n    \"\"\"\n\n    def __init__(\n        self,\n        stage_name: str,\n        logger: _SGLDiffusionLogger,\n        metrics: Optional[\"RequestMetrics\"],\n        log_stage_start_end: bool = False,\n        perf_dump_path_provided: bool = False,\n        capture_memory: bool = False,\n    ):\n        self.stage_name = stage_name\n        self.metrics = metrics\n        self.logger = logger\n        self.start_time = 0.0\n        self.log_timing = perf_dump_path_provided or envs.SGLANG_DIFFUSION_STAGE_LOGGING\n        self.log_stage_start_end = log_stage_start_end\n        self.capture_memory = capture_memory\n\n    def __enter__(self):\n        if self.log_stage_start_end:\n            msg = f\"[{self.stage_name}] started...\"\n            if self.logger.isEnabledFor(logging.DEBUG):\n                msg += f\" ({round(current_platform.get_available_gpu_memory(), 2)} GB left)\"\n            self.logger.info(msg)\n\n        if (self.log_timing and self.metrics) or self.log_stage_start_end:\n            if (\n                os.environ.get(\"SGLANG_DIFFUSION_SYNC_STAGE_PROFILING\", \"0\") == \"1\"\n                and self.stage_name.startswith(\"denoising_step_\")\n                and torch.get_device_module().is_available()\n            ):\n                torch.get_device_module().synchronize()\n            self.start_time = time.perf_counter()\n\n        return self\n\n    def __exit__(self, exc_type, exc_val, exc_tb):\n        if not ((self.log_timing and self.metrics) or self.log_stage_start_end):\n            return False\n\n        if (\n            os.environ.get(\"SGLANG_DIFFUSION_SYNC_STAGE_PROFILING\", \"0\") == \"1\"\n            and self.stage_name.startswith(\"denoising_step_\")\n            and torch.get_device_module().is_available()\n        ):\n            torch.get_device_module().synchronize()\n        execution_time_s = time.perf_counter() - self.start_time\n\n        if exc_type:\n            self.logger.error(\n                \"[%s] Error during execution after %.4f ms: %s\",\n                self.stage_name,\n                execution_time_s * 1000,\n                exc_val,\n                exc_info=True,\n            )\n            return False\n\n        if self.log_stage_start_end:\n            self.logger.info(\n                f\"[{self.stage_name}] finished in {execution_time_s:.4f} seconds\",\n            )\n\n        if self.log_timing and self.metrics:\n            if \"denoising_step_\" in self.stage_name:\n                index = int(self.stage_name[len(\"denoising_step_\") :])\n                self.metrics.record_steps(index, execution_time_s)\n            else:\n                self.metrics.record_stage(self.stage_name, execution_time_s)\n\n            # capture memory snapshot after stage if requested\n            if self.capture_memory and torch.get_device_module().is_available():\n                snapshot = capture_memory_snapshot()\n                self.metrics.record_memory_snapshot(\n                    f\"after_{self.stage_name}\", snapshot\n                )\n\n        return False\n\n\nclass PerformanceLogger:\n    \"\"\"\n    A global utility class for logging performance metrics for all request, categorized by request-id.\n\n    Serves both as a runtime logger (stream to file) and a dump utility.\n\n    Notice that RequestMetrics stores the performance metrics of a single request\n    \"\"\"\n\n    @classmethod\n    def dump_benchmark_report(\n        cls,\n        file_path: str,\n        metrics: \"RequestMetrics\",\n        meta: Optional[Dict[str, Any]] = None,\n        tag: str = \"benchmark_dump\",\n    ):\n        \"\"\"\n        Static method to dump a standardized benchmark report to a file.\n        Eliminates duplicate logic in CLI/Client code.\n        \"\"\"\n        formatted_steps = [\n            {\"name\": name, \"duration_ms\": duration_ms}\n            for name, duration_ms in metrics.stages.items()\n        ]\n\n        denoise_steps_ms = [\n            {\"step\": idx, \"duration_ms\": duration_ms}\n            for idx, duration_ms in enumerate(metrics.steps)\n        ]\n\n        memory_checkpoints = {\n            name: snapshot.to_dict()\n            for name, snapshot in metrics.memory_snapshots.items()\n        }\n\n        report = {\n            \"timestamp\": datetime.now(UTC).isoformat(),\n            \"request_id\": metrics.request_id,\n            \"commit_hash\": get_git_commit_hash(),\n            \"tag\": tag,\n            \"total_duration_ms\": metrics.total_duration_ms,\n            \"steps\": formatted_steps,\n            \"denoise_steps_ms\": denoise_steps_ms,\n            \"memory_checkpoints\": memory_checkpoints,\n            \"meta\": meta or {},\n        }\n\n        try:\n            abs_path = os.path.abspath(file_path)\n            os.makedirs(os.path.dirname(abs_path), exist_ok=True)\n            with open(abs_path, \"w\", encoding=\"utf-8\") as f:\n                json.dump(report, f, indent=2)\n            logger.info(f\"Metrics dumped to: {CYAN}{abs_path}{RESET}\")\n        except IOError as e:\n            logger.error(f\"Failed to dump metrics to {abs_path}: {e}\")\n\n    @classmethod\n    def log_request_summary(\n        cls,\n        metrics: \"RequestMetrics\",\n        tag: str = \"total_inference_time\",\n    ):\n        \"\"\"logs the stage metrics and total duration for a completed request\n        to the performance_log file.\n\n        Note that this accords to the time spent internally in server, postprocess is not included\n        \"\"\"\n        formatted_stages = [\n            {\"name\": name, \"execution_time_ms\": duration_ms}\n            for name, duration_ms in metrics.stages.items()\n        ]\n\n        memory_checkpoints = {\n            name: snapshot.to_dict()\n            for name, snapshot in metrics.memory_snapshots.items()\n        }\n\n        record = RequestPerfRecord(\n            metrics.request_id,\n            commit_hash=get_git_commit_hash(),\n            tag=\"pipeline_stage_metrics\",\n            stages=formatted_stages,\n            steps=metrics.steps,\n            total_duration_ms=metrics.total_duration_ms,\n            memory_snapshots=memory_checkpoints,\n        )\n\n        try:\n            if get_is_main_process():\n                log_dir = get_diffusion_perf_log_dir()\n                if not os.path.exists(log_dir):\n                    os.makedirs(log_dir, exist_ok=True)\n\n                log_file = os.path.join(log_dir, \"performance.log\")\n\n                with open(log_file, \"a\", encoding=\"utf-8\") as f:\n                    f.write(json.dumps(dataclasses.asdict(record)) + \"\\n\")\n\n        except (OSError, PermissionError) as e:\n            print(f\"WARNING: Failed to log performance record: {e}\", file=sys.stderr)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/utils/profiler.py",
    "content": "import gzip\nimport os\n\nimport torch\n\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import CYAN, RESET, init_logger\n\nif current_platform.is_npu():\n    import torch_npu\n\n    patches = [\n        [\"profiler.profile\", torch_npu.profiler.profile],\n        [\"profiler.schedule\", torch_npu.profiler.schedule],\n    ]\n    torch_npu._apply_patches(patches)\n\nlogger = init_logger(__name__)\n\n\nclass SGLDiffusionProfiler:\n    \"\"\"\n    A wrapper around torch.profiler to simplify usage in pipelines.\n    Supports both full profiling and scheduled profiling.\n\n\n    1. if profile_all_stages is on: profile all stages, including all denoising steps\n    2. otherwise, if num_profiled_timesteps is specified: profile {num_profiled_timesteps} denoising steps. profile all steps if num_profiled_timesteps==-1\n    \"\"\"\n\n    _instance = None\n\n    def __init__(\n        self,\n        request_id: str | None = None,\n        rank: int = 0,\n        full_profile: bool = False,\n        num_steps: int | None = None,\n        num_inference_steps: int | None = None,\n        log_dir: str | None = None,\n    ):\n        self.request_id = request_id or \"profile_trace\"\n        self.rank = rank\n        self.full_profile = full_profile\n\n        self.log_dir = (\n            log_dir\n            if log_dir is not None\n            else os.getenv(\"SGLANG_TORCH_PROFILER_DIR\", \"./logs\")\n        )\n\n        try:\n            os.makedirs(self.log_dir, exist_ok=True)\n        except OSError:\n            pass\n\n        activities = [torch.profiler.ProfilerActivity.CPU]\n        if torch.cuda.is_available() or (\n            hasattr(torch, \"musa\") and torch.musa.is_available()\n        ):\n            activities.append(torch.profiler.ProfilerActivity.CUDA)\n        if current_platform.is_npu():\n            activities.append(torch_npu.profiler.ProfilerActivity.NPU)\n\n        common_torch_profiler_args = dict(\n            activities=activities,\n            record_shapes=True,\n            with_stack=True,\n            on_trace_ready=(\n                None\n                if not current_platform.is_npu()\n                else torch_npu.profiler.tensorboard_trace_handler(self.log_dir)\n            ),\n        )\n        if self.full_profile:\n            # profile all stages\n            self.profiler = torch.profiler.profile(**common_torch_profiler_args)\n            self.profile_mode_id = \"full stages\"\n        else:\n            # profile denoising stage only\n            warmup = 1\n            num_actual_steps = num_inference_steps if num_steps == -1 else num_steps\n            self.num_active_steps = num_actual_steps + warmup\n            self.profiler = torch.profiler.profile(\n                **common_torch_profiler_args,\n                schedule=torch.profiler.schedule(\n                    skip_first=0,\n                    wait=0,\n                    warmup=warmup,\n                    active=self.num_active_steps,\n                    repeat=1,\n                ),\n            )\n            self.profile_mode_id = f\"{num_actual_steps} steps\"\n\n        logger.info(f\"Profiling request: {request_id} for {self.profile_mode_id}...\")\n\n        self.has_stopped = False\n\n        SGLDiffusionProfiler._instance = self\n        self.start()\n\n    def start(self):\n        logger.info(\"Starting Profiler...\")\n        self.profiler.start()\n\n    def _step(self):\n        self.profiler.step()\n\n    def step_stage(self):\n        if self.full_profile:\n            self._step()\n\n    def step_denoising_step(self):\n        if not self.full_profile:\n            if self.num_active_steps >= 0:\n                self._step()\n                self.num_active_steps -= 1\n            else:\n                # early exit when enough steps are captured, to reduce the trace file size\n                self.stop(dump_rank=0)\n\n    @classmethod\n    def get_instance(cls) -> \"SGLDiffusionProfiler\":\n        return cls._instance\n\n    def stop(self, export_trace: bool = True, dump_rank: int | None = None):\n        if self.has_stopped:\n            return\n        self.has_stopped = True\n        logger.info(\"Stopping Profiler...\")\n        if torch.cuda.is_available() or (\n            hasattr(torch, \"musa\") and torch.musa.is_available()\n        ):\n            torch.cuda.synchronize()\n        if current_platform.is_npu():\n            torch.npu.synchronize()\n            export_trace = False  # set to false because our internal torch_npu.profiler will generate trace file\n        self.profiler.stop()\n\n        if export_trace:\n            if dump_rank is not None and dump_rank != self.rank:\n                pass\n            else:\n                self._export_trace()\n\n        SGLDiffusionProfiler._instance = None\n\n    def _export_trace(self):\n\n        try:\n            os.makedirs(self.log_dir, exist_ok=True)\n            sanitized_profile_mode_id = self.profile_mode_id.replace(\" \", \"_\")\n            trace_path = os.path.abspath(\n                os.path.join(\n                    self.log_dir,\n                    f\"{self.request_id}-{sanitized_profile_mode_id}-global-rank{self.rank}.trace.json.gz\",\n                )\n            )\n            self.profiler.export_chrome_trace(trace_path)\n\n            if self._check_trace_integrity(trace_path):\n                logger.info(f\"Saved profiler traces to: {CYAN}{trace_path}{RESET}\")\n            else:\n                logger.warning(f\"Trace file may be corrupted: {trace_path}\")\n        except Exception as e:\n            logger.error(f\"Failed to save trace: {e}\")\n\n    def _check_trace_integrity(self, trace_path: str) -> bool:\n        try:\n            if not os.path.exists(trace_path) or os.path.getsize(trace_path) == 0:\n                return False\n\n            with gzip.open(trace_path, \"rb\") as f:\n                content = f.read()\n                if content.count(b\"\\x1f\\x8b\") > 1:\n                    logger.warning(\"Multiple gzip headers detected\")\n                    return False\n\n            return True\n        except Exception as e:\n            logger.warning(f\"Trace file integrity check failed: {e}\")\n            return False\n"
  },
  {
    "path": "python/sglang/multimodal_gen/runtime/utils/quantization_utils.py",
    "content": "import glob\nimport json\nimport os\nfrom pathlib import Path\nfrom typing import Dict, List, Optional\n\nfrom safetensors import safe_open\n\nfrom sglang.multimodal_gen.runtime.layers.quantization import (\n    QuantizationConfig,\n    get_quantization_config,\n)\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n\ndef find_quant_modelslim_config(model_config, component_model_path):\n    quant_config_file = Path(component_model_path, \"quant_model_description.json\")\n    quant_cfg = None\n    if quant_config_file.is_file():\n        with open(quant_config_file) as f:\n            quant_cfg = json.load(f)\n        # This field is required for flagless model loading but is not present in\n        # modelslim model description, so we're adding it here manually.\n        quant_cfg[\"quant_method\"] = \"modelslim\"\n\n    return quant_cfg\n\n\ndef replace_prefix(key: str, prefix_mapping: dict[str, str]) -> str:\n    for prefix, new_prefix in prefix_mapping.items():\n        if key.startswith(prefix):\n            key = key.replace(prefix, new_prefix, 1)\n    return key\n\n\ndef get_quant_config(\n    model_config,\n    component_model_path: str,\n    packed_modules_mapping: Dict[str, List[str]] = {},\n    remap_prefix: Dict[str, str] | None = None,\n) -> QuantizationConfig:\n\n    quant_cfg = find_quant_modelslim_config(model_config, component_model_path)\n    if quant_cfg is not None:\n        quant_cls = get_quantization_config(quant_cfg[\"quant_method\"])\n        return quant_cls.from_config(quant_cfg)\n    else:\n        if \"quantization_config\" not in model_config:\n            return None\n        quant_cls = get_quantization_config(\n            model_config[\"quantization_config\"][\"quant_method\"]\n        )\n\n        # GGUF doesn't have config file\n        if model_config[\"quantization_config\"][\"quant_method\"] == \"gguf\":\n            return quant_cls.from_config({})\n\n        # Read the quantization config from the HF model config, if available.\n        hf_quant_config = model_config[\"quantization_config\"]\n        # some vision model may keep quantization_config in their text_config\n        hf_text_config = getattr(model_config, \"text_config\", None)\n        if hf_quant_config is None and hf_text_config is not None:\n            hf_quant_config = getattr(hf_text_config, \"quantization_config\", None)\n        if hf_quant_config is None:\n            # compressed-tensors uses a compressions_config\n            hf_quant_config = getattr(model_config, \"compression_config\", None)\n        if hf_quant_config is not None:\n            hf_quant_config[\"packed_modules_mapping\"] = packed_modules_mapping\n            return quant_cls.from_config(hf_quant_config)\n        # In case of bitsandbytes/QLoRA, get quant config from the adapter model.\n        else:\n            model_name_or_path = model_config[\"model_path\"]\n        is_local = os.path.isdir(model_name_or_path)\n        hf_folder = model_name_or_path\n\n        possible_config_filenames = quant_cls.get_config_filenames()\n\n        # If the quantization config is not found, use the default config.\n        if not possible_config_filenames:\n            return quant_cls()\n\n        config_files = glob.glob(os.path.join(hf_folder, \"*.json\"))\n\n        quant_config_files = [\n            f\n            for f in config_files\n            if any(f.endswith(x) for x in possible_config_filenames)\n        ]\n        if len(quant_config_files) == 0:\n            raise ValueError(\n                f\"Cannot find the config file for {model_config['quantization_config']['quant_method']}\"\n            )\n        if len(quant_config_files) > 1:\n            raise ValueError(\n                f\"Found multiple config files for {model_config['quantization_config']['quant_method']}: \"\n                f\"{quant_config_files}\"\n            )\n\n        quant_config_file = quant_config_files[0]\n        with open(quant_config_file) as f:\n            config = json.load(f)\n            if remap_prefix is not None:\n                exclude_modules = [\n                    replace_prefix(key, remap_prefix)\n                    for key in config[\"quantization\"][\"exclude_modules\"]\n                ]\n                config[\"quantization\"][\"exclude_modules\"] = exclude_modules\n            config[\"packed_modules_mapping\"] = packed_modules_mapping\n            return quant_cls.from_config(config)\n\n\ndef handle_fp8_metadata_format(quant_config_dict):\n    layers = quant_config_dict.get(\"layers\", {})\n    if any(\n        isinstance(v, dict) and \"float8\" in v.get(\"format\", \"\") for v in layers.values()\n    ):\n        quant_config_dict[\"quant_method\"] = \"fp8\"\n        quant_config_dict[\"activation_scheme\"] = \"dynamic\"\n    return quant_config_dict\n\n\ndef get_quant_config_from_safetensors_metadata(\n    file_path: str,\n) -> Optional[QuantizationConfig]:\n    \"\"\"Extract quantization config from a safetensors file's metadata header.\n    Returns None if no recognizable quantization metadata is found.\n    \"\"\"\n    metadata = get_metadata_from_safetensors_file(file_path)\n    if not metadata:\n        return None\n\n    quant_config_str = metadata.get(\"_quantization_metadata\")\n    if not quant_config_str:\n        return None\n    try:\n        quant_config_dict = json.loads(quant_config_str)\n    except Exception as _e:\n        return None\n\n    # handle diffusers fp8 safetensors metadata format\n    if (\n        \"quant_method\" not in quant_config_dict\n        and \"format_version\" in quant_config_dict\n        and \"layers\" in quant_config_dict\n    ):\n        quant_config_dict = handle_fp8_metadata_format(quant_config_dict)\n\n    quant_method = quant_config_dict.get(\"quant_method\")\n    if not quant_method:\n        return None\n\n    try:\n        quant_cls = get_quantization_config(quant_method)\n        config = quant_cls.from_config(quant_config_dict)\n        logger.debug(f\"Get quantization config from safetensors file: {file_path}\")\n        return config\n    except Exception as _e:\n        return None\n\n\ndef get_metadata_from_safetensors_file(file_path: str):\n    try:\n        with safe_open(file_path, framework=\"pt\", device=\"cpu\") as f:\n            metadata = f.metadata()\n            return metadata\n    except Exception as e:\n        logger.warning(e)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/test/__init__.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n"
  },
  {
    "path": "python/sglang/multimodal_gen/test/cli/test_generate_common.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n\"\"\"\nCommon generate cli test, one test for image and video each\n\"\"\"\n\nimport dataclasses\nimport os\nimport shlex\nimport subprocess\nimport sys\nimport unittest\nfrom typing import Optional\n\nfrom PIL import Image\n\nfrom sglang.multimodal_gen.configs.sample.sampling_params import DataType\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.test.test_utils import check_image_size\n\nlogger = init_logger(__name__)\n\n\n@dataclasses.dataclass\nclass TestResult:\n    name: str\n    key: str\n    succeed: bool\n\n\ndef run_command(command) -> Optional[float]:\n    \"\"\"Runs a command and returns the execution time and status.\"\"\"\n    print(f\"Running command: {shlex.join(command)}\")\n\n    with subprocess.Popen(\n        command,\n        stdout=subprocess.PIPE,\n        stderr=subprocess.STDOUT,\n        text=True,\n        encoding=\"utf-8\",\n    ) as process:\n        for line in process.stdout:\n            sys.stdout.write(line)\n        process.wait()\n        if process.returncode == 0:\n            return True\n        print(f\"Command failed with exit code {process.returncode}\")\n    return False\n\n\nclass CLIBase(unittest.TestCase):\n    model_path: str = None\n    extra_args = []\n    data_type: DataType = None\n    # tested on h100\n\n    width: int = 720\n    height: int = 720\n    output_path: str = \"test_outputs\"\n\n    def setUp(self):\n        super().setUp()\n        if not os.path.exists(self.output_path):\n            os.makedirs(self.output_path, exist_ok=True)\n        if os.path.exists(self.output_path):\n            for f in os.listdir(self.output_path):\n                path = os.path.join(self.output_path, f)\n                if os.path.isfile(path):\n                    os.remove(path)\n\n    def tearDown(self):\n        super().tearDown()\n        if os.path.exists(self.output_path):\n            for f in os.listdir(self.output_path):\n                path = os.path.join(self.output_path, f)\n                if os.path.isfile(path):\n                    os.remove(path)\n\n    def get_base_command(self):\n        return [\n            \"sglang\",\n            \"generate\",\n            \"--prompt\",\n            \"A curious raccoon\",\n            \"--save-output\",\n            \"--log-level=debug\",\n            f\"--width={self.width}\",\n            f\"--height={self.height}\",\n            f\"--output-path={self.output_path}\",\n        ]\n\n    def _run_command(self, name: str, model_path: str, args=[]):\n        command = (\n            self.get_base_command()\n            + [f\"--model-path={model_path}\"]\n            + shlex.split(args or \"\")\n            + [\"--output-file-name\", f\"{name}\"]\n            + self.extra_args\n        )\n        succeed = run_command(command)\n        status = \"Success\" if succeed else \"Failed\"\n\n        return name, status\n\n    def _run_test(self, name: str, args, model_path: str, test_key: str):\n        name, status = self._run_command(name, args=args, model_path=model_path)\n        self.verify(status, name)\n\n    def verify(self, status, name):\n        print(\"-\" * 80)\n        print(\"\\n\" * 3)\n\n        # test task status\n        self.assertEqual(status, \"Success\", f\"{name} command failed\")\n\n        # test output file\n        path = os.path.join(\n            self.output_path, f\"{name}.{self.data_type.get_default_extension()}\"\n        )\n        self.assertTrue(os.path.exists(path), f\"Output file not exist for {path}\")\n        if self.data_type == DataType.IMAGE:\n            with Image.open(path) as image:\n                check_image_size(self, image, self.width, self.height)\n\n    def model_name(self):\n        return self.model_path.split(\"/\")[-1]\n\n    def test_single_gpu(self):\n        \"\"\"single gpu\"\"\"\n        self._run_test(\n            name=f\"{self.model_name()}_single_gpu\",\n            args=None,\n            model_path=self.model_path,\n            test_key=\"test_single_gpu\",\n        )\n"
  },
  {
    "path": "python/sglang/multimodal_gen/test/cli/test_generate_i2i.py",
    "content": "import os\nimport unittest\n\nfrom PIL import Image\n\nfrom sglang.multimodal_gen.configs.sample.sampling_params import DataType\nfrom sglang.multimodal_gen.test.cli.test_generate_common import CLIBase, run_command\nfrom sglang.multimodal_gen.test.test_utils import (\n    DEFAULT_QWEN_IMAGE_EDIT_2511_MODEL_NAME_FOR_TEST,\n    check_image_size,\n)\n\n\nclass TestQwenImageEditI2I(CLIBase):\n    model_path: str = DEFAULT_QWEN_IMAGE_EDIT_2511_MODEL_NAME_FOR_TEST\n    data_type: DataType = DataType.IMAGE\n    width: int = 512\n    height: int = 512\n\n    test_image_urls = [\n        \"https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Image/edit2509/edit2509_1.jpg\",\n        \"https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Image/edit2509/edit2509_2.jpg\",\n    ]\n\n    def get_base_command(self):\n        return [\n            \"sglang\",\n            \"generate\",\n            \"--save-output\",\n            \"--log-level=info\",\n            f\"--width={self.width}\",\n            f\"--height={self.height}\",\n            f\"--output-path={self.output_path}\",\n        ]\n\n    def verify_multi_output(self, name: str, num_outputs: int):\n        output_files = []\n        try:\n            all_files = os.listdir(self.output_path)\n            ext = self.data_type.get_default_extension()\n            for f in all_files:\n                if f.endswith(f\".{ext}\"):\n                    output_files.append(f)\n\n            self.assertEqual(\n                len(output_files),\n                num_outputs,\n                f\"Expected {num_outputs} output files, found {len(output_files)}: {output_files}\",\n            )\n\n            for f in output_files:\n                path = os.path.join(self.output_path, f)\n                with Image.open(path) as image:\n                    check_image_size(self, image, self.width, self.height)\n        finally:\n            for f in output_files:\n                path = os.path.join(self.output_path, f)\n                if os.path.exists(path):\n                    os.remove(path)\n\n    def test_single_prompt_single_image(self):\n        \"\"\"Case 1: Single prompt + single image.\"\"\"\n        name = \"single_prompt_single_image\"\n\n        command = self.get_base_command() + [\n            f\"--model-path={self.model_path}\",\n            \"--prompt\",\n            \"Add a red hat\",\n            \"--image-path\",\n            self.test_image_urls[0],\n        ]\n\n        succeed = run_command(command)\n        self.assertTrue(succeed, f\"{name} command failed\")\n        self.verify_multi_output(name, 1)\n\n    def test_single_prompt_multi_image(self):\n        \"\"\"Case 2: Single prompt + multiple images (image composition).\"\"\"\n        name = \"single_prompt_multi_image\"\n\n        command = self.get_base_command() + [\n            f\"--model-path={self.model_path}\",\n            \"--prompt\",\n            \"Combine both images\",\n            \"--image-path\",\n            *self.test_image_urls,\n        ]\n\n        succeed = run_command(command)\n        self.assertTrue(succeed, f\"{name} command failed\")\n        self.verify_multi_output(name, 1)\n\n    def test_multi_prompt_multi_image(self):\n        \"\"\"Case 3: Multiple prompts + multiple images (image editing).\"\"\"\n        name = \"multi_prompt_multi_image\"\n\n        command = self.get_base_command() + [\n            f\"--model-path={self.model_path}\",\n            \"--prompt\",\n            \"Convert to oil painting style\",\n            \"Convert to watercolor style\",\n            \"--image-path\",\n            *self.test_image_urls,\n        ]\n\n        succeed = run_command(command)\n        self.assertTrue(succeed, f\"{name} command failed\")\n        self.verify_multi_output(name, 2)\n\n    def test_multi_prompt_single_image(self):\n        \"\"\"Case 4: Multiple prompts + single image (image editing).\"\"\"\n        name = \"multi_prompt_single_image\"\n\n        command = self.get_base_command() + [\n            f\"--model-path={self.model_path}\",\n            \"--prompt\",\n            \"Add a red hat\",\n            \"Change to blue background\",\n            \"--image-path\",\n            self.test_image_urls[0],\n        ]\n\n        succeed = run_command(command)\n        self.assertTrue(succeed, f\"{name} command failed\")\n        self.verify_multi_output(name, 2)\n\n\ndel CLIBase\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/sglang/multimodal_gen/test/cli/test_generate_t2i_perf.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\nimport unittest\n\nfrom sglang.multimodal_gen.configs.sample.sampling_params import DataType\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.test.cli.test_generate_common import CLIBase\nfrom sglang.multimodal_gen.test.test_utils import DEFAULT_FLUX_1_DEV_MODEL_NAME_FOR_TEST\n\nlogger = init_logger(__name__)\n\n\nclass TestFlux_T2V(CLIBase):\n    model_path = DEFAULT_FLUX_1_DEV_MODEL_NAME_FOR_TEST\n    extra_args = []\n    data_type: DataType = DataType.IMAGE\n\n\ndel CLIBase\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/sglang/multimodal_gen/test/run_suite.py",
    "content": "\"\"\"\nTest runner for multimodal_gen that manages test suites and parallel execution.\n\nUsage:\n    python3 run_suite.py --suite <suite_name> --partition-id <id> --total-partitions <num>\n\nExample:\n    python3 run_suite.py --suite 1-gpu --partition-id 0 --total-partitions 4\n\"\"\"\n\nimport argparse\nimport os\nimport random\nimport subprocess\nimport sys\nfrom pathlib import Path\n\nimport tabulate\n\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\n\nlogger = init_logger(__name__)\n\n_UPDATE_WEIGHTS_FROM_DISK_TEST_FILE = \"test_update_weights_from_disk.py\"\n_UPDATE_WEIGHTS_MODEL_PAIR_ENV = \"SGLANG_MMGEN_UPDATE_WEIGHTS_PAIR\"\n_UPDATE_WEIGHTS_MODEL_PAIR_IDS = (\n    \"FLUX.2-klein-base-4B\",\n    \"Qwen-Image\",\n)\n\nSUITES = {\n    # no GPU required; safe to run on any CPU-only runner\n    \"unit\": [\n        \"../unit/test_sampling_params.py\",\n        \"../unit/test_storage.py\",\n        \"../unit/test_lora_format_adapter.py\",\n        \"../unit/test_server_args.py\",\n        # add new unit tests here\n    ],\n    \"1-gpu\": [\n        \"test_server_a.py\",\n        \"test_server_b.py\",\n        # cli test\n        \"../cli/test_generate_t2i_perf.py\",\n        \"test_update_weights_from_disk.py\",\n        # add new 1-gpu test files here\n    ],\n    \"2-gpu\": [\n        \"test_server_2_gpu_a.py\",\n        \"test_server_2_gpu_b.py\",\n        # add new 2-gpu test files here\n    ],\n}\n\nsuites_ascend = {\n    \"1-npu\": [\n        \"ascend/test_server_1_npu.py\",\n        # add new 1-npu test files here\n    ],\n    \"2-npu\": [\n        \"ascend/test_server_2_npu.py\",\n        # add new 2-npu test files here\n    ],\n    \"8-npu\": [\n        \"ascend/test_server_8_npu.py\",\n        # add new 8-npu test files here\n    ],\n}\n\nSUITES.update(suites_ascend)\nSTRICT_SUITES = {\"unit\"}\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Run multimodal_gen test suite\")\n    parser.add_argument(\n        \"--suite\",\n        type=str,\n        required=True,\n        choices=list(SUITES.keys()),\n        help=\"The test suite to run (e.g., 1-gpu, 2-gpu)\",\n    )\n    parser.add_argument(\n        \"--partition-id\",\n        type=int,\n        default=0,\n        help=\"Index of the current partition (for parallel execution)\",\n    )\n    parser.add_argument(\n        \"--total-partitions\",\n        type=int,\n        default=1,\n        help=\"Total number of partitions\",\n    )\n    parser.add_argument(\n        \"--base-dir\",\n        type=str,\n        default=\"server\",\n        help=\"Base directory for tests relative to this script's parent\",\n    )\n    parser.add_argument(\n        \"-k\",\n        \"--filter\",\n        type=str,\n        default=None,\n        help=\"Pytest filter expression (passed to pytest -k)\",\n    )\n    parser.add_argument(\n        \"--continue-on-error\",\n        action=\"store_true\",\n        default=False,\n        help=\"Continue running remaining tests even if one fails (for CI consistency; pytest already continues by default)\",\n    )\n    return parser.parse_args()\n\n\ndef collect_test_items(files, filter_expr=None):\n    \"\"\"Collect test item node IDs from the given files using pytest --collect-only.\"\"\"\n    cmd = [sys.executable, \"-m\", \"pytest\", \"--collect-only\", \"-q\"]\n    if filter_expr:\n        cmd.extend([\"-k\", filter_expr])\n    cmd.extend(files)\n\n    print(f\"Collecting tests with command: {' '.join(cmd)}\")\n    result = subprocess.run(cmd, capture_output=True, text=True)\n\n    # Check for collection errors\n    # pytest exit codes:\n    #   0: success\n    #   1: tests collected but some had errors during collection\n    #   2: test execution interrupted\n    #   3: internal error\n    #   4: command line usage error\n    #   5: no tests collected (may be expected with filters)\n    if result.returncode not in (0, 5):\n        error_msg = (\n            f\"pytest --collect-only failed with exit code {result.returncode}\\n\"\n            f\"Command: {' '.join(cmd)}\\n\"\n        )\n        if result.stderr:\n            error_msg += f\"stderr:\\n{result.stderr}\\n\"\n        if result.stdout:\n            error_msg += f\"stdout:\\n{result.stdout}\\n\"\n        logger.error(error_msg)\n        raise RuntimeError(error_msg)\n\n    if result.returncode == 5:\n        print(\n            \"No tests were collected (exit code 5). This may be expected with filters.\"\n        )\n\n    # Parse the output to extract test node IDs\n    # pytest -q outputs lines like: test_file.py::TestClass::test_method[param]\n    test_items = []\n    for line in result.stdout.strip().split(\"\\n\"):\n        line = line.strip()\n        # Skip empty lines and summary lines\n        if line and \"::\" in line and not line.startswith((\"=\", \"-\", \" \")):\n            # Handle lines that might have extra info after the test ID\n            test_id = line.split()[0] if \" \" in line else line\n            if \"::\" in test_id:\n                test_items.append(test_id)\n\n    print(f\"Collected {len(test_items)} test items\")\n    return test_items\n\n\ndef run_pytest(files, filter_expr=None):\n    if not files:\n        print(\"No files to run.\")\n        return 0\n\n    base_cmd = [sys.executable, \"-m\", \"pytest\", \"-s\", \"-v\"]\n\n    # Add pytest -k filter if provided\n    if filter_expr:\n        base_cmd.extend([\"-k\", filter_expr])\n\n    max_retries = 6\n    # retry if the perf assertion failed, for {max_retries} times\n    for i in range(max_retries + 1):\n        cmd = list(base_cmd)\n        if i > 0:\n            cmd.append(\"--last-failed\")\n        # Always include files to constrain test discovery scope\n        # This prevents pytest from scanning the entire rootdir and\n        # discovering unrelated tests that may have missing dependencies\n        cmd.extend(files)\n\n        if i > 0:\n            print(\n                f\"Performance assertion failed. Retrying ({i}/{max_retries}) with --last-failed...\"\n            )\n\n        print(f\"Running command: {' '.join(cmd)}\")\n\n        process = subprocess.Popen(\n            cmd,\n            stdout=subprocess.PIPE,\n            stderr=subprocess.STDOUT,\n            bufsize=0,\n        )\n\n        output_bytes = bytearray()\n        while True:\n            chunk = process.stdout.read(4096)\n            if not chunk:\n                break\n            sys.stdout.buffer.write(chunk)\n            sys.stdout.buffer.flush()\n            output_bytes.extend(chunk)\n\n        process.wait()\n        returncode = process.returncode\n\n        if returncode == 0:\n            return 0\n\n        # Exit code 5 means no tests were collected/selected - treat as success\n        # when using filters, since some partitions may have all tests filtered out\n        if returncode == 5:\n            print(\n                \"No tests collected (exit code 5). This is expected when filters \"\n                \"deselect all tests in a partition. Treating as success.\"\n            )\n            return 0\n\n        # check if the failure is due to an assertion in test_server_utils.py\n        full_output = output_bytes.decode(\"utf-8\", errors=\"replace\")\n        is_perf_assertion = (\n            \"multimodal_gen/test/server/test_server_utils.py\" in full_output\n            and \"AssertionError\" in full_output\n        )\n\n        is_flaky_ci_assertion = (\n            \"SafetensorError\" in full_output or \"FileNotFoundError\" in full_output\n        )\n\n        is_oom_error = (\n            \"out of memory\" in full_output.lower()\n            or \"oom killer\" in full_output.lower()\n        )\n\n        if not (is_perf_assertion or is_flaky_ci_assertion or is_oom_error):\n            return returncode\n\n    print(f\"Max retry exceeded\")\n    return returncode\n\n\ndef _is_in_ci() -> bool:\n    return os.environ.get(\"SGLANG_IS_IN_CI\", \"\").lower() in (\"1\", \"true\", \"yes\", \"on\")\n\n\ndef _maybe_pin_update_weights_model_pair(suite_files_rel: list[str]) -> None:\n    if not _is_in_ci():\n        return\n    if _UPDATE_WEIGHTS_FROM_DISK_TEST_FILE not in suite_files_rel:\n        return\n    if os.environ.get(_UPDATE_WEIGHTS_MODEL_PAIR_ENV):\n        print(\n            f\"Using preset {_UPDATE_WEIGHTS_MODEL_PAIR_ENV}=\"\n            f\"{os.environ[_UPDATE_WEIGHTS_MODEL_PAIR_ENV]}\"\n        )\n        return\n\n    selected_pair = random.choice(_UPDATE_WEIGHTS_MODEL_PAIR_IDS)\n    os.environ[_UPDATE_WEIGHTS_MODEL_PAIR_ENV] = selected_pair\n    print(f\"Selected {_UPDATE_WEIGHTS_MODEL_PAIR_ENV}={selected_pair} for this CI run\")\n\n\ndef main():\n    args = parse_args()\n\n    # 1. resolve base path\n    current_file_path = Path(__file__).resolve()\n    test_root_dir = current_file_path.parent\n    target_dir = test_root_dir / args.base_dir\n\n    if not target_dir.exists():\n        print(f\"Error: Target directory {target_dir} does not exist.\")\n        sys.exit(1)\n\n    # 2. get files from suite\n    suite_files_rel = SUITES[args.suite]\n    _maybe_pin_update_weights_model_pair(suite_files_rel)\n\n    suite_files_abs = []\n    for f_rel in suite_files_rel:\n        f_abs = target_dir / f_rel\n        if not f_abs.exists():\n            msg = f\"Test file {f_rel} not found in {target_dir}.\"\n            if args.suite in STRICT_SUITES:\n                print(f\"Error: {msg}\")\n                sys.exit(1)\n            print(f\"Warning: {msg} Skipping.\")\n            continue\n        suite_files_abs.append(str(f_abs))\n\n    if not suite_files_abs:\n        print(f\"No valid test files found for suite '{args.suite}'.\")\n        sys.exit(1 if args.suite in STRICT_SUITES else 0)\n\n    # 3. collect all test items and partition by items (not files)\n    all_test_items = collect_test_items(suite_files_abs, filter_expr=args.filter)\n\n    if not all_test_items:\n        print(f\"No test items found for suite '{args.suite}'.\")\n        sys.exit(0)\n\n    # Partition by test items\n    my_items = [\n        item\n        for i, item in enumerate(all_test_items)\n        if i % args.total_partitions == args.partition_id\n    ]\n\n    # Print test info at beginning (similar to test/run_suite.py pretty_print_tests)\n    partition_info = f\"{args.partition_id + 1}/{args.total_partitions} (0-based id={args.partition_id})\"\n    headers = [\"Suite\", \"Partition\"]\n    rows = [[args.suite, partition_info]]\n    msg = tabulate.tabulate(rows, headers=headers, tablefmt=\"psql\") + \"\\n\"\n    msg += f\"✅ Enabled {len(my_items)} test(s):\\n\"\n    for item in my_items:\n        msg += f\"  - {item}\\n\"\n    print(msg, flush=True)\n    print(\n        f\"Suite: {args.suite} | Partition: {args.partition_id}/{args.total_partitions}\"\n    )\n    print(f\"Selected {len(suite_files_abs)} files:\")\n    for f in suite_files_abs:\n        print(f\"  - {os.path.basename(f)}\")\n\n    if not my_items:\n        print(\"No items assigned to this partition. Exiting success.\")\n        sys.exit(0)\n\n    print(f\"Running {len(my_items)} items in this shard: {', '.join(my_items)}\")\n\n    # 4. execute with the specific test items\n    exit_code = run_pytest(my_items)\n\n    # Print tests again at the end for visibility\n    msg = \"\\n\" + tabulate.tabulate(rows, headers=headers, tablefmt=\"psql\") + \"\\n\"\n    msg += f\"✅ Executed {len(my_items)} test(s):\\n\"\n    for item in my_items:\n        msg += f\"  - {item}\\n\"\n    print(msg, flush=True)\n\n    sys.exit(exit_code)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "python/sglang/multimodal_gen/test/scripts/gen_diffusion_ci_outputs.py",
    "content": "#!/usr/bin/env python3\n\"\"\"\nGenerate diffusion CI outputs for consistency testing.\n\nThis script reuses the CI test code by calling run_suite.py with SGLANG_GEN_GT=1,\nensuring that GT generation uses exactly the same code path as CI tests.\n\nUsage:\n    python gen_diffusion_ci_outputs.py --suite 1-gpu --partition-id 0 --total-partitions 2 --out-dir ./output\n    python gen_diffusion_ci_outputs.py --suite 1-gpu --case-ids qwen_image_t2i flux_image_t2i --out-dir ./output\n\"\"\"\n\nimport argparse\nimport os\nimport sys\nfrom pathlib import Path\n\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.test.run_suite import SUITES, collect_test_items, run_pytest\n\nlogger = init_logger(__name__)\n\n\ndef main():\n    \"\"\"Main entry point.\"\"\"\n    parser = argparse.ArgumentParser(description=\"Generate diffusion CI outputs\")\n    parser.add_argument(\n        \"--suite\",\n        type=str,\n        choices=[\"1-gpu\", \"2-gpu\"],\n        required=True,\n        help=\"Test suite to run (1-gpu or 2-gpu)\",\n    )\n    parser.add_argument(\n        \"--partition-id\",\n        type=int,\n        required=False,\n        help=\"Partition ID for matrix partitioning (0-based)\",\n    )\n    parser.add_argument(\n        \"--total-partitions\",\n        type=int,\n        required=False,\n        help=\"Total number of partitions\",\n    )\n    parser.add_argument(\n        \"--out-dir\",\n        type=str,\n        required=True,\n        help=\"Output directory for generated files\",\n    )\n    parser.add_argument(\n        \"--continue-on-error\",\n        action=\"store_true\",\n        help=\"Continue processing other cases if one fails\",\n    )\n    parser.add_argument(\n        \"--case-ids\",\n        type=str,\n        nargs=\"*\",\n        required=False,\n        help=\"Specific case IDs to run (space-separated). If provided, only these cases will be run.\",\n    )\n\n    args = parser.parse_args()\n\n    # Validate partition arguments\n    if args.partition_id is not None and args.total_partitions is not None:\n        if args.partition_id < 0 or args.partition_id >= args.total_partitions:\n            parser.error(f\"partition-id must be in range [0, {args.total_partitions})\")\n    elif args.partition_id is not None or args.total_partitions is not None:\n        parser.error(\n            \"Both --partition-id and --total-partitions must be provided together\"\n        )\n\n    # Create output directory\n    out_dir = Path(args.out_dir)\n    out_dir.mkdir(parents=True, exist_ok=True)\n\n    # Set environment variables for GT generation mode\n    os.environ[\"SGLANG_GEN_GT\"] = \"1\"\n    os.environ[\"SGLANG_GT_OUTPUT_DIR\"] = str(out_dir.absolute())\n    os.environ[\"SGLANG_SKIP_CONSISTENCY\"] = (\n        \"1\"  # Skip consistency checks in GT gen mode\n    )\n\n    logger.info(f\"GT generation mode enabled\")\n    logger.info(f\"Output directory: {out_dir}\")\n\n    # Resolve test files path (same as run_suite.py)\n    current_file_path = Path(__file__).resolve()\n    test_root_dir = current_file_path.parent.parent  # scripts -> test\n    target_dir = test_root_dir / \"server\"\n\n    # Get files from suite (same as run_suite.py)\n    suite_files_rel = SUITES[args.suite]\n    suite_files_abs = []\n    for f_rel in suite_files_rel:\n        f_abs = target_dir / f_rel\n        if not f_abs.exists():\n            logger.warning(f\"Test file {f_rel} not found in {target_dir}. Skipping.\")\n            continue\n        suite_files_abs.append(str(f_abs))\n\n    if not suite_files_abs:\n        logger.error(f\"No valid test files found for suite '{args.suite}'.\")\n        sys.exit(1)\n\n    # Build pytest filter for case_ids if provided\n    filter_expr = None\n    if args.case_ids:\n        # pytest parametrized test format: test_diffusion_generation[case_id]\n        filters = [f\"test_diffusion_generation[{case_id}]\" for case_id in args.case_ids]\n        filter_expr = \" or \".join(filters)\n        logger.info(f\"Filtering by case IDs: {args.case_ids}\")\n\n    # Collect all test items (same as run_suite.py)\n    all_test_items = collect_test_items(suite_files_abs, filter_expr=filter_expr)\n\n    if not all_test_items:\n        logger.warning(f\"No test items found for suite '{args.suite}'.\")\n        sys.exit(0)\n\n    # Partition by test items (same as run_suite.py)\n    partition_id = args.partition_id if args.partition_id is not None else 0\n    total_partitions = args.total_partitions if args.total_partitions is not None else 1\n\n    my_items = [\n        item\n        for i, item in enumerate(all_test_items)\n        if i % total_partitions == partition_id\n    ]\n\n    logger.info(\n        f\"Partition {partition_id}/{total_partitions}: \"\n        f\"running {len(my_items)} of {len(all_test_items)} test items\"\n    )\n\n    if not my_items:\n        logger.warning(\"No items assigned to this partition. Exiting success.\")\n        sys.exit(0)\n\n    # Run pytest with the specific test items (same as run_suite.py)\n    exit_code = run_pytest(my_items)\n\n    if exit_code != 0:\n        if args.continue_on_error:\n            logger.warning(f\"pytest exited with code {exit_code}\")\n        else:\n            sys.exit(exit_code)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "python/sglang/multimodal_gen/test/scripts/gen_perf_baselines.py",
    "content": "import argparse\nimport inspect\nimport json\nimport os\nimport re\nimport sys\nfrom pathlib import Path\n\nfrom openai import OpenAI\n\nfrom sglang.multimodal_gen.test.server.test_server_utils import (\n    ServerManager,\n    get_generate_fn,\n)\nfrom sglang.multimodal_gen.test.server.testcase_configs import (\n    BASELINE_CONFIG,\n    DiffusionTestCase,\n)\nfrom sglang.multimodal_gen.test.test_utils import (\n    get_dynamic_server_port,\n    wait_for_req_perf_record,\n)\n\n\ndef _all_cases() -> list[DiffusionTestCase]:\n    import sglang.multimodal_gen.test.server.testcase_configs as cfg\n\n    cases: list[DiffusionTestCase] = []\n    for _, v in inspect.getmembers(cfg):\n        if isinstance(v, list) and v and isinstance(v[0], DiffusionTestCase):\n            cases.extend(v)\n\n    seen: set[str] = set()\n    out: list[DiffusionTestCase] = []\n    for c in cases:\n        if c.id not in seen:\n            seen.add(c.id)\n            out.append(c)\n    return out\n\n\ndef _baseline_path() -> Path:\n    import sglang.multimodal_gen.test.server.testcase_configs as cfg\n\n    return Path(cfg.__file__).with_name(\"perf_baselines.json\")\n\n\ndef _openai_client(port: int) -> OpenAI:\n    return OpenAI(api_key=\"sglang-anything\", base_url=f\"http://localhost:{port}/v1\")\n\n\ndef _build_server_extra_args(case: DiffusionTestCase) -> str:\n    server_args = case.server_args\n    a = os.environ.get(\"SGLANG_TEST_SERVE_ARGS\", \"\")\n    a += f\" --num-gpus {server_args.num_gpus}\"\n    if server_args.tp_size is not None:\n        a += f\" --tp-size {server_args.tp_size}\"\n    if server_args.ulysses_degree is not None:\n        a += f\" --ulysses-degree {server_args.ulysses_degree}\"\n    if server_args.dit_layerwise_offload:\n        a += \" --dit-layerwise-offload true\"\n    if server_args.dit_offload_prefetch_size:\n        a += f\" --dit-offload-prefetch-size {server_args.dit_offload_prefetch_size}\"\n    if server_args.text_encoder_cpu_offload:\n        a += \" --text-encoder-cpu-offload\"\n    if server_args.ring_degree is not None:\n        a += f\" --ring-degree {server_args.ring_degree}\"\n    if server_args.lora_path:\n        a += f\" --lora-path {server_args.lora_path}\"\n\n    # default warmup\n    a += \" --warmup\"\n\n    for extra_arg in server_args.extras:\n        a += f\" {extra_arg}\"\n    return a\n\n\ndef _build_env_vars(case: DiffusionTestCase) -> dict[str, str]:\n    if case.server_args.enable_cache_dit:\n        return {\"SGLANG_CACHE_DIT_ENABLED\": \"true\"}\n    return {}\n\n\ndef _torch_cleanup() -> None:\n    try:\n        import gc\n\n        gc.collect()\n    except Exception:\n        pass\n    try:\n        import torch\n\n        if torch.get_device_module().is_available():\n            torch.get_device_module().synchronize()\n            torch.get_device_module().empty_cache()\n    except Exception:\n        pass\n\n\ndef _run_case(case: DiffusionTestCase) -> dict:\n    default_port = get_dynamic_server_port()\n    port = int(os.environ.get(\"SGLANG_TEST_SERVER_PORT\", default_port))\n    mgr = ServerManager(\n        model=case.server_args.model_path,\n        port=port,\n        wait_deadline=float(os.environ.get(\"SGLANG_TEST_WAIT_SECS\", \"1200\")),\n        extra_args=_build_server_extra_args(case),\n        env_vars=_build_env_vars(case),\n    )\n    ctx = mgr.start()\n    try:\n        sp = case.sampling_params\n        output_size = os.environ.get(\"SGLANG_TEST_OUTPUT_SIZE\", sp.output_size)\n        client = _openai_client(ctx.port)\n        gen = get_generate_fn(\n            model_path=case.server_args.model_path,\n            modality=case.server_args.modality,\n            sampling_params=sp,\n        )\n        rid, _ = gen(case.id, client)\n        rec = wait_for_req_perf_record(\n            rid,\n            ctx.perf_log_path,\n            timeout=float(os.environ.get(\"SGLANG_PERF_TIMEOUT\", \"300\")),\n        )\n        if rec is None:\n            raise RuntimeError(f\"missing perf record: {case.id}\")\n        from sglang.multimodal_gen.test.server.testcase_configs import (\n            PerformanceSummary,\n        )\n\n        perf = PerformanceSummary.from_req_perf_record(\n            rec, BASELINE_CONFIG.step_fractions\n        )\n        if case.server_args.modality == \"video\" and sp.num_frames and sp.num_frames > 0:\n            if \"per_frame_generation\" not in perf.stage_metrics:\n                perf.stage_metrics[\"per_frame_generation\"] = perf.e2e_ms / sp.num_frames\n\n        return {\n            \"stages_ms\": {k: round(v, 2) for k, v in perf.stage_metrics.items()},\n            \"denoise_step_ms\": {\n                str(k): round(v, 2) for k, v in perf.all_denoise_steps.items()\n            },\n            \"expected_e2e_ms\": round(perf.e2e_ms, 2),\n            \"expected_avg_denoise_ms\": round(perf.avg_denoise_ms, 2),\n            \"expected_median_denoise_ms\": round(perf.median_denoise_ms, 2),\n        }\n    finally:\n        ctx.cleanup()\n\n\ndef main() -> int:\n    ap = argparse.ArgumentParser()\n    ap.add_argument(\"--baseline\", default=\"\")\n    ap.add_argument(\"--out\", default=\"\")\n    ap.add_argument(\"--match\", default=\"\")\n    ap.add_argument(\"--case\", action=\"append\", default=[])\n    ap.add_argument(\"--all-from-baseline\", action=\"store_true\")\n    ap.add_argument(\"--timeout\", type=float, default=300.0)\n    args = ap.parse_args()\n\n    os.environ.setdefault(\"SGLANG_GEN_BASELINE\", \"1\")\n    os.environ[\"SGLANG_PERF_TIMEOUT\"] = str(args.timeout)\n\n    baseline_path = Path(args.baseline) if args.baseline else _baseline_path()\n    out_path = Path(args.out) if args.out else baseline_path\n    data = json.loads(baseline_path.read_text(encoding=\"utf-8\"))\n    scenarios = data.setdefault(\"scenarios\", {})\n\n    ids = set(args.case) if args.case else None\n    pat = re.compile(args.match) if args.match else None\n    if args.all_from_baseline:\n        ids = set(scenarios.keys())\n        pat = None\n\n    all_cases = _all_cases()\n    cases = []\n    for c in all_cases:\n        if ids and c.id not in ids:\n            continue\n        if pat and not pat.search(c.id):\n            continue\n        cases.append(c)\n\n    if args.all_from_baseline and ids:\n        case_ids = {c.id for c in all_cases}\n        missing = sorted([i for i in ids if i not in case_ids])\n        if missing:\n            sys.stderr.write(f\"missing cases in testcase_configs.py: {len(missing)}\\n\")\n\n    if not cases:\n        return 0\n\n    for c in cases:\n        prev = scenarios.get(c.id, {})\n        note = prev.get(\"notes\")\n        baseline = _run_case(c)\n        if note is not None:\n            baseline[\"notes\"] = note\n        scenarios[c.id] = baseline\n        sys.stdout.write(f\"{c.id}\\n\")\n        sys.stdout.flush()\n        _torch_cleanup()\n\n    out_path.write_text(json.dumps(data, indent=4) + \"\\n\", encoding=\"utf-8\")\n    return 0\n\n\nif __name__ == \"__main__\":\n    sys.exit(main())\n"
  },
  {
    "path": "python/sglang/multimodal_gen/test/server/ascend/perf_baselines_npu.json",
    "content": "{\n    \"metadata\": {\n        \"model\": \"Diffusion Server\",\n        \"hardware\": \"CI A2 64GB pool\",\n        \"description\": \"Reference numbers captured from the CI diffusion server baseline run\"\n    },\n    \"scenarios\": {\n        \"flux_image_t2i_npu\": {\n            \"stages_ms\": {\n                \"InputValidationStage\": 0.07,\n                \"TextEncodingStage\": 154.51,\n                \"TimestepPreparationStage\": 53.52,\n                \"LatentPreparationStage\": 0.39,\n                \"DenoisingStage\": 19423.39,\n                \"DecodingStage\": 40.14\n            },\n            \"denoise_step_ms\": {\n                \"0\": 123.16,\n                \"1\": 91.7,\n                \"2\": 265.62,\n                \"3\": 402.68,\n                \"4\": 402.86,\n                \"5\": 402.78,\n                \"6\": 402.99,\n                \"7\": 402.77,\n                \"8\": 402.59,\n                \"9\": 402.93,\n                \"10\": 402.05,\n                \"11\": 402.99,\n                \"12\": 402.29,\n                \"13\": 403.07,\n                \"14\": 402.62,\n                \"15\": 402.99,\n                \"16\": 402.68,\n                \"17\": 403.0,\n                \"18\": 402.74,\n                \"19\": 402.85,\n                \"20\": 402.83,\n                \"21\": 403.03,\n                \"22\": 402.56,\n                \"23\": 402.84,\n                \"24\": 402.79,\n                \"25\": 402.95,\n                \"26\": 402.65,\n                \"27\": 403.01,\n                \"28\": 402.66,\n                \"29\": 402.92,\n                \"30\": 402.75,\n                \"31\": 403.0,\n                \"32\": 402.9,\n                \"33\": 402.48,\n                \"34\": 402.85,\n                \"35\": 402.03,\n                \"36\": 402.93,\n                \"37\": 402.3,\n                \"38\": 403.12,\n                \"39\": 402.83,\n                \"40\": 402.84,\n                \"41\": 402.75,\n                \"42\": 402.97,\n                \"43\": 402.62,\n                \"44\": 402.91,\n                \"45\": 402.81,\n                \"46\": 402.97,\n                \"47\": 402.57,\n                \"48\": 403.0,\n                \"49\": 402.75\n            },\n            \"expected_e2e_ms\": 23819.1,\n            \"expected_avg_denoise_ms\": 388.22,\n            \"expected_median_denoise_ms\": 402.82\n        },\n        \"flux_2_image_t2i_2npu\": {\n            \"stages_ms\": {\n                \"InputValidationStage\": 0.06,\n                \"TextEncodingStage\": 5628.31,\n                \"ImageVAEEncodingStage\": 0.01,\n                \"LatentPreparationStage\": 0.75,\n                \"TimestepPreparationStage\": 30.68,\n                \"DenoisingStage\": 55002.26,\n                \"DecodingStage\": 43.73\n            },\n            \"denoise_step_ms\": {\n                \"0\": 110.35,\n                \"1\": 301.82,\n                \"2\": 1139.81,\n                \"3\": 1114.17,\n                \"4\": 1099.34,\n                \"5\": 1099.12,\n                \"6\": 1100.16,\n                \"7\": 1099.67,\n                \"8\": 1099.09,\n                \"9\": 1089.81,\n                \"10\": 1109.73,\n                \"11\": 1099.97,\n                \"12\": 1100.26,\n                \"13\": 1099.67,\n                \"14\": 1099.79,\n                \"15\": 1099.6,\n                \"16\": 1100.16,\n                \"17\": 1099.87,\n                \"18\": 1100.02,\n                \"19\": 1099.34,\n                \"20\": 1099.6,\n                \"21\": 1099.45,\n                \"22\": 1100.2,\n                \"23\": 1099.29,\n                \"24\": 1098.86,\n                \"25\": 1090.38,\n                \"26\": 1109.19,\n                \"27\": 1099.67,\n                \"28\": 1100.06,\n                \"29\": 1099.22,\n                \"30\": 1100.08,\n                \"31\": 1098.86,\n                \"32\": 1099.73,\n                \"33\": 1099.11,\n                \"34\": 1100.13,\n                \"35\": 1103.97,\n                \"36\": 1095.26,\n                \"37\": 1099.38,\n                \"38\": 1099.34,\n                \"39\": 1099.17,\n                \"40\": 1100.08,\n                \"41\": 1089.89,\n                \"42\": 1106.69,\n                \"43\": 1102.57,\n                \"44\": 1100.17,\n                \"45\": 1099.21,\n                \"46\": 1100.42,\n                \"47\": 1099.38,\n                \"48\": 1099.59,\n                \"49\": 1099.47\n            },\n            \"expected_e2e_ms\": 64195.08,\n            \"expected_avg_denoise_ms\": 1065.0,\n            \"expected_median_denoise_ms\": 1099.63\n        },\n        \"wan2_1_t2v_1.3b_1_npu\": {\n            \"stages_ms\": {\n                \"InputValidationStage\": 0.07,\n                \"TextEncodingStage\": 876.11,\n                \"LatentPreparationStage\": 0.25,\n                \"TimestepPreparationStage\": 2.9,\n                \"DenoisingStage\": 26188.0,\n                \"DecodingStage\": 320.03,\n                \"per_frame_generation\": null\n            },\n            \"denoise_step_ms\": {\n                \"0\": 103.56,\n                \"1\": 329.59,\n                \"2\": 545.23,\n                \"3\": 537.0,\n                \"4\": 536.27,\n                \"5\": 536.29,\n                \"6\": 536.33,\n                \"7\": 536.0,\n                \"8\": 536.17,\n                \"9\": 536.28,\n                \"10\": 535.53,\n                \"11\": 536.04,\n                \"12\": 536.42,\n                \"13\": 536.09,\n                \"14\": 536.32,\n                \"15\": 536.25,\n                \"16\": 536.36,\n                \"17\": 536.21,\n                \"18\": 536.29,\n                \"19\": 536.15,\n                \"20\": 536.28,\n                \"21\": 536.5,\n                \"22\": 536.46,\n                \"23\": 536.06,\n                \"24\": 536.45,\n                \"25\": 536.24,\n                \"26\": 536.14,\n                \"27\": 536.13,\n                \"28\": 536.22,\n                \"29\": 536.15,\n                \"30\": 535.94,\n                \"31\": 536.1,\n                \"32\": 536.13,\n                \"33\": 536.2,\n                \"34\": 536.24,\n                \"35\": 536.34,\n                \"36\": 536.54,\n                \"37\": 536.42,\n                \"38\": 536.41,\n                \"39\": 536.42,\n                \"40\": 536.13,\n                \"41\": 536.32,\n                \"42\": 536.23,\n                \"43\": 536.16,\n                \"44\": 536.05,\n                \"45\": 536.18,\n                \"46\": 536.08,\n                \"47\": 536.34,\n                \"48\": 536.26,\n                \"49\": 535.41\n            },\n            \"expected_e2e_ms\": 38738.17,\n            \"expected_avg_denoise_ms\": 523.62,\n            \"expected_median_denoise_ms\": 536.23\n        },\n        \"wan2_2_t2v_14b_w8a8_8npu\": {\n            \"stages_ms\": {\n                \"InputValidationStage\": 0.07,\n                \"TextEncodingStage\": 301.21,\n                \"LatentPreparationStage\": 0.2,\n                \"TimestepPreparationStage\": 2.68,\n                \"DenoisingStage\": 83661.46,\n                \"DecodingStage\": 232.94,\n                \"per_frame_generation\": null\n            },\n            \"denoise_step_ms\": {\n                \"0\": 1919.92,\n                \"1\": 2099.45,\n                \"2\": 2092.11,\n                \"3\": 2090.84,\n                \"4\": 2089.89,\n                \"5\": 2090.6,\n                \"6\": 2090.77,\n                \"7\": 2091.43,\n                \"8\": 2091.24,\n                \"9\": 2067.83,\n                \"10\": 2078.02,\n                \"11\": 2090.75,\n                \"12\": 2108.36,\n                \"13\": 2096.16,\n                \"14\": 2091.74,\n                \"15\": 2091.47,\n                \"16\": 2091.6,\n                \"17\": 2091.94,\n                \"18\": 2091.39,\n                \"19\": 2090.69,\n                \"20\": 2090.27,\n                \"21\": 2090.77,\n                \"22\": 2090.24,\n                \"23\": 2091.65,\n                \"24\": 2091.21,\n                \"25\": 2126.82,\n                \"26\": 2338.39,\n                \"27\": 2085.18,\n                \"28\": 2084.68,\n                \"29\": 2084.71,\n                \"30\": 2051.48,\n                \"31\": 2104.3,\n                \"32\": 2084.58,\n                \"33\": 2085.04,\n                \"34\": 2085.03,\n                \"35\": 2084.58,\n                \"36\": 2084.41,\n                \"37\": 2085.16,\n                \"38\": 2084.88,\n                \"39\": 2083.54\n            },\n            \"expected_e2e_ms\": 91733.92,\n            \"expected_avg_denoise_ms\": 2091.33,\n            \"expected_median_denoise_ms\": 2090.72\n        }\n    }\n}\n"
  },
  {
    "path": "python/sglang/multimodal_gen/test/server/ascend/test_server_1_npu.py",
    "content": "\"\"\"\nConfig-driven diffusion performance test with pytest parametrization.\n\n\nIf the actual run is significantly better than the baseline, the improved cases with their updated baseline will be printed\n\"\"\"\n\nfrom __future__ import annotations\n\nimport pytest\n\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.test.server.ascend.testcase_configs_npu import ONE_NPU_CASES\nfrom sglang.multimodal_gen.test.server.test_server_common import (  # noqa: F401\n    DiffusionServerBase,\n    diffusion_server,\n)\nfrom sglang.multimodal_gen.test.server.testcase_configs import DiffusionTestCase\n\nlogger = init_logger(__name__)\n\n\nclass TestDiffusionServerOneNpu(DiffusionServerBase):\n    \"\"\"Performance tests for 1-NPU diffusion cases.\"\"\"\n\n    @pytest.fixture(params=ONE_NPU_CASES, ids=lambda c: c.id)\n    def case(self, request) -> DiffusionTestCase:\n        \"\"\"Provide a DiffusionTestCase for each 1-NPU test.\"\"\"\n        return request.param\n"
  },
  {
    "path": "python/sglang/multimodal_gen/test/server/ascend/test_server_2_npu.py",
    "content": "\"\"\"\nConfig-driven diffusion performance test with pytest parametrization.\n\n\nIf the actual run is significantly better than the baseline, the improved cases with their updated baseline will be printed\n\"\"\"\n\nfrom __future__ import annotations\n\nimport pytest\n\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.test.server.ascend.testcase_configs_npu import TWO_NPU_CASES\nfrom sglang.multimodal_gen.test.server.test_server_common import (  # noqa: F401\n    DiffusionServerBase,\n    diffusion_server,\n)\nfrom sglang.multimodal_gen.test.server.testcase_configs import DiffusionTestCase\n\nlogger = init_logger(__name__)\n\n\nclass TestDiffusionServerTwoNpu(DiffusionServerBase):\n    \"\"\"Performance tests for 2-NPU diffusion cases.\"\"\"\n\n    @pytest.fixture(params=TWO_NPU_CASES, ids=lambda c: c.id)\n    def case(self, request) -> DiffusionTestCase:\n        \"\"\"Provide a DiffusionTestCase for each 2-NPU test.\"\"\"\n        return request.param\n"
  },
  {
    "path": "python/sglang/multimodal_gen/test/server/ascend/test_server_8_npu.py",
    "content": "\"\"\"\nConfig-driven diffusion performance test with pytest parametrization.\n\n\nIf the actual run is significantly better than the baseline, the improved cases with their updated baseline will be printed\n\"\"\"\n\nfrom __future__ import annotations\n\nimport pytest\n\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.test.server.ascend.testcase_configs_npu import (\n    EIGHT_NPU_CASES,\n)\nfrom sglang.multimodal_gen.test.server.test_server_common import (  # noqa: F401\n    DiffusionServerBase,\n    diffusion_server,\n)\nfrom sglang.multimodal_gen.test.server.testcase_configs import DiffusionTestCase\n\nlogger = init_logger(__name__)\n\n\nclass TestDiffusionServerEightNpu(DiffusionServerBase):\n    \"\"\"Performance tests for 8-NPU diffusion cases.\"\"\"\n\n    @pytest.fixture(params=EIGHT_NPU_CASES, ids=lambda c: c.id)\n    def case(self, request) -> DiffusionTestCase:\n        \"\"\"Provide a DiffusionTestCase for each 8-NPU test.\"\"\"\n        return request.param\n"
  },
  {
    "path": "python/sglang/multimodal_gen/test/server/ascend/testcase_configs_npu.py",
    "content": "from sglang.multimodal_gen.test.server.testcase_configs import (\n    T2V_PROMPT,\n    DiffusionSamplingParams,\n    DiffusionServerArgs,\n    DiffusionTestCase,\n    T2I_sampling_params,\n)\n\nONE_NPU_CASES: list[DiffusionTestCase] = [\n    # === Text to Image (T2I) ===\n    DiffusionTestCase(\n        \"flux_image_t2i_npu\",\n        DiffusionServerArgs(\n            model_path=\"/root/.cache/modelscope/hub/models/black-forest-labs/FLUX.1-dev\",\n            modality=\"image\",\n        ),\n        T2I_sampling_params,\n    ),\n    # === Text to Video (T2V) ===\n    DiffusionTestCase(\n        \"wan2_1_t2v_1.3b_1_npu\",\n        DiffusionServerArgs(\n            model_path=\"/root/.cache/modelscope/hub/models/Wan-AI/Wan2.1-T2V-1.3B-Diffusers\",\n            modality=\"video\",\n            custom_validator=\"video\",\n        ),\n        DiffusionSamplingParams(\n            prompt=T2V_PROMPT,\n        ),\n    ),\n]\n\nTWO_NPU_CASES: list[DiffusionTestCase] = [\n    # === Text to Image (T2I) ===\n    DiffusionTestCase(\n        \"flux_2_image_t2i_2npu\",\n        DiffusionServerArgs(\n            model_path=\"/root/.cache/modelscope/hub/models/black-forest-labs/FLUX.2-dev\",\n            modality=\"image\",\n            num_gpus=2,\n            tp_size=2,\n        ),\n        T2I_sampling_params,\n    ),\n]\n\nEIGHT_NPU_CASES: list[DiffusionTestCase] = [\n    # === Text to Video (T2V) ===\n    DiffusionTestCase(\n        \"wan2_2_t2v_14b_w8a8_8npu\",\n        DiffusionServerArgs(\n            model_path=\"/root/.cache/modelscope/hub/models/Eco-Tech/Wan2.2-T2V-A14B-Diffusers-w8a8\",\n            modality=\"video\",\n            custom_validator=\"video\",\n            num_gpus=8,\n            tp_size=4,\n        ),\n        DiffusionSamplingParams(\n            prompt=T2V_PROMPT,\n        ),\n    ),\n]\n"
  },
  {
    "path": "python/sglang/multimodal_gen/test/server/conftest.py",
    "content": "import os\n\nimport pytest\n\nprint(\"[CONFTEST] Loading conftest.py at import time\")\n\n\ndef pytest_configure(config):\n    \"\"\"\n    Create the perf results StashKey once and store it in config.\n    This hook runs once per test session, before module double-import issues.\n    \"\"\"\n    if not hasattr(config, \"_diffusion_perf_key\"):\n        config._diffusion_perf_key = pytest.StashKey[list]()\n        print(f\"[CONFTEST] Created perf_results_key: {config._diffusion_perf_key}\")\n\n\ndef add_perf_results(config, results: list):\n    \"\"\"Add performance results to the shared stash.\"\"\"\n    # Get the shared key from config (created once in pytest_configure)\n    key = config._diffusion_perf_key\n    existing = config.stash.get(key, [])\n    existing.extend(results)\n    config.stash[key] = existing\n    print(f\"[CONFTEST] Added {len(results)} results, total now: {len(existing)}\")\n\n\n@pytest.fixture(scope=\"session\")\ndef perf_config(request):\n    \"\"\"Provide access to pytest config for storing perf results.\"\"\"\n    return request.config\n\n\ndef _write_github_step_summary(content: str):\n    \"\"\"Write content to GitHub Step Summary if available.\"\"\"\n    summary_file = os.environ.get(\"GITHUB_STEP_SUMMARY\")\n    if summary_file:\n        with open(summary_file, \"a\") as f:\n            f.write(content)\n\n\ndef _write_results_json(results: list, output_path: str = \"diffusion-results.json\"):\n    \"\"\"Write performance results to JSON file for CI artifact collection.\"\"\"\n    import json\n\n    try:\n        with open(output_path, \"w\") as f:\n            json.dump(results, f, indent=2)\n        print(f\"[CONFTEST] Wrote results to {output_path}\")\n    except Exception as e:\n        print(f\"[CONFTEST] Failed to write results JSON: {e}\")\n\n\ndef _generate_diffusion_markdown_report(results: list) -> str:\n    \"\"\"Generate a markdown report for diffusion performance results.\"\"\"\n    if not results:\n        return \"\"\n\n    gpu_config = os.environ.get(\"GPU_CONFIG\", \"\")\n    header = \"## Diffusion Performance Summary\"\n    if gpu_config:\n        header += f\" [{gpu_config}]\"\n    header += \"\\n\\n\"\n\n    # Main performance table\n    markdown = header\n    markdown += \"| Test Suite | Test Name | Modality | E2E (ms) | Avg Denoise (ms) | Median Denoise (ms) |\\n\"\n    markdown += \"| ---------- | --------- | -------- | -------- | ---------------- | ------------------- |\\n\"\n\n    for entry in sorted(results, key=lambda x: (x[\"class_name\"], x[\"test_name\"])):\n        modality = entry.get(\"modality\", \"image\")\n        markdown += (\n            f\"| {entry['class_name']} | {entry['test_name']} | {modality} | \"\n            f\"{entry['e2e_ms']:.2f} | {entry['avg_denoise_ms']:.2f} | \"\n            f\"{entry['median_denoise_ms']:.2f} |\\n\"\n        )\n\n    # Video-specific metrics table (if any video tests)\n    video_results = [r for r in results if r.get(\"modality\") == \"video\"]\n    if video_results:\n        markdown += \"\\n### Video Generation Metrics\\n\\n\"\n        markdown += \"| Test Name | FPS | Total Frames | Avg Frame Time (ms) |\\n\"\n        markdown += \"| --------- | --- | ------------ | ------------------- |\\n\"\n        for entry in video_results:\n            fps = entry.get(\"frames_per_second\", \"N/A\")\n            frames = entry.get(\"total_frames\", \"N/A\")\n            avg_frame = entry.get(\"avg_frame_time_ms\", \"N/A\")\n            if isinstance(fps, float):\n                fps = f\"{fps:.2f}\"\n            if isinstance(avg_frame, float):\n                avg_frame = f\"{avg_frame:.2f}\"\n            markdown += f\"| {entry['test_name']} | {fps} | {frames} | {avg_frame} |\\n\"\n\n    return markdown\n\n\ndef pytest_sessionfinish(session):\n    \"\"\"\n    This hook is called by pytest at the end of the entire test session.\n    It prints a consolidated summary of all performance results.\n    \"\"\"\n    # Get results from stash using the shared key from config\n    key = session.config._diffusion_perf_key\n    results = session.config.stash.get(key, [])\n    print(f\"\\n[DEBUG] pytest_sessionfinish called, has {len(results)} entries\")\n    if not results:\n        print(\"[DEBUG] No results collected, skipping summary output\")\n        return\n\n    # Print to stdout (existing behavior)\n    print(\"\\n\\n\" + \"=\" * 35 + \" Performance Summary \" + \"=\" * 35)\n    print(\n        f\"{'Test Suite':<30} | {'Test Name':<20} | {'E2E (ms)':>12} | {'Avg Denoise (ms)':>18} | {'Median Denoise (ms)':>20}\"\n    )\n    print(\n        \"-\" * 30\n        + \"-+-\"\n        + \"-\" * 20\n        + \"-+-\"\n        + \"-\" * 12\n        + \"-+-\"\n        + \"-\" * 18\n        + \"-+-\"\n        + \"-\" * 20\n    )\n\n    for entry in sorted(results, key=lambda x: x[\"class_name\"]):\n        print(\n            f\"{entry['class_name']:<30} | {entry['test_name']:<20} | {entry['e2e_ms']:>12.2f} | \"\n            f\"{entry['avg_denoise_ms']:>18.2f} | {entry['median_denoise_ms']:>20.2f}\"\n        )\n\n    print(\"=\" * 91)\n\n    print(\"\\n\\n\" + \"=\" * 36 + \" Detailed Reports \" + \"=\" * 37)\n    for entry in sorted(results, key=lambda x: x[\"class_name\"]):\n        print(f\"\\n--- Details for {entry['class_name']} / {entry['test_name']} ---\")\n        stage_report = \", \".join(\n            f\"{name}:{duration:.2f}ms\"\n            for name, duration in entry.get(\"stage_metrics\", {}).items()\n        )\n        if stage_report:\n            print(f\"    Stages: {stage_report}\")\n\n        sampled_steps = entry.get(\"sampled_steps\") or {}\n        if sampled_steps:\n            step_report = \", \".join(\n                f\"{idx}:{duration:.2f}ms\"\n                for idx, duration in sorted(sampled_steps.items())\n            )\n            print(f\"    Sampled Steps: {step_report}\")\n    print(\"=\" * 91)\n\n    # Write to GitHub Step Summary (new behavior for CI monitoring)\n    markdown_report = _generate_diffusion_markdown_report(results)\n    if markdown_report:\n        _write_github_step_summary(markdown_report)\n\n    # Write results to JSON file for CI artifact collection\n    _write_results_json(results)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/test/server/perf_baselines.json",
    "content": "{\n    \"metadata\": {\n        \"model\": \"Diffusion Server\",\n        \"hardware\": \"CI H100 80GB pool\",\n        \"description\": \"Reference numbers captured from the CI diffusion server baseline run\"\n    },\n    \"tolerances\": {\n        \"long_term\": {\n            \"e2e\": 0.1,\n            \"denoise_stage\": 0.05,\n            \"non_denoise_stage\": 0.4,\n            \"denoise_step\": 0.2,\n            \"denoise_agg\": 0.1\n        },\n        \"pr_test\": {\n            \"e2e\": 0.15,\n            \"denoise_stage\": 0.1,\n            \"non_denoise_stage\": 0.6,\n            \"denoise_step\": 0.25,\n            \"denoise_agg\": 0.15\n        }\n    },\n    \"improvement_reporting\": {\n        \"threshold\": 0.2\n    },\n    \"sampling\": {\n        \"step_fractions\": [\n            0.0,\n            0.2,\n            0.4,\n            0.6,\n            0.8,\n            1.0\n        ]\n    },\n    \"scenarios\": {\n        \"qwen_image_t2i\": {\n            \"notes\": \"Single-image generation using the default prompt\",\n            \"stages_ms\": {\n                \"DecodingStage\": 51.86,\n                \"TextEncodingStage\": 611.83,\n                \"InputValidationStage\": 0.05,\n                \"DenoisingStage\": 14289.46,\n                \"LatentPreparationStage\": 0.2,\n                \"TimestepPreparationStage\": 3.34\n            },\n            \"denoise_step_ms\": {\n                \"0\": 240.5,\n                \"1\": 279.1,\n                \"2\": 283.29,\n                \"3\": 296.63,\n                \"4\": 287.72,\n                \"5\": 283.39,\n                \"6\": 283.98,\n                \"7\": 291.82,\n                \"8\": 283.1,\n                \"9\": 284.43,\n                \"10\": 288.95,\n                \"11\": 285.6,\n                \"12\": 285.99,\n                \"13\": 285.47,\n                \"14\": 289.66,\n                \"15\": 285.74,\n                \"16\": 284.15,\n                \"17\": 290.27,\n                \"18\": 288.04,\n                \"19\": 284.57,\n                \"20\": 286.69,\n                \"21\": 288.95,\n                \"22\": 287.09,\n                \"23\": 285.6,\n                \"24\": 289.31,\n                \"25\": 285.48,\n                \"26\": 285.53,\n                \"27\": 288.13,\n                \"28\": 287.65,\n                \"29\": 285.97,\n                \"30\": 288.9,\n                \"31\": 287.97,\n                \"32\": 286.48,\n                \"33\": 285.38,\n                \"34\": 286.62,\n                \"35\": 288.22,\n                \"36\": 285.6,\n                \"37\": 286.61,\n                \"38\": 287.06,\n                \"39\": 286.2,\n                \"40\": 284.6,\n                \"41\": 285.69,\n                \"42\": 288.46,\n                \"43\": 285.53,\n                \"44\": 285.34,\n                \"45\": 285.74,\n                \"46\": 287.25,\n                \"47\": 285.0,\n                \"48\": 286.82,\n                \"49\": 287.19\n            },\n            \"expected_e2e_ms\": 14959.11,\n            \"expected_avg_denoise_ms\": 285.67,\n            \"expected_median_denoise_ms\": 286.1\n        },\n        \"qwen_image_t2i_2_gpus\": {\n            \"stages_ms\": {\n                \"InputValidationStage\": 0.04,\n                \"TextEncodingStage\": 693.2,\n                \"TimestepPreparationStage\": 2.84,\n                \"LatentPreparationStage\": 9.13,\n                \"DenoisingStage\": 24529.77,\n                \"DecodingStage\": 612.79\n            },\n            \"denoise_step_ms\": {\n                \"0\": 405.94,\n                \"1\": 420.06,\n                \"2\": 414.79,\n                \"3\": 392.4,\n                \"4\": 408.14,\n                \"5\": 605.0,\n                \"6\": 469.39,\n                \"7\": 574.04,\n                \"8\": 539.61,\n                \"9\": 452.93,\n                \"10\": 279.36,\n                \"11\": 271.8,\n                \"12\": 438.26,\n                \"13\": 552.65,\n                \"14\": 576.1,\n                \"15\": 679.84,\n                \"16\": 543.0,\n                \"17\": 512.81,\n                \"18\": 522.27,\n                \"19\": 545.06,\n                \"20\": 545.85,\n                \"21\": 523.83,\n                \"22\": 519.36,\n                \"23\": 513.78,\n                \"24\": 532.54,\n                \"25\": 524.94,\n                \"26\": 542.59,\n                \"27\": 570.91,\n                \"28\": 568.73,\n                \"29\": 564.52,\n                \"30\": 564.57,\n                \"31\": 544.94,\n                \"32\": 496.81,\n                \"33\": 488.98,\n                \"34\": 457.18,\n                \"35\": 441.42,\n                \"36\": 437.44,\n                \"37\": 477.6,\n                \"38\": 429.17,\n                \"39\": 465.55,\n                \"40\": 448.25,\n                \"41\": 511.83,\n                \"42\": 450.6,\n                \"43\": 375.78,\n                \"44\": 504.4,\n                \"45\": 524.44,\n                \"46\": 535.22,\n                \"47\": 514.52,\n                \"48\": 431.58,\n                \"49\": 410.68\n            },\n            \"expected_e2e_ms\": 25850.45,\n            \"expected_avg_denoise_ms\": 490.43,\n            \"expected_median_denoise_ms\": 512.32\n        },\n        \"flux_image_t2i\": {\n            \"stages_ms\": {\n                \"DecodingStage\": 32.72,\n                \"TextEncodingStage\": 51.96,\n                \"InputValidationStage\": 0.03,\n                \"DenoisingStage\": 7545.16,\n                \"LatentPreparationStage\": 0.2,\n                \"TimestepPreparationStage\": 2.43\n            },\n            \"denoise_step_ms\": {\n                \"0\": 50.06,\n                \"1\": 58.88,\n                \"2\": 151.24,\n                \"3\": 150.97,\n                \"4\": 151.23,\n                \"5\": 151.63,\n                \"6\": 159.11,\n                \"7\": 158.31,\n                \"8\": 153.42,\n                \"9\": 151.42,\n                \"10\": 151.91,\n                \"11\": 151.05,\n                \"12\": 151.52,\n                \"13\": 157.2,\n                \"14\": 152.76,\n                \"15\": 153.85,\n                \"16\": 153.02,\n                \"17\": 151.09,\n                \"18\": 151.49,\n                \"19\": 155.13,\n                \"20\": 155.2,\n                \"21\": 152.82,\n                \"22\": 152.2,\n                \"23\": 150.99,\n                \"24\": 152.74,\n                \"25\": 153.45,\n                \"26\": 153.63,\n                \"27\": 154.92,\n                \"28\": 152.72,\n                \"29\": 151.84,\n                \"30\": 151.84,\n                \"31\": 152.44,\n                \"32\": 153.03,\n                \"33\": 154.07,\n                \"34\": 152.36,\n                \"35\": 153.48,\n                \"36\": 152.05,\n                \"37\": 152.45,\n                \"38\": 152.42,\n                \"39\": 154.91,\n                \"40\": 152.68,\n                \"41\": 153.43,\n                \"42\": 151.62,\n                \"43\": 153.52,\n                \"44\": 153.13,\n                \"45\": 152.85,\n                \"46\": 152.33,\n                \"47\": 151.61,\n                \"48\": 152.4,\n                \"49\": 152.33\n            },\n            \"expected_e2e_ms\": 7798.99,\n            \"expected_avg_denoise_ms\": 150.77,\n            \"expected_median_denoise_ms\": 152.45\n        },\n        \"flux_2_image_t2i\": {\n            \"stages_ms\": {\n                \"LatentPreparationStage\": 0.52,\n                \"TimestepPreparationStage\": 2.91,\n                \"TextEncodingStage\": 518.54,\n                \"ImageVAEEncodingStage\": 0.0,\n                \"InputValidationStage\": 0.05,\n                \"DenoisingStage\": 24901.97,\n                \"DecodingStage\": 8.98\n            },\n            \"denoise_step_ms\": {\n                \"0\": 69.14,\n                \"1\": 132.57,\n                \"2\": 508.67,\n                \"3\": 493.52,\n                \"4\": 504.31,\n                \"5\": 492.99,\n                \"6\": 501.91,\n                \"7\": 495.18,\n                \"8\": 500.87,\n                \"9\": 497.36,\n                \"10\": 498.74,\n                \"11\": 497.46,\n                \"12\": 499.08,\n                \"13\": 494.65,\n                \"14\": 500.35,\n                \"15\": 496.89,\n                \"16\": 500.23,\n                \"17\": 497.01,\n                \"18\": 501.68,\n                \"19\": 493.8,\n                \"20\": 501.1,\n                \"21\": 494.81,\n                \"22\": 501.04,\n                \"23\": 499.27,\n                \"24\": 500.04,\n                \"25\": 497.14,\n                \"26\": 499.05,\n                \"27\": 494.91,\n                \"28\": 496.89,\n                \"29\": 498.53,\n                \"30\": 497.94,\n                \"31\": 497.09,\n                \"32\": 497.7,\n                \"33\": 497.58,\n                \"34\": 496.43,\n                \"35\": 497.7,\n                \"36\": 497.37,\n                \"37\": 497.17,\n                \"38\": 499.27,\n                \"39\": 495.52,\n                \"40\": 501.67,\n                \"41\": 495.11,\n                \"42\": 500.69,\n                \"43\": 501.61,\n                \"44\": 501.91,\n                \"45\": 495.58,\n                \"46\": 499.37,\n                \"47\": 496.8,\n                \"48\": 497.49,\n                \"49\": 495.69\n            },\n            \"expected_e2e_ms\": 25832.82,\n            \"expected_avg_denoise_ms\": 489.43,\n            \"expected_median_denoise_ms\": 497.53\n        },\n        \"flux_2_klein_image_t2i\": {\n            \"stages_ms\": {\n                \"DecodingStage\": 9.27,\n                \"TextEncodingStage\": 92.17,\n                \"InputValidationStage\": 0.05,\n                \"ImageVAEEncodingStage\": 0.0,\n                \"DenoisingStage\": 252.01,\n                \"LatentPreparationStage\": 0.42,\n                \"TimestepPreparationStage\": 1.5\n            },\n            \"denoise_step_ms\": {\n                \"0\": 19.91,\n                \"1\": 19.32,\n                \"2\": 51.99,\n                \"3\": 61.78\n            },\n            \"expected_e2e_ms\": 430.73,\n            \"expected_avg_denoise_ms\": 38.25,\n            \"expected_median_denoise_ms\": 35.95\n        },\n        \"layerwise_offload\": {\n            \"stages_ms\": {\n                \"InputValidationStage\": 0.06,\n                \"TextEncodingStage\": 513.58,\n                \"LatentPreparationStage\": 0.46,\n                \"TimestepPreparationStage\": 2.38,\n                \"DenoisingStage\": 52187.62,\n                \"DecodingStage\": 190.31\n            },\n            \"denoise_step_ms\": {\n                \"0\": 1033.45,\n                \"1\": 137.03,\n                \"2\": 1046.96,\n                \"3\": 1039.28,\n                \"4\": 1039.05,\n                \"5\": 1043.91,\n                \"6\": 1041.75,\n                \"7\": 1037.6,\n                \"8\": 1043.54,\n                \"9\": 1048.63,\n                \"10\": 1039.8,\n                \"11\": 1042.25,\n                \"12\": 1041.54,\n                \"13\": 1045.89,\n                \"14\": 1038.99,\n                \"15\": 1041.82,\n                \"16\": 1038.32,\n                \"17\": 1045.53,\n                \"18\": 1046.54,\n                \"19\": 1041.22,\n                \"20\": 1044.55,\n                \"21\": 1041.31,\n                \"22\": 1051.28,\n                \"23\": 1043.12,\n                \"24\": 1044.65,\n                \"25\": 1042.25,\n                \"26\": 1046.47,\n                \"27\": 1052.9,\n                \"28\": 1039.04,\n                \"29\": 1042.39,\n                \"30\": 1045.33,\n                \"31\": 1038.05,\n                \"32\": 1037.76,\n                \"33\": 1037.93,\n                \"34\": 1052.85,\n                \"35\": 1045.59,\n                \"36\": 1054.32,\n                \"37\": 1044.59,\n                \"38\": 1043.57,\n                \"39\": 1041.93,\n                \"40\": 1043.59,\n                \"41\": 1046.17,\n                \"42\": 1046.92,\n                \"43\": 1047.04,\n                \"44\": 1046.8,\n                \"45\": 1041.86,\n                \"46\": 1041.05,\n                \"47\": 1044.04,\n                \"48\": 1039.77,\n                \"49\": 1047.12\n            },\n            \"expected_e2e_ms\": 53290.15,\n            \"expected_avg_denoise_ms\": 1025.35,\n            \"expected_median_denoise_ms\": 1043.33\n        },\n        \"flux_2_ti2i\": {\n            \"stages_ms\": {\n                \"InputValidationStage\": 99.82,\n                \"TextEncodingStage\": 519.88,\n                \"ImageVAEEncodingStage\": 254.56,\n                \"LatentPreparationStage\": 12.4,\n                \"TimestepPreparationStage\": 2.71,\n                \"DenoisingStage\": 54705.41,\n                \"DecodingStage\": 469.47\n            },\n            \"denoise_step_ms\": {\n                \"0\": 1067.03,\n                \"1\": 271.58,\n                \"2\": 1073.07,\n                \"3\": 1071.93,\n                \"4\": 1100.0,\n                \"5\": 1102.28,\n                \"6\": 1088.3,\n                \"7\": 1089.09,\n                \"8\": 1086.95,\n                \"9\": 1089.33,\n                \"10\": 1089.28,\n                \"11\": 1096.51,\n                \"12\": 1098.88,\n                \"13\": 1080.84,\n                \"14\": 1098.44,\n                \"15\": 1100.88,\n                \"16\": 1086.83,\n                \"17\": 1090.58,\n                \"18\": 1096.35,\n                \"19\": 1086.25,\n                \"20\": 1082.71,\n                \"21\": 1097.6,\n                \"22\": 1098.72,\n                \"23\": 1100.9,\n                \"24\": 1099.02,\n                \"25\": 1101.52,\n                \"26\": 1098.75,\n                \"27\": 1101.41,\n                \"28\": 1091.75,\n                \"29\": 1087.2,\n                \"30\": 1101.33,\n                \"31\": 1098.14,\n                \"32\": 1100.14,\n                \"33\": 1098.91,\n                \"34\": 1100.05,\n                \"35\": 1099.12,\n                \"36\": 1100.22,\n                \"37\": 1103.29,\n                \"38\": 1092.79,\n                \"39\": 1086.59,\n                \"40\": 1094.81,\n                \"41\": 1105.6,\n                \"42\": 1100.54,\n                \"43\": 1099.95,\n                \"44\": 1096.5,\n                \"45\": 1086.69,\n                \"46\": 1095.85,\n                \"47\": 1092.85,\n                \"48\": 1086.17,\n                \"49\": 1099.67\n            },\n            \"expected_e2e_ms\": 56308.23,\n            \"expected_avg_denoise_ms\": 1077.26,\n            \"expected_median_denoise_ms\": 1096.5\n        },\n        \"flux_2_ti2i_multi_image_cache_dit\": {\n            \"stages_ms\": {\n                \"ImageVAEEncodingStage\": 282.83,\n                \"DenoisingStage\": 26936.93,\n                \"DecodingStage\": 129.33,\n                \"TextEncodingStage\": 737.01,\n                \"LatentPreparationStage\": 0.84,\n                \"TimestepPreparationStage\": 20.57,\n                \"InputValidationStage\": 84.29\n            },\n            \"denoise_step_ms\": {\n                \"0\": 1846.55,\n                \"1\": 232.6,\n                \"2\": 1621.05,\n                \"3\": 1606.1,\n                \"4\": 1441.78,\n                \"5\": 59.2,\n                \"6\": 60.32,\n                \"7\": 235.36,\n                \"8\": 1443.21,\n                \"9\": 59.71,\n                \"10\": 59.79,\n                \"11\": 236.12,\n                \"12\": 1439.11,\n                \"13\": 59.97,\n                \"14\": 60.46,\n                \"15\": 239.29,\n                \"16\": 1441.32,\n                \"17\": 60.76,\n                \"18\": 60.68,\n                \"19\": 240.16,\n                \"20\": 1442.46,\n                \"21\": 60.14,\n                \"22\": 61.1,\n                \"23\": 239.26,\n                \"24\": 1443.28,\n                \"25\": 59.41,\n                \"26\": 60.74,\n                \"27\": 238.02,\n                \"28\": 1444.38,\n                \"29\": 59.09,\n                \"30\": 59.12,\n                \"31\": 241.69,\n                \"32\": 1441.99,\n                \"33\": 60.44,\n                \"34\": 61.93,\n                \"35\": 241.0,\n                \"36\": 1443.56,\n                \"37\": 60.55,\n                \"38\": 61.07,\n                \"39\": 238.11,\n                \"40\": 1443.33,\n                \"41\": 59.08,\n                \"42\": 60.43,\n                \"43\": 239.23,\n                \"44\": 1444.53,\n                \"45\": 61.77,\n                \"46\": 61.65,\n                \"47\": 239.84,\n                \"48\": 1443.89,\n                \"49\": 59.92\n            },\n            \"expected_e2e_ms\": 28591.93,\n            \"expected_avg_denoise_ms\": 533.42,\n            \"expected_median_denoise_ms\": 235.74\n        },\n        \"flux_image_t2i_2_gpus\": {\n            \"stages_ms\": {\n                \"InputValidationStage\": 0.03,\n                \"TextEncodingStage\": 74.47,\n                \"TimestepPreparationStage\": 2.23,\n                \"LatentPreparationStage\": 6.17,\n                \"DenoisingStage\": 8400.49,\n                \"DecodingStage\": 381.56\n            },\n            \"denoise_step_ms\": {\n                \"0\": 73.27,\n                \"1\": 166.6,\n                \"2\": 167.31,\n                \"3\": 168.7,\n                \"4\": 168.83,\n                \"5\": 171.05,\n                \"6\": 174.64,\n                \"7\": 170.92,\n                \"8\": 169.69,\n                \"9\": 169.21,\n                \"10\": 167.71,\n                \"11\": 177.62,\n                \"12\": 166.44,\n                \"13\": 174.61,\n                \"14\": 170.43,\n                \"15\": 169.47,\n                \"16\": 167.24,\n                \"17\": 169.15,\n                \"18\": 169.51,\n                \"19\": 172.3,\n                \"20\": 172.19,\n                \"21\": 172.36,\n                \"22\": 168.39,\n                \"23\": 168.47,\n                \"24\": 170.55,\n                \"25\": 170.96,\n                \"26\": 168.43,\n                \"27\": 169.01,\n                \"28\": 169.62,\n                \"29\": 170.95,\n                \"30\": 171.83,\n                \"31\": 171.92,\n                \"32\": 170.1,\n                \"33\": 170.46,\n                \"34\": 169.91,\n                \"35\": 168.91,\n                \"36\": 170.27,\n                \"37\": 170.23,\n                \"38\": 169.62,\n                \"39\": 169.66,\n                \"40\": 169.57,\n                \"41\": 169.42,\n                \"42\": 168.59,\n                \"43\": 171.12,\n                \"44\": 169.6,\n                \"45\": 169.93,\n                \"46\": 171.23,\n                \"47\": 171.03,\n                \"48\": 170.14,\n                \"49\": 169.4\n            },\n            \"expected_e2e_ms\": 9006.3,\n            \"expected_avg_denoise_ms\": 167.89,\n            \"expected_median_denoise_ms\": 169.67\n        },\n        \"zimage_image_t2i\": {\n            \"stages_ms\": {\n                \"InputValidationStage\": 0.03,\n                \"TextEncodingStage\": 403.47,\n                \"TimestepPreparationStage\": 1.41,\n                \"LatentPreparationStage\": 0.11,\n                \"DenoisingStage\": 756.21,\n                \"DecodingStage\": 29.41\n            },\n            \"denoise_step_ms\": {\n                \"0\": 22.29,\n                \"1\": 75.04,\n                \"2\": 93.82,\n                \"3\": 93.34,\n                \"4\": 93.38,\n                \"5\": 93.58,\n                \"6\": 94.01,\n                \"7\": 93.97,\n                \"8\": 94.32\n            },\n            \"expected_e2e_ms\": 1292.92,\n            \"expected_avg_denoise_ms\": 83.75,\n            \"expected_median_denoise_ms\": 93.58\n        },\n        \"zimage_image_t2i_fp8\": {\n            \"stages_ms\": {\n                \"InputValidationStage\": 0.04,\n                \"TextEncodingStage\": 428.59,\n                \"LatentPreparationStage\": 0.14,\n                \"TimestepPreparationStage\": 47.26,\n                \"DenoisingStage\": 778.56,\n                \"DecodingStage\": 10.39\n            },\n            \"denoise_step_ms\": {\n                \"0\": 40.9,\n                \"1\": 61.08,\n                \"2\": 95.65,\n                \"3\": 95.83,\n                \"4\": 95.65,\n                \"5\": 96.09,\n                \"6\": 96.23,\n                \"7\": 96.04,\n                \"8\": 96.29\n            },\n            \"expected_e2e_ms\": 1370.28,\n            \"expected_avg_denoise_ms\": 85.97,\n            \"expected_median_denoise_ms\": 95.83\n        },\n        \"zimage_image_t2i_multi_lora\": {\n            \"stages_ms\": {\n                \"InputValidationStage\": 0.04,\n                \"TextEncodingStage\": 413.69,\n                \"TimestepPreparationStage\": 1.3,\n                \"LatentPreparationStage\": 0.11,\n                \"DenoisingStage\": 813.7,\n                \"DecodingStage\": 34.51\n            },\n            \"denoise_step_ms\": {\n                \"0\": 30.35,\n                \"1\": 74.53,\n                \"2\": 99.34,\n                \"3\": 100.92,\n                \"4\": 99.46,\n                \"5\": 100.57,\n                \"6\": 99.72,\n                \"7\": 100.86,\n                \"8\": 103.87\n            },\n            \"expected_e2e_ms\": 1464.31,\n            \"expected_avg_denoise_ms\": 89.96,\n            \"expected_median_denoise_ms\": 99.72\n        },\n        \"zimage_image_t2i_2_gpus\": {\n            \"stages_ms\": {\n                \"InputValidationStage\": 0.08,\n                \"TextEncodingStage\": 420.74,\n                \"TimestepPreparationStage\": 1.5,\n                \"LatentPreparationStage\": 0.12,\n                \"DenoisingStage\": 1304.07,\n                \"DecodingStage\": 37.83\n            },\n            \"denoise_step_ms\": {\n                \"0\": 49.76,\n                \"1\": 155.22,\n                \"2\": 155.98,\n                \"3\": 156.16,\n                \"4\": 157.04,\n                \"5\": 156.54,\n                \"6\": 156.29,\n                \"7\": 157.36,\n                \"8\": 156.05\n            },\n            \"expected_e2e_ms\": 1464.87,\n            \"expected_avg_denoise_ms\": 144.49,\n            \"expected_median_denoise_ms\": 156.16\n        },\n        \"qwen_image_edit_ti2i\": {\n            \"stages_ms\": {\n                \"LatentPreparationStage\": 0.16,\n                \"TimestepPreparationStage\": 2.62,\n                \"ImageEncodingStage\": 1174.26,\n                \"ImageVAEEncodingStage\": 132.67,\n                \"InputValidationStage\": 38.1,\n                \"DenoisingStage\": 38135.64,\n                \"DecodingStage\": 139.72\n            },\n            \"denoise_step_ms\": {\n                \"0\": 618.31,\n                \"1\": 769.07,\n                \"2\": 766.91,\n                \"3\": 762.77,\n                \"4\": 764.26,\n                \"5\": 765.27,\n                \"6\": 767.35,\n                \"7\": 764.18,\n                \"8\": 766.16,\n                \"9\": 766.89,\n                \"10\": 766.1,\n                \"11\": 764.96,\n                \"12\": 763.52,\n                \"13\": 765.22,\n                \"14\": 765.44,\n                \"15\": 763.9,\n                \"16\": 763.19,\n                \"17\": 764.83,\n                \"18\": 765.36,\n                \"19\": 765.19,\n                \"20\": 765.96,\n                \"21\": 765.74,\n                \"22\": 765.87,\n                \"23\": 764.85,\n                \"24\": 765.44,\n                \"25\": 765.95,\n                \"26\": 766.21,\n                \"27\": 767.91,\n                \"28\": 765.45,\n                \"29\": 764.81,\n                \"30\": 766.26,\n                \"31\": 765.37,\n                \"32\": 766.71,\n                \"33\": 765.67,\n                \"34\": 766.64,\n                \"35\": 765.98,\n                \"36\": 766.04,\n                \"37\": 764.19,\n                \"38\": 765.15,\n                \"39\": 766.33,\n                \"40\": 767.68,\n                \"41\": 765.36,\n                \"42\": 766.61,\n                \"43\": 766.06,\n                \"44\": 765.26,\n                \"45\": 765.29,\n                \"46\": 764.64,\n                \"47\": 766.07,\n                \"48\": 762.89,\n                \"49\": 763.01\n            },\n            \"expected_e2e_ms\": 39706.9,\n            \"expected_avg_denoise_ms\": 762.57,\n            \"expected_median_denoise_ms\": 765.44\n        },\n        \"qwen_image_t2i_cache_dit_enabled\": {\n            \"stages_ms\": {\n                \"InputValidationStage\": 0.05,\n                \"TextEncodingStage\": 675.95,\n                \"TimestepPreparationStage\": 3.21,\n                \"LatentPreparationStage\": 0.2,\n                \"DenoisingStage\": 5248.83,\n                \"DecodingStage\": 52.24\n            },\n            \"denoise_step_ms\": {\n                \"0\": 227.68,\n                \"1\": 277.41,\n                \"2\": 276.7,\n                \"3\": 291.52,\n                \"4\": 52.8,\n                \"5\": 6.58,\n                \"6\": 231.58,\n                \"7\": 52.55,\n                \"8\": 7.69,\n                \"9\": 230.59,\n                \"10\": 52.58,\n                \"11\": 7.14,\n                \"12\": 6.95,\n                \"13\": 234.71,\n                \"14\": 53.28,\n                \"15\": 7.09,\n                \"16\": 6.63,\n                \"17\": 233.93,\n                \"18\": 52.71,\n                \"19\": 6.64,\n                \"20\": 6.5,\n                \"21\": 231.37,\n                \"22\": 52.28,\n                \"23\": 6.61,\n                \"24\": 6.48,\n                \"25\": 232.86,\n                \"26\": 54.92,\n                \"27\": 7.51,\n                \"28\": 7.19,\n                \"29\": 233.51,\n                \"30\": 52.97,\n                \"31\": 6.72,\n                \"32\": 7.02,\n                \"33\": 233.14,\n                \"34\": 52.47,\n                \"35\": 6.66,\n                \"36\": 6.52,\n                \"37\": 233.84,\n                \"38\": 51.49,\n                \"39\": 6.87,\n                \"40\": 6.74,\n                \"41\": 233.75,\n                \"42\": 52.65,\n                \"43\": 6.62,\n                \"44\": 6.55,\n                \"45\": 233.45,\n                \"46\": 52.33,\n                \"47\": 6.55,\n                \"48\": 232.58,\n                \"49\": 52.84\n            },\n            \"expected_e2e_ms\": 5982.78,\n            \"expected_avg_denoise_ms\": 104.84,\n            \"expected_median_denoise_ms\": 102.01\n        },\n        \"wan2_1_t2v_1.3b_teacache_enabled\": {\n            \"stages_ms\": {\n                \"DenoisingStage\": 4598.36,\n                \"InputValidationStage\": 0.07,\n                \"DecodingStage\": 552.92,\n                \"LatentPreparationStage\": 0.26,\n                \"per_frame_generation\": null,\n                \"TextEncodingStage\": 1114.01,\n                \"TimestepPreparationStage\": 2.1\n            },\n            \"denoise_step_ms\": {\n                \"0\": 94.24,\n                \"1\": 172.68,\n                \"2\": 169.48,\n                \"3\": 169.08,\n                \"4\": 168.38,\n                \"5\": 167.27,\n                \"6\": 62.95,\n                \"7\": 119.56,\n                \"8\": 53.34,\n                \"9\": 121.85,\n                \"10\": 47.64,\n                \"11\": 125.75,\n                \"12\": 3.24,\n                \"13\": 48.21,\n                \"14\": 125.17,\n                \"15\": 3.71,\n                \"16\": 48.15,\n                \"17\": 124.61,\n                \"18\": 3.3,\n                \"19\": 47.25,\n                \"20\": 129.33,\n                \"21\": 3.11,\n                \"22\": 48.03,\n                \"23\": 127.46,\n                \"24\": 3.37,\n                \"25\": 45.6,\n                \"26\": 127.17,\n                \"27\": 3.35,\n                \"28\": 49.83,\n                \"29\": 125.42,\n                \"30\": 3.19,\n                \"31\": 42.76,\n                \"32\": 131.19,\n                \"33\": 2.93,\n                \"34\": 130.04,\n                \"35\": 44.77,\n                \"36\": 131.45,\n                \"37\": 44.06,\n                \"38\": 131.02,\n                \"39\": 43.48,\n                \"40\": 130.42,\n                \"41\": 45.24,\n                \"42\": 129.46,\n                \"43\": 44.6,\n                \"44\": 130.33,\n                \"45\": 173.84,\n                \"46\": 175.58,\n                \"47\": 168.16,\n                \"48\": 173.85,\n                \"49\": 177.56\n            },\n            \"expected_e2e_ms\": 6497.84,\n            \"expected_avg_denoise_ms\": 91.85,\n            \"expected_median_denoise_ms\": 120.7\n        },\n        \"wan2_1_t2v_1.3b\": {\n            \"stages_ms\": {\n                \"InputValidationStage\": 0.07,\n                \"TextEncodingStage\": 2237.78,\n                \"TimestepPreparationStage\": 2.1,\n                \"LatentPreparationStage\": 0.84,\n                \"DenoisingStage\": 13041.23,\n                \"DecodingStage\": 1274.63,\n                \"per_frame_generation\": null\n            },\n            \"denoise_step_ms\": {\n                \"0\": 224.71,\n                \"1\": 248.13,\n                \"2\": 246.48,\n                \"3\": 247.87,\n                \"4\": 249.38,\n                \"5\": 246.76,\n                \"6\": 250.42,\n                \"7\": 250.81,\n                \"8\": 250.98,\n                \"9\": 249.9,\n                \"10\": 246.72,\n                \"11\": 249.79,\n                \"12\": 250.46,\n                \"13\": 249.19,\n                \"14\": 247.55,\n                \"15\": 250.12,\n                \"16\": 247.57,\n                \"17\": 247.21,\n                \"18\": 247.32,\n                \"19\": 247.42,\n                \"20\": 248.21,\n                \"21\": 247.19,\n                \"22\": 247.72,\n                \"23\": 247.45,\n                \"24\": 247.9,\n                \"25\": 247.87,\n                \"26\": 247.18,\n                \"27\": 247.65,\n                \"28\": 246.91,\n                \"29\": 248.26,\n                \"30\": 247.82,\n                \"31\": 247.73,\n                \"32\": 247.38,\n                \"33\": 247.84,\n                \"34\": 247.46,\n                \"35\": 247.52,\n                \"36\": 247.94,\n                \"37\": 248.76,\n                \"38\": 248.01,\n                \"39\": 247.45,\n                \"40\": 247.84,\n                \"41\": 248.33,\n                \"42\": 247.41,\n                \"43\": 248.16,\n                \"44\": 248.18,\n                \"45\": 248.44,\n                \"46\": 248.65,\n                \"47\": 247.73,\n                \"48\": 247.48,\n                \"49\": 247.54\n            },\n            \"expected_e2e_ms\": 18382.19,\n            \"expected_avg_denoise_ms\": 260.76,\n            \"expected_median_denoise_ms\": 247.84\n        },\n        \"wan2_1_t2v_1.3b_text_encoder_cpu_offload\": {\n            \"stages_ms\": {\n                \"InputValidationStage\": 0.09,\n                \"TextEncodingStage\": 2480.54,\n                \"TimestepPreparationStage\": 3.73,\n                \"LatentPreparationStage\": 1.34,\n                \"DenoisingStage\": 12514.88,\n                \"DecodingStage\": 1147.6,\n                \"per_frame_generation\": null\n            },\n            \"denoise_step_ms\": {\n                \"0\": 487.21,\n                \"1\": 243.47,\n                \"2\": 244.28,\n                \"3\": 244.06,\n                \"4\": 244.77,\n                \"5\": 245.86,\n                \"6\": 245.38,\n                \"7\": 246.74,\n                \"8\": 246.28,\n                \"9\": 245.58,\n                \"10\": 245.6,\n                \"11\": 245.21,\n                \"12\": 245.08,\n                \"13\": 245.03,\n                \"14\": 245.53,\n                \"15\": 245.36,\n                \"16\": 246.17,\n                \"17\": 245.32,\n                \"18\": 244.37,\n                \"19\": 246.83,\n                \"20\": 245.87,\n                \"21\": 244.93,\n                \"22\": 245.11,\n                \"23\": 245.23,\n                \"24\": 245.76,\n                \"25\": 245.44,\n                \"26\": 246.47,\n                \"27\": 244.56,\n                \"28\": 244.76,\n                \"29\": 244.79,\n                \"30\": 244.76,\n                \"31\": 244.8,\n                \"32\": 245.11,\n                \"33\": 245.27,\n                \"34\": 245.37,\n                \"35\": 245.3,\n                \"36\": 244.84,\n                \"37\": 245.26,\n                \"38\": 245.38,\n                \"39\": 245.31,\n                \"40\": 244.7,\n                \"41\": 245.84,\n                \"42\": 245.66,\n                \"43\": 246.68,\n                \"44\": 245.38,\n                \"45\": 245.98,\n                \"46\": 246.02,\n                \"47\": 245.96,\n                \"48\": 245.31,\n                \"49\": 244.99\n            },\n            \"expected_e2e_ms\": 16161.11,\n            \"expected_avg_denoise_ms\": 250.18,\n            \"expected_median_denoise_ms\": 245.32\n        },\n        \"wan2_1_t2v_1.3b_cfg_parallel\": {\n            \"stages_ms\": {\n                \"InputValidationStage\": 0.08,\n                \"TextEncodingStage\": 2700.44,\n                \"TimestepPreparationStage\": 2.82,\n                \"LatentPreparationStage\": 2.0,\n                \"DenoisingStage\": 11640.75,\n                \"DecodingStage\": 890.63,\n                \"per_frame_generation\": null\n            },\n            \"denoise_step_ms\": {\n                \"0\": 266.91,\n                \"1\": 211.32,\n                \"2\": 206.59,\n                \"3\": 208.12,\n                \"4\": 210.68,\n                \"5\": 210.28,\n                \"6\": 213.92,\n                \"7\": 211.25,\n                \"8\": 212.89,\n                \"9\": 205.35,\n                \"10\": 205.92,\n                \"11\": 208.99,\n                \"12\": 207.1,\n                \"13\": 208.1,\n                \"14\": 206.52,\n                \"15\": 205.5,\n                \"16\": 205.24,\n                \"17\": 204.93,\n                \"18\": 207.05,\n                \"19\": 203.78,\n                \"20\": 205.23,\n                \"21\": 203.87,\n                \"22\": 204.28,\n                \"23\": 203.8,\n                \"24\": 206.02,\n                \"25\": 207.2,\n                \"26\": 209.53,\n                \"27\": 207.46,\n                \"28\": 206.77,\n                \"29\": 208.14,\n                \"30\": 208.05,\n                \"31\": 208.78,\n                \"32\": 209.23,\n                \"33\": 209.72,\n                \"34\": 208.26,\n                \"35\": 208.55,\n                \"36\": 205.24,\n                \"37\": 204.96,\n                \"38\": 203.77,\n                \"39\": 210.2,\n                \"40\": 202.57,\n                \"41\": 204.77,\n                \"42\": 204.96,\n                \"43\": 203.8,\n                \"44\": 203.9,\n                \"45\": 204.49,\n                \"46\": 207.75,\n                \"47\": 209.09,\n                \"48\": 207.51,\n                \"49\": 207.38\n            },\n            \"expected_e2e_ms\": 15245.6,\n            \"expected_avg_denoise_ms\": 224.37,\n            \"expected_median_denoise_ms\": 207.15\n        },\n        \"turbo_wan2_1_t2v_1.3b\": {\n            \"stages_ms\": {\n                \"InputValidationStage\": 0.06,\n                \"TextEncodingStage\": 2508.95,\n                \"TimestepPreparationStage\": 73.51,\n                \"LatentPreparationStage\": 1.34,\n                \"DmdDenoisingStage\": 1285.25,\n                \"DecodingStage\": 805.04,\n                \"per_frame_generation\": null\n            },\n            \"denoise_step_ms\": {\n                \"0\": 897.62,\n                \"1\": 126.04,\n                \"2\": 126.52,\n                \"3\": 128.26\n            },\n            \"expected_e2e_ms\": 4686.66,\n            \"expected_avg_denoise_ms\": 319.61,\n            \"expected_median_denoise_ms\": 127.39\n        },\n        \"wan2_2_ti2v_5b\": {\n            \"stages_ms\": {\n                \"InputValidationStage\": 96.27,\n                \"TextEncodingStage\": 2238.81,\n                \"TimestepPreparationStage\": 2.39,\n                \"LatentPreparationStage\": 27.62,\n                \"DenoisingStage\": 134069.79,\n                \"DecodingStage\": 13559.79,\n                \"per_frame_generation\": null\n            },\n            \"denoise_step_ms\": {\n                \"0\": 3181.0,\n                \"1\": 2561.67,\n                \"2\": 2578.49,\n                \"3\": 2582.1,\n                \"4\": 2572.24,\n                \"5\": 2577.72,\n                \"6\": 2581.35,\n                \"7\": 2578.79,\n                \"8\": 2584.98,\n                \"9\": 2588.49,\n                \"10\": 2594.37,\n                \"11\": 2591.19,\n                \"12\": 2591.32,\n                \"13\": 2595.35,\n                \"14\": 2594.35,\n                \"15\": 2595.62,\n                \"16\": 2596.35,\n                \"17\": 2596.11,\n                \"18\": 2597.24,\n                \"19\": 2603.13,\n                \"20\": 2599.9,\n                \"21\": 2601.48,\n                \"22\": 2603.58,\n                \"23\": 2601.13,\n                \"24\": 2600.47,\n                \"25\": 2604.13,\n                \"26\": 2606.04,\n                \"27\": 2605.3,\n                \"28\": 2602.02,\n                \"29\": 2601.83,\n                \"30\": 2603.57,\n                \"31\": 2606.63,\n                \"32\": 2606.1,\n                \"33\": 2602.24,\n                \"34\": 2603.29,\n                \"35\": 2602.34,\n                \"36\": 2602.16,\n                \"37\": 2608.14,\n                \"38\": 2603.48,\n                \"39\": 2601.7,\n                \"40\": 2603.96,\n                \"41\": 2604.58,\n                \"42\": 2606.67,\n                \"43\": 2603.52,\n                \"44\": 2599.88,\n                \"45\": 2598.66,\n                \"46\": 2600.74,\n                \"47\": 2602.31,\n                \"48\": 2608.4,\n                \"49\": 2606.02\n            },\n            \"expected_e2e_ms\": 150004.2,\n            \"expected_avg_denoise_ms\": 2608.84,\n            \"expected_median_denoise_ms\": 2601.59\n        },\n        \"qwen_image_edit_2509_ti2i\": {\n            \"stages_ms\": {\n                \"InputValidationStage\": 213.24,\n                \"ImageEncodingStage\": 1089.12,\n                \"ImageVAEEncodingStage\": 304.56,\n                \"TimestepPreparationStage\": 2.94,\n                \"LatentPreparationStage\": 0.2,\n                \"DenoisingStage\": 50724.5,\n                \"DecodingStage\": 601.02\n            },\n            \"denoise_step_ms\": {\n                \"0\": 1057.09,\n                \"1\": 1267.06,\n                \"2\": 1268.33,\n                \"3\": 1268.94,\n                \"4\": 1270.36,\n                \"5\": 1270.44,\n                \"6\": 1268.61,\n                \"7\": 1270.21,\n                \"8\": 1274.98,\n                \"9\": 1271.57,\n                \"10\": 1273.15,\n                \"11\": 1271.56,\n                \"12\": 1272.69,\n                \"13\": 1271.62,\n                \"14\": 1274.04,\n                \"15\": 1276.81,\n                \"16\": 1272.2,\n                \"17\": 1269.33,\n                \"18\": 1275.96,\n                \"19\": 1274.43,\n                \"20\": 1272.57,\n                \"21\": 1275.28,\n                \"22\": 1273.63,\n                \"23\": 1275.06,\n                \"24\": 1277.39,\n                \"25\": 1277.27,\n                \"26\": 1274.74,\n                \"27\": 1273.38,\n                \"28\": 1276.77,\n                \"29\": 1275.59,\n                \"30\": 1275.51,\n                \"31\": 1274.9,\n                \"32\": 1274.8,\n                \"33\": 1279.03,\n                \"34\": 1272.9,\n                \"35\": 1274.67,\n                \"36\": 1272.61,\n                \"37\": 1272.82,\n                \"38\": 1276.41,\n                \"39\": 1273.55\n            },\n            \"expected_e2e_ms\": 52938.04,\n            \"expected_avg_denoise_ms\": 1267.96,\n            \"expected_median_denoise_ms\": 1273.46\n        },\n        \"qwen_image_layered_i2i\": {\n            \"stages_ms\": {\n                \"QwenImageLayeredBeforeDenoisingStage\": 2897.28,\n                \"DecodingStage\": 312.93,\n                \"DenoisingStage\": 39417.66,\n                \"TimestepPreparationStage\": 2.29\n            },\n            \"denoise_step_ms\": {\n                \"0\": 657.28,\n                \"1\": 799.2,\n                \"2\": 790.35,\n                \"3\": 785.79,\n                \"4\": 792.9,\n                \"5\": 795.78,\n                \"6\": 791.28,\n                \"7\": 790.87,\n                \"8\": 786.47,\n                \"9\": 791.03,\n                \"10\": 788.77,\n                \"11\": 790.57,\n                \"12\": 788.7,\n                \"13\": 786.01,\n                \"14\": 791.43,\n                \"15\": 789.88,\n                \"16\": 791.18,\n                \"17\": 792.78,\n                \"18\": 792.06,\n                \"19\": 790.47,\n                \"20\": 792.48,\n                \"21\": 789.13,\n                \"22\": 792.12,\n                \"23\": 789.36,\n                \"24\": 790.2,\n                \"25\": 790.87,\n                \"26\": 792.37,\n                \"27\": 794.92,\n                \"28\": 792.9,\n                \"29\": 791.43,\n                \"30\": 793.01,\n                \"31\": 793.71,\n                \"32\": 794.15,\n                \"33\": 787.93,\n                \"34\": 792.12,\n                \"35\": 794.01,\n                \"36\": 789.05,\n                \"37\": 790.51,\n                \"38\": 793.29,\n                \"39\": 791.94,\n                \"40\": 788.94,\n                \"41\": 788.85,\n                \"42\": 789.76,\n                \"43\": 788.89,\n                \"44\": 791.62,\n                \"45\": 788.04,\n                \"46\": 790.03,\n                \"47\": 786.82,\n                \"48\": 789.75,\n                \"49\": 789.0\n            },\n            \"expected_e2e_ms\": 42660.88,\n            \"expected_avg_denoise_ms\": 788.2,\n            \"expected_median_denoise_ms\": 790.72\n        },\n        \"fastwan2_2_ti2v_5b\": {\n            \"stages_ms\": {\n                \"InputValidationStage\": 300.0,\n                \"TextEncodingStage\": 843.86,\n                \"TimestepPreparationStage\": 58.66,\n                \"LatentPreparationStage\": 28.55,\n                \"DmdDenoisingStage\": 499.34,\n                \"DecodingStage\": 1924.01,\n                \"per_frame_generation\": null\n            },\n            \"denoise_step_ms\": {\n                \"0\": 164.76,\n                \"1\": 165.6,\n                \"2\": 165.84\n            },\n            \"expected_e2e_ms\": 7722.91,\n            \"expected_avg_denoise_ms\": 165.42,\n            \"expected_median_denoise_ms\": 165.66\n        },\n        \"fast_hunyuan_video\": {\n            \"stages_ms\": {\n                \"InputValidationStage\": 0.34,\n                \"TextEncodingStage\": 550.63,\n                \"TimestepPreparationStage\": 44.28,\n                \"LatentPreparationStage\": 0.29,\n                \"DenoisingStage\": 9154.39,\n                \"DecodingStage\": 5995.09,\n                \"per_frame_generation\": null\n            },\n            \"denoise_step_ms\": {\n                \"0\": 485.99,\n                \"1\": 1399.84,\n                \"2\": 1399.91,\n                \"3\": 1397.79,\n                \"4\": 1400.61,\n                \"5\": 1402.53\n            },\n            \"expected_e2e_ms\": 16672.15,\n            \"expected_avg_denoise_ms\": 1608.46,\n            \"expected_median_denoise_ms\": 1488.48\n        },\n        \"wan2_2_i2v_a14b_2gpu\": {\n            \"stages_ms\": {\n                \"InputValidationStage\": 18.45,\n                \"TextEncodingStage\": 3337.77,\n                \"TimestepPreparationStage\": 2.9,\n                \"LatentPreparationStage\": 1.25,\n                \"ImageVAEEncodingStage\": 1655.89,\n                \"DenoisingStage\": 106972.82,\n                \"DecodingStage\": 1355.52,\n                \"per_frame_generation\": null\n            },\n            \"denoise_step_ms\": {\n                \"0\": 1525.6,\n                \"1\": 1582.6,\n                \"2\": 1597.84,\n                \"3\": 1601.34,\n                \"4\": 1600.86,\n                \"5\": 1598.32,\n                \"6\": 1600.93,\n                \"7\": 1599.88,\n                \"8\": 1600.0,\n                \"9\": 1600.55,\n                \"10\": 1599.27,\n                \"11\": 1600.59,\n                \"12\": 1600.17,\n                \"13\": 1599.72,\n                \"14\": 1599.76,\n                \"15\": 24098.85,\n                \"16\": 1601.29,\n                \"17\": 1598.89,\n                \"18\": 1600.12,\n                \"19\": 1600.52,\n                \"20\": 1599.59,\n                \"21\": 1600.37,\n                \"22\": 1600.35,\n                \"23\": 1599.7,\n                \"24\": 1599.92,\n                \"25\": 1599.75,\n                \"26\": 1600.2,\n                \"27\": 1600.06,\n                \"28\": 1600.41,\n                \"29\": 1599.35,\n                \"30\": 1600.69,\n                \"31\": 1600.15,\n                \"32\": 1599.33,\n                \"33\": 1599.86,\n                \"34\": 1600.52,\n                \"35\": 1599.84,\n                \"36\": 1600.38,\n                \"37\": 1599.23,\n                \"38\": 1600.27,\n                \"39\": 1599.78\n            },\n            \"expected_e2e_ms\": 123182.9887,\n            \"expected_avg_denoise_ms\": 2831.0,\n            \"expected_median_denoise_ms\": 1600.09\n        },\n        \"turbo_wan2_2_i2v_a14b_2gpu\": {\n            \"stages_ms\": {\n                \"InputValidationStage\": 25.01,\n                \"TextEncodingStage\": 5198.6,\n                \"TimestepPreparationStage\": 56.26,\n                \"LatentPreparationStage\": 1.4,\n                \"ImageVAEEncodingStage\": 1001.89,\n                \"DmdDenoisingStage\": 4487.79,\n                \"DecodingStage\": 821.01,\n                \"per_frame_generation\": null\n            },\n            \"denoise_step_ms\": {\n                \"0\": 3042.56,\n                \"1\": 485.88,\n                \"2\": 721.2,\n                \"3\": 475.58\n            },\n            \"expected_e2e_ms\": 11605.97,\n            \"expected_avg_denoise_ms\": 1120.4,\n            \"expected_median_denoise_ms\": 481.74\n        },\n        \"wan2_1_i2v_14b_480P_2gpu\": {\n            \"stages_ms\": {\n                \"InputValidationStage\": 38.23,\n                \"TextEncodingStage\": 3550.36,\n                \"ImageEncodingStage\": 3462.55,\n                \"TimestepPreparationStage\": 2.6,\n                \"LatentPreparationStage\": 9.73,\n                \"ImageVAEEncodingStage\": 2290.98,\n                \"DenoisingStage\": 415021.17,\n                \"DecodingStage\": 3016.1,\n                \"per_frame_generation\": null\n            },\n            \"denoise_step_ms\": {\n                \"0\": 10200.25,\n                \"1\": 8222.39,\n                \"2\": 8279.38,\n                \"3\": 8301.48,\n                \"4\": 8338.87,\n                \"5\": 8352.39,\n                \"6\": 8354.64,\n                \"7\": 8353.64,\n                \"8\": 8315.58,\n                \"9\": 8308.48,\n                \"10\": 8299.65,\n                \"11\": 8292.7,\n                \"12\": 8292.73,\n                \"13\": 8285.21,\n                \"14\": 8276.06,\n                \"15\": 8270.41,\n                \"16\": 8273.04,\n                \"17\": 8266.04,\n                \"18\": 8267.7,\n                \"19\": 8264.06,\n                \"20\": 8259.32,\n                \"21\": 8257.26,\n                \"22\": 8253.02,\n                \"23\": 8251.77,\n                \"24\": 8260.97,\n                \"25\": 8251.39,\n                \"26\": 8237.43,\n                \"27\": 8241.33,\n                \"28\": 8235.96,\n                \"29\": 8240.6,\n                \"30\": 8232.48,\n                \"31\": 8237.85,\n                \"32\": 8244.3,\n                \"33\": 8236.79,\n                \"34\": 8239.83,\n                \"35\": 8239.89,\n                \"36\": 8239.12,\n                \"37\": 8246.74,\n                \"38\": 8235.67,\n                \"39\": 8242.77,\n                \"40\": 8241.17,\n                \"41\": 8240.24,\n                \"42\": 8237.01,\n                \"43\": 8231.26,\n                \"44\": 8232.85,\n                \"45\": 8226.56,\n                \"46\": 8236.98,\n                \"47\": 8226.73,\n                \"48\": 8220.49,\n                \"49\": 8217.04\n            },\n            \"expected_e2e_ms\": 426697.37,\n            \"expected_avg_denoise_ms\": 8300.19,\n            \"expected_median_denoise_ms\": 8267.01\n        },\n        \"wan2_1_i2v_14b_720P_2gpu\": {\n            \"stages_ms\": {\n                \"InputValidationStage\": 53.67,\n                \"TextEncodingStage\": 2838,\n                \"ImageEncodingStage\": 3123.99,\n                \"TimestepPreparationStage\": 3.39,\n                \"LatentPreparationStage\": 8.41,\n                \"ImageVAEEncodingStage\": 2261.05,\n                \"DenoisingStage\": 417418.12,\n                \"DecodingStage\": 2968.35\n            },\n            \"denoise_step_ms\": {\n                \"0\": 11848.08,\n                \"1\": 8220.3,\n                \"2\": 8274.3,\n                \"3\": 8298.9,\n                \"4\": 8303.34,\n                \"5\": 8322.44,\n                \"6\": 8314.37,\n                \"7\": 8318.54,\n                \"8\": 8304.94,\n                \"9\": 8303.04,\n                \"10\": 8305.22,\n                \"11\": 8296.22,\n                \"12\": 8289.2,\n                \"13\": 8294.19,\n                \"14\": 8294.87,\n                \"15\": 8285.96,\n                \"16\": 8284.98,\n                \"17\": 8281.61,\n                \"18\": 8277.35,\n                \"19\": 8287.46,\n                \"20\": 8280.3,\n                \"21\": 8279.18,\n                \"22\": 8279.37,\n                \"23\": 8280.16,\n                \"24\": 8282.67,\n                \"25\": 8272.14,\n                \"26\": 8279.37,\n                \"27\": 8271.66,\n                \"28\": 8274.6,\n                \"29\": 8272.88,\n                \"30\": 8273.76,\n                \"31\": 8266.17,\n                \"32\": 8267.77,\n                \"33\": 8266.88,\n                \"34\": 8263.14,\n                \"35\": 8265.97,\n                \"36\": 8267.76,\n                \"37\": 8268.03,\n                \"38\": 8262.24,\n                \"39\": 8261.4,\n                \"40\": 8263.65,\n                \"41\": 8272.46,\n                \"42\": 8254.9,\n                \"43\": 8261.03,\n                \"44\": 8252.92,\n                \"45\": 8262.49,\n                \"46\": 8253.67,\n                \"47\": 8254.92,\n                \"48\": 8257.08,\n                \"49\": 8236.56\n            },\n            \"expected_e2e_ms\": 427536.9,\n            \"expected_avg_denoise_ms\": 8348.21,\n            \"expected_median_denoise_ms\": 8274.45\n        },\n        \"wan2_2_t2v_a14b_2gpu\": {\n            \"stages_ms\": {\n                \"InputValidationStage\": 0.07,\n                \"TextEncodingStage\": 2575.3,\n                \"TimestepPreparationStage\": 1.99,\n                \"LatentPreparationStage\": 1.26,\n                \"DenoisingStage\": 156678.8406,\n                \"DecodingStage\": 2702.7,\n                \"per_frame_generation\": null\n            },\n            \"denoise_step_ms\": {\n                \"0\": 17908.3,\n                \"1\": 2379.69,\n                \"2\": 2393.59,\n                \"3\": 2400.91,\n                \"4\": 2398.76,\n                \"5\": 2403.1,\n                \"6\": 2403.26,\n                \"7\": 2399.48,\n                \"8\": 2401.33,\n                \"9\": 2398.4,\n                \"10\": 2401.14,\n                \"11\": 2409.1,\n                \"12\": 2401.16,\n                \"13\": 2408.74,\n                \"14\": 2404.97,\n                \"15\": 2400.51,\n                \"16\": 2402.84,\n                \"17\": 2401.87,\n                \"18\": 2399.67,\n                \"19\": 2400.71,\n                \"20\": 2399.23,\n                \"21\": 2400.13,\n                \"22\": 2400.64,\n                \"23\": 2399.15,\n                \"24\": 2399.58,\n                \"25\": 2400.26,\n                \"26\": 35247.02,\n                \"27\": 2390.25,\n                \"28\": 2398.42,\n                \"29\": 2399.8,\n                \"30\": 2400.08,\n                \"31\": 2400.58,\n                \"32\": 2403.68,\n                \"33\": 2399.37,\n                \"34\": 2401.53,\n                \"35\": 2399.69,\n                \"36\": 2399.9,\n                \"37\": 2400.75,\n                \"38\": 2398.97,\n                \"39\": 2399.12\n            },\n            \"expected_e2e_ms\": 149864.99,\n            \"expected_avg_denoise_ms\": 3608.89,\n            \"expected_median_denoise_ms\": 2400.38\n        },\n        \"wan2_1_t2v_14b_2gpu\": {\n            \"stages_ms\": {\n                \"InputValidationStage\": 0.05,\n                \"TextEncodingStage\": 2310.34,\n                \"TimestepPreparationStage\": 2.42,\n                \"LatentPreparationStage\": 27.7,\n                \"DenoisingStage\": 803631.52,\n                \"DecodingStage\": 8898.74,\n                \"per_frame_generation\": null\n            },\n            \"denoise_step_ms\": {\n                \"0\": 17347.88,\n                \"1\": 15956.93,\n                \"2\": 16027.54,\n                \"3\": 16054.15,\n                \"4\": 16081.46,\n                \"5\": 16062.7,\n                \"6\": 16058.56,\n                \"7\": 16057.58,\n                \"8\": 16061.04,\n                \"9\": 16120.97,\n                \"10\": 16036.84,\n                \"11\": 16019.6,\n                \"12\": 16042.29,\n                \"13\": 16039.87,\n                \"14\": 16063.0,\n                \"15\": 16036.16,\n                \"16\": 16079.82,\n                \"17\": 16019.7,\n                \"18\": 16061.5,\n                \"19\": 16039.95,\n                \"20\": 16009.42,\n                \"21\": 16051.01,\n                \"22\": 16039.31,\n                \"23\": 16048.22,\n                \"24\": 16071.41,\n                \"25\": 16078.75,\n                \"26\": 16061.78,\n                \"27\": 16018.39,\n                \"28\": 16041.44,\n                \"29\": 16039.64,\n                \"30\": 16041.89,\n                \"31\": 16039.6,\n                \"32\": 16038.97,\n                \"33\": 15999.48,\n                \"34\": 16019.93,\n                \"35\": 16040.27,\n                \"36\": 16020.3,\n                \"37\": 16039.38,\n                \"38\": 15999.4,\n                \"39\": 16022.15,\n                \"40\": 16042.32,\n                \"41\": 16016.62,\n                \"42\": 15998.92,\n                \"43\": 16041.48,\n                \"44\": 15999.63,\n                \"45\": 16003.21,\n                \"46\": 15995.91,\n                \"47\": 16023.52,\n                \"48\": 16016.64,\n                \"49\": 16019.6\n            },\n            \"expected_e2e_ms\": 814884.71,\n            \"expected_avg_denoise_ms\": 16062.92,\n            \"expected_median_denoise_ms\": 16039.62\n        },\n        \"wan2_2_t2v_a14b_lora_2gpu\": {\n            \"stages_ms\": {\n                \"InputValidationStage\": 0.09,\n                \"TextEncodingStage\": 2552.97,\n                \"TimestepPreparationStage\": 1.99,\n                \"LatentPreparationStage\": 1.29,\n                \"DenoisingStage\": 154340.69,\n                \"DecodingStage\": 2730.86,\n                \"per_frame_generation\": null\n            },\n            \"denoise_step_ms\": {\n                \"0\": 26510.7,\n                \"1\": 2381.25,\n                \"2\": 2396.9,\n                \"3\": 2400.96,\n                \"4\": 2402.47,\n                \"5\": 2399.6,\n                \"6\": 2400.5,\n                \"7\": 2401.13,\n                \"8\": 2399.32,\n                \"9\": 2400.0,\n                \"10\": 2401.35,\n                \"11\": 2400.04,\n                \"12\": 2408.27,\n                \"13\": 2407.08,\n                \"14\": 2405.92,\n                \"15\": 2403.99,\n                \"16\": 2402.12,\n                \"17\": 2402.52,\n                \"18\": 2398.08,\n                \"19\": 2399.9,\n                \"20\": 2400.14,\n                \"21\": 2398.64,\n                \"22\": 2401.32,\n                \"23\": 2400.75,\n                \"24\": 2399.27,\n                \"25\": 2400.21,\n                \"26\": 36387.55,\n                \"27\": 2399.77,\n                \"28\": 2398.09,\n                \"29\": 2404.64,\n                \"30\": 2400.68,\n                \"31\": 2404.3,\n                \"32\": 2392.44,\n                \"33\": 2390.56,\n                \"34\": 2396.05,\n                \"35\": 2394.86,\n                \"36\": 2396.07,\n                \"37\": 2398.49,\n                \"38\": 2394.77,\n                \"39\": 2394.19\n            },\n            \"expected_e2e_ms\": 159643.06,\n            \"expected_avg_denoise_ms\": 3851.87,\n            \"expected_median_denoise_ms\": 2400.09\n        },\n        \"wan2_1_t2v_1_3b_lora_1gpu\": {\n            \"stages_ms\": {\n                \"InputValidationStage\": 0.06,\n                \"TextEncodingStage\": 2467.44,\n                \"TimestepPreparationStage\": 2.96,\n                \"LatentPreparationStage\": 1.87,\n                \"DenoisingStage\": 14859.47,\n                \"DecodingStage\": 1199.31,\n                \"per_frame_generation\": null\n            },\n            \"denoise_step_ms\": {\n                \"0\": 233.29,\n                \"1\": 265.02,\n                \"2\": 257.83,\n                \"3\": 260.27,\n                \"4\": 261.43,\n                \"5\": 258.58,\n                \"6\": 256.64,\n                \"7\": 256.91,\n                \"8\": 258.41,\n                \"9\": 257.84,\n                \"10\": 257.08,\n                \"11\": 257.0,\n                \"12\": 258.44,\n                \"13\": 257.1,\n                \"14\": 256.95,\n                \"15\": 257.2,\n                \"16\": 256.84,\n                \"17\": 257.64,\n                \"18\": 257.22,\n                \"19\": 257.42,\n                \"20\": 256.91,\n                \"21\": 256.99,\n                \"22\": 257.17,\n                \"23\": 257.63,\n                \"24\": 258.89,\n                \"25\": 257.46,\n                \"26\": 257.3,\n                \"27\": 257.42,\n                \"28\": 257.19,\n                \"29\": 257.65,\n                \"30\": 257.39,\n                \"31\": 256.93,\n                \"32\": 258.23,\n                \"33\": 257.62,\n                \"34\": 281.86,\n                \"35\": 295.86,\n                \"36\": 296.73,\n                \"37\": 287.21,\n                \"38\": 300.87,\n                \"39\": 303.47,\n                \"40\": 294.09,\n                \"41\": 270.52,\n                \"42\": 256.53,\n                \"43\": 256.58,\n                \"44\": 256.29,\n                \"45\": 255.81,\n                \"46\": 256.34,\n                \"47\": 256.08,\n                \"48\": 255.92,\n                \"49\": 255.87\n            },\n            \"expected_e2e_ms\": 18547.46,\n            \"expected_avg_denoise_ms\": 297.09,\n            \"expected_median_denoise_ms\": 257.42\n        },\n        \"wan2_1_i2v_14b_lora_2gpu\": {\n            \"stages_ms\": {\n                \"InputValidationStage\": 23.97,\n                \"TextEncodingStage\": 2485.39,\n                \"ImageEncodingStage\": 2372.07,\n                \"TimestepPreparationStage\": 2.6,\n                \"LatentPreparationStage\": 0.18,\n                \"ImageVAEEncodingStage\": 2500.13,\n                \"DenoisingStage\": 193514.04,\n                \"DecodingStage\": 3341.78,\n                \"per_frame_generation\": null\n            },\n            \"denoise_step_ms\": {\n                \"0\": 6680.62,\n                \"1\": 3765.8,\n                \"2\": 3774.63,\n                \"3\": 3772.93,\n                \"4\": 3781.13,\n                \"5\": 3778.22,\n                \"6\": 3776.41,\n                \"7\": 3772.02,\n                \"8\": 3776.15,\n                \"9\": 3768.82,\n                \"10\": 3775.31,\n                \"11\": 3771.32,\n                \"12\": 3774.33,\n                \"13\": 3772.5,\n                \"14\": 3778.41,\n                \"15\": 3775.31,\n                \"16\": 3771.38,\n                \"17\": 3774.87,\n                \"18\": 3780.01,\n                \"19\": 3772.85,\n                \"20\": 3773.65,\n                \"21\": 3774.47,\n                \"22\": 3774.39,\n                \"23\": 3773.08,\n                \"24\": 3776.71,\n                \"25\": 3780.01,\n                \"26\": 3774.83,\n                \"27\": 3773.27,\n                \"28\": 3773.76,\n                \"29\": 3772.75,\n                \"30\": 3773.01,\n                \"31\": 3773.34,\n                \"32\": 3773.13,\n                \"33\": 3774.12,\n                \"34\": 3772.19,\n                \"35\": 3774.7,\n                \"36\": 3773.98,\n                \"37\": 3772.47,\n                \"38\": 3771.72,\n                \"39\": 3774.07,\n                \"40\": 3773.71,\n                \"41\": 3773.6,\n                \"42\": 3772.12,\n                \"43\": 3773.75,\n                \"44\": 3782.43,\n                \"45\": 3779.66,\n                \"46\": 3779.86,\n                \"47\": 3774.58,\n                \"48\": 3770.54,\n                \"49\": 3776.76\n            },\n            \"expected_e2e_ms\": 204257.12,\n            \"expected_avg_denoise_ms\": 3855.55,\n            \"expected_median_denoise_ms\": 3774.03\n        },\n        \"flux_2_image_t2i_2_gpus\": {\n            \"stages_ms\": {\n                \"InputValidationStage\": 0.05,\n                \"TextEncodingStage\": 518.88,\n                \"ImageVAEEncodingStage\": 0.0,\n                \"LatentPreparationStage\": 0.45,\n                \"TimestepPreparationStage\": 3.41,\n                \"DenoisingStage\": 26377.63,\n                \"DecodingStage\": 321.94\n            },\n            \"denoise_step_ms\": {\n                \"0\": 129.07,\n                \"1\": 437.16,\n                \"2\": 437.7,\n                \"3\": 437.67,\n                \"4\": 437.84,\n                \"5\": 438.03,\n                \"6\": 438.09,\n                \"7\": 437.65,\n                \"8\": 437.95,\n                \"9\": 438.31,\n                \"10\": 437.99,\n                \"11\": 438.54,\n                \"12\": 438.47,\n                \"13\": 438.2,\n                \"14\": 438.56,\n                \"15\": 438.69,\n                \"16\": 438.69,\n                \"17\": 438.98,\n                \"18\": 437.96,\n                \"19\": 438.9,\n                \"20\": 438.87,\n                \"21\": 438.04,\n                \"22\": 437.88,\n                \"23\": 439.09,\n                \"24\": 438.61,\n                \"25\": 437.68,\n                \"26\": 439.2,\n                \"27\": 439.63,\n                \"28\": 438.65,\n                \"29\": 439.32,\n                \"30\": 439.01,\n                \"31\": 438.84,\n                \"32\": 438.72,\n                \"33\": 439.09,\n                \"34\": 438.3,\n                \"35\": 439.48,\n                \"36\": 438.2,\n                \"37\": 439.67,\n                \"38\": 440.65,\n                \"39\": 439.96,\n                \"40\": 439.0,\n                \"41\": 439.2,\n                \"42\": 439.37,\n                \"43\": 439.98,\n                \"44\": 438.6,\n                \"45\": 439.58,\n                \"46\": 440.23,\n                \"47\": 440.1,\n                \"48\": 440.21,\n                \"49\": 439.22\n            },\n            \"expected_e2e_ms\": 27624.8,\n            \"expected_avg_denoise_ms\": 518.23,\n            \"expected_median_denoise_ms\": 528.06\n        },\n        \"qwen_image_edit_2511_ti2i\": {\n            \"stages_ms\": {\n                \"InputValidationStage\": 55.15,\n                \"ImageEncodingStage\": 770.33,\n                \"ImageVAEEncodingStage\": 88.06,\n                \"TimestepPreparationStage\": 2.12,\n                \"LatentPreparationStage\": 0.14,\n                \"DenoisingStage\": 23869.32,\n                \"DecodingStage\": 108.23\n            },\n            \"denoise_step_ms\": {\n                \"0\": 478.35,\n                \"1\": 608.56,\n                \"2\": 588.51,\n                \"3\": 607.26,\n                \"4\": 599.37,\n                \"5\": 595.19,\n                \"6\": 603.22,\n                \"7\": 594.48,\n                \"8\": 605.06,\n                \"9\": 597.63,\n                \"10\": 601.03,\n                \"11\": 597.18,\n                \"12\": 598.82,\n                \"13\": 600.05,\n                \"14\": 598.57,\n                \"15\": 601.4,\n                \"16\": 595.17,\n                \"17\": 599.21,\n                \"18\": 600.86,\n                \"19\": 600.93,\n                \"20\": 600.35,\n                \"21\": 600.63,\n                \"22\": 597.58,\n                \"23\": 600.73,\n                \"24\": 599.36,\n                \"25\": 600.48,\n                \"26\": 600.33,\n                \"27\": 599.34,\n                \"28\": 599.61,\n                \"29\": 599.71,\n                \"30\": 596.03,\n                \"31\": 599.85,\n                \"32\": 599.36,\n                \"33\": 601.58,\n                \"34\": 597.91,\n                \"35\": 600.79,\n                \"36\": 599.29,\n                \"37\": 601.64,\n                \"38\": 598.24,\n                \"39\": 599.87\n            },\n            \"expected_e2e_ms\": 24895.28,\n            \"expected_avg_denoise_ms\": 596.59,\n            \"expected_median_denoise_ms\": 599.66\n        },\n        \"fsdp-inference\": {\n            \"stages_ms\": {\n                \"InputValidationStage\": 0.04,\n                \"TextEncodingStage\": 411.12,\n                \"TimestepPreparationStage\": 1.44,\n                \"LatentPreparationStage\": 0.1,\n                \"DenoisingStage\": 1569.61,\n                \"DecodingStage\": 41.43\n            },\n            \"denoise_step_ms\": {\n                \"0\": 165.33,\n                \"1\": 158.34,\n                \"2\": 167.65,\n                \"3\": 179.11,\n                \"4\": 183.98,\n                \"5\": 175.08,\n                \"6\": 178.34,\n                \"7\": 178.53,\n                \"8\": 178.08\n            },\n            \"expected_e2e_ms\": 2103.05,\n            \"expected_avg_denoise_ms\": 173.83,\n            \"expected_median_denoise_ms\": 178.08\n        },\n        \"hunyuan3d_shape_gen\": {\n            \"stages_ms\": {\n                \"Hunyuan3DShapeBeforeDenoisingStage\": 31.42,\n                \"Hunyuan3DShapeDenoisingStage\": 3259.83,\n                \"Hunyuan3DShapeExportStage\": 8735.55,\n                \"Hunyuan3DShapeSaveStage\": 981.64,\n                \"Hunyuan3DPaintPreprocessStage\": 226071.67,\n                \"Hunyuan3DPaintTexGenStage\": 11083.05,\n                \"Hunyuan3DPaintPostprocessStage\": 7469.29\n            },\n            \"denoise_step_ms\": {\n                \"0\": 32.26,\n                \"1\": 63.34,\n                \"2\": 65.44,\n                \"3\": 65.44,\n                \"4\": 65.6,\n                \"5\": 65.81,\n                \"6\": 65.82,\n                \"7\": 65.48,\n                \"8\": 65.9,\n                \"9\": 65.77,\n                \"10\": 65.54,\n                \"11\": 65.68,\n                \"12\": 65.85,\n                \"13\": 65.77,\n                \"14\": 65.7,\n                \"15\": 65.78,\n                \"16\": 66.0,\n                \"17\": 66.15,\n                \"18\": 65.91,\n                \"19\": 66.5,\n                \"20\": 65.76,\n                \"21\": 66.08,\n                \"22\": 66.06,\n                \"23\": 66.23,\n                \"24\": 65.79,\n                \"25\": 65.58,\n                \"26\": 65.88,\n                \"27\": 65.67,\n                \"28\": 65.87,\n                \"29\": 66.09,\n                \"30\": 65.81,\n                \"31\": 65.91,\n                \"32\": 66.18,\n                \"33\": 65.93,\n                \"34\": 66.26,\n                \"35\": 66.26,\n                \"36\": 66.27,\n                \"37\": 65.57,\n                \"38\": 66.02,\n                \"39\": 66.19,\n                \"40\": 65.23,\n                \"41\": 66.11,\n                \"42\": 66.18,\n                \"43\": 65.86,\n                \"44\": 65.86,\n                \"45\": 65.92,\n                \"46\": 65.65,\n                \"47\": 65.78,\n                \"48\": 66.01,\n                \"49\": 66.08\n            },\n            \"expected_e2e_ms\": 257696.97,\n            \"expected_avg_denoise_ms\": 65.16,\n            \"expected_median_denoise_ms\": 65.86\n        },\n        \"wan2_1_t2v_1.3b_frame_interp_2x\": {\n            \"stages_ms\": {\n                \"TextEncodingStage\": 1104.4,\n                \"TimestepPreparationStage\": 2.19,\n                \"LatentPreparationStage\": 0.15,\n                \"DenoisingStage\": 8502.22,\n                \"DecodingStage\": 498.36,\n                \"InputValidationStage\": 0.07\n            },\n            \"denoise_step_ms\": {\n                \"0\": 91.83,\n                \"1\": 174.57,\n                \"2\": 170.48,\n                \"3\": 169.33,\n                \"4\": 169.24,\n                \"5\": 177.43,\n                \"6\": 173.73,\n                \"7\": 171.67,\n                \"8\": 170.98,\n                \"9\": 168.61,\n                \"10\": 169.96,\n                \"11\": 174.75,\n                \"12\": 172.33,\n                \"13\": 170.62,\n                \"14\": 169.84,\n                \"15\": 168.86,\n                \"16\": 171.32,\n                \"17\": 174.7,\n                \"18\": 172.31,\n                \"19\": 171.71,\n                \"20\": 170.98,\n                \"21\": 169.83,\n                \"22\": 170.54,\n                \"23\": 173.08,\n                \"24\": 172.11,\n                \"25\": 171.49,\n                \"26\": 171.0,\n                \"27\": 170.9,\n                \"28\": 171.78,\n                \"29\": 173.44,\n                \"30\": 171.14,\n                \"31\": 170.72,\n                \"32\": 170.64,\n                \"33\": 170.58,\n                \"34\": 172.51,\n                \"35\": 171.74,\n                \"36\": 171.57,\n                \"37\": 170.73,\n                \"38\": 171.49,\n                \"39\": 170.98,\n                \"40\": 172.63,\n                \"41\": 171.88,\n                \"42\": 171.71,\n                \"43\": 170.94,\n                \"44\": 170.31,\n                \"45\": 171.25,\n                \"46\": 171.43,\n                \"47\": 171.55,\n                \"48\": 172.08,\n                \"49\": 169.92\n            },\n            \"expected_e2e_ms\": 10464.97,\n            \"expected_avg_denoise_ms\": 169.92,\n            \"expected_median_denoise_ms\": 171.37\n        },\n        \"flux_2_klein_ti2i_2_gpus\": {\n            \"stages_ms\": {\n                \"InputValidationStage\": 40.19,\n                \"TextEncodingStage\": 88.84,\n                \"ImageVAEEncodingStage\": 80.81,\n                \"LatentPreparationStage\": 1.05,\n                \"TimestepPreparationStage\": 28.64,\n                \"DenoisingStage\": 354.04,\n                \"DecodingStage\": 11.11\n            },\n            \"denoise_step_ms\": {\n                \"0\": 33.54,\n                \"1\": 61.3,\n                \"2\": 86.9,\n                \"3\": 87.55\n            },\n            \"expected_e2e_ms\": 716.81,\n            \"expected_avg_denoise_ms\": 67.32,\n            \"expected_median_denoise_ms\": 74.1\n        },\n        \"flux_2_image_t2i_upscaling_4x\": {\n            \"stages_ms\": {\n                \"InputValidationStage\": 0.15,\n                \"TextEncodingStage\": 537.4,\n                \"ImageVAEEncodingStage\": 0.01,\n                \"LatentPreparationStage\": 1.11,\n                \"TimestepPreparationStage\": 39.57,\n                \"DenoisingStage\": 24738.35,\n                \"DecodingStage\": 14.04\n            },\n            \"denoise_step_ms\": {\n                \"0\": 70.71,\n                \"1\": 476.99,\n                \"2\": 502.01,\n                \"3\": 483.39,\n                \"4\": 500.65,\n                \"5\": 487.43,\n                \"6\": 502.39,\n                \"7\": 484.76,\n                \"8\": 498.5,\n                \"9\": 489.98,\n                \"10\": 499.22,\n                \"11\": 490.47,\n                \"12\": 498.98,\n                \"13\": 491.42,\n                \"14\": 495.17,\n                \"15\": 492.24,\n                \"16\": 494.69,\n                \"17\": 491.68,\n                \"18\": 497.13,\n                \"19\": 493.63,\n                \"20\": 495.29,\n                \"21\": 496.19,\n                \"22\": 496.52,\n                \"23\": 496.31,\n                \"24\": 493.64,\n                \"25\": 494.35,\n                \"26\": 493.27,\n                \"27\": 495.52,\n                \"28\": 493.06,\n                \"29\": 494.66,\n                \"30\": 494.08,\n                \"31\": 496.01,\n                \"32\": 494.79,\n                \"33\": 495.81,\n                \"34\": 493.94,\n                \"35\": 495.56,\n                \"36\": 493.5,\n                \"37\": 495.98,\n                \"38\": 495.82,\n                \"39\": 496.7,\n                \"40\": 495.29,\n                \"41\": 496.84,\n                \"42\": 495.67,\n                \"43\": 495.32,\n                \"44\": 496.48,\n                \"45\": 496.03,\n                \"46\": 495.65,\n                \"47\": 498.27,\n                \"48\": 496.44,\n                \"49\": 496.79\n            },\n            \"expected_e2e_ms\": 25735.08,\n            \"expected_avg_denoise_ms\": 486.1,\n            \"expected_median_denoise_ms\": 495.42\n        },\n        \"wan2_1_t2v_1.3b_upscaling_4x\": {\n            \"stages_ms\": {\n                \"InputValidationStage\": 0.08,\n                \"TextEncodingStage\": 1164.21,\n                \"LatentPreparationStage\": 0.26,\n                \"TimestepPreparationStage\": 3.86,\n                \"DenoisingStage\": 10234.35,\n                \"DecodingStage\": 499.44\n            },\n            \"denoise_step_ms\": {\n                \"0\": 200.87,\n                \"1\": 202.85,\n                \"2\": 203.32,\n                \"3\": 206.02,\n                \"4\": 205.75,\n                \"5\": 204.14,\n                \"6\": 205.12,\n                \"7\": 204.65,\n                \"8\": 203.99,\n                \"9\": 204.96,\n                \"10\": 204.35,\n                \"11\": 206.89,\n                \"12\": 200.69,\n                \"13\": 209.67,\n                \"14\": 204.91,\n                \"15\": 203.5,\n                \"16\": 206.73,\n                \"17\": 202.43,\n                \"18\": 205.92,\n                \"19\": 204.61,\n                \"20\": 211.47,\n                \"21\": 197.43,\n                \"22\": 203.58,\n                \"23\": 205.82,\n                \"24\": 204.01,\n                \"25\": 205.06,\n                \"26\": 204.86,\n                \"27\": 206.03,\n                \"28\": 200.78,\n                \"29\": 206.99,\n                \"30\": 206.58,\n                \"31\": 202.84,\n                \"32\": 204.51,\n                \"33\": 204.19,\n                \"34\": 202.89,\n                \"35\": 204.55,\n                \"36\": 205.03,\n                \"37\": 204.2,\n                \"38\": 203.92,\n                \"39\": 204.9,\n                \"40\": 203.24,\n                \"41\": 204.21,\n                \"42\": 205.76,\n                \"43\": 205.32,\n                \"44\": 202.63,\n                \"45\": 205.67,\n                \"46\": 204.55,\n                \"47\": 202.89,\n                \"48\": 205.29,\n                \"49\": 203.87\n            },\n            \"expected_e2e_ms\": 12021.58,\n            \"expected_avg_denoise_ms\": 204.49,\n            \"expected_median_denoise_ms\": 204.55\n        },\n        \"wan2_1_t2v_1.3b_frame_interp_2x_upscaling_4x\": {\n            \"stages_ms\": {\n                \"InputValidationStage\": 0.03,\n                \"TextEncodingStage\": 1089.94,\n                \"LatentPreparationStage\": 0.12,\n                \"TimestepPreparationStage\": 1.99,\n                \"DenoisingStage\": 8617.56,\n                \"DecodingStage\": 469.68\n            },\n            \"denoise_step_ms\": {\n                \"0\": 122.51,\n                \"1\": 172.31,\n                \"2\": 169.58,\n                \"3\": 171.76,\n                \"4\": 171.5,\n                \"5\": 174.24,\n                \"6\": 175.85,\n                \"7\": 171.2,\n                \"8\": 172.0,\n                \"9\": 172.13,\n                \"10\": 171.22,\n                \"11\": 174.8,\n                \"12\": 174.01,\n                \"13\": 172.28,\n                \"14\": 172.77,\n                \"15\": 173.52,\n                \"16\": 172.19,\n                \"17\": 175.12,\n                \"18\": 172.78,\n                \"19\": 175.1,\n                \"20\": 171.54,\n                \"21\": 173.38,\n                \"22\": 171.61,\n                \"23\": 174.14,\n                \"24\": 174.06,\n                \"25\": 172.32,\n                \"26\": 173.08,\n                \"27\": 173.94,\n                \"28\": 173.32,\n                \"29\": 174.3,\n                \"30\": 173.63,\n                \"31\": 172.21,\n                \"32\": 174.4,\n                \"33\": 173.25,\n                \"34\": 173.54,\n                \"35\": 175.12,\n                \"36\": 172.93,\n                \"37\": 172.76,\n                \"38\": 174.73,\n                \"39\": 174.46,\n                \"40\": 172.66,\n                \"41\": 174.58,\n                \"42\": 173.9,\n                \"43\": 174.88,\n                \"44\": 172.35,\n                \"45\": 173.52,\n                \"46\": 175.94,\n                \"47\": 172.88,\n                \"48\": 174.97,\n                \"49\": 172.94\n            },\n            \"expected_e2e_ms\": 10425.77,\n            \"expected_avg_denoise_ms\": 172.28,\n            \"expected_median_denoise_ms\": 173.29\n        },\n        \"helios_base_t2v\": {\n            \"stages_ms\": {\n                \"InputValidationStage\": 0.07,\n                \"TextEncodingStage\": 1103.97,\n                \"LatentPreparationStage\": 0.24,\n                \"HeliosChunkedDenoisingStage\": 118580.37,\n                \"HeliosDecodingStage\": 664.79,\n                \"per_frame_generation\": null\n            },\n            \"denoise_step_ms\": {},\n            \"expected_e2e_ms\": 120413.51,\n            \"expected_avg_denoise_ms\": 0.0,\n            \"expected_median_denoise_ms\": 0.0\n        },\n        \"helios_distilled_t2v\": {\n            \"stages_ms\": {\n                \"InputValidationStage\": 0.13,\n                \"TextEncodingStage\": 581.79,\n                \"LatentPreparationStage\": 0.18,\n                \"HeliosChunkedDenoisingStage\": 49752.88,\n                \"HeliosDecodingStage\": 666.69,\n                \"per_frame_generation\": null\n            },\n            \"denoise_step_ms\": {},\n            \"expected_e2e_ms\": 51038.66,\n            \"expected_avg_denoise_ms\": 0.0,\n            \"expected_median_denoise_ms\": 0.0\n        },\n        \"helios_mid_t2v\": {\n            \"stages_ms\": {\n                \"InputValidationStage\": 0.05,\n                \"TextEncodingStage\": 1101.99,\n                \"LatentPreparationStage\": 0.16,\n                \"HeliosChunkedDenoisingStage\": 77728.72,\n                \"HeliosDecodingStage\": 661.23,\n                \"per_frame_generation\": null\n            },\n            \"denoise_step_ms\": {},\n            \"expected_e2e_ms\": 79600.62,\n            \"expected_avg_denoise_ms\": 0.0,\n            \"expected_median_denoise_ms\": 0.0\n        },\n        \"helios_base_t2v\": {\n            \"stages_ms\": {\n                \"InputValidationStage\": 0.04,\n                \"TextEncodingStage\": 1102.45,\n                \"LatentPreparationStage\": 0.14,\n                \"HeliosChunkedDenoisingStage\": 116964.69,\n                \"HeliosDecodingStage\": 664.76,\n                \"per_frame_generation\": null\n            },\n            \"denoise_step_ms\": {\n                \"0\": 1893.3,\n                \"1\": 1900.93,\n                \"2\": 1934.08,\n                \"3\": 1897.65,\n                \"4\": 1907.59,\n                \"5\": 1909.1,\n                \"6\": 1911.51,\n                \"7\": 1909.25,\n                \"8\": 1911.69,\n                \"9\": 1911.77,\n                \"10\": 1913.35,\n                \"11\": 1915.44,\n                \"12\": 1912.11,\n                \"13\": 1910.08,\n                \"14\": 1911.77,\n                \"15\": 1908.22,\n                \"16\": 1908.83,\n                \"17\": 1910.11,\n                \"18\": 1908.19,\n                \"19\": 1911.99,\n                \"20\": 1909.96,\n                \"21\": 1910.32,\n                \"22\": 1911.76,\n                \"23\": 1911.87,\n                \"24\": 1908.91,\n                \"25\": 1912.41,\n                \"26\": 1913.15,\n                \"27\": 1908.34,\n                \"28\": 1913.21,\n                \"29\": 1911.98,\n                \"30\": 1912.16,\n                \"31\": 1914.17,\n                \"32\": 1911.45,\n                \"33\": 1912.5,\n                \"34\": 1914.48,\n                \"35\": 1912.64,\n                \"36\": 1912.24,\n                \"37\": 1914.48,\n                \"38\": 1911.06,\n                \"39\": 1915.45,\n                \"40\": 1914.0,\n                \"41\": 1912.99,\n                \"42\": 1913.68,\n                \"43\": 1914.09,\n                \"44\": 1915.83,\n                \"45\": 1913.36,\n                \"46\": 1914.84,\n                \"47\": 1915.31,\n                \"48\": 1915.58,\n                \"49\": 1912.63\n            },\n            \"expected_e2e_ms\": 118821.41,\n            \"expected_avg_denoise_ms\": 1911.64,\n            \"expected_median_denoise_ms\": 1912.05\n        },\n        \"helios_mid_t2v\": {\n            \"stages_ms\": {\n                \"InputValidationStage\": 0.09,\n                \"TextEncodingStage\": 1102.28,\n                \"LatentPreparationStage\": 0.23,\n                \"HeliosChunkedDenoisingStage\": 77947.9,\n                \"HeliosDecodingStage\": 664.96,\n                \"per_frame_generation\": null\n            },\n            \"denoise_step_ms\": {\n                \"0\": 404.46,\n                \"1\": 404.88,\n                \"2\": 405.35,\n                \"3\": 406.01,\n                \"4\": 404.97,\n                \"5\": 405.07,\n                \"6\": 405.06,\n                \"7\": 404.98,\n                \"8\": 405.39,\n                \"9\": 405.52,\n                \"10\": 405.76,\n                \"11\": 405.53,\n                \"12\": 405.16,\n                \"13\": 405.46,\n                \"14\": 405.75,\n                \"15\": 405.69,\n                \"16\": 405.26,\n                \"17\": 405.23,\n                \"18\": 405.42,\n                \"19\": 405.99,\n                \"20\": 663.39,\n                \"21\": 666.6,\n                \"22\": 665.73,\n                \"23\": 666.37,\n                \"24\": 667.43,\n                \"25\": 668.28,\n                \"26\": 667.96,\n                \"27\": 668.93,\n                \"28\": 667.78,\n                \"29\": 668.15,\n                \"30\": 668.91,\n                \"31\": 667.22,\n                \"32\": 669.31,\n                \"33\": 666.57,\n                \"34\": 669.78,\n                \"35\": 668.38,\n                \"36\": 669.95,\n                \"37\": 668.76,\n                \"38\": 667.82,\n                \"39\": 668.98,\n                \"40\": 1891.05,\n                \"41\": 1893.52,\n                \"42\": 1893.48,\n                \"43\": 1892.79,\n                \"44\": 1892.03,\n                \"45\": 1892.87,\n                \"46\": 1895.55,\n                \"47\": 1892.19,\n                \"48\": 1892.89,\n                \"49\": 1892.32,\n                \"50\": 1890.25,\n                \"51\": 1894.1,\n                \"52\": 1890.67,\n                \"53\": 1892.09,\n                \"54\": 1892.64,\n                \"55\": 1891.91,\n                \"56\": 1894.27,\n                \"57\": 1893.62,\n                \"58\": 1892.65,\n                \"59\": 1891.9\n            },\n            \"expected_e2e_ms\": 79824.32,\n            \"expected_avg_denoise_ms\": 988.6,\n            \"expected_median_denoise_ms\": 668.05\n        },\n        \"helios_distilled_t2v\": {\n            \"stages_ms\": {\n                \"InputValidationStage\": 0.05,\n                \"TextEncodingStage\": 552.02,\n                \"LatentPreparationStage\": 0.13,\n                \"HeliosChunkedDenoisingStage\": 57879.88,\n                \"HeliosDecodingStage\": 663.31,\n                \"per_frame_generation\": null\n            },\n            \"denoise_step_ms\": {\n                \"0\": 207.03,\n                \"1\": 204.36,\n                \"2\": 203.87,\n                \"3\": 204.51,\n                \"4\": 206.21,\n                \"5\": 205.54,\n                \"6\": 205.06,\n                \"7\": 205.45,\n                \"8\": 205.96,\n                \"9\": 205.95,\n                \"10\": 205.22,\n                \"11\": 204.43,\n                \"12\": 205.14,\n                \"13\": 205.06,\n                \"14\": 205.11,\n                \"15\": 206.09,\n                \"16\": 205.1,\n                \"17\": 204.99,\n                \"18\": 204.55,\n                \"19\": 205.14,\n                \"20\": 337.47,\n                \"21\": 337.06,\n                \"22\": 337.68,\n                \"23\": 336.58,\n                \"24\": 335.98,\n                \"25\": 335.84,\n                \"26\": 336.01,\n                \"27\": 335.61,\n                \"28\": 335.79,\n                \"29\": 335.62,\n                \"30\": 336.69,\n                \"31\": 335.98,\n                \"32\": 336.15,\n                \"33\": 336.55,\n                \"34\": 336.98,\n                \"35\": 337.33,\n                \"36\": 336.34,\n                \"37\": 335.94,\n                \"38\": 336.69,\n                \"39\": 336.14,\n                \"40\": 954.88,\n                \"41\": 956.2,\n                \"42\": 953.9,\n                \"43\": 953.49,\n                \"44\": 957.1,\n                \"45\": 956.95,\n                \"46\": 955.02,\n                \"47\": 954.98,\n                \"48\": 956.0,\n                \"49\": 956.63,\n                \"50\": 958.66,\n                \"51\": 957.26,\n                \"52\": 956.73,\n                \"53\": 955.06,\n                \"54\": 957.04,\n                \"55\": 958.07,\n                \"56\": 958.28,\n                \"57\": 957.99,\n                \"58\": 957.61,\n                \"59\": 956.98\n            },\n            \"expected_e2e_ms\": 59168.9,\n            \"expected_avg_denoise_ms\": 499.37,\n            \"expected_median_denoise_ms\": 336.25\n        }\n    }\n}\n"
  },
  {
    "path": "python/sglang/multimodal_gen/test/server/test_server_2_gpu_a.py",
    "content": "\"\"\"\n2 GPU tests\n\"\"\"\n\nfrom __future__ import annotations\n\nimport pytest\n\nfrom sglang.multimodal_gen.test.server.test_server_common import (  # noqa: F401\n    DiffusionServerBase,\n    diffusion_server,\n)\nfrom sglang.multimodal_gen.test.server.testcase_configs import (\n    TWO_GPU_CASES_A,\n    DiffusionTestCase,\n)\n\n\nclass TestDiffusionServerTwoGpu(DiffusionServerBase):\n    \"\"\"Performance tests for 2-GPU diffusion cases.\"\"\"\n\n    @pytest.fixture(params=TWO_GPU_CASES_A, ids=lambda c: c.id)\n    def case(self, request) -> DiffusionTestCase:\n        \"\"\"Provide a DiffusionTestCase for each 2-GPU test.\"\"\"\n        return request.param\n"
  },
  {
    "path": "python/sglang/multimodal_gen/test/server/test_server_2_gpu_b.py",
    "content": "\"\"\"\n2 GPU tests\n\"\"\"\n\nfrom __future__ import annotations\n\nimport pytest\n\nfrom sglang.multimodal_gen.test.server.test_server_common import (  # noqa: F401\n    DiffusionServerBase,\n    diffusion_server,\n)\nfrom sglang.multimodal_gen.test.server.testcase_configs import (\n    TWO_GPU_CASES_B,\n    DiffusionTestCase,\n)\n\n\nclass TestDiffusionServerTwoGpu(DiffusionServerBase):\n    \"\"\"Performance tests for 2-GPU diffusion cases.\"\"\"\n\n    @pytest.fixture(params=TWO_GPU_CASES_B, ids=lambda c: c.id)\n    def case(self, request) -> DiffusionTestCase:\n        \"\"\"Provide a DiffusionTestCase for each 2-GPU test.\"\"\"\n        return request.param\n"
  },
  {
    "path": "python/sglang/multimodal_gen/test/server/test_server_a.py",
    "content": "\"\"\"\nConfig-driven diffusion performance test with pytest parametrization.\n\n\nIf the actual run is significantly better than the baseline, the improved cases with their updated baseline will be printed\n\"\"\"\n\nfrom __future__ import annotations\n\nimport pytest\n\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.test.server.test_server_common import (  # noqa: F401\n    DiffusionServerBase,\n    diffusion_server,\n)\nfrom sglang.multimodal_gen.test.server.testcase_configs import (\n    ONE_GPU_CASES_A,\n    DiffusionTestCase,\n)\n\nlogger = init_logger(__name__)\n\n\nclass TestDiffusionServerOneGpu(DiffusionServerBase):\n    \"\"\"Performance tests for 1-GPU diffusion cases.\"\"\"\n\n    @pytest.fixture(params=ONE_GPU_CASES_A, ids=lambda c: c.id)\n    def case(self, request) -> DiffusionTestCase:\n        \"\"\"Provide a DiffusionTestCase for each 1-GPU test.\"\"\"\n        return request.param\n"
  },
  {
    "path": "python/sglang/multimodal_gen/test/server/test_server_b.py",
    "content": "\"\"\"\nConfig-driven diffusion performance test with pytest parametrization.\n\n\nIf the actual run is significantly better than the baseline, the improved cases with their updated baseline will be printed\n\"\"\"\n\nfrom __future__ import annotations\n\nimport pytest\n\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.test.server.test_server_common import (  # noqa: F401\n    DiffusionServerBase,\n    diffusion_server,\n)\nfrom sglang.multimodal_gen.test.server.testcase_configs import (\n    ONE_GPU_CASES_B,\n    DiffusionTestCase,\n)\n\nlogger = init_logger(__name__)\n\n\nclass TestDiffusionServerOneGpu(DiffusionServerBase):\n    \"\"\"Performance tests for 1-GPU diffusion cases.\"\"\"\n\n    @pytest.fixture(params=ONE_GPU_CASES_B, ids=lambda c: c.id)\n    def case(self, request) -> DiffusionTestCase:\n        \"\"\"Provide a DiffusionTestCase for each 1-GPU test.\"\"\"\n        return request.param\n"
  },
  {
    "path": "python/sglang/multimodal_gen/test/server/test_server_common.py",
    "content": "\"\"\"\nConfig-driven diffusion generation test with pytest parametrization.\n\n\nIf the actual run is significantly better than the baseline, the improved cases with their updated baseline will be printed\n\"\"\"\n\nfrom __future__ import annotations\n\nimport os\nfrom pathlib import Path\nfrom typing import Any, Callable\n\nimport openai\nimport pytest\nimport requests\nfrom openai import OpenAI\n\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.runtime.utils.perf_logger import RequestPerfRecord\nfrom sglang.multimodal_gen.test.server import conftest\nfrom sglang.multimodal_gen.test.server.test_server_utils import (\n    VALIDATOR_REGISTRY,\n    PerformanceValidator,\n    ServerContext,\n    ServerManager,\n    get_generate_fn,\n)\nfrom sglang.multimodal_gen.test.server.testcase_configs import (\n    BASELINE_CONFIG,\n    DiffusionTestCase,\n    PerformanceSummary,\n    ScenarioConfig,\n)\nfrom sglang.multimodal_gen.test.test_utils import (\n    _consistency_gt_filenames,\n    extract_key_frames_from_video,\n    get_dynamic_server_port,\n    wait_for_req_perf_record,\n)\n\nlogger = init_logger(__name__)\n\n\n@pytest.fixture\ndef diffusion_server(case: DiffusionTestCase) -> ServerContext:\n    \"\"\"Start a diffusion server for a single case and tear it down afterwards.\"\"\"\n    server_args = case.server_args\n\n    # Skip ring attention tests on AMD/ROCm - Ring Attention requires Flash Attention\n    # which is not available on AMD. Use Ulysses parallelism instead.\n    if (\n        current_platform.is_hip()\n        and server_args.ring_degree is not None\n        and server_args.ring_degree > 1\n    ):\n        pytest.skip(\n            f\"Skipping {case.id}: Ring Attention (ring_degree={server_args.ring_degree}) \"\n            \"requires Flash Attention which is not available on AMD/ROCm\"\n        )\n\n    default_port = get_dynamic_server_port()\n    port = int(os.environ.get(\"SGLANG_TEST_SERVER_PORT\", default_port))\n    sampling_params = case.sampling_params\n    extra_args = os.environ.get(\"SGLANG_TEST_SERVE_ARGS\", \"\")\n\n    # In GT generation mode, force --backend diffusers\n    if os.environ.get(\"SGLANG_GEN_GT\", \"0\") == \"1\":\n        if \"--backend\" not in extra_args:\n            extra_args = \"--backend diffusers \" + extra_args.strip()\n\n    extra_args += f\" --num-gpus {server_args.num_gpus}\"\n\n    if server_args.tp_size is not None:\n        extra_args += f\" --tp-size {server_args.tp_size}\"\n\n    if server_args.ulysses_degree is not None:\n        extra_args += f\" --ulysses-degree {server_args.ulysses_degree}\"\n\n    if server_args.dit_layerwise_offload:\n        extra_args += f\" --dit-layerwise-offload true\"\n\n    if server_args.dit_offload_prefetch_size:\n        extra_args += (\n            f\" --dit-offload-prefetch-size {server_args.dit_offload_prefetch_size}\"\n        )\n\n    if server_args.text_encoder_cpu_offload:\n        extra_args += f\" --text-encoder-cpu-offload\"\n\n    if server_args.ring_degree is not None:\n        extra_args += f\" --ring-degree {server_args.ring_degree}\"\n\n    # LoRA support\n    if server_args.lora_path:\n        extra_args += f\" --lora-path {server_args.lora_path}\"\n\n    # default warmup\n    extra_args += f\" --warmup\"\n\n    for arg in server_args.extras:\n        extra_args += f\" {arg}\"\n\n    # Build custom environment variables\n    env_vars = {}\n    if server_args.enable_cache_dit:\n        env_vars[\"SGLANG_CACHE_DIT_ENABLED\"] = \"true\"\n\n    # start server\n    manager = ServerManager(\n        model=server_args.model_path,\n        port=port,\n        wait_deadline=float(os.environ.get(\"SGLANG_TEST_WAIT_SECS\", \"1200\")),\n        extra_args=extra_args,\n        env_vars=env_vars,\n    )\n    ctx = manager.start()\n\n    try:\n        # Reconstruct output size for OpenAI API\n        # Allow override via environment variable (useful for AMD where large resolutions can cause GPU hang)\n        output_size = os.environ.get(\n            \"SGLANG_TEST_OUTPUT_SIZE\", sampling_params.output_size\n        )\n    except Exception as exc:\n        logger.error(\"Warm-up failed for %s: %s\", case.id, exc)\n        ctx.cleanup()\n        raise\n\n    try:\n        yield ctx\n    finally:\n        ctx.cleanup()\n\n\nclass DiffusionServerBase:\n    \"\"\"Performance tests for all diffusion models/scenarios.\n\n    This single test class runs against all cases defined in ONE_GPU_CASES.\n    Each case gets its own server instance via the parametrized fixture.\n    \"\"\"\n\n    _perf_results: list[dict[str, Any]] = []\n    _improved_baselines: list[dict[str, Any]] = []\n    _pytest_config = None  # Store pytest config for stash access\n\n    @classmethod\n    def setup_class(cls):\n        cls._perf_results = []\n        cls._improved_baselines = []\n\n    @classmethod\n    def teardown_class(cls):\n        print(\n            f\"\\n[DEBUG teardown_class] Called for {cls.__name__}, _perf_results has {len(cls._perf_results)} entries\"\n        )\n        if cls._pytest_config:\n            # Add results to pytest stash (shared across all import contexts)\n            for result in cls._perf_results:\n                result[\"class_name\"] = cls.__name__\n            conftest.add_perf_results(cls._pytest_config, cls._perf_results)\n            print(\n                f\"[DEBUG teardown_class] Added {len(cls._perf_results)} results to stash\"\n            )\n        else:\n            print(\n                \"[DEBUG teardown_class] No pytest_config available, skipping stash update\"\n            )\n\n        if cls._improved_baselines:\n            import json\n\n            output = \"\"\"\n--- POTENTIAL BASELINE IMPROVEMENTS DETECTED ---\nThe following test cases performed significantly better than their baselines.\nConsider updating perf_baselines.json with the snippets below:\n\"\"\"\n            for item in cls._improved_baselines:\n                output += (\n                    f'\\n\"{item[\"id\"]}\": {json.dumps(item[\"baseline\"], indent=4)},\\n'\n                )\n            print(output)\n\n    @pytest.fixture(autouse=True)\n    def _capture_pytest_config(self, request):\n        \"\"\"Capture pytest config for use in teardown_class.\"\"\"\n        self.__class__._pytest_config = request.config\n\n    def _client(self, ctx: ServerContext) -> OpenAI:\n        \"\"\"Get OpenAI client for the server.\"\"\"\n        return OpenAI(\n            api_key=\"sglang-anything\",\n            base_url=f\"http://localhost:{ctx.port}/v1\",\n        )\n\n    def run_and_collect(\n        self,\n        ctx: ServerContext,\n        case_id: str,\n        generate_fn: Callable[[str, openai.Client], tuple[str, bytes]],\n    ) -> tuple[RequestPerfRecord, bytes]:\n        \"\"\"Run generation and collect performance records.\n\n        Returns:\n            Tuple of (performance_record, content_bytes)\n        \"\"\"\n        log_path = ctx.perf_log_path\n        log_wait_timeout = 30\n\n        client = self._client(ctx)\n        rid, content = generate_fn(case_id, client)\n\n        req_perf_record = wait_for_req_perf_record(\n            rid,\n            log_path,\n            timeout=log_wait_timeout,\n        )\n\n        return (req_perf_record, content)\n\n    def _validate_and_record(\n        self,\n        case: DiffusionTestCase,\n        perf_record: RequestPerfRecord,\n    ) -> None:\n        \"\"\"Validate metrics and record results.\"\"\"\n        is_baseline_generation_mode = os.environ.get(\"SGLANG_GEN_BASELINE\", \"0\") == \"1\"\n\n        scenario = BASELINE_CONFIG.scenarios.get(case.id)\n        missing_scenario = False\n        if scenario is None:\n            # Create dummy scenario to allow metric collection\n            scenario = type(\n                \"DummyScenario\",\n                (),\n                {\n                    \"expected_e2e_ms\": 0,\n                    \"expected_avg_denoise_ms\": 0,\n                    \"expected_median_denoise_ms\": 0,\n                    \"stages_ms\": {},\n                    \"denoise_step_ms\": {},\n                },\n            )()\n            if not is_baseline_generation_mode:\n                missing_scenario = True\n\n        validator_name = case.server_args.custom_validator or \"default\"\n        validator_class = VALIDATOR_REGISTRY.get(validator_name, PerformanceValidator)\n\n        validator = validator_class(\n            scenario=scenario,\n            tolerances=BASELINE_CONFIG.tolerances,\n            step_fractions=BASELINE_CONFIG.step_fractions,\n        )\n\n        summary = validator.collect_metrics(perf_record)\n\n        if case.run_perf_check:\n            if is_baseline_generation_mode or missing_scenario:\n                self._dump_baseline_for_testcase(case, summary, missing_scenario)\n                if missing_scenario:\n                    pytest.fail(\n                        f\"Testcase '{case.id}' not found in perf_baselines.json\"\n                    )\n                return\n\n            self._check_for_improvement(case, summary, scenario)\n\n            # only run performance validation if run_perf_check is True\n            try:\n                validator.validate(perf_record, case.sampling_params.num_frames)\n            except AssertionError as e:\n                logger.error(f\"Performance validation failed for {case.id}:\\n{e}\")\n                self._dump_baseline_for_testcase(case, summary, missing_scenario)\n                raise\n\n        result = {\n            \"test_name\": case.id,\n            \"modality\": case.server_args.modality,\n            \"e2e_ms\": summary.e2e_ms,\n            \"avg_denoise_ms\": summary.avg_denoise_ms,\n            \"median_denoise_ms\": summary.median_denoise_ms,\n            \"stage_metrics\": summary.stage_metrics,\n            \"sampled_steps\": summary.sampled_steps,\n        }\n\n        # video-specific metrics\n        if summary.frames_per_second:\n            result.update(\n                {\n                    \"frames_per_second\": summary.frames_per_second,\n                    \"total_frames\": summary.total_frames,\n                    \"avg_frame_time_ms\": summary.avg_frame_time_ms,\n                }\n            )\n\n        self.__class__._perf_results.append(result)\n        print(\n            f\"[DEBUG _validate_and_record] Appended result for {case.id}, class {self.__class__.__name__} now has {len(self.__class__._perf_results)} results\"\n        )\n\n    def _check_for_improvement(\n        self,\n        case: DiffusionTestCase,\n        summary: PerformanceSummary,\n        scenario: \"ScenarioConfig\",\n    ) -> None:\n        \"\"\"Check for potential significant performance improvements and record them.\"\"\"\n        is_improved = False\n        threshold = BASELINE_CONFIG.improvement_threshold\n\n        def is_sig_faster(actual, expected):\n            if expected == 0 or expected is None:\n                return False\n            return actual < expected * (1 - threshold)\n\n        def safe_get_metric(metric_dict, key):\n            val = metric_dict.get(key)\n            return val if val is not None else float(\"inf\")\n\n        # Check for any significant improvement\n        if (\n            is_sig_faster(summary.e2e_ms, scenario.expected_e2e_ms)\n            or is_sig_faster(summary.avg_denoise_ms, scenario.expected_avg_denoise_ms)\n            or is_sig_faster(\n                summary.median_denoise_ms, scenario.expected_median_denoise_ms\n            )\n        ):\n            is_improved = True\n        # Combine metrics, always taking the better (lower) value\n        new_stages = {\n            stage: min(\n                safe_get_metric(summary.stage_metrics, stage),\n                safe_get_metric(scenario.stages_ms, stage),\n            )\n            for stage in set(summary.stage_metrics) | set(scenario.stages_ms)\n        }\n        new_denoise_steps = {\n            step: min(\n                safe_get_metric(summary.all_denoise_steps, step),\n                safe_get_metric(scenario.denoise_step_ms, step),\n            )\n            for step in set(summary.all_denoise_steps.keys())\n            | set(scenario.denoise_step_ms)\n        }\n\n        # Check for stage-level improvements\n        if not is_improved:\n            for stage, new_val in new_stages.items():\n                if is_sig_faster(new_val, scenario.stages_ms.get(stage, float(\"inf\"))):\n                    is_improved = True\n                    break\n        if not is_improved:\n            for step, new_val in new_denoise_steps.items():\n                if is_sig_faster(\n                    new_val, scenario.denoise_step_ms.get(step, float(\"inf\"))\n                ):\n                    is_improved = True\n                    break\n\n        if is_improved:\n            new_baseline = {\n                \"stages_ms\": {k: round(v, 2) for k, v in new_stages.items()},\n                \"denoise_step_ms\": {\n                    str(k): round(v, 2) for k, v in new_denoise_steps.items()\n                },\n                \"expected_e2e_ms\": round(\n                    min(summary.e2e_ms, scenario.expected_e2e_ms), 2\n                ),\n                \"expected_avg_denoise_ms\": round(\n                    min(summary.avg_denoise_ms, scenario.expected_avg_denoise_ms), 2\n                ),\n                \"expected_median_denoise_ms\": round(\n                    min(summary.median_denoise_ms, scenario.expected_median_denoise_ms),\n                    2,\n                ),\n            }\n            self._improved_baselines.append({\"id\": case.id, \"baseline\": new_baseline})\n\n    def _dump_baseline_for_testcase(\n        self,\n        case: DiffusionTestCase,\n        summary: \"PerformanceSummary\",\n        missing_scenario: bool = False,\n    ) -> None:\n        \"\"\"Dump performance metrics as a JSON scenario for baselines.\"\"\"\n        import json\n\n        denoise_steps_formatted = {\n            str(k): round(v, 2) for k, v in summary.all_denoise_steps.items()\n        }\n        stages_formatted = {k: round(v, 2) for k, v in summary.stage_metrics.items()}\n\n        baseline = {\n            \"stages_ms\": stages_formatted,\n            \"denoise_step_ms\": denoise_steps_formatted,\n            \"expected_e2e_ms\": round(summary.e2e_ms, 2),\n            \"expected_avg_denoise_ms\": round(summary.avg_denoise_ms, 2),\n            \"expected_median_denoise_ms\": round(summary.median_denoise_ms, 2),\n        }\n\n        # Video-specific metrics\n        if case.server_args.modality == \"video\":\n            if \"per_frame_generation\" not in baseline[\"stages_ms\"]:\n                baseline[\"stages_ms\"][\"per_frame_generation\"] = (\n                    round(summary.avg_frame_time_ms, 2)\n                    if summary.avg_frame_time_ms\n                    else None\n                )\n        action = \"add\" if missing_scenario else \"update\"\n        output = f\"\"\"\n{action} this baseline in the \"scenarios\" section of perf_baselines.json:\n\n\"{case.id}\": {json.dumps(baseline, indent=4)}\n\n\"\"\"\n        logger.error(output)\n\n    def _save_gt_output(\n        self,\n        case: DiffusionTestCase,\n        content: bytes,\n    ) -> None:\n        \"\"\"Save generated content as ground truth files.\n\n        Args:\n            case: Test case configuration\n            content: Generated content bytes (image or video)\n        \"\"\"\n        gt_output_dir = os.environ.get(\"SGLANG_GT_OUTPUT_DIR\")\n        if not gt_output_dir:\n            logger.error(\"SGLANG_GT_OUTPUT_DIR not set, cannot save GT output\")\n            return\n\n        out_dir = Path(gt_output_dir)\n        out_dir.mkdir(parents=True, exist_ok=True)\n\n        num_gpus = case.server_args.num_gpus\n        is_video = case.server_args.modality == \"video\"\n\n        if is_video:\n            # Extract key frames from video\n            frames = extract_key_frames_from_video(\n                content, num_frames=case.sampling_params.num_frames\n            )\n\n            if len(frames) != 3:\n                logger.warning(\n                    f\"{case.id}: expected 3 frames, got {len(frames)}, skipping frame save\"\n                )\n                return\n\n            # Save frames (reuse naming from _consistency_gt_filenames)\n            filenames = _consistency_gt_filenames(case.id, num_gpus, is_video=True)\n            from PIL import Image\n\n            for frame, fn in zip(frames, filenames):\n                frame_path = out_dir / fn\n                Image.fromarray(frame).save(frame_path)\n                logger.info(f\"Saved GT frame: {frame_path}\")\n        else:\n            # Save image\n            from sglang.multimodal_gen.test.test_utils import detect_image_format\n\n            detected_format = detect_image_format(content)\n            filenames = _consistency_gt_filenames(\n                case.id, num_gpus, is_video=False, output_format=detected_format\n            )\n            output_path = out_dir / filenames[0]\n            output_path.write_bytes(content)\n            logger.info(f\"Saved GT image: {output_path} (format: {detected_format})\")\n\n    def _test_lora_api_functionality(\n        self,\n        ctx: ServerContext,\n        case: DiffusionTestCase,\n        generate_fn: Callable[[str, openai.Client], tuple[str, bytes]],\n    ) -> None:\n        \"\"\"\n        Test LoRA API functionality with end-to-end validation: merge, unmerge, and set_lora.\n        This test verifies that each API call succeeds AND that generation works after each operation.\n        \"\"\"\n        base_url = f\"http://localhost:{ctx.port}/v1\"\n        client = OpenAI(base_url=base_url, api_key=\"dummy\")\n\n        # Test 1: unmerge_lora_weights - API should succeed and generation should work\n        logger.info(\"[LoRA E2E] Testing unmerge_lora_weights for %s\", case.id)\n        resp = requests.post(f\"{base_url}/unmerge_lora_weights\")\n        assert resp.status_code == 200, f\"unmerge_lora_weights failed: {resp.text}\"\n\n        logger.info(\"[LoRA E2E] Verifying generation after unmerge for %s\", case.id)\n        rid_after_unmerge, _ = generate_fn(case.id, client)\n        assert rid_after_unmerge is not None, \"Generation after unmerge failed\"\n        logger.info(\"[LoRA E2E] Generation after unmerge succeeded\")\n\n        # Test 2: merge_lora_weights - API should succeed and generation should work\n        logger.info(\"[LoRA E2E] Testing merge_lora_weights for %s\", case.id)\n        resp = requests.post(f\"{base_url}/merge_lora_weights\")\n        assert resp.status_code == 200, f\"merge_lora_weights failed: {resp.text}\"\n\n        logger.info(\"[LoRA E2E] Verifying generation after re-merge for %s\", case.id)\n        rid_after_merge, _ = generate_fn(case.id, client)\n        assert rid_after_merge is not None, \"Generation after merge failed\"\n        logger.info(\"[LoRA E2E] Generation after merge succeeded\")\n\n        # Test 3: set_lora (re-set the same adapter) - API should succeed and generation should work\n        logger.info(\"[LoRA E2E] Testing set_lora for %s\", case.id)\n        resp = requests.post(f\"{base_url}/set_lora\", json={\"lora_nickname\": \"default\"})\n        assert resp.status_code == 200, f\"set_lora failed: {resp.text}\"\n\n        logger.info(\"[LoRA E2E] Verifying generation after set_lora for %s\", case.id)\n        rid_after_set, _ = generate_fn(case.id, client)\n        assert rid_after_set is not None, \"Generation after set_lora failed\"\n        logger.info(\"[LoRA E2E] Generation after set_lora succeeded\")\n\n        # Test 4: list_loras - API should return the expected list of LoRA adapters\n        logger.info(\"[LoRA E2E] Testing list_loras for %s\", case.id)\n        resp = requests.get(f\"{base_url}/list_loras\")\n        assert resp.status_code == 200, f\"list_loras failed: {resp.text}\"\n        lora_info = resp.json()\n        logger.info(\"[LoRA E2E] list_loras returned %s\", lora_info)\n        assert (\n            isinstance(lora_info[\"loaded_adapters\"], list)\n            and len(lora_info[\"loaded_adapters\"]) > 0\n        ), \"loaded_adapters should be a non-empty list\"\n        assert any(\n            a.get(\"nickname\") == \"default\" for a in lora_info[\"loaded_adapters\"]\n        ), f\"nickname 'default' not found in loaded_adapters: {lora_info['loaded_adapters']}\"\n        logger.info(\"[LoRA E2E] list_loras returned expected LoRA adapters\")\n\n        logger.info(\"[LoRA E2E] All LoRA API E2E tests passed for %s\", case.id)\n\n    def _test_lora_dynamic_switch_e2e(\n        self,\n        ctx: ServerContext,\n        case: DiffusionTestCase,\n        generate_fn: Callable[[str, openai.Client], tuple[str, bytes]],\n        second_lora_path: str,\n    ) -> None:\n        \"\"\"\n        Test dynamic LoRA switching with end-to-end validation.\n        This test verifies that switching between LoRA adapters works correctly\n        and generation succeeds after each switch.\n        \"\"\"\n        base_url = f\"http://localhost:{ctx.port}/v1\"\n        client = OpenAI(base_url=base_url, api_key=\"dummy\")\n\n        # Test 1: Generate with initial LoRA\n        logger.info(\n            \"[LoRA Switch E2E] Testing generation with initial LoRA for %s\", case.id\n        )\n        rid_initial, _ = generate_fn(case.id, client)\n        assert rid_initial is not None, \"Generation with initial LoRA failed\"\n        logger.info(\"[LoRA Switch E2E] Generation with initial LoRA succeeded\")\n\n        # Test 2: Switch to second LoRA and generate\n        logger.info(\n            \"[LoRA Switch E2E] Switching to second LoRA adapter for %s\", case.id\n        )\n        resp = requests.post(\n            f\"{base_url}/set_lora\",\n            json={\"lora_nickname\": \"lora2\", \"lora_path\": second_lora_path},\n        )\n        assert (\n            resp.status_code == 200\n        ), f\"set_lora to second adapter failed: {resp.text}\"\n\n        logger.info(\n            \"[LoRA Switch E2E] Verifying generation with second LoRA for %s\", case.id\n        )\n        rid_second, _ = generate_fn(case.id, client)\n        assert rid_second is not None, \"Generation with second LoRA failed\"\n        logger.info(\"[LoRA Switch E2E] Generation with second LoRA succeeded\")\n\n        # Test 3: Switch back to original LoRA and generate\n        logger.info(\"[LoRA Switch E2E] Switching back to original LoRA for %s\", case.id)\n        resp = requests.post(f\"{base_url}/set_lora\", json={\"lora_nickname\": \"default\"})\n        assert resp.status_code == 200, f\"set_lora back to default failed: {resp.text}\"\n\n        logger.info(\n            \"[LoRA Switch E2E] Verifying generation after switching back for %s\",\n            case.id,\n        )\n        rid_switched_back, _ = generate_fn(case.id, client)\n        assert rid_switched_back is not None, \"Generation after switching back failed\"\n        logger.info(\"[LoRA Switch E2E] Generation after switching back succeeded\")\n\n        logger.info(\n            \"[LoRA Switch E2E] All dynamic switch E2E tests passed for %s\", case.id\n        )\n\n    def _test_dynamic_lora_loading(\n        self,\n        ctx: ServerContext,\n        case: DiffusionTestCase,\n    ) -> None:\n        \"\"\"\n        Test dynamic LoRA loading after server startup.\n\n        This test reproduces the LayerwiseOffload + set_lora issue:\n        - Server starts WITHOUT lora_path (LayerwiseOffloadManager initializes first)\n        - Then set_lora is called via API to load LoRA dynamically\n        - This tests the interaction between layerwise offload and dynamic LoRA loading\n        \"\"\"\n        base_url = f\"http://localhost:{ctx.port}/v1\"\n        dynamic_lora_path = case.server_args.dynamic_lora_path\n\n        # Call set_lora to load LoRA dynamically after server startup\n        logger.info(\n            \"[Dynamic LoRA] Loading LoRA dynamically via set_lora API for %s\", case.id\n        )\n        logger.info(\"[Dynamic LoRA] LoRA path: %s\", dynamic_lora_path)\n        resp = requests.post(\n            f\"{base_url}/set_lora\",\n            json={\"lora_nickname\": \"default\", \"lora_path\": dynamic_lora_path},\n        )\n        assert resp.status_code == 200, f\"Dynamic set_lora failed: {resp.text}\"\n        logger.info(\"[Dynamic LoRA] set_lora succeeded for %s\", case.id)\n\n    def _test_multi_lora_e2e(\n        self,\n        ctx: ServerContext,\n        case: DiffusionTestCase,\n        generate_fn: Callable[[str, openai.Client], tuple[str, bytes]],\n        first_lora_path: str,\n        second_lora_path: str,\n    ) -> None:\n        \"\"\"\n        Test multiple LoRA adapters with different set_lora input scenarios.\n        Tests: basic multi-LoRA, different strengths, cached adapters, switch back to single.\n        \"\"\"\n        base_url = f\"http://localhost:{ctx.port}/v1\"\n        client = OpenAI(base_url=base_url, api_key=\"dummy\")\n\n        # Test 1: Basic multi-LoRA with list format\n        resp = requests.post(\n            f\"{base_url}/set_lora\",\n            json={\n                \"lora_nickname\": [\"default\", \"lora2\"],\n                \"lora_path\": [first_lora_path, second_lora_path],\n                \"target\": \"all\",\n                \"strength\": [1.0, 1.0],\n            },\n        )\n        assert (\n            resp.status_code == 200\n        ), f\"set_lora with multiple adapters failed: {resp.text}\"\n        rid, _ = generate_fn(case.id, client)\n        assert rid is not None\n\n        # Test 2: Different strengths\n        resp = requests.post(\n            f\"{base_url}/set_lora\",\n            json={\n                \"lora_nickname\": [\"default\", \"lora2\"],\n                \"lora_path\": [first_lora_path, second_lora_path],\n                \"target\": \"all\",\n                \"strength\": [0.8, 0.5],\n            },\n        )\n        assert (\n            resp.status_code == 200\n        ), f\"set_lora with different strengths failed: {resp.text}\"\n        rid, _ = generate_fn(case.id, client)\n        assert rid is not None\n\n        # Test 3: Different targets\n        requests.post(f\"{base_url}/set_lora\", json={\"lora_nickname\": \"default\"})\n        resp = requests.post(\n            f\"{base_url}/set_lora\",\n            json={\n                \"lora_nickname\": [\"default\", \"lora2\"],\n                \"lora_path\": [first_lora_path, second_lora_path],\n                \"target\": [\"transformer\", \"transformer_2\"],\n                \"strength\": [0.8, 0.5],\n            },\n        )\n        assert (\n            resp.status_code == 200\n        ), f\"set_lora with cached adapters failed: {resp.text}\"\n        rid, _ = generate_fn(case.id, client)\n        assert rid is not None\n\n        # Test 4: Switch back to single LoRA\n        resp = requests.post(f\"{base_url}/set_lora\", json={\"lora_nickname\": \"default\"})\n        assert (\n            resp.status_code == 200\n        ), f\"set_lora back to single adapter failed: {resp.text}\"\n        rid, _ = generate_fn(case.id, client)\n        assert rid is not None\n\n        logger.info(\"[Multi-LoRA] All multi-LoRA tests passed for %s\", case.id)\n\n    def _test_v1_models_endpoint(\n        self, ctx: ServerContext, case: DiffusionTestCase\n    ) -> None:\n        \"\"\"\n        Test /v1/models endpoint returns OpenAI-compatible response.\n        This endpoint is required for sgl-model-gateway router compatibility.\n        \"\"\"\n        base_url = f\"http://localhost:{ctx.port}\"\n\n        # Test GET /v1/models\n        logger.info(\"[Models API] Testing GET /v1/models for %s\", case.id)\n        resp = requests.get(f\"{base_url}/v1/models\")\n        assert resp.status_code == 200, f\"/v1/models failed: {resp.text}\"\n\n        data = resp.json()\n        assert (\n            data[\"object\"] == \"list\"\n        ), f\"Expected object='list', got {data.get('object')}\"\n        assert len(data[\"data\"]) >= 1, \"Expected at least one model in response\"\n\n        model = data[\"data\"][0]\n        assert \"id\" in model, \"Model missing 'id' field\"\n        assert (\n            model[\"object\"] == \"model\"\n        ), f\"Expected object='model', got {model.get('object')}\"\n        assert (\n            model[\"id\"] == case.server_args.model_path\n        ), f\"Model ID mismatch: expected {case.server_args.model_path}, got {model['id']}\"\n\n        # Verify extended diffusion-specific fields\n        assert \"num_gpus\" in model, \"Model missing 'num_gpus' field\"\n        assert \"task_type\" in model, \"Model missing 'task_type' field\"\n        assert \"dit_precision\" in model, \"Model missing 'dit_precision' field\"\n        assert \"vae_precision\" in model, \"Model missing 'vae_precision' field\"\n        assert (\n            model[\"num_gpus\"] == case.server_args.num_gpus\n        ), f\"num_gpus mismatch: expected {case.server_args.num_gpus}, got {model['num_gpus']}\"\n        # Verify task_type is consistent with the modality specified in the test config.\n        # We can't access pipeline_config from test config, but we can validate against modality.\n        modality_to_valid_task_types = {\n            \"image\": {\"T2I\", \"I2I\", \"TI2I\"},\n            \"video\": {\"T2V\", \"I2V\", \"TI2V\"},\n            \"3d\": {\"I2M\"},\n        }\n        valid_task_types = modality_to_valid_task_types.get(\n            case.server_args.modality, set()\n        )\n        assert model[\"task_type\"] in valid_task_types, (\n            f\"task_type '{model['task_type']}' not valid for modality \"\n            f\"'{case.server_args.modality}'. Expected one of: {valid_task_types}\"\n        )\n        logger.info(\n            \"[Models API] GET /v1/models returned valid response with extended fields\"\n        )\n\n        # Test GET /v1/models/{model_path}\n        model_path = model[\"id\"]\n        logger.info(\"[Models API] Testing GET /v1/models/%s\", model_path)\n        resp = requests.get(f\"{base_url}/v1/models/{model_path}\")\n        assert resp.status_code == 200, f\"/v1/models/{model_path} failed: {resp.text}\"\n\n        single_model = resp.json()\n        assert single_model[\"id\"] == model_path, \"Single model ID mismatch\"\n        assert single_model[\"object\"] == \"model\", \"Single model object type mismatch\"\n\n        # Verify extended fields on single model endpoint too\n        assert \"num_gpus\" in single_model, \"Single model missing 'num_gpus' field\"\n        assert \"task_type\" in single_model, \"Single model missing 'task_type' field\"\n        assert single_model[\"task_type\"] in valid_task_types, (\n            f\"Single model task_type '{single_model['task_type']}' not valid for modality \"\n            f\"'{case.server_args.modality}'. Expected one of: {valid_task_types}\"\n        )\n        logger.info(\n            \"[Models API] GET /v1/models/{model_path} returned valid response with extended fields\"\n        )\n\n        # Test GET /v1/models/{non_existent_model} returns 404\n        logger.info(\"[Models API] Testing GET /v1/models/non_existent_model\")\n        resp = requests.get(f\"{base_url}/v1/models/non_existent_model\")\n        assert resp.status_code == 404, f\"Expected 404, got {resp.status_code}\"\n        error_data = resp.json()\n        assert \"error\" in error_data, \"404 response missing 'error' field\"\n        assert (\n            error_data[\"error\"][\"code\"] == \"model_not_found\"\n        ), f\"Incorrect error code: {error_data['error'].get('code')}\"\n        logger.info(\"[Models API] GET /v1/models/non_existent returns 404 as expected\")\n\n        logger.info(\"[Models API] All /v1/models tests passed for %s\", case.id)\n\n    def _test_t2v_rejects_input_reference(\n        self, ctx: ServerContext, case: DiffusionTestCase\n    ) -> None:\n        if case.server_args.modality != \"video\":\n            return\n\n        base_url = f\"http://localhost:{ctx.port}\"\n        resp = requests.get(f\"{base_url}/v1/models\")\n        assert resp.status_code == 200, f\"/v1/models failed: {resp.text}\"\n        data = resp.json().get(\"data\", [])\n        if not data:\n            pytest.fail(\"/v1/models returned empty model list\")\n\n        task_type = data[0].get(\"task_type\")\n        if task_type != \"T2V\":\n            return\n\n        prompt = case.sampling_params.prompt or \"test\"\n        payload = {\"prompt\": prompt, \"input_reference\": \"dummy\"}\n        if case.sampling_params.output_size:\n            payload[\"size\"] = case.sampling_params.output_size\n\n        resp = requests.post(f\"{base_url}/v1/videos\", json=payload)\n        assert (\n            resp.status_code == 400\n        ), f\"Expected 400 for T2V input_reference, got {resp.status_code}: {resp.text}\"\n        detail = resp.json().get(\"detail\", \"\")\n        assert (\n            \"input_reference is not supported\" in detail\n        ), f\"Unexpected error detail for T2V input_reference: {detail}\"\n\n    def test_diffusion_generation(\n        self,\n        case: DiffusionTestCase,\n        diffusion_server: ServerContext,\n    ):\n        \"\"\"Single parametrized test that runs for all cases.\n\n        This test performs:\n        1. Generation\n        2. Performance validation against baselines\n        3. Consistency validation against ground truth\n\n        Pytest will execute this test once per case in ONE_GPU_CASES,\n        with test IDs like:\n        - test_diffusion_generation[qwen_image_text]\n        - test_diffusion_generation[qwen_image_edit]\n        - etc.\n        \"\"\"\n        # Check if we're in GT generation mode\n        is_gt_gen_mode = os.environ.get(\"SGLANG_GEN_GT\", \"0\") == \"1\"\n\n        # Dynamic LoRA loading test - tests LayerwiseOffload + set_lora interaction\n        # Server starts WITHOUT lora_path, then set_lora is called after startup\n        if case.server_args.dynamic_lora_path and not is_gt_gen_mode:\n            self._test_dynamic_lora_loading(diffusion_server, case)\n\n        generate_fn = get_generate_fn(\n            model_path=case.server_args.model_path,\n            modality=case.server_args.modality,\n            sampling_params=case.sampling_params,\n        )\n\n        # Single generation - output is reused for both validations\n        perf_record, content = self.run_and_collect(\n            diffusion_server,\n            case.id,\n            generate_fn,\n        )\n\n        if is_gt_gen_mode:\n            # GT generation mode: save output and skip all validations/tests\n            self._save_gt_output(case, content)\n            return\n\n        # Validation 1: Performance\n        self._validate_and_record(case, perf_record)\n\n        # Mesh correctness check (Chamfer Distance) for 3D models\n        if case.server_args.custom_validator == \"mesh\":\n            from sglang.multimodal_gen.test.server.test_server_utils import (\n                MESH_OUTPUT_PATHS,\n                validate_mesh_correctness,\n            )\n\n            mesh_path = MESH_OUTPUT_PATHS.pop(case.id, None)\n            if mesh_path:\n                validate_mesh_correctness(mesh_path)\n\n        # Test /v1/models endpoint for router compatibility\n        self._test_v1_models_endpoint(diffusion_server, case)\n        self._test_t2v_rejects_input_reference(diffusion_server, case)\n\n        # LoRA API functionality test with E2E validation (only for LoRA-enabled cases)\n        if case.server_args.lora_path or case.server_args.dynamic_lora_path:\n            self._test_lora_api_functionality(diffusion_server, case, generate_fn)\n\n            # Test dynamic LoRA switching (requires a second LoRA adapter)\n            if case.server_args.second_lora_path:\n                self._test_lora_dynamic_switch_e2e(\n                    diffusion_server,\n                    case,\n                    generate_fn,\n                    case.server_args.second_lora_path,\n                )\n\n                # Test multi-LoRA functionality\n                self._test_multi_lora_e2e(\n                    diffusion_server,\n                    case,\n                    generate_fn,\n                    case.server_args.lora_path,\n                    case.server_args.second_lora_path,\n                )\n"
  },
  {
    "path": "python/sglang/multimodal_gen/test/server/test_server_utils.py",
    "content": "\"\"\"\nServer management and performance validation for diffusion tests.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport base64\nimport os\nimport shlex\nimport subprocess\nimport sys\nimport tempfile\nimport threading\nimport time\nfrom dataclasses import dataclass, field\nfrom pathlib import Path\nfrom typing import Any, Callable, Sequence\nfrom urllib.request import urlopen\n\nimport pytest\nfrom openai import Client\n\nfrom sglang.multimodal_gen.benchmarks.compare_perf import calculate_upper_bound\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\nfrom sglang.multimodal_gen.runtime.utils.common import kill_process_tree\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import (\n    globally_suppress_loggers,\n    init_logger,\n)\nfrom sglang.multimodal_gen.runtime.utils.perf_logger import RequestPerfRecord\nfrom sglang.multimodal_gen.test.server.testcase_configs import (\n    DiffusionSamplingParams,\n    PerformanceSummary,\n    ScenarioConfig,\n    ToleranceConfig,\n)\nfrom sglang.multimodal_gen.test.slack_utils import upload_file_to_slack\nfrom sglang.multimodal_gen.test.test_utils import (\n    get_expected_image_format,\n    get_video_frame_count,\n    is_image_url,\n    prepare_perf_log,\n    validate_image,\n    validate_image_file,\n    validate_openai_video,\n    validate_video_file,\n)\n\nlogger = init_logger(__name__)\n\nglobally_suppress_loggers()\n\n# Tracks mesh output file paths from generate_mesh for later correctness validation.\n# Keyed by case_id, cleaned up after use.\nMESH_OUTPUT_PATHS: dict[str, str] = {}\n\n\ndef download_image_from_url(url: str) -> Path:\n    \"\"\"Download an image from a URL to a temporary file.\n\n    Args:\n        url: The URL of the image to download\n\n    Returns:\n        Path to the downloaded temporary file\n    \"\"\"\n    logger.info(f\"Downloading image from URL: {url}\")\n\n    # Determine file extension from URL\n    ext = \".jpg\"  # default\n    if url.lower().endswith((\".png\", \".jpeg\", \".jpg\", \".webp\", \".gif\")):\n        ext = url[url.rfind(\".\") :]\n\n    # Create temporary file\n    temp_file = (\n        Path(tempfile.gettempdir()) / f\"diffusion_test_image_{int(time.time())}{ext}\"\n    )\n\n    try:\n        with urlopen(url, timeout=30) as response:\n            temp_file.write_bytes(response.read())\n        logger.info(f\"Downloaded image to: {temp_file}\")\n        return temp_file\n    except Exception as e:\n        logger.error(f\"Failed to download image from {url}: {e}\")\n        raise\n\n\ndef parse_dimensions(size_string: str | None) -> tuple[int | None, int | None]:\n    \"\"\"Parse a size string in \"widthxheight\" format to (width, height) tuple.\n\n    Args:\n        size_string: Size string in \"widthxheight\" format (e.g., \"1024x1024\") or None.\n                    Spaces are automatically stripped.\n\n    Returns:\n        Tuple of (width, height) as integers if parsing succeeds, (None, None) otherwise.\n    \"\"\"\n    if not size_string:\n        return (None, None)\n\n    # Strip spaces from the entire string\n    size_string = size_string.strip()\n    if not size_string:\n        return (None, None)\n\n    # Split by \"x\"\n    parts = size_string.split(\"x\")\n    if len(parts) != 2:\n        return (None, None)\n\n    # Strip spaces from each part and try to convert to int\n    try:\n        width_str = parts[0].strip()\n        height_str = parts[1].strip()\n\n        if not width_str or not height_str:\n            return (None, None)\n\n        width = int(width_str)\n        height = int(height_str)\n\n        # Validate that both are positive\n        if width <= 0 or height <= 0:\n            return (None, None)\n\n        return (width, height)\n    except ValueError:\n        return (None, None)\n\n\n@dataclass\nclass ServerContext:\n    \"\"\"Context for a running diffusion server.\"\"\"\n\n    port: int\n    process: subprocess.Popen\n    model: str\n    stdout_file: Path\n    perf_log_path: Path\n    log_dir: Path\n    _stdout_fh: Any = field(repr=False)\n    _log_thread: threading.Thread | None = field(default=None, repr=False)\n\n    def cleanup(self) -> None:\n        \"\"\"Clean up server resources.\"\"\"\n        try:\n            kill_process_tree(self.process.pid)\n        except Exception:\n            pass\n        try:\n            self._stdout_fh.flush()\n            self._stdout_fh.close()\n        except Exception:\n            pass\n\n        # ROCm/AMD: Extra cleanup to ensure GPU memory is released between tests\n        # This is needed because ROCm memory release can be slower than CUDA\n        if current_platform.is_hip():\n            self._cleanup_rocm_gpu_memory()\n            # Clean up downloaded models if HF cache is not persistent\n            # This prevents disk exhaustion in CI when cache is not mounted\n            self._cleanup_hf_cache_if_not_persistent()\n\n    def _cleanup_hf_cache_if_not_persistent(self) -> None:\n        \"\"\"Clean up HF cache if it's not on a persistent volume.\n\n        When running in CI without persistent cache, downloaded models accumulate\n        and can cause disk/memory exhaustion. This cleans up the model after each\n        test if the cache is not persistent.\n        \"\"\"\n        import shutil\n\n        hf_home = os.environ.get(\"HF_HOME\", \"\")\n        if not hf_home:\n            return\n\n        hf_hub_cache = os.path.join(hf_home, \"hub\")\n\n        # Check if HF cache is on a persistent volume by looking for a marker file\n        # or checking if the directory existed before this test run\n        persistent_marker = os.path.join(hf_home, \".persistent_cache\")\n        if os.path.exists(persistent_marker):\n            logger.info(\"HF cache is persistent, skipping cleanup\")\n            return\n\n        # Check if the cache directory is empty or was just created\n        # If it has very few models, it's likely not persistent\n        if not os.path.exists(hf_hub_cache):\n            return\n\n        try:\n            # Get model cache directories\n            model_dirs = [\n                d\n                for d in os.listdir(hf_hub_cache)\n                if d.startswith(\"models--\")\n                and os.path.isdir(os.path.join(hf_hub_cache, d))\n            ]\n\n            # If there are cached models but no persistent marker, clean up\n            # to prevent disk exhaustion in CI\n            if model_dirs:\n                logger.info(\n                    \"HF cache appears non-persistent (no .persistent_cache marker), \"\n                    \"cleaning up %d model(s) to prevent disk exhaustion\",\n                    len(model_dirs),\n                )\n                for model_dir in model_dirs:\n                    model_path = os.path.join(hf_hub_cache, model_dir)\n                    try:\n                        shutil.rmtree(model_path)\n                        logger.info(\"Cleaned up model cache: %s\", model_dir)\n                    except Exception as e:\n                        logger.warning(\"Failed to clean up %s: %s\", model_dir, e)\n        except Exception as e:\n            logger.warning(\"Error during HF cache cleanup: %s\", e)\n\n    def _cleanup_rocm_gpu_memory(self) -> None:\n        \"\"\"ROCm-specific cleanup to ensure GPU memory is fully released.\"\"\"\n        import gc\n\n        # Wait for process to fully terminate\n        try:\n            self.process.wait(timeout=30)\n        except Exception:\n            pass\n\n        # Force garbage collection multiple times\n        for _ in range(3):\n            gc.collect()\n\n        # Clear HIP memory on all GPUs\n        try:\n            import torch\n\n            for i in range(torch.cuda.device_count()):\n                with torch.cuda.device(i):\n                    torch.cuda.empty_cache()\n                    torch.cuda.synchronize()\n        except Exception:\n            pass\n\n        # Wait for GPU memory to be released (ROCm can be much slower than CUDA)\n        # The GPU driver needs time to reclaim memory from killed processes\n        time.sleep(15)\n\n\nclass ServerManager:\n    \"\"\"Manages diffusion server lifecycle.\"\"\"\n\n    def __init__(\n        self,\n        model: str,\n        port: int,\n        wait_deadline: float = 1200.0,\n        extra_args: str = \"\",\n        env_vars: dict[str, str] | None = None,\n    ):\n        self.model = model\n        self.port = port\n        self.wait_deadline = wait_deadline\n        self.extra_args = extra_args\n        self.env_vars = env_vars or {}\n\n    def _wait_for_rocm_gpu_memory_clear(self, max_wait: float = 60.0) -> None:\n        \"\"\"ROCm-specific: Wait for GPU memory to be mostly free before starting.\n\n        ROCm GPU memory release from killed processes can be significantly slower\n        than CUDA, so we need to wait longer and be more patient.\n        \"\"\"\n        try:\n            import torch\n\n            if not torch.cuda.is_available():\n                return\n\n            start_time = time.time()\n            last_total_used = float(\"inf\")\n\n            while time.time() - start_time < max_wait:\n                # Check GPU memory usage\n                total_used = 0\n                for i in range(torch.cuda.device_count()):\n                    mem_info = torch.cuda.mem_get_info(i)\n                    free, total = mem_info\n                    used = total - free\n                    total_used += used\n\n                # If less than 5GB is used across all GPUs, we're good\n                if total_used < 5 * 1024 * 1024 * 1024:  # 5GB\n                    logger.info(\n                        \"[server-test] ROCm GPU memory is clear (used: %.2f GB)\",\n                        total_used / (1024**3),\n                    )\n                    return\n\n                # Log progress\n                elapsed = int(time.time() - start_time)\n                if total_used < last_total_used:\n                    logger.info(\n                        \"[server-test] ROCm: GPU memory clearing (used: %.2f GB, elapsed: %ds)\",\n                        total_used / (1024**3),\n                        elapsed,\n                    )\n                else:\n                    logger.info(\n                        \"[server-test] ROCm: Waiting for GPU memory (used: %.2f GB, elapsed: %ds)\",\n                        total_used / (1024**3),\n                        elapsed,\n                    )\n                last_total_used = total_used\n                time.sleep(3)\n\n            # Final warning with detailed GPU info\n            logger.warning(\n                \"[server-test] ROCm GPU memory not fully cleared after %.0fs (used: %.2f GB). \"\n                \"Proceeding anyway - this may cause OOM.\",\n                max_wait,\n                total_used / (1024**3),\n            )\n        except Exception as e:\n            logger.debug(\"[server-test] Could not check ROCm GPU memory: %s\", e)\n\n    def start(self) -> ServerContext:\n        \"\"\"Start the diffusion server and wait for readiness.\"\"\"\n        # ROCm/AMD: Wait for GPU memory to be clear before starting\n        # This prevents OOM when running sequential tests on ROCm\n        if current_platform.is_hip():\n            self._wait_for_rocm_gpu_memory_clear()\n\n        log_dir, perf_log_path = prepare_perf_log()\n\n        safe_model_name = self.model.replace(\"/\", \"_\")\n        stdout_path = (\n            Path(tempfile.gettempdir())\n            / f\"sgl_server_{self.port}_{safe_model_name}.log\"\n        )\n        stdout_path.unlink(missing_ok=True)\n\n        command = [\n            \"sglang\",\n            \"serve\",\n            \"--model-path\",\n            self.model,\n            \"--port\",\n            str(self.port),\n            \"--log-level=debug\",\n        ]\n        if self.extra_args.strip():\n            command.extend(self.extra_args.strip().split())\n\n        env = os.environ.copy()\n        env[\"SGLANG_DIFFUSION_STAGE_LOGGING\"] = \"1\"\n        env[\"SGLANG_PERF_LOG_DIR\"] = log_dir.as_posix()\n\n        # Apply custom environment variables\n        env.update(self.env_vars)\n\n        # TODO: unify with run_command\n        logger.info(f\"Running command: {shlex.join(command)}\")\n\n        process = subprocess.Popen(\n            command,\n            stdout=subprocess.PIPE,\n            stderr=subprocess.STDOUT,\n            text=True,\n            bufsize=1,\n            env=env,\n        )\n\n        log_thread = None\n        stdout_fh = stdout_path.open(\"w\", encoding=\"utf-8\", buffering=1)\n        if process.stdout:\n\n            def _log_pipe(pipe: Any, file: Any) -> None:\n                \"\"\"Read from pipe and write to file and stdout.\"\"\"\n                try:\n                    with pipe:\n                        for line in iter(pipe.readline, \"\"):\n                            sys.stdout.write(line)\n                            sys.stdout.flush()\n                            file.write(line)\n                            file.flush()\n                except Exception as e:\n                    logger.error(\"Log pipe thread error: %s\", e)\n                finally:\n                    file.close()\n                    logger.debug(\"Log pipe thread finished.\")\n\n            log_thread = threading.Thread(\n                target=_log_pipe, args=(process.stdout, stdout_fh)\n            )\n            log_thread.daemon = True\n            log_thread.start()\n\n        logger.info(\n            \"[server-test] Starting server pid=%s, model=%s, log=%s\",\n            process.pid,\n            self.model,\n            stdout_path,\n        )\n\n        self._wait_for_ready(process, stdout_path)\n\n        return ServerContext(\n            port=self.port,\n            process=process,\n            model=self.model,\n            stdout_file=stdout_path,\n            perf_log_path=perf_log_path,\n            log_dir=log_dir,\n            _stdout_fh=stdout_fh,\n            _log_thread=log_thread,\n        )\n\n    def _wait_for_ready(self, process: subprocess.Popen, stdout_path: Path) -> None:\n        \"\"\"Wait for server to become ready.\"\"\"\n        start = time.time()\n        ready_message = \"Application startup complete.\"\n        log_period = 30\n        prev_log_period_count = 0\n\n        while time.time() - start < self.wait_deadline:\n            if process.poll() is not None:\n                tail = self._get_log_tail(stdout_path)\n                raise RuntimeError(\n                    f\"Server exited early (code {process.returncode}).\\n{tail}\"\n                )\n\n            if stdout_path.exists():\n                try:\n                    content = stdout_path.read_text(encoding=\"utf-8\", errors=\"ignore\")\n                    if ready_message in content:\n                        logger.info(\"[server-test] Server ready\")\n                        return\n                except Exception as e:\n                    logger.debug(\"Could not read log yet: %s\", e)\n\n            elapsed = int(time.time() - start)\n            if (elapsed // log_period) > prev_log_period_count:\n                prev_log_period_count = elapsed // log_period\n                logger.info(\"[server-test] Waiting for server... elapsed=%ss\", elapsed)\n            time.sleep(1)\n\n        tail = self._get_log_tail(stdout_path)\n        raise TimeoutError(f\"Server not ready within {self.wait_deadline}s.\\n{tail}\")\n\n    @staticmethod\n    def _get_log_tail(path: Path, lines: int = 200) -> str:\n        \"\"\"Get the last N lines from a log file.\"\"\"\n        try:\n            content = path.read_text(encoding=\"utf-8\", errors=\"ignore\")\n            return \"\\n\".join(content.splitlines()[-lines:])\n        except Exception:\n            return \"\"\n\n\nclass PerformanceValidator:\n    \"\"\"Validates performance metrics against expectations.\"\"\"\n\n    is_video_gen: bool = False\n\n    def __init__(\n        self,\n        scenario: ScenarioConfig,\n        tolerances: ToleranceConfig,\n        step_fractions: Sequence[float],\n    ):\n        self.scenario = scenario\n        self.tolerances = tolerances\n        self.step_fractions = step_fractions\n        self.is_baseline_generation_mode = (\n            os.environ.get(\"SGLANG_GEN_BASELINE\", \"0\") == \"1\"\n        )\n\n    def _assert_le(\n        self,\n        name: str,\n        actual: float,\n        expected: float,\n        tolerance: float,\n        min_abs_tolerance_ms: float = 20.0,\n    ):\n        \"\"\"Assert that actual is less than or equal to expected within a tolerance.\n\n        Uses the larger of relative tolerance or absolute tolerance to prevent\n        flaky failures on very fast operations.\n\n        For AMD GPUs, uses 100% higher tolerance and issues warning instead of assertion.\n        \"\"\"\n        # Check if running on AMD GPU\n        is_amd = current_platform.is_hip()\n\n        if is_amd:\n            # Use 100% higher tolerance for AMD (2x the expected value)\n            amd_tolerance = 1.0  # 100%\n            upper_bound = calculate_upper_bound(\n                expected, amd_tolerance, min_abs_tolerance_ms\n            )\n            if actual > upper_bound:\n                logger.warning(\n                    f\"[AMD PERF WARNING] Validation would fail for '{name}'.\\n\"\n                    f\"  Actual:   {actual:.4f}ms\\n\"\n                    f\"  Expected: {expected:.4f}ms\\n\"\n                    f\"  AMD Limit: {upper_bound:.4f}ms \"\n                    f\"(rel_tol: {amd_tolerance:.1%}, abs_pad: {min_abs_tolerance_ms}ms)\\n\"\n                    f\"  Original tolerance was: {tolerance:.1%}\"\n                )\n        else:\n            upper_bound = calculate_upper_bound(\n                expected, tolerance, min_abs_tolerance_ms\n            )\n            assert actual <= upper_bound, (\n                f\"Validation failed for '{name}'.\\n\"\n                f\"  Actual:   {actual:.4f}ms\\n\"\n                f\"  Expected: {expected:.4f}ms\\n\"\n                f\"  Limit:    {upper_bound:.4f}ms \"\n                f\"(rel_tol: {tolerance:.1%}, abs_pad: {min_abs_tolerance_ms}ms)\"\n            )\n\n    def validate(\n        self, perf_record: RequestPerfRecord, *args, **kwargs\n    ) -> PerformanceSummary:\n        \"\"\"Validate all performance metrics and return summary.\"\"\"\n        summary = self.collect_metrics(perf_record)\n        if self.is_baseline_generation_mode:\n            return summary\n\n        self._validate_e2e(summary)\n        self._validate_denoise_agg(summary)\n        self._validate_denoise_steps(summary)\n        self._validate_stages(summary)\n\n        return summary\n\n    def collect_metrics(\n        self,\n        perf_record: RequestPerfRecord,\n    ) -> PerformanceSummary:\n        return PerformanceSummary.from_req_perf_record(perf_record, self.step_fractions)\n\n    def _validate_e2e(self, summary: PerformanceSummary) -> None:\n        \"\"\"Validate end-to-end performance.\"\"\"\n        assert summary.e2e_ms > 0, \"E2E duration missing\"\n        self._assert_le(\n            \"E2E Latency\",\n            summary.e2e_ms,\n            self.scenario.expected_e2e_ms,\n            self.tolerances.e2e,\n        )\n\n    def _validate_denoise_agg(self, summary: PerformanceSummary) -> None:\n        \"\"\"Validate aggregate denoising metrics.\"\"\"\n        assert summary.avg_denoise_ms > 0, \"Denoising step timings missing\"\n\n        self._assert_le(\n            \"Average Denoise Step\",\n            summary.avg_denoise_ms,\n            self.scenario.expected_avg_denoise_ms,\n            self.tolerances.denoise_agg,\n        )\n        self._assert_le(\n            \"Median Denoise Step\",\n            summary.median_denoise_ms,\n            self.scenario.expected_median_denoise_ms,\n            self.tolerances.denoise_agg,\n        )\n\n    def _validate_denoise_steps(self, summary: PerformanceSummary) -> None:\n        \"\"\"Validate individual denoising steps.\"\"\"\n        for idx, actual in summary.sampled_steps.items():\n            expected = self.scenario.denoise_step_ms.get(idx)\n            if expected is None:\n                continue\n            # FIXME: hardcode, looser for first step\n            tolerance = 0.4 if idx == 0 else self.tolerances.denoise_step\n\n            self._assert_le(\n                f\"Denoise Step {idx}\",\n                actual,\n                expected,\n                tolerance,\n            )\n\n    def _validate_stages(self, summary: PerformanceSummary) -> None:\n        \"\"\"Validate stage-level metrics.\"\"\"\n        assert summary.stage_metrics, \"Stage metrics missing\"\n\n        for stage, expected in self.scenario.stages_ms.items():\n            if stage == \"per_frame_generation\" and self.is_video_gen:\n                continue\n            actual = summary.stage_metrics.get(stage)\n            assert actual is not None, f\"Stage {stage} timing missing\"\n            tolerance = (\n                self.tolerances.denoise_stage\n                if stage == \"DenoisingStage\"\n                else self.tolerances.non_denoise_stage\n            )\n            self._assert_le(\n                f\"Stage '{stage}'\",\n                actual,\n                expected,\n                tolerance,\n                min_abs_tolerance_ms=120.0,  # relax absolute tolerance for non-denoising stages\n            )\n\n\nclass VideoPerformanceValidator(PerformanceValidator):\n    \"\"\"Extended validator for video diffusion with frame-level metrics.\"\"\"\n\n    is_video_gen = True\n\n    def validate(\n        self,\n        perf_record: RequestPerfRecord,\n        num_frames: int | None = None,\n    ) -> PerformanceSummary:\n        \"\"\"Validate video metrics including frame generation rates.\"\"\"\n        summary = super().validate(perf_record)\n\n        if num_frames and summary.e2e_ms > 0:\n            summary.total_frames = num_frames\n            summary.avg_frame_time_ms = summary.e2e_ms / num_frames\n            summary.frames_per_second = 1000.0 / summary.avg_frame_time_ms\n\n            if not self.is_baseline_generation_mode:\n                self._validate_frame_rate(summary)\n\n        return summary\n\n    def _validate_frame_rate(self, summary: PerformanceSummary) -> None:\n        \"\"\"Validate frame generation performance.\"\"\"\n        expected_frame_time = self.scenario.stages_ms.get(\"per_frame_generation\")\n        if expected_frame_time and summary.avg_frame_time_ms:\n            self._assert_le(\n                \"Average Frame Time\",\n                summary.avg_frame_time_ms,\n                expected_frame_time,\n                self.tolerances.denoise_stage,\n            )\n\n\nclass MeshValidator(PerformanceValidator):\n    \"\"\"Validator for 3D mesh generation. Inherits perf validation from PerformanceValidator.\"\"\"\n\n    pass\n\n\nHUNYUAN3D_REFERENCE_URL = (\n    \"https://raw.githubusercontent.com/sgl-project/sgl-test-files/\"\n    \"main/diffusion-ci/consistency_gt/1-gpu/hunyuan3d_2_0/hunyuan3d.glb\"\n)\n\n\ndef _download_reference_mesh(url: str) -> Path:\n    \"\"\"Download a reference mesh from URL, caching in temp dir.\"\"\"\n    import hashlib\n\n    cache_name = f\"ref_mesh_{hashlib.md5(url.encode()).hexdigest()}.glb\"\n    cache_path = Path(tempfile.gettempdir()) / cache_name\n    if cache_path.exists():\n        logger.info(f\"Using cached reference mesh: {cache_path}\")\n        return cache_path\n\n    logger.info(f\"Downloading reference mesh from: {url}\")\n    with urlopen(url, timeout=60) as resp:\n        cache_path.write_bytes(resp.read())\n    logger.info(f\"Reference mesh cached at: {cache_path}\")\n    return cache_path\n\n\ndef validate_mesh_correctness(\n    generated_mesh_path: str,\n    reference_url: str = HUNYUAN3D_REFERENCE_URL,\n    num_sample_points: int = 4096,\n    cd_threshold_ratio: float = 0.01,\n    random_seed: int = 42,\n):\n    \"\"\"Validate mesh geometric similarity against a reference via Chamfer Distance.\n\n    Downloads the reference mesh from a URL (cached), samples point clouds from\n    both meshes, and asserts Chamfer Distance is within threshold.\n    \"\"\"\n    import numpy as np\n\n    try:\n        import trimesh\n    except ImportError:\n        pytest.fail(\"trimesh is required for mesh validation: pip install trimesh\")\n\n    from scipy.spatial import cKDTree\n\n    # Load generated mesh\n    generated_mesh = trimesh.load(generated_mesh_path)\n    if isinstance(generated_mesh, trimesh.Scene):\n        generated_mesh = generated_mesh.dump(concatenate=True)\n\n    # Download and load reference mesh\n    ref_path = _download_reference_mesh(reference_url)\n    reference_mesh = trimesh.load(str(ref_path))\n    if isinstance(reference_mesh, trimesh.Scene):\n        reference_mesh = reference_mesh.dump(concatenate=True)\n\n    # Bounding box diagonal for threshold normalization\n    ref_bbox = reference_mesh.bounding_box.bounds\n    bbox_diagonal = float(np.linalg.norm(ref_bbox[1] - ref_bbox[0]))\n    cd_threshold = cd_threshold_ratio * bbox_diagonal\n\n    # Sample point clouds\n    np.random.seed(random_seed)\n    gen_points = np.array(\n        generated_mesh.sample(num_sample_points, return_index=True)[0]\n    )\n    ref_points = np.array(\n        reference_mesh.sample(num_sample_points, return_index=True)[0]\n    )\n\n    # Bidirectional Chamfer Distance\n    tree1 = cKDTree(gen_points)\n    tree2 = cKDTree(ref_points)\n    forward_cd = float(np.mean(tree2.query(gen_points)[0] ** 2))\n    backward_cd = float(np.mean(tree1.query(ref_points)[0] ** 2))\n    total_cd = forward_cd + backward_cd\n\n    assert total_cd <= cd_threshold, (\n        f\"Chamfer Distance check failed: total_cd={total_cd:.6f}, \"\n        f\"threshold={cd_threshold:.6f} ({cd_threshold_ratio * 100:.2f}% of bbox diagonal {bbox_diagonal:.4f})\"\n    )\n\n\n# Registry of validators by name\nVALIDATOR_REGISTRY = {\n    \"default\": PerformanceValidator,\n    \"video\": VideoPerformanceValidator,\n    \"mesh\": MeshValidator,\n}\n\n\ndef get_generate_fn(\n    model_path: str,\n    modality: str,\n    sampling_params: DiffusionSamplingParams,\n) -> Callable[[str, Client], tuple[str, bytes]]:\n    \"\"\"Return appropriate generation function for the case.\"\"\"\n    # Allow override via environment variable (useful for AMD where large resolutions cause slow VAE)\n    output_size = os.environ.get(\"SGLANG_TEST_OUTPUT_SIZE\", sampling_params.output_size)\n    n = sampling_params.num_outputs_per_prompt\n\n    def _create_and_download_video(\n        client,\n        case_id,\n        *,\n        model: str,\n        size: str,\n        prompt: str | None = None,\n        seconds: int | None = None,\n        input_reference: Any | None = None,\n        extra_body: dict[Any] | None = None,\n        expected_frame_count: int | None = None,\n    ) -> str:\n        \"\"\"\n        Create a video job via /v1/videos, poll until completion,\n        then download the binary content and validate it.\n\n        Returns request-id\n        \"\"\"\n\n        create_kwargs: dict[str, Any] = {\n            \"model\": model,\n            \"size\": size,\n        }\n        if prompt is not None:\n            create_kwargs[\"prompt\"] = prompt\n        if seconds is not None:\n            create_kwargs[\"seconds\"] = seconds\n        if input_reference is not None:\n            create_kwargs[\"input_reference\"] = input_reference  # triggers multipart\n        if extra_body is not None:\n            create_kwargs[\"extra_body\"] = extra_body\n\n        job = client.videos.create(**create_kwargs)  # type: ignore[attr-defined]\n        video_id = job.id\n\n        job_completed = False\n        is_baseline_generation_mode = os.environ.get(\"SGLANG_GEN_BASELINE\", \"0\") == \"1\"\n        # Check if running on AMD GPU - use longer timeout\n        is_amd = current_platform.is_hip()\n        if is_baseline_generation_mode:\n            timeout = 3600.0\n        elif is_amd:\n            timeout = 2400.0  # 40 minutes for AMD\n        else:\n            timeout = 1200.0\n        deadline = time.time() + timeout\n        while True:\n            page = client.videos.list()  # type: ignore[attr-defined]\n            item = next((v for v in page.data if v.id == video_id), None)\n\n            if item and getattr(item, \"status\", None) == \"completed\":\n                job_completed = True\n                break\n\n            if time.time() > deadline:\n                break\n\n            time.sleep(1)\n\n        if not job_completed:\n            if is_baseline_generation_mode:\n                logger.warning(\n                    f\"{case_id}: video job {video_id} timed out during baseline generation. \"\n                    \"Attempting to collect performance data anyway.\"\n                )\n                return (video_id, b\"\")\n\n            if is_amd:\n                logger.warning(\n                    f\"[AMD TIMEOUT WARNING] {case_id}: video job {video_id} did not complete \"\n                    f\"within {timeout}s timeout. This may indicate performance issues on AMD.\"\n                )\n                pytest.skip(\n                    f\"{case_id}: video job timed out on AMD after {timeout}s - skipping\"\n                )\n\n            pytest.fail(f\"{case_id}: video job {video_id} did not complete in time\")\n\n        # download video\n        resp = client.videos.download_content(video_id=video_id)  # type: ignore[attr-defined]\n        content = resp.read()\n        validate_openai_video(content)\n\n        expected_filename = f\"{video_id}.mp4\"\n        tmp_path = expected_filename\n        with open(tmp_path, \"wb\") as f:\n            f.write(content)\n\n        # Validate output file\n        expected_width, expected_height = parse_dimensions(size)\n        validate_video_file(\n            tmp_path, expected_filename, expected_width, expected_height\n        )\n\n        if expected_frame_count is not None:\n            actual_count = get_video_frame_count(tmp_path)\n            assert actual_count == expected_frame_count, (\n                f\"{case_id}: frame count mismatch after interpolation — \"\n                f\"expected {expected_frame_count}, got {actual_count}\"\n            )\n\n        upload_file_to_slack(\n            case_id=case_id,\n            model=model_path,\n            prompt=sampling_params.prompt,\n            file_path=tmp_path,\n            origin_file_path=sampling_params.image_path,\n        )\n        os.remove(tmp_path)\n\n        return (video_id, content)\n\n    video_seconds = sampling_params.seconds or 4\n\n    def generate_image(case_id, client) -> tuple[str, bytes]:\n        \"\"\"T2I: Text to Image generation.\"\"\"\n        if not sampling_params.prompt:\n            pytest.skip(f\"{case_id}: no text prompt configured\")\n\n        # Request parameters that affect output format\n        req_output_format = None  # Not specified in current request\n        req_background = None  # Not specified in current request\n\n        # Build extra_body for optional features\n        extra_body = dict(sampling_params.extras)\n\n        response = client.images.with_raw_response.generate(\n            model=model_path,\n            prompt=sampling_params.prompt,\n            n=n,\n            size=output_size,\n            response_format=\"b64_json\",\n            extra_body=extra_body if extra_body else None,\n        )\n        result = response.parse()\n        validate_image(result.data[0].b64_json)\n\n        rid = result.id\n\n        img_data = base64.b64decode(result.data[0].b64_json)\n        # Infer expected format from request parameters\n        expected_ext = get_expected_image_format(req_output_format, req_background)\n        expected_filename = f\"{result.created}.{expected_ext}\"\n        tmp_path = expected_filename\n        with open(tmp_path, \"wb\") as f:\n            f.write(img_data)\n\n        # Validate output file\n        expected_width, expected_height = parse_dimensions(output_size)\n        if (\n            sampling_params.extras.get(\"enable_upscaling\")\n            and expected_width\n            and expected_height\n        ):\n            expected_width *= sampling_params.extras.get(\"upscaling_scale\", 4)\n            expected_height *= sampling_params.extras.get(\"upscaling_scale\", 4)\n        validate_image_file(\n            tmp_path,\n            expected_filename,\n            expected_width,\n            expected_height,\n            output_format=req_output_format,\n            background=req_background,\n        )\n\n        upload_file_to_slack(\n            case_id=case_id,\n            model=model_path,\n            prompt=sampling_params.prompt,\n            file_path=tmp_path,\n        )\n        os.remove(tmp_path)\n\n        return (rid, img_data)\n\n    def generate_image_edit(case_id, client) -> tuple[str, bytes]:\n        \"\"\"TI2I: Text + Image -> Image edit.\"\"\"\n        if not sampling_params.prompt or not sampling_params.image_path:\n            pytest.skip(f\"{case_id}: no edit config\")\n\n        image_paths = sampling_params.image_path\n\n        if not isinstance(image_paths, list):\n            image_paths = [image_paths]\n\n        new_image_paths = []\n        for image_path in image_paths:\n            if is_image_url(image_path):\n                new_image_paths.append(download_image_from_url(str(image_path)))\n            else:\n                local_path = Path(image_path)\n                new_image_paths.append(local_path)\n                if not local_path.exists():\n                    pytest.skip(f\"{case_id}: file missing: {image_path}\")\n\n        image_paths = new_image_paths\n\n        # Request parameters that affect output format\n        req_output_format = (\n            sampling_params.output_format\n        )  # Not specified in current request\n        req_background = None  # Not specified in current request\n\n        # Build extra_body for optional features\n        extra_body = {\"num_frames\": sampling_params.num_frames}\n        extra_body.update(sampling_params.extras)\n\n        images = [open(image_path, \"rb\") for image_path in image_paths]\n        try:\n            response = client.images.with_raw_response.edit(\n                model=model_path,\n                image=images,\n                prompt=sampling_params.prompt,\n                n=n,\n                size=output_size,\n                response_format=\"b64_json\",\n                output_format=req_output_format,\n                extra_body=extra_body,\n            )\n        finally:\n            for img in images:\n                img.close()\n\n        result = response.parse()\n        validate_image(result.data[0].b64_json)\n\n        img_data = base64.b64decode(result.data[0].b64_json)\n        rid = result.id\n\n        # Infer expected format from request parameters\n        expected_ext = get_expected_image_format(req_output_format, req_background)\n        expected_filename = f\"{rid}.{expected_ext}\"\n        tmp_path = expected_filename\n        with open(tmp_path, \"wb\") as f:\n            f.write(img_data)\n\n        # Validate output file\n        expected_width, expected_height = parse_dimensions(output_size)\n        validate_image_file(\n            tmp_path,\n            expected_filename,\n            expected_width,\n            expected_height,\n            output_format=req_output_format,\n            background=req_background,\n        )\n\n        upload_file_to_slack(\n            case_id=case_id,\n            model=model_path,\n            prompt=sampling_params.prompt,\n            file_path=tmp_path,\n            origin_file_path=sampling_params.image_path,\n        )\n        os.remove(tmp_path)\n\n        return (rid, img_data)\n\n    def generate_image_edit_url(case_id, client) -> tuple[str, bytes]:\n        \"\"\"TI2I: Text + Image ? Image edit using direct URL transfer (no pre-download).\"\"\"\n        if not sampling_params.prompt or not sampling_params.image_path:\n            pytest.skip(f\"{case_id}: no edit config\")\n        # Handle both single URL and list of URLs\n        image_urls = sampling_params.image_path\n        if not isinstance(image_urls, list):\n            image_urls = [image_urls]\n\n        # Validate all URLs\n        for url in image_urls:\n            if not is_image_url(url):\n                pytest.skip(\n                    f\"{case_id}: image_path must be a URL for URL direct test: {url}\"\n                )\n\n        # Request parameters that affect output format\n        req_output_format = (\n            sampling_params.output_format\n        )  # Not specified in current request\n        req_background = None  # Not specified in current request\n\n        response = client.images.with_raw_response.edit(\n            model=model_path,\n            prompt=sampling_params.prompt,\n            image=[],  # Only for OpenAI verification\n            n=n,\n            size=sampling_params.output_size,\n            response_format=\"b64_json\",\n            output_format=req_output_format,\n            extra_body={\"url\": image_urls, \"num_frames\": sampling_params.num_frames},\n        )\n\n        result = response.parse()\n        rid = result.id\n\n        validate_image(result.data[0].b64_json)\n\n        # Save and upload result for verification\n        img_data = base64.b64decode(result.data[0].b64_json)\n        # Infer expected format from request parameters\n        expected_ext = get_expected_image_format(req_output_format, req_background)\n        expected_filename = f\"{rid}.{expected_ext}\"\n        tmp_path = expected_filename\n        with open(tmp_path, \"wb\") as f:\n            f.write(img_data)\n\n        # Validate output file\n        expected_width, expected_height = parse_dimensions(sampling_params.output_size)\n        validate_image_file(\n            tmp_path,\n            expected_filename,\n            expected_width,\n            expected_height,\n            output_format=req_output_format,\n            background=req_background,\n        )\n\n        upload_file_to_slack(\n            case_id=case_id,\n            model=model_path,\n            prompt=sampling_params.prompt,\n            file_path=tmp_path,\n            origin_file_path=str(sampling_params.image_path),\n        )\n        os.remove(tmp_path)\n\n        return (rid, img_data)\n\n    def generate_video(case_id, client) -> tuple[str, bytes]:\n        \"\"\"T2V: Text ? Video.\"\"\"\n        if not sampling_params.prompt:\n            pytest.skip(f\"{case_id}: no text prompt configured\")\n\n        # Build extra_body for optional features\n        extra_body = dict(sampling_params.extras)\n        if sampling_params.num_frames:\n            extra_body[\"num_frames\"] = sampling_params.num_frames\n\n        # Compute expected output frame count for validation\n        expected_frame_count = None\n        if (\n            sampling_params.extras.get(\"enable_frame_interpolation\")\n            and sampling_params.num_frames\n        ):\n            n = sampling_params.num_frames\n            exp = sampling_params.extras.get(\"frame_interpolation_exp\", 1)\n            expected_frame_count = (n - 1) * (2**exp) + 1\n\n        return _create_and_download_video(\n            client,\n            case_id,\n            model=model_path,\n            prompt=sampling_params.prompt,\n            size=output_size,\n            seconds=video_seconds,\n            extra_body=extra_body if extra_body else None,\n            expected_frame_count=expected_frame_count,\n        )\n\n    def generate_image_to_video(case_id, client) -> tuple[str, bytes]:\n        \"\"\"I2V: Image -> Video (optional prompt).\"\"\"\n        if not sampling_params.image_path:\n            pytest.skip(f\"{case_id}: no input image configured\")\n\n        if is_image_url(sampling_params.image_path):\n            image_path = download_image_from_url(str(sampling_params.image_path))\n        else:\n            image_path = Path(sampling_params.image_path)\n            if not image_path.exists():\n                pytest.skip(f\"{case_id}: file missing: {image_path}\")\n\n        # Build extra_body for optional features\n        extra_body = dict(sampling_params.extras)\n\n        with image_path.open(\"rb\") as fh:\n            return _create_and_download_video(\n                client,\n                case_id,\n                model=model_path,\n                prompt=sampling_params.prompt,\n                size=output_size,\n                seconds=video_seconds,\n                input_reference=fh,\n                extra_body=extra_body if extra_body else None,\n            )\n\n    def generate_text_url_image_to_video(case_id, client) -> tuple[str, bytes]:\n        if not sampling_params.prompt or not sampling_params.image_path:\n            pytest.skip(f\"{case_id}: no edit config\")\n\n        # Build extra_body for optional features\n        extra_body = {\"reference_url\": sampling_params.image_path}\n        extra_body.update(sampling_params.extras)\n\n        return _create_and_download_video(\n            client,\n            case_id,\n            model=model_path,\n            prompt=sampling_params.prompt,\n            size=sampling_params.output_size,\n            seconds=video_seconds,\n            extra_body={\n                \"reference_url\": sampling_params.image_path,\n                \"fps\": sampling_params.fps,\n                \"num_frames\": sampling_params.num_frames,\n            },\n        )\n\n    def generate_text_image_to_video(case_id, client) -> tuple[str, bytes]:\n        \"\"\"TI2V: Text + Image -> Video.\"\"\"\n        if not sampling_params.prompt or not sampling_params.image_path:\n            pytest.skip(f\"{case_id}: no edit config\")\n\n        if is_image_url(sampling_params.image_path):\n            image_path = download_image_from_url(str(sampling_params.image_path))\n        else:\n            image_path = Path(sampling_params.image_path)\n            if not image_path.exists():\n                pytest.skip(f\"{case_id}: file missing: {image_path}\")\n\n        # Build extra_body for optional features\n        extra_body = dict(sampling_params.extras)\n\n        with image_path.open(\"rb\") as fh:\n            return _create_and_download_video(\n                client,\n                case_id,\n                model=model_path,\n                prompt=sampling_params.prompt,\n                size=output_size,\n                seconds=video_seconds,\n                input_reference=fh,\n                extra_body={\n                    \"fps\": sampling_params.fps,\n                    \"num_frames\": sampling_params.num_frames,\n                    **extra_body,\n                },\n            )\n\n    def generate_mesh(case_id, client) -> tuple[str, bytes]:\n        \"\"\"I2M: Image to Mesh generation using async /v1/meshes API.\"\"\"\n        import requests as http_requests\n\n        if not sampling_params.image_path:\n            pytest.skip(f\"{case_id}: no input image configured for mesh generation\")\n\n        image_path = sampling_params.image_path\n        if isinstance(image_path, str) and is_image_url(image_path):\n            image_path = download_image_from_url(image_path)\n        elif isinstance(image_path, Path):\n            if not image_path.exists():\n                pytest.skip(f\"{case_id}: image file missing: {image_path}\")\n        else:\n            image_path = Path(str(image_path))\n            if not image_path.exists():\n                pytest.skip(f\"{case_id}: image file missing: {image_path}\")\n\n        base_url = str(client.base_url).rstrip(\"/\")\n        if base_url.endswith(\"/v1\"):\n            base_url = base_url[:-3]\n\n        create_url = f\"{base_url}/v1/meshes\"\n\n        with open(str(image_path), \"rb\") as img_file:\n            files = {\"image\": (Path(str(image_path)).name, img_file, \"image/png\")}\n            data = {\n                \"prompt\": \"generate 3d mesh\",\n                \"model\": model_path,\n                \"seed\": \"0\",\n                \"guidance_scale\": \"5.0\",\n                \"num_inference_steps\": \"50\",\n            }\n\n            logger.info(f\"[Mesh Gen] Sending request to {create_url}\")\n\n            try:\n                response = http_requests.post(\n                    create_url, files=files, data=data, timeout=60\n                )\n            except Exception as e:\n                pytest.fail(f\"{case_id}: mesh creation request failed: {e}\")\n\n        if response.status_code != 200:\n            pytest.fail(f\"{case_id}: mesh creation failed: {response.text}\")\n\n        job = response.json()\n        mesh_id = job.get(\"id\")\n        if not mesh_id:\n            pytest.fail(f\"{case_id}: no mesh id in response: {job}\")\n\n        poll_url = f\"{base_url}/v1/meshes/{mesh_id}\"\n        poll_interval = 5\n        max_wait = 1200\n        elapsed = 0\n\n        while elapsed < max_wait:\n            time.sleep(poll_interval)\n            elapsed += poll_interval\n\n            try:\n                poll_resp = http_requests.get(poll_url, timeout=30)\n            except Exception as e:\n                logger.warning(f\"[Mesh Gen] Poll failed: {e}\")\n                continue\n\n            if poll_resp.status_code != 200:\n                continue\n\n            status_data = poll_resp.json()\n            status = status_data.get(\"status\", \"\")\n\n            if status == \"completed\":\n                content_url = f\"{base_url}/v1/meshes/{mesh_id}/content\"\n                try:\n                    content_resp = http_requests.get(content_url, timeout=60)\n                except Exception as e:\n                    pytest.fail(f\"{case_id}: mesh download failed: {e}\")\n\n                if content_resp.status_code != 200:\n                    pytest.fail(f\"{case_id}: mesh download failed: {content_resp.text}\")\n\n                temp_path = Path(tempfile.gettempdir()) / f\"mesh_test_{mesh_id}.glb\"\n                temp_path.write_bytes(content_resp.content)\n                MESH_OUTPUT_PATHS[case_id] = str(temp_path)\n\n                logger.info(f\"[Mesh Gen] Mesh downloaded to {temp_path}\")\n                return (mesh_id, b\"\")\n            elif status == \"failed\":\n                error = status_data.get(\"error\", {})\n                pytest.fail(f\"{case_id}: mesh generation failed: {error}\")\n\n        pytest.fail(f\"{case_id}: mesh generation timed out after {max_wait}s\")\n\n    if modality == \"3d\":\n        fn = generate_mesh\n    elif modality == \"video\":\n        if sampling_params.image_path and sampling_params.prompt:\n            if getattr(sampling_params, \"direct_url_test\", False):\n                fn = generate_text_url_image_to_video\n            else:\n                fn = generate_text_image_to_video\n        elif sampling_params.image_path:\n            fn = generate_image_to_video\n        else:\n            fn = generate_video\n    elif sampling_params.prompt and sampling_params.image_path:\n        if getattr(sampling_params, \"direct_url_test\", False):\n            fn = generate_image_edit_url\n        else:\n            fn = generate_image_edit\n    else:\n        fn = generate_image\n\n    return fn\n"
  },
  {
    "path": "python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py",
    "content": "\"\"\"Tests for diffusion `update_weights_from_disk`.\n\nThis module verifies the ability to update model weights in place without restarting\nthe server, which is critical for RL workflows and iterative fine-tuning scenarios.\n\nAuthor:\n\nMenyang Liu, https://github.com/dreamyang-liu\nChenyang Zhao, https://github.com/zhaochenyang20\n\nWe use two model pairs for testing (base model / instruct model pairs):\n\n- FLUX.2-klein-base-4B / FLUX.2-klein-4B\n- Qwen/Qwen-Image / Qwen/Qwen-Image-2512\n\nThese model pairs share the same architecture but differ in transformer\nweights. The basic testing logic is to refit the instruct model into the\nbase model and verify the checksum of the transformer weights are the same,\nwhich simulates the real-world RL scenario. However, since these two model\npairs only differ in transformer weights, and we want to verify update a\nspecific module with update_weights_from_disk API, we need to create a perturbed\ninstruct model that adds noise to the vae weights. In this sense, the instruct\nmodel differs from the base model in vae and transformer weights, the text\nencoder are still the same.\n\nTo strictly verify the correctness of the refit API, we compare the checksum in\nSHA-256 on the disk and the server.\n\nNOTE and TODO: In the refit a specific module test, we randomly select one module\nfrom the transformer and vae to refit the server and keep other modules the same.\nAs described above, the vae's weights are perturbed. If we select the vae to be the\ntarget module, ideally speaking, we should assert that the refitted vae's checksum\nis the same as directly computed from the perturbed vae weights in the disk. However,\nsince the there is complex weight-name remapping and QKV merge during model loading,\nit is not easy to compare the server-disk checksum for vae and text encoder directly.\nTherefore, if the target module is vae, we only verify that the refitted vae's checksum\nis different from the base model's vae's checksum.\n\nIt should be good issue to solve for the community to adds comparison the server-disk\nchecksum for vae and text encoder in this test.\n\n=============================================================================\n\nTest organization:\n\n7 test cases in 2 classes;\ntwo model pairs are tested locally, one in CI.\n\n=============================================================================\n\nClass 1: TestUpdateWeightsFromDisk                  (6 tests) — API contract, checksum & rollback\nClass 2: TestUpdateWeightsFromDiskWithOffload       (1 test) — Offload-aware update + checksum\n\n-----------------------------------------------------------------------------\n\nClass 1: TestUpdateWeightsFromDisk\n\nValidate the update_weights_from_disk API contract, request/response shape,\nerror handling, checksum verification, and corrupted-weight rollback.\n\nAll tests share one class-scoped server (same process, same in-memory weights).\nTests that require \"base model then update\" should be explicitly reset to\nbase model first so behavior is order-independent and updates are real\n(base -> perturbed), not no-ops (perturbed -> perturbed).\n\n  • test_update_weights_from_disk_default\n\n    base model -> perturbed model with flush_cache=True.\n    Verifies after-update transformer checksum == perturbed model's\n    transformer disk checksum\n\n\n  • test_update_weights_specific_modules\n\n    base -> perturbed with flush_cache=False.  Randomly selects one module\n    from _DIFFERING_MODULES (transformer and vae) as target_modules, updates\n    only that module. Verifies that:\n    (1) targeted module's in-memory checksum changed;\n    (2) non-targeted modules' in-memory checksums are unchanged.\n\n  • test_update_weights_nonexistent_model\n\n    model_path set to a non-existent path; must fail (400, success=False).\n\n    Ensure server is healthy after failed update and server's transformer\n    checksums equal base model's transformer disk checksum.\n\n  • test_update_weights_missing_model_path\n\n    Request body empty (no model_path); must fail (400, success=False).\n\n    Ensure server is healthy after failed update and server's transformer\n    checksums equal base model's transformer disk checksum.\n\n  • test_update_weights_nonexistent_module\n\n    target_modules=[\"nonexistent_module\"]; must fail (400, success=False).\n\n    Verify server is healthy after failed update and server's checksums\n    equal base model's transformer disk checksum.\n\n  • test_corrupted_weights_rollback\n\n    All-or-nothing rollback: We first refit the server from base model ->\n    perturbed model. We manually truncate the vae weights of the base\n    model to get a corrupted model. We then call the refit to update\n    the server from the perturbed model -> corrupted model. Verify that:\n\n    1. The update fails due to truncated vae, server should roll back to the\n    perturbed model, i.e., server's transformer weights == perturbed model's\n    transformer weights != base model's transformer weights.\n\n    2. After the rollback, server's vae weights == perturbed model's vae\n    weights != base model's vae weights.\n\n    3. After the rollback, server's text encoder weights == base model's\n    text encoder weights == perturbed model's text encoder weights.\n\n-----------------------------------------------------------------------------\n\nClass 2: TestUpdateWeightsFromDiskWithOffload\n\n\nEnsure weight updates and checksum verification work when layerwise offload is enabled\n(--dit-layerwise-offload). With offload, parameters live in CPU buffers and only left\nsmall torch.empty((1,)) as placeholders on GPU; the updater must write into CPU buffers\nand update prefetched GPU tensors without shape mismatch.\n\n  • test_update_weights_with_offload_enabled\n\n    Server with --dit-layerwise-offload (base). Load perturbed checkpoint;\n    must succeed (200, success=True), no \"Shape mismatch\". server's transformer checksum\n    matches perturbed model's transformer disk checksum.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport functools\nimport os\nimport random\nimport shutil\nimport tempfile\nimport threading\nfrom collections.abc import Callable\n\nimport pytest\nimport requests\nfrom safetensors.torch import load_file, save_file\n\nfrom sglang.multimodal_gen.runtime.loader.utils import (\n    _list_safetensors_files,\n)\nfrom sglang.multimodal_gen.runtime.loader.weight_utils import (\n    compute_weights_checksum,\n    safetensors_weights_iterator,\n)\nfrom sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import maybe_download_model\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.test.server.test_server_utils import (\n    ServerManager,\n)\nfrom sglang.multimodal_gen.test.test_utils import (\n    DEFAULT_FLUX_2_KLEIN_4B_MODEL_NAME_FOR_TEST,\n    DEFAULT_FLUX_2_KLEIN_BASE_4B_MODEL_NAME_FOR_TEST,\n    DEFAULT_QWEN_IMAGE_2512_MODEL_NAME_FOR_TEST,\n    DEFAULT_QWEN_IMAGE_MODEL_NAME_FOR_TEST,\n    get_dynamic_server_port,\n    is_in_ci,\n)\n\nlogger = init_logger(__name__)\n\n\n_TRANSFORMER_MODULE = \"transformer\"\n_VAE_MODULE = \"vae\"\n_TEXT_ENCODER_MODULE_PREFIX = \"text_encoder\"\n\n\n# Modules whose weights differ between the base model and the perturbed\n# perturbed checkpoint\n_DIFFERING_MODULES: list[str] = [_TRANSFORMER_MODULE, _VAE_MODULE]\n\n_ALL_MODEL_PAIRS: list[tuple[str, str]] = [\n    (\n        DEFAULT_FLUX_2_KLEIN_BASE_4B_MODEL_NAME_FOR_TEST,\n        DEFAULT_FLUX_2_KLEIN_4B_MODEL_NAME_FOR_TEST,\n    ),\n    (\n        DEFAULT_QWEN_IMAGE_MODEL_NAME_FOR_TEST,\n        DEFAULT_QWEN_IMAGE_2512_MODEL_NAME_FOR_TEST,\n    ),\n]\n\n\n_CI_MODEL_PAIR_ENV = \"SGLANG_MMGEN_UPDATE_WEIGHTS_PAIR\"\n\n\ndef _resolve_active_model_pairs() -> list[tuple[str, str]]:\n    if not is_in_ci():\n        return _ALL_MODEL_PAIRS\n\n    pair_by_id = {pair[0].split(\"/\")[-1]: pair for pair in _ALL_MODEL_PAIRS}\n    selected_pair_id = os.environ.get(_CI_MODEL_PAIR_ENV)\n    if selected_pair_id is None:\n        return [random.choice(_ALL_MODEL_PAIRS)]\n\n    selected_pair = pair_by_id.get(selected_pair_id)\n    if selected_pair is None:\n        valid_ids = \", \".join(sorted(pair_by_id))\n        raise ValueError(\n            f\"Invalid {_CI_MODEL_PAIR_ENV}={selected_pair_id!r}. \"\n            f\"Expected one of: {valid_ids}.\"\n        )\n    return [selected_pair]\n\n\n_ACTIVE_MODEL_PAIRS = _resolve_active_model_pairs()\n_PAIR_IDS = [p[0].split(\"/\")[-1] for p in _ACTIVE_MODEL_PAIRS]\n\n\n@functools.lru_cache(maxsize=None)\ndef _compute_checksum_from_disk(model_path: str, module_name: str) -> str:\n    \"\"\"Compute SHA-256 checksum from safetensors files on disk.\n\n    Uses the same compute_weights_checksum function as the server,\n    so the checksums are directly comparable.\n\n    Results are cached (keyed on model_path and module_name) because the\n    same disk checksum is requested multiple times across tests.\n    \"\"\"\n    local_path = maybe_download_model(model_path)\n    weights_dir = os.path.join(local_path, module_name)\n    assert os.path.exists(\n        weights_dir\n    ), f\"No weights dir for {module_name} in {local_path}\"\n\n    safetensors_files = _list_safetensors_files(weights_dir)\n    assert safetensors_files, f\"No safetensors files in {weights_dir}\"\n\n    return compute_weights_checksum(safetensors_weights_iterator(safetensors_files))\n\n\ndef _clone_model_with_modified_module(\n    src_model: str,\n    dst_model: str,\n    target_module: str,\n    transform_safetensor: Callable[[str, str], None],\n) -> None:\n    # Symlink root-level files (model_index.json, etc.).\n    for fname in os.listdir(src_model):\n        src_path = os.path.join(src_model, fname)\n        dst_path = os.path.join(dst_model, fname)\n        if os.path.isfile(src_path) and not os.path.exists(dst_path):\n            os.symlink(src_path, dst_path)\n\n    for module_dir in sorted(os.listdir(src_model)):\n        src_dir = os.path.join(src_model, module_dir)\n        dst_dir = os.path.join(dst_model, module_dir)\n        if not os.path.isdir(src_dir):\n            continue\n\n        if module_dir != target_module:\n            if not os.path.exists(dst_dir):\n                os.symlink(src_dir, dst_dir)\n            continue\n\n        os.makedirs(dst_dir, exist_ok=True)\n        transformed = False\n        for fname in sorted(os.listdir(src_dir)):\n            src_file = os.path.join(src_dir, fname)\n            dst_file = os.path.join(dst_dir, fname)\n            if not os.path.isfile(src_file):\n                continue\n\n            if not fname.endswith(\".safetensors\") or transformed:\n                if not os.path.exists(dst_file):\n                    os.symlink(src_file, dst_file)\n                continue\n\n            transform_safetensor(src_file, dst_file)\n            transformed = True\n\n\ndef _truncate_safetensor(src_file: str, dst_file: str) -> None:\n    shutil.copy2(src_file, dst_file)\n    size = os.path.getsize(dst_file)\n    with open(dst_file, \"r+b\") as f:\n        f.truncate(size - 2)\n    logger.info(\n        \"Created corrupted safetensors: %s (%d -> %d bytes)\",\n        dst_file,\n        size,\n        size - 2,\n    )\n\n\ndef _perturb_safetensor(src_file: str, dst_file: str) -> None:\n\n    tensors = load_file(src_file)\n    perturbed = {\n        k: (t + 0.01 if t.is_floating_point() else t) for k, t in tensors.items()\n    }\n    save_file(perturbed, dst_file)\n    logger.info(\"Created perturbed safetensors: %s\", dst_file)\n\n\nclass _UpdateWeightsApiMixin:\n    def _update_weights(\n        self,\n        base_url: str,\n        model_path: str,\n        flush_cache: bool = True,\n        target_modules: list[str] | None = None,\n        timeout: int = 300,\n    ) -> tuple[dict, int]:\n        payload = {\"model_path\": model_path, \"flush_cache\": flush_cache}\n        if target_modules is not None:\n            payload[\"target_modules\"] = target_modules\n        response = requests.post(\n            f\"{base_url}/update_weights_from_disk\",\n            json=payload,\n            timeout=timeout,\n        )\n        return response.json(), response.status_code\n\n    def _get_weights_checksum(\n        self,\n        base_url: str,\n        module_names: list[str] | None = None,\n        timeout: int = 300,\n    ) -> dict:\n        payload = {}\n        if module_names is not None:\n            payload[\"module_names\"] = module_names\n        response = requests.post(\n            f\"{base_url}/get_weights_checksum\",\n            json=payload,\n            timeout=timeout,\n        )\n        assert (\n            response.status_code == 200\n        ), f\"get_weights_checksum failed: {response.status_code} {response.text}\"\n        return response.json()\n\n    def _assert_server_matches_model(\n        self,\n        base_url: str,\n        expected_model: str,\n    ) -> None:\n        server_checksums = self._get_weights_checksum(\n            base_url, module_names=[_TRANSFORMER_MODULE]\n        )\n        expected_cs = _compute_checksum_from_disk(expected_model, _TRANSFORMER_MODULE)\n        server_cs = server_checksums.get(_TRANSFORMER_MODULE)\n        assert server_cs == expected_cs, (\n            f\"Checksum mismatch on '{_TRANSFORMER_MODULE}'\\n\"\n            f\"  expected({expected_model}): {expected_cs}\\n\"\n            f\"  server: {server_cs}\"\n        )\n\n\nclass TestUpdateWeightsFromDisk(_UpdateWeightsApiMixin):\n\n    @pytest.fixture(\n        scope=\"class\",\n        params=_ACTIVE_MODEL_PAIRS,\n        ids=_PAIR_IDS,\n    )\n    def diffusion_server_no_offload(self, request):\n        default_model, source_model = request.param\n        port = get_dynamic_server_port()\n        wait_deadline = float(os.environ.get(\"SGLANG_TEST_WAIT_SECS\", \"600\"))\n\n        manager = ServerManager(\n            model=default_model,\n            port=port,\n            wait_deadline=wait_deadline,\n            extra_args=\"--num-gpus 1\",\n        )\n\n        # Ensure models are local before spawning threads that need the paths.\n        local_default = maybe_download_model(default_model)\n        local_source = maybe_download_model(source_model)\n\n        perturbed_vae_model_dir = tempfile.mkdtemp(prefix=\"sglang_perturbed_vae_\")\n        corrupted_vae_model_dir = tempfile.mkdtemp(prefix=\"sglang_corrupted_\")\n\n        # Run all disk I/O in background while the server boots.\n        bg_threads = [\n            threading.Thread(\n                target=_compute_checksum_from_disk, args=(default_model, module)\n            )\n            for module in _DIFFERING_MODULES\n        ] + [\n            threading.Thread(\n                target=_clone_model_with_modified_module,\n                args=(\n                    local_source,\n                    perturbed_vae_model_dir,\n                    _VAE_MODULE,\n                    _perturb_safetensor,\n                ),\n            ),\n            threading.Thread(\n                target=_clone_model_with_modified_module,\n                args=(\n                    local_default,\n                    corrupted_vae_model_dir,\n                    _VAE_MODULE,\n                    _truncate_safetensor,\n                ),\n            ),\n        ]\n        for t in bg_threads:\n            t.start()\n\n        ctx = manager.start()\n        for t in bg_threads:\n            t.join()\n\n        # Sanity: all _DIFFERING_MODULES should differ between base and perturbed.\n        for module in _DIFFERING_MODULES:\n            assert _compute_checksum_from_disk(\n                default_model, module\n            ) != _compute_checksum_from_disk(perturbed_vae_model_dir, module), (\n                f\"Assumption violated: {module} should differ between \"\n                f\"{default_model} and {perturbed_vae_model_dir}\"\n            )\n\n        try:\n            yield ctx, default_model, perturbed_vae_model_dir, corrupted_vae_model_dir\n        finally:\n            ctx.cleanup()\n            shutil.rmtree(perturbed_vae_model_dir, ignore_errors=True)\n            shutil.rmtree(corrupted_vae_model_dir, ignore_errors=True)\n\n    def test_update_weights_from_disk_default(self, diffusion_server_no_offload):\n        \"\"\"Default update (target_modules=None, flush_cache=True): all changed modules updated.\"\"\"\n        ctx, default_model, perturbed_model_dir, _ = diffusion_server_no_offload\n        base_url = f\"http://localhost:{ctx.port}\"\n\n        self._update_weights(base_url, default_model, flush_cache=True)\n\n        result, status_code = self._update_weights(\n            base_url, perturbed_model_dir, flush_cache=True\n        )\n        assert status_code == 200\n        assert result.get(\"success\", False), f\"Update failed: {result.get('message')}\"\n\n        self._assert_server_matches_model(base_url, perturbed_model_dir)\n\n    def test_update_weights_specific_modules(self, diffusion_server_no_offload):\n        ctx, default_model, perturbed_model_dir, _ = diffusion_server_no_offload\n        base_url = f\"http://localhost:{ctx.port}\"\n\n        # Reset server to default_model.\n        self._update_weights(base_url, default_model)\n        before_checksums = self._get_weights_checksum(\n            base_url, module_names=_DIFFERING_MODULES\n        )\n\n        target_modules = [random.choice(_DIFFERING_MODULES)]\n        result, status_code = self._update_weights(\n            base_url,\n            perturbed_model_dir,\n            target_modules=target_modules,\n            flush_cache=False,\n        )\n        assert status_code == 200, f\"Update failed: {result}\"\n        assert result.get(\"success\", False), f\"Update failed: {result.get('message')}\"\n\n        after_checksums = self._get_weights_checksum(\n            base_url, module_names=_DIFFERING_MODULES\n        )\n\n        # Targeted module should have changed.\n        for name in target_modules:\n            assert after_checksums.get(name) != before_checksums.get(name), (\n                f\"Targeted module '{name}' checksum should change after update\\n\"\n                f\"  before: {before_checksums.get(name)}\\n\"\n                f\"  after:  {after_checksums.get(name)}\"\n            )\n\n        # Non-targeted modules should be unchanged.\n        for name, cs in after_checksums.items():\n            if name in target_modules or cs == \"not_found\":\n                continue\n            assert cs == before_checksums.get(name), (\n                f\"Non-targeted module '{name}' should be unchanged\\n\"\n                f\"  before: {before_checksums.get(name)}\\n\"\n                f\"  after:  {cs}\"\n            )\n\n    def test_update_weights_nonexistent_model(self, diffusion_server_no_offload):\n        \"\"\"Nonexistent model path must fail (400). Server healthy, checksums == base disk.\"\"\"\n        ctx, default_model, _, _ = diffusion_server_no_offload\n        base_url = f\"http://localhost:{ctx.port}\"\n\n        self._update_weights(base_url, default_model)\n\n        result, status_code = self._update_weights(\n            base_url,\n            \"/nonexistent/path/to/model\",\n            timeout=60,\n        )\n        logger.info(f\"Update result for nonexistent model: {result}\")\n\n        assert status_code == 400, f\"Expected 400, got {status_code}\"\n        assert not result.get(\"success\", True), \"Should fail for nonexistent model\"\n        self._assert_server_matches_model(base_url, default_model)\n\n    def test_update_weights_missing_model_path(self, diffusion_server_no_offload):\n        \"\"\"Request without model_path must fail (400). Server healthy, checksums == base disk.\"\"\"\n        ctx, default_model, _, _ = diffusion_server_no_offload\n        base_url = f\"http://localhost:{ctx.port}\"\n\n        self._update_weights(base_url, default_model)\n\n        response = requests.post(\n            f\"{base_url}/update_weights_from_disk\",\n            json={},\n            timeout=30,\n        )\n\n        assert response.status_code == 400, f\"Expected 400, got {response.status_code}\"\n        result = response.json()\n        assert not result.get(\"success\", True), \"Should fail when model_path is missing\"\n        self._assert_server_matches_model(base_url, default_model)\n\n    def test_update_weights_nonexistent_module(self, diffusion_server_no_offload):\n        \"\"\"Nonexistent module must fail (400). Server healthy, checksums == base disk.\"\"\"\n        ctx, default_model, perturbed_model_dir, _ = diffusion_server_no_offload\n        base_url = f\"http://localhost:{ctx.port}\"\n\n        self._update_weights(base_url, default_model)\n\n        result, status_code = self._update_weights(\n            base_url,\n            perturbed_model_dir,\n            target_modules=[\"nonexistent_module\"],\n            timeout=60,\n        )\n        logger.info(f\"Update nonexistent module result: {result}\")\n\n        assert status_code == 400, f\"Expected 400, got {status_code}\"\n        assert not result.get(\"success\", True), \"Should fail for nonexistent module\"\n        assert \"not found in pipeline\" in result.get(\"message\", \"\")\n        self._assert_server_matches_model(base_url, default_model)\n\n    def test_corrupted_weights_rollback(self, diffusion_server_no_offload):\n        ctx, default_model, perturbed_model_dir, corrupted_vae_model_dir = (\n            diffusion_server_no_offload\n        )\n        base_url = f\"http://localhost:{ctx.port}\"\n\n        # base → perturbed\n        self._update_weights(base_url, default_model)\n        base_checksums = self._get_weights_checksum(base_url)\n\n        result, status_code = self._update_weights(base_url, perturbed_model_dir)\n        assert status_code == 200 and result.get(\"success\")\n        perturbed_checksums = self._get_weights_checksum(base_url)\n\n        text_encoder_modules = sorted(\n            name\n            for name in perturbed_checksums\n            if _TEXT_ENCODER_MODULE_PREFIX in name\n            and perturbed_checksums.get(name) != \"not_found\"\n            and base_checksums.get(name) != \"not_found\"\n        )\n        assert (\n            text_encoder_modules\n        ), \"Expected at least one text encoder module checksum\"\n\n        # perturbed → corrupted (should fail and rollback)\n        rollback_targets = [_TRANSFORMER_MODULE, _VAE_MODULE]\n        result, status_code = self._update_weights(\n            base_url,\n            corrupted_vae_model_dir,\n            target_modules=rollback_targets,\n        )\n        assert (\n            status_code == 400\n        ), f\"Expected 400 on corrupted weights, got {status_code}\"\n        assert not result.get(\"success\", True)\n        message = result.get(\"message\", \"\")\n        assert \"rolled back\" in message.lower()\n        # The updater reports the first failing module in the error message.\n        # With ordered target_modules=[transformer, vae], this makes the\n        # failure point explicit: transformer is processed first, then vae fails.\n        assert (\n            \"Failed to update module 'vae'\" in message\n        ), f\"Expected vae to be the explicit failure point, got: {message}\"\n        rolled_back_checksums = self._get_weights_checksum(base_url)\n\n        # 1) transformer: server == perturbed != base\n        transformer_base = base_checksums.get(_TRANSFORMER_MODULE)\n        transformer_perturbed = perturbed_checksums.get(_TRANSFORMER_MODULE)\n        transformer_rolled_back = rolled_back_checksums.get(_TRANSFORMER_MODULE)\n        assert transformer_rolled_back == transformer_perturbed\n        assert transformer_rolled_back != transformer_base\n\n        # 2) vae: server == perturbed != base\n        vae_base = base_checksums.get(_VAE_MODULE)\n        vae_perturbed = perturbed_checksums.get(_VAE_MODULE)\n        vae_rolled_back = rolled_back_checksums.get(_VAE_MODULE)\n        assert vae_rolled_back == vae_perturbed\n        assert vae_rolled_back != vae_base\n\n        # 3) text encoder(s): server == base == perturbed\n        for name in text_encoder_modules:\n            assert rolled_back_checksums.get(name) == perturbed_checksums.get(\n                name\n            ), f\"Text encoder module '{name}' should stay equal to perturbed\"\n            assert rolled_back_checksums.get(name) == base_checksums.get(\n                name\n            ), f\"Text encoder module '{name}' should stay equal to base\"\n\n\nclass TestUpdateWeightsFromDiskWithOffload(_UpdateWeightsApiMixin):\n    \"\"\"Test update_weights_from_disk with layerwise offload enabled.\"\"\"\n\n    @pytest.fixture(scope=\"class\", params=_ACTIVE_MODEL_PAIRS, ids=_PAIR_IDS)\n    def diffusion_server_with_offload(self, request):\n        default_model, source_model = request.param\n        port = get_dynamic_server_port()\n        wait_deadline = float(os.environ.get(\"SGLANG_TEST_WAIT_SECS\", \"600\"))\n\n        local_source = maybe_download_model(source_model)\n        perturbed_vae_model_dir = tempfile.mkdtemp(prefix=\"sglang_perturbed_vae_\")\n\n        clone_thread = threading.Thread(\n            target=_clone_model_with_modified_module,\n            args=(\n                local_source,\n                perturbed_vae_model_dir,\n                _VAE_MODULE,\n                _perturb_safetensor,\n            ),\n        )\n        clone_thread.start()\n\n        manager = ServerManager(\n            model=default_model,\n            port=port,\n            wait_deadline=wait_deadline,\n            extra_args=\"--num-gpus 1 --dit-layerwise-offload true\",\n        )\n\n        ctx = manager.start()\n        clone_thread.join()\n\n        try:\n            yield ctx, default_model, perturbed_vae_model_dir\n        finally:\n            ctx.cleanup()\n            shutil.rmtree(perturbed_vae_model_dir, ignore_errors=True)\n\n    def test_update_weights_with_offload_enabled(self, diffusion_server_with_offload):\n        ctx, _, perturbed_model_dir = diffusion_server_with_offload\n        base_url = f\"http://localhost:{ctx.port}\"\n\n        result, status_code = self._update_weights(base_url, perturbed_model_dir)\n        assert status_code == 200, f\"Expected 200, got {status_code}\"\n        assert result.get(\"success\", False), f\"Update failed: {result.get('message')}\"\n\n        message = result.get(\"message\", \"\")\n        assert \"Shape mismatch\" not in message, f\"Shape mismatch detected: {message}\"\n\n        self._assert_server_matches_model(base_url, perturbed_model_dir)\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__, \"-v\", \"-s\"])\n"
  },
  {
    "path": "python/sglang/multimodal_gen/test/server/testcase_configs.py",
    "content": "\"\"\"\nConfiguration and data structures for diffusion performance tests.\n\nUsage:\n\npytest python/sglang/multimodal_gen/test/server/test_server_a.py\n# for a single testcase, look for the name of the testcases in DIFFUSION_CASES\npytest python/sglang/multimodal_gen/test/server/test_server_a.py -k qwen_image_t2i\n\n\nTo add a new testcase:\n1. add your testcase with case-id: `my_new_test_case_id` to DIFFUSION_CASES\n2. run `SGLANG_GEN_BASELINE=1 pytest -s python/sglang/multimodal_gen/test/server/ -k my_new_test_case_id`\n3. insert or override the corresponding scenario in `scenarios` section of perf_baselines.json with the output baseline of step-2\n\n\n\"\"\"\n\nfrom __future__ import annotations\n\nimport json\nimport os\nimport statistics\nfrom dataclasses import dataclass, field\nfrom pathlib import Path\nfrom typing import Sequence\n\nfrom sglang.multimodal_gen.runtime.platforms import current_platform\nfrom sglang.multimodal_gen.runtime.utils.perf_logger import RequestPerfRecord\nfrom sglang.multimodal_gen.test.test_utils import (\n    DEFAULT_FLUX_1_DEV_MODEL_NAME_FOR_TEST,\n    DEFAULT_FLUX_2_DEV_MODEL_NAME_FOR_TEST,\n    DEFAULT_FLUX_2_KLEIN_4B_MODEL_NAME_FOR_TEST,\n    DEFAULT_QWEN_IMAGE_EDIT_2509_MODEL_NAME_FOR_TEST,\n    DEFAULT_QWEN_IMAGE_EDIT_2511_MODEL_NAME_FOR_TEST,\n    DEFAULT_QWEN_IMAGE_EDIT_MODEL_NAME_FOR_TEST,\n    DEFAULT_QWEN_IMAGE_LAYERED_MODEL_NAME_FOR_TEST,\n    DEFAULT_QWEN_IMAGE_MODEL_NAME_FOR_TEST,\n    DEFAULT_SMALL_MODEL_NAME_FOR_TEST,\n    DEFAULT_WAN_2_1_I2V_14B_480P_MODEL_NAME_FOR_TEST,\n    DEFAULT_WAN_2_1_I2V_14B_720P_MODEL_NAME_FOR_TEST,\n    DEFAULT_WAN_2_1_T2V_1_3B_MODEL_NAME_FOR_TEST,\n    DEFAULT_WAN_2_1_T2V_14B_MODEL_NAME_FOR_TEST,\n    DEFAULT_WAN_2_2_I2V_A14B_MODEL_NAME_FOR_TEST,\n    DEFAULT_WAN_2_2_T2V_A14B_MODEL_NAME_FOR_TEST,\n    DEFAULT_WAN_2_2_TI2V_5B_MODEL_NAME_FOR_TEST,\n)\n\n\n@dataclass\nclass ToleranceConfig:\n    \"\"\"Tolerance ratios for performance validation.\"\"\"\n\n    e2e: float\n    denoise_stage: float\n    non_denoise_stage: float\n    denoise_step: float\n    denoise_agg: float\n\n    @classmethod\n    def load_profile(cls, all_tolerances: dict, profile_name: str) -> ToleranceConfig:\n        \"\"\"Load a specific tolerance profile from a dictionary of profiles.\"\"\"\n        # Support both flat structure (backward compatibility) and profiled structure\n        if \"e2e\" in all_tolerances and not isinstance(all_tolerances[\"e2e\"], dict):\n            tol_data = all_tolerances\n            actual_profile = \"legacy/flat\"\n        else:\n            tol_data = all_tolerances.get(\n                profile_name, all_tolerances.get(\"pr_test\", {})\n            )\n            actual_profile = (\n                profile_name if profile_name in all_tolerances else \"pr_test\"\n            )\n\n        if not tol_data:\n            raise ValueError(\n                f\"No tolerance profile found for '{profile_name}' and no default 'pr_test' profile exists.\"\n            )\n\n        print(f\"--- Performance Tolerance Profile: {actual_profile} ---\")\n\n        return cls(\n            e2e=float(os.getenv(\"SGLANG_E2E_TOLERANCE\", tol_data[\"e2e\"])),\n            denoise_stage=float(\n                os.getenv(\"SGLANG_STAGE_TIME_TOLERANCE\", tol_data[\"denoise_stage\"])\n            ),\n            non_denoise_stage=float(\n                os.getenv(\n                    \"SGLANG_NON_DENOISE_STAGE_TIME_TOLERANCE\",\n                    tol_data[\"non_denoise_stage\"],\n                )\n            ),\n            denoise_step=float(\n                os.getenv(\"SGLANG_DENOISE_STEP_TOLERANCE\", tol_data[\"denoise_step\"])\n            ),\n            denoise_agg=float(\n                os.getenv(\"SGLANG_DENOISE_AGG_TOLERANCE\", tol_data[\"denoise_agg\"])\n            ),\n        )\n\n\n@dataclass\nclass ScenarioConfig:\n    \"\"\"Expected performance metrics for a test scenario.\"\"\"\n\n    stages_ms: dict[str, float]\n    denoise_step_ms: dict[int, float]\n    expected_e2e_ms: float\n    expected_avg_denoise_ms: float\n    expected_median_denoise_ms: float\n\n\n@dataclass\nclass BaselineConfig:\n    \"\"\"Full baseline configuration.\"\"\"\n\n    scenarios: dict[str, ScenarioConfig]\n    step_fractions: Sequence[float]\n    tolerances: ToleranceConfig\n    improvement_threshold: float\n\n    @classmethod\n    def load(cls, path: Path) -> BaselineConfig:\n        \"\"\"Load baseline configuration from JSON file.\"\"\"\n        with path.open(\"r\", encoding=\"utf-8\") as fh:\n            data = json.load(fh)\n\n        # Get tolerance profile, defaulting to 'pr_test'\n        profile_name = \"pr_test\"\n        tolerances = ToleranceConfig.load_profile(\n            data.get(\"tolerances\", {}), profile_name\n        )\n\n        scenarios = {}\n        for name, cfg in data[\"scenarios\"].items():\n            scenarios[name] = ScenarioConfig(\n                stages_ms=cfg[\"stages_ms\"],\n                denoise_step_ms={int(k): v for k, v in cfg[\"denoise_step_ms\"].items()},\n                expected_e2e_ms=float(cfg[\"expected_e2e_ms\"]),\n                expected_avg_denoise_ms=float(cfg[\"expected_avg_denoise_ms\"]),\n                expected_median_denoise_ms=float(cfg[\"expected_median_denoise_ms\"]),\n            )\n\n        return cls(\n            scenarios=scenarios,\n            step_fractions=tuple(data[\"sampling\"][\"step_fractions\"]),\n            tolerances=tolerances,\n            improvement_threshold=data.get(\"improvement_reporting\", {}).get(\n                \"threshold\", 0.2\n            ),\n        )\n\n    def update(self, path: Path):\n        \"\"\"Load baseline configuration from JSON file.\"\"\"\n        with path.open(\"r\", encoding=\"utf-8\") as fh:\n            data = json.load(fh)\n\n        scenarios_new = {}\n        for name, cfg in data[\"scenarios\"].items():\n            scenarios_new[name] = ScenarioConfig(\n                stages_ms=cfg[\"stages_ms\"],\n                denoise_step_ms={int(k): v for k, v in cfg[\"denoise_step_ms\"].items()},\n                expected_e2e_ms=float(cfg[\"expected_e2e_ms\"]),\n                expected_avg_denoise_ms=float(cfg[\"expected_avg_denoise_ms\"]),\n                expected_median_denoise_ms=float(cfg[\"expected_median_denoise_ms\"]),\n            )\n\n        self.scenarios.update(scenarios_new)\n        return self\n\n\n@dataclass\nclass DiffusionServerArgs:\n    \"\"\"Configuration for a single model/scenario test case.\"\"\"\n\n    model_path: str  # HF repo or local path\n    modality: str = \"image\"  # \"image\" or \"video\" or \"3d\"\n\n    custom_validator: str | None = None  # optional custom validator name\n    # resources\n    num_gpus: int = 1\n    tp_size: int | None = None\n    ulysses_degree: int | None = None\n    ring_degree: int | None = None\n    cfg_parallel: bool | None = None\n    # LoRA\n    lora_path: str | None = (\n        None  # LoRA adapter path (HF repo or local path, loaded at startup)\n    )\n    dynamic_lora_path: str | None = (\n        None  # LoRA path for dynamic loading test (loaded via set_lora after startup)\n    )\n    second_lora_path: str | None = (\n        None  # Second LoRA adapter path for multi-LoRA testing\n    )\n\n    dit_layerwise_offload: bool = False\n    dit_offload_prefetch_size: int | float | None = None\n    enable_cache_dit: bool = False\n    text_encoder_cpu_offload: bool = False\n\n    extras: list[str] = field(default_factory=lambda: [])\n\n    def __post_init__(self):\n        if self.modality == \"image\":\n            self.custom_validator = \"image\"\n        elif self.modality == \"video\":\n            self.custom_validator = \"video\"\n        elif self.modality == \"3d\":\n            self.custom_validator = \"mesh\"\n\n\n@dataclass(frozen=True)\nclass DiffusionSamplingParams:\n    \"\"\"Configuration for a single model/scenario test case.\"\"\"\n\n    output_size: str = \"\"\n\n    # inputs and conditioning\n    prompt: str | None = None  # text prompt for generation\n    image_path: Path | str | None = None  # input image/video for editing (Path or URL)\n\n    # duration\n    seconds: int = 1  # for video: duration in seconds\n    num_frames: int | None = None  # for video: number of frames\n    fps: int | None = None  # for video: frames per second\n\n    # URL direct test flag - if True, don't pre-download URL images\n    direct_url_test: bool = False\n\n    # output format\n    output_format: str | None = None  # \"png\", \"jpeg\", \"mp4\", etc.\n\n    num_outputs_per_prompt: int = 1\n\n    # Additional request-level parameters (e.g. enable_teacache, enable_upscaling, …)\n    # merged directly into the OpenAI extra_body dict.\n    extras: dict = field(default_factory=dict)\n\n\n@dataclass(frozen=True)\nclass DiffusionTestCase:\n    \"\"\"Configuration for a single model/scenario test case.\"\"\"\n\n    id: str  # pytest test id and scenario name\n    server_args: DiffusionServerArgs\n    sampling_params: DiffusionSamplingParams\n    run_perf_check: bool = True\n\n\ndef sample_step_indices(\n    step_map: dict[int, float], fractions: Sequence[float]\n) -> list[int]:\n    if not step_map:\n        return []\n    max_idx = max(step_map.keys())\n    indices = set()\n    for fraction in fractions:\n        idx = min(max_idx, max(0, int(round(fraction * max_idx))))\n        if idx in step_map:\n            indices.add(idx)\n    return sorted(indices)\n\n\n@dataclass\nclass PerformanceSummary:\n    \"\"\"Summary of performance of a request, built from RequestPerfRecord\"\"\"\n\n    e2e_ms: float\n    avg_denoise_ms: float\n    median_denoise_ms: float\n    # { \"stage_1\": time_1, \"stage_2\": time_2 }\n    stage_metrics: dict[str, float]\n    step_metrics: list[float]\n    sampled_steps: dict[int, float]\n    all_denoise_steps: dict[int, float]\n    frames_per_second: float | None = None\n    total_frames: int | None = None\n    avg_frame_time_ms: float | None = None\n\n    @staticmethod\n    def from_req_perf_record(\n        record: RequestPerfRecord, step_fractions: Sequence[float]\n    ):\n        \"\"\"Collect all performance metrics into a summary without validation.\"\"\"\n        e2e_ms = record.total_duration_ms\n\n        step_durations = record.steps\n        avg_denoise = 0.0\n        median_denoise = 0.0\n        if step_durations:\n            avg_denoise = sum(step_durations) / len(step_durations)\n            median_denoise = statistics.median(step_durations)\n\n        per_step = {index: s for index, s in enumerate(step_durations)}\n        sample_indices = sample_step_indices(per_step, step_fractions)\n        sampled_steps = {idx: per_step[idx] for idx in sample_indices}\n\n        # convert from list to dict\n        stage_metrics = {}\n        for item in record.stages:\n            if isinstance(item, dict) and \"name\" in item:\n                val = item.get(\"execution_time_ms\", 0.0)\n                stage_metrics[item[\"name\"]] = val\n\n        return PerformanceSummary(\n            e2e_ms=e2e_ms,\n            avg_denoise_ms=avg_denoise,\n            median_denoise_ms=median_denoise,\n            stage_metrics=stage_metrics,\n            step_metrics=step_durations,\n            sampled_steps=sampled_steps,\n            all_denoise_steps=per_step,\n        )\n\n\nT2I_sampling_params = DiffusionSamplingParams(\n    prompt=\"Doraemon is eating dorayaki\",\n    output_size=\"1024x1024\",\n)\n\nTI2I_sampling_params = DiffusionSamplingParams(\n    prompt=\"Convert 2D style to 3D style\",\n    image_path=\"https://github.com/lm-sys/lm-sys.github.io/releases/download/test/TI2I_Qwen_Image_Edit_Input.jpg\",\n)\n\nMULTI_IMAGE_TI2I_sampling_params = DiffusionSamplingParams(\n    prompt=\"The magician bear is on the left, the alchemist bear is on the right, facing each other in the central park square.\",\n    image_path=[\n        \"https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Image/edit2509/edit2509_1.jpg\",\n        \"https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Image/edit2509/edit2509_2.jpg\",\n    ],\n    direct_url_test=True,\n)\nMULTI_IMAGE_TI2I_UPLOAD_sampling_params = DiffusionSamplingParams(\n    prompt=\"The magician bear is on the left, the alchemist bear is on the right, facing each other in the central park square.\",\n    image_path=[\n        \"https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Image/edit2509/edit2509_1.jpg\",\n        \"https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Image/edit2509/edit2509_2.jpg\",\n    ],\n)\nMULTI_FRAME_I2I_sampling_params = DiffusionSamplingParams(\n    prompt=\"a high quality, cute halloween themed illustration, consistent style and lighting\",\n    image_path=[\n        \"https://raw.githubusercontent.com/QwenLM/Qwen-Image-Layered/main/assets/test_images/4.png\"\n    ],\n    num_frames=4,\n    direct_url_test=True,\n    output_format=\"png\",\n)\n\nT2V_PROMPT = \"A curious raccoon\"\n\nTI2V_sampling_params = DiffusionSamplingParams(\n    prompt=\"The man in the picture slowly turns his head, his expression enigmatic and otherworldly. The camera performs a slow, cinematic dolly out, focusing on his face. Moody lighting, neon signs glowing in the background, shallow depth of field.\",\n    image_path=\"https://is1-ssl.mzstatic.com/image/thumb/Music114/v4/5f/fa/56/5ffa56c2-ea1f-7a17-6bad-192ff9b6476d/825646124206.jpg/600x600bb.jpg\",\n    direct_url_test=True,\n)\n\nTURBOWAN_I2V_sampling_params = DiffusionSamplingParams(\n    prompt=\"The man in the picture slowly turns his head, his expression enigmatic and otherworldly. The camera performs a slow, cinematic dolly out, focusing on his face. Moody lighting, neon signs glowing in the background, shallow depth of field.\",\n    image_path=\"https://is1-ssl.mzstatic.com/image/thumb/Music114/v4/5f/fa/56/5ffa56c2-ea1f-7a17-6bad-192ff9b6476d/825646124206.jpg/600x600bb.jpg\",\n    direct_url_test=True,\n    output_size=\"960x960\",\n    num_frames=4,\n    fps=4,\n)\n\n# All test cases with clean default values\n# To test different models, simply add more DiffusionCase entries\nONE_GPU_CASES_A: list[DiffusionTestCase] = [\n    # === Text to Image (T2I) ===\n    DiffusionTestCase(\n        \"qwen_image_t2i\",\n        DiffusionServerArgs(\n            model_path=DEFAULT_QWEN_IMAGE_MODEL_NAME_FOR_TEST,\n            modality=\"image\",\n        ),\n        T2I_sampling_params,\n    ),\n    DiffusionTestCase(\n        \"qwen_image_t2i_cache_dit_enabled\",\n        DiffusionServerArgs(\n            model_path=DEFAULT_QWEN_IMAGE_MODEL_NAME_FOR_TEST,\n            modality=\"image\",\n            enable_cache_dit=True,\n        ),\n        T2I_sampling_params,\n    ),\n    DiffusionTestCase(\n        \"flux_image_t2i\",\n        DiffusionServerArgs(\n            model_path=DEFAULT_FLUX_1_DEV_MODEL_NAME_FOR_TEST, modality=\"image\"\n        ),\n        T2I_sampling_params,\n    ),\n    # TODO: modeling of flux different from official flux, so weights can't be loaded\n    # consider opting for a different quantized hf-repo\n    # DiffusionTestCase(\n    #     \"flux_image_t2i_override_transformer_weights_path_fp8\",\n    #     DiffusionServerArgs(\n    #         model_path=\"black-forest-labs/FLUX.1-dev\", modality=\"image\",\n    #         extras=[\"--transformer-weights-path black-forest-labs/FLUX.1-dev-FP8\"]\n    #     ),\n    #     T2I_sampling_params,\n    # ),\n    DiffusionTestCase(\n        \"flux_2_image_t2i\",\n        DiffusionServerArgs(\n            model_path=DEFAULT_FLUX_2_DEV_MODEL_NAME_FOR_TEST, modality=\"image\"\n        ),\n        T2I_sampling_params,\n    ),\n    DiffusionTestCase(\n        \"flux_2_klein_image_t2i\",\n        DiffusionServerArgs(\n            model_path=DEFAULT_FLUX_2_KLEIN_4B_MODEL_NAME_FOR_TEST,\n            modality=\"image\",\n        ),\n        T2I_sampling_params,\n    ),\n    # TODO: replace with a faster model to test the --dit-layerwise-offload\n    # TODO: currently, we don't support sending more than one request in test, and setting `num_outputs_per_prompt` to 2 doesn't guarantee the denoising be executed twice,\n    # so we do one warmup and send one request instead\n    DiffusionTestCase(\n        \"layerwise_offload\",\n        DiffusionServerArgs(\n            model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,\n            modality=\"image\",\n            dit_layerwise_offload=True,\n            dit_offload_prefetch_size=2,\n        ),\n        T2I_sampling_params,\n    ),\n    DiffusionTestCase(\n        \"zimage_image_t2i\",\n        DiffusionServerArgs(\n            model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, modality=\"image\"\n        ),\n        T2I_sampling_params,\n    ),\n    DiffusionTestCase(\n        \"zimage_image_t2i_fp8\",\n        DiffusionServerArgs(\n            model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,\n            modality=\"image\",\n            extras=[\"--transformer-path MickJ/Z-Image-Turbo-fp8\"],\n        ),\n        T2I_sampling_params,\n    ),\n    # Multi-LoRA test case for Z-Image-Turbo\n    DiffusionTestCase(\n        \"zimage_image_t2i_multi_lora\",\n        DiffusionServerArgs(\n            model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,\n            modality=\"image\",\n            lora_path=\"reverentelusarca/elusarca-anime-style-lora-z-image-turbo\",\n            second_lora_path=\"tarn59/pixel_art_style_lora_z_image_turbo\",\n        ),\n        T2I_sampling_params,\n    ),\n    DiffusionTestCase(\n        \"sana_image_t2i\",\n        DiffusionServerArgs(\n            model_path=\"Efficient-Large-Model/Sana_600M_1024px_diffusers\",\n            modality=\"image\",\n        ),\n        T2I_sampling_params,\n        run_perf_check=False,\n    ),\n    # === Text and Image to Image (TI2I) ===\n    DiffusionTestCase(\n        \"qwen_image_edit_ti2i\",\n        DiffusionServerArgs(\n            model_path=DEFAULT_QWEN_IMAGE_EDIT_MODEL_NAME_FOR_TEST, modality=\"image\"\n        ),\n        TI2I_sampling_params,\n    ),\n    DiffusionTestCase(\n        \"qwen_image_edit_2509_ti2i\",\n        DiffusionServerArgs(\n            model_path=DEFAULT_QWEN_IMAGE_EDIT_2509_MODEL_NAME_FOR_TEST,\n            modality=\"image\",\n        ),\n        MULTI_IMAGE_TI2I_sampling_params,\n    ),\n    DiffusionTestCase(\n        \"qwen_image_edit_2511_ti2i\",\n        DiffusionServerArgs(\n            model_path=DEFAULT_QWEN_IMAGE_EDIT_2511_MODEL_NAME_FOR_TEST,\n            modality=\"image\",\n        ),\n        TI2I_sampling_params,\n    ),\n    DiffusionTestCase(\n        \"qwen_image_layered_i2i\",\n        DiffusionServerArgs(\n            model_path=DEFAULT_QWEN_IMAGE_LAYERED_MODEL_NAME_FOR_TEST,\n            modality=\"image\",\n        ),\n        MULTI_FRAME_I2I_sampling_params,\n    ),\n    # Upscaling (Real-ESRGAN 4×) for T2I\n    DiffusionTestCase(\n        \"flux_2_image_t2i_upscaling_4x\",\n        DiffusionServerArgs(\n            model_path=\"black-forest-labs/FLUX.2-dev\",\n            modality=\"image\",\n        ),\n        DiffusionSamplingParams(\n            prompt=\"Doraemon is eating dorayaki\",\n            output_size=\"1024x1024\",\n            extras={\"enable_upscaling\": True, \"upscaling_scale\": 4},\n        ),\n    ),\n]\n\nHUNYUAN3D_SHAPE_sampling_params = DiffusionSamplingParams(\n    prompt=\"\",\n    image_path=\"https://raw.githubusercontent.com/sgl-project/sgl-test-files/main/diffusion-ci/consistency_gt/1-gpu/hunyuan3d_2_0/hunyuan3d.png\",\n)\n\nONE_GPU_CASES_B: list[DiffusionTestCase] = [\n    # === Text to Video (T2V) ===\n    DiffusionTestCase(\n        \"wan2_1_t2v_1.3b\",\n        DiffusionServerArgs(\n            model_path=DEFAULT_WAN_2_1_T2V_1_3B_MODEL_NAME_FOR_TEST,\n            modality=\"video\",\n            custom_validator=\"video\",\n        ),\n        DiffusionSamplingParams(\n            prompt=T2V_PROMPT,\n        ),\n    ),\n    DiffusionTestCase(\n        \"wan2_1_t2v_1.3b_text_encoder_cpu_offload\",\n        DiffusionServerArgs(\n            model_path=DEFAULT_WAN_2_1_T2V_1_3B_MODEL_NAME_FOR_TEST,\n            modality=\"video\",\n            custom_validator=\"video\",\n            text_encoder_cpu_offload=True,\n        ),\n        DiffusionSamplingParams(\n            prompt=T2V_PROMPT,\n        ),\n    ),\n    # TeaCache acceleration test for Wan video model\n    DiffusionTestCase(\n        \"wan2_1_t2v_1.3b_teacache_enabled\",\n        DiffusionServerArgs(\n            model_path=DEFAULT_WAN_2_1_T2V_1_3B_MODEL_NAME_FOR_TEST,\n            modality=\"video\",\n            custom_validator=\"video\",\n        ),\n        DiffusionSamplingParams(\n            prompt=T2V_PROMPT,\n            extras={\"enable_teacache\": True},\n        ),\n    ),\n    # Frame interpolation (2× / exp=1)\n    # Uses the same 1.3B model already in the suite;\n    DiffusionTestCase(\n        \"wan2_1_t2v_1.3b_frame_interp_2x\",\n        DiffusionServerArgs(\n            model_path=\"Wan-AI/Wan2.1-T2V-1.3B-Diffusers\",\n            modality=\"video\",\n            custom_validator=\"video\",\n        ),\n        DiffusionSamplingParams(\n            prompt=T2V_PROMPT,\n            extras={\"enable_frame_interpolation\": True, \"frame_interpolation_exp\": 1},\n        ),\n    ),\n    # Upscaling (Real-ESRGAN 4×)\n    # Uses the same 1.3B model already in the suite;\n    DiffusionTestCase(\n        \"wan2_1_t2v_1.3b_upscaling_4x\",\n        DiffusionServerArgs(\n            model_path=\"Wan-AI/Wan2.1-T2V-1.3B-Diffusers\",\n            modality=\"video\",\n            custom_validator=\"video\",\n        ),\n        DiffusionSamplingParams(\n            prompt=T2V_PROMPT,\n            extras={\"enable_upscaling\": True, \"upscaling_scale\": 4},\n        ),\n    ),\n    # Combined: Frame interpolation (2×) + Upscaling (4×)\n    # Verifies that both post-processing steps compose correctly.\n    DiffusionTestCase(\n        \"wan2_1_t2v_1.3b_frame_interp_2x_upscaling_4x\",\n        DiffusionServerArgs(\n            model_path=\"Wan-AI/Wan2.1-T2V-1.3B-Diffusers\",\n            modality=\"video\",\n            custom_validator=\"video\",\n        ),\n        DiffusionSamplingParams(\n            prompt=T2V_PROMPT,\n            extras={\n                \"enable_frame_interpolation\": True,\n                \"frame_interpolation_exp\": 1,\n                \"enable_upscaling\": True,\n                \"upscaling_scale\": 4,\n            },\n        ),\n    ),\n    # LoRA test case for single transformer + merge/unmerge API test\n    # Note: Uses dynamic_lora_path instead of lora_path to test LayerwiseOffload + set_lora interaction\n    # Server starts WITHOUT LoRA, then set_lora is called after startup (Wan models auto-enable layerwise offload)\n    DiffusionTestCase(\n        \"wan2_1_t2v_1_3b_lora_1gpu\",\n        DiffusionServerArgs(\n            model_path=DEFAULT_WAN_2_1_T2V_1_3B_MODEL_NAME_FOR_TEST,\n            modality=\"video\",\n            custom_validator=\"video\",\n            num_gpus=1,\n            dynamic_lora_path=\"Cseti/Wan-LoRA-Arcane-Jinx-v1\",\n        ),\n        DiffusionSamplingParams(\n            prompt=\"csetiarcane Nfj1nx with blue hair, a woman walking in a cyberpunk city at night\",\n        ),\n    ),\n    # NOTE(mick): flaky\n    # DiffusionTestCase(\n    #     \"hunyuan_video\",\n    #     DiffusionServerArgs(\n    #         model_path=\"hunyuanvideo-community/HunyuanVideo\",\n    #         modality=\"video\",\n    #     ),\n    #     DiffusionSamplingParams(\n    #         prompt=T2V_PROMPT,\n    #     ),\n    # ),\n    DiffusionTestCase(\n        \"flux_2_ti2i\",\n        DiffusionServerArgs(\n            model_path=DEFAULT_FLUX_2_DEV_MODEL_NAME_FOR_TEST, modality=\"image\"\n        ),\n        TI2I_sampling_params,\n    ),\n    DiffusionTestCase(\n        \"flux_2_t2i_customized_vae_path\",\n        DiffusionServerArgs(\n            model_path=DEFAULT_FLUX_2_DEV_MODEL_NAME_FOR_TEST,\n            modality=\"image\",\n            extras=[\"--vae-path=fal/FLUX.2-Tiny-AutoEncoder\"],\n        ),\n        T2I_sampling_params,\n        run_perf_check=False,\n    ),\n    DiffusionTestCase(\n        \"fast_hunyuan_video\",\n        DiffusionServerArgs(\n            model_path=\"FastVideo/FastHunyuan-diffusers\",\n            modality=\"video\",\n            custom_validator=\"video\",\n        ),\n        DiffusionSamplingParams(\n            prompt=T2V_PROMPT,\n        ),\n    ),\n    # === Text and Image to Video (TI2V) ===\n    DiffusionTestCase(\n        \"wan2_2_ti2v_5b\",\n        DiffusionServerArgs(\n            model_path=DEFAULT_WAN_2_2_TI2V_5B_MODEL_NAME_FOR_TEST,\n            modality=\"video\",\n            custom_validator=\"video\",\n        ),\n        TI2V_sampling_params,\n    ),\n    DiffusionTestCase(\n        \"fastwan2_2_ti2v_5b\",\n        DiffusionServerArgs(\n            model_path=\"FastVideo/FastWan2.2-TI2V-5B-FullAttn-Diffusers\",\n            modality=\"video\",\n            custom_validator=\"video\",\n        ),\n        TI2V_sampling_params,\n    ),\n    # === Helios T2V ===\n    DiffusionTestCase(\n        \"helios_base_t2v\",\n        DiffusionServerArgs(\n            model_path=\"BestWishYsh/Helios-Base\",\n            modality=\"video\",\n        ),\n        DiffusionSamplingParams(\n            prompt=T2V_PROMPT,\n            output_size=\"640x384\",\n            num_frames=33,\n        ),\n    ),\n    DiffusionTestCase(\n        \"helios_mid_t2v\",\n        DiffusionServerArgs(\n            model_path=\"BestWishYsh/Helios-Mid\",\n            modality=\"video\",\n        ),\n        DiffusionSamplingParams(\n            prompt=T2V_PROMPT,\n            output_size=\"640x384\",\n            num_frames=33,\n        ),\n    ),\n    DiffusionTestCase(\n        \"helios_distilled_t2v\",\n        DiffusionServerArgs(\n            model_path=\"BestWishYsh/Helios-Distilled\",\n            modality=\"video\",\n        ),\n        DiffusionSamplingParams(\n            prompt=T2V_PROMPT,\n            output_size=\"640x384\",\n            num_frames=33,\n        ),\n    ),\n]\n\n# Skip hunyuan3d on AMD: marching_cubes surface extraction produces invalid SDF on ROCm.\nif not current_platform.is_hip():\n    ONE_GPU_CASES_B.append(\n        DiffusionTestCase(\n            \"hunyuan3d_shape_gen\",\n            DiffusionServerArgs(\n                model_path=\"tencent/Hunyuan3D-2\",\n                modality=\"3d\",\n            ),\n            HUNYUAN3D_SHAPE_sampling_params,\n        ),\n    )\n# Skip turbowan on AMD: Triton requires 81920 shared memory, but AMD only has 65536.\nif not current_platform.is_hip():\n    ONE_GPU_CASES_B.append(\n        DiffusionTestCase(\n            \"turbo_wan2_1_t2v_1.3b\",\n            DiffusionServerArgs(\n                model_path=\"IPostYellow/TurboWan2.1-T2V-1.3B-Diffusers\",\n                modality=\"video\",\n                custom_validator=\"video\",\n            ),\n            DiffusionSamplingParams(\n                prompt=T2V_PROMPT,\n            ),\n        )\n    )\n\nTWO_GPU_CASES_A = [\n    DiffusionTestCase(\n        \"wan2_2_i2v_a14b_2gpu\",\n        DiffusionServerArgs(\n            model_path=DEFAULT_WAN_2_2_I2V_A14B_MODEL_NAME_FOR_TEST,\n            modality=\"video\",\n            custom_validator=\"video\",\n        ),\n        TI2V_sampling_params,\n    ),\n    DiffusionTestCase(\n        \"wan2_2_t2v_a14b_2gpu\",\n        DiffusionServerArgs(\n            model_path=DEFAULT_WAN_2_2_T2V_A14B_MODEL_NAME_FOR_TEST,\n            modality=\"video\",\n            custom_validator=\"video\",\n            num_gpus=2,\n        ),\n        DiffusionSamplingParams(\n            prompt=T2V_PROMPT,\n        ),\n    ),\n    # TeaCache smoke test for Wan2.2 T2V A14B — verifies enable_teacache=True\n    # doesn't crash. Perf check disabled because Wan2.2-specific TeaCache\n    # coefficients are not yet calibrated (teacache_params=None, so no speedup).\n    DiffusionTestCase(\n        \"wan2_2_t2v_a14b_teacache_2gpu\",\n        DiffusionServerArgs(\n            model_path=DEFAULT_WAN_2_2_T2V_A14B_MODEL_NAME_FOR_TEST,\n            modality=\"video\",\n            custom_validator=\"video\",\n            num_gpus=2,\n        ),\n        DiffusionSamplingParams(\n            prompt=T2V_PROMPT,\n            extras={\"enable_teacache\": True},\n        ),\n        run_perf_check=False,\n    ),\n    # LoRA test case for transformer_2 support\n    DiffusionTestCase(\n        \"wan2_2_t2v_a14b_lora_2gpu\",\n        DiffusionServerArgs(\n            model_path=DEFAULT_WAN_2_2_T2V_A14B_MODEL_NAME_FOR_TEST,\n            modality=\"video\",\n            custom_validator=\"video\",\n            num_gpus=2,\n            lora_path=\"Cseti/wan2.2-14B-Arcane_Jinx-lora-v1\",\n        ),\n        DiffusionSamplingParams(\n            prompt=\"Nfj1nx with blue hair, a woman walking in a cyberpunk city at night\",\n        ),\n    ),\n    DiffusionTestCase(\n        \"wan2_1_t2v_14b_2gpu\",\n        DiffusionServerArgs(\n            model_path=DEFAULT_WAN_2_1_T2V_14B_MODEL_NAME_FOR_TEST,\n            modality=\"video\",\n            num_gpus=2,\n            custom_validator=\"video\",\n        ),\n        DiffusionSamplingParams(\n            prompt=T2V_PROMPT,\n            output_size=\"832x480\",\n        ),\n    ),\n    DiffusionTestCase(\n        \"wan2_1_t2v_1.3b_cfg_parallel\",\n        DiffusionServerArgs(\n            model_path=DEFAULT_WAN_2_1_T2V_1_3B_MODEL_NAME_FOR_TEST,\n            modality=\"video\",\n            custom_validator=\"video\",\n            num_gpus=2,\n            cfg_parallel=True,\n        ),\n        DiffusionSamplingParams(\n            prompt=T2V_PROMPT,\n        ),\n    ),\n    DiffusionTestCase(\n        \"fsdp-inference\",\n        DiffusionServerArgs(\n            model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,\n            modality=\"image\",\n            num_gpus=2,\n            extras=[\"--use-fsdp-inference\"],\n        ),\n        T2I_sampling_params,\n    ),\n]\n\nTWO_GPU_CASES_B = [\n    DiffusionTestCase(\n        \"wan2_1_i2v_14b_480P_2gpu\",\n        DiffusionServerArgs(\n            model_path=DEFAULT_WAN_2_1_I2V_14B_480P_MODEL_NAME_FOR_TEST,\n            modality=\"video\",\n            custom_validator=\"video\",\n            num_gpus=2,\n        ),\n        TI2V_sampling_params,\n    ),\n    # I2V LoRA test case\n    DiffusionTestCase(\n        \"wan2_1_i2v_14b_lora_2gpu\",\n        DiffusionServerArgs(\n            model_path=DEFAULT_WAN_2_1_I2V_14B_720P_MODEL_NAME_FOR_TEST,\n            modality=\"video\",\n            custom_validator=\"video\",\n            num_gpus=2,\n            lora_path=\"starsfriday/Wan2.1-Divine-Power-LoRA\",\n        ),\n        TI2V_sampling_params,\n    ),\n    DiffusionTestCase(\n        \"wan2_1_i2v_14b_720P_2gpu\",\n        DiffusionServerArgs(\n            model_path=DEFAULT_WAN_2_1_I2V_14B_720P_MODEL_NAME_FOR_TEST,\n            modality=\"video\",\n            custom_validator=\"video\",\n            num_gpus=2,\n        ),\n        TI2V_sampling_params,\n    ),\n    DiffusionTestCase(\n        \"qwen_image_t2i_2_gpus\",\n        DiffusionServerArgs(\n            model_path=DEFAULT_QWEN_IMAGE_MODEL_NAME_FOR_TEST,\n            modality=\"image\",\n            num_gpus=2,\n            # test ring attn\n            ulysses_degree=1,\n            ring_degree=2,\n        ),\n        T2I_sampling_params,\n    ),\n    DiffusionTestCase(\n        \"zimage_image_t2i_2_gpus\",\n        DiffusionServerArgs(\n            model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,\n            modality=\"image\",\n            num_gpus=2,\n            ulysses_degree=2,\n        ),\n        T2I_sampling_params,\n    ),\n    DiffusionTestCase(\n        \"flux_image_t2i_2_gpus\",\n        DiffusionServerArgs(\n            model_path=DEFAULT_FLUX_1_DEV_MODEL_NAME_FOR_TEST,\n            modality=\"image\",\n            num_gpus=2,\n        ),\n        T2I_sampling_params,\n    ),\n    DiffusionTestCase(\n        \"flux_2_image_t2i_2_gpus\",\n        DiffusionServerArgs(\n            model_path=DEFAULT_FLUX_2_DEV_MODEL_NAME_FOR_TEST,\n            modality=\"image\",\n            num_gpus=2,\n            tp_size=2,\n        ),\n        T2I_sampling_params,\n    ),\n    DiffusionTestCase(\n        \"flux_2_klein_ti2i_2_gpus\",\n        DiffusionServerArgs(\n            model_path=\"black-forest-labs/FLUX.2-klein-4B\",\n            modality=\"image\",\n            num_gpus=2,\n        ),\n        TI2I_sampling_params,\n    ),\n]\n\nif not current_platform.is_hip():\n    # Flux2 multi-image edit with cache-dit, regression test\n    ONE_GPU_CASES_B.append(\n        DiffusionTestCase(\n            \"flux_2_ti2i_multi_image_cache_dit\",\n            DiffusionServerArgs(\n                model_path=\"black-forest-labs/FLUX.2-dev\",\n                modality=\"image\",\n                enable_cache_dit=True,\n            ),\n            MULTI_IMAGE_TI2I_UPLOAD_sampling_params,\n        )\n    )\n    # Skip turbowan because Triton requires 81920 shared memory, but AMD only has 65536.\n    ONE_GPU_CASES_B.append(\n        DiffusionTestCase(\n            \"turbo_wan2_1_t2v_1.3b\",\n            DiffusionServerArgs(\n                model_path=\"IPostYellow/TurboWan2.1-T2V-1.3B-Diffusers\",\n                modality=\"video\",\n                custom_validator=\"video\",\n            ),\n            DiffusionSamplingParams(\n                prompt=T2V_PROMPT,\n            ),\n        )\n    )\n    # Skip turbowan because Triton requires 81920 shared memory, but AMD only has 65536.\n    TWO_GPU_CASES_A.append(\n        DiffusionTestCase(\n            \"turbo_wan2_2_i2v_a14b_2gpu\",\n            DiffusionServerArgs(\n                model_path=\"IPostYellow/TurboWan2.2-I2V-A14B-Diffusers\",\n                modality=\"video\",\n                custom_validator=\"video\",\n                num_gpus=2,\n                tp_size=2,\n            ),\n            TURBOWAN_I2V_sampling_params,\n        )\n    )\n\n# Load global configuration\nBASELINE_CONFIG = BaselineConfig.load(\n    Path(__file__).with_name(\"perf_baselines.json\")\n).update(Path(__file__).parent / \"ascend\" / \"perf_baselines_npu.json\")\n"
  },
  {
    "path": "python/sglang/multimodal_gen/test/slack_utils.py",
    "content": "\"\"\"\nThis file upload the media generated in diffusion-nightly-test to a slack channel of SGLang\n\"\"\"\n\nimport logging\nimport os\nimport tempfile\nfrom datetime import datetime\nfrom typing import List, Union\nfrom urllib.parse import urlparse\nfrom urllib.request import urlopen\n\nfrom sglang.multimodal_gen.runtime.utils.perf_logger import get_git_commit_hash\n\nlogging.basicConfig(level=logging.INFO)\nlogger = logging.getLogger(__name__)\n\nimport inspect\n\ntry:\n    import sglang.multimodal_gen.test.server.testcase_configs as configs\n    from sglang.multimodal_gen.test.server.testcase_configs import DiffusionTestCase\n\n    ALL_CASES = []\n    for name, value in inspect.getmembers(configs):\n        if name.endswith(\"_CASES\") or \"_CASES_\" in name:\n            if (\n                isinstance(value, list)\n                and len(value) > 0\n                and isinstance(value[0], DiffusionTestCase)\n            ):\n                ALL_CASES.extend(value)\n            elif isinstance(value, list) and len(value) == 0:\n                # Assume empty list with matching name is a valid case list container\n                pass\n\n    # Deduplicate cases by ID\n    seen_ids = set()\n    unique_cases = []\n    for c in ALL_CASES:\n        if c.id not in seen_ids:\n            seen_ids.add(c.id)\n            unique_cases.append(c)\n    ALL_CASES = unique_cases\n\nexcept Exception as e:\n    logger.warning(f\"Failed to import test cases: {e}\")\n    ALL_CASES = []\n\n\ndef _get_status_message(run_id, current_case_id, thread_messages=None):\n    date_str = datetime.now().strftime(\"%d/%m\")\n    base_header = f\"\"\"🧵 for nightly test of {date_str}\n*Git Revision:* {get_git_commit_hash()}\n*GitHub Run ID:* {run_id}\n*Total Tasks:* {len(ALL_CASES)}\n\"\"\"\n\n    if not ALL_CASES:\n        return base_header\n\n    default_emoji_for_case_in_progress = \"⏳\"\n    status_map = {c.id: default_emoji_for_case_in_progress for c in ALL_CASES}\n\n    if thread_messages:\n        for msg in thread_messages:\n            text = msg.get(\"text\", \"\")\n            # Look for case_id in the message (format: *Case ID:* `case_id`)\n            for c in ALL_CASES:\n                if f\"*Case ID:* `{c.id}`\" in text:\n                    status_map[c.id] = \"✅\"\n\n    if current_case_id:\n        status_map[current_case_id] = \"✅\"\n\n    lines = [base_header, \"\", \"*Tasks Status:*\"]\n\n    # Calculate padding\n    max_len = max(len(c.id) for c in ALL_CASES) if ALL_CASES else 10\n    max_len = max(max_len, len(\"Case ID\"))\n\n    # Build markdown table inside a code block\n    table_lines = [\"```\"]\n    table_lines.append(f\"| {'Case ID'.ljust(max_len)} | Status |\")\n    table_lines.append(f\"| {'-' * max_len} | :----: |\")\n\n    for c in ALL_CASES:\n        mark = status_map.get(c.id, default_emoji_for_case_in_progress)\n        table_lines.append(f\"| {c.id.ljust(max_len)} |   {mark}   |\")\n\n    table_lines.append(\"```\")\n\n    lines.extend(table_lines)\n\n    return \"\\n\".join(lines)\n\n\ndef upload_file_to_slack(\n    case_id: str = None,\n    model: str = None,\n    prompt: str = None,\n    file_path: str = None,\n    origin_file_path: Union[str, List[str]] = None,\n) -> bool:\n    temp_paths = []\n    try:\n        from slack_sdk import WebClient\n\n        run_id = os.getenv(\"GITHUB_RUN_ID\", \"local\")\n\n        token = os.environ.get(\"SGLANG_DIFFUSION_SLACK_TOKEN\")\n        if not token:\n            logger.info(f\"Slack upload failed: no token\")\n            return False\n\n        if not file_path or not os.path.exists(file_path):\n            logger.info(f\"Slack upload failed: no file path\")\n            return False\n\n        origin_paths = []\n        if isinstance(origin_file_path, str):\n            if origin_file_path:\n                origin_paths.append(origin_file_path)\n        elif isinstance(origin_file_path, list):\n            origin_paths = [p for p in origin_file_path if p]\n\n        final_origin_paths = []\n        for path in origin_paths:\n            if path.startswith((\"http\", \"https\")):\n                try:\n                    suffix = os.path.splitext(urlparse(path).path)[1] or \".tmp\"\n                    with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:\n                        with urlopen(path) as response:\n                            tf.write(response.read())\n                    temp_paths.append(tf.name)\n                    final_origin_paths.append(tf.name)\n                except Exception as e:\n                    logger.warning(f\"Failed to download {path}: {e}\")\n            else:\n                final_origin_paths.append(path)\n\n        uploads = []\n        for i, path in enumerate(final_origin_paths):\n            if os.path.exists(path):\n                title = (\n                    \"Original Image\"\n                    if len(final_origin_paths) == 1\n                    else f\"Original Image {i+1}\"\n                )\n                uploads.append({\"file\": path, \"title\": title})\n\n        uploads.append({\"file\": file_path, \"title\": \"Generated Image\"})\n\n        message = (\n            f\"*Case ID:* `{case_id}`\\n\" f\"*Model:* `{model}`\\n\" f\"*Prompt:* {prompt}\"\n        )\n\n        client = WebClient(token=token)\n        channel_id = \"C0A02NDF7UY\"\n        thread_ts = None\n\n        parent_msg_text = None\n        try:\n            history = client.conversations_history(channel=channel_id, limit=100)\n            for msg in history.get(\"messages\", []):\n                if f\"*GitHub Run ID:* {run_id}\" in msg.get(\"text\", \"\"):\n                    # Use thread_ts if it exists (msg is a reply), otherwise use ts (msg is a parent)\n                    thread_ts = msg.get(\"thread_ts\") or msg.get(\"ts\")\n                    parent_msg_text = msg.get(\"text\", \"\")\n                    logger.info(f\"Found thread_ts: {thread_ts}\")\n                    break\n        except Exception as e:\n            logger.warning(f\"Failed to search slack history: {e}\")\n\n        if not thread_ts:\n            try:\n                text = _get_status_message(run_id, case_id)\n                response = client.chat_postMessage(channel=channel_id, text=text)\n                thread_ts = response[\"ts\"]\n            except Exception as e:\n                logger.warning(f\"Failed to create parent thread: {e}\")\n\n        # Upload first to ensure it's in history\n        client.files_upload_v2(\n            channel=channel_id,\n            file_uploads=uploads,\n            initial_comment=message,\n            thread_ts=thread_ts,\n        )\n\n        # Then update status based on thread replies\n        if thread_ts:\n            try:\n                replies = client.conversations_replies(\n                    channel=channel_id, ts=thread_ts, limit=200\n                )\n                messages = replies.get(\"messages\", [])\n                new_text = _get_status_message(run_id, case_id, messages)\n\n                # Only update if changed significantly (ignoring timestamp diffs if any)\n                # But here we just check text content\n                if new_text != parent_msg_text:\n                    client.chat_update(channel=channel_id, ts=thread_ts, text=new_text)\n            except Exception as e:\n                logger.warning(f\"Failed to update parent message: {e}\")\n\n        logger.info(f\"File uploaded successfully: {os.path.basename(file_path)}\")\n        return True\n\n    except Exception as e:\n        logger.info(f\"Slack upload failed: {e}\")\n        return False\n    finally:\n        for p in temp_paths:\n            if os.path.exists(p):\n                os.remove(p)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/test/test_files/launch_flux.json",
    "content": "{\n    \"model_path\": \"black-forest-labs/FLUX.1-dev\",\n    \"prompt\": \"A beautiful woman in a red dress walking down a street\",\n    \"text_encoder_cpu_offload\": true,\n    \"pin_cpu_memory\": true,\n    \"save_output\": true,\n    \"width\": 720,\n    \"height\": 720,\n    \"output_path\": \"outputs\",\n    \"output_file_name\": \"FLUX.1-dev, single gpu\"\n}\n"
  },
  {
    "path": "python/sglang/multimodal_gen/test/test_files/launch_wan.json",
    "content": "{\n    \"model_path\": \"Wan-AI/Wan2.1-T2V-1.3B-Diffusers\",\n    \"prompt\": \"A beautiful woman in a red dress walking down a street\",\n    \"text_encoder_cpu_offload\": true,\n    \"pin_cpu_memory\": true,\n    \"save_output\": true,\n    \"width\": 720,\n    \"height\": 720,\n    \"output_path\": \"outputs\",\n    \"output_file_name\": \"Wan2.1-T2V-1.3B-Diffusers, single gpu\"\n}\n"
  },
  {
    "path": "python/sglang/multimodal_gen/test/test_utils.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\nimport base64\nimport io\nimport json\nimport os\nimport socket\nimport subprocess\nimport tempfile\nimport time\nfrom pathlib import Path\nfrom urllib.parse import urljoin\n\nimport cv2\nimport httpx\nimport numpy as np\nfrom PIL import Image\n\nfrom sglang.multimodal_gen.runtime.utils.common import get_bool_env_var\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_logger\nfrom sglang.multimodal_gen.runtime.utils.perf_logger import (\n    RequestPerfRecord,\n    get_diffusion_perf_log_dir,\n)\n\nlogger = init_logger(__name__)\n\n# ---------------------------------------------------------------------------\n# Common model IDs for diffusion tests\n#\n# Centralised here so every test file references the same constants instead\n# of scattering hard-coded strings. When adding a new model that will be\n# reused across tests, define it here.\n# ---------------------------------------------------------------------------\n\nDEFAULT_SMALL_MODEL_NAME_FOR_TEST = \"Tongyi-MAI/Z-Image-Turbo\"\n\n# Qwen image generation models\nDEFAULT_QWEN_IMAGE_MODEL_NAME_FOR_TEST = \"Qwen/Qwen-Image\"\nDEFAULT_QWEN_IMAGE_2512_MODEL_NAME_FOR_TEST = \"Qwen/Qwen-Image-2512\"\nDEFAULT_QWEN_IMAGE_EDIT_MODEL_NAME_FOR_TEST = \"Qwen/Qwen-Image-Edit\"\nDEFAULT_QWEN_IMAGE_EDIT_2509_MODEL_NAME_FOR_TEST = \"Qwen/Qwen-Image-Edit-2509\"\nDEFAULT_QWEN_IMAGE_EDIT_2511_MODEL_NAME_FOR_TEST = \"Qwen/Qwen-Image-Edit-2511\"\nDEFAULT_QWEN_IMAGE_LAYERED_MODEL_NAME_FOR_TEST = \"Qwen/Qwen-Image-Layered\"\n\n# FLUX image generation models\nDEFAULT_FLUX_1_DEV_MODEL_NAME_FOR_TEST = \"black-forest-labs/FLUX.1-dev\"\nDEFAULT_FLUX_2_DEV_MODEL_NAME_FOR_TEST = \"black-forest-labs/FLUX.2-dev\"\nDEFAULT_FLUX_2_KLEIN_4B_MODEL_NAME_FOR_TEST = \"black-forest-labs/FLUX.2-klein-4B\"\nDEFAULT_FLUX_2_KLEIN_BASE_4B_MODEL_NAME_FOR_TEST = (\n    \"black-forest-labs/FLUX.2-klein-base-4B\"\n)\n\n# Wan video generation models\nDEFAULT_WAN_2_1_T2V_1_3B_MODEL_NAME_FOR_TEST = \"Wan-AI/Wan2.1-T2V-1.3B-Diffusers\"\nDEFAULT_WAN_2_1_T2V_14B_MODEL_NAME_FOR_TEST = \"Wan-AI/Wan2.1-T2V-14B-Diffusers\"\nDEFAULT_WAN_2_1_I2V_14B_480P_MODEL_NAME_FOR_TEST = (\n    \"Wan-AI/Wan2.1-I2V-14B-480P-Diffusers\"\n)\nDEFAULT_WAN_2_1_I2V_14B_720P_MODEL_NAME_FOR_TEST = (\n    \"Wan-AI/Wan2.1-I2V-14B-720P-Diffusers\"\n)\nDEFAULT_WAN_2_2_TI2V_5B_MODEL_NAME_FOR_TEST = \"Wan-AI/Wan2.2-TI2V-5B-Diffusers\"\nDEFAULT_WAN_2_2_T2V_A14B_MODEL_NAME_FOR_TEST = \"Wan-AI/Wan2.2-T2V-A14B-Diffusers\"\nDEFAULT_WAN_2_2_I2V_A14B_MODEL_NAME_FOR_TEST = \"Wan-AI/Wan2.2-I2V-A14B-Diffusers\"\n\n\ndef print_value_formatted(description: str, value: int | float | str):\n    \"\"\"Helper function to print a metric value formatted.\"\"\"\n    if isinstance(value, int):\n        if value >= 1e6:\n            value_str = f\"{value / 1e6:<30.2f}M\"\n        elif value >= 1e3:\n            value_str = f\"{value / 1e3:<30.2f}K\"\n        else:\n            value_str = f\"{value:<30}\"\n    elif isinstance(value, float):\n        value_str = f\"{value:<30.2f}\"\n    else:\n        value_str = f\"{value:<30}\"\n\n    print(f\"{description:<45} {value_str}\")\n\n\ndef print_divider(length: int, char: str = \"-\"):\n    \"\"\"Helper function to print a divider line.\"\"\"\n    print(char * length)\n\n\ndef is_image_url(image_path: str | Path | None) -> bool:\n    \"\"\"Check if image_path is a URL.\"\"\"\n    if image_path is None:\n        return False\n    return isinstance(image_path, str) and (\n        image_path.startswith(\"http://\") or image_path.startswith(\"https://\")\n    )\n\n\ndef probe_port(host=\"127.0.0.1\", port=30010, timeout=2.0) -> bool:\n    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:\n        s.settimeout(timeout)\n        try:\n            s.connect((host, port))\n            return True\n        except OSError:\n            return False\n\n\ndef is_in_ci() -> bool:\n    return get_bool_env_var(\"SGLANG_IS_IN_CI\")\n\n\ndef get_dynamic_server_port() -> int:\n    cuda_devices = os.environ.get(\"CUDA_VISIBLE_DEVICES\", \"0\")\n    if not cuda_devices:\n        cuda_devices = \"0\"\n    try:\n        first_device_id = int(cuda_devices.split(\",\")[0].strip()[0])\n    except (ValueError, IndexError):\n        first_device_id = 0\n\n    if is_in_ci():\n        base_port = 10000 + first_device_id * 2000\n    else:\n        base_port = 20000 + first_device_id * 1000\n\n    return base_port + 1000\n\n\ndef find_free_port(host: str = \"127.0.0.1\") -> int:\n    \"\"\"Bind to port 0 and let the OS assign an available port.\"\"\"\n    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:\n        s.bind((host, 0))\n        return s.getsockname()[1]\n\n\ndef wait_for_server_health(\n    base_url: str,\n    path: str = \"/health\",\n    timeout: float = 180.0,\n    interval: float = 1.0,\n) -> None:\n    \"\"\"Poll ``GET <base_url><path>`` until it returns HTTP 200.\"\"\"\n    deadline = time.time() + timeout\n    last_err: httpx.RequestError | None = None\n    last_status: int | None = None\n    while time.time() < deadline:\n        try:\n            r = httpx.get(urljoin(base_url, path), timeout=5.0)\n            last_status = r.status_code\n            if r.status_code == 200:\n                return\n        except httpx.RequestError as e:\n            last_err = e\n        time.sleep(interval)\n    raise TimeoutError(\n        f\"Server at {urljoin(base_url, path)} not healthy after {timeout}s. \"\n        f\"{last_status=} {last_err=}\"\n    )\n\n\ndef post_json(\n    base_url: str,\n    path: str,\n    payload: dict,\n    timeout: float = 300.0,\n) -> httpx.Response:\n    \"\"\"POST JSON to ``<base_url><path>`` and return the response.\"\"\"\n    return httpx.post(urljoin(base_url, path), json=payload, timeout=timeout)\n\n\n# ---------------------------------------------------------------------------\n# GPU memory helpers (nvidia-smi)\n# ---------------------------------------------------------------------------\n\n\ndef query_gpu_mem_used_mib(gpu_index: int = 0, required: bool = False) -> int | None:\n    \"\"\"Return GPU memory usage in MiB via ``nvidia-smi``, or *None* on failure.\n\n    When *required* is ``True`` the function raises instead of returning ``None``.\n    \"\"\"\n    try:\n        out = subprocess.check_output(\n            [\n                \"nvidia-smi\",\n                f\"--id={gpu_index}\",\n                \"--query-gpu=memory.used\",\n                \"--format=csv,noheader,nounits\",\n            ],\n            text=True,\n        ).strip()\n        return int(out.splitlines()[0].strip())\n    except Exception as e:\n        logger.warning(f\"nvidia-smi memory query failed: {type(e).__name__}: {e}\")\n        assert not required, (\n            \"nvidia-smi memory query is unavailable; \"\n            \"cannot enforce GPU memory assertions.\"\n        )\n        return None\n\n\ndef require_gpu_mem_query(gpu_index: int = 0) -> int:\n    \"\"\"Same as :func:`query_gpu_mem_used_mib` but asserts availability.\n\n    Raises ``AssertionError`` when ``nvidia-smi`` is unavailable instead of\n    returning ``None``, so callers can rely on a valid ``int`` result.\n    \"\"\"\n    mem = query_gpu_mem_used_mib(gpu_index, required=True)\n    assert mem is not None\n    return mem\n\n\ndef assert_gpu_mem_changed(\n    label: str,\n    before_mib: int,\n    after_mib: int,\n    min_delta_mib: int,\n) -> None:\n    \"\"\"Assert that GPU memory changed by at least *min_delta_mib* MiB.\"\"\"\n    delta = abs(after_mib - before_mib)\n    logger.debug(\n        f\"[MEM] {label}: before={before_mib} MiB  after={after_mib} MiB  |delta|={delta} MiB\"\n    )\n    assert delta >= min_delta_mib, (\n        f\"GPU memory change too small for '{label}': \"\n        f\"|after-before|={delta} MiB < {min_delta_mib} MiB \"\n        f\"(before={before_mib} MiB, after={after_mib} MiB)\"\n    )\n\n\ndef is_mp4(data: bytes) -> bool:\n    \"\"\"Check if data represents a valid MP4 file by magic bytes.\"\"\"\n    if len(data) < 8:\n        return False\n    return data[4:8] == b\"ftyp\"\n\n\ndef is_jpeg(data: bytes) -> bool:\n    # JPEG files start with: FF D8 FF\n    return data.startswith(b\"\\xff\\xd8\\xff\")\n\n\ndef is_png(data):\n    # PNG files start with: 89 50 4E 47 0D 0A 1A 0A\n    return data.startswith(b\"\\x89PNG\\r\\n\\x1a\\n\")\n\n\ndef is_webp(data: bytes) -> bool:\n    # WebP files start with: RIFF....WEBP\n    return data[:4] == b\"RIFF\" and data[8:12] == b\"WEBP\"\n\n\ndef detect_image_format(data: bytes) -> str:\n    \"\"\"Detect image format from bytes (magic). Returns 'png'|'jpeg'|'webp'; default 'png'.\"\"\"\n    if len(data) < 12:\n        return \"png\"\n    if is_png(data):\n        return \"png\"\n    if is_jpeg(data):\n        return \"jpeg\"\n    if is_webp(data):\n        return \"webp\"\n    return \"png\"\n\n\ndef get_expected_image_format(\n    output_format: str | None = None,\n    background: str | None = None,\n) -> str:\n    \"\"\"Infer expected image format based on request parameters.\n    Args:\n        output_format: The output_format parameter from the request (png/jpeg/webp/jpg)\n        background: The background parameter from the request (transparent/opaque/auto)\n    Returns:\n        Expected file extension: \"jpg\", \"png\", or \"webp\"\n    \"\"\"\n    fmt = (output_format or \"\").lower()\n    if fmt in {\"png\", \"webp\", \"jpeg\", \"jpg\"}:\n        return \"jpg\" if fmt == \"jpeg\" else fmt\n    if (background or \"auto\").lower() == \"transparent\":\n        return \"png\"\n    return \"jpg\"  # Default\n\n\ndef wait_for_port(host=\"127.0.0.1\", port=30010, deadline=300.0, interval=0.5):\n    end = time.time() + deadline\n    last_err = None\n    while time.time() < end:\n        if probe_port(host, port, timeout=interval):\n            return True\n        time.sleep(interval)\n    raise TimeoutError(f\"Port {host}:{port} not ready. Last error: {last_err}\")\n\n\ndef check_image_size(ut, image, width, height):\n    # check image size\n    ut.assertEqual(image.size, (width, height))\n\n\ndef get_perf_log_dir() -> Path:\n    \"\"\"Gets the performance log directory from the centralized sglang utility.\"\"\"\n    log_dir_str = get_diffusion_perf_log_dir()\n    if not log_dir_str:\n        raise RuntimeError(\n            \"Performance logging is disabled (SGLANG_PERF_LOG_DIR is empty), \"\n            \"but a test tried to access the log directory.\"\n        )\n    return Path(log_dir_str)\n\n\ndef _ensure_log_path(log_dir: Path) -> Path:\n    log_dir.mkdir(parents=True, exist_ok=True)\n    return log_dir / \"performance.log\"\n\n\ndef clear_perf_log(log_dir: Path) -> Path:\n    \"\"\"Delete the perf log file so tests can watch for fresh entries.\"\"\"\n    log_path = _ensure_log_path(log_dir)\n    if log_path.exists():\n        log_path.unlink()\n    logger.info(\"[server-test] Monitoring perf log at %s\", log_path.as_posix())\n    return log_path\n\n\ndef prepare_perf_log() -> tuple[Path, Path]:\n    \"\"\"Convenience helper to resolve and clear the perf log in one call.\"\"\"\n    log_dir = get_perf_log_dir()\n    log_path = clear_perf_log(log_dir)\n    return log_dir, log_path\n\n\ndef read_perf_logs(log_path: Path) -> list[RequestPerfRecord]:\n    if not log_path.exists():\n        return []\n    records: list[RequestPerfRecord] = []\n    with log_path.open(\"r\", encoding=\"utf-8\") as fh:\n        for line in fh:\n            line = line.strip()\n            if not line:\n                continue\n            try:\n                record_dict = json.loads(line)\n                records.append(RequestPerfRecord(**record_dict))\n            except json.JSONDecodeError:\n                continue\n    return records\n\n\ndef wait_for_req_perf_record(\n    request_id: str,\n    log_path: Path,\n    timeout: float = 30.0,\n) -> RequestPerfRecord | None:\n    \"\"\"\n    the stage metrics of this request should be in the performance_log file with {request-id}\n    \"\"\"\n    logger.info(f\"Waiting for req perf record with request id: {request_id}\")\n    deadline = time.time() + timeout\n    while time.time() < deadline:\n        records = read_perf_logs(log_path)\n        for record in records:\n            if record.request_id == request_id:\n                return record\n\n        time.sleep(0.5)\n\n    if os.environ.get(\"SGLANG_GEN_BASELINE\", \"0\") == \"1\":\n        return None\n\n    logger.error(f\"record: {records}\")\n    raise AssertionError(f\"Timeout waiting for stage metrics for request {request_id} \")\n\n\ndef validate_image(b64_json: str) -> None:\n    \"\"\"Decode and validate that image is PNG or JPEG.\"\"\"\n    image_bytes = base64.b64decode(b64_json)\n    assert is_png(image_bytes) or is_jpeg(image_bytes), \"Image must be PNG or JPEG\"\n\n\ndef validate_video(b64_json: str) -> None:\n    \"\"\"Decode and validate that video is a valid format.\"\"\"\n    video_bytes = base64.b64decode(b64_json)\n    is_webm = video_bytes[:4] == b\"\\x1a\\x45\\xdf\\xa3\"\n    assert is_mp4(video_bytes) or is_webm, \"Video must be MP4 or WebM\"\n\n\ndef validate_openai_video(video_bytes: bytes) -> None:\n    \"\"\"Validate that video is MP4 or WebM by magic bytes.\"\"\"\n    is_webm = video_bytes.startswith(b\"\\x1a\\x45\\xdf\\xa3\")\n    assert is_mp4(video_bytes) or is_webm, \"Video must be MP4 or WebM\"\n\n\ndef validate_image_file(\n    file_path: str,\n    expected_filename: str,\n    expected_width: int | None = None,\n    expected_height: int | None = None,\n    output_format: str | None = None,\n    background: str | None = None,\n) -> None:\n    \"\"\"Validate image output file: existence, extension, size, filename, format, dimensions.\"\"\"\n    # Infer expected format from request parameters\n    expected_ext = get_expected_image_format(output_format, background)\n\n    # 1. File existence\n    assert os.path.exists(file_path), f\"Image file does not exist: {file_path}\"\n\n    # 2. Extension check\n    assert file_path.endswith(\n        f\".{expected_ext}\"\n    ), f\"Expected .{expected_ext} extension, got: {file_path}\"\n\n    # 3. File size > 0\n    file_size = os.path.getsize(file_path)\n    assert file_size > 0, f\"Image file is empty: {file_path}\"\n\n    # 4. Filename validation\n    actual_filename = os.path.basename(file_path)\n    assert (\n        actual_filename == expected_filename\n    ), f\"Filename mismatch: expected '{expected_filename}', got '{actual_filename}'\"\n\n    # 5. Image format validation (magic bytes check based on expected format)\n    with open(file_path, \"rb\") as f:\n        header = f.read(12)  # Read enough bytes for webp detection\n        if expected_ext == \"png\":\n            assert is_png(header), f\"File is not a valid PNG: {file_path}\"\n        elif expected_ext == \"jpg\":\n            assert is_jpeg(header), f\"File is not a valid JPEG: {file_path}\"\n        elif expected_ext == \"webp\":\n            assert is_webp(header), f\"File is not a valid WebP: {file_path}\"\n\n    # 6. Image dimension validation (reuse PIL)\n    if expected_width is not None and expected_height is not None:\n        with Image.open(file_path) as img:\n            width, height = img.size\n            assert (\n                width == expected_width\n            ), f\"Width mismatch: expected {expected_width}, got {width}\"\n            assert (\n                height == expected_height\n            ), f\"Height mismatch: expected {expected_height}, got {height}\"\n\n\ndef _get_video_dimensions_from_metadata(\n    cap: cv2.VideoCapture,\n) -> tuple[int, int] | None:\n    \"\"\"Get video dimensions from metadata properties.\n\n    Args:\n        cap: OpenCV VideoCapture object\n\n    Returns:\n        Tuple of (width, height) if successful, None if metadata is invalid\n    \"\"\"\n    width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)\n    height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)\n\n    if width == 0 or height == 0:\n        return None\n\n    return int(width), int(height)\n\n\ndef _get_video_dimensions_from_frame(cap: cv2.VideoCapture) -> tuple[int, int]:\n    \"\"\"Get video dimensions by reading the first frame.\n\n    Args:\n        cap: OpenCV VideoCapture object\n\n    Returns:\n        Tuple of (width, height)\n\n    \"\"\"\n    ret, frame = cap.read()\n    if not ret or frame is None:\n        raise ValueError(\"Unable to read video frame to get dimensions\")\n\n    # frame.shape is (height, width, channels)\n    height, width = frame.shape[:2]\n    return int(width), int(height)\n\n\ndef get_video_dimensions(file_path: str) -> tuple[int, int]:\n    \"\"\"Get video dimensions (width, height) from a video file.\n\n    Tries to get dimensions from metadata first, falls back to reading first frame.\n\n    Returns:\n        Tuple of (width, height)\n\n    \"\"\"\n    cap = cv2.VideoCapture(file_path)\n    try:\n        # Try to get dimensions from metadata first\n        dimensions = _get_video_dimensions_from_metadata(cap)\n        if dimensions is not None:\n            return dimensions\n\n        # Fall back to reading first frame\n        return _get_video_dimensions_from_frame(cap)\n    finally:\n        cap.release()\n\n\ndef get_video_frame_count(file_path: str) -> int:\n    \"\"\"Return the number of frames in a video file using OpenCV.\"\"\"\n    cap = cv2.VideoCapture(file_path)\n    try:\n        count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))\n        if count > 0:\n            return count\n        # Fallback: count frames manually\n        n = 0\n        while cap.read()[0]:\n            n += 1\n        return n\n    finally:\n        cap.release()\n\n\ndef validate_video_file(\n    file_path: str,\n    expected_filename: str,\n    expected_width: int | None = None,\n    expected_height: int | None = None,\n) -> None:\n    \"\"\"Validate video output file: existence, extension, size, filename, format, dimensions.\"\"\"\n    # 1. File existence\n    assert os.path.exists(file_path), f\"Video file does not exist: {file_path}\"\n\n    # 2. Extension check\n    assert file_path.endswith(\".mp4\"), f\"Expected .mp4 extension, got: {file_path}\"\n\n    # 3. File size > 0\n    file_size = os.path.getsize(file_path)\n    assert file_size > 0, f\"Video file is empty: {file_path}\"\n\n    # 4. Filename validation\n    actual_filename = os.path.basename(file_path)\n    assert (\n        actual_filename == expected_filename\n    ), f\"Filename mismatch: expected '{expected_filename}', got '{actual_filename}'\"\n\n    # 5. Video format validation (reuse is_mp4)\n    with open(file_path, \"rb\") as f:\n        header = f.read(32)\n        assert is_mp4(header), f\"File is not a valid MP4: {file_path}\"\n\n    # 6. Video dimension validation (using OpenCV)\n    if expected_width is not None and expected_height is not None:\n        actual_width, actual_height = get_video_dimensions(file_path)\n        assert (\n            actual_width == expected_width\n        ), f\"Video width mismatch: expected {expected_width}, got {actual_width}\"\n        assert (\n            actual_height == expected_height\n        ), f\"Video height mismatch: expected {expected_height}, got {actual_height}\"\n\n\ndef output_format_to_ext(output_format: str | None) -> str:\n    \"\"\"Map output_format to file extension. Used by GT naming and consistency check.\"\"\"\n    if not output_format:\n        return \"png\"\n    of = output_format.lower()\n    if of == \"jpeg\":\n        return \"jpg\"\n    if of in (\"png\", \"webp\", \"jpg\"):\n        return of\n    return \"png\"\n\n\ndef _consistency_gt_filenames(\n    case_id: str, num_gpus: int, is_video: bool, output_format: str | None = None\n) -> list[str]:\n    \"\"\"Return the list of GT image filenames for a case. Reused by GT generation and consistency check.\"\"\"\n    n = num_gpus\n    if is_video:\n        return [\n            f\"{case_id}_{n}gpu_frame_0.png\",\n            f\"{case_id}_{n}gpu_frame_mid.png\",\n            f\"{case_id}_{n}gpu_frame_last.png\",\n        ]\n    ext = output_format_to_ext(output_format)\n    return [f\"{case_id}_{n}gpu.{ext}\"]\n\n\ndef extract_key_frames_from_video(\n    video_bytes: bytes,\n    num_frames: int | None = None,\n) -> list[np.ndarray]:\n    \"\"\"\n    Extract key frames (first, middle, last) from video bytes.\n\n    Args:\n        video_bytes: Raw video bytes (MP4 format)\n        num_frames: Total number of frames (if known), used for validation\n\n    Returns:\n        List of numpy arrays [first_frame, middle_frame, last_frame].\n    \"\"\"\n    with tempfile.NamedTemporaryFile(suffix=\".mp4\", delete=False) as tmp:\n        tmp.write(video_bytes)\n        tmp_path = tmp.name\n\n    try:\n        cap = cv2.VideoCapture(tmp_path)\n        if not cap.isOpened():\n            raise ValueError(\"Failed to open video file\")\n\n        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))\n        if total_frames < 1:\n            raise ValueError(\"Video has no frames\")\n\n        first_idx = 0\n        mid_idx = total_frames // 2\n        last_idx = total_frames - 1\n        key_indices = [first_idx, mid_idx, last_idx]\n\n        frames = []\n        for idx in key_indices:\n            cap.set(cv2.CAP_PROP_POS_FRAMES, idx)\n            ret, frame = cap.read()\n            if not ret:\n                raise ValueError(f\"Failed to read frame at index {idx}\")\n            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n            frames.append(frame_rgb)\n\n        cap.release()\n        logger.info(\n            f\"Extracted {len(frames)} key frames from video \"\n            f\"(total: {total_frames}, indices: {key_indices})\"\n        )\n        return frames\n\n    finally:\n        os.unlink(tmp_path)\n\n\ndef image_bytes_to_numpy(image_bytes: bytes) -> np.ndarray:\n    \"\"\"Convert image bytes to numpy array.\"\"\"\n    img = Image.open(io.BytesIO(image_bytes)).convert(\"RGB\")\n    return np.array(img)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/test/unit/test_lora_format_adapter.py",
    "content": "\"\"\"\ntest_lora_format_adapter.py\n\nSmall regression test for the LoRA format adapter.\n\nIt downloads several public LoRA checkpoints from Hugging Face, runs\nformat detection and normalization, and prints a compact summary table.\n\"\"\"\n\nimport logging\nimport os\nimport tempfile\nfrom typing import Dict, List\n\nimport torch\nfrom huggingface_hub import hf_hub_download\nfrom safetensors.torch import load_file\n\nfrom sglang.multimodal_gen.runtime.pipelines_core.lora_format_adapter import (\n    LoRAFormat,\n    detect_lora_format_from_state_dict,\n    normalize_lora_state_dict,\n)\n\nlogging.basicConfig(level=logging.INFO, force=True)\nlogger = logging.getLogger(\"lora_test\")\n\nROOT_DIR = os.path.join(tempfile.gettempdir(), \"sglang_lora_tests\")\nos.makedirs(ROOT_DIR, exist_ok=True)\n\n\ndef download_lora(\n    repo_id: str,\n    filename: str,\n    local_name: str,\n) -> str:\n    \"\"\"\n    Download a LoRA safetensors file into ROOT_DIR and return its local path.\n    \"\"\"\n    print(f\"=== Downloading LoRA from {repo_id} ({filename}) ===\")\n    path = hf_hub_download(\n        repo_id=repo_id,\n        filename=filename,\n        local_dir=ROOT_DIR,\n        local_dir_use_symlinks=False,\n    )\n    dst = os.path.join(ROOT_DIR, local_name)\n    if os.path.abspath(path) != os.path.abspath(dst):\n        try:\n            import shutil\n\n            shutil.copy2(path, dst)\n        except Exception:\n            dst = path\n    print(f\"Saved to: {dst}\")\n    return dst\n\n\ndef is_diffusers_style_keys(\n    sd: Dict[str, torch.Tensor],\n    debug_name: str = \"\",\n) -> bool:\n    \"\"\"\n    Relaxed structural check that a state_dict looks like diffusers-style LoRA.\n\n    The check verifies:\n    1) No known non-diffusers prefixes.\n    2) No non-diffusers suffixes such as alpha / dora_scale / magnitude vectors.\n    3) Most top-level roots match common diffusers module namespaces.\n    \"\"\"\n    if not sd:\n        print(f\"[{debug_name}] diffusers-style check: EMPTY state_dict\")\n        return False\n\n    keys: List[str] = list(sd.keys())\n    total = len(keys)\n\n    banned_prefixes = (\n        \"lora_unet_\",\n        \"lora_te_\",\n        \"lora_te1_\",\n        \"lora_te2_\",\n        \"lora_unet_double_blocks_\",\n        \"lora_unet_single_blocks_\",\n    )\n    bad_prefix_keys = [k for k in keys if k.startswith(banned_prefixes)]\n    cond1 = len(bad_prefix_keys) == 0\n\n    banned_suffixes = (\n        \".alpha\",\n        \".dora_scale\",\n        \".lora_magnitude_vector\",\n    )\n    bad_suffix_keys = [k for k in keys if k.endswith(banned_suffixes)]\n    cond2 = len(bad_suffix_keys) == 0\n\n    allowed_roots = {\n        \"unet\",\n        \"text_encoder\",\n        \"text_encoder_2\",\n        \"transformer\",\n        \"prior\",\n        \"image_encoder\",\n        \"vae\",\n        \"diffusion_model\",\n    }\n    root_names = [k.split(\".\", 1)[0] for k in keys]\n    root_ok_count = sum(r in allowed_roots for r in root_names)\n    cond3 = root_ok_count >= 0.6 * total\n\n    ok = cond1 and cond2 and cond3\n\n    if not ok:\n        print(f\"[{debug_name}] diffusers-style check FAILED (relaxed):\")\n        print(f\"  total keys = {total}\")\n        print(\n            f\"  cond1(no banned prefixes)  = {cond1}, bad_prefix_keys={len(bad_prefix_keys)}\"\n        )\n        if not cond1 and bad_prefix_keys:\n            print(\"    example bad prefix key:\", bad_prefix_keys[0])\n        print(\n            f\"  cond2(no banned suffixes)  = {cond2}, bad_suffix_keys={len(bad_suffix_keys)}\"\n        )\n        if not cond2 and bad_suffix_keys:\n            print(\"    example bad suffix key:\", bad_suffix_keys[0])\n        print(f\"  cond3(allowed roots>=60%)  = {cond3}, root_ok_count={root_ok_count}\")\n    return ok\n\n\ndef run_single_test(\n    name: str,\n    repo_id: str,\n    filename: str,\n    local_name: str,\n    expected_before: LoRAFormat,\n    expected_after: LoRAFormat = LoRAFormat.STANDARD,\n):\n    \"\"\"\n    Run a single end-to-end test for one LoRA checkpoint.\n\n    Steps:\n    1) Download.\n    2) Detect format on raw keys.\n    3) Normalize via lora_format_adapter.\n    4) Detect again on the normalized dict.\n    5) Optionally check for diffusers-style key structure.\n    \"\"\"\n    logger.info(f\"=== Running test: {name} ===\")\n    local_path = download_lora(repo_id, filename, local_name)\n    raw_state = load_file(local_path)\n\n    detected_before = detect_lora_format_from_state_dict(raw_state)\n    norm_state = normalize_lora_state_dict(raw_state, logger=logger)\n    detected_after = detect_lora_format_from_state_dict(norm_state)\n    standard_like = is_diffusers_style_keys(norm_state, debug_name=name)\n\n    passed = detected_before == expected_before and detected_after == expected_after\n\n    return {\n        \"name\": name,\n        \"expected_before\": expected_before.value,\n        \"detected_before\": detected_before.value,\n        \"expected_after\": expected_after.value,\n        \"detected_after\": detected_after.value,\n        \"standard_like_keys\": standard_like,\n        \"pass\": passed,\n        \"num_keys_raw\": len(raw_state),\n        \"num_keys_norm\": len(norm_state),\n    }\n\n\ndef _run_all_tests() -> List[Dict]:\n    results: List[Dict] = []\n\n    # SDXL LoRA that is already in diffusers/PEFT format.\n    results.append(\n        run_single_test(\n            name=\"HF standard SDXL LoRA\",\n            repo_id=\"jbilcke-hf/sdxl-cinematic-1\",\n            filename=\"pytorch_lora_weights.safetensors\",\n            local_name=\"sdxl_cinematic1_pytorch_lora_weights.safetensors\",\n            expected_before=LoRAFormat.STANDARD,\n            expected_after=LoRAFormat.STANDARD,\n        )\n    )\n\n    # XLabs FLUX LoRA (non-diffusers → diffusers).\n    results.append(\n        run_single_test(\n            name=\"XLabs FLUX Realism LoRA\",\n            repo_id=\"XLabs-AI/flux-RealismLora\",\n            filename=\"lora.safetensors\",\n            local_name=\"flux_realism_lora.safetensors\",\n            expected_before=LoRAFormat.XLABS_FLUX,\n            expected_after=LoRAFormat.STANDARD,\n        )\n    )\n\n    # Kohya-style FLUX LoRA (sd-scripts flux_lora.py → diffusers).\n    results.append(\n        run_single_test(\n            name=\"Kohya-style Flux LoRA\",\n            repo_id=\"kohya-ss/misc-models\",\n            filename=\"flux-hasui-lora-d4-sigmoid-raw-gs1.0.safetensors\",\n            local_name=\"flux_hasui_lora_d4_sigmoid_raw_gs1_0.safetensors\",\n            expected_before=LoRAFormat.KOHYA_FLUX,\n            expected_after=LoRAFormat.STANDARD,\n        )\n    )\n\n    # AI-Toolkit Flux LoRA (non-diffusers → diffusers).\n    results.append(\n        run_single_test(\n            name=\"AI-Toolkit Flux LoRA\",\n            repo_id=\"fal/flux-2-klein-4b-spritesheet-lora\",\n            filename=\"flux-spritesheet-lora.safetensors\",\n            local_name=\"flux_spritesheet_lora.safetensors\",\n            expected_before=LoRAFormat.AI_TOOLKIT_FLUX,\n            expected_after=LoRAFormat.STANDARD,\n        )\n    )\n\n    # Classic Kohya/A1111 SD LoRA (non-diffusers SD → diffusers).\n    results.append(\n        run_single_test(\n            name=\"Kohya-style SD LoRA\",\n            repo_id=\"kohya-ss/misc-models\",\n            filename=\"fp-1f-chibi-1024.safetensors\",\n            local_name=\"fp_1f_chibi_1024.safetensors\",\n            expected_before=LoRAFormat.NON_DIFFUSERS_SD,\n            expected_after=LoRAFormat.STANDARD,\n        )\n    )\n\n    # Wan2.1 Fun Reward LoRA (ComfyUI format → diffusers).\n    results.append(\n        run_single_test(\n            name=\"Wan2.1 Fun Reward LoRA (Comfy)\",\n            repo_id=\"alibaba-pai/Wan2.1-Fun-Reward-LoRAs\",\n            filename=\"Wan2.1-Fun-1.3B-InP-MPS.safetensors\",\n            local_name=\"wan21_fun_1_3b_inp_mps.safetensors\",\n            expected_before=LoRAFormat.NON_DIFFUSERS_SD,\n            expected_after=LoRAFormat.STANDARD,\n        )\n    )\n\n    # Qwen-Image EVA LoRA (already diffusers/PEFT-style).\n    results.append(\n        run_single_test(\n            name=\"Qwen-Image EVA LoRA\",\n            repo_id=\"starsfriday/Qwen-Image-EVA-LoRA\",\n            filename=\"qwen_image_eva.safetensors\",\n            local_name=\"qwen_image_eva.safetensors\",\n            expected_before=LoRAFormat.STANDARD,\n            expected_after=LoRAFormat.STANDARD,\n        )\n    )\n\n    # Qwen-Image Lightning LoRA (non-diffusers Qwen → diffusers).\n    results.append(\n        run_single_test(\n            name=\"Qwen-Image Lightning LoRA\",\n            repo_id=\"lightx2v/Qwen-Image-Lightning\",\n            filename=\"Qwen-Image-Lightning-4steps-V1.0-bf16.safetensors\",\n            local_name=\"qwen_image_lightning_4steps_v1_bf16.safetensors\",\n            expected_before=LoRAFormat.NON_DIFFUSERS_SD,\n            expected_after=LoRAFormat.STANDARD,\n        )\n    )\n\n    # Classic Painting Z-Image Turbo LoRA (Z-Image family).\n    results.append(\n        run_single_test(\n            name=\"Classic Painting Z-Image LoRA\",\n            repo_id=\"renderartist/Classic-Painting-Z-Image-Turbo-LoRA\",\n            filename=\"Classic_Painting_Z_Image_Turbo_v1_renderartist_1750.safetensors\",\n            local_name=\"classic_painting_z_image_turbo_v1_renderartist_1750.safetensors\",\n            expected_before=LoRAFormat.STANDARD,\n            expected_after=LoRAFormat.STANDARD,\n        )\n    )\n\n    return results\n\n\ndef _print_summary(results: List[Dict]) -> None:\n    print(\"\\n================ LoRA format adapter test ================\")\n\n    header = (\n        f\"{'Test Name':30} \"\n        f\"{'Exp(b)':12} \"\n        f\"{'Act(b)':12} \"\n        f\"{'Exp(a)':12} \"\n        f\"{'Act(a)':12} \"\n        f\"{'StdLike':8} \"\n        f\"{'#Raw':7} \"\n        f\"{'#Norm':7} \"\n        f\"{'PASS':5}\"\n    )\n    print(header)\n    print(\"-\" * len(header))\n\n    for r in results:\n        print(\n            f\"{r['name'][:30]:30} \"\n            f\"{r['expected_before'][:12]:12} \"\n            f\"{r['detected_before'][:12]:12} \"\n            f\"{r['expected_after'][:12]:12} \"\n            f\"{r['detected_after'][:12]:12} \"\n            f\"{str(r['standard_like_keys']):8} \"\n            f\"{r['num_keys_raw']:7d} \"\n            f\"{r['num_keys_norm']:7d} \"\n            f\"{str(r['pass']):5}\"\n        )\n\n    print(\"=========================================================\\n\")\n\n\ndef main() -> None:\n    results = _run_all_tests()\n    _print_summary(results)\n\n    if not all(r[\"pass\"] for r in results):\n        raise SystemExit(1)\n\n\nclass TestLoRAFormatAdapter:\n    def test_lora_format_adapter_all_formats(self):\n        results = _run_all_tests()\n        assert all(\n            r[\"pass\"] for r in results\n        ), \"At least one LoRA format adapter case failed\"\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "python/sglang/multimodal_gen/test/unit/test_sampling_params.py",
    "content": "import argparse\nimport math\nimport unittest\n\nfrom sglang.multimodal_gen.configs.sample.diffusers_generic import (\n    DiffusersGenericSamplingParams,\n)\nfrom sglang.multimodal_gen.configs.sample.flux import FluxSamplingParams\nfrom sglang.multimodal_gen.configs.sample.qwenimage import QwenImageSamplingParams\nfrom sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams\n\n\nclass TestSamplingParamsValidate(unittest.TestCase):\n    def test_prompt_path_suffix(self):\n        with self.assertRaisesRegex(ValueError, r\"prompt_path\"):\n            SamplingParams(prompt_path=\"bad.png\")\n\n    def test_num_outputs_per_prompt_must_be_positive(self):\n        with self.assertRaisesRegex(ValueError, r\"num_outputs_per_prompt\"):\n            SamplingParams(num_outputs_per_prompt=0)\n\n    def test_fps_must_be_positive_int(self):\n        with self.assertRaisesRegex(ValueError, r\"\\bfps\\b\"):\n            SamplingParams(fps=0)\n        with self.assertRaisesRegex(ValueError, r\"\\bfps\\b\"):\n            SamplingParams(fps=None)  # type: ignore[arg-type]\n\n    def test_num_inference_steps_optional_but_if_set_must_be_positive(self):\n        SamplingParams(num_inference_steps=None)\n        with self.assertRaisesRegex(ValueError, r\"num_inference_steps\"):\n            SamplingParams(num_inference_steps=-1)\n\n    def test_guidance_scale_must_be_finite_non_negative_if_set(self):\n        SamplingParams(guidance_scale=None)\n        with self.assertRaisesRegex(ValueError, r\"guidance_scale\"):\n            SamplingParams(guidance_scale=math.nan)\n        with self.assertRaisesRegex(ValueError, r\"guidance_scale\"):\n            SamplingParams(guidance_scale=-0.1)\n\n    def test_guidance_rescale_must_be_finite_non_negative(self):\n        with self.assertRaisesRegex(ValueError, r\"guidance_rescale\"):\n            SamplingParams(guidance_rescale=-1.0)\n        with self.assertRaisesRegex(ValueError, r\"guidance_rescale\"):\n            SamplingParams(guidance_rescale=math.inf)\n\n    def test_boundary_ratio_range(self):\n        SamplingParams(boundary_ratio=None)\n        with self.assertRaisesRegex(ValueError, r\"boundary_ratio\"):\n            SamplingParams(boundary_ratio=1.5)\n        with self.assertRaisesRegex(ValueError, r\"boundary_ratio\"):\n            SamplingParams(boundary_ratio=math.nan)\n\n\nclass TestSamplingParamsSubclass(unittest.TestCase):\n    def test_flux_defaults_resolution_when_not_provided(self):\n        params = FluxSamplingParams()\n\n        self.assertEqual(params.height, 1024)\n        self.assertEqual(params.width, 1024)\n\n    def test_flux_preserves_user_resolution(self):\n        params = FluxSamplingParams(height=640, width=768)\n\n        self.assertEqual(params.height, 640)\n        self.assertEqual(params.width, 768)\n\n    def test_diffusers_generic_calls_base_post_init(self):\n        with self.assertRaises(AssertionError):\n            DiffusersGenericSamplingParams(num_frames=0)\n\n\nclass TestSamplingParamsCliArgs(unittest.TestCase):\n    def _parse_cli_kwargs(self, argv: list[str]) -> dict:\n        parser = argparse.ArgumentParser()\n        SamplingParams.add_cli_args(parser)\n        args = parser.parse_args(argv)\n        return SamplingParams.get_cli_args(args)\n\n    def _make_qwen_image_params(self, argv: list[str]) -> QwenImageSamplingParams:\n        return QwenImageSamplingParams(**self._parse_cli_kwargs(argv))\n\n    def test_get_cli_args_drops_unset_sampling_params(self):\n        self.assertEqual(self._parse_cli_kwargs([]), {})\n\n    def test_get_cli_args_keeps_explicit_sampling_params(self):\n        kwargs = self._parse_cli_kwargs(\n            [\n                \"--guidance-scale\",\n                str(SamplingParams.guidance_scale),\n                \"--negative-prompt\",\n                SamplingParams.negative_prompt,\n                \"--save-output\",\n            ]\n        )\n\n        self.assertEqual(kwargs[\"guidance_scale\"], SamplingParams.guidance_scale)\n        self.assertEqual(kwargs[\"negative_prompt\"], SamplingParams.negative_prompt)\n        self.assertTrue(kwargs[\"save_output\"])\n\n    def test_qwen_image_cli_path_preserves_model_defaults(self):\n        params = self._make_qwen_image_params([])\n\n        self.assertEqual(params.negative_prompt, \" \")\n        self.assertEqual(params.guidance_scale, 4.0)\n\n    def test_qwen_image_cli_path_allows_explicit_override_to_base_defaults(self):\n        params = self._make_qwen_image_params(\n            [\n                \"--guidance-scale\",\n                str(SamplingParams.guidance_scale),\n                \"--negative-prompt\",\n                SamplingParams.negative_prompt,\n            ]\n        )\n\n        self.assertEqual(params.guidance_scale, SamplingParams.guidance_scale)\n        self.assertEqual(params.negative_prompt, SamplingParams.negative_prompt)\n\n    def test_merge_allows_explicit_field_matching_base_default(self):\n        target = DiffusersGenericSamplingParams()\n        user = SamplingParams(negative_prompt=SamplingParams.negative_prompt)\n\n        target._merge_with_user_params(user, explicit_fields={\"negative_prompt\"})\n\n        self.assertEqual(target.negative_prompt, SamplingParams.negative_prompt)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/sglang/multimodal_gen/test/unit/test_server_args.py",
    "content": "import os\nimport sys\nimport unittest\nfrom unittest.mock import patch\n\nfrom sglang.multimodal_gen.configs.pipeline_configs.base import PipelineConfig\nfrom sglang.multimodal_gen.configs.pipeline_configs.qwen_image import (\n    QwenImagePipelineConfig,\n)\nfrom sglang.multimodal_gen.registry import _get_config_info\nfrom sglang.multimodal_gen.runtime.server_args import ServerArgs\nfrom sglang.multimodal_gen.utils import FlexibleArgumentParser\n\n\nclass TestServerArgsPathExpansion(unittest.TestCase):\n    def _from_dict_without_model_resolution(self, kwargs):\n        with patch.object(\n            PipelineConfig, \"from_kwargs\", return_value=QwenImagePipelineConfig()\n        ):\n            return ServerArgs.from_dict(kwargs)\n\n    def test_tilde_model_path_is_expanded(self):\n        args = self._from_dict_without_model_resolution(\n            {\"model_path\": \"~/fake/local/model\"}\n        )\n        expected = os.path.expanduser(\"~/fake/local/model\")\n        self.assertEqual(args.model_path, expected)\n        self.assertFalse(args.model_path.startswith(\"~\"))\n\n    def test_absolute_path_is_unchanged(self):\n        args = self._from_dict_without_model_resolution(\n            {\"model_path\": \"/data/my-model\"}\n        )\n        self.assertEqual(args.model_path, \"/data/my-model\")\n\n    def test_component_paths_are_expanded_before_pipeline_resolution(self):\n        args = self._from_dict_without_model_resolution(\n            {\n                \"model_path\": \"/data/my-model\",\n                \"component_paths\": {\"vae\": \"~/fake/local/vae\"},\n            }\n        )\n\n        self.assertEqual(\n            args.component_paths[\"vae\"], os.path.expanduser(\"~/fake/local/vae\")\n        )\n\n\nclass TestModelIdResolution(unittest.TestCase):\n    def setUp(self):\n        _get_config_info.cache_clear()\n\n    def test_model_id_overrides_arbitrary_local_path(self):\n        # a local path whose directory name does not match any HF repo name;\n        # --model-id tells the engine which config to use\n        info = _get_config_info(\"/data/my-custom-qwen\", model_id=\"Qwen-Image\")\n        self.assertIsNotNone(info)\n\n        self.assertIs(info.pipeline_config_cls, QwenImagePipelineConfig)\n\n    def test_model_id_works_after_tilde_expansion(self):\n        # simulate the full flow: user passes ~/..., engine expands and resolves\n        expanded = os.path.expanduser(\"~/.cache/huggingface/hub/bbb/snapshots/ccc\")\n        _get_config_info.cache_clear()\n        info = _get_config_info(expanded, model_id=\"Qwen-Image\")\n        self.assertIsNotNone(info)\n\n    def test_model_id_unknown_falls_back_without_crash(self):\n        # unrecognized model_id: should warn and fall back to path-based detection\n        # with an unresolvable path, expect RuntimeError from the detector step\n        with self.assertRaises((RuntimeError, Exception)):\n            _get_config_info(\"/data/no-such-model\", model_id=\"NonExistentModelXYZ\")\n\n\nclass TestPipelineResolutionCliOverride(unittest.TestCase):\n    def setUp(self):\n        _get_config_info.cache_clear()\n\n    def test_resolution_flag_overrides_qwen_image_layered_pipeline_config(self):\n        parser = FlexibleArgumentParser()\n        ServerArgs.add_cli_args(parser)\n        argv = [\n            \"--model-path\",\n            \"Qwen/Qwen-Image-Layered\",\n            \"--resolution\",\n            \"768\",\n        ]\n\n        with patch.object(sys, \"argv\", [\"sglang\"] + argv):\n            args, unknown_args = parser.parse_known_args(argv)\n            server_args = ServerArgs.from_cli_args(args, unknown_args)\n\n        self.assertEqual(server_args.pipeline_config.resolution, 768)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/sglang/multimodal_gen/test/unit/test_storage.py",
    "content": "\"\"\"\nTest suite for S3 CloudStorage integration.\n\nTests verify file upload, cleanup, URL generation, and error handling.\n\"\"\"\n\nimport asyncio\nimport importlib\nimport os\nfrom types import SimpleNamespace\n\nimport pytest\n\nimport sglang.multimodal_gen.runtime.entrypoints.openai.storage as storage_mod\nfrom sglang.multimodal_gen.runtime.entrypoints.openai.storage import CloudStorage\n\n\ndef _create_temp_file(tmp_path, name=\"test.png\", content=b\"\\x89PNG\\r\\n\\x1a\\nfake\"):\n    \"\"\"Create a temporary test file.\"\"\"\n    p = tmp_path / name\n    p.write_bytes(content)\n    return str(p)\n\n\n# UNIT TESTS\n\n\ndef test_upload_file_success(tmp_path):\n    \"\"\"Test successful upload with correct URL generation.\"\"\"\n    file_path = _create_temp_file(tmp_path, \"image.png\")\n\n    storage_mod.cloud_storage.enabled = True\n    storage_mod.cloud_storage.bucket_name = \"my-bucket\"\n    storage_mod.cloud_storage.endpoint_url = \"https://s3.example.com\"\n    storage_mod.cloud_storage.region_name = None\n\n    called = {}\n\n    def fake_upload(local_path, bucket, key, ExtraArgs=None):\n        called[\"local_path\"] = local_path\n        called[\"bucket\"] = bucket\n        called[\"key\"] = key\n        called[\"extra\"] = ExtraArgs\n\n    storage_mod.cloud_storage.client = SimpleNamespace(upload_file=fake_upload)\n\n    url = asyncio.run(storage_mod.cloud_storage.upload_file(file_path, \"image.png\"))\n\n    assert url == \"https://s3.example.com/my-bucket/image.png\"\n    assert called[\"local_path\"] == file_path\n    assert called[\"bucket\"] == \"my-bucket\"\n    assert called[\"key\"] == \"image.png\"\n    assert called[\"extra\"][\"ContentType\"] == \"image/png\"\n\n\ndef test_upload_and_cleanup(tmp_path):\n    \"\"\"Test that local file is deleted after successful upload.\"\"\"\n    file_path = _create_temp_file(tmp_path, \"cleanup.png\")\n\n    storage_mod.cloud_storage.enabled = True\n    storage_mod.cloud_storage.bucket_name = \"my-bucket\"\n    storage_mod.cloud_storage.endpoint_url = \"https://s3.example.com\"\n    storage_mod.cloud_storage.client = SimpleNamespace(\n        upload_file=lambda *args, **kwargs: None\n    )\n\n    assert os.path.exists(file_path)\n\n    url = asyncio.run(storage_mod.cloud_storage.upload_and_cleanup(file_path))\n\n    assert url == \"https://s3.example.com/my-bucket/cleanup.png\"\n    assert not os.path.exists(file_path)\n\n\ndef test_upload_failure_preserves_file(tmp_path):\n    \"\"\"Test that file is preserved when upload fails.\"\"\"\n    file_path = _create_temp_file(tmp_path, \"preserve.png\")\n\n    storage_mod.cloud_storage.enabled = True\n    storage_mod.cloud_storage.bucket_name = \"my-bucket\"\n    storage_mod.cloud_storage.endpoint_url = \"https://s3.example.com\"\n\n    def fake_upload_raises(*args, **kwargs):\n        raise RuntimeError(\"simulated failure\")\n\n    storage_mod.cloud_storage.client = SimpleNamespace(upload_file=fake_upload_raises)\n\n    result = asyncio.run(storage_mod.cloud_storage.upload_and_cleanup(file_path))\n\n    assert result is None\n    assert os.path.exists(file_path)\n\n\ndef test_disabled_storage_returns_none(tmp_path):\n    \"\"\"Test that disabled storage returns None.\"\"\"\n    file_path = _create_temp_file(tmp_path, \"test.png\")\n\n    prev_enabled = storage_mod.cloud_storage.enabled\n    storage_mod.cloud_storage.enabled = False\n\n    try:\n        result = asyncio.run(\n            storage_mod.cloud_storage.upload_file(file_path, \"test.png\")\n        )\n        assert result is None\n    finally:\n        storage_mod.cloud_storage.enabled = prev_enabled\n\n\ndef test_aws_url_with_region(tmp_path):\n    \"\"\"Test AWS S3 URL generation with specific region.\"\"\"\n    file_path = _create_temp_file(tmp_path, \"aws.png\")\n\n    storage_mod.cloud_storage.enabled = True\n    storage_mod.cloud_storage.bucket_name = \"aws-bucket\"\n    storage_mod.cloud_storage.endpoint_url = None\n    storage_mod.cloud_storage.region_name = \"us-west-2\"\n    storage_mod.cloud_storage.client = SimpleNamespace(\n        upload_file=lambda *args, **kwargs: None\n    )\n\n    url = asyncio.run(storage_mod.cloud_storage.upload_file(file_path, \"aws.png\"))\n\n    assert url == \"https://aws-bucket.s3.us-west-2.amazonaws.com/aws.png\"\n\n\ndef test_aws_url_default_region(tmp_path):\n    \"\"\"Test AWS S3 URL defaults to us-east-1 when region not specified.\"\"\"\n    file_path = _create_temp_file(tmp_path, \"default.png\")\n\n    storage_mod.cloud_storage.enabled = True\n    storage_mod.cloud_storage.bucket_name = \"default-bucket\"\n    storage_mod.cloud_storage.endpoint_url = None\n    storage_mod.cloud_storage.region_name = None\n    storage_mod.cloud_storage.client = SimpleNamespace(\n        upload_file=lambda *args, **kwargs: None\n    )\n\n    url = asyncio.run(storage_mod.cloud_storage.upload_file(file_path, \"default.png\"))\n\n    assert url == \"https://default-bucket.s3.us-east-1.amazonaws.com/default.png\"\n\n\ndef test_custom_endpoint_url(tmp_path):\n    \"\"\"Test URL generation with custom endpoint (MinIO/OSS/COS).\"\"\"\n    file_path = _create_temp_file(tmp_path, \"custom.png\")\n\n    storage_mod.cloud_storage.enabled = True\n    storage_mod.cloud_storage.bucket_name = \"custom-bucket\"\n    storage_mod.cloud_storage.endpoint_url = \"https://minio.example.com/\"\n    storage_mod.cloud_storage.region_name = None\n    storage_mod.cloud_storage.client = SimpleNamespace(\n        upload_file=lambda *args, **kwargs: None\n    )\n\n    url = asyncio.run(storage_mod.cloud_storage.upload_file(file_path, \"custom.png\"))\n\n    # Verify trailing slash is stripped\n    assert url == \"https://minio.example.com/custom-bucket/custom.png\"\n\n\ndef test_content_type_detection(tmp_path):\n    \"\"\"Test Content-Type header for different file extensions.\"\"\"\n    storage_mod.cloud_storage.enabled = True\n    storage_mod.cloud_storage.bucket_name = \"test-bucket\"\n    storage_mod.cloud_storage.endpoint_url = \"https://s3.test\"\n\n    test_cases = [\n        (\"image.png\", \"image/png\"),\n        (\"image.jpg\", \"image/jpeg\"),\n        (\"image.jpeg\", \"image/jpeg\"),\n        (\"image.webp\", \"image/webp\"),\n        (\"video.mp4\", \"video/mp4\"),\n        (\"file.bin\", \"application/octet-stream\"),\n    ]\n\n    for filename, expected_type in test_cases:\n        called = {}\n\n        def fake_upload(local_path, bucket, key, ExtraArgs=None):\n            called[\"content_type\"] = ExtraArgs.get(\"ContentType\")\n\n        storage_mod.cloud_storage.client = SimpleNamespace(upload_file=fake_upload)\n\n        file_path = _create_temp_file(tmp_path, filename)\n        asyncio.run(storage_mod.cloud_storage.upload_file(file_path, filename))\n\n        assert called[\"content_type\"] == expected_type\n\n\n# requires moto and boto3\nhas_moto = (\n    importlib.util.find_spec(\"moto\") is not None\n    and importlib.util.find_spec(\"boto3\") is not None\n)\n\n\n@pytest.mark.skipif(not has_moto, reason=\"moto/boto3 not installed\")\ndef test_integration_with_moto(tmp_path):\n    \"\"\"Integration test using moto to mock real S3 service.\"\"\"\n    import boto3\n    from moto import mock_aws\n\n    os.environ[\"SGLANG_CLOUD_STORAGE_TYPE\"] = \"s3\"\n    os.environ[\"SGLANG_S3_BUCKET_NAME\"] = \"integration-test\"\n    os.environ[\"SGLANG_S3_REGION_NAME\"] = \"us-east-1\"\n\n    with mock_aws():\n        s3 = boto3.client(\"s3\", region_name=\"us-east-1\")\n        s3.create_bucket(Bucket=\"integration-test\")\n\n        storage = CloudStorage()\n        assert storage.is_enabled()\n\n        file_path = _create_temp_file(tmp_path, \"integration.png\", b\"test_data\")\n\n        url = asyncio.run(storage.upload_and_cleanup(file_path))\n\n        assert url is not None\n        assert \"integration-test\" in url\n        assert \"integration.png\" in url\n        assert not os.path.exists(file_path)\n\n        obj = s3.get_object(Bucket=\"integration-test\", Key=\"integration.png\")\n        assert obj[\"Body\"].read() == b\"test_data\"\n\n    for key in [\n        \"SGLANG_CLOUD_STORAGE_TYPE\",\n        \"SGLANG_S3_BUCKET_NAME\",\n        \"SGLANG_S3_REGION_NAME\",\n    ]:\n        os.environ.pop(key, None)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/third_party/__init__.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n"
  },
  {
    "path": "python/sglang/multimodal_gen/third_party/pynvml.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# copied from https://pypi.org/project/nvidia-ml-py\n# version 12.570.86\n\n#####\n# Copyright (c) 2011-2023, NVIDIA Corporation.  All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n#    * Redistributions of source code must retain the above copyright notice,\n#      this list of conditions and the following disclaimer.\n#    * Redistributions in binary form must reproduce the above copyright\n#      notice, this list of conditions and the following disclaimer in the\n#      documentation and/or other materials provided with the distribution.\n#    * Neither the name of the NVIDIA Corporation nor the names of its\n#      contributors may be used to endorse or promote products derived from\n#      this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE\n# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE\n# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR\n# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF\n# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS\n# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN\n# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)\n# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF\n# THE POSSIBILITY OF SUCH DAMAGE.\n#####\n\nimport os\nimport string\nimport sys\nimport threading\n\n##\n# Python bindings for the NVML library\n##\nfrom ctypes import *\nfrom functools import wraps\n\n## C Type mappings ##\n## Enums\n_nvmlEnableState_t = c_uint\nNVML_FEATURE_DISABLED = 0\nNVML_FEATURE_ENABLED = 1\n\n_nvmlBrandType_t = c_uint\nNVML_BRAND_UNKNOWN = 0\nNVML_BRAND_QUADRO = 1\nNVML_BRAND_TESLA = 2\nNVML_BRAND_NVS = 3\nNVML_BRAND_GRID = (\n    4  # Deprecated from API reporting. Keeping definition for backward compatibility.\n)\nNVML_BRAND_GEFORCE = 5\nNVML_BRAND_TITAN = 6\nNVML_BRAND_NVIDIA_VAPPS = 7  # NVIDIA Virtual Applications\nNVML_BRAND_NVIDIA_VPC = 8  # NVIDIA Virtual PC\nNVML_BRAND_NVIDIA_VCS = 9  # NVIDIA Virtual Compute Server\nNVML_BRAND_NVIDIA_VWS = 10  # NVIDIA RTX Virtual Workstation\nNVML_BRAND_NVIDIA_CLOUD_GAMING = 11  # NVIDIA Cloud Gaming\nNVML_BRAND_NVIDIA_VGAMING = NVML_BRAND_NVIDIA_CLOUD_GAMING  # Deprecated from API reporting. Keeping definition for backward compatibility.\nNVML_BRAND_QUADRO_RTX = 12\nNVML_BRAND_NVIDIA_RTX = 13\nNVML_BRAND_NVIDIA = 14\nNVML_BRAND_GEFORCE_RTX = 15  # Unused\nNVML_BRAND_TITAN_RTX = 16  # Unused\nNVML_BRAND_COUNT = 17\n\n_nvmlTemperatureThresholds_t = c_uint\nNVML_TEMPERATURE_THRESHOLD_SHUTDOWN = 0\nNVML_TEMPERATURE_THRESHOLD_SLOWDOWN = 1\nNVML_TEMPERATURE_THRESHOLD_MEM_MAX = 2\nNVML_TEMPERATURE_THRESHOLD_GPU_MAX = 3\nNVML_TEMPERATURE_THRESHOLD_ACOUSTIC_MIN = 4\nNVML_TEMPERATURE_THRESHOLD_ACOUSTIC_CURR = 5\nNVML_TEMPERATURE_THRESHOLD_ACOUSTIC_MAX = 6\nNVML_TEMPERATURE_THRESHOLD_GPS_CURR = 7\nNVML_TEMPERATURE_THRESHOLD_COUNT = 8\n\n_nvmlTemperatureSensors_t = c_uint\nNVML_TEMPERATURE_GPU = 0\nNVML_TEMPERATURE_COUNT = 1\n\n\n_nvmlComputeMode_t = c_uint\nNVML_COMPUTEMODE_DEFAULT = 0\nNVML_COMPUTEMODE_EXCLUSIVE_THREAD = 1  ## Support Removed\nNVML_COMPUTEMODE_PROHIBITED = 2\nNVML_COMPUTEMODE_EXCLUSIVE_PROCESS = 3\nNVML_COMPUTEMODE_COUNT = 4\n\n_nvmlMemoryLocation_t = c_uint\nNVML_MEMORY_LOCATION_L1_CACHE = 0\nNVML_MEMORY_LOCATION_L2_CACHE = 1\nNVML_MEMORY_LOCATION_DEVICE_MEMORY = 2\nNVML_MEMORY_LOCATION_DRAM = 2\nNVML_MEMORY_LOCATION_REGISTER_FILE = 3\nNVML_MEMORY_LOCATION_TEXTURE_MEMORY = 4\nNVML_MEMORY_LOCATION_TEXTURE_SHM = 5\nNVML_MEMORY_LOCATION_CBU = 6\nNVML_MEMORY_LOCATION_SRAM = 7\nNVML_MEMORY_LOCATION_COUNT = 8\n\nNVML_NVLINK_MAX_LINKS = 18\n\n# For backwards compatibility, maintain the incorrectly-named \"LANES\" define\nNVML_NVLINK_MAX_LANES = NVML_NVLINK_MAX_LINKS\n\n_nvmlNvLinkErrorCounter_t = c_uint\nNVML_NVLINK_ERROR_DL_REPLAY = 0\nNVML_NVLINK_ERROR_DL_RECOVERY = 1\nNVML_NVLINK_ERROR_DL_CRC_FLIT = 2\nNVML_NVLINK_ERROR_DL_CRC_DATA = 3\nNVML_NVLINK_ERROR_DL_ECC_DATA = 4\nNVML_NVLINK_ERROR_COUNT = 5\n\n_nvmlNvLinkEccLaneErrorCounter_t = c_uint\nNVML_NVLINK_ERROR_DL_ECC_LANE0 = 0\nNVML_NVLINK_ERROR_DL_ECC_LANE1 = 1\nNVML_NVLINK_ERROR_DL_ECC_LANE2 = 2\nNVML_NVLINK_ERROR_DL_ECC_LANE3 = 3\nNVML_NVLINK_ERROR_DL_ECC_COUNT = 5\n\n_nvmlNvLinkCapability_t = c_uint\nNVML_NVLINK_CAP_P2P_SUPPORTED = 0\nNVML_NVLINK_CAP_SYSMEM_ACCESS = 1\nNVML_NVLINK_CAP_P2P_ATOMICS = 2\nNVML_NVLINK_CAP_SYSMEM_ATOMICS = 3\nNVML_NVLINK_CAP_SLI_BRIDGE = 4\nNVML_NVLINK_CAP_VALID = 5\nNVML_NVLINK_CAP_COUNT = 6\n\n_nvmlNvLinkUtilizationCountPktTypes_t = c_uint\nNVML_NVLINK_COUNTER_PKTFILTER_NOP = 0x1\nNVML_NVLINK_COUNTER_PKTFILTER_READ = 0x2\nNVML_NVLINK_COUNTER_PKTFILTER_WRITE = 0x4\nNVML_NVLINK_COUNTER_PKTFILTER_RATOM = 0x8\nNVML_NVLINK_COUNTER_PKTFILTER_NRATOM = 0x10\nNVML_NVLINK_COUNTER_PKTFILTER_FLUSH = 0x20\nNVML_NVLINK_COUNTER_PKTFILTER_RESPDATA = 0x40\nNVML_NVLINK_COUNTER_PKTFILTER_RESPNODATA = 0x80\nNVML_NVLINK_COUNTER_PKTFILTER_ALL = 0xFF\n\n_nvmlNvLinkUtilizationCountUnits_t = c_uint\nNVML_NVLINK_COUNTER_UNIT_CYCLES = 0\nNVML_NVLINK_COUNTER_UNIT_PACKETS = 1\nNVML_NVLINK_COUNTER_UNIT_BYTES = 2\nNVML_NVLINK_COUNTER_UNIT_RESERVED = 3\nNVML_NVLINK_COUNTER_UNIT_COUNT = 4\n\n_nvmlNvLinkDeviceType_t = c_uint\nNVML_NVLINK_DEVICE_TYPE_GPU = 0x00\nNVML_NVLINK_DEVICE_TYPE_IBMNPU = 0x01\nNVML_NVLINK_DEVICE_TYPE_SWITCH = 0x02\nNVML_NVLINK_DEVICE_TYPE_UNKNOWN = 0xFF\n\n# These are deprecated, instead use _nvmlMemoryErrorType_t\n_nvmlEccBitType_t = c_uint\nNVML_SINGLE_BIT_ECC = 0\nNVML_DOUBLE_BIT_ECC = 1\nNVML_ECC_ERROR_TYPE_COUNT = 2\n\n_nvmlEccCounterType_t = c_uint\nNVML_VOLATILE_ECC = 0\nNVML_AGGREGATE_ECC = 1\nNVML_ECC_COUNTER_TYPE_COUNT = 2\n\n_nvmlMemoryErrorType_t = c_uint\nNVML_MEMORY_ERROR_TYPE_CORRECTED = 0\nNVML_MEMORY_ERROR_TYPE_UNCORRECTED = 1\nNVML_MEMORY_ERROR_TYPE_COUNT = 2\n\n_nvmlClockType_t = c_uint\nNVML_CLOCK_GRAPHICS = 0\nNVML_CLOCK_SM = 1\nNVML_CLOCK_MEM = 2\nNVML_CLOCK_VIDEO = 3\nNVML_CLOCK_COUNT = 4\n\n_nvmlClockId_t = c_uint\nNVML_CLOCK_ID_CURRENT = 0\nNVML_CLOCK_ID_APP_CLOCK_TARGET = 1\nNVML_CLOCK_ID_APP_CLOCK_DEFAULT = 2\nNVML_CLOCK_ID_CUSTOMER_BOOST_MAX = 3\nNVML_CLOCK_ID_COUNT = 4\n\n_nvmlDriverModel_t = c_uint\nNVML_DRIVER_WDDM = 0\nNVML_DRIVER_WDM = 1\nNVML_DRIVER_MCDM = 2\n\nNVML_MAX_GPU_PERF_PSTATES = 16\n\n_nvmlPstates_t = c_uint\nNVML_PSTATE_0 = 0\nNVML_PSTATE_1 = 1\nNVML_PSTATE_2 = 2\nNVML_PSTATE_3 = 3\nNVML_PSTATE_4 = 4\nNVML_PSTATE_5 = 5\nNVML_PSTATE_6 = 6\nNVML_PSTATE_7 = 7\nNVML_PSTATE_8 = 8\nNVML_PSTATE_9 = 9\nNVML_PSTATE_10 = 10\nNVML_PSTATE_11 = 11\nNVML_PSTATE_12 = 12\nNVML_PSTATE_13 = 13\nNVML_PSTATE_14 = 14\nNVML_PSTATE_15 = 15\nNVML_PSTATE_UNKNOWN = 32\n\n_nvmlInforomObject_t = c_uint\nNVML_INFOROM_OEM = 0\nNVML_INFOROM_ECC = 1\nNVML_INFOROM_POWER = 2\nNVML_INFOROM_DEN = 3\nNVML_INFOROM_COUNT = 4\n\n_nvmlReturn_t = c_uint\nNVML_SUCCESS = 0\nNVML_ERROR_UNINITIALIZED = 1\nNVML_ERROR_INVALID_ARGUMENT = 2\nNVML_ERROR_NOT_SUPPORTED = 3\nNVML_ERROR_NO_PERMISSION = 4\nNVML_ERROR_ALREADY_INITIALIZED = 5\nNVML_ERROR_NOT_FOUND = 6\nNVML_ERROR_INSUFFICIENT_SIZE = 7\nNVML_ERROR_INSUFFICIENT_POWER = 8\nNVML_ERROR_DRIVER_NOT_LOADED = 9\nNVML_ERROR_TIMEOUT = 10\nNVML_ERROR_IRQ_ISSUE = 11\nNVML_ERROR_LIBRARY_NOT_FOUND = 12\nNVML_ERROR_FUNCTION_NOT_FOUND = 13\nNVML_ERROR_CORRUPTED_INFOROM = 14\nNVML_ERROR_GPU_IS_LOST = 15\nNVML_ERROR_RESET_REQUIRED = 16\nNVML_ERROR_OPERATING_SYSTEM = 17\nNVML_ERROR_LIB_RM_VERSION_MISMATCH = 18\nNVML_ERROR_IN_USE = 19\nNVML_ERROR_MEMORY = 20\nNVML_ERROR_NO_DATA = 21\nNVML_ERROR_VGPU_ECC_NOT_SUPPORTED = 22\nNVML_ERROR_INSUFFICIENT_RESOURCES = 23\nNVML_ERROR_FREQ_NOT_SUPPORTED = 24\nNVML_ERROR_ARGUMENT_VERSION_MISMATCH = 25\nNVML_ERROR_DEPRECATED = 26\nNVML_ERROR_NOT_READY = 27\nNVML_ERROR_GPU_NOT_FOUND = 28\nNVML_ERROR_INVALID_STATE = 29\nNVML_ERROR_UNKNOWN = 999\n\n_nvmlFanState_t = c_uint\nNVML_FAN_NORMAL = 0\nNVML_FAN_FAILED = 1\n\n_nvmlFanControlPolicy_t = c_uint\nNVML_FAN_POLICY_TEMPERATURE_CONTINOUS_SW = 0\nNVML_FAN_POLICY_MANUAL = 1\n\n_nvmlLedColor_t = c_uint\nNVML_LED_COLOR_GREEN = 0\nNVML_LED_COLOR_AMBER = 1\n\n_nvmlGpuOperationMode_t = c_uint\nNVML_GOM_ALL_ON = 0\nNVML_GOM_COMPUTE = 1\nNVML_GOM_LOW_DP = 2\n\n_nvmlPageRetirementCause_t = c_uint\nNVML_PAGE_RETIREMENT_CAUSE_MULTIPLE_SINGLE_BIT_ECC_ERRORS = 0\nNVML_PAGE_RETIREMENT_CAUSE_DOUBLE_BIT_ECC_ERROR = 1\nNVML_PAGE_RETIREMENT_CAUSE_COUNT = 2\n\n_nvmlRestrictedAPI_t = c_uint\nNVML_RESTRICTED_API_SET_APPLICATION_CLOCKS = 0\nNVML_RESTRICTED_API_SET_AUTO_BOOSTED_CLOCKS = 1\nNVML_RESTRICTED_API_COUNT = 2\n\n_nvmlBridgeChipType_t = c_uint\nNVML_BRIDGE_CHIP_PLX = 0\nNVML_BRIDGE_CHIP_BRO4 = 1\nNVML_MAX_PHYSICAL_BRIDGE = 128\n\n_nvmlValueType_t = c_uint\nNVML_VALUE_TYPE_DOUBLE = 0\nNVML_VALUE_TYPE_UNSIGNED_INT = 1\nNVML_VALUE_TYPE_UNSIGNED_LONG = 2\nNVML_VALUE_TYPE_UNSIGNED_LONG_LONG = 3\nNVML_VALUE_TYPE_SIGNED_LONG_LONG = 4\nNVML_VALUE_TYPE_SIGNED_INT = 5\nNVML_VALUE_TYPE_UNSIGNED_SHORT = 6\nNVML_VALUE_TYPE_COUNT = 7\n\n_nvmlNvlinkVersion_t = c_uint\nNVML_NVLINK_VERSION_INVALID = 0\nNVML_NVLINK_VERSION_1_0 = 1\nNVML_NVLINK_VERSION_2_0 = 2\nNVML_NVLINK_VERSION_2_2 = 3\nNVML_NVLINK_VERSION_3_0 = 4\nNVML_NVLINK_VERSION_3_1 = 5\nNVML_NVLINK_VERSION_4_0 = 6\nNVML_NVLINK_VERSION_5_0 = 7\n\n_nvmlPerfPolicyType_t = c_uint\nNVML_PERF_POLICY_POWER = 0\nNVML_PERF_POLICY_THERMAL = 1\nNVML_PERF_POLICY_SYNC_BOOST = 2\nNVML_PERF_POLICY_BOARD_LIMIT = 3\nNVML_PERF_POLICY_LOW_UTILIZATION = 4\nNVML_PERF_POLICY_RELIABILITY = 5\nNVML_PERF_POLICY_TOTAL_APP_CLOCKS = 10\nNVML_PERF_POLICY_TOTAL_BASE_CLOCKS = 11\nNVML_PERF_POLICY_COUNT = 12\n\n_nvmlEncoderQueryType_t = c_uint\nNVML_ENCODER_QUERY_H264 = 0\nNVML_ENCODER_QUERY_HEVC = 1\nNVML_ENCODER_QUERY_AV1 = 2\nNVML_ENCODER_QUERY_UNKNOWN = 255\n\n_nvmlFBCSessionType_t = c_uint\nNVML_FBC_SESSION_TYPE_UNKNOWN = 0\nNVML_FBC_SESSION_TYPE_TOSYS = 1\nNVML_FBC_SESSION_TYPE_CUDA = 2\nNVML_FBC_SESSION_TYPE_VID = 3\nNVML_FBC_SESSION_TYPE_HWENC = 4\n\n_nvmlDetachGpuState_t = c_uint\nNVML_DETACH_GPU_KEEP = 0\nNVML_DETACH_GPU_REMOVE = 1\n\n_nvmlPcieLinkState_t = c_uint\nNVML_PCIE_LINK_KEEP = 0\nNVML_PCIE_LINK_SHUT_DOWN = 1\n\n_nvmlSamplingType_t = c_uint\nNVML_TOTAL_POWER_SAMPLES = 0\nNVML_GPU_UTILIZATION_SAMPLES = 1\nNVML_MEMORY_UTILIZATION_SAMPLES = 2\nNVML_ENC_UTILIZATION_SAMPLES = 3\nNVML_DEC_UTILIZATION_SAMPLES = 4\nNVML_PROCESSOR_CLK_SAMPLES = 5\nNVML_MEMORY_CLK_SAMPLES = 6\nNVML_MODULE_POWER_SAMPLES = 7\nNVML_JPG_UTILIZATION_SAMPLES = 8\nNVML_OFA_UTILIZATION_SAMPLES = 9\nNVML_SAMPLINGTYPE_COUNT = 10\n\n_nvmlPcieUtilCounter_t = c_uint\nNVML_PCIE_UTIL_TX_BYTES = 0\nNVML_PCIE_UTIL_RX_BYTES = 1\nNVML_PCIE_UTIL_COUNT = 2\n\n_nvmlGpuTopologyLevel_t = c_uint\nNVML_TOPOLOGY_INTERNAL = 0\nNVML_TOPOLOGY_SINGLE = 10\nNVML_TOPOLOGY_MULTIPLE = 20\nNVML_TOPOLOGY_HOSTBRIDGE = 30\nNVML_TOPOLOGY_NODE = 40\nNVML_TOPOLOGY_CPU = NVML_TOPOLOGY_NODE\nNVML_TOPOLOGY_SYSTEM = 50\n\n_nvmlGpuP2PCapsIndex_t = c_uint\nNVML_P2P_CAPS_INDEX_READ = (0,)\nNVML_P2P_CAPS_INDEX_WRITE = 1\nNVML_P2P_CAPS_INDEX_NVLINK = 2\nNVML_P2P_CAPS_INDEX_ATOMICS = 3\n#\n# NVML_P2P_CAPS_INDEX_PROP is deprecated.\n# Use NVML_P2P_CAPS_INDEX_PCI instead.\n#\nNVML_P2P_CAPS_INDEX_PROP = 4\nNVML_P2P_CAPS_INDEX_PCI = 4\nNVML_P2P_CAPS_INDEX_UNKNOWN = 5\n\n_nvmlGpuP2PStatus_t = c_uint\nNVML_P2P_STATUS_OK = 0\nNVML_P2P_STATUS_CHIPSET_NOT_SUPPORED = 1\nNVML_P2P_STATUS_CHIPSET_NOT_SUPPORTED = NVML_P2P_STATUS_CHIPSET_NOT_SUPPORED\nNVML_P2P_STATUS_GPU_NOT_SUPPORTED = 2\nNVML_P2P_STATUS_IOH_TOPOLOGY_NOT_SUPPORTED = 3\nNVML_P2P_STATUS_DISABLED_BY_REGKEY = 4\nNVML_P2P_STATUS_NOT_SUPPORTED = 5\nNVML_P2P_STATUS_UNKNOWN = 6\n\n_nvmlDeviceArchitecture_t = c_uint\nNVML_DEVICE_ARCH_KEPLER = 2\nNVML_DEVICE_ARCH_MAXWELL = 3\nNVML_DEVICE_ARCH_PASCAL = 4\nNVML_DEVICE_ARCH_VOLTA = 5\nNVML_DEVICE_ARCH_TURING = 6\nNVML_DEVICE_ARCH_AMPERE = 7\nNVML_DEVICE_ARCH_ADA = 8\nNVML_DEVICE_ARCH_HOPPER = 9\nNVML_DEVICE_ARCH_BLACKWELL = 10\nNVML_DEVICE_ARCH_T23X = 11\nNVML_DEVICE_ARCH_UNKNOWN = 0xFFFFFFFF\n\n# PCI bus Types\n_nvmlBusType_t = c_uint\nNVML_BUS_TYPE_UNKNOWN = 0\nNVML_BUS_TYPE_PCI = 1\nNVML_BUS_TYPE_PCIE = 2\nNVML_BUS_TYPE_FPCI = 3\nNVML_BUS_TYPE_AGP = 4\n\n_nvmlPowerSource_t = c_uint\nNVML_POWER_SOURCE_AC = 0x00000000\nNVML_POWER_SOURCE_BATTERY = 0x00000001\nNVML_POWER_SOURCE_UNDERSIZED = 0x00000002\n\n_nvmlAdaptiveClockInfoStatus_t = c_uint\nNVML_ADAPTIVE_CLOCKING_INFO_STATUS_DISABLED = 0x00000000\nNVML_ADAPTIVE_CLOCKING_INFO_STATUS_ENABLED = 0x00000001\n\n_nvmlClockLimitId_t = c_uint\nNVML_CLOCK_LIMIT_ID_RANGE_START = 0xFFFFFF00\nNVML_CLOCK_LIMIT_ID_TDP = 0xFFFFFF01\nNVML_CLOCK_LIMIT_ID_UNLIMITED = 0xFFFFFF02\n\n_nvmlPcieLinkMaxSpeed_t = c_uint\nNVML_PCIE_LINK_MAX_SPEED_INVALID = 0x00000000\nNVML_PCIE_LINK_MAX_SPEED_2500MBPS = 0x00000001\nNVML_PCIE_LINK_MAX_SPEED_5000MBPS = 0x00000002\nNVML_PCIE_LINK_MAX_SPEED_8000MBPS = 0x00000003\nNVML_PCIE_LINK_MAX_SPEED_16000MBPS = 0x00000004\nNVML_PCIE_LINK_MAX_SPEED_32000MBPS = 0x00000005\nNVML_PCIE_LINK_MAX_SPEED_64000MBPS = 0x00000006\n\n_nvmlPcieAtomicsCapability_t = c_uint\nNVML_PCIE_ATOMICS_CAP_FETCHADD32 = 0x01\nNVML_PCIE_ATOMICS_CAP_FETCHADD64 = 0x02\nNVML_PCIE_ATOMICS_CAP_SWAP32 = 0x04\nNVML_PCIE_ATOMICS_CAP_SWAP64 = 0x08\nNVML_PCIE_ATOMICS_CAP_CAS32 = 0x10\nNVML_PCIE_ATOMICS_CAP_CAS64 = 0x20\nNVML_PCIE_ATOMICS_CAP_CAS128 = 0x40\nNVML_PCIE_ATOMICS_OPS_MAX = 7\n\n_nvmlAffinityScope_t = c_uint\nNVML_AFFINITY_SCOPE_NODE = 0\nNVML_AFFINITY_SCOPE_SOCKET = 1\n\n_nvmlDeviceGpuRecoveryAction_t = c_uint\nNVML_GPU_RECOVERY_ACTION_NONE = 0\nNVML_GPU_RECOVERY_ACTION_GPU_RESET = 1\nNVML_GPU_RECOVERY_ACTION_NODE_REBOOT = 2\nNVML_GPU_RECOVERY_ACTION_DRAIN_P2P = 3\nNVML_GPU_RECOVERY_ACTION_DRAIN_AND_RESET = 4\n\n# C preprocessor defined values\nnvmlFlagDefault = 0\nnvmlFlagForce = 1\nNVML_INIT_FLAG_NO_GPUS = 1\nNVML_INIT_FLAG_NO_ATTACH = 2\n\nNVML_MAX_GPC_COUNT = 32\n\n# buffer size\nNVML_DEVICE_INFOROM_VERSION_BUFFER_SIZE = 16\nNVML_DEVICE_UUID_BUFFER_SIZE = 80\nNVML_DEVICE_UUID_V2_BUFFER_SIZE = 96\nNVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE = 80\nNVML_SYSTEM_NVML_VERSION_BUFFER_SIZE = 80\nNVML_DEVICE_NAME_BUFFER_SIZE = 64\nNVML_DEVICE_NAME_V2_BUFFER_SIZE = 96\nNVML_DEVICE_SERIAL_BUFFER_SIZE = 30\nNVML_DEVICE_PART_NUMBER_BUFFER_SIZE = 80\nNVML_DEVICE_GPU_PART_NUMBER_BUFFER_SIZE = 80\nNVML_DEVICE_VBIOS_VERSION_BUFFER_SIZE = 32\nNVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE = 32\nNVML_DEVICE_PCI_BUS_ID_BUFFER_V2_SIZE = 16\nNVML_GRID_LICENSE_BUFFER_SIZE = 128\nNVML_VGPU_NAME_BUFFER_SIZE = 64\nNVML_GRID_LICENSE_FEATURE_MAX_COUNT = 3\nNVML_VGPU_METADATA_OPAQUE_DATA_SIZE = sizeof(c_uint) + 256\nNVML_VGPU_PGPU_METADATA_OPAQUE_DATA_SIZE = 256\nNVML_DEVICE_GPU_FRU_PART_NUMBER_BUFFER_SIZE = (\n    0x14  # NV2080_GPU_MAX_PRODUCT_PART_NUMBER_LENGTH\n)\nNVML_PERF_MODES_BUFFER_SIZE = 2048\n\n# Format strings\nNVML_DEVICE_PCI_BUS_ID_LEGACY_FMT = \"%04X:%02X:%02X.0\"\nNVML_DEVICE_PCI_BUS_ID_FMT = \"%08X:%02X:%02X.0\"\n\nNVML_VALUE_NOT_AVAILABLE_ulonglong = c_ulonglong(-1)\nNVML_VALUE_NOT_AVAILABLE_uint = c_uint(-1)\n\n\"\"\"\n Field Identifiers.\n\n All Identifiers pertain to a device. Each ID is only used once and is guaranteed never to change.\n\"\"\"\nNVML_FI_DEV_ECC_CURRENT = 1  # Current ECC mode. 1=Active. 0=Inactive\nNVML_FI_DEV_ECC_PENDING = 2  # Pending ECC mode. 1=Active. 0=Inactive\n\n# ECC Count Totals\nNVML_FI_DEV_ECC_SBE_VOL_TOTAL = 3  # Total single bit volatile ECC errors\nNVML_FI_DEV_ECC_DBE_VOL_TOTAL = 4  # Total double bit volatile ECC errors\nNVML_FI_DEV_ECC_SBE_AGG_TOTAL = 5  # Total single bit aggregate (persistent) ECC errors\nNVML_FI_DEV_ECC_DBE_AGG_TOTAL = 6  # Total double bit aggregate (persistent) ECC errors\n# Individual ECC locations\nNVML_FI_DEV_ECC_SBE_VOL_L1 = 7  # L1 cache single bit volatile ECC errors\nNVML_FI_DEV_ECC_DBE_VOL_L1 = 8  # L1 cache double bit volatile ECC errors\nNVML_FI_DEV_ECC_SBE_VOL_L2 = 9  # L2 cache single bit volatile ECC errors\nNVML_FI_DEV_ECC_DBE_VOL_L2 = 10  # L2 cache double bit volatile ECC errors\nNVML_FI_DEV_ECC_SBE_VOL_DEV = 11  # Device memory single bit volatile ECC errors\nNVML_FI_DEV_ECC_DBE_VOL_DEV = 12  # Device memory double bit volatile ECC errors\nNVML_FI_DEV_ECC_SBE_VOL_REG = 13  # Register file single bit volatile ECC errors\nNVML_FI_DEV_ECC_DBE_VOL_REG = 14  # Register file double bit volatile ECC errors\nNVML_FI_DEV_ECC_SBE_VOL_TEX = 15  # Texture memory single bit volatile ECC errors\nNVML_FI_DEV_ECC_DBE_VOL_TEX = 16  # Texture memory double bit volatile ECC errors\nNVML_FI_DEV_ECC_DBE_VOL_CBU = 17  # CBU double bit volatile ECC errors\nNVML_FI_DEV_ECC_SBE_AGG_L1 = 18  # L1 cache single bit aggregate (persistent) ECC errors\nNVML_FI_DEV_ECC_DBE_AGG_L1 = 19  # L1 cache double bit aggregate (persistent) ECC errors\nNVML_FI_DEV_ECC_SBE_AGG_L2 = 20  # L2 cache single bit aggregate (persistent) ECC errors\nNVML_FI_DEV_ECC_DBE_AGG_L2 = 21  # L2 cache double bit aggregate (persistent) ECC errors\nNVML_FI_DEV_ECC_SBE_AGG_DEV = (\n    22  # Device memory single bit aggregate (persistent) ECC errors\n)\nNVML_FI_DEV_ECC_DBE_AGG_DEV = (\n    23  # Device memory double bit aggregate (persistent) ECC errors\n)\nNVML_FI_DEV_ECC_SBE_AGG_REG = (\n    24  # Register File single bit aggregate (persistent) ECC errors\n)\nNVML_FI_DEV_ECC_DBE_AGG_REG = (\n    25  # Register File double bit aggregate (persistent) ECC errors\n)\nNVML_FI_DEV_ECC_SBE_AGG_TEX = (\n    26  # Texture memory single bit aggregate (persistent) ECC errors\n)\nNVML_FI_DEV_ECC_DBE_AGG_TEX = (\n    27  # Texture memory double bit aggregate (persistent) ECC errors\n)\nNVML_FI_DEV_ECC_DBE_AGG_CBU = 28  # CBU double bit aggregate ECC errors\n\n# Page Retirement\nNVML_FI_DEV_RETIRED_SBE = 29  # Number of retired pages because of single bit errors\nNVML_FI_DEV_RETIRED_DBE = 30  # Number of retired pages because of double bit errors\nNVML_FI_DEV_RETIRED_PENDING = 31  # If any pages are pending retirement. 1=yes. 0=no.\n\n# NvLink Flit Error Counters\nNVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L0 = (\n    32  # NVLink flow control CRC  Error Counter for Lane 0\n)\nNVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L1 = (\n    33  # NVLink flow control CRC  Error Counter for Lane 1\n)\nNVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L2 = (\n    34  # NVLink flow control CRC  Error Counter for Lane 2\n)\nNVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L3 = (\n    35  # NVLink flow control CRC  Error Counter for Lane 3\n)\nNVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L4 = (\n    36  # NVLink flow control CRC  Error Counter for Lane 4\n)\nNVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L5 = (\n    37  # NVLink flow control CRC  Error Counter for Lane 5\n)\nNVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_TOTAL = (\n    38  # NVLink flow control CRC  Error Counter total for all Lanes\n)\n\n# NvLink CRC Data Error Counters\nNVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L0 = (\n    39  # NVLink data CRC Error Counter for Lane 0\n)\nNVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L1 = (\n    40  # NVLink data CRC Error Counter for Lane 1\n)\nNVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L2 = (\n    41  # NVLink data CRC Error Counter for Lane 2\n)\nNVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L3 = (\n    42  # NVLink data CRC Error Counter for Lane 3\n)\nNVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L4 = (\n    43  # NVLink data CRC Error Counter for Lane 4\n)\nNVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L5 = (\n    44  # NVLink data CRC Error Counter for Lane 5\n)\nNVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_TOTAL = (\n    45  # NvLink data CRC Error Counter total for all Lanes\n)\n\n# NvLink Replay Error Counters\nNVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L0 = 46  # NVLink Replay Error Counter for Lane 0\nNVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L1 = 47  # NVLink Replay Error Counter for Lane 1\nNVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L2 = 48  # NVLink Replay Error Counter for Lane 2\nNVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L3 = 49  # NVLink Replay Error Counter for Lane 3\nNVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L4 = 50  # NVLink Replay Error Counter for Lane 4\nNVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L5 = 51  # NVLink Replay Error Counter for Lane 5\nNVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_TOTAL = (\n    52  # NVLink Replay Error Counter total for all Lanes\n)\n\n# NvLink Recovery Error Counters\nNVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L0 = (\n    53  # NVLink Recovery Error Counter for Lane 0\n)\nNVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L1 = (\n    54  # NVLink Recovery Error Counter for Lane 1\n)\nNVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L2 = (\n    55  # NVLink Recovery Error Counter for Lane 2\n)\nNVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L3 = (\n    56  # NVLink Recovery Error Counter for Lane 3\n)\nNVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L4 = (\n    57  # NVLink Recovery Error Counter for Lane 4\n)\nNVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L5 = (\n    58  # NVLink Recovery Error Counter for Lane 5\n)\nNVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_TOTAL = (\n    59  # NVLink Recovery Error Counter total for all Lanes\n)\n\n# NvLink Bandwidth Counters\nNVML_FI_DEV_NVLINK_BANDWIDTH_C0_L0 = (\n    60  # NVLink Bandwidth Counter for Counter Set 0, Lane 0\n)\nNVML_FI_DEV_NVLINK_BANDWIDTH_C0_L1 = (\n    61  # NVLink Bandwidth Counter for Counter Set 0, Lane 1\n)\nNVML_FI_DEV_NVLINK_BANDWIDTH_C0_L2 = (\n    62  # NVLink Bandwidth Counter for Counter Set 0, Lane 2\n)\nNVML_FI_DEV_NVLINK_BANDWIDTH_C0_L3 = (\n    63  # NVLink Bandwidth Counter for Counter Set 0, Lane 3\n)\nNVML_FI_DEV_NVLINK_BANDWIDTH_C0_L4 = (\n    64  # NVLink Bandwidth Counter for Counter Set 0, Lane 4\n)\nNVML_FI_DEV_NVLINK_BANDWIDTH_C0_L5 = (\n    65  # NVLink Bandwidth Counter for Counter Set 0, Lane 5\n)\nNVML_FI_DEV_NVLINK_BANDWIDTH_C0_TOTAL = (\n    66  # NVLink Bandwidth Counter Total for Counter Set 0, All Lanes\n)\n\n# NvLink Bandwidth Counters\nNVML_FI_DEV_NVLINK_BANDWIDTH_C1_L0 = (\n    67  # NVLink Bandwidth Counter for Counter Set 1, Lane 0\n)\nNVML_FI_DEV_NVLINK_BANDWIDTH_C1_L1 = (\n    68  # NVLink Bandwidth Counter for Counter Set 1, Lane 1\n)\nNVML_FI_DEV_NVLINK_BANDWIDTH_C1_L2 = (\n    69  # NVLink Bandwidth Counter for Counter Set 1, Lane 2\n)\nNVML_FI_DEV_NVLINK_BANDWIDTH_C1_L3 = (\n    70  # NVLink Bandwidth Counter for Counter Set 1, Lane 3\n)\nNVML_FI_DEV_NVLINK_BANDWIDTH_C1_L4 = (\n    71  # NVLink Bandwidth Counter for Counter Set 1, Lane 4\n)\nNVML_FI_DEV_NVLINK_BANDWIDTH_C1_L5 = (\n    72  # NVLink Bandwidth Counter for Counter Set 1, Lane 5\n)\nNVML_FI_DEV_NVLINK_BANDWIDTH_C1_TOTAL = (\n    73  # NVLink Bandwidth Counter Total for Counter Set 1, All Lanes\n)\n\n# Perf Policy Counters\nNVML_FI_DEV_PERF_POLICY_POWER = 74  # Perf Policy Counter for Power Policy\nNVML_FI_DEV_PERF_POLICY_THERMAL = 75  # Perf Policy Counter for Thermal Policy\nNVML_FI_DEV_PERF_POLICY_SYNC_BOOST = 76  # Perf Policy Counter for Sync boost Policy\nNVML_FI_DEV_PERF_POLICY_BOARD_LIMIT = 77  # Perf Policy Counter for Board Limit\nNVML_FI_DEV_PERF_POLICY_LOW_UTILIZATION = (\n    78  # Perf Policy Counter for Low GPU Utilization Policy\n)\nNVML_FI_DEV_PERF_POLICY_RELIABILITY = 79  # Perf Policy Counter for Reliability Policy\nNVML_FI_DEV_PERF_POLICY_TOTAL_APP_CLOCKS = (\n    80  # Perf Policy Counter for Total App Clock Policy\n)\nNVML_FI_DEV_PERF_POLICY_TOTAL_BASE_CLOCKS = (\n    81  # Perf Policy Counter for Total Base Clocks Policy\n)\n\n# Memory temperatures\nNVML_FI_DEV_MEMORY_TEMP = 82  # Memory temperature for the device\n\n# Energy Counter\nNVML_FI_DEV_TOTAL_ENERGY_CONSUMPTION = (\n    83  # Total energy consumption for the GPU in mJ since the driver was last reloaded\n)\n\n# NVLink Speed\nNVML_FI_DEV_NVLINK_SPEED_MBPS_L0 = 84\nNVML_FI_DEV_NVLINK_SPEED_MBPS_L1 = 85\nNVML_FI_DEV_NVLINK_SPEED_MBPS_L2 = 86\nNVML_FI_DEV_NVLINK_SPEED_MBPS_L3 = 87\nNVML_FI_DEV_NVLINK_SPEED_MBPS_L4 = 88\nNVML_FI_DEV_NVLINK_SPEED_MBPS_L5 = 89\nNVML_FI_DEV_NVLINK_SPEED_MBPS_COMMON = 90\n\n# NVLink Link Count\nNVML_FI_DEV_NVLINK_LINK_COUNT = 91\n\n# Page Retirement pending fields\nNVML_FI_DEV_RETIRED_PENDING_SBE = 92\nNVML_FI_DEV_RETIRED_PENDING_DBE = 93\n\n# PCIe replay and replay rollover counters\nNVML_FI_DEV_PCIE_REPLAY_COUNTER = 94\nNVML_FI_DEV_PCIE_REPLAY_ROLLOVER_COUNTER = 95\n\n# NvLink Flit Error Counters\nNVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L6 = (\n    96  # NVLink flow control CRC  Error Counter for Lane 6\n)\nNVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L7 = (\n    97  # NVLink flow control CRC  Error Counter for Lane 7\n)\nNVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L8 = (\n    98  # NVLink flow control CRC  Error Counter for Lane 8\n)\nNVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L9 = (\n    99  # NVLink flow control CRC  Error Counter for Lane 9\n)\nNVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L10 = (\n    100  # NVLink flow control CRC  Error Counter for Lane 10\n)\nNVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L11 = (\n    101  # NVLink flow control CRC  Error Counter for Lane 11\n)\n\n# NvLink CRC Data Error Counters\nNVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L6 = (\n    102  # NVLink data CRC Error Counter for Lane 6\n)\nNVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L7 = (\n    103  # NVLink data CRC Error Counter for Lane 7\n)\nNVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L8 = (\n    104  # NVLink data CRC Error Counter for Lane 8\n)\nNVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L9 = (\n    105  # NVLink data CRC Error Counter for Lane 9\n)\nNVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L10 = (\n    106  # NVLink data CRC Error Counter for Lane 10\n)\nNVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L11 = (\n    107  # NVLink data CRC Error Counter for Lane 11\n)\n\n# NvLink Replay Error Counters\nNVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L6 = 108  # NVLink Replay Error Counter for Lane 6\nNVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L7 = 109  # NVLink Replay Error Counter for Lane 7\nNVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L8 = 110  # NVLink Replay Error Counter for Lane 8\nNVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L9 = 111  # NVLink Replay Error Counter for Lane 9\nNVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L10 = (\n    112  # NVLink Replay Error Counter for Lane 10\n)\nNVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L11 = (\n    113  # NVLink Replay Error Counter for Lane 11\n)\n\n# NvLink Recovery Error Counters\nNVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L6 = (\n    114  # NVLink Recovery Error Counter for Lane 6\n)\nNVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L7 = (\n    115  # NVLink Recovery Error Counter for Lane 7\n)\nNVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L8 = (\n    116  # NVLink Recovery Error Counter for Lane 8\n)\nNVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L9 = (\n    117  # NVLink Recovery Error Counter for Lane 9\n)\nNVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L10 = (\n    118  # NVLink Recovery Error Counter for Lane 10\n)\nNVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L11 = (\n    119  # NVLink Recovery Error Counter for Lane 11\n)\n\n# NvLink Bandwidth Counters\nNVML_FI_DEV_NVLINK_BANDWIDTH_C0_L6 = (\n    120  # NVLink Bandwidth Counter for Counter Set 0, Lane 6\n)\nNVML_FI_DEV_NVLINK_BANDWIDTH_C0_L7 = (\n    121  # NVLink Bandwidth Counter for Counter Set 0, Lane 7\n)\nNVML_FI_DEV_NVLINK_BANDWIDTH_C0_L8 = (\n    122  # NVLink Bandwidth Counter for Counter Set 0, Lane 8\n)\nNVML_FI_DEV_NVLINK_BANDWIDTH_C0_L9 = (\n    123  # NVLink Bandwidth Counter for Counter Set 0, Lane 9\n)\nNVML_FI_DEV_NVLINK_BANDWIDTH_C0_L10 = (\n    124  # NVLink Bandwidth Counter for Counter Set 0, Lane 10\n)\nNVML_FI_DEV_NVLINK_BANDWIDTH_C0_L11 = (\n    125  # NVLink Bandwidth Counter for Counter Set 0, Lane 11\n)\n\n# NvLink Bandwidth Counters\nNVML_FI_DEV_NVLINK_BANDWIDTH_C1_L6 = (\n    126  # NVLink Bandwidth Counter for Counter Set 1, Lane 6\n)\nNVML_FI_DEV_NVLINK_BANDWIDTH_C1_L7 = (\n    127  # NVLink Bandwidth Counter for Counter Set 1, Lane 7\n)\nNVML_FI_DEV_NVLINK_BANDWIDTH_C1_L8 = (\n    128  # NVLink Bandwidth Counter for Counter Set 1, Lane 8\n)\nNVML_FI_DEV_NVLINK_BANDWIDTH_C1_L9 = (\n    129  # NVLink Bandwidth Counter for Counter Set 1, Lane 9\n)\nNVML_FI_DEV_NVLINK_BANDWIDTH_C1_L10 = (\n    130  # NVLink Bandwidth Counter for Counter Set 1, Lane 10\n)\nNVML_FI_DEV_NVLINK_BANDWIDTH_C1_L11 = (\n    131  # NVLink Bandwidth Counter for Counter Set 1, Lane 11\n)\n\n# NVLink Speed\nNVML_FI_DEV_NVLINK_SPEED_MBPS_L6 = 132\nNVML_FI_DEV_NVLINK_SPEED_MBPS_L7 = 133\nNVML_FI_DEV_NVLINK_SPEED_MBPS_L8 = 134\nNVML_FI_DEV_NVLINK_SPEED_MBPS_L9 = 135\nNVML_FI_DEV_NVLINK_SPEED_MBPS_L10 = 136\nNVML_FI_DEV_NVLINK_SPEED_MBPS_L11 = 137\n\n# NVLink Throughput Counters\nNVML_FI_DEV_NVLINK_THROUGHPUT_DATA_TX = 138  # NVLink TX Data throughput in KiB\nNVML_FI_DEV_NVLINK_THROUGHPUT_DATA_RX = 139  # NVLink RX Data throughput in KiB\nNVML_FI_DEV_NVLINK_THROUGHPUT_RAW_TX = 140  # NVLink TX Data + protocol overhead in KiB\nNVML_FI_DEV_NVLINK_THROUGHPUT_RAW_RX = 141  # NVLink RX Data + protocol overhead in KiB\n\n# Row Remapper\nNVML_FI_DEV_REMAPPED_COR = 142\nNVML_FI_DEV_REMAPPED_UNC = 143\nNVML_FI_DEV_REMAPPED_PENDING = 144\nNVML_FI_DEV_REMAPPED_FAILURE = 145\n\n# Remote device NVLink ID\nNVML_FI_DEV_NVLINK_REMOTE_NVLINK_ID = 146\n\n# Number of NVLinks connected to NVSwitch\nNVML_FI_DEV_NVSWITCH_CONNECTED_LINK_COUNT = 147\n\n# NvLink ECC Data Error Counters\nNVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L0 = (\n    148  # < NVLink data ECC Error Counter for Link 0\n)\nNVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L1 = (\n    149  # < NVLink data ECC Error Counter for Link 1\n)\nNVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L2 = (\n    150  # < NVLink data ECC Error Counter for Link 2\n)\nNVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L3 = (\n    151  # < NVLink data ECC Error Counter for Link 3\n)\nNVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L4 = (\n    152  # < NVLink data ECC Error Counter for Link 4\n)\nNVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L5 = (\n    153  # < NVLink data ECC Error Counter for Link 5\n)\nNVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L6 = (\n    154  # < NVLink data ECC Error Counter for Link 6\n)\nNVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L7 = (\n    155  # < NVLink data ECC Error Counter for Link 7\n)\nNVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L8 = (\n    156  # < NVLink data ECC Error Counter for Link 8\n)\nNVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L9 = (\n    157  # < NVLink data ECC Error Counter for Link 9\n)\nNVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L10 = (\n    158  # < NVLink data ECC Error Counter for Link 10\n)\nNVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L11 = (\n    159  # < NVLink data ECC Error Counter for Link 11\n)\nNVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_TOTAL = (\n    160  # < NvLink data ECC Error Counter total for all Links\n)\n\nNVML_FI_DEV_NVLINK_ERROR_DL_REPLAY = 161\nNVML_FI_DEV_NVLINK_ERROR_DL_RECOVERY = 162\nNVML_FI_DEV_NVLINK_ERROR_DL_CRC = 163\nNVML_FI_DEV_NVLINK_GET_SPEED = 164\nNVML_FI_DEV_NVLINK_GET_STATE = 165\nNVML_FI_DEV_NVLINK_GET_VERSION = 166\n\nNVML_FI_DEV_NVLINK_GET_POWER_STATE = 167\nNVML_FI_DEV_NVLINK_GET_POWER_THRESHOLD = 168\n\nNVML_FI_DEV_PCIE_L0_TO_RECOVERY_COUNTER = 169\n\nNVML_FI_DEV_C2C_LINK_COUNT = 170\nNVML_FI_DEV_C2C_LINK_GET_STATUS = 171\nNVML_FI_DEV_C2C_LINK_GET_MAX_BW = 172\n\nNVML_FI_DEV_PCIE_COUNT_CORRECTABLE_ERRORS = 173\nNVML_FI_DEV_PCIE_COUNT_NAKS_RECEIVED = 174\nNVML_FI_DEV_PCIE_COUNT_RECEIVER_ERROR = 175\nNVML_FI_DEV_PCIE_COUNT_BAD_TLP = 176\nNVML_FI_DEV_PCIE_COUNT_NAKS_SENT = 177\nNVML_FI_DEV_PCIE_COUNT_BAD_DLLP = 178\nNVML_FI_DEV_PCIE_COUNT_NON_FATAL_ERROR = 179\nNVML_FI_DEV_PCIE_COUNT_FATAL_ERROR = 180\nNVML_FI_DEV_PCIE_COUNT_UNSUPPORTED_REQ = 181\nNVML_FI_DEV_PCIE_COUNT_LCRC_ERROR = 182\nNVML_FI_DEV_PCIE_COUNT_LANE_ERROR = 183\n\nNVML_FI_DEV_IS_RESETLESS_MIG_SUPPORTED = 184\n\nNVML_FI_DEV_POWER_AVERAGE = 185\nNVML_FI_DEV_POWER_INSTANT = 186\nNVML_FI_DEV_POWER_MIN_LIMIT = 187\nNVML_FI_DEV_POWER_MAX_LIMIT = 188\nNVML_FI_DEV_POWER_DEFAULT_LIMIT = 189\nNVML_FI_DEV_POWER_CURRENT_LIMIT = 190\nNVML_FI_DEV_ENERGY = 191\nNVML_FI_DEV_POWER_REQUESTED_LIMIT = 192\n\nNVML_FI_DEV_TEMPERATURE_SHUTDOWN_TLIMIT = 193\nNVML_FI_DEV_TEMPERATURE_SLOWDOWN_TLIMIT = 194\nNVML_FI_DEV_TEMPERATURE_MEM_MAX_TLIMIT = 195\nNVML_FI_DEV_TEMPERATURE_GPU_MAX_TLIMIT = 196\n\nNVML_FI_DEV_PCIE_COUNT_TX_BYTES = 197\nNVML_FI_DEV_PCIE_COUNT_RX_BYTES = 198\n\nNVML_FI_DEV_IS_MIG_MODE_INDEPENDENT_MIG_QUERY_CAPABLE = 199\n\nNVML_FI_DEV_NVLINK_GET_POWER_THRESHOLD_MAX = 200\n\nNVML_FI_DEV_NVLINK_COUNT_XMIT_PACKETS = 201\nNVML_FI_DEV_NVLINK_COUNT_XMIT_BYTES = 202\nNVML_FI_DEV_NVLINK_COUNT_RCV_PACKETS = 203\nNVML_FI_DEV_NVLINK_COUNT_RCV_BYTES = 204\nNVML_FI_DEV_NVLINK_COUNT_VL15_DROPPED = 205  # Deprecated, do not use\nNVML_FI_DEV_NVLINK_COUNT_MALFORMED_PACKET_ERRORS = 206\nNVML_FI_DEV_NVLINK_COUNT_BUFFER_OVERRUN_ERRORS = 207\nNVML_FI_DEV_NVLINK_COUNT_RCV_ERRORS = 208\nNVML_FI_DEV_NVLINK_COUNT_RCV_REMOTE_ERRORS = 209\nNVML_FI_DEV_NVLINK_COUNT_RCV_GENERAL_ERRORS = 210\nNVML_FI_DEV_NVLINK_COUNT_LOCAL_LINK_INTEGRITY_ERRORS = 211\nNVML_FI_DEV_NVLINK_COUNT_XMIT_DISCARDS = 212\n\nNVML_FI_DEV_NVLINK_COUNT_LINK_RECOVERY_SUCCESSFUL_EVENTS = 213\nNVML_FI_DEV_NVLINK_COUNT_LINK_RECOVERY_FAILED_EVENTS = 214\nNVML_FI_DEV_NVLINK_COUNT_LINK_RECOVERY_EVENTS = 215\n\nNVML_FI_DEV_NVLINK_COUNT_RAW_BER_LANE0 = 216  # Deprecated, do not use\nNVML_FI_DEV_NVLINK_COUNT_RAW_BER_LANE1 = 217  # Deprecated, do not use\nNVML_FI_DEV_NVLINK_COUNT_RAW_BER = 218  # Deprecated, do not use\nNVML_FI_DEV_NVLINK_COUNT_EFFECTIVE_ERRORS = 219\nNVML_FI_DEV_NVLINK_COUNT_EFFECTIVE_BER = 220\nNVML_FI_DEV_NVLINK_COUNT_SYMBOL_ERRORS = 221\nNVML_FI_DEV_NVLINK_COUNT_SYMBOL_BER = 222\n\nNVML_FI_DEV_NVLINK_GET_POWER_THRESHOLD_MIN = 223\nNVML_FI_DEV_NVLINK_GET_POWER_THRESHOLD_UNITS = (\n    224  # Values are in the form NVML_NVLINK_LOW_POWER_THRESHOLD_UNIT_*\n)\nNVML_FI_DEV_NVLINK_GET_POWER_THRESHOLD_SUPPORTED = 225\n\nNVML_FI_DEV_RESET_STATUS = (\n    226  # Deprecated use NVML_FI_DEV_GET_GPU_RECOVERY_ACTION instead\n)\nNVML_FI_DEV_DRAIN_AND_RESET_STATUS = (\n    227  # Deprecated use NVML_FI_DEV_GET_GPU_RECOVERY_ACTION instead\n)\nNVML_FI_DEV_PCIE_OUTBOUND_ATOMICS_MASK = 228\nNVML_FI_DEV_PCIE_INBOUND_ATOMICS_MASK = 229\nNVML_FI_DEV_GET_GPU_RECOVERY_ACTION = 230\n\nNVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_0 = 235\nNVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_1 = 236\nNVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_2 = 237\nNVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_3 = 238\nNVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_4 = 239\nNVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_5 = 240\nNVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_6 = 241\nNVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_7 = 242\nNVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_8 = 243\nNVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_9 = 244\nNVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_10 = 245\nNVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_11 = 246\nNVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_12 = 247\nNVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_13 = 248\nNVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_14 = 249\nNVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_15 = 250\nNVML_FI_PWR_SMOOTHING_ENABLED = 251  # Enablement (0/DISABLED or 1/ENABLED)\nNVML_FI_PWR_SMOOTHING_PRIV_LVL = 252  # Current privilege level\nNVML_FI_PWR_SMOOTHING_IMM_RAMP_DOWN_ENABLED = (\n    253  # Immediate ramp down enablement (0/DISABLED or 1/ENABLED)\n)\nNVML_FI_PWR_SMOOTHING_APPLIED_TMP_CEIL = 254  # Applied TMP ceiling value\nNVML_FI_PWR_SMOOTHING_APPLIED_TMP_FLOOR = 255  # Applied TMP floor value\nNVML_FI_PWR_SMOOTHING_MAX_PERCENT_TMP_FLOOR_SETTING = 256  # Max % TMP Floor value\nNVML_FI_PWR_SMOOTHING_MIN_PERCENT_TMP_FLOOR_SETTING = 257  # Min % TMP Floor value\nNVML_FI_PWR_SMOOTHING_HW_CIRCUITRY_PERCENT_LIFETIME_REMAINING = (\n    258  # HW Circuitry % lifetime remaining\n)\nNVML_FI_PWR_SMOOTHING_MAX_NUM_PRESET_PROFILES = 259  # Max number of preset profiles\nNVML_FI_PWR_SMOOTHING_PROFILE_PERCENT_TMP_FLOOR = 260  # % TMP floor for a given profile\nNVML_FI_PWR_SMOOTHING_PROFILE_RAMP_UP_RATE = (\n    261  # Ramp up rate in mW/s for a given profile\n)\nNVML_FI_PWR_SMOOTHING_PROFILE_RAMP_DOWN_RATE = (\n    262  # Ramp down rate in mW/s for a given profile\n)\nNVML_FI_PWR_SMOOTHING_PROFILE_RAMP_DOWN_HYST_VAL = (\n    263  # Ramp down hysteresis value in ms for a given profile\n)\nNVML_FI_PWR_SMOOTHING_ACTIVE_PRESET_PROFILE = 264  # Active preset profile number\nNVML_FI_PWR_SMOOTHING_ADMIN_OVERRIDE_PERCENT_TMP_FLOOR = (\n    265  # % TMP floor for a given profile\n)\nNVML_FI_PWR_SMOOTHING_ADMIN_OVERRIDE_RAMP_UP_RATE = (\n    266  # Ramp up rate in mW/s for a given profile\n)\nNVML_FI_PWR_SMOOTHING_ADMIN_OVERRIDE_RAMP_DOWN_RATE = (\n    267  # Ramp down rate in mW/s for a given profile\n)\nNVML_FI_PWR_SMOOTHING_ADMIN_OVERRIDE_RAMP_DOWN_HYST_VAL = (\n    268  # Ramp down hysteresis value in ms for a given profile\n)\n\nNVML_FI_MAX = 269  # One greater than the largest field ID defined above\n\n# NVML_FI_DEV_NVLINK_GET_STATE state enums\nNVML_NVLINK_STATE_INACTIVE = 0x0\nNVML_NVLINK_STATE_ACTIVE = 0x1\nNVML_NVLINK_STATE_SLEEP = 0x2\n\nNVML_NVLINK_LOW_POWER_THRESHOLD_UNIT_100US = (\n    0  # NVML_FI_DEV_NVLINK_GET_POWER_THRESHOLD_UNITS\n)\nNVML_NVLINK_LOW_POWER_THRESHOLD_UNIT_50US = (\n    1  # NVML_FI_DEV_NVLINK_GET_POWER_THRESHOLD_UNITS\n)\n\n## Enums needed for the method nvmlDeviceGetVirtualizationMode and nvmlDeviceSetVirtualizationMode\nNVML_GPU_VIRTUALIZATION_MODE_NONE = 0  # Represents Bare Metal GPU\nNVML_GPU_VIRTUALIZATION_MODE_PASSTHROUGH = (\n    1  # Device is associated with GPU-Passthorugh\n)\nNVML_GPU_VIRTUALIZATION_MODE_VGPU = (\n    2  # Device is associated with vGPU inside virtual machine.\n)\nNVML_GPU_VIRTUALIZATION_MODE_HOST_VGPU = (\n    3  # Device is associated with VGX hypervisor in vGPU mode\n)\nNVML_GPU_VIRTUALIZATION_MODE_HOST_VSGA = (\n    4  # Device is associated with VGX hypervisor in vSGA mode\n)\n\n## Lib loading ##\nnvmlLib = None\nlibLoadLock = threading.Lock()\n_nvmlLib_refcount = 0  # Incremented on each nvmlInit and decremented on nvmlShutdown\n\n## vGPU Management\n_nvmlVgpuTypeId_t = c_uint\n_nvmlVgpuInstance_t = c_uint\n\n_nvmlVgpuVmIdType_t = c_uint\nNVML_VGPU_VM_ID_DOMAIN_ID = 0\nNVML_VGPU_VM_ID_UUID = 1\n\n_nvmlGridLicenseFeatureCode_t = c_uint\nNVML_GRID_LICENSE_FEATURE_CODE_UNKNOWN = 0\nNVML_GRID_LICENSE_FEATURE_CODE_VGPU = 1\nNVML_GRID_LICENSE_FEATURE_CODE_NVIDIA_RTX = 2\nNVML_GRID_LICENSE_FEATURE_CODE_VWORKSTATION = (\n    2  # deprecated, use NVML_GRID_LICENSE_FEATURE_CODE_NVIDIA_RTX.\n)\nNVML_GRID_LICENSE_FEATURE_CODE_GAMING = 3\nNVML_GRID_LICENSE_FEATURE_CODE_COMPUTE = 4\n\n_nvmlGridLicenseExpiryStatus_t = c_uint8\nNVML_GRID_LICENSE_EXPIRY_NOT_AVAILABLE = (0,)  # Expiry information not available\nNVML_GRID_LICENSE_EXPIRY_INVALID = (1,)  # Invalid expiry or error fetching expiry\nNVML_GRID_LICENSE_EXPIRY_VALID = (2,)  # Valid expiry\nNVML_GRID_LICENSE_EXPIRY_NOT_APPLICABLE = (3,)  # Expiry not applicable\nNVML_GRID_LICENSE_EXPIRY_PERMANENT = (4,)  # Permanent expiry\n\n_nvmlVgpuCapability_t = c_uint\nNVML_VGPU_CAP_NVLINK_P2P = 0  # vGPU P2P over NVLink is supported\nNVML_VGPU_CAP_GPUDIRECT = 1  # GPUDirect capability is supported\nNVML_VGPU_CAP_MULTI_VGPU_EXCLUSIVE = (\n    2  # vGPU profile cannot be mixed with other vGPU profiles in same VM\n)\nNVML_VGPU_CAP_EXCLUSIVE_TYPE = (\n    3  # vGPU profile cannot run on a GPU alongside other profiles of different type\n)\nNVML_VGPU_CAP_EXCLUSIVE_SIZE = (\n    4  # vGPU profile cannot run on a GPU alongside other profiles of different size\n)\nNVML_VGPU_CAP_COUNT = 5\n\n_nvmlVgpuDriverCapability_t = c_uint\nNVML_VGPU_DRIVER_CAP_HETEROGENEOUS_MULTI_VGPU = (\n    0  # Supports mixing of different vGPU profiles within one guest VM\n)\nNVML_VGPU_DRIVER_CAP_WARM_UPDATE = 1  # Supports FSR and warm update of vGPU host driver without terminating the running guest VM\nNVML_VGPU_DRIVER_CAP_COUNT = 2\n\n_nvmlDeviceVgpuCapability_t = c_uint\nNVML_DEVICE_VGPU_CAP_FRACTIONAL_MULTI_VGPU = 0  # Query whether the fractional vGPU profiles on this GPU can be used in multi-vGPU configurations\nNVML_DEVICE_VGPU_CAP_HETEROGENEOUS_TIMESLICE_PROFILES = 1  # Query whether the GPU supports concurrent execution of timesliced vGPU profiles of differing types\nNVML_DEVICE_VGPU_CAP_HETEROGENEOUS_TIMESLICE_SIZES = 2  # Query whether the GPU supports concurrent execution of timesliced vGPU profiles of differing framebuffer sizes\nNVML_DEVICE_VGPU_CAP_READ_DEVICE_BUFFER_BW = 3  # Query the GPU's read_device_buffer expected bandwidth capacity in megabytes per second\nNVML_DEVICE_VGPU_CAP_WRITE_DEVICE_BUFFER_BW = 4  # Query the GPU's write_device_buffer expected bandwidth capacity in megabytes per second\nNVML_DEVICE_VGPU_CAP_DEVICE_STREAMING = (\n    5  # Query whether the vGPU profiles on the GPU supports migration data streaming\n)\nNVML_DEVICE_VGPU_CAP_MINI_QUARTER_GPU = (\n    6  # Set/Get support of mini-quarter vGPU profiles\n)\nNVML_DEVICE_VGPU_CAP_COMPUTE_MEDIA_ENGINE_GPU = (\n    7  # Set/Get support for compute media engine vGPU profiles\n)\nNVML_DEVICE_VGPU_CAP_WARM_UPDATE = (\n    8  # Query whether the GPU supports FSR and warm update\n)\nNVML_DEVICE_VGPU_CAP_HOMOGENEOUS_PLACEMENTS = 9  # Query whether the GPU supports reporting of placements of timesliced vGPU profiles with identical framebuffer sizes\nNVML_DEVICE_VGPU_CAP_COUNT = 10\n\n_nvmlVgpuGuestInfoState_t = c_uint\nNVML_VGPU_INSTANCE_GUEST_INFO_STATE_UNINITIALIZED = 0\nNVML_VGPU_INSTANCE_GUEST_INFO_STATE_INITIALIZED = 1\n\n_nvmlVgpuVmCompatibility_t = c_uint\nNVML_VGPU_VM_COMPATIBILITY_NONE = 0x0\nNVML_VGPU_VM_COMPATIBILITY_COLD = 0x1\nNVML_VGPU_VM_COMPATIBILITY_HIBERNATE = 0x2\nNVML_VGPU_VM_COMPATIBILITY_SLEEP = 0x4\nNVML_VGPU_VM_COMPATIBILITY_LIVE = 0x8\n\n_nvmlVgpuPgpuCompatibilityLimitCode_t = c_uint\nNVML_VGPU_COMPATIBILITY_LIMIT_NONE = 0x0\nNVML_VGPU_COMPATIBILITY_LIMIT_HOST_DRIVER = 0x1\nNVML_VGPU_COMPATIBILITY_LIMIT_GUEST_DRIVER = 0x2\nNVML_VGPU_COMPATIBILITY_LIMIT_GPU = 0x4\nNVML_VGPU_COMPATIBILITY_LIMIT_OTHER = 0x80000000\n\n_nvmlHostVgpuMode_t = c_uint\nNVML_HOST_VGPU_MODE_NON_SRIOV = 0\nNVML_HOST_VGPU_MODE_SRIOV = 1\n\n_nvmlConfComputeGpusReadyState_t = c_uint\nNVML_CC_ACCEPTING_CLIENT_REQUESTS_FALSE = 0\nNVML_CC_ACCEPTING_CLIENT_REQUESTS_TRUE = 1\n\n_nvmlConfComputeGpuCaps_t = c_uint\nNVML_CC_SYSTEM_GPUS_CC_NOT_CAPABLE = 0\nNVML_CC_SYSTEM_GPUS_CC_CAPABLE = 1\n\n_nvmlConfComputeCpuCaps_t = c_uint\nNVML_CC_SYSTEM_CPU_CAPS_NONE = 0\nNVML_CC_SYSTEM_CPU_CAPS_AMD_SEV = 1\nNVML_CC_SYSTEM_CPU_CAPS_INTEL_TDX = 2\nNVML_CC_SYSTEM_CPU_CAPS_AMD_SEV_SNP = 3\nNVML_CC_SYSTEM_CPU_CAPS_AMD_SNP_VTOM = 4\n\n_nvmlConfComputeDevToolsMode_t = c_uint\nNVML_CC_SYSTEM_DEVTOOLS_MODE_OFF = 0\nNVML_CC_SYSTEM_DEVTOOLS_MODE_ON = 1\n\nNVML_CC_SYSTEM_MULTIGPU_NONE = 0\nNVML_CC_SYSTEM_MULTIGPU_PROTECTED_PCIE = 1\n\nNVML_CC_SYSTEM_ENVIRONMENT_UNAVAILABLE = 0\nNVML_CC_SYSTEM_ENVIRONMENT_SIM = 1\nNVML_CC_SYSTEM_ENVIRONMENT_PROD = 2\n\n_nvmlConfComputeCcFeature_t = c_uint\nNVML_CC_SYSTEM_FEATURE_DISABLED = 0\nNVML_CC_SYSTEM_FEATURE_ENABLED = 1\n\n_nvmlConfComputeCcKeyRotationThreshAttackerAdv_t = c_uint\nNVML_CC_KEY_ROTATION_THRESH_ATTACKER_ADVANTAGE_MIN = 50\nNVML_CC_KEY_ROTATION_THRESH_ATTACKER_ADVANTAGE_MAX = 65\n\n# GSP firmware\nNVML_GSP_FIRMWARE_VERSION_BUF_SIZE = 0x40\n\n\nclass NVMLLibraryMismatchError(Exception):\n    pass\n\n\n## Error Checking ##\nclass NVMLError(Exception):\n    _valClassMapping = dict()\n    # List of currently known error codes\n    _errcode_to_string = {\n        NVML_ERROR_UNINITIALIZED: \"Uninitialized\",\n        NVML_ERROR_INVALID_ARGUMENT: \"Invalid Argument\",\n        NVML_ERROR_NOT_SUPPORTED: \"Not Supported\",\n        NVML_ERROR_NO_PERMISSION: \"Insufficient Permissions\",\n        NVML_ERROR_ALREADY_INITIALIZED: \"Already Initialized\",\n        NVML_ERROR_NOT_FOUND: \"Not Found\",\n        NVML_ERROR_INSUFFICIENT_SIZE: \"Insufficient Size\",\n        NVML_ERROR_INSUFFICIENT_POWER: \"Insufficient External Power\",\n        NVML_ERROR_DRIVER_NOT_LOADED: \"Driver Not Loaded\",\n        NVML_ERROR_TIMEOUT: \"Timeout\",\n        NVML_ERROR_IRQ_ISSUE: \"Interrupt Request Issue\",\n        NVML_ERROR_LIBRARY_NOT_FOUND: \"NVML Shared Library Not Found\",\n        NVML_ERROR_FUNCTION_NOT_FOUND: \"Function Not Found\",\n        NVML_ERROR_CORRUPTED_INFOROM: \"Corrupted infoROM\",\n        NVML_ERROR_GPU_IS_LOST: \"GPU is lost\",\n        NVML_ERROR_RESET_REQUIRED: \"GPU requires restart\",\n        NVML_ERROR_OPERATING_SYSTEM: \"The operating system has blocked the request.\",\n        NVML_ERROR_LIB_RM_VERSION_MISMATCH: \"RM has detected an NVML/RM version mismatch.\",\n        NVML_ERROR_MEMORY: \"Insufficient Memory\",\n        NVML_ERROR_UNKNOWN: \"Unknown Error\",\n    }\n\n    def __new__(typ, value):\n        \"\"\"\n        Maps value to a proper subclass of NVMLError.\n        See _extractNVMLErrorsAsClasses function for more details\n        \"\"\"\n        if typ == NVMLError:\n            typ = NVMLError._valClassMapping.get(value, typ)\n        obj = Exception.__new__(typ)\n        obj.value = value\n        return obj\n\n    def __str__(self):\n        try:\n            if self.value not in NVMLError._errcode_to_string:\n                NVMLError._errcode_to_string[self.value] = str(\n                    nvmlErrorString(self.value)\n                )\n            return NVMLError._errcode_to_string[self.value]\n        except NVMLError:\n            return \"NVML Error with code %d\" % self.value\n\n    def __eq__(self, other):\n        return self.value == other.value\n\n\ndef nvmlExceptionClass(nvmlErrorCode):\n    if nvmlErrorCode not in NVMLError._valClassMapping:\n        raise ValueError(\"nvmlErrorCode %s is not valid\" % nvmlErrorCode)\n    return NVMLError._valClassMapping[nvmlErrorCode]\n\n\ndef _extractNVMLErrorsAsClasses():\n    \"\"\"\n    Generates a hierarchy of classes on top of NVMLError class.\n\n    Each NVML Error gets a new NVMLError subclass. This way try,except blocks can filter appropriate\n    exceptions more easily.\n\n    NVMLError is a parent class. Each NVML_ERROR_* gets it's own subclass.\n    e.g. NVML_ERROR_ALREADY_INITIALIZED will be turned into NVMLError_AlreadyInitialized\n    \"\"\"\n    this_module = sys.modules[__name__]\n    nvmlErrorsNames = [x for x in dir(this_module) if x.startswith(\"NVML_ERROR_\")]\n    for err_name in nvmlErrorsNames:\n        # e.g. Turn NVML_ERROR_ALREADY_INITIALIZED into NVMLError_AlreadyInitialized\n        class_name = \"NVMLError_\" + string.capwords(\n            err_name.replace(\"NVML_ERROR_\", \"\"), \"_\"\n        ).replace(\"_\", \"\")\n        err_val = getattr(this_module, err_name)\n\n        def gen_new(val):\n            def new(typ):\n                obj = NVMLError.__new__(typ, val)\n                return obj\n\n            return new\n\n        new_error_class = type(class_name, (NVMLError,), {\"__new__\": gen_new(err_val)})\n        new_error_class.__module__ = __name__\n        setattr(this_module, class_name, new_error_class)\n        NVMLError._valClassMapping[err_val] = new_error_class\n\n\n_extractNVMLErrorsAsClasses()\n\n\ndef _nvmlCheckReturn(ret):\n    if ret != NVML_SUCCESS:\n        raise NVMLError(ret)\n    return ret\n\n\n## Function access ##\n_nvmlGetFunctionPointer_cache = (\n    dict()\n)  # function pointers are cached to prevent unnecessary libLoadLock locking\n\n\ndef _nvmlGetFunctionPointer(name):\n    global nvmlLib\n\n    if name in _nvmlGetFunctionPointer_cache:\n        return _nvmlGetFunctionPointer_cache[name]\n\n    libLoadLock.acquire()\n    try:\n        # ensure library was loaded\n        if nvmlLib is None:\n            raise NVMLError(NVML_ERROR_UNINITIALIZED)\n        try:\n            _nvmlGetFunctionPointer_cache[name] = getattr(nvmlLib, name)\n            return _nvmlGetFunctionPointer_cache[name]\n        except AttributeError:\n            raise NVMLError(NVML_ERROR_FUNCTION_NOT_FOUND)\n    finally:\n        # lock is always freed\n        libLoadLock.release()\n\n\n## Alternative object\n# Allows the object to be printed\n# Allows mismatched types to be assigned\n#  - like None when the Structure variant requires c_uint\nclass nvmlFriendlyObject(object):\n    def __init__(self, dictionary):\n        for x in dictionary:\n            setattr(self, x, dictionary[x])\n\n    def __str__(self):\n        return self.__dict__.__str__()\n\n\ndef nvmlStructToFriendlyObject(struct):\n    d = {}\n    for x in struct._fields_:\n        key = x[0]\n        value = getattr(struct, key)\n        # only need to convert from bytes if bytes, no need to check python version.\n        d[key] = value.decode() if isinstance(value, bytes) else value\n    obj = nvmlFriendlyObject(d)\n    return obj\n\n\n# pack the object so it can be passed to the NVML library\ndef nvmlFriendlyObjectToStruct(obj, model):\n    for x in model._fields_:\n        key = x[0]\n        value = obj.__dict__[key]\n        # any c_char_p in python3 needs to be bytes, default encoding works fine.\n        if sys.version_info >= (3,):\n            setattr(model, key, value.encode())\n        else:\n            setattr(model, key, value)\n    return model\n\n\n## Unit structures\nclass struct_c_nvmlUnit_t(Structure):\n    pass  # opaque handle\n\n\nc_nvmlUnit_t = POINTER(struct_c_nvmlUnit_t)\n\n\nclass _PrintableStructure(Structure):\n    \"\"\"\n    Abstract class that produces nicer __str__ output than ctypes.Structure.\n    e.g. instead of:\n      >>> print str(obj)\n      <class_name object at 0x7fdf82fef9e0>\n    this class will print\n      class_name(field_name: formatted_value, field_name: formatted_value)\n\n    _fmt_ dictionary of <str _field_ name> -> <str format>\n    e.g. class that has _field_ 'hex_value', c_uint could be formatted with\n      _fmt_ = {\"hex_value\" : \"%08X\"}\n    to produce nicer output.\n    Default formatting string for all fields can be set with key \"<default>\" like:\n      _fmt_ = {\"<default>\" : \"%d MHz\"} # e.g all values are numbers in MHz.\n    If not set it's assumed to be just \"%s\"\n\n    Exact format of returned str from this class is subject to change in the future.\n    \"\"\"\n\n    _fmt_ = {}\n\n    def __str__(self):\n        result = []\n        for x in self._fields_:\n            key = x[0]\n            value = getattr(self, key)\n            fmt = \"%s\"\n            if key in self._fmt_:\n                fmt = self._fmt_[key]\n            elif \"<default>\" in self._fmt_:\n                fmt = self._fmt_[\"<default>\"]\n            result.append((\"%s: \" + fmt) % (key, value))\n        return self.__class__.__name__ + \"(\" + \", \".join(result) + \")\"\n\n    def __getattribute__(self, name):\n        res = super(_PrintableStructure, self).__getattribute__(name)\n        # need to convert bytes to unicode for python3 don't need to for python2\n        # Python 2 strings are of both str and bytes\n        # Python 3 strings are not of type bytes\n        # ctypes should convert everything to the correct values otherwise\n        if isinstance(res, bytes):\n            if isinstance(res, str):\n                return res\n            return res.decode()\n        return res\n\n    def __setattr__(self, name, value):\n        if isinstance(value, str):\n            # encoding a python2 string returns the same value, since python2 strings are bytes already\n            # bytes passed in python3 will be ignored.\n            value = value.encode()\n        super(_PrintableStructure, self).__setattr__(name, value)\n\n\nclass c_nvmlUnitInfo_t(_PrintableStructure):\n    _fields_ = [\n        (\"name\", c_char * 96),\n        (\"id\", c_char * 96),\n        (\"serial\", c_char * 96),\n        (\"firmwareVersion\", c_char * 96),\n    ]\n\n\nclass c_nvmlC2cModeInfo_v1_t(_PrintableStructure):\n    _fields_ = [(\"isC2cEnabled\", c_uint)]\n\n\nnvmlC2cModeInfo_v1 = 0x1000008\n\n\nclass c_nvmlLedState_t(_PrintableStructure):\n    _fields_ = [\n        (\"cause\", c_char * 256),\n        (\"color\", _nvmlLedColor_t),\n    ]\n\n\nclass c_nvmlPSUInfo_t(_PrintableStructure):\n    _fields_ = [\n        (\"state\", c_char * 256),\n        (\"current\", c_uint),\n        (\"voltage\", c_uint),\n        (\"power\", c_uint),\n    ]\n\n\nclass c_nvmlUnitFanInfo_t(_PrintableStructure):\n    _fields_ = [\n        (\"speed\", c_uint),\n        (\"state\", _nvmlFanState_t),\n    ]\n\n\nclass c_nvmlUnitFanSpeeds_t(_PrintableStructure):\n    _fields_ = [(\"fans\", c_nvmlUnitFanInfo_t * 24), (\"count\", c_uint)]\n\n\n## Device structures\nclass struct_c_nvmlDevice_t(Structure):\n    pass  # opaque handle\n\n\nc_nvmlDevice_t = POINTER(struct_c_nvmlDevice_t)\n\n\nclass nvmlPciInfoExt_v1_t(_PrintableStructure):\n    _fields_ = [\n        (\"version\", c_uint),\n        (\"domain\", c_uint),\n        (\"bus\", c_uint),\n        (\"device\", c_uint),\n        (\"pciDeviceId\", c_uint),\n        (\"pciSubSystemId\", c_uint),\n        (\"baseClass\", c_uint),\n        (\"subClass\", c_uint),\n        (\"busId\", c_char * NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE),\n    ]\n    _fmt_ = {\n        \"version\": \"0x%04X\",\n        \"domain\": \"0x%04X\",\n        \"bus\": \"0x%02X\",\n        \"device\": \"0x%02X\",\n        \"pciDeviceId\": \"0x%08X\",\n        \"pciSubSystemId\": \"0x%08X\",\n        \"baseClass\": \"0x%01X\",\n        \"subClass\": \"0x%01X\",\n    }\n\n\nnvmlPciInfoExt_v1 = 0x1000040\n\n\n# Legacy pciInfo used for _v1 and _v2\nclass nvmlPciInfo_v2_t(_PrintableStructure):\n    _fields_ = [\n        (\"busId\", c_char * NVML_DEVICE_PCI_BUS_ID_BUFFER_V2_SIZE),\n        (\"domain\", c_uint),\n        (\"bus\", c_uint),\n        (\"device\", c_uint),\n        (\"pciDeviceId\", c_uint),\n        # Added in 2.285\n        (\"pciSubSystemId\", c_uint),\n        (\"reserved0\", c_uint),\n        (\"reserved1\", c_uint),\n        (\"reserved2\", c_uint),\n        (\"reserved3\", c_uint),\n    ]\n    _fmt_ = {\n        \"domain\": \"0x%04X\",\n        \"bus\": \"0x%02X\",\n        \"device\": \"0x%02X\",\n        \"pciDeviceId\": \"0x%08X\",\n        \"pciSubSystemId\": \"0x%08X\",\n    }\n\n\nclass nvmlPciInfo_t(_PrintableStructure):\n    _fields_ = [\n        # Moved to the new busId location below\n        (\"busIdLegacy\", c_char * NVML_DEVICE_PCI_BUS_ID_BUFFER_V2_SIZE),\n        (\"domain\", c_uint),\n        (\"bus\", c_uint),\n        (\"device\", c_uint),\n        (\"pciDeviceId\", c_uint),\n        # Added in 2.285\n        (\"pciSubSystemId\", c_uint),\n        # New busId replaced the long deprecated and reserved fields with a\n        # field of the same size in 9.0\n        (\"busId\", c_char * NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE),\n    ]\n    _fmt_ = {\n        \"domain\": \"0x%08X\",\n        \"bus\": \"0x%02X\",\n        \"device\": \"0x%02X\",\n        \"pciDeviceId\": \"0x%08X\",\n        \"pciSubSystemId\": \"0x%08X\",\n    }\n\n\nclass c_nvmlSystemDriverBranchInfo_v1_t(_PrintableStructure):\n    _fields_ = [\n        (\"version\", c_uint),\n        (\"branch\", c_char * NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE),\n    ]\n\n\nSystemDriverBranchInfo_v1 = 0x1000054\n\n\nclass c_nvmlExcludedDeviceInfo_t(_PrintableStructure):\n    _fields_ = [(\"pci\", nvmlPciInfo_t), (\"uuid\", c_char * NVML_DEVICE_UUID_BUFFER_SIZE)]\n\n\nclass nvmlNvLinkUtilizationControl_t(_PrintableStructure):\n    _fields_ = [\n        (\"units\", _nvmlNvLinkUtilizationCountUnits_t),\n        (\"pktfilter\", _nvmlNvLinkUtilizationCountPktTypes_t),\n    ]\n\n\nclass c_nvmlMemory_t(_PrintableStructure):\n    _fields_ = [\n        (\"total\", c_ulonglong),\n        (\"free\", c_ulonglong),\n        (\"used\", c_ulonglong),\n    ]\n    _fmt_ = {\"<default>\": \"%d B\"}\n\n\nclass c_nvmlMemory_v2_t(_PrintableStructure):\n    _fields_ = [\n        (\"version\", c_uint),\n        (\"total\", c_ulonglong),\n        (\"reserved\", c_ulonglong),\n        (\"free\", c_ulonglong),\n        (\"used\", c_ulonglong),\n    ]\n    _fmt_ = {\"<default>\": \"%d B\"}\n\n\nnvmlMemory_v2 = 0x02000028\n\n\nclass c_nvmlBAR1Memory_t(_PrintableStructure):\n    _fields_ = [\n        (\"bar1Total\", c_ulonglong),\n        (\"bar1Free\", c_ulonglong),\n        (\"bar1Used\", c_ulonglong),\n    ]\n    _fmt_ = {\"<default>\": \"%d B\"}\n\n\nclass nvmlClkMonFaultInfo_t(Structure):\n    _fields_ = [(\"clkApiDomain\", c_uint), (\"clkDomainFaultMask\", c_uint)]\n\n\nMAX_CLK_DOMAINS = 32\n\n\nclass nvmlClkMonStatus_t(Structure):\n    _fields_ = [\n        (\"bGlobalStatus\", c_uint),\n        (\"clkMonListSize\", c_uint),\n        (\"clkMonList\", nvmlClkMonFaultInfo_t * MAX_CLK_DOMAINS),\n    ]\n\n\n# On Windows with the WDDM driver, usedGpuMemory is reported as None\n# Code that processes this structure should check for None, I.E.\n#\n# if (info.usedGpuMemory is None):\n#     # TODO handle the error\n#     pass\n# else:\n#    print(\"Using %d MiB of memory\" % (info.usedGpuMemory / 1024 / 1024))\n# endif\n#\n# See NVML documentation for more information\nclass c_nvmlProcessInfo_v2_t(_PrintableStructure):\n    _fields_ = [\n        (\"pid\", c_uint),\n        (\"usedGpuMemory\", c_ulonglong),\n        (\"gpuInstanceId\", c_uint),\n        (\"computeInstanceId\", c_uint),\n    ]\n    _fmt_ = {\"usedGpuMemory\": \"%d B\"}\n\n\nc_nvmlProcessInfo_v3_t = c_nvmlProcessInfo_v2_t\n\nc_nvmlProcessInfo_t = c_nvmlProcessInfo_v3_t\n\n_nvmlProcessMode_t = c_uint\nNVML_PROCESS_MODE_COMPUTE = 0\nNVML_PROCESS_MODE_GRAPHICS = 1\nNVML_PROCESS_MODE_MPS = 2\n\n\nclass c_nvmlProcessDetail_v1_t(Structure):\n    _fields_ = [\n        (\"pid\", c_uint),\n        (\"usedGpuMemory\", c_ulonglong),\n        (\"gpuInstanceId\", c_uint),\n        (\"computeInstanceId\", c_uint),\n        (\"usedGpuCcProtectedMemory\", c_ulonglong),\n    ]\n\n\nclass c_nvmlProcessDetailList_v1_t(_PrintableStructure):\n    _fields_ = [\n        (\"version\", c_uint),\n        (\"mode\", _nvmlProcessMode_t),\n        (\"numProcArrayEntries\", c_uint),\n        (\"procArray\", POINTER(c_nvmlProcessDetail_v1_t)),\n    ]\n    _fmt_ = {\"numProcArrayEntries\": \"%d B\"}\n\n\nc_nvmlProcessDetailList_t = c_nvmlProcessDetailList_v1_t\n\nnvmlProcessDetailList_v1 = 0x1000018\n\n\nclass c_nvmlBridgeChipInfo_t(_PrintableStructure):\n    _fields_ = [\n        (\"type\", _nvmlBridgeChipType_t),\n        (\"fwVersion\", c_uint),\n    ]\n\n\nclass c_nvmlBridgeChipHierarchy_t(_PrintableStructure):\n    _fields_ = [\n        (\"bridgeCount\", c_uint),\n        (\"bridgeChipInfo\", c_nvmlBridgeChipInfo_t * 128),\n    ]\n\n\nclass c_nvmlEccErrorCounts_t(_PrintableStructure):\n    _fields_ = [\n        (\"l1Cache\", c_ulonglong),\n        (\"l2Cache\", c_ulonglong),\n        (\"deviceMemory\", c_ulonglong),\n        (\"registerFile\", c_ulonglong),\n    ]\n\n\nclass c_nvmlUtilization_t(_PrintableStructure):\n    _fields_ = [\n        (\"gpu\", c_uint),\n        (\"memory\", c_uint),\n    ]\n    _fmt_ = {\"<default>\": \"%d %%\"}\n\n\n# Added in 2.285\nclass c_nvmlHwbcEntry_t(_PrintableStructure):\n    _fields_ = [\n        (\"hwbcId\", c_uint),\n        (\"firmwareVersion\", c_char * 32),\n    ]\n\n\nclass c_nvmlValue_t(Union):\n    _fields_ = [\n        (\"dVal\", c_double),\n        (\"uiVal\", c_uint),\n        (\"ulVal\", c_ulong),\n        (\"ullVal\", c_ulonglong),\n        (\"sllVal\", c_longlong),\n        (\"siVal\", c_int),\n        (\"usVal\", c_ushort),\n    ]\n\n\nclass c_nvmlSample_t(_PrintableStructure):\n    _fields_ = [\n        (\"timeStamp\", c_ulonglong),\n        (\"sampleValue\", c_nvmlValue_t),\n    ]\n\n\nclass c_nvmlViolationTime_t(_PrintableStructure):\n    _fields_ = [\n        (\"referenceTime\", c_ulonglong),\n        (\"violationTime\", c_ulonglong),\n    ]\n\n\nclass c_nvmlFieldValue_t(_PrintableStructure):\n    _fields_ = [\n        (\"fieldId\", c_uint32),\n        (\"scopeId\", c_uint32),\n        (\"timestamp\", c_int64),\n        (\"latencyUsec\", c_int64),\n        (\"valueType\", _nvmlValueType_t),\n        (\"nvmlReturn\", _nvmlReturn_t),\n        (\"value\", c_nvmlValue_t),\n    ]\n\n\nNVML_NVLINK_TOTAL_SUPPORTED_BW_MODES = 23\n\nnvmlNvlinkSupportedBwModes_v1 = 0x100001C\n\n\nclass c_nvmlNvlinkSupportedBwModes_v1_t(_PrintableStructure):\n    _fields_ = [\n        (\"version\", c_uint),\n        (\"bwModes\", c_uint8 * NVML_NVLINK_TOTAL_SUPPORTED_BW_MODES),\n        (\"totalBwModes\", c_uint8),\n    ]\n\n    def __init__(self):\n        super(c_nvmlNvlinkSupportedBwModes_v1_t, self).__init__(\n            version=nvmlNvlinkSupportedBwModes_v1\n        )\n\n\nnvmlNvlinkGetBwMode_v1 = 0x100000C\n\n\nclass c_nvmlNvlinkGetBwMode_v1_t(_PrintableStructure):\n    _fields_ = [(\"version\", c_uint), (\"bIsBest\", c_uint), (\"bwMode\", c_uint8)]\n\n    def __init__(self):\n        super(c_nvmlNvlinkGetBwMode_v1_t, self).__init__(version=nvmlNvlinkGetBwMode_v1)\n\n\nnvmlNvlinkSetBwMode_v1 = 0x100000C\n\n\nclass c_nvmlNvlinkSetBwMode_v1_t(_PrintableStructure):\n    _fields_ = [(\"version\", c_uint), (\"bSetBest\", c_uint), (\"bwMode\", c_uint8)]\n\n    def __init__(self):\n        super(c_nvmlNvlinkSetBwMode_v1_t, self).__init__(version=nvmlNvlinkSetBwMode_v1)\n\n\nclass c_nvmlVgpuHeterogeneousMode_v1_t(_PrintableStructure):\n    _fields_ = [\n        (\"version\", c_uint),\n        (\"mode\", c_uint),\n    ]\n\n\nVgpuHeterogeneousMode_v1 = 0x1000008\n\n\nclass c_nvmlVgpuPlacementId_v1_t(_PrintableStructure):\n    _fields_ = [\n        (\"version\", c_uint),\n        (\"placementId\", c_uint),\n    ]\n\n\nVgpuPlacementId_v1 = 0x1000008\n\n\nclass c_nvmlVgpuPlacementList_v1_t(_PrintableStructure):\n    _fields_ = [\n        (\"version\", c_uint),\n        (\"count\", c_uint),\n        (\"placementSize\", c_uint),\n        (\"placementIds\", POINTER(c_uint)),\n    ]\n\n\nVgpuPlacementList_v1 = 0x1000018\n\nNVML_VGPU_PGPU_HETEROGENEOUS_MODE = 0\nNVML_VGPU_PGPU_HOMOGENEOUS_MODE = 1\n\n\nclass c_nvmlVgpuPlacementList_v2_t(_PrintableStructure):\n    _fields_ = [\n        (\"version\", c_uint),\n        (\"placementSize\", c_uint),\n        (\"count\", c_uint),\n        (\"placementIds\", POINTER(c_uint)),\n        (\"mode\", c_uint),\n    ]\n\n\nVgpuPlacementList_v2 = 0x2000020\n\n\nclass c_nvmlVgpuTypeBar1Info_v1_t(_PrintableStructure):\n    _fields_ = [\n        (\"version\", c_uint),\n        (\"bar1Size\", c_ulonglong),\n    ]\n\n\nVgpuTypeBar1Info_v1 = 0x1000010\n\n\nclass c_nvmlVgpuInstanceUtilizationSample_t(_PrintableStructure):\n    _fields_ = [\n        (\"vgpuInstance\", _nvmlVgpuInstance_t),\n        (\"timeStamp\", c_ulonglong),\n        (\"smUtil\", c_nvmlValue_t),\n        (\"memUtil\", c_nvmlValue_t),\n        (\"encUtil\", c_nvmlValue_t),\n        (\"decUtil\", c_nvmlValue_t),\n    ]\n\n\nclass c_nvmlVgpuInstanceUtilizationInfo_v1_t(_PrintableStructure):\n    _fields_ = [\n        (\"timeStamp\", c_ulonglong),\n        (\"vgpuInstance\", _nvmlVgpuInstance_t),\n        (\"smUtil\", c_nvmlValue_t),\n        (\"memUtil\", c_nvmlValue_t),\n        (\"encUtil\", c_nvmlValue_t),\n        (\"decUtil\", c_nvmlValue_t),\n        (\"jpgUtil\", c_nvmlValue_t),\n        (\"ofaUtil\", c_nvmlValue_t),\n    ]\n\n\nclass c_nvmlVgpuInstancesUtilizationInfo_v1_t(_PrintableStructure):\n    _fields_ = [\n        (\"version\", c_uint),\n        (\"sampleValType\", _nvmlValueType_t),\n        (\"vgpuInstanceCount\", c_uint),\n        (\"lastSeenTimeStamp\", c_ulonglong),\n        (\"vgpuUtilArray\", POINTER(c_nvmlVgpuInstanceUtilizationInfo_v1_t)),\n    ]\n\n\nVgpuInstancesUtilizationInfo_v1 = 0x01000020\n\n\nclass c_nvmlVgpuProcessUtilizationSample_t(_PrintableStructure):\n    _fields_ = [\n        (\"vgpuInstance\", _nvmlVgpuInstance_t),\n        (\"pid\", c_uint),\n        (\"processName\", c_char * NVML_VGPU_NAME_BUFFER_SIZE),\n        (\"timeStamp\", c_ulonglong),\n        (\"smUtil\", c_uint),\n        (\"memUtil\", c_uint),\n        (\"encUtil\", c_uint),\n        (\"decUtil\", c_uint),\n    ]\n\n\nclass c_nvmlVgpuProcessUtilizationInfo_v1_t(_PrintableStructure):\n    _fields_ = [\n        (\"processName\", c_char * NVML_VGPU_NAME_BUFFER_SIZE),\n        (\"timeStamp\", c_ulonglong),\n        (\"vgpuInstance\", _nvmlVgpuInstance_t),\n        (\"pid\", c_uint),\n        (\"smUtil\", c_uint),\n        (\"memUtil\", c_uint),\n        (\"encUtil\", c_uint),\n        (\"decUtil\", c_uint),\n        (\"jpgUtil\", c_uint),\n        (\"ofaUtil\", c_uint),\n    ]\n\n\nclass c_nvmlVgpuProcessesUtilizationInfo_v1_t(_PrintableStructure):\n    _fields_ = [\n        (\"version\", c_uint),\n        (\"vgpuProcessCount\", c_uint),\n        (\"lastSeenTimeStamp\", c_ulonglong),\n        (\"vgpuProcUtilArray\", POINTER(c_nvmlVgpuProcessUtilizationInfo_v1_t)),\n    ]\n\n\nVgpuProcessesUtilizationInfo_v1 = 0x01000018\n\n\nclass nvmlVgpuRuntimeState_v1_t(_PrintableStructure):\n    _fields_ = [\n        (\"version\", c_uint),\n        (\"size\", c_ulonglong),\n    ]\n\n\nVgpuRuntimeState_v1 = 0x1000010\n\n\nclass c_nvmlVgpuLicenseExpiry_t(_PrintableStructure):\n    _fields_ = [\n        (\"year\", c_uint32),\n        (\"month\", c_uint16),\n        (\"day\", c_uint16),\n        (\"hour\", c_uint16),\n        (\"min\", c_uint16),\n        (\"sec\", c_uint16),\n        (\"status\", c_uint8),\n    ]\n\n\nNVML_GRID_LICENSE_STATE_UNKNOWN = 0\nNVML_GRID_LICENSE_STATE_UNINITIALIZED = 1\nNVML_GRID_LICENSE_STATE_UNLICENSED_UNRESTRICTED = 2\nNVML_GRID_LICENSE_STATE_UNLICENSED_RESTRICTED = 3\nNVML_GRID_LICENSE_STATE_UNLICENSED = 4\nNVML_GRID_LICENSE_STATE_LICENSED = 5\n\n\nclass c_nvmlVgpuLicenseInfo_t(_PrintableStructure):\n    _fields_ = [\n        (\"isLicensed\", c_uint8),\n        (\"licenseExpiry\", c_nvmlVgpuLicenseExpiry_t),\n        (\"currentState\", c_uint),\n    ]\n\n\nclass c_nvmlEncoderSession_t(_PrintableStructure):\n    _fields_ = [\n        (\"sessionId\", c_uint),\n        (\"pid\", c_uint),\n        (\"vgpuInstance\", _nvmlVgpuInstance_t),\n        (\"codecType\", c_uint),\n        (\"hResolution\", c_uint),\n        (\"vResolution\", c_uint),\n        (\"averageFps\", c_uint),\n        (\"encodeLatency\", c_uint),\n    ]\n\n\nclass c_nvmlProcessUtilizationSample_t(_PrintableStructure):\n    _fields_ = [\n        (\"pid\", c_uint),\n        (\"timeStamp\", c_ulonglong),\n        (\"smUtil\", c_uint),\n        (\"memUtil\", c_uint),\n        (\"encUtil\", c_uint),\n        (\"decUtil\", c_uint),\n    ]\n\n\nclass c_nvmlProcessUtilizationInfo_v1_t(_PrintableStructure):\n    _fields_ = [\n        (\"timeStamp\", c_ulonglong),\n        (\"pid\", c_uint),\n        (\"smUtil\", c_uint),\n        (\"memUtil\", c_uint),\n        (\"encUtil\", c_uint),\n        (\"decUtil\", c_uint),\n        (\"jpgUtil\", c_uint),\n        (\"ofaUtil\", c_uint),\n    ]\n\n\nclass c_nvmlProcessesUtilizationInfo_v1_t(_PrintableStructure):\n    _fields_ = [\n        (\"version\", c_uint),\n        (\"processSamplesCount\", c_uint),\n        (\"lastSeenTimeStamp\", c_ulonglong),\n        (\"procUtilArray\", POINTER(c_nvmlProcessUtilizationInfo_v1_t)),\n    ]\n\n\nProcessesUtilizationInfo_v1 = 0x01000018\n\n\nclass c_nvmlGridLicenseExpiry_t(_PrintableStructure):\n    _fields_ = [\n        (\"year\", c_uint32),\n        (\"month\", c_uint16),\n        (\"day\", c_uint16),\n        (\"hour\", c_uint16),\n        (\"min\", c_uint16),\n        (\"sec\", c_uint16),\n        (\"status\", c_uint8),\n    ]\n\n\nclass c_nvmlGridLicensableFeature_v4_t(_PrintableStructure):\n    _fields_ = [\n        (\"featureCode\", _nvmlGridLicenseFeatureCode_t),\n        (\"featureState\", c_uint),\n        (\"licenseInfo\", c_char * NVML_GRID_LICENSE_BUFFER_SIZE),\n        (\"productName\", c_char * NVML_GRID_LICENSE_BUFFER_SIZE),\n        (\"featureEnabled\", c_uint),\n        (\"licenseExpiry\", c_nvmlGridLicenseExpiry_t),\n    ]\n\n\nclass c_nvmlGridLicensableFeatures_v4_t(_PrintableStructure):\n    _fields_ = [\n        (\"isGridLicenseSupported\", c_int),\n        (\"licensableFeaturesCount\", c_uint),\n        (\n            \"gridLicensableFeatures\",\n            c_nvmlGridLicensableFeature_v4_t * NVML_GRID_LICENSE_FEATURE_MAX_COUNT,\n        ),\n    ]\n\n\nclass c_nvmlGridLicensableFeature_v3_t(_PrintableStructure):\n    _fields_ = [\n        (\"featureCode\", _nvmlGridLicenseFeatureCode_t),\n        (\"featureState\", c_uint),\n        (\"licenseInfo\", c_char * NVML_GRID_LICENSE_BUFFER_SIZE),\n        (\"productName\", c_char * NVML_GRID_LICENSE_BUFFER_SIZE),\n        (\"featureEnabled\", c_uint),\n    ]\n\n\nclass c_nvmlGridLicensableFeatures_v3_t(_PrintableStructure):\n    _fields_ = [\n        (\"isGridLicenseSupported\", c_int),\n        (\"licensableFeaturesCount\", c_uint),\n        (\n            \"gridLicensableFeatures\",\n            c_nvmlGridLicensableFeature_v3_t * NVML_GRID_LICENSE_FEATURE_MAX_COUNT,\n        ),\n    ]\n\n\nclass c_nvmlGridLicensableFeature_v2_t(_PrintableStructure):\n    _fields_ = [\n        (\"featureCode\", _nvmlGridLicenseFeatureCode_t),\n        (\"featureState\", c_uint),\n        (\"licenseInfo\", c_char * NVML_GRID_LICENSE_BUFFER_SIZE),\n        (\"productName\", c_char * NVML_GRID_LICENSE_BUFFER_SIZE),\n    ]\n\n\nclass c_nvmlGridLicensableFeatures_v2_t(_PrintableStructure):\n    _fields_ = [\n        (\"isGridLicenseSupported\", c_int),\n        (\"licensableFeaturesCount\", c_uint),\n        (\n            \"gridLicensableFeatures\",\n            c_nvmlGridLicensableFeature_v2_t * NVML_GRID_LICENSE_FEATURE_MAX_COUNT,\n        ),\n    ]\n\n\nclass c_nvmlGridLicensableFeature_t(_PrintableStructure):\n    _fields_ = [\n        (\"featureCode\", _nvmlGridLicenseFeatureCode_t),\n        (\"featureState\", c_uint),\n        (\"licenseInfo\", c_char * NVML_GRID_LICENSE_BUFFER_SIZE),\n    ]\n\n\nclass c_nvmlGridLicensableFeatures_t(_PrintableStructure):\n    _fields_ = [\n        (\"isGridLicenseSupported\", c_int),\n        (\"licensableFeaturesCount\", c_uint),\n        (\n            \"gridLicensableFeatures\",\n            c_nvmlGridLicensableFeature_t * NVML_GRID_LICENSE_FEATURE_MAX_COUNT,\n        ),\n    ]\n\n\nclass c_nvmlMarginTemperature_v1_t(_PrintableStructure):\n    _fields_ = [\n        (\"version\", c_uint),\n        (\"marginTemperature\", c_int),\n    ]\n\n\nnvmlMarginTemperature_v1 = 0x1000008\n\n\n## Event structures\nclass struct_c_nvmlEventSet_t(Structure):\n    pass  # opaque handle\n\n\nc_nvmlEventSet_t = POINTER(struct_c_nvmlEventSet_t)\n\nnvmlEventTypeSingleBitEccError = 0x0000000000000001\nnvmlEventTypeDoubleBitEccError = 0x0000000000000002\nnvmlEventTypePState = 0x0000000000000004\nnvmlEventTypeXidCriticalError = 0x0000000000000008\nnvmlEventTypeClock = 0x0000000000000010\nnvmlEventTypePowerSourceChange = 0x0000000000000080\nnvmlEventMigConfigChange = 0x0000000000000100\nnvmlEventTypeSingleBitEccErrorStorm = 0x0000000000000200\nnvmlEventTypeDramRetirementEvent = 0x0000000000000400\nnvmlEventTypeDramRetirementFailure = 0x0000000000000800\nnvmlEventTypeNonFatalPoisonError = 0x0000000000001000\nnvmlEventTypeFatalPoisonError = 0x0000000000002000\nnvmlEventTypeGpuUnavailableError = 0x0000000000004000\nnvmlEventTypeGpuRecoveryAction = 0x0000000000008000\nnvmlEventTypeNone = 0x0000000000000000\nnvmlEventTypeAll = (\n    nvmlEventTypeNone\n    | nvmlEventTypeSingleBitEccError\n    | nvmlEventTypeDoubleBitEccError\n    | nvmlEventTypePState\n    | nvmlEventTypeClock\n    | nvmlEventTypePowerSourceChange\n    | nvmlEventTypeXidCriticalError\n    | nvmlEventMigConfigChange\n    | nvmlEventTypeSingleBitEccErrorStorm\n    | nvmlEventTypeDramRetirementEvent\n    | nvmlEventTypeDramRetirementFailure\n    | nvmlEventTypeNonFatalPoisonError\n    | nvmlEventTypeFatalPoisonError\n    | nvmlEventTypeGpuUnavailableError\n    | nvmlEventTypeGpuRecoveryAction\n)\n\n## Clock Event Reasons defines\nnvmlClocksEventReasonGpuIdle = 0x0000000000000001\nnvmlClocksEventReasonApplicationsClocksSetting = 0x0000000000000002\nnvmlClocksEventReasonUserDefinedClocks = nvmlClocksEventReasonApplicationsClocksSetting  # deprecated, use nvmlClocksEventReasonApplicationsClocksSetting\nnvmlClocksEventReasonSwPowerCap = 0x0000000000000004\nnvmlClocksEventReasonHwSlowdown = 0x0000000000000008\nnvmlClocksEventReasonSyncBoost = 0x0000000000000010\nnvmlClocksEventReasonSwThermalSlowdown = 0x0000000000000020\nnvmlClocksEventReasonHwThermalSlowdown = 0x0000000000000040\nnvmlClocksEventReasonHwPowerBrakeSlowdown = 0x0000000000000080\nnvmlClocksEventReasonDisplayClockSetting = 0x0000000000000100\nnvmlClocksEventReasonNone = 0x0000000000000000\nnvmlClocksEventReasonAll = (\n    nvmlClocksEventReasonNone\n    | nvmlClocksEventReasonGpuIdle\n    | nvmlClocksEventReasonApplicationsClocksSetting\n    | nvmlClocksEventReasonSwPowerCap\n    | nvmlClocksEventReasonHwSlowdown\n    | nvmlClocksEventReasonSyncBoost\n    | nvmlClocksEventReasonSwThermalSlowdown\n    | nvmlClocksEventReasonHwThermalSlowdown\n    | nvmlClocksEventReasonHwPowerBrakeSlowdown\n    | nvmlClocksEventReasonDisplayClockSetting\n)\n\n## Following have been deprecated\nnvmlClocksThrottleReasonGpuIdle = 0x0000000000000001\nnvmlClocksThrottleReasonApplicationsClocksSetting = 0x0000000000000002\nnvmlClocksThrottleReasonUserDefinedClocks = nvmlClocksThrottleReasonApplicationsClocksSetting  # deprecated, use nvmlClocksThrottleReasonApplicationsClocksSetting\nnvmlClocksThrottleReasonSwPowerCap = 0x0000000000000004\nnvmlClocksThrottleReasonHwSlowdown = 0x0000000000000008\nnvmlClocksThrottleReasonSyncBoost = 0x0000000000000010\nnvmlClocksThrottleReasonSwThermalSlowdown = 0x0000000000000020\nnvmlClocksThrottleReasonHwThermalSlowdown = 0x0000000000000040\nnvmlClocksThrottleReasonHwPowerBrakeSlowdown = 0x0000000000000080\nnvmlClocksThrottleReasonDisplayClockSetting = 0x0000000000000100\nnvmlClocksThrottleReasonNone = 0x0000000000000000\nnvmlClocksThrottleReasonAll = (\n    nvmlClocksThrottleReasonNone\n    | nvmlClocksThrottleReasonGpuIdle\n    | nvmlClocksThrottleReasonApplicationsClocksSetting\n    | nvmlClocksThrottleReasonSwPowerCap\n    | nvmlClocksThrottleReasonHwSlowdown\n    | nvmlClocksThrottleReasonSyncBoost\n    | nvmlClocksThrottleReasonSwThermalSlowdown\n    | nvmlClocksThrottleReasonHwThermalSlowdown\n    | nvmlClocksThrottleReasonHwPowerBrakeSlowdown\n    | nvmlClocksThrottleReasonDisplayClockSetting\n)\n\n\nclass c_nvmlEventData_t(_PrintableStructure):\n    _fields_ = [\n        (\"device\", c_nvmlDevice_t),\n        (\"eventType\", c_ulonglong),\n        (\"eventData\", c_ulonglong),\n        (\"gpuInstanceId\", c_uint),\n        (\"computeInstanceId\", c_uint),\n    ]\n    _fmt_ = {\"eventType\": \"0x%08X\"}\n\n\nclass c_nvmlAccountingStats_t(_PrintableStructure):\n    _fields_ = [\n        (\"gpuUtilization\", c_uint),\n        (\"memoryUtilization\", c_uint),\n        (\"maxMemoryUsage\", c_ulonglong),\n        (\"time\", c_ulonglong),\n        (\"startTime\", c_ulonglong),\n        (\"isRunning\", c_uint),\n        (\"reserved\", c_uint * 5),\n    ]\n\n\nclass c_nvmlVgpuVersion_t(Structure):\n    _fields_ = [(\"minVersion\", c_uint), (\"maxVersion\", c_uint)]\n\n\nclass c_nvmlVgpuMetadata_t(_PrintableStructure):\n    _fields_ = [\n        (\"version\", c_uint),\n        (\"revision\", c_uint),\n        (\"guestInfoState\", _nvmlVgpuGuestInfoState_t),\n        (\"guestDriverVersion\", c_char * NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE),\n        (\"hostDriverVersion\", c_char * NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE),\n        (\"reserved\", c_uint * 6),\n        (\"vgpuVirtualizationCaps\", c_uint),\n        (\"guestVgpuVersion\", c_uint),\n        (\"opaqueDataSize\", c_uint),\n        (\"opaqueData\", c_char * NVML_VGPU_METADATA_OPAQUE_DATA_SIZE),\n    ]\n\n\nclass c_nvmlVgpuPgpuMetadata_t(_PrintableStructure):\n    _fields_ = [\n        (\"version\", c_uint),\n        (\"revision\", c_uint),\n        (\"hostDriverVersion\", c_char * NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE),\n        (\"pgpuVirtualizationCaps\", c_uint),\n        (\"reserved\", c_uint * 5),\n        (\"hostSupportedVgpuRange\", c_nvmlVgpuVersion_t),\n        (\"opaqueDataSize\", c_uint),\n        (\"opaqueData\", c_char * NVML_VGPU_PGPU_METADATA_OPAQUE_DATA_SIZE),\n    ]\n\n\nclass c_nvmlVgpuPgpuCompatibility_t(Structure):\n    _fields_ = [\n        (\"vgpuVmCompatibility\", _nvmlVgpuVmCompatibility_t),\n        (\"compatibilityLimitCode\", _nvmlVgpuPgpuCompatibilityLimitCode_t),\n    ]\n\n\n## vGPU scheduler policy defines\nNVML_VGPU_SCHEDULER_POLICY_UNKNOWN = 0\nNVML_VGPU_SCHEDULER_POLICY_BEST_EFFORT = 1\nNVML_VGPU_SCHEDULER_POLICY_EQUAL_SHARE = 2\nNVML_VGPU_SCHEDULER_POLICY_FIXED_SHARE = 3\n\n## Supported vGPU scheduler policy count\nNVML_SUPPORTED_VGPU_SCHEDULER_POLICY_COUNT = 3\n\nNVML_SCHEDULER_SW_MAX_LOG_ENTRIES = 200\n\nNVML_VGPU_SCHEDULER_ARR_DEFAULT = 0\nNVML_VGPU_SCHEDULER_ARR_DISABLE = 1\nNVML_VGPU_SCHEDULER_ARR_ENABLE = 2\n\n\nclass c_nvmlVgpuSchedDataWithARR_t(_PrintableStructure):\n    _fields_ = [\n        (\"avgFactor\", c_uint),\n        (\"timeslice\", c_uint),\n    ]\n\n\nclass c_nvmlVgpuSchedData_t(_PrintableStructure):\n    _fields_ = [\n        (\"timeslice\", c_uint),\n    ]\n\n\nclass c_nvmlVgpuSchedulerParams_t(Union):\n    _fields_ = [\n        (\"vgpuSchedDataWithARR\", c_nvmlVgpuSchedDataWithARR_t),\n        (\"vgpuSchedData\", c_nvmlVgpuSchedData_t),\n    ]\n\n\nclass c_nvmlVgpuSchedulerLogEntry_t(_PrintableStructure):\n    _fields_ = [\n        (\"timestamp\", c_ulonglong),\n        (\"timeRunTotal\", c_ulonglong),\n        (\"timeRun\", c_ulonglong),\n        (\"swRunlistId\", c_uint),\n        (\"targetTimeSlice\", c_ulonglong),\n        (\"cumulativePreemptionTime\", c_ulonglong),\n    ]\n\n\nclass c_nvmlVgpuSchedulerLog_t(_PrintableStructure):\n    _fields_ = [\n        (\"engineId\", c_uint),\n        (\"schedulerPolicy\", c_uint),\n        (\"arrMode\", c_uint),\n        (\"schedulerParams\", c_nvmlVgpuSchedulerParams_t),\n        (\"entriesCount\", c_uint),\n        (\n            \"logEntries\",\n            c_nvmlVgpuSchedulerLogEntry_t * NVML_SCHEDULER_SW_MAX_LOG_ENTRIES,\n        ),\n    ]\n\n\nclass c_nvmlVgpuSchedulerGetState_t(_PrintableStructure):\n    _fields_ = [\n        (\"schedulerPolicy\", c_uint),\n        (\"arrMode\", c_uint),\n        (\"schedulerParams\", c_nvmlVgpuSchedulerParams_t),\n    ]\n\n\nclass c_nvmlVgpuSchedSetDataWithARR_t(_PrintableStructure):\n    _fields_ = [\n        (\"avgFactor\", c_uint),\n        (\"frequency\", c_uint),\n    ]\n\n\nclass c_nvmlVgpuSchedSetData_t(_PrintableStructure):\n    _fields_ = [\n        (\"timeslice\", c_uint),\n    ]\n\n\nclass c_nvmlVgpuSchedulerSetParams_t(Union):\n    _fields_ = [\n        (\"vgpuSchedDataWithARR\", c_nvmlVgpuSchedSetDataWithARR_t),\n        (\"vgpuSchedData\", c_nvmlVgpuSchedSetData_t),\n    ]\n\n\nclass c_nvmlVgpuSchedulerSetState_t(_PrintableStructure):\n    _fields_ = [\n        (\"schedulerPolicy\", c_uint),\n        (\"enableARRMode\", c_uint),\n        (\"schedulerParams\", c_nvmlVgpuSchedulerSetParams_t),\n    ]\n\n\nclass c_nvmlVgpuSchedulerCapabilities_t(_PrintableStructure):\n    _fields_ = [\n        (\"supportedSchedulers\", c_uint * NVML_SUPPORTED_VGPU_SCHEDULER_POLICY_COUNT),\n        (\"maxTimeslice\", c_uint),\n        (\"minTimeslice\", c_uint),\n        (\"isArrModeSupported\", c_uint),\n        (\"maxFrequencyForARR\", c_uint),\n        (\"minFrequencyForARR\", c_uint),\n        (\"maxAvgFactorForARR\", c_uint),\n        (\"minAvgFactorForARR\", c_uint),\n    ]\n\n\nclass c_nvmlFBCStats_t(Structure):\n    _fields_ = [\n        (\"sessionsCount\", c_uint),\n        (\"averageFPS\", c_uint),\n        (\"averageLatency\", c_uint),\n    ]\n\n\nclass c_nvmlFBCSession_t(_PrintableStructure):\n    _fields_ = [\n        (\"sessionId\", c_uint),\n        (\"pid\", c_uint),\n        (\"vgpuInstance\", _nvmlVgpuInstance_t),\n        (\"displayOrdinal\", c_uint),\n        (\"sessionType\", c_uint),\n        (\"sessionFlags\", c_uint),\n        (\"hMaxResolution\", c_uint),\n        (\"vMaxResolution\", c_uint),\n        (\"hResolution\", c_uint),\n        (\"vResolution\", c_uint),\n        (\"averageFPS\", c_uint),\n        (\"averageLatency\", c_uint),\n    ]\n\n\nNVML_DEVICE_MIG_DISABLE = 0x0\nNVML_DEVICE_MIG_ENABLE = 0x1\n\nNVML_GPU_INSTANCE_PROFILE_1_SLICE = 0x0\nNVML_GPU_INSTANCE_PROFILE_2_SLICE = 0x1\nNVML_GPU_INSTANCE_PROFILE_3_SLICE = 0x2\nNVML_GPU_INSTANCE_PROFILE_4_SLICE = 0x3\nNVML_GPU_INSTANCE_PROFILE_7_SLICE = 0x4\nNVML_GPU_INSTANCE_PROFILE_8_SLICE = 0x5\nNVML_GPU_INSTANCE_PROFILE_6_SLICE = 0x6\nNVML_GPU_INSTANCE_PROFILE_1_SLICE_REV1 = 0x7\nNVML_GPU_INSTANCE_PROFILE_2_SLICE_REV1 = 0x8\nNVML_GPU_INSTANCE_PROFILE_1_SLICE_REV2 = 0x9\nNVML_GPU_INSTANCE_PROFILE_1_SLICE_GFX = 0xA\nNVML_GPU_INSTANCE_PROFILE_2_SLICE_GFX = 0xB\nNVML_GPU_INSTANCE_PROFILE_4_SLICE_GFX = 0xC\nNVML_GPU_INSTANCE_PROFILE_COUNT = 0xD\n\n\nclass c_nvmlGpuInstancePlacement_t(Structure):\n    _fields_ = [(\"start\", c_uint), (\"size\", c_uint)]\n\n\nclass c_nvmlGpuInstanceProfileInfo_t(Structure):\n    _fields_ = [\n        (\"id\", c_uint),\n        (\"isP2pSupported\", c_uint),\n        (\"sliceCount\", c_uint),\n        (\"instanceCount\", c_uint),\n        (\"multiprocessorCount\", c_uint),\n        (\"copyEngineCount\", c_uint),\n        (\"decoderCount\", c_uint),\n        (\"encoderCount\", c_uint),\n        (\"jpegCount\", c_uint),\n        (\"ofaCount\", c_uint),\n        (\"memorySizeMB\", c_ulonglong),\n    ]\n\n\nnvmlGpuInstanceProfileInfo_v2 = 0x02000098\n\n\nclass c_nvmlGpuInstanceProfileInfo_v2_t(_PrintableStructure):\n    _fields_ = [\n        (\"version\", c_uint),\n        (\"id\", c_uint),\n        (\"isP2pSupported\", c_uint),\n        (\"sliceCount\", c_uint),\n        (\"instanceCount\", c_uint),\n        (\"multiprocessorCount\", c_uint),\n        (\"copyEngineCount\", c_uint),\n        (\"decoderCount\", c_uint),\n        (\"encoderCount\", c_uint),\n        (\"jpegCount\", c_uint),\n        (\"ofaCount\", c_uint),\n        (\"memorySizeMB\", c_ulonglong),\n        (\"name\", c_char * NVML_DEVICE_NAME_V2_BUFFER_SIZE),\n    ]\n\n    def __init__(self):\n        super(c_nvmlGpuInstanceProfileInfo_v2_t, self).__init__(\n            version=nvmlGpuInstanceProfileInfo_v2\n        )\n\n\nclass c_nvmlGpuInstanceInfo_t(Structure):\n    _fields_ = [\n        (\"device\", c_nvmlDevice_t),\n        (\"id\", c_uint),\n        (\"profileId\", c_uint),\n        (\"placement\", c_nvmlGpuInstancePlacement_t),\n    ]\n\n\nclass struct_c_nvmlGpuInstance_t(Structure):\n    pass  # opaque handle\n\n\nc_nvmlGpuInstance_t = POINTER(struct_c_nvmlGpuInstance_t)\n\nNVML_COMPUTE_INSTANCE_PROFILE_1_SLICE = 0x0\nNVML_COMPUTE_INSTANCE_PROFILE_2_SLICE = 0x1\nNVML_COMPUTE_INSTANCE_PROFILE_3_SLICE = 0x2\nNVML_COMPUTE_INSTANCE_PROFILE_4_SLICE = 0x3\nNVML_COMPUTE_INSTANCE_PROFILE_7_SLICE = 0x4\nNVML_COMPUTE_INSTANCE_PROFILE_8_SLICE = 0x5\nNVML_COMPUTE_INSTANCE_PROFILE_6_SLICE = 0x6\nNVML_COMPUTE_INSTANCE_PROFILE_1_SLICE_REV1 = 0x7\nNVML_COMPUTE_INSTANCE_PROFILE_COUNT = 0x8\n\nNVML_COMPUTE_INSTANCE_ENGINE_PROFILE_SHARED = 0x0\nNVML_COMPUTE_INSTANCE_ENGINE_PROFILE_COUNT = 0x1\n\n\nclass c_nvmlComputeInstancePlacement_t(Structure):\n    _fields_ = [(\"start\", c_uint), (\"size\", c_uint)]\n\n\nclass c_nvmlComputeInstanceProfileInfo_t(Structure):\n    _fields_ = [\n        (\"id\", c_uint),\n        (\"sliceCount\", c_uint),\n        (\"instanceCount\", c_uint),\n        (\"multiprocessorCount\", c_uint),\n        (\"sharedCopyEngineCount\", c_uint),\n        (\"sharedDecoderCount\", c_uint),\n        (\"sharedEncoderCount\", c_uint),\n        (\"sharedJpegCount\", c_uint),\n        (\"sharedOfaCount\", c_uint),\n    ]\n\n\nnvmlComputeInstanceProfileInfo_v2 = 0x02000088\n\n\nclass c_nvmlComputeInstanceProfileInfo_v2_t(_PrintableStructure):\n    _fields_ = [\n        (\"version\", c_uint),\n        (\"id\", c_uint),\n        (\"sliceCount\", c_uint),\n        (\"instanceCount\", c_uint),\n        (\"multiprocessorCount\", c_uint),\n        (\"sharedCopyEngineCount\", c_uint),\n        (\"sharedDecoderCount\", c_uint),\n        (\"sharedEncoderCount\", c_uint),\n        (\"sharedJpegCount\", c_uint),\n        (\"sharedOfaCount\", c_uint),\n        (\"name\", c_char * NVML_DEVICE_NAME_V2_BUFFER_SIZE),\n    ]\n\n    def __init__(self):\n        super(c_nvmlComputeInstanceProfileInfo_v2_t, self).__init__(\n            version=nvmlComputeInstanceProfileInfo_v2\n        )\n\n\nclass c_nvmlComputeInstanceInfo_t(Structure):\n    _fields_ = [\n        (\"device\", c_nvmlDevice_t),\n        (\"gpuInstance\", c_nvmlGpuInstance_t),\n        (\"id\", c_uint),\n        (\"profileId\", c_uint),\n        (\"placement\", c_nvmlComputeInstancePlacement_t),\n    ]\n\n\nNVML_MAX_GPU_UTILIZATIONS = 8\nNVML_GPU_UTILIZATION_DOMAIN_GPU = 0\nNVML_GPU_UTILIZATION_DOMAIN_FB = 1\nNVML_GPU_UTILIZATION_DOMAIN_VID = 2\nNVML_GPU_UTILIZATION_DOMAIN_BUS = 3\n\n\nclass c_nvmlGpuDynamicPstatesUtilization_t(Structure):\n    _fields_ = [\n        (\"bIsPresent\", c_uint, 1),\n        (\"percentage\", c_uint),\n        (\"incThreshold\", c_uint),\n        (\"decThreshold\", c_uint),\n    ]\n\n\nclass c_nvmlGpuDynamicPstatesInfo_t(Structure):\n    _fields_ = [\n        (\"flags\", c_uint),\n        (\n            \"utilization\",\n            c_nvmlGpuDynamicPstatesUtilization_t * NVML_MAX_GPU_UTILIZATIONS,\n        ),\n    ]\n\n\nNVML_MAX_THERMAL_SENSORS_PER_GPU = 3\n\nNVML_THERMAL_TARGET_NONE = 0\nNVML_THERMAL_TARGET_GPU = 1\nNVML_THERMAL_TARGET_MEMORY = 2\nNVML_THERMAL_TARGET_POWER_SUPPLY = 4\nNVML_THERMAL_TARGET_BOARD = 8\nNVML_THERMAL_TARGET_VCD_BOARD = 9\nNVML_THERMAL_TARGET_VCD_INLET = 10\nNVML_THERMAL_TARGET_VCD_OUTLET = 11\nNVML_THERMAL_TARGET_ALL = 15\nNVML_THERMAL_TARGET_UNKNOWN = -1\n\nNVML_THERMAL_CONTROLLER_NONE = 0\nNVML_THERMAL_CONTROLLER_GPU_INTERNAL = 1\nNVML_THERMAL_CONTROLLER_ADM1032 = 2\nNVML_THERMAL_CONTROLLER_ADT7461 = 3\nNVML_THERMAL_CONTROLLER_MAX6649 = 4\nNVML_THERMAL_CONTROLLER_MAX1617 = 5\nNVML_THERMAL_CONTROLLER_LM99 = 6\nNVML_THERMAL_CONTROLLER_LM89 = 7\nNVML_THERMAL_CONTROLLER_LM64 = 8\nNVML_THERMAL_CONTROLLER_G781 = 9\nNVML_THERMAL_CONTROLLER_ADT7473 = 10\nNVML_THERMAL_CONTROLLER_SBMAX6649 = 11\nNVML_THERMAL_CONTROLLER_VBIOSEVT = 12\nNVML_THERMAL_CONTROLLER_OS = 13\nNVML_THERMAL_CONTROLLER_NVSYSCON_CANOAS = 14\nNVML_THERMAL_CONTROLLER_NVSYSCON_E551 = 15\nNVML_THERMAL_CONTROLLER_MAX6649R = 16\nNVML_THERMAL_CONTROLLER_ADT7473S = 17\nNVML_THERMAL_CONTROLLER_UNKNOWN = -1\n\n\nclass c_nvmlGpuThermalSensor_t(Structure):\n    _fields_ = [\n        (\"controller\", c_int),\n        (\"defaultMinTemp\", c_int),\n        (\"defaultMaxTemp\", c_int),\n        (\"currentTemp\", c_int),\n        (\"target\", c_int),\n    ]\n\n\nclass c_nvmlGpuThermalSettings_t(Structure):\n    _fields_ = [\n        (\"count\", c_uint),\n        (\"sensor\", c_nvmlGpuThermalSensor_t * NVML_MAX_THERMAL_SENSORS_PER_GPU),\n    ]\n\n\n_nvmlCoolerControl_t = c_uint\nNVML_THERMAL_COOLER_SIGNAL_NONE = 0\nNVML_THERMAL_COOLER_SIGNAL_TOGGLE = 1\nNVML_THERMAL_COOLER_SIGNAL_VARIABLE = 2\nNVML_THERMAL_COOLER_SIGNAL_COUNT = 3\n\n_nvmlCoolerTarget_t = c_uint\nNVML_THERMAL_COOLER_TARGET_NONE = 1 << 0\nNVML_THERMAL_COOLER_TARGET_GPU = 1 << 1\nNVML_THERMAL_COOLER_TARGET_MEMORY = 1 << 2\nNVML_THERMAL_COOLER_TARGET_POWER_SUPPLY = 1 << 3\nNVML_THERMAL_COOLER_TARGET_GPU_RELATED = (\n    NVML_THERMAL_COOLER_TARGET_GPU\n    | NVML_THERMAL_COOLER_TARGET_MEMORY\n    | NVML_THERMAL_COOLER_TARGET_POWER_SUPPLY\n)\n\n\nclass c_nvmlCoolerInfo_t(_PrintableStructure):\n    _fields_ = [\n        (\"version\", c_uint),\n        (\"index\", c_uint),\n        (\"coolerControlType\", _nvmlCoolerControl_t),\n        (\"coolerTarget\", _nvmlCoolerTarget_t),\n    ]\n\n\nnvmlCoolerInfo_v1 = 0x1000010\n\n\ndef nvmlDeviceGetCoolerInfo(handle):\n    c_coolerInfo = c_nvmlCoolerInfo_t()\n    c_coolerInfo.version = nvmlCoolerInfo_v1\n    c_coolerInfo.index = 0\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetCoolerInfo\")\n    ret = fn(handle, byref(c_coolerInfo))\n    _nvmlCheckReturn(ret)\n    return [c_coolerInfo.coolerControlType, c_coolerInfo.coolerTarget]\n\n\nclass struct_c_nvmlComputeInstance_t(Structure):\n    pass  # opaque handle\n\n\nc_nvmlComputeInstance_t = POINTER(struct_c_nvmlComputeInstance_t)\n\n\nclass c_nvmlDeviceAttributes(Structure):\n    _fields_ = [\n        (\"multiprocessorCount\", c_uint),\n        (\"sharedCopyEngineCount\", c_uint),\n        (\"sharedDecoderCount\", c_uint),\n        (\"sharedEncoderCount\", c_uint),\n        (\"sharedJpegCount\", c_uint),\n        (\"sharedOfaCount\", c_uint),\n        (\"gpuInstanceSliceCount\", c_uint),\n        (\"computeInstanceSliceCount\", c_uint),\n        (\"memorySizeMB\", c_ulonglong),\n    ]\n\n\nclass c_nvmlRowRemapperHistogramValues(Structure):\n    _fields_ = [\n        (\"max\", c_uint),\n        (\"high\", c_uint),\n        (\"partial\", c_uint),\n        (\"low\", c_uint),\n        (\"none\", c_uint),\n    ]\n\n\nNVML_GPU_CERT_CHAIN_SIZE = 0x1000\nNVML_GPU_ATTESTATION_CERT_CHAIN_SIZE = 0x1400\nNVML_CC_GPU_CEC_NONCE_SIZE = 0x20\nNVML_CC_GPU_ATTESTATION_REPORT_SIZE = 0x2000\nNVML_CC_GPU_CEC_ATTESTATION_REPORT_SIZE = 0x1000\nNVML_CC_CEC_ATTESTATION_REPORT_NOT_PRESENT = 0\nNVML_CC_CEC_ATTESTATION_REPORT_PRESENT = 1\n\n\nclass c_nvmlConfComputeSystemState_t(Structure):\n    _fields_ = [\n        (\"environment\", c_uint),\n        (\"ccFeature\", c_uint),\n        (\"devToolsMode\", c_uint),\n    ]\n\n\nnvmlSystemConfComputeSettings_v1 = 0x1000014\n\n\nclass c_nvmlSystemConfComputeSettings_v1_t(Structure):\n    _fields_ = [\n        (\"version\", c_uint),\n        (\"environment\", c_uint),\n        (\"ccFeature\", c_uint),\n        (\"devToolsMode\", c_uint),\n        (\"multiGpuMode\", c_uint),\n    ]\n\n    def __init__(self):\n        super(c_nvmlSystemConfComputeSettings_v1_t, self).__init__(\n            version=nvmlSystemConfComputeSettings_v1\n        )\n\n\nclass c_nvmlConfComputeSystemCaps_t(Structure):\n    _fields_ = [\n        (\"cpuCaps\", c_uint),\n        (\"gpusCaps\", c_uint),\n    ]\n\n\nclass c_nvmlConfComputeMemSizeInfo_t(Structure):\n    _fields_ = [\n        (\"protectedMemSizeKib\", c_ulonglong),\n        (\"unprotectedMemSizeKib\", c_ulonglong),\n    ]\n\n\nclass c_nvmlConfComputeGpuCertificate_t(Structure):\n    _fields_ = [\n        (\"certChainSize\", c_uint),\n        (\"attestationCertChainSize\", c_uint),\n        (\"certChain\", c_uint8 * NVML_GPU_CERT_CHAIN_SIZE),\n        (\"attestationCertChain\", c_uint8 * NVML_GPU_ATTESTATION_CERT_CHAIN_SIZE),\n    ]\n\n\nclass c_nvmlConfComputeGpuAttestationReport_t(Structure):\n    _fields_ = [\n        (\"isCecAttestationReportPresent\", c_uint),\n        (\"attestationReportSize\", c_uint),\n        (\"cecAttestationReportSize\", c_uint),\n        (\"nonce\", c_uint8 * NVML_CC_GPU_CEC_NONCE_SIZE),\n        (\"attestationReport\", c_uint8 * NVML_CC_GPU_ATTESTATION_REPORT_SIZE),\n        (\"cecAttestationReport\", c_uint8 * NVML_CC_GPU_CEC_ATTESTATION_REPORT_SIZE),\n    ]\n\n\nclass c_nvmlConfComputeSetKeyRotationThresholdInfo_t(Structure):\n    _fields_ = [\n        (\"version\", c_uint),\n        (\"maxAttackerAdvantage\", c_ulong),\n    ]\n\n\nConfComputeSetKeyRotationThresholdInfo_v1 = 0x1000010\n\n\nclass c_nvmlConfComputeGetKeyRotationThresholdInfo_t(Structure):\n    _fields_ = [\n        (\"version\", c_uint),\n        (\"attackerAdvantage\", c_ulong),\n    ]\n\n\nConfComputeGetKeyRotationThresholdInfo_v1 = 0x1000010\n\n\n## string/bytes conversion for ease of use\ndef convertStrBytes(func):\n    \"\"\"\n    In python 3, strings are unicode instead of bytes, and need to be converted for ctypes\n    Args from caller: (1, 'string', <__main__.c_nvmlDevice_t at 0xFFFFFFFF>)\n    Args passed to function: (1, b'string', <__main__.c_nvmlDevice_t at 0xFFFFFFFF)>\n    ----\n    Returned from function: b'returned string'\n    Returned to caller: 'returned string'\n    \"\"\"\n\n    @wraps(func)\n    def wrapper(*args, **kwargs):\n        # encoding a str returns bytes in python 2 and 3\n        args = [arg.encode() if isinstance(arg, str) else arg for arg in args]\n        res = func(*args, **kwargs)\n        # In python 2, str and bytes are the same\n        # In python 3, str is unicode and should be decoded.\n        # Ctypes handles most conversions, this only effects c_char and char arrays.\n        if isinstance(res, bytes):\n            if isinstance(res, str):\n                return res\n            return res.decode()\n        return res\n\n    if sys.version_info >= (3,):\n        return wrapper\n    return func\n\n\ndef throwOnVersionMismatch(func):\n    @wraps(func)\n    def wrapper(*args, **kwargs):\n        try:\n            return func(*args, **kwargs)\n        except NVMLError_FunctionNotFound:\n            raise NVMLLibraryMismatchError(\n                \"Unversioned function called and the \"\n                \"pyNVML version does not match the NVML lib version. \"\n                \"Either use matching pyNVML and NVML lib versions or \"\n                \"use a versioned function such as \" + func.__name__ + \"_v2\"\n            )\n\n    return wrapper\n\n\n## C function wrappers ##\ndef nvmlInitWithFlags(flags):\n    _LoadNvmlLibrary()\n\n    #\n    # Initialize the library\n    #\n    fn = _nvmlGetFunctionPointer(\"nvmlInitWithFlags\")\n    ret = fn(flags)\n    _nvmlCheckReturn(ret)\n\n    # Atomically update refcount\n    global _nvmlLib_refcount\n    libLoadLock.acquire()\n    _nvmlLib_refcount += 1\n    libLoadLock.release()\n    return None\n\n\ndef nvmlInit():\n    nvmlInitWithFlags(0)\n    return None\n\n\ndef _LoadNvmlLibrary():\n    \"\"\"\n    Load the library if it isn't loaded already\n    \"\"\"\n    global nvmlLib\n\n    if nvmlLib is None:\n        # lock to ensure only one caller loads the library\n        libLoadLock.acquire()\n\n        try:\n            # ensure the library still isn't loaded\n            if nvmlLib is None:\n                try:\n                    if sys.platform[:3] == \"win\":\n                        # cdecl calling convention\n                        try:\n                            # Check for nvml.dll in System32 first for DCH drivers\n                            nvmlLib = CDLL(\n                                os.path.join(\n                                    os.getenv(\"WINDIR\", \"C:/Windows\"),\n                                    \"System32/nvml.dll\",\n                                )\n                            )\n                        except OSError as ose:\n                            # If nvml.dll is not found in System32, it should be in ProgramFiles\n                            # load nvml.dll from %ProgramFiles%/NVIDIA Corporation/NVSMI/nvml.dll\n                            nvmlLib = CDLL(\n                                os.path.join(\n                                    os.getenv(\"ProgramFiles\", \"C:/Program Files\"),\n                                    \"NVIDIA Corporation/NVSMI/nvml.dll\",\n                                )\n                            )\n                    else:\n                        # assume linux\n                        nvmlLib = CDLL(\"libnvidia-ml.so.1\")\n                except OSError as ose:\n                    _nvmlCheckReturn(NVML_ERROR_LIBRARY_NOT_FOUND)\n                if nvmlLib is None:\n                    _nvmlCheckReturn(NVML_ERROR_LIBRARY_NOT_FOUND)\n        finally:\n            # lock is always freed\n            libLoadLock.release()\n\n\ndef nvmlShutdown():\n    #\n    # Leave the library loaded, but shutdown the interface\n    #\n    fn = _nvmlGetFunctionPointer(\"nvmlShutdown\")\n    ret = fn()\n    _nvmlCheckReturn(ret)\n\n    # Atomically update refcount\n    global _nvmlLib_refcount\n    libLoadLock.acquire()\n    if 0 < _nvmlLib_refcount:\n        _nvmlLib_refcount -= 1\n    libLoadLock.release()\n    return None\n\n\n# Added in 2.285\n@convertStrBytes\ndef nvmlErrorString(result):\n    fn = _nvmlGetFunctionPointer(\"nvmlErrorString\")\n    fn.restype = c_char_p  # otherwise return is an int\n    ret = fn(result)\n    return ret\n\n\n# Added in 2.285\n@convertStrBytes\ndef nvmlSystemGetNVMLVersion():\n    c_version = create_string_buffer(NVML_SYSTEM_NVML_VERSION_BUFFER_SIZE)\n    fn = _nvmlGetFunctionPointer(\"nvmlSystemGetNVMLVersion\")\n    ret = fn(c_version, c_uint(NVML_SYSTEM_NVML_VERSION_BUFFER_SIZE))\n    _nvmlCheckReturn(ret)\n    return c_version.value\n\n\ndef nvmlSystemGetCudaDriverVersion():\n    c_cuda_version = c_int()\n    fn = _nvmlGetFunctionPointer(\"nvmlSystemGetCudaDriverVersion\")\n    ret = fn(byref(c_cuda_version))\n    _nvmlCheckReturn(ret)\n    return c_cuda_version.value\n\n\ndef nvmlSystemGetCudaDriverVersion_v2():\n    c_cuda_version = c_int()\n    fn = _nvmlGetFunctionPointer(\"nvmlSystemGetCudaDriverVersion_v2\")\n    ret = fn(byref(c_cuda_version))\n    _nvmlCheckReturn(ret)\n    return c_cuda_version.value\n\n\n# Added in 2.285\n@convertStrBytes\ndef nvmlSystemGetProcessName(pid):\n    c_name = create_string_buffer(1024)\n    fn = _nvmlGetFunctionPointer(\"nvmlSystemGetProcessName\")\n    ret = fn(c_uint(pid), c_name, c_uint(1024))\n    _nvmlCheckReturn(ret)\n    return c_name.value\n\n\n@convertStrBytes\ndef nvmlSystemGetDriverVersion():\n    c_version = create_string_buffer(NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE)\n    fn = _nvmlGetFunctionPointer(\"nvmlSystemGetDriverVersion\")\n    ret = fn(c_version, c_uint(NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE))\n    _nvmlCheckReturn(ret)\n    return c_version.value\n\n\n# Added in 2.285\ndef nvmlSystemGetHicVersion():\n    c_count = c_uint(0)\n    hics = None\n    fn = _nvmlGetFunctionPointer(\"nvmlSystemGetHicVersion\")\n\n    # get the count\n    ret = fn(byref(c_count), None)\n\n    # this should only fail with insufficient size\n    if (ret != NVML_SUCCESS) and (ret != NVML_ERROR_INSUFFICIENT_SIZE):\n        raise NVMLError(ret)\n\n    # If there are no hics\n    if c_count.value == 0:\n        return []\n\n    hic_array = c_nvmlHwbcEntry_t * c_count.value\n    hics = hic_array()\n    ret = fn(byref(c_count), hics)\n    _nvmlCheckReturn(ret)\n    return hics\n\n\ndef nvmlSystemGetDriverBranch():\n    c_branchInfo = c_nvmlSystemDriverBranchInfo_v1_t(0)\n    c_branchInfo.version = SystemDriverBranchInfo_v1\n    fn = _nvmlGetFunctionPointer(\"nvmlSystemGetDriverBranch\")\n    ret = fn(byref(c_branchInfo), c_uint(NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE))\n    _nvmlCheckReturn(ret)\n    return c_branchInfo\n\n\n## Unit get functions\ndef nvmlUnitGetCount():\n    c_count = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlUnitGetCount\")\n    ret = fn(byref(c_count))\n    _nvmlCheckReturn(ret)\n    return c_count.value\n\n\ndef nvmlUnitGetHandleByIndex(index):\n    c_index = c_uint(index)\n    unit = c_nvmlUnit_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlUnitGetHandleByIndex\")\n    ret = fn(c_index, byref(unit))\n    _nvmlCheckReturn(ret)\n    return unit\n\n\ndef nvmlUnitGetUnitInfo(unit):\n    c_info = c_nvmlUnitInfo_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlUnitGetUnitInfo\")\n    ret = fn(unit, byref(c_info))\n    _nvmlCheckReturn(ret)\n    return c_info\n\n\ndef nvmlUnitGetLedState(unit):\n    c_state = c_nvmlLedState_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlUnitGetLedState\")\n    ret = fn(unit, byref(c_state))\n    _nvmlCheckReturn(ret)\n    return c_state\n\n\ndef nvmlUnitGetPsuInfo(unit):\n    c_info = c_nvmlPSUInfo_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlUnitGetPsuInfo\")\n    ret = fn(unit, byref(c_info))\n    _nvmlCheckReturn(ret)\n    return c_info\n\n\ndef nvmlUnitGetTemperature(unit, type):\n    c_temp = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlUnitGetTemperature\")\n    ret = fn(unit, c_uint(type), byref(c_temp))\n    _nvmlCheckReturn(ret)\n    return c_temp.value\n\n\ndef nvmlUnitGetFanSpeedInfo(unit):\n    c_speeds = c_nvmlUnitFanSpeeds_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlUnitGetFanSpeedInfo\")\n    ret = fn(unit, byref(c_speeds))\n    _nvmlCheckReturn(ret)\n    return c_speeds\n\n\n# added to API\ndef nvmlUnitGetDeviceCount(unit):\n    c_count = c_uint(0)\n    # query the unit to determine device count\n    fn = _nvmlGetFunctionPointer(\"nvmlUnitGetDevices\")\n    ret = fn(unit, byref(c_count), None)\n    if ret == NVML_ERROR_INSUFFICIENT_SIZE:\n        ret = NVML_SUCCESS\n    _nvmlCheckReturn(ret)\n    return c_count.value\n\n\ndef nvmlUnitGetDevices(unit):\n    c_count = c_uint(nvmlUnitGetDeviceCount(unit))\n    device_array = c_nvmlDevice_t * c_count.value\n    c_devices = device_array()\n    fn = _nvmlGetFunctionPointer(\"nvmlUnitGetDevices\")\n    ret = fn(unit, byref(c_count), c_devices)\n    _nvmlCheckReturn(ret)\n    return c_devices\n\n\n## Device get functions\ndef nvmlDeviceGetCount():\n    c_count = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetCount_v2\")\n    ret = fn(byref(c_count))\n    _nvmlCheckReturn(ret)\n    return c_count.value\n\n\ndef nvmlDeviceGetHandleByIndex(index):\n    c_index = c_uint(index)\n    device = c_nvmlDevice_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetHandleByIndex_v2\")\n    ret = fn(c_index, byref(device))\n    _nvmlCheckReturn(ret)\n    return device\n\n\n@convertStrBytes\ndef nvmlDeviceGetHandleBySerial(serial):\n    c_serial = c_char_p(serial)\n    device = c_nvmlDevice_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetHandleBySerial\")\n    ret = fn(c_serial, byref(device))\n    _nvmlCheckReturn(ret)\n    return device\n\n\n@convertStrBytes\ndef nvmlDeviceGetHandleByUUID(uuid):\n    c_uuid = c_char_p(uuid)\n    device = c_nvmlDevice_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetHandleByUUID\")\n    ret = fn(c_uuid, byref(device))\n    _nvmlCheckReturn(ret)\n    return device\n\n\n@convertStrBytes\ndef nvmlDeviceGetHandleByPciBusId(pciBusId):\n    c_busId = c_char_p(pciBusId)\n    device = c_nvmlDevice_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetHandleByPciBusId_v2\")\n    ret = fn(c_busId, byref(device))\n    _nvmlCheckReturn(ret)\n    return device\n\n\n@convertStrBytes\ndef nvmlDeviceGetName(handle):\n    c_name = create_string_buffer(NVML_DEVICE_NAME_V2_BUFFER_SIZE)\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetName\")\n    ret = fn(handle, c_name, c_uint(NVML_DEVICE_NAME_V2_BUFFER_SIZE))\n    _nvmlCheckReturn(ret)\n    return c_name.value\n\n\nclass c_nvmlDevicePerfModes_v1_t(_PrintableStructure):\n    _fields_ = [\n        (\"version\", c_uint),\n        (\"str\", c_char * NVML_PERF_MODES_BUFFER_SIZE),\n    ]\n\n\nnvmlDevicePerfModes_v1 = 0x1000804\n\n\n@convertStrBytes\ndef nvmlDeviceGetPerformanceModes(handle):\n    perfModes = c_nvmlDevicePerfModes_v1_t()\n    perfModes.version = nvmlDevicePerfModes_v1\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetPerformanceModes\")\n    ret = fn(handle, byref(perfModes))\n    _nvmlCheckReturn(ret)\n    return perfModes.str\n\n\nclass c_nvmlDeviceCurrentClockFreqs_v1_t(_PrintableStructure):\n    _fields_ = [\n        (\"version\", c_uint),\n        (\"str\", c_char * NVML_PERF_MODES_BUFFER_SIZE),\n    ]\n\n\nnvmlDeviceCurrentClockFreqs_v1 = 0x1000804\n\n\n@convertStrBytes\ndef nvmlDeviceGetCurrentClockFreqs(handle):\n    currentClockFreqs = c_nvmlDeviceCurrentClockFreqs_v1_t()\n    currentClockFreqs.version = nvmlDeviceCurrentClockFreqs_v1\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetCurrentClockFreqs\")\n    ret = fn(handle, byref(currentClockFreqs))\n    _nvmlCheckReturn(ret)\n    return currentClockFreqs.str\n\n\ndef nvmlDeviceGetBoardId(handle):\n    c_id = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetBoardId\")\n    ret = fn(handle, byref(c_id))\n    _nvmlCheckReturn(ret)\n    return c_id.value\n\n\ndef nvmlDeviceGetMultiGpuBoard(handle):\n    c_multiGpu = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetMultiGpuBoard\")\n    ret = fn(handle, byref(c_multiGpu))\n    _nvmlCheckReturn(ret)\n    return c_multiGpu.value\n\n\ndef nvmlDeviceGetBrand(handle):\n    c_type = _nvmlBrandType_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetBrand\")\n    ret = fn(handle, byref(c_type))\n    _nvmlCheckReturn(ret)\n    return c_type.value\n\n\ndef nvmlDeviceGetC2cModeInfoV1(handle):\n    c_info = c_nvmlC2cModeInfo_v1_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetC2cModeInfoV\")\n    ret = fn(handle, byref(c_info))\n    _nvmlCheckReturn(ret)\n    return c_info\n\n\ndef nvmlDeviceGetC2cModeInfoV(handle):\n    return nvmlDeviceGetC2cModeInfoV1(handle)\n\n\n@convertStrBytes\ndef nvmlDeviceGetBoardPartNumber(handle):\n    c_part_number = create_string_buffer(NVML_DEVICE_PART_NUMBER_BUFFER_SIZE)\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetBoardPartNumber\")\n    ret = fn(handle, c_part_number, c_uint(NVML_DEVICE_PART_NUMBER_BUFFER_SIZE))\n    _nvmlCheckReturn(ret)\n    return c_part_number.value\n\n\n@convertStrBytes\ndef nvmlDeviceGetSerial(handle):\n    c_serial = create_string_buffer(NVML_DEVICE_SERIAL_BUFFER_SIZE)\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetSerial\")\n    ret = fn(handle, c_serial, c_uint(NVML_DEVICE_SERIAL_BUFFER_SIZE))\n    _nvmlCheckReturn(ret)\n    return c_serial.value\n\n\ndef nvmlDeviceGetModuleId(handle, moduleId=c_uint()):\n    isReference = type(moduleId) is not c_uint\n    moduleIdRef = moduleId if isReference else byref(moduleId)\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetModuleId\")\n    ret = fn(handle, moduleIdRef)\n    if isReference:\n        return ret\n    else:\n        _nvmlCheckReturn(ret)\n        return moduleId.value\n\n\ndef nvmlDeviceGetMemoryAffinity(handle, nodeSetSize, scope):\n    affinity_array = c_ulonglong * nodeSetSize\n    c_affinity = affinity_array()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetMemoryAffinity\")\n    ret = fn(handle, nodeSetSize, byref(c_affinity), _nvmlAffinityScope_t(scope))\n    _nvmlCheckReturn(ret)\n    return c_affinity\n\n\ndef nvmlDeviceGetCpuAffinityWithinScope(handle, cpuSetSize, scope):\n    affinity_array = c_ulonglong * cpuSetSize\n    c_affinity = affinity_array()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetCpuAffinityWithinScope\")\n    ret = fn(handle, cpuSetSize, byref(c_affinity), _nvmlAffinityScope_t(scope))\n    _nvmlCheckReturn(ret)\n    return c_affinity\n\n\ndef nvmlDeviceGetCpuAffinity(handle, cpuSetSize):\n    affinity_array = c_ulonglong * cpuSetSize\n    c_affinity = affinity_array()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetCpuAffinity\")\n    ret = fn(handle, cpuSetSize, byref(c_affinity))\n    _nvmlCheckReturn(ret)\n    return c_affinity\n\n\ndef nvmlDeviceSetCpuAffinity(handle):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceSetCpuAffinity\")\n    ret = fn(handle)\n    _nvmlCheckReturn(ret)\n    return None\n\n\ndef nvmlDeviceClearCpuAffinity(handle):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceClearCpuAffinity\")\n    ret = fn(handle)\n    _nvmlCheckReturn(ret)\n    return None\n\n\ndef nvmlDeviceGetNumaNodeId(handle):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetNumaNodeId\")\n    node = c_int()\n    ret = fn(handle, byref(node))\n    _nvmlCheckReturn(ret)\n    return node.value\n\n\ndef nvmlDeviceGetMinorNumber(handle):\n    c_minor_number = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetMinorNumber\")\n    ret = fn(handle, byref(c_minor_number))\n    _nvmlCheckReturn(ret)\n    return c_minor_number.value\n\n\n@convertStrBytes\ndef nvmlDeviceGetUUID(handle):\n    c_uuid = create_string_buffer(NVML_DEVICE_UUID_V2_BUFFER_SIZE)\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetUUID\")\n    ret = fn(handle, c_uuid, c_uint(NVML_DEVICE_UUID_V2_BUFFER_SIZE))\n    _nvmlCheckReturn(ret)\n    return c_uuid.value\n\n\n@convertStrBytes\ndef nvmlDeviceGetInforomVersion(handle, infoRomObject):\n    c_version = create_string_buffer(NVML_DEVICE_INFOROM_VERSION_BUFFER_SIZE)\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetInforomVersion\")\n    ret = fn(\n        handle,\n        _nvmlInforomObject_t(infoRomObject),\n        c_version,\n        c_uint(NVML_DEVICE_INFOROM_VERSION_BUFFER_SIZE),\n    )\n    _nvmlCheckReturn(ret)\n    return c_version.value\n\n\n# Added in 4.304\n@convertStrBytes\ndef nvmlDeviceGetInforomImageVersion(handle):\n    c_version = create_string_buffer(NVML_DEVICE_INFOROM_VERSION_BUFFER_SIZE)\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetInforomImageVersion\")\n    ret = fn(handle, c_version, c_uint(NVML_DEVICE_INFOROM_VERSION_BUFFER_SIZE))\n    _nvmlCheckReturn(ret)\n    return c_version.value\n\n\n# Added in 4.304\ndef nvmlDeviceGetInforomConfigurationChecksum(handle):\n    c_checksum = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetInforomConfigurationChecksum\")\n    ret = fn(handle, byref(c_checksum))\n    _nvmlCheckReturn(ret)\n    return c_checksum.value\n\n\n# Added in 4.304\ndef nvmlDeviceValidateInforom(handle):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceValidateInforom\")\n    ret = fn(handle)\n    _nvmlCheckReturn(ret)\n    return None\n\n\ndef nvmlDeviceGetLastBBXFlushTime(handle):\n    c_timestamp = c_ulonglong()\n    c_durationUs = c_ulong()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetLastBBXFlushTime\")\n    ret = fn(handle, byref(c_timestamp), byref(c_durationUs))\n    _nvmlCheckReturn(ret)\n    return [c_timestamp.value, c_durationUs.value]\n\n\ndef nvmlDeviceGetDisplayMode(handle):\n    c_mode = _nvmlEnableState_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetDisplayMode\")\n    ret = fn(handle, byref(c_mode))\n    _nvmlCheckReturn(ret)\n    return c_mode.value\n\n\ndef nvmlDeviceGetDisplayActive(handle):\n    c_mode = _nvmlEnableState_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetDisplayActive\")\n    ret = fn(handle, byref(c_mode))\n    _nvmlCheckReturn(ret)\n    return c_mode.value\n\n\ndef nvmlDeviceGetPersistenceMode(handle):\n    c_state = _nvmlEnableState_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetPersistenceMode\")\n    ret = fn(handle, byref(c_state))\n    _nvmlCheckReturn(ret)\n    return c_state.value\n\n\ndef nvmlDeviceGetPciInfoExt(handle, c_info):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetPciInfoExt\")\n    ret = fn(handle, c_info)\n    _nvmlCheckReturn(ret)\n    return None\n\n\ndef nvmlDeviceGetPciInfo_v3(handle):\n    c_info = nvmlPciInfo_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetPciInfo_v3\")\n    ret = fn(handle, byref(c_info))\n    _nvmlCheckReturn(ret)\n    return c_info\n\n\ndef nvmlDeviceGetPciInfo(handle):\n    return nvmlDeviceGetPciInfo_v3(handle)\n\n\ndef nvmlDeviceGetClockInfo(handle, type):\n    c_clock = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetClockInfo\")\n    ret = fn(handle, _nvmlClockType_t(type), byref(c_clock))\n    _nvmlCheckReturn(ret)\n    return c_clock.value\n\n\n# Added in 2.285\ndef nvmlDeviceGetMaxClockInfo(handle, type):\n    c_clock = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetMaxClockInfo\")\n    ret = fn(handle, _nvmlClockType_t(type), byref(c_clock))\n    _nvmlCheckReturn(ret)\n    return c_clock.value\n\n\n# Added in 4.304\ndef nvmlDeviceGetApplicationsClock(handle, type):\n    c_clock = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetApplicationsClock\")\n    ret = fn(handle, _nvmlClockType_t(type), byref(c_clock))\n    _nvmlCheckReturn(ret)\n    return c_clock.value\n\n\ndef nvmlDeviceGetMaxCustomerBoostClock(handle, type):\n    c_clock = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetMaxCustomerBoostClock\")\n    ret = fn(handle, _nvmlClockType_t(type), byref(c_clock))\n    _nvmlCheckReturn(ret)\n    return c_clock.value\n\n\ndef nvmlDeviceGetClock(handle, type, id):\n    c_clock = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetClock\")\n    ret = fn(handle, _nvmlClockType_t(type), _nvmlClockId_t(id), byref(c_clock))\n    _nvmlCheckReturn(ret)\n    return c_clock.value\n\n\n# Added in 5.319\ndef nvmlDeviceGetDefaultApplicationsClock(handle, type):\n    c_clock = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetDefaultApplicationsClock\")\n    ret = fn(handle, _nvmlClockType_t(type), byref(c_clock))\n    _nvmlCheckReturn(ret)\n    return c_clock.value\n\n\n# Added in 4.304\ndef nvmlDeviceGetSupportedMemoryClocks(handle):\n    # first call to get the size\n    c_count = c_uint(0)\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetSupportedMemoryClocks\")\n    ret = fn(handle, byref(c_count), None)\n\n    if ret == NVML_SUCCESS:\n        # special case, no clocks\n        return []\n    elif ret == NVML_ERROR_INSUFFICIENT_SIZE:\n        # typical case\n        clocks_array = c_uint * c_count.value\n        c_clocks = clocks_array()\n\n        # make the call again\n        ret = fn(handle, byref(c_count), c_clocks)\n        _nvmlCheckReturn(ret)\n\n        procs = []\n        for i in range(c_count.value):\n            procs.append(c_clocks[i])\n\n        return procs\n    else:\n        # error case\n        raise NVMLError(ret)\n\n\n# Added in 4.304\ndef nvmlDeviceGetSupportedGraphicsClocks(handle, memoryClockMHz):\n    # first call to get the size\n    c_count = c_uint(0)\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetSupportedGraphicsClocks\")\n    ret = fn(handle, c_uint(memoryClockMHz), byref(c_count), None)\n\n    if ret == NVML_SUCCESS:\n        # special case, no clocks\n        return []\n    elif ret == NVML_ERROR_INSUFFICIENT_SIZE:\n        # typical case\n        clocks_array = c_uint * c_count.value\n        c_clocks = clocks_array()\n\n        # make the call again\n        ret = fn(handle, c_uint(memoryClockMHz), byref(c_count), c_clocks)\n        _nvmlCheckReturn(ret)\n\n        procs = []\n        for i in range(c_count.value):\n            procs.append(c_clocks[i])\n\n        return procs\n    else:\n        # error case\n        raise NVMLError(ret)\n\n\ndef nvmlDeviceGetFanSpeed(handle):\n    c_speed = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetFanSpeed\")\n    ret = fn(handle, byref(c_speed))\n    _nvmlCheckReturn(ret)\n    return c_speed.value\n\n\ndef nvmlDeviceGetFanSpeed_v2(handle, fan):\n    c_speed = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetFanSpeed_v2\")\n    ret = fn(handle, fan, byref(c_speed))\n    _nvmlCheckReturn(ret)\n    return c_speed.value\n\n\nclass c_nvmlFanSpeedInfo_t(_PrintableStructure):\n    _fields_ = [\n        (\"version\", c_uint),\n        (\"fan\", c_uint),\n        (\"speed\", c_uint),\n    ]\n\n\nnvmlFanSpeedInfo_v1 = 0x100000C\n\n\ndef nvmlDeviceGetFanSpeedRPM(handle):\n    c_fanSpeed = c_nvmlFanSpeedInfo_t()\n    c_fanSpeed.fan = 0\n    c_fanSpeed.version = nvmlFanSpeedInfo_v1\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetFanSpeedRPM\")\n    ret = fn(handle, byref(c_fanSpeed))\n    _nvmlCheckReturn(ret)\n    return c_fanSpeed.speed\n\n\ndef nvmlDeviceGetTargetFanSpeed(handle, fan):\n    c_speed = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetTargetFanSpeed\")\n    ret = fn(handle, fan, byref(c_speed))\n    _nvmlCheckReturn(ret)\n    return c_speed.value\n\n\ndef nvmlDeviceGetNumFans(device):\n    c_numFans = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetNumFans\")\n    ret = fn(device, byref(c_numFans))\n    _nvmlCheckReturn(ret)\n    return c_numFans.value\n\n\ndef nvmlDeviceSetDefaultFanSpeed_v2(handle, index):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceSetDefaultFanSpeed_v2\")\n    ret = fn(handle, index)\n    _nvmlCheckReturn(ret)\n    return NVML_SUCCESS\n\n\ndef nvmlDeviceGetMinMaxFanSpeed(handle, minSpeed=c_uint(), maxSpeed=c_uint()):\n    isReference = (type(minSpeed) is not c_uint) or (type(maxSpeed) is not c_uint)\n    minSpeedRef = minSpeed if isReference else byref(minSpeed)\n    maxSpeedRef = maxSpeed if isReference else byref(maxSpeed)\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetMinMaxFanSpeed\")\n    ret = fn(handle, minSpeedRef, maxSpeedRef)\n    _nvmlCheckReturn(ret)\n    return NVML_SUCCESS if isReference else [minSpeed.value, maxSpeed.value]\n\n\ndef nvmlDeviceGetFanControlPolicy_v2(handle, fan, fanControlPolicy=c_uint()):\n    isReference = type(fanControlPolicy) is not c_uint\n    fanControlPolicyRef = fanControlPolicy if isReference else byref(fanControlPolicy)\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetFanControlPolicy_v2\")\n    ret = fn(handle, fan, fanControlPolicyRef)\n    _nvmlCheckReturn(ret)\n    return NVML_SUCCESS if isReference else fanControlPolicy.value\n\n\ndef nvmlDeviceSetFanControlPolicy(handle, fan, fanControlPolicy):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceSetFanControlPolicy\")\n    ret = fn(handle, fan, _nvmlFanControlPolicy_t(fanControlPolicy))\n    _nvmlCheckReturn(ret)\n    return NVML_SUCCESS\n\n\nclass c_nvmlTemperature_v1_t(_PrintableStructure):\n    _fields_ = [\n        (\"version\", c_uint),\n        (\"sensorType\", _nvmlTemperatureSensors_t),\n        (\"temperature\", c_int),\n    ]\n\n\nnvmlTemperature_v1 = 0x100000C\n\n\ndef nvmlDeviceGetTemperatureV1(handle, sensor):\n    c_temp = c_nvmlTemperature_v1_t()\n    c_temp.version = nvmlTemperature_v1\n    c_temp.sensorType = _nvmlTemperatureSensors_t(sensor)\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetTemperatureV\")\n    ret = fn(handle, byref(c_temp))\n    _nvmlCheckReturn(ret)\n    return c_temp.temperature\n\n\ndef nvmlDeviceGetTemperatureV(handle, sensor, version=nvmlTemperature_v1):\n    if version == nvmlTemperature_v1:\n        return nvmlDeviceGetTemperatureV1(handle, sensor)\n    else:\n        raise NVMLError(NVML_ERROR_ARGUMENT_VERSION_MISMATCH)\n\n\n# DEPRECATED use nvmlDeviceGetTemperatureV instead\ndef nvmlDeviceGetTemperature(handle, sensor):\n    c_temp = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetTemperature\")\n    ret = fn(handle, _nvmlTemperatureSensors_t(sensor), byref(c_temp))\n    _nvmlCheckReturn(ret)\n    return c_temp.value\n\n\ndef nvmlDeviceGetTemperatureThreshold(handle, threshold):\n    c_temp = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetTemperatureThreshold\")\n    ret = fn(handle, _nvmlTemperatureThresholds_t(threshold), byref(c_temp))\n    _nvmlCheckReturn(ret)\n    return c_temp.value\n\n\ndef nvmlDeviceSetTemperatureThreshold(handle, threshold, temp):\n    c_temp = c_uint()\n    c_temp.value = temp\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceSetTemperatureThreshold\")\n    ret = fn(handle, _nvmlTemperatureThresholds_t(threshold), byref(c_temp))\n    _nvmlCheckReturn(ret)\n    return None\n\n\ndef nvmlDeviceGetMarginTemperature(handle):\n    c_marginTempInfo = c_nvmlMarginTemperature_v1_t()\n    c_marginTempInfo.version = nvmlMarginTemperature_v1\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetMarginTemperature\")\n    ret = fn(handle, byref(c_marginTempInfo))\n    _nvmlCheckReturn(ret)\n    return c_marginTempInfo.marginTemperature\n\n\n# DEPRECATED use nvmlDeviceGetPerformanceState\ndef nvmlDeviceGetPowerState(handle):\n    c_pstate = _nvmlPstates_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetPowerState\")\n    ret = fn(handle, byref(c_pstate))\n    _nvmlCheckReturn(ret)\n    return c_pstate.value\n\n\ndef nvmlDeviceGetPerformanceState(handle):\n    c_pstate = _nvmlPstates_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetPerformanceState\")\n    ret = fn(handle, byref(c_pstate))\n    _nvmlCheckReturn(ret)\n    return c_pstate.value\n\n\ndef nvmlDeviceGetPowerManagementMode(handle):\n    c_pcapMode = _nvmlEnableState_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetPowerManagementMode\")\n    ret = fn(handle, byref(c_pcapMode))\n    _nvmlCheckReturn(ret)\n    return c_pcapMode.value\n\n\ndef nvmlDeviceGetPowerManagementLimit(handle):\n    c_limit = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetPowerManagementLimit\")\n    ret = fn(handle, byref(c_limit))\n    _nvmlCheckReturn(ret)\n    return c_limit.value\n\n\n# Added in 4.304\ndef nvmlDeviceGetPowerManagementLimitConstraints(handle):\n    c_minLimit = c_uint()\n    c_maxLimit = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetPowerManagementLimitConstraints\")\n    ret = fn(handle, byref(c_minLimit), byref(c_maxLimit))\n    _nvmlCheckReturn(ret)\n    return [c_minLimit.value, c_maxLimit.value]\n\n\n# Added in 4.304\ndef nvmlDeviceGetPowerManagementDefaultLimit(handle):\n    c_limit = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetPowerManagementDefaultLimit\")\n    ret = fn(handle, byref(c_limit))\n    _nvmlCheckReturn(ret)\n    return c_limit.value\n\n\n# Added in 331\ndef nvmlDeviceGetEnforcedPowerLimit(handle):\n    c_limit = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetEnforcedPowerLimit\")\n    ret = fn(handle, byref(c_limit))\n    _nvmlCheckReturn(ret)\n    return c_limit.value\n\n\ndef nvmlDeviceGetPowerUsage(handle):\n    c_watts = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetPowerUsage\")\n    ret = fn(handle, byref(c_watts))\n    _nvmlCheckReturn(ret)\n    return c_watts.value\n\n\ndef nvmlDeviceGetTotalEnergyConsumption(handle):\n    c_millijoules = c_uint64()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetTotalEnergyConsumption\")\n    ret = fn(handle, byref(c_millijoules))\n    _nvmlCheckReturn(ret)\n    return c_millijoules.value\n\n\n# Added in 4.304\ndef nvmlDeviceGetGpuOperationMode(handle):\n    c_currState = _nvmlGpuOperationMode_t()\n    c_pendingState = _nvmlGpuOperationMode_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetGpuOperationMode\")\n    ret = fn(handle, byref(c_currState), byref(c_pendingState))\n    _nvmlCheckReturn(ret)\n    return [c_currState.value, c_pendingState.value]\n\n\n# Added in 4.304\ndef nvmlDeviceGetCurrentGpuOperationMode(handle):\n    return nvmlDeviceGetGpuOperationMode(handle)[0]\n\n\n# Added in 4.304\ndef nvmlDeviceGetPendingGpuOperationMode(handle):\n    return nvmlDeviceGetGpuOperationMode(handle)[1]\n\n\ndef nvmlDeviceGetMemoryInfo(handle, version=None):\n    if not version:\n        c_memory = c_nvmlMemory_t()\n        fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetMemoryInfo\")\n    else:\n        c_memory = c_nvmlMemory_v2_t()\n        c_memory.version = version\n        fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetMemoryInfo_v2\")\n    ret = fn(handle, byref(c_memory))\n    _nvmlCheckReturn(ret)\n    return c_memory\n\n\ndef nvmlDeviceGetBAR1MemoryInfo(handle):\n    c_bar1_memory = c_nvmlBAR1Memory_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetBAR1MemoryInfo\")\n    ret = fn(handle, byref(c_bar1_memory))\n    _nvmlCheckReturn(ret)\n    return c_bar1_memory\n\n\ndef nvmlDeviceGetComputeMode(handle):\n    c_mode = _nvmlComputeMode_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetComputeMode\")\n    ret = fn(handle, byref(c_mode))\n    _nvmlCheckReturn(ret)\n    return c_mode.value\n\n\ndef nvmlDeviceGetCudaComputeCapability(handle):\n    c_major = c_int()\n    c_minor = c_int()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetCudaComputeCapability\")\n    ret = fn(handle, byref(c_major), byref(c_minor))\n    _nvmlCheckReturn(ret)\n    return (c_major.value, c_minor.value)\n\n\ndef nvmlDeviceGetEccMode(handle):\n    c_currState = _nvmlEnableState_t()\n    c_pendingState = _nvmlEnableState_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetEccMode\")\n    ret = fn(handle, byref(c_currState), byref(c_pendingState))\n    _nvmlCheckReturn(ret)\n    return [c_currState.value, c_pendingState.value]\n\n\n# added to API\ndef nvmlDeviceGetCurrentEccMode(handle):\n    return nvmlDeviceGetEccMode(handle)[0]\n\n\n# added to API\ndef nvmlDeviceGetPendingEccMode(handle):\n    return nvmlDeviceGetEccMode(handle)[1]\n\n\ndef nvmlDeviceGetDefaultEccMode(handle):\n    c_defaultState = _nvmlEnableState_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetDefaultEccMode\")\n    ret = fn(handle, byref(c_defaultState))\n    _nvmlCheckReturn(ret)\n    return [c_defaultState.value]\n\n\ndef nvmlDeviceGetTotalEccErrors(handle, errorType, counterType):\n    c_count = c_ulonglong()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetTotalEccErrors\")\n    ret = fn(\n        handle,\n        _nvmlMemoryErrorType_t(errorType),\n        _nvmlEccCounterType_t(counterType),\n        byref(c_count),\n    )\n    _nvmlCheckReturn(ret)\n    return c_count.value\n\n\n# This is deprecated, instead use nvmlDeviceGetMemoryErrorCounter\ndef nvmlDeviceGetDetailedEccErrors(handle, errorType, counterType):\n    c_counts = c_nvmlEccErrorCounts_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetDetailedEccErrors\")\n    ret = fn(\n        handle,\n        _nvmlMemoryErrorType_t(errorType),\n        _nvmlEccCounterType_t(counterType),\n        byref(c_counts),\n    )\n    _nvmlCheckReturn(ret)\n    return c_counts\n\n\n# Added in 4.304\ndef nvmlDeviceGetMemoryErrorCounter(handle, errorType, counterType, locationType):\n    c_count = c_ulonglong()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetMemoryErrorCounter\")\n    ret = fn(\n        handle,\n        _nvmlMemoryErrorType_t(errorType),\n        _nvmlEccCounterType_t(counterType),\n        _nvmlMemoryLocation_t(locationType),\n        byref(c_count),\n    )\n    _nvmlCheckReturn(ret)\n    return c_count.value\n\n\ndef nvmlDeviceGetUtilizationRates(handle):\n    c_util = c_nvmlUtilization_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetUtilizationRates\")\n    ret = fn(handle, byref(c_util))\n    _nvmlCheckReturn(ret)\n    return c_util\n\n\ndef nvmlDeviceGetEncoderUtilization(handle):\n    c_util = c_uint()\n    c_samplingPeriod = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetEncoderUtilization\")\n    ret = fn(handle, byref(c_util), byref(c_samplingPeriod))\n    _nvmlCheckReturn(ret)\n    return [c_util.value, c_samplingPeriod.value]\n\n\ndef nvmlDeviceGetDecoderUtilization(handle):\n    c_util = c_uint()\n    c_samplingPeriod = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetDecoderUtilization\")\n    ret = fn(handle, byref(c_util), byref(c_samplingPeriod))\n    _nvmlCheckReturn(ret)\n    return [c_util.value, c_samplingPeriod.value]\n\n\ndef nvmlDeviceGetJpgUtilization(handle):\n    c_util = c_uint()\n    c_samplingPeriod = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetJpgUtilization\")\n    ret = fn(handle, byref(c_util), byref(c_samplingPeriod))\n    _nvmlCheckReturn(ret)\n    return [c_util.value, c_samplingPeriod.value]\n\n\ndef nvmlDeviceGetOfaUtilization(handle):\n    c_util = c_uint()\n    c_samplingPeriod = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetOfaUtilization\")\n    ret = fn(handle, byref(c_util), byref(c_samplingPeriod))\n    _nvmlCheckReturn(ret)\n    return [c_util.value, c_samplingPeriod.value]\n\n\ndef nvmlDeviceGetPcieReplayCounter(handle):\n    c_replay = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetPcieReplayCounter\")\n    ret = fn(handle, byref(c_replay))\n    _nvmlCheckReturn(ret)\n    return c_replay.value\n\n\ndef nvmlDeviceGetDriverModel(handle):\n    c_currModel = _nvmlDriverModel_t()\n    c_pendingModel = _nvmlDriverModel_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetDriverModel\")\n    ret = fn(handle, byref(c_currModel), byref(c_pendingModel))\n    _nvmlCheckReturn(ret)\n    return [c_currModel.value, c_pendingModel.value]\n\n\n# added to API\ndef nvmlDeviceGetCurrentDriverModel(handle):\n    return nvmlDeviceGetDriverModel(handle)[0]\n\n\n# added to API\ndef nvmlDeviceGetPendingDriverModel(handle):\n    return nvmlDeviceGetDriverModel(handle)[1]\n\n\n# Added in 2.285\n@convertStrBytes\ndef nvmlDeviceGetVbiosVersion(handle):\n    c_version = create_string_buffer(NVML_DEVICE_VBIOS_VERSION_BUFFER_SIZE)\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetVbiosVersion\")\n    ret = fn(handle, c_version, c_uint(NVML_DEVICE_VBIOS_VERSION_BUFFER_SIZE))\n    _nvmlCheckReturn(ret)\n    return c_version.value\n\n\n# Added in 2.285\ndef nvmlDeviceGetComputeRunningProcesses_v2(handle):\n    # first call to get the size\n    c_count = c_uint(0)\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetComputeRunningProcesses_v2\")\n    ret = fn(handle, byref(c_count), None)\n    if ret == NVML_SUCCESS:\n        # special case, no running processes\n        return []\n    elif ret == NVML_ERROR_INSUFFICIENT_SIZE:\n        # typical case\n        # oversize the array in case more processes are created\n        c_count.value = c_count.value * 2 + 5\n        proc_array = c_nvmlProcessInfo_v2_t * c_count.value\n        c_procs = proc_array()\n        # make the call again\n        ret = fn(handle, byref(c_count), c_procs)\n        _nvmlCheckReturn(ret)\n        procs = []\n        for i in range(c_count.value):\n            # use an alternative struct for this object\n            obj = nvmlStructToFriendlyObject(c_procs[i])\n            if obj.usedGpuMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value:\n                # special case for WDDM on Windows, see comment above\n                obj.usedGpuMemory = None\n            procs.append(obj)\n        return procs\n    else:\n        # error case\n        raise NVMLError(ret)\n\n\n# Added in 2.285\ndef nvmlDeviceGetComputeRunningProcesses_v3(handle):\n    # first call to get the size\n    c_count = c_uint(0)\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetComputeRunningProcesses_v3\")\n    ret = fn(handle, byref(c_count), None)\n\n    if ret == NVML_SUCCESS:\n        # special case, no running processes\n        return []\n    elif ret == NVML_ERROR_INSUFFICIENT_SIZE:\n        # typical case\n        # oversize the array in case more processes are created\n        c_count.value = c_count.value * 2 + 5\n        proc_array = c_nvmlProcessInfo_v3_t * c_count.value\n        c_procs = proc_array()\n\n        # make the call again\n        ret = fn(handle, byref(c_count), c_procs)\n        _nvmlCheckReturn(ret)\n\n        procs = []\n        for i in range(c_count.value):\n            # use an alternative struct for this object\n            obj = nvmlStructToFriendlyObject(c_procs[i])\n            if obj.usedGpuMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value:\n                # special case for WDDM on Windows, see comment above\n                obj.usedGpuMemory = None\n            procs.append(obj)\n\n        return procs\n    else:\n        # error case\n        raise NVMLError(ret)\n\n\n@throwOnVersionMismatch\ndef nvmlDeviceGetComputeRunningProcesses(handle):\n    return nvmlDeviceGetComputeRunningProcesses_v3(handle)\n\n\ndef nvmlDeviceGetGraphicsRunningProcesses_v2(handle):\n    # first call to get the size\n    c_count = c_uint(0)\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetGraphicsRunningProcesses_v2\")\n    ret = fn(handle, byref(c_count), None)\n    if ret == NVML_SUCCESS:\n        # special case, no running processes\n        return []\n    elif ret == NVML_ERROR_INSUFFICIENT_SIZE:\n        # typical case\n        # oversize the array in case more processes are created\n        c_count.value = c_count.value * 2 + 5\n        proc_array = c_nvmlProcessInfo_v2_t * c_count.value\n        c_procs = proc_array()\n        # make the call again\n        ret = fn(handle, byref(c_count), c_procs)\n        _nvmlCheckReturn(ret)\n        procs = []\n        for i in range(c_count.value):\n            # use an alternative struct for this object\n            obj = nvmlStructToFriendlyObject(c_procs[i])\n            if obj.usedGpuMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value:\n                # special case for WDDM on Windows, see comment above\n                obj.usedGpuMemory = None\n            procs.append(obj)\n        return procs\n    else:\n        # error case\n        raise NVMLError(ret)\n\n\ndef nvmlDeviceGetGraphicsRunningProcesses_v3(handle):\n    # first call to get the size\n    c_count = c_uint(0)\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetGraphicsRunningProcesses_v3\")\n    ret = fn(handle, byref(c_count), None)\n\n    if ret == NVML_SUCCESS:\n        # special case, no running processes\n        return []\n    elif ret == NVML_ERROR_INSUFFICIENT_SIZE:\n        # typical case\n        # oversize the array in case more processes are created\n        c_count.value = c_count.value * 2 + 5\n        proc_array = c_nvmlProcessInfo_v3_t * c_count.value\n        c_procs = proc_array()\n\n        # make the call again\n        ret = fn(handle, byref(c_count), c_procs)\n        _nvmlCheckReturn(ret)\n\n        procs = []\n        for i in range(c_count.value):\n            # use an alternative struct for this object\n            obj = nvmlStructToFriendlyObject(c_procs[i])\n            if obj.usedGpuMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value:\n                # special case for WDDM on Windows, see comment above\n                obj.usedGpuMemory = None\n            procs.append(obj)\n\n        return procs\n    else:\n        # error case\n        raise NVMLError(ret)\n\n\n@throwOnVersionMismatch\ndef nvmlDeviceGetGraphicsRunningProcesses(handle):\n    return nvmlDeviceGetGraphicsRunningProcesses_v3(handle)\n\n\n@throwOnVersionMismatch\ndef nvmlDeviceGetMPSComputeRunningProcesses(handle):\n    return nvmlDeviceGetMPSComputeRunningProcesses_v3(handle)\n\n\ndef nvmlDeviceGetMPSComputeRunningProcesses_v2(handle):\n    # first call to get the size\n    c_count = c_uint(0)\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetMPSComputeRunningProcesses_v2\")\n    ret = fn(handle, byref(c_count), None)\n\n    if ret == NVML_SUCCESS:\n        # special case, no running processes\n        return []\n    elif ret == NVML_ERROR_INSUFFICIENT_SIZE:\n        # typical case\n        # oversize the array in case more processes are created\n        c_count.value = c_count.value * 2 + 5\n        proc_array = c_nvmlProcessInfo_v2_t * c_count.value\n        c_procs = proc_array()\n\n        # make the call again\n        ret = fn(handle, byref(c_count), c_procs)\n        _nvmlCheckReturn(ret)\n\n        procs = []\n        for i in range(c_count.value):\n            # use an alternative struct for this object\n            obj = nvmlStructToFriendlyObject(c_procs[i])\n            if obj.usedGpuMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value:\n                # special case for WDDM on Windows, see comment above\n                obj.usedGpuMemory = None\n            procs.append(obj)\n\n        return procs\n    else:\n        # error case\n        raise NVMLError(ret)\n\n\ndef nvmlDeviceGetMPSComputeRunningProcesses_v3(handle):\n    # first call to get the size\n    c_count = c_uint(0)\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetMPSComputeRunningProcesses_v3\")\n    ret = fn(handle, byref(c_count), None)\n\n    if ret == NVML_SUCCESS:\n        # special case, no running processes\n        return []\n    elif ret == NVML_ERROR_INSUFFICIENT_SIZE:\n        # typical case\n        # oversize the array in case more processes are created\n        c_count.value = c_count.value * 2 + 5\n        proc_array = c_nvmlProcessInfo_v3_t * c_count.value\n        c_procs = proc_array()\n\n        # make the call again\n        ret = fn(handle, byref(c_count), c_procs)\n        _nvmlCheckReturn(ret)\n\n        procs = []\n        for i in range(c_count.value):\n            # use an alternative struct for this object\n            obj = nvmlStructToFriendlyObject(c_procs[i])\n            if obj.usedGpuMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value:\n                # special case for WDDM on Windows, see comment above\n                obj.usedGpuMemory = None\n            procs.append(obj)\n\n        return procs\n    else:\n        # error case\n        raise NVMLError(ret)\n\n\ndef nvmlDeviceGetRunningProcessDetailList(handle, version, mode):\n    c_processDetailList = c_nvmlProcessDetailList_t()\n    c_processDetailList.version = version\n    c_processDetailList.mode = mode\n\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetRunningProcessDetailList\")\n\n    # first call to get the size\n    ret = fn(handle, byref(c_processDetailList))\n    if ret == NVML_SUCCESS:\n        # special case, no running processes\n        return []\n    elif ret == NVML_ERROR_INSUFFICIENT_SIZE:\n        c_procs = c_nvmlProcessDetail_v1_t * c_processDetailList.numProcArrayEntries\n        c_processDetailList.procArray = cast(\n            (c_procs)(), POINTER(c_nvmlProcessDetail_v1_t)\n        )\n\n        # make the call again\n        ret = fn(handle, byref(c_processDetailList))\n        _nvmlCheckReturn(ret)\n\n        procs = []\n        for i in range(c_processDetailList.numProcArrayEntries):\n            # use an alternative struct for this object\n            obj = c_processDetailList.procArray[i]\n            if obj.usedGpuMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value:\n                obj.usedGpuMemory = None\n            if obj.usedGpuCcProtectedMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value:\n                obj.usedGpuCcProtectedMemory = None\n            procs.append(obj)\n\n        return procs\n    else:\n        # error case\n        raise NVMLError(ret)\n\n\ndef nvmlDeviceGetAutoBoostedClocksEnabled(handle):\n    c_isEnabled = _nvmlEnableState_t()\n    c_defaultIsEnabled = _nvmlEnableState_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetAutoBoostedClocksEnabled\")\n    ret = fn(handle, byref(c_isEnabled), byref(c_defaultIsEnabled))\n    _nvmlCheckReturn(ret)\n    return [c_isEnabled.value, c_defaultIsEnabled.value]\n    # Throws NVML_ERROR_NOT_SUPPORTED if hardware doesn't support setting auto boosted clocks\n\n\n## Set functions\ndef nvmlUnitSetLedState(unit, color):\n    fn = _nvmlGetFunctionPointer(\"nvmlUnitSetLedState\")\n    ret = fn(unit, _nvmlLedColor_t(color))\n    _nvmlCheckReturn(ret)\n    return None\n\n\ndef nvmlDeviceSetPersistenceMode(handle, mode):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceSetPersistenceMode\")\n    ret = fn(handle, _nvmlEnableState_t(mode))\n    _nvmlCheckReturn(ret)\n    return None\n\n\ndef nvmlDeviceSetComputeMode(handle, mode):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceSetComputeMode\")\n    ret = fn(handle, _nvmlComputeMode_t(mode))\n    _nvmlCheckReturn(ret)\n    return None\n\n\ndef nvmlDeviceSetEccMode(handle, mode):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceSetEccMode\")\n    ret = fn(handle, _nvmlEnableState_t(mode))\n    _nvmlCheckReturn(ret)\n    return None\n\n\ndef nvmlDeviceClearEccErrorCounts(handle, counterType):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceClearEccErrorCounts\")\n    ret = fn(handle, _nvmlEccCounterType_t(counterType))\n    _nvmlCheckReturn(ret)\n    return None\n\n\ndef nvmlDeviceSetDriverModel(handle, model):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceSetDriverModel\")\n    ret = fn(handle, _nvmlDriverModel_t(model))\n    _nvmlCheckReturn(ret)\n    return None\n\n\ndef nvmlDeviceSetAutoBoostedClocksEnabled(handle, enabled):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceSetAutoBoostedClocksEnabled\")\n    ret = fn(handle, _nvmlEnableState_t(enabled))\n    _nvmlCheckReturn(ret)\n    return None\n    # Throws NVML_ERROR_NOT_SUPPORTED if hardware doesn't support setting auto boosted clocks\n\n\ndef nvmlDeviceSetDefaultAutoBoostedClocksEnabled(handle, enabled, flags):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceSetDefaultAutoBoostedClocksEnabled\")\n    ret = fn(handle, _nvmlEnableState_t(enabled), c_uint(flags))\n    _nvmlCheckReturn(ret)\n    return None\n    # Throws NVML_ERROR_NOT_SUPPORTED if hardware doesn't support setting auto boosted clocks\n\n\ndef nvmlDeviceSetGpuLockedClocks(handle, minGpuClockMHz, maxGpuClockMHz):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceSetGpuLockedClocks\")\n    ret = fn(handle, c_uint(minGpuClockMHz), c_uint(maxGpuClockMHz))\n    _nvmlCheckReturn(ret)\n    return None\n\n\ndef nvmlDeviceResetGpuLockedClocks(handle):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceResetGpuLockedClocks\")\n    ret = fn(handle)\n    _nvmlCheckReturn(ret)\n    return None\n\n\ndef nvmlDeviceSetMemoryLockedClocks(handle, minMemClockMHz, maxMemClockMHz):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceSetMemoryLockedClocks\")\n    ret = fn(handle, c_uint(minMemClockMHz), c_uint(maxMemClockMHz))\n    _nvmlCheckReturn(ret)\n    return None\n\n\ndef nvmlDeviceResetMemoryLockedClocks(handle):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceResetMemoryLockedClocks\")\n    ret = fn(handle)\n    _nvmlCheckReturn(ret)\n    return None\n\n\ndef nvmlDeviceGetClkMonStatus(handle, c_clkMonInfo=nvmlClkMonStatus_t()):\n    isReference = type(c_clkMonInfo) is not nvmlClkMonStatus_t\n    c_clkMonInfoRef = c_clkMonInfo if isReference else byref(c_clkMonInfo)\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetClkMonStatus\")\n    ret = fn(handle, c_clkMonInfoRef)\n    _nvmlCheckReturn(ret)\n    return NVML_SUCCESS if isReference else c_clkMonInfo\n\n\n# Added in 4.304\ndef nvmlDeviceSetApplicationsClocks(handle, maxMemClockMHz, maxGraphicsClockMHz):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceSetApplicationsClocks\")\n    ret = fn(handle, c_uint(maxMemClockMHz), c_uint(maxGraphicsClockMHz))\n    _nvmlCheckReturn(ret)\n    return None\n\n\n# Added in 4.304\ndef nvmlDeviceResetApplicationsClocks(handle):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceResetApplicationsClocks\")\n    ret = fn(handle)\n    _nvmlCheckReturn(ret)\n    return None\n\n\n# Added in 4.304\ndef nvmlDeviceSetPowerManagementLimit(handle, limit):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceSetPowerManagementLimit\")\n    ret = fn(handle, c_uint(limit))\n    _nvmlCheckReturn(ret)\n    return None\n\n\n# Added in 4.304\ndef nvmlDeviceSetGpuOperationMode(handle, mode):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceSetGpuOperationMode\")\n    ret = fn(handle, _nvmlGpuOperationMode_t(mode))\n    _nvmlCheckReturn(ret)\n    return None\n\n\n# Added in 2.285\ndef nvmlEventSetCreate():\n    fn = _nvmlGetFunctionPointer(\"nvmlEventSetCreate\")\n    eventSet = c_nvmlEventSet_t()\n    ret = fn(byref(eventSet))\n    _nvmlCheckReturn(ret)\n    return eventSet\n\n\n# Added in 2.285\ndef nvmlDeviceRegisterEvents(handle, eventTypes, eventSet):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceRegisterEvents\")\n    ret = fn(handle, c_ulonglong(eventTypes), eventSet)\n    _nvmlCheckReturn(ret)\n    return None\n\n\n# Added in 2.285\ndef nvmlDeviceGetSupportedEventTypes(handle):\n    c_eventTypes = c_ulonglong()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetSupportedEventTypes\")\n    ret = fn(handle, byref(c_eventTypes))\n    _nvmlCheckReturn(ret)\n    return c_eventTypes.value\n\n\n# raises NVML_ERROR_TIMEOUT exception on timeout\ndef nvmlEventSetWait_v2(eventSet, timeoutms):\n    fn = _nvmlGetFunctionPointer(\"nvmlEventSetWait_v2\")\n    data = c_nvmlEventData_t()\n    ret = fn(eventSet, byref(data), c_uint(timeoutms))\n    _nvmlCheckReturn(ret)\n    return data\n\n\ndef nvmlEventSetWait(eventSet, timeoutms):\n    return nvmlEventSetWait_v2(eventSet, timeoutms)\n\n\n# Added in 2.285\ndef nvmlEventSetFree(eventSet):\n    fn = _nvmlGetFunctionPointer(\"nvmlEventSetFree\")\n    ret = fn(eventSet)\n    _nvmlCheckReturn(ret)\n    return None\n\n\n# Added in 3.295\ndef nvmlDeviceOnSameBoard(handle1, handle2):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceOnSameBoard\")\n    onSameBoard = c_int()\n    ret = fn(handle1, handle2, byref(onSameBoard))\n    _nvmlCheckReturn(ret)\n    return onSameBoard.value != 0\n\n\n# Added in 3.295\ndef nvmlDeviceGetCurrPcieLinkGeneration(handle):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetCurrPcieLinkGeneration\")\n    gen = c_uint()\n    ret = fn(handle, byref(gen))\n    _nvmlCheckReturn(ret)\n    return gen.value\n\n\n# Added in 3.295\ndef nvmlDeviceGetMaxPcieLinkGeneration(handle):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetMaxPcieLinkGeneration\")\n    gen = c_uint()\n    ret = fn(handle, byref(gen))\n    _nvmlCheckReturn(ret)\n    return gen.value\n\n\n# Added in 3.295\ndef nvmlDeviceGetCurrPcieLinkWidth(handle):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetCurrPcieLinkWidth\")\n    width = c_uint()\n    ret = fn(handle, byref(width))\n    _nvmlCheckReturn(ret)\n    return width.value\n\n\n# Added in 3.295\ndef nvmlDeviceGetMaxPcieLinkWidth(handle):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetMaxPcieLinkWidth\")\n    width = c_uint()\n    ret = fn(handle, byref(width))\n    _nvmlCheckReturn(ret)\n    return width.value\n\n\ndef nvmlDeviceGetGpuMaxPcieLinkGeneration(handle):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetGpuMaxPcieLinkGeneration\")\n    gen = c_uint()\n    ret = fn(handle, byref(gen))\n    _nvmlCheckReturn(ret)\n    return gen.value\n\n\n# Added in 4.304\ndef nvmlDeviceGetSupportedClocksThrottleReasons(handle):\n    c_reasons = c_ulonglong()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetSupportedClocksThrottleReasons\")\n    ret = fn(handle, byref(c_reasons))\n    _nvmlCheckReturn(ret)\n    return c_reasons.value\n\n\ndef nvmlDeviceGetSupportedClocksEventReasons(handle):\n    c_reasons = c_ulonglong()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetSupportedClocksEventReasons\")\n    ret = fn(handle, byref(c_reasons))\n    _nvmlCheckReturn(ret)\n    return c_reasons.value\n\n\n# Added in 4.304\ndef nvmlDeviceGetCurrentClocksThrottleReasons(handle):\n    c_reasons = c_ulonglong()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetCurrentClocksThrottleReasons\")\n    ret = fn(handle, byref(c_reasons))\n    _nvmlCheckReturn(ret)\n    return c_reasons.value\n\n\ndef nvmlDeviceGetCurrentClocksEventReasons(handle):\n    c_reasons = c_ulonglong()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetCurrentClocksEventReasons\")\n    ret = fn(handle, byref(c_reasons))\n    _nvmlCheckReturn(ret)\n    return c_reasons.value\n\n\n# Added in 5.319\ndef nvmlDeviceGetIndex(handle):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetIndex\")\n    c_index = c_uint()\n    ret = fn(handle, byref(c_index))\n    _nvmlCheckReturn(ret)\n    return c_index.value\n\n\n# Added in 5.319\ndef nvmlDeviceGetAccountingMode(handle):\n    c_mode = _nvmlEnableState_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetAccountingMode\")\n    ret = fn(handle, byref(c_mode))\n    _nvmlCheckReturn(ret)\n    return c_mode.value\n\n\ndef nvmlDeviceSetAccountingMode(handle, mode):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceSetAccountingMode\")\n    ret = fn(handle, _nvmlEnableState_t(mode))\n    _nvmlCheckReturn(ret)\n    return None\n\n\ndef nvmlDeviceClearAccountingPids(handle):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceClearAccountingPids\")\n    ret = fn(handle)\n    _nvmlCheckReturn(ret)\n    return None\n\n\ndef nvmlDeviceGetAccountingStats(handle, pid):\n    stats = c_nvmlAccountingStats_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetAccountingStats\")\n    ret = fn(handle, c_uint(pid), byref(stats))\n    _nvmlCheckReturn(ret)\n    if stats.maxMemoryUsage == NVML_VALUE_NOT_AVAILABLE_ulonglong.value:\n        # special case for WDDM on Windows, see comment above\n        stats.maxMemoryUsage = None\n    return stats\n\n\ndef nvmlDeviceGetAccountingPids(handle):\n    count = c_uint(nvmlDeviceGetAccountingBufferSize(handle))\n    pids = (c_uint * count.value)()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetAccountingPids\")\n    ret = fn(handle, byref(count), pids)\n    _nvmlCheckReturn(ret)\n    return list(map(int, pids[0 : count.value]))\n\n\ndef nvmlDeviceGetAccountingBufferSize(handle):\n    bufferSize = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetAccountingBufferSize\")\n    ret = fn(handle, byref(bufferSize))\n    _nvmlCheckReturn(ret)\n    return int(bufferSize.value)\n\n\ndef nvmlDeviceGetRetiredPages(device, sourceFilter):\n    c_source = _nvmlPageRetirementCause_t(sourceFilter)\n    c_count = c_uint(0)\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetRetiredPages\")\n\n    # First call will get the size\n    ret = fn(device, c_source, byref(c_count), None)\n\n    # this should only fail with insufficient size\n    if (ret != NVML_SUCCESS) and (ret != NVML_ERROR_INSUFFICIENT_SIZE):\n        raise NVMLError(ret)\n\n    # call again with a buffer\n    # oversize the array for the rare cases where additional pages\n    # are retired between NVML calls\n    c_count.value = c_count.value * 2 + 5\n    page_array = c_ulonglong * c_count.value\n    c_pages = page_array()\n    ret = fn(device, c_source, byref(c_count), c_pages)\n    _nvmlCheckReturn(ret)\n    return list(map(int, c_pages[0 : c_count.value]))\n\n\ndef nvmlDeviceGetRetiredPages_v2(device, sourceFilter):\n    c_source = _nvmlPageRetirementCause_t(sourceFilter)\n    c_count = c_uint(0)\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetRetiredPages_v2\")\n\n    # First call will get the size\n    ret = fn(device, c_source, byref(c_count), None)\n\n    # this should only fail with insufficient size\n    if (ret != NVML_SUCCESS) and (ret != NVML_ERROR_INSUFFICIENT_SIZE):\n        raise NVMLError(ret)\n\n    # call again with a buffer\n    # oversize the array for the rare cases where additional pages\n    # are retired between NVML calls\n    c_count.value = c_count.value * 2 + 5\n    page_array = c_ulonglong * c_count.value\n    c_pages = page_array()\n    times_array = c_ulonglong * c_count.value\n    c_times = times_array()\n    ret = fn(device, c_source, byref(c_count), c_pages, c_times)\n    _nvmlCheckReturn(ret)\n    return [\n        {\"address\": int(c_pages[i]), \"timestamp\": int(c_times[i])}\n        for i in range(c_count.value)\n    ]\n\n\ndef nvmlDeviceGetRetiredPagesPendingStatus(device):\n    c_pending = _nvmlEnableState_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetRetiredPagesPendingStatus\")\n    ret = fn(device, byref(c_pending))\n    _nvmlCheckReturn(ret)\n    return int(c_pending.value)\n\n\ndef nvmlDeviceGetAPIRestriction(device, apiType):\n    c_permission = _nvmlEnableState_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetAPIRestriction\")\n    ret = fn(device, _nvmlRestrictedAPI_t(apiType), byref(c_permission))\n    _nvmlCheckReturn(ret)\n    return int(c_permission.value)\n\n\ndef nvmlDeviceSetAPIRestriction(handle, apiType, isRestricted):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceSetAPIRestriction\")\n    ret = fn(handle, _nvmlRestrictedAPI_t(apiType), _nvmlEnableState_t(isRestricted))\n    _nvmlCheckReturn(ret)\n    return None\n\n\ndef nvmlDeviceGetBridgeChipInfo(handle):\n    bridgeHierarchy = c_nvmlBridgeChipHierarchy_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetBridgeChipInfo\")\n    ret = fn(handle, byref(bridgeHierarchy))\n    _nvmlCheckReturn(ret)\n    return bridgeHierarchy\n\n\ndef nvmlDeviceGetSamples(device, sampling_type, timeStamp):\n    c_sampling_type = _nvmlSamplingType_t(sampling_type)\n    c_time_stamp = c_ulonglong(timeStamp)\n    c_sample_count = c_uint(0)\n    c_sample_value_type = _nvmlValueType_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetSamples\")\n\n    ## First Call gets the size\n    ret = fn(\n        device,\n        c_sampling_type,\n        c_time_stamp,\n        byref(c_sample_value_type),\n        byref(c_sample_count),\n        None,\n    )\n\n    # Stop if this fails\n    if ret != NVML_SUCCESS:\n        raise NVMLError(ret)\n\n    sampleArray = c_sample_count.value * c_nvmlSample_t\n    c_samples = sampleArray()\n    ret = fn(\n        device,\n        c_sampling_type,\n        c_time_stamp,\n        byref(c_sample_value_type),\n        byref(c_sample_count),\n        c_samples,\n    )\n    _nvmlCheckReturn(ret)\n    return (c_sample_value_type.value, c_samples[0 : c_sample_count.value])\n\n\ndef nvmlDeviceGetViolationStatus(device, perfPolicyType):\n    c_perfPolicy_type = _nvmlPerfPolicyType_t(perfPolicyType)\n    c_violTime = c_nvmlViolationTime_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetViolationStatus\")\n\n    ## Invoke the method to get violation time\n    ret = fn(device, c_perfPolicy_type, byref(c_violTime))\n    _nvmlCheckReturn(ret)\n    return c_violTime\n\n\ndef nvmlDeviceGetPcieThroughput(device, counter):\n    c_util = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetPcieThroughput\")\n    ret = fn(device, _nvmlPcieUtilCounter_t(counter), byref(c_util))\n    _nvmlCheckReturn(ret)\n    return c_util.value\n\n\ndef nvmlSystemGetTopologyGpuSet(cpuNumber):\n    c_count = c_uint(0)\n    fn = _nvmlGetFunctionPointer(\"nvmlSystemGetTopologyGpuSet\")\n\n    # First call will get the size\n    ret = fn(cpuNumber, byref(c_count), None)\n\n    if ret != NVML_SUCCESS:\n        raise NVMLError(ret)\n    # call again with a buffer\n    device_array = c_nvmlDevice_t * c_count.value\n    c_devices = device_array()\n    ret = fn(cpuNumber, byref(c_count), c_devices)\n    _nvmlCheckReturn(ret)\n    return list(c_devices[0 : c_count.value])\n\n\ndef nvmlDeviceGetTopologyNearestGpus(device, level):\n    c_count = c_uint(0)\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetTopologyNearestGpus\")\n\n    # First call will get the size\n    ret = fn(device, level, byref(c_count), None)\n\n    if ret != NVML_SUCCESS:\n        raise NVMLError(ret)\n\n    # call again with a buffer\n    device_array = c_nvmlDevice_t * c_count.value\n    c_devices = device_array()\n    ret = fn(device, level, byref(c_count), c_devices)\n    _nvmlCheckReturn(ret)\n    return list(c_devices[0 : c_count.value])\n\n\ndef nvmlDeviceGetTopologyCommonAncestor(device1, device2):\n    c_level = _nvmlGpuTopologyLevel_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetTopologyCommonAncestor\")\n    ret = fn(device1, device2, byref(c_level))\n    _nvmlCheckReturn(ret)\n    return c_level.value\n\n\ndef nvmlDeviceGetNvLinkUtilizationCounter(device, link, counter):\n    c_rxcounter = c_ulonglong()\n    c_txcounter = c_ulonglong()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetNvLinkUtilizationCounter\")\n    ret = fn(device, link, counter, byref(c_rxcounter), byref(c_txcounter))\n    _nvmlCheckReturn(ret)\n    return (c_rxcounter.value, c_txcounter.value)\n\n\ndef nvmlDeviceFreezeNvLinkUtilizationCounter(device, link, counter, freeze):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceFreezeNvLinkUtilizationCounter\")\n    ret = fn(device, link, counter, freeze)\n    _nvmlCheckReturn(ret)\n    return None\n\n\ndef nvmlDeviceResetNvLinkUtilizationCounter(device, link, counter):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceResetNvLinkUtilizationCounter\")\n    ret = fn(device, link, counter)\n    _nvmlCheckReturn(ret)\n    return None\n\n\ndef nvmlDeviceSetNvLinkUtilizationControl(device, link, counter, control, reset):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceSetNvLinkUtilizationControl\")\n    ret = fn(device, link, counter, byref(control), reset)\n    _nvmlCheckReturn(ret)\n    return None\n\n\ndef nvmlDeviceGetNvLinkUtilizationControl(device, link, counter):\n    c_control = nvmlNvLinkUtilizationControl_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetNvLinkUtilizationControl\")\n    ret = fn(device, link, counter, byref(c_control))\n    _nvmlCheckReturn(ret)\n    return c_control\n\n\ndef nvmlDeviceGetNvLinkCapability(device, link, capability):\n    c_capResult = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetNvLinkCapability\")\n    ret = fn(device, link, capability, byref(c_capResult))\n    _nvmlCheckReturn(ret)\n    return c_capResult.value\n\n\ndef nvmlDeviceGetNvLinkErrorCounter(device, link, counter):\n    c_result = c_ulonglong()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetNvLinkErrorCounter\")\n    ret = fn(device, link, counter, byref(c_result))\n    _nvmlCheckReturn(ret)\n    return c_result.value\n\n\ndef nvmlDeviceResetNvLinkErrorCounters(device, link):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceResetNvLinkErrorCounters\")\n    ret = fn(device, link)\n    _nvmlCheckReturn(ret)\n    return None\n\n\ndef nvmlDeviceGetNvLinkRemotePciInfo(device, link):\n    c_pci = nvmlPciInfo_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetNvLinkRemotePciInfo_v2\")\n    ret = fn(device, link, byref(c_pci))\n    _nvmlCheckReturn(ret)\n    return c_pci\n\n\ndef nvmlDeviceGetNvLinkRemoteDeviceType(handle, link):\n    c_type = _nvmlNvLinkDeviceType_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetNvLinkRemoteDeviceType\")\n    ret = fn(handle, link, byref(c_type))\n    _nvmlCheckReturn(ret)\n    return c_type.value\n\n\ndef nvmlDeviceGetNvLinkState(device, link):\n    c_isActive = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetNvLinkState\")\n    ret = fn(device, link, byref(c_isActive))\n    _nvmlCheckReturn(ret)\n    return c_isActive.value\n\n\ndef nvmlDeviceGetNvLinkVersion(device, link):\n    c_version = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetNvLinkVersion\")\n    ret = fn(device, link, byref(c_version))\n    _nvmlCheckReturn(ret)\n    return c_version.value\n\n\ndef nvmlDeviceModifyDrainState(pciInfo, newState):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceModifyDrainState\")\n    ret = fn(pointer(pciInfo), newState)\n    _nvmlCheckReturn(ret)\n    return None\n\n\ndef nvmlDeviceQueryDrainState(pciInfo):\n    c_newState = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceQueryDrainState\")\n    ret = fn(pointer(pciInfo), byref(c_newState))\n    _nvmlCheckReturn(ret)\n    return c_newState.value\n\n\ndef nvmlDeviceRemoveGpu(pciInfo):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceRemoveGpu\")\n    ret = fn(pointer(pciInfo))\n    _nvmlCheckReturn(ret)\n    return None\n\n\ndef nvmlDeviceDiscoverGpus(pciInfo):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceDiscoverGpus\")\n    ret = fn(pointer(pciInfo))\n    _nvmlCheckReturn(ret)\n    return None\n\n\ndef nvmlDeviceGetFieldValues(handle, fieldIds):\n    values_arr = c_nvmlFieldValue_t * len(fieldIds)\n    values = values_arr()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetFieldValues\")\n\n    for i, fieldId in enumerate(fieldIds):\n        try:\n            values[i].fieldId, values[i].scopeId = fieldId\n        except TypeError:\n            values[i].fieldId = fieldId\n\n    ret = fn(handle, c_int32(len(fieldIds)), byref(values))\n    _nvmlCheckReturn(ret)\n    return values\n\n\ndef nvmlDeviceClearFieldValues(handle, fieldIds):\n    values_arr = c_nvmlFieldValue_t * len(fieldIds)\n    values = values_arr()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceClearFieldValues\")\n\n    for i, fieldId in enumerate(fieldIds):\n        try:\n            values[i].fieldId, values[i].scopeId = fieldId\n        except TypeError:\n            values[i].fieldId = fieldId\n\n    ret = fn(handle, c_int32(len(fieldIds)), byref(values))\n    _nvmlCheckReturn(ret)\n    return values\n\n\ndef nvmlDeviceGetVirtualizationMode(handle):\n    c_virtualization_mode = c_ulonglong()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetVirtualizationMode\")\n    ret = fn(handle, byref(c_virtualization_mode))\n    _nvmlCheckReturn(ret)\n    return c_virtualization_mode.value\n\n\ndef nvmlDeviceSetVirtualizationMode(handle, virtualization_mode):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceSetVirtualizationMode\")\n    return fn(handle, virtualization_mode)\n\n\ndef nvmlDeviceGetVgpuHeterogeneousMode(handle):\n    c_vgpuHeterogeneousMode = c_nvmlVgpuHeterogeneousMode_v1_t(0)\n    c_vgpuHeterogeneousMode.version = VgpuHeterogeneousMode_v1\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetVgpuHeterogeneousMode\")\n    ret = fn(handle, byref(c_vgpuHeterogeneousMode))\n    _nvmlCheckReturn(ret)\n    return c_vgpuHeterogeneousMode.mode\n\n\ndef nvmlDeviceSetVgpuHeterogeneousMode(handle, heterogeneous_mode):\n    c_vgpuHeterogeneousMode = c_nvmlVgpuHeterogeneousMode_v1_t(0)\n    c_vgpuHeterogeneousMode.version = VgpuHeterogeneousMode_v1\n    c_vgpuHeterogeneousMode.mode = heterogeneous_mode\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceSetVgpuHeterogeneousMode\")\n    ret = fn(handle, byref(c_vgpuHeterogeneousMode))\n    _nvmlCheckReturn(ret)\n    return NVML_SUCCESS\n\n\ndef nvmlVgpuInstanceGetPlacementId(vgpuInstance):\n    c_placement = c_nvmlVgpuPlacementId_v1_t(0)\n    c_placement.version = VgpuPlacementId_v1\n    fn = _nvmlGetFunctionPointer(\"nvmlVgpuInstanceGetPlacementId\")\n    ret = fn(vgpuInstance, byref(c_placement))\n    _nvmlCheckReturn(ret)\n    return c_placement.placementId\n\n\ndef nvmlDeviceGetVgpuTypeSupportedPlacements(handle, vgpuTypeId, mode=0, version=1):\n    c_max_instances = c_uint(0)\n    fn = _nvmlGetFunctionPointer(\"nvmlVgpuTypeGetMaxInstances\")\n    ret = fn(handle, vgpuTypeId, byref(c_max_instances))\n    _nvmlCheckReturn(ret)\n\n    if version == 2:\n        c_vgpu_placements = c_nvmlVgpuPlacementList_v2_t()\n        c_vgpu_placements.version = VgpuPlacementList_v2\n        c_vgpu_placements.count = c_max_instances.value\n        c_vgpu_placements.mode = mode\n    elif version == 1:\n        c_vgpu_placements = c_nvmlVgpuPlacementList_v1_t()\n        c_vgpu_placements.version = VgpuPlacementList_v1\n    else:\n        raise NVMLError(NVML_ERROR_ARGUMENT_VERSION_MISMATCH)\n\n    c_placements = c_uint * c_max_instances.value\n    c_vgpu_placements.placementIds = c_placements()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetVgpuTypeSupportedPlacements\")\n    ret = fn(handle, vgpuTypeId, byref(c_vgpu_placements))\n    _nvmlCheckReturn(ret)\n    return c_vgpu_placements\n\n\ndef nvmlDeviceGetVgpuTypeCreatablePlacements(handle, vgpuTypeId, version=1):\n    c_max_instances = c_uint(0)\n    fn = _nvmlGetFunctionPointer(\"nvmlVgpuTypeGetMaxInstances\")\n    ret = fn(handle, vgpuTypeId, byref(c_max_instances))\n    _nvmlCheckReturn(ret)\n\n    if version == 2:\n        c_vgpu_placements = c_nvmlVgpuPlacementList_v2_t()\n        c_vgpu_placements.version = VgpuPlacementList_v2\n        c_vgpu_placements.count = c_max_instances.value\n    elif version == 1:\n        c_vgpu_placements = c_nvmlVgpuPlacementList_v1_t()\n        c_vgpu_placements.version = VgpuPlacementList_v1\n\n    c_placements = c_uint * c_max_instances.value\n    c_vgpu_placements.placementIds = c_placements()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetVgpuTypeCreatablePlacements\")\n    ret = fn(handle, vgpuTypeId, byref(c_vgpu_placements))\n    _nvmlCheckReturn(ret)\n    return c_vgpu_placements\n\n\ndef nvmlGetVgpuDriverCapabilities(capability):\n    c_capResult = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlGetVgpuDriverCapabilities\")\n    ret = fn(_nvmlVgpuDriverCapability_t(capability), byref(c_capResult))\n    _nvmlCheckReturn(ret)\n    return c_capResult.value\n\n\ndef nvmlDeviceGetVgpuCapabilities(handle, capability):\n    c_capResult = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetVgpuCapabilities\")\n    ret = fn(handle, _nvmlDeviceVgpuCapability_t(capability), byref(c_capResult))\n    _nvmlCheckReturn(ret)\n    return c_capResult.value\n\n\ndef nvmlDeviceSetVgpuCapabilities(handle, capability, state):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceSetVgpuCapabilities\")\n    ret = fn(handle, _nvmlDeviceVgpuCapability_t(capability), state)\n    _nvmlCheckReturn(ret)\n    return NVML_SUCCESS\n\n\ndef nvmlDeviceGetSupportedVgpus(handle):\n    # first call to get the size\n    c_vgpu_count = c_uint(0)\n\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetSupportedVgpus\")\n    ret = fn(handle, byref(c_vgpu_count), None)\n\n    if ret == NVML_SUCCESS:\n        # special case, no supported vGPUs\n        return []\n    elif ret == NVML_ERROR_INSUFFICIENT_SIZE:\n        # typical case\n        vgpu_type_ids_array = _nvmlVgpuTypeId_t * c_vgpu_count.value\n        c_vgpu_type_ids = vgpu_type_ids_array()\n\n        # make the call again\n        ret = fn(handle, byref(c_vgpu_count), c_vgpu_type_ids)\n        _nvmlCheckReturn(ret)\n        vgpus = []\n        for i in range(c_vgpu_count.value):\n            vgpus.append(c_vgpu_type_ids[i])\n        return vgpus\n    else:\n        # error case\n        raise NVMLError(ret)\n\n\ndef nvmlDeviceGetCreatableVgpus(handle):\n    # first call to get the size\n    c_vgpu_count = c_uint(0)\n\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetCreatableVgpus\")\n    ret = fn(handle, byref(c_vgpu_count), None)\n\n    if ret == NVML_SUCCESS:\n        # special case, no supported vGPUs\n        return []\n    elif ret == NVML_ERROR_INSUFFICIENT_SIZE:\n        # typical case\n        vgpu_type_ids_array = _nvmlVgpuTypeId_t * c_vgpu_count.value\n        c_vgpu_type_ids = vgpu_type_ids_array()\n\n        # make the call again\n        ret = fn(handle, byref(c_vgpu_count), c_vgpu_type_ids)\n        _nvmlCheckReturn(ret)\n        vgpus = []\n        for i in range(c_vgpu_count.value):\n            vgpus.append(c_vgpu_type_ids[i])\n        return vgpus\n    else:\n        # error case\n        raise NVMLError(ret)\n\n\ndef nvmlVgpuTypeGetGpuInstanceProfileId(vgpuTypeId):\n    c_profile_id = c_uint(0)\n    fn = _nvmlGetFunctionPointer(\"nvmlVgpuTypeGetGpuInstanceProfileId\")\n    ret = fn(vgpuTypeId, byref(c_profile_id))\n    _nvmlCheckReturn(ret)\n    return c_profile_id.value\n\n\n@convertStrBytes\ndef nvmlVgpuTypeGetClass(vgpuTypeId):\n    c_class = create_string_buffer(NVML_DEVICE_NAME_BUFFER_SIZE)\n    c_buffer_size = c_uint(NVML_DEVICE_NAME_BUFFER_SIZE)\n    fn = _nvmlGetFunctionPointer(\"nvmlVgpuTypeGetClass\")\n    ret = fn(vgpuTypeId, c_class, byref(c_buffer_size))\n    _nvmlCheckReturn(ret)\n    return c_class.value\n\n\n@convertStrBytes\ndef nvmlVgpuTypeGetName(vgpuTypeId):\n    c_name = create_string_buffer(NVML_DEVICE_NAME_BUFFER_SIZE)\n    c_buffer_size = c_uint(NVML_DEVICE_NAME_BUFFER_SIZE)\n    fn = _nvmlGetFunctionPointer(\"nvmlVgpuTypeGetName\")\n    ret = fn(vgpuTypeId, c_name, byref(c_buffer_size))\n    _nvmlCheckReturn(ret)\n    return c_name.value\n\n\ndef nvmlVgpuTypeGetDeviceID(vgpuTypeId):\n    c_device_id = c_ulonglong(0)\n    c_subsystem_id = c_ulonglong(0)\n    fn = _nvmlGetFunctionPointer(\"nvmlVgpuTypeGetDeviceID\")\n    ret = fn(vgpuTypeId, byref(c_device_id), byref(c_subsystem_id))\n    _nvmlCheckReturn(ret)\n    return (c_device_id.value, c_subsystem_id.value)\n\n\ndef nvmlVgpuTypeGetFramebufferSize(vgpuTypeId):\n    c_fb_size = c_ulonglong(0)\n    fn = _nvmlGetFunctionPointer(\"nvmlVgpuTypeGetFramebufferSize\")\n    ret = fn(vgpuTypeId, byref(c_fb_size))\n    _nvmlCheckReturn(ret)\n    return c_fb_size.value\n\n\ndef nvmlVgpuTypeGetNumDisplayHeads(vgpuTypeId):\n    c_num_heads = c_uint(0)\n    fn = _nvmlGetFunctionPointer(\"nvmlVgpuTypeGetNumDisplayHeads\")\n    ret = fn(vgpuTypeId, byref(c_num_heads))\n    _nvmlCheckReturn(ret)\n    return c_num_heads.value\n\n\ndef nvmlVgpuTypeGetResolution(vgpuTypeId):\n    c_xdim = c_uint(0)\n    c_ydim = c_uint(0)\n    fn = _nvmlGetFunctionPointer(\"nvmlVgpuTypeGetResolution\")\n    ret = fn(vgpuTypeId, 0, byref(c_xdim), byref(c_ydim))\n    _nvmlCheckReturn(ret)\n    return (c_xdim.value, c_ydim.value)\n\n\n@convertStrBytes\ndef nvmlVgpuTypeGetLicense(vgpuTypeId):\n    c_license = create_string_buffer(NVML_GRID_LICENSE_BUFFER_SIZE)\n    c_buffer_size = c_uint(NVML_GRID_LICENSE_BUFFER_SIZE)\n    fn = _nvmlGetFunctionPointer(\"nvmlVgpuTypeGetLicense\")\n    ret = fn(vgpuTypeId, c_license, c_buffer_size)\n    _nvmlCheckReturn(ret)\n    return c_license.value\n\n\ndef nvmlVgpuTypeGetFrameRateLimit(vgpuTypeId):\n    c_frl_config = c_uint(0)\n    fn = _nvmlGetFunctionPointer(\"nvmlVgpuTypeGetFrameRateLimit\")\n    ret = fn(vgpuTypeId, byref(c_frl_config))\n    _nvmlCheckReturn(ret)\n    return c_frl_config.value\n\n\ndef nvmlVgpuTypeGetGspHeapSize(vgpuTypeId):\n    c_gsp_heap = c_uint(0)\n    fn = _nvmlGetFunctionPointer(\"nvmlVgpuTypeGetGspHeapSize\")\n    ret = fn(vgpuTypeId, byref(c_gsp_heap))\n    _nvmlCheckReturn(ret)\n    return c_gsp_heap.value\n\n\ndef nvmlVgpuTypeGetFbReservation(vgpuTypeId):\n    c_fb_reservation = c_uint(0)\n    fn = _nvmlGetFunctionPointer(\"nvmlVgpuTypeGetFbReservation\")\n    ret = fn(vgpuTypeId, byref(c_fb_reservation))\n    _nvmlCheckReturn(ret)\n    return c_fb_reservation.value\n\n\ndef nvmlVgpuInstanceGetRuntimeStateSize(vgpuInstance):\n    c_runtime_state = nvmlVgpuRuntimeState_v1_t()\n    c_runtime_state.version = VgpuRuntimeState_v1\n    fn = _nvmlGetFunctionPointer(\"nvmlVgpuInstanceGetRuntimeStateSize\")\n    ret = fn(vgpuInstance, byref(c_runtime_state))\n    _nvmlCheckReturn(ret)\n    return c_runtime_state\n\n\ndef nvmlVgpuTypeGetMaxInstances(handle, vgpuTypeId):\n    c_max_instances = c_uint(0)\n    fn = _nvmlGetFunctionPointer(\"nvmlVgpuTypeGetMaxInstances\")\n    ret = fn(handle, vgpuTypeId, byref(c_max_instances))\n    _nvmlCheckReturn(ret)\n    return c_max_instances.value\n\n\ndef nvmlVgpuTypeGetMaxInstancesPerVm(vgpuTypeId):\n    c_max_instances_per_vm = c_uint(0)\n    fn = _nvmlGetFunctionPointer(\"nvmlVgpuTypeGetMaxInstancesPerVm\")\n    ret = fn(vgpuTypeId, byref(c_max_instances_per_vm))\n    _nvmlCheckReturn(ret)\n    return c_max_instances_per_vm.value\n\n\ndef nvmlVgpuTypeGetBAR1Info(vgpuTypeId):\n    c_bar1Info = c_nvmlVgpuTypeBar1Info_v1_t(0)\n    c_bar1Info.version = VgpuTypeBar1Info_v1\n    fn = _nvmlGetFunctionPointer(\"nvmlVgpuTypeGetBAR1Info\")\n    ret = fn(vgpuTypeId, byref(c_bar1Info))\n    _nvmlCheckReturn(ret)\n    return c_bar1Info\n\n\ndef nvmlDeviceGetActiveVgpus(handle):\n    # first call to get the size\n    c_vgpu_count = c_uint(0)\n\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetActiveVgpus\")\n    ret = fn(handle, byref(c_vgpu_count), None)\n\n    if ret == NVML_SUCCESS:\n        # special case, no active vGPUs\n        return []\n    elif ret == NVML_ERROR_INSUFFICIENT_SIZE:\n        # typical case\n        vgpu_instance_array = _nvmlVgpuInstance_t * c_vgpu_count.value\n        c_vgpu_instances = vgpu_instance_array()\n\n        # make the call again\n        ret = fn(handle, byref(c_vgpu_count), c_vgpu_instances)\n        _nvmlCheckReturn(ret)\n        vgpus = []\n        for i in range(c_vgpu_count.value):\n            vgpus.append(c_vgpu_instances[i])\n        return vgpus\n    else:\n        # error case\n        raise NVMLError(ret)\n\n\n@convertStrBytes\ndef nvmlVgpuInstanceGetVmID(vgpuInstance):\n    c_vm_id = create_string_buffer(NVML_DEVICE_UUID_BUFFER_SIZE)\n    c_buffer_size = c_uint(NVML_GRID_LICENSE_BUFFER_SIZE)\n    c_vm_id_type = c_uint(0)\n    fn = _nvmlGetFunctionPointer(\"nvmlVgpuInstanceGetVmID\")\n    ret = fn(vgpuInstance, byref(c_vm_id), c_buffer_size, byref(c_vm_id_type))\n    _nvmlCheckReturn(ret)\n    return (c_vm_id.value, c_vm_id_type.value)\n\n\n@convertStrBytes\ndef nvmlVgpuInstanceGetUUID(vgpuInstance):\n    c_uuid = create_string_buffer(NVML_DEVICE_UUID_BUFFER_SIZE)\n    c_buffer_size = c_uint(NVML_DEVICE_UUID_BUFFER_SIZE)\n    fn = _nvmlGetFunctionPointer(\"nvmlVgpuInstanceGetUUID\")\n    ret = fn(vgpuInstance, byref(c_uuid), c_buffer_size)\n    _nvmlCheckReturn(ret)\n    return c_uuid.value\n\n\n@convertStrBytes\ndef nvmlVgpuInstanceGetMdevUUID(vgpuInstance):\n    c_uuid = create_string_buffer(NVML_DEVICE_UUID_BUFFER_SIZE)\n    c_buffer_size = c_uint(NVML_DEVICE_UUID_BUFFER_SIZE)\n    fn = _nvmlGetFunctionPointer(\"nvmlVgpuInstanceGetMdevUUID\")\n    ret = fn(vgpuInstance, byref(c_uuid), c_buffer_size)\n    _nvmlCheckReturn(ret)\n    return c_uuid.value\n\n\n@convertStrBytes\ndef nvmlVgpuInstanceGetVmDriverVersion(vgpuInstance):\n    c_driver_version = create_string_buffer(NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE)\n    c_buffer_size = c_uint(NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE)\n    fn = _nvmlGetFunctionPointer(\"nvmlVgpuInstanceGetVmDriverVersion\")\n    ret = fn(vgpuInstance, byref(c_driver_version), c_buffer_size)\n    _nvmlCheckReturn(ret)\n    return c_driver_version.value\n\n\ndef nvmlVgpuInstanceGetLicenseStatus(vgpuInstance):\n    c_license_status = c_uint(0)\n    fn = _nvmlGetFunctionPointer(\"nvmlVgpuInstanceGetLicenseStatus\")\n    ret = fn(vgpuInstance, byref(c_license_status))\n    _nvmlCheckReturn(ret)\n    return c_license_status.value\n\n\ndef nvmlVgpuInstanceGetLicenseInfo_v2(vgpuInstance):\n    fn = _nvmlGetFunctionPointer(\"nvmlVgpuInstanceGetLicenseInfo_v2\")\n    c_license_info = c_nvmlVgpuLicenseInfo_t()\n    ret = fn(vgpuInstance, byref(c_license_info))\n    _nvmlCheckReturn(ret)\n    return c_license_info\n\n\ndef nvmlVgpuInstanceGetLicenseInfo(vgpuInstance):\n    return nvmlVgpuInstanceGetLicenseInfo_v2(vgpuInstance)\n\n\ndef nvmlVgpuInstanceGetFrameRateLimit(vgpuInstance):\n    c_frl = c_uint(0)\n    fn = _nvmlGetFunctionPointer(\"nvmlVgpuInstanceGetFrameRateLimit\")\n    ret = fn(vgpuInstance, byref(c_frl))\n    _nvmlCheckReturn(ret)\n    return c_frl.value\n\n\ndef nvmlVgpuInstanceGetEccMode(vgpuInstance):\n    c_mode = _nvmlEnableState_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlVgpuInstanceGetEccMode\")\n    ret = fn(vgpuInstance, byref(c_mode))\n    _nvmlCheckReturn(ret)\n    return c_mode.value\n\n\ndef nvmlVgpuInstanceGetType(vgpuInstance):\n    c_vgpu_type = c_uint(0)\n    fn = _nvmlGetFunctionPointer(\"nvmlVgpuInstanceGetType\")\n    ret = fn(vgpuInstance, byref(c_vgpu_type))\n    _nvmlCheckReturn(ret)\n    return c_vgpu_type.value\n\n\ndef nvmlVgpuInstanceGetEncoderCapacity(vgpuInstance):\n    c_encoder_capacity = c_ulonglong(0)\n    fn = _nvmlGetFunctionPointer(\"nvmlVgpuInstanceGetEncoderCapacity\")\n    ret = fn(vgpuInstance, byref(c_encoder_capacity))\n    _nvmlCheckReturn(ret)\n    return c_encoder_capacity.value\n\n\ndef nvmlVgpuInstanceSetEncoderCapacity(vgpuInstance, encoder_capacity):\n    fn = _nvmlGetFunctionPointer(\"nvmlVgpuInstanceSetEncoderCapacity\")\n    return fn(vgpuInstance, encoder_capacity)\n\n\ndef nvmlVgpuInstanceGetFbUsage(vgpuInstance):\n    c_fb_usage = c_uint(0)\n    fn = _nvmlGetFunctionPointer(\"nvmlVgpuInstanceGetFbUsage\")\n    ret = fn(vgpuInstance, byref(c_fb_usage))\n    _nvmlCheckReturn(ret)\n    return c_fb_usage.value\n\n\ndef nvmlVgpuTypeGetCapabilities(vgpuTypeId, capability):\n    c_cap_result = c_uint(0)\n    fn = _nvmlGetFunctionPointer(\"nvmlVgpuTypeGetCapabilities\")\n    ret = fn(vgpuTypeId, _nvmlVgpuCapability_t(capability), byref(c_cap_result))\n    _nvmlCheckReturn(ret)\n    return c_cap_result.value\n\n\ndef nvmlVgpuInstanceGetGpuInstanceId(vgpuInstance):\n    c_id = c_uint(0)\n    fn = _nvmlGetFunctionPointer(\"nvmlVgpuInstanceGetGpuInstanceId\")\n    ret = fn(vgpuInstance, byref(c_id))\n    _nvmlCheckReturn(ret)\n    return c_id.value\n\n\n@convertStrBytes\ndef nvmlVgpuInstanceGetGpuPciId(vgpuInstance):\n    c_vgpuPciId = create_string_buffer(NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE)\n    fn = _nvmlGetFunctionPointer(\"nvmlVgpuInstanceGetGpuPciId\")\n    ret = fn(\n        vgpuInstance, c_vgpuPciId, byref(c_uint(NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE))\n    )\n    _nvmlCheckReturn(ret)\n    return c_vgpuPciId.value\n\n\ndef nvmlDeviceGetVgpuUtilization(handle, timeStamp):\n    # first call to get the size\n    c_vgpu_count = c_uint(0)\n    c_time_stamp = c_ulonglong(timeStamp)\n    c_sample_value_type = _nvmlValueType_t()\n\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetVgpuUtilization\")\n    ret = fn(\n        handle, c_time_stamp, byref(c_sample_value_type), byref(c_vgpu_count), None\n    )\n\n    if ret == NVML_SUCCESS:\n        # special case, no active vGPUs\n        return []\n    elif ret == NVML_ERROR_INSUFFICIENT_SIZE:\n        # typical case\n        sampleArray = c_vgpu_count.value * c_nvmlVgpuInstanceUtilizationSample_t\n        c_samples = sampleArray()\n\n        # make the call again\n        ret = fn(\n            handle,\n            c_time_stamp,\n            byref(c_sample_value_type),\n            byref(c_vgpu_count),\n            c_samples,\n        )\n        _nvmlCheckReturn(ret)\n\n        return c_samples[0 : c_vgpu_count.value]\n    else:\n        # error case\n        raise NVMLError(ret)\n\n\ndef nvmlDeviceGetVgpuInstancesUtilizationInfo(handle, timeStamp):\n    # first call to get the size\n    c_time_stamp = c_ulonglong(timeStamp)\n    c_vgpuUtilInfo = c_nvmlVgpuInstancesUtilizationInfo_v1_t(0)\n    c_vgpuUtilInfo.version = VgpuInstancesUtilizationInfo_v1\n    c_vgpuUtilInfo.sampleValType = _nvmlValueType_t()\n    c_vgpuUtilInfo.vgpuInstanceCount = c_uint(0)\n    c_vgpuUtilInfo.lastSeenTimeStamp = c_time_stamp\n\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetVgpuInstancesUtilizationInfo\")\n    ret = fn(handle, byref(c_vgpuUtilInfo))\n\n    if ret == NVML_SUCCESS:\n        # special case, no active vGPUs\n        return []\n    elif ret == NVML_ERROR_INSUFFICIENT_SIZE:\n        # typical case\n        sampleArray = (\n            c_vgpuUtilInfo.vgpuInstanceCount * c_nvmlVgpuInstanceUtilizationInfo_v1_t\n        )\n        c_samples = sampleArray()\n        c_vgpuUtilInfo.vgpuUtilArray = c_samples\n\n        # make the call again\n        ret = fn(handle, byref(c_vgpuUtilInfo))\n        _nvmlCheckReturn(ret)\n\n        return c_samples[0 : c_vgpuUtilInfo.vgpuInstanceCount]\n    else:\n        # error case\n        raise NVMLError(ret)\n\n\ndef nvmlDeviceGetP2PStatus(device1, device2, p2pIndex):\n    c_p2pstatus = _nvmlGpuP2PStatus_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetP2PStatus\")\n    ret = fn(device1, device2, p2pIndex, byref(c_p2pstatus))\n    _nvmlCheckReturn(ret)\n    return c_p2pstatus.value\n\n\ndef nvmlDeviceGetGridLicensableFeatures_v4(handle):\n    c_get_grid_licensable_features = c_nvmlGridLicensableFeatures_v4_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetGridLicensableFeatures_v4\")\n    ret = fn(handle, byref(c_get_grid_licensable_features))\n    _nvmlCheckReturn(ret)\n\n    return c_get_grid_licensable_features\n\n\ndef nvmlDeviceGetGridLicensableFeatures(handle):\n    return nvmlDeviceGetGridLicensableFeatures_v4(handle)\n\n\ndef nvmlDeviceGetGspFirmwareVersion(handle, version=None):\n    isUserDefined = version is not None\n    if not isUserDefined:\n        version = (c_char * NVML_GSP_FIRMWARE_VERSION_BUF_SIZE)()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetGspFirmwareVersion\")\n    ret = fn(handle, version)\n    _nvmlCheckReturn(ret)\n    return NVML_SUCCESS if isUserDefined else version.value\n\n\ndef nvmlDeviceGetGspFirmwareMode(handle, isEnabled=c_uint(), defaultMode=c_uint()):\n    isReference = type(isEnabled) is not c_uint\n    isEnabledRef = isEnabled if isReference else byref(isEnabled)\n    defaultModeRef = defaultMode if isReference else byref(defaultMode)\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetGspFirmwareMode\")\n    ret = fn(handle, isEnabledRef, defaultModeRef)\n    _nvmlCheckReturn(ret)\n    return NVML_SUCCESS if isReference else [isEnabled.value, defaultMode.value]\n\n\ndef nvmlDeviceGetEncoderCapacity(handle, encoderQueryType):\n    c_encoder_capacity = c_ulonglong(0)\n    c_encoderQuery_type = _nvmlEncoderQueryType_t(encoderQueryType)\n\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetEncoderCapacity\")\n    ret = fn(handle, c_encoderQuery_type, byref(c_encoder_capacity))\n    _nvmlCheckReturn(ret)\n    return c_encoder_capacity.value\n\n\ndef nvmlDeviceGetVgpuProcessUtilization(handle, timeStamp):\n    # first call to get the size\n    c_vgpu_count = c_uint(0)\n    c_time_stamp = c_ulonglong(timeStamp)\n\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetVgpuProcessUtilization\")\n    ret = fn(handle, c_time_stamp, byref(c_vgpu_count), None)\n\n    if ret == NVML_SUCCESS:\n        # special case, no active vGPUs\n        return []\n    elif ret == NVML_ERROR_INSUFFICIENT_SIZE:\n        # typical case\n        sampleArray = c_vgpu_count.value * c_nvmlVgpuProcessUtilizationSample_t\n        c_samples = sampleArray()\n\n        # make the call again\n        ret = fn(handle, c_time_stamp, byref(c_vgpu_count), c_samples)\n        _nvmlCheckReturn(ret)\n\n        return c_samples[0 : c_vgpu_count.value]\n    else:\n        # error case\n        raise NVMLError(ret)\n\n\ndef nvmlDeviceGetVgpuProcessesUtilizationInfo(handle, timeStamp):\n    # first call to get the size\n    c_time_stamp = c_ulonglong(timeStamp)\n    c_vgpuProcUtilInfo = c_nvmlVgpuProcessesUtilizationInfo_v1_t(0)\n    c_vgpuProcUtilInfo.version = VgpuProcessesUtilizationInfo_v1\n    c_vgpuProcUtilInfo.vgpuProcessCount = c_uint(0)\n    c_vgpuProcUtilInfo.lastSeenTimeStamp = c_time_stamp\n\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetVgpuProcessesUtilizationInfo\")\n    ret = fn(handle, byref(c_vgpuProcUtilInfo))\n\n    if ret == NVML_SUCCESS:\n        # special case, no active vGPUs\n        return []\n    elif ret == NVML_ERROR_INSUFFICIENT_SIZE:\n        # typical case\n        sampleArray = (\n            c_vgpuProcUtilInfo.vgpuProcessCount * c_nvmlVgpuProcessUtilizationInfo_v1_t\n        )\n        c_samples = sampleArray()\n        c_vgpuProcUtilInfo.vgpuProcUtilArray = c_samples\n\n        # make the call again\n        ret = fn(handle, byref(c_vgpuProcUtilInfo))\n        _nvmlCheckReturn(ret)\n\n        return c_samples[0 : c_vgpuProcUtilInfo.vgpuProcessCount]\n    else:\n        # error case\n        raise NVMLError(ret)\n\n\ndef nvmlDeviceGetEncoderStats(handle):\n    c_encoderCount = c_ulonglong(0)\n    c_encodeFps = c_ulonglong(0)\n    c_encoderLatency = c_ulonglong(0)\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetEncoderStats\")\n    ret = fn(handle, byref(c_encoderCount), byref(c_encodeFps), byref(c_encoderLatency))\n    _nvmlCheckReturn(ret)\n    return (c_encoderCount.value, c_encodeFps.value, c_encoderLatency.value)\n\n\ndef nvmlDeviceGetEncoderSessions(handle):\n    # first call to get the size\n    c_session_count = c_uint(0)\n\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetEncoderSessions\")\n    ret = fn(handle, byref(c_session_count), None)\n\n    if ret == NVML_SUCCESS:\n        if c_session_count.value != 0:\n            # typical case\n            session_array = c_nvmlEncoderSession_t * c_session_count.value\n            c_sessions = session_array()\n\n            # make the call again\n            ret = fn(handle, byref(c_session_count), c_sessions)\n            _nvmlCheckReturn(ret)\n            sessions = []\n            for i in range(c_session_count.value):\n                sessions.append(c_sessions[i])\n            return sessions\n        else:\n            return []  # no active sessions\n    else:\n        # error case\n        raise NVMLError(ret)\n\n\ndef nvmlDeviceGetFBCStats(handle):\n    c_fbcStats = c_nvmlFBCStats_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetFBCStats\")\n    ret = fn(handle, byref(c_fbcStats))\n    _nvmlCheckReturn(ret)\n    return c_fbcStats\n\n\ndef nvmlDeviceGetFBCSessions(handle):\n    # first call to get the size\n    c_session_count = c_uint(0)\n\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetFBCSessions\")\n    ret = fn(handle, byref(c_session_count), None)\n\n    if ret == NVML_SUCCESS:\n        if c_session_count.value != 0:\n            # typical case\n            session_array = c_nvmlFBCSession_t * c_session_count.value\n            c_sessions = session_array()\n\n            # make the call again\n            ret = fn(handle, byref(c_session_count), c_sessions)\n            _nvmlCheckReturn(ret)\n            sessions = []\n            for i in range(c_session_count.value):\n                sessions.append(c_sessions[i])\n            return sessions\n        else:\n            return []  # no active sessions\n    else:\n        # error case\n        raise NVMLError(ret)\n\n\ndef nvmlVgpuInstanceGetEncoderStats(vgpuInstance):\n    c_encoderCount = c_ulonglong(0)\n    c_encodeFps = c_ulonglong(0)\n    c_encoderLatency = c_ulonglong(0)\n    fn = _nvmlGetFunctionPointer(\"nvmlVgpuInstanceGetEncoderStats\")\n    ret = fn(\n        vgpuInstance, byref(c_encoderCount), byref(c_encodeFps), byref(c_encoderLatency)\n    )\n    _nvmlCheckReturn(ret)\n    return (c_encoderCount.value, c_encodeFps.value, c_encoderLatency.value)\n\n\ndef nvmlVgpuInstanceGetEncoderSessions(vgpuInstance):\n    # first call to get the size\n    c_session_count = c_uint(0)\n\n    fn = _nvmlGetFunctionPointer(\"nvmlVgpuInstanceGetEncoderSessions\")\n    ret = fn(vgpuInstance, byref(c_session_count), None)\n\n    if ret == NVML_SUCCESS:\n        if c_session_count.value != 0:\n            # typical case\n            session_array = c_nvmlEncoderSession_t * c_session_count.value\n            c_sessions = session_array()\n\n            # make the call again\n            ret = fn(vgpuInstance, byref(c_session_count), c_sessions)\n            _nvmlCheckReturn(ret)\n            sessions = []\n            for i in range(c_session_count.value):\n                sessions.append(c_sessions[i])\n            return sessions\n        else:\n            return []  # no active sessions\n    else:\n        # error case\n        raise NVMLError(ret)\n\n\ndef nvmlVgpuInstanceGetFBCStats(vgpuInstance):\n    c_fbcStats = c_nvmlFBCStats_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlVgpuInstanceGetFBCStats\")\n    ret = fn(vgpuInstance, byref(c_fbcStats))\n    _nvmlCheckReturn(ret)\n    return c_fbcStats\n\n\ndef nvmlVgpuInstanceGetFBCSessions(vgpuInstance):\n    # first call to get the size\n    c_session_count = c_uint(0)\n\n    fn = _nvmlGetFunctionPointer(\"nvmlVgpuInstanceGetFBCSessions\")\n    ret = fn(vgpuInstance, byref(c_session_count), None)\n\n    if ret == NVML_SUCCESS:\n        if c_session_count.value != 0:\n            # typical case\n            session_array = c_nvmlFBCSession_t * c_session_count.value\n            c_sessions = session_array()\n\n            # make the call again\n            ret = fn(vgpuInstance, byref(c_session_count), c_sessions)\n            _nvmlCheckReturn(ret)\n            sessions = []\n            for i in range(c_session_count.value):\n                sessions.append(c_sessions[i])\n            return sessions\n        else:\n            return []  # no active sessions\n    else:\n        # error case\n        raise NVMLError(ret)\n\n\ndef nvmlDeviceGetProcessUtilization(handle, timeStamp):\n    # first call to get the size\n    c_count = c_uint(0)\n    c_time_stamp = c_ulonglong(timeStamp)\n\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetProcessUtilization\")\n    ret = fn(handle, None, byref(c_count), c_time_stamp)\n\n    if ret == NVML_ERROR_INSUFFICIENT_SIZE:\n        # typical case\n        sampleArray = c_count.value * c_nvmlProcessUtilizationSample_t\n        c_samples = sampleArray()\n\n        # make the call again\n        ret = fn(handle, c_samples, byref(c_count), c_time_stamp)\n        _nvmlCheckReturn(ret)\n\n        return c_samples[0 : c_count.value]\n    else:\n        # error case\n        raise NVMLError(ret)\n\n\ndef nvmlDeviceGetProcessesUtilizationInfo(handle, timeStamp):\n    # first call to get the size\n    c_time_stamp = c_ulonglong(timeStamp)\n    c_processesUtilInfo = c_nvmlProcessesUtilizationInfo_v1_t(0)\n    c_processesUtilInfo.version = ProcessesUtilizationInfo_v1\n    c_processesUtilInfo.processSamplesCount = c_uint(0)\n    c_processesUtilInfo.lastSeenTimeStamp = c_time_stamp\n\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetProcessesUtilizationInfo\")\n    ret = fn(handle, byref(c_processesUtilInfo))\n\n    if ret == NVML_ERROR_INSUFFICIENT_SIZE:\n        # typical case\n        sampleArray = (\n            c_processesUtilInfo.processSamplesCount * c_nvmlProcessUtilizationInfo_v1_t\n        )\n        c_samples = sampleArray()\n        c_processesUtilInfo.procUtilArray = c_samples\n\n        # make the call again\n        ret = fn(handle, byref(c_processesUtilInfo))\n        _nvmlCheckReturn(ret)\n\n        return c_samples[0 : c_processesUtilInfo.processSamplesCount]\n    else:\n        # error case\n        raise NVMLError(ret)\n\n\ndef nvmlVgpuInstanceGetMetadata(vgpuInstance):\n    fn = _nvmlGetFunctionPointer(\"nvmlVgpuInstanceGetMetadata\")\n    c_vgpuMetadata = c_nvmlVgpuMetadata_t()\n    c_bufferSize = c_uint(0)\n    # Make the first NVML API call to get the c_bufferSize value.\n    # We have already allocated required buffer above.\n    ret = fn(vgpuInstance, byref(c_vgpuMetadata), byref(c_bufferSize))\n    if ret == NVML_ERROR_INSUFFICIENT_SIZE:\n        ret = fn(vgpuInstance, byref(c_vgpuMetadata), byref(c_bufferSize))\n        _nvmlCheckReturn(ret)\n    else:\n        raise NVMLError(ret)\n    return c_vgpuMetadata\n\n\ndef nvmlDeviceGetVgpuMetadata(handle):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetVgpuMetadata\")\n    c_vgpuPgpuMetadata = c_nvmlVgpuPgpuMetadata_t()\n    c_bufferSize = c_uint(0)\n    # Make the first NVML API call to get the c_bufferSize value.\n    # We have already allocated required buffer above.\n    ret = fn(handle, byref(c_vgpuPgpuMetadata), byref(c_bufferSize))\n    if ret == NVML_ERROR_INSUFFICIENT_SIZE:\n        ret = fn(handle, byref(c_vgpuPgpuMetadata), byref(c_bufferSize))\n        _nvmlCheckReturn(ret)\n    else:\n        raise NVMLError(ret)\n    return c_vgpuPgpuMetadata\n\n\ndef nvmlGetVgpuCompatibility(vgpuMetadata, pgpuMetadata):\n    fn = _nvmlGetFunctionPointer(\"nvmlGetVgpuCompatibility\")\n    c_vgpuPgpuCompatibility = c_nvmlVgpuPgpuCompatibility_t()\n    ret = fn(byref(vgpuMetadata), byref(pgpuMetadata), byref(c_vgpuPgpuCompatibility))\n    _nvmlCheckReturn(ret)\n    return c_vgpuPgpuCompatibility\n\n\n@convertStrBytes\ndef nvmlDeviceGetPgpuMetadataString(handle):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetPgpuMetadataString\")\n    c_pgpuMetadata = create_string_buffer(NVML_VGPU_PGPU_METADATA_OPAQUE_DATA_SIZE)\n    c_bufferSize = c_uint(0)\n    # Make the first NVML API call to get the c_bufferSize value.\n    # We have already allocated required buffer above.\n    ret = fn(handle, byref(c_pgpuMetadata), byref(c_bufferSize))\n    if ret == NVML_ERROR_INSUFFICIENT_SIZE:\n        ret = fn(handle, byref(c_pgpuMetadata), byref(c_bufferSize))\n        _nvmlCheckReturn(ret)\n    else:\n        raise NVMLError(ret)\n    return (c_pgpuMetadata.value, c_bufferSize.value)\n\n\ndef nvmlDeviceGetVgpuSchedulerLog(handle):\n    c_vgpu_sched_log = c_nvmlVgpuSchedulerLog_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetVgpuSchedulerLog\")\n    ret = fn(handle, byref(c_vgpu_sched_log))\n    _nvmlCheckReturn(ret)\n    return c_vgpu_sched_log\n\n\ndef nvmlDeviceGetVgpuSchedulerState(handle):\n    c_vgpu_sched_state = c_nvmlVgpuSchedulerGetState_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetVgpuSchedulerState\")\n    ret = fn(handle, byref(c_vgpu_sched_state))\n    _nvmlCheckReturn(ret)\n    return c_vgpu_sched_state\n\n\ndef nvmlDeviceGetVgpuSchedulerCapabilities(handle):\n    c_vgpu_sched_caps = c_nvmlVgpuSchedulerCapabilities_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetVgpuSchedulerCapabilities\")\n    ret = fn(handle, byref(c_vgpu_sched_caps))\n    _nvmlCheckReturn(ret)\n    return c_vgpu_sched_caps\n\n\ndef nvmlDeviceSetVgpuSchedulerState(handle, sched_state):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceSetVgpuSchedulerState\")\n    ret = fn(handle, byref(sched_state))\n    _nvmlCheckReturn(ret)\n    return NVML_SUCCESS\n\n\ndef nvmlSetVgpuVersion(vgpuVersion):\n    fn = _nvmlGetFunctionPointer(\"nvmlSetVgpuVersion\")\n    ret = fn(byref(vgpuVersion))\n    _nvmlCheckReturn(ret)\n    return NVML_SUCCESS\n\n\ndef nvmlGetVgpuVersion(supported=None, current=None):\n    isUserDefined = (supported is not None) or (current is not None)\n    if not isUserDefined:\n        supported = c_nvmlVgpuVersion_t()\n        current = c_nvmlVgpuVersion_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlGetVgpuVersion\")\n    ret = fn(byref(supported), byref(current))\n    _nvmlCheckReturn(ret)\n    return (\n        NVML_SUCCESS\n        if isUserDefined\n        else [\n            (supported.minVersion, supported.maxVersion),\n            (current.minVersion, current.maxVersion),\n        ]\n    )\n\n\ndef nvmlVgpuInstanceGetAccountingMode(vgpuInstance):\n    c_mode = _nvmlEnableState_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlVgpuInstanceGetAccountingMode\")\n    ret = fn(vgpuInstance, byref(c_mode))\n    _nvmlCheckReturn(ret)\n    return c_mode.value\n\n\ndef nvmlVgpuInstanceGetAccountingPids(vgpuInstance):\n    c_pidCount = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlVgpuInstanceGetAccountingPids\")\n    ret = fn(vgpuInstance, byref(c_pidCount), None)\n    if ret == NVML_ERROR_INSUFFICIENT_SIZE:\n        sampleArray = c_pidCount.value * c_uint\n        c_pidArray = sampleArray()\n        ret = fn(vgpuInstance, byref(c_pidCount), byref(c_pidArray))\n        _nvmlCheckReturn(ret)\n    else:\n        raise NVMLError(ret)\n    return (c_pidCount, c_pidArray)\n\n\ndef nvmlVgpuInstanceGetAccountingStats(vgpuInstance, pid):\n    c_accountingStats = c_nvmlAccountingStats_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlVgpuInstanceGetAccountingStats\")\n    ret = fn(vgpuInstance, pid, byref(c_accountingStats))\n    _nvmlCheckReturn(ret)\n    return c_accountingStats\n\n\ndef nvmlVgpuInstanceClearAccountingPids(vgpuInstance):\n    fn = _nvmlGetFunctionPointer(\"nvmlVgpuInstanceClearAccountingPids\")\n    ret = fn(vgpuInstance)\n    _nvmlCheckReturn(ret)\n    return NVML_SUCCESS\n\n\ndef nvmlGetExcludedDeviceCount():\n    c_count = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlGetExcludedDeviceCount\")\n    ret = fn(byref(c_count))\n    _nvmlCheckReturn(ret)\n    return c_count.value\n\n\ndef nvmlGetExcludedDeviceInfoByIndex(index):\n    c_index = c_uint(index)\n    info = c_nvmlExcludedDeviceInfo_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlGetExcludedDeviceInfoByIndex\")\n    ret = fn(c_index, byref(info))\n    _nvmlCheckReturn(ret)\n    return info\n\n\ndef nvmlDeviceGetHostVgpuMode(handle):\n    c_host_vgpu_mode = _nvmlHostVgpuMode_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetHostVgpuMode\")\n    ret = fn(handle, byref(c_host_vgpu_mode))\n    _nvmlCheckReturn(ret)\n    return c_host_vgpu_mode.value\n\n\ndef nvmlDeviceSetMigMode(device, mode):\n    c_activationStatus = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceSetMigMode\")\n    ret = fn(device, mode, byref(c_activationStatus))\n    _nvmlCheckReturn(ret)\n    return c_activationStatus.value\n\n\ndef nvmlDeviceGetMigMode(device):\n    c_currentMode = c_uint()\n    c_pendingMode = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetMigMode\")\n    ret = fn(device, byref(c_currentMode), byref(c_pendingMode))\n    _nvmlCheckReturn(ret)\n    return [c_currentMode.value, c_pendingMode.value]\n\n\ndef nvmlDeviceGetGpuInstanceProfileInfo(device, profile, version=2):\n    if version == 2:\n        c_info = c_nvmlGpuInstanceProfileInfo_v2_t()\n        fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetGpuInstanceProfileInfoV\")\n    elif version == 1:\n        c_info = c_nvmlGpuInstanceProfileInfo_t()\n        fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetGpuInstanceProfileInfo\")\n    else:\n        raise NVMLError(NVML_ERROR_FUNCTION_NOT_FOUND)\n    ret = fn(device, profile, byref(c_info))\n    _nvmlCheckReturn(ret)\n    return c_info\n\n\n# Define function alias for the API exposed by NVML\nnvmlDeviceGetGpuInstanceProfileInfoV = nvmlDeviceGetGpuInstanceProfileInfo\n\n\ndef nvmlDeviceGetGpuInstanceRemainingCapacity(device, profileId):\n    c_count = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetGpuInstanceRemainingCapacity\")\n    ret = fn(device, profileId, byref(c_count))\n    _nvmlCheckReturn(ret)\n    return c_count.value\n\n\ndef nvmlDeviceGetGpuInstancePossiblePlacements(\n    device, profileId, placementsRef, countRef\n):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetGpuInstancePossiblePlacements_v2\")\n    ret = fn(device, profileId, placementsRef, countRef)\n    _nvmlCheckReturn(ret)\n    return NVML_SUCCESS\n\n\ndef nvmlDeviceCreateGpuInstance(device, profileId):\n    c_instance = c_nvmlGpuInstance_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceCreateGpuInstance\")\n    ret = fn(device, profileId, byref(c_instance))\n    _nvmlCheckReturn(ret)\n    return c_instance\n\n\ndef nvmlDeviceCreateGpuInstanceWithPlacement(device, profileId, placement):\n    c_instance = c_nvmlGpuInstance_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceCreateGpuInstanceWithPlacement\")\n    ret = fn(device, profileId, placement, byref(c_instance))\n    _nvmlCheckReturn(ret)\n    return c_instance\n\n\ndef nvmlGpuInstanceDestroy(gpuInstance):\n    fn = _nvmlGetFunctionPointer(\"nvmlGpuInstanceDestroy\")\n    ret = fn(gpuInstance)\n    _nvmlCheckReturn(ret)\n    return NVML_SUCCESS\n\n\ndef nvmlDeviceGetGpuInstances(device, profileId, gpuInstancesRef, countRef):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetGpuInstances\")\n    ret = fn(device, profileId, gpuInstancesRef, countRef)\n    _nvmlCheckReturn(ret)\n    return NVML_SUCCESS\n\n\ndef nvmlDeviceGetGpuInstanceById(device, gpuInstanceId):\n    c_instance = c_nvmlGpuInstance_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetGpuInstanceById\")\n    ret = fn(device, gpuInstanceId, byref(c_instance))\n    _nvmlCheckReturn(ret)\n    return c_instance\n\n\ndef nvmlGpuInstanceGetInfo(gpuInstance):\n    c_info = c_nvmlGpuInstanceInfo_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlGpuInstanceGetInfo\")\n    ret = fn(gpuInstance, byref(c_info))\n    _nvmlCheckReturn(ret)\n    return c_info\n\n\ndef nvmlGpuInstanceGetComputeInstanceProfileInfo(\n    device, profile, engProfile, version=2\n):\n    if version == 2:\n        c_info = c_nvmlComputeInstanceProfileInfo_v2_t()\n        fn = _nvmlGetFunctionPointer(\"nvmlGpuInstanceGetComputeInstanceProfileInfoV\")\n    elif version == 1:\n        c_info = c_nvmlComputeInstanceProfileInfo_t()\n        fn = _nvmlGetFunctionPointer(\"nvmlGpuInstanceGetComputeInstanceProfileInfo\")\n    else:\n        raise NVMLError(NVML_ERROR_FUNCTION_NOT_FOUND)\n    ret = fn(device, profile, engProfile, byref(c_info))\n    _nvmlCheckReturn(ret)\n    return c_info\n\n\n# Define function alias for the API exposed by NVML\nnvmlGpuInstanceGetComputeInstanceProfileInfoV = (\n    nvmlGpuInstanceGetComputeInstanceProfileInfo\n)\n\n\ndef nvmlGpuInstanceGetComputeInstanceRemainingCapacity(gpuInstance, profileId):\n    c_count = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlGpuInstanceGetComputeInstanceRemainingCapacity\")\n    ret = fn(gpuInstance, profileId, byref(c_count))\n    _nvmlCheckReturn(ret)\n    return c_count.value\n\n\ndef nvmlGpuInstanceGetComputeInstancePossiblePlacements(\n    gpuInstance, profileId, placementsRef, countRef\n):\n    fn = _nvmlGetFunctionPointer(\"nvmlGpuInstanceGetComputeInstancePossiblePlacements\")\n    ret = fn(gpuInstance, profileId, placementsRef, countRef)\n    _nvmlCheckReturn(ret)\n    return NVML_SUCCESS\n\n\ndef nvmlGpuInstanceCreateComputeInstance(gpuInstance, profileId):\n    c_instance = c_nvmlComputeInstance_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlGpuInstanceCreateComputeInstance\")\n    ret = fn(gpuInstance, profileId, byref(c_instance))\n    _nvmlCheckReturn(ret)\n    return c_instance\n\n\ndef nvmlGpuInstanceCreateComputeInstanceWithPlacement(\n    gpuInstance, profileId, placement\n):\n    c_instance = c_nvmlComputeInstance_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlGpuInstanceCreateComputeInstanceWithPlacement\")\n    ret = fn(gpuInstance, profileId, placement, byref(c_instance))\n    _nvmlCheckReturn(ret)\n    return c_instance\n\n\ndef nvmlComputeInstanceDestroy(computeInstance):\n    fn = _nvmlGetFunctionPointer(\"nvmlComputeInstanceDestroy\")\n    ret = fn(computeInstance)\n    _nvmlCheckReturn(ret)\n    return NVML_SUCCESS\n\n\ndef nvmlGpuInstanceGetComputeInstances(\n    gpuInstance, profileId, computeInstancesRef, countRef\n):\n    fn = _nvmlGetFunctionPointer(\"nvmlGpuInstanceGetComputeInstances\")\n    ret = fn(gpuInstance, profileId, computeInstancesRef, countRef)\n    _nvmlCheckReturn(ret)\n    return NVML_SUCCESS\n\n\ndef nvmlGpuInstanceGetComputeInstanceById(gpuInstance, computeInstanceId):\n    c_instance = c_nvmlComputeInstance_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlGpuInstanceGetComputeInstanceById\")\n    ret = fn(gpuInstance, computeInstanceId, byref(c_instance))\n    _nvmlCheckReturn(ret)\n    return c_instance\n\n\ndef nvmlComputeInstanceGetInfo_v2(computeInstance):\n    c_info = c_nvmlComputeInstanceInfo_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlComputeInstanceGetInfo_v2\")\n    ret = fn(computeInstance, byref(c_info))\n    _nvmlCheckReturn(ret)\n    return c_info\n\n\ndef nvmlComputeInstanceGetInfo(computeInstance):\n    return nvmlComputeInstanceGetInfo_v2(computeInstance)\n\n\ndef nvmlDeviceIsMigDeviceHandle(device):\n    c_isMigDevice = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceIsMigDeviceHandle\")\n    ret = fn(device, byref(c_isMigDevice))\n    _nvmlCheckReturn(ret)\n    return c_isMigDevice\n\n\ndef nvmlDeviceGetGpuInstanceId(device):\n    c_gpuInstanceId = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetGpuInstanceId\")\n    ret = fn(device, byref(c_gpuInstanceId))\n    _nvmlCheckReturn(ret)\n    return c_gpuInstanceId.value\n\n\ndef nvmlDeviceGetComputeInstanceId(device):\n    c_computeInstanceId = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetComputeInstanceId\")\n    ret = fn(device, byref(c_computeInstanceId))\n    _nvmlCheckReturn(ret)\n    return c_computeInstanceId.value\n\n\ndef nvmlDeviceGetMaxMigDeviceCount(device):\n    c_count = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetMaxMigDeviceCount\")\n    ret = fn(device, byref(c_count))\n    _nvmlCheckReturn(ret)\n    return c_count.value\n\n\ndef nvmlDeviceGetMigDeviceHandleByIndex(device, index):\n    c_index = c_uint(index)\n    migDevice = c_nvmlDevice_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetMigDeviceHandleByIndex\")\n    ret = fn(device, c_index, byref(migDevice))\n    _nvmlCheckReturn(ret)\n    return migDevice\n\n\ndef nvmlDeviceGetDeviceHandleFromMigDeviceHandle(migDevice):\n    device = c_nvmlDevice_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetDeviceHandleFromMigDeviceHandle\")\n    ret = fn(migDevice, byref(device))\n    _nvmlCheckReturn(ret)\n    return device\n\n\ndef nvmlDeviceGetAttributes_v2(device):\n    c_attrs = c_nvmlDeviceAttributes()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetAttributes_v2\")\n    ret = fn(device, byref(c_attrs))\n    _nvmlCheckReturn(ret)\n    return c_attrs\n\n\ndef nvmlDeviceGetAttributes(device):\n    return nvmlDeviceGetAttributes_v2(device)\n\n\ndef nvmlDeviceGetRemappedRows(device):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetRemappedRows\")\n    c_corr = c_uint()\n    c_unc = c_uint()\n    c_bpending = c_uint()\n    c_bfailure = c_uint()\n    ret = fn(device, byref(c_corr), byref(c_unc), byref(c_bpending), byref(c_bfailure))\n    _nvmlCheckReturn(ret)\n    return (c_corr.value, c_unc.value, c_bpending.value, c_bfailure.value)\n\n\ndef nvmlDeviceGetRowRemapperHistogram(device):\n    c_vals = c_nvmlRowRemapperHistogramValues()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetRowRemapperHistogram\")\n    ret = fn(device, byref(c_vals))\n    _nvmlCheckReturn(ret)\n    return c_vals\n\n\ndef nvmlDeviceGetArchitecture(device):\n    arch = _nvmlDeviceArchitecture_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetArchitecture\")\n    ret = fn(device, byref(arch))\n    _nvmlCheckReturn(ret)\n    return arch.value\n\n\ndef nvmlDeviceGetBusType(device):\n    c_busType = _nvmlBusType_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetBusType\")\n    ret = fn(device, byref(c_busType))\n    _nvmlCheckReturn(ret)\n    return c_busType.value\n\n\ndef nvmlDeviceGetIrqNum(device):\n    c_irqNum = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetIrqNum\")\n    ret = fn(device, byref(c_irqNum))\n    _nvmlCheckReturn(ret)\n    return c_irqNum.value\n\n\ndef nvmlDeviceGetNumGpuCores(device):\n    c_numCores = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetNumGpuCores\")\n    ret = fn(device, byref(c_numCores))\n    _nvmlCheckReturn(ret)\n    return c_numCores.value\n\n\ndef nvmlDeviceGetPowerSource(device):\n    c_powerSource = _nvmlPowerSource_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetPowerSource\")\n    ret = fn(device, byref(c_powerSource))\n    _nvmlCheckReturn(ret)\n    return c_powerSource.value\n\n\ndef nvmlDeviceGetMemoryBusWidth(device):\n    c_memBusWidth = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetMemoryBusWidth\")\n    ret = fn(device, byref(c_memBusWidth))\n    _nvmlCheckReturn(ret)\n    return c_memBusWidth.value\n\n\ndef nvmlDeviceGetPcieLinkMaxSpeed(device):\n    c_speed = _nvmlPcieLinkMaxSpeed_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetPcieLinkMaxSpeed\")\n    ret = fn(device, byref(c_speed))\n    _nvmlCheckReturn(ret)\n    return c_speed.value\n\n\ndef nvmlDeviceGetAdaptiveClockInfoStatus(device):\n    c_adaptiveClockInfoStatus = _nvmlAdaptiveClockInfoStatus_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetAdaptiveClockInfoStatus\")\n    ret = fn(device, byref(c_adaptiveClockInfoStatus))\n    _nvmlCheckReturn(ret)\n    return c_adaptiveClockInfoStatus.value\n\n\ndef nvmlDeviceGetPcieSpeed(device):\n    c_speed = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetPcieSpeed\")\n    ret = fn(device, byref(c_speed))\n    _nvmlCheckReturn(ret)\n    return c_speed.value\n\n\ndef nvmlDeviceGetDynamicPstatesInfo(\n    device, c_dynamicpstatesinfo=c_nvmlGpuDynamicPstatesInfo_t()\n):\n    isReference = type(c_dynamicpstatesinfo) is not c_nvmlGpuDynamicPstatesInfo_t\n    dynamicpstatesinfoRef = (\n        c_dynamicpstatesinfo if isReference else byref(c_dynamicpstatesinfo)\n    )\n\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetDynamicPstatesInfo\")\n    ret = fn(device, dynamicpstatesinfoRef)\n    _nvmlCheckReturn(ret)\n    return NVML_SUCCESS if isReference else c_dynamicpstatesinfo\n\n\ndef nvmlDeviceSetFanSpeed_v2(handle, index, speed):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceSetFanSpeed_v2\")\n    ret = fn(handle, index, speed)\n    _nvmlCheckReturn(ret)\n    return NVML_SUCCESS\n\n\ndef nvmlDeviceGetThermalSettings(\n    device, sensorindex, c_thermalsettings=c_nvmlGpuThermalSettings_t()\n):\n    isReference = type(c_thermalsettings) is not c_nvmlGpuThermalSettings_t\n    thermalsettingsRef = c_thermalsettings if isReference else byref(c_thermalsettings)\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetThermalSettings\")\n    ret = fn(device, sensorindex, thermalsettingsRef)\n    _nvmlCheckReturn(ret)\n    return NVML_SUCCESS if isReference else c_thermalsettings.sensor[:]\n\n\ndef nvmlDeviceGetMinMaxClockOfPState(\n    device, clockType, pstate, minClockMHz=c_uint(), maxClockMHz=c_uint()\n):\n    isReference = (type(minClockMHz) is not c_uint) or (type(maxClockMHz) is not c_uint)\n    minClockMHzRef = minClockMHz if isReference else byref(minClockMHz)\n    maxClockMHzRef = maxClockMHz if isReference else byref(maxClockMHz)\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetMinMaxClockOfPState\")\n    ret = fn(\n        device,\n        _nvmlClockType_t(clockType),\n        _nvmlClockType_t(pstate),\n        minClockMHzRef,\n        maxClockMHzRef,\n    )\n    _nvmlCheckReturn(ret)\n    return NVML_SUCCESS if isReference else (minClockMHz.value, maxClockMHz.value)\n\n\nclass c_nvmlClockOffset_t(_PrintableStructure):\n    _fields_ = [\n        (\"version\", c_uint),\n        (\"type\", _nvmlClockType_t),\n        (\"pstate\", _nvmlPstates_t),\n        (\"clockOffsetMHz\", c_int),\n        (\"minClockOffsetMHz\", c_int),\n        (\"maxClockOffsetMHz\", c_int),\n    ]\n\n\nnvmlClockOffset_v1 = 0x1000018\n\n\ndef nvmlDeviceGetClockOffsets(device, info):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetClockOffsets\")\n    ret = fn(device, info)\n    return NVML_SUCCESS\n\n\ndef nvmlDeviceSetClockOffsets(device, info):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceSetClockOffsets\")\n    ret = fn(device, info)\n    return NVML_SUCCESS\n\n\ndef nvmlDeviceGetSupportedPerformanceStates(device):\n    pstates = []\n    c_count = c_uint(NVML_MAX_GPU_PERF_PSTATES)\n    c_size = sizeof(c_uint) * c_count.value\n\n    # NOTE: use 'c_uint' to represent the size of the nvmlPstate_t enumeration.\n    pstates_array = _nvmlPstates_t * c_count.value\n    c_pstates = pstates_array()\n\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetSupportedPerformanceStates\")\n    ret = fn(device, c_pstates, c_size)\n    _nvmlCheckReturn(ret)\n\n    for value in c_pstates:\n        if value != NVML_PSTATE_UNKNOWN:\n            pstates.append(value)\n\n    return pstates\n\n\ndef nvmlDeviceGetGpcClkVfOffset(device):\n    offset = c_int32()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetGpcClkVfOffset\")\n    ret = fn(device, byref(offset))\n    _nvmlCheckReturn(ret)\n    return offset.value\n\n\ndef nvmlDeviceSetGpcClkVfOffset(device, offset):\n    c_offset = c_int32(offset)\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceSetGpcClkVfOffset\")\n    ret = fn(device, c_offset)\n    _nvmlCheckReturn(ret)\n    return NVML_SUCCESS\n\n\ndef nvmlDeviceGetGpcClkMinMaxVfOffset(device, minOffset=c_int(), maxOffset=c_int()):\n    isReference = (type(minOffset) is not c_int) or (type(maxOffset) is not c_int)\n    minOffsetRef = minOffset if isReference else byref(minOffset)\n    maxOffsetRef = maxOffset if isReference else byref(maxOffset)\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetGpcClkMinMaxVfOffset\")\n    ret = fn(device, minOffsetRef, maxOffsetRef)\n    _nvmlCheckReturn(ret)\n    return NVML_SUCCESS if isReference else (minOffset.value, maxOffset.value)\n\n\ndef nvmlDeviceGetMemClkVfOffset(device):\n    offset = c_int32()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetMemClkVfOffset\")\n    ret = fn(device, byref(offset))\n    _nvmlCheckReturn(ret)\n    return offset.value\n\n\ndef nvmlDeviceSetMemClkVfOffset(device, offset):\n    c_offset = c_int32(offset)\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceSetMemClkVfOffset\")\n    ret = fn(device, c_offset)\n    _nvmlCheckReturn(ret)\n    return NVML_SUCCESS\n\n\ndef nvmlDeviceGetMemClkMinMaxVfOffset(device, minOffset=c_int(), maxOffset=c_int()):\n    isReference = (type(minOffset) is not c_int) or (type(maxOffset) is not c_int)\n    minOffsetRef = minOffset if isReference else byref(minOffset)\n    maxOffsetRef = maxOffset if isReference else byref(maxOffset)\n\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetMemClkMinMaxVfOffset\")\n    ret = fn(device, minOffsetRef, maxOffsetRef)\n    _nvmlCheckReturn(ret)\n    return NVML_SUCCESS if isReference else (minOffset.value, maxOffset.value)\n\n\ndef nvmlSystemSetConfComputeGpusReadyState(state):\n    c_state = c_uint(state)\n    fn = _nvmlGetFunctionPointer(\"nvmlSystemSetConfComputeGpusReadyState\")\n    ret = fn(c_state)\n    _nvmlCheckReturn(ret)\n    return NVML_SUCCESS\n\n\ndef nvmlSystemGetConfComputeGpusReadyState():\n    c_state = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlSystemGetConfComputeGpusReadyState\")\n    ret = fn(byref(c_state))\n    _nvmlCheckReturn(ret)\n    return c_state.value\n\n\ndef nvmlSystemGetConfComputeCapabilities():\n    c_ccSysCaps = c_nvmlConfComputeSystemCaps_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlSystemGetConfComputeCapabilities\")\n    ret = fn(byref(c_ccSysCaps))\n    _nvmlCheckReturn(ret)\n    return c_ccSysCaps\n\n\ndef nvmlSystemGetConfComputeState():\n    c_state = c_nvmlConfComputeSystemState_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlSystemGetConfComputeState\")\n    ret = fn(byref(c_state))\n    _nvmlCheckReturn(ret)\n    return c_state\n\n\ndef nvmlSystemGetConfComputeSettings(settings):\n    fn = _nvmlGetFunctionPointer(\"nvmlSystemGetConfComputeSettings\")\n    return fn(settings)\n\n\ndef nvmlDeviceSetConfComputeUnprotectedMemSize(device, c_ccMemSize):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceSetConfComputeUnprotectedMemSize\")\n    ret = fn(device, c_ccMemSize)\n    _nvmlCheckReturn(ret)\n    return NVML_SUCCESS\n\n\ndef nvmlDeviceGetConfComputeMemSizeInfo(device):\n    c_ccMemSize = c_nvmlConfComputeMemSizeInfo_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetConfComputeMemSizeInfo\")\n    ret = fn(device, byref(c_ccMemSize))\n    _nvmlCheckReturn(ret)\n    return c_ccMemSize\n\n\ndef nvmlDeviceGetConfComputeProtectedMemoryUsage(device):\n    c_memory = c_nvmlMemory_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetConfComputeProtectedMemoryUsage\")\n    ret = fn(device, byref(c_memory))\n    _nvmlCheckReturn(ret)\n    return c_memory\n\n\ndef nvmlDeviceGetConfComputeGpuCertificate(device):\n    c_cert = c_nvmlConfComputeGpuCertificate_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetConfComputeGpuCertificate\")\n    ret = fn(device, byref(c_cert))\n    _nvmlCheckReturn(ret)\n    return c_cert\n\n\ndef nvmlDeviceGetConfComputeGpuAttestationReport(device, c_nonce):\n    c_attestReport = c_nvmlConfComputeGpuAttestationReport_t()\n    c_nonce_arr = (c_uint8 * len(c_nonce))(*(c_nonce))\n    setattr(c_attestReport, \"nonce\", c_nonce_arr)\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetConfComputeGpuAttestationReport\")\n    ret = fn(device, byref(c_attestReport))\n    _nvmlCheckReturn(ret)\n    return c_attestReport\n\n\ndef nvmlSystemSetConfComputeKeyRotationThresholdInfo(max_atk_adv):\n    c_keyRotationThrInfo = c_nvmlConfComputeSetKeyRotationThresholdInfo_t(0)\n    c_keyRotationThrInfo.version = ConfComputeSetKeyRotationThresholdInfo_v1\n    c_keyRotationThrInfo.maxAttackerAdvantage = max_atk_adv\n    fn = _nvmlGetFunctionPointer(\"nvmlSystemSetConfComputeKeyRotationThresholdInfo\")\n    ret = fn(byref(c_keyRotationThrInfo))\n    _nvmlCheckReturn(ret)\n    return NVML_SUCCESS\n\n\ndef nvmlSystemGetConfComputeKeyRotationThresholdInfo():\n    c_keyRotationThrInfo = c_nvmlConfComputeGetKeyRotationThresholdInfo_t(0)\n    c_keyRotationThrInfo.version = ConfComputeGetKeyRotationThresholdInfo_v1\n    fn = _nvmlGetFunctionPointer(\"nvmlSystemGetConfComputeKeyRotationThresholdInfo\")\n    ret = fn(byref(c_keyRotationThrInfo))\n    _nvmlCheckReturn(ret)\n    return c_keyRotationThrInfo\n\n\n## GPM ##\n#########\n\n## Enums/defines\n\n#### GPM Metric Identifiers\nNVML_GPM_METRIC_GRAPHICS_UTIL = (\n    1  # Percentage of time any compute/graphics app was active on the GPU. 0.0 - 100.0\n)\nNVML_GPM_METRIC_SM_UTIL = 2  # Percentage of SMs that were busy. 0.0 - 100.0\nNVML_GPM_METRIC_SM_OCCUPANCY = (\n    3  # Percentage of warps that were active vs theoretical maximum. 0.0 - 100.0\n)\nNVML_GPM_METRIC_INTEGER_UTIL = (\n    4  # Percentage of time the GPU's SMs were doing integer operations. 0.0 - 100.0\n)\nNVML_GPM_METRIC_ANY_TENSOR_UTIL = (\n    5  # Percentage of time the GPU's SMs were doing ANY tensor operations. 0.0 - 100.0\n)\nNVML_GPM_METRIC_DFMA_TENSOR_UTIL = (\n    6  # Percentage of time the GPU's SMs were doing DFMA tensor operations. 0.0 - 100.0\n)\nNVML_GPM_METRIC_HMMA_TENSOR_UTIL = (\n    7  # Percentage of time the GPU's SMs were doing HMMA tensor operations. 0.0 - 100.0\n)\nNVML_GPM_METRIC_IMMA_TENSOR_UTIL = (\n    9  # Percentage of time the GPU's SMs were doing IMMA tensor operations. 0.0 - 100.0\n)\nNVML_GPM_METRIC_DRAM_BW_UTIL = (\n    10  # Percentage of DRAM bw used vs theoretical maximum. 0.0 - 100.0\n)\nNVML_GPM_METRIC_FP64_UTIL = (\n    11  # Percentage of time the GPU's SMs were doing non-tensor FP64 math. 0.0 - 100.0\n)\nNVML_GPM_METRIC_FP32_UTIL = (\n    12  # Percentage of time the GPU's SMs were doing non-tensor FP32 math. 0.0 - 100.0\n)\nNVML_GPM_METRIC_FP16_UTIL = (\n    13  # Percentage of time the GPU's SMs were doing non-tensor FP16 math. 0.0 - 100.0\n)\nNVML_GPM_METRIC_PCIE_TX_PER_SEC = 20  # PCIe traffic from this GPU in MiB/sec\nNVML_GPM_METRIC_PCIE_RX_PER_SEC = 21  # PCIe traffic to this GPU in MiB/sec\nNVML_GPM_METRIC_NVDEC_0_UTIL = 30  # Percent utilization of NVDEC 0. 0.0 - 100.0\nNVML_GPM_METRIC_NVDEC_1_UTIL = 31  # Percent utilization of NVDEC 1. 0.0 - 100.0\nNVML_GPM_METRIC_NVDEC_2_UTIL = 32  # Percent utilization of NVDEC 2. 0.0 - 100.0\nNVML_GPM_METRIC_NVDEC_3_UTIL = 33  # Percent utilization of NVDEC 3. 0.0 - 100.0\nNVML_GPM_METRIC_NVDEC_4_UTIL = 34  # Percent utilization of NVDEC 4. 0.0 - 100.0\nNVML_GPM_METRIC_NVDEC_5_UTIL = 35  # Percent utilization of NVDEC 5. 0.0 - 100.0\nNVML_GPM_METRIC_NVDEC_6_UTIL = 36  # Percent utilization of NVDEC 6. 0.0 - 100.0\nNVML_GPM_METRIC_NVDEC_7_UTIL = 37  # Percent utilization of NVDEC 7. 0.0 - 100.0\nNVML_GPM_METRIC_NVJPG_0_UTIL = 40  # Percent utilization of NVJPG 0. 0.0 - 100.0\nNVML_GPM_METRIC_NVJPG_1_UTIL = 41  # Percent utilization of NVJPG 1. 0.0 - 100.0\nNVML_GPM_METRIC_NVJPG_2_UTIL = 42  # Percent utilization of NVJPG 2. 0.0 - 100.0\nNVML_GPM_METRIC_NVJPG_3_UTIL = 43  # Percent utilization of NVJPG 3. 0.0 - 100.0\nNVML_GPM_METRIC_NVJPG_4_UTIL = 44  # Percent utilization of NVJPG 4. 0.0 - 100.0\nNVML_GPM_METRIC_NVJPG_5_UTIL = 45  # Percent utilization of NVJPG 5. 0.0 - 100.0\nNVML_GPM_METRIC_NVJPG_6_UTIL = 46  # Percent utilization of NVJPG 6. 0.0 - 100.0\nNVML_GPM_METRIC_NVJPG_7_UTIL = 47  # Percent utilization of NVJPG 7. 0.0 - 100.0\nNVML_GPM_METRIC_NVOFA_0_UTIL = 50  # Percent utilization of NVOFA 0. 0.0 - 100.0\nNVML_GPM_METRIC_NVOFA_1_UTIL = 51  # Percent utilization of NVOFA 1. 0.0 - 100.0\nNVML_GPM_METRIC_NVLINK_TOTAL_RX_PER_SEC = (\n    60  # NvLink read bandwidth for all links in MiB/sec\n)\nNVML_GPM_METRIC_NVLINK_TOTAL_TX_PER_SEC = (\n    61  # NvLink write bandwidth for all links in MiB/sec\n)\nNVML_GPM_METRIC_NVLINK_L0_RX_PER_SEC = 62  # NvLink read bandwidth for link 0 in MiB/sec\nNVML_GPM_METRIC_NVLINK_L0_TX_PER_SEC = (\n    63  # NvLink write bandwidth for link 0 in MiB/sec\n)\nNVML_GPM_METRIC_NVLINK_L1_RX_PER_SEC = 64  # NvLink read bandwidth for link 1 in MiB/sec\nNVML_GPM_METRIC_NVLINK_L1_TX_PER_SEC = (\n    65  # NvLink write bandwidth for link 1 in MiB/sec\n)\nNVML_GPM_METRIC_NVLINK_L2_RX_PER_SEC = 66  # NvLink read bandwidth for link 2 in MiB/sec\nNVML_GPM_METRIC_NVLINK_L2_TX_PER_SEC = (\n    67  # NvLink write bandwidth for link 2 in MiB/sec\n)\nNVML_GPM_METRIC_NVLINK_L3_RX_PER_SEC = 68  # NvLink read bandwidth for link 3 in MiB/sec\nNVML_GPM_METRIC_NVLINK_L3_TX_PER_SEC = (\n    69  # NvLink write bandwidth for link 3 in MiB/sec\n)\nNVML_GPM_METRIC_NVLINK_L4_RX_PER_SEC = 70  # NvLink read bandwidth for link 4 in MiB/sec\nNVML_GPM_METRIC_NVLINK_L4_TX_PER_SEC = (\n    71  # NvLink write bandwidth for link 4 in MiB/sec\n)\nNVML_GPM_METRIC_NVLINK_L5_RX_PER_SEC = 72  # NvLink read bandwidth for link 5 in MiB/sec\nNVML_GPM_METRIC_NVLINK_L5_TX_PER_SEC = (\n    73  # NvLink write bandwidth for link 5 in MiB/sec\n)\nNVML_GPM_METRIC_NVLINK_L6_RX_PER_SEC = 74  # NvLink read bandwidth for link 6 in MiB/sec\nNVML_GPM_METRIC_NVLINK_L6_TX_PER_SEC = (\n    75  # NvLink write bandwidth for link 6 in MiB/sec\n)\nNVML_GPM_METRIC_NVLINK_L7_RX_PER_SEC = 76  # NvLink read bandwidth for link 7 in MiB/sec\nNVML_GPM_METRIC_NVLINK_L7_TX_PER_SEC = (\n    77  # NvLink write bandwidth for link 7 in MiB/sec\n)\nNVML_GPM_METRIC_NVLINK_L8_RX_PER_SEC = 78  # NvLink read bandwidth for link 8 in MiB/sec\nNVML_GPM_METRIC_NVLINK_L8_TX_PER_SEC = (\n    79  # NvLink write bandwidth for link 8 in MiB/sec\n)\nNVML_GPM_METRIC_NVLINK_L9_RX_PER_SEC = 80  # NvLink read bandwidth for link 9 in MiB/sec\nNVML_GPM_METRIC_NVLINK_L9_TX_PER_SEC = (\n    81  # NvLink write bandwidth for link 9 in MiB/sec\n)\nNVML_GPM_METRIC_NVLINK_L10_RX_PER_SEC = (\n    82  # NvLink read bandwidth for link 10 in MiB/sec\n)\nNVML_GPM_METRIC_NVLINK_L10_TX_PER_SEC = (\n    83  # NvLink write bandwidth for link 10 in MiB/sec\n)\nNVML_GPM_METRIC_NVLINK_L11_RX_PER_SEC = (\n    84  # NvLink read bandwidth for link 11 in MiB/sec\n)\nNVML_GPM_METRIC_NVLINK_L11_TX_PER_SEC = (\n    85  # NvLink write bandwidth for link 11 in MiB/sec\n)\nNVML_GPM_METRIC_NVLINK_L12_RX_PER_SEC = (\n    86  # NvLink read bandwidth for link 12 in MiB/sec\n)\nNVML_GPM_METRIC_NVLINK_L12_TX_PER_SEC = (\n    87  # NvLink write bandwidth for link 12 in MiB/sec\n)\nNVML_GPM_METRIC_NVLINK_L13_RX_PER_SEC = (\n    88  # NvLink read bandwidth for link 13 in MiB/sec\n)\nNVML_GPM_METRIC_NVLINK_L13_TX_PER_SEC = (\n    89  # NvLink write bandwidth for link 13 in MiB/sec\n)\nNVML_GPM_METRIC_NVLINK_L14_RX_PER_SEC = (\n    90  # NvLink read bandwidth for link 14 in MiB/sec\n)\nNVML_GPM_METRIC_NVLINK_L14_TX_PER_SEC = (\n    91  # NvLink write bandwidth for link 14 in MiB/sec\n)\nNVML_GPM_METRIC_NVLINK_L15_RX_PER_SEC = (\n    92  # NvLink read bandwidth for link 15 in MiB/sec\n)\nNVML_GPM_METRIC_NVLINK_L15_TX_PER_SEC = (\n    93  # NvLink write bandwidth for link 15 in MiB/sec\n)\nNVML_GPM_METRIC_NVLINK_L16_RX_PER_SEC = (\n    94  # NvLink read bandwidth for link 16 in MiB/sec\n)\nNVML_GPM_METRIC_NVLINK_L16_TX_PER_SEC = (\n    95  # NvLink write bandwidth for link 16 in MiB/sec\n)\nNVML_GPM_METRIC_NVLINK_L17_RX_PER_SEC = (\n    96  # NvLink read bandwidth for link 17 in MiB/sec\n)\nNVML_GPM_METRIC_NVLINK_L17_TX_PER_SEC = (\n    97  # NvLink write bandwidth for link 17 in MiB/sec\n)\nNVML_GPM_METRIC_MAX = 98\n\n## Structs\n\n\nclass c_nvmlUnitInfo_t(_PrintableStructure):\n    _fields_ = [\n        (\"name\", c_char * 96),\n        (\"id\", c_char * 96),\n        (\"serial\", c_char * 96),\n        (\"firmwareVersion\", c_char * 96),\n    ]\n\n\nclass struct_c_nvmlGpmSample_t(Structure):\n    pass  # opaque handle\n\n\nc_nvmlGpmSample_t = POINTER(struct_c_nvmlGpmSample_t)\n\n\nclass c_metricInfo_t(Structure):\n    _fields_ = [\n        (\"shortName\", c_char_p),\n        (\"longName\", c_char_p),\n        (\"unit\", c_char_p),\n    ]\n\n\nclass c_nvmlGpmMetric_t(_PrintableStructure):\n    _fields_ = [\n        (\"metricId\", c_uint),\n        (\"nvmlReturn\", _nvmlReturn_t),\n        (\"value\", c_double),\n        (\"metricInfo\", c_metricInfo_t),\n    ]\n\n\nclass c_nvmlGpmMetricsGet_t(_PrintableStructure):\n    _fields_ = [\n        (\"version\", c_uint),\n        (\"numMetrics\", c_uint),\n        (\"sample1\", c_nvmlGpmSample_t),\n        (\"sample2\", c_nvmlGpmSample_t),\n        (\"metrics\", c_nvmlGpmMetric_t * NVML_GPM_METRIC_MAX),\n    ]\n\n\nNVML_GPM_METRICS_GET_VERSION = 1\n\n\nclass c_nvmlGpmSupport_t(_PrintableStructure):\n    _fields_ = [\n        (\"version\", c_uint),\n        (\"isSupportedDevice\", c_uint),\n    ]\n\n\nNVML_GPM_SUPPORT_VERSION = 1\n\n## Functions\n\n\ndef nvmlGpmMetricsGet(metricsGet):\n    fn = _nvmlGetFunctionPointer(\"nvmlGpmMetricsGet\")\n    ret = fn(byref(metricsGet))\n    _nvmlCheckReturn(ret)\n    return metricsGet\n\n\ndef nvmlGpmSampleFree(gpmSample):\n    fn = _nvmlGetFunctionPointer(\"nvmlGpmSampleFree\")\n    ret = fn(gpmSample)\n    _nvmlCheckReturn(ret)\n    return\n\n\ndef nvmlGpmSampleAlloc():\n    gpmSample = c_nvmlGpmSample_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlGpmSampleAlloc\")\n    ret = fn(byref(gpmSample))\n    _nvmlCheckReturn(ret)\n    return gpmSample\n\n\ndef nvmlGpmSampleGet(device, gpmSample):\n    fn = _nvmlGetFunctionPointer(\"nvmlGpmSampleGet\")\n    ret = fn(device, gpmSample)\n    _nvmlCheckReturn(ret)\n    return gpmSample\n\n\ndef nvmlGpmMigSampleGet(device, gpuInstanceId, gpmSample):\n    fn = _nvmlGetFunctionPointer(\"nvmlGpmMigSampleGet\")\n    ret = fn(device, gpuInstanceId, gpmSample)\n    _nvmlCheckReturn(ret)\n    return gpmSample\n\n\ndef nvmlGpmQueryDeviceSupport(device):\n    gpmSupport = c_nvmlGpmSupport_t()\n    gpmSupport.version = NVML_GPM_SUPPORT_VERSION\n    fn = _nvmlGetFunctionPointer(\"nvmlGpmQueryDeviceSupport\")\n    ret = fn(device, byref(gpmSupport))\n    _nvmlCheckReturn(ret)\n    return gpmSupport\n\n\ndef nvmlGpmSetStreamingEnabled(device, state):\n    c_state = c_uint(state)\n    fn = _nvmlGetFunctionPointer(\"nvmlGpmSetStreamingEnabled\")\n    ret = fn(device, c_state)\n    _nvmlCheckReturn(ret)\n    return NVML_SUCCESS\n\n\ndef nvmlGpmQueryIfStreamingEnabled(device):\n    c_state = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlGpmQueryIfStreamingEnabled\")\n    ret = fn(device, byref(c_state))\n    _nvmlCheckReturn(ret)\n    return c_state.value\n\n\n# Low Power Structure and Function\n\nNVML_NVLINK_POWER_STATE_HIGH_SPEED = 0x0\nNVML_NVLINK_POWER_STATE_LOW = 0x1\n\nNVML_NVLINK_LOW_POWER_THRESHOLD_MIN = 0x1\nNVML_NVLINK_LOW_POWER_THRESHOLD_MAX = 0x1FFF\nNVML_NVLINK_LOW_POWER_THRESHOLD_RESET = 0xFFFFFFFF\nNVML_NVLINK_LOW_POWER_THRESHOLD_DEFAULT = NVML_NVLINK_LOW_POWER_THRESHOLD_RESET\n\n\nclass c_nvmlNvLinkPowerThres_t(Structure):\n    _fields_ = [\n        (\"lowPwrThreshold\", c_uint),\n    ]\n\n\ndef nvmlDeviceSetNvLinkDeviceLowPowerThreshold(device, l1threshold):\n    c_info = c_nvmlNvLinkPowerThres_t()\n    c_info.lowPwrThreshold = l1threshold\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceSetNvLinkDeviceLowPowerThreshold\")\n    ret = fn(device, byref(c_info))\n    _nvmlCheckReturn(ret)\n    return NVML_SUCCESS\n\n\nNVML_GPU_FABRIC_UUID_LEN = 16\n\n_nvmlGpuFabricState_t = c_uint\nNVML_GPU_FABRIC_STATE_NOT_SUPPORTED = 0\nNVML_GPU_FABRIC_STATE_NOT_STARTED = 1\nNVML_GPU_FABRIC_STATE_IN_PROGRESS = 2\nNVML_GPU_FABRIC_STATE_COMPLETED = 3\n\n\nclass c_nvmlGpuFabricInfo_t(_PrintableStructure):\n    _fields_ = [\n        (\"clusterUuid\", c_char * NVML_DEVICE_UUID_BUFFER_SIZE),\n        (\"status\", _nvmlReturn_t),\n        (\"cliqueId\", c_uint32),\n        (\"state\", _nvmlGpuFabricState_t),\n    ]\n\n\nNVML_GPU_FABRIC_HEALTH_MASK_DEGRADED_BW_NOT_SUPPORTED = 0\nNVML_GPU_FABRIC_HEALTH_MASK_DEGRADED_BW_TRUE = 1\nNVML_GPU_FABRIC_HEALTH_MASK_DEGRADED_BW_FALSE = 2\nNVML_GPU_FABRIC_HEALTH_MASK_SHIFT_DEGRADED_BW = 0\nNVML_GPU_FABRIC_HEALTH_MASK_WIDTH_DEGRADED_BW = 0x11\n\nNVML_GPU_FABRIC_HEALTH_MASK_ROUTE_RECOVERY_NOT_SUPPORTED = 0\nNVML_GPU_FABRIC_HEALTH_MASK_ROUTE_RECOVERY_TRUE = 1\nNVML_GPU_FABRIC_HEALTH_MASK_ROUTE_RECOVERY_FALSE = 2\nNVML_GPU_FABRIC_HEALTH_MASK_SHIFT_ROUTE_RECOVERY = 2\nNVML_GPU_FABRIC_HEALTH_MASK_WIDTH_ROUTE_RECOVERY = 0x11\n\nNVML_GPU_FABRIC_HEALTH_MASK_ROUTE_UNHEALTHY_NOT_SUPPORTED = 0\nNVML_GPU_FABRIC_HEALTH_MASK_ROUTE_UNHEALTHY_TRUE = 1\nNVML_GPU_FABRIC_HEALTH_MASK_ROUTE_UNHEALTHY_FALSE = 2\nNVML_GPU_FABRIC_HEALTH_MASK_SHIFT_ROUTE_UNHEALTHY = 4\nNVML_GPU_FABRIC_HEALTH_MASK_WIDTH_ROUTE_UNHEALTHY = 0x11\n\nNVML_GPU_FABRIC_HEALTH_MASK_ACCESS_TIMEOUT_RECOVERY_NOT_SUPPORTED = 0\nNVML_GPU_FABRIC_HEALTH_MASK_ACCESS_TIMEOUT_RECOVERY_TRUE = 1\nNVML_GPU_FABRIC_HEALTH_MASK_ACCESS_TIMEOUT_RECOVERY_FALSE = 2\nNVML_GPU_FABRIC_HEALTH_MASK_SHIFT_ACCESS_TIMEOUT_RECOVERY = 6\nNVML_GPU_FABRIC_HEALTH_MASK_WIDTH_ACCESS_TIMEOUT_RECOVERY = 0x11\n\nnvmlGpuFabricInfo_v2 = 0x02000024\n\n\nclass c_nvmlGpuFabricInfoV_t(_PrintableStructure):\n    _fields_ = [\n        (\"version\", c_uint),\n        (\"clusterUuid\", c_char * NVML_GPU_FABRIC_UUID_LEN),\n        (\"status\", _nvmlReturn_t),\n        (\"cliqueId\", c_uint32),\n        (\"state\", _nvmlGpuFabricState_t),\n        (\"healthMask\", c_uint32),\n    ]\n\n    def __init__(self):\n        super(c_nvmlGpuFabricInfoV_t, self).__init__(version=nvmlGpuFabricInfo_v2)\n\n\ndef nvmlDeviceGetGpuFabricInfo(device, gpuFabricInfo):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetGpuFabricInfo\")\n    ret = fn(device, gpuFabricInfo)\n    _nvmlCheckReturn(ret)\n    return NVML_SUCCESS\n\n\ndef nvmlDeviceGetGpuFabricInfoV(device, gpuFabricInfo):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetGpuFabricInfoV\")\n    ret = fn(device, gpuFabricInfo)\n    _nvmlCheckReturn(ret)\n    return NVML_SUCCESS\n\n\n######################\n## Enums/defines\n#### NVML GPU NVLINK BW MODE\nNVML_GPU_NVLINK_BW_MODE_FULL = 0x0\nNVML_GPU_NVLINK_BW_MODE_OFF = 0x1\nNVML_GPU_NVLINK_BW_MODE_MIN = 0x2\nNVML_GPU_NVLINK_BW_MODE_HALF = 0x3\nNVML_GPU_NVLINK_BW_MODE_3QUARTER = 0x4\nNVML_GPU_NVLINK_BW_MODE_COUNT = 0x5\n\n\ndef nvmlSystemSetNvlinkBwMode(mode):\n    fn = _nvmlGetFunctionPointer(\"nvmlSystemSetNvlinkBwMode\")\n    ret = fn(mode)\n    _nvmlCheckReturn(ret)\n    return NVML_SUCCESS\n\n\ndef nvmlSystemGetNvlinkBwMode():\n    mode = c_uint()\n    fn = _nvmlGetFunctionPointer(\"nvmlSystemGetNvlinkBwMode\")\n    ret = fn(byref(mode))\n    _nvmlCheckReturn(ret)\n    return mode.value\n\n\n_nvmlPowerScopeType_t = c_uint\nNVML_POWER_SCOPE_GPU = 0\nNVML_POWER_SCOPE_MODULE = 1\nNVML_POWER_SCOPE_MEMORY = 2\n\n\nclass c_nvmlPowerValue_v2_t(_PrintableStructure):\n    _fields_ = [\n        (\"version\", c_uint),\n        (\"powerScope\", _nvmlPowerScopeType_t),\n        (\"powerValueMw\", c_uint),\n    ]\n    _fmt_ = {\"<default>\": \"%d B\"}\n\n\nnvmlPowerValue_v2 = 0x0200000C\n\n\ndef nvmlDeviceSetPowerManagementLimit_v2(\n    device, powerScope, powerLimit, version=nvmlPowerValue_v2\n):\n    c_powerScope = _nvmlPowerScopeType_t(powerScope)\n    c_powerValue = c_nvmlPowerValue_v2_t()\n    c_powerValue.version = c_uint(version)\n    c_powerValue.powerScope = c_powerScope\n    c_powerValue.powerValueMw = c_uint(powerLimit)\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceSetPowerManagementLimit_v2\")\n    ret = fn(device, byref(c_powerValue))\n    return NVML_SUCCESS\n\n\nclass c_nvmlEccSramErrorStatus_v1_t(_PrintableStructure):\n    _fields_ = [\n        (\"version\", c_uint),\n        (\"aggregateUncParity\", c_ulonglong),\n        (\"aggregateUncSecDed\", c_ulonglong),\n        (\"aggregateCor\", c_ulonglong),\n        (\"volatileUncParity\", c_ulonglong),\n        (\"volatileUncSecDed\", c_ulonglong),\n        (\"volatileCor\", c_ulonglong),\n        (\"aggregateUncBucketL2\", c_ulonglong),\n        (\"aggregateUncBucketSm\", c_ulonglong),\n        (\"aggregateUncBucketPcie\", c_ulonglong),\n        (\"aggregateUncBucketMcu\", c_ulonglong),\n        (\"aggregateUncBucketOther\", c_ulonglong),\n        (\"bThresholdExceeded\", c_uint),\n    ]\n\n    def __init__(self):\n        super(c_nvmlEccSramErrorStatus_v1_t, self).__init__(\n            version=nvmlEccSramErrorStatus_v1\n        )\n\n\nnvmlEccSramErrorStatus_v1 = 0x1000068\n\n\ndef nvmlDeviceGetSramEccErrorStatus(device, status):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetSramEccErrorStatus\")\n    ret = fn(device, status)\n    _nvmlCheckReturn(ret)\n    return NVML_SUCCESS\n\n\nNVML_DEV_CAP_EGM = 1 << 0\nnvmlDeviceCapabilities_v1 = 0x1000008\n\n\nclass c_nvmlDeviceCapabilities_v1_t(_PrintableStructure):\n    _fields_ = [\n        (\"version\", c_uint),\n        (\"capMask\", c_uint),\n    ]\n\n    def __init__(self):\n        super(c_nvmlDeviceCapabilities_v1_t, self).__init__(\n            version=nvmlDeviceCapabilities_v1\n        )\n\n\ndef nvmlDeviceGetCapabilities(device, caps):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetCapabilities\")\n    return fn(device, caps)\n\n\nclass c_nvmlPlatformInfo_v1_t(_PrintableStructure):\n    _fields_ = [\n        (\"version\", c_uint),\n        (\"ibGuid\", c_char * 16),\n        (\"rackGuid\", c_char * 16),\n        (\"chassisPhysicalSlotNumber\", c_char),\n        (\"computeSlotIndex\", c_char),\n        (\"nodeIndex\", c_char),\n        (\"peerType\", c_char),\n        (\"moduleId\", c_char),\n    ]\n\n    def __init__(self):\n        super(c_nvmlPlatformInfo_v1_t, self).__init__(version=nvmlPlatformInfo_v1)\n\n\nnvmlPlatformInfo_v1 = 0x100002C\n\n\ndef nvmlDeviceGetPlatformInfo(device, platformInfo):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetPlatformInfo\")\n    ret = fn(device, platformInfo)\n    _nvmlCheckReturn(ret)\n    return NVML_SUCCESS\n\n\nclass c_nvmlMask255_t(_PrintableStructure):\n    _fields_ = [\n        (\"mask\", c_uint * 8),\n    ]\n\n\nNVML_WORKLOAD_POWER_MAX_PROFILES = 255\nNVML_POWER_PROFILE_MAX_P = 0\nNVML_POWER_PROFILE_MAX_Q = 1\nNVML_POWER_PROFILE_COMPUTE = 2\nNVML_POWER_PROFILE_MEMORY_BOUND = 3\nNVML_POWER_PROFILE_NETWORK = 4\nNVML_POWER_PROFILE_BALANCED = 5\nNVML_POWER_PROFILE_LLM_INFERENCE = 6\nNVML_POWER_PROFILE_LLM_TRAINING = 7\nNVML_POWER_PROFILE_RBM = 8\nNVML_POWER_PROFILE_DCPCIE = 9\nNVML_POWER_PROFILE_HMMA_SPARSE = 10\nNVML_POWER_PROFILE_HMMA_DENSE = 11\nNVML_POWER_PROFILE_SYNC_BALANCED = 12\nNVML_POWER_PROFILE_HPC = 13\nNVML_POWER_PROFILE_MIG = 14\nNVML_POWER_PROFILE_MAX = 15\n\nnvmlWorkloadPowerProfileInfo_v1 = 0x100002C\n\n\nclass c_nvmlWorkloadPowerProfileInfo_v1_t(_PrintableStructure):\n    _fields_ = [\n        (\"version\", c_uint),\n        (\"profileId\", c_uint),\n        (\"priority\", c_uint),\n        (\"conflictingmask\", c_nvmlMask255_t),\n    ]\n\n    def __init__(self):\n        super(c_nvmlWorkloadPowerProfileInfo_v1_t, self).__init__(\n            version=nvmlWorkloadPowerProfileInfo_v1\n        )\n\n\nnvmlWorkloadPowerProfileProfilesInfo_v1 = 0x1002BF8\n\n\nclass c_nvmlWorkloadPowerProfileProfilesInfo_v1_t(_PrintableStructure):\n    _fields_ = [\n        (\"version\", c_uint),\n        (\"perfProfilesMask\", c_nvmlMask255_t),\n        (\n            \"perfProfile\",\n            c_nvmlWorkloadPowerProfileInfo_v1_t * NVML_WORKLOAD_POWER_MAX_PROFILES,\n        ),\n    ]\n\n    def __init__(self):\n        super(c_nvmlWorkloadPowerProfileProfilesInfo_v1_t, self).__init__(\n            version=nvmlWorkloadPowerProfileProfilesInfo_v1\n        )\n\n\nnvmlWorkloadPowerProfileCurrentProfiles_v1 = 0x1000064\n\n\nclass c_nvmlWorkloadPowerProfileCurrentProfiles_v1_t(_PrintableStructure):\n    _fields_ = [\n        (\"version\", c_uint),\n        (\"perfProfilesMask\", c_nvmlMask255_t),\n        (\"requestedProfilesMask\", c_nvmlMask255_t),\n        (\"enforcedProfilesMask\", c_nvmlMask255_t),\n    ]\n\n    def __init__(self):\n        super(c_nvmlWorkloadPowerProfileCurrentProfiles_v1_t, self).__init__(\n            version=nvmlWorkloadPowerProfileCurrentProfiles_v1\n        )\n\n\nnvmlWorkloadPowerProfileRequestedProfiles_v1 = 0x1000024\n\n\nclass c_nvmlWorkloadPowerProfileRequestedProfiles_v1_t(_PrintableStructure):\n    _fields_ = [\n        (\"version\", c_uint),\n        (\"requestedProfilesMask\", c_nvmlMask255_t),\n    ]\n\n    def __init__(self):\n        super(c_nvmlWorkloadPowerProfileRequestedProfiles_v1_t, self).__init__(\n            version=nvmlWorkloadPowerProfileRequestedProfiles_v1\n        )\n\n\ndef nvmlDeviceWorkloadPowerProfileGetProfilesInfo(device, profilesInfo):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceWorkloadPowerProfileGetProfilesInfo\")\n    ret = fn(device, profilesInfo)\n    _nvmlCheckReturn(ret)\n    return NVML_SUCCESS\n\n\ndef nvmlDeviceWorkloadPowerProfileGetCurrentProfiles(device, currentProfiles):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceWorkloadPowerProfileGetCurrentProfiles\")\n    ret = fn(device, currentProfiles)\n    _nvmlCheckReturn(ret)\n    return NVML_SUCCESS\n\n\ndef nvmlDeviceWorkloadPowerProfileSetRequestedProfiles(device, requestedProfiles):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceWorkloadPowerProfileSetRequestedProfiles\")\n    ret = fn(device, requestedProfiles)\n    _nvmlCheckReturn(ret)\n    return NVML_SUCCESS\n\n\ndef nvmlDeviceWorkloadPowerProfileClearRequestedProfiles(device, requestedProfiles):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceWorkloadPowerProfileClearRequestedProfiles\")\n    ret = fn(device, requestedProfiles)\n    _nvmlCheckReturn(ret)\n    return NVML_SUCCESS\n\n\ndef nvmlDeviceGetNvlinkSupportedBwModes(device, supportedBwModes):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetNvlinkSupportedBwModes\")\n    ret = fn(device, supportedBwModes)\n    _nvmlCheckReturn(ret)\n    return NVML_SUCCESS\n\n\ndef nvmlDeviceGetNvlinkBwMode(device, getBwMode):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetNvlinkBwMode\")\n    ret = fn(device, getBwMode)\n    _nvmlCheckReturn(ret)\n    return NVML_SUCCESS\n\n\ndef nvmlDeviceSetNvlinkBwMode(device, setBwMode):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceSetNvlinkBwMode\")\n    ret = fn(device, setBwMode)\n    _nvmlCheckReturn(ret)\n    return NVML_SUCCESS\n\n\nnvmlDramEncryptionInfo_v1 = 0x01000008\n\n\nclass c_nvmlDramEncryptionInfo_t(_PrintableStructure):\n    _fields_ = [\n        (\"version\", c_uint),\n        (\"encryptionState\", _nvmlEnableState_t),\n    ]\n\n    def __init__(self):\n        super(c_nvmlDramEncryptionInfo_t, self).__init__(\n            version=nvmlDramEncryptionInfo_v1\n        )\n\n\ndef nvmlDeviceGetDramEncryptionMode(handle):\n    c_currState = c_nvmlDramEncryptionInfo_t()\n    c_pendingState = c_nvmlDramEncryptionInfo_t()\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceGetDramEncryptionMode\")\n    ret = fn(handle, byref(c_currState), byref(c_pendingState))\n    _nvmlCheckReturn(ret)\n    return [c_currState.encryptionState, c_pendingState.encryptionState]\n\n\n# added to API\ndef nvmlDeviceGetCurrentDramEncryptionMode(handle):\n    return nvmlDeviceGetDramEncryptionMode(handle)[0]\n\n\n# added to API\ndef nvmlDeviceGetPendingDramEncryptionMode(handle):\n    return nvmlDeviceGetDramEncryptionMode(handle)[1]\n\n\ndef nvmlDeviceSetDramEncryptionMode(handle, mode):\n    fn = _nvmlGetFunctionPointer(\"nvmlDeviceSetDramEncryptionMode\")\n    c_dramEncryptionMode = c_nvmlDramEncryptionInfo_t()\n    c_dramEncryptionMode.encryptionState = mode\n    ret = fn(handle, byref(c_dramEncryptionMode))\n    _nvmlCheckReturn(ret)\n    return None\n\n\n# Power Smoothing defines\nNVML_POWER_SMOOTHING_MAX_NUM_PROFILES = 5\nNVML_POWER_SMOOTHING_ADMIN_OVERRIDE_NOT_SET = 0xFFFFFFFF\nNVML_POWER_SMOOTHING_PROFILE_PARAM_PERCENT_TMP_FLOOR = 0\nNVML_POWER_SMOOTHING_PROFILE_PARAM_RAMP_UP_RATE = 1\nNVML_POWER_SMOOTHING_PROFILE_PARAM_RAMP_DOWN_RATE = 2\nNVML_POWER_SMOOTHING_PROFILE_PARAM_RAMP_DOWN_HYSTERESIS = 3\n\nnvmlPowerSmoothingState_v1 = 0x1000008\n\n\nclass c_nvmlPowerSmoothingState_v1_t(_PrintableStructure):\n    _fields_ = [\n        (\"version\", c_uint),\n        (\"state\", c_uint),\n    ]\n\n    def __init__(self):\n        super(c_nvmlPowerSmoothingState_v1_t, self).__init__(\n            version=nvmlPowerSmoothingState_v1\n        )\n\n\nnvmlPowerSmoothingProfile_v1 = 0x1000018\n\n\nclass c_nvmlPowerSmoothingProfile_v1_t(_PrintableStructure):\n    _fields_ = [\n        (\"version\", c_uint),\n        (\"profileId\", c_uint),\n        (\"paramId\", c_uint),\n        (\"value\", c_double),\n    ]\n\n    def __init__(self):\n        super(c_nvmlPowerSmoothingProfile_v1_t, self).__init__(\n            version=nvmlPowerSmoothingProfile_v1\n        )\n\n\ndef nvmlDevicePowerSmoothingActivatePresetProfile(device, profile):\n    fn = _nvmlGetFunctionPointer(\"nvmlDevicePowerSmoothingActivatePresetProfile\")\n    ret = fn(device, profile)\n    _nvmlCheckReturn(ret)\n\n\ndef nvmlDevicePowerSmoothingUpdatePresetProfileParam(device, profile):\n    fn = _nvmlGetFunctionPointer(\"nvmlDevicePowerSmoothingUpdatePresetProfileParam\")\n    ret = fn(device, profile)\n    _nvmlCheckReturn(ret)\n\n\ndef nvmlDevicePowerSmoothingSetState(device, state):\n    fn = _nvmlGetFunctionPointer(\"nvmlDevicePowerSmoothingSetState\")\n    ret = fn(device, state)\n    _nvmlCheckReturn(ret)\n"
  },
  {
    "path": "python/sglang/multimodal_gen/tools/convert_hf_to_fp8.py",
    "content": "# copied and adapted from Slime\n\"\"\"\nConvert HuggingFace safetensors model to FP8 format for efficient inference.\n\nExample usage:\n    # convert FLUX.1-dev transformer to FP8\n    python -m sglang.multimodal_gen.tools.convert_hf_to_fp8 \\\n        --model-dir /path/to/FLUX.1-dev/transformer \\\n        --save-dir /path/to/FLUX.1-dev/transformer-FP8 \\\n        --strategy block \\\n        --block-size 128 128\n\nOptions:\n    --model-dir MODEL_DIR\n                        path to the directory of the HF safetensors model (e.g., transformer subfolder)\n    --save-dir SAVE_DIR\n                        path to the directory to save the converted FP8 model\n    --strategy {block,channel,tensor}\n                        quantization strategy (default: block)\n    --block-size [BLOCK_SIZE ...]\n                        block size for block quantization, e.g., --block-size 128 128\n    --max-workers MAX_WORKERS\n                        number of worker threads for parallel processing (default: 1)\n\"\"\"\n\nimport argparse\nimport gc\nimport json\nimport os\nimport shutil\nimport threading\nfrom concurrent.futures import ThreadPoolExecutor\n\nimport safetensors\nimport safetensors.torch\nimport torch\nimport torch.nn.functional as F\nfrom tqdm import tqdm\n\nFP8_INFO = torch.finfo(torch.float8_e4m3fn)\nFP8_MAX, FP8_MIN = FP8_INFO.max, FP8_INFO.min\n\n\ndef ceildiv(a, b):\n    return -(-a // b)\n\n\ndef block_fp8(weight, block_size):\n\n    # per block quant\n    block_n, block_k = block_size[0], block_size[1]\n\n    shape_0, shape_1 = weight.shape\n\n    n_tiles = ceildiv(shape_0, block_n)\n    k_tiles = ceildiv(shape_1, block_k)\n\n    q_weight = F.pad(\n        weight,\n        (0, k_tiles * block_k - shape_1, 0, n_tiles * block_n - shape_0),\n        mode=\"constant\",\n        value=0.0,\n    )\n\n    qweight = q_weight.reshape(n_tiles, block_n, k_tiles, block_k)\n    block_max = torch.max(torch.abs(qweight), dim=1, keepdim=True)[0]\n    block_max = torch.max(block_max, dim=3, keepdim=True)[0]\n\n    scale = block_max.to(torch.float32) / FP8_MAX\n    qweight = (\n        (qweight / scale)\n        .clamp(min=FP8_MIN, max=FP8_MAX)\n        .reshape((n_tiles * block_n, k_tiles * block_k))\n        .to(torch.float8_e4m3fn)\n    )\n    qweight = qweight[:shape_0, :shape_1].clone().detach()\n    scale = scale.squeeze()\n\n    return qweight, scale\n\n\ndef channel_fp8(weight):\n    channel_max = torch.max(weight.abs(), dim=-1, keepdim=True)[0]\n    scale = channel_max.clamp(min=1e-12).to(torch.float32) / FP8_MAX\n    qweight = (weight / scale).clamp(min=FP8_MIN, max=FP8_MAX)\n    qweight = qweight.to(torch.float8_e4m3fn)\n    return qweight, scale\n\n\ndef tensor_fp8(weight):\n    scale = weight.abs().max().clamp(min=1e-12).to(torch.float32) / FP8_MAX\n    qweight = (weight / scale).clamp(min=FP8_MIN, max=FP8_MAX)\n    qweight = qweight.to(torch.float8_e4m3fn)\n    scale = scale.view(1)\n    return qweight, scale\n\n\ndef quant_fp8(weight, strategy, block_size=None):\n    if strategy == \"tensor\":\n        return tensor_fp8(weight)\n    elif strategy == \"channel\":\n        return channel_fp8(weight)\n    else:\n        return block_fp8(weight, block_size)\n\n\nclass ConversionResult:\n    def __init__(self):\n        self.lock = threading.Lock()\n        self.weight_map = {}\n        self.param_count = 0\n        self.modules_to_not_convert = []\n\n    def add_result(self, filename, q_weights, module_names):\n        with self.lock:\n            for k, v in q_weights.items():\n                self.weight_map[k] = filename\n                self.param_count += v.numel()\n            self.modules_to_not_convert.extend(module_names)\n\n\ndef process_file(\n    input_path, output_path, filename, strategy, block_size, result_collector\n):\n    if not filename.endswith(\".safetensors\"):\n        return\n\n    print(f\"Processing {filename}, memory usage: {torch.cuda.memory_allocated()}\")\n    weights = {}\n    q_weights = {}\n\n    with safetensors.safe_open(\n        os.path.join(input_path, filename), framework=\"pt\", device=\"cuda\"\n    ) as f:\n        for k in f.keys():\n            weights[k] = f.get_tensor(k)\n\n    modules_to_not_convert = []\n    for key in weights.keys():\n        if (\n            \"weight\" in key\n            and \"layernorm\" not in key\n            and \"embed\" not in key\n            and \"router\" not in key\n            and \"mlp.gate.\" not in key\n            and \"norm\" not in key\n            and \"lm_head\" not in key\n            and \"eh_proj\" not in key\n            and \"net\" not in key\n            and \"txt_mod\" not in key\n            and \"img_mod\" not in key\n            and \"modulation\" not in key\n            and \"img_in\" not in key\n            and \"txt_in\" not in key\n            and \"time_in\" not in key\n            and \"vector_in\" not in key\n            and \"adaLN_modulation\" not in key\n            and \"all_final_layer\" not in key\n            and \"feed_forward\" not in key\n            and \"proj_out.weight\" != key\n        ):\n            qw, s = quant_fp8(weights[key], strategy, block_size)\n            q_weights[key] = qw\n            if block_size:\n                scale_name = key.replace(\".weight\", \".weight_scale_inv\")\n            else:\n                scale_name = key.replace(\".weight\", \".weight_scale\")\n            q_weights[scale_name] = s\n        else:\n            modules_to_not_convert.append(key.replace(\".weight\", \"\"))\n            q_weights[key] = weights[key]\n\n    safetensors.torch.save_file(\n        q_weights, os.path.join(output_path, filename), metadata={\"format\": \"pt\"}\n    )\n\n    result_collector.add_result(filename, q_weights, modules_to_not_convert)\n\n\ndef convert_fp8(input_path, output_path, strategy, block_size=None, max_workers=4):\n    input_path = os.path.abspath(input_path)\n    os.makedirs(output_path, exist_ok=True)\n\n    for filename in os.listdir(input_path):\n        if not filename.endswith(\".safetensors\") and not os.path.isdir(\n            os.path.join(input_path, filename)\n        ):\n            shutil.copyfile(\n                os.path.join(input_path, filename), os.path.join(output_path, filename)\n            )\n\n    safetensors_files = [\n        f for f in os.listdir(input_path) if f.endswith(\".safetensors\")\n    ]\n\n    result_collector = ConversionResult()\n\n    with ThreadPoolExecutor(max_workers=max_workers) as executor:\n        futures = []\n        for filename in safetensors_files:\n            future = executor.submit(\n                process_file,\n                input_path,\n                output_path,\n                filename,\n                strategy,\n                block_size,\n                result_collector,\n            )\n            futures.append(future)\n\n        for future in tqdm(futures, desc=\"Processing files\"):\n            future.result()\n\n    if strategy == \"block\" or strategy == \"tensor\":\n        quantization_config = {\n            \"activation_scheme\": \"dynamic\",\n            \"fmt\": \"e4m3\",\n            \"quant_method\": \"fp8\",\n        }\n        if block_size:\n            quantization_config[\"weight_block_size\"] = block_size\n        if len(result_collector.modules_to_not_convert) > 0:\n            quantization_config[\"modules_to_not_convert\"] = list(\n                set(result_collector.modules_to_not_convert)\n            )\n    else:\n        quant_group = {\n            \"group_0\": {\n                \"input_activations\": {\n                    \"actorder\": None,\n                    \"block_structure\": None,\n                    \"dynamic\": True,\n                    \"group_size\": None,\n                    \"num_bits\": 8,\n                    \"observer\": None,\n                    \"observer_kwargs\": {},\n                    \"strategy\": \"token\",\n                    \"symmetric\": True,\n                    \"type\": \"float\",\n                },\n                \"output_activations\": None,\n                \"targets\": [\"Linear\"],\n                \"weights\": {\n                    \"actorder\": None,\n                    \"block_structure\": None,\n                    \"dynamic\": False,\n                    \"group_size\": None,\n                    \"num_bits\": 8,\n                    \"observer\": \"minmax\",\n                    \"observer_kwargs\": {},\n                    \"strategy\": strategy,\n                    \"symmetric\": True,\n                    \"type\": \"float\",\n                },\n            },\n        }\n        quantization_config = {\n            \"config_groups\": quant_group,\n            \"format\": \"float-quantized\",\n            \"ignore\": list(set(result_collector.modules_to_not_convert)),\n            \"quant_method\": \"compressed-tensors\",\n            \"quantization_status\": \"compressed\",\n        }\n\n    config_path = os.path.join(input_path, \"config.json\")\n    if os.path.exists(config_path):\n        cfg = json.load(open(config_path))\n        cfg[\"quantization_config\"] = quantization_config\n        json.dump(cfg, open(os.path.join(output_path, \"config.json\"), \"w\"), indent=2)\n\n    index_dict = {\n        \"weight_map\": result_collector.weight_map,\n        \"metadata\": {\"total_size\": result_collector.param_count},\n    }\n    json.dump(\n        index_dict,\n        open(os.path.join(output_path, \"model.safetensors.index.json\"), \"w\"),\n        indent=2,\n    )\n\n    gc.collect()\n    torch.cuda.empty_cache()\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--model-dir\",\n        type=str,\n        help=\"Path to the directory of the HF safetensors model.\",\n    )\n    parser.add_argument(\n        \"--save-dir\",\n        type=str,\n        help=\"Path to the directory to save the converted model.\",\n    )\n    parser.add_argument(\n        \"--strategy\", type=str, default=\"block\", choices=[\"block\", \"channel\", \"tensor\"]\n    )\n    parser.add_argument(\n        \"--block-size\", type=int, nargs=\"*\", default=None, help=\"eg. --block-size 32 32\"\n    )\n    parser.add_argument(\n        \"--max-workers\",\n        type=int,\n        default=8,\n        help=\"Number of worker threads for parallel processing\",\n    )\n    args = parser.parse_args()\n\n    if not os.path.exists(args.save_dir):\n        print(f\"Creating directory {args.save_dir}\")\n        os.makedirs(args.save_dir)\n    elif not os.path.isdir(args.save_dir):\n        raise ValueError(\"The save_dir should be a directory.\")\n\n    convert_fp8(\n        args.model_dir, args.save_dir, args.strategy, args.block_size, args.max_workers\n    )\n"
  },
  {
    "path": "python/sglang/multimodal_gen/tools/wan_repack.py",
    "content": "### Based on https://github.com/huggingface/diffusers/blob/main/scripts/convert_wan_to_diffusers.py\r\n\r\nimport argparse\r\nimport json\r\nimport pathlib\r\nfrom typing import Any, Dict, Tuple\r\n\r\nfrom safetensors.torch import load_file, save_file\r\n\r\nTRANSFORMER_KEYS_RENAME_DICT = {\r\n    \"time_embedding.0\": \"condition_embedder.time_embedder.linear_1\",\r\n    \"time_embedding.2\": \"condition_embedder.time_embedder.linear_2\",\r\n    \"text_embedding.0\": \"condition_embedder.text_embedder.linear_1\",\r\n    \"text_embedding.2\": \"condition_embedder.text_embedder.linear_2\",\r\n    \"time_projection.1\": \"condition_embedder.time_proj\",\r\n    \"head.modulation\": \"scale_shift_table\",\r\n    \"head.head\": \"proj_out\",\r\n    \"modulation\": \"scale_shift_table\",\r\n    \"ffn.0\": \"ffn.net.0.proj\",\r\n    \"ffn.2\": \"ffn.net.2\",\r\n    # Hack to swap the layer names\r\n    # The original model calls the norms in following order: norm1, norm3, norm2\r\n    # We convert it to: norm1, norm2, norm3\r\n    \"norm2\": \"norm__placeholder\",\r\n    \"norm3\": \"norm2\",\r\n    \"norm__placeholder\": \"norm3\",\r\n    # For the I2V model\r\n    \"img_emb.proj.0\": \"condition_embedder.image_embedder.norm1\",\r\n    \"img_emb.proj.1\": \"condition_embedder.image_embedder.ff.net.0.proj\",\r\n    \"img_emb.proj.3\": \"condition_embedder.image_embedder.ff.net.2\",\r\n    \"img_emb.proj.4\": \"condition_embedder.image_embedder.norm2\",\r\n    # for the FLF2V model\r\n    \"img_emb.emb_pos\": \"condition_embedder.image_embedder.pos_embed\",\r\n    # Add attention component mappings\r\n    \"self_attn.q\": \"attn1.to_q\",\r\n    \"self_attn.k\": \"attn1.to_k\",\r\n    \"self_attn.v\": \"attn1.to_v\",\r\n    \"self_attn.o\": \"attn1.to_out.0\",\r\n    \"self_attn.norm_q\": \"attn1.norm_q\",\r\n    \"self_attn.norm_k\": \"attn1.norm_k\",\r\n    \"cross_attn.q\": \"attn2.to_q\",\r\n    \"cross_attn.k\": \"attn2.to_k\",\r\n    \"cross_attn.v\": \"attn2.to_v\",\r\n    \"cross_attn.o\": \"attn2.to_out.0\",\r\n    \"cross_attn.norm_q\": \"attn2.norm_q\",\r\n    \"cross_attn.norm_k\": \"attn2.norm_k\",\r\n    \"attn2.to_k_img\": \"attn2.add_k_proj\",\r\n    \"attn2.to_v_img\": \"attn2.add_v_proj\",\r\n    \"attn2.norm_k_img\": \"attn2.norm_added_k\",\r\n}\r\n\r\n\r\ndef get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:\r\n    if model_type == \"Wan-T2V-14B\":\r\n        RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT\r\n    return RENAME_DICT\r\n\r\n\r\ndef update_dict_(dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:\r\n    dict[new_key] = dict.pop(old_key)\r\n\r\n\r\ndef load_sharded_safetensors(path: pathlib.Path):\r\n    file_path = path\r\n    state_dict = {}\r\n    state_dict.update(load_file(file_path))\r\n    return state_dict\r\n\r\n\r\ndef convert_transformer(model_type: str, model_dir: str, output_dir: str):\r\n    pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)\r\n    RENAME_DICT = get_transformer_config(model_type)\r\n\r\n    original_state_dict = load_sharded_safetensors(\r\n        pathlib.Path(model_dir, \"*model*.safetensors\")\r\n    )\r\n    with open(pathlib.Path(model_dir, \"*quant_model_description*.json\")) as f:\r\n        original_quant_config = json.load(f)\r\n\r\n    for key in list(original_state_dict.keys()):\r\n        new_key = key[:]\r\n        for replace_key, rename_key in RENAME_DICT.items():\r\n            new_key = new_key.replace(replace_key, rename_key)\r\n        update_dict_(original_state_dict, key, new_key)\r\n        update_dict_(original_quant_config, key, new_key)\r\n\r\n    save_file(\r\n        original_state_dict,\r\n        pathlib.Path(output_dir, \"diffusion_pytorch_model.safetensors\"),\r\n    )\r\n\r\n    with open(pathlib.Path(output_dir, \"quant_model_description.json\"), \"w\") as f:\r\n        json.dump(original_quant_config, f)\r\n\r\n\r\ndef get_args():\r\n    parser = argparse.ArgumentParser()\r\n    parser.add_argument(\"--input-path\", type=str, required=True)\r\n    parser.add_argument(\"--output-path\", type=str, required=True)\r\n    return parser.parse_args()\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    args = get_args()\r\n\r\n    convert_transformer(\r\n        \"Wan-T2V-14B\",\r\n        model_dir=pathlib.Path(args.input_path, \"high_noise_model\"),\r\n        output_dir=pathlib.Path(args.output_path, \"transformer\"),\r\n    )\r\n    convert_transformer(\r\n        \"Wan-T2V-14B\",\r\n        model_dir=pathlib.Path(args.input_path, \"low_noise_model\"),\r\n        output_dir=pathlib.Path(args.output_path, \"transformer_2\"),\r\n    )\r\n"
  },
  {
    "path": "python/sglang/multimodal_gen/utils.py",
    "content": "# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo\n\n# SPDX-License-Identifier: Apache-2.0\n# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/utils.py\n\nimport argparse\nimport ctypes\nimport importlib\nimport importlib.util\nimport inspect\nimport math\nimport os\nimport signal\nimport sys\nimport threading\nimport traceback\nfrom collections.abc import Callable\nfrom dataclasses import dataclass, fields, is_dataclass\nfrom functools import lru_cache, partial, wraps\nfrom typing import Any, TypeVar, cast\n\nimport cloudpickle\nimport torch\nimport yaml\nfrom torch.distributed.fsdp import MixedPrecisionPolicy\n\nimport sglang.multimodal_gen.envs as envs\nfrom sglang.multimodal_gen.runtime.utils.logging_utils import (\n    SortedHelpFormatter,\n    init_logger,\n)\n\nlogger = init_logger(__name__)\n\nT = TypeVar(\"T\")\n\n\ndef _expand_path_value(field_name: str, value: Any) -> Any:\n    eu = os.path.expanduser\n    if field_name.endswith(\"_path\") and isinstance(value, str):\n        return eu(value)\n    if field_name.endswith(\"_path\") and isinstance(value, list):\n        return [eu(x) if isinstance(x, str) else x for x in value]\n    if field_name.endswith(\"_paths\") and isinstance(value, dict):\n        return {k: eu(p) if isinstance(p, str) else p for k, p in value.items()}\n    return value\n\n\ndef expand_path_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]:\n    return {key: _expand_path_value(key, value) for key, value in kwargs.items()}\n\n\ndef expand_path_fields(obj) -> None:\n    \"\"\"In-place expanduser on all dataclass fields whose name ends with '_path' or '_paths'.\"\"\"\n    for f in fields(obj):\n        setattr(obj, f.name, _expand_path_value(f.name, getattr(obj, f.name)))\n\n\n# TODO(will): used to convert server_args.precision to torch.dtype. Find a\n# cleaner way to do this.\nPRECISION_TO_TYPE = {\n    \"fp32\": torch.float32,\n    \"fp16\": torch.float16,\n    \"bf16\": torch.bfloat16,\n}\n\nSTR_BACKEND_ENV_VAR: str = \"SGLANG_DIFFUSION_ATTENTION_BACKEND\"\nSTR_ATTN_CONFIG_ENV_VAR: str = \"SGLANG_DIFFUSION_ATTENTION_CONFIG\"\n\n\ndef find_nccl_library() -> str:\n    \"\"\"\n    We either use the library file specified by the `VLLM_NCCL_SO_PATH`\n    environment variable, or we find the library file brought by PyTorch.\n    After importing `torch`, `libnccl.so.2`, `librccl.so.1` or `libmccl.so.2`\n    can be found by `ctypes` automatically.\n    \"\"\"\n    so_file = envs.SGLANG_DIFFUSION_NCCL_SO_PATH\n\n    # manually load the nccl library\n    if so_file:\n        logger.info(\n            \"Found nccl from environment variable SGLANG_DIFFUSION_NCCL_SO_PATH=%s\",\n            so_file,\n        )\n    else:\n        if torch.version.cuda is not None:\n            so_file = \"libnccl.so.2\"\n        elif torch.version.hip is not None:\n            so_file = \"librccl.so.1\"\n        elif hasattr(torch.version, \"musa\") and torch.version.musa is not None:\n            so_file = \"libmccl.so.2\"\n        else:\n            raise ValueError(\"NCCL only supports CUDA, ROCm and MUSA backends.\")\n        logger.info(\"Found nccl from library %s\", so_file)\n    return str(so_file)\n\n\nprev_set_stream = torch.cuda.set_stream\n\n_current_stream = None\n\n\ndef _patched_set_stream(stream: torch.cuda.Stream | None) -> None:\n    global _current_stream\n    _current_stream = stream\n    if stream is not None:\n        prev_set_stream(stream)\n\n\ntorch.cuda.set_stream = _patched_set_stream\n\n\ndef current_stream() -> torch.cuda.Stream | None:\n    \"\"\"\n    replace `torch.cuda.current_stream()` with `sglang.multimodal_gen.utils.current_stream()`.\n    it turns out that `torch.cuda.current_stream()` is quite expensive,\n    as it will construct a new stream object at each call.\n    here we patch `torch.cuda.set_stream` to keep track of the current stream\n    directly, so that we can avoid calling `torch.cuda.current_stream()`.\n\n    the underlying hypothesis is that we do not call `torch._C._cuda_setStream`\n    from C/C++ code.\n    \"\"\"\n    from sglang.multimodal_gen.runtime.platforms import current_platform\n\n    # For non-CUDA platforms, return None\n    if not current_platform.is_cuda_alike():\n        return None\n\n    global _current_stream\n    if _current_stream is None:\n        # when this function is called before any stream is set,\n        # we return the default stream.\n        # On ROCm using the default 0 stream in combination with RCCL\n        # is hurting performance. Therefore creating a dedicated stream\n        # per process\n        _current_stream = (\n            torch.cuda.Stream()\n            if current_platform.is_rocm()\n            else torch.cuda.current_stream()\n        )\n    return _current_stream\n\n\nclass StoreBoolean(argparse.Action):\n\n    def __init__(self, option_strings, dest, default=False, required=False, help=None):\n        super().__init__(\n            option_strings=option_strings,\n            dest=dest,\n            nargs=\"?\",\n            const=True,\n            default=default,\n            required=required,\n            help=help,\n        )\n\n    def __call__(self, parser, namespace, values, option_string=None):\n        if values is None:\n            setattr(namespace, self.dest, True)\n        elif isinstance(values, str):\n            if values.lower() == \"true\":\n                setattr(namespace, self.dest, True)\n            elif values.lower() == \"false\":\n                setattr(namespace, self.dest, False)\n            else:\n                raise ValueError(\n                    f\"Invalid boolean value: {values}. \" \"Expected 'true' or 'false'.\"\n                )\n        else:\n            setattr(namespace, self.dest, bool(values))\n\n\nclass FlexibleArgumentParser(argparse.ArgumentParser):\n    \"\"\"ArgumentParser that allows both underscore and dash in names.\"\"\"\n\n    def __init__(self, *args, **kwargs) -> None:\n        # Set the default 'formatter_class' to SortedHelpFormatter\n        if \"formatter_class\" not in kwargs:\n            kwargs[\"formatter_class\"] = SortedHelpFormatter\n        super().__init__(*args, **kwargs)\n\n    def parse_args(  # type: ignore[override]\n        self, args=None, namespace=None\n    ) -> argparse.Namespace:\n        if args is None:\n            args = sys.argv[1:]\n\n        if any(arg.startswith(\"--config\") for arg in args):\n            args = self._pull_args_from_config(args)\n\n        # Convert underscores to dashes and vice versa in argument names\n        processed_args = []\n        for arg in args:\n            if arg.startswith(\"--\"):\n                if \"=\" in arg:\n                    key, value = arg.split(\"=\", 1)\n                    key = \"--\" + key[len(\"--\") :].replace(\"_\", \"-\")\n                    processed_args.append(f\"{key}={value}\")\n                else:\n                    processed_args.append(\"--\" + arg[len(\"--\") :].replace(\"_\", \"-\"))\n            elif arg.startswith(\"-O\") and arg != \"-O\" and len(arg) == 2:\n                # allow -O flag to be used without space, e.g. -O3\n                processed_args.append(\"-O\")\n                processed_args.append(arg[2:])\n            else:\n                processed_args.append(arg)\n\n        namespace = super().parse_args(processed_args, namespace)\n\n        # Track which arguments were explicitly provided\n        namespace._provided = set()\n\n        i = 0\n        while i < len(args):\n            arg = args[i]\n            if arg.startswith(\"--\"):\n                # Handle --key=value format\n                if \"=\" in arg:\n                    key = arg.split(\"=\")[0][2:].replace(\"-\", \"_\")\n                    namespace._provided.add(key)\n                    i += 1\n                # Handle --key value format\n                else:\n                    key = arg[2:].replace(\"-\", \"_\")\n                    namespace._provided.add(key)\n                    # Skip the value if there is one\n                    if i + 1 < len(args) and not args[i + 1].startswith(\"-\"):\n                        i += 2\n                    else:\n                        i += 1\n            else:\n                i += 1\n\n        return namespace  # type: ignore[no-any-return]\n\n    def _pull_args_from_config(self, args: list[str]) -> list[str]:\n        \"\"\"Method to pull arguments specified in the config file\n        into the command-line args variable.\n\n        The arguments in config file will be inserted between\n        the argument list.\n\n        example:\n        ```yaml\n            port: 12323\n            tensor-parallel-size: 4\n        ```\n        ```python\n        $: vllm {serve,chat,complete} \"facebook/opt-12B\" \\\n            --config config.yaml -tp 2\n        $: args = [\n            \"serve,chat,complete\",\n            \"facebook/opt-12B\",\n            '--config', 'config.yaml',\n            '-tp', '2'\n        ]\n        $: args = [\n            \"serve,chat,complete\",\n            \"facebook/opt-12B\",\n            '--port', '12323',\n            '--tp-size', '4',\n            '-tp', '2'\n            ]\n        ```\n\n        Please note how the config args are inserted after the sub command.\n        this way the order of priorities is maintained when these are args\n        parsed by super().\n        \"\"\"\n        index = -1\n        config_arg = None\n        for i, arg in enumerate(args):\n            if arg.startswith(\"--config\"):\n                if index != -1:\n                    raise ValueError(\"More than one config file specified!\")\n                index = i\n                config_arg = arg\n\n        if config_arg is None:\n            return args\n        args_before_config = args[:index]\n        if \"=\" in config_arg:\n            file_path = config_arg.split(\"=\", 1)[1]\n            args_after_config = args[index + 1 :]\n        else:\n            if index == len(args) - 1:\n                raise ValueError(\n                    \"No config file specified! \"\n                    \"Please check your command-line arguments.\"\n                )\n            file_path = args[index + 1]\n            args_after_config = args[index + 2 :]\n\n        config_args = self._load_config_file(file_path)\n\n        # 0th index is for {serve,chat,complete}\n        # followed by model_tag (only for serve)\n        # followed by config args\n        # followed by rest of cli args.\n        # maintaining this order will enforce the precedence\n        # of cli > config > defaults\n        if args[0] == \"serve\":\n            if index == 1:\n                raise ValueError(\n                    \"No model_tag specified! Please check your command-line\"\n                    \" arguments.\"\n                )\n            command = args_before_config[0]\n            model_tag = args_before_config[1]\n            other_args_before = args_before_config[2:]\n            args = (\n                [command, model_tag]\n                + config_args\n                + other_args_before\n                + args_after_config\n            )\n        else:\n            command = args_before_config[0]\n            other_args_before = args_before_config[1:]\n            args = [command] + config_args + other_args_before + args_after_config\n\n        return args\n\n    def _load_config_file(self, file_path: str) -> list[str]:\n        \"\"\"Loads a yaml file and returns the key value pairs as a\n        flattened list with argparse like pattern\n        ```yaml\n            port: 12323\n            tensor-parallel-size: 4\n            vae_config:\n                load_encoder: false\n                load_decoder: true\n        ```\n        returns:\n            processed_args: list[str] = [\n                '--port': '12323',\n                '--tp-size': '4',\n                '--vae-config.load-encoder': 'false',\n                '--vae-config.load-decoder': 'true'\n            ]\n        \"\"\"\n\n        extension: str = file_path.split(\".\")[-1]\n        if extension not in (\"yaml\", \"yml\", \"json\"):\n            raise ValueError(\n                \"Config file must be of a yaml/yml/json type.\\\n                              %s supplied\",\n                extension,\n            )\n\n        processed_args: list[str] = []\n\n        config: dict[str, Any] = {}\n        try:\n            with open(file_path) as config_file:\n                config = yaml.safe_load(config_file)\n        except Exception as ex:\n            logger.error(\n                \"Unable to read the config file at %s. \\\n                Make sure path is correct\",\n                file_path,\n            )\n            raise ex\n\n        store_boolean_arguments = [\n            action.dest for action in self._actions if isinstance(action, StoreBoolean)\n        ]\n\n        def process_dict(prefix: str, d: dict[str, Any]):\n            for key, value in d.items():\n                full_key = f\"{prefix}.{key}\" if prefix else key\n\n                if isinstance(value, bool) and full_key not in store_boolean_arguments:\n                    if value:\n                        processed_args.append(\"--\" + full_key)\n                    else:\n                        processed_args.append(\"--\" + full_key)\n                        processed_args.append(\"false\")\n                elif isinstance(value, list):\n                    processed_args.append(\"--\" + full_key)\n                    for item in value:\n                        processed_args.append(str(item))\n                elif isinstance(value, dict):\n                    process_dict(full_key, value)\n                else:\n                    processed_args.append(\"--\" + full_key)\n                    processed_args.append(str(value))\n\n        process_dict(\"\", config)\n\n        return processed_args\n\n\ndef warn_for_unimplemented_methods(cls: type[T]) -> type[T]:\n    \"\"\"\n    A replacement for `abc.ABC`.\n    When we use `abc.ABC`, subclasses will fail to instantiate\n    if they do not implement all abstract methods.\n    Here, we only require `raise NotImplementedError` in the\n    base class, and log a warning if the method is not implemented\n    in the subclass.\n    \"\"\"\n\n    original_init = cls.__init__\n\n    def find_unimplemented_methods(self: object):\n        unimplemented_methods = []\n        for attr_name in dir(self):\n            # bypass inner method\n            if attr_name.startswith(\"_\"):\n                continue\n\n            try:\n                attr = getattr(self, attr_name)\n                # get the func of callable method\n                if callable(attr):\n                    attr_func = attr.__func__\n            except AttributeError:\n                continue\n            src = inspect.getsource(attr_func)\n            if \"NotImplementedError\" in src:\n                unimplemented_methods.append(attr_name)\n        if unimplemented_methods:\n            method_names = \",\".join(unimplemented_methods)\n            msg = f\"Methods {method_names} not implemented in {self}\"\n            logger.warning(msg)\n\n    @wraps(original_init)\n    def wrapped_init(self, *args, **kwargs) -> None:\n        original_init(self, *args, **kwargs)\n        find_unimplemented_methods(self)\n\n    type.__setattr__(cls, \"__init__\", wrapped_init)\n    return cls\n\n\ndef align_to(value: int, alignment: int) -> int:\n    \"\"\"align height, width according to alignment\n\n    Args:\n        value (int): height or width\n        alignment (int): target alignment factor\n\n    Returns:\n        int: the aligned value\n    \"\"\"\n    return int(math.ceil(value / alignment) * alignment)\n\n\ndef resolve_obj_by_qualname(qualname: str) -> Any:\n    \"\"\"\n    Resolve an object by its fully qualified name.\n    \"\"\"\n    module_name, obj_name = qualname.rsplit(\".\", 1)\n    module = importlib.import_module(module_name)\n    return getattr(module, obj_name)\n\n\n# From vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/utils.py\ndef import_pynvml():\n    \"\"\"\n    Historical comments:\n\n    libnvml.so is the library behind nvidia-smi, and\n    pynvml is a Python wrapper around it. We use it to get GPU\n    status without initializing CUDA context in the current process.\n    Historically, there are two packages that provide pynvml:\n    - `nvidia-ml-py` (https://pypi.org/project/nvidia-ml-py/): The official\n        wrapper. It is a dependency of sglang-diffusion, and is installed when users\n        install sglang-diffusion. It provides a Python module named `pynvml`.\n    - `pynvml` (https://pypi.org/project/pynvml/): An unofficial wrapper.\n        Prior to version 12.0, it also provides a Python module `pynvml`,\n        and therefore conflicts with the official one which is a standalone Python file.\n        This causes errors when both of them are installed.\n        Starting from version 12.0, it migrates to a new module\n        named `pynvml_utils` to avoid the conflict.\n    It is so confusing that many packages in the community use the\n    unofficial one by mistake, and we have to handle this case.\n    For example, `nvcr.io/nvidia/pytorch:24.12-py3` uses the unofficial\n    one, and it will cause errors, see the issue\n    https://github.com/vllm-project/vllm/issues/12847 for example.\n    After all the troubles, we decide to copy the official `pynvml`\n    module to our codebase, and use it directly.\n    \"\"\"\n    import sglang.multimodal_gen.third_party.pynvml as pynvml\n\n    return pynvml\n\n\ndef update_environment_variables(envs: dict[str, str]):\n    for k, v in envs.items():\n        if k in os.environ and os.environ[k] != v:\n            logger.warning(\n                \"Overwriting environment variable %s \" \"from '%s' to '%s'\",\n                k,\n                os.environ[k],\n                v,\n            )\n        os.environ[k] = v\n\n\ndef run_method(\n    obj: Any, method: str | bytes | Callable, args: tuple[Any], kwargs: dict[str, Any]\n) -> Any:\n    \"\"\"\n    Run a method of an object with the given arguments and keyword arguments.\n    If the method is string, it will be converted to a method using getattr.\n    If the method is serialized bytes and will be deserialized using\n    cloudpickle.\n    If the method is a callable, it will be called directly.\n    \"\"\"\n    if isinstance(method, bytes):\n        func = partial(cloudpickle.loads(method), obj)\n    elif isinstance(method, str):\n        try:\n            func = getattr(obj, method)\n        except AttributeError:\n            raise NotImplementedError(\n                f\"Method {method!r} is not\" \" implemented.\"\n            ) from None\n    else:\n        func = partial(method, obj)  # type: ignore\n    return func(*args, **kwargs)\n\n\ndef shallow_asdict(obj) -> dict[str, Any]:\n    if not is_dataclass(obj):\n        raise TypeError(\"Expected dataclass instance\")\n    return {f.name: getattr(obj, f.name) for f in fields(obj)}\n\n\n# TODO: validate that this is fine\ndef kill_itself_when_parent_died() -> None:\n    # if sys.platform == \"linux\":\n    # sigkill this process when parent worker manager dies\n    PR_SET_PDEATHSIG = 1\n    import platform\n\n    if platform.system() == \"Linux\":\n        libc = ctypes.CDLL(\"libc.so.6\")\n        libc.prctl(PR_SET_PDEATHSIG, signal.SIGKILL)\n    # elif platform.system() == \"Darwin\":\n    #     libc = ctypes.CDLL(\"libc.dylib\")\n    #     logger.warning(\"kill_itself_when_parent_died is only supported in linux.\")\n    else:\n        logger.warning(\"kill_itself_when_parent_died is only supported in linux.\")\n\n\ndef get_exception_traceback() -> str:\n    etype, value, tb = sys.exc_info()\n    err_str = \"\".join(traceback.format_exception(etype, value, tb))\n    return err_str\n\n\nclass TypeBasedDispatcher:\n\n    def __init__(self, mapping: list[tuple[type, Callable]]):\n        self._mapping = mapping\n\n    def __call__(self, obj: Any):\n        for ty, fn in self._mapping:\n            if isinstance(obj, ty):\n                return fn(obj)\n        raise ValueError(f\"Invalid object: {obj}\")\n\n\n@dataclass\nclass MixedPrecisionState:\n    param_dtype: torch.dtype | None = None\n    reduce_dtype: torch.dtype | None = None\n    output_dtype: torch.dtype | None = None\n    compute_dtype: torch.dtype | None = None\n    mp_policy: MixedPrecisionPolicy | None = None\n\n\n# Thread-local storage for mixed precision state\n_mixed_precision_state = threading.local()\n\n\ndef get_mixed_precision_state() -> MixedPrecisionState:\n    \"\"\"Get the current mixed precision state.\"\"\"\n    if not hasattr(_mixed_precision_state, \"state\"):\n        raise ValueError(\"Mixed precision state not set\")\n    return cast(MixedPrecisionState, _mixed_precision_state.state)\n\n\ndef set_mixed_precision_policy(\n    param_dtype: torch.dtype,\n    reduce_dtype: torch.dtype,\n    output_dtype: torch.dtype | None = None,\n    mp_policy: MixedPrecisionPolicy | None = None,\n):\n    \"\"\"Set mixed precision policy globally.\n\n    Args:\n        param_dtype: Parameter dtype used for training\n        reduce_dtype: Reduction dtype used for gradients\n        output_dtype: Optional output dtype\n    \"\"\"\n    state = MixedPrecisionState(\n        param_dtype=param_dtype,\n        reduce_dtype=reduce_dtype,\n        output_dtype=output_dtype,\n        mp_policy=mp_policy,\n    )\n    _mixed_precision_state.state = state\n\n\ndef get_compute_dtype() -> torch.dtype:\n    \"\"\"Get the current compute dtype from mixed precision policy.\"\"\"\n    if not hasattr(_mixed_precision_state, \"state\"):\n        return torch.get_default_dtype()\n    else:\n        state = get_mixed_precision_state()\n        return state.param_dtype\n\n\ndef dict_to_3d_list(\n    mask_strategy: dict[str, Any] | None = None,\n    t_max: int | None = None,\n    l_max: int | None = None,\n    h_max: int | None = None,\n) -> list[list[list[torch.Tensor | None]]]:\n    \"\"\"\n    Convert a dictionary of mask indices to a 3D list of tensors.\n    Args:\n        mask_strategy: keys are \"t_l_h\", values are torch.Tensor masks.\n        t_max, l_max, h_max: if provided (all three), force the output shape to (t_max, l_max, h_max).\n                            If all three are None, infer shape from the data.\n    \"\"\"\n    # Case 1: no data, but fixed shape requested\n    if mask_strategy is None:\n        assert (\n            t_max is not None and l_max is not None and h_max is not None\n        ), \"If mask_strategy is None, you must provide t_max, l_max, and h_max\"\n        return [\n            [[None for _ in range(h_max)] for _ in range(l_max)] for _ in range(t_max)\n        ]\n\n    # Parse all keys into integer tuples\n    indices = [tuple(map(int, key.split(\"_\"))) for key in mask_strategy]\n\n    # Decide on dimensions\n    if t_max is None and l_max is None and h_max is None:\n        # fully dynamic: infer from data\n        max_timesteps_idx = max(t for t, _, _ in indices) + 1\n        max_layer_idx = max(l for _, l, _ in indices) + 1  # noqa: E741\n        max_head_idx = max(h for _, _, h in indices) + 1\n    else:\n        # require all three to be provided\n        assert t_max is not None and l_max is not None and h_max is not None, (\n            \"Either supply none of (t_max, l_max, h_max) to infer dimensions, \"\n            \"or supply all three to fix the shape.\"\n        )\n        max_timesteps_idx = t_max\n        max_layer_idx = l_max\n        max_head_idx = h_max\n\n    # Preallocate\n    result = [\n        [[None for _ in range(max_head_idx)] for _ in range(max_layer_idx)]\n        for _ in range(max_timesteps_idx)\n    ]\n\n    # Fill in, skipping any out-of-bounds entries\n    for key, value in mask_strategy.items():\n        t, l, h = map(int, key.split(\"_\"))  # noqa: E741\n        if (\n            0 <= t < max_timesteps_idx\n            and 0 <= l < max_layer_idx\n            and 0 <= h < max_head_idx\n        ):\n            result[t][l][h] = value\n        # else: silently ignore any key that doesn't fit\n\n    return result\n\n\ndef set_random_seed(seed: int) -> None:\n    from sglang.multimodal_gen.runtime.platforms import current_platform\n\n    current_platform.seed_everything(seed)\n\n\n@lru_cache(maxsize=1)\ndef is_vsa_available() -> bool:\n    return importlib.util.find_spec(\"vsa\") is not None\n\n\n@lru_cache(maxsize=1)\ndef is_vmoba_available() -> bool:\n    if importlib.util.find_spec(\"kernel.csrc.attn.vmoba_attn.vmoba\") is None:\n        return False\n    try:\n        import flash_attn\n\n        return flash_attn.__version__ >= \"2.7.4\"\n    except Exception:\n        return False\n\n\n# adapted from: https://github.com/Wan-Video/Wan2.2/blob/main/wan/utils/utils.py\ndef masks_like(\n    tensors, zero=False, generator=None, p=0.2\n) -> tuple[list[torch.Tensor], list[torch.Tensor]]:\n    \"\"\"\n    Generate binary masks for Text-to-Image-to-Video (TI2V) tasks.\n\n    Creates masks to control which frames should be preserved vs replaced.\n    Primarily used to fix the first frame to the input image while generating other frames.\n\n    Args:\n        tensors: List of tensors with shape [C, T, H, W]\n        zero: If True, set first frame (dim 1, index 0) to zero. Default: False\n        generator: Optional random generator for stochastic masking\n        p: Probability of applying special noise when generator is provided. Default: 0.2\n\n    Returns:\n        Tuple of two lists of tensors:\n        - When zero=False: Both lists contain all-ones tensors\n        - When zero=True (no generator): First frame set to 0, others to 1\n        - When zero=True (with generator): First frame set to small random values with probability p\n\n    Example:\n        >>> latent = torch.randn(48, 69, 96, 160)  # [C, T, H, W]\n        >>> _, mask = masks_like([latent], zero=True)\n        >>> # mask[0][:, 0] == 0 (first frame)\n        >>> # mask[0][:, 1:] == 1 (other frames)\n        >>> blended = (1.0 - mask[0]) * image + mask[0] * latent\n        >>> # Result: first frame = image, other frames = latent\n    \"\"\"\n    assert isinstance(tensors, list)\n    out1 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensors]\n\n    out2 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensors]\n\n    if zero:\n        if generator is not None:\n            for u, v in zip(out1, out2, strict=False):\n                random_num = torch.rand(\n                    1, generator=generator, device=generator.device\n                ).item()\n                if random_num < p:\n                    u[:, 0] = (\n                        torch.normal(\n                            mean=-3.5,\n                            std=0.5,\n                            size=(1,),\n                            device=u.device,\n                            generator=generator,\n                        )\n                        .expand_as(u[:, 0])\n                        .exp()\n                    )\n                    v[:, 0] = torch.zeros_like(v[:, 0])\n                else:\n                    u[:, 0] = u[:, 0]\n                    v[:, 0] = v[:, 0]\n\n        else:\n            for u, v in zip(out1, out2, strict=False):\n                u[:, 0] = torch.zeros_like(u[:, 0])\n                v[:, 0] = torch.zeros_like(v[:, 0])\n\n    return out1, out2\n\n\n# adapted from: https://github.com/Wan-Video/Wan2.2/blob/main/wan/utils/utils.py\ndef best_output_size(w, h, dw, dh, expected_area):\n    # float output size\n    ratio = w / h\n    ow = (expected_area * ratio) ** 0.5\n    oh = expected_area / ow\n\n    # process width first\n    ow1 = int(ow // dw * dw)\n    oh1 = int(expected_area / ow1 // dh * dh)\n    assert ow1 % dw == 0 and oh1 % dh == 0 and ow1 * oh1 <= expected_area\n    ratio1 = ow1 / oh1\n\n    # process height first\n    oh2 = int(oh // dh * dh)\n    ow2 = int(expected_area / oh2 // dw * dw)\n    assert oh2 % dh == 0 and ow2 % dw == 0 and ow2 * oh2 <= expected_area\n    ratio2 = ow2 / oh2\n\n    # compare ratios\n    if max(ratio / ratio1, ratio1 / ratio) < max(ratio / ratio2, ratio2 / ratio):\n        return ow1, oh1\n    else:\n        return ow2, oh2\n\n\ndef calculate_dimensions(target_area, ratio):\n    width = math.sqrt(target_area * ratio)\n    height = width / ratio\n\n    width = round(width / 32) * 32\n    height = round(height / 32) * 32\n\n    return width, height, None\n"
  },
  {
    "path": "python/sglang/profiler.py",
    "content": "\"\"\"\nRun live profiling.\n\nUsage:\npython3 -m sglang.profiler\n\"\"\"\n\nimport argparse\nimport json\nimport os\nimport time\nfrom argparse import ArgumentParser\nfrom pathlib import Path\nfrom typing import List, Optional\n\nimport requests\n\nPROFILER_DIR = os.getenv(\"SGLANG_TORCH_PROFILER_DIR\", \"/tmp\")\n\n\ndef run_profile(\n    url: Optional[str],\n    num_steps: int,\n    activities: List[str],\n    output_dir: Optional[str] = None,\n    profile_by_stage: bool = False,\n    merge_profiles: bool = False,\n    profile_prefix: Optional[str] = None,\n    start_step: Optional[int] = None,\n) -> str:\n    if output_dir is None:\n        output_dir = PROFILER_DIR\n\n    output_dir = Path(os.path.abspath(os.path.normpath(output_dir))) / str(time.time())\n    output_dir.mkdir(exist_ok=True, parents=True)\n\n    print(f\"Dump profiling traces to {output_dir}\")\n    print(\n        f\"Waiting for {num_steps} steps and the trace to be flushed.... ({profile_by_stage=})\"\n    )\n\n    # Dump server args.\n    file_path = Path(output_dir) / \"server_args.json\"\n    if not file_path.exists():\n        response = requests.get(url + \"/get_server_info\")\n        response.raise_for_status()\n        server_args_data = response.json()\n        with open(file_path, \"w\") as file:\n            file.write(json.dumps(server_args_data))\n\n    # Start profiler. The API replies when all steps are processed\n    # and files are generated.\n    json_data = {\n        \"output_dir\": str(output_dir),\n        \"num_steps\": str(num_steps),\n        \"activities\": activities,\n        \"profile_by_stage\": profile_by_stage,\n        \"merge_profiles\": merge_profiles,\n        \"profile_prefix\": profile_prefix,\n    }\n    if start_step is not None:\n        json_data[\"start_step\"] = str(start_step)\n\n    response = requests.post(url=url + \"/start_profile\", json=json_data)\n    response.raise_for_status()\n\n    trace_link = str(output_dir)\n    return trace_link\n\n\nif __name__ == \"__main__\":\n    parser = ArgumentParser(description=\"Benchmark the online serving throughput.\")\n    parser.add_argument(\n        \"--url\",\n        type=str,\n        default=\"http://localhost:30000\",\n        help=\"Server or API base url if not using http host and port.\",\n    )\n    parser.add_argument(\n        \"--output-dir\",\n        type=str,\n        default=None,\n        help=\"Profile directory to dump profile traces.\",\n    )\n    parser.add_argument(\n        \"--num-steps\",\n        type=int,\n        default=5,\n        help=\"The number of forward steps to profile.\",\n    )\n    parser.add_argument(\n        \"--profile-by-stage\",\n        action=argparse.BooleanOptionalAction,\n        type=bool,\n        default=False,\n        help=\"Whether to profile prefill and decode separately\",\n    )\n    parser.add_argument(\n        \"--profile-prefix\",\n        type=str,\n        help=\"The prefix of this profiler file.\",\n    )\n    parser.add_argument(\n        \"--cpu\",\n        action=argparse.BooleanOptionalAction,\n        type=bool,\n        default=True,\n        help=\"Whether to profile CPU activity\",\n    )\n    parser.add_argument(\n        \"--gpu\",\n        action=argparse.BooleanOptionalAction,\n        type=bool,\n        default=True,\n        help=\"Whether to profile GPU activity\",\n    )\n    parser.add_argument(\n        \"--mem\",\n        action=argparse.BooleanOptionalAction,\n        type=bool,\n        default=False,\n        help=\"Whether to profile memory usage (https://pytorch.org/memory_viz)\",\n    )\n    parser.add_argument(\n        \"--rpd\",\n        action=argparse.BooleanOptionalAction,\n        type=bool,\n        default=False,\n        help=\"Whether to use ROCM rpd profiler (https://github.com/ROCm/rocmProfileData)\",\n    )\n    parser.add_argument(\n        \"--merge-profiles\",\n        action=argparse.BooleanOptionalAction,\n        type=bool,\n        default=False,\n        help=\"Whether to merge profiles from all ranks into a single trace file\",\n    )\n\n    args = parser.parse_args()\n    activities = []\n    if args.cpu:\n        activities.append(\"CPU\")\n    if args.gpu:\n        activities.append(\"GPU\")\n    if args.mem:\n        activities.append(\"MEM\")\n    if args.rpd:\n        activities.append(\"RPD\")\n\n    run_profile(\n        url=args.url,\n        num_steps=args.num_steps,\n        activities=activities,\n        output_dir=args.output_dir,\n        profile_by_stage=args.profile_by_stage,\n        profile_prefix=args.profile_prefix,\n        merge_profiles=args.merge_profiles,\n    )\n"
  },
  {
    "path": "python/sglang/srt/batch_invariant_ops/__init__.py",
    "content": "# Adapted from https://github.com/thinking-machines-lab/batch_invariant_ops/blob/main/batch_invariant_ops/__init__.py\n\nfrom .batch_invariant_ops import (\n    AttentionBlockSize,\n    disable_batch_invariant_mode,\n    enable_batch_invariant_mode,\n    get_batch_invariant_attention_block_size,\n    is_batch_invariant_mode_enabled,\n    log_softmax,\n    matmul_persistent,\n    mean_dim,\n    rms_norm_batch_invariant,\n    set_batch_invariant_mode,\n)\n\n__version__ = \"0.1.0\"\n\n__all__ = [\n    \"set_batch_invariant_mode\",\n    \"is_batch_invariant_mode_enabled\",\n    \"disable_batch_invariant_mode\",\n    \"enable_batch_invariant_mode\",\n    \"matmul_persistent\",\n    \"log_softmax\",\n    \"mean_dim\",\n    \"get_batch_invariant_attention_block_size\",\n    \"AttentionBlockSize\",\n    \"rms_norm_batch_invariant\",\n]\n"
  },
  {
    "path": "python/sglang/srt/batch_invariant_ops/batch_invariant_ops.py",
    "content": "# Adapted from https://github.com/thinking-machines-lab/batch_invariant_ops/blob/main/batch_invariant_ops/batch_invariant_ops.py\n\nimport contextlib\nfrom collections import namedtuple\nfrom collections.abc import Callable\nfrom typing import Any, Dict\n\nimport torch\nimport triton\nimport triton.language as tl\n\nfrom sglang.srt.layers.deep_gemm_wrapper.configurer import ENABLE_JIT_DEEPGEMM\nfrom sglang.srt.utils.common import calc_diff, get_bool_env_var\n\nif ENABLE_JIT_DEEPGEMM:\n    import deep_gemm\n\n_ENABLE_MM_DEEPGEMM = get_bool_env_var(\n    \"SGLANG_BATCH_INVARIANT_OPS_ENABLE_MM_DEEPGEMM\", \"1\"\n)\n# If true, allows to fallback to batch variant gemm when the shape cannot be run in DeepGEMM\n_ENABLE_MM_FALLBACK_VARIANT = get_bool_env_var(\n    \"SGLANG_BATCH_INVARIANT_OPS_ENABLE_MM_FALLBACK_VARIANT\", \"0\"\n)\n_ENABLE_MM_COMPARISON_TEST = get_bool_env_var(\n    \"SGLANG_BATCH_INVARIANT_OPS_ENABLE_MM_COMPARISON_TEST\"\n)\n\nif not _ENABLE_MM_DEEPGEMM:\n    print(\"Disable DeepGEMM in batch invariant ops. Performance may be suboptimal.\")\n\n__all__ = [\n    \"set_batch_invariant_mode\",\n    \"is_batch_invariant_mode_enabled\",\n    \"disable_batch_invariant_mode\",\n    \"enable_batch_invariant_mode\",\n]\n\n\ndef _matmul_launch_metadata(\n    grid: Callable[..., Any], kernel: Any, args: Dict[str, Any]\n) -> Dict[str, Any]:\n    ret = {}\n    m, n, k = args[\"M\"], args[\"N\"], args[\"K\"]\n    ret[\"name\"] = f\"{kernel.name} [M={m}, N={n}, K={k}]\"\n    if \"tiles_per_update\" in args:\n        ret[\"name\"] = (\n            f\"{kernel.name} [M={m}, N={n}, K={k}, tiles_per_update={args['tiles_per_update']:02}]\"\n        )\n    if \"c_ptr\" in args:\n        bytes_per_elem = args[\"c_ptr\"].element_size()\n    else:\n        bytes_per_elem = 1 if args[\"FP8_OUTPUT\"] else 2\n    ret[f\"flops{bytes_per_elem * 8}\"] = 2.0 * m * n * k\n    ret[\"bytes\"] = bytes_per_elem * (m * k + n * k + m * n)\n    return ret\n\n\n@triton.jit\ndef _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS):\n    group_id = tile_id // num_pid_in_group\n    first_pid_m = group_id * GROUP_SIZE_M\n    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n    pid_m = first_pid_m + (tile_id % group_size_m)\n    pid_n = (tile_id % num_pid_in_group) // group_size_m\n    return pid_m, pid_n\n\n\n@triton.jit(launch_metadata=_matmul_launch_metadata)\ndef matmul_kernel_persistent(\n    a_ptr,\n    b_ptr,\n    c_ptr,  #\n    bias_ptr,\n    M,\n    N,\n    K,  #\n    stride_am,\n    stride_ak,\n    stride_bk,\n    stride_bn,\n    stride_cm,\n    stride_cn,\n    BLOCK_SIZE_M: tl.constexpr,  #\n    BLOCK_SIZE_N: tl.constexpr,  #\n    BLOCK_SIZE_K: tl.constexpr,  #\n    GROUP_SIZE_M: tl.constexpr,  #\n    NUM_SMS: tl.constexpr,  #\n    A_LARGE: tl.constexpr,\n    B_LARGE: tl.constexpr,\n    C_LARGE: tl.constexpr,\n    HAS_BIAS: tl.constexpr,\n):\n    start_pid = tl.program_id(axis=0)\n    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n    k_tiles = tl.cdiv(K, BLOCK_SIZE_K)\n    num_tiles = num_pid_m * num_pid_n\n\n    offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K)\n    num_pid_in_group = GROUP_SIZE_M * num_pid_n\n\n    for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True):\n        pid_m, pid_n = _compute_pid(\n            tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS\n        )\n        start_m = pid_m * BLOCK_SIZE_M\n        start_n = pid_n * BLOCK_SIZE_N\n        offs_am = start_m + tl.arange(0, BLOCK_SIZE_M)\n        offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N)\n        if A_LARGE:\n            offs_am = offs_am.to(tl.int64)\n        if B_LARGE:\n            offs_bn = offs_bn.to(tl.int64)\n        offs_am = tl.where(offs_am < M, offs_am, 0)\n        offs_bn = tl.where(offs_bn < N, offs_bn, 0)\n        offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)\n        offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)\n\n        accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n        for ki in range(k_tiles):\n            if A_LARGE or B_LARGE:\n                offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K).to(tl.int64)\n            else:\n                offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)\n            a_ptrs = a_ptr + (\n                offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak\n            )\n            b_ptrs = b_ptr + (\n                offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn\n            )\n\n            a = tl.load(\n                a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0\n            )\n            b = tl.load(\n                b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0\n            )\n            accumulator = tl.dot(a, b, accumulator)\n\n        offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n        offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n        if C_LARGE:\n            offs_cm = offs_cm.to(tl.int64)\n            offs_cn = offs_cn.to(tl.int64)\n        c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n        c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n        if HAS_BIAS:\n            bias_ptrs = bias_ptr + offs_cn\n            bias = tl.load(bias_ptrs, mask=offs_cn < N, other=0.0).to(tl.float32)\n            accumulator += bias\n        if c_ptr.dtype.element_ty == tl.float8e4nv:\n            c = accumulator.to(tl.float8e4nv)\n        elif c_ptr.dtype.element_ty == tl.bfloat16:\n            c = accumulator.to(tl.bfloat16)\n        elif c_ptr.dtype.element_ty == tl.float32:\n            c = accumulator.to(tl.float32)\n        else:\n            c = accumulator.to(tl.float16)\n        tl.store(c_ptrs, c, mask=c_mask)\n\n\ndef _matmul_persistent_triton(\n    a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None\n):\n    # Check constraints.\n    assert a.shape[1] == b.shape[0], \"Incompatible dimensions\"\n    assert a.dtype == b.dtype, \"Incompatible dtypes\"\n    assert (\n        bias is None or bias.dim() == 1\n    ), \"Currently assuming bias is 1D, let Horace know if you run into this\"\n    NUM_SMS = torch.cuda.get_device_properties(\"cuda\").multi_processor_count\n    M, K = a.shape\n    K, N = b.shape\n    dtype = a.dtype\n    # Allocates output.\n    c = torch.empty((M, N), device=a.device, dtype=dtype)\n\n    # 1D launch kernel where each block gets its own program.\n    def grid(META):\n        return (\n            min(\n                NUM_SMS,\n                triton.cdiv(M, META[\"BLOCK_SIZE_M\"])\n                * triton.cdiv(N, META[\"BLOCK_SIZE_N\"]),\n            ),\n        )\n\n    configs = {\n        torch.bfloat16: {\n            \"BLOCK_SIZE_M\": 128,\n            \"BLOCK_SIZE_N\": 128,\n            \"BLOCK_SIZE_K\": 64,\n            \"GROUP_SIZE_M\": 8,\n            \"num_stages\": 3,\n            \"num_warps\": 8,\n        },\n        torch.float16: {\n            \"BLOCK_SIZE_M\": 128,\n            \"BLOCK_SIZE_N\": 256,\n            \"BLOCK_SIZE_K\": 64,\n            \"GROUP_SIZE_M\": 8,\n            \"num_stages\": 3,\n            \"num_warps\": 8,\n        },\n        torch.float32: {\n            \"BLOCK_SIZE_M\": 128,\n            \"BLOCK_SIZE_N\": 128,\n            \"BLOCK_SIZE_K\": 32,\n            \"GROUP_SIZE_M\": 8,\n            \"num_stages\": 3,\n            \"num_warps\": 8,\n        },\n    }\n    # print(a.device, b.device, c.device)\n    matmul_kernel_persistent[grid](\n        a,\n        b,\n        c,  #\n        bias,\n        M,\n        N,\n        K,  #\n        a.stride(0),\n        a.stride(1),  #\n        b.stride(0),\n        b.stride(1),  #\n        c.stride(0),\n        c.stride(1),  #\n        NUM_SMS=NUM_SMS,  #\n        A_LARGE=a.numel() > 2**31,\n        B_LARGE=b.numel() > 2**31,\n        C_LARGE=c.numel() > 2**31,\n        HAS_BIAS=bias is not None,\n        **configs[dtype],\n    )\n    return c\n\n\ndef _matmul_persistent_deepgemm(\n    a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None\n):\n    M, K = a.shape\n    K, N = b.shape\n    dtype = a.dtype\n    out = torch.empty((M, N), device=a.device, dtype=dtype)\n\n    try:\n        deep_gemm.bf16_gemm_nn(a, b, out)\n    except RuntimeError as e:\n        raise RuntimeError(\n            f\"DeepGEMM failed for matrix shapes M={M}, N={N}, K={K}. \"\n            f\"This typically occurs when dimensions are too small for DeepGEMM's TMA descriptors. \"\n            f\"Consider increasing MIN_DEEPGEMM_DIM in matmul_persistent() or disabling DeepGEMM \"\n            f\"for small matrices. Original error: {e}\"\n        ) from e\n\n    # TODO can this be put in DeepGEMM's `c`?\n    if bias is not None:\n        out += bias\n\n    return out\n\n\ndef matmul_persistent(\n    a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None\n):\n    K, N = b.shape\n\n    # DeepGEMM has minimum dimension requirements for TMA descriptors\n    MIN_DEEPGEMM_DIM = 16\n\n    if (\n        _ENABLE_MM_DEEPGEMM\n        and ENABLE_JIT_DEEPGEMM\n        and (a.dtype == torch.bfloat16)\n        and (b.dtype == torch.bfloat16)\n        and a.is_contiguous()\n        and b.transpose(0, 1).is_contiguous()\n        and N >= MIN_DEEPGEMM_DIM\n    ):\n        if _ENABLE_MM_COMPARISON_TEST:\n            out_triton = _matmul_persistent_triton(a=a, b=b, bias=bias)\n            out_deepgemm = _matmul_persistent_deepgemm(a=a, b=b, bias=bias)\n            diff = calc_diff(out_triton, out_deepgemm)\n            assert diff < 0.0001, f\"{diff=} {out_triton=} {out_deepgemm=}\"\n            # can be enabled for debugging\n            # print(\n            #     f\"{diff=} \"\n            #     f\"{(out_triton - out_deepgemm).abs().mean()=} \"\n            #     f\"{(out_triton - out_deepgemm).abs().sum()=} \"\n            #     f\"{torch.sum(out_triton != out_deepgemm)=} \"\n            # )\n            # print(f\"{a=} {b=} {bias=} {out_triton=} {out_deepgemm=}\")\n            return out_deepgemm\n\n        return _matmul_persistent_deepgemm(a=a, b=b, bias=bias)\n\n    if _ENABLE_MM_FALLBACK_VARIANT:\n        out = torch.einsum(\"ik,kj->ij\", a, b)\n        if bias is not None:\n            out += bias\n        return out\n\n    return _matmul_persistent_triton(a=a, b=b, bias=bias)\n\n\n@triton.jit\ndef _log_softmax_kernel(\n    input_ptr,\n    output_ptr,\n    input_row_stride: tl.constexpr,\n    output_row_stride: tl.constexpr,\n    n_cols: tl.constexpr,\n    BLOCK_SIZE: tl.constexpr,\n):\n    \"\"\"\n    Compute log_softmax along the last dimension of a 2D tensor.\n    Each block handles one row of the input tensor.\n    \"\"\"\n    # Get the row index for this block\n    row_idx = tl.program_id(0).to(tl.int64)\n\n    # Compute base pointers for input and output rows\n    row_start_ptr = input_ptr + row_idx * input_row_stride\n    output_row_start_ptr = output_ptr + row_idx * output_row_stride\n\n    # Step 1: Find maximum value in the row for numerical stability\n    # Load first block to infer dtype and initialize max_val with correct type\n    col_idx_init = tl.arange(0, BLOCK_SIZE)\n    mask_init = col_idx_init < n_cols\n    vals_init = tl.load(\n        row_start_ptr + col_idx_init, mask=mask_init, other=-float(\"inf\")\n    )\n    max_val = tl.max(vals_init)\n\n    # Continue with remaining blocks\n    for col_offset in range(BLOCK_SIZE, n_cols, BLOCK_SIZE):\n        col_idx = col_offset + tl.arange(0, BLOCK_SIZE)\n        mask = col_idx < n_cols\n\n        # Load values\n        vals = tl.load(row_start_ptr + col_idx, mask=mask, other=-float(\"inf\"))\n\n        # Update maximum\n        max_val = tl.max(tl.maximum(vals, max_val))\n\n    # Step 2: Compute sum of exp(x - max_val)\n    # Initialize sum_exp with correct dtype by using tl.sum on a zero vector\n    sum_exp = tl.sum(tl.zeros([1], dtype=max_val.dtype))\n\n    for col_offset in range(0, n_cols, BLOCK_SIZE):\n        col_idx = col_offset + tl.arange(0, BLOCK_SIZE)\n        mask = col_idx < n_cols\n\n        # Load values\n        vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)\n\n        # Compute exp(x - max_val) and accumulate\n        exp_vals = tl.exp(vals - max_val)\n        sum_exp += tl.sum(tl.where(mask, exp_vals, 0.0))\n\n    # Compute log(sum_exp)\n    log_sum_exp = tl.log(sum_exp)\n\n    # Step 3: Compute final log_softmax values: x - max_val - log_sum_exp\n    for col_offset in range(0, n_cols, BLOCK_SIZE):\n        col_idx = col_offset + tl.arange(0, BLOCK_SIZE)\n        mask = col_idx < n_cols\n\n        # Load values\n        vals = tl.load(row_start_ptr + col_idx, mask=mask)\n\n        # Compute log_softmax\n        output = vals - max_val - log_sum_exp\n\n        # Store results\n        tl.store(output_row_start_ptr + col_idx, output, mask=mask)\n\n\ndef log_softmax(input: torch.Tensor, dim: int = -1) -> torch.Tensor:\n    \"\"\"\n    Compute log_softmax using Triton kernel.\n\n    Args:\n        input: Input tensor\n        dim: Dimension along which to compute log_softmax (only -1 or last dim supported)\n    >> Stashed changes\n    Returns:\n        Tensor with log_softmax applied along the specified dimension\n    \"\"\"\n    if dim != -1 and dim != input.ndim - 1:\n        raise ValueError(\n            \"This implementation only supports log_softmax along the last dimension\"\n        )\n\n    # Flatten all dimensions except the last one\n    original_shape = input.shape\n    input_2d = input.reshape(-1, input.shape[-1])\n    input_2d = input_2d.contiguous()\n\n    n_rows, n_cols = input_2d.shape\n\n    # Allocate output tensor\n    output = torch.empty_like(input_2d)\n\n    # Choose block size based on the number of columns\n    BLOCK_SIZE = 1024\n\n    # Launch kernel with one block per row\n    grid = (n_rows,)\n    _log_softmax_kernel[grid](\n        input_2d,\n        output,\n        input_2d.stride(0),\n        output.stride(0),\n        n_cols,\n        BLOCK_SIZE=BLOCK_SIZE,\n    )\n    # Reshape output back to original shape\n    return output.reshape(original_shape)\n\n\n@triton.jit\ndef mean_kernel(\n    input_ptr,\n    output_ptr,\n    input_stride0,\n    input_stride1,\n    input_stride2,\n    output_stride0,\n    output_stride1,\n    M,  # size before reduction dim\n    N,  # size of reduction dim\n    K,  # size after reduction dim\n    BLOCK_SIZE: tl.constexpr,\n):\n    \"\"\"\n    Kernel for computing mean along a single dimension.\n    Input is viewed as (M, N, K) where N is the dimension being reduced.\n    \"\"\"\n    # Program ID gives us which output element we're computing\n    pid = tl.program_id(0)\n\n    # Compute output indices\n    m_idx = pid // K\n    k_idx = pid % K\n\n    # Bounds check\n    if m_idx >= M or k_idx >= K:\n        return\n\n    # Accumulate sum across reduction dimension\n    acc = 0.0\n    for n_start in range(0, N, BLOCK_SIZE):\n        n_offsets = n_start + tl.arange(0, BLOCK_SIZE)\n        mask = n_offsets < N\n\n        # Calculate input indices\n        input_idx = (\n            m_idx * input_stride0 + n_offsets * input_stride1 + k_idx * input_stride2\n        )\n\n        # Load and accumulate\n        vals = tl.load(input_ptr + input_idx, mask=mask, other=0.0)\n        acc += tl.sum(vals)\n\n    # Compute mean and store\n    mean_val = acc / N\n    output_idx = m_idx * output_stride0 + k_idx * output_stride1\n    tl.store(output_ptr + output_idx, mean_val)\n\n\ndef mean_dim(\n    input: torch.Tensor,\n    dim: int,\n    keepdim: bool = False,\n    dtype: torch.dtype | None = None,\n) -> torch.Tensor:\n    \"\"\"\n    Triton implementation of torch.mean with single dimension reduction.\n\n    Args:\n        input: Input tensor\n        dim: Single dimension along which to compute mean\n        keepdim: Whether to keep the reduced dimension\n        dtype: Output dtype. If None, uses input dtype (or float32 for integer inputs)\n\n    Returns:\n        Tensor with mean values along specified dimension\n    \"\"\"\n    # Validate inputs\n    assert input.is_cuda, \"Input must be a CUDA tensor\"\n    assert (\n        -input.ndim <= dim < input.ndim\n    ), f\"Invalid dimension {dim} for tensor with {input.ndim} dimensions\"\n\n    # Handle negative dim\n    if dim < 0:\n        dim = dim + input.ndim\n\n    # Handle dtype\n    if dtype is None:\n        if input.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:\n            dtype = torch.float32\n        else:\n            dtype = input.dtype\n\n    # Convert input to appropriate dtype if needed\n    if input.dtype != dtype:\n        input = input.to(dtype)\n\n    # Get input shape and strides\n    shape = list(input.shape)\n\n    # Calculate dimensions for kernel\n    M = 1\n    for i in range(dim):\n        M *= shape[i]\n\n    N = shape[dim]\n\n    K = 1\n    for i in range(dim + 1, len(shape)):\n        K *= shape[i]\n\n    # Reshape input to 3D view (M, N, K)\n    input_3d = input.reshape(M, N, K)\n\n    # Create output shape\n    if keepdim:\n        output_shape = shape.copy()\n        output_shape[dim] = 1\n    else:\n        output_shape = shape[:dim] + shape[dim + 1 :]\n\n    # Create output tensor\n    output = torch.empty(output_shape, dtype=dtype, device=input.device)\n\n    # Reshape output for kernel\n    if keepdim:\n        output_2d = output.reshape(M, 1, K).squeeze(1)\n    else:\n        output_2d = output.reshape(M, K)\n\n    # Launch kernel\n    grid = (M * K,)\n    BLOCK_SIZE = 1024\n\n    mean_kernel[grid](\n        input_3d,\n        output_2d,\n        input_3d.stride(0),\n        input_3d.stride(1),\n        input_3d.stride(2),\n        output_2d.stride(0),\n        output_2d.stride(1) if output_2d.ndim > 1 else 0,\n        M,\n        N,\n        K,\n        BLOCK_SIZE,\n    )\n\n    return output\n\n\ndef mm_batch_invariant(a, b):\n    return matmul_persistent(a, b)\n\n\ndef addmm_batch_invariant(bias, a, b):\n    return matmul_persistent(a, b, bias=bias)\n\n\ndef _log_softmax_batch_invariant(input, dim, _half_to_float):\n    assert not _half_to_float, \"not implemented\"\n    return log_softmax(input, dim=dim)\n\n\ndef mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype | None = None):\n    assert dtype is None or dtype == torch.float32, f\"unsupported dtype: {dtype}\"\n    if len(dim) == 1:\n        return mean_dim(input, dim[0], keepdim=keepdim)\n    else:\n        assert input.dtype in {\n            torch.float16,\n            torch.bfloat16,\n            torch.float32,\n        }, \"only float types supported for now\"\n        n_elems = 1\n        for d in dim:\n            n_elems *= input.shape[d]\n        return torch.sum(input, dim=dim, keepdim=keepdim, dtype=torch.float32) / n_elems\n\n\n@triton.jit\ndef bmm_kernel_persistent(\n    a_ptr,\n    b_ptr,\n    c_ptr,  #\n    B,\n    M,\n    N,\n    K,  #\n    stride_ab,\n    stride_am,\n    stride_ak,\n    stride_bb,\n    stride_bk,\n    stride_bn,\n    stride_cb,\n    stride_cm,\n    stride_cn,\n    BLOCK_SIZE_M: tl.constexpr,  #\n    BLOCK_SIZE_N: tl.constexpr,  #\n    BLOCK_SIZE_K: tl.constexpr,  #\n    GROUP_SIZE_M: tl.constexpr,  #\n    NUM_SMS: tl.constexpr,  #\n    A_LARGE: tl.constexpr,\n    B_LARGE: tl.constexpr,\n    C_LARGE: tl.constexpr,\n):\n    \"\"\"\n    Batched matrix multiplication kernel that processes batches in parallel.\n    Each tile processes a (BLOCK_SIZE_M, BLOCK_SIZE_N) output block for a specific batch.\n    \"\"\"\n    start_pid = tl.program_id(axis=0)\n    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n    k_tiles = tl.cdiv(K, BLOCK_SIZE_K)\n    num_tiles_per_batch = num_pid_m * num_pid_n\n    num_tiles_total = B * num_tiles_per_batch\n\n    offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K)\n    num_pid_in_group = GROUP_SIZE_M * num_pid_n\n\n    # Process tiles in a deterministic order: batch-major ordering\n    for tile_id in tl.range(start_pid, num_tiles_total, NUM_SMS, flatten=True):\n        # Decompose tile_id into batch and within-batch tile\n        batch_idx = tile_id // num_tiles_per_batch\n        tile_in_batch = tile_id % num_tiles_per_batch\n\n        pid_m, pid_n = _compute_pid(\n            tile_in_batch, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS\n        )\n        start_m = pid_m * BLOCK_SIZE_M\n        start_n = pid_n * BLOCK_SIZE_N\n        offs_am = start_m + tl.arange(0, BLOCK_SIZE_M)\n        offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N)\n        if A_LARGE:\n            offs_am = offs_am.to(tl.int64)\n        if B_LARGE:\n            offs_bn = offs_bn.to(tl.int64)\n        offs_am = tl.where(offs_am < M, offs_am, 0)\n        offs_bn = tl.where(offs_bn < N, offs_bn, 0)\n        offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)\n        offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)\n\n        # Add batch offset\n        if A_LARGE or B_LARGE:\n            batch_idx_typed = batch_idx.to(tl.int64)\n        else:\n            batch_idx_typed = batch_idx\n\n        accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n        for ki in range(k_tiles):\n            if A_LARGE or B_LARGE:\n                offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K).to(tl.int64)\n            else:\n                offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)\n\n            a_ptrs = a_ptr + (\n                batch_idx_typed * stride_ab\n                + offs_am[:, None] * stride_am\n                + offs_k[None, :] * stride_ak\n            )\n            b_ptrs = b_ptr + (\n                batch_idx_typed * stride_bb\n                + offs_k[:, None] * stride_bk\n                + offs_bn[None, :] * stride_bn\n            )\n\n            a = tl.load(\n                a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0\n            )\n            b = tl.load(\n                b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0\n            )\n            accumulator = tl.dot(a, b, accumulator)\n\n        offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n        offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n        if C_LARGE:\n            offs_cm = offs_cm.to(tl.int64)\n            offs_cn = offs_cn.to(tl.int64)\n        c_ptrs = (\n            c_ptr\n            + batch_idx_typed * stride_cb\n            + stride_cm * offs_cm[:, None]\n            + stride_cn * offs_cn[None, :]\n        )\n        c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n\n        if c_ptr.dtype.element_ty == tl.float8e4nv:\n            c = accumulator.to(tl.float8e4nv)\n        elif c_ptr.dtype.element_ty == tl.bfloat16:\n            c = accumulator.to(tl.bfloat16)\n        elif c_ptr.dtype.element_ty == tl.float32:\n            c = accumulator.to(tl.float32)\n        else:\n            c = accumulator.to(tl.float16)\n        tl.store(c_ptrs, c, mask=c_mask)\n\n\ndef bmm_batch_invariant(a, b, *, out=None):\n    # Batched matrix multiply: (B, M, K) x (B, K, N) -> (B, M, N)\n    # Process batches in parallel with our persistent kernel\n    if a.ndim == 3 and b.ndim == 3:\n        # Check constraints\n        assert a.shape[0] == b.shape[0], \"Batch sizes must match\"\n        assert a.shape[2] == b.shape[1], \"Incompatible dimensions\"\n        assert a.dtype == b.dtype, \"Incompatible dtypes\"\n\n        B = a.shape[0]\n        M = a.shape[1]\n        K = a.shape[2]\n        N = b.shape[2]\n        dtype = a.dtype\n\n        # Allocate output\n        if out is None:\n            c = torch.empty((B, M, N), device=a.device, dtype=dtype)\n        else:\n            c = out\n\n        NUM_SMS = torch.cuda.get_device_properties(\"cuda\").multi_processor_count\n\n        # Use fixed kernel configuration for determinism\n        configs = {\n            torch.bfloat16: {\n                \"BLOCK_SIZE_M\": 128,\n                \"BLOCK_SIZE_N\": 128,\n                \"BLOCK_SIZE_K\": 64,\n                \"GROUP_SIZE_M\": 8,\n                \"num_stages\": 3,\n                \"num_warps\": 8,\n            },\n            torch.float16: {\n                \"BLOCK_SIZE_M\": 128,\n                \"BLOCK_SIZE_N\": 256,\n                \"BLOCK_SIZE_K\": 64,\n                \"GROUP_SIZE_M\": 8,\n                \"num_stages\": 3,\n                \"num_warps\": 8,\n            },\n            torch.float32: {\n                \"BLOCK_SIZE_M\": 128,\n                \"BLOCK_SIZE_N\": 128,\n                \"BLOCK_SIZE_K\": 32,\n                \"GROUP_SIZE_M\": 8,\n                \"num_stages\": 3,\n                \"num_warps\": 8,\n            },\n        }\n\n        config = configs.get(dtype)\n        if config is None:\n            raise ValueError(\n                f\"Unsupported dtype {dtype} for bmm_batch_invariant. \"\n                f\"Supported dtypes are: {list(configs.keys())}\"\n            )\n\n        # Grid: limit by NUM_SMS for persistent kernel approach\n        num_tiles_per_batch = triton.cdiv(M, config[\"BLOCK_SIZE_M\"]) * triton.cdiv(\n            N, config[\"BLOCK_SIZE_N\"]\n        )\n        num_tiles_total = B * num_tiles_per_batch\n        grid = (min(NUM_SMS, num_tiles_total),)\n\n        bmm_kernel_persistent[grid](\n            a,\n            b,\n            c,  #\n            B,\n            M,\n            N,\n            K,  #\n            a.stride(0),\n            a.stride(1),\n            a.stride(2),  #\n            b.stride(0),\n            b.stride(1),\n            b.stride(2),  #\n            c.stride(0),\n            c.stride(1),\n            c.stride(2),  #\n            NUM_SMS=NUM_SMS,  #\n            A_LARGE=a.numel() > 2**31,\n            B_LARGE=b.numel() > 2**31,\n            C_LARGE=c.numel() > 2**31,\n            **config,\n        )\n\n        return c\n    else:\n        raise ValueError(\n            f\"bmm_batch_invariant expects 3D tensors, \"\n            f\"got shapes {a.shape} and {b.shape}\"\n        )\n\n\n@triton.jit\ndef _rms_norm_kernel(\n    input_ptr,\n    weight_ptr,\n    output_ptr,\n    input_row_stride: tl.constexpr,\n    output_row_stride: tl.constexpr,\n    n_cols: tl.constexpr,\n    eps,\n    BLOCK_SIZE: tl.constexpr,\n):\n    \"\"\"\n    Compute RMS normalization along the last dimension of a 2D tensor.\n    RMS Norm: y = x / sqrt(mean(x^2) + eps) * weight\n    Each block handles one row of the input tensor.\n    \"\"\"\n    row_idx = tl.program_id(0).to(tl.int64)\n    row_start_ptr = input_ptr + row_idx * input_row_stride\n    output_row_start_ptr = output_ptr + row_idx * output_row_stride\n\n    # Step 1: Compute sum of squares in float32 to avoid overflow\n    sum_sq = tl.zeros([1], dtype=tl.float32)\n    for col_offset in range(0, n_cols, BLOCK_SIZE):\n        col_idx = col_offset + tl.arange(0, BLOCK_SIZE)\n        mask = col_idx < n_cols\n\n        vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)\n        # Convert to float32 for accumulation to prevent overflow\n        vals_f32 = vals.to(tl.float32)\n        sq_vals = vals_f32 * vals_f32\n        sum_sq += tl.sum(tl.where(mask, sq_vals, 0.0))\n\n    # Step 2: Compute RMS (root mean square) in float32\n    mean_sq = sum_sq / n_cols\n    rms = tl.sqrt(mean_sq + eps)\n    inv_rms = 1.0 / rms\n\n    # Step 3: Normalize and apply weight\n    for col_offset in range(0, n_cols, BLOCK_SIZE):\n        col_idx = col_offset + tl.arange(0, BLOCK_SIZE)\n        mask = col_idx < n_cols\n        vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)\n        weight = tl.load(weight_ptr + col_idx, mask=mask, other=1.0)\n        # Compute in float32 then convert back to input dtype\n        vals_f32 = vals.to(tl.float32)\n        weight_f32 = weight.to(tl.float32)\n        output_f32 = vals_f32 * inv_rms * weight_f32\n        output = output_f32.to(vals.dtype)\n        tl.store(output_row_start_ptr + col_idx, output, mask=mask)\n\n\ndef rms_norm(\n    input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6\n) -> torch.Tensor:\n    \"\"\"\n    Compute RMS normalization using Triton kernel.\n\n    RMS Norm normalizes the input by the root mean square and scales by weight:\n    output = input / sqrt(mean(input^2) + eps) * weight\n\n    Args:\n        input: Input tensor of shape (..., hidden_size)\n        weight: Weight tensor of shape (hidden_size,)\n        eps: Small constant for numerical stability\n\n    Returns:\n        Tensor with RMS normalization applied along the last dimension\n    \"\"\"\n    assert weight.dim() == 1, \"Weight must be 1-dimensional\"\n    assert input.shape[-1] == weight.shape[0], (\n        f\"Input last dimension ({input.shape[-1]}) must match \"\n        f\"weight dimension ({weight.shape[0]})\"\n    )\n\n    # Flatten all dimensions except the last one\n    original_shape = input.shape\n    input_2d = input.reshape(-1, input.shape[-1])\n    input_2d = input_2d.contiguous()\n    weight = weight.contiguous()\n\n    n_rows, n_cols = input_2d.shape\n\n    output = torch.empty_like(input_2d)\n    BLOCK_SIZE = 1024\n    grid = (n_rows,)\n    _rms_norm_kernel[grid](\n        input_2d,\n        weight,\n        output,\n        input_2d.stride(0),\n        output.stride(0),\n        n_cols,\n        eps,\n        BLOCK_SIZE=BLOCK_SIZE,\n    )\n    return output.reshape(original_shape)\n\n\ndef rms_norm_batch_invariant(\n    input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6\n) -> torch.Tensor:\n    \"\"\"\n    Batch-invariant wrapper for RMS normalization.\n\n    This function provides a deterministic, batch-invariant implementation\n    of RMS normalization for use with the batch_invariant mode.\n\n    Adapted from @https://github.com/vllm-project/vllm/blob/66a168a197ba214a5b70a74fa2e713c9eeb3251a/vllm/model_executor/layers/batch_invariant.py#L649\n\n    Args:\n        input: Input tensor of shape (..., hidden_size)\n        weight: Weight tensor of shape (hidden_size,)\n        eps: Small constant for numerical stability\n\n    Returns:\n        RMS normalized tensor\n    \"\"\"\n    return rms_norm(input, weight, eps=eps)\n\n\n_batch_invariant_MODE = False\n_batch_invariant_LIB = None\n_original_torch_bmm = None\n\n\ndef is_batch_invariant_mode_enabled():\n    return _batch_invariant_MODE\n\n\ndef enable_batch_invariant_mode(\n    enable_bmm: bool = True,\n):\n    global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm\n    if _batch_invariant_MODE:\n        return\n\n    _batch_invariant_MODE = True\n    _batch_invariant_LIB = torch.library.Library(\"aten\", \"IMPL\")\n    _batch_invariant_LIB.impl(\"aten::mm\", mm_batch_invariant, \"CUDA\")\n    _batch_invariant_LIB.impl(\"aten::addmm\", addmm_batch_invariant, \"CUDA\")\n    _batch_invariant_LIB.impl(\n        \"aten::_log_softmax\", _log_softmax_batch_invariant, \"CUDA\"\n    )\n    _batch_invariant_LIB.impl(\"aten::mean.dim\", mean_batch_invariant, \"CUDA\")\n\n    if enable_bmm:\n        _batch_invariant_LIB.impl(\"aten::bmm\", bmm_batch_invariant, \"CUDA\")\n\n        # Also monkeypatch torch.bmm directly as a fallback\n        _original_torch_bmm = torch.bmm\n        torch.bmm = bmm_batch_invariant\n\n\ndef disable_batch_invariant_mode():\n    global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm\n    if _batch_invariant_LIB is not None:\n        _batch_invariant_LIB._destroy()\n    if _original_torch_bmm is not None:\n        torch.bmm = _original_torch_bmm\n        _original_torch_bmm = None\n    _batch_invariant_MODE = False\n    _batch_invariant_LIB = None\n\n\n@contextlib.contextmanager\ndef set_batch_invariant_mode(enabled: bool = True):\n    global _batch_invariant_MODE, _batch_invariant_LIB\n    old_data = (_batch_invariant_MODE, _batch_invariant_LIB)\n    if enabled:\n        enable_batch_invariant_mode()\n    else:\n        disable_batch_invariant_mode()\n    yield\n    if _batch_invariant_LIB is not None:\n        _batch_invariant_LIB._destroy()\n    _batch_invariant_MODE, _batch_invariant_LIB = old_data\n\n\nAttentionBlockSize = namedtuple(\"AttentionBlockSize\", [\"block_m\", \"block_n\"])\n\n\ndef get_batch_invariant_attention_block_size() -> AttentionBlockSize:\n    return AttentionBlockSize(block_m=16, block_n=16)\n"
  },
  {
    "path": "python/sglang/srt/batch_overlap/operations.py",
    "content": "from __future__ import annotations\n\nimport os\nfrom contextlib import contextmanager\nfrom dataclasses import dataclass\nfrom typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Sequence, Union\n\nimport torch\n\nfrom sglang.srt.layers.dp_attention import set_dp_buffer_len\n\nif TYPE_CHECKING:\n    from sglang.srt.model_executor.forward_batch_info import ForwardBatch\n\n_ENABLE_PROFILE = bool(int(os.environ.get(\"SGLANG_OPERATIONS_ENABLE_PROFILE\", \"0\")))\n\nif _ENABLE_PROFILE:\n    import nvtx\n\n\ndef execute_operations(inputs, operations):\n    stages = _convert_operations_to_stages(operations)\n    executor = _StageExecutor(\"primary\", stages, inputs=inputs)\n    for _ in range(executor.num_stages):\n        executor.next()\n    assert executor.done\n    return executor.output\n\n\ndef execute_overlapped_operations(\n    inputs_arr: Sequence,\n    operations_arr: Sequence,\n    delta_stages: Sequence[int],\n) -> Sequence:\n    # Make it explicit for clarity; if we need multi-batch overlap, this can be generalized\n    inputs_a, inputs_b = inputs_arr\n    operations_a, operations_b = operations_arr\n    delta_stage_a, delta_stage_b = delta_stages\n    assert delta_stage_a == 0\n    delta_stage = delta_stage_b\n\n    stages_a = _convert_operations_to_stages(operations_a)\n    stages_b = _convert_operations_to_stages(operations_b)\n    executor_a = _StageExecutor(\"a\", stages_a, inputs=inputs_a)\n    executor_b = _StageExecutor(\"b\", stages_b, inputs=inputs_b)\n\n    for _ in range(delta_stage):\n        executor_a.next()\n\n    for _ in range(executor_a.num_stages - delta_stage):\n        executor_a.next()\n        executor_b.next()\n\n    for _ in range(delta_stage):\n        executor_b.next()\n\n    assert executor_a.done and executor_b.done\n    return [executor_a.output, executor_b.output]\n\n\nclass YieldOperation:\n    pass\n\n\n@dataclass\nclass ExecutionOperation:\n    debug_name: str\n    fn: Callable\n\n\nOperation = Union[YieldOperation, ExecutionOperation, Callable]\nStage = List[ExecutionOperation]\n\n\nclass _StageExecutor:\n    def __init__(self, debug_name: str, stages: List[Stage], inputs: dict):\n        self._debug_name = debug_name\n        self._stages = stages\n        self._index = 0\n        self._stage_state = _StateDict()\n        self._stage_output = inputs\n\n        # handling DP attention\n        forward_batch: ForwardBatch = inputs[\"forward_batch\"]\n        self._global_dp_buffer_len = forward_batch.global_dp_buffer_len\n        self._local_dp_buffer_len = forward_batch.tbo_padded_len\n        self._global_num_tokens = forward_batch.global_num_tokens_cpu\n        self._is_dp_max_padding = forward_batch.dp_padding_mode.is_max_len()\n\n    def next(self):\n        assert not self.done\n\n        stage = self._stages[self._index]\n\n        # TODO: We currently always call set_dp_buffer_len here because sub-batches\n        # may have different padded lengths. It can likely be removed after TBO slice &\n        # pad logic is refactored.\n        set_dp_buffer_len(\n            self._global_dp_buffer_len,\n            self._local_dp_buffer_len,\n            self._is_dp_max_padding,\n            self._global_num_tokens,\n        )\n\n        with _annotate_region(debug_name=f\"{self._debug_name}{self._index}\"):\n            for op in stage:\n                with _annotate_region(debug_name=op.debug_name):\n                    self._stage_output = op.fn(\n                        state=self._stage_state,\n                        **(\n                            self._stage_output if self._stage_output is not None else {}\n                        ),\n                    )\n\n        self._index += 1\n\n    @property\n    def output(self):\n        assert self.done\n        return self._stage_output\n\n    @property\n    def done(self):\n        return self._index >= self.num_stages\n\n    @property\n    def num_stages(self):\n        return len(self._stages)\n\n\n@contextmanager\ndef _annotate_region(debug_name):\n    if _ENABLE_PROFILE:\n        with torch.autograd.profiler.record_function(debug_name):\n            with nvtx.annotate(debug_name):\n                yield\n    else:\n        yield\n\n\nclass _StateDict:\n    def __init__(self):\n        self._data = {}\n\n    def __setattr__(self, key, value):\n        if key == \"_data\":\n            super().__setattr__(key, value)\n            return\n        assert (\n            key not in self._data\n        ), f\"`{key}` already exist, are you sure you want to override it?\"\n        self._data[key] = value\n\n    def __getattr__(self, item):\n        return self._data[item]\n\n    def __delattr__(self, item):\n        del self._data[item]\n\n    def pop(self, item):\n        return self._data.pop(item)\n\n    def update(self, values: Dict[str, Any]):\n        for k, v in values.items():\n            setattr(self, k, v)\n\n    def get(self, item):\n        return self._data.get(item)\n\n    def clear(self, expect_keys: Sequence[str]):\n        if set(self._data.keys()) != set(expect_keys):\n            raise Exception(\n                f\"Unexpected keys when clearing. This may indicate you do not release memory early enough but leave it until here. {list(self._data.keys())=} {expect_keys=}\"\n            )\n\n        self._data.clear()\n\n\ndef _convert_operations_to_stages(operations: List[Operation]) -> List[Stage]:\n    operations = _decorate_operations(operations)\n    operation_chunks = list(\n        _chunk_by_separator(operations, lambda op: isinstance(op, YieldOperation))\n    )\n    assert all(len(chunk) > 0 for chunk in operation_chunks)\n    return operation_chunks\n\n\ndef _chunk_by_separator(\n    items: List[Any], is_separator: Callable[[Any], bool]\n) -> Generator[List[Any], None, None]:\n    pending_items = []\n    for item in items:\n        if is_separator(item):\n            yield pending_items\n            pending_items = []\n        else:\n            pending_items.append(item)\n    if len(pending_items) > 0:\n        yield pending_items\n\n\ndef _decorate_operations(operations: List[Operation], debug_name_prefix: str = \"\"):\n    return [_decorate_operation(op, debug_name_prefix) for op in operations]\n\n\ndef _decorate_operation(operation: Operation, debug_name_prefix: str):\n    if isinstance(operation, YieldOperation):\n        return operation\n    return ExecutionOperation(\n        debug_name=debug_name_prefix\n        + getattr(operation, \"__name__\", \"unknown\").replace(\"op_\", \"\"),\n        fn=operation,\n    )\n"
  },
  {
    "path": "python/sglang/srt/batch_overlap/operations_strategy.py",
    "content": "from dataclasses import dataclass\nfrom typing import List, Optional\n\nimport torch\n\nfrom sglang.srt.batch_overlap import operations\nfrom sglang.srt.batch_overlap.operations import Operation\nfrom sglang.srt.layers.moe.token_dispatcher import DeepEPConfig\nfrom sglang.srt.model_executor.forward_batch_info import ForwardMode\nfrom sglang.srt.utils import is_hip\n\n_is_hip = is_hip()\n\n\n@dataclass\nclass OperationsStrategy:\n    operations: List[Operation]\n    deep_gemm_num_sms: Optional[int] = None\n    tbo_delta_stages: Optional[int] = None\n\n    @classmethod\n    def concat(cls, items: List[\"OperationsStrategy\"]) -> \"OperationsStrategy\":\n        return OperationsStrategy(\n            operations=[x for item in items for x in item.operations],\n            deep_gemm_num_sms=_assert_all_same(\n                [item.deep_gemm_num_sms for item in items]\n            ),\n            tbo_delta_stages=_assert_all_same(\n                [item.tbo_delta_stages for item in items]\n            ),\n        )\n\n    @staticmethod\n    def init_new_tbo(\n        layers: torch.nn.ModuleList,\n        forward_mode: ForwardMode,\n    ) -> \"OperationsStrategy\":\n        layer_name = layers[0].__class__.__name__\n        if layer_name == \"DeepseekV2DecoderLayer\":\n            return OperationsStrategy.concat(\n                [\n                    _compute_moe_deepseek_layer_operations_strategy_tbo(\n                        layer, forward_mode\n                    )\n                    for layer in layers\n                ]\n            )\n        elif layer_name == \"Qwen3MoeDecoderLayer\":\n            return OperationsStrategy.concat(\n                [\n                    _compute_moe_qwen3_layer_operations_strategy_tbo(\n                        layer, forward_mode\n                    )\n                    for layer in layers\n                ]\n            )\n        elif layer_name == \"MiMoV2DecoderLayer\":\n            return OperationsStrategy.concat(\n                [\n                    _compute_moe_mimov2_layer_operations_strategy_tbo(\n                        layer, forward_mode\n                    )\n                    for layer in layers\n                ]\n            )\n        else:\n            raise NotImplementedError\n\n\ndef _assert_all_same(items: List):\n    assert all(item == items[0] for item in items)\n    return items[0]\n\n\n# -------------------------------- Strategy for DeepSeek ---------------------------------------\n\n\n# TODO can refactor to make it more fancy if we have more complex strategies\ndef _compute_moe_deepseek_layer_operations_strategy_tbo(\n    layer: torch.nn.Module,\n    forward_mode: ForwardMode,\n) -> OperationsStrategy:\n    assert layer.is_layer_sparse, \"dense layer TBO not yet implemented\"\n    if forward_mode == ForwardMode.EXTEND:\n        return _compute_moe_deepseek_blog_prefill(layer)\n    elif (\n        forward_mode == ForwardMode.DECODE or forward_mode == ForwardMode.TARGET_VERIFY\n    ):\n        return _compute_moe_deepseek_blog_decode(layer)\n    else:\n        raise NotImplementedError(f\"Unsupported {forward_mode=}\")\n\n\ndef _compute_moe_deepseek_blog_prefill(layer):\n    device_properties = torch.cuda.get_device_properties(device=\"cuda\")\n    total_num_sms = device_properties.multi_processor_count\n    deep_gemm_num_sms = None\n    if not _is_hip:\n        deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms\n\n    return OperationsStrategy(\n        deep_gemm_num_sms=deep_gemm_num_sms,\n        tbo_delta_stages=0,\n        operations=[\n            layer.op_comm_prepare_attn,\n            layer.self_attn.op_prepare,\n            layer.self_attn.op_core,\n            layer.op_comm_prepare_mlp,\n            layer.mlp.op_gate,\n            layer.mlp.op_select_experts,\n            layer.mlp.op_dispatch_a,\n            operations.YieldOperation(),\n            layer.mlp.op_dispatch_b,\n            layer.mlp.op_experts,\n            layer.mlp.op_combine_a,\n            operations.YieldOperation(),\n            layer.mlp.op_shared_experts,\n            layer.mlp.op_combine_b,\n            layer.mlp.op_output,\n            layer.op_comm_postprocess_layer,\n        ],\n    )\n\n\ndef _compute_moe_deepseek_blog_decode(layer):\n    return OperationsStrategy(\n        deep_gemm_num_sms=None,\n        tbo_delta_stages=2,\n        operations=[\n            layer.op_comm_prepare_attn,\n            layer.self_attn.op_prepare,\n            operations.YieldOperation(),\n            layer.self_attn.op_core,\n            layer.op_comm_prepare_mlp,\n            layer.mlp.op_gate,\n            layer.mlp.op_select_experts,\n            operations.YieldOperation(),\n            layer.mlp.op_dispatch_a,\n            layer.mlp.op_shared_experts,\n            operations.YieldOperation(),\n            layer.mlp.op_dispatch_b,\n            layer.mlp.op_experts,\n            layer.mlp.op_combine_a,\n            operations.YieldOperation(),\n            layer.mlp.op_combine_b,\n            operations.YieldOperation(),\n            layer.mlp.op_output,\n            layer.op_comm_postprocess_layer,\n        ],\n    )\n\n\n# -------------------------------- Strategy for Qwen3 ---------------------------------------\n\n\n# TODO: unstable, current strategy is almost the same as DeepSeek, keep redundant code here for\n# convenience to adjust strategy\ndef _compute_moe_qwen3_layer_operations_strategy_tbo(\n    layer: torch.nn.Module,\n    forward_mode: ForwardMode,\n) -> OperationsStrategy:\n    assert layer.is_layer_sparse, \"qwen3 moe only support sparse layers\"\n    if forward_mode == ForwardMode.EXTEND:\n        return _compute_moe_qwen3_prefill(layer)\n    elif (\n        forward_mode == ForwardMode.DECODE or forward_mode == ForwardMode.TARGET_VERIFY\n    ):\n        return _compute_moe_qwen3_decode(layer)\n    else:\n        raise NotImplementedError(f\"Unsupported {forward_mode=}\")\n\n\ndef _compute_moe_qwen3_prefill(layer):\n    device_properties = torch.cuda.get_device_properties(device=\"cuda\")\n    total_num_sms = device_properties.multi_processor_count\n    deep_gemm_num_sms = None\n    if not _is_hip:\n        deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms\n\n    return OperationsStrategy(\n        deep_gemm_num_sms=deep_gemm_num_sms,\n        tbo_delta_stages=0,\n        operations=[\n            layer.op_comm_prepare_attn,\n            layer.self_attn.op_prepare,\n            layer.self_attn.op_core,\n            layer.op_comm_prepare_mlp,\n            layer.mlp.op_gate,\n            layer.mlp.op_select_experts,\n            layer.mlp.op_dispatch_a,\n            operations.YieldOperation(),\n            layer.mlp.op_dispatch_b,\n            layer.mlp.op_experts,\n            layer.mlp.op_combine_a,\n            operations.YieldOperation(),\n            layer.mlp.op_combine_b,\n            layer.mlp.op_output,\n            layer.op_comm_postprocess_layer,\n        ],\n    )\n\n\ndef _compute_moe_qwen3_decode(layer):\n    return OperationsStrategy(\n        deep_gemm_num_sms=None,\n        tbo_delta_stages=2,\n        operations=[\n            layer.op_comm_prepare_attn,\n            layer.self_attn.op_prepare,\n            operations.YieldOperation(),\n            layer.self_attn.op_core,\n            layer.op_comm_prepare_mlp,\n            layer.mlp.op_gate,\n            layer.mlp.op_select_experts,\n            operations.YieldOperation(),\n            layer.mlp.op_dispatch_a,\n            operations.YieldOperation(),\n            layer.mlp.op_dispatch_b,\n            layer.mlp.op_experts,\n            layer.mlp.op_combine_a,\n            operations.YieldOperation(),\n            layer.mlp.op_combine_b,\n            layer.mlp.op_output,\n            layer.op_comm_postprocess_layer,\n            operations.YieldOperation(),\n        ],\n    )\n\n\n# -------------------------------- Strategy for MiMoV2DecoderLayer ---------------------------------------\n\n\n# TODO: unstable; current strategy matches DeepSeek for the common operations (MiMoV2 has no op_shared_experts),\n# so we keep this redundant code here for convenience when adjusting the strategy\ndef _compute_moe_mimov2_layer_operations_strategy_tbo(\n    layer: torch.nn.Module,\n    forward_mode: ForwardMode,\n) -> OperationsStrategy:\n    assert layer.is_layer_sparse, \"MiMoV2DecoderLayer moe only support sparse layers\"\n    if forward_mode == ForwardMode.EXTEND:\n        return _compute_moe_mimov2_prefill(layer)\n    elif (\n        forward_mode == ForwardMode.DECODE or forward_mode == ForwardMode.TARGET_VERIFY\n    ):\n        return _compute_moe_mimov2_decode(layer)\n    else:\n        raise NotImplementedError(f\"Unsupported {forward_mode=}\")\n\n\ndef _compute_moe_mimov2_prefill(layer):\n    device_properties = torch.cuda.get_device_properties(device=\"cuda\")\n    total_num_sms = device_properties.multi_processor_count\n    deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms\n\n    return OperationsStrategy(\n        deep_gemm_num_sms=deep_gemm_num_sms,\n        tbo_delta_stages=0,\n        operations=[\n            layer.op_comm_prepare_attn,\n            layer.self_attn.op_prepare,\n            layer.self_attn.op_core,\n            layer.op_comm_prepare_mlp,\n            layer.mlp.op_gate,\n            layer.mlp.op_select_experts,\n            layer.mlp.op_dispatch_a,\n            operations.YieldOperation(),\n            layer.mlp.op_dispatch_b,\n            layer.mlp.op_experts,\n            layer.mlp.op_combine_a,\n            operations.YieldOperation(),\n            layer.mlp.op_combine_b,\n            layer.mlp.op_output,\n            layer.op_comm_postprocess_layer,\n        ],\n    )\n\n\ndef _compute_moe_mimov2_decode(layer):\n    return OperationsStrategy(\n        deep_gemm_num_sms=None,\n        tbo_delta_stages=2,\n        operations=[\n            layer.op_comm_prepare_attn,\n            layer.self_attn.op_prepare,\n            operations.YieldOperation(),\n            layer.self_attn.op_core,\n            layer.op_comm_prepare_mlp,\n            layer.mlp.op_gate,\n            layer.mlp.op_select_experts,\n            operations.YieldOperation(),\n            layer.mlp.op_dispatch_a,\n            operations.YieldOperation(),\n            layer.mlp.op_dispatch_b,\n            layer.mlp.op_experts,\n            layer.mlp.op_combine_a,\n            operations.YieldOperation(),\n            layer.mlp.op_combine_b,\n            layer.mlp.op_output,\n            layer.op_comm_postprocess_layer,\n            operations.YieldOperation(),\n        ],\n    )\n"
  },
  {
    "path": "python/sglang/srt/batch_overlap/single_batch_overlap.py",
    "content": "# Copyright 2025 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\nfrom __future__ import annotations\n\nfrom dataclasses import dataclass\nfrom typing import Optional\n\nimport torch\n\nfrom sglang.srt.environ import envs\nfrom sglang.srt.layers.moe import get_moe_runner_backend\nfrom sglang.srt.layers.moe.utils import is_sbo_enabled\nfrom sglang.srt.utils import is_blackwell\n\n\nclass SboFlags:\n    # TODO may have: \"enable_dispatch_gateup_gemm_two_stream_overlap\", ...\n\n    @classmethod\n    def enable_combine_down_gemm_two_stream_overlap(cls):\n        return (\n            is_sbo_enabled()\n            # currently only cutedsl backend supports it\n            and (\n                get_moe_runner_backend().is_flashinfer_cutedsl()\n                or (get_moe_runner_backend().is_deep_gemm() and not is_blackwell())\n            )\n        )\n\n    @classmethod\n    def enable_combine_shared_two_stream_overlap(cls):\n        return (\n            is_sbo_enabled()\n            and not cls.enable_dispatch_shared_one_stream_overlap()\n            and not envs.SGLANG_BLACKWELL_OVERLAP_SHARED_EXPERTS_OUTSIDE_SBO.get()\n        )\n\n    @classmethod\n    def enable_dispatch_shared_one_stream_overlap(cls):\n        return is_sbo_enabled() and not is_blackwell()\n\n    @classmethod\n    def fuse_shared_experts_inside_sbo(cls):\n        return (\n            cls.enable_combine_shared_two_stream_overlap()\n            or cls.enable_dispatch_shared_one_stream_overlap()\n        )\n\n\n@dataclass\nclass CombineOverlapArgs:\n    # this \"overlap\" flag means overlapping with down gemm, not the general two-stream overlap\n    overlap: bool\n    stream: torch.cuda.Stream\n    wait_event: torch.cuda.Event\n    num_sms: Optional[int] = None\n    signal: Optional[torch.Tensor] = None\n    block_m: Optional[int] = 64\n    threshold: Optional[int] = 0\n\n\n@dataclass\nclass DownGemmOverlapArgs:\n    num_sms: int\n    signal: torch.Tensor\n    start_event: torch.cuda.Event\n\n\ndef compute_overlap_args(dispatch_output, alt_stream):\n    if not (\n        SboFlags.enable_combine_down_gemm_two_stream_overlap()\n        or SboFlags.enable_combine_shared_two_stream_overlap()\n    ):\n        return None, None, {}\n\n    hidden_states = dispatch_output.hidden_states\n\n    num_local_experts, num_tokens_static, hidden_dim = hidden_states.shape\n\n    total_num_sms = torch.cuda.get_device_properties(\n        device=\"cuda\"\n    ).multi_processor_count\n\n    if envs.SGLANG_DEEPEP_LL_COMBINE_SEND_NUM_SMS.is_set():\n        communicate_num_sms = envs.SGLANG_DEEPEP_LL_COMBINE_SEND_NUM_SMS.get()\n    else:\n        communicate_num_sms = 32 if is_blackwell() else 3\n    compute_num_sms = total_num_sms - communicate_num_sms\n\n    assert alt_stream is not None\n    combine_wait_event = torch.cuda.Event()\n    combine_overlap_args = CombineOverlapArgs(\n        overlap=False,\n        num_sms=communicate_num_sms,\n        stream=alt_stream,\n        wait_event=combine_wait_event,\n    )\n    meta_overlap_args = dict(\n        compute_num_sms=compute_num_sms,\n    )\n    down_gemm_overlap_args = None\n\n    if SboFlags.enable_combine_down_gemm_two_stream_overlap():\n        # TODO use zero_allocator to remove this `torch.zeros` call\n        # NOTE ours v2 use uint32 not int32 currently\n        if is_blackwell():\n            combine_signal = torch.zeros(\n                num_local_experts, dtype=torch.uint32, device=hidden_states.device\n            )\n        else:\n            MIN_BLOCK_M = 64\n            combine_signal_size = num_local_experts * (\n                (num_tokens_static + MIN_BLOCK_M - 1) // MIN_BLOCK_M\n            )\n            combine_signal = torch.zeros(\n                combine_signal_size, dtype=torch.int32, device=hidden_states.device\n            )\n\n        down_gemm_overlap_args = DownGemmOverlapArgs(\n            signal=combine_signal,\n            start_event=combine_wait_event,\n            num_sms=compute_num_sms,\n        )\n        combine_overlap_args.overlap = True\n        combine_overlap_args.signal = combine_signal\n        combine_overlap_args.threshold = compute_num_sms\n    else:\n        meta_overlap_args |= dict(\n            record_event_after_down=combine_wait_event,\n        )\n\n    return combine_overlap_args, down_gemm_overlap_args, meta_overlap_args\n"
  },
  {
    "path": "python/sglang/srt/batch_overlap/two_batch_overlap.py",
    "content": "from __future__ import annotations\n\nimport copy\nimport dataclasses\nimport logging\nfrom dataclasses import replace\nfrom typing import TYPE_CHECKING, Dict, List, Optional, Sequence\n\nimport torch\n\nfrom sglang.srt.batch_overlap.operations import (\n    execute_operations,\n    execute_overlapped_operations,\n)\nfrom sglang.srt.batch_overlap.operations_strategy import OperationsStrategy\nfrom sglang.srt.layers import deep_gemm_wrapper\nfrom sglang.srt.layers.attention.base_attn_backend import AttentionBackend\nfrom sglang.srt.layers.communicator import (\n    CommunicateContext,\n    CommunicateSummableTensorPairFn,\n    ScatterMode,\n)\nfrom sglang.srt.layers.dp_attention import get_attention_tp_size\nfrom sglang.srt.layers.moe import (\n    get_deepep_mode,\n    get_moe_a2a_backend,\n    get_tbo_token_distribution_threshold,\n    is_tbo_enabled,\n)\nfrom sglang.srt.layers.moe.token_dispatcher import (\n    DeepEPDispatcher,\n    MooncakeEPDispatcher,\n    MoriEPDispatcher,\n    NixlEPDispatcher,\n)\nfrom sglang.srt.layers.moe.token_dispatcher.base import BaseDispatcher\nfrom sglang.srt.managers.schedule_batch import ScheduleBatch\nfrom sglang.srt.model_executor.forward_batch_info import (\n    ForwardBatch,\n    ForwardMode,\n    compute_position,\n)\nfrom sglang.srt.server_args import get_global_server_args\nfrom sglang.srt.speculative.spec_info import SpecInput\nfrom sglang.srt.utils import BumpAllocator, empty_context, get_bool_env_var, is_hip\n\nif TYPE_CHECKING:\n    from sglang.srt.batch_overlap.single_batch_overlap import CombineOverlapArgs\n    from sglang.srt.layers.moe.token_dispatcher import DispatchOutput\n    from sglang.srt.speculative.eagle_info import EagleVerifyInput\n\n_is_hip = is_hip()\n\n_tbo_debug = get_bool_env_var(\"SGLANG_TBO_DEBUG\")\n\nlogger = logging.getLogger(__name__)\n\n\n# -------------------------------- Compute Basic Info ---------------------------------------\n\n\ndef get_token_num_per_seq(\n    forward_mode: ForwardMode,\n    spec_info: Optional[SpecInput] = None,\n):\n    if forward_mode.is_target_verify():\n        return spec_info.draft_token_num\n    elif forward_mode.is_decode():\n        return 1\n    elif forward_mode.is_idle():\n        return 0\n    else:\n        # For extend, we should not use `token_num_per_seq`.\n        return None\n\n\n# TODO: may smartly disable TBO when batch size is too small b/c it will slow down\ndef compute_split_seq_index(\n    forward_mode: ForwardMode,\n    num_tokens: int,\n    extend_lens: Optional[Sequence[int]],\n    token_num_per_seq: Optional[int],\n) -> Optional[int]:\n    if forward_mode == ForwardMode.EXTEND:\n        assert extend_lens is not None\n        return _split_extend_seqs(extend_lens)\n    elif forward_mode.is_target_verify() or forward_mode.is_decode():\n        assert token_num_per_seq is not None\n        return (num_tokens // token_num_per_seq) // 2\n    elif forward_mode.is_idle() or forward_mode.is_prebuilt():\n        assert num_tokens == 0\n        return 0\n    else:\n        raise NotImplementedError()\n\n\ndef _is_two_chunk_split_enabled(extend_lens: Sequence[int]) -> bool:\n    if extend_lens is None:\n        return False\n\n    vanilla_split_seq_index = _split_array_by_balanced_sum(extend_lens)\n    left_sum = sum(extend_lens[:vanilla_split_seq_index])\n    overall_sum = sum(extend_lens)\n    threshold = get_tbo_token_distribution_threshold()\n    assert threshold <= 0.5, f\"{threshold=}\"\n    return left_sum < overall_sum * threshold or left_sum > overall_sum * (\n        1 - threshold\n    )\n\n\ndef _split_extend_seqs(arr: Sequence[int]) -> int:\n    if _is_two_chunk_split_enabled(arr):\n        return _split_array_by_cum_less_than_half(arr)\n\n    return _split_array_by_balanced_sum(arr)\n\n\ndef _split_array_by_cum_less_than_half(arr: Sequence[int]) -> int:\n    left_sum = 0\n    overall_sum = sum(arr)\n    half_sum = overall_sum // 2\n    chosen_index = 0\n\n    for i in range(len(arr)):\n        left_sum += arr[i]\n        if left_sum > half_sum:\n            chosen_index = i\n            break\n\n    return chosen_index\n\n\ndef _split_array_by_balanced_sum(arr: Sequence[int]) -> int:\n    overall_sum = sum(arr)\n    left_sum = 0\n    min_diff = float(\"inf\")\n    best_index = 0\n\n    for i in range(1, len(arr)):\n        left_sum += arr[i - 1]\n        right_sum = overall_sum - left_sum\n        diff = abs(left_sum - right_sum)\n        if diff <= min_diff:\n            min_diff = diff\n            best_index = i\n        else:\n            break\n\n    return best_index\n\n\ndef _update_device_and_sum_field_from_cpu_field(\n    batch: ForwardBatch, cpu_field: str, device_field: str, sum_field: str = None\n):\n    cpu_value = getattr(batch, cpu_field, None)\n    old_device_value = getattr(batch, device_field, None)\n    if (\n        cpu_value is None\n        or old_device_value is None\n        or not (isinstance(cpu_value, torch.Tensor) or isinstance(cpu_value, list))\n    ):\n        return\n\n    new_device_value = (\n        cpu_value\n        if isinstance(cpu_value, torch.Tensor)\n        else torch.tensor(cpu_value, dtype=old_device_value.dtype)\n    ).to(device=get_global_server_args().device, non_blocking=True)\n    setattr(batch, device_field, new_device_value)\n\n    if sum_field is not None:\n        sum_value = (\n            cpu_value.sum().item()\n            if isinstance(cpu_value, torch.Tensor)\n            else sum(cpu_value)\n        )\n        setattr(batch, sum_field, sum_value)\n\n\ndef _compute_mask_offset(seq_index: int, spec_info: Optional[EagleVerifyInput]) -> int:\n    if seq_index == 0:\n        return 0\n\n    offset = 0\n    max_seq_len = min(seq_index, spec_info.seq_lens_cpu.shape[0])\n    for i in range(max_seq_len):\n        offset += (\n            spec_info.seq_lens_cpu[i] + spec_info.draft_token_num\n        ) * spec_info.draft_token_num\n    return offset\n\n\ndef split_spec_info(\n    spec_info: Optional[EagleVerifyInput],\n    start_seq_index: int,\n    end_seq_index: int,\n    start_token_index: int,\n    end_token_index: int,\n):\n    if spec_info is None:\n        return None\n    if spec_info.draft_token is not None:\n        draft_token = spec_info.draft_token[start_token_index:end_token_index]\n    else:\n        draft_token = None\n    if spec_info.custom_mask is not None and spec_info.draft_token is not None:\n        custom_mask_start = _compute_mask_offset(start_seq_index, spec_info)\n        if end_seq_index == spec_info.seq_lens_cpu.shape[0]:\n            custom_mask_end = spec_info.custom_mask.shape[0]\n        else:\n            custom_mask_end = _compute_mask_offset(end_seq_index, spec_info)\n\n        if custom_mask_end > custom_mask_start:\n            custom_mask = spec_info.custom_mask[custom_mask_start:custom_mask_end]\n        else:\n            custom_mask = spec_info.custom_mask\n    else:\n        custom_mask = spec_info.custom_mask\n    if spec_info.positions is not None:\n        positions = spec_info.positions[start_token_index:end_token_index]\n    else:\n        positions = None\n    if spec_info.retrive_index is not None:\n        retrive_index = spec_info.retrive_index[start_seq_index:end_seq_index]\n    else:\n        retrive_index = None\n    if spec_info.retrive_next_token is not None:\n        retrive_next_token = spec_info.retrive_next_token[start_seq_index:end_seq_index]\n    else:\n        retrive_next_token = None\n    if spec_info.retrive_next_sibling is not None:\n        retrive_next_sibling = spec_info.retrive_next_sibling[\n            start_seq_index:end_seq_index\n        ]\n    else:\n        retrive_next_sibling = None\n    if spec_info.retrive_cum_len is not None:\n        retrive_cum_len = spec_info.retrive_cum_len[start_seq_index:end_seq_index]\n    else:\n        retrive_cum_len = None\n\n    if spec_info.seq_lens_cpu is not None:\n        seq_lens_cpu = spec_info.seq_lens_cpu[start_seq_index:end_seq_index]\n    else:\n        seq_lens_cpu = None\n    if seq_lens_cpu is not None:\n        seq_lens_sum = seq_lens_cpu.sum()\n    else:\n        seq_lens_sum = None\n    output_spec_info = replace(\n        spec_info,\n        custom_mask=custom_mask,\n        draft_token=draft_token,\n        positions=positions,\n        retrive_index=retrive_index,\n        retrive_next_token=retrive_next_token,\n        retrive_next_sibling=retrive_next_sibling,\n        retrive_cum_len=retrive_cum_len,\n        seq_lens_cpu=seq_lens_cpu,\n        seq_lens_sum=seq_lens_sum,\n    )\n    return output_spec_info\n\n\ndef compute_split_token_index(\n    split_seq_index: int,\n    forward_mode: \"ForwardMode\",\n    extend_seq_lens: Optional[Sequence[int]],\n    token_num_per_seq: Optional[int],\n) -> int:\n    if forward_mode == ForwardMode.EXTEND:\n        assert extend_seq_lens is not None\n        if _is_two_chunk_split_enabled(extend_seq_lens):\n            return sum(extend_seq_lens) // 2\n        return sum(extend_seq_lens[:split_seq_index])\n    elif forward_mode.is_target_verify() or forward_mode.is_decode():\n        assert token_num_per_seq is not None\n        return split_seq_index * token_num_per_seq\n    elif forward_mode.is_idle():\n        assert split_seq_index == 0\n        return 0\n    else:\n        raise NotImplementedError\n\n\ndef compute_split_indices_for_cuda_graph_replay(\n    forward_mode: ForwardMode,\n    cuda_graph_num_tokens: int,\n    spec_info: Optional[SpecInput],\n):\n    forward_mode_for_tbo_split = (\n        forward_mode if forward_mode != ForwardMode.IDLE else ForwardMode.DECODE\n    )\n    token_num_per_seq = get_token_num_per_seq(\n        forward_mode=forward_mode, spec_info=spec_info\n    )\n    tbo_split_seq_index = compute_split_seq_index(\n        forward_mode=forward_mode_for_tbo_split,\n        num_tokens=cuda_graph_num_tokens,\n        extend_lens=None,\n        token_num_per_seq=token_num_per_seq,\n    )\n    tbo_split_token_index = compute_split_token_index(\n        split_seq_index=tbo_split_seq_index,\n        forward_mode=forward_mode_for_tbo_split,\n        extend_seq_lens=None,\n        token_num_per_seq=token_num_per_seq,\n    )\n    return tbo_split_seq_index, tbo_split_token_index\n\n\n# -------------------------------- Preparation ---------------------------------------\n\n\nclass TboCudaGraphRunnerPlugin:\n    def __init__(self):\n        self._tbo_children_num_token_non_padded = torch.zeros((2,), dtype=torch.int32)\n\n    def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int):\n        if not is_tbo_enabled():\n            return\n        token_num_per_seq = get_token_num_per_seq(\n            forward_mode=batch.forward_mode, spec_info=batch.spec_info\n        )\n\n        batch.tbo_split_seq_index = compute_split_seq_index(\n            forward_mode=batch.forward_mode,\n            num_tokens=num_tokens,\n            extend_lens=None,\n            token_num_per_seq=token_num_per_seq,\n        )\n        # For simplicity, when two_batch_overlap is enabled, we only capture CUDA Graph for tbo=true\n        assert batch.tbo_split_seq_index is not None, f\"{num_tokens=}\"\n\n        self._tbo_children_num_token_non_padded[...] = (\n            TboForwardBatchPreparer.compute_tbo_children_num_token_non_padded(batch)\n        )\n\n        TboForwardBatchPreparer.prepare_raw(\n            batch,\n            tbo_children_num_token_non_padded=self._tbo_children_num_token_non_padded,\n        )\n\n    def replay_prepare(\n        self,\n        forward_mode: ForwardMode,\n        bs: int,\n        num_token_non_padded: int,\n        spec_info: Optional[SpecInput],\n    ):\n        token_num_per_seq = get_token_num_per_seq(\n            forward_mode=forward_mode, spec_info=spec_info\n        )\n        tbo_split_seq_index, tbo_split_token_index = (\n            compute_split_indices_for_cuda_graph_replay(\n                forward_mode=forward_mode,\n                cuda_graph_num_tokens=bs * token_num_per_seq,\n                spec_info=spec_info,\n            )\n        )\n\n        self._tbo_children_num_token_non_padded[...] = (\n            TboForwardBatchPreparer.compute_tbo_children_num_token_non_padded_raw(\n                tbo_split_token_index=tbo_split_token_index,\n                num_token_non_padded=num_token_non_padded,\n            )\n        )\n\n\nclass TboDPAttentionPreparer:\n    def prepare_all_gather(\n        self,\n        local_batch: ScheduleBatch,\n    ):\n\n        deepep_mode = get_deepep_mode()\n        enable_a2a_moe = not get_moe_a2a_backend().is_none()\n        enable_two_batch_overlap = is_tbo_enabled()\n\n        self.enable_two_batch_overlap = enable_two_batch_overlap\n\n        if local_batch is not None:\n            token_num_per_seq = get_token_num_per_seq(\n                forward_mode=local_batch.forward_mode, spec_info=local_batch.spec_info\n            )\n\n            if (\n                local_batch.forward_mode.is_target_verify()\n                or local_batch.forward_mode.is_decode()\n            ):\n                num_tokens = local_batch.batch_size() * token_num_per_seq\n            elif local_batch.forward_mode.is_prebuilt():\n                num_tokens = 0\n            else:\n                num_tokens = local_batch.extend_num_tokens\n            self.local_tbo_split_seq_index = compute_split_seq_index(\n                forward_mode=local_batch.forward_mode,\n                num_tokens=num_tokens,\n                extend_lens=local_batch.extend_lens,\n                token_num_per_seq=token_num_per_seq,\n            )\n            resolved_deepep_mode = deepep_mode.resolve(local_batch.is_extend_in_batch)\n            local_can_run_tbo = (self.local_tbo_split_seq_index is not None) and not (\n                (\n                    local_batch.forward_mode.is_extend()\n                    and not local_batch.forward_mode.is_target_verify()\n                )\n                and enable_a2a_moe\n                and (resolved_deepep_mode.is_low_latency())\n            )\n        else:\n            self.local_tbo_split_seq_index = 0\n            local_can_run_tbo = True\n\n        local_forward_mode = self._compute_local_forward_mode(local_batch)\n\n        return local_can_run_tbo, local_forward_mode\n\n    def compute_output(self, partial_global_info):\n        # Perform only one Device-to-Host (D2H) memory copy\n        cpu_data = partial_global_info[:, :2].cpu()\n        local_can_run_tbo_aggregated = min(cpu_data[:, 0].tolist())\n        forward_modes = cpu_data[:, 1].tolist()\n\n        global_forward_mode, forward_mode_agree = self._compute_global_forward_mode(\n            forward_modes\n        )\n\n        can_run_tbo = (\n            self.enable_two_batch_overlap\n            and local_can_run_tbo_aggregated\n            and forward_mode_agree\n        )\n\n        tbo_split_seq_index = self.local_tbo_split_seq_index if can_run_tbo else None\n        global_forward_mode = global_forward_mode if can_run_tbo else None\n        return tbo_split_seq_index, global_forward_mode\n\n    @staticmethod\n    def _compute_local_forward_mode(local_batch):\n        return (\n            local_batch.forward_mode if local_batch is not None else ForwardMode.IDLE\n        ).value\n\n    @staticmethod\n    def _compute_global_forward_mode(forward_modes):\n        forward_modes_excluding_idle_and_prebuilt = [\n            x\n            for x in forward_modes\n            if x != ForwardMode.IDLE.value and x != ForwardMode.PREBUILT.value\n        ]\n\n        if not forward_modes_excluding_idle_and_prebuilt:\n            return ForwardMode.IDLE, False\n\n        forward_mode_agree = TboDPAttentionPreparer._is_all_same(\n            forward_modes_excluding_idle_and_prebuilt\n        )\n\n        global_forward_mode = (\n            ForwardMode(forward_modes_excluding_idle_and_prebuilt[0])\n            if forward_mode_agree\n            else None\n        )\n        return global_forward_mode, forward_mode_agree\n\n    @staticmethod\n    def _is_all_same(x):\n        return all(value == x[0] for value in x)\n\n\nclass TboForwardBatchPreparer:\n    @classmethod\n    def prepare(cls, batch: ForwardBatch, is_draft_worker: bool = False):\n        if batch.tbo_split_seq_index is None or is_draft_worker:\n            return\n\n        tbo_children_num_token_non_padded = (\n            cls.compute_tbo_children_num_token_non_padded(batch)\n        )\n        cls.prepare_raw(\n            batch, tbo_children_num_token_non_padded=tbo_children_num_token_non_padded\n        )\n\n    @classmethod\n    def prepare_raw(\n        cls, batch: ForwardBatch, tbo_children_num_token_non_padded: torch.Tensor\n    ):\n        from sglang.srt.layers.attention.tbo_backend import TboAttnBackend\n\n        tbo_split_token_index = cls._compute_split_token_index(batch)\n\n        is_enable_two_chunk = (\n            batch.forward_mode == ForwardMode.EXTEND\n            and _is_two_chunk_split_enabled(batch.extend_seq_lens_cpu)\n        )\n\n        if _tbo_debug:\n            logger.info(\n                f\"TboForwardBatchPreparer.prepare \"\n                f\"is_enable_two_chunk={is_enable_two_chunk} \"\n                f\"tbo_split_seq_index={batch.tbo_split_seq_index} \"\n                f\"tbo_split_token_index={tbo_split_token_index} \"\n                f\"extend_seq_lens={batch.extend_seq_lens_cpu} \"\n                f\"bs={batch.batch_size} \"\n                f\"forward_mode={batch.forward_mode}\"\n            )\n\n        assert isinstance(batch.attn_backend, TboAttnBackend)\n        attn_backend_child_a, attn_backend_child_b = batch.attn_backend.children\n\n        [out_num_token_non_padded_a, out_num_token_non_padded_b] = (\n            tbo_children_num_token_non_padded\n        )\n\n        child_a = cls.filter_batch(\n            batch,\n            start_token_index=0,\n            end_token_index=tbo_split_token_index,\n            start_seq_index=0,\n            end_seq_index=(\n                batch.tbo_split_seq_index + 1\n                if is_enable_two_chunk\n                else batch.tbo_split_seq_index\n            ),\n            output_attn_backend=attn_backend_child_a,\n            out_num_token_non_padded=out_num_token_non_padded_a,\n        )\n        child_b = cls.filter_batch(\n            batch,\n            start_token_index=tbo_split_token_index,\n            end_token_index=batch.input_ids.shape[0],\n            start_seq_index=batch.tbo_split_seq_index,\n            end_seq_index=batch.batch_size,\n            output_attn_backend=attn_backend_child_b,\n            out_num_token_non_padded=out_num_token_non_padded_b,\n        )\n\n        if is_enable_two_chunk:\n            cls.derive_fields_related_to_seq_len_for_two_chunk(\n                batch,\n                child_a=child_a,\n                child_b=child_b,\n                tbo_split_seq_index=batch.tbo_split_seq_index,\n            )\n\n        assert batch.tbo_children is None\n        batch.tbo_children = [child_a, child_b]\n\n    @classmethod\n    def derive_fields_related_to_seq_len_for_two_chunk(\n        cls,\n        batch: ForwardBatch,\n        *,\n        child_a: ForwardBatch,\n        child_b: ForwardBatch,\n        tbo_split_seq_index: int,\n    ):\n        extend_seq_lens_cpu = batch.extend_seq_lens_cpu\n        overall_seq_lens_sum = sum(extend_seq_lens_cpu)\n        half_seq_lens_sum = overall_seq_lens_sum // 2\n        left_last_seq_token_num = half_seq_lens_sum - sum(\n            extend_seq_lens_cpu[:tbo_split_seq_index]\n        )\n        right_first_seq_token_num = (\n            extend_seq_lens_cpu[tbo_split_seq_index] - left_last_seq_token_num\n        )\n\n        # making deepcopy to be extra safe\n        child_a.extend_seq_lens_cpu = copy.deepcopy(child_a.extend_seq_lens_cpu)\n        child_a.extend_seq_lens_cpu[-1] = left_last_seq_token_num\n        child_b.extend_seq_lens_cpu = copy.deepcopy(child_b.extend_seq_lens_cpu)\n        child_b.extend_seq_lens_cpu[0] = right_first_seq_token_num\n        for child in [child_a, child_b]:\n            _update_device_and_sum_field_from_cpu_field(\n                batch=child,\n                cpu_field=\"extend_seq_lens_cpu\",\n                device_field=\"extend_seq_lens\",\n                sum_field=\"extend_num_tokens\",\n            )\n\n        assert (\n            child_a.extend_num_tokens == half_seq_lens_sum\n        ), f\"{child_a.extend_num_tokens=}, {half_seq_lens_sum=}\"\n\n        child_a.seq_lens_cpu = copy.deepcopy(child_a.seq_lens_cpu)\n        child_a.seq_lens_cpu[-1] = (\n            child_a.extend_seq_lens_cpu[-1] + child_a.extend_prefix_lens_cpu[-1]\n        )\n        _update_device_and_sum_field_from_cpu_field(\n            batch=child_a,\n            cpu_field=\"seq_lens_cpu\",\n            device_field=\"seq_lens\",\n            sum_field=\"seq_lens_sum\",\n        )\n\n        child_b.extend_prefix_lens_cpu = copy.deepcopy(child_b.extend_prefix_lens_cpu)\n        child_b.extend_prefix_lens_cpu[0] += left_last_seq_token_num\n        _update_device_and_sum_field_from_cpu_field(\n            batch=child_b,\n            cpu_field=\"extend_prefix_lens_cpu\",\n            device_field=\"extend_prefix_lens\",\n            sum_field=None,\n        )\n        _, child_b.extend_start_loc = compute_position(\n            get_global_server_args().attention_backend,\n            child_b.extend_prefix_lens,\n            child_b.extend_seq_lens,\n            child_b.extend_num_tokens,\n        )\n\n    @classmethod\n    def filter_batch(\n        cls,\n        batch: ForwardBatch,\n        *,\n        start_token_index: int,\n        end_token_index: int,\n        start_seq_index: int,\n        end_seq_index: int,\n        output_attn_backend: AttentionBackend,\n        out_num_token_non_padded: torch.Tensor,\n    ):\n        assert (\n            end_token_index >= start_token_index\n        ), f\"{end_token_index=}, {start_token_index=}, batch={batch}\"\n        num_tokens = batch.input_ids.shape[0]\n        num_seqs = batch.batch_size\n\n        output_dict = dict()\n\n        for key in [\n            \"input_ids\",\n            \"positions\",\n            \"out_cache_loc\",\n        ]:\n            old_value = getattr(batch, key)\n            assert (\n                old_value.shape[0] == num_tokens\n            ), f\"{key=} {old_value=} {num_tokens=} {batch=}\"\n            output_dict[key] = old_value[start_token_index:end_token_index]\n\n        attention_tp_size = get_attention_tp_size()\n        output_dict[\"tbo_padded_len\"] = (\n            (end_token_index - start_token_index - 1) // attention_tp_size + 1\n        ) * attention_tp_size\n\n        for key in [\n            \"req_pool_indices\",\n            \"seq_lens\",\n            \"seq_lens_cpu\",\n            \"extend_seq_lens\",\n            \"extend_prefix_lens\",\n            \"extend_start_loc\",\n            \"extend_prefix_lens_cpu\",\n            \"extend_seq_lens_cpu\",\n            \"extend_logprob_start_lens_cpu\",\n            \"lora_ids\",\n            \"rids\",\n        ]:\n            old_value = getattr(batch, key)\n            if old_value is None:\n                continue\n            elif batch.forward_mode.is_target_verify() and (\n                key == \"extend_seq_lens\"\n                or key == \"extend_prefix_lens\"\n                or key == \"extend_start_loc\"\n                or key == \"extend_prefix_lens_cpu\"\n                or key == \"extend_seq_lens_cpu\"\n                or key == \"extend_logprob_start_lens_cpu\"\n            ):\n                output_dict[key] = None\n                continue\n            assert (\n                len(old_value) == num_seqs\n            ), f\"{key=} {old_value=} {num_seqs=} {batch=}\"\n            output_dict[key] = old_value[start_seq_index:end_seq_index]\n\n        spec_info = getattr(batch, \"spec_info\")\n        output_spec_info = split_spec_info(\n            spec_info=spec_info,\n            start_token_index=start_token_index,\n            end_token_index=end_token_index,\n            start_seq_index=start_seq_index,\n            end_seq_index=end_seq_index,\n        )\n        output_dict[\"spec_info\"] = output_spec_info\n        for key in [\n            \"forward_mode\",\n            \"is_extend_in_batch\",\n            \"all_extend_in_batch\",\n            \"return_logprob\",\n            \"req_to_token_pool\",\n            \"token_to_kv_pool\",\n            \"can_run_dp_cuda_graph\",\n            \"dp_padding_mode\",\n            \"global_forward_mode\",\n            \"is_prefill_only\",\n            \"spec_algorithm\",\n            \"capture_hidden_mode\",\n            \"padded_static_len\",\n            \"mrope_positions\",  # only used by qwen2-vl, thus not care\n            \"split_index\",  # for split prefill\n            \"orig_seq_lens\",  # only used by qwen-1m, thus not care\n        ]:\n            output_dict[key] = getattr(batch, key)\n        if not batch.forward_mode.is_target_verify():\n            assert (\n                _compute_extend_num_tokens(batch.input_ids, batch.forward_mode)\n                == batch.extend_num_tokens\n            ), f\"{batch=}\"\n        extend_num_tokens = _compute_extend_num_tokens(\n            output_dict[\"input_ids\"], output_dict[\"forward_mode\"]\n        )\n\n        # TODO improve, e.g. unify w/ `init_raw`\n        if (\n            get_global_server_args().moe_dense_tp_size == 1\n            and batch.global_dp_buffer_len is not None\n        ):\n            sum_len = end_token_index - start_token_index\n            global_dp_buffer_len = sum_len\n        else:\n            global_dp_buffer_len = None\n\n        output_dict.update(\n            dict(\n                batch_size=end_seq_index - start_seq_index,\n                seq_lens_sum=(\n                    output_dict[\"seq_lens_cpu\"].sum()\n                    if \"seq_lens_cpu\" in output_dict\n                    else None\n                ),\n                extend_num_tokens=extend_num_tokens,\n                attn_backend=output_attn_backend,\n                num_token_non_padded=out_num_token_non_padded,\n                # TODO: handle it when we need TBO + DeepSeek V3.2\n                num_token_non_padded_cpu=None,\n                tbo_split_seq_index=None,\n                tbo_parent_token_range=(start_token_index, end_token_index),\n                tbo_children=None,\n                original_global_num_tokens_cpu=None,\n                global_num_tokens_gpu=None,\n                global_num_tokens_cpu=None,\n                global_dp_buffer_len=global_dp_buffer_len,\n                global_num_tokens_for_logprob_gpu=None,\n                global_num_tokens_for_logprob_cpu=None,\n                sampling_info=None,\n                # For logits and logprobs post processing, thus we do not care\n                temp_scaled_logprobs=False,\n                temperature=None,\n                top_p_normalized_logprobs=False,\n                top_p=None,\n                mm_inputs=None,\n                top_logprobs_nums=None,\n                token_ids_logprobs=None,\n                next_token_logits_buffer=None,\n                return_hidden_states_before_norm=False,\n            )\n        )\n\n        errors = []\n        for field in dataclasses.fields(ForwardBatch):\n            if getattr(batch, field.name) is not None and field.name not in output_dict:\n                errors.append(\n                    f\"Field {field.name} has value, but is not yet supported (value={getattr(batch, field.name)} batch={batch})\"\n                )\n        if len(errors) > 0:\n            raise Exception(f\"{len(errors)} errors happen:\\n\" + \"\\n\\n\".join(errors))\n\n        return ForwardBatch(**output_dict)\n\n    @classmethod\n    def compute_tbo_children_num_token_non_padded(cls, batch: ForwardBatch):\n        return cls.compute_tbo_children_num_token_non_padded_raw(\n            tbo_split_token_index=cls._compute_split_token_index(batch),\n            num_token_non_padded=len(batch.input_ids),\n        )\n\n    @classmethod\n    def compute_tbo_children_num_token_non_padded_raw(\n        cls, tbo_split_token_index: int, num_token_non_padded: int\n    ):\n        # TODO we may make padding on both sub-batches to make it slightly more balanced\n        value_a = min(tbo_split_token_index, num_token_non_padded)\n        value_b = max(0, num_token_non_padded - tbo_split_token_index)\n        return torch.tensor([value_a, value_b], dtype=torch.int32).to(\n            device=get_global_server_args().device, non_blocking=True\n        )\n\n    @classmethod\n    def _compute_split_token_index(cls, batch: ForwardBatch):\n        token_num_per_seq = get_token_num_per_seq(\n            forward_mode=batch.forward_mode, spec_info=batch.spec_info\n        )\n        return compute_split_token_index(\n            split_seq_index=batch.tbo_split_seq_index,\n            forward_mode=batch.forward_mode,\n            extend_seq_lens=batch.extend_seq_lens_cpu,\n            token_num_per_seq=token_num_per_seq,\n        )\n\n\ndef _compute_extend_num_tokens(input_ids, forward_mode: ForwardMode):\n    if (\n        forward_mode.is_decode()\n        or forward_mode.is_idle()\n        or forward_mode.is_target_verify()\n    ):\n        return None\n    elif forward_mode.is_extend():\n        return input_ids.shape[0]\n    raise NotImplementedError\n\n\n# -------------------------------- Execution ---------------------------------------\n\n\ndef model_forward_maybe_tbo(\n    layers,\n    enable_tbo: bool,\n    positions: torch.Tensor,\n    forward_batch: ForwardBatch,\n    hidden_states: torch.Tensor,\n    input_data_scatter_mode: ScatterMode,\n    residual: Optional[torch.Tensor],\n    zero_allocator: Optional[BumpAllocator] = None,\n):\n    inputs = dict(\n        positions=positions,\n        hidden_states=hidden_states,\n        forward_batch=forward_batch,\n        residual=residual,\n        zero_allocator=zero_allocator,\n    )\n    layer_input_scatter_mode = layers[0].layer_scatter_modes.layer_input_mode\n    operations_strategy = OperationsStrategy.init_new_tbo(\n        layers, forward_batch.global_forward_mode\n    )\n    if enable_tbo:\n        return _model_forward_tbo(\n            inputs=inputs,\n            operations_strategy=operations_strategy,\n            input_data_scatter_mode=input_data_scatter_mode,\n            layer_input_scatter_mode=layer_input_scatter_mode,\n        )\n    else:\n        return _model_forward_non_tbo(inputs, operations_strategy)\n\n\ndef _model_forward_tbo(\n    inputs,\n    operations_strategy: OperationsStrategy,\n    input_data_scatter_mode: ScatterMode,\n    layer_input_scatter_mode: ScatterMode,\n):\n    inputs_arr = _model_forward_tbo_split_inputs(\n        **inputs,\n        input_data_scatter_mode=input_data_scatter_mode,\n        layer_input_scatter_mode=layer_input_scatter_mode,\n    )\n    original_hidden_states_len = inputs[\"hidden_states\"].shape[0]\n    del inputs\n\n    context = (\n        empty_context()\n        if _is_hip\n        else deep_gemm_wrapper.configure_deep_gemm_num_sms(\n            operations_strategy.deep_gemm_num_sms\n        )\n    )\n\n    with context:\n        outputs_arr = execute_overlapped_operations(\n            inputs_arr=inputs_arr,\n            operations_arr=[operations_strategy.operations] * 2,\n            delta_stages=[0, operations_strategy.tbo_delta_stages],\n        )\n\n    return _model_forward_tbo_merge_outputs(*outputs_arr, original_hidden_states_len)\n\n\ndef _model_forward_non_tbo(inputs, operations_strategy: OperationsStrategy):\n    outputs = execute_operations(inputs, operations_strategy.operations)\n    return outputs[\"hidden_states\"], outputs[\"residual\"]\n\n\ndef _model_forward_tbo_split_inputs(\n    hidden_states: torch.Tensor,\n    residual: torch.Tensor,\n    positions: torch.Tensor,\n    forward_batch: ForwardBatch,\n    zero_allocator: Optional[BumpAllocator],\n    input_data_scatter_mode: ScatterMode,\n    layer_input_scatter_mode: ScatterMode,\n) -> List[Dict]:\n    tbo_splitter_scatter_mode = ScatterMode.TP_ATTN_FULL\n    context = CommunicateContext.init_new()\n\n    hidden_states, residual = CommunicateSummableTensorPairFn.execute(\n        hidden_states_input_mode=input_data_scatter_mode,\n        residual_input_mode=input_data_scatter_mode,\n        output_mode=tbo_splitter_scatter_mode,\n        hidden_states=hidden_states,\n        residual=residual,\n        forward_batch=forward_batch,\n        context=context,\n    )\n\n    inputs_arr = _model_forward_tbo_split_inputs_raw(\n        hidden_states=hidden_states,\n        residual=residual,\n        positions=positions,\n        forward_batch=forward_batch,\n        zero_allocator=zero_allocator,\n    )\n\n    def _post_transform(hidden_states, residual, forward_batch, **kwargs):\n        hidden_states, residual = CommunicateSummableTensorPairFn.execute(\n            hidden_states_input_mode=tbo_splitter_scatter_mode,\n            residual_input_mode=tbo_splitter_scatter_mode,\n            output_mode=layer_input_scatter_mode,\n            hidden_states=hidden_states,\n            residual=residual,\n            forward_batch=forward_batch,\n            context=context,\n        )\n        return dict(\n            hidden_states=hidden_states,\n            residual=residual,\n            forward_batch=forward_batch,\n            **kwargs,\n        )\n\n    return [_post_transform(**inputs) for inputs in inputs_arr]\n\n\ndef _model_forward_tbo_split_inputs_raw(\n    hidden_states: torch.Tensor,\n    residual: torch.Tensor,\n    positions: torch.Tensor,\n    forward_batch: ForwardBatch,\n    zero_allocator: Optional[BumpAllocator],\n) -> List[Dict]:\n    return [\n        dict(\n            **_model_forward_filter_inputs(\n                hidden_states=hidden_states,\n                residual=residual,\n                positions=positions,\n                output_forward_batch=output_forward_batch,\n                tbo_subbatch_index=tbo_subbatch_index,\n            ),\n            **(\n                dict(zero_allocator=zero_allocator)\n                if zero_allocator is not None\n                else {}\n            ),\n        )\n        for tbo_subbatch_index, output_forward_batch in enumerate(\n            forward_batch.tbo_children\n        )\n    ]\n\n\ndef _model_forward_filter_inputs(\n    hidden_states: torch.Tensor,\n    residual: torch.Tensor,\n    positions: torch.Tensor,\n    output_forward_batch: ForwardBatch,\n    tbo_subbatch_index: int,\n) -> Dict:\n    token_slice = slice(*output_forward_batch.tbo_parent_token_range)\n    hidden_states = hidden_states[token_slice]\n    residual = None if residual is None else residual[token_slice]\n    positions = positions[token_slice]\n\n    assert output_forward_batch.tbo_padded_len is not None\n    padded_len = output_forward_batch.tbo_padded_len\n\n    def _pad(x):\n        nonlocal padded_len\n        if x is None:\n            return None\n        if x.shape[0] == padded_len:\n            return x\n        res = torch.zeros((padded_len, *x.shape[1:]), dtype=x.dtype, device=x.device)\n        res[: x.shape[0]] = x\n        return res\n\n    return dict(\n        hidden_states=_pad(hidden_states),\n        residual=_pad(residual),\n        positions=_pad(positions),\n        forward_batch=output_forward_batch,\n        tbo_subbatch_index=tbo_subbatch_index,\n    )\n\n\ndef _model_forward_tbo_merge_outputs(output_a, output_b, original_len):\n    def _handle_key(name):\n        value_a = output_a[name]\n        value_b = output_b[name]\n        assert (value_a is None) == (value_b is None)\n        if value_a is None:\n            return None\n        s0, t0 = output_a[\"forward_batch\"].tbo_parent_token_range\n        s1, t1 = output_b[\"forward_batch\"].tbo_parent_token_range\n        res = torch.zeros(\n            (original_len, *value_a.shape[1:]),\n            dtype=value_a.dtype,\n            device=value_a.device,\n        )\n        res[slice(s0, t0)] = value_a[: t0 - s0]\n        res[slice(s1, t1)] = value_b[: t1 - s1]\n        return res\n\n    return _handle_key(\"hidden_states\"), _handle_key(\"residual\")\n\n\n# -------------------------------- Utilities and wrappers ---------------------------------------\n\n\nclass MaybeTboDeepEPDispatcher(BaseDispatcher):\n    def __init__(self, **kwargs):\n        super().__init__()\n        num_inner_dispatchers = 2 if is_tbo_enabled() else 1\n        if get_moe_a2a_backend().is_deepep():\n            self._inners = [\n                DeepEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers)\n            ]\n        elif get_moe_a2a_backend().is_mooncake():\n            self._inners = [\n                MooncakeEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers)\n            ]\n        elif get_moe_a2a_backend().is_mori():\n            self._inners = [\n                MoriEPDispatcher(instance_id=i, **kwargs)\n                for i in range(num_inner_dispatchers)\n            ]\n        elif get_moe_a2a_backend().is_nixl():\n            self._inners = [\n                NixlEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers)\n            ]\n\n    def _execute(self, name, tbo_subbatch_index: Optional[int] = None, **kwargs):\n        return getattr(self._inners[tbo_subbatch_index or 0], name)(**kwargs)\n\n    def dispatch(self, **kwargs) -> DispatchOutput:\n        return self._execute(\"dispatch\", **kwargs)\n\n    def dispatch_a(self, **kwargs):\n        return self._execute(\"dispatch_a\", **kwargs)\n\n    def dispatch_b(self, **kwargs):\n        return self._execute(\"dispatch_b\", **kwargs)\n\n    def combine(self, **kwargs) -> torch.Tensor:\n        return self._execute(\"combine\", **kwargs)\n\n    def combine_a(self, **kwargs):\n        return self._execute(\"combine_a\", **kwargs)\n\n    def combine_b(self, **kwargs):\n        return self._execute(\"combine_b\", **kwargs)\n\n    def register_deepep_dispatch_hook(self, hook):\n        handle_list = []\n        for inner in self._inners:\n            handle_list.append(inner.register_deepep_dispatch_hook(hook))\n        return handle_list\n\n    def set_quant_config(self, quant_config: dict):\n        super().set_quant_config(quant_config)\n        for inner in self._inners:\n            inner.set_quant_config(quant_config)\n\n    def set_overlap_args(\n        self, combine_overlap_args: CombineOverlapArgs, meta_overlap_args: dict\n    ):\n        super().set_overlap_args(combine_overlap_args, meta_overlap_args)\n        for inner in self._inners:\n            inner.set_overlap_args(combine_overlap_args, meta_overlap_args)\n\n    def clear_overlap_args(self):\n        super().clear_overlap_args()\n        for inner in self._inners:\n            inner.clear_overlap_args()\n"
  },
  {
    "path": "python/sglang/srt/checkpoint_engine/__init__.py",
    "content": "\"\"\"\nCheckpoint engine module for SGLang.\n\nThis module provides functionality for updating model weights via checkpoint engine.\n\"\"\"\n\nfrom sglang.srt.checkpoint_engine.update import main\n\n__all__ = [\"main\"]\n"
  },
  {
    "path": "python/sglang/srt/checkpoint_engine/checkpoint_engine_worker.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"\nCheckpoint-engine integration for SGLang.\nThis module provides weight update functionality via IPC for checkpoint-engine compatibility.\n\"\"\"\n\nimport logging\nfrom typing import Callable, Dict, Optional\n\nimport torch\nimport zmq\n\ntry:\n    from checkpoint_engine.worker import update_weights_from_ipc\nexcept ImportError:\n    raise ImportError(\n        \"checkpoint-engine is not installed. \"\n        \"Please install it with: pip install sglang[checkpoint-engine]\"\n    )\n\nlogger = logging.getLogger(__name__)\n\n\nclass SGLangCheckpointEngineWorkerExtension:\n    \"\"\"\n    Worker extension for SGLang to support checkpoint-engine IPC weight updates.\n    This class provides the interface needed for checkpoint-engine integration.\n    \"\"\"\n\n    def __init__(self):\n        self._zmq_ctx: Optional[zmq.Context] = None\n\n    def get_device_uuid(self) -> str:\n        \"\"\"Get the UUID of current device.\"\"\"\n        # We need to implement this to get the device UUID\n        # This will be overridden when integrated into SGLang's worker\n        raise NotImplementedError(\n            \"This method should be overridden by SGLang integration\"\n        )\n\n    def get_device_id(self) -> int:\n        \"\"\"Get the device ID.\"\"\"\n        raise NotImplementedError(\n            \"This method should be overridden by SGLang integration\"\n        )\n\n    def get_model_loader(self) -> Callable:\n        \"\"\"Get the model weight loader function.\"\"\"\n        raise NotImplementedError(\n            \"This method should be overridden by SGLang integration\"\n        )\n\n    def get_post_hook(self) -> Optional[Callable]:\n        \"\"\"Get the post-processing hook after weight loading.\"\"\"\n        return None\n\n    def update_weights_from_ipc(self, zmq_handles: Dict[str, str]):\n        \"\"\"\n        Update weights from IPC communication.\n        Args:\n            zmq_handles: Dict mapping device UUID to ZMQ socket path\n        \"\"\"\n        if self._zmq_ctx is None:\n            self._zmq_ctx = zmq.Context()\n        device_uuid = self.get_device_uuid()\n        device_id = self.get_device_id()\n        if device_uuid not in zmq_handles:\n            raise ValueError(\n                f\"Device UUID {device_uuid} not found in zmq_handles: {list(zmq_handles.keys())}\"\n            )\n        update_weights_from_ipc(\n            self._zmq_ctx,\n            zmq_handles[device_uuid],\n            device_id=device_id,\n            run=self.get_model_loader(),\n            post_hook=self.get_post_hook(),\n        )\n\n\nclass SGLangCheckpointEngineWorkerExtensionImpl(SGLangCheckpointEngineWorkerExtension):\n    \"\"\"\n    Implementation of SGLangCheckpointEngineWorkerExtension that integrates with SGLang's model runner.\n    This class provides the concrete implementation for checkpoint-engine IPC weight updates.\n    \"\"\"\n\n    def __init__(self, model_runner):\n        super().__init__()\n        self.model_runner = model_runner\n\n    def get_device_uuid(self) -> str:\n        \"\"\"Get the UUID of current device.\"\"\"\n        # Get device UUID for current device\n        device_id = torch.cuda.current_device()\n        try:\n            return f\"GPU-{torch.cuda.get_device_properties(device_id).uuid!s}\"\n        except AssertionError as e:\n            raise ValueError(f\"Failed to get GPU UUID for device {device_id}\") from e\n\n    def get_device_id(self) -> int:\n        \"\"\"Get the device ID.\"\"\"\n        return torch.cuda.current_device()\n\n    def get_model_loader(self) -> Callable:\n        \"\"\"Get the model weight loader function.\"\"\"\n        return self.model_runner.model.load_weights\n\n    def get_post_hook(self) -> Optional[Callable]:\n        \"\"\"Get the post-processing hook after weight loading.\"\"\"\n\n        def post_hook():\n            # Perform post-processing after weight loading similar to DefaultModelLoader\n            try:\n                from sglang.srt.model_loader.loader import device_loading_context\n\n                # Process quantization methods after loading weights\n                for _, module in self.model_runner.model.named_modules():\n                    quant_method = getattr(module, \"quant_method\", None)\n                    if quant_method is not None:\n                        # Move parameters to device if needed for quantization processing\n                        target_device = torch.device(\n                            \"cuda\", torch.cuda.current_device()\n                        )\n                        with device_loading_context(module, target_device):\n                            quant_method.process_weights_after_loading(module)\n                # Call model-specific post-loading hook if available\n                if hasattr(self.model_runner.model, \"post_load_weights\"):\n                    self.model_runner.model.post_load_weights()\n            except Exception as e:\n                logger.warning(f\"Post-hook processing failed: {e}\")\n\n        return post_hook\n"
  },
  {
    "path": "python/sglang/srt/checkpoint_engine/update.py",
    "content": "\"\"\"\nUsage:\n1) Launch the server with wait-for-initial-weights option in one terminal:\n   python -m sglang.launch_server --model-path /workspace/Qwen/Qwen3-4B/ --tensor-parallel-size 2 --port 19730 --load-format dummy --checkpoint-engine-wait-weights-before-ready --mem-fraction-static 0.7\n\n2) Torchrun this script in another terminal:\n    torchrun --nproc-per-node 2 update.py --update-method broadcast --checkpoint-path /workspace/Qwen/Qwen3-4B/  --inference-parallel-size 2\n\nOr use the integrated entry point:\n    python -m sglang.srt.checkpoint_engine.update --update-method broadcast --checkpoint-path /workspace/Qwen/Qwen3-4B/  --inference-parallel-size 2\n\"\"\"\n\nimport argparse\nimport json\nimport os\nimport pickle\nimport subprocess\nimport sys\nimport time\nfrom collections import defaultdict\nfrom collections.abc import Callable\nfrom contextlib import contextmanager\nfrom typing import Literal\n\nimport httpx\nimport torch\nimport torch.distributed as dist\nfrom safetensors import safe_open\n\ntry:\n    from checkpoint_engine.ps import ParameterServer\n    from loguru import logger\nexcept ImportError:\n    # Fallback for when checkpoint_engine is not available\n    ParameterServer = None\n    import logging\n\n    logger = logging.getLogger(__name__)\n\n\n@contextmanager\ndef timer(msg: str):\n    start = time.perf_counter()\n    yield\n    end = time.perf_counter()\n    logger.info(f\"{msg} duration: {end - start:.2f} seconds\")\n\n\ndef check_sglang_ready(\n    endpoint: str, inference_parallel_size: int, uds: str | None = None\n):\n    rank = int(os.getenv(\"RANK\", 0))\n    if rank != rank // inference_parallel_size * inference_parallel_size:\n        return\n    retry_num = 0\n    transport = None\n    if uds is not None:\n        transport = httpx.HTTPTransport(uds=uds)\n    with httpx.Client(transport=transport) as client:\n        while True:\n            try:\n                response = client.get(f\"{endpoint}/ping\", timeout=10)\n                response.raise_for_status()\n                break\n            except (httpx.ConnectError, httpx.HTTPStatusError) as e:\n                if retry_num % 10 == 0:\n                    logger.warning(\n                        f\"fail to check sglang ready, retry {retry_num} times, error: {e}\"\n                    )\n                retry_num += 1\n                time.sleep(0.1)\n\n\ndef split_checkpoint_files(\n    checkpoint_path: str, rank: int, world_size: int\n) -> list[str]:\n    checkpoint_files = [\n        os.path.join(checkpoint_path, f)\n        for f in filter(\n            lambda x: x.endswith(\".safetensors\"), os.listdir(checkpoint_path)\n        )\n    ]\n    files_per_rank = (len(checkpoint_files) + world_size - 1) // world_size\n    return checkpoint_files[rank * files_per_rank : (rank + 1) * files_per_rank]\n\n\ndef split_tensors(\n    checkpoint_path: str, rank: int, world_size: int\n) -> dict[str, torch.Tensor]:\n    index_fn = os.path.join(checkpoint_path, \"model.safetensors.index.json\")\n    with open(index_fn) as f:\n        weight_map: dict[str, str] = json.load(f)[\"weight_map\"]\n    weights_per_rank = (len(weight_map) + world_size - 1) // world_size\n    fn_tensors: dict[str, list[str]] = defaultdict(list)\n    weight_keys = list(weight_map.items())\n    for name, file in weight_keys[\n        rank * weights_per_rank : (rank + 1) * weights_per_rank\n    ]:\n        fn_tensors[file].append(name)\n    named_tensors = {}\n    for file, names in fn_tensors.items():\n        with safe_open(os.path.join(checkpoint_path, file), framework=\"pt\") as f:\n            for name in names:\n                named_tensors[name] = f.get_tensor(name)\n    return named_tensors\n\n\ndef req_inference(\n    endpoint: str,\n    inference_parallel_size: int,\n    timeout: float = 300.0,\n    uds: str | None = None,\n    weight_version: str | None = None,\n) -> Callable[[list[tuple[str, str]]], None]:\n    rank = int(os.getenv(\"RANK\", 0))\n    src = rank // inference_parallel_size * inference_parallel_size\n\n    def req_func(socket_paths: list[tuple[str, str]]):\n        if rank == src:\n            with httpx.Client(transport=httpx.HTTPTransport(uds=uds)) as client:\n                resp = client.post(\n                    f\"{endpoint}/update_weights_from_ipc\",\n                    json={\n                        \"zmq_handles\": dict(\n                            socket_paths[src : src + inference_parallel_size]\n                        ),\n                        \"flush_cache\": True,\n                        \"weight_version\": weight_version,\n                    },\n                    timeout=timeout,\n                )\n                resp.raise_for_status()\n\n    return req_func\n\n\ndef update_weights(\n    ps,\n    checkpoint_name: str,\n    checkpoint_files: list[str],\n    named_tensors: dict[str, torch.Tensor],\n    req_func: Callable[[list[tuple[str, str]]], None],\n    inference_parallel_size: int,\n    endpoint: str,\n    save_metas_file: str | None = None,\n    update_method: Literal[\"broadcast\", \"p2p\", \"all\"] = \"broadcast\",\n    uds: str | None = None,\n):\n    ps.register_checkpoint(\n        checkpoint_name, files=checkpoint_files, named_tensors=named_tensors\n    )\n    ps.init_process_group()\n    check_sglang_ready(endpoint, inference_parallel_size, uds)\n    dist.barrier()\n    with timer(\"Gather metas\"):\n        ps.gather_metas(checkpoint_name)\n    if save_metas_file and int(os.getenv(\"RANK\")) == 0:\n        with open(save_metas_file, \"wb\") as f:\n            pickle.dump(ps.get_metas(), f)\n\n    if update_method == \"broadcast\" or update_method == \"all\":\n        with timer(\"Update weights without setting ranks\"):\n            ps.update(checkpoint_name, req_func)\n\n    if update_method == \"p2p\" or update_method == \"all\":\n        if update_method:\n            # sleep 2s to wait destroy process group\n            time.sleep(2)\n        with timer(\"Update weights with setting ranks\"):\n            ps.update(\n                checkpoint_name, req_func, ranks=list(range(inference_parallel_size))\n            )\n\n\ndef join(\n    ps: ParameterServer,\n    checkpoint_name: str,\n    load_metas_file: str,\n    req_func: Callable[[list[tuple[str, str]]], None],\n    inference_parallel_size: int,\n    endpoint: str,\n    uds: str | None = None,\n):\n    assert load_metas_file, \"load_metas_file is required\"\n    with open(load_metas_file, \"rb\") as f:\n        metas = pickle.load(f)\n    ps.init_process_group()\n    check_sglang_ready(endpoint, inference_parallel_size, uds)\n    dist.barrier()\n    with timer(\"Gather metas before join\"):\n        ps.gather_metas(checkpoint_name)\n    ps.load_metas(metas)\n    with timer(\n        f\"Update weights with setting ranks as range(0, {inference_parallel_size}) by using p2p\"\n    ):\n        ps.update(checkpoint_name, req_func, ranks=list(range(inference_parallel_size)))\n\n\ndef run_with_torchrun():\n    \"\"\"Run the update script with torchrun automatically.\"\"\"\n    # Parse inference_parallel_size from command line arguments to determine nproc-per-node\n    inference_parallel_size = 8  # default\n    args = sys.argv[1:]  # Skip the script name\n\n    # Look for --inference-parallel-size in arguments\n    for i, arg in enumerate(args):\n        if arg == \"--inference-parallel-size\" and i + 1 < len(args):\n            try:\n                inference_parallel_size = int(args[i + 1])\n            except ValueError:\n                pass\n            break\n        elif arg.startswith(\"--inference-parallel-size=\"):\n            try:\n                inference_parallel_size = int(arg.split(\"=\", 1)[1])\n            except ValueError:\n                pass\n            break\n\n    # Build torchrun command\n    cmd = [\"torchrun\", f\"--nproc-per-node={inference_parallel_size}\", __file__] + args\n\n    print(f\"Running: {' '.join(cmd)}\", file=sys.stderr)\n\n    # Execute torchrun with the original script\n    try:\n        result = subprocess.run(cmd, check=False)\n        sys.exit(result.returncode)\n    except FileNotFoundError:\n        print(\n            \"Error: torchrun command not found. Please ensure PyTorch is installed.\",\n            file=sys.stderr,\n        )\n        sys.exit(1)\n    except KeyboardInterrupt:\n        print(\"\\nInterrupted by user\", file=sys.stderr)\n        sys.exit(130)\n\n\ndef main():\n    # Check if we're running under torchrun or need to invoke it\n    if os.getenv(\"RANK\") is None:\n        # Not running under torchrun, so invoke it\n        run_with_torchrun()\n        return\n\n    # Running under torchrun, proceed with normal execution\n    parser = argparse.ArgumentParser(description=\"Update weights example\")\n    parser.add_argument(\"--checkpoint-path\", type=str, default=None)\n    parser.add_argument(\"--save-metas-file\", type=str, default=None)\n    parser.add_argument(\"--load-metas-file\", type=str, default=None)\n    parser.add_argument(\"--sleep-time\", type=int, default=0)\n    parser.add_argument(\"--endpoint\", type=str, default=\"http://localhost:19730\")\n    parser.add_argument(\"--inference-parallel-size\", type=int, default=8)\n    parser.add_argument(\"--checkpoint-name\", type=str, default=\"my-checkpoint-iter-0\")\n    parser.add_argument(\"--update-method\", type=str, default=\"broadcast\")\n    parser.add_argument(\"--uds\", type=str, default=None)\n    parser.add_argument(\"--weight-version\", type=str, default=None)\n    args = parser.parse_args()\n\n    # Get rank and world_size from environment (set by torchrun)\n    rank = int(os.getenv(\"RANK\", 0))\n    world_size = int(os.getenv(\"WORLD_SIZE\", 1))\n\n    req_func = req_inference(\n        args.endpoint,\n        args.inference_parallel_size,\n        uds=args.uds,\n        weight_version=args.weight_version,\n    )\n\n    if ParameterServer is None:\n        print(\"Error: checkpoint_engine package not available\", file=sys.stderr)\n        sys.exit(1)\n\n    ps = ParameterServer(auto_pg=True)\n    ps._p2p_store = None\n    if args.load_metas_file:\n        join(\n            ps,\n            args.checkpoint_name,\n            args.load_metas_file,\n            req_func,\n            args.inference_parallel_size,\n            args.endpoint,\n            args.uds,\n        )\n    else:\n        if args.checkpoint_path and os.path.exists(\n            os.path.join(args.checkpoint_path, \"model.safetensors.index.json\")\n        ):\n            named_tensors = split_tensors(args.checkpoint_path, rank, world_size)\n            checkpoint_files = []\n        else:\n            checkpoint_files = (\n                split_checkpoint_files(args.checkpoint_path, rank, world_size)\n                if args.checkpoint_path\n                else []\n            )\n            named_tensors = {}\n        update_weights(\n            ps,\n            args.checkpoint_name,\n            checkpoint_files,\n            named_tensors,\n            req_func,\n            args.inference_parallel_size,\n            args.endpoint,\n            args.save_metas_file,\n            args.update_method,\n            args.uds,\n        )\n    time.sleep(args.sleep_time)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "python/sglang/srt/compilation/backend.py",
    "content": "# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/backend.py\n\n\nimport ast\nimport dataclasses\nimport logging\nimport os\nimport pprint\nimport time\nfrom collections.abc import Sequence\nfrom contextlib import contextmanager\nfrom typing import Any, Callable, Optional\n\nimport torch\nimport torch.fx as fx\nfrom torch._dispatch.python import enable_python_dispatcher\n\nfrom sglang.srt.compilation.compilation_config import CompilationConfig\nfrom sglang.srt.compilation.compilation_counter import compilation_counter\nfrom sglang.srt.compilation.compiler_interface import EagerAdapter, InductorAdaptor\nfrom sglang.srt.compilation.cuda_piecewise_backend import CUDAPiecewiseBackend\nfrom sglang.srt.compilation.npu_piecewise_backend import NPUPiecewiseBackend\nfrom sglang.srt.compilation.pass_manager import PostGradPassManager\nfrom sglang.srt.utils.common import is_npu\n\nlogger = logging.getLogger(__name__)\n\n\ndef make_compiler(config: CompilationConfig):\n    if config.compiler == \"eager\":\n        return EagerAdapter()\n    elif config.compiler == \"inductor\":\n        return InductorAdaptor()\n    else:\n        raise ValueError(f\"Unknown compiler: {config.compiler}\")\n\n\ndef make_backend(\n    graph: fx.GraphModule,\n    compile_config: CompilationConfig,\n    inductor_config: dict[str, Any],\n    graph_pool: Any,\n    piecewise_compile_index: int,\n    total_piecewise_compiles: int,\n    sym_shape_indices: list[int],\n    compiled_graph_for_general_shape: Callable,\n    sglang_backend,\n):\n\n    backend_cls = CUDAPiecewiseBackend if not is_npu() else NPUPiecewiseBackend\n    return backend_cls(\n        graph,\n        compile_config,\n        inductor_config,\n        graph_pool,\n        piecewise_compile_index,\n        total_piecewise_compiles,\n        sym_shape_indices,\n        compiled_graph_for_general_shape,\n        sglang_backend,\n    )\n\n\nclass CompilerManager:\n    def __init__(\n        self,\n        config: CompilationConfig,\n    ):\n        self.cache = dict()\n        self.is_cache_updated = False\n        self.compiler = make_compiler(config)\n\n    def compute_hash(self):\n        return self.compiler.compute_hash()\n\n    def initialize_cache(\n        self, cache_dir: str, disable_cache: bool = False, prefix: str = \"\"\n    ):\n        self.disable_cache = disable_cache\n        self.cache_dir = cache_dir\n        self.cache_file_path = os.path.join(cache_dir, \"sglang_compile_cache.py\")\n\n        if not disable_cache and os.path.exists(self.cache_file_path):\n            with open(self.cache_file_path) as f:\n                self.cache = ast.literal_eval(f.read())\n\n        self.compiler.initialize_cache(\n            cache_dir=cache_dir, disable_cache=disable_cache, prefix=prefix\n        )\n\n    def save_to_file(self):\n        if self.disable_cache or not self.is_cache_updated:\n            return\n        printer = pprint.PrettyPrinter(indent=4)\n        data = printer.pformat(self.cache)\n        with open(self.cache_file_path, \"w\") as f:\n            f.write(data)\n\n    def load(\n        self,\n        graph: fx.GraphModule,\n        example_inputs: list[Any],\n        graph_index: int,\n        runtime_shape: Optional[int] = None,\n    ) -> Optional[Callable]:\n        handle = self.cache[(runtime_shape, graph_index, self.compiler.name)]\n        compiled_graph = self.compiler.load(\n            handle, graph, example_inputs, graph_index, runtime_shape\n        )\n        if runtime_shape is None:\n            logger.debug(\n                \"Directly load the %s-th graph for dynamic shape from %s via \"\n                \"handle %s\",\n                graph_index,\n                self.compiler.name,\n                handle,\n            )\n        else:\n            logger.debug(\n                \"Directly load the %s-th graph for shape %s from %s via \" \"handle %s\",\n                graph_index,\n                str(runtime_shape),\n                self.compiler.name,\n                handle,\n            )\n        return compiled_graph\n\n    def compile(\n        self,\n        graph: fx.GraphModule,\n        example_inputs,\n        inductor_config: dict[str, Any],\n        graph_index: int = 0,\n        num_graphs: int = 1,\n        runtime_shape: Optional[int] = None,\n    ) -> Any:\n        if graph_index == 0:\n            # before compiling the first graph, record the start time\n            global compilation_start_time\n            compilation_start_time = time.time()\n\n        compilation_counter.num_backend_compilations += 1\n\n        compiled_graph = None\n\n        # TODO(Yuwei): support cache loading\n\n        # no compiler cached the graph, or the cache is disabled,\n        # we need to compile it\n        if isinstance(self.compiler, InductorAdaptor):\n            maybe_key = None\n        else:\n            maybe_key = f\"artifact_shape_{runtime_shape}_subgraph_{graph_index}\"\n        compiled_graph, handle = self.compiler.compile(\n            graph, example_inputs, inductor_config, runtime_shape, maybe_key\n        )\n\n        assert compiled_graph is not None, \"Failed to compile the graph\"\n\n        # store the artifact in the cache\n        if handle is not None:\n            self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle\n            compilation_counter.num_cache_entries_updated += 1\n            self.is_cache_updated = True\n            if graph_index == 0:\n                # adds some info logging for the first graph\n                if runtime_shape is None:\n                    logger.info(\"Cache the graph for dynamic shape for later use\")\n                else:\n                    logger.info(\n                        \"Cache the graph of shape %s for later use\", str(runtime_shape)\n                    )\n            if runtime_shape is None:\n                logger.debug(\n                    \"Store the %s-th graph for dynamic shape from %s via \" \"handle %s\",\n                    graph_index,\n                    self.compiler.name,\n                    handle,\n                )\n            else:\n                logger.debug(\n                    \"Store the %s-th graph for shape %s from %s via handle %s\",\n                    graph_index,\n                    str(runtime_shape),\n                    self.compiler.name,\n                    handle,\n                )\n\n        # after compiling the last graph, record the end time\n        if graph_index == num_graphs - 1:\n            now = time.time()\n            elapsed = now - compilation_start_time\n            if runtime_shape is None:\n                logger.info(\"Compiling a graph for dynamic shape takes %.2f s\", elapsed)\n            else:\n                logger.info(\n                    \"Compiling a graph for shape %s takes %.2f s\",\n                    runtime_shape,\n                    elapsed,\n                )\n\n        return compiled_graph\n\n\n@dataclasses.dataclass\nclass SplitItem:\n    submod_name: str\n    graph_id: int\n    is_splitting_graph: bool\n    graph: fx.GraphModule\n\n\ndef split_graph(\n    graph: fx.GraphModule, ops: list[str]\n) -> tuple[fx.GraphModule, list[SplitItem]]:\n    # split graph by ops\n    subgraph_id = 0\n    node_to_subgraph_id = {}\n    split_op_graphs = []\n    for node in graph.graph.nodes:\n        if node.op in (\"output\", \"placeholder\"):\n            continue\n        if node.op == \"call_function\" and str(node.target) in ops:\n            subgraph_id += 1\n            node_to_subgraph_id[node] = subgraph_id\n            split_op_graphs.append(subgraph_id)\n            subgraph_id += 1\n        else:\n            node_to_subgraph_id[node] = subgraph_id\n\n    # `keep_original_order` is important!\n    # otherwise pytorch might reorder the nodes and\n    # the semantics of the graph will change when we\n    # have mutations in the graph\n    split_gm = torch.fx.passes.split_module.split_module(\n        graph, None, lambda node: node_to_subgraph_id[node], keep_original_order=True\n    )\n\n    outputs = []\n\n    names = [name for (name, module) in split_gm.named_modules()]\n\n    for name in names:\n        if \".\" in name or name == \"\":\n            # recursive child module or the root module\n            continue\n\n        module = getattr(split_gm, name)\n\n        graph_id = int(name.replace(\"submod_\", \"\"))\n        outputs.append(SplitItem(name, graph_id, (graph_id in split_op_graphs), module))\n\n    # sort by intetger graph_id, rather than string name\n    outputs.sort(key=lambda x: x.graph_id)\n\n    return split_gm, outputs\n\n\n# we share the global graph pool among all the backends\nglobal_graph_pool = None\n\ncompilation_start_time = 0.0\n\n\nclass PiecewiseCompileInterpreter(torch.fx.Interpreter):\n    def __init__(\n        self,\n        module: torch.fx.GraphModule,\n        compile_submod_names: list[str],\n        inductor_config: dict[str, Any],\n        graph_pool,\n        compile_config: CompilationConfig,\n        sglang_backend: \"SGLangBackend\",\n    ):\n        super().__init__(module)\n        from torch._guards import detect_fake_mode\n\n        self.fake_mode = detect_fake_mode()\n        self.compile_submod_names = compile_submod_names\n        self.graph_pool = graph_pool\n        self.sglang_backend = sglang_backend\n        # When True, it annoyingly dumps the torch.fx.Graph on errors.\n        self.extra_traceback = False\n        self.inductor_config = inductor_config\n        self.compile_config = compile_config\n\n    def run(self, *args):\n        fake_args = [\n            self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t\n            for t in args\n        ]\n        with self.fake_mode, enable_python_dispatcher():\n            return super().run(*fake_args)\n\n    def call_module(\n        self,\n        target: torch.fx.node.Target,\n        args: tuple[torch.fx.node.Argument, ...],\n        kwargs: dict[str, Any],\n    ) -> Any:\n        assert isinstance(target, str)\n        output = super().call_module(target, args, kwargs)\n\n        if target in self.compile_submod_names:\n            index = self.compile_submod_names.index(target)\n            submod = self.fetch_attr(target)\n            sym_shape_indices = [\n                i for i, x in enumerate(args) if isinstance(x, torch.SymInt)\n            ]\n            global compilation_start_time\n            compiled_graph_for_dynamic_shape = (\n                self.sglang_backend.compiler_manager.compile(\n                    submod,\n                    args,\n                    self.inductor_config,\n                    graph_index=index,\n                    num_graphs=len(self.compile_submod_names),\n                    runtime_shape=None,\n                )\n            )\n\n            self.module.__dict__[target] = make_backend(\n                submod,\n                self.compile_config,\n                self.inductor_config,\n                self.graph_pool,\n                index,\n                len(self.compile_submod_names),\n                sym_shape_indices,\n                compiled_graph_for_dynamic_shape,\n                self.sglang_backend,\n            )\n\n            compilation_counter.num_piecewise_capturable_graphs_seen += 1\n\n        return output\n\n\nmodel_tag: str = \"backbone\"\n\n\n@contextmanager\ndef set_model_tag(tag: str):\n    \"\"\"Context manager to set the model tag.\"\"\"\n    global model_tag\n    assert (\n        tag != model_tag\n    ), f\"Model tag {tag} is the same as the current tag {model_tag}.\"\n    old_tag = model_tag\n    model_tag = tag\n    try:\n        yield\n    finally:\n        model_tag = old_tag\n\n\nclass SGLangBackend:\n\n    graph_pool: Any\n    _called: bool = False\n    # the graph we compiled\n    graph: fx.GraphModule\n    # the stiching graph module for all the piecewise graphs\n    split_gm: fx.GraphModule\n    piecewise_graphs: list[SplitItem]\n    returned_callable: Callable\n    # Inductor passes to run on the graph pre-defunctionalization\n    post_grad_passes: Sequence[Callable]\n    sym_tensor_indices: list[int]\n    input_buffers: list[torch.Tensor]\n    compiler_manager: CompilerManager\n\n    def __init__(\n        self,\n        config: CompilationConfig,\n        graph_pool: Any,\n    ):\n        assert graph_pool is not None\n        self.graph_pool = graph_pool\n\n        self.post_grad_pass_manager = PostGradPassManager()\n        self.sym_tensor_indices = []\n        self.input_buffers = []\n\n        self.compiler_manager = CompilerManager(config)\n        self.inductor_config = {\n            \"enable_auto_functionalized_v2\": False,\n        }\n        self.compile_config = config\n\n    def configure_post_pass(self):\n        self.post_grad_pass_manager.configure()\n        self.inductor_config[\"post_grad_custom_post_pass\"] = self.post_grad_pass_manager\n\n    def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:\n        base_cache_dir = os.path.expanduser(\n            os.getenv(\"SGLANG_CACHE_DIR\", \"~/.cache/sglang/\")\n        )\n\n        cache_hash = self.compiler_manager.compute_hash()\n        cache_dir = os.path.join(\n            base_cache_dir,\n            \"torch_compile_cache\",\n            cache_hash,\n        )\n\n        os.makedirs(cache_dir, exist_ok=True)\n        rank = 0\n        dp_rank = 0\n        local_cache_dir = os.path.join(cache_dir, f\"rank_{rank}_{dp_rank}\", model_tag)\n        os.makedirs(local_cache_dir, exist_ok=True)\n        self.compiler_manager.initialize_cache(\n            local_cache_dir, disable_cache=False, prefix=\"\"\n        )\n        compilation_counter.num_graphs_seen += 1\n\n        assert not self._called, \"SGLangBackend can only be called once\"\n\n        self.graph = graph\n        self.configure_post_pass()\n\n        self.split_gm, self.piecewise_graphs = split_graph(\n            graph,\n            self.compile_config.split_ops,\n        )\n        from torch._dynamo.utils import lazy_format_graph_code\n\n        # depyf will hook lazy_format_graph_code and dump the graph\n        # for debugging, no need to print the graph here\n        lazy_format_graph_code(\"before split\", self.graph)\n        lazy_format_graph_code(\"after split\", self.split_gm)\n\n        compilation_counter.num_piecewise_graphs_seen += len(self.piecewise_graphs)\n\n        submod_names_to_compile = [\n            item.submod_name\n            for item in self.piecewise_graphs\n            if not item.is_splitting_graph\n        ]\n\n        PiecewiseCompileInterpreter(\n            self.split_gm,\n            submod_names_to_compile,\n            self.inductor_config,\n            self.graph_pool,\n            self.compile_config,\n            self,\n        ).run(*example_inputs)\n\n        rank = torch.distributed.get_rank()\n\n        if rank == 0:\n            graph_path = os.path.join(\n                local_cache_dir, f\"computation_graph_{time.time()}.py\"\n            )\n            if not os.path.exists(graph_path):\n                # code adapted from https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30 # noqa\n                # use `print_readable` because it can include submodules\n                src = (\n                    \"from __future__ import annotations\\nimport torch\\n\"\n                    + self.split_gm.print_readable(print_output=False)\n                )\n                src = src.replace(\"<lambda>\", \"GraphModule\")\n                with open(graph_path, \"w\") as f:\n                    f.write(src)\n\n        self._called = True\n        return self.split_gm\n"
  },
  {
    "path": "python/sglang/srt/compilation/compilation_config.py",
    "content": "# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/compilation_config.py\n\nfrom typing import Callable, List, Optional\n\nSPLIT_OPS = []\n\n\ndef register_split_op(op_name: Optional[str] = None):\n    def decorator(op_func: Callable):\n        name = op_name or op_func.__name__\n        SPLIT_OPS.append(f\"sglang.{name}\")\n        return op_func\n\n    return decorator\n\n\n# TODO(Yuwei): support better compile config support\nclass CompilationConfig:\n    def __init__(\n        self,\n        capture_sizes: List[int],\n        compiler: str = \"eager\",\n        enable_debug_mode: bool = False,\n    ):\n        self.traced_files = set()\n        self.capture_sizes = capture_sizes\n        self.compiler = compiler\n        self.enable_debug_mode = enable_debug_mode\n        self.split_ops = []\n        self.split_ops.extend(SPLIT_OPS)\n\n    def add_split_op(self, op: str):\n        self.split_ops.append(op)\n\n    def add_traced_file(self, file_path: str):\n        self.traced_files.add(file_path)\n\n    def get_traced_files(self):\n        return self.traced_files\n\n    def get_capture_sizes(self):\n        return self.capture_sizes\n\n    def get_enable_debug_mode(self):\n        return self.enable_debug_mode\n"
  },
  {
    "path": "python/sglang/srt/compilation/compilation_counter.py",
    "content": "# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/compilation_counter.py\n\nimport copy\nimport dataclasses\nfrom contextlib import contextmanager\n\n\n@dataclasses.dataclass\nclass CompilationCounter:\n    num_models_seen: int = 0\n    num_graphs_seen: int = 0\n    # including the splitting ops\n    num_piecewise_graphs_seen: int = 0\n    # not including the splitting ops\n    num_piecewise_capturable_graphs_seen: int = 0\n    num_backend_compilations: int = 0\n    # Number of gpu_model_runner attempts to trigger CUDAGraphs capture\n    num_gpu_runner_capture_triggers: int = 0\n    # Number of CUDAGraphs captured\n    num_cudagraph_captured: int = 0\n    # InductorAdapter.compile calls\n    num_inductor_compiles: int = 0\n    # EagerAdapter.compile calls\n    num_eager_compiles: int = 0\n    # The number of time vLLM's compiler cache entry was updated\n    num_cache_entries_updated: int = 0\n    # The number of standalone_compile compiled artifacts saved\n    num_compiled_artifacts_saved: int = 0\n    # Number of times a model was loaded with CompilationLevel.DYNAMO_AS_IS\n    dynamo_as_is_count: int = 0\n\n    def clone(self) -> \"CompilationCounter\":\n        return copy.deepcopy(self)\n\n    @contextmanager\n    def expect(self, **kwargs):\n        old = self.clone()\n        yield\n        for k, v in kwargs.items():\n            assert getattr(self, k) - getattr(old, k) == v, (\n                f\"{k} not as expected, before it is {getattr(old, k)}\"\n                f\", after it is {getattr(self, k)}, \"\n                f\"expected diff is {v}\"\n            )\n\n\ncompilation_counter = CompilationCounter()\n"
  },
  {
    "path": "python/sglang/srt/compilation/compile.py",
    "content": "import inspect\nimport logging\nimport os\nimport sys\nimport types\nfrom dataclasses import dataclass\nfrom typing import Any, Callable, Optional, Union\n\nimport torch\n\nfrom sglang.srt.compilation.compilation_config import CompilationConfig\nfrom sglang.srt.compilation.piecewise_context_manager import is_in_piecewise_cuda_graph\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclass\nclass IntermediateTensors:\n    \"\"\"For all pipeline stages except the last, we need to return the hidden\n    states and residuals to be sent to the next stage. This data structure\n    contains the hidden states and residuals for a request.\n\n    Each stage also needs to handle its own finished_sending and\n    finished_recving in case of kv transfer.\n    \"\"\"\n\n    tensors: dict[str, torch.Tensor]\n    # [req_ids]\n    finished_sending: Optional[set[str]] = None\n    finished_recving: Optional[set[str]] = None\n\n    def __init__(self, tensors):\n        # manually define this function, so that\n        # Dynamo knows `IntermediateTensors()` comes from this file.\n        # Otherwise, dataclass will generate this function by evaluating\n        # a string, and we will lose the information about the source file.\n        self.tensors = tensors\n\n    def __getitem__(self, key: Union[str, slice]):\n        if isinstance(key, str):\n            return self.tensors[key]\n        elif isinstance(key, slice):\n            return self.__class__({k: v[key] for k, v in self.tensors.items()})\n\n    def __setitem__(self, key: str, value: torch.Tensor):\n        self.tensors[key] = value\n\n    def items(self):\n        return self.tensors.items()\n\n    def __len__(self):\n        return len(self.tensors)\n\n    def __eq__(self, other: object):\n        return isinstance(other, self.__class__) and self\n\n    def __repr__(self) -> str:\n        return f\"IntermediateTensors(tensors={self.tensors})\"\n\n\ndef _normalize_dims(dims, ndim: int):\n    dims = [dims] if isinstance(dims, int) else list(dims)\n    return [d if d >= 0 else ndim + d for d in dims]\n\n\nclass _MaybeIntermediateTensors:\n    \"\"\"Duck-typed check to support your IntermediateTensors without importing.\"\"\"\n\n    def __init__(self, obj):\n        self.is_intermediate = hasattr(obj, \"tensors\") and isinstance(\n            getattr(obj, \"tensors\"), dict\n        )\n        self.obj = obj\n\n\ndef _mark_dynamic_on_value(val, dims):\n    if isinstance(val, torch.Tensor):\n        torch._dynamo.maybe_mark_dynamic(val, _normalize_dims(dims, val.ndim))\n    else:\n        mit = _MaybeIntermediateTensors(val)\n        if mit.is_intermediate:\n            for t in mit.obj.tensors.values():\n                torch._dynamo.maybe_mark_dynamic(t, _normalize_dims(dims, t.ndim))\n        # else: ignore (None or non-tensor)\n\n\ndef _infer_dynamic_arg_dims_from_annotations(forward_fn):\n    sig = inspect.signature(forward_fn)\n    dyn = {}\n    for name, p in sig.parameters.items():\n        ann = p.annotation\n        # Accept torch.Tensor / Optional[torch.Tensor] / your IntermediateTensors types by name\n        if (\n            ann is torch.Tensor\n            or getattr(getattr(ann, \"__args__\", [None])[0], \"__name__\", \"\") == \"Tensor\"\n        ):\n            dyn[name] = 0\n        elif getattr(ann, \"__name__\", \"\") in (\"IntermediateTensors\",) or any(\n            getattr(a, \"__name__\", \"\") == \"IntermediateTensors\"\n            for a in getattr(ann, \"__args__\", [])\n        ):\n            dyn[name] = 0\n        elif ann == \"torch.Tensor\" or ann == \"Optional[torch.Tensor]\":\n            # For future import annotations (e.g. from __future__ import annotations), the annotation is a string\n            dyn[name] = 0\n    if not dyn:\n        raise ValueError(\"No dynamic dims inferred; pass dynamic_arg_dims explicitly.\")\n    return dyn\n\n\ndef install_torch_compiled(\n    module: torch.nn.Module,\n    *,\n    dynamic_arg_dims: dict[str, Union[int, list[int]]] | None = None,\n    backend_factory: Optional[Callable[[torch.fx.GraphModule, list], Callable]] = None,\n    compile_config: CompilationConfig = None,\n    fullgraph: bool = True,\n    graph_pool: Any = None,\n):\n    unbound_fwd = module.__class__.forward\n    if not callable(unbound_fwd):\n        raise TypeError(\"module.__class__.forward must be callable\")\n    original_code = unbound_fwd.__code__\n\n    dyn_map = dynamic_arg_dims or _infer_dynamic_arg_dims_from_annotations(unbound_fwd)\n\n    if backend_factory is None:\n        from sglang.srt.compilation.backend import SGLangBackend\n\n        backend_factory = lambda gm, ex: SGLangBackend(compile_config, graph_pool)(\n            gm, ex\n        )\n\n    compiled_codes: list[type(original_code)] = []\n    state = {\"compiled\": False, \"compiled_callable\": None}\n\n    def bytecode_hook(old_code, new_code):\n        if old_code is not original_code:\n            return\n        frame = sys._getframe()\n        while frame and frame.f_back:\n            frame = frame.f_back\n            if (\n                frame.f_code.co_name == \"_compile\"\n                and os.path.basename(frame.f_code.co_filename) == \"convert_frame.py\"\n            ):\n                break\n        try:\n            dynamo_frame = frame.f_locals[\"frame\"]\n        except Exception:\n            return\n        if dynamo_frame.f_code is not old_code:\n            return\n        if dynamo_frame.f_locals.get(\"self\") is not module:\n            return\n        compiled_codes.append(new_code)\n\n    torch._dynamo.convert_frame.register_bytecode_hook(bytecode_hook)\n\n    def _ensure_compiled(self, *args, **kwargs):\n        \"\"\"Compile on first use (with flag ON).\"\"\"\n        if state[\"compiled\"]:\n            return\n        # Mark dynamic dims only when we are about to compile\n        sig = inspect.signature(unbound_fwd)\n        ba = sig.bind(self, *args, **kwargs)\n        ba.apply_defaults()\n        for name, dims in (dyn_map or {}).items():\n            if name in ba.arguments:\n                val = ba.arguments[name]\n                if val is not None:\n                    _mark_dynamic_on_value(val, dims)\n\n        # Avoid cross-instance cache reuse\n        torch._dynamo.eval_frame.remove_from_cache(unbound_fwd.__code__)\n\n        bound = types.MethodType(unbound_fwd, self)\n        compiled_callable = torch.compile(\n            bound, fullgraph=fullgraph, backend=backend_factory\n        )\n\n        # Trigger Dynamo so bytecode hook can capture\n        compiled_callable(*args, **kwargs)\n\n        state[\"compiled\"] = True\n        state[\"compiled_callable\"] = compiled_callable\n\n    def trampoline(self, *args, **kwargs):\n        use_compiled = is_in_piecewise_cuda_graph()\n        if use_compiled:\n            if not state[\"compiled\"]:\n                _ensure_compiled(self, *args, **kwargs)\n\n            compiled_callable = state[\"compiled_callable\"]\n            return compiled_callable(*args, **kwargs)\n        else:\n            # Explicitly run the original uncompiled forward\n            return unbound_fwd(self, *args, **kwargs)\n\n    module.forward = types.MethodType(trampoline, module)\n    return module\n"
  },
  {
    "path": "python/sglang/srt/compilation/compiler_interface.py",
    "content": "# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/compiler_interface.py\n\nimport contextlib\nimport copy\nimport hashlib\nimport os\nfrom contextlib import ExitStack\nfrom typing import Any, Callable, Optional\nfrom unittest.mock import patch\n\nimport torch\nimport torch._inductor.compile_fx\nimport torch.fx as fx\n\nfrom sglang.srt.compilation.compilation_counter import compilation_counter\nfrom sglang.srt.compilation.inductor_pass import pass_context\nfrom sglang.srt.utils.common import torch_release\n\n\nclass CompilerInterface:\n    \"\"\"\n    The interface for a compiler that can be used by vLLM.\n    \"\"\"\n\n    # The name of the compiler, e.g. inductor.\n    # This is a class-level attribute.\n    name: str\n\n    def initialize_cache(\n        self, cache_dir: str, disable_cache: bool = False, prefix: str = \"\"\n    ):\n        \"\"\"\n        when the vLLM process uses `cache_dir` as the cache directory,\n        the compiler should initialize itself with the cache directory,\n        e.g. by re-directing its own cache directory to a sub-directory.\n\n        prefix can be used in combination with cache_dir to figure out the base\n        cache directory, e.g. there're multiple parts of model being compiled,\n        but we want to share the same cache directory for all of them.\n\n        e.g.\n        cache_dir = \"/path/to/dir/backbone\", prefix = \"backbone\"\n        cache_dir = \"/path/to/dir/eagle_head\", prefix = \"eagle_head\"\n        \"\"\"\n        pass\n\n    def compute_hash(self) -> str:\n        \"\"\"\n        Gather all the relevant information from the vLLM config,\n        to compute a hash so that we can cache the compiled model.\n\n        See [`VllmConfig.compute_hash`][vllm.config.VllmConfig.compute_hash]\n        to check what information\n        is already considered by default. This function should only\n        consider the information that is specific to the compiler.\n        \"\"\"\n        return \"\"\n\n    def compile(\n        self,\n        graph: fx.GraphModule,\n        example_inputs: list[Any],\n        compiler_config: dict[str, Any],\n        runtime_shape: Optional[int] = None,\n        key: Optional[str] = None,\n    ) -> tuple[Optional[Callable], Optional[Any]]:\n        \"\"\"\n        Compile the graph with the given example inputs and compiler config,\n        with a runtime shape. If the `runtime_shape` is None, it means\n        the `example_inputs` have a dynamic shape. Otherwise, the\n        `runtime_shape` specifies the shape of the inputs. Right now we only\n        support one variable shape for all inputs, which is the batchsize\n        (number of tokens) during inference.\n\n        Dynamo will make sure `graph(*example_inputs)` is valid.\n\n        The function should return a compiled callable function, as well as\n        a handle that can be used to directly load the compiled function.\n\n        The handle should be a plain Python object, preferably a string or a\n        file path for readability.\n\n        If the compiler doesn't support caching, it should return None for the\n        handle. If the compiler fails to compile the graph, it should return\n        None for the compiled function as well.\n\n        `key` is required for StandaloneInductorAdapter, it specifies where to\n        save the compiled artifact. The compiled artifact gets saved to\n        `cache_dir/key`.\n        \"\"\"\n        return None, None\n\n    def load(\n        self,\n        handle: Any,\n        graph: fx.GraphModule,\n        example_inputs: list[Any],\n        graph_index: int,\n        runtime_shape: Optional[int] = None,\n    ) -> Callable:\n        \"\"\"\n        Load the compiled function from the handle.\n        Raises an error if the handle is invalid.\n\n        The handle is the second return value of the `compile` function.\n        \"\"\"\n        raise NotImplementedError(\"caching is not supported\")\n\n\ndef get_inductor_factors() -> list[Any]:\n    factors: list[Any] = []\n    # summarize system state\n    from torch._inductor.codecache import CacheBase\n\n    system_factors = CacheBase.get_system()\n    factors.append(system_factors)\n\n    # summarize pytorch state\n    from torch._inductor.codecache import torch_key\n\n    torch_factors = torch_key()\n    factors.append(torch_factors)\n    return factors\n\n\nclass AlwaysHitShapeEnv:\n    \"\"\"\n    Why do we need this class:\n\n    For normal `torch.compile` usage, every compilation will have\n    one Dynamo bytecode compilation and one Inductor compilation.\n    The Inductor compilation happens under the context of the\n    Dynamo bytecode compilation, and that context is used to\n    determine the dynamic shape information, etc.\n\n    For our use case, we only run Dynamo bytecode compilation once,\n    and run Inductor compilation multiple times with different shapes\n    plus a general shape. The compilation for specific shapes happens\n    outside of the context of the Dynamo bytecode compilation. At that\n    time, we don't have shape environment to provide to Inductor, and\n    it will fail the Inductor code cache lookup.\n\n    By providing a dummy shape environment that always hits, we can\n    make the Inductor code cache lookup always hit, and we can\n    compile the graph for different shapes as needed.\n\n    The following dummy methods are obtained by trial-and-error\n    until it works.\n    \"\"\"\n\n    def __init__(self) -> None:\n        self.guards: list[Any] = []\n\n    def evaluate_guards_expression(self, *args, **kwargs):\n        return True\n\n    def get_pruned_guards(self, *args, **kwargs):\n        return []\n\n    def produce_guards_expression(self, *args, **kwargs):\n        return \"\"\n\n\nclass InductorAdaptor(CompilerInterface):\n    \"\"\"\n    The adaptor for the Inductor compiler, version 2.5, 2.6, 2.7.\n    \"\"\"\n\n    name = \"inductor\"\n\n    def compute_hash(self) -> str:\n        factors = get_inductor_factors()\n        hash_str = hashlib.md5(\n            str(factors).encode(), usedforsecurity=False\n        ).hexdigest()[:10]\n        return hash_str\n\n    def initialize_cache(\n        self, cache_dir: str, disable_cache: bool = False, prefix: str = \"\"\n    ):\n        self.cache_dir = cache_dir\n        self.prefix = prefix\n        self.base_cache_dir = cache_dir[: -len(prefix)] if prefix else cache_dir\n        if disable_cache:\n            return\n        # redirect the cache directory to a sub-directory\n        # set flags so that Inductor and Triton store their cache\n        # in the cache_dir, then users only need to copy the cache_dir\n        # to another machine to reuse the cache.\n        inductor_cache = os.path.join(self.base_cache_dir, \"inductor_cache\")\n        os.makedirs(inductor_cache, exist_ok=True)\n        os.environ[\"TORCHINDUCTOR_CACHE_DIR\"] = inductor_cache\n        triton_cache = os.path.join(self.base_cache_dir, \"triton_cache\")\n        os.makedirs(triton_cache, exist_ok=True)\n        os.environ[\"TRITON_CACHE_DIR\"] = triton_cache\n\n    def compile(\n        self,\n        graph: fx.GraphModule,\n        example_inputs: list[Any],\n        compiler_config: dict[str, Any],\n        runtime_shape: Optional[int] = None,\n        key: Optional[str] = None,\n    ) -> tuple[Optional[Callable], Optional[Any]]:\n        compilation_counter.num_inductor_compiles += 1\n        from torch._inductor.compile_fx import compile_fx\n\n        current_config = {}\n        if compiler_config is not None:\n            current_config.update(compiler_config)\n\n        # disable remote cache\n        current_config[\"fx_graph_cache\"] = True\n        current_config[\"fx_graph_remote_cache\"] = False\n\n        set_inductor_config(current_config, runtime_shape)\n\n        # inductor can inplace modify the graph, so we need to copy it\n        # see https://github.com/pytorch/pytorch/issues/138980\n        graph = copy.deepcopy(graph)\n\n        # it's the first time we compile this graph\n        # the assumption is that we don't have nested Inductor compilation.\n        # compiled_fx_graph_hash will only be called once, and we can hook\n        # it to get the hash of the compiled graph directly.\n\n        hash_str, file_path = None, None\n        from torch._inductor.codecache import FxGraphCache, compiled_fx_graph_hash\n\n        if torch_release[:2] == (2, 5):\n            original_load = FxGraphCache.load\n            original_load_name = \"torch._inductor.codecache.FxGraphCache.load\"\n\n            def hijack_load(*args, **kwargs):\n                inductor_compiled_graph = original_load(*args, **kwargs)\n                nonlocal file_path\n                compiled_fn = inductor_compiled_graph.current_callable\n                file_path = compiled_fn.__code__.co_filename  # noqa\n                if not file_path.startswith(self.base_cache_dir):\n                    # hooked in the align_inputs_from_check_idxs function\n                    # in torch/_inductor/utils.py\n                    for cell in compiled_fn.__closure__:\n                        if not callable(cell.cell_contents):\n                            continue\n                        if cell.cell_contents.__code__.co_filename.startswith(\n                            self.base_cache_dir\n                        ):\n                            # this is the real file path compiled from Inductor\n                            file_path = cell.cell_contents.__code__.co_filename\n                            break\n                return inductor_compiled_graph\n\n            hijacked_compile_fx_inner = (\n                torch._inductor.compile_fx.compile_fx_inner\n            )  # noqa\n        elif torch_release >= (2, 6):\n            # function renamed in 2.6\n            original_load_name = None\n\n            def hijacked_compile_fx_inner(*args, **kwargs):\n                output = torch._inductor.compile_fx.compile_fx_inner(*args, **kwargs)\n                nonlocal hash_str\n                inductor_compiled_graph = output\n                if inductor_compiled_graph is not None:\n                    nonlocal file_path\n                    compiled_fn = inductor_compiled_graph.current_callable\n                    file_path = compiled_fn.__code__.co_filename  # noqa\n                    if not file_path.startswith(self.base_cache_dir):\n                        # hooked in the align_inputs_from_check_idxs function\n                        # in torch/_inductor/utils.py\n                        for cell in compiled_fn.__closure__:\n                            if not callable(cell.cell_contents):\n                                continue\n                            code = cell.cell_contents.__code__\n                            if code.co_filename.startswith(self.base_cache_dir):\n                                # this is the real file path\n                                # compiled from Inductor\n                                file_path = code.co_filename\n                                break\n                    hash_str = inductor_compiled_graph._fx_graph_cache_key\n                return output\n\n        def hijack_compiled_fx_graph_hash(*args, **kwargs):\n            out = compiled_fx_graph_hash(*args, **kwargs)\n            nonlocal hash_str\n            hash_str = out[0]\n            return out\n\n        def _check_can_cache(*args, **kwargs):\n            # no error means it can be cached.\n            # Inductor refuses to cache the graph outside of Dynamo\n            # tracing context, and also disables caching for graphs\n            # with high-order ops.\n            # For vLLM, in either case, we want to cache the graph.\n            # see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa\n            return\n\n        def _get_shape_env() -> AlwaysHitShapeEnv:\n            return AlwaysHitShapeEnv()\n\n        with ExitStack() as stack:\n            # hijack to get the compiled graph itself\n            if original_load_name is not None:\n                stack.enter_context(patch(original_load_name, hijack_load))\n\n            # for hijacking the hash of the compiled graph\n            stack.enter_context(\n                patch(\n                    \"torch._inductor.codecache.compiled_fx_graph_hash\",\n                    hijack_compiled_fx_graph_hash,\n                )\n            )\n\n            # for providing a dummy shape environment\n            stack.enter_context(\n                patch(\n                    \"torch._inductor.codecache.FxGraphCache._get_shape_env\",\n                    _get_shape_env,\n                )\n            )\n\n            from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache\n\n            # torch 2.8+ on main uses _get_shape_env in AOTAutogradCache\n            if hasattr(AOTAutogradCache, \"_get_shape_env\"):\n                stack.enter_context(\n                    patch(\n                        \"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env\",\n                        _get_shape_env,\n                    )\n                )\n\n            # for forcing the graph to be cached\n            stack.enter_context(\n                patch(\n                    \"torch._inductor.codecache.FxGraphCache._check_can_cache\",\n                    _check_can_cache,\n                )\n            )\n\n            # Dynamo metrics context, see method for more details.\n            stack.enter_context(self.metrics_context())\n\n            # Disable remote caching. When these are on, on remote cache-hit,\n            # the monkey-patched functions never actually get called.\n            # vLLM today assumes and requires the monkey-patched functions to\n            # get hit.\n            # TODO(zou3519): we're going to replace this all with\n            # standalone_compile sometime.\n\n            stack.enter_context(\n                torch._inductor.config.patch(fx_graph_remote_cache=False)\n            )\n            # InductorAdaptor (unfortunately) requires AOTAutogradCache\n            # to be turned off to run. It will fail to acquire the hash_str\n            # and error if not.\n            # StandaloneInductorAdaptor (PyTorch 2.8+) fixes this problem.\n            stack.enter_context(\n                torch._functorch.config.patch(enable_autograd_cache=False)\n            )\n            stack.enter_context(\n                torch._functorch.config.patch(enable_remote_autograd_cache=False)\n            )\n\n            with pass_context(runtime_shape):\n                compiled_graph = compile_fx(\n                    graph,\n                    example_inputs,\n                    inner_compile=hijacked_compile_fx_inner,\n                    config_patches=current_config,\n                )\n        return compiled_graph, (hash_str, file_path)\n\n    def load(\n        self,\n        handle: Any,\n        graph: fx.GraphModule,\n        example_inputs: list[Any],\n        graph_index: int,\n        runtime_shape: Optional[int] = None,\n    ) -> Callable:\n        assert isinstance(handle, tuple)\n        assert isinstance(handle[0], str)\n        assert isinstance(handle[1], str)\n        hash_str = handle[0]\n\n        from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache\n        from torch._inductor.codecache import FxGraphCache\n\n        with ExitStack() as exit_stack:\n            exit_stack.enter_context(\n                patch(\n                    \"torch._inductor.codecache.FxGraphCache._get_shape_env\",\n                    lambda *args, **kwargs: AlwaysHitShapeEnv(),\n                )\n            )\n            # torch 2.8+ on main uses _get_shape_env in AOTAutogradCache\n            if hasattr(AOTAutogradCache, \"_get_shape_env\"):\n                exit_stack.enter_context(\n                    patch(\n                        \"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env\",\n                        lambda *args, **kwargs: AlwaysHitShapeEnv(),\n                    )\n                )\n\n            # Dynamo metrics context, see method for more details.\n            exit_stack.enter_context(self.metrics_context())\n\n            if torch_release[:2] == (2, 5):\n                inductor_compiled_graph = FxGraphCache._lookup_graph(\n                    hash_str, example_inputs, True, False\n                )\n                assert inductor_compiled_graph is not None, (\n                    \"Inductor cache lookup failed. Please remove\"\n                    f\"the cache directory and try again.\"  # noqa\n                )\n            elif torch_release >= (2, 6):\n                from torch._inductor.output_code import CompiledFxGraphConstantsWithGm\n\n                constants = CompiledFxGraphConstantsWithGm(graph)\n                inductor_compiled_graph, _ = FxGraphCache._lookup_graph(\n                    hash_str, example_inputs, True, None, constants\n                )\n                assert inductor_compiled_graph is not None, (\n                    \"Inductor cache lookup failed. Please remove\"\n                    f\"the cache directory and try again.\"  # noqa\n                )\n\n        # Inductor calling convention (function signature):\n        # f(list) -> tuple\n        # Dynamo calling convention (function signature):\n        # f(*args) -> Any\n\n        # need to know if the graph returns a tuple\n        from torch._inductor.compile_fx import graph_returns_tuple\n\n        returns_tuple = graph_returns_tuple(graph)\n\n        # this is the callable we return to Dynamo to run\n        def compiled_graph(*args):\n            # convert args to list\n            list_args = list(args)\n            graph_output = inductor_compiled_graph(list_args)\n            # unpack the tuple if needed\n            if returns_tuple:\n                return graph_output\n            else:\n                return graph_output[0]\n\n        return compiled_graph\n\n    def metrics_context(self) -> contextlib.AbstractContextManager:\n        \"\"\"\n        This method returns the Dynamo metrics context (if it exists,\n        otherwise a null context). It is used by various compile components.\n        Present in torch>=2.6, it's used inside FxGraphCache in\n        torch==2.6 (but not after). It might also be used in various other\n        torch.compile internal functions.\n\n        Because it is re-entrant, we always set it (even if entering via Dynamo\n        and the context was already entered). We might want to revisit if it\n        should be set at a different level of compilation.\n\n        This is likely a bug in PyTorch: public APIs should not rely on\n        manually setting up internal contexts. But we also rely on non-public\n        APIs which might not provide these guarantees.\n        \"\"\"\n        import torch._dynamo.utils\n\n        return torch._dynamo.utils.get_metrics_context()\n\n\ndef set_inductor_config(config, runtime_shape):\n    if isinstance(runtime_shape, int):\n        # for a specific batchsize, tuning triton kernel parameters\n        # can be beneficial\n        config[\"max_autotune\"] = True\n        config[\"coordinate_descent_tuning\"] = True\n\n\nclass EagerAdapter(CompilerInterface):\n    name = \"eager\"\n\n    def compile(\n        self,\n        graph: fx.GraphModule,\n        example_inputs: list[Any],\n        compiler_config: dict[str, Any],\n        runtime_shape: Optional[int] = None,\n        key: Optional[str] = None,\n        num_graphs: int = 1,\n    ) -> tuple[Optional[Callable], Optional[Any]]:\n        return graph, None\n\n    def load(\n        self,\n        handle: Any,\n        graph: fx.GraphModule,\n        example_inputs: list[Any],\n        graph_index: int,\n        runtime_shape: Optional[int] = None,\n        num_graphs: int = 1,\n    ) -> Callable:\n        raise NotImplementedError(\"eager compilation is not supported\")\n"
  },
  {
    "path": "python/sglang/srt/compilation/cuda_piecewise_backend.py",
    "content": "# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/cuda_piecewise_backend.py\n\nimport dataclasses\nimport logging\nfrom contextlib import ExitStack\nfrom typing import Any, Callable, Optional\nfrom unittest.mock import patch\n\nimport torch\nimport torch.fx as fx\n\nfrom sglang.srt.compilation.compilation_config import CompilationConfig\nfrom sglang.srt.compilation.compilation_counter import compilation_counter\nfrom sglang.srt.compilation.piecewise_context_manager import (\n    get_pcg_capture_stream,\n    is_in_pcg_torch_compile,\n)\nfrom sglang.srt.compilation.weak_ref_tensor import weak_ref_tensors\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass ConcreteSizeEntry:\n    runtime_shape: int\n    need_to_compile: bool  # the size is in compile_sizes\n    use_cudagraph: bool  # the size is in cudagraph_capture_sizes\n\n    compiled: bool = False\n    runnable: Callable = None  # type: ignore\n    num_finished_warmup: int = 0\n    cudagraph: Optional[torch.cuda.CUDAGraph] = None\n    output: Optional[Any] = None\n\n    # for cudagraph debugging, track the input addresses\n    # during capture, and check if they are the same during replay\n    input_addresses: Optional[list[int]] = None\n\n\nclass CUDAPiecewiseBackend:\n\n    def __init__(\n        self,\n        graph: fx.GraphModule,\n        compile_config: CompilationConfig,\n        inductor_config: dict[str, Any],\n        graph_pool: Any,\n        piecewise_compile_index: int,\n        total_piecewise_compiles: int,\n        sym_shape_indices: list[int],\n        compiled_graph_for_general_shape: Callable,\n        sglang_backend,\n    ):\n        \"\"\"\n        The backend for piecewise compilation.\n        It mainly handles the compilation and cudagraph capturing.\n\n        We will compile `self.graph` once for the general shape,\n        and then compile for different shapes specified in\n        `compilation_config.compile_sizes`.\n\n        Independently, we will capture cudagraph for different shapes.\n\n        If a shape needs both compilation and cudagraph, we will\n        compile it first, and then capture cudagraph.\n        \"\"\"\n        self.graph = graph\n        self.inductor_config = inductor_config\n        self.graph_pool = graph_pool\n        self.piecewise_compile_index = piecewise_compile_index\n        self.total_piecewise_compiles = total_piecewise_compiles\n        self.sglang_backend = sglang_backend\n\n        self.is_first_graph = piecewise_compile_index == 0\n        self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1\n\n        self.compile_sizes: set[int] = set([])\n        self.compile_config = compile_config\n        self.cudagraph_capture_sizes: set[int] = set(compile_config.get_capture_sizes())\n\n        self.first_run_finished = False\n\n        self.compiled_graph_for_general_shape = compiled_graph_for_general_shape  # noqa\n\n        self.sym_shape_indices = sym_shape_indices\n\n        # the entries for different shapes that we need to either\n        # compile or capture cudagraph\n        self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {}\n\n        # to_be_compiled_sizes tracks the remaining sizes to compile,\n        # and updates during the compilation process, so we need to copy it\n        self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy()\n        for shape in self.compile_sizes.union(self.cudagraph_capture_sizes):\n            self.concrete_size_entries[shape] = ConcreteSizeEntry(\n                runtime_shape=shape,\n                need_to_compile=shape in self.compile_sizes,\n                use_cudagraph=shape in self.cudagraph_capture_sizes,\n            )\n\n    def check_for_ending_compilation(self):\n        if self.is_last_graph and not self.to_be_compiled_sizes:\n            # no specific sizes to compile\n            # save the hash of the inductor graph for the next run\n            self.sglang_backend.compiler_manager.save_to_file()\n\n    def __call__(self, *args) -> Any:\n        if not self.first_run_finished:\n            self.first_run_finished = True\n            self.check_for_ending_compilation()\n            return self.compiled_graph_for_general_shape(*args)\n\n        if len(self.sym_shape_indices) == 0:\n            return self.compiled_graph_for_general_shape(*args)\n\n        runtime_shape = args[self.sym_shape_indices[0]]\n        if runtime_shape not in self.concrete_size_entries:\n            # we don't need to do anything for this shape\n            return self.compiled_graph_for_general_shape(*args)\n\n        entry = self.concrete_size_entries[runtime_shape]\n\n        if entry.runnable is None:\n            entry.runnable = self.compiled_graph_for_general_shape\n\n        if entry.need_to_compile and not entry.compiled:\n            entry.compiled = True\n            self.to_be_compiled_sizes.remove(runtime_shape)\n            # args are real arguments\n            entry.runnable = self.sglang_backend.compiler_manager.compile(\n                self.graph,\n                args,\n                self.inductor_config,\n                graph_index=self.piecewise_compile_index,\n                num_graphs=self.total_piecewise_compiles,\n                runtime_shape=runtime_shape,\n            )\n\n            # finished compilations for all required shapes\n            if self.is_last_graph and not self.to_be_compiled_sizes:\n                self.check_for_ending_compilation()\n\n        if is_in_pcg_torch_compile():\n            return entry.runnable(*args)\n\n        if entry.cudagraph is None:\n            if entry.num_finished_warmup < 1:  # noqa\n                entry.num_finished_warmup += 1\n                return entry.runnable(*args)\n\n            if self.compile_config.get_enable_debug_mode():\n                input_addresses = [\n                    x.data_ptr() for x in args if isinstance(x, torch.Tensor)\n                ]\n                entry.input_addresses = input_addresses\n            cudagraph = torch.cuda.CUDAGraph()\n\n            with ExitStack() as stack:\n                if not self.is_first_graph:\n                    # during every model forward, we will capture\n                    # many pieces of cudagraphs (roughly one per layer).\n                    # running gc again and again across layers will\n                    # make the cudagraph capture very slow.\n                    # therefore, we only run gc for the first graph,\n                    # and disable gc for the rest of the graphs.\n                    stack.enter_context(patch(\"gc.collect\", lambda: None))\n                    stack.enter_context(patch(\"torch.cuda.empty_cache\", lambda: None))\n                # mind-exploding: carefully manage the reference and memory.\n                stream = get_pcg_capture_stream()\n                assert (\n                    stream is not None\n                ), \"PCG capture stream is not set, please check if runtime recompilation happened\"\n                with torch.cuda.graph(cudagraph, pool=self.graph_pool, stream=stream):\n                    # `output` is managed by pytorch's cudagraph pool\n                    output = entry.runnable(*args)\n                    if self.is_last_graph:\n                        # by converting it to weak ref,\n                        # the original `output` will immediately be released\n                        # to save memory. It is only safe to do this for\n                        # the last graph, because the output of the last graph\n                        # will not be used by any other cuda graph.\n                        output = weak_ref_tensors(output)\n\n            # here we always use weak ref for the output\n            # to save memory\n            entry.output = weak_ref_tensors(output)\n            entry.cudagraph = cudagraph\n\n            compilation_counter.num_cudagraph_captured += 1\n\n            # important: we need to return the output, rather than\n            # the weak ref of the output, so that pytorch can correctly\n            # manage the memory during cuda graph capture\n            return output\n\n        if self.compile_config.get_enable_debug_mode():\n            # check if the input addresses are the same\n            new_input_addresses = [\n                x.data_ptr() for x in args if isinstance(x, torch.Tensor)\n            ]\n            assert new_input_addresses == entry.input_addresses, (\n                \"Input addresses for cudagraphs are different during replay.\"\n                f\" Expected {entry.input_addresses}, got {new_input_addresses}\"\n            )\n        entry.cudagraph.replay()\n        return entry.output\n"
  },
  {
    "path": "python/sglang/srt/compilation/fix_functionalization.py",
    "content": "# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/fix_functionalization.py\n\nimport logging\nimport operator\nfrom collections.abc import Iterable\nfrom typing import Optional, Union\n\nimport torch\nfrom torch._higher_order_ops.auto_functionalize import auto_functionalized\n\nfrom sglang.srt.compilation.fx_utils import is_func\nfrom sglang.srt.compilation.inductor_pass import SGLangInductorPass\n\nlogger = logging.getLogger(__name__)\n\n\nclass FixFunctionalizationPass(SGLangInductorPass):\n    \"\"\"\n    This pass defunctionalizes certain nodes to avoid redundant tensor copies.\n    After this pass, DCE (dead-code elimination) should never be run,\n    as de-functionalized nodes may appear as dead code.\n\n    To add new nodes to defunctionalize, add to the if-elif chain in __call__.\n    \"\"\"\n\n    def __call__(self, graph: torch.fx.Graph):\n        self.begin()\n        self.dump_graph(graph, \"before_fix_functionalization\")\n\n        self.nodes_to_remove: list[torch.fx.Node] = []\n        count = 0\n        for node in graph.nodes:\n            if not is_func(node, auto_functionalized):\n                continue  # Avoid deep if-elif nesting\n            count += 1\n\n        self.dump_graph(graph, \"before_fix_functionalization_cleanup\")\n\n        # Remove the nodes all at once\n        count_removed = len(self.nodes_to_remove)\n        for node in self.nodes_to_remove:\n            graph.erase_node(node)\n\n        logger.debug(\n            \"De-functionalized %s nodes, removed %s nodes\", count, count_removed\n        )\n        self.dump_graph(graph, \"after_fix_functionalization\")\n        self.end_and_log()\n\n    def _remove(self, node_or_nodes: Union[torch.fx.Node, Iterable[torch.fx.Node]]):\n        \"\"\"\n        Stage a node (or nodes) for removal at the end of the pass.\n        \"\"\"\n        if isinstance(node_or_nodes, torch.fx.Node):\n            self.nodes_to_remove.append(node_or_nodes)\n        else:\n            self.nodes_to_remove.extend(node_or_nodes)\n\n    def defunctionalize(\n        self,\n        graph: torch.fx.Graph,\n        node: torch.fx.Node,\n        mutated_args: dict[int, Union[torch.fx.Node, str]],\n        args: Optional[tuple[Union[torch.fx.Node, str], ...]] = None,\n    ):\n        \"\"\"\n        De-functionalize a node by replacing it with a call to the original.\n        It also replaces the getitem users with the mutated arguments.\n        See replace_users_with_mutated_args and insert_defunctionalized.\n        \"\"\"\n        self.replace_users_with_mutated_args(node, mutated_args)\n        self.insert_defunctionalized(graph, node, args=args)\n        self._remove(node)\n\n    def replace_users_with_mutated_args(\n        self, node: torch.fx.Node, mutated_args: dict[int, Union[torch.fx.Node, str]]\n    ):\n        \"\"\"\n        Replace all getitem users of the auto-functionalized node with the\n        mutated arguments.\n        :param node: The auto-functionalized node\n        :param mutated_args: The mutated arguments, indexed by getitem index.\n        If the value of an arg is a string, `node.kwargs[arg]` is used.\n        \"\"\"\n        for idx, user in self.getitem_users(node).items():\n            arg = mutated_args[idx]\n            arg = node.kwargs[arg] if isinstance(arg, str) else arg\n            user.replace_all_uses_with(arg)\n            self._remove(user)\n\n    def getitem_users(self, node: torch.fx.Node) -> dict[int, torch.fx.Node]:\n        \"\"\"\n        Returns the operator.getitem users of the auto-functionalized node,\n        indexed by the index they are getting.\n        \"\"\"\n        users = {}\n        for user in node.users:\n            if is_func(user, operator.getitem):\n                idx = user.args[1]\n                users[idx] = user\n        return users\n\n    def insert_defunctionalized(\n        self,\n        graph: torch.fx.Graph,\n        node: torch.fx.Node,\n        args: Optional[tuple[Union[torch.fx.Node, str], ...]] = None,\n    ):\n        \"\"\"\n        Insert a new defunctionalized node into the graph before node.\n        If one of the kwargs is 'out', provide args directly,\n        as node.kwargs cannot be used.\n        See https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351\n\n        :param graph: Graph to insert the defunctionalized node into\n        :param node: The auto-functionalized node to defunctionalize\n        :param args: If we cannot use kwargs, specify args directly.\n        If an arg is a string, `node.kwargs[arg]` is used.\n        \"\"\"  # noqa: E501\n        assert is_func(\n            node, auto_functionalized\n        ), f\"node must be auto-functionalized, is {node} instead\"\n\n        # Create a new call to the original function\n        with graph.inserting_before(node):\n            function = node.args[0]\n            if args is None:\n                graph.call_function(function, kwargs=node.kwargs)\n            else:\n                # Args passed as strings refer to items in node.kwargs\n                args = tuple(\n                    node.kwargs[arg] if isinstance(arg, str) else arg for arg in args\n                )\n                graph.call_function(function, args=args)\n"
  },
  {
    "path": "python/sglang/srt/compilation/fx_utils.py",
    "content": "# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/fx_utils.py\n\nimport operator\nfrom collections.abc import Iterable, Iterator\nfrom typing import Optional\n\nfrom torch import fx\nfrom torch._higher_order_ops.auto_functionalize import auto_functionalized\nfrom torch._ops import OpOverload\n\n\ndef is_func(node: fx.Node, target) -> bool:\n    return node.op == \"call_function\" and node.target == target\n\n\ndef is_auto_func(node: fx.Node, op: OpOverload) -> bool:\n    return is_func(node, auto_functionalized) and node.args[0] == op\n\n\n# Returns the first specified node with the given op (if it exists)\ndef find_specified_fn_maybe(\n    nodes: Iterable[fx.Node], op: OpOverload\n) -> Optional[fx.Node]:\n    for node in nodes:\n        if node.target == op:\n            return node\n    return None\n\n\n# Returns the first specified node with the given op\ndef find_specified_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node:\n    node = find_specified_fn_maybe(nodes, op)\n    assert node is not None, f\"Could not find {op} in nodes {nodes}\"\n    return node\n\n\n# Returns the first auto_functionalized node with the given op (if it exists)\ndef find_auto_fn_maybe(nodes: Iterable[fx.Node], op: OpOverload) -> Optional[fx.Node]:\n    for node in nodes:\n        if is_func(node, auto_functionalized) and node.args[0] == op:  # noqa\n            return node\n    return None\n\n\n# Returns the first auto_functionalized node with the given op\ndef find_auto_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node:\n    node = find_auto_fn_maybe(nodes, op)\n    assert node is not None, f\"Could not find {op} in nodes {nodes}\"\n    return node\n\n\n# Returns the getitem node that extracts the idx-th element from node\n# (if it exists)\ndef find_getitem_maybe(node: fx.Node, idx: int) -> Optional[fx.Node]:\n    for user in node.users:\n        if is_func(user, operator.getitem) and user.args[1] == idx:\n            return user\n    return None\n\n\n# Returns the getitem node that extracts the idx-th element from node\ndef find_getitem(node: fx.Node, idx: int) -> fx.Node:\n    ret = find_getitem_maybe(node, idx)\n    assert ret is not None, f\"Could not find getitem {idx} in node {node}\"\n    return ret\n\n\n# An auto-functionalization-aware utility for finding nodes with a specific op\ndef find_op_nodes(op: OpOverload, graph: fx.Graph) -> Iterator[fx.Node]:\n    if not op._schema.is_mutable:\n        yield from graph.find_nodes(op=\"call_function\", target=op)\n\n    for n in graph.find_nodes(op=\"call_function\", target=auto_functionalized):\n        if n.args[0] == op:\n            yield n\n\n\n# Asserts that the node only has one user and returns it\n# Even if a node has only 1 user, it might share storage with another node,\n# which might need to be taken into account.\ndef get_only_user(node: fx.Node) -> fx.Node:\n    assert len(node.users) == 1\n    return next(iter(node.users))\n"
  },
  {
    "path": "python/sglang/srt/compilation/inductor_pass.py",
    "content": "# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/inductor_pass.py\n\nimport hashlib\nimport inspect\nimport json\nimport logging\nimport time\nimport types\nfrom contextlib import contextmanager\nfrom typing import Any, Callable, Optional, Union\n\nimport torch\nfrom torch import fx\nfrom torch._dynamo.utils import lazy_format_graph_code\nfrom torch._inductor.custom_graph_pass import CustomGraphPass\n\nlogger = logging.getLogger(__name__)\n\n_pass_context = None\n\n\nclass PassContext:\n\n    def __init__(self, runtime_shape: Optional[int]):\n        self.runtime_shape = runtime_shape\n\n\ndef get_pass_context() -> PassContext:\n    \"\"\"Get the current pass context.\"\"\"\n    assert _pass_context is not None\n    return _pass_context\n\n\n@contextmanager\ndef pass_context(runtime_shape: Optional[int]):\n    \"\"\"A context manager that stores the current pass context,\n    usually it is a list of sizes to specialize.\n    \"\"\"\n    global _pass_context\n    prev_context = _pass_context\n    _pass_context = PassContext(runtime_shape)\n    try:\n        yield\n    finally:\n        _pass_context = prev_context\n\n\nclass InductorPass(CustomGraphPass):\n    \"\"\"\n    A custom graph pass that uses a hash of its source as the UUID.\n    This is defined as a convenience and should work in most cases.\n    \"\"\"\n\n    def uuid(self) -> Any:\n        \"\"\"\n        Provide a unique identifier for the pass, used in Inductor code cache.\n        This should depend on the pass implementation, so that changes to the\n        pass result in recompilation.\n        By default, the object source is hashed.\n        \"\"\"\n        return InductorPass.hash_source(self)\n\n    @staticmethod\n    def hash_source(*srcs: Union[str, Any]):\n        \"\"\"\n        Utility method to hash the sources of functions or objects.\n        :param srcs: strings or objects to add to the hash.\n        Objects and functions have their source inspected.\n        :return:\n        \"\"\"\n        hasher = hashlib.sha256()\n        for src in srcs:\n            if isinstance(src, str):\n                src_str = src\n            elif isinstance(src, types.FunctionType):\n                src_str = inspect.getsource(src)\n            else:\n                src_str = inspect.getsource(src.__class__)\n            hasher.update(src_str.encode(\"utf-8\"))\n        return hasher.hexdigest()\n\n    @staticmethod\n    def hash_dict(dict_: dict[Any, Any]):\n        \"\"\"\n        Utility method to hash a dictionary, can alternatively be used for uuid.\n        :return: A sha256 hash of the json rep of the dictionary.\n        \"\"\"\n        encoded = json.dumps(dict_, sort_keys=True).encode(\"utf-8\")\n        return hashlib.sha256(encoded).hexdigest()\n\n    def is_applicable_for_shape(self, shape: Optional[int]):\n        return True\n\n\nclass CallableInductorPass(InductorPass):\n    \"\"\"\n    This class is a wrapper for a callable that automatically provides an\n    implementation of the UUID.\n    \"\"\"\n\n    def __init__(\n        self, callable: Callable[[fx.Graph], None], uuid: Optional[Any] = None\n    ):\n        self.callable = callable\n        self._uuid = self.hash_source(callable) if uuid is None else uuid\n\n    def __call__(self, graph: torch.fx.Graph):\n        self.callable(graph)\n\n    def uuid(self) -> Any:\n        return self._uuid\n\n\nclass SGLangInductorPass(InductorPass):\n\n    def __init__(\n        self,\n    ):\n        self.pass_name = self.__class__.__name__\n\n    def dump_graph(self, graph: torch.fx.Graph, stage: str):\n        lazy_format_graph_code(stage, graph.owning_module)\n\n    def begin(self):\n        self._start_time = time.perf_counter_ns()\n\n    def end_and_log(self):\n        self._end_time = time.perf_counter_ns()\n        duration_ms = float(self._end_time - self._start_time) / 1.0e6\n        logger.debug(\"%s completed in %.1f ms\", self.pass_name, duration_ms)\n\n\nclass PrinterInductorPass(SGLangInductorPass):\n\n    def __init__(self, name: str):\n        super().__init__()\n        self.name = name\n\n    def __call__(self, graph: torch.fx.Graph):\n        self.dump_graph(graph, self.name)\n"
  },
  {
    "path": "python/sglang/srt/compilation/npu_piecewise_backend.py",
    "content": "from contextlib import ExitStack\nfrom typing import Any, Callable\nfrom unittest.mock import patch\n\nimport torch\nimport torch.fx as fx\n\nfrom sglang.srt.compilation.compilation_config import CompilationConfig\nfrom sglang.srt.compilation.compilation_counter import compilation_counter\nfrom sglang.srt.compilation.cuda_piecewise_backend import (\n    CUDAPiecewiseBackend,\n    weak_ref_tensors,\n)\n\n\nclass NPUPiecewiseBackend(CUDAPiecewiseBackend):\n    def __init__(\n        self,\n        graph: fx.GraphModule,\n        compile_config: CompilationConfig,\n        inductor_config: dict[str, Any],\n        graph_pool: Any,\n        piecewise_compile_index: int,\n        total_piecewise_compiles: int,\n        sym_shape_indices: list[int],\n        compiled_graph_for_general_shape: Callable,\n        sglang_backend,\n    ):\n        super().__init__(\n            graph,\n            compile_config,\n            inductor_config,\n            graph_pool,\n            piecewise_compile_index,\n            total_piecewise_compiles,\n            sym_shape_indices,\n            compiled_graph_for_general_shape,\n            sglang_backend,\n        )\n\n    def __call__(self, *args):\n        runtime_shape = args[self.sym_shape_indices[0]]\n        if runtime_shape not in self.concrete_size_entries:\n            # we don't need to do anything for this shape\n            return self.compiled_graph_for_general_shape(*args)\n\n        entry = self.concrete_size_entries[runtime_shape]\n\n        if entry.runnable is None:\n            entry.runnable = self.compiled_graph_for_general_shape\n\n        if entry.cudagraph is None:\n            if entry.num_finished_warmup < 1:  # noqa\n                entry.num_finished_warmup += 1\n                return entry.runnable(*args)\n\n            if self.compile_config.get_enable_debug_mode():\n                input_addresses = [\n                    x.data_ptr() for x in args if isinstance(x, torch.Tensor)\n                ]\n                entry.input_addresses = input_addresses\n            npugraph = torch.npu.NPUGraph()\n\n            with ExitStack() as stack:\n                if not self.is_first_graph:\n                    # during every model forward, we will capture\n                    # many pieces of cudagraphs (roughly one per layer).\n                    # running gc again and again across layers will\n                    # make the cudagraph capture very slow.\n                    # therefore, we only run gc for the first graph,\n                    # and disable gc for the rest of the graphs.\n                    stack.enter_context(patch(\"gc.collect\", lambda: None))\n                    stack.enter_context(patch(\"torch.npu.empty_cache\", lambda: None))\n\n                # mind-exploding: carefully manage the reference and memory.\n                with torch.npu.graph(npugraph, pool=self.graph_pool):\n                    # `output` is managed by pytorch's cudagraph pool\n                    output = entry.runnable(*args)\n                    if self.is_last_graph:\n                        # by converting it to weak ref,\n                        # the original `output` will immediately be released\n                        # to save memory. It is only safe to do this for\n                        # the last graph, because the output of the last graph\n                        # will not be used by any other cuda graph.\n                        output = weak_ref_tensors(output)\n\n            # here we always use weak ref for the output\n            # to save memory\n            entry.output = weak_ref_tensors(output)\n            entry.cudagraph = npugraph\n\n            compilation_counter.num_cudagraph_captured += 1\n\n            # important: we need to return the output, rather than\n            # the weak ref of the output, so that pytorch can correctly\n            # manage the memory during cuda graph capture\n            return output\n\n        if self.compile_config.get_enable_debug_mode():\n            # check if the input addresses are the same\n            new_input_addresses = [\n                x.data_ptr() for x in args if isinstance(x, torch.Tensor)\n            ]\n            assert new_input_addresses == entry.input_addresses, (\n                \"Input addresses for cudagraphs are different during replay.\"\n                f\" Expected {entry.input_addresses}, got {new_input_addresses}\"\n            )\n        entry.cudagraph.replay()\n        return entry.output\n"
  },
  {
    "path": "python/sglang/srt/compilation/pass_manager.py",
    "content": "# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/pass_manager.py\n\nimport logging\n\nfrom torch import fx as fx\n\nfrom sglang.srt.compilation.fix_functionalization import FixFunctionalizationPass\nfrom sglang.srt.compilation.inductor_pass import (\n    CustomGraphPass,\n    InductorPass,\n    SGLangInductorPass,\n    get_pass_context,\n)\n\nlogger = logging.getLogger(__name__)\n\n\nclass PostGradPassManager(CustomGraphPass):\n    \"\"\"\n    The pass manager for post-grad passes.\n    It handles configuration, adding custom passes, and running passes.\n    It supports uuid for the Inductor code cache. That includes torch<2.6\n    support using pickling (in .inductor_pass.CustomGraphPass).\n\n    The order of the post-grad post-passes is:\n    1. passes (constructor parameter)\n    2. default passes (NoopEliminationPass, FusionPass)\n    3. config[\"post_grad_custom_post_pass\"] (if it exists)\n    4. fix_functionalization\n    This way, all passes operate on a functionalized graph.\n    \"\"\"\n\n    def __init__(self):\n        self.passes: list[SGLangInductorPass] = []\n\n    def __call__(self, graph: fx.Graph):\n        shape = get_pass_context().runtime_shape\n        for pass_ in self.passes:\n            if pass_.is_applicable_for_shape(shape):\n                pass_(graph)\n\n        # always run fix_functionalization last\n        self.fix_functionalization(graph)\n\n    def configure(\n        self,\n    ):\n        self.pass_config = dict()\n        self.fix_functionalization = FixFunctionalizationPass()\n\n    def add(self, pass_: InductorPass):\n        assert isinstance(pass_, InductorPass)\n        self.passes.append(pass_)\n\n    def uuid(self):\n        \"\"\"\n        The PostGradPassManager is set as a custom pass in the Inductor and\n        affects compilation caching. Its uuid depends on the UUIDs of all\n        dependent passes and the pass config. See InductorPass for more info.\n        \"\"\"\n        pass_manager_uuid = \"fshdakhsa\"\n        state = {\"pass_config\": pass_manager_uuid, \"passes\": []}\n        for pass_ in self.passes:\n            state[\"passes\"].append(pass_.uuid())\n        state[\"passes\"].append(self.fix_functionalization.uuid())\n        return InductorPass.hash_dict(state)\n"
  },
  {
    "path": "python/sglang/srt/compilation/piecewise_context_manager.py",
    "content": "from __future__ import annotations\n\nimport logging\nfrom contextlib import contextmanager\nfrom dataclasses import dataclass\nfrom typing import TYPE_CHECKING, Any, List, Optional\n\nimport torch\n\nlogger = logging.getLogger(__name__)\n\n\nif TYPE_CHECKING:\n    from sglang.srt.model_executor.forward_batch_info import ForwardBatch\n\n_in_piecewise_cuda_graph = False\n_in_pcg_torch_compile = False\n_pcg_capture_stream = None\n\n\ndef is_in_piecewise_cuda_graph():\n    return _in_piecewise_cuda_graph\n\n\ndef is_in_pcg_torch_compile():\n    return _in_pcg_torch_compile\n\n\ndef get_pcg_capture_stream():\n    return _pcg_capture_stream\n\n\n@contextmanager\ndef enable_piecewise_cuda_graph_compile():\n    global _in_pcg_torch_compile\n    _in_pcg_torch_compile = True\n    yield\n    _in_pcg_torch_compile = False\n\n\n@contextmanager\ndef enable_piecewise_cuda_graph():\n    global _in_piecewise_cuda_graph\n    _in_piecewise_cuda_graph = True\n    try:\n        yield\n    except Exception as e:\n        logger.error(\n            \"Piecewise CUDA Graph failed with error: %s\\n%s\",\n            e,\n            PIECEWISE_CUDA_GRAPH_CAPTURE_FAILED_MSG,\n        )\n        raise\n    finally:\n        _in_piecewise_cuda_graph = False\n\n\n@contextmanager\ndef set_pcg_capture_stream(stream: torch.cuda.Stream):\n    global _pcg_capture_stream\n    _pcg_capture_stream = stream\n    yield\n    _pcg_capture_stream = None\n\n\n@dataclass\nclass ForwardContext:\n    def __init__(self):\n        self.forward_batch = None\n        self.attention_layers = None\n        self.quant_config = None\n        self.moe_layers = None\n        self.moe_fusions = None\n\n    def set_forward_batch(self, forward_batch: ForwardBatch):\n        self.forward_batch = forward_batch\n\n    def set_attention_layers(self, layers: List[Any]):\n        self.attention_layers = layers\n\n    def set_quant_config(self, quant_config: Any):\n        self.quant_config = quant_config\n\n    def set_moe_layers(self, layers: List[Any]):\n        self.moe_layers = layers\n\n    def set_moe_fusions(self, fusions: List[Any]):\n        self.moe_fusions = fusions\n\n\n_forward_context: Optional[ForwardContext] = None\n\n\ndef get_forward_context() -> Optional[ForwardContext]:\n    if _forward_context is None:\n        return None\n    return _forward_context\n\n\n@contextmanager\ndef set_forward_context(\n    forward_batch: ForwardBatch,\n    attention_layers: List[Any],\n    quant_config: Any,\n    moe_layers: List[Any],\n    moe_fusions: List[Any],\n):\n    global _forward_context\n    _forward_context = ForwardContext()\n    _forward_context.set_forward_batch(forward_batch)\n    _forward_context.set_attention_layers(attention_layers)\n    _forward_context.set_quant_config(quant_config)\n    _forward_context.set_moe_layers(moe_layers)\n    _forward_context.set_moe_fusions(moe_fusions)\n    try:\n        yield\n    finally:\n        _forward_context = None\n\n\nPIECEWISE_CUDA_GRAPH_CAPTURE_FAILED_MSG = (\n    \"Piecewise CUDA Graph is enabled by default as an experimental feature.\\n\"\n    \"To work around this error, add --disable-piecewise-cuda-graph to your launch command.\\n\"\n    \"Please report this issue at https://github.com/sgl-project/sglang/issues/new/choose\"\n)\n"
  },
  {
    "path": "python/sglang/srt/compilation/weak_ref_tensor.py",
    "content": "from typing import Any, Union\n\nimport torch\n\nfrom sglang.srt.utils.common import is_cuda, is_hip, is_npu\n\nif is_cuda() or is_hip():\n    from sgl_kernel import weak_ref_tensor\nelif is_npu():\n    from torch_npu._C import _weak_ref_tensor as weak_ref_tensor\nelse:\n    raise NotImplementedError(\"weak_ref_tensor is implemented only for CUDA and NPU.\")\n\n\ndef weak_ref_tensors(\n    tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]],\n) -> Union[torch.Tensor, list[Any], tuple[Any], Any]:\n    \"\"\"\n    Convenience function to create weak references to tensors,\n    for single tensor, list of tensors or tuple of tensors.\n    \"\"\"\n    if isinstance(tensors, torch.Tensor):\n        return weak_ref_tensor(tensors)\n    if isinstance(tensors, list):\n        return [weak_ref_tensor(t) for t in tensors]\n    if isinstance(tensors, tuple):\n        return tuple(weak_ref_tensor(t) for t in tensors)\n    raise ValueError(\"Invalid type for tensors\")\n"
  },
  {
    "path": "python/sglang/srt/configs/__init__.py",
    "content": "from sglang.srt.configs.afmoe import AfmoeConfig\nfrom sglang.srt.configs.bailing_hybrid import BailingHybridConfig\nfrom sglang.srt.configs.chatglm import ChatGLMConfig\nfrom sglang.srt.configs.dbrx import DbrxConfig\nfrom sglang.srt.configs.deepseekvl2 import DeepseekVL2Config\nfrom sglang.srt.configs.dots_ocr import DotsOCRConfig\nfrom sglang.srt.configs.dots_vlm import DotsVLMConfig\nfrom sglang.srt.configs.exaone import ExaoneConfig\nfrom sglang.srt.configs.falcon_h1 import FalconH1Config\nfrom sglang.srt.configs.granitemoehybrid import GraniteMoeHybridConfig\nfrom sglang.srt.configs.janus_pro import MultiModalityConfig\nfrom sglang.srt.configs.jet_nemotron import JetNemotronConfig\nfrom sglang.srt.configs.jet_vlm import JetVLMConfig\nfrom sglang.srt.configs.kimi_k25 import KimiK25Config\nfrom sglang.srt.configs.kimi_linear import KimiLinearConfig\nfrom sglang.srt.configs.kimi_vl import KimiVLConfig\nfrom sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig\nfrom sglang.srt.configs.lfm2 import Lfm2Config\nfrom sglang.srt.configs.lfm2_moe import Lfm2MoeConfig\nfrom sglang.srt.configs.longcat_flash import LongcatFlashConfig\nfrom sglang.srt.configs.nano_nemotron_vl import NemotronH_Nano_VL_V2_Config\nfrom sglang.srt.configs.nemotron_h import NemotronHConfig\nfrom sglang.srt.configs.olmo3 import Olmo3Config\nfrom sglang.srt.configs.qwen3_5 import Qwen3_5Config, Qwen3_5MoeConfig\nfrom sglang.srt.configs.qwen3_next import Qwen3NextConfig\nfrom sglang.srt.configs.step3_vl import (\n    Step3TextConfig,\n    Step3VisionEncoderConfig,\n    Step3VLConfig,\n)\nfrom sglang.srt.configs.step3p5 import Step3p5Config\n\n__all__ = [\n    \"AfmoeConfig\",\n    \"BailingHybridConfig\",\n    \"ExaoneConfig\",\n    \"ChatGLMConfig\",\n    \"DbrxConfig\",\n    \"DeepseekVL2Config\",\n    \"LongcatFlashConfig\",\n    \"MultiModalityConfig\",\n    \"KimiVLConfig\",\n    \"MoonViTConfig\",\n    \"Step3VLConfig\",\n    \"Step3TextConfig\",\n    \"Step3VisionEncoderConfig\",\n    \"Olmo3Config\",\n    \"KimiLinearConfig\",\n    \"KimiK25Config\",\n    \"Qwen3NextConfig\",\n    \"Qwen3_5Config\",\n    \"Qwen3_5MoeConfig\",\n    \"DotsVLMConfig\",\n    \"DotsOCRConfig\",\n    \"FalconH1Config\",\n    \"GraniteMoeHybridConfig\",\n    \"Lfm2Config\",\n    \"Lfm2MoeConfig\",\n    \"NemotronHConfig\",\n    \"NemotronH_Nano_VL_V2_Config\",\n    \"JetNemotronConfig\",\n    \"JetVLMConfig\",\n    \"Step3p5Config\",\n]\n"
  },
  {
    "path": "python/sglang/srt/configs/afmoe.py",
    "content": "from typing import List, Optional\n\nfrom transformers import PretrainedConfig\n\n\nclass AfmoeConfig(PretrainedConfig):\n    model_type = \"afmoe\"\n\n    def __init__(\n        self,\n        vocab_size: int = 32000,\n        hidden_size: int = 4096,\n        intermediate_size: int = 11008,\n        moe_intermediate_size: int = 256,\n        num_hidden_layers: int = 32,\n        num_attention_heads: int = 32,\n        num_key_value_heads: Optional[int] = None,\n        head_dim: Optional[int] = None,\n        hidden_act: str = \"silu\",\n        max_position_embeddings: int = 131072,\n        initializer_range: float = 0.02,\n        rms_norm_eps: float = 1e-5,\n        use_cache: bool = True,\n        pad_token_id: Optional[int] = None,\n        bos_token_id: int = 1,\n        eos_token_id: int = 2,\n        tie_word_embeddings: bool = False,\n        rope_theta: float = 10000.0,\n        rope_scaling: Optional[dict] = None,\n        attention_bias: bool = False,\n        attention_dropout: float = 0.0,\n        # MoE parameters\n        num_experts: Optional[int] = None,\n        num_experts_per_tok: Optional[int] = None,\n        num_shared_experts: int = 0,\n        num_dense_layers: int = 0,\n        # Routing parameters\n        score_func: str = \"sigmoid\",\n        route_norm: bool = True,\n        route_scale: float = 1.0,\n        n_group: int = 1,\n        topk_group: int = 1,\n        # Attention parameters\n        sliding_window: Optional[int] = None,\n        layer_types: Optional[List[str]] = None,\n        global_attn_every_n_layers: int = 4,\n        # muP scaling\n        mup_enabled: bool = False,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.moe_intermediate_size = moe_intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n\n        if num_key_value_heads is None:\n            num_key_value_heads = num_attention_heads\n\n        self.num_key_value_heads = num_key_value_heads\n        self.head_dim = (\n            head_dim if head_dim is not None else hidden_size // num_attention_heads\n        )\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.rope_scaling = rope_scaling\n        self.attention_bias = attention_bias\n        self.attention_dropout = attention_dropout\n\n        # MoE parameters\n        self.num_experts = num_experts\n        self.num_experts_per_tok = num_experts_per_tok\n        self.num_shared_experts = num_shared_experts\n        self.num_dense_layers = num_dense_layers\n\n        # Routing parameters\n        self.score_func = score_func\n        self.route_norm = route_norm\n        self.route_scale = route_scale\n        self.n_group = n_group\n        self.topk_group = topk_group\n\n        # Attention parameters\n        self.sliding_window = sliding_window\n        self.layer_types = layer_types\n        self.global_attn_every_n_layers = global_attn_every_n_layers\n\n        # muP scaling\n        self.mup_enabled = mup_enabled\n\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n"
  },
  {
    "path": "python/sglang/srt/configs/bailing_hybrid.py",
    "content": "# coding=utf-8\n# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"BailingHybrid model configuration\"\"\"\n\nimport enum\n\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.utils import logging\n\nfrom sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape\n\nlogger = logging.get_logger(__name__)\n\n\nclass HybridLayerType(enum.Enum):\n    full_attention = \"attention\"\n    linear_attention = \"linear_attention\"\n\n\nclass BailingHybridConfig(PretrainedConfig):\n\n    model_type = \"bailing_hybrid\"\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n\n    def __init__(\n        self,\n        vocab_size=157184,\n        hidden_size=2048,\n        intermediate_size=5120,\n        num_hidden_layers=20,\n        num_attention_heads=16,\n        num_key_value_heads=4,\n        hidden_act=\"silu\",\n        use_qkv_bias=False,  # bailing only\n        use_bias=False,  # bailing only\n        rms_norm_eps=1e-06,\n        tie_word_embeddings=False,  # PretrainedConfig key, here change default value.\n        embedding_dropout=0.0,\n        attention_dropout=0.0,\n        output_dropout=0.0,\n        initializer_range=0.02,\n        max_position_embeddings=32768,\n        rope_theta=600000.0,\n        use_cache=True,\n        max_window_layers=20,\n        rope_scaling=None,\n        pad_token_id=156892,\n        eos_token_id=156892,\n        num_experts=256,\n        num_shared_experts=1,\n        num_experts_per_tok=8,\n        n_group=8,\n        topk_group=4,\n        moe_intermediate_size=512,\n        first_k_dense_replace=1,\n        head_dim=128,\n        output_router_logits=False,\n        use_qk_norm=True,\n        num_nextn_predict_layers=0,\n        mtp_loss_scaling_factor=0,\n        moe_router_enable_expert_bias=True,\n        routed_scaling_factor=1.0,\n        layer_group_size=1,\n        group_norm_size=1,\n        linear_silu=False,\n        kv_lora_rank=512,\n        q_lora_rank=None,\n        qk_rope_head_dim=64,\n        v_head_dim=128,\n        qk_nope_head_dim=128,\n        rope_interleave=True,\n        **kwargs,\n    ):\n        self.num_hidden_layers = num_hidden_layers\n        self.vocab_size = vocab_size\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_attention_heads = num_attention_heads\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.use_qkv_bias = use_qkv_bias\n        self.use_bias = use_bias\n        self.rms_norm_eps = rms_norm_eps\n        self.embedding_dropout = embedding_dropout\n        self.attention_dropout = attention_dropout\n        self.output_dropout = output_dropout\n        self.num_nextn_predict_layers = num_nextn_predict_layers\n        self.mtp_loss_scaling_factor = mtp_loss_scaling_factor\n        self.initializer_range = initializer_range\n        self.max_position_embeddings = max_position_embeddings\n        self.rope_theta = rope_theta\n        self.use_cache = use_cache\n        self.max_window_layers = max_window_layers\n        self.head_dim = head_dim or self.hidden_size // self.num_attention_heads\n        self.rope_scaling = rope_scaling\n        self.use_qk_norm = use_qk_norm\n        self.moe_router_enable_expert_bias = moe_router_enable_expert_bias\n        self.routed_scaling_factor = routed_scaling_factor\n\n        # MoE configs\n        self.num_experts = num_experts\n        self.num_shared_experts = num_shared_experts\n        self.num_experts_per_tok = num_experts_per_tok\n        self.n_group = n_group\n        self.topk_group = topk_group\n        self.moe_intermediate_size = moe_intermediate_size\n        self.first_k_dense_replace = first_k_dense_replace\n        self.output_router_logits = output_router_logits\n\n        # Linear configs\n        self.layer_group_size = layer_group_size\n        self.group_norm_size = group_norm_size\n        self.linear_silu = linear_silu\n        self.num_linear_key_value_heads = num_attention_heads\n        # mla\n        self.kv_lora_rank = kv_lora_rank\n        self.q_lora_rank = q_lora_rank\n        self.qk_rope_head_dim = qk_rope_head_dim\n        self.v_head_dim = v_head_dim\n        self.qk_nope_head_dim = qk_nope_head_dim\n        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim\n        self.rope_interleave = rope_interleave\n        self.for_nextn_model = False\n        super().__init__(\n            pad_token_id=pad_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n\n    @property\n    def layers_block_type(self):\n        if self.for_nextn_model:\n            return [HybridLayerType.full_attention.value]\n\n        layer_type_list = []\n\n        for l in range(self.num_hidden_layers):\n            if (l + 1) % self.layer_group_size == 0:\n                layer_type_list.append(HybridLayerType.full_attention.value)\n            else:\n                layer_type_list.append(HybridLayerType.linear_attention.value)\n\n        return layer_type_list\n\n    @property\n    def linear_layer_ids(self):\n        return [\n            i\n            for i, type_value in enumerate(self.layers_block_type)\n            if type_value == HybridLayerType.linear_attention.value\n        ]\n\n    @property\n    def full_attention_layer_ids(self):\n        return [\n            i\n            for i, type_value in enumerate(self.layers_block_type)\n            if type_value == HybridLayerType.full_attention.value\n        ]\n\n    @property\n    def mamba2_cache_params(self) -> Mamba2CacheParams:\n        from sglang.srt.layers.dp_attention import get_attention_tp_size\n\n        shape = Mamba2StateShape.create(\n            tp_world_size=get_attention_tp_size(),\n            intermediate_size=0,\n            n_groups=0,\n            num_heads=self.num_linear_key_value_heads,\n            head_dim=self.head_dim,\n            state_size=self.head_dim,\n            conv_kernel=1,\n        )\n\n        return Mamba2CacheParams(shape=shape, layers=self.linear_layer_ids)\n"
  },
  {
    "path": "python/sglang/srt/configs/chatglm.py",
    "content": "# Adapted from\n# https://github.com/THUDM/ChatGLM2-6B\n# https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/configs/chatglm.py\n\n# ChatGLM2 and ChatGLM3 share the same config.\n# ChatGLM4 is officially supported by Huggingface\n# transformers >= 4.46.0 is required\n# https://huggingface.co/docs/transformers/en/model_doc/glm\nfrom transformers import PretrainedConfig\n\n\nclass ChatGLMConfig(PretrainedConfig):\n    model_type = \"chatglm\"\n    attribute_map = {\n        \"num_hidden_layers\": \"num_layers\",\n        \"n_head_kv\": \"multi_query_group_num\",\n    }\n\n    def __init__(\n        self,\n        num_layers=28,\n        padded_vocab_size=65024,\n        hidden_size=4096,\n        ffn_hidden_size=13696,\n        kv_channels=128,\n        num_attention_heads=32,\n        seq_length=2048,\n        hidden_dropout=0.0,\n        attention_dropout=0.0,\n        layernorm_epsilon=1e-5,\n        rmsnorm=True,\n        apply_residual_connection_post_layernorm=False,\n        post_layer_norm=True,\n        add_bias_linear=False,\n        add_qkv_bias=False,\n        interleaved_qkv=False,\n        bias_dropout_fusion=True,\n        multi_query_attention=False,\n        multi_query_group_num=1,\n        apply_query_key_layer_scaling=True,\n        attention_softmax_in_fp32=True,\n        fp32_residual_connection=False,\n        quantization_bit=0,\n        pre_seq_len=None,\n        prefix_projection=False,\n        **kwargs\n    ):\n        self.num_layers = num_layers\n        self.vocab_size = padded_vocab_size\n        self.padded_vocab_size = padded_vocab_size\n        self.hidden_size = hidden_size\n        self.ffn_hidden_size = ffn_hidden_size\n        self.kv_channels = kv_channels\n        self.num_attention_heads = num_attention_heads\n        self.seq_length = seq_length\n        # It is to be compatible with long lora.\n        self.max_position_embeddings = seq_length\n        self.hidden_dropout = hidden_dropout\n        self.attention_dropout = attention_dropout\n        self.layernorm_epsilon = layernorm_epsilon\n        self.rmsnorm = rmsnorm\n        self.apply_residual_connection_post_layernorm = (\n            apply_residual_connection_post_layernorm\n        )\n        self.post_layer_norm = post_layer_norm\n        self.add_bias_linear = add_bias_linear\n        self.add_qkv_bias = add_qkv_bias\n        self.bias_dropout_fusion = bias_dropout_fusion\n        self.multi_query_attention = multi_query_attention\n        self.multi_query_group_num = multi_query_group_num\n        self.apply_query_key_layer_scaling = apply_query_key_layer_scaling\n        self.attention_softmax_in_fp32 = attention_softmax_in_fp32\n        self.fp32_residual_connection = fp32_residual_connection\n        self.quantization_bit = quantization_bit\n        self.pre_seq_len = pre_seq_len\n        self.prefix_projection = prefix_projection\n        self.interleaved_qkv = interleaved_qkv\n        super().__init__(**kwargs)\n"
  },
  {
    "path": "python/sglang/srt/configs/dbrx.py",
    "content": "# Adapted from\n# https://huggingface.co/databricks/dbrx-base/blob/main/configuration_dbrx.py\n# https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/configs/dbrx.py\n\"\"\"Dbrx configuration.\"\"\"\n\nfrom typing import Any, Optional\n\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.utils import logging\n\nlogger = logging.get_logger(__name__)\n\nDBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {}  # type: ignore\n\n\nclass DbrxAttentionConfig(PretrainedConfig):\n    \"\"\"Configuration class for Dbrx Attention.\n\n    [`DbrxAttention`] class. It is used to instantiate attention layers\n    according to the specified arguments, defining the layers architecture.\n\n    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the\n    documentation from [`PretrainedConfig`] for more information.\n\n    Args:\n        attn_pdrop (`float`, *optional*, defaults to 0.0):\n            The dropout probability for the attention layers.\n        clip_qkv (`float`, *optional*, defaults to None):\n            If not `None`, clip the queries, keys, and values in the attention layer to this value.\n        kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads.\n        rope_theta (float): The base frequency for rope.\n    \"\"\"\n\n    def __init__(\n        self,\n        attn_pdrop: float = 0,\n        clip_qkv: Optional[float] = None,\n        kv_n_heads: int = 1,\n        rope_theta: float = 10000.0,\n        **kwargs: Any,\n    ):\n        super().__init__(**kwargs)\n        self.attn_pdrop = attn_pdrop\n        self.clip_qkv = clip_qkv\n        self.kv_n_heads = kv_n_heads\n        self.rope_theta = rope_theta\n\n        for k in [\"model_type\"]:\n            if k in kwargs:\n                kwargs.pop(k)\n        if len(kwargs) != 0:\n            raise ValueError(f\"Found unknown {kwargs=}\")\n\n    @classmethod\n    def from_pretrained(\n        cls, pretrained_model_name_or_path: str, **kwargs: Any\n    ) -> \"PretrainedConfig\":\n        cls._set_token_in_kwargs(kwargs)\n\n        config_dict, kwargs = cls.get_config_dict(\n            pretrained_model_name_or_path, **kwargs\n        )\n\n        if config_dict.get(\"model_type\") == \"dbrx\":\n            config_dict = config_dict[\"attn_config\"]\n\n        if (\n            \"model_type\" in config_dict\n            and hasattr(cls, \"model_type\")\n            and config_dict[\"model_type\"] != cls.model_type\n        ):\n            logger.warning(\n                \"You are using a model of type %s to instantiate a model of \"\n                \"type %s. This is not supported for all configurations of \"\n                \"models and can yield errors.\",\n                config_dict[\"model_type\"],\n                cls.model_type,\n            )\n\n        return cls.from_dict(config_dict, **kwargs)\n\n\nclass DbrxFFNConfig(PretrainedConfig):\n    \"\"\"Configuration class for Dbrx FFN.\n\n    [`DbrxFFN`] class. It is used to instantiate feedforward layers according to\n    the specified arguments, defining the layers architecture.\n\n    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the\n    documentation from [`PretrainedConfig`] for more information.\n\n    Args:\n        ffn_act_fn (dict, optional): A dict specifying activation function for the FFN.\n            The dict should have a key 'name' with the value being the name of\n            the activation function along with any additional keyword arguments.\n        ffn_hidden_size (int, optional): The hidden size of the feedforward network.\n        moe_num_experts (int, optional): The number of experts in the mixture of experts layer.\n        moe_top_k (int, optional): The number of experts to use in the mixture of experts layer.\n        moe_jitter_eps (float, optional): The jitter epsilon for the mixture of experts layer.\n        moe_loss_weight (float, optional): The loss weight for the mixture of experts layer.\n        moe_normalize_expert_weights (float, optional): The normalization factor for the expert weights.\n        uniform_expert_assignment (bool, optional): Whether to use uniform expert assignment.\n            This should only be used for benchmarking purposes.\n    \"\"\"\n\n    def __init__(\n        self,\n        ffn_act_fn: Optional[dict] = None,\n        ffn_hidden_size: int = 3584,\n        moe_num_experts: int = 4,\n        moe_top_k: int = 1,\n        moe_jitter_eps: Optional[float] = None,\n        moe_loss_weight: float = 0.01,\n        moe_normalize_expert_weights: Optional[float] = 1,\n        uniform_expert_assignment: bool = False,\n        **kwargs: Any,\n    ):\n        super().__init__()\n        if ffn_act_fn is None:\n            ffn_act_fn = {\"name\": \"silu\"}\n        self.ffn_act_fn = ffn_act_fn\n        self.ffn_hidden_size = ffn_hidden_size\n        self.moe_num_experts = moe_num_experts\n        self.moe_top_k = moe_top_k\n        self.moe_jitter_eps = moe_jitter_eps\n        self.moe_loss_weight = moe_loss_weight\n        self.moe_normalize_expert_weights = moe_normalize_expert_weights\n        self.uniform_expert_assignment = uniform_expert_assignment\n\n        for k in [\"model_type\"]:\n            if k in kwargs:\n                kwargs.pop(k)\n        if len(kwargs) != 0:\n            raise ValueError(f\"Found unknown {kwargs=}\")\n\n    @classmethod\n    def from_pretrained(\n        cls, pretrained_model_name_or_path: str, **kwargs: Any\n    ) -> \"PretrainedConfig\":\n        cls._set_token_in_kwargs(kwargs)\n\n        config_dict, kwargs = cls.get_config_dict(\n            pretrained_model_name_or_path, **kwargs\n        )\n\n        if config_dict.get(\"model_type\") == \"dbrx\":\n            config_dict = config_dict[\"ffn_config\"]\n\n        if (\n            \"model_type\" in config_dict\n            and hasattr(cls, \"model_type\")\n            and config_dict[\"model_type\"] != cls.model_type\n        ):\n            logger.warning(\n                \"You are using a model of type %s to instantiate a model of \"\n                \"type %s. This is not supported for all \"\n                \"configurations of models and can yield errors.\",\n                config_dict[\"model_type\"],\n                cls.model_type,\n            )\n\n        return cls.from_dict(config_dict, **kwargs)\n\n\nclass DbrxConfig(PretrainedConfig):\n    \"\"\"Configuration class for Dbrx.\n\n    [`DbrxModel`]. It is used to instantiate a Dbrx model according to the\n    specified arguments, defining the model architecture.\n\n    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the\n    documentation from [`PretrainedConfig`] for more information.\n\n\n    Args:\n        d_model (`int`, *optional*, defaults to 6144):\n            Dimensionality of the embeddings and hidden states.\n        n_heads (`int`, *optional*, defaults to 48):\n            Number of attention heads for each attention layer in the Transformer encoder.\n        n_layers (`int`, *optional*, defaults to 40):\n            Number of hidden layers in the Transformer encoder.\n        max_seq_len (`int`, *optional*, defaults to 32768):\n            The maximum sequence length of the model.\n        vocab_size (`int`, *optional*, defaults to 100352):\n            Vocabulary size of the Dbrx model. Defines the maximum number of different tokens that can be represented by\n            the `inputs_ids` passed when calling [`DbrxModel`].\n        resid_pdrop (`float`, *optional*, defaults to 0.0):\n            The dropout probability applied to the attention output before combining with residual.\n        emb_pdrop (`float`, *optional*, defaults to 0.0):\n            The dropout probability for the embedding layer.\n        attn_config (`dict`, *optional*):\n            A dictionary used to configure the model's attention module.\n        ffn_config (`dict`, *optional*):\n            A dictionary used to configure the model's FFN module.\n        use_cache (`bool`, *optional*, defaults to `False`):\n            Whether or not the model should return the last key/values attentions (not used by all models).\n        initializer_range (`float`, *optional*, defaults to 0.02):\n            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.\n        output_router_logits (`bool`, *optional*, defaults to `False`):\n            Whether or not the router logits should be returned by the model. Enabling this will also\n            allow the model to output the auxiliary loss. See [here]() for more details\n        router_aux_loss_coef (`float`, *optional*, defaults to 0.001):\n            The aux loss factor for the total loss.\n\n\n    Example:\n    ```python\n    >>> from transformers import DbrxConfig, DbrxModel\n\n    >>> # Initializing a Dbrx configuration\n    >>> configuration = DbrxConfig()\n\n    >>> # Initializing a model (with random weights) from the configuration\n    >>> model = DbrxModel(configuration)\n\n    >>> # Accessing the model configuration\n    >>> configuration = model.config\n    ```\n    \"\"\"\n\n    model_type = \"dbrx\"\n    attribute_map = {\n        \"num_attention_heads\": \"n_heads\",\n        \"hidden_size\": \"d_model\",\n        \"num_hidden_layers\": \"n_layers\",\n        \"max_position_embeddings\": \"max_seq_len\",\n    }\n\n    def __init__(\n        self,\n        d_model: int = 2048,\n        n_heads: int = 16,\n        n_layers: int = 24,\n        max_seq_len: int = 2048,\n        vocab_size: int = 32000,\n        resid_pdrop: float = 0.0,\n        emb_pdrop: float = 0.0,\n        attn_config: Optional[DbrxAttentionConfig] = None,\n        ffn_config: Optional[DbrxFFNConfig] = None,\n        use_cache: bool = True,\n        initializer_range: float = 0.02,\n        output_router_logits: bool = False,\n        router_aux_loss_coef: float = 0.05,\n        **kwargs: Any,\n    ):\n        if attn_config is None:\n            self.attn_config = DbrxAttentionConfig()\n        elif isinstance(attn_config, dict):\n            self.attn_config = DbrxAttentionConfig(**attn_config)\n        else:\n            self.attn_config = attn_config\n\n        if ffn_config is None:\n            self.ffn_config = DbrxFFNConfig()\n        elif isinstance(ffn_config, dict):\n            self.ffn_config = DbrxFFNConfig(**ffn_config)\n        else:\n            self.ffn_config = ffn_config\n\n        self.d_model = d_model\n        self.n_heads = n_heads\n        self.n_layers = n_layers\n        self.max_seq_len = max_seq_len\n        self.vocab_size = vocab_size\n        self.resid_pdrop = resid_pdrop\n        self.emb_pdrop = emb_pdrop\n        self.use_cache = use_cache\n        self.initializer_range = initializer_range\n        self.output_router_logits = output_router_logits\n        self.router_aux_loss_coef = router_aux_loss_coef\n\n        tie_word_embeddings = kwargs.pop(\"tie_word_embeddings\", False)\n        if tie_word_embeddings:\n            raise ValueError(\"tie_word_embeddings is not supported for Dbrx models.\")\n\n        super().__init__(\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n"
  },
  {
    "path": "python/sglang/srt/configs/deepseek_ocr.py",
    "content": "import math\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List, Optional, Tuple\n\nimport torch\nfrom PIL import Image, ImageOps\nfrom transformers import (\n    AutoProcessor,\n    LlamaTokenizerFast,\n    PretrainedConfig,\n    ProcessorMixin,\n)\n\nfrom sglang.srt.multimodal.customized_mm_processor_utils import (\n    register_customized_processor,\n)\nfrom sglang.srt.sampling.custom_logit_processor import (\n    DeepseekOCRNoRepeatNGramLogitProcessor,\n)\n\nBASE_SIZE = 1024\nIMAGE_SIZE = 640\nCROP_MODE = True\nMIN_CROPS = 2\nMAX_CROPS = 6  # max:9; If your GPU memory is small, it is recommended to set it to 6.\nMAX_CONCURRENCY = 100  # If you have limited GPU memory, lower the concurrency count.\nNUM_WORKERS = 64  # image pre-process (resize/padding) workers\nPRINT_NUM_VIS_TOKENS = False\nSKIP_REPEAT = True\nMODEL_PATH = \"deepseek-ai/DeepSeek-OCR\"  # change to your model path\n\nNGRAM_NO_REPEAT_SIZE = 30\nNGRAM_NO_REPEAT_WINDOW = 90\n# Whitelist `<td>` and `</td>` token ids to allow table structures.\nNGRAM_NO_REPEAT_WHITELIST = (128821, 128822)\n\nDEFAULT_CUSTOM_LOGIT_PROCESSOR = DeepseekOCRNoRepeatNGramLogitProcessor.to_str()\n\n\ndef get_default_ngram_custom_params() -> Dict[str, Any]:\n    \"\"\"Return default custom params for the DeepSeek-OCR n-gram no repeat processor.\"\"\"\n\n    return {\n        \"ngram_size\": NGRAM_NO_REPEAT_SIZE,\n        \"window_size\": NGRAM_NO_REPEAT_WINDOW,\n        \"whitelist_token_ids\": list(NGRAM_NO_REPEAT_WHITELIST),\n    }\n\n\nPROMPT = \"<image>\\n<|grounding|>Convert the document to markdown.\"\n\n\nclass DictOutput(object):\n    def items(self):\n        return self.__dict__.items()\n\n    def keys(self):\n        return self.__dict__.keys()\n\n    def __getitem__(self, item):\n        return self.__dict__[item]\n\n    def __contains__(self, key):\n        return key in self.__dict__\n\n    def __setitem__(self, key, value):\n        self.__dict__[key] = value\n\n\n@dataclass\nclass VLChatProcessorOutput(DictOutput):\n    input_ids: torch.LongTensor\n    target_ids: torch.LongTensor\n    images_crop: torch.LongTensor\n    pixel_values: (\n        torch.Tensor\n    )  # rename from \"images\" to \"pixel_values\" for compatibility\n    images_seq_mask: torch.BoolTensor\n    images_spatial_crop: torch.LongTensor\n\n    def __len__(self):\n        return len(self.input_ids)\n\n\nclass ImageTransform(object):\n    def __init__(\n        self,\n        mean: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),\n        std: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),\n        normalize: bool = True,\n    ):\n        self.mean = mean\n        self.std = std\n        self.normalize = normalize\n\n        # only load torchvision.transforms when needed\n        try:\n            import torchvision.transforms as T\n\n            # FIXME: add version check for gguf\n        except ImportError as err:\n            raise ImportError(\n                \"Please install torchvision via `pip install torchvision` to use Deepseek-VL2.\"\n            ) from err\n\n        transform_pipelines = [T.ToTensor()]\n\n        if normalize:\n            transform_pipelines.append(T.Normalize(mean, std))\n\n        self.transform = T.Compose(transform_pipelines)\n\n    def __call__(self, pil_img: Image.Image):\n        x = self.transform(pil_img)\n        return x\n\n\ndef find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):\n    best_ratio_diff = float(\"inf\")\n    best_ratio = (1, 1)\n    area = width * height\n    for ratio in target_ratios:\n        target_aspect_ratio = ratio[0] / ratio[1]\n        ratio_diff = abs(aspect_ratio - target_aspect_ratio)\n        if ratio_diff < best_ratio_diff:\n            best_ratio_diff = ratio_diff\n            best_ratio = ratio\n        elif ratio_diff == best_ratio_diff:\n            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:\n                best_ratio = ratio\n    return best_ratio\n\n\ndef dynamic_preprocess(\n    image, min_num=MIN_CROPS, max_num=MAX_CROPS, image_size=640, use_thumbnail=False\n):\n    orig_width, orig_height = image.size\n    aspect_ratio = orig_width / orig_height\n\n    # calculate the existing image aspect ratio\n    target_ratios = set(\n        (i, j)\n        for n in range(min_num, max_num + 1)\n        for i in range(1, n + 1)\n        for j in range(1, n + 1)\n        if i * j <= max_num and i * j >= min_num\n    )\n    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])\n\n    # find the closest aspect ratio to the target\n    target_aspect_ratio = find_closest_aspect_ratio(\n        aspect_ratio, target_ratios, orig_width, orig_height, image_size\n    )\n\n    # calculate the target width and height\n    target_width = image_size * target_aspect_ratio[0]\n    target_height = image_size * target_aspect_ratio[1]\n    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]\n\n    # resize the image\n    resized_img = image.resize((target_width, target_height))\n    processed_images = []\n    for i in range(blocks):\n        box = (\n            (i % (target_width // image_size)) * image_size,\n            (i // (target_width // image_size)) * image_size,\n            ((i % (target_width // image_size)) + 1) * image_size,\n            ((i // (target_width // image_size)) + 1) * image_size,\n        )\n        # split the image\n        split_img = resized_img.crop(box)\n        processed_images.append(split_img)\n    assert len(processed_images) == blocks\n    if use_thumbnail and len(processed_images) != 1:\n        thumbnail_img = image.resize((image_size, image_size))\n        processed_images.append(thumbnail_img)\n    return processed_images, target_aspect_ratio\n\n\nclass DeepseekOCRProcessor(ProcessorMixin):\n    tokenizer_class = (\"LlamaTokenizer\", \"LlamaTokenizerFast\")\n    attributes = [\"tokenizer\"]\n\n    def __init__(\n        self,\n        tokenizer: LlamaTokenizerFast,\n        candidate_resolutions: Tuple[Tuple[int, int]],\n        patch_size: int,\n        downsample_ratio: int,\n        image_mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),\n        image_std: Tuple[float, float, float] = (0.5, 0.5, 0.5),\n        normalize: bool = True,\n        image_token: str = \"<image>\",\n        pad_token: str = \"<｜▁pad▁｜>\",\n        add_special_token: bool = False,\n        sft_format: str = \"deepseek\",\n        mask_prompt: bool = True,\n        ignore_id: int = -100,\n        ocr2_mode: bool = False,\n        **kwargs,\n    ):\n\n        self.candidate_resolutions = candidate_resolutions\n        self.image_size = candidate_resolutions[0][0]\n        self.patch_size = patch_size\n        self.image_mean = image_mean\n        self.image_std = image_std\n        self.normalize = normalize\n        self.downsample_ratio = downsample_ratio\n        self.base_size = BASE_SIZE\n        self.image_transform = ImageTransform(\n            mean=image_mean, std=image_std, normalize=normalize\n        )\n        self.tokenizer = tokenizer\n        # must set this，padding side with make a difference in batch inference\n        self.tokenizer.padding_side = \"left\"\n\n        # add the pad_token as special token to use 'tokenizer.pad_token' and 'tokenizer.pad_token_id'\n        if tokenizer.pad_token is None:\n            self.tokenizer.add_special_tokens({\"pad_token\": pad_token})\n\n        # add image token\n        image_token_id = self.tokenizer.vocab.get(image_token)\n        if image_token_id is None:\n            special_tokens = [image_token]\n            special_tokens_dict = {\"additional_special_tokens\": special_tokens}\n            self.tokenizer.add_special_tokens(special_tokens_dict)\n        self.image_token_id = self.tokenizer.vocab.get(image_token)\n\n        # add five special tokens for grounding-related tasks\n        # <|ref|>, <|/ref|>, <|det|>, <|/det|>, <|grounding|>\n        special_tokens = [\"<|ref|>\", \"<|/ref|>\", \"<|det|>\", \"<|/det|>\", \"<|grounding|>\"]\n        special_tokens_dict = {\"additional_special_tokens\": special_tokens}\n        self.tokenizer.add_special_tokens(special_tokens_dict)\n\n        # add special tokens for SFT data\n        special_tokens = [\"<|User|>\", \"<|Assistant|>\"]\n        special_tokens_dict = {\"additional_special_tokens\": special_tokens}\n        self.tokenizer.add_special_tokens(special_tokens_dict)\n\n        self.image_token = image_token\n        self.pad_token = pad_token\n        self.add_special_token = add_special_token\n        self.sft_format = sft_format\n        self.mask_prompt = mask_prompt\n        self.ignore_id = ignore_id\n        self.ocr2_mode = ocr2_mode\n\n        super().__init__(\n            tokenizer,\n            **kwargs,\n        )\n\n    def format_messages_v2(self, messages: str, pil_images, max_req_input_len=-1):\n        \"\"\"play the role of format_messages_v2 and get_images_info in the last version\"\"\"\n        tokenized_data = []\n        masked_tokenized_data = []  # labels\n        images_list = []\n        images_seq_mask = []\n        images_spatial_crop = []\n\n        image_index = 0\n        image_token_cnt = messages.count(self.image_token)\n        (\n            input_ids,\n            images,\n            images_crop,\n            seq_mask,\n            spatial_crop,\n            num_image_tokens,\n            image_shapes,\n        ) = self.tokenize_with_images(\n            messages,\n            pil_images[image_index : image_index + image_token_cnt],\n            bos=True,\n            eos=True,\n            cropping=len(pil_images) <= 2,\n        )\n\n        image_index = image_token_cnt\n        images_list += images\n        images_seq_mask += seq_mask\n        images_spatial_crop = spatial_crop\n\n        return (\n            input_ids,\n            masked_tokenized_data,\n            images_list,\n            images_seq_mask,\n            images_spatial_crop,\n            images_crop,\n        )\n\n    @property\n    def bos_id(self):\n        return self.tokenizer.bos_token_id\n\n    @property\n    def eos_id(self):\n        return self.tokenizer.eos_token_id\n\n    @property\n    def pad_id(self):\n        return self.tokenizer.pad_token_id\n\n    def encode(self, text: str, bos: bool = True, eos: bool = False):\n        t = self.tokenizer.encode(text, add_special_tokens=False)\n\n        if bos:\n            t = [self.bos_id] + t\n        if eos:\n            t = t + [self.eos_id]\n\n        return t\n\n    def decode(self, t: List[int], **kwargs) -> str:\n        return self.tokenizer.decode(t, **kwargs)\n\n    def process_one(\n        self,\n        prompt: str = None,\n        conversations: List[Dict[str, str]] = None,\n        images: List[Image.Image] = None,\n        apply_sft_format: bool = False,\n        inference_mode: bool = True,\n        system_prompt: str = \"\",\n        max_req_input_len: int = -1,\n        cropping: bool = True,\n        **kwargs,\n    ):\n        \"\"\"\n\n        Args:\n            prompt (str): the formatted prompt;\n            conversations (List[Dict]): conversations with a list of messages;\n            images (List[ImageType]): the list of images;\n            apply_sft_format (bool): if prompt is not None, then apply the SFT format to prompt;\n                if conversations is not None, then it will always apply the SFT format to conversations;\n            inference_mode (bool): if True, then remove the last eos token;\n            system_prompt (str): the system prompt;\n            **kwargs:\n\n        Returns:\n            outputs (BaseProcessorOutput): the output of the processor,\n                - input_ids (torch.LongTensor): [N + image tokens]\n                - target_ids (torch.LongTensor): [N + image tokens]\n                - images (torch.FloatTensor): [n_images, 3, H, W]\n                - image_id (int): the id of the image token\n                - num_image_tokens (List[int]): the number of image tokens\n        \"\"\"\n\n        prompt = conversations or prompt\n        (\n            input_ids,\n            masked_tokenized_str,\n            images_list,\n            images_seq_mask,\n            images_spatial_crop,\n            images_crop,\n        ) = self.format_messages_v2(prompt, images, max_req_input_len)\n\n        target_ids = torch.LongTensor(masked_tokenized_str)\n\n        has_images = len(images_list) > 0\n        has_local_crops = False\n        if len(images_spatial_crop) > 0:\n            has_local_crops = any(\n                crop[0] > 1 or crop[1] > 1 for crop in images_spatial_crop\n            )\n\n        if len(images_list) == 0:\n            images = torch.zeros((1, 3, self.image_size, self.image_size))\n        else:\n            images = torch.stack(images_list, dim=0)\n\n        images_spatial_crop = torch.stack(\n            [images_spatial_crop], dim=0\n        )  # stack the tensor to make it a batch of 1\n\n        prepare = VLChatProcessorOutput(\n            input_ids=input_ids,\n            target_ids=target_ids,\n            images_crop=images_crop,\n            pixel_values=images,\n            images_seq_mask=images_seq_mask,\n            images_spatial_crop=images_spatial_crop,\n        )\n        prepare.has_images = has_images\n        prepare.has_local_crops = has_local_crops\n\n        return prepare\n\n    def __call__(\n        self,\n        *,\n        prompt: str = None,\n        conversations: List[Dict[str, str]] = None,\n        images: List[Image.Image] = None,\n        apply_sft_format: bool = False,\n        inference_mode: bool = True,\n        system_prompt: str = \"\",\n        max_req_input_len: int = -1,\n        text: list[str] = None,\n        **kwargs,\n    ):\n        assert text is None or isinstance(text, list)\n        if text is not None:\n            text = text[0]\n        prepare = self.process_one(\n            prompt=prompt or text,\n            conversations=conversations,\n            images=images,\n            apply_sft_format=apply_sft_format,\n            inference_mode=inference_mode,\n            system_prompt=system_prompt,\n            max_req_input_len=max_req_input_len,\n        )\n\n        return prepare\n\n    def find_all_indices(self, messages, target_value):\n        indices = []\n        for index, item in enumerate(messages):\n            if item == target_value:\n                indices.append(index)\n        return indices\n\n    def tokenize_with_images(\n        self,\n        conversation: str,\n        images: List[Image.Image],\n        bos: bool = True,\n        eos: bool = True,\n        cropping: bool = True,\n    ):\n        \"\"\"Tokenize text with <image> tags.\"\"\"\n\n        conversation = conversation\n        assert conversation.count(self.image_token) == len(images)\n        text_splits = conversation.split(self.image_token)\n        images_list, images_crop_list, images_seq_mask, images_spatial_crop = (\n            [],\n            [],\n            [],\n            [],\n        )\n        image_shapes = []\n        num_image_tokens = []\n        tokenized_str = []\n        for text_sep, image in zip(text_splits, images):\n            \"\"\"encode text_sep\"\"\"\n            tokenized_sep = self.encode(text_sep, bos=False, eos=False)\n\n            tokenized_str += tokenized_sep\n            images_seq_mask += [False] * len(tokenized_sep)\n\n            image_shapes.append(image.size)\n\n            if image.size[0] <= 640 and image.size[1] <= 640:\n                crop_ratio = [1, 1]\n            else:\n                if cropping:\n                    images_crop_raw, crop_ratio = dynamic_preprocess(\n                        image, image_size=IMAGE_SIZE\n                    )\n                else:\n                    crop_ratio = [1, 1]\n\n            \"\"\"process the global view\"\"\"\n            if self.image_size <= 640 and not cropping:\n                image = image.resize((self.image_size, self.image_size))\n\n            global_view = ImageOps.pad(\n                image,\n                (self.base_size, self.base_size),\n                color=tuple(int(x * 255) for x in self.image_transform.mean),\n            )\n            images_list.append(self.image_transform(global_view))\n\n            num_width_tiles, num_height_tiles = crop_ratio\n            images_spatial_crop.append([num_width_tiles, num_height_tiles])\n\n            if num_width_tiles > 1 or num_height_tiles > 1:\n                for i in range(len(images_crop_raw)):\n                    images_crop_list.append(self.image_transform(images_crop_raw[i]))\n\n            \"\"\"add image tokens\"\"\"\n            num_queries = math.ceil(\n                (self.image_size // self.patch_size) / self.downsample_ratio\n            )\n            num_queries_base = math.ceil(\n                (self.base_size // self.patch_size) / self.downsample_ratio\n            )\n\n            if self.ocr2_mode:\n                tokenized_image = []\n                if num_width_tiles > 1 or num_height_tiles > 1:\n                    tokenized_image += [self.image_token_id] * (\n                        num_queries * num_width_tiles * num_queries * num_height_tiles\n                    )\n                tokenized_image += [self.image_token_id] * (\n                    num_queries_base * num_queries_base\n                )\n                # One extra token for the view separator.\n                tokenized_image += [self.image_token_id]\n            else:\n                tokenized_image = (\n                    [self.image_token_id] * num_queries_base + [self.image_token_id]\n                ) * num_queries_base\n                tokenized_image += [self.image_token_id]\n                if num_width_tiles > 1 or num_height_tiles > 1:\n                    tokenized_image += (\n                        [self.image_token_id] * (num_queries * num_width_tiles)\n                        + [self.image_token_id]\n                    ) * (num_queries * num_height_tiles)\n            tokenized_str += tokenized_image\n\n            images_seq_mask += [True] * len(tokenized_image)\n            num_image_tokens.append(len(tokenized_image))\n\n        \"\"\"process the last text split\"\"\"\n        tokenized_sep = self.encode(text_splits[-1], bos=False, eos=False)\n\n        tokenized_str += tokenized_sep\n        images_seq_mask += [False] * len(tokenized_sep)\n\n        \"\"\"add the bos and eos tokens\"\"\"\n        if bos:\n            tokenized_str = [self.bos_id] + tokenized_str\n            images_seq_mask = [False] + images_seq_mask\n        if eos:\n            tokenized_str = tokenized_str + [self.eos_id]\n            images_seq_mask = images_seq_mask + [False]\n\n        assert len(tokenized_str) == len(\n            images_seq_mask\n        ), f\"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}\"\n\n        masked_tokenized_str = []\n        for token_index in tokenized_str:\n            if token_index != self.image_token_id:\n                masked_tokenized_str.append(token_index)\n            else:\n                masked_tokenized_str.append(self.ignore_id)\n\n        assert (\n            len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str)\n        ), (\n            f\"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, \"\n            f\"imags_seq_mask's length {len(images_seq_mask)}, are not equal\"\n        )\n        input_ids = torch.LongTensor(tokenized_str)\n        target_ids = torch.LongTensor(masked_tokenized_str)\n        images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)\n\n        # set input_ids < 0 | input_ids == self.image_token_id as ignore_id\n        target_ids[(input_ids < 0) | (input_ids == self.image_token_id)] = (\n            self.ignore_id\n        )\n        input_ids[input_ids < 0] = self.pad_id\n\n        inference_mode = True\n\n        if inference_mode:\n            # Remove the ending eos token\n            assert input_ids[-1] == self.eos_id\n            input_ids = input_ids[:-1]\n            target_ids = target_ids[:-1]\n            images_seq_mask = images_seq_mask[:-1]\n\n        if len(images_list) == 0:\n            pixel_values = torch.zeros((1, 3, self.base_size, self.base_size))\n            images_spatial_crop = torch.zeros((1, 1), dtype=torch.long)\n            images_crop = torch.zeros(\n                (1, 3, self.image_size, self.image_size)\n            ).unsqueeze(0)\n        else:\n            pixel_values = torch.stack(images_list, dim=0)\n            images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)\n            if images_crop_list:\n                images_crop = torch.stack(images_crop_list, dim=0).unsqueeze(0)\n            else:\n                images_crop = torch.zeros(\n                    (1, 3, self.image_size, self.image_size)\n                ).unsqueeze(0)\n\n        input_ids = input_ids.unsqueeze(0)\n        return (\n            input_ids,\n            pixel_values,\n            images_crop,\n            images_seq_mask,\n            images_spatial_crop,\n            num_image_tokens,\n            image_shapes,\n        )\n\n\nclass VisionEncoderConfig(PretrainedConfig):\n    model_type: str = \"vision\"\n\n    model_name: str = \"vit_so400m_patch14_siglip_384.webli\"\n    image_size: int = 384\n    patch_size: int = 16\n    width: int = 1024\n    layers: int = 24\n    heads: int = 16\n    mlp_ratio: int = 4\n    global_pool: str = \"map\"\n    ignore_head: bool = True\n    class_token: bool = False\n    num_classes: int = 0\n    use_checkpoint: bool = False\n    weight_init: str = \"skip\"\n    deterministic: bool = False\n    num_recomputing_layers: int = 0\n\n    def __init__(\n        self,\n        model_name: str = \"vit_so400m_patch14_siglip_384.webli\",\n        image_size: int = 384,\n        patch_size: int = 16,\n        width: int = 1024,\n        layers: int = 24,\n        heads: int = 16,\n        mlp_ratio: int = 4,\n        global_pool: str = \"map\",\n        ignore_head: bool = True,\n        class_token: bool = False,\n        num_classes: int = 0,\n        use_checkpoint: bool = False,\n        **kwargs,\n    ):\n        self.model_name = model_name\n        self.image_size = image_size\n        self.patch_size = patch_size\n        self.width = width\n        self.layers = layers\n        self.heads = heads\n        self.mlp_ratio = mlp_ratio\n        self.global_pool = global_pool\n        self.ignore_head = ignore_head\n        self.class_token = class_token\n        self.num_classes = num_classes\n        self.use_checkpoint = use_checkpoint\n\n        super().__init__(**kwargs)\n\n\nclass MlpProjectorConfig(PretrainedConfig):\n    model_type = \"mlp_projector\"\n    projector_type: str = \"downsample_mlp_gelu\"\n    input_dim: int = 1152\n    n_embed: int = 2048\n    depth: int = 2\n    mlp_ratio: int = 1\n    downsample_ratio: int = 2\n    token_pooling: bool = False\n\n    def __init__(\n        self,\n        projector_type: str = \"downsample_mlp_gelu\",\n        input_dim: int = 1152,\n        n_embed: int = 2048,\n        depth: int = 2,\n        mlp_ratio: int = 1,\n        downsample_ratio: int = 2,\n        **kwargs,\n    ):\n        self.projector_type = projector_type\n        self.input_dim = input_dim\n        self.n_embed = n_embed\n        self.depth = depth\n        self.mlp_ratio = mlp_ratio\n        self.downsample_ratio = downsample_ratio\n\n        super().__init__(**kwargs)\n\n\nclass DeepseekV2Config(PretrainedConfig):\n    model_type = \"deepseek_v2\"\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n\n    def __init__(\n        self,\n        vocab_size=102400,\n        hidden_size=4096,\n        intermediate_size=11008,\n        moe_intermediate_size=1407,\n        num_hidden_layers=30,\n        num_attention_heads=32,\n        num_key_value_heads=32,\n        n_shared_experts=None,\n        n_routed_experts=None,\n        ep_size=1,\n        routed_scaling_factor=1.0,\n        kv_lora_rank=512,\n        q_lora_rank=1536,\n        qk_rope_head_dim=64,\n        v_head_dim=128,\n        qk_nope_head_dim=128,\n        topk_method=\"gready\",\n        n_group=None,\n        topk_group=None,\n        num_experts_per_tok=None,\n        moe_layer_freq=1,\n        first_k_dense_replace=0,\n        norm_topk_prob=False,\n        scoring_func=\"softmax\",\n        aux_loss_alpha=0.001,\n        seq_aux=True,\n        hidden_act=\"silu\",\n        max_position_embeddings=2048,\n        initializer_range=0.02,\n        rms_norm_eps=1e-6,\n        use_cache=True,\n        pad_token_id=None,\n        bos_token_id=100000,\n        eos_token_id=100001,\n        pretraining_tp=1,\n        tie_word_embeddings=False,\n        rope_theta=10000.0,\n        rope_scaling=None,\n        attention_bias=False,\n        attention_dropout=0.0,\n        use_mla=True,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.moe_intermediate_size = moe_intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.n_shared_experts = n_shared_experts\n        self.n_routed_experts = n_routed_experts\n        self.ep_size = ep_size\n        self.routed_scaling_factor = routed_scaling_factor\n        self.kv_lora_rank = kv_lora_rank\n        self.q_lora_rank = q_lora_rank\n        self.qk_rope_head_dim = qk_rope_head_dim\n        self.v_head_dim = v_head_dim\n        self.qk_nope_head_dim = qk_nope_head_dim\n        self.topk_method = topk_method\n        self.n_group = n_group\n        self.topk_group = topk_group\n        self.num_experts_per_tok = num_experts_per_tok\n        self.moe_layer_freq = moe_layer_freq\n        self.first_k_dense_replace = first_k_dense_replace\n        self.norm_topk_prob = norm_topk_prob\n        self.scoring_func = scoring_func\n        self.aux_loss_alpha = aux_loss_alpha\n        self.seq_aux = seq_aux\n        # for backward compatibility\n        if num_key_value_heads is None:\n            num_key_value_heads = num_attention_heads\n\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = float(rms_norm_eps)\n        self.pretraining_tp = pretraining_tp\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.rope_scaling = rope_scaling\n        self.attention_bias = attention_bias\n        self.attention_dropout = attention_dropout\n        self.use_mla = use_mla\n\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n\n\n@register_customized_processor(processor_class=DeepseekOCRProcessor)\nclass DeepseekVLV2Config(PretrainedConfig):\n    # model_type = \"deepseek_vl_v2\"\n    model_type = \"deepseek-ocr\"\n    vision_config: VisionEncoderConfig = None\n    projector_config: MlpProjectorConfig = None\n\n    tile_tag: str = \"2D\"\n    global_view_pos: str = \"head\"\n    candidate_resolutions: tuple[tuple[int, int]] = ((384, 384),)\n    customized_processor_type: type[Any] = DeepseekOCRProcessor\n\n    def __init__(\n        self,\n        tile_tag: str = \"tile_tag\",\n        global_view_pos: str = \"head\",\n        candidate_resolutions: tuple[tuple[int, int]] = ((384, 384),),\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n\n        vision_config = kwargs.get(\"vision_config\", {})\n        self.vision_config = VisionEncoderConfig(**vision_config)\n\n        projector_config = kwargs.get(\"projector_config\", {})\n        self.projector_config = MlpProjectorConfig(**projector_config)\n\n        language_config = kwargs.get(\"language_config\", {})\n        self.text_config = DeepseekV2Config(**language_config)\n\n        self.tile_tag = tile_tag\n        self.global_view_pos = global_view_pos\n        self.candidate_resolutions = candidate_resolutions\n        self.vocab_size = self.text_config.vocab_size\n        self.hidden_size = self.text_config.hidden_size\n\n\nAutoProcessor.register(DeepseekVLV2Config, DeepseekOCRProcessor)\n"
  },
  {
    "path": "python/sglang/srt/configs/deepseekvl2.py",
    "content": "import math\nfrom dataclasses import dataclass\nfrom typing import Dict, List, Optional, Tuple\n\nimport torch\nfrom PIL import Image, ImageOps\nfrom transformers import (\n    AutoProcessor,\n    LlamaTokenizerFast,\n    PretrainedConfig,\n    ProcessorMixin,\n)\n\n\ndef select_best_resolution(image_size, candidate_resolutions):\n    # used for cropping\n    original_width, original_height = image_size\n    best_fit = None\n    max_effective_resolution = 0\n    min_wasted_resolution = float(\"inf\")\n\n    for width, height in candidate_resolutions:\n        scale = min(width / original_width, height / original_height)\n        downscaled_width, downscaled_height = int(original_width * scale), int(\n            original_height * scale\n        )\n        effective_resolution = min(\n            downscaled_width * downscaled_height, original_width * original_height\n        )\n        wasted_resolution = (width * height) - effective_resolution\n\n        if effective_resolution > max_effective_resolution or (\n            effective_resolution == max_effective_resolution\n            and wasted_resolution < min_wasted_resolution\n        ):\n            max_effective_resolution = effective_resolution\n            min_wasted_resolution = wasted_resolution\n            best_fit = (width, height)\n\n    return best_fit\n\n\nclass DictOutput(object):\n    def items(self):\n        return self.__dict__.items()\n\n    def keys(self):\n        return self.__dict__.keys()\n\n    def __getitem__(self, item):\n        return self.__dict__[item]\n\n    def __contains__(self, key):\n        return key in self.__dict__\n\n    def __setitem__(self, key, value):\n        self.__dict__[key] = value\n\n\n@dataclass\nclass VLChatProcessorOutput(DictOutput):\n    input_ids: torch.LongTensor\n    target_ids: torch.LongTensor\n    pixel_values: (\n        torch.Tensor\n    )  # rename from \"images\" to \"pixel_values\" for compatibility\n    images_seq_mask: torch.BoolTensor\n    images_spatial_crop: torch.LongTensor\n\n    def __len__(self):\n        return len(self.input_ids)\n\n\nclass ImageTransform(object):\n    def __init__(\n        self,\n        mean: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),\n        std: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),\n        normalize: bool = True,\n    ):\n        self.mean = mean\n        self.std = std\n        self.normalize = normalize\n\n        # only load torchvision.transforms when needed\n        try:\n            import torchvision.transforms as T\n\n            # FIXME: add version check for gguf\n        except ImportError as err:\n            raise ImportError(\n                \"Please install torchvision via `pip install torchvision` to use Deepseek-VL2.\"\n            ) from err\n\n        transform_pipelines = [T.ToTensor()]\n\n        if normalize:\n            transform_pipelines.append(T.Normalize(mean, std))\n\n        self.transform = T.Compose(transform_pipelines)\n\n    def __call__(self, pil_img: Image.Image):\n        x = self.transform(pil_img)\n        return x\n\n\nclass DeepseekVLV2Processor(ProcessorMixin):\n    tokenizer_class = (\"LlamaTokenizer\", \"LlamaTokenizerFast\")\n    attributes = [\"tokenizer\"]\n\n    def __init__(\n        self,\n        tokenizer: LlamaTokenizerFast,\n        candidate_resolutions: Tuple[Tuple[int, int]],\n        patch_size: int,\n        downsample_ratio: int,\n        image_mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),\n        image_std: Tuple[float, float, float] = (0.5, 0.5, 0.5),\n        normalize: bool = True,\n        image_token: str = \"<image>\",\n        pad_token: str = \"<｜▁pad▁｜>\",\n        add_special_token: bool = False,\n        sft_format: str = \"deepseek\",\n        mask_prompt: bool = True,\n        ignore_id: int = -100,\n        **kwargs,\n    ):\n\n        self.candidate_resolutions = candidate_resolutions\n        self.image_size = candidate_resolutions[0][0]\n        self.patch_size = patch_size\n        self.image_mean = image_mean\n        self.image_std = image_std\n        self.normalize = normalize\n        self.downsample_ratio = downsample_ratio\n\n        self.image_transform = ImageTransform(\n            mean=image_mean, std=image_std, normalize=normalize\n        )\n        self.tokenizer = tokenizer\n        # must set this，padding side with make a difference in batch inference\n        self.tokenizer.padding_side = \"left\"\n\n        # add the pad_token as special token to use 'tokenizer.pad_token' and 'tokenizer.pad_token_id'\n        if tokenizer.pad_token is None:\n            self.tokenizer.add_special_tokens({\"pad_token\": pad_token})\n\n        # add image token\n        image_token_id = self.tokenizer.vocab.get(image_token)\n        if image_token_id is None:\n            special_tokens = [image_token]\n            special_tokens_dict = {\"additional_special_tokens\": special_tokens}\n            self.tokenizer.add_special_tokens(special_tokens_dict)\n        self.image_token_id = self.tokenizer.vocab.get(image_token)\n\n        # add five special tokens for grounding-related tasks\n        # <|ref|>, <|/ref|>, <|det|>, <|/det|>, <|grounding|>\n        special_tokens = [\"<|ref|>\", \"<|/ref|>\", \"<|det|>\", \"<|/det|>\", \"<|grounding|>\"]\n        special_tokens_dict = {\"additional_special_tokens\": special_tokens}\n        self.tokenizer.add_special_tokens(special_tokens_dict)\n\n        # add special tokens for SFT data\n        special_tokens = [\"<|User|>\", \"<|Assistant|>\"]\n        special_tokens_dict = {\"additional_special_tokens\": special_tokens}\n        self.tokenizer.add_special_tokens(special_tokens_dict)\n\n        self.image_token = image_token\n        self.pad_token = pad_token\n        self.add_special_token = add_special_token\n        self.sft_format = sft_format\n        self.mask_prompt = mask_prompt\n        self.ignore_id = ignore_id\n\n        super().__init__(\n            tokenizer,\n            **kwargs,\n        )\n\n    def format_messages_v2(self, messages, pil_images, max_req_input_len=-1):\n        \"\"\"play the role of format_messages_v2 and get_images_info in the last version\"\"\"\n        tokenized_data = []\n        masked_tokenized_data = []  # labels\n        images_list = []\n        images_seq_mask = []\n        images_spatial_crop = []\n\n        image_index = 0\n        image_token_cnt = messages.count(self.image_token)\n        tokenized_str, images, seq_mask, spatial_crop = self.tokenize_with_images(\n            messages,\n            pil_images[image_index : image_index + image_token_cnt],\n            bos=True,\n            eos=True,\n            cropping=len(pil_images) <= 2,\n            max_req_input_len=max_req_input_len,\n        )\n\n        image_index = image_token_cnt\n        tokenized_data += tokenized_str\n        if self.mask_prompt:\n            masked_tokenized_data += [self.ignore_id] * len(tokenized_str)\n        else:\n            masked_tokenized_data += tokenized_str\n        images_list += images\n        images_seq_mask += seq_mask\n        images_spatial_crop += spatial_crop\n\n        assert len(tokenized_data) == len(\n            images_seq_mask\n        ), f\"format_messages_v2: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}\"\n\n        return (\n            tokenized_data,\n            masked_tokenized_data,\n            images_list,\n            images_seq_mask,\n            images_spatial_crop,\n        )\n\n    @property\n    def bos_id(self):\n        return self.tokenizer.bos_token_id\n\n    @property\n    def eos_id(self):\n        return self.tokenizer.eos_token_id\n\n    @property\n    def pad_id(self):\n        return self.tokenizer.pad_token_id\n\n    def encode(self, text: str, bos: bool = True, eos: bool = False):\n        t = self.tokenizer.encode(text, add_special_tokens=False)\n\n        if bos:\n            t = [self.bos_id] + t\n        if eos:\n            t = t + [self.eos_id]\n\n        return t\n\n    def decode(self, t: List[int], **kwargs) -> str:\n        return self.tokenizer.decode(t, **kwargs)\n\n    def process_one(\n        self,\n        prompt: str = None,\n        conversations: List[Dict[str, str]] = None,\n        images: List[Image.Image] = None,\n        apply_sft_format: bool = False,\n        inference_mode: bool = True,\n        system_prompt: str = \"\",\n        max_req_input_len: int = -1,\n        **kwargs,\n    ):\n        \"\"\"\n\n        Args:\n            prompt (str): the formatted prompt;\n            conversations (List[Dict]): conversations with a list of messages;\n            images (List[ImageType]): the list of images;\n            apply_sft_format (bool): if prompt is not None, then apply the SFT format to prompt;\n                if conversations is not None, then it will always apply the SFT format to conversations;\n            inference_mode (bool): if True, then remove the last eos token;\n            system_prompt (str): the system prompt;\n            **kwargs:\n\n        Returns:\n            outputs (BaseProcessorOutput): the output of the processor,\n                - input_ids (torch.LongTensor): [N + image tokens]\n                - target_ids (torch.LongTensor): [N + image tokens]\n                - images (torch.FloatTensor): [n_images, 3, H, W]\n                - image_id (int): the id of the image token\n                - num_image_tokens (List[int]): the number of image tokens\n        \"\"\"\n\n        assert (\n            prompt is None or conversations is None\n        ), \"prompt and conversations cannot be used at the same time.\"\n\n        (\n            tokenized_str,\n            masked_tokenized_str,\n            images_list,\n            images_seq_mask,\n            images_spatial_crop,\n        ) = self.format_messages_v2(conversations, images, max_req_input_len)\n\n        assert (\n            len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str)\n        ), (\n            f\"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, \"\n            f\"imags_seq_mask's length {len(images_seq_mask)}, are not equal\"\n        )\n\n        input_ids = torch.LongTensor(tokenized_str)\n        target_ids = torch.LongTensor(masked_tokenized_str)\n        images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)\n\n        # set input_ids < 0 | input_ids == self.image_token_id as ignore_id\n        target_ids[(input_ids < 0) | (input_ids == self.image_token_id)] = (\n            self.ignore_id\n        )\n        input_ids[input_ids < 0] = self.pad_id\n\n        if inference_mode:\n            assert input_ids[-1] == self.eos_id\n            input_ids = input_ids[:-1]\n            target_ids = target_ids[:-1]\n            images_seq_mask = images_seq_mask[:-1]\n\n        if len(images_list) == 0:\n            images = torch.zeros((1, 3, self.image_size, self.image_size))\n            images_spatial_crop = torch.zeros((1, 2), dtype=torch.long)\n        else:\n            images = torch.stack(images_list, dim=0)\n            images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)\n\n        images_spatial_crop = torch.stack(\n            [images_spatial_crop], dim=0\n        )  # stack the tensor to make it a batch of 1\n\n        prepare = VLChatProcessorOutput(\n            input_ids=input_ids,\n            target_ids=target_ids,\n            pixel_values=images,\n            images_seq_mask=images_seq_mask,\n            images_spatial_crop=images_spatial_crop,\n        )\n\n        return prepare\n\n    def __call__(\n        self,\n        *,\n        prompt: str = None,\n        conversations: List[Dict[str, str]] = None,\n        images: List[Image.Image] = None,\n        apply_sft_format: bool = False,\n        inference_mode: bool = True,\n        system_prompt: str = \"\",\n        max_req_input_len: int = -1,\n        **kwargs,\n    ):\n        prepare = self.process_one(\n            prompt=prompt,\n            conversations=conversations,\n            images=images,\n            apply_sft_format=apply_sft_format,\n            inference_mode=inference_mode,\n            system_prompt=system_prompt,\n            max_req_input_len=max_req_input_len,\n        )\n\n        return prepare\n\n    def find_all_indices(self, messages, target_value):\n        indices = []\n        for index, item in enumerate(messages):\n            if item == target_value:\n                indices.append(index)\n        return indices\n\n    def tokenize_with_images(\n        self,\n        conversation: str,\n        images: List[Image.Image],\n        bos: bool = True,\n        eos: bool = True,\n        cropping: bool = True,\n        max_req_input_len: int = -1,\n    ):\n        \"\"\"Tokenize text with <image> tags.\"\"\"\n        images_list, images_seq_mask, images_spatial_crop = [], [], []\n        text_splits = conversation.split(self.image_token)\n        tokenized_str = []\n        for text_sep, image in zip(text_splits, images):\n            \"\"\"encode text_sep\"\"\"\n            tokenized_sep = self.encode(text_sep, bos=False, eos=False)\n            tokenized_str += tokenized_sep\n            images_seq_mask += [False] * len(tokenized_sep)\n\n            \"\"\"select best resolution for anyres\"\"\"\n            if cropping:\n                best_width, best_height = select_best_resolution(\n                    image.size, self.candidate_resolutions\n                )\n            else:\n                best_width, best_height = self.image_size, self.image_size\n            # print(image.size, (best_width, best_height)) # check the select_best_resolutions func\n\n            \"\"\"process the global view\"\"\"\n            global_view = ImageOps.pad(\n                image,\n                (self.image_size, self.image_size),\n                color=tuple(int(x * 255) for x in self.image_transform.mean),\n            )\n            images_list.append(self.image_transform(global_view))\n\n            \"\"\"process the local views\"\"\"\n            local_view = ImageOps.pad(\n                image,\n                (best_width, best_height),\n                color=tuple(int(x * 255) for x in self.image_transform.mean),\n            )\n            for i in range(0, best_height, self.image_size):\n                for j in range(0, best_width, self.image_size):\n                    images_list.append(\n                        self.image_transform(\n                            local_view.crop(\n                                (j, i, j + self.image_size, i + self.image_size)\n                            )\n                        )\n                    )\n\n            \"\"\"record height / width crop num\"\"\"\n            num_width_tiles, num_height_tiles = (\n                best_width // self.image_size,\n                best_height // self.image_size,\n            )\n            images_spatial_crop.append([num_width_tiles, num_height_tiles])\n\n            \"\"\"add image tokens\"\"\"\n            h = w = math.ceil(\n                (self.image_size // self.patch_size) / self.downsample_ratio\n            )\n            # global views tokens h * (w + 1), 1 is for line separator\n            tokenized_image = [self.image_token_id] * h * (w + 1)\n            # add a separator between global and local views\n            tokenized_image += [self.image_token_id]\n            # local views tokens, (num_height_tiles * h) * (num_width_tiles * w + 1)\n            tokenized_image += (\n                [self.image_token_id]\n                * (num_height_tiles * h)\n                * (num_width_tiles * w + 1)\n            )\n\n            tokenized_str += tokenized_image\n            images_seq_mask += [True] * len(tokenized_image)\n            # print(width_crop_num, height_crop_num, len(tokenized_image)) # test the correctness of the number of image-related tokens\n\n        \"\"\"process the last text split\"\"\"\n        tokenized_sep = self.encode(text_splits[-1], bos=False, eos=False)\n        # deal with video, limit with request len\n        if max_req_input_len > -1:\n            if max_req_input_len < len(tokenized_sep) + len(tokenized_str) - 1:\n                rest = max_req_input_len - len(tokenized_sep) - 1 - 1024\n                tokenized_str = tokenized_str[:rest]\n                images_seq_mask = images_seq_mask[:rest]\n        tokenized_str += tokenized_sep\n        images_seq_mask += [False] * len(tokenized_sep)\n\n        \"\"\"add the bos and eos tokens\"\"\"\n        if bos:\n            tokenized_str = [self.bos_id] + tokenized_str\n            images_seq_mask = [False] + images_seq_mask\n        if eos:\n            tokenized_str = tokenized_str + [self.eos_id]\n            images_seq_mask = images_seq_mask + [False]\n\n        assert len(tokenized_str) == len(\n            images_seq_mask\n        ), f\"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}\"\n\n        return tokenized_str, images_list, images_seq_mask, images_spatial_crop\n\n\nclass DeepseekVL2VisionEncoderConfig(PretrainedConfig):\n    model_type: str = \"vision\"\n\n    model_name: str = \"siglip_large_patch16_384\"\n    image_size: int = 384\n    patch_size: int = 16\n    width: int = 1024\n    layers: int = 24\n    heads: int = 16\n    mlp_ratio: int = 4\n    global_pool: str = \"map\"\n    ignore_head: bool = True\n    class_token: bool = False\n    num_classes: int = 0\n    use_checkpoint: bool = False\n    weight_init: str = \"skip\"\n    deterministic: bool = False\n    num_recomputing_layers: int = 0\n\n    def __init__(\n        self,\n        model_name: str = \"siglip_large_patch16_384\",\n        image_size: int = 384,\n        patch_size: int = 16,\n        width: int = 1024,\n        layers: int = 24,\n        heads: int = 16,\n        mlp_ratio: int = 4,\n        global_pool: str = \"map\",\n        ignore_head: bool = True,\n        class_token: bool = False,\n        num_classes: int = 0,\n        use_checkpoint: bool = False,\n        **kwargs,\n    ):\n        self.model_name = model_name\n        self.image_size = image_size\n        self.patch_size = patch_size\n        self.width = width\n        self.layers = layers\n        self.heads = heads\n        self.mlp_ratio = mlp_ratio\n        self.global_pool = global_pool\n        self.ignore_head = ignore_head\n        self.class_token = class_token\n        self.num_classes = num_classes\n        self.use_checkpoint = use_checkpoint\n\n        super().__init__(**kwargs)\n\n\nclass DeepseekVL2MlpProjectorConfig(PretrainedConfig):\n    model_type = \"mlp_projector\"\n    projector_type: str = \"downsample_mlp_gelu\"\n    input_dim: int = 1152\n    n_embed: int = 2048\n    depth: int = 2\n    mlp_ratio: int = 1\n    downsample_ratio: int = 2\n    token_pooling: bool = False\n\n    def __init__(\n        self,\n        projector_type: str = \"downsample_mlp_gelu\",\n        input_dim: int = 1152,\n        n_embed: int = 2048,\n        depth: int = 2,\n        mlp_ratio: int = 1,\n        downsample_ratio: int = 2,\n        **kwargs,\n    ):\n        self.projector_type = projector_type\n        self.input_dim = input_dim\n        self.n_embed = n_embed\n        self.depth = depth\n        self.mlp_ratio = mlp_ratio\n        self.downsample_ratio = downsample_ratio\n\n        super().__init__(**kwargs)\n\n\nclass DeepseekV2Config(PretrainedConfig):\n\n    model_type = \"deepseek_v2\"\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n\n    def __init__(\n        self,\n        vocab_size=102400,\n        hidden_size=4096,\n        intermediate_size=11008,\n        moe_intermediate_size=1407,\n        num_hidden_layers=30,\n        num_attention_heads=32,\n        num_key_value_heads=32,\n        n_shared_experts=None,\n        n_routed_experts=None,\n        ep_size=1,\n        routed_scaling_factor=1.0,\n        kv_lora_rank=512,\n        q_lora_rank=1536,\n        qk_rope_head_dim=64,\n        v_head_dim=128,\n        qk_nope_head_dim=128,\n        topk_method=\"gready\",\n        n_group=None,\n        topk_group=None,\n        num_experts_per_tok=None,\n        moe_layer_freq=1,\n        first_k_dense_replace=0,\n        norm_topk_prob=False,\n        scoring_func=\"softmax\",\n        aux_loss_alpha=0.001,\n        seq_aux=True,\n        hidden_act=\"silu\",\n        max_position_embeddings=2048,\n        initializer_range=0.02,\n        rms_norm_eps=1e-6,\n        use_cache=True,\n        pad_token_id=None,\n        bos_token_id=100000,\n        eos_token_id=100001,\n        pretraining_tp=1,\n        tie_word_embeddings=False,\n        rope_theta=10000.0,\n        rope_scaling=None,\n        attention_bias=False,\n        attention_dropout=0.0,\n        use_mla=True,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.moe_intermediate_size = moe_intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.n_shared_experts = n_shared_experts\n        self.n_routed_experts = n_routed_experts\n        self.ep_size = ep_size\n        self.routed_scaling_factor = routed_scaling_factor\n        self.kv_lora_rank = kv_lora_rank\n        self.q_lora_rank = q_lora_rank\n        self.qk_rope_head_dim = qk_rope_head_dim\n        self.v_head_dim = v_head_dim\n        self.qk_nope_head_dim = qk_nope_head_dim\n        self.topk_method = topk_method\n        self.n_group = n_group\n        self.topk_group = topk_group\n        self.num_experts_per_tok = num_experts_per_tok\n        self.moe_layer_freq = moe_layer_freq\n        self.first_k_dense_replace = first_k_dense_replace\n        self.norm_topk_prob = norm_topk_prob\n        self.scoring_func = scoring_func\n        self.aux_loss_alpha = aux_loss_alpha\n        self.seq_aux = seq_aux\n        # for backward compatibility\n        if num_key_value_heads is None:\n            num_key_value_heads = num_attention_heads\n\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = float(rms_norm_eps)\n        self.pretraining_tp = pretraining_tp\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.rope_scaling = rope_scaling\n        self.attention_bias = attention_bias\n        self.attention_dropout = attention_dropout\n        self.use_mla = use_mla\n\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n\n\nclass DeepseekVL2Config(PretrainedConfig):\n    model_type = \"deepseek_vl_v2\"\n    vision_config: DeepseekVL2VisionEncoderConfig = None\n    projector_config: DeepseekVL2MlpProjectorConfig = None\n    language_config: DeepseekV2Config = None\n\n    tile_tag: str = \"2D\"\n    global_view_pos: str = \"head\"\n    candidate_resolutions: Tuple[Tuple[int, int]] = ((384, 384),)\n\n    def __init__(\n        self,\n        tile_tag: str = \"tile_tag\",\n        global_view_pos: str = \"head\",\n        candidate_resolutions: Tuple[Tuple[int, int]] = ((384, 384),),\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n\n        vision_config = kwargs.get(\"vision_config\", {})\n        self.vision_config = DeepseekVL2VisionEncoderConfig(**vision_config)\n\n        projector_config = kwargs.get(\"projector_config\", {})\n        self.projector_config = DeepseekVL2MlpProjectorConfig(**projector_config)\n\n        language_config = kwargs.get(\"language_config\", {})\n        if isinstance(language_config, DeepseekV2Config):\n            self.language_config = language_config\n        else:\n            self.language_config = DeepseekV2Config(**language_config)\n\n        self.tile_tag = tile_tag\n        self.global_view_pos = global_view_pos\n        self.candidate_resolutions = candidate_resolutions\n        self.architectures = [\"DeepseekVL2ForCausalLM\"]\n\n\nAutoProcessor.register(DeepseekVL2Config, DeepseekVLV2Processor)\n"
  },
  {
    "path": "python/sglang/srt/configs/device_config.py",
    "content": "import logging\nfrom typing import Optional\n\nimport torch\n\nlogger = logging.getLogger(__name__)\n\nSUPPORTED_DEVICES = [\"cuda\", \"xpu\", \"hpu\", \"cpu\", \"npu\", \"musa\", \"mps\"]\n\n\nclass DeviceConfig:\n    device: Optional[torch.device]\n    gpu_id: Optional[int]\n\n    def __init__(self, device: str = \"cuda\", gpu_id: int = -1) -> None:\n        if device in SUPPORTED_DEVICES:\n            self.device_type = device\n        else:\n            raise RuntimeError(f\"Not supported device type: {device}\")\n        self.device = torch.device(self.device_type)\n        self.gpu_id = gpu_id\n"
  },
  {
    "path": "python/sglang/srt/configs/dots_ocr.py",
    "content": "from typing import Optional\n\nfrom transformers import AutoProcessor, Qwen2_5_VLProcessor\nfrom transformers.image_processing_utils import BaseImageProcessor\nfrom transformers.models.qwen2 import Qwen2Config\n\nfrom sglang.srt.configs.dots_vlm import DotsVisionConfig\n\n\nclass DotsOCRConfig(Qwen2Config):\n    model_type = \"dots_ocr\"\n\n    def __init__(\n        self,\n        image_token_id=151665,\n        video_token_id=151656,\n        vision_config: Optional[dict] = None,\n        *args,\n        **kwargs\n    ):\n        super().__init__(*args, **kwargs)\n        self.image_token_id = image_token_id\n        self.video_token_id = video_token_id\n        self.vision_config = DotsVisionConfig(**(vision_config or {}))\n\n    def save_pretrained(self, save_directory, **kwargs):\n        self._auto_class = None\n        super().save_pretrained(save_directory, **kwargs)\n\n\nclass DummyVideoProcessor(BaseImageProcessor):\n    model_input_names = [\"pixel_values\"]\n\n    def __call__(self, *args, **kwargs):\n        return None\n\n\nclass DotsVLProcessor(Qwen2_5_VLProcessor):\n    def __init__(\n        self,\n        image_processor=None,\n        tokenizer=None,\n        video_processor=None,\n        chat_template=None,\n        **kwargs\n    ):\n        if video_processor is None:\n            video_processor = DummyVideoProcessor()\n        super().__init__(\n            image_processor, tokenizer, video_processor, chat_template=chat_template\n        )\n        self.image_token = (\n            \"<|imgpad|>\"\n            if not hasattr(tokenizer, \"image_token\")\n            else tokenizer.image_token\n        )\n        self.image_token_id = (\n            tokenizer.image_token_id\n            if getattr(tokenizer, \"image_token_id\", None) is not None\n            else tokenizer.convert_tokens_to_ids(self.image_token)\n        )\n\n\nAutoProcessor.register(DotsOCRConfig, DotsVLProcessor)\n"
  },
  {
    "path": "python/sglang/srt/configs/dots_vlm.py",
    "content": "from transformers import AutoProcessor, PretrainedConfig\nfrom transformers.processing_utils import ProcessingKwargs\n\ntry:\n    from transformers import Qwen2_5_VLProcessor\nexcept ImportError:\n    raise ImportError(\n        \"Qwen2_5_VLProcessor can not be found. Please upgrade your transformers version.\"\n    )\n\nfrom sglang.srt.configs.deepseekvl2 import DeepseekV2Config\n\n\nclass DotsVisionConfig(PretrainedConfig):\n    model_type: str = \"dots_vit\"\n\n    def __init__(\n        self,\n        embed_dim: int = 1536,  # vision encoder embed size\n        hidden_size: int = 1536,  # after merger hidden size\n        intermediate_size: int = 4224,\n        num_hidden_layers: int = 42,\n        num_attention_heads: int = 12,\n        num_channels: int = 3,\n        patch_size: int = 14,\n        spatial_merge_size: int = 2,\n        temporal_patch_size: int = 1,\n        rms_norm_eps: float = 1e-5,\n        use_bias: bool = False,\n        attn_implementation=\"flash_attention_2\",  # \"eager\",\"sdpa\",\"flash_attention_2\"\n        initializer_range=0.02,\n        init_merger_std=0.02,\n        is_causal=False,  # ve causal forward\n        post_norm=True,\n        gradient_checkpointing=False,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n        self.embed_dim = embed_dim\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.num_channels = num_channels\n        self.patch_size = patch_size\n        self.spatial_merge_size = spatial_merge_size\n        self.temporal_patch_size = temporal_patch_size\n        self.rms_norm_eps = rms_norm_eps\n        self.use_bias = use_bias\n        self.attn_implementation = attn_implementation\n        self.initializer_range = initializer_range\n        self.init_merger_std = init_merger_std\n        self.is_causal = is_causal\n        self.post_norm = post_norm\n        self.gradient_checkpointing = gradient_checkpointing\n\n\nclass DotsVLMConfig(PretrainedConfig):\n    model_type = \"dots_vlm\"\n\n    def __init__(self, **kwargs):\n        super().__init__(**kwargs)\n        vision_config = kwargs.get(\"vision_config\", {})\n        self.im_span_id = kwargs.get(\"image_token_id\", 128815)\n        self.video_span_id = kwargs.get(\"video_token_id\", 128836)\n        self.vision_config = DotsVisionConfig(**vision_config)\n        self.language_config = DeepseekV2Config(**kwargs)\n        self.architectures = [\"DotsVLMForCausalLM\"]\n\n\nclass DotsVLMProcessorKwargs(ProcessingKwargs, total=False):\n    _defaults = {\n        \"text_kwargs\": {\n            \"padding\": False,\n        },\n    }\n\n\nclass DotsVLMProcessor(Qwen2_5_VLProcessor):\n    r\"\"\"\n    Constructs a DotsVLM processor which derives from Qwen2_5_VLProcessor, but overrides the image and video token ids.\n    Besides, its tokenizer is a LlamaTokenizerFast instead of Qwen2TokenizerFast.\n    [`DotsVLMProcessor`] offers all the functionalities of [`DotsVisionConfig`] and [`LlamaTokenizerFast`]. See the\n    [`~DotsVLMProcessor.__call__`] and [`~DotsVLMProcessor.decode`] for more information.\n    Args:\n        image_processor ([`Qwen2VLImageProcessor`], *optional*):\n            The image processor is a required input.\n        tokenizer ([`LlamaTokenizerFast`], *optional*):\n            The tokenizer is a required input.\n        chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages\n            in a chat into a tokenizable string.\n    \"\"\"\n\n    attributes = [\"image_processor\", \"tokenizer\"]\n\n    valid_kwargs = [\"chat_template\"]\n\n    tokenizer_class = (\"LlamaTokenizer\", \"LlamaTokenizerFast\")\n\n    def __init__(\n        self, image_processor=None, tokenizer=None, chat_template=None, **kwargs\n    ):\n        super().__init__(image_processor, tokenizer, chat_template=chat_template)\n        self.image_token = (\n            \"<|imgpad|>\"\n            if not hasattr(tokenizer, \"image_token\")\n            else tokenizer.image_token\n        )\n        self.video_token = (\n            \"<|video_pad|>\"\n            if not hasattr(tokenizer, \"video_token\")\n            else tokenizer.video_token\n        )\n        self.img_token = (\n            \"<|img|>\" if not hasattr(tokenizer, \"img_token\") else tokenizer.img_token\n        )\n        self.endofimg_token = (\n            \"<|endofimg|>\"\n            if not hasattr(tokenizer, \"endofimg_token\")\n            else tokenizer.endofimg_token\n        )\n        self.image_token_id = (\n            tokenizer.image_token_id\n            if getattr(tokenizer, \"image_token_id\", None)\n            else tokenizer.encode(self.image_token)[0]\n        )\n        self.video_token_id = (\n            tokenizer.video_token_id\n            if getattr(tokenizer, \"video_token_id\", None)\n            else tokenizer.encode(self.video_token)[0]\n        )\n\n\nAutoProcessor.register(DotsVLMConfig, DotsVLMProcessor)\n"
  },
  {
    "path": "python/sglang/srt/configs/exaone.py",
    "content": "# coding=utf-8\n# Copyright 2024 The LG AI Research EXAONE Lab. All rights reserved.\n# Copyright 2024 The LG CNS AI Engineering Team.\n# Copyright 2023-2024 SGLang Team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"EXAONE model configuration\"\"\"\n\nfrom typing import Any, Dict\n\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.utils import logging\n\nlogger = logging.get_logger(__name__)\n\nEXAONE_PRETRAINED_CONFIG_ARCHIVE_MAP: Dict[str, Any] = {}\n\n\n# ruff: noqa: E501\nclass ExaoneConfig(PretrainedConfig):\n    r\"\"\"\n    This is the configuration class to store the configuration of a :class:`~transformers.ExaoneModel`. It is used to\n    instantiate a EXAONE model according to the specified arguments, defining the model architecture. Instantiating a\n    configuration with the defaults will yield a similar configuration to that of the Exaone\n\n    Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model\n    outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.\n\n\n    Args:\n        vocab_size (:obj:`int`, `optional`, defaults to 102400):\n            Vocabulary size of the EXAONE model. Defines the number of different tokens that can be represented by the\n            :obj:`inputs_ids` passed when calling :class:`~transformers.ExaoneModel`. Vocabulary size of the model.\n            Defines the different tokens that can be represented by the `inputs_ids` passed to the forward method of\n            :class:`~transformers.EXAONEModel`.\n        max_position_embeddings (:obj:`int`, `optional`, defaults to 2048):\n            The maximum sequence length that this model might ever be used with. Typically set this to something large\n            just in case (e.g., 512 or 1024 or 2048).\n        hidden_size (:obj:`int`, `optional`, defaults to 2048):\n            Dimensionality of the encoder layers and the pooler layer.\n        num_layers (:obj:`int`, `optional`, defaults to 32):\n            Number of hidden layers in the Transformer encoder.\n        num_attention_heads (:obj:`int`, `optional`, defaults to 32):\n            Number of attention heads for each attention layer in the Transformer decoder.\n        num_key_value_heads (:obj:`int`, `optional`):\n            This is the number of key_value heads that should be used to implement Grouped Query Attention. If\n            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if\n            `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When\n            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed\n            by meanpooling all the original heads within that group. For more details checkout [this\n            paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to\n            `num_attention_heads`.\n        intermediate_size (:obj:`int`, `optional`, defaults to `hidden_size * 4`):\n            Dimensionality of the \"intermediate\" (i.e., feed-forward) layer in the Transformer encoder.\n        activation_function (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`\"silu\"`):\n            The non-linear activation function (function or string) in the decoder.\n        rope_theta (:obj:`float`, `optional`, defaults to 10000.0):\n            The base period of the RoPE embeddings.\n        rope_scaling (:obj:`Dict`, `optional`):\n            Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type\n            and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value\n            accordingly.\n            Expected contents:\n                `rope_type` (:obj:`str`):\n                    The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',\n                    'llama3'], with 'default' being the original RoPE implementation.\n                `factor` (:obj:`float`, `optional`):\n                    Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In\n                    most scaling types, a `factor` of x will enable the model to handle sequences of length x *\n                    original maximum pre-trained length.\n                `original_max_position_embeddings` (:obj:`int`, `optional`):\n                    Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during\n                    pretraining.\n                `attention_factor` (:obj:`float`, `optional`):\n                    Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention\n                    computation. If unspecified, it defaults to value recommended by the implementation, using the\n                    `factor` field to infer the suggested value.\n                `beta_fast` (:obj:`float`, `optional`):\n                    Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear\n                    ramp function. If unspecified, it defaults to 32.\n                `beta_slow` (:obj:`float`, `optional`):\n                    Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear\n                    ramp function. If unspecified, it defaults to 1.\n                `short_factor` (:obj:`List[float]`, `optional`):\n                    Only used with 'longrope'. The scaling factor to be applied to short contexts (<\n                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden\n                    size divided by the number of attention heads divided by 2\n                `long_factor` (:obj:`List[float]`, `optional`):\n                    Only used with 'longrope'. The scaling factor to be applied to long contexts (<\n                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden\n                    size divided by the number of attention heads divided by 2\n                `low_freq_factor` (:obj:`float`, `optional`):\n                    Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE\n                `high_freq_factor` (:obj:`float`, `optional`):\n                    Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE\n        embed_dropout (:obj:`float`, `optional`, defaults to 0.0):\n            The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.\n        attention_dropout (:obj:`float`, `optional`, defaults to 0.0):\n            The dropout ratio for the attention probabilities.\n        layer_norm_epsilon (:obj:`float`, `optional`, defaults to 1e-5):\n            The epsilon used by the layer normalization layers.\n        initializer_range (:obj:`float`, `optional`, defaults to 0.02):\n            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.\n        use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):\n            Whether or not the model should return the last key/values attentions (not used by all models). Only\n            relevant if ``configs.is_decoder=True``.\n        bos_token_id (:obj:`int`, `optional`, defaults to 0):\n            Beginning of stream token id.\n        eos_token_id (:obj:`int`, `optional`, defaults to 2):\n            End of stream token id.\n        tie_word_embeddings (:obj:`bool`, `optional`, defaults to :obj:`True`):\n            Whether to tie weight embeddings\n        gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):\n            If True, use gradient checkpointing to save memory at the expense of slower backward pass.\n\n        Example::\n\n            >>> from transformers import EXAONEModel, ExaoneConfig\n\n            >>> # Initializing a EXAONE configuration\n            >>> configuration = ExaoneConfig()\n\n            >>> # Initializing a model from configuration\n            >>> model = EXAONEModel(configuration)\n\n            >>> # Accessing the model configuration\n            >>> configuration = model.configs\n    \"\"\"\n\n    model_type = \"exaone\"\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n    attribute_map = {\"num_hidden_layers\": \"num_layers\"}\n\n    def __init__(\n        self,\n        vocab_size=102400,\n        max_position_embeddings=2048,\n        hidden_size=2048,\n        num_layers=32,\n        num_attention_heads=32,\n        num_key_value_heads=None,\n        intermediate_size=None,\n        activation_function=\"silu\",\n        rope_theta=10000.0,\n        rope_scaling=None,\n        embed_dropout=0.0,\n        attention_dropout=0.0,\n        layer_norm_epsilon=1e-5,\n        initializer_range=0.02,\n        use_cache=True,\n        bos_token_id=0,\n        eos_token_id=2,\n        tie_word_embeddings=True,\n        **kwargs\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.num_layers = num_layers\n        self.num_attention_heads = num_attention_heads\n        self.num_hidden_layers = num_layers\n        if num_key_value_heads is None:\n            num_key_value_heads = num_attention_heads\n        self.num_key_value_heads = num_key_value_heads\n        if intermediate_size:\n            self.intermediate_size = intermediate_size\n        else:\n            self.intermediate_size = hidden_size * 4\n        self.activation_function = activation_function\n        self.embed_dropout = embed_dropout\n        self.attention_dropout = attention_dropout\n        self.layer_norm_epsilon = layer_norm_epsilon\n        self.initializer_range = initializer_range\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.rope_scaling = rope_scaling\n\n        self.bos_token_id = bos_token_id\n        self.eos_token_id = eos_token_id\n\n        super().__init__(\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs\n        )\n"
  },
  {
    "path": "python/sglang/srt/configs/falcon_h1.py",
    "content": "# coding=utf-8\n# Copyright 2024 TII and the HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Falcon-H1 model configuration\"\"\"\n\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.utils import logging\n\nfrom sglang.srt.configs.mamba_utils import (\n    Mamba2CacheParams,\n    Mamba2StateShape,\n    mamba2_state_dtype,\n)\n\nlogger = logging.get_logger(__name__)\n\n\nclass FalconH1Config(PretrainedConfig):\n    r\"\"\"\n    This is the configuration class to store the configuration of a [`FalconH1Model`]. It is used to instantiate a\n    FalconH1Model model according to the specified arguments, defining the model architecture. Instantiating a configuration\n    with defaults taken from [ibm-fms/FalconH1-9.8b-2.2T-hf](https://huggingface.co/ibm-fms/FalconH1-9.8b-2.2T-hf).\n    The FalconH1Model is a hybrid [mamba2](https://github.com/state-spaces/mamba) architecture with SwiGLU.\n    The checkpoints are  jointly trained by IBM, Princeton, and UIUC.\n    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the\n    documentation from [`PretrainedConfig`] for more information.\n    Args:\n        vocab_size (`int`, *optional*, defaults to 128000):\n            Vocabulary size of the FalconH1 model. Defines the number of different tokens that can be represented by the\n            `inputs_ids` passed when calling [`FalconH1Model`]\n        tie_word_embeddings (`bool`, *optional*, defaults to `False`):\n            Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the\n            model has a output word embedding layer.\n        hidden_size (`int`, *optional*, defaults to 4096):\n            Dimension of the hidden representations.\n        intermediate_size (`int`, *optional*, defaults to 14336):\n            Dimension of the MLP representations.\n        num_hidden_layers (`int`, *optional*, defaults to 32):\n            Number of hidden layers in the Transformer encoder.\n        num_attention_heads (`int`, *optional*, defaults to 32):\n            Number of attention heads for each attention layer in the Transformer encoder.\n        num_key_value_heads (`int`, *optional*, defaults to 8):\n            This is the number of key_value heads that should be used to implement Grouped Query Attention. If\n            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if\n            `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When\n            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed\n            by meanpooling all the original heads within that group. For more details, check out [this\n            paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `8`.\n        hidden_act (`str` or `function`, *optional*, defaults to `\"silu\"`):\n            The non-linear activation function (function or string) in the decoder.\n        initializer_range (`float`, *optional*, defaults to 0.02):\n            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.\n        rms_norm_eps (`float`, *optional*, defaults to 1e-05):\n            The epsilon used by the rms normalization layers.\n        use_cache (`bool`, *optional*, defaults to `True`):\n            Whether or not the model should return the last key/values attentions (not used by all models). Only\n            relevant if `config.is_decoder=True`.\n        num_logits_to_keep (`int` or `None`, *optional*, defaults to 1):\n            Number of prompt logits to calculate during generation. If `None`, all logits will be calculated. If an\n            integer value, only last `num_logits_to_keep` logits will be calculated. Default is 1 because only the\n            logits of the last prompt token are needed for generation. For long sequences, the logits for the entire\n            sequence may use a lot of memory so, setting `num_logits_to_keep=1` will reduce memory footprint\n            significantly.\n        pad_token_id (`int`, *optional*, defaults to 0):\n            The id of the padding token.\n        bos_token_id (`int`, *optional*, defaults to 1):\n            The id of the \"beginning-of-sequence\" token.\n        eos_token_id (`int`, *optional*, defaults to 2):\n            The id of the \"end-of-sequence\" token.\n        max_position_embeddings (`int`, *optional*, defaults to 8192):\n            Max cached sequence length for the model\n        attention_dropout (`float`, *optional*, defaults to 0.0):\n            The dropout ratio for the attention probabilities.\n        mamba_d_ssm (`int`, *optional*, defaults to 1024):\n            The dimension of the SSM state space latents.\n        mamba_n_heads (`int`, *optional*, defaults to 128):\n            The number of mamba heads used in the v2 implementation.\n        mamba_d_head (`int`, *optional*, defaults to `\"auto\"`):\n            Head embedding dimension size\n        mamba_n_groups (`int`, *optional*, defaults to 1):\n            The number of the mamba groups used in the v2 implementation.\n        mamba_d_state (`int`, *optional*, defaults to 256):\n            The dimension the mamba state space latents\n        mamba_d_conv (`int`, *optional*, defaults to 4):\n            The size of the mamba convolution kernel\n        mamba_expand (`int`, *optional*, defaults to 2):\n            Expanding factor (relative to hidden_size) used to determine the mamba intermediate size\n        mamba_chunk_size (`int`, *optional*, defaults to 256):\n            The chunks in which to break the sequence when doing prefill/training\n        mamba_conv_bias (`bool`, *optional*, defaults to `True`):\n            Flag indicating whether or not to use bias in the convolution layer of the mamba mixer block.\n        mamba_proj_bias (`bool`, *optional*, defaults to `False`):\n            Flag indicating whether or not to use bias in the input and output projections ([\"in_proj\", \"out_proj\"]) of the mamba mixer block\n        mamba_norm_before_gate (`bool`, *optional*, defaults to `True`):\n            Whether to use RMSNorm before the gate in the Mamba block\n        mamba_rms_norm (`bool`, *optional*, defaults to `False`):\n            Whether to use RMSNorm instead of LayerNorm in the Mamba block\n        projectors_bias (`bool`, *optional*, defaults to `False`):\n            Flag indicating whether or not to use bias in the input and output projections ([\"in_proj\", \"out_proj\"]) of the attention block\n        rope_theta (`float`, *optional*, defaults to 100000.0):\n            The theta value used for the RoPE embeddings.\n        rope_scaling (`float`, *optional*):\n            The scaling value used for the RoPE embeddings. If `None`, no scaling is applied.\n        lm_head_multiplier (`float`, *optional*, defaults to 1.0):\n            The multiplier for the LM head. This is used to scale the output of the LM head.\n        embedding_multiplier (`float`, *optional*, defaults to 1.0):\n            The multiplier for the embedding layer. This is used to scale the output of the embedding layer.\n        mlp_multipliers (`list[float]`, *optional*):\n            The multipliers for the MLP layers. This is used to scale the output of the MLP layers. The first value is\n            the multiplier of gate layer, the second value is the multiplier of the down_proj layer.\n        key_multiplier (`float`, *optional*):\n            The multiplier for the key layer. This is used to scale the output of the key layer.\n        attention_out_multiplier (`float`, *optional*):\n            The multiplier for the attention output layer. This is used to scale the output of the attention output\n        attention_in_multiplier (`float`, *optional*):\n            The multiplier for the attention input layer. This is used to scale the output of the attention input layer.\n        ssm_multipliers (`list[float]`, *optional*):\n            The multipliers for the SSM layers. This is used to scale the output of the SSM layers.\n        ssm_in_multiplier (`float`, *optional*):\n            The multiplier for the SSM input layer. This is used to scale the output of the SSM input layer.\n        ssm_out_multiplier (`float`, *optional*):\n            The multiplier for the SSM output layer. This is used to scale the output of the SSM output layer.\n    \"\"\"\n\n    model_type = \"falcon_h1\"\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n\n    def __init__(\n        self,\n        vocab_size=128000,\n        tie_word_embeddings=False,\n        hidden_size=4096,\n        intermediate_size=14336,\n        num_hidden_layers=32,\n        num_attention_heads=32,\n        num_key_value_heads=8,\n        hidden_act=\"silu\",\n        initializer_range=0.02,\n        rms_norm_eps=1e-5,\n        use_cache=True,\n        num_logits_to_keep=1,\n        pad_token_id=0,\n        bos_token_id=1,\n        eos_token_id=2,\n        max_position_embeddings=8192,\n        attention_dropout=0.0,\n        mamba_d_ssm=1024,\n        mamba_n_heads=128,\n        mamba_d_head=\"auto\",\n        mamba_n_groups=1,\n        mamba_d_state=256,\n        mamba_d_conv=4,\n        mamba_expand=2,\n        mamba_chunk_size=256,\n        mamba_conv_bias=True,\n        mamba_proj_bias=False,\n        mamba_norm_before_gate=True,\n        mamba_rms_norm=False,\n        projectors_bias=False,\n        rope_theta=100000.0,\n        rope_scaling=None,\n        lm_head_multiplier=1.0,\n        embedding_multiplier=1.0,\n        mlp_multipliers=None,\n        key_multiplier=None,\n        attention_out_multiplier=None,\n        attention_in_multiplier=None,\n        ssm_multipliers=None,\n        ssm_in_multiplier=None,\n        ssm_out_multiplier=None,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.max_position_embeddings = max_position_embeddings\n        self.attention_dropout = attention_dropout\n        self.attention_bias = False\n        self.mlp_bias = False\n\n        # for backward compatibility\n        if num_key_value_heads is None:\n            num_key_value_heads = num_attention_heads\n\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n\n        self.use_cache = use_cache\n        self.num_logits_to_keep = num_logits_to_keep\n\n        self.rope_theta = rope_theta\n        self.rope_scaling = None\n        self.rope_scaling = rope_scaling\n        self.projectors_bias = projectors_bias\n        self.mamba_intermediate = mamba_intermediate = (\n            mamba_expand * hidden_size if mamba_d_ssm is None else mamba_d_ssm\n        )\n\n        if mamba_intermediate % mamba_n_heads != 0:\n            raise ValueError(\"mamba_n_heads must divide mamba_expand * hidden_size\")\n\n        # for the mamba_v2, must satisfy the following\n        if mamba_d_head == \"auto\":\n            mamba_d_head = mamba_intermediate // mamba_n_heads\n\n        if mamba_d_head * mamba_n_heads != mamba_intermediate:\n            raise ValueError(\n                \"The dimensions for the Mamba head state do not match the model intermediate_size\"\n            )\n\n        self.mamba_d_ssm = mamba_d_ssm\n        self.mamba_n_heads = mamba_n_heads\n        self.mamba_d_head = mamba_d_head\n        self.mamba_n_groups = mamba_n_groups\n        self.mamba_d_state = mamba_d_state\n        self.mamba_d_conv = mamba_d_conv\n        self.mamba_expand = mamba_expand\n        self.mamba_chunk_size = mamba_chunk_size\n        self.mamba_conv_bias = mamba_conv_bias\n        self.mamba_proj_bias = mamba_proj_bias\n\n        self.mamba_norm_before_gate = mamba_norm_before_gate\n        self.mamba_rms_norm = mamba_rms_norm\n\n        self.lm_head_multiplier = lm_head_multiplier\n        self.embedding_multiplier = embedding_multiplier\n\n        if mlp_multipliers is not None:\n            self.mlp_multipliers = mlp_multipliers\n        else:\n            self.mlp_multipliers = [1.0, 1.0]\n\n        if attention_out_multiplier is not None:\n            self.attention_out_multiplier = attention_out_multiplier\n        else:\n            self.attention_out_multiplier = 1.0\n\n        if attention_in_multiplier is not None:\n            self.attention_in_multiplier = attention_in_multiplier\n        else:\n            self.attention_in_multiplier = 1.0\n\n        if key_multiplier is not None:\n            self.key_multiplier = key_multiplier\n        else:\n            self.key_multiplier = 1.0\n\n        if ssm_multipliers is not None:\n            self.ssm_multipliers = ssm_multipliers\n        else:\n            self.ssm_multipliers = [1.0, 1.0, 1.0, 1.0, 1.0]\n\n        if ssm_in_multiplier is not None:\n            self.ssm_in_multiplier = ssm_in_multiplier\n        else:\n            self.ssm_in_multiplier = 1.0\n\n        if ssm_out_multiplier is not None:\n            self.ssm_out_multiplier = ssm_out_multiplier\n        else:\n            self.ssm_out_multiplier = 1.0\n\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n\n    @property\n    def layers_block_type(self):\n        return [\"falcon_h1\" for i in range(self.num_hidden_layers)]\n\n    @property\n    def full_attention_layer_ids(self):\n        # For Falcon-H1, we do have attention on all layers\n        return range(self.num_hidden_layers)\n\n    @property\n    def linear_layer_ids(self):\n        # For Falcon-H1, we do have mamba on all layers\n        return range(self.num_hidden_layers)\n\n    @property\n    def mamba2_cache_params(self):\n        from sglang.srt.layers.dp_attention import get_attention_tp_size\n\n        shape = Mamba2StateShape.create(\n            tp_world_size=get_attention_tp_size(),\n            intermediate_size=self.mamba_intermediate,\n            n_groups=self.mamba_n_groups,\n            num_heads=self.mamba_n_heads,\n            head_dim=self.mamba_d_head,\n            state_size=self.mamba_d_state,\n            conv_kernel=self.mamba_d_conv,\n        )\n        return Mamba2CacheParams(\n            shape=shape, layers=self.linear_layer_ids, dtype=mamba2_state_dtype(self)\n        )\n"
  },
  {
    "path": "python/sglang/srt/configs/granitemoehybrid.py",
    "content": "# coding=utf-8\n# Copyright 2025 IBM and the HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"GraniteMoeHybrid model configuration\"\"\"\n\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.utils import logging\n\nfrom sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape\n\nlogger = logging.get_logger(__name__)\n\nMAMBA = \"mamba\"\nATTENTION = \"attention\"\n\n\nclass GraniteMoeHybridConfig(PretrainedConfig):\n    r\"\"\"\n    This is the configuration class to store the configuration of a [`GraniteMoeHybridModel`]. It is used to instantiate a\n    GraniteMoeHybrid model according to the specified arguments, defining the model architecture. The GraniteMoeHybrid is a\n    hybrid architecture combining Mamba2 layers with attention layers, developed by IBM.\n\n    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the\n    documentation from [`PretrainedConfig`] for more information.\n\n    Args:\n        vocab_size (`int`, *optional*, defaults to 100352):\n            Vocabulary size of the GraniteMoeHybrid model. Defines the number of different tokens that can be represented by the\n            `inputs_ids` passed when calling [`GraniteMoeHybridModel`]\n        tie_word_embeddings (`bool`, *optional*, defaults to `True`):\n            Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the\n            model has a output word embedding layer.\n        hidden_size (`int`, *optional*, defaults to 2048):\n            Dimension of the hidden representations.\n        intermediate_size (`int`, *optional*, defaults to 8192):\n            Dimension of the MLP representations.\n        num_hidden_layers (`int`, *optional*, defaults to 40):\n            Number of hidden layers in the model.\n        layer_types (`list[str]`, *optional*):\n            List of layer types for each layer. Each element should be either \"mamba\" or \"attention\".\n            If not provided, defaults to alternating pattern based on num_hidden_layers.\n        num_attention_heads (`int`, *optional*, defaults to 32):\n            Number of attention heads for each attention layer.\n        num_key_value_heads (`int`, *optional*, defaults to 8):\n            This is the number of key_value heads that should be used to implement Grouped Query Attention. If\n            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if\n            `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used.\n        hidden_act (`str` or `function`, *optional*, defaults to `\"silu\"`):\n            The non-linear activation function (function or string) in the decoder.\n        initializer_range (`float`, *optional*, defaults to 0.1):\n            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.\n        rms_norm_eps (`float`, *optional*, defaults to 1e-05):\n            The epsilon used by the rms normalization layers.\n        normalization_function (`str`, *optional*, defaults to `\"rmsnorm\"`):\n            The normalization function to use. Currently only \"rmsnorm\" is supported.\n        use_cache (`bool`, *optional*, defaults to `True`):\n            Whether or not the model should return the last key/values attentions (not used by all models). Only\n            relevant if `config.is_decoder=True`.\n        pad_token_id (`int`, *optional*, defaults to 100256):\n            The id of the padding token.\n        bos_token_id (`int`, *optional*, defaults to 100257):\n            The id of the \"beginning-of-sequence\" token.\n        eos_token_id (`int`, *optional*, defaults to 100257):\n            The id of the \"end-of-sequence\" token.\n        max_position_embeddings (`int`, *optional*, defaults to 131072):\n            Max cached sequence length for the model\n        attention_dropout (`float`, *optional*, defaults to 0.0):\n            The dropout ratio for the attention probabilities.\n        attention_bias (`bool`, *optional*, defaults to `False`):\n            Whether to use bias in attention layers.\n        position_embedding_type (`str`, *optional*, defaults to `\"nope\"`):\n            Type of position embedding. Can be \"nope\" (no position embedding) or \"rope\".\n        rope_theta (`float`, *optional*, defaults to 10000.0):\n            The theta value used for the RoPE embeddings.\n        rope_scaling (`dict`, *optional*):\n            The scaling configuration for the RoPE embeddings. If `None`, no scaling is applied.\n        mamba_d_state (`int`, *optional*, defaults to 128):\n            The dimension of the mamba state space latents\n        mamba_d_conv (`int`, *optional*, defaults to 4):\n            The size of the mamba convolution kernel\n        mamba_expand (`int`, *optional*, defaults to 2):\n            Expanding factor (relative to hidden_size) used to determine the mamba intermediate size\n        mamba_d_head (`int`, *optional*, defaults to 64):\n            Head embedding dimension size for Mamba\n        mamba_n_heads (`int`, *optional*, defaults to 64):\n            The number of mamba heads\n        mamba_n_groups (`int`, *optional*, defaults to 1):\n            The number of the mamba groups\n        mamba_chunk_size (`int`, *optional*, defaults to 256):\n            The chunks in which to break the sequence when doing prefill/training\n        mamba_conv_bias (`bool`, *optional*, defaults to `True`):\n            Flag indicating whether or not to use bias in the convolution layer of the mamba mixer block.\n        mamba_proj_bias (`bool`, *optional*, defaults to `False`):\n            Flag indicating whether or not to use bias in the input and output projections of the mamba mixer block\n        embedding_multiplier (`float`, *optional*, defaults to 12.0):\n            The multiplier for the embedding layer. This is used to scale the output of the embedding layer.\n        logits_scaling (`float`, *optional*, defaults to 8.0):\n            The scaling factor for the logits.\n        attention_multiplier (`float`, *optional*, defaults to 0.015625):\n            The multiplier for the attention layers.\n        residual_multiplier (`float`, *optional*, defaults to 0.22):\n            The multiplier for residual connections.\n        num_local_experts (`int`, *optional*, defaults to 0):\n            Number of local experts in MoE layers.\n        num_experts_per_tok (`int`, *optional*, defaults to 0):\n            Number of experts to use per token in MoE layers.\n        shared_intermediate_size (`int`, *optional*, defaults to 8192):\n            Intermediate size for shared experts.\n        output_router_logits (`bool`, *optional*, defaults to `False`):\n            Whether to output router logits.\n        router_aux_loss_coef (`float`, *optional*, defaults to 0.01):\n            Auxiliary loss coefficient for the router.\n    \"\"\"\n\n    model_type = \"granitemoehybrid\"\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n\n    def __init__(\n        self,\n        vocab_size=100352,\n        tie_word_embeddings=True,\n        hidden_size=2048,\n        intermediate_size=8192,\n        num_hidden_layers=40,\n        layer_types=None,\n        num_attention_heads=32,\n        num_key_value_heads=8,\n        hidden_act=\"silu\",\n        initializer_range=0.1,\n        rms_norm_eps=1e-5,\n        normalization_function=\"rmsnorm\",\n        use_cache=True,\n        pad_token_id=100256,\n        bos_token_id=100257,\n        eos_token_id=100257,\n        max_position_embeddings=131072,\n        attention_dropout=0.0,\n        attention_bias=False,\n        position_embedding_type=\"nope\",\n        rope_theta=10000.0,\n        rope_scaling=None,\n        mamba_d_state=128,\n        mamba_d_conv=4,\n        mamba_expand=2,\n        mamba_d_head=64,\n        mamba_n_heads=64,\n        mamba_n_groups=1,\n        mamba_chunk_size=256,\n        mamba_conv_bias=True,\n        mamba_proj_bias=False,\n        embedding_multiplier=12.0,\n        logits_scaling=8.0,\n        attention_multiplier=0.015625,\n        residual_multiplier=0.22,\n        num_local_experts=0,\n        num_experts_per_tok=0,\n        shared_intermediate_size=8192,\n        output_router_logits=False,\n        router_aux_loss_coef=0.01,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n\n        # Set layer types - if not provided, create default pattern\n        if layer_types is None:\n            # Default pattern: mamba layers with attention every 6th layer (roughly)\n            self.layer_types = []\n            for i in range(num_hidden_layers):\n                if (i + 1) % 6 == 0:\n                    self.layer_types.append(ATTENTION)\n                else:\n                    self.layer_types.append(MAMBA)\n        else:\n            self.layer_types = layer_types\n\n        # Validate layer_types\n        if len(self.layer_types) != self.num_hidden_layers:\n            raise ValueError(\n                f\"layer_types must have length equal to num_hidden_layers ({num_hidden_layers}), \"\n                f\"but got {len(self.layer_types)}\"\n            )\n\n        for layer_type in self.layer_types:\n            if layer_type not in [MAMBA, ATTENTION]:\n                raise ValueError(\n                    f\"Each element in layer_types must be either '{MAMBA}' or '{ATTENTION}', \"\n                    f\"but got '{layer_type}'\"\n                )\n\n        self.num_attention_heads = num_attention_heads\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.normalization_function = normalization_function\n\n        self.use_cache = use_cache\n        self.max_position_embeddings = max_position_embeddings\n        self.attention_dropout = attention_dropout\n        self.attention_bias = attention_bias\n\n        self.position_embedding_type = position_embedding_type\n        self.rope_theta = rope_theta\n        self.rope_scaling = rope_scaling\n\n        # Mamba configuration\n        self.mamba_d_state = mamba_d_state\n        self.mamba_d_conv = mamba_d_conv\n        self.mamba_expand = mamba_expand\n        self.mamba_d_head = mamba_d_head\n        self.mamba_n_heads = mamba_n_heads\n        self.mamba_n_groups = mamba_n_groups\n        self.mamba_chunk_size = mamba_chunk_size\n        self.mamba_conv_bias = mamba_conv_bias\n        self.mamba_proj_bias = mamba_proj_bias\n\n        # Calculate mamba intermediate size\n        self.mamba_intermediate_size = mamba_expand * hidden_size\n\n        # Validate mamba configuration\n        if self.mamba_intermediate_size % mamba_n_heads != 0:\n            raise ValueError(\n                f\"mamba_intermediate_size ({self.mamba_intermediate_size}) must be divisible by \"\n                f\"mamba_n_heads ({mamba_n_heads})\"\n            )\n\n        if mamba_d_head * mamba_n_heads != self.mamba_intermediate_size:\n            raise ValueError(\n                f\"mamba_d_head ({mamba_d_head}) * mamba_n_heads ({mamba_n_heads}) must equal \"\n                f\"mamba_intermediate_size ({self.mamba_intermediate_size})\"\n            )\n\n        # Scaling factors\n        self.embedding_multiplier = embedding_multiplier\n        self.logits_scaling = logits_scaling\n        self.attention_multiplier = attention_multiplier\n        self.residual_multiplier = residual_multiplier\n\n        # MoE configuration\n        self.num_local_experts = num_local_experts\n        self.num_experts_per_tok = num_experts_per_tok\n        self.shared_intermediate_size = shared_intermediate_size\n        self.output_router_logits = output_router_logits\n        self.router_aux_loss_coef = router_aux_loss_coef\n\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n\n    @property\n    def mamba_layer_ids(self):\n        \"\"\"Returns the indices of layers that are Mamba layers.\"\"\"\n        return [\n            i for i in range(self.num_hidden_layers) if self.layer_types[i] == MAMBA\n        ]\n\n    @property\n    def attention_layer_ids(self):\n        \"\"\"Returns the indices of layers that are attention layers.\"\"\"\n        return [\n            i for i in range(self.num_hidden_layers) if self.layer_types[i] == ATTENTION\n        ]\n\n    @property\n    def full_attention_layer_ids(self):\n        \"\"\"Alias for attention_layer_ids for compatibility.\"\"\"\n        return self.attention_layer_ids\n\n    @property\n    def mamba2_cache_params(self):\n        \"\"\"Returns the Mamba2 cache parameters for this configuration.\"\"\"\n        from sglang.srt.layers.dp_attention import get_attention_tp_size\n\n        shape = Mamba2StateShape.create(\n            tp_world_size=get_attention_tp_size(),\n            intermediate_size=self.mamba_intermediate_size,\n            n_groups=self.mamba_n_groups,\n            num_heads=self.mamba_n_heads,\n            head_dim=self.mamba_d_head,\n            state_size=self.mamba_d_state,\n            conv_kernel=self.mamba_d_conv,\n        )\n        return Mamba2CacheParams(shape=shape, layers=self.mamba_layer_ids)\n"
  },
  {
    "path": "python/sglang/srt/configs/internvl.py",
    "content": "import copy\nimport os\nfrom shutil import copyfile\nfrom typing import Any, Dict, List, Optional, Tuple, Union\n\nimport sentencepiece as spm\nfrom transformers import (\n    TOKENIZER_MAPPING,\n    GptOssConfig,\n    LlamaConfig,\n    PretrainedConfig,\n    PreTrainedTokenizer,\n    Qwen2Config,\n    Qwen3Config,\n    Qwen3MoeConfig,\n)\n\nfrom sglang.utils import logger\n\n# Copied from: https://github.com/OpenGVLab/InternVL/blob/34a81000402bf8f716bab8c9b57aff1f6b436bd0/internvl_chat/internvl/model/internvl_chat/configuration_internvl_chat.py#L21\n\n\nVOCAB_FILES_NAMES = {\"vocab_file\": \"./tokenizer.model\"}\n\nPRETRAINED_VOCAB_FILES_MAP = {}\n\n\n# Modified from transformers.model.llama.configuration_llama.LlamaConfig\nclass InternLM2Config(PretrainedConfig):\n    r\"\"\"\n    This is the configuration class to store the configuration of a [`InternLM2Model`]. It is used to instantiate\n    an InternLM2 model according to the specified arguments, defining the model architecture. Instantiating a\n    configuration with the defaults will yield a similar configuration to that of the InternLM2-7B.\n\n    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the\n    documentation from [`PretrainedConfig`] for more information.\n\n\n    Args:\n        vocab_size (`int`, *optional*, defaults to 32000):\n            Vocabulary size of the InternLM2 model. Defines the number of different tokens that can be represented by the\n            `inputs_ids` passed when calling [`InternLM2Model`]\n        hidden_size (`int`, *optional*, defaults to 4096):\n            Dimension of the hidden representations.\n        intermediate_size (`int`, *optional*, defaults to 11008):\n            Dimension of the MLP representations.\n        num_hidden_layers (`int`, *optional*, defaults to 32):\n            Number of hidden layers in the Transformer encoder.\n        num_attention_heads (`int`, *optional*, defaults to 32):\n            Number of attention heads for each attention layer in the Transformer encoder.\n        num_key_value_heads (`int`, *optional*):\n            This is the number of key_value heads that should be used to implement Grouped Query Attention. If\n            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if\n            `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When\n            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed\n            by meanpooling all the original heads within that group. For more details checkout [this\n            paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to\n            `num_attention_heads`.\n        hidden_act (`str` or `function`, *optional*, defaults to `\"silu\"`):\n            The non-linear activation function (function or string) in the decoder.\n        max_position_embeddings (`int`, *optional*, defaults to 2048):\n            The maximum sequence length that this model might ever be used with. Typically set this to something large\n            just in case (e.g., 512 or 1024 or 2048).\n        initializer_range (`float`, *optional*, defaults to 0.02):\n            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.\n        rms_norm_eps (`float`, *optional*, defaults to 1e-12):\n            The epsilon used by the rms normalization layers.\n        use_cache (`bool`, *optional*, defaults to `True`):\n            Whether or not the model should return the last key/values attentions (not used by all models). Only\n            relevant if `config.is_decoder=True`.\n        tie_word_embeddings(`bool`, *optional*, defaults to `False`):\n            Whether to tie weight embeddings\n        Example:\n\n    \"\"\"\n\n    model_type = \"internlm2\"\n    _auto_class = \"AutoConfig\"\n\n    def __init__(  # pylint: disable=W0102\n        self,\n        vocab_size=103168,\n        hidden_size=4096,\n        intermediate_size=11008,\n        num_hidden_layers=32,\n        num_attention_heads=32,\n        num_key_value_heads=None,\n        hidden_act=\"silu\",\n        max_position_embeddings=2048,\n        initializer_range=0.02,\n        rms_norm_eps=1e-6,\n        use_cache=True,\n        pad_token_id=0,\n        bos_token_id=1,\n        eos_token_id=2,\n        tie_word_embeddings=False,\n        bias=True,\n        rope_theta=10000,\n        rope_scaling=None,\n        attn_implementation=\"eager\",\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.bias = bias\n\n        if num_key_value_heads is None:\n            num_key_value_heads = num_attention_heads\n        self.num_key_value_heads = num_key_value_heads\n\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.rope_scaling = rope_scaling\n        self._rope_scaling_validation()\n\n        self.attn_implementation = attn_implementation\n        if self.attn_implementation is None:\n            self.attn_implementation = \"eager\"\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n\n    def _rope_scaling_validation(self):\n        \"\"\"\n        Validate the `rope_scaling` configuration.\n        \"\"\"\n        if self.rope_scaling is None:\n            return\n\n        if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:\n            raise ValueError(\n                \"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, \"\n                f\"got {self.rope_scaling}\"\n            )\n        rope_scaling_type = self.rope_scaling.get(\"type\", None)\n        rope_scaling_factor = self.rope_scaling.get(\"factor\", None)\n        if rope_scaling_type is None or rope_scaling_type not in [\"linear\", \"dynamic\"]:\n            raise ValueError(\n                f\"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}\"\n            )\n        if (\n            rope_scaling_factor is None\n            or not isinstance(rope_scaling_factor, (float, int))\n            or rope_scaling_factor < 1.0\n        ):\n            raise ValueError(\n                f\"`rope_scaling`'s factor field must be a float|int >= 1, got {rope_scaling_factor=}, {type(rope_scaling_factor)=}\"\n            )\n        if isinstance(rope_scaling_factor, int):\n            rope_scaling_factor = float(rope_scaling_factor)\n\n\nclass InternVisionConfig(PretrainedConfig):\n    r\"\"\"\n    This is the configuration class to store the configuration of a [`InternVisionModel`]. It is used to\n    instantiate a vision encoder according to the specified arguments, defining the model architecture.\n\n    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the\n    documentation from [`PretrainedConfig`] for more information.\n\n    Args:\n        num_channels (`int`, *optional*, defaults to 3):\n            Number of color channels in the input images (e.g., 3 for RGB).\n        patch_size (`int`, *optional*, defaults to 14):\n            The size (resolution) of each patch.\n        image_size (`int`, *optional*, defaults to 224):\n            The size (resolution) of each image.\n        qkv_bias (`bool`, *optional*, defaults to `False`):\n            Whether to add a bias to the queries and values in the self-attention layers.\n        hidden_size (`int`, *optional*, defaults to 3200):\n            Dimensionality of the encoder layers and the pooler layer.\n        num_attention_heads (`int`, *optional*, defaults to 25):\n            Number of attention heads for each attention layer in the Transformer encoder.\n        intermediate_size (`int`, *optional*, defaults to 12800):\n            Dimensionality of the \"intermediate\" (i.e., feed-forward) layer in the Transformer encoder.\n        qk_normalization (`bool`, *optional*, defaults to `True`):\n            Whether to normalize the queries and keys in the self-attention layers.\n        num_hidden_layers (`int`, *optional*, defaults to 48):\n            Number of hidden layers in the Transformer encoder.\n        use_flash_attn (`bool`, *optional*, defaults to `True`):\n            Whether to use flash attention mechanism.\n        hidden_act (`str` or `function`, *optional*, defaults to `\"gelu\"`):\n            The non-linear activation function (function or string) in the encoder and pooler. If string, `\"gelu\"`,\n            `\"relu\"`, `\"selu\"` and `\"gelu_new\"` ``\"gelu\"` are supported.\n        layer_norm_eps (`float`, *optional*, defaults to 1e-6):\n            The epsilon used by the layer normalization layers.\n        dropout (`float`, *optional*, defaults to 0.0):\n            The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.\n        drop_path_rate (`float`, *optional*, defaults to 0.0):\n            Dropout rate for stochastic depth.\n        attention_dropout (`float`, *optional*, defaults to 0.0):\n            The dropout ratio for the attention probabilities.\n        initializer_range (`float`, *optional*, defaults to 0.02):\n            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.\n        initializer_factor (`float`, *optional*, defaults to 0.1):\n            A factor for layer scale.\n    \"\"\"\n\n    model_type = \"intern_vit_6b\"\n\n    def __init__(\n        self,\n        num_channels=3,\n        patch_size=14,\n        image_size=224,\n        qkv_bias=False,\n        hidden_size=3200,\n        num_attention_heads=25,\n        intermediate_size=12800,\n        qk_normalization=True,\n        num_hidden_layers=48,\n        use_flash_attn=True,\n        hidden_act=\"gelu\",\n        layer_norm_eps=1e-6,\n        dropout=0.0,\n        drop_path_rate=0.0,\n        attention_dropout=0.0,\n        initializer_range=0.02,\n        initializer_factor=0.1,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.dropout = dropout\n        self.drop_path_rate = drop_path_rate\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.num_channels = num_channels\n        self.patch_size = patch_size\n        self.image_size = image_size\n        self.initializer_range = initializer_range\n        self.initializer_factor = initializer_factor\n        self.attention_dropout = attention_dropout\n        self.layer_norm_eps = layer_norm_eps\n        self.hidden_act = hidden_act\n        self.qkv_bias = qkv_bias\n        self.qk_normalization = qk_normalization\n        self.use_flash_attn = use_flash_attn\n\n    @classmethod\n    def from_pretrained(\n        cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs\n    ) -> \"PretrainedConfig\":\n        config_dict, kwargs = cls.get_config_dict(\n            pretrained_model_name_or_path, **kwargs\n        )\n\n        if \"vision_config\" in config_dict:\n            config_dict = config_dict[\"vision_config\"]\n\n        if (\n            \"model_type\" in config_dict\n            and hasattr(cls, \"model_type\")\n            and config_dict[\"model_type\"] != cls.model_type\n        ):\n            logger.warning(\n                f\"You are using a model of type {config_dict['model_type']} to instantiate a model of type \"\n                f\"{cls.model_type}. This is not supported for all configurations of models and can yield errors.\"\n            )\n\n        return cls.from_dict(config_dict, **kwargs)\n\n\nclass InternVLChatConfig(PretrainedConfig):\n    model_type = \"internvl_chat\"\n    is_composition = True\n\n    def __init__(\n        self,\n        vision_config=None,\n        llm_config=None,\n        use_backbone_lora=0,\n        use_llm_lora=0,\n        pad2square=False,\n        select_layer=-1,\n        force_image_size=None,\n        downsample_ratio=0.5,\n        template=None,\n        dynamic_image_size=False,\n        use_thumbnail=False,\n        ps_version=\"v1\",\n        min_dynamic_patch=1,\n        max_dynamic_patch=6,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n\n        if vision_config is None:\n            vision_config = {\"architectures\": [\"InternVisionModel\"]}\n            logger.info(\n                \"vision_config is None. Initializing the InternVisionConfig with default values.\"\n            )\n\n        if llm_config is None:\n            llm_config = {\"architectures\": [\"InternLM2ForCausalLM\"]}\n            logger.info(\n                \"llm_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`).\"\n            )\n\n        self.vision_config = InternVisionConfig(**vision_config)\n        if llm_config.get(\"architectures\")[0] == \"LlamaForCausalLM\":\n            self.llm_config = LlamaConfig(**llm_config)\n        elif llm_config.get(\"architectures\")[0] == \"InternLM2ForCausalLM\":\n            self.llm_config = InternLM2Config(**llm_config)\n        elif llm_config.get(\"architectures\")[0] == \"Qwen2ForCausalLM\":\n            self.llm_config = Qwen2Config(**llm_config)\n        elif llm_config.get(\"architectures\")[0] == \"Qwen3MoeForCausalLM\":\n            self.llm_config = Qwen3MoeConfig(**llm_config)\n        elif llm_config.get(\"architectures\")[0] == \"Qwen3ForCausalLM\":\n            self.llm_config = Qwen3Config(**llm_config)\n        elif llm_config.get(\"architectures\")[0] == \"GptOssForCausalLM\":\n            self.llm_config = GptOssConfig(**llm_config)\n        else:\n            raise ValueError(\n                \"Unsupported architecture: {}\".format(\n                    llm_config.get(\"architectures\")[0]\n                )\n            )\n\n        self.use_backbone_lora = use_backbone_lora\n        self.use_llm_lora = use_llm_lora\n        self.pad2square = pad2square\n        self.select_layer = select_layer\n        self.force_image_size = force_image_size\n        self.downsample_ratio = downsample_ratio\n        self.template = template\n        self.dynamic_image_size = dynamic_image_size\n        self.use_thumbnail = use_thumbnail\n        self.ps_version = ps_version  # pixel shuffle version\n        self.min_dynamic_patch = min_dynamic_patch\n        self.max_dynamic_patch = max_dynamic_patch\n\n        self.hidden_size = self.llm_config.hidden_size\n        # By default, we use tie_word_embeddings=False for models of all sizes.\n        self.tie_word_embeddings = False\n        self.llm_config.tie_word_embeddings = self.tie_word_embeddings\n\n    def to_dict(self):\n        \"\"\"\n        Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].\n\n        Returns:\n            `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,\n        \"\"\"\n        output = copy.deepcopy(self.__dict__)\n        output[\"vision_config\"] = self.vision_config.to_dict()\n        output[\"llm_config\"] = self.llm_config.to_dict()\n        output[\"model_type\"] = self.__class__.model_type\n        output[\"use_backbone_lora\"] = self.use_backbone_lora\n        output[\"use_llm_lora\"] = self.use_llm_lora\n        output[\"select_layer\"] = self.select_layer\n        output[\"force_image_size\"] = self.force_image_size\n        output[\"downsample_ratio\"] = self.downsample_ratio\n        output[\"template\"] = self.template\n        output[\"dynamic_image_size\"] = self.dynamic_image_size\n        output[\"use_thumbnail\"] = self.use_thumbnail\n        output[\"ps_version\"] = self.ps_version\n        output[\"min_dynamic_patch\"] = self.min_dynamic_patch\n        output[\"max_dynamic_patch\"] = self.max_dynamic_patch\n\n        return output\n\n\n# # Modified from transformers.model.llama.tokenization_llama_fast.LlamaTokenizerFast -> InternLM2TokenizerFast\n# class InternLM2TokenizerFast(PreTrainedTokenizerFast):\n#     vocab_files_names = VOCAB_FILES_NAMES\n#     slow_tokenizer_class = InternLM2Tokenizer\n#     padding_side = 'left'\n#     model_input_names = ['input_ids', 'attention_mask']\n#     _auto_class = 'AutoTokenizer'\n#\n#     def __init__(\n#         self,\n#         vocab_file,\n#         unk_token='<unk>',\n#         bos_token='<s>',\n#         eos_token='</s>',\n#         pad_token='</s>',\n#         sp_model_kwargs: Optional[Dict[str, Any]] = None,\n#         add_bos_token=True,\n#         add_eos_token=False,\n#         decode_with_prefix_space=False,\n#         clean_up_tokenization_spaces=False,\n#         **kwargs,\n#     ):\n#         super().__init__(\n#             vocab_file=vocab_file,\n#             unk_token=unk_token,\n#             bos_token=bos_token,\n#             eos_token=eos_token,\n#             pad_token=pad_token,\n#             sp_model_kwargs=sp_model_kwargs,\n#             add_bos_token=add_bos_token,\n#             add_eos_token=add_eos_token,\n#             decode_with_prefix_space=decode_with_prefix_space,\n#             clean_up_tokenization_spaces=clean_up_tokenization_spaces,\n#             **kwargs,\n#         )\n#         self._add_bos_token = add_bos_token\n#         self._add_eos_token = add_eos_token\n#         self.update_post_processor()\n#         self.vocab_file = vocab_file\n#\n#     @property\n#     def can_save_slow_tokenizer(self) -> bool:\n#         return os.path.isfile(self.vocab_file) if self.vocab_file else False\n#\n#     def update_post_processor(self):\n#         \"\"\"\n#         Updates the underlying post processor with the current `bos_token` and `eos_token`.\n#         \"\"\"\n#         bos = self.bos_token\n#         bos_token_id = self.bos_token_id\n#         if bos is None and self.add_bos_token:\n#             raise ValueError('add_bos_token = True but bos_token = None')\n#\n#         eos = self.eos_token\n#         eos_token_id = self.eos_token_id\n#         if eos is None and self.add_eos_token:\n#             raise ValueError('add_eos_token = True but eos_token = None')\n#\n#         single = f\"{(bos + ':0 ') if self.add_bos_token else ''}$A:0{(' ' + eos + ':0') if self.add_eos_token else ''}\"\n#         pair = f\"{single}{(' ' + bos + ':1') if self.add_bos_token else ''} $B:1{(' ' + eos + ':1') if self.add_eos_token else ''}\"\n#\n#         special_tokens = []\n#         if self.add_bos_token:\n#             special_tokens.append((bos, bos_token_id))\n#         if self.add_eos_token:\n#             special_tokens.append((eos, eos_token_id))\n#         self._tokenizer.post_processor = processors.TemplateProcessing(\n#             single=single, pair=pair, special_tokens=special_tokens\n#         )\n#\n#     @property\n#     def add_eos_token(self):\n#         return self._add_eos_token\n#\n#     @property\n#     def add_bos_token(self):\n#         return self._add_bos_token\n#\n#     @add_eos_token.setter\n#     def add_eos_token(self, value):\n#         self._add_eos_token = value\n#         self.update_post_processor()\n#\n#     @add_bos_token.setter\n#     def add_bos_token(self, value):\n#         self._add_bos_token = value\n#         self.update_post_processor()\n#\n#     def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:\n#         if not self.can_save_slow_tokenizer:\n#             raise ValueError(\n#                 'Your fast tokenizer does not have the necessary information to save the vocabulary for a slow '\n#                 'tokenizer.'\n#             )\n#\n#         if not os.path.isdir(save_directory):\n#             logger.error(f'Vocabulary path ({save_directory}) should be a directory')\n#             return\n#         out_vocab_file = os.path.join(\n#             save_directory, (filename_prefix + '-' if filename_prefix else '') + VOCAB_FILES_NAMES['vocab_file']\n#         )\n#\n#         if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):\n#             copyfile(self.vocab_file, out_vocab_file)\n#\n#         return (out_vocab_file,)\n\n\n# Modified from transformers.model.llama.tokenization_llama.LlamaTokenizer\nclass InternLM2Tokenizer(PreTrainedTokenizer):\n    \"\"\"\n    Construct a InternLM2 tokenizer. Based on byte-level Byte-Pair-Encoding.\n\n    Args:\n        vocab_file (`str`):\n            Path to the vocabulary file.\n    \"\"\"\n\n    vocab_files_names = VOCAB_FILES_NAMES\n    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP\n    model_input_names = [\"input_ids\", \"attention_mask\"]\n    _auto_class = \"AutoTokenizer\"\n\n    def __init__(\n        self,\n        vocab_file,\n        unk_token=\"<unk>\",\n        bos_token=\"<s>\",\n        eos_token=\"</s>\",\n        pad_token=\"</s>\",\n        sp_model_kwargs: Optional[Dict[str, Any]] = None,\n        add_bos_token=True,\n        add_eos_token=False,\n        decode_with_prefix_space=False,\n        clean_up_tokenization_spaces=False,\n        **kwargs,\n    ):\n        print(\"register succeed\")\n        self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs\n        self.vocab_file = vocab_file\n        self.add_bos_token = add_bos_token\n        self.add_eos_token = add_eos_token\n        self.decode_with_prefix_space = decode_with_prefix_space\n        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)\n        self.sp_model.Load(vocab_file)\n        self._no_prefix_space_tokens = None\n        super().__init__(\n            bos_token=bos_token,\n            eos_token=eos_token,\n            unk_token=unk_token,\n            pad_token=pad_token,\n            clean_up_tokenization_spaces=clean_up_tokenization_spaces,\n            **kwargs,\n        )\n\n    @property\n    def no_prefix_space_tokens(self):\n        if self._no_prefix_space_tokens is None:\n            vocab = self.convert_ids_to_tokens(list(range(self.vocab_size)))\n            self._no_prefix_space_tokens = {\n                i for i, tok in enumerate(vocab) if not tok.startswith(\"▁\")\n            }\n        return self._no_prefix_space_tokens\n\n    @property\n    def vocab_size(self):\n        \"\"\"Returns vocab size\"\"\"\n        return self.sp_model.get_piece_size()\n\n    @property\n    def bos_token_id(self) -> Optional[int]:\n        return self.sp_model.bos_id()\n\n    @property\n    def eos_token_id(self) -> Optional[int]:\n        return self.sp_model.eos_id()\n\n    def get_vocab(self):\n        \"\"\"Returns vocab as a dict\"\"\"\n        vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}\n        vocab.update(self.added_tokens_encoder)\n        return vocab\n\n    def _tokenize(self, text):\n        \"\"\"Returns a tokenized string.\"\"\"\n        return self.sp_model.encode(text, out_type=str)\n\n    def _convert_token_to_id(self, token):\n        \"\"\"Converts a token (str) in an id using the vocab.\"\"\"\n        return self.sp_model.piece_to_id(token)\n\n    def _convert_id_to_token(self, index):\n        \"\"\"Converts an index (integer) in a token (str) using the vocab.\"\"\"\n        token = self.sp_model.IdToPiece(index)\n        return token\n\n    def _maybe_add_prefix_space(self, tokens, decoded):\n        if tokens and tokens[0] not in self.no_prefix_space_tokens:\n            return \" \" + decoded\n        else:\n            return decoded\n\n    def convert_tokens_to_string(self, tokens):\n        \"\"\"Converts a sequence of tokens (string) in a single string.\"\"\"\n        current_sub_tokens = []\n        out_string = \"\"\n        prev_is_special = False\n        for token in tokens:\n            # make sure that special tokens are not decoded using sentencepiece model\n            if token in self.all_special_tokens:\n                if not prev_is_special:\n                    out_string += \" \"\n                out_string += self.sp_model.decode(current_sub_tokens) + token\n                prev_is_special = True\n                current_sub_tokens = []\n            else:\n                current_sub_tokens.append(token)\n                prev_is_special = False\n        out_string += self.sp_model.decode(current_sub_tokens)\n        out_string = self._maybe_add_prefix_space(tokens=tokens, decoded=out_string)\n        return out_string[1:]\n\n    def save_vocabulary(\n        self, save_directory, filename_prefix: Optional[str] = None\n    ) -> Tuple[str]:\n        \"\"\"\n        Save the vocabulary and special tokens file to a directory.\n\n        Args:\n            save_directory (`str`):\n                The directory in which to save the vocabulary.\n\n        Returns:\n            `Tuple(str)`: Paths to the files saved.\n        \"\"\"\n        if not os.path.isdir(save_directory):\n            logger.error(f\"Vocabulary path ({save_directory}) should be a directory\")\n            return\n        out_vocab_file = os.path.join(\n            save_directory,\n            (filename_prefix + \"-\" if filename_prefix else \"\")\n            + VOCAB_FILES_NAMES[\"vocab_file\"],\n        )\n\n        if os.path.abspath(self.vocab_file) != os.path.abspath(\n            out_vocab_file\n        ) and os.path.isfile(self.vocab_file):\n            copyfile(self.vocab_file, out_vocab_file)\n        elif not os.path.isfile(self.vocab_file):\n            with open(out_vocab_file, \"wb\") as fi:\n                content_spiece_model = self.sp_model.serialized_model_proto()\n                fi.write(content_spiece_model)\n\n        return (out_vocab_file,)\n\n    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):\n        if self.add_bos_token:\n            bos_token_ids = [self.bos_token_id]\n        else:\n            bos_token_ids = []\n\n        output = bos_token_ids + token_ids_0\n\n        if token_ids_1 is not None:\n            output = output + token_ids_1\n\n        if self.add_eos_token:\n            output = output + [self.eos_token_id]\n\n        return output\n\n    def get_special_tokens_mask(\n        self,\n        token_ids_0: List[int],\n        token_ids_1: Optional[List[int]] = None,\n        already_has_special_tokens: bool = False,\n    ) -> List[int]:\n        \"\"\"\n        Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding\n        special tokens using the tokenizer `prepare_for_model` method.\n\n        Args:\n            token_ids_0 (`List[int]`):\n                List of IDs.\n            token_ids_1 (`List[int]`, *optional*):\n                Optional second list of IDs for sequence pairs.\n            already_has_special_tokens (`bool`, *optional*, defaults to `False`):\n                Whether or not the token list is already formatted with special tokens for the model.\n\n        Returns:\n            `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.\n        \"\"\"\n        if already_has_special_tokens:\n            return super().get_special_tokens_mask(\n                token_ids_0=token_ids_0,\n                token_ids_1=token_ids_1,\n                already_has_special_tokens=True,\n            )\n\n        if token_ids_1 is None:\n            return [1] + ([0] * len(token_ids_0)) + [1]\n        return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]\n\n    def create_token_type_ids_from_sequences(\n        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None\n    ) -> List[int]:\n        \"\"\"\n        Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make\n        use of token type ids, therefore a list of zeros is returned.\n\n        Args:\n            token_ids_0 (`List[int]`):\n                List of IDs.\n            token_ids_1 (`List[int]`, *optional*):\n                Optional second list of IDs for sequence pairs.\n\n        Returns:\n            `List[int]`: List of zeros.\n        \"\"\"\n        eos = [self.eos_token_id]\n\n        if token_ids_1 is None:\n            return len(token_ids_0 + eos) * [0]\n        return len(token_ids_0 + eos + token_ids_1 + eos) * [0]\n\n\nTOKENIZER_MAPPING.register(\n    InternVLChatConfig, (InternLM2Tokenizer, None), exist_ok=True\n)\n"
  },
  {
    "path": "python/sglang/srt/configs/janus_pro.py",
    "content": "# Adapted from:\n# https://github.com/deepseek-ai/Janus/tree/main/janus/models\n\nfrom dataclasses import dataclass\nfrom typing import Dict, List, Tuple, Union\n\nimport numpy as np\nimport PIL\nimport torch\nfrom PIL.Image import Image\nfrom transformers import (\n    BaseImageProcessor,\n    BatchFeature,\n    LlamaConfig,\n    LlamaTokenizerFast,\n    PretrainedConfig,\n    ProcessorMixin,\n)\nfrom transformers.image_utils import to_numpy_array\n\nfrom sglang.srt.configs.utils import register_image_processor, register_processor\nfrom sglang.srt.multimodal.mm_utils import expand2square\n\n\nclass DictToObject(dict):\n    def __init__(self, dictionary):\n        super(self).__init__(dictionary)\n\n        for key, value in dictionary.items():\n            if isinstance(value, dict):\n                value = DictToObject(value)\n            setattr(self, key, value)\n\n\nclass VisionConfig(PretrainedConfig):\n    model_type = \"vision\"\n    cls: str = \"\"\n    params = {}\n\n    def __init__(self, **kwargs):\n        super().__init__(**kwargs)\n\n        self.cls = kwargs.get(\"cls\", \"\")\n        if not isinstance(self.cls, str):\n            self.cls = self.cls.__name__\n\n        self.params = kwargs.get(\"params\", {})\n\n\nclass GenAlignerConfig(PretrainedConfig):\n    model_type = \"gen_aligner\"\n    cls: str = \"\"\n    params = {}\n\n    def __init__(self, **kwargs):\n        super().__init__(**kwargs)\n\n        self.cls = kwargs.get(\"cls\", \"\")\n        if not isinstance(self.cls, str):\n            self.cls = self.cls.__name__\n\n        self.params = kwargs.get(\"params\", {})\n\n\nclass GenHeadConfig(PretrainedConfig):\n    model_type = \"gen_head\"\n    cls: str = \"\"\n    params = {}\n\n    def __init__(self, **kwargs):\n        super().__init__(**kwargs)\n\n        self.cls = kwargs.get(\"cls\", \"\")\n        if not isinstance(self.cls, str):\n            self.cls = self.cls.__name__\n\n        self.params = kwargs.get(\"params\", {})\n\n\nclass AlignerConfig(PretrainedConfig):\n    model_type = \"aligner\"\n    cls: str = \"\"\n    params = {}\n\n    def __init__(self, **kwargs):\n        super().__init__(**kwargs)\n\n        self.cls = kwargs.get(\"cls\", \"\")\n        if not isinstance(self.cls, str):\n            self.cls = self.cls.__name__\n\n        self.params = kwargs.get(\"params\", {})\n\n\nclass GenVisionConfig(PretrainedConfig):\n    model_type = \"gen_vision\"\n    cls: str = \"\"\n    params = {}\n\n    def __init__(self, **kwargs):\n        super().__init__(**kwargs)\n\n        self.cls = kwargs.get(\"cls\", \"\")\n        if not isinstance(self.cls, str):\n            self.cls = self.cls.__name__\n\n        self.params = kwargs.get(\"params\", {})\n\n\n@dataclass\nclass SigLIPVisionCfg:\n    width: int = 1152\n    layers: Union[Tuple[int, int, int, int], int] = 27\n    heads: int = 16\n    patch_size: int = 14\n    image_size: Union[Tuple[int, int], int] = 336\n    global_pool: str = \"map\"\n    mlp_ratio: float = 3.7362\n    class_token: bool = False\n    num_classes: int = 0\n    use_checkpoint: bool = False\n\n\nclass MultiModalityConfig(PretrainedConfig):\n    model_type = \"multi_modality\"\n    vision_config: VisionConfig = None\n    aligner_config: AlignerConfig = None\n\n    gen_vision_config: GenVisionConfig = None\n    gen_aligner_config: GenAlignerConfig = None\n    gen_head_config: GenHeadConfig = None\n\n    language_config: LlamaConfig = None\n\n    def __init__(self, **kwargs):\n        super().__init__(**kwargs)\n        vision_config = kwargs.get(\"vision_config\", {})\n        self.vision_config = VisionConfig(**vision_config)\n\n        aligner_config = kwargs.get(\"aligner_config\", {})\n        self.aligner_config = AlignerConfig(**aligner_config)\n\n        gen_vision_config = kwargs.get(\"gen_vision_config\", {})\n        self.gen_vision_config = GenVisionConfig(**gen_vision_config)\n\n        gen_aligner_config = kwargs.get(\"gen_aligner_config\", {})\n        self.gen_aligner_config = GenAlignerConfig(**gen_aligner_config)\n\n        gen_head_config = kwargs.get(\"gen_head_config\", {})\n        self.gen_head_config = GenHeadConfig(**gen_head_config)\n\n        language_config = kwargs.get(\"language_config\", {})\n        if isinstance(language_config, LlamaConfig):\n            self.language_config = language_config\n        else:\n            self.language_config = LlamaConfig(**language_config)\n\n\nclass VLMImageProcessor(BaseImageProcessor):\n    model_input_names = [\"pixel_values\"]\n\n    def __init__(\n        self,\n        image_size: int,\n        min_size: int = 14,\n        image_mean: Union[Tuple[float, float, float], List[float]] = (\n            0.48145466,\n            0.4578275,\n            0.40821073,\n        ),\n        image_std: Union[Tuple[float, float, float], List[float]] = (\n            0.26862954,\n            0.26130258,\n            0.27577711,\n        ),\n        rescale_factor: float = 1.0 / 255.0,\n        do_normalize: bool = True,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n\n        self.image_size = image_size\n        self.rescale_factor = rescale_factor\n        self.image_mean = image_mean\n        self.image_std = image_std\n        self.min_size = min_size\n        self.do_normalize = do_normalize\n\n        if image_mean is None:\n            self.background_color = (127, 127, 127)\n        else:\n            self.background_color = tuple([int(x * 255) for x in image_mean])\n\n    def resize(self, pil_img: Image) -> np.ndarray:\n        \"\"\"\n\n        Args:\n            pil_img (PIL.Image): [H, W, 3] in PIL.Image in RGB\n\n        Returns:\n            x (np.ndarray): [3, self.image_size, self.image_size]\n        \"\"\"\n\n        width, height = pil_img.size\n        max_size = max(width, height)\n\n        size = [\n            max(int(height / max_size * self.image_size), self.min_size),\n            max(int(width / max_size * self.image_size), self.min_size),\n        ]\n\n        if width <= 0 or height <= 0 or size[0] <= 0 or size[1] <= 0:\n            # print(f\"orig size = {pil_img.size}, new size = {size}\")\n            raise ValueError(\"Invalid size!\")\n\n        def resize(\n            pil_img, size, interpolation=PIL.Image.Resampling.BICUBIC, antialias=True\n        ):\n            if isinstance(size, int):\n                w, h = pil_img.size\n                if (w <= h and w == size) or (h <= w and h == size):\n                    return pil_img\n                if w < h:\n                    ow = size\n                    oh = int(size * h / w)\n                else:\n                    oh = size\n                    ow = int(size * w / h)\n                size = (ow, oh)\n            else:\n                size = (size[1], size[0])\n\n            return pil_img.resize(\n                size, resample=interpolation, reducing_gap=None if antialias else 3.0\n            )\n\n        pil_img = resize(\n            pil_img, size, interpolation=PIL.Image.Resampling.BICUBIC, antialias=True\n        )\n\n        pil_img = expand2square(pil_img, self.background_color)\n        x = to_numpy_array(pil_img)\n\n        # [H, W, 3] -> [3, H, W]\n        x = np.transpose(x, (2, 0, 1))\n\n        return x\n\n    def preprocess(self, images, return_tensors: str = \"pt\", **kwargs) -> BatchFeature:\n        # resize and pad to [self.image_size, self.image_size]\n        # then convert from [H, W, 3] to [3, H, W]\n        if not isinstance(images, list):\n            images = [images]\n        images: List[np.ndarray] = [self.resize(image) for image in images]\n        images = [image[:3, ...] for image in images]\n\n        # rescale from [0, 255] -> [0, 1]\n        images = [\n            self.rescale(\n                image=image,\n                scale=self.rescale_factor,\n                input_data_format=\"channels_first\",\n            )\n            for image in images\n        ]\n\n        # normalize\n        if self.do_normalize:\n            images = [\n                self.normalize(\n                    image=image,\n                    mean=self.image_mean,\n                    std=self.image_std,\n                    input_data_format=\"channels_first\",\n                )\n                for image in images\n            ]\n        data = {\"pixel_values\": images}\n        return BatchFeature(data=data, tensor_type=return_tensors)\n\n    @property\n    def default_shape(self):\n        return [3, self.image_size, self.image_size]\n\n\nclass DictOutput(object):\n    def items(self):\n        return self.__dict__.items()\n\n    def keys(self):\n        return self.__dict__.keys()\n\n    def __getitem__(self, item):\n        return self.__dict__[item]\n\n    def __contains__(self, key):\n        return key in self.__dict__\n\n    def __setitem__(self, key, value):\n        self.__dict__[key] = value\n\n\n@dataclass\nclass VLChatProcessorOutput(DictOutput):\n    sft_format: str\n    input_ids: torch.Tensor\n    pixel_values: torch.Tensor\n    num_image_tokens: torch.IntTensor\n\n    def __len__(self):\n        return len(self.input_ids)\n\n\n@dataclass\nclass BatchedVLChatProcessorOutput(DictOutput):\n    sft_format: List[str]\n    input_ids: torch.Tensor\n    pixel_values: torch.Tensor\n    attention_mask: torch.Tensor\n    images_seq_mask: torch.BoolTensor\n    images_emb_mask: torch.BoolTensor\n\n\n# FIXME: had to place Official Processor here, since image_processor module would not be imported in all threads,\n# hence AutoProcessor registration would not be affective in some cases\nclass VLChatProcessor(ProcessorMixin):\n    image_processor_class = \"AutoImageProcessor\"\n    tokenizer_class = (\"LlamaTokenizer\", \"LlamaTokenizerFast\")\n\n    attributes = [\"image_processor\", \"tokenizer\"]\n\n    def __init__(\n        self,\n        image_processor: VLMImageProcessor,\n        tokenizer: LlamaTokenizerFast,\n        image_tag: str = \"<image_placeholder>\",\n        image_start_tag: str = \"<begin_of_image>\",\n        image_end_tag: str = \"<end_of_image>\",\n        pad_tag: str = \"<｜▁pad▁｜>\",\n        num_image_tokens: int = 576,\n        add_special_token: bool = False,\n        sft_format: str = \"deepseek\",\n        mask_prompt: bool = True,\n        ignore_id: int = -100,\n        **kwargs,\n    ):\n        self.image_processor = image_processor\n        self.tokenizer = tokenizer\n\n        image_id = self.tokenizer.vocab.get(image_tag)\n        if image_id is None:\n            special_tokens = [image_tag]\n            special_tokens_dict = {\"additional_special_tokens\": special_tokens}\n            self.tokenizer.add_special_tokens(special_tokens_dict)\n            # print(f\"Add image tag = {image_tag} to the tokenizer\")\n\n        self.image_tag = image_tag\n        self.image_start_tag = image_start_tag\n        self.image_end_tag = image_end_tag\n        self.pad_tag = pad_tag\n\n        self.num_image_tokens = num_image_tokens\n        self.add_special_token = add_special_token\n        self.sft_format = sft_format\n        self.ignore_id = ignore_id\n\n        super().__init__(\n            image_processor,\n            tokenizer,\n            **kwargs,\n        )\n\n    @property\n    def image_token(self):\n        return self.image_tag\n\n    @property\n    def image_id(self) -> int:\n        image_id = self.tokenizer.vocab.get(self.image_tag)\n        return image_id\n\n    @property\n    def image_start_id(self):\n        image_start_id = self.tokenizer.vocab.get(self.image_start_tag)\n        return image_start_id\n\n    @property\n    def image_end_id(self):\n        image_end_id = self.tokenizer.vocab.get(self.image_end_tag)\n        return image_end_id\n\n    @property\n    def image_start_token(self):\n        return self.image_start_tag\n\n    @property\n    def image_end_token(self):\n        return self.image_end_tag\n\n    @property\n    def pad_id(self):\n        pad_id = self.tokenizer.vocab.get(self.pad_tag)\n        return pad_id\n\n    def add_image_token(\n        self,\n        image_indices: List[int],\n        input_ids: torch.LongTensor,\n    ):\n        \"\"\"\n\n        Args:\n            image_indices (List[int]): [index_0, index_1, ..., index_j]\n            input_ids (torch.LongTensor): [N]\n\n        Returns:\n            input_ids (torch.LongTensor): [N + image tokens]\n            num_image_tokens (torch.IntTensor): [n_images]\n        \"\"\"\n\n        input_slices = []\n\n        start = 0\n        for index in image_indices:\n            if self.add_special_token:\n                end = index + 1\n            else:\n                end = index\n\n            # original text tokens\n            input_slices.append(input_ids[start:end])\n\n            # add boi, image tokens, eoi and set the mask as False\n            input_slices.append(self.image_start_id * torch.ones((1), dtype=torch.long))\n            input_slices.append(\n                self.image_id * torch.ones((self.num_image_tokens,), dtype=torch.long)\n            )\n            input_slices.append(self.image_end_id * torch.ones((1), dtype=torch.long))\n            start = index + 1\n\n        # the left part\n        input_slices.append(input_ids[start:])\n\n        # concat all slices\n        input_ids = torch.cat(input_slices, dim=0)\n        num_image_tokens = torch.IntTensor([self.num_image_tokens] * len(image_indices))\n\n        return input_ids, num_image_tokens\n\n    def process_one(\n        self,\n        prompt: str = None,\n        images: List[Image] = None,\n        **kwargs,\n    ):\n        \"\"\"\n\n        Args:\n            prompt (str): the formatted prompt;\n            images (List[ImageType]): the list of images;\n            **kwargs:\n\n        Returns:\n            outputs (BaseProcessorOutput): the output of the processor,\n                - input_ids (torch.LongTensor): [N + image tokens]\n                - target_ids (torch.LongTensor): [N + image tokens]\n                - images (torch.FloatTensor): [n_images, 3, H, W]\n                - image_id (int): the id of the image token\n                - num_image_tokens (List[int]): the number of image tokens\n        \"\"\"\n\n        sft_format = prompt\n        # tokenize\n        input_ids = self.tokenizer.encode(sft_format)\n        input_ids = torch.LongTensor(input_ids)\n\n        # add image tokens to the input_ids\n        image_token_mask: torch.Tensor = (input_ids == self.image_id).to(torch.bool)\n        image_indices = image_token_mask.nonzero()\n        input_ids, num_image_tokens = self.add_image_token(\n            image_indices=image_indices,\n            input_ids=input_ids,\n        )\n\n        # load images\n        images_outputs = self.image_processor(images, return_tensors=\"pt\")\n\n        prepare = VLChatProcessorOutput(\n            sft_format=sft_format,\n            input_ids=input_ids,\n            pixel_values=images_outputs.pixel_values,\n            num_image_tokens=num_image_tokens,\n        )\n\n        return prepare\n\n    def __call__(\n        self,\n        *,\n        prompt: str = None,\n        conversations: List[Dict[str, str]] = None,\n        images: List[Image] = None,\n        force_batchify: bool = True,\n        **kwargs,\n    ):\n        \"\"\"\n\n        Args:\n            prompt (str): the formatted prompt;\n            conversations (List[Dict]): conversations with a list of messages;\n            images (List[ImageType]): the list of images;\n            force_batchify (bool): force batchify the inputs;\n            **kwargs:\n\n        Returns:\n            outputs (BaseProcessorOutput): the output of the processor,\n                - input_ids (torch.LongTensor): [N + image tokens]\n                - images (torch.FloatTensor): [n_images, 3, H, W]\n                - image_id (int): the id of the image token\n                - num_image_tokens (List[int]): the number of image tokens\n        \"\"\"\n\n        prepare = self.process_one(\n            prompt=prompt, conversations=conversations, images=images\n        )\n\n        if force_batchify:\n            prepare = self.batchify([prepare])\n\n        return prepare\n\n    def batchify(\n        self, prepare_list: List[VLChatProcessorOutput]\n    ) -> BatchedVLChatProcessorOutput:\n        \"\"\"\n        Preprocesses the inputs for multimodal inference.\n\n        Args:\n            prepare_list (List[VLChatProcessorOutput]): A list of VLChatProcessorOutput.\n\n        Returns:\n            BatchedVLChatProcessorOutput: A dictionary of the inputs to use for multimodal inference.\n        \"\"\"\n\n        batch_size = len(prepare_list)\n        sft_format = []\n        n_images = []\n        seq_lens = []\n        for prepare in prepare_list:\n            n_images.append(len(prepare.num_image_tokens))\n            seq_lens.append(len(prepare))\n\n        input_token_max_len = max(seq_lens)\n        max_n_images = max(1, max(n_images))\n\n        batched_input_ids = torch.full(\n            (batch_size, input_token_max_len), self.pad_id\n        ).long()  # FIXME\n        batched_attention_mask = torch.zeros((batch_size, input_token_max_len)).long()\n        batched_pixel_values = torch.zeros(\n            (batch_size, max_n_images, *self.image_processor.default_shape)\n        ).float()\n        batched_images_seq_mask = torch.zeros((batch_size, input_token_max_len)).bool()\n        batched_images_emb_mask = torch.zeros(\n            (batch_size, max_n_images, self.num_image_tokens)\n        ).bool()\n\n        for i, prepare in enumerate(prepare_list):\n            input_ids = prepare.input_ids\n            seq_len = len(prepare)\n            n_image = len(prepare.num_image_tokens)\n            # left-padding\n            batched_attention_mask[i, -seq_len:] = 1\n            batched_input_ids[i, -seq_len:] = torch.LongTensor(input_ids)\n            batched_images_seq_mask[i, -seq_len:] = input_ids == self.image_id\n\n            if n_image > 0:\n                batched_pixel_values[i, :n_image] = prepare.pixel_values\n                for j, n_image_tokens in enumerate(prepare.num_image_tokens):\n                    batched_images_emb_mask[i, j, :n_image_tokens] = True\n\n            sft_format.append(prepare.sft_format)\n\n        batched_prepares = BatchedVLChatProcessorOutput(\n            input_ids=batched_input_ids,\n            attention_mask=batched_attention_mask,\n            pixel_values=batched_pixel_values,\n            images_seq_mask=batched_images_seq_mask,\n            images_emb_mask=batched_images_emb_mask,\n            sft_format=sft_format,\n        )\n\n        return batched_prepares\n\n\nclass VLMImageProcessorConfig(PretrainedConfig):\n    model_type = \"deepseek_vlm\"\n    image_size: int = None\n    min_size: int = None\n    image_mean: Union[Tuple[float, float, float], List[float]] = None\n    image_std: Union[Tuple[float, float, float], List[float]] = None\n    rescale_factor: float = None\n    do_normalize: bool = None\n\n    def __init__(\n        self,\n        image_size: int,\n        min_size: int = 14,\n        image_mean: Union[Tuple[float, float, float], List[float]] = (\n            0.48145466,\n            0.4578275,\n            0.40821073,\n        ),\n        image_std: Union[Tuple[float, float, float], List[float]] = (\n            0.26862954,\n            0.26130258,\n            0.27577711,\n        ),\n        rescale_factor: float = 1.0 / 255.0,\n        do_normalize: bool = True,\n        **kwargs,\n    ):\n        self.image_size = image_size\n        self.min_size = min_size\n        self.image_mean = image_mean\n        self.image_std = image_std\n        self.rescale_factor = rescale_factor\n        self.do_normalize = do_normalize\n\n        super().__init__(**kwargs)\n\n\nregister_processor(MultiModalityConfig, VLChatProcessor)\nregister_image_processor(MultiModalityConfig, VLMImageProcessor)\n"
  },
  {
    "path": "python/sglang/srt/configs/jet_nemotron.py",
    "content": "from dataclasses import dataclass\nfrom typing import Any\n\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom sglang.srt.configs.mamba_utils import (\n    Mamba2CacheParams,\n    Mamba2StateShape,\n    mamba2_state_dtype,\n)\n\n\n@dataclass\nclass JetBlockConfig:\n    mode: str\n    expand_v: float\n    num_heads: int\n    head_dim: int\n    norm_eps: str\n    conv_size: int\n    dconv_generator_reduction: int\n    dconv_implementation: str\n\n\nclass JetNemotronConfig(PretrainedConfig):\n    model_type: str = \"jet_nemotron\"\n\n    efficient_attention_config: dict[str, dict[str, Any]] = None\n    hidden_act: str = None\n    hidden_size: int = None\n    initializer_range: float = None\n    intermediate_size: int = None\n    layer_types: list[str] = None\n    max_position_embeddings: int = None\n    num_attention_heads: int = None\n    num_key_value_heads: int = None\n    rms_norm_eps: float = None\n    rope_scaling: None = None\n    rope_theta: float = None\n\n    @property\n    def full_attention_layer_ids(self) -> list[int]:\n        return [\n            idx\n            for idx, layer_type in enumerate(self.layer_types)\n            if layer_type in (\"attn\", \"swa\")\n        ]\n\n    @property\n    def linear_layer_ids(self) -> list[int]:\n        return [\n            idx\n            for idx, layer_type in enumerate(self.layer_types)\n            if layer_type == \"jet\"\n        ]\n\n    @property\n    def mamba2_cache_params(self) -> Mamba2CacheParams:\n        from sglang.srt.layers.dp_attention import get_attention_tp_size\n\n        jet_block_config = JetBlockConfig(**self.efficient_attention_config[\"jet\"])\n\n        num_heads = jet_block_config.num_heads\n        head_k_dim = jet_block_config.head_dim\n        head_v_dim = int(head_k_dim * jet_block_config.expand_v)\n        total_v_dim = num_heads * head_v_dim\n\n        shape = Mamba2StateShape.create(\n            tp_world_size=get_attention_tp_size(),\n            intermediate_size=total_v_dim,\n            n_groups=num_heads,\n            num_heads=num_heads,\n            head_dim=head_v_dim,\n            state_size=head_k_dim,\n            conv_kernel=jet_block_config.conv_size,\n        )\n\n        return Mamba2CacheParams(\n            shape=shape, layers=self.linear_layer_ids, dtype=mamba2_state_dtype(self)\n        )\n"
  },
  {
    "path": "python/sglang/srt/configs/jet_vlm.py",
    "content": "from typing import Any\n\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.models.siglip import SiglipVisionConfig\n\nfrom sglang.srt.configs.jet_nemotron import JetNemotronConfig\nfrom sglang.srt.configs.mamba_utils import Mamba2CacheParams\n\n\nclass JetVLMConfig(PretrainedConfig):\n    model_type = \"jet_vlm\"\n    sub_configs = {\n        \"text_config\": JetNemotronConfig,\n        \"vision_config\": SiglipVisionConfig,\n    }\n    _auto_class = \"AutoConfig\"\n\n    def __init__(\n        self,\n        *,\n        text_config: dict[str, Any] | None = None,\n        vision_config: dict[str, Any] | None = None,\n        image_token_id: int | None = None,\n        video_token_id: int | None = None,\n        **kwargs,\n    ):\n        self.text_config = (\n            JetNemotronConfig(**text_config)\n            if text_config is not None\n            else JetNemotronConfig()\n        )\n        self.vision_config = (\n            SiglipVisionConfig(**vision_config)\n            if vision_config is not None\n            else SiglipVisionConfig()\n        )\n\n        self.image_token_id = image_token_id if image_token_id is not None else -1\n        self.video_token_id = video_token_id if video_token_id is not None else -1\n\n        super().__init__(**kwargs)\n\n    @property\n    def full_attention_layer_ids(self) -> list[int]:\n        return self.text_config.full_attention_layer_ids\n\n    @property\n    def linear_layer_ids(self) -> list[int]:\n        return self.text_config.linear_layer_ids\n\n    @property\n    def mamba2_cache_params(self) -> Mamba2CacheParams:\n        return self.text_config.mamba2_cache_params\n"
  },
  {
    "path": "python/sglang/srt/configs/kimi_k25.py",
    "content": "\"\"\"\nKimi K25 Model Configuration.\n\"\"\"\n\nfrom transformers import DeepseekV3Config\nfrom transformers.configuration_utils import PretrainedConfig\n\n\nclass KimiK25VisionConfig(PretrainedConfig):\n    \"\"\"Vision configuration for K2-VL (vision tower + mm projector).\n\n    Args:\n        Vision Tower Parameters:\n            patch_size: Patch size for vision tower.\n            init_pos_emb_height: Initial position embedding height.\n            init_pos_emb_width: Initial position embedding width.\n            init_pos_emb_time: Initial position embedding time dimension.\n            pos_emb_type: Type of position embedding.\n            num_attention_heads: Number of attention heads in vision tower.\n            num_hidden_layers: Number of hidden layers in vision tower.\n            hidden_size: Hidden size of vision tower.\n            intermediate_size: Intermediate size in vision tower FFN.\n            merge_kernel_size: Kernel size for spatial patch merging.\n            video_attn_type: Type of video attention.\n            merge_type: Type of merge operation.\n\n        MM Projector Parameters:\n            mm_projector_type: Type of multimodal projector.\n            mm_hidden_size: Hidden size for projector (defaults to hidden_size).\n            projector_hidden_act: Activation function for projector.\n            projector_ln_eps: Layer norm epsilon for projector.\n    \"\"\"\n\n    model_type = \"kimi_k25\"\n\n    def __init__(\n        self,\n        # Vision Tower\n        patch_size: int = 14,\n        init_pos_emb_height: int = 64,\n        init_pos_emb_width: int = 64,\n        init_pos_emb_time: int = 4,\n        pos_emb_type: str = \"divided_fixed\",\n        num_attention_heads: int = 16,\n        num_hidden_layers: int = 27,\n        hidden_size: int = 1152,\n        intermediate_size: int = 4304,\n        merge_kernel_size: tuple[int, int] = (2, 2),\n        video_attn_type: str = \"spatial_temporal\",\n        merge_type: str = \"sd2_tpool\",\n        # MM Projector\n        mm_projector_type: str = \"patchmerger\",\n        mm_hidden_size: int | None = None,\n        projector_hidden_act: str = \"gelu\",\n        projector_ln_eps: float = 1e-5,\n        text_hidden_size: int = 7168,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n        # Vision Tower\n        self.patch_size = patch_size\n        self.init_pos_emb_height = init_pos_emb_height\n        self.init_pos_emb_width = init_pos_emb_width\n        self.init_pos_emb_time = init_pos_emb_time\n        self.pos_emb_type = pos_emb_type\n        self.num_attention_heads = num_attention_heads\n        self.num_hidden_layers = num_hidden_layers\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.merge_kernel_size = merge_kernel_size\n        self.video_attn_type = video_attn_type\n        self.merge_type = merge_type\n        # MM Projector\n        self.mm_projector_type = mm_projector_type\n        if mm_hidden_size is not None:\n            self.mm_hidden_size = mm_hidden_size\n        else:\n            self.mm_hidden_size = hidden_size\n        self.projector_hidden_act = projector_hidden_act\n        self.projector_ln_eps = projector_ln_eps\n        self.text_hidden_size = text_hidden_size\n\n\nclass KimiK25Config(PretrainedConfig):\n    \"\"\"K2-VL model configuration.\n\n    K2-VL extends Kimi-VL with video support using video-chunks.\n    A video-chunk consists of multiple consecutive frames (default: 4)\n    that are processed together with temporal pooling.\n\n    Args:\n        text_config: Configuration for the text model (DeepseekV3).\n\n        Vision Tower Parameters:\n            patch_size: Patch size for vision tower.\n            init_pos_emb_height: Initial position embedding height.\n            init_pos_emb_width: Initial position embedding width.\n            init_pos_emb_time: Initial position embedding time dimension.\n            pos_emb_type: Type of position embedding.\n            vt_num_attention_heads: Number of attention heads in vision tower.\n            vt_num_hidden_layers: Number of hidden layers in vision tower.\n            vt_hidden_size: Hidden size of vision tower.\n            vt_intermediate_size: Intermediate size in vision tower FFN.\n            merge_kernel_size: Kernel size for spatial patch merging.\n            video_attn_type: Type of video attention.\n            merge_type: Type of merge operation.\n\n        Video-Chunk Parameters:\n            temporal_merge_kernel_size: Number of frames per video chunk.\n                Default is 4, meaning 4 frames are merged into 1 chunk.\n            sample_fps: Video sampling frame rate.\n            timestamp_mode: Format for chunk timestamps.\n\n        MM Projector Parameters:\n            mm_projector_type: Type of multimodal projector.\n            mm_hidden_size: Hidden size from vision tower.\n            projector_hidden_act: Activation function for projector.\n            projector_ln_eps: Layer norm epsilon for projector.\n\n        Other Parameters:\n            ignore_index: The ignore index for the loss function.\n            media_placeholder_token_id: The token ID for media placeholders.\n            pad_token_id: The token ID for padding.\n    \"\"\"\n\n    model_type = \"kimi_k25\"\n\n    def __init__(\n        self,\n        text_config: dict | DeepseekV3Config | None = None,\n        vision_config: dict | KimiK25VisionConfig | None = None,\n        # Other parameters\n        ignore_index: int = -100,\n        media_placeholder_token_id: int = 163605,\n        pad_token_id: int = 0,\n        use_unified_vision_chunk: bool = False,\n        video_placeholder: str = \"<|kimi_k25_video_placeholder|>\",\n        **kwargs,\n    ):\n        if text_config is None:\n            text_config = DeepseekV3Config()\n        elif isinstance(text_config, dict):\n            text_config = DeepseekV3Config(**text_config)\n\n        if vision_config is None:\n            vision_config = KimiK25VisionConfig()\n        elif isinstance(vision_config, dict):\n            vision_config = KimiK25VisionConfig(**vision_config)\n        self.vision_config = vision_config\n        self.text_config = text_config\n        # Other config\n        self.ignore_index = ignore_index\n        self.media_placeholder_token_id = media_placeholder_token_id\n        self.use_unified_vision_chunk = use_unified_vision_chunk\n        self.video_placeholder = video_placeholder\n\n        # Propagate quantization config from text model\n        if getattr(self.text_config, \"quantization_config\", None) is not None:\n            self.quantization_config = self.text_config.quantization_config\n\n        super().__init__(pad_token_id=pad_token_id, **kwargs)\n\n    @property\n    def hidden_size(self) -> int:\n        \"\"\"Get hidden size from text config for compatibility.\"\"\"\n        return self.text_config.hidden_size\n\n    @property\n    def vocab_size(self) -> int:\n        \"\"\"Get vocab size from text config for compatibility.\"\"\"\n        return self.text_config.vocab_size\n"
  },
  {
    "path": "python/sglang/srt/configs/kimi_linear.py",
    "content": "# Adapted from: https://github.com/vllm-project/vllm/blob/0384aa7150c4c9778efca041ffd1beb3ad2bd694/vllm/transformers_utils/configs/kimi_linear.py\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom sglang.srt.configs.mamba_utils import KimiLinearCacheParams, KimiLinearStateShape\n\n\nclass KimiLinearConfig(PretrainedConfig):\n    model_type = \"kimi_linear\"\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n\n    def __init__(\n        self,\n        model_type=\"kimi_linear\",\n        vocab_size=163840,\n        hidden_size=4096,\n        head_dim=None,\n        intermediate_size=11008,\n        num_hidden_layers=32,\n        num_attention_heads=32,\n        num_key_value_heads=None,\n        hidden_act=\"silu\",\n        initializer_range=0.02,\n        rms_norm_eps=1e-6,\n        use_cache=True,\n        pad_token_id=0,\n        bos_token_id=1,\n        eos_token_id=2,\n        rope_theta=10000.0,\n        rope_scaling=None,\n        tie_word_embeddings=False,\n        moe_intermediate_size: int | None = None,\n        moe_renormalize: bool = True,\n        moe_router_activation_func: str = \"sigmoid\",\n        num_experts: int | None = None,\n        num_experts_per_token: int | None = None,\n        num_shared_experts: int = 0,\n        routed_scaling_factor: float = 1.0,\n        first_k_dense_replace: int = 0,\n        moe_layer_freq: int = 1,\n        use_grouped_topk: bool = True,\n        num_expert_group: int = 1,\n        topk_group: int = 1,\n        q_lora_rank: int | None = None,\n        kv_lora_rank: int | None = None,\n        qk_nope_head_dim: int | None = None,\n        qk_rope_head_dim: int | None = None,\n        v_head_dim: int | None = None,\n        mla_use_nope: bool | None = False,\n        num_nextn_predict_layers: int = 0,\n        linear_attn_config: dict | None = None,\n        **kwargs,\n    ):\n        self.model_type = model_type\n        self.vocab_size = vocab_size\n        self.hidden_size = hidden_size\n        self.head_dim = (\n            head_dim if head_dim is not None else hidden_size // num_attention_heads\n        )\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n\n        # for backward compatibility\n        if num_key_value_heads is None:\n            num_key_value_heads = num_attention_heads\n\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.rope_scaling = rope_scaling\n\n        self.q_lora_rank = q_lora_rank\n        self.kv_lora_rank = kv_lora_rank\n        self.qk_nope_head_dim = qk_nope_head_dim\n        self.qk_rope_head_dim = qk_rope_head_dim\n        self.v_head_dim = v_head_dim\n        self.mla_use_nope = mla_use_nope\n        # moe config\n        self.n_routed_experts = self.num_experts = num_experts\n        self.num_experts_per_token = num_experts_per_token\n        self.moe_renormalize = moe_renormalize\n        self.num_shared_experts = num_shared_experts\n        self.routed_scaling_factor = routed_scaling_factor\n        self.moe_router_activation_func = moe_router_activation_func\n        assert self.moe_router_activation_func in (\"softmax\", \"sigmoid\")\n        self.moe_intermediate_size = moe_intermediate_size\n        self.first_k_dense_replace = first_k_dense_replace\n        self.moe_layer_freq = moe_layer_freq\n        self.use_grouped_topk = use_grouped_topk\n        self.num_expert_group = num_expert_group\n        self.topk_group = topk_group\n        self.num_nextn_predict_layers = num_nextn_predict_layers\n\n        if linear_attn_config is not None:\n            assert linear_attn_config[\"kda_layers\"] is not None\n            assert linear_attn_config[\"full_attn_layers\"] is not None\n        self.linear_attn_config = linear_attn_config\n\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n\n    @property\n    def is_mla(self):\n        return (\n            self.q_lora_rank is not None\n            or self.kv_lora_rank is not None\n            or self.qk_nope_head_dim is not None\n            or self.qk_rope_head_dim is not None\n            or self.v_head_dim is not None\n            or self.mla_use_nope is True\n        )\n\n    @property\n    def is_moe(self):\n        return self.num_experts is not None\n\n    @property\n    def is_linear_attn(self) -> bool:\n        return not (\n            self.linear_attn_config is None\n            or (\n                isinstance(self.linear_attn_config, dict)\n                and self.linear_attn_config[\"kda_layers\"] is not None\n                and len(self.linear_attn_config[\"kda_layers\"]) == 0\n            )\n        )\n\n    def is_kda_layer(self, layer_idx: int):\n        return (\n            self.linear_attn_config is not None\n            and (layer_idx + 1) in self.linear_attn_config[\"kda_layers\"]\n        )\n\n    @property\n    def linear_layer_ids(self):\n        return [i for i in range(self.num_hidden_layers) if self.is_kda_layer(i)]\n\n    @property\n    def full_attention_layer_ids(self):\n        return [i for i in range(self.num_hidden_layers) if not self.is_kda_layer(i)]\n\n    @property\n    def mamba2_cache_params(self) -> KimiLinearCacheParams:\n        from sglang.srt.layers.dp_attention import get_attention_tp_size\n\n        shape = KimiLinearStateShape.create(\n            tp_world_size=get_attention_tp_size(),\n            num_heads=self.linear_attn_config[\"num_heads\"],\n            head_dim=self.linear_attn_config[\"head_dim\"],\n            conv_kernel_size=self.linear_attn_config[\"short_conv_kernel_size\"],\n        )\n\n        return KimiLinearCacheParams(shape=shape, layers=self.linear_layer_ids)\n"
  },
  {
    "path": "python/sglang/srt/configs/kimi_vl.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/configuration_kimi_vl.py\nfrom typing import Optional, Union\n\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom sglang.srt.configs.deepseekvl2 import DeepseekV2Config\nfrom sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig\n\n\nclass KimiVLConfig(PretrainedConfig):\n    model_type = \"kimi_vl\"\n\n    def __init__(\n        self,\n        vision_config: Optional[Union[dict, MoonViTConfig]] = None,\n        text_config: Optional[Union[dict, DeepseekV2Config]] = None,\n        ignore_index: int = -100,\n        media_placeholder_token_id: int = 163605,\n        pad_token_id: int = 0,\n        **kwargs\n    ):\n        if vision_config is None:\n            vision_config = MoonViTConfig()\n        elif isinstance(vision_config, dict):\n            vision_config = MoonViTConfig(**vision_config)\n        self.vision_config = vision_config\n\n        if text_config is None:\n            text_config = DeepseekV2Config()\n        elif isinstance(text_config, dict):\n            text_config = DeepseekV2Config(**text_config)\n        self.text_config = text_config\n\n        self.ignore_index = ignore_index\n        self.media_placeholder_token_id = media_placeholder_token_id\n\n        super().__init__(pad_token_id=pad_token_id, **kwargs)\n"
  },
  {
    "path": "python/sglang/srt/configs/kimi_vl_moonvit.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/configuration_kimi_vl.py\nfrom transformers.configuration_utils import PretrainedConfig\n\n\nclass MoonViTConfig(PretrainedConfig):\n    model_type = \"moonvit\"\n\n    def __init__(\n        self,\n        patch_size: int = 14,\n        init_pos_emb_height: int = 64,\n        init_pos_emb_width: int = 64,\n        num_attention_heads: int = 16,\n        num_hidden_layers: int = 27,\n        hidden_size: int = 1152,\n        intermediate_size: int = 4304,\n        merge_kernel_size: tuple[int, int] = (2, 2),\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n        self.patch_size = patch_size\n        # Positional embedding config\n        self.init_pos_emb_height = init_pos_emb_height\n        self.init_pos_emb_width = init_pos_emb_width\n        # Transformer config\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        # Patch merger config\n        self.merge_kernel_size = merge_kernel_size\n"
  },
  {
    "path": "python/sglang/srt/configs/lfm2.py",
    "content": "# coding=utf-8\n# Copyright 2024 Liquid AI and the HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"LFM2 (Liquid Foundation Model 2) configuration\"\"\"\n\nfrom typing import List, Optional\n\nfrom transformers import CONFIG_MAPPING\nfrom transformers import Lfm2Config as HFLfm2Config\nfrom transformers.utils import logging\n\nfrom sglang.srt.configs.mamba_utils import (\n    Mamba2CacheParams,\n    Mamba2StateShape,\n    mamba2_state_dtype,\n)\n\nlogger = logging.get_logger(__name__)\n\n\nclass Lfm2Config(HFLfm2Config):\n    \"\"\"\n    SGLang configuration for LFM2 models.\n\n    Extends HuggingFace's Lfm2Config with hybrid model properties needed by SGLang.\n    LFM2 uses a hybrid architecture mixing full attention and ShortConv layers.\n    \"\"\"\n\n    @property\n    def full_attention_layer_ids(self) -> List[int]:\n        \"\"\"Return indices of attention layers for KV cache.\"\"\"\n        return [i for i, lt in enumerate(self.layer_types) if lt == \"full_attention\"]\n\n    @property\n    def linear_layer_ids(self) -> List[int]:\n        \"\"\"Return indices of conv layers for conv state cache.\"\"\"\n        return [\n            i for i, lt in enumerate(self.layer_types) if lt in (\"conv\", \"short_conv\")\n        ]\n\n    @property\n    def mamba_chunk_size(self) -> int:\n        \"\"\"Return chunk size for Mamba2 backend. LFM2 doesn't use chunking, return 1.\"\"\"\n        return 1\n\n    @property\n    def mamba2_cache_params(self) -> Optional[Mamba2CacheParams]:\n        \"\"\"\n        Get cache params for HybridReqToTokenPool initialization.\n\n        LFM2 uses ShortConv layers with a small fixed-size cache (kernel_size - 1).\n        Unlike full Mamba2 models, LFM2 only uses the conv state, not SSM temporal state.\n        \"\"\"\n        from sglang.srt.layers.dp_attention import get_attention_tp_size\n\n        conv_layer_ids = self.linear_layer_ids\n        if not conv_layer_ids:\n            return None\n\n        hidden_size = self.hidden_size\n        conv_kernel = int(self.conv_L_cache)\n\n        # get_attention_tp_size() requires initialization, default to 1 if not available\n        try:\n            tp_size = get_attention_tp_size()\n        except (AssertionError, RuntimeError):\n            tp_size = 1\n\n        # For ShortConv layers, we use a simplified Mamba2StateShape\n        # LFM2 doesn't use SSM state (state_size=0), only conv state\n        # We pass num_heads=tp_size so divide(tp_size, tp_size)=1 always works.\n        # Since state_size=0, the temporal state shape has zero elements anyway.\n        shape = Mamba2StateShape.create(\n            tp_world_size=tp_size,\n            intermediate_size=hidden_size,\n            n_groups=1,  # ShortConv doesn't use grouping\n            num_heads=tp_size,  # Ensures divide works; temporal state is empty anyway\n            head_dim=hidden_size,  # Conv operates on full hidden dim\n            state_size=0,  # No SSM temporal state for ShortConv\n            conv_kernel=conv_kernel,\n        )\n\n        return Mamba2CacheParams(\n            shape=shape,\n            layers=conv_layer_ids,\n            dtype=mamba2_state_dtype(self),\n        )\n\n\n# Override HuggingFace's Lfm2Config with our extended version\n# Cannot use .register() because lfm2 is already registered by transformers\n# Directly modify the internal _extra_content dict instead\nCONFIG_MAPPING._extra_content[\"lfm2\"] = Lfm2Config\n"
  },
  {
    "path": "python/sglang/srt/configs/lfm2_moe.py",
    "content": "# Copyright 2025 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"LFM2-MoE (Liquid Foundation Model 2 - Mixture of Experts) configuration\n\nNote: HF transformers has Lfm2MoeConfig in v5.0.0rc2 (unreleased).\nOnce released, we could inherit from it like Lfm2Config does with HFLfm2Config.\nFor now, we define a standalone config to support the model immediately.\n\"\"\"\n\nfrom typing import List, Optional\n\nfrom transformers import CONFIG_MAPPING\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape\n\n\nclass Lfm2MoeConfig(PretrainedConfig):\n    \"\"\"\n    Configuration for LFM2-MoE models (e.g., LiquidAI/LFM2-8B-A1B).\n\n    LFM2-MoE is a hybrid architecture with:\n    - Attention layers and ShortConv layers (like dense LFM2)\n    - MoE (Mixture of Experts) FFN layers with sigmoid routing\n\n    Key MoE specifics:\n    - First `num_dense_layers` use dense MLP, rest use MoE\n    - Sigmoid routing (not softmax) with expert_bias for load balancing\n    - expert_bias is fp32 for numerical stability\n    \"\"\"\n\n    model_type = \"lfm2_moe\"\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n\n    def __init__(\n        self,\n        vocab_size: int = 65536,\n        hidden_size: int = 2048,\n        intermediate_size: int = 7168,\n        moe_intermediate_size: int = 1792,\n        num_hidden_layers: int = 32,\n        num_attention_heads: int = 32,\n        num_key_value_heads: int = 8,\n        max_position_embeddings: int = 128000,\n        initializer_range: float = 0.02,\n        norm_eps: float = 1e-5,\n        use_cache: bool = True,\n        pad_token_id: int = 0,\n        bos_token_id: int = 1,\n        eos_token_id: int = 2,\n        tie_word_embeddings: bool = True,\n        rope_parameters: Optional[dict] = None,\n        conv_bias: bool = False,\n        conv_L_cache: int = 3,\n        # MoE-specific parameters\n        num_dense_layers: int = 2,\n        num_experts: int = 32,\n        num_experts_per_tok: int = 4,\n        use_expert_bias: bool = True,\n        routed_scaling_factor: float = 1.0,\n        norm_topk_prob: bool = True,\n        # Layer types\n        layer_types: Optional[List[str]] = None,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.moe_intermediate_size = moe_intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.num_key_value_heads = num_key_value_heads\n        self.max_position_embeddings = max_position_embeddings\n        self.initializer_range = initializer_range\n        self.norm_eps = norm_eps\n        self.use_cache = use_cache\n\n        # Conv parameters\n        self.conv_bias = conv_bias\n        self.conv_L_cache = conv_L_cache\n\n        # MoE parameters\n        self.num_dense_layers = num_dense_layers\n        self.num_experts = num_experts\n        self.num_experts_per_tok = num_experts_per_tok\n        self.use_expert_bias = use_expert_bias\n        self.routed_scaling_factor = routed_scaling_factor\n        self.norm_topk_prob = norm_topk_prob\n\n        # Layer types (attention vs conv)\n        self.layer_types = layer_types\n\n        # RoPE parameters\n        self.rope_parameters = rope_parameters\n\n        # Validate layer_types length matches num_hidden_layers\n        if layer_types is not None and len(layer_types) != num_hidden_layers:\n            raise ValueError(\n                f\"layer_types length ({len(layer_types)}) must match \"\n                f\"num_hidden_layers ({num_hidden_layers})\"\n            )\n\n        # Handle tie_embedding alias from original config\n        tie_word_embeddings = kwargs.pop(\"tie_embedding\", tie_word_embeddings)\n\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n\n    @property\n    def full_attention_layer_ids(self) -> List[int]:\n        \"\"\"Return indices of attention layers for KV cache.\"\"\"\n        if self.layer_types is None:\n            return []\n        return [i for i, lt in enumerate(self.layer_types) if lt == \"full_attention\"]\n\n    @property\n    def linear_layer_ids(self) -> List[int]:\n        \"\"\"Return indices of conv layers for conv state cache.\"\"\"\n        if self.layer_types is None:\n            return []\n        return [\n            i for i, lt in enumerate(self.layer_types) if lt in (\"conv\", \"short_conv\")\n        ]\n\n    @property\n    def mamba_chunk_size(self) -> int:\n        \"\"\"Return chunk size for Mamba2 backend. LFM2 doesn't use chunking.\"\"\"\n        return 1\n\n    @property\n    def mamba2_cache_params(self) -> Optional[Mamba2CacheParams]:\n        \"\"\"\n        Get cache params for HybridReqToTokenPool initialization.\n\n        LFM2-MoE uses ShortConv layers with a small fixed-size cache.\n        \"\"\"\n        from sglang.srt.layers.dp_attention import get_attention_tp_size\n\n        conv_layer_ids = self.linear_layer_ids\n        if not conv_layer_ids:\n            return None\n\n        hidden_size = self.hidden_size\n        # conv_L_cache in config is kernel_size (e.g., 3)\n        conv_kernel = int(self.conv_L_cache)\n        # actual cache size is kernel_size - 1 (e.g., 2 for kernel=3)\n\n        try:\n            tp_size = get_attention_tp_size()\n        except (AssertionError, RuntimeError):\n            tp_size = 1\n\n        shape = Mamba2StateShape.create(\n            tp_world_size=tp_size,\n            intermediate_size=hidden_size,\n            n_groups=1,\n            num_heads=tp_size,  # Ensures divide works; temporal state is empty anyway\n            head_dim=hidden_size,\n            state_size=0,\n            conv_kernel=conv_kernel,\n        )\n\n        # Uses default mamba2_state_dtype() which reads SGLANG_MAMBA_CONV_DTYPE env var\n        # (defaults to bfloat16). Set SGLANG_MAMBA_CONV_DTYPE=float16 for fp16 inference.\n        return Mamba2CacheParams(\n            shape=shape,\n            layers=conv_layer_ids,\n        )\n\n\n# Register with transformers CONFIG_MAPPING so AutoConfig.from_pretrained()\n# can instantiate our config class when loading models with model_type=\"lfm2_moe\"\ntry:\n    CONFIG_MAPPING.register(\"lfm2_moe\", Lfm2MoeConfig)\nexcept Exception:\n    # Already registered or registration failed - use direct assignment\n    CONFIG_MAPPING._extra_content[\"lfm2_moe\"] = Lfm2MoeConfig\n"
  },
  {
    "path": "python/sglang/srt/configs/load_config.py",
    "content": "# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py\nimport enum\nimport logging\nfrom dataclasses import dataclass, field\nfrom typing import Any, List, Optional, Union\n\nimport orjson\n\nfrom sglang.srt.configs.modelopt_config import ModelOptConfig\nfrom sglang.srt.utils import is_hip\n\nlogger = logging.getLogger(__name__)\n\n\nclass LoadFormat(str, enum.Enum):\n    AUTO = \"auto\"\n    PT = \"pt\"\n    SAFETENSORS = \"safetensors\"\n    NPCACHE = \"npcache\"\n    DUMMY = \"dummy\"\n    SHARDED_STATE = \"sharded_state\"\n    GGUF = \"gguf\"\n    BITSANDBYTES = \"bitsandbytes\"\n    MISTRAL = \"mistral\"\n    LAYERED = \"layered\"\n    FLASH_RL = \"flash_rl\"  # For RL training with quantized models\n    JAX = \"jax\"\n    REMOTE = \"remote\"\n    REMOTE_INSTANCE = \"remote_instance\"\n    RDMA = \"rdma\"\n    LOCAL_CACHED = \"local_cached\"\n    FASTSAFETENSORS = \"fastsafetensors\"\n    PRIVATE = \"private\"\n\n\n@dataclass\nclass LoadConfig:\n    \"\"\"\n    download_dir: Directory to download and load the weights, default to the\n        default cache directory of huggingface.\n    load_format: The format of the model weights to load:\n        \"auto\" will try to load the weights in the safetensors format and\n            fall back to the pytorch bin format if safetensors format is\n            not available.\n        \"pt\" will load the weights in the pytorch bin format.\n        \"safetensors\" will load the weights in the safetensors format.\n        \"npcache\" will load the weights in pytorch format and store\n            a numpy cache to speed up the loading.\n        \"dummy\" will initialize the weights with random values, which is\n            mainly for profiling.\n        \"bitsandbytes\" will load nf4 type weights.\n        \"flash_rl\" will load weights with support for RL training\n            with quantized models, enabling efficient weight reloading.\n    ignore_patterns: The list of patterns to ignore when loading the model.\n        Default to \"original/**/*\" to avoid repeated loading of llama's\n        checkpoints.\n    decryption_key_file: If set, decrypts the output files with a password read\n        from this file (after PBKDF2).\n    decrypt_max_concurrency: The maximum number of concurrent processes to decrypt the safetensor files. -1 means no limit.\n\n    # ModelOpt-specific loading options\n    modelopt_checkpoint_restore_path: Optional[str] = None\n    modelopt_checkpoint_save_path: Optional[str] = None\n    modelopt_export_path: Optional[str] = None\n    \"\"\"\n\n    load_format: Union[str, LoadFormat] = LoadFormat.AUTO\n    download_dir: Optional[str] = None\n    model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict)\n    ignore_patterns: Optional[Union[List[str], str]] = None\n    decryption_key_file: Optional[str] = None\n    decrypt_max_concurrency: int = -1\n    tp_rank: Optional[int] = None\n    remote_instance_weight_loader_seed_instance_ip: Optional[str] = None\n    remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None\n    remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None\n    remote_instance_weight_loader_backend: Optional[str] = None\n    remote_instance_weight_loader_transfer_engine: Optional[Any] = None\n    modelexpress_url: Optional[str] = None\n    modelexpress_model_name: Optional[str] = None\n\n    # ModelOpt-specific loading options\n    modelopt_checkpoint_restore_path: Optional[str] = None\n    modelopt_checkpoint_save_path: Optional[str] = None\n    modelopt_export_path: Optional[str] = None\n\n    # ModelOpt configuration object\n    modelopt_config: Optional[ModelOptConfig] = None\n\n    # QuantizedRL-specific options (for FlashRL-style quantization)\n    rl_quant_profile: Optional[str] = (\n        None  # Path to rollout quantization profile (e.g., /root/profile.7b.pt)\n    )\n\n    # For multi-layer MTP\n    draft_model_idx: Optional[int] = None\n\n    def __post_init__(self):\n        model_loader_extra_config = self.model_loader_extra_config or {}\n        if isinstance(model_loader_extra_config, str):\n            self.model_loader_extra_config = orjson.loads(model_loader_extra_config)\n        self._verify_load_format()\n\n        if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:\n            logger.info(\n                \"Ignoring the following patterns when downloading weights: %s\",\n                self.ignore_patterns,\n            )\n        else:\n            self.ignore_patterns = [\"original/**/*\"]\n\n        # Create ModelOptConfig if not provided\n        if self.modelopt_config is None:\n            self.modelopt_config = ModelOptConfig(\n                checkpoint_restore_path=self.modelopt_checkpoint_restore_path,\n                checkpoint_save_path=self.modelopt_checkpoint_save_path,\n                export_path=self.modelopt_export_path,\n            )\n\n    def _verify_load_format(self) -> None:\n        if not isinstance(self.load_format, str):\n            return\n\n        load_format = self.load_format.lower()\n        self.load_format = LoadFormat(load_format)\n\n        rocm_not_supported_load_format: List[str] = []\n        if is_hip() and load_format in rocm_not_supported_load_format:\n            rocm_supported_load_format = [\n                f\n                for f in LoadFormat.__members__\n                if (f not in rocm_not_supported_load_format)\n            ]\n            raise ValueError(\n                f\"load format '{load_format}' is not supported in ROCm. \"\n                f\"Supported load formats are \"\n                f\"{rocm_supported_load_format}\"\n            )\n"
  },
  {
    "path": "python/sglang/srt/configs/longcat_flash.py",
    "content": "from transformers.configuration_utils import PretrainedConfig\nfrom transformers.utils import logging\n\nlogger = logging.get_logger(__name__)\n\nFLASH_PRETRAINED_CONFIG_ARCHIVE_MAP = {}\n\n\nclass LongcatFlashConfig(PretrainedConfig):\n    model_type = \"longcat_flash\"\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n\n    def __init__(\n        self,\n        vocab_size=131072,\n        hidden_size=6144,\n        intermediate_size=None,\n        ffn_hidden_size=12288,\n        expert_ffn_hidden_size=2048,\n        num_layers=28,\n        num_hidden_layers=None,\n        num_attention_heads=64,\n        ep_size=1,\n        kv_lora_rank=512,\n        q_lora_rank=1536,\n        qk_rope_head_dim=128,\n        qk_nope_head_dim=128,\n        v_head_dim=128,\n        n_routed_experts=512,\n        moe_topk=12,\n        norm_topk_prob=False,\n        max_position_embeddings=131072,\n        rms_norm_eps=1e-05,\n        use_cache=True,\n        pad_token_id=None,\n        bos_token_id=1,\n        eos_token_id=2,\n        pretraining_tp=1,\n        tie_word_embeddings=False,\n        rope_theta=10000000.0,\n        rope_scaling=None,\n        attention_bias=False,\n        attention_dropout=0.0,\n        mla_scale_q_lora=True,\n        mla_scale_kv_lora=True,\n        torch_dtype=\"bfloat16\",\n        params_dtype=\"bfloat16\",\n        rounter_params_dtype=\"float32\",\n        router_bias=False,\n        topk_method=None,\n        routed_scaling_factor=6.0,\n        zero_expert_num=256,\n        zero_expert_type=\"identity\",\n        nextn_use_scmoe=False,\n        num_nextn_predict_layers=1,\n        ngram_vocab_size_ratio=None,\n        emb_neighbor_num=None,\n        emb_split_num=None,\n        **kwargs,\n    ):\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            torch_dtype=torch_dtype,\n            params_dtype=params_dtype,\n            rounter_params_dtype=rounter_params_dtype,\n            topk_method=topk_method,\n            router_bias=router_bias,\n            nextn_use_scmoe=nextn_use_scmoe,\n            num_nextn_predict_layers=num_nextn_predict_layers,\n            **kwargs,\n        )\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.num_hidden_layers = (\n            num_hidden_layers if num_hidden_layers is not None else num_layers\n        )\n        self.intermediate_size = (\n            intermediate_size if intermediate_size is not None else ffn_hidden_size\n        )\n        self.moe_intermediate_size = expert_ffn_hidden_size\n        self.num_attention_heads = num_attention_heads\n        self.ep_size = ep_size\n        self.kv_lora_rank = kv_lora_rank\n        self.q_lora_rank = q_lora_rank\n        self.qk_rope_head_dim = qk_rope_head_dim\n        self.v_head_dim = v_head_dim\n        self.qk_nope_head_dim = qk_nope_head_dim\n        self.n_routed_experts = n_routed_experts\n        self.moe_topk = moe_topk\n        self.norm_topk_prob = norm_topk_prob\n        self.rms_norm_eps = rms_norm_eps\n        self.pretraining_tp = pretraining_tp\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.rope_scaling = rope_scaling\n        self.attention_bias = attention_bias\n        self.attention_dropout = attention_dropout\n        self.mla_scale_q_lora = mla_scale_q_lora\n        self.mla_scale_kv_lora = mla_scale_kv_lora\n        self.zero_expert_num = zero_expert_num\n        self.zero_expert_type = zero_expert_type\n        self.routed_scaling_factor = routed_scaling_factor\n        self.hidden_act = \"silu\"\n        self.use_ngram_embedding = ngram_vocab_size_ratio is not None\n        if self.use_ngram_embedding:\n            self.ngram_embedding_m = int(ngram_vocab_size_ratio * vocab_size)\n            self.ngram_embedding_n = emb_neighbor_num\n            self.ngram_embedding_k = emb_split_num\n"
  },
  {
    "path": "python/sglang/srt/configs/mamba_utils.py",
    "content": "# Copyright 2025 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Common config utils for mamba2 - NemotronH, FalconH1, Qwen3Next, LFM2, etc.\"\"\"\n\nimport logging\nfrom abc import ABC\nfrom dataclasses import dataclass, field\nfrom typing import List, Optional\n\nimport numpy as np\nimport torch\n\nfrom sglang.srt.distributed.utils import divide\nfrom sglang.srt.environ import envs\n\nlogger = logging.getLogger(__name__)\n\n\ndef extra_groups_for_head_shards(ngroups: int, tp_size: int):\n    \"\"\"Compute the increase in group numbers to account for\n    replication in order to accompany the head shards.\"\"\"\n\n    # in the case ngoups % tp_size == 0, this will be zero\n    if ngroups % tp_size == 0:\n        return 0\n\n    # for n_groups == 1, this is exactly tp_size - n_groups\n    return tp_size - ngroups\n\n\n@dataclass(kw_only=True, frozen=True)\nclass Mamba2StateDType:\n    conv: torch.dtype\n    temporal: torch.dtype\n\n\ndef mamba2_state_dtype(config=None) -> Mamba2StateDType:\n    \"\"\"\n    Get mamba2 state dtype from config or environment variable.\n\n    Priority (from highest to lowest):\n    1. Environment variable SGLANG_MAMBA_SSM_DTYPE\n    2. Config file (config.mamba_ssm_dtype or config.text_config.mamba_ssm_dtype)\n    3. Default \"float32\"\n\n    Args:\n        config: Optional config object (PretrainedConfig). If provided, will read\n                mamba_ssm_dtype from it. For VL models, reads from text_config.\n\n    Returns:\n        Mamba2StateDType with conv and temporal dtypes\n    \"\"\"\n    dtype_map = {\n        \"float32\": torch.float32,\n        \"bfloat16\": torch.bfloat16,\n        \"float16\": torch.float16,\n    }\n    conv_dtype = dtype_map.get(envs.SGLANG_MAMBA_CONV_DTYPE.get(), torch.bfloat16)\n\n    # Get SSM dtype: default -> config -> env var\n    ssm_dtype = torch.float32  # Step 1: Default value\n\n    # Step 2: Try to read from config\n    if config is not None:\n        config_dtype = None\n        if hasattr(config, \"text_config\") and hasattr(\n            config.text_config, \"mamba_ssm_dtype\"\n        ):\n            # VL model: read from text_config\n            config_dtype = config.text_config.mamba_ssm_dtype\n        elif hasattr(config, \"mamba_ssm_dtype\"):\n            # Text model: read from root config\n            config_dtype = config.mamba_ssm_dtype\n\n        if config_dtype is not None:\n            if config_dtype not in dtype_map:\n                logger.warning(\n                    f\"Invalid mamba_ssm_dtype '{config_dtype}' in config. \"\n                    f\"Must be one of {list(dtype_map.keys())}. Using default 'float32'.\"\n                )\n            else:\n                ssm_dtype = dtype_map[config_dtype]\n\n    # Step 3: Check environment variable, if not None, override\n    env_ssm_dtype = envs.SGLANG_MAMBA_SSM_DTYPE.get()\n    if env_ssm_dtype is not None:\n        if env_ssm_dtype not in dtype_map:\n            logger.warning(\n                f\"Invalid mamba_ssm_dtype '{env_ssm_dtype}' from environment variable. \"\n                f\"Must be one of {list(dtype_map.keys())}. Using default 'float32'.\"\n            )\n        else:\n            ssm_dtype = dtype_map[env_ssm_dtype]\n\n    logger.debug(f\"Mamba2 state dtype: conv_dtype={conv_dtype}, ssm_dtype={ssm_dtype}\")\n\n    return Mamba2StateDType(conv=conv_dtype, temporal=ssm_dtype)\n\n\n@dataclass(kw_only=True, frozen=True)\nclass BaseLinearStateParams(ABC):\n    dtype: Mamba2StateDType = field(default_factory=lambda: mamba2_state_dtype(None))\n    layers: list[int]\n\n    @property\n    def mamba_cache_per_req(self) -> int:\n        conv_numel = int(\n            np.sum([np.prod(conv_shape) for conv_shape in self.shape.conv])\n        )\n\n        ssm_numel = int(np.prod(self.shape.temporal))\n        return (\n            conv_numel * self.dtype.conv.itemsize\n            + ssm_numel * self.dtype.temporal.itemsize\n        ) * len(self.layers)\n\n\n@dataclass(kw_only=True, frozen=True)\nclass Mamba2StateShape:\n    conv: list[tuple[int, int]]\n    temporal: tuple[int, int, int]\n\n    intermediate_size: int\n    conv_dim: int\n    ssm_state_size: int\n    num_heads: int\n    head_dim: int\n    state_size: int\n    conv_kernel: int\n\n    @staticmethod\n    def create(\n        *,\n        tp_world_size: int,\n        intermediate_size: int,\n        n_groups: int,\n        num_heads: int,\n        head_dim: int,\n        state_size: int,\n        conv_kernel: int,\n    ) -> \"Mamba2StateShape\":\n        # if n_groups is not divisible by world_size, need to extend the shards\n        # to ensure all groups needed by a head is sharded along with it\n        if n_groups % tp_world_size != 0:\n            extra_groups = extra_groups_for_head_shards(n_groups, tp_world_size)\n            n_groups += extra_groups\n        # heads and n_groups are TP-ed\n        conv_dim = intermediate_size + 2 * n_groups * state_size\n\n        # contiguous along 'dim' axis\n        conv_state_shape = divide(conv_dim, tp_world_size), conv_kernel - 1\n\n        # These are not TP-ed as they depend on A, dt_bias, D\n        # - they are typically small\n        #   e.g., QWen3-Next: (32, 128, 128)\n        temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, state_size)\n        return Mamba2StateShape(\n            conv=[conv_state_shape],\n            temporal=temporal_state_shape,\n            intermediate_size=intermediate_size,\n            conv_dim=conv_dim,\n            ssm_state_size=state_size,\n            num_heads=num_heads,\n            head_dim=head_dim,\n            state_size=state_size,\n            conv_kernel=conv_kernel,\n        )\n\n\n@dataclass(kw_only=True, frozen=True)\nclass Mamba2CacheParams(BaseLinearStateParams):\n    shape: Mamba2StateShape\n\n\n@dataclass(kw_only=True, frozen=True)\nclass KimiLinearStateShape:\n    conv: List[tuple[int, int]]\n    temporal: tuple[int, int, int]\n\n    num_heads: int\n    head_dim: int\n    num_k_heads: int\n    head_k_dim: int\n    conv_kernel: int\n    num_spec: int\n\n    @staticmethod\n    def create(\n        *,\n        tp_world_size: int,\n        num_heads: int,\n        head_dim: int,\n        num_k_heads: Optional[int] = None,\n        head_k_dim: Optional[int] = None,\n        conv_kernel_size: int = 4,\n        num_spec: int = 0,\n    ) -> \"KimiLinearStateShape\":\n        if num_k_heads is None:\n            num_k_heads = num_heads\n        if head_k_dim is None:\n            head_k_dim = head_dim\n\n        proj_size = num_heads * head_dim\n        proj_k_size = num_k_heads * head_k_dim\n\n        conv_state_shape = (divide(proj_size, tp_world_size), conv_kernel_size - 1)\n        conv_state_k_shape = (divide(proj_k_size, tp_world_size), conv_kernel_size - 1)\n        temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, head_dim)\n\n        conv_state_shape = (\n            conv_state_shape[1],\n            conv_state_shape[0] + conv_state_k_shape[0] * 2,\n        )\n\n        return KimiLinearStateShape(\n            conv=[conv_state_shape],\n            temporal=temporal_state_shape,\n            num_heads=num_heads,\n            head_dim=head_dim,\n            num_k_heads=num_k_heads,\n            head_k_dim=head_k_dim,\n            conv_kernel=conv_kernel_size,\n            num_spec=num_spec,\n        )\n\n\n@dataclass(kw_only=True, frozen=True)\nclass KimiLinearCacheParams(BaseLinearStateParams):\n    shape: KimiLinearStateShape\n"
  },
  {
    "path": "python/sglang/srt/configs/model_config.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\nimport json\nimport logging\nimport math\nimport os\nfrom enum import Enum, IntEnum, auto\nfrom pathlib import Path\nfrom typing import Any, List, Optional, Set, Union\n\nimport torch\nfrom transformers import PretrainedConfig\n\nfrom sglang.srt.environ import envs\nfrom sglang.srt.layers.quantization import QUANTIZATION_METHODS\nfrom sglang.srt.server_args import ServerArgs\nfrom sglang.srt.utils import is_hip, is_sm100_supported, retry\nfrom sglang.srt.utils.hf_transformers_utils import (\n    get_config,\n    get_context_length,\n    get_generation_config,\n    get_hf_text_config,\n    get_sparse_attention_config,\n)\nfrom sglang.utils import is_in_ci\n\nlogger = logging.getLogger(__name__)\n\n\nclass AttentionArch(IntEnum):\n    MLA = auto()\n    MHA = auto()\n\n\nclass ModelImpl(str, Enum):\n    AUTO = \"auto\"\n    SGLANG = \"sglang\"\n    TRANSFORMERS = \"transformers\"\n    MINDSPORE = \"mindspore\"\n\n\ndef is_deepseek_nsa(config) -> bool:\n    architectures = (\n        config.get(\"architectures\")\n        if isinstance(config, dict)\n        else getattr(config, \"architectures\", None)\n    )\n    index_topk = (\n        config.get(\"index_topk\")\n        if isinstance(config, dict)\n        else getattr(config, \"index_topk\", None)\n    )\n    return (\n        architectures is not None\n        and architectures[0]\n        in [\n            \"DeepseekV3ForCausalLM\",\n            \"DeepseekV32ForCausalLM\",\n            \"DeepseekV3ForCausalLMNextN\",\n            \"MistralLarge3ForCausalLM\",\n            \"PixtralForConditionalGeneration\",\n            \"GlmMoeDsaForCausalLM\",\n        ]\n        and index_topk is not None\n    )\n\n\ndef get_nsa_index_head_dim(config: PretrainedConfig) -> int:\n    assert is_deepseek_nsa(config)\n    return config.index_head_dim\n\n\ndef get_nsa_index_topk(config: PretrainedConfig) -> int:\n    assert is_deepseek_nsa(config)\n    return config.index_topk\n\n\ndef get_nsa_index_n_heads(config: PretrainedConfig) -> int:\n    assert is_deepseek_nsa(config)\n    return config.index_n_heads\n\n\nclass ModelConfig:\n    def __init__(\n        self,\n        model_path: str,\n        trust_remote_code: bool = True,\n        revision: Optional[str] = None,\n        context_length: Optional[int] = None,\n        model_override_args: str = \"{}\",\n        is_embedding: Optional[bool] = None,\n        enable_multimodal: Optional[bool] = None,\n        dtype: str = \"auto\",\n        quantization: Optional[str] = None,\n        override_config_file: Optional[str] = None,\n        is_draft_model: bool = False,\n        model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,\n        sampling_defaults: str = \"openai\",\n        quantize_and_serve: bool = False,\n        is_multi_layer_eagle: bool = False,\n        encoder_only: bool = False,\n        language_only: bool = False,\n        disable_hybrid_swa_memory: bool = False,\n    ) -> None:\n        # Parse args\n        self.model_path = model_path\n        self.revision = revision\n        self.quantization = quantization\n        self.is_draft_model = is_draft_model\n        self.model_impl = model_impl\n        self.sampling_defaults = sampling_defaults\n        self.quantize_and_serve = quantize_and_serve\n        self.is_multi_layer_eagle = is_multi_layer_eagle\n        self.disable_hybrid_swa_memory = disable_hybrid_swa_memory\n\n        # Validate quantize_and_serve configuration\n        self._validate_quantize_and_serve_config()\n\n        # Get hf config\n        self._maybe_pull_model_tokenizer_from_remote()\n        self.model_override_args = json.loads(model_override_args)\n        kwargs = {}\n        if override_config_file and override_config_file.strip():\n            kwargs[\"_configuration_file\"] = override_config_file.strip()\n        self.hf_config = get_config(\n            self.model_path,\n            trust_remote_code=trust_remote_code,\n            revision=revision,\n            model_override_args=self.model_override_args,\n            **kwargs,\n        )\n        self.hf_text_config = get_hf_text_config(self.hf_config)\n        self.hf_generation_config = get_generation_config(\n            self.model_path,\n            trust_remote_code=trust_remote_code,\n            revision=revision,\n            **kwargs,\n        )\n\n        # Set enable_multimodal\n        if enable_multimodal is None:\n            mm_disabled_models = [\n                \"Gemma3ForConditionalGeneration\",\n                \"Llama4ForConditionalGeneration\",\n                \"Step3VLForConditionalGeneration\",\n            ]\n            if self.hf_config.architectures[0] in mm_disabled_models:\n                enable_multimodal = False\n                logger.info(\n                    f\"Multimodal is disabled for {self.hf_config.model_type}. To enable it, set --enable-multimodal.\"\n                )\n            else:\n                enable_multimodal = True\n\n        # Config draft model\n        self._config_draft_model()\n\n        # Check model type\n        self.attention_chunk_size = getattr(\n            self.hf_text_config, \"attention_chunk_size\", None\n        )\n        self.sliding_window_size = self._get_sliding_window_size()\n        self.is_generation = is_generation_model(\n            self.hf_config.architectures, is_embedding\n        )\n        self.is_multimodal = enable_multimodal and is_multimodal_model(\n            self.hf_config.architectures\n        )\n        self.is_multimodal_gen = enable_multimodal and is_multimodal_gen_model(\n            self.hf_config.architectures\n        )\n        self.is_image_gen = enable_multimodal and is_image_gen_model(\n            self.hf_config.architectures\n        )\n        self.is_audio_model = enable_multimodal and is_audio_model(\n            self.hf_config.architectures\n        )\n        # TODO: requires further polishing\n        self.is_image_understandable_model = enable_multimodal and hasattr(\n            self.hf_config, \"vision_config\"\n        )\n        self.is_audio_understandable_model = enable_multimodal and hasattr(\n            self.hf_config, \"audio_config\"\n        )\n\n        self.is_multimodal_chunked_prefill_supported = (\n            enable_multimodal\n            and is_multimodal_chunked_prefill_supported(self.hf_config.architectures)\n        )\n        self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)\n        self.is_local_attention_model = is_local_attention_model(\n            self.hf_config.architectures\n        )\n        self.use_ngram_embedding = getattr(self.hf_config, \"use_ngram_embedding\", False)\n        self.is_piecewise_cuda_graph_disabled_model = (\n            is_piecewise_cuda_graph_disabled_model(self.hf_config.architectures)\n            or is_deepseek_nsa(self.hf_text_config)\n        )\n        self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)\n\n        # Derive context length and model shapes\n        self._derive_context_length(context_length)\n        self._derive_model_shapes()\n\n        # Update hybrid model\n        self._derive_hybrid_model()\n\n        # Verify quantization\n        self._verify_quantization()\n\n        self._verify_transformers_version()\n\n        # Verify dual-chunk attention config\n        self._verify_dual_chunk_attention_config()\n\n        # Cache attributes\n        self.hf_eos_token_id = self._get_hf_eos_token_id()\n\n        # multimodal\n        self.image_token_id = getattr(\n            self.hf_config, \"image_token_id\", None\n        ) or getattr(self.hf_config, \"image_token_index\", None)\n\n        self.hf_config.encoder_only = encoder_only\n        self.hf_config.language_only = language_only\n\n        # matryoshka embeddings\n        self.matryoshka_dimensions = getattr(\n            self.hf_config, \"matryoshka_dimensions\", None\n        )\n        self.is_matryoshka = self.matryoshka_dimensions or getattr(\n            self.hf_config, \"is_matryoshka\", False\n        )\n\n    @staticmethod\n    def from_server_args(\n        server_args: ServerArgs,\n        model_path: str = None,\n        model_revision: str = None,\n        is_draft_model: bool = False,\n        **kwargs,\n    ):\n        quantization = (\n            server_args.speculative_draft_model_quantization\n            if is_draft_model\n            else server_args.quantization\n        )\n        override_config_file = (\n            server_args.decrypted_draft_config_file\n            if is_draft_model\n            else server_args.decrypted_config_file\n        )\n        return ModelConfig(\n            model_path=model_path or server_args.model_path,\n            trust_remote_code=server_args.trust_remote_code,\n            revision=model_revision or server_args.revision,\n            context_length=server_args.context_length,\n            model_override_args=server_args.json_model_override_args,\n            is_embedding=server_args.is_embedding,\n            enable_multimodal=server_args.enable_multimodal,\n            dtype=server_args.dtype,\n            quantization=quantization,\n            model_impl=server_args.model_impl,\n            sampling_defaults=server_args.sampling_defaults,\n            quantize_and_serve=server_args.quantize_and_serve,\n            override_config_file=override_config_file,\n            is_multi_layer_eagle=server_args.enable_multi_layer_eagle,\n            language_only=server_args.language_only,\n            encoder_only=server_args.encoder_only,\n            is_draft_model=is_draft_model,\n            disable_hybrid_swa_memory=server_args.disable_hybrid_swa_memory,\n            **kwargs,\n        )\n\n    def _config_draft_model(self):\n        is_draft_model = self.is_draft_model\n\n        if is_draft_model and self.hf_config.architectures[0] in [\n            \"DeepseekV3ForCausalLM\",\n            \"GlmMoeDsaForCausalLM\",\n        ]:\n            self.hf_config.architectures[0] = \"DeepseekV3ForCausalLMNextN\"\n\n        if is_draft_model and self.hf_config.architectures[0] in [\n            \"Glm4MoeForCausalLM\",\n            \"Glm4MoeLiteForCausalLM\",\n        ]:\n            self.hf_config.architectures[0] = \"Glm4MoeForCausalLMNextN\"\n\n        if is_draft_model and self.hf_config.architectures[0] in [\n            \"GlmOcrForConditionalGeneration\",\n        ]:\n            self.hf_config.architectures[0] = \"GlmOcrForConditionalGenerationNextN\"\n\n        if (\n            is_draft_model\n            and self.hf_config.architectures[0] == \"LongcatFlashForCausalLM\"\n        ):\n            self.hf_config.architectures[0] = \"LongcatFlashForCausalLMNextN\"\n            self.hf_config.num_hidden_layers = self.hf_config.num_nextn_predict_layers\n\n        if is_draft_model and self.hf_config.architectures[0] == \"MiMoForCausalLM\":\n            self.hf_config.architectures[0] = \"MiMoMTP\"\n        if (\n            is_draft_model\n            and self.hf_config.architectures[0] == \"MiMoV2FlashForCausalLM\"\n        ):\n            self.hf_config.architectures[0] = \"MiMoV2MTP\"\n        if is_draft_model and self.hf_config.architectures[0] == \"Step3p5ForCausalLM\":\n            self.hf_config.architectures[0] = \"Step3p5MTP\"\n        if is_draft_model and self.hf_config.architectures[0] in [\n            \"BailingMoeV2ForCausalLM\",\n            \"BailingMoeForCausalLM\",\n            \"BailingMoeV2_5ForCausalLM\",\n        ]:\n            self.hf_config.architectures[0] = \"BailingMoeForCausalLMNextN\"\n        if (\n            is_draft_model\n            and self.hf_config.architectures[0] == \"Ernie4_5_MoeForCausalLM\"\n        ):\n            self.hf_config.architectures[0] = \"Ernie4_5_MoeForCausalLMMTP\"\n\n        if is_draft_model and self.hf_config.architectures[0] == \"Qwen3NextForCausalLM\":\n            self.hf_config.architectures[0] = \"Qwen3NextForCausalLMMTP\"\n            self.hf_config.num_nextn_predict_layers = 1\n\n        if is_draft_model and self.hf_config.architectures[0] in [\n            \"Qwen3_5ForConditionalGeneration\",\n            \"Qwen3_5MoeForConditionalGeneration\",\n        ]:\n            self.hf_config.architectures[0] = \"Qwen3_5ForCausalLMMTP\"\n            self.hf_config.num_nextn_predict_layers = 1\n\n        if is_draft_model and self.hf_config.architectures[0] == \"ExaoneMoEForCausalLM\":\n            self.hf_config.architectures[0] = \"ExaoneMoEForCausalLMMTP\"\n            self.hf_config.num_nextn_predict_layers = 1\n\n        if is_draft_model and self.hf_config.architectures[0] == \"NemotronHForCausalLM\":\n            self.hf_config.architectures[0] = \"NemotronHForCausalLMMTP\"\n            self.hf_config.num_nextn_predict_layers = 1\n\n    def _derive_hybrid_model(self):\n        # Use self.context_len after it has been initialized to prevent using context_len which may be None.\n        self.is_hybrid_swa = (\n            is_hybrid_swa_model(self.hf_config.architectures)\n            and not self.disable_hybrid_swa_memory\n        )\n\n        if self.is_hybrid_swa:\n            self.swa_attention_layer_ids, self.full_attention_layer_ids = (\n                get_hybrid_layer_ids(\n                    self.hf_config.architectures,\n                    self.hf_text_config,\n                )\n            )\n\n        self.is_hybrid_swa_compress = self.hf_config.architectures[0] in [\n            \"MiMoV2FlashForCausalLM\",\n            \"MiMoV2MTP\",\n        ]\n\n    def _derive_context_length(self, context_length: int):\n        is_draft_model = self.is_draft_model\n        derived_context_len = get_context_length(self.hf_text_config)\n\n        if context_length is not None:\n            if context_length > derived_context_len:\n                reason = \"Target model's\" if is_draft_model else \"User-specified\"\n                msg = (\n                    f\"Warning: {reason} context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). \"\n                    f\"This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config.\"\n                )\n                if (\n                    envs.SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN.get()\n                    or is_in_ci()  # FIXME: fix this special case\n                ):\n                    logger.warning(msg)\n                    self.context_len = context_length\n                    if is_draft_model:\n                        self.hf_text_config.max_position_embeddings = context_length\n                        logger.warning(\n                            f\"Overriding the draft model's max_position_embeddings to {context_length}.\"\n                        )\n                else:\n                    raise ValueError(\n                        f\"{msg} To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1\"\n                    )\n            else:\n                self.context_len = context_length\n        else:\n            self.context_len = derived_context_len\n\n        # Transfer context_len to HuggingFace config so models can access it\n        self.hf_config.context_len = self.context_len\n\n    def _derive_model_shapes(self):\n        # Unify the config keys for hf_text_config\n        self.head_dim = getattr(\n            self.hf_text_config,\n            \"head_dim\",\n            self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads,\n        )\n        self.v_head_dim = getattr(\n            self.hf_text_config,\n            \"v_head_dim\",\n            self.head_dim,\n        )\n\n        self.swa_head_dim = getattr(\n            self.hf_text_config,\n            \"swa_head_dim\",\n            self.head_dim,\n        )\n        self.swa_v_head_dim = getattr(\n            self.hf_text_config,\n            \"swa_v_head_dim\",\n            self.v_head_dim,\n        )\n        # FIXME: temporary special judge for MLA architecture\n        if (\n            \"DeepseekV2ForCausalLM\" in self.hf_config.architectures\n            or \"DeepseekV32ForCausalLM\" in self.hf_config.architectures\n            or \"DeepseekV3ForCausalLM\" in self.hf_config.architectures\n            or \"DeepseekV3ForCausalLMNextN\" in self.hf_config.architectures\n            or \"Glm4MoeLiteForCausalLM\" in self.hf_config.architectures\n            or \"GlmMoeDsaForCausalLM\" in self.hf_config.architectures\n            or \"LongcatFlashForCausalLM\" in self.hf_config.architectures\n            or \"LongcatFlashForCausalLMNextN\" in self.hf_config.architectures\n            or \"DotsVLMForCausalLM\" in self.hf_config.architectures\n            or \"MistralLarge3ForCausalLM\" in self.hf_config.architectures\n            or \"PixtralForConditionalGeneration\" in self.hf_config.architectures\n            or \"MistralLarge3ForCausalLMEagle\" in self.hf_config.architectures\n            or \"KimiK25ForConditionalGeneration\" in self.hf_config.architectures\n        ):\n            self.head_dim = 256\n            self.attention_arch = AttentionArch.MLA\n            self.kv_lora_rank = self.hf_text_config.kv_lora_rank\n            self.qk_nope_head_dim = self.hf_text_config.qk_nope_head_dim\n            self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim\n            self.v_head_dim = self.hf_text_config.v_head_dim\n            self.index_head_dim = (\n                get_nsa_index_head_dim(self.hf_text_config)\n                if is_deepseek_nsa(self.hf_text_config)\n                else None\n            )\n            # Handle rope scaling\n            self.scaling = 1 / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)\n            # in transformers v5, rope_scaling is just rope_parameters for backward compatibility\n            rope_scaling = self.hf_text_config.rope_scaling\n            if rope_scaling:\n                # v5 uses \"rope_type\", v4 uses \"type\"\n                rope_type = (\n                    rope_scaling.get(\"rope_type\")\n                    or rope_scaling.get(\"type\")\n                    or \"default\"\n                )\n                if rope_type != \"default\":\n                    self.scaling = compute_mla_mscale_scaling(\n                        rope_scaling, self.scaling\n                    )\n        elif \"MiniCPM3ForCausalLM\" in self.hf_config.architectures:\n            self.head_dim = 128\n            self.attention_arch = AttentionArch.MLA\n            self.kv_lora_rank = self.hf_config.kv_lora_rank\n            self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim\n        elif \"DeepseekVL2ForCausalLM\" in self.hf_config.architectures and getattr(\n            self.hf_text_config, \"use_mla\", True\n        ):\n            self.head_dim = 256\n            self.attention_arch = AttentionArch.MLA\n            self.kv_lora_rank = self.hf_text_config.kv_lora_rank\n            self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim\n            self.qk_nope_head_dim = self.hf_text_config.qk_nope_head_dim\n        elif \"KimiVLForConditionalGeneration\" in self.hf_config.architectures:\n            self.head_dim = 256\n            self.attention_arch = AttentionArch.MLA\n            self.kv_lora_rank = self.hf_text_config.kv_lora_rank\n            self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim\n            self.v_head_dim = self.hf_text_config.v_head_dim\n            self.qk_nope_head_dim = self.hf_text_config.qk_nope_head_dim\n        elif \"KimiLinearForCausalLM\" in self.hf_config.architectures:\n            self.head_dim = 72\n            self.attention_arch = AttentionArch.MLA\n            self.kv_lora_rank = self.hf_config.kv_lora_rank\n            self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim\n            self.v_head_dim = self.hf_config.v_head_dim\n            self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim\n        elif (\n            \"BailingMoeV2_5ForCausalLM\" in self.hf_config.architectures\n            or \"BailingMoeForCausalLMNextN\" in self.hf_config.architectures\n        ):\n            self.head_dim = self.hf_text_config.head_dim\n            self.attention_arch = AttentionArch.MLA\n            self.kv_lora_rank = self.hf_text_config.kv_lora_rank\n            self.qk_nope_head_dim = self.hf_text_config.qk_nope_head_dim\n            self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim\n            self.v_head_dim = self.hf_config.v_head_dim\n            # Handle rope scaling with yarn\n            self.scaling = 1 / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)\n            if self.hf_config.rope_scaling:\n                self.scaling = compute_mla_mscale_scaling(\n                    self.hf_config.rope_scaling, self.scaling\n                )\n        elif \"SarvamMLAForCausalLM\" in self.hf_config.architectures:\n            self.head_dim = (\n                self.hf_config.qk_nope_head_dim + self.hf_config.qk_rope_head_dim\n            )\n            self.attention_arch = AttentionArch.MLA\n            self.kv_lora_rank = self.hf_config.kv_lora_rank\n            self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim\n            self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim\n            self.v_head_dim = self.hf_config.v_head_dim\n            self.scaling = 1 / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)\n            if self.hf_config.rope_scaling:\n                self.scaling = compute_mla_mscale_scaling(\n                    self.hf_config.rope_scaling, self.scaling\n                )\n        else:\n            if (\n                \"MistralModel\" in self.hf_config.architectures\n                or \"MixtralForCausalLM\" in self.hf_config.architectures\n                or \"MistralForCausalLM\" in self.hf_config.architectures\n            ):\n                if getattr(self, \"head_dim\", None) is None:\n                    self.head_dim = (\n                        self.hf_config.hidden_size // self.hf_config.num_attention_heads\n                    )\n                    # In transformers==4.52.3, the head_dim is null in MistralConfig\n                    if (\n                        not hasattr(self.hf_text_config, \"head_dim\")\n                        or self.hf_text_config.head_dim is None\n                    ):\n                        setattr(self.hf_text_config, \"head_dim\", self.head_dim)\n\n            elif \"BaichuanForCausalLM\" in self.hf_config.architectures:\n                self.use_alibi = self.hf_config.hidden_size != 4096\n\n            self.attention_arch = AttentionArch.MHA\n\n        self.num_attention_heads = self.hf_text_config.num_attention_heads\n        self.num_key_value_heads = getattr(\n            self.hf_text_config, \"num_key_value_heads\", None\n        )\n\n        # for Dbrx and MPT models\n        if self.hf_config.model_type in [\"dbrx\", \"mpt\"]:\n            self.num_key_value_heads = getattr(\n                self.hf_config.attn_config, \"kv_n_heads\", None\n            )\n\n        if self.num_key_value_heads is None:\n            self.num_key_value_heads = self.num_attention_heads\n        self.hidden_size = self.hf_text_config.hidden_size\n        self.num_hidden_layers = self.hf_text_config.num_hidden_layers\n        self.num_attention_layers = self.num_hidden_layers\n        if \"LongcatFlashForCausalLM\" in self.hf_config.architectures:\n            self.num_attention_layers = self.num_hidden_layers * 2\n        if \"IQuestLoopCoderForCausalLM\" in self.hf_config.architectures:\n            loop_num = getattr(self.hf_text_config, \"loop_num\", 1)\n            self.num_attention_layers = int(self.num_hidden_layers * int(loop_num))\n        if \"WhisperForConditionalGeneration\" in self.hf_config.architectures:\n            # Whisper has unique layer ID scheme:\n            # - Encoder self-attention: 0 to encoder_layers-1 (no KV cache)\n            # - Decoder self-attention: encoder_layers to encoder_layers+decoder_layers-1 (uses KV cache)\n            # - Decoder cross-attention: encoder_layers+decoder_layers to encoder_layers+2*decoder_layers-1\n            # Even though cross-attention doesn't save KV cache, attention backend needs buffer to exist\n            encoder_layers = getattr(self.hf_text_config, \"encoder_layers\", 0)\n            decoder_layers = getattr(\n                self.hf_text_config, \"decoder_layers\", self.num_hidden_layers\n            )\n            self.num_attention_layers = encoder_layers + 2 * decoder_layers\n        self.num_nextn_predict_layers = getattr(\n            self.hf_text_config, \"num_nextn_predict_layers\", None\n        )\n        self.vocab_size = self.hf_text_config.vocab_size\n\n    def get_total_num_attention_heads(self) -> int:\n        return self.num_attention_heads\n\n    def get_num_attention_heads(self, tensor_parallel_size) -> int:\n        total_num_attention_heads = self.num_attention_heads\n        return max(1, total_num_attention_heads // tensor_parallel_size)\n\n    # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289\n    def get_total_num_kv_heads(self) -> int:\n        \"\"\"Returns the total number of KV heads.\"\"\"\n        # For GPTBigCode & Falcon:\n        # NOTE: for falcon, when new_decoder_architecture is True, the\n        # multi_query flag is ignored and we use n_head_kv for the number of\n        # KV heads.\n        falcon_model_types = [\"falcon\", \"RefinedWeb\", \"RefinedWebModel\"]\n        new_decoder_arch_falcon = (\n            self.hf_config.model_type in falcon_model_types\n            and getattr(self.hf_config, \"new_decoder_architecture\", False)\n        )\n        if not new_decoder_arch_falcon and getattr(\n            self.hf_text_config, \"multi_query\", False\n        ):\n            # Multi-query attention, only one KV head.\n            # Currently, tensor parallelism is not supported in this case.\n            return 1\n\n        # For DBRX and MPT\n        if self.hf_config.model_type in [\"mpt\"]:\n            if \"kv_n_heads\" in self.hf_config.attn_config:\n                return self.hf_config.attn_config[\"kv_n_heads\"]\n            return self.hf_config.num_attention_heads\n        if self.hf_config.model_type in [\"dbrx\"]:\n            return getattr(\n                self.hf_config.attn_config,\n                \"kv_n_heads\",\n                self.hf_config.num_attention_heads,\n            )\n        if self.hf_config.model_type in [\"nemotron-nas\"]:\n            nkvh = {\n                self.hf_config.num_attention_heads // block.attention.n_heads_in_group\n                for block in self.hf_config.block_configs\n                if not block.attention.no_op\n            }\n            if len(nkvh) == 0:\n                raise RuntimeError(\"Couldn't determine number of kv heads\")\n            if len(nkvh) > 1:\n                raise ValueError(\n                    \"Variable GQA (VGQA) is not yet supported for nemotron-nas in sglang\"\n                )\n            return next(iter(nkvh))\n\n        attributes = [\n            # For Falcon:\n            \"n_head_kv\",\n            \"num_kv_heads\",\n            # For LLaMA-2:\n            \"num_key_value_heads\",\n            # For ChatGLM:\n            \"multi_query_group_num\",\n            # For Step3\n            \"num_attention_groups\",\n        ]\n        for attr in attributes:\n            num_kv_heads = getattr(self.hf_text_config, attr, None)\n            if num_kv_heads is not None:\n                return num_kv_heads\n\n        # For non-grouped-query attention models, the number of KV heads is\n        # equal to the number of attention heads.\n        return self.hf_text_config.num_attention_heads\n\n    def get_num_kv_heads(self, tensor_parallel_size) -> int:\n        \"\"\"Returns the number of KV heads per GPU.\"\"\"\n        total_num_kv_heads = self.get_total_num_kv_heads()\n        # If tensor parallelism is used, we divide the number of KV heads by\n        # the tensor parallel size. We will replicate the KV heads in the\n        # case where the number of KV heads is smaller than the tensor\n        # parallel size so each GPU has at least one KV head.\n        return max(1, total_num_kv_heads // tensor_parallel_size)\n\n    def get_swa_num_kv_heads(self, tensor_parallel_size) -> int:\n        \"\"\"Similar to get_num_kv_heads(), but for SWA.\"\"\"\n        if hasattr(self.hf_text_config, \"swa_num_key_value_heads\"):\n            total_num_kv_heads = self.hf_text_config.swa_num_key_value_heads\n            return max(1, total_num_kv_heads // tensor_parallel_size)\n        elif hasattr(self.hf_text_config, \"attention_other_setting\"):  # For step3p5\n            total_num_kv_heads = self.hf_text_config.attention_other_setting.get(\n                \"num_attention_groups\"\n            )\n            return max(1, total_num_kv_heads // tensor_parallel_size)\n        else:\n            return self.get_num_kv_heads(tensor_parallel_size)\n\n    # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py\n    def _parse_quant_hf_config(self):\n        quant_cfg = getattr(self.hf_config, \"quantization_config\", None)\n        if quant_cfg is not None and not isinstance(quant_cfg, dict):\n            quant_cfg = quant_cfg.to_dict()\n        if quant_cfg is not None:\n            # Identify modelopt quantization\n            if (\n                \"quant_method\" not in quant_cfg\n                or quant_cfg[\"quant_method\"] == \"modelopt\"\n            ):\n                parsed_cfg = self._parse_modelopt_quant_config(\n                    {\"quantization\": quant_cfg}\n                )\n                if parsed_cfg:\n                    quant_cfg.update(parsed_cfg)\n\n        if quant_cfg is None:\n            # compressed-tensors uses a \"compression_config\" key\n            quant_cfg = getattr(self.hf_config, \"compression_config\", None)\n        if quant_cfg is None:\n            # check if is modelopt or mixed-precision model -- Both of them don't have corresponding field\n            # in hf `config.json` but has a standalone `hf_quant_config.json` in the root directory\n            # example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main\n            # example: https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/tree/main\n            is_local = os.path.exists(self.model_path)\n            if not is_local:\n                # Conditional import based on SGLANG_USE_MODELSCOPE environment variable\n                if envs.SGLANG_USE_MODELSCOPE.get():\n\n                    from modelscope import HubApi, model_file_download\n\n                    hf_api = HubApi()\n                else:\n                    import huggingface_hub\n                    from huggingface_hub import HfApi, hf_hub_download\n\n                    hf_api = HfApi()\n                try:\n                    # In offline mode, skip file_exists check to avoid OfflineModeIsEnabled error\n                    # Instead, directly try to download/read from cache with local_files_only\n                    file_exists = False  # Initialize to avoid UnboundLocalError\n                    if not huggingface_hub.constants.HF_HUB_OFFLINE:\n                        # Online mode: check if file exists before attempting download (optimization)\n                        file_exists = retry(\n                            lambda: hf_api.file_exists(\n                                self.model_path, \"hf_quant_config.json\"\n                            ),\n                            max_retry=2,\n                            initial_delay=1.0,\n                            max_delay=5.0,\n                        )\n                        if not file_exists:\n                            # File doesn't exist on hub, no need to try downloading\n                            return quant_cfg  # None\n\n                    # Download (online mode) or read from cache (offline mode)\n                    if envs.SGLANG_USE_MODELSCOPE.get():\n                        quant_config_file = model_file_download(\n                            model_id=self.model_path,\n                            file_path=\"hf_quant_config.json\",\n                            revision=self.revision,\n                        )\n                    else:\n                        quant_config_file = hf_hub_download(\n                            repo_id=self.model_path,\n                            filename=\"hf_quant_config.json\",\n                            revision=self.revision,\n                            local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,\n                        )\n                    with open(quant_config_file) as f:\n                        quant_config_dict = json.load(f)\n                    quant_cfg = self._parse_modelopt_quant_config(quant_config_dict)\n                except huggingface_hub.errors.LocalEntryNotFoundError:\n                    # Offline mode and file not in cache - this is normal for non-quantized models\n                    logger.debug(\n                        f\"hf_quant_config.json not found in cache for {self.model_path} \"\n                        \"(offline mode, normal for non-quantized models)\"\n                    )\n                except huggingface_hub.errors.OfflineModeIsEnabled:\n                    # Should not reach here after our changes, but keep for safety\n                    logger.warning(\n                        \"Offline mode is enabled, skipping hf_quant_config.json check\"\n                    )\n                except Exception as e:\n                    logger.warning(\n                        \"Failed to load hf_quant_config.json for model %s: %s\",\n                        self.model_path,\n                        e,\n                    )\n            elif os.path.exists(os.path.join(self.model_path, \"hf_quant_config.json\")):\n                quant_config_file = os.path.join(\n                    self.model_path, \"hf_quant_config.json\"\n                )\n                with open(quant_config_file) as f:\n                    quant_config_dict = json.load(f)\n                quant_cfg = self._parse_modelopt_quant_config(quant_config_dict)\n        return quant_cfg\n\n    def _find_quant_modelslim_config(self):\n        quant_config_file = Path(self.model_path, \"quant_model_description.json\")\n        quant_cfg = None\n        if quant_config_file.is_file():\n            with open(quant_config_file) as f:\n                quant_cfg = json.load(f)\n            # This field is required for flagless model loading but is not present in\n            # modelslim model description, so we're adding it here manually.\n            quant_cfg[\"quant_method\"] = \"modelslim\"\n\n        return quant_cfg\n\n    def _parse_modelopt_quant_config(self, quant_config_dict: dict) -> Optional[dict]:\n        \"\"\"Parse ModelOpt quantization config and return the appropriate quant_method.\"\"\"\n        json_quant_configs = quant_config_dict[\"quantization\"]\n        quant_algo = json_quant_configs.get(\"quant_algo\", None)\n\n        if quant_algo == \"MIXED_PRECISION\":\n            architectures = getattr(self.hf_config, \"architectures\", []) or []\n            if getattr(self.hf_config, \"model_type\", None) == \"nemotron_h\" or any(\n                arch.startswith(\"NemotronH\") for arch in architectures\n            ):\n                return {\"quant_method\": \"modelopt_mixed\", \"quant_algo\": quant_algo}\n            return {\"quant_method\": \"w4afp8\", \"quant_algo\": quant_algo}\n        elif quant_algo and (\"FP4\" in quant_algo or \"NVFP4\" in quant_algo):\n            return {\"quant_method\": \"modelopt_fp4\", \"quant_algo\": quant_algo}\n        elif quant_algo and \"FP8\" in quant_algo:\n            return {\"quant_method\": \"modelopt_fp8\", \"quant_algo\": quant_algo}\n        else:\n            return None\n\n    def get_quantization_config_log_str(self) -> Optional[str]:\n        \"\"\"\n        Get a concise string representation of the quantization config for logging.\n        Returns something like \"quant=fp8, fmt=e4m3\" or \"quant=gptq, bits=4\".\n        \"\"\"\n        try:\n            quant_cfg = self._parse_quant_hf_config()\n            if not quant_cfg:\n                return None\n\n            quant_method = quant_cfg.get(\"quant_method\", \"quantized\")\n            log_str = f\"quant={quant_method}\"\n\n            # Append interesting fields if they exist\n            for field in [\"bits\", \"quant_algo\", \"fmt\"]:\n                if field in quant_cfg:\n                    log_str += f\", {field}={quant_cfg[field]}\"\n\n            return log_str\n        except Exception:\n            return None\n\n    def _is_already_quantized(self) -> bool:\n        \"\"\"Check if the model is already quantized based on config files.\"\"\"\n        # Check for quantization in hf_config (config.json)\n        if getattr(self.hf_config, \"quantization_config\", None) or getattr(\n            self.hf_config, \"compression_config\", None\n        ):\n            return True\n\n        # Check for HuggingFace quantization config\n        from sglang.srt.utils import has_hf_quant_config\n\n        return has_hf_quant_config(self.model_path)\n\n    def _get_modelopt_quant_type(self) -> str:\n        \"\"\"Extract ModelOpt quantization type from unified quantization flag.\"\"\"\n        if self.quantization == \"modelopt_fp8\":\n            return \"fp8\"\n        elif self.quantization == \"modelopt_fp4\":\n            return \"nvfp4\"\n        elif self.quantization == \"modelopt_mixed\":\n            raise ValueError(\n                \"modelopt_mixed is only supported for pre-quantized checkpoints.\"\n            )\n        elif self.quantization == \"modelopt\":\n            # Auto-detect from model config\n            quant_cfg = self._parse_quant_hf_config()\n            if quant_cfg:\n                quant_method = quant_cfg.get(\"quant_method\", \"\").lower()\n                if \"fp4\" in quant_method:\n                    return \"fp4\"\n                elif \"fp8\" in quant_method:\n                    return \"fp8\"\n            # Default to fp8 if can't detect\n            return \"fp8\"\n        else:\n            return \"fp8\"  # Default fallback\n\n    def _get_sliding_window_size(self) -> Optional[int]:\n        sliding_window_size = getattr(self.hf_text_config, \"sliding_window_size\", None)\n        if sliding_window_size is None:\n            sliding_window_size = getattr(self.hf_text_config, \"sliding_window\", None)\n        return sliding_window_size\n\n    def _validate_quantize_and_serve_config(self):\n        \"\"\"Validate quantize_and_serve configuration.\"\"\"\n        if not self.quantize_and_serve:\n            return\n\n        # Check if ModelOpt quantization is specified\n        _MODELOPT_QUANTIZATION_METHODS = [\n            \"modelopt\",\n            \"modelopt_fp8\",\n            \"modelopt_fp4\",\n            \"modelopt_mixed\",\n        ]\n        modelopt_quantization_specified = (\n            self.quantization in _MODELOPT_QUANTIZATION_METHODS\n        )\n\n        if not modelopt_quantization_specified:\n            raise ValueError(\n                \"quantize_and_serve requires ModelOpt quantization (set with --quantization \"\n                f\"{{{', '.join(sorted(_MODELOPT_QUANTIZATION_METHODS))}}})\"\n            )\n\n        # quantize_and_serve is disabled due to compatibility issues\n        raise NotImplementedError(\n            \"quantize_and_serve functionality is currently disabled due to compatibility issues. \"\n            \"Please use the separate quantize-then-deploy workflow instead. \"\n            \"Step 1: Quantize and export model. \"\n            \"Step 2: Deploy the exported model.\"\n        )\n\n    # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py\n    def _verify_quantization(self) -> None:\n        supported_quantization = [*QUANTIZATION_METHODS]\n        rocm_supported_quantization = [\n            \"awq\",\n            \"gptq\",\n            \"fp8\",\n            \"compressed_tensors\",\n            \"compressed-tensors\",\n            \"fbgemm_fp8\",\n            \"w8a8_fp8\",\n            \"petit_nvfp4\",\n            \"quark\",\n            \"mxfp4\",\n            \"auto-round\",\n            \"quark_int4fp8_moe\",\n        ]\n        optimized_quantization_methods = [\n            \"fp8\",\n            \"marlin\",\n            \"modelopt_fp8\",\n            \"modelopt_fp4\",\n            \"modelopt_mixed\",\n            \"gptq_marlin_24\",\n            \"gptq_marlin\",\n            \"awq_marlin\",\n            \"fbgemm_fp8\",\n            \"compressed_tensors\",\n            \"compressed-tensors\",\n            \"experts_int8\",\n            \"w8a8_int8\",\n            \"w8a8_fp8\",\n            \"moe_wna16\",\n            \"qoq\",\n            \"w4afp8\",\n            \"petit_nvfp4\",\n            \"quark\",\n            \"modelslim\",\n        ]\n        compatible_quantization_methods = {\n            \"modelopt_fp8\": [\"modelopt\"],\n            \"modelopt_fp4\": [\"modelopt\"],\n            \"modelopt_mixed\": [\"modelopt\"],\n            \"petit_nvfp4\": [\"modelopt\"],\n            \"w8a8_int8\": [\"compressed-tensors\", \"compressed_tensors\"],\n            \"w8a8_fp8\": [\"compressed-tensors\", \"compressed_tensors\"],\n        }\n        if self.quantization is not None:\n            self.quantization = self.quantization.lower()\n\n        # Parse quantization method from the HF and ModelSlim model config, if available.\n        # Only one function should return config, other should return None.\n        cfg_list = []\n        hf_config = self._parse_quant_hf_config()\n        modelslim_config = self._find_quant_modelslim_config()\n        quant_config = modelslim_config or hf_config\n        if quant_config is not None:\n            cfg_list.append(quant_config)\n\n        # Filter out None values\n        cfg_list = [item for item in cfg_list if item is not None]\n        if len(cfg_list) > 1:\n            raise ValueError(\n                \"Config list contains configs from 2 methods, must be only 1\"\n            )\n        quant_cfg = cfg_list[0] if cfg_list else None\n\n        if quant_cfg is not None:\n            quant_method = quant_cfg.get(\n                \"quant_method\", \"\" if not self.quantization else self.quantization\n            ).lower()\n\n            # Detect which checkpoint is it\n            for _, method in QUANTIZATION_METHODS.items():\n                quantization_override = method.override_quantization_method(\n                    quant_cfg, self.quantization\n                )\n                if quantization_override:\n                    quant_method = quantization_override\n                    self.quantization = quantization_override\n                    break\n\n            # Verify quantization configurations.\n            if self.quantization is None:\n                self.quantization = quant_method\n            elif self.quantization != quant_method:\n                # Check if the CLI-specified quantization is compatible with HF config's quant_method\n                is_compatible = (\n                    self.quantization in compatible_quantization_methods\n                    and quant_method\n                    in compatible_quantization_methods[self.quantization]\n                )\n                if is_compatible:\n                    # Keep the CLI-specified quantization (e.g., modelopt_fp4) even if\n                    # HF config says \"modelopt\" - they are compatible\n                    logger.info(\n                        f\"Using CLI-specified quantization ({self.quantization}) which is \"\n                        f\"compatible with HF config quant_method ({quant_method}).\"\n                    )\n                elif self.is_draft_model:\n                    # Allow auto-detection of quantization from checkpoint for draft model\n                    # only if the CLI quantization is not compatible\n                    logger.info(\n                        f\"Draft model quantization ({quant_method}) differs from \"\n                        f\"main model quantization ({self.quantization}). \"\n                        f\"Using draft model's detected quantization: {quant_method}\"\n                    )\n                    self.quantization = quant_method\n                else:\n                    raise ValueError(\n                        \"Quantization method specified in the model config \"\n                        f\"({quant_method}) does not match the quantization \"\n                        f\"method specified in the `quantization` argument \"\n                        f\"({self.quantization}).\"\n                    )\n\n            # Check if the scale_fmt is ue8m0, and warn user if deepgemm is enabled for non-ue8m0 models on blackwell\n            self.use_scale_ue8m0 = quant_cfg.get(\"scale_fmt\", None) == \"ue8m0\"\n            from sglang.srt.layers import deep_gemm_wrapper\n\n            if not self.use_scale_ue8m0 and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:\n                logger.warning(\n                    \"DeepGemm is enabled but the scale_fmt of checkpoint is not ue8m0. This might cause accuracy degradation on Blackwell.\"\n                )\n\n        if self.quantization is not None:\n            if self.quantization not in supported_quantization:\n                raise ValueError(\n                    f\"Unknown quantization method: {self.quantization}. Must \"\n                    f\"be one of {supported_quantization}.\"\n                )\n            if is_hip() and self.quantization not in rocm_supported_quantization:\n                raise ValueError(\n                    f\"{self.quantization} quantization is currently not \"\n                    f\"supported in ROCm.\"\n                )\n            if self.quantization not in optimized_quantization_methods:\n                # Don't warn for MXFP4 on SM100 since it has optimized kernels\n                if not (self.quantization == \"mxfp4\" and is_sm100_supported()):\n                    logger.warning(\n                        \"%s quantization is not fully \"\n                        \"optimized yet. The speed can be slower than \"\n                        \"non-quantized models.\",\n                        self.quantization,\n                    )\n\n    def _verify_dual_chunk_attention_config(self) -> None:\n        if hasattr(self.hf_config, \"dual_chunk_attention_config\"):\n            # Try loading the sparse attention config\n            sparse_attn_config = get_sparse_attention_config(self.model_path)\n            if not sparse_attn_config:\n                return\n            self.hf_config.dual_chunk_attention_config[\"sparse_attention_config\"] = (\n                sparse_attn_config\n            )\n            if (\n                \"sparse_attention_enabled\"\n                not in self.hf_config.dual_chunk_attention_config\n            ):\n                self.hf_config.dual_chunk_attention_config[\n                    \"sparse_attention_enabled\"\n                ] = True\n\n    def _verify_transformers_version(self):\n        import transformers\n        from packaging import version\n\n        tf_version_str = getattr(transformers, \"__version__\", None)\n        if tf_version_str is None:\n            return\n\n        vision_config = getattr(self.hf_config, \"vision_config\", None)\n        is_glm_46vmoe = \"glm-4.6v\" in self.model_path.lower() or (\n            vision_config is not None\n            and getattr(vision_config, \"model_type\", None) == \"glm4v_moe_vision\"\n            # The vision config model type for GLM-4.5v is 'glm4v_moe',\n            # while for GLM-4.6v, it is 'glm4v_moe_vision'.\n        )\n        needs_tf_v5 = is_glm_46vmoe\n\n        tf_version = version.parse(tf_version_str)\n        required_version = version.parse(\"5.0.0dev0\")\n\n        if tf_version < required_version:\n            if needs_tf_v5:\n                raise ValueError(\n                    f\"Transformers version {tf_version_str} is not supported for model {self.model_path} \"\n                    f\"or model type {self.hf_config.model_type}. \"\n                    \"Please upgrade transformers to >= 5.0.0.\"\n                )\n        elif not needs_tf_v5:\n            logger.warning(\n                f\"Transformers version {tf_version_str} is used for model type {self.hf_config.model_type}. \"\n                \"If you experience issues related to RoPE parameters, \"\n                \"they may be due to incompatibilities between Transformers >=5.0.0 and some models. \"\n                \"You can try downgrading to transformers==4.57.1 as a workaround.\"\n            )\n\n    def _get_hf_eos_token_id(self) -> Optional[Set[int]]:\n        eos_ids = getattr(self.hf_config, \"eos_token_id\", None)\n        if eos_ids is not None:\n            # it can be either int or list of int\n            eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)\n        if eos_ids is None:\n            eos_ids = set()\n        if self.hf_generation_config:\n            generation_eos_ids = getattr(\n                self.hf_generation_config, \"eos_token_id\", None\n            )\n            if generation_eos_ids:\n                generation_eos_ids = (\n                    {generation_eos_ids}\n                    if isinstance(generation_eos_ids, int)\n                    else set(generation_eos_ids)\n                )\n                eos_ids = eos_ids | generation_eos_ids\n        return eos_ids\n\n    def get_default_sampling_params(self) -> dict[str, Any]:\n        \"\"\"\n        Get default sampling parameters from the model's generation config.\n\n        This method returns non-default sampling parameters from the model's\n        generation_config.json when sampling_defaults is set to \"model\".\n\n        Returns:\n            A dictionary containing the non-default sampling parameters.\n        \"\"\"\n        if self.sampling_defaults != \"model\":\n            return {}\n\n        if self.hf_generation_config is None:\n            return {}\n\n        config = self.hf_generation_config.to_dict()\n\n        available_params = [\n            \"repetition_penalty\",\n            \"temperature\",\n            \"top_k\",\n            \"top_p\",\n            \"min_p\",\n        ]\n\n        default_sampling_params = {\n            p: config.get(p) for p in available_params if config.get(p) is not None\n        }\n\n        return default_sampling_params\n\n    def _maybe_pull_model_tokenizer_from_remote(self) -> None:\n        \"\"\"\n        Pull the model config files to a temporary\n        directory in case of remote.\n\n        Args:\n            model: The model name or path.\n\n        \"\"\"\n        from sglang.srt.connector import create_remote_connector\n        from sglang.srt.utils import is_remote_url\n\n        if is_remote_url(self.model_path):\n            logger.info(\"Pulling model configs from remote...\")\n            # BaseConnector implements __del__() to clean up the local dir.\n            # Since config files need to exist all the time, so we DO NOT use\n            # with statement to avoid closing the client.\n            client = create_remote_connector(self.model_path)\n            if is_remote_url(self.model_path):\n                client.pull_files(allow_pattern=[\"*config.json\"])\n                self.model_weights = self.model_path\n                self.model_path = client.get_local_dir()\n\n\n# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py\n_STR_DTYPE_TO_TORCH_DTYPE = {\n    \"half\": torch.float16,\n    \"float16\": torch.float16,\n    \"float\": torch.float32,\n    \"float32\": torch.float32,\n    \"bfloat16\": torch.bfloat16,\n}\n\n\n# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py\ndef _get_and_verify_dtype(\n    config: PretrainedConfig,\n    dtype: Union[str, torch.dtype],\n) -> torch.dtype:\n    # NOTE: getattr(config, \"torch_dtype\", torch.float32) is not correct\n    # because config.torch_dtype can be None.\n    if isinstance(config, dict):\n        config_dtype = config.get(\"dtype\", None) or config.get(\"torch_dtype\", None)\n        model_type = config.get(\"model_type\", \"\")\n    else:\n        config_dtype = getattr(config, \"dtype\", None)\n        model_type = getattr(config, \"model_type\", \"\")\n    if isinstance(config_dtype, str):\n        config_dtype = _STR_DTYPE_TO_TORCH_DTYPE.get(config_dtype, None)\n    if config_dtype is None:\n        config_dtype = torch.float32\n\n    if isinstance(dtype, str):\n        dtype = dtype.lower()\n        if dtype == \"auto\":\n            if config_dtype == torch.float32:\n                if model_type.startswith(\"gemma\"):\n                    if model_type == \"gemma\":\n                        gemma_version = \"\"\n                    else:\n                        gemma_version = model_type[5]\n                    logger.info(\n                        f\"For Gemma {gemma_version}, we downcast float32 to bfloat16 instead \"\n                        \"of float16 by default. Please specify `dtype` if you \"\n                        \"want to use float16.\"\n                    )\n                    torch_dtype = torch.bfloat16\n                else:\n                    # Following the common practice, we use float16 for float32\n                    # models.\n                    torch_dtype = torch.float16\n            else:\n                torch_dtype = config_dtype\n        else:\n            if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:\n                raise ValueError(f\"Unknown dtype: {dtype}\")\n            torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]\n    elif isinstance(dtype, torch.dtype):\n        torch_dtype = dtype\n    else:\n        raise ValueError(f\"Unknown dtype: {dtype}\")\n\n    # Verify the dtype.\n    if torch_dtype != config_dtype:\n        if torch_dtype == torch.float32:\n            # Upcasting to float32 is allowed.\n            logger.info(\"Upcasting %s to %s.\", config_dtype, torch_dtype)\n            pass\n        elif config_dtype == torch.float32:\n            # Downcasting from float32 to float16 or bfloat16 is allowed.\n            logger.info(\"Downcasting %s to %s.\", config_dtype, torch_dtype)\n            pass\n        else:\n            # Casting between float16 and bfloat16 is allowed with a warning.\n            logger.warning(\"Casting %s to %s.\", config_dtype, torch_dtype)\n\n    return torch_dtype\n\n\ndef is_generation_model(model_architectures: List[str], is_embedding: bool = False):\n    # We have two ways to determine whether a model is a generative model.\n    # 1. Check the model architecture\n    # 2. check the `is_embedding` server args\n\n    if (\n        \"LlamaEmbeddingModel\" in model_architectures\n        or \"MistralModel\" in model_architectures\n        or \"LlamaForSequenceClassification\" in model_architectures\n        or \"LlamaForSequenceClassificationWithNormal_Weights\" in model_architectures\n        or \"InternLM2ForRewardModel\" in model_architectures\n        or \"Qwen2ForRewardModel\" in model_architectures\n        or \"Qwen3ForRewardModel\" in model_architectures\n        or \"Qwen2ForSequenceClassification\" in model_architectures\n        or \"Qwen3ForSequenceClassification\" in model_architectures\n        or \"CLIPModel\" in model_architectures\n        or \"BertModel\" in model_architectures\n        or \"Contriever\" in model_architectures\n        or \"BertForSequenceClassification\" in model_architectures\n        or \"XLMRobertaModel\" in model_architectures\n        or \"XLMRobertaForSequenceClassification\" in model_architectures\n        or \"Gemma2ForSequenceClassification\" in model_architectures\n    ):\n        return False\n    else:\n        return not is_embedding\n\n\nmultimodal_model_archs = [\n    \"CLIPModel\",\n    \"DeepseekVL2ForCausalLM\",\n    \"Ernie4_5_VLMoeForConditionalGeneration\",\n    \"Gemma3ForConditionalGeneration\",\n    \"Gemma3nForConditionalGeneration\",\n    \"Glm4vForConditionalGeneration\",\n    \"Glm4vMoeForConditionalGeneration\",\n    \"GlmOcrForConditionalGeneration\",\n    \"GlmAsrForConditionalGeneration\",\n    \"Grok1VForCausalLM\",\n    \"Grok1AForCausalLM\",\n    \"LlavaLlamaForCausalLM\",\n    \"Llama4ForConditionalGeneration\",\n    \"LlavaMistralForCausalLM\",\n    \"LlavaQwenForCausalLM\",\n    \"LlavaForConditionalGeneration\",\n    \"LlavaVidForCausalLM\",\n    \"LightOnOCRForConditionalGeneration\",\n    \"MiniCPMO\",\n    \"MiniCPMV\",\n    \"Mistral3ForConditionalGeneration\",\n    \"MultiModalityCausalLM\",\n    \"MllamaForConditionalGeneration\",\n    \"NemotronH_Nano_VL_V2\",\n    \"PixtralForConditionalGeneration\",\n    \"Qwen2AudioForConditionalGeneration\",\n    \"Qwen2VLForConditionalGeneration\",\n    \"Qwen2_5_VLForConditionalGeneration\",\n    \"Qwen3VLForConditionalGeneration\",\n    \"Qwen3VLMoeForConditionalGeneration\",\n    \"Qwen3_5ForConditionalGeneration\",\n    \"Qwen3_5MoeForConditionalGeneration\",\n    \"Qwen3OmniMoeForConditionalGeneration\",\n    \"KimiVLForConditionalGeneration\",\n    \"InternVLChatModel\",\n    \"InternS1ForConditionalGeneration\",\n    \"InternS1ProForConditionalGeneration\",\n    \"Phi4MMForCausalLM\",\n    \"WhisperForConditionalGeneration\",\n    \"Step3VLForConditionalGeneration\",\n    \"POINTSV15ChatModel\",\n    \"DotsVLMForCausalLM\",\n    \"DotsOCRForCausalLM\",\n    \"Sarashina2VisionForCausalLM\",\n    \"NVILAForConditionalGeneration\",\n    \"NVILALiteForConditionalGeneration\",\n    \"DeepseekOCRForCausalLM\",\n    \"JetVLMForConditionalGeneration\",\n    \"PaddleOCRVLForConditionalGeneration\",\n    \"MiDashengLMModel\",\n    \"StepVLForConditionalGeneration\",\n    \"KimiK25ForConditionalGeneration\",\n]\n\npiecewise_cuda_graph_disabled_model_archs = [\n    \"DeepseekV32ForCausalLM\",\n    \"Qwen3NextForCausalLM\",\n    \"GlmMoeDsaForCausalLM\",\n    \"BailingMoeV2_5ForCausalLM\",\n    \"LLaDAModelLM\",\n]\n\nif external_mm_model_arch := envs.SGLANG_EXTERNAL_MM_MODEL_ARCH.get():\n    multimodal_model_archs.append(external_mm_model_arch)\n\n\ndef is_multimodal_model(model_architectures: List[str]):\n    if any(\n        multi_model_arch in model_architectures\n        for multi_model_arch in multimodal_model_archs\n    ):\n        return True\n    else:\n        return False\n\n\ndef is_multimodal_gen_model(model_architectures: List[str]):\n    return False\n\n\ndef is_image_gen_model(model_architectures: List[str]):\n    return False\n\n\ndef is_audio_model(model_architectures: List[str]):\n    models = [\n        \"WhisperForConditionalGeneration\",\n    ]\n    return any(model in model_architectures for model in models)\n\n\ndef is_encoder_decoder_model(model_architectures: List[str]):\n    models = [\n        \"WhisperForConditionalGeneration\",\n        \"MllamaForConditionalGeneration\",\n    ]\n    return any(model in model_architectures for model in models)\n\n\ndef is_local_attention_model(model_architectures: List[str]):\n    return \"Llama4ForConditionalGeneration\" in model_architectures\n\n\ndef is_multimodal_chunked_prefill_supported(model_architectures: List[str]):\n    \"\"\"Check if chunked prefill is supported for a MultiModal model.\"\"\"\n    unsupported = [\n        \"Grok1VForCausalLM\",\n        \"Grok1AForCausalLM\",\n        \"LlavaLlamaForCausalLM\",\n        \"MllamaForConditionalGeneration\",\n        \"CLIPModel\",\n    ]\n    if any(multi_model_arch in unsupported for multi_model_arch in model_architectures):\n        return False\n    else:\n        return True\n\n\ndef is_piecewise_cuda_graph_disabled_model(model_architectures: List[str]):\n    return any(\n        arch in piecewise_cuda_graph_disabled_model_archs\n        for arch in model_architectures\n    )\n\n\ndef yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:\n    if scale <= 1:\n        return 1.0\n    return 0.1 * mscale * math.log(scale) + 1.0\n\n\ndef compute_mla_mscale_scaling(rope_scaling: dict, base_scaling: float) -> float:\n    \"\"\"Compute MLA attention scaling factor from rope_scaling with mscale.\n\n    Used by DeepSeek, BailingMoe, SarvamMLA and similar MLA models.\n    Warns if 'factor' is missing from rope_scaling (common in v5 configs).\n    \"\"\"\n    mscale_all_dim = rope_scaling.get(\"mscale_all_dim\", False)\n    if \"factor\" not in rope_scaling:\n        logger.warning(\n            \"rope_scaling missing 'factor', defaulting to 1.0. \"\n            \"Check model accuracy.\",\n        )\n    scaling_factor = rope_scaling.get(\"factor\", 1.0)\n    mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))\n    return base_scaling * mscale * mscale\n\n\ndef is_hybrid_swa_model(model_architectures: List[str]):\n\n    hybrid_swa_archs = {\n        \"Llama4ForConditionalGeneration\",\n        \"GptOssForCausalLM\",\n        \"MiMoV2FlashForCausalLM\",\n        \"MiMoV2MTP\",\n        \"Step3p5ForCausalLM\",\n        \"Step3p5MTP\",\n    }\n    return any(arch in hybrid_swa_archs for arch in model_architectures)\n\n\ndef get_hybrid_layer_ids(\n    model_architectures: List[str],\n    hf_text_config: PretrainedConfig,\n):\n    num_hidden_layers = hf_text_config.num_hidden_layers\n    if \"Llama4ForConditionalGeneration\" in model_architectures:\n        swa_attention_layer_ids = [\n            i for i in range(num_hidden_layers) if (i + 1) % 4 != 0\n        ]\n        full_attention_layer_ids = [\n            i for i in range(num_hidden_layers) if (i + 1) % 4 == 0\n        ]\n    elif \"GptOssForCausalLM\" in model_architectures:\n        layer_types = getattr(hf_text_config, \"layer_types\", None)\n        swa_attention_layer_ids = [\n            i for i, x in enumerate(layer_types) if x == \"sliding_attention\"\n        ]\n        full_attention_layer_ids = [\n            i for i, x in enumerate(layer_types) if x == \"full_attention\"\n        ]\n    elif \"MiMoV2FlashForCausalLM\" in model_architectures:\n        hybrid_layer_pattern = getattr(hf_text_config, \"hybrid_layer_pattern\", None)\n        swa_attention_layer_ids = [\n            i for i in range(num_hidden_layers) if hybrid_layer_pattern[i] == 1\n        ]\n        full_attention_layer_ids = [\n            i for i in range(num_hidden_layers) if hybrid_layer_pattern[i] == 0\n        ]\n    elif \"MiMoV2MTP\" in model_architectures:\n        swa_attention_layer_ids = [0]\n        full_attention_layer_ids = []\n    elif \"Step3p5ForCausalLM\" in model_architectures:\n        layer_types = hf_text_config.layer_types\n        swa_attention_layer_ids = [\n            i\n            for i, x in enumerate(layer_types)\n            if x == \"sliding_attention\" and i < num_hidden_layers\n        ]\n        full_attention_layer_ids = [\n            i\n            for i, x in enumerate(layer_types)\n            if x == \"full_attention\" and i < num_hidden_layers\n        ]\n    elif \"Step3p5MTP\" in model_architectures:\n        swa_attention_layer_ids = [0]\n        full_attention_layer_ids = []\n    else:\n        swa_attention_layer_ids = None\n        full_attention_layer_ids = None\n    return swa_attention_layer_ids, full_attention_layer_ids\n"
  },
  {
    "path": "python/sglang/srt/configs/modelopt_config.py",
    "content": "# Configuration for NVIDIA ModelOpt quantization integration\nfrom dataclasses import dataclass\nfrom typing import Optional\n\n\n@dataclass\nclass ModelOptConfig:\n    \"\"\"Configuration for NVIDIA ModelOpt quantization operations.\n\n    This configuration class holds parameters for ModelOpt quantization,\n    checkpoint management, and model export operations.\n\n    Args:\n        quant: Quantization method/type (e.g., \"fp8\", \"fp4\")\n        checkpoint_restore_path: Path to restore ModelOpt checkpoint from\n        checkpoint_save_path: Path to save ModelOpt checkpoint to\n        export_path: Path to export quantized model in HuggingFace format\n        quantize_and_serve: Whether to quantize and serve in one step\n    \"\"\"\n\n    quant: Optional[str] = None\n    checkpoint_restore_path: Optional[str] = None\n    checkpoint_save_path: Optional[str] = None\n    export_path: Optional[str] = None\n    quantize_and_serve: bool = False\n\n    def __post_init__(self):\n        \"\"\"Validate configuration after initialization.\"\"\"\n        # Add any validation logic if needed\n        pass\n"
  },
  {
    "path": "python/sglang/srt/configs/nano_nemotron_vl.py",
    "content": "# Copyright 2025 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n# Adapted from https://huggingface.co/nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16/blob/cb5a65ff10232128389d882d805fa609427544f1/configuration.py\n\nfrom typing import Any\n\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom sglang.srt.configs.nemotron_h import NemotronHConfig\nfrom sglang.srt.configs.radio import RadioConfig\nfrom sglang.srt.multimodal.internvl_utils import IMAGENET_MEAN, IMAGENET_STD\n\n\ndef float_triplet(seq: Any):\n    a, b, c = tuple(seq)\n    assert (\n        isinstance(a, float) and isinstance(b, float) and isinstance(c, float)\n    ), \"expected three floats\"\n    return a, b, c\n\n\nclass NemotronH_Nano_VL_V2_Config(PretrainedConfig):\n    model_type = \"NemotronH_Nano_VL_V2\"\n    is_composition = True\n\n    def __init__(\n        self,\n        vision_config=None,\n        llm_config=None,\n        force_image_size: int = 512,\n        patch_size: int = 16,\n        downsample_ratio=0.5,\n        template=None,\n        ps_version=\"v2\",\n        image_tag_type=\"internvl\",\n        projector_hidden_size=4096,\n        vit_hidden_size=1280,\n        video_pruning_rate: float = 0.0,\n        video_context_token: str = \"<video>\",\n        img_context_token: str = \"<image>\",\n        img_start_token: str = \"<img>\",\n        img_end_token: str = \"</img>\",\n        norm_mean: tuple[float, float, float] | list[float] = IMAGENET_MEAN,\n        norm_std: tuple[float, float, float] | list[float] = IMAGENET_STD,\n        use_thumbnail: bool = True,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n\n        # Handle both cases: when loading from JSON (llm_config is dict) and when called internally by transformers (llm_config; vision_config are None)\n        if llm_config is not None:\n            self.llm_config = NemotronHConfig(**llm_config)\n            assert isinstance(vision_config, dict), \"vision_config must be a dictionary\"\n            self.raw_vision_config = vision_config\n        else:\n            assert vision_config is None\n            self.llm_config = NemotronHConfig()\n            self.raw_vision_config = {}\n\n        # Assign configuration values\n        vision_image_size = self.raw_vision_config.get(\"image_size\", force_image_size)\n        vision_patch_size = self.raw_vision_config.get(\"patch_size\", patch_size)\n        self.image_size = int(\n            vision_image_size[0]\n            if isinstance(vision_image_size, list)\n            else vision_image_size\n        )\n        self.patch_size = int(\n            vision_patch_size[0]\n            if isinstance(vision_patch_size, list)\n            else vision_patch_size\n        )\n\n        self.downsample_ratio = downsample_ratio\n        self.video_context_token = video_context_token\n        self.img_context_token = img_context_token\n        self.template = template  # TODO move out of here and into the tokenizer\n        self.ps_version = ps_version  # Pixel shuffle version\n        self.image_tag_type = image_tag_type  # TODO: into the tokenizer too?\n        self.projector_hidden_size = projector_hidden_size\n        self.vit_hidden_size = vit_hidden_size\n        self.video_pruning_rate = video_pruning_rate\n\n        self.norm_mean = float_triplet(norm_mean)\n        self.norm_std = float_triplet(norm_std)\n        self.use_thumbnail = use_thumbnail\n        self.img_start_token = img_start_token\n        self.img_end_token = img_end_token\n\n    def create_radio_config(self):\n        config = self.raw_vision_config\n        model_name = config[\"args\"][\"model\"]\n        reg_tokens = config[\"args\"].get(\"register_multiple\")\n        image_size = config.get(\"preferred_resolution\", [224])[0]\n        radio_config = RadioConfig(\n            patch_size=self.patch_size,\n            norm_mean=self.norm_mean,\n            norm_std=self.norm_std,\n            model_name=model_name,\n            reg_tokens=reg_tokens,\n            image_size=image_size,\n        )\n        return radio_config\n"
  },
  {
    "path": "python/sglang/srt/configs/nemotron_h.py",
    "content": "# Copyright 2025 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/configs/nemotron_h.py\n\n\"\"\"NemotronH model configuration\"\"\"\n\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.utils import logging\n\nfrom sglang.srt.configs.mamba_utils import (\n    Mamba2CacheParams,\n    Mamba2StateShape,\n    mamba2_state_dtype,\n)\n\nlogger = logging.get_logger(__name__)\n\nMAMBA = \"M\"\nATTENTION = \"*\"\nMLP = \"-\"\nMOE = \"E\"\nDEFAULT_LAYERS_BLOCK_TYPE = [\"mamba\", \"moe\", \"attention\", \"moe\"]\nDEFAULT_MTP_LAYERS_BLOCK_TYPE = [\"attention\", \"moe\"]\nDEFAULT_MAMBA_CHUNK_SIZE = 256\n\n\nclass NemotronHConfig(PretrainedConfig):\n    r\"\"\"\n    This is the configuration class to store the configuration of a\n    [`NemotronHModel`]. It is used to instantiate a NemotronH model according\n    to the specified arguments, defining the model architecture. Instantiating\n    a configuration with the defaults will yield a similar configuration to\n    that of the NemotronH-v0.1 model.\n    Args:\n        vocab_size (`int`, *optional*, defaults to 131072):\n            Vocabulary size of the NemotronH model. Defines the number of\n            different tokens that can be represented by the `inputs_ids`\n            passed when calling [`NemotronHModel`]\n        tie_word_embeddings (`bool`, *optional*, defaults to `False`):\n            Whether the model's input and output word embeddings should be\n            tied. Note that this is only relevant if the model has an output\n            word embedding layer.\n        hidden_size (`int`, *optional*, defaults to 4096):\n            Dimension of the hidden representations.\n        intermediate_size (`int`, *optional*, defaults to 21504):\n            Dimension of the MLP representations.\n        num_hidden_layers (`int`, *optional*):\n            Deprecated. Kept only for backward compatibility. The effective\n            layer count is derived from `layers_block_type`.\n        hybrid_override_pattern (`str`, *optional*, defaults to\n            `\"M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-\"`):\n            Deprecated compatibility field. Pattern string where each\n            character represents Mamba2 (`M`), Attention (`*`), MLP (`-`),\n            or MoE (`E`).\n        layers_block_type (`list[str]`, *optional*):\n            Canonical layer layout. Each entry is one of:\n            `\"mamba\"`, `\"attention\"`, `\"mlp\"`, `\"moe\"`.\n        num_attention_heads (`int`, *optional*, defaults to 32):\n            Number of attention heads for each attention layer in the\n            Transformer encoder.\n        attention_head_dim (`int`, *optional*, defaults to 128):\n            Dimension of each attention head.\n        num_key_value_heads (`int`, *optional*, defaults to 8):\n            This is the number of key_value heads that should be used to\n            implement Grouped Query Attention. If\n            `num_key_value_heads=num_attention_heads`, the model will use\n            Multi Head Attention (MHA), if `num_key_value_heads=1` the model\n            will use Multi Query Attention (MQA) otherwise GQA is used.\n        mlp_hidden_act (`str`, *optional*, defaults to \"relu2\"):\n            The non-linear activation function in the MLP layers.\n        attention_bias (`bool`, *optional*, defaults to `False`):\n            Whether to use bias in attention layers.\n        mlp_bias (`bool`, *optional*, defaults to `False`):\n            Whether to use bias in MLP layers.\n        use_bias (`bool`, *optional*, defaults to `False`):\n            Whether to use bias in the model.\n        initializer_range (`float`, *optional*, defaults to 0.02):\n            The standard deviation of the truncated_normal_initializer for\n            initializing all weight matrices.\n        layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):\n            The epsilon used by the layer normalization layers.\n        residual_in_fp32 (`bool`, *optional*, defaults to `False`):\n            Whether or not residuals should be in `float32`. If set to `False`\n            residuals will keep the same `dtype` as the rest of the model.\n        use_cache (`bool`, *optional*, defaults to `True`):\n            Whether or not the model should return the last key/values\n            attentions (not used by all models). Only relevant if\n            `config.is_decoder=True`.\n        num_logits_to_keep (`int` or `None`, *optional*, defaults to 1):\n            Number of prompt logits to calculate during generation. If `None`,\n            all logits will be calculated. If an integer value, only last\n            `num_logits_to_keep` logits will be calculated.\n        pad_token_id (`int`, *optional*, defaults to 0):\n            The id of the padding token.\n        bos_token_id (`int`, *optional*, defaults to 1):\n            The id of the \"beginning-of-sequence\" token.\n        eos_token_id (`int`, *optional*, defaults to 2):\n            The id of the \"end-of-sequence\" token.\n        sliding_window (`int`, *optional*, defaults to None):\n            Sliding window attention window size.\n        max_position_embeddings (`int`, *optional*, defaults to 4096):\n            The maximum sequence length that this model might ever be used\n            with.\n        attention_dropout (`float`, *optional*, defaults to 0.0):\n            The dropout ratio for the attention probabilities.\n        hidden_dropout (`float`, *optional*, defaults to 0.0):\n            The dropout ratio for the hidden states.\n        use_mamba_kernels (`bool`, *optional*, defaults to `True`):\n            Flag indicating whether or not to use the fast mamba kernels.\n            These are available only if `mamba-ssm` and `causal-conv1d`\n            are installed, and the mamba modules are running on a CUDA device.\n        ssm_state_size (`int`, *optional*, defaults to 128):\n            The dimension of the mamba state space latents.\n        mamba_num_heads (`int`, *optional*, defaults to 128):\n            Number of heads in Mamba layers.\n        mamba_n_groups (`int`, *optional*, defaults to 8):\n            Number of groups in Mamba layers.\n        mamba_head_dim (`int`, *optional*, defaults to 64):\n            Dimension of each Mamba head.\n        mamba_d_conv (`int`, *optional*, defaults to 4):\n            The size of the mamba convolution kernel.\n        mamba_expand (`int`, *optional*, defaults to 2):\n            Expanding factor used to determine the mamba intermediate size.\n        mamba_hidden_act (`str`, *optional*, defaults to \"silu\"):\n            The non-linear activation function in the Mamba layers.\n        mamba_dt_min (`float`, *optional*, defaults to 0.001):\n            Minimum value for the time step in Mamba.\n        mamba_dt_max (`float`, *optional*, defaults to 0.1):\n            Maximum value for the time step in Mamba.\n        mamba_dt_limit (`tuple`, *optional*, defaults to (0.0, float(\"inf\"))):\n            Limits for the time step in Mamba.\n        mamba_dt_init_floor (`float`, *optional*, defaults to 1e-4):\n            Floor value for time step initialization in Mamba.\n        mamba_conv_bias (`bool`, *optional*, defaults to `True`):\n            Whether to use bias in the convolution layer of the mamba mixer\n            block.\n        mamba_proj_bias (`bool`, *optional*, defaults to `False`):\n            Whether to use bias in the input and output projections of the\n            mamba mixer block.\n        mamba_chunk_size (`int`, *optional*, defaults to 256):\n            Size of chunks for Mamba processing.\n        rescale_prenorm_residual (`bool`, *optional*, defaults to `True`):\n            Whether to rescale the pre-normalization residual connections.\n    \"\"\"\n\n    model_type = \"nemotron_h\"\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n\n    @staticmethod\n    def _validate_layers_block_type(\n        layers_block_type, expected_length=None, param_name=\"layers_block_type\"\n    ):\n        \"\"\"\n        Validate layers_block_type list.\n        Args:\n            layers_block_type: List of layer types to validate.\n            expected_length: If provided, validate the list has this length.\n            param_name: Parameter name for error messages.\n        Raises:\n            ValueError: If validation fails.\n        \"\"\"\n        if not isinstance(layers_block_type, list):\n            raise ValueError(\n                f\"{param_name} must be a list of strings. Got type: {type(layers_block_type)}\"\n            )\n        if expected_length is not None and len(layers_block_type) != expected_length:\n            raise ValueError(\n                f\"{param_name} must have length {expected_length}. Got length {len(layers_block_type)}.\"\n            )\n        valid_types = {\"mamba\", \"attention\", \"mlp\", \"moe\"}\n        if not all(block_type in valid_types for block_type in layers_block_type):\n            invalid = set(layers_block_type) - valid_types\n            raise ValueError(\n                f\"{param_name} contains invalid types: {invalid}. Must be one of: {valid_types}\"\n            )\n\n    @staticmethod\n    def _resolve_layers_block_type(\n        layers_block_type, hybrid_override_pattern, kwargs\n    ) -> list[str]:\n        \"\"\"Resolve canonical layers_block_type from new and legacy config fields.\"\"\"\n        # Prefer explicit kwargs override first (legacy HF path), otherwise use\n        # the function argument value from config fields.\n        pattern = kwargs.pop(\"hybrid_override_pattern\", hybrid_override_pattern)\n        if layers_block_type is None:\n            if pattern is not None:\n                layers_block_type = NemotronHConfig._pattern_to_list(pattern)\n            else:\n                # Last-resort fallback to preserve compatibility when neither\n                # canonical nor legacy pattern fields are provided.\n                layers_block_type = DEFAULT_LAYERS_BLOCK_TYPE\n        return layers_block_type\n\n    @staticmethod\n    def _resolve_mtp_layers_block_type(mtp_layers_block_type, kwargs) -> list[str]:\n        \"\"\"Resolve canonical mtp_layers_block_type from new and legacy config fields.\"\"\"\n        if \"mtp_hybrid_override_pattern\" in kwargs:\n            pattern = kwargs.pop(\"mtp_hybrid_override_pattern\")\n            if mtp_layers_block_type is None or mtp_layers_block_type == [\n                \"attention\",\n                \"moe\",\n            ]:\n                mtp_layers_block_type = NemotronHConfig._pattern_to_list(pattern)\n        return mtp_layers_block_type\n\n    @staticmethod\n    def _resolve_mamba_chunk_size(mamba_chunk_size, kwargs) -> int:\n        \"\"\"Resolve canonical mamba_chunk_size from new and legacy config fields.\"\"\"\n        chunk_size = kwargs.pop(\"chunk_size\", None)\n        if (\n            mamba_chunk_size is not None\n            and chunk_size is not None\n            and mamba_chunk_size != chunk_size\n        ):\n            logger.warning(\n                \"Both chunk_size=%s and mamba_chunk_size=%s were provided. \"\n                \"Using mamba_chunk_size.\",\n                chunk_size,\n                mamba_chunk_size,\n            )\n\n        if mamba_chunk_size is None:\n            mamba_chunk_size = chunk_size\n        if mamba_chunk_size is None:\n            mamba_chunk_size = DEFAULT_MAMBA_CHUNK_SIZE\n        return mamba_chunk_size\n\n    def __init__(\n        self,\n        vocab_size=131072,\n        tie_word_embeddings=False,\n        hidden_size=4096,\n        intermediate_size=21504,\n        num_hidden_layers=None,  # Deprecated, only for backward compatibility\n        hybrid_override_pattern=\"M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-\",\n        layers_block_type=None,\n        num_attention_heads=32,\n        head_dim=128,\n        num_key_value_heads=8,  # nemo: num_query_groups\n        mlp_hidden_act=\"relu2\",\n        attention_bias=False,\n        mlp_bias=False,\n        use_bias=False,\n        initializer_range=0.02,  # nemo: init_method_std\n        layer_norm_epsilon=1e-5,  # nemo: layernorm_epsilon\n        residual_in_fp32=False,  #  Megatron Core default value\n        use_cache=True,\n        num_logits_to_keep=1,\n        pad_token_id=0,\n        bos_token_id=1,\n        eos_token_id=2,\n        sliding_window=None,\n        max_position_embeddings=4096,\n        attention_dropout=0.0,\n        hidden_dropout=0.0,  # * ADDED\n        use_mamba_kernels=True,\n        ssm_state_size=128,  # mamba_state_size\n        mamba_num_heads=128,\n        mamba_n_groups=8,  # nemo: mamba_ssm_ngroups = num_heads\n        mamba_head_dim=64,\n        mamba_d_conv=4,\n        mamba_expand=2,\n        mamba_hidden_act=\"silu\",\n        mamba_dt_min=0.001,\n        mamba_dt_max=0.1,\n        mamba_dt_limit=(0.0, float(\"inf\")),\n        mamba_dt_init_floor=1e-4,\n        mamba_conv_bias=True,\n        mamba_proj_bias=False,\n        mamba_chunk_size=None,\n        rescale_prenorm_residual=True,\n        n_routed_experts=8,\n        n_shared_experts=1,\n        moe_intermediate_size=7688,\n        moe_shared_expert_intermediate_size=7688,\n        moe_latent_size=None,\n        num_experts_per_tok=2,\n        routed_scaling_factor=1.0,\n        n_group=1,\n        topk_group=1,\n        norm_topk_prob=True,\n        num_nextn_predict_layers=0,\n        mtp_layers_block_type=DEFAULT_MTP_LAYERS_BLOCK_TYPE,\n        **kwargs,\n    ):\n        mamba_chunk_size = self._resolve_mamba_chunk_size(mamba_chunk_size, kwargs)\n\n        # Compatibility parsing: normalize legacy pattern fields into canonical list fields.\n        layers_block_type = self._resolve_layers_block_type(\n            layers_block_type, hybrid_override_pattern, kwargs\n        )\n        mtp_layers_block_type = self._resolve_mtp_layers_block_type(\n            mtp_layers_block_type, kwargs\n        )\n\n        # num_hidden_layers is deprecated and ignored as a source of truth.\n        if (\n            num_hidden_layers is not None\n            and len(layers_block_type) != num_hidden_layers\n        ):\n            logger.warning(\n                f\"num_hidden_layers ({num_hidden_layers}) is deprecated and doesn't match \"\n                f\"layers_block_type length ({len(layers_block_type)}). Using layers_block_type length.\"\n            )\n\n        # Core model attributes.\n        self.vocab_size = vocab_size\n        self.tie_word_embeddings = tie_word_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_attention_heads = num_attention_heads\n        self.head_dim = head_dim\n        self.sliding_window = sliding_window\n        self.max_position_embeddings = max_position_embeddings\n        self.attention_dropout = attention_dropout\n        self.hidden_dropout = hidden_dropout\n\n        self._validate_layers_block_type(\n            layers_block_type, expected_length=None, param_name=\"layers_block_type\"\n        )\n        self.layers_block_type = layers_block_type\n\n        # for backward compatibility\n        if num_key_value_heads is None:\n            num_key_value_heads = num_attention_heads\n\n        self.num_key_value_heads = num_key_value_heads\n        self.mlp_hidden_act = mlp_hidden_act\n        self.attention_bias = attention_bias\n        self.mlp_bias = mlp_bias\n        self.use_bias = use_bias\n        self.initializer_range = initializer_range\n        self.layer_norm_epsilon = layer_norm_epsilon\n        self.residual_in_fp32 = residual_in_fp32\n\n        self.use_cache = use_cache\n        self.num_logits_to_keep = num_logits_to_keep\n\n        # Mamba attributes.\n        self.use_mamba_kernels = use_mamba_kernels\n        self.mamba_n_groups = mamba_n_groups\n        self.mamba_head_dim = mamba_head_dim\n        self.ssm_state_size = ssm_state_size\n        self.mamba_num_heads = mamba_num_heads\n        self.conv_kernel = mamba_d_conv\n        self.expand = mamba_expand\n        self.mamba_hidden_act = mamba_hidden_act\n        self.time_step_min = mamba_dt_min\n        self.time_step_max = mamba_dt_max\n        self.time_step_limit = mamba_dt_limit\n        self.time_step_floor = mamba_dt_init_floor\n        self.use_conv_bias = mamba_conv_bias\n        self.mamba_proj_bias = mamba_proj_bias\n        self.mamba_chunk_size = mamba_chunk_size\n        self.rescale_prenorm_residual = rescale_prenorm_residual\n        # MoE attributes.\n        self.n_routed_experts = n_routed_experts\n        self.n_shared_experts = n_shared_experts\n        self.moe_intermediate_size = moe_intermediate_size\n        self.moe_shared_expert_intermediate_size = moe_shared_expert_intermediate_size\n        self.moe_latent_size = moe_latent_size\n        self.num_experts_per_tok = num_experts_per_tok\n        self.routed_scaling_factor = routed_scaling_factor\n        self.n_group = n_group\n        self.topk_group = topk_group\n        self.norm_topk_prob = norm_topk_prob\n        # MTP attributes.\n        self.num_nextn_predict_layers = num_nextn_predict_layers\n\n        if self.num_nextn_predict_layers > 0:\n            if mtp_layers_block_type is None:\n                raise ValueError(\n                    \"mtp_layers_block_type is required when num_nextn_predict_layers > 0. \"\n                    \"Please provide an explicit list of layer types for MTP layers. \"\n                    \"Example: mtp_layers_block_type=['attention', 'moe']\"\n                )\n            self._validate_layers_block_type(\n                mtp_layers_block_type, None, \"mtp_layers_block_type\"\n            )\n        self.mtp_layers_block_type = mtp_layers_block_type\n\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n\n    @property\n    def mamba_layer_ids(self):\n        return [\n            i\n            for i in range(self.num_hidden_layers)\n            if self.hybrid_override_pattern[i] == MAMBA\n        ]\n\n    @property\n    def full_attention_layer_ids(self):\n        return [\n            i\n            for i in range(self.num_hidden_layers)\n            if self.hybrid_override_pattern[i] == ATTENTION\n        ]\n\n    @property\n    def mamba2_cache_params(self) -> Mamba2CacheParams:\n        from sglang.srt.layers.dp_attention import get_attention_tp_size\n\n        shape = Mamba2StateShape.create(\n            tp_world_size=get_attention_tp_size(),\n            intermediate_size=self.mamba_num_heads * self.mamba_head_dim,\n            n_groups=self.n_groups,\n            num_heads=self.mamba_num_heads,\n            head_dim=self.mamba_head_dim,\n            state_size=self.ssm_state_size,\n            conv_kernel=self.conv_kernel,\n        )\n\n        return Mamba2CacheParams(\n            shape=shape, layers=self.mamba_layer_ids, dtype=mamba2_state_dtype(self)\n        )\n\n    @property\n    def num_hidden_layers(self) -> int:\n        \"\"\"\n        Number of hidden layers derived from the length of layers_block_type.\n        This property replaces the deprecated num_hidden_layers parameter.\n        \"\"\"\n        return len(self.layers_block_type)\n\n    @num_hidden_layers.setter\n    def num_hidden_layers(self, value):\n        \"\"\"\n        Setter for backward compatibility when loading configs.\n        The value is ignored since num_hidden_layers is computed from layers_block_type.\n        \"\"\"\n        pass\n\n    @property\n    def hybrid_override_pattern(self) -> str:\n        \"\"\"\n        Backward compatibility property.\n        Returns the pattern string representation of layers_block_type.\n        \"\"\"\n        return self._list_to_pattern(self.layers_block_type)\n\n    @hybrid_override_pattern.setter\n    def hybrid_override_pattern(self, value):\n        \"\"\"\n        Setter for backward compatibility when loading configs.\n        \"\"\"\n        self.layers_block_type = self._pattern_to_list(value)\n\n    @property\n    def mtp_hybrid_override_pattern(self) -> str:\n        \"\"\"\n        Backward compatibility property.\n        Returns the pattern string representation of mtp_layers_block_type.\n        \"\"\"\n        return self._list_to_pattern(self.mtp_layers_block_type)\n\n    @mtp_hybrid_override_pattern.setter\n    def mtp_hybrid_override_pattern(self, value):\n        \"\"\"Setter for backward compatibility when loading configs.\"\"\"\n        self.mtp_layers_block_type = self._pattern_to_list(value)\n\n    @staticmethod\n    def _list_to_pattern(layers_list: list[str]) -> str:\n        \"\"\"Convert list of layer types back to pattern string (for backward compatibility).\"\"\"\n        reverse_mapping = {\n            \"mamba\": MAMBA,\n            \"moe\": MOE,\n            \"attention\": ATTENTION,\n            \"mlp\": MLP,\n        }\n        return \"\".join(reverse_mapping[layer_type] for layer_type in layers_list)\n\n    @staticmethod\n    def _pattern_to_list(pattern: str) -> list[str]:\n        \"\"\"Convert pattern string to list of layer types (for backward compatibility).\"\"\"\n        if any(char not in {MAMBA, MOE, ATTENTION, MLP} for char in pattern):\n            raise ValueError(\n                \"Pattern must only contain characters 'M', '*', '-' or 'E'. \"\n                f\"Got: {pattern}\"\n            )\n        pattern_mapping = {\n            MAMBA: \"mamba\",\n            MOE: \"moe\",\n            ATTENTION: \"attention\",\n            MLP: \"mlp\",\n        }\n        return [pattern_mapping[char] for char in pattern]\n"
  },
  {
    "path": "python/sglang/srt/configs/olmo3.py",
    "content": "# coding=utf-8\n# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Olmo3 model configuration\"\"\"\n\nimport enum\n\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.utils import logging\n\nlogger = logging.get_logger(__name__)\n\n\nclass Olmo3LayerType(enum.Enum):\n    full_attention = \"full_attention\"\n    sliding_attention = \"sliding_attention\"\n\n\nclass Olmo3Config(PretrainedConfig):\n\n    model_type = \"olmo3\"\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n\n    def __init__(\n        self,\n        vocab_size=50304,\n        hidden_size=4096,\n        intermediate_size=11008,\n        num_hidden_layers=32,\n        num_attention_heads=32,\n        num_key_value_heads=None,\n        hidden_act=\"silu\",\n        max_position_embeddings=2048,\n        initializer_range=0.02,\n        use_cache=True,\n        pad_token_id=1,\n        bos_token_id=None,\n        eos_token_id=50279,\n        tie_word_embeddings=False,\n        rope_theta=10000.0,\n        rope_scaling=None,\n        attention_bias=False,\n        attention_dropout=0.0,\n        rms_norm_eps=1e-5,\n        sliding_window=4096,\n        layer_types=None,\n        **kwargs,\n    ):\n        # This model uses Olmo3ForCausalLM in transformers but Olmo2ForCausalLM\n        # in sglang.\n        if \"architectures\" not in kwargs:\n            kwargs[\"architectures\"] = [\"Olmo2ForCausalLM\"]\n        elif \"Olmo3ForCausalLM\" in kwargs[\"architectures\"]:\n            kwargs[\"architectures\"].remove(\"Olmo3ForCausalLM\")\n            kwargs[\"architectures\"].append(\"Olmo2ForCausalLM\")\n\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n\n        # for backward compatibility\n        if num_key_value_heads is None:\n            num_key_value_heads = num_attention_heads\n\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.rope_scaling = rope_scaling\n        self.attention_bias = attention_bias\n        self.attention_dropout = attention_dropout\n\n        self.rms_norm_eps = rms_norm_eps\n\n        self.sliding_window = sliding_window\n        self.layer_types = layer_types\n        if self.layer_types is None:\n            self.layer_types = [\n                \"sliding_attention\" if (i + 1) % 4 != 0 else \"full_attention\"\n                for i in range(self.num_hidden_layers)\n            ]\n"
  },
  {
    "path": "python/sglang/srt/configs/points_v15_chat.py",
    "content": "from typing import Optional, Union\n\nfrom transformers import PretrainedConfig, Qwen2Config\nfrom transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLVisionConfig\n\n\nclass POINTSV15ChatConfig(PretrainedConfig):\n    model_type = \"pointsv1.5_chat\"\n\n    def __init__(\n        self,\n        vision_config: Optional[Union[dict, Qwen2VLVisionConfig]] = None,\n        llm_config: Optional[Union[dict, Qwen2Config]] = None,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n        if vision_config is None:\n            vision_config = Qwen2VLVisionConfig()\n        elif isinstance(vision_config, dict):\n            vision_config = Qwen2VLVisionConfig(**vision_config)\n        self.vision_config = vision_config\n\n        if llm_config is None:\n            llm_config = Qwen2Config()\n        elif isinstance(llm_config, dict):\n            llm_config = Qwen2Config(**llm_config)\n\n        self.llm_config = llm_config\n        self.hidden_size = self.llm_config.hidden_size\n"
  },
  {
    "path": "python/sglang/srt/configs/qwen3_5.py",
    "content": "from transformers import PretrainedConfig\n\nfrom sglang.srt.configs.qwen3_next import Qwen3NextConfig\nfrom sglang.srt.configs.qwen3_vl import Qwen3VLVisionConfig\n\n\nclass Qwen3_5VisionConfig(Qwen3VLVisionConfig):\n    model_type = \"qwen3_5\"\n    base_config_key = \"vision_config\"\n\n\nclass Qwen3_5TextConfig(Qwen3NextConfig):\n    model_type = \"qwen3_5_text\"\n    base_config_key = \"text_config\"\n\n    def __init__(\n        self,\n        **kwargs,\n    ):\n        # HF Qwen3.5 checkpoints may provide RoPE settings under rope_parameters.\n        # Normalize it before parent init so downstream code sees the expected values.\n        rope_parameters = kwargs.pop(\"rope_parameters\", None)\n        if kwargs.get(\"rope_scaling\") is None and rope_parameters is not None:\n            kwargs[\"rope_scaling\"] = rope_parameters\n\n        super().__init__(**kwargs)\n        if self.rope_scaling is None:\n            self.rope_scaling = rope_parameters or {}\n\n        # Keep both names for compatibility with model code paths that read either.\n        self.rope_parameters = rope_parameters or self.rope_scaling\n\n\nclass Qwen3_5Config(PretrainedConfig):\n    r\"\"\"\n    This is the configuration class to store the configuration of a [`Qwen3_5Model`]. It is used to instantiate a\n    Qwen3.5 model according to the specified arguments, defining the model architecture. Instantiating a configuration\n    with the defaults will yield a similar configuration to that of\n    Qwen3.5.\n\n    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the\n    documentation from [`PretrainedConfig`] for more information.\n\n\n    Args:\n        text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3_5TextConfig`):\n            The config object or dictionary of the text backbone.\n        vision_config (`Union[PreTrainedConfig, dict]`,  *optional*, defaults to `Qwen3_5VisionConfig`):\n            The config object or dictionary of the vision backbone.\n        image_token_id (`int`, *optional*, defaults to 151655):\n            The image token index to encode the image prompt.\n        video_token_id (`int`, *optional*, defaults to 151656):\n            The video token index to encode the image prompt.\n        vision_start_token_id (`int`, *optional*, defaults to 151652):\n            The start token index to encode the image prompt.\n        vision_end_token_id (`int`, *optional*, defaults to 151653):\n            The end token index to encode the image prompt.\n        tie_word_embeddings (`bool`, *optional*, defaults to `False`):\n            Whether to tie the word embeddings.\n\n    ```python\n    >>> from transformers import Qwen3_5ForConditionalGeneration, Qwen3_5Config\n\n    >>> # Initializing a Qwen3.5 style configuration\n    >>> configuration = Qwen3_5Config()\n\n    >>> # Initializing a model from the Qwen3.5 style configuration\n    >>> model = Qwen3_5ForConditionalGeneration(configuration)\n\n    >>> # Accessing the model configuration\n    >>> configuration = model.config\n    ```\"\"\"\n\n    model_type = \"qwen3_5\"\n    sub_configs = {\n        \"vision_config\": Qwen3_5VisionConfig,\n        \"text_config\": Qwen3_5TextConfig,\n    }\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n\n    def __init__(\n        self,\n        text_config=None,\n        vision_config=None,\n        image_token_id=151655,\n        video_token_id=151656,\n        vision_start_token_id=151652,\n        vision_end_token_id=151653,\n        tie_word_embeddings=False,\n        **kwargs,\n    ):\n        if isinstance(vision_config, dict):\n            self.vision_config = self.sub_configs[\"vision_config\"](**vision_config)\n        elif vision_config is None:\n            self.vision_config = self.sub_configs[\"vision_config\"]()\n\n        if isinstance(text_config, dict):\n            self.text_config = self.sub_configs[\"text_config\"](**text_config)\n        elif text_config is None:\n            self.text_config = self.sub_configs[\"text_config\"]()\n\n        self.image_token_id = image_token_id\n        self.video_token_id = video_token_id\n        self.vision_start_token_id = vision_start_token_id\n        self.vision_end_token_id = vision_end_token_id\n        super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings)\n\n\nclass Qwen3_5MoeVisionConfig(Qwen3_5VisionConfig):\n    model_type = \"qwen3_5_moe\"\n\n\nclass Qwen3_5MoeTextConfig(Qwen3_5TextConfig):\n    model_type = \"qwen3_5_moe_text\"\n\n\nclass Qwen3_5MoeConfig(Qwen3_5Config):\n    model_type = \"qwen3_5_moe\"\n    sub_configs = {\n        \"vision_config\": Qwen3_5MoeVisionConfig,\n        \"text_config\": Qwen3_5MoeTextConfig,\n    }\n"
  },
  {
    "path": "python/sglang/srt/configs/qwen3_next.py",
    "content": "# coding=utf-8\n# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Qwen3Hybrid model configuration\"\"\"\n\nimport enum\n\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.utils import logging\n\nfrom sglang.srt.configs.mamba_utils import (\n    Mamba2CacheParams,\n    Mamba2StateShape,\n    mamba2_state_dtype,\n)\nfrom sglang.srt.configs.update_config import adjust_tp_num_heads_if_necessary\nfrom sglang.srt.utils import is_cpu\n\nlogger = logging.get_logger(__name__)\n_is_cpu = is_cpu()\n\n\nclass HybridLayerType(enum.Enum):\n    full_attention = \"attention\"\n    linear_attention = \"linear_attention\"\n\n\nclass Qwen3NextConfig(PretrainedConfig):\n    r\"\"\"\n    This is the configuration class to store the configuration of a [`Qwen3NextModel`]. It is used to instantiate a\n    Qwen3-Next model according to the specified arguments, defining the model architecture.\n    Instantiating a configuration with the defaults will yield a similar configuration to that of\n    Qwen3-Next-80B-A3B-Instruct [Qwen/Qwen3-Next-80B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-Next-80B-A3B-Instruct).\n\n    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the\n    documentation from [`PretrainedConfig`] for more information.\n\n\n    Args:\n        vocab_size (`int`, *optional*, defaults to 151936):\n            Vocabulary size of the model. Defines the number of different tokens that can be represented by the\n            `inputs_ids`.\n        hidden_size (`int`, *optional*, defaults to 2048):\n            Dimension of the hidden representations.\n        intermediate_size (`int`, *optional*, defaults to 5632):\n            Dimension of the MLP representations.\n        num_hidden_layers (`int`, *optional*, defaults to 48):\n            Number of hidden layers in the Transformer encoder.\n        num_attention_heads (`int`, *optional*, defaults to 16):\n            Number of attention heads for each attention layer in the Transformer encoder.\n        num_key_value_heads (`int`, *optional*, defaults to 2):\n            This is the number of key_value heads that should be used to implement Grouped Query Attention. If\n            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if\n            `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When\n            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed\n            by meanpooling all the original heads within that group. For more details checkout [this\n            paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.\n        hidden_act (`str`, *optional*, defaults to `\"silu\"`):\n            The non-linear activation function in the decoder.\n        max_position_embeddings (`int`, *optional*, defaults to 32768):\n            The maximum sequence length that this model might ever be used with.\n        initializer_range (`float`, *optional*, defaults to 0.02):\n            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.\n        rms_norm_eps (`float`, *optional*, defaults to 1e-06):\n            The epsilon used by the rms normalization layers.\n        use_cache (`bool`, *optional*, defaults to `True`):\n            Whether or not the model should return the last key/values attentions (not used by all models). Only\n            relevant if `config.is_decoder=True`.\n        tie_word_embeddings (`bool`, *optional*, defaults to `False`):\n            Whether the model's input and output word embeddings should be tied.\n        rope_theta (`float`, *optional*, defaults to 10000.0):\n            The base period of the RoPE embeddings.\n        rope_scaling (`Dict`, *optional*):\n            Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type\n            and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value\n            accordingly.\n            Expected contents:\n                `rope_type` (`str`):\n                    The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',\n                    'llama3'], with 'default' being the original RoPE implementation.\n                `factor` (`float`, *optional*):\n                    Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In\n                    most scaling types, a `factor` of x will enable the model to handle sequences of length x *\n                    original maximum pre-trained length.\n                `original_max_position_embeddings` (`int`, *optional*):\n                    Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during\n                    pretraining.\n                `attention_factor` (`float`, *optional*):\n                    Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention\n                    computation. If unspecified, it defaults to value recommended by the implementation, using the\n                    `factor` field to infer the suggested value.\n                `beta_fast` (`float`, *optional*):\n                    Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear\n                    ramp function. If unspecified, it defaults to 32.\n                `beta_slow` (`float`, *optional*):\n                    Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear\n                    ramp function. If unspecified, it defaults to 1.\n                `short_factor` (`List[float]`, *optional*):\n                    Only used with 'longrope'. The scaling factor to be applied to short contexts (<\n                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden\n                    size divided by the number of attention heads divided by 2\n                `long_factor` (`List[float]`, *optional*):\n                    Only used with 'longrope'. The scaling factor to be applied to long contexts (<\n                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden\n                    size divided by the number of attention heads divided by 2\n                `low_freq_factor` (`float`, *optional*):\n                    Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE\n                `high_freq_factor` (`float`, *optional*):\n                    Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE\n        partial_rotary_factor (`float`, *optional*, defaults to 0.25):\n            Percentage of the query and keys which will have rotary embedding.\n        attention_bias (`bool`, *optional*, defaults to `False`):\n            Whether to use a bias in the query, key, value and output projection layers during self-attention.\n        attention_dropout (`float`, *optional*, defaults to 0.0):\n            The dropout ratio for the attention probabilities.\n        head_dim (`int`, *optional*, defaults to 256):\n            Projection weights dimension in multi-head attention.\n        linear_conv_kernel_dim (`int`, *optional*, defaults to 4):\n            Kernel size of the convolution used in linear attention layers.\n        linear_key_head_dim (`int`, *optional*, defaults to 128):\n            Dimension of each key head in linear attention.\n        linear_value_head_dim (`int`, *optional*, defaults to 128):\n            Dimension of each value head in linear attention.\n        linear_num_key_heads (`int`, *optional*, defaults to 16):\n            Number of key heads used in linear attention layers.\n        linear_num_value_heads (`int`, *optional*, defaults to 32):\n            Number of value heads used in linear attention layers.\n        decoder_sparse_step (`int`, *optional*, defaults to 1):\n            The frequency of the MoE layer.\n        moe_intermediate_size (`int`, *optional*, defaults to 512):\n            Intermediate size of the routed expert.\n        shared_expert_intermediate_size (`int`, *optional*, defaults to 512):\n            Intermediate size of the shared expert.\n        num_experts_per_tok (`int`, *optional*, defaults to 10):\n            Number of selected experts.\n        num_experts (`int`, *optional*, defaults to 512):\n            Number of routed experts.\n        norm_topk_prob (`bool`, *optional*, defaults to `True`):\n            Whether to normalize the topk probabilities.\n        output_router_logits (`bool`, *optional*, defaults to `False`):\n            Whether or not the router logits should be returned by the model. Enabling this will also\n            allow the model to output the auxiliary loss, including load balancing loss and router z-loss.\n        router_aux_loss_coef (`float`, *optional*, defaults to 0.001):\n            The aux loss factor for the total loss.\n        mlp_only_layers (`list[int]`, *optional*, defaults to `[]`):\n            Indicate which layers use Qwen3NextMLP rather than Qwen3NextSparseMoeBlock\n            The list contains layer index, from 0 to num_layers-1 if we have num_layers layers\n            If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity.\n        layer_types (`list[str]`, *optional*, defaults to None):\n            Types of each layer (attention or linear).\n\n    ```python\n    >>> from transformers import Qwen3NextModel, Qwen3NextConfig\n\n    >>> # Initializing a Qwen3Next style configuration\n    >>> configuration =  Qwen3NextConfig()\n\n    >>> # Initializing a model from the Qwen3-Next-80B-A3B style configuration\n    >>> model = Qwen3NextModel(configuration)\n\n    >>> # Accessing the model configuration\n    >>> configuration = model.config\n    ```\n    \"\"\"\n\n    model_type = \"qwen3_next\"\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n\n    def __init__(\n        self,\n        vocab_size=151936,\n        hidden_size=2048,\n        intermediate_size=5632,\n        num_hidden_layers=48,\n        num_attention_heads=16,\n        num_key_value_heads=2,\n        hidden_act=\"silu\",\n        max_position_embeddings=32768,\n        initializer_range=0.02,\n        rms_norm_eps=1e-6,\n        use_cache=True,\n        tie_word_embeddings=False,\n        rope_theta=10000.0,\n        rope_scaling=None,\n        partial_rotary_factor=0.25,\n        attention_bias=False,\n        attention_dropout=0.0,\n        head_dim=256,\n        linear_conv_kernel_dim=4,\n        linear_key_head_dim=128,\n        linear_value_head_dim=128,\n        linear_num_key_heads=16,\n        linear_num_value_heads=32,\n        decoder_sparse_step=1,\n        moe_intermediate_size=512,\n        shared_expert_intermediate_size=512,\n        num_experts_per_tok=10,\n        num_experts=512,\n        norm_topk_prob=True,\n        output_router_logits=False,\n        router_aux_loss_coef=0.001,\n        mlp_only_layers=[],\n        layer_types=None,\n        **kwargs,\n    ):\n        super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.rope_scaling = rope_scaling\n        self.partial_rotary_factor = partial_rotary_factor\n        self.attention_bias = attention_bias\n        self.attention_dropout = attention_dropout\n        self.head_dim = head_dim\n\n        # linear attention (gdn now part)\n        self.linear_conv_kernel_dim = linear_conv_kernel_dim\n        self.linear_key_head_dim = linear_key_head_dim\n        self.linear_value_head_dim = linear_value_head_dim\n        self.linear_num_key_heads = linear_num_key_heads\n        self.linear_num_value_heads = linear_num_value_heads\n\n        # MoE arguments\n        self.decoder_sparse_step = decoder_sparse_step\n        self.moe_intermediate_size = moe_intermediate_size\n        self.shared_expert_intermediate_size = shared_expert_intermediate_size\n        self.num_experts_per_tok = num_experts_per_tok\n        self.num_experts = num_experts\n        self.norm_topk_prob = norm_topk_prob\n        self.output_router_logits = output_router_logits\n        self.router_aux_loss_coef = router_aux_loss_coef\n        self.mlp_only_layers = mlp_only_layers\n\n    @property\n    def layers_block_type(self):\n        layer_type_list = []\n\n        for l in range(self.num_hidden_layers):\n            if (l + 1) % self.full_attention_interval == 0:\n                layer_type_list.append(HybridLayerType.full_attention.value)\n            else:\n                layer_type_list.append(HybridLayerType.linear_attention.value)\n\n        return layer_type_list\n\n    @property\n    def linear_layer_ids(self):\n        return [\n            i\n            for i, type_value in enumerate(self.layers_block_type)\n            if type_value == HybridLayerType.linear_attention.value\n        ]\n\n    @property\n    def full_attention_layer_ids(self):\n        return [\n            i\n            for i, type_value in enumerate(self.layers_block_type)\n            if type_value == HybridLayerType.full_attention.value\n        ]\n\n    @property\n    def mamba2_cache_params(self) -> Mamba2CacheParams:\n        from sglang.srt.layers.dp_attention import get_attention_tp_size\n\n        if _is_cpu:\n            world_size = get_attention_tp_size()\n            adjust_tp_num_heads_if_necessary(self, world_size, False)\n\n        shape = Mamba2StateShape.create(\n            tp_world_size=get_attention_tp_size(),\n            intermediate_size=self.linear_value_head_dim * self.linear_num_value_heads,\n            n_groups=self.linear_num_key_heads,\n            num_heads=self.linear_num_value_heads,\n            head_dim=self.linear_value_head_dim,\n            state_size=self.linear_key_head_dim,\n            conv_kernel=self.linear_conv_kernel_dim,\n        )\n\n        return Mamba2CacheParams(\n            shape=shape, layers=self.linear_layer_ids, dtype=mamba2_state_dtype(self)\n        )\n"
  },
  {
    "path": "python/sglang/srt/configs/qwen3_omni.py",
    "content": "from transformers import PretrainedConfig\nfrom transformers.configuration_utils import layer_type_validation\n\nfrom sglang.utils import logger\n\n\nclass Qwen3OmniMoeAudioEncoderConfig(PretrainedConfig):\n    model_type = \"qwen3_omni_moe_audio_encoder\"\n\n    def __init__(\n        self,\n        num_mel_bins=128,\n        encoder_layers=32,\n        encoder_attention_heads=20,\n        encoder_ffn_dim=5120,\n        d_model=1280,\n        dropout=0,\n        attention_dropout=0,\n        activation_function=\"gelu\",\n        activation_dropout=0,\n        scale_embedding=False,\n        initializer_range=0.02,\n        max_source_positions=1500,\n        n_window=100,\n        output_dim=3584,\n        n_window_infer=400,\n        conv_chunksize=500,\n        downsample_hidden_size=480,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n\n        self.num_mel_bins = num_mel_bins\n        self.d_model = d_model\n        self.encoder_layers = encoder_layers\n        self.encoder_attention_heads = encoder_attention_heads\n        self.encoder_ffn_dim = encoder_ffn_dim\n        self.dropout = dropout\n        self.attention_dropout = attention_dropout\n        self.activation_function = activation_function\n        self.activation_dropout = activation_dropout\n        self.num_hidden_layers = encoder_layers\n        self.initializer_range = initializer_range\n        self.scale_embedding = (\n            scale_embedding  # scale factor will be sqrt(d_model) if True\n        )\n        self.max_source_positions = max_source_positions\n        self.n_window = n_window\n        self.output_dim = output_dim\n        self.n_window_infer = n_window_infer\n        self.conv_chunksize = conv_chunksize\n        self.downsample_hidden_size = downsample_hidden_size\n\n\nclass Qwen3OmniMoeVisionEncoderConfig(PretrainedConfig):\n    model_type = \"qwen3_omni_moe_vision_encoder\"\n    base_config_key = \"vision_config\"\n\n    def __init__(\n        self,\n        depth=27,\n        hidden_size=1152,\n        hidden_act=\"gelu_pytorch_tanh\",\n        intermediate_size=4304,\n        num_heads=16,\n        in_channels=3,\n        patch_size=16,\n        spatial_merge_size=2,\n        temporal_patch_size=2,\n        out_hidden_size=3584,\n        num_position_embeddings=2304,\n        deepstack_visual_indexes=[8, 16, 24],\n        initializer_range=0.02,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n\n        self.depth = depth\n        self.hidden_size = hidden_size\n        self.hidden_act = hidden_act\n        self.intermediate_size = intermediate_size\n        self.num_heads = num_heads\n        self.in_channels = in_channels\n        self.patch_size = patch_size\n        self.spatial_merge_size = spatial_merge_size\n        self.temporal_patch_size = temporal_patch_size\n        self.out_hidden_size = out_hidden_size\n        self.num_position_embeddings = num_position_embeddings\n        self.initializer_range = initializer_range\n        self.deepstack_visual_indexes = deepstack_visual_indexes\n\n\nclass Qwen3OmniMoeTextConfig(PretrainedConfig):\n    model_type = \"qwen3_omni_moe_text\"\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n\n    # Default tensor parallel plan for base model `Qwen3OmniMoeText`\n    base_model_tp_plan = {\n        \"layers.*.self_attn.q_proj\": \"colwise\",\n        \"layers.*.self_attn.k_proj\": \"colwise\",\n        \"layers.*.self_attn.v_proj\": \"colwise\",\n        \"layers.*.self_attn.o_proj\": \"rowwise\",\n        \"layers.*.mlp.experts.*.gate_proj\": \"colwise\",\n        \"layers.*.mlp.experts.*.up_proj\": \"colwise\",\n        \"layers.*.mlp.experts.*.down_proj\": \"rowwise\",\n        \"layers.*.mlp.gate_proj\": \"colwise\",\n        \"layers.*.mlp.up_proj\": \"colwise\",\n        \"layers.*.mlp.down_proj\": \"rowwise\",\n    }\n    base_model_pp_plan = {\n        \"embed_tokens\": ([\"input_ids\"], [\"inputs_embeds\"]),\n        \"layers\": ([\"hidden_states\", \"attention_mask\"], [\"hidden_states\"]),\n        \"norm\": ([\"hidden_states\"], [\"hidden_states\"]),\n    }\n\n    def __init__(\n        self,\n        vocab_size=3584,\n        hidden_size=2048,\n        intermediate_size=18944,\n        num_hidden_layers=28,\n        num_attention_heads=28,\n        num_key_value_heads=4,\n        hidden_act=\"silu\",\n        max_position_embeddings=32768,\n        initializer_range=0.02,\n        rms_norm_eps=1e-6,\n        use_cache=True,\n        tie_word_embeddings=False,\n        rope_theta=1000000.0,\n        rope_scaling=None,\n        attention_bias=False,\n        sliding_window=None,\n        attention_dropout=0,\n        decoder_sparse_step=1,\n        moe_intermediate_size=768,\n        num_experts_per_tok=8,\n        num_experts=128,\n        norm_topk_prob=True,\n        output_router_logits=False,\n        router_aux_loss_coef=0.001,\n        mlp_only_layers=None,\n        **kwargs,\n    ):\n        super().__init__(\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.sliding_window = sliding_window\n\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.rope_scaling = rope_scaling\n        self.attention_bias = attention_bias\n        self.attention_dropout = attention_dropout\n        # Validate the correctness of rotary position embeddings parameters\n        # BC: if there is a 'type' field, move it to 'rope_type'.\n        if self.rope_scaling is not None and \"type\" in self.rope_scaling:\n            self.rope_scaling[\"rope_type\"] = self.rope_scaling[\"type\"]\n\n        # MoE arguments\n        self.decoder_sparse_step = decoder_sparse_step\n        self.moe_intermediate_size = moe_intermediate_size\n        self.num_experts_per_tok = num_experts_per_tok\n        self.num_experts = num_experts\n        self.norm_topk_prob = norm_topk_prob\n        self.output_router_logits = output_router_logits\n        self.router_aux_loss_coef = router_aux_loss_coef\n        self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers\n\n\nclass Qwen3OmniMoeThinkerConfig(PretrainedConfig):\n    model_type = \"qwen3_omni_moe_thinker\"\n    attribute_map = {\n        \"image_token_id\": \"image_token_index\",\n        \"video_token_id\": \"video_token_index\",\n        \"audio_token_id\": \"audio_token_index\",\n    }\n    sub_configs = {\n        \"audio_config\": Qwen3OmniMoeAudioEncoderConfig,\n        \"vision_config\": Qwen3OmniMoeVisionEncoderConfig,\n        \"text_config\": Qwen3OmniMoeTextConfig,\n    }\n\n    def __init__(\n        self,\n        audio_config=None,\n        vision_config=None,\n        text_config=None,\n        audio_token_id=151646,\n        image_token_id=151655,\n        video_token_id=151656,\n        position_id_per_seconds=25,\n        audio_start_token_id=151647,\n        user_token_id=872,\n        initializer_range=0.02,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n        self.user_token_id = user_token_id\n        self.position_id_per_seconds = position_id_per_seconds\n        self.audio_start_token_id = audio_start_token_id\n        self.initializer_range = initializer_range\n\n        if isinstance(vision_config, dict):\n            vision_config = Qwen3OmniMoeVisionEncoderConfig(**vision_config)\n        elif vision_config is None:\n            vision_config = Qwen3OmniMoeVisionEncoderConfig()\n        self.vision_config = vision_config\n\n        if isinstance(audio_config, dict):\n            audio_config = Qwen3OmniMoeAudioEncoderConfig(**audio_config)\n        elif audio_config is None:\n            audio_config = Qwen3OmniMoeAudioEncoderConfig()\n        self.audio_config = audio_config\n\n        if isinstance(text_config, dict):\n            text_config = Qwen3OmniMoeTextConfig(**text_config)\n        elif text_config is None:\n            text_config = Qwen3OmniMoeTextConfig()\n        self.text_config = text_config\n        self.audio_token_id = audio_token_id\n        self.image_token_id = image_token_id\n        self.video_token_id = video_token_id\n\n\nclass Qwen3OmniMoeTalkerCodePredictorConfig(PretrainedConfig):\n\n    model_type = \"qwen3_omni_moe_talker_code_predictor\"\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n\n    # Default tensor parallel plan for base model `Qwen3OmniMoeTalkerCodePredictor`\n    base_model_tp_plan = {\n        \"layers.*.self_attn.q_proj\": \"colwise\",\n        \"layers.*.self_attn.k_proj\": \"colwise\",\n        \"layers.*.self_attn.v_proj\": \"colwise\",\n        \"layers.*.self_attn.o_proj\": \"rowwise\",\n        \"layers.*.mlp.gate_proj\": \"colwise\",\n        \"layers.*.mlp.up_proj\": \"colwise\",\n        \"layers.*.mlp.down_proj\": \"rowwise\",\n    }\n    base_model_pp_plan = {\n        \"embed_tokens\": ([\"input_ids\"], [\"inputs_embeds\"]),\n        \"layers\": ([\"hidden_states\", \"attention_mask\"], [\"hidden_states\"]),\n        \"norm\": ([\"hidden_states\"], [\"hidden_states\"]),\n    }\n\n    def __init__(\n        self,\n        vocab_size=2048,\n        hidden_size=1024,\n        intermediate_size=3072,\n        num_hidden_layers=5,\n        num_attention_heads=16,\n        num_key_value_heads=8,\n        head_dim=128,\n        hidden_act=\"silu\",\n        max_position_embeddings=32768,\n        initializer_range=0.02,\n        rms_norm_eps=0.000001,\n        use_cache=True,\n        tie_word_embeddings=False,\n        rope_theta=10000,\n        rope_scaling=None,\n        attention_bias=False,\n        sliding_window=None,\n        layer_types=None,\n        attention_dropout=0,\n        num_code_groups=32,\n        **kwargs,\n    ):\n        super().__init__(\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.sliding_window = sliding_window\n\n        # for backward compatibility\n        if num_key_value_heads is None:\n            num_key_value_heads = num_attention_heads\n\n        self.num_key_value_heads = num_key_value_heads\n        self.head_dim = head_dim\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.rope_scaling = rope_scaling\n        self.attention_bias = attention_bias\n        self.attention_dropout = attention_dropout\n        # Validate the correctness of rotary position embeddings parameters\n        # BC: if there is a 'type' field, move it to 'rope_type'.\n        if self.rope_scaling is not None and \"type\" in self.rope_scaling:\n            self.rope_scaling[\"rope_type\"] = self.rope_scaling[\"type\"]\n\n        self.layer_types = layer_types\n        if self.layer_types is None:\n            self.layer_types = [\n                (\n                    \"sliding_attention\"\n                    if self.sliding_window is not None and i >= self.max_window_layers\n                    else \"full_attention\"\n                )\n                for i in range(self.num_hidden_layers)\n            ]\n        layer_type_validation(self.layer_types, self.num_hidden_layers)\n        self.num_code_groups = num_code_groups\n\n\nclass Qwen3OmniMoeTalkerTextConfig(PretrainedConfig):\n\n    model_type = \"qwen3_omni_moe_talker_text\"\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n\n    # Default tensor parallel plan for base model `Qwen3OmniMoeTalkerText`\n    base_model_tp_plan = {\n        \"layers.*.self_attn.q_proj\": \"colwise\",\n        \"layers.*.self_attn.k_proj\": \"colwise\",\n        \"layers.*.self_attn.v_proj\": \"colwise\",\n        \"layers.*.self_attn.o_proj\": \"rowwise\",\n        \"layers.*.mlp.experts.*.gate_proj\": \"colwise\",\n        \"layers.*.mlp.experts.*.up_proj\": \"colwise\",\n        \"layers.*.mlp.experts.*.down_proj\": \"rowwise\",\n        \"layers.*.mlp.gate_proj\": \"colwise\",\n        \"layers.*.mlp.up_proj\": \"colwise\",\n        \"layers.*.mlp.down_proj\": \"rowwise\",\n    }\n    base_model_pp_plan = {\n        \"embed_tokens\": ([\"input_ids\"], [\"inputs_embeds\"]),\n        \"layers\": ([\"hidden_states\", \"attention_mask\"], [\"hidden_states\"]),\n        \"norm\": ([\"hidden_states\"], [\"hidden_states\"]),\n    }\n\n    def __init__(\n        self,\n        vocab_size=3072,\n        hidden_size=1024,\n        intermediate_size=2048,\n        num_hidden_layers=20,\n        num_attention_heads=16,\n        num_key_value_heads=2,\n        hidden_act=\"silu\",\n        max_position_embeddings=32768,\n        initializer_range=0.02,\n        rms_norm_eps=0.000001,\n        use_cache=True,\n        tie_word_embeddings=False,\n        rope_theta=10000,\n        rope_scaling=None,\n        attention_bias=False,\n        sliding_window=None,\n        attention_dropout=0,\n        decoder_sparse_step=1,\n        moe_intermediate_size=384,\n        num_experts_per_tok=8,\n        num_experts=128,\n        norm_topk_prob=False,\n        output_router_logits=False,\n        router_aux_loss_coef=0.001,\n        mlp_only_layers=None,\n        **kwargs,\n    ):\n        super().__init__(\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.sliding_window = sliding_window\n\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.rope_scaling = rope_scaling\n        self.attention_bias = attention_bias\n        self.attention_dropout = attention_dropout\n        # Validate the correctness of rotary position embeddings parameters\n        # BC: if there is a 'type' field, move it to 'rope_type'.\n        if self.rope_scaling is not None and \"type\" in self.rope_scaling:\n            self.rope_scaling[\"rope_type\"] = self.rope_scaling[\"type\"]\n\n        # MoE arguments\n        self.decoder_sparse_step = decoder_sparse_step\n        self.moe_intermediate_size = moe_intermediate_size\n        self.num_experts_per_tok = num_experts_per_tok\n        self.num_experts = num_experts\n        self.norm_topk_prob = norm_topk_prob\n        self.output_router_logits = output_router_logits\n        self.router_aux_loss_coef = router_aux_loss_coef\n        self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers\n\n\nclass Qwen3OmniMoeTalkerConfig(PretrainedConfig):\n\n    sub_configs = {\n        \"code_predictor_config\": Qwen3OmniMoeTalkerCodePredictorConfig,\n        \"text_config\": Qwen3OmniMoeTalkerTextConfig,\n    }\n\n    def __init__(\n        self,\n        code_predictor_config=None,\n        text_config=None,\n        num_code_groups=32,\n        thinker_hidden_size=2048,\n        codec_eos_token_id=4198,\n        accept_hidden_layer=18,\n        codec_nothink_id=4203,\n        codec_think_bos_id=4204,\n        codec_think_eos_id=4205,\n        codec_pad_id=4196,\n        codec_bos_id=4197,\n        audio_token_id=151646,\n        image_token_id=151655,\n        video_token_id=151656,\n        vision_start_token_id=151652,\n        position_id_per_seconds=25,\n        audio_start_token_id=151669,\n        speaker_id=None,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n        if code_predictor_config is None:\n            code_predictor_config = {}\n            self.code_predictor_config = Qwen3OmniMoeTalkerCodePredictorConfig()\n            logger.info(\n                \"code_predictor_config is None. Initializing code_predictor_config model with default values\"\n            )\n        elif isinstance(code_predictor_config, Qwen3OmniMoeTalkerCodePredictorConfig):\n            self.code_predictor_config = code_predictor_config\n        else:\n            self.code_predictor_config = Qwen3OmniMoeTalkerCodePredictorConfig(\n                **code_predictor_config\n            )\n\n        if text_config is None:\n            text_config = {}\n            self.text_config = Qwen3OmniMoeTalkerTextConfig()\n            logger.info(\n                \"talker text_config is None. Initializing talker text model with default values\"\n            )\n        elif isinstance(text_config, Qwen3OmniMoeTalkerTextConfig):\n            self.text_config = text_config\n        else:\n            self.text_config = Qwen3OmniMoeTalkerTextConfig(**text_config)\n        self.num_code_groups = num_code_groups\n        self.thinker_hidden_size = thinker_hidden_size\n        self.codec_eos_token_id = codec_eos_token_id\n        self.accept_hidden_layer = accept_hidden_layer\n        self.codec_nothink_id = codec_nothink_id\n        self.codec_think_bos_id = codec_think_bos_id\n        self.codec_think_eos_id = codec_think_eos_id\n        self.codec_pad_id = codec_pad_id\n        self.codec_bos_id = codec_bos_id\n        self.audio_token_id = audio_token_id\n        self.image_token_id = image_token_id\n        self.video_token_id = video_token_id\n        self.position_id_per_seconds = position_id_per_seconds\n        self.audio_start_token_id = audio_start_token_id\n        self.vision_start_token_id = vision_start_token_id\n        self.speaker_id = speaker_id\n\n\nclass Qwen3OmniMoeCode2WavConfig(PretrainedConfig):\n\n    def __init__(\n        self,\n        codebook_size=2048,\n        hidden_size=1024,\n        max_position_embeddings=8000,\n        rope_theta=10000,\n        num_attention_heads=16,\n        num_key_value_heads=16,\n        attention_bias=False,\n        sliding_window=72,\n        intermediate_size=3072,\n        hidden_act=\"silu\",\n        layer_scale_initial_scale=0.01,\n        rms_norm_eps=1e-5,\n        num_hidden_layers=8,\n        num_quantizers=16,\n        upsample_rates=(8, 5, 4, 3),\n        upsampling_ratios=(2, 2),\n        decoder_dim=1536,\n        attention_dropout=0.0,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n        self.codebook_size = codebook_size\n        self.hidden_size = hidden_size\n        self.max_position_embeddings = max_position_embeddings\n        self.rope_theta = rope_theta\n        self.num_attention_heads = num_attention_heads\n        self.num_key_value_heads = num_key_value_heads\n        self.attention_bias = attention_bias\n        self.sliding_window = sliding_window\n        self.intermediate_size = intermediate_size\n        self.hidden_act = hidden_act\n        self.layer_scale_initial_scale = layer_scale_initial_scale\n        self.rms_norm_eps = rms_norm_eps\n        self.num_hidden_layers = num_hidden_layers\n        self.num_quantizers = num_quantizers\n        self.upsample_rates = upsample_rates\n        self.upsampling_ratios = upsampling_ratios\n        self.decoder_dim = decoder_dim\n        self.attention_dropout = attention_dropout\n\n    @property\n    def layer_types(self):\n        \"\"\"\n        All layer in code2wav should be sliding attention\n        \"\"\"\n        return [\"sliding_attention\"] * self.num_hidden_layers\n\n\nclass Qwen3OmniMoeConfig(PretrainedConfig):\n\n    model_type = \"qwen3_omni_moe\"\n    sub_configs = {\n        \"thinker_config\": Qwen3OmniMoeThinkerConfig,\n        \"talker_config\": Qwen3OmniMoeTalkerConfig,\n        \"code2wav_config\": Qwen3OmniMoeCode2WavConfig,\n    }\n\n    def __init__(\n        self,\n        thinker_config=None,\n        talker_config=None,\n        code2wav_config=None,\n        enable_audio_output=True,\n        im_start_token_id=151644,\n        im_end_token_id=151645,\n        tts_pad_token_id=151671,\n        tts_bos_token_id=151672,\n        tts_eos_token_id=151673,\n        system_token_id=8948,\n        user_token_id=872,\n        assistant_token_id=77091,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n        if thinker_config is None:\n            thinker_config = {}\n            logger.info(\n                \"thinker_config is None. Initializing thinker model with default values\"\n            )\n\n        if talker_config is None:\n            talker_config = {}\n            logger.info(\n                \"talker_config is None. Initializing talker model with default values\"\n            )\n\n        if code2wav_config is None:\n            code2wav_config = {}\n            logger.info(\n                \"code2wav_config is None. Initializing code2wav model with default values\"\n            )\n\n        self.thinker_config = Qwen3OmniMoeThinkerConfig(**thinker_config)\n        self.talker_config = Qwen3OmniMoeTalkerConfig(**talker_config)\n        self.code2wav_config = Qwen3OmniMoeCode2WavConfig(**code2wav_config)\n        self.enable_audio_output = enable_audio_output\n        self.im_start_token_id = im_start_token_id\n        self.im_end_token_id = im_end_token_id\n        self.tts_pad_token_id = tts_pad_token_id\n        self.tts_bos_token_id = tts_bos_token_id\n        self.tts_eos_token_id = tts_eos_token_id\n        self.system_token_id = system_token_id\n        self.user_token_id = user_token_id\n        self.assistant_token_id = assistant_token_id\n\n    def get_text_config(self, decoder=False) -> \"PretrainedConfig\":\n        \"\"\"\n        Returns the config that is meant to be used with text IO. On most models, it is the original config instance\n        itself. On specific composite models, it is under a set of valid names.\n\n        Args:\n            decoder (`Optional[bool]`, *optional*, defaults to `False`):\n                If set to `True`, then only search for decoder config names.\n        \"\"\"\n        # Overridden for deeply nested config like Qwen2-Omni. We don't have any omni model\n        # except for Qwen yet. This has to be generalized if more deeply nested configs are\n        # added. NOTE: currently method used only by vLLM\n        return self.thinker_config.get_text_config()\n"
  },
  {
    "path": "python/sglang/srt/configs/qwen3_vl.py",
    "content": "from transformers import PretrainedConfig\n\n\nclass Qwen3VLVisionConfig(PretrainedConfig):\n    model_type = \"qwen3_vl\"\n    base_config_key = \"vision_config\"\n\n    def __init__(\n        self,\n        depth=27,\n        hidden_size=1152,\n        hidden_act=\"gelu_pytorch_tanh\",\n        intermediate_size=4304,\n        num_heads=16,\n        in_channels=3,\n        patch_size=16,\n        spatial_merge_size=2,\n        temporal_patch_size=2,\n        out_hidden_size=3584,\n        num_position_embeddings=2304,\n        deepstack_visual_indexes=[8, 16, 24],\n        initializer_range=0.02,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n\n        self.depth = depth\n        self.hidden_size = hidden_size\n        self.hidden_act = hidden_act\n        self.intermediate_size = intermediate_size\n        self.num_heads = num_heads\n        self.in_channels = in_channels\n        self.patch_size = patch_size\n        self.spatial_merge_size = spatial_merge_size\n        self.temporal_patch_size = temporal_patch_size\n        self.out_hidden_size = out_hidden_size\n        self.num_position_embeddings = num_position_embeddings\n        self.initializer_range = initializer_range\n        self.deepstack_visual_indexes = deepstack_visual_indexes\n\n\nclass Qwen3VLTextConfig(PretrainedConfig):\n    r\"\"\"\n    This is the configuration class to store the configuration of a [`Qwen3VLTextModel`]. It is used to instantiate a\n    Qwen3-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration\n    with the defaults will yield a similar configuration to that of\n    Qwen3-VL-4B-Instruct [Qwen/Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct).\n\n    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the\n    documentation from [`PretrainedConfig`] for more information.\n\n    Args:\n        vocab_size (`int`, *optional*, defaults to 151936):\n            Vocabulary size of the Qwen3VL model. Defines the number of different tokens that can be represented by the\n            `inputs_ids` passed when calling [`Qwen3VLModel`]\n        hidden_size (`int`, *optional*, defaults to 4096):\n            Dimension of the hidden representations.\n        intermediate_size (`int`, *optional*, defaults to 22016):\n            Dimension of the MLP representations.\n        num_hidden_layers (`int`, *optional*, defaults to 32):\n            Number of hidden layers in the Transformer encoder.\n        num_attention_heads (`int`, *optional*, defaults to 32):\n            Number of attention heads for each attention layer in the Transformer encoder.\n        num_key_value_heads (`int`, *optional*, defaults to 32):\n            This is the number of key_value heads that should be used to implement Grouped Query Attention. If\n            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if\n            `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When\n            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed\n            by meanpooling all the original heads within that group. For more details, check out [this\n            paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.\n        head_dim (`int`, *optional*, defaults to 128):\n            The dimension of the head. If not specified, will default to `hidden_size // num_attention_heads`.\n        hidden_act (`str` or `function`, *optional*, defaults to `\"silu\"`):\n            The non-linear activation function (function or string) in the decoder.\n        max_position_embeddings (`int`, *optional*, defaults to 128000):\n            The maximum sequence length that this model might ever be used with.\n        initializer_range (`float`, *optional*, defaults to 0.02):\n            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.\n        rms_norm_eps (`float`, *optional*, defaults to 1e-06):\n            The epsilon used by the rms normalization layers.\n        use_cache (`bool`, *optional*, defaults to `True`):\n            Whether or not the model should return the last key/values attentions (not used by all models). Only\n            relevant if `config.is_decoder=True`.\n        tie_word_embeddings (`bool`, *optional*, defaults to `False`):\n            Whether the model's input and output word embeddings should be tied.\n        rope_theta (`float`, *optional*, defaults to 5000000.0):\n            The base period of the RoPE embeddings.\n        rope_scaling (`Dict`, *optional*):\n            Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type\n            and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value\n            accordingly.\n            Expected contents:\n                `rope_type` (`str`):\n                    The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',\n                    'llama3'], with 'default' being the original RoPE implementation.\n                `factor` (`float`, *optional*):\n                    Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In\n                    most scaling types, a `factor` of x will enable the model to handle sequences of length x *\n                    original maximum pre-trained length.\n                `original_max_position_embeddings` (`int`, *optional*):\n                    Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during\n                    pretraining.\n                `attention_factor` (`float`, *optional*):\n                    Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention\n                    computation. If unspecified, it defaults to value recommended by the implementation, using the\n                    `factor` field to infer the suggested value.\n                `beta_fast` (`float`, *optional*):\n                    Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear\n                    ramp function. If unspecified, it defaults to 32.\n                `beta_slow` (`float`, *optional*):\n                    Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear\n                    ramp function. If unspecified, it defaults to 1.\n                `short_factor` (`list[float]`, *optional*):\n                    Only used with 'longrope'. The scaling factor to be applied to short contexts (<\n                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden\n                    size divided by the number of attention heads divided by 2\n                `long_factor` (`list[float]`, *optional*):\n                    Only used with 'longrope'. The scaling factor to be applied to long contexts (<\n                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden\n                    size divided by the number of attention heads divided by 2\n                `low_freq_factor` (`float`, *optional*):\n                    Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE\n                `high_freq_factor` (`float`, *optional*):\n                    Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE\n        attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):\n            Whether to use a bias in the query, key, value and output projection layers during self-attention.\n        attention_dropout (`float`, *optional*, defaults to 0.0):\n            The dropout ratio for the attention probabilities.\n\n    ```python\n    >>> from transformers import Qwen3VLTextModel, Qwen3VLTextConfig\n\n    >>> # Initializing a Qwen3VL style configuration\n    >>> configuration = Qwen3VLTextConfig()\n\n    >>> # Initializing a model from the Qwen3-VL-7B style configuration\n    >>> model = Qwen3VLTextModel(configuration)\n\n    >>> # Accessing the model configuration\n    >>> configuration = model.config\n    ```\"\"\"\n\n    model_type = \"qwen3_vl_text\"\n    base_config_key = \"text_config\"\n\n    def __init__(\n        self,\n        vocab_size=151936,\n        hidden_size=4096,\n        intermediate_size=22016,\n        num_hidden_layers=32,\n        num_attention_heads=32,\n        num_key_value_heads=32,\n        head_dim=128,\n        hidden_act=\"silu\",\n        max_position_embeddings=128000,\n        initializer_range=0.02,\n        rms_norm_eps=1e-6,\n        use_cache=True,\n        tie_word_embeddings=False,\n        rope_theta=5000000.0,\n        rope_scaling=None,\n        attention_bias=False,\n        attention_dropout=0.0,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n\n        # for backward compatibility\n        if num_key_value_heads is None:\n            num_key_value_heads = num_attention_heads\n\n        self.num_key_value_heads = num_key_value_heads\n        self.head_dim = head_dim\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.rope_scaling = rope_scaling\n        self.attention_bias = attention_bias\n        self.attention_dropout = attention_dropout\n\n        super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)\n\n\nclass Qwen3VLConfig(PretrainedConfig):\n    r\"\"\"\n    This is the configuration class to store the configuration of a [`Qwen3VLModel`]. It is used to instantiate a\n    Qwen3-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration\n    with the defaults will yield a similar configuration to that of\n    Qwen3-VL-4B-Instruct [Qwen/Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct).\n\n    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the\n    documentation from [`PretrainedConfig`] for more information.\n\n\n    Args:\n        text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLTextConfig`):\n            The config object or dictionary of the text backbone.\n        vision_config (`Union[PreTrainedConfig, dict]`,  *optional*, defaults to `Qwen3VLVisionConfig`):\n            The config object or dictionary of the vision backbone.\n        image_token_id (`int`, *optional*, defaults to 151655):\n            The image token index to encode the image prompt.\n        video_token_id (`int`, *optional*, defaults to 151656):\n            The video token index to encode the image prompt.\n        vision_start_token_id (`int`, *optional*, defaults to 151652):\n            The start token index to encode the image prompt.\n        vision_end_token_id (`int`, *optional*, defaults to 151653):\n            The end token index to encode the image prompt.\n        tie_word_embeddings (`bool`, *optional*, defaults to `False`):\n            Whether to tie the word embeddings.\n\n    ```python\n    >>> from transformers import Qwen3VLForConditionalGeneration, Qwen3VLConfig\n\n    >>> # Initializing a Qwen3-VL style configuration\n    >>> configuration = Qwen3VLConfig()\n\n    >>> # Initializing a model from the Qwen3-VL-4B style configuration\n    >>> model = Qwen3VLForConditionalGeneration(configuration)\n\n    >>> # Accessing the model configuration\n    >>> configuration = model.config\n    ```\"\"\"\n\n    model_type = \"qwen3_vl\"\n    sub_configs = {\n        \"vision_config\": Qwen3VLVisionConfig,\n        \"text_config\": Qwen3VLTextConfig,\n    }\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n\n    def __init__(\n        self,\n        text_config=None,\n        vision_config=None,\n        image_token_id=151655,\n        video_token_id=151656,\n        vision_start_token_id=151652,\n        vision_end_token_id=151653,\n        tie_word_embeddings=False,\n        **kwargs,\n    ):\n        if isinstance(vision_config, dict):\n            self.vision_config = self.sub_configs[\"vision_config\"](**vision_config)\n        elif vision_config is None:\n            self.vision_config = self.sub_configs[\"vision_config\"]()\n\n        if isinstance(text_config, dict):\n            self.text_config = self.sub_configs[\"text_config\"](**text_config)\n        elif text_config is None:\n            self.text_config = self.sub_configs[\"text_config\"]()\n\n        self.image_token_id = image_token_id\n        self.video_token_id = video_token_id\n        self.vision_start_token_id = vision_start_token_id\n        self.vision_end_token_id = vision_end_token_id\n        super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings)\n\n\nclass Qwen3VLMoeTextConfig(PretrainedConfig):\n    r\"\"\"\n    This is the configuration class to store the configuration of a [`Qwen3VLMoeTextModel`]. It is used to instantiate a\n    Qwen3-VL-MOE model according to the specified arguments, defining the model architecture. Instantiating a configuration\n    with the defaults will yield a similar configuration to that of\n    Qwen3-VL-30B-A3B-Instruct [Qwen/Qwen3-VL-30B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-30B-A3B-Instruct).\n\n    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the\n    documentation from [`PretrainedConfig`] for more information.\n\n    Args:\n        vocab_size (`int`, *optional*, defaults to 151936):\n            Vocabulary size of the Qwen2MoE model. Defines the number of different tokens that can be represented by the\n            `inputs_ids` passed when calling [`Qwen2MoeModel`]\n        hidden_size (`int`, *optional*, defaults to 2048):\n            Dimension of the hidden representations.\n        intermediate_size (`int`, *optional*, defaults to 5632):\n            Dimension of the MLP representations.\n        num_hidden_layers (`int`, *optional*, defaults to 24):\n            Number of hidden layers in the Transformer encoder.\n        num_attention_heads (`int`, *optional*, defaults to 16):\n            Number of attention heads for each attention layer in the Transformer encoder.\n        num_key_value_heads (`int`, *optional*, defaults to 16):\n            This is the number of key_value heads that should be used to implement Grouped Query Attention. If\n            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if\n            `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When\n            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed\n            by meanpooling all the original heads within that group. For more details checkout [this\n            paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.\n        hidden_act (`str` or `function`, *optional*, defaults to `\"silu\"`):\n            The non-linear activation function (function or string) in the decoder.\n        max_position_embeddings (`int`, *optional*, defaults to 128000):\n            The maximum sequence length that this model might ever be used with.\n        initializer_range (`float`, *optional*, defaults to 0.02):\n            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.\n        rms_norm_eps (`float`, *optional*, defaults to 1e-06):\n            The epsilon used by the rms normalization layers.\n        use_cache (`bool`, *optional*, defaults to `True`):\n            Whether or not the model should return the last key/values attentions (not used by all models). Only\n            relevant if `config.is_decoder=True`.\n        tie_word_embeddings (`bool`, *optional*, defaults to `False`):\n            Whether the model's input and output word embeddings should be tied.\n        rope_theta (`float`, *optional*, defaults to 5000000.0):\n            The base period of the RoPE embeddings.\n        attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):\n            Whether to use a bias in the query, key, value and output projection layers during self-attention.\n        attention_dropout (`float`, *optional*, defaults to 0.0):\n            The dropout ratio for the attention probabilities.\n        decoder_sparse_step (`int`, *optional*, defaults to 1):\n            The frequency of the MoE layer.\n        moe_intermediate_size (`int`, *optional*, defaults to 1408):\n            Intermediate size of the routed expert.\n        num_experts_per_tok (`int`, *optional*, defaults to 4):\n            Number of selected experts.\n        num_experts (`int`, *optional*, defaults to 60):\n            Number of routed experts.\n        norm_topk_prob (`bool`, *optional*, defaults to `True`):\n            Whether to normalize the topk probabilities.\n        mlp_only_layers (`List[int]`, *optional*, defaults to `[]`):\n            Indicate which layers use Qwen3VLMoeMLP rather than Qwen3VLMoeSparseMoeBlock\n            The list contains layer index, from 0 to num_layers-1 if we have num_layers layers\n            If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity.\n        rope_scaling (`Dict`, *optional*):\n            Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type\n            and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value\n            accordingly.\n            Expected contents:\n                `rope_type` (`str`):\n                    The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',\n                    'llama3'], with 'default' being the original RoPE implementation.\n                `factor` (`float`, *optional*):\n                    Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In\n                    most scaling types, a `factor` of x will enable the model to handle sequences of length x *\n                    original maximum pre-trained length.\n                `original_max_position_embeddings` (`int`, *optional*):\n                    Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during\n                    pretraining.\n                `attention_factor` (`float`, *optional*):\n                    Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention\n                    computation. If unspecified, it defaults to value recommended by the implementation, using the\n                    `factor` field to infer the suggested value.\n                `beta_fast` (`float`, *optional*):\n                    Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear\n                    ramp function. If unspecified, it defaults to 32.\n                `beta_slow` (`float`, *optional*):\n                    Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear\n                    ramp function. If unspecified, it defaults to 1.\n                `short_factor` (`List[float]`, *optional*):\n                    Only used with 'longrope'. The scaling factor to be applied to short contexts (<\n                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden\n                    size divided by the number of attention heads divided by 2\n                `long_factor` (`List[float]`, *optional*):\n                    Only used with 'longrope'. The scaling factor to be applied to long contexts (<\n                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden\n                    size divided by the number of attention heads divided by 2\n                `low_freq_factor` (`float`, *optional*):\n                    Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE\n                `high_freq_factor` (`float`, *optional*):\n                    Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE\n        head_dim (`int`, *optional*):\n            The dimension of the head. If not specified, will default to `hidden_size // num_attention_heads`.\n\n    ```python\n    >>> from transformers import Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeConfig\n\n    >>> # Initializing a Qwen3VLMoe style configuration\n    >>> configuration = Qwen3VLMoeConfig()\n\n    >>> # Initializing a model from the Qwen3-VL-30B-A3B style configuration\n    >>> model = Qwen3VLMoeForConditionalGeneration(configuration)\n\n    >>> # Accessing the model configuration\n    >>> configuration = model.config\n    ```\"\"\"\n\n    model_type = \"qwen3_vl_moe_text\"\n    base_config_key = \"text_config\"\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n    # Default tensor parallel plan for base model `Qwen3VLMoe`\n    base_model_tp_plan = {\n        \"layers.*.self_attn.q_proj\": \"colwise\",\n        \"layers.*.self_attn.k_proj\": \"colwise\",\n        \"layers.*.self_attn.v_proj\": \"colwise\",\n        \"layers.*.self_attn.o_proj\": \"rowwise\",\n        \"layers.*.mlp.gate_proj\": \"colwise\",\n        \"layers.*.mlp.up_proj\": \"colwise\",\n        \"layers.*.mlp.down_proj\": \"rowwise\",\n    }\n    base_model_pp_plan = {\n        \"embed_tokens\": ([\"input_ids\"], [\"inputs_embeds\"]),\n        \"layers\": ([\"hidden_states\", \"attention_mask\"], [\"hidden_states\"]),\n        \"norm\": ([\"hidden_states\"], [\"hidden_states\"]),\n    }\n\n    def __init__(\n        self,\n        vocab_size=151936,\n        hidden_size=2048,\n        intermediate_size=5632,\n        num_hidden_layers=24,\n        num_attention_heads=16,\n        num_key_value_heads=16,\n        hidden_act=\"silu\",\n        max_position_embeddings=128000,\n        initializer_range=0.02,\n        rms_norm_eps=1e-6,\n        use_cache=True,\n        tie_word_embeddings=False,\n        rope_theta=5000000.0,\n        attention_bias=False,\n        attention_dropout=0.0,\n        decoder_sparse_step=1,\n        moe_intermediate_size=1408,\n        num_experts_per_tok=4,\n        num_experts=60,\n        norm_topk_prob=True,\n        mlp_only_layers=None,\n        rope_scaling=None,\n        head_dim=None,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n\n        # for backward compatibility\n        if num_key_value_heads is None:\n            num_key_value_heads = num_attention_heads\n\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.attention_bias = attention_bias\n        self.attention_dropout = attention_dropout\n        self.rope_scaling = rope_scaling\n        self.head_dim = head_dim or hidden_size // num_attention_heads\n\n        # MoE arguments\n        self.decoder_sparse_step = decoder_sparse_step\n        self.moe_intermediate_size = moe_intermediate_size\n        self.num_experts_per_tok = num_experts_per_tok\n        self.num_experts = num_experts\n        self.norm_topk_prob = norm_topk_prob\n        self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers\n\n        super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)\n\n\nclass Qwen3VLMoeVisionConfig(PretrainedConfig):\n    model_type = \"qwen3_vl_moe\"\n    base_config_key = \"vision_config\"\n\n    def __init__(\n        self,\n        depth=27,\n        hidden_size=1152,\n        hidden_act=\"gelu_pytorch_tanh\",\n        intermediate_size=4304,\n        num_heads=16,\n        in_channels=3,\n        patch_size=16,\n        spatial_merge_size=2,\n        temporal_patch_size=2,\n        out_hidden_size=3584,\n        num_position_embeddings=2304,\n        deepstack_visual_indexes=[8, 16, 24],\n        initializer_range=0.02,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n\n        self.depth = depth\n        self.hidden_size = hidden_size\n        self.hidden_act = hidden_act\n        self.intermediate_size = intermediate_size\n        self.num_heads = num_heads\n        self.in_channels = in_channels\n        self.patch_size = patch_size\n        self.spatial_merge_size = spatial_merge_size\n        self.temporal_patch_size = temporal_patch_size\n        self.out_hidden_size = out_hidden_size\n        self.num_position_embeddings = num_position_embeddings\n        self.initializer_range = initializer_range\n        self.deepstack_visual_indexes = deepstack_visual_indexes\n\n\nclass Qwen3VLMoeConfig(PretrainedConfig):\n    r\"\"\"\n    This is the configuration class to store the configuration of a [`Qwen3VLMoeModel`]. It is used to instantiate a\n    Qwen3-VL-MOE model according to the specified arguments, defining the model architecture. Instantiating a configuration\n    with the defaults will yield a similar configuration to that of\n    Qwen3-VL-30B-A3B-Instruct [Qwen/Qwen3-VL-30B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-30B-A3B-Instruct).\n\n    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the\n    documentation from [`PretrainedConfig`] for more information.\n\n\n    Args:\n        text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLMoeTextConfig`):\n            The config object or dictionary of the text backbone.\n        vision_config (`Union[PreTrainedConfig, dict]`,  *optional*, defaults to `Qwen3VLMoeVisionConfig`):\n            The config object or dictionary of the vision backbone.\n        image_token_id (`int`, *optional*, defaults to 151655):\n            The image token index to encode the image prompt.\n        video_token_id (`int`, *optional*, defaults to 151656):\n            The video token index to encode the image prompt.\n        vision_start_token_id (`int`, *optional*, defaults to 151652):\n            The start token index to encode the image prompt.\n        vision_end_token_id (`int`, *optional*, defaults to 151653):\n            The end token index to encode the image prompt.\n        tie_word_embeddings (`bool`, *optional*, defaults to `False`):\n            Whether to tie the word embeddings.\n\n    ```python\n    >>> from transformers import Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeConfig\n\n    >>> # Initializing a Qwen3-VL-MOE style configuration\n    >>> configuration = Qwen3VLMoeConfig()\n\n    >>> # Initializing a model from the Qwen3-VL-30B-A3B style configuration\n    >>> model = Qwen3VLMoeForConditionalGeneration(configuration)\n\n    >>> # Accessing the model configuration\n    >>> configuration = model.config\n    ```\"\"\"\n\n    model_type = \"qwen3_vl_moe\"\n    sub_configs = {\n        \"vision_config\": Qwen3VLMoeVisionConfig,\n        \"text_config\": Qwen3VLMoeTextConfig,\n    }\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n\n    def __init__(\n        self,\n        text_config=None,\n        vision_config=None,\n        image_token_id=151655,\n        video_token_id=151656,\n        vision_start_token_id=151652,\n        vision_end_token_id=151653,\n        tie_word_embeddings=False,\n        **kwargs,\n    ):\n        if isinstance(vision_config, dict):\n            self.vision_config = self.sub_configs[\"vision_config\"](**vision_config)\n        elif vision_config is None:\n            self.vision_config = self.sub_configs[\"vision_config\"]()\n\n        if isinstance(text_config, dict):\n            self.text_config = self.sub_configs[\"text_config\"](**text_config)\n        elif text_config is None:\n            self.text_config = self.sub_configs[\"text_config\"]()\n\n        self.image_token_id = image_token_id\n        self.video_token_id = video_token_id\n        self.vision_start_token_id = vision_start_token_id\n        self.vision_end_token_id = vision_end_token_id\n        super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings)\n"
  },
  {
    "path": "python/sglang/srt/configs/radio.py",
    "content": "# Copyright 2025 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/configs/radio.py\n\n\"\"\"Radio vision model configuration\"\"\"\n\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.utils import logging\n\nlogger = logging.get_logger(__name__)\n\nVIT_TIMM_DIM_BY_NAME: dict[str, tuple[int, int, int, int]] = {\n    \"vit_small_patch16_224\": (384, 12, 6, 1536),\n    \"vit_base_patch16_224\": (768, 12, 12, 3072),\n    \"vit_large_patch16_224\": (1024, 24, 16, 4096),\n    \"vit_huge_patch16_224\": (1280, 32, 16, 5120),\n}\n\nOPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)\nOPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711)\n\n\nclass RadioConfig(PretrainedConfig):\n    r\"\"\"\n    This is the configuration class to store the configuration of a Radio\n    vision model. It is used to instantiate a Radio model according to the\n    specified arguments, defining the model architecture.\n\n    Args:\n        model_name: Name of the vision transformer model\n            (e.g., \"vit_base_patch16_224\"). Used to determine architecture\n            dimensions from `VIT_TIMM_DIM_BY_NAME`.\n        image_size: The size (resolution) of each image.\n        patch_size: The size (resolution) of each patch.\n        qkv_bias: Whether to add a bias to the queries, keys and values.\n        qk_normalization: Whether to apply normalization to queries and keys.\n        norm_type: The normalization type to use.\n        layer_norm_eps: The epsilon used by the layer normalization layers.\n        initializer_factor: A factor for initializing all weight matrices.\n        hidden_act: The non-linear activation function in the encoder.\n        max_img_size: Maximum image size for position embeddings.\n        norm_mean: Mean values for image normalization (RGB channels).\n            Defaults to (0.48145466, 0.4578275, 0.40821073)).\n        norm_std: Standard deviation values for image normalization\n            (RGB channels). Defaults to (0.26862954, 0.26130258, 0.27577711)).\n        reg_tokens: Number of register tokens to use.\n    \"\"\"\n\n    model_type = \"radio\"\n\n    def __init__(\n        self,\n        model_name: str,\n        image_size: int = 224,\n        patch_size: int = 16,\n        qkv_bias: bool = True,\n        qk_normalization: bool = False,\n        norm_type: str = \"layer_norm\",\n        layer_norm_eps: float = 1e-6,\n        initializer_factor: float = 1.0,\n        hidden_act: str = \"gelu\",\n        max_img_size: int = 2048,\n        norm_mean: tuple[float, float, float] | list = OPENAI_CLIP_MEAN,\n        norm_std: tuple[float, float, float] | list = OPENAI_CLIP_STD,\n        reg_tokens: int | None = None,\n        drop_path_rate: float = 0.0,\n        dropout: float = 0.0,\n        **kwargs,\n    ):\n        self.model_name = model_name\n        (\n            self.hidden_size,\n            self.num_hidden_layers,\n            self.num_attention_heads,\n            self.intermediate_size,\n        ) = VIT_TIMM_DIM_BY_NAME[model_name]\n        self.image_size = image_size\n        self.patch_size = patch_size\n        self.qkv_bias = qkv_bias\n        self.qk_normalization = qk_normalization\n        self.norm_type = norm_type\n        self.layer_norm_eps = layer_norm_eps\n        self.initializer_factor = initializer_factor\n        self.hidden_act = hidden_act\n        self.max_img_size = max_img_size\n        self.norm_mean = (\n            list(norm_mean) if isinstance(norm_mean, (tuple, list)) else norm_mean\n        )\n        self.norm_std = (\n            list(norm_std) if isinstance(norm_std, (tuple, list)) else norm_std\n        )\n        self.reg_tokens = reg_tokens\n        self.drop_path_rate = drop_path_rate\n        self.dropout = dropout\n        super().__init__(**kwargs)\n"
  },
  {
    "path": "python/sglang/srt/configs/step3_vl.py",
    "content": "from typing import Any, Optional, Union\n\nfrom transformers.configuration_utils import PretrainedConfig\n\n\nclass Step3VisionEncoderConfig(PretrainedConfig):\n    model_type = \"step3_vision_encoder\"\n\n    def __init__(\n        self,\n        hidden_size=1792,\n        intermediate_size=3072,\n        output_hidden_size=4096,\n        num_hidden_layers=63,\n        num_attention_heads=16,\n        num_channels=3,\n        image_size=728,\n        patch_size=14,\n        hidden_act=\"quick_gelu\",\n        layer_norm_eps=1e-5,\n        **kwargs,\n    ):\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.output_hidden_size = output_hidden_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.num_channels = num_channels\n        self.patch_size = patch_size\n        self.image_size = image_size\n        self.layer_norm_eps = layer_norm_eps\n        self.hidden_act = hidden_act\n        super().__init__(**kwargs)\n\n\nclass Step3TextConfig(PretrainedConfig):\n    model_type = \"step3_text\"\n    architectures = [\"Step3TextForCausalLM\"]\n\n    def __init__(\n        self,\n        hidden_size: int = 7168,\n        intermediate_size: int = 18432,\n        num_attention_heads: int = 64,\n        num_attention_groups: int = 1,\n        num_hidden_layers: int = 61,\n        max_seq_len: int = 65536,\n        vocab_size: int = 128815,\n        rms_norm_eps: float = 1e-5,\n        moe_intermediate_size: int = 5120,\n        moe_num_experts: int = 48,\n        moe_top_k: int = 3,\n        rope_theta: float = 500000,\n        rope_scaling: Optional[dict[str, Any]] = None,\n        max_position_embedding: int = 65536,\n        share_expert_dim: int = 5120,\n        share_q_dim: int = 2048,\n        head_dim: int = 256,\n        norm_expert_weight: bool = False,\n        moe_layers_enum: tuple[int] = (\n            4,\n            5,\n            6,\n            7,\n            8,\n            9,\n            10,\n            11,\n            12,\n            13,\n            14,\n            15,\n            16,\n            17,\n            18,\n            19,\n            20,\n            21,\n            22,\n            23,\n            24,\n            25,\n            26,\n            27,\n            28,\n            29,\n            30,\n            31,\n            32,\n            33,\n            34,\n            35,\n            36,\n            37,\n            38,\n            39,\n            40,\n            41,\n            42,\n            43,\n            44,\n            45,\n            46,\n            47,\n            48,\n            49,\n            50,\n            51,\n            52,\n            53,\n            54,\n            55,\n            56,\n            57,\n            58,\n            59,\n        ),\n        **kwargs,\n    ) -> None:\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_attention_heads = num_attention_heads\n        self.num_attention_groups = num_attention_groups\n        self.num_hidden_layers = num_hidden_layers\n        self.max_seq_len = max_seq_len\n        self.vocab_size = vocab_size\n        self.rms_norm_eps = rms_norm_eps\n        self.moe_intermediate_size = moe_intermediate_size\n        self.moe_num_experts = moe_num_experts\n        self.moe_top_k = moe_top_k\n        self.rope_theta = rope_theta\n        self.rope_scaling = rope_scaling\n        self.max_position_embedding = max_position_embedding\n        self.share_expert_dim = share_expert_dim\n        self.share_q_dim = share_q_dim\n        self.head_dim = head_dim\n        self.norm_expert_weight = norm_expert_weight\n        self.moe_layers_enum = moe_layers_enum\n\n        super().__init__(**kwargs)\n\n\nclass Step3VLConfig(PretrainedConfig):\n    model_type = \"step3_vl\"\n\n    def __init__(\n        self,\n        vision_config: Optional[Union[dict, Step3VisionEncoderConfig]] = None,\n        text_config: Optional[Union[dict, Step3TextConfig]] = None,\n        understand_projector_stride: int = 1,\n        projector_bias: bool = True,\n        image_token_id: int = 128001,\n        **kwargs,\n    ) -> None:\n        if vision_config is None:\n            vision_config = Step3VisionEncoderConfig()\n        elif isinstance(vision_config, dict):\n            vision_config = Step3VisionEncoderConfig(**vision_config)\n        self.vision_config = vision_config\n\n        if text_config is None:\n            text_config = Step3TextConfig()\n        elif isinstance(text_config, dict):\n            text_config = Step3TextConfig(**text_config)\n        self.text_config = text_config\n\n        self.understand_projector_stride = understand_projector_stride\n        self.projector_bias = projector_bias\n        self.hidden_size = text_config.hidden_size\n        self.image_token_id = image_token_id\n\n        super().__init__(**kwargs)\n"
  },
  {
    "path": "python/sglang/srt/configs/step3p5.py",
    "content": "from typing import Any, Optional\n\nfrom transformers.configuration_utils import PretrainedConfig\n\n\nclass Step3p5Config(PretrainedConfig):\n    model_type = \"step3p5\"\n    architectures = [\"Step3p5ForCausalLM\"]\n\n    def __init__(\n        self,\n        hidden_size: int = 4096,\n        intermediate_size: int = 11264,\n        num_attention_heads: int = 64,\n        num_attention_groups: int = 8,\n        num_hidden_layers: int = 45,\n        max_seq_len: int = 128000,\n        vocab_size: int = 128815,\n        rms_norm_eps: float = 1e-5,\n        moe_intermediate_size: int = 1280,\n        moe_num_experts: int = 288,\n        moe_top_k: int = 8,\n        rope_theta: float = 10000,\n        rope_scaling: Optional[dict[str, Any]] = None,\n        max_position_embeddings: int = 128000,\n        share_expert_dims: int = 1280,\n        head_dim: int = 128,\n        norm_expert_weight: bool = True,\n        layer_types: list[str] = None,\n        sliding_window: Optional[int] = None,\n        moe_layers_enum: tuple[int] = (\n            3,\n            4,\n            5,\n            6,\n            7,\n            8,\n            9,\n            10,\n            11,\n            12,\n            13,\n            14,\n            15,\n            16,\n            17,\n            18,\n            19,\n            20,\n            21,\n            22,\n            23,\n            24,\n            25,\n            26,\n            27,\n            28,\n            29,\n            30,\n            31,\n            32,\n            33,\n            34,\n            35,\n            36,\n            37,\n            38,\n            39,\n            40,\n            41,\n            42,\n            43,\n            44,\n        ),\n        **kwargs,\n    ) -> None:\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_attention_heads = num_attention_heads\n        self.num_attention_groups = num_attention_groups\n        self.num_hidden_layers = num_hidden_layers\n        self.max_seq_len = max_seq_len\n        self.vocab_size = vocab_size\n        self.rms_norm_eps = rms_norm_eps\n        self.moe_intermediate_size = moe_intermediate_size\n        self.moe_num_experts = moe_num_experts\n        self.moe_top_k = moe_top_k\n        self.rope_theta = rope_theta\n        self.rope_scaling = rope_scaling\n        self.max_position_embeddings = max_position_embeddings\n        self.share_expert_dim = share_expert_dims\n        self.head_dim = head_dim\n        self.norm_expert_weight = norm_expert_weight\n        self.moe_layers_enum = moe_layers_enum\n        self.layer_types = layer_types\n        self.sliding_window = sliding_window\n        super().__init__(**kwargs)\n"
  },
  {
    "path": "python/sglang/srt/configs/update_config.py",
    "content": "from __future__ import annotations\n\nfrom typing import TYPE_CHECKING\n\nDEFAULT_MOE_PADDING_SIZE = 32\n\n\nif TYPE_CHECKING:\n    from sglang.srt.configs.load_config import LoadConfig\n    from sglang.srt.configs.model_config import ModelConfig\n\n\ndef may_get_weight_block_size(model_config, load_config):\n    from sglang.srt.model_loader.loader import _get_quantization_config\n\n    quant_config = _get_quantization_config(model_config, load_config)\n\n    if quant_config is not None and hasattr(quant_config, \"weight_block_size\"):\n        return getattr(quant_config, \"weight_block_size\")\n    return None\n\n\ndef get_moe_padding_size(weight_block_size):\n    if weight_block_size is not None:\n        # See NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.\n        assert (\n            len(weight_block_size) == 2\n        ), \"Only len(weight_block_size) == 2 is supported\"\n        assert (\n            weight_block_size[0] == weight_block_size[1]\n        ), \"Only weight_block_size[0] == weight_block_size[1] is supported\"\n\n        return weight_block_size[0]\n\n    return DEFAULT_MOE_PADDING_SIZE\n\n\ndef get_num_heads_padding_size(tp_size, weight_block_size, head_dim):\n    pad_size = tp_size\n\n    if weight_block_size is not None and head_dim % weight_block_size[0] != 0:\n        import math\n\n        pad_size = tp_size * (\n            math.lcm(head_dim, weight_block_size[0]) // weight_block_size[0]\n        )\n\n    return pad_size\n\n\ndef adjust_tp_num_heads_if_necessary(model_config, tp_size, is_post_update):\n    # is_post_update: whether to update an existing config\n    from sglang.srt.layers.vocab_parallel_embedding import pad_vocab_size\n\n    # Linear attn check logic\n    if hasattr(model_config, \"linear_num_key_heads\") and hasattr(\n        model_config, \"linear_num_value_heads\"\n    ):\n        if (\n            model_config.linear_num_key_heads % tp_size != 0\n            or model_config.linear_num_value_heads % tp_size != 0\n        ):\n            pad_size = tp_size\n            linear_num_key_heads_cpu = pad_vocab_size(\n                model_config.linear_num_key_heads, pad_size\n            )\n            linear_num_value_heads_cpu = (\n                linear_num_key_heads_cpu\n                * model_config.linear_num_value_heads\n                // model_config.linear_num_key_heads\n            )\n            if is_post_update:\n                model_config.linear_num_key_heads_cpu = linear_num_key_heads_cpu\n                model_config.linear_num_value_heads_cpu = linear_num_value_heads_cpu\n            else:\n                model_config.linear_num_key_heads = linear_num_key_heads_cpu\n                model_config.linear_num_value_heads = linear_num_value_heads_cpu\n\n        else:\n            if is_post_update:\n                model_config.linear_num_key_heads_cpu = (\n                    model_config.linear_num_key_heads\n                )\n                model_config.linear_num_value_heads_cpu = (\n                    model_config.linear_num_value_heads\n                )\n\n\ndef update_intermediate_size(model_config, attr_name, intermediate_padding_size):\n    attr_value = intermediate_padding_size\n    if hasattr(model_config, \"hf_config\") and hasattr(\n        model_config.hf_config, attr_name\n    ):\n        attr_value = getattr(model_config.hf_config, attr_name)\n    elif hasattr(model_config, attr_name):\n        attr_value = getattr(model_config, attr_name)\n\n    if attr_value % intermediate_padding_size != 0:\n        from sglang.srt.layers.vocab_parallel_embedding import pad_vocab_size\n\n        attr_value = pad_vocab_size(attr_value, intermediate_padding_size)\n        if hasattr(model_config, \"hf_config\"):\n            setattr(model_config.hf_config, attr_name, attr_value)\n            if hasattr(model_config, \"hf_text_config\"):\n                setattr(model_config.hf_text_config, attr_name, attr_value)\n        else:\n            setattr(model_config, attr_name, attr_value)\n\n    return model_config\n\n\ndef adjust_config_with_unaligned_cpu_tp(\n    model_config: ModelConfig, load_config: LoadConfig, tp_size: int\n) -> ModelConfig:\n    # Support the case where the num_attention_heads is not divisible by the TP size.\n    weight_block_size = may_get_weight_block_size(model_config, load_config)\n\n    model_config.hf_config.original_num_attention_heads = (\n        model_config.num_attention_heads\n    )\n    model_config.hf_text_config.original_num_attention_heads = (\n        model_config.num_attention_heads\n    )\n\n    model_config.hf_config.original_total_num_kv_heads = (\n        model_config.get_total_num_kv_heads()\n    )\n    model_config.hf_text_config.original_total_num_kv_heads = (\n        model_config.get_total_num_kv_heads()\n    )\n\n    if (\n        model_config.num_attention_heads % tp_size != 0\n        or model_config.get_total_num_kv_heads() % tp_size != 0\n    ):\n        # Compute the head_dim using the model_config.num_attention_heads before padding\n        if not hasattr(model_config.hf_config, \"head_dim\"):\n            model_config.hf_config.head_dim = (\n                model_config.hidden_size // model_config.num_attention_heads\n            )\n        if hasattr(model_config.hf_config, \"qk_nope_head_dim\") and hasattr(\n            model_config.hf_config, \"qk_rope_head_dim\"\n        ):\n            model_config.hf_config.qk_head_dim = (\n                model_config.hf_config.qk_nope_head_dim\n                + model_config.hf_config.qk_rope_head_dim\n            )\n\n        query_heads_per_kv = (\n            model_config.num_attention_heads // model_config.get_total_num_kv_heads()\n        )\n        total_kv_heads = model_config.get_total_num_kv_heads()\n        from sglang.srt.layers.vocab_parallel_embedding import pad_vocab_size\n\n        head_dim = (\n            model_config.hf_config.qk_head_dim\n            if hasattr(model_config.hf_config, \"qk_head_dim\")\n            else model_config.hf_config.head_dim\n        )\n        pad_size = get_num_heads_padding_size(tp_size, weight_block_size, head_dim)\n        num_key_value_heads = pad_vocab_size(total_kv_heads, pad_size)\n\n        model_config.num_key_value_heads = num_key_value_heads\n        model_config.hf_config.num_key_value_heads = num_key_value_heads\n        model_config.hf_text_config.num_key_value_heads = num_key_value_heads\n\n        num_attention_heads = num_key_value_heads * query_heads_per_kv\n        model_config.num_attention_heads = num_attention_heads\n        model_config.hf_config.num_attention_heads = num_attention_heads\n        model_config.hf_text_config.num_attention_heads = num_attention_heads\n\n    adjust_tp_num_heads_if_necessary(model_config.hf_config, tp_size, True)\n\n    intermediate_padding_size = tp_size * get_moe_padding_size(weight_block_size)\n    model_config = update_intermediate_size(\n        model_config, \"moe_intermediate_size\", intermediate_padding_size\n    )\n    model_config = update_intermediate_size(\n        model_config, \"intermediate_size\", intermediate_padding_size\n    )\n    model_config = update_intermediate_size(\n        model_config, \"intermediate_size_mlp\", intermediate_padding_size\n    )\n    model_config = update_intermediate_size(\n        model_config, \"shared_expert_intermediate_size\", intermediate_padding_size\n    )\n    if (\n        hasattr(model_config.hf_config, \"vision_config\")\n        and model_config.hf_config.vision_config.model_type == \"siglip_vision_model\"\n    ):\n        model_config.hf_config.vision_config.original_num_attention_heads = (\n            model_config.num_attention_heads\n        )\n        if model_config.hf_config.vision_config.num_attention_heads % tp_size != 0:\n            model_config.hf_config.vision_config.head_dim = (\n                model_config.hf_config.vision_config.hidden_size\n                // model_config.hf_config.vision_config.num_attention_heads\n            )\n            from sglang.srt.layers.vocab_parallel_embedding import pad_vocab_size\n\n            pad_size = get_num_heads_padding_size(tp_size, weight_block_size)\n            model_config.hf_config.vision_config.num_attention_heads = pad_vocab_size(\n                model_config.hf_config.vision_config.num_attention_heads, pad_size\n            )\n        model_config.hf_config.vision_config = update_intermediate_size(\n            model_config.hf_config.vision_config,\n            \"intermediate_size\",\n            intermediate_padding_size,\n        )\n\n    return model_config\n"
  },
  {
    "path": "python/sglang/srt/configs/utils.py",
    "content": "from typing import Type\n\nfrom transformers import (\n    AutoImageProcessor,\n    AutoProcessor,\n    BaseImageProcessor,\n    PretrainedConfig,\n    ProcessorMixin,\n)\n\n\ndef register_image_processor(\n    config: Type[PretrainedConfig], image_processor: Type[BaseImageProcessor]\n):\n    \"\"\"\n    register customized hf image processor while removing hf impl\n    \"\"\"\n    AutoImageProcessor.register(\n        config, slow_image_processor_class=image_processor, exist_ok=True\n    )\n\n\ndef register_processor(config: Type[PretrainedConfig], processor: Type[ProcessorMixin]):\n    \"\"\"\n    register customized hf processor while removing hf impl\n    \"\"\"\n    AutoProcessor.register(config, processor, exist_ok=True)\n"
  },
  {
    "path": "python/sglang/srt/connector/__init__.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\nimport enum\nimport logging\n\nfrom sglang.srt.connector.base_connector import (\n    BaseConnector,\n    BaseFileConnector,\n    BaseKVConnector,\n)\nfrom sglang.srt.connector.redis import RedisConnector\nfrom sglang.srt.connector.remote_instance import RemoteInstanceConnector\nfrom sglang.srt.connector.s3 import S3Connector\nfrom sglang.srt.utils import parse_connector_type\n\nlogger = logging.getLogger(__name__)\n\n\nclass ConnectorType(str, enum.Enum):\n    FS = \"filesystem\"\n    KV = \"KV\"\n    INSTANCE = \"instance\"\n\n\ndef create_remote_connector(url, device=None, **kwargs) -> BaseConnector:\n    connector_type = parse_connector_type(url)\n    if connector_type == \"redis\":\n        return RedisConnector(url)\n    elif connector_type == \"s3\":\n        return S3Connector(url)\n    elif connector_type == \"instance\":\n        return RemoteInstanceConnector(url, device)\n    else:\n        raise ValueError(f\"Invalid connector type: {url}\")\n\n\ndef get_connector_type(client: BaseConnector) -> ConnectorType:\n    if isinstance(client, BaseKVConnector):\n        return ConnectorType.KV\n    if isinstance(client, BaseFileConnector):\n        return ConnectorType.FS\n    if isinstance(client, RemoteInstanceConnector):\n        return ConnectorType.INSTANCE\n\n    raise ValueError(f\"Invalid connector type: {client}\")\n\n\n__all__ = [\n    \"BaseConnector\",\n    \"BaseFileConnector\",\n    \"BaseKVConnector\",\n    \"RedisConnector\",\n    \"RemoteInstanceConnector\",\n    \"S3Connector\",\n    \"ConnectorType\",\n    \"create_remote_connector\",\n    \"get_connector_type\",\n]\n"
  },
  {
    "path": "python/sglang/srt/connector/base_connector.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\nimport os\nimport shutil\nimport signal\nimport tempfile\nfrom abc import ABC, abstractmethod\nfrom typing import Generator, List, Optional, Tuple\n\nimport torch\n\n\nclass BaseConnector(ABC):\n    \"\"\"\n    For fs connector such as s3:\n    <connector_type>://<path>/<filename>\n\n    For kv connector such as redis:\n    <connector_type>://<host>:<port>/<model_name>/keys/<key>\n    <connector_type://<host>:<port>/<model_name>/files/<filename>\n    \"\"\"\n\n    def __init__(self, url: str):\n        self.url = url\n        self.closed = False\n        self.local_dir = tempfile.mkdtemp()\n        for sig in (signal.SIGINT, signal.SIGTERM):\n            existing_handler = signal.getsignal(sig)\n            signal.signal(sig, self._close_by_signal(existing_handler))\n\n    def get_local_dir(self):\n        return self.local_dir\n\n    @abstractmethod\n    def weight_iterator(\n        self, rank: int = 0\n    ) -> Generator[Tuple[str, torch.Tensor], None, None]:\n        raise NotImplementedError()\n\n    @abstractmethod\n    def pull_files(\n        self,\n        allow_pattern: Optional[List[str]] = None,\n        ignore_pattern: Optional[List[str]] = None,\n    ) -> None:\n        raise NotImplementedError()\n\n    def close(self):\n        if self.closed:\n            return\n\n        self.closed = True\n        if os.path.exists(self.local_dir):\n            shutil.rmtree(self.local_dir)\n\n    def __enter__(self):\n        return self\n\n    def __exit__(self, exc_type, exc_value, traceback):\n        self.close()\n\n    def __del__(self):\n        self.close()\n\n    def _close_by_signal(self, existing_handler=None):\n\n        def new_handler(signum, frame):\n            self.close()\n            if existing_handler:\n                existing_handler(signum, frame)\n\n        return new_handler\n\n\nclass BaseKVConnector(BaseConnector):\n\n    @abstractmethod\n    def get(self, key: str) -> Optional[torch.Tensor]:\n        raise NotImplementedError()\n\n    @abstractmethod\n    def getstr(self, key: str) -> Optional[str]:\n        raise NotImplementedError()\n\n    @abstractmethod\n    def set(self, key: str, obj: torch.Tensor) -> None:\n        raise NotImplementedError()\n\n    @abstractmethod\n    def setstr(self, key: str, obj: str) -> None:\n        raise NotImplementedError()\n\n    @abstractmethod\n    def list(self, prefix: str) -> List[str]:\n        raise NotImplementedError()\n\n\nclass BaseFileConnector(BaseConnector):\n    \"\"\"\n    List full file names from remote fs path and filter by allow pattern.\n\n    Args:\n        allow_pattern: A list of patterns of which files to pull.\n\n    Returns:\n        list[str]: List of full paths allowed by the pattern\n    \"\"\"\n\n    @abstractmethod\n    def glob(self, allow_pattern: str) -> List[str]:\n        raise NotImplementedError()\n"
  },
  {
    "path": "python/sglang/srt/connector/redis.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\nimport logging\nfrom typing import Generator, List, Optional, Tuple\nfrom urllib.parse import urlparse\n\nimport torch\n\nfrom sglang.srt.connector import BaseKVConnector\nfrom sglang.srt.connector.serde import create_serde\nfrom sglang.srt.connector.utils import pull_files_from_db\n\nlogger = logging.getLogger(__name__)\n\n\nclass RedisConnector(BaseKVConnector):\n\n    def __init__(self, url: str):\n        import redis\n\n        super().__init__(url)\n        parsed_url = urlparse(url)\n        self.connection = redis.Redis(host=parsed_url.hostname, port=parsed_url.port)\n        self.model_name = parsed_url.path.lstrip(\"/\")\n        # TODO: more serde options\n        self.s, self.d = create_serde(\"safe\")\n\n    def get(self, key: str) -> Optional[torch.Tensor]:\n        val = self.connection.get(key)\n\n        if val is None:\n            logger.error(\"Key %s not found\", key)\n            return None\n\n        return self.d.from_bytes(val)\n\n    def getstr(self, key: str) -> Optional[str]:\n        val = self.connection.get(key)\n        if val is None:\n            logger.error(\"Key %s not found\", key)\n            return None\n\n        return val.decode(\"utf-8\")\n\n    def set(self, key: str, tensor: torch.Tensor) -> None:\n        assert tensor is not None\n        self.connection.set(key, self.s.to_bytes(tensor))\n\n    def setstr(self, key: str, obj: str) -> None:\n        self.connection.set(key, obj)\n\n    def list(self, prefix: str) -> List[str]:\n        cursor = 0\n        all_keys: List[bytes] = []\n\n        while True:\n            ret: Tuple[int, List[bytes]] = self.connection.scan(\n                cursor=cursor, match=f\"{prefix}*\"\n            )  # type: ignore\n            cursor, keys = ret\n            all_keys.extend(keys)\n            if cursor == 0:\n                break\n\n        return [key.decode(\"utf-8\") for key in all_keys]\n\n    def weight_iterator(\n        self, rank: int = 0\n    ) -> Generator[Tuple[str, bytes], None, None]:\n        keys = self.list(f\"{self.model_name}/keys/rank_{rank}/\")\n        for key in keys:\n            val = self.get(key)\n            key = key.removeprefix(f\"{self.model_name}/keys/rank_{rank}/\")\n            yield key, val\n\n    def pull_files(\n        self,\n        allow_pattern: Optional[List[str]] = None,\n        ignore_pattern: Optional[List[str]] = None,\n    ) -> None:\n        pull_files_from_db(self, self.model_name, allow_pattern, ignore_pattern)\n\n    def close(self):\n        self.connection.close()\n        super().close()\n"
  },
  {
    "path": "python/sglang/srt/connector/remote_instance.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\nimport logging\nfrom typing import Generator, Optional, Tuple\nfrom urllib.parse import urlparse\n\nimport torch\nimport torch.distributed as dist\n\nfrom sglang.srt.connector import BaseConnector\nfrom sglang.srt.utils import init_custom_process_group\n\nlogger = logging.getLogger(__name__)\n\n\nclass RemoteInstanceConnector(BaseConnector):\n\n    def __init__(self, url: str, device: torch.device = \"cpu\"):\n        assert (\n            device.type == \"cuda\" or device.type == \"npu\"\n        ), \"RemoteInstanceConnector only supports cuda device.\"\n        super().__init__(url)\n        self.url = url\n        self.device = device\n\n    def build_group(\n        self,\n        gpu_id: int = -1,\n        tp_rank: int = -1,\n        instance_ip: str = None,\n        group_rank: int = 1,\n        world_size: int = 2,\n    ):\n        assert (\n            self.device.type == \"cuda\" or self.device.type == \"npu\"\n        ), \"RemoteInstanceConnector only supports cuda device.\"\n        assert (\n            gpu_id != -1 and tp_rank != -1\n        ), \"gpu_id and tp_rank must be specified for RemoteInstanceConnector. \"\n\n        self.device_id = torch.device(self.device.type, gpu_id)\n\n        parsed_url = urlparse(self.url)\n        master_address = parsed_url.hostname\n        master_port = parsed_url.port\n        group_name = f\"send_weights_{instance_ip}_{master_port}_{tp_rank}\"\n        backend = \"nccl\"\n\n        logger.info(\n            f\"init custom process group: master_address={master_address}, master_port={master_port}, \"\n            f\"rank_offset={group_rank}, world_size={world_size}, group_name={group_name}, backend={backend}\"\n        )\n\n        try:\n            self._model_update_group = init_custom_process_group(\n                backend=backend,\n                init_method=f\"tcp://{master_address}:{master_port}\",\n                world_size=world_size,\n                rank=group_rank,\n                group_name=group_name,\n                device_id=self.device_id,\n            )\n            dist.barrier(group=self._model_update_group)\n            return True, \"Succeeded to initialize custom process group.\"\n        except Exception as e:\n            message = f\"Failed to initialize custom process group: {e}.\"\n            logger.error(message)\n            return False, message\n\n    # Implemented as a no-op to make BaseConnector interface consistent.\n    def pull_files(\n        self,\n        allow_pattern: Optional[list[str]] = None,\n        ignore_pattern: Optional[list[str]] = None,\n    ) -> None:\n        return\n\n    # Implemented as a no-op to make BaseConnector interface consistent.\n    def weight_iterator(\n        self, rank: int = 0\n    ) -> Generator[Tuple[str, torch.Tensor], None, None]:\n        return\n"
  },
  {
    "path": "python/sglang/srt/connector/s3.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\nimport fnmatch\nimport os\nfrom pathlib import Path\nfrom typing import Generator, Optional, Tuple\n\nimport torch\n\nfrom sglang.srt.connector import BaseFileConnector\n\n\ndef _filter_allow(paths: list[str], patterns: list[str]) -> list[str]:\n    return [\n        path\n        for path in paths\n        if any(fnmatch.fnmatch(path, pattern) for pattern in patterns)\n    ]\n\n\ndef _filter_ignore(paths: list[str], patterns: list[str]) -> list[str]:\n    return [\n        path\n        for path in paths\n        if not any(fnmatch.fnmatch(path, pattern) for pattern in patterns)\n    ]\n\n\ndef list_files(\n    s3,\n    path: str,\n    allow_pattern: Optional[list[str]] = None,\n    ignore_pattern: Optional[list[str]] = None,\n) -> tuple[str, str, list[str]]:\n    \"\"\"\n    List files from S3 path and filter by pattern.\n\n    Args:\n        s3: S3 client to use.\n        path: The S3 path to list from.\n        allow_pattern: A list of patterns of which files to pull.\n        ignore_pattern: A list of patterns of which files not to pull.\n\n    Returns:\n        tuple[str, str, list[str]]: A tuple where:\n            - The first element is the bucket name\n            - The second element is string represent the bucket\n              and the prefix as a dir like string\n            - The third element is a list of files allowed or\n              disallowed by pattern\n    \"\"\"\n    parts = path.removeprefix(\"s3://\").split(\"/\")\n    prefix = \"/\".join(parts[1:])\n    bucket_name = parts[0]\n\n    objects = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix)\n    paths = [obj[\"Key\"] for obj in objects.get(\"Contents\", [])]\n\n    paths = _filter_ignore(paths, [\"*/\"])\n    if allow_pattern is not None:\n        paths = _filter_allow(paths, allow_pattern)\n\n    if ignore_pattern is not None:\n        paths = _filter_ignore(paths, ignore_pattern)\n\n    return bucket_name, prefix, paths\n\n\nclass S3Connector(BaseFileConnector):\n\n    def __init__(self, url: str) -> None:\n        import boto3\n\n        super().__init__(url)\n        self.client = boto3.client(\"s3\")\n\n    def glob(self, allow_pattern: Optional[list[str]] = None) -> list[str]:\n        bucket_name, _, paths = list_files(\n            self.client, path=self.url, allow_pattern=allow_pattern\n        )\n        return [f\"s3://{bucket_name}/{path}\" for path in paths]\n\n    def pull_files(\n        self,\n        allow_pattern: Optional[list[str]] = None,\n        ignore_pattern: Optional[list[str]] = None,\n    ) -> None:\n        \"\"\"\n        Pull files from S3 storage into the temporary directory.\n\n        Args:\n            s3_model_path: The S3 path of the model.\n            allow_pattern: A list of patterns of which files to pull.\n            ignore_pattern: A list of patterns of which files not to pull.\n\n        \"\"\"\n        bucket_name, base_dir, files = list_files(\n            self.client, self.url, allow_pattern, ignore_pattern\n        )\n        if len(files) == 0:\n            return\n\n        for file in files:\n            destination_file = os.path.join(self.local_dir, file.removeprefix(base_dir))\n            local_dir = Path(destination_file).parent\n            os.makedirs(local_dir, exist_ok=True)\n            self.client.download_file(bucket_name, file, destination_file)\n\n    def weight_iterator(\n        self, rank: int = 0\n    ) -> Generator[Tuple[str, torch.Tensor], None, None]:\n        from sglang.srt.model_loader.weight_utils import (\n            runai_safetensors_weights_iterator,\n        )\n\n        # only support safetensor files now\n        hf_weights_files = self.glob(allow_pattern=[\"*.safetensors\"])\n        return runai_safetensors_weights_iterator(hf_weights_files)\n\n    def close(self):\n        self.client.close()\n        super().close()\n"
  },
  {
    "path": "python/sglang/srt/connector/serde/__init__.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\n# inspired by LMCache\nfrom typing import Optional, Tuple\n\nimport torch\n\nfrom sglang.srt.connector.serde.safe_serde import SafeDeserializer, SafeSerializer\nfrom sglang.srt.connector.serde.serde import Deserializer, Serializer\n\n\ndef create_serde(serde_type: str) -> Tuple[Serializer, Deserializer]:\n    s: Optional[Serializer] = None\n    d: Optional[Deserializer] = None\n\n    if serde_type == \"safe\":\n        s = SafeSerializer()\n        d = SafeDeserializer()\n    else:\n        raise ValueError(f\"Unknown serde type: {serde_type}\")\n\n    return s, d\n\n\n__all__ = [\n    \"Serializer\",\n    \"Deserializer\",\n    \"SafeSerializer\",\n    \"SafeDeserializer\",\n    \"create_serde\",\n]\n"
  },
  {
    "path": "python/sglang/srt/connector/serde/safe_serde.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\nfrom typing import Union\n\nimport torch\nfrom safetensors.torch import load, save\n\nfrom sglang.srt.connector.serde.serde import Deserializer, Serializer\n\n\nclass SafeSerializer(Serializer):\n\n    def __init__(self):\n        super().__init__()\n\n    def to_bytes(self, t: torch.Tensor) -> bytes:\n        return save({\"tensor_bytes\": t.cpu().contiguous()})\n\n\nclass SafeDeserializer(Deserializer):\n\n    def __init__(self):\n        # TODO: dtype options\n        super().__init__(torch.float32)\n\n    def from_bytes_normal(self, b: Union[bytearray, bytes]) -> torch.Tensor:\n        return load(bytes(b))[\"tensor_bytes\"]\n\n    def from_bytes(self, b: Union[bytearray, bytes]) -> torch.Tensor:\n        return self.from_bytes_normal(b)\n"
  },
  {
    "path": "python/sglang/srt/connector/serde/serde.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\nimport abc\nfrom abc import ABC, abstractmethod\n\nimport torch\n\n\nclass Serializer(ABC):\n\n    @abstractmethod\n    def to_bytes(self, t: torch.Tensor) -> bytes:\n        \"\"\"\n        Serialize a pytorch tensor to bytes. The serialized bytes should contain\n        both the data and the metadata (shape, dtype, etc.) of the tensor.\n\n        Input:\n            t: the input pytorch tensor, can be on any device, in any shape,\n               with any dtype\n\n        Returns:\n            bytes: the serialized bytes\n        \"\"\"\n        raise NotImplementedError\n\n\nclass Deserializer(metaclass=abc.ABCMeta):\n\n    def __init__(self, dtype):\n        self.dtype = dtype\n\n    @abstractmethod\n    def from_bytes(self, bs: bytes) -> torch.Tensor:\n        \"\"\"\n        Deserialize a pytorch tensor from bytes.\n\n        Input:\n            bytes: a stream of bytes\n\n        Output:\n            torch.Tensor: the deserialized pytorch tensor\n        \"\"\"\n        raise NotImplementedError\n"
  },
  {
    "path": "python/sglang/srt/connector/utils.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\nimport os\nfrom pathlib import Path\nfrom typing import Optional\nfrom urllib.parse import urlparse\n\nfrom sglang.srt.connector import BaseConnector\n\n\ndef parse_model_name(url: str) -> str:\n    \"\"\"\n    Parse the model name from the url.\n    Only used for db connector\n    \"\"\"\n    parsed_url = urlparse(url)\n    return parsed_url.path.lstrip(\"/\")\n\n\ndef pull_files_from_db(\n    connector: BaseConnector,\n    model_name: str,\n    allow_pattern: Optional[list[str]] = None,\n    ignore_pattern: Optional[list[str]] = None,\n) -> None:\n    prefix = f\"{model_name}/files/\"\n    local_dir = connector.get_local_dir()\n    files = connector.list(prefix)\n\n    for file in files:\n        destination_file = os.path.join(local_dir, file.removeprefix(prefix))\n        local_dir = Path(destination_file).parent\n        os.makedirs(local_dir, exist_ok=True)\n        with open(destination_file, \"wb\") as f:\n            f.write(connector.getstr(file).encode(\"utf-8\"))\n"
  },
  {
    "path": "python/sglang/srt/constants.py",
    "content": "# GPU Memory Types\nGPU_MEMORY_TYPE_KV_CACHE = \"kv_cache\"\nGPU_MEMORY_TYPE_WEIGHTS = \"weights\"\nGPU_MEMORY_TYPE_CUDA_GRAPH = \"cuda_graph\"\n\nGPU_MEMORY_ALL_TYPES = [\n    GPU_MEMORY_TYPE_KV_CACHE,\n    GPU_MEMORY_TYPE_WEIGHTS,\n    GPU_MEMORY_TYPE_CUDA_GRAPH,\n]\n"
  },
  {
    "path": "python/sglang/srt/constrained/base_grammar_backend.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"The baseclass of a backend for grammar-guided constrained decoding.\"\"\"\n\nimport logging\nimport time\nfrom concurrent.futures import Future, ThreadPoolExecutor\nfrom dataclasses import dataclass, field\nfrom typing import Dict, List, Optional, Tuple\n\nimport torch\n\nfrom sglang.srt.server_args import ServerArgs\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclass\nclass GrammarStats:\n    compilation_time: Optional[float] = None\n    schema_count: Optional[int] = None\n    ebnf_size: Optional[int] = None\n    is_cache_hit: bool = False\n    is_grammar_aborted: bool = False\n    tree_traversal_time: List[float] = field(default_factory=list)\n    dispatch_type: Optional[str] = None\n    num_timeout: int = 0\n\n\nclass BaseGrammarObject:\n\n    def __init__(self):\n        self._finished = False\n        self.grammar_stats = None\n        self.current_token = None\n\n    def maybe_init_reasoning(self, reasoning: bool):\n        pass\n\n    def accept_token(self, token: int) -> None:\n        \"\"\"\n        Accept a token in the grammar.\n        \"\"\"\n        raise NotImplementedError()\n\n    def rollback(self, k: int):\n        raise NotImplementedError()\n\n    def is_terminated(self):\n        return False\n\n    def allocate_vocab_mask(\n        self, vocab_size: int, batch_size: int, device\n    ) -> torch.Tensor:\n        raise NotImplementedError()\n\n    def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:\n        raise NotImplementedError()\n\n    @staticmethod\n    def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:\n        raise NotImplementedError()\n\n    @staticmethod\n    def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:\n        raise NotImplementedError()\n\n    def copy(self) -> \"BaseGrammarObject\":\n        return self\n\n    @property\n    def finished(self):\n        return self._finished\n\n    @finished.setter\n    def finished(self, finished):\n        self._finished = finished\n\n    def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:\n        \"\"\"\n        Try to jump forward in the grammar.\n\n        Returns:\n            A jump forward helper which may be used in `jump_forward_str_state`.\n            None if the jump forward is not possible.\n        \"\"\"\n        raise NotImplementedError()\n\n    def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:\n        \"\"\"\n        Jump forward for the grammar.\n\n        Returns:\n            A tuple of the jump forward string and the next state of the grammar\n            (which can be used in `jump_and_retokenize` if needed).\n        \"\"\"\n        raise NotImplementedError()\n\n    def jump_and_retokenize(\n        self, old_output_ids: List[int], new_output_ids: List[int], next_state: int\n    ) -> None:\n        \"\"\"\n        Jump forward occurs, and update the grammar state if needed.\n        \"\"\"\n        raise NotImplementedError()\n\n\nclass InvalidGrammarObject(BaseGrammarObject):\n    \"\"\"Represents a grammar that failed to compile, carrying the original error message.\"\"\"\n\n    def __init__(self, error_message: str = \"Unknown grammar error\"):\n        super().__init__()\n        self.error_message = error_message\n\n    def __repr__(self):\n        return f\"InvalidGrammarObject(error_message={self.error_message!r})\"\n\n\nclass BaseGrammarBackend:\n    def __init__(self):\n        self.executor = ThreadPoolExecutor()\n        self.cache: Dict[Tuple[str, str], BaseGrammarObject] = {}\n\n    def _not_supported(self, key_type: str, key_string: str) -> BaseGrammarObject:\n        logger.warning(f\"Skip unsupported {key_type=}, {key_string=}\")\n        return InvalidGrammarObject()\n\n    def dispatch_fallback(self, key_type: str, key_string: str) -> BaseGrammarObject:\n        \"\"\"\n        This function should not be reached in any case.\n        \"\"\"\n        raise ValueError(f\"Invalid key_type: {key_type}={key_string}\")\n\n    def dispatch_json(self, key_string: str) -> BaseGrammarObject:\n        return self._not_supported(\"json\", key_string)\n\n    def dispatch_regex(self, key_string: str) -> BaseGrammarObject:\n        return self._not_supported(\"regex\", key_string)\n\n    def dispatch_ebnf(self, key_string: str) -> BaseGrammarObject:\n        return self._not_supported(\"ebnf\", key_string)\n\n    def dispatch_structural_tag(self, key_string: str) -> BaseGrammarObject:\n        return self._not_supported(\"structural_tag\", key_string)\n\n    def _init_value_dispatch(\n        self, key: Tuple[str, str], require_reasoning: bool\n    ) -> BaseGrammarObject:\n        s = time.perf_counter()\n        key_type, key_string = key\n        if key_type == \"json\":\n            grammar = self.dispatch_json(key_string)\n        elif key_type == \"regex\":\n            grammar = self.dispatch_regex(key_string)\n        elif key_type == \"ebnf\":\n            grammar = self.dispatch_ebnf(key_string)\n        elif key_type == \"structural_tag\":\n            grammar = self.dispatch_structural_tag(key_string)\n        else:\n            grammar = self.dispatch_fallback(key_type, key_string)\n\n        if grammar is not None and grammar.grammar_stats is not None:\n            grammar.grammar_stats.compilation_time = time.perf_counter() - s\n        return grammar\n\n    def get_cached_or_future_value(\n        self, key: Tuple[str, str], require_reasoning: bool\n    ) -> Tuple[BaseGrammarObject | Future[BaseGrammarObject], bool]:\n        value = self.cache.get(key)\n        if value:\n            copied_value = value.copy()\n            copied_value.maybe_init_reasoning(require_reasoning)\n            return copied_value, True\n        value = self.executor.submit(self._init_value_dispatch, key, require_reasoning)\n        return value, False\n\n    def set_cache(self, key: Tuple[str, str], value: BaseGrammarObject):\n        self.cache[key] = value\n\n    def reset(self):\n        self.cache.clear()\n\n\nGRAMMAR_BACKEND_REGISTRY = {}\n\n\ndef register_grammar_backend(name, init_func):\n    GRAMMAR_BACKEND_REGISTRY[name] = init_func\n\n\ndef create_grammar_backend(\n    server_args: ServerArgs,\n    tokenizer,\n    vocab_size: int,\n    eos_token_ids: Optional[set] = None,\n) -> Optional[BaseGrammarBackend]:\n    name = server_args.grammar_backend\n\n    # Custom grammar backend has the highest priority\n    if name in GRAMMAR_BACKEND_REGISTRY:\n        return GRAMMAR_BACKEND_REGISTRY[name](\n            server_args, tokenizer, vocab_size, eos_token_ids\n        )\n\n    # Default grammar backends\n    if name == \"outlines\":\n        from sglang.srt.constrained.outlines_backend import OutlinesGrammarBackend\n\n        grammar_backend = OutlinesGrammarBackend(\n            tokenizer,\n            whitespace_pattern=server_args.constrained_json_whitespace_pattern,\n        )\n    elif name == \"xgrammar\":\n        from sglang.srt.constrained.xgrammar_backend import (\n            TokenizerNotSupportedError,\n            XGrammarGrammarBackend,\n        )\n\n        # Convert Set[int] to List[int] if needed\n        eos_list = list(eos_token_ids) if eos_token_ids else None\n\n        try:\n            grammar_backend = XGrammarGrammarBackend(\n                tokenizer,\n                vocab_size=vocab_size,\n                model_eos_token_ids=eos_list,\n                any_whitespace=not server_args.constrained_json_disable_any_whitespace,\n            )\n        except TokenizerNotSupportedError as e:\n            logger.warning(\n                f\"Grammar backend disabled because tokenizer is not supported by XGrammar: {e}. \"\n                \"Falling back to grammar_backend='none'. \"\n                \"Structured outputs (JSON schema, regex, EBNF) will not be available.\"\n            )\n            server_args.grammar_backend = \"none\"\n            return None\n    elif name == \"llguidance\":\n        from sglang.srt.constrained.llguidance_backend import GuidanceBackend\n\n        grammar_backend = GuidanceBackend(\n            tokenizer=tokenizer,\n            any_whitespace=not server_args.constrained_json_disable_any_whitespace,\n            whitespace_pattern=server_args.constrained_json_whitespace_pattern,\n        )\n    elif name == \"none\":\n        return None\n    else:\n        raise ValueError(f\"Invalid grammar backend: {name}\")\n\n    if server_args.reasoning_parser and hasattr(tokenizer, \"think_end_id\"):\n        from sglang.srt.constrained.reasoner_grammar_backend import (\n            ReasonerGrammarBackend,\n        )\n\n        grammar_backend = ReasonerGrammarBackend(\n            grammar_backend, tokenizer.think_end_id\n        )\n\n    return grammar_backend\n"
  },
  {
    "path": "python/sglang/srt/constrained/grammar_manager.py",
    "content": "from __future__ import annotations\n\nimport logging\nimport time\nfrom concurrent import futures\nfrom typing import TYPE_CHECKING, List\n\nimport torch\n\nfrom sglang.srt.constrained.base_grammar_backend import (\n    InvalidGrammarObject,\n    create_grammar_backend,\n)\nfrom sglang.srt.environ import envs\n\nif TYPE_CHECKING:\n    from sglang.srt.managers.io_struct import AbortReq\n    from sglang.srt.managers.schedule_batch import Req\n    from sglang.srt.managers.scheduler import Scheduler\n\nlogger = logging.getLogger(__name__)\n\n\nclass GrammarManager:\n    def __init__(self, scheduler: Scheduler):\n        self.scheduler = scheduler\n        self.server_args = scheduler.server_args\n        self.grammar_queue: List[Req] = []\n        if not self.server_args.skip_tokenizer_init:\n            self.grammar_backend = create_grammar_backend(\n                self.server_args,\n                scheduler.tokenizer,\n                scheduler.model_config.vocab_size,\n                scheduler.model_config.hf_eos_token_id,\n            )\n        else:\n            self.grammar_backend = None\n\n        self.grammar_sync_group = scheduler.dp_tp_cpu_group\n        self.grammar_sync_size = scheduler.dp_tp_group.world_size\n        self.grammar_sync_entry = scheduler.dp_tp_group.first_rank\n        self.is_grammar_sync_entry = scheduler.dp_tp_group.is_first_rank\n\n        self.SGLANG_GRAMMAR_POLL_INTERVAL = envs.SGLANG_GRAMMAR_POLL_INTERVAL.get()\n        self.SGLANG_GRAMMAR_MAX_POLL_ITERATIONS = (\n            envs.SGLANG_GRAMMAR_MAX_POLL_ITERATIONS.get()\n        )\n\n    def __len__(self):\n        return len(self.grammar_queue)\n\n    def clear(self):\n        if self.grammar_backend:\n            self.grammar_backend.reset()\n\n    def has_waiting_grammars(self) -> bool:\n        return len(self.grammar_queue) > 0\n\n    def abort_requests(self, recv_req: AbortReq):\n        for req in self.grammar_queue:\n            if recv_req.abort_all or req.rid.startswith(recv_req.rid):\n                logger.debug(f\"Abort grammar queue request. {req.rid=}\")\n                if isinstance(req.grammar, futures.Future) and req.grammar:\n                    req.grammar.cancel()\n                req.set_finish_with_abort(\"Aborted by AbortReq.\")\n\n    def process_req_with_grammar(self, req: Req) -> bool:\n        # Init grammar cache for this request\n        add_to_grammar_queue = False\n        if (\n            req.sampling_params.json_schema is not None\n            or req.sampling_params.regex is not None\n            or req.sampling_params.ebnf is not None\n            or req.sampling_params.structural_tag is not None\n        ):\n            if self.grammar_backend is None:\n                error_msg = \"Grammar-based generation (json_schema, regex, ebnf, structural_tag) is not supported when the server is launched with --grammar-backend none\"\n                req.set_finish_with_abort(error_msg)\n            else:\n                if req.sampling_params.json_schema is not None:\n                    key = (\"json\", req.sampling_params.json_schema)\n                elif req.sampling_params.regex is not None:\n                    key = (\"regex\", req.sampling_params.regex)\n                elif req.sampling_params.ebnf is not None:\n                    key = (\"ebnf\", req.sampling_params.ebnf)\n                elif req.sampling_params.structural_tag:\n                    key = (\"structural_tag\", req.sampling_params.structural_tag)\n\n                value, cache_hit = self.grammar_backend.get_cached_or_future_value(\n                    key, req.require_reasoning\n                )\n                req.grammar = value\n\n                if not cache_hit:\n                    req.grammar_key = key\n                    add_to_grammar_queue = True\n                else:\n                    if isinstance(\n                        value, InvalidGrammarObject\n                    ):  # We hit a cached invalid grammar.\n                        error_msg = (\n                            f\"Failed to compile {key[0]} grammar: {value.error_message}\"\n                        )\n                        req.set_finish_with_abort(error_msg)\n\n        if add_to_grammar_queue:\n            self.grammar_queue.append(req)\n\n        return add_to_grammar_queue\n\n    def get_ready_grammar_requests(self) -> List[Req]:\n        \"\"\"\n        Move requests whose grammar objects are ready from grammar_queue to waiting_queue.\n\n        Rank i returns two sets ready_reqs_i, failed_reqs_i\n        ready_reqs_all = all_gather(ready_reqs_i)\n        failed_reqs_all = all_gather(failed_reqs_i)\n\n        ready_reqs = intersect(ready_reqs_all)\n        failed_reqs = union(failed_reqs_all)\n        \"\"\"\n        assert self.grammar_backend\n        ready_req_idxs: set[int] = set()\n        failed_req_idxs: set[int] = set()\n\n        # Poll for ready requests\n        start_time = time.perf_counter()\n        while time.perf_counter() - start_time < self.SGLANG_GRAMMAR_POLL_INTERVAL:\n            for i, req in enumerate(self.grammar_queue):\n                if i in ready_req_idxs:\n                    continue\n\n                if req.finished() or req.grammar is None:  # It is aborted by AbortReq\n                    ready_req_idxs.add(i)\n                    continue\n\n                assert isinstance(req.grammar, futures.Future), f\"{req=}\"\n                if req.grammar.done():\n                    ready_req_idxs.add(i)\n\n            # Sleep a bit to avoid busy waiting\n            time.sleep(self.SGLANG_GRAMMAR_POLL_INTERVAL / 10)\n\n        # Check failed requests\n        for i, req in enumerate(self.grammar_queue):\n            if i not in ready_req_idxs:\n                self.grammar_queue[i].grammar_wait_ct += 1\n                if (\n                    self.grammar_queue[i].grammar_wait_ct\n                    >= self.SGLANG_GRAMMAR_MAX_POLL_ITERATIONS\n                ):\n                    # Timeout after max poll iterations\n                    # The actual waiting time is SGLANG_GRAMMAR_MAX_POLL_ITERATIONS * max(SGLANG_GRAMMAR_POLL_INTERVAL, GPU_forward_batch_latency)\n                    failed_req_idxs.add(i)\n\n        # Sync ready and failed requests across all ranks\n        if self.grammar_sync_size == 1:\n            synced_ready_req_idxs = ready_req_idxs\n            synced_failed_req_idxs = failed_req_idxs\n        else:\n            all_gather_output = [None] * self.grammar_sync_size\n            torch.distributed.all_gather_object(\n                all_gather_output,\n                (ready_req_idxs, failed_req_idxs),\n                group=self.grammar_sync_group,\n            )\n            synced_ready_req_idxs = set.intersection(*[x[0] for x in all_gather_output])\n            synced_failed_req_idxs = set.union(*[x[1] for x in all_gather_output])\n\n        # Return ready requests\n        return_reqs: List[Req] = []\n        for i in synced_ready_req_idxs:\n            req = self.grammar_queue[i]\n            return_reqs.append(req)\n            if req.finished() or req.grammar is None:  # It is aborted by AbortReq\n                continue\n\n            assert isinstance(req.grammar, futures.Future) and req.grammar_key\n            req.grammar = req.grammar.result()\n            self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())\n            if isinstance(req.grammar, InvalidGrammarObject):\n                error_msg = f\"Failed to compile {req.grammar_key[0]} grammar: {req.grammar.error_message}\"\n                req.set_finish_with_abort(error_msg)\n\n        # Return failed requests\n        for i in synced_failed_req_idxs:\n            req = self.grammar_queue[i]\n            return_reqs.append(req)\n\n            assert isinstance(req.grammar, futures.Future) and req.grammar_key\n            req.grammar.cancel()\n            self.grammar_backend.set_cache(\n                req.grammar_key, InvalidGrammarObject(\"Grammar preprocessing timed out\")\n            )\n            error_msg = f\"Grammar preprocessing timed out: {req.grammar_key=}\"\n            req.set_finish_with_abort(error_msg)\n\n        # Remove finished requests from grammar_queue\n        self.grammar_queue = [\n            req\n            for i, req in enumerate(self.grammar_queue)\n            if i not in synced_ready_req_idxs and i not in synced_failed_req_idxs\n        ]\n        return return_reqs\n"
  },
  {
    "path": "python/sglang/srt/constrained/llguidance_backend.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Constrained decoding with llguidance backend.\"\"\"\n\nimport json\nimport logging\nimport os\nfrom typing import List, Optional, Tuple\n\nimport torch\nfrom llguidance import LLMatcher, LLTokenizer, StructTag, grammar_from\nfrom llguidance.hf import from_tokenizer\nfrom llguidance.torch import (\n    allocate_token_bitmask,\n    apply_token_bitmask_inplace,\n    fill_next_token_bitmask,\n)\n\nfrom sglang.srt.constrained.base_grammar_backend import (\n    BaseGrammarBackend,\n    BaseGrammarObject,\n    InvalidGrammarObject,\n)\nfrom sglang.srt.constrained.utils import is_legacy_structural_tag\n\nlogger = logging.getLogger(__name__)\n\n\nclass GuidanceGrammar(BaseGrammarObject):\n\n    def __init__(self, llguidance_tokenizer: LLTokenizer, serialized_grammar: str):\n        super().__init__()\n        self.llguidance_tokenizer = llguidance_tokenizer\n        self.serialized_grammar = serialized_grammar\n\n        self.ll_matcher = LLMatcher(\n            self.llguidance_tokenizer,\n            self.serialized_grammar,\n            log_level=int(os.environ.get(\"LLGUIDANCE_LOG_LEVEL\", \"1\")),\n        )\n        self._check_err()\n\n        self.bitmask = None\n        self.eos_token = self.llguidance_tokenizer.eos_token\n\n    def accept_token(self, token: int):\n        if self.finished:\n            return\n        if self.ll_matcher.is_stopped() and token == self.eos_token:\n            self.finished = True\n            return\n        self.ll_matcher.consume_token(token)\n        self._check_err()\n\n    def rollback(self, num_tokens: int) -> None:\n        if num_tokens <= 0:\n            return\n        if self.finished:\n            self.finished = False\n            # EOS token after stop isn't tracked in ll_matcher\n            num_tokens -= 1\n        self.ll_matcher.rollback(num_tokens)\n        self._check_err()\n\n    def is_terminated(self):\n        return self.finished\n\n    def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:\n        fill_next_token_bitmask(self.ll_matcher, vocab_mask, idx)\n        self._check_err()\n\n    def allocate_vocab_mask(\n        self, vocab_size: int, batch_size: int, device\n    ) -> torch.Tensor:\n        if self.bitmask is None or self.bitmask.shape[0] < batch_size:\n            # only create bitmask when batch gets larger\n            self.bitmask = allocate_token_bitmask(\n                batch_size, self.llguidance_tokenizer.vocab_size\n            )\n            bitmask = self.bitmask\n        else:\n            bitmask = self.bitmask[:batch_size]\n\n        return bitmask\n\n    @staticmethod\n    def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:\n        return vocab_mask.to(device, non_blocking=True)\n\n    @staticmethod\n    def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:\n        apply_token_bitmask_inplace(logits, vocab_mask)\n\n    def copy(self):\n        return GuidanceGrammar(\n            llguidance_tokenizer=self.llguidance_tokenizer,\n            serialized_grammar=self.serialized_grammar,\n        )\n\n    def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:\n        ff_tokens = self.ll_matcher.compute_ff_tokens()\n        if ff_tokens:\n            return ff_tokens, \"\"\n        else:\n            return None\n\n    def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:\n        return \"\", -1\n\n    def jump_and_retokenize(\n        self, old_output_ids: List[int], new_output_ids: List[int], next_state: int\n    ):\n        pass\n\n    def _check_err(self) -> None:\n        if self.ll_matcher.is_error():\n            raise ValueError(self.ll_matcher.get_error())\n\n\nclass GuidanceBackend(BaseGrammarBackend):\n\n    def __init__(\n        self,\n        tokenizer,\n        any_whitespace: bool = True,\n        whitespace_pattern: Optional[str] = None,\n        n_vocab: Optional[int] = None,\n    ):\n        super().__init__()\n\n        self.tokenizer = tokenizer\n        self.any_whitespace = any_whitespace\n        self.whitespace_pattern = whitespace_pattern\n        self.llguidance_tokenizer = from_tokenizer(self.tokenizer, n_vocab)\n\n    def _from_serialized(self, serialized_grammar) -> BaseGrammarObject:\n        try:\n            return GuidanceGrammar(\n                llguidance_tokenizer=self.llguidance_tokenizer,\n                serialized_grammar=serialized_grammar,\n            )\n        except Exception as e:\n            logger.error(f\"Hit invalid grammar: {serialized_grammar=}, {e=}\")\n            return InvalidGrammarObject(str(e))\n\n    def dispatch_json(self, key_string: str) -> BaseGrammarObject:\n        try:\n            serialized_grammar = LLMatcher.grammar_from_json_schema(\n                key_string,\n                defaults={\n                    \"whitespace_flexible\": self.any_whitespace,\n                    \"whitespace_pattern\": self.whitespace_pattern,\n                },\n            )\n        except Exception as e:\n            logger.error(f\"Hit invalid json_schema: {key_string=}, {e=}\")\n            return InvalidGrammarObject(str(e))\n        return self._from_serialized(serialized_grammar)\n\n    def dispatch_regex(self, key_string: str) -> BaseGrammarObject:\n        serialized_grammar = grammar_from(\"regex\", key_string)\n        return self._from_serialized(serialized_grammar)\n\n    def dispatch_ebnf(self, key_string: str) -> BaseGrammarObject:\n        try:\n            serialized_grammar = grammar_from(\"ebnf\", key_string)\n            return self._from_serialized(serialized_grammar)\n        except ValueError as e:\n            logger.error(f\"Hit invalid ebnf: {key_string=}, {e=}\")\n            return InvalidGrammarObject(str(e))\n\n    def dispatch_structural_tag(self, key_string: str) -> BaseGrammarObject:\n        try:\n            structural_tag = json.loads(key_string)\n            assert is_legacy_structural_tag(structural_tag)\n            tags = [\n                StructTag(\n                    begin=structure[\"begin\"],\n                    grammar=structure[\"schema\"],\n                    end=structure[\"end\"],\n                    trigger=structural_tag[\"triggers\"][0],  # TODO?\n                )\n                for structure in structural_tag[\"structures\"]\n            ]\n            g = StructTag.to_grammar(tags)\n            return self._from_serialized(g)\n        except Exception as e:\n            logger.error(f\"Hit invalid structural_tag: {key_string=}, {e=}\")\n            return InvalidGrammarObject(str(e))\n"
  },
  {
    "path": "python/sglang/srt/constrained/outlines_backend.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Constrained decoding with outlines backend.\"\"\"\n\nimport json\nimport logging\nfrom typing import Dict, List, Optional, Tuple, Union\n\nimport interegular\nimport torch\nfrom outlines.fsm.guide import RegexGuide\nfrom outlines.models.transformers import TransformerTokenizer\nfrom pydantic import BaseModel\n\nfrom sglang.srt.constrained.base_grammar_backend import (\n    BaseGrammarBackend,\n    BaseGrammarObject,\n    InvalidGrammarObject,\n)\nfrom sglang.srt.constrained.outlines_jump_forward import OutlinesJumpForwardMap\n\ntry:\n    from outlines.fsm.json_schema import build_regex_from_schema\nexcept ImportError:\n    from outlines_core.fsm.json_schema import build_regex_from_schema\n\n\nlogger = logging.getLogger(__name__)\n\n\nclass OutlinesGrammar(BaseGrammarObject):\n    def __init__(\n        self,\n        guide: RegexGuide,\n        jump_forward_map: Union[OutlinesJumpForwardMap, None],\n    ) -> None:\n        super().__init__()\n        self.guide = guide\n        self.jump_forward_map = jump_forward_map\n        self.state = 0\n\n    def accept_token(self, token: int):\n        self.state = self.guide.get_next_state(self.state, token)\n\n    def allocate_vocab_mask(\n        self, vocab_size: int, batch_size: int, device\n    ) -> torch.Tensor:\n        return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device)\n\n    @staticmethod\n    def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:\n        return vocab_mask\n\n    def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:\n        tokens = torch.tensor(\n            self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64\n        ).to(vocab_mask.device, non_blocking=True)\n        vocab_mask = vocab_mask[idx]\n        vocab_mask.fill_(1)\n        vocab_mask.scatter_(0, tokens, torch.zeros_like(tokens, dtype=torch.bool))\n\n    @staticmethod\n    def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor):\n        logits.masked_fill_(vocab_mask, float(\"-inf\"))\n\n    def copy(self):\n        return OutlinesGrammar(self.guide, self.jump_forward_map)\n\n    def try_jump_forward(self, tokenizer) -> Optional[Tuple]:\n        if not self.jump_forward_map:\n            return None\n\n        jump_forward_bytes = self.jump_forward_map.jump_forward_byte(self.state)\n        if jump_forward_bytes is None or len(jump_forward_bytes) <= 1:\n            return None\n\n        # preprocess the jump forward string\n        suffix_bytes = []\n        continuation_range = range(0x80, 0xC0)\n        cur_state = self.state\n        while (\n            len(jump_forward_bytes) and jump_forward_bytes[0][0] in continuation_range\n        ):\n            # continuation bytes\n            byte_edge = jump_forward_bytes.pop(0)\n            suffix_bytes.append(byte_edge[0])\n            cur_state = byte_edge[1]\n\n        suffix_tokens = [f\"<0x{hex(b)[2:].upper()}>\" for b in suffix_bytes]\n        suffix_ids = tokenizer.convert_tokens_to_ids(suffix_tokens)\n        return suffix_ids, cur_state\n\n    def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:\n        _, cur_state = helper\n        return self.jump_forward_map.jump_forward_symbol(cur_state)\n\n    def jump_and_retokenize(\n        self, old_output_ids: List[int], new_output_ids: List[int], next_state: int\n    ):\n        self.state = next_state\n\n\nclass OutlinesGrammarBackend(BaseGrammarBackend):\n    def __init__(\n        self,\n        tokenizer,\n        whitespace_pattern: str | None,\n    ):\n        super().__init__()\n\n        try:\n            self.outlines_tokenizer = TransformerTokenizer(tokenizer)\n        except AttributeError:\n            # FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0)\n            origin_pad_token_id = tokenizer.pad_token_id\n\n            def fset(self, value):\n                self._value = value\n\n            type(tokenizer).pad_token_id = property(\n                fget=type(tokenizer).pad_token_id.fget, fset=fset\n            )\n            self.outlines_tokenizer = TransformerTokenizer(tokenizer)\n            self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id\n            self.outlines_tokenizer.pad_token_id = origin_pad_token_id\n            self.outlines_tokenizer.pad_token = (\n                self.outlines_tokenizer.tokenizer.pad_token\n            )\n            self.outlines_tokenizer.vocabulary = (\n                self.outlines_tokenizer.tokenizer.get_vocab()\n            )\n        self.whitespace_pattern = whitespace_pattern\n\n    def _compile_regex(self, regex: str) -> BaseGrammarObject:\n        try:\n            if hasattr(RegexGuide, \"from_regex\"):\n                # outlines >= 0.1.1\n                guide = RegexGuide.from_regex(regex, self.outlines_tokenizer)\n            else:\n                # outlines <= 0.0.46\n                guide = RegexGuide(regex, self.outlines_tokenizer)\n        except interegular.patterns.InvalidSyntax as e:\n            logger.error(f\"Hit invalid regex schema: {regex=}, {e=}\")\n            return InvalidGrammarObject(str(e))\n\n        jump_forward_map = None\n        return OutlinesGrammar(guide, jump_forward_map)\n\n    def dispatch_ebnf(self, key_string: str):\n        return super().dispatch_ebnf(key_string)\n\n    def dispatch_structural_tag(self, key_string: str):\n        return super().dispatch_structural_tag(key_string)\n\n    def dispatch_json(self, key_string: str):\n        try:\n            regex = build_regex_from_object(\n                key_string,\n                whitespace_pattern=self.whitespace_pattern,\n            )\n        except (NotImplementedError, json.decoder.JSONDecodeError, ValueError) as e:\n            logger.error(f\"Hit invalid json_schema: {key_string=}, {e=}\")\n            return InvalidGrammarObject(str(e))\n        return self._compile_regex(regex)\n\n    def dispatch_regex(self, key_string: str):\n        return self._compile_regex(key_string)\n\n\ndef build_regex_from_object(\n    object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None\n):\n    if isinstance(object, type(BaseModel)):\n        schema = json.dumps(object.model_json_schema())\n    elif isinstance(object, Dict):\n        schema = json.dumps(object)\n    else:\n        schema = object\n    return build_regex_from_schema(schema, whitespace_pattern)\n"
  },
  {
    "path": "python/sglang/srt/constrained/outlines_jump_forward.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"\nFaster constrained decoding with jump forward decoding / compressed finite state machine.\nReference: https://lmsys.org/blog/2024-02-05-compressed-fsm/\n\"\"\"\n\nimport dataclasses\nimport logging\nfrom collections import defaultdict\nfrom typing import Optional\n\nimport interegular\nfrom interegular import InvalidSyntax\nfrom outlines.caching import cache\n\nfrom sglang.srt.utils import get_bool_env_var\n\ntry:\n    # outlines >= 0.1.0\n    from outlines_core.fsm.outlines_core_rs import FSMInfo\n    from outlines_core.fsm.regex import make_byte_level_fsm, make_deterministic_fsm\nexcept ImportError:\n    # outlines <= 0.0.46\n    from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm\n\nIP_REGEX = r\"((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\"\n\n# Env var was set in sglang.srt.server_args.ServerArgs.__post_init__\nDISABLE_DISK_CACHE = get_bool_env_var(\"SGLANG_DISABLE_OUTLINES_DISK_CACHE\", \"true\")\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass JumpEdge:\n    symbol: str = None\n    symbol_next_state: int = None\n    byte: int = None\n    byte_next_state: int = None\n\n\ndef disk_cache(expire: Optional[float] = None, typed=False, ignore=()):\n    if not DISABLE_DISK_CACHE:\n        return cache(expire, typed, ignore)\n    else:\n        return lambda fn: None\n\n\n@disk_cache()\ndef init_state_to_jump_forward(regex_string):\n    try:\n        regex_pattern = interegular.parse_pattern(regex_string)\n    except InvalidSyntax as e:\n        logger.warning(f\"skip invalid regex: {regex_string}, {e=}\")\n        return\n\n    byte_fsm = make_byte_level_fsm(regex_pattern.to_fsm().reduce(), keep_utf8=True)\n    regex_fsm, _ = make_deterministic_fsm(byte_fsm)\n\n    fsm_info: FSMInfo = regex_fsm.fsm_info\n\n    symbol_to_id = fsm_info.alphabet_symbol_mapping\n    id_to_symbol = {}\n    for symbol, id_ in symbol_to_id.items():\n        id_to_symbol.setdefault(id_, []).append(symbol)\n\n    transitions = fsm_info.transitions\n\n    outgoings_ct = defaultdict(int)\n    # NOTE(lsyin): Final states can lead to terminate, so they have one outgoing edge naturally\n    for s in fsm_info.finals:\n        outgoings_ct[s] = 1\n\n    state_to_jump_forward = {}\n    for (state, id_), next_state in transitions.items():\n        if id_ == fsm_info.alphabet_anything_value:\n            # Arbitrarily symbol cannot be recognized as jump forward\n            continue\n\n        symbols = id_to_symbol[id_]\n        for c in symbols:\n            if len(c) > 1:\n                # Skip byte level transitions like c = \"5E\"\n                continue\n\n            outgoings_ct[state] += 1\n            if outgoings_ct[state] > 1:\n                if state in state_to_jump_forward:\n                    del state_to_jump_forward[state]\n                break\n\n            state_to_jump_forward[state] = JumpEdge(\n                symbol=c,\n                symbol_next_state=next_state,\n            )\n\n    # Process the byte level jump forward\n    outgoings_ct = defaultdict(int)\n    for s in fsm_info.finals:\n        outgoings_ct[s] = 1\n\n    for (state, id_), next_state in transitions.items():\n        if id_ == fsm_info.alphabet_anything_value:\n            continue\n        symbols = id_to_symbol[id_]\n        for c in symbols:\n            byte_ = None\n            if len(c) == 1 and ord(c) < 0x80:\n                # ASCII character\n                byte_ = ord(c)\n            elif len(c) > 1:\n                # FIXME: This logic is due to the leading \\x00\n                # https://github.com/outlines-dev/outlines/pull/930\n                byte_ = int(symbols[0][1:], 16)\n\n            if byte_ is not None:\n                outgoings_ct[state] += 1\n                if outgoings_ct[state] > 1:\n                    if state in state_to_jump_forward:\n                        del state_to_jump_forward[state]\n                    break\n                e = state_to_jump_forward.get(state, JumpEdge())\n                e.byte = byte_\n                e.byte_next_state = next_state\n                state_to_jump_forward[state] = e\n\n    return state_to_jump_forward\n\n\nclass OutlinesJumpForwardMap:\n    def __init__(self, regex_string):\n        self.state_to_jump_forward = init_state_to_jump_forward(regex_string)\n\n    def jump_forward_symbol(self, state):\n        jump_forward_str = \"\"\n        next_state = state\n        while state in self.state_to_jump_forward:\n            e = self.state_to_jump_forward[state]\n            if e.symbol is None:\n                break\n            jump_forward_str += e.symbol\n            next_state = e.symbol_next_state\n            state = next_state\n\n        return jump_forward_str, next_state\n\n    def jump_forward_byte(self, state):\n        if state not in self.state_to_jump_forward:\n            return None\n\n        jump_forward_bytes = []\n        next_state = None\n        while state in self.state_to_jump_forward:\n            e = self.state_to_jump_forward[state]\n            assert e.byte is not None and e.byte_next_state is not None\n            jump_forward_bytes.append((e.byte, e.byte_next_state))\n            next_state = e.byte_next_state\n            state = next_state\n\n        return jump_forward_bytes\n\n    def is_jump_forward_symbol_state(self, state):\n        return (\n            state in self.state_to_jump_forward\n            and self.state_to_jump_forward[state].symbol is not None\n        )\n\n\ndef test_main(regex_string):\n    jump_forward_map = OutlinesJumpForwardMap(regex_string)\n    for state, e in jump_forward_map.state_to_jump_forward.items():\n        if e.symbol is not None:\n            jump_forward_str, next_state = jump_forward_map.jump_forward_symbol(state)\n            print(f\"{state} -> {next_state}\", jump_forward_str)\n        bytes_ = jump_forward_map.jump_forward_byte(state)\n        print(f\"{state} -> {bytes_[-1][1]}\", [hex(b) for b, _ in bytes_])\n\n\nif __name__ == \"__main__\":\n    import outlines\n\n    outlines.caching.clear_cache()\n    test_main(r\"The google's DNS sever address is \" + IP_REGEX)\n    test_main(r\"霍格沃茨特快列车|霍比特人比尔博\")\n    # 霍格: \\xe9\\x9c\\x8d \\xe6\\xa0\\xbc ...\n    # 霍比: \\xe9\\x9c\\x8d \\xe6\\xaf\\x94 ...\n\n    test_main(r\"[-+]?[0-9]+[ ]*\")\n"
  },
  {
    "path": "python/sglang/srt/constrained/reasoner_grammar_backend.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"The baseclass of a backend for reasoner grammar-guided constrained decoding.\"\"\"\n\nfrom typing import List, Optional, Tuple\n\nimport torch\n\nfrom .base_grammar_backend import (\n    BaseGrammarBackend,\n    BaseGrammarObject,\n    InvalidGrammarObject,\n)\n\n\nclass ReasonerGrammarObject(BaseGrammarObject):\n    def __init__(self, grammar: BaseGrammarObject, think_end_id: int):\n        super().__init__()\n        self.grammar = grammar\n        self.think_end_id = think_end_id\n        # -1    means thinking has not ended yet\n        # 0     means just ended thinking in the last token\n        # +     means number of tokens after thinking ended\n        self.tokens_after_think_end = -1\n\n    def maybe_init_reasoning(self, reasoning: bool):\n        self.tokens_after_think_end = -1 if reasoning else 0\n\n    def transfer_state(self, token: int) -> int:\n        if self.tokens_after_think_end == -1 and token == self.think_end_id:\n            self.tokens_after_think_end = 0\n        elif self.tokens_after_think_end >= 0:\n            self.tokens_after_think_end += 1\n\n    def rollback_state(self):\n        if self.tokens_after_think_end == 0:\n            self.tokens_after_think_end = -1\n        elif self.tokens_after_think_end > 0:\n            self.tokens_after_think_end -= 1\n\n    def accept_token(self, token: int):\n        if self.tokens_after_think_end >= 0:\n            self.grammar.accept_token(token)\n        self.transfer_state(token)\n\n    def is_terminated(self):\n        return self.grammar.is_terminated()\n\n    def rollback(self, k):\n        steps_after_think = min(k, self.tokens_after_think_end)\n        if steps_after_think > 0:\n            self.grammar.rollback(steps_after_think)\n\n        for _ in range(k):\n            self.rollback_state()\n\n    def allocate_vocab_mask(\n        self, vocab_size: int, batch_size: int, device\n    ) -> torch.Tensor:\n        return self.grammar.allocate_vocab_mask(vocab_size, batch_size, device)\n\n    def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:\n        if self.tokens_after_think_end >= 0:\n            self.grammar.fill_vocab_mask(vocab_mask, idx)\n\n    def move_vocab_mask(self, vocab_mask: torch.Tensor, device) -> torch.Tensor:\n        return self.grammar.move_vocab_mask(vocab_mask, device)\n\n    @property\n    def apply_vocab_mask(self):\n        return self.grammar.apply_vocab_mask\n\n    def copy(self) -> BaseGrammarObject:\n        return ReasonerGrammarObject(self.grammar.copy(), self.think_end_id)\n\n    @property\n    def finished(self):\n        return self.grammar.finished\n\n    @finished.setter\n    def finished(self, finished):\n        self.grammar.finished = finished\n\n    def try_jump_forward(self, tokenizer):\n        return self.grammar.try_jump_forward(tokenizer)\n\n    def jump_forward_str_state(self, helper):\n        return self.grammar.jump_forward_str_state(helper)\n\n    def jump_and_retokenize(\n        self, old_output_ids: List[int], new_output_ids: List[int], next_state: int\n    ):\n        return self.grammar.jump_and_retokenize(\n            old_output_ids, new_output_ids, next_state\n        )\n\n\nclass ReasonerGrammarBackend(BaseGrammarBackend):\n    def __init__(self, grammar_backend: BaseGrammarBackend, think_end_id):\n        super().__init__()\n        self.grammar_backend = grammar_backend\n        self.think_end_id = think_end_id\n\n    def _init_value_dispatch(\n        self, key: Tuple[str, str], reasoning: bool\n    ) -> Optional[BaseGrammarObject]:\n        ret = self.grammar_backend._init_value_dispatch(key, reasoning)\n        # avoid wrapping invalid grammar, so that the scheduler can detect it\n        if ret is None or isinstance(ret, InvalidGrammarObject):\n            return ret\n        obj = ReasonerGrammarObject(ret, self.think_end_id)\n        obj.maybe_init_reasoning(reasoning)\n        return obj\n"
  },
  {
    "path": "python/sglang/srt/constrained/triton_ops/bitmask_ops.py",
    "content": "# Adapt from\n# https://github.com/mlc-ai/xgrammar/blob/v0.1.17/python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py\n\nfrom typing import List, Optional, Union\n\nimport torch\nimport triton\nimport triton.language as tl\n\nfrom sglang.srt.utils import get_device_core_count\n\n\n@triton.jit\ndef apply_token_bitmask_inplace_kernel(\n    logits_ptr,\n    bitmask_ptr,\n    indices_ptr,\n    num_rows,\n    vocab_size,\n    logits_strides,\n    bitmask_strides,\n    NUM_SMS: tl.constexpr,\n    BLOCK_SIZE: tl.constexpr,\n):\n    \"\"\"Apply a bitmask to logits in-place using Triton. The bitmask is a 01 bitwise compressed tensor,\n    where 0 means the token is masked and 1 means the token is not masked. After applying the bitmask,\n    the masked logits will be set to -inf.\n\n    Parameters\n    ----------\n    logits_ptr : tl.tensor\n        Pointer to the logits tensor to apply the bitmask to.\n\n    bitmask_ptr : tl.tensor\n        Pointer to the bitmask tensor to apply.\n\n    indices_ptr : Optional[tl.tensor]\n        Optional pointer to indices tensor specifying which rows to apply the mask to.\n\n    num_rows : int\n        Number of rows to process. If indices_ptr is provided, this is the number of unique indices.\n\n    vocab_size : int\n        Size of the vocabulary dimension. If the logits does not have a vocab padding, this is the\n        same as the logits's second dimension. Otherwise, this is the actual size of the vocabulary.\n\n    logits_strides : int\n        Stride between rows in the logits tensor.\n\n    bitmask_strides : int\n        Stride between rows in the bitmask tensor.\n\n    NUM_SMS : int\n        Number of streaming multiprocessors to use.\n\n    BLOCK_SIZE : int\n        Size of processing blocks.\n    \"\"\"\n\n    pid = tl.program_id(0)\n    num_blocks = tl.cdiv(vocab_size, BLOCK_SIZE)\n    for work_id in tl.range(pid, num_rows * num_blocks, NUM_SMS):\n        row_id = work_id // num_blocks\n        block_offset = (work_id % num_blocks) * BLOCK_SIZE\n        batch_id = row_id if indices_ptr is None else tl.load(indices_ptr + row_id)\n        offsets = block_offset + tl.arange(0, BLOCK_SIZE)\n        bitmask_offsets = block_offset // 32 + tl.arange(0, BLOCK_SIZE // 32)\n        vocab_mask = offsets < vocab_size\n        packed_bitmask_mask = bitmask_offsets < bitmask_strides\n        packed_bitmask = tl.load(\n            bitmask_ptr + batch_id * bitmask_strides + bitmask_offsets,\n            packed_bitmask_mask,\n        )\n        bitmask = ((packed_bitmask[:, None] >> (tl.arange(0, 32)[None, :])) & 1) == 0\n        bitmask = bitmask.reshape(BLOCK_SIZE)\n\n        tl.store(\n            logits_ptr + batch_id * logits_strides + offsets,\n            -float(\"inf\"),\n            vocab_mask & bitmask,\n        )\n\n\ndef apply_token_bitmask_inplace_triton(\n    logits: torch.Tensor,\n    bitmask: torch.Tensor,\n    indices: Optional[Union[List[int], torch.Tensor]] = None,\n):\n    NUM_SMS = get_device_core_count()\n    BLOCK_SIZE = 4096\n    BITS_PER_BLOCK = 32\n\n    # Check input dtype\n    assert bitmask.dtype == torch.int32, \"bitmask must be of type int32\"\n\n    # Check input tensor shapes.\n    logits_shape = logits.shape\n    bitmask_shape = bitmask.shape\n    if logits.ndim == 1:\n        logits_shape = (1, logits_shape[0])\n    if bitmask.ndim == 1:\n        bitmask_shape = (1, bitmask_shape[0])\n\n    required_bitmask_width = (logits_shape[1] + BITS_PER_BLOCK - 1) // BITS_PER_BLOCK\n    assert required_bitmask_width >= bitmask_shape[1], (\n        f\"Bitmask width too large: allow at most {required_bitmask_width} int32s for \"\n        f\"logits' width {logits_shape[1]}, but got {bitmask_shape[1]}\"\n    )\n\n    vocab_size = min(logits_shape[1], bitmask_shape[1] * BITS_PER_BLOCK)\n\n    num_rows = None\n    if isinstance(indices, list) or isinstance(indices, torch.Tensor):\n        indices = torch.tensor(indices, dtype=torch.int32, device=logits.device)\n        num_rows = indices.shape[0]\n    else:\n        assert (\n            logits_shape[0] == bitmask_shape[0]\n        ), f\"batch size mismatch: logits {logits_shape[0]} vs bitmask {bitmask_shape[0]}\"\n        num_rows = logits_shape[0]\n\n    if NUM_SMS > 0:\n        grid = (NUM_SMS,)\n    else:\n        num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)\n        grid = (num_rows * num_blocks,)\n        NUM_SMS = triton.next_power_of_2(grid[0])\n\n    apply_token_bitmask_inplace_kernel[grid](\n        logits,\n        bitmask,\n        indices,\n        num_rows,\n        vocab_size,\n        logits_shape[1],\n        bitmask_shape[1],\n        NUM_SMS,\n        BLOCK_SIZE,\n        num_warps=BLOCK_SIZE // 32 // (16 // logits.element_size()),\n        num_stages=3,\n    )\n"
  },
  {
    "path": "python/sglang/srt/constrained/utils.py",
    "content": "from typing import Dict\n\n\ndef is_legacy_structural_tag(obj: Dict) -> bool:\n    # test whether an object is a legacy structural tag\n    # see `StructuralTagResponseFormat` at `sglang.srt.entrypoints.openai.protocol`\n    if obj.get(\"structures\", None) is not None:\n        assert obj.get(\"triggers\", None) is not None\n        return True\n    else:\n        assert obj.get(\"format\", None) is not None\n        return False\n"
  },
  {
    "path": "python/sglang/srt/constrained/xgrammar_backend.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Constrained decoding with xgrammar backend.\"\"\"\n\nimport dataclasses\nimport json\nimport logging\nfrom typing import Dict, List, Optional, Tuple, Union\n\nimport torch\nfrom xgrammar import (\n    CompiledGrammar,\n    GrammarCompiler,\n    GrammarMatcher,\n    StructuralTagItem,\n    TokenizerInfo,\n    allocate_token_bitmask,\n)\n\nfrom sglang.srt.constrained.base_grammar_backend import (\n    BaseGrammarBackend,\n    BaseGrammarObject,\n    GrammarStats,\n    InvalidGrammarObject,\n)\nfrom sglang.srt.constrained.utils import is_legacy_structural_tag\nfrom sglang.srt.utils import is_hip\n\n_is_hip = is_hip()\nif _is_hip:\n    from sgl_kernel import apply_token_bitmask_inplace_cuda\nelse:\n    from sglang.srt.constrained.triton_ops.bitmask_ops import (\n        apply_token_bitmask_inplace_triton,\n    )\n\n\nlogger = logging.getLogger(__name__)\nMAX_ROLLBACK_TOKENS = 200\n\n\nclass XGrammarGrammar(BaseGrammarObject):\n\n    def __init__(\n        self,\n        matcher: GrammarMatcher,\n        vocab_size: int,\n        ctx: CompiledGrammar,\n        override_stop_tokens: Optional[Union[List[int], int]],\n        key_string: Optional[str] = None,  # TODO (sk): for debugging, remove later\n        grammar_stats: Optional[GrammarStats] = GrammarStats(),\n    ) -> None:\n        super().__init__()\n        self.matcher = matcher\n        self.vocab_size = vocab_size\n        self.ctx = ctx\n        self.override_stop_tokens = override_stop_tokens\n        self.accepted_tokens = []\n        self.key_string = key_string\n        self.grammar_stats = grammar_stats\n\n    def accept_token(self, token: int):\n        if not self.is_terminated():\n            self.current_token = token\n            accepted = self.matcher.accept_token(token)\n            if not accepted:\n                # log for debugging\n                raise ValueError(\n                    f\"Tokens not accepted: {token}\\n\"\n                    f\"Accepted tokens: {self.accepted_tokens}\\n\"\n                    f\"Key string: {self.key_string}\"\n                )\n            else:\n                self.accepted_tokens.append(token)\n\n    def rollback(self, k: int):\n        self.matcher.rollback(k)\n        self.accepted_tokens = self.accepted_tokens[:-k]\n\n    def is_terminated(self):\n        return self.matcher.is_terminated()\n\n    def allocate_vocab_mask(\n        self, vocab_size: int, batch_size: int, device\n    ) -> torch.Tensor:\n        return allocate_token_bitmask(batch_size, vocab_size)\n\n    def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:\n        self.matcher.fill_next_token_bitmask(vocab_mask, idx)\n\n    @staticmethod\n    def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:\n        return vocab_mask.to(device, non_blocking=True)\n\n    def apply_vocab_mask(self, logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:\n        if (\n            logits.device.type == \"cuda\"\n            or logits.device.type == \"npu\"\n            or logits.device.type == \"xpu\"\n        ):\n            if _is_hip:\n                apply_token_bitmask_inplace_cuda(logits, vocab_mask)\n            else:\n                apply_token_bitmask_inplace_triton(logits, vocab_mask)\n        else:\n            raise RuntimeError(f\"Unsupported device: {logits.device.type}\")\n\n    def copy(self):\n        matcher = GrammarMatcher(\n            self.ctx,\n            max_rollback_tokens=MAX_ROLLBACK_TOKENS,\n            override_stop_tokens=self.override_stop_tokens,\n        )\n        if grammar_stats := self.grammar_stats:\n            grammar_stats = dataclasses.replace(\n                grammar_stats, is_cache_hit=True, tree_traversal_time=[]\n            )\n        return XGrammarGrammar(\n            matcher,\n            self.vocab_size,\n            self.ctx,\n            self.override_stop_tokens,\n            self.key_string,\n            grammar_stats,\n        )\n\n    def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:\n        s = self.matcher.find_jump_forward_string()\n        if s:\n            return [], s\n        return None\n\n    def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:\n        _, data = helper\n        return data, -1\n\n    def jump_and_retokenize(\n        self, old_output_ids: List[int], new_output_ids: List[int], next_state: int\n    ):\n        k = 0\n        for i, old_id in enumerate(old_output_ids):\n            if old_id == new_output_ids[i]:\n                k = i + 1\n            else:\n                break\n\n        # rollback to the last token that is the same\n        if k < len(old_output_ids):\n            self.matcher.rollback(len(old_output_ids) - k)\n\n        for i in range(k, len(new_output_ids)):\n            assert self.matcher.accept_token(new_output_ids[i])\n\n    def __repr__(self):\n        return f\"XGrammarGrammar({self.key_string=}, {self.accepted_tokens=}, {self.current_token=})\"\n\n\nclass TokenizerNotSupportedError(Exception):\n    \"\"\"Raised when tokenizer is not supported by XGrammar backend.\"\"\"\n\n    pass\n\n\nclass XGrammarGrammarBackend(BaseGrammarBackend):\n    def __init__(\n        self,\n        tokenizer,\n        vocab_size: int,\n        model_eos_token_ids: Optional[List[int]] = None,\n        any_whitespace: bool = True,\n    ):\n        super().__init__()\n\n        if hasattr(tokenizer, \"init_xgrammar\"):\n            # For special tokenizer\n            tokenizer_info, override_stop_tokens = tokenizer.init_xgrammar()\n\n            if tokenizer_info is None:\n                # Not supported tokenizer\n                raise TokenizerNotSupportedError(\n                    f\"Tokenizer type {type(tokenizer).__name__} is not supported by XGrammar\"\n                )\n        else:\n            # Create TokenizerInfo with model's EOS tokens as the authoritative stop tokens\n            # This ensures consistency between what the model considers EOS and what XGrammar uses\n            try:\n                tokenizer_info = TokenizerInfo.from_huggingface(\n                    tokenizer, vocab_size=vocab_size, stop_token_ids=model_eos_token_ids\n                )\n                override_stop_tokens = None\n            except Exception as e:\n                raise TokenizerNotSupportedError(\n                    f\"Failed to create XGrammar TokenizerInfo from tokenizer: {e}\"\n                )\n\n        self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)\n        self.vocab_size = vocab_size\n        self.override_stop_tokens = override_stop_tokens\n        self.any_whitespace = any_whitespace\n\n    @staticmethod\n    def _sanitize_structural_format(structural_format):\n        \"\"\"Recursively replace missing json_schema fields with an empty schema.\"\"\"\n        if not isinstance(structural_format, dict):\n            return\n\n        fmt_type = structural_format.get(\"type\")\n        if fmt_type in {\"json_schema\", \"qwen_xml_parameter\"}:\n            if structural_format.get(\"json_schema\") is None:\n                structural_format[\"json_schema\"] = {}\n\n        if fmt_type == \"tag\":\n            XGrammarGrammarBackend._sanitize_structural_format(\n                structural_format.get(\"content\")\n            )\n        elif fmt_type in {\"sequence\", \"or\"}:\n            for element in structural_format.get(\"elements\", []):\n                XGrammarGrammarBackend._sanitize_structural_format(element)\n        elif fmt_type in {\"triggered_tags\", \"tags_with_separator\"}:\n            for tag in structural_format.get(\"tags\", []):\n                XGrammarGrammarBackend._sanitize_structural_format(tag)\n\n    @staticmethod\n    def _sanitize_structural_tag_structures(structural_tag: Dict) -> None:\n        for structure in structural_tag.get(\"structures\", []):\n            if structure.get(\"schema\") is None:\n                structure[\"schema\"] = {}\n\n    def _from_context(\n        self, ctx: CompiledGrammar, key_string: str, grammar_stats: GrammarStats\n    ) -> XGrammarGrammar:\n        matcher = GrammarMatcher(\n            ctx,\n            max_rollback_tokens=MAX_ROLLBACK_TOKENS,\n            override_stop_tokens=self.override_stop_tokens,\n        )\n        return XGrammarGrammar(\n            matcher,\n            self.vocab_size,\n            ctx,\n            self.override_stop_tokens,\n            key_string,\n            grammar_stats,\n        )\n\n    def dispatch_json(self, key_string: str) -> BaseGrammarObject:\n        try:\n            if key_string == \"$$ANY$$\":\n                # Note: This builtin JSON grammar includes *all* valid JSON (including, for example, arrays at the root)\n                ctx = self.grammar_compiler.compile_builtin_json_grammar()\n            else:\n                ctx = self.grammar_compiler.compile_json_schema(\n                    schema=key_string, any_whitespace=self.any_whitespace\n                )\n\n        except (RuntimeError, json.decoder.JSONDecodeError, UnicodeDecodeError) as e:\n            logger.error(f\"Hit invalid json_schema: {key_string=}, {e=}\")\n            return InvalidGrammarObject(str(e))\n        return self._from_context(ctx, key_string, GrammarStats(dispatch_type=\"json\"))\n\n    def dispatch_ebnf(self, key_string: str) -> BaseGrammarObject:\n        try:\n            ctx = self.grammar_compiler.compile_grammar(key_string)\n        except RuntimeError as e:\n            logger.error(f\"Hit invalid ebnf: {key_string=}, {e=}\")\n            return InvalidGrammarObject(str(e))\n        return self._from_context(ctx, key_string, GrammarStats(dispatch_type=\"ebnf\"))\n\n    def dispatch_regex(self, key_string: str) -> BaseGrammarObject:\n        try:\n            ctx = self.grammar_compiler.compile_regex(key_string)\n        except RuntimeError as e:\n            logger.error(f\"Hit invalid regex: {key_string=}, {e=}\")\n            return InvalidGrammarObject(str(e))\n        return self._from_context(ctx, key_string, GrammarStats(dispatch_type=\"regex\"))\n\n    def dispatch_structural_tag(self, key_string: str) -> BaseGrammarObject:\n        try:\n            # TODO(dark): it's REALLY stupid to construct object from string and decode it again\n            structural_tag = json.loads(key_string)\n            if is_legacy_structural_tag(structural_tag):\n                self._sanitize_structural_tag_structures(structural_tag)\n                tags = [\n                    StructuralTagItem(\n                        begin=structure[\"begin\"],\n                        schema=json.dumps(structure[\"schema\"]),\n                        end=structure[\"end\"],\n                    )\n                    for structure in structural_tag[\"structures\"]\n                ]\n                ctx = self.grammar_compiler.compile_structural_tag(\n                    tags, structural_tag[\"triggers\"]\n                )\n            else:\n                format_dict = structural_tag.get(\"format\")\n                if isinstance(format_dict, dict):\n                    self._sanitize_structural_format(format_dict)\n                    structural_tag[\"format\"] = format_dict\n                    key_string = json.dumps(structural_tag)\n                ctx = self.grammar_compiler.compile_structural_tag(key_string)\n        except (RuntimeError, json.decoder.JSONDecodeError) as e:\n            logger.error(f\"Hit invalid structural_tag: {key_string=}, {e=}\")\n            return InvalidGrammarObject(str(e))\n        return self._from_context(\n            ctx, key_string, GrammarStats(dispatch_type=\"structural_tag\")\n        )\n\n    def reset(self):\n        self.grammar_compiler.clear_cache()\n\n\ndef demo_test():\n    from transformers import AutoConfig, AutoTokenizer\n\n    from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST\n\n    tokenizer = AutoTokenizer.from_pretrained(DEFAULT_MODEL_NAME_FOR_TEST)\n    hf_config = AutoConfig.from_pretrained(DEFAULT_MODEL_NAME_FOR_TEST)\n\n    # Should use vocab size from model config\n    vocab_size = hf_config.vocab_size\n    eos_token_id = tokenizer.eos_token_id\n\n    backend = XGrammarGrammarBackend(\n        tokenizer, vocab_size=vocab_size, model_eos_token_ids=[eos_token_id]\n    )\n    regex = r\"hello (world|there)\"\n    grammar = backend.dispatch_regex(regex)\n    tokens = [\n        tokenizer.encode(t, add_special_tokens=False)[0] for t in [\"hello\", \" world\"]\n    ]\n\n    # Test termination\n    grammar.accept_token(tokens[0])  # accept \"hello\"\n    grammar.accept_token(tokens[1])  # accept \" world\"\n    grammar.accept_token(eos_token_id)  # accept EOS\n    assert grammar.is_terminated()\n\n    # Test rollback the terminated state\n    grammar.rollback(1)\n    assert not grammar.is_terminated()\n\n\nif __name__ == \"__main__\":\n    demo_test()\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/__init__.py",
    "content": ""
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/__init__.py",
    "content": "from sglang.srt.debug_utils.comparator.aligner.entrypoint.traced_types import (  # noqa: F401\n    TracedAlignerPlan,\n)\nfrom sglang.srt.debug_utils.comparator.aligner.entrypoint.types import (  # noqa: F401\n    AlignerPlan,\n)\nfrom sglang.srt.debug_utils.comparator.output_types import ComparisonTensorRecord\n\nComparisonTensorRecord.model_rebuild()\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/__main__.py",
    "content": "from sglang.srt.debug_utils.comparator.entrypoint import main\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/aligner/__init__.py",
    "content": ""
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/aligner/axis_aligner.py",
    "content": "from __future__ import annotations\n\nfrom typing import Optional\n\nimport torch\nfrom einops import rearrange\n\nfrom sglang.srt.debug_utils.comparator.dims_spec import (\n    _FUSED_NAME_SEP,\n    DimSpec,\n    _SingletonDimUtil,\n    parse_dims,\n)\nfrom sglang.srt.debug_utils.comparator.log_sink import log_sink\nfrom sglang.srt.debug_utils.comparator.utils import Pair, _FrozenBase\n\n# --- types ---\n\n\nclass AxisAlignerPlan(_FrozenBase):\n    pattern: Pair[Optional[str]]  # einops pattern per side, None = no-op\n\n\n# --- planner ---\n\n\ndef compute_axis_aligner_plan(\n    dims_str_pair: Pair[Optional[str]],\n) -> Optional[AxisAlignerPlan]:\n    if dims_str_pair.x is None or dims_str_pair.y is None:\n        return None\n\n    dims_pair: Pair[str] = Pair(x=dims_str_pair.x, y=dims_str_pair.y)\n    specs_pair: Pair[list[DimSpec]] = dims_pair.map(lambda s: parse_dims(s).dims)\n\n    if not _semantic_names_match(specs_pair):\n        return None\n\n    # Canonical dim order follows y; fused groups stay fused (flatten, not unflatten).\n    canonical_order: Optional[list[str]] = _build_canonical_order(specs_pair)\n    if canonical_order is None:\n        return None\n\n    pattern: Pair[Optional[str]] = specs_pair.map(\n        lambda specs: _build_side_pattern(specs=specs, canonical_order=canonical_order)\n    )\n\n    if pattern.x is None and pattern.y is None:\n        return None\n\n    return AxisAlignerPlan(pattern=pattern)\n\n\ndef _semantic_names_match(specs_pair: Pair[list[DimSpec]]) -> bool:\n    \"\"\"Check that both sides share the same semantic name set (ignoring squeeze dims).\"\"\"\n    names_pair: Pair[list[str]] = specs_pair.map(_expand_and_skip_squeeze)\n\n    if set(names_pair.x) == set(names_pair.y):\n        return True\n\n    # Local import to avoid circular dependency:\n    # output_types -> aligner/entrypoint/types -> axis_aligner -> output_types\n    from sglang.srt.debug_utils.comparator.output_types import ErrorLog\n\n    log_sink.add(\n        ErrorLog(\n            category=\"axis_aligner_dim_mismatch\",\n            message=(\n                f\"AxisAligner: dim name sets differ (x={names_pair.x}, y={names_pair.y}), \"\n                f\"skipping axis swap\"\n            ),\n        )\n    )\n    return False\n\n\ndef _expand_and_skip_squeeze(specs: list[DimSpec]) -> list[str]:\n    \"\"\"Expand DimSpecs to flat semantic names, skipping squeeze dims.\"\"\"\n    return [\n        name\n        for spec in specs\n        if not _SingletonDimUtil.is_squeeze(spec)\n        for name in spec.sub_dims\n    ]\n\n\ndef _build_canonical_order(specs_pair: Pair[list[DimSpec]]) -> Optional[list[str]]:\n    \"\"\"Build canonical dim order following y, preferring fused representation.\n\n    Each element is either a plain name (``\"c\"``) or a fused placeholder (``\"a___b\"``).\n    Fused groups from *either* side are merged — the separate side must flatten.\n    Squeeze dims are excluded.\n\n    Returns ``None`` if the two sides have overlapping but incompatible fused groups\n    (e.g. x fuses ``(a*b)`` while y fuses ``(b*c)``).\n    \"\"\"\n    # Map each sub-dim name → (placeholder, siblings) from both sides\n    fused_lookup: dict[str, tuple[str, frozenset[str]]] = {}\n    for spec in (*specs_pair.x, *specs_pair.y):\n        if spec.is_fused:\n            placeholder: str = spec.sanitized_name\n            siblings: frozenset[str] = frozenset(spec.sub_dims)\n            for sub_name in spec.sub_dims:\n                existing: Optional[tuple[str, frozenset[str]]] = fused_lookup.get(\n                    sub_name\n                )\n                if existing is not None and existing[1] != siblings:\n                    from sglang.srt.debug_utils.comparator.output_types import ErrorLog\n\n                    log_sink.add(\n                        ErrorLog(\n                            category=\"axis_aligner_fused_conflict\",\n                            message=(\n                                f\"AxisAligner: overlapping fused groups for sub-dim {sub_name!r} \"\n                                f\"({existing[0]} vs {placeholder}), skipping axis alignment\"\n                            ),\n                        )\n                    )\n                    return None\n                fused_lookup.setdefault(sub_name, (placeholder, siblings))\n\n    result: list[str] = []\n    consumed: set[str] = set()\n\n    for spec in specs_pair.y:\n        if _SingletonDimUtil.is_squeeze(spec):\n            continue\n\n        names: list[str] = spec.sub_dims\n        if any(n in consumed for n in names):\n            continue\n\n        entry: Optional[tuple[str, frozenset[str]]] = fused_lookup.get(names[0])\n        if entry is not None:\n            fused_placeholder, sibs = entry\n            result.append(fused_placeholder)\n            consumed.update(sibs)\n        else:\n            result.append(spec.name)\n            consumed.update(names)\n\n    return result\n\n\ndef _build_side_pattern(\n    *, specs: list[DimSpec], canonical_order: list[str]\n) -> Optional[str]:\n    \"\"\"Build an einops pattern for one side to reach ``canonical_order``.\n\n    Fused specs become their placeholder; separate specs that belong to a fused group\n    stay as individual names on the LHS and become ``(a b)`` on the RHS (einops flatten).\n    Squeeze dims (``1``) appear on the LHS but are dropped from the RHS.\n    \"\"\"\n    source_tokens: list[str] = [spec.sanitized_name for spec in specs]\n\n    # Build per-side target: replace fused placeholders with ``(a b)`` only if this side\n    # has the sub-dims as separate (non-fused) names in the source\n    fused_placeholders: set[str] = {\n        spec.sanitized_name for spec in specs if spec.is_fused\n    }\n    target_tokens: list[str] = [\n        (\n            f\"({t.replace(_FUSED_NAME_SEP, ' ')})\"\n            if _FUSED_NAME_SEP in t and t not in fused_placeholders\n            else t\n        )\n        for t in canonical_order\n    ]\n\n    if source_tokens == target_tokens:\n        return None\n\n    return f\"{' '.join(source_tokens)} -> {' '.join(target_tokens)}\"\n\n\n# --- executor ---\n\n\ndef execute_axis_aligner_plan(\n    tensor: torch.Tensor, plan: AxisAlignerPlan, *, side: str\n) -> torch.Tensor:\n    if side not in (\"x\", \"y\"):\n        raise ValueError(f\"side must be 'x' or 'y', got {side!r}\")\n\n    pattern: Optional[str] = plan.pattern.x if side == \"x\" else plan.pattern.y\n\n    if pattern is not None:\n        tensor = rearrange(tensor.rename(None), pattern)\n\n    return tensor\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/aligner/entrypoint/__init__.py",
    "content": ""
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/aligner/entrypoint/executor.py",
    "content": "from __future__ import annotations\n\nfrom dataclasses import dataclass, field\nfrom typing import NamedTuple, Optional\n\nimport torch\n\nfrom sglang.srt.debug_utils.comparator.aligner.axis_aligner import (\n    execute_axis_aligner_plan,\n)\nfrom sglang.srt.debug_utils.comparator.aligner.entrypoint.traced_types import (\n    TracedAlignerPlan,\n    TracedSidePlan,\n    TracedStepPlan,\n    TracedSubPlan,\n)\nfrom sglang.srt.debug_utils.comparator.aligner.entrypoint.types import (\n    AlignerPerStepPlan,\n    AlignerPerStepSubPlan,\n    AlignerPlan,\n)\nfrom sglang.srt.debug_utils.comparator.aligner.reorderer.executor import (\n    execute_reorderer_plan,\n)\nfrom sglang.srt.debug_utils.comparator.aligner.reorderer.types import ReordererPlan\nfrom sglang.srt.debug_utils.comparator.aligner.token_aligner.concat_steps import (\n    execute_token_aligner_concat_steps,\n)\nfrom sglang.srt.debug_utils.comparator.aligner.token_aligner.smart.executor import (\n    execute_token_aligner,\n)\nfrom sglang.srt.debug_utils.comparator.aligner.unsharder.executor import (\n    UnsharderResult,\n    execute_unsharder_plan,\n)\nfrom sglang.srt.debug_utils.comparator.aligner.unsharder.types import UnsharderPlan\nfrom sglang.srt.debug_utils.comparator.output_types import (\n    ReplicatedCheckResult,\n    ShapeSnapshot,\n)\nfrom sglang.srt.debug_utils.comparator.utils import Pair\n\n\nclass StepPlansResult(NamedTuple):\n    tensors: dict[int, torch.Tensor]\n    checks: list[ReplicatedCheckResult]\n    traced_side: TracedSidePlan\n\n\nclass SubPlansResult(NamedTuple):\n    tensor: Optional[torch.Tensor]\n    checks: list[ReplicatedCheckResult]\n    snapshots: list[ShapeSnapshot]\n\n\n@dataclass(frozen=True)\nclass AlignerResult:\n    tensors: Optional[Pair[torch.Tensor]]\n    failed_side_xy: Optional[str]  # \"x\" or \"y\"; None if success\n    replicated_checks: list[ReplicatedCheckResult] = field(default_factory=list)\n    traced_plan: Optional[TracedAlignerPlan] = None\n\n\ndef execute_aligner_plan(\n    *,\n    tensors_pair: Pair[list[torch.Tensor]],\n    plan: AlignerPlan,\n) -> AlignerResult:\n    \"\"\"Execute unified unshard/reorder + token-align.\"\"\"\n    all_checks: list[ReplicatedCheckResult] = []\n\n    # Per-side: unshard + reorder -> dict[step, tensor]\n    result_x: StepPlansResult = _execute_step_plans(\n        tensors=tensors_pair.x, step_plans=plan.per_step_plans.x\n    )\n    all_checks.extend(result_x.checks)\n\n    result_y: StepPlansResult = _execute_step_plans(\n        tensors=tensors_pair.y, step_plans=plan.per_step_plans.y\n    )\n    all_checks.extend(result_y.checks)\n\n    traced_plan: TracedAlignerPlan = TracedAlignerPlan(\n        plan=plan,\n        per_side=Pair(x=result_x.traced_side, y=result_y.traced_side),\n    )\n\n    if not result_x.tensors or not result_y.tensors:\n        failed_side_xy: str = \"x\" if not result_x.tensors else \"y\"\n        return AlignerResult(\n            tensors=None,\n            failed_side_xy=failed_side_xy,\n            replicated_checks=all_checks,\n            traced_plan=traced_plan,\n        )\n\n    # Cross-side: token alignment (or direct extraction for single-step)\n    step_pair: Pair[dict[int, torch.Tensor]] = Pair(\n        x=result_x.tensors, y=result_y.tensors\n    )\n    combined: Pair[torch.Tensor]\n    if plan.token_aligner_mode == \"concat_steps\":\n        combined = execute_token_aligner_concat_steps(tensor_of_step_pair=step_pair)\n    elif plan.token_aligner_mode == \"smart\":\n        assert plan.token_aligner_plan is not None\n        combined = execute_token_aligner(\n            plan=plan.token_aligner_plan,\n            tensor_of_step_pair=step_pair,\n        )\n    else:\n        assert len(result_x.tensors) == 1 and len(result_y.tensors) == 1\n        combined = Pair(\n            x=list(result_x.tensors.values())[0],\n            y=list(result_y.tensors.values())[0],\n        )\n\n    # Cross-side: axis alignment (squeeze singletons + rearrange dim order)\n    if (aligner_plan := plan.axis_aligner_plan) is not None:\n        combined = Pair(\n            x=execute_axis_aligner_plan(tensor=combined.x, plan=aligner_plan, side=\"x\"),\n            y=execute_axis_aligner_plan(tensor=combined.y, plan=aligner_plan, side=\"y\"),\n        )\n\n    return AlignerResult(\n        tensors=combined,\n        failed_side_xy=None,\n        replicated_checks=all_checks,\n        traced_plan=traced_plan,\n    )\n\n\ndef _execute_step_plans(\n    tensors: list[torch.Tensor],\n    step_plans: list[AlignerPerStepPlan],\n) -> StepPlansResult:\n    result: dict[int, torch.Tensor] = {}\n    all_checks: list[ReplicatedCheckResult] = []\n    traced_steps: list[TracedStepPlan] = []\n\n    for step_plan in step_plans:\n        step_tensors: list[torch.Tensor] = [\n            tensors[i] for i in step_plan.input_object_indices\n        ]\n        sub_result: SubPlansResult = execute_sub_plans(\n            tensors=step_tensors, plans=step_plan.sub_plans\n        )\n        all_checks.extend(sub_result.checks)\n\n        traced_subs: list[TracedSubPlan] = [\n            TracedSubPlan(plan=sub_plan, snapshot=snapshot)\n            for sub_plan, snapshot in zip(step_plan.sub_plans, sub_result.snapshots)\n        ]\n        traced_steps.append(\n            TracedStepPlan(\n                step=step_plan.step,\n                input_object_indices=step_plan.input_object_indices,\n                sub_plans=traced_subs,\n            )\n        )\n\n        if sub_result.tensor is not None:\n            result[step_plan.step] = sub_result.tensor\n\n    return StepPlansResult(\n        tensors=result,\n        checks=all_checks,\n        traced_side=TracedSidePlan(step_plans=traced_steps),\n    )\n\n\ndef execute_sub_plans(\n    tensors: list[torch.Tensor],\n    plans: list[AlignerPerStepSubPlan],\n) -> SubPlansResult:\n    if not tensors:\n        return SubPlansResult(tensor=None, checks=[], snapshots=[])\n\n    if not plans:\n        if len(tensors) != 1:\n            return SubPlansResult(tensor=None, checks=[], snapshots=[])\n        return SubPlansResult(tensor=tensors[0], checks=[], snapshots=[])\n\n    current: list[torch.Tensor] = tensors\n    all_checks: list[ReplicatedCheckResult] = []\n    all_snapshots: list[ShapeSnapshot] = []\n    for plan in plans:\n        input_shapes: list[list[int]] = [list(t.shape) for t in current]\n        current, checks = execute_sub_plan(tensors=current, plan=plan)\n        output_shapes: list[list[int]] = [list(t.shape) for t in current]\n        all_checks.extend(checks)\n        all_snapshots.append(\n            ShapeSnapshot(\n                input_shapes=input_shapes,\n                output_shapes=output_shapes,\n            )\n        )\n\n    assert len(current) == 1\n    return SubPlansResult(tensor=current[0], checks=all_checks, snapshots=all_snapshots)\n\n\ndef execute_sub_plan(\n    tensors: list[torch.Tensor],\n    plan: AlignerPerStepSubPlan,\n) -> tuple[list[torch.Tensor], list[ReplicatedCheckResult]]:\n    if isinstance(plan, UnsharderPlan):\n        unsharder_result: UnsharderResult = execute_unsharder_plan(plan, tensors)\n        return unsharder_result.tensors, unsharder_result.replicated_checks\n    elif isinstance(plan, ReordererPlan):\n        return execute_reorderer_plan(plan, tensors), []\n    else:\n        raise NotImplementedError(f\"Unknown {plan=}\")\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/aligner/entrypoint/planner.py",
    "content": "from __future__ import annotations\n\nfrom typing import Any, Optional\n\nfrom sglang.srt.debug_utils.comparator.aligner.axis_aligner import (\n    AxisAlignerPlan,\n    compute_axis_aligner_plan,\n)\nfrom sglang.srt.debug_utils.comparator.aligner.entrypoint.types import (\n    AlignerPerStepPlan,\n    AlignerPerStepSubPlan,\n    AlignerPlan,\n)\nfrom sglang.srt.debug_utils.comparator.aligner.reorderer.planner import (\n    compute_reorderer_plans,\n)\nfrom sglang.srt.debug_utils.comparator.aligner.token_aligner.smart.types import (\n    TokenAlignerPlan,\n)\nfrom sglang.srt.debug_utils.comparator.aligner.unsharder.parallel_info import (\n    normalize_parallel_info,\n)\nfrom sglang.srt.debug_utils.comparator.aligner.unsharder.planner import (\n    compute_unsharder_plan,\n)\nfrom sglang.srt.debug_utils.comparator.dims_spec import (\n    DimSpec,\n    DimsSpec,\n    ParallelAxis,\n    _SingletonDimUtil,\n    parse_dims,\n)\nfrom sglang.srt.debug_utils.comparator.utils import Pair\n\n\ndef compute_aligner_plan(\n    *,\n    metas_pair: Pair[list[dict[str, Any]]],\n    token_aligner_mode: Optional[str],\n    token_aligner_plan: Optional[TokenAlignerPlan],\n    thd_seq_lens_by_step_pair: Pair[Optional[dict[int, list[int]]]] = Pair(\n        x=None, y=None\n    ),\n) -> AlignerPlan:\n    dims_str_pair: Pair[Optional[str]] = metas_pair.map(\n        lambda metas: metas[0].get(\"dims\") if metas else None\n    )\n    axis_aligner_plan: Optional[AxisAlignerPlan] = compute_axis_aligner_plan(\n        dims_str_pair=dims_str_pair\n    )\n\n    return AlignerPlan(\n        per_step_plans=Pair(\n            x=_compute_per_step_plans(\n                metas=metas_pair.x,\n                thd_seq_lens_by_step=thd_seq_lens_by_step_pair.x,\n            ),\n            y=_compute_per_step_plans(\n                metas=metas_pair.y,\n                thd_seq_lens_by_step=thd_seq_lens_by_step_pair.y,\n            ),\n        ),\n        token_aligner_mode=token_aligner_mode,\n        token_aligner_plan=token_aligner_plan,\n        axis_aligner_plan=axis_aligner_plan,\n    )\n\n\ndef _compute_per_step_plans(\n    metas: list[dict[str, Any]],\n    *,\n    thd_seq_lens_by_step: Optional[dict[int, list[int]]] = None,\n) -> list[AlignerPerStepPlan]:\n    step_to_input_indices: dict[int, list[int]] = {}\n    for i, meta in enumerate(metas):\n        step: int = int(meta[\"step\"])\n        step_to_input_indices.setdefault(step, []).append(i)\n\n    result: list[AlignerPerStepPlan] = []\n    for step in sorted(step_to_input_indices):\n        input_indices: list[int] = step_to_input_indices[step]\n        step_metas: list[dict[str, Any]] = [metas[idx] for idx in input_indices]\n        step_seq_lens: Optional[list[int]] = (\n            thd_seq_lens_by_step.get(step) if thd_seq_lens_by_step is not None else None\n        )\n        plans: list[AlignerPerStepSubPlan] = compute_per_step_sub_plans(\n            metas=step_metas,\n            thd_global_seq_lens=step_seq_lens,\n        )\n        result.append(\n            AlignerPerStepPlan(\n                step=step, input_object_indices=input_indices, sub_plans=plans\n            )\n        )\n\n    return result\n\n\ndef compute_per_step_sub_plans(\n    metas: list[dict[str, Any]],\n    *,\n    thd_global_seq_lens: Optional[list[int]] = None,\n) -> list[AlignerPerStepSubPlan]:\n    if not metas or len(metas) == 1:\n        return []\n\n    dims_str = metas[0].get(\"dims\")\n    if dims_str is None:\n        return []\n\n    dims_spec: DimsSpec = parse_dims(dims_str)\n    dim_specs: list[DimSpec] = _SingletonDimUtil.filter_out(dims_spec.dims)\n    replicated_axes: frozenset[ParallelAxis] = dims_spec.replicated_axes\n    parallel_infos = [normalize_parallel_info(meta) for meta in metas]\n\n    unsharder_plans = compute_unsharder_plan(\n        dim_specs=dim_specs,\n        parallel_infos=parallel_infos,\n        explicit_replicated_axes=replicated_axes,\n        thd_global_seq_lens=thd_global_seq_lens,\n    )\n    reorderer_plans = compute_reorderer_plans(\n        dim_specs=dim_specs,\n        parallel_infos=parallel_infos,\n        thd_global_seq_lens=thd_global_seq_lens,\n    )\n    return [*unsharder_plans, *reorderer_plans]\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/aligner/entrypoint/traced_types.py",
    "content": "\"\"\"Traced wrapper types that embed execution traces (ShapeSnapshots) into plan nodes.\n\nThese types are created *after* execution, pairing each sub-plan with its\nobserved shape snapshot so that downstream formatters never need to manually\nzip plan + trace by index.\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom typing import Optional\n\nfrom sglang.srt.debug_utils.comparator.aligner.entrypoint.types import (\n    AlignerPerStepSubPlan,\n    AlignerPlan,\n)\nfrom sglang.srt.debug_utils.comparator.output_types import ShapeSnapshot\nfrom sglang.srt.debug_utils.comparator.utils import Pair, _StrictBase\n\n\nclass TracedSubPlan(_StrictBase):\n    plan: AlignerPerStepSubPlan\n    snapshot: Optional[ShapeSnapshot] = None\n\n\nclass TracedStepPlan(_StrictBase):\n    step: int\n    input_object_indices: list[int]\n    sub_plans: list[TracedSubPlan]\n\n\nclass TracedSidePlan(_StrictBase):\n    step_plans: list[TracedStepPlan]\n\n\nclass TracedAlignerPlan(_StrictBase):\n    plan: AlignerPlan\n    per_side: Pair[TracedSidePlan]\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/aligner/entrypoint/types.py",
    "content": "from __future__ import annotations\n\nfrom typing import Annotated, Optional, Union\n\nfrom pydantic import Discriminator\n\nfrom sglang.srt.debug_utils.comparator.aligner.axis_aligner import AxisAlignerPlan\nfrom sglang.srt.debug_utils.comparator.aligner.reorderer.types import ReordererPlan\nfrom sglang.srt.debug_utils.comparator.aligner.token_aligner.smart.types import (\n    TokenAlignerPlan,\n)\nfrom sglang.srt.debug_utils.comparator.aligner.unsharder.types import UnsharderPlan\nfrom sglang.srt.debug_utils.comparator.utils import Pair, _FrozenBase\n\nAlignerPerStepSubPlan = Annotated[\n    Union[UnsharderPlan, ReordererPlan],\n    Discriminator(\"type\"),\n]\n\n\nclass AlignerPerStepPlan(_FrozenBase):\n    step: int\n    input_object_indices: list[int]\n    sub_plans: list[AlignerPerStepSubPlan]\n\n\nclass AlignerPlan(_FrozenBase):\n    per_step_plans: Pair[list[AlignerPerStepPlan]]\n    token_aligner_mode: Optional[str] = None  # \"concat_steps\" | \"smart\" | None\n    token_aligner_plan: Optional[TokenAlignerPlan] = None\n    axis_aligner_plan: Optional[AxisAlignerPlan] = None\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/aligner/reorderer/__init__.py",
    "content": ""
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/aligner/reorderer/executor.py",
    "content": "from typing import Optional\n\nimport torch\n\nfrom sglang.srt.debug_utils.comparator.aligner.reorderer.types import (\n    ReordererPlan,\n    ZigzagToNaturalParams,\n    ZigzagToNaturalThdParams,\n)\nfrom sglang.srt.debug_utils.comparator.dims_spec import (\n    resolve_dim_by_name,\n    strip_dim_names,\n)\n\n\ndef execute_reorderer_plan(\n    plan: ReordererPlan,\n    tensors: list[torch.Tensor],\n) -> list[torch.Tensor]:\n    if isinstance(plan.params, ZigzagToNaturalThdParams):\n        thd_dim: int = resolve_dim_by_name(tensors[0], plan.params.dim_name)\n        return [\n            _reorder_zigzag_to_natural_thd(\n                tensor,\n                dim=thd_dim,\n                cp_size=plan.params.cp_size,\n                seq_lens=plan.params.seq_lens,\n            )\n            for tensor in tensors\n        ]\n\n    if isinstance(plan.params, ZigzagToNaturalParams):\n        dim: int = resolve_dim_by_name(tensors[0], plan.params.dim_name)\n        return [\n            _reorder_zigzag_to_natural(tensor, dim=dim, cp_size=plan.params.cp_size)\n            for tensor in tensors\n        ]\n\n    raise ValueError(f\"Unsupported reorderer params type: {type(plan.params).__name__}\")\n\n\ndef _reorder_zigzag_to_natural_thd(\n    tensor: torch.Tensor, *, dim: int, cp_size: int, seq_lens: list[int]\n) -> torch.Tensor:\n    \"\"\"Undo CP zigzag interleaving for THD (packed-seq) format.\n\n    Each seq in seq_lens is independently reordered from zigzag to natural order\n    along the given dim.\n    \"\"\"\n    stripped: torch.Tensor = strip_dim_names(tensor)\n    names: tuple[Optional[str], ...] = tensor.names\n\n    split_sizes: list[int] = list(seq_lens)\n    remainder: int = stripped.shape[dim] - sum(split_sizes)\n    if remainder < 0:\n        raise ValueError(\n            f\"sum(seq_lens)={sum(split_sizes)} exceeds tensor dim size \"\n            f\"{stripped.shape[dim]} along dim={dim}\"\n        )\n    if remainder > 0:\n        split_sizes.append(remainder)\n\n    segments: list[torch.Tensor] = list(stripped.split(split_sizes, dim=dim))\n\n    reordered_segments: list[torch.Tensor] = [\n        _reorder_zigzag_to_natural(seg, dim=dim, cp_size=cp_size)\n        for seg in segments[: len(seq_lens)]\n    ]\n\n    # Tail padding — pass through unchanged\n    if remainder > 0:\n        reordered_segments.append(segments[-1])\n\n    result: torch.Tensor = torch.cat(reordered_segments, dim=dim)\n\n    if names[0] is not None:\n        result = result.refine_names(*names)\n    return result\n\n\ndef _reorder_zigzag_to_natural(\n    tensor: torch.Tensor, *, dim: int, cp_size: int\n) -> torch.Tensor:\n    \"\"\"Undo CP zigzag interleaving, restoring natural chunk order.\n\n    Generalized from Megatron-LM _undo_attention_load_balancing\n    (megatron/core/ssm/mamba_context_parallel.py:360-373).\n    \"\"\"\n    stripped: torch.Tensor = strip_dim_names(tensor)\n    names: tuple[Optional[str], ...] = tensor.names\n\n    num_chunks: int = cp_size * 2\n    chunks: tuple[torch.Tensor, ...] = stripped.chunk(num_chunks, dim=dim)\n    order: list[int] = [2 * i for i in range(cp_size)] + [\n        num_chunks - 2 * i - 1 for i in range(cp_size)\n    ]\n    result: torch.Tensor = torch.cat([chunks[i] for i in order], dim=dim)\n\n    if names[0] is not None:\n        result = result.refine_names(*names)\n    return result\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/aligner/reorderer/planner.py",
    "content": "from typing import Optional\n\nfrom sglang.srt.debug_utils.comparator.aligner.reorderer.types import (\n    ReordererPlan,\n    ZigzagToNaturalParams,\n    ZigzagToNaturalThdParams,\n)\nfrom sglang.srt.debug_utils.comparator.aligner.unsharder.types import AxisInfo\nfrom sglang.srt.debug_utils.comparator.dims_spec import (\n    SEQ_DIM_NAME,\n    TOKEN_DIM_NAME,\n    DimSpec,\n    Ordering,\n    ParallelAxis,\n)\n\n_ALLOWED_ZIGZAG_DIM_NAMES: set[str] = {SEQ_DIM_NAME, TOKEN_DIM_NAME}\n\n\ndef compute_reorderer_plans(\n    dim_specs: list[DimSpec],\n    parallel_infos: list[dict[ParallelAxis, AxisInfo]],\n    *,\n    thd_global_seq_lens: Optional[list[int]] = None,\n) -> list[ReordererPlan]:\n    plans: list[ReordererPlan] = []\n\n    for spec in dim_specs:\n        for modifier in spec.parallel_modifiers:\n            if modifier.ordering is None or modifier.ordering == Ordering.NATURAL:\n                continue\n\n            if spec.name not in _ALLOWED_ZIGZAG_DIM_NAMES:\n                raise ValueError(\n                    f\"Zigzag ordering is only supported on sequence dims \"\n                    f\"(dim name must be one of \"\n                    f\"{sorted(_ALLOWED_ZIGZAG_DIM_NAMES)}), \"\n                    f\"but got dim name {spec.name!r} in {spec}\"\n                )\n\n            if modifier.ordering != Ordering.ZIGZAG:\n                raise ValueError(\n                    f\"Unsupported ordering {modifier.ordering!r} for dim {spec.name!r}\"\n                )\n            axis_size: int = parallel_infos[0][modifier.axis].axis_size\n\n            if spec.name == TOKEN_DIM_NAME:\n                if thd_global_seq_lens is None:\n                    raise ValueError(\n                        \"thd_global_seq_lens is required for zigzag reorder on 't' dimension\"\n                    )\n                params = ZigzagToNaturalThdParams(\n                    dim_name=spec.name,\n                    cp_size=axis_size,\n                    seq_lens=thd_global_seq_lens,\n                )\n            elif spec.name == SEQ_DIM_NAME:\n                params = ZigzagToNaturalParams(dim_name=spec.name, cp_size=axis_size)\n            else:\n                raise ValueError(\n                    f\"Unsupported zigzag dim name {spec.name!r}, \"\n                    f\"expected one of {sorted(_ALLOWED_ZIGZAG_DIM_NAMES)}\"\n                )\n\n            plans.append(ReordererPlan(params=params))\n\n    return plans\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/aligner/reorderer/types.py",
    "content": "from typing import Annotated, Literal, Union\n\nfrom pydantic import Field\n\nfrom sglang.srt.debug_utils.comparator.utils import _FrozenBase\n\n\nclass ZigzagToNaturalParams(_FrozenBase):\n    op: Literal[\"zigzag_to_natural\"] = \"zigzag_to_natural\"\n    dim_name: str\n    cp_size: int\n\n\nclass ZigzagToNaturalThdParams(_FrozenBase):\n    op: Literal[\"zigzag_to_natural_thd\"] = \"zigzag_to_natural_thd\"\n    dim_name: str\n    cp_size: int\n    seq_lens: list[int]  # unshard-ed per-seq token counts, e.g. [100, 64, 92]\n\n\nReordererParams = Annotated[\n    Union[ZigzagToNaturalParams, ZigzagToNaturalThdParams],\n    Field(discriminator=\"op\"),\n]\n\n\nclass ReordererPlan(_FrozenBase):\n    type: Literal[\"reorderer\"] = \"reorderer\"\n    params: ReordererParams\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/aligner/token_aligner/__init__.py",
    "content": ""
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/aligner/token_aligner/concat_steps/__init__.py",
    "content": "from sglang.srt.debug_utils.comparator.aligner.token_aligner.concat_steps.executor import (\n    execute_token_aligner_concat_steps,\n)\n\n__all__ = [\n    \"execute_token_aligner_concat_steps\",\n]\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/aligner/token_aligner/concat_steps/executor.py",
    "content": "from __future__ import annotations\n\nfrom typing import Optional\n\nimport torch\n\nfrom sglang.srt.debug_utils.comparator.dims_spec import (\n    SEQ_DIM_NAME,\n    TOKEN_DIM_NAME,\n)\nfrom sglang.srt.debug_utils.comparator.utils import Pair\n\n_UNNAMED_TOKEN_DIM_FALLBACK: int = 0\n\n\ndef execute_token_aligner_concat_steps(\n    tensor_of_step_pair: Pair[dict[int, torch.Tensor]],\n) -> Pair[torch.Tensor]:\n    \"\"\"Concat all steps in order, then truncate to min(total_x, total_y) tokens.\"\"\"\n    some_tensor: torch.Tensor = next(iter(tensor_of_step_pair.x.values()))\n    token_dim: int = _resolve_token_dim(some_tensor)\n\n    concatenated: Pair[torch.Tensor] = tensor_of_step_pair.map(\n        lambda d: _concat_steps(d, dim=token_dim)\n    )\n    common: int = min(concatenated.x.shape[token_dim], concatenated.y.shape[token_dim])\n    return concatenated.map(lambda t: t.narrow(dim=token_dim, start=0, length=common))\n\n\ndef _resolve_token_dim(tensor: torch.Tensor) -> int:\n    \"\"\"Find the token/seq dim index. Falls back to dim 0 for unnamed tensors or\n    tensors without a recognised token/seq dim.\"\"\"\n    if tensor.names[0] is None:\n        return _UNNAMED_TOKEN_DIM_FALLBACK\n\n    names: tuple[Optional[str], ...] = tensor.names\n    for candidate in (TOKEN_DIM_NAME, SEQ_DIM_NAME):\n        if candidate in names:\n            return list(names).index(candidate)\n\n    return _UNNAMED_TOKEN_DIM_FALLBACK\n\n\ndef _concat_steps(tensor_of_step: dict[int, torch.Tensor], *, dim: int) -> torch.Tensor:\n    return torch.cat([tensor_of_step[s] for s in sorted(tensor_of_step)], dim=dim)\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/aligner/token_aligner/concat_steps/thd_seq_lens_loader.py",
    "content": "from __future__ import annotations\n\nfrom pathlib import Path\nfrom typing import Optional\n\nimport polars as pl\n\nfrom sglang.srt.debug_utils.comparator.aligner.token_aligner.smart.aux_loader import (\n    _detect_plugin,\n    _load_and_align_aux_tensor,\n)\nfrom sglang.srt.debug_utils.comparator.aligner.token_aligner.smart.aux_plugins import (\n    _AuxFrameworkPlugin,\n)\n\n\ndef load_thd_seq_lens_only(\n    dump_path: Path, df: pl.DataFrame\n) -> Optional[dict[int, list[int]]]:\n    plugin: Optional[_AuxFrameworkPlugin] = _detect_plugin(df, dump_path=dump_path)\n    if plugin is None or not plugin.cp_sharded_names:\n        return None\n\n    non_cp_tensor_names: set[str] = (\n        set(df[\"name\"].unique().to_list()) & plugin.tensor_names\n    ) - plugin.cp_sharded_names\n    steps: list[int] = sorted(df[\"step\"].unique().to_list())\n\n    result: dict[int, list[int]] = {}\n    for step in steps:\n        step_data: dict[str, object] = {}\n        for name in non_cp_tensor_names:\n            tensor = _load_and_align_aux_tensor(\n                name=name, step=step, df=df, dump_path=dump_path, plugin=plugin\n            )\n            if tensor is not None:\n                step_data[name] = tensor\n\n        seq_lens: Optional[list[int]] = plugin.extract_global_seq_lens(step_data)\n        if seq_lens is not None:\n            result[step] = seq_lens\n\n    return result or None\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/aligner/token_aligner/entrypoint.py",
    "content": "from __future__ import annotations\n\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom typing import Literal, Optional\n\nimport polars as pl\n\nfrom sglang.srt.debug_utils.comparator.aligner.token_aligner.concat_steps.thd_seq_lens_loader import (\n    load_thd_seq_lens_only,\n)\nfrom sglang.srt.debug_utils.comparator.aligner.token_aligner.smart.aux_loader import (\n    has_aux_tensors,\n    load_and_normalize_aux,\n)\nfrom sglang.srt.debug_utils.comparator.aligner.token_aligner.smart.planner import (\n    compute_token_aligner_plan,\n)\nfrom sglang.srt.debug_utils.comparator.aligner.token_aligner.smart.seq_info_builder import (\n    build_seqs_info,\n)\nfrom sglang.srt.debug_utils.comparator.aligner.token_aligner.smart.types import (\n    TokenAlignerGlobalAux,\n    TokenAlignerPlan,\n    TokenAlignerSeqsInfo,\n)\nfrom sglang.srt.debug_utils.comparator.log_sink import log_sink\nfrom sglang.srt.debug_utils.comparator.output_types import InfoLog\nfrom sglang.srt.debug_utils.comparator.utils import Pair\n\n_NONE_THD: Pair[Optional[dict[int, list[int]]]] = Pair(x=None, y=None)\n\n\nTokenAlignerMode = Literal[\"concat_steps\", \"smart\"]\n\n\n@dataclass(frozen=True)\nclass TokenAlignerResult:\n    \"\"\"Result of token aligner computation, bundling mode + plan with THD metadata.\"\"\"\n\n    mode: Optional[TokenAlignerMode]\n    plan: Optional[TokenAlignerPlan]\n    thd_seq_lens_by_step_pair: Pair[Optional[dict[int, list[int]]]]\n\n\ndef compute_maybe_token_aligner_result(\n    *,\n    dir_pair: Pair[Path],\n    dfs: Pair[pl.DataFrame],\n    token_aligner_mode: Optional[TokenAlignerMode],\n) -> TokenAlignerResult:\n    if token_aligner_mode is None:\n        return TokenAlignerResult(\n            mode=None, plan=None, thd_seq_lens_by_step_pair=_NONE_THD\n        )\n\n    if token_aligner_mode == \"concat_steps\":\n        thd_pair: Pair[Optional[dict[int, list[int]]]] = _load_thd_seq_lens_pair(\n            dir_pair=dir_pair, dfs=dfs\n        )\n        return TokenAlignerResult(\n            mode=\"concat_steps\", plan=None, thd_seq_lens_by_step_pair=thd_pair\n        )\n    elif token_aligner_mode == \"smart\":\n        if not (has_aux_tensors(dfs.x) and has_aux_tensors(dfs.y)):\n            log_sink.add(\n                InfoLog(\n                    category=\"aux_tensors_missing\",\n                    message=\"Aux tensors missing, skipping token alignment\",\n                )\n            )\n            return TokenAlignerResult(\n                mode=None, plan=None, thd_seq_lens_by_step_pair=_NONE_THD\n            )\n\n        return _build_smart_result(dir_pair=dir_pair, dfs=dfs)\n    else:\n        raise NotImplementedError(f\"Unknown {token_aligner_mode=}\")\n\n\ndef _build_smart_result(\n    *,\n    dir_pair: Pair[Path],\n    dfs: Pair[pl.DataFrame],\n) -> TokenAlignerResult:\n    \"\"\"Load aux tensors, build token indices, and compute the alignment plan.\"\"\"\n    aux_pair: Pair[Optional[TokenAlignerGlobalAux]] = Pair(\n        x=load_and_normalize_aux(dump_path=dir_pair.x, df=dfs.x),\n        y=load_and_normalize_aux(dump_path=dir_pair.y, df=dfs.y),\n    )\n\n    thd_seq_lens_by_step_pair: Pair[Optional[dict[int, list[int]]]] = aux_pair.map(\n        lambda aux: aux.thd_seq_lens_by_step if aux is not None else None\n    )\n\n    if aux_pair.x is None or aux_pair.y is None:\n        log_sink.add(\n            InfoLog(\n                category=\"framework_detection_failed\",\n                message=\"Framework detection failed, skipping token alignment\",\n            )\n        )\n        return TokenAlignerResult(\n            mode=None,\n            plan=None,\n            thd_seq_lens_by_step_pair=thd_seq_lens_by_step_pair,\n        )\n\n    global_aux: Pair[TokenAlignerGlobalAux] = Pair(x=aux_pair.x, y=aux_pair.y)\n\n    seqs_info: Pair[TokenAlignerSeqsInfo] = global_aux.map(build_seqs_info)\n\n    plan: Optional[TokenAlignerPlan] = compute_token_aligner_plan(\n        seqs_info_pair=seqs_info\n    )\n    return TokenAlignerResult(\n        mode=\"smart\",\n        plan=plan,\n        thd_seq_lens_by_step_pair=thd_seq_lens_by_step_pair,\n    )\n\n\ndef _load_thd_seq_lens_pair(\n    *,\n    dir_pair: Pair[Path],\n    dfs: Pair[pl.DataFrame],\n) -> Pair[Optional[dict[int, list[int]]]]:\n    \"\"\"Load only thd_seq_lens for each side (lightweight, no full aux loading).\"\"\"\n    return Pair(\n        x=load_thd_seq_lens_only(dump_path=dir_pair.x, df=dfs.x),\n        y=load_thd_seq_lens_only(dump_path=dir_pair.y, df=dfs.y),\n    )\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/aligner/token_aligner/smart/__init__.py",
    "content": ""
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/aligner/token_aligner/smart/aux_loader.py",
    "content": "from __future__ import annotations\n\nfrom pathlib import Path\nfrom typing import Any, Optional\n\nimport polars as pl\nimport torch\n\nfrom sglang.srt.debug_utils.comparator.aligner.entrypoint.executor import (\n    execute_sub_plans,\n)\nfrom sglang.srt.debug_utils.comparator.aligner.entrypoint.planner import (\n    compute_per_step_sub_plans,\n)\nfrom sglang.srt.debug_utils.comparator.aligner.token_aligner.smart.aux_plugins import (\n    AUX_NAMES,\n    _AuxFrameworkPlugin,\n    _plugins,\n)\nfrom sglang.srt.debug_utils.comparator.aligner.token_aligner.smart.types import (\n    TokenAlignerGlobalAux,\n    TokenAlignerStepAux,\n)\nfrom sglang.srt.debug_utils.comparator.aligner.unsharder.parallel_info import (\n    normalize_parallel_info,\n)\nfrom sglang.srt.debug_utils.comparator.dims_spec import (\n    ParallelAxis,\n    TokenLayout,\n    apply_dim_names,\n    resolve_dim_names,\n)\nfrom sglang.srt.debug_utils.comparator.dp_utils import filter_to_non_empty_dp_rank\nfrom sglang.srt.debug_utils.comparator.log_sink import log_sink\nfrom sglang.srt.debug_utils.comparator.output_types import ErrorLog, InfoLog\nfrom sglang.srt.debug_utils.dump_loader import ValueWithMeta, filter_rows\n\n# re-export for existing callers\n__all__ = [\n    \"AUX_NAMES\",\n    \"has_aux_tensors\",\n    \"load_and_normalize_aux\",\n]\n\n\ndef load_and_normalize_aux(\n    dump_path: Path, df: pl.DataFrame\n) -> Optional[TokenAlignerGlobalAux]:\n    \"\"\"Bootstrap: load, unshard, and normalize auxiliary tensors for one side.\"\"\"\n    plugin: Optional[_AuxFrameworkPlugin] = _detect_plugin(df, dump_path=dump_path)\n    if plugin is None:\n        return None\n\n    available_names: set[str] = set(df[\"name\"].unique().to_list()) & plugin.all_names\n    steps: list[int] = sorted(df[\"step\"].unique().to_list())\n    tensor_names: set[str] = available_names & plugin.tensor_names\n    non_tensor_names: set[str] = available_names & plugin.non_tensor_names\n\n    steps_data: dict[int, dict[str, object]] = {}\n    thd_seq_lens_by_step: dict[int, list[int]] = {}\n    for step in steps:\n        step_data, thd_seq_lens = _load_step_data(\n            step=step,\n            tensor_names=tensor_names,\n            non_tensor_names=non_tensor_names,\n            df=df,\n            dump_path=dump_path,\n            plugin=plugin,\n        )\n        if step_data:\n            steps_data[step] = step_data\n        if thd_seq_lens is not None:\n            thd_seq_lens_by_step[step] = thd_seq_lens\n\n    layout: TokenLayout = plugin.detect_layout(steps_data)\n\n    step_auxs: dict[int, TokenAlignerStepAux] = {\n        step: plugin.compute_step_aux(step_data, layout=layout, step=step)\n        for step, step_data in steps_data.items()\n    }\n\n    return TokenAlignerGlobalAux(\n        step_auxs=step_auxs,\n        framework=plugin.name,\n        layout=layout,\n        thd_seq_lens_by_step=thd_seq_lens_by_step or None,\n    )\n\n\ndef has_aux_tensors(df: pl.DataFrame) -> bool:\n    \"\"\"Check if the DataFrame contains the minimum auxiliary tensors for alignment.\"\"\"\n    names: set[str] = set(df[\"name\"].unique().to_list())\n    return any(plugin.has_required_names(names) for plugin in _plugins)\n\n\ndef _detect_plugin(df: pl.DataFrame, dump_path: Path) -> Optional[_AuxFrameworkPlugin]:\n    names: set[str] = set(df[\"name\"].unique().to_list())\n\n    for plugin in _plugins:\n        if names & plugin.discriminating_names:\n            return plugin\n\n    first_row: dict = df.row(0, named=True)\n    value: ValueWithMeta = ValueWithMeta.load(dump_path / first_row[\"filename\"])\n\n    for plugin in _plugins:\n        if f\"{plugin.name}_parallel_info\" in value.meta:\n            return plugin\n\n    return None\n\n\ndef _load_step_data(\n    *,\n    step: int,\n    tensor_names: set[str],\n    non_tensor_names: set[str],\n    df: pl.DataFrame,\n    dump_path: Path,\n    plugin: _AuxFrameworkPlugin,\n) -> tuple[dict[str, object], Optional[list[int]]]:\n    \"\"\"Load all tensor and non-tensor aux values for a single step.\n\n    Two-pass loading: non-CP-sharded tensors first (to obtain cu_seqlens_q\n    for seq_lens), then CP-sharded tensors with seq_lens for THD unshard/reorder.\n\n    Returns (step_data, thd_global_seq_lens).\n    \"\"\"\n    result: dict[str, object] = {}\n\n    # Pass 0: non-tensor values\n    for name in non_tensor_names:\n        value = _load_non_tensor_aux(name=name, step=step, df=df, dump_path=dump_path)\n        if value is not None:\n            result[name] = value\n\n    # Pass 1: non-CP-sharded tensors (e.g. cu_seqlens_q, seq_lens)\n    non_cp_tensor_names: set[str] = tensor_names - plugin.cp_sharded_names\n    cp_tensor_names: set[str] = tensor_names & plugin.cp_sharded_names\n\n    for name in non_cp_tensor_names:\n        tensor = _load_and_align_aux_tensor(\n            name=name, step=step, df=df, dump_path=dump_path, plugin=plugin\n        )\n        if tensor is not None:\n            result[name] = tensor\n\n    # Derive global seq_lens for THD unshard (framework-specific extraction)\n    thd_global_seq_lens: Optional[list[int]] = plugin.extract_global_seq_lens(result)\n\n    # Pass 2: CP-sharded tensors (input_ids, position_ids, etc.)\n    for name in cp_tensor_names:\n        tensor = _load_and_align_aux_tensor(\n            name=name,\n            step=step,\n            df=df,\n            dump_path=dump_path,\n            plugin=plugin,\n            thd_global_seq_lens=thd_global_seq_lens,\n        )\n        if tensor is not None:\n            result[name] = tensor\n\n    return result, thd_global_seq_lens\n\n\ndef _load_non_tensor_aux(\n    *, name: str, step: int, df: pl.DataFrame, dump_path: Path\n) -> Optional[object]:\n    \"\"\"Load a non-tensor auxiliary value for a step, validating consistency across ranks.\"\"\"\n    rows = filter_rows(df, conditions={\"name\": name, \"step\": step})\n    if not rows:\n        return None\n\n    loaded: list[ValueWithMeta] = [\n        ValueWithMeta.load(dump_path / r[\"filename\"]) for r in rows\n    ]\n    loaded = filter_to_non_empty_dp_rank(loaded)\n\n    if len(loaded) > 1:\n        first_value = loaded[0].value\n        for i, item in enumerate(loaded[1:], start=1):\n            if item.value != first_value:\n                log_sink.add(\n                    ErrorLog(\n                        category=f\"{name}_mismatch\",\n                        message=(\n                            f\"{name} mismatch across ranks: rank 0 has {first_value}, \"\n                            f\"rank {i} has {item.value}\"\n                        ),\n                    )\n                )\n                break\n\n    return loaded[0].value\n\n\ndef _load_and_align_aux_tensor(\n    *,\n    name: str,\n    step: int,\n    df: pl.DataFrame,\n    dump_path: Path,\n    plugin: _AuxFrameworkPlugin,\n    thd_global_seq_lens: Optional[list[int]] = None,\n) -> Optional[torch.Tensor]:\n    \"\"\"Load an auxiliary tensor for (name, step), align if needed.\"\"\"\n    rows = filter_rows(df, conditions={\"name\": name, \"step\": step})\n    if not rows:\n        return None\n\n    loaded: list[ValueWithMeta] = [\n        ValueWithMeta.load(dump_path / r[\"filename\"]) for r in rows\n    ]\n    loaded = filter_to_non_empty_dp_rank(loaded)\n\n    tensors: list[torch.Tensor] = [\n        item.value for item in loaded if isinstance(item.value, torch.Tensor)\n    ]\n    if not tensors:\n        return None\n\n    if len(tensors) == 1:\n        return tensors[0]\n\n    metas: list[dict[str, Any]] = [item.meta for item in loaded]\n    metas = _ensure_dims_in_metas(\n        name=name, plugin=plugin, metas=metas, ndim=tensors[0].ndim\n    )\n\n    sub_plans = compute_per_step_sub_plans(\n        metas=metas,\n        thd_global_seq_lens=(\n            thd_global_seq_lens if name in plugin.cp_sharded_names else None\n        ),\n    )\n    if sub_plans:\n        dims_str: Optional[str] = metas[0].get(\"dims\")\n        if dims_str is not None:\n            dim_names: list[str] = resolve_dim_names(dims_str)\n            tensors = [apply_dim_names(t, dim_names) for t in tensors]\n\n        sub_result = execute_sub_plans(tensors=tensors, plans=sub_plans)\n        assert sub_result.tensor is not None\n        return sub_result.tensor.rename(\n            None\n        )  # strip named dims before returning to plugin\n\n    log_sink.add(\n        InfoLog(\n            category=\"aux_no_dims\",\n            message=(\n                f\"aux tensor '{name}' has {len(tensors)} ranks \"\n                f\"but no dims metadata, using rank 0 only\"\n            ),\n        )\n    )\n    return tensors[0]\n\n\ndef _ensure_dims_in_metas(\n    *,\n    name: str,\n    plugin: _AuxFrameworkPlugin,\n    metas: list[dict[str, Any]],\n    ndim: int,\n) -> list[dict[str, Any]]:\n    \"\"\"Inject inferred dims into metas if not already present.\n\n    Returns metas unchanged if dims is already set, or a new list with dims\n    injected if inference succeeds for CP-sharded tensors.\n    \"\"\"\n    if metas[0].get(\"dims\") is not None:\n        return metas\n\n    parallel_infos = [normalize_parallel_info(m) for m in metas]\n    has_cp: bool = any(ParallelAxis.CP in info for info in parallel_infos)\n    if not has_cp:\n        return metas\n\n    if name in plugin.cp_sharded_names:\n        inferred_dims: str = plugin.infer_cp_sharded_dims(name=name, ndim=ndim)\n        return [{**m, \"dims\": inferred_dims} for m in metas]\n\n    return metas\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/aligner/token_aligner/smart/aux_plugins.py",
    "content": "from __future__ import annotations\n\nfrom abc import ABC, abstractmethod\nfrom typing import Optional\n\nimport torch\n\nfrom sglang.srt.debug_utils.comparator.aligner.token_aligner.smart.types import (\n    PositionalSeqId,\n    SeqId,\n    SGLangSeqId,\n    TokenAlignerStepAux,\n)\nfrom sglang.srt.debug_utils.comparator.dims_spec import TokenLayout\nfrom sglang.srt.debug_utils.comparator.log_sink import log_sink\nfrom sglang.srt.debug_utils.comparator.output_types import InfoLog\n\n# ── plugin ABC ─────────────────────────────────────────────────────\n\n\nclass _AuxFrameworkPlugin(ABC):\n    @property\n    @abstractmethod\n    def name(self) -> str: ...\n\n    @property\n    @abstractmethod\n    def tensor_names(self) -> frozenset[str]: ...\n\n    @property\n    @abstractmethod\n    def non_tensor_names(self) -> frozenset[str]: ...\n\n    @property\n    def cp_sharded_names(self) -> frozenset[str]:\n        return frozenset()\n\n    @property\n    def discriminating_names(self) -> frozenset[str]:\n        \"\"\"Field names unique to this framework (excluding shared names like input_ids).\"\"\"\n        return frozenset()\n\n    @abstractmethod\n    def detect_layout(self, raw: dict[int, dict[str, object]]) -> TokenLayout: ...\n\n    @abstractmethod\n    def compute_step_aux(\n        self, step_data: dict[str, object], *, layout: TokenLayout, step: int\n    ) -> TokenAlignerStepAux: ...\n\n    @abstractmethod\n    def has_required_names(self, names: set[str]) -> bool:\n        \"\"\"Whether the minimum set of aux names needed for alignment is present.\"\"\"\n        ...\n\n    @property\n    def all_names(self) -> frozenset[str]:\n        return self.tensor_names | self.non_tensor_names\n\n    def extract_global_seq_lens(\n        self, step_data: dict[str, object]\n    ) -> Optional[list[int]]:\n        \"\"\"Extract per-seq token counts from loaded step data.\n\n        Returns None if this framework doesn't support THD / no relevant data available.\n        \"\"\"\n        return None\n\n    def infer_cp_sharded_dims(self, name: str, ndim: int) -> str:\n        \"\"\"Infer dims string for a CP-sharded aux tensor based on its ndim.\"\"\"\n        raise NotImplementedError(\n            f\"infer_cp_sharded_dims not implemented for {type(self).__name__}\"\n        )\n\n\n# ── sglang plugin ─────────────────────────────────────────────────\n\n\nclass _SGLangPlugin(_AuxFrameworkPlugin):\n    @property\n    def name(self) -> str:\n        return \"sglang\"\n\n    @property\n    def tensor_names(self) -> frozenset[str]:\n        return frozenset({\"input_ids\", \"positions\", \"seq_lens\", \"req_pool_indices\"})\n\n    @property\n    def non_tensor_names(self) -> frozenset[str]:\n        return frozenset({\"rids\"})\n\n    @property\n    def cp_sharded_names(self) -> frozenset[str]:\n        return frozenset({\"input_ids\", \"positions\"})\n\n    @property\n    def discriminating_names(self) -> frozenset[str]:\n        return frozenset({\"seq_lens\", \"positions\", \"req_pool_indices\", \"rids\"})\n\n    def has_required_names(self, names: set[str]) -> bool:\n        return \"input_ids\" in names and \"seq_lens\" in names\n\n    def detect_layout(self, raw: dict[int, dict[str, object]]) -> TokenLayout:\n        return TokenLayout.T\n\n    def extract_global_seq_lens(\n        self, step_data: dict[str, object]\n    ) -> Optional[list[int]]:\n        if not self.cp_sharded_names:\n            return None\n\n        seq_lens = step_data.get(\"seq_lens\")\n        if not isinstance(seq_lens, torch.Tensor):\n            return None\n\n        return seq_lens.tolist()\n\n    def infer_cp_sharded_dims(self, name: str, ndim: int) -> str:\n        \"\"\"Infer dims for CP-sharded aux tensors.\n\n        NOTE: assumes zigzag ordering — natural-order CP without explicit dims\n        will be mishandled. Callers should set dims explicitly for non-zigzag CP.\n        \"\"\"\n        if ndim == 1:\n            return \"t[cp:zigzag]\"\n        raise ValueError(\n            f\"SGLang: cannot infer dims for CP-sharded '{name}' with ndim={ndim}\"\n        )\n\n    def compute_step_aux(\n        self, step_data: dict[str, object], *, layout: TokenLayout, step: int\n    ) -> TokenAlignerStepAux:\n        input_ids = step_data[\"input_ids\"]\n        positions = step_data[\"positions\"]\n        seq_lens = step_data[\"seq_lens\"]\n        rids_raw = step_data.get(\"rids\")\n\n        assert isinstance(\n            input_ids, torch.Tensor\n        ), f\"input_ids: expected Tensor, got {type(input_ids)}\"\n        assert isinstance(\n            positions, torch.Tensor\n        ), f\"positions: expected Tensor, got {type(positions)}\"\n        assert isinstance(\n            seq_lens, torch.Tensor\n        ), f\"seq_lens: expected Tensor, got {type(seq_lens)}\"\n\n        seq_lens_list: list[int] = seq_lens.tolist()\n        num_seqs: int = len(seq_lens_list)\n\n        seq_ids: list[SeqId]\n        if rids_raw is not None and isinstance(rids_raw, (list, tuple)):\n            seq_ids = [SGLangSeqId(rid=str(r)) for r in rids_raw]\n        else:\n            seq_ids = [PositionalSeqId(step=step, seq_index=i) for i in range(num_seqs)]\n\n        return TokenAlignerStepAux(\n            input_ids=input_ids.tolist(),\n            positions=positions.tolist(),\n            seq_lens=seq_lens_list,\n            seq_ids=seq_ids,\n        )\n\n\n# ── megatron plugin ───────────────────────────────────────────────\n\n\nclass _MegatronPlugin(_AuxFrameworkPlugin):\n    @property\n    def name(self) -> str:\n        return \"megatron\"\n\n    @property\n    def tensor_names(self) -> frozenset[str]:\n        return frozenset({\"input_ids\", \"position_ids\", \"cu_seqlens_q\", \"cu_seqlens_kv\"})\n\n    @property\n    def non_tensor_names(self) -> frozenset[str]:\n        return frozenset({\"qkv_format\"})\n\n    @property\n    def cp_sharded_names(self) -> frozenset[str]:\n        return frozenset({\"input_ids\", \"position_ids\"})\n\n    @property\n    def discriminating_names(self) -> frozenset[str]:\n        return frozenset({\"cu_seqlens_q\", \"cu_seqlens_kv\", \"qkv_format\"})\n\n    def has_required_names(self, names: set[str]) -> bool:\n        return \"input_ids\" in names\n\n    def extract_global_seq_lens(\n        self, step_data: dict[str, object]\n    ) -> Optional[list[int]]:\n        if not self.cp_sharded_names:\n            return None\n\n        cu_seqlens_q = step_data.get(\"cu_seqlens_q\")\n        if not isinstance(cu_seqlens_q, torch.Tensor):\n            return None\n\n        return (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).tolist()\n\n    def infer_cp_sharded_dims(self, name: str, ndim: int) -> str:\n        \"\"\"Infer dims for CP-sharded aux tensors.\n\n        NOTE: assumes zigzag ordering — natural-order CP without explicit dims\n        will be mishandled. Callers should set dims explicitly for non-zigzag CP.\n        \"\"\"\n        if ndim == 1:\n            return \"t[cp:zigzag]\"\n        if ndim == 2:\n            return \"b s[cp:zigzag]\"\n        raise ValueError(\n            f\"Megatron: cannot infer dims for CP-sharded '{name}' with ndim={ndim}\"\n        )\n\n    def detect_layout(self, raw: dict[int, dict[str, object]]) -> TokenLayout:\n        for step_data in raw.values():\n            if (qkv_format := step_data.get(\"qkv_format\")) is not None:\n                fmt = qkv_format if isinstance(qkv_format, str) else str(qkv_format)\n                if \"bshd\" in fmt.lower():\n                    return TokenLayout.BS\n                return TokenLayout.T\n\n            input_ids = step_data.get(\"input_ids\")\n            if isinstance(input_ids, torch.Tensor) and input_ids.ndim == 2:\n                return TokenLayout.BS\n\n        log_sink.add(\n            InfoLog(\n                category=\"layout_detection_fallback\",\n                message=(\n                    \"Megatron layout detection: no qkv_format or 2D input_ids found, \"\n                    \"falling back to T\"\n                ),\n            )\n        )\n        return TokenLayout.T\n\n    def compute_step_aux(\n        self, step_data: dict[str, object], *, layout: TokenLayout, step: int\n    ) -> TokenAlignerStepAux:\n        input_ids: torch.Tensor = step_data[\"input_ids\"]\n        is_bshd: bool = layout == TokenLayout.BS\n\n        # BSHD [B, S] → flat [B*S]; THD [T] stays as-is\n        flat_ids: list[int] = input_ids.reshape(-1).tolist()\n\n        if (cu_seqlens_q := step_data.get(\"cu_seqlens_q\")) is not None:\n            seq_lens_list: list[int] = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).tolist()\n        elif is_bshd:\n            seq_lens_list = [input_ids.shape[1]] * input_ids.shape[0]\n        else:\n            seq_lens_list = [input_ids.shape[0]]\n\n        if (position_ids := step_data.get(\"position_ids\")) is not None:\n            flat_positions: list[int] = position_ids.reshape(-1).tolist()\n        elif is_bshd:\n            flat_positions = list(range(input_ids.shape[1])) * input_ids.shape[0]\n        else:\n            flat_positions = _infer_positions(\n                seq_lens=torch.tensor(seq_lens_list)\n            ).tolist()\n\n        num_seqs: int = len(seq_lens_list)\n        seq_ids: list[SeqId] = [\n            PositionalSeqId(step=step, seq_index=seq_index)\n            for seq_index in range(num_seqs)\n        ]\n\n        return TokenAlignerStepAux(\n            input_ids=flat_ids,\n            positions=flat_positions,\n            seq_lens=seq_lens_list,\n            seq_ids=seq_ids,\n        )\n\n\n# ── plugin registry ───────────────────────────────────────────────\n\n_plugins: list[_AuxFrameworkPlugin] = [_SGLangPlugin(), _MegatronPlugin()]\n\nAUX_NAMES: frozenset[str] = frozenset().union(*(p.all_names for p in _plugins))\n\n\n# ── helpers ────────────────────────────────────────────────────────\n\n\ndef _infer_positions(*, seq_lens: torch.Tensor) -> torch.Tensor:\n    \"\"\"Infer positions when position_ids is missing (THD only).\"\"\"\n    return torch.cat([torch.arange(int(slen.item())) for slen in seq_lens])\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/aligner/token_aligner/smart/executor.py",
    "content": "from __future__ import annotations\n\nimport torch\nfrom einops import rearrange\n\nfrom sglang.srt.debug_utils.comparator.aligner.token_aligner.smart.types import (\n    TokenAlignerPlan,\n    TokenLocator,\n)\nfrom sglang.srt.debug_utils.comparator.dims_spec import (\n    BATCH_DIM_NAME,\n    SEQ_DIM_NAME,\n    TOKEN_DIM_NAME,\n    TokenLayout,\n    resolve_dim_by_name,\n    strip_dim_names,\n)\nfrom sglang.srt.debug_utils.comparator.utils import Pair\n\n_UNNAMED_TOKEN_DIM_FALLBACK: int = 0\n\n\ndef execute_token_aligner(\n    plan: TokenAlignerPlan,\n    tensor_of_step_pair: Pair[dict[int, torch.Tensor]],\n) -> Pair[torch.Tensor]:\n    flat_pair: Pair[dict[int, torch.Tensor]] = Pair(\n        x=_collapse_bs_to_t(\n            tensor_of_step=tensor_of_step_pair.x, layout=plan.layouts.x\n        ),\n        y=_collapse_bs_to_t(\n            tensor_of_step=tensor_of_step_pair.y, layout=plan.layouts.y\n        ),\n    )\n\n    if not plan.locators.x.steps:\n        return Pair(\n            x=_make_empty(tensor_of_step=flat_pair.x),\n            y=_make_empty(tensor_of_step=flat_pair.y),\n        )\n\n    return Pair(\n        x=_extract_and_stack_tokens(\n            tensor_of_step=flat_pair.x, locator=plan.locators.x\n        ),\n        y=_extract_and_stack_tokens(\n            tensor_of_step=flat_pair.y, locator=plan.locators.y\n        ),\n    )\n\n\n# ── BS → T preprocessing ─────────────────────────────────────────\n\n\ndef _collapse_bs_to_t(\n    *,\n    tensor_of_step: dict[int, torch.Tensor],\n    layout: TokenLayout,\n) -> dict[int, torch.Tensor]:\n    \"\"\"Collapse B and S dims into a single flat token dim (always batch-major).\n\n    Handles both ``b s`` and ``s b`` orderings correctly via einops rearrange.\n    Returns the original tensors unchanged if layout is T.\n    \"\"\"\n    if layout != TokenLayout.BS:\n        return tensor_of_step\n\n    some_tensor: torch.Tensor = next(iter(tensor_of_step.values()))\n    batch_dim: int = _resolve_dim_or_fallback(some_tensor, BATCH_DIM_NAME)\n    seq_dim: int = _resolve_dim_or_fallback(some_tensor, SEQ_DIM_NAME)\n\n    if abs(batch_dim - seq_dim) != 1:\n        raise ValueError(\n            f\"BS dims must be adjacent: \"\n            f\"{BATCH_DIM_NAME}={batch_dim}, \"\n            f\"{SEQ_DIM_NAME}={seq_dim}\"\n        )\n\n    lhs_pattern, rhs_pattern, new_names = _build_bs_collapse_pattern(\n        names=list(some_tensor.names),\n        batch_dim=batch_dim,\n        seq_dim=seq_dim,\n    )\n\n    result: dict[int, torch.Tensor] = {}\n    for step, tensor in tensor_of_step.items():\n        plain: torch.Tensor = strip_dim_names(tensor)\n        collapsed: torch.Tensor = rearrange(plain, f\"{lhs_pattern} -> {rhs_pattern}\")\n        result[step] = collapsed.refine_names(*new_names)\n\n    return result\n\n\ndef _build_bs_collapse_pattern(\n    *,\n    names: list[str | None],\n    batch_dim: int,\n    seq_dim: int,\n) -> tuple[str, str, list[str | None]]:\n    \"\"\"Build einops lhs/rhs patterns and output dim names for BS→T collapse.\n\n    Always produces batch-major order ``(b s)`` regardless of input ordering.\n    Uses the tensor's own dim names as einops axis names.\n    \"\"\"\n    lo: int = min(batch_dim, seq_dim)\n    hi: int = max(batch_dim, seq_dim)\n\n    lhs: str = \" \".join(names)  # type: ignore[arg-type]\n\n    rhs_names: list[str] = list(names[:lo]) + [f\"({BATCH_DIM_NAME} {SEQ_DIM_NAME})\"] + list(names[hi + 1 :])  # type: ignore[misc]\n    rhs: str = \" \".join(rhs_names)\n\n    new_names: list[str | None] = (\n        list(names[:lo]) + [TOKEN_DIM_NAME] + list(names[hi + 1 :])\n    )\n\n    return lhs, rhs, new_names\n\n\n# ── core logic (T layout only) ───────────────────────────────────\n\n\ndef _resolve_dim_or_fallback(tensor: torch.Tensor, name: str) -> int:\n    if tensor.names[0] is None:\n        return _UNNAMED_TOKEN_DIM_FALLBACK\n    return resolve_dim_by_name(tensor, name)\n\n\ndef _make_empty(*, tensor_of_step: dict[int, torch.Tensor]) -> torch.Tensor:\n    dummy: torch.Tensor = next(iter(tensor_of_step.values()))\n    token_dim: int = _resolve_dim_or_fallback(dummy, TOKEN_DIM_NAME)\n    shape: list[int] = list(dummy.shape)\n    shape[token_dim] = 0\n    return torch.empty(shape, dtype=dummy.dtype)\n\n\ndef _extract_and_stack_tokens(\n    *,\n    tensor_of_step: dict[int, torch.Tensor],\n    locator: TokenLocator,\n) -> torch.Tensor:\n    some_tensor: torch.Tensor = next(iter(tensor_of_step.values()))\n    token_dim: int = _resolve_dim_or_fallback(some_tensor, TOKEN_DIM_NAME)\n\n    tokens: list[torch.Tensor] = [\n        strip_dim_names(tensor_of_step[s]).select(dim=token_dim, index=i)\n        for s, i in zip(locator.steps, locator.token_index_in_step)\n    ]\n    return torch.stack(tokens, dim=token_dim)\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/aligner/token_aligner/smart/planner.py",
    "content": "from __future__ import annotations\n\nfrom collections import defaultdict\nfrom typing import NamedTuple, Optional\n\nfrom sglang.srt.debug_utils.comparator.aligner.token_aligner.smart.types import (\n    SeqId,\n    TokenAlignerPlan,\n    TokenAlignerSeqInfo,\n    TokenAlignerSeqsInfo,\n    TokenLocator,\n)\nfrom sglang.srt.debug_utils.comparator.utils import Pair\n\n\ndef compute_token_aligner_plan(\n    seqs_info_pair: Pair[TokenAlignerSeqsInfo],\n) -> TokenAlignerPlan:\n    \"\"\"Compute a token alignment plan from two side token seqs_info_pair.\"\"\"\n    matched_pairs: list[tuple[SeqId, SeqId]] = _match_sequences(\n        seqs=Pair(x=seqs_info_pair.x.sequences, y=seqs_info_pair.y.sequences)\n    )\n\n    _empty = TokenLocator(steps=[], token_index_in_step=[])\n    locator_x: TokenLocator = _empty\n    locator_y: TokenLocator = _empty\n\n    for seq_id_x, seq_id_y in matched_pairs:\n        rec: Pair[TokenAlignerSeqInfo] = Pair(\n            x=seqs_info_pair.x.sequences[seq_id_x],\n            y=seqs_info_pair.y.sequences[seq_id_y],\n        )\n\n        # positions is validated to be [0, 1, ..., N-1], so position == index\n        # and the common range is simply [0, min(len_x, len_y)).\n        common_len: int = min(len(rec.x.positions), len(rec.y.positions))\n\n        x_ids = rec.x.input_ids[:common_len]\n        y_ids = rec.y.input_ids[:common_len]\n        assert x_ids == y_ids, f\"{seq_id_x=} {seq_id_y=} {x_ids=} {y_ids=}\"\n\n        locator_x = locator_x + TokenLocator(\n            steps=rec.x.locator.steps[:common_len],\n            token_index_in_step=rec.x.locator.token_index_in_step[:common_len],\n        )\n        locator_y = locator_y + TokenLocator(\n            steps=rec.y.locator.steps[:common_len],\n            token_index_in_step=rec.y.locator.token_index_in_step[:common_len],\n        )\n\n    return TokenAlignerPlan(\n        locators=Pair(x=locator_x, y=locator_y),\n        layouts=seqs_info_pair.map(lambda s: s.layout),\n    )\n\n\n# -------------------- Sequence matcher --------------------\n\n\ndef _match_sequences(\n    seqs: Pair[dict[SeqId, TokenAlignerSeqInfo]],\n) -> list[tuple[SeqId, SeqId]]:\n    \"\"\"For each y (target) sequence, find a matching x (baseline) sequence.\n\n    Two-pass: exact match first, then prefix match for remaining.\n    \"\"\"\n    x_lookup: dict[tuple[int, ...], list[SeqId]] = defaultdict(list)\n    for seq_id, rec in seqs.x.items():\n        x_lookup[tuple(rec.input_ids)].append(seq_id)\n\n    claimed_x_ids: set[SeqId] = set()\n    matched_seq_id_pairs: list[tuple[SeqId, SeqId]] = []\n\n    for seq_id_y in sorted(seqs.y.keys()):\n        seq_y: TokenAlignerSeqInfo = seqs.y[seq_id_y]\n\n        matched_x: Optional[SeqId] = _find_matching_x_exact(\n            seq_y=seq_y, x_lookup=x_lookup, claimed_x_ids=claimed_x_ids\n        )\n        if matched_x is None:\n            matched_x = _find_matching_x_prefix(\n                seq_y=seq_y, x_seqs=seqs.x, claimed_x_ids=claimed_x_ids\n            )\n\n        if matched_x is not None:\n            matched_seq_id_pairs.append((matched_x, seq_id_y))\n            claimed_x_ids.add(matched_x)\n\n    return matched_seq_id_pairs\n\n\ndef _find_matching_x_exact(\n    *,\n    seq_y: TokenAlignerSeqInfo,\n    x_lookup: dict[tuple[int, ...], list[SeqId]],\n    claimed_x_ids: set[SeqId],\n) -> Optional[SeqId]:\n    \"\"\"Find an x sequence with identical input_ids.\"\"\"\n    ids_y_key: tuple[int, ...] = tuple(seq_y.input_ids)\n    candidates: list[SeqId] = x_lookup.get(ids_y_key, [])\n    for candidate in candidates:\n        if candidate not in claimed_x_ids:\n            return candidate\n    return None\n\n\nclass _PrefixCandidate(NamedTuple):\n    seq_id_x: SeqId\n    overlap_len: int\n\n\ndef _find_matching_x_prefix(\n    *,\n    seq_y: TokenAlignerSeqInfo,\n    x_seqs: dict[SeqId, TokenAlignerSeqInfo],\n    claimed_x_ids: set[SeqId],\n) -> Optional[SeqId]:\n    \"\"\"Find the x sequence with the longest prefix relationship to y.\"\"\"\n    ids_y: list[int] = seq_y.input_ids\n    candidates: list[_PrefixCandidate] = [\n        _PrefixCandidate(\n            seq_id_x=seq_id_x, overlap_len=min(len(seq_x.input_ids), len(ids_y))\n        )\n        for seq_id_x, seq_x in x_seqs.items()\n        if seq_id_x not in claimed_x_ids and _is_prefix_pair(seq_x.input_ids, ids_y)\n    ]\n    if not candidates:\n        return None\n    return max(candidates, key=lambda c: c.overlap_len).seq_id_x\n\n\ndef _is_prefix_pair(a: list[int], b: list[int]) -> bool:\n    \"\"\"True if a is a prefix of b, or b is a prefix of a.\"\"\"\n    shorter_len: int = min(len(a), len(b))\n    return a[:shorter_len] == b[:shorter_len]\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/aligner/token_aligner/smart/seq_info_builder.py",
    "content": "from __future__ import annotations\n\nfrom dataclasses import dataclass, field\n\nfrom sglang.srt.debug_utils.comparator.aligner.token_aligner.smart.types import (\n    SeqId,\n    TokenAlignerGlobalAux,\n    TokenAlignerSeqInfo,\n    TokenAlignerSeqsInfo,\n    TokenAlignerStepAux,\n    TokenLocator,\n)\n\n\n@dataclass\nclass _SeqInfoAccumulator:\n    \"\"\"Mutable accumulator for building TokenAlignerSeqInfo without per-step validation.\"\"\"\n\n    input_ids: list[int] = field(default_factory=list)\n    positions: list[int] = field(default_factory=list)\n    steps: list[int] = field(default_factory=list)\n    token_index_in_step: list[int] = field(default_factory=list)\n\n    def extend(\n        self,\n        *,\n        input_ids: list[int],\n        positions: list[int],\n        steps: list[int],\n        token_index_in_step: list[int],\n    ) -> None:\n        self.input_ids.extend(input_ids)\n        self.positions.extend(positions)\n        self.steps.extend(steps)\n        self.token_index_in_step.extend(token_index_in_step)\n\n    def build(self) -> TokenAlignerSeqInfo:\n        return TokenAlignerSeqInfo(\n            input_ids=self.input_ids,\n            positions=self.positions,\n            locator=TokenLocator(\n                steps=self.steps,\n                token_index_in_step=self.token_index_in_step,\n            ),\n        )\n\n\ndef build_seqs_info(global_aux: TokenAlignerGlobalAux) -> TokenAlignerSeqsInfo:\n    \"\"\"Build sequence info for one side from its auxiliary tensors.\"\"\"\n    return TokenAlignerSeqsInfo(\n        sequences=_build_token_aligner_seq_infos(global_aux),\n        layout=global_aux.layout,\n    )\n\n\ndef _build_token_aligner_seq_infos(\n    global_aux: TokenAlignerGlobalAux,\n) -> dict[SeqId, TokenAlignerSeqInfo]:\n    \"\"\"Build token index for any framework/layout using seq_ids for identity tracking.\"\"\"\n    accum: dict[SeqId, _SeqInfoAccumulator] = {}\n\n    for step in sorted(global_aux.step_auxs.keys()):\n        aux: TokenAlignerStepAux = global_aux.step_auxs[step]\n\n        offset: int = 0\n        for seq_index, seq_len in enumerate(aux.seq_lens):\n            seq_id: SeqId = aux.seq_ids[seq_index]\n\n            if seq_id not in accum:\n                accum[seq_id] = _SeqInfoAccumulator()\n\n            accum[seq_id].extend(\n                input_ids=aux.input_ids[offset : offset + seq_len],\n                positions=aux.positions[offset : offset + seq_len],\n                steps=[step] * seq_len,\n                token_index_in_step=list(range(offset, offset + seq_len)),\n            )\n\n            offset += seq_len\n\n    return {seq_id: acc.build() for seq_id, acc in accum.items()}\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/aligner/token_aligner/smart/types.py",
    "content": "from __future__ import annotations\n\nfrom dataclasses import dataclass, field\nfrom typing import NamedTuple, Optional, Union\n\nfrom pydantic import model_validator\n\nfrom sglang.srt.debug_utils.comparator.dims_spec import TokenLayout\nfrom sglang.srt.debug_utils.comparator.utils import (\n    Pair,\n    _check_equal_lengths,\n    _FrozenBase,\n)\n\n\nclass SGLangSeqId(NamedTuple):\n    rid: str\n\n\nclass PositionalSeqId(NamedTuple):\n    step: int\n    seq_index: int\n\n\nSeqId = Union[SGLangSeqId, PositionalSeqId]\n\n\n@dataclass(frozen=True)\nclass TokenAlignerStepAux:\n    \"\"\"Normalized auxiliary tensors for a single step (framework-agnostic).\"\"\"\n\n    input_ids: list[int]  # [num_tokens]\n    positions: list[int]  # [num_tokens]\n    seq_lens: list[int]  # [num_seqs]\n    seq_ids: list[SeqId]  # [num_seqs] — sequence identity\n\n    def __post_init__(self) -> None:\n        _check_equal_lengths(input_ids=self.input_ids, positions=self.positions)\n        _check_equal_lengths(seq_lens=self.seq_lens, seq_ids=self.seq_ids)\n\n        token_count: int = sum(self.seq_lens)\n        if token_count != len(self.input_ids):\n            raise ValueError(\n                f\"sum(seq_lens)={token_count} != len(input_ids)={len(self.input_ids)}\"\n            )\n\n\n@dataclass(frozen=True)\nclass TokenAlignerGlobalAux:\n    \"\"\"Auxiliary tensors for one side across all steps + side-level metadata.\"\"\"\n\n    step_auxs: dict[int, TokenAlignerStepAux]\n    framework: str  # \"sglang\" | \"megatron\"\n    layout: TokenLayout\n    thd_seq_lens_by_step: Optional[dict[int, list[int]]] = field(default=None)\n\n\nclass TokenLocator(_FrozenBase):\n    \"\"\"Locates tokens within a multi-step tensor store.\n\n    token i is at tensor_of_step[steps[i]][token_index_in_step[i]].\n    \"\"\"\n\n    steps: list[int]\n    token_index_in_step: list[int]\n\n    def __add__(self, other: TokenLocator) -> TokenLocator:\n        return TokenLocator(\n            steps=self.steps + other.steps,\n            token_index_in_step=self.token_index_in_step + other.token_index_in_step,\n        )\n\n\nclass TokenAlignerSeqInfo(_FrozenBase):\n    \"\"\"Information for a sequence, containing information to locate all the tokens inside the sequence.\"\"\"\n\n    # All these fields are of shape (num_tokens_in_seq,)\n    input_ids: list[int]\n    positions: list[int]\n    locator: TokenLocator\n\n    @model_validator(mode=\"after\")\n    def _validate_fields(self) -> TokenAlignerSeqInfo:\n        n: int = len(self.input_ids)\n        _check_equal_lengths(\n            input_ids=self.input_ids,\n            positions=self.positions,\n            locator_steps=self.locator.steps,\n            locator_token_index_in_step=self.locator.token_index_in_step,\n        )\n\n        if self.positions != list(range(n)):\n            raise ValueError(\n                f\"positions must be [0, 1, ..., {n - 1}], got {self.positions}\"\n            )\n\n        return self\n\n    def __add__(self, other: TokenAlignerSeqInfo) -> TokenAlignerSeqInfo:\n        return TokenAlignerSeqInfo(\n            input_ids=self.input_ids + other.input_ids,\n            positions=self.positions + other.positions,\n            locator=self.locator + other.locator,\n        )\n\n\nclass TokenAlignerSeqsInfo(_FrozenBase):\n    \"\"\"All sequences for one side across all steps.\"\"\"\n\n    sequences: dict[SeqId, TokenAlignerSeqInfo]\n    layout: TokenLayout\n\n\nclass TokenAlignerPlan(_FrozenBase):\n    \"\"\"Token alignment plan. locators.x[i] and locators.y[i] correspond to the same logical token.\"\"\"\n\n    locators: Pair[TokenLocator]\n    layouts: Pair[TokenLayout]\n\n    @model_validator(mode=\"after\")\n    def _validate_fields(self) -> TokenAlignerPlan:\n        _check_equal_lengths(\n            locators_x_steps=self.locators.x.steps,\n            locators_x_token_index_in_step=self.locators.x.token_index_in_step,\n            locators_y_steps=self.locators.y.steps,\n            locators_y_token_index_in_step=self.locators.y.token_index_in_step,\n        )\n        return self\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/aligner/unsharder/__init__.py",
    "content": ""
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/aligner/unsharder/executor.py",
    "content": "from dataclasses import dataclass, field\nfrom typing import Optional\n\nimport torch\n\nfrom sglang.srt.debug_utils.comparator.aligner.unsharder.types import (\n    ConcatParams,\n    CpThdConcatParams,\n    PickParams,\n    ReduceSumParams,\n    UnsharderParams,\n    UnsharderPlan,\n)\nfrom sglang.srt.debug_utils.comparator.dims_spec import (\n    ParallelAxis,\n    resolve_dim_by_name,\n)\nfrom sglang.srt.debug_utils.comparator.output_types import ReplicatedCheckResult\nfrom sglang.srt.debug_utils.comparator.tensor_comparator.comparator import compute_diff\n\n_REPLICATED_ATOL: float = 1e-6\n\n\n@dataclass(frozen=True)\nclass UnsharderResult:\n    tensors: list[torch.Tensor]\n    replicated_checks: list[ReplicatedCheckResult] = field(default_factory=list)\n\n\ndef execute_unsharder_plan(\n    plan: UnsharderPlan,\n    tensors: list[torch.Tensor],\n) -> UnsharderResult:\n    result_tensors: list[torch.Tensor] = []\n    all_checks: list[ReplicatedCheckResult] = []\n\n    for group_idx, group in enumerate(plan.groups):\n        group_tensors = [tensors[i] for i in group]\n        tensor, checks = _apply_unshard(\n            plan.params,\n            group_tensors,\n            axis=plan.axis,\n            group_index=group_idx,\n        )\n        result_tensors.append(tensor)\n        all_checks.extend(checks)\n\n    return UnsharderResult(tensors=result_tensors, replicated_checks=all_checks)\n\n\ndef _apply_unshard(\n    params: UnsharderParams,\n    ordered_tensors: list[torch.Tensor],\n    *,\n    axis: ParallelAxis,\n    group_index: int,\n) -> tuple[torch.Tensor, list[ReplicatedCheckResult]]:\n    if isinstance(params, PickParams):\n        checks: list[ReplicatedCheckResult] = _verify_replicated_group(\n            ordered_tensors,\n            axis=axis,\n            group_index=group_index,\n        )\n        return ordered_tensors[0], checks\n\n    if isinstance(params, ConcatParams):\n        dim: int = resolve_dim_by_name(ordered_tensors[0], params.dim_name)\n        return torch.cat(ordered_tensors, dim=dim), []\n\n    if isinstance(params, CpThdConcatParams):\n        thd_dim: int = resolve_dim_by_name(ordered_tensors[0], params.dim_name)\n        return (\n            _thd_concat(\n                ordered_tensors,\n                dim=thd_dim,\n                seq_lens_per_rank=params.seq_lens_per_rank,\n            ),\n            [],\n        )\n\n    if isinstance(params, ReduceSumParams):\n        stripped: list[torch.Tensor] = [t.rename(None) for t in ordered_tensors]\n        result: torch.Tensor = torch.stack(stripped).sum(dim=0)\n        names: tuple[Optional[str], ...] = ordered_tensors[0].names\n        if names[0] is not None:\n            result = result.refine_names(*names)\n        return result, []\n\n    raise ValueError(f\"Unsupported unshard operation: {type(params).__name__}\")\n\n\ndef _verify_replicated_group(\n    ordered_tensors: list[torch.Tensor],\n    *,\n    axis: ParallelAxis,\n    group_index: int,\n) -> list[ReplicatedCheckResult]:\n    baseline: torch.Tensor = ordered_tensors[0].rename(None).float()\n\n    return [\n        _check_replicated_pair(\n            baseline=baseline,\n            other=ordered_tensors[i],\n            axis=axis,\n            group_index=group_index,\n            compared_index=i,\n        )\n        for i in range(1, len(ordered_tensors))\n    ]\n\n\ndef _check_replicated_pair(\n    *,\n    baseline: torch.Tensor,\n    other: torch.Tensor,\n    axis: ParallelAxis,\n    group_index: int,\n    compared_index: int,\n) -> ReplicatedCheckResult:\n    other_float: torch.Tensor = other.rename(None).float()\n\n    if baseline.shape != other_float.shape:\n        passed = False\n        diff_info = None\n    else:\n        diff_info = compute_diff(\n            x_baseline=baseline,\n            x_target=other_float,\n            diff_threshold=_REPLICATED_ATOL,\n        )\n        passed = diff_info.max_abs_diff <= _REPLICATED_ATOL\n\n    return ReplicatedCheckResult(\n        axis=axis.value,\n        group_index=group_index,\n        compared_index=compared_index,\n        baseline_index=0,\n        passed=passed,\n        atol=_REPLICATED_ATOL,\n        diff=diff_info,\n    )\n\n\ndef _thd_concat(\n    ordered_tensors: list[torch.Tensor],\n    *,\n    dim: int,\n    seq_lens_per_rank: list[int],\n) -> torch.Tensor:\n    \"\"\"Per-seq concat across ranks for THD format.\n\n    Each rank holds segments of each seq packed contiguously:\n      rank_data = [seq0_tokens | seq1_tokens | ... | pad_tokens]\n\n    This function splits each rank by seq_lens, then interleaves across ranks\n    per-seq: [seqA_r0 + seqA_r1 + ... | seqB_r0 + seqB_r1 + ... | tail_pad].\n    \"\"\"\n    names: tuple[Optional[str], ...] = ordered_tensors[0].names\n    stripped: list[torch.Tensor] = [t.rename(None) for t in ordered_tensors]\n\n    # Split each rank into [seq0, seq1, ..., tail_remainder]\n    split_sizes: list[int] = list(seq_lens_per_rank)\n    remainder: int = stripped[0].shape[dim] - sum(split_sizes)\n    if remainder < 0:\n        raise ValueError(\n            f\"sum(seq_lens_per_rank)={sum(split_sizes)} exceeds tensor dim size \"\n            f\"{stripped[0].shape[dim]} along dim={dim}\"\n        )\n    if remainder > 0:\n        split_sizes.append(remainder)\n    per_rank_splits: list[tuple[torch.Tensor, ...]] = [\n        t.split(split_sizes, dim=dim) for t in stripped\n    ]\n\n    # Per-seq concat across ranks, then concatenate all seqs\n    result: torch.Tensor = torch.cat(\n        [torch.cat(rank_parts, dim=dim) for rank_parts in zip(*per_rank_splits)],\n        dim=dim,\n    )\n\n    if names[0] is not None:\n        result = result.refine_names(*names)\n    return result\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/aligner/unsharder/parallel_info.py",
    "content": "from typing import Optional\n\nfrom sglang.srt.debug_utils.comparator.aligner.unsharder.types import AxisInfo\nfrom sglang.srt.debug_utils.comparator.dims_spec import ParallelAxis\n\n_PARALLEL_INFO_KEYS = (\"sglang_parallel_info\", \"megatron_parallel_info\")\n\n\ndef _is_error_sentinel(value: dict) -> bool:\n    \"\"\"Check if a parallel_info dict is an error sentinel (e.g. {'megatron_error': True}).\"\"\"\n    return any(k.endswith(\"_error\") for k in value)\n\n\ndef normalize_parallel_info(meta: dict) -> dict[ParallelAxis, AxisInfo]:\n    \"\"\"Extract unified parallel info from dump meta.\"\"\"\n    info: Optional[dict] = None\n    for key in _PARALLEL_INFO_KEYS:\n        value = meta.get(key)\n        if isinstance(value, dict) and value and not _is_error_sentinel(value):\n            if info is not None:\n                raise ValueError(\n                    f\"Meta contains multiple parallel_info keys among {_PARALLEL_INFO_KEYS}\"\n                )\n            info = value\n\n    if info is None:\n        info = {}\n\n    result: dict[ParallelAxis, AxisInfo] = {}\n    for axis in ParallelAxis:\n        axis_rank = info.get(f\"{axis.value}_rank\")\n        axis_size = info.get(f\"{axis.value}_size\")\n\n        # Recompute pseudo-axis lives at top-level meta, not inside parallel_info\n        if axis_rank is None:\n            axis_rank = meta.get(f\"{axis.value}_rank\")\n            axis_size = meta.get(f\"{axis.value}_size\")\n\n        if axis_rank is not None and axis_size is not None and axis_size > 1:\n            result[axis] = AxisInfo(\n                axis_rank=axis_rank,\n                axis_size=axis_size,\n            )\n\n    return result\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/aligner/unsharder/planner.py",
    "content": "from collections import defaultdict\nfrom typing import NamedTuple, Optional\n\nfrom sglang.srt.debug_utils.comparator.aligner.unsharder.types import (\n    AxisInfo,\n    ConcatParams,\n    CpThdConcatParams,\n    PickParams,\n    ReduceSumParams,\n    UnsharderParams,\n    UnsharderPlan,\n)\nfrom sglang.srt.debug_utils.comparator.dims_spec import (\n    TOKEN_DIM_NAME,\n    DimSpec,\n    ParallelAxis,\n    ParallelModifier,\n)\n\n# _CoordsList[tensor_index][axis] =\n#     the axis_rank (shard position) of the tensor_index-th tensor along `axis`\n#     (e.g. coords[2] = {TP: 3} means tensor 2 is the 3rd shard in TP axis)\n_CoordsList = list[dict[ParallelAxis, int]]\n\n\nclass _GroupResult(NamedTuple):\n    groups: list[list[int]]\n    projected_coords: _CoordsList\n\n\ndef compute_unsharder_plan(\n    dim_specs: list[DimSpec],\n    parallel_infos: list[dict[ParallelAxis, AxisInfo]],\n    *,\n    explicit_replicated_axes: frozenset[ParallelAxis] = frozenset(),\n    thd_global_seq_lens: Optional[list[int]] = None,\n) -> list[UnsharderPlan]:\n    if not parallel_infos:\n        raise ValueError(\"parallel_infos must not be empty\")\n\n    # Within each dim spec, reverse modifier order: innermost shard (rightmost) unshards first.\n    reversed_sharded_modifiers: list[tuple[str, ParallelModifier]] = [\n        (spec.sanitized_name, m)\n        for spec in dim_specs\n        for m in reversed(spec.parallel_modifiers)\n    ]\n\n    sharded_axes_raw: set[ParallelAxis] = {\n        m.axis for _, m in reversed_sharded_modifiers\n    }\n    all_axes: set[ParallelAxis] = {axis for info in parallel_infos for axis in info}\n\n    # axis annotated in dims but absent from all parallel_infos -> axis_size=1, skip\n    sharded_axes: set[ParallelAxis] = sharded_axes_raw & all_axes\n    reversed_sharded_modifiers = [\n        (name, m) for name, m in reversed_sharded_modifiers if m.axis in sharded_axes\n    ]\n\n    # RECOMPUTE_PSEUDO is always implicitly replicated (system-injected, not user-facing)\n    auto_replicated: frozenset[ParallelAxis] = frozenset(\n        {ParallelAxis.RECOMPUTE_PSEUDO} & all_axes\n    )\n    effective_replicated: frozenset[ParallelAxis] = (\n        explicit_replicated_axes | auto_replicated\n    )\n\n    _validate_explicit_replicated(\n        explicit_replicated_axes=effective_replicated,\n        sharded_axes=sharded_axes,\n        all_axes=all_axes,\n    )\n    replicated_axes: frozenset[ParallelAxis] = effective_replicated\n\n    if not sharded_axes and not replicated_axes:\n        return []\n\n    _validate(\n        axes_to_validate=sharded_axes | replicated_axes,\n        parallel_infos=parallel_infos,\n    )\n\n    current_coords: _CoordsList = [\n        {axis: info[axis].axis_rank for axis in sharded_axes | replicated_axes}\n        for info in parallel_infos\n    ]\n\n    axis_and_params: list[tuple[ParallelAxis, UnsharderParams]] = [\n        (axis, PickParams()) for axis in sorted(replicated_axes, key=lambda a: a.value)\n    ] + [\n        (\n            modifier.axis,\n            _resolve_unshard_params(\n                modifier=modifier,\n                dim_name=dim_name,\n                parallel_infos=parallel_infos,\n                thd_global_seq_lens=thd_global_seq_lens,\n            ),\n        )\n        for dim_name, modifier in reversed_sharded_modifiers\n    ]\n\n    plans: list[UnsharderPlan] = []\n    for axis, params in axis_and_params:\n        result = _group_and_project(\n            current_coords=current_coords,\n            target_axis=axis,\n        )\n        plans.append(UnsharderPlan(axis=axis, params=params, groups=result.groups))\n        current_coords = result.projected_coords\n\n    return plans\n\n\ndef _validate_explicit_replicated(\n    *,\n    explicit_replicated_axes: frozenset[ParallelAxis],\n    sharded_axes: set[ParallelAxis],\n    all_axes: set[ParallelAxis],\n) -> None:\n    \"\"\"Validate explicit replicated declarations against sharded axes and parallel_infos.\"\"\"\n    invalid: frozenset[ParallelAxis] = explicit_replicated_axes - all_axes\n    if invalid:\n        invalid_names: str = \", \".join(sorted(a.value for a in invalid))\n        raise ValueError(\n            f\"Declared replicated axes {{{invalid_names}}} not found in parallel_infos \"\n            f\"(active axes: {{{', '.join(sorted(a.value for a in all_axes))}}})\"\n        )\n\n    conflict: set[ParallelAxis] = explicit_replicated_axes & sharded_axes\n    if conflict:\n        conflict_names: str = \", \".join(sorted(a.value for a in conflict))\n        raise ValueError(\n            f\"Axes {{{conflict_names}}} declared as both sharded and replicated\"\n        )\n\n    undeclared: set[ParallelAxis] = all_axes - sharded_axes - explicit_replicated_axes\n    if undeclared:\n        undeclared_names: str = \", \".join(sorted(a.value for a in undeclared))\n        raise ValueError(\n            f\"Axes {{{undeclared_names}}} are active (axis_size > 1) but not declared \"\n            f\"in dims. Annotate as sharded in dim spec or as '# axis:replicated'.\"\n        )\n\n\ndef _validate(\n    *,\n    axes_to_validate: set[ParallelAxis],\n    parallel_infos: list[dict[ParallelAxis, AxisInfo]],\n) -> None:\n    \"\"\"Check that every rank has all axes, sizes are consistent, and ranks are complete.\"\"\"\n    axis_sizes: dict[ParallelAxis, int] = {}\n\n    for world_rank, parallel_info in enumerate(parallel_infos):\n        for axis in axes_to_validate:\n            if axis not in parallel_info:\n                raise ValueError(\n                    f\"world_rank={world_rank} missing parallel_info for \"\n                    f\"axis {axis.value!r}\"\n                )\n\n            axis_info = parallel_info[axis]\n            if axis not in axis_sizes:\n                axis_sizes[axis] = axis_info.axis_size\n            elif axis_info.axis_size != axis_sizes[axis]:\n                raise ValueError(\n                    f\"Inconsistent axis_size for {axis.value}: \"\n                    f\"expected {axis_sizes[axis]}, got {axis_info.axis_size} \"\n                    f\"at world_rank={world_rank}\"\n                )\n\n    for axis, expected_size in axis_sizes.items():\n        seen_ranks = {info[axis].axis_rank for info in parallel_infos}\n        if seen_ranks != set(range(expected_size)):\n            raise ValueError(\n                f\"axis_rank coverage for {axis.value} is incomplete: \"\n                f\"got {sorted(seen_ranks)}, expected 0..{expected_size - 1}\"\n            )\n\n\ndef _group_and_project(\n    *,\n    current_coords: _CoordsList,\n    target_axis: ParallelAxis,\n) -> _GroupResult:\n    \"\"\"Group tensors by other-axes coords, sort within group by target_axis rank.\"\"\"\n    # buckets[coords_excluding_target] = [(axis_rank, tensor_index), ...]\n    # e.g. when target_axis=CP: buckets[{(TP,0)}] = [(0, 1), (1, 3)]\n    #   means tensor 1 (CP rank 0) and tensor 3 (CP rank 1) share TP rank 0\n    buckets: dict[frozenset, list[tuple[int, int]]] = defaultdict(list)\n\n    for idx, coords in enumerate(current_coords):\n        key = frozenset((k, v) for k, v in coords.items() if k != target_axis)\n        buckets[key].append((coords[target_axis], idx))\n\n    groups: list[list[int]] = []\n    projected: _CoordsList = []\n    for key in sorted(buckets, key=lambda k: sorted((a.value, v) for a, v in k)):\n        entries = sorted(buckets[key])\n        groups.append([idx for _, idx in entries])\n        projected.append(dict(key))\n\n    return _GroupResult(groups=groups, projected_coords=projected)\n\n\ndef _resolve_unshard_params(\n    *,\n    modifier: ParallelModifier,\n    dim_name: str,\n    parallel_infos: list[dict[ParallelAxis, AxisInfo]],\n    thd_global_seq_lens: Optional[list[int]] = None,\n) -> UnsharderParams:\n    if modifier.reduction is not None:\n        return ReduceSumParams()\n\n    if (\n        dim_name == TOKEN_DIM_NAME\n        and modifier.axis == ParallelAxis.CP\n        and thd_global_seq_lens is not None\n    ):\n        axis_size: int = parallel_infos[0][modifier.axis].axis_size\n        for s in thd_global_seq_lens:\n            if s % axis_size != 0:\n                raise ValueError(\n                    f\"THD seq_len {s} is not divisible by cp_size {axis_size}. \"\n                    f\"Sequences must be padded to a multiple of cp_size for CP zigzag.\"\n                )\n        seq_lens_per_rank: list[int] = [s // axis_size for s in thd_global_seq_lens]\n        return CpThdConcatParams(dim_name=dim_name, seq_lens_per_rank=seq_lens_per_rank)\n\n    return ConcatParams(dim_name=dim_name)\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/aligner/unsharder/types.py",
    "content": "from __future__ import annotations\n\nfrom typing import Annotated, Literal, Union\n\nfrom pydantic import Field, model_validator\n\nfrom sglang.srt.debug_utils.comparator.dims_spec import ParallelAxis\nfrom sglang.srt.debug_utils.comparator.utils import _FrozenBase\n\n\nclass AxisInfo(_FrozenBase):\n    axis_rank: int\n    axis_size: int\n\n    @model_validator(mode=\"after\")\n    def _validate_bounds(self) -> AxisInfo:\n        if self.axis_size <= 0:\n            raise ValueError(f\"axis_size must be > 0, got {self.axis_size}\")\n        if not (0 <= self.axis_rank < self.axis_size):\n            raise ValueError(\n                f\"axis_rank must be in [0, {self.axis_size}), got {self.axis_rank}\"\n            )\n        return self\n\n\nclass ConcatParams(_FrozenBase):\n    op: Literal[\"concat\"] = \"concat\"\n    dim_name: str\n\n\nclass CpThdConcatParams(_FrozenBase):\n    op: Literal[\"cp_thd_concat\"] = \"cp_thd_concat\"\n    dim_name: str\n    seq_lens_per_rank: list[int]  # per-seq token count on each rank, e.g. [50, 32, 46]\n\n\nclass PickParams(_FrozenBase):\n    op: Literal[\"pick\"] = \"pick\"\n\n\nclass ReduceSumParams(_FrozenBase):\n    op: Literal[\"reduce_sum\"] = \"reduce_sum\"\n\n\nUnsharderParams = Annotated[\n    Union[ConcatParams, CpThdConcatParams, PickParams, ReduceSumParams],\n    Field(discriminator=\"op\"),\n]\n\n\nclass UnsharderPlan(_FrozenBase):\n    type: Literal[\"unsharder\"] = \"unsharder\"\n    axis: ParallelAxis\n    params: UnsharderParams\n    # groups[i] = indices in the input tensor list, which will be operated (e.g. concat) into i-th output tensor.\n    #\n    # Multistep example (CP=2, TP=2, 4 input tensors):\n    #   plan[0] (CP): groups=[[0,2],[1,3]]  — 4 tensors → 2 tensors\n    #   plan[1] (TP): groups=[[0,1]]        — 2 tensors → 1 tensor\n    groups: list[list[int]]\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/bundle_comparator.py",
    "content": "\"\"\"Compare two tensor bundles.\"\"\"\n\nfrom __future__ import annotations\n\nfrom pathlib import Path\nfrom typing import Any, Optional, Union\n\nimport torch\n\nfrom sglang.srt.debug_utils.comparator.aligner.entrypoint.executor import (\n    AlignerResult,\n    execute_aligner_plan,\n)\nfrom sglang.srt.debug_utils.comparator.aligner.entrypoint.planner import (\n    compute_aligner_plan,\n)\nfrom sglang.srt.debug_utils.comparator.aligner.entrypoint.types import AlignerPlan\nfrom sglang.srt.debug_utils.comparator.aligner.token_aligner.smart.types import (\n    TokenAlignerPlan,\n)\nfrom sglang.srt.debug_utils.comparator.dims_spec import (\n    SEQ_DIM_NAME,\n    TOKEN_DIM_NAME,\n    apply_dim_names,\n    parse_dims,\n    resolve_dim_names,\n)\nfrom sglang.srt.debug_utils.comparator.dp_utils import filter_to_non_empty_dp_rank\nfrom sglang.srt.debug_utils.comparator.log_sink import log_sink\nfrom sglang.srt.debug_utils.comparator.meta_overrider import MetaOverrider\nfrom sglang.srt.debug_utils.comparator.output_types import (\n    BundleFileInfo,\n    BundleSideInfo,\n    ComparisonNonTensorRecord,\n    ComparisonSkipRecord,\n    ComparisonTensorRecord,\n    ErrorLog,\n    _split_logs,\n)\nfrom sglang.srt.debug_utils.comparator.tensor_comparator.comparator import (\n    compare_tensor_pair,\n)\nfrom sglang.srt.debug_utils.comparator.utils import Pair\nfrom sglang.srt.debug_utils.dump_loader import LOAD_FAILED, ValueWithMeta\n\n_FAILED_SIDE_MAP: dict[str, str] = {\"x\": \"baseline\", \"y\": \"target\"}\n\n\ndef _collect_bundle_side_info(\n    items: list[ValueWithMeta],\n    metas: list[dict[str, Any]],\n) -> BundleSideInfo:\n    from sglang.srt.debug_utils.comparator.display import (\n        _PARALLEL_INFO_KEYS,\n        extract_parallel_info,\n    )\n\n    files: list[BundleFileInfo] = []\n    for item, meta in zip(items, metas):\n        assert isinstance(item.value, torch.Tensor)\n        tensor: torch.Tensor = item.value\n\n        parallel_info: dict[str, str] = {}\n        for key in _PARALLEL_INFO_KEYS:\n            extract_parallel_info(row_data=parallel_info, info=meta.get(key, {}))\n\n        files.append(\n            BundleFileInfo(\n                shape=list(tensor.shape),\n                dtype=str(tensor.dtype),\n                rank=meta.get(\"rank\"),\n                parallel_info=parallel_info if parallel_info else None,\n            )\n        )\n\n    dims: Optional[str] = metas[0].get(\"dims\") if metas else None\n    return BundleSideInfo(num_files=len(files), files=files, dims=dims)\n\n\ndef compare_bundle_pair(\n    *,\n    name: str,\n    filenames_pair: Pair[list[str]],\n    dir_pair: Pair[Path],\n    token_aligner_mode: Optional[str],\n    token_aligner_plan: Optional[TokenAlignerPlan],\n    diff_threshold: float,\n    thd_seq_lens_by_step_pair: Pair[Optional[dict[int, list[int]]]] = Pair(\n        x=None, y=None\n    ),\n    viz_output_dir: Optional[Path] = None,\n    compute_per_token: bool = False,\n    meta_overrider: Optional[MetaOverrider] = None,\n) -> Union[ComparisonTensorRecord, ComparisonSkipRecord, ComparisonNonTensorRecord]:\n    with log_sink.context() as collected_logs:\n        result = _compare_bundle_pair_inner(\n            name=name,\n            filenames_pair=filenames_pair,\n            dir_pair=dir_pair,\n            token_aligner_mode=token_aligner_mode,\n            token_aligner_plan=token_aligner_plan,\n            diff_threshold=diff_threshold,\n            thd_seq_lens_by_step_pair=thd_seq_lens_by_step_pair,\n            viz_output_dir=viz_output_dir,\n            compute_per_token=compute_per_token,\n            meta_overrider=meta_overrider,\n        )\n\n    errors, infos = _split_logs(collected_logs)\n    return result.model_copy(update={\"errors\": errors, \"infos\": infos})\n\n\ndef _compare_bundle_pair_inner(\n    *,\n    name: str,\n    filenames_pair: Pair[list[str]],\n    dir_pair: Pair[Path],\n    token_aligner_mode: Optional[str],\n    token_aligner_plan: Optional[TokenAlignerPlan],\n    diff_threshold: float,\n    thd_seq_lens_by_step_pair: Pair[Optional[dict[int, list[int]]]] = Pair(\n        x=None, y=None\n    ),\n    viz_output_dir: Optional[Path] = None,\n    compute_per_token: bool = False,\n    meta_overrider: Optional[MetaOverrider] = None,\n) -> Union[ComparisonTensorRecord, ComparisonSkipRecord, ComparisonNonTensorRecord]:\n    # 1. Load all successfully loaded values\n    all_pair: Pair[list[ValueWithMeta]] = Pair(\n        x=_load_all_values(filenames=filenames_pair.x, base_path=dir_pair.x),\n        y=_load_all_values(filenames=filenames_pair.y, base_path=dir_pair.y),\n    )\n\n    if not all_pair.x or not all_pair.y:\n        reason = \"baseline_load_failed\" if not all_pair.x else \"target_load_failed\"\n        return ComparisonSkipRecord(name=name, reason=reason)\n\n    # 1b. Dims override: patch meta[\"dims\"] before DP filter reads it\n    # (--override-dims may add ``# dp:=moe_dp``, so it must run first)\n    if meta_overrider is not None and not meta_overrider.is_empty:\n        _apply = meta_overrider.apply_to_meta\n        all_pair = Pair(\n            x=[\n                ValueWithMeta(\n                    value=v.value, meta=_apply(name=name, meta=v.meta, side=\"baseline\")\n                )\n                for v in all_pair.x\n            ],\n            y=[\n                ValueWithMeta(\n                    value=v.value, meta=_apply(name=name, meta=v.meta, side=\"target\")\n                )\n                for v in all_pair.y\n            ],\n        )\n\n    # 1c. DP filter: keep only the non-empty dp_rank\n    all_pair = all_pair.map(\n        lambda items: filter_to_non_empty_dp_rank(\n            items, dp_group_alias=_extract_dp_alias_from_items(items)\n        )\n    )\n\n    # 2. Check if any side has non-tensor values → non-tensor display path\n    has_non_tensor: bool = any(\n        not isinstance(it.value, torch.Tensor) for it in [*all_pair.x, *all_pair.y]\n    )\n    if has_non_tensor:\n        return _compare_bundle_pair_non_tensor_type(name=name, value_pair=all_pair)\n\n    # 3. All values are tensors → tensor comparison path\n    return _compare_bundle_pair_tensor_type(\n        name=name,\n        valid_pair=all_pair,\n        token_aligner_mode=token_aligner_mode,\n        token_aligner_plan=token_aligner_plan,\n        diff_threshold=diff_threshold,\n        thd_seq_lens_by_step_pair=thd_seq_lens_by_step_pair,\n        viz_output_dir=viz_output_dir,\n        compute_per_token=compute_per_token,\n    )\n\n\ndef _extract_dp_alias_from_items(items: list[ValueWithMeta]) -> Optional[str]:\n    \"\"\"Extract dp group alias from the first item's ``meta[\"dims\"]``.\"\"\"\n    if not items:\n        return None\n    dims_str: Optional[str] = items[0].meta.get(\"dims\")\n    if dims_str is None:\n        return None\n    return parse_dims(dims_str).dp_group_alias\n\n\ndef _compare_bundle_pair_tensor_type(\n    *,\n    name: str,\n    valid_pair: Pair[list[ValueWithMeta]],\n    token_aligner_mode: Optional[str],\n    token_aligner_plan: Optional[TokenAlignerPlan],\n    diff_threshold: float,\n    thd_seq_lens_by_step_pair: Pair[Optional[dict[int, list[int]]]] = Pair(\n        x=None, y=None\n    ),\n    viz_output_dir: Optional[Path] = None,\n    compute_per_token: bool = False,\n) -> Union[ComparisonTensorRecord, ComparisonSkipRecord]:\n    if not valid_pair.x or not valid_pair.y:\n        reason = \"baseline_load_failed\" if not valid_pair.x else \"target_load_failed\"\n        return ComparisonSkipRecord(name=name, reason=reason)\n\n    # Plan (meta only, no tensor)\n    metas_pair: Pair[list[dict[str, Any]]] = valid_pair.map(\n        lambda items: [it.meta for it in items]\n    )\n    plan: AlignerPlan = compute_aligner_plan(\n        metas_pair=metas_pair,\n        token_aligner_mode=token_aligner_mode,\n        token_aligner_plan=token_aligner_plan,\n        thd_seq_lens_by_step_pair=thd_seq_lens_by_step_pair,\n    )\n\n    # Collect raw bundle info before alignment\n    raw_bundle_info: Pair[BundleSideInfo] = Pair(\n        x=_collect_bundle_side_info(items=valid_pair.x, metas=metas_pair.x),\n        y=_collect_bundle_side_info(items=valid_pair.y, metas=metas_pair.y),\n    )\n\n    # Apply dim names to tensors, then execute\n    tensors_pair: Pair[list[torch.Tensor]] = Pair(\n        x=_apply_dim_names_from_meta(\n            tensors=[it.value for it in valid_pair.x],\n            metas=metas_pair.x,\n        ),\n        y=_apply_dim_names_from_meta(\n            tensors=[it.value for it in valid_pair.y],\n            metas=metas_pair.y,\n        ),\n    )\n    aligner_result: AlignerResult = execute_aligner_plan(\n        tensors_pair=tensors_pair, plan=plan\n    )\n    replicated_checks = aligner_result.replicated_checks\n\n    if aligner_result.tensors is None:\n        assert aligner_result.failed_side_xy is not None\n        side_name: str = _FAILED_SIDE_MAP[aligner_result.failed_side_xy]\n        reason: str = f\"{side_name}_load_failed\"\n        return ComparisonSkipRecord(name=name, reason=reason)\n\n    # Resolve seq_dim for per-token computation\n    seq_dim: Optional[int] = (\n        _resolve_seq_dim(aligner_result.tensors.y) if compute_per_token else None\n    )\n\n    # Compare\n    aligned_baseline: torch.Tensor = aligner_result.tensors.x.rename(None)\n    aligned_target: torch.Tensor = aligner_result.tensors.y.rename(None)\n\n    info = compare_tensor_pair(\n        x_baseline=aligned_baseline,\n        x_target=aligned_target,\n        name=name,\n        diff_threshold=diff_threshold,\n        seq_dim=seq_dim,\n    )\n    record = ComparisonTensorRecord(\n        **info.model_dump(),\n        traced_plan=aligner_result.traced_plan,\n        replicated_checks=replicated_checks,\n        raw_bundle_info=raw_bundle_info,\n    )\n\n    if viz_output_dir is not None:\n        _try_generate_viz(\n            baseline=aligned_baseline,\n            target=aligned_target,\n            name=name,\n            viz_output_dir=viz_output_dir,\n        )\n\n    return record\n\n\ndef _try_generate_viz(\n    *,\n    baseline: torch.Tensor,\n    target: torch.Tensor,\n    name: str,\n    viz_output_dir: Path,\n) -> None:\n    from sglang.srt.debug_utils.comparator.visualizer import (\n        generate_comparison_figure,\n    )\n    from sglang.srt.debug_utils.comparator.visualizer.preprocessing import (\n        _sanitize_filename,\n    )\n\n    filename: str = _sanitize_filename(name) + \".png\"\n    output_path: Path = viz_output_dir / filename\n\n    try:\n        generate_comparison_figure(\n            baseline=baseline,\n            target=target,\n            name=name,\n            output_path=output_path,\n        )\n    except Exception as exc:\n        log_sink.add(\n            ErrorLog(\n                category=\"visualizer\",\n                message=f\"Visualization failed for {name}: {exc}\",\n            )\n        )\n\n\ndef _resolve_seq_dim(tensor: torch.Tensor) -> Optional[int]:\n    \"\"\"Find the token/seq dimension index from the tensor's named dims.\"\"\"\n    if tensor.names[0] is None:\n        return None\n\n    names: tuple[Optional[str], ...] = tensor.names\n    for target_name in (TOKEN_DIM_NAME, SEQ_DIM_NAME):\n        if target_name in names:\n            return list(names).index(target_name)\n\n    return None\n\n\ndef _compare_bundle_pair_non_tensor_type(\n    *,\n    name: str,\n    value_pair: Pair[list[ValueWithMeta]],\n) -> ComparisonNonTensorRecord:\n    baseline_value: Any = value_pair.x[0].value\n    target_value: Any = value_pair.y[0].value\n\n    try:\n        values_equal: bool = bool(baseline_value == target_value)\n    except Exception:\n        values_equal = False\n\n    return ComparisonNonTensorRecord(\n        name=name,\n        baseline_value=repr(baseline_value),\n        target_value=repr(target_value),\n        baseline_type=type(baseline_value).__name__,\n        target_type=type(target_value).__name__,\n        values_equal=values_equal,\n    )\n\n\ndef _apply_dim_names_from_meta(\n    *,\n    tensors: list[torch.Tensor],\n    metas: list[dict[str, Any]],\n) -> list[torch.Tensor]:\n    if not metas:\n        return tensors\n\n    dims_str: Optional[str] = metas[0].get(\"dims\")\n    if dims_str is None:\n        return tensors\n\n    dim_names: list[str] = resolve_dim_names(dims_str)\n    return [apply_dim_names(t, dim_names) for t in tensors]\n\n\ndef _load_all_values(filenames: list[str], base_path: Path) -> list[ValueWithMeta]:\n    result: list[ValueWithMeta] = []\n    for f in filenames:\n        item: ValueWithMeta = ValueWithMeta.load(base_path / f)\n        if item.value is LOAD_FAILED:\n            log_sink.add(\n                ErrorLog(\n                    category=\"load_failed\",\n                    message=f\"Failed to load tensor file: {f}\",\n                )\n            )\n            continue\n        result.append(item)\n    return result\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/bundle_matcher.py",
    "content": "from __future__ import annotations\n\nimport dataclasses\nfrom dataclasses import dataclass\nfrom typing import Any\n\nimport polars as pl\n\nfrom sglang.srt.debug_utils.comparator.utils import Pair\nfrom sglang.srt.debug_utils.dump_loader import filter_rows\n\n\n@dataclass(frozen=True)\nclass TensorFileInfo:\n    filename: str\n    name: str\n    step: int\n\n\nTensorBundleInfo = list[TensorFileInfo]\n\n\ndef match_bundles(\n    *,\n    dfs: Pair[pl.DataFrame],\n    skip_keys: set[str],\n) -> list[Pair[TensorBundleInfo]]:\n    match_key_cols: list[str] = [c for c in dfs.y.columns if c not in skip_keys]\n    unique_keys: pl.DataFrame = dfs.y.select(match_key_cols).unique(maintain_order=True)\n\n    results: list[Pair[TensorBundleInfo]] = []\n    for key_values in unique_keys.iter_rows(named=True):\n        result = dfs.map(\n            lambda df: _rows_to_tensor_infos(filter_rows(df, conditions=key_values))\n        )\n        results.append(result)\n\n    return results\n\n\ndef _rows_to_tensor_infos(rows: list[dict[str, Any]]) -> list[TensorFileInfo]:\n    tensor_info_fields: set[str] = {f.name for f in dataclasses.fields(TensorFileInfo)}\n    return [\n        TensorFileInfo(**{k: v for k, v in row.items() if k in tensor_info_fields})\n        for row in rows\n    ]\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/dims_spec/__init__.py",
    "content": "from sglang.srt.debug_utils.comparator.dims_spec.dim_parser import parse_dim\nfrom sglang.srt.debug_utils.comparator.dims_spec.dims_parser import (\n    _SingletonDimUtil,\n    parse_dims,\n    resolve_dim_names,\n)\nfrom sglang.srt.debug_utils.comparator.dims_spec.tensor_naming import (\n    apply_dim_names,\n    find_dim_index,\n    resolve_dim_by_name,\n    strip_dim_names,\n)\nfrom sglang.srt.debug_utils.comparator.dims_spec.types import (\n    _FUSED_NAME_SEP,\n    BATCH_DIM_NAME,\n    SEQ_DIM_NAME,\n    SQUEEZE_DIM_NAME,\n    TOKEN_DIM_NAME,\n    DimSpec,\n    DimsSpec,\n    Ordering,\n    ParallelAxis,\n    ParallelModifier,\n    Reduction,\n    TokenLayout,\n)\n\n__all__ = [\n    \"BATCH_DIM_NAME\",\n    \"SEQ_DIM_NAME\",\n    \"SQUEEZE_DIM_NAME\",\n    \"TOKEN_DIM_NAME\",\n    \"DimsSpec\",\n    \"DimSpec\",\n    \"Ordering\",\n    \"ParallelAxis\",\n    \"ParallelModifier\",\n    \"Reduction\",\n    \"TokenLayout\",\n    \"_FUSED_NAME_SEP\",\n    \"_SingletonDimUtil\",\n    \"apply_dim_names\",\n    \"find_dim_index\",\n    \"parse_dim\",\n    \"parse_dims\",\n    \"resolve_dim_by_name\",\n    \"resolve_dim_names\",\n    \"strip_dim_names\",\n]\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/dims_spec/comment_parser.py",
    "content": "from __future__ import annotations\n\nimport re\nfrom typing import NamedTuple, Optional\n\nfrom sglang.srt.debug_utils.comparator.dims_spec.types import (\n    _AXIS_LOOKUP,\n    ParallelAxis,\n)\n\n_DP_ALIAS_PATTERN = re.compile(r\"^dp:=(\\w+)$\")\n_REPLICATED_PATTERN = re.compile(r\"^(\\w+):replicated$\")\n\n\nclass _CommentSuffix(NamedTuple):\n    dp_group_alias: Optional[str] = None\n    replicated_axes: frozenset[ParallelAxis] = frozenset()\n\n\ndef _parse_comment_suffix(declaration_part: str) -> _CommentSuffix:\n    \"\"\"Parse the ``#`` comment section for dp alias and replicated declarations.\"\"\"\n    dp_group_alias: Optional[str] = None\n    replicated_axes: set[ParallelAxis] = set()\n\n    for token in declaration_part.strip().split():\n        dp_match = _DP_ALIAS_PATTERN.match(token)\n        if dp_match is not None:\n            if dp_group_alias is not None:\n                raise ValueError(\n                    f\"Duplicate dp alias declaration: already have {dp_group_alias!r}, \"\n                    f\"got {dp_match.group(1)!r}\"\n                )\n            dp_group_alias = dp_match.group(1)\n            continue\n\n        repl_match = _REPLICATED_PATTERN.match(token)\n        if repl_match is not None:\n            axis_str: str = repl_match.group(1)\n            axis: Optional[ParallelAxis] = _AXIS_LOOKUP.get(axis_str)\n            if axis is None:\n                raise ValueError(\n                    f\"Unknown axis {axis_str!r} in replicated declaration: {token!r}\"\n                )\n            if axis in replicated_axes:\n                raise ValueError(\n                    f\"Duplicate replicated declaration for axis {axis_str!r}\"\n                )\n            replicated_axes.add(axis)\n            continue\n\n        raise ValueError(\n            f\"Unrecognized token {token!r} in # comment section. \"\n            f\"Expected 'dp:=<group>' or '<axis>:replicated'.\"\n        )\n\n    return _CommentSuffix(\n        dp_group_alias=dp_group_alias,\n        replicated_axes=frozenset(replicated_axes),\n    )\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/dims_spec/dim_parser.py",
    "content": "from __future__ import annotations\n\nimport re\nfrom typing import Optional\n\nfrom sglang.srt.debug_utils.comparator.dims_spec.modifier_parser import (\n    _parse_modifiers,\n)\nfrom sglang.srt.debug_utils.comparator.dims_spec.types import (\n    SQUEEZE_DIM_NAME,\n    DimSpec,\n    ParallelModifier,\n)\n\n_DIM_PATTERN = re.compile(r\"^(?P<name>[a-zA-Z_]\\w*)(?:\\[(?P<modifiers>[^\\]]+)\\])?$\")\n\n_FUSED_DIM_PATTERN = re.compile(r\"^\\((?P<inner>[^)]+)\\)(?:\\[(?P<modifiers>[^\\]]+)\\])?$\")\n\n_SUB_DIM_NAME_PATTERN = re.compile(r\"^[a-zA-Z_]\\w*$\")\n\n\ndef parse_dim(token: str) -> DimSpec:\n    if token == SQUEEZE_DIM_NAME:\n        return DimSpec(name=SQUEEZE_DIM_NAME)\n\n    fused_match = _FUSED_DIM_PATTERN.match(token)\n    if fused_match is not None:\n        return _parse_fused_dim(token=token, fused_match=fused_match)\n\n    return _parse_single_dim(token)\n\n\ndef _parse_single_dim(token: str) -> DimSpec:\n    match = _DIM_PATTERN.match(token)\n    if match is None:\n        raise ValueError(f\"Invalid dim token: {token!r}\")\n\n    name: str = match.group(\"name\")\n    modifiers: list[ParallelModifier] = _parse_modifiers(\n        modifiers_str=match.group(\"modifiers\"), dim_token=token\n    )\n    return DimSpec(name=name, parallel_modifiers=modifiers)\n\n\ndef _parse_fused_dim(*, token: str, fused_match: re.Match[str]) -> DimSpec:\n    inner: str = fused_match.group(\"inner\")\n    modifiers_str: Optional[str] = fused_match.group(\"modifiers\")\n\n    sub_names: list[str] = [s.strip() for s in inner.split(\"*\")]\n    for sub_name in sub_names:\n        if not _SUB_DIM_NAME_PATTERN.match(sub_name):\n            raise ValueError(\n                f\"Invalid sub-dim {sub_name!r} in fused dim token: {token!r}\"\n            )\n\n    if len(sub_names) != len(set(sub_names)):\n        raise ValueError(f\"Duplicate sub-dim names in fused dim token: {token!r}\")\n\n    if len(sub_names) < 2:\n        raise ValueError(\n            f\"Fused dim must have at least 2 sub-dims, got {len(sub_names)} in: {token!r}\"\n        )\n\n    fused_name: str = \"*\".join(sub_names)\n    modifiers: list[ParallelModifier] = _parse_modifiers(\n        modifiers_str=modifiers_str, dim_token=token\n    )\n    return DimSpec(name=fused_name, parallel_modifiers=modifiers)\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/dims_spec/dims_parser.py",
    "content": "from __future__ import annotations\n\nfrom typing import Optional\n\nfrom sglang.srt.debug_utils.comparator.dims_spec.comment_parser import (\n    _CommentSuffix,\n    _parse_comment_suffix,\n)\nfrom sglang.srt.debug_utils.comparator.dims_spec.dim_parser import parse_dim\nfrom sglang.srt.debug_utils.comparator.dims_spec.types import (\n    SQUEEZE_DIM_NAME,\n    DimSpec,\n    DimsSpec,\n    ParallelAxis,\n)\n\n\nclass _SingletonDimUtil:\n    \"\"\"Utilities for squeeze dims (name=\"1\") and their singleton tensor-name mapping.\"\"\"\n\n    PREFIX: str = \"singleton\"\n\n    @staticmethod\n    def is_squeeze(spec: DimSpec) -> bool:\n        return spec.name == SQUEEZE_DIM_NAME\n\n    @staticmethod\n    def filter_out(dim_specs: list[DimSpec]) -> list[DimSpec]:\n        return [s for s in dim_specs if not _SingletonDimUtil.is_squeeze(s)]\n\n    @staticmethod\n    def make_name(index: int) -> str:\n        return f\"{_SingletonDimUtil.PREFIX}{index}\"\n\n    @staticmethod\n    def is_singleton_name(name: str) -> bool:\n        return (\n            name.startswith(_SingletonDimUtil.PREFIX)\n            and name[len(_SingletonDimUtil.PREFIX) :].isdigit()\n        )\n\n    @staticmethod\n    def sanitize_names(names: list[str]) -> list[str]:\n        \"\"\"Replace '1' with 'singleton0', 'singleton1', ... for named tensor compatibility.\"\"\"\n        result: list[str] = []\n        sq_idx: int = 0\n\n        for name in names:\n            if name == SQUEEZE_DIM_NAME:\n                result.append(_SingletonDimUtil.make_name(sq_idx))\n                sq_idx += 1\n            else:\n                result.append(name)\n\n        return result\n\n\ndef parse_dims(dims_str: str) -> DimsSpec:\n    \"\"\"Parse ``\"b s[cp:zigzag] h[tp] d # dp:=moe_dp ep:replicated\"`` → :class:`DimsSpec`.\n\n    The shape part (before ``#``) produces :pyattr:`DimsSpec.dims`.\n    The declaration part (after ``#``) is scanned for:\n    - ``dp:=<group>`` → :pyattr:`DimsSpec.dp_group_alias`\n    - ``axis:replicated`` → :pyattr:`DimsSpec.replicated_axes`\n    \"\"\"\n    parts: list[str] = dims_str.split(\"#\", maxsplit=1)\n    raw: str = parts[0]\n\n    if not raw.strip():\n        raise ValueError(\"dims string must not be empty\")\n\n    dims: list[DimSpec] = [parse_dim(token) for token in raw.strip().split()]\n\n    # Collect all semantic names (expanding fused sub-dims) for duplicate detection\n    semantic_names: list[str] = []\n    for spec in dims:\n        if _SingletonDimUtil.is_squeeze(spec):\n            continue\n        semantic_names.extend(spec.sub_dims)\n\n    if len(semantic_names) != len(set(semantic_names)):\n        duplicates = sorted({n for n in semantic_names if semantic_names.count(n) > 1})\n        raise ValueError(f\"Duplicate dim names: {duplicates}\")\n\n    comment_suffix: _CommentSuffix = (\n        _parse_comment_suffix(parts[1]) if len(parts) > 1 else _CommentSuffix()\n    )\n    dp_group_alias: Optional[str] = comment_suffix.dp_group_alias\n    replicated_axes: frozenset[ParallelAxis] = comment_suffix.replicated_axes\n\n    sharded_axes: set[ParallelAxis] = {\n        m.axis for spec in dims for m in spec.parallel_modifiers\n    }\n    conflict: frozenset[ParallelAxis] = replicated_axes & sharded_axes\n    if conflict:\n        conflict_names: str = \", \".join(sorted(a.value for a in conflict))\n        raise ValueError(\n            f\"Axes declared as both sharded (in dim spec) and replicated \"\n            f\"(in # declaration): {conflict_names}\"\n        )\n\n    return DimsSpec(\n        dims=dims,\n        dp_group_alias=dp_group_alias,\n        replicated_axes=replicated_axes,\n    )\n\n\ndef resolve_dim_names(dims_str: str) -> list[str]:\n    \"\"\"Parse dims string and return tensor-compatible names ('1' → 'singleton0', ...).\"\"\"\n    specs: list[DimSpec] = parse_dims(dims_str).dims\n    names: list[str] = [spec.sanitized_name for spec in specs]\n    return _SingletonDimUtil.sanitize_names(names)\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/dims_spec/modifier_parser.py",
    "content": "from __future__ import annotations\n\nfrom typing import Optional\n\nfrom sglang.srt.debug_utils.comparator.dims_spec.types import (\n    _AXIS_LOOKUP,\n    _QUALIFIER_LOOKUP,\n    Ordering,\n    ParallelAxis,\n    ParallelModifier,\n    Reduction,\n)\n\n\ndef _parse_modifier_token(modifier_token: str, dim_token: str) -> ParallelModifier:\n    \"\"\"Parse 'sp', 'cp:zigzag', 'tp:partial', or 'cp:zigzag+partial' → ParallelModifier.\n\n    Format: ``axis`` or ``axis:qual`` or ``axis:qual+qual``.\n    Colon separates axis from qualifiers; ``+`` separates multiple qualifiers.\n    \"\"\"\n    axis_str: str\n    qualifiers_str: str\n    if \":\" in modifier_token:\n        axis_str, qualifiers_str = modifier_token.split(\":\", maxsplit=1)\n    else:\n        axis_str, qualifiers_str = modifier_token, \"\"\n\n    axis_str = axis_str.strip()\n    axis: Optional[ParallelAxis] = _AXIS_LOOKUP.get(axis_str)\n    if axis is None:\n        raise ValueError(\n            f\"Unknown axis {axis_str!r} in modifier {modifier_token!r} \"\n            f\"of dim spec: {dim_token!r}\"\n        )\n\n    ordering: Optional[Ordering] = None\n    reduction: Optional[Reduction] = None\n\n    for q_str in (q.strip() for q in qualifiers_str.split(\"+\") if q.strip()):\n        if q_str == \"sharded\":\n            continue\n        qualifier: Optional[Ordering | Reduction] = _QUALIFIER_LOOKUP.get(q_str)\n        if qualifier is None:\n            raise ValueError(\n                f\"Unknown qualifier {q_str!r} in modifier \"\n                f\"{modifier_token!r} of dim spec: {dim_token!r}\"\n            )\n        if isinstance(qualifier, Ordering):\n            if ordering is not None:\n                raise ValueError(\n                    f\"Multiple ordering values in modifier \"\n                    f\"{modifier_token!r} of dim spec: {dim_token!r}\"\n                )\n            ordering = qualifier\n        else:\n            if reduction is not None:\n                raise ValueError(\n                    f\"Multiple reduction values in modifier \"\n                    f\"{modifier_token!r} of dim spec: {dim_token!r}\"\n                )\n            reduction = qualifier\n\n    return ParallelModifier(axis=axis, ordering=ordering, reduction=reduction)\n\n\ndef _parse_modifiers(\n    *, modifiers_str: Optional[str], dim_token: str\n) -> list[ParallelModifier]:\n    if modifiers_str is None:\n        return []\n\n    modifiers: list[ParallelModifier] = []\n    seen_axes: set[ParallelAxis] = set()\n\n    for modifier_token in (p.strip() for p in modifiers_str.split(\",\")):\n        modifier: ParallelModifier = _parse_modifier_token(modifier_token, dim_token)\n        if modifier.axis in seen_axes:\n            raise ValueError(\n                f\"Duplicate axis {modifier.axis.value!r} in dim spec: {dim_token!r}\"\n            )\n        seen_axes.add(modifier.axis)\n        modifiers.append(modifier)\n\n    return modifiers\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/dims_spec/tensor_naming.py",
    "content": "from __future__ import annotations\n\nfrom typing import Optional\n\nimport torch\n\nfrom sglang.srt.debug_utils.comparator.dims_spec.types import DimSpec\n\n\ndef find_dim_index(dim_specs: list[DimSpec], name: str) -> Optional[int]:\n    \"\"\"Find index by name. Accepts both ``*``-form and ``___``-form for fused dims.\"\"\"\n    for i, spec in enumerate(dim_specs):\n        if spec.name == name or spec.sanitized_name == name:\n            return i\n    return None\n\n\ndef resolve_dim_by_name(tensor: torch.Tensor, name: str) -> int:\n    if tensor.names[0] is None:\n        raise ValueError(f\"Tensor has no names, cannot resolve {name!r}\")\n\n    names: tuple[Optional[str], ...] = tensor.names\n    try:\n        return list(names).index(name)\n    except ValueError:\n        raise ValueError(f\"Dim name {name!r} not in tensor names {names}\")\n\n\ndef apply_dim_names(tensor: torch.Tensor, dim_names: list[str]) -> torch.Tensor:\n    if tensor.ndim != len(dim_names):\n        raise ValueError(\n            f\"dims metadata mismatch: tensor has {tensor.ndim} dims (shape {list(tensor.shape)}) \"\n            f\"but dims string specifies {len(dim_names)} names {dim_names}. \"\n            f\"Please fix the dims string in the dumper.dump() call to match the actual tensor shape.\"\n        )\n    return tensor.refine_names(*dim_names)\n\n\ndef strip_dim_names(tensor: torch.Tensor) -> torch.Tensor:\n    return tensor.rename(None)\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/dims_spec/types.py",
    "content": "from __future__ import annotations\n\nfrom enum import Enum\nfrom typing import Optional\n\nfrom sglang.srt.debug_utils.comparator.utils import _FrozenBase\n\nTOKEN_DIM_NAME: str = \"t\"\nBATCH_DIM_NAME: str = \"b\"\nSEQ_DIM_NAME: str = \"s\"\nSQUEEZE_DIM_NAME: str = \"1\"\n\n\nclass TokenLayout(Enum):\n    T = \"t\"  # single flat token dim\n    BS = \"bs\"  # separate batch + seq dims, need collapse\n\n\nclass ParallelAxis(Enum):\n    TP = \"tp\"\n    CP = \"cp\"\n    EP = \"ep\"\n    SP = \"sp\"\n    RECOMPUTE_PSEUDO = \"recompute_pseudo\"\n\n\nclass Ordering(Enum):\n    ZIGZAG = \"zigzag\"\n    NATURAL = \"natural\"\n\n\nclass Reduction(Enum):\n    PARTIAL = \"partial\"\n\n\nclass ParallelModifier(_FrozenBase):\n    axis: ParallelAxis\n    ordering: Optional[Ordering] = None\n    reduction: Optional[Reduction] = None\n\n\n_AXIS_LOOKUP: dict[str, ParallelAxis] = {m.value: m for m in ParallelAxis}\n_QUALIFIER_LOOKUP: dict[str, Ordering | Reduction] = {\n    **{m.value: m for m in Ordering},\n    **{m.value: m for m in Reduction},\n}\n\n_FUSED_NAME_SEP: str = \"___\"\n\n\nclass DimSpec(_FrozenBase):\n    name: str\n    parallel_modifiers: list[ParallelModifier] = []\n\n    @property\n    def sub_dims(self) -> list[str]:\n        \"\"\"Sub-dim names. Fused: ``[\"num_heads\", \"head_dim\"]``; plain: ``[\"h\"]``.\"\"\"\n        return self.name.split(\"*\")\n\n    @property\n    def is_fused(self) -> bool:\n        return len(self.sub_dims) > 1\n\n    @property\n    def sanitized_name(self) -> str:\n        \"\"\"Name safe for PyTorch named tensors (``*`` → ``___``).\"\"\"\n        if self.is_fused:\n            return _FUSED_NAME_SEP.join(self.sub_dims)\n        return self.name\n\n\nclass DimsSpec(_FrozenBase):\n    \"\"\"Parsed result of a full dims string like ``\"b s h[tp] # dp:=moe_dp\"``.\"\"\"\n\n    dims: list[DimSpec]\n    dp_group_alias: Optional[str] = None\n    replicated_axes: frozenset[ParallelAxis] = frozenset()\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/display.py",
    "content": "from __future__ import annotations\n\nfrom collections import defaultdict\nfrom io import StringIO\nfrom pathlib import Path\nfrom typing import Any, Optional\n\nimport polars as pl\n\nfrom sglang.srt.debug_utils.comparator.output_types import (\n    InputIdsRecord,\n    RankInfoRecord,\n)\nfrom sglang.srt.debug_utils.comparator.report_sink import report_sink\nfrom sglang.srt.debug_utils.dump_loader import LOAD_FAILED, ValueWithMeta\n\nPARALLEL_INFO_KEYS: list[str] = [\"sglang_parallel_info\", \"megatron_parallel_info\"]\n\n\ndef emit_display_records(\n    *,\n    df: pl.DataFrame,\n    dump_dir: Path,\n    label: str,\n    tokenizer: Any,\n) -> None:\n    rank_rows: Optional[list[dict[str, Any]]] = _collect_rank_info(\n        df, dump_dir=dump_dir\n    )\n    if rank_rows is not None:\n        report_sink.add(RankInfoRecord(label=label, rows=rank_rows))\n\n    input_ids_rows: Optional[list[dict[str, Any]]] = _collect_input_ids_and_positions(\n        df, dump_dir=dump_dir, tokenizer=tokenizer\n    )\n    if input_ids_rows is not None:\n        report_sink.add(InputIdsRecord(label=label, rows=input_ids_rows))\n\n\ndef _render_polars_as_text(df: pl.DataFrame, *, title: Optional[str] = None) -> str:\n    from rich.console import Console\n    from rich.table import Table\n\n    table = Table(title=title)\n    for col in df.columns:\n        table.add_column(col)\n    for row in df.iter_rows():\n        table.add_row(*[str(v) for v in row])\n\n    buf = StringIO()\n    Console(file=buf, force_terminal=False, width=200).print(table)\n    return buf.getvalue().rstrip(\"\\n\")\n\n\ndef _collect_rank_info(\n    df: pl.DataFrame, dump_dir: Path\n) -> Optional[list[dict[str, Any]]]:\n    unique_rows: pl.DataFrame = (\n        df.filter(pl.col(\"name\") == \"input_ids\")\n        .sort(\"rank\")\n        .unique(subset=[\"rank\"], keep=\"first\")\n    )\n    if unique_rows.is_empty():\n        return None\n\n    table_rows: list[dict[str, Any]] = []\n    for row in unique_rows.to_dicts():\n        meta: dict[str, Any] = ValueWithMeta.load(dump_dir / row[\"filename\"]).meta\n\n        row_data: dict[str, Any] = {\"rank\": row[\"rank\"]}\n        for key in PARALLEL_INFO_KEYS:\n            _extract_parallel_info(row_data=row_data, info=meta.get(key, {}))\n        table_rows.append(row_data)\n\n    return table_rows or None\n\n\ndef _collect_input_ids_and_positions(\n    df: pl.DataFrame,\n    dump_dir: Path,\n    *,\n    tokenizer: Any = None,\n) -> Optional[list[dict[str, Any]]]:\n    filtered: pl.DataFrame = df.filter(pl.col(\"name\").is_in([\"input_ids\", \"positions\"]))\n    if filtered.is_empty():\n        return None\n\n    data_by_step_rank: dict[tuple[int, int], dict[str, Any]] = defaultdict(dict)\n    for row in filtered.to_dicts():\n        key: tuple[int, int] = (row[\"step\"], row[\"rank\"])\n        item: ValueWithMeta = ValueWithMeta.load(dump_dir / row[\"filename\"])\n        if item.value is not LOAD_FAILED:\n            data_by_step_rank[key][row[\"name\"]] = item.value\n\n    table_rows: list[dict[str, Any]] = []\n    for (step, rank), data in sorted(data_by_step_rank.items()):\n        ids = data.get(\"input_ids\")\n        pos = data.get(\"positions\")\n\n        ids_list: Optional[list[int]] = (\n            ids.flatten().tolist() if ids is not None else None\n        )\n\n        row_data: dict[str, Any] = {\n            \"step\": step,\n            \"rank\": rank,\n            \"num_tokens\": len(ids_list) if ids_list is not None else None,\n            \"input_ids\": str(ids_list) if ids_list is not None else \"N/A\",\n            \"positions\": str(pos.flatten().tolist()) if pos is not None else \"N/A\",\n        }\n\n        if tokenizer is not None and ids_list is not None:\n            row_data[\"decoded_text\"] = repr(\n                tokenizer.decode(ids_list, skip_special_tokens=False)\n            )\n\n        table_rows.append(row_data)\n\n    return table_rows or None\n\n\ndef _extract_parallel_info(row_data: dict[str, Any], info: dict[str, Any]) -> None:\n    if not info or info.get(\"error\"):\n        return\n\n    for key in sorted(info.keys()):\n        if key.endswith(\"_rank\"):\n            base: str = key[:-5]\n            size_key: str = f\"{base}_size\"\n            if size_key in info:\n                row_data[base] = f\"{info[key]}/{info[size_key]}\"\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/dp_utils.py",
    "content": "\"\"\"DP filtering: keep only the non-empty dp_rank items.\"\"\"\n\nfrom __future__ import annotations\n\nfrom collections import defaultdict\nfrom typing import Optional\n\nimport torch\n\nfrom sglang.srt.debug_utils.dump_loader import ValueWithMeta\n\n_PARALLEL_INFO_KEYS = (\"sglang_parallel_info\", \"megatron_parallel_info\")\n\n_DP_RANK_FIELD = \"dp_rank\"\n_DP_SIZE_FIELD = \"dp_size\"\n\n\ndef filter_to_non_empty_dp_rank(\n    items: list[ValueWithMeta],\n    *,\n    dp_group_alias: Optional[str] = None,\n) -> list[ValueWithMeta]:\n    \"\"\"Filter items to the single non-empty dp_rank.\n\n    - dp_size <= 1: return items unchanged.\n    - dp_size > 1: group by dp_rank, assert exactly one group has non-empty\n      tensors, return that group.\n\n    When *dp_group_alias* is set (e.g. ``\"moe_dp\"``), the function looks\n    for ``<alias>_rank`` / ``<alias>_size`` instead of the default\n    ``dp_rank`` / ``dp_size``.  If the aliased fields are absent the\n    filter is a noop (items returned unchanged).\n    \"\"\"\n    if not items:\n        return items\n\n    dp_info: Optional[tuple[int, int]] = _extract_dp_info(\n        items[0].meta, dp_group_alias=dp_group_alias\n    )\n    if dp_info is None:\n        return items\n\n    _dp_rank, dp_size = dp_info\n    if dp_size <= 1:\n        return items\n\n    has_any_tensor: bool = any(isinstance(item.value, torch.Tensor) for item in items)\n    if not has_any_tensor:\n        return items\n\n    groups: dict[int, list[ValueWithMeta]] = defaultdict(list)\n    for item in items:\n        item_dp: Optional[tuple[int, int]] = _extract_dp_info(\n            item.meta, dp_group_alias=dp_group_alias\n        )\n        rank: int = item_dp[0] if item_dp is not None else 0\n        groups[rank].append(item)\n\n    non_empty_ranks: list[int] = [\n        rank for rank, group in groups.items() if _group_has_data(group)\n    ]\n\n    assert len(non_empty_ranks) == 1, (\n        f\"Expected exactly 1 non-empty dp_rank, got {len(non_empty_ranks)}: \"\n        f\"ranks={non_empty_ranks}\"\n    )\n\n    return groups[non_empty_ranks[0]]\n\n\ndef _extract_dp_info(\n    meta: dict,\n    *,\n    dp_group_alias: Optional[str] = None,\n) -> Optional[tuple[int, int]]:\n    \"\"\"Extract (dp_rank, dp_size) from meta's parallel_info block.\n\n    When *dp_group_alias* is given, look for ``<alias>_rank``/``<alias>_size``\n    instead of the default ``dp_rank``/``dp_size``.\n    \"\"\"\n    rank_field: str = f\"{dp_group_alias}_rank\" if dp_group_alias else _DP_RANK_FIELD\n    size_field: str = f\"{dp_group_alias}_size\" if dp_group_alias else _DP_SIZE_FIELD\n\n    for key in _PARALLEL_INFO_KEYS:\n        info = meta.get(key)\n        if not isinstance(info, dict) or not info:\n            continue\n\n        dp_rank = info.get(rank_field)\n        dp_size = info.get(size_field)\n        if dp_rank is not None and dp_size is not None:\n            return (int(dp_rank), int(dp_size))\n\n    return None\n\n\ndef _group_has_data(group: list[ValueWithMeta]) -> bool:\n    \"\"\"Check if any tensor in the group is non-empty (numel > 0).\"\"\"\n    return any(\n        isinstance(item.value, torch.Tensor) and item.value.numel() > 0\n        for item in group\n    )\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/entrypoint.py",
    "content": "from __future__ import annotations\n\nimport argparse\nimport sys\nimport traceback as _traceback_module\nfrom pathlib import Path\nfrom typing import Any, Iterator, Optional, Union\n\nimport polars as pl\n\nfrom sglang.srt.debug_utils.comparator.aligner.token_aligner.entrypoint import (\n    TokenAlignerResult,\n    compute_maybe_token_aligner_result,\n)\nfrom sglang.srt.debug_utils.comparator.aligner.token_aligner.smart.aux_loader import (\n    AUX_NAMES,\n)\nfrom sglang.srt.debug_utils.comparator.aligner.token_aligner.smart.types import (\n    TokenAlignerPlan,\n)\nfrom sglang.srt.debug_utils.comparator.bundle_comparator import compare_bundle_pair\nfrom sglang.srt.debug_utils.comparator.bundle_matcher import (\n    TensorBundleInfo,\n    match_bundles,\n)\nfrom sglang.srt.debug_utils.comparator.display import emit_display_records\nfrom sglang.srt.debug_utils.comparator.meta_overrider import MetaOverrider\nfrom sglang.srt.debug_utils.comparator.output_types import (\n    ComparisonErrorRecord,\n    ComparisonNonTensorRecord,\n    ComparisonSkipRecord,\n    ComparisonTensorRecord,\n    ConfigRecord,\n    RecordLocation,\n    SummaryRecord,\n)\nfrom sglang.srt.debug_utils.comparator.per_token_visualizer import (\n    generate_per_token_heatmap,\n)\nfrom sglang.srt.debug_utils.comparator.preset import PRESETS, expand_preset\nfrom sglang.srt.debug_utils.comparator.report_sink import report_sink\nfrom sglang.srt.debug_utils.comparator.utils import (\n    Pair,\n    auto_descend_dir,\n    compute_exit_code,\n)\nfrom sglang.srt.debug_utils.dump_loader import read_meta, read_tokenizer_path\n\n_DEFAULT_SKIP_KEYS: set[str] = {\"dump_index\", \"filename\"}\n\n\ndef main() -> None:\n    args = parse_args(sys.argv[1:])\n    sys.exit(run(args))\n\n\ndef run(args: argparse.Namespace) -> int:\n    report_sink.configure(\n        output_format=args.output_format,\n        report_path=None,\n        verbosity=args.verbosity,\n    )\n\n    dir_pair: Pair[Path] = Pair(\n        x=auto_descend_dir(Path(args.baseline_path), label=\"baseline_path\"),\n        y=auto_descend_dir(Path(args.target_path), label=\"target_path\"),\n    )\n    viz_output_dir: Optional[Path] = (\n        Path(args.viz_output_dir) if args.viz_bundle_details else None\n    )\n    visualize_per_token: Optional[Path] = (\n        Path(args.visualize_per_token) if args.visualize_per_token else None\n    )\n    override_config: Optional[Path] = (\n        Path(args.override_config) if args.override_config else None\n    )\n\n    report_path: Optional[Path] = _resolve_report_path(\n        target_path=dir_pair.y,\n        report_path_arg=args.report_path,\n    )\n    report_sink.configure(\n        output_format=args.output_format,\n        report_path=report_path,\n        verbosity=args.verbosity,\n    )\n\n    try:\n        report_sink.add(ConfigRecord(config=vars(args)))\n\n        dfs: Pair[pl.DataFrame] = _read_df(\n            dir_pair=dir_pair,\n            start_step=args.start_step,\n            end_step=args.end_step,\n            filter_pattern=args.filter,\n        )\n\n        tokenizer: Any = _maybe_load_tokenizer(\n            tokenizer_arg=args.tokenizer, dir_pair=dir_pair\n        )\n        for label, df, dump_dir in [\n            (\"baseline\", dfs.x, dir_pair.x),\n            (\"target\", dfs.y, dir_pair.y),\n        ]:\n            emit_display_records(\n                df=df, dump_dir=dump_dir, label=label, tokenizer=tokenizer\n            )\n\n        ta_result: TokenAlignerResult = compute_maybe_token_aligner_result(\n            dir_pair=dir_pair,\n            dfs=dfs,\n            token_aligner_mode=args.token_aligner,\n        )\n\n        if ta_result.mode == \"smart\":\n            dfs = dfs.map(lambda df: df.filter(~pl.col(\"name\").is_in(AUX_NAMES)))\n\n        skip_keys: set[str] = _DEFAULT_SKIP_KEYS | set(args.grouping_skip_keys or [])\n        bundle_info_pairs: list[Pair[TensorBundleInfo]] = match_bundles(\n            dfs=dfs, skip_keys=skip_keys\n        )\n\n        meta_overrider: MetaOverrider = MetaOverrider.from_args_and_config(\n            override_dims=args.override_dims,\n            override_baseline_dims=args.override_baseline_dims,\n            override_target_dims=args.override_target_dims,\n            override_config=override_config,\n        )\n\n        comparison_records = _compare_bundle_pairs(\n            bundle_info_pairs=bundle_info_pairs,\n            dir_pair=dir_pair,\n            token_aligner_mode=ta_result.mode,\n            token_aligner_plan=ta_result.plan,\n            diff_threshold=args.diff_threshold,\n            thd_seq_lens_by_step_pair=ta_result.thd_seq_lens_by_step_pair,\n            viz_output_dir=viz_output_dir,\n            compute_per_token=visualize_per_token is not None,\n            meta_overrider=meta_overrider,\n        )\n        summary, skipped_names, failed_names, errored_names = (\n            _consume_comparison_records(\n                comparison_records=comparison_records,\n                visualize_per_token=visualize_per_token,\n            )\n        )\n        return compute_exit_code(\n            summary,\n            allow_skipped_pattern=args.allow_skipped_pattern,\n            skipped_names=skipped_names,\n            allow_failed_pattern=args.allow_failed_pattern,\n            failed_names=failed_names,\n            errored_names=errored_names,\n        )\n    finally:\n        report_sink.close()\n        if report_path is not None:\n            print(f\"Report: {report_path}\", file=sys.stderr)\n\n\ndef _resolve_report_path(\n    *, target_path: Path, report_path_arg: Optional[str]\n) -> Optional[Path]:\n    if report_path_arg is not None:\n        return Path(report_path_arg) if report_path_arg else None\n    return target_path / \"comparator_report.jsonl\"\n\n\ndef _maybe_load_tokenizer(*, tokenizer_arg: Optional[str], dir_pair: Pair[Path]) -> Any:\n    tokenizer_path: Optional[str] = tokenizer_arg\n\n    if tokenizer_path is None:\n        for directory in [dir_pair.x, dir_pair.y]:\n            tokenizer_path = read_tokenizer_path(directory)\n            if tokenizer_path is not None:\n                break\n\n    if tokenizer_path is None:\n        return None\n\n    try:\n        from transformers import AutoTokenizer\n\n        return AutoTokenizer.from_pretrained(tokenizer_path)\n    except Exception:\n        return None\n\n\ndef _read_df(\n    *,\n    dir_pair: Pair[Path],\n    start_step: int,\n    end_step: int,\n    filter_pattern: Optional[str],\n) -> Pair[pl.DataFrame]:\n    df_baseline = read_meta(dir_pair.x)\n\n    df_target = read_meta(dir_pair.y)\n    df_target = df_target.filter(\n        (pl.col(\"step\") >= start_step) & (pl.col(\"step\") <= end_step)\n    )\n    if filter_pattern:\n        df_target = df_target.filter(pl.col(\"filename\").str.contains(filter_pattern))\n    assert all(c in df_target.columns for c in [\"rank\", \"step\", \"dump_index\", \"name\"])\n\n    return Pair(x=df_baseline, y=df_target)\n\n\ndef _compare_bundle_pairs(\n    *,\n    bundle_info_pairs: list[Pair[TensorBundleInfo]],\n    dir_pair: Pair[Path],\n    token_aligner_mode: Optional[str],\n    token_aligner_plan: Optional[TokenAlignerPlan],\n    diff_threshold: float,\n    thd_seq_lens_by_step_pair: Pair[Optional[dict[int, list[int]]]],\n    viz_output_dir: Optional[Path] = None,\n    compute_per_token: bool = False,\n    meta_overrider: Optional[MetaOverrider] = None,\n) -> Iterator[\n    Union[\n        ComparisonTensorRecord,\n        ComparisonSkipRecord,\n        ComparisonNonTensorRecord,\n        ComparisonErrorRecord,\n    ]\n]:\n    for bundle_info_pair in bundle_info_pairs:\n        if not bundle_info_pair.y:\n            continue\n\n        name: str = bundle_info_pair.y[0].name\n        filenames_pair: Pair[list[str]] = bundle_info_pair.map(\n            lambda infos: [info.filename for info in infos]\n        )\n\n        record: Union[\n            ComparisonTensorRecord,\n            ComparisonSkipRecord,\n            ComparisonNonTensorRecord,\n            ComparisonErrorRecord,\n        ]\n        try:\n            record = compare_bundle_pair(\n                name=name,\n                filenames_pair=filenames_pair,\n                dir_pair=dir_pair,\n                token_aligner_mode=token_aligner_mode,\n                token_aligner_plan=token_aligner_plan,\n                diff_threshold=diff_threshold,\n                thd_seq_lens_by_step_pair=thd_seq_lens_by_step_pair,\n                viz_output_dir=viz_output_dir,\n                compute_per_token=compute_per_token,\n                meta_overrider=meta_overrider,\n            )\n        except Exception as exc:\n            record = ComparisonErrorRecord(\n                name=name,\n                exception_type=type(exc).__name__,\n                traceback_str=_traceback_module.format_exc(),\n            )\n\n        target_steps: set[int] = {info.step for info in bundle_info_pair.y}\n        step: Optional[int] = target_steps.pop() if len(target_steps) == 1 else None\n        if step is not None:\n            record = record.model_copy(update={\"location\": RecordLocation(step=step)})\n\n        yield record\n\n\ndef _consume_comparison_records(\n    *,\n    comparison_records: Iterator[\n        Union[\n            ComparisonTensorRecord,\n            ComparisonSkipRecord,\n            ComparisonNonTensorRecord,\n            ComparisonErrorRecord,\n        ]\n    ],\n    visualize_per_token: Optional[Path] = None,\n) -> tuple[SummaryRecord, list[str], list[str], list[str]]:\n    counts: dict[str, int] = {\"passed\": 0, \"failed\": 0, \"skipped\": 0, \"errored\": 0}\n    collected_comparisons: list[ComparisonTensorRecord] = []\n    skipped_names: list[str] = []\n    failed_names: list[str] = []\n    errored_names: list[str] = []\n\n    for record in comparison_records:\n        counts[record.category] += 1\n        report_sink.add(record)\n        if isinstance(record, ComparisonSkipRecord) and record.category == \"skipped\":\n            skipped_names.append(record.name)\n        if record.category == \"failed\":\n            failed_names.append(record.name)\n        if isinstance(record, ComparisonErrorRecord):\n            errored_names.append(record.name)\n        if visualize_per_token is not None and isinstance(\n            record, ComparisonTensorRecord\n        ):\n            collected_comparisons.append(record)\n\n    summary: SummaryRecord = SummaryRecord(total=sum(counts.values()), **counts)\n    report_sink.add(summary)\n\n    if visualize_per_token is not None and collected_comparisons:\n        generate_per_token_heatmap(\n            records=collected_comparisons,\n            output_path=visualize_per_token,\n        )\n\n    return summary, skipped_names, failed_names, errored_names\n\n\ndef parse_args(argv: list[str]) -> argparse.Namespace:\n    \"\"\"Parse CLI arguments from an argv list. Applies preset expansion.\"\"\"\n    argv = expand_preset(argv, presets=PRESETS)\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--baseline-path\", type=str)\n    parser.add_argument(\"--target-path\", type=str)\n    parser.add_argument(\"--start-step\", type=int, default=0)\n    parser.add_argument(\"--end-step\", type=int, default=1000000)\n    parser.add_argument(\"--diff-threshold\", type=float, default=1e-3)\n    parser.add_argument(\n        \"--filter\", type=str, default=None, help=\"Regex to filter filenames (include)\"\n    )\n    parser.add_argument(\n        \"--output-format\",\n        type=str,\n        choices=[\"text\", \"json\"],\n        default=\"text\",\n        help=\"Output format: text (default) or json (JSONL, one JSON object per line)\",\n    )\n    parser.add_argument(\n        \"--verbosity\",\n        type=str,\n        choices=[\"minimal\", \"normal\", \"verbose\"],\n        default=\"normal\",\n        help=\"Output verbosity: minimal (1 line per tensor), normal (compact lifecycle), \"\n        \"verbose (full detail). Default: normal\",\n    )\n    parser.add_argument(\n        \"--preset\",\n        type=str,\n        choices=list(PRESETS.keys()),\n        default=None,\n        help=\"Preset configuration (expanded before parsing). \"\n        f\"Available: {list(PRESETS.keys())}\",\n    )\n    parser.add_argument(\n        \"--grouping-skip-keys\",\n        nargs=\"*\",\n        default=None,\n        help=\"Metadata keys to skip when grouping bundles (additive on top of \"\n        \"always-skipped dump_index and filename). \"\n        \"E.g. '--grouping-skip-keys rank step' skips rank and step.\",\n    )\n    parser.add_argument(\n        \"--token-aligner\",\n        type=str,\n        choices=[\"smart\", \"concat_steps\"],\n        default=None,\n        help=\"Token aligner mode: concat_steps (BS=1, no aux needed) or smart (BS>1, sequence matching). \"\n        \"Default None (per-step comparison).\",\n    )\n    parser.add_argument(\n        \"--tokenizer\",\n        type=str,\n        default=None,\n        help=\"Tokenizer path for decoding input_ids (auto-discovered from dump metadata if not set)\",\n    )\n    parser.add_argument(\n        \"--viz-bundle-details\",\n        action=\"store_true\",\n        default=False,\n        help=\"Generate comparison heatmap/histogram PNG for each compared tensor\",\n    )\n    parser.add_argument(\n        \"--viz-output-dir\",\n        type=str,\n        default=\"/tmp/comparator_viz/\",\n        help=\"Output directory for visualization PNGs (default: /tmp/comparator_viz/)\",\n    )\n    parser.add_argument(\n        \"--visualize-per-token\",\n        type=str,\n        default=None,\n        help=\"Output path for per-token relative difference heatmap PNG\",\n    )\n\n    # Dims override\n    parser.add_argument(\n        \"--override-dims\",\n        action=\"append\",\n        default=[],\n        help=\"Override dims for both sides: 'name:dims_string' (repeatable)\",\n    )\n    parser.add_argument(\n        \"--override-baseline-dims\",\n        action=\"append\",\n        default=[],\n        help=\"Override dims for baseline only: 'name:dims_string' (repeatable)\",\n    )\n    parser.add_argument(\n        \"--override-target-dims\",\n        action=\"append\",\n        default=[],\n        help=\"Override dims for target only: 'name:dims_string' (repeatable)\",\n    )\n    parser.add_argument(\n        \"--override-config\",\n        type=str,\n        default=None,\n        help=\"Path to YAML override config file (dims overrides, etc.)\",\n    )\n    parser.add_argument(\n        \"--allow-skipped-pattern\",\n        type=str,\n        default=\".*\",\n        help=\"Regex pattern for tensor names allowed to be skipped. \"\n        \"Default '.*' allows all skips. Use '^$' to forbid all skips.\",\n    )\n    parser.add_argument(\n        \"--allow-failed-pattern\",\n        type=str,\n        default=None,\n        help=\"Regex pattern for tensor names allowed to fail without affecting exit code. \"\n        \"Default None (all failures affect exit code).\",\n    )\n\n    # Report output\n    parser.add_argument(\n        \"--report-path\",\n        type=str,\n        default=None,\n        help=\"Path for JSONL report (default: <target-path>/comparator_report.jsonl). \"\n        \"Pass empty string '' to disable.\",\n    )\n\n    return parser.parse_args(argv)\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/log_sink.py",
    "content": "from __future__ import annotations\n\nfrom contextlib import contextmanager\nfrom typing import Generator\n\nfrom sglang.srt.debug_utils.comparator.output_types import BaseLog\n\n\nclass LogSink:\n    def __init__(self) -> None:\n        self._stack: list[list[BaseLog]] = []\n\n    @contextmanager\n    def context(self) -> Generator[list[BaseLog], None, None]:\n        bucket: list[BaseLog] = []\n        self._stack.append(bucket)\n        try:\n            yield bucket\n        finally:\n            popped = self._stack.pop()\n            assert popped is bucket\n\n    def add(self, log: BaseLog) -> None:\n        if self._stack:\n            self._stack[-1].append(log)\n        else:\n            from sglang.srt.debug_utils.comparator.output_types import (\n                LogRecord,\n                _split_logs,\n            )\n            from sglang.srt.debug_utils.comparator.report_sink import report_sink\n\n            errors, infos = _split_logs([log])\n            report_sink.add(LogRecord(errors=errors, infos=infos))\n\n\nlog_sink = LogSink()\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/meta_overrider.py",
    "content": "\"\"\"Meta overrider: replace metadata fields without re-running dumps.\n\nCurrently only overrides 'dims', but the design supports overriding\nadditional meta fields (e.g. parallel_info) in the future.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport re\nfrom pathlib import Path\nfrom typing import Any, Literal, Optional\n\nimport yaml\n\nfrom sglang.srt.debug_utils.comparator.utils import _StrictBase\n\n\nclass MetaOverrideRule(_StrictBase):\n    \"\"\"Single override rule: regex match on tensor name → replacement meta field(s).\n\n    Currently only 'dims' is supported; more fields may be added in the future.\n    \"\"\"\n\n    match: str\n    dims: str\n    side: Literal[\"both\", \"baseline\", \"target\"] = \"both\"\n\n\nclass MetaOverrideConfig(_StrictBase):\n    \"\"\"YAML top-level config for overriding comparator behavior.\"\"\"\n\n    overrides: list[MetaOverrideRule] = []\n\n\nclass MetaOverrider:\n    \"\"\"Holds override rules and applies first-match-wins replacement.\"\"\"\n\n    def __init__(self, rules: list[MetaOverrideRule]) -> None:\n        self._rules: list[MetaOverrideRule] = rules\n\n    @property\n    def is_empty(self) -> bool:\n        return len(self._rules) == 0\n\n    @classmethod\n    def from_args_and_config(\n        cls,\n        *,\n        override_dims: list[str],\n        override_baseline_dims: list[str],\n        override_target_dims: list[str],\n        override_config: Optional[Path],\n    ) -> \"MetaOverrider\":\n        per_side_args: list[tuple[list[str], Literal[\"both\", \"baseline\", \"target\"]]] = [\n            (override_dims, \"both\"),\n            (override_baseline_dims, \"baseline\"),\n            (override_target_dims, \"target\"),\n        ]\n        cli_rules: list[MetaOverrideRule] = [\n            MetaOverrideRule(match=name, dims=dims_str, side=side)\n            for raw_args, side in per_side_args\n            for name, dims_str in [_parse_cli_override_arg(raw) for raw in raw_args]\n        ]\n\n        yaml_rules: list[MetaOverrideRule] = (\n            _load_yaml_rules(override_config) if override_config is not None else []\n        )\n\n        return cls(rules=cli_rules + yaml_rules)\n\n    def apply_to_meta(\n        self,\n        *,\n        name: str,\n        meta: dict[str, Any],\n        side: Literal[\"baseline\", \"target\"],\n    ) -> dict[str, Any]:\n        \"\"\"First-match-wins: return meta with dims replaced by the first matching rule for this side.\"\"\"\n        for rule in self._rules:\n            if rule.side not in (\"both\", side):\n                continue\n            if re.search(rule.match, name):\n                return {**meta, \"dims\": rule.dims}\n\n        return meta\n\n\ndef _parse_cli_override_arg(raw: str) -> tuple[str, str]:\n    \"\"\"Parse 'name:dims_string' from a CLI --override-* argument.\"\"\"\n    parts: list[str] = raw.split(\":\", maxsplit=1)\n    if len(parts) != 2 or not parts[0].strip() or not parts[1].strip():\n        raise ValueError(\n            f\"Invalid override format: {raw!r}; expected 'name:dims_string'\"\n        )\n    return parts[0].strip(), parts[1].strip()\n\n\ndef _load_yaml_rules(path: Path) -> list[MetaOverrideRule]:\n    \"\"\"Load override rules from a YAML config file.\"\"\"\n    with open(path) as f:\n        raw_data: Any = yaml.safe_load(f)\n\n    if raw_data is None:\n        return []\n\n    config: MetaOverrideConfig = MetaOverrideConfig.model_validate(raw_data)\n    return config.overrides\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/output_formatter.py",
    "content": "\"\"\"Formatting functions for comparator output records.\n\nExtracted from output_types.py to separate data-structure definitions\nfrom rendering / formatting logic.\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING, Literal\n\nfrom rich.console import Group\nfrom rich.markup import escape\nfrom rich.panel import Panel\n\nfrom sglang.srt.debug_utils.comparator.tensor_comparator.formatter import (\n    format_comparison,\n    format_replicated_checks,\n)\n\nif TYPE_CHECKING:\n    from rich.console import RenderableType\n\n    from sglang.srt.debug_utils.comparator.aligner.entrypoint.traced_types import (\n        TracedAlignerPlan,\n        TracedSubPlan,\n    )\n    from sglang.srt.debug_utils.comparator.aligner.entrypoint.types import AlignerPlan\n    from sglang.srt.debug_utils.comparator.output_types import (\n        ComparisonErrorRecord,\n        ComparisonNonTensorRecord,\n        ComparisonSkipRecord,\n        ComparisonTensorRecord,\n        ConfigRecord,\n        ErrorLog,\n        InfoLog,\n        LogRecord,\n        SummaryRecord,\n        _OutputRecord,\n        _TableRecord,\n    )\n\nVerbosity = Literal[\"minimal\", \"normal\", \"verbose\"]\n\n\n# ── Record-level rendering (body + logs) ─────────────────────────────\n\n\ndef _render_record_rich(\n    record: _OutputRecord, *, verbosity: Verbosity = \"normal\"\n) -> RenderableType:\n    body: RenderableType = record._format_rich_body(verbosity=verbosity)\n\n    log_lines: list[str] = _format_log_lines_rich(\n        errors=record.errors, infos=record.infos\n    )\n\n    if not log_lines:\n        return body\n\n    log_block: str = \"\\n\".join(log_lines)\n    if isinstance(body, str):\n        return body + \"\\n\" + log_block\n    return Group(body, log_block)\n\n\ndef _render_record_text(record: _OutputRecord) -> str:\n    body: str = record._format_body()\n\n    log_suffix: str = _format_log_lines_text(errors=record.errors, infos=record.infos)\n\n    if log_suffix:\n        body += \"\\n\" + log_suffix\n\n    return body\n\n\ndef _format_log_lines_rich(\n    *, errors: list[ErrorLog], infos: list[InfoLog]\n) -> list[str]:\n    lines: list[str] = []\n\n    if errors:\n        lines.extend(f\"  [red]✗ {e.to_text()}[/]\" for e in errors)\n    if infos:\n        lines.extend(f\"  [dim]ℹ {i.to_text()}[/]\" for i in infos)\n\n    return lines\n\n\ndef _format_log_lines_text(*, errors: list[ErrorLog], infos: list[InfoLog]) -> str:\n    lines: list[str] = []\n\n    if errors:\n        lines.extend(f\"  ✗ {e.to_text()}\" for e in errors)\n    if infos:\n        lines.extend(f\"  ℹ {i.to_text()}\" for i in infos)\n\n    return \"\\n\".join(lines)\n\n\n# ── ConfigRecord ──────────────────────────────────────────────────────\n\n\ndef _format_config_body(record: ConfigRecord) -> str:\n    return f\"Config: {record.config}\"\n\n\ndef _format_config_rich_body(\n    record: ConfigRecord, verbosity: Verbosity = \"normal\"\n) -> RenderableType:\n    lines: list[str] = [f\"  [bold]{k}[/] : {v}\" for k, v in record.config.items()]\n    return Panel(\"\\n\".join(lines), title=\"Comparator Config\", border_style=\"cyan\")\n\n\n# ── ComparisonSkipRecord ─────────────────────────────────────────────\n\n\ndef _format_skip_body(record: ComparisonSkipRecord) -> str:\n    return f\"Skip: {record.name}{record._format_location_suffix()} ({record.reason})\"\n\n\ndef _format_skip_rich_body(\n    record: ComparisonSkipRecord, verbosity: Verbosity = \"normal\"\n) -> RenderableType:\n    suffix: str = record._format_location_suffix()\n    return (\n        f\"[dim]⊘ {escape(record.name)}{suffix} ── skipped ({escape(record.reason)})[/]\"\n    )\n\n\n# ── ComparisonErrorRecord ────────────────────────────────────────────\n\n\ndef _format_error_body(record: ComparisonErrorRecord) -> str:\n    prefix: str = record._format_location_prefix()\n    return (\n        f\"{prefix}Error: {record.name} ({record.exception_type})\\n\"\n        f\"{record.traceback_str}\"\n    )\n\n\ndef _format_error_rich_body(\n    record: ComparisonErrorRecord, verbosity: Verbosity = \"normal\"\n) -> RenderableType:\n    prefix: str = record._format_location_prefix_rich()\n    name: str = escape(record.name)\n    header: str = (\n        f\"{prefix}[bold red]{name} ── errored ({escape(record.exception_type)})[/]\"\n    )\n    if verbosity == \"minimal\":\n        return header\n    return header + f\"\\n[dim]{escape(record.traceback_str)}[/]\"\n\n\n# ── _TableRecord ─────────────────────────────────────────────────────\n\n\ndef _format_table_body(record: _TableRecord) -> str:\n    import polars as pl\n\n    from sglang.srt.debug_utils.comparator.display import _render_polars_as_text\n\n    return _render_polars_as_text(\n        pl.DataFrame(record.rows), title=record._table_title()\n    )\n\n\ndef _format_table_rich_body(\n    record: _TableRecord, verbosity: Verbosity = \"normal\"\n) -> RenderableType:\n    import polars as pl\n\n    from sglang.srt.debug_utils.comparator.display import (\n        _render_polars_as_rich_table,\n    )\n\n    return _render_polars_as_rich_table(\n        pl.DataFrame(record.rows), title=record._table_title()\n    )\n\n\n# ── ComparisonTensorRecord ───────────────────────────────────────────\n\n\ndef _format_tensor_comparison_body(record: ComparisonTensorRecord) -> str:\n    body: str = record._format_location_prefix() + format_comparison(record)\n    if record.replicated_checks:\n        body += \"\\n\" + format_replicated_checks(record.replicated_checks)\n    if record.traced_plan is not None:\n        body += \"\\n\" + _format_aligner_plan(record.traced_plan)\n    return body\n\n\ndef _format_tensor_comparison_rich_body(\n    record: ComparisonTensorRecord, verbosity: Verbosity = \"normal\"\n) -> RenderableType:\n    from sglang.srt.debug_utils.comparator.tensor_comparator.formatter import (\n        format_comparison_rich,\n    )\n\n    return record._format_location_prefix_rich() + format_comparison_rich(\n        record=record, verbosity=verbosity\n    )\n\n\n# ── ComparisonNonTensorRecord ────────────────────────────────────────\n\n\ndef _format_non_tensor_body(record: ComparisonNonTensorRecord) -> str:\n    suffix: str = record._format_location_suffix()\n    if record.values_equal:\n        return f\"NonTensor: {record.name}{suffix} = {record.baseline_value} ({record.baseline_type}) [equal]\"\n    return (\n        f\"NonTensor: {record.name}{suffix}\\n\"\n        f\"  baseline = {record.baseline_value} ({record.baseline_type})\\n\"\n        f\"  target   = {record.target_value} ({record.target_type})\"\n    )\n\n\ndef _format_non_tensor_rich_body(\n    record: ComparisonNonTensorRecord, verbosity: Verbosity = \"normal\"\n) -> RenderableType:\n    suffix: str = record._format_location_suffix()\n    name: str = escape(record.name)\n    baseline_val: str = escape(record.baseline_value)\n    target_val: str = escape(record.target_value)\n\n    if record.values_equal:\n        return (\n            f\"═ {name}{suffix} = {baseline_val} \"\n            f\"({record.baseline_type}) [green]✓[/]\"\n        )\n    return (\n        f\"═ [bold red]{name}{suffix}[/]\\n\"\n        f\"  baseline = {baseline_val} ({record.baseline_type})\\n\"\n        f\"  target   = {target_val} ({record.target_type})\"\n    )\n\n\n# ── SummaryRecord ────────────────────────────────────────────────────\n\n\ndef _format_summary_body(record: SummaryRecord) -> str:\n    text: str = (\n        f\"Summary: {record.passed} passed, {record.failed} failed, \"\n        f\"{record.skipped} skipped (total {record.total})\"\n    )\n    if record.errored > 0:\n        text += f\", {record.errored} errored\"\n    return text\n\n\ndef _format_summary_rich_body(\n    record: SummaryRecord, verbosity: Verbosity = \"normal\"\n) -> RenderableType:\n    text: str = (\n        f\"[bold green]{record.passed} passed[/] │ \"\n        f\"[bold red]{record.failed} failed[/] │ \"\n        f\"[yellow]{record.skipped} skipped[/] │ \"\n        f\"{record.total} total\"\n    )\n    if record.errored > 0:\n        text += f\" │ [bold red]{record.errored} errored[/]\"\n    return Panel(text, title=\"SUMMARY\", border_style=\"bold\")\n\n\n# ── LogRecord ────────────────────────────────────────────────────────\n\n\ndef _format_log_body(record: LogRecord) -> str:\n    return \"\"\n\n\n# ── Standalone helpers ───────────────────────────────────────────────\n\n\ndef _format_aligner_plan(traced_plan: TracedAlignerPlan) -> str:\n    lines: list[str] = [\"Aligner Plan:\"]\n\n    for side_label, traced_side in [\n        (\"baseline\", traced_plan.per_side.x),\n        (\"target\", traced_plan.per_side.y),\n    ]:\n        if not traced_side.step_plans:\n            lines.append(f\"  {side_label}: (no steps)\")\n            continue\n\n        step_summaries: list[str] = []\n        for traced_step in traced_side.step_plans:\n            sub_strs: list[str] = [\n                _format_sub_plan_text(traced_sub)\n                for traced_sub in traced_step.sub_plans\n            ]\n            summary: str = \", \".join(sub_strs) if sub_strs else \"passthrough\"\n            step_summaries.append(f\"step={traced_step.step}: {summary}\")\n        lines.append(f\"  {side_label}: [{'; '.join(step_summaries)}]\")\n\n    lines.extend(_format_cross_side_plan_text(traced_plan.plan))\n    return \"\\n\".join(lines)\n\n\ndef _format_sub_plan_text(traced_sub: TracedSubPlan) -> str:\n    sub_desc: str = f\"{traced_sub.plan.type}\"\n\n    if traced_sub.snapshot is not None:\n        snap = traced_sub.snapshot\n        in_count: int = len(snap.input_shapes)\n        out_count: int = len(snap.output_shapes)\n        in_shape: str = str(snap.input_shapes[0]) if snap.input_shapes else \"?\"\n        out_shape: str = str(snap.output_shapes[0]) if snap.output_shapes else \"?\"\n        sub_desc += f\" {in_count}x{in_shape} -> {out_count}x{out_shape}\"\n\n    return sub_desc\n\n\ndef _format_cross_side_plan_text(plan: AlignerPlan) -> list[str]:\n    lines: list[str] = []\n\n    if plan.token_aligner_plan is not None:\n        num_tokens: int = len(plan.token_aligner_plan.locators.x.steps)\n        lines.append(f\"  token_aligner: {num_tokens} tokens aligned\")\n\n    if plan.axis_aligner_plan is not None:\n        parts: list[str] = []\n        if plan.axis_aligner_plan.pattern.x:\n            parts.append(f\"x: {plan.axis_aligner_plan.pattern.x}\")\n        if plan.axis_aligner_plan.pattern.y:\n            parts.append(f\"y: {plan.axis_aligner_plan.pattern.y}\")\n        lines.append(f\"  axis_aligner: {', '.join(parts)}\")\n\n    return lines\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/output_types.py",
    "content": "from __future__ import annotations\n\nfrom abc import abstractmethod\nfrom typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, Union\n\nfrom pydantic import ConfigDict, Discriminator, Field, TypeAdapter, model_validator\nfrom rich.console import RenderableType\nfrom rich.markup import escape\n\nfrom sglang.srt.debug_utils.comparator.output_formatter import (  # noqa: F401 — re-export\n    _format_aligner_plan as _format_aligner_plan,\n)\nfrom sglang.srt.debug_utils.comparator.output_formatter import (\n    _format_config_body,\n    _format_config_rich_body,\n    _format_error_body,\n    _format_error_rich_body,\n    _format_log_body,\n    _format_non_tensor_body,\n    _format_non_tensor_rich_body,\n    _format_skip_body,\n    _format_skip_rich_body,\n    _format_summary_body,\n    _format_summary_rich_body,\n    _format_table_body,\n    _format_table_rich_body,\n    _format_tensor_comparison_body,\n    _format_tensor_comparison_rich_body,\n    _render_record_rich,\n    _render_record_text,\n)\nfrom sglang.srt.debug_utils.comparator.tensor_comparator.types import (\n    DiffInfo,\n    TensorComparisonInfo,\n)\nfrom sglang.srt.debug_utils.comparator.utils import Pair, _StrictBase\n\nif TYPE_CHECKING:\n    from sglang.srt.debug_utils.comparator.aligner.entrypoint.traced_types import (\n        TracedAlignerPlan,\n    )\n    from sglang.srt.debug_utils.comparator.report_sink import Verbosity\n\n\nclass BaseLog(_StrictBase):\n    category: str\n    message: str\n\n    def to_text(self) -> str:\n        return self.message\n\n\nclass ErrorLog(BaseLog):\n    kind: Literal[\"error\"] = \"error\"\n\n\nclass InfoLog(BaseLog):\n    kind: Literal[\"info\"] = \"info\"\n\n\nAnyLog = Annotated[Union[ErrorLog, InfoLog], Discriminator(\"kind\")]\n\n\ndef _split_logs(logs: list[BaseLog]) -> tuple[list[ErrorLog], list[InfoLog]]:\n    errors: list[ErrorLog] = [log for log in logs if isinstance(log, ErrorLog)]\n    infos: list[InfoLog] = [log for log in logs if isinstance(log, InfoLog)]\n    return errors, infos\n\n\nclass ReplicatedCheckResult(_StrictBase):\n    axis: str\n    group_index: int\n    compared_index: int\n    baseline_index: int\n    passed: bool\n    atol: float\n    diff: Optional[DiffInfo] = None\n\n\nclass BundleFileInfo(_StrictBase):\n    \"\"\"Per-file info within a bundle (one rank's raw tensor).\"\"\"\n\n    shape: list[int]\n    dtype: str\n    rank: Optional[int] = None\n    parallel_info: Optional[dict[str, str]] = None  # e.g. {\"tp\": \"0/4\", \"ep\": \"1/2\"}\n\n\nclass BundleSideInfo(_StrictBase):\n    num_files: int\n    files: list[BundleFileInfo]\n    dims: Optional[str] = None  # e.g. \"b s h(tp) d\"\n\n\nclass ShapeSnapshot(_StrictBase):\n    input_shapes: list[list[int]]\n    output_shapes: list[list[int]]\n\n\nclass _OutputRecord(_StrictBase):\n    errors: list[ErrorLog] = Field(default_factory=list)\n    infos: list[InfoLog] = Field(default_factory=list)\n\n    @abstractmethod\n    def _format_body(self) -> str: ...\n\n    def _format_rich_body(self, verbosity: Verbosity = \"normal\") -> RenderableType:\n        return self._format_body()\n\n    def to_rich(self, verbosity: Verbosity = \"normal\") -> RenderableType:\n        return _render_record_rich(self, verbosity=verbosity)\n\n    def to_text(self) -> str:\n        return _render_record_text(self)\n\n\nclass RecordLocation(_StrictBase):\n    step: Optional[int] = None\n\n\nclass _BaseComparisonRecord(_OutputRecord):\n    location: RecordLocation = Field(default_factory=RecordLocation)\n\n    def _format_location_prefix(self) -> str:\n        if self.location.step is not None:\n            return f\"[step={self.location.step}] \"\n        return \"\"\n\n    def _format_location_prefix_rich(self) -> str:\n        if self.location.step is not None:\n            return escape(f\"[step={self.location.step}]\") + \" \"\n        return \"\"\n\n    def _format_location_suffix(self) -> str:\n        if self.location.step is not None:\n            return f\" (step={self.location.step})\"\n        return \"\"\n\n\nclass ConfigRecord(_OutputRecord):\n    type: Literal[\"config\"] = \"config\"\n    config: dict[str, Any]\n\n    def _format_body(self) -> str:\n        return _format_config_body(self)\n\n    def _format_rich_body(self, verbosity: Verbosity = \"normal\") -> RenderableType:\n        return _format_config_rich_body(self, verbosity=verbosity)\n\n\nclass ComparisonSkipRecord(_BaseComparisonRecord):\n    type: Literal[\"comparison_skip\"] = \"comparison_skip\"\n    name: str\n    reason: str\n\n    @property\n    def category(self) -> str:\n        if self.errors:\n            return \"failed\"\n        return \"skipped\"\n\n    def _format_body(self) -> str:\n        return _format_skip_body(self)\n\n    def _format_rich_body(self, verbosity: Verbosity = \"normal\") -> RenderableType:\n        return _format_skip_rich_body(self, verbosity=verbosity)\n\n\nclass ComparisonErrorRecord(_BaseComparisonRecord):\n    type: Literal[\"comparison_error\"] = \"comparison_error\"\n    name: str\n    exception_type: str\n    traceback_str: str\n\n    @property\n    def category(self) -> str:\n        return \"errored\"\n\n    def _format_body(self) -> str:\n        return _format_error_body(self)\n\n    def _format_rich_body(self, verbosity: Verbosity = \"normal\") -> RenderableType:\n        return _format_error_rich_body(self, verbosity=verbosity)\n\n\nclass _TableRecord(_OutputRecord):\n    label: str\n    rows: list[dict[str, Any]]\n\n    @abstractmethod\n    def _table_title(self) -> str: ...\n\n    def _format_body(self) -> str:\n        return _format_table_body(self)\n\n    def _format_rich_body(self, verbosity: Verbosity = \"normal\") -> RenderableType:\n        return _format_table_rich_body(self, verbosity=verbosity)\n\n\nclass RankInfoRecord(_TableRecord):\n    type: Literal[\"rank_info\"] = \"rank_info\"\n\n    def _table_title(self) -> str:\n        return f\"{self.label} ranks\"\n\n\nclass InputIdsRecord(_TableRecord):\n    type: Literal[\"input_ids\"] = \"input_ids\"\n\n    def _table_title(self) -> str:\n        return f\"{self.label} input_ids & positions\"\n\n\nclass ComparisonTensorRecord(TensorComparisonInfo, _BaseComparisonRecord):\n    model_config = ConfigDict(extra=\"forbid\", defer_build=True)\n\n    type: Literal[\"comparison_tensor\"] = \"comparison_tensor\"\n    traced_plan: Optional[TracedAlignerPlan] = None\n    replicated_checks: list[ReplicatedCheckResult] = Field(default_factory=list)\n    raw_bundle_info: Optional[Pair[BundleSideInfo]] = None\n\n    @property\n    def category(self) -> str:\n        if self.errors:\n            return \"failed\"\n        if any(not check.passed for check in self.replicated_checks):\n            return \"failed\"\n        return \"passed\" if self.diff is not None and self.diff.passed else \"failed\"\n\n    def _format_body(self) -> str:\n        return _format_tensor_comparison_body(self)\n\n    def _format_rich_body(self, verbosity: Verbosity = \"normal\") -> RenderableType:\n        return _format_tensor_comparison_rich_body(self, verbosity=verbosity)\n\n\nclass ComparisonNonTensorRecord(_BaseComparisonRecord):\n    type: Literal[\"comparison_non_tensor\"] = \"comparison_non_tensor\"\n    name: str\n    baseline_value: str\n    target_value: str\n    baseline_type: str\n    target_type: str\n    values_equal: bool\n\n    @property\n    def category(self) -> str:\n        if self.errors:\n            return \"failed\"\n        return \"passed\" if self.values_equal else \"failed\"\n\n    def _format_body(self) -> str:\n        return _format_non_tensor_body(self)\n\n    def _format_rich_body(self, verbosity: Verbosity = \"normal\") -> RenderableType:\n        return _format_non_tensor_rich_body(self, verbosity=verbosity)\n\n\nclass SummaryRecord(_OutputRecord):\n    type: Literal[\"summary\"] = \"summary\"\n    total: int\n    passed: int\n    failed: int\n    skipped: int\n    errored: int = 0\n\n    @model_validator(mode=\"after\")\n    def _validate_totals(self) -> \"SummaryRecord\":\n        expected: int = self.passed + self.failed + self.skipped + self.errored\n        if self.total != expected:\n            raise ValueError(\n                f\"total={self.total} != passed({self.passed}) + failed({self.failed}) \"\n                f\"+ skipped({self.skipped}) + errored({self.errored}) = {expected}\"\n            )\n        return self\n\n    def _format_body(self) -> str:\n        return _format_summary_body(self)\n\n    def _format_rich_body(self, verbosity: Verbosity = \"normal\") -> RenderableType:\n        return _format_summary_rich_body(self, verbosity=verbosity)\n\n\nclass LogRecord(_OutputRecord):\n    type: Literal[\"log\"] = \"log\"\n\n    def _format_body(self) -> str:\n        return _format_log_body(self)\n\n\nAnyRecord = Annotated[\n    Union[\n        ConfigRecord,\n        RankInfoRecord,\n        InputIdsRecord,\n        ComparisonSkipRecord,\n        ComparisonErrorRecord,\n        ComparisonTensorRecord,\n        ComparisonNonTensorRecord,\n        SummaryRecord,\n        LogRecord,\n    ],\n    Discriminator(\"type\"),\n]\n\n\ndef _get_any_record_adapter() -> TypeAdapter:\n    return TypeAdapter(AnyRecord)\n\n\ndef parse_record_json(json_str: str | bytes) -> AnyRecord:\n    return _get_any_record_adapter().validate_json(json_str)\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/per_token_visualizer.py",
    "content": "\"\"\"Per-token relative difference heatmap generator.\n\nProduces a single PNG with rows = tensor names, columns = token positions,\ncolor = log10(rel_diff).\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom pathlib import Path\nfrom typing import Optional\n\nfrom sglang.srt.debug_utils.comparator.output_types import ComparisonTensorRecord\n\n\ndef generate_per_token_heatmap(\n    *,\n    records: list[ComparisonTensorRecord],\n    output_path: Path,\n) -> Optional[Path]:\n    \"\"\"Generate a per-token relative difference heatmap PNG.\n\n    Returns the output path if a file was written, or None if no data was available.\n    \"\"\"\n    rows_data: list[tuple[str, list[float]]] = _collect_per_token_data(records=records)\n    if not rows_data:\n        return None\n\n    _render_heatmap(rows_data=rows_data, output_path=output_path)\n    return output_path\n\n\ndef _collect_per_token_data(\n    *,\n    records: list[ComparisonTensorRecord],\n) -> list[tuple[str, list[float]]]:\n    rows: list[tuple[str, list[float]]] = []\n    for record in records:\n        if record.diff is None or record.diff.per_token_rel_diff is None:\n            continue\n        rows.append((record.name, record.diff.per_token_rel_diff))\n    return rows\n\n\ndef _render_heatmap(\n    *,\n    rows_data: list[tuple[str, list[float]]],\n    output_path: Path,\n) -> None:\n    import matplotlib\n    import numpy as np\n\n    matplotlib.use(\"Agg\")\n    import matplotlib.pyplot as plt\n\n    max_len: int = max(len(vals) for _, vals in rows_data)\n    labels: list[str] = [label for label, _ in rows_data]\n\n    matrix: np.ndarray = np.full((len(rows_data), max_len), np.nan, dtype=np.float64)\n    for i, (_, vals) in enumerate(rows_data):\n        matrix[i, : len(vals)] = vals\n\n    fig_width: float = max(12.0, max_len * 0.15)\n    fig_height: float = max(6.0, len(rows_data) * 0.3)\n    fig, ax = plt.subplots(figsize=(fig_width, fig_height))\n\n    im = ax.imshow(\n        np.log10(matrix + 1e-10), aspect=\"auto\", cmap=\"hot\", interpolation=\"nearest\"\n    )\n\n    ax.set_xlabel(\"Token Position\")\n    ax.set_ylabel(\"Tensor\")\n    ax.set_yticks(range(len(labels)))\n    ax.set_yticklabels(labels, fontsize=8)\n\n    colorbar = fig.colorbar(im, ax=ax)\n    colorbar.set_label(\"log10(rel_diff)\")\n\n    ax.set_title(\"Per-Token Relative Difference Heatmap\")\n    fig.tight_layout()\n\n    output_path.parent.mkdir(parents=True, exist_ok=True)\n    fig.savefig(str(output_path), dpi=150)\n    plt.close(fig)\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/preset.py",
    "content": "from __future__ import annotations\n\nPRESETS: dict[str, list[str]] = {\n    \"raw\": [\n        \"--grouping-skip-keys\",\n    ],\n    \"sglang_dev\": [\n        \"--grouping-skip-keys\",\n        \"rank\",\n    ],\n    \"sglang_megatron\": [\n        \"--grouping-skip-keys\",\n        \"rank\",\n        \"step\",\n        \"--token-aligner\",\n        \"concat_steps\",\n    ],\n}\n\nDEFAULT_PRESET: str = \"sglang_dev\"\n\n\ndef expand_preset(argv: list[str], presets: dict[str, list[str]]) -> list[str]:\n    \"\"\"Expand ``--preset <name>`` into the corresponding argv fragment.\n\n    If ``--preset`` is absent **and** ``--grouping-skip-keys`` is also absent,\n    the DEFAULT_PRESET is applied automatically.\n    \"\"\"\n    if (expanded := _expand_flag(argv, \"--preset\", presets)) is not None:\n        return expanded\n\n    if \"--grouping-skip-keys\" not in argv:\n        return presets[DEFAULT_PRESET] + argv\n\n    return argv\n\n\ndef _expand_flag(\n    argv: list[str], flag: str, mapping: dict[str, list[str]]\n) -> list[str] | None:\n    \"\"\"Replace ``flag <name>`` in *argv* with the corresponding argv fragment from *mapping*.\"\"\"\n    if flag not in argv:\n        return None\n\n    idx: int = argv.index(flag)\n    name: str = argv[idx + 1]\n    if name not in mapping:\n        raise ValueError(\n            f\"Unknown value for {flag}: {name}. Available: {list(mapping.keys())}\"\n        )\n\n    return argv[:idx] + mapping[name] + argv[idx + 2 :]\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/report_sink.py",
    "content": "from __future__ import annotations\n\nimport sys\nfrom pathlib import Path\nfrom typing import IO, Literal, Optional\n\nfrom rich.console import Console\n\nfrom sglang.srt.debug_utils.comparator.output_types import _OutputRecord\n\nVerbosity = Literal[\"minimal\", \"normal\", \"verbose\"]\n\n\nclass ReportSink:\n    \"\"\"Unified entry point for all record output.\"\"\"\n\n    def __init__(self) -> None:\n        self._output_format: str = \"text\"\n        self._verbosity: Verbosity = \"normal\"\n        self._report_file: Optional[IO[str]] = None\n        self._report_path: Optional[Path] = None\n        self._console: Optional[Console] = None\n\n    @property\n    def verbosity(self) -> Verbosity:\n        return self._verbosity\n\n    def configure(\n        self,\n        *,\n        output_format: str = \"text\",\n        report_path: Optional[Path] = None,\n        verbosity: Verbosity = \"normal\",\n    ) -> None:\n        self._output_format = output_format\n        self._verbosity = verbosity\n\n        if report_path is not None:\n            try:\n                report_path.parent.mkdir(parents=True, exist_ok=True)\n                self._report_file = open(report_path, \"w\", encoding=\"utf-8\")\n                self._report_path = report_path\n            except OSError as exc:\n                print(\n                    f\"Warning: cannot open report file {report_path}: {exc}\",\n                    file=sys.stderr,\n                )\n\n    def add(self, record: _OutputRecord) -> None:\n        self._print_to_stdout(record)\n\n        if self._report_file is not None:\n            self._report_file.write(record.model_dump_json())\n            self._report_file.write(\"\\n\")\n            self._report_file.flush()\n\n    def close(self) -> None:\n        if self._report_file is not None:\n            self._report_file.close()\n            self._report_file = None\n\n    @property\n    def report_path(self) -> Optional[Path]:\n        return self._report_path\n\n    def _reset(self) -> None:\n        self.close()\n        self._output_format = \"text\"\n        self._verbosity = \"normal\"\n        self._report_path = None\n        self._console = None\n\n    def _get_console(self) -> Console:\n        if self._console is None:\n            self._console = Console()\n        return self._console\n\n    def _print_to_stdout(self, record: _OutputRecord) -> None:\n        if self._output_format == \"json\":\n            print(record.model_dump_json())\n        else:\n            console: Console = self._get_console()\n            console.print(record.to_rich(verbosity=self._verbosity))\n            console.print()  # blank line between records\n\n\nreport_sink = ReportSink()\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/tensor_comparator/__init__.py",
    "content": "from sglang.srt.debug_utils.comparator.tensor_comparator.comparator import (\n    compare_tensor_pair,\n)\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/tensor_comparator/comparator.py",
    "content": "from typing import Optional\n\nimport torch\n\nfrom sglang.srt.debug_utils.comparator.tensor_comparator.types import (\n    DEFAULT_PERCENTILES,\n    DiffInfo,\n    TensorComparisonInfo,\n    TensorInfo,\n    TensorStats,\n)\nfrom sglang.srt.debug_utils.comparator.utils import (\n    Pair,\n    argmax_coord,\n    calc_per_token_rel_diff,\n    calc_rel_diff,\n    compute_smaller_dtype,\n    try_unify_shape,\n)\nfrom sglang.srt.debug_utils.dumper import get_truncated_value\n\nQUANTILE_NUMEL_THRESHOLD = 10_000_000\nSAMPLE_DIFF_THRESHOLD = 1e-3\n\n\ndef compare_tensor_pair(\n    x_baseline: torch.Tensor,\n    x_target: torch.Tensor,\n    name: str = \"\",\n    diff_threshold: float = 1e-3,\n    seq_dim: Optional[int] = None,\n) -> TensorComparisonInfo:\n    baseline_info = TensorInfo(\n        shape=list(x_baseline.shape),\n        dtype=str(x_baseline.dtype),\n        stats=_compute_tensor_stats(x_baseline.float()),\n    )\n    target_info = TensorInfo(\n        shape=list(x_target.shape),\n        dtype=str(x_target.dtype),\n        stats=_compute_tensor_stats(x_target.float()),\n    )\n\n    x_baseline = try_unify_shape(x_baseline, target_shape=x_target.shape)\n    unified_shape = list(x_baseline.shape)\n\n    baseline_original_dtype = x_baseline.dtype\n    target_original_dtype = x_target.dtype\n\n    x_baseline_f = x_baseline.float()\n    x_target_f = x_target.float()\n\n    shape_mismatch = x_baseline_f.shape != x_target_f.shape\n\n    diff: Optional[DiffInfo] = None\n    diff_downcast: Optional[DiffInfo] = None\n    downcast_dtype: Optional[torch.dtype] = None\n\n    if not shape_mismatch:\n        diff = compute_diff(\n            x_baseline=x_baseline_f,\n            x_target=x_target_f,\n            diff_threshold=diff_threshold,\n            seq_dim=seq_dim,\n        )\n\n        needs_sample = diff.max_abs_diff > SAMPLE_DIFF_THRESHOLD\n        if needs_sample:\n            baseline_info.sample = str(get_truncated_value(x_baseline_f))\n            target_info.sample = str(get_truncated_value(x_target_f))\n\n        if baseline_original_dtype != target_original_dtype:\n            downcast_dtype = compute_smaller_dtype(\n                Pair(x=baseline_original_dtype, y=target_original_dtype)\n            )\n            if downcast_dtype is not None:\n                diff_downcast = compute_diff(\n                    x_baseline=x_baseline_f.to(downcast_dtype),\n                    x_target=x_target_f.to(downcast_dtype),\n                    diff_threshold=diff_threshold,\n                )\n\n    return TensorComparisonInfo(\n        name=name,\n        baseline=baseline_info,\n        target=target_info,\n        unified_shape=unified_shape,\n        shape_mismatch=shape_mismatch,\n        diff=diff,\n        diff_downcast=diff_downcast,\n        downcast_dtype=str(downcast_dtype) if downcast_dtype is not None else None,\n    )\n\n\ndef _compute_tensor_stats(x: torch.Tensor) -> TensorStats:\n    if x.numel() == 0:\n        return TensorStats(\n            mean=0.0,\n            abs_mean=0.0,\n            std=0.0,\n            min=0.0,\n            max=0.0,\n            percentiles={},\n        )\n\n    include_quantiles: bool = x.numel() < QUANTILE_NUMEL_THRESHOLD\n    return TensorStats(\n        mean=torch.mean(x).item(),\n        abs_mean=torch.mean(x.abs()).item(),\n        std=torch.std(x).item(),\n        min=torch.min(x).item(),\n        max=torch.max(x).item(),\n        percentiles=_compute_percentiles(x, include=include_quantiles),\n    )\n\n\ndef _compute_percentiles(x: torch.Tensor, *, include: bool) -> dict[int, float]:\n    if not include:\n        return {}\n    x_float: torch.Tensor = x.float()\n    return {p: torch.quantile(x_float, p / 100.0).item() for p in DEFAULT_PERCENTILES}\n\n\ndef compute_diff(\n    x_baseline: torch.Tensor,\n    x_target: torch.Tensor,\n    diff_threshold: float = 1e-3,\n    seq_dim: Optional[int] = None,\n) -> DiffInfo:\n    if x_baseline.numel() == 0:\n        return DiffInfo(\n            rel_diff=0.0,\n            max_abs_diff=0.0,\n            mean_abs_diff=0.0,\n            abs_diff_percentiles={},\n            max_diff_coord=[],\n            baseline_at_max=0.0,\n            target_at_max=0.0,\n            diff_threshold=diff_threshold,\n            passed=True,\n        )\n\n    raw_abs_diff = (x_target - x_baseline).abs()\n    max_diff_coord = argmax_coord(raw_abs_diff)\n\n    rel_diff = calc_rel_diff(x_target, x_baseline).item()\n    max_abs_diff = raw_abs_diff.max().item()\n    mean_abs_diff = raw_abs_diff.mean().item()\n\n    include_quantiles: bool = raw_abs_diff.numel() < QUANTILE_NUMEL_THRESHOLD\n\n    per_token_rel_diff: Optional[list[float]] = None\n    if seq_dim is not None and x_baseline.dim() > seq_dim:\n        per_token_rel_diff = calc_per_token_rel_diff(\n            x_baseline, x_target, seq_dim=seq_dim\n        ).tolist()\n\n    return DiffInfo(\n        rel_diff=rel_diff,\n        max_abs_diff=max_abs_diff,\n        mean_abs_diff=mean_abs_diff,\n        abs_diff_percentiles=_compute_percentiles(\n            raw_abs_diff, include=include_quantiles\n        ),\n        max_diff_coord=list(max_diff_coord),\n        baseline_at_max=x_baseline[max_diff_coord].item(),\n        target_at_max=x_target[max_diff_coord].item(),\n        diff_threshold=diff_threshold,\n        passed=rel_diff <= diff_threshold,\n        per_token_rel_diff=per_token_rel_diff,\n    )\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/tensor_comparator/formatter.py",
    "content": "from __future__ import annotations\n\nfrom typing import TYPE_CHECKING, Literal, Optional\n\nfrom rich.markup import escape\n\nfrom sglang.srt.debug_utils.comparator.aligner.unsharder.types import UnsharderPlan\nfrom sglang.srt.debug_utils.comparator.tensor_comparator.types import (\n    DiffInfo,\n    TensorComparisonInfo,\n    TensorInfo,\n    TensorStats,\n)\n\nif TYPE_CHECKING:\n    from sglang.srt.debug_utils.comparator.aligner.entrypoint.traced_types import (\n        TracedAlignerPlan,\n        TracedSubPlan,\n    )\n    from sglang.srt.debug_utils.comparator.aligner.entrypoint.types import AlignerPlan\n    from sglang.srt.debug_utils.comparator.output_types import (\n        BundleSideInfo,\n        ComparisonTensorRecord,\n        ReplicatedCheckResult,\n        ShapeSnapshot,\n    )\n    from sglang.srt.debug_utils.comparator.utils import Pair\n\nVerbosity = Literal[\"minimal\", \"normal\", \"verbose\"]\n\n\ndef _esc_shape(shape: Optional[list[int]]) -> str:\n    return escape(str(shape))\n\n\ndef _strip_torch_prefix(dtype: str) -> str:\n    return dtype.replace(\"torch.\", \"\")\n\n\n# ---------------------------------------------------------------------------\n# Number formatting\n# ---------------------------------------------------------------------------\n\n\ndef _fmt_val(value: float) -> str:\n    return f\"{value:.2e}\"\n\n\ndef _fmt_diff_colored(diff: float, *, threshold: float = 1e-2) -> str:\n    formatted: str = f\"{diff:+.2e}\"\n    if abs(diff) >= threshold:\n        return f\"[yellow]{formatted}[/]\"\n    return f\"[dim]{formatted}[/]\"\n\n\n# ---------------------------------------------------------------------------\n# Passed / color / marker helper\n# ---------------------------------------------------------------------------\n\n\ndef _category_marker(category: str) -> tuple[bool, str, str]:\n    passed: bool = category == \"passed\"\n    color: str = \"green\" if passed else \"red\"\n    marker: str = f\"[{color}]✅[/]\" if passed else f\"[{color}]❌[/]\"\n    return passed, color, marker\n\n\n# ---------------------------------------------------------------------------\n# Stats formatting helpers (shared between compact / verbose)\n# ---------------------------------------------------------------------------\n\n\ndef _format_stat_line(stat_name: str, val_b: float, val_t: float, diff: float) -> str:\n    return (\n        f\"      [blue]{stat_name:10s}[/] {val_b:>10.4f} vs {val_t:>10.4f}\"\n        f\"  Δ {_fmt_diff_colored(diff)}\"\n    )\n\n\n# ---------------------------------------------------------------------------\n# Old text-only formatters (kept for to_text() backward compatibility)\n# ---------------------------------------------------------------------------\n\n\ndef format_comparison(info: TensorComparisonInfo) -> str:\n    lines: list[str] = []\n    baseline = info.baseline\n    target = info.target\n\n    dtype_marker = \"\" if baseline.dtype == target.dtype else \"🟠\"\n    lines.append(\n        f\"Raw \"\n        f\"[shape] {baseline.shape} vs {target.shape}\\t\"\n        f\"[{dtype_marker}dtype] {baseline.dtype} vs {target.dtype}\"\n    )\n\n    if info.unified_shape != baseline.shape:\n        lines.append(\n            f\"Unify shape: {baseline.shape} -> {info.unified_shape} \"\n            f\"(to match {target.shape})\"\n        )\n\n    lines.append(\n        f\"After unify \"\n        f\"[shape] {info.unified_shape} vs {target.shape}\\t\"\n        f\"[dtype] {baseline.dtype} vs {target.dtype}\"\n    )\n\n    lines.extend(_format_stats_comparison(baseline=baseline.stats, target=target.stats))\n\n    if info.shape_mismatch:\n        lines.append(\"⚠️ Shape mismatch\")\n        return \"\\n\".join(lines)\n\n    if info.diff is not None:\n        lines.extend(_format_diff(diff=info.diff))\n\n    if info.diff_downcast is not None and info.downcast_dtype is not None:\n        lines.extend(\n            _format_diff(\n                diff=info.diff_downcast,\n                prefix_text=f\"When downcast to {info.downcast_dtype}: \",\n            )\n        )\n\n    if baseline.sample is not None:\n        lines.append(f\"x_baseline(sample)={baseline.sample}\")\n    if target.sample is not None:\n        lines.append(f\"x_target(sample)={target.sample}\")\n\n    return \"\\n\".join(lines)\n\n\ndef format_replicated_checks(checks: list[ReplicatedCheckResult]) -> str:\n    lines: list[str] = [\"Replicated checks:\"]\n\n    for check in checks:\n        marker: str = \"✅\" if check.passed else \"❌\"\n\n        if check.diff is not None:\n            detail: str = (\n                f\"rel_diff={check.diff.rel_diff:.6e} \"\n                f\"max_abs_diff={check.diff.max_abs_diff:.6e} \"\n                f\"mean_abs_diff={check.diff.mean_abs_diff:.6e}\"\n            )\n        else:\n            detail = \"n/a diff\"\n\n        lines.append(\n            f\"  {marker} axis={check.axis} group={check.group_index} \"\n            f\"idx={check.compared_index} vs {check.baseline_index}: \"\n            f\"{detail}\"\n        )\n\n    return \"\\n\".join(lines)\n\n\ndef _format_stats_comparison(baseline: TensorStats, target: TensorStats) -> list[str]:\n    lines: list[str] = []\n\n    for stat_name in TensorStats.model_fields:\n        if stat_name == \"percentiles\":\n            continue\n        value_baseline: float = getattr(baseline, stat_name)\n        value_target: float = getattr(target, stat_name)\n        lines.append(\n            f\"[{stat_name}] {value_baseline:.4f} vs {value_target:.4f} \"\n            f\"(diff: {value_target - value_baseline:.4f})\"\n        )\n\n    for p in sorted(set(baseline.percentiles) & set(target.percentiles)):\n        value_baseline = baseline.percentiles[p]\n        value_target = target.percentiles[p]\n        lines.append(\n            f\"[p{p}] {value_baseline:.4f} vs {value_target:.4f} \"\n            f\"(diff: {value_target - value_baseline:.4f})\"\n        )\n\n    return lines\n\n\ndef _format_diff(diff: DiffInfo, prefix_text: str = \"\") -> list[str]:\n    rel_diff_marker: str = \"❌\" if diff.rel_diff > diff.diff_threshold else \"✅\"\n    lines: list[str] = [\n        prefix_text\n        + f\"{rel_diff_marker} rel_diff={diff.rel_diff}\\t\"\n        + f\"max_abs_diff={diff.max_abs_diff}\\t\"\n        + f\"mean_abs_diff={diff.mean_abs_diff}\",\n        f\"max_abs_diff happens at coord={diff.max_diff_coord} with \"\n        f\"baseline={diff.baseline_at_max} \"\n        f\"target={diff.target_at_max}\",\n    ]\n\n    if diff.abs_diff_percentiles:\n        quantile_parts: list[str] = [\n            f\"p{p}={value:.4f}\"\n            for p, value in sorted(diff.abs_diff_percentiles.items())\n        ]\n        lines.append(\"[abs_diff] \" + \" \".join(quantile_parts))\n\n    return lines\n\n\n# ---------------------------------------------------------------------------\n# New Rich markup formatters\n# ---------------------------------------------------------------------------\n\n\ndef format_comparison_rich(\n    record: ComparisonTensorRecord,\n    verbosity: Verbosity = \"normal\",\n) -> str:\n    if verbosity == \"minimal\":\n        return _format_comparison_minimal(record)\n\n    return _format_comparison_normal_or_verbose(\n        record=record,\n        verbose=(verbosity == \"verbose\"),\n    )\n\n\ndef _format_comparison_minimal(record: ComparisonTensorRecord) -> str:\n    passed, color, marker = _category_marker(record.category)\n\n    name_part: str = f\"[bold {color}]{escape(record.name):30s}[/]\"\n    if record.diff is not None:\n        return f\"{marker} {name_part} rel_diff={_fmt_val(record.diff.rel_diff)}\"\n    elif record.shape_mismatch:\n        return f\"{marker} {name_part} [yellow]shape mismatch[/]\"\n    else:\n        return f\"{marker} {name_part}\"\n\n\ndef _format_comparison_normal_or_verbose(\n    *,\n    record: ComparisonTensorRecord,\n    verbose: bool,\n) -> str:\n    passed, color, marker = _category_marker(record.category)\n\n    baseline: TensorInfo = record.baseline\n    target: TensorInfo = record.target\n    aligned_shape: str = _esc_shape(record.unified_shape)\n    dtype_str: str = _strip_torch_prefix(baseline.dtype)\n\n    lines: list[str] = []\n\n    # L0: Header\n    lines.append(\n        f\"{marker} [bold {color}]{escape(record.name)}[/] \"\n        f\"[dim cyan]── {dtype_str}  {aligned_shape}[/]\"\n    )\n\n    # L1: Key metrics\n    if record.diff is not None:\n        diff: DiffInfo = record.diff\n        rel_style: str = f\"bold {color}\" if not passed else color\n        lines.append(\n            f\"   [{rel_style}]rel_diff={_fmt_val(diff.rel_diff)}[/]\"\n            f\"  max_abs={_fmt_val(diff.max_abs_diff)}\"\n            f\"  mean_abs={_fmt_val(diff.mean_abs_diff)}\"\n        )\n\n        if not passed:\n            lines.append(\n                f\"   max_abs @ {_esc_shape(diff.max_diff_coord)}: \"\n                f\"baseline={diff.baseline_at_max}  target={diff.target_at_max}\"\n            )\n    elif record.shape_mismatch:\n        lines.append(\"   [yellow]⚠ Shape mismatch[/]\")\n\n    # Downcast info\n    if record.diff_downcast is not None and record.downcast_dtype is not None:\n        dc: DiffInfo = record.diff_downcast\n        dc_marker: str = \"[green]✅[/]\" if dc.passed else \"[red]❌[/]\"\n        lines.append(\n            f\"   {dc_marker} downcast to {record.downcast_dtype}: \"\n            f\"rel_diff={_fmt_val(dc.rel_diff)}\"\n        )\n\n    # Bundle section\n    if record.raw_bundle_info is not None:\n        lines.append(\"   [dim]Bundle[/]\")\n        lines.extend(\n            _format_bundle_section(bundle_info=record.raw_bundle_info, verbose=verbose)\n        )\n\n    # Plan section\n    if record.traced_plan is not None:\n        lines.append(\"   [dim]Plan[/]\")\n        lines.extend(\n            _format_plan_section_rich(\n                traced_plan=record.traced_plan,\n                verbose=verbose,\n            )\n        )\n\n    # Aligned section\n    lines.append(\"   [dim]Aligned[/]\")\n    lines.append(\n        f\"      {_esc_shape(record.unified_shape)} vs {_esc_shape(target.shape)}\"\n        f\"   {baseline.dtype} vs {target.dtype}\"\n    )\n\n    # Stats section\n    lines.append(\"   [dim]Stats[/]\")\n    lines.extend(\n        _format_stats_rich(\n            baseline=baseline.stats, target=target.stats, verbose=verbose\n        )\n    )\n\n    show_detail: bool = verbose or not passed\n\n    # Abs diff percentiles\n    if show_detail and record.diff is not None and record.diff.abs_diff_percentiles:\n        lines.append(\"   [dim]Abs Diff Percentiles[/]\")\n        lines.append(\"      \" + _format_abs_diff_percentiles_rich(record.diff))\n\n    # Samples\n    if show_detail and baseline.sample is not None:\n        lines.append(\"   [dim]Samples[/]\")\n        lines.append(f\"      baseline  {escape(baseline.sample)}\")\n        if target.sample is not None:\n            lines.append(f\"      target    {escape(target.sample)}\")\n\n    # Replicated checks\n    if show_detail and record.replicated_checks:\n        lines.append(\"   [dim]Replicated Checks[/]\")\n        for check in record.replicated_checks:\n            chk_marker: str = \"[green]✅[/]\" if check.passed else \"[red]❌[/]\"\n            if check.diff is not None:\n                lines.append(\n                    f\"      {chk_marker} axis={check.axis}  group={check.group_index}\"\n                    f\"  idx={check.compared_index} vs {check.baseline_index}\"\n                    f\"  rel_diff={_fmt_val(check.diff.rel_diff)}\"\n                    f\"  max_abs={_fmt_val(check.diff.max_abs_diff)}\"\n                )\n            else:\n                lines.append(\n                    f\"      {chk_marker} axis={check.axis}  group={check.group_index}\"\n                    f\"  idx={check.compared_index} vs {check.baseline_index}: n/a\"\n                )\n\n    return \"\\n\".join(lines)\n\n\ndef _format_bundle_section(\n    bundle_info: Pair[BundleSideInfo], *, verbose: bool = False\n) -> list[str]:\n    lines: list[str] = []\n\n    for label, side in [(\"baseline\", bundle_info.x), (\"target\", bundle_info.y)]:\n        if not side.files:\n            lines.append(f\"      {label}  [dim](no files)[/]\")\n            continue\n\n        dtype_desc: str = _strip_torch_prefix(side.files[0].dtype)\n\n        if verbose:\n            dims_part: str = f\"  dims: {side.dims}\" if side.dims else \"\"\n            lines.append(\n                f\"      {label}  [cyan]{side.num_files} files[/]\"\n                f\" {dtype_desc}{dims_part}\"\n            )\n\n            for idx, f in enumerate(side.files):\n                rank_part: str = f\"rank={f.rank}\" if f.rank is not None else \"\"\n                par_part: str = \"\"\n                if f.parallel_info:\n                    par_part = \" \" + \" \".join(\n                        f\"{k}={v}\" for k, v in f.parallel_info.items()\n                    )\n                lines.append(\n                    f\"         [{idx}] {_esc_shape(f.shape)}  {rank_part}{par_part}\"\n                )\n        else:\n            shapes: list[list[int]] = [f.shape for f in side.files]\n            unique_shapes: set[str] = {str(s) for s in shapes}\n            shape_desc: str\n            if len(unique_shapes) == 1:\n                shape_desc = _esc_shape(shapes[0])\n            else:\n                shape_desc = \"mixed shapes\"\n\n            dims_part = f\"  [dim]dims: {side.dims}[/]\" if side.dims else \"\"\n            lines.append(\n                f\"      {label}  [cyan]{side.num_files} files[/]\"\n                f\" × {shape_desc} {dtype_desc}{dims_part}\"\n            )\n\n    return lines\n\n\ndef _format_plan_section_rich(\n    *,\n    traced_plan: TracedAlignerPlan,\n    verbose: bool = False,\n) -> list[str]:\n    lines: list[str] = []\n\n    for side_label, traced_side in [\n        (\"baseline\", traced_plan.per_side.x),\n        (\"target\", traced_plan.per_side.y),\n    ]:\n        if not traced_side.step_plans:\n            lines.append(f\"      {side_label}  [dim](passthrough)[/]\")\n            continue\n\n        parts: list[str] = [\n            _format_sub_plan_rich(traced_sub)\n            for traced_step in traced_side.step_plans\n            for traced_sub in traced_step.sub_plans\n        ]\n        lines.append(f\"      {side_label}  \" + \" → \".join(parts))\n\n    lines.extend(_format_cross_side_plan_rich(traced_plan.plan))\n    return lines\n\n\ndef _format_sub_plan_rich(traced_sub: TracedSubPlan) -> str:\n    sub = traced_sub.plan\n    snapshot: Optional[ShapeSnapshot] = traced_sub.snapshot\n\n    op_name: str = sub.type\n    axis_str: str = \"\"\n    if isinstance(sub, UnsharderPlan):\n        axis_str = f\"({sub.axis})\"\n\n    shape_change: str = \"\"\n    if snapshot:\n        in_count: int = len(snapshot.input_shapes)\n        out_count: int = len(snapshot.output_shapes)\n        in_shape: str = (\n            _esc_shape(snapshot.input_shapes[0]) if snapshot.input_shapes else \"?\"\n        )\n        out_shape: str = (\n            _esc_shape(snapshot.output_shapes[0]) if snapshot.output_shapes else \"?\"\n        )\n        shape_change = f\" {in_count}×{in_shape} → {out_count}×{out_shape}\"\n\n    return f\"[magenta]{op_name}{axis_str}[/]{shape_change}\"\n\n\ndef _format_cross_side_plan_rich(plan: AlignerPlan) -> list[str]:\n    lines: list[str] = []\n\n    if plan.token_aligner_plan is not None:\n        num_tokens: int = len(plan.token_aligner_plan.locators.x.steps)\n        lines.append(f\"      token_aligner  [dim]{num_tokens} tokens[/]\")\n\n    if plan.axis_aligner_plan is not None:\n        parts: list[str] = []\n        if plan.axis_aligner_plan.pattern.x:\n            parts.append(f\"x={plan.axis_aligner_plan.pattern.x}\")\n        if plan.axis_aligner_plan.pattern.y:\n            parts.append(f\"y={plan.axis_aligner_plan.pattern.y}\")\n        if parts:\n            lines.append(f\"      axis_aligner  [dim]{', '.join(parts)}[/]\")\n        else:\n            lines.append(\"      axis_aligner  [dim](no-op)[/]\")\n\n    return lines\n\n\ndef _format_stats_rich(\n    *,\n    baseline: TensorStats,\n    target: TensorStats,\n    verbose: bool = False,\n) -> list[str]:\n    lines: list[str] = []\n\n    if verbose:\n        # All stat fields\n        for stat_name in TensorStats.model_fields:\n            if stat_name == \"percentiles\":\n                continue\n            val_b: float = getattr(baseline, stat_name)\n            val_t: float = getattr(target, stat_name)\n            lines.append(_format_stat_line(stat_name, val_b, val_t, val_t - val_b))\n\n        # Percentiles\n        for p in sorted(set(baseline.percentiles) & set(target.percentiles)):\n            val_b = baseline.percentiles[p]\n            val_t = target.percentiles[p]\n            lines.append(_format_stat_line(f\"p{p}\", val_b, val_t, val_t - val_b))\n    else:\n        # Compact: mean, std, range (min/max combined)\n        for stat_name in (\"mean\", \"std\"):\n            val_b = getattr(baseline, stat_name)\n            val_t = getattr(target, stat_name)\n            lines.append(_format_stat_line(stat_name, val_b, val_t, val_t - val_b))\n\n        # Range line: combine min/max (escape brackets to avoid Rich markup)\n        range_baseline: str = escape(f\"[{baseline.min:.4f}, {baseline.max:.4f}]\")\n        range_target: str = escape(f\"[{target.min:.4f}, {target.max:.4f}]\")\n        lines.append(f\"      [blue]{'range':10s}[/] {range_baseline} vs {range_target}\")\n\n    return lines\n\n\ndef _format_abs_diff_percentiles_rich(diff: DiffInfo) -> str:\n    parts: list[str] = []\n    for p, value in sorted(diff.abs_diff_percentiles.items()):\n        formatted: str = f\"p{p}={_fmt_val(value)}\"\n        if p >= 99 and value > 0.1:\n            formatted = f\"[yellow]{formatted}[/]\"\n        parts.append(formatted)\n    return \"  \".join(parts)\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/tensor_comparator/types.py",
    "content": "from typing import Optional\n\nfrom sglang.srt.debug_utils.comparator.utils import _StrictBase\n\nDEFAULT_PERCENTILES: tuple[int, ...] = (1, 5, 50, 95, 99)\n\n\nclass TensorStats(_StrictBase):\n    mean: float\n    abs_mean: float\n    std: float\n    min: float\n    max: float\n    percentiles: dict[int, float] = {}\n\n\nclass TensorInfo(_StrictBase):\n    shape: list[int]\n    dtype: str\n    stats: TensorStats\n    sample: Optional[str] = None\n\n\nclass DiffInfo(_StrictBase):\n    rel_diff: float\n    max_abs_diff: float\n    mean_abs_diff: float\n    abs_diff_percentiles: dict[int, float] = {}\n    max_diff_coord: list[int]\n    baseline_at_max: float\n    target_at_max: float\n    diff_threshold: float\n    passed: bool\n    per_token_rel_diff: Optional[list[float]] = None\n\n\nclass TensorComparisonInfo(_StrictBase):\n    name: str\n    baseline: TensorInfo\n    target: TensorInfo\n    unified_shape: Optional[list[int]]\n    shape_mismatch: bool\n    diff: Optional[DiffInfo] = None\n    diff_downcast: Optional[DiffInfo] = None\n    downcast_dtype: Optional[str] = None\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/utils.py",
    "content": "from __future__ import annotations\n\nimport functools\nimport re\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING, Callable, Generic, Optional, Tuple, TypeVar\n\nimport torch\nfrom pydantic import BaseModel, ConfigDict\n\n_T = TypeVar(\"_T\")\n_U = TypeVar(\"_U\")\n\n\ndef _check_equal_lengths(**named_lists: list) -> None:\n    lengths: dict[str, int] = {name: len(lst) for name, lst in named_lists.items()}\n    unique: set[int] = set(lengths.values())\n    if len(unique) > 1:\n        details: str = \", \".join(f\"{name}={length}\" for name, length in lengths.items())\n        raise ValueError(f\"Length mismatch: {details}\")\n\n\ndef auto_descend_dir(directory: Path, label: str) -> Path:\n    \"\"\"If directory has no .pt files but exactly one subdirectory does, descend into it.\n\n    Raises ValueError when the layout is ambiguous (>=2 subdirs with .pt)\n    or when no .pt data is found at all.\n    \"\"\"\n    if any(directory.glob(\"*.pt\")):\n        return directory\n\n    candidates: list[Path] = [\n        sub for sub in directory.iterdir() if sub.is_dir() and any(sub.glob(\"*.pt\"))\n    ]\n\n    if len(candidates) >= 2:\n        names: str = \", \".join(sorted(c.name for c in candidates))\n        raise ValueError(\n            f\"{label}: directory {directory} has no .pt files at top level \"\n            f\"and multiple subdirectories contain data ({names}). \"\n            f\"Please specify the exact subdirectory.\"\n        )\n\n    if len(candidates) == 0:\n        raise ValueError(\n            f\"{label}: no .pt files found in {directory} or any of its subdirectories.\"\n        )\n\n    resolved: Path = candidates[0]\n\n    from sglang.srt.debug_utils.comparator.log_sink import log_sink\n    from sglang.srt.debug_utils.comparator.output_types import InfoLog\n\n    log_sink.add(\n        InfoLog(\n            category=\"auto_descend\",\n            message=f\"auto-descend {label}: {directory} -> {resolved}\",\n        )\n    )\n    return resolved\n\n\nclass _StrictBase(BaseModel):\n    model_config = ConfigDict(extra=\"forbid\")\n\n\nclass _FrozenBase(BaseModel):\n    model_config = ConfigDict(frozen=True, extra=\"forbid\")\n\n\nclass Pair(_FrozenBase, Generic[_T]):\n    x: _T\n    y: _T\n\n    def map(self, fn: Callable[[_T], _U]) -> Pair[_U]:\n        return Pair(x=fn(self.x), y=fn(self.y))\n\n\ndef argmax_coord(x: torch.Tensor) -> Tuple[int, ...]:\n    flat_idx = x.argmax()\n    return tuple(idx.item() for idx in torch.unravel_index(flat_idx, x.shape))\n\n\ndef compute_smaller_dtype(\n    dtypes: Pair[torch.dtype],\n) -> Optional[torch.dtype]:\n    info_dict = {\n        (torch.float32, torch.bfloat16): torch.bfloat16,\n        # ... add more ...\n    }\n    return info_dict.get((dtypes.x, dtypes.y)) or info_dict.get((dtypes.y, dtypes.x))\n\n\ndef try_unify_shape(x: torch.Tensor, target_shape: torch.Size) -> torch.Tensor:\n    x_shape = x.shape\n    num_dim_to_remove = len(x_shape) - len(target_shape)\n    if (x_shape[num_dim_to_remove:] == target_shape) and all(\n        val == 1 for val in x_shape[:num_dim_to_remove]\n    ):\n        return functools.reduce(lambda a, _: a.squeeze(0), range(num_dim_to_remove), x)\n\n    return x\n\n\n# Copied from DeepGEMM\ndef calc_rel_diff(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n    x, y = x.double(), y.double()\n    denominator = (x * x + y * y).sum()\n    sim = 2 * (x * y).sum() / denominator\n    return 1 - sim\n\n\ndef calc_per_token_rel_diff(\n    x: torch.Tensor, y: torch.Tensor, *, seq_dim: int\n) -> torch.Tensor:\n    \"\"\"Cosine-distance-like metric per token position.\n\n    Sums over all dims except seq_dim.\n    \"\"\"\n    x, y = x.double(), y.double()\n    other_dims: list[int] = [d for d in range(x.dim()) if d != seq_dim]\n\n    if other_dims:\n        denominator: torch.Tensor = (x * x + y * y).sum(dim=other_dims)\n        sim: torch.Tensor = 2 * (x * y).sum(dim=other_dims) / (denominator + 1e-10)\n    else:\n        denominator = x * x + y * y\n        sim = 2 * (x * y) / (denominator + 1e-10)\n\n    return (1 - sim).float()\n\n\nif TYPE_CHECKING:\n    from sglang.srt.debug_utils.comparator.output_types import SummaryRecord\n\n\ndef compute_exit_code(\n    summary: SummaryRecord,\n    *,\n    allow_skipped_pattern: str,\n    skipped_names: list[str],\n    allow_failed_pattern: Optional[str],\n    failed_names: list[str],\n    errored_names: Optional[list[str]] = None,\n) -> int:\n    if summary.passed == 0:\n        return 1\n\n    if errored_names:\n        return 1\n\n    if not _is_all_match_pattern(pattern=allow_failed_pattern, strings=failed_names):\n        return 1\n\n    if not _is_all_match_pattern(pattern=allow_skipped_pattern, strings=skipped_names):\n        return 1\n\n    return 0\n\n\ndef _is_all_match_pattern(*, pattern: Optional[str], strings: list[str]) -> bool:\n    if pattern is None:\n        return len(strings) == 0\n    compiled: re.Pattern[str] = re.compile(pattern)\n    return all(compiled.fullmatch(s) for s in strings)\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/visualizer/__init__.py",
    "content": "from sglang.srt.debug_utils.comparator.visualizer.figure import (  # noqa: F401\n    generate_comparison_figure,\n)\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/visualizer/figure.py",
    "content": "\"\"\"Main orchestration logic for comparison figure generation.\"\"\"\n\nfrom __future__ import annotations\n\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom typing import Callable, Optional\n\nimport numpy as np\nimport torch\n\nfrom sglang.srt.debug_utils.comparator.visualizer.preprocessing import (\n    _preprocess_tensor,\n)\n\n\n@dataclass(frozen=True)\nclass _PanelContext:\n    baseline_2d: torch.Tensor\n    target_2d: torch.Tensor\n    diff: Optional[torch.Tensor]  # None when shapes differ\n    name: str\n\n\n@dataclass(frozen=True)\nclass _Panel:\n    label: str\n    requires_diff: bool\n    draw: Callable[[np.ndarray, int, _PanelContext], Optional[str]]\n\n\ndef _build_panels() -> list[_Panel]:\n    from sglang.srt.debug_utils.comparator.visualizer.panels import (\n        _draw_baseline_heatmap,\n        _draw_diff_heatmap,\n        _draw_diff_histogram,\n        _draw_hist2d,\n        _draw_sampled,\n        _draw_target_heatmap,\n    )\n\n    return [\n        _Panel(\n            label=\"Baseline Heatmap\", requires_diff=False, draw=_draw_baseline_heatmap\n        ),\n        _Panel(label=\"Target Heatmap\", requires_diff=False, draw=_draw_target_heatmap),\n        _Panel(label=\"Abs Diff Heatmap\", requires_diff=True, draw=_draw_diff_heatmap),\n        _Panel(label=\"Abs Diff Hist\", requires_diff=True, draw=_draw_diff_histogram),\n        _Panel(label=\"Hist2D\", requires_diff=True, draw=_draw_hist2d),\n        _Panel(label=\"Sampled\", requires_diff=True, draw=_draw_sampled),\n    ]\n\n\ndef generate_comparison_figure(\n    *,\n    baseline: torch.Tensor,\n    target: torch.Tensor,\n    name: str,\n    output_path: Path,\n) -> None:\n    \"\"\"Generate a multi-panel comparison PNG for a baseline/target tensor pair.\n\n    Panels (6 rows x 2 cols, left=normal, right=log10):\n      Row 0: Baseline heatmap\n      Row 1: Target heatmap\n      Row 2: Abs Diff heatmap\n      Row 3: Abs Diff histogram\n      Row 4: Hist2D scatter (baseline vs target density)\n      Row 5: Sampled scatter (10k sampled mini-heatmap)\n    \"\"\"\n    import matplotlib.pyplot as plt\n\n    baseline_f: torch.Tensor = baseline.detach().cpu().float()\n    target_f: torch.Tensor = target.detach().cpu().float()\n\n    can_diff: bool = baseline_f.shape == target_f.shape\n\n    baseline_2d: torch.Tensor = _preprocess_tensor(baseline_f)\n    target_2d: torch.Tensor = _preprocess_tensor(target_f)\n\n    diff: Optional[torch.Tensor] = (baseline_2d - target_2d).abs() if can_diff else None\n\n    ctx = _PanelContext(\n        baseline_2d=baseline_2d,\n        target_2d=target_2d,\n        diff=diff,\n        name=name,\n    )\n\n    panels: list[_Panel] = _build_panels()\n    active: list[_Panel] = [p for p in panels if not p.requires_diff or can_diff]\n\n    nrows: int = len(active)\n    ncols: int = 2\n    fig, axes = plt.subplots(nrows, ncols, figsize=(5 * ncols, 3.5 * nrows))\n    if nrows == 1:\n        axes = axes.reshape(1, -1)\n\n    stats_lines: list[str] = []\n    for i, panel in enumerate(active):\n        stats_line: Optional[str] = panel.draw(axes, i, ctx)\n        if stats_line is not None:\n            stats_lines.append(stats_line)\n\n    num_stats: int = len(stats_lines)\n    title_height: float = 0.015 * num_stats + 0.015\n    fig.suptitle(\n        \"\\n\".join(stats_lines),\n        fontsize=9,\n        family=\"monospace\",\n        y=1 - title_height / 2,\n    )\n    plt.tight_layout(rect=[0, 0, 1, 1 - title_height])\n    output_path.parent.mkdir(parents=True, exist_ok=True)\n    plt.savefig(str(output_path), dpi=150, bbox_inches=\"tight\")\n    plt.close(fig)\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/visualizer/panels.py",
    "content": "\"\"\"Panel draw functions for tensor comparison visualization.\"\"\"\n\nfrom __future__ import annotations\n\nfrom typing import Optional\n\nimport numpy as np\nimport torch\n\nfrom sglang.srt.debug_utils.comparator.visualizer.figure import _PanelContext\nfrom sglang.srt.debug_utils.comparator.visualizer.preprocessing import (\n    _SCATTER_SAMPLE_SIZE,\n    _format_log_ticks,\n    _format_stats,\n    _maybe_downsample_numpy,\n    _safe_hist,\n    _to_log10,\n)\n\n\ndef _draw_baseline_heatmap(\n    axes: np.ndarray, row_idx: int, ctx: _PanelContext\n) -> Optional[str]:\n    _draw_heatmap_pair(\n        axes, row_idx=row_idx, t=ctx.baseline_2d, title=f\"{ctx.name} Baseline\"\n    )\n    return _format_stats(\"Baseline\", ctx.baseline_2d)\n\n\ndef _draw_target_heatmap(\n    axes: np.ndarray, row_idx: int, ctx: _PanelContext\n) -> Optional[str]:\n    _draw_heatmap_pair(\n        axes, row_idx=row_idx, t=ctx.target_2d, title=f\"{ctx.name} Target\"\n    )\n    return _format_stats(\"Target\", ctx.target_2d)\n\n\ndef _draw_diff_heatmap(\n    axes: np.ndarray, row_idx: int, ctx: _PanelContext\n) -> Optional[str]:\n    assert ctx.diff is not None\n    _draw_heatmap_pair(axes, row_idx=row_idx, t=ctx.diff, title=f\"{ctx.name} Abs Diff\")\n    return _format_stats(\"Abs Diff\", ctx.diff)\n\n\ndef _draw_diff_histogram(\n    axes: np.ndarray, row_idx: int, ctx: _PanelContext\n) -> Optional[str]:\n    assert ctx.diff is not None\n    _draw_histogram_pair(\n        axes, row_idx=row_idx, diff=ctx.diff, label=f\"{ctx.name} Abs Diff\"\n    )\n    return None\n\n\ndef _draw_hist2d(axes: np.ndarray, row_idx: int, ctx: _PanelContext) -> Optional[str]:\n    _draw_scatter_hist2d(\n        axes,\n        row_idx=row_idx,\n        baseline=ctx.baseline_2d,\n        target=ctx.target_2d,\n        label=ctx.name,\n    )\n    return None\n\n\ndef _draw_sampled(axes: np.ndarray, row_idx: int, ctx: _PanelContext) -> Optional[str]:\n    _draw_scatter_sampled(\n        axes,\n        row_idx=row_idx,\n        baseline=ctx.baseline_2d,\n        target=ctx.target_2d,\n        label=ctx.name,\n    )\n    return None\n\n\n# ────────────────────── internal drawing helpers ──────────────────────\n\n\ndef _draw_heatmap_pair(\n    axes: np.ndarray,\n    *,\n    row_idx: int,\n    t: torch.Tensor,\n    title: str,\n) -> None:\n    import matplotlib.pyplot as plt\n\n    ax_normal = axes[row_idx, 0]\n    ax_log = axes[row_idx, 1]\n\n    im = ax_normal.imshow(t.numpy(), aspect=\"auto\", cmap=\"viridis\")\n    ax_normal.set_title(title)\n    plt.colorbar(im, ax=ax_normal)\n\n    im_log = ax_log.imshow(_to_log10(t).numpy(), aspect=\"auto\", cmap=\"viridis\")\n    ax_log.set_title(f\"{title} (Log10)\")\n    cbar = plt.colorbar(im_log, ax=ax_log)\n    _format_log_ticks(cbar.ax, axis=\"y\")\n\n\ndef _draw_histogram_pair(\n    axes: np.ndarray,\n    *,\n    row_idx: int,\n    diff: torch.Tensor,\n    label: str,\n) -> None:\n\n    ax_normal = axes[row_idx, 0]\n    ax_log = axes[row_idx, 1]\n\n    diff_flat: np.ndarray = _maybe_downsample_numpy(diff.flatten())\n\n    _safe_hist(ax_normal, diff_flat, bins=100, edgecolor=\"none\")\n    ax_normal.set_title(f\"{label} Histogram\")\n    ax_normal.set_xlabel(\"Abs Diff\")\n    ax_normal.set_ylabel(\"Count\")\n\n    log_flat: np.ndarray = np.log10(np.abs(diff_flat) + 1e-10)\n    _safe_hist(ax_log, log_flat, bins=100, edgecolor=\"none\")\n    ax_log.set_title(f\"{label} Histogram (Log10)\")\n    ax_log.set_xlabel(\"Abs Diff\")\n    ax_log.set_ylabel(\"Count\")\n    _format_log_ticks(ax_log, axis=\"x\")\n\n\ndef _draw_scatter_hist2d(\n    axes: np.ndarray,\n    *,\n    row_idx: int,\n    baseline: torch.Tensor,\n    target: torch.Tensor,\n    label: str,\n) -> None:\n    import matplotlib.pyplot as plt\n\n    ax_normal = axes[row_idx, 0]\n    ax_log = axes[row_idx, 1]\n\n    b_flat: np.ndarray = _maybe_downsample_numpy(baseline.flatten())\n    t_flat: np.ndarray = _maybe_downsample_numpy(target.flatten())\n    min_len: int = min(len(b_flat), len(t_flat))\n    b_flat = b_flat[:min_len]\n    t_flat = t_flat[:min_len]\n\n    # Normal scale\n    lim: float = float(max(np.abs(b_flat).max(), np.abs(t_flat).max())) * 1.05\n    if lim == 0:\n        lim = 1.0\n    _h, _xe, _ye, im = ax_normal.hist2d(\n        b_flat,\n        t_flat,\n        bins=200,\n        range=[[-lim, lim], [-lim, lim]],\n        cmap=\"viridis\",\n        norm=\"log\",\n    )\n    ax_normal.plot([-lim, lim], [-lim, lim], \"r--\", linewidth=0.5)\n    ax_normal.set_title(f\"{label} Hist2D\")\n    ax_normal.set_xlabel(\"Baseline\")\n    ax_normal.set_ylabel(\"Target\")\n    ax_normal.set_aspect(\"equal\")\n    plt.colorbar(im, ax=ax_normal)\n\n    # Log scale\n    b_log: np.ndarray = np.log10(np.abs(b_flat) + 1e-10)\n    t_log: np.ndarray = np.log10(np.abs(t_flat) + 1e-10)\n    vmin: float = float(min(b_log.min(), t_log.min())) - 0.5\n    vmax: float = float(max(b_log.max(), t_log.max())) + 0.5\n    _h2, _xe2, _ye2, im2 = ax_log.hist2d(\n        b_log,\n        t_log,\n        bins=200,\n        range=[[vmin, vmax], [vmin, vmax]],\n        cmap=\"viridis\",\n        norm=\"log\",\n    )\n    ax_log.plot([vmin, vmax], [vmin, vmax], \"r--\", linewidth=0.5)\n    ax_log.set_title(f\"{label} Hist2D (Log10 Abs)\")\n    ax_log.set_xlabel(\"Baseline\")\n    ax_log.set_ylabel(\"Target\")\n    ax_log.set_aspect(\"equal\")\n    plt.colorbar(im2, ax=ax_log)\n    _format_log_ticks(ax_log, axis=\"both\")\n\n\ndef _draw_scatter_sampled(\n    axes: np.ndarray,\n    *,\n    row_idx: int,\n    baseline: torch.Tensor,\n    target: torch.Tensor,\n    label: str,\n) -> None:\n    import matplotlib.pyplot as plt\n\n    ax_baseline = axes[row_idx, 0]\n    ax_target = axes[row_idx, 1]\n\n    b_flat: np.ndarray = baseline.flatten().numpy()\n    t_flat: np.ndarray = target.flatten().numpy()\n\n    n_samples: int = min(_SCATTER_SAMPLE_SIZE, len(b_flat))\n    rng: np.random.Generator = np.random.default_rng(seed=42)\n    indices: np.ndarray = np.sort(rng.choice(len(b_flat), n_samples, replace=False))\n    b_sampled: np.ndarray = b_flat[indices]\n    t_sampled: np.ndarray = t_flat[indices]\n\n    side: int = int(np.sqrt(n_samples))\n    n_use: int = side * side\n    b_2d: np.ndarray = b_sampled[:n_use].reshape(side, side)\n    t_2d: np.ndarray = t_sampled[:n_use].reshape(side, side)\n\n    vmin: float = float(min(b_2d.min(), t_2d.min()))\n    vmax: float = float(max(b_2d.max(), t_2d.max()))\n\n    im_b = ax_baseline.imshow(b_2d, aspect=\"auto\", cmap=\"viridis\", vmin=vmin, vmax=vmax)\n    ax_baseline.set_title(f\"{label} Baseline (10k sampled)\")\n    plt.colorbar(im_b, ax=ax_baseline)\n\n    im_t = ax_target.imshow(t_2d, aspect=\"auto\", cmap=\"viridis\", vmin=vmin, vmax=vmax)\n    ax_target.set_title(f\"{label} Target (10k sampled)\")\n    plt.colorbar(im_t, ax=ax_target)\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/comparator/visualizer/preprocessing.py",
    "content": "\"\"\"Tensor preprocessing and utility functions for visualization.\"\"\"\n\nfrom __future__ import annotations\n\nimport math\nimport re\n\nimport numpy as np\nimport torch\n\n_DOWNSAMPLE_THRESHOLD: int = 10_000_000\n_SCATTER_SAMPLE_SIZE: int = 10_000\n\n\ndef _preprocess_tensor(tensor: torch.Tensor) -> torch.Tensor:\n    t: torch.Tensor = tensor.squeeze()\n\n    while t.ndim < 2:\n        t = t.unsqueeze(0)\n    if t.ndim > 2:\n        t = t.reshape(-1, t.shape[-1])\n\n    t = _reshape_to_balanced_aspect(t)\n    return t\n\n\ndef _reshape_to_balanced_aspect(\n    t: torch.Tensor, max_ratio: float = 5.0\n) -> torch.Tensor:\n    assert t.ndim == 2\n\n    h, w = t.shape\n    ratio: float = h / w if w > 0 else float(\"inf\")\n\n    if 1 / max_ratio <= ratio <= max_ratio:\n        return t\n\n    total: int = h * w\n    target_side: int = int(math.sqrt(total))\n\n    for new_h in range(target_side, 0, -1):\n        if total % new_h == 0:\n            new_w: int = total // new_h\n            new_ratio: float = new_h / new_w\n            if 1 / max_ratio <= new_ratio <= max_ratio:\n                return t.reshape(new_h, new_w)\n\n    return t.reshape(1, -1)\n\n\n# ────────────────────── utility ──────────────────────\n\n\ndef _to_log10(t: torch.Tensor) -> torch.Tensor:\n    return t.abs().clamp(min=1e-10).log10()\n\n\ndef _format_log_ticks(ax: object, axis: str = \"both\") -> None:\n    from matplotlib.ticker import FuncFormatter\n\n    formatter = FuncFormatter(\n        lambda x, _: f\"1e{int(x)}\" if x == int(x) else f\"1e{x:.1f}\"\n    )\n    if axis in (\"x\", \"both\"):\n        ax.xaxis.set_major_formatter(formatter)\n    if axis in (\"y\", \"both\"):\n        ax.yaxis.set_major_formatter(formatter)\n\n\ndef _format_stats(name: str, t: torch.Tensor) -> str:\n    return (\n        f\"{name}: shape={tuple(t.shape)}, \"\n        f\"min={t.min().item():.4g}, max={t.max().item():.4g}, \"\n        f\"mean={t.mean().item():.4g}, std={t.std().item():.4g}\"\n    )\n\n\ndef _safe_hist(\n    ax: object, data: np.ndarray, *, bins: int = 100, **kwargs: object\n) -> None:\n    data_f64: np.ndarray = data.astype(np.float64)\n    try:\n        ax.hist(data_f64, bins=bins, **kwargs)\n    except ValueError:\n        ax.hist(data_f64, bins=max(1, len(np.unique(data_f64[:1000]))), **kwargs)\n\n\ndef _maybe_downsample_numpy(\n    t: torch.Tensor,\n    max_elements: int = _DOWNSAMPLE_THRESHOLD,\n) -> np.ndarray:\n    if t.numel() <= max_elements:\n        return t.numpy()\n\n    rng: np.random.Generator = np.random.default_rng(seed=0)\n    indices: np.ndarray = rng.choice(t.numel(), max_elements, replace=False)\n    return t.numpy()[indices]\n\n\ndef _sanitize_filename(name: str) -> str:\n    return re.sub(r\"[/\\.\\s]+\", \"_\", name).strip(\"_\")\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/cuda_coredump.py",
    "content": "\"\"\"CUDA coredump helpers.\n\nWhen SGLANG_CUDA_COREDUMP=1, this module injects CUDA coredump environment\nvariables into the current process so that GPU exceptions (e.g. illegal\nmemory access) produce lightweight coredump files for post-mortem analysis\nwith cuda-gdb.\n\nThe injection happens at module import time via _inject_env() on a\nbest-effort basis.  If any CUDA_* variable is already present in the\nenvironment (e.g. set by the user in the shell), injection is skipped for\nthat variable and a warning is printed.  For strict guarantees, set the\nCUDA_* env vars in the shell before launching Python.\n\"\"\"\n\nimport glob\nimport os\nimport warnings\n\nfrom sglang.srt.environ import envs\n\n_CUDA_COREDUMP_FLAGS = (\n    \"skip_nonrelocated_elf_images,skip_global_memory,\"\n    \"skip_shared_memory,skip_local_memory,skip_constbank_memory\"\n)\n\n\ndef is_enabled() -> bool:\n    return envs.SGLANG_CUDA_COREDUMP.get()\n\n\ndef get_dump_dir() -> str:\n    return envs.SGLANG_CUDA_COREDUMP_DIR.get()\n\n\ndef _inject_env():\n    \"\"\"Inject CUDA coredump environment variables into the current process.\n    If a CUDA_* variable is already present, skip it and log a warning.\"\"\"\n    dump_dir = get_dump_dir()\n    os.makedirs(dump_dir, exist_ok=True)\n\n    env_vars = {\n        \"CUDA_ENABLE_COREDUMP_ON_EXCEPTION\": \"1\",\n        \"CUDA_COREDUMP_SHOW_PROGRESS\": \"1\",\n        \"CUDA_COREDUMP_GENERATION_FLAGS\": _CUDA_COREDUMP_FLAGS,\n        \"CUDA_COREDUMP_FILE\": f\"{dump_dir}/cuda_coredump_%h.%p.%t\",\n    }\n    for key, value in env_vars.items():\n        if key in os.environ:\n            warnings.warn(\n                f\"CUDA coredump env var {key} is already set to \"\n                f\"'{os.environ[key]}', skipping injection of '{value}'.\",\n                stacklevel=2,\n            )\n        else:\n            os.environ[key] = value\n\n\ndef cleanup_dump_dir():\n    \"\"\"Remove stale coredump files from the dump directory.\"\"\"\n    dump_dir = get_dump_dir()\n    for f in glob.glob(os.path.join(dump_dir, \"cuda_coredump_*\")):\n        os.remove(f)\n\n\ndef report():\n    \"\"\"Log any CUDA coredump files found after a test failure.\"\"\"\n    dump_dir = get_dump_dir()\n    coredump_files = glob.glob(os.path.join(dump_dir, \"cuda_coredump_*\"))\n    if not coredump_files:\n        return\n\n    print(f\"\\n{'='*60}\")\n    print(f\"CUDA coredump(s) detected ({len(coredump_files)} file(s)):\")\n    for f in coredump_files:\n        size_mb = os.path.getsize(f) / (1024 * 1024)\n        print(f\"  {f} ({size_mb:.1f} MB)\")\n    print(\"Use cuda-gdb to analyze: cuda-gdb -c <coredump_file>\")\n\n    run_id = os.environ.get(\"GITHUB_RUN_ID\")\n    if run_id:\n        repo = os.environ.get(\"GITHUB_REPOSITORY\", \"sgl-project/sglang\")\n        print(f\"Download from CI: gh run download {run_id} --repo {repo}\")\n\n    print(f\"{'='*60}\\n\")\n\n\n# Auto-inject CUDA coredump env vars at import time.\n# The sentinel env var is inherited by child processes, so injection only\n# happens once in the top-level process.\n_SENTINEL = \"_SGLANG_CUDA_COREDUMP_INJECTED\"\n\nif is_enabled() and _SENTINEL not in os.environ:\n    os.environ[_SENTINEL] = \"1\"\n    print(f\"Injecting CUDA coredump env vars (pid={os.getpid()})\")\n    _inject_env()\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/dump_comparator.py",
    "content": "\"\"\"Simplified dump comparator — a self-contained single-file script for comparing\ntwo dump directories tensor-by-tensor.\n\nFor advanced features (unshard, token alignment, per-dimension annotations), see the\nfull ``comparator/`` package: ``python -m sglang.srt.debug_utils.comparator``.\n\"\"\"\n\nimport argparse\nimport functools\nimport re\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom typing import Callable, List, Optional\n\nimport torch\n\nfrom sglang.srt.debug_utils.dumper import get_truncated_value\n\n\ndef main(args):\n    import polars as pl\n\n    from sglang.srt.debug_utils.dump_loader import find_row, read_meta\n\n    df_target = read_meta(args.target_path)\n    df_target = df_target.filter(\n        (pl.col(\"step\") >= args.start_step) & (pl.col(\"step\") <= args.end_step)\n    )\n    if args.filter:\n        df_target = df_target.filter(pl.col(\"filename\").str.contains(args.filter))\n    assert all(c in df_target.columns for c in [\"rank\", \"step\", \"dump_index\", \"name\"])\n\n    df_baseline = read_meta(args.baseline_path)\n    print(\"df_target\", df_target)\n    print(\"df_baseline\", df_baseline)\n\n    tensor_dim_descs: List[TensorDimDesc] = _get_tensor_dim_descs()\n\n    for row in df_target.iter_rows(named=True):\n        path_target = Path(args.target_path) / row[\"filename\"]\n\n        tensor_dim_desc: Optional[TensorDimDesc] = None\n        if tensor_dim_descs:\n            matched: list[TensorDimDesc] = [\n                desc\n                for desc in tensor_dim_descs\n                if re.search(desc.pattern, row[\"filename\"]) is not None\n            ]\n            if matched:\n                tensor_dim_desc = matched[0]\n\n        row_baseline = find_row(\n            df_baseline,\n            conditions=dict(\n                step=row[\"step\"],\n                **{\n                    k: v\n                    for k, v in row.items()\n                    if k not in [\"step\", \"dump_index\", \"filename\"]\n                },\n            ),\n        )\n\n        if row_baseline is None:\n            print(f\"Skip: target={str(path_target)} since no baseline\")\n            x_target = _load_object(path_target)\n            if x_target is not None:\n                print(f\"x_target(sample)={get_truncated_value(x_target)}\")\n            continue\n\n        path_baseline = Path(args.baseline_path) / row_baseline[\"filename\"]\n        print(\n            f\"Check:\\n\"\n            f\"target={str(path_target)} (duplicate_index={row['duplicate_index']})\\n\"\n            f\"baseline={str(path_baseline)} (duplicate_index={row_baseline['duplicate_index']})\"\n        )\n        check_tensor_pair(\n            path_baseline=path_baseline,\n            path_target=path_target,\n            diff_threshold=args.diff_threshold,\n            name=row[\"name\"],\n            tensor_dim_desc=tensor_dim_desc,\n        )\n        print()\n\n\ndef check_tensor_pair(\n    path_baseline,\n    path_target,\n    diff_threshold: float = 1e-3,\n    name=\"\",\n    tensor_dim_desc: Optional[\"TensorDimDesc\"] = None,\n):\n    x_baseline = _load_object(path_baseline)\n    x_target = _load_object(path_target)\n\n    if x_baseline is None or x_target is None:\n        print(\n            f\"Skip comparison because of None: x_baseline={x_baseline}, x_target={x_target}\"\n        )\n        return\n\n    print(\n        f\"Raw \"\n        f\"[shape] {x_baseline.shape} vs {x_target.shape}\\t\"\n        f\"[{'' if x_baseline.dtype == x_target.dtype else '🟠'}dtype] {x_baseline.dtype} vs {x_target.dtype}\"\n    )\n\n    if tensor_dim_desc is not None:\n        import einops\n\n        x_baseline = einops.rearrange(\n            x_baseline,\n            tensor_dim_desc.baseline_desc + \" -> \" + tensor_dim_desc.target_desc,\n        )\n        if tensor_dim_desc.baseline_cropper is not None:\n            print(\"Apply baseline_cropper\")\n            x_baseline = tensor_dim_desc.baseline_cropper(x_baseline)\n\n    x_baseline, x_target = _comparison_preprocessor(x_baseline, x_target, name=name)\n    x_baseline = _try_unify_shape(x_baseline, target_shape=x_target.shape)\n\n    print(\n        f\"After preprocessor \"\n        f\"[shape] {x_baseline.shape} vs {x_target.shape}\\t\"\n        f\"[dtype] {x_baseline.dtype} vs {x_target.dtype}\"\n    )\n\n    x_baseline_original_dtype = x_baseline.dtype\n    x_target_original_dtype = x_target.dtype\n\n    x_target = x_target.float()\n    x_baseline = x_baseline.float()\n\n    for name, fn in [\n        (\"mean\", torch.mean),\n        (\"std\", torch.std),\n        (\"min\", torch.min),\n        (\"max\", torch.max),\n        *(\n            [\n                (\"p1\", functools.partial(torch.quantile, q=0.01)),\n                (\"p5\", functools.partial(torch.quantile, q=0.05)),\n                (\"p95\", functools.partial(torch.quantile, q=0.95)),\n                (\"p99\", functools.partial(torch.quantile, q=0.99)),\n            ]\n            if x_baseline.numel() < 10_000_000\n            else []\n        ),\n    ]:\n        value_baseline = fn(x_baseline).item()\n        value_target = fn(x_target).item()\n        print(\n            f\"[{name}] {value_baseline :.4f} vs {value_target:.4f} (diff: {value_target - value_baseline:.4f})\"\n        )\n\n    if x_baseline.shape != x_target.shape:\n        print(f\"⚠️ Shape mismatch\")\n        return\n\n    diff_info = _compute_and_print_diff(\n        x_baseline=x_baseline,\n        x_target=x_target,\n        diff_threshold=diff_threshold,\n    )\n    needs_print = diff_info[\"max_abs_diff\"] > 1e-3\n\n    if (x_baseline_original_dtype != x_target_original_dtype) and (\n        (\n            downcast_dtype := _compute_smaller_dtype(\n                x_baseline_original_dtype, x_target_original_dtype\n            )\n        )\n        is not None\n    ):\n        _compute_and_print_diff(\n            x_baseline=x_baseline.to(downcast_dtype),\n            x_target=x_target.to(downcast_dtype),\n            diff_threshold=diff_threshold,\n            prefix_text=f\"When downcast to {downcast_dtype}: \",\n        )\n\n    if needs_print:\n        print(f\"x_baseline(sample)={get_truncated_value(x_baseline)}\")\n        print(f\"x_target(sample)={get_truncated_value(x_target)}\")\n\n\ndef _compute_and_print_diff(\n    x_baseline, x_target, diff_threshold: float, prefix_text=\"\"\n):\n    raw_abs_diff = (x_target - x_baseline).abs()\n\n    max_abs_diff = raw_abs_diff.max().item()\n    mean_abs_diff = raw_abs_diff.mean().item()\n    rel_diff = _calc_rel_diff(x_target, x_baseline)\n\n    rel_diff_marker: str = \"❌\" if rel_diff > diff_threshold else \"✅\"\n    print(\n        prefix_text\n        + f\"{rel_diff_marker} rel_diff={rel_diff}\\t\"\n        + f\"max_abs_diff={max_abs_diff}\\t\"\n        + f\"mean_abs_diff={mean_abs_diff}\"\n    )\n\n    max_diff_coord = _argmax_coord(raw_abs_diff)\n    print(\n        f\"max_abs_diff happens at coord={max_diff_coord} with \"\n        f\"baseline={x_baseline[max_diff_coord].item()} \"\n        f\"target={x_target[max_diff_coord].item()}\"\n    )\n\n    return dict(max_abs_diff=max_abs_diff)\n\n\ndef _argmax_coord(x: torch.Tensor) -> tuple:\n    flat_idx = x.argmax()\n    return tuple(idx.item() for idx in torch.unravel_index(flat_idx, x.shape))\n\n\ndef _compute_smaller_dtype(dtype_a, dtype_b):\n    info_dict = {\n        (torch.float32, torch.bfloat16): torch.bfloat16,\n        # ... add more ...\n    }\n    return info_dict.get((dtype_a, dtype_b)) or info_dict.get((dtype_b, dtype_a))\n\n\ndef _try_unify_shape(x: torch.Tensor, target_shape):\n    x_shape = x.shape\n    num_dim_to_remove = len(x_shape) - len(target_shape)\n    if (x_shape[num_dim_to_remove:] == target_shape) and all(\n        val == 1 for val in x_shape[:num_dim_to_remove]\n    ):\n        out = functools.reduce(lambda a, _: a.squeeze(0), range(num_dim_to_remove), x)\n        print(f\"Unify shape: {x_shape} -> {out.shape} (to match {target_shape})\")\n        return out\n\n    return x\n\n\n# Copied from DeepGEMM\ndef _calc_rel_diff(x: torch.Tensor, y: torch.Tensor):\n    x, y = x.double(), y.double()\n    denominator = (x * x + y * y).sum()\n    sim = 2 * (x * y).sum() / denominator\n    return 1 - sim\n\n\ndef _load_object(path):\n    try:\n        x = torch.load(path, weights_only=False)\n    except Exception as e:\n        print(f\"Skip load {path} since error {e}\")\n        return None\n\n    if isinstance(x, dict) and \"value\" in x:\n        x = x[\"value\"]\n\n    if not isinstance(x, torch.Tensor):\n        print(f\"Skip load {path} since {type(x)=} is not a Tensor ({x=})\")\n        return None\n    return x.cuda()\n\n\ndef _comparison_preprocessor(x_baseline, x_target, name):\n    \"\"\"Customization endpoint. Can insert arbitrary adhoc postprocessing logic here.\"\"\"\n    return x_baseline, x_target\n\n\n@dataclass\nclass TensorDimDesc:\n    pattern: str\n    baseline_desc: str\n    target_desc: str\n    baseline_cropper: Optional[Callable[[torch.Tensor], torch.Tensor]] = None\n\n\ndef _get_tensor_dim_descs() -> List[TensorDimDesc]:\n    \"\"\"Customization endpoint. Return a list of TensorDimDesc to rearrange baseline\n    dimensions to match target layout via einops before comparison.\"\"\"\n    return []\n\n\nif __name__ == \"__main__\":\n    # python -m sglang.srt.debug_utils.dump_comparator --baseline-path ... --target-path ...\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--baseline-path\", type=str)\n    parser.add_argument(\"--target-path\", type=str)\n    parser.add_argument(\"--start-step\", type=int, default=0)\n    parser.add_argument(\"--end-step\", type=int, default=1000000)\n    parser.add_argument(\"--diff-threshold\", type=float, default=1e-3)\n    parser.add_argument(\n        \"--filter\", type=str, default=None, help=\"Regex to filter filenames\"\n    )\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/dump_loader.py",
    "content": "import functools\nimport os\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom typing import Any, Callable, Dict, Optional, Tuple\n\nimport polars as pl\nimport torch\n\nLOAD_FAILED: object = object()\n\n\ndef parse_meta_from_filename(path: Path) -> Dict[str, Any]:\n    stem = Path(path).stem\n    result: Dict[str, Any] = {}\n    for kv in stem.split(\"___\"):\n        if \"=\" in kv:\n            k, v = kv.split(\"=\", 1)\n            result[k] = v\n    for field_name, converter in _TYPED_FIELDS:\n        if field_name in result:\n            result[field_name] = converter(result[field_name])\n    return result\n\n\n@dataclass\nclass ValueWithMeta:\n    value: Any\n    meta: Dict[str, Any]\n\n    @staticmethod\n    def load(path: Path) -> \"ValueWithMeta\":\n        path = Path(path)\n        meta_from_filename = parse_meta_from_filename(path)\n\n        try:\n            raw = torch.load(path, weights_only=False, map_location=\"cpu\")\n        except Exception as e:\n            print(f\"Skip load {path} since error {e}\")\n            return ValueWithMeta(\n                value=LOAD_FAILED, meta={**meta_from_filename, \"filename\": path.name}\n            )\n\n        value, meta_from_embedded = _unwrap_dict_format(raw)\n        return ValueWithMeta(\n            value=value,\n            meta={**meta_from_filename, **meta_from_embedded, \"filename\": path.name},\n        )\n\n\ndef _unwrap_dict_format(obj: Any) -> Tuple[Any, Dict[str, Any]]:\n    if isinstance(obj, dict) and \"value\" in obj:\n        meta = obj.get(\"meta\", {})\n        assert isinstance(meta, dict), f\"Expected meta to be dict, got {type(meta)}\"\n        return obj[\"value\"], meta\n    return obj, {}\n\n\nclass DumpLoader:\n    def __init__(self):\n        directory = os.environ.get(\"SGLANG_DUMP_LOADER_DIR\")\n\n        self._enable = directory is not None\n        if self._enable:\n            self._directory = Path(directory)\n            self._df = read_meta(directory)\n\n    @property\n    def enable(self):\n        return self._enable\n\n    def load(self, name, **kwargs):\n        assert self._enable, \"Please call DumpLoader.load only when it is enabled\"\n\n        from sglang.srt.debug_utils.dumper import dumper\n\n        step = dumper._state.step\n        conditions = dict(name=name, step=step, **kwargs)\n        row = find_row(self._df, conditions=conditions)\n        assert (\n            row is not None\n        ), f\"DumpLoader cannot find row given query {name=} {kwargs=} {self._directory=}\"\n\n        path = self._directory / row[\"filename\"]\n        output = torch.load(path, weights_only=False)\n        if isinstance(output, dict) and \"value\" in output:\n            output = output[\"value\"]\n\n        print(\n            f\"[DumpLoader] load from {path=} (query: {name=} {kwargs=}, output: {type(output)})\"\n        )\n        return output\n\n\ndef read_meta(directory):\n    directory = Path(directory)\n    assert directory.is_dir(), f\"{directory=} should be a directory\"\n\n    rows = []\n    for p in directory.glob(\"*.pt\"):\n        try:\n            full_kwargs = parse_meta_from_filename(p)\n            rows.append(\n                {\n                    \"filename\": str(p.name),\n                    **full_kwargs,\n                }\n            )\n        except Exception as e:\n            print(f\"[DumpLoader] skip loading {p} due to error {e}\")\n\n    df = pl.DataFrame(rows)\n    df = df.with_columns(\n        pl.col(\"step\").cast(int),\n        pl.col(\"rank\").cast(int),\n        pl.col(\"dump_index\").cast(int),\n    )\n    df = _add_duplicate_index(df)\n    df = df.sort(\"rank\", \"dump_index\")\n    return df\n\n\ndef _add_duplicate_index(df: pl.DataFrame) -> pl.DataFrame:\n    group_cols = [c for c in df.columns if c not in [\"filename\", \"dump_index\"]]\n    df = df.sort(group_cols + [\"dump_index\"])\n    df = df.with_columns(\n        pl.cum_count(\"dump_index\").over(group_cols).sub(1).alias(\"duplicate_index\")\n    )\n    return df\n\n\ndef filter_rows(df: pl.DataFrame, conditions: Dict[str, Any]) -> list[dict]:\n    filter_exprs = [\n        (\n            pl.col(col) == _cast_to_polars_dtype(conditions[col], df.schema[col])\n            if conditions[col] is not None\n            else pl.col(col).is_null()\n        )\n        for col in conditions\n        if col in df.columns\n    ]\n    if not filter_exprs:\n        return []\n    return df.filter(functools.reduce(lambda a, b: a & b, filter_exprs)).to_dicts()\n\n\ndef find_row(df: pl.DataFrame, conditions: Dict[str, Any]):\n    rows = filter_rows(df, conditions)\n    if len(rows) > 1:\n        print(f\"find_row find ambiguous results: {rows=}\")\n        return None\n    return rows[0] if rows else None\n\n\ndef _cast_to_polars_dtype(value, target_dtype):\n    if target_dtype in (pl.Int64, pl.Int32, pl.UInt64, pl.UInt32):\n        return int(value)\n    elif target_dtype in (pl.Float64, pl.Float32):\n        return float(value)\n    elif target_dtype == pl.Boolean:\n        return bool(value)\n    elif target_dtype == pl.String:\n        return str(value)\n    else:\n        return value\n\n\ndef read_tokenizer_path(directory: Path) -> Optional[str]:\n    \"\"\"Read tokenizer_path from any .pt file's embedded metadata in a dump directory.\"\"\"\n    for p in directory.glob(\"*.pt\"):\n        item: ValueWithMeta = ValueWithMeta.load(p)\n        tokenizer_path: Optional[str] = item.meta.get(\"tokenizer_path\")\n        if tokenizer_path is not None:\n            return str(tokenizer_path)\n    return None\n\n\n_TYPED_FIELDS: list[tuple[str, Callable[[str], Any]]] = [\n    (\"rank\", int),\n]\n\n\ndump_loader = DumpLoader()\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/dumper.py",
    "content": "import enum\nimport functools\nimport json\nimport os\nimport random\nimport re\nimport socket\nimport threading\nimport time\nfrom abc import ABC, abstractmethod\nfrom collections.abc import Callable\nfrom contextlib import contextmanager\nfrom copy import deepcopy\nfrom dataclasses import asdict, dataclass, field, fields, replace\nfrom functools import cached_property\nfrom http.server import BaseHTTPRequestHandler, HTTPServer\nfrom pathlib import Path\nfrom typing import Any, List, Literal, Optional, Union, get_args, get_type_hints\n\nimport torch\nimport torch.distributed as dist\n\n# -------------------------------------- config base ------------------------------------------\n\n\n@dataclass(frozen=True)\nclass _BaseConfig(ABC):\n    def __post_init__(self) -> None:\n        self._verify_types()\n\n    def _verify_types(self) -> None:\n        hints = get_type_hints(type(self))\n        cls_name = type(self).__name__\n        for f in fields(self):\n            value = getattr(self, f.name)\n            if value is None:\n                continue\n            expected = self._unwrap_type(hints[f.name])\n            if not isinstance(value, expected):\n                raise TypeError(\n                    f\"{cls_name}.{f.name}: expected {expected.__name__}, \"\n                    f\"got {type(value).__name__}\"\n                )\n\n    @classmethod\n    @abstractmethod\n    def _env_prefix(cls) -> str: ...\n\n    @classmethod\n    def _env_name(cls, field_name: str) -> str:\n        return f\"{cls._env_prefix()}{field_name.upper()}\"\n\n    @classmethod\n    def from_env(cls) -> \"_BaseConfig\":\n        return cls(\n            **{\n                f.name: cls._parse_env_field(cls._env_name(f.name), f.default)\n                for f in fields(cls)\n            }\n        )\n\n    def with_defaults(self, **kwargs) -> \"_BaseConfig\":\n        cls = type(self)\n        actual = {\n            key: value\n            for key, value in kwargs.items()\n            if os.getenv(cls._env_name(key)) is None\n        }\n        return replace(self, **actual) if actual else self\n\n    @staticmethod\n    def _unwrap_type(hint) -> type:\n        args = get_args(hint)\n        if args:\n            return next(a for a in args if a is not type(None))\n        return hint\n\n    @classmethod\n    def _parse_env_field(cls, env_name: str, default):\n        return cls._parse_env_value(os.getenv(env_name), default)\n\n    @staticmethod\n    def _parse_env_value(raw, default):\n        if raw is None or not raw.strip():\n            return default\n        if isinstance(default, bool):\n            return raw.lower() in (\"true\", \"1\")\n        if isinstance(default, int):\n            return int(raw)\n        return raw\n\n    @classmethod\n    def from_kv_pairs(cls, pairs: Optional[List[str]]) -> \"_BaseConfig\":\n        return cls(**cls._kv_pairs_to_dict(pairs))\n\n    @classmethod\n    def _kv_pairs_to_dict(cls, pairs: Optional[List[str]]) -> dict:\n        if not pairs:\n            return {}\n\n        missing = object()\n        defaults = {f.name: f.default for f in fields(cls)}\n        result: dict = {}\n\n        for pair in pairs:\n            key, sep, value = pair.partition(\"=\")\n            if not sep:\n                raise ValueError(f\"Invalid config pair (missing '='): {pair!r}\")\n            default = defaults.get(key, missing)\n            if default is missing:\n                raise ValueError(\n                    f\"Unknown config key {key!r}. Valid keys: {sorted(defaults)}\"\n                )\n            try:\n                result[key] = cls._parse_env_value(value, default)\n            except (ValueError, TypeError) as exc:\n                field_type = type(default).__name__\n                raise TypeError(f\"{key}: expected {field_type}, got {value!r}\") from exc\n\n        return result\n\n\n_DEFAULT_EXP_NAME_PREFIX = \"dump_\"\n\n\n@dataclass(frozen=True)\nclass DumperConfig(_BaseConfig):\n    enable: bool = False\n    filter: Optional[str] = None\n    dir: str = \"/tmp/dumper\"\n    enable_output_file: bool = True\n    enable_output_console: bool = True\n    enable_value: bool = True\n    enable_grad: bool = False\n    enable_model_value: bool = False\n    enable_model_grad: bool = False\n    exp_name: Optional[str] = None\n    cleanup_previous: bool = False\n    collective_timeout: int = 60\n    server_port: str = \"-1\"\n    non_intrusive_mode: str = \"core\"\n    source_patcher_config: Optional[str] = None\n\n    @classmethod\n    def _env_prefix(cls) -> str:\n        # NOTE: should not be `SGLANG_DUMPER_`, otherwise it is weird when dumping Megatron in Miles\n        return \"DUMPER_\"\n\n    @property\n    def server_port_parsed(self) -> Optional[Union[int, Literal[\"reuse\"]]]:\n        raw = self.server_port\n        if raw == \"reuse\":\n            return \"reuse\"\n        port = int(raw)\n        if port <= 0:\n            return None\n        return port\n\n\n# -------------------------------------- dumper core ------------------------------------------\n\n\n@dataclass\nclass _DumperState:\n    dump_index: int = 0\n    step: int = 0\n    global_ctx: dict = field(default_factory=dict)\n    captured_output_data: Optional[dict] = None\n    cleanup_previous_handled: bool = False\n\n\nclass _Dumper:\n    \"\"\"Utility to dump tensors, which can be useful when comparison checking models.\n\n    Example usage:\n    dumper.dump(\"layer_start__hidden_states\", hidden_states, layer_id=self.layer_id)\n    dumper.step()\n\n    Import from non-SGLang system:\n    ```\n    import sys\n    sys.path.append(\"/YOUR_PATH/sglang/python/sglang/srt/debug_utils\")\n    from dumper import dumper\n    ```\n\n    Then run the program:\n    `DUMPER_ENABLE=1 python ...`\n\n    Auto-cleanup old dumps before first write:\n    `DUMPER_CLEANUP_PREVIOUS=1 python ...`\n\n    Alternatively, disable at startup and configure via HTTP:\n    1. `python ...`\n    2. sglang mode:  `curl -X POST http://localhost:30000/dumper/configure -d '{\"enable\": true}'`\n       standalone:   `curl -X POST http://localhost:40000/dumper/configure -d '{\"enable\": true}'`\n    3. `curl -X POST http://localhost:30000/dumper/configure -d '{\"enable\": true, \"filter\": \"layer_id=[0-3]\"}'`\n    4. `curl -X POST http://localhost:30000/dumper/reset`\n\n    Related: `sglang.srt.debug_utils.dump_comparator` for dump comparison\n    \"\"\"\n\n    def __init__(self, *, config: DumperConfig):\n        self._config = config\n        self._state = _DumperState()\n        self._non_intrusives: list[\"_NonIntrusiveDumper\"] = []\n\n    # ------------------------------- public :: core ---------------------------------\n\n    @property\n    def may_enable(self) -> bool:\n        return self._config.enable or self._config.server_port_parsed is not None\n\n    def step(self):\n        \"\"\"This should be called on all ranks at the end of each iteration.\"\"\"\n\n        self._http_manager  # noqa: B018\n\n        if not self._config.enable:\n            return\n\n        # Users may want to `dump` only on some ranks, thus determine name here\n        self._ensure_exp_name()\n\n        self._state.step += 1\n        print(f\"[Dumper] [{time.time()}] step={self._state.step}\")\n\n    def dump(\n        self,\n        name: str,\n        value,\n        save: bool = True,\n        dims: Optional[str] = None,\n        dims_grad: Optional[str] = None,\n        **kwargs,\n    ) -> None:\n        value_meta: dict = {}\n        grad_meta: dict = {}\n        if dims is not None:\n            value_meta[\"dims\"] = dims\n            grad_meta[\"dims\"] = dims\n        if dims_grad is not None:\n            value_meta[\"dims_grad\"] = dims_grad\n            grad_meta[\"dims\"] = dims_grad\n\n        self._dump_inner(\n            name=name,\n            value=value,\n            extra_kwargs=kwargs,\n            save=save,\n            enable_value=self._config.enable_value,\n            enable_curr_grad=False,\n            enable_future_grad=self._config.enable_grad,\n            value_tag=\"Dumper.Value\",\n            grad_tag=\"Dumper.Grad\",\n            value_meta_only_fields=value_meta,\n            grad_meta_only_fields=grad_meta,\n        )\n\n    def dump_model(\n        self,\n        model: \"torch.nn.Module\",\n        name_prefix: str = \"param\",\n        save: bool = True,\n        **kwargs,\n    ) -> None:\n        for param_name, param in model.named_parameters():\n            self._dump_inner(\n                name=f\"{name_prefix}__{param_name}\",\n                value=param,\n                extra_kwargs=kwargs,\n                save=save,\n                enable_value=self._config.enable_model_value,\n                enable_curr_grad=self._config.enable_model_grad,\n                enable_future_grad=False,\n                value_tag=\"Dumper.ParamValue\",\n                grad_tag=\"Dumper.ParamGrad\",\n            )\n\n    def dump_dict(self, name_prefix, data, save: bool = True, **kwargs):\n        data = _obj_to_dict(data)\n        for name, value in data.items():\n            self.dump(f\"{name_prefix}_{name}\", value, save=save, **kwargs)\n\n    def set_ctx(self, **kwargs):\n        \"\"\"\n        Example:\n\n        dumper.configure_default(filter='layer_id=[0-3]')\n        dumper.set_ctx(layer_id=self.layer_id)\n        ...\n        dumper.set_ctx(layer_id=None)\n        \"\"\"\n        self._state.global_ctx = {\n            k: v for k, v in (self._state.global_ctx | kwargs).items() if v is not None\n        }\n\n    def ctx(\n        self,\n        _extractor: Optional[Callable[..., dict]] = None,\n        **static_ctx: Any,\n    ) -> Callable:\n        \"\"\"Decorator that sets context before calling the wrapped function and clears it after.\n\n        Two forms:\n            @dumper.ctx(lambda self: dict(layer_id=self.layer_id))\n            def forward(self, x): ...\n\n            @dumper.ctx(phase=\"decode\")\n            def decode_step(self, x): ...\n        \"\"\"\n        if _extractor is not None and static_ctx:\n            raise ValueError(\"cannot mix lambda extractor with static kwargs\")\n        if _extractor is None and not static_ctx:\n            raise ValueError(\"must provide either a lambda or static kwargs\")\n\n        def decorator(fn: Callable) -> Callable:\n            @functools.wraps(fn)\n            def wrapper(*args: Any, **kwargs: Any) -> Any:\n                ctx_dict: dict = _extractor(args[0]) if _extractor else static_ctx\n                self.set_ctx(**ctx_dict)\n                try:\n                    return fn(*args, **kwargs)\n                finally:\n                    self.set_ctx(**{k: None for k in ctx_dict})\n\n            return wrapper\n\n        return decorator\n\n    def apply_source_patches(self) -> None:\n        \"\"\"Apply source patches from DUMPER_SOURCE_PATCHER_CONFIG if set.\n\n        Automatically injects ``from sglang.srt.debug_utils.dumper import dumper``\n        into every replacement block so users don't need to write it in YAML.\n        \"\"\"\n        config_path = self._config.source_patcher_config\n        if not config_path:\n            return\n\n        from sglang.srt.debug_utils.source_patcher import apply_patches_from_config\n\n        yaml_content: str = Path(config_path).read_text()\n        print(f\"[source_patcher] loading config from {config_path}\")\n        apply_patches_from_config(\n            yaml_content,\n            extra_imports=[\"from sglang.srt.debug_utils.dumper import dumper\"],\n        )\n\n    def register_non_intrusive_dumper(\n        self,\n        model: \"torch.nn.Module\",\n    ) -> Optional[\"_NonIntrusiveDumper\"]:\n        self._http_manager  # noqa: B018\n        mode = self._config.non_intrusive_mode\n        if mode == \"off\":\n            return None\n        non_intrusive = _NonIntrusiveDumper(dumper=self, model=model, mode=mode)\n        self._non_intrusives.append(non_intrusive)\n        return non_intrusive\n\n    # ------------------------------- public :: secondary ---------------------------------\n\n    def configure(self, **kwargs) -> None:\n        self._config = replace(self._config, **kwargs)\n\n    def configure_default(self, **kwargs) -> None:\n        self._config = self._config.with_defaults(**kwargs)\n\n    def reset(self) -> None:\n        for non_intrusive in self._non_intrusives:\n            non_intrusive.remove()\n        self._non_intrusives.clear()\n        self._state = _DumperState()\n\n    @contextmanager\n    def capture_output(self):\n        assert self._state.captured_output_data is None\n        self._state.captured_output_data = {}\n        try:\n            yield self._state.captured_output_data\n        finally:\n            self._state.captured_output_data = None\n\n    def get_state(self) -> dict:\n        return {\n            \"config\": asdict(self._config),\n            \"dump_index\": self._state.dump_index,\n            \"step\": self._state.step,\n        }\n\n    @cached_property\n    def _http_manager(self) -> Optional[\"_DumperHttpManager\"]:\n        if self._config.server_port_parsed is None:\n            return None\n        return _DumperHttpManager(self)\n\n    # ------------------------- private :: related to dump -----------------------------\n\n    def _dump_inner(\n        self,\n        *,\n        name: str,\n        value,\n        extra_kwargs: dict,\n        save: bool,\n        enable_value: bool,\n        enable_curr_grad: bool,\n        enable_future_grad: bool,\n        value_tag: str,\n        grad_tag: str,\n        value_meta_only_fields: Optional[dict] = None,\n        grad_meta_only_fields: Optional[dict] = None,\n    ) -> None:\n        self._http_manager  # noqa: B018\n\n        if not self._config.enable:\n            return\n\n        recompute_status = _detect_recompute_status()\n        tags = dict(\n            name=name,\n            recompute_status=recompute_status.value,\n            **extra_kwargs,\n            **self._state.global_ctx,\n        )\n\n        if (f := self._config.filter) is not None and not _evaluate_filter(f, tags):\n            return\n\n        if not (enable_value or enable_curr_grad or enable_future_grad):\n            return\n\n        recompute_meta = recompute_status.to_pseudo_parallel_meta()\n        value = _materialize_value(value)\n\n        if enable_value:\n            self._dump_single(\n                tag=value_tag,\n                tags=tags,\n                value=value,\n                save=save,\n                meta_only_fields={**(value_meta_only_fields or {}), **recompute_meta},\n            )\n\n        if (\n            enable_curr_grad\n            and isinstance(value, torch.Tensor)\n            and (g := value.grad) is not None\n        ):\n            self._dump_single(\n                tag=grad_tag,\n                tags={**tags, \"name\": f\"grad__{name}\"},\n                value=g,\n                save=save,\n                meta_only_fields={**(grad_meta_only_fields or {}), **recompute_meta},\n            )\n\n        if enable_future_grad:\n            self._register_dump_grad_hook(\n                name=name,\n                tensor=value,\n                extra_kwargs=extra_kwargs,\n                save=save,\n                meta_only_fields=grad_meta_only_fields or {},\n            )\n\n    def _register_dump_grad_hook(\n        self,\n        *,\n        name: str,\n        tensor,\n        extra_kwargs: dict,\n        save: bool,\n        meta_only_fields: Optional[dict] = None,\n    ) -> None:\n        if not isinstance(tensor, torch.Tensor):\n            return\n        if not tensor.requires_grad:\n            return\n\n        captured_step = self._state.step\n        captured_tags = dict(\n            name=f\"grad__{name}\",\n            **deepcopy(extra_kwargs),\n        )\n        captured_meta_only = meta_only_fields or {}\n\n        def grad_hook(grad: torch.Tensor) -> None:\n            self._dump_single(\n                tag=\"Dumper.Grad\",\n                tags=captured_tags,\n                value=grad,\n                save=save,\n                step=captured_step,\n                meta_only_fields=captured_meta_only,\n            )\n\n        tensor.register_hook(grad_hook)\n\n    def _dump_single(\n        self,\n        *,\n        tag: str,\n        tags: dict,\n        value,\n        save: bool,\n        step: Optional[int] = None,\n        meta_only_fields: Optional[dict] = None,\n    ) -> None:\n        self._ensure_exp_name()\n        self._state.dump_index += 1\n\n        rank = _get_rank()\n        full_kwargs = dict(\n            step=(step if step is not None else self._state.step),\n            rank=rank,\n            dump_index=self._state.dump_index,\n            **tags,\n        )\n        full_filename = _format_tags(full_kwargs) + \".pt\"\n        path = Path(self._config.dir) / self._config.exp_name / full_filename\n\n        if self._config.enable_output_console:\n            print(\n                f\"[{tag}] [{rank}, {time.time()}] {path} \"\n                f\"type={type(value)} \"\n                f\"shape={value.shape if isinstance(value, torch.Tensor) else None} \"\n                f\"dtype={value.dtype if isinstance(value, torch.Tensor) else None} \"\n                f\"device={value.device if isinstance(value, torch.Tensor) else None} \"\n                f\"id={id(value)} \"\n                f\"sample_value={get_truncated_value(value)}\"\n            )\n\n        capturing = self._state.captured_output_data is not None\n        if save and (self._config.enable_output_file or capturing):\n            output_data = {\n                \"value\": value,\n                \"meta\": dict(\n                    **full_kwargs,\n                    **self._static_meta,\n                    **(meta_only_fields or {}),\n                ),\n            }\n\n            if capturing:\n                output_data[\"value\"] = _deepcopy_or_clone(output_data[\"value\"])\n                self._state.captured_output_data[tags[\"name\"]] = output_data\n            else:\n                if (\n                    not self._state.cleanup_previous_handled\n                    and self._config.cleanup_previous\n                ):\n                    self._state.cleanup_previous_handled = True\n                    _cleanup_old_dumps(\n                        Path(self._config.dir), exp_name=self._config.exp_name\n                    )\n\n                path.parent.mkdir(parents=True, exist_ok=True)\n                _torch_save(output_data, str(path))\n\n    # ------------------------------- private :: misc ---------------------------------\n\n    @cached_property\n    def _static_meta(self) -> dict:\n        return _compute_static_meta()\n\n    def _ensure_exp_name(self):\n        if self._config.exp_name is None:\n            name = _get_default_exp_name(\n                timeout_seconds=self._config.collective_timeout\n            )\n            self.configure(exp_name=name)\n            print(f\"[Dumper] Choose exp_name={name}\")\n\n\n# -------------------------------------- hook dumper ------------------------------------------\n\n\nclass _NonIntrusiveDumper:\n    _NAME_PREFIX = \"non_intrusive__\"\n    _LAYER_NAME_RE = re.compile(r\"(?:.+\\.)?layers\\.(\\d+)$\")\n\n    def __init__(\n        self,\n        dumper: _Dumper,\n        model: \"torch.nn.Module\",\n        mode: str,\n    ):\n        self._dumper = dumper\n        self._mode = mode\n        self._handles: list = []\n        self._core_fields: frozenset[str] = frozenset().union(\n            *(p.core_fields() for p in _plugins)\n        )\n\n        for module_name, module in model.named_modules():\n            if ctx := self._detect_module_ctx(module_name, module):\n                self._register_ctx_hooks(module, ctx=ctx)\n\n            is_root = module_name == \"\"\n            pre_hook = self._make_forward_pre_hook(\n                module_name=module_name, is_root=is_root\n            )\n            hook = self._make_forward_hook(module_name=module_name, is_root=is_root)\n            self._handles += _register_forward_hook_or_replace_fn(\n                module,\n                pre_hook=pre_hook,\n                hook=hook,\n                mode=\"replace_fn\" if is_root else \"hook\",\n            )\n\n    def remove(self) -> None:\n        for handle in self._handles:\n            handle.remove()\n        self._handles.clear()\n\n    @classmethod\n    def _detect_module_ctx(\n        cls, module_name: str, module: \"torch.nn.Module\"\n    ) -> Optional[dict]:\n        match = cls._LAYER_NAME_RE.fullmatch(module_name)\n        if match:\n            for plugin in _plugins:\n                layer_id = plugin.detect_layer_id(module)\n                if layer_id is not None:\n                    return {\"layer_id\": layer_id}\n            return {\"layer_id\": int(match.group(1))}\n        return None\n\n    def _register_ctx_hooks(self, module: \"torch.nn.Module\", *, ctx: dict) -> None:\n        clear_ctx = {k: None for k in ctx}\n        self._handles.append(\n            module.register_forward_pre_hook(\n                lambda _mod, _input, _ctx=ctx: self._dumper.set_ctx(**_ctx)\n            )\n        )\n        self._handles.append(\n            module.register_forward_hook(\n                lambda _mod, _input, _output, _clear=clear_ctx: self._dumper.set_ctx(\n                    **_clear\n                )\n            )\n        )\n\n    def _make_forward_pre_hook(self, *, module_name: str, is_root: bool):\n        def _hook(_module, args, kwargs):\n            for i, item in enumerate(args):\n                self._dump_value(\n                    module_name, item, sub_name=f\"inputs.{i}\", is_root=is_root\n                )\n            for name, value in kwargs.items():\n                self._dump_value(\n                    module_name,\n                    value,\n                    sub_name=f\"inputs.{name}\",\n                    is_root=is_root,\n                )\n\n        return _hook\n\n    def _make_forward_hook(self, *, module_name: str, is_root: bool):\n        def _hook(_module, input, output):\n            if output is not None:\n                self._dump_value(module_name, output, sub_name=\"output\", is_root=False)\n\n        return _hook\n\n    def _dump_value(\n        self, module_name: str, value: Any, sub_name: str, *, is_root: bool\n    ) -> None:\n        for key, item in self._convert_value(\n            value, skip_forward_batch=(not is_root)\n        ).items():\n            effective_key = key or sub_name.rsplit(\".\", 1)[-1]\n            if effective_key in self._core_fields:\n                self._dumper.dump(effective_key, item)\n            elif self._mode == \"all\":\n                parts = [p for p in (module_name, sub_name, key) if p]\n                self._dumper.dump(self._NAME_PREFIX + \".\".join(parts), item)\n\n    @staticmethod\n    def _convert_value(value, *, skip_forward_batch: bool = False) -> dict[str, Any]:\n        if isinstance(value, torch.Tensor):\n            return {\"\": value}\n\n        if isinstance(value, (tuple, list)):\n            tensors = [t for t in value if isinstance(t, torch.Tensor)]\n            if len(tensors) == 1:\n                return {\"\": tensors[0]}\n            return {str(i): t for i, t in enumerate(tensors)}\n\n        for plugin in _plugins:\n            result = plugin.convert_value(value, skip_forward_batch=skip_forward_batch)\n            if result is not None:\n                return result\n\n        return {}\n\n\ndef _register_forward_hook_or_replace_fn(\n    module: \"torch.nn.Module\",\n    *,\n    pre_hook,\n    hook,\n    mode: str,\n) -> list:\n    \"\"\"Attach pre/post forward hooks to *module*.\n\n    mode=\"hook\"       — standard ``register_forward_pre_hook`` / ``register_forward_hook``\n                        (fires only via ``__call__``).\n    mode=\"replace_fn\" — monkey-patch ``module.forward`` so hooks fire even when\n                        callers invoke ``.forward()`` directly (as sglang does for the\n                        root model).\n\n    Returns a list of handle objects with a ``.remove()`` method that undoes\n    the registration.\n    \"\"\"\n    if mode == \"hook\":\n        return [\n            module.register_forward_pre_hook(pre_hook, with_kwargs=True),\n            module.register_forward_hook(hook),\n        ]\n    elif mode == \"replace_fn\":\n        original_forward = module.forward\n\n        @functools.wraps(original_forward)\n        def _wrapped(*args, **kwargs):\n            pre_hook(module, args, kwargs)\n            output = original_forward(*args, **kwargs)\n            hook(module, args, output)\n            return output\n\n        module.forward = _wrapped\n\n        class _Handle:\n            def remove(self) -> None:\n                assert module.forward is _wrapped\n                module.forward = original_forward\n\n        return [_Handle()]\n    else:\n        raise ValueError(f\"Unknown mode {mode!r}\")\n\n\n# -------------------------------------- util fn ------------------------------------------\n\n\ndef _torch_save(value, path: str):\n    value = _clone_if_view(value)\n    try:\n        try:\n            return torch.save(value, path)\n        except RuntimeError as e:\n            if \"not pickleable\" in str(e):\n                stripped = _strip_parameter(value)\n                if stripped is not value:\n                    print(f\"[Dumper] Observe error={e} and try pickling .data\")\n                    return _torch_save(stripped, path)\n            raise\n    except Exception as e:\n        print(f\"[Dumper] Observe error={e} when saving data, skip the tensor\")\n\n\ndef _map_tensor(value, fn: Callable[[torch.Tensor], torch.Tensor]):\n    if isinstance(value, dict):\n        return {k: _map_tensor(v, fn) for k, v in value.items()}\n    if isinstance(value, torch.Tensor):\n        return fn(value)\n    return value\n\n\ndef _clone_if_view(value):\n    def _fn(t: torch.Tensor) -> torch.Tensor:\n        if t.untyped_storage().nbytes() > t.nelement() * t.element_size():\n            return t.clone()\n        return t\n\n    return _map_tensor(value, _fn)\n\n\ndef _strip_parameter(value):\n    def _fn(t: torch.Tensor) -> torch.Tensor:\n        if isinstance(t, torch.nn.Parameter):\n            return t.data\n        return t\n\n    return _map_tensor(value, _fn)\n\n\ndef _collective_with_timeout(fn, operation_name: str, timeout_seconds: int = 60):\n    completed = threading.Event()\n\n    def watchdog():\n        if not completed.wait(timeout=timeout_seconds):\n            print(\n                f\"\\n[Dumper] WARNING: '{operation_name}' has not completed after \"\n                f\"{timeout_seconds}s. This usually means not all ranks are \"\n                f\"participating in this collective operation.\\n\",\n                flush=True,\n            )\n\n    thread = threading.Thread(target=watchdog, daemon=True)\n    thread.start()\n    try:\n        return fn()\n    finally:\n        completed.set()\n\n\ndef _get_default_exp_name(timeout_seconds: int = 60):\n    rank = _get_rank()\n    now = time.time()\n    ms = int((now % 1) * 1000)\n    rand_suffix = random.randint(0, 999)\n    object_list = [\n        (\n            (\n                f\"{_DEFAULT_EXP_NAME_PREFIX}\"\n                f\"{time.strftime('%Y%m%d_%H%M%S', time.gmtime(now))}\"\n                f\"_{ms:03d}{rand_suffix:03d}\"\n            )\n            if rank == 0\n            else None\n        )\n    ]\n\n    if dist.is_initialized():\n        _collective_with_timeout(\n            lambda: dist.broadcast_object_list(object_list, device=\"cuda\"),\n            operation_name=\"broadcast_object_list in _get_default_exp_name\",\n            timeout_seconds=timeout_seconds,\n        )\n\n    return object_list[0]\n\n\ndef _cleanup_old_dumps(base_dir: Path, exp_name: Optional[str] = None) -> None:\n    import shutil\n\n    if _get_rank() == 0:\n        targets = {entry for entry in base_dir.glob(f\"{_DEFAULT_EXP_NAME_PREFIX}*\")}\n        if exp_name:\n            targets.add(base_dir / exp_name)\n        targets = {d for d in targets if d.is_dir()}\n\n        for entry in targets:\n            shutil.rmtree(entry)\n            print(f\"[Dumper] Cleaned up {entry}\")\n\n    if dist.is_initialized():\n        _collective_with_timeout(\n            dist.barrier,\n            operation_name=\"barrier in _cleanup_old_dumps\",\n        )\n\n\ndef _get_rank():\n    if dist.is_initialized():\n        return dist.get_rank()\n    else:\n        return 0\n\n\ndef _get_world_size():\n    if dist.is_initialized():\n        return dist.get_world_size()\n    else:\n        return 1\n\n\ndef _obj_to_dict(obj):\n    if isinstance(obj, dict):\n        return obj\n    ret = {}\n    for k in dir(obj):\n        if k.startswith(\"__\") and k.endswith(\"__\"):\n            continue\n        try:\n            v = getattr(obj, k)\n            if not callable(v):\n                ret[k] = v\n        except Exception:\n            # Skip attributes that raise an exception on access\n            continue\n    return ret\n\n\ndef _materialize_value(value):\n    if callable(value):\n        value = value()\n    return value\n\n\ndef _format_tags(kwargs: dict) -> str:\n    return \"___\".join(f\"{k}={v}\" for k, v in kwargs.items())\n\n\nclass _DefaultNoneDict(dict):\n    \"\"\"dict subclass that returns None for missing keys, for filter expression eval.\"\"\"\n\n    def __missing__(self, key: str):\n        return None\n\n\n_FILTER_BUILTINS: dict[str, Any] = {\"search\": re.search, \"match\": re.match}\n\n\ndef _evaluate_filter(filter_expr: str, tags: dict[str, Any]) -> bool:\n    \"\"\"Evaluate a Python filter expression against the tags dict.\n\n    Unknown tag keys resolve to None, so `layer_id is None` works when layer_id is absent.\n    `re.search` and `re.match` are available as `search()` and `match()`.\n    \"\"\"\n    namespace = _DefaultNoneDict(tags)\n    namespace.update(_FILTER_BUILTINS)\n    return bool(eval(filter_expr, {\"__builtins__\": {}}, namespace))\n\n\ndef _deepcopy_or_clone(x):\n    if isinstance(x, torch.Tensor):\n        return x.clone()\n    return deepcopy(x)\n\n\n# -------------------------------------- static meta ------------------------------------------\n\n\ndef _compute_static_meta():\n    result = {\n        \"world_rank\": _get_rank(),\n        \"world_size\": _get_world_size(),\n    }\n\n    for plugin in _plugins:\n        if info := plugin.collect_parallel_info():\n            result[f\"{plugin.name}_parallel_info\"] = info\n\n    for plugin in _plugins:\n        tokenizer_path: Optional[str] = plugin.get_tokenizer_path()\n        if tokenizer_path is not None:\n            result[\"tokenizer_path\"] = tokenizer_path\n            break\n\n    return result\n\n\n# -------------------------------------- http manager ------------------------------------------\n\n\nclass _DumperHttpManager:\n    def __init__(self, dumper: \"_Dumper\"):\n        self._dumper = dumper\n        http_port = self._dumper._config.server_port_parsed\n\n        rpc_broadcast = _create_zmq_rpc_broadcast(\n            self,\n            timeout_seconds=self._dumper._config.collective_timeout,\n        )\n\n        if _get_rank() == 0:\n            assert rpc_broadcast is not None\n            self._rpc_broadcast = rpc_broadcast\n\n            if http_port == \"reuse\":\n                print(\n                    \"[Dumper] Standalone HTTP server disabled, reusing existing ports\"\n                )\n            else:\n                _start_http_server(prefix=\"/dumper/\", target=self, http_port=http_port)\n                print(f\"[Dumper] HTTP server started on port {http_port}\")\n\n    # ------------------------------- public ---------------------------------\n\n    def handle_request(self, *, method: str, body: dict[str, Any]) -> list[dict]:\n        return self._rpc_broadcast._handle_request_inner(method=method, body=body)\n\n    # ------------------------------- private ---------------------------------\n\n    def _handle_request_inner(self, *, method: str, body: dict[str, Any]) -> dict:\n        if method == \"get_state\":\n            return self._dumper.get_state()\n        elif method == \"configure\":\n            self._dumper.configure(**body)\n            return {}\n        elif method == \"reset\":\n            self._dumper.reset()\n            return {}\n        else:\n            raise ValueError(f\"Unknown dumper control method: {method!r}\")\n\n\n# -------------------------------------- http control server ------------------------------------------\n\n\ndef _start_http_server(*, prefix: str, target: object, http_port: int):\n    handler_class = _make_http_handler(prefix=prefix, target=target)\n    server = HTTPServer((\"0.0.0.0\", http_port), handler_class)\n    thread = threading.Thread(target=server.serve_forever, daemon=True)\n    thread.start()\n\n\ndef _make_http_handler(*, prefix: str, target):\n    class _HTTPHandler(BaseHTTPRequestHandler):\n        def do_POST(self):\n            if not self.path.startswith(prefix):\n                self.send_error(404)\n                return\n            method = self.path[len(prefix) :]\n            try:\n                req_body = self._get_request_body()\n                print(f\"[Dumper#{_get_rank()}] HTTP {self.path} {req_body=}\")\n                result = target.handle_request(method=method, body=req_body)\n                resp_body = json.dumps(result).encode()\n                self.send_response(200)\n                self.send_header(\"Content-Type\", \"application/json\")\n                self.send_header(\"Content-Length\", str(len(resp_body)))\n                self.end_headers()\n                self.wfile.write(resp_body)\n            except Exception as e:\n                self.send_error(400, str(e))\n\n        def _get_request_body(self) -> dict:\n            content_length = int(self.headers.get(\"Content-Length\", 0))\n            if content_length == 0:\n                return {}\n            return json.loads(self.rfile.read(content_length))\n\n    return _HTTPHandler\n\n\n# -------------------------------------- zmq rpc ------------------------------------------\n\n\ndef _create_zmq_rpc_broadcast(\n    handler, timeout_seconds: int = 60\n) -> Optional[\"_ZmqRpcBroadcast\"]:\n    \"\"\"A general-purpose minimal RPC to support broadcasting executions to multi processes\"\"\"\n    import zmq\n\n    rank = _get_rank()\n    world_size = dist.get_world_size() if dist.is_initialized() else 1\n\n    ctx = zmq.Context()\n    sock = ctx.socket(zmq.REP)\n    sock.bind(\"tcp://*:0\")\n    bound_port = int(sock.getsockopt_string(zmq.LAST_ENDPOINT).rsplit(\":\", 1)[1])\n    local_addr = f\"tcp://{_get_local_ip_by_remote()}:{bound_port}\"\n\n    def serve_loop():\n        while True:\n            try:\n                req = sock.recv_pyobj()\n                result = getattr(handler, req[\"method\"])(*req[\"args\"], **req[\"kwargs\"])\n                resp = {\"result\": result, \"error\": None}\n            except Exception as e:\n                print(f\"[Dumper.ZmqRpc] error inside handler: {e}\")\n                resp = {\"result\": None, \"error\": str(e)}\n            sock.send_pyobj(resp)\n\n    thread = threading.Thread(target=serve_loop, daemon=True)\n    thread.start()\n    print(f\"[Dumper.ZmqRpc] rank={rank} server started at {local_addr}\")\n\n    if dist.is_initialized():\n        all_addresses = [None] * world_size\n        _collective_with_timeout(\n            lambda: dist.all_gather_object(all_addresses, local_addr),\n            operation_name=\"all_gather_object in _create_zmq_rpc_broadcast\",\n            timeout_seconds=timeout_seconds,\n        )\n    else:\n        all_addresses = [local_addr]\n    print(f\"[Dumper.ZmqRpc] rank={rank} all_addresses={all_addresses}\")\n\n    if rank == 0:\n        handles = []\n        for i, addr in enumerate(all_addresses):\n            req_socket = ctx.socket(zmq.REQ)\n            req_socket.connect(addr)\n            handles.append(_ZmqRpcHandle(req_socket, debug_name=f\"rank-{i}\"))\n        return _ZmqRpcBroadcast(handles)\n    else:\n        return None\n\n\nclass _ZmqRpcHandle:\n    \"\"\"Proxy object to call remote handler methods via ZMQ.\"\"\"\n\n    def __init__(self, socket, debug_name: str):\n        self._socket = socket\n        self._debug_name = debug_name\n\n    def __getattr__(self, method_name: str):\n        def call(*args, **kwargs):\n            self._socket.send_pyobj(\n                {\n                    \"method\": method_name,\n                    \"args\": args,\n                    \"kwargs\": kwargs,\n                }\n            )\n            response = self._socket.recv_pyobj()\n            if response[\"error\"]:\n                raise RuntimeError(\n                    f\"RPC error on {self._debug_name}: {response['error']}\"\n                )\n            return response[\"result\"]\n\n        return call\n\n\nclass _RpcBroadcastBase:\n    \"\"\"Base for broadcasting method calls to dumper instance(s).\"\"\"\n\n    def __getattr__(self, method_name: str):\n        raise NotImplementedError\n\n    def __init__(self, handles: List[_ZmqRpcHandle]):\n        self._handles = handles\n\n\nclass _ZmqRpcBroadcast(_RpcBroadcastBase):\n    \"\"\"Broadcasts method calls to all ZMQ RPC handles.\n\n    Returns a list of results, one per rank (ordered by rank).\n    \"\"\"\n\n    def __init__(self, handles: List[_ZmqRpcHandle]):\n        self._handles = handles\n\n    def __getattr__(self, method_name: str):\n        def call(*args, **kwargs):\n            return [\n                getattr(handle, method_name)(*args, **kwargs)\n                for handle in self._handles\n            ]\n\n        return call\n\n\n# --------------------------------- copied code (avoid dependency) --------------------------------------\n\n\ndef _get_local_ip_by_remote() -> Optional[str]:\n    # try ipv4\n    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)\n    try:\n        s.connect((\"8.8.8.8\", 80))  # Doesn't need to be reachable\n        return s.getsockname()[0]\n    except Exception:\n        pass\n\n    try:\n        hostname = socket.gethostname()\n        ip = socket.gethostbyname(hostname)\n        if ip and ip != \"127.0.0.1\" and ip != \"0.0.0.0\":\n            return ip\n    except Exception:\n        pass\n\n    # try ipv6\n    try:\n        s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)\n        # Google's public DNS server, see\n        # https://developers.google.com/speed/public-dns/docs/using#addresses\n        s.connect((\"2001:4860:4860::8888\", 80))  # Doesn't need to be reachable\n        return s.getsockname()[0]\n    except Exception:\n        print(\"Can not get local ip by remote\")\n    return None\n\n\n# -------------------------------------- framework plugins ------------------------------------------\n\n\nclass _RecomputeStatus(enum.Enum):\n    DISABLED = \"disabled\"\n    ORIGINAL = \"original\"  # inside checkpoint, original forward\n    RECOMPUTE = \"recompute\"  # inside checkpoint, recompute forward\n\n    def to_pseudo_parallel_meta(self) -> dict[str, Any]:\n        if self == _RecomputeStatus.DISABLED:\n            return {}\n        return {\n            \"recompute_pseudo_rank\": 1 if self == _RecomputeStatus.RECOMPUTE else 0,\n            \"recompute_pseudo_size\": 2,\n        }\n\n\nclass _FrameworkPlugin(ABC):\n    @property\n    @abstractmethod\n    def name(self) -> str: ...\n\n    @abstractmethod\n    def collect_parallel_info(self) -> dict: ...\n\n    @abstractmethod\n    def convert_value(\n        self, value: Any, *, skip_forward_batch: bool\n    ) -> Optional[dict[str, Any]]:\n        \"\"\"Return converted dict, or None if this plugin doesn't handle the value.\"\"\"\n        ...\n\n    @abstractmethod\n    def detect_layer_id(self, module: \"torch.nn.Module\") -> Optional[int]:\n        \"\"\"Return 0-indexed layer_id, or None if not detectable.\"\"\"\n        ...\n\n    def core_fields(self) -> frozenset[str]:\n        return frozenset()\n\n    def get_tokenizer_path(self) -> Optional[str]:\n        return None\n\n    def detect_recompute_status(self) -> _RecomputeStatus:\n        return _RecomputeStatus.DISABLED\n\n\nclass _SGLangPlugin(_FrameworkPlugin):\n    _available = True\n    try:\n        from sglang.srt import distributed as _dist\n        from sglang.srt.layers import dp_attention as _dp_attn\n        from sglang.srt.layers.logits_processor import LogitsProcessorOutput\n        from sglang.srt.model_executor.forward_batch_info import (\n            ForwardBatch,\n            PPProxyTensors,\n        )\n    except ImportError:\n        _available = False\n\n    @property\n    def name(self) -> str:\n        return \"sglang\"\n\n    def collect_parallel_info(self) -> dict:\n        if not self._available:\n            return {}\n\n        info = {}\n\n        try:\n            info[\"tp_rank\"] = self._dist.get_tensor_model_parallel_rank()\n            info[\"tp_size\"] = self._dist.get_tensor_model_parallel_world_size()\n            info[\"pp_rank\"] = self._dist.get_pipeline_model_parallel_rank()\n            info[\"pp_size\"] = self._dist.get_pipeline_model_parallel_world_size()\n            info[\"moe_ep_rank\"] = self._dist.get_moe_expert_parallel_rank()\n            info[\"moe_ep_size\"] = self._dist.get_moe_expert_parallel_world_size()\n            info[\"moe_tp_rank\"] = self._dist.get_moe_tensor_parallel_rank()\n            info[\"moe_tp_size\"] = self._dist.get_moe_tensor_parallel_world_size()\n            info[\"moe_dp_rank\"] = self._dist.get_moe_data_parallel_rank()\n            info[\"moe_dp_size\"] = self._dist.get_moe_data_parallel_world_size()\n        except (AttributeError, AssertionError):\n            info[\"distributed_error\"] = True\n\n        try:\n            info[\"enable_dp_attention\"] = self._dp_attn.is_dp_attention_enabled()\n            info[\"attn_tp_rank\"] = self._dp_attn.get_attention_tp_rank()\n            info[\"attn_tp_size\"] = self._dp_attn.get_attention_tp_size()\n            info[\"attn_dp_rank\"] = self._dp_attn.get_attention_dp_rank()\n            info[\"attn_dp_size\"] = self._dp_attn.get_attention_dp_size()\n            info[\"local_attn_dp_rank\"] = self._dp_attn.get_local_attention_dp_rank()\n            info[\"local_attn_dp_size\"] = self._dp_attn.get_local_attention_dp_size()\n            info[\"attn_cp_rank\"] = self._dp_attn.get_attention_cp_rank()\n            info[\"attn_cp_size\"] = self._dp_attn.get_attention_cp_size()\n        except (AttributeError, AssertionError):\n            info[\"dp_attention_error\"] = True\n\n        return info\n\n    def convert_value(\n        self, value: Any, *, skip_forward_batch: bool\n    ) -> Optional[dict[str, Any]]:\n        if not self._available:\n            return None\n\n        if isinstance(value, self.LogitsProcessorOutput):\n            return {\"next_token_logits\": value.next_token_logits}\n        if isinstance(value, self.ForwardBatch):\n            if skip_forward_batch:\n                return {}\n            result = {\n                \"input_ids\": value.input_ids,\n                \"seq_lens\": value.seq_lens,\n                \"positions\": value.positions,\n                \"req_pool_indices\": value.req_pool_indices,\n            }\n            if value.rids is not None:\n                result[\"rids\"] = value.rids\n            return result\n        if isinstance(value, self.PPProxyTensors):\n            return {k: v for k, v in value.tensors.items()}\n\n        return None\n\n    def detect_layer_id(self, module: \"torch.nn.Module\") -> Optional[int]:\n        if hasattr(module, \"layer_id\"):\n            return module.layer_id\n        return None\n\n    def core_fields(self) -> frozenset[str]:\n        return frozenset(\n            {\"input_ids\", \"positions\", \"seq_lens\", \"req_pool_indices\", \"rids\"}\n        )\n\n    def get_tokenizer_path(self) -> Optional[str]:\n        if not self._available:\n            return None\n\n        try:\n            from sglang.srt.server_args import get_global_server_args\n\n            args = get_global_server_args()\n            if args is None:\n                return None\n\n            return args.tokenizer_path\n        except Exception:\n            return None\n\n\nclass _MegatronPlugin(_FrameworkPlugin):\n    _available = True\n    try:\n        from megatron.core import parallel_state as _mpu\n        from megatron.core.packed_seq_params import PackedSeqParams\n    except ImportError:\n        _available = False\n\n    @property\n    def name(self) -> str:\n        return \"megatron\"\n\n    def collect_parallel_info(self) -> dict:\n        if not self._available:\n            return {}\n\n        info = {}\n        try:\n            info[\"tp_rank\"] = self._mpu.get_tensor_model_parallel_rank()\n            info[\"tp_size\"] = self._mpu.get_tensor_model_parallel_world_size()\n            info[\"pp_rank\"] = self._mpu.get_pipeline_model_parallel_rank()\n            info[\"pp_size\"] = self._mpu.get_pipeline_model_parallel_world_size()\n            info[\"dp_rank\"] = self._mpu.get_data_parallel_rank()\n            info[\"dp_size\"] = self._mpu.get_data_parallel_world_size()\n            info[\"cp_rank\"] = self._mpu.get_context_parallel_rank()\n            info[\"cp_size\"] = self._mpu.get_context_parallel_world_size()\n            info[\"vpp_rank\"] = self._mpu.get_virtual_pipeline_model_parallel_rank()\n            info[\"vpp_size\"] = (\n                self._mpu.get_virtual_pipeline_model_parallel_world_size()\n            )\n            info[\"ep_rank\"] = self._mpu.get_expert_model_parallel_rank()\n            info[\"ep_size\"] = self._mpu.get_expert_model_parallel_world_size()\n            info[\"etp_rank\"] = self._mpu.get_expert_tensor_parallel_rank()\n            info[\"etp_size\"] = self._mpu.get_expert_tensor_parallel_world_size()\n            info[\"edp_rank\"] = self._mpu.get_expert_data_parallel_rank()\n            info[\"edp_size\"] = self._mpu.get_expert_data_parallel_world_size()\n            info[\"tcp_rank\"] = self._mpu.get_tensor_and_context_parallel_rank()\n            info[\"tcp_size\"] = self._mpu.get_tensor_and_context_parallel_world_size()\n            info[\"etmp_rank\"] = self._mpu.get_expert_tensor_and_model_parallel_rank()\n            info[\"etmp_size\"] = (\n                self._mpu.get_expert_tensor_and_model_parallel_world_size()\n            )\n            info[\"tp_src_rank\"] = self._mpu.get_tensor_model_parallel_src_rank()\n            info[\"mp_src_rank\"] = self._mpu.get_model_parallel_src_rank()\n            info[\"dp_src_rank\"] = self._mpu.get_data_parallel_src_rank()\n        except (AttributeError, AssertionError):\n            info[\"megatron_error\"] = True\n\n        # Megatron sequence parallel reuses the TP group (no dedicated parallel state API).\n        # When sequence_parallel=True, inject sp_rank/sp_size for the comparator unsharder.\n        try:\n            from megatron.training.global_vars import get_args\n\n            args = get_args()\n            if getattr(args, \"sequence_parallel\", False) and \"tp_rank\" in info:\n                info[\"sp_rank\"] = info[\"tp_rank\"]\n                info[\"sp_size\"] = info[\"tp_size\"]\n        except (ImportError, AssertionError, AttributeError):\n            pass\n\n        return info\n\n    def convert_value(\n        self, value: Any, *, skip_forward_batch: bool\n    ) -> Optional[dict[str, Any]]:\n        if not self._available:\n            return None\n        if isinstance(value, self.PackedSeqParams):\n            return {\n                \"cu_seqlens_q\": value.cu_seqlens_q,\n                \"cu_seqlens_kv\": value.cu_seqlens_kv,\n                \"qkv_format\": value.qkv_format,\n            }\n        return None\n\n    def detect_layer_id(self, module: \"torch.nn.Module\") -> Optional[int]:\n        if hasattr(module, \"layer_number\"):\n            return module.layer_number - 1\n        return None\n\n    def core_fields(self) -> frozenset[str]:\n        return frozenset(\n            {\"input_ids\", \"position_ids\", \"cu_seqlens_q\", \"cu_seqlens_kv\", \"qkv_format\"}\n        )\n\n    def detect_recompute_status(self) -> _RecomputeStatus:\n        if not self._available:\n            return _RecomputeStatus.DISABLED\n        try:\n            from megatron.core.tensor_parallel.random import is_checkpointing\n\n            if not is_checkpointing():\n                return _RecomputeStatus.DISABLED\n            if torch.is_grad_enabled():\n                return _RecomputeStatus.RECOMPUTE\n            return _RecomputeStatus.ORIGINAL\n        except (ImportError, AttributeError):\n            return _RecomputeStatus.DISABLED\n\n\n_plugins: list[_FrameworkPlugin] = [_SGLangPlugin(), _MegatronPlugin()]\n\n\ndef _detect_recompute_status() -> _RecomputeStatus:\n    for plugin in _plugins:\n        info = plugin.detect_recompute_status()\n        if info != _RecomputeStatus.DISABLED:\n            return info\n    return _RecomputeStatus.DISABLED\n\n\n# -------------------------------------- singleton ------------------------------------------\n\n\ndumper = _Dumper(config=DumperConfig.from_env())\n\n\n# -------------------------------------- other utility functions ------------------------------------------\n\n\ndef get_truncated_value(value):\n    if value is None:\n        return None\n\n    if isinstance(value, tuple):\n        return [get_truncated_value(x) for x in value]\n\n    if not isinstance(value, torch.Tensor):\n        return value\n\n    if value.numel() < 200:\n        return value\n\n    slices = [slice(0, 5) if dim_size > 50 else slice(None) for dim_size in value.shape]\n    return value[tuple(slices)]\n\n\ndef get_tensor_info(x):\n    \"\"\"\n    from sglang.srt.debug_utils.dumper import get_tensor_info\n    \"\"\"\n    if not isinstance(x, torch.Tensor):\n        return f\"type={type(x)} value={x}\"\n    min = x.float().min() if x.numel() > 0 else None\n    max = x.float().max() if x.numel() > 0 else None\n    mean = x.float().mean() if x.numel() > 0 else None\n    torch.set_printoptions(precision=10)\n    x_sample_head = str(x.flatten()[:5])\n    x_sample_tail = str(x.flatten()[-5:])\n    torch.set_printoptions(precision=4)\n    return (\n        f\"type={type(x)} \"\n        f\"shape={x.shape} \"\n        f\"dtype={x.dtype} \"\n        f\"device={x.device} \"\n        f\"stride={x.stride()} \"\n        f\"req_grad={x.requires_grad} \"\n        f\"min={min} \"\n        f\"max={max} \"\n        f\"mean={mean} \"\n        f\"x_sample_head={x_sample_head} \"\n        f\"x_sample_tail={x_sample_tail}\"\n    )\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/log_parser.py",
    "content": "_PATTERN_DECODE = (\n    r\"(\\(\\w+ pid=(?P<pid>\\d+)(?:,\\s*ip=(?P<ip>[\\d\\.]+))?\\))?\\s*\"\n    r\"\\[(?P<time>\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2})\"\n    r\"(?:\\s+DP(?P<dp_rank>\\d+))?\"\n    r\"(?:\\s+TP(?P<tp_rank>\\d+))?\"\n    r\"(?:\\s+EP(?P<ep_rank>\\d+))?\"\n    r\"(?:\\s+PP(?P<pp_rank>\\d+))?\"\n    r\"\\]\\s+\"\n    r\"Decode batch( \\[\\d+\\])?,\\s+\"\n    r\"#running-req:\\s*(?P<num_running_req>\\d+),\\s+\"\n    r\"#token:\\s*(?P<num_token>\\d+),\\s+\"\n    r\"token usage:\\s*(?P<token_usage>[0-9.]+),\\s+\"\n    r\".*?\"\n    r\"gen throughput \\(token/s\\):\\s*(?P<gen_throughput>[0-9.]+),\\s+\"\n    r\"#queue-req:\\s*(?P<queue_req>\\d+),\"\n)\n\n\ndef parse(lines):\n    import polars as pl\n\n    df = pl.DataFrame(dict(line=lines.splitlines()))\n    df = df.with_columns(info=pl.col(\"line\").str.extract_groups(_PATTERN_DECODE))\n    df = df.unnest(\"info\")\n    df = df.filter(pl.col(\"gen_throughput\").is_not_null())\n\n    df = df.with_columns(\n        pl.col(\"time\").str.strptime(pl.Datetime, \"%Y-%m-%d %H:%M:%S\"),\n        *[\n            pl.col(col).cast(dtype)\n            for col, dtype in [\n                (\"pid\", pl.Int64),\n                (\"dp_rank\", pl.Int64),\n                (\"tp_rank\", pl.Int64),\n                (\"ep_rank\", pl.Int64),\n                (\"pp_rank\", pl.Int64),\n                (\"num_running_req\", pl.Int64),\n                (\"num_token\", pl.Int64),\n                (\"token_usage\", pl.Float64),\n                (\"gen_throughput\", pl.Float64),\n                (\"queue_req\", pl.Int64),\n            ]\n            if col in df.columns\n        ],\n    )\n    return df\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/model_truncator.py",
    "content": "# This file also references Slime :: fp8_cast_bf16.py\nimport json\nimport os\nimport re\nfrom argparse import ArgumentParser\nfrom pathlib import Path\nfrom typing import Dict\n\nimport torch\nfrom huggingface_hub import snapshot_download\nfrom safetensors.torch import load_file, save_file\n\n\ndef main(args):\n    dir_input = Path(_maybe_snapshot_download(args.input))\n    dir_output = Path(args.output)\n    print(f\"{dir_input=} {dir_output=}\")\n\n    dir_output.mkdir(parents=True, exist_ok=True)\n\n    for pattern in [\"generation_config.json\", \"*.py\", \"tokenizer*\"]:\n        os.system(f\"cp -rf {dir_input}/{pattern} {dir_output}\")\n\n    _transform_json(\n        dir_input,\n        dir_output,\n        \"config.json\",\n        lambda data: _transform_config(args, data),\n    )\n\n    safetensors_index = _transform_json(\n        dir_input,\n        dir_output,\n        \"model.safetensors.index.json\",\n        lambda data: _transform_safetensors_index(args, data),\n    )\n\n    for path_input_safetensors in sorted(list(dir_input.glob(\"*.safetensors\"))):\n        path_output_safetensors = dir_output / path_input_safetensors.relative_to(\n            dir_input\n        )\n\n        state_dict = load_file(path_input_safetensors)\n        _transform_safetensors_file(\n            state_dict, safetensors_index, debug_name=str(path_output_safetensors)\n        )\n        if len(state_dict) > 0:\n            print(f\"Save {len(state_dict)} tensors to {path_output_safetensors}\")\n            save_file(state_dict, path_output_safetensors)\n        else:\n            print(f\"Skip saving {path_output_safetensors} since it is empty\")\n\n\ndef _maybe_snapshot_download(path):\n    if Path(path).exists():\n        return path\n    return snapshot_download(path)\n\n\ndef _transform_json(dir_input, dir_output, filename, fn):\n    data = json.loads((dir_input / filename).read_text())\n    fn(data)\n    (dir_output / filename).write_text(json.dumps(data, indent=4))\n    return data\n\n\ndef _transform_config(args, config_json):\n    config_json[\"num_hidden_layers\"] = args.keep_num_layers\n\n\ndef _transform_safetensors_index(args, safetensors_index):\n    weight_map = safetensors_index[\"weight_map\"]\n    weight_map = {\n        name: loc for name, loc in weight_map.items() if _filter_tensor_name(args, name)\n    }\n    safetensors_index[\"weight_map\"] = weight_map\n\n\ndef _transform_safetensors_file(\n    state_dict: Dict[str, torch.Tensor], safetensors_index, debug_name: str\n):\n    names_to_remove = set(state_dict) - set(safetensors_index[\"weight_map\"])\n    print(f\"Remove {list(names_to_remove)} in {debug_name}\")\n    for name in names_to_remove:\n        del state_dict[name]\n\n\ndef _filter_tensor_name(args, tensor_name: str):\n    # We focus on DeepSeek-like names currently, but can be easily extended to more kinds of models\n    m = re.match(r\"^model.layers.(\\d+).*\", tensor_name)\n    if m is None:\n        return True\n\n    layer_id = int(m.group(1))\n    return layer_id < args.keep_num_layers\n\n\nif __name__ == \"__main__\":\n    \"\"\"\n    Example:\n    python -m sglang.srt.debug_utils.model_truncator --input deepseek-ai/DeepSeek-V3-0324 --output /tmp/DeepSeek-V3-0324-5layer\n    hf upload my_name/DeepSeek-V3-0324-5layer /tmp/DeepSeek-V3-0324-5layer\n\n    Alternatively, the following may be used on-the-fly.\n    But this may not be useful to test RL frameworks, and sometimes it may have issues.\n        --json-model-override-args '{\"num_hidden_layers\": 5}'\n    \"\"\"\n    parser = ArgumentParser(description=\"Create truncated model for fast debugging.\")\n    parser.add_argument(\"--input\", type=str, required=True)\n    parser.add_argument(\"--output\", type=str, required=True)\n    parser.add_argument(\"--keep-num-layers\", type=int, default=5)\n    main(parser.parse_args())\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/schedule_simulator/__init__.py",
    "content": "from sglang.srt.debug_utils.schedule_simulator.data_source import (\n    generate_gsp_requests,\n    generate_random_requests,\n    load_from_request_logger,\n)\nfrom sglang.srt.debug_utils.schedule_simulator.entrypoint import create_arg_parser, main\nfrom sglang.srt.debug_utils.schedule_simulator.gpu_state import GPUState, StepRecord\nfrom sglang.srt.debug_utils.schedule_simulator.metrics import (\n    AttentionComputeBalancednessRecorder,\n    AvgBatchSizeRecorder,\n    BatchSizeBalancednessRecorder,\n    MetricRecorder,\n)\nfrom sglang.srt.debug_utils.schedule_simulator.request import SimRequest\nfrom sglang.srt.debug_utils.schedule_simulator.routers import (\n    RandomRouter,\n    RoundRobinRouter,\n    RouterPolicy,\n    StickyRouter,\n)\nfrom sglang.srt.debug_utils.schedule_simulator.schedulers import (\n    FIFOScheduler,\n    SchedulerPolicy,\n)\nfrom sglang.srt.debug_utils.schedule_simulator.simulator import (\n    SimulationResult,\n    Simulator,\n)\n\n__all__ = [\n    \"SimRequest\",\n    \"GPUState\",\n    \"Simulator\",\n    \"SimulationResult\",\n    \"StepRecord\",\n    \"RouterPolicy\",\n    \"RandomRouter\",\n    \"RoundRobinRouter\",\n    \"StickyRouter\",\n    \"SchedulerPolicy\",\n    \"FIFOScheduler\",\n    \"MetricRecorder\",\n    \"BatchSizeBalancednessRecorder\",\n    \"AttentionComputeBalancednessRecorder\",\n    \"AvgBatchSizeRecorder\",\n    \"load_from_request_logger\",\n    \"generate_random_requests\",\n    \"generate_gsp_requests\",\n    \"create_arg_parser\",\n    \"main\",\n]\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/schedule_simulator/__main__.py",
    "content": "from sglang.srt.debug_utils.schedule_simulator.entrypoint import create_arg_parser, main\n\nif __name__ == \"__main__\":\n    parser = create_arg_parser()\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/schedule_simulator/data_source/__init__.py",
    "content": "from sglang.srt.debug_utils.schedule_simulator.data_source.data_loader import (\n    load_from_request_logger,\n)\nfrom sglang.srt.debug_utils.schedule_simulator.data_source.data_synthesis import (\n    generate_gsp_requests,\n    generate_random_requests,\n)\n\n__all__ = [\n    \"load_from_request_logger\",\n    \"generate_random_requests\",\n    \"generate_gsp_requests\",\n]\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/schedule_simulator/data_source/data_loader.py",
    "content": "import json\nfrom pathlib import Path\nfrom typing import List, Union\n\nfrom sglang.srt.debug_utils.schedule_simulator.request import SimRequest\n\n\ndef load_from_request_logger(file_path: Union[str, Path]) -> List[SimRequest]:\n    requests = []\n    file_path = Path(file_path)\n\n    with file_path.open(encoding=\"utf-8\") as f:\n        for line_num, line in enumerate(f):\n            line = line.strip()\n            if not line or not line.startswith(\"{\"):\n                continue\n\n            data = json.loads(line)\n\n            if data.get(\"event\") != \"request.finished\":\n                continue\n\n            rid = data.get(\"rid\", f\"req_{line_num}\")\n            meta_info = data[\"out\"][\"meta_info\"]\n\n            requests.append(\n                SimRequest(\n                    request_id=rid,\n                    input_len=meta_info[\"prompt_tokens\"],\n                    output_len=meta_info[\"completion_tokens\"],\n                )\n            )\n\n    return requests\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/schedule_simulator/data_source/data_synthesis.py",
    "content": "import random\nfrom typing import List, Optional\n\nfrom sglang.srt.debug_utils.schedule_simulator.request import SimRequest\n\n\ndef generate_random_requests(\n    num_requests: int,\n    input_len: int,\n    output_len: int,\n    range_ratio: float = 1.0,\n    seed: Optional[int] = None,\n) -> List[SimRequest]:\n    if seed is not None:\n        random.seed(seed)\n\n    requests = []\n    for i in range(num_requests):\n        isl = _random_len(input_len, range_ratio)\n        osl = _random_len(output_len, range_ratio)\n        requests.append(\n            SimRequest(\n                request_id=f\"syn{i}\",\n                input_len=isl,\n                output_len=osl,\n            )\n        )\n\n    print(\n        f\"Generated {len(requests)} random requests \"\n        f\"(input_len={input_len}, output_len={output_len}, range_ratio={range_ratio})\"\n    )\n    return requests\n\n\ndef generate_gsp_requests(\n    num_groups: int,\n    prompts_per_group: int,\n    system_prompt_len: int,\n    question_len: int,\n    output_len: int,\n    range_ratio: float = 1.0,\n    seed: Optional[int] = None,\n) -> List[SimRequest]:\n    if seed is not None:\n        random.seed(seed)\n\n    requests = []\n    idx = 0\n    for group_idx in range(num_groups):\n        group_id = f\"g{group_idx}\"\n        prefix_len = _random_len(system_prompt_len, range_ratio)\n        for _ in range(prompts_per_group):\n            q_len = _random_len(question_len, range_ratio)\n            osl = _random_len(output_len, range_ratio)\n            requests.append(\n                SimRequest(\n                    request_id=f\"gsp{idx}\",\n                    input_len=prefix_len + q_len,\n                    output_len=osl,\n                    group_id=group_id,\n                    prefix_len=prefix_len,\n                )\n            )\n            idx += 1\n\n    random.shuffle(requests)\n    print(\n        f\"Generated {len(requests)} GSP requests \"\n        f\"({num_groups} groups x {prompts_per_group} prompts, \"\n        f\"system_prompt_len={system_prompt_len}, question_len={question_len}, \"\n        f\"output_len={output_len})\"\n    )\n    return requests\n\n\ndef _random_len(full_len: int, range_ratio: float) -> int:\n    min_len = max(int(full_len * range_ratio), 1)\n    return random.randint(min_len, full_len)\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/schedule_simulator/entrypoint.py",
    "content": "import argparse\nimport json\nimport random\nfrom typing import List\n\nfrom sglang.srt.debug_utils.schedule_simulator.data_source.data_loader import (\n    load_from_request_logger,\n)\nfrom sglang.srt.debug_utils.schedule_simulator.data_source.data_synthesis import (\n    generate_gsp_requests,\n    generate_random_requests,\n)\nfrom sglang.srt.debug_utils.schedule_simulator.metrics import (\n    AttentionComputeBalancednessRecorder,\n    AvgBatchSizeRecorder,\n    BatchSizeBalancednessRecorder,\n)\nfrom sglang.srt.debug_utils.schedule_simulator.request import SimRequest\nfrom sglang.srt.debug_utils.schedule_simulator.routers import (\n    RandomRouter,\n    RoundRobinRouter,\n    StickyRouter,\n)\nfrom sglang.srt.debug_utils.schedule_simulator.schedulers import FIFOScheduler\nfrom sglang.srt.debug_utils.schedule_simulator.simulator import (\n    SimulationResult,\n    Simulator,\n)\n\n\ndef create_arg_parser() -> argparse.ArgumentParser:\n    parser = argparse.ArgumentParser(\n        description=\"Schedule Simulator for analyzing request scheduling across GPUs\"\n    )\n\n    data_group = parser.add_mutually_exclusive_group(required=True)\n    data_group.add_argument(\n        \"--input\", type=str, help=\"Path to request_logger JSON file\"\n    )\n    data_group.add_argument(\n        \"--synthetic\", action=\"store_true\", help=\"Use synthetic data generation\"\n    )\n    data_group.add_argument(\n        \"--synth-gsp\",\n        action=\"store_true\",\n        help=\"Use generated-shared-prefix (GSP) data generation\",\n    )\n\n    # Shared synthetic arguments\n    parser.add_argument(\"--synth-seed\", type=int, default=None)\n\n    # Random dataset arguments (aligned with bench_serving.py --random-* options)\n    parser.add_argument(\"--synth-random-num-requests\", type=int, default=1000)\n    parser.add_argument(\"--synth-random-input-len\", type=int, default=1024)\n    parser.add_argument(\"--synth-random-output-len\", type=int, default=256)\n    parser.add_argument(\"--synth-random-range-ratio\", type=float, default=0.0)\n\n    # GSP dataset arguments (aligned with bench_serving.py --gsp-* options)\n    parser.add_argument(\"--synth-gsp-num-groups\", type=int, default=64)\n    parser.add_argument(\"--synth-gsp-prompts-per-group\", type=int, default=16)\n    parser.add_argument(\"--synth-gsp-system-prompt-len\", type=int, default=2048)\n    parser.add_argument(\"--synth-gsp-question-len\", type=int, default=128)\n    parser.add_argument(\"--synth-gsp-output-len\", type=int, default=256)\n    parser.add_argument(\"--synth-gsp-range-ratio\", type=float, default=1.0)\n\n    parser.add_argument(\"--num-gpus-per-engine\", type=int, default=8)\n    parser.add_argument(\"--num-engines\", type=int, default=1)\n    parser.add_argument(\n        \"--router\",\n        type=str,\n        choices=[\"random\", \"round_robin\", \"sticky\"],\n        default=\"round_robin\",\n    )\n    parser.add_argument(\"--scheduler\", type=str, choices=[\"fifo\"], default=\"fifo\")\n    parser.add_argument(\"--max-total-tokens\", type=int, default=100000)\n    parser.add_argument(\n        \"--stop-criteria\",\n        type=str,\n        choices=[\"all_done\", \"exist_no_pending\"],\n        default=\"all_done\",\n        help=\"all_done: run until all requests complete; exist_no_pending: stop when any GPU has no pending requests\",\n    )\n    parser.add_argument(\"--max-steps\", type=int, default=None)\n    parser.add_argument(\"--output\", type=str, default=None)\n    parser.add_argument(\"--log-level\", type=int, choices=[0, 1, 2], default=0)\n\n    return parser\n\n\ndef _load_requests(args: argparse.Namespace) -> List[SimRequest]:\n    if args.input:\n        requests = load_from_request_logger(args.input)\n        print(f\"Loaded {len(requests)} requests from {args.input}\")\n    elif args.synth_gsp:\n        requests = generate_gsp_requests(\n            num_groups=args.synth_gsp_num_groups,\n            prompts_per_group=args.synth_gsp_prompts_per_group,\n            system_prompt_len=args.synth_gsp_system_prompt_len,\n            question_len=args.synth_gsp_question_len,\n            output_len=args.synth_gsp_output_len,\n            range_ratio=args.synth_gsp_range_ratio,\n            seed=args.synth_seed,\n        )\n    else:\n        requests = generate_random_requests(\n            num_requests=args.synth_random_num_requests,\n            input_len=args.synth_random_input_len,\n            output_len=args.synth_random_output_len,\n            range_ratio=args.synth_random_range_ratio,\n            seed=args.synth_seed,\n        )\n    return requests\n\n\ndef _create_router(name: str, total_gpus: int):\n    if name == \"random\":\n        return RandomRouter(total_gpus)\n    if name == \"round_robin\":\n        return RoundRobinRouter(total_gpus)\n    if name == \"sticky\":\n        return StickyRouter(total_gpus)\n    raise ValueError(f\"Unknown router: {name}\")\n\n\ndef _create_scheduler(name: str):\n    if name == \"fifo\":\n        return FIFOScheduler()\n    raise ValueError(f\"Unknown scheduler: {name}\")\n\n\ndef main(args: argparse.Namespace) -> SimulationResult:\n    if args.synth_seed is not None:\n        random.seed(args.synth_seed)\n    requests = _load_requests(args)\n    total_gpus = args.num_gpus_per_engine * args.num_engines\n    router = _create_router(args.router, total_gpus)\n    scheduler = _create_scheduler(args.scheduler)\n\n    sim = Simulator(\n        num_gpus_per_engine=args.num_gpus_per_engine,\n        router=router,\n        scheduler=scheduler,\n        recorders=[\n            BatchSizeBalancednessRecorder(),\n            AttentionComputeBalancednessRecorder(),\n            AvgBatchSizeRecorder(),\n        ],\n        log_level=args.log_level,\n        max_total_tokens=args.max_total_tokens,\n        stop_criteria=args.stop_criteria,\n        max_steps=args.max_steps,\n    )\n\n    print(\n        f\"Running simulation with {args.num_gpus_per_engine} GPUs/engine x {args.num_engines} engines, router={args.router}, scheduler={args.scheduler}\"\n    )\n    result = sim.run(requests)\n\n    print(\"\\n=== Summary ===\")\n    for key, value in result.summary.items():\n        print(f\"{key}: {value:.4f}\" if isinstance(value, float) else f\"{key}: {value}\")\n\n    if args.output:\n        with open(args.output, \"w\") as f:\n            json.dump(result.summary, f, indent=2)\n        print(f\"\\nSummary saved to {args.output}\")\n\n    return result\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/schedule_simulator/gpu_state.py",
    "content": "from dataclasses import dataclass, field\nfrom typing import List, Optional\n\nfrom sglang.srt.debug_utils.schedule_simulator.request import SimRequest\n\n\n@dataclass\nclass StepRecord:\n    step: int\n    gpu_id: int\n    running_count: int\n    pending_count: int\n    total_seq_len: int\n    running_req_ids: List[str] = field(default_factory=list)\n    pending_req_ids: List[str] = field(default_factory=list)\n\n\n@dataclass\nclass GPUState:\n    gpu_id: int\n    max_total_tokens: int\n    pending_requests: List[SimRequest] = field(default_factory=list)\n    running_requests: List[SimRequest] = field(default_factory=list)\n\n    def batch_size(self) -> int:\n        return len(self.running_requests)\n\n    def total_attention_compute(self) -> int:\n        return sum(req.seq_len() for req in self.running_requests)\n\n    def total_seq_len(self, extra_reqs: Optional[List[SimRequest]] = None) -> int:\n        seen_groups = set()\n        total = 0\n        for req in self.running_requests + (extra_reqs or []):\n            is_shared = req.group_id is not None and req.group_id in seen_groups\n            total += req.seq_len() - (req.prefix_len if is_shared else 0)\n            if req.group_id is not None:\n                seen_groups.add(req.group_id)\n        return total\n\n    def is_valid(self) -> bool:\n        return self.total_seq_len() <= self.max_total_tokens\n\n    def start_request(self, req: SimRequest) -> None:\n        assert req in self.pending_requests\n        self.pending_requests.remove(req)\n        self.running_requests.append(req)\n\n    def evict_request(self, req: SimRequest) -> None:\n        assert req in self.running_requests\n        self.running_requests.remove(req)\n        self.pending_requests.insert(0, req)\n\n    def execute_step(self) -> None:\n        for req in self.running_requests:\n            req.decoded_tokens += 1\n        self.running_requests = [\n            r for r in self.running_requests if not r.is_finished()\n        ]\n\n    def get_step_record(self, step: int) -> StepRecord:\n        return StepRecord(\n            step=step,\n            gpu_id=self.gpu_id,\n            running_count=len(self.running_requests),\n            pending_count=len(self.pending_requests),\n            total_seq_len=self.total_seq_len(),\n            running_req_ids=[r.request_id for r in self.running_requests],\n            pending_req_ids=[r.request_id for r in self.pending_requests],\n        )\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/schedule_simulator/metrics.py",
    "content": "from abc import ABC, abstractmethod\nfrom typing import Any, Callable, Dict, List\n\nfrom sglang.srt.debug_utils.schedule_simulator.gpu_state import GPUState\n\n\nclass MetricRecorder(ABC):\n    @abstractmethod\n    def on_step_end(self, step: int, gpu_states: List[GPUState]) -> None: ...\n\n    @abstractmethod\n    def get_summary(self) -> Dict[str, Any]: ...\n\n\nclass BalancednessRecorder(MetricRecorder):\n    def __init__(self, name: str, value_fn: Callable[[GPUState], float]):\n        self._name = name\n        self._value_fn = value_fn\n        self._history: List[float] = []\n\n    def on_step_end(self, step: int, gpu_states: List[GPUState]) -> None:\n        values = [self._value_fn(gpu) for gpu in gpu_states]\n        max_val = max(values) if values else 0\n        mean_val = sum(values) / len(values) if values else 0\n        balancedness = mean_val / max_val if max_val > 0 else 1.0\n        self._history.append(balancedness)\n\n    def get_summary(self) -> Dict[str, Any]:\n        if not self._history:\n            return {f\"{self._name}_mean\": 0.0}\n        return {\n            f\"{self._name}_mean\": sum(self._history) / len(self._history),\n            f\"{self._name}_min\": min(self._history),\n            f\"{self._name}_max\": max(self._history),\n        }\n\n\ndef BatchSizeBalancednessRecorder() -> BalancednessRecorder:\n    return BalancednessRecorder(\"batch_size_balancedness\", lambda gpu: gpu.batch_size())\n\n\ndef AttentionComputeBalancednessRecorder() -> BalancednessRecorder:\n    return BalancednessRecorder(\n        \"attention_compute_balancedness\", lambda gpu: gpu.total_attention_compute()\n    )\n\n\nclass AvgBatchSizeRecorder(MetricRecorder):\n    def __init__(self):\n        self._total_running = 0\n        self._num_records = 0\n\n    def on_step_end(self, step: int, gpu_states: List[GPUState]) -> None:\n        for gpu in gpu_states:\n            self._total_running += gpu.batch_size()\n            self._num_records += 1\n\n    def get_summary(self) -> Dict[str, Any]:\n        avg = self._total_running / self._num_records if self._num_records else 0.0\n        return {\"avg_batch_size\": avg}\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/schedule_simulator/request.py",
    "content": "from dataclasses import dataclass\nfrom typing import Optional\n\n\n@dataclass\nclass SimRequest:\n    request_id: str\n    input_len: int\n    output_len: int\n    decoded_tokens: int = 0\n    group_id: Optional[str] = None\n    prefix_len: int = 0\n\n    def seq_len(self) -> int:\n        return self.input_len + self.decoded_tokens\n\n    def is_finished(self) -> bool:\n        return self.decoded_tokens >= self.output_len\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/schedule_simulator/routers/__init__.py",
    "content": "from sglang.srt.debug_utils.schedule_simulator.routers.base import RouterPolicy\nfrom sglang.srt.debug_utils.schedule_simulator.routers.random_router import RandomRouter\nfrom sglang.srt.debug_utils.schedule_simulator.routers.round_robin_router import (\n    RoundRobinRouter,\n)\nfrom sglang.srt.debug_utils.schedule_simulator.routers.sticky_router import StickyRouter\n\n__all__ = [\"RouterPolicy\", \"RandomRouter\", \"RoundRobinRouter\", \"StickyRouter\"]\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/schedule_simulator/routers/base.py",
    "content": "from abc import ABC, abstractmethod\n\nfrom sglang.srt.debug_utils.schedule_simulator.request import SimRequest\n\n\nclass RouterPolicy(ABC):\n    @abstractmethod\n    def route(self, incoming_request: SimRequest) -> int: ...\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/schedule_simulator/routers/random_router.py",
    "content": "import random\n\nfrom sglang.srt.debug_utils.schedule_simulator.request import SimRequest\nfrom sglang.srt.debug_utils.schedule_simulator.routers.base import RouterPolicy\n\n\nclass RandomRouter(RouterPolicy):\n    def __init__(self, num_gpus: int):\n        self._num_gpus = num_gpus\n\n    def route(self, incoming_request: SimRequest) -> int:\n        return random.randint(0, self._num_gpus - 1)\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/schedule_simulator/routers/round_robin_router.py",
    "content": "from sglang.srt.debug_utils.schedule_simulator.request import SimRequest\nfrom sglang.srt.debug_utils.schedule_simulator.routers.base import RouterPolicy\n\n\nclass RoundRobinRouter(RouterPolicy):\n    def __init__(self, num_gpus: int):\n        self._num_gpus = num_gpus\n        self._counter = 0\n\n    def route(self, incoming_request: SimRequest) -> int:\n        gpu_id = self._counter % self._num_gpus\n        self._counter += 1\n        return gpu_id\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/schedule_simulator/routers/sticky_router.py",
    "content": "import random\nfrom collections import defaultdict\n\nfrom sglang.srt.debug_utils.schedule_simulator.request import SimRequest\nfrom sglang.srt.debug_utils.schedule_simulator.routers.base import RouterPolicy\n\n\nclass StickyRouter(RouterPolicy):\n    def __init__(self, num_gpus: int):\n        self._num_gpus = num_gpus\n        self._group_to_gpu = defaultdict(self._assign_gpu)\n\n    def _assign_gpu(self) -> int:\n        return random.randint(0, self._num_gpus - 1)\n\n    def route(self, incoming_request: SimRequest) -> int:\n        group_id = incoming_request.group_id\n        if group_id is None:\n            return random.randint(0, self._num_gpus - 1)\n        return self._group_to_gpu[group_id]\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/schedule_simulator/schedulers/__init__.py",
    "content": "from sglang.srt.debug_utils.schedule_simulator.schedulers.base import SchedulerPolicy\nfrom sglang.srt.debug_utils.schedule_simulator.schedulers.fifo_scheduler import (\n    FIFOScheduler,\n)\n\n__all__ = [\"SchedulerPolicy\", \"FIFOScheduler\"]\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/schedule_simulator/schedulers/base.py",
    "content": "from abc import ABC, abstractmethod\nfrom typing import TYPE_CHECKING\n\nif TYPE_CHECKING:\n    from sglang.srt.debug_utils.schedule_simulator.gpu_state import GPUState\n\n\nclass SchedulerPolicy(ABC):\n    @abstractmethod\n    def schedule(self, gpu_state: \"GPUState\") -> None: ...\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/schedule_simulator/schedulers/fifo_scheduler.py",
    "content": "from typing import TYPE_CHECKING\n\nfrom sglang.srt.debug_utils.schedule_simulator.schedulers.base import SchedulerPolicy\n\nif TYPE_CHECKING:\n    from sglang.srt.debug_utils.schedule_simulator.gpu_state import GPUState\n\n\nclass FIFOScheduler(SchedulerPolicy):\n    def schedule(self, gpu_state: \"GPUState\") -> None:\n        while not gpu_state.is_valid() and gpu_state.running_requests:\n            gpu_state.evict_request(gpu_state.running_requests[-1])\n\n        for req in list(gpu_state.pending_requests):\n            if gpu_state.total_seq_len(extra_reqs=[req]) <= gpu_state.max_total_tokens:\n                gpu_state.start_request(req)\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/schedule_simulator/simulator.py",
    "content": "from dataclasses import dataclass\nfrom typing import Any, Dict, List, Optional\n\nfrom sglang.srt.debug_utils.schedule_simulator.gpu_state import GPUState, StepRecord\nfrom sglang.srt.debug_utils.schedule_simulator.metrics import MetricRecorder\nfrom sglang.srt.debug_utils.schedule_simulator.request import SimRequest\nfrom sglang.srt.debug_utils.schedule_simulator.routers.base import RouterPolicy\nfrom sglang.srt.debug_utils.schedule_simulator.schedulers.base import SchedulerPolicy\n\n\n@dataclass\nclass SimulationResult:\n    step_records: List[StepRecord]\n    summary: Dict[str, Any]\n\n\nclass Simulator:\n    def __init__(\n        self,\n        num_gpus_per_engine: int,\n        router: RouterPolicy,\n        scheduler: SchedulerPolicy,\n        recorders: Optional[List[MetricRecorder]] = None,\n        log_level: int = 0,\n        max_total_tokens: int = 100000,\n        stop_criteria: str = \"all_done\",\n        max_steps: Optional[int] = None,\n    ):\n        self.num_gpus_per_engine = num_gpus_per_engine\n        self.router = router\n        self.scheduler = scheduler\n        self.recorders = recorders or []\n        self.log_level = log_level\n        self.max_total_tokens = max_total_tokens\n        self.stop_criteria = stop_criteria\n        self.max_steps = max_steps\n        self.gpu_states: List[GPUState] = []\n        self.step = 0\n\n    def run(self, requests: List[SimRequest]) -> SimulationResult:\n        self.gpu_states = [\n            GPUState(gpu_id=i, max_total_tokens=self.max_total_tokens)\n            for i in range(self.num_gpus_per_engine)\n        ]\n        self.step = 0\n        step_records: List[StepRecord] = []\n        incoming_requests = list(requests)\n\n        while True:\n            self._route_requests(incoming_requests)\n            incoming_requests.clear()\n            self._schedule_all_gpus()\n            if self._should_stop():\n                break\n            self._execute_step()\n            step_records.extend(\n                gpu.get_step_record(self.step) for gpu in self.gpu_states\n            )\n            self._log_step()\n            self._record_metrics()\n            self.step += 1\n\n        return SimulationResult(step_records=step_records, summary=self._get_summary())\n\n    def _should_stop(self) -> bool:\n        if self.max_steps is not None and self.step >= self.max_steps:\n            return True\n        if self.stop_criteria == \"exist_no_pending\":\n            return any(not gpu.pending_requests for gpu in self.gpu_states)\n        if self.stop_criteria == \"all_done\":\n            return not any(\n                gpu.pending_requests or gpu.running_requests for gpu in self.gpu_states\n            )\n        raise ValueError(f\"Unknown stop criteria: {self.stop_criteria}\")\n\n    def _route_requests(self, incoming_requests: List[SimRequest]) -> None:\n        for req in incoming_requests:\n            gpu_id = self.router.route(req)\n            if gpu_id < self.num_gpus_per_engine:\n                self.gpu_states[gpu_id].pending_requests.append(req)\n\n    def _schedule_all_gpus(self) -> None:\n        for gpu in self.gpu_states:\n            self.scheduler.schedule(gpu)\n            assert gpu.is_valid(), (\n                f\"GPU{gpu.gpu_id} invalid after scheduling \"\n                f\"({gpu.total_seq_len()=}, {gpu.max_total_tokens=})\"\n            )\n\n    def _execute_step(self) -> None:\n        for gpu in self.gpu_states:\n            gpu.execute_step()\n\n    def _log_step(self) -> None:\n        if self.log_level == 0 and self.step % 100 != 0:\n            return\n        parts = [f\"step={self.step:<4}\"]\n        for gpu in self.gpu_states:\n            r, q = len(gpu.running_requests), len(gpu.pending_requests)\n            if self.log_level <= 1:\n                parts.append(f\"GPU{gpu.gpu_id}[R={r:<3} Q={q:<3}]\")\n            else:\n                run_ids = _format_ids(gpu.running_requests)\n                queue_ids = _format_ids(gpu.pending_requests)\n                parts.append(f\"GPU{gpu.gpu_id}[R={r}:{run_ids} Q={q}:{queue_ids}]\")\n        print(\" | \".join(parts))\n\n    def _record_metrics(self) -> None:\n        for recorder in self.recorders:\n            recorder.on_step_end(self.step, self.gpu_states)\n\n    def _get_summary(self) -> Dict[str, Any]:\n        return {k: v for r in self.recorders for k, v in r.get_summary().items()}\n\n\ndef _format_ids(requests: List[SimRequest], limit: int = 5) -> str:\n    if not requests:\n        return \"-\"\n    ids = \",\".join(r.request_id for r in requests[:limit])\n    if len(requests) > limit:\n        ids += f\"...+{len(requests) - limit}\"\n    return ids\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/source_patcher/__init__.py",
    "content": "from sglang.srt.debug_utils.source_patcher.code_patcher import (\n    CodePatcher,\n    apply_patches_from_config,\n    patch_function,\n)\nfrom sglang.srt.debug_utils.source_patcher.types import (\n    EditSpec,\n    PatchApplicationError,\n    PatchConfig,\n    PatchSpec,\n    PatchState,\n)\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/source_patcher/code_patcher.py",
    "content": "import importlib\nimport inspect\nimport textwrap\nimport types\nfrom collections.abc import Callable\nfrom typing import Any, Optional\n\nimport yaml\n\nfrom sglang.srt.debug_utils.source_patcher.source_editor import apply_edits\nfrom sglang.srt.debug_utils.source_patcher.types import (\n    EditSpec,\n    PatchConfig,\n    PatchSpec,\n    PatchState,\n)\n\n\ndef apply_patches_from_config(\n    yaml_content: str,\n    *,\n    extra_imports: Optional[list[str]] = None,\n) -> list[PatchState]:\n    \"\"\"Parse a YAML config string and apply all patches.\n\n    Args:\n        yaml_content: YAML string with patch specifications.\n        extra_imports: Import lines inserted once at the top of each patched\n            function body (e.g. [\"from pkg import foo\"]).  The caller (dumper)\n            uses this so users don't have to write boilerplate in YAML.\n    \"\"\"\n    raw: dict[str, Any] = yaml.safe_load(yaml_content)\n    config: PatchConfig = PatchConfig(**raw)\n\n    if extra_imports:\n        config = _inject_preamble(config=config, extra_imports=extra_imports)\n\n    return _apply_specs(config.patches)\n\n\nclass CodePatcher:\n    \"\"\"Context manager that patches functions on enter and restores on exit.\"\"\"\n\n    def __init__(self, *, patches: list[PatchSpec]) -> None:\n        self._patches = patches\n        self._states: list[PatchState] = []\n\n    def __enter__(self) -> \"CodePatcher\":\n        self._states = _apply_specs(self._patches)\n        return self\n\n    def __exit__(\n        self,\n        exc_type: Optional[type],\n        exc_val: Optional[BaseException],\n        exc_tb: Optional[Any],\n    ) -> None:\n        for state in reversed(self._states):\n            state.restore()\n        self._states.clear()\n\n\ndef patch_function(\n    *,\n    target: Callable[..., Any],\n    edits: list[EditSpec],\n    preamble: str = \"\",\n) -> PatchState:\n    \"\"\"Patch a function by modifying its source and replacing __code__.\n\n    1. inspect.getsource -> get original source\n    2. apply_edits -> modify source text\n    3. optionally prepend preamble (e.g. import lines) inside the function body\n    4. compile + exec -> get new code object\n    5. replace target.__code__\n\n    Returns PatchState that can restore the original code.\n    \"\"\"\n    original_code: types.CodeType = target.__code__\n\n    source: str = inspect.getsource(target)\n    modified_source: str = apply_edits(source=source, edits=edits)\n    modified_source = textwrap.dedent(modified_source)\n\n    if preamble.strip():\n        modified_source = _insert_preamble(source=modified_source, preamble=preamble)\n\n    code: types.CodeType = compile(modified_source, inspect.getfile(target), \"exec\")\n    temp_namespace: dict[str, Any] = {}\n    exec(code, target.__globals__, temp_namespace)\n\n    new_fn: Any = temp_namespace[target.__name__]\n    target.__code__ = new_fn.__code__\n\n    return PatchState(target_fn=target, original_code=original_code)\n\n\n# --------------------------------- private ---------------------------------\n\n\ndef _apply_specs(specs: list[PatchSpec]) -> list[PatchState]:\n    states: list[PatchState] = []\n    for spec in specs:\n        target_fn: Callable[..., Any] = _resolve_target(spec.target)\n        print(f\"[source_patcher] patching {spec.target}\")\n        state: PatchState = patch_function(\n            target=target_fn, edits=spec.edits, preamble=spec.preamble\n        )\n        states.append(state)\n    return states\n\n\ndef _inject_preamble(*, config: PatchConfig, extra_imports: list[str]) -> PatchConfig:\n    \"\"\"Set preamble on every PatchSpec so imports are inserted once at function top.\"\"\"\n    import_block: str = \"\\n\".join(extra_imports)\n    new_patches: list[PatchSpec] = []\n\n    for spec in config.patches:\n        existing: str = spec.preamble\n        combined: str = (\n            import_block + \"\\n\" + existing if existing.strip() else import_block\n        )\n        new_patches.append(\n            PatchSpec(target=spec.target, edits=spec.edits, preamble=combined)\n        )\n\n    return PatchConfig(patches=new_patches)\n\n\ndef _insert_preamble(*, source: str, preamble: str) -> str:\n    \"\"\"Insert preamble lines right after the function signature (and optional docstring).\"\"\"\n    lines: list[str] = source.splitlines()\n\n    signature_end: int = _find_signature_end(lines)\n\n    body_start: int = signature_end + 1\n    body_indent: str = \"\"\n    for i in range(body_start, len(lines)):\n        if lines[i].strip():\n            body_indent = \" \" * (len(lines[i]) - len(lines[i].lstrip()))\n            body_start = i\n            break\n\n    preamble_lines: list[str] = [\n        body_indent + pl for pl in preamble.strip().splitlines()\n    ]\n    return \"\\n\".join(lines[:body_start] + preamble_lines + lines[body_start:])\n\n\ndef _find_signature_end(lines: list[str]) -> int:\n    \"\"\"Find the line index where the function signature ends (the line with trailing colon).\"\"\"\n    for i, line in enumerate(lines):\n        if line.rstrip().endswith(\":\"):\n            return i\n    return 0\n\n\ndef _resolve_target(qualified_name: str) -> Callable[..., Any]:\n    \"\"\"Resolve 'pkg.mod.Class.method' to the actual function object.\n\n    Tries progressively shorter module paths from right to left,\n    then uses getattr for the remaining attribute chain.\n    \"\"\"\n    parts: list[str] = qualified_name.split(\".\")\n\n    target: Any = None\n    for split_idx in range(len(parts), 0, -1):\n        module_path: str = \".\".join(parts[:split_idx])\n        try:\n            target = importlib.import_module(module_path)\n            attr_parts: list[str] = parts[split_idx:]\n            break\n        except ImportError:\n            continue\n    else:\n        raise ImportError(f\"could not import any module prefix of '{qualified_name}'\")\n\n    for attr_name in attr_parts:\n        target = getattr(target, attr_name)\n\n    if isinstance(target, classmethod):\n        target = target.__func__\n    if not callable(target):\n        raise TypeError(\n            f\"resolved target '{qualified_name}' is not callable: {type(target)}\"\n        )\n\n    return target\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/source_patcher/source_editor.py",
    "content": "from sglang.srt.debug_utils.source_patcher.types import EditSpec, PatchApplicationError\n\n\ndef apply_edits(*, source: str, edits: list[EditSpec]) -> str:\n    \"\"\"Apply a sequence of match/replacement edits to source text.\n\n    Each edit is applied sequentially so later edits see the result of earlier ones.\n    \"\"\"\n    result: str = source\n    for edit in edits:\n        result = _apply_single_edit(source=result, edit=edit)\n    return result\n\n\ndef _apply_single_edit(*, source: str, edit: EditSpec) -> str:\n    \"\"\"Apply a single match/replacement edit to the source text.\"\"\"\n    match_text: str = edit.match.strip()\n    if not match_text:\n        raise PatchApplicationError(\"empty match text\")\n\n    source_lines: list[str] = source.splitlines()\n    match_lines: list[str] = match_text.splitlines()\n\n    start_idx: int = _find_match(source_lines=source_lines, match_lines=match_lines)\n    match_len: int = len(match_lines)\n\n    original_indent: int = _leading_spaces(source_lines[start_idx])\n\n    effective_replacement: str = _resolve_replacement(edit=edit, match_text=match_text)\n    replacement_lines: list[str] = (\n        effective_replacement.splitlines() if effective_replacement else []\n    )\n    aligned: list[str] = _realign_replacement(\n        replacement_lines=replacement_lines, original_indent=original_indent\n    )\n    new_lines: list[str] = (\n        source_lines[:start_idx] + aligned + source_lines[start_idx + match_len :]\n    )\n\n    trailing_newline: str = \"\\n\" if source.endswith(\"\\n\") else \"\"\n    return \"\\n\".join(new_lines) + trailing_newline\n\n\ndef _resolve_replacement(*, edit: EditSpec, match_text: str) -> str:\n    \"\"\"Return the effective replacement text, handling replacement, prepend, and append modes.\"\"\"\n    if edit.prepend.strip():\n        return edit.prepend.strip() + \"\\n\" + match_text\n    if edit.append.strip():\n        return match_text + \"\\n\" + edit.append.strip()\n    return edit.replacement.strip()\n\n\ndef _find_match(*, source_lines: list[str], match_lines: list[str]) -> int:\n    \"\"\"Find the start index of match_lines in source_lines (strip-compared).\n\n    Returns the index of the first matching line.\n    Raises PatchApplicationError if not found or found multiple times.\n    \"\"\"\n    stripped_source: list[str] = [line.strip() for line in source_lines]\n    stripped_match: list[str] = [line.strip() for line in match_lines]\n    match_len: int = len(stripped_match)\n\n    found_indices: list[int] = [\n        i\n        for i in range(len(stripped_source) - match_len + 1)\n        if stripped_source[i : i + match_len] == stripped_match\n    ]\n\n    if len(found_indices) == 0:\n        preview: str = \"\\n\".join(match_lines)\n        raise PatchApplicationError(f\"match text not found in source:\\n{preview}\")\n    if len(found_indices) > 1:\n        preview = \"\\n\".join(match_lines)\n        raise PatchApplicationError(\n            f\"match text found multiple times ({len(found_indices)} occurrences) in source:\\n{preview}\"\n        )\n\n    return found_indices[0]\n\n\ndef _realign_replacement(\n    *, replacement_lines: list[str], original_indent: int\n) -> list[str]:\n    \"\"\"Realign replacement lines to the original indentation level.\n\n    Strategy:\n    - Take the leading spaces of the first non-empty replacement line as base_indent\n    - For each replacement line: remove base_indent, add original_indent\n    \"\"\"\n    non_empty: list[str] = [line for line in replacement_lines if line.strip()]\n    if not non_empty:\n        return []\n\n    base_indent: int = _leading_spaces(non_empty[0])\n    result: list[str] = []\n\n    for line in replacement_lines:\n        if not line.strip():\n            result.append(\"\")\n        else:\n            stripped = line[min(base_indent, len(line) - len(line.lstrip())) :]\n            result.append(\" \" * original_indent + stripped)\n\n    return result\n\n\ndef _leading_spaces(line: str) -> int:\n    return len(line) - len(line.lstrip(\" \"))\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/source_patcher/types.py",
    "content": "import types\nfrom collections.abc import Callable\nfrom typing import Any\n\nfrom pydantic import BaseModel, ConfigDict, model_validator\n\n\nclass PatchApplicationError(Exception):\n    \"\"\"match text not found or not unique in source.\"\"\"\n\n\nclass _StrictBase(BaseModel):\n    model_config = ConfigDict(extra=\"forbid\")\n\n\nclass EditSpec(_StrictBase):\n    \"\"\"Specify one edit: replace, prepend before, or append after the matched text.\n\n    Use ``replacement`` to substitute the matched text (empty string = delete).\n    Use ``prepend`` to keep the matched text and add lines before it.\n    Use ``append`` to keep the matched text and add lines after it.\n    Only one of ``replacement``, ``prepend``, and ``append`` may be set.\n    \"\"\"\n\n    match: str\n    replacement: str = \"\"\n    prepend: str = \"\"\n    append: str = \"\"\n\n    @model_validator(mode=\"after\")\n    def _check_modes_mutually_exclusive(self) -> \"EditSpec\":\n        active: list[str] = [\n            name\n            for name in (\"replacement\", \"prepend\", \"append\")\n            if getattr(self, name).strip()\n        ]\n        if len(active) > 1:\n            raise ValueError(\n                f\"only one of 'replacement', 'prepend', 'append' may be set, \"\n                f\"got: {', '.join(active)}\"\n            )\n        return self\n\n\nclass PatchSpec(_StrictBase):\n    target: str\n    edits: list[EditSpec]\n    preamble: str = \"\"\n\n\nclass PatchConfig(_StrictBase):\n    patches: list[PatchSpec]\n\n\nclass PatchState:\n    def __init__(\n        self, *, target_fn: Callable[..., Any], original_code: types.CodeType\n    ) -> None:\n        self.target_fn = target_fn\n        self.original_code = original_code\n\n    def restore(self) -> None:\n        self.target_fn.__code__ = self.original_code\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/tensor_dump_forward_hook.py",
    "content": "\"\"\"\nThis file provides a function `register_forward_hook_for_model` that registers a forward hook on every operator of the model.\nAfter registration, during model inference, all tensors generated throughout the forward pass will be recorded.\n\nUsage:\nSpecify the output directory for dumping tensors using the argument `--debug-tensor-dump-output-folder`.\nA separate directory will be created for each GPU rank, named in the format `f\"TP{tp_rank}_PP{pp_rank}_Rank{rank}_pid{pid}\"`.\nEach complete forward pass of the model generates a `.pt` file named `f\"Pass{pass_num}.pt\"`, which can be loaded using `torch.load`.\nThe file contains a series of key-value pairs, where the keys correspond to operator names in the model\n(similar to those in model.safetensors.index.json), and the values are the outputs produced by the respective operators.\n\"\"\"\n\nimport logging\nimport os\nfrom pathlib import Path\nfrom typing import List, Optional\n\nimport torch\n\nfrom sglang.srt.layers.logits_processor import LogitsProcessorOutput\nfrom sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors\n\nlogger = logging.getLogger(__name__)\n\n\nclass TensorDumper:\n    def __init__(\n        self,\n        dump_dir: str,\n        dump_layers: Optional[List[int]],\n        tp_size: int,\n        tp_rank: int,\n        pp_rank: int,\n    ):\n        self._dump_layers = dump_layers\n        self._forward_pass_id = 0\n        self._pid = os.getpid()\n        self._current_tensors = {}\n        self._base_dir = Path(dump_dir)\n        rank = tp_size * pp_rank + tp_rank\n        self._process_dir = (\n            self._base_dir / f\"TP{tp_rank}_PP{pp_rank}_Rank{rank}_pid{self._pid}\"\n        )\n        self._process_dir.mkdir(parents=True, exist_ok=True)\n\n    def get_dump_dir(self):\n        return str(self._process_dir)\n\n    def add_tensor(self, name, tensor_item):\n        if isinstance(tensor_item, (tuple, list)):\n            tensors = [t.cpu() for t in tensor_item if t is not None]\n            if len(tensors) == 1:\n                self._current_tensors[name] = tensors[0]\n            else:\n                self._current_tensors[name] = tensors\n        elif isinstance(tensor_item, torch.Tensor):\n            self._current_tensors[name] = tensor_item.cpu()\n        elif isinstance(tensor_item, LogitsProcessorOutput):\n            self._current_tensors[name] = tensor_item.next_token_logits.cpu()\n        elif isinstance(tensor_item, ForwardBatch):\n            self._current_tensors[name + \".forward_batch_info.input_ids\"] = (\n                tensor_item.input_ids.cpu()\n            )\n            self._current_tensors[name + \".forward_batch_info.seq_lens\"] = (\n                tensor_item.seq_lens.cpu()\n            )\n            self._current_tensors[name + \".forward_batch_info.positions\"] = (\n                tensor_item.positions.cpu()\n            )\n        elif isinstance(tensor_item, PPProxyTensors):\n            for tensor_name in tensor_item.tensors.keys():\n                self._current_tensors[name + \".pp_proxy_tensors.\" + tensor_name] = (\n                    tensor_item.tensors[tensor_name].cpu()\n                )\n        else:\n            logger.warning(f\"Unsupported type: {type(tensor_item)}: {tensor_item}\")\n\n    def dump_current_tensors(self):\n        if len(self._current_tensors) == 0:\n            return\n        tensor_file_for_pass = self._process_dir / f\"Pass{self._forward_pass_id:05d}.pt\"\n        logger.info(\n            f\"Dump {self._forward_pass_id:05d}th pass to {tensor_file_for_pass}\"\n        )\n        torch.save(self._current_tensors, str(tensor_file_for_pass))\n        self._current_tensors = {}\n        self._forward_pass_id += 1\n\n    def _add_hook_recursive(\n        self, model, prefix, top_level_module_name, layers_module_name\n    ):\n        model_top_level_module_matched = False\n        layers_prefix = top_level_module_name + \".\" + layers_module_name\n        for name, module in model._modules.items():\n            top_level_model = False\n            if len(prefix) == 0:\n                cur_name = name\n                if cur_name == top_level_module_name:\n                    model_top_level_module_matched = True\n                    top_level_model = True\n            else:\n                cur_name = prefix + \".\" + name\n            if (\n                self._dump_layers is not None\n                and name.isdigit()\n                and prefix == layers_prefix\n            ):\n                # If we only need n layers, skip the reset layers.\n                # Most models' layout is like model.layers.0.\n                cur_layer = int(name)\n                if cur_layer not in self._dump_layers:\n                    continue\n            if module is not None:\n                _, sub_count = self._add_hook_recursive(\n                    module, cur_name, top_level_module_name, layers_module_name\n                )\n                if sub_count == 0 or top_level_model:\n                    # Avoid duplicated output hooks, e.g. self_attn may contain:\n                    # self_attn.qkv_proj, self_attn.attn & self_attn.o_proj.\n                    # Therefore, we do not need to add output hooks for self_attn,\n                    # since the output of self_attn should be the same to self_attn.o_proj.\n                    module.register_forward_hook(\n                        self._dump_hook(cur_name, top_level_model)\n                    )\n        return model_top_level_module_matched, len(model._modules.items())\n\n    def _dump_hook(self, tensor_name, do_dump):\n        def inner_dump_hook(module, input, output):\n            if do_dump:\n                # This is the top-level model, so we will record the input for it.\n                for item in input:\n                    if isinstance(item, ForwardBatch):\n                        self.add_tensor(tensor_name, item)\n                self.dump_current_tensors()\n            if output is not None:\n                self.add_tensor(tensor_name, output)\n\n        return inner_dump_hook\n\n\ndef register_forward_hook_for_model(\n    model,\n    dump_dir: str,\n    dump_layers: Optional[List[int]],\n    tp_size: int,\n    tp_rank: int,\n    pp_rank: int,\n):\n    tensor_dumper = TensorDumper(dump_dir, dump_layers, tp_size, tp_rank, pp_rank)\n    # Most models have the layerout like:\n    # XxxxForCausalLM\n    #     (model): XxxxModel\n    #         (layers): ModuleList\n    # If the model is not constructed with this layout,\n    # environment variable can be used to specify the module names.\n    top_level_module_name = os.getenv(\"TENSOR_DUMP_TOP_LEVEL_MODULE_NAME\", \"model\")\n    layers_module_name = os.getenv(\"TENSOR_DUMP_LAYERS_MODULE_NAME\", \"layers\")\n    model_top_level_module_matched, _ = tensor_dumper._add_hook_recursive(\n        model, \"\", top_level_module_name, layers_module_name\n    )\n    assert (\n        model_top_level_module_matched\n    ), f\"model should have a module named {top_level_module_name}\"\n    return tensor_dumper\n"
  },
  {
    "path": "python/sglang/srt/debug_utils/text_comparator.py",
    "content": "import argparse\nimport hashlib\nimport json\nfrom pathlib import Path\n\nimport polars as pl\n\n_DESCRIPTION = \"\"\"Compare and find differences to benchmark outputs.\n\nSupported inputs:\n* The samples jsonl from `lm_eval --log_samples --output_path FOLDER_NAME`\n* The output from `gsm8k/bench_sglang.py --raw-result-file FILE_NAME` (or mmlu)\n\"\"\"\n\n\ndef main(args):\n    if args.data_type == \"simple_evals\":\n        df_input = _compute_df_input_mode_simple_evals(args)\n    else:\n        df_input = _transform_df_input(_compute_df_raw(args))\n\n    assert all(\n        c in df_input.columns\n        for c in [\"category\", \"trial_index\", \"prompt_id\", \"prompt\", \"output\", \"correct\"]\n    )\n\n    df_meta = _compute_df_meta(df_input)\n\n    df_correctness_per_trial = df_input.group_by(\n        \"category\", \"trial_index\", maintain_order=True\n    ).agg(pl.col(\"correct\").mean())\n    df_correctness_delta = (\n        df_meta.group_by(\"correctness_delta\").len().sort(\"correctness_delta\")\n    )\n    df_good_to_bad = df_meta.filter(pl.col(\"correctness_delta\") < 0)\n    df_bad_to_good = df_meta.filter(pl.col(\"correctness_delta\") > 0)\n\n    print(f\"Dump output to {args.output_path}\")\n    Path(args.output_path).write_text(\n        json.dumps(\n            dict(\n                df_meta=df_meta.to_dicts(),\n                df_good_to_bad=df_good_to_bad.to_dicts(),\n                df_bad_to_good=df_bad_to_good.to_dicts(),\n            ),\n            indent=4,\n        ),\n    )\n\n    if not args.disable_print_details:\n        with pl.Config(\n            fmt_str_lengths=10000,\n            tbl_cols=-1,\n            tbl_rows=-1,\n            tbl_width_chars=-1,\n            tbl_formatting=\"UTF8_FULL\",\n        ):\n            print(\"====== Correctness per trial ======\")\n            print(df_correctness_per_trial)\n\n            print(\n                \"====== Correctness Delta (-1.0 means all-right becomes all-wrong) ======\"\n            )\n            print(df_correctness_delta)\n\n            for name, df in [\n                (\"Good->Bad\", df_good_to_bad),\n                (\"Bad->Good\", df_bad_to_good),\n            ]:\n                print(f\"====== Concrete Examples: {name} ======\")\n                print(df)\n\n\ndef _compute_df_input_mode_simple_evals(args):\n    return pl.concat(\n        [\n            _compute_df_input_one_mode_simple_evals(**info)\n            for info in _get_file_infos(args=args)\n        ]\n    )\n\n\ndef _compute_df_input_one_mode_simple_evals(path, category, trial_index):\n    data = json.loads(Path(path).read_text())\n    rows = []\n\n    for single_eval_result in data[\"metadata\"][\"single_eval_results\"]:\n        prompt = single_eval_result[\"example_level_metadata\"][\n            \"actual_queried_prompt_messages\"\n        ]\n        score = single_eval_result[\"score\"]\n        assert score in {0.0, 1.0}, f\"{score=}\"\n\n        row = dict(\n            category=category,\n            trial_index=trial_index,\n            prompt_id=_compute_id_from_object(prompt),\n            prompt=json.dumps(prompt),\n            output=single_eval_result[\"example_level_metadata\"][\"response_text\"],\n            correct=score == 1.0,\n        )\n        rows.append(row)\n\n    return pl.DataFrame(rows)\n\n\ndef _compute_id_from_object(obj):\n    if isinstance(obj, pl.Series):\n        obj = obj.to_list()\n    json_str = json.dumps(obj, sort_keys=True, ensure_ascii=False)\n    return hashlib.sha256(json_str.encode(\"utf-8\")).hexdigest()\n\n\ndef _compute_df_raw(args):\n    return pl.concat(\n        [\n            _read_df_raw(\n                path=info[\"path\"],\n                category=info[\"category\"],\n                trial_index=info[\"trial_index\"],\n            )\n            for info in _get_file_infos(args=args)\n        ]\n    )\n\n\ndef _get_file_infos(args):\n    return [\n        dict(path=path, category=category, trial_index=trial_index)\n        for category, paths in [\n            (\"baseline\", args.baseline_path),\n            (\"target\", args.target_path),\n        ]\n        for trial_index, path in enumerate(paths)\n    ]\n\n\ndef _read_df_raw(path: str, category: str, trial_index: int):\n    return pl.read_ndjson(path).with_columns(\n        category=pl.lit(category), trial_index=trial_index\n    )\n\n\ndef _transform_df_input(df: pl.DataFrame):\n    if \"doc_id\" in df.columns:\n        print(\"Transform mode: lm_eval\")\n\n        filter_names = df[\"filter\"].unique(maintain_order=True).to_list()\n        if len(filter_names) > 1:\n            filter_name = filter_names[0]\n            print(f\"Choose {filter_name=} among {filter_names}\")\n            df = df.filter(pl.col(\"filter\") == filter_name)\n\n        df = df.select(\n            pl.col(\"category\"),\n            pl.col(\"trial_index\"),\n            prompt_id=pl.col(\"doc_id\"),\n            prompt=pl.col(\"arguments\").struct.field(\"gen_args_0\").struct.field(\"arg_0\"),\n            output=pl.col(\"resps\").list.get(0).list.get(0),\n            correct=pl.col(\"exact_match\").cast(bool),\n        )\n\n        return df\n    elif \"prompt_id\" in df.columns:\n        print(\"Transform mode: SGLang bench\")\n        return df\n    else:\n        raise Exception(\n            f\"Unknown data: {df.columns}. You may need to set `--data-type` if using e.g. simple_evals.\"\n        )\n\n\ndef _compute_df_meta(df_input: pl.DataFrame):\n    df_input = df_input.sort(\"prompt_id\", \"category\", \"trial_index\")\n    df_meta = pl.DataFrame(\n        [\n            _handle_one_prompt(df_one_prompt)\n            for df_one_prompt in df_input.partition_by(\"prompt_id\", maintain_order=True)\n        ]\n    )\n    df_meta = df_meta.with_columns(\n        correctness_delta=pl.col(\"correctness_target\") - pl.col(\"correctness_baseline\"),\n    )\n    df_meta = df_meta.sort(\"correctness_delta\", \"output_same_prefix_len\")\n    return df_meta\n\n\ndef _handle_one_prompt(df_one_prompt: pl.DataFrame):\n    assert (\n        len(set(_compute_id_from_object(obj) for obj in df_one_prompt[\"prompt\"])) == 1\n    )\n\n    df_baseline = df_one_prompt.filter(pl.col(\"category\") == \"baseline\")\n    df_target = df_one_prompt.filter(pl.col(\"category\") == \"target\")\n\n    outputs_baseline = df_baseline[\"output\"].to_list()\n    outputs_target = df_target[\"output\"].to_list()\n\n    output_same_prefix_len = max(\n        _compute_str_prefix_len(output_baseline, output_target)\n        for output_baseline in outputs_baseline\n        for output_target in outputs_target\n    )\n\n    return dict(\n        prompt_id=df_one_prompt[0, \"prompt_id\"],\n        correctness_baseline=df_baseline[\"correct\"].mean(),\n        correctness_target=df_target[\"correct\"].mean(),\n        output_same_prefix_len=output_same_prefix_len,\n        prompt=df_one_prompt[0, \"prompt\"],\n        outputs_baseline=outputs_baseline,\n        outputs_target=outputs_target,\n    )\n\n\ndef _compute_str_prefix_len(a: str, b: str) -> int:\n    min_len = min(len(a), len(b))\n    for i in range(min_len):\n        if a[i] != b[i]:\n            return i\n    return min_len\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=_DESCRIPTION)\n    parser.add_argument(\"--data-type\", type=str, default=\"auto\")\n    parser.add_argument(\"--baseline-path\", type=str, nargs=\"+\")\n    parser.add_argument(\"--target-path\", type=str, nargs=\"+\")\n    parser.add_argument(\n        \"--output-path\", type=str, default=\"/tmp/text_comparator_output.json\"\n    )\n    parser.add_argument(\"--disable-print-details\", action=\"store_true\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "python/sglang/srt/disaggregation/ascend/__init__.py",
    "content": "from sglang.srt.disaggregation.ascend.conn import (\n    AscendKVBootstrapServer,\n    AscendKVManager,\n    AscendKVReceiver,\n    AscendKVSender,\n)\n"
  },
  {
    "path": "python/sglang/srt/disaggregation/ascend/conn.py",
    "content": "import concurrent.futures\nimport logging\nfrom typing import List, Tuple\n\nimport numpy as np\nimport numpy.typing as npt\n\nfrom sglang.srt.disaggregation.ascend.transfer_engine import AscendTransferEngine\nfrom sglang.srt.disaggregation.common.utils import group_concurrent_contiguous\nfrom sglang.srt.disaggregation.mooncake.conn import (\n    MooncakeKVBootstrapServer,\n    MooncakeKVManager,\n    MooncakeKVReceiver,\n    MooncakeKVSender,\n)\nfrom sglang.srt.utils.network import get_local_ip_auto\n\nlogger = logging.getLogger(__name__)\n\n\nclass AscendKVManager(MooncakeKVManager):\n    def init_engine(self):\n        # TransferEngine initialized on ascend.\n        local_ip = get_local_ip_auto()\n        self.engine = AscendTransferEngine(\n            hostname=local_ip,\n            npu_id=self.kv_args.gpu_id,\n            disaggregation_mode=self.disaggregation_mode,\n        )\n\n    def register_buffer_to_engine(self):\n        self.engine.batch_register(self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens)\n        # The Ascend backend optimize batch registration for small memory blocks.\n        self.engine.batch_register(\n            self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens\n        )\n        # Batch register state/extra pool data buffers\n        if self.kv_args.state_data_ptrs and self.kv_args.state_data_lens:\n            self.engine.batch_register(\n                self.kv_args.state_data_ptrs, self.kv_args.state_data_lens\n            )\n\n    def send_kvcache(\n        self,\n        mooncake_session_id: str,\n        prefill_kv_indices: npt.NDArray[np.int32],\n        dst_kv_ptrs: list[int],\n        dst_kv_indices: npt.NDArray[np.int32],\n        executor: concurrent.futures.ThreadPoolExecutor,\n    ):\n        # Group by indices\n        prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(\n            prefill_kv_indices, dst_kv_indices\n        )\n\n        if self.pp_size > 1:\n            src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (\n                self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)\n            )\n\n            layers_params = [\n                (\n                    src_k_ptrs[layer_id],\n                    dst_k_ptrs[layer_id],\n                    self.kv_args.kv_item_lens[layer_id],\n                )\n                for layer_id in range(layers_current_pp_stage)\n            ] + [\n                (\n                    src_v_ptrs[layer_id],\n                    dst_v_ptrs[layer_id],\n                    self.kv_args.kv_item_lens[layers_current_pp_stage + layer_id],\n                )\n                for layer_id in range(layers_current_pp_stage)\n            ]\n        else:\n            num_layers = len(self.kv_args.kv_data_ptrs)\n            layers_params = [\n                (\n                    self.kv_args.kv_data_ptrs[layer_id],\n                    dst_kv_ptrs[layer_id],\n                    self.kv_args.kv_item_lens[layer_id],\n                )\n                for layer_id in range(num_layers)\n            ]\n\n        def set_transfer_blocks(\n            src_ptr: int, dst_ptr: int, item_len: int\n        ) -> List[Tuple[int, int, int]]:\n            transfer_blocks = []\n            for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):\n                src_addr = src_ptr + int(prefill_index[0]) * item_len\n                dst_addr = dst_ptr + int(decode_index[0]) * item_len\n                length = item_len * len(prefill_index)\n                transfer_blocks.append((src_addr, dst_addr, length))\n            return transfer_blocks\n\n        # Worker function for processing a single layer\n        def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:\n            transfer_blocks = set_transfer_blocks(src_ptr, dst_ptr, item_len)\n            return self._transfer_data(mooncake_session_id, transfer_blocks)\n\n        # Worker function for processing all layers in a batch\n        def process_layers(layers_params: List[Tuple[int, int, int]]) -> int:\n            transfer_blocks = []\n            for src_ptr, dst_ptr, item_len in layers_params:\n                transfer_blocks.extend(set_transfer_blocks(src_ptr, dst_ptr, item_len))\n            return self._transfer_data(mooncake_session_id, transfer_blocks)\n\n        if self.enable_custom_mem_pool:\n            futures = [\n                executor.submit(\n                    process_layer,\n                    src_ptr,\n                    dst_ptr,\n                    item_len,\n                )\n                for (src_ptr, dst_ptr, item_len) in layers_params\n            ]\n            for future in concurrent.futures.as_completed(futures):\n                status = future.result()\n                if status != 0:\n                    for f in futures:\n                        f.cancel()\n                    return status\n        else:\n            # Combining all layers' params in one batch transfer is more efficient\n            # compared to using multiple threads\n            return process_layers(layers_params)\n\n        return 0\n\n\nclass AscendKVSender(MooncakeKVSender):\n    pass\n\n\nclass AscendKVReceiver(MooncakeKVReceiver):\n    pass\n\n\nclass AscendKVBootstrapServer(MooncakeKVBootstrapServer):\n    pass\n"
  },
  {
    "path": "python/sglang/srt/disaggregation/ascend/transfer_engine.py",
    "content": "import logging\nimport os\nfrom typing import List\n\nimport torch\n\nfrom sglang.srt.disaggregation.utils import DisaggregationMode\nfrom sglang.srt.distributed.device_communicators.mooncake_transfer_engine import (\n    MooncakeTransferEngine,\n)\nfrom sglang.srt.utils.network import NetworkAddress\n\ntry:\n    from memfabric_hybrid import TransferEngine\n\n    import_error = None\nexcept ImportError as e:\n    import_error = e\n    pass\n\nlogger = logging.getLogger(__name__)\n\n\nclass AscendTransferEngine(MooncakeTransferEngine):\n\n    def __init__(\n        self,\n        hostname: str,\n        npu_id: int,\n        disaggregation_mode: DisaggregationMode,\n    ):\n        if import_error is not None:\n            logger.warning(\n                \"Please install memfabric_hybrid, for details, see docs/backend/pd_disaggregation.md\"\n            )\n            raise import_error\n\n        self.engine = TransferEngine()\n        self.hostname = hostname\n        self.npu_id = npu_id\n\n        # Centralized storage address of the AscendTransferEngine\n        self.store_url = os.getenv(\"ASCEND_MF_STORE_URL\")\n        if disaggregation_mode == DisaggregationMode.PREFILL:\n            self.role = \"Prefill\"\n        elif disaggregation_mode == DisaggregationMode.DECODE:\n            self.role = \"Decode\"\n        else:\n            logger.error(f\"Unsupported DisaggregationMode: {disaggregation_mode}\")\n            raise ValueError(f\"Unsupported DisaggregationMode: {disaggregation_mode}\")\n        self.session_id = NetworkAddress(\n            self.hostname, self.engine.get_rpc_port()\n        ).to_host_port_str()\n        self.initialize()\n\n    def initialize(self) -> None:\n        from sglang.srt.distributed.parallel_state import (\n            get_world_group,\n            get_world_size,\n        )\n\n        transfer_protocol = self._get_transfer_protocol()\n        if transfer_protocol is None or transfer_protocol == \"sdma\":\n            trans_op_type = TransferEngine.TransDataOpType.SDMA\n        else:\n            trans_op_type = TransferEngine.TransDataOpType.DEVICE_RDMA\n            \"\"\"with device RDMA for PD transfer\"\"\"\n            tmp_tensor = torch.zeros(1, device=\"npu\")\n            output_tensor_list = [\n                torch.empty_like(tmp_tensor) for _ in range(get_world_size())\n            ]\n            # Initialize hccl in advance through all_gather to avoid conflicts with rdma initialization.\n            torch.distributed.all_gather(\n                output_tensor_list, tmp_tensor, group=get_world_group().device_group\n            )\n        \"\"\"Initialize the ascend transfer instance.\"\"\"\n        ret_value = self.engine.initialize(\n            self.store_url, self.session_id, self.role, self.npu_id, trans_op_type\n        )\n        if ret_value != 0:\n            logger.error(\"Ascend Transfer Engine initialization failed.\")\n            raise RuntimeError(\"Ascend Transfer Engine initialization failed.\")\n\n    def batch_register(self, ptrs: List[int], lengths: List[int]):\n        try:\n            ret_value = self.engine.batch_register_memory(ptrs, lengths)\n        except Exception:\n            # Mark register as failed\n            ret_value = -1\n        if ret_value != 0:\n            logger.debug(f\"Ascend memory registration for ptr {ptrs} failed.\")\n\n    @staticmethod\n    def _get_transfer_protocol():\n        protocol = os.getenv(\"ASCEND_MF_TRANSFER_PROTOCOL\")\n        allowed_protocols = {\"device_rdma\", \"sdma\"}\n        if protocol and protocol.lower() in allowed_protocols:\n            return protocol.lower()\n        else:\n            logger.warning(\n                \"Invalid or no transfer protocol specified, using default protocol.\"\n            )\n            return None\n"
  },
  {
    "path": "python/sglang/srt/disaggregation/base/__init__.py",
    "content": "from sglang.srt.disaggregation.base.conn import (\n    BaseKVBootstrapServer,\n    BaseKVManager,\n    BaseKVReceiver,\n    BaseKVSender,\n    KVArgs,\n    KVPoll,\n)\n"
  },
  {
    "path": "python/sglang/srt/disaggregation/base/conn.py",
    "content": "from __future__ import annotations\n\nfrom abc import ABC, abstractmethod\nfrom typing import TYPE_CHECKING, List, Optional\n\nimport numpy as np\nimport numpy.typing as npt\n\nfrom sglang.srt.server_args import ServerArgs\n\nif TYPE_CHECKING:\n    from sglang.srt.disaggregation.utils import DisaggregationMode\n\n\nclass KVArgs:\n    engine_rank: int\n    kv_data_ptrs: List[int]\n    kv_data_lens: List[int]\n    kv_item_lens: List[int]\n    aux_data_ptrs: List[int]\n    aux_data_lens: List[int]\n    aux_item_lens: List[int]\n    state_data_ptrs: List[int]\n    state_data_lens: List[int]\n    state_item_lens: List[int]\n    state_type: str  # \"none\", \"mamba\", \"swa\"\n    # for mamba state different tp slice transfer\n    state_dim_per_tensor: List[int]  # dimension to slice for each state tensor\n    ib_device: str\n    ib_traffic_class: str\n    gpu_id: int\n    kv_head_num: int\n    total_kv_head_num: int\n    page_size: int\n    # for pp prefill\n    pp_rank: int\n    prefill_start_layer: int\n    # for system dp\n    system_dp_rank: int\n\n\nclass KVPoll:\n    Failed = 0\n    Bootstrapping = 1\n    WaitingForInput = 2\n    Transferring = 3\n    Success = 4\n\n\nclass BaseKVManager(ABC):\n    \"\"\"Base class for managing transfer states\"\"\"\n\n    @abstractmethod\n    def __init__(\n        self,\n        args: KVArgs,\n        disaggregation_mode: DisaggregationMode,\n        server_args: ServerArgs,\n        is_mla_backend: Optional[bool] = False,\n    ): ...\n\n    @abstractmethod\n    def register_to_bootstrap(self):\n        \"\"\"Register prefill server info to the bootstrap server.\"\"\"\n        ...\n\n\nclass BaseKVSender(ABC):\n\n    @abstractmethod\n    def __init__(\n        self,\n        mgr: BaseKVManager,\n        bootstrap_addr: str,\n        bootstrap_room: int,\n        dest_tp_ranks: List[int],\n        pp_rank: int,\n    ): ...\n\n    @abstractmethod\n    def init(self, num_kv_indices: int, aux_index: Optional[int] = None):\n        \"\"\"\n        Set req's index metadata locally or notify the decoder server about the kv indices length and aux index.\n        \"\"\"\n        ...\n\n    @abstractmethod\n    def send(\n        self,\n        kv_indices: npt.NDArray[np.int32],\n        state_indices: Optional[List[int]] = None,\n    ):\n        \"\"\"\n        Send the kv cache at the given kv indices and the extra cache/state at the given indices to the decoder server.\n        \"\"\"\n        ...\n\n    @abstractmethod\n    def poll(self) -> KVPoll:\n        \"\"\"\n        Check the status of the kv cache transfer.\n        \"\"\"\n        ...\n\n    @abstractmethod\n    def failure_exception(self):\n        \"\"\"\n        Raise an exception if the kv cache transfer fails.\n        \"\"\"\n        ...\n\n\nclass BaseKVReceiver(ABC):\n\n    @abstractmethod\n    def __init__(\n        self,\n        mgr: BaseKVManager,\n        bootstrap_addr: str,\n        bootstrap_room: Optional[int] = None,\n    ): ...\n\n    @abstractmethod\n    def init(\n        self,\n        kv_indices: npt.NDArray[np.int32],\n        aux_index: Optional[int] = None,\n        state_indices: Optional[List[int]] = None,\n    ):\n        \"\"\"\n        Set req's index metadata locally or notify the prefill server about the kv indices, aux index, and state_indices.\n        \"\"\"\n        ...\n\n    @abstractmethod\n    def poll(self) -> KVPoll:\n        \"\"\"\n        Check the status of the kv cache transfer.\n        \"\"\"\n        ...\n\n    @abstractmethod\n    def failure_exception(self):\n        \"\"\"\n        Raise an exception if the kv cache transfer fails.\n        \"\"\"\n        ...\n\n    def clear(self):\n        \"\"\"\n        Clear any internal states.\n        \"\"\"\n        pass\n\n    def abort(self):\n        \"\"\"\n        Abort the current transfer.\n        \"\"\"\n        pass\n\n\nclass BaseKVBootstrapServer(ABC):\n    @abstractmethod\n    def __init__(self, host: str, port: int): ...\n"
  },
  {
    "path": "python/sglang/srt/disaggregation/common/__init__.py",
    "content": "from sglang.srt.disaggregation.common.conn import (\n    CommonKVBootstrapServer,\n    CommonKVManager,\n    CommonKVReceiver,\n)\n"
  },
  {
    "path": "python/sglang/srt/disaggregation/common/conn.py",
    "content": "from __future__ import annotations\n\nimport asyncio\nimport dataclasses\nimport logging\nimport threading\nimport time\nfrom collections import defaultdict\nfrom functools import cache\nfrom typing import Dict, List, Optional, Set, Tuple, Union\n\nimport numpy as np\nimport numpy.typing as npt\nimport requests\nimport zmq\nfrom aiohttp import web\n\nfrom sglang.srt.disaggregation.base.conn import (\n    BaseKVBootstrapServer,\n    BaseKVManager,\n    BaseKVReceiver,\n    BaseKVSender,\n    KVArgs,\n    KVPoll,\n)\nfrom sglang.srt.disaggregation.utils import DisaggregationMode\nfrom sglang.srt.distributed import get_pp_group\nfrom sglang.srt.environ import envs\nfrom sglang.srt.layers.dp_attention import (\n    get_attention_cp_rank,\n    get_attention_cp_size,\n    get_attention_dp_rank,\n    get_attention_dp_size,\n    get_attention_tp_rank,\n    get_attention_tp_size,\n)\nfrom sglang.srt.server_args import ServerArgs\nfrom sglang.srt.utils.network import (\n    NetworkAddress,\n    get_local_ip_auto,\n    get_zmq_socket_on_host,\n)\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass PrefillServerInfo:\n    # Topology fields (fetched from bootstrap server)\n    attn_tp_size: int\n    attn_cp_size: int\n    dp_size: int\n    pp_size: int\n    page_size: Optional[int]\n    kv_cache_dtype: Optional[str]\n    follow_bootstrap_room: bool\n\n    # Pre-computed rank mapping (set by try_ensure_parallel_info on decode side)\n    target_tp_rank: Optional[int] = None\n    target_tp_ranks: Optional[List[int]] = None\n    target_cp_ranks: Optional[List[int]] = None\n    target_pp_ranks: Optional[List[int]] = None\n    required_dst_info_num: Optional[int] = None\n    required_prefill_response_num: Optional[int] = None\n\n    def __post_init__(self):\n        self.attn_tp_size = int(self.attn_tp_size)\n        self.attn_cp_size = int(self.attn_cp_size)\n        self.dp_size = int(self.dp_size)\n        self.pp_size = int(self.pp_size)\n        self.page_size = int(self.page_size) if self.page_size is not None else None\n        self.kv_cache_dtype = (\n            str(self.kv_cache_dtype) if self.kv_cache_dtype is not None else None\n        )\n        self.follow_bootstrap_room = bool(self.follow_bootstrap_room)\n\n\n@dataclasses.dataclass\nclass PrefillRankInfo:\n    rank_ip: str\n    rank_port: int\n\n    def __post_init__(self):\n        self.rank_ip = str(self.rank_ip)\n        self.rank_port = int(self.rank_port)\n\n\nclass CommonKVManager(BaseKVManager):\n    def __init__(\n        self,\n        args: KVArgs,\n        disaggregation_mode: DisaggregationMode,\n        server_args: ServerArgs,\n        is_mla_backend: Optional[bool] = False,\n    ):\n        self.kv_args = args\n        self.is_mla_backend = is_mla_backend\n        self.disaggregation_mode = disaggregation_mode\n        self.server_args = server_args\n        # for p/d multi node infer\n        self.bootstrap_host = server_args.host\n        self.bootstrap_port = server_args.disaggregation_bootstrap_port\n        self.dist_init_addr = server_args.dist_init_addr\n        self.attn_tp_size = get_attention_tp_size()\n        self.attn_tp_rank = get_attention_tp_rank()\n        self.attn_cp_size = get_attention_cp_size()\n        self.attn_cp_rank = get_attention_cp_rank()\n        self.attn_dp_size = get_attention_dp_size()\n        self.attn_dp_rank = get_attention_dp_rank()\n        self.system_dp_size = (\n            1 if server_args.enable_dp_attention else server_args.dp_size\n        )\n        self.system_dp_rank = (\n            self.kv_args.system_dp_rank if self.kv_args.system_dp_rank else 0\n        )\n        self.pp_size = server_args.pp_size\n        self.pp_rank = self.kv_args.pp_rank\n        self.local_ip = get_local_ip_auto()\n        self.enable_all_cp_ranks_for_transfer = (\n            envs.SGLANG_DISAGGREGATION_ALL_CP_RANKS_TRANSFER.get()\n        )\n\n        # bind zmq socket\n        context = zmq.Context()\n        self.rank_port, self.server_socket = get_zmq_socket_on_host(\n            context, zmq.PULL, host=self.local_ip\n        )\n        logger.debug(f\"kv manager bind to {self.local_ip}:{self.rank_port}\")\n\n        self.request_status: Dict[int, KVPoll] = {}\n        self.failure_records: Dict[int, str] = {}\n        self.failure_lock = threading.Lock()\n\n        if self.disaggregation_mode == DisaggregationMode.PREFILL:\n            # When SGLANG_DISAGGREGATION_ALL_CP_RANKS_TRANSFER is True, all CP ranks\n            # participate in KV transfer; Otherwise only CP rank 0 sends.\n            self.is_dummy_cp_rank = (\n                not self.enable_all_cp_ranks_for_transfer\n                and self.attn_cp_size > 1\n                and self.attn_cp_rank != 0\n            )\n            self.register_to_bootstrap()\n            self.transfer_infos = {}\n            self.decode_kv_args_table = {}\n            self.pp_group = get_pp_group()\n            # If a timeout happens on the prefill side, it means prefill instances\n            # fail to receive the KV indices from the decode instance of this request.\n            # These timeout requests should be aborted to release the tree cache.\n            self.bootstrap_timeout = envs.SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT.get()\n        elif self.disaggregation_mode == DisaggregationMode.DECODE:\n            self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}\n            self.connection_lock = threading.Lock()\n            self.required_prefill_response_num_table: Dict[int, int] = {}\n            self.prefill_info_table: Dict[str, PrefillServerInfo] = {}\n            self.heartbeat_failures: Dict[str, int] = {}\n            self.session_pool: Dict = defaultdict(requests.Session)\n            self.session_pool_lock = threading.Lock()\n            self.addr_to_rooms_tracker: Dict[str, Set[int]] = defaultdict(set)\n            self.prefill_response_tracker: Dict[int, Set[int]] = defaultdict(set)\n            # Heartbeat interval should be at least 2 seconds\n            self.heartbeat_interval = max(\n                envs.SGLANG_DISAGGREGATION_HEARTBEAT_INTERVAL.get(), 2.0\n            )\n            # Heartbeat failure should be at least 1\n            self.max_failures = max(\n                envs.SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE.get(), 1\n            )\n            # If a timeout happens on the decode side, it means decode instances\n            # fail to receive the KV Cache transfer done signal after bootstrapping.\n            # These timeout requests should be aborted to release the tree cache.\n            self.waiting_timeout = envs.SGLANG_DISAGGREGATION_WAITING_TIMEOUT.get()\n        else:\n            raise ValueError(\n                f\"Unsupported DisaggregationMode: {self.disaggregation_mode}\"\n            )\n\n    def check_status(self, bootstrap_room: int) -> KVPoll:\n        return self.request_status[bootstrap_room]\n\n    def update_status(self, bootstrap_room: int, status: KVPoll):\n        if bootstrap_room not in self.request_status:\n            self.request_status[bootstrap_room] = status\n        else:\n            if status == KVPoll.Failed:\n                self.request_status[bootstrap_room] = KVPoll.Failed\n            else:\n                self.request_status[bootstrap_room] = max(\n                    self.request_status[bootstrap_room], status\n                )\n\n    def record_failure(self, bootstrap_room: int, failure_reason: str):\n        with self.failure_lock:\n            self.failure_records[bootstrap_room] = failure_reason\n\n    def try_ensure_parallel_info(self, bootstrap_addr: str) -> bool:\n        \"\"\"Single non-blocking attempt to fetch and cache prefill parallel info.\n        Returns True if info is available (cached or freshly fetched).\"\"\"\n        if bootstrap_addr in self.prefill_info_table:\n            return True\n\n        info: PrefillServerInfo = None\n        try:\n            url = f\"http://{bootstrap_addr}/route?prefill_dp_rank={-1}&prefill_cp_rank={-1}&target_tp_rank={-1}&target_pp_rank={-1}\"\n            response = requests.get(url, timeout=5)\n            if response.status_code == 200:\n                data = response.json()\n                info = PrefillServerInfo(**data)\n            else:\n                logger.error(\n                    f\"Failed to get prefill server info: {response.status_code}, {response.text}\"\n                )\n                return False\n        except Exception as e:\n            logger.error(f\"Error fetching prefill server info from bootstrap: {e}\")\n            return False\n\n        # Sanity checks\n        if info.page_size is not None and info.page_size != self.kv_args.page_size:\n            raise RuntimeError(\n                f\"Page size mismatch: prefill server has page_size={info.page_size}, \"\n                f\"but decode server has page_size={self.kv_args.page_size}. \"\n                f\"Both servers must use the same --page-size value.\"\n            )\n\n        if (\n            info.kv_cache_dtype is not None\n            and info.kv_cache_dtype != self.server_args.kv_cache_dtype\n        ):\n            raise RuntimeError(\n                f\"KV cache dtype mismatch: prefill server has kv_cache_dtype={info.kv_cache_dtype}, \"\n                f\"but decode server has kv_cache_dtype={self.server_args.kv_cache_dtype}. \"\n                f\"Both servers must use the same --kv-cache-dtype value.\"\n            )\n\n        self._resolve_rank_mapping(info)\n        self.prefill_info_table[bootstrap_addr] = info\n        logger.debug(f\"Prefill parallel info for [{bootstrap_addr}]: {info}\")\n        return True\n\n    def _resolve_rank_mapping(self, info: PrefillServerInfo) -> None:\n        \"\"\"Compute TP/CP/PP rank mapping and store on the PrefillServerInfo object.\n        Deterministic for a given (bootstrap_addr, decode engine) pair.\"\"\"\n        # TP rank mapping\n        if self.attn_tp_size == info.attn_tp_size:\n            target_tp_rank = self.kv_args.engine_rank % self.attn_tp_size\n            required_dst_info_num = 1\n            required_prefill_response_num = 1\n            target_tp_ranks = [target_tp_rank]\n        elif self.attn_tp_size > info.attn_tp_size:\n            if not self.is_mla_backend:\n                logger.warning_once(\n                    \"Performance is NOT guaranteed when using different TP sizes for non-MLA models. \"\n                )\n            target_tp_rank = (self.kv_args.engine_rank % self.attn_tp_size) // (\n                self.attn_tp_size // info.attn_tp_size\n            )\n            required_dst_info_num = self.attn_tp_size // info.attn_tp_size\n            required_prefill_response_num = 1\n            target_tp_ranks = [target_tp_rank]\n        else:\n            if not self.is_mla_backend:\n                logger.warning_once(\n                    \"Performance is NOT guaranteed when using different TP sizes for non-MLA models. \"\n                )\n            # For non-MLA models, one decode rank needs to retrieve KVCache from multiple prefill ranks\n            target_tp_ranks = list(\n                range(\n                    (self.kv_args.engine_rank % self.attn_tp_size)\n                    * (info.attn_tp_size // self.attn_tp_size),\n                    (self.kv_args.engine_rank % self.attn_tp_size + 1)\n                    * (info.attn_tp_size // self.attn_tp_size),\n                )\n            )\n            # For MLA models, we can retrieve KVCache from only one prefill rank, but we still need to maintain\n            # multiple connections in the connection pool and have to send dummy requests to other prefill ranks,\n            # or the KVPoll will never be set correctly\n            target_tp_rank = target_tp_ranks[0]\n            required_dst_info_num = 1\n            if self.is_mla_backend:\n                required_prefill_response_num = 1\n            else:\n                required_prefill_response_num = info.attn_tp_size // self.attn_tp_size\n\n        # CP rank mapping — decode cp size should be equal to 1\n        assert self.attn_cp_size == 1, (\n            f\"Decode cp size ({self.attn_cp_size}) should be equal to 1\",\n        )\n        if self.attn_cp_size == info.attn_cp_size:\n            assert info.attn_cp_size == 1, (\n                f\"When prefill cp size is 1, attn cp size should be 1, but got {self.attn_cp_size}\",\n            )\n            target_cp_ranks = [self.attn_cp_rank]\n        else:\n            target_cp_ranks = list(range(info.attn_cp_size))\n            if not self.enable_all_cp_ranks_for_transfer:\n                # Only retrieve from prefill CP rank 0 when not using all ranks\n                target_cp_ranks = target_cp_ranks[:1]\n                required_prefill_response_num *= 1\n            else:\n                required_prefill_response_num *= info.attn_cp_size // self.attn_cp_size\n\n        # PP rank mapping — decode pp size should be equal to prefill pp size or 1\n        assert self.pp_size == info.pp_size or self.pp_size == 1, (\n            f\"Decode pp size ({self.pp_size}) should be equal to prefill pp size ({info.pp_size}) or 1\",\n        )\n        if info.pp_size == self.pp_size:\n            target_pp_ranks = [self.pp_rank]\n        else:\n            target_pp_ranks = list(range(info.pp_size))\n            required_prefill_response_num *= info.pp_size // self.pp_size\n\n        info.target_tp_rank = target_tp_rank\n        info.target_tp_ranks = target_tp_ranks\n        info.target_cp_ranks = target_cp_ranks\n        info.target_pp_ranks = target_pp_ranks\n        info.required_dst_info_num = required_dst_info_num\n        info.required_prefill_response_num = required_prefill_response_num\n\n    def register_to_bootstrap(self):\n        \"\"\"Register prefill server info to bootstrap server via HTTP POST.\"\"\"\n        if self.dist_init_addr:\n            # Multi-node case: bootstrap server's host is dist_init_addr\n            host = NetworkAddress.parse(self.dist_init_addr).resolved().host\n        else:\n            # Single-node case: bootstrap server's host is the same as http server's host\n            host = self.bootstrap_host\n\n        bootstrap_na = NetworkAddress(host, self.bootstrap_port)\n        bootstrap_server_url = bootstrap_na.to_host_port_str()\n        url = f\"{bootstrap_na.to_url()}/route\"\n        payload = {\n            \"attn_tp_size\": self.attn_tp_size,\n            \"attn_tp_rank\": self.attn_tp_rank,\n            \"attn_cp_size\": self.attn_cp_size,\n            \"attn_cp_rank\": self.attn_cp_rank,\n            \"attn_dp_size\": self.attn_dp_size,\n            \"attn_dp_rank\": self.attn_dp_rank,\n            \"pp_size\": self.pp_size,\n            \"pp_rank\": self.pp_rank,\n            \"system_dp_size\": self.system_dp_size,\n            \"system_dp_rank\": self.system_dp_rank,\n            \"rank_ip\": self.local_ip,\n            \"rank_port\": self.rank_port,\n            \"page_size\": self.kv_args.page_size,\n            \"kv_cache_dtype\": self.server_args.kv_cache_dtype,\n            \"load_balance_method\": self.server_args.load_balance_method,\n        }\n\n        try:\n            response = requests.put(url, json=payload, timeout=5)\n            if response.status_code == 200:\n                logger.debug(\"Prefill successfully registered to bootstrap server.\")\n            else:\n                logger.error(\n                    f\"Prefill instance failed to connect to bootstrap server: {response.status_code}, {response.text}\"\n                )\n        except Exception as e:\n            logger.error(\n                f\"Prefill instance failed to register to bootstrap server: {e}\"\n            )\n\n    @cache\n    def _connect(self, endpoint: str, is_ipv6: bool = False):\n        socket = zmq.Context().socket(zmq.PUSH)\n        if is_ipv6:\n            socket.setsockopt(zmq.IPV6, 1)\n        socket.connect(endpoint)\n        return socket\n\n    def get_mha_kv_ptrs_with_pp(\n        self, src_kv_ptrs: List[int], dst_kv_ptrs: List[int]\n    ) -> Tuple[List[int], List[int], List[int], List[int], int]:\n        start_layer = self.kv_args.prefill_start_layer\n        num_kv_layers = len(src_kv_ptrs) // 2\n        end_layer = start_layer + num_kv_layers\n        dst_num_total_layers = len(dst_kv_ptrs) // 2\n        src_k_ptrs = src_kv_ptrs[:num_kv_layers]\n        src_v_ptrs = src_kv_ptrs[num_kv_layers:]\n        if num_kv_layers == dst_num_total_layers:\n            dst_k_ptrs = dst_kv_ptrs[:dst_num_total_layers]\n            dst_v_ptrs = dst_kv_ptrs[dst_num_total_layers:]\n        elif (\n            num_kv_layers < dst_num_total_layers\n            and dst_num_total_layers % num_kv_layers != 0\n        ):\n            # Case: Decode has draft model KV while Prefill is deployed without speculative decoding\n            # dst_kv_ptrs layout: [K_main..., V_main..., draft_K..., draft_V...]\n            multiplier_ratio = dst_num_total_layers // num_kv_layers\n            dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]\n            v_ptr_offset = num_kv_layers * multiplier_ratio\n            dst_v_ptrs = dst_kv_ptrs[\n                v_ptr_offset + start_layer : v_ptr_offset + end_layer\n            ]\n        else:\n            # Decode pp size should be equal to prefill pp size or 1\n            dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]\n            dst_v_ptrs = dst_kv_ptrs[\n                dst_num_total_layers + start_layer : dst_num_total_layers + end_layer\n            ]\n        layers_current_pp_stage = len(src_k_ptrs)\n        return src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage\n\n    def get_mla_kv_ptrs_with_pp(\n        self, src_kv_ptrs: List[int], dst_kv_ptrs: List[int]\n    ) -> Tuple[List[int], List[int], int]:\n        start_layer = self.kv_args.prefill_start_layer\n        end_layer = start_layer + len(src_kv_ptrs)\n        if len(src_kv_ptrs) == len(dst_kv_ptrs):\n            sliced_dst_kv_ptrs = dst_kv_ptrs\n        else:\n            # Decode pp size should be equal to prefill pp size or 1\n            sliced_dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer]\n        layers_current_pp_stage = len(src_kv_ptrs)\n        return src_kv_ptrs, sliced_dst_kv_ptrs, layers_current_pp_stage\n\n\nclass CommonKVSender(BaseKVSender):\n    def __init__(\n        self,\n        mgr: CommonKVManager,\n        bootstrap_addr: str,\n        bootstrap_room: int,\n        dest_tp_ranks: List[int],\n        pp_rank: int,\n    ):\n        self.kv_mgr = mgr\n        self.bootstrap_room = bootstrap_room\n        self.aux_index = None\n        self.bootstrap_server_url = bootstrap_addr\n        # inner state\n        self.curr_idx = 0\n        if self.kv_mgr.is_dummy_cp_rank:\n            # Non-authoritative CP ranks are dummy participants.\n            self.kv_mgr.update_status(self.bootstrap_room, KVPoll.WaitingForInput)\n            return\n\n        self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)\n        if (\n            self.kv_mgr.server_args.dp_size > 1\n            and self.kv_mgr.server_args.load_balance_method != \"follow_bootstrap_room\"\n        ):\n            self._register_prefill_dp_rank()\n\n    def _register_prefill_dp_rank(self):\n        \"\"\"Register this request's prefill dp_rank to the bootstrap server.\"\"\"\n        url = f\"http://{self.bootstrap_server_url}/register_dp_rank\"\n        payload = {\n            \"bootstrap_room\": self.bootstrap_room,\n            \"dp_rank\": self.kv_mgr.attn_dp_rank,\n        }\n        try:\n            response = requests.post(url, json=payload, timeout=5)\n            if response.status_code != 200:\n                logger.error(\n                    f\"Failed to register prefill dp_rank: {response.status_code}, {response.text}\"\n                )\n        except Exception as e:\n            logger.error(f\"Failed to register prefill dp_rank: {e}\")\n\n    def init(self, num_kv_indices: int, aux_index: Optional[int] = None):\n        self.num_kv_indices = num_kv_indices\n        self.aux_index = aux_index\n        logger.debug(\n            f\"CommonKVSender init with num_kv_indices: {num_kv_indices} and aux_index: {aux_index}\"\n        )\n\n    def send(\n        self,\n        kv_indices: npt.NDArray[np.int32],\n        state_indices: Optional[List[int]] = None,\n    ):\n        pass\n\n    def poll(self) -> KVPoll:\n        pass\n\n    def failure_exception(self):\n        raise Exception(\"Fake KVReceiver Exception\")\n\n\nclass CommonKVReceiver(BaseKVReceiver):\n    _ctx = zmq.Context()\n    _socket_cache = {}\n    _socket_locks = {}\n    _global_lock = threading.Lock()\n\n    def __init__(\n        self,\n        mgr: CommonKVManager,\n        bootstrap_addr: str,\n        bootstrap_room: Optional[int] = None,\n        prefill_dp_rank: Optional[int] = None,\n    ):\n        self.bootstrap_room = bootstrap_room\n        self.bootstrap_addr = bootstrap_addr\n        self.kv_mgr = mgr\n        self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)\n\n        if self.bootstrap_addr not in self.kv_mgr.prefill_info_table:\n            self.kv_mgr.record_failure(\n                self.bootstrap_room,\n                f\"Prefill server with bootstrap_addr: {self.bootstrap_addr} is healthy before, but now it is down. Request (bootstrap_room: {self.bootstrap_room}) has been marked as failed.\",\n            )\n            self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)\n            self.bootstrap_infos = None\n            return\n\n        # Read pre-computed rank mapping from prefill_info (computed in try_ensure_parallel_info)\n        self.prefill_info = self.kv_mgr.prefill_info_table[self.bootstrap_addr]\n        self.target_tp_rank = self.prefill_info.target_tp_rank\n        self.target_tp_ranks = self.prefill_info.target_tp_ranks\n        self.target_cp_ranks = self.prefill_info.target_cp_ranks\n        self.target_pp_ranks = self.prefill_info.target_pp_ranks\n        self.required_dst_info_num = self.prefill_info.required_dst_info_num\n        self.required_prefill_response_num = (\n            self.prefill_info.required_prefill_response_num\n        )\n\n        self.kv_mgr.required_prefill_response_num_table[self.bootstrap_room] = (\n            self.required_prefill_response_num\n        )\n\n        assert (\n            prefill_dp_rank is not None\n        ), \"prefill_dp_rank must be resolved before creating receiver\"\n        self.prefill_dp_rank = prefill_dp_rank\n        self._setup_bootstrap_infos()\n\n    def _setup_bootstrap_infos(self):\n        all_bootstrap_infos = []\n        # NOTE: key distinguished by bootstrap_addr, prefill_dp_rank, prefill_cp_rank, and target_tp_rank\n        for target_cp_rank in self.target_cp_ranks:\n            bootstrap_key = f\"{self.bootstrap_addr}_{self.prefill_dp_rank}_{target_cp_rank}_{self.target_tp_rank}\"\n\n            if bootstrap_key not in self.kv_mgr.connection_pool:\n                bootstrap_infos = []\n                for target_tp_rank in self.target_tp_ranks:\n                    # Enable higher PP ranks to be bootstrapped earlier to make PP PD requests bootstrap more robust\n                    for target_pp_rank in reversed(self.target_pp_ranks):\n                        bootstrap_info = self._get_bootstrap_info_from_server(\n                            self.prefill_dp_rank,\n                            target_cp_rank,\n                            target_tp_rank,\n                            target_pp_rank,\n                        )\n                        if bootstrap_info is not None:\n                            if self.kv_mgr.is_mla_backend:\n                                # For MLA: target_tp_rank is the selected real rank, others are dummy ranks\n                                bootstrap_info[\"is_dummy\"] = not bool(\n                                    target_tp_rank == self.target_tp_rank\n                                    or self.target_tp_rank is None\n                                )\n                            else:\n                                # For non-MLA: all target_tp_ranks are selected real ranks\n                                bootstrap_info[\"is_dummy\"] = False\n                            logger.debug(\n                                f\"Fetched bootstrap info: {bootstrap_info} for DP {self.prefill_dp_rank} CP {target_cp_rank} TP {target_tp_rank} PP {target_pp_rank}\"\n                            )\n                            bootstrap_infos.append(bootstrap_info)\n                        else:\n                            self.kv_mgr.record_failure(\n                                self.bootstrap_room,\n                                f\"Could not fetch bootstrap info for: prefill_dp_rank: {self.prefill_dp_rank} prefill_cp_rank: {target_cp_rank} target_tp_rank: {target_tp_rank} and target_pp_rank {target_pp_rank}\",\n                            )\n                            self.kv_mgr.update_status(\n                                self.bootstrap_room, KVPoll.Failed\n                            )\n                            return\n\n                self.bootstrap_infos = bootstrap_infos\n                self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos\n\n                # Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server\n                self._register_kv_args()\n            else:\n                self.bootstrap_infos = self.kv_mgr.connection_pool[bootstrap_key]\n\n            assert len(self.bootstrap_infos) > 0\n            all_bootstrap_infos.extend(self.bootstrap_infos)\n\n        self.bootstrap_infos = all_bootstrap_infos\n\n    def _get_bootstrap_info_from_server(\n        self, prefill_dp_rank, prefill_cp_rank, target_tp_rank, target_pp_rank\n    ):\n        \"\"\"Fetch the bootstrap info from the bootstrap server.\"\"\"\n        try:\n            url = f\"http://{self.bootstrap_addr}/route?prefill_dp_rank={prefill_dp_rank}&prefill_cp_rank={prefill_cp_rank}&target_tp_rank={target_tp_rank}&target_pp_rank={target_pp_rank}\"\n            response = requests.get(url, timeout=5)\n            if response.status_code == 200:\n                bootstrap_info = response.json()\n                return bootstrap_info\n            else:\n                logger.error(\n                    f\"Failed to get prefill server info: {response.status_code}, {response.text}\"\n                )\n                return None\n        except Exception as e:\n            logger.error(f\"Error fetching prefill info from bootstrap: {e}\")\n            return None\n\n    @staticmethod\n    def query_prefill_dp_ranks(\n        bootstrap_addr: str, bootstrap_rooms: List[int]\n    ) -> Dict[str, int]:\n        \"\"\"Batch query prefill dp_ranks for given bootstrap_rooms.\"\"\"\n        try:\n            url = f\"http://{bootstrap_addr}/query_dp_ranks\"\n            response = requests.post(\n                url,\n                json={\"bootstrap_rooms\": bootstrap_rooms},\n                timeout=5,\n            )\n            if response.status_code == 200:\n                return response.json()\n            else:\n                logger.error(\n                    f\"Failed to query dp_ranks: {response.status_code}, {response.text}\"\n                )\n                return {}\n        except Exception as e:\n            logger.error(f\"Error querying dp_ranks from bootstrap: {e}\")\n            return {}\n\n    @classmethod\n    def _connect(cls, endpoint: str, is_ipv6: bool = False):\n        with cls._global_lock:\n            if endpoint not in cls._socket_cache:\n                sock = cls._ctx.socket(zmq.PUSH)\n                if is_ipv6:\n                    sock.setsockopt(zmq.IPV6, 1)\n                sock.connect(endpoint)\n                cls._socket_cache[endpoint] = sock\n                cls._socket_locks[endpoint] = threading.Lock()\n            return cls._socket_cache[endpoint], cls._socket_locks[endpoint]\n\n    @classmethod\n    def _connect_to_bootstrap_server(cls, bootstrap_info: dict):\n        ip_address = bootstrap_info[\"rank_ip\"]\n        port = bootstrap_info[\"rank_port\"]\n        na = NetworkAddress(ip_address, port)\n        sock, lock = cls._connect(na.to_tcp(), is_ipv6=na.is_ipv6)\n        return sock, lock\n\n    def _register_kv_args(self):\n        pass\n\n    def failure_exception(self):\n        raise Exception(\"Fake KVReceiver Exception\")\n\n\nclass CommonKVBootstrapServer(BaseKVBootstrapServer):\n    def __init__(self, host: str, port: int):\n        self.host = host\n        self.port = port\n        self.app = web.Application()\n        self.store = dict()\n        self.lock = asyncio.Lock()\n        self._setup_routes()\n        self.pp_size = None\n        self.attn_tp_size = None\n        self.attn_cp_size = None\n        self.dp_size = None\n        self.page_size = None\n        self.kv_cache_dtype: Optional[str] = None\n        self.follow_bootstrap_room: Optional[bool] = None\n        self.prefill_port_table: Dict[\n            int, Dict[int, Dict[int, Dict[int, PrefillRankInfo]]]\n        ] = {}\n        self.room_to_dp_rank: Dict[int, Dict[str, Union[int, float]]] = {}\n        self._registered_count = 0\n        self.entry_cleanup_interval = (\n            envs.SGLANG_DISAGGREGATION_BOOTSTRAP_ENTRY_CLEANUP_INTERVAL.get()\n        )\n\n        # Start bootstrap server\n        self.thread = threading.Thread(target=self._run_server, daemon=True)\n        self.run()\n\n    def run(self):\n        self.thread.start()\n\n    def _is_ready(self) -> bool:\n        if (\n            self.attn_tp_size is None\n            or self.attn_cp_size is None\n            or self.pp_size is None\n            or self.dp_size is None\n        ):\n            return False\n        expected = self.dp_size * self.attn_cp_size * self.attn_tp_size * self.pp_size\n        logger.debug(\n            f\"Expected {expected} prefill servers to be registered, {self._registered_count} registered so far\"\n        )\n        return self._registered_count >= expected\n\n    def _setup_routes(self):\n        self.app.router.add_route(\"*\", \"/route\", self._handle_route)\n        self.app.router.add_post(\"/register_dp_rank\", self._handle_register_dp_rank)\n        self.app.router.add_post(\"/query_dp_ranks\", self._handle_query_dp_ranks)\n        self.app.router.add_get(\"/health\", self._handle_health_check)\n\n    async def _handle_health_check(self, request):\n        return web.Response(text=\"OK\", status=200)\n\n    async def _handle_route(self, request: web.Request):\n        method = request.method\n        if method == \"PUT\":\n            return await self._handle_route_put(request)\n        elif method == \"GET\":\n            return await self._handle_route_get(request)\n        else:\n            return web.Response(\n                text=\"Method not allowed\", status=405, content_type=\"application/json\"\n            )\n\n    async def _handle_route_put(self, request: web.Request):\n        data = await request.json()\n        attn_tp_size = data[\"attn_tp_size\"]\n        attn_tp_rank = data[\"attn_tp_rank\"]\n        attn_cp_size = data[\"attn_cp_size\"]\n        attn_cp_rank = data[\"attn_cp_rank\"]\n        attn_dp_size = data[\"attn_dp_size\"]\n        attn_dp_rank = data[\"attn_dp_rank\"]\n        pp_size = data[\"pp_size\"]\n        pp_rank = data[\"pp_rank\"]\n        system_dp_size = data[\"system_dp_size\"]\n        system_dp_rank = data[\"system_dp_rank\"]\n        rank_ip = data[\"rank_ip\"]\n        rank_port = int(data[\"rank_port\"])\n        page_size = int(data[\"page_size\"])\n        kv_cache_dtype = data[\"kv_cache_dtype\"]\n\n        if self.attn_tp_size is None:\n            self.attn_tp_size = attn_tp_size\n\n        if self.attn_cp_size is None:\n            self.attn_cp_size = attn_cp_size\n\n        if self.dp_size is None:\n            self.dp_size = attn_dp_size if system_dp_size == 1 else system_dp_size\n\n        if self.pp_size is None:\n            self.pp_size = pp_size\n\n        if self.page_size is None and page_size is not None:\n            self.page_size = page_size\n\n        if self.kv_cache_dtype is None and kv_cache_dtype is not None:\n            self.kv_cache_dtype = kv_cache_dtype\n\n        if self.follow_bootstrap_room is None:\n            load_balance_method = data.get(\n                \"load_balance_method\", \"follow_bootstrap_room\"\n            )\n            self.follow_bootstrap_room = load_balance_method == \"follow_bootstrap_room\"\n\n        if system_dp_size == 1:\n            dp_group = attn_dp_rank\n        else:\n            dp_group = system_dp_rank\n\n        # Add lock to make sure thread-safe\n        async with self.lock:\n            dp_group_table = self.prefill_port_table.setdefault(dp_group, {})\n            cp_group_table = dp_group_table.setdefault(attn_cp_rank, {})\n            tp_group_table = cp_group_table.setdefault(attn_tp_rank, {})\n\n            tp_group_table[pp_rank] = PrefillRankInfo(\n                rank_ip=rank_ip,\n                rank_port=rank_port,\n            )\n\n            self._registered_count += 1\n\n        expected = self.dp_size * self.attn_cp_size * self.attn_tp_size * self.pp_size\n        logger.debug(\n            f\"Register prefill bootstrap: DP{dp_group} CP{attn_cp_rank} TP{attn_tp_rank} PP{pp_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}\"\n            f\" ({self._registered_count}/{expected} registered)\"\n        )\n\n        return web.Response(text=\"OK\", status=200)\n\n    async def _handle_route_get(self, request: web.Request):\n        prefill_dp_rank = request.query.get(\"prefill_dp_rank\")\n        prefill_cp_rank = request.query.get(\"prefill_cp_rank\")\n        target_tp_rank = request.query.get(\"target_tp_rank\")\n        target_pp_rank = request.query.get(\"target_pp_rank\")\n        if (\n            not prefill_dp_rank\n            or not prefill_cp_rank\n            or not target_tp_rank\n            or not target_pp_rank\n        ):\n            return web.Response(text=\"Missing inputs for bootstrap server.\", status=400)\n\n        if (\n            int(prefill_dp_rank) == -1\n            and int(prefill_cp_rank) == -1\n            and int(target_tp_rank) == -1\n            and int(target_pp_rank) == -1\n        ):\n            if not self._is_ready():\n                return web.Response(\n                    text=f\"Prefill server not fully registered yet\"\n                    f\" ({self._registered_count} workers registered).\",\n                    status=503,\n                )\n            info = PrefillServerInfo(\n                attn_tp_size=self.attn_tp_size,\n                attn_cp_size=self.attn_cp_size,\n                dp_size=self.dp_size,\n                pp_size=self.pp_size,\n                page_size=self.page_size,\n                kv_cache_dtype=self.kv_cache_dtype,\n                follow_bootstrap_room=(\n                    self.follow_bootstrap_room\n                    if self.follow_bootstrap_room is not None\n                    else True\n                ),\n            )\n            return web.json_response(dataclasses.asdict(info), status=200)\n\n        if not self._is_ready():\n            return web.Response(\n                text=f\"Prefill server not fully registered yet\"\n                f\" ({self._registered_count} workers registered).\",\n                status=503,\n            )\n\n        # Find corresponding prefill info\n        try:\n            async with self.lock:\n                bootstrap_info = self.prefill_port_table[int(prefill_dp_rank)][\n                    int(prefill_cp_rank)\n                ][int(target_tp_rank)][int(target_pp_rank)]\n        except KeyError:\n            return web.Response(\n                text=f\"Bootstrap info not found for dp_rank={prefill_dp_rank} cp_rank={prefill_cp_rank} \"\n                f\"tp_rank={target_tp_rank} pp_rank={target_pp_rank}\",\n                status=404,\n            )\n\n        return web.json_response(dataclasses.asdict(bootstrap_info), status=200)\n\n    async def _handle_register_dp_rank(self, request: web.Request):\n        data = await request.json()\n        bootstrap_room = int(data[\"bootstrap_room\"])\n        dp_rank = int(data[\"dp_rank\"])\n        async with self.lock:\n            self.room_to_dp_rank[bootstrap_room] = {\n                \"dp_rank\": dp_rank,\n                \"timestamp\": time.time(),\n            }\n        logger.debug(f\"Registered dp_rank={dp_rank} for {bootstrap_room=}\")\n        return web.Response(text=\"OK\", status=200)\n\n    async def _handle_query_dp_ranks(self, request: web.Request):\n        data = await request.json()\n        bootstrap_rooms = data[\"bootstrap_rooms\"]\n        result = {}\n        async with self.lock:\n            for room in bootstrap_rooms:\n                room_int = int(room)\n                if room_int in self.room_to_dp_rank:\n                    result[str(room_int)] = self.room_to_dp_rank[room_int][\"dp_rank\"]\n        return web.json_response(result, status=200)\n\n    async def _cleanup_expired_entries(self):\n        \"\"\"Remove entries older than cleanup interval from room_to_dp_rank.\"\"\"\n        while True:\n            await asyncio.sleep(self.entry_cleanup_interval)\n            current_time = time.time()\n            async with self.lock:\n                expired_keys = [\n                    key\n                    for key, value in self.room_to_dp_rank.items()\n                    if current_time - value[\"timestamp\"] > self.entry_cleanup_interval\n                ]\n                for key in expired_keys:\n                    del self.room_to_dp_rank[key]\n            if expired_keys:\n                logger.debug(\n                    f\"Cleaned up {len(expired_keys)} expired entries from room_to_dp_rank\"\n                )\n\n    def _run_server(self):\n        try:\n            # Event Loop\n            self._loop = asyncio.new_event_loop()\n            asyncio.set_event_loop(self._loop)\n\n            self._loop.create_task(self._cleanup_expired_entries())\n\n            access_log = None\n            if logging.getLogger(__name__).getEffectiveLevel() <= logging.DEBUG:\n                access_log = self.app.logger\n\n            self._runner = web.AppRunner(self.app, access_log=access_log)\n            self._loop.run_until_complete(self._runner.setup())\n\n            site = web.TCPSite(self._runner, host=self.host, port=self.port)\n            self._loop.run_until_complete(site.start())\n            logger.info(\n                f\"CommonKVBootstrapServer started successfully on {self.host}:{self.port}\"\n            )\n            self._loop.run_forever()\n        except Exception as e:\n            logger.error(f\"Server error: {str(e)}\", exc_info=True)\n        finally:\n            # Cleanup\n            self._loop.run_until_complete(self._runner.cleanup())\n            self._loop.close()\n\n    def close(self):\n        \"\"\"Shutdown\"\"\"\n        if self._loop is not None and self._loop.is_running():\n            self._loop.call_soon_threadsafe(self._loop.stop)\n            logger.info(\"Stopping server loop...\")\n\n        if self.thread.is_alive():\n            self.thread.join(timeout=2)\n            logger.info(\"Server thread stopped\")\n\n    def poll(self) -> KVPoll: ...\n"
  },
  {
    "path": "python/sglang/srt/disaggregation/common/utils.py",
    "content": "import threading\nfrom collections import deque\nfrom typing import List, Tuple\n\nimport numpy as np\nimport numpy.typing as npt\n\n\nclass FastQueue:\n    def __init__(self):\n        self._buf = deque()\n        self._cond = threading.Condition()\n\n    def put(self, item):\n        with self._cond:\n            self._buf.append(item)\n            # wake up a thread of wait()\n            self._cond.notify()\n\n    def get(self):\n        with self._cond:\n            # if queue is empty  ,block until is notified()\n            while not self._buf:\n                self._cond.wait()\n            return self._buf.popleft()\n\n\ndef group_concurrent_contiguous(\n    src_indices: npt.NDArray[np.int32], dst_indices: npt.NDArray[np.int32]\n) -> Tuple[List[npt.NDArray[np.int32]], List[npt.NDArray[np.int32]]]:\n    \"\"\"Vectorised NumPy implementation.\"\"\"\n    if src_indices.size == 0:\n        return [], []\n\n    brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1\n    src_groups = np.split(src_indices, brk)\n    dst_groups = np.split(dst_indices, brk)\n\n    src_groups = [g.tolist() for g in src_groups]\n    dst_groups = [g.tolist() for g in dst_groups]\n\n    return src_groups, dst_groups\n"
  },
  {
    "path": "python/sglang/srt/disaggregation/decode.py",
    "content": "\"\"\"\nLife cycle of a request in the decode server\n\n1. PreallocQueue:\n    a. Initialize a receiver for each request\n    b. The request handshakes first, and pre-allocate kv once there is available kv.\n    c. Move the request to TransferQueue.\n\n2. TransferQueue:\n    a. Poll the receiver to check the transfer state\n    b. If the transfer has finished, move the request to waiting queue\n\n3. WaitingQueue:\n    a. Use the requests in the queue to construct a PrebuiltExtendBatch\n    b. Skip the prefill forward but only populate metadata\n\n4. RunningBatch:\n    a. Merge the resolved PrebuiltExtendBatch into running batch to run decoding\n\"\"\"\n\nfrom __future__ import annotations\n\nimport logging\nimport time\nfrom collections import deque\nfrom dataclasses import dataclass\nfrom http import HTTPStatus\nfrom typing import TYPE_CHECKING, Dict, List, Optional, Tuple\n\nimport torch\nfrom torch.distributed import ProcessGroup\n\nfrom sglang.srt.configs.mamba_utils import Mamba2CacheParams\nfrom sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE\nfrom sglang.srt.disaggregation.base import KVPoll\nfrom sglang.srt.disaggregation.common.conn import CommonKVManager, CommonKVReceiver\nfrom sglang.srt.disaggregation.utils import (\n    FAKE_BOOTSTRAP_HOST,\n    DisaggregationMode,\n    KVClassType,\n    MetadataBuffers,\n    ReqToMetadataIdxAllocator,\n    TransferBackend,\n    get_kv_class,\n    is_mla_backend,\n    kv_to_page_indices,\n    poll_and_all_reduce,\n    prepare_abort,\n)\nfrom sglang.srt.environ import envs\nfrom sglang.srt.layers.dp_attention import get_attention_tp_size\nfrom sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch\nfrom sglang.srt.managers.utils import GenerationBatchResult\nfrom sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator\nfrom sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache\nfrom sglang.srt.mem_cache.common import release_kv_cache\nfrom sglang.srt.mem_cache.memory_pool import (\n    HybridLinearKVPool,\n    HybridReqToTokenPool,\n    KVCache,\n    NSATokenToKVPool,\n    ReqToTokenPool,\n)\nfrom sglang.srt.mem_cache.swa_memory_pool import SWAKVPool\nfrom sglang.srt.observability.req_time_stats import (\n    set_schedule_time_batch,\n    set_time_batch,\n)\nfrom sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter\n\nlogger = logging.getLogger(__name__)\n\nif TYPE_CHECKING:\n    from sglang.srt.managers.schedule_batch import Req\n    from sglang.srt.managers.scheduler import Scheduler\n    from sglang.srt.server_args import ServerArgs\n\nCLIP_MAX_NEW_TOKEN = envs.SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION.get()\n\n\ndef _is_fake_transfer(req: Req, server_args: ServerArgs) -> bool:\n    return req.bootstrap_host == FAKE_BOOTSTRAP_HOST or (\n        req.bootstrap_host is None\n        and server_args.disaggregation_transfer_backend == \"fake\"\n    )\n\n\ndef _bootstrap_addr(req: Req) -> str:\n    # FIXME: make a property of a req\n    return f\"{req.bootstrap_host}:{req.bootstrap_port}\"\n\n\nclass DecodeReqToTokenPool:\n    \"\"\"\n    The difference of DecodeReqToTokenPool and ReqToTokenPool is that\n    DecodeReqToTokenPool subscribes memory for pre-allocated requests.\n\n    In ReqToTokenPool, if `--max-running-requests` is 8,\n    #pre-allocated + #transfer + #running <= 8, but there are in fact more memory can carry pre-allocated requests.\n\n    In DecodeReqToTokenPool, if `--max-running-requests` is 8,\n    #running <= 8, #pre-allocated + #transfer <= pre_alloc_size, so we can use the free memory to pre-allocate requests to unblock prefill.\n    \"\"\"\n\n    def __init__(\n        self,\n        size: int,\n        max_context_len: int,\n        device: str,\n        enable_memory_saver: bool,\n        pre_alloc_size: int,\n    ):\n        memory_saver_adapter = TorchMemorySaverAdapter.create(\n            enable=enable_memory_saver\n        )\n\n        self.size = size\n        self.max_context_len = max_context_len\n        self.device = device\n        self.pre_alloc_size = pre_alloc_size\n        with memory_saver_adapter.region(tag=GPU_MEMORY_TYPE_KV_CACHE):\n            self.req_to_token = torch.zeros(\n                (size + pre_alloc_size, max_context_len),\n                dtype=torch.int32,\n                device=device,\n            )\n\n        self.free_slots = list(range(size + pre_alloc_size))\n\n    def write(self, indices, values):\n        self.req_to_token[indices] = values\n\n    def available_size(self):\n        return len(self.free_slots)\n\n    def alloc(self, reqs: List[\"Req\"]) -> Optional[List[int]]:\n        # Indices of reqs that already have a req_pool_idx and will reuse\n        # their existing slot (e.g. chunked prefill continuing across chunks).\n        reusing = [i for i, r in enumerate(reqs) if r.req_pool_idx is not None]\n        assert (\n            len(reusing) <= 1\n        ), \"only one chunked request may reuse req_pool_idx in a batch\"\n        assert all(\n            reqs[i].is_chunked > 0 or reqs[i].kv_committed_len > 0 for i in reusing\n        ), \"reusing request must be chunked or have committed KV\"\n\n        need_size = len(reqs) - len(reusing)\n        if need_size > len(self.free_slots):\n            return None\n        select_index = self.free_slots[:need_size]\n        self.free_slots = self.free_slots[need_size:]\n        offset = 0\n        for r in reqs:\n            if r.req_pool_idx is None:\n                r.req_pool_idx = select_index[offset]\n                offset += 1\n        return [r.req_pool_idx for r in reqs]\n\n    def free(self, req: \"Req\"):\n        assert req.req_pool_idx is not None, \"request must have req_pool_idx\"\n        self.free_slots.append(req.req_pool_idx)\n        req.req_pool_idx = None\n\n    def clear(self):\n        self.free_slots = list(range(self.size + self.pre_alloc_size))\n\n\nclass HybridMambaDecodeReqToTokenPool(HybridReqToTokenPool):\n\n    def __init__(\n        self,\n        size: int,\n        max_context_len: int,\n        device: str,\n        enable_memory_saver: bool,\n        cache_params: \"Mamba2CacheParams\",\n        speculative_num_draft_tokens: int,\n        enable_mamba_extra_buffer: bool,\n        pre_alloc_size: int,\n        enable_overlap_schedule: bool,\n        mamba_size: int = None,\n    ):\n        DecodeReqToTokenPool.__init__(\n            self,\n            size=size,\n            max_context_len=max_context_len,\n            device=device,\n            enable_memory_saver=enable_memory_saver,\n            pre_alloc_size=pre_alloc_size,\n        )\n\n        self.mamba_ping_pong_track_buffer_size = 2 if enable_overlap_schedule else 1\n        self.enable_mamba_extra_buffer = enable_mamba_extra_buffer\n        self.enable_memory_saver = enable_memory_saver\n        effective_mamba_size = (\n            mamba_size if mamba_size is not None else size\n        ) + pre_alloc_size\n        self._init_mamba_pool(\n            size=effective_mamba_size,\n            mamba_spec_state_size=size + pre_alloc_size,\n            cache_params=cache_params,\n            device=device,\n            enable_mamba_extra_buffer=self.enable_mamba_extra_buffer,\n            speculative_num_draft_tokens=speculative_num_draft_tokens,\n        )\n\n    def clear(self):\n        self.free_slots = list(range(self.size + self.pre_alloc_size))\n        self.mamba_pool.clear()\n\n\n@dataclass\nclass DecodeRequest:\n    req: Req\n    kv_receiver: CommonKVReceiver\n    waiting_for_input: bool = False\n    metadata_buffer_index: int = -1\n\n    @property\n    def seqlen(self) -> int:\n        return self.req.seqlen\n\n\nclass DecodePreallocQueue:\n    \"\"\"\n    Store the requests that are preallocating.\n    \"\"\"\n\n    def __init__(\n        self,\n        req_to_token_pool: ReqToTokenPool,\n        token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,\n        draft_token_to_kv_pool: Optional[KVCache],\n        req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,\n        metadata_buffers: MetadataBuffers,\n        scheduler: Scheduler,\n        transfer_queue: DecodeTransferQueue,\n        tree_cache: BasePrefixCache,\n        gloo_group: ProcessGroup,\n        tp_rank: int,\n        tp_size: int,\n        dp_size: int,\n        gpu_id: int,\n        bootstrap_port: int,\n        max_total_num_tokens: int,\n        pp_rank: int,\n        num_reserved_decode_tokens: int,\n        transfer_backend: TransferBackend,\n    ):\n        self.req_to_token_pool = req_to_token_pool\n        self.token_to_kv_pool_allocator = token_to_kv_pool_allocator\n        self.token_to_kv_pool = token_to_kv_pool_allocator.get_kvcache()\n        self.draft_token_to_kv_pool = draft_token_to_kv_pool\n        self.is_mla_backend = is_mla_backend(self.token_to_kv_pool)\n        self.metadata_buffers = metadata_buffers\n        self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator\n        self.scheduler = scheduler\n        self.transfer_queue = transfer_queue\n        self.tree_cache = tree_cache  # this is always a chunk cache\n        self.gloo_group = gloo_group\n        self.tp_rank = tp_rank\n        self.tp_size = tp_size\n        self.dp_size = dp_size\n        self.gpu_id = gpu_id\n        self.bootstrap_port = bootstrap_port\n        self.max_total_num_tokens = max_total_num_tokens\n        self.pp_rank = pp_rank\n        self.num_reserved_decode_tokens = num_reserved_decode_tokens\n        self.transfer_backend = transfer_backend\n        # Queue for requests pending pre-allocation\n        self.queue: List[DecodeRequest] = []\n        self.retracted_queue: List[Req] = []\n        self.pending_reqs: List[Req] = []\n        self._ensure_retry_count: Dict[str, int] = {}\n        self._max_ensure_retries: int = 20  # scheduling cycles\n        self._ensure_last_attempt_time: Dict[str, float] = {}\n        self._ensure_retry_interval: float = 1.0  # seconds\n        self.kv_manager = self._init_kv_manager()\n\n        if self.scheduler.tp_worker.is_hybrid_swa:\n            # FIXME: current SWA allocation allocate full kv cache size in prefill\n            self.max_total_num_tokens = min(\n                self.max_total_num_tokens,\n                self.scheduler.tp_worker.model_runner.swa_max_total_num_tokens,\n            )\n\n    def _init_kv_manager(self) -> CommonKVManager:\n        kv_args_class = get_kv_class(self.transfer_backend, KVClassType.KVARGS)\n        kv_args = kv_args_class()\n\n        attn_tp_size = get_attention_tp_size()\n        kv_args.engine_rank = self.tp_rank % (attn_tp_size)\n\n        kv_args.pp_rank = self.pp_rank\n        kv_args.system_dp_rank = self.scheduler.dp_rank\n        kv_data_ptrs, kv_data_lens, kv_item_lens = (\n            self.token_to_kv_pool.get_contiguous_buf_infos()\n        )\n        if self.draft_token_to_kv_pool is not None:\n            # We should also transfer draft model kv cache. The indices are\n            # always shared with a target model.\n            draft_kv_data_ptrs, draft_kv_data_lens, draft_kv_item_lens = (\n                self.draft_token_to_kv_pool.get_contiguous_buf_infos()\n            )\n            kv_data_ptrs += draft_kv_data_ptrs\n            kv_data_lens += draft_kv_data_lens\n            kv_item_lens += draft_kv_item_lens\n\n        kv_args.kv_data_ptrs = kv_data_ptrs\n        kv_args.kv_data_lens = kv_data_lens\n        kv_args.kv_item_lens = kv_item_lens\n        kv_args.page_size = self.token_to_kv_pool.page_size\n\n        kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (\n            self.metadata_buffers.get_buf_infos()\n        )\n\n        if hasattr(self.token_to_kv_pool, \"get_state_buf_infos\"):\n            state_data_ptrs, state_data_lens, state_item_lens = (\n                self.token_to_kv_pool.get_state_buf_infos()\n            )\n            kv_args.state_data_ptrs = state_data_ptrs\n            kv_args.state_data_lens = state_data_lens\n            kv_args.state_item_lens = state_item_lens\n\n            if isinstance(self.token_to_kv_pool, SWAKVPool):\n                kv_args.state_type = \"swa\"\n            elif isinstance(self.token_to_kv_pool, HybridLinearKVPool):\n                kv_args.state_type = \"mamba\"\n                # Get state dimension info for cross-TP slice transfer\n                if hasattr(self.token_to_kv_pool, \"get_state_dim_per_tensor\"):\n                    kv_args.state_dim_per_tensor = (\n                        self.token_to_kv_pool.get_state_dim_per_tensor()\n                    )\n            elif isinstance(self.token_to_kv_pool, NSATokenToKVPool):\n                kv_args.state_type = \"nsa\"\n            else:\n                kv_args.state_type = \"none\"\n        else:\n            kv_args.state_data_ptrs = []\n            kv_args.state_data_lens = []\n            kv_args.state_item_lens = []\n            kv_args.state_type = \"none\"\n\n        kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device\n        kv_args.gpu_id = self.scheduler.gpu_id\n        kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)\n        kv_manager = kv_manager_class(\n            kv_args,\n            DisaggregationMode.DECODE,\n            self.scheduler.server_args,\n            self.is_mla_backend,\n        )\n        return kv_manager\n\n    def add(self, req: Req, is_retracted: bool = False) -> None:\n        \"\"\"Add a request to the pending queue.\"\"\"\n        if self._check_if_req_exceed_kv_capacity(req):\n            return\n\n        if is_retracted:\n            req.retraction_mb_id = None\n            self.retracted_queue.append(req)\n        else:\n            # NOTE: fake transfer does not need to resolve prefill dp rank in the pending queue\n            if _is_fake_transfer(req, self.scheduler.server_args):\n                self._create_receiver_and_enqueue(req, 0)\n                return\n\n            # Fast path: cache-only lookup, no network calls\n            prefill_dp_rank = self._resolve_prefill_dp_rank(req)\n            if prefill_dp_rank is not None:\n                self._create_receiver_and_enqueue(req, prefill_dp_rank)\n            else:\n                self.pending_reqs.append(req)\n\n    def _resolve_prefill_dp_rank(self, req: Req) -> Optional[int]:\n        if req.disagg_prefill_dp_rank is not None:\n            return req.disagg_prefill_dp_rank\n\n        prefill_info = self.kv_manager.prefill_info_table.get(_bootstrap_addr(req))\n        if prefill_info is None:\n            return None\n\n        if prefill_info.dp_size == 1:\n            return 0\n\n        if prefill_info.follow_bootstrap_room:\n            return req.bootstrap_room % prefill_info.dp_size\n\n        return None\n\n    def _create_receiver_and_enqueue(self, req: Req, prefill_dp_rank: int) -> None:\n        backend = (\n            TransferBackend.FAKE\n            if _is_fake_transfer(req, self.scheduler.server_args)\n            else self.transfer_backend\n        )\n        kv_receiver_class = get_kv_class(backend, KVClassType.RECEIVER)\n\n        kv_receiver = kv_receiver_class(\n            mgr=self.kv_manager,\n            bootstrap_addr=_bootstrap_addr(req),\n            bootstrap_room=req.bootstrap_room,\n            prefill_dp_rank=prefill_dp_rank,\n        )\n\n        self.queue.append(\n            DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False)\n        )\n\n    def _check_if_req_exceed_kv_capacity(self, req: Req) -> bool:\n        if len(req.origin_input_ids) > self.max_total_num_tokens:\n            message = f\"Request {req.rid} exceeds the maximum number of tokens: {len(req.origin_input_ids)} > {self.max_total_num_tokens}\"\n            logger.error(message)\n            prepare_abort(req, message, status_code=HTTPStatus.BAD_REQUEST)\n            self.scheduler.stream_output([req], req.return_logprob)\n            return True\n        return False\n\n    def extend(self, reqs: List[Req], is_retracted: bool = False) -> None:\n        \"\"\"Add a request to the pending queue.\"\"\"\n        for req in reqs:\n            self.add(req, is_retracted=is_retracted)\n\n    def resume_retracted_reqs(\n        self, rids_to_check: Optional[List[str]] = None\n    ) -> List[Req]:\n        # TODO refactor the scheduling part, reuse with the unified engine logic as much as possible\n\n        # allocate memory\n        resumed_reqs = []\n        indices_to_remove = set()\n        allocatable_tokens = self._allocatable_tokens(count_retracted=False)\n\n        for i, req in enumerate(self.retracted_queue):\n            if rids_to_check is not None and req.rid not in rids_to_check:\n                continue\n\n            if self.req_to_token_pool.available_size() <= 0:\n                break\n\n            required_tokens_for_request = (\n                len(req.origin_input_ids)\n                + len(req.output_ids)\n                + self.num_reserved_decode_tokens\n            )\n            if required_tokens_for_request > allocatable_tokens:\n                break\n\n            resumed_reqs.append(req)\n            indices_to_remove.add(i)\n            req.is_retracted = False\n            self._pre_alloc(req)\n            allocatable_tokens -= required_tokens_for_request\n\n            # load from cpu, release the cpu copy\n            req.load_kv_cache(self.req_to_token_pool, self.token_to_kv_pool_allocator)\n\n        self.retracted_queue = [\n            entry\n            for i, entry in enumerate(self.retracted_queue)\n            if i not in indices_to_remove\n        ]\n\n        return resumed_reqs\n\n    def _update_handshake_waiters(\n        self, rids_to_check: Optional[List[str]] = None\n    ) -> None:\n        if not self.queue:\n            return\n\n        if all(decode_req.waiting_for_input for decode_req in self.queue):\n            return\n\n        polls = poll_and_all_reduce(\n            [decode_req.kv_receiver for decode_req in self.queue], self.gloo_group\n        )\n\n        for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):\n            if rids_to_check is not None and decode_req.req.rid not in rids_to_check:\n                continue\n\n            if poll == KVPoll.Bootstrapping:\n                pass\n            elif poll == KVPoll.WaitingForInput:\n                decode_req.waiting_for_input = True\n                decode_req.req.time_stats.set_bootstrap_done_time()\n            elif poll == KVPoll.Failed:\n                error_message = f\"Decode handshake failed for request rank={self.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}\"\n                try:\n                    decode_req.kv_receiver.failure_exception()\n                except Exception as e:\n                    error_message += f\" with exception {e}\"\n                logger.error(error_message)\n                prepare_abort(\n                    decode_req.req,\n                    error_message,\n                    status_code=HTTPStatus.INTERNAL_SERVER_ERROR,\n                )\n                if self.scheduler.enable_metrics:\n                    self.scheduler.metrics_collector.increment_bootstrap_failed_reqs()\n            else:\n                raise ValueError(f\"Unexpected poll case: {poll}\")\n\n    def _ensure_prefill_info(\n        self, addr_to_reqs: Dict[str, List[Req]]\n    ) -> Tuple[Dict[str, List[Req]], List[Req]]:\n        \"\"\"Non-blocking ensure parallel info for each addr.\n        Returns (ready_addrs, remaining_reqs).\"\"\"\n        ready: Dict[str, List[Req]] = {}\n        remaining: List[Req] = []\n\n        now = time.monotonic()\n        for bootstrap_addr, reqs in addr_to_reqs.items():\n            last_attempt = self._ensure_last_attempt_time.get(bootstrap_addr)\n            if last_attempt is not None and (\n                now - last_attempt < self._ensure_retry_interval\n            ):\n                remaining.extend(reqs)\n                continue\n\n            self._ensure_last_attempt_time[bootstrap_addr] = now\n\n            if self.kv_manager.try_ensure_parallel_info(bootstrap_addr):\n                if bootstrap_addr in self._ensure_retry_count:\n                    del self._ensure_retry_count[bootstrap_addr]\n                if bootstrap_addr in self._ensure_last_attempt_time:\n                    del self._ensure_last_attempt_time[bootstrap_addr]\n                ready[bootstrap_addr] = reqs\n                continue\n\n            count = self._ensure_retry_count.get(bootstrap_addr, 0) + 1\n            self._ensure_retry_count[bootstrap_addr] = count\n\n            if count >= self._max_ensure_retries:\n                error_msg = f\"Could not fetch prefill parallel info from {bootstrap_addr} after {count} attempts\"\n                logger.error(error_msg)\n                for req in reqs:\n                    prepare_abort(\n                        req, error_msg, status_code=HTTPStatus.INTERNAL_SERVER_ERROR\n                    )\n                    if self.scheduler.enable_metrics:\n                        self.scheduler.metrics_collector.increment_bootstrap_failed_reqs()\n                    self.scheduler.stream_output([req], req.return_logprob)\n                del self._ensure_retry_count[bootstrap_addr]\n                del self._ensure_last_attempt_time[bootstrap_addr]\n            else:\n                remaining.extend(reqs)\n\n        return ready, remaining\n\n    def _resolve_pending_reqs(self) -> None:\n        \"\"\"Batch-resolve prefill_dp_ranks for pending requests and create receivers.\"\"\"\n        if not self.pending_reqs:\n            return\n\n        # Group pending requests by bootstrap_addr\n        addr_to_reqs: Dict[str, List[Req]] = {}\n        for req in self.pending_reqs:\n            addr = _bootstrap_addr(req)\n            addr_to_reqs.setdefault(addr, []).append(req)\n\n        # Pass 1: ensure parallel info for each addr\n        ready_addrs, remaining = self._ensure_prefill_info(addr_to_reqs)\n\n        # Pass 2: resolve dp rank for addrs whose info is available\n        resolved = []\n        for bootstrap_addr, reqs in ready_addrs.items():\n            need_query: List[Req] = []\n            for req in reqs:\n                prefill_dp_rank = self._resolve_prefill_dp_rank(req)\n                if prefill_dp_rank is not None:\n                    resolved.append((req, prefill_dp_rank))\n                else:\n                    need_query.append(req)\n\n            if need_query:\n                rooms = [req.bootstrap_room for req in need_query]\n                room_to_rank = CommonKVReceiver.query_prefill_dp_ranks(\n                    bootstrap_addr, rooms\n                )\n                for req in need_query:\n                    prefill_dp_rank = room_to_rank.get(str(req.bootstrap_room))\n                    if prefill_dp_rank is not None:\n                        resolved.append((req, int(prefill_dp_rank)))\n                    else:\n                        remaining.append(req)\n\n        self.pending_reqs = remaining\n\n        for req, prefill_dp_rank in resolved:\n            self._create_receiver_and_enqueue(req, prefill_dp_rank)\n\n    def pop_preallocated(\n        self, rids_to_check: Optional[List[str]] = None\n    ) -> Tuple[List[DecodeRequest], List[DecodeRequest]]:\n        \"\"\"Pop the preallocated requests from the pending queue (FIFO).\"\"\"\n        self._resolve_pending_reqs()\n        self._update_handshake_waiters(rids_to_check)\n\n        failed_reqs = []\n        preallocated_reqs = []\n        indices_to_remove = set()\n\n        # We need to make sure that the sum of inflight tokens and allocatable tokens is greater than maximum input+output length of each inflight request\n        # Otherwise it is possible for one request running decode out of memory, while all other requests are in the transfer queue that cannot be retracted.\n        retractable_tokens = sum(\n            len(r.origin_input_ids) + len(r.output_ids)\n            for r in self.scheduler.running_batch.reqs\n        )\n        allocatable_tokens = self._allocatable_tokens(\n            retractable_tokens=retractable_tokens, count_retracted=True\n        )\n        # First, remove all failed requests from the queue\n        for i, decode_req in enumerate(self.queue):\n            if rids_to_check is not None and decode_req.req.rid not in rids_to_check:\n                continue\n            if isinstance(decode_req.req.finished_reason, FINISH_ABORT):\n                self.scheduler.stream_output(\n                    [decode_req.req], decode_req.req.return_logprob\n                )\n                failed_reqs.append(decode_req)\n                indices_to_remove.add(i)\n\n        # Then, preallocate the remaining requests if possible\n        for i, decode_req in enumerate(self.queue):\n            if rids_to_check is not None and decode_req.req.rid not in rids_to_check:\n                continue\n\n            if i in indices_to_remove:\n                continue\n\n            if not decode_req.waiting_for_input:\n                continue\n\n            if self.req_to_token_pool.available_size() <= 0:\n                break\n\n            if self.req_to_metadata_buffer_idx_allocator.available_size() <= 0:\n                break\n\n            # Memory estimation: don't add if the projected memory cannot be met\n            # TODO: add new_token ratio\n            origin_input_len = len(decode_req.req.origin_input_ids)\n            required_tokens_for_request = (\n                origin_input_len + self.num_reserved_decode_tokens\n            )\n\n            if (\n                max(\n                    required_tokens_for_request,\n                    origin_input_len\n                    + min(\n                        decode_req.req.sampling_params.max_new_tokens,\n                        CLIP_MAX_NEW_TOKEN,\n                    )\n                    - retractable_tokens,\n                )\n                > allocatable_tokens\n            ):\n                break\n            if required_tokens_for_request > allocatable_tokens:\n                break\n\n            allocatable_tokens -= required_tokens_for_request\n            self._pre_alloc(decode_req.req)\n\n            kv_indices = (\n                self.req_to_token_pool.req_to_token[decode_req.req.req_pool_idx][\n                    : len(decode_req.req.origin_input_ids)\n                ]\n                .cpu()\n                .numpy()\n            )\n            page_size = self.token_to_kv_pool_allocator.page_size\n\n            # Prepare extra pool indices for hybrid models\n            if isinstance(self.token_to_kv_pool, HybridLinearKVPool):\n                # Mamba hybrid model: single mamba state index\n                state_indices = [\n                    self.req_to_token_pool.req_index_to_mamba_index_mapping[\n                        decode_req.req.req_pool_idx\n                    ]\n                    .cpu()\n                    .numpy()\n                ]\n            elif isinstance(self.token_to_kv_pool, SWAKVPool):\n                # SWA hybrid model: send decode-side SWA window indices\n                seq_len = len(decode_req.req.origin_input_ids)\n                window_size = self.scheduler.sliding_window_size\n\n                window_start = max(0, seq_len - window_size)\n                window_start = (window_start // page_size) * page_size\n                window_kv_indices_full = self.req_to_token_pool.req_to_token[\n                    decode_req.req.req_pool_idx, window_start:seq_len\n                ]\n\n                # Translate to SWA pool indices\n                window_kv_indices_swa = (\n                    self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(\n                        window_kv_indices_full\n                    )\n                )\n                state_indices = window_kv_indices_swa.cpu().numpy()\n                state_indices = kv_to_page_indices(state_indices, page_size)\n            elif isinstance(self.token_to_kv_pool, NSATokenToKVPool):\n                seq_len = len(decode_req.req.origin_input_ids)\n                kv_indices_full = self.req_to_token_pool.req_to_token[\n                    decode_req.req.req_pool_idx, :seq_len\n                ]\n                state_indices = kv_indices_full.cpu().numpy()\n                state_indices = kv_to_page_indices(state_indices, page_size)\n            else:\n                state_indices = None\n\n            decode_req.metadata_buffer_index = (\n                self.req_to_metadata_buffer_idx_allocator.alloc()\n            )\n            assert decode_req.metadata_buffer_index is not None\n            page_indices = kv_to_page_indices(kv_indices, page_size)\n            decode_req.kv_receiver.init(\n                page_indices, decode_req.metadata_buffer_index, state_indices\n            )\n            preallocated_reqs.append(decode_req)\n            indices_to_remove.add(i)\n            decode_req.req.time_stats.set_decode_transfer_queue_entry_time()\n\n        self.queue = [\n            entry for i, entry in enumerate(self.queue) if i not in indices_to_remove\n        ]\n\n        return preallocated_reqs, failed_reqs\n\n    @property\n    def num_tokens_pre_allocated(self):\n        return sum(\n            len(decode_req.req.fill_ids) for decode_req in self.transfer_queue.queue\n        )\n\n    def _allocatable_tokens(\n        self, retractable_tokens: Optional[int] = None, count_retracted: bool = True\n    ) -> int:\n        need_space_for_single_req = (\n            max(\n                [\n                    min(x.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKEN)\n                    + len(x.origin_input_ids)\n                    - retractable_tokens\n                    for x in self.scheduler.running_batch.reqs\n                ]\n            )\n            if retractable_tokens is not None\n            and len(self.scheduler.running_batch.reqs) > 0\n            else 0\n        )\n        available_size = self.token_to_kv_pool_allocator.available_size()\n        allocatable_tokens = available_size - max(\n            # preserve some space for future decode\n            self.num_reserved_decode_tokens\n            * (\n                len(self.scheduler.running_batch.reqs)\n                + len(self.transfer_queue.queue)\n                + len(self.scheduler.waiting_queue)\n            ),\n            # make sure each request can finish if reach max_tokens with all other requests retracted\n            need_space_for_single_req,\n        )\n\n        # Note: if the last prebuilt extend just finishes, and we enter `pop_preallocated` immediately in the next iteration\n        #       the extend batch is not in any queue, so we need to explicitly add the tokens slots here\n        if (\n            self.scheduler.last_batch\n            and self.scheduler.last_batch.forward_mode.is_prebuilt()\n        ):\n            allocatable_tokens -= self.num_reserved_decode_tokens * len(\n                self.scheduler.last_batch.reqs\n            )\n\n        if count_retracted:\n            allocatable_tokens -= sum(\n                [\n                    len(req.origin_input_ids)\n                    + len(req.output_ids)\n                    + self.num_reserved_decode_tokens\n                    for req in self.retracted_queue\n                ]\n            )\n        return allocatable_tokens\n\n    def _pre_alloc(self, req: Req) -> torch.Tensor:\n        \"\"\"Pre-allocate the memory for req_to_token and token_kv_pool\"\"\"\n        req_pool_indices = self.req_to_token_pool.alloc([req])\n\n        assert (\n            req_pool_indices is not None\n        ), \"req_pool_indices is full! There is a bug in memory estimation.\"\n\n        # Alloc all tokens for the prebuilt req (except for the reserved input token for decoding)\n        fill_len = len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)\n        req.kv_allocated_len = fill_len\n        req.kv_committed_len = fill_len\n        if self.token_to_kv_pool_allocator.page_size == 1:\n            kv_loc = self.token_to_kv_pool_allocator.alloc(fill_len)\n        else:\n            device = self.token_to_kv_pool_allocator.device\n            kv_loc = self.token_to_kv_pool_allocator.alloc_extend(\n                prefix_lens=torch.tensor([0], dtype=torch.int64, device=device),\n                prefix_lens_cpu=torch.tensor([0], dtype=torch.int64),\n                seq_lens=torch.tensor([fill_len], dtype=torch.int64, device=device),\n                seq_lens_cpu=torch.tensor([fill_len], dtype=torch.int64),\n                last_loc=torch.tensor([-1], dtype=torch.int64, device=device),\n                extend_num_tokens=fill_len,\n            )\n\n        assert (\n            kv_loc is not None\n        ), \"KV cache is full! There is a bug in memory estimation.\"\n\n        self.req_to_token_pool.write((req.req_pool_idx, slice(0, len(kv_loc))), kv_loc)\n\n        # populate metadata\n        req.fill_ids = req.origin_input_ids + req.output_ids\n        req.set_extend_input_len(len(req.fill_ids))\n\n        return kv_loc\n\n\nclass DecodeTransferQueue:\n    \"\"\"\n    Store the requests that is polling kv\n    \"\"\"\n\n    def __init__(\n        self,\n        gloo_group: ProcessGroup,\n        req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,\n        tp_rank: int,\n        metadata_buffers: MetadataBuffers,\n        scheduler: Scheduler,\n        tree_cache: BasePrefixCache,\n    ):\n        self.queue: List[DecodeRequest] = []\n        self.gloo_group = gloo_group\n        self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator\n        self.tp_rank = tp_rank\n        self.metadata_buffers = metadata_buffers\n        self.scheduler = scheduler\n        self.tree_cache = tree_cache\n        self.spec_algorithm = scheduler.spec_algorithm\n\n    def add(self, decode_req: DecodeRequest) -> None:\n        self.queue.append(decode_req)\n\n    def extend(self, decode_reqs: List[DecodeRequest]) -> None:\n        self.queue.extend(decode_reqs)\n\n    def _commit_transfer_to_req(self, decode_req: DecodeRequest) -> bool:\n        \"\"\"\n        Returns:\n            True if the request should be removed from the queue (success or corruption)\n            False if metadata not ready yet (keep in queue for next poll)\n        \"\"\"\n        idx = decode_req.metadata_buffer_index\n        (\n            output_id,\n            cached_tokens,\n            output_token_logprobs_val,\n            output_token_logprobs_idx,\n            output_top_logprobs_val,\n            output_top_logprobs_idx,\n            output_topk_p,\n            output_topk_index,\n            output_hidden_states,\n            output_bootstrap_room,\n        ) = self.metadata_buffers.get_buf(idx)\n\n        # Validate bootstrap_room to detect context corruption\n        actual_room = output_bootstrap_room[0].item()\n        expected_room = (\n            decode_req.req.bootstrap_room\n            if decode_req.req.bootstrap_room is not None\n            else 0\n        )\n\n        if _is_fake_transfer(decode_req.req, self.scheduler.server_args):\n            pass\n        elif actual_room == 0:\n            # Case 1: Metadata not ready yet (actual_room == 0)\n            # Keep request in queue and wait for next poll\n            return False\n        elif actual_room != expected_room:\n            # Case 2: Real corruption detected (mismatch)\n            # Abort the request and remove from the queue\n            error_msg = (\n                f\"Context corruption detected: Request {decode_req.req.rid} \"\n                f\"(bootstrap_room={expected_room}) received metadata from \"\n                f\"bootstrap_room={actual_room}. \"\n                f\"Metadata buffer index: {idx}. \"\n                f\"This indicates metadata buffer index collision.\"\n            )\n            logger.error(error_msg)\n            prepare_abort(\n                decode_req.req,\n                \"Metadata corruption detected - bootstrap_room mismatch\",\n                status_code=HTTPStatus.INTERNAL_SERVER_ERROR,\n            )\n            decode_req.kv_receiver.clear()\n            decode_req.kv_receiver = None\n            return True\n\n        # Case 3: Success - commit the transfer\n        decode_req.req.output_ids.append(output_id[0].item())\n        decode_req.req.cached_tokens = cached_tokens[0].item()\n        if not self.spec_algorithm.is_none():\n            decode_req.req.output_topk_p = output_topk_p\n            decode_req.req.output_topk_index = output_topk_index\n            decode_req.req.hidden_states_tensor = output_hidden_states\n\n        if decode_req.req.return_logprob:\n            decode_req.req.output_token_logprobs_val.append(\n                output_token_logprobs_val[0].item()\n            )\n            decode_req.req.output_token_logprobs_idx.append(\n                output_token_logprobs_idx[0].item()\n            )\n            decode_req.req.output_top_logprobs_val.append(\n                output_top_logprobs_val[: decode_req.req.top_logprobs_num].tolist()\n            )\n            decode_req.req.output_top_logprobs_idx.append(\n                output_top_logprobs_idx[: decode_req.req.top_logprobs_num].tolist()\n            )\n\n        decode_req.kv_receiver.clear()\n        decode_req.kv_receiver = None\n        decode_req.req.time_stats.set_wait_queue_entry_time()\n        return True\n\n    def pop_transferred(self, rids_to_check: Optional[List[str]] = None) -> List[Req]:\n        if not self.queue:\n            return []\n        polls = poll_and_all_reduce(\n            [decode_req.kv_receiver for decode_req in self.queue], self.gloo_group\n        )\n\n        transferred_reqs = []\n        indices_to_remove = set()\n        for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):\n            if rids_to_check is not None and decode_req.req.rid not in rids_to_check:\n                continue\n            if poll == KVPoll.Failed:\n                error_message = f\"Decode transfer failed for request rank={self.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}\"\n                try:\n                    decode_req.kv_receiver.failure_exception()\n                except Exception as e:\n                    error_message += f\" with exception {e}\"\n                logger.error(error_message)\n                prepare_abort(\n                    decode_req.req,\n                    error_message,\n                    status_code=HTTPStatus.INTERNAL_SERVER_ERROR,\n                )\n                self.scheduler.stream_output(\n                    [decode_req.req], decode_req.req.return_logprob\n                )\n                # release pre-allocated kv cache, but don't insert into the tree since it's failed\n                release_kv_cache(decode_req.req, self.tree_cache, is_insert=False)\n                indices_to_remove.add(i)\n                if self.scheduler.enable_metrics:\n                    self.scheduler.metrics_collector.increment_transfer_failed_reqs()\n                continue\n            elif poll == KVPoll.Success:\n                should_remove = self._commit_transfer_to_req(decode_req)\n                if should_remove:\n                    indices_to_remove.add(i)\n                    # Check if request was aborted due to corruption\n                    if isinstance(decode_req.req.finished_reason, FINISH_ABORT):\n                        self.scheduler.stream_output(\n                            [decode_req.req], decode_req.req.return_logprob\n                        )\n                        release_kv_cache(\n                            decode_req.req, self.tree_cache, is_insert=False\n                        )\n                        if self.scheduler.enable_metrics:\n                            self.scheduler.metrics_collector.increment_transfer_failed_reqs()\n                    else:\n                        transferred_reqs.append(decode_req.req)\n            elif poll in [\n                KVPoll.Bootstrapping,\n                KVPoll.WaitingForInput,\n                KVPoll.Transferring,\n            ]:\n                pass\n            else:\n                raise ValueError(f\"Unexpected poll case: {poll}\")\n\n        for i in indices_to_remove:\n            idx = self.queue[i].metadata_buffer_index\n            assert idx != -1\n            self.req_to_metadata_buffer_idx_allocator.free(idx)\n\n        self.queue = [\n            entry for i, entry in enumerate(self.queue) if i not in indices_to_remove\n        ]\n\n        return transferred_reqs\n\n\nclass SchedulerDisaggregationDecodeMixin:\n\n    @torch.no_grad()\n    def event_loop_normal_disagg_decode(self: Scheduler):\n        \"\"\"A normal scheduler loop for decode worker in disaggregation mode.\"\"\"\n\n        while True:\n            # Receive requests\n            recv_reqs = self.recv_requests()\n            self.process_input_requests(recv_reqs)\n            # polling and allocating kv cache\n            self.process_decode_queue()\n\n            # Get the next batch to run\n            batch = self.get_next_disagg_decode_batch_to_run()\n            self.cur_batch = batch\n\n            # Launch the current batch\n            if batch:\n                result = self.run_batch(batch)\n                self.process_batch_result(batch, result)\n            else:\n                # When the server is idle, do self-check and re-init some states\n                self.self_check_during_idle()\n\n            # Update last_batch\n            self.last_batch = batch\n\n    @torch.no_grad()\n    def event_loop_overlap_disagg_decode(self: Scheduler):\n        self.result_queue = deque()\n        self.last_batch: Optional[ScheduleBatch] = None\n\n        while True:\n            # Receive requests\n            recv_reqs = self.recv_requests()\n            self.process_input_requests(recv_reqs)\n            # polling and allocating kv cache\n            self.process_decode_queue()\n\n            # Get the next batch to run\n            batch = self.get_next_disagg_decode_batch_to_run()\n            self.cur_batch = batch\n\n            # Launch the current batch\n            if batch:\n                batch_result = self.run_batch(batch)\n                self.result_queue.append((batch.copy(), batch_result))\n            else:\n                batch_result = None\n\n            # Process the last batch\n            if self.last_batch:\n                tmp_batch, tmp_result = self.result_queue.popleft()\n                self.process_batch_result(tmp_batch, tmp_result)\n            elif batch is None:\n                self.self_check_during_idle()\n\n            # Run sample of the current batch\n            # It depends on the result of the last batch (e.g., grammar), so we run it after the last batch is processed.\n            self.launch_batch_sample_if_needed(batch_result)\n\n            # Update last_batch\n            self.last_batch = batch\n\n    def _run_batch_prebuilt(\n        self: Scheduler, batch: ScheduleBatch\n    ) -> GenerationBatchResult:\n        if batch.inner_idle_batch is not None:\n            idle_batch = batch.inner_idle_batch\n            # Reset the inner idle batch to avoid reusing it.\n            batch.inner_idle_batch = None\n            return self.run_batch(idle_batch)\n\n        return GenerationBatchResult()\n\n    def get_next_disagg_decode_batch_to_run(\n        self: Scheduler,\n    ) -> Optional[ScheduleBatch]:\n        \"\"\"Process prebuilt batch and schedule the next decode batch.\"\"\"\n        # Process pending prebuilt batch: output processing + filter + merge\n        new_prebuilt_batch = self.get_new_prebuilt_batch()\n        if new_prebuilt_batch:\n            assert self.chunked_req is None\n            self.process_batch_result_prebuilt(new_prebuilt_batch)\n            new_prebuilt_batch.filter_batch()\n            if not new_prebuilt_batch.is_empty():\n                if self.running_batch.is_empty():\n                    self.running_batch = new_prebuilt_batch\n                else:\n                    self.running_batch.merge_batch(new_prebuilt_batch)\n\n        # Schedule decode batch\n        if self.running_batch.is_empty():\n            ret = None\n        else:\n            self.running_batch = self.update_running_batch(self.running_batch)\n            ret = self.running_batch if not self.running_batch.is_empty() else None\n\n        ret = self.maybe_prepare_mlp_sync_batch(ret)\n        if ret:\n            set_schedule_time_batch(ret)\n        return ret\n\n    def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]:\n        \"\"\"Create a schedulebatch for fake completed prefill\"\"\"\n        if self.grammar_manager.has_waiting_grammars():\n            ready_grammar_requests = self.grammar_manager.get_ready_grammar_requests()\n            for req in ready_grammar_requests:\n                self._add_request_to_queue(req)\n\n        if len(self.waiting_queue) == 0:\n            return None\n\n        curr_batch_size = self.running_batch.batch_size()\n\n        batch_size = min(self.req_to_token_pool.size, self.max_running_requests)\n\n        num_not_used_batch = batch_size - curr_batch_size\n\n        # pop req from waiting queue\n        can_run_list: List[Req] = []\n        waiting_queue: List[Req] = []\n\n        for i in range(len(self.waiting_queue)):\n            req = self.waiting_queue[i]\n            # we can only add at least `num_not_used_batch` new batch to the running queue\n            if i < num_not_used_batch:\n                can_run_list.append(req)\n                req.init_next_round_input(self.tree_cache)\n            else:\n                waiting_queue.append(req)\n\n        self.waiting_queue = waiting_queue\n        if len(can_run_list) == 0:\n            return None\n\n        set_time_batch(can_run_list, \"set_forward_entry_time\")\n\n        # construct a schedule batch with those requests and mark as decode\n        new_batch = ScheduleBatch.init_new(\n            can_run_list,\n            self.req_to_token_pool,\n            self.token_to_kv_pool_allocator,\n            self.tree_cache,\n            self.model_config,\n            self.enable_overlap,\n            self.spec_algorithm,\n        )\n\n        # construct fake completed prefill\n        new_batch.prepare_for_prebuilt()\n        new_batch.process_prebuilt(self.server_args, self.future_map)\n\n        return new_batch\n\n    def process_decode_queue(self: Scheduler):\n        if self.server_args.disaggregation_decode_enable_offload_kvcache:\n            self.decode_offload_manager.check_offload_progress()\n\n        # try to resume retracted requests if there are enough space for another `num_reserved_decode_tokens` decode steps\n        resumed_reqs = self.disagg_decode_prealloc_queue.resume_retracted_reqs()\n        self.waiting_queue.extend(resumed_reqs)\n        if len(self.disagg_decode_prealloc_queue.retracted_queue) > 0:\n            # if there are still retracted requests, we do not allocate new requests\n            return\n\n        if not hasattr(self, \"polling_count\"):\n            self.polling_count = 0\n            self.polling_interval = (\n                self.server_args.disaggregation_decode_polling_interval\n            )\n\n        self.polling_count = (self.polling_count + 1) % self.polling_interval\n\n        if self.polling_count % self.polling_interval == 0:\n            req_conns, _ = self.disagg_decode_prealloc_queue.pop_preallocated()\n            self.disagg_decode_transfer_queue.extend(req_conns)\n            transferred_reqs = (\n                self.disagg_decode_transfer_queue.pop_transferred()\n            )  # the requests which kv has arrived\n            self.waiting_queue.extend(transferred_reqs)\n"
  },
  {
    "path": "python/sglang/srt/disaggregation/decode_kvcache_offload_manager.py",
    "content": "from __future__ import annotations\n\nimport json\nimport logging\nimport threading\nimport time\nfrom typing import TYPE_CHECKING\n\nimport torch\n\nfrom sglang.srt.disaggregation.kv_events import OffloadedState\nfrom sglang.srt.environ import envs\nfrom sglang.srt.managers.cache_controller import HiCacheController\nfrom sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator\nfrom sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache\nfrom sglang.srt.mem_cache.memory_pool import (\n    MHATokenToKVPool,\n    MLATokenToKVPool,\n    ReqToTokenPool,\n)\nfrom sglang.srt.mem_cache.memory_pool_host import (\n    MHATokenToKVPoolHost,\n    MLATokenToKVPoolHost,\n)\nfrom sglang.srt.server_args import ServerArgs\nfrom sglang.srt.utils.common import ceil_align\n\nif TYPE_CHECKING:\n    from sglang.srt.managers.schedule_batch import Req\n\nlogger = logging.getLogger(__name__)\n\n\nclass DecodeKVCacheOffloadManager:\n    \"\"\"Manage decode-side KV cache offloading lifecycle and operations.\"\"\"\n\n    def __init__(\n        self,\n        req_to_token_pool: ReqToTokenPool,\n        token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,\n        tp_group: torch.distributed.ProcessGroup,\n        tree_cache: BasePrefixCache,\n        server_args: ServerArgs,\n    ) -> None:\n        self.req_to_token_pool = req_to_token_pool\n        self.token_to_kv_pool_allocator = token_to_kv_pool_allocator\n        self.page_size = server_args.page_size\n        self.server_args = server_args\n        self.request_counter = 0\n        self.tree_cache = tree_cache\n        env_stride = envs.SGLANG_HICACHE_DECODE_OFFLOAD_STRIDE.get()\n        if env_stride is None or env_stride <= 0:\n            self.offload_stride = self.page_size\n        else:\n            self.offload_stride = max(\n                self.page_size, (env_stride // self.page_size) * self.page_size\n            )\n        kv_cache = self.token_to_kv_pool_allocator.get_kvcache()\n        if isinstance(kv_cache, MHATokenToKVPool):\n            self.decode_host_mem_pool = MHATokenToKVPoolHost(\n                kv_cache,\n                server_args.hicache_ratio,\n                server_args.hicache_size,\n                self.page_size,\n                server_args.hicache_mem_layout,\n            )\n        elif isinstance(kv_cache, MLATokenToKVPool):\n            self.decode_host_mem_pool = MLATokenToKVPoolHost(\n                kv_cache,\n                server_args.hicache_ratio,\n                server_args.hicache_size,\n                self.page_size,\n                server_args.hicache_mem_layout,\n            )\n        else:\n            raise ValueError(\"Unsupported KV cache type for decode offload\")\n\n        self.tp_group = tp_group\n        self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)\n\n        hicache_storage_backend_extra_config = {}\n        if server_args.hicache_storage_backend_extra_config:\n            try:\n                hicache_storage_backend_extra_config = json.loads(\n                    server_args.hicache_storage_backend_extra_config\n                )\n            except json.JSONDecodeError as e:\n                raise ValueError(\n                    f\"Invalid hicache storage backend extra config JSON: {e}\"\n                )\n\n        self.cache_controller = HiCacheController(\n            token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,\n            mem_pool_host=self.decode_host_mem_pool,\n            page_size=self.page_size,\n            tp_group=tp_group,\n            io_backend=server_args.hicache_io_backend,\n            load_cache_event=threading.Event(),\n            storage_backend=server_args.hicache_storage_backend,\n            model_name=server_args.served_model_name,\n            storage_backend_extra_config=hicache_storage_backend_extra_config,\n        )\n\n        self.ongoing_offload = {}\n        self.ongoing_backup = {}\n        self.offloaded_state = {}\n        logger.info(\"Enable offload kv cache for decode side\")\n\n    def offload_kv_cache(self, req) -> bool:\n        \"\"\"Offload incremental KV cache for decode side.\"\"\"\n\n        if self.cache_controller is None or self.decode_host_mem_pool is None:\n            return False\n\n        if req.req_pool_idx == -1 or len(req.output_ids) == 0:\n            return False\n\n        token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx]\n        if token_indices.dim() == 0 or token_indices.numel() == 0:\n            return False\n\n        # Prefill side offloads page-aligned origin_input_ids, decode side offloads the incremental part\n        all_tokens = req.origin_input_ids + req.output_ids[:-1]\n        prefill_offloaded_len = (\n            len(req.origin_input_ids) // self.page_size * self.page_size\n        )\n        state = self.offloaded_state.get(req.rid)\n        if state is None:\n            prefill_hashes = self._compute_prefix_hash(\n                req.origin_input_ids[:prefill_offloaded_len]\n            )\n            last_prefill_hash = (\n                prefill_hashes[-1] if prefill_offloaded_len > 0 else None\n            )\n            state = OffloadedState(\n                prefill_len=prefill_offloaded_len,\n                inc_len=0,\n                last_hash=last_prefill_hash,\n            )\n            self.offloaded_state[req.rid] = state\n        incremental_total = len(all_tokens) - state.prefill_len\n        incremental_new = incremental_total - state.inc_len\n        incremental_aligned_len = (\n            incremental_new // self.offload_stride * self.offload_stride\n        )\n\n        if incremental_aligned_len == 0:\n            return False\n\n        # Extract incremental tokens and indices for the newly available chunk\n        start = state.prefill_len + state.inc_len\n        end = start + incremental_aligned_len\n        incremental_tokens = all_tokens[start:end]\n        incremental_indices = token_indices[start:end]\n\n        # Early free prefill-offloaded GPU memory\n        if state.prefill_len > 0 and state.inc_len == 0:\n            self.token_to_kv_pool_allocator.free(token_indices[: state.prefill_len])\n\n        # Asynchronously offload incremental KV cache from device to host\n        self.request_counter += 1\n        ack_id = self.request_counter\n        host_indices = self.cache_controller.write(\n            device_indices=incremental_indices.long(),\n            node_id=ack_id,\n        )\n        if host_indices is None:\n            logger.error(f\"Not enough host memory for request {req.rid}\")\n            return False\n\n        self.ongoing_offload[ack_id] = (\n            req,\n            host_indices,\n            incremental_tokens,\n            time.time(),\n            start,\n            end,\n        )\n        state.inc_len += incremental_aligned_len\n        return True\n\n    def check_offload_progress(self):\n        \"\"\"Check the progress of offload from device to host and backup from host to storage.\"\"\"\n        cc = self.cache_controller\n\n        qsizes = torch.tensor(\n            [\n                len(cc.ack_write_queue),\n                cc.ack_backup_queue.qsize(),\n            ],\n            dtype=torch.int,\n        )\n        if self.tp_world_size > 1:\n            torch.distributed.all_reduce(\n                qsizes, op=torch.distributed.ReduceOp.MIN, group=self.tp_group\n            )\n\n        n_write, n_backup = map(int, qsizes.tolist())\n        self._check_offload_progress(n_write)\n        self._check_backup_progress(n_backup)\n\n    def _check_offload_progress(self, finish_count):\n        \"\"\"Check the progress of offload from device to host.\"\"\"\n        while finish_count > 0:\n            _, finish_event, ack_list = self.cache_controller.ack_write_queue.pop(0)\n            finish_event.synchronize()\n            for ack_id in ack_list:\n                (\n                    req,\n                    host_indices,\n                    incremental_tokens,\n                    start_time,\n                    start,\n                    end,\n                ) = self.ongoing_offload.pop(ack_id)\n\n                if req.finished():\n                    self._release_finished_req(req, start)\n                else:\n                    kv_indices = self.req_to_token_pool.req_to_token[\n                        req.req_pool_idx, start:end\n                    ]\n                    self.token_to_kv_pool_allocator.free(kv_indices)\n\n                prior_hash = (\n                    self.offloaded_state[req.rid].last_hash\n                    if req.rid in self.offloaded_state\n                    else None\n                )\n                last_hash = self._trigger_backup(\n                    req, host_indices, incremental_tokens, start_time, prior_hash\n                )\n                if req.rid in self.offloaded_state:\n                    self.offloaded_state[req.rid].last_hash = last_hash\n            finish_count -= 1\n\n    def _release_finished_req(self, req: Req, start_offset: int):\n        kv_committed_len = req.pop_committed_kv_cache()\n        start = start_offset\n        end = kv_committed_len\n        # Free the incremental part of the request (NSA-aware)\n        kv_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx, start:end]\n        self.token_to_kv_pool_allocator.free(kv_indices)\n\n        # Free over-allocated KV cache slots (e.g. from speculative decoding v2).\n        # Without spec v2, start_p == end_p so this is a no-op.\n        start_p, end_p = req.pop_overallocated_kv_cache()\n        if self.page_size > 1:\n            start_p = ceil_align(start_p, self.page_size)\n        if start_p < end_p:\n            overalloc_indices = self.req_to_token_pool.req_to_token[\n                req.req_pool_idx, start_p:end_p\n            ]\n            self.token_to_kv_pool_allocator.free(overalloc_indices)\n\n        self.req_to_token_pool.free(req)\n        self.tree_cache.protected_size_ -= len(req.prefix_indices)\n        if req.rid in self.offloaded_state:\n            del self.offloaded_state[req.rid]\n\n    def _check_backup_progress(self, finish_count):\n        \"\"\"Check the progress of backup from host to storage.\"\"\"\n        for _ in range(finish_count):\n            storage_operation = self.cache_controller.ack_backup_queue.get()\n            ack_id = storage_operation.id\n            req_id, host_indices, start_time = self.ongoing_backup.pop(ack_id)\n\n            # Release host memory\n            self.decode_host_mem_pool.free(host_indices)\n\n            logger.debug(\n                f\"Finished backup request {req_id}, free host memory, len:{len(host_indices)}, cost time:{time.time() - start_time:.2f} seconds.\"\n            )\n\n    def _trigger_backup(\n        self, req, host_indices, incremental_tokens, start_time, prior_hash\n    ):\n        \"\"\"Trigger async backup from host to storage.\"\"\"\n        page_hashes = self._compute_prefix_hash(incremental_tokens, prior_hash)\n        ack_id = self.cache_controller.write_storage(\n            host_indices,\n            incremental_tokens,\n            hash_value=page_hashes,\n        )\n        self.ongoing_backup[ack_id] = (req.rid, host_indices, start_time)\n        return page_hashes[-1] if len(page_hashes) > 0 else prior_hash\n\n    def _compute_prefix_hash(self, tokens, prior_hash=\"\"):\n        page_hashes = []\n        last_hash = prior_hash\n        for offset in range(0, len(tokens), self.page_size):\n            page_tokens = tokens[offset : offset + self.page_size]\n            last_hash = self.cache_controller.get_hash_str(page_tokens, last_hash)\n            page_hashes.append(last_hash)\n        return page_hashes\n\n    def finalize_release_on_finish(self, req: Req):\n        \"\"\"Free any remaining tail KV that was not offloaded due to non-aligned length.\"\"\"\n        if req.req_pool_idx == -1:\n            return\n        state = self.offloaded_state.get(req.rid)\n        if state is None:\n            prefill_len = len(req.origin_input_ids) // self.page_size * self.page_size\n            inc_len = 0\n        else:\n            prefill_len = state.prefill_len\n            inc_len = state.inc_len\n        # If no incremental offload ever happened, the prefill-aligned part was never freed.\n        # Free the prefill portion on request finish to avoid leaks.\n        if prefill_len > 0 and inc_len == 0:\n            token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx]\n            self.token_to_kv_pool_allocator.free(token_indices[:prefill_len])\n            logger.info(\n                f\"Finalize release: freed prefill-aligned KV for req {req.rid}, len:{prefill_len}\"\n            )\n        start_offset = prefill_len + inc_len\n        self._release_finished_req(req, start_offset)\n"
  },
  {
    "path": "python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py",
    "content": "from __future__ import annotations\n\nimport logging\nfrom http import HTTPStatus\nfrom typing import TYPE_CHECKING\n\nimport torch\n\nfrom sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode\nfrom sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo\n\nlogger = logging.getLogger(__name__)\n\nif TYPE_CHECKING:\n    from sglang.srt.managers.overlap_utils import FutureMap\n    from sglang.srt.managers.schedule_batch import ScheduleBatch\n    from sglang.srt.server_args import ServerArgs\n\n\nclass ScheduleBatchDisaggregationDecodeMixin:\n\n    def prepare_for_prebuilt(self: ScheduleBatch):\n        \"\"\"\n        Prepare a prebuilt extend by populate metadata\n        Adapted from .prepare_for_extend().\n        \"\"\"\n\n        self.forward_mode = ForwardMode.PREBUILT\n        reqs = self.reqs\n        input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]\n        extend_num_tokens = sum(len(ids) for ids in input_ids)\n        seq_lens = []\n        pre_lens = []\n        req_pool_indices = []\n\n        # Pre-calculate total size\n        total_size = sum(req.extend_input_len for req in reqs)\n        out_cache_loc = torch.empty(total_size, dtype=torch.int64, device=self.device)\n\n        # Fill the tensor in one pass\n        offset = 0\n        for i, req in enumerate(reqs):\n            req_pool_indices.append(req.req_pool_idx)\n\n            chunk = self.req_to_token_pool.req_to_token[req.req_pool_idx][\n                : req.extend_input_len\n            ]\n            assert (\n                offset + req.extend_input_len <= total_size\n            ), f\"Exceeds total size: offset={offset}, req.extend_input_len={req.extend_input_len}, total_size={total_size}\"\n            out_cache_loc[offset : offset + req.extend_input_len] = chunk\n            offset += req.extend_input_len\n\n            pre_len = len(req.prefix_indices)\n            seq_len = len(req.origin_input_ids) + max(0, len(req.output_ids) - 1)\n            seq_lens.append(seq_len)\n            if len(req.output_ids) == 0:\n                assert (\n                    seq_len - pre_len == req.extend_input_len\n                ), f\"seq_len={seq_len}, pre_len={pre_len}, req.extend_input_len={req.extend_input_len}\"\n\n            if not req.retracted_stain:\n                req.cached_tokens += pre_len - req.already_computed\n                req.already_computed = seq_len\n            req.is_retracted = False\n            pre_lens.append(pre_len)\n            req.extend_logprob_start_len = 0\n\n        extend_input_logprob_token_ids = None\n\n        # Set fields\n        self.input_ids = torch.tensor(\n            sum(input_ids, []), dtype=torch.int32, device=self.device\n        )\n        self.req_pool_indices = torch.tensor(\n            req_pool_indices, dtype=torch.int64, device=self.device\n        )\n        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)\n        self.seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)\n        self.orig_seq_lens = torch.tensor(\n            seq_lens, dtype=torch.int32, device=self.device\n        )\n        self.out_cache_loc = out_cache_loc\n        self.seq_lens_sum = sum(seq_lens)\n\n        if self.return_logprob:\n            self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]\n            self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]\n\n        self.extend_num_tokens = extend_num_tokens\n        self.prefix_lens = [len(r.prefix_indices) for r in reqs]\n        self.extend_lens = [r.extend_input_len for r in reqs]\n        self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]\n        self.extend_input_logprob_token_ids = extend_input_logprob_token_ids\n        self.multimodal_inputs = [r.multimodal_inputs for r in reqs]\n\n        # Build sampling info\n        self.sampling_info = SamplingBatchInfo.from_schedule_batch(\n            self,\n            self.model_config.vocab_size,\n        )\n\n    def process_prebuilt(\n        self: ScheduleBatch,\n        server_args: ServerArgs,\n        future_map: FutureMap,\n    ):\n        \"\"\"Assign the buffered last input id to schedule batch\"\"\"\n        self.output_ids = []\n        for req in self.reqs:\n            self.output_ids.append(req.output_ids[-1])\n            self.tree_cache.cache_unfinished_req(req)\n            if req.grammar is not None:\n                # FIXME: this try-except block is for handling unexpected xgrammar issue.\n                try:\n                    # if it is not None, then the grammar is from a retracted request, and we should not\n                    # accept the token as it's already accepted\n                    if req.grammar.current_token is None:\n                        req.grammar.accept_token(req.output_ids[-1])\n                except ValueError as e:\n                    from sglang.srt.managers.schedule_batch import FINISH_ABORT\n\n                    # Grammar accept_token can raise ValueError if the token is not in the grammar.\n                    # This can happen if the grammar is not set correctly or the token is invalid.\n                    # Use to_finish (not finished_reason) so that process_batch_result_prebuilt\n                    # handles the release via check_finished -> release_kv_cache in one place.\n                    error_message = f\"Grammar accept_token failed for req {req.rid} with token {req.output_ids[-1]}: {e}\"\n                    req.to_finish = FINISH_ABORT(\n                        error_message, HTTPStatus.INTERNAL_SERVER_ERROR\n                    )\n                req.grammar.finished = req.finished()\n        self.output_ids = torch.tensor(self.output_ids, device=self.device)\n\n        # Simulate the eagle run.\n        if self.spec_algorithm.is_eagle():\n            num_states = server_args.speculative_eagle_topk\n            if server_args.enable_multi_layer_eagle:\n                num_states *= server_args.speculative_num_steps\n            topk_p = torch.stack(\n                [\n                    torch.as_tensor(\n                        req.output_topk_p[:num_states],\n                        device=self.device,\n                        dtype=torch.float32,\n                    )\n                    for req in self.reqs\n                ],\n                dim=0,\n            )\n            topk_index = torch.stack(\n                [\n                    torch.as_tensor(\n                        req.output_topk_index[:num_states],\n                        device=self.device,\n                        dtype=torch.int64,\n                    )\n                    for req in self.reqs\n                ],\n                dim=0,\n            )\n\n            hidden_states_list = [req.hidden_states_tensor for req in self.reqs]\n            hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device)\n\n            # local import to avoid circular import\n            from sglang.srt.speculative.eagle_info import EagleDraftInput\n\n            spec_info = EagleDraftInput(\n                topk_p=topk_p,\n                topk_index=topk_index,\n                hidden_states=hidden_states,\n                verified_id=self.output_ids,\n                new_seq_lens=self.seq_lens,\n            )\n            spec_info.prepare_for_extend(self)\n            spec_info.capture_hidden_mode = CaptureHiddenMode.LAST\n            if self.enable_overlap:\n                spec_info.future_indices = future_map.alloc_future_indices(\n                    len(self.seq_lens)\n                )\n                future_map.store_to_map_for_new_batch(\n                    spec_info.future_indices, spec_info\n                )\n            self.spec_info = spec_info\n"
  },
  {
    "path": "python/sglang/srt/disaggregation/encode_grpc_server.py",
    "content": "\"\"\"\ngRPC Encoder Server for SGLang EPD (Encode-Prefill-Decode) mode.\n\nThis server provides gRPC-based encoding for multimodal inputs.\n\nUsage:\n    python -m sglang.launch_server --model-path <model> --encoder-only --grpc-mode\n\"\"\"\n\nimport asyncio\nimport logging\nimport multiprocessing as mp\nimport traceback\nfrom concurrent import futures\nfrom typing import List\n\nimport grpc\nimport zmq\nimport zmq.asyncio\nfrom grpc_health.v1 import health_pb2, health_pb2_grpc\nfrom grpc_reflection.v1alpha import reflection\nfrom smg_grpc_proto import sglang_encoder_pb2, sglang_encoder_pb2_grpc\n\nfrom sglang.srt.disaggregation.encode_server import (\n    MMEncoder,\n    handle_scheduler_receive_url_request,\n    launch_encoder,\n)\nfrom sglang.srt.managers.schedule_batch import Modality\nfrom sglang.srt.server_args import PortArgs, ServerArgs\nfrom sglang.srt.utils import random_uuid\nfrom sglang.srt.utils.network import NetworkAddress, get_zmq_socket\n\nlogger = logging.getLogger(__name__)\nSGLangEncoderServicer = sglang_encoder_pb2_grpc.SglangEncoderServicer\nadd_SGLangEncoderServicer_to_server = (\n    sglang_encoder_pb2_grpc.add_SglangEncoderServicer_to_server\n)\n\n\nclass EncoderHealthServicer(health_pb2_grpc.HealthServicer):\n    \"\"\"\n    Standard gRPC health check service for encoder server.\n    Implements grpc.health.v1.Health for Kubernetes probes.\n    \"\"\"\n\n    OVERALL_SERVER = \"\"\n    ENCODER_SERVICE = \"sglang.grpc.encoder.SglangEncoder\"\n\n    def __init__(self):\n        self._serving = False\n\n    def set_serving(self):\n        self._serving = True\n\n    def set_not_serving(self):\n        self._serving = False\n\n    async def Check(self, request, context) -> health_pb2.HealthCheckResponse:\n        if self._serving:\n            return health_pb2.HealthCheckResponse(\n                status=health_pb2.HealthCheckResponse.SERVING\n            )\n        return health_pb2.HealthCheckResponse(\n            status=health_pb2.HealthCheckResponse.NOT_SERVING\n        )\n\n    async def Watch(self, request, context):\n        yield await self.Check(request, context)\n\n\nclass SGLangEncoderServer(SGLangEncoderServicer):\n    \"\"\"\n    gRPC service implementation for SGLang encoder.\n    \"\"\"\n\n    def __init__(\n        self,\n        encoder: MMEncoder,\n        send_sockets: List[zmq.Socket],\n        server_args: ServerArgs,\n    ):\n        self.encoder = encoder\n        self.send_sockets = send_sockets\n        self.server_args = server_args\n\n    async def Encode(\n        self, request: sglang_encoder_pb2.EncodeRequest, context\n    ) -> sglang_encoder_pb2.EncodeResponse:\n        try:\n            request_dict = {\n                \"mm_items\": list(request.mm_items),\n                \"req_id\": request.req_id,\n                \"num_parts\": request.num_parts,\n                \"part_idx\": request.part_idx,\n            }\n            for socket in self.send_sockets:\n                await socket.send_pyobj(request_dict)\n\n            # gRPC encode is image-only; encoder.encode() requires modality\n            (\n                nbytes,\n                embedding_len,\n                embedding_dim,\n                error_msg,\n                error_code,\n            ) = await self.encoder.encode(\n                mm_items=list(request.mm_items),\n                modality=Modality.IMAGE,\n                req_id=request.req_id,\n                num_parts=request.num_parts,\n                part_idx=request.part_idx,\n            )\n            if error_msg is not None:\n                context.set_code(grpc.StatusCode.INTERNAL)\n                context.set_details(error_msg)\n                return sglang_encoder_pb2.EncodeResponse()\n\n            if self.server_args.encoder_transfer_backend == \"mooncake\":\n                return sglang_encoder_pb2.EncodeResponse(\n                    embedding_size=nbytes,\n                    embedding_len=embedding_len,\n                    embedding_dim=embedding_dim,\n                )\n            elif self.server_args.encoder_transfer_backend == \"zmq_to_scheduler\":\n                embedding_ports = list(request.embedding_port)\n                logger.info(f\"embedding_port = {embedding_ports}\")\n                if not embedding_ports:\n                    await self.encoder.send_with_url(req_id=request.req_id)\n                else:\n                    tasks = []\n                    for embedding_port in embedding_ports:\n                        tasks.append(\n                            self.encoder.send(\n                                req_id=request.req_id,\n                                prefill_host=request.prefill_host,\n                                embedding_port=embedding_port,\n                            )\n                        )\n                    await asyncio.gather(*tasks)\n                    self.encoder.embedding_to_send.pop(request.req_id, None)\n                return sglang_encoder_pb2.EncodeResponse()\n            elif self.server_args.encoder_transfer_backend == \"zmq_to_tokenizer\":\n                embedding_port = (\n                    request.embedding_port[0] if request.embedding_port else 0\n                )\n                await self.encoder.send(\n                    req_id=request.req_id,\n                    prefill_host=request.prefill_host,\n                    embedding_port=embedding_port,\n                )\n                self.encoder.embedding_to_send.pop(request.req_id, None)\n                return sglang_encoder_pb2.EncodeResponse()\n\n            return sglang_encoder_pb2.EncodeResponse()\n\n        except Exception as e:\n            logger.error(f\"Encode error: {e}\")\n            traceback.print_exc()\n            context.set_code(grpc.StatusCode.INTERNAL)\n            context.set_details(str(e))\n            return sglang_encoder_pb2.EncodeResponse()\n\n    async def Send(\n        self, request: sglang_encoder_pb2.SendRequest, context\n    ) -> sglang_encoder_pb2.SendResponse:\n        try:\n            await self.encoder.send(\n                req_id=request.req_id,\n                prefill_host=request.prefill_host,\n                embedding_port=request.embedding_port,\n                session_id=request.session_id if request.session_id else None,\n                buffer_address=(\n                    request.buffer_address if request.buffer_address else None\n                ),\n            )\n            self.encoder.embedding_to_send.pop(request.req_id, None)\n            return sglang_encoder_pb2.SendResponse()\n\n        except Exception as e:\n            logger.error(f\"Send error: {e}\")\n            traceback.print_exc()\n            context.set_code(grpc.StatusCode.INTERNAL)\n            context.set_details(str(e))\n            return sglang_encoder_pb2.SendResponse()\n\n    async def SchedulerReceiveUrl(\n        self, request: sglang_encoder_pb2.SchedulerReceiveUrlRequest, context\n    ) -> sglang_encoder_pb2.SchedulerReceiveUrlResponse:\n        try:\n            await handle_scheduler_receive_url_request(\n                {\n                    \"req_id\": request.req_id,\n                    \"receive_count\": request.receive_count,\n                    \"receive_url\": request.receive_url,\n                }\n            )\n            return sglang_encoder_pb2.SchedulerReceiveUrlResponse()\n\n        except Exception as e:\n            logger.error(f\"SchedulerReceiveUrl error: {e}\")\n            traceback.print_exc()\n            context.set_code(grpc.StatusCode.INTERNAL)\n            context.set_details(str(e))\n            return sglang_encoder_pb2.SchedulerReceiveUrlResponse()\n\n\nasync def serve_grpc_encoder(server_args: ServerArgs):\n    ctx = mp.get_context(\"spawn\")\n    zmq_ctx = zmq.asyncio.Context(10)\n    ipc_path_prefix = random_uuid()\n    port_args = PortArgs.init_new(server_args)\n\n    if server_args.dist_init_addr:\n        na = NetworkAddress.parse(server_args.dist_init_addr)\n        dist_init_method = na.to_tcp()\n    else:\n        dist_init_method = NetworkAddress(\n            server_args.host or \"127.0.0.1\", port_args.nccl_port\n        ).to_tcp()\n\n    send_sockets: List[zmq.Socket] = []\n    for rank in range(1, server_args.tp_size):\n        schedule_path = f\"ipc:///tmp/{ipc_path_prefix}_schedule_{rank}\"\n        send_sockets.append(\n            get_zmq_socket(zmq_ctx, zmq.PUSH, schedule_path, bind=False)\n        )\n        ctx.Process(\n            target=launch_encoder,\n            args=(server_args, schedule_path, dist_init_method, rank),\n            daemon=True,\n        ).start()\n\n    encoder = MMEncoder(server_args, dist_init_method=dist_init_method)\n\n    server = grpc.aio.server(\n        futures.ThreadPoolExecutor(max_workers=10),\n        options=[\n            (\"grpc.max_send_message_length\", 1024 * 1024 * 256),\n            (\"grpc.max_receive_message_length\", 1024 * 1024 * 256),\n        ],\n    )\n\n    health_servicer = EncoderHealthServicer()\n    health_pb2_grpc.add_HealthServicer_to_server(health_servicer, server)\n\n    encoder_servicer = SGLangEncoderServer(\n        encoder=encoder,\n        send_sockets=send_sockets,\n        server_args=server_args,\n    )\n    add_SGLangEncoderServicer_to_server(encoder_servicer, server)\n\n    SERVICE_NAMES = (\n        sglang_encoder_pb2.DESCRIPTOR.services_by_name[\"SglangEncoder\"].full_name,\n        \"grpc.health.v1.Health\",\n        reflection.SERVICE_NAME,\n    )\n    reflection.enable_server_reflection(SERVICE_NAMES, server)\n\n    listen_addr = f\"{server_args.host}:{server_args.port}\"\n    server.add_insecure_port(listen_addr)\n\n    await server.start()\n    logger.info(f\"gRPC encoder server listening on {listen_addr}\")\n\n    health_servicer.set_serving()\n\n    try:\n        await server.wait_for_termination()\n    except KeyboardInterrupt:\n        logger.info(\"Shutting down gRPC encoder server...\")\n        health_servicer.set_not_serving()\n        await server.stop(grace=5)\n"
  },
  {
    "path": "python/sglang/srt/disaggregation/encode_receiver.py",
    "content": "import asyncio\nimport itertools\nimport logging\nimport pickle\nimport random\nimport threading\nimport time\nimport uuid\nfrom abc import ABC, abstractmethod\nfrom collections import OrderedDict, defaultdict\nfrom enum import IntEnum\nfrom http import HTTPStatus\nfrom typing import TYPE_CHECKING, Dict, List, Optional\n\nimport aiohttp\nimport torch\nimport zmq\nimport zmq.asyncio\nfrom transformers import PretrainedConfig\n\nfrom sglang.srt.distributed.parallel_state import (\n    GroupCoordinator,\n    get_mooncake_transfer_engine,\n)\nfrom sglang.srt.environ import envs\nfrom sglang.srt.managers.io_struct import GenerateReqInput, TokenizedGenerateReqInput\nfrom sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors\nfrom sglang.srt.managers.schedule_batch import Modality, Req\nfrom sglang.srt.server_args import ServerArgs\nfrom sglang.srt.utils import ImageData\nfrom sglang.srt.utils.hf_transformers_utils import get_processor\nfrom sglang.srt.utils.network import get_local_ip_auto, get_zmq_socket_on_host\n\nlogger = logging.getLogger(__name__)\n\nif TYPE_CHECKING:\n    from sglang.srt.managers.scheduler import Scheduler\n\n\ndef _grpc_target(url: str) -> str:\n    if url.startswith(\"grpc://\"):\n        return url[len(\"grpc://\") :]\n    if url.startswith(\"grpcs://\"):\n        raise ValueError(\"grpcs:// is not supported; use grpc://\")\n    return url\n\n\ndef _normalize_embedding_ports(embedding_port):\n    if embedding_port is None:\n        return []\n    if isinstance(embedding_port, list):\n        return embedding_port\n    return [embedding_port]\n\n\ndef _grpc_scheduler_receive_url(target, req_id, receive_url, receive_count):\n    import grpc\n    from smg_grpc_proto import sglang_encoder_pb2, sglang_encoder_pb2_grpc\n\n    timeout_secs = envs.SGLANG_ENCODER_GRPC_TIMEOUT_SECS.get()\n    channel = grpc.insecure_channel(target)\n    stub = sglang_encoder_pb2_grpc.SglangEncoderStub(channel)\n    try:\n        stub.SchedulerReceiveUrl(\n            sglang_encoder_pb2.SchedulerReceiveUrlRequest(\n                req_id=req_id,\n                receive_url=receive_url,\n                receive_count=receive_count,\n            ),\n            timeout=timeout_secs,\n        )\n    finally:\n        channel.close()\n\n\ndef _grpc_encode_request(target, encode_request):\n    import grpc\n    from smg_grpc_proto import sglang_encoder_pb2, sglang_encoder_pb2_grpc\n\n    timeout_secs = envs.SGLANG_ENCODER_GRPC_TIMEOUT_SECS.get()\n    channel = grpc.insecure_channel(target)\n    stub = sglang_encoder_pb2_grpc.SglangEncoderStub(channel)\n    try:\n        response = stub.Encode(\n            sglang_encoder_pb2.EncodeRequest(\n                mm_items=encode_request[\"mm_items\"],\n                req_id=encode_request[\"req_id\"],\n                num_parts=encode_request[\"num_parts\"],\n                part_idx=encode_request[\"part_idx\"],\n                prefill_host=encode_request[\"prefill_host\"],\n                embedding_port=_normalize_embedding_ports(\n                    encode_request[\"embedding_port\"]\n                ),\n            ),\n            timeout=timeout_secs,\n        )\n        return response\n    finally:\n        channel.close()\n\n\ndef _grpc_send_request(target, request_json):\n    import grpc\n    from smg_grpc_proto import sglang_encoder_pb2, sglang_encoder_pb2_grpc\n\n    timeout_secs = envs.SGLANG_ENCODER_GRPC_TIMEOUT_SECS.get()\n    channel = grpc.insecure_channel(target)\n    stub = sglang_encoder_pb2_grpc.SglangEncoderStub(channel)\n    try:\n        stub.Send(\n            sglang_encoder_pb2.SendRequest(\n                req_id=request_json[\"req_id\"],\n                prefill_host=request_json[\"prefill_host\"],\n                embedding_port=request_json[\"embedding_port\"],\n                session_id=request_json[\"session_id\"],\n                buffer_address=request_json[\"buffer_address\"],\n            ),\n            timeout=timeout_secs,\n        )\n    finally:\n        channel.close()\n\n\nclass EmbeddingData:\n    def __init__(\n        self,\n        req_id,\n        num_parts,\n        part_idx,\n        grid_dim,\n        modality,\n        embedding=None,\n        embedding_shape=None,\n        error_msg=None,\n        error_code=None,\n        **kwargs,\n    ):\n        self.req_id = req_id\n        self.num_parts = num_parts\n        self.part_idx = part_idx\n        self.grid_dim = grid_dim\n        self.modality = modality\n        self.embedding = embedding\n        self.send_time = None\n        self.dtype = embedding.dtype if embedding is not None else None\n        if embedding_shape is not None:\n            self.shape = embedding_shape\n        else:\n            self.shape = list(embedding.shape) if embedding is not None else None\n        self.error_msg = error_msg\n        self.error_code = error_code\n        # Store additional metadata (e.g., video_timestamps for qwen3_vl)\n        for key, value in kwargs.items():\n            setattr(self, key, value)\n\n    def get_grid(self):\n        \"\"\"Get the grid dimension of the embedding, used for image/video/audio.\"\"\"\n        return self.grid_dim\n\n    def get_embedding(self):\n        return self.embedding\n\n    def __repr__(self):\n        return f\"EmbeddingData(req_id={self.req_id}, num_parts={self.num_parts}, part_idx={self.part_idx}) error_msg={self.error_msg}\"\n\n    def copy_without_embedding(self):\n        new_data = EmbeddingData(\n            req_id=self.req_id,\n            num_parts=self.num_parts,\n            part_idx=self.part_idx,\n            grid_dim=self.grid_dim,\n            modality=self.modality,\n            embedding=None,\n            embedding_shape=self.shape,\n            error_msg=self.error_msg,\n            error_code=self.error_code,\n        )\n        for key, value in self.__dict__.items():\n            if key.startswith(\"_\") or key == \"embedding\":\n                continue\n            setattr(new_data, key, value)\n        return new_data\n\n\n# Modality -> (list attr name, whether to flatten grid for that list)\n_MODALITY_GRID_ATTRS = {\n    Modality.IMAGE: (\"img_grid_thw\", False),\n    Modality.VIDEO: (\"video_grid_thw\", False),\n    Modality.AUDIO: (\"audio_feature_lens\", True),\n}\n_VIDEO_META_ATTRS = (\"video_timestamps\", \"second_per_grid_ts\")\n\n\ndef _cat_grid(dims, flatten_items=False):\n    \"\"\"Concatenate non-None tensors from a list; optionally flatten each before cat.\"\"\"\n    valid = (\n        [g.flatten() for g in dims if g is not None]\n        if flatten_items\n        else [g for g in dims if g is not None]\n    )\n    return torch.cat(valid, dim=0) if valid else None\n\n\nclass MultiModalEmbeddingData(EmbeddingData):\n    def __init__(\n        self,\n        part_idx,\n        num_parts,\n        req_id,\n        grid_dim,\n        modality,\n        embedding,\n        embedding_shape,\n        **kwargs,\n    ):\n        super().__init__(\n            req_id,\n            num_parts,\n            part_idx,\n            grid_dim,\n            modality,\n            embedding,\n            embedding_shape,\n            **kwargs,\n        )\n        self.img_grid_thw = [None] * num_parts\n        self.video_grid_thw = [None] * num_parts\n        self.audio_feature_lens = [None] * num_parts\n        self.modality_list = [\n            modality if part_idx == i else None for i in range(num_parts)\n        ]\n        self.ready_list = [i == part_idx for i in range(num_parts)]\n        self.embedding_list = [\n            embedding if i == part_idx else None for i in range(num_parts)\n        ]\n        self.embedding_shape_list = [\n            embedding_shape if i == part_idx else None for i in range(num_parts)\n        ]\n        self.video_timestamps = [None] * num_parts\n        self.second_per_grid_ts = [None] * num_parts\n\n        self._set_part_grid(part_idx, modality, self.get_grid())\n        if modality == Modality.VIDEO:\n            self._set_video_meta_for_part(part_idx, kwargs)\n\n    def _set_part_grid(self, part_idx, modality, grid):\n        \"\"\"Set the grid for one part according to modality (IMAGE/VIDEO/AUDIO).\"\"\"\n        spec = _MODALITY_GRID_ATTRS.get(modality)\n        if spec is None:\n            raise ValueError(f\"Invalid modality: {modality}\")\n        attr_name, flatten = spec\n        value = grid.flatten() if flatten else grid\n        getattr(self, attr_name)[part_idx] = value\n\n    def _set_video_meta_for_part(self, part_idx, source):\n        \"\"\"Copy video_timestamps and second_per_grid_ts from source (dict or object).\"\"\"\n        for attr_name in _VIDEO_META_ATTRS:\n            val = (\n                source.get(attr_name)\n                if isinstance(source, dict)\n                else getattr(source, attr_name, None)\n            )\n            if val is not None:\n                getattr(self, attr_name)[part_idx] = val\n\n    @classmethod\n    def from_embedding_data(cls, embedding_data: EmbeddingData):\n        \"\"\"Create MultiModalEmbeddingData from an EmbeddingData instance.\"\"\"\n        # Only forward known optional attrs (e.g. video metadata) so they land on the instance\n        extra = {}\n        for attr in _VIDEO_META_ATTRS:\n            val = getattr(embedding_data, attr, None)\n            if val is not None:\n                extra[attr] = val\n        mm_data = cls(\n            part_idx=embedding_data.part_idx,\n            num_parts=embedding_data.num_parts,\n            req_id=embedding_data.req_id,\n            grid_dim=embedding_data.grid_dim,\n            modality=embedding_data.modality,\n            embedding=embedding_data.embedding,\n            embedding_shape=embedding_data.shape,\n            **extra,\n        )\n        mm_data.send_time = embedding_data.send_time\n        return mm_data\n\n    def __repr__(self):\n        return f\"MultiModalEmbeddingData(req_id={self.req_id}, num_parts={self.num_parts}, part_idx={self.part_idx}, modality={self.modality})\"\n\n    def get_embedding(self, is_concat=False):\n        if is_concat:\n            groups = defaultdict(list)\n            for i, e in enumerate(self.embedding_list):\n                if e is not None:\n                    groups[self.modality_list[i]].append(e.cuda())\n            return {\n                mod: torch.concat(tensors).to(\"cpu\", non_blocking=True)\n                for mod, tensors in groups.items()\n            }\n        return self.embedding_list\n\n    @property\n    def ready(self):\n        return sum(self.ready_list) == self.num_parts\n\n    def get_mm_extra_meta(self):\n        \"\"\"Build kwargs for mm_processor.get_mm_data() from grid and optional video meta.\"\"\"\n        kwargs = {\n            \"img_grid_thw\": _cat_grid(self.img_grid_thw),\n            \"video_grid_thw\": _cat_grid(self.video_grid_thw),\n            \"audio_feature_lens\": _cat_grid(\n                self.audio_feature_lens, flatten_items=True\n            ),\n        }\n        for attr in _VIDEO_META_ATTRS:\n            lst = getattr(self, attr, None)\n            if not lst:\n                continue\n            valid = [a for a in lst if a is not None]\n            if valid:\n                kwargs[attr] = list(itertools.chain(*valid))\n        return kwargs\n\n    def add(self, embedding_data: EmbeddingData):\n        assert self.req_id == embedding_data.req_id\n        assert not self.ready_list[embedding_data.part_idx]\n        pid = embedding_data.part_idx\n        self.ready_list[pid] = True\n        self.modality_list[pid] = embedding_data.modality\n        self.embedding_list[pid] = embedding_data.get_embedding()\n        self.embedding_shape_list[pid] = embedding_data.shape\n        self._set_part_grid(pid, embedding_data.modality, embedding_data.get_grid())\n        if embedding_data.modality == Modality.VIDEO:\n            self._set_video_meta_for_part(pid, embedding_data)\n\n\nclass WaitingImageRequestStatus(IntEnum):\n    FAIL = -1\n    PENDING = 0\n    SUCCESS = 1\n    TIMEOUT = -2\n\n\ndef create_part_req_id(original_req_id: str, part_idx: int) -> str:\n    \"\"\"Create a unique part request ID by appending part index suffix.\"\"\"\n    return f\"{original_req_id}_local_part_{part_idx}\"\n\n\ndef extract_original_req_id(part_req_id: str) -> str:\n    \"\"\"Extract the original request ID from a part request ID.\"\"\"\n    if \"_local_part_\" in part_req_id:\n        return part_req_id.rsplit(\"_local_part_\", 1)[0]\n    return part_req_id\n\n\ndef calculate_modality_num_parts(modalities, num_items_assigned):\n    \"\"\"\n    Calculate total number of parts and number of parts per modality.\n\n    Args:\n        modalities: List of modalities in order\n        num_items_assigned: Dictionary mapping modality to list of assignment counts per encoder\n\n    Returns:\n        Tuple of (total_num_parts, modality_num_parts_dict)\n        - total_num_parts: Total number of parts across all modalities\n        - modality_num_parts: Dictionary mapping modality to number of parts for that modality\n    \"\"\"\n    total_num_parts = 0\n    modality_num_parts = {}\n    for modality in modalities:\n        num_items_assigned_modality = num_items_assigned.get(modality)\n        num_parts = sum(1 for x in num_items_assigned_modality if x != 0)\n        modality_num_parts[modality] = num_parts\n        total_num_parts += num_parts\n    return total_num_parts, modality_num_parts\n\n\n# For zmq_to_scheduler\nclass WaitingImageRequest:\n    def __init__(\n        self,\n        rid: str,\n        recv_req: TokenizedGenerateReqInput,\n        mm_processor,\n        encoder_urls,\n        host_name,\n        receive_count,\n    ):\n        self.rid = rid\n        self.recv_req = recv_req\n        self.mm_inputs = None\n        self.error = None\n        self.thread = None\n        self.mm_processor = mm_processor\n        self.encoder_urls = encoder_urls\n        self.host_name = host_name\n        self.receive_count = receive_count\n        self.num_items_assigned = recv_req.num_items_assigned\n        self.embedding_port, self.recv_socket = get_zmq_socket_on_host(\n            zmq.Context(), zmq.PULL\n        )\n        logger.info(f\"Waiting for input {self.embedding_port = }\")\n        self.recv_embedding_data = None\n        # ok=1 pending=0 fail=-1\n        self.status = WaitingImageRequestStatus.PENDING\n        self.error_msg = None\n        self.error_code = None\n        self.start_time = time.time()\n\n    def send_encode_request(self):\n        async def _send_single_request(session, url, payload):\n            try:\n                async with session.post(url, json=payload) as response:\n                    response.raise_for_status()\n                    return await response.text()\n            except Exception as e:\n                logger.error(f\"Failed to send request to {url}: {e}\")\n                raise\n\n        async def send_embedding_port(req_id, receive_count, host_name, embedding_port):\n            async with aiohttp.ClientSession(\n                timeout=aiohttp.ClientTimeout(total=1800)\n            ) as session:\n                tasks = []\n                logger.info(f\"{self.num_items_assigned = } \")\n\n                # Calculate part_idx_offset similar to encode() method\n                modalities = list(self.num_items_assigned.keys())\n                _, modality_num_parts = calculate_modality_num_parts(\n                    modalities, self.num_items_assigned\n                )\n\n                part_idx_offset = 0\n                for modality in modalities:\n                    assigned_nums = self.num_items_assigned[modality]\n                    num_parts = modality_num_parts[modality]\n                    cum_idx = 0\n                    for idx, assigned_num in enumerate(assigned_nums):\n                        if assigned_num == 0:\n                            continue\n                        part_idx = part_idx_offset + cum_idx\n                        part_req_id = create_part_req_id(req_id, part_idx)\n                        encoder_url = self.encoder_urls[idx]\n                        target_url = f\"{encoder_url}/scheduler_receive_url\"\n                        payload = {\n                            \"req_id\": part_req_id,  # use part_req_id to match encode request\n                            \"receive_count\": receive_count,\n                            \"receive_url\": f\"{host_name}:{embedding_port}\",\n                            \"modality\": modality.name,\n                        }\n                        logger.info(\n                            f\"Preparing to send to {target_url} with part_req_id={part_req_id}\"\n                        )\n                        task = _send_single_request(session, target_url, payload)\n                        tasks.append(task)\n                        cum_idx += 1\n                    part_idx_offset += num_parts\n\n                if not tasks:\n                    logger.info(\"No tasks to send.\")\n                    return\n                logger.info(f\"Concurrently sending {len(tasks)} requests...\")\n                results = await asyncio.gather(*tasks, return_exceptions=True)\n\n                for i, result in enumerate(results):\n                    if isinstance(result, Exception):\n                        logger.error(f\"Request {i} failed: {result}\")\n                    else:\n                        logger.debug(f\"Request {i} succeeded.\")\n\n        asyncio.run(\n            send_embedding_port(\n                self.recv_req.rid,\n                self.receive_count,\n                self.host_name,\n                self.embedding_port,\n            )\n        )\n\n    def _try_recv_mm_data(self):\n        if self.status != WaitingImageRequestStatus.PENDING:\n            return\n        while self.recv_embedding_data is None or not self.recv_embedding_data.ready:\n            try:\n                parts = self.recv_socket.recv_multipart(flags=zmq.NOBLOCK, copy=False)\n            except zmq.Again:\n                # No data available yet, wait a bit and retry\n                return\n            recv_obj: EmbeddingData = pickle.loads(parts[0])\n            if getattr(recv_obj, \"error_msg\", None) is not None:\n                logger.warning(\n                    f\"Received error signal from encoder for {self.rid}: {recv_obj.error_msg} {recv_obj.error_code = }\"\n                )\n                self.error_msg = recv_obj.error_msg\n                self.error_code = recv_obj.error_code\n                self.status = WaitingImageRequestStatus.FAIL\n                self.recv_socket.close()\n                return\n\n            buffer = parts[1].buffer if hasattr(parts[1], \"buffer\") else parts[1]\n            recv_obj.embedding = (\n                torch.frombuffer(buffer, dtype=recv_obj.dtype)\n                .reshape(recv_obj.shape)\n                .clone()\n            )\n\n            # Extract original req_id from part_req_id\n            part_req_id = recv_obj.req_id\n            original_req_id = extract_original_req_id(part_req_id)\n            # Update recv_obj.req_id to original for aggregation\n            recv_obj.req_id = original_req_id\n\n            if self.recv_embedding_data is None:\n                self.recv_embedding_data = MultiModalEmbeddingData.from_embedding_data(\n                    recv_obj\n                )\n            else:\n                self.recv_embedding_data.add(recv_obj)\n\n        recv_embedding = self.recv_embedding_data.get_embedding(is_concat=True)\n        mm_inputs = self.mm_processor.get_mm_data(\n            self.recv_req.input_text,\n            recv_embedding,\n            **self.recv_embedding_data.get_mm_extra_meta(),\n        )\n        self.recv_req.mm_inputs = mm_inputs\n        self.recv_req.input_ids = mm_inputs[\"input_ids\"]\n        self.status = WaitingImageRequestStatus.SUCCESS\n        self.recv_socket.close()\n\n\nclass WaitingImageRequestGrpc(WaitingImageRequest):\n    def send_encode_request(self):\n        async def send_embedding_port(req_id, receive_count, host_name, embedding_port):\n            tasks = []\n            # gRPC image-only: flatten modality dict to flat list\n            assigned = list(self.num_items_assigned.values())[0]\n            logger.info(f\"num_items_assigned={assigned}\")\n\n            for idx, assigned_num in enumerate(assigned):\n                if assigned_num == 0:\n                    continue\n                encoder_url = self.encoder_urls[idx]\n                receive_url = f\"{host_name}:{embedding_port}\"\n                target_url = f\"{encoder_url}/SchedulerReceiveUrl\"\n                logger.info(f\"Preparing to send to {target_url}\")\n                tasks.append(\n                    asyncio.to_thread(\n                        _grpc_scheduler_receive_url,\n                        _grpc_target(encoder_url),\n                        req_id,\n                        receive_url,\n                        receive_count,\n                    )\n                )\n\n            if not tasks:\n                logger.info(\"No tasks to send.\")\n                return\n            logger.info(f\"Concurrently sending {len(tasks)} requests...\")\n            results = await asyncio.gather(*tasks, return_exceptions=True)\n\n            for i, result in enumerate(results):\n                if isinstance(result, Exception):\n                    logger.error(f\"Request {i} failed: {result}\")\n                else:\n                    logger.debug(f\"Request {i} succeeded.\")\n\n        asyncio.run(\n            send_embedding_port(\n                self.recv_req.rid,\n                self.receive_count,\n                self.host_name,\n                self.embedding_port,\n            )\n        )\n\n\ndef _determine_tensor_transport_mode(server_args):\n    is_cross_node = server_args.dist_init_addr\n\n    if is_cross_node:\n        # Fallback to default CPU transport for multi-node\n        return \"default\"\n    else:\n        return \"cuda_ipc\"\n\n\nclass MMReceiverBase(ABC):\n    def __init__(\n        self,\n        server_args: ServerArgs,\n        dtype: Optional[torch.dtype] = None,\n        hf_config: Optional[PretrainedConfig] = None,\n        pp_rank: Optional[int] = None,\n        tp_rank: Optional[int] = None,\n        tp_group: Optional[GroupCoordinator] = None,\n        scheduler: Optional[\"Scheduler\"] = None,\n    ):\n        self.context = zmq.asyncio.Context(20)\n        self.encoder_transfer_backend = server_args.encoder_transfer_backend\n        self.encode_urls = server_args.encoder_urls\n        self.host = get_local_ip_auto(server_args.host)\n        if self.encoder_transfer_backend == \"mooncake\":\n            self.dtype = dtype\n            self.embeddings_engine = get_mooncake_transfer_engine()\n            if self.embeddings_engine is None:\n                from sglang.srt.distributed.device_communicators.mooncake_transfer_engine import (\n                    init_mooncake_transfer_engine,\n                )\n\n                self.embeddings_engine = init_mooncake_transfer_engine(\n                    hostname=self.host,\n                    ib_device=(\n                        server_args.disaggregation_ib_device\n                        or server_args.mooncake_ib_device\n                    ),\n                )\n            self.embeddings_buffer = dict()\n        elif self.encoder_transfer_backend == \"zmq_to_scheduler\":\n            self.pp_rank = pp_rank\n            self.tp_rank = tp_rank\n            self.tp_size = server_args.tp_size\n            self.tp_group = tp_group\n            self.nnodes = server_args.nnodes\n            self.hostname = get_local_ip_auto()\n            self.waiting_list: List[WaitingImageRequest] = []\n            self.scheduler = scheduler\n            self.wait_timeout = envs.SGLANG_ENCODER_RECV_TIMEOUT.get()\n            if hf_config is not None:\n                transport_mode = _determine_tensor_transport_mode(server_args)\n                import_processors(\"sglang.srt.multimodal.processors\")\n                _processor = None\n                try:\n                    _processor = get_processor(\n                        server_args.tokenizer_path,\n                        tokenizer_mode=server_args.tokenizer_mode,\n                        trust_remote_code=server_args.trust_remote_code,\n                        revision=server_args.revision,\n                        use_fast=not server_args.disable_fast_image_processor,\n                    )\n                except ValueError as e:\n                    error_message = str(e)\n                    if \"does not have a slow version\" in error_message:\n                        logger.info(\n                            f\"Processor {server_args.tokenizer_path} does not have a slow version. Automatically use fast version\"\n                        )\n                        _processor = get_processor(\n                            server_args.tokenizer_path,\n                            tokenizer_mode=server_args.tokenizer_mode,\n                            trust_remote_code=server_args.trust_remote_code,\n                            revision=server_args.revision,\n                            use_fast=True,\n                        )\n                    else:\n                        raise e\n\n                # Skip mm_pool if not adaptive dispatch to encoder\n                enable_adaptive_dispatch_to_encoder = (\n                    server_args.enable_adaptive_dispatch_to_encoder\n                )\n                self.mm_processor = get_mm_processor(\n                    hf_config,\n                    server_args,\n                    _processor,\n                    transport_mode,\n                    skip_mm_pool=not enable_adaptive_dispatch_to_encoder,\n                )\n\n    @abstractmethod\n    def process_waiting_requests(self, recv_reqs):\n        pass\n\n    async def recv_mm_data(\n        self, request_obj, mm_processor, prompt, need_wait_for_mm_inputs=True\n    ):\n        req_id = None\n        try:\n            if len(self.encode_urls) == 0 or not need_wait_for_mm_inputs:\n                return None\n            req_id = uuid.uuid4().hex\n            embedding_port, recv_socket = get_zmq_socket_on_host(self.context, zmq.PULL)\n            mm_data = self._extract_url_data(request_obj)\n            asyncio.create_task(\n                self.encode(req_id, mm_data, embedding_port, \"encode\", \"send\")\n            )\n            return await asyncio.wait_for(\n                self._recv_mm_data(req_id, recv_socket, mm_processor, prompt),\n                timeout=20,\n            )\n        except asyncio.TimeoutError:\n            logger.warning(f\"Embedding recv timeout for request {req_id}\")\n            if req_id is not None:\n                self._cleanup_mooncake_buffer(req_id)\n            return None\n\n    def _cleanup_mooncake_buffer(self, req_id):\n        if self.encoder_transfer_backend != \"mooncake\":\n            return\n        if not hasattr(self, \"embeddings_buffer\"):\n            return\n        embeddings = self.embeddings_buffer.pop(req_id, None)\n        if embeddings is None:\n            return\n        try:\n            self.embeddings_engine.deregister(embeddings.data_ptr())\n        except Exception:\n            logger.exception(\n                \"mooncake: failed to deregister buffer for req_id=%s\", req_id\n            )\n\n    async def _recv_mm_data(self, req_id, recv_socket, mm_processor, prompt):\n        if req_id is None:\n            return None\n\n        recv_embedding = None\n\n        recv_embedding_data: MultiModalEmbeddingData = None\n\n        try:\n            while recv_embedding_data is None or not recv_embedding_data.ready:\n                parts = await recv_socket.recv_multipart(copy=False)\n                if not parts:\n                    continue\n                recv_obj: EmbeddingData = pickle.loads(parts[0])\n                if getattr(recv_obj, \"error_msg\", None) is not None:\n                    logger.warning(\n                        f\"Encoder error for req_id={req_id}: {recv_obj.error_msg} \"\n                        f\"error_code={getattr(recv_obj, 'error_code', None)}\"\n                    )\n                    self._cleanup_mooncake_buffer(req_id)\n                    return None\n                logger.debug(\"recv_obj=%s\", recv_obj)\n                # Extract original req_id from part_req_id\n                part_req_id = recv_obj.req_id\n                original_req_id = extract_original_req_id(part_req_id)\n                # Update recv_obj.req_id to original for aggregation\n                recv_obj.req_id = original_req_id\n                if self.encoder_transfer_backend == \"zmq_to_tokenizer\":\n                    if len(parts) < 2:\n                        logger.error(\n                            \"zmq_to_tokenizer expected 2-part message, got %d parts\",\n                            len(parts),\n                        )\n                        return None\n                    buffer = (\n                        parts[1].buffer if hasattr(parts[1], \"buffer\") else parts[1]\n                    )\n                    # Clone so we don't depend on ZMQ buffer after next recv.\n                    recv_obj.embedding = (\n                        torch.frombuffer(buffer, dtype=recv_obj.dtype)\n                        .reshape(recv_obj.shape)\n                        .clone()\n                    )\n                if recv_embedding_data is None:\n                    recv_embedding_data = MultiModalEmbeddingData.from_embedding_data(\n                        recv_obj\n                    )\n                else:\n                    recv_embedding_data.add(recv_obj)\n\n            if self.encoder_transfer_backend == \"mooncake\":\n                if req_id not in self.embeddings_buffer:\n                    logger.error(\n                        \"mooncake: embeddings_buffer missing req_id=%s\", req_id\n                    )\n                    return None\n                raw_buffer = self.embeddings_buffer.pop(req_id)\n                self.embeddings_engine.deregister(raw_buffer.data_ptr())\n                byte_offset = 0\n                for i in range(recv_embedding_data.num_parts):\n                    shape = recv_embedding_data.embedding_shape_list[i]\n                    if shape is None:\n                        continue\n                    part_bytes = (\n                        shape[0]\n                        * shape[1]\n                        * torch.tensor([], dtype=self.dtype).element_size()\n                    )\n                    recv_embedding_data.embedding_list[i] = (\n                        raw_buffer[byte_offset : byte_offset + part_bytes]\n                        .view(self.dtype)\n                        .reshape(shape)\n                    )\n                    byte_offset += part_bytes\n\n            recv_embedding = recv_embedding_data.get_embedding(is_concat=True)\n\n            mm_inputs = mm_processor.get_mm_data(\n                prompt,\n                recv_embedding,\n                **recv_embedding_data.get_mm_extra_meta(),\n            )\n            return mm_inputs\n        finally:\n            recv_socket.close()\n\n    def send_encode_request(self, obj):\n        self._send_encode_request(obj)\n\n    def _send_encode_request(self, obj):\n        mm_data = self._extract_url_data(obj)\n        if obj.rid is None:\n            obj.rid = uuid.uuid4().hex\n        if mm_data and self.encode_urls:\n            logger.info(f\"Processing {len(mm_data)} mm items for request {obj.rid}\")\n            obj.need_wait_for_mm_inputs = True\n\n            num_items_assigned = self._assign_items_by_modality(\n                mm_data, len(self.encode_urls)\n            )\n            obj.num_items_assigned = num_items_assigned\n            encode_thread = threading.Thread(\n                target=self._run_encode_in_thread,\n                args=(\n                    obj.rid,\n                    mm_data,\n                    \"encode\",\n                    num_items_assigned,\n                    None,\n                ),\n                daemon=True,\n            )\n            encode_thread.start()\n\n    # For zmq_to_scheduler\n    def _process_waiting_requests(self, recv_reqs, waiting_cls):\n        new_recv_reqs = []\n        for recv_req in recv_reqs:\n            if (\n                isinstance(recv_req, TokenizedGenerateReqInput)\n                and recv_req.need_wait_for_mm_inputs is True\n            ):\n                waiting_req = waiting_cls(\n                    rid=recv_req.rid,\n                    recv_req=recv_req,\n                    mm_processor=self.mm_processor,\n                    encoder_urls=self.encode_urls,\n                    host_name=self.hostname,\n                    receive_count=self.tp_size,\n                )\n                waiting_req.send_encode_request()\n                self.waiting_list.append(waiting_req)\n            else:\n                new_recv_reqs.append(recv_req)\n\n        if len(self.waiting_list) == 0:\n            return new_recv_reqs, []\n\n        current_time = time.time()\n        local_status = []\n        for waiting_req in self.waiting_list:\n            waiting_req._try_recv_mm_data()\n            if current_time - waiting_req.start_time > self.wait_timeout:\n                waiting_req.status = WaitingImageRequestStatus.TIMEOUT\n            local_status.append(waiting_req.status)\n\n        local_status = torch.tensor(local_status, device=\"cpu\", dtype=torch.int32)\n\n        torch.distributed.all_reduce(\n            local_status,\n            op=torch.distributed.ReduceOp.MIN,\n            group=self.tp_group.cpu_group,\n        )\n\n        new_waiting = []\n        abort_reqs = []\n        for i, waiting_req in enumerate(self.waiting_list):\n            status_value = local_status[i].item()\n            if status_value == WaitingImageRequestStatus.SUCCESS:\n                new_recv_reqs.append(waiting_req.recv_req)\n            elif status_value == WaitingImageRequestStatus.FAIL:\n                logger.error(\n                    f\"Waiting request {waiting_req.rid} failed: {waiting_req.error_msg} {waiting_req.error_code = }\"\n                )\n                abort_reqs.append(\n                    (\n                        self.create_req(waiting_req.recv_req),\n                        waiting_req.error_msg,\n                        waiting_req.error_code,\n                    )\n                )\n            elif status_value == WaitingImageRequestStatus.TIMEOUT:\n                logger.error(\n                    f\"Timed out waiting for image embeddings for request {waiting_req.rid}\"\n                )\n                abort_reqs.append(\n                    (\n                        self.create_req(waiting_req.recv_req),\n                        f\"Timeout waiting for image embedding after {self.wait_timeout}s\",\n                        HTTPStatus.REQUEST_TIMEOUT,\n                    )\n                )\n            else:  # status_value == WaitingImageRequestStatus.PENDING\n                new_waiting.append(waiting_req)\n\n        self.waiting_list = new_waiting\n        return new_recv_reqs, abort_reqs\n\n    def _run_encode_in_thread(\n        self, req_id, mm_data, endpoint_encode, num_items_assigned, embedding_port\n    ):\n        try:\n            asyncio.run(\n                self.encode(\n                    req_id=req_id,\n                    mm_data=mm_data,\n                    embedding_port=embedding_port,\n                    endpoint_encode=endpoint_encode,\n                    endpoint_send=None,\n                    num_items_assigned=num_items_assigned,\n                )\n            )\n        except Exception as e:\n            logger.error(f\"Encode failed for request {req_id}: {e}\", exc_info=True)\n\n    def create_req(self, recv_req: TokenizedGenerateReqInput):\n        req = Req(\n            recv_req.rid,\n            recv_req.input_text,\n            recv_req.input_ids,\n            recv_req.sampling_params,\n            return_logprob=recv_req.return_logprob,\n            top_logprobs_num=recv_req.top_logprobs_num,\n            token_ids_logprob=recv_req.token_ids_logprob,\n            stream=recv_req.stream,\n            lora_id=recv_req.lora_id,\n            input_embeds=recv_req.input_embeds,\n            custom_logit_processor=recv_req.custom_logit_processor,\n            require_reasoning=recv_req.require_reasoning,\n            return_hidden_states=recv_req.return_hidden_states,\n            return_routed_experts=recv_req.return_routed_experts,\n            eos_token_ids=self.scheduler.model_config.hf_eos_token_id,\n            bootstrap_host=recv_req.bootstrap_host,\n            bootstrap_port=recv_req.bootstrap_port,\n            bootstrap_room=recv_req.bootstrap_room,\n            disagg_mode=self.scheduler.disaggregation_mode,\n            routed_dp_rank=recv_req.routed_dp_rank,\n            disagg_prefill_dp_rank=recv_req.disagg_prefill_dp_rank,\n            vocab_size=self.scheduler.model_config.vocab_size,\n            priority=recv_req.priority,\n            metrics_collector=(\n                self.scheduler.metrics_collector\n                if self.scheduler.enable_metrics\n                else None\n            ),\n            http_worker_ipc=recv_req.http_worker_ipc,\n            dllm_config=self.scheduler.dllm_config,\n        )\n        req.tokenizer = self.scheduler.tokenizer\n        return req\n\n    async def allocate_embedding_buffer(self, req_id, total_bytes):\n        embeddings = torch.empty(total_bytes, dtype=torch.uint8)\n        self.embeddings_engine.register(\n            embeddings.data_ptr(),\n            embeddings.nbytes,\n        )\n        self.embeddings_buffer[req_id] = embeddings\n        return embeddings.data_ptr()\n\n    def _assign_items_by_modality(\n        self, mm_data, encoder_num, random_shuffle=True\n    ) -> Dict:\n        \"\"\"\n        Assign multimodal items across encoders by modality with cross-modality load balancing.\n\n        Args:\n            mm_data: List of multimodal data items, each with a \"modality\" key\n            encoder_num: Number of encoders\n            random_shuffle: Whether to shuffle the encoder indices\n\n        Returns:\n            Dictionary mapping modality to list of assignment counts per encoder\n            Format: {modality: [count_for_encoder_0, count_for_encoder_1, ...]}\n        \"\"\"\n        encode_idx = list(range(encoder_num))\n        if random_shuffle:\n            random.shuffle(encode_idx)\n        # Get unique modalities with order preserved\n        modalities = list(dict.fromkeys(mm_item.get(\"modality\") for mm_item in mm_data))\n        # Use OrderedDict to explicitly maintain modality order\n        num_items_assigned = OrderedDict()\n        current_offset = 0\n\n        for modality in modalities:\n            mm_data_modality = [\n                mm_item for mm_item in mm_data if mm_item.get(\"modality\") == modality\n            ]\n            num_items = len(mm_data_modality)\n            if num_items == 0:\n                continue\n\n            base = num_items // len(encode_idx)\n            remainder = num_items % len(encode_idx)\n            # Rotate assignments based on current_offset to balance load across modalities\n            assignments = [0] * len(encode_idx)\n            for i in range(len(encode_idx)):\n                # keep shuffle order when assigning items to encoders\n                pos_in_shuffled = (current_offset + i) % len(encode_idx)\n                actual_encoder_idx = encode_idx[pos_in_shuffled]\n                assignments[actual_encoder_idx] = base + (1 if i < remainder else 0)\n            num_items_assigned[modality] = assignments\n            current_offset = (current_offset + remainder) % len(encode_idx)\n\n        return num_items_assigned\n\n    def _extract_url_data(self, request_obj) -> List[Dict]:\n        mm_data = []\n        for attr, modality in [\n            (\"image_data\", Modality.IMAGE),\n            (\"video_data\", Modality.VIDEO),\n            (\"audio_data\", Modality.AUDIO),\n        ]:\n            mm_items = getattr(request_obj, attr, None)\n            if mm_items:\n                if not isinstance(mm_items, list):\n                    mm_items = [mm_items]\n                for mm_item in mm_items:\n                    mm_data.append(\n                        {\n                            \"url\": (\n                                mm_item.url\n                                if isinstance(mm_item, ImageData)\n                                else mm_item\n                            ),\n                            \"modality\": modality,\n                        }\n                    )\n        return mm_data\n\n\nclass MMReceiverHTTP(MMReceiverBase):\n    def __init__(\n        self,\n        server_args: ServerArgs,\n        dtype: Optional[torch.dtype] = None,\n        hf_config: Optional[PretrainedConfig] = None,\n        pp_rank: Optional[int] = None,\n        tp_rank: Optional[int] = None,\n        tp_group: Optional[GroupCoordinator] = None,\n        scheduler: Optional[\"Scheduler\"] = None,\n    ):\n        super().__init__(\n            server_args,\n            dtype=dtype,\n            hf_config=hf_config,\n            pp_rank=pp_rank,\n            tp_rank=tp_rank,\n            tp_group=tp_group,\n            scheduler=scheduler,\n        )\n\n    # For zmq_to_scheduler\n    def process_waiting_requests(self, recv_reqs):\n        return self._process_waiting_requests(recv_reqs, WaitingImageRequest)\n\n    async def encode(\n        self,\n        req_id,\n        mm_data,\n        embedding_port,\n        endpoint_encode,\n        endpoint_send,\n        num_items_assigned=None,\n    ):\n        if len(mm_data) == 0:\n            return\n\n        # get unique modalities with order preserved\n        modalities = [mm_item.get(\"modality\") for mm_item in mm_data]\n        modalities = list(dict.fromkeys(modalities))\n        encode_requests = []\n\n        if num_items_assigned is None:\n            num_items_assigned = self._assign_items_by_modality(\n                mm_data, len(self.encode_urls)\n            )\n\n        # Calculate total num_parts across all modalities\n        total_num_parts, modality_num_parts = calculate_modality_num_parts(\n            modalities, num_items_assigned\n        )\n\n        part_idx_offset = 0\n        for modality in modalities:\n            num_items_assigned_modality = num_items_assigned.get(modality)\n            mm_data_modality = [\n                mm_item for mm_item in mm_data if mm_item.get(\"modality\") == modality\n            ]\n\n            num_parts = modality_num_parts[modality]\n            cum_num_items = 0\n            cum_idx = 0\n            for idx, assigned_num in enumerate(num_items_assigned_modality):\n                if assigned_num == 0:\n                    continue\n                part_idx = part_idx_offset + cum_idx\n                part_req_id = create_part_req_id(req_id, part_idx)\n                encode_requests.append(\n                    {\n                        \"encoder_idx\": idx,\n                        \"mm_items\": [\n                            mm_item.get(\"url\")\n                            for mm_item in mm_data_modality[\n                                cum_num_items : cum_num_items + assigned_num\n                            ]\n                        ],\n                        \"num_parts\": total_num_parts,\n                        \"part_idx\": part_idx,\n                        \"req_id\": part_req_id,  # use part_req_id to avoid key collision\n                        \"modality\": modality.name,  # convert enum to string for json serialization\n                        \"prefill_host\": self.host,\n                        \"embedding_port\": embedding_port,\n                    }\n                )\n                cum_idx += 1\n                cum_num_items += assigned_num\n            part_idx_offset += num_parts\n\n        async with aiohttp.ClientSession(\n            timeout=aiohttp.ClientTimeout(\n                total=1800\n            )  # Add timeout for request reliability\n        ) as session:\n            # Send encode requests\n\n            tasks = [\n                session.post(\n                    f\"{self.encode_urls[encode_request['encoder_idx']]}/{endpoint_encode}\",\n                    json=encode_request,\n                )\n                for encode_request in encode_requests\n            ]\n\n            responses = await asyncio.gather(*tasks)\n            for response in responses:\n                if response.status != 200:\n                    try:\n                        err_data = await response.json()\n                        msg = err_data.get(\"message\", \"Unknown encoder error\")\n                    except:\n                        msg = await response.text()\n\n                    logger.error(f\"Encoder returned error {response.status}: {msg}\")\n                    return\n            response_json_list_unsort = [\n                await response.json() for response in responses\n            ]\n\n            # zmq backend: return is None\n            if None in response_json_list_unsort:\n                return\n\n            # mooncake backend: send bootstrap info\n\n            embedding_size_list_sort = [None for _ in range(total_num_parts)]\n            response_json_list_sort = [None for _ in range(total_num_parts)]\n            for response_json in response_json_list_unsort:\n                idx = response_json[\"part_idx\"]\n                embedding_size_list_sort[idx] = response_json[\"embedding_size\"]\n                response_json_list_sort[idx] = response_json\n\n            total_embedding_bytes = sum(\n                s for s in embedding_size_list_sort if s is not None\n            )\n            offset = 0\n            metadata_tasks = []\n            buffer_address = await self.allocate_embedding_buffer(\n                req_id,\n                total_embedding_bytes,\n            )\n            for idx in range(len(tasks)):\n                response_json = response_json_list_sort[idx]\n                buffer_address_adjust = offset + buffer_address\n                response_json.update(\n                    {\n                        \"session_id\": self.embeddings_engine.session_id,\n                        \"buffer_address\": buffer_address_adjust,\n                    }\n                )\n                metadata_tasks.append(\n                    session.post(\n                        f\"{self.encode_urls[response_json['encoder_idx']]}/{endpoint_send}\",\n                        json=response_json,\n                    )\n                )\n                offset += embedding_size_list_sort[idx]\n            await asyncio.gather(*metadata_tasks)\n\n\nclass MMReceiverGrpc(MMReceiverBase):\n    def __init__(\n        self,\n        server_args: ServerArgs,\n        dtype: Optional[torch.dtype] = None,\n        hf_config: Optional[PretrainedConfig] = None,\n        pp_rank: Optional[int] = None,\n        tp_rank: Optional[int] = None,\n        tp_group: Optional[GroupCoordinator] = None,\n        scheduler: Optional[\"Scheduler\"] = None,\n    ):\n        super().__init__(\n            server_args,\n            dtype=dtype,\n            hf_config=hf_config,\n            pp_rank=pp_rank,\n            tp_rank=tp_rank,\n            tp_group=tp_group,\n            scheduler=scheduler,\n        )\n\n    def build_and_send_encode_request(self, image_urls, rid):\n        encode_req = GenerateReqInput(\n            image_data=[ImageData(url=url) for url in image_urls],\n            rid=rid,\n        )\n        self.send_encode_request(encode_req)\n        return encode_req\n\n    # For zmq_to_scheduler\n    def process_waiting_requests(self, recv_reqs):\n        return self._process_waiting_requests(recv_reqs, WaitingImageRequestGrpc)\n\n    async def encode(\n        self,\n        req_id,\n        mm_data,\n        embedding_port,\n        endpoint_encode,\n        endpoint_send,\n        num_items_assigned=None,\n    ):\n        if not mm_data:\n            return\n\n        # gRPC currently only supports image; flatten new dict formats to simple lists\n        if mm_data and isinstance(mm_data[0], dict):\n            non_image = [\n                item.get(\"modality\")\n                for item in mm_data\n                if item.get(\"modality\") != Modality.IMAGE\n            ]\n            if non_image:\n                raise NotImplementedError(\n                    f\"gRPC encode only supports IMAGE modality, got: {non_image}\"\n                )\n            img_data = [item.get(\"url\") for item in mm_data]\n        else:\n            img_data = mm_data\n        if isinstance(num_items_assigned, dict):\n            num_items_assigned = list(num_items_assigned.values())[0]\n\n        encode_requests = []\n        if num_items_assigned is None:\n            encode_idx = list(range(len(self.encode_urls)))\n            random.shuffle(encode_idx)\n            num_items_assigned = [\n                (idx + len(img_data)) // len(self.encode_urls) for idx in encode_idx\n            ]\n        num_parts = sum(1 for x in num_items_assigned if x != 0)\n        cum_num_items = 0\n        cum_idx = 0\n        for idx, assigned_num in enumerate(num_items_assigned):\n            if assigned_num == 0:\n                continue\n            start = cum_num_items\n            end = cum_num_items + assigned_num\n            encode_requests.append(\n                {\n                    \"encoder_idx\": idx,\n                    \"mm_items\": img_data[start:end],\n                    \"num_parts\": num_parts,\n                    \"part_idx\": cum_idx,\n                    \"req_id\": req_id,\n                    \"prefill_host\": self.host,\n                    \"embedding_port\": embedding_port,\n                }\n            )\n            cum_idx += 1\n            cum_num_items += assigned_num\n\n        grpc_tasks = [\n            asyncio.to_thread(\n                _grpc_encode_request,\n                _grpc_target(self.encode_urls[encode_request[\"encoder_idx\"]]),\n                encode_request,\n            )\n            for encode_request in encode_requests\n        ]\n        grpc_responses = await asyncio.gather(*grpc_tasks)\n        response_json_unsorted = []\n        for encode_request, response in zip(encode_requests, grpc_responses):\n            if self.encoder_transfer_backend == \"zmq_to_scheduler\":\n                response_json_unsorted.append(None)\n                continue\n            response_json_unsorted.append(\n                {\n                    \"req_id\": encode_request[\"req_id\"],\n                    \"prefill_host\": encode_request[\"prefill_host\"],\n                    \"embedding_port\": encode_request[\"embedding_port\"],\n                    \"encoder_idx\": encode_request[\"encoder_idx\"],\n                    \"part_idx\": encode_request[\"part_idx\"],\n                    \"embedding_size\": response.embedding_size,\n                    \"embedding_len\": response.embedding_len,\n                    \"embedding_dim\": response.embedding_dim,\n                }\n            )\n\n        if None in response_json_unsorted:\n            return\n\n        embedding_size_by_part = [None for _ in range(num_parts)]\n        response_json_sorted = [None for _ in range(num_parts)]\n        for response_json in response_json_unsorted:\n            idx = response_json[\"part_idx\"]\n            embedding_size_by_part[idx] = response_json[\"embedding_size\"]\n            response_json_sorted[idx] = response_json\n\n        total_embedding_bytes = sum(s for s in embedding_size_by_part if s is not None)\n        offset = 0\n        buffer_address = await self.allocate_embedding_buffer(\n            req_id,\n            total_embedding_bytes,\n        )\n        grpc_metadata_tasks = []\n        for response_json in response_json_sorted:\n            response_json.update(\n                {\n                    \"session_id\": self.embeddings_engine.session_id,\n                    \"buffer_address\": offset + buffer_address,\n                }\n            )\n            grpc_metadata_tasks.append(\n                asyncio.to_thread(\n                    _grpc_send_request,\n                    _grpc_target(self.encode_urls[response_json[\"encoder_idx\"]]),\n                    response_json,\n                )\n            )\n            offset += embedding_size_by_part[response_json[\"part_idx\"]]\n\n        if grpc_metadata_tasks:\n            await asyncio.gather(*grpc_metadata_tasks)\n\n\ndef _validate_transport_mode(transport_mode: str, encoder_urls):\n    if transport_mode == \"grpc\":\n        invalid_prefix = \"http://\"\n        error_msg = (\n            \"EPD MMReceiver: grpc mode requires grpc:// encoder URLs. \"\n            \"Set SGLANG_ENCODER_MM_RECEIVER_MODE=http for http:// URLs.\"\n        )\n    elif transport_mode == \"http\":\n        invalid_prefix = \"grpc://\"\n        error_msg = (\n            \"EPD MMReceiver: http mode requires http:// encoder URLs. \"\n            \"Set SGLANG_ENCODER_MM_RECEIVER_MODE=grpc for grpc:// URLs.\"\n        )\n    else:\n        return\n\n    if any(url.startswith(invalid_prefix) for url in encoder_urls):\n        raise ValueError(error_msg)\n\n\n_MM_RECEIVER_BY_MODE = {\n    \"grpc\": MMReceiverGrpc,\n    \"http\": MMReceiverHTTP,\n}\n\n\ndef create_mm_receiver(\n    server_args: ServerArgs,\n    dtype: Optional[torch.dtype] = None,\n    hf_config: Optional[PretrainedConfig] = None,\n    pp_rank: Optional[int] = None,\n    tp_rank: Optional[int] = None,\n    tp_group: Optional[GroupCoordinator] = None,\n    scheduler: Optional[\"Scheduler\"] = None,\n    transport_mode: Optional[str] = None,\n):\n    if transport_mode is None:\n        transport_mode = envs.SGLANG_ENCODER_MM_RECEIVER_MODE.get()\n        logger.debug(f\"MMReceiver transport_mode from env: {transport_mode}\")\n\n    _validate_transport_mode(transport_mode, server_args.encoder_urls)\n    logger.info(f\"EPD MMReceiver: using transport_mode={transport_mode}\")\n\n    receiver_cls = _MM_RECEIVER_BY_MODE.get(transport_mode)\n    if receiver_cls is None:\n        raise ValueError(f\"Unsupported transport_mode: {transport_mode}\")\n    return receiver_cls(\n        server_args,\n        dtype=dtype,\n        hf_config=hf_config,\n        pp_rank=pp_rank,\n        tp_rank=tp_rank,\n        tp_group=tp_group,\n        scheduler=scheduler,\n    )\n"
  },
  {
    "path": "python/sglang/srt/disaggregation/encode_server.py",
    "content": "import asyncio\nimport concurrent.futures\nimport ctypes\nimport logging\nimport multiprocessing as mp\nimport os\nimport pickle\nimport time\nimport traceback\nfrom http import HTTPStatus\nfrom typing import Dict, List, Optional, Set, Tuple, Union\n\nimport aiohttp\nimport numpy as np\nimport torch\nimport uvicorn\nimport zmq\nimport zmq.asyncio\nfrom fastapi import FastAPI\nfrom fastapi.responses import ORJSONResponse, Response\nfrom transformers import AutoProcessor\n\nfrom sglang.srt.configs.device_config import DeviceConfig\nfrom sglang.srt.configs.load_config import LoadConfig\nfrom sglang.srt.configs.model_config import ModelConfig\nfrom sglang.srt.disaggregation.encode_receiver import EmbeddingData\nfrom sglang.srt.distributed.parallel_state import (\n    get_default_distributed_backend,\n    get_mooncake_transfer_engine,\n    get_tp_group,\n    init_distributed_environment,\n    initialize_model_parallel,\n)\nfrom sglang.srt.environ import envs\nfrom sglang.srt.layers.dp_attention import initialize_dp_attention\nfrom sglang.srt.managers.io_struct import ProfileReq, ProfileReqInput, ProfileReqType\nfrom sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem\nfrom sglang.srt.mem_cache.multimodal_cache import EmbeddingResult, MultiModalStaticCache\nfrom sglang.srt.model_loader import get_model\nfrom sglang.srt.multimodal.processors.qwen_vl import preprocess_video\nfrom sglang.srt.server_args import (\n    PortArgs,\n    ServerArgs,\n    set_global_server_args_for_scheduler,\n)\nfrom sglang.srt.utils import (\n    load_audio,\n    load_image,\n    load_video,\n    random_uuid,\n)\nfrom sglang.srt.utils.network import (\n    NetworkAddress,\n    config_socket,\n    get_local_ip_auto,\n    get_zmq_socket,\n)\n\nlogger = logging.getLogger(__name__)\n\nrid_lock = asyncio.Lock()\nrid_to_receive_endpoint: Dict[str, List[str]] = dict()\nrid_to_receive_count: Dict[str, int] = dict()\nrid_to_err_msg: Dict[str, str] = dict()\ncond_dict_lock = asyncio.Lock()\nrid_to_cond: Dict[str, asyncio.Condition] = {}\n\nuse_image_processor_gpu = (\n    int(os.getenv(\"SGLANG_ENCODER_IMAGE_PROCESSOR_USE_GPU\", \"0\")) == 1\n)\n\n\nclass MMError(Exception):\n    def __init__(self, message, code=HTTPStatus.INTERNAL_SERVER_ERROR):\n        self.message = message\n        self.code = code\n        super().__init__(self.message)\n\n\nclass BadRequestError(MMError):\n    def __init__(self, message):\n        super().__init__(message, code=HTTPStatus.BAD_REQUEST)\n\n\nclass InternalError(MMError):\n    def __init__(self, message):\n        super().__init__(message, code=HTTPStatus.INTERNAL_SERVER_ERROR)\n\n\nclass TensorWrapper:\n    \"\"\"Wrapper to keep tensor alive while exposing buffer for zero-copy.\"\"\"\n\n    def __init__(self, tensor):\n        # Ensure tensor is on CPU and contiguous\n        if tensor.is_cuda:\n            tensor = tensor.cpu()\n        if not tensor.is_contiguous():\n            tensor = tensor.contiguous()\n\n        # Keep tensor reference\n        self.tensor = tensor\n        self.shape = list(tensor.shape)\n        self.dtype = tensor.dtype\n\n    def __buffer__(self):\n        data_ptr = self.tensor.data_ptr()\n        total_bytes = self.tensor.numel() * self.tensor.element_size()\n        c_obj = (ctypes.c_char * total_bytes).from_address(data_ptr)\n        c_obj._keep_alive_ref = self\n        return memoryview(c_obj)\n\n\ndef _convert(data):\n    if isinstance(data, torch.Tensor):\n        return data\n    elif isinstance(data, np.ndarray):\n        return torch.tensor(data)\n    elif isinstance(data, list) and isinstance(data[0], np.ndarray):\n        return torch.tensor(np.array(data))\n    elif isinstance(data, list) and isinstance(data[0], (int, float)):\n        return torch.tensor(data)\n    else:\n        return data\n\n\n_mm_grid_attrs = {\n    Modality.IMAGE: [\"image_grid_thw\", \"image_grid_hws\"],\n    Modality.VIDEO: [\"video_grid_thw\"],\n    Modality.AUDIO: [\"audio_feature_lens_raw\"],\n}\n\n_mm_feature_attrs = {\n    Modality.IMAGE: [\"pixel_values\"],\n    Modality.VIDEO: [\"pixel_values_videos\"],\n    Modality.AUDIO: [\"input_features\"],\n}\n\n\ndef _get_mm_grid_dim(mm_inputs, modality):\n    for attr in _mm_grid_attrs[modality]:\n        if attr in mm_inputs:\n            return mm_inputs[attr]\n    raise ValueError(f\"Grid dim ({_mm_grid_attrs[modality]}) not found in {mm_inputs}\")\n\n\ndef _get_mm_feature(mm_inputs, modality):\n    for attr in _mm_feature_attrs[modality]:\n        if attr in mm_inputs:\n            return mm_inputs[attr]\n    raise ValueError(\n        f\"Feature attrs ({_mm_feature_attrs[modality]}) not found in {mm_inputs}\"\n    )\n\n\ndef _build_mm_aux_data(mm_inputs):\n    \"\"\"\n    Build auxiliary data for video modality.\n    \"\"\"\n    aux_data = {\n        \"video_timestamps\": mm_inputs.get(\"video_timestamps\", None),\n        \"second_per_grid_ts\": mm_inputs.get(\"second_per_grid_ts\", None),\n    }\n    return aux_data\n\n\nclass MMEncoder:\n    def __init__(\n        self,\n        server_args: ServerArgs,\n        schedule_path=None,\n        dist_init_method=None,\n        rank: int = 0,\n    ):\n        logger.info(f\"init MMEncoder {rank}/{server_args.tp_size}\")\n        self.server_args = server_args\n        set_global_server_args_for_scheduler(server_args)\n        self.rank = rank\n        self.profiler = EncoderProfiler(rank)\n        self._load_mm_processor(server_args)\n\n        self.model_config = ModelConfig.from_server_args(\n            server_args,\n        )\n        self.load_config = LoadConfig(\n            load_format=server_args.load_format,\n            download_dir=server_args.download_dir,\n            model_loader_extra_config=server_args.model_loader_extra_config,\n            remote_instance_weight_loader_seed_instance_ip=server_args.remote_instance_weight_loader_seed_instance_ip,\n            remote_instance_weight_loader_seed_instance_service_port=server_args.remote_instance_weight_loader_seed_instance_service_port,\n            remote_instance_weight_loader_send_weights_group_ports=server_args.remote_instance_weight_loader_send_weights_group_ports,\n        )\n        self.model_type = getattr(\n            self.model_config.hf_config, \"model_type\", \"unknown\"\n        ).lower()\n\n        self.device = server_args.device\n        self.gpu_id = server_args.base_gpu_id + rank\n\n        self.device_config = DeviceConfig(\n            device=self.device,\n            gpu_id=self.gpu_id,\n        )\n\n        torch.get_device_module(self.device).set_device(self.gpu_id)\n\n        self.use_image_processor_gpu = (\n            use_image_processor_gpu and not server_args.disable_fast_image_processor\n        )\n        self._build_vision_config(server_args.mm_process_config)\n\n        init_distributed_environment(\n            backend=get_default_distributed_backend(self.device),\n            world_size=server_args.tp_size,\n            rank=rank,\n            distributed_init_method=dist_init_method,\n            local_rank=rank,\n        )\n        initialize_model_parallel(tensor_model_parallel_size=server_args.tp_size)\n        initialize_dp_attention(server_args, self.model_config)\n\n        self.model = get_model(\n            model_config=self.model_config,\n            load_config=self.load_config,\n            device_config=self.device_config,\n        )\n\n        self.context = zmq.asyncio.Context(2)\n        self.sync_context = zmq.Context()  # Reuse sync context for thread pool\n        self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=10)\n\n        embedding_cache_size = int(os.environ.get(\"SGLANG_VLM_CACHE_SIZE_MB\", \"4096\"))\n        self.mm_cache = MultiModalStaticCache(embedding_cache_size * 1024 * 1024)\n        self.mm_cache_lock = asyncio.Lock()\n\n        self.io_executor = concurrent.futures.ThreadPoolExecutor(\n            max_workers=int(os.environ.get(\"SGLANG_ENCODER_MM_LOAD_WORKERS\", 4))\n        )\n        self.send_timeout = envs.SGLANG_ENCODER_SEND_TIMEOUT.get()\n\n        if schedule_path is not None:\n            self.schedule_socket = get_zmq_socket(\n                self.context, zmq.PULL, schedule_path, True\n            )\n        self.background_tasks: Set[asyncio.Task] = set()\n\n        if self.server_args.enable_mm_global_cache:\n            from sglang.srt.mem_cache.storage.mooncake_store.embedding_cache_controller import (\n                EmbeddingCacheController,\n            )\n\n            hidden_dims = self._infer_embedding_dims()\n            self.mm_global_cache = EmbeddingCacheController(\n                rank,\n                server_args.tp_size,\n                hidden_dims=hidden_dims,\n                tp_group=get_tp_group().cpu_group,\n                all_rank_get=False,\n            )\n        else:\n            self.mm_global_cache = None\n\n        if self.rank == 0:\n            logger.info(\n                f\"Using transfer backend: {self.server_args.encoder_transfer_backend}\"\n            )\n\n            if self.server_args.encoder_transfer_backend == \"mooncake\":\n                self.local_ip = get_local_ip_auto()\n\n                self.engine = get_mooncake_transfer_engine()\n                if self.engine is None:\n                    from sglang.srt.distributed.device_communicators.mooncake_transfer_engine import (\n                        init_mooncake_transfer_engine,\n                    )\n\n                    self.engine = init_mooncake_transfer_engine(\n                        hostname=self.local_ip,\n                        gpu_id=self.gpu_id,\n                        ib_device=(\n                            self.server_args.disaggregation_ib_device\n                            or self.server_args.mooncake_ib_device\n                        ),\n                    )\n\n            self.embedding_to_send = dict()\n\n        logger.info(f\"rank {rank} init finish \")\n\n    def _infer_embedding_dims(self) -> dict:\n        \"\"\"Infer per-modality embedding dimensions from hf_config at init time.\"\"\"\n        default = self.model_config.hidden_size\n        hf_cfg = self.model_config.hf_config\n        thinker_cfg = getattr(hf_cfg, \"thinker_config\", None)\n        dims = {\n            Modality.IMAGE: default,\n            Modality.VIDEO: default,\n            Modality.AUDIO: default,\n        }\n\n        vision_cfg = getattr(thinker_cfg, \"vision_config\", None) or getattr(\n            hf_cfg, \"vision_config\", None\n        )\n        if vision_cfg is not None:\n            out_hs = getattr(vision_cfg, \"out_hidden_size\", None)\n            if out_hs is not None:\n                ds = getattr(vision_cfg, \"deepstack_visual_indexes\", None)\n                vis_dim = (\n                    out_hs * (1 + len(ds))\n                    if isinstance(ds, (list, tuple)) and ds\n                    else out_hs\n                )\n                dims[Modality.IMAGE] = vis_dim\n                dims[Modality.VIDEO] = vis_dim\n\n        audio_cfg = getattr(thinker_cfg, \"audio_config\", None) or getattr(\n            hf_cfg, \"audio_config\", None\n        )\n        if audio_cfg is not None:\n            for attr in (\"output_dim\", \"d_model\"):\n                val = getattr(audio_cfg, attr, None)\n                if val and int(val) > 0:\n                    dims[Modality.AUDIO] = int(val)\n                    break\n\n        logger.info(f\"Global cache embedding dims: {dims}\")\n        return dims\n\n    def _build_vision_config(self, mm_process_config):\n        \"\"\"\n        Validate vision config, used for image/video/audio.\n        If not provided, keep default values.\n        \"\"\"\n        self.vision_config = (\n            mm_process_config.get(\"vision_config\", {})\n            if mm_process_config is not None\n            else {}\n        )\n        for modality_str in [\"image\", \"video\", \"audio\"]:\n            if not self.vision_config.get(modality_str, None):\n                self.vision_config[modality_str] = {}\n            if self.use_image_processor_gpu:\n                self.vision_config[modality_str][\"device\"] = self.device\n\n            if modality_str == \"video\":\n                video_defaults = {\"fps\": 2.0, \"max_frames\": 768, \"min_frames\": 4}\n                for k, v in video_defaults.items():\n                    self.vision_config[\"video\"].setdefault(k, v)\n\n            if modality_str == \"audio\":\n                if \"return_attention_mask\" not in self.vision_config[\"audio\"]:\n                    self.vision_config[\"audio\"][\"return_attention_mask\"] = True\n                if \"padding\" not in self.vision_config[\"audio\"]:\n                    if self.model_type == \"qwen2_audio\":\n                        # For Qwen2Audio, use padding=\"max_length\"\n                        # (same as https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_audio/processing_qwen2_audio.py#L93)\n                        self.vision_config[\"audio\"][\"padding\"] = \"max_length\"\n                    else:\n                        self.vision_config[\"audio\"][\"padding\"] = True\n                if \"truncation\" not in self.vision_config[\"audio\"]:\n                    # keep same logic as base_processor.py\n                    if (\n                        hasattr(self, \"audio_processor\")\n                        and self.audio_processor is not None\n                    ):\n                        if self.audio_processor.__class__.__name__ in {\n                            \"Gemma3nProcessor\",\n                            \"GlmAsrProcessor\",\n                            \"Qwen2AudioProcessor\",\n                            \"Qwen3OmniMoeProcessor\",\n                        }:\n                            self.vision_config[\"audio\"][\"truncation\"] = False\n\n    def _load_mm_processor(self, server_args: ServerArgs):\n        \"\"\"\n        Load image/video/audio processor separately,\n        avoid issues with AutoProcessor not recognizing certain models\n        \"\"\"\n        from transformers import AutoImageProcessor, AutoVideoProcessor\n\n        try:\n            self.image_processor = AutoImageProcessor.from_pretrained(\n                server_args.tokenizer_path or server_args.model_path,\n                trust_remote_code=server_args.trust_remote_code,\n                revision=server_args.revision,\n                use_fast=not server_args.disable_fast_image_processor,\n            )\n        except Exception as e:\n            logger.warning(f\"Failed to load image processor: {e}\")\n            self.image_processor = None\n\n        try:\n            self.video_processor = AutoVideoProcessor.from_pretrained(\n                server_args.tokenizer_path or server_args.model_path,\n                trust_remote_code=server_args.trust_remote_code,\n                revision=server_args.revision,\n                use_fast=not server_args.disable_fast_image_processor,\n            )\n        except Exception as e:\n            logger.warning(f\"Failed to load video processor: {e}\")\n            self.video_processor = None\n\n        try:\n            # Note: AutoProcessor is used for audio processor\n            _audio_proc = AutoProcessor.from_pretrained(\n                server_args.tokenizer_path or server_args.model_path,\n                trust_remote_code=server_args.trust_remote_code,\n                revision=server_args.revision,\n                use_fast=not server_args.disable_fast_image_processor,\n            )\n            if not hasattr(_audio_proc, \"feature_extractor\"):\n                logger.warning(\n                    \"Loaded AutoProcessor has no feature_extractor attribute, \"\n                    \"audio processing will be unavailable.\"\n                )\n                self.audio_processor = None\n            else:\n                self.audio_processor = _audio_proc\n        except Exception as e:\n            logger.warning(f\"Failed to load audio processor: {e}\")\n            self.audio_processor = None\n\n    def _load_single_item(\n        self,\n        data,\n        modality: Modality,\n        frame_count_limit=None,\n        audio_sample_rate: Optional[int] = None,\n        discard_alpha_channel=True,\n    ):\n        \"\"\"\n        Load a single multimodal data.\n        If data is precomputed, returns directly.\n        Static method that can be pickled for multiprocessing\"\"\"\n        if isinstance(data, dict):\n            return data\n        try:\n            if modality == Modality.IMAGE:\n                img, _ = load_image(data)\n                if discard_alpha_channel and img.mode != \"RGB\":\n                    img = img.convert(\"RGB\")\n                return img\n            elif modality == Modality.VIDEO:\n                return load_video(data, frame_count_limit)\n            elif modality == Modality.AUDIO:\n                return load_audio(data, audio_sample_rate)\n\n        except Exception as e:\n            raise RuntimeError(f\"Error while loading data {data}: {e}\")\n\n    def submit_data_loading_tasks(self, items, modalities):\n        futures = []\n        task_info = []\n\n        for data, modality in zip(items, modalities):\n            if modality is not None:\n                futures.append(\n                    self.io_executor.submit(\n                        self._load_single_item,\n                        data,\n                        modality,\n                    )\n                )\n                task_info.append((modality, data))\n        return futures, task_info\n\n    def _get_feat_extract_output_lengths(self, feature_lens):\n        \"\"\"\n        Computes the output length of the convolutional layers and the output length of the audio encoder\n        \"\"\"\n        # qwen2_audio/qwen2.5_omni\n        if self.model_type in [\"qwen2_audio\", \"qwen2_5_omni\"]:\n            input_length = (feature_lens - 1) // 2 + 1\n            return (input_length - 2) // 2 + 1\n        # qwen3_omni_moe\n        elif self.model_type == \"qwen3_omni_moe\":\n            input_lengths_leave = feature_lens % 100\n            feat_lengths = (input_lengths_leave - 1) // 2 + 1\n            output_lengths = (\n                ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (feature_lens // 100) * 13\n            )\n            return output_lengths\n        else:\n            # fallback to original HF audio sample logic for other models\n            logger.warning(\n                f\"Fallback to original HF audio sample logic for {self.model_type}\"\n            )\n            input_length = (feature_lens - 1) // 2 + 1\n            return (input_length - 2) // 2 + 1\n\n    async def _flatten_and_load_videos(self, mm_items):\n        if not isinstance(mm_items, (list, tuple)):\n            mm_items = [mm_items]\n\n        futures, _ = self.submit_data_loading_tasks(\n            mm_items, [Modality.VIDEO] * len(mm_items)\n        )\n        async_futures = [asyncio.wrap_future(f) for f in futures]\n        video_items = await asyncio.gather(*async_futures)\n\n        video_processor_kwargs = {}\n        if \"qwen\" in self.model_type:\n            # for qwen-series model, do sample frames before preprocess\n            video_processed = [\n                await preprocess_video(\n                    video, video_config=self.vision_config.get(\"video\", {})\n                )\n                for video in video_items\n            ]\n            videos, video_metadata = map(list, zip(*video_processed))\n            video_processor_kwargs[\"do_sample_frames\"] = False\n            if video_metadata:\n                video_processor_kwargs[\"video_metadata\"] = video_metadata\n            return videos, video_processor_kwargs\n        else:\n            raise NotImplementedError(\n                f\"Video processing is not supported for {self.model_type} model.\"\n            )\n\n    async def _flatten_and_load_data_by_modality(self, mm_items, modality):\n        \"\"\"\n        Flatten mm_items structure, load multimodal data concurrently, and restore original structure.\n\n        Returns:\n            Same structure as load_mm_items would return, support for image/audio\n        \"\"\"\n        # Handle single mm_item (not a list)\n        if not isinstance(mm_items, (list, tuple)):\n            futures, _ = self.submit_data_loading_tasks([mm_items], [modality])\n            return await asyncio.wrap_future(futures[0])\n\n        # Handle nested list (list of lists)\n        if len(mm_items) > 0 and isinstance(mm_items[0], (list, tuple)):\n            # Flatten nested structure\n            flat_data = []\n            flat_indices = []  # Track which group each item belongs to\n            for group_idx, item_group in enumerate(mm_items):\n                for item in item_group:\n                    flat_data.append(item)\n                    flat_indices.append(group_idx)\n\n            # Submit all tasks concurrently\n            futures, _ = self.submit_data_loading_tasks(\n                flat_data, [modality] * len(flat_data)\n            )\n\n            # Wait for all tasks to complete asynchronously\n            async_futures = [asyncio.wrap_future(f) for f in futures]\n            results = await asyncio.gather(*async_futures)\n\n            # Restore nested structure\n            nested_results = [[] for _ in range(len(mm_items))]\n            for idx, result in zip(flat_indices, results):\n                nested_results[idx].append(result)\n\n            return nested_results\n\n        # Handle simple list\n        else:\n            futures, _ = self.submit_data_loading_tasks(\n                mm_items, [modality] * len(mm_items)\n            )\n            # Wait for all tasks to complete asynchronously\n            async_futures = [asyncio.wrap_future(f) for f in futures]\n            return await asyncio.gather(*async_futures)\n\n    def get_num_patches(\n        self, grid: Union[torch.Tensor, List[int]], modality: Modality\n    ) -> int:\n        \"\"\"Calculate number of raw patches (before merge/sampling). Used for pixel_values slicing.\"\"\"\n        if modality == Modality.AUDIO:\n            return int(grid.item())\n        else:\n            return int(grid[0] * grid[1] * grid[2])\n\n    def get_num_tokens(\n        self, grid: Union[torch.Tensor, List[int]], modality: Modality\n    ) -> int:\n        \"\"\"Calculate number of tokens (after 2x2 merge). Used for mm_embedding slicing.\"\"\"\n        if modality == Modality.AUDIO:\n            input_length = self.get_num_patches(grid, modality)\n            return self._get_feat_extract_output_lengths(input_length)\n        else:\n            merge_size = getattr(self.image_processor, \"merge_size\", 2)\n            return self.get_num_patches(grid, modality) // (merge_size**2)\n\n    def slice_embedding(\n        self, mm_embedding: torch.Tensor, grid_thw: List, modality: Modality\n    ) -> List[torch.Tensor]:\n        \"\"\"Slice a concatenated embedding tensor into individual image embeddings.\"\"\"\n        slices, offset = [], 0\n        for grid in grid_thw:\n            count = self.get_num_tokens(grid, modality)\n            slices.append(mm_embedding[offset : offset + count])\n            offset += count\n        return slices\n\n    def _calculate_hashes_from_features(\n        self, mm_feature: torch.Tensor, grid_thw: List, modality: Modality\n    ) -> List[str]:\n        \"\"\"CPU Task: Compute hashes based on processed feature patches.\"\"\"\n        hashes, offset = [], 0\n        logger.info(f\"{mm_feature.shape=} with {modality=}\")\n        for grid in grid_thw:\n            num_patches = self.get_num_patches(grid, modality)\n            feature_slice = mm_feature[offset : offset + num_patches]\n            tmp_item = MultimodalDataItem(modality=modality, feature=feature_slice)\n            tmp_item.set_pad_value()\n            hashes.append(tmp_item.hash)\n            offset += num_patches\n        return hashes\n\n    async def _encode_missing(\n        self,\n        mm_feature: torch.Tensor,\n        mm_inputs: dict,\n        indices: List[int],\n        modality: Modality = Modality.IMAGE,\n        get_feature_fn=None,\n    ) -> List[torch.Tensor]:\n        \"\"\"\n        GPU Task: Run ViT inference ONLY on the subset of mm items missing from the cache.\n        \"\"\"\n        grid_thw = _get_mm_grid_dim(mm_inputs, modality)\n\n        # 1. Slice mm_feature to get only the patches for missing mm items\n        sub_feature_list = []\n        offsets = [0]\n        curr = 0\n        for g in grid_thw:\n            curr += self.get_num_patches(g, modality)\n            offsets.append(curr)\n\n        for idx in indices:\n            sub_feature_list.append(mm_feature[offsets[idx] : offsets[idx + 1]])\n\n        sub_feature = torch.cat(sub_feature_list, dim=0)\n\n        mm_item = MultimodalDataItem.from_dict(\n            {\n                \"modality\": modality,\n                \"feature\": _convert(sub_feature),\n            }\n        )\n\n        for k, v in mm_inputs.items():\n            if k in _mm_feature_attrs.get(modality, []):\n                continue\n            val = _convert(v)\n            if k in _mm_grid_attrs.get(modality, []):\n                mm_item.set(k, val[indices])\n            else:\n                mm_item.set(k, val)\n\n        with torch.inference_mode():\n            new_embeddings = get_feature_fn([mm_item]).cpu()\n            if new_embeddings.ndim != 2:\n                new_embeddings = new_embeddings.reshape(-1, new_embeddings.shape[-1])\n\n        sub_grids = [grid_thw[i] for i in indices]\n        return self.slice_embedding(new_embeddings, sub_grids, modality)\n\n    async def encode_with_global_cache(\n        self,\n        mm_items,\n        modality: Modality,\n        req_id: str,\n        num_parts: int,\n        part_idx: int,\n        hashes: Optional[List[str]] = None,\n    ) -> torch.Tensor:\n        mm_inputs, get_feature_fn = await self._process_mm_items(mm_items, modality)\n        grid_thw = _get_mm_grid_dim(mm_inputs, modality)\n        mm_feature = _convert(_get_mm_feature(mm_inputs, modality))\n        num_items = len(grid_thw)\n\n        # Step 1: Rank 0 checks global cache and broadcasts hit/miss mask to all ranks.\n        if self.rank == 0:\n            if hashes is None:\n                mm_hashes = self._calculate_hashes_from_features(\n                    mm_feature, grid_thw, modality\n                )\n            else:\n                mm_hashes = hashes\n            exist_mask = await self.mm_global_cache.batch_is_exist(mm_hashes)\n            mask_tensor = torch.tensor(\n                [1 if e else 0 for e in exist_mask], dtype=torch.int32\n            )\n        else:\n            mm_hashes = None\n            mask_tensor = torch.zeros(num_items, dtype=torch.int32)\n\n        if self.server_args.tp_size > 1:\n            torch.distributed.broadcast(\n                mask_tensor,\n                src=0,\n                group=self.mm_global_cache.prefetch_tp_group,\n            )\n\n        exist_mask = [m.item() == 1 for m in mask_tensor]\n        missing_indices = [i for i, e in enumerate(exist_mask) if not e]\n        hit_indices = [i for i, e in enumerate(exist_mask) if e]\n\n        # Step 2: All ranks run ViT together on cache-miss images.\n        new_slices = []\n        if missing_indices:\n            new_slices = await self._encode_missing(\n                mm_feature, mm_inputs, missing_indices, modality, get_feature_fn\n            )\n\n        # Step 3: Rank 0 prefetches cache-hit embeddings from global cache.\n        prefetch_status = torch.tensor([1], dtype=torch.int32)\n\n        if self.rank == 0:\n            if hit_indices:\n                hit_hashes = [mm_hashes[i] for i in hit_indices]\n                hit_tokens = [\n                    self.get_num_tokens(grid_thw[i], modality) for i in hit_indices\n                ]\n                self.mm_global_cache.prefetch(req_id, hit_hashes, hit_tokens, modality)\n\n                try:\n\n                    async def _wait_prefetch():\n                        while not self.mm_global_cache.check_prefetch_progress(req_id):\n                            await asyncio.sleep(0.005)\n\n                    await asyncio.wait_for(_wait_prefetch(), timeout=60.0)\n                except (asyncio.TimeoutError, Exception) as e:\n                    logger.error(\n                        f\"Prefetch failed for req {req_id}: {e}. \"\n                        f\"Falling back to ViT for {len(hit_indices)} hit items.\"\n                    )\n                    prefetch_status[0] = 0\n\n        # Step 4: Broadcast prefetch result to all ranks so they stay in sync.\n        if self.server_args.tp_size > 1:\n            torch.distributed.broadcast(\n                prefetch_status,\n                src=0,\n                group=self.mm_global_cache.prefetch_tp_group,\n            )\n\n        # Step 5: If prefetch failed, all ranks fallback to ViT for the hit mm items.\n        if prefetch_status.item() == 0 and hit_indices:\n            logger.info(\n                f\"Req {req_id}: Prefetch failed, all ranks running ViT fallback \"\n                f\"for {len(hit_indices)} mm items.\"\n            )\n            fallback_slices = await self._encode_missing(\n                mm_feature, mm_inputs, hit_indices, modality, get_feature_fn\n            )\n        else:\n            fallback_slices = None\n\n        # Step 6: Rank 0 assembles final embedding and prepares for sending.\n        if self.rank == 0:\n            final_slices = [None] * num_items\n\n            for i, idx in enumerate(missing_indices):\n                final_slices[idx] = new_slices[i]\n\n            # Fill in cache-hit embeddings (from prefetch or fallback)\n            if prefetch_status.item() == 1 and hit_indices:\n                cached_slices = self.mm_global_cache.get_embeddings(\n                    [mm_hashes[i] for i in hit_indices]\n                )\n                for i, idx in enumerate(hit_indices):\n                    final_slices[idx] = cached_slices[i]\n            elif fallback_slices is not None:\n                for i, idx in enumerate(hit_indices):\n                    final_slices[idx] = fallback_slices[i]\n\n            mm_embedding = torch.cat(final_slices, dim=0)\n\n            # Background insert: store newly computed embeddings into global cache.\n            # Includes both original misses and fallback-recomputed hits.\n            all_new_hashes = [mm_hashes[i] for i in missing_indices]\n            all_new_slices = list(new_slices)\n            if fallback_slices is not None:\n                all_new_hashes += [mm_hashes[i] for i in hit_indices]\n                all_new_slices += list(fallback_slices)\n\n            if all_new_hashes:\n\n                async def _background_insert():\n                    await asyncio.to_thread(\n                        self.mm_global_cache.insert_batch,\n                        all_new_hashes,\n                        all_new_slices,\n                    )\n\n                task = asyncio.create_task(_background_insert())\n                self.background_tasks.add(task)\n                task.add_done_callback(self.background_tasks.discard)\n\n            aux_data = _build_mm_aux_data(mm_inputs)\n            self.embedding_to_send[req_id] = EmbeddingData(\n                req_id,\n                num_parts,\n                part_idx,\n                grid_thw,\n                modality,\n                mm_embedding,\n                **aux_data,\n            )\n            return (\n                mm_embedding.nbytes,\n                mm_embedding.shape[0],\n                mm_embedding.shape[1],\n                None,\n                None,\n            )\n        else:\n            return (0, 0, 0, None, None)\n\n    async def _flatten_and_load_audios(self, mm_items):\n        \"\"\"\n        Flatten mm_items structure, load audios concurrently, and restore original structure.\n        \"\"\"\n        return await self._flatten_and_load_data_by_modality(mm_items, Modality.AUDIO)\n\n    async def _flatten_and_load_images(self, mm_items):\n        \"\"\"\n        Flatten mm_items structure, load images concurrently, and restore original structure.\n        \"\"\"\n        return await self._flatten_and_load_data_by_modality(mm_items, Modality.IMAGE)\n\n    def _calculate_timestamps(self, indices, video_fps: float, merge_size: int = 2):\n        \"\"\"Calculate timestamps for video frames, used for qwen3_vl models.\"\"\"\n        # refer to https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3_vl/processing_qwen3_vl.py#L255\n        if not isinstance(indices, list):\n            indices = indices.tolist()\n        if len(indices) % merge_size != 0:\n            indices.extend(\n                indices[-1] for _ in range(merge_size - len(indices) % merge_size)\n            )\n        timestamps = [idx / video_fps for idx in indices]\n        # Frames are merged by merge_size, so we need to average the timestamps\n        # between the first/last frame within the temporal patch\n        timestamps = [\n            (timestamps[i] + timestamps[i + merge_size - 1]) / 2\n            for i in range(0, len(timestamps), merge_size)\n        ]\n        return timestamps\n\n    async def _process_mm_items(self, mm_items, modality):\n        if modality == Modality.IMAGE and self.image_processor:\n            images = await self._flatten_and_load_images(mm_items)\n            image_config = self.vision_config.get(\"image\", {})\n            processor_input = self.image_processor(images=images, **image_config)\n            feature = processor_input[\"pixel_values\"]\n            if hasattr(self.model, \"thinker\"):  # for omni models\n                get_feature_method = self.model.thinker.get_image_feature\n            else:\n                get_feature_method = self.model.get_image_feature\n        elif modality == Modality.VIDEO and self.video_processor:\n            videos, video_processor_kwargs = await self._flatten_and_load_videos(\n                mm_items\n            )\n            processor_input = self.video_processor(\n                videos=videos, **video_processor_kwargs\n            )\n            # Get additional video metadata\n            if (\n                self.model_type in [\"qwen3_vl\", \"qwen3_vl_moe\"]\n                and video_processor_kwargs.get(\"video_metadata\", None) is not None\n            ):\n                # For qwen3-vl models, we need to store the video timestamps\n                video_metadata = video_processor_kwargs[\"video_metadata\"]\n                try:\n                    merge_size = (\n                        self.model_config.hf_config.vision_config.spatial_merge_size\n                    )\n                except (AttributeError, KeyError):\n                    merge_size = 2  # Default merge_size\n\n                video_timestamps = []\n                for metadata in video_metadata:\n                    video_fps = metadata.get(\"fps\", None) or 24  # original video fps\n                    frames_indices = metadata.get(\"frames_indices\", None)\n                    timestamps = self._calculate_timestamps(\n                        frames_indices, video_fps, merge_size\n                    )\n                    video_timestamps.append(timestamps)\n                processor_input[\"video_timestamps\"] = video_timestamps\n            elif (\n                self.model_type in [\"qwen2_5_vl\", \"qwen2_5_omni\", \"qwen3_omni_moe\"]\n                and processor_input.get(\"video_grid_thw\", None) is not None\n            ):\n                # For omni/qwen2_5_vl models, calculate second_per_grid_ts for rotary embedding\n                video_grid_thw = processor_input[\"video_grid_thw\"]\n                try:\n                    temporal_patch_size = self.video_processor.temporal_patch_size\n                except AttributeError:\n                    temporal_patch_size = 2  # Default temporal_patch_size\n                # get sampled fps, default: 2\n                fps_list = [\n                    self.vision_config.get(\"video\", {}).get(\"fps\", None) or 2\n                ] * len(video_grid_thw)\n                second_per_grid_ts = [(temporal_patch_size / fps) for fps in fps_list]\n                second_per_grid_ts_tensor = torch.tensor(\n                    second_per_grid_ts, dtype=torch.float32\n                )\n                processor_input[\"second_per_grid_ts\"] = second_per_grid_ts_tensor\n\n            feature = processor_input[\"pixel_values_videos\"]\n            if hasattr(self.model, \"thinker\"):  # for omni models\n                get_feature_method = self.model.thinker.get_video_feature\n            else:\n                get_feature_method = self.model.get_video_feature\n        elif modality == Modality.AUDIO and self.audio_processor:\n            audios = await self._flatten_and_load_audios(mm_items)\n            audio_config = self.vision_config.get(\"audio\", {})\n            processor_input = self.audio_processor.feature_extractor(\n                audios, **audio_config\n            )\n            processor_input[\"feature_attention_mask\"] = processor_input.pop(\n                \"attention_mask\"\n            )\n            # convert to same format as image/video\n            input_lengths = torch.tensor(\n                processor_input[\"feature_attention_mask\"].sum(-1), dtype=torch.long\n            )\n            processor_input[\"audio_feature_lens_raw\"] = input_lengths\n            output_lengths = self._get_feat_extract_output_lengths(input_lengths)\n            processor_input[\"audio_feature_lens\"] = output_lengths\n            feature = processor_input[\"input_features\"]\n            if hasattr(self.model, \"thinker\"):  # for omni models\n                get_feature_method = self.model.thinker.get_audio_feature\n            else:\n                get_feature_method = self.model.get_audio_feature\n        else:\n            raise ValueError(\n                f\"Currently only support image, video and audio modalities, {modality} modality has no processor available.\"\n            )\n\n        return processor_input, get_feature_method\n\n    async def _encode(self, mm_items, modality: Modality) -> torch.Tensor:\n        try:\n            mm_inputs, get_feature_fn = await self._process_mm_items(mm_items, modality)\n        except NotImplementedError as e:\n            raise InternalError(f\"Not implemented error: {str(e)}\")\n        except Exception as e:\n            raise BadRequestError(f\"Failed to process mm items: {str(e)}\")\n        try:\n            # support mm_cache\n            mm_embedding = None\n            mm_hash = None\n\n            mm_item = MultimodalDataItem.from_dict(\n                {\n                    \"modality\": modality,\n                    \"feature\": _convert(_get_mm_feature(mm_inputs, modality)),\n                }\n            )\n            for k, v in mm_inputs.items():\n                if k in _mm_feature_attrs[modality]:\n                    continue\n                mm_item.set(k, _convert(v))\n\n            if self.server_args.enable_prefix_mm_cache:\n                mm_item.set_pad_value()\n                mm_hash = MultiModalStaticCache.combine_hashes([mm_item.hash])\n                async with self.mm_cache_lock:\n                    mm_cache = self.mm_cache.get([mm_item.hash])\n                    if mm_cache is not None:\n                        mm_embedding = mm_cache.embedding\n\n            if mm_embedding is None:\n                with torch.inference_mode():\n                    mm_embedding: torch.Tensor = get_feature_fn([mm_item])\n                    mm_embedding = mm_embedding.cpu()\n                if len(mm_embedding.shape) != 2:\n                    mm_embedding = mm_embedding.reshape(-1, mm_embedding.shape[-1])\n\n            if self.server_args.enable_prefix_mm_cache:\n                async with self.mm_cache_lock:\n                    self.mm_cache.set(mm_hash, EmbeddingResult(embedding=mm_embedding))\n            if self.profiler is not None:\n                self.profiler.step()\n\n            aux_data = _build_mm_aux_data(mm_inputs)\n            return _get_mm_grid_dim(mm_inputs, modality), mm_embedding, aux_data\n        except BadRequestError as e:\n            raise BadRequestError(f\"Bad request error: {str(e)}\")\n        except Exception as e:\n            raise InternalError(f\"Internal encoding error: {str(e)}\")\n\n    async def _send(\n        self,\n        embedding: torch.Tensor,\n        mm_data: EmbeddingData,\n        session_id=None,\n        buffer_address=None,\n        prefill_host=None,\n        embedding_port=None,\n        url=None,\n    ):\n        if self.server_args.encoder_transfer_backend == \"mooncake\":\n            self.engine.register(embedding.data_ptr(), embedding.nbytes)\n            self.engine.transfer_sync(\n                session_id, embedding.data_ptr(), buffer_address, embedding.nbytes\n            )\n            self.engine.deregister(embedding.data_ptr())\n\n            mm_data.embedding = None\n\n        # Send ack/data\n        if url is not None:\n            endpoint = NetworkAddress.parse(url).to_tcp()\n        else:\n            endpoint = NetworkAddress(prefill_host, embedding_port).to_tcp()\n        logger.info(f\"{endpoint = }\")\n\n        # Serialize data\n        if self.server_args.encoder_transfer_backend == \"mooncake\":\n            serialized_data = pickle.dumps(mm_data)\n            buffer = None\n        else:\n            new_mm_data = mm_data.copy_without_embedding()\n            if new_mm_data.error_msg is not None:\n                buffer = None\n                serialized_data = pickle.dumps(new_mm_data)\n            else:\n                embedding_tensor = TensorWrapper(mm_data.embedding)\n                serialized_data = pickle.dumps(new_mm_data)\n                buffer = embedding_tensor.__buffer__()\n\n        # Use thread pool executor for parallel ZMQ send operations\n        def send_with_socket():\n            sock = self.sync_context.socket(zmq.PUSH)\n            config_socket(sock, zmq.PUSH)\n            try:\n                sock.connect(endpoint)\n                if buffer is not None:\n                    sock.send_multipart([serialized_data, buffer], copy=False)\n                else:\n                    sock.send_multipart([serialized_data], copy=False)\n            finally:\n                sock.close()\n\n        await asyncio.get_event_loop().run_in_executor(self.executor, send_with_socket)\n\n    async def encode(self, mm_items, modality: Modality, req_id, num_parts, part_idx):\n        try:\n            grid_dim, mm_embedding, aux_data = await self._encode(mm_items, modality)\n\n            if self.rank == 0:\n                mm_data = EmbeddingData(\n                    req_id,\n                    num_parts,\n                    part_idx,\n                    grid_dim,\n                    modality,\n                    mm_embedding,\n                    **aux_data,\n                )\n                self.embedding_to_send[req_id] = mm_data\n            return (\n                mm_embedding.nbytes,\n                mm_embedding.shape[0],\n                mm_embedding.shape[1],\n                None,\n                None,\n            )\n        except Exception as e:\n            error_code = getattr(e, \"code\", HTTPStatus.INTERNAL_SERVER_ERROR)\n            error_msg = str(e)\n            logger.error(f\"Rank {self.rank} encode failed: {error_msg} {error_code = }\")\n            if self.rank == 0:\n                mm_data = EmbeddingData(\n                    req_id,\n                    num_parts,\n                    part_idx,\n                    None,\n                    modality,\n                    error_msg=error_msg,\n                    error_code=error_code,\n                )\n                self.embedding_to_send[req_id] = mm_data\n                logger.debug(f\"Created error EmbeddingData: {mm_data}\")\n            return 0, 0, 0, error_msg, error_code\n\n    # For zmq_to_tokenizer zmq_to_scheduler and mooncake\n    async def send(\n        self, req_id, prefill_host, embedding_port, session_id=None, buffer_address=None\n    ):\n        mm_data: EmbeddingData = self.embedding_to_send[req_id]\n        await self._send(\n            mm_data.embedding,\n            mm_data,\n            session_id=session_id,\n            buffer_address=buffer_address,\n            prefill_host=prefill_host,\n            embedding_port=embedding_port,\n        )\n\n    # For zmq_to_scheduler\n    async def send_with_url(\n        self,\n        req_id,\n    ):\n        mm_data = self.embedding_to_send.get(req_id)\n        if not mm_data:\n            return\n        sent_urls: Set[str] = set()\n        all_tasks: List[Tuple[asyncio.Task, str]] = []\n        start_time = asyncio.get_running_loop().time()\n        timeout = self.send_timeout\n        cond = await get_condition(req_id)\n\n        try:\n            while True:\n                async with rid_lock:\n                    current_targets = rid_to_receive_endpoint.get(req_id, set()).copy()\n                    expected_count = rid_to_receive_count.get(req_id)\n\n                new_targets = current_targets - sent_urls\n\n                if new_targets:\n                    logger.info(\n                        f\"Found {len(new_targets)} new endpoints for {req_id}. Starting tasks...\"\n                    )\n                    for url in new_targets:\n                        task = asyncio.create_task(\n                            self._send(\n                                mm_data.embedding,\n                                mm_data,\n                                url=url,\n                            )\n                        )\n                        all_tasks.append((task, url))\n                        sent_urls.add(url)  # Mark as handled immediately\n                if expected_count is not None and len(sent_urls) >= expected_count:\n                    logger.info(\n                        f\"All {expected_count} endpoints initiated for {req_id}. Breaking loop.\"\n                    )\n                    break\n                remaining = timeout - (asyncio.get_running_loop().time() - start_time)\n                if remaining <= 0:\n                    logger.error(\n                        f\"[{req_id}] Timeout! Sent {len(sent_urls)}/{expected_count}\"\n                    )\n                    break\n\n                async with cond:\n                    try:\n                        await asyncio.wait_for(cond.wait(), timeout=remaining)\n                    except asyncio.TimeoutError:\n                        continue\n\n            if all_tasks:\n                logger.info(\n                    f\"Loop finished. Awaiting completion of {len(all_tasks)} sending tasks...\"\n                )\n                tasks_only = [t[0] for t in all_tasks]\n                results = await asyncio.gather(*tasks_only, return_exceptions=True)\n\n                # Process results and log errors\n                for i, result in enumerate(results):\n                    url = all_tasks[i][1]  # Retrieve URL associated with the task\n                    if isinstance(result, Exception):\n                        logger.error(f\"Failed to send to {url}: {result}\")\n                    else:\n                        logger.debug(f\"Successfully sent to {url}\")\n\n            logger.info(f\"All tasks completed for req_id: {req_id}\")\n\n        finally:\n            logger.info(f\"Cleaning up resources for req_id {req_id}\")\n            async with rid_lock:\n                rid_to_receive_endpoint.pop(req_id, None)\n                rid_to_receive_count.pop(req_id, None)\n            async with cond_dict_lock:\n                rid_to_cond.pop(req_id, None)\n            self.embedding_to_send.pop(req_id, None)\n\n    async def get_embedding_port(self, prefill_url):\n        async with aiohttp.ClientSession(\n            timeout=aiohttp.ClientTimeout(total=1800)\n        ) as session:\n            response = await session.post(\n                f\"{prefill_url}/embedding_bootstrap\",\n                json={\"embedding_port\": None},\n            )\n            response_json = await response.json()\n            return response_json[\"embedding_port\"]\n\n\nclass EncoderProfiler:\n    def __init__(self, rank: int):\n        self.rank = rank\n        self.profiler = None\n        self.steps_left = None\n        self.output_dir = None\n        self.prefix = None\n        self.profile_id = None\n\n    def start(self, obj: ProfileReq):\n        if self.profiler is not None:\n            return False, \"profiling already running\"\n\n        output_dir = obj.output_dir or os.getenv(\"SGLANG_TORCH_PROFILER_DIR\", \"/tmp\")\n        os.makedirs(output_dir, exist_ok=True)\n        self.output_dir = output_dir\n        self.prefix = obj.profile_prefix or \"encoder\"\n        self.profile_id = str(time.time())\n\n        activities = obj.activities or [\"CPU\", \"GPU\"]\n        torch_activities = []\n        if \"CPU\" in activities:\n            torch_activities.append(torch.profiler.ProfilerActivity.CPU)\n        if \"GPU\" in activities:\n            torch_activities.append(torch.profiler.ProfilerActivity.CUDA)\n\n        profile_memory = \"MEM\" in activities\n        if not torch_activities and not profile_memory:\n            return False, \"no supported activities\"\n\n        self.profiler = torch.profiler.profile(\n            activities=torch_activities,\n            with_stack=True if obj.with_stack is None else obj.with_stack,\n            record_shapes=False if obj.record_shapes is None else obj.record_shapes,\n            profile_memory=profile_memory,\n        )\n        self.profiler.start()\n        self.steps_left = obj.num_steps\n        logger.info(\n            f\"Encoder profiling started. output_dir={self.output_dir} profile_id={self.profile_id}\"\n        )\n        return True, None\n\n    def step(self):\n        if self.profiler is None:\n            return\n        self.profiler.step()\n        if self.steps_left is not None:\n            self.steps_left -= 1\n            if self.steps_left <= 0:\n                self.stop()\n\n    def stop(self):\n        if self.profiler is None:\n            return False, \"profiling not running\"\n        self.profiler.stop()\n        filename = f\"{self.prefix}-rank{self.rank}-{self.profile_id}.trace.json\"\n        trace_path = os.path.join(self.output_dir, filename)\n        self.profiler.export_chrome_trace(trace_path)\n        logger.info(\"Encoder profiling saved to: %s\", trace_path)\n        self.profiler = None\n        self.steps_left = None\n        return True, None\n\n\napp = FastAPI()\nencoder: Optional[MMEncoder] = None\nsend_sockets: List[zmq.Socket] = []\n\n\nasync def run_encoder(\n    server_args: ServerArgs, schedule_path, dist_init_method, rank: int\n):\n    encoder = MMEncoder(server_args, schedule_path, dist_init_method, rank)\n    while True:\n        request = await encoder.schedule_socket.recv_pyobj()\n        if isinstance(request, ProfileReq):\n            if request.type == ProfileReqType.START_PROFILE:\n                if encoder.profiler is None:\n                    encoder.profiler = EncoderProfiler(encoder.rank)\n                encoder.profiler.start(request)\n            else:\n                encoder.profiler.stop()\n        else:\n            if encoder.mm_global_cache is not None:\n                await encoder.encode_with_global_cache(\n                    mm_items=request[\"mm_items\"],\n                    modality=Modality.from_str(request[\"modality\"]),\n                    req_id=request[\"req_id\"],\n                    num_parts=request[\"num_parts\"],\n                    part_idx=request[\"part_idx\"],\n                    hashes=request.get(\"hashes\", None),\n                )\n            else:\n                await encoder.encode(\n                    mm_items=request[\"mm_items\"],\n                    modality=Modality.from_str(request[\"modality\"]),\n                    req_id=request[\"req_id\"],\n                    num_parts=request[\"num_parts\"],\n                    part_idx=request[\"part_idx\"],\n                )\n\n\ndef launch_encoder(server_args, schedule_path, dist_init_method, rank):\n    try:\n        asyncio.run(run_encoder(server_args, schedule_path, dist_init_method, rank))\n    except KeyboardInterrupt:\n        logger.info(f\"Exit rank {rank}\")\n    except Exception:\n        traceback.print_exc()\n\n\ndef launch_server(server_args: ServerArgs):\n    global encoder\n    ctx = mp.get_context(\"spawn\")\n    zmq_ctx = zmq.Context(10)\n    ipc_path_prefix = random_uuid()\n    port_args = PortArgs.init_new(server_args)\n    if server_args.dist_init_addr:\n        na = NetworkAddress.parse(server_args.dist_init_addr)\n        dist_init_method = na.to_tcp()\n    else:\n        dist_init_method = NetworkAddress(\n            server_args.host or \"127.0.0.1\", port_args.nccl_port\n        ).to_tcp()\n    for rank in range(1, server_args.tp_size):\n        schedule_path = f\"ipc:///tmp/{ipc_path_prefix}_schedule_{rank}\"\n        send_sockets.append(\n            get_zmq_socket(zmq_ctx, zmq.PUSH, schedule_path, bind=False)\n        )\n        ctx.Process(\n            target=launch_encoder,\n            args=(server_args, schedule_path, dist_init_method, rank),\n            daemon=True,\n        ).start()\n    encoder = MMEncoder(server_args, dist_init_method=dist_init_method)\n    uvicorn.run(app, host=server_args.host, port=server_args.port)\n\n\nasync def get_condition(rid):\n    async with cond_dict_lock:\n        if rid not in rid_to_cond:\n            rid_to_cond[rid] = asyncio.Condition()\n        return rid_to_cond[rid]\n\n\n@app.post(\"/encode\")\nasync def handle_encode_request(request: dict):\n    req_id = request[\"req_id\"]\n    try:\n\n        def start_background_send(req_id):\n            task = asyncio.create_task(encoder.send_with_url(req_id=req_id))\n            encoder.background_tasks.add(task)\n            task.add_done_callback(encoder.background_tasks.discard)\n\n        # broadcast request\n        request.update({\"enter_time\": time.time()})\n        for socket in send_sockets:\n            socket.send_pyobj(request)\n        if encoder.mm_global_cache is not None:\n            nbytes, embedding_len, embedding_dim, error_msg, error_code = (\n                await encoder.encode_with_global_cache(\n                    mm_items=request[\"mm_items\"],\n                    modality=Modality.from_str(request[\"modality\"]),\n                    req_id=request[\"req_id\"],\n                    num_parts=request[\"num_parts\"],\n                    part_idx=request[\"part_idx\"],\n                    hashes=request.get(\"hashes\", None),\n                )\n            )\n        else:\n            nbytes, embedding_len, embedding_dim, error_msg, error_code = (\n                await encoder.encode(\n                    mm_items=request[\"mm_items\"],\n                    modality=Modality.from_str(request[\"modality\"]),\n                    req_id=request[\"req_id\"],\n                    num_parts=request[\"num_parts\"],\n                    part_idx=request[\"part_idx\"],\n                )\n            )\n\n        if error_msg:\n            if encoder.server_args.encoder_transfer_backend == \"zmq_to_scheduler\":\n                if request[\"embedding_port\"] is None:\n                    start_background_send(req_id)\n                else:\n                    for port in request[\"embedding_port\"]:\n                        await encoder.send(\n                            req_id=req_id,\n                            prefill_host=request[\"prefill_host\"],\n                            embedding_port=port,\n                        )\n            return ORJSONResponse(\n                status_code=error_code,\n                content={\"status\": \"error\", \"message\": error_msg, \"req_id\": req_id},\n            )\n        if encoder.server_args.encoder_transfer_backend == \"mooncake\":\n            del request[\"mm_items\"]\n            request.update(\n                {\n                    \"embedding_size\": nbytes,\n                    \"embedding_len\": embedding_len,\n                    \"embedding_dim\": embedding_dim,\n                }\n            )\n            return ORJSONResponse(content=request)\n        elif encoder.server_args.encoder_transfer_backend == \"zmq_to_scheduler\":\n            logger.info(f\"{request['embedding_port'] = }\")\n            if request[\"embedding_port\"] is None:\n                await encoder.send_with_url(\n                    req_id=request[\"req_id\"],\n                )\n            else:\n                assert type(request[\"embedding_port\"]) == list\n                tasks = []\n                for embedding_port in request[\"embedding_port\"]:\n                    tasks.append(\n                        encoder.send(\n                            req_id=request[\"req_id\"],\n                            prefill_host=request[\"prefill_host\"],\n                            embedding_port=embedding_port,\n                        )\n                    )\n                await asyncio.gather(*tasks)\n                encoder.embedding_to_send.pop(request[\"req_id\"], None)\n            return ORJSONResponse(content=None)\n        elif encoder.server_args.encoder_transfer_backend == \"zmq_to_tokenizer\":\n            await encoder.send(\n                req_id=request[\"req_id\"],\n                prefill_host=request[\"prefill_host\"],\n                embedding_port=request[\"embedding_port\"],\n            )\n            encoder.embedding_to_send.pop(request[\"req_id\"], None)\n            return ORJSONResponse(content=None)\n    except Exception as e:\n        error_msg = str(e)\n        logger.error(f\"Unexpected error in encoder logic for {req_id}: {error_msg}\")\n        rid_to_err_msg[req_id] = error_msg\n        return ORJSONResponse(\n            status_code=HTTPStatus.INTERNAL_SERVER_ERROR,\n            content={\n                \"status\": \"error\",\n                \"message\": error_msg,\n                \"req_id\": req_id,\n            },\n        )\n\n\n@app.post(\"/send\")\nasync def handle_send_request(request: dict):\n    # mooncake backend\n    await encoder.send(\n        req_id=request[\"req_id\"],\n        prefill_host=request[\"prefill_host\"],\n        embedding_port=request[\"embedding_port\"],\n        session_id=request[\"session_id\"],\n        buffer_address=request[\"buffer_address\"],\n    )\n    encoder.embedding_to_send.pop(request[\"req_id\"], None)\n    return ORJSONResponse(content=None)\n\n\n@app.post(\"/scheduler_receive_url\")\nasync def handle_scheduler_receive_url_request(request: dict):\n    rid = request[\"req_id\"]\n    async with rid_lock:\n        global rid_to_receive_endpoint\n        if rid not in rid_to_receive_endpoint:\n            rid_to_receive_endpoint[rid] = set()\n            rid_to_receive_count[rid] = request[\"receive_count\"]\n        assert rid_to_receive_count[rid] == request[\"receive_count\"]\n        rid_to_receive_endpoint[rid].add(request[\"receive_url\"])\n    cond = await get_condition(rid)\n    async with cond:\n        cond.notify_all()\n\n\n@app.get(\"/health\")\n@app.get(\"/health_generate\")\nasync def health_generate():\n    \"\"\"\n    Health check endpoint for the encoder server.\n    Returns 200 if the encoder is initialized and ready.\n    \"\"\"\n    if encoder is None:\n        return Response(status_code=503)\n    return Response(status_code=200)\n\n\n@app.api_route(\"/start_profile\", methods=[\"GET\", \"POST\"])\nasync def start_profile_async(obj: Optional[ProfileReqInput] = None):\n    if encoder is None:\n        return Response(content=\"encoder not ready\\n\", status_code=503)\n    req = None\n    if obj is None:\n        req = ProfileReq(ProfileReqType.START_PROFILE)\n    else:\n        req = ProfileReq(\n            type=ProfileReqType.START_PROFILE,\n            output_dir=obj.output_dir,\n            start_step=obj.start_step,\n            num_steps=obj.num_steps,\n            activities=obj.activities,\n            with_stack=obj.with_stack,\n            record_shapes=obj.record_shapes,\n            profile_by_stage=obj.profile_by_stage,\n            profile_id=str(time.time()),\n            merge_profiles=obj.merge_profiles,\n            profile_prefix=obj.profile_prefix,\n            profile_stages=obj.profile_stages,\n        )\n    for socket in send_sockets:\n        socket.send_pyobj(req)\n    if encoder.profiler is None:\n        encoder.profiler = EncoderProfiler(encoder.rank)\n    ok, msg = encoder.profiler.start(req)\n    if ok:\n        detail = (\n            f\"Start profiling. output_dir={encoder.profiler.output_dir} \"\n            f\"profile_id={encoder.profiler.profile_id}\\n\"\n        )\n        return Response(content=detail, status_code=200)\n    return Response(\n        content=(msg or \"Start profiling failed.\\n\"), status_code=HTTPStatus.BAD_REQUEST\n    )\n\n\n@app.api_route(\"/stop_profile\", methods=[\"GET\", \"POST\"])\nasync def stop_profile_async():\n    if encoder is None:\n        return Response(content=\"encoder not ready\\n\", status_code=503)\n    if encoder.profiler is None:\n        return Response(\n            content=\"profiling not initialized\\n\", status_code=HTTPStatus.BAD_REQUEST\n        )\n    req = ProfileReq(ProfileReqType.STOP_PROFILE)\n    for socket in send_sockets:\n        socket.send_pyobj(req)\n    ok, msg = encoder.profiler.stop()\n    if ok:\n        return Response(content=\"Stop profiling.\\n\", status_code=200)\n    return Response(\n        content=(msg or \"Stop profiling failed.\\n\"), status_code=HTTPStatus.BAD_REQUEST\n    )\n"
  },
  {
    "path": "python/sglang/srt/disaggregation/fake/__init__.py",
    "content": "from sglang.srt.disaggregation.fake.conn import (\n    FakeKVManager,\n    FakeKVReceiver,\n    FakeKVSender,\n)\n"
  },
  {
    "path": "python/sglang/srt/disaggregation/fake/conn.py",
    "content": "import logging\nfrom typing import List, Optional\n\nimport numpy as np\nimport numpy.typing as npt\n\nfrom sglang.srt.disaggregation.base.conn import (\n    BaseKVManager,\n    BaseKVReceiver,\n    BaseKVSender,\n    KVArgs,\n    KVPoll,\n)\nfrom sglang.srt.disaggregation.utils import DisaggregationMode\nfrom sglang.srt.server_args import ServerArgs\n\nlogger = logging.getLogger(__name__)\n\n\n# For warmup reqs, we don't kv transfer, we use the fake manager, sender and receiver\nclass FakeKVManager(BaseKVManager):\n    def __init__(\n        self,\n        args: KVArgs,\n        disaggregation_mode: DisaggregationMode,\n        server_args: ServerArgs,\n        is_mla_backend: Optional[bool] = False,\n    ):\n        super().__init__(args, disaggregation_mode, server_args, is_mla_backend)\n\n    def register_to_bootstrap(self):\n        pass\n\n\nclass FakeKVSender(BaseKVSender):\n    def __init__(\n        self,\n        mgr: BaseKVManager,\n        bootstrap_addr: str,\n        bootstrap_room: int,\n        dest_tp_ranks: List[int],\n        pp_rank: int,\n    ):\n        self.has_sent = False\n\n    def poll(self) -> KVPoll:\n        if self.has_sent is False:\n            # Assume handshake completed instantly\n            return KVPoll.WaitingForInput\n        else:\n            # Assume transfer completed instantly\n            logger.debug(\"FakeKVSender poll success\")\n            return KVPoll.Success\n\n    def init(\n        self,\n        kv_indices: list[int],\n        aux_index: Optional[int] = None,\n    ):\n        logger.debug(\n            f\"FakeKVSender init with kv_indices: {kv_indices}, aux_index: {aux_index}\"\n        )\n        pass\n\n    def send(\n        self,\n        kv_indices: npt.NDArray[np.int32],\n        state_indices: Optional[List[int]] = None,\n    ):\n        self.has_sent = True\n        logger.debug(\n            f\"FakeKVSender send with kv_indices: {kv_indices}, state_indices: {state_indices}\"\n        )\n\n    def failure_exception(self):\n        raise Exception(\"Fake KVSender Exception\")\n\n\nclass FakeKVReceiver(BaseKVReceiver):\n    def __init__(\n        self,\n        mgr: BaseKVManager,\n        bootstrap_addr: str,\n        bootstrap_room: Optional[int] = None,\n        prefill_dp_rank: Optional[int] = None,\n    ):\n        self.has_init = False\n\n    def poll(self) -> KVPoll:\n        if self.has_init is False:\n            # Assume handshake completed instantly\n            return KVPoll.WaitingForInput\n        else:\n            # Assume transfer completed instantly\n            logger.debug(\"FakeKVReceiver poll success\")\n            return KVPoll.Success\n\n    def init(\n        self,\n        kv_indices: list[int],\n        aux_index: Optional[int] = None,\n        state_indices: Optional[List[int]] = None,\n    ):\n        self.has_init = True\n        logger.debug(\n            f\"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}, state_indices: {state_indices}\"\n        )\n\n    def failure_exception(self):\n        raise Exception(\"Fake KVReceiver Exception\")\n"
  },
  {
    "path": "python/sglang/srt/disaggregation/kv_events.py",
    "content": "\"\"\"\nCopyright 2025 SGLang Team\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\n\"\"\"\nKV caching events\n\"\"\"\n\nimport atexit\nimport logging\nimport queue\nimport threading\nimport time\nfrom abc import ABC, abstractmethod\nfrom collections import deque\nfrom itertools import count\nfrom queue import Queue\nfrom typing import Any, Callable, Optional, Union\n\nimport msgspec\nimport zmq\nfrom pydantic import BaseModel\n\nlogger = logging.getLogger(__name__)\n\n\nclass EventBatch(\n    msgspec.Struct,\n    array_like=True,  # type: ignore[call-arg]\n    omit_defaults=True,  # type: ignore[call-arg]\n    gc=False,  # type: ignore[call-arg]\n):\n    ts: float\n    events: list[Any]\n    attn_dp_rank: Optional[int] = None\n\n\nclass KVCacheEvent(\n    msgspec.Struct,\n    array_like=True,  # type: ignore[call-arg]\n    omit_defaults=True,  # type: ignore[call-arg]\n    gc=False,  # type: ignore[call-arg]\n    tag=True,\n):\n    \"\"\"Base class for all KV cache-related events\"\"\"\n\n\n# Medium values for hicache storage tiers\nMEDIUM_GPU = \"GPU\"\nMEDIUM_CPU = \"CPU_PINNED\"\n\n\nclass OffloadedState:\n    \"\"\"\n    OffloadedState represents the state of a KV cache block offloaded to the hicache.\n\n    - prefill_len (int): The length of the prefill part of the KV cache block.\n    - inc_len (int): The length of the incremental part of the KV cache block.\n    - last_hash (Optional[str]): The hash of the last token in the KV cache block.\n    \"\"\"\n\n    def __init__(\n        self, prefill_len: int, inc_len: int = 0, last_hash: Optional[str] = None\n    ):\n        self.prefill_len = prefill_len\n        self.inc_len = inc_len\n        self.last_hash = last_hash\n\n\nclass BlockStored(KVCacheEvent):\n    block_hashes: list[int]\n    parent_block_hash: Optional[int]\n    token_ids: list[int]\n    block_size: int\n    lora_id: Optional[int]\n    medium: Optional[str] = None\n\n\nclass BlockRemoved(KVCacheEvent):\n    block_hashes: list[int]\n    medium: Optional[str] = None\n\n\nclass AllBlocksCleared(KVCacheEvent):\n    pass\n\n\nclass KVEventBatch(EventBatch):\n    events: list[Union[BlockStored, BlockRemoved, AllBlocksCleared]]\n\n\nclass EventPublisher(ABC):\n    \"\"\"\n    Lightweight publisher for EventBatch batches with\n    support for DP attention.\n\n    In DP attention - each rank has its own Scheduler and\n    KV cache instance in order to avoid duplicate events\n    and ensure proper event attribution. In our implementation\n\n    - Each DP rank has its own EventPublisher\n    - Publishers annotate events with the dp rank\n    - This allows consumers to distinguish events from different DP ranks\n    \"\"\"\n\n    def __init__(self, attn_dp_rank: int = 0):\n        self._attn_dp_rank = attn_dp_rank\n\n    @abstractmethod\n    def publish(self, events: EventBatch) -> None:\n        \"\"\"Emit events in order.\n\n        Implementations should guarantee at-least-once delivery and\n        monotonic ordering (e.g., via sequence numbers).\n        \"\"\"\n\n    @abstractmethod\n    def shutdown(self) -> None:\n        \"\"\"Shutdown the publisher.\"\"\"\n\n\nclass NullEventPublisher(EventPublisher):\n    \"\"\"No-op implementation (default when disabled).\"\"\"\n\n    def publish(self, events) -> None:\n        return\n\n    def shutdown(self) -> None:\n        return\n\n\nclass ZmqEventPublisher(EventPublisher):\n    \"\"\"Reliable PUB/ROUTER publisher with an in-memory replay buffer.\n\n    Spawns a separate thread to handle publishing from a queue.\n\n    Parameters\n    ----------\n    endpoint:\n        PUB address. Use ``tcp://*:5557`` to bind or ``tcp://host:5557`` to\n        connect.\n    replay_endpoint:\n        Optional ROUTER address for replay requests. When given, subscribers can\n        request missed batches by sending the starting sequence number as an\n        8-byte big-endian integer.\n    buffer_steps:\n        Number of past batches to keep for replay.\n    hwm:\n        ZeroMQ high-water-mark for PUB socket.\n    max_queue_size:\n        Maximum number of events to buffer in memory.\n    topic:\n        Topic to publish events to.\n    \"\"\"\n\n    SHUTDOWN_TIMEOUT: float = 1.0\n    END_SEQ = (-1).to_bytes(8, \"big\", signed=True)\n\n    def __init__(\n        self,\n        attn_dp_rank: int,\n        endpoint: str = \"tcp://*:5557\",\n        replay_endpoint: Optional[str] = None,\n        buffer_steps: int = 10_000,\n        hwm: int = 100_000,\n        max_queue_size: int = 100_000,\n        topic: str = \"\",\n    ) -> None:\n        # Storage\n        super().__init__(attn_dp_rank)\n        self._event_queue = Queue[Optional[EventBatch]](maxsize=max_queue_size)\n        self._buffer = deque[tuple[int, bytes]](maxlen=buffer_steps)\n\n        # ZMQ sockets\n        self._ctx = zmq.Context.instance()\n        self._pub: Optional[zmq.Socket] = None\n        self._replay: Optional[zmq.Socket] = None\n        self._dp_rank = attn_dp_rank\n        self._endpoint = self.offset_endpoint_port(endpoint, self._dp_rank)\n        self._replay_endpoint = self.offset_endpoint_port(\n            replay_endpoint, self._dp_rank\n        )\n        self._hwm = hwm\n        self._socket_setup()\n\n        # Payload\n        self._seq_gen = count()\n        self._topic_bytes = topic.encode(\"utf-8\")\n\n        # Thread\n        self._running = True\n        logger.info(\"Starting ZMQ publisher thread\")\n\n        self._thread = threading.Thread(\n            target=self._publisher_thread, daemon=True, name=\"zmq-publisher\"\n        )\n        self._thread.start()\n\n        atexit.register(self.shutdown)\n\n    def publish(self, events: EventBatch) -> None:\n        if not self._running:\n            raise RuntimeError(\"Publisher is closed\")\n        if events.attn_dp_rank is None:\n            events.attn_dp_rank = self._dp_rank\n        self._event_queue.put(events)\n\n    def shutdown(self) -> None:\n        \"\"\"Stop the publisher thread and clean up resources.\"\"\"\n        self._running = False\n        self._event_queue.put_nowait(None)\n\n        start = time.time()\n        pending_items = True\n        while pending_items and (time.time() - start < self.SHUTDOWN_TIMEOUT):\n            pending_items = not self._event_queue.empty()\n            if pending_items:\n                time.sleep(0.1)\n\n        if pending_items:\n            logger.warning(\n                \"Warning: Queue still has %s items after %s seconds timeout\",\n                self._event_queue.qsize(),\n                self.SHUTDOWN_TIMEOUT,\n            )\n\n        if self._thread.is_alive():\n            self._thread.join(timeout=self.SHUTDOWN_TIMEOUT)\n\n        # Clean up ZMQ resources\n        try:\n            if self._pub is not None:\n                self._pub.close(linger=0)\n            if self._replay is not None:\n                self._replay.close(linger=0)\n        finally:\n            pass  # Do not terminate context; other sockets may use it\n\n    def _socket_setup(self) -> None:\n        \"\"\"Initialize sockets\n        https://pyzmq.readthedocs.io/en/v19.0.0/morethanbindings.html#thread-safety\n        \"\"\"\n        if self._pub is None:\n            self._pub = self._ctx.socket(zmq.PUB)\n            self._pub.set_hwm(self._hwm)\n            # Heuristic: bind if wildcard / * present, else connect.\n            # bind stable, connect volatile convention\n            if (\n                \"*\" in self._endpoint\n                or \"::\" in self._endpoint\n                or self._endpoint.startswith(\"ipc://\")\n                or self._endpoint.startswith(\"inproc://\")\n            ):\n                logger.debug(\n                    f\"ZmqEventPublisher socket publisher_endpoint bind to {self._endpoint}\"\n                )\n                self._pub.bind(self._endpoint)\n            else:\n                self._pub.connect(self._endpoint)\n\n        # Set up replay socket: use ROUTER\n        # 1) handles multiple REQ clients (identities)\n        # 2) lets us send back one request → many replies (streamed events)\n        # 3) works in our non‑blocking poll loop alongside PUB\n        if self._replay_endpoint is not None:\n            self._replay = self._ctx.socket(zmq.ROUTER)\n            logger.debug(\n                f\"ZmqEventPublisher socket replay_endpoint bind to {self._replay_endpoint}\"\n            )\n            self._replay.bind(self._replay_endpoint)\n\n    def _publisher_thread(self) -> None:\n        \"\"\"Background thread that processes the event queue.\"\"\"\n        self._pack = msgspec.msgpack.Encoder()\n\n        assert self._pub is not None  # narrows type for mypy\n\n        while self._running or self._event_queue.qsize() > 0:\n            # --- replay (non-critical) ---------------------------------\n            if self._replay is not None and self._replay.poll(0):\n                try:\n                    self._service_replay()\n                except Exception as e:\n                    logger.exception(\"Error in replay: %s\", e)\n\n            # --- main queue (critical) ---------------------------------\n            try:\n                event = self._event_queue.get(timeout=0.1)\n                if event is None:\n                    break  # Sentinel received, exit thread\n            except queue.Empty:\n                continue\n\n            try:\n                seq = next(self._seq_gen)\n\n                payload = self._pack.encode(event)\n                seq_bytes = seq.to_bytes(8, \"big\")\n                self._pub.send_multipart((self._topic_bytes, seq_bytes, payload))\n\n                self._buffer.append((seq, payload))\n                self._event_queue.task_done()\n\n            except Exception as e:\n                # Publishing failed;  back-off a bit to avoid a tight error loop\n                logger.exception(\"Error in publisher thread: %s\", e)\n                time.sleep(0.1)\n\n    def _service_replay(self) -> None:\n        \"\"\"If a replay request is waiting, send buffered batches.\"\"\"\n        assert self._replay is not None  # narrows type for mypy\n\n        frame = self._replay.recv_multipart()\n        if len(frame) != 3:\n            logger.warning(\"Invalid replay request: %s\", frame)\n            return\n        client_id, _, start_seq_bytes = frame\n        start_seq = int.from_bytes(start_seq_bytes, \"big\")\n\n        for seq, buf in self._buffer:\n            if seq >= start_seq:\n                # [identity, empty_delim, seq_bytes, payload]\n                # (identity, empty_delim) are stripped off by the router\n                # receiving payload is (seq_bytes, payload)\n                self._replay.send_multipart(\n                    (client_id, b\"\", seq.to_bytes(8, \"big\"), buf)\n                )\n        # Send end of sequence marker\n        # receiving payload is (-1, b\"\"\")\n        self._replay.send_multipart((client_id, b\"\", self.END_SEQ, b\"\"))\n\n    @staticmethod\n    def offset_endpoint_port(\n        endpoint: Optional[str], data_parallel_rank: int\n    ) -> Optional[str]:\n        \"\"\"Helper function to offset the port in an endpoint by\n            the data parallel rank.\n\n        Args:\n            endpoint: The endpoint string\n                (e.g., \"tcp://*:5557\" or \"inproc://cache\")\n            data_parallel_rank: The data parallel rank to offset by\n\n        Returns:\n            The endpoint with the port offset by data_parallel_rank\n                or suffix appended\n        \"\"\"\n        # Do nothing if input is None or data_parallel_rank is 0\n        if not endpoint or data_parallel_rank == 0:\n            return endpoint\n\n        if \"inproc\" in endpoint:\n            return f\"{endpoint}_dp{data_parallel_rank}\"\n        if \"tcp\" in endpoint:\n            if endpoint and \":\" in endpoint:\n                # Get everything after the last colon (the port)\n                last_colon_idx = endpoint.rfind(\":\")\n                base_addr = endpoint[:last_colon_idx]\n                base_port = int(endpoint[last_colon_idx + 1 :])\n                new_port = base_port + data_parallel_rank\n                return f\"{base_addr}:{new_port}\"\n            return endpoint\n        raise ValueError(\"Invalid endpoint: must contain 'inproc' or 'tcp'\")\n\n\nclass KVEventsConfig(BaseModel):\n    \"\"\"Configuration for KV event publishing.\"\"\"\n\n    publisher: str = \"null\"\n    \"\"\"The publisher to use for publishing kv events. Can be \"null\", \"zmq\".\n    \"\"\"\n\n    endpoint: str = \"tcp://*:5557\"\n    \"\"\"The zmq endpoint to use for publishing kv events.\n    \"\"\"\n\n    replay_endpoint: Optional[str] = None\n    \"\"\"The zmq endpoint to use for replaying kv events.\n    \"\"\"\n\n    buffer_steps: int = 10_000\n    \"\"\"The number of steps to cache for replay endpoint. Will only save\n    events from the last N steps for the replay endpoint.\n    \"\"\"\n\n    hwm: int = 100_000\n    \"\"\"The zmq high water mark for the event publisher. After queueing N events,\n    events will start dropping if the consumer is not keeping up.\n    \"\"\"\n\n    max_queue_size: int = 100_000\n    \"\"\"The maximum number of events to queue while waiting for publishing.\n    \"\"\"\n\n    topic: str = \"\"\n    \"\"\"The topic to use for the event publisher. Consumers can subscribe to\n    this topic to receive events.\n    \"\"\"\n\n    @classmethod\n    def from_cli(cls, cli_value: str) -> \"KVEventsConfig\":\n        \"\"\"Parse the CLI value for the event publisher config.\"\"\"\n        return KVEventsConfig.model_validate_json(cli_value)\n\n\nclass EventPublisherFactory:\n    _registry: dict[str, Callable[..., EventPublisher]] = {\n        \"null\": NullEventPublisher,\n        \"zmq\": ZmqEventPublisher,\n    }\n\n    @classmethod\n    def register_publisher(cls, name: str, ctor: Callable[..., EventPublisher]) -> None:\n        if name in cls._registry:\n            raise KeyError(f\"publisher '{name}' already registered\")\n        cls._registry[name] = ctor\n\n    @classmethod\n    def create(cls, config: Optional[str], attn_dp_rank: int = 0) -> EventPublisher:\n        \"\"\"Create publisher from a config mapping.\"\"\"\n        if not config:\n            return NullEventPublisher()\n        config = KVEventsConfig.from_cli(config)\n        config_dict = config.model_dump()\n\n        kind = config_dict.pop(\"publisher\", \"null\")\n        try:\n            constructor = cls._registry[kind]\n        except KeyError as exc:\n            raise ValueError(f\"Unknown event publisher '{kind}'\") from exc\n        return constructor(attn_dp_rank=attn_dp_rank, **config_dict)\n"
  },
  {
    "path": "python/sglang/srt/disaggregation/mooncake/__init__.py",
    "content": "from sglang.srt.disaggregation.mooncake.conn import (\n    MooncakeKVBootstrapServer,\n    MooncakeKVManager,\n    MooncakeKVReceiver,\n    MooncakeKVSender,\n)\n"
  },
  {
    "path": "python/sglang/srt/disaggregation/mooncake/conn.py",
    "content": "from __future__ import annotations\n\nimport concurrent.futures\nimport ctypes\nimport dataclasses\nimport logging\nimport os\nimport struct\nimport threading\nimport time\nfrom collections import defaultdict\nfrom typing import List, Optional, Tuple\n\nimport numpy as np\nimport numpy.typing as npt\n\nfrom sglang.srt.disaggregation.base.conn import KVArgs, KVPoll\nfrom sglang.srt.disaggregation.common.conn import (\n    CommonKVBootstrapServer,\n    CommonKVManager,\n    CommonKVReceiver,\n    CommonKVSender,\n)\nfrom sglang.srt.disaggregation.common.utils import (\n    FastQueue,\n    group_concurrent_contiguous,\n)\nfrom sglang.srt.disaggregation.mooncake.utils import (\n    check_mooncake_custom_mem_pool_enabled,\n)\nfrom sglang.srt.disaggregation.utils import (\n    DisaggregationMode,\n    filter_kv_indices_for_cp_rank,\n)\nfrom sglang.srt.distributed.parallel_state import get_mooncake_transfer_engine\nfrom sglang.srt.environ import envs\nfrom sglang.srt.server_args import ServerArgs\nfrom sglang.srt.utils.network import NetworkAddress\n\nlogger = logging.getLogger(__name__)\n\n\nclass KVTransferError(Exception):\n    def __init__(self, bootstrap_room: int, failure_reason: str):\n        super().__init__(failure_reason)\n        self.bootstrap_room = bootstrap_room\n        self.failure_reason = failure_reason\n\n    def __str__(self):\n        return f\"KVTransferError(bootstrap_room={self.bootstrap_room}): {self.failure_reason}\"\n\n\n# prefill\n@dataclasses.dataclass\nclass TransferKVChunk:\n    room: int\n    prefill_kv_indices: npt.NDArray[np.int32]\n    index_slice: slice\n    is_last_chunk: bool\n    prefill_aux_index: Optional[int]\n    state_indices: Optional[List[int]]\n\n\n# decode\n@dataclasses.dataclass\nclass TransferInfo:\n    room: int\n    endpoint: str\n    dst_port: int\n    mooncake_session_id: str\n    dst_kv_indices: npt.NDArray[np.int32]\n    dst_aux_index: int\n    dst_state_indices: List[int]\n    required_dst_info_num: int\n    is_dummy: bool\n\n    @classmethod\n    def from_zmq(cls, msg: List[bytes]):\n        if msg[4] == b\"\" and msg[5] == b\"\":\n            is_dummy = True\n            dst_kv_indices = np.array([], dtype=np.int32)\n            dst_aux_index = None\n            dst_state_indices = []\n        else:\n            dst_kv_indices = np.frombuffer(msg[4], dtype=np.int32)\n            dst_aux_index = int(msg[5].decode(\"ascii\"))\n            if msg[6] == b\"\":\n                dst_state_indices = []\n            else:\n                dst_state_indices = list(np.frombuffer(msg[6], dtype=np.int32))\n            is_dummy = False\n        return cls(\n            room=int(msg[0].decode(\"ascii\")),\n            endpoint=msg[1].decode(\"ascii\"),\n            dst_port=int(msg[2].decode(\"ascii\")),\n            mooncake_session_id=msg[3].decode(\"ascii\"),\n            dst_kv_indices=dst_kv_indices,\n            dst_aux_index=dst_aux_index,\n            dst_state_indices=dst_state_indices,\n            required_dst_info_num=int(msg[7].decode(\"ascii\")),\n            is_dummy=is_dummy,\n        )\n\n\n# decode\n@dataclasses.dataclass\nclass KVArgsRegisterInfo:\n    room: str\n    endpoint: str\n    dst_port: int\n    mooncake_session_id: str\n    dst_kv_ptrs: list[int]\n    dst_aux_ptrs: list[int]\n    dst_state_data_ptrs: list[int]\n    dst_tp_rank: int\n    dst_attn_tp_size: int\n    dst_kv_item_len: int\n    # for mamba state different tp slice transfer\n    dst_state_item_lens: list[int]\n    dst_state_dim_per_tensor: list[int]\n\n    @classmethod\n    def from_zmq(cls, msg: List[bytes]):\n        return cls(\n            room=str(msg[0].decode(\"ascii\")),\n            endpoint=msg[1].decode(\"ascii\"),\n            dst_port=int(msg[2].decode(\"ascii\")),\n            mooncake_session_id=msg[3].decode(\"ascii\"),\n            dst_kv_ptrs=list(struct.unpack(f\"{len(msg[4])//8}Q\", msg[4])),\n            dst_aux_ptrs=list(struct.unpack(f\"{len(msg[5])//8}Q\", msg[5])),\n            dst_state_data_ptrs=list(struct.unpack(f\"{len(msg[6])//8}Q\", msg[6])),\n            dst_tp_rank=int(msg[7].decode(\"ascii\")),\n            dst_attn_tp_size=int(msg[8].decode(\"ascii\")),\n            dst_kv_item_len=int(msg[9].decode(\"ascii\")),\n            dst_state_item_lens=(\n                list(struct.unpack(f\"{len(msg[10])//4}I\", msg[10]))\n                if len(msg) > 10 and len(msg[10]) > 0\n                else []\n            ),\n            dst_state_dim_per_tensor=(\n                list(struct.unpack(f\"{len(msg[11])//4}I\", msg[11]))\n                if len(msg) > 11 and len(msg[11]) > 0\n                else []\n            ),\n        )\n\n\nclass AuxDataCodec:\n    \"\"\"Handles serialization and deserialization of auxiliary data buffers\"\"\"\n\n    @staticmethod\n    def serialize_data_from_buffer(src_addr, data_length):\n        \"\"\"Serialize data from memory buffer to bytes\"\"\"\n        buffer = (ctypes.c_byte * data_length).from_address(src_addr)\n        return bytes(buffer)\n\n    @staticmethod\n    def deserialize_data_to_buffer(kv_args, buffer_index, aux_index, data):\n        \"\"\"Deserialize bytes into target memory buffer\"\"\"\n        dst_aux_ptr = kv_args.aux_data_ptrs[buffer_index]\n        item_len = kv_args.aux_item_lens[buffer_index]\n        dst_addr = dst_aux_ptr + item_len * aux_index\n        buffer = (ctypes.c_byte * len(data)).from_address(dst_addr)\n        buffer[:] = data\n        return\n\n\nclass MooncakeKVManager(CommonKVManager):\n    AUX_DATA_HEADER = b\"AUX_DATA\"\n\n    def __init__(\n        self,\n        args: KVArgs,\n        disaggregation_mode: DisaggregationMode,\n        server_args: ServerArgs,\n        is_mla_backend: Optional[bool] = False,\n    ):\n        super().__init__(args, disaggregation_mode, server_args, is_mla_backend)\n        self.init_engine()\n        self.register_buffer_to_engine()\n        if self.disaggregation_mode == DisaggregationMode.PREFILL:\n            self.start_prefill_thread()\n            self.session_failures = defaultdict(int)\n            self.failed_sessions = set()\n            self.session_lock = threading.Lock()\n            # Determine the number of threads to use for kv sender\n            cpu_count = os.cpu_count()\n            transfer_thread_pool_size = (\n                envs.SGLANG_DISAGGREGATION_THREAD_POOL_SIZE.get()\n            )\n            if transfer_thread_pool_size is None:\n                transfer_thread_pool_size = min(max(4, int(0.5 * cpu_count) // 8), 12)\n            transfer_queue_size = envs.SGLANG_DISAGGREGATION_QUEUE_SIZE.get()\n            self.transfer_queues: List[FastQueue] = [\n                FastQueue() for _ in range(transfer_queue_size)\n            ]\n            assert transfer_thread_pool_size >= transfer_queue_size, (\n                f\"The environment variable SGLANG_DISAGGREGATION_THREAD_POOL_SIZE={transfer_thread_pool_size} must be \"\n                f\"greater than or equal to SGLANG_DISAGGREGATION_QUEUE_SIZE={transfer_queue_size}.\"\n            )\n            self.executors = [\n                concurrent.futures.ThreadPoolExecutor(\n                    transfer_thread_pool_size // transfer_queue_size\n                )\n                for _ in range(transfer_queue_size)\n            ]\n            for queue, executor in zip(self.transfer_queues, self.executors):\n                threading.Thread(\n                    target=self.transfer_worker, args=(queue, executor), daemon=True\n                ).start()\n            self.enable_custom_mem_pool, self.custom_mem_pool_type = (\n                check_mooncake_custom_mem_pool_enabled()\n            )\n        elif self.disaggregation_mode == DisaggregationMode.DECODE:\n            self.start_decode_thread()\n\n    def init_engine(self):\n        self.engine = get_mooncake_transfer_engine()\n\n    def register_buffer_to_engine(self):\n        # Batch register KV data buffers\n        if self.kv_args.kv_data_ptrs and self.kv_args.kv_data_lens:\n            self.engine.batch_register(\n                self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens\n            )\n\n        # Batch register auxiliary data buffers\n        if self.kv_args.aux_data_ptrs and self.kv_args.aux_data_lens:\n            self.engine.batch_register(\n                self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens\n            )\n\n        # Batch register state/extra pool data buffers\n        if self.kv_args.state_data_ptrs and self.kv_args.state_data_lens:\n            self.engine.batch_register(\n                self.kv_args.state_data_ptrs, self.kv_args.state_data_lens\n            )\n\n    def _transfer_data(self, mooncake_session_id, transfer_blocks):\n        if not transfer_blocks:\n            return 0\n\n        src_addrs, dst_addrs, lengths = zip(*transfer_blocks)\n        return self.engine.batch_transfer_sync(\n            mooncake_session_id, list(src_addrs), list(dst_addrs), list(lengths)\n        )\n\n    def _send_kvcache_generic(\n        self,\n        mooncake_session_id: str,\n        src_data_ptrs: list[int],\n        dst_data_ptrs: list[int],\n        item_lens: list[int],\n        prefill_data_indices: npt.NDArray[np.int32],\n        dst_data_indices: npt.NDArray[np.int32],\n        executor: concurrent.futures.ThreadPoolExecutor,\n    ) -> int:\n        \"\"\"\n        Generic KV cache transfer supporting both MHA and MLA architectures.\n        This method is used by both send_kvcache (full pool) and maybe_send_extra.\n        \"\"\"\n        # Group by indices for optimization\n        prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(\n            prefill_data_indices, dst_data_indices\n        )\n\n        layers_params = None\n\n        # Decode pp size should be equal to prefill pp size or 1\n        if self.is_mla_backend:\n            src_kv_ptrs, dst_kv_ptrs, layers_current_pp_stage = (\n                self.get_mla_kv_ptrs_with_pp(src_data_ptrs, dst_data_ptrs)\n            )\n            layers_params = [\n                (\n                    src_kv_ptrs[layer_id],\n                    dst_kv_ptrs[layer_id],\n                    item_lens[layer_id],\n                )\n                for layer_id in range(layers_current_pp_stage)\n            ]\n        else:\n            src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (\n                self.get_mha_kv_ptrs_with_pp(src_data_ptrs, dst_data_ptrs)\n            )\n            # item_lens structure: [k_layer0, k_layer1, ..., k_layerN, v_layer0, v_layer1, ..., v_layerN]\n            # Use correct item lengths for K and V separately\n            if layers_current_pp_stage > len(dst_k_ptrs):\n                logger.error(\n                    \"Prefill transfer kvcache error, layers_current_pp_stage is out of range: \"\n                    f\"layers_current_pp_stage={layers_current_pp_stage}, len(dst_k_ptrs)={len(dst_k_ptrs)}\"\n                )\n                return -1\n            layers_params = [\n                (\n                    src_k_ptrs[layer_id],\n                    dst_k_ptrs[layer_id],\n                    item_lens[layer_id],  # K item length\n                )\n                for layer_id in range(layers_current_pp_stage)\n            ] + [\n                (\n                    src_v_ptrs[layer_id],\n                    dst_v_ptrs[layer_id],\n                    item_lens[layers_current_pp_stage + layer_id],  # V item length\n                )\n                for layer_id in range(layers_current_pp_stage)\n            ]\n        assert layers_params is not None\n\n        def set_transfer_blocks(\n            src_ptr: int, dst_ptr: int, item_len: int\n        ) -> List[Tuple[int, int, int]]:\n            transfer_blocks = []\n            for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):\n                src_addr = src_ptr + int(prefill_index[0]) * item_len\n                dst_addr = dst_ptr + int(decode_index[0]) * item_len\n                length = item_len * len(prefill_index)\n                transfer_blocks.append((src_addr, dst_addr, length))\n            return transfer_blocks\n\n        # Worker function for processing a single layer\n        def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:\n            transfer_blocks = set_transfer_blocks(src_ptr, dst_ptr, item_len)\n            return self._transfer_data(mooncake_session_id, transfer_blocks)\n\n        # Worker function for processing all layers in a batch\n        def process_layers(layers_params: List[Tuple[int, int, int]]) -> int:\n            transfer_blocks = []\n            for src_ptr, dst_ptr, item_len in layers_params:\n                transfer_blocks.extend(set_transfer_blocks(src_ptr, dst_ptr, item_len))\n            return self._transfer_data(mooncake_session_id, transfer_blocks)\n\n        if self.enable_custom_mem_pool:\n            futures = [\n                executor.submit(\n                    process_layer,\n                    src_ptr,\n                    dst_ptr,\n                    item_len,\n                )\n                for (src_ptr, dst_ptr, item_len) in layers_params\n            ]\n            for future in concurrent.futures.as_completed(futures):\n                status = future.result()\n                if status != 0:\n                    for f in futures:\n                        f.cancel()\n                    return status\n            return 0\n        else:\n            # Combining all layers' params in one batch transfer is more efficient\n            # compared to using multiple threads\n            return process_layers(layers_params)\n\n    def send_kvcache(\n        self,\n        mooncake_session_id: str,\n        prefill_kv_indices: npt.NDArray[np.int32],\n        dst_kv_ptrs: list[int],\n        dst_kv_indices: npt.NDArray[np.int32],\n        executor: concurrent.futures.ThreadPoolExecutor,\n    ):\n        return self._send_kvcache_generic(\n            mooncake_session_id=mooncake_session_id,\n            src_data_ptrs=self.kv_args.kv_data_ptrs,\n            dst_data_ptrs=dst_kv_ptrs,\n            item_lens=self.kv_args.kv_item_lens,\n            prefill_data_indices=prefill_kv_indices,\n            dst_data_indices=dst_kv_indices,\n            executor=executor,\n        )\n\n    def send_kvcache_slice(\n        self,\n        mooncake_session_id: str,\n        prefill_kv_indices: npt.NDArray[np.int32],\n        dst_kv_ptrs: list[int],\n        dst_kv_indices: npt.NDArray[np.int32],\n        dst_tp_rank: int,\n        dst_attn_tp_size: int,\n        dst_kv_item_len: int,\n        executor: concurrent.futures.ThreadPoolExecutor,\n    ):\n        \"\"\"\n        Sends KV cache slices from this Prefill rank to a target Decode rank,\n        supporting generic M-to-N TP size configurations.\n\n        NOTE: This implementation calls the transfer engine for each token slot within\n        each page to ensure correctness for any page_size and head-slicing configuration.\n        This may introduce performance overhead (increased TTFT) for long sequences.\n        \"\"\"\n        # Extract configuration\n        local_tp_rank_in_group = self.kv_args.engine_rank % self.attn_tp_size\n        src_kv_item_len = self.kv_args.kv_item_lens[0]\n        dst_tp_rank_in_group = dst_tp_rank % dst_attn_tp_size\n        page_size = self.kv_args.page_size\n\n        # Use total KV head count (not per-rank) for correct head distribution.\n        # Per-rank kv_head_num is max(1, total//tp) which loses info when total < tp.\n        total_kv_heads = getattr(self.kv_args, \"total_kv_head_num\", 0)\n        if total_kv_heads <= 0:\n            total_kv_heads = self.kv_args.kv_head_num * self.attn_tp_size\n\n        src_heads_per_rank = max(1, total_kv_heads // self.attn_tp_size)\n        dst_heads_per_rank = max(1, total_kv_heads // dst_attn_tp_size)\n        bytes_per_head_slice_to_send = (\n            dst_kv_item_len // page_size // dst_heads_per_rank\n        )\n\n        # GQA replication: how many prefill ranks share the same KV head\n        src_replication = max(1, self.attn_tp_size // total_kv_heads)\n\n        # Determine slicing parameters based on TP configuration\n        if self.attn_tp_size > dst_attn_tp_size:\n            # Send KVCache from multiple prefill instances to 1 decode instance\n            src_head_start_offset = 0\n            num_heads_to_send = src_heads_per_rank\n            unique_head_idx = local_tp_rank_in_group // src_replication\n            dst_head_start_offset = (\n                unique_head_idx * src_heads_per_rank\n            ) % dst_heads_per_rank\n        else:\n            # Send KVCache from 1 prefill instance to multiple decode instances\n            src_head_start_offset = (\n                dst_tp_rank_in_group * dst_heads_per_rank\n            ) % src_heads_per_rank\n            num_heads_to_send = dst_heads_per_rank\n            dst_head_start_offset = 0\n\n        src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (\n            self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)\n        )\n\n        # Calculate precise byte offset and length for the sub-slice within the token\n        src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send\n        dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send\n        heads_bytes_per_token_to_send = num_heads_to_send * bytes_per_head_slice_to_send\n\n        # Sanity check: The data sub-slice to be sent should fit into the dst buffer.\n        # This means heads_bytes_per_token_to_send <= (dst_kv_item_len // page_size)\n        if heads_bytes_per_token_to_send > (dst_kv_item_len // page_size):\n            logger.error(\n                f\"[{mooncake_session_id}] slice size ({heads_bytes_per_token_to_send}) exceeds \"\n                f\"target token slot size ({dst_kv_item_len // page_size})\"\n            )\n            return -1\n\n        prefill_page_indices = prefill_kv_indices.reshape(-1, 1).astype(np.int64)\n        decode_page_indices = dst_kv_indices.reshape(-1, 1).astype(np.int64)\n        tokens_per_page = np.arange(page_size, dtype=np.int64).reshape(1, -1)\n        bytes_per_token_on_prefill = src_kv_item_len // page_size\n        bytes_per_token_on_decode = dst_kv_item_len // page_size\n        src_token_slot_offsets = (\n            tokens_per_page * bytes_per_token_on_prefill + src_head_slice_offset\n        )\n        dst_token_slot_offsets = (\n            tokens_per_page * bytes_per_token_on_decode + dst_head_slice_offset\n        )\n\n        def process_layer_tp_aware(src_layer_ptr, dst_layer_ptr):\n            src_page_base_addrs = src_layer_ptr + prefill_page_indices * src_kv_item_len\n            dst_page_base_addrs = dst_layer_ptr + decode_page_indices * dst_kv_item_len\n            src_slice_addrs = src_page_base_addrs + src_token_slot_offsets\n            dst_slice_addrs = dst_page_base_addrs + dst_token_slot_offsets\n\n            src_addr_list = src_slice_addrs.reshape(-1).tolist()\n            if not src_addr_list:\n                # Nothing to transfer for this layer.\n                return 0\n            dst_addr_list = dst_slice_addrs.reshape(-1).tolist()\n            total_slices = len(src_addr_list)\n            length_list = [heads_bytes_per_token_to_send] * total_slices\n            return self.engine.batch_transfer_sync(\n                mooncake_session_id, src_addr_list, dst_addr_list, length_list\n            )\n\n        futures = []\n        for i in range(layers_current_pp_stage):\n            futures.append(\n                executor.submit(process_layer_tp_aware, src_k_ptrs[i], dst_k_ptrs[i])\n            )\n        for i in range(layers_current_pp_stage):\n            futures.append(\n                executor.submit(process_layer_tp_aware, src_v_ptrs[i], dst_v_ptrs[i])\n            )\n\n        for future in concurrent.futures.as_completed(futures):\n            status = future.result()\n            if status != 0:\n                for f in futures:\n                    f.cancel()\n                return status\n\n        return 0\n\n    def send_aux(\n        self,\n        req: TransferInfo,\n        prefill_aux_index: int,\n        dst_aux_ptrs: list[int],\n    ):\n        # TODO(shangming): Fix me when nvlink_transport of Mooncake is bug-free\n        if (\n            self.enable_custom_mem_pool and self.custom_mem_pool_type == \"NVLINK\"\n        ) or envs.SGLANG_MOONCAKE_SEND_AUX_TCP.get():\n            return self.send_aux_tcp(req, prefill_aux_index, dst_aux_ptrs)\n\n        transfer_blocks = []\n        prefill_aux_ptrs = self.kv_args.aux_data_ptrs\n        prefill_aux_item_lens = self.kv_args.aux_item_lens\n\n        for i, dst_aux_ptr in enumerate(dst_aux_ptrs):\n            length = prefill_aux_item_lens[i]\n            src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index\n            dst_addr = dst_aux_ptrs[i] + length * req.dst_aux_index\n            transfer_blocks.append((src_addr, dst_addr, length))\n\n        return self._transfer_data(req.mooncake_session_id, transfer_blocks)\n\n    def send_aux_tcp(\n        self,\n        req: TransferInfo,\n        prefill_aux_index: int,\n        dst_aux_ptrs: list[int],\n    ):\n        prefill_aux_ptrs = self.kv_args.aux_data_ptrs\n        prefill_aux_item_lens = self.kv_args.aux_item_lens\n\n        for i in range(len(prefill_aux_ptrs)):\n            length = prefill_aux_item_lens[i]\n            src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index\n            data = AuxDataCodec.serialize_data_from_buffer(src_addr, length)\n\n            self.send_aux_data_to_endpoint(\n                remote=req.endpoint,\n                dst_port=req.dst_port,\n                room=req.room,\n                buffer_index=i,\n                aux_index=req.dst_aux_index,\n                data=data,\n            )\n\n        return 0\n\n    def send_aux_data_to_endpoint(\n        self,\n        remote: str,\n        dst_port: int,\n        room: int,\n        buffer_index: int,\n        aux_index: int,\n        data: bytes,\n    ):\n        na = NetworkAddress(remote, dst_port)\n        socket = self._connect(na.to_tcp(), is_ipv6=na.is_ipv6)\n\n        socket.send_multipart(\n            [\n                MooncakeKVManager.AUX_DATA_HEADER,\n                str(room).encode(\"ascii\"),\n                str(buffer_index).encode(\"ascii\"),\n                str(aux_index).encode(\"ascii\"),\n                struct.pack(\">I\", len(data)),\n                data,\n            ]\n        )\n\n    def _handle_aux_data(self, msg: List[bytes]):\n        \"\"\"Handle AUX_DATA messages received by the decode thread.\"\"\"\n        room = int(msg[1].decode(\"ascii\"))\n        buffer_index = int(msg[2].decode(\"ascii\"))\n        aux_index = int(msg[3].decode(\"ascii\"))\n        data_length = struct.unpack(\">I\", msg[4])[0]\n        data = msg[5]\n\n        if len(data) != data_length:\n            logger.error(f\"AUX_DATA length mismatch for bootstrap_room {room}\")\n            return\n\n        AuxDataCodec.deserialize_data_to_buffer(\n            self.kv_args, buffer_index, aux_index, data\n        )\n\n        logger.debug(\n            f\"Received AUX_DATA for bootstrap_room {room} with length:{len(data)}\"\n        )\n\n    def maybe_send_extra(\n        self,\n        req: TransferInfo,\n        prefill_state_indices: list[int],\n        dst_state_data_ptrs: list[int],\n        executor: concurrent.futures.ThreadPoolExecutor,\n        target_rank_registration_info: Optional[KVArgsRegisterInfo] = None,\n    ):\n        \"\"\"Send state or extra pool data with type-specific handling.\"\"\"\n        state_type = getattr(self.kv_args, \"state_type\", \"none\")\n\n        if state_type == \"mamba\":\n            # Check if we need slice transfer for different TP sizes\n            if (\n                target_rank_registration_info is not None\n                and self.attn_tp_size != target_rank_registration_info.dst_attn_tp_size\n            ):\n                return self._send_mamba_state_slice(\n                    req,\n                    prefill_state_indices,\n                    dst_state_data_ptrs,\n                    target_rank_registration_info.dst_state_item_lens,\n                    target_rank_registration_info.dst_state_dim_per_tensor,\n                    target_rank_registration_info.dst_tp_rank,\n                    target_rank_registration_info.dst_attn_tp_size,\n                )\n            else:\n                return self._send_mamba_state(\n                    req,\n                    prefill_state_indices,\n                    dst_state_data_ptrs,\n                )\n        elif state_type in [\"swa\", \"nsa\"]:\n            # SWA and NSA hybrid models do not support different TP sizes yet\n            if (\n                target_rank_registration_info is not None\n                and not self.is_mla_backend\n                and self.attn_tp_size != target_rank_registration_info.dst_attn_tp_size\n            ):\n                raise RuntimeError(\n                    f\"PD Disaggregation does NOT support PD different TP sizes for non-MLA {state_type.upper()} hybrid models yet.\"\n                )\n            if len(prefill_state_indices) < len(req.dst_state_indices):\n                logger.warning(\n                    f\"len(prefill_state_indices) = {len(prefill_state_indices)}, len(dst_state_indices) = {len(req.dst_state_indices)}\"\n                )\n                prefill_state_indices = prefill_state_indices[\n                    : len(req.dst_state_indices)\n                ]\n            # Reuse _send_kvcache_generic interface to send extra pool data\n            prefill_state_indices = np.array(prefill_state_indices, dtype=np.int32)\n            dst_state_indices = np.array(req.dst_state_indices, dtype=np.int32)\n            return self._send_kvcache_generic(\n                mooncake_session_id=req.mooncake_session_id,\n                src_data_ptrs=self.kv_args.state_data_ptrs,\n                dst_data_ptrs=dst_state_data_ptrs,\n                item_lens=self.kv_args.state_item_lens,\n                prefill_data_indices=prefill_state_indices,\n                dst_data_indices=dst_state_indices,\n                executor=executor,\n            )\n        else:\n            return 0\n\n    def _send_mamba_state(\n        self,\n        req: TransferInfo,\n        prefill_mamba_index: list[int],\n        dst_state_data_ptrs: list[int],\n    ):\n        \"\"\"Transfer Mamba states.\"\"\"\n        assert len(prefill_mamba_index) == 1, \"Mamba should have single state index\"\n\n        transfer_blocks = []\n        prefill_state_data_ptrs = self.kv_args.state_data_ptrs\n        prefill_state_item_lens = self.kv_args.state_item_lens\n\n        for i, dst_state_ptr in enumerate(dst_state_data_ptrs):\n            length = prefill_state_item_lens[i]\n            src_addr = prefill_state_data_ptrs[i] + length * int(prefill_mamba_index[0])\n            dst_addr = dst_state_ptr + length * int(req.dst_state_indices[0])\n            transfer_blocks.append((src_addr, dst_addr, length))\n\n        return self._transfer_data(req.mooncake_session_id, transfer_blocks)\n\n    def _send_mamba_state_slice(\n        self,\n        req: TransferInfo,\n        prefill_mamba_index: list[int],\n        dst_state_data_ptrs: list[int],\n        dst_state_item_lens: list[int],\n        dst_state_dim_per_tensor: list[int],\n        dst_tp_rank: int,\n        dst_attn_tp_size: int,\n    ):\n        \"\"\"Transfer Mamba states with TP slice support.\n\n        Mamba state layout:\n        - conv_state: [num_layers, size+1, conv_dim/tp, conv_kernel-1]\n        - temporal_state: [num_layers, size+1, num_heads/tp, head_dim, state_size]\n\n        The 3rd dimension is sliced by TP. When prefill and decode have different\n        attn_tp_size, we need to slice the state accordingly.\n        \"\"\"\n        logger.warning_once(\n            \"Using Mamba state slice transfer for different TP sizes between prefill and decode. \"\n            f\"Prefill attn_tp_size={self.attn_tp_size}, Decode attn_tp_size={dst_attn_tp_size}. \"\n            \"Performance may be affected.\"\n        )\n        assert len(prefill_mamba_index) == 1, \"Mamba should have single state index\"\n\n        transfer_blocks = []\n        prefill_state_data_ptrs = self.kv_args.state_data_ptrs\n        prefill_state_item_lens = self.kv_args.state_item_lens\n        src_state_dim_per_tensor = getattr(self.kv_args, \"state_dim_per_tensor\", [])\n\n        # If no dimension info available, fall back to regular transfer\n        if not src_state_dim_per_tensor or not dst_state_dim_per_tensor:\n            return self._send_mamba_state(req, prefill_mamba_index, dst_state_data_ptrs)\n\n        local_tp_rank_in_group = self.kv_args.engine_rank % self.attn_tp_size\n        dst_tp_rank_in_group = dst_tp_rank % dst_attn_tp_size\n\n        for i, dst_state_ptr in enumerate(dst_state_data_ptrs):\n            src_item_len = prefill_state_item_lens[i]\n            dst_item_len = dst_state_item_lens[i]\n            src_dim = src_state_dim_per_tensor[i]\n            dst_dim = dst_state_dim_per_tensor[i]\n\n            # Calculate bytes per dimension slice\n            # item_len = dim * trailing_dims_size, so trailing_dims_size = item_len / dim\n            src_bytes_per_dim = src_item_len // src_dim\n            dst_bytes_per_dim = dst_item_len // dst_dim\n\n            # Determine slicing parameters based on TP configuration\n            if self.attn_tp_size > dst_attn_tp_size:\n                # Multiple prefill ranks send to 1 decode rank\n                # Each prefill sends all its dims to the appropriate offset in decode\n                src_dim_start = 0\n                num_dims_to_send = src_dim\n                writers_per_decode = self.attn_tp_size // dst_attn_tp_size\n                local_writer_idx = local_tp_rank_in_group % writers_per_decode\n                dst_dim_start = local_writer_idx * src_dim\n            else:\n                # 1 prefill rank sends to multiple decode ranks\n                # Prefill sends a slice of its dims to each decode rank\n                src_dim_start = (dst_tp_rank_in_group * dst_dim) % src_dim\n                num_dims_to_send = dst_dim\n                dst_dim_start = 0\n\n            # Calculate byte offsets\n            src_dim_offset = src_dim_start * src_bytes_per_dim\n            dst_dim_offset = dst_dim_start * dst_bytes_per_dim\n            bytes_to_send = num_dims_to_send * src_bytes_per_dim\n\n            # Calculate addresses for this state tensor\n            src_addr = (\n                prefill_state_data_ptrs[i]\n                + src_item_len * int(prefill_mamba_index[0])\n                + src_dim_offset\n            )\n            dst_addr = (\n                dst_state_ptr\n                + dst_item_len * int(req.dst_state_indices[0])\n                + dst_dim_offset\n            )\n\n            transfer_blocks.append((src_addr, dst_addr, bytes_to_send))\n\n        return self._transfer_data(req.mooncake_session_id, transfer_blocks)\n\n    def sync_status_to_decode_endpoint(\n        self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int\n    ):\n        na = NetworkAddress(remote, dst_port)\n        self._connect(na.to_tcp(), is_ipv6=na.is_ipv6).send_multipart(\n            [\n                str(room).encode(\"ascii\"),\n                str(status).encode(\"ascii\"),\n                str(prefill_rank).encode(\"ascii\"),\n            ]\n        )\n\n    def transfer_worker(\n        self, queue: FastQueue, executor: concurrent.futures.ThreadPoolExecutor\n    ):\n        while True:\n            try:\n                kv_chunk: TransferKVChunk = queue.get()\n                reqs_to_be_processed = (\n                    self.transfer_infos[kv_chunk.room].values()\n                    if kv_chunk.room in self.transfer_infos\n                    else []\n                )\n                polls = []\n                dst_ranks_infos = []\n                # Unique id per prefill sender so decode's response set size matches expected_response_num.\n                prefill_unique_rank = (\n                    self.attn_tp_rank * (self.pp_size * self.attn_cp_size)\n                    + self.pp_rank * self.attn_cp_size\n                    + self.attn_cp_rank\n                )\n                for req in reqs_to_be_processed:\n                    if not req.is_dummy:\n                        # Early exit if the request has failed\n                        with self.session_lock:\n                            if req.mooncake_session_id in self.failed_sessions:\n                                self.record_failure(\n                                    kv_chunk.room,\n                                    f\"Decode instance could be dead, remote mooncake session {req.mooncake_session_id} is not alive\",\n                                )\n                                self.update_status(kv_chunk.room, KVPoll.Failed)\n                                self.sync_status_to_decode_endpoint(\n                                    req.endpoint,\n                                    req.dst_port,\n                                    req.room,\n                                    KVPoll.Failed,\n                                    prefill_unique_rank,\n                                )\n                                break\n\n                        chunked_dst_kv_indice = req.dst_kv_indices[kv_chunk.index_slice]\n\n                        # NOTE: This is temporarily a workaround to deal with the case where the prefill_kv_indices\n                        # is mismatched with the dst_kv_indices when page size > 1, this should never happen.\n                        if len(chunked_dst_kv_indice) < len(\n                            kv_chunk.prefill_kv_indices\n                        ):\n                            logger.warning(\n                                f\"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}\"\n                            )\n                            kv_chunk.prefill_kv_indices = kv_chunk.prefill_kv_indices[\n                                : len(chunked_dst_kv_indice)\n                            ]\n\n                        target_rank_registration_info: KVArgsRegisterInfo = (\n                            self.decode_kv_args_table[req.mooncake_session_id]\n                        )\n                        if self.is_mla_backend or (\n                            self.attn_tp_size\n                            == target_rank_registration_info.dst_attn_tp_size\n                        ):\n                            ret = self.send_kvcache(\n                                req.mooncake_session_id,\n                                kv_chunk.prefill_kv_indices,\n                                target_rank_registration_info.dst_kv_ptrs,\n                                chunked_dst_kv_indice,\n                                executor,\n                            )\n                        else:\n                            ret = self.send_kvcache_slice(\n                                req.mooncake_session_id,\n                                kv_chunk.prefill_kv_indices,\n                                target_rank_registration_info.dst_kv_ptrs,\n                                chunked_dst_kv_indice,\n                                target_rank_registration_info.dst_tp_rank,\n                                target_rank_registration_info.dst_attn_tp_size,\n                                target_rank_registration_info.dst_kv_item_len,\n                                executor,\n                            )\n                        if ret != 0:\n                            with self.session_lock:\n                                self.session_failures[req.mooncake_session_id] += 1\n                                # Failures should never happen if the session is not dead, if the session fails once, mark it as failed\n                                if self.session_failures[req.mooncake_session_id] >= 1:\n                                    self.failed_sessions.add(req.mooncake_session_id)\n                                    logger.error(\n                                        f\"Session {req.mooncake_session_id} failed.\"\n                                    )\n                            self.record_failure(\n                                kv_chunk.room,\n                                f\"Failed to send kv chunk of {kv_chunk.room} to {req.endpoint}:{req.dst_port}\",\n                            )\n                            self.update_status(kv_chunk.room, KVPoll.Failed)\n                            self.sync_status_to_decode_endpoint(\n                                req.endpoint,\n                                req.dst_port,\n                                req.room,\n                                KVPoll.Failed,\n                                prefill_unique_rank,\n                            )\n                            break\n\n                        if kv_chunk.is_last_chunk:\n                            if kv_chunk.state_indices is not None:\n                                self.maybe_send_extra(\n                                    req,\n                                    kv_chunk.state_indices,\n                                    target_rank_registration_info.dst_state_data_ptrs,\n                                    executor,\n                                    target_rank_registration_info,\n                                )\n\n                            # Only the last chunk we need to send the aux data\n                            ret = self.send_aux(\n                                req,\n                                kv_chunk.prefill_aux_index,\n                                target_rank_registration_info.dst_aux_ptrs,\n                            )\n                            polls.append(True if ret == 0 else False)\n                            dst_ranks_infos.append(\n                                (req.endpoint, req.dst_port, req.room)\n                            )\n\n                            # Only sync status when all the dst ranks have received the kvcache\n                            if len(polls) == req.required_dst_info_num:\n                                status = KVPoll.Success if all(polls) else KVPoll.Failed\n                                self.update_status(req.room, status)\n                                for endpoint, dst_port, room in dst_ranks_infos:\n                                    self.sync_status_to_decode_endpoint(\n                                        endpoint,\n                                        dst_port,\n                                        room,\n                                        status,\n                                        prefill_unique_rank,\n                                    )\n                    else:\n                        # Dummy request means the decode instance is not used, so its status can be marked as success directly\n                        # Dummy request does not need to sync status to decode endpoint\n                        if kv_chunk.is_last_chunk and req.room in self.request_status:\n                            self.update_status(req.room, KVPoll.Success)\n\n                if (\n                    kv_chunk.room not in self.request_status\n                    or self.check_status(kv_chunk.room) == KVPoll.Success\n                ):\n                    if kv_chunk.room in self.transfer_infos:\n                        self.transfer_infos.pop(kv_chunk.room)\n\n            except Exception as e:\n                # NOTE(shangming): Remove this when we make sure the transfer thread is bug-free\n                raise RuntimeError(\n                    f\"Transfer thread failed because of {e}. Prefill instance with bootstrap_port={self.bootstrap_port} is dead.\"\n                )\n\n    def start_prefill_thread(self):\n        def bootstrap_thread():\n            \"\"\"This thread recvs pre-alloc notification from the decode engine\"\"\"\n            # KVPoll.Bootstrapping -> KVPoll.WaitingForInput\n            while True:\n                waiting_req_bytes = self.server_socket.recv_multipart()\n                room = waiting_req_bytes[0].decode(\"ascii\")\n                mooncake_session_id = waiting_req_bytes[3].decode(\"ascii\")\n                if room == \"None\":\n                    self.decode_kv_args_table[mooncake_session_id] = (\n                        KVArgsRegisterInfo.from_zmq(waiting_req_bytes)\n                    )\n                    with self.session_lock:\n                        if mooncake_session_id in self.failed_sessions:\n                            self.failed_sessions.remove(mooncake_session_id)\n                        if mooncake_session_id in self.session_failures:\n                            del self.session_failures[mooncake_session_id]\n                    logger.debug(\n                        f\"Register KVArgs from {mooncake_session_id} successfully\"\n                    )\n                    continue\n                else:\n                    required_dst_info_num = int(waiting_req_bytes[7].decode(\"ascii\"))\n                    room = int(room)\n                    if room not in self.transfer_infos:\n                        self.transfer_infos[room] = {}\n\n                    self.transfer_infos[room][mooncake_session_id] = (\n                        TransferInfo.from_zmq(waiting_req_bytes)\n                    )\n                    # NOTE: after bootstrapping we can mark the req as waiting for input\n                    if len(self.transfer_infos[room]) == required_dst_info_num:\n                        self.update_status(room, KVPoll.WaitingForInput)\n\n        threading.Thread(target=bootstrap_thread).start()\n\n    def start_decode_thread(self):\n        def decode_thread():\n            while True:\n                msg = self.server_socket.recv_multipart()\n                if msg[0] == MooncakeKVManager.AUX_DATA_HEADER:\n                    self._handle_aux_data(msg)\n                    continue\n\n                bootstrap_room, status, prefill_rank = msg\n                status = int(status.decode(\"ascii\"))\n                bootstrap_room = int(bootstrap_room.decode(\"ascii\"))\n                prefill_rank = int(prefill_rank.decode(\"ascii\"))\n\n                if status == KVPoll.Success:\n                    if bootstrap_room in self.request_status:\n                        self.prefill_response_tracker[bootstrap_room].add(prefill_rank)\n                        expected_response_num = (\n                            self.required_prefill_response_num_table[bootstrap_room]\n                        )\n                        arrived_response_num = len(\n                            self.prefill_response_tracker[bootstrap_room]\n                        )\n                        if arrived_response_num == expected_response_num:\n                            self.update_status(bootstrap_room, KVPoll.Success)\n                elif status == KVPoll.Failed:\n                    self.record_failure(\n                        bootstrap_room,\n                        \"Failed to get kvcache from prefill instance, it might be dead\",\n                    )\n                    self.update_status(bootstrap_room, status)\n\n        def heartbeat_checker():\n            while True:\n                time.sleep(self.heartbeat_interval)\n                with self.connection_lock:\n                    addresses = list(self.prefill_info_table.keys())\n\n                for bootstrap_addr in addresses:\n                    session = None\n                    try:\n                        with self.session_pool_lock:\n                            session = self.session_pool[bootstrap_addr]\n                        response = session.get(\n                            f\"http://{bootstrap_addr}/health\",\n                            timeout=(2, 3),\n                            headers={\"Connection\": \"keep-alive\"},\n                        )\n                        if response.status_code == 200:\n                            self.heartbeat_failures[bootstrap_addr] = 0\n\n                            current_rooms = self.addr_to_rooms_tracker[\n                                bootstrap_addr\n                            ].copy()\n\n                            for bootstrap_room in current_rooms:\n                                # Remove KVPoll.Success requests from the tracker\n                                if bootstrap_room not in self.request_status:\n                                    self.addr_to_rooms_tracker[bootstrap_addr].discard(\n                                        bootstrap_room\n                                    )\n                        else:\n                            logger.info(\n                                f\"Attempting to reconnect to {bootstrap_addr}...\"\n                            )\n                            self.heartbeat_failures[bootstrap_addr] = (\n                                self.heartbeat_failures.get(bootstrap_addr, 0) + 1\n                            )\n                            with self.session_pool_lock:\n                                if bootstrap_addr in self.session_pool:\n                                    del self.session_pool[bootstrap_addr]\n                    except Exception:\n                        logger.info(f\"Attempting to reconnect to {bootstrap_addr}...\")\n                        self.heartbeat_failures[bootstrap_addr] = (\n                            self.heartbeat_failures.get(bootstrap_addr, 0) + 1\n                        )\n\n                    if (\n                        self.heartbeat_failures.get(bootstrap_addr, 0)\n                        >= self.max_failures\n                    ):\n                        self._handle_node_failure(bootstrap_addr)\n                        with self.session_pool_lock:\n                            if bootstrap_addr in self.session_pool:\n                                del self.session_pool[bootstrap_addr]\n\n        threading.Thread(target=decode_thread).start()\n        threading.Thread(target=heartbeat_checker).start()\n\n    def add_transfer_request(\n        self,\n        bootstrap_room: int,\n        kv_indices: npt.NDArray[np.int32],\n        index_slice: slice,\n        is_last_chunk: bool,\n        aux_index: Optional[int] = None,\n        state_indices: Optional[List[int]] = None,\n    ):\n        assert self.disaggregation_mode == DisaggregationMode.PREFILL\n        assert not is_last_chunk or (is_last_chunk and aux_index is not None)\n\n        if (\n            bootstrap_room not in self.request_status\n            or self.check_status(bootstrap_room) == KVPoll.Failed\n        ):\n            logger.debug(\n                \"Request with bootstrap_room=%s already failed\", bootstrap_room\n            )\n            return\n\n        if bootstrap_room not in self.transfer_infos:\n            # This means that the current rank is a dummy rank for this request,\n            # and it has already been marked as success, so there is no need to\n            # add further chunks into the transfer queue.\n            return\n\n        # NOTE(shangming): sharding according to the dst_infos to make sure\n        # requests with the same dst_sessions will be added into the same\n        # queue, which enables early abort with failed sessions.\n        dst_infos = self.transfer_infos[bootstrap_room].keys()\n        session_port_sum = sum(int(session.rsplit(\":\", 1)[1]) for session in dst_infos)\n        shard_idx = session_port_sum % len(self.transfer_queues)\n\n        self.transfer_queues[shard_idx].put(\n            TransferKVChunk(\n                room=bootstrap_room,\n                prefill_kv_indices=kv_indices,\n                index_slice=index_slice,\n                is_last_chunk=is_last_chunk,\n                prefill_aux_index=aux_index,\n                state_indices=state_indices,\n            )\n        )\n\n    def get_session_id(self):\n        return self.engine.get_session_id()\n\n    def _handle_node_failure(self, failed_bootstrap_addr):\n        with self.connection_lock:\n            keys_to_remove = [\n                k for k in self.connection_pool if k.startswith(failed_bootstrap_addr)\n            ]\n            for k in keys_to_remove:\n                del self.connection_pool[k]\n\n            possible_affected_rooms = self.addr_to_rooms_tracker.get(\n                failed_bootstrap_addr, []\n            )\n            self.prefill_info_table.pop(failed_bootstrap_addr, None)\n            self.addr_to_rooms_tracker.pop(failed_bootstrap_addr, None)\n\n        # Report the requests associated with the failed bootstrap addr and mark their status as KVPoll.Failed\n        affected_rooms = []\n        for room in possible_affected_rooms:\n            if (\n                room in self.request_status\n                and self.check_status(room) != KVPoll.Success\n            ):\n                self.record_failure(\n                    room,\n                    f\"Losing connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr})\",\n                )\n                self.update_status(room, KVPoll.Failed)\n                affected_rooms.append(room)\n        logger.error(\n            f\"Losing connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr}), {len(affected_rooms)} requests affected\"\n        )\n\n\nclass MooncakeKVSender(CommonKVSender):\n\n    def __init__(\n        self,\n        mgr: MooncakeKVManager,\n        bootstrap_addr: str,\n        bootstrap_room: int,\n        dest_tp_ranks: List[int],\n        pp_rank: int,\n    ):\n        super().__init__(mgr, bootstrap_addr, bootstrap_room, dest_tp_ranks, pp_rank)\n        self.conclude_state = None\n        self.init_time = time.time()\n\n    def send(\n        self,\n        kv_indices: npt.NDArray[np.int32],\n        state_indices: Optional[List[int]] = None,\n    ):\n        index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))\n        self.curr_idx += len(kv_indices)\n        is_last_chunk = self.curr_idx == self.num_kv_indices\n\n        # Special handling for cp\n        if self.kv_mgr.enable_all_cp_ranks_for_transfer:\n            kv_indices, index_slice = filter_kv_indices_for_cp_rank(\n                self.kv_mgr,\n                kv_indices,\n                index_slice,\n            )\n        elif self.kv_mgr.is_dummy_cp_rank:\n            if not is_last_chunk:\n                return\n            else:\n                self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Success)\n                return\n\n        if not is_last_chunk:\n            self.kv_mgr.add_transfer_request(\n                self.bootstrap_room,\n                kv_indices,\n                index_slice,\n                False,\n            )\n        else:\n            self.kv_mgr.add_transfer_request(\n                self.bootstrap_room,\n                kv_indices,\n                index_slice,\n                True,\n                aux_index=self.aux_index,\n                state_indices=state_indices,\n            )\n\n    def poll(self) -> KVPoll:\n        if self.conclude_state is None:\n            status = self.kv_mgr.check_status(self.bootstrap_room)\n            if status in (KVPoll.Success, KVPoll.Failed):\n                self.conclude_state = status\n            elif status == KVPoll.Bootstrapping:\n                if self.init_time is not None:\n                    now = time.time()\n                    elapsed = now - self.init_time\n                    if elapsed >= self.kv_mgr.bootstrap_timeout:\n                        logger.warning_once(\n                            \"Some requests timed out when bootstrapping, \"\n                            \"which means prefill instances fail to receive the KV indices from the decode instance of this request. \"\n                            \"If a greater mean TTFT is acceptable, you can 'export SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=600' (10 minutes) to relax the timeout condition. \"\n                        )\n                        self.kv_mgr.record_failure(\n                            self.bootstrap_room,\n                            f\"Request {self.bootstrap_room} timed out after {elapsed:.1f}s in KVPoll.Bootstrapping\",\n                        )\n                        self.conclude_state = KVPoll.Failed\n                        return KVPoll.Failed\n\n            return status\n        else:\n            return self.conclude_state\n\n    def clear(self) -> None:\n        if self.bootstrap_room in self.kv_mgr.request_status:\n            self.kv_mgr.request_status.pop(self.bootstrap_room)\n\n    def failure_exception(self):\n        # Explicitly set the status to failure since this request has failed in another rank\n        if self.conclude_state is None:\n            self.conclude_state = KVPoll.Failed\n\n        self.clear()\n\n        with self.kv_mgr.failure_lock:\n            failure_reason = self.kv_mgr.failure_records.pop(\n                self.bootstrap_room, \"Failed due to an unknown reason from another rank\"\n            )\n        raise KVTransferError(self.bootstrap_room, failure_reason)\n\n    def abort(self):\n        self.kv_mgr.record_failure(\n            self.bootstrap_room,\n            \"Aborted by AbortReq.\",\n        )\n        # Explicitly set the status to failure since this request has been aborted\n        self.conclude_state = KVPoll.Failed\n\n\nclass MooncakeKVReceiver(CommonKVReceiver):\n    def __init__(\n        self,\n        mgr: MooncakeKVManager,\n        bootstrap_addr: str,\n        bootstrap_room: Optional[int] = None,\n        prefill_dp_rank: Optional[int] = None,\n    ):\n        self.session_id = mgr.get_session_id()\n        self.conclude_state = None\n        self.init_time = None\n        super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank)\n\n        self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add(self.bootstrap_room)\n        self.kv_mgr.update_status(self.bootstrap_room, KVPoll.WaitingForInput)\n\n    def _register_kv_args(self):\n        for bootstrap_info in self.bootstrap_infos:\n            packed_kv_data_ptrs = b\"\".join(\n                struct.pack(\"Q\", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs\n            )\n            packed_aux_data_ptrs = b\"\".join(\n                struct.pack(\"Q\", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs\n            )\n            packed_state_data_ptrs = b\"\".join(\n                struct.pack(\"Q\", ptr) for ptr in self.kv_mgr.kv_args.state_data_ptrs\n            )\n            # Pack state_item_lens and state_dim_per_tensor for mamba state slice transfer\n            packed_state_item_lens = b\"\".join(\n                struct.pack(\"I\", item_len)\n                for item_len in self.kv_mgr.kv_args.state_item_lens\n            )\n            state_dim_per_tensor = getattr(\n                self.kv_mgr.kv_args, \"state_dim_per_tensor\", []\n            )\n            packed_state_dim_per_tensor = b\"\".join(\n                struct.pack(\"I\", dim) for dim in state_dim_per_tensor\n            )\n            # Note(shangming): No need to add pp rank here since decode pp size should be equal to prefill pp size or 1\n            tp_rank = self.kv_mgr.kv_args.engine_rank\n            kv_item_len = self.kv_mgr.kv_args.kv_item_lens[0]\n            dst_tp_rank = str(tp_rank).encode(\"ascii\")\n            dst_attn_tp_size = str(self.kv_mgr.attn_tp_size).encode(\"ascii\")\n            dst_kv_item_len = str(kv_item_len).encode(\"ascii\")\n\n            sock, lock = self._connect_to_bootstrap_server(bootstrap_info)\n            with lock:\n                sock.send_multipart(\n                    [\n                        \"None\".encode(\"ascii\"),\n                        self.kv_mgr.local_ip.encode(\"ascii\"),\n                        str(self.kv_mgr.rank_port).encode(\"ascii\"),\n                        self.session_id.encode(\"ascii\"),\n                        packed_kv_data_ptrs,\n                        packed_aux_data_ptrs,\n                        packed_state_data_ptrs,\n                        dst_tp_rank,\n                        dst_attn_tp_size,\n                        dst_kv_item_len,\n                        packed_state_item_lens,\n                        packed_state_dim_per_tensor,\n                    ]\n                )\n\n    def init(\n        self,\n        kv_indices: npt.NDArray[np.int32],\n        aux_index: Optional[int] = None,\n        state_indices: Optional[List[int]] = None,\n    ):\n        if self.bootstrap_infos is None:\n            self.kv_mgr.record_failure(\n                self.bootstrap_room,\n                f\"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}\",\n            )\n            self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)\n            return\n\n        for bootstrap_info in self.bootstrap_infos:\n            sock, lock = self._connect_to_bootstrap_server(bootstrap_info)\n            is_dummy = bootstrap_info[\"is_dummy\"]\n\n            with lock:\n                sock.send_multipart(\n                    [\n                        str(self.bootstrap_room).encode(\"ascii\"),\n                        self.kv_mgr.local_ip.encode(\"ascii\"),\n                        str(self.kv_mgr.rank_port).encode(\"ascii\"),\n                        self.session_id.encode(\"ascii\"),\n                        kv_indices.tobytes() if not is_dummy else b\"\",\n                        str(aux_index).encode(\"ascii\") if not is_dummy else b\"\",\n                        (\n                            np.array(\n                                state_indices,\n                                dtype=np.int32,\n                            ).tobytes()\n                            if not is_dummy and state_indices is not None\n                            else b\"\"\n                        ),\n                        str(self.required_dst_info_num).encode(\"ascii\"),\n                    ]\n                )\n        self.init_time = time.time()\n\n    def poll(self) -> KVPoll:\n        if self.conclude_state is None:\n            status = self.kv_mgr.check_status(self.bootstrap_room)\n            if status in (KVPoll.Success, KVPoll.Failed):\n                self.conclude_state = status\n            elif status == KVPoll.WaitingForInput:\n                if self.init_time is not None:\n                    now = time.time()\n                    elapsed = now - self.init_time\n                    if elapsed >= self.kv_mgr.waiting_timeout:\n                        logger.warning_once(\n                            \"Some requests fail to receive KV Cache transfer done signal after bootstrapping. \"\n                            \"If a greater mean TTFT is acceptable, you can 'export SGLANG_DISAGGREGATION_WAITING_TIMEOUT=600' (10 minutes) to relax the timeout condition. \"\n                        )\n                        self.kv_mgr.record_failure(\n                            self.bootstrap_room,\n                            f\"Request {self.bootstrap_room} timed out after {elapsed:.1f}s in KVPoll.WaitingForInput\",\n                        )\n                        self.conclude_state = KVPoll.Failed\n                        return KVPoll.Failed\n\n            return status\n\n        else:\n            return self.conclude_state\n\n    def clear(self) -> None:\n        if self.bootstrap_room in self.kv_mgr.request_status:\n            self.kv_mgr.request_status.pop(self.bootstrap_room)\n\n        if self.bootstrap_room in self.kv_mgr.required_prefill_response_num_table:\n            self.kv_mgr.required_prefill_response_num_table.pop(self.bootstrap_room)\n\n        if self.bootstrap_room in self.kv_mgr.prefill_response_tracker:\n            self.kv_mgr.prefill_response_tracker.pop(self.bootstrap_room)\n\n    def failure_exception(self):\n        # Explicitly set the status to failure since this request has failed in another rank\n        if self.conclude_state is None:\n            self.conclude_state = KVPoll.Failed\n\n        self.clear()\n\n        with self.kv_mgr.failure_lock:\n            failure_reason = self.kv_mgr.failure_records.pop(\n                self.bootstrap_room, \"Failed due to an unknown reason from another rank\"\n            )\n        raise KVTransferError(self.bootstrap_room, failure_reason)\n\n    def abort(self):\n        self.kv_mgr.record_failure(\n            self.bootstrap_room,\n            \"Aborted by AbortReq.\",\n        )\n        # Explicitly set the status to failure since this request has been aborted\n        self.conclude_state = KVPoll.Failed\n\n\nclass MooncakeKVBootstrapServer(CommonKVBootstrapServer):\n    pass\n"
  },
  {
    "path": "python/sglang/srt/disaggregation/mooncake/utils.py",
    "content": "# Copyright 2025 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Mooncake-specific utilities for custom memory pool management.\"\"\"\n\nimport logging\nfrom typing import Any, Optional, Tuple\n\nimport torch\n\nfrom sglang.srt.environ import envs\n\nlogger = logging.getLogger(__name__)\n\n# Global constants for custom memory pool types\nSUPPORTED_MOONCAKE_CUSTOM_MEM_POOL_TYPES = [\"NVLINK\", \"BAREX\", \"INTRA_NVLINK\"]\n\n\ndef init_mooncake_custom_mem_pool(\n    device: str,\n) -> Tuple[bool, Optional[Any], Optional[str]]:\n    \"\"\"\n    Initialize custom memory pool based on environment variable.\n\n    Args:\n        device: The device to allocate memory on\n\n    Returns:\n        Tuple of (enable_custom_mem_pool, custom_mem_pool, custom_mem_pool_type)\n    \"\"\"\n    enable_custom_mem_pool, custom_mem_pool_type = (\n        check_mooncake_custom_mem_pool_enabled()\n    )\n\n    custom_mem_pool = None\n\n    if enable_custom_mem_pool:\n        try:\n            # TODO(shangming): abstract custom allocator class for more backends\n            if custom_mem_pool_type == \"NVLINK\":\n                from mooncake.allocator import NVLinkAllocator\n\n                allocator = NVLinkAllocator.get_allocator(device)\n            elif custom_mem_pool_type == \"BAREX\":\n                from mooncake.allocator import BarexAllocator\n\n                allocator = BarexAllocator.get_allocator(device)\n            elif custom_mem_pool_type == \"INTRA_NODE_NVLINK\":\n                return False, None, None\n            else:\n                # This should not happen due to the enable_custom_mem_pool check above\n                raise ValueError(\n                    f\"Unsupported custom mem pool type: {custom_mem_pool_type}\"\n                )\n\n            custom_mem_pool = torch.cuda.MemPool(allocator.allocator())\n            logger.debug(\n                f\"Initialized custom memory pool: {custom_mem_pool_type} on device {device}\"\n            )\n        except ImportError as e:\n            logger.warning(\n                f\"Failed to import mooncake allocator for {custom_mem_pool_type}: {e}. \"\n                f\"Falling back to default memory pool.\"\n            )\n            enable_custom_mem_pool = False\n            custom_mem_pool = None\n            custom_mem_pool_type = None\n        except Exception as e:\n            logger.error(\n                f\"Failed to initialize custom memory pool {custom_mem_pool_type}: {e}. \"\n                f\"Falling back to default memory pool.\"\n            )\n            enable_custom_mem_pool = False\n            custom_mem_pool = None\n            custom_mem_pool_type = None\n    else:\n        return False, None, None\n\n    return enable_custom_mem_pool, custom_mem_pool, custom_mem_pool_type\n\n\ndef check_mooncake_custom_mem_pool_enabled() -> Tuple[bool, Optional[str]]:\n    \"\"\"\n    Check if custom memory pool is enabled without importing allocators.\n\n    Returns:\n        Tuple of (enable_custom_mem_pool, custom_mem_pool_type)\n    \"\"\"\n    custom_mem_pool_type = envs.SGLANG_MOONCAKE_CUSTOM_MEM_POOL.get()\n\n    if custom_mem_pool_type is not None:\n        # Handle boolean True as NVLINK\n        if custom_mem_pool_type.lower() == \"true\":\n            custom_mem_pool_type = \"NVLINK\"\n        enable_custom_mem_pool = (\n            custom_mem_pool_type in SUPPORTED_MOONCAKE_CUSTOM_MEM_POOL_TYPES\n        )\n    else:\n        enable_custom_mem_pool = False\n        custom_mem_pool_type = None\n\n    return enable_custom_mem_pool, custom_mem_pool_type\n"
  },
  {
    "path": "python/sglang/srt/disaggregation/mori/__init__.py",
    "content": "from sglang.srt.disaggregation.mori.conn import (\n    MoriKVBootstrapServer,\n    MoriKVManager,\n    MoriKVReceiver,\n    MoriKVSender,\n)\n"
  },
  {
    "path": "python/sglang/srt/disaggregation/mori/conn.py",
    "content": "from __future__ import annotations\n\nimport ctypes\nimport dataclasses\nimport logging\nimport os\nimport struct\nimport threading\nimport time\nfrom typing import Dict, List, Optional, Tuple\n\nimport msgspec\nimport numpy as np\nimport numpy.typing as npt\nfrom mori.cpp import TransferStatus\nfrom mori.io import (\n    BackendType,\n    EngineDesc,\n    IOEngine,\n    IOEngineConfig,\n    MemoryDesc,\n    MemoryLocationType,\n    PollCqMode,\n    RdmaBackendConfig,\n)\n\nfrom sglang.srt.disaggregation.base.conn import KVArgs, KVPoll\nfrom sglang.srt.disaggregation.common.conn import (\n    CommonKVBootstrapServer,\n    CommonKVManager,\n    CommonKVReceiver,\n    CommonKVSender,\n)\nfrom sglang.srt.disaggregation.common.utils import group_concurrent_contiguous\nfrom sglang.srt.disaggregation.utils import (\n    DisaggregationMode,\n    filter_kv_indices_for_cp_rank,\n)\nfrom sglang.srt.server_args import ServerArgs\nfrom sglang.srt.utils.common import get_int_env_var\nfrom sglang.srt.utils.network import NetworkAddress, get_local_ip_auto\n\nlogger = logging.getLogger(__name__)\nMORI_GUARD = b\"MoriMsgGuard\"\n\n\ndef _pack_mem_desc_list(mems: List[MemoryDesc]) -> bytes:\n    if not mems:\n        return b\"\"\n    packed_descs = [mem.pack() for mem in mems]\n    return msgspec.msgpack.encode(packed_descs)\n\n\ndef _unpack_mem_desc_list(blob: bytes) -> List[MemoryDesc]:\n    if not blob:\n        return []\n    desc_blobs = msgspec.msgpack.decode(blob)\n    return [MemoryDesc.unpack(b) for b in desc_blobs]\n\n\n@dataclasses.dataclass\nclass TransferInfo:\n    room: int\n    endpoint: str\n    dst_port: int\n    engine_key: str\n    dst_kv_indices: npt.NDArray[np.int32]\n    dst_aux_index: int\n    required_dst_info_num: int\n    is_dummy: bool\n\n    @classmethod\n    def from_zmq(cls, payload: List[bytes]) -> TransferInfo:\n        room = int(payload[0].decode(\"ascii\"))\n        endpoint = payload[1].decode(\"ascii\")\n        dst_port = int(payload[2].decode(\"ascii\"))\n        engine_key = payload[3].decode(\"ascii\")\n\n        if payload[4]:\n            dst_kv_indices = np.frombuffer(payload[4], dtype=np.int32)\n        else:\n            dst_kv_indices = np.array([], dtype=np.int32)\n\n        if payload[5]:\n            dst_aux_index = int(payload[5].decode(\"ascii\"))\n        else:\n            dst_aux_index = -1\n\n        required_dst_info_num = (\n            int(payload[7].decode(\"ascii\")) if len(payload) > 7 else 1\n        )\n        is_dummy = dst_kv_indices.size == 0 and dst_aux_index < 0\n        return cls(\n            room=room,\n            endpoint=endpoint,\n            dst_port=dst_port,\n            engine_key=engine_key,\n            dst_kv_indices=dst_kv_indices,\n            dst_aux_index=dst_aux_index,\n            required_dst_info_num=required_dst_info_num,\n            is_dummy=is_dummy,\n        )\n\n\n@dataclasses.dataclass\nclass KVArgsRegisterInfo:\n    endpoint: str\n    dst_port: int\n    engine_desc: EngineDesc\n    dst_kv_mem_descs: List[MemoryDesc]\n    dst_aux_mem_descs: List[MemoryDesc]\n    dst_state_mem_descs: List[MemoryDesc]\n    gpu_id: int\n    decode_tp_size: int\n    decode_tp_rank: int\n    dst_kv_item_len: int\n\n    @property\n    def engine_key(self) -> str:\n        return self.engine_desc.key\n\n    @classmethod\n    def from_zmq(cls, payload: List[bytes]) -> KVArgsRegisterInfo:\n        endpoint = payload[1].decode(\"ascii\")\n        dst_port = int(payload[2].decode(\"ascii\"))\n        engine_desc = EngineDesc.unpack(payload[3])\n        dst_kv_mem_descs = _unpack_mem_desc_list(payload[4])\n        dst_aux_mem_descs = _unpack_mem_desc_list(payload[5])\n        dst_state_mem_descs = _unpack_mem_desc_list(payload[6])\n        gpu_id = int(payload[7].decode(\"ascii\"))\n        decode_tp_size = int(payload[8].decode(\"ascii\"))\n        decode_tp_rank = int(payload[9].decode(\"ascii\"))\n        dst_kv_item_len = int(payload[10].decode(\"ascii\"))\n        return cls(\n            endpoint=endpoint,\n            dst_port=dst_port,\n            engine_desc=engine_desc,\n            dst_kv_mem_descs=dst_kv_mem_descs,\n            dst_aux_mem_descs=dst_aux_mem_descs,\n            dst_state_mem_descs=dst_state_mem_descs,\n            gpu_id=gpu_id,\n            decode_tp_size=decode_tp_size,\n            decode_tp_rank=decode_tp_rank,\n            dst_kv_item_len=dst_kv_item_len,\n        )\n\n\nclass AuxDataCodec:\n    @staticmethod\n    def serialize_data_from_buffer(src_addr, data_length):\n        buffer = (ctypes.c_byte * data_length).from_address(src_addr)\n        return bytes(buffer)\n\n    @staticmethod\n    def deserialize_data_to_buffer(kv_args, buffer_index, aux_index, data):\n        dst_aux_ptr = kv_args.aux_data_ptrs[buffer_index]\n        item_len = kv_args.aux_item_lens[buffer_index]\n        dst_addr = dst_aux_ptr + item_len * aux_index\n        buffer = (ctypes.c_byte * len(data)).from_address(dst_addr)\n        buffer[:] = data\n        return\n\n\n@dataclasses.dataclass\nclass TPSliceConfig:\n    page_size: int\n    src_item_len: int\n    dst_item_len: int\n    bytes_per_token_src: int\n    bytes_per_token_dst: int\n    src_head_slice_offset: int\n    dst_head_slice_offset: int\n    heads_bytes_per_token_to_send: int\n\n\nclass MoriKVManager(CommonKVManager):\n    AUX_DATA_HEADER = b\"AUX_DATA\"\n\n    def __init__(\n        self,\n        args: KVArgs,\n        disaggregation_mode: DisaggregationMode,\n        server_args: ServerArgs,\n        is_mla_backend: Optional[bool] = False,\n    ):\n        super().__init__(args, disaggregation_mode, server_args, is_mla_backend)\n        self.engine = self._init_engine()\n        self.engine_desc = self.engine.get_engine_desc()\n        self.kv_mem_descs: List[MemoryDesc] = []\n        self.aux_mem_descs: List[MemoryDesc] = []\n        self.state_mem_descs: List[MemoryDesc] = []\n        self.transfer_lock = threading.Lock()\n        self._register_local_buffers()\n        if self.disaggregation_mode == DisaggregationMode.PREFILL:\n            self._start_bootstrap_thread()\n        elif self.disaggregation_mode == DisaggregationMode.DECODE:\n            self.room_to_bootstrap_addr: Dict[int, str] = {}\n            self._start_decode_thread()\n\n    def _init_engine(self) -> IOEngine:\n        if self.kv_args.ib_device:\n            os.environ[\"MORI_RDMA_DEVICES\"] = self.kv_args.ib_device\n\n        self.local_ip = get_local_ip_auto()\n        config = IOEngineConfig(host=self.local_ip, port=0)\n\n        engine_key = (\n            f\"io-{self.disaggregation_mode.value}-\"\n            f\"dp{self.system_dp_rank}-tp{self.attn_tp_rank}-\"\n            f\"pid{os.getpid()}-{self.local_ip}\"\n        )\n\n        engine = IOEngine(engine_key, config)\n        poll_mode = PollCqMode.POLLING\n\n        # Number of RDMA Queue Pairs (QPs) used per transfer operation.\n        # Higher values can increase parallelism and bandwidth utilization.\n        # Default: 1\n        qp_per_transfer = get_int_env_var(\"SGLANG_MORI_QP_PER_TRANSFER\", 1)\n\n        # Number of RDMA work requests posted in a single batch to each QP.\n        # Larger batch sizes reduce per-operation overhead and improve throughput\n        # at the cost of higher latency. Use -1 for automatic sizing based on\n        # the number of merged work requests and available endpoints.\n        # Default: -1 (automatic)\n        post_batch_size = get_int_env_var(\"SGLANG_MORI_POST_BATCH_SIZE\", -1)\n\n        # Number of worker threads in the RDMA executor thread pool.\n        # Each worker handles RDMA operations on a separate CPU core (with affinity).\n        # More workers can improve parallelism for large batch transfers across\n        # multiple QPs, but excessive threads may cause contention.\n        # Default: 1\n        num_worker_threads = get_int_env_var(\"SGLANG_MORI_NUM_WORKERS\", 1)\n\n        rdma_cfg = RdmaBackendConfig(\n            qp_per_transfer,\n            post_batch_size,\n            num_worker_threads,\n            poll_mode,\n            False,\n        )\n        engine.create_backend(BackendType.RDMA, rdma_cfg)\n        actual_port = engine.get_engine_desc().port\n        assert actual_port > 0, f\"Failed to bind port for engine {engine_key}\"\n        logger.debug(\n            \"Initialized Mori IOEngine %s at %s:%s (qp_per_transfer=%s, workers=%s, poll_mode=%s)\",\n            engine_key,\n            self.local_ip,\n            actual_port,\n            qp_per_transfer,\n            num_worker_threads,\n            poll_mode.name,\n        )\n        return engine\n\n    def _register_local_buffers(self) -> None:\n        for ptr, length in zip(self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens):\n            mem_desc = self.engine.register_memory(\n                ptr,\n                length,\n                self.kv_args.gpu_id,\n                MemoryLocationType.GPU,\n            )\n            self.kv_mem_descs.append(mem_desc)\n        for ptr, length in zip(self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens):\n            desc = self.engine.register_memory(\n                ptr,\n                length,\n                -1,\n                MemoryLocationType.CPU,\n            )\n            self.aux_mem_descs.append(desc)\n        for ptr, length in zip(\n            self.kv_args.state_data_ptrs, getattr(self.kv_args, \"state_data_lens\", [])\n        ):\n            desc = self.engine.register_memory(\n                ptr,\n                length,\n                self.kv_args.gpu_id,\n                MemoryLocationType.GPU,\n            )\n            self.state_mem_descs.append(desc)\n\n    def _handle_register_message(self, payload: List[bytes]) -> None:\n        try:\n            register_info = KVArgsRegisterInfo.from_zmq(payload)\n            self._add_remote_peer(register_info)\n        except Exception:\n            logger.exception(\"Failed to register remote peer\")\n\n    def _handle_transfer_message(self, payload: List[bytes]) -> None:\n        try:\n            transfer_info = TransferInfo.from_zmq(payload)\n            infos = self.transfer_infos.setdefault(transfer_info.room, {})\n            infos[transfer_info.engine_key] = transfer_info\n\n            if len(infos) >= transfer_info.required_dst_info_num:\n                logger.debug(\n                    \"Bootstrap room %s got enough transfer info (%s)\",\n                    transfer_info.room,\n                    len(infos),\n                )\n                self.update_status(transfer_info.room, KVPoll.WaitingForInput)\n        except Exception:\n            logger.exception(\"Failed to parse transfer info message\")\n\n    def _validate_message(self, msg: List[bytes]) -> Optional[List[bytes]]:\n        if not msg or msg[0] != MORI_GUARD:\n            logger.warning(\"Received malformed bootstrap message\")\n            return None\n        payload = msg[1:]\n        if not payload:\n            return None\n        return payload\n\n    def _start_bootstrap_thread(self) -> None:\n        def bootstrap_worker():\n            while True:\n                try:\n                    msg = self.server_socket.recv_multipart()\n                    payload = self._validate_message(msg)\n                    if payload is None:\n                        continue\n                    room = payload[0].decode(\"ascii\")\n\n                    if room == \"None\":\n                        self._handle_register_message(payload)\n                    else:\n                        self._handle_transfer_message(payload)\n                except Exception:\n                    logger.exception(\"Bootstrap worker failed\")\n\n        threading.Thread(target=bootstrap_worker, daemon=True).start()\n\n    def _cleanup_room_tracking(self, bootstrap_room: int) -> None:\n        bootstrap_addr = self.room_to_bootstrap_addr.pop(bootstrap_room, None)\n        if bootstrap_addr is not None:\n            rooms = self.addr_to_rooms_tracker.get(bootstrap_addr)\n            if rooms is not None:\n                rooms.discard(bootstrap_room)\n                if not rooms:\n                    self.addr_to_rooms_tracker.pop(bootstrap_addr, None)\n\n    def _start_decode_thread(self) -> None:\n        def decode_worker():\n            while True:\n                try:\n                    msg = self.server_socket.recv_multipart()\n                    if msg and msg[0] == MoriKVManager.AUX_DATA_HEADER:\n                        self._handle_aux_data(msg)\n                        continue\n\n                    if not msg or msg[0] != MORI_GUARD:\n                        logger.warning(\n                            \"Received malformed status message on decode worker\"\n                        )\n                        continue\n                    payload = msg[1:]\n                    if len(payload) < 3:\n                        logger.warning(\"Incomplete status payload received\")\n                        continue\n                    bootstrap_room = int(payload[0].decode(\"ascii\"))\n                    status_code = int(payload[1].decode(\"ascii\"))\n                    prefill_rank = int(payload[2].decode(\"ascii\"))\n                    failure_reason = (\n                        payload[3].decode(\"utf-8\")\n                        if len(payload) > 3 and payload[3]\n                        else None\n                    )\n\n                    if status_code == KVPoll.Success:\n                        tracker = self.prefill_response_tracker[bootstrap_room]\n                        tracker.add(prefill_rank)\n                        expected = self.required_prefill_response_num_table.get(\n                            bootstrap_room, 1\n                        )\n                        if len(tracker) >= expected:\n                            self.prefill_response_tracker.pop(bootstrap_room, None)\n                            self.update_status(bootstrap_room, KVPoll.Success)\n                            self._cleanup_room_tracking(bootstrap_room)\n                    elif status_code == KVPoll.Failed:\n                        if failure_reason:\n                            self.record_failure(bootstrap_room, failure_reason)\n                        self.prefill_response_tracker.pop(bootstrap_room, None)\n                        self.update_status(bootstrap_room, KVPoll.Failed)\n                        self._cleanup_room_tracking(bootstrap_room)\n                    else:\n                        logger.warning(\n                            \"Unknown status code %s received for room %s\",\n                            status_code,\n                            bootstrap_room,\n                        )\n                except Exception:\n                    logger.exception(\"Decode status worker failed\")\n\n        threading.Thread(target=decode_worker, daemon=True).start()\n\n    def notify_decode_status(\n        self,\n        infos: List[TransferInfo],\n        bootstrap_room: int,\n        status: KVPoll,\n        failure_reason: Optional[str] = None,\n    ) -> None:\n        if not infos:\n            return\n        payload = [\n            MORI_GUARD,\n            str(bootstrap_room).encode(\"ascii\"),\n            str(int(status)).encode(\"ascii\"),\n            str(self.attn_tp_rank * self.pp_size + self.pp_rank).encode(\"ascii\"),\n            failure_reason.encode(\"utf-8\") if failure_reason else b\"\",\n        ]\n        for info in infos:\n            try:\n                na = NetworkAddress(info.endpoint, info.dst_port)\n                socket = self._connect(na.to_tcp(), is_ipv6=na.is_ipv6)\n                socket.send_multipart(payload)\n            except Exception:\n                logger.exception(\n                    \"Failed to sync status %s to decode endpoint %s:%s for room %s\",\n                    status,\n                    info.endpoint,\n                    info.dst_port,\n                    bootstrap_room,\n                )\n\n    def _add_remote_peer(self, register_info: KVArgsRegisterInfo) -> None:\n        engine_key = register_info.engine_key\n        if engine_key in self.decode_kv_args_table:\n            logger.debug(\"Remote peer %s already registered. Skipping.\", engine_key)\n            return\n        self.engine.register_remote_engine(register_info.engine_desc)\n        self.decode_kv_args_table[engine_key] = register_info\n        logger.debug(\n            \"Registered decode peer %s (%s:%s)\",\n            engine_key,\n            register_info.endpoint,\n            register_info.dst_port,\n        )\n\n    def _get_mha_mem_desc_slices(\n        self, dst_mem_descs: List[MemoryDesc]\n    ) -> tuple[\n        List[MemoryDesc], List[MemoryDesc], List[MemoryDesc], List[MemoryDesc], int\n    ]:\n        src_descs = self.kv_mem_descs\n        if not src_descs:\n            raise RuntimeError(\"KV memory descriptors are empty on prefill side\")\n\n        num_local_layers = len(src_descs) // 2\n        src_k_descs = src_descs[:num_local_layers]\n        src_v_descs = src_descs[num_local_layers:]\n\n        start_layer = self.kv_args.prefill_start_layer\n        end_layer = start_layer + num_local_layers\n        dst_total_layers = len(dst_mem_descs) // 2\n        if len(dst_mem_descs) < 2 or end_layer > dst_total_layers:\n            raise ValueError(\n                \"Destination KV descriptors do not match prefill pp configuration\"\n            )\n        dst_k_descs = dst_mem_descs[start_layer:end_layer]\n        dst_v_descs = dst_mem_descs[\n            dst_total_layers + start_layer : dst_total_layers + end_layer\n        ]\n        return src_k_descs, src_v_descs, dst_k_descs, dst_v_descs, num_local_layers\n\n    def _get_mla_mem_desc_slices(\n        self, dst_mem_descs: List[MemoryDesc]\n    ) -> tuple[List[MemoryDesc], List[MemoryDesc], int]:\n        src_descs = self.kv_mem_descs\n        num_local_layers = len(src_descs)\n        start_layer = self.kv_args.prefill_start_layer\n        end_layer = start_layer + num_local_layers\n        if end_layer > len(dst_mem_descs):\n            raise ValueError(\n                \"Destination MLA KV descriptors do not match prefill pp configuration\"\n            )\n        dst_slice = dst_mem_descs[start_layer:end_layer]\n        return src_descs, dst_slice, num_local_layers\n\n    def _issue_layer_transfers(\n        self,\n        src_desc: MemoryDesc,\n        dst_desc: MemoryDesc,\n        kv_item_len: int,\n        src_groups: List[List[int]],\n        dst_groups: List[List[int]],\n    ) -> List[TransferStatus]:\n        if not src_groups:\n            return []\n        local_offsets = [int(src_group[0]) * kv_item_len for src_group in src_groups]\n        remote_offsets = [int(dst_group[0]) * kv_item_len for dst_group in dst_groups]\n        sizes = [len(src_group) * kv_item_len for src_group in src_groups]\n\n        transfer_uid = self.engine.allocate_transfer_uid()\n\n        statuses = self.engine.batch_write(\n            [src_desc],\n            [local_offsets],\n            [dst_desc],\n            [remote_offsets],\n            [sizes],\n            [transfer_uid],\n        )\n        return statuses\n\n    def _build_tp_slice_config(self, peer_info: KVArgsRegisterInfo) -> TPSliceConfig:\n        page_size = self.kv_args.page_size\n\n        src_item_len = self.kv_args.kv_item_lens[0]\n        dst_item_len = peer_info.dst_kv_item_len\n\n        bytes_per_token_src = src_item_len // page_size\n        bytes_per_token_dst = dst_item_len // page_size\n\n        prefill_tp_size = self.attn_tp_size\n        decode_tp_size = peer_info.decode_tp_size\n\n        num_kv_heads = self.kv_args.kv_head_num\n        src_heads_per_rank = num_kv_heads\n        dst_heads_per_rank = num_kv_heads * prefill_tp_size // decode_tp_size\n        if dst_heads_per_rank == 0:\n            raise ValueError(\"Destination heads per rank evaluates to zero\")\n\n        bytes_per_head_slice = bytes_per_token_dst // dst_heads_per_rank\n        if bytes_per_head_slice == 0:\n            raise ValueError(\"Head slice size evaluates to zero\")\n\n        local_tp_rank = self.kv_args.engine_rank % prefill_tp_size\n        dst_tp_rank = peer_info.decode_tp_rank % decode_tp_size\n\n        if prefill_tp_size > decode_tp_size:\n            src_head_start = 0\n            num_heads_to_send = src_heads_per_rank\n            dst_head_start = local_tp_rank * src_heads_per_rank\n        else:\n            src_head_start = (dst_tp_rank * dst_heads_per_rank) % src_heads_per_rank\n            num_heads_to_send = dst_heads_per_rank\n            dst_head_start = 0\n\n        src_head_slice_offset = src_head_start * bytes_per_head_slice\n        dst_head_slice_offset = dst_head_start * bytes_per_head_slice\n        heads_bytes_per_token = num_heads_to_send * bytes_per_head_slice\n\n        if heads_bytes_per_token > bytes_per_token_dst:\n            raise ValueError(\n                \"Slice size exceeds destination token capacity for TP slice transfer\"\n            )\n\n        return TPSliceConfig(\n            page_size=page_size,\n            src_item_len=src_item_len,\n            dst_item_len=dst_item_len,\n            bytes_per_token_src=bytes_per_token_src,\n            bytes_per_token_dst=bytes_per_token_dst,\n            src_head_slice_offset=src_head_slice_offset,\n            dst_head_slice_offset=dst_head_slice_offset,\n            heads_bytes_per_token_to_send=heads_bytes_per_token,\n        )\n\n    def _issue_tp_slice_transfers(\n        self,\n        src_desc: MemoryDesc,\n        dst_desc: MemoryDesc,\n        kv_indices: npt.NDArray[np.int32],\n        dst_indices: npt.NDArray[np.int32],\n        tp_cfg: TPSliceConfig,\n    ) -> List[TransferStatus]:\n        if kv_indices.size == 0 or dst_indices.size == 0:\n            return []\n\n        limit = min(kv_indices.size, dst_indices.size)\n        if not limit:\n            return []\n\n        src_pages = kv_indices[:limit].astype(np.int64)\n        dst_pages = dst_indices[:limit].astype(np.int64)\n        token_slots = np.arange(tp_cfg.page_size, dtype=np.int64)\n\n        src_page_bases = src_pages * tp_cfg.src_item_len\n        dst_page_bases = dst_pages * tp_cfg.dst_item_len\n\n        src_token_offsets = token_slots * tp_cfg.bytes_per_token_src\n        dst_token_offsets = token_slots * tp_cfg.bytes_per_token_dst\n\n        local_offsets = (\n            (\n                src_page_bases[:, np.newaxis]\n                + src_token_offsets\n                + tp_cfg.src_head_slice_offset\n            )\n            .flatten()\n            .tolist()\n        )\n        remote_offsets = (\n            (\n                dst_page_bases[:, np.newaxis]\n                + dst_token_offsets\n                + tp_cfg.dst_head_slice_offset\n            )\n            .flatten()\n            .tolist()\n        )\n\n        num_transfers = limit * tp_cfg.page_size\n        sizes = [tp_cfg.heads_bytes_per_token_to_send] * num_transfers\n\n        if not local_offsets:\n            return []\n\n        transfer_uid = self.engine.allocate_transfer_uid()\n        statuses = self.engine.batch_write(\n            [src_desc],\n            [local_offsets],\n            [dst_desc],\n            [remote_offsets],\n            [sizes],\n            [transfer_uid],\n        )\n        return statuses\n\n    def send_kvcache(\n        self,\n        peer_info: KVArgsRegisterInfo,\n        prefill_kv_indices: npt.NDArray[np.int32],\n        dst_kv_indices: npt.NDArray[np.int32],\n    ) -> List[TransferStatus]:\n        src_groups, dst_groups = group_concurrent_contiguous(\n            prefill_kv_indices, dst_kv_indices\n        )\n        statuses = []\n        kv_item_len = self.kv_args.kv_item_lens[0]\n        if self.is_mla_backend:\n            (\n                src_descs,\n                dst_descs,\n                layers_current_pp_stage,\n            ) = self._get_mla_mem_desc_slices(peer_info.dst_kv_mem_descs)\n            for layer_id in range(layers_current_pp_stage):\n                statuses.extend(\n                    self._issue_layer_transfers(\n                        src_descs[layer_id],\n                        dst_descs[layer_id],\n                        kv_item_len,\n                        src_groups,\n                        dst_groups,\n                    )\n                )\n        else:\n            tp_mismatch = peer_info.decode_tp_size != self.attn_tp_size\n            (\n                src_k_descs,\n                src_v_descs,\n                dst_k_descs,\n                dst_v_descs,\n                layers_current_pp_stage,\n            ) = self._get_mha_mem_desc_slices(peer_info.dst_kv_mem_descs)\n\n            if tp_mismatch:\n                tp_cfg = self._build_tp_slice_config(peer_info)\n                for layer_id in range(layers_current_pp_stage):\n                    statuses.extend(\n                        self._issue_tp_slice_transfers(\n                            src_k_descs[layer_id],\n                            dst_k_descs[layer_id],\n                            prefill_kv_indices,\n                            dst_kv_indices,\n                            tp_cfg,\n                        )\n                    )\n                    statuses.extend(\n                        self._issue_tp_slice_transfers(\n                            src_v_descs[layer_id],\n                            dst_v_descs[layer_id],\n                            prefill_kv_indices,\n                            dst_kv_indices,\n                            tp_cfg,\n                        )\n                    )\n            else:\n                src_groups, dst_groups = group_concurrent_contiguous(\n                    prefill_kv_indices, dst_kv_indices\n                )\n                for layer_id in range(layers_current_pp_stage):\n                    statuses.extend(\n                        self._issue_layer_transfers(\n                            src_k_descs[layer_id],\n                            dst_k_descs[layer_id],\n                            kv_item_len,\n                            src_groups,\n                            dst_groups,\n                        )\n                    )\n                    statuses.extend(\n                        self._issue_layer_transfers(\n                            src_v_descs[layer_id],\n                            dst_v_descs[layer_id],\n                            kv_item_len,\n                            src_groups,\n                            dst_groups,\n                        )\n                    )\n\n        return statuses\n\n    def send_aux(\n        self,\n        peer_info: KVArgsRegisterInfo,\n        prefill_aux_index: int,\n        dst_aux_index: int,\n        room: int,\n    ) -> List[TransferStatus]:\n        return self.send_aux_tcp(peer_info, prefill_aux_index, dst_aux_index, room)\n\n    def send_aux_tcp(\n        self,\n        peer_info: KVArgsRegisterInfo,\n        prefill_aux_index: int,\n        dst_aux_index: int,\n        room: int,\n    ) -> List[TransferStatus]:\n        prefill_aux_ptrs = self.kv_args.aux_data_ptrs\n        prefill_aux_item_lens = self.kv_args.aux_item_lens\n\n        for i in range(len(prefill_aux_ptrs)):\n            length = prefill_aux_item_lens[i]\n            src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index\n            data = AuxDataCodec.serialize_data_from_buffer(src_addr, length)\n\n            self.send_aux_data_to_endpoint(\n                remote=peer_info.endpoint,\n                dst_port=peer_info.dst_port,\n                room=room,\n                buffer_index=i,\n                aux_index=dst_aux_index,\n                data=data,\n            )\n\n        return []\n\n    def send_aux_data_to_endpoint(\n        self,\n        remote: str,\n        dst_port: int,\n        room: int,\n        buffer_index: int,\n        aux_index: int,\n        data: bytes,\n    ):\n        na = NetworkAddress(remote, dst_port)\n        socket = self._connect(na.to_tcp(), is_ipv6=na.is_ipv6)\n\n        socket.send_multipart(\n            [\n                MoriKVManager.AUX_DATA_HEADER,\n                str(room).encode(\"ascii\"),\n                str(buffer_index).encode(\"ascii\"),\n                str(aux_index).encode(\"ascii\"),\n                struct.pack(\">I\", len(data)),\n                data,\n            ]\n        )\n\n    def _handle_aux_data(self, msg: List[bytes]):\n        \"\"\"Handle AUX_DATA messages received by the decode thread.\"\"\"\n        room = int(msg[1].decode(\"ascii\"))\n        buffer_index = int(msg[2].decode(\"ascii\"))\n        aux_index = int(msg[3].decode(\"ascii\"))\n        data_length = struct.unpack(\">I\", msg[4])[0]\n        data = msg[5]\n\n        if len(data) != data_length:\n            logger.error(f\"AUX_DATA length mismatch for bootstrap_room {room}\")\n            return\n\n        AuxDataCodec.deserialize_data_to_buffer(\n            self.kv_args, buffer_index, aux_index, data\n        )\n\n        logger.debug(\n            f\"Received AUX_DATA for bootstrap_room {room} with length:{len(data)}\"\n        )\n\n    def add_transfer_request(\n        self,\n        bootstrap_room: int,\n        kv_indices: npt.NDArray[np.int32],\n        index_slice: slice,\n        is_last: bool,\n        aux_index: Optional[int] = None,\n        state_indices: Optional[npt.NDArray[np.int32]] = None,\n    ) -> Tuple[List[TransferStatus], Optional[List[TransferInfo]]]:\n        assert self.disaggregation_mode == DisaggregationMode.PREFILL\n        transfer_infos = self.transfer_infos.get(bootstrap_room)\n        if not transfer_infos:\n            raise RuntimeError(\n                f\"No transfer info found for bootstrap_room={bootstrap_room}\"\n            )\n        result_statuses = []\n        target_infos_snapshot: Optional[List[TransferInfo]] = None\n        with self.transfer_lock:\n            self.update_status(bootstrap_room, KVPoll.Transferring)\n            for info in transfer_infos.values():\n                peer_info = self.decode_kv_args_table.get(info.engine_key)\n                if not peer_info:\n                    self.record_failure(\n                        bootstrap_room,\n                        f\"Peer info missing for engine {info.engine_key}\",\n                    )\n                    raise RuntimeError(\n                        f\"Missing decode peer info for {info.engine_key}\"\n                    )\n                if not info.is_dummy:\n                    dst_indices_chunk = info.dst_kv_indices[index_slice]\n                    statuses = self.send_kvcache(\n                        peer_info, kv_indices, dst_indices_chunk\n                    )\n                    result_statuses.extend(statuses)\n                if (\n                    is_last\n                    and aux_index is not None\n                    and info.dst_aux_index >= 0\n                    and self.pp_group.is_last_rank\n                ):\n                    result_statuses.extend(\n                        self.send_aux(\n                            peer_info, aux_index, info.dst_aux_index, bootstrap_room\n                        )\n                    )\n            if is_last:\n                self.update_status(bootstrap_room, KVPoll.Success)\n                target_infos_snapshot = list(transfer_infos.values())\n                self.transfer_infos.pop(bootstrap_room, None)\n        return result_statuses, target_infos_snapshot\n\n\nclass MoriKVSender(CommonKVSender):\n    def __init__(\n        self,\n        mgr: MoriKVManager,\n        bootstrap_addr: str,\n        bootstrap_room: int,\n        dest_tp_ranks: List[int],\n        pp_rank: int,\n    ):\n        super().__init__(mgr, bootstrap_addr, bootstrap_room, dest_tp_ranks, pp_rank)\n        self.transfer_statuses: List[TransferStatus] = []\n        self.pending_infos: Optional[List[TransferInfo]] = None\n        self.sent_last_chunk = False\n        self.conclude_state: Optional[KVPoll] = None\n        self.status_notified = False\n        self.init_time = time.time()\n\n    def send(\n        self,\n        kv_indices: npt.NDArray[np.int32],\n        state_indices: Optional[List[int]] = None,\n    ):\n        index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))\n        self.curr_idx += len(kv_indices)\n        is_last = self.curr_idx == self.num_kv_indices\n\n        # Special handling for cp\n        if self.kv_mgr.enable_all_cp_ranks_for_transfer:\n            kv_indices, index_slice = filter_kv_indices_for_cp_rank(\n                self.kv_mgr,\n                kv_indices,\n                index_slice,\n            )\n        elif self.kv_mgr.is_dummy_cp_rank:\n            if not is_last:\n                return\n            else:\n                self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Success)\n                return\n        statuses, infos = self.kv_mgr.add_transfer_request(\n            self.bootstrap_room,\n            kv_indices,\n            index_slice,\n            is_last,\n            aux_index=self.aux_index if is_last else None,\n        )\n        self.transfer_statuses.extend(statuses)\n        if infos is not None:\n            self.pending_infos = infos\n            self.sent_last_chunk = True\n\n    def poll(self) -> KVPoll:\n        if self.conclude_state is not None:\n            return self.conclude_state\n\n        status = self.kv_mgr.check_status(self.bootstrap_room)\n        if status == KVPoll.Bootstrapping:\n            elapsed = time.time() - self.init_time\n            if elapsed >= self.kv_mgr.bootstrap_timeout:\n                reason = (\n                    f\"Request {self.bootstrap_room} timed out after {elapsed:.1f}s \"\n                    \"waiting for decode handshake\"\n                )\n                self.kv_mgr.record_failure(self.bootstrap_room, reason)\n                self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)\n                self._finalize_failure(reason)\n                return KVPoll.Failed\n            return status\n\n        if status == KVPoll.Failed:\n            self._finalize_failure()\n            return KVPoll.Failed\n\n        transfers_done = self._all_transfers_finished()\n        if transfers_done:\n            if self._has_transfer_error():\n                reason = self._collect_failure_reason()\n                self.kv_mgr.record_failure(self.bootstrap_room, reason)\n                self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)\n                self._finalize_failure(reason)\n                return KVPoll.Failed\n            self._notify_decode(KVPoll.Success)\n            self.conclude_state = KVPoll.Success\n            return KVPoll.Success\n        return KVPoll.Transferring if status == KVPoll.Success else status\n\n    def _all_transfers_finished(self) -> bool:\n        if not self.sent_last_chunk:\n            return False\n        if not self.transfer_statuses:\n            return True\n        return all(not status.InProgress() for status in self.transfer_statuses)\n\n    def _has_transfer_error(self) -> bool:\n        return any(status.Failed() for status in self.transfer_statuses)\n\n    def _collect_failure_reason(self) -> str:\n        for status in self.transfer_statuses:\n            if status.Failed():\n                return f\"KV transfer failed: {status.Message()}\"\n        return \"KV transfer failed due to unknown reason\"\n\n    def _notify_decode(\n        self, status: KVPoll, failure_reason: Optional[str] = None\n    ) -> None:\n        if self.status_notified:\n            return\n        if self.pending_infos:\n            self.kv_mgr.notify_decode_status(\n                self.pending_infos, self.bootstrap_room, status, failure_reason\n            )\n        self.status_notified = True\n\n    def _finalize_failure(self, failure_reason: Optional[str] = None) -> None:\n        if self.conclude_state == KVPoll.Failed:\n            return\n        if failure_reason is None:\n            failure_reason = self.kv_mgr.failure_records.get(\n                self.bootstrap_room, \"KV transfer failed\"\n            )\n        self._notify_decode(KVPoll.Failed, failure_reason)\n        self.conclude_state = KVPoll.Failed\n\n    def clear(self) -> None:\n        self.kv_mgr.request_status.pop(self.bootstrap_room, None)\n\n    def failure_exception(self):\n        if self.conclude_state is None:\n            self._finalize_failure()\n        self.clear()\n        with self.kv_mgr.failure_lock:\n            failure_reason = self.kv_mgr.failure_records.pop(\n                self.bootstrap_room, \"KV transfer failed\"\n            )\n        raise RuntimeError(failure_reason)\n\n    def abort(self):\n        reason = \"Aborted by AbortReq.\"\n        self.kv_mgr.record_failure(self.bootstrap_room, reason)\n        self._notify_decode(KVPoll.Failed, reason)\n        self.conclude_state = KVPoll.Failed\n\n\nclass MoriKVReceiver(CommonKVReceiver):\n\n    def __init__(\n        self,\n        mgr: MoriKVManager,\n        bootstrap_addr: str,\n        bootstrap_room: Optional[int] = None,\n        prefill_dp_rank: Optional[int] = None,\n    ):\n        super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank)\n        self.conclude_state: Optional[KVPoll] = None\n        self.init_time: Optional[float] = None\n        if self.bootstrap_room is None or self.bootstrap_infos is None:\n            return\n        self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add(self.bootstrap_room)\n        self.kv_mgr.room_to_bootstrap_addr[self.bootstrap_room] = self.bootstrap_addr\n        self.kv_mgr.update_status(self.bootstrap_room, KVPoll.WaitingForInput)\n        self._register_kv_args()\n\n    def _register_kv_args(self):\n        if self.bootstrap_infos is None:\n            return\n        engine_desc_blob = self.kv_mgr.engine_desc.pack()\n        packed_kv_descs = _pack_mem_desc_list(self.kv_mgr.kv_mem_descs)\n        packed_aux_descs = _pack_mem_desc_list(self.kv_mgr.aux_mem_descs)\n        packed_state_descs = _pack_mem_desc_list(self.kv_mgr.state_mem_descs)\n        gpu_id = str(self.kv_mgr.kv_args.gpu_id).encode(\"ascii\")\n        decode_tp_size = str(self.kv_mgr.attn_tp_size).encode(\"ascii\")\n        decode_tp_rank = str(self.kv_mgr.kv_args.engine_rank).encode(\"ascii\")\n        kv_item_len = str(self.kv_mgr.kv_args.kv_item_lens[0]).encode(\"ascii\")\n\n        for bootstrap_info in self.bootstrap_infos:\n            sock, lock = self._connect_to_bootstrap_server(bootstrap_info)\n            with lock:\n                sock.send_multipart(\n                    [\n                        MORI_GUARD,\n                        \"None\".encode(\"ascii\"),\n                        self.kv_mgr.local_ip.encode(\"ascii\"),\n                        str(self.kv_mgr.rank_port).encode(\"ascii\"),\n                        engine_desc_blob,\n                        packed_kv_descs,\n                        packed_aux_descs,\n                        packed_state_descs,\n                        gpu_id,\n                        decode_tp_size,\n                        decode_tp_rank,\n                        kv_item_len,\n                    ]\n                )\n\n    def init(\n        self,\n        kv_indices: npt.NDArray[np.int32],\n        aux_index: Optional[int] = None,\n        state_indices: Optional[List[int]] = None,\n    ):\n        if self.bootstrap_infos is None or self.bootstrap_room is None:\n            return\n\n        kv_indices_bytes = (\n            np.asarray(kv_indices, dtype=np.int32).tobytes() if kv_indices.size else b\"\"\n        )\n        aux_bytes = str(aux_index).encode(\"ascii\") if aux_index is not None else b\"\"\n        state_bytes = b\"\"\n\n        for bootstrap_info in self.bootstrap_infos:\n            sock, lock = self._connect_to_bootstrap_server(bootstrap_info)\n            is_dummy = bootstrap_info.get(\"is_dummy\", False)\n            with lock:\n                sock.send_multipart(\n                    [\n                        MORI_GUARD,\n                        str(self.bootstrap_room).encode(\"ascii\"),\n                        self.kv_mgr.local_ip.encode(\"ascii\"),\n                        str(self.kv_mgr.rank_port).encode(\"ascii\"),\n                        self.kv_mgr.engine_desc.key.encode(\"ascii\"),\n                        kv_indices_bytes if not is_dummy else b\"\",\n                        aux_bytes if not is_dummy else b\"\",\n                        state_bytes,\n                        str(self.required_dst_info_num).encode(\"ascii\"),\n                    ]\n                )\n        self.init_time = time.time()\n\n    def poll(self) -> KVPoll:\n        if self.conclude_state is not None:\n            return self.conclude_state\n\n        status = self.kv_mgr.check_status(self.bootstrap_room)\n        if status in (KVPoll.Success, KVPoll.Failed):\n            self.conclude_state = status\n            return status\n\n        if status == KVPoll.WaitingForInput and self.init_time is not None:\n            elapsed = time.time() - self.init_time\n            if elapsed >= self.kv_mgr.waiting_timeout:\n                reason = f\"Request {self.bootstrap_room} timed out after {elapsed:.1f}s waiting for KV transfer\"\n                self.kv_mgr.record_failure(self.bootstrap_room, reason)\n                self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)\n                self.conclude_state = KVPoll.Failed\n                return KVPoll.Failed\n\n        return status\n\n    def clear(self) -> None:\n        if self.bootstrap_room is None:\n            return\n        self.kv_mgr.request_status.pop(self.bootstrap_room, None)\n        self.kv_mgr.required_prefill_response_num_table.pop(self.bootstrap_room, None)\n        self.kv_mgr.prefill_response_tracker.pop(self.bootstrap_room, None)\n        self.kv_mgr._cleanup_room_tracking(self.bootstrap_room)\n\n    def failure_exception(self):\n        if self.conclude_state is None:\n            self.conclude_state = KVPoll.Failed\n\n        self.clear()\n        with self.kv_mgr.failure_lock:\n            failure_reason = self.kv_mgr.failure_records.pop(\n                self.bootstrap_room, \"KV transfer failed\"\n            )\n        raise RuntimeError(failure_reason)\n\n    def abort(self):\n        if self.bootstrap_room is None:\n            return\n        reason = \"Aborted by AbortReq.\"\n        self.kv_mgr.record_failure(self.bootstrap_room, reason)\n        self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)\n        self.conclude_state = KVPoll.Failed\n        self.clear()\n\n\nclass MoriKVBootstrapServer(CommonKVBootstrapServer):\n    pass\n"
  },
  {
    "path": "python/sglang/srt/disaggregation/nixl/__init__.py",
    "content": "from sglang.srt.disaggregation.nixl.conn import (\n    NixlKVBootstrapServer,\n    NixlKVManager,\n    NixlKVReceiver,\n    NixlKVSender,\n)\n"
  },
  {
    "path": "python/sglang/srt/disaggregation/nixl/conn.py",
    "content": "from __future__ import annotations\n\nimport dataclasses\nimport logging\nimport struct\nimport threading\nimport time\nimport uuid\nfrom collections import defaultdict\nfrom typing import Dict, List, Optional, Set\n\nimport numpy as np\nimport numpy.typing as npt\n\nfrom sglang.srt.disaggregation.base.conn import KVArgs, KVPoll\nfrom sglang.srt.disaggregation.common.conn import (\n    CommonKVBootstrapServer,\n    CommonKVManager,\n    CommonKVReceiver,\n    CommonKVSender,\n)\nfrom sglang.srt.disaggregation.common.utils import group_concurrent_contiguous\nfrom sglang.srt.disaggregation.utils import (\n    DisaggregationMode,\n    filter_kv_indices_for_cp_rank,\n)\nfrom sglang.srt.environ import envs\nfrom sglang.srt.server_args import ServerArgs\n\nlogger = logging.getLogger(__name__)\n\nGUARD = \"NixlMsgGuard\".encode(\"ascii\")\n\n\n@dataclasses.dataclass\nclass TransferInfo:\n    \"\"\"Contains indices for a transfer, sent by KVReceiver. Received by prefill bootstrap thread.\"\"\"\n\n    room: int\n    endpoint: str\n    dst_port: int\n    agent_name: str\n    dst_kv_indices: npt.NDArray[np.int32]\n    dst_aux_index: int\n    required_dst_info_num: int\n    dst_state_indices: List[int]\n\n    def is_dummy(self):\n        return self.dst_kv_indices.size == 0\n\n    @classmethod\n    def from_zmq(cls, msg: List[bytes]):\n        # Parse state_indices from msg[7] if present\n        if len(msg) > 7 and msg[7] != b\"\":\n            dst_state_indices = list(np.frombuffer(msg[7], dtype=np.int32))\n        else:\n            dst_state_indices = []\n\n        return cls(\n            room=int(msg[0].decode(\"ascii\")),\n            endpoint=msg[1].decode(\"ascii\"),\n            dst_port=int(msg[2].decode(\"ascii\")),\n            agent_name=msg[3].decode(\"ascii\"),\n            dst_kv_indices=np.frombuffer(msg[4], dtype=np.int32),\n            dst_aux_index=int(msg[5].decode(\"ascii\")),\n            required_dst_info_num=int(msg[6].decode(\"ascii\")),\n            dst_state_indices=dst_state_indices,\n        )\n\n\n@dataclasses.dataclass\nclass KVArgsRegisterInfo:\n    \"\"\"Contains base pointers and other info which only needs to be sent once by KVReceiver. Received by prefill bootstrap thread.\"\"\"\n\n    room: str\n    endpoint: str\n    dst_port: int\n    agent_name: str\n    agent_metadata: bytes\n    dst_kv_ptrs: list[int]\n    dst_aux_ptrs: list[int]\n    dst_state_data_ptrs: list[int]\n    gpu_id: int\n    decode_tp_size: int\n    decode_tp_rank: int\n    dst_kv_item_len: int\n\n    @classmethod\n    def from_zmq(cls, msg: List[bytes]):\n        # Parse state_data_ptrs from msg[7] if present\n        if len(msg) > 7 and msg[7] != b\"\":\n            dst_state_data_ptrs = list(struct.unpack(f\"{len(msg[7]) // 8}Q\", msg[7]))\n        else:\n            dst_state_data_ptrs = []\n\n        return cls(\n            room=str(msg[0].decode(\"ascii\")),\n            endpoint=msg[1].decode(\"ascii\"),\n            dst_port=int(msg[2].decode(\"ascii\")),\n            agent_name=msg[3].decode(\"ascii\"),\n            agent_metadata=msg[4],\n            dst_kv_ptrs=list(struct.unpack(f\"{len(msg[5]) // 8}Q\", msg[5])),\n            dst_aux_ptrs=list(struct.unpack(f\"{len(msg[6]) // 8}Q\", msg[6])),\n            dst_state_data_ptrs=dst_state_data_ptrs,\n            gpu_id=int(msg[8].decode(\"ascii\")),\n            decode_tp_size=int(msg[9].decode(\"ascii\")),\n            decode_tp_rank=int(msg[10].decode(\"ascii\")),\n            dst_kv_item_len=int(msg[11].decode(\"ascii\")),\n        )\n\n\n@dataclasses.dataclass\nclass TransferStatus:\n    \"\"\"Used by KV Receiver to know when a transfer is done.\"\"\"\n\n    # KV chunks received per pp_rank: {pp_rank: set of chunk_ids}\n    received_kvs_per_pp: Dict[int, Set[int]] = dataclasses.field(\n        default_factory=lambda: defaultdict(set)\n    )\n    # Expected chunk count per pp_rank (set when is_last=True): {pp_rank: expected_count}\n    expected_kvs_per_pp: Dict[int, int] = dataclasses.field(default_factory=dict)\n    # Number of PP ranks expected to send data.\n    num_pp_ranks_expected: Optional[int] = None\n    # Whether aux data has been received.\n    received_aux: bool = False\n    # PP ranks that have sent state data (state is layer-specific, each PP rank sends its portion).\n    received_state_per_pp: Set[int] = dataclasses.field(default_factory=set)\n    # Whether state data is expected (set based on state_type).\n    expects_state: bool = False\n    # Mark as failed\n    is_failure: bool = False\n\n    def is_done(self):\n        if self.is_failure:\n            return True\n        if self.num_pp_ranks_expected is None or not self.received_aux:\n            return False\n        # If state data is expected, check all PP ranks have sent it\n        if (\n            self.expects_state\n            and len(self.received_state_per_pp) < self.num_pp_ranks_expected\n        ):\n            return False\n        # All PP ranks must have reported their expected count\n        if len(self.expected_kvs_per_pp) < self.num_pp_ranks_expected:\n            return False\n        # Each PP rank must have received all expected chunks\n        for pp_rank, expected in self.expected_kvs_per_pp.items():\n            if len(self.received_kvs_per_pp[pp_rank]) != expected:\n                return False\n        return True\n\n    def is_failed(self):\n        return self.is_failure\n\n\nclass NixlKVManager(CommonKVManager):\n    def __init__(\n        self,\n        args: KVArgs,\n        disaggregation_mode: DisaggregationMode,\n        server_args: ServerArgs,\n        is_mla_backend: Optional[bool] = False,\n    ):\n        super().__init__(args, disaggregation_mode, server_args, is_mla_backend)\n        try:\n            from nixl._api import nixl_agent, nixl_agent_config\n        except ImportError as e:\n            raise ImportError(\n                \"Please install NIXL by following the instructions at \"\n                \"https://github.com/ai-dynamo/nixl/blob/main/README.md \"\n                \"to run SGLang with NixlTransferEngine.\"\n            ) from e\n\n        backend = envs.SGLANG_DISAGGREGATION_NIXL_BACKEND.get()\n        agent_config = nixl_agent_config(\n            backends=[backend],\n            num_threads=(8 if disaggregation_mode == DisaggregationMode.PREFILL else 0),\n        )\n        self.agent = nixl_agent(str(uuid.uuid4()), agent_config)\n\n        available_plugins = self.agent.get_plugin_list()\n        if backend not in available_plugins:\n            raise ValueError(\n                f\"NIXL backend '{backend}' not found. Available: {available_plugins}. \"\n                f\"Please install the required NIXL plugin or choose from: {available_plugins}\"\n            )\n        logger.info(f\"NIXL KVManager initialized with backend: {backend}\")\n\n        self.register_buffer_to_engine()\n\n        if self.disaggregation_mode == DisaggregationMode.PREFILL:\n            self._start_bootstrap_thread()\n        elif self.disaggregation_mode == DisaggregationMode.DECODE:\n            self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(\n                TransferStatus\n            )\n            self._start_heartbeat_checker_thread()\n        else:\n            raise ValueError(\n                f\"Unsupported DisaggregationMode: {self.disaggregation_mode}\"\n            )\n\n    def _start_heartbeat_checker_thread(self):\n        \"\"\"\n        Start the heartbeat checker thread for Decode worker.\n        TODO (smor): unite nixl heartbeat checker with mooncake's.\n        \"\"\"\n\n        def heartbeat_checker():\n            while True:\n                time.sleep(self.heartbeat_interval)\n                with self.connection_lock:\n                    addresses = list(self.prefill_info_table.keys())\n\n                for bootstrap_addr in addresses:\n                    session = None\n                    try:\n                        with self.session_pool_lock:\n                            session = self.session_pool[bootstrap_addr]\n                        response = session.get(\n                            f\"http://{bootstrap_addr}/health\",\n                            timeout=(2, 3),\n                            headers={\"Connection\": \"keep-alive\"},\n                        )\n                        if response.status_code == 200:\n                            self.heartbeat_failures[bootstrap_addr] = 0\n\n                        else:\n                            logger.info(\n                                f\"Attempting to reconnect to {bootstrap_addr}...\"\n                            )\n                            self.heartbeat_failures[bootstrap_addr] = (\n                                self.heartbeat_failures.get(bootstrap_addr, 0) + 1\n                            )\n                            with self.session_pool_lock:\n                                if bootstrap_addr in self.session_pool:\n                                    del self.session_pool[bootstrap_addr]\n                    except Exception:\n                        logger.info(f\"Attempting to reconnect to {bootstrap_addr}...\")\n                        self.heartbeat_failures[bootstrap_addr] = (\n                            self.heartbeat_failures.get(bootstrap_addr, 0) + 1\n                        )\n\n                    if (\n                        self.heartbeat_failures.get(bootstrap_addr, 0)\n                        >= self.max_failures\n                    ):\n                        self._handle_node_failure(bootstrap_addr)\n                        with self.session_pool_lock:\n                            if bootstrap_addr in self.session_pool:\n                                del self.session_pool[bootstrap_addr]\n\n        threading.Thread(target=heartbeat_checker, daemon=True).start()\n\n    def _handle_node_failure(self, failed_bootstrap_addr):\n        \"\"\"Handle failure of a prefill node.\"\"\"\n        with self.connection_lock:\n            keys_to_remove = [\n                k for k in self.connection_pool if k.startswith(failed_bootstrap_addr)\n            ]\n            for k in keys_to_remove:\n                del self.connection_pool[k]\n            self.prefill_info_table.pop(failed_bootstrap_addr, None)\n\n            possible_affected_rooms = self.addr_to_rooms_tracker.get(\n                failed_bootstrap_addr, []\n            )\n            self.addr_to_rooms_tracker.pop(failed_bootstrap_addr, None)\n\n        # Mark all pending transfers associated with the failed node as failed\n        affected_rooms = []\n        for room in possible_affected_rooms:\n            if (\n                room in self.transfer_statuses\n                and not self.transfer_statuses[room].is_done()\n            ):\n                # Mark the transfer as failed\n                self.transfer_statuses[room].is_failure = True\n                affected_rooms.append(room)\n\n        logger.error(\n            f\"Lost connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr}), \"\n            f\"{len(affected_rooms)} transfers affected\"\n        )\n        for room in possible_affected_rooms:\n            logger.error(f\"Let room {room} be failed due to prefill down\")\n            self.update_status(room, KVPoll.Failed)\n\n    def register_buffer_to_engine(self):\n        kv_addrs = []\n        for kv_data_ptr, kv_data_len in zip(\n            self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens\n        ):\n            kv_addrs.append((kv_data_ptr, kv_data_len, self.kv_args.gpu_id, \"\"))\n        self.kv_descs = self.agent.register_memory(kv_addrs, \"VRAM\")\n        logger.debug(f\"Register kv tensors, len(kv_addr)= {len(kv_addrs)}\")\n        if not self.kv_descs:\n            raise Exception(\"NIXL memory registration failed for kv tensors\")\n        aux_addrs = []\n        for aux_data_ptr, aux_data_len in zip(\n            self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens\n        ):\n            aux_addrs.append((aux_data_ptr, aux_data_len, 0, \"\"))\n        self.aux_descs = self.agent.register_memory(aux_addrs, \"DRAM\")\n        logger.debug(f\"Register aux tensors, len(aux_addrs)= {len(aux_addrs)}\")\n        if not self.aux_descs:\n            raise Exception(\"NIXL memory registration failed for aux tensors\")\n\n        # Register state/extra pool data buffers if present\n        if self.kv_args.state_data_ptrs and self.kv_args.state_data_lens:\n            state_addrs = []\n            for state_data_ptr, state_data_len in zip(\n                self.kv_args.state_data_ptrs, self.kv_args.state_data_lens\n            ):\n                state_addrs.append(\n                    (state_data_ptr, state_data_len, self.kv_args.gpu_id, \"\")\n                )\n            self.state_descs = self.agent.register_memory(state_addrs, \"VRAM\")\n            logger.debug(\n                f\"Register state tensors, len(state_addrs)= {len(state_addrs)}\"\n            )\n            if not self.state_descs:\n                raise Exception(\"NIXL memory registration failed for state tensors\")\n\n    def _add_remote_peer(self, decode_kv_args: KVArgsRegisterInfo):\n        agent_name = decode_kv_args.agent_name\n        if agent_name in self.decode_kv_args_table:\n            logger.info(f\"Peer {agent_name} was already registered, ignoring.\")\n            return\n        self.decode_kv_args_table[agent_name] = decode_kv_args\n        self.agent.add_remote_agent(decode_kv_args.agent_metadata)\n\n    def _send_kvcache_generic(\n        self,\n        peer_name: str,\n        src_data_ptrs: list[int],\n        dst_data_ptrs: list[int],\n        item_lens: list[int],\n        prefill_data_indices: npt.NDArray[np.int32],\n        dst_data_indices: npt.NDArray[np.int32],\n        dst_gpu_id: int,\n        notif: str,\n    ):\n        \"\"\"Generic KV cache transfer supporting both MHA and MLA architectures.\n        Used by both send_kvcache and maybe_send_extra.\"\"\"\n        # group by indices\n        prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(\n            prefill_data_indices, dst_data_indices\n        )\n\n        logger.debug(f\"sending kvcache to {peer_name} with notif {notif}\")\n        # Make descs\n        if self.is_mla_backend:\n            src_kv_ptrs, dst_kv_ptrs, layers_current_pp_stage = (\n                self.get_mla_kv_ptrs_with_pp(src_data_ptrs, dst_data_ptrs)\n            )\n            layers_params = [\n                (\n                    src_kv_ptrs[layer_id],\n                    dst_kv_ptrs[layer_id],\n                    item_lens[layer_id],\n                )\n                for layer_id in range(layers_current_pp_stage)\n            ]\n        else:\n            src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (\n                self.get_mha_kv_ptrs_with_pp(src_data_ptrs, dst_data_ptrs)\n            )\n\n            layers_params = [\n                (\n                    src_k_ptrs[layer_id],\n                    dst_k_ptrs[layer_id],\n                    item_lens[layer_id],\n                )\n                for layer_id in range(layers_current_pp_stage)\n            ] + [\n                (\n                    src_v_ptrs[layer_id],\n                    dst_v_ptrs[layer_id],\n                    item_lens[layer_id],\n                )\n                for layer_id in range(layers_current_pp_stage)\n            ]\n\n        src_addrs = []\n        src_lens = []\n        dst_addrs = []\n        dst_lens = []\n\n        # Precompute block starts/lengths to reduce Python-level loops.\n        prefill_starts = np.fromiter(\n            (block[0] for block in prefill_kv_blocks), dtype=np.int64\n        )\n        dst_starts = np.fromiter((block[0] for block in dst_kv_blocks), dtype=np.int64)\n        block_lens = np.fromiter(\n            (len(block) for block in prefill_kv_blocks), dtype=np.int64\n        )\n\n        for src_ptr, dst_ptr, item_len in layers_params:\n            lengths = item_len * block_lens\n            src_addrs.append(src_ptr + prefill_starts * item_len)\n            src_lens.append(lengths)\n            dst_addrs.append(dst_ptr + dst_starts * item_len)\n            dst_lens.append(lengths)\n\n        def make_req_array(addr_chunks, len_chunks, gpu):\n            if not addr_chunks:\n                return np.empty((0, 3), dtype=np.int64)\n            flat_addrs = np.concatenate(addr_chunks)\n            flat_lens = np.concatenate(len_chunks)\n            return np.column_stack(\n                (\n                    flat_addrs,\n                    flat_lens,\n                    np.full_like(flat_addrs, gpu),\n                )\n            )\n\n        src_reqs = make_req_array(src_addrs, src_lens, self.kv_args.gpu_id)\n        dst_reqs = make_req_array(dst_addrs, dst_lens, dst_gpu_id)\n\n        logger.debug(\n            f\"len(src_addrs): before group: {len(prefill_data_indices)}, after group: {len(src_addrs)}\"\n        )\n        src_descs = self.agent.get_xfer_descs(src_reqs, \"VRAM\")\n        dst_descs = self.agent.get_xfer_descs(dst_reqs, \"VRAM\")\n        # Transfer data\n        xfer_handle = self.agent.initialize_xfer(\n            \"WRITE\",\n            src_descs,\n            dst_descs,\n            peer_name,\n            notif.encode(\"ascii\"),  # type: ignore\n        )\n        if not xfer_handle:\n            raise Exception(\"KVSender failed to create transfer\")\n        state = self.agent.transfer(xfer_handle)\n        if state == \"ERR\":\n            raise Exception(\"KVSender failed to post transfer\")\n        return xfer_handle\n\n    def send_kvcache(\n        self,\n        peer_name: str,\n        prefill_kv_indices: npt.NDArray[np.int32],\n        dst_kv_ptrs: list[int],\n        dst_kv_indices: npt.NDArray[np.int32],\n        dst_gpu_id: int,\n        notif: str,\n    ):\n        return self._send_kvcache_generic(\n            peer_name=peer_name,\n            src_data_ptrs=self.kv_args.kv_data_ptrs,\n            dst_data_ptrs=dst_kv_ptrs,\n            item_lens=self.kv_args.kv_item_lens,\n            prefill_data_indices=prefill_kv_indices,\n            dst_data_indices=dst_kv_indices,\n            dst_gpu_id=dst_gpu_id,\n            notif=notif,\n        )\n\n    def send_kvcache_slice(\n        self,\n        peer_name: str,\n        prefill_kv_indices: npt.NDArray[np.int32],\n        dst_kv_ptrs: list[int],\n        dst_kv_indices: npt.NDArray[np.int32],\n        dst_gpu_id: int,\n        notif: str,\n        prefill_tp_size: int,\n        decode_tp_size: int,\n        decode_tp_rank: int,\n        dst_kv_item_len: int,\n    ):\n        # Get configuration from kv_args\n        local_tp_rank_in_group = self.kv_args.engine_rank % prefill_tp_size\n        dst_tp_rank_in_group = decode_tp_rank % decode_tp_size\n        num_kv_heads = self.kv_args.kv_head_num\n\n        # Calculate head distribution\n        src_heads_per_rank = num_kv_heads\n        dst_heads_per_rank = num_kv_heads * prefill_tp_size // decode_tp_size\n\n        src_kv_item_len = self.kv_args.kv_item_lens[0]\n        page_size = self.kv_args.page_size\n\n        bytes_per_head_slice_to_send = (\n            dst_kv_item_len // page_size // dst_heads_per_rank\n        )\n\n        # Determine which heads to send\n        if prefill_tp_size > decode_tp_size:\n            # Multiple prefill ranks to one decode rank\n            src_head_start_offset = 0\n            num_heads_to_send = src_heads_per_rank\n            dst_head_start_offset = local_tp_rank_in_group * src_heads_per_rank\n        else:\n            # Send KVCache from 1 prefill instance to multiple decode instances\n            src_head_start_offset = (\n                dst_tp_rank_in_group * dst_heads_per_rank\n            ) % src_heads_per_rank\n            num_heads_to_send = dst_heads_per_rank\n            dst_head_start_offset = 0\n\n        src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (\n            self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)\n        )\n        # Calculate precise byte offset and length for the sub-slice within the token\n        src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send\n        dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send\n        heads_bytes_per_token_to_send = num_heads_to_send * bytes_per_head_slice_to_send\n\n        src_dst_ptr_pairs = [\n            (\n                src_k_ptrs[layer_id],\n                dst_k_ptrs[layer_id],\n            )\n            for layer_id in range(layers_current_pp_stage)\n        ] + [\n            (\n                src_v_ptrs[layer_id],\n                dst_v_ptrs[layer_id],\n            )\n            for layer_id in range(layers_current_pp_stage)\n        ]\n\n        prefill_indices = np.asarray(prefill_kv_indices, dtype=np.int64)\n        dst_indices = np.asarray(dst_kv_indices, dtype=np.int64)\n        bytes_per_token_prefill = src_kv_item_len // page_size\n        bytes_per_token_decode = dst_kv_item_len // page_size\n        token_offsets = np.arange(page_size, dtype=np.int64)\n\n        src_addrs = []\n        dst_addrs = []\n\n        for src_ptr, dst_ptr in src_dst_ptr_pairs:\n            src_page_bases = src_ptr + prefill_indices * src_kv_item_len\n            dst_page_bases = dst_ptr + dst_indices * dst_kv_item_len\n\n            src_all = (\n                src_page_bases[:, None]\n                + token_offsets[None, :] * bytes_per_token_prefill\n                + src_head_slice_offset\n            ).ravel()\n            dst_all = (\n                dst_page_bases[:, None]\n                + token_offsets[None, :] * bytes_per_token_decode\n                + dst_head_slice_offset\n            ).ravel()\n\n            src_addrs.append(src_all)\n            dst_addrs.append(dst_all)\n\n        def make_req_array(addr_chunks, size, gpu):\n            if not addr_chunks:\n                return np.empty((0, 3), dtype=np.int64)\n            flat_addrs = np.concatenate(addr_chunks)\n            return np.column_stack(\n                (\n                    flat_addrs,\n                    np.full_like(flat_addrs, size),\n                    np.full_like(flat_addrs, gpu),\n                )\n            )\n\n        src_reqs = make_req_array(\n            src_addrs, heads_bytes_per_token_to_send, self.kv_args.gpu_id\n        )\n        dst_reqs = make_req_array(dst_addrs, heads_bytes_per_token_to_send, dst_gpu_id)\n\n        # Use NIXL agent for transfer\n        src_descs = self.agent.get_xfer_descs(src_reqs, \"VRAM\")\n        dst_descs = self.agent.get_xfer_descs(dst_reqs, \"VRAM\")\n\n        xfer_handle = self.agent.initialize_xfer(\n            \"WRITE\", src_descs, dst_descs, peer_name, notif.encode(\"ascii\")\n        )\n        if not xfer_handle:\n            raise Exception(\"Failed to create sliced KV transfer\")\n\n        state = self.agent.transfer(xfer_handle)\n        if state == \"ERR\":\n            raise Exception(\"Failed to post sliced KV transfer\")\n\n        return xfer_handle\n\n    def send_aux(\n        self,\n        peer_name: str,\n        prefill_aux_index: int,\n        dst_aux_ptrs: list[int],\n        dst_aux_index: int,\n        notif: str,\n    ):\n        src_addrs = []\n        dst_addrs = []\n\n        prefill_aux_ptrs = self.kv_args.aux_data_ptrs\n        prefill_aux_item_lens = self.kv_args.aux_item_lens\n\n        for i, _ in enumerate(dst_aux_ptrs):\n            length = prefill_aux_item_lens[i]\n            src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index\n            dst_addr = dst_aux_ptrs[i] + length * dst_aux_index\n            src_addrs.append((src_addr, length, 0))\n            dst_addrs.append((dst_addr, length, 0))\n\n        src_descs = self.agent.get_xfer_descs(src_addrs, \"DRAM\")\n        dst_descs = self.agent.get_xfer_descs(dst_addrs, \"DRAM\")\n        # Transfer data\n        xfer_handle = self.agent.initialize_xfer(\n            \"WRITE\",\n            src_descs,\n            dst_descs,\n            peer_name,\n            notif.encode(\"ascii\"),  # type: ignore\n        )\n        if not xfer_handle:\n            raise Exception(\"KVSender failed to create transfer\")\n        state = self.agent.transfer(xfer_handle)\n        if state == \"ERR\":\n            raise Exception(\"KVSender failed to post transfer\")\n        return xfer_handle\n\n    def _send_mamba_state(\n        self,\n        peer_name: str,\n        prefill_state_indices: List[int],\n        dst_state_data_ptrs: list[int],\n        dst_state_indices: List[int],\n        dst_gpu_id: int,\n        notif: str,\n    ):\n        \"\"\"Transfer Mamba states via RDMA.\"\"\"\n        assert len(prefill_state_indices) == 1, \"Mamba should have single state index\"\n        assert len(dst_state_indices) == len(\n            prefill_state_indices\n        ), \"State indices count mismatch between Prefill and Decode\"\n\n        src_addrs = []\n        dst_addrs = []\n\n        prefill_state_data_ptrs = self.kv_args.state_data_ptrs\n        prefill_state_item_lens = self.kv_args.state_item_lens\n\n        for i, dst_state_ptr in enumerate(dst_state_data_ptrs):\n            length = prefill_state_item_lens[i]\n            src_addr = prefill_state_data_ptrs[i] + length * int(\n                prefill_state_indices[0]\n            )\n            dst_addr = dst_state_ptr + length * int(dst_state_indices[0])\n            src_addrs.append((src_addr, length, self.kv_args.gpu_id))\n            dst_addrs.append((dst_addr, length, dst_gpu_id))\n\n        src_descs = self.agent.get_xfer_descs(src_addrs, \"VRAM\")\n        dst_descs = self.agent.get_xfer_descs(dst_addrs, \"VRAM\")\n\n        xfer_handle = self.agent.initialize_xfer(\n            \"WRITE\",\n            src_descs,\n            dst_descs,\n            peer_name,\n            notif.encode(\"ascii\"),\n        )\n        if not xfer_handle:\n            raise Exception(\"Failed to create Mamba state transfer\")\n        state = self.agent.transfer(xfer_handle)\n        if state == \"ERR\":\n            raise Exception(\"Failed to post Mamba state transfer\")\n        return xfer_handle\n\n    def maybe_send_extra(\n        self,\n        peer_name: str,\n        prefill_state_indices: List[int],\n        dst_state_data_ptrs: list[int],\n        dst_state_indices: List[int],\n        dst_gpu_id: int,\n        notif: str,\n        decode_tp_size: int,\n    ):\n        \"\"\"Send state or extra pool data with type-specific handling.\"\"\"\n        state_type = getattr(self.kv_args, \"state_type\", \"none\")\n\n        if state_type == \"mamba\":\n            if self.attn_tp_size != decode_tp_size:\n                raise RuntimeError(\n                    \"PD Disaggregation does NOT support PD different TP sizes for hybrid mamba models yet.\"\n                )\n            return self._send_mamba_state(\n                peer_name,\n                prefill_state_indices,\n                dst_state_data_ptrs,\n                dst_state_indices,\n                dst_gpu_id,\n                notif,\n            )\n        elif state_type in [\"swa\", \"nsa\"]:\n            if not self.is_mla_backend and self.attn_tp_size != decode_tp_size:\n                raise RuntimeError(\n                    f\"PD Disaggregation does NOT support PD different TP sizes for non-MLA {state_type.upper()} hybrid models yet.\"\n                )\n            if len(prefill_state_indices) != len(dst_state_indices):\n                raise RuntimeError(\n                    f\"State index length mismatch: prefill={len(prefill_state_indices)}, \"\n                    f\"dst={len(dst_state_indices)}\"\n                )\n            return self._send_kvcache_generic(\n                peer_name=peer_name,\n                src_data_ptrs=self.kv_args.state_data_ptrs,\n                dst_data_ptrs=dst_state_data_ptrs,\n                item_lens=self.kv_args.state_item_lens,\n                prefill_data_indices=np.array(prefill_state_indices, dtype=np.int32),\n                dst_data_indices=np.array(dst_state_indices, dtype=np.int32),\n                dst_gpu_id=dst_gpu_id,\n                notif=notif,\n            )\n        else:\n            if state_type != \"none\":\n                raise RuntimeError(\n                    f\"PD Disaggregation via NIXL does NOT support {state_type} hybrid models yet.\"\n                )\n            return None\n\n    def add_transfer_request(\n        self,\n        bootstrap_room: int,\n        kv_indices: npt.NDArray[np.int32],\n        index_slice: slice,\n        is_last: bool,\n        chunk_id: int,\n        aux_index: Optional[int] = None,\n        state_indices: Optional[List[int]] = None,\n    ):\n        assert self.disaggregation_mode == DisaggregationMode.PREFILL\n        assert not is_last or (is_last and aux_index is not None)\n\n        reqs_to_be_processed = self.transfer_infos[bootstrap_room].values()\n        handles = []\n        for req in reqs_to_be_processed:\n            assert bootstrap_room == req.room\n            if req.is_dummy():\n                continue\n\n            chunked_dst_kv_indice = req.dst_kv_indices[index_slice]\n            assert len(chunked_dst_kv_indice) == len(kv_indices)\n            assert req.agent_name in self.decode_kv_args_table\n\n            notif = f\"{req.room}_kv_{chunk_id}_{int(is_last)}_{self.kv_args.pp_rank}\"\n            decode_tp_size = self.decode_kv_args_table[req.agent_name].decode_tp_size\n\n            if self.is_mla_backend or (decode_tp_size == self.attn_tp_size):\n                kv_xfer_handle = self.send_kvcache(\n                    req.agent_name,\n                    kv_indices,\n                    self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,\n                    chunked_dst_kv_indice,\n                    self.decode_kv_args_table[req.agent_name].gpu_id,\n                    notif,\n                )\n            else:\n                kv_xfer_handle = self.send_kvcache_slice(\n                    req.agent_name,\n                    kv_indices,\n                    self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,\n                    chunked_dst_kv_indice,\n                    self.decode_kv_args_table[req.agent_name].gpu_id,\n                    notif,\n                    prefill_tp_size=self.attn_tp_size,\n                    decode_tp_size=decode_tp_size,\n                    decode_tp_rank=self.decode_kv_args_table[\n                        req.agent_name\n                    ].decode_tp_rank,\n                    dst_kv_item_len=self.decode_kv_args_table[\n                        req.agent_name\n                    ].dst_kv_item_len,\n                )\n\n            handles.append(kv_xfer_handle)\n            # Only the last chunk we need to send the aux data.\n            if is_last:\n                if state_indices is not None:\n                    dst_info = self.decode_kv_args_table[req.agent_name]\n                    state_xfer_handle = self.maybe_send_extra(\n                        req.agent_name,\n                        state_indices,\n                        dst_info.dst_state_data_ptrs,\n                        req.dst_state_indices,\n                        dst_info.gpu_id,\n                        f\"{req.room}_state_{self.kv_args.pp_rank}\",\n                        decode_tp_size,\n                    )\n                    if state_xfer_handle is not None:\n                        handles.append(state_xfer_handle)\n\n                assert aux_index is not None\n                aux_xfer_handle = self.send_aux(\n                    req.agent_name,\n                    aux_index,\n                    self.decode_kv_args_table[req.agent_name].dst_aux_ptrs,\n                    req.dst_aux_index,\n                    f\"{req.room}_aux\",\n                )\n                handles.append(aux_xfer_handle)\n        if is_last:\n            del self.transfer_infos[bootstrap_room]\n        return handles\n\n    def update_transfer_status(self):\n        # Process notifications from received transfers.\n        notif_map = self.agent.get_new_notifs()\n        for peer_name, messages in notif_map.items():\n            # We could also check that self.bootstrap_info['agent_name'] matches\n            # the message sender. But the bootstrap room alone should be\n            # sufficient to map the status.\n            for msg in messages:\n                components = msg.decode(\"ascii\").split(\"_\", 4)\n                room = int(components[0])\n                if components[1] == \"kv\":\n                    chunk_id = int(components[2])\n                    is_last = bool(int(components[3]))\n                    pp_rank = int(components[4]) if len(components) > 4 else 0\n                    # Track received chunks per pp_rank\n                    self.transfer_statuses[room].received_kvs_per_pp[pp_rank].add(\n                        chunk_id\n                    )\n                    if is_last:\n                        # Record expected chunk count for this pp_rank\n                        self.transfer_statuses[room].expected_kvs_per_pp[pp_rank] = (\n                            chunk_id + 1\n                        )\n                        # Set num_pp_ranks_expected from table (or default to 1)\n                        if self.transfer_statuses[room].num_pp_ranks_expected is None:\n                            self.transfer_statuses[room].num_pp_ranks_expected = (\n                                self.required_prefill_response_num_table.get(room, 1)\n                            )\n                elif components[1] == \"aux\":\n                    self.transfer_statuses[room].received_aux = True\n                elif components[1] == \"state\":\n                    pp_rank = int(components[2]) if len(components) > 2 else 0\n                    self.transfer_statuses[room].received_state_per_pp.add(pp_rank)\n\n    def check_transfer_done(self, room: int):\n        if room not in self.transfer_statuses:\n            return False\n        return self.transfer_statuses[room].is_done()\n\n    def _start_bootstrap_thread(self):\n        def bootstrap_thread():\n            \"\"\"This thread recvs transfer info from the decode engine\"\"\"\n            while True:\n                waiting_req_bytes = self.server_socket.recv_multipart()\n                logger.debug(\n                    f\"Received multipart with total byte size {sum(len(x) for x in waiting_req_bytes)}\"\n                )\n                assert (\n                    waiting_req_bytes[0] == GUARD\n                ), f\"First message should be {GUARD}. Foreign traffic?\"\n                waiting_req_bytes = waiting_req_bytes[1:]\n                room = waiting_req_bytes[0].decode(\"ascii\")\n                agent_name = waiting_req_bytes[3].decode(\"ascii\")\n                if room == \"None\":\n                    # Register new peer and save KV base pointers.\n                    self._add_remote_peer(\n                        KVArgsRegisterInfo.from_zmq(waiting_req_bytes)\n                    )\n                    logger.debug(f\"Register KVArgs from {agent_name} successfully\")\n                    continue\n                room = int(room)\n                if room not in self.transfer_infos:\n                    self.transfer_infos[room] = {}\n                self.transfer_infos[room][agent_name] = TransferInfo.from_zmq(\n                    waiting_req_bytes\n                )\n                required_dst_info_num = self.transfer_infos[room][\n                    agent_name\n                ].required_dst_info_num\n                logger.debug(f\"got info {room=} {agent_name=} {required_dst_info_num=}\")\n                if len(self.transfer_infos[room]) == required_dst_info_num:\n                    logger.debug(f\"{room=} is bootstrapped\")\n                    self.update_status(room, KVPoll.WaitingForInput)\n\n        threading.Thread(target=bootstrap_thread).start()\n\n\nclass NixlKVSender(CommonKVSender):\n    def __init__(\n        self,\n        mgr: NixlKVManager,\n        bootstrap_addr: str,\n        bootstrap_room: int,\n        dest_tp_ranks: List[int],\n        pp_rank: int,\n    ):\n        super().__init__(mgr, bootstrap_addr, bootstrap_room, dest_tp_ranks, pp_rank)\n        self.xfer_handles = []\n        self.has_sent = False\n        self.chunk_id = 0\n\n    def send(\n        self,\n        kv_indices: npt.NDArray[np.int32],\n        state_indices: Optional[List[int]] = None,\n    ):\n        index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))\n        self.curr_idx += len(kv_indices)\n        is_last = self.curr_idx == self.num_kv_indices\n\n        # Special handling for cp\n        if self.kv_mgr.enable_all_cp_ranks_for_transfer:\n            kv_indices, index_slice = filter_kv_indices_for_cp_rank(\n                self.kv_mgr,\n                kv_indices,\n                index_slice,\n            )\n        elif self.kv_mgr.is_dummy_cp_rank:\n            if not is_last:\n                return\n            else:\n                self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Success)\n                return\n\n        new_xfer_handles = self.kv_mgr.add_transfer_request(\n            self.bootstrap_room,\n            kv_indices,\n            index_slice,\n            is_last,\n            self.chunk_id,\n            self.aux_index,\n            state_indices,\n        )\n        self.xfer_handles.extend(new_xfer_handles)\n        self.chunk_id += 1\n        if is_last:\n            self.has_sent = True\n            del self.kv_mgr.request_status[self.bootstrap_room]\n\n    def poll(self) -> KVPoll:\n        if not self.has_sent:\n            return self.kv_mgr.check_status(self.bootstrap_room)\n        states = [self.kv_mgr.agent.check_xfer_state(x) for x in self.xfer_handles]\n        if all([x == \"DONE\" for x in states]):\n            return KVPoll.Success  # type: ignore\n        if any([x == \"ERR\" for x in states]):\n            raise Exception(\"KVSender transfer encountered an error.\")\n        return KVPoll.WaitingForInput  # type: ignore\n\n    def failure_exception(self):\n        raise RuntimeError(\"NIXL KVSender Exception\")\n\n\nclass NixlKVReceiver(CommonKVReceiver):\n    def __init__(\n        self,\n        mgr: NixlKVManager,\n        bootstrap_addr: str,\n        bootstrap_room: Optional[int] = None,\n        prefill_dp_rank: Optional[int] = None,\n    ):\n        self.started_transfer = False\n        self.conclude_state = None\n        super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank)\n\n        # Track this room with its bootstrap address for heartbeat monitoring\n        if hasattr(self.kv_mgr, \"addr_to_rooms_tracker\"):\n            self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add(\n                self.bootstrap_room\n            )\n        self.init_time = None\n\n    def init(\n        self,\n        kv_indices: npt.NDArray[np.int32],\n        aux_index: Optional[int] = None,\n        state_indices: Optional[List[int]] = None,\n    ):\n        if self.bootstrap_infos is None:\n            logger.error(\n                f\"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}\",\n            )\n            self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)\n            return\n\n        for bootstrap_info in self.bootstrap_infos:\n            logger.debug(\n                f\"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}\"\n            )\n            sock, lock = self._connect_to_bootstrap_server(bootstrap_info)\n            is_dummy = bootstrap_info[\"is_dummy\"]\n            logger.debug(\n                f\"Sending to prefill server with bootstrap room {self.bootstrap_room} {is_dummy=}\"\n            )\n            with lock:\n                sock.send_multipart(\n                    [\n                        GUARD,\n                        str(self.bootstrap_room).encode(\"ascii\"),\n                        self.kv_mgr.local_ip.encode(\"ascii\"),\n                        str(self.kv_mgr.rank_port).encode(\"ascii\"),\n                        self.kv_mgr.agent.name.encode(\"ascii\"),\n                        kv_indices.tobytes() if not is_dummy else b\"\",\n                        str(aux_index).encode(\"ascii\"),\n                        str(self.required_dst_info_num).encode(\"ascii\"),\n                        (\n                            np.array(state_indices, dtype=np.int32).tobytes()\n                            if not is_dummy and state_indices is not None\n                            else b\"\"\n                        ),\n                    ]\n                )\n\n        # Mark that we expect state data if state_indices was provided\n        if state_indices is not None:\n            self.kv_mgr.transfer_statuses[self.bootstrap_room].expects_state = True\n\n        self.started_transfer = True\n        self.init_time = time.time()\n\n    def poll(self) -> KVPoll:\n        if self.conclude_state is not None:\n            return self.conclude_state\n        status = self.kv_mgr.check_status(self.bootstrap_room)\n        if status in (KVPoll.Success, KVPoll.Failed):\n            self.conclude_state = status\n            return status\n        if not self.started_transfer:\n            return KVPoll.WaitingForInput  # type: ignore\n\n        now = time.time()\n        elapsed = now - self.init_time\n\n        if elapsed >= self.kv_mgr.waiting_timeout:\n            logger.error(f\"Request {self.bootstrap_room} waiting_timeout\")\n            self.kv_mgr.record_failure(\n                self.bootstrap_room,\n                f\"Request {self.bootstrap_room} timed out after {elapsed:.1f}s in KVPoll.WaitingForInput\",\n            )\n            self.conclude_state = KVPoll.Failed\n            return KVPoll.Failed\n\n        self.kv_mgr.update_transfer_status()\n        if self.kv_mgr.check_transfer_done(self.bootstrap_room):  # type: ignore\n            self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].discard(\n                self.bootstrap_room\n            )\n            # Check if the transfer failed\n            if self.kv_mgr.transfer_statuses[self.bootstrap_room].is_failed():\n                self.conclude_state = KVPoll.Failed\n                logger.error(\n                    f\"Transfer for room {self.bootstrap_room} failed due to node failure\"\n                )\n            else:\n                self.conclude_state = KVPoll.Success\n            del self.kv_mgr.transfer_statuses[self.bootstrap_room]\n            return self.conclude_state  # type: ignore\n        return KVPoll.WaitingForInput  # type: ignore\n\n    def _register_kv_args(self):\n        for bootstrap_info in self.bootstrap_infos:\n            sock, lock = self._connect_to_bootstrap_server(bootstrap_info)\n            packed_kv_data_ptrs = b\"\".join(\n                struct.pack(\"Q\", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs\n            )\n            packed_aux_data_ptrs = b\"\".join(\n                struct.pack(\"Q\", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs\n            )\n            packed_state_data_ptrs = b\"\".join(\n                struct.pack(\"Q\", ptr) for ptr in self.kv_mgr.kv_args.state_data_ptrs\n            )\n\n            with lock:\n                sock.send_multipart(\n                    [\n                        GUARD,\n                        \"None\".encode(\"ascii\"),\n                        self.kv_mgr.local_ip.encode(\"ascii\"),\n                        str(self.kv_mgr.rank_port).encode(\"ascii\"),\n                        self.kv_mgr.agent.name.encode(\"ascii\"),\n                        self.kv_mgr.agent.get_agent_metadata(),\n                        packed_kv_data_ptrs,\n                        packed_aux_data_ptrs,\n                        packed_state_data_ptrs,\n                        str(self.kv_mgr.kv_args.gpu_id).encode(\"ascii\"),\n                        str(self.kv_mgr.attn_tp_size).encode(\"ascii\"),\n                        str(self.kv_mgr.kv_args.engine_rank).encode(\"ascii\"),\n                        str(self.kv_mgr.kv_args.kv_item_lens[0]).encode(\"ascii\"),\n                    ]\n                )\n\n    def failure_exception(self):\n        raise RuntimeError(\"NIXL KVReceiver Exception\")\n\n\nclass NixlKVBootstrapServer(CommonKVBootstrapServer):\n    pass\n"
  },
  {
    "path": "python/sglang/srt/disaggregation/prefill.py",
    "content": "\"\"\"\nLife cycle of a request in the prefill server\n\n1. Bootstrap Queue\n    a. Initialize a sender for each request\n    b. Use the queue to store requests whose bootstrap (handshake and preallocation) has not finished\n    c. Poll senders to check bootstrap state\n    d. Once bootstrap is complete, move request to Waiting Queue\n\n2. Waiting Queue\n    a. Use PrefillAdder to pop requests\n    b. Run forward\n    c. Add the request to Inflight Queue\n\n3. Inflight Queue\n    a. Poll (non-blocking) the sender of the request\n    b. Once the transfer has finished, return the request\n\"\"\"\n\nfrom __future__ import annotations\n\nimport logging\nfrom collections import deque\nfrom http import HTTPStatus\nfrom typing import TYPE_CHECKING, List, Optional\n\nimport torch\n\nfrom sglang.srt.disaggregation.base import KVPoll\nfrom sglang.srt.disaggregation.common.conn import CommonKVManager\nfrom sglang.srt.disaggregation.utils import (\n    FAKE_BOOTSTRAP_HOST,\n    DisaggregationMode,\n    KVClassType,\n    MetadataBuffers,\n    ReqToMetadataIdxAllocator,\n    TransferBackend,\n    get_kv_class,\n    is_mla_backend,\n    kv_to_page_indices,\n    kv_to_page_num,\n    poll_and_all_reduce_attn_cp_tp_group,\n    prepare_abort,\n)\nfrom sglang.srt.managers.schedule_batch import (\n    FINISH_ABORT,\n    FINISH_LENGTH,\n    Req,\n    ScheduleBatch,\n)\nfrom sglang.srt.mem_cache.common import release_kv_cache\nfrom sglang.srt.mem_cache.memory_pool import HybridLinearKVPool, NSATokenToKVPool\nfrom sglang.srt.mem_cache.swa_memory_pool import SWAKVPool\nfrom sglang.srt.observability.req_time_stats import set_schedule_time_batch\n\nif TYPE_CHECKING:\n    from torch.distributed import ProcessGroup\n\n    from sglang.srt.managers.scheduler import GenerationBatchResult, Scheduler\n    from sglang.srt.mem_cache.memory_pool import KVCache\n\nlogger = logging.getLogger(__name__)\n\n\ndef release_req_to_metadata_buffer(\n    req: Req, allocator: ReqToMetadataIdxAllocator\n) -> None:\n    \"\"\"\n    Release the metadata buffer index allocated for a request in prefill disaggregation mode.\n\n    This function safely releases the metadata buffer index if it was allocated.\n\n    Args:\n        req: The request object that may have a metadata_buffer_index allocated\n        allocator: The ReqToMetadataIdxAllocator instance to free the index\n    \"\"\"\n    if (\n        hasattr(req, \"metadata_buffer_index\")\n        and req.metadata_buffer_index is not None\n        and req.metadata_buffer_index >= 0\n    ):\n        allocator.free(req.metadata_buffer_index)\n        req.metadata_buffer_index = -1\n\n\nclass PrefillBootstrapQueue:\n    \"\"\"\n    Store the requests in bootstrapping\n    \"\"\"\n\n    def __init__(\n        self,\n        token_to_kv_pool: KVCache,\n        draft_token_to_kv_pool: Optional[KVCache],\n        req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,\n        metadata_buffers: MetadataBuffers,\n        tp_rank: int,\n        tp_size: int,\n        gpu_id: int,\n        bootstrap_port: int,\n        gloo_group: ProcessGroup,\n        max_total_num_tokens: int,\n        scheduler: Scheduler,\n        pp_rank: int,\n        pp_size: int,\n        transfer_backend: TransferBackend,\n    ):\n        self.token_to_kv_pool = token_to_kv_pool\n        self.draft_token_to_kv_pool = draft_token_to_kv_pool\n        self.is_mla_backend = is_mla_backend(token_to_kv_pool)\n        self.metadata_buffers = metadata_buffers\n        self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator\n        self.tp_rank = tp_rank\n        self.tp_size = tp_size\n        self.pp_rank = pp_rank\n        self.pp_size = pp_size\n        self.gpu_id = gpu_id\n        self.bootstrap_port = bootstrap_port\n        self.queue: List[Req] = []\n        self.gloo_group = gloo_group\n        self.max_total_num_tokens = max_total_num_tokens\n        self.scheduler = scheduler\n        self.transfer_backend = transfer_backend\n        self.kv_manager = self._init_kv_manager()\n\n        if self.scheduler.tp_worker.is_hybrid_swa:\n            # FIXME: current SWA allocation allocate full kv cache size in prefill\n            self.max_total_num_tokens = min(\n                self.max_total_num_tokens,\n                self.scheduler.tp_worker.model_runner.swa_max_total_num_tokens,\n            )\n\n    def _init_kv_manager(self) -> CommonKVManager:\n        kv_args_class = get_kv_class(self.transfer_backend, KVClassType.KVARGS)\n        kv_args = kv_args_class()\n        kv_args.engine_rank = self.tp_rank\n        kv_args.pp_rank = self.pp_rank\n        kv_args.system_dp_rank = self.scheduler.dp_rank\n        kv_args.prefill_start_layer = self.token_to_kv_pool.start_layer\n        kv_data_ptrs, kv_data_lens, kv_item_lens = (\n            self.token_to_kv_pool.get_contiguous_buf_infos()\n        )\n\n        if self.draft_token_to_kv_pool is not None:\n            # We should also transfer draft model kv cache. The indices are\n            # always shared with a target model.\n            draft_kv_data_ptrs, draft_kv_data_lens, draft_kv_item_lens = (\n                self.draft_token_to_kv_pool.get_contiguous_buf_infos()\n            )\n            kv_data_ptrs += draft_kv_data_ptrs\n            kv_data_lens += draft_kv_data_lens\n            kv_item_lens += draft_kv_item_lens\n\n        kv_args.kv_data_ptrs = kv_data_ptrs\n        kv_args.kv_data_lens = kv_data_lens\n        kv_args.kv_item_lens = kv_item_lens\n        if not self.is_mla_backend:\n            kv_args.kv_head_num = self.token_to_kv_pool.head_num\n            kv_args.total_kv_head_num = (\n                self.scheduler.model_config.get_total_num_kv_heads()\n            )\n        kv_args.page_size = self.token_to_kv_pool.page_size\n\n        kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (\n            self.metadata_buffers.get_buf_infos()\n        )\n        kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device\n        kv_args.gpu_id = self.scheduler.gpu_id\n\n        if hasattr(self.token_to_kv_pool, \"get_state_buf_infos\"):\n            state_data_ptrs, state_data_lens, state_item_lens = (\n                self.token_to_kv_pool.get_state_buf_infos()\n            )\n            kv_args.state_data_ptrs = state_data_ptrs\n            kv_args.state_data_lens = state_data_lens\n            kv_args.state_item_lens = state_item_lens\n\n            if isinstance(self.token_to_kv_pool, SWAKVPool):\n                kv_args.state_type = \"swa\"\n            elif isinstance(self.token_to_kv_pool, HybridLinearKVPool):\n                kv_args.state_type = \"mamba\"\n                # Get state dimension info for cross-TP slice transfer\n                if hasattr(self.token_to_kv_pool, \"get_state_dim_per_tensor\"):\n                    kv_args.state_dim_per_tensor = (\n                        self.token_to_kv_pool.get_state_dim_per_tensor()\n                    )\n            elif isinstance(self.token_to_kv_pool, NSATokenToKVPool):\n                kv_args.state_type = \"nsa\"\n            else:\n                kv_args.state_type = \"none\"\n        else:\n            kv_args.state_data_ptrs = []\n            kv_args.state_data_lens = []\n            kv_args.state_item_lens = []\n            kv_args.state_type = \"none\"\n\n        kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)\n        kv_manager = kv_manager_class(\n            kv_args,\n            DisaggregationMode.PREFILL,\n            self.scheduler.server_args,\n            self.is_mla_backend,\n        )\n        return kv_manager\n\n    def add(self, req: Req, num_kv_heads: int) -> None:\n        if self._check_if_req_exceed_kv_capacity(req):\n            return\n\n        backend = (\n            TransferBackend.FAKE\n            if req.bootstrap_host == FAKE_BOOTSTRAP_HOST\n            else self.transfer_backend\n        )\n        kv_sender_class = get_kv_class(backend, KVClassType.SENDER)\n\n        dest_tp_ranks = [self.tp_rank]\n\n        req.disagg_kv_sender = kv_sender_class(\n            mgr=self.kv_manager,\n            bootstrap_addr=f\"{req.bootstrap_host}:{self.bootstrap_port}\",\n            bootstrap_room=req.bootstrap_room,\n            dest_tp_ranks=dest_tp_ranks,\n            pp_rank=self.pp_rank,\n        )\n        self._process_req(req)\n        self.queue.append(req)\n\n    def extend(self, reqs: List[Req], num_kv_heads: int) -> None:\n        for req in reqs:\n            self.add(req, num_kv_heads)\n\n    def _check_if_req_exceed_kv_capacity(self, req: Req) -> bool:\n        if len(req.origin_input_ids) > self.max_total_num_tokens:\n            message = f\"Request {req.rid} exceeds the maximum number of tokens: {len(req.origin_input_ids)} > {self.max_total_num_tokens}\"\n            logger.error(message)\n            req.time_stats.trace_ctx.abort(abort_info={\"reason\": message})\n            prepare_abort(req, message, status_code=HTTPStatus.BAD_REQUEST)\n            self.scheduler.stream_output([req], req.return_logprob)\n            return True\n        return False\n\n    def _process_req(self, req: Req) -> None:\n        \"\"\"\n        Set max_new_tokens = 1, so PrefillAdder memory estimation is accurate\n        \"\"\"\n        req.sampling_params.max_new_tokens = 1\n\n    def pop_bootstrapped(\n        self,\n        return_failed_reqs: bool = False,\n        rids_to_check: Optional[List[str]] = None,\n    ) -> List[Req]:\n        \"\"\"\n        pop the reqs which has finished bootstrapping\n\n        return_failed_reqs: For PP, on rank 0, also return the failed reqs to notify the next rank\n        rids_to_check: For PP, on rank > 0, check the rids from the previous rank has consensus with the current rank.\n        \"\"\"\n\n        bootstrapped_reqs = []\n        failed_reqs = []\n        indices_to_remove = set()\n\n        if len(self.queue) == 0:\n            if return_failed_reqs is False:\n                return []\n            else:\n                return [], []\n\n        polls = poll_and_all_reduce_attn_cp_tp_group(\n            [req.disagg_kv_sender for req in self.queue],\n            self.scheduler.attn_cp_cpu_group,\n            self.scheduler.attn_tp_cpu_group,\n        )\n\n        for i, (req, poll) in enumerate(zip(self.queue, polls)):\n            if rids_to_check is not None:\n                # if req not in reqs_info_to_check, skip\n                if req.rid not in rids_to_check:\n                    continue\n\n            if poll == KVPoll.Bootstrapping:\n                continue\n            elif poll == KVPoll.Failed:\n                error_message = f\"Prefill bootstrap failed for request rank={self.tp_rank} {req.rid=} {req.bootstrap_room=}\"\n                try:\n                    req.disagg_kv_sender.failure_exception()\n                except Exception as e:\n                    error_message += f\" with exception {e}\"\n                logger.error(error_message)\n                req.time_stats.trace_ctx.abort(abort_info={\"reason\": error_message})\n                prepare_abort(\n                    req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR\n                )\n                self.scheduler.stream_output([req], req.return_logprob)\n                indices_to_remove.add(i)\n                failed_reqs.append(req)\n                if self.scheduler.enable_metrics:\n                    self.scheduler.metrics_collector.increment_bootstrap_failed_reqs()\n                if self.scheduler.enable_hicache_storage:\n                    # to release prefetch events associated with the request\n                    self.scheduler.tree_cache.release_aborted_request(req.rid)\n                continue\n\n            # KV.WaitingForInput - init here\n            req.time_stats.set_bootstrap_done_time()\n            num_kv_indices = len(req.origin_input_ids)\n            if self.req_to_metadata_buffer_idx_allocator.available_size() == 0:\n                break\n\n            req.metadata_buffer_index = (\n                self.req_to_metadata_buffer_idx_allocator.alloc()\n            )\n            assert req.metadata_buffer_index is not None\n\n            num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)\n            req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)\n\n            bootstrapped_reqs.append(req)\n            indices_to_remove.add(i)\n            req.time_stats.set_wait_queue_entry_time()\n\n        self.queue = [\n            entry for i, entry in enumerate(self.queue) if i not in indices_to_remove\n        ]\n\n        if return_failed_reqs is False:\n            return bootstrapped_reqs\n        else:\n            return bootstrapped_reqs, failed_reqs\n\n\nclass SchedulerDisaggregationPrefillMixin:\n    \"\"\"\n    Mixin for Scheduler to handle disaggregation prefill\n    \"\"\"\n\n    def get_next_disagg_prefill_batch_to_run(\n        self: Scheduler,\n    ) -> Optional[ScheduleBatch]:\n        # HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it\n        # Otherwise, it hangs under high concurrency\n        self.running_batch.batch_is_full = False\n\n        self.process_prefill_chunk()\n\n        batch = self.get_new_batch_prefill()\n        batch = self.maybe_prepare_mlp_sync_batch(batch)\n\n        if batch:\n            set_schedule_time_batch(batch)\n\n        return batch\n\n    @torch.no_grad()\n    def event_loop_normal_disagg_prefill(self: Scheduler) -> None:\n        \"\"\"A normal scheduler loop for prefill worker in disaggregation mode.\"\"\"\n\n        while True:\n            # Receive requests\n            recv_reqs = self.recv_requests()\n            self.process_input_requests(recv_reqs)\n            self.waiting_queue.extend(\n                self.disagg_prefill_bootstrap_queue.pop_bootstrapped()\n            )\n\n            # Get the next batch to run\n            batch = self.get_next_disagg_prefill_batch_to_run()\n            self.cur_batch = batch\n\n            # Launch the current batch\n            if batch:\n                result = self.run_batch(batch)\n                self.process_batch_result(batch, result)\n            else:\n                self.self_check_during_idle()\n\n            self.process_disagg_prefill_inflight_queue()\n\n            # Update last_batch\n            self.last_batch = batch\n\n    @torch.no_grad()\n    def event_loop_overlap_disagg_prefill(self: Scheduler) -> None:\n        self.result_queue = deque()\n\n        while True:\n            # Receive requests\n            recv_reqs = self.recv_requests()\n            self.process_input_requests(recv_reqs)\n            self.waiting_queue.extend(\n                self.disagg_prefill_bootstrap_queue.pop_bootstrapped()\n            )\n\n            # Get the next batch to run\n            batch = self.get_next_disagg_prefill_batch_to_run()\n            self.cur_batch = batch\n\n            # Launch the current batch\n            if batch:\n                batch_result = self.run_batch(batch)\n                self.result_queue.append((batch.copy(), batch_result))\n            else:\n                batch_result = None\n\n            # Process the last batch\n            if self.last_batch:\n                tmp_batch, tmp_result = self.result_queue.popleft()\n                self.process_batch_result(tmp_batch, tmp_result)\n            elif batch is None:\n                # When the server is idle, do self-check and re-init some states\n                self.self_check_during_idle()\n\n            self.process_disagg_prefill_inflight_queue()\n\n            # Run sample of the current batch\n            # It depends on the result of the last batch (e.g., grammar), so we run it after the last batch is processed.\n            self.launch_batch_sample_if_needed(batch_result)\n\n            # Update last_batch\n            self.last_batch = batch\n\n    def process_batch_result_disagg_prefill(\n        self: Scheduler,\n        batch: ScheduleBatch,\n        result: GenerationBatchResult,\n    ) -> None:\n        \"\"\"\n        Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue\n        Adapted from process_batch_result_prefill\n        \"\"\"\n        (\n            logits_output,\n            next_token_ids,\n            extend_input_len_per_req,\n            extend_logprob_start_len_per_req,\n            copy_done,\n        ) = (\n            result.logits_output,\n            result.next_token_ids,\n            result.extend_input_len_per_req,\n            result.extend_logprob_start_len_per_req,\n            result.copy_done,\n        )\n\n        if copy_done is not None:\n            copy_done.synchronize()\n\n        logprob_pt = 0\n        # Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue\n        next_token_ids = result.next_token_ids.tolist()\n        if batch.return_logprob:\n            if logits_output.next_token_logprobs is not None:\n                logits_output.next_token_logprobs = (\n                    logits_output.next_token_logprobs.tolist()\n                )\n            if logits_output.input_token_logprobs is not None:\n                logits_output.input_token_logprobs = tuple(\n                    logits_output.input_token_logprobs.tolist()\n                )\n\n        for i, (req, next_token_id) in enumerate(\n            zip(batch.reqs, next_token_ids, strict=True)\n        ):\n            if req.is_chunked <= 0:\n                req.time_stats.set_prefill_finished_time()\n\n                # There is no output_ids for prefill\n                req.output_ids.append(next_token_id)\n                self.tree_cache.cache_unfinished_req(req)  # update the tree and lock\n                self.disagg_prefill_inflight_queue.append(req)\n                if self.spec_algorithm.is_eagle() and batch.spec_info is not None:\n                    req.output_topk_p = batch.spec_info.topk_p[i]\n                    req.output_topk_index = batch.spec_info.topk_index[i]\n                    req.hidden_states_tensor = (\n                        batch.spec_info.hidden_states[i].cpu().clone()\n                    )\n                else:\n                    req.hidden_states_tensor = None\n                if req.return_logprob:\n                    assert extend_logprob_start_len_per_req is not None\n                    assert extend_input_len_per_req is not None\n                    extend_logprob_start_len = extend_logprob_start_len_per_req[i]\n                    extend_input_len = extend_input_len_per_req[i]\n                    num_input_logprobs = extend_input_len - extend_logprob_start_len\n                    self.add_logprob_return_values(\n                        i,\n                        req,\n                        logprob_pt,\n                        next_token_ids,\n                        num_input_logprobs,\n                        logits_output,\n                    )\n                    logprob_pt += num_input_logprobs\n                self.send_kv_chunk(req, last_chunk=True)\n                req.time_stats.set_prefill_transfer_queue_entry_time()\n\n                if req.grammar is not None:\n                    # FIXME: this try-except block is for handling unexpected xgrammar issue.\n                    try:\n                        req.grammar.accept_token(next_token_id)\n                    except ValueError as e:\n                        # Grammar accept_token can raise ValueError if the token is not in the grammar.\n                        # This can happen if the grammar is not set correctly or the token is invalid.\n                        error_message = f\"Grammar accept_token failed for req {req.rid} with token {next_token_id}: {e}\"\n                        release_kv_cache(req, self.tree_cache)\n                        prepare_abort(\n                            req,\n                            error_message,\n                            status_code=HTTPStatus.INTERNAL_SERVER_ERROR,\n                        )\n                    req.grammar.finished = req.finished()\n            else:\n                # being chunked reqs' prefill is not finished\n                req.is_chunked -= 1\n\n                if req.return_logprob:\n                    extend_logprob_start_len = extend_logprob_start_len_per_req[i]\n                    extend_input_len = extend_input_len_per_req[i]\n                    if extend_logprob_start_len < extend_input_len:\n                        # Update input logprobs.\n                        num_input_logprobs = extend_input_len - extend_logprob_start_len\n                        self.add_input_logprob_return_values(\n                            i,\n                            req,\n                            logits_output,\n                            logprob_pt,\n                            num_input_logprobs,\n                            last_prefill_chunk=False,\n                        )\n                        logprob_pt += num_input_logprobs\n\n                if self.enable_overlap:\n                    self.send_kv_chunk(req, last_chunk=False, end_idx=req.tmp_end_idx)\n                req.time_stats.set_last_chunked_prefill_finish_time()\n\n        can_run_cuda_graph = getattr(result, \"can_run_cuda_graph\", False)\n        self.report_prefill_stats(\n            prefill_stats=batch.prefill_stats,\n            can_run_cuda_graph=can_run_cuda_graph,\n            dp_cooperation_info=batch.dp_cooperation_info,\n        )\n\n    def process_disagg_prefill_inflight_queue(\n        self: Scheduler, rids_to_check: Optional[List[str]] = None\n    ) -> List[Req]:\n        \"\"\"\n        Poll the requests in the middle of transfer. If done, return the request.\n        rids_to_check: For PP, on rank > 0, check the rids from the previous rank has consensus with the current rank.\n        \"\"\"\n        if len(self.disagg_prefill_inflight_queue) == 0:\n            return []\n\n        done_reqs = []\n\n        polls = poll_and_all_reduce_attn_cp_tp_group(\n            [req.disagg_kv_sender for req in self.disagg_prefill_inflight_queue],\n            self.attn_cp_cpu_group,\n            self.attn_tp_cpu_group,\n        )\n\n        undone_reqs: List[Req] = []\n        # Check .poll() for the reqs in disagg_prefill_inflight_queue. If Success, respond to the client and remove it from the queue\n        for req, poll in zip(self.disagg_prefill_inflight_queue, polls):\n\n            if rids_to_check is not None:\n                if req.rid not in rids_to_check:\n                    undone_reqs.append(req)\n                    continue\n\n                # In PP mode, the previous rank may have reached a terminal\n                # state (Success/Failed) while this rank's local poll is still\n                # in a transient state due to clock skew or propagation delay.\n                # Treat non-terminal states as undone instead of crashing.\n                if poll not in (\n                    KVPoll.Success,\n                    KVPoll.Failed,\n                ):\n                    logger.warning(\n                        f\"PP rank {self.pp_rank}: unexpected poll state {poll} for rid {req.rid} \"\n                        f\"from consensus; treating as undone\"\n                    )\n                    undone_reqs.append(req)\n                    continue\n\n            if poll in [KVPoll.WaitingForInput, KVPoll.Transferring]:\n                undone_reqs.append(req)\n            elif poll == KVPoll.Success:  # transfer done\n                release_kv_cache(req, self.tree_cache)  # unlock the tree\n                req.finished_reason = FINISH_LENGTH(length=0)\n                # FIXME: clean up req's data in transfer engine\n                if hasattr(req.disagg_kv_sender, \"clear\"):\n                    req.disagg_kv_sender.clear()\n                done_reqs.append(req)\n                req.time_stats.set_prefill_kv_transfer_finish_time()\n            elif poll == KVPoll.Failed:\n                error_message = f\"Prefill transfer failed for request rank={self.tp_rank} {req.rid=} {req.bootstrap_room=}\"\n                try:\n                    req.disagg_kv_sender.failure_exception()\n                except Exception as e:\n                    error_message += f\" with exception {e}\"\n                logger.warning(error_message)\n                req.time_stats.trace_ctx.abort(abort_info={\"reason\": error_message})\n                release_kv_cache(req, self.tree_cache)  # unlock the tree\n                prepare_abort(\n                    req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR\n                )\n                done_reqs.append(req)\n                if self.enable_metrics:\n                    self.metrics_collector.increment_transfer_failed_reqs()\n            else:\n                logger.warning(\n                    f\"Unexpected polling state {poll} for rid {req.rid} in inflight queue; \"\n                    f\"treating as undone\"\n                )\n                undone_reqs.append(req)\n\n        for req in done_reqs:\n            req.time_stats.set_completion_time()\n\n        page_size = self.token_to_kv_pool_allocator.page_size\n        kv_item_lens = (\n            self.disagg_prefill_bootstrap_queue.kv_manager.kv_args.kv_item_lens\n        )\n        bytes_per_page_all_layers = sum(kv_item_lens)\n\n        for req in done_reqs:\n            if isinstance(req.finished_reason, FINISH_ABORT):\n                continue\n            metrics = req.time_stats.compute_and_observe_kv_transfer_metrics(\n                num_tokens=len(req.origin_input_ids),\n                page_size=page_size,\n                bytes_per_page_all_layers=bytes_per_page_all_layers,\n            )\n            if metrics:\n                # Update last-value for REST API\n                if \"latency_ms\" in metrics:\n                    self.kv_transfer_latency_ms = metrics[\"latency_ms\"]\n                if \"speed_gb_s\" in metrics:\n                    self.kv_transfer_speed_gb_s = metrics[\"speed_gb_s\"]\n\n        # Stream requests which have finished transfer\n        self.stream_output(\n            done_reqs,\n            any(req.return_logprob for req in done_reqs),\n            None,\n        )\n        for req in done_reqs:\n            req: Req\n\n            release_req_to_metadata_buffer(\n                req, self.req_to_metadata_buffer_idx_allocator\n            )\n\n        self.disagg_prefill_inflight_queue = undone_reqs\n\n        return done_reqs\n\n    def get_transferred_rids(self: Scheduler) -> List[str]:\n        \"\"\"\n        Used by PP, get the transferred rids but **do not pop**\n        \"\"\"\n        polls = poll_and_all_reduce_attn_cp_tp_group(\n            [req.disagg_kv_sender for req in self.disagg_prefill_inflight_queue],\n            self.attn_cp_cpu_group,\n            self.attn_tp_cpu_group,\n        )\n\n        transferred_rids: List[str] = []\n\n        for req, poll in zip(self.disagg_prefill_inflight_queue, polls):\n            if poll == KVPoll.Success or poll == KVPoll.Failed:\n                transferred_rids.append(req.rid)\n\n        return transferred_rids\n\n    def process_prefill_chunk(self: Scheduler) -> None:\n        chunked_req_to_exclude = set()\n        if self.chunked_req:\n            chunked_req_to_exclude.add(self.chunked_req)\n            self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)\n            if self.enable_overlap:\n                # Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved\n                self.chunked_req.tmp_end_idx = min(\n                    len(self.chunked_req.fill_ids),\n                    len(self.chunked_req.origin_input_ids),\n                )\n            else:\n                self.send_kv_chunk(self.chunked_req)\n            self.running_batch.batch_is_full = False\n\n        if self.last_batch and self.last_batch.forward_mode.is_extend():\n            if self.last_batch.chunked_req:\n                # In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.\n                # We need to discard it.\n                chunked_req_to_exclude.add(self.last_batch.chunked_req)\n\n            last_bs = self.last_batch.batch_size()\n            self.last_batch.filter_batch(\n                chunked_req_to_exclude=list(chunked_req_to_exclude)\n            )\n            if self.last_batch.batch_size() < last_bs:\n                self.running_batch.batch_is_full = False\n\n    def send_kv_chunk(\n        self: Scheduler,\n        req: Req,\n        last_chunk: bool = False,\n        end_idx: Optional[int] = None,\n    ) -> None:\n        \"\"\"\n        Send a prefilled chunk to the decode server\n        \"\"\"\n        page_size = self.token_to_kv_pool_allocator.page_size\n        start_idx = req.start_send_idx\n        end_idx = (\n            end_idx\n            if end_idx is not None\n            else min(len(req.fill_ids), len(req.origin_input_ids))\n        )\n\n        if not last_chunk:\n            # if not the last chunk and the last page is partial, delay the last partial page to the next send\n            end_idx = end_idx - end_idx % page_size\n\n        kv_indices = (\n            self.req_to_token_pool.req_to_token[req.req_pool_idx, start_idx:end_idx]\n            .cpu()\n            .numpy()\n        )\n        req.start_send_idx = end_idx\n        state_indices = None\n        if last_chunk:\n            self.disagg_metadata_buffers.set_buf(req)\n\n            # Prepare extra pool indices for hybrid models\n            if isinstance(\n                self.token_to_kv_pool_allocator.get_kvcache(), HybridLinearKVPool\n            ):\n                # Mamba hybrid model: send single mamba state index\n                state_indices = [\n                    self.req_to_token_pool.req_index_to_mamba_index_mapping[\n                        req.req_pool_idx\n                    ]\n                    .cpu()\n                    .numpy()\n                ]\n            elif isinstance(self.token_to_kv_pool_allocator.get_kvcache(), SWAKVPool):\n                # SWA hybrid model: send last window KV indices\n                seq_len = len(req.fill_ids)\n                window_size = self.sliding_window_size\n                window_start = max(0, seq_len - window_size)\n                window_start = (window_start // page_size) * page_size\n\n                window_kv_indices_full = self.req_to_token_pool.req_to_token[\n                    req.req_pool_idx, window_start:seq_len\n                ]\n\n                # Translate to SWA pool indices\n                window_kv_indices_swa = (\n                    self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(\n                        window_kv_indices_full\n                    )\n                )\n                state_indices = window_kv_indices_swa.cpu().numpy()\n                state_indices = kv_to_page_indices(state_indices, page_size)\n            elif isinstance(\n                self.token_to_kv_pool_allocator.get_kvcache(), NSATokenToKVPool\n            ):\n                seq_len = len(req.fill_ids)\n                kv_indices_full = self.req_to_token_pool.req_to_token[\n                    req.req_pool_idx, :seq_len\n                ]\n                state_indices = kv_indices_full.cpu().numpy()\n                state_indices = kv_to_page_indices(state_indices, page_size)\n\n        page_indices = kv_to_page_indices(kv_indices, page_size)\n        if len(page_indices) == 0:\n            logger.info(\n                f\"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty\"\n            )\n            return\n        req.disagg_kv_sender.send(page_indices, state_indices)\n"
  },
  {
    "path": "python/sglang/srt/disaggregation/utils.py",
    "content": "from __future__ import annotations\n\nimport os\nimport random\nfrom collections import deque\nfrom contextlib import nullcontext\nfrom enum import Enum\nfrom typing import TYPE_CHECKING, Literal, Optional, Tuple, Type, overload\n\nimport numpy as np\nimport torch\nimport torch.distributed as dist\n\nfrom sglang.srt.environ import envs\nfrom sglang.srt.utils import is_npu\n\nif TYPE_CHECKING:\n    from sglang.srt.disaggregation.base.conn import KVArgs\n    from sglang.srt.disaggregation.common.conn import (\n        CommonKVBootstrapServer,\n        CommonKVManager,\n        CommonKVReceiver,\n        CommonKVSender,\n    )\n    from sglang.srt.managers.schedule_batch import Req\n\n#########################\n# Constants & Enums\n#########################\nFAKE_BOOTSTRAP_HOST = \"2.2.2.2\"\n\n\nclass DisaggregationMode(Enum):\n    NULL = \"null\"\n    PREFILL = \"prefill\"\n    DECODE = \"decode\"\n\n\n#########################\n# Synchronization\n#########################\n\n# env var for testing failure, convert to float explicitly\nFAILURE_PROB = float(os.getenv(\"DISAGGREGATION_TEST_FAILURE_PROB\", 0))\n\n\ndef poll_and_all_reduce(pollers, gloo_group: dist.ProcessGroup):\n    # at a certain prob, the poll is failed to simulate failure\n    if FAILURE_PROB > 0:\n        from sglang.srt.disaggregation.base import KVPoll\n\n        polls = [\n            int(KVPoll.Failed) if random.random() < FAILURE_PROB else int(poller.poll())\n            for poller in pollers\n        ]\n    else:\n        polls = [int(poller.poll()) for poller in pollers]\n    tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device=\"cpu\")\n    dist.all_reduce(tensor_to_reduce, op=dist.ReduceOp.MIN, group=gloo_group)\n    return tensor_to_reduce.tolist()\n\n\ndef poll_and_all_reduce_attn_cp_tp_group(\n    pollers,\n    attn_cp_cpu_group: dist.ProcessGroup,\n    attn_tp_cpu_group: dist.ProcessGroup,\n):\n    # First sync across attn-tp ranks so all TP participants for a given (dp, cp)\n    # shard observe the same status transitions.\n    polls = poll_and_all_reduce(pollers, attn_tp_cpu_group)\n\n    # Then sync across attn-cp ranks, so all TPxCP participants in one DP shard\n    # converge to the same global status.\n    tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device=\"cpu\")\n    dist.all_reduce(\n        tensor_to_reduce,\n        op=dist.ReduceOp.MIN,\n        group=attn_cp_cpu_group,\n    )\n    return tensor_to_reduce.tolist()\n\n\n#########################\n# Metadata Buffers\n#########################\n\n\nclass ReqToMetadataIdxAllocator:\n    \"\"\"A memory pool that maps a request to its first output token location.\"\"\"\n\n    def __init__(\n        self,\n        size: int,\n    ):\n        self.size = size\n        self.free_slots = deque(list(range(size)))\n\n    def available_size(self):\n        return len(self.free_slots)\n\n    def alloc(self) -> Optional[int]:\n        if len(self.free_slots) == 0:\n            return None\n\n        return self.free_slots.popleft()\n\n    def free(self, free_index: int):\n        self.free_slots.append(free_index)\n\n\nclass MetadataBuffers:\n    def __init__(\n        self,\n        size: int,\n        hidden_size: int,\n        hidden_states_dtype: torch.dtype,\n        max_top_logprobs_num: int = 128,\n        custom_mem_pool: torch.cuda.MemPool = None,\n    ):\n        self.custom_mem_pool = custom_mem_pool\n        bootstrap_room_dtype = torch.uint64\n        device = \"cpu\"\n        if is_npu():\n            # For ascend backend, output tokens are placed in the NPU and will be transferred by D2D channel.\n            device = \"npu\"\n            # TODO: Fix me when npu backend supports torch.uint64\n            bootstrap_room_dtype = torch.int64\n        elif self.custom_mem_pool:\n            # TODO(shangming): Fix me (use 'cuda') when nvlink_transport of Mooncake is bug-free\n            device = \"cpu\"\n        elif envs.SGLANG_MOONCAKE_CUSTOM_MEM_POOL.get() == \"INTRA_NODE_NVLINK\":\n            device = \"cuda\"\n        with (\n            torch.cuda.use_mem_pool(self.custom_mem_pool)\n            if self.custom_mem_pool\n            else nullcontext()\n        ):\n            # TODO: abort top_logprobs_num > 128 in PD\n\n            # We transfer the metadata of first output token to decode\n            # The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes\n            self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)\n            self.cached_tokens = torch.zeros(\n                (size, 16), dtype=torch.int32, device=device\n            )\n            self.output_token_logprobs_val = torch.zeros(\n                (size, 16), dtype=torch.float32, device=device\n            )\n            self.output_token_logprobs_idx = torch.zeros(\n                (size, 16), dtype=torch.int32, device=device\n            )\n            self.output_top_logprobs_val = torch.zeros(\n                (size, max_top_logprobs_num), dtype=torch.float32, device=device\n            )\n            self.output_top_logprobs_idx = torch.zeros(\n                (size, max_top_logprobs_num), dtype=torch.int32, device=device\n            )\n            # For PD + spec decode\n            self.output_topk_p = torch.zeros(\n                (size, 16), dtype=torch.float32, device=device\n            )\n            self.output_topk_index = torch.zeros(\n                (size, 16), dtype=torch.int64, device=device\n            )\n            self.output_hidden_states = torch.zeros(\n                (size, hidden_size), dtype=hidden_states_dtype, device=device\n            )\n            # Request validation: store bootstrap_room to detect metadata corruption\n            self.bootstrap_room = torch.zeros(\n                (size, 8), dtype=bootstrap_room_dtype, device=device\n            )\n\n    def get_buf_infos(self):\n        ptrs = [\n            self.output_ids.data_ptr(),\n            self.cached_tokens.data_ptr(),\n            self.output_token_logprobs_val.data_ptr(),\n            self.output_token_logprobs_idx.data_ptr(),\n            self.output_top_logprobs_val.data_ptr(),\n            self.output_top_logprobs_idx.data_ptr(),\n            self.output_topk_p.data_ptr(),\n            self.output_topk_index.data_ptr(),\n            self.output_hidden_states.data_ptr(),\n            self.bootstrap_room.data_ptr(),\n        ]\n        data_lens = [\n            self.output_ids.nbytes,\n            self.cached_tokens.nbytes,\n            self.output_token_logprobs_val.nbytes,\n            self.output_token_logprobs_idx.nbytes,\n            self.output_top_logprobs_val.nbytes,\n            self.output_top_logprobs_idx.nbytes,\n            self.output_topk_p.nbytes,\n            self.output_topk_index.nbytes,\n            self.output_hidden_states.nbytes,\n            self.bootstrap_room.nbytes,\n        ]\n        item_lens = [\n            self.output_ids[0].nbytes,\n            self.cached_tokens[0].nbytes,\n            self.output_token_logprobs_val[0].nbytes,\n            self.output_token_logprobs_idx[0].nbytes,\n            self.output_top_logprobs_val[0].nbytes,\n            self.output_top_logprobs_idx[0].nbytes,\n            self.output_topk_p[0].nbytes,\n            self.output_topk_index[0].nbytes,\n            self.output_hidden_states[0].nbytes,\n            self.bootstrap_room[0].nbytes,\n        ]\n        return ptrs, data_lens, item_lens\n\n    def get_buf(self, idx: int):\n        return (\n            self.output_ids[idx],\n            self.cached_tokens[idx],\n            self.output_token_logprobs_val[idx],\n            self.output_token_logprobs_idx[idx],\n            self.output_top_logprobs_val[idx],\n            self.output_top_logprobs_idx[idx],\n            self.output_topk_p[idx],\n            self.output_topk_index[idx],\n            self.output_hidden_states[idx],\n            self.bootstrap_room[idx],\n        )\n\n    def set_buf(self, req: Req):\n\n        self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]\n        self.cached_tokens[req.metadata_buffer_index][0] = req.cached_tokens\n        if req.return_logprob:\n            if req.output_token_logprobs_val:  # not none or empty list\n                self.output_token_logprobs_val[req.metadata_buffer_index][0] = (\n                    req.output_token_logprobs_val[0]\n                )\n            if req.output_token_logprobs_idx:  # not none or empty list\n                self.output_token_logprobs_idx[req.metadata_buffer_index][0] = (\n                    req.output_token_logprobs_idx[0]\n                )\n\n            if req.output_top_logprobs_val:  # not none or empty list\n                self.output_top_logprobs_val[req.metadata_buffer_index][\n                    : len(req.output_top_logprobs_val[0])\n                ] = torch.tensor(\n                    req.output_top_logprobs_val[0], dtype=torch.float32, device=\"cpu\"\n                )\n            if req.output_top_logprobs_idx:  # not none or empty list\n                self.output_top_logprobs_idx[req.metadata_buffer_index][\n                    : len(req.output_top_logprobs_idx[0])\n                ] = torch.tensor(\n                    req.output_top_logprobs_idx[0], dtype=torch.int32, device=\"cpu\"\n                )\n        # For PD + spec decode\n        if req.hidden_states_tensor is not None:\n            # speculative_eagle_topk should not be greater than 16 currently\n            topk = req.output_topk_p.size(0)\n\n            self.output_topk_p[req.metadata_buffer_index, :topk].copy_(\n                req.output_topk_p\n            )\n            self.output_topk_index[req.metadata_buffer_index, :topk].copy_(\n                req.output_topk_index\n            )\n            self.output_hidden_states[req.metadata_buffer_index].copy_(\n                req.hidden_states_tensor\n            )\n        # Store bootstrap_room for validation on decode side\n        self.bootstrap_room[req.metadata_buffer_index, 0] = (\n            req.bootstrap_room if req.bootstrap_room is not None else 0\n        )\n\n\n#########################\n# Transfer Backend\n#########################\n\n\nclass TransferBackend(Enum):\n    MOONCAKE = \"mooncake\"\n    MORI = \"mori\"\n    NIXL = \"nixl\"\n    ASCEND = \"ascend\"\n    FAKE = \"fake\"\n\n\nclass KVClassType(Enum):\n    KVARGS = \"kvargs\"\n    MANAGER = \"manager\"\n    SENDER = \"sender\"\n    RECEIVER = \"receiver\"\n    BOOTSTRAP_SERVER = \"bootstrap_server\"\n\n\n@overload\ndef get_kv_class(\n    transfer_backend: TransferBackend, class_type: Literal[KVClassType.KVARGS]\n) -> Type[KVArgs]: ...\n@overload\ndef get_kv_class(\n    transfer_backend: TransferBackend, class_type: Literal[KVClassType.MANAGER]\n) -> Type[CommonKVManager]: ...\n@overload\ndef get_kv_class(\n    transfer_backend: TransferBackend, class_type: Literal[KVClassType.SENDER]\n) -> Type[CommonKVSender]: ...\n@overload\ndef get_kv_class(\n    transfer_backend: TransferBackend, class_type: Literal[KVClassType.RECEIVER]\n) -> Type[CommonKVReceiver]: ...\n@overload\ndef get_kv_class(\n    transfer_backend: TransferBackend, class_type: Literal[KVClassType.BOOTSTRAP_SERVER]\n) -> Type[CommonKVBootstrapServer]: ...\n\n\ndef get_kv_class(\n    transfer_backend: TransferBackend, class_type: KVClassType\n) -> Optional[Type]:\n    from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender\n\n    if transfer_backend == TransferBackend.MOONCAKE:\n        from sglang.srt.disaggregation.base import KVArgs\n        from sglang.srt.disaggregation.mooncake import (\n            MooncakeKVBootstrapServer,\n            MooncakeKVManager,\n            MooncakeKVReceiver,\n            MooncakeKVSender,\n        )\n\n        class_mapping = {\n            KVClassType.KVARGS: KVArgs,\n            KVClassType.MANAGER: MooncakeKVManager,\n            KVClassType.SENDER: MooncakeKVSender,\n            KVClassType.RECEIVER: (MooncakeKVReceiver),\n            KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,\n        }\n        return class_mapping.get(class_type)\n    elif transfer_backend == TransferBackend.MORI:\n        from sglang.srt.disaggregation.base import KVArgs\n        from sglang.srt.disaggregation.mori import (\n            MoriKVBootstrapServer,\n            MoriKVManager,\n            MoriKVReceiver,\n            MoriKVSender,\n        )\n\n        class_mapping = {\n            KVClassType.KVARGS: KVArgs,\n            KVClassType.MANAGER: MoriKVManager,\n            KVClassType.SENDER: MoriKVSender,\n            KVClassType.RECEIVER: (MoriKVReceiver),\n            KVClassType.BOOTSTRAP_SERVER: MoriKVBootstrapServer,\n        }\n        return class_mapping.get(class_type)\n    elif transfer_backend == TransferBackend.ASCEND:\n        from sglang.srt.disaggregation.ascend import (\n            AscendKVBootstrapServer,\n            AscendKVManager,\n            AscendKVReceiver,\n            AscendKVSender,\n        )\n        from sglang.srt.disaggregation.base import KVArgs\n\n        class_mapping = {\n            KVClassType.KVARGS: KVArgs,\n            KVClassType.MANAGER: AscendKVManager,\n            KVClassType.SENDER: AscendKVSender,\n            KVClassType.RECEIVER: (AscendKVReceiver),\n            KVClassType.BOOTSTRAP_SERVER: AscendKVBootstrapServer,\n        }\n        return class_mapping.get(class_type)\n    elif transfer_backend == TransferBackend.NIXL:\n        from sglang.srt.disaggregation.base import KVArgs\n        from sglang.srt.disaggregation.nixl import (\n            NixlKVBootstrapServer,\n            NixlKVManager,\n            NixlKVReceiver,\n            NixlKVSender,\n        )\n\n        class_mapping = {\n            KVClassType.KVARGS: KVArgs,\n            KVClassType.MANAGER: NixlKVManager,\n            KVClassType.SENDER: NixlKVSender,\n            KVClassType.RECEIVER: (NixlKVReceiver),\n            KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer,\n        }\n        return class_mapping.get(class_type)\n    elif transfer_backend == TransferBackend.FAKE:\n        from sglang.srt.disaggregation.base import KVArgs\n        from sglang.srt.disaggregation.fake import (\n            FakeKVManager,\n            FakeKVReceiver,\n            FakeKVSender,\n        )\n\n        class_mapping = {\n            KVClassType.KVARGS: KVArgs,\n            KVClassType.MANAGER: FakeKVManager,\n            KVClassType.SENDER: FakeKVSender,\n            KVClassType.RECEIVER: (FakeKVReceiver),\n        }\n        return class_mapping.get(class_type)\n\n    raise ValueError(f\"Unsupported transfer backend: {transfer_backend}\")\n\n\n#########################\n# KV Pages\n#########################\n\n\ndef kv_to_page_indices(kv_indices: np.ndarray, page_size: int):\n    # 1. The page is guaranteed to be full except the last page.\n    # 2. page index = kv_index // page_size\n    # The return vector is kv_indices[::page_size] // page_size\n    if page_size == 1:  # shortcut\n        return kv_indices\n\n    return kv_indices[::page_size] // page_size\n\n\ndef kv_to_page_num(num_kv_indices: int, page_size: int):\n    # ceil(num_kv_indices / page_size)\n    return (num_kv_indices + page_size - 1) // page_size\n\n\ndef page_indices_to_cp_rank_page_indices(\n    page_indices: np.ndarray,\n    total_pages: int,\n    cp_rank: int,\n    cp_size: int,\n) -> np.ndarray:\n    \"\"\"\n    Filter page_indices (which are *global* page ids in the KV pool) to those\n    belonging to the given CP rank for this request.\n\n    For a single request, its pages occupy a contiguous global range\n    [first_page, first_page + total_pages). We first compute the local\n    split [0, total_pages) across cp_size ranks, then shift that local\n    range by first_page back into the global page id space and take\n    the intersection with page_indices.\n\n    Returns:\n        Subset of page_indices that fall in this rank's global\n        [start_page, end_page) slice for the given CP rank.\n    \"\"\"\n    if cp_size <= 1:\n        return page_indices\n\n    if page_indices.size == 0:\n        return np.asarray(page_indices)\n\n    first_page = int(page_indices.min())\n    base = total_pages // cp_size\n    rem = total_pages % cp_size\n\n    if rem == 0:\n        local_start = cp_rank * base\n        local_end = local_start + base\n    else:\n        local_start = cp_rank * base + min(cp_rank, rem)\n        n_pages = base + (1 if cp_rank < rem else 0)\n        local_end = local_start + n_pages\n\n    # Map back to global page ids.\n    start_page = first_page + local_start\n    end_page = first_page + local_end\n\n    mask = (page_indices >= start_page) & (page_indices < end_page)\n    return np.asarray(page_indices)[mask]\n\n\ndef filter_kv_indices_for_cp_rank(\n    kv_mgr: CommonKVManager, kv_indices: np.ndarray, index_slice: slice\n) -> Tuple[np.ndarray, slice]:\n    \"\"\"Filters kv_indices and index_slice for the current CP rank.\"\"\"\n    total_pages = len(kv_indices)\n    cp_rank = kv_mgr.attn_cp_rank\n    cp_size = kv_mgr.attn_cp_size\n\n    rank_page_indices = page_indices_to_cp_rank_page_indices(\n        page_indices=kv_indices,\n        total_pages=total_pages,\n        cp_rank=cp_rank,\n        cp_size=cp_size,\n    )\n\n    if rank_page_indices.size == 0:\n        new_kv_indices = kv_indices[:0]\n        new_index_slice = slice(index_slice.start, index_slice.start)\n    else:\n        mask = np.isin(kv_indices, rank_page_indices)\n        if not mask.any():\n            new_kv_indices = kv_indices[:0]\n            new_index_slice = slice(index_slice.start, index_slice.start)\n        else:\n            first_pos = int(mask.argmax())\n            last_pos = len(mask) - int(mask[::-1].argmax())\n\n            new_kv_indices = kv_indices[first_pos:last_pos]\n            new_index_slice = slice(\n                index_slice.start + first_pos,\n                index_slice.start + last_pos,\n            )\n    return new_kv_indices, new_index_slice\n\n\n#########################\n# Misc\n#########################\n\n\ndef is_mla_backend(target_kv_pool) -> bool:\n    from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool\n\n    return isinstance(target_kv_pool, MLATokenToKVPool)\n\n\ndef prepare_abort(req: Req, error_message: str, status_code=None):\n    from sglang.srt.managers.schedule_batch import FINISH_ABORT\n\n    # populate finish metadata and stream output\n    req.finished_reason = FINISH_ABORT(error_message, status_code)\n\n    if req.return_logprob:\n        req.input_token_logprobs_val = []\n        req.input_token_logprobs_idx = []\n        req.input_top_logprobs_val = []\n        req.input_top_logprobs_idx = []\n        req.input_token_ids_logprobs_val = []\n        req.input_token_ids_logprobs_idx = []\n"
  },
  {
    "path": "python/sglang/srt/distributed/__init__.py",
    "content": "from sglang.srt.distributed.communication_op import *\nfrom sglang.srt.distributed.parallel_state import *\nfrom sglang.srt.distributed.utils import *\n"
  },
  {
    "path": "python/sglang/srt/distributed/communication_op.py",
    "content": "# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/communication_op.py\n\nfrom typing import Any, Dict, Optional, Tuple, Union\n\nimport torch\nimport torch.distributed\n\nfrom .parallel_state import get_tp_group\n\n\ndef tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:\n    \"\"\"All-reduce the input tensor across model parallel group.\"\"\"\n    return get_tp_group().all_reduce(input_)\n\n\ndef tensor_model_parallel_fused_allreduce_rmsnorm(\n    input_: torch.Tensor,\n    residual_inp_: torch.Tensor,\n    weight_: torch.Tensor,\n    eps: float,\n) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:\n    \"\"\"Fused TP all-reduce + RMSNorm.\n\n    Policy and backend selection are owned by GroupCoordinator:\n    it may dispatch to communicator-native fused APIs, custom fused kernels,\n    or return None so callers can run generic fallback paths.\n    \"\"\"\n    return get_tp_group().fused_allreduce_rmsnorm(input_, residual_inp_, weight_, eps)\n\n\ndef tensor_model_parallel_all_gather(\n    input_: torch.Tensor, dim: int = -1\n) -> torch.Tensor:\n    \"\"\"All-gather the input tensor across model parallel group.\"\"\"\n    return get_tp_group().all_gather(input_, dim)\n\n\ndef tensor_model_parallel_gather(\n    input_: torch.Tensor, dst: int = 0, dim: int = -1\n) -> Optional[torch.Tensor]:\n    \"\"\"Gather the input tensor across model parallel group.\"\"\"\n    return get_tp_group().gather(input_, dst, dim)\n\n\ndef broadcast_tensor_dict(\n    tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, src: int = 0\n):\n    if not torch.distributed.is_initialized():\n        return tensor_dict\n    return get_tp_group().broadcast_tensor_dict(tensor_dict, src)\n"
  },
  {
    "path": "python/sglang/srt/distributed/device_communicators/all_reduce_utils.py",
    "content": "MiB = 1024 * 1024\n\nTORCH_SYMM_MEM_ALL_REDUCE_MAX_SIZES = {\n    9: {\n        2: 64 * MiB,  # 64 MB\n        4: 64 * MiB,  # 64 MB\n        6: 128 * MiB,  # 128 MB\n        8: 128 * MiB,  # 128 MB\n    },\n    10: {\n        2: 64 * MiB,  # 64 MB\n        4: 64 * MiB,  # 64 MB\n        6: 128 * MiB,  # 128 MB\n        8: 128 * MiB,  # 128 MB\n    },\n}\n"
  },
  {
    "path": "python/sglang/srt/distributed/device_communicators/cuda_wrapper.py",
    "content": "# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/cuda_wrapper.py\n\n\"\"\"This file is a pure Python wrapper for the cudart library.\nIt avoids the need to compile a separate shared library, and is\nconvenient for use when we just need to call a few functions.\n\"\"\"\n\nimport ctypes\nimport logging\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List, Optional\n\n# this line makes it possible to directly load `libcudart.so` using `ctypes`\nimport torch  # noqa\n\nfrom sglang.srt.utils import is_musa\n\n_is_musa = is_musa()\n\nlogger = logging.getLogger(__name__)\n\n# === export types and functions from cudart to Python ===\n# for the original cudart definition, please check\n# https://docs.nvidia.com/cuda/cuda-runtime-api/index.html\n\ncudaError_t = ctypes.c_int\ncudaMemcpyKind = ctypes.c_int\n\n\nclass cudaIpcMemHandle_t(ctypes.Structure):\n    _fields_ = [(\"internal\", ctypes.c_byte * 128)]\n\n\n@dataclass\nclass Function:\n    name: str\n    restype: Any\n    argtypes: List[Any]\n\n\ndef find_loaded_library(lib_name) -> Optional[str]:\n    \"\"\"\n    According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,\n    the file `/proc/self/maps` contains the memory maps of the process, which includes the\n    shared libraries loaded by the process. We can use this file to find the path of the\n    a loaded library.\n    \"\"\"  # noqa\n    found = False\n    with open(\"/proc/self/maps\") as f:\n        for line in f:\n            if lib_name in line:\n                found = True\n                break\n    if not found:\n        # the library is not loaded in the current process\n        return None\n    # if lib_name is libcudart, we need to match a line with:\n    # address /path/to/libcudart-hash.so.11.0\n    start = line.index(\"/\")\n    path = line[start:].strip()\n    filename = path.split(\"/\")[-1]\n    assert filename.rpartition(\".so\")[0].startswith(\n        lib_name\n    ), f\"Unexpected filename: {filename} for library {lib_name}\"\n    return path\n\n\nclass CudaRTLibrary:\n    exported_functions = [\n        # ​cudaError_t cudaSetDevice ( int  device )\n        Function(\"cudaSetDevice\", cudaError_t, [ctypes.c_int]),\n        # cudaError_t \tcudaDeviceSynchronize ( void )\n        Function(\"cudaDeviceSynchronize\", cudaError_t, []),\n        # ​cudaError_t cudaDeviceReset ( void )\n        Function(\"cudaDeviceReset\", cudaError_t, []),\n        # const char* \tcudaGetErrorString ( cudaError_t error )\n        Function(\"cudaGetErrorString\", ctypes.c_char_p, [cudaError_t]),\n        # ​cudaError_t \tcudaMalloc ( void** devPtr, size_t size )\n        Function(\n            \"cudaMalloc\",\n            cudaError_t,\n            [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t],\n        ),\n        # ​cudaError_t \tcudaFree ( void* devPtr )\n        Function(\"cudaFree\", cudaError_t, [ctypes.c_void_p]),\n        # ​cudaError_t cudaMemset ( void* devPtr, int  value, size_t count )\n        Function(\n            \"cudaMemset\", cudaError_t, [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]\n        ),\n        # ​cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa\n        Function(\n            \"cudaMemcpy\",\n            cudaError_t,\n            [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind],\n        ),\n        # cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa\n        Function(\n            \"cudaIpcGetMemHandle\",\n            cudaError_t,\n            [ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p],\n        ),\n        # ​cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int  flags ) # noqa\n        Function(\n            \"cudaIpcOpenMemHandle\",\n            cudaError_t,\n            [ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint],\n        ),\n    ]\n\n    # class attribute to store the mapping from the path to the library\n    # to avoid loading the same library multiple times\n    path_to_library_cache: Dict[str, Any] = {}\n\n    # class attribute to store the mapping from library path\n    #  to the corresponding dictionary\n    path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}\n\n    def __init__(self, so_file: Optional[str] = None):\n        if so_file is None:\n            so_file = find_loaded_library(\"libcudart\" if not _is_musa else \"libmusart\")\n            assert so_file is not None, \"libcudart is not loaded in the current process\"\n        if so_file not in CudaRTLibrary.path_to_library_cache:\n            lib = ctypes.CDLL(so_file)\n            CudaRTLibrary.path_to_library_cache[so_file] = lib\n        self.lib = CudaRTLibrary.path_to_library_cache[so_file]\n\n        if so_file not in CudaRTLibrary.path_to_dict_mapping:\n            _funcs = {}\n            for func in CudaRTLibrary.exported_functions:\n                f = getattr(self.lib, func.name)\n                f.restype = func.restype\n                f.argtypes = func.argtypes\n                _funcs[func.name] = f\n            CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs\n        self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file]\n\n    def CUDART_CHECK(self, result: cudaError_t) -> None:\n        if result != 0:\n            error_str = self.cudaGetErrorString(result)\n            raise RuntimeError(f\"CUDART error: {error_str}\")\n\n    def cudaGetErrorString(self, error: cudaError_t) -> str:\n        return self.funcs[\"cudaGetErrorString\"](error).decode(\"utf-8\")\n\n    def cudaSetDevice(self, device: int) -> None:\n        self.CUDART_CHECK(self.funcs[\"cudaSetDevice\"](device))\n\n    def cudaDeviceSynchronize(self) -> None:\n        self.CUDART_CHECK(self.funcs[\"cudaDeviceSynchronize\"]())\n\n    def cudaDeviceReset(self) -> None:\n        self.CUDART_CHECK(self.funcs[\"cudaDeviceReset\"]())\n\n    def cudaMalloc(self, size: int) -> ctypes.c_void_p:\n        devPtr = ctypes.c_void_p()\n        self.CUDART_CHECK(self.funcs[\"cudaMalloc\"](ctypes.byref(devPtr), size))\n        return devPtr\n\n    def cudaFree(self, devPtr: ctypes.c_void_p) -> None:\n        self.CUDART_CHECK(self.funcs[\"cudaFree\"](devPtr))\n\n    def cudaMemset(self, devPtr: ctypes.c_void_p, value: int, count: int) -> None:\n        self.CUDART_CHECK(self.funcs[\"cudaMemset\"](devPtr, value, count))\n\n    def cudaMemcpy(\n        self, dst: ctypes.c_void_p, src: ctypes.c_void_p, count: int\n    ) -> None:\n        cudaMemcpyDefault = 4\n        kind = cudaMemcpyDefault\n        self.CUDART_CHECK(self.funcs[\"cudaMemcpy\"](dst, src, count, kind))\n\n    def cudaIpcGetMemHandle(self, devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t:\n        handle = cudaIpcMemHandle_t()\n        self.CUDART_CHECK(\n            self.funcs[\"cudaIpcGetMemHandle\"](ctypes.byref(handle), devPtr)\n        )\n        return handle\n\n    def cudaIpcOpenMemHandle(self, handle: cudaIpcMemHandle_t) -> ctypes.c_void_p:\n        cudaIpcMemLazyEnablePeerAccess = 1\n        devPtr = ctypes.c_void_p()\n        self.CUDART_CHECK(\n            self.funcs[\"cudaIpcOpenMemHandle\"](\n                ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess\n            )\n        )\n        return devPtr\n"
  },
  {
    "path": "python/sglang/srt/distributed/device_communicators/custom_all_reduce.py",
    "content": "# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/custom_all_reduce.py\n\nimport ctypes\nimport logging\nimport os\nfrom contextlib import contextmanager\nfrom functools import partial\nfrom typing import Any, List, Optional, Union\n\nimport torch\nimport torch.distributed as dist\nfrom torch.distributed import ProcessGroup\n\nimport sglang.srt.distributed.device_communicators.custom_all_reduce_ops as ops\nfrom sglang.srt.compilation.piecewise_context_manager import is_in_piecewise_cuda_graph\nfrom sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary\nfrom sglang.srt.distributed.device_communicators.custom_all_reduce_utils import (\n    gpu_p2p_access_check,\n    is_full_nvlink,\n    is_weak_contiguous,\n)\nfrom sglang.srt.distributed.parallel_state import in_the_same_node_as\nfrom sglang.srt.environ import envs\nfrom sglang.srt.utils import (\n    get_bool_env_var,\n    is_cuda,\n    is_hip,\n    is_musa,\n    log_info_on_rank0,\n)\n\n_is_cuda = is_cuda()\n_is_hip = is_hip()\n_is_musa = is_musa()\n\nlogger = logging.getLogger(__name__)\n\n\ndef _can_p2p(rank: int, world_size: int) -> bool:\n    # SGLANG_SKIP_P2P_CHECK can be set to False in sglang\n    SGLANG_SKIP_P2P_CHECK = os.getenv(\"SGLANG_SKIP_P2P_CHECK\", \"0\") == \"1\"\n    for i in range(world_size):\n        if i == rank:\n            continue\n        if SGLANG_SKIP_P2P_CHECK:\n            logger.info(\"Skipping P2P check and trusting the driver's P2P report.\")\n            return torch.cuda.can_device_access_peer(rank, i)\n        if not gpu_p2p_access_check(rank, i):\n            return False\n    return True\n\n\nclass CustomAllreduce:\n    _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]\n    _MAX_CAR_SIZE = 8192 * 1024\n    if _is_hip:\n        # crossover is at 16MB buffer size for ROCm\n        _MAX_CAR_SIZE = 2 * 8192 * 1024\n    if _is_musa:\n        # crossover is at 128MB buffer size for MUSA\n        _MAX_CAR_SIZE = 16 * 8196 * 1024\n\n    # max_size: max supported allreduce size\n    def __init__(\n        self,\n        group: ProcessGroup,\n        device: Union[int, str, torch.device],\n        max_size=_MAX_CAR_SIZE,\n    ) -> None:\n        \"\"\"\n        Args:\n            group: the process group to work on. If None, it will use the\n                default process group.\n            device: the device to bind the CustomAllreduce to. If None,\n                it will be bind to f\"cuda:{local_rank}\".\n        It is the caller's responsibility to make sure each communicator\n        is bind to a unique device, and all communicators in this group\n        are in the same node.\n        \"\"\"\n        self._IS_CAPTURING = False\n        self.disabled = True  # This can be modified in-place by context manager in piecewise cuda graph runner\n        self.original_disabled = True  # To store the original state\n\n        if not ops.IS_CUSTOM_AR_AVAILABLE:\n            # disable because of missing custom allreduce library\n            # e.g. in a non-cuda environment\n            return\n\n        self.group = group\n\n        assert (\n            dist.get_backend(group) != dist.Backend.NCCL\n        ), \"CustomAllreduce should be attached to a non-NCCL group.\"\n\n        if not all(in_the_same_node_as(group, source_rank=0)):\n            # No need to initialize custom allreduce for multi-node case.\n            logger.warning(\n                \"Custom allreduce is disabled because this process group\"\n                \" spans across nodes.\"\n            )\n            return\n\n        rank = dist.get_rank(group=self.group)\n        world_size = dist.get_world_size(group=self.group)\n        if world_size == 1:\n            # No need to initialize custom allreduce for single GPU case.\n            return\n\n        if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES:\n            logger.warning(\n                \"Custom allreduce is disabled due to an unsupported world\"\n                \" size: %d. Supported world sizes: %s. To silence this \"\n                \"warning, specify disable_custom_all_reduce=True explicitly.\",\n                world_size,\n                str(CustomAllreduce._SUPPORTED_WORLD_SIZES),\n            )\n            return\n\n        if isinstance(device, int):\n            device = torch.device(f\"cuda:{device}\")\n        elif isinstance(device, str):\n            device = torch.device(device)\n        # now `device` is a `torch.device` object\n        assert isinstance(device, torch.device)\n        self.device = device\n\n        cuda_visible_devices = os.environ.get(\"CUDA_VISIBLE_DEVICES\", None)\n        if cuda_visible_devices:\n            device_ids = list(map(int, cuda_visible_devices.split(\",\")))\n        else:\n            device_ids = list(range(torch.cuda.device_count()))\n\n        physical_device_id = device_ids[device.index]\n        tensor = torch.tensor([physical_device_id], dtype=torch.int, device=\"cpu\")\n        gather_list = [\n            torch.tensor([0], dtype=torch.int, device=\"cpu\") for _ in range(world_size)\n        ]\n        dist.all_gather(gather_list, tensor, group=self.group)\n        physical_device_ids = [t.item() for t in gather_list]\n\n        # test nvlink first, this will filter out most of the cases\n        # where custom allreduce is not supported\n        # this checks hardware and driver support for NVLink\n        if _is_cuda or _is_hip or _is_musa:\n            full_nvlink = is_full_nvlink(physical_device_ids, world_size)\n\n        if world_size > 2 and not full_nvlink:\n            logger.warning(\n                \"Custom allreduce is disabled because it's not supported on\"\n                \" more than two PCIe-only GPUs. To silence this warning, \"\n                \"specify disable_custom_all_reduce=True explicitly.\"\n            )\n            return\n        # test P2P capability, this checks software/cudaruntime support\n        # this is expensive to compute at the first time\n        # then we cache the result\n        # On AMD GPU, p2p is always enabled between XGMI connected GPUs\n        if not _is_hip and not _can_p2p(rank, world_size):\n            logger.warning(\n                \"Custom allreduce is disabled because your platform lacks \"\n                \"GPU P2P capability or P2P test failed. To silence this \"\n                \"warning, specify disable_custom_all_reduce=True explicitly.\"\n            )\n            return\n\n        self.max_size = max_size\n        self.rank = rank\n        self.world_size = world_size\n        self.full_nvlink = full_nvlink\n\n        if not _is_hip:\n            # Buffers memory are owned by this Python class and passed to C++.\n            # Meta data composes of two parts: meta data for synchronization and a\n            # temporary buffer for storing intermediate allreduce results.\n            self.meta_ptrs = self.create_shared_buffer(\n                ops.meta_size() + max_size, group=group\n            )\n            # This is a pre-registered IPC buffer. In eager mode, input tensors\n            # are first copied into this buffer before allreduce is performed\n            self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)\n            # This is a buffer for storing the tuples of pointers pointing to\n            # IPC buffers from all ranks. Each registered tuple has size of\n            # 8*world_size bytes where world_size is at most 8. Allocating 8MB\n            # is enough for 131072 such tuples. The largest model I've seen only\n            # needs less than 10000 of registered tuples.\n            self.rank_data = torch.empty(\n                max_size, dtype=torch.uint8, device=self.device\n            )\n            self._ptr = ops.init_custom_ar(\n                self.meta_ptrs, self.rank_data, rank, self.full_nvlink\n            )\n            ops.register_buffer(self._ptr, self.buffer_ptrs)\n        else:\n            # meta data buffers need to be \"uncached\" for signal on MI200\n            self.meta = ops.allocate_meta_buffer(ops.meta_size() + max_size)\n            self.buffer = torch.empty(max_size, dtype=torch.uint8, device=self.device)\n            handle = ops.get_meta_buffer_ipc_handle(self.meta)\n            shard_data = (\n                bytes(handle),  # ipc handle to base ptr\n                0,  # offset of base ptr\n            )\n            handles, offsets = self._gather_ipc_meta(shard_data)\n            self.rank_data = torch.empty(\n                max_size, dtype=torch.uint8, device=self.device\n            )\n            self._ptr = ops.init_custom_ar(\n                self.meta, self.rank_data, handles, offsets, rank, self.full_nvlink\n            )\n            self.register_buffer(self.buffer)\n\n        self.disabled = False\n        self.original_disabled = False  # Ensure original_disabled == disabled\n        self.tms_cudagraph = envs.SGLANG_MEMORY_SAVER_CUDA_GRAPH.get()\n\n    @staticmethod\n    def create_shared_buffer(\n        size_in_bytes: int, group: Optional[ProcessGroup] = None\n    ) -> List[int]:\n        \"\"\"\n        Creates a shared buffer and returns a list of pointers\n        representing the buffer on all processes in the group.\n        \"\"\"\n        lib = CudaRTLibrary()\n        pointer = lib.cudaMalloc(size_in_bytes)\n        if _is_musa:\n            lib.cudaMemset(pointer, 0, size_in_bytes)\n        handle = lib.cudaIpcGetMemHandle(pointer)\n        world_size = dist.get_world_size(group=group)\n        rank = dist.get_rank(group=group)\n        handles = [None] * world_size\n        dist.all_gather_object(handles, handle, group=group)\n\n        pointers: List[int] = []\n        for i, h in enumerate(handles):\n            if i == rank:\n                pointers.append(pointer.value)  # type: ignore\n            else:\n                pointers.append(lib.cudaIpcOpenMemHandle(h).value)  # type: ignore\n\n        return pointers\n\n    @staticmethod\n    def free_shared_buffer(\n        pointers: List[int], group: Optional[ProcessGroup] = None\n    ) -> None:\n        rank = dist.get_rank(group=group)\n        lib = CudaRTLibrary()\n        lib.cudaFree(ctypes.c_void_p(pointers[rank]))\n\n    @contextmanager\n    def capture(self):\n        \"\"\"\n        The main responsibility of this context manager is the\n        `register_graph_buffers` call at the end of the context.\n        It records all the buffer addresses used in the CUDA graph.\n        \"\"\"\n        try:\n            self._IS_CAPTURING = True\n            yield\n        finally:\n            self._IS_CAPTURING = False\n            if not self.disabled:\n                self.register_graph_buffers()\n\n    def _get_ipc_meta(self, inp: torch.Tensor):\n        # _share_cuda_() doesn't accept meta buffer not allocated from\n        # PyTorch cache allocator, use direct HIP call to get IPC handle\n        handle = ops.get_meta_buffer_ipc_handle(inp)\n        shard_data = (\n            bytes(handle),  # ipc handle to base ptr\n            0,  # offset of base ptr\n        )\n        return self._gather_ipc_meta(shard_data)\n\n    def _gather_ipc_meta(self, shard_data):\n        # Note: don't use `[[None]] * self.world_size` here\n        # because it will create a list of the same reference\n        all_data: List[Optional[Any]] = [[None] for i in range(self.world_size)]\n        all_data[self.rank][0] = shard_data\n\n        ranks = dist.get_process_group_ranks(group=self.group)\n        ranks.sort()\n        for i, rank in enumerate(ranks):\n            dist.broadcast_object_list(\n                all_data[i], src=rank, group=self.group, device=\"cpu\"\n            )\n\n        # we cannot directly use `dist.all_gather_object` here\n        # because it is incompatible with `gloo` backend under inference mode.\n        # see https://github.com/pytorch/pytorch/issues/126032 for details.\n\n        handles = []\n        offsets = []\n        for i in range(len(all_data)):\n            handles.append(all_data[i][0][0])  # type: ignore\n            offsets.append(all_data[i][0][1])  # type: ignore\n        return handles, offsets\n\n    def register_buffer(self, inp: torch.Tensor):\n        handles, offsets = self._get_ipc_meta(inp)\n        ops.register_buffer(self._ptr, inp, handles, offsets)\n\n    def register_graph_buffers(self):\n        if _is_hip:\n            handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)\n            handles, offsets = self._gather_ipc_meta((bytes(handle), offset))\n            log_info_on_rank0(logger, f\"Registering {len(offset)} cuda graph addresses\")\n            ops.register_graph_buffers(self._ptr, handles, offsets)\n        else:\n            handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)\n            log_info_on_rank0(logger, f\"Registering {len(offset)} cuda graph addresses\")\n            # We cannot directly use `dist.all_gather_object` here\n            # because it is incompatible with `gloo` backend under inference mode.\n            # see https://github.com/pytorch/pytorch/issues/126032 for details.\n            all_data = [\n                [None, None] for _ in range(dist.get_world_size(group=self.group))\n            ]\n            all_data[self.rank] = [handle, offset]\n            ranks = sorted(dist.get_process_group_ranks(group=self.group))\n            for i, rank in enumerate(ranks):\n                dist.broadcast_object_list(\n                    all_data[i], src=rank, group=self.group, device=\"cpu\"\n                )\n            # Unpack list of tuples to tuple of lists.\n            handles = [d[0] for d in all_data]  # type: ignore\n            offsets = [d[1] for d in all_data]  # type: ignore\n            ops.register_graph_buffers(self._ptr, handles, offsets)\n\n    def should_custom_ar(self, inp: torch.Tensor):\n        if self.disabled:\n            return False\n        inp_size = inp.numel() * inp.element_size()\n        # custom allreduce requires input byte size to be multiples of 16\n        if inp_size % 16 != 0:\n            return False\n        if not is_weak_contiguous(inp):\n            return False\n        # for 4 or more non NVLink-capable GPUs, custom allreduce provides\n        # little performance improvement over NCCL.\n        if not _is_hip:\n            if self.world_size == 2 or self.full_nvlink:\n                return inp_size <= self.max_size\n            return False\n\n        if _is_hip:\n            if self.full_nvlink:\n                return inp_size <= self.max_size\n            return False\n\n        return False\n\n    # all reduce, assuming inp tensor is IPC registered with register_buffer,\n    # or, in the context of cuda graphs, register_graph_buffers\n    def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None):\n        if out is None:\n            out = torch.empty_like(inp)\n        ops.all_reduce_reg(self._ptr, inp, out)\n        return out\n\n    # all reduce, assuming inp tensor is NOT IPC registered\n    def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None):\n        if out is None:\n            out = torch.empty_like(inp)\n        ops.all_reduce_unreg(self._ptr, inp, self.buffer, out)\n        return out\n\n    def all_reduce(\n        self,\n        inp: torch.Tensor,\n        *,\n        out: torch.Tensor = None,\n        registered: bool = False,\n    ):\n        \"\"\"Performs an out-of-place all reduce.\n\n        If registered is True, this assumes inp's pointer is already\n        IPC-registered. Otherwise, inp is first copied into a pre-registered\n        buffer.\n        \"\"\"\n        if out is None:\n            out = torch.empty_like(inp)\n        if registered:\n            ops.all_reduce(self._ptr, inp, out, 0, 0)\n        else:\n            ops.all_reduce(\n                self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size\n            )\n        return out\n\n    def deterministic_all_reduce(\n        self,\n        inp: torch.Tensor,\n        *,\n        out: torch.Tensor = None,\n        registered: bool = False,\n    ):\n        \"\"\"Deterministic all-reduce using 1-stage kernel with fixed ordering (AMD only).\"\"\"\n        if out is None:\n            out = torch.empty_like(inp)\n        if registered:\n            ops.deterministic_all_reduce_reg(self._ptr, inp, out)\n        else:\n            reg_buffer = self.buffer.view(inp.dtype)[: inp.numel()]\n            ops.deterministic_all_reduce_unreg(self._ptr, inp, reg_buffer, out)\n        return out\n\n    def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:\n        \"\"\"The main allreduce API that provides support for cuda graph.\"\"\"\n        # When custom allreduce is disabled, this will be None.\n        if self.disabled or not self.should_custom_ar(input):\n            return None\n        if self._IS_CAPTURING:\n            if torch.cuda.is_current_stream_capturing():\n                if _is_hip:\n                    if self.tms_cudagraph:\n                        return self.all_reduce_unreg(input)\n                    return self.all_reduce_reg(input)\n                else:\n                    return self.all_reduce(input, registered=not self.tms_cudagraph)\n            else:\n                # Could be warmup OR piecewise cuda graph split op execution.\n                # In piecewise cuda graph, split ops run eagerly outside the graph\n                # but _IS_CAPTURING is still True. We need to do real all-reduce.\n                if is_in_piecewise_cuda_graph():\n                    # Split op execution - do real all-reduce\n                    if _is_hip:\n                        return self.all_reduce_unreg(input)\n                    else:\n                        return self.all_reduce(input, registered=False)\n                else:\n                    # True warmup - mimic the allocation pattern since custom\n                    # allreduce is out-of-place.\n                    return torch.zeros_like(input)\n        else:\n            if _is_hip:\n                # note: outside of cuda graph context,\n                # custom allreduce incurs a cost of cudaMemcpy, which should\n                # be small(<=1% of overall latency) compared to the performance\n                # gains of using custom kernels\n                return self.all_reduce_unreg(input)\n            else:\n                return self.all_reduce(input, registered=False)\n\n    def close(self):\n        if not self.disabled and self._ptr:\n            ops.dispose(self._ptr)\n            if _is_cuda:\n                self.free_shared_buffer(self.meta_ptrs)\n                self.free_shared_buffer(self.buffer_ptrs)\n            self._ptr = 0\n\n    def __del__(self):\n        self.close()\n\n\ndef dispatch_custom_allreduce():\n    \"\"\"Return the CustomAllreduce class to use (aiter on ROCm if enabled).\n\n    On AMD with 1-stage AR enabled, use sglang's CustomAllreduce (has deterministic_all_reduce method).\n    Otherwise use AiterCustomAllreduce if available.\n    \"\"\"\n    if _is_cuda or _is_musa:\n        return CustomAllreduce\n\n    assert _is_hip\n\n    if envs.SGLANG_USE_1STAGE_ALLREDUCE.is_set():\n        if envs.SGLANG_USE_1STAGE_ALLREDUCE.get():\n            logger.debug(\n                \"[AR] All-reduce: 1-stage kernel (SGLANG_USE_1STAGE_ALLREDUCE=1)\"\n            )\n        else:\n            logger.debug(\"[AR] All-reduce: default (SGLANG_USE_1STAGE_ALLREDUCE=0)\")\n    elif envs.SGLANG_ENABLE_DETERMINISTIC_INFERENCE.get():\n        logger.debug(\n            \"[AR] All-reduce: 1-stage kernel (deterministic inference enabled)\"\n        )\n    else:\n        logger.debug(\"[AR] All-reduce: default\")\n\n    # Check if 1-stage AR should be used\n    if envs.SGLANG_USE_1STAGE_ALLREDUCE.is_set():\n        use_1stage = envs.SGLANG_USE_1STAGE_ALLREDUCE.get()\n    else:\n        use_1stage = envs.SGLANG_ENABLE_DETERMINISTIC_INFERENCE.get()\n\n    # On AMD with 1-stage AR, use sglang's CustomAllreduce\n    # (AiterCustomAllreduce doesn't have deterministic_all_reduce method)\n    if use_1stage:\n        return CustomAllreduce\n\n    if get_bool_env_var(\"SGLANG_USE_AITER_AR\", default=\"true\"):\n        try:\n            from aiter.dist.device_communicators.custom_all_reduce import (\n                CustomAllreduce as AiterCustomAllreduce,\n            )\n\n            logger.info(\"[AR] Using AiterCustomAllreduce (AMD default)\")\n            tms_cudagraph = envs.SGLANG_MEMORY_SAVER_CUDA_GRAPH.get()\n            return partial(\n                AiterCustomAllreduce,\n                enable_register_for_capturing=not tms_cudagraph,\n            )\n        except ImportError as e:\n            logger.warning(\n                \"[AR] Aiter custom all-reduce not available; \"\n                \"falling back to sglang CustomAllreduce. Details: %s\",\n                e,\n            )\n            return CustomAllreduce\n\n    return CustomAllreduce\n"
  },
  {
    "path": "python/sglang/srt/distributed/device_communicators/custom_all_reduce_ops.py",
    "content": "# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py\nimport logging\nfrom typing import List, Optional, Tuple\n\nimport torch\n\nfrom sglang.srt.utils import is_cuda, is_hip, is_musa\n\nlogger = logging.getLogger(__name__)\n\n_is_cuda = is_cuda()\n_is_hip = is_hip()\n_is_musa = is_musa()\n\nIS_CUSTOM_AR_AVAILABLE = _is_cuda or _is_hip or _is_musa\nIS_QUICK_AR_AVAILABLE = _is_hip\n# TODO(zyksir): mscclpp is untested on AMD and therefore disabled.\nIS_MSCCLPP_AR_AVAILABLE = _is_cuda\n\ntry:\n    import sgl_kernel.allreduce as _custom_ar\nexcept ImportError as e:\n    if _is_cuda or _is_hip:\n        logger.warning(\"Failed to import from custom_ar with %r\", e)\n    IS_CUSTOM_AR_AVAILABLE = False\n    IS_QUICK_AR_AVAILABLE = False\n    IS_MSCCLPP_AR_AVAILABLE = False\n\n# region IS_CUSTOM_AR_AVAILABLE\n\nif not IS_CUSTOM_AR_AVAILABLE:\n    pass\n\nelif _is_cuda or _is_musa:\n    # CUDA custom allreduce\n\n    def init_custom_ar(\n        ipc_tensors: List[torch.Tensor],\n        rank_data: torch.Tensor,\n        rank: int,\n        full_nvlink: bool,\n    ) -> int:\n        return _custom_ar.init_custom_ar(ipc_tensors, rank_data, rank, full_nvlink)\n\n    def all_reduce(\n        fa: int,\n        inp: torch.Tensor,\n        out: torch.Tensor,\n        reg_buffer: int,\n        reg_buffer_sz_bytes: int,\n    ) -> None:\n        _custom_ar.all_reduce(fa, inp, out, reg_buffer, reg_buffer_sz_bytes)\n\n    def dispose(fa: int) -> None:\n        _custom_ar.dispose(fa)\n\n    def meta_size() -> int:\n        return _custom_ar.meta_size()\n\n    def register_buffer(fa: int, ipc_tensors: List[int]) -> None:\n        return _custom_ar.register_buffer(fa, ipc_tensors)\n\n    def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:\n        return _custom_ar.get_graph_buffer_ipc_meta(fa)\n\n    def register_graph_buffers(\n        fa: int, handles: List[List[int]], offsets: List[List[int]]\n    ) -> None:\n        _custom_ar.register_graph_buffers(fa, handles, offsets)\n\nelif _is_hip:\n    # ROCM custom allreduce\n\n    def init_custom_ar(\n        meta: torch.Tensor,\n        rank_data: torch.Tensor,\n        handles: List[str],\n        offsets: List[int],\n        rank: int,\n        full_nvlink: bool,\n    ) -> int:\n        return _custom_ar.init_custom_ar(\n            meta, rank_data, handles, offsets, rank, full_nvlink\n        )\n\n    def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:\n        _custom_ar.all_reduce_reg(fa, inp, out)\n\n    def all_reduce_unreg(\n        fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor\n    ) -> None:\n        _custom_ar.all_reduce_unreg(fa, inp, reg_buffer, out)\n\n    def deterministic_all_reduce_reg(\n        fa: int, inp: torch.Tensor, out: torch.Tensor\n    ) -> None:\n        _custom_ar.deterministic_all_reduce_reg(fa, inp, out)\n\n    def deterministic_all_reduce_unreg(\n        fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor\n    ) -> None:\n        _custom_ar.deterministic_all_reduce_unreg(fa, inp, reg_buffer, out)\n\n    def dispose(fa: int) -> None:\n        _custom_ar.dispose(fa)\n\n    def meta_size() -> int:\n        return _custom_ar.meta_size()\n\n    def register_buffer(\n        fa: int, t: torch.Tensor, handles: List[str], offsets: List[int]\n    ) -> None:\n        return _custom_ar.register_buffer(fa, t, handles, offsets)\n\n    def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]:\n        return _custom_ar.get_graph_buffer_ipc_meta(fa)\n\n    def register_graph_buffers(\n        fa: int, handles: List[str], offsets: List[List[int]]\n    ) -> None:\n        _custom_ar.register_graph_buffers(fa, handles, offsets)\n\n    def allocate_meta_buffer(size: int) -> torch.Tensor:\n        return _custom_ar.allocate_meta_buffer(size)\n\n    def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:\n        return _custom_ar.get_meta_buffer_ipc_handle(inp)\n\n\n# endregion\n\n# region IS_QUICK_AR_AVAILABLE\n\nif not IS_QUICK_AR_AVAILABLE:\n    pass\n\nelif _is_hip:\n    # ROCM custom quick allreduce\n\n    def init_custom_qr(\n        rank: int, world_size: int, qr_max_size: Optional[int] = None\n    ) -> int:\n        return _custom_ar.init_custom_qr(world_size, rank, qr_max_size)\n\n    def qr_get_handle(fa: int) -> torch.Tensor:\n        return _custom_ar.qr_get_handle(fa)\n\n    def qr_open_handles(fa: int, handles: list[torch.Tensor]) -> None:\n        _custom_ar.qr_open_handles(fa, handles)\n\n    def qr_all_reduce(\n        fa: int,\n        inp: torch.Tensor,\n        out: torch.Tensor,\n        quant_level: int,\n        cast_bf2half: bool,\n    ) -> None:\n        _custom_ar.qr_all_reduce(fa, inp, out, quant_level, cast_bf2half)\n\n    def qr_destroy(fa: int) -> None:\n        _custom_ar.qr_destroy(fa)\n\n    def qr_max_size() -> int:\n        return _custom_ar.qr_max_size()\n\n\n# endregion\n\n# region IS_MSCCLPP_AR_AVAILABLE\n\nif not IS_MSCCLPP_AR_AVAILABLE:\n    pass\n\nelif _is_cuda:\n\n    def mscclpp_generate_unique_id() -> bytes:\n        return _custom_ar.mscclpp_generate_unique_id()\n\n    def mscclpp_init_context(\n        unique_id: bytes,\n        rank: int,\n        world_size: int,\n        scratch: torch.Tensor,\n        put_buffer: torch.Tensor,\n        nranks_per_node: int,\n        rank_to_node: List[int],\n        rank_to_ib: List[int],\n        context_selection: int,\n    ) -> int:\n        return _custom_ar.mscclpp_init_context(\n            unique_id,\n            rank,\n            world_size,\n            scratch,\n            put_buffer,\n            nranks_per_node,\n            rank_to_node,\n            rank_to_ib,\n            context_selection,\n        )\n\n    def mscclpp_allreduce(\n        context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int\n    ) -> None:\n        return _custom_ar.mscclpp_allreduce(context, inp, out, nthreads, nblocks)\n\n\n# endregion\n"
  },
  {
    "path": "python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py",
    "content": "# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/custom_all_reduce_utils.py\n\nimport ctypes\nimport json\nimport logging\nimport os\nimport pickle\nimport subprocess\nimport sys\nimport tempfile\nfrom functools import wraps\nfrom itertools import product\nfrom typing import Callable, Dict, List, Optional, Sequence, TypeVar\n\nimport torch\nimport torch.distributed as dist\nimport torch.multiprocessing as mp\nfrom typing_extensions import ParamSpec\n\nfrom sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary\nfrom sglang.srt.utils import is_cuda, is_hip, is_musa\n\nlogger = logging.getLogger(__name__)\n\n_is_cuda = is_cuda()\n_is_hip = is_hip()\n_is_musa = is_musa()\n\nif _is_cuda:\n    try:\n        import pynvml\n    except ImportError as e:\n        logger.warning(\"Failed to import pynvml with %r\", e)\n\nif _is_musa:\n    try:\n        import pymtml as pynvml\n    except ImportError as e:\n        logger.warning(\"Failed to import pymtml with %r\", e)\n\nif _is_hip:\n    try:\n        from amdsmi import (\n            AmdSmiException,\n            amdsmi_get_processor_handles,\n            amdsmi_init,\n            amdsmi_shut_down,\n            amdsmi_topo_get_link_type,\n        )\n    except ImportError as e:\n        logger.warning(\"Failed to import amdsmi with %r\", e)\n\n_P = ParamSpec(\"_P\")\n_R = TypeVar(\"_R\")\n\n\ndef update_environment_variables(envs: Dict[str, str]):\n    for k, v in envs.items():\n        if k in os.environ and os.environ[k] != v:\n            logger.warning(\n                \"Overwriting environment variable %s \" \"from '%s' to '%s'\",\n                k,\n                os.environ[k],\n                v,\n            )\n        os.environ[k] = v\n\n\ndef producer(\n    batch_src: Sequence[int],\n    producer_queue,\n    consumer_queue,\n    result_queue,\n    cuda_visible_devices: Optional[str] = None,\n):\n    if cuda_visible_devices is not None:\n        update_environment_variables({\"CUDA_VISIBLE_DEVICES\": cuda_visible_devices})\n\n    lib = CudaRTLibrary()\n    for i in batch_src:\n        lib.cudaSetDevice(i)\n        pointer = lib.cudaMalloc(1024)\n        lib.cudaMemset(pointer, 1, 1024)\n        lib.cudaDeviceSynchronize()\n        handle = lib.cudaIpcGetMemHandle(pointer)\n        producer_queue.put(handle)\n        open_success = consumer_queue.get()\n        if open_success:\n            # use two queues to simulate barrier\n            producer_queue.put(0)\n            consumer_queue.get()\n            # check if the memory is modified\n            host_data = (ctypes.c_char * 1024)()\n            lib.cudaMemcpy(host_data, pointer, 1024)  # type: ignore\n            for i in range(1024):\n                if ord(host_data[i]) != 2:\n                    open_success = False\n                    break\n        result_queue.put(open_success)\n        lib.cudaDeviceReset()\n\n\ndef consumer(\n    batch_tgt: Sequence[int],\n    producer_queue,\n    consumer_queue,\n    result_queue,\n    cuda_visible_devices: Optional[str] = None,\n):\n    if cuda_visible_devices is not None:\n        update_environment_variables({\"CUDA_VISIBLE_DEVICES\": cuda_visible_devices})\n\n    lib = CudaRTLibrary()\n    for j in batch_tgt:\n        lib.cudaSetDevice(j)\n        handle = producer_queue.get()\n        open_success = False\n        try:\n            pointer = lib.cudaIpcOpenMemHandle(handle)  # type: ignore\n            open_success = True\n        except RuntimeError:\n            # cannot error out here, because the producer process\n            # is still waiting for the response.\n            pass\n        consumer_queue.put(open_success)\n        if open_success:\n            # modify the memory\n            lib.cudaMemset(pointer, 2, 1024)\n            lib.cudaDeviceSynchronize()\n            # use two queues to simulate barrier\n            producer_queue.get()\n            consumer_queue.put(0)\n            # check if the memory is modified\n            host_data = (ctypes.c_char * 1024)()\n            lib.cudaMemcpy(host_data, pointer, 1024)  # type: ignore\n            for i in range(1024):\n                if ord(host_data[i]) != 2:\n                    open_success = False\n                    break\n        result_queue.put(open_success)\n        lib.cudaDeviceReset()\n\n\ndef can_actually_p2p(\n    batch_src: Sequence[int],\n    batch_tgt: Sequence[int],\n) -> Sequence[bool]:\n    \"\"\"\n    Usually, checking if P2P access is enabled can be done by\n    `torch.cuda.can_device_access_peer(src, tgt)`. However, sometimes\n    the driver might be broken, and `torch.cuda.can_device_access_peer(src, tgt)`\n    returns `True` even if P2P access is not actually possible.\n    See https://github.com/vllm-project/vllm/issues/2728 and\n    https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10\n    Therefore, we have to perform a real P2P access to check if it is actually\n    possible.\n\n    Note on p2p and cuda IPC:\n    Usually, one process uses one GPU:\n    GPU src --> cuda context src --> tensor src --> process src\n\n    We need to combine p2p and cuda IPC, so that:\n    GPU src --> cuda context src --> tensor src --> process src\n                                      |shared|\n    GPU tgt --> cuda context tgt --> tensor tgt --> process tgt\n    That is to say, process src creates a tensor in GPU src, passes IPC handle to\n    process tgt, and process tgt accesses the tensor in GPU tgt. Any operation on the\n    tensor in process tgt will be reflected in the tensor in process src, because\n    they are the same memory segment.\n    It is important to note that process tgt accesses the tensor in GPU tgt, not\n    GPU src. That's why we need p2p access.\n\n    The most time-consuming part is the process creation. To avoid creating\n    processes for every pair of GPUs, we use batched testing. We create two\n    processes for testing all pairs of GPUs in batch. The trick is to reset\n    the device after each test (which is not available in PyTorch).\n    \"\"\"  # noqa\n    cuda_visible_devices = os.environ.get(\"CUDA_VISIBLE_DEVICES\", None)\n    # pass the CUDA_VISIBLE_DEVICES to the child process\n    # to make sure they see the same set of GPUs\n\n    # make sure the processes are spawned\n    smp = mp.get_context(\"spawn\")\n    producer_queue = smp.Queue()\n    consumer_queue = smp.Queue()\n    result_queue = smp.Queue()\n    p_src = smp.Process(\n        target=producer,\n        args=(\n            batch_src,\n            producer_queue,\n            consumer_queue,\n            result_queue,\n            cuda_visible_devices,\n        ),\n    )\n    p_tgt = smp.Process(\n        target=consumer,\n        args=(\n            batch_tgt,\n            producer_queue,\n            consumer_queue,\n            result_queue,\n            cuda_visible_devices,\n        ),\n    )\n    p_src.start()\n    p_tgt.start()\n    p_src.join()\n    p_tgt.join()\n    assert p_src.exitcode == 0 and p_tgt.exitcode == 0\n    result: List[bool] = []\n    for src, tgt in zip(batch_src, batch_tgt):\n        a = result_queue.get()\n        b = result_queue.get()\n        if a != b:\n            logger.warning(\n                \"Two processes do not agree on the P2P access\"\n                \" status on %d -> %d, treat as disabled.\",\n                src,\n                tgt,\n            )\n            result.append(False)\n        else:\n            result.append(a)\n    return result\n\n\n# why do we need this cache?\n# we are testing peer-to-peer (p2p) access between GPUs,across processes.\n# if we test it every time, it will be very slow, because we need to create\n#  N * N * 2 processes, where N is the world size. This is very slow.\n# to reduce the time, we use a cache file to store the p2p access status.\n# the cache file is generated by the master process if it does not exist.\n# then all the processes can read the cache file to check the p2p access status.\n# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we\n#  can have different cache files for different CUDA_VISIBLE_DEVICES settings,\n#  e.g. used by different vllm engines. The device id in the cache file is a\n#  **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number\n#  of visible devices in the vllm engine.\n_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None\n\n\ndef gpu_p2p_access_check(src: int, tgt: int) -> bool:\n    \"\"\"Check if GPU src can access GPU tgt.\"\"\"\n\n    # if the cache variable is already calculated,\n    # read from the cache instead of checking it again\n    global _gpu_p2p_access_cache\n    if _gpu_p2p_access_cache is not None:\n        return _gpu_p2p_access_cache[f\"{src}->{tgt}\"]\n\n    is_distributed = dist.is_initialized()\n\n    num_dev = torch.cuda.device_count()\n    cuda_visible_devices = os.environ.get(\"CUDA_VISIBLE_DEVICES\", None)\n    if cuda_visible_devices is None:\n        cuda_visible_devices = \",\".join(str(i) for i in range(num_dev))\n\n    # VLLM_CACHE_ROOT -> SGLANG_CACHE_ROOT\n    # \"~/.cache/vllm\" -> \"~/.cache/sglang\"\n    SGLANG_CACHE_ROOT = os.path.expanduser(\"~/.cache/sglang\")\n    path = os.path.join(\n        SGLANG_CACHE_ROOT, f\"gpu_p2p_access_cache_for_{cuda_visible_devices}.json\"\n    )\n    os.makedirs(os.path.dirname(path), exist_ok=True)\n    from sglang.srt.distributed.parallel_state import get_world_group\n\n    if (not is_distributed or get_world_group().local_rank == 0) and (\n        not os.path.exists(path)\n    ):\n        # only the local master process (with local_rank == 0) can\n        #  enter this block to calculate the cache\n        logger.info(\"generating GPU P2P access cache in %s\", path)\n        cache: Dict[str, bool] = {}\n        ids = list(range(num_dev))\n        # batch of all pairs of GPUs\n        batch_src, batch_tgt = zip(*list(product(ids, ids)))\n        # NOTE: we use `subprocess` rather than `multiprocessing` here\n        # because the caller might not have `if __name__ == \"__main__\":`,\n        # in that case we cannot use spawn method in multiprocessing.\n        # However, `can_actually_p2p` requires spawn method.\n        # The fix is, we use `subprocess` to call the function,\n        # where we have `if __name__ == \"__main__\":` in this file.\n\n        # use a temporary file to store the result\n        # we don't use the output of the subprocess directly,\n        # because the subprocess might produce logging output\n        with tempfile.NamedTemporaryFile() as output_file:\n            input_bytes = pickle.dumps((batch_src, batch_tgt, output_file.name))\n            returned = subprocess.run(\n                [sys.executable, __file__], input=input_bytes, capture_output=True\n            )\n            # check if the subprocess is successful\n            try:\n                returned.check_returncode()\n            except Exception as e:\n                # wrap raised exception to provide more information\n                raise RuntimeError(\n                    f\"Error happened when batch testing \"\n                    f\"peer-to-peer access from {batch_src} to {batch_tgt}:\\n\"\n                    f\"{returned.stderr.decode()}\"\n                ) from e\n            with open(output_file.name, \"rb\") as f:\n                result = pickle.load(f)\n        for _i, _j, r in zip(batch_src, batch_tgt, result):\n            cache[f\"{_i}->{_j}\"] = r\n        with open(path, \"w\") as f:\n            json.dump(cache, f, indent=4)\n    if is_distributed:\n        get_world_group().barrier()\n    logger.info(\"reading GPU P2P access cache from %s\", path)\n    with open(path) as f:\n        cache = json.load(f)\n    _gpu_p2p_access_cache = cache\n    return _gpu_p2p_access_cache[f\"{src}->{tgt}\"]\n\n\ndef with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:\n    @wraps(fn)\n    def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:\n        if _is_hip:\n            try:\n                amdsmi_init()\n                return fn(*args, **kwargs)\n            finally:\n                amdsmi_shut_down()\n        else:\n            pynvml.nvmlInit()\n            try:\n                return fn(*args, **kwargs)\n            finally:\n                pynvml.nvmlShutdown()\n\n    return wrapper\n\n\n@with_nvml_context\ndef is_full_nvlink(physical_device_ids: List[int], world_size: int) -> bool:\n    if _is_hip:\n        \"\"\"\n        query if the set of gpus are fully connected by xgmi (1 hop)\n        \"\"\"\n        handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]\n        for i, handle in enumerate(handles):\n            for j, peer_handle in enumerate(handles):\n                if i < j:\n                    try:\n                        link_type = amdsmi_topo_get_link_type(handle, peer_handle)\n                        # type is 2 for XGMI\n                        if link_type[\"hops\"] != 1 or link_type[\"type\"] != 2:\n                            return False\n                    except AmdSmiException as error:\n                        logger.error(\"AMD 1 hop XGMI detection failed.\", exc_info=error)\n                        return False\n        return True\n    else:\n        \"\"\"\n        query if the set of gpus are fully connected by nvlink (1 hop)\n        \"\"\"\n        handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]\n        for i, handle in enumerate(handles):\n            for j, peer_handle in enumerate(handles):\n                if i < j:\n                    try:\n                        p2p_status = pynvml.nvmlDeviceGetP2PStatus(\n                            handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK\n                        )\n                        if p2p_status != pynvml.NVML_P2P_STATUS_OK:\n                            return False\n                    except pynvml.NVMLError:\n                        logger.exception(\n                            \"NVLink detection failed. This is normal if your\"\n                            \" machine has no NVLink equipped.\"\n                        )\n                        return False\n        return True\n\n\ndef is_weak_contiguous(inp: torch.Tensor):\n    return inp.is_contiguous() or (\n        inp.storage().nbytes() - inp.storage_offset() * inp.element_size()\n        == inp.numel() * inp.element_size()\n    )\n\n\n__all__ = [\"gpu_p2p_access_check\"]\n\nif __name__ == \"__main__\":\n    batch_src, batch_tgt, output_file = pickle.loads(sys.stdin.buffer.read())\n    result = can_actually_p2p(batch_src, batch_tgt)\n    with open(output_file, \"wb\") as f:\n        f.write(pickle.dumps(result))\n"
  },
  {
    "path": "python/sglang/srt/distributed/device_communicators/hpu_communicator.py",
    "content": "# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/hpu_communicator.py\n\nimport torch\nimport torch.distributed as dist\nfrom torch.distributed import ProcessGroup\n\nfrom sglang.srt.utils import is_hpu\n\nif is_hpu():\n    import habana_frameworks.torch as htorch  # noqa: F401\n\n\nclass HpuCommunicator:\n\n    def __init__(self, group: ProcessGroup):\n        if not is_hpu():\n            self.disabled = True\n            return\n        self.disabled = False\n        self.group = group\n        self.world_size = dist.get_world_size(self.group)\n\n    def all_reduce(self, x: torch.Tensor) -> torch.Tensor:\n        # FIXME(kzawora): this is a workaround for a bug in Habana PT bridge\n        # occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used\n        # (which is required for tensor parallel HPUGraph inference)\n        htorch.core.mark_step()\n        dist.all_reduce(x, group=self.group)\n        return x\n\n    def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:\n        world_size = self.world_size\n        if dim < 0:\n            # Convert negative dim to positive.\n            dim += x.dim()\n        input_size = x.size()\n        # Allocate output tensor.\n        output_tensor = torch.empty(\n            (world_size,) + input_size, dtype=x.dtype, device=x.device\n        )\n        # All-gather.\n        htorch.core.mark_step()\n        dist.all_gather_into_tensor(output_tensor, x, group=self.group)\n        # Reshape\n        output_tensor = output_tensor.movedim(0, dim)\n        output_tensor = output_tensor.reshape(\n            input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :]\n        )\n        return output_tensor\n"
  },
  {
    "path": "python/sglang/srt/distributed/device_communicators/mooncake_transfer_engine.py",
    "content": "import json\nimport logging\nimport os\nfrom typing import List, Optional\n\nfrom sglang.srt.environ import envs\nfrom sglang.srt.utils.network import NetworkAddress, get_free_port\n\nlogger = logging.getLogger(__name__)\n\n# Module-level shared engine instance, set by init_mooncake_transfer_engine().\n_mooncake_transfer_engine: Optional[\"MooncakeTransferEngine\"] = None\n\n\ndef get_ib_devices_for_gpu(ib_device_str: Optional[str], gpu_id: int) -> Optional[str]:\n    \"\"\"\n    Parse IB device string and get IB devices for a specific GPU ID.\n\n    Supports all the following formats:\n    1. Old format: \"ib0, ib1, ib2\"\n    2. New format: {0: \"ib0, ib1\", 1: \"ib2, ib3\", 2: \"ib4\"}\n    3. JSON file: path to a JSON file containing the mapping\n\n    Args:\n        ib_device_str: The original IB device string or path to JSON file\n        gpu_id: The GPU ID to get devices for\n\n    Returns:\n        IB devices string for the GPU, or None if not available\n    \"\"\"\n    if ib_device_str is None or not ib_device_str.strip():\n        return None\n\n    ib_device_str = ib_device_str.strip()\n\n    # Check if it's a JSON file first and load its content\n    is_json_file = ib_device_str.endswith(\".json\")\n    if is_json_file:\n        try:\n            if os.path.isfile(ib_device_str):\n                with open(ib_device_str, \"r\") as f:\n                    ib_device_str = f.read()\n            else:\n                # File doesn't exist, treat as old format\n                raise RuntimeError(f\"File {ib_device_str} does not exist.\")\n        except (IOError, OSError) as e:\n            # File reading failed, raise exception\n            raise RuntimeError(f\"Failed to read JSON file {ib_device_str}: {e}\") from e\n\n    # Check if it's JSON format (new format)\n    try:\n        parsed_json = json.loads(ib_device_str)\n        if isinstance(parsed_json, dict):\n            # Validate format - keys should be integers (or string rep), values should be strings\n            gpu_mapping = {}\n            for gpu_key, ib_devices in parsed_json.items():\n                if (\n                    isinstance(gpu_key, str)\n                    and gpu_key.isdigit()\n                    and isinstance(ib_devices, str)\n                ):\n                    gpu_mapping[int(gpu_key)] = ib_devices.strip()\n                elif isinstance(gpu_key, int) and isinstance(ib_devices, str):\n                    gpu_mapping[gpu_key] = ib_devices.strip()\n                else:\n                    raise ValueError(\n                        \"Invalid format: keys must be integers (or string \"\n                        \"representations of integers) and values must be strings\"\n                    )\n\n            if not gpu_mapping:\n                raise ValueError(\"No valid GPU mappings found in JSON\")\n\n            # Return devices for specific GPU\n            if gpu_id in gpu_mapping:\n                return gpu_mapping[gpu_id]\n            else:\n                raise ValueError(\n                    f\"No IB devices configured for GPU {gpu_id}. \"\n                    f\"Available GPUs: {list(gpu_mapping.keys())}\"\n                )\n\n    except json.JSONDecodeError:\n        if is_json_file:\n            # It was supposed to be a JSON file but failed to parse\n            raise RuntimeError(\n                f\"Failed to parse JSON content from file {ib_device_str}\"\n            )\n        # Not JSON format, treat as old format - return same devices for all GPUs\n        return ib_device_str\n\n\nclass MooncakeTransferEngine:\n    \"\"\"Shared Mooncake transfer engine for RDMA/transfer operations.\"\"\"\n\n    def __init__(\n        self,\n        hostname: str,\n        gpu_id: Optional[int] = None,\n        ib_device: Optional[str] = None,\n    ):\n        try:\n            from mooncake.engine import TransferEngine\n        except ImportError as e:\n            raise ImportError(\n                \"Please install mooncake by following the instructions at \"\n                \"https://kvcache-ai.github.io/Mooncake/getting_started/build.html \"\n                \"to run SGLang with MooncakeTransferEngine.\"\n            ) from e\n\n        self.engine = TransferEngine()\n        self.hostname = hostname\n        self.gpu_id = gpu_id if gpu_id is not None else 0\n        self.ib_device = get_ib_devices_for_gpu(ib_device, self.gpu_id)\n\n        self.initialize(\n            hostname=self.hostname,\n            device_name=self.ib_device,\n        )\n        self.session_id = NetworkAddress(\n            self.hostname, self.engine.get_rpc_port()\n        ).to_host_port_str()\n\n    def register(self, ptr, length):\n        try:\n            ret_value = self.engine.register_memory(ptr, length)\n        except Exception:\n            # Mark register as failed\n            ret_value = -1\n\n        if ret_value != 0:\n            logger.debug(\"Mooncake memory registration %s failed.\", ptr)\n\n    def deregister(self, ptr):\n        try:\n            ret_value = self.engine.unregister_memory(ptr)\n        except Exception:\n            # Mark deregister as failed\n            ret_value = -1\n\n        if ret_value != 0:\n            logger.debug(\"Mooncake memory deregistration %s failed.\", ptr)\n\n    def batch_register(self, ptrs: List[int], lengths: List[int]) -> int:\n        \"\"\"Batch register multiple memory regions.\"\"\"\n        try:\n            ret_value = self.engine.batch_register_memory(ptrs, lengths)\n        except Exception:\n            # Mark batch register as failed\n            ret_value = -1\n            if not hasattr(self.engine, \"batch_register_memory\"):\n                raise RuntimeError(\n                    \"Mooncake's batch register requires a newer version of \"\n                    \"mooncake-transfer-engine. Please upgrade Mooncake.\"\n                )\n\n        if ret_value != 0:\n            logger.debug(\"Mooncake batch memory registration failed.\")\n        return ret_value\n\n    def batch_deregister(self, ptrs: List[int]) -> int:\n        \"\"\"Batch deregister multiple memory regions.\"\"\"\n        try:\n            ret_value = self.engine.batch_unregister_memory(ptrs)\n        except Exception:\n            # Mark batch deregister as failed\n            ret_value = -1\n\n        if ret_value != 0:\n            logger.debug(\"Mooncake batch memory deregistration failed.\")\n        return ret_value\n\n    def initialize(\n        self,\n        hostname: str,\n        device_name: Optional[str],\n    ) -> None:\n        \"\"\"Initialize the mooncake instance.\"\"\"\n        if envs.ENABLE_ASCEND_TRANSFER_WITH_MOONCAKE.get():\n            npu_phy_id = envs.ASCEND_NPU_PHY_ID.get()\n            if npu_phy_id == -1:\n                hostname += f\":{get_free_port()}:npu_{self.gpu_id}\"\n            else:\n                hostname += f\":{get_free_port()}:npu_{npu_phy_id}\"\n            ret_value = self.engine.initialize(\n                hostname,\n                \"P2PHANDSHAKE\",\n                \"ascend\",\n                device_name if device_name is not None else \"\",\n            )\n        else:\n            ret_value = self.engine.initialize(\n                hostname,\n                \"P2PHANDSHAKE\",\n                \"rdma\",\n                device_name if device_name is not None else \"\",\n            )\n        if ret_value != 0:\n            logger.error(\"Mooncake Transfer Engine initialization failed.\")\n            raise RuntimeError(\"Mooncake Transfer Engine initialization failed.\")\n\n    def transfer_sync(\n        self, session_id: str, buffer: int, peer_buffer_address: int, length: int\n    ) -> int:\n        \"\"\"Synchronously transfer data to the specified address.\"\"\"\n        try:\n            ret = self.engine.transfer_sync_write(\n                session_id, buffer, peer_buffer_address, length\n            )\n        except Exception:\n            ret = -1\n\n        if ret < 0:\n            logger.debug(\n                \"Failed to transfer data from %s to %s - %s.\",\n                buffer,\n                session_id,\n                peer_buffer_address,\n            )\n\n        return ret\n\n    def batch_transfer_sync(\n        self,\n        session_id: str,\n        buffers: List[int],\n        peer_buffer_addresses: List[int],\n        lengths: List[int],\n    ) -> int:\n        \"\"\"Synchronously transfer data to the specified addresses in batches.\"\"\"\n        try:\n            ret = self.engine.batch_transfer_sync_write(\n                session_id, buffers, peer_buffer_addresses, lengths\n            )\n        except Exception:\n            ret = -1\n            if not hasattr(self.engine, \"batch_transfer_sync_write\"):\n                raise RuntimeError(\n                    \"Mooncake's batch transfer requires mooncake-transfer-engine \"\n                    \">= 0.3.4.post2. Please upgrade Mooncake by \"\n                    \"'pip install mooncake-transfer-engine --upgrade'\"\n                )\n\n        if ret < 0:\n            logger.debug(\n                \"Failed to batch transfer data. Buffers: %s, Session: %s, \"\n                \"Peer addresses: %s\",\n                buffers,\n                session_id,\n                peer_buffer_addresses,\n            )\n        return ret\n\n    def get_session_id(self):\n        return self.session_id\n\n    def get_engine(self):\n        return self.engine.get_engine()\n\n    def get_ib_device(self):\n        return self.ib_device\n\n\ndef init_mooncake_transfer_engine(\n    hostname: str,\n    gpu_id: Optional[int] = None,\n    ib_device: Optional[str] = None,\n) -> MooncakeTransferEngine:\n    \"\"\"\n    Initialize the shared MooncakeTransferEngine. Note: if already\n    initialized with the same (hostname, gpu_id, ib_device), returns existing\n    instance. Call from parallel_state when model parallel is set up and\n    mooncake transfer is needed.\n    \"\"\"\n    global _mooncake_transfer_engine\n    if _mooncake_transfer_engine is not None:\n        return _mooncake_transfer_engine\n    _mooncake_transfer_engine = MooncakeTransferEngine(\n        hostname=hostname, gpu_id=gpu_id, ib_device=ib_device\n    )\n    return _mooncake_transfer_engine\n\n\ndef get_mooncake_transfer_engine() -> Optional[MooncakeTransferEngine]:\n    \"\"\"Return the shared MooncakeTransferEngine if initialized, else None.\"\"\"\n    return _mooncake_transfer_engine\n"
  },
  {
    "path": "python/sglang/srt/distributed/device_communicators/npu_communicator.py",
    "content": "import torch\nimport torch.distributed as dist\nfrom torch.distributed import ProcessGroup\n\nfrom sglang.srt.utils import is_npu\n\n\nclass NpuCommunicator:\n\n    def __init__(self, group: ProcessGroup):\n        if not is_npu():\n            self.disabled = True\n            return\n        self.disabled = False\n        self.group = group\n        self.world_size = dist.get_world_size(self.group)\n\n    def all_reduce(self, x: torch.Tensor) -> torch.Tensor:\n        dist.all_reduce(x, group=self.group)\n        return x\n\n    def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:\n        world_size = self.world_size\n        if dim < 0:\n            # Convert negative dim to positive.\n            dim += x.dim()\n        input_size = x.size()\n        output_size = (input_size[0] * world_size,) + input_size[1:]\n        # Allocate output tensor.\n        output_tensor = torch.empty(output_size, dtype=x.dtype, device=x.device)\n        # All-gather.\n        dist.all_gather_into_tensor(output_tensor, x, group=self.group)\n        # Reshape\n        output_tensor = output_tensor.reshape((world_size,) + input_size)\n        output_tensor = output_tensor.movedim(0, dim)\n        output_tensor = output_tensor.reshape(\n            input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :]\n        )\n        return output_tensor\n"
  },
  {
    "path": "python/sglang/srt/distributed/device_communicators/pymscclpp.py",
    "content": "import bisect\nimport logging\nimport math\nimport os\nfrom contextlib import contextmanager\nfrom enum import IntEnum\nfrom typing import Optional, Union\n\nimport torch\nimport torch.distributed as dist\nfrom torch.distributed import ProcessGroup, ReduceOp\n\nimport sglang.srt.distributed.device_communicators.custom_all_reduce_ops as ops\nfrom sglang.srt.utils import is_hip\n\nlogger = logging.getLogger(__name__)\n\n_is_hip = is_hip()\n\n\nclass MscclContextSelection(IntEnum):\n    MSCCL1SHOT1NODELL = 1\n    MSCCL1SHOT2NODELL = 2\n\n\ndef mscclpp_is_weak_contiguous(inp: torch.Tensor):\n    return inp.is_contiguous() or (\n        inp.storage().nbytes() - inp.storage_offset() * inp.element_size()\n        == inp.numel() * inp.element_size()\n    )\n\n\ndef mscclpp_convert_to_bytes(size_str):\n    \"\"\"\n    Converts a human-readable size string (e.g., \"1MB\", \"2.5kb\", \"3 GB\")\n    into the equivalent number of bytes using binary units.\n\n    Args:\n        size_str (str): A string representing size with unit (KB, MB, GB).\n\n    Returns:\n        int: Number of bytes.\n    \"\"\"\n    size_str = size_str.strip().lower()\n\n    if not size_str:\n        raise ValueError(\"Empty input string\")\n\n    # Extract numeric part and unit\n    for i in range(len(size_str)):\n        if not size_str[i].isdigit() and size_str[i] != \".\":\n            break\n    num_str = size_str[:i]\n    unit = size_str[i:].strip()\n\n    try:\n        num = float(num_str)\n    except ValueError:\n        raise ValueError(f\"Invalid numeric value in '{size_str}'\")\n\n    # Conversion factors\n    if unit == \"b\":\n        return int(num)\n    elif unit == \"kb\":\n        return int(num * 1024)\n    elif unit == \"mb\":\n        return int(num * 1024 * 1024)\n    elif unit == \"gb\":\n        return int(num * 1024 * 1024 * 1024)\n    else:\n        raise ValueError(f\"Unsupported unit: {unit}, support B, KB, MB, GB only\")\n\n\ndef mscclpp_bench_time(func, test_niter: int = 10, warmup_niter: int = 2):\n    # warmup\n    for _ in range(warmup_niter):\n        func()\n    start_event = torch.cuda.Event(enable_timing=True)\n    end_event = torch.cuda.Event(enable_timing=True)\n    torch.cuda.synchronize()\n    dist.barrier()\n    start_event.record()\n    for _ in range(test_niter):\n        func()\n    end_event.record()\n    end_event.synchronize()\n    func_cost_us = start_event.elapsed_time(end_event) / test_niter * 1000\n    return func_cost_us\n\n\nclass PyMscclppCommunicator:\n    _SUPPORTED_WORLD_SIZES = [8, 16]\n    _MAX_BYTES = mscclpp_convert_to_bytes(os.getenv(\"SGLANG_MSCCLPP_MAX_BYTES\", \"1MB\"))\n    _SUPPORTED_DTYPE = [torch.float, torch.float16, torch.bfloat16]\n\n    # max_bytes: max supported mscclpp allreduce size\n    # in A100 mscclpp is faster than nccl only under condition of msg size smaller than1MB\n    def __init__(\n        self,\n        group: ProcessGroup,\n        device: Union[int, str, torch.device],\n        max_bytes=_MAX_BYTES,\n    ) -> None:\n        \"\"\"\n        Args:\n            group: the process group to work on. If None, it will use the\n                default process group.\n            device: the device to bind the CustomAllreduce to. If None,\n                it will be bind to f\"cuda:{local_rank}\".\n        It is the caller's responsibility to make sure each communicator\n        is bind to a unique device, and all communicators in this group\n        are in the same node.\n        \"\"\"\n        self._IS_CAPTURING = False\n        self.disabled = True\n\n        if not ops.IS_MSCCLPP_AR_AVAILABLE:\n            # disable because of missing mscclpp library\n            # e.g. in a non-cuda environment\n            return\n\n        self.group = group\n\n        assert (\n            dist.get_backend(group) != dist.Backend.NCCL\n        ), \"CustomAllreduce should be attached to a non-NCCL group.\"\n\n        rank = dist.get_rank(group=self.group)\n        world_size = dist.get_world_size(group=self.group)\n        if world_size == 1:\n            # No need to initialize mscclpp for single GPU case.\n            return\n\n        if world_size not in PyMscclppCommunicator._SUPPORTED_WORLD_SIZES:\n            logger.warning(\n                \"PyMscclpp is disabled due to an unsupported world\"\n                \" size: %d. Supported world sizes: %s. To silence this \"\n                \"warning, specify disable_mscclpp=True explicitly.\",\n                world_size,\n                str(PyMscclppCommunicator._SUPPORTED_WORLD_SIZES),\n            )\n            return\n\n        self.ranks = torch.distributed.get_process_group_ranks(group)\n        self.nranks_per_node = torch.cuda.device_count()\n        # for now mscclpp with stride in the communicator is not tested\n        if not (abs(self.ranks[-1] - self.ranks[0]) == world_size - 1):\n            logger.warning(\n                \"PyMscclpp is disabled due to an unsupported group %s.\"\n                \"Please ensure all ranks in the group are consecutive.\"\n                \"To silence this warning, specify disable_mscclpp=True explicitly.\",\n                str(self.ranks),\n            )\n            return\n\n        if isinstance(device, int):\n            device = torch.device(f\"cuda:{device}\")\n        elif isinstance(device, str):\n            device = torch.device(device)\n        # now `device` is a `torch.device` object\n        assert isinstance(device, torch.device)\n        self.device = device\n\n        self.max_bytes = max_bytes\n        self.rank = rank\n        self.world_size = world_size\n\n        if dist.get_rank(group) == 0:\n            unique_id = [ops.mscclpp_generate_unique_id()]\n        else:\n            unique_id = [None]\n        dist.broadcast_object_list(unique_id, src=self.ranks[0], group=self.group)\n        self.unique_id = unique_id[0]\n        self.rank_to_node, self.rank_to_ib = list(range(world_size)), list(\n            range(world_size)\n        )\n        for r in range(world_size):\n            self.rank_to_node[r] = r // 8\n            self.rank_to_ib[r] = self.rank % 8\n\n        self._context = None\n        self.context_selection = None\n        self.msg_size_for_finetune = [\n            2**i for i in range(10, math.floor(math.log2(self.max_bytes)) + 1)\n        ]\n        self.msg_size2best_config = {}\n        if world_size == 8:\n            self.context_selection = MscclContextSelection.MSCCL1SHOT1NODELL\n        elif world_size == 16:\n            self.context_selection = MscclContextSelection.MSCCL1SHOT2NODELL\n        if not _is_hip:\n            self.scratch = torch.empty(\n                self.max_bytes * 8,\n                dtype=torch.uint8,\n                device=self.device,\n            )\n            self.put_buffer = torch.empty(\n                self.max_bytes * 8 // self.nranks_per_node,\n                dtype=torch.uint8,\n                device=self.device,\n            )\n            self._context = ops.mscclpp_init_context(\n                self.unique_id,\n                self.rank,\n                self.world_size,\n                self.scratch,\n                self.put_buffer,\n                self.nranks_per_node,\n                self.rank_to_node,\n                self.rank_to_ib,\n                int(self.context_selection),\n            )\n        else:\n            raise NotImplementedError(\"HIP Mscclpp is not supported yet.\")\n\n        self.msg_size2best_config = {}\n        self.pre_tune_config()\n        if dist.get_rank(group) == 0:\n            msg_size2best_config = [self.msg_size2best_config]\n        else:\n            msg_size2best_config = [None]\n        dist.broadcast_object_list(\n            msg_size2best_config, src=self.ranks[0], group=self.group\n        )\n        self.msg_size2best_config = msg_size2best_config[0]\n\n        # PyMscclpp is enabled only in cuda graph\n        self.disabled = True\n\n    def pre_tune_config(self, dtype=torch.bfloat16) -> bool:\n        logger.debug(f\"start to pre-tune configs for rank {self.rank}\")\n        nthreads_to_try = [256, 512, 1024]\n        nblocks_to_try = [21, 42, 84]\n        inp_randn = torch.ones(\n            self.msg_size_for_finetune[-1] // dtype.itemsize, dtype=dtype, device=\"cuda\"\n        )\n        oup_randn = torch.empty_like(inp_randn)\n        for msg_size in self.msg_size_for_finetune:\n            mock_inp, mock_outp = (\n                inp_randn[: msg_size // dtype.itemsize],\n                oup_randn[: msg_size // dtype.itemsize],\n            )\n            best_config, best_time = None, None\n            for nthreads in nthreads_to_try:\n                for nblocks in nblocks_to_try:\n                    cur_cost = mscclpp_bench_time(\n                        lambda: ops.mscclpp_allreduce(\n                            self._context, mock_inp, mock_outp, nthreads, nblocks\n                        )\n                    )\n                    if best_time is None or cur_cost < best_time:\n                        best_config = (nthreads, nblocks)\n                        best_time = cur_cost\n            self.msg_size2best_config[msg_size] = best_config\n            if self.rank == 0:\n                logger.debug(\n                    f\"for msg_size {msg_size}, best_config: {best_config}, best_time: {best_time}us\"\n                )\n\n    def should_mscclpp_allreduce(\n        self, inp: torch.Tensor, op: ReduceOp = ReduceOp.SUM\n    ) -> bool:\n        if self.disabled or self._context is None:\n            return False\n        if inp.dtype not in PyMscclppCommunicator._SUPPORTED_DTYPE:\n            return False\n        if not mscclpp_is_weak_contiguous(inp):\n            return False\n        # only support sum op\n        if op != ReduceOp.SUM:\n            return False\n        if inp.numel() * inp.element_size() > self.max_bytes:\n            return False\n        return True\n\n    def all_reduce(self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM):\n        if self._IS_CAPTURING:\n            if torch.cuda.is_current_stream_capturing():\n                self.graph_input_set.add((tensor.dtype, tensor.numel()))\n        msg_size = tensor.numel() * tensor.itemsize\n        index = bisect.bisect_left(self.msg_size_for_finetune, msg_size)\n        msg_size_finetune = self.msg_size_for_finetune[index]\n        nthreads, nblocks = self.msg_size2best_config[msg_size_finetune]\n        result = torch.empty_like(tensor)\n        ops.mscclpp_allreduce(self._context, tensor, result, nthreads, nblocks)\n        return result\n\n    @contextmanager\n    def change_state(\n        self,\n        enable: Optional[bool] = None,\n    ):\n        if enable is None:\n            # guess a default value when not specified\n            enable = self.available\n\n        old_disable = self.disabled\n        self.disabled = not enable\n\n        yield\n\n        self.disabled = old_disable\n"
  },
  {
    "path": "python/sglang/srt/distributed/device_communicators/pynccl.py",
    "content": "# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/pynccl.py\n\nimport logging\nfrom contextlib import contextmanager\nfrom typing import Optional, Union\n\n# ===================== import region =====================\nimport torch\nimport torch.distributed as dist\nfrom torch.distributed import ProcessGroup, ReduceOp\n\nfrom sglang.srt.distributed.device_communicators.pynccl_wrapper import (\n    NCCLLibrary,\n    buffer_type,\n    cudaStream_t,\n    ncclComm_t,\n    ncclDataTypeEnum,\n    ncclRedOpTypeEnum,\n    ncclUniqueId,\n)\nfrom sglang.srt.distributed.utils import StatelessProcessGroup\nfrom sglang.srt.utils.common import get_current_device_stream_fast\n\nlogger = logging.getLogger(__name__)\n\n\nclass PyNcclCommunicator:\n\n    def __init__(\n        self,\n        group: Union[ProcessGroup, StatelessProcessGroup],\n        device: Union[int, str, torch.device],\n        library_path: Optional[str] = None,\n        use_current_stream: bool = False,\n    ):\n        \"\"\"\n        Args:\n            group: the process group to work on. If None, it will use the\n                default process group.\n            device: the device to bind the PyNcclCommunicator to. If None,\n                it will be bind to f\"cuda:{local_rank}\".\n            library_path: the path to the NCCL library. If None, it will\n                use the default library path.\n        It is the caller's responsibility to make sure each communicator\n        is bind to a unique device.\n        \"\"\"\n        if not isinstance(group, StatelessProcessGroup):\n            assert dist.is_initialized()\n            assert (\n                dist.get_backend(group) != dist.Backend.NCCL\n            ), \"PyNcclCommunicator should be attached to a non-NCCL group.\"\n            # note: this rank is the rank in the group\n            self.rank = dist.get_rank(group)\n            self.world_size = dist.get_world_size(group)\n        else:\n            self.rank = group.rank\n            self.world_size = group.world_size\n\n        self.group = group\n\n        # if world_size == 1, no need to create communicator\n        if self.world_size == 1:\n            self.available = False\n            self.disabled = True\n            self.stream = None\n            return\n        try:\n            self.nccl = NCCLLibrary(library_path)\n        except Exception:\n            # disable because of missing NCCL library\n            # e.g. in a non-GPU environment\n            self.available = False\n            self.disabled = True\n            self.stream = None\n            return\n\n        self.available = True\n        self.disabled = False\n        self.use_current_stream = use_current_stream\n\n        self.nccl_version = self.nccl.ncclGetRawVersion()\n        if self.rank == 0:\n            logger.info(\"sglang is using nccl==%s\", self.nccl.ncclGetVersion())\n\n        if self.rank == 0:\n            # get the unique id from NCCL\n            self.unique_id = self.nccl.ncclGetUniqueId()\n        else:\n            # construct an empty unique id\n            self.unique_id = ncclUniqueId()\n\n        if not isinstance(group, StatelessProcessGroup):\n            tensor = torch.ByteTensor(list(self.unique_id.internal))\n            ranks = dist.get_process_group_ranks(group)\n            # arg `src` in `broadcast` is the global rank\n            dist.broadcast(tensor, src=ranks[0], group=group)\n            byte_list = tensor.tolist()\n            for i, byte in enumerate(byte_list):\n                self.unique_id.internal[i] = byte\n        else:\n            self.unique_id = group.broadcast_obj(self.unique_id, src=0)\n        if isinstance(device, int):\n            device = torch.device(f\"cuda:{device}\")\n        elif isinstance(device, str):\n            device = torch.device(device)\n        # now `device` is a `torch.device` object\n        assert isinstance(device, torch.device)\n        self.device = device\n        # nccl communicator and stream will use this device\n        # `torch.cuda.device` is a context manager that changes the\n        # current cuda device to the specified one\n        with torch.cuda.device(device):\n            self.comm: ncclComm_t = self.nccl.ncclCommInitRank(\n                self.world_size, self.unique_id, self.rank\n            )\n            self.stream = torch.cuda.Stream()\n\n            # A small all_reduce for warmup.\n            data = torch.zeros(1, device=device)\n            self.all_reduce(data)\n            self.stream.synchronize()\n            del data\n\n        # by default it is disabled, e.g. in profiling models and prefill phase.\n        # to use it, use under `with obj.change_state(enable=True)`, usually\n        # when we are using CUDA graph.\n        self.disabled = True\n\n    def _resolve_stream(self, stream: Optional[torch.cuda.Stream]):\n        \"\"\"Return the stream to use for NCCL calls.\n\n        Behavior mirrors the previous inline logic:\n        - if an explicit stream is provided, return it\n        - if stream is None and self.use_current_stream is True, return\n          torch.cuda.current_stream()\n        - otherwise return the communicator's default stream (self.stream)\n        \"\"\"\n        if stream is not None:\n            return stream\n        if self.use_current_stream:\n            return get_current_device_stream_fast()\n        return self.stream\n\n    def all_reduce(\n        self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None\n    ):\n        if self.disabled:\n            return\n        # nccl communicator created on a specific device\n        # will only work on tensors on the same device\n        # otherwise it will cause \"illegal memory access\"\n        assert tensor.device == self.device, (\n            f\"this nccl communicator is created to work on {self.device}, \"\n            f\"but the input tensor is on {tensor.device}\"\n        )\n        stream = self._resolve_stream(stream)\n        self.nccl.ncclAllReduce(\n            buffer_type(tensor.data_ptr()),\n            buffer_type(tensor.data_ptr()),\n            tensor.numel(),\n            ncclDataTypeEnum.from_torch(tensor.dtype),\n            ncclRedOpTypeEnum.from_torch(op),\n            self.comm,\n            cudaStream_t(stream.cuda_stream),\n        )\n\n    def outplace_all_reduce(\n        self,\n        in_tensor: torch.Tensor,\n        out_tensor: Optional[torch.Tensor] = None,\n        op: ReduceOp = ReduceOp.SUM,\n        stream=None,\n    ) -> Optional[torch.Tensor]:\n        if self.disabled:\n            return None\n        assert in_tensor.device == self.device, (\n            f\"this nccl communicator is created to work on {self.device}, \"\n            f\"but the input tensor is on {in_tensor.device}\"\n        )\n\n        if out_tensor is None:\n            out_tensor = torch.empty_like(in_tensor)\n\n        stream = self._resolve_stream(stream)\n        self.nccl.ncclAllReduce(\n            buffer_type(in_tensor.data_ptr()),  # sendbuff\n            buffer_type(out_tensor.data_ptr()),  # recvbuff - DIFFERENT pointer\n            in_tensor.numel(),\n            ncclDataTypeEnum.from_torch(in_tensor.dtype),\n            ncclRedOpTypeEnum.from_torch(op),\n            self.comm,\n            cudaStream_t(stream.cuda_stream),\n        )\n        return out_tensor\n\n    def all_gather(\n        self,\n        output_tensor: torch.Tensor,\n        input_tensor: torch.Tensor,\n        stream=None,\n        sizes: Optional[list[int]] = None,\n    ):\n        if self.disabled:\n            return\n        # nccl communicator created on a specific device\n        # will only work on tensors on the same device\n        # otherwise it will cause \"illegal memory access\"\n        assert input_tensor.device == self.device, (\n            f\"this nccl communicator is created to work on {self.device}, \"\n            f\"but the input tensor is on {input_tensor.device}\"\n        )\n        stream = self._resolve_stream(stream)\n\n        if sizes is not None:\n            split_offset = 0\n\n            self.nccl.ncclGroupStart()\n            for root, split_size in enumerate(sizes):\n                dst_slice = output_tensor[split_offset : split_offset + split_size]\n                self.nccl.ncclBroadcast(\n                    buffer_type(input_tensor.data_ptr()),\n                    buffer_type(dst_slice.data_ptr()),\n                    dst_slice.numel(),\n                    ncclDataTypeEnum.from_torch(input_tensor.dtype),\n                    root,\n                    self.comm,\n                    cudaStream_t(stream.cuda_stream),\n                )\n                split_offset += split_size\n            self.nccl.ncclGroupEnd()\n        else:\n            self.nccl.ncclAllGather(\n                buffer_type(input_tensor.data_ptr()),\n                buffer_type(output_tensor.data_ptr()),\n                input_tensor.numel(),\n                ncclDataTypeEnum.from_torch(input_tensor.dtype),\n                self.comm,\n                cudaStream_t(stream.cuda_stream),\n            )\n\n    def cp_all_gather_into_tensor(\n        self,\n        output_tensor: torch.Tensor,\n        input_tensor: torch.Tensor,\n        stream=None,\n        sizes: Optional[list[int]] = None,\n    ):\n        \"\"\"\n        Currently, it is mainly used in context parallelism,\n        primarily leveraging pynccl to implement non-blocking allgather communication.\n        \"\"\"\n        # nccl communicator created on a specific device\n        # will only work on tensors on the same device\n        # otherwise it will cause \"illegal memory access\"\n        assert input_tensor.device == self.device, (\n            f\"this nccl communicator is created to work on {self.device}, \"\n            f\"but the input tensor is on {input_tensor.device}\"\n        )\n        stream = self._resolve_stream(stream)\n        self.nccl.ncclAllGather(\n            buffer_type(input_tensor.data_ptr()),\n            buffer_type(output_tensor.data_ptr()),\n            input_tensor.numel(),\n            ncclDataTypeEnum.from_torch(input_tensor.dtype),\n            self.comm,\n            cudaStream_t(stream.cuda_stream),\n        )\n\n    def reduce_scatter(\n        self,\n        output_tensor: torch.Tensor,\n        input_tensor: torch.Tensor,\n        op: ReduceOp = ReduceOp.SUM,\n        stream=None,\n        sizes: Optional[list[int]] = None,\n    ):\n        if self.disabled:\n            return\n        # nccl communicator created on a specific device\n        # will only work on tensors on the same device\n        # otherwise it will cause \"illegal memory access\"\n        assert input_tensor.device == self.device, (\n            f\"this nccl communicator is created to work on {self.device}, \"\n            f\"but the input tensor is on {input_tensor.device}\"\n        )\n        stream = self._resolve_stream(stream)\n\n        if sizes is not None:\n            split_offset = 0\n            self.nccl.ncclGroupStart()\n            for root, split_size in enumerate(sizes):\n                chunk = input_tensor[split_offset : split_offset + split_size, ...]\n\n                self.nccl.ncclReduce(\n                    buffer_type(chunk.data_ptr()),\n                    buffer_type(output_tensor.data_ptr()),\n                    chunk.numel(),\n                    ncclDataTypeEnum.from_torch(input_tensor.dtype),\n                    ncclRedOpTypeEnum.from_torch(op),\n                    root,\n                    self.comm,\n                    cudaStream_t(stream.cuda_stream),\n                )\n                split_offset += split_size\n            self.nccl.ncclGroupEnd()\n        else:\n            self.nccl.ncclReduceScatter(\n                buffer_type(input_tensor.data_ptr()),\n                buffer_type(output_tensor.data_ptr()),\n                output_tensor.numel(),\n                ncclDataTypeEnum.from_torch(input_tensor.dtype),\n                ncclRedOpTypeEnum.from_torch(op),\n                self.comm,\n                cudaStream_t(stream.cuda_stream),\n            )\n\n    def send(self, tensor: torch.Tensor, dst: int, stream=None):\n        if self.disabled:\n            return\n        assert tensor.device == self.device, (\n            f\"this nccl communicator is created to work on {self.device}, \"\n            f\"but the input tensor is on {tensor.device}\"\n        )\n        stream = self._resolve_stream(stream)\n        self.nccl.ncclSend(\n            buffer_type(tensor.data_ptr()),\n            tensor.numel(),\n            ncclDataTypeEnum.from_torch(tensor.dtype),\n            dst,\n            self.comm,\n            cudaStream_t(stream.cuda_stream),\n        )\n\n    def recv(self, tensor: torch.Tensor, src: int, stream=None):\n        if self.disabled:\n            return\n        assert tensor.device == self.device, (\n            f\"this nccl communicator is created to work on {self.device}, \"\n            f\"but the input tensor is on {tensor.device}\"\n        )\n        stream = self._resolve_stream(stream)\n        self.nccl.ncclRecv(\n            buffer_type(tensor.data_ptr()),\n            tensor.numel(),\n            ncclDataTypeEnum.from_torch(tensor.dtype),\n            src,\n            self.comm,\n            cudaStream_t(stream.cuda_stream),\n        )\n\n    def broadcast(self, tensor: torch.Tensor, src: int, stream=None):\n        if self.disabled:\n            return\n        assert tensor.device == self.device, (\n            f\"this nccl communicator is created to work on {self.device}, \"\n            f\"but the input tensor is on {tensor.device}\"\n        )\n        stream = self._resolve_stream(stream)\n\n        if src == self.rank:\n            sendbuff = buffer_type(tensor.data_ptr())\n            # NCCL requires the sender also to have a receive buffer\n            recvbuff = buffer_type(tensor.data_ptr())\n        else:\n            sendbuff = buffer_type()\n            recvbuff = buffer_type(tensor.data_ptr())\n        self.nccl.ncclBroadcast(\n            sendbuff,\n            recvbuff,\n            tensor.numel(),\n            ncclDataTypeEnum.from_torch(tensor.dtype),\n            src,\n            self.comm,\n            cudaStream_t(stream.cuda_stream),\n        )\n\n    def register_comm_window_raw(self, ptr: int, size: int):\n        return self.nccl.ncclCommWindowRegister(self.comm, buffer_type(ptr), size, 1)\n\n    def deregister_comm_window(self, window):\n        return self.nccl.ncclCommWindowDeregister(self.comm, window)\n\n    def group_start(self):\n        self.nccl.ncclGroupStart()\n\n    def group_end(self):\n        self.nccl.ncclGroupEnd()\n\n    @contextmanager\n    def change_state(\n        self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None\n    ):\n        \"\"\"\n        A context manager to change the state of the communicator.\n        \"\"\"\n        if enable is None:\n            # guess a default value when not specified\n            enable = self.available\n\n        if stream is None:\n            stream = self.stream\n\n        old_disable = self.disabled\n        old_stream = self.stream\n\n        self.stream = stream\n        self.disabled = not enable\n        yield\n\n        self.disabled = old_disable\n        self.stream = old_stream\n"
  },
  {
    "path": "python/sglang/srt/distributed/device_communicators/pynccl_allocator.py",
    "content": "import os\nimport tempfile\nfrom contextlib import nullcontext\n\nimport torch\nfrom torch.cuda.memory import CUDAPluggableAllocator\n\nfrom sglang.srt.distributed.parallel_state import GroupCoordinator\nfrom sglang.srt.server_args import get_global_server_args\nfrom sglang.srt.utils.common import torch_release\n\nafter_2_8_0 = torch_release >= (2, 8)\n\nnccl_allocator_source = \"\"\"\n\n#include <cuda_runtime.h>\n\nextern \"C\" {\n\n// copy from https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in\ntypedef enum { ncclSuccess                 =  0,\n               ncclUnhandledCudaError      =  1,\n               ncclSystemError             =  2,\n               ncclInternalError           =  3,\n               ncclInvalidArgument         =  4,\n               ncclInvalidUsage            =  5,\n               ncclRemoteError             =  6,\n               ncclInProgress              =  7,\n               ncclNumResults              =  8 } ncclResult_t;\ntypedef struct ncclComm* ncclComm_t;\ntypedef struct ncclWindow_vidmem* ncclWindow_t;\nncclResult_t  ncclCommWindowRegister(ncclComm_t comm, void* buff, size_t size, ncclWindow_t* win, int winFlags);\n#define NCCL_WIN_COLL_SYMMETRIC 0x01\n\nncclResult_t  ncclMemAlloc(void** ptr, size_t size);\nncclResult_t  ncclMemFree(void *ptr);\nconst char*  ncclGetErrorString(ncclResult_t result);\n\n#define NCCLCHECK(cmd) do {                                               \\\n  ncclResult_t res = cmd;                                                 \\\n  if (res != ncclSuccess) {                                               \\\n    fprintf(stderr, \"ERROR: NCCL symmetric memory allocation failed. Most likely out of device memory. '%s'\\\\n\", \\\n           ncclGetErrorString(res));                       \\\n    return NULL;                                                        \\\n  }                                                                       \\\n} while(0)\n\nvoid* nccl_alloc_plug(size_t size, int device, void* stream) {\n  void* ptr;\n  NCCLCHECK(ncclMemAlloc(&ptr, size));\n\n  const char *str_val = getenv(\"SGLANG_TMP_NCCL_COMM_VALUE\");\n  char *endptr;\n  void* int_val = (void *)strtoull(str_val, &endptr, 0);\n\n  ncclComm_t comm = (ncclComm_t)(int_val);\n  ncclWindow_t win;\n  NCCLCHECK(ncclCommWindowRegister(comm, ptr, size, &win, NCCL_WIN_COLL_SYMMETRIC));\n\n  return ptr;\n}\n\nvoid nccl_free_plug(void* ptr, size_t size, int device, void* stream) {\n  ncclResult_t err = ncclMemFree(ptr);\n}\n\n}\n\"\"\"\n\n_allocator = None\n_mem_pool = None\n_graph_pool_id = None\n_cur_device = None\n_active_symmetric_memory_context = None\n\n\ndef is_symmetric_memory_enabled():\n    try:\n        return get_global_server_args().enable_symm_mem\n    except ValueError:\n        return False\n\n\ndef set_graph_pool_id(graph_pool_id):\n    global _graph_pool_id\n    _graph_pool_id = graph_pool_id\n\n\ndef disable_symmetric_memory_context():\n    if _active_symmetric_memory_context is None:\n        return None\n    saved_context = _active_symmetric_memory_context\n    saved_context.__exit__(None, None, None)\n    return saved_context\n\n\ndef restore_symmetric_memory_context(saved_context):\n    if saved_context is not None:\n        saved_context.__enter__()\n\n\ndef get_nccl_mem_pool():\n    global _allocator, _mem_pool, _cur_device\n    if _mem_pool is None:\n        import torch.utils.cpp_extension\n\n        out_dir = os.path.join(tempfile.gettempdir(), \"symm_allocator\")\n        os.makedirs(out_dir, exist_ok=True)\n        # Make sure to clean up leftover pytorch lock files\n        # from previous runs and synchronize across processes\n        # right after\n        try:\n            os.remove(os.path.join(out_dir, \"lock\"))\n        except FileNotFoundError:\n            pass\n        torch.distributed.barrier()\n\n        nccl_allocator_libname = \"nccl_allocator\"\n        torch.utils.cpp_extension.load_inline(\n            name=nccl_allocator_libname,\n            cpp_sources=nccl_allocator_source,\n            with_cuda=True,\n            extra_ldflags=[\"-lnccl\"],\n            verbose=True,\n            is_python_module=False,\n            build_directory=out_dir,\n        )\n        _allocator = CUDAPluggableAllocator(\n            f\"{out_dir}/{nccl_allocator_libname}.so\",\n            \"nccl_alloc_plug\",\n            \"nccl_free_plug\",\n        ).allocator()\n        _mem_pool = torch.cuda.MemPool(_allocator)\n        _cur_device = torch.cuda.current_device()\n    return _mem_pool\n\n\nclass SymmetricMemoryContext:\n    \"\"\"\n    Context manager for using symmetric memory with pynccl.\n\n    To Utilize the symmetric memory feature in NCCL, the buffers need to be allocated\n    by `ncclMemAlloc` and registered by `ncclCommWindowRegister`. Due to this, we introduce\n    this context manager. All tensors created under this context will be correctly\n    allocated and registered with a custom allocator.\n    \"\"\"\n\n    def __init__(\n        self,\n        group_coordinator: GroupCoordinator,\n    ):\n        self.group_coordinator = group_coordinator\n        self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool())\n        self.is_graph_capture = torch.cuda.is_current_stream_capturing()\n        self.exited = False\n\n    def __enter__(self):\n        assert (\n            self.group_coordinator.pynccl_comm is not None\n        ), f\"Symmetric memory requires pynccl to be enabled in group '{self.group_coordinator.group_name}'\"\n\n        if self.is_graph_capture:\n            assert (\n                _graph_pool_id is not None\n            ), \"graph_pool_id is not set under graph capture\"\n            # Pause graph memory pool to use symmetric memory with cuda graph\n            if after_2_8_0:\n                torch._C._cuda_endAllocateToPool(_cur_device, _graph_pool_id)\n            else:\n                torch._C._cuda_endAllocateCurrentStreamToPool(\n                    _cur_device, _graph_pool_id\n                )\n\n        if self.exited:\n            # mempool ctx (@contextlib.contextmanager) is not re-entrant\n            self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool())\n            self.exited = False\n        self._mem_pool_ctx.__enter__()\n\n        # Set the env var to pass this argument to the C functions.\n        os.environ[\"SGLANG_TMP_NCCL_COMM_VALUE\"] = str(\n            self.group_coordinator.pynccl_comm.comm.value\n        )\n\n        global _active_symmetric_memory_context\n        _active_symmetric_memory_context = self\n\n        return self\n\n    def __exit__(self, exc_type, exc_val, exc_tb):\n        self._mem_pool_ctx.__exit__(exc_type, exc_val, exc_tb)\n\n        if self.is_graph_capture:\n            if after_2_8_0:\n                torch._C._cuda_beginAllocateCurrentThreadToPool(\n                    _cur_device, _graph_pool_id\n                )\n            else:\n                torch._C._cuda_beginAllocateToPool(_cur_device, _graph_pool_id)\n\n        global _active_symmetric_memory_context\n        _active_symmetric_memory_context = None\n\n        self.exited = True\n\n\ndef use_symmetric_memory(group_coordinator: GroupCoordinator, disabled: bool = False):\n    disabled = (\n        not is_symmetric_memory_enabled()\n        or disabled\n        or group_coordinator.world_size == 1\n    )\n    return SymmetricMemoryContext(group_coordinator) if not disabled else nullcontext()\n"
  },
  {
    "path": "python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py",
    "content": "# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/pynccl.py\n\n# This file is a pure Python wrapper for the NCCL library.\n# The main purpose is to use NCCL combined with CUDA graph.\n# Before writing this script, we tried the following approach:\n# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself\n#  often gets stuck when initializing the NCCL communicator.\n# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`\n#  contains many other potential cuda APIs, that are not allowed during\n#  capturing the CUDA graph. For further details, please check\n# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .\n#\n# Another rejected idea is to write a C/C++ binding for NCCL. It is usually\n# doable, but we often encounter issues related with nccl versions, and need\n# to switch between different versions of NCCL. See\n# https://github.com/NVIDIA/nccl/issues/1234 for more details.\n# A C/C++ binding is not flexible enough to handle this. It requires\n# recompilation of the code every time we want to switch between different\n# versions. This current implementation, with a **pure** Python wrapper, is\n# more flexible. We can easily switch between different versions of NCCL by\n# changing the environment variable `SGLANG_NCCL_SO_PATH`, or the `so_file`\n# variable in the code.\n\nimport ctypes\nimport logging\nimport os\nimport platform\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List, Optional\n\nimport torch\nfrom torch.distributed import ReduceOp\n\nlogger = logging.getLogger(__name__)\n\n\ndef find_nccl_library() -> str:\n    \"\"\"\n    We either use the library file specified by the `SGLANG_NCCL_SO_PATH`\n    environment variable, or we find the library file brought by PyTorch.\n    After importing `torch`, `libnccl.so.2`, `librccl.so.1` or `libmccl.so.2`\n    can be found by `ctypes` automatically.\n    \"\"\"\n\n    # so_file can be set to None in sglang\n    so_file = os.environ.get(\"SGLANG_NCCL_SO_PATH\", None)\n\n    # manually load the nccl library\n    if so_file:\n        logger.info(\n            \"Found nccl from environment variable SGLANG_NCCL_SO_PATH=%s\", so_file\n        )\n    else:\n        if torch.version.cuda is not None:\n            so_file = \"libnccl.so.2\"\n        elif torch.version.hip is not None:\n            so_file = \"librccl.so.1\"\n        elif hasattr(torch.version, \"musa\") and torch.version.musa is not None:\n            so_file = \"libmccl.so.2\"\n        else:\n            raise ValueError(\"NCCL only supports CUDA, ROCm and MUSA backends.\")\n        logger.debug(\"Found nccl from library %s\", so_file)\n    return so_file\n\n\n# === export types and functions from nccl to Python ===\n# for the original nccl definition, please check\n# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in\n\nncclResult_t = ctypes.c_int\nncclComm_t = ctypes.c_void_p\nncclWindow_t = ctypes.c_void_p\n\n\nclass ncclUniqueId(ctypes.Structure):\n    _fields_ = [(\"internal\", ctypes.c_byte * 128)]\n\n\ncudaStream_t = ctypes.c_void_p\nbuffer_type = ctypes.c_void_p\n\nncclDataType_t = ctypes.c_int\n\n\nclass ncclDataTypeEnum:\n    ncclInt8 = 0\n    ncclChar = 0\n    ncclUint8 = 1\n    ncclInt32 = 2\n    ncclInt = 2\n    ncclUint32 = 3\n    ncclInt64 = 4\n    ncclUint64 = 5\n    ncclFloat16 = 6\n    ncclHalf = 6\n    ncclFloat32 = 7\n    ncclFloat = 7\n    ncclFloat64 = 8\n    ncclDouble = 8\n    ncclBfloat16 = 9\n    ncclNumTypes = 10\n\n    @classmethod\n    def from_torch(cls, dtype: torch.dtype) -> int:\n        if dtype == torch.int8:\n            return cls.ncclInt8\n        if dtype == torch.uint8:\n            return cls.ncclUint8\n        if dtype == torch.int32:\n            return cls.ncclInt32\n        if dtype == torch.int64:\n            return cls.ncclInt64\n        if dtype == torch.float16:\n            return cls.ncclFloat16\n        if dtype == torch.float32:\n            return cls.ncclFloat32\n        if dtype == torch.float64:\n            return cls.ncclFloat64\n        if dtype == torch.bfloat16:\n            return cls.ncclBfloat16\n        raise ValueError(f\"Unsupported dtype: {dtype}\")\n\n\nncclRedOp_t = ctypes.c_int\n\n\nclass ncclRedOpTypeEnum:\n    ncclSum = 0\n    ncclProd = 1\n    ncclMax = 2\n    ncclMin = 3\n    ncclAvg = 4\n    ncclNumOps = 5\n\n    @classmethod\n    def from_torch(cls, op: ReduceOp) -> int:\n        if op == ReduceOp.SUM:\n            return cls.ncclSum\n        if op == ReduceOp.PRODUCT:\n            return cls.ncclProd\n        if op == ReduceOp.MAX:\n            return cls.ncclMax\n        if op == ReduceOp.MIN:\n            return cls.ncclMin\n        if op == ReduceOp.AVG:\n            return cls.ncclAvg\n        raise ValueError(f\"Unsupported op: {op}\")\n\n\n@dataclass\nclass Function:\n    name: str\n    restype: Any\n    argtypes: List[Any]\n\n\nclass NCCLLibrary:\n    exported_functions = [\n        # const char* ncclGetErrorString(ncclResult_t result)\n        Function(\"ncclGetErrorString\", ctypes.c_char_p, [ncclResult_t]),\n        # ncclResult_t  ncclGetVersion(int *version);\n        Function(\"ncclGetVersion\", ncclResult_t, [ctypes.POINTER(ctypes.c_int)]),\n        # ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);\n        Function(\"ncclGetUniqueId\", ncclResult_t, [ctypes.POINTER(ncclUniqueId)]),\n        # ncclResult_t  ncclCommInitRank(\n        #   ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);\n        # note that ncclComm_t is a pointer type, so the first argument\n        # is a pointer to a pointer\n        Function(\n            \"ncclCommInitRank\",\n            ncclResult_t,\n            [ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, ctypes.c_int],\n        ),\n        # ncclResult_t  ncclAllReduce(\n        #   const void* sendbuff, void* recvbuff, size_t count,\n        #   ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,\n        #   cudaStream_t stream);\n        # note that cudaStream_t is a pointer type, so the last argument\n        # is a pointer\n        Function(\n            \"ncclAllReduce\",\n            ncclResult_t,\n            [\n                buffer_type,\n                buffer_type,\n                ctypes.c_size_t,\n                ncclDataType_t,\n                ncclRedOp_t,\n                ncclComm_t,\n                cudaStream_t,\n            ],\n        ),\n        # ncclResult_t  ncclAllGather(\n        #   const void* sendbuff, void* recvbuff, size_t count,\n        #   ncclDataType_t datatype, ncclComm_t comm,\n        #   cudaStream_t stream);\n        # note that cudaStream_t is a pointer type, so the last argument\n        # is a pointer\n        Function(\n            \"ncclAllGather\",\n            ncclResult_t,\n            [\n                buffer_type,\n                buffer_type,\n                ctypes.c_size_t,\n                ncclDataType_t,\n                ncclComm_t,\n                cudaStream_t,\n            ],\n        ),\n        # ncclResult_t  ncclReduce(\n        #   const void* sendbuff, void* recvbuff, size_t count,\n        #   ncclDataType_t datatype, ncclRedOp_t op, int root,\n        #   ncclComm_t comm,  cudaStream_t stream);\n        # note that cudaStream_t is a pointer type, so the last argument\n        # is a pointer\n        Function(\n            \"ncclReduce\",\n            ncclResult_t,\n            [\n                buffer_type,\n                buffer_type,\n                ctypes.c_size_t,\n                ncclDataType_t,\n                ncclRedOp_t,\n                ctypes.c_int,\n                ncclComm_t,\n                cudaStream_t,\n            ],\n        ),\n        # ncclResult_t  ncclReduceScatter(\n        #   const void* sendbuff, void* recvbuff, size_t count,\n        #   ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,\n        #   cudaStream_t stream);\n        # note that cudaStream_t is a pointer type, so the last argument\n        # is a pointer\n        Function(\n            \"ncclReduceScatter\",\n            ncclResult_t,\n            [\n                buffer_type,\n                buffer_type,\n                ctypes.c_size_t,\n                ncclDataType_t,\n                ncclRedOp_t,\n                ncclComm_t,\n                cudaStream_t,\n            ],\n        ),\n        # ncclResult_t  ncclSend(\n        #   const void* sendbuff, size_t count, ncclDataType_t datatype,\n        #   int dest, ncclComm_t comm, cudaStream_t stream);\n        Function(\n            \"ncclSend\",\n            ncclResult_t,\n            [\n                buffer_type,\n                ctypes.c_size_t,\n                ncclDataType_t,\n                ctypes.c_int,\n                ncclComm_t,\n                cudaStream_t,\n            ],\n        ),\n        # ncclResult_t  ncclRecv(\n        #   void* recvbuff, size_t count, ncclDataType_t datatype,\n        #   int src, ncclComm_t comm, cudaStream_t stream);\n        Function(\n            \"ncclRecv\",\n            ncclResult_t,\n            [\n                buffer_type,\n                ctypes.c_size_t,\n                ncclDataType_t,\n                ctypes.c_int,\n                ncclComm_t,\n                cudaStream_t,\n            ],\n        ),\n        # ncclResult_t ncclBroadcast(\n        #   const void* sendbuff, void* recvbuff, size_t count,\n        #   ncclDataType_t datatype, int root, ncclComm_t comm,\n        #   cudaStream_t stream);\n        Function(\n            \"ncclBroadcast\",\n            ncclResult_t,\n            [\n                buffer_type,\n                buffer_type,\n                ctypes.c_size_t,\n                ncclDataType_t,\n                ctypes.c_int,\n                ncclComm_t,\n                cudaStream_t,\n            ],\n        ),\n        # be cautious! this is a collective call, it will block until all\n        # processes in the communicator have called this function.\n        # because Python object destruction can happen in random order,\n        # it is better not to call it at all.\n        # ncclResult_t  ncclCommDestroy(ncclComm_t comm);\n        Function(\"ncclCommDestroy\", ncclResult_t, [ncclComm_t]),\n        # ncclResult_t ncclGroupStart();\n        Function(\"ncclGroupStart\", ncclResult_t, []),\n        # ncclResult_t ncclGroupEnd();\n        Function(\"ncclGroupEnd\", ncclResult_t, []),\n    ]\n\n    exported_functions_symm_mem = [\n        # ncclResult_t ncclCommWindowRegister(ncclComm_t comm, void* buff, size_t size, ncclWindow_t* win, int winFlags);\n        Function(\n            \"ncclCommWindowRegister\",\n            ncclResult_t,\n            [\n                ncclComm_t,\n                buffer_type,\n                ctypes.c_size_t,\n                ctypes.POINTER(ncclWindow_t),\n                ctypes.c_int,\n            ],\n        ),\n        # ncclResult_t ncclCommWindowDeregister(ncclComm_t comm, ncclWindow_t win);\n        Function(\"ncclCommWindowDeregister\", ncclResult_t, [ncclComm_t, ncclWindow_t]),\n    ]\n\n    # class attribute to store the mapping from the path to the library\n    # to avoid loading the same library multiple times\n    path_to_library_cache: Dict[str, Any] = {}\n\n    # class attribute to store the mapping from library path\n    #  to the corresponding dictionary\n    path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}\n\n    def __init__(self, so_file: Optional[str] = None):\n\n        so_file = so_file or find_nccl_library()\n\n        try:\n            if so_file not in NCCLLibrary.path_to_dict_mapping:\n                lib = ctypes.CDLL(so_file)\n                NCCLLibrary.path_to_library_cache[so_file] = lib\n            self.lib = NCCLLibrary.path_to_library_cache[so_file]\n        except Exception as e:\n            logger.error(\n                \"Failed to load NCCL library from %s . \"\n                \"It is expected if you are not running on NVIDIA/AMD/MTHREADS GPUs. \"\n                \"Otherwise, the nccl library might not exist, be corrupted \"\n                \"or it does not support the current platform %s. \"\n                \"If you already have the library, please set the \"\n                \"environment variable SGLANG_NCCL_SO_PATH\"\n                \" to point to the correct nccl library path.\",\n                so_file,\n                platform.platform(),\n            )\n            raise e\n\n        if so_file not in NCCLLibrary.path_to_dict_mapping:\n            _funcs: Dict[str, Any] = {}\n            exported_functions = NCCLLibrary.exported_functions\n            if hasattr(self.lib, \"ncclCommWindowRegister\"):\n                exported_functions.extend(NCCLLibrary.exported_functions_symm_mem)\n            for func in exported_functions:\n                f = getattr(self.lib, func.name)\n                f.restype = func.restype\n                f.argtypes = func.argtypes\n                _funcs[func.name] = f\n            NCCLLibrary.path_to_dict_mapping[so_file] = _funcs\n        self._funcs = NCCLLibrary.path_to_dict_mapping[so_file]\n\n    def ncclGetErrorString(self, result: ncclResult_t) -> str:\n        return self._funcs[\"ncclGetErrorString\"](result).decode(\"utf-8\")\n\n    def NCCL_CHECK(self, result: ncclResult_t) -> None:\n        if result != 0:\n            error_str = self.ncclGetErrorString(result)\n            raise RuntimeError(f\"NCCL error: {error_str}\")\n\n    def ncclGetRawVersion(self) -> int:\n        version = ctypes.c_int()\n        self.NCCL_CHECK(self._funcs[\"ncclGetVersion\"](ctypes.byref(version)))\n        # something like 21903\n        return version.value\n\n    def ncclGetVersion(self) -> str:\n        version_str = str(self.ncclGetRawVersion())\n        # something like 21903 --> \"2.19.3\"\n        major = version_str[0].lstrip(\"0\")\n        minor = version_str[1:3].lstrip(\"0\")\n        patch = version_str[3:].lstrip(\"0\")\n        return f\"{major}.{minor}.{patch}\"\n\n    def ncclGetUniqueId(self) -> ncclUniqueId:\n        unique_id = ncclUniqueId()\n        self.NCCL_CHECK(self._funcs[\"ncclGetUniqueId\"](ctypes.byref(unique_id)))\n        return unique_id\n\n    def ncclCommInitRank(\n        self, world_size: int, unique_id: ncclUniqueId, rank: int\n    ) -> ncclComm_t:\n        comm = ncclComm_t()\n        self.NCCL_CHECK(\n            self._funcs[\"ncclCommInitRank\"](\n                ctypes.byref(comm), world_size, unique_id, rank\n            )\n        )\n        return comm\n\n    def ncclAllReduce(\n        self,\n        sendbuff: buffer_type,\n        recvbuff: buffer_type,\n        count: int,\n        datatype: int,\n        op: int,\n        comm: ncclComm_t,\n        stream: cudaStream_t,\n    ) -> None:\n        # `datatype` actually should be `ncclDataType_t`\n        # and `op` should be `ncclRedOp_t`\n        # both are aliases of `ctypes.c_int`\n        # when we pass int to a function, it will be converted to `ctypes.c_int`\n        # by ctypes automatically\n        self.NCCL_CHECK(\n            self._funcs[\"ncclAllReduce\"](\n                sendbuff, recvbuff, count, datatype, op, comm, stream\n            )\n        )\n\n    def ncclReduce(\n        self,\n        sendbuff: buffer_type,\n        recvbuff: buffer_type,\n        count: int,\n        datatype: int,\n        op: int,\n        root: int,\n        comm: ncclComm_t,\n        stream: cudaStream_t,\n    ) -> None:\n        # `datatype` actually should be `ncclDataType_t`\n        # and `op` should be `ncclRedOp_t`\n        # both are aliases of `ctypes.c_int`\n        # when we pass int to a function, it will be converted to `ctypes.c_int`\n        # by ctypes automatically\n        self.NCCL_CHECK(\n            self._funcs[\"ncclReduce\"](\n                sendbuff, recvbuff, count, datatype, op, root, comm, stream\n            )\n        )\n\n    def ncclReduceScatter(\n        self,\n        sendbuff: buffer_type,\n        recvbuff: buffer_type,\n        count: int,\n        datatype: int,\n        op: int,\n        comm: ncclComm_t,\n        stream: cudaStream_t,\n    ) -> None:\n        # `datatype` actually should be `ncclDataType_t`\n        # and `op` should be `ncclRedOp_t`\n        # both are aliases of `ctypes.c_int`\n        # when we pass int to a function, it will be converted to `ctypes.c_int`\n        # by ctypes automatically\n        self.NCCL_CHECK(\n            self._funcs[\"ncclReduceScatter\"](\n                sendbuff, recvbuff, count, datatype, op, comm, stream\n            )\n        )\n\n    def ncclAllGather(\n        self,\n        sendbuff: buffer_type,\n        recvbuff: buffer_type,\n        count: int,\n        datatype: int,\n        comm: ncclComm_t,\n        stream: cudaStream_t,\n    ) -> None:\n        # `datatype` actually should be `ncclDataType_t`\n        # which is an aliases of `ctypes.c_int`\n        # when we pass int to a function, it will be converted to `ctypes.c_int`\n        # by ctypes automatically\n        self.NCCL_CHECK(\n            self._funcs[\"ncclAllGather\"](\n                sendbuff, recvbuff, count, datatype, comm, stream\n            )\n        )\n\n    def ncclSend(\n        self,\n        sendbuff: buffer_type,\n        count: int,\n        datatype: int,\n        dest: int,\n        comm: ncclComm_t,\n        stream: cudaStream_t,\n    ) -> None:\n        self.NCCL_CHECK(\n            self._funcs[\"ncclSend\"](sendbuff, count, datatype, dest, comm, stream)\n        )\n\n    def ncclRecv(\n        self,\n        recvbuff: buffer_type,\n        count: int,\n        datatype: int,\n        src: int,\n        comm: ncclComm_t,\n        stream: cudaStream_t,\n    ) -> None:\n        self.NCCL_CHECK(\n            self._funcs[\"ncclRecv\"](recvbuff, count, datatype, src, comm, stream)\n        )\n\n    def ncclBroadcast(\n        self,\n        sendbuff: buffer_type,\n        recvbuff: buffer_type,\n        count: int,\n        datatype: int,\n        root: int,\n        comm: ncclComm_t,\n        stream: cudaStream_t,\n    ) -> None:\n        self.NCCL_CHECK(\n            self._funcs[\"ncclBroadcast\"](\n                sendbuff, recvbuff, count, datatype, root, comm, stream\n            )\n        )\n\n    def ncclCommDestroy(self, comm: ncclComm_t) -> None:\n        self.NCCL_CHECK(self._funcs[\"ncclCommDestroy\"](comm))\n\n    def ncclCommWindowRegister(\n        self, comm: ncclComm_t, buff: buffer_type, size: int, win_flags: int\n    ) -> ncclWindow_t:\n        window = ncclWindow_t()\n        self.NCCL_CHECK(\n            self._funcs[\"ncclCommWindowRegister\"](\n                comm, buff, size, ctypes.byref(window), win_flags\n            )\n        )\n        return window\n\n    def ncclCommWindowDeregister(self, comm: ncclComm_t, window: ncclWindow_t) -> None:\n        self.NCCL_CHECK(self._funcs[\"ncclCommWindowDeregister\"](comm, window))\n\n    def ncclGroupStart(self) -> None:\n        self.NCCL_CHECK(self._funcs[\"ncclGroupStart\"]())\n\n    def ncclGroupEnd(self) -> None:\n        self.NCCL_CHECK(self._funcs[\"ncclGroupEnd\"]())\n\n\n__all__ = [\n    \"NCCLLibrary\",\n    \"ncclDataTypeEnum\",\n    \"ncclRedOpTypeEnum\",\n    \"ncclUniqueId\",\n    \"ncclComm_t\",\n    \"cudaStream_t\",\n    \"buffer_type\",\n]\n"
  },
  {
    "path": "python/sglang/srt/distributed/device_communicators/quick_all_reduce.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\nimport logging\nimport os\nfrom enum import Enum\nfrom functools import cache\nfrom typing import Union\n\nimport torch\nimport torch.distributed as dist\nfrom torch.distributed import ProcessGroup\n\nimport sglang.srt.distributed.device_communicators.custom_all_reduce_ops as ops\nfrom sglang.srt.distributed.device_communicators.custom_all_reduce_utils import (\n    is_full_nvlink,\n    is_weak_contiguous,\n)\nfrom sglang.srt.distributed.parallel_state import in_the_same_node_as\nfrom sglang.srt.utils import is_cuda, is_hip\n\nlogger = logging.getLogger(__name__)\n\n_is_cuda = is_cuda()\n_is_hip = is_hip()\n\n\n@cache\ndef qr_rocm_arch_available():\n    if not _is_hip:\n        return False\n    try:\n        props = torch.cuda.get_device_properties(0)\n        gcn_arch = getattr(props, \"gcnArchName\", \"\")\n        supported_archs = [\"gfx94\", \"gfx95\"]\n        return any(gfx in gcn_arch for gfx in supported_archs)\n    except Exception as e:\n        logger.warning(\"Failed to determine ROCm for quick allreduce: %s\", e)\n        return False\n\n\nclass QuickReduceRegime(Enum):\n    FP = 0\n    INT8 = 1\n    INT6 = 2\n    INT4 = 3\n    NONE = 4\n\n\nMB = 1024 * 1024\n\n\nclass QuickAllReduce:\n\n    _SUPPORTED_WORLD_SIZES = [2, 4, 8]\n    _SUPPORTED_DTYPES = [torch.float16, torch.bfloat16]\n    # The following data is based on kernel tests.\n    # In this order [FP, INT8, INT6, INT4].\n    _QR_MIN_SIZE = {\n        (torch.float16, 2): [1 * MB, 2 * MB, 2 * MB, 1 * MB],\n        (torch.float16, 4): [1 * MB, 16 * MB, 4 * MB, 2 * MB],\n        (torch.float16, 8): [16 * MB, 4 * MB, 4 * MB, 2 * MB],\n        (torch.bfloat16, 2): [2 * MB, 8 * MB, 8 * MB, 8 * MB],\n        (torch.bfloat16, 4): [8 * MB, 64 * MB, 64 * MB, 16 * MB],\n        (torch.bfloat16, 8): [16 * MB, 2048 * MB, 2048 * MB, 2048 * MB],\n    }\n\n    def __init__(\n        self, group: ProcessGroup, device: Union[int, str, torch.device]\n    ) -> None:\n        \"\"\"\n        Custom allreduce provides non-destructive acceleration and is\n        available for CUDA and ROCm MI300 series.\n        Custom quick allreduce leverages quantization for further\n        acceleration on ROCm. It currently supports Q8, Q6, and Q4\n        quantization formats and FP(float16, bfloat16).\n        Quick allreduce is designed as a complement to custom allreduce.\n        Its initialization requires even stricter conditions.\n        Only the ROCm MI300 series is supported for quick allreduce at\n        this time.\n        Args:\n            group: the process group to work on. If None, it will use the\n                default process group.\n            device: the device to bind the CustomAllreduce to. If None,\n                it will be bind to f\"cuda:{local_rank}\".\n        It is the caller's responsibility to make sure each communicator\n        is bind to a unique device, and all communicators in this group\n        are in the same node.\n        \"\"\"\n        self.disabled = True\n        if not qr_rocm_arch_available():\n            logger.debug(\n                \"Custom quick allreduce is only supported on ROCm MI300 series.\"\n            )\n            return\n\n        if not ops.IS_QUICK_AR_AVAILABLE:\n            # disable because of missing quick reduce library\n            # e.g. in a cuda environment\n            logger.info(\n                \"Custom quick allreduce is disabled because \"\n                \"of missing custom quick allreduce library\"\n            )\n            return\n\n        self.group = group\n        assert (\n            dist.get_backend(group) != dist.Backend.NCCL\n        ), \"Custom quick allreduce should be attached to a non-NCCL group.\"\n        if not all(in_the_same_node_as(group, source_rank=0)):\n            # No need to initialize custom quick allreduce for\n            # multi-node case.\n            logger.warning(\n                \"Custom quick allreduce is disabled because this \"\n                \"process group spans across nodes.\"\n            )\n            return\n        rank = dist.get_rank(group=self.group)\n        world_size = dist.get_world_size(group=self.group)\n        self.rank = rank\n        self.world_size = world_size\n        if world_size == 1:\n            # No need to initialize QuickReduce for single GPU case.\n            return\n\n        if world_size not in QuickAllReduce._SUPPORTED_WORLD_SIZES:\n            logger.warning(\n                \"Custom quick allreduce is disabled due to an \"\n                \"unsupported world size: %d. Supported world sizes: %s.\",\n                world_size,\n                str(QuickAllReduce._SUPPORTED_WORLD_SIZES),\n            )\n            return\n\n        if isinstance(device, int):\n            device = torch.device(f\"cuda:{device}\")\n        elif isinstance(device, str):\n            device = torch.device(device)\n        assert isinstance(device, torch.device)\n        self.device = device\n\n        cuda_visible_devices = os.environ.get(\"CUDA_VISIBLE_DEVICES\", None)\n        if cuda_visible_devices:\n            device_ids = list(map(int, cuda_visible_devices.split(\",\")))\n        else:\n            device_ids = list(range(torch.cuda.device_count()))\n        physical_device_id = device_ids[device.index]\n        tensor = torch.tensor([physical_device_id], dtype=torch.int, device=\"cpu\")\n        gather_list = [\n            torch.tensor([0], dtype=torch.int, device=\"cpu\")\n            for _ in range(self.world_size)\n        ]\n        dist.all_gather(gather_list, tensor, group=self.group)\n        physical_device_ids = [t.item() for t in gather_list]\n\n        # test nvlink first, this will filter out most of the cases\n        # where custom quick allreduce is not supported\n        # this checks hardware and driver support for NVLink\n        if _is_cuda or _is_hip:\n            self.fully_connected = is_full_nvlink(physical_device_ids, self.world_size)\n        if self.world_size > 2 and not self.fully_connected:\n            logger.debug(\n                \"Custom quick allreduce is disabled because it's not supported \"\n                \"on more than two PCIe-only GPUs. \"\n            )\n            return\n\n        self.init_quick_all_reduce()\n\n    def init_quick_all_reduce(self):\n        # On RocM, bfloat16 kernels are slower than fp16\n        # due to slower match operations\n        # If environment variable is set to 1, we convert input to fp16\n        self.use_fp16_kernels = int(\n            os.environ.get(\"ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16\", 1)\n        )\n        regime_str = os.environ.get(\"ROCM_QUICK_REDUCE_QUANTIZATION\", \"NONE\")\n        if regime_str not in QuickReduceRegime.__members__:\n            logger.warning(\n                \"Custom quick allreduce:\",\n                f\"Invalid quantization level: {regime_str}. \"\n                \"Supported levels: \"\n                f\"{list(QuickReduceRegime.__members__.keys())}\",\n            )\n            return\n\n        if regime_str == \"NONE\":\n            logger.debug(\n                \"Custom quick allreduce is disabled based \"\n                \"on env variable \"\n                \"ROCM_QUICK_REDUCE_QUANTIZATION='NONE'\"\n            )\n            return\n        self.qr_quant_level = QuickReduceRegime[regime_str]\n\n        # TODO: If the dtype is not bfloat16 or then float16,\n        # quickallreduce should not be created.\n\n        # ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB is specified in MB\n        qr_max_size = int(os.environ.get(\"ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB\", 0))\n        if qr_max_size > 0:\n            if qr_max_size < 1:\n                logger.info(\n                    \"You should not set a max_size smaller than 1MB, which can \"\n                    \"lead to error or degradation to custom allreduce or rccl.\"\n                )\n            qr_max_size = qr_max_size * MB\n        # If qr_max_size is None, then 2GB is used by default.\n        self._ptr = ops.init_custom_qr(self.rank, self.world_size, qr_max_size)\n        self.qr_max_size = qr_max_size if qr_max_size > 0 else ops.qr_max_size()\n        self.create_shared_buffer()\n        self.disabled = False\n\n    def create_shared_buffer(self):\n        \"\"\"\n        Creates a shared buffer for quickreduce.\n        Has to be called after init_custom_qr\n        \"\"\"\n        handle = ops.qr_get_handle(self._ptr)\n        world_size = dist.get_world_size(group=self.group)\n        handles = [None] * world_size\n        dist.all_gather_object(handles, handle, group=self.group)\n        ops.qr_open_handles(self._ptr, handles)\n\n    def should_quick_allreduce(self, inp: torch.Tensor):\n        \"\"\"\n        Check if quickreduce is available\n        \"\"\"\n        if self.disabled:\n            return False\n        if inp.dtype not in self._SUPPORTED_DTYPES:\n            return False\n        inp_size = inp.numel() * inp.element_size()\n        # custom quick allreduce requires input byte size to be\n        # multiples of 16\n        if inp_size % 16 != 0:\n            return False\n        if not is_weak_contiguous(inp):\n            return False\n        dtype = inp.dtype\n        if self.use_fp16_kernels:\n            dtype = torch.float16\n        return (\n            inp_size <= self.qr_max_size\n            and inp_size\n            >= self._QR_MIN_SIZE[(dtype, self.world_size)][self.qr_quant_level.value]\n        )\n\n    def quick_all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None):\n        \"\"\"Performs an out-of-place custom quick all reduce.\"\"\"\n        # quick allreduce doesn't require a separate graph mode,\n        # as QR uses static IPC buffer.\n        if out is None:\n            out = torch.empty_like(inp)\n        ops.qr_all_reduce(\n            self._ptr, inp, out, self.qr_quant_level.value, self.use_fp16_kernels\n        )\n        return out\n\n    def close(self):\n        if not self.disabled and getattr(self, \"_ptr\", None):\n            if ops is not None:\n                ops.qr_destroy(self._ptr)\n            self._ptr = 0\n            self.disabled = True\n\n    def __del__(self):\n        self.close()\n"
  },
  {
    "path": "python/sglang/srt/distributed/device_communicators/shm_broadcast.py",
    "content": "# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/shm_broadcast.py\n\nimport logging\nimport os\nimport pickle\nimport time\nfrom contextlib import contextmanager\nfrom dataclasses import dataclass, field\nfrom multiprocessing import shared_memory\nfrom typing import List, Optional\nfrom unittest.mock import patch\n\nimport torch\nimport torch.distributed as dist\nfrom torch.distributed import ProcessGroup\nfrom zmq import IPV6  # type: ignore\nfrom zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context  # type: ignore\n\nfrom sglang.srt.utils.network import NetworkAddress, get_local_ip_auto, get_open_port\n\n# SGLANG_RINGBUFFER_WARNING_INTERVAL can be set to 60\nSGLANG_RINGBUFFER_WARNING_INTERVAL = int(\n    os.environ.get(\"SGLANG_RINGBUFFER_WARNING_INTERVAL\", \"60\")\n)\n\nlogger = logging.getLogger(__name__)\n\n\nclass ShmRingBuffer:\n\n    def __init__(\n        self,\n        n_reader: int,\n        max_chunk_bytes: int,\n        max_chunks: int,\n        name: Optional[str] = None,\n    ):\n        \"\"\"\n        A shared memory ring buffer implementation for broadcast communication.\n        Essentially, it is a queue where only one will `enqueue` and multiple\n        will `dequeue`. The max size of each item, together with the max number\n        of items that can be stored in the buffer are known in advance.\n        In this case, we don't need to synchronize the access to\n         the buffer.\n\n        Buffer memory layout:\n                  data                                 metadata\n                    |                                      |\n                    | (current_idx)                        | (current_idx)\n                    v                                      v\n        +-------------------------------+----------------------------------------+\n        | chunk0 | chunk1 | ... | chunk | metadata0 | metadata1 | ... | metadata |\n        +-------------------------------+----------------------------------------+\n        | max_chunks x max_chunk_bytes  | max_chunks x (1 + n_reader) bytes      |\n\n        metadata memory layout: each byte is a flag, the first byte is the written\n        flag, and the rest are reader flags. The flags are set to 0 by default.\n        +--------------+--------------+--------------+-----+--------------+\n        | written_flag | reader0_flag | reader1_flag | ... | readerN_flag |\n        +--------------+--------------+--------------+-----+--------------+\n\n        The state of metadata is as follows:\n\n        (case 1) 0???...???: the block is not written yet, cannot read, can write\n        (case 2) 1000...000: the block is just written, can read, cannot write\n        (case 3) 1???...???: the block is written and read by some readers, can read if not read, cannot write\n        (case 4) 1111...111: the block is written and read by all readers, cannot read, can write\n\n        State transition for readers:\n\n        When a reader finds a block that it can read (case 2 or 3), it can yield the block for caller to read.\n        Only after the caller finishes reading the block, the reader can mark the block as read.\n        Readers only mark the block as read (from 0 to 1), the writer marks the block as ready to read (from 1 to 0).\n\n        State transition for writer:\n\n        When the writer writes to a block (case 1 or 4), it first resets the written flag to 0, converting either case\n        to case 1. Then it can yield the block for caller to write. After the caller finishes writing the block, the writer\n        can reset the reader flags to 0, and mark the block as written (from 0 to 1).\n        NOTE: the order is important here, first reset the reader flags (so that we are still in case 1), then mark the block as written. The state transition is atomic. If we do it in the reverse order, it will go through case 3 and then back to case 2, and readers might read the intermediate case 3, which is not correct.\n\n        During creation, `name` is None and the buffer is created. We can pass the\n        created object to other processes by pickling it. The other processes will\n        get the name of the shared memory and open it, so that they can access the\n        same shared memory buffer.\n        \"\"\"  # noqa\n        self.n_reader = n_reader\n        self.metadata_size = 1 + n_reader\n        self.max_chunk_bytes = max_chunk_bytes\n        self.max_chunks = max_chunks\n        self.total_bytes_of_buffer = (\n            self.max_chunk_bytes + self.metadata_size\n        ) * self.max_chunks\n        self.data_offset = 0\n        self.metadata_offset = self.max_chunk_bytes * self.max_chunks\n\n        if name is None:\n            # we are creating a buffer\n            self.is_creator = True\n            self.shared_memory = shared_memory.SharedMemory(\n                create=True, size=self.total_bytes_of_buffer\n            )\n            # initialize the metadata section to 0\n            with memoryview(\n                self.shared_memory.buf[self.metadata_offset :]\n            ) as metadata_buffer:\n                torch.frombuffer(metadata_buffer, dtype=torch.uint8).fill_(0)\n        else:\n            # we are opening an existing buffer\n            self.is_creator = False\n            # fix to https://stackoverflow.com/q/62748654/9191338\n            # Python incorrectly tracks shared memory even if it is not\n            # created by the process. The following patch is a workaround.\n            with patch(\n                \"multiprocessing.resource_tracker.register\",\n                lambda *args, **kwargs: None,\n            ):\n                try:\n                    self.shared_memory = shared_memory.SharedMemory(name=name)\n                    assert self.shared_memory.size == self.total_bytes_of_buffer\n                except FileNotFoundError:\n                    # we might deserialize the object in a different node\n                    # in this case, this object is not used,\n                    # and we should suppress the error\n                    pass\n\n    def __reduce__(self):\n        return (\n            self.__class__,\n            (\n                self.n_reader,\n                self.max_chunk_bytes,\n                self.max_chunks,\n                self.shared_memory.name,\n            ),\n        )\n\n    def __del__(self):\n        if hasattr(self, \"shared_memory\"):\n            self.shared_memory.close()\n            if self.is_creator:\n                self.shared_memory.unlink()\n\n    @contextmanager\n    def get_data(self, current_idx: int):\n        start = self.data_offset + current_idx * self.max_chunk_bytes\n        end = start + self.max_chunk_bytes\n        with memoryview(self.shared_memory.buf[start:end]) as buf:\n            yield buf\n\n    @contextmanager\n    def get_metadata(self, current_idx: int):\n        start = self.metadata_offset + current_idx * self.metadata_size\n        end = start + self.metadata_size\n        with memoryview(self.shared_memory.buf[start:end]) as buf:\n            yield buf\n\n\n@dataclass\nclass Handle:\n    connect_ip: str\n    local_reader_ranks: List[int] = field(default_factory=list)\n\n    buffer: Optional[ShmRingBuffer] = None\n    local_subscribe_port: Optional[int] = None\n    remote_subscribe_port: Optional[int] = None\n\n\nclass MessageQueue:\n\n    def __init__(\n        self,\n        n_reader,  # number of all readers\n        n_local_reader,  # number of local readers through shared memory\n        local_reader_ranks: Optional[List[int]] = None,\n        max_chunk_bytes: int = 1024 * 1024 * 10,\n        max_chunks: int = 10,\n        connect_ip: Optional[str] = None,\n    ):\n        if local_reader_ranks is None:\n            local_reader_ranks = list(range(n_local_reader))\n        else:\n            assert len(local_reader_ranks) == n_local_reader\n        self.n_local_reader = n_local_reader\n        n_remote_reader = n_reader - n_local_reader\n        self.n_remote_reader = n_remote_reader\n\n        if connect_ip is None:\n            connect_ip = (\n                get_local_ip_auto(\"0.0.0.0\") if n_remote_reader > 0 else \"127.0.0.1\"\n            )\n\n        context = Context()\n\n        if n_local_reader > 0:\n            # for local readers, we will:\n            # 1. create a shared memory ring buffer to communicate small data\n            # 2. create a publish-subscribe socket to communicate large data\n            self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes, max_chunks)\n\n            # XPUB is very similar to PUB,\n            # except that it can receive subscription messages\n            # to confirm the number of subscribers\n            self.local_socket = context.socket(XPUB)\n            # set the verbose option so that we can receive every subscription\n            # message. otherwise, we will only receive the first subscription\n            # see http://api.zeromq.org/3-3:zmq-setsockopt for more details\n            self.local_socket.setsockopt(XPUB_VERBOSE, True)\n            local_subscribe_port = get_open_port()\n            socket_addr = f\"tcp://127.0.0.1:{local_subscribe_port}\"\n            logger.debug(\"Binding to %s\", socket_addr)\n            self.local_socket.bind(socket_addr)\n            self.current_idx = 0\n\n        else:\n            self.buffer = None  # type: ignore\n            local_subscribe_port = None\n            self.local_socket = None\n            self.current_idx = -1\n\n        if n_remote_reader > 0:\n            # for remote readers, we will:\n            # create a publish-subscribe socket to communicate large data\n            self.remote_socket = context.socket(XPUB)\n            self.remote_socket.setsockopt(XPUB_VERBOSE, True)\n            remote_subscribe_port = get_open_port()\n            na = NetworkAddress(connect_ip, remote_subscribe_port)\n            if na.is_ipv6:\n                self.remote_socket.setsockopt(IPV6, 1)\n            address = na.to_tcp()\n            logger.debug(f\"class MessageQueue: Binding remote socket to {address=}\")\n            self.remote_socket.bind(address)\n\n        else:\n            remote_subscribe_port = None\n            self.remote_socket = None\n\n        self._is_writer = True\n        self._is_local_reader = False\n        self.local_reader_rank = -1\n        # rank does not matter for remote readers\n        self._is_remote_reader = False\n\n        self.handle = Handle(\n            connect_ip=connect_ip,\n            local_reader_ranks=local_reader_ranks,\n            buffer=self.buffer,\n            local_subscribe_port=local_subscribe_port,\n            remote_subscribe_port=remote_subscribe_port,\n        )\n\n        logger.debug(\"Message queue communication handle: %s\", self.handle)\n\n    def export_handle(self) -> Handle:\n        return self.handle\n\n    @staticmethod\n    def create_from_handle(handle: Handle, rank) -> \"MessageQueue\":\n        self = MessageQueue.__new__(MessageQueue)\n        self.handle = handle\n        self._is_writer = False\n\n        context = Context()\n\n        if rank in handle.local_reader_ranks:\n            assert handle.buffer is not None\n            self.buffer = handle.buffer\n            self.current_idx = 0\n            self.local_reader_rank = handle.local_reader_ranks.index(rank)\n            self._is_local_reader = True\n            self._is_remote_reader = False\n\n            self.local_socket = context.socket(SUB)\n            self.local_socket.setsockopt_string(SUBSCRIBE, \"\")\n            socket_addr = f\"tcp://127.0.0.1:{handle.local_subscribe_port}\"\n            logger.debug(\"Connecting to %s\", socket_addr)\n            self.local_socket.connect(socket_addr)\n\n            self.remote_socket = None\n        else:\n            self.buffer = None  # type: ignore\n            self.current_idx = -1\n            self.local_reader_rank = -1\n            self._is_local_reader = False\n            self._is_remote_reader = True\n\n            self.local_socket = None\n\n            self.remote_socket = context.socket(SUB)\n            self.remote_socket.setsockopt_string(SUBSCRIBE, \"\")\n            na = NetworkAddress(handle.connect_ip, handle.remote_subscribe_port)\n            if na.is_ipv6:\n                self.remote_socket.setsockopt(IPV6, 1)\n            socket_addr = na.to_tcp()\n            logger.debug(\"Connecting to %s\", socket_addr)\n            self.remote_socket.connect(socket_addr)\n\n        return self\n\n    def wait_until_ready(self):\n        \"\"\"This is a collective operation. All processes (including the\n        readers and the writer) should call this function.\n        \"\"\"\n        if self._is_writer:\n            # wait for all readers to connect\n\n            # local readers\n            for i in range(self.n_local_reader):\n                # wait for subscription messages from all local readers\n                self.local_socket.recv()\n            if self.n_local_reader > 0:\n                # send a message to all local readers\n                # to make sure the publish channel is working\n                self.local_socket.send(b\"READY\")\n\n            # remote readers\n            for i in range(self.n_remote_reader):\n                # wait for subscription messages from all remote readers\n                self.remote_socket.recv()\n            if self.n_remote_reader > 0:\n                # send a message to all remote readers\n                # to make sure the publish channel is working\n                self.remote_socket.send(b\"READY\")\n        elif self._is_local_reader:\n            # wait for the writer to send a message\n            recv = self.local_socket.recv()\n            assert recv == b\"READY\"\n        elif self._is_remote_reader:\n            # wait for the writer to send a message\n            recv = self.remote_socket.recv()\n            assert recv == b\"READY\"\n\n    @contextmanager\n    def acquire_write(self):\n        assert self._is_writer, \"Only writers can acquire write\"\n        start_time = time.monotonic()\n        n_warning = 1\n        while True:\n            with self.buffer.get_metadata(self.current_idx) as metadata_buffer:\n                read_count = sum(metadata_buffer[1:])\n                written_flag = metadata_buffer[0]\n                if written_flag and read_count != self.buffer.n_reader:\n                    # this block is written and not read by all readers\n                    # for writers, `self.current_idx` is the next block to write\n                    # if this block is not ready to write,\n                    # we need to wait until it is read by all readers\n\n                    # Release the processor to other threads\n                    os.sched_yield()\n\n                    # if we wait for a long time, we should warn the user\n                    if (\n                        time.monotonic() - start_time\n                        > SGLANG_RINGBUFFER_WARNING_INTERVAL * n_warning\n                    ):\n                        logger.warning(\n                            \"No available block found in %s second. \",\n                            SGLANG_RINGBUFFER_WARNING_INTERVAL,\n                        )\n                        n_warning += 1\n\n                    continue\n                # found a block that is either\n                # (1) not written\n                # (2) read by all readers\n\n                # mark the block as not written\n                metadata_buffer[0] = 0\n                # let caller write to the buffer\n                with self.buffer.get_data(self.current_idx) as buf:\n                    yield buf\n\n                # caller has written to the buffer\n                # NOTE: order is important here\n                # first set the read flags to 0\n                # then set the written flag to 1\n                # otherwise, the readers may think they already read the block\n                for i in range(1, self.buffer.n_reader + 1):\n                    # set read flag to 0, meaning it is not read yet\n                    metadata_buffer[i] = 0\n                # mark the block as written\n                metadata_buffer[0] = 1\n                self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks\n                break\n\n    @contextmanager\n    def acquire_read(self):\n        assert self._is_local_reader, \"Only readers can acquire read\"\n        start_time = time.monotonic()\n        n_warning = 1\n        while True:\n            with self.buffer.get_metadata(self.current_idx) as metadata_buffer:\n                read_flag = metadata_buffer[self.local_reader_rank + 1]\n                written_flag = metadata_buffer[0]\n                if not written_flag or read_flag:\n                    # this block is either\n                    # (1) not written\n                    # (2) already read by this reader\n\n                    # for readers, `self.current_idx` is the next block to read\n                    # if this block is not ready,\n                    # we need to wait until it is written\n\n                    # Release the processor to other threads\n                    os.sched_yield()\n\n                    # if we wait for a long time, we should warn the user\n                    if (\n                        time.monotonic() - start_time\n                        > SGLANG_RINGBUFFER_WARNING_INTERVAL * n_warning\n                    ):\n                        logger.warning(\n                            \"No available block found in %s second. \",\n                            SGLANG_RINGBUFFER_WARNING_INTERVAL,\n                        )\n                        n_warning += 1\n\n                    continue\n                # found a block that is not read by this reader\n                # let caller read from the buffer\n                with self.buffer.get_data(self.current_idx) as buf:\n                    yield buf\n\n                # caller has read from the buffer\n                # set the read flag\n                metadata_buffer[self.local_reader_rank + 1] = 1\n                self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks\n                break\n\n    def enqueue(self, obj):\n        assert self._is_writer, \"Only writers can enqueue\"\n        serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)\n        if self.n_local_reader > 0:\n            if len(serialized_obj) >= self.buffer.max_chunk_bytes:\n                with self.acquire_write() as buf:\n                    buf[0] = 1  # overflow\n                self.local_socket.send(serialized_obj)\n            else:\n                with self.acquire_write() as buf:\n                    buf[0] = 0  # not overflow\n                    buf[1 : len(serialized_obj) + 1] = serialized_obj\n        if self.n_remote_reader > 0:\n            self.remote_socket.send(serialized_obj)\n\n    def dequeue(self):\n        if self._is_local_reader:\n            with self.acquire_read() as buf:\n                overflow = buf[0] == 1\n                if not overflow:\n                    # no need to know the size of serialized object\n                    # pickle format contains the size information internally\n                    # see https://docs.python.org/3/library/pickle.html\n                    obj = pickle.loads(buf[1:])\n            if overflow:\n                recv = self.local_socket.recv()\n                obj = pickle.loads(recv)\n        elif self._is_remote_reader:\n            recv = self.remote_socket.recv()\n            obj = pickle.loads(recv)\n        else:\n            raise RuntimeError(\"Only readers can dequeue\")\n        return obj\n\n    def broadcast_object(self, obj=None):\n        if self._is_writer:\n            self.enqueue(obj)\n            return obj\n        else:\n            return self.dequeue()\n\n    @staticmethod\n    def create_from_process_group(\n        pg: ProcessGroup, max_chunk_bytes, max_chunks, writer_rank=0\n    ) -> \"MessageQueue\":\n        group_rank = dist.get_rank(pg)\n        group_world_size = dist.get_world_size(pg)\n        global_ranks = dist.get_process_group_ranks(pg)\n\n        from sglang.srt.distributed.parallel_state import in_the_same_node_as\n\n        status = in_the_same_node_as(pg, source_rank=writer_rank)\n        same_node_ranks = [i for i, s in enumerate(status) if s]\n        n_reader = group_world_size - 1\n        n_local_reader = len(same_node_ranks) - 1\n        local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]\n        buffer_io: MessageQueue\n        if group_rank == writer_rank:\n            buffer_io = MessageQueue(\n                n_reader=n_reader,\n                n_local_reader=n_local_reader,\n                local_reader_ranks=local_reader_ranks,\n                max_chunk_bytes=max_chunk_bytes,\n                max_chunks=max_chunks,\n            )\n            handle = buffer_io.export_handle()\n            dist.broadcast_object_list(\n                [handle], src=global_ranks[writer_rank], group=pg\n            )\n        else:\n            recv = [None]\n            dist.broadcast_object_list(recv, src=global_ranks[writer_rank], group=pg)\n            handle = recv[0]  # type: ignore\n            buffer_io = MessageQueue.create_from_handle(handle, group_rank)\n        buffer_io.wait_until_ready()\n        return buffer_io\n"
  },
  {
    "path": "python/sglang/srt/distributed/device_communicators/torch_symm_mem.py",
    "content": "# Adapted from https://github.com/vllm-project/vllm/blob/bf214ca22625e311a2c4c0dfbf7af19128f4919c/vllm/distributed/device_communicators/symm_mem.py\nimport logging\nfrom typing import Optional, Union\n\nimport torch\nimport torch.distributed as dist\nfrom torch.distributed import ProcessGroup\n\nfrom sglang.srt.distributed.device_communicators.all_reduce_utils import (\n    TORCH_SYMM_MEM_ALL_REDUCE_MAX_SIZES,\n)\nfrom sglang.srt.utils import is_cuda, is_hip\n\ntry:\n    import torch.distributed._symmetric_memory as torch_symm_mem\n\n    _is_cuda = is_cuda()\n    _is_hip = is_hip()\n\n    torch_symm_mem_available = False\n    if _is_cuda:\n        torch_symm_mem_available = True\nexcept ImportError:\n    torch_symm_mem_available = False\n\n\nlogger = logging.getLogger(__name__)\n\n\nclass TorchSymmMemCommunicator:\n    \"\"\"\n    Thin wrapper around torch-symmetric-memory collectives.\n\n    This communicator:\n      - Validates device capability and world size.\n      - Allocates a shared symmetric buffer.\n      - Chooses between 'multimem' and 'two-shot' all-reduce kernels.\n      - Exposes a fast-path all_reduce() compatible with bfloat16 inputs.\n\n    If any prerequisite is not met, the instance remains disabled and will\n    decline to perform symmetric-memory all-reduce.\n    \"\"\"\n\n    # Mapping: compute capability major -> supported world sizes for multimem\n    # If the current (cc_major, world_size) is not listed, we fall back\n    # to the two-shot path.\n    _WORLD_SIZES_MULTIMEM = {\n        9: [4, 6, 8],\n        10: [6, 8],\n    }\n\n    def __init__(self, group: ProcessGroup, device: Union[int, str, torch.device]):\n        \"\"\"\n        Args:\n            group: Torch process group used for rendezvous and naming.\n            device: Target CUDA device (index, 'cuda:X', or torch.device).\n        \"\"\"\n\n        self.disabled = True\n\n        if not torch_symm_mem_available:\n            return\n\n        if isinstance(device, int):\n            device = torch.device(f\"cuda:{device}\")\n        elif isinstance(device, str):\n            device = torch.device(device)\n        torch.cuda.set_device(device)\n        self.dtype = torch.bfloat16\n        self.device = device\n        self.group = group\n        self.world_size = dist.get_world_size(self.group)\n        self.device_capability = torch.cuda.get_device_capability(device)[0]\n        if self.device_capability < 9:\n            logger.warning(\n                \"TorchSymmMemCommunicator: Device capability %s not supported, \"\n                \"communicator is not available.\",\n                self.device_capability,\n            )\n            return\n        if (\n            self.world_size\n            not in TORCH_SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability]\n        ):\n            logger.warning(\n                \"TorchSymmMemCommunicator: World size %d not supported, \"\n                \"communicator is not available.\",\n                self.world_size,\n            )\n            return\n        self.max_size = TORCH_SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][\n            self.world_size\n        ]\n        self.buffer = torch_symm_mem.empty(\n            self.max_size // self.dtype.itemsize,\n            device=self.device,\n            dtype=self.dtype,\n        )\n        handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name)\n        if handle.multicast_ptr == 0:\n            logger.warning(\n                \"TorchSymmMemCommunicator: torch symmetric memory \"\n                \"multicast operations are not supported.\"\n            )\n            self.buffer = None\n            self.disabled = True\n            return\n        self.disabled = False\n\n    def should_torch_symm_mem_allreduce(self, inp: torch.Tensor):\n        \"\"\"\n        Fast-path eligibility check for a given tensor.\n\n        Conditions:\n          - Communicator must be enabled.\n          - dtype must be bfloat16 (matches kernel + buffer dtype).\n          - Total byte size must be 4-byte aligned (hardware requirement).\n          - Payload must be smaller than the symmetric-memory max size.\n\n        Returns:\n            True if the symmetric-memory path can handle this tensor.\n        \"\"\"\n        if self.disabled:\n            return False\n        if inp.dtype != self.dtype:\n            return False\n        inp_size = inp.numel() * inp.element_size()\n        # enforce 4-byte alignment\n        if inp_size % 4 != 0:\n            return False\n        return inp_size < self.max_size\n\n    def all_reduce(\n        self, inp: torch.Tensor, *, out: Optional[torch.Tensor] = None\n    ) -> Optional[torch.Tensor]:\n        \"\"\"\n        Perform an in-place sum all-reduce via torch symmetric memory.\n\n        Args:\n            inp: Input tensor on the target CUDA device (bfloat16).\n            out: Optional output tensor; if omitted, a new tensor is allocated.\n\n        Returns:\n            The reduced tensor (same shape as inp), or None if disabled.\n\n        Implementation details:\n            - Stages 'inp' into the symmetric buffer.\n            - Selects 'multimem' or 'two_shot' kernel based on topology.\n            - Writes the result into 'out' and returns it.\n        \"\"\"\n        if out is None:\n            out = torch.empty_like(inp)\n        self.buffer[: inp.numel()].copy_(inp.view(-1))\n        if self.world_size in self._WORLD_SIZES_MULTIMEM[self.device_capability]:\n            torch.ops.symm_mem.multimem_all_reduce_(\n                self.buffer[: inp.numel()], \"sum\", self.group.group_name\n            )\n        else:\n            torch.ops.symm_mem.two_shot_all_reduce_(\n                self.buffer[: inp.numel()], \"sum\", self.group.group_name\n            )\n        out.copy_(self.buffer[: inp.numel()].view(out.shape))\n        return out\n"
  },
  {
    "path": "python/sglang/srt/distributed/device_communicators/xpu_communicator.py",
    "content": "# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/xpu_communicator.py\n\nimport torch\nimport torch.distributed as dist\nfrom torch.distributed import ProcessGroup\n\nfrom sglang.srt.utils import is_xpu\n\n\nclass XpuCommunicator:\n\n    def __init__(self, group: ProcessGroup):\n        if not is_xpu():\n            self.disabled = True\n            return\n        self.disabled = False\n        self.group = group\n        self.world_size = dist.get_world_size(self.group)\n\n    def all_reduce(self, x: torch.Tensor) -> torch.Tensor:\n        dist.all_reduce(x, group=self.group)\n        return x\n\n    def gather(\n        self, input_: torch.Tensor, rank_in_group: int, dst: int = 0, dim: int = -1\n    ):\n        # For xpu path, gather doesn't work properly together with ray\n        # cluster so we use all_gather instead for now.\n        input_size = input_.size()\n        # Allocate output tensor.\n        output_tensor = torch.empty(\n            (self.world_size,) + input_size, dtype=input_.dtype, device=input_.device\n        )\n        # All-gather.\n        torch.distributed.all_gather_into_tensor(\n            output_tensor, input_, group=self.group\n        )\n        if rank_in_group == dst:\n            # Reshape\n            output_tensor = output_tensor.movedim(0, dim)\n            output_tensor = output_tensor.reshape(\n                input_size[:dim]\n                + (self.world_size * input_size[dim],)\n                + input_size[dim + 1 :]\n            )\n        else:\n            output_tensor = None\n        return output_tensor\n"
  },
  {
    "path": "python/sglang/srt/distributed/naive_distributed.py",
    "content": "import pickle\nimport time\nfrom pathlib import Path\nfrom typing import Any, List, Optional\n\nimport pybase64\nimport torch\n\nfrom sglang.srt.utils import MultiprocessingSerializer\n\n\nclass NaiveDistributed:\n    def __init__(self, rank: int, world_size: int, rendezvous: str):\n        self._rank = rank\n        self._world_size = world_size\n        self._operation_index = 0\n        self._directory = Path(rendezvous)\n        self._directory.mkdir(parents=True, exist_ok=True)\n        assert 0 <= rank < world_size\n\n        # both barrier to be safe, and as a sanity check\n        self.barrier()\n\n    def get_rank(self):\n        return self._rank\n\n    def get_world_size(self):\n        return self._world_size\n\n    def scatter(\n        self, tensor: torch.Tensor, scatter_list: List[torch.Tensor], src: int = 0\n    ):\n        if self._rank == src:\n            assert len(scatter_list) == self._world_size\n        else:\n            assert scatter_list is None\n\n        gathered_objects = self.all_gather_object(\n            dict(\n                serialized_scatter_list=[\n                    (\n                        None\n                        if item_rank == src\n                        else MultiprocessingSerializer.serialize(item)\n                    )\n                    for item_rank, item in enumerate(scatter_list)\n                ]\n            )\n            if self._rank == src\n            else dict()\n        )\n\n        remote_serialized_tensor = gathered_objects[src][\"serialized_scatter_list\"][\n            self._rank\n        ]\n        if self._rank == src:\n            assert remote_serialized_tensor is None\n            remote_tensor = scatter_list[self._rank]\n        else:\n            remote_tensor = MultiprocessingSerializer.deserialize(\n                remote_serialized_tensor\n            )\n        tensor.copy_(remote_tensor)\n\n        # avoid src tensor be deleted too early\n        self.barrier()\n\n    def all_gather_object(self, obj: Any) -> List[Any]:\n        self._operation_index += 1\n\n        text_postfix = \"\\n\"\n\n        def _get_path(interesting_rank: int):\n            return (\n                self._directory\n                / f\"rank{interesting_rank}_op{self._operation_index}.txt\"\n            )\n\n        _get_path(self._rank).write_text(\n            pybase64.b64encode(pickle.dumps(obj)).decode(\"utf-8\") + text_postfix\n        )\n\n        def _read_one(interesting_rank: int):\n            p = _get_path(interesting_rank)\n            while True:\n                if p.exists() and (text := p.read_text()).endswith(text_postfix):\n                    return pickle.loads(\n                        pybase64.b64decode(text[: -len(text_postfix)], validate=True)\n                    )\n                time.sleep(0.001)\n\n        return [\n            _read_one(interesting_rank) for interesting_rank in range(self._world_size)\n        ]\n\n    def barrier(self):\n        actual_objs = self.all_gather_object(self._rank)\n        assert actual_objs == list(range(self._world_size)), f\"{actual_objs=}\"\n\n\n# Can have multi instances if needed\n_instance: Optional[NaiveDistributed] = None\n\n\ndef get_naive_distributed():\n    assert _instance is not None\n    return _instance\n\n\ndef set_naive_distributed(instance: NaiveDistributed):\n    global _instance\n    assert _instance is None\n    _instance = instance\n"
  },
  {
    "path": "python/sglang/srt/distributed/parallel_state.py",
    "content": "# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/parallel_state.py\n\n# Copyright 2023 The vLLM team.\n# Adapted from\n# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py\n# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.\n\"\"\"Distributed state.\nIt takes over the control of the distributed environment from PyTorch.\nThe typical workflow is:\n\n- call `init_distributed_environment` to initialize the distributed environment.\n- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to\n initialize the model parallel groups.\n\n- any code dealing with the distributed stuff\n\n- call `destroy_model_parallel` to destroy the model parallel groups.\n- call `destroy_distributed_environment` to destroy the distributed environment.\n\nIf you only need to use the distributed environment without model/pipeline\n parallelism, you can skip the model parallel initialization and destruction\n steps.\n\"\"\"\n\nimport contextlib\nimport gc\nimport logging\nimport os\nimport pickle\nimport weakref\nfrom collections import namedtuple\nfrom contextlib import contextmanager, nullcontext\nfrom dataclasses import dataclass\nfrom datetime import timedelta\nfrom multiprocessing import shared_memory\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\nfrom unittest.mock import patch\n\nimport torch\nimport torch.distributed\nfrom torch.distributed import Backend, ProcessGroup\n\nfrom sglang.srt.compilation.compilation_config import register_split_op\nfrom sglang.srt.compilation.piecewise_context_manager import is_in_piecewise_cuda_graph\nfrom sglang.srt.distributed.utils import set_global_tcp_store\nfrom sglang.srt.environ import envs\nfrom sglang.srt.utils import (\n    get_bool_env_var,\n    get_current_device_stream_fast,\n    get_int_env_var,\n    is_cpu,\n    is_cuda_alike,\n    is_hip,\n    is_musa,\n    is_npu,\n    is_shm_available,\n    is_xpu,\n)\nfrom sglang.srt.utils.custom_op import register_custom_op\nfrom sglang.srt.utils.network import get_local_ip_auto\n\n_is_npu = is_npu()\n_is_cpu = is_cpu()\n_is_xpu = is_xpu()\n_is_musa = is_musa()\n\nTensorMetadata = namedtuple(\"TensorMetadata\", [\"device\", \"dtype\", \"size\"])\n\n# use int value instead of ReduceOp.SUM to support torch compile\nREDUCE_OP_SUM = int(torch.distributed.ReduceOp.SUM)\n\n\ndef get_torch_distributed_pg_options(group_name=None):\n    if not _is_npu:\n        return None\n\n    # Only create HCCL options for default group or MoE-related groups\n    if group_name is not None and \"moe\" not in group_name:\n        return None\n\n    import torch_npu\n\n    options = torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options()\n    hccl_buffer_size = int(\n        os.environ.get(\"DEEPEP_HCCL_BUFFSIZE\") or os.environ.get(\"HCCL_BUFFSIZE\") or 200\n    )\n    options.hccl_config = {\"hccl_buffer_size\": hccl_buffer_size}\n    return options\n\n\n@dataclass\nclass GraphCaptureContext:\n    stream: torch.get_device_module().Stream\n\n\n@dataclass\nclass P2PWork:\n    work: Optional[torch.distributed.Work]\n    payload: Optional[torch.Tensor]\n\n\ndef _split_tensor_dict(\n    tensor_dict: Dict[str, Union[torch.Tensor, Any]],\n) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:\n    \"\"\"Split the tensor dictionary into two parts:\n    1. A list of (key, value) pairs. If the value is a tensor, it is replaced\n         by its metadata.\n    2. A list of tensors.\n    \"\"\"\n    metadata_list: List[Tuple[str, Any]] = []\n    tensor_list: List[torch.Tensor] = []\n    for key, value in tensor_dict.items():\n        if isinstance(value, torch.Tensor):\n            # Note: we cannot use `value.device` here,\n            # because it contains not only the device type but also the device\n            # index (e.g. \"cuda:0\"). We only need the device type.\n            # receiving side will set the device index.\n            device = value.device.type\n            metadata_list.append(\n                (key, TensorMetadata(device, value.dtype, value.size()))\n            )\n            tensor_list.append(value)\n        else:\n            metadata_list.append((key, value))\n    return metadata_list, tensor_list\n\n\n_group_name_counter: Dict[str, int] = {}\n\n\ndef _get_unique_name(name: str) -> str:\n    \"\"\"Get a unique name for the group.\n    Example:\n    _get_unique_name(\"tp\") -> \"tp:0\"\n    _get_unique_name(\"tp\") -> \"tp:1\"\n    \"\"\"\n    if name not in _group_name_counter:\n        _group_name_counter[name] = 0\n    newname = f\"{name}:{_group_name_counter[name]}\"\n    _group_name_counter[name] += 1\n    return newname\n\n\n_groups: Dict[str, Callable[[], Optional[\"GroupCoordinator\"]]] = {}\n\n\ndef _register_group(group: \"GroupCoordinator\") -> None:\n    _groups[group.unique_name] = weakref.ref(group)\n\n\n@register_custom_op(mutates_args=[\"tensor\"])\n@register_split_op()\ndef inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:\n    assert group_name in _groups, f\"Group {group_name} is not found.\"\n    group = _groups[group_name]()\n    if group is None:\n        raise ValueError(f\"Group {group_name} is destroyed.\")\n    group._all_reduce_in_place(tensor)\n\n\n@register_custom_op(out_shape=\"tensor\")\ndef outplace_all_reduce(\n    tensor: torch.Tensor, group_name: str, outplace_all_reduce_method: str\n) -> torch.Tensor:\n    assert group_name in _groups, f\"Group {group_name} is not found.\"\n    group = _groups[group_name]()\n    if group is None:\n        raise ValueError(f\"Group {group_name} is destroyed.\")\n    return group._all_reduce_out_place(tensor, outplace_all_reduce_method)\n\n\n@register_custom_op(mutates_args=[\"output\"])\ndef reg_all_gather_into_tensor(\n    output: torch.Tensor, input: torch.Tensor, group_name: str\n) -> None:\n    assert group_name in _groups, f\"Group {group_name} is not found.\"\n    group = _groups[group_name]()\n    if group is None:\n        raise ValueError(f\"Group {group_name} is destroyed.\")\n    group._all_gather_into_tensor(output, input)\n\n\n@register_custom_op(mutates_args=[\"output\"])\ndef reg_reduce_scatter_tensor(\n    output: torch.Tensor, input: torch.Tensor, group_name: str\n) -> None:\n    assert group_name in _groups, f\"Group {group_name} is not found.\"\n    group = _groups[group_name]()\n    if group is None:\n        raise ValueError(f\"Group {group_name} is destroyed.\")\n    group._reduce_scatter_tensor(output, input)\n\n\nclass GroupCoordinator:\n    \"\"\"\n    PyTorch ProcessGroup wrapper for a group of processes.\n    PyTorch ProcessGroup is bound to one specific communication backend,\n        e.g. NCCL, Gloo, MPI, etc.\n    GroupCoordinator takes charge of all the communication operations among\n        the processes in the group. It can route the communication to\n        a specific implementation (e.g. switch allreduce implementation\n        based on the tensor size and cuda graph mode).\n    \"\"\"\n\n    # available attributes:\n    rank: int  # global rank\n    ranks: List[int]  # global ranks in the group\n    world_size: int  # size of the group\n    # difference between `local_rank` and `rank_in_group`:\n    # if we have a group of size 4 across two nodes:\n    # Process | Node | Rank | Local Rank | Rank in Group\n    #   0     |   0  |  0   |     0      |       0\n    #   1     |   0  |  1   |     1      |       1\n    #   2     |   1  |  2   |     0      |       2\n    #   3     |   1  |  3   |     1      |       3\n    local_rank: int  # local rank used to assign devices\n    rank_in_group: int  # rank inside the group\n    cpu_group: ProcessGroup  # group for CPU communication\n    device_group: ProcessGroup  # group for device communication\n    use_pynccl: bool  # a hint of whether to use PyNccl\n    use_pymscclpp: bool  # a hint of whether to use PyMsccl\n    use_custom_allreduce: bool  # a hint of whether to use CustomAllreduce\n    use_torch_symm_mem_all_reduce: (\n        bool  # a hint of whether to use TorchSymmMemAllReduce\n    )\n    use_message_queue_broadcaster: (\n        bool  # a hint of whether to use message queue broadcaster\n    )\n    # communicators are only created for world size > 1\n    pynccl_comm: Optional[Any]  # PyNccl communicator\n    ca_comm: Optional[Any]  # Custom allreduce communicator\n    torch_symm_mem_comm: Optional[Any]  # Torch symm mem communicator\n    mq_broadcaster: Optional[Any]  # shared memory broadcaster\n\n    def __init__(\n        self,\n        group_ranks: List[List[int]],\n        local_rank: int,\n        torch_distributed_backend: Union[str, Backend],\n        use_pynccl: bool,\n        use_pymscclpp: bool,\n        use_custom_allreduce: bool,\n        use_torch_symm_mem_all_reduce: bool,\n        use_hpu_communicator: bool,\n        use_xpu_communicator: bool,\n        use_npu_communicator: bool,\n        use_message_queue_broadcaster: bool = False,\n        group_name: Optional[str] = None,\n        pynccl_use_current_stream: bool = False,\n        gloo_timeout: timedelta = timedelta(seconds=120 * 60),\n    ):\n        # Set group info\n        group_name = group_name or \"anonymous\"\n        self.unique_name = _get_unique_name(group_name)\n        _register_group(self)\n\n        # Set rank info\n        self.rank = torch.distributed.get_rank()\n        self.local_rank = local_rank\n        self.device_group = None\n        self.cpu_group = None\n        self.local_size = get_int_env_var(\"LOCAL_SIZE\", 0)\n\n        if is_cuda_alike():\n            device_id = (\n                0 if envs.SGLANG_ONE_VISIBLE_DEVICE_PER_PROCESS.get() else local_rank\n            )\n            self.device = torch.device(f\"cuda:{device_id}\")\n        elif _is_npu:\n            self.device = torch.device(f\"npu:{local_rank}\")\n        elif _is_xpu:\n            self.device = torch.device(f\"xpu:{local_rank}\")\n        elif _is_musa:\n            self.device = torch.device(f\"musa:{local_rank}\")\n        else:\n            self.device = torch.device(\"cpu\")\n        self.device_module = torch.get_device_module(self.device)\n\n        for ranks in group_ranks:\n            active_ranks = torch.ones(len(ranks), dtype=torch.int32, device=self.device)\n            active_ranks_cpu = torch.ones(len(ranks), dtype=torch.int32)\n            if \"mooncake\" in torch_distributed_backend:\n                from mooncake.ep import MooncakeBackendOptions\n\n                device_group = torch.distributed.new_group(\n                    ranks,\n                    backend=\"mooncake\",\n                    pg_options=MooncakeBackendOptions(active_ranks),\n                )\n                cpu_group = torch.distributed.new_group(\n                    ranks,\n                    backend=\"mooncake-cpu\",\n                    pg_options=MooncakeBackendOptions(active_ranks_cpu),\n                )\n            else:\n                pg_options = get_torch_distributed_pg_options(group_name)\n                device_group = torch.distributed.new_group(\n                    ranks, backend=torch_distributed_backend, pg_options=pg_options\n                )\n                # a group with `gloo` backend, to allow direct coordination\n                # between processes through the CPU.\n                cpu_group = torch.distributed.new_group(\n                    ranks, backend=\"gloo\", timeout=gloo_timeout\n                )\n            if self.rank in ranks:\n                self.ranks = ranks\n                self.world_size = len(ranks)\n                self.rank_in_group = ranks.index(self.rank)\n                self.device_group = device_group\n                self.cpu_group = cpu_group\n                self.active_ranks = active_ranks\n                self.active_ranks_cpu = active_ranks_cpu\n\n        assert self.cpu_group is not None\n        assert self.device_group is not None\n\n        # Import communicators\n        self.use_pynccl = use_pynccl\n        self.pynccl_use_current_stream = pynccl_use_current_stream\n        self.use_pymscclpp = use_pymscclpp\n        self.use_custom_allreduce = use_custom_allreduce\n        self.use_torch_symm_mem_all_reduce = use_torch_symm_mem_all_reduce\n        self.use_hpu_communicator = use_hpu_communicator\n        self.use_xpu_communicator = use_xpu_communicator\n        self.use_npu_communicator = use_npu_communicator\n        self.use_message_queue_broadcaster = use_message_queue_broadcaster\n\n        # Lazy import to avoid documentation build error\n        from sglang.srt.distributed.device_communicators.custom_all_reduce import (\n            dispatch_custom_allreduce,\n        )\n        from sglang.srt.distributed.device_communicators.pymscclpp import (\n            PyMscclppCommunicator,\n        )\n        from sglang.srt.distributed.device_communicators.pynccl import (\n            PyNcclCommunicator,\n        )\n        from sglang.srt.distributed.device_communicators.pynccl_allocator import (\n            is_symmetric_memory_enabled,\n            use_symmetric_memory,\n        )\n        from sglang.srt.distributed.device_communicators.torch_symm_mem import (\n            TorchSymmMemCommunicator,\n        )\n        from sglang.srt.layers.dp_attention import is_allocation_symmetric\n\n        self.is_symmetric_memory_enabled = is_symmetric_memory_enabled\n        self.use_symmetric_memory = use_symmetric_memory\n        self.is_allocation_symmetric = is_allocation_symmetric\n        if is_hip():\n            from sglang.srt.distributed.device_communicators.quick_all_reduce import (\n                QuickAllReduce,\n                qr_rocm_arch_available,\n            )\n\n        self.pynccl_comm: Optional[PyNcclCommunicator] = None\n        if use_pynccl and self.world_size > 1:\n            self.pynccl_comm = PyNcclCommunicator(\n                group=self.cpu_group,\n                device=self.device,\n                use_current_stream=pynccl_use_current_stream,\n            )\n\n        self.pymscclpp_comm: Optional[PyMscclppCommunicator] = None\n        if use_pymscclpp and self.world_size > 1:\n            self.pymscclpp_comm = PyMscclppCommunicator(\n                group=self.cpu_group,\n                device=self.device,\n            )\n\n        self.ca_comm: Optional[Any] = None\n        self.qr_comm: Optional[QuickAllReduce] = None\n        if use_custom_allreduce and self.world_size > 1:\n            # Initialize a custom fast all-reduce implementation.\n            try:\n                CAClass = dispatch_custom_allreduce()\n                self.ca_comm = CAClass(\n                    group=self.cpu_group,\n                    device=self.device,\n                )\n            except Exception as e:\n                logger.warning(\n                    f\"Setup Custom allreduce failed with {e}. To silence this \"\n                    \"warning, specify --disable-custom-all-reduce explicitly.\"\n                )\n\n            if is_hip():\n                try:\n                    # Initialize a custom quick all-reduce implementation for AMD\n                    # when rocm >= gfx942. Quick reduce is designed as a\n                    # complement to custom allreduce.\n                    # Based on quickreduce (https://github.com/mk1-project/quickreduce).\n                    if qr_rocm_arch_available():\n                        self.qr_comm = QuickAllReduce(\n                            group=self.cpu_group, device=self.device\n                        )\n                except Exception as e:\n                    logger.warning(f\"Failed to initialize QuickAllReduce: {e}\")\n        elif self.world_size > 1 and is_hip():\n            logger.info(\"[AR] All-reduce call path: NCCL (custom AR disabled)\")\n\n        self.torch_symm_mem_comm: Optional[TorchSymmMemCommunicator] = None\n        if self.use_torch_symm_mem_all_reduce and self.world_size > 1:\n            self.torch_symm_mem_comm = TorchSymmMemCommunicator(\n                group=self.cpu_group,\n                device=self.device,\n            )\n\n        # Create communicator for other hardware backends\n        from sglang.srt.distributed.device_communicators.hpu_communicator import (\n            HpuCommunicator,\n        )\n        from sglang.srt.distributed.device_communicators.npu_communicator import (\n            NpuCommunicator,\n        )\n        from sglang.srt.distributed.device_communicators.xpu_communicator import (\n            XpuCommunicator,\n        )\n\n        self.hpu_communicator: Optional[HpuCommunicator] = None\n        if use_hpu_communicator and self.world_size > 1:\n            self.hpu_communicator = HpuCommunicator(group=self.device_group)\n\n        self.xpu_communicator: Optional[XpuCommunicator] = None\n        if use_xpu_communicator and self.world_size > 1:\n            self.xpu_communicator = XpuCommunicator(group=self.device_group)\n\n        self.npu_communicator: Optional[NpuCommunicator] = None\n        if use_npu_communicator and self.world_size > 1:\n            self.npu_communicator = NpuCommunicator(group=self.device_group)\n\n        # Create message queue\n        from sglang.srt.distributed.device_communicators.shm_broadcast import (\n            MessageQueue,\n        )\n\n        self.mq_broadcaster: Optional[MessageQueue] = None\n        if use_message_queue_broadcaster and self.world_size > 1:\n            self.mq_broadcaster = MessageQueue.create_from_process_group(\n                self.cpu_group, 1 << 22, 6\n            )\n\n    def __repr__(self):\n        return (\n            f\"ranks={self.ranks} rank={self.rank} local_rank={self.local_rank} use_pynccl={self.use_pynccl} \"\n            f\"device_group={self.device_group} cpu_group={self.cpu_group} unique_name={self.unique_name} \"\n            f\"world_size={self.world_size} rank_in_group={self.rank_in_group}\"\n        )\n\n    @property\n    def first_rank(self):\n        \"\"\"Return the global rank of the first process in the group\"\"\"\n        return self.ranks[0]\n\n    @property\n    def last_rank(self):\n        \"\"\"Return the global rank of the last process in the group\"\"\"\n        return self.ranks[-1]\n\n    @property\n    def is_first_rank(self):\n        \"\"\"Return whether the caller is the first process in the group\"\"\"\n        return self.rank == self.first_rank\n\n    @property\n    def is_last_rank(self):\n        \"\"\"Return whether the caller is the last process in the group\"\"\"\n        return self.rank == self.last_rank\n\n    @property\n    def next_rank(self):\n        \"\"\"Return the global rank of the process that follows the caller\"\"\"\n        rank_in_group = self.rank_in_group\n        world_size = self.world_size\n        return self.ranks[(rank_in_group + 1) % world_size]\n\n    @property\n    def prev_rank(self):\n        \"\"\"Return the global rank of the process that precedes the caller\"\"\"\n        rank_in_group = self.rank_in_group\n        world_size = self.world_size\n        return self.ranks[(rank_in_group - 1) % world_size]\n\n    @contextmanager\n    def graph_capture(\n        self,\n        graph_capture_context: Optional[GraphCaptureContext] = None,\n        stream: Optional[torch.cuda.Stream] = None,\n    ):\n        if graph_capture_context is None:\n            if stream is None:\n                stream = self.device_module.Stream()\n            graph_capture_context = GraphCaptureContext(stream)\n        else:\n            stream = graph_capture_context.stream\n        # We don't need the context of custom quick allreduce because the ipc access\n        # is already collected in init() and we can capture the quick allreduce directly.\n        ca_comm = self.ca_comm\n        maybe_ca_context = nullcontext() if ca_comm is None else ca_comm.capture()\n\n        # ensure all initialization operations complete before attempting to\n        # capture the graph on another stream\n        curr_stream = get_current_device_stream_fast()\n        if curr_stream != stream:\n            stream.wait_stream(curr_stream)\n\n        with self.device_module.stream(stream), maybe_ca_context:\n            # In graph mode, we have to be very careful about the collective\n            # operations. The current status is:\n            #     allreduce \\ Mode   |  Eager  |  Graph  |\n            # --------------------------------------------\n            # quick allreduce        | enabled | enabled |\n            # custom allreduce       | enabled | enabled |\n            # PyNccl                 | disabled| enabled |\n            # PyMscclpp              | disabled| enabled |\n            # TorchSymmMem           | disabled| enabled |\n            # torch.distributed      | enabled | disabled|\n            #\n            # Note: When custom quick allreduce is enabled, a runtime check\n            #  will be performed. If the tensor size is too small, it will\n            #  automatically fall back to the next available option.\n            # Note that custom allreduce will have a runtime check, if the\n            #  tensor size is too large, it will fallback to the next\n            #  available option.\n            # Note that the PyMsccl needs to register the tensor in ahead,\n            #  which will introduce large overhead in the eager case,\n            #  therefore it is only supported in the graph case.\n            # In summary: We select the appropriate allreduce method for\n            #  each mode based on the algorithm order in the table and\n            #  their usage conditions.\n            pynccl_comm = self.pynccl_comm\n            maybe_pynccl_context: Any\n            if not pynccl_comm:\n                maybe_pynccl_context = nullcontext()\n            else:\n                maybe_pynccl_context = pynccl_comm.change_state(\n                    enable=True, stream=get_current_device_stream_fast()\n                )\n\n            pymscclpp_comm = self.pymscclpp_comm\n            maybe_pymscclpp_context: Any\n            if not pymscclpp_comm:\n                maybe_pymscclpp_context = nullcontext()\n            else:\n                maybe_pymscclpp_context = pymscclpp_comm.change_state(enable=True)\n            with maybe_pynccl_context, maybe_pymscclpp_context:\n                yield graph_capture_context\n\n    def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        User-facing all-reduce function before we actually call the\n        all-reduce operation.\n\n        We need this because Dynamo does not support passing an arbitrary\n        object (`self` in this case) to a custom op. We need to pass the\n         group name as a string, and then look up the group coordinator from\n         the group name, dispatch the all-reduce operation to the group\n         coordinator.\n\n        In addition, PyTorch custom ops do not support mutation or returning\n        a new tensor in the same op. So we need to figure out if the op is\n        in-place or out-of-place ahead of time.\n        \"\"\"\n        # Bypass the function if we are using only 1 GPU.\n        if self.world_size == 1:\n            return input_\n\n        # On AMD, use the deterministic 1-stage kernel when:\n        # - SGLANG_USE_1STAGE_ALLREDUCE=1 (explicitly enabled), OR\n        # - SGLANG_USE_1STAGE_ALLREDUCE not set AND --enable-deterministic-inference is on\n        if envs.SGLANG_USE_1STAGE_ALLREDUCE.is_set():\n            use_1stage_ar = envs.SGLANG_USE_1STAGE_ALLREDUCE.get()\n        else:\n            use_1stage_ar = envs.SGLANG_ENABLE_DETERMINISTIC_INFERENCE.get()\n        use_deterministic_ar = is_hip() and use_1stage_ar\n        if use_deterministic_ar:\n            if not input_.is_cpu and self.ca_comm is not None:\n                inp_size = input_.numel() * input_.element_size()\n                # Try unregistered mode first (faster for smaller tensors)\n                if inp_size < self.ca_comm.max_size:\n                    return self.ca_comm.deterministic_all_reduce(\n                        input_, registered=False\n                    )\n                # Use registered mode for larger tensors\n                self.ca_comm.register_buffer(input_)\n                return self.ca_comm.deterministic_all_reduce(input_, registered=True)\n\n        if input_.is_cpu:\n            if is_shm_available(input_.dtype, self.world_size, self.local_size):\n                torch.ops.sgl_kernel.shm_allreduce(input_, REDUCE_OP_SUM)\n            else:\n                torch.distributed.all_reduce(input_, group=self.device_group)\n            return input_\n\n        if self.hpu_communicator is not None and not self.hpu_communicator.disabled:\n            return self.hpu_communicator.all_reduce(input_)\n\n        if self.xpu_communicator is not None and not self.xpu_communicator.disabled:\n            return self.xpu_communicator.all_reduce(input_)\n\n        if self.npu_communicator is not None and not self.npu_communicator.disabled:\n            return self.npu_communicator.all_reduce(input_)\n\n        if self.pynccl_comm is not None and self.is_symmetric_memory_enabled():\n            with self.pynccl_comm.change_state(\n                enable=True, stream=get_current_device_stream_fast()\n            ):\n                self.pynccl_comm.all_reduce(input_)\n                return input_\n\n        outplace_all_reduce_method = None\n        if (\n            self.ca_comm is not None\n            and not self.ca_comm.disabled\n            and self.ca_comm.should_custom_ar(input_)\n        ):\n            outplace_all_reduce_method = \"ca\"\n        elif (\n            self.qr_comm is not None\n            and not self.qr_comm.disabled\n            and self.qr_comm.should_quick_allreduce(input_)\n        ):\n            outplace_all_reduce_method = \"qr\"\n        elif (\n            self.pymscclpp_comm is not None\n            and not self.pymscclpp_comm.disabled\n            and self.pymscclpp_comm.should_mscclpp_allreduce(input_)\n        ):\n            outplace_all_reduce_method = \"pymscclpp\"\n        elif (\n            self.torch_symm_mem_comm is not None\n            and not self.torch_symm_mem_comm.disabled\n            and self.torch_symm_mem_comm.should_torch_symm_mem_allreduce(input_)\n        ):\n            outplace_all_reduce_method = \"torch_symm_mem\"\n        elif is_in_piecewise_cuda_graph():\n            # For piecewise cuda graph, we use pynccl outplace allreduce\n            outplace_all_reduce_method = \"pynccl\"\n        if outplace_all_reduce_method is not None:\n            return outplace_all_reduce(\n                input_,\n                group_name=self.unique_name,\n                outplace_all_reduce_method=outplace_all_reduce_method,\n            )\n        else:\n            inplace_all_reduce(input_, group_name=self.unique_name)\n            return input_\n\n    def fused_allreduce_rmsnorm(\n        self,\n        input_: torch.Tensor,\n        residual_inp_: torch.Tensor,\n        weight_: torch.Tensor,\n        eps: float,\n    ) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:\n        \"\"\"Attempt fused all-reduce + RMSNorm via custom all-reduce communicator.\"\"\"\n        ca_comm = self.ca_comm\n        if ca_comm is None or getattr(ca_comm, \"disabled\", True):\n            return None\n\n        # Prefer communicator-native fused API when provided.\n        if hasattr(ca_comm, \"fused_allreduce_rmsnorm\"):\n            try:\n                return ca_comm.fused_allreduce_rmsnorm(\n                    input_, residual_inp_, weight_, eps\n                )\n            except Exception:\n                # Fall back to custom_fused_ar_rms path below.\n                pass\n\n        if not hasattr(ca_comm, \"custom_fused_ar_rms\"):\n            return None\n\n        # 1-stage policy for fused AR+RMSNorm:\n        # 1) Explicit env override wins.\n        # 2) Deterministic inference forces 1-stage for reproducibility.\n        # 3) Otherwise follow AITER's heuristic (small payloads only).\n        if envs.SGLANG_USE_1STAGE_ALLREDUCE.is_set():\n            use_1stage_ar = envs.SGLANG_USE_1STAGE_ALLREDUCE.get()\n        elif envs.SGLANG_ENABLE_DETERMINISTIC_INFERENCE.get():\n            use_1stage_ar = True\n        else:\n            total_bytes = input_.numel() * input_.element_size()\n            hidden_dim = input_.shape[-1]\n            use_1stage_ar = total_bytes <= 128 * 1024 and hidden_dim in {\n                512,\n                1024,\n                2048,\n                4096,\n            }\n\n        fused_outputs = ca_comm.custom_fused_ar_rms(\n            input_,\n            residual_inp_,\n            weight_,\n            eps,\n            use_1stage_ar,\n        )\n        return fused_outputs\n\n    def _all_reduce_out_place(\n        self, input_: torch.Tensor, outplace_all_reduce_method: str\n    ) -> torch.Tensor:\n        ca_comm = self.ca_comm\n        qr_comm = self.qr_comm\n        pymscclpp_comm = self.pymscclpp_comm\n        torch_symm_mem_comm = self.torch_symm_mem_comm\n        pynccl_comm = self.pynccl_comm\n        assert any([qr_comm, ca_comm, pymscclpp_comm, torch_symm_mem_comm, pynccl_comm])\n        if outplace_all_reduce_method == \"ca\":\n            assert not ca_comm.disabled\n            out = ca_comm.custom_all_reduce(input_)\n        elif outplace_all_reduce_method == \"qr\":\n            assert not qr_comm.disabled\n            out = qr_comm.quick_all_reduce(input_)\n        elif outplace_all_reduce_method == \"torch_symm_mem\":\n            assert not torch_symm_mem_comm.disabled\n            out = torch_symm_mem_comm.all_reduce(input_)\n        elif outplace_all_reduce_method == \"pymscclpp\":\n            assert not pymscclpp_comm.disabled\n            out = pymscclpp_comm.all_reduce(input_)\n        elif outplace_all_reduce_method == \"pynccl\":\n            with pynccl_comm.change_state(\n                enable=True, stream=get_current_device_stream_fast()\n            ):\n                out = pynccl_comm.outplace_all_reduce(input_)\n        assert out is not None\n        return out\n\n    def _all_reduce_in_place(self, input_: torch.Tensor) -> None:\n        pynccl_comm = self.pynccl_comm\n        torch_symm_mem_comm = self.torch_symm_mem_comm\n        if pynccl_comm is not None and not pynccl_comm.disabled:\n            pynccl_comm.all_reduce(input_)\n        elif torch_symm_mem_comm is not None and not torch_symm_mem_comm.disabled:\n            torch_symm_mem_comm.all_reduce(input_)\n        else:\n            torch.distributed.all_reduce(input_, group=self.device_group)\n\n    def _reduce_scatter_tensor(\n        self,\n        output: torch.Tensor,\n        input: torch.Tensor,\n    ) -> torch.Tensor:\n        pynccl_comm = self.pynccl_comm\n        if pynccl_comm is not None and (\n            not pynccl_comm.disabled or self.is_symmetric_memory_enabled()\n        ):\n            with pynccl_comm.change_state(\n                enable=True, stream=get_current_device_stream_fast()\n            ):\n                pynccl_comm.reduce_scatter(output, input)\n        else:\n            torch.distributed.reduce_scatter_tensor(\n                output, input, group=self.device_group\n            )\n        return output\n\n    def reduce_scatter_tensor(self, output: torch.Tensor, input: torch.Tensor):\n        if _is_npu:\n            self._reduce_scatter_tensor(output, input)\n        else:\n            reg_reduce_scatter_tensor(output, input, group_name=self.unique_name)\n\n    def reduce_scatter(\n        self,\n        output: torch.Tensor,\n        input_list: List[torch.Tensor],\n    ) -> None:\n        # TODO(ch-wan): support other backends\n        torch.distributed.reduce_scatter(output, input_list, group=self.device_group)\n        return output\n\n    def reduce_scatterv(\n        self,\n        input_: torch.Tensor,\n        output: Optional[torch.Tensor] = None,\n        sizes: Optional[List[int]] = None,\n    ) -> torch.Tensor:\n        world_size = self.world_size\n        pynccl_comm = self.pynccl_comm\n\n        with pynccl_comm.change_state(\n            enable=True, stream=get_current_device_stream_fast()\n        ):\n            assert (\n                pynccl_comm is not None and not pynccl_comm.disabled\n            ), \"pynccl is required for reduce_scatterv\"\n\n            if sizes is not None:\n                assert len(sizes) == world_size\n                assert input_.shape[0] == sum(sizes)\n                chunk_size = sizes[self.rank_in_group]\n            else:\n                assert input_.shape[0] % world_size == 0\n                chunk_size = input_.shape[0] // world_size\n            output_shape = (chunk_size,) + input_.shape[1:]\n\n            if output is None:\n                output = torch.empty(\n                    output_shape, dtype=input_.dtype, device=input_.device\n                )\n            else:\n                assert output.shape == output_shape\n\n            pynccl_comm.reduce_scatter(output, input_, sizes=sizes)\n            return output\n\n    def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):\n        pynccl_comm = self.pynccl_comm\n        if pynccl_comm is not None and (\n            not pynccl_comm.disabled or self.is_symmetric_memory_enabled()\n        ):\n            with pynccl_comm.change_state(\n                enable=True, stream=get_current_device_stream_fast()\n            ):\n                pynccl_comm.all_gather(output, input)\n        else:\n            torch.distributed.all_gather_into_tensor(\n                output, input, group=self.device_group\n            )\n\n    def all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):\n        if _is_npu or _is_xpu:\n            self._all_gather_into_tensor(output, input)\n        else:\n            reg_all_gather_into_tensor(output, input, group_name=self.unique_name)\n\n    def cp_all_gather_into_tensor_async(\n        self, output: torch.Tensor, input: torch.Tensor, stream=None\n    ):\n        \"\"\"\n        Implement an asynchronous `allgather` operation on a specified stream.\n        (the default `torch.distributed.all_gather_into_tensor` will trigger event synchronization),\n        eliminating the CPU-side launch-kernel blocking issue caused by synchronization problems.\n        The specific implementation uses the interface provided by pynccl to remove the synchronization logic of events.\n        \"\"\"\n        assert (\n            stream is not None\n        ), f\"Invalid params stream ({stream}, Please specify the stream to use when calling cp_all_gather_into_tensor_async.)\"\n        pynccl_comm = self.pynccl_comm\n        if pynccl_comm is None or pynccl_comm.disabled:\n            self.all_gather_into_tensor(output, input)\n        else:\n            pynccl_comm.cp_all_gather_into_tensor(output, input, stream=stream)\n\n    def all_gather(\n        self,\n        input_: torch.Tensor,\n        dim: int = -1,\n        output_tensor_list: Optional[List[torch.Tensor]] = None,\n    ) -> torch.Tensor:\n        world_size = self.world_size\n        # Bypass the function if we are using only 1 GPU.\n        if world_size == 1:\n            if output_tensor_list is not None:\n                logger.warning(\n                    \"Performing in-place all-gather with a group size of 1. \"\n                    \"This may be unnecessary; consider bypassing it for better efficiency.\"\n                )\n                output_tensor_list[0].copy_(input_)\n                return None\n            else:\n                return input_\n\n        if output_tensor_list is not None:\n            # TODO(ch-wan): support other backends\n            return torch.distributed.all_gather(\n                output_tensor_list, input_, group=self.device_group\n            )\n\n        assert (\n            -input_.dim() <= dim < input_.dim()\n        ), f\"Invalid dim ({dim}) for input tensor with shape {input_.size()}\"\n\n        # For HPUs, use HPU communicator.\n        hpu_comm = self.hpu_communicator\n        if hpu_comm is not None and not hpu_comm.disabled:\n            return hpu_comm.all_gather(input_, dim)\n\n        # For NPUs, use NPU communicator.\n        npu_comm = self.npu_communicator\n        if npu_comm is not None and not npu_comm.disabled:\n            return npu_comm.all_gather(input_, dim)\n\n        if dim < 0:\n            # Convert negative dim to positive.\n            dim += input_.dim()\n        input_size = input_.size()\n        # NOTE: we have to use concat-style all-gather here,\n        # stack-style all-gather has compatibility issues with\n        # torch.compile . see https://github.com/pytorch/pytorch/issues/138795\n        output_size = (input_size[0] * world_size,) + input_size[1:]\n        # Allocate output tensor.\n        with self.use_symmetric_memory(\n            self, disabled=not self.is_allocation_symmetric()\n        ):\n            output_tensor = torch.empty(\n                output_size, dtype=input_.dtype, device=input_.device\n            )\n\n        # All-gather.\n        if input_.is_cpu:\n            if is_shm_available(input_.dtype, self.world_size, self.local_size):\n                return torch.ops.sgl_kernel.shm_allgather(input_, dim)\n            else:\n                torch.distributed.all_gather_into_tensor(\n                    output_tensor, input_, group=self.device_group\n                )\n        else:\n            self.all_gather_into_tensor(output_tensor, input_)\n\n        # Reshape\n        output_tensor = output_tensor.reshape((world_size,) + input_size)\n        output_tensor = output_tensor.movedim(0, dim)\n        output_tensor = output_tensor.reshape(\n            input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :]\n        )\n        return output_tensor\n\n    def all_gatherv(\n        self,\n        input_: Union[torch.Tensor, List[torch.Tensor]],\n        sizes: Optional[List[int]] = None,\n    ) -> Union[torch.Tensor, List[torch.Tensor]]:\n        \"\"\"\n        Supports varying sizes per rank and input tensor list.\n        `sizes`: a list of len(world_size) with the number of items per rank to gather.\n        \"\"\"\n        world_size = self.world_size\n        pynccl_comm = self.pynccl_comm\n\n        with pynccl_comm.change_state(\n            enable=True, stream=get_current_device_stream_fast()\n        ):\n            assert (\n                pynccl_comm is not None and not pynccl_comm.disabled\n            ), \"pynccl is required for all_gatherv\"\n\n            def _all_gather_allocate_output(\n                input_: torch.Tensor, sizes: Optional[List[int]] = None\n            ):\n                input_size = input_.size()\n                if sizes is not None:\n                    assert len(sizes) == world_size\n                    assert input_.shape[0] == sizes[self.rank_in_group]\n                    output_size = (sum(sizes),) + input_size[1:]\n                    # 'sizes' is not needed if all inputs in the same group have the same shape\n                    if all(s == sizes[0] for s in sizes):\n                        sizes = None\n                else:\n                    output_size = (input_size[0] * world_size,) + input_size[1:]\n                # Allocate output tensor.\n                with self.use_symmetric_memory(self, disabled=sizes is not None):\n                    output_tensor = torch.empty(\n                        output_size, dtype=input_.dtype, device=input_.device\n                    )\n                return output_tensor, sizes\n\n            if isinstance(input_, torch.Tensor):\n                input_ = [input_]\n\n            output_list = []\n            size_list = []\n            for inp in input_:\n                output_tensor, s = _all_gather_allocate_output(inp, sizes=sizes)\n                output_list.append(output_tensor)\n                size_list.append(s)\n\n            pynccl_comm.group_start()\n            for i, inp in enumerate(input_):\n                pynccl_comm.all_gather(output_list[i], inp, sizes=size_list[i])\n            pynccl_comm.group_end()\n\n            return output_list\n\n    def gather(\n        self, input_: torch.Tensor, dst: int = 0, dim: int = -1\n    ) -> Optional[torch.Tensor]:\n        \"\"\"\n        NOTE: We assume that the input tensor is on the same device across\n        all the ranks.\n        NOTE: `dst` is the local rank of the destination rank.\n        \"\"\"\n        world_size = self.world_size\n        # Bypass the function if we are using only 1 GPU.\n        if world_size == 1:\n            return input_\n        assert (\n            -input_.dim() <= dim < input_.dim()\n        ), f\"Invalid dim ({dim}) for input tensor with shape {input_.size()}\"\n        if dim < 0:\n            # Convert negative dim to positive.\n            dim += input_.dim()\n        if self.xpu_communicator is not None and not self.xpu_communicator.disabled:\n            return self.xpu_communicator.gather(input_, self.rank_in_group, dst, dim)\n        # Allocate output tensor.\n        if self.rank_in_group == dst:\n            gather_list = [torch.empty_like(input_) for _ in range(world_size)]\n        else:\n            gather_list = None\n        # Gather.\n        torch.distributed.gather(\n            input_, gather_list, dst=self.ranks[dst], group=self.device_group\n        )\n        if self.rank_in_group == dst:\n            output_tensor = torch.cat(gather_list, dim=dim)\n        else:\n            output_tensor = None\n        return output_tensor\n\n    def broadcast(self, input_: torch.Tensor, src: int = 0):\n        \"\"\"Broadcast the input tensor.\n        NOTE: `src` is the local rank of the source rank.\n        \"\"\"\n        assert src < self.world_size, f\"Invalid src rank ({src})\"\n\n        # Bypass the function if we are using only 1 GPU.\n        if self.world_size == 1:\n            return input_\n        # Broadcast.\n        torch.distributed.broadcast(\n            input_, src=self.ranks[src], group=self.device_group\n        )\n        return input_\n\n    def broadcast_object(self, obj: Optional[Any] = None, src: int = 0):\n        \"\"\"Broadcast the input object.\n        NOTE: `src` is the local rank of the source rank.\n        \"\"\"\n        assert src < self.world_size, f\"Invalid src rank ({src})\"\n\n        # Bypass the function if we are using only 1 GPU.\n        if self.world_size == 1:\n            return obj\n        if self.mq_broadcaster is not None:\n            assert src == 0, \"Message queue broadcaster only supports src=0\"\n            return self.mq_broadcaster.broadcast_object(obj)\n        if self.rank_in_group == src:\n            torch.distributed.broadcast_object_list(\n                [obj], src=self.ranks[src], group=self.cpu_group\n            )\n            return obj\n        else:\n            recv = [None]\n            torch.distributed.broadcast_object_list(\n                recv, src=self.ranks[src], group=self.cpu_group\n            )\n            return recv[0]\n\n    def broadcast_object_list(\n        self, obj_list: List[Any], src: int = 0, group: Optional[ProcessGroup] = None\n    ):\n        \"\"\"Broadcast the input object list.\n        NOTE: `src` is the local rank of the source rank.\n        \"\"\"\n        assert src < self.world_size, f\"Invalid src rank ({src})\"\n\n        # Bypass the function if we are using only 1 GPU.\n        if self.world_size == 1:\n            return obj_list\n        # Broadcast.\n        torch.distributed.broadcast_object_list(\n            obj_list, src=self.ranks[src], group=self.device_group\n        )\n        return obj_list\n\n    def all_gather_object(self, obj: Any) -> List[Any]:\n        objs = [None] * self.world_size\n        torch.distributed.all_gather_object(objs, obj, group=self.cpu_group)\n        return objs\n\n    def send_object(\n        self,\n        obj: Any,\n        dst: int,\n        async_send: bool = False,\n    ) -> List[P2PWork]:\n        \"\"\"\n        Send the input object list to the destination rank.\n        This function uses the CPU group for all communications.\n\n        TODO: If you want to use GPU communication, please add a new argument (e.g., data_group, group),\n        use other functions (e.g., send), or implement a new function (e.g., send_object_device).\n\n        NOTE: `dst` is the local rank of the destination rank.\n        \"\"\"\n\n        assert dst < self.world_size, f\"Invalid dst rank ({dst})\"\n        assert dst != self.rank_in_group, (\n            \"Invalid destination rank. Destination rank is the same \"\n            \"as the current rank.\"\n        )\n        send_func = torch.distributed.isend if async_send else torch.distributed.send\n\n        # Serialize object to tensor and get the size as well\n        object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)\n        size_tensor = torch.tensor(\n            [object_tensor.numel()], dtype=torch.long, device=\"cpu\"\n        )\n\n        # Send object size\n        p2p_work = []\n        size_work = send_func(\n            size_tensor,\n            self.ranks[dst],\n            group=self.cpu_group,\n        )\n        if async_send:\n            p2p_work.append(P2PWork(size_work, size_tensor))\n\n        object_work = send_func(\n            object_tensor,\n            self.ranks[dst],\n            group=self.cpu_group,\n        )\n        if async_send:\n            p2p_work.append(P2PWork(object_work, object_tensor))\n\n        return p2p_work\n\n    def recv_object(\n        self,\n        src: int,\n    ) -> Any:\n        \"\"\"Receive the input object list from the source rank.\"\"\"\n        \"\"\"NOTE: `src` is the local rank of the source rank.\"\"\"\n\n        assert src < self.world_size, f\"Invalid src rank ({src})\"\n        assert (\n            src != self.rank_in_group\n        ), \"Invalid source rank. Source rank is the same as the current rank.\"\n\n        size_tensor = torch.empty(1, dtype=torch.long, device=\"cpu\")\n\n        # Receive object size\n        # We have to use irecv here to make it work for both isend and send.\n        work = torch.distributed.irecv(\n            size_tensor, src=self.ranks[src], group=self.cpu_group\n        )\n        work.wait()\n\n        # Tensor to receive serialized objects into.\n        object_tensor: Any = torch.empty(  # type: ignore[call-overload]\n            size_tensor.item(),  # type: ignore[arg-type]\n            dtype=torch.uint8,\n            device=\"cpu\",\n        )\n\n        work = torch.distributed.irecv(\n            object_tensor, src=self.ranks[src], group=self.cpu_group\n        )\n        work.wait()\n\n        obj = pickle.loads(object_tensor.numpy())\n        return obj\n\n    def broadcast_tensor_dict(\n        self,\n        tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None,\n        src: int = 0,\n        group: Optional[ProcessGroup] = None,\n        metadata_group: Optional[ProcessGroup] = None,\n    ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:\n        \"\"\"Broadcast the input tensor dictionary.\n        NOTE: `src` is the local rank of the source rank.\n        \"\"\"\n        # Bypass the function if we are using only 1 GPU.\n        if not torch.distributed.is_initialized() or self.world_size == 1:\n            return tensor_dict\n\n        group = self.device_group\n        metadata_group = self.cpu_group\n        assert src < self.world_size, f\"Invalid src rank ({src})\"\n\n        rank_in_group = self.rank_in_group\n        if rank_in_group == src:\n            metadata_list: List[Tuple[Any, Any]] = []\n            assert isinstance(\n                tensor_dict, dict\n            ), f\"Expecting a dictionary, got {type(tensor_dict)}\"\n            metadata_list, tensor_list = _split_tensor_dict(tensor_dict)\n            # `metadata_list` lives in CPU memory.\n            # `broadcast_object_list` has serialization & deserialization,\n            # all happening on CPU. Therefore, we can use the CPU group.\n            self.broadcast_object(metadata_list, src=src)\n            async_handles = []\n            for tensor in tensor_list:\n                if tensor.numel() == 0:\n                    # Skip broadcasting empty tensors.\n                    continue\n                if tensor.is_cpu:\n                    # use metadata_group for CPU tensors\n                    handle = torch.distributed.broadcast(\n                        tensor, src=self.ranks[src], group=metadata_group, async_op=True\n                    )\n                else:\n                    # use group for GPU tensors\n                    handle = torch.distributed.broadcast(\n                        tensor, src=self.ranks[src], group=group, async_op=True\n                    )\n                async_handles.append(handle)\n            for async_handle in async_handles:\n                async_handle.wait()\n\n        else:\n            metadata_list = self.broadcast_object(None, src=src)\n            tensor_dict = {}\n            async_handles = []\n            for key, value in metadata_list:\n                if isinstance(value, TensorMetadata):\n                    tensor = torch.empty(\n                        value.size, dtype=value.dtype, device=value.device\n                    )\n                    if tensor.numel() == 0:\n                        # Skip broadcasting empty tensors.\n                        tensor_dict[key] = tensor\n                        continue\n                    if tensor.is_cpu:\n                        # use metadata_group for CPU tensors\n                        handle = torch.distributed.broadcast(\n                            tensor,\n                            src=self.ranks[src],\n                            group=metadata_group,\n                            async_op=True,\n                        )\n                    else:\n                        # use group for GPU tensors\n                        handle = torch.distributed.broadcast(\n                            tensor, src=self.ranks[src], group=group, async_op=True\n                        )\n                    async_handles.append(handle)\n                    tensor_dict[key] = tensor\n                else:\n                    tensor_dict[key] = value\n            for async_handle in async_handles:\n                async_handle.wait()\n        return tensor_dict\n\n    def send_tensor_dict(\n        self,\n        tensor_dict: Dict[str, Union[torch.Tensor, Any]],\n        dst: Optional[int] = None,\n        all_gather_group: Optional[\"GroupCoordinator\"] = None,\n        async_send: bool = False,\n    ) -> Optional[List[P2PWork]]:\n        \"\"\"Send the input tensor dictionary.\n        NOTE: `dst` is the local rank of the source rank.\n        \"\"\"\n        # Bypass the function if we are using only 1 GPU.\n        if self.world_size == 1:\n            return tensor_dict\n\n        all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size\n        all_gather_rank = (\n            0 if all_gather_group is None else all_gather_group.rank_in_group\n        )\n\n        group = self.device_group\n        metadata_group = self.cpu_group\n\n        if dst is None:\n            dst = (self.rank_in_group + 1) % self.world_size\n        assert dst < self.world_size, f\"Invalid dst rank ({dst})\"\n\n        assert isinstance(\n            tensor_dict, dict\n        ), f\"Expecting a dictionary, got {type(tensor_dict)}\"\n        metadata_list, tensor_list = _split_tensor_dict(tensor_dict)\n        # Note: While switching to Device-to-Device (D2D) would introduce an extra\n        # Device-to-Host (D2H) memory copy overhead for serialization, our benchmarks\n        # show better overall transmission performance with D2D due to:\n        # 1. Superior D2D transfer bandwidth\n        # 2. Ability to overlap send and recv operations\n        # Thus the net performance gain justifies this approach.\n\n        send_func = torch.distributed.isend if async_send else torch.distributed.send\n        p2p_works = self.send_object(metadata_list, dst=dst, async_send=async_send)\n\n        for tensor in tensor_list:\n            if tensor.numel() == 0:\n                # Skip sending empty tensors.\n                continue\n\n            # send-allgather: send only a slice, then do allgather.\n            if all_gather_group is not None and tensor.numel() % all_gather_size == 0:\n                tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]\n\n            comm_group = metadata_group if tensor.is_cpu else group\n            work = send_func(tensor, self.ranks[dst], group=comm_group)\n            if async_send:\n                p2p_works.append(P2PWork(work, tensor))\n        return p2p_works\n\n    def recv_tensor_dict(\n        self,\n        src: Optional[int] = None,\n        all_gather_group: Optional[\"GroupCoordinator\"] = None,\n    ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:\n        \"\"\"Recv the input tensor dictionary.\n        NOTE: `src` is the local rank of the source rank.\n        \"\"\"\n        # Bypass the function if we are using only 1 GPU.\n        if not torch.distributed.is_initialized() or self.world_size == 1:\n            return None\n\n        all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size\n        all_gather_rank = (\n            0 if all_gather_group is None else all_gather_group.rank_in_group\n        )\n\n        group = self.device_group\n        metadata_group = self.cpu_group\n\n        if src is None:\n            src = (self.rank_in_group - 1) % self.world_size\n        assert src < self.world_size, f\"Invalid src rank ({src})\"\n\n        recv_metadata_list = self.recv_object(src=src)\n        tensor_dict: Dict[str, Any] = {}\n        for key, value in recv_metadata_list:\n            if isinstance(value, TensorMetadata):\n                tensor = torch.empty(value.size, dtype=value.dtype, device=value.device)\n                if tensor.numel() == 0:\n                    # Skip broadcasting empty tensors.\n                    tensor_dict[key] = tensor\n                    continue\n\n                # send-allgather: send only a slice, then do allgather.\n                use_all_gather = (\n                    all_gather_group is not None\n                    and tensor.numel() % all_gather_size == 0\n                )\n\n                if use_all_gather:\n                    orig_shape = tensor.shape\n                    tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]\n\n                # We have to use irecv here to make it work for both isend and send.\n                comm_group = metadata_group if tensor.is_cpu else group\n                work = torch.distributed.irecv(\n                    tensor, src=self.ranks[src], group=comm_group\n                )\n                work.wait()\n\n                if use_all_gather:\n                    tensor = all_gather_group.all_gather(tensor, dim=0)\n                    tensor = tensor.reshape(orig_shape)\n\n                tensor_dict[key] = tensor\n            else:\n                tensor_dict[key] = value\n        return tensor_dict\n\n    def barrier(self):\n        \"\"\"Barrier synchronization among the group.\n        NOTE: don't use `device_group` here! `barrier` in NCCL is\n        terrible because it is internally a broadcast operation with\n        secretly created GPU tensors. It is easy to mess up the current\n        device. Use the CPU group instead.\n        \"\"\"\n        torch.distributed.barrier(group=self.cpu_group)\n\n    def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:\n        \"\"\"Sends a tensor to the destination rank in a non-blocking way\"\"\"\n        \"\"\"NOTE: `dst` is the local rank of the destination rank.\"\"\"\n        if dst is None:\n            dst = (self.rank_in_group + 1) % self.world_size\n\n        pynccl_comm = self.pynccl_comm\n        if pynccl_comm is not None and not pynccl_comm.disabled:\n            pynccl_comm.send(tensor, dst)\n        else:\n            torch.distributed.send(tensor, self.ranks[dst], self.device_group)\n\n    def recv(\n        self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None\n    ) -> torch.Tensor:\n        \"\"\"Receives a tensor from the source rank.\"\"\"\n        \"\"\"NOTE: `src` is the local rank of the source rank.\"\"\"\n        if src is None:\n            src = (self.rank_in_group - 1) % self.world_size\n\n        tensor = torch.empty(size, dtype=dtype, device=self.device)\n        pynccl_comm = self.pynccl_comm\n        if pynccl_comm is not None and not pynccl_comm.disabled:\n            pynccl_comm.recv(tensor, src)\n        else:\n            torch.distributed.recv(tensor, self.ranks[src], self.device_group)\n        return tensor\n\n    def destroy(self):\n        if self.device_group is not None:\n            torch.distributed.destroy_process_group(self.device_group)\n            self.device_group = None\n        if self.cpu_group is not None:\n            torch.distributed.destroy_process_group(self.cpu_group)\n            self.cpu_group = None\n        if self.pynccl_comm is not None:\n            self.pynccl_comm = None\n        if self.ca_comm is not None:\n            self.ca_comm = None\n        if self.mq_broadcaster is not None:\n            self.mq_broadcaster = None\n\n\n_WORLD: Optional[GroupCoordinator] = None\n\n\ndef get_world_group() -> GroupCoordinator:\n    assert _WORLD is not None, \"world group is not initialized\"\n    return _WORLD\n\n\ndef init_world_group(\n    ranks: List[int], local_rank: int, backend: str\n) -> GroupCoordinator:\n    return GroupCoordinator(\n        group_ranks=[ranks],\n        local_rank=local_rank,\n        torch_distributed_backend=backend,\n        use_pynccl=False,\n        use_pymscclpp=False,\n        use_custom_allreduce=False,\n        use_torch_symm_mem_all_reduce=False,\n        use_hpu_communicator=False,\n        use_xpu_communicator=False,\n        use_npu_communicator=False,\n        group_name=\"world\",\n    )\n\n\ndef init_model_parallel_group(\n    group_ranks: List[List[int]],\n    local_rank: int,\n    backend: str,\n    use_pynccl: Optional[bool] = None,\n    use_custom_allreduce: Optional[bool] = None,\n    use_message_queue_broadcaster: bool = False,\n    group_name: Optional[str] = None,\n    use_mscclpp_allreduce: Optional[bool] = None,\n    pynccl_use_current_stream: bool = True,\n    use_torch_symm_mem_allreduce: Optional[bool] = None,\n) -> GroupCoordinator:\n    if use_custom_allreduce is None:\n        use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE\n    if use_mscclpp_allreduce is None:\n        use_mscclpp_allreduce = _ENABLE_MSCCLPP_ALL_REDUCE\n    if use_torch_symm_mem_allreduce is None:\n        use_torch_symm_mem_allreduce = _ENABLE_TORCH_SYMM_MEM_ALL_REDUCE\n    return GroupCoordinator(\n        group_ranks=group_ranks,\n        local_rank=local_rank,\n        torch_distributed_backend=backend,\n        use_pynccl=(\n            not (_is_npu or _is_xpu or backend == \"mooncake\")\n            if use_pynccl is None\n            else use_pynccl\n        ),\n        use_pymscclpp=use_mscclpp_allreduce,\n        use_custom_allreduce=use_custom_allreduce,\n        use_torch_symm_mem_all_reduce=use_torch_symm_mem_allreduce,\n        use_hpu_communicator=True,\n        use_xpu_communicator=True,\n        use_npu_communicator=True,\n        use_message_queue_broadcaster=use_message_queue_broadcaster,\n        group_name=group_name,\n        pynccl_use_current_stream=pynccl_use_current_stream,\n    )\n\n\n_TP: Optional[GroupCoordinator] = None\n_ATTN_TP: Optional[GroupCoordinator] = None\n_ATTN_CP: Optional[GroupCoordinator] = None\n\n# duplicate GroupCoordinator for prefill in PD-Multiplexing\n_PDMUX_PREFILL_TP_GROUP: Optional[GroupCoordinator] = None\n\n_ENABLE_PDMUX_P_TP: bool = False\n\n\ndef set_pdmux_status(enable_prefill_multiplexing: bool):\n    global _ENABLE_PDMUX_P_TP\n    _ENABLE_PDMUX_P_TP = enable_prefill_multiplexing\n\n\ndef get_tp_group() -> GroupCoordinator:\n    if _ENABLE_PDMUX_P_TP:\n        assert (\n            _PDMUX_PREFILL_TP_GROUP is not None\n        ), \"tensor model parallel group for PD-Multiplexing Prefill is not initialized\"\n        return _PDMUX_PREFILL_TP_GROUP\n    assert _TP is not None, \"tensor model parallel group is not initialized\"\n    return _TP\n\n\ndef get_attn_tp_group() -> GroupCoordinator:\n    assert (\n        _ATTN_TP is not None\n    ), \"attention tensor model parallel group is not initialized\"\n    return _ATTN_TP\n\n\ndef get_attn_cp_group() -> GroupCoordinator:\n    assert (\n        _ATTN_CP is not None\n    ), \"attention context model parallel group is not initialized\"\n    return _ATTN_CP\n\n\n_MOE_DP: Optional[GroupCoordinator] = None\n_MOE_EP: Optional[GroupCoordinator] = None\n_MOE_TP: Optional[GroupCoordinator] = None\n\n\ndef get_moe_dp_group() -> GroupCoordinator:\n    assert _MOE_DP is not None, \"moe data parallel group is not initialized\"\n    return _MOE_DP\n\n\ndef get_moe_ep_group() -> GroupCoordinator:\n    assert _MOE_EP is not None, \"expert model parallel group is not initialized\"\n    return _MOE_EP\n\n\ndef get_moe_tp_group() -> GroupCoordinator:\n    assert _MOE_TP is not None, \"expert model parallel group is not initialized\"\n    return _MOE_TP\n\n\n# kept for backward compatibility\nget_tensor_model_parallel_group = get_tp_group\n\n_PP: Optional[GroupCoordinator] = None\n\n\ndef get_pp_group() -> GroupCoordinator:\n    assert _PP is not None, \"pipeline model parallel group is not initialized\"\n    return _PP\n\n\n# kept for backward compatibility\nget_pipeline_model_parallel_group = get_pp_group\n\n\ndef get_mooncake_transfer_engine():\n    \"\"\"\n    Return the shared MooncakeTransferEngine if initialized in device_communicators,\n    else None. Used by disaggregation mooncake backend and mem_cache mooncake_store.\n    \"\"\"\n    from sglang.srt.distributed.device_communicators.mooncake_transfer_engine import (\n        get_mooncake_transfer_engine as _get_engine,\n    )\n\n    return _get_engine()\n\n\n@contextmanager\ndef graph_capture(stream: Optional[torch.cuda.Stream] = None):\n    \"\"\"\n    `graph_capture` is a context manager which should surround the code that\n    is capturing the CUDA graph. Its main purpose is to ensure that the\n    some operations will be run after the graph is captured, before the graph\n    is replayed. It returns a `GraphCaptureContext` object which contains the\n    necessary data for the graph capture. Currently, it only contains the\n    stream that the graph capture is running on. This stream is set to the\n    current CUDA stream when the context manager is entered and reset to the\n    default stream when the context manager is exited. This is to ensure that\n    the graph capture is running on a separate stream from the default stream,\n    in order to explicitly distinguish the kernels to capture\n    from other kernels possibly launched on background in the default stream.\n    \"\"\"\n    with get_tp_group().graph_capture(\n        stream=stream\n    ) as context, get_pp_group().graph_capture(context):\n        yield context\n\n\nlogger = logging.getLogger(__name__)\n\n_ENABLE_CUSTOM_ALL_REDUCE = True\n_ENABLE_MSCCLPP_ALL_REDUCE = False\n_ENABLE_TORCH_SYMM_MEM_ALL_REDUCE = False\n\n\ndef set_custom_all_reduce(enable: bool):\n    global _ENABLE_CUSTOM_ALL_REDUCE\n    _ENABLE_CUSTOM_ALL_REDUCE = enable\n\n\ndef set_mscclpp_all_reduce(enable: bool):\n    global _ENABLE_MSCCLPP_ALL_REDUCE\n    _ENABLE_MSCCLPP_ALL_REDUCE = enable\n\n\ndef set_torch_symm_mem_all_reduce(enable: bool):\n    global _ENABLE_TORCH_SYMM_MEM_ALL_REDUCE\n    _ENABLE_TORCH_SYMM_MEM_ALL_REDUCE = enable\n\n\n_DEVICE_TO_DISTRIBUTED_BACKEND = {\n    \"cuda\": \"nccl\",\n    \"xpu\": \"xccl\",\n    \"hpu\": \"hccl\",\n    \"cpu\": \"gloo\",\n    \"npu\": \"hccl\",\n    \"musa\": \"mccl\",\n}\n\n\ndef get_default_distributed_backend(device: str) -> str:\n    return _DEVICE_TO_DISTRIBUTED_BACKEND.get(device, \"gloo\")\n\n\ndef _create_global_tcp_store(rank: int, world_size: int) -> None:\n    \"\"\"Create a global TCPStore for coordination across ranks.\n\n    This function creates a TCPStore that all ranks can use for coordination\n    (e.g., for NIXL buffer setup).\n    \"\"\"\n    from torch.distributed import TCPStore\n\n    master_ip = os.environ.get(\"MASTER_ADDR\")\n\n    if not master_ip:\n        logger.warning(\n            \"Could not determine master IP for global TCPStore. \"\n            \"Broadcasting from rank 0 to all ranks.\"\n        )\n\n    base_store_port = envs.SGLANG_TCP_STORE_PORT.get()\n\n    # Rank 0 gets its local IP and broadcasts it to all ranks\n    # Use broadcast_object_list which works with any backend (handles CPU/GPU automatically)\n    if not master_ip:\n        if rank == 0:\n            master_ip = get_local_ip_auto()\n            ip_list = [master_ip]\n        else:\n            ip_list = [None]\n\n        torch.distributed.broadcast_object_list(ip_list, src=0)\n        master_ip = ip_list[0]\n\n    try:\n        tcp_store = TCPStore(\n            host_name=master_ip,\n            port=base_store_port,\n            world_size=world_size,\n            is_master=(rank == 0),\n        )\n        set_global_tcp_store(tcp_store)\n        logger.info(\n            \"Created global TCPStore at %s:%d (rank=%d, world_size=%d)\",\n            master_ip,\n            base_store_port,\n            rank,\n            world_size,\n        )\n    except Exception as e:\n        logger.warning(\n            \"Failed to create global TCPStore at %s:%d: %s. \"\n            \"Components requiring TCPStore (like NIXL) may not work.\",\n            master_ip,\n            base_store_port,\n            e,\n        )\n\n\ndef init_distributed_environment(\n    world_size: int = -1,\n    rank: int = -1,\n    distributed_init_method: str = \"env://\",\n    local_rank: int = -1,\n    backend: str = \"nccl\",\n    timeout: Optional[int] = None,\n    moe_a2a_backend: Optional[str] = None,\n):\n    logger.debug(\n        \"world_size=%d rank=%d local_rank=%d \" \"distributed_init_method=%s backend=%s\",\n        world_size,\n        rank,\n        local_rank,\n        distributed_init_method,\n        backend,\n    )\n    if \"mooncake\" in backend:\n        try:\n            from mooncake import ep as mooncake_ep\n        except ImportError as e:\n            raise ImportError(\n                \"Please install mooncake by following the instructions at \"\n                \"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md \"  # noqa: E501\n                \"to run SGLang with Mooncake Backend.\"\n            ) from e\n        mooncake_ep.set_host_ip(get_local_ip_auto())\n\n    if not torch.distributed.is_initialized():\n        assert distributed_init_method is not None, (\n            \"distributed_init_method must be provided when initializing \"\n            \"distributed environment\"\n        )\n        if timeout is not None:\n            assert isinstance(timeout, (int)), \"timeout must be a number\"\n            assert timeout > 0, \"timeout must be positive\"\n            timeout = timedelta(seconds=timeout)\n\n        pg_options = get_torch_distributed_pg_options()\n\n        # this backend is used for WORLD\n        torch.distributed.init_process_group(\n            backend=backend,\n            init_method=distributed_init_method,\n            world_size=world_size,\n            rank=rank,\n            timeout=timeout,\n            pg_options=pg_options,\n        )\n\n        # Create a global TCPStore for coordination (used by NIXL)\n        if moe_a2a_backend == \"nixl\":\n            _create_global_tcp_store(rank, world_size)\n\n    # set the local rank\n    # local_rank is not available in torch ProcessGroup,\n    # see https://github.com/pytorch/pytorch/issues/122816\n    if local_rank == -1:\n        # local rank not set, this usually happens in single-node\n        # setting, where we can use rank as local rank\n        if distributed_init_method == \"env://\":\n            local_rank = int(os.environ.get(\"LOCAL_RANK\", \"0\"))\n        else:\n            local_rank = rank\n    global _WORLD\n    if _WORLD is None:\n        ranks = list(range(torch.distributed.get_world_size()))\n        _WORLD = init_world_group(ranks, local_rank, backend)\n    else:\n        assert (\n            _WORLD.world_size == torch.distributed.get_world_size()\n        ), \"world group already initialized with a different world size\"\n\n\ndef initialize_model_parallel(\n    tensor_model_parallel_size: int = 1,\n    expert_model_parallel_size: int = 1,\n    pipeline_model_parallel_size: int = 1,\n    attention_data_parallel_size: int = 1,\n    attention_context_model_parallel_size: int = 1,\n    moe_data_model_parallel_size: int = 1,\n    backend: Optional[str] = None,\n    duplicate_tp_group: bool = False,\n) -> None:\n    \"\"\"\n    Initialize model parallel groups.\n\n    Arguments:\n        tensor_model_parallel_size: number of GPUs used for tensor model\n            parallelism.\n        expert_model_parallel_size: number of GPUs used for expert model\n            parallelism.\n        pipeline_model_parallel_size: number of GPUs used for pipeline model\n            parallelism.\n        attention_data_parallel_size: number of GPUs used for attention data\n            parallelism.\n        attention_context_model_parallel_size: number of GPUs used for attention context\n            parallelism.\n        moe_data_model_parallel_size: number of GPUs used for moe data\n            parallelism.\n\n    Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we\n    use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize\n    the model pipeline. The present function will\n    create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:\n        4 tensor model-parallel groups:\n            [g0, g1], [g2, g3], [g4, g5], [g6, g7]\n        2 pipeline model-parallel groups:\n            [g0, g2, g4, g6], [g1, g3, g5, g7]\n\n    Let's say we use 2 GPUs for attention context parallelism (attn_cp_size=2) and 4 GPUs for\n    attention tensor parallelism (attn_tp_size=4). As for MoE part, we use 2 GPUs for moe data\n    parallelism (moe_dp_size=2) and 4 GPUs for moe expert parallelism (moe_ep_size=4). The present\n    function will create the following groups:\n        2 tensor model-parallel groups:\n            [g0, g1, g2, g3], [g4, g5, g6, g7]\n        4 attention context-parallel groups:\n            [g0, g4], [g1, g5], [g2, g6], [g3, g7]\n        2 moe expert-parallel groups:\n            [g0, g1, g2, g3], [g4, g5, g6, g7]\n        4 moe data-parallel groups:\n            [g0, g4], [g1, g5], [g2, g6], [g3, g7]\n\n    Note that for efficiency, the caller should make sure adjacent ranks\n    are on the same DGX box. For example if we are using 2 DGX-1 boxes\n    with a total of 16 GPUs, rank 0 to 7 belong to the first box and\n    ranks 8 to 15 belong to the second box.\n    \"\"\"\n    # Get world size and rank. Ensure some consistencies.\n    assert torch.distributed.is_initialized()\n    world_size: int = torch.distributed.get_world_size()\n    backend = backend or torch.distributed.get_backend(get_world_group().device_group)\n\n    if world_size != tensor_model_parallel_size * pipeline_model_parallel_size:\n        raise RuntimeError(\n            f\"world_size ({world_size}) is not equal to \"\n            f\"tensor_model_parallel_size ({tensor_model_parallel_size}) x \"\n            f\"pipeline_model_parallel_size ({pipeline_model_parallel_size})\"\n        )\n\n    # Build the tensor model-parallel groups.\n    num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size\n    global _TP\n    assert _TP is None, \"tensor model parallel group is already initialized\"\n    group_ranks = []\n    for tp_group_idx in range(num_tensor_model_parallel_groups):\n        ranks = list(\n            range(\n                tp_group_idx * tensor_model_parallel_size,\n                (tp_group_idx + 1) * tensor_model_parallel_size,\n            )\n        )\n        group_ranks.append(ranks)\n\n    # message queue broadcaster is only used in tensor model parallel group\n    _TP = init_model_parallel_group(\n        group_ranks,\n        get_world_group().local_rank,\n        backend,\n        use_message_queue_broadcaster=get_bool_env_var(\n            \"SGLANG_USE_MESSAGE_QUEUE_BROADCASTER\", \"true\"\n        ),\n        group_name=\"tp\",\n        pynccl_use_current_stream=duplicate_tp_group,\n    )\n\n    if duplicate_tp_group:\n        global _PDMUX_PREFILL_TP_GROUP\n        assert (\n            _PDMUX_PREFILL_TP_GROUP is None\n        ), \"tensor model parallel group for PD-Multiplexing Prefill is already initialized\"\n        _PDMUX_PREFILL_TP_GROUP = init_model_parallel_group(\n            group_ranks,\n            get_world_group().local_rank,\n            backend,\n            use_message_queue_broadcaster=get_bool_env_var(\n                \"SGLANG_USE_MESSAGE_QUEUE_BROADCASTER\", \"true\"\n            ),\n            group_name=\"pdmux_prefill_tp\",\n            pynccl_use_current_stream=True,\n        )\n        if _TP.pynccl_comm:\n            _TP.pynccl_comm.disabled = False\n            _PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False\n\n    attn_dp_size = attention_data_parallel_size\n    attn_cp_size = attention_context_model_parallel_size\n    attn_tp_size = tensor_model_parallel_size // attn_cp_size // attn_dp_size\n\n    global _ATTN_CP\n    assert (\n        _ATTN_CP is None\n    ), \"attention context model parallel group is already initialized\"\n    if attn_cp_size == tensor_model_parallel_size:\n        _ATTN_CP = _TP\n    else:\n        group_ranks = []\n        for tp_group_idx in range(num_tensor_model_parallel_groups):\n            for dp_idx in range(attn_dp_size):\n                for attn_tp_idx in range(attn_tp_size):\n                    st = (\n                        tp_group_idx * tensor_model_parallel_size\n                        + dp_idx * attn_tp_size * attn_cp_size\n                        + attn_tp_idx\n                    )\n                    en = (\n                        tp_group_idx * tensor_model_parallel_size\n                        + (dp_idx + 1) * attn_tp_size * attn_cp_size\n                        + attn_tp_idx\n                    )\n                    ranks = list(range(st, en, attn_tp_size))\n                    group_ranks.append(ranks)\n        _ATTN_CP = init_model_parallel_group(\n            group_ranks,\n            get_world_group().local_rank,\n            backend,\n            group_name=\"attn_cp\",\n        )\n\n    from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP\n\n    global _ATTN_TP\n    assert (\n        _ATTN_TP is None\n    ), \"attention tensor model parallel group is already initialized\"\n    if attn_tp_size == tensor_model_parallel_size:\n        _ATTN_TP = _TP\n    else:\n        group_ranks = []\n        for tp_group_idx in range(num_tensor_model_parallel_groups):\n            for cp_dp_combined_idx in range(attn_cp_size * attn_dp_size):\n                st = (\n                    tp_group_idx * tensor_model_parallel_size\n                    + cp_dp_combined_idx * attn_tp_size\n                )\n                en = (\n                    tp_group_idx * tensor_model_parallel_size\n                    + (cp_dp_combined_idx + 1) * attn_tp_size\n                )\n                ranks = list(range(st, en))\n                group_ranks.append(ranks)\n        _ATTN_TP = init_model_parallel_group(\n            group_ranks,\n            get_world_group().local_rank,\n            backend,\n            use_pynccl=SYNC_TOKEN_IDS_ACROSS_TP,\n            use_mscclpp_allreduce=False,\n            use_custom_allreduce=False,\n            use_torch_symm_mem_allreduce=False,\n            group_name=\"attention_tp\",\n        )\n\n    moe_ep_size = expert_model_parallel_size\n    moe_dp_size = moe_data_model_parallel_size\n    moe_tp_size = tensor_model_parallel_size // moe_ep_size // moe_dp_size\n\n    global _MOE_DP\n    assert _MOE_DP is None, \"moe data parallel group is already initialized\"\n    # gpus_per_pp_stage = tensor_model_parallel_size * attention_context_model_parallel_size\n    if moe_dp_size == tensor_model_parallel_size:\n        _MOE_DP = _TP\n    else:\n        group_ranks = []\n        for tp_group_idx in range(num_tensor_model_parallel_groups):\n            for tp_ep_combined_idx in range(moe_tp_size * moe_ep_size):\n                st = tp_group_idx * tensor_model_parallel_size + tp_ep_combined_idx\n                en = (\n                    tp_group_idx + 1\n                ) * tensor_model_parallel_size + tp_ep_combined_idx\n                ranks = list(range(st, en, moe_tp_size * moe_ep_size))\n                group_ranks.append(ranks)\n        _MOE_DP = init_model_parallel_group(\n            group_ranks,\n            get_world_group().local_rank,\n            backend,\n            group_name=\"moe_dp\",\n        )\n\n    global _MOE_EP\n    assert _MOE_EP is None, \"expert model parallel group is already initialized\"\n    if moe_ep_size == tensor_model_parallel_size:\n        _MOE_EP = _TP\n    else:\n        # TODO(ch-wan): use split_group to save memory\n        group_ranks = []\n        for tp_group_idx in range(num_tensor_model_parallel_groups):\n            for moe_dp_idx in range(moe_dp_size):\n                for moe_tp_idx in range(moe_tp_size):\n                    st = (\n                        tp_group_idx * tensor_model_parallel_size\n                        + moe_dp_idx * moe_ep_size * moe_tp_size\n                        + moe_tp_idx\n                    )\n                    en = st + moe_ep_size * moe_tp_size\n                    ranks = list(range(st, en, moe_tp_size))\n                    group_ranks.append(ranks)\n        _MOE_EP = init_model_parallel_group(\n            group_ranks,\n            get_world_group().local_rank,\n            backend,\n            group_name=\"moe_ep\",\n        )\n\n    global _MOE_TP\n    assert _MOE_TP is None, \"expert model parallel group is already initialized\"\n    if moe_tp_size == tensor_model_parallel_size:\n        _MOE_TP = _TP\n    else:\n        # TODO(ch-wan): use split_group to save memory\n        group_ranks = []\n        for tp_group_idx in range(num_tensor_model_parallel_groups):\n            for ep_dp_combined_idx in range(moe_ep_size * moe_dp_size):\n                st = (\n                    tp_group_idx * tensor_model_parallel_size\n                    + ep_dp_combined_idx * moe_tp_size\n                )\n                en = (\n                    tp_group_idx * tensor_model_parallel_size\n                    + (ep_dp_combined_idx + 1) * moe_tp_size\n                )\n                ranks = list(range(st, en))\n                group_ranks.append(ranks)\n        _MOE_TP = init_model_parallel_group(\n            group_ranks,\n            get_world_group().local_rank,\n            backend,\n            group_name=\"moe_tp\",\n        )\n\n    # Build the pipeline model-parallel groups.\n    num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size\n    global _PP\n    assert _PP is None, \"pipeline model parallel group is already initialized\"\n    group_ranks = []\n    for pp_group_idx in range(num_pipeline_model_parallel_groups):\n        ranks = list(\n            range(pp_group_idx, world_size, num_pipeline_model_parallel_groups)\n        )\n        group_ranks.append(ranks)\n    # pipeline parallel does not need custom allreduce\n    _PP = init_model_parallel_group(\n        group_ranks,\n        get_world_group().local_rank,\n        backend,\n        use_custom_allreduce=False,\n        group_name=\"pp\",\n    )\n\n\ndef create_custom_parallel_group(\n    group_ranks: List[int], backend: str = \"gloo\"\n) -> Optional[torch.distributed.ProcessGroup]:\n    \"\"\"\n    Create a custom parallel group based on the provided ranks.\n\n    Args:\n        group_ranks: The list of ranks that the CURRENT process wants to join.\n                     (e.g., Rank 0 passes [0...7], Rank 8 passes [8...15])\n        backend: The communication backend (default: \"gloo\").\n\n    Returns:\n        The ProcessGroup if the current rank is in group_ranks, else None.\n    \"\"\"\n    assert torch.distributed.is_initialized()\n\n    world_size = torch.distributed.get_world_size()\n    rank = torch.distributed.get_rank()\n\n    local_config = sorted(list(set(group_ranks)))\n    gathered_configs = [None for _ in range(world_size)]\n\n    torch.distributed.all_gather_object(gathered_configs, local_config)\n\n    unique_groups = []\n    seen_signatures = set()\n\n    for config in gathered_configs:\n        config_tuple = tuple(config)\n        if config_tuple not in seen_signatures:\n            seen_signatures.add(config_tuple)\n            unique_groups.append(list(config_tuple))\n\n    unique_groups.sort(key=lambda x: x[0])\n\n    my_new_group = None\n\n    for g_ranks in unique_groups:\n        group = torch.distributed.new_group(ranks=g_ranks, backend=backend)\n\n        if set(g_ranks) == set(local_config):\n            my_new_group = group\n            logger.debug(\n                f\"Rank {rank} successfully created/joined custom group: {g_ranks}\"\n            )\n\n    return my_new_group\n\n\ndef ensure_model_parallel_initialized(\n    tensor_model_parallel_size: int,\n    expert_model_parallel_size: int,\n    pipeline_model_parallel_size: int,\n    backend: Optional[str] = None,\n) -> None:\n    \"\"\"Helper to initialize model parallel groups if they are not initialized,\n    or ensure tensor-parallel and pipeline-parallel sizes are equal to expected\n    values if the model parallel groups are initialized.\n    \"\"\"\n    backend = backend or torch.distributed.get_backend(get_world_group().device_group)\n    if not model_parallel_is_initialized():\n        initialize_model_parallel(\n            tensor_model_parallel_size,\n            expert_model_parallel_size,\n            pipeline_model_parallel_size,\n            backend,\n        )\n        return\n\n    assert get_tensor_model_parallel_world_size() == tensor_model_parallel_size, (\n        \"tensor parallel group already initialized, but of unexpected size: \"\n        f\"{get_tensor_model_parallel_world_size()=} vs. \"\n        f\"{tensor_model_parallel_size=}\"\n    )\n    pp_world_size = get_pp_group().world_size\n    assert pp_world_size == pipeline_model_parallel_size, (\n        \"pipeline parallel group already initialized, but of unexpected size: \"\n        f\"{pp_world_size=} vs. \"\n        f\"{pipeline_model_parallel_size=}\"\n    )\n\n\ndef model_parallel_is_initialized():\n    \"\"\"Check if tensor and pipeline parallel groups are initialized.\"\"\"\n    return _TP is not None and _PP is not None\n\n\n_TP_STATE_PATCHED = False\n\n\n@contextmanager\ndef patch_tensor_parallel_group(tp_group: GroupCoordinator):\n    \"\"\"Patch the tp group temporarily until this function ends.\n\n    This method is for draft workers of speculative decoding to run draft model\n    with different tp degree from that of target model workers.\n\n    Args:\n        tp_group (GroupCoordinator): the tp group coordinator\n    \"\"\"\n    global _TP_STATE_PATCHED\n    assert not _TP_STATE_PATCHED, \"Should not call when it's already patched\"\n\n    _TP_STATE_PATCHED = True\n    old_tp_group = get_tp_group()\n    global _TP\n    _TP = tp_group\n    try:\n        yield\n    finally:\n        # restore the original state\n        _TP_STATE_PATCHED = False\n        _TP = old_tp_group\n\n\ndef get_world_size():\n    \"\"\"Return world size for the world group.\"\"\"\n    return get_world_group().world_size\n\n\ndef get_world_rank():\n    \"\"\"Return my rank for the world group.\"\"\"\n    return get_world_group().rank_in_group\n\n\ndef get_tensor_model_parallel_world_size():\n    \"\"\"Return world size for the tensor model parallel group.\"\"\"\n    return get_tp_group().world_size\n\n\ndef get_tensor_model_parallel_rank():\n    \"\"\"Return my rank for the tensor model parallel group.\"\"\"\n    return get_tp_group().rank_in_group\n\n\n# ATTN_TP\ndef get_attn_tensor_model_parallel_world_size():\n    \"\"\"Return world size for the attention tensor model parallel group.\"\"\"\n    return get_attn_tp_group().world_size\n\n\ndef get_attn_tensor_model_parallel_rank():\n    \"\"\"Return my rank for the attention tensor model parallel group.\"\"\"\n    return get_attn_tp_group().rank_in_group\n\n\n# ATTN_CP\ndef get_attn_context_model_parallel_world_size():\n    \"\"\"Return world size for the attention context model parallel group.\"\"\"\n    return get_attn_cp_group().world_size\n\n\ndef get_attn_context_model_parallel_rank():\n    \"\"\"Return my rank for the attention context model parallel group.\"\"\"\n    return get_attn_cp_group().rank_in_group\n\n\ndef get_pipeline_model_parallel_world_size():\n    \"\"\"Return world size for the pipeline model parallel group.\"\"\"\n    return get_pp_group().world_size\n\n\ndef get_pipeline_model_parallel_rank():\n    \"\"\"Return my rank for the pipeline model parallel group.\"\"\"\n    return get_pp_group().rank_in_group\n\n\n# MOE_DP\ndef get_moe_data_parallel_world_size():\n    \"\"\"Return world size for the moe data parallel group.\"\"\"\n    return get_moe_dp_group().world_size\n\n\ndef get_moe_data_parallel_rank():\n    \"\"\"Return my rank for the moe data parallel group.\"\"\"\n    return get_moe_dp_group().rank_in_group\n\n\n# MOE_EP\ndef get_moe_expert_parallel_world_size():\n    \"\"\"Return world size for the moe expert parallel group.\"\"\"\n    return get_moe_ep_group().world_size\n\n\ndef get_moe_expert_parallel_rank():\n    \"\"\"Return my rank for the moe expert parallel group.\"\"\"\n    return get_moe_ep_group().rank_in_group\n\n\n# MOE_TP\ndef get_moe_tensor_parallel_world_size():\n    \"\"\"Return world size for the moe tensor parallel group.\"\"\"\n    return get_moe_tp_group().world_size\n\n\ndef get_moe_tensor_parallel_rank():\n    \"\"\"Return my rank for the moe tensor parallel group.\"\"\"\n    return get_moe_tp_group().rank_in_group\n\n\ndef destroy_model_parallel():\n    \"\"\"Set the groups to none and destroy them.\"\"\"\n    global _TP\n    if _TP:\n        _TP.destroy()\n    _TP = None\n\n    global _PP\n    if _PP:\n        _PP.destroy()\n    _PP = None\n\n    global _MOE_EP\n    if _MOE_EP:\n        _MOE_EP.destroy()\n    _MOE_EP = None\n\n    global _MOE_TP\n    if _MOE_TP:\n        _MOE_TP.destroy()\n    _MOE_TP = None\n\n    global _ATTN_CP\n    if _ATTN_CP:\n        _ATTN_CP.destroy()\n    _ATTN_CP = None\n\n    global _ATTN_TP\n    if _ATTN_TP:\n        _ATTN_TP.destroy()\n    _ATTN_TP = None\n\n    global _MOE_DP\n    if _MOE_DP:\n        _MOE_DP.destroy()\n    _MOE_DP = None\n\n    global _PDMUX_PREFILL_TP_GROUP\n    if _PDMUX_PREFILL_TP_GROUP:  # type: ignore[union-attr]\n        _PDMUX_PREFILL_TP_GROUP.destroy()\n    _PDMUX_PREFILL_TP_GROUP = None\n\n\ndef destroy_distributed_environment():\n    global _WORLD\n    if _WORLD:\n        _WORLD.destroy()\n    _WORLD = None\n    if torch.distributed.is_initialized():\n        torch.distributed.destroy_process_group()\n\n\ndef cleanup_dist_env_and_memory(shutdown_ray: bool = False):\n    destroy_model_parallel()\n    destroy_distributed_environment()\n    with contextlib.suppress(AssertionError):\n        torch.distributed.destroy_process_group()\n    if shutdown_ray:\n        import ray  # Lazy import Ray\n\n        ray.shutdown()\n    gc.collect()\n    if not _is_cpu:\n        if hasattr(torch, \"cuda\") and torch.cuda.is_available():\n            torch.cuda.empty_cache()\n            if hasattr(torch._C, \"_host_emptyCache\"):\n                torch._C._host_emptyCache()\n            else:\n                logger.warning(\n                    \"torch._C._host_emptyCache() only available in Pytorch >=2.5\"\n                )\n        elif hasattr(torch, \"xpu\") and torch.xpu.is_available():\n            torch.xpu.empty_cache()\n        elif hasattr(torch, \"npu\") and torch.npu.is_available():\n            torch.npu.empty_cache()\n        elif hasattr(torch, \"musa\") and torch.musa.is_available():\n            torch.musa.empty_cache()\n\n\ndef in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:\n    \"\"\"\n    This is a collective operation that returns if each rank is in the same node\n    as the source rank. It tests if processes are attached to the same\n    memory system (shared access to shared memory).\n    \"\"\"\n    assert (\n        torch.distributed.get_backend(pg) != torch.distributed.Backend.NCCL\n    ), \"in_the_same_node_as should be tested with a non-NCCL group.\"\n    # local rank inside the group\n    rank = torch.distributed.get_rank(group=pg)\n    world_size = torch.distributed.get_world_size(group=pg)\n\n    # local tensor in each process to store the result\n    is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)\n\n    # global ranks of the processes in the group\n    ranks = torch.distributed.get_process_group_ranks(pg)\n\n    magic_message = b\"magic_message\"\n    shm = None\n\n    try:\n        with contextlib.suppress(OSError):\n            if rank == source_rank:\n                # create a shared memory segment\n                shm = shared_memory.SharedMemory(create=True, size=128)\n                shm.buf[: len(magic_message)] = magic_message\n                torch.distributed.broadcast_object_list(\n                    [shm.name], src=ranks[source_rank], group=pg\n                )\n                is_in_the_same_node[rank] = 1\n            else:\n                # try to open the shared memory segment\n                recv = [None]\n                torch.distributed.broadcast_object_list(\n                    recv, src=ranks[source_rank], group=pg\n                )\n                name = recv[0]\n                # fix to https://stackoverflow.com/q/62748654/9191338\n                # Python incorrectly tracks shared memory even if it is not\n                # created by the process. The following patch is a workaround.\n                with patch(\n                    \"multiprocessing.resource_tracker.register\",\n                    lambda *args, **kwargs: None,\n                ):\n                    shm = shared_memory.SharedMemory(name=name)\n                if shm.buf[: len(magic_message)] == magic_message:\n                    is_in_the_same_node[rank] = 1\n    except Exception as e:\n        logger.error(\"Error ignored in is_in_the_same_node: %s\", e)\n    finally:\n        if shm:\n            shm.close()\n\n    torch.distributed.barrier(group=pg)\n\n    # clean up the shared memory segment\n    with contextlib.suppress(OSError):\n        if rank == source_rank and shm:\n            shm.unlink()\n    torch.distributed.all_reduce(is_in_the_same_node, group=pg)\n\n    return [x == 1 for x in is_in_the_same_node.tolist()]\n\n\nvllm_get_pp_group = None\nvllm_get_tp_group = None\nvllm_get_world_group = None\n\n\ndef monkey_patch_vllm_parallel_state(reverse: bool = False):\n    try:\n        import vllm.distributed.parallel_state as vllm_parrlel_state\n    except ImportError:\n        return\n\n    global vllm_get_pp_group, vllm_get_tp_group, vllm_get_world_group\n    if vllm_get_pp_group is None:\n        vllm_get_pp_group = vllm_parrlel_state.get_pp_group\n        vllm_get_tp_group = vllm_parrlel_state.get_tp_group\n        vllm_get_world_group = vllm_parrlel_state.get_world_group\n    if reverse:\n        setattr(vllm_parrlel_state, \"get_pp_group\", vllm_get_pp_group)\n        setattr(vllm_parrlel_state, \"get_tp_group\", vllm_get_tp_group)\n        setattr(vllm_parrlel_state, \"get_world_group\", vllm_get_world_group)\n    else:\n        setattr(vllm_parrlel_state, \"get_pp_group\", get_pp_group)\n        setattr(vllm_parrlel_state, \"get_tp_group\", get_tp_group)\n        setattr(vllm_parrlel_state, \"get_world_group\", get_world_group)\n"
  },
  {
    "path": "python/sglang/srt/distributed/utils.py",
    "content": "# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/utils.py\n\n# Copyright 2023 The vLLM team.\n# Adapted from\n# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py\n# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.\nimport dataclasses\nimport logging\nimport os\nimport pickle\nimport time\nfrom collections import deque\nfrom typing import Any, Deque, Dict, Optional, Sequence, Tuple\n\nimport torch\nfrom torch.distributed import TCPStore\n\nlogger = logging.getLogger(__name__)\n\n# Global TCPStore that is created during distributed initialization\n# This is the single shared store that all components should use\n_global_tcp_store: Optional[TCPStore] = None\n\n\ndef set_global_tcp_store(store: TCPStore) -> None:\n    \"\"\"Set the global TCPStore instance.\n\n    This should be called during distributed initialization to make\n    the store available to all components that need it.\n    \"\"\"\n    global _global_tcp_store\n    _global_tcp_store = store\n    logger.info(\"Global TCPStore has been set\")\n\n\ndef get_global_tcp_store() -> Optional[TCPStore]:\n    \"\"\"Get the existing global TCPStore.\n\n    This function provides access to the shared TCPStore instance that was\n    created during distributed initialization. All components (like NIXL buffers)\n    should use this same store for coordination.\n\n    Returns:\n        The global TCPStore instance, or None if not initialized yet.\n    \"\"\"\n    global _global_tcp_store\n\n    if _global_tcp_store is None:\n        logger.warning(\n            \"Global TCPStore not found. Make sure init_distributed_environment \"\n            \"was called with a tcp:// init method.\"\n        )\n\n    return _global_tcp_store\n\n\ndef ensure_divisibility(numerator, denominator):\n    \"\"\"Ensure that numerator is divisible by the denominator.\"\"\"\n    assert numerator % denominator == 0, \"{} is not divisible by {}\".format(\n        numerator, denominator\n    )\n\n\ndef divide(numerator, denominator):\n    \"\"\"Ensure that numerator is divisible by the denominator and return\n    the division value.\"\"\"\n    ensure_divisibility(numerator, denominator)\n    return numerator // denominator\n\n\ndef split_tensor_along_last_dim(\n    tensor: torch.Tensor,\n    num_partitions: int,\n    contiguous_split_chunks: bool = False,\n) -> Sequence[torch.Tensor]:\n    \"\"\"Split a tensor along its last dimension.\n\n    Arguments:\n        tensor: input tensor.\n        num_partitions: number of partitions to split the tensor\n        contiguous_split_chunks: If True, make each chunk contiguous\n                                 in memory.\n\n    Returns:\n        A list of Tensors\n    \"\"\"\n    # Get the size and dimension.\n    last_dim = tensor.dim() - 1\n    last_dim_size = divide(tensor.size()[last_dim], num_partitions)\n    # Split.\n    tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)\n    # NOTE: torch.split does not create contiguous tensors by default.\n    if contiguous_split_chunks:\n        return tuple(chunk.contiguous() for chunk in tensor_list)\n\n    return tensor_list\n\n\ndef get_pp_indices(\n    num_hidden_layers: int, pp_rank: int, pp_size: int\n) -> Tuple[int, int]:\n    \"\"\"Try to evenly distribute layers across partitions.\n    If the number of layers is not divisible by the number of partitions,\n    the last N partitions will have one extra layer, where N = remainder.\n    \"\"\"\n    # partition_list_str can be set to None in sglang\n    partition_list_str = os.getenv(\"SGLANG_PP_LAYER_PARTITION\", None)\n    if partition_list_str is not None:\n        try:\n            partitions = [int(layer) for layer in partition_list_str.split(\",\")]\n        except ValueError as err:\n            raise ValueError(\n                \"Invalid partition string: {}\".format(partition_list_str)\n            ) from err\n        if len(partitions) != pp_size:\n            raise ValueError(f\"{len(partitions)=} does not match {pp_size=}.\")\n        if sum(partitions) != num_hidden_layers:\n            raise ValueError(f\"{sum(partitions)=} does not match {num_hidden_layers=}.\")\n        start_layer = sum(partitions[:pp_rank])\n        end_layer = start_layer + partitions[pp_rank]\n    else:\n        base_layers = num_hidden_layers // pp_size\n        remainder = num_hidden_layers % pp_size\n        # Distribute the extra layers to the last 'remainder' partitions\n        if pp_rank >= pp_size - remainder:\n            partitions_without_extra_layer = pp_size - remainder\n            # This partition gets one extra layer\n            start_layer = pp_rank * (base_layers + 1) - partitions_without_extra_layer\n            end_layer = start_layer + (base_layers + 1)\n        else:\n            # This partition gets only base layers\n            start_layer = pp_rank * base_layers\n            end_layer = start_layer + base_layers\n\n    return (start_layer, end_layer)\n\n\n@dataclasses.dataclass\nclass StatelessProcessGroup:\n    \"\"\"A dataclass to hold a metadata store, and the rank, world_size of the\n    group. Only use it to communicate metadata between processes.\n    For data-plane communication, create NCCL-related objects.\n    \"\"\"\n\n    rank: int\n    world_size: int\n    store: torch._C._distributed_c10d.Store\n    data_expiration_seconds: int = 3600  # 1 hour\n\n    # dst rank -> counter\n    send_dst_counter: Dict[int, int] = dataclasses.field(default_factory=dict)\n    # src rank -> counter\n    recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict)\n    broadcast_send_counter: int = 0\n    broadcast_recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict)\n\n    # A deque to store the data entries, with key and timestamp.\n    entries: Deque[Tuple[str, float]] = dataclasses.field(default_factory=deque)\n\n    def __post_init__(self):\n        assert self.rank < self.world_size\n        self.send_dst_counter = {i: 0 for i in range(self.world_size)}\n        self.recv_src_counter = {i: 0 for i in range(self.world_size)}\n        self.broadcast_recv_src_counter = {i: 0 for i in range(self.world_size)}\n\n    def send_obj(self, obj: Any, dst: int):\n        \"\"\"Send an object to a destination rank.\"\"\"\n        self.expire_data()\n        key = f\"send_to/{dst}/{self.send_dst_counter[dst]}\"\n        self.store.set(key, pickle.dumps(obj))\n        self.send_dst_counter[dst] += 1\n        self.entries.append((key, time.perf_counter()))\n\n    def expire_data(self):\n        \"\"\"Expire data that is older than `data_expiration_seconds` seconds.\"\"\"\n        while self.entries:\n            # check the oldest entry\n            key, timestamp = self.entries[0]\n            if time.perf_counter() - timestamp > self.data_expiration_seconds:\n                self.store.delete_key(key)\n                self.entries.popleft()\n            else:\n                break\n\n    def recv_obj(self, src: int) -> Any:\n        \"\"\"Receive an object from a source rank.\"\"\"\n        obj = pickle.loads(\n            self.store.get(f\"send_to/{self.rank}/{self.recv_src_counter[src]}\")\n        )\n        self.recv_src_counter[src] += 1\n        return obj\n\n    def broadcast_obj(self, obj: Optional[Any], src: int) -> Any:\n        \"\"\"Broadcast an object from a source rank to all other ranks.\n        It does not clean up after all ranks have received the object.\n        Use it for limited times, e.g., for initialization.\n        \"\"\"\n        if self.rank == src:\n            self.expire_data()\n            key = f\"broadcast_from/{src}/\" f\"{self.broadcast_send_counter}\"\n            self.store.set(key, pickle.dumps(obj))\n            self.broadcast_send_counter += 1\n            self.entries.append((key, time.perf_counter()))\n            return obj\n        else:\n            key = f\"broadcast_from/{src}/\" f\"{self.broadcast_recv_src_counter[src]}\"\n            recv_obj = pickle.loads(self.store.get(key))\n            self.broadcast_recv_src_counter[src] += 1\n            return recv_obj\n\n    def all_gather_obj(self, obj: Any) -> list[Any]:\n        \"\"\"All gather an object from all ranks.\"\"\"\n        gathered_objs = []\n        for i in range(self.world_size):\n            if i == self.rank:\n                gathered_objs.append(obj)\n                self.broadcast_obj(obj, src=self.rank)\n            else:\n                recv_obj = self.broadcast_obj(None, src=i)\n                gathered_objs.append(recv_obj)\n        return gathered_objs\n\n    def barrier(self):\n        \"\"\"A barrier to synchronize all ranks.\"\"\"\n        for i in range(self.world_size):\n            if i == self.rank:\n                self.broadcast_obj(None, src=self.rank)\n            else:\n                self.broadcast_obj(None, src=i)\n\n    @staticmethod\n    def create(\n        host: str,\n        port: int,\n        rank: int,\n        world_size: int,\n        data_expiration_seconds: int = 3600,\n    ) -> \"StatelessProcessGroup\":\n        \"\"\"A replacement for `torch.distributed.init_process_group` that does not\n        pollute the global state.\n\n        If we have process A and process B called `torch.distributed.init_process_group`\n        to form a group, and then we want to form another group with process A, B, C,\n        D, it is not possible in PyTorch, because process A and process B have already\n        formed a group, and process C and process D cannot join that group. This\n        function is a workaround for this issue.\n\n        `torch.distributed.init_process_group` is a global call, while this function\n        is a stateless call. It will return a `StatelessProcessGroup` object that can be\n        used for exchanging metadata. With this function, process A and process B\n        can call `StatelessProcessGroup.create` to form a group, and then process A, B,\n        C, and D can call `StatelessProcessGroup.create` to form another group.\n        \"\"\"  # noqa\n        store = TCPStore(\n            host_name=host,\n            port=port,\n            world_size=world_size,\n            is_master=(rank == 0),\n        )\n\n        return StatelessProcessGroup(\n            rank=rank,\n            world_size=world_size,\n            store=store,\n            data_expiration_seconds=data_expiration_seconds,\n        )\n"
  },
  {
    "path": "python/sglang/srt/dllm/algorithm/__init__.py",
    "content": "import importlib\nimport logging\nimport pkgutil\n\nfrom sglang.srt.dllm.config import DllmConfig\n\nlogger = logging.getLogger(__name__)\n\n\ndef import_algorithms():\n    mapping = {}\n    package_name = \"sglang.srt.dllm.algorithm\"\n    package = importlib.import_module(package_name)\n    for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + \".\"):\n        if ispkg:\n            continue\n        try:\n            module = importlib.import_module(name)\n        except Exception as e:\n            logger.warning(f\"Ignore import error when loading {name}: {e}\")\n            continue\n        if not hasattr(module, \"Algorithm\"):\n            continue\n\n        algo = module.Algorithm\n        mapping[algo.__name__] = algo\n\n    return mapping\n\n\ndef get_algorithm(config: DllmConfig):\n    try:\n        name = config.algorithm\n        return algo_name_to_cls[name](config)\n    except:\n        raise RuntimeError(f\"Unknown diffusion LLM algorithm: {name}\")\n\n\nalgo_name_to_cls = import_algorithms()\n"
  },
  {
    "path": "python/sglang/srt/dllm/algorithm/base.py",
    "content": "from sglang.srt.dllm.algorithm import get_algorithm\nfrom sglang.srt.dllm.config import DllmConfig\nfrom sglang.srt.server_args import ServerArgs\n\n\nclass DllmAlgorithm:\n\n    def __init__(\n        self,\n        config: DllmConfig,\n    ):\n        self.block_size = config.block_size\n        self.mask_id = config.mask_id\n\n    @staticmethod\n    def from_server_args(server_args: ServerArgs):\n        config = DllmConfig.from_server_args(server_args)\n        return get_algorithm(config)\n"
  },
  {
    "path": "python/sglang/srt/dllm/algorithm/joint_threshold.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn.functional as F\n\nfrom sglang.srt.dllm.algorithm.base import DllmAlgorithm\nfrom sglang.srt.dllm.config import DllmConfig\nfrom sglang.srt.layers.logits_processor import LogitsProcessorOutput\nfrom sglang.srt.model_executor.forward_batch_info import ForwardBatch\nfrom sglang.srt.model_executor.model_runner import ModelRunner\n\n\nclass JointThreshold(DllmAlgorithm):\n\n    def __init__(\n        self,\n        config: DllmConfig,\n    ):\n        super().__init__(config)\n        self.threshold = config.algorithm_config.get(\"threshold\", 0.5)\n        self.edit_threshold = config.algorithm_config.get(\"edit_threshold\", 0)\n        self.max_post_edit_steps = config.algorithm_config.get(\n            \"max_post_edit_steps\", 16\n        )\n        self.penalty_lambda = config.algorithm_config.get(\"penalty_lambda\", 0)\n\n    def run(\n        self,\n        model_runner: ModelRunner,\n        forward_batch: ForwardBatch,\n    ) -> tuple[LogitsProcessorOutput | torch.Tensor, torch.Tensor | None, bool]:\n        batch_size = forward_batch.batch_size\n        device = forward_batch.input_ids.device\n\n        mask_index = forward_batch.input_ids == self.mask_id\n        if not mask_index.any():\n            out = model_runner.forward(forward_batch, pp_proxy_tensors=None)\n            return out.logits_output, [], out.can_run_graph\n\n        start_list = []\n        prompt_masks = []\n        for i in range(batch_size):\n            block_start = i * self.block_size\n            block_end = block_start + self.block_size\n            block_input_ids = forward_batch.input_ids[block_start:block_end]\n\n            prompt_mask = block_input_ids != self.mask_id\n            prompt_masks.append(prompt_mask)\n            start_list.append(prompt_mask.sum().item())\n\n        post_edit_steps = torch.zeros(batch_size, dtype=torch.int32, device=device)\n\n        finished = torch.zeros(batch_size, dtype=torch.bool, device=device)\n        # Controls whether to perform an additional forward pass for KV cache persistence.\n        # For certain decoding rounds where the terminal step yields no state change,\n        # this can be set to False to bypass the overhead of an idle forward pass.\n        any_changed_in_last_step = False\n\n        max_iterations = self.block_size + self.max_post_edit_steps\n        for _ in range(max_iterations):\n            if finished.all():\n                break\n\n            out = model_runner.forward(forward_batch, pp_proxy_tensors=None)\n            logits_output, can_run_cuda_graph = out.logits_output, out.can_run_graph\n\n            any_changed_in_last_step = False\n\n            for i in range(batch_size):\n                if finished[i]:\n                    continue\n\n                block_start = i * self.block_size\n                block_end = block_start + self.block_size\n\n                curr_input_ids = forward_batch.input_ids[block_start:block_end]\n                curr_logits = logits_output.full_logits[block_start:block_end]\n                curr_prompt_mask = prompt_masks[i]\n\n                if self.penalty_lambda > 0:\n                    prev_ids = curr_input_ids[:-1]\n                    curr_logits[1:, :].scatter_(\n                        1, prev_ids.unsqueeze(-1), -self.penalty_lambda, reduce=\"add\"\n                    )\n\n                x = torch.argmax(curr_logits, dim=-1)\n                p = torch.squeeze(\n                    torch.gather(\n                        F.softmax(curr_logits, dim=-1),\n                        dim=-1,\n                        index=torch.unsqueeze(x, -1),\n                    ),\n                    -1,\n                )\n\n                mask_index = curr_input_ids == self.mask_id\n                has_mask = mask_index.any()\n\n                # Mask to token (M2T)\n                mask_transfer_index = torch.zeros_like(mask_index)\n                if has_mask:\n                    confidence = torch.where(mask_index, p, -np.inf)\n                    mask_transfer_index = confidence > self.threshold\n\n                    if not mask_transfer_index.any():\n                        _, select_index = torch.topk(confidence, k=1)\n                        mask_transfer_index[select_index] = True\n                else:\n                    post_edit_steps[i] += 1\n                    if post_edit_steps[i] > self.max_post_edit_steps:\n                        finished[i] = True\n                        continue\n\n                # Token to token (T2T)\n                edit_mask = ~mask_index & ~curr_prompt_mask\n                edit_transfer_index = (\n                    (p > self.edit_threshold) & (curr_input_ids != x) & edit_mask\n                )\n\n                transfer_index = mask_transfer_index | edit_transfer_index\n                if not transfer_index.any():\n                    finished[i] = True\n                    continue\n\n                curr_input_ids[transfer_index] = x[transfer_index]\n                any_changed_in_last_step = True\n\n        if any_changed_in_last_step:\n            out = model_runner.forward(forward_batch, pp_proxy_tensors=None)\n            logits_output, can_run_cuda_graph = out.logits_output, out.can_run_graph\n\n        next_token_ids = torch.reshape(forward_batch.input_ids, (batch_size, -1))\n        next_token_ids_list = [\n            next_token_ids[i, start_list[i] :] for i in range(batch_size)\n        ]\n\n        return logits_output, next_token_ids_list, can_run_cuda_graph\n\n\nAlgorithm = JointThreshold\n"
  },
  {
    "path": "python/sglang/srt/dllm/algorithm/low_confidence.py",
    "content": "from typing import List, Tuple, Union\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\n\nfrom sglang.srt.dllm.algorithm.base import DllmAlgorithm\nfrom sglang.srt.dllm.config import DllmConfig\nfrom sglang.srt.layers.logits_processor import LogitsProcessorOutput\nfrom sglang.srt.model_executor.forward_batch_info import ForwardBatch\nfrom sglang.srt.model_executor.model_runner import ModelRunner\n\n\nclass LowConfidence(DllmAlgorithm):\n\n    def __init__(\n        self,\n        config: DllmConfig,\n    ):\n        super().__init__(config)\n        self.threshold = config.algorithm_config.get(\"threshold\", 0.95)\n\n    def run(\n        self,\n        model_runner: ModelRunner,\n        forward_batch: ForwardBatch,\n    ) -> Tuple[Union[LogitsProcessorOutput, torch.Tensor], List[torch.Tensor], bool]:\n        batch_size = forward_batch.batch_size\n        # Here, the forward_batch full logits contains all the blocks\n        # such as [dllm_block_size * batch_size, hidden_size]\n        start_list = []\n        mask_index = forward_batch.input_ids == self.mask_id\n\n        # Fast path: if there is no mask token, forward and save kv cache\n        if torch.sum(mask_index).item() == 0:\n            out = model_runner.forward(forward_batch, pp_proxy_tensors=None)\n            logits_output, can_run_cuda_graph = out.logits_output, out.can_run_graph\n\n            next_token_ids = []\n            return logits_output, next_token_ids, can_run_cuda_graph\n\n        # Calculate start positions for each block\n        for block_id in range(batch_size):\n            block_start = block_id * self.block_size\n            block_end = block_start + self.block_size\n            block_input_ids = forward_batch.input_ids[block_start:block_end]\n            block_mask_index = block_input_ids == self.mask_id\n            start = self.block_size - torch.sum(block_mask_index).item()\n            start_list.append(start)\n\n        for _ in range(self.block_size):\n            mask_index = forward_batch.input_ids == self.mask_id\n            if torch.sum(mask_index).item() == 0:\n                break\n\n            out = model_runner.forward(forward_batch, pp_proxy_tensors=None)\n            logits_output, can_run_cuda_graph = out.logits_output, out.can_run_graph\n            assert batch_size == forward_batch.input_ids.shape[0] // self.block_size\n            for batch_id in range(batch_size):\n                curr_block_start = batch_id * self.block_size\n                curr_block_end = curr_block_start + self.block_size\n                block_input_ids = forward_batch.input_ids[\n                    curr_block_start:curr_block_end,\n                ]\n                block_mask_index = block_input_ids == self.mask_id\n                if torch.sum(block_mask_index).item() == 0:\n                    continue\n                curr_logits = logits_output.full_logits[\n                    curr_block_start:curr_block_end,\n                ]\n\n                x = torch.argmax(curr_logits, dim=-1)\n                p = torch.squeeze(\n                    torch.gather(\n                        F.softmax(curr_logits, dim=-1),\n                        dim=-1,\n                        index=torch.unsqueeze(x, -1),\n                    ),\n                    -1,\n                )\n                x = torch.where(block_mask_index, x, block_input_ids)\n                confidence = torch.where(block_mask_index, p, -np.inf)\n\n                transfer_index = confidence > self.threshold\n\n                if transfer_index.sum().item() == 0:\n                    _, select_index = torch.topk(confidence, k=1)\n                    transfer_index[select_index] = True\n\n                block_input_ids[transfer_index] = x[transfer_index]\n\n        out = model_runner.forward(forward_batch, pp_proxy_tensors=None)\n        logits_output, can_run_cuda_graph = out.logits_output, out.can_run_graph\n        # Here next token ids is tricky to implement the dynamic lengths,\n        # so we return a list of tensors\n        next_token_ids = torch.reshape(forward_batch.input_ids, (batch_size, -1))\n        next_token_ids_list = [\n            next_token_ids[i, start_list[i] :] for i in range(batch_size)\n        ]\n\n        return logits_output, next_token_ids_list, can_run_cuda_graph\n\n\nAlgorithm = LowConfidence\n"
  },
  {
    "path": "python/sglang/srt/dllm/config.py",
    "content": "from typing import Any\n\nfrom sglang.srt.configs.model_config import ModelConfig\nfrom sglang.srt.server_args import ServerArgs\n\n\nclass DllmConfig:\n    def __init__(\n        self,\n        algorithm: str,\n        algorithm_config: dict[str, Any],\n        block_size: int,\n        mask_id: int,\n        max_running_requests: int,\n    ):\n        self.algorithm = algorithm\n        self.algorithm_config = algorithm_config\n        self.block_size = block_size\n        self.mask_id = mask_id\n        self.max_running_requests = max_running_requests\n\n    @staticmethod\n    def from_server_args(\n        server_args: ServerArgs,\n    ):\n        if server_args.dllm_algorithm is None:\n            return None\n\n        model_config = ModelConfig.from_server_args(\n            server_args,\n            model_path=server_args.model_path,\n            model_revision=server_args.revision,\n        )\n        DLLM_PARAMS = {\n            \"LLaDA2MoeModelLM\": {\"block_size\": 32, \"mask_id\": 156895},\n            \"SDARForCausalLM\": {\"block_size\": 4, \"mask_id\": 151669},\n            \"SDARMoeForCausalLM\": {\"block_size\": 4, \"mask_id\": 151669},\n        }\n\n        arch = model_config.hf_config.architectures[0]\n        if arch in DLLM_PARAMS:\n            params = DLLM_PARAMS[arch]\n            block_size = params[\"block_size\"]\n            mask_id = params[\"mask_id\"]\n        else:\n            raise RuntimeError(f\"Unknown diffusion LLM: {arch}\")\n\n        max_running_requests = (\n            1\n            if server_args.max_running_requests is None\n            else server_args.max_running_requests\n        )\n\n        algorithm_config = {}\n        if server_args.dllm_algorithm_config is not None:\n            try:\n                import yaml\n            except ImportError:\n                raise ImportError(\n                    \"Please install PyYAML to use YAML config files. \"\n                    \"`pip install pyyaml`\"\n                )\n            with open(server_args.dllm_algorithm_config, \"r\") as f:\n                algorithm_config = yaml.safe_load(f)\n\n            # Parse common algorithm configurations\n            block_size = algorithm_config.get(\"block_size\", block_size)\n\n        return DllmConfig(\n            algorithm=server_args.dllm_algorithm,\n            algorithm_config=algorithm_config,\n            block_size=block_size,\n            mask_id=mask_id,\n            max_running_requests=max_running_requests,\n        )\n"
  },
  {
    "path": "python/sglang/srt/dllm/mixin/req.py",
    "content": "from __future__ import annotations\n\nimport enum\nfrom typing import TYPE_CHECKING, Optional\n\nfrom sglang.srt.dllm.config import DllmConfig\n\nif TYPE_CHECKING:\n    from sglang.srt.managers.schedule_batch import Req\n\n\nclass DllmReqPhase(str, enum.Enum):\n    STAGING_PREFILL = \"staging_prefill\"\n    STAGING_DECODE = \"staging_decode\"\n    INCOMING_PREFILL = \"incoming_prefill\"\n    INCOMING_DECODE = \"incoming_decode\"\n\n\nclass ReqDllmMixin:\n    def init_diffusion_llm(self: Req, dllm_config: DllmConfig):\n        self.dllm_phase: Optional[DllmReqPhase] = None\n        self.dllm_block_offset = 0\n        self.dllm_config = dllm_config\n\n        if self.dllm_config is not None:\n            if len(self.origin_input_ids) < self.dllm_config.block_size:\n                self.dllm_phase = DllmReqPhase.INCOMING_DECODE\n            else:\n                self.dllm_phase = DllmReqPhase.INCOMING_PREFILL\n\n    def is_dllm(self: Req) -> bool:\n        return self.dllm_config is not None\n\n    def is_dllm_prefill(self: Req) -> bool:\n        return self.dllm_phase in [\n            DllmReqPhase.STAGING_PREFILL,\n            DllmReqPhase.INCOMING_PREFILL,\n        ]\n\n    def determine_dllm_phase(self: Req):\n        prefix_length = len(self.prefix_indices)\n        min_required_length = prefix_length + self.dllm_config.block_size\n\n        if len(self.fill_ids) < min_required_length:\n            # still incoming stage\n            return\n\n        input_block = self.fill_ids[prefix_length:min_required_length]\n        is_prefill_phase = self.dllm_config.mask_id not in input_block\n\n        if is_prefill_phase:\n            self.dllm_phase = DllmReqPhase.STAGING_PREFILL\n        else:\n            self.dllm_phase = DllmReqPhase.STAGING_DECODE\n\n    def _init_fill_ids_for_dllm(self: Req):\n        self.dllm_block_offset = (\n            0\n            if not self.fill_ids\n            else self.dllm_block_offset + self.dllm_config.block_size\n        )\n        self.fill_ids = (\n            self.origin_input_ids\n            + self.output_ids\n            + [self.dllm_config.mask_id] * self.dllm_config.block_size\n        )\n\n    def _update_block_offset_for_dllm(self):\n        prefix_len = len(self.prefix_indices)\n        assert (\n            prefix_len % self.dllm_config.block_size == 0\n        ), f\"Unexpected prefix len: {prefix_len}\"\n        if prefix_len > self.dllm_block_offset:\n            self.dllm_block_offset = prefix_len\n"
  },
  {
    "path": "python/sglang/srt/dllm/mixin/scheduler.py",
    "content": "from __future__ import annotations\n\nimport logging\nfrom typing import TYPE_CHECKING, List, Optional, Set, Union\n\nfrom sglang.srt.dllm.config import DllmConfig\nfrom sglang.srt.dllm.mixin.req import DllmReqPhase\nfrom sglang.srt.managers.schedule_batch import Req, ScheduleBatch\nfrom sglang.srt.managers.schedule_policy import AddReqResult, PrefillAdder\nfrom sglang.srt.mem_cache.common import release_kv_cache\nfrom sglang.srt.model_executor.forward_batch_info import ForwardMode\nfrom sglang.srt.observability.req_time_stats import set_time_batch\n\nlogger = logging.getLogger(__name__)\n\nif TYPE_CHECKING:\n    from sglang.srt.managers.scheduler import GenerationBatchResult, Scheduler\n\n\nclass SchedulerDllmMixin:\n    def init_diffusion_llm(self: Scheduler):\n        self.dllm_config = (\n            DllmConfig.from_server_args(self.server_args)\n            if self.server_args.dllm_algorithm is not None\n            else None\n        )\n        self.dllm_manager = DllmManager(dllm_config=self.dllm_config)\n\n    def get_new_batch_dllm(self: Scheduler) -> Optional[ScheduleBatch]:\n        \"\"\"Generate a new batch for DLLM (Diffusion LLM) scheduling.\"\"\"\n        if self.enable_priority_preemption:\n            self.running_batch.batch_is_full = False\n\n        # Early exit if batch is full or no requests available\n        if self._should_skip_prefill():\n            return None\n\n        running_bs = len(self.running_batch.reqs)\n        self.policy.calc_priority(self.waiting_queue)\n\n        # Create prefill adder with resource constraints\n        adder = self._create_dllm_prefill_adder(running_bs)\n\n        # Initialize DLLM manager and transfer requests\n        self.dllm_manager.init_next_round()\n        self._fetch_waiting_reqs()\n\n        # Process batches\n        forward_mode = self._process_dllm_batches(adder)\n\n        can_run_list = adder.can_run_list\n        if not can_run_list:\n            return None\n\n        # Record metrics and update state\n        set_time_batch(can_run_list, \"set_forward_entry_time\")\n        self._update_state_for_batch(can_run_list, adder, running_bs)\n\n        # Create and prepare batch\n        new_batch = self._create_dllm_batch(can_run_list, forward_mode)\n        return new_batch\n\n    def process_batch_result_dllm(\n        self: Scheduler,\n        batch: ScheduleBatch,\n        result: GenerationBatchResult,\n    ):\n        if result.copy_done is not None:\n            result.copy_done.synchronize()\n\n        if result.next_token_ids:\n            self.token_to_kv_pool_allocator.free_group_begin()\n\n            for idx in range(batch.batch_size()):\n                req = batch.reqs[idx]\n\n                next_token_ids = result.next_token_ids[idx].tolist()\n                new_tokens = len(next_token_ids)\n                if new_tokens == 0:\n                    continue\n\n                req.fill_ids[-new_tokens:] = next_token_ids[:]\n                self.num_generated_tokens += new_tokens\n\n                req.output_ids.extend(next_token_ids)\n                req.check_finished(new_accepted_len=new_tokens)\n\n                if req.finished():\n                    release_kv_cache(req, self.tree_cache)\n                    req.time_stats.set_completion_time()\n\n            self.stream_output(batch.reqs, batch.return_logprob)\n            self.token_to_kv_pool_allocator.free_group_end()\n\n        can_run_cuda_graph = getattr(result, \"can_run_cuda_graph\", False)\n        self.report_prefill_stats(\n            prefill_stats=batch.prefill_stats,\n            can_run_cuda_graph=can_run_cuda_graph,\n            dp_cooperation_info=batch.dp_cooperation_info,\n        )\n\n    def _fetch_waiting_reqs(self: Scheduler):\n        # Calculate how many requests can be added to DLLM manager\n        max_dllm_capacity = self.dllm_config.max_running_requests - len(\n            self.dllm_manager.waiting_queue\n        )\n        num_requests_to_add = min(max_dllm_capacity, len(self.waiting_queue))\n\n        if num_requests_to_add > 0:\n            requests_to_add = self.waiting_queue[:num_requests_to_add]\n            self.dllm_manager.add_waiting_reqs(requests_to_add)\n            self.waiting_queue = self.waiting_queue[num_requests_to_add:]\n\n    def _should_skip_prefill(self: Scheduler) -> bool:\n        \"\"\"Check if DLLM prefill should be skipped.\"\"\"\n        if (\n            self.running_batch.batch_is_full or not self.waiting_queue\n        ) and self.dllm_manager.is_empty():\n            return True\n\n        running_bs = len(self.running_batch.reqs)\n        if (\n            self.get_num_allocatable_reqs(running_bs) <= 0\n            and self.dllm_manager.is_empty()\n            and not self.enable_priority_preemption\n        ):\n            self.running_batch.batch_is_full = True\n            return True\n\n        return False\n\n    def _create_dllm_prefill_adder(self: Scheduler, running_bs: int) -> PrefillAdder:\n        \"\"\"Create a prefill adder configured for DLLM scheduling.\"\"\"\n        return PrefillAdder(\n            self.page_size,\n            self.tree_cache,\n            self.token_to_kv_pool_allocator,\n            self.running_batch,\n            self.new_token_ratio,\n            self.max_prefill_tokens,\n            self.chunked_prefill_size,\n            running_bs if self.is_mixed_chunk else 0,\n            self.priority_scheduling_preemption_threshold,\n            prefill_max_requests=self.server_args.prefill_max_requests,\n            dllm_config=self.dllm_config,\n        )\n\n    def _process_dllm_batches(self: Scheduler, adder: PrefillAdder) -> ForwardMode:\n        \"\"\"Process prefill or decode batches for DLLM.\"\"\"\n        forward_mode = ForwardMode.DLLM_EXTEND\n\n        # Try prefill batch first\n        prefill_reqs = self.dllm_manager.get_prefill_requests()\n        if prefill_reqs:\n            self._process_batch_by_phase(\n                adder,\n                prefill_reqs,\n                DllmReqPhase.STAGING_PREFILL,\n                DllmReqPhase.INCOMING_PREFILL,\n            )\n        else:\n            # Fall back to decode batch\n            decode_reqs = self.dllm_manager.get_decode_requests()\n            self._process_batch_by_phase(\n                adder,\n                decode_reqs,\n                DllmReqPhase.STAGING_DECODE,\n                DllmReqPhase.INCOMING_DECODE,\n            )\n\n        return forward_mode\n\n    def _process_batch_by_phase(\n        self,\n        adder: PrefillAdder,\n        batch: List[Req],\n        staging_phase: DllmReqPhase,\n        incoming_phase: DllmReqPhase,\n    ) -> None:\n        \"\"\"Process a batch, separating staging and incoming requests.\"\"\"\n        staging_reqs = [req for req in batch if req.dllm_phase == staging_phase]\n        if staging_reqs:\n            staging_result = self.process_dllm_staging_reqs(adder, staging_reqs)\n            if staging_result != AddReqResult.CONTINUE:\n                return\n\n        incoming_reqs = [req for req in batch if req.dllm_phase == incoming_phase]\n        if incoming_reqs:\n            self.process_dllm_incoming_reqs(adder, incoming_reqs)\n\n    def _update_state_for_batch(\n        self: Scheduler, can_run_list: List[Req], adder: PrefillAdder, running_bs: int\n    ) -> None:\n        \"\"\"Update state for the batch.\"\"\"\n\n        if adder.preempt_list:\n            for req in adder.preempt_list:\n                self._add_request_to_queue(req)\n\n        if can_run_list:\n            self.dllm_manager.add_staging_reqs(can_run_list)\n            self.dllm_manager.increment_chunked_count()\n\n        self.adder = adder\n        self.can_run_list = can_run_list\n        self.running_bs = len(self.running_batch.reqs)\n\n    def _create_dllm_batch(\n        self: Scheduler, can_run_list: List[Req], forward_mode: ForwardMode\n    ) -> ScheduleBatch:\n        \"\"\"Create and prepare a new DLLM batch.\"\"\"\n        new_batch = ScheduleBatch.init_new(\n            can_run_list,\n            self.req_to_token_pool,\n            self.token_to_kv_pool_allocator,\n            self.tree_cache,\n            self.model_config,\n            self.enable_overlap,\n            self.spec_algorithm,\n            dllm_config=self.dllm_config,\n        )\n        new_batch.prepare_for_extend()\n        new_batch.forward_mode = forward_mode\n        new_batch.decoding_reqs = None\n\n        # Record prefill stats for logging after forward\n        from sglang.srt.observability.scheduler_metrics_mixin import PrefillStats\n\n        new_batch.prefill_stats = PrefillStats.from_adder(\n            self.adder, self.running_batch.reqs, self.enable_priority_scheduling\n        )\n\n        return new_batch\n\n    def process_dllm_incoming_reqs(\n        self: Scheduler, adder: PrefillAdder, reqs: List[Req]\n    ) -> AddReqResult:\n        \"\"\"Process incoming DLLM requests with resource allocation and preemption.\"\"\"\n        res = AddReqResult.CONTINUE\n        for req in reqs:\n            # Check if batch is full\n            running_bs = len(self.running_batch.reqs)\n            if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):\n                self.running_batch.batch_is_full = True\n\n            # Try preemption if batch is full\n            if self.running_batch.batch_is_full:\n                if (\n                    not self.enable_priority_preemption\n                    or not adder.preempt_to_schedule(req, self.server_args)\n                ):\n                    break\n\n            # Prepare and add request\n            req.init_next_round_input(self.tree_cache)\n            res = adder.add_one_req(\n                req,\n                has_chunked_req=True,\n                truncation_align_size=self.truncation_align_size,\n            )\n\n            if res != AddReqResult.CONTINUE:\n                if res == AddReqResult.NO_TOKEN:\n                    self.running_batch.batch_is_full = True\n                break\n\n        return res\n\n    def process_dllm_staging_reqs(\n        self: Scheduler, adder: PrefillAdder, reqs: List[Req]\n    ) -> AddReqResult:\n        \"\"\"Process staging DLLM requests with resource allocation.\"\"\"\n        for req in reqs:\n            res = adder.add_dllm_staging_req(req)\n            if res == AddReqResult.NO_TOKEN:\n                return res\n\n        return AddReqResult.CONTINUE\n\n\nclass DllmManager:\n    \"\"\"\n    Manager for Diffusion LLM request scheduling.\n\n    Maintains two queues:\n    - waiting_queue: The requests waiting to be scheduled with max running requests limit\n    - staging_queue: Requests allocated resources by PrefillAdder\n    \"\"\"\n\n    def __init__(self, dllm_config: Optional[DllmConfig] = None):\n        self.dllm_config = dllm_config\n        self.max_running_reqs = (\n            dllm_config.max_running_requests if dllm_config is not None else 1\n        )\n        self.waiting_queue: List[Req] = []\n        self.staging_queue: List[Req] = []\n\n    def get_prefill_requests(self) -> List[Req]:\n        \"\"\"Get all prefill requests from waiting queue.\"\"\"\n        return [req for req in self.waiting_queue if req.is_dllm_prefill()]\n\n    def get_decode_requests(self) -> List[Req]:\n        \"\"\"Get all decode requests from waiting queue.\"\"\"\n        return [req for req in self.waiting_queue if not req.is_dllm_prefill()]\n\n    def add_waiting_reqs(self, reqs: Union[Req, List[Req]]) -> None:\n        \"\"\"Add requests to waiting queue with redundancy check.\"\"\"\n        assert self.dllm_config is not None, \"Diffusion LLM config is not set.\"\n\n        reqs_to_add = reqs if isinstance(reqs, list) else [reqs]\n\n        # Check for duplicate request IDs\n        if self._has_duplicate_reqs(reqs_to_add):\n            raise RuntimeError(\"Redundant requests detected in dLLM requests.\")\n\n        self.waiting_queue.extend(reqs_to_add)\n\n    def add_staging_reqs(self, reqs: Union[Req, List[Req]]) -> None:\n        \"\"\"Add requests to staging queue (allocated by PrefillAdder).\"\"\"\n        reqs_to_add = reqs if isinstance(reqs, list) else [reqs]\n        self.staging_queue.extend(reqs_to_add)\n\n    def _has_duplicate_reqs(self, reqs: List[Req]) -> bool:\n        \"\"\"Check if any request ID already exists in waiting queue.\"\"\"\n        existing_rids: Set[str] = {r.rid for r in self.waiting_queue}\n        return any(req.rid in existing_rids for req in reqs)\n\n    def any_staging_reqs(self) -> bool:\n        \"\"\"Check if there are requests in staging queue.\"\"\"\n        return self.dllm_config is not None and len(self.staging_queue) > 0\n\n    def is_empty(self) -> bool:\n        \"\"\"Check if both queues are empty or DLLM is not configured.\"\"\"\n        if self.dllm_config is None:\n            return True\n        return len(self.waiting_queue) == 0\n\n    def increment_chunked_count(self) -> None:\n        \"\"\"Increment chunked count for all staging requests.\"\"\"\n        for req in self.staging_queue:\n            req.is_chunked += 1\n\n    def filter_finished_reqs(self) -> None:\n        \"\"\"Remove finished requests from both queues.\"\"\"\n        self.waiting_queue = [req for req in self.waiting_queue if not req.finished()]\n        self.staging_queue = [req for req in self.staging_queue if not req.finished()]\n\n    def init_next_round(self) -> None:\n        \"\"\"Initialize staging requests for next round and clear staging queue.\"\"\"\n        for req in self.staging_queue:\n            req.init_next_round_input()\n        self.staging_queue = []\n"
  },
  {
    "path": "python/sglang/srt/elastic_ep/elastic_ep.py",
    "content": "from __future__ import annotations\n\nfrom dataclasses import dataclass\nfrom typing import Optional\n\nimport torch\n\nfrom sglang.srt.managers.schedule_batch import ServerArgs\nfrom sglang.srt.utils import is_cpu, is_cuda\n\n\n@dataclass\nclass ElasticEPState:\n    active_ranks: Optional[torch.Tensor]\n    last_active_ranks: Optional[torch.Tensor]\n    active_ranks_cpu: Optional[torch.Tensor]\n\n    def is_active_equal_last(self) -> bool:\n        return torch.equal(self.active_ranks, self.last_active_ranks)\n\n    def sync_active_to_cpu(self):\n        if self.active_ranks is not None:\n            self.active_ranks_cpu = self.active_ranks.detach().cpu().clone()\n\n    def snapshot_active_to_last(self):\n        if self.active_ranks is not None:\n            self.last_active_ranks = self.active_ranks.clone()\n\n\nclass ElasticEPStateManager:\n    _instance: Optional[ElasticEPState] = None\n\n    @classmethod\n    def instance(cls) -> ElasticEPState:\n        return cls._instance\n\n    @classmethod\n    def init(cls, server_args: ServerArgs):\n        if cls._instance is not None:\n            return cls._instance\n\n        if server_args.elastic_ep_backend is not None:\n            cls._instance = cls._build_state(ep_size=None, device=None)\n        return cls._instance\n\n    @staticmethod\n    def _select_device() -> torch.device:\n        if is_cuda():\n            return torch.device(\"cuda\")\n        elif is_cpu():\n            return torch.device(\"cpu\")\n        else:\n            raise NotImplementedError(\"Only CUDA and CPU support elastic ep now.\")\n\n    @classmethod\n    def _build_state(\n        cls, *, ep_size: Optional[int] = None, device: Optional[torch.device] = None\n    ) -> ElasticEPState:\n        active = cls.healthy_rank_state(ep_size=ep_size, device=device)\n        return ElasticEPState(\n            active_ranks=active,\n            last_active_ranks=active.clone(),\n            active_ranks_cpu=active.detach().cpu().clone(),\n        )\n\n    @classmethod\n    def healthy_rank_state(\n        cls, *, ep_size: Optional[int] = None, device: Optional[torch.device] = None\n    ) -> torch.Tensor:\n        size = ep_size if ep_size is not None else torch.distributed.get_world_size()\n        dev = device if device is not None else cls._select_device()\n\n        return torch.ones(size, dtype=torch.int32, device=dev)\n"
  },
  {
    "path": "python/sglang/srt/elastic_ep/expert_backup_client.py",
    "content": "import logging\nimport re\nimport threading\nimport time\n\nimport torch\nimport zmq\n\nfrom sglang.srt.distributed.parallel_state import (\n    get_world_group,\n    get_world_size,\n)\nfrom sglang.srt.environ import envs\nfrom sglang.srt.eplb.expert_location import get_global_expert_location_metadata\nfrom sglang.srt.managers.io_struct import UpdateExpertBackupReq\nfrom sglang.srt.server_args import ServerArgs\nfrom sglang.srt.utils.network import get_local_ip_auto\n\nPORT_BASE = envs.SGLANG_BACKUP_PORT_BASE.get()\nlogger = logging.getLogger(__name__)\n\n\ndef extract_layer_and_expert_id(param_name):\n    pattern = r\"layers\\.(\\d+)\\.mlp\\.experts\\.(\\d+)\\.(.+?)\\.\"\n    match = re.search(pattern, param_name)\n    if match:\n        return int(match.group(1)), int(match.group(2)), match.group(3)\n    return -1, -1, \"\"\n\n\nclass ExpertBackupClient:\n    def __init__(self, server_args: ServerArgs, model_runner):\n        context = zmq.Context(2)\n        self.server_args = server_args\n        self.engine_num = server_args.nnodes\n        self.engine_rank = server_args.node_rank\n        self.recv_list = [None] * self.engine_num\n        self.ready_sockets = [None] * self.engine_num\n        self.model_runner = model_runner\n        self.moe_ep_size = model_runner.moe_ep_size\n        self.model_config = model_runner.model_config\n        self.moe_ep_rank = model_runner.moe_ep_rank\n        self.dram_map_list = [None] * self.engine_num\n        self.session_id_list = [None] * self.engine_num\n        self.transfer_engine = None\n        self.gpu_buffer = None\n        self.buffer_size = 0\n        self.use_backup = False\n        local_ip = get_local_ip_auto()\n        all_ips = [None] * get_world_size()\n        torch.distributed.all_gather_object(\n            all_ips, local_ip, group=get_world_group().cpu_group\n        )\n        logger.info(f\"all_ips: {all_ips}\")\n\n        for i in range(self.engine_num):\n            self.recv_list[i] = context.socket(zmq.SUB)\n            self.recv_list[i].connect(\n                f\"tcp://{all_ips[i * get_world_size() // server_args.nnodes]}:{PORT_BASE + i * 2 + 1}\"\n            )\n            self.recv_list[i].setsockopt(zmq.SUBSCRIBE, b\"\")\n\n            # Synchronization channel to notify the manager when this client is ready.\n            self.ready_sockets[i] = context.socket(zmq.PUSH)\n            self.ready_sockets[i].connect(\n                f\"tcp://{all_ips[i * get_world_size() // server_args.nnodes]}:{PORT_BASE + i * 2}\"\n            )\n            self.ready_sockets[i].send_pyobj(UpdateExpertBackupReq())\n\n        self._receive_thread = threading.Thread(target=self._receive_loop, daemon=True)\n        self._receive_thread.start()\n\n    def _receive_loop(self):\n        cnt = 0\n        while cnt < self.engine_num:\n            response = self.recv_list[cnt].recv_pyobj()\n            self.dram_map_list[response.rank] = response.weight_pointer_map\n            self.session_id_list[response.rank] = response.session_id\n            self.buffer_size = max(self.buffer_size, response.buffer_size)\n            cnt += 1\n\n        self.use_backup = True\n        self.start_transfer_client()\n\n    def start_transfer_client(self):\n        from sglang.srt.distributed.parallel_state import get_mooncake_transfer_engine\n\n        self.transfer_engine = get_mooncake_transfer_engine()\n\n        self.params_dict = dict(self.model_runner.model.named_parameters())\n        for name, param in self.params_dict.items():\n            param_data = param.data\n            ret_value = self.transfer_engine.engine.register_memory(\n                param_data.data_ptr(), param_data.numel() * param_data.element_size()\n            )\n            if ret_value != 0:\n                self.use_backup = False\n                logger.warning(\"Register fails. Stop using expert weight backup!\")\n                break\n\n    def update_weights(self, weight_name_filter=None):\n        global_expert_location_metadata = get_global_expert_location_metadata()\n        num_experts = (\n            self.model_config.hf_config.n_routed_experts\n            + self.server_args.ep_num_redundant_experts\n        )\n        num_local_experts = num_experts // self.moe_ep_size\n        for i in range(self.engine_num):\n            server_ptr_list = []\n            local_ptr_list = []\n            weight_size_list = []\n\n            for name, weight_info in self.dram_map_list[i].items():\n                if weight_name_filter is not None and not weight_name_filter(name):\n                    continue\n                layer_id, expert_id, weight_name = extract_layer_and_expert_id(name)\n                if layer_id >= self.model_config.hf_config.num_hidden_layers:\n                    continue\n\n                if weight_name == \"gate_proj\":\n                    shard_id = \"w1\"\n                    param_name = \"experts.w13_\"\n                elif weight_name == \"down_proj\":\n                    shard_id = \"w2\"\n                    param_name = \"experts.w2_\"\n                elif weight_name == \"up_proj\":\n                    shard_id = \"w3\"\n                    param_name = \"experts.w13_\"\n                else:\n                    raise RuntimeError(f\"Unknown weight name {weight_name}\")\n\n                name = name.replace(f\"experts.{expert_id}.{weight_name}.\", param_name)\n                weight_param = self.params_dict[name]\n\n                physical_expert_ids = (\n                    global_expert_location_metadata.logical_to_all_physical(\n                        layer_id, expert_id\n                    )\n                )\n                for physical_expert_id in physical_expert_ids:\n                    if physical_expert_id not in range(\n                        num_local_experts * self.moe_ep_rank,\n                        num_local_experts * (self.moe_ep_rank + 1),\n                    ):\n                        continue\n                    param = weight_param[physical_expert_id % num_local_experts]\n                    if shard_id == \"w1\":\n                        param = param.narrow(0, 0, param.shape[0] // 2)\n                    elif shard_id == \"w3\":\n                        param = param.narrow(\n                            0, param.shape[0] // 2, param.shape[0] // 2\n                        )\n                    server_ptr_list.append(weight_info[\"weight_ptr\"])\n                    local_ptr_list.append(param.data_ptr())\n                    assert (\n                        param.numel() * param.element_size() == weight_info[\"byte_size\"]\n                    )\n                    weight_size_list.append(weight_info[\"byte_size\"])\n            before_transfer = time.time()\n            ret = self.transfer_engine.engine.batch_transfer_sync_read(\n                self.session_id_list[i],\n                local_ptr_list,\n                server_ptr_list,\n                weight_size_list,\n            )\n            after_transfer = time.time()\n            logger.info(f\"transfer time = {after_transfer - before_transfer} s\")\n\n            if ret != 0:\n                raise RuntimeError(\n                    f\"Failed to read weights from backup, error code: {ret}\"\n                )\n        return\n"
  },
  {
    "path": "python/sglang/srt/elastic_ep/expert_backup_manager.py",
    "content": "import logging\nimport multiprocessing as mp\nimport re\nimport signal\n\nimport torch\nimport zmq\n\nfrom sglang.srt.configs.load_config import LoadConfig\nfrom sglang.srt.configs.model_config import ModelConfig\nfrom sglang.srt.environ import envs\nfrom sglang.srt.managers.io_struct import BackupDramReq\nfrom sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader\nfrom sglang.srt.model_loader.utils import set_default_torch_dtype\nfrom sglang.srt.server_args import (\n    PortArgs,\n    ServerArgs,\n    set_global_server_args_for_scheduler,\n)\nfrom sglang.srt.utils.network import get_local_ip_auto\n\nPORT_BASE = envs.SGLANG_BACKUP_PORT_BASE.get()\nlogger = logging.getLogger(__name__)\n\n\ndef extract_expert_id(param_name):\n    pattern = r\"\\.experts\\.(\\d+)\\.\"\n    match = re.search(pattern, param_name)\n    if match:\n        return int(match.group(1))\n    return -1\n\n\nclass ExpertBackupManager:\n    def __init__(self, server_args: ServerArgs, port_args: PortArgs):\n        self.load_format = server_args.load_format\n        self.model_config = ModelConfig.from_server_args(server_args)\n        self.continuous_buffer = None\n        self.weight_pointer_map = {}\n        self.transfer_engine = None\n        self.session_id = None\n        self.engine_num = server_args.nnodes\n        self.engine_rank = server_args.node_rank\n        self.expert_num = self.model_config.hf_config.n_routed_experts\n        self.idmn = (self.expert_num // self.engine_num) * self.engine_rank\n        self.idmx = (self.expert_num // self.engine_num) * (self.engine_rank + 1)\n        context = zmq.Context(2)\n        # Synchronization socket to avoid PUB/SUB slow joiner issues.\n        self.recv_from_expert_backup_client = context.socket(zmq.PULL)\n        self.recv_from_expert_backup_client.bind(\n            f\"tcp://{get_local_ip_auto()}:{PORT_BASE + server_args.node_rank * 2}\"\n        )\n        self.send_to_expert_backup_client = context.socket(zmq.PUB)\n        self.send_to_expert_backup_client.bind(\n            f\"tcp://{get_local_ip_auto()}:{PORT_BASE + server_args.node_rank * 2 + 1}\"\n        )\n        self.backup_weights_from_disk()\n        self.start_transfer_server()\n\n        # Block until all expert backup clients have reported readiness, to avoid\n        # losing the initial PUB message due to slow joiners.\n        num_ready_clients = 0\n\n        while num_ready_clients < server_args.tp_size:\n            self.recv_from_expert_backup_client.recv_pyobj()\n            num_ready_clients += 1\n\n        back_req = BackupDramReq(\n            rank=self.engine_rank,\n            weight_pointer_map=self.weight_pointer_map,\n            session_id=self.session_id,\n            buffer_size=self.continuous_buffer.numel()\n            * self.continuous_buffer.element_size(),\n        )\n        self.send_to_expert_backup_client.send_pyobj(back_req)\n\n        # Keep the manager subprocess alive until signals\n        signal.pause()\n\n    def backup_weights_from_disk(self):\n        load_config = LoadConfig(load_format=self.load_format)\n        loader = get_model_loader(load_config, self.model_config)\n\n        with set_default_torch_dtype(self.model_config.dtype):\n            iter = loader._get_weights_iterator(\n                DefaultModelLoader.Source.init_new(self.model_config, None)\n            )\n\n            total_bytes = 0\n            weight_info_dict = {}\n\n            for name, weight in iter:\n                expert_id = extract_expert_id(name)\n                if expert_id < self.idmx and expert_id >= self.idmn:\n                    numel = weight.numel()\n                    element_size = weight.element_size()\n                    byte_size = numel * element_size\n                    weight_info_dict[name] = {\n                        \"name\": name,\n                        \"weight\": weight,\n                        \"numel\": numel,\n                        \"shape\": weight.shape,\n                        \"dtype\": weight.dtype,\n                        \"element_size\": element_size,\n                        \"byte_size\": byte_size,\n                    }\n                    total_bytes += byte_size\n\n            if total_bytes == 0:\n                self.continuous_buffer = None\n                self.weight_pointer_map = {}\n                return\n\n            self.continuous_buffer = torch.empty(\n                total_bytes, dtype=torch.uint8, device=\"cpu\"\n            )\n            buffer_base_ptr = self.continuous_buffer.data_ptr()\n            self.weight_pointer_map = {}\n            current_byte_offset = 0\n\n            for name in sorted(weight_info_dict.keys()):\n                weight_info = weight_info_dict[name]\n                weight = weight_info[\"weight\"]\n                byte_size = weight_info[\"byte_size\"]\n                weight_flat = weight.flatten().contiguous()\n                weight_bytes = weight_flat.view(torch.uint8)\n                start_byte = current_byte_offset\n                end_byte = current_byte_offset + byte_size\n                weight_ptr = buffer_base_ptr + current_byte_offset\n                self.continuous_buffer[start_byte:end_byte].copy_(weight_bytes)\n                self.weight_pointer_map[name] = {\n                    \"name\": name,\n                    \"weight_ptr\": weight_ptr,\n                    \"shape\": weight_info[\"shape\"],\n                    \"numel\": weight_info[\"numel\"],\n                    \"dtype\": weight_info[\"dtype\"],\n                    \"element_size\": weight_info[\"element_size\"],\n                    \"byte_size\": byte_size,\n                }\n\n                current_byte_offset = end_byte\n\n    def start_transfer_server(self):\n        from sglang.srt.distributed.parallel_state import get_mooncake_transfer_engine\n\n        self.transfer_engine = get_mooncake_transfer_engine()\n        self.session_id = self.transfer_engine.session_id\n        server_ptr = self.continuous_buffer.data_ptr()\n        server_len = (\n            self.continuous_buffer.numel() * self.continuous_buffer.element_size()\n        )\n\n        ret_value = self.transfer_engine.engine.register_memory(server_ptr, server_len)\n        if ret_value != 0:\n            raise RuntimeError(\"Mooncake memory registration failed.\")\n\n\ndef run_expert_backup_manager_process(\n    server_args: ServerArgs,\n    port_args: PortArgs,\n):\n    set_global_server_args_for_scheduler(server_args)\n    from sglang.srt.distributed.device_communicators.mooncake_transfer_engine import (\n        init_mooncake_transfer_engine,\n    )\n\n    init_mooncake_transfer_engine(\n        hostname=get_local_ip_auto(),\n        gpu_id=0,\n        ib_device=(\n            server_args.disaggregation_ib_device or server_args.mooncake_ib_device\n        ),\n    )\n    manager = ExpertBackupManager(server_args, port_args)\n\n\ndef run_expert_backup_manager(\n    server_args: ServerArgs,\n    port_args: PortArgs,\n):\n    proc = mp.Process(\n        target=run_expert_backup_manager_process,\n        args=(server_args, port_args),\n    )\n    proc.start()\n    return proc\n"
  },
  {
    "path": "python/sglang/srt/entrypoints/EngineBase.py",
    "content": "from abc import ABC, abstractmethod\nfrom typing import Dict, Iterator, List, Optional, Tuple, Union\n\nimport torch\n\n\nclass EngineBase(ABC):\n    \"\"\"\n    Abstract base class for engine interfaces that support generation, weight updating, and memory control.\n    This base class provides a unified API for both HTTP-based engines and engines.\n    \"\"\"\n\n    @abstractmethod\n    def generate(\n        self,\n        prompt: Optional[Union[List[str], str]] = None,\n        sampling_params: Optional[Union[List[Dict], Dict]] = None,\n        input_ids: Optional[Union[List[List[int]], List[int]]] = None,\n        image_data: Optional[Union[List[str], str]] = None,\n        return_logprob: Optional[Union[List[bool], bool]] = False,\n        logprob_start_len: Optional[Union[List[int], int]] = None,\n        top_logprobs_num: Optional[Union[List[int], int]] = None,\n        token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,\n        lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None,\n        custom_logit_processor: Optional[Union[List[str], str]] = None,\n        return_hidden_states: Optional[bool] = None,\n        stream: Optional[bool] = None,\n        bootstrap_host: Optional[Union[List[str], str]] = None,\n        bootstrap_port: Optional[Union[List[int], int]] = None,\n        bootstrap_room: Optional[Union[List[int], int]] = None,\n        routed_dp_rank: Optional[int] = None,\n        disagg_prefill_dp_rank: Optional[int] = None,\n        data_parallel_rank: Optional[int] = None,\n        rid: Optional[Union[List[str], str]] = None,\n        priority: Optional[int] = None,\n    ) -> Union[Dict, Iterator[Dict]]:\n        \"\"\"Generate outputs based on given inputs.\"\"\"\n        pass\n\n    @abstractmethod\n    def flush_cache(self):\n        \"\"\"Flush the cache of the engine.\"\"\"\n        pass\n\n    @abstractmethod\n    def update_weights_from_tensor(\n        self,\n        named_tensors: List[Tuple[str, torch.Tensor]],\n        load_format: Optional[str] = None,\n        flush_cache: bool = True,\n    ):\n        \"\"\"Update model weights with in-memory tensor data.\"\"\"\n        pass\n\n    def load_lora_adapter(self, lora_name: str, lora_path: str):\n        \"\"\"Load a new LoRA adapter without re-launching the engine.\"\"\"\n        pass\n\n    def unload_lora_adapter(self, lora_name: str):\n        \"\"\"Unload a LoRA adapter without re-launching the engine.\"\"\"\n        pass\n\n    @abstractmethod\n    def release_memory_occupation(self):\n        \"\"\"Release GPU memory occupation temporarily.\"\"\"\n        pass\n\n    @abstractmethod\n    def resume_memory_occupation(self):\n        \"\"\"Resume GPU memory occupation which is previously released.\"\"\"\n        pass\n\n    @abstractmethod\n    def shutdown(self):\n        \"\"\"Shutdown the engine and clean up resources.\"\"\"\n        pass\n"
  },
  {
    "path": "python/sglang/srt/entrypoints/anthropic/__init__.py",
    "content": ""
  },
  {
    "path": "python/sglang/srt/entrypoints/anthropic/protocol.py",
    "content": "\"\"\"Pydantic models for Anthropic Messages API protocol\"\"\"\n\nimport uuid\nfrom typing import Any, Literal, Optional\n\nfrom pydantic import BaseModel, Field, field_validator\n\n\nclass AnthropicError(BaseModel):\n    \"\"\"Error structure for Anthropic API\"\"\"\n\n    type: str\n    message: str\n\n\nclass AnthropicErrorResponse(BaseModel):\n    \"\"\"Error response structure for Anthropic API\"\"\"\n\n    type: Literal[\"error\"] = \"error\"\n    error: AnthropicError\n\n\nclass AnthropicUsage(BaseModel):\n    \"\"\"Token usage information\"\"\"\n\n    input_tokens: int\n    output_tokens: int\n    cache_creation_input_tokens: Optional[int] = None\n    cache_read_input_tokens: Optional[int] = None\n\n\nclass AnthropicContentBlock(BaseModel):\n    \"\"\"Content block in message\"\"\"\n\n    type: Literal[\n        \"text\", \"image\", \"tool_use\", \"tool_result\", \"thinking\", \"redacted_thinking\"\n    ]\n    text: Optional[str] = None\n    # For image content\n    source: Optional[dict[str, Any]] = None\n    # For tool use/result\n    id: Optional[str] = None\n    tool_use_id: Optional[str] = None\n    name: Optional[str] = None\n    input: Optional[dict[str, Any]] = None\n    content: Optional[str | list[dict[str, Any]]] = None\n    is_error: Optional[bool] = None\n    # For thinking content\n    thinking: Optional[str] = None\n    signature: Optional[str] = None\n\n\nclass AnthropicMessage(BaseModel):\n    \"\"\"Message structure\"\"\"\n\n    role: Literal[\"user\", \"assistant\"]\n    content: str | list[AnthropicContentBlock]\n\n\nclass AnthropicTool(BaseModel):\n    \"\"\"Tool definition\"\"\"\n\n    name: str\n    description: Optional[str] = None\n    input_schema: dict[str, Any]\n\n    @field_validator(\"input_schema\")\n    @classmethod\n    def validate_input_schema(cls, v):\n        if not isinstance(v, dict):\n            raise ValueError(\"input_schema must be a dictionary\")\n        if \"type\" not in v:\n            v[\"type\"] = \"object\"\n        return v\n\n\nclass AnthropicToolChoice(BaseModel):\n    \"\"\"Tool Choice definition\"\"\"\n\n    type: Literal[\"auto\", \"any\", \"tool\", \"none\"]\n    name: Optional[str] = None\n\n\nclass AnthropicCountTokensRequest(BaseModel):\n    \"\"\"Anthropic Count Tokens API request\"\"\"\n\n    model: str\n    messages: list[AnthropicMessage]\n    system: Optional[str | list[AnthropicContentBlock]] = None\n    tool_choice: Optional[AnthropicToolChoice] = None\n    tools: Optional[list[AnthropicTool]] = None\n\n\nclass AnthropicCountTokensResponse(BaseModel):\n    \"\"\"Anthropic Count Tokens API response\"\"\"\n\n    input_tokens: int\n\n\nclass AnthropicMessagesRequest(BaseModel):\n    \"\"\"Anthropic Messages API request\"\"\"\n\n    model: str\n    messages: list[AnthropicMessage]\n    max_tokens: int\n    metadata: Optional[dict[str, Any]] = None\n    stop_sequences: Optional[list[str]] = None\n    stream: Optional[bool] = False\n    system: Optional[str | list[AnthropicContentBlock]] = None\n    temperature: Optional[float] = None\n    tool_choice: Optional[AnthropicToolChoice] = None\n    tools: Optional[list[AnthropicTool]] = None\n    top_k: Optional[int] = None\n    top_p: Optional[float] = None\n\n    @field_validator(\"model\")\n    @classmethod\n    def validate_model(cls, v):\n        if not v:\n            raise ValueError(\"Model is required\")\n        return v\n\n    @field_validator(\"max_tokens\")\n    @classmethod\n    def validate_max_tokens(cls, v):\n        if v <= 0:\n            raise ValueError(\"max_tokens must be positive\")\n        return v\n\n\nclass AnthropicDelta(BaseModel):\n    \"\"\"Delta for streaming responses\"\"\"\n\n    type: Optional[Literal[\"text_delta\", \"input_json_delta\"]] = None\n    text: Optional[str] = None\n    partial_json: Optional[str] = None\n\n    # Message delta fields\n    stop_reason: Optional[\n        Literal[\"end_turn\", \"max_tokens\", \"stop_sequence\", \"tool_use\"]\n    ] = None\n    stop_sequence: Optional[str] = None\n\n\nclass AnthropicStreamEvent(BaseModel):\n    \"\"\"Streaming event\"\"\"\n\n    type: Literal[\n        \"message_start\",\n        \"message_delta\",\n        \"message_stop\",\n        \"content_block_start\",\n        \"content_block_delta\",\n        \"content_block_stop\",\n        \"ping\",\n        \"error\",\n    ]\n    message: Optional[\"AnthropicMessagesResponse\"] = None\n    delta: Optional[AnthropicDelta] = None\n    content_block: Optional[AnthropicContentBlock] = None\n    index: Optional[int] = None\n    error: Optional[AnthropicError] = None\n    usage: Optional[AnthropicUsage] = None\n\n\nclass AnthropicMessagesResponse(BaseModel):\n    \"\"\"Anthropic Messages API response\"\"\"\n\n    id: str = Field(default_factory=lambda: f\"msg_{uuid.uuid4().hex}\")\n    type: Literal[\"message\"] = \"message\"\n    role: Literal[\"assistant\"] = \"assistant\"\n    content: list[AnthropicContentBlock]\n    model: str\n    stop_reason: Optional[\n        Literal[\"end_turn\", \"max_tokens\", \"stop_sequence\", \"tool_use\"]\n    ] = None\n    stop_sequence: Optional[str] = None\n    usage: Optional[AnthropicUsage] = None\n"
  },
  {
    "path": "python/sglang/srt/entrypoints/anthropic/serving.py",
    "content": "\"\"\"Handler for Anthropic Messages API requests.\n\nConverts Anthropic requests to OpenAI ChatCompletion format, delegates to\nOpenAIServingChat for processing, and converts responses back to Anthropic format.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport json\nimport logging\nimport time\nimport uuid\nfrom typing import TYPE_CHECKING, AsyncGenerator, Optional, Union\n\nfrom fastapi import Request\nfrom fastapi.responses import JSONResponse, StreamingResponse\n\nfrom sglang.srt.entrypoints.anthropic.protocol import (\n    AnthropicContentBlock,\n    AnthropicCountTokensRequest,\n    AnthropicCountTokensResponse,\n    AnthropicDelta,\n    AnthropicError,\n    AnthropicErrorResponse,\n    AnthropicMessagesRequest,\n    AnthropicMessagesResponse,\n    AnthropicStreamEvent,\n    AnthropicUsage,\n)\nfrom sglang.srt.entrypoints.openai.protocol import (\n    ChatCompletionRequest,\n    ChatCompletionResponse,\n    ChatCompletionStreamResponse,\n    StreamOptions,\n    Tool,\n    ToolChoice,\n    ToolChoiceFuncName,\n)\n\nif TYPE_CHECKING:\n    from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat\n\nlogger = logging.getLogger(__name__)\n\n# Map OpenAI finish reasons to Anthropic stop reasons\nSTOP_REASON_MAP = {\n    \"stop\": \"end_turn\",\n    \"length\": \"max_tokens\",\n    \"tool_calls\": \"tool_use\",\n}\n\n\ndef _wrap_sse_event(data: str, event_type: str) -> str:\n    \"\"\"Format an Anthropic SSE event with event type and data lines.\"\"\"\n    return f\"event: {event_type}\\ndata: {data}\\n\\n\"\n\n\nclass AnthropicServing:\n    \"\"\"Handler for Anthropic Messages API requests.\n\n    Acts as a translation layer between Anthropic's Messages API and SGLang's\n    OpenAI-compatible chat completion infrastructure.\n    \"\"\"\n\n    def __init__(self, openai_serving_chat: OpenAIServingChat):\n        self.openai_serving_chat = openai_serving_chat\n\n    async def handle_messages(\n        self,\n        request: AnthropicMessagesRequest,\n        raw_request: Request,\n    ) -> Union[JSONResponse, StreamingResponse]:\n        \"\"\"Main entry point for /v1/messages endpoint.\"\"\"\n        try:\n            chat_request = self._convert_to_chat_completion_request(request)\n        except Exception as e:\n            logger.exception(\"Error converting Anthropic request: %s\", e)\n            return self._error_response(\n                status_code=400,\n                error_type=\"invalid_request_error\",\n                message=str(e),\n            )\n\n        if request.stream:\n            return await self._handle_streaming(chat_request, request, raw_request)\n        else:\n            return await self._handle_non_streaming(chat_request, request, raw_request)\n\n    def _convert_to_chat_completion_request(\n        self, anthropic_request: AnthropicMessagesRequest\n    ) -> ChatCompletionRequest:\n        \"\"\"Convert an Anthropic Messages request to an OpenAI ChatCompletion request.\"\"\"\n        openai_messages = []\n\n        def _convert_anthropic_image_source_to_openai_part(\n            source: Optional[dict],\n        ) -> Optional[dict]:\n            if not isinstance(source, dict):\n                return None\n\n            source_type = source.get(\"type\")\n            if source_type == \"base64\":\n                media_type = source.get(\"media_type\", \"image/png\")\n                data = source.get(\"data\", \"\")\n                if not data:\n                    return None\n                return {\n                    \"type\": \"image_url\",\n                    \"image_url\": {\n                        \"url\": f\"data:{media_type};base64,{data}\",\n                    },\n                }\n\n            url = source.get(\"url\")\n            if url:\n                return {\n                    \"type\": \"image_url\",\n                    \"image_url\": {\n                        \"url\": url,\n                    },\n                }\n\n            return None\n\n        def _convert_tool_result_content(\n            content: Optional[str | list[dict]],\n        ) -> tuple[str | list[dict], str]:\n            if isinstance(content, list):\n                tool_content_parts = []\n                tool_text_parts = []\n\n                for item in content:\n                    if not isinstance(item, dict):\n                        continue\n\n                    item_type = item.get(\"type\")\n                    if item_type == \"text\":\n                        text = item.get(\"text\", \"\")\n                        if text:\n                            tool_text_parts.append(text)\n                            tool_content_parts.append({\"type\": \"text\", \"text\": text})\n                    elif item_type == \"image\":\n                        image_part = _convert_anthropic_image_source_to_openai_part(\n                            item.get(\"source\")\n                        )\n                        if image_part is not None:\n                            tool_content_parts.append(image_part)\n\n                tool_text = \"\\n\".join(tool_text_parts)\n                if (\n                    len(tool_content_parts) == 1\n                    and tool_content_parts[0][\"type\"] == \"text\"\n                ):\n                    return tool_content_parts[0][\"text\"], tool_text\n                if tool_content_parts:\n                    return tool_content_parts, tool_text\n                return \"\", tool_text\n\n            tool_text = str(content) if content else \"\"\n            return tool_text, tool_text\n\n        # Add system message if provided\n        if anthropic_request.system:\n            if isinstance(anthropic_request.system, str):\n                openai_messages.append(\n                    {\"role\": \"system\", \"content\": anthropic_request.system}\n                )\n            else:\n                system_parts = []\n                for block in anthropic_request.system:\n                    if block.type == \"text\" and block.text:\n                        system_parts.append(block.text)\n                system_text = \"\\n\".join(system_parts)\n                openai_messages.append({\"role\": \"system\", \"content\": system_text})\n\n        # Convert messages\n        for msg in anthropic_request.messages:\n            if isinstance(msg.content, str):\n                openai_messages.append({\"role\": msg.role, \"content\": msg.content})\n                continue\n\n            # Complex content with blocks\n            openai_msg = {\"role\": msg.role}\n            content_parts = []\n            tool_calls = []\n\n            for block in msg.content:\n                if block.type == \"text\" and block.text:\n                    content_parts.append({\"type\": \"text\", \"text\": block.text})\n\n                elif block.type == \"image\" and block.source:\n                    image_part = _convert_anthropic_image_source_to_openai_part(\n                        block.source\n                    )\n                    if image_part is not None:\n                        content_parts.append(image_part)\n\n                elif block.type == \"tool_use\":\n                    tool_call = {\n                        \"id\": block.id or f\"call_{uuid.uuid4().hex}\",\n                        \"type\": \"function\",\n                        \"function\": {\n                            \"name\": block.name or \"\",\n                            \"arguments\": json.dumps(block.input or {}),\n                        },\n                    }\n                    tool_calls.append(tool_call)\n\n                elif block.type == \"tool_result\":\n                    tool_content, tool_text = _convert_tool_result_content(\n                        block.content\n                    )\n\n                    # Use tool_use_id (per spec) with fallback to id\n                    tool_call_id = block.tool_use_id or block.id or \"\"\n\n                    # Tool results from user become separate tool messages\n                    if msg.role == \"user\":\n                        openai_messages.append(\n                            {\n                                \"role\": \"tool\",\n                                \"tool_call_id\": tool_call_id,\n                                \"content\": tool_content,\n                            }\n                        )\n                    else:\n                        content_parts.append(\n                            {\n                                \"type\": \"text\",\n                                \"text\": f\"Tool result: {tool_text}\",\n                            }\n                        )\n\n            # Attach tool calls to assistant messages\n            if tool_calls:\n                openai_msg[\"tool_calls\"] = tool_calls\n\n            # Attach content\n            if content_parts:\n                if len(content_parts) == 1 and content_parts[0][\"type\"] == \"text\":\n                    openai_msg[\"content\"] = content_parts[0][\"text\"]\n                else:\n                    openai_msg[\"content\"] = content_parts\n            elif not tool_calls:\n                continue\n\n            openai_messages.append(openai_msg)\n\n        # Build ChatCompletionRequest\n        request_data = {\n            \"messages\": openai_messages,\n            \"model\": anthropic_request.model,\n            \"max_tokens\": anthropic_request.max_tokens,\n            \"stream\": anthropic_request.stream or False,\n        }\n\n        if anthropic_request.temperature is not None:\n            request_data[\"temperature\"] = anthropic_request.temperature\n        if anthropic_request.top_p is not None:\n            request_data[\"top_p\"] = anthropic_request.top_p\n        if anthropic_request.top_k is not None:\n            request_data[\"top_k\"] = anthropic_request.top_k\n        if anthropic_request.stop_sequences is not None:\n            request_data[\"stop\"] = anthropic_request.stop_sequences\n\n        # Enable usage in stream so we can report it\n        if anthropic_request.stream:\n            request_data[\"stream_options\"] = StreamOptions(include_usage=True)\n\n        chat_request = ChatCompletionRequest(**request_data)\n\n        # Convert tools\n        if anthropic_request.tools:\n            tools = []\n            for tool in anthropic_request.tools:\n                tools.append(\n                    Tool(\n                        type=\"function\",\n                        function={\n                            \"name\": tool.name,\n                            \"description\": tool.description or \"\",\n                            \"parameters\": tool.input_schema,\n                        },\n                    )\n                )\n            chat_request.tools = tools\n\n        # Convert tool choice\n        if anthropic_request.tool_choice is not None:\n            if anthropic_request.tool_choice.type == \"none\":\n                chat_request.tool_choice = \"none\"\n            elif anthropic_request.tool_choice.type == \"auto\":\n                chat_request.tool_choice = \"auto\"\n            elif anthropic_request.tool_choice.type == \"any\":\n                chat_request.tool_choice = \"required\"\n            elif anthropic_request.tool_choice.type == \"tool\":\n                chat_request.tool_choice = ToolChoice(\n                    type=\"function\",\n                    function=ToolChoiceFuncName(\n                        name=anthropic_request.tool_choice.name\n                    ),\n                )\n        elif anthropic_request.tools:\n            # Default to auto when tools are provided\n            chat_request.tool_choice = \"auto\"\n\n        return chat_request\n\n    async def _handle_non_streaming(\n        self,\n        chat_request: ChatCompletionRequest,\n        anthropic_request: AnthropicMessagesRequest,\n        raw_request: Request,\n    ) -> JSONResponse:\n        \"\"\"Handle non-streaming Anthropic request by delegating to OpenAI handler.\"\"\"\n        received_time = time.time()\n        received_time_perf = time.perf_counter()\n\n        # Validate\n        error_msg = self.openai_serving_chat._validate_request(chat_request)\n        if error_msg:\n            return self._error_response(\n                status_code=400,\n                error_type=\"invalid_request_error\",\n                message=error_msg,\n            )\n\n        try:\n            # Convert to internal request\n            validation_time = time.perf_counter() - received_time_perf\n            adapted_request, processed_request = (\n                self.openai_serving_chat._convert_to_internal_request(\n                    chat_request, raw_request\n                )\n            )\n            adapted_request.validation_time = validation_time\n            adapted_request.received_time = received_time\n            adapted_request.received_time_perf = received_time_perf\n\n            # Get response from OpenAI handler\n            response = await self.openai_serving_chat._handle_non_streaming_request(\n                adapted_request, processed_request, raw_request\n            )\n        except Exception as e:\n            logger.exception(\"Error processing Anthropic request: %s\", e)\n            return self._error_response(\n                status_code=500,\n                error_type=\"internal_error\",\n                message=\"Internal server error\",\n            )\n\n        # Check for error responses from OpenAI handler\n        if not isinstance(response, ChatCompletionResponse):\n            # It's an error response (ORJSONResponse)\n            return self._error_response(\n                status_code=500,\n                error_type=\"internal_error\",\n                message=\"Internal processing error\",\n            )\n\n        # Convert to Anthropic response\n        anthropic_response = self._convert_response(response)\n        return JSONResponse(content=anthropic_response.model_dump(exclude_none=True))\n\n    async def _handle_streaming(\n        self,\n        chat_request: ChatCompletionRequest,\n        anthropic_request: AnthropicMessagesRequest,\n        raw_request: Request,\n    ) -> Union[StreamingResponse, JSONResponse]:\n        \"\"\"Handle streaming Anthropic request.\"\"\"\n        received_time = time.time()\n        received_time_perf = time.perf_counter()\n\n        # Validate\n        error_msg = self.openai_serving_chat._validate_request(chat_request)\n        if error_msg:\n            return self._error_response(\n                status_code=400,\n                error_type=\"invalid_request_error\",\n                message=error_msg,\n            )\n\n        try:\n            validation_time = time.perf_counter() - received_time_perf\n            adapted_request, processed_request = (\n                self.openai_serving_chat._convert_to_internal_request(\n                    chat_request, raw_request\n                )\n            )\n            adapted_request.validation_time = validation_time\n            adapted_request.received_time = received_time\n            adapted_request.received_time_perf = received_time_perf\n        except Exception as e:\n            logger.exception(\"Error converting streaming request: %s\", e)\n            return self._error_response(\n                status_code=500,\n                error_type=\"internal_error\",\n                message=\"Internal server error\",\n            )\n\n        return StreamingResponse(\n            self._generate_anthropic_stream(\n                adapted_request,\n                processed_request,\n                anthropic_request,\n                raw_request,\n            ),\n            media_type=\"text/event-stream\",\n            background=self.openai_serving_chat.tokenizer_manager.create_abort_task(\n                adapted_request\n            ),\n        )\n\n    async def _generate_anthropic_stream(\n        self,\n        adapted_request,\n        processed_request: ChatCompletionRequest,\n        anthropic_request: AnthropicMessagesRequest,\n        raw_request: Request,\n    ) -> AsyncGenerator[str, None]:\n        \"\"\"Convert OpenAI chat stream to Anthropic event stream.\"\"\"\n        openai_stream = self.openai_serving_chat._generate_chat_stream(\n            adapted_request, processed_request, raw_request\n        )\n\n        # State tracking\n        first_chunk = True\n        content_block_index = 0\n        content_block_open = False\n        finish_reason: Optional[str] = None\n        usage_info: Optional[dict] = None\n        message_id = f\"msg_{uuid.uuid4().hex}\"\n        model = anthropic_request.model\n\n        async for sse_line in openai_stream:\n            if not sse_line.startswith(\"data: \"):\n                continue\n\n            data_str = sse_line[6:].strip()\n\n            if data_str == \"[DONE]\":\n                # Close any open content block\n                if content_block_open:\n                    stop_event = AnthropicStreamEvent(\n                        type=\"content_block_stop\",\n                        index=content_block_index,\n                    )\n                    yield _wrap_sse_event(\n                        stop_event.model_dump_json(exclude_none=True),\n                        \"content_block_stop\",\n                    )\n\n                # Emit message_delta with stop_reason and usage\n                stop_reason = STOP_REASON_MAP.get(finish_reason or \"stop\", \"end_turn\")\n                delta_event = AnthropicStreamEvent(\n                    type=\"message_delta\",\n                    delta=AnthropicDelta(stop_reason=stop_reason),\n                    usage=AnthropicUsage(\n                        input_tokens=(\n                            usage_info.get(\"input_tokens\", 0) if usage_info else 0\n                        ),\n                        output_tokens=(\n                            usage_info.get(\"output_tokens\", 0) if usage_info else 0\n                        ),\n                    ),\n                )\n                yield _wrap_sse_event(\n                    delta_event.model_dump_json(exclude_none=True),\n                    \"message_delta\",\n                )\n\n                # Emit message_stop\n                stop_msg = AnthropicStreamEvent(type=\"message_stop\")\n                yield _wrap_sse_event(\n                    stop_msg.model_dump_json(exclude_none=True),\n                    \"message_stop\",\n                )\n                continue\n\n            # Parse the OpenAI chunk\n            try:\n                chunk = ChatCompletionStreamResponse.model_validate_json(data_str)\n            except Exception:\n                logger.debug(\"Failed to parse stream chunk: %s\", data_str)\n                error_event = AnthropicStreamEvent(\n                    type=\"error\",\n                    error=AnthropicError(\n                        type=\"api_error\", message=\"Stream processing error\"\n                    ),\n                )\n                yield _wrap_sse_event(\n                    error_event.model_dump_json(exclude_none=True), \"error\"\n                )\n                continue\n\n            # First chunk: emit message_start\n            if first_chunk:\n                first_chunk = False\n\n                start_event = AnthropicStreamEvent(\n                    type=\"message_start\",\n                    message=AnthropicMessagesResponse(\n                        id=message_id,\n                        content=[],\n                        model=model,\n                        usage=AnthropicUsage(\n                            input_tokens=(\n                                chunk.usage.prompt_tokens if chunk.usage else 0\n                            ),\n                            output_tokens=0,\n                        ),\n                    ),\n                )\n                yield _wrap_sse_event(\n                    start_event.model_dump_json(exclude_none=True),\n                    \"message_start\",\n                )\n                # Skip if this was just the role chunk with empty content\n                if chunk.choices and chunk.choices[0].delta.content == \"\":\n                    continue\n\n            # Usage-only chunk (empty choices with usage info)\n            if not chunk.choices and chunk.usage:\n                usage_info = {\n                    \"input_tokens\": chunk.usage.prompt_tokens,\n                    \"output_tokens\": chunk.usage.completion_tokens or 0,\n                }\n                continue\n\n            if not chunk.choices:\n                continue\n\n            choice = chunk.choices[0]\n\n            # Capture finish reason\n            if choice.finish_reason is not None:\n                finish_reason = choice.finish_reason\n                continue\n\n            delta = choice.delta\n\n            # Handle tool call deltas\n            if delta.tool_calls:\n                for tc in delta.tool_calls:\n                    tc_id = tc.id\n                    tc_func = tc.function\n\n                    # New tool call: close previous block, start new one\n                    if tc_func and tc_func.name:\n                        # Close previous content block if open\n                        if content_block_open:\n                            stop_event = AnthropicStreamEvent(\n                                type=\"content_block_stop\",\n                                index=content_block_index,\n                            )\n                            yield _wrap_sse_event(\n                                stop_event.model_dump_json(exclude_none=True),\n                                \"content_block_stop\",\n                            )\n                            content_block_index += 1\n\n                        # Start tool_use content block\n                        start_event = AnthropicStreamEvent(\n                            type=\"content_block_start\",\n                            index=content_block_index,\n                            content_block=AnthropicContentBlock(\n                                type=\"tool_use\",\n                                id=tc_id or f\"toolu_{uuid.uuid4().hex}\",\n                                name=tc_func.name,\n                                input={},\n                            ),\n                        )\n                        yield _wrap_sse_event(\n                            start_event.model_dump_json(exclude_none=True),\n                            \"content_block_start\",\n                        )\n                        content_block_open = True\n\n                        # Stream initial arguments if present\n                        if tc_func.arguments:\n                            delta_event = AnthropicStreamEvent(\n                                type=\"content_block_delta\",\n                                index=content_block_index,\n                                delta=AnthropicDelta(\n                                    type=\"input_json_delta\",\n                                    partial_json=tc_func.arguments,\n                                ),\n                            )\n                            yield _wrap_sse_event(\n                                delta_event.model_dump_json(exclude_none=True),\n                                \"content_block_delta\",\n                            )\n\n                    elif tc_func and tc_func.arguments:\n                        # Continuing arguments for current tool call\n                        delta_event = AnthropicStreamEvent(\n                            type=\"content_block_delta\",\n                            index=content_block_index,\n                            delta=AnthropicDelta(\n                                type=\"input_json_delta\",\n                                partial_json=tc_func.arguments,\n                            ),\n                        )\n                        yield _wrap_sse_event(\n                            delta_event.model_dump_json(exclude_none=True),\n                            \"content_block_delta\",\n                        )\n                continue\n\n            # Handle text content deltas\n            if delta.content is not None and delta.content != \"\":\n                # Start a text content block if needed\n                if not content_block_open:\n                    start_event = AnthropicStreamEvent(\n                        type=\"content_block_start\",\n                        index=content_block_index,\n                        content_block=AnthropicContentBlock(type=\"text\", text=\"\"),\n                    )\n                    yield _wrap_sse_event(\n                        start_event.model_dump_json(exclude_none=True),\n                        \"content_block_start\",\n                    )\n                    content_block_open = True\n\n                # Emit text delta\n                delta_event = AnthropicStreamEvent(\n                    type=\"content_block_delta\",\n                    index=content_block_index,\n                    delta=AnthropicDelta(\n                        type=\"text_delta\",\n                        text=delta.content,\n                    ),\n                )\n                yield _wrap_sse_event(\n                    delta_event.model_dump_json(exclude_none=True),\n                    \"content_block_delta\",\n                )\n\n    def _convert_response(\n        self, response: ChatCompletionResponse\n    ) -> AnthropicMessagesResponse:\n        \"\"\"Convert an OpenAI ChatCompletionResponse to an Anthropic Messages response.\"\"\"\n        if not response.choices:\n            return AnthropicMessagesResponse(\n                content=[AnthropicContentBlock(type=\"text\", text=\"\")],\n                model=response.model,\n                stop_reason=\"end_turn\",\n                usage=AnthropicUsage(input_tokens=0, output_tokens=0),\n            )\n\n        choice = response.choices[0]\n        content: list[AnthropicContentBlock] = []\n\n        # Add text content\n        if choice.message.content:\n            content.append(\n                AnthropicContentBlock(type=\"text\", text=choice.message.content)\n            )\n\n        # Add tool calls\n        if choice.message.tool_calls:\n            for tool_call in choice.message.tool_calls:\n                try:\n                    tool_input = json.loads(tool_call.function.arguments)\n                except (json.JSONDecodeError, TypeError):\n                    tool_input = {}\n\n                content.append(\n                    AnthropicContentBlock(\n                        type=\"tool_use\",\n                        id=tool_call.id,\n                        name=tool_call.function.name,\n                        input=tool_input,\n                    )\n                )\n\n        # Map stop reason\n        stop_reason = STOP_REASON_MAP.get(choice.finish_reason or \"stop\", \"end_turn\")\n\n        return AnthropicMessagesResponse(\n            id=f\"msg_{uuid.uuid4().hex}\",\n            content=content,\n            model=response.model,\n            stop_reason=stop_reason,\n            usage=AnthropicUsage(\n                input_tokens=response.usage.prompt_tokens if response.usage else 0,\n                output_tokens=response.usage.completion_tokens if response.usage else 0,\n            ),\n        )\n\n    def _error_response(\n        self,\n        status_code: int,\n        error_type: str,\n        message: str,\n    ) -> JSONResponse:\n        \"\"\"Create an Anthropic-format error response.\"\"\"\n        error_resp = AnthropicErrorResponse(\n            error=AnthropicError(type=error_type, message=message)\n        )\n        return JSONResponse(\n            status_code=status_code,\n            content=error_resp.model_dump(),\n        )\n\n    async def handle_count_tokens(\n        self,\n        request: AnthropicCountTokensRequest,\n        raw_request: Request,\n    ) -> JSONResponse:\n        \"\"\"Handle /v1/messages/count_tokens endpoint.\n\n        Converts the request to a ChatCompletionRequest, applies the chat\n        template via the OpenAI handler to tokenize, and returns the count.\n        \"\"\"\n        try:\n            # Build a minimal AnthropicMessagesRequest so we can reuse conversion\n            messages_request = AnthropicMessagesRequest(\n                model=request.model,\n                messages=request.messages,\n                max_tokens=1,  # dummy, not used for counting\n                system=request.system,\n                tools=request.tools,\n                tool_choice=request.tool_choice,\n            )\n            chat_request = self._convert_to_chat_completion_request(messages_request)\n        except Exception as e:\n            logger.exception(\"Error converting count_tokens request: %s\", e)\n            return self._error_response(\n                status_code=400,\n                error_type=\"invalid_request_error\",\n                message=str(e),\n            )\n\n        try:\n            is_multimodal = (\n                self.openai_serving_chat.tokenizer_manager.model_config.is_multimodal\n            )\n            processed = self.openai_serving_chat._process_messages(\n                chat_request, is_multimodal\n            )\n\n            if isinstance(processed.prompt_ids, list):\n                input_tokens = len(processed.prompt_ids)\n            else:\n                # prompt_ids is a string (multimodal case) — tokenize it\n                tokenizer = self.openai_serving_chat.tokenizer_manager.tokenizer\n                input_tokens = len(tokenizer.encode(processed.prompt_ids))\n\n            return JSONResponse(\n                content=AnthropicCountTokensResponse(\n                    input_tokens=input_tokens\n                ).model_dump()\n            )\n        except Exception as e:\n            logger.exception(\"Error counting tokens: %s\", e)\n            return self._error_response(\n                status_code=500,\n                error_type=\"internal_error\",\n                message=\"Internal server error\",\n            )\n"
  },
  {
    "path": "python/sglang/srt/entrypoints/context.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n# Copied from vLLM\nimport logging\nfrom abc import ABC, abstractmethod\nfrom typing import Union\n\nimport orjson\n\nlogger = logging.getLogger(__name__)\n\ntry:\n    from mcp import ClientSession\nexcept ImportError as e:\n    mcp = e\n\nfrom openai_harmony import Author, Message, Role, StreamState, TextContent\n\nfrom sglang.srt.entrypoints.harmony_utils import (\n    get_encoding,\n    get_streamable_parser_for_assistant,\n    render_for_completion,\n)\nfrom sglang.srt.entrypoints.tool import Tool\n\n\nclass ConversationContext(ABC):\n\n    @abstractmethod\n    def append_output(self, output) -> None:\n        pass\n\n    @abstractmethod\n    async def call_tool(self) -> list[Message]:\n        pass\n\n    @abstractmethod\n    def need_builtin_tool_call(self) -> bool:\n        pass\n\n    @abstractmethod\n    def render_for_completion(self) -> list[int]:\n        pass\n\n\nclass SimpleContext(ConversationContext):\n\n    def __init__(self):\n        self.last_output = None\n\n    def append_output(self, output) -> None:\n        self.last_output = output\n\n    def need_builtin_tool_call(self) -> bool:\n        return False\n\n    async def call_tool(self) -> list[Message]:\n        raise NotImplementedError(\"Should not be called.\")\n\n    def render_for_completion(self) -> list[int]:\n        raise NotImplementedError(\"Should not be called.\")\n\n\nclass HarmonyContext(ConversationContext):\n\n    def __init__(\n        self,\n        messages: list,\n        tool_sessions: dict[str, Union[\"ClientSession\", Tool]],\n    ):\n        # TODO: Remove the hack of Union[ClientSession, Tool] by using MCP\n        # when demo.\n        self._messages = messages\n        self.tool_sessions = tool_sessions\n\n        self.parser = get_streamable_parser_for_assistant()\n        self.num_init_messages = len(messages)\n        # TODO\n        self.num_prompt_tokens = 0\n        self.num_cached_tokens = 0\n        self.num_output_tokens = 0\n        self.num_reasoning_tokens = 0\n\n    def append_output(self, output) -> None:\n        if isinstance(output, dict) and \"output_ids\" in output:\n            output_token_ids = output[\"output_ids\"]\n\n            for token_id in output_token_ids:\n                self.parser.process(token_id)\n            output_msgs = self.parser.messages\n\n            meta_info = output[\"meta_info\"]\n\n            if isinstance(meta_info, dict):\n                if \"prompt_token_ids\" in meta_info:\n                    self.num_prompt_tokens = meta_info[\"prompt_tokens\"]\n                if \"cached_tokens\" in meta_info:\n                    self.num_cached_tokens = meta_info[\"cached_tokens\"]\n                if \"completion_tokens\" in meta_info:\n                    self.num_output_tokens += meta_info[\"completion_tokens\"]\n\n        else:\n            output_msgs = output\n\n        self._messages.extend(output_msgs)\n\n    @property\n    def messages(self) -> list:\n        return self._messages\n\n    def need_builtin_tool_call(self) -> bool:\n        if not self.messages:\n            return False\n        last_msg = self.messages[-1]\n        recipient = last_msg.recipient\n        return recipient is not None and (\n            recipient.startswith(\"browser.\") or recipient.startswith(\"python\")\n        )\n\n    async def call_tool(self) -> list[Message]:\n        if not self.messages:\n            return []\n        last_msg = self.messages[-1]\n        recipient = last_msg.recipient\n        if recipient is not None:\n            if recipient.startswith(\"browser.\"):\n                return await self.call_search_tool(\n                    self.tool_sessions[\"browser\"], last_msg\n                )\n            elif recipient.startswith(\"python\"):\n                return await self.call_python_tool(\n                    self.tool_sessions[\"python\"], last_msg\n                )\n        raise ValueError(\"No tool call found\")\n\n    def render_for_completion(self) -> list[int]:\n        return render_for_completion(self.messages)\n\n    async def call_search_tool(\n        self, tool_session: Union[\"ClientSession\", Tool], last_msg: Message\n    ) -> list[Message]:\n        if isinstance(tool_session, Tool):\n            return await tool_session.get_result(self)\n        tool_name = last_msg.recipient.split(\".\")[1]\n        args = orjson.loads(last_msg.content[0].text)\n        result = await tool_session.call_tool(tool_name, args)\n        result_str = result.content[0].text\n        content = TextContent(text=result_str)\n        author = Author(role=Role.TOOL, name=last_msg.recipient)\n        return [Message(author=author, content=[content], recipient=Role.ASSISTANT)]\n\n    async def call_python_tool(\n        self, tool_session: Union[\"ClientSession\", Tool], last_msg: Message\n    ) -> list[Message]:\n        if isinstance(tool_session, Tool):\n            return await tool_session.get_result(self)\n        param = {\n            \"code\": last_msg.content[0].text,\n        }\n        result = await tool_session.call_tool(\"python\", param)\n        result_str = result.content[0].text\n\n        content = TextContent(text=result_str)\n        author = Author(role=Role.TOOL, name=\"python\")\n\n        return [\n            Message(\n                author=author,\n                content=[content],\n                channel=last_msg.channel,\n                recipient=Role.ASSISTANT,\n            )\n        ]\n\n\nclass StreamingHarmonyContext(HarmonyContext):\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.last_output = None\n\n        self.parser = get_streamable_parser_for_assistant()\n        self.encoding = get_encoding()\n        self.last_tok = None\n        self.num_processed_tokens = 0\n\n    @property\n    def messages(self) -> list:\n        return self.parser.messages\n\n    def append_output(self, output) -> None:\n        if isinstance(output, dict) and \"output_ids\" in output:\n            # RequestOutput from SGLang with outputs\n            output_token_ids = output[\"output_ids\"]\n\n            # Check if we need to handle cumulative tokens\n            meta_info = output.get(\"meta_info\", {})\n            completion_tokens = meta_info.get(\"completion_tokens\")\n            if (\n                completion_tokens is not None\n                and len(output_token_ids) == completion_tokens\n            ):\n                # Case 1: When --incremental-streaming-output is not set.\n                # The output_ids contains all tokens generated so far.\n                # We only need to process the new tokens.\n                new_token_ids = output_token_ids[self.num_processed_tokens :]\n                self.num_processed_tokens = len(output_token_ids)\n            else:\n                # Case 2: When --incremental-streaming-output is set.\n                # The output_ids contains only the new tokens.\n                new_token_ids = output_token_ids\n                self.num_processed_tokens += len(output_token_ids)\n\n            for token_id in new_token_ids:\n                self.parser.process(token_id)\n\n        else:\n            # Handle the case of tool output in direct message format\n            assert len(output) == 1, \"Tool output should be a single message\"\n            msg = output[0]\n            # Sometimes the recipient is not set for tool messages,\n            # so we set it to \"assistant\"\n            if msg.author.role == Role.TOOL and msg.recipient is None:\n                msg.recipient = \"assistant\"\n            toks = self.encoding.render(msg)\n            for tok in toks:\n                self.parser.process(tok)\n            self.last_tok = toks[-1]\n\n    def is_expecting_start(self) -> bool:\n        return self.parser.state == StreamState.EXPECT_START\n\n    def is_assistant_action_turn(self) -> bool:\n        return self.last_tok in self.encoding.stop_tokens_for_assistant_actions()\n\n    def render_for_completion(self) -> list[int]:\n        # now this list of tokens as next turn's starting tokens\n        # `<|start|>assistant``,\n        # we need to process them in parser.\n        rendered_tokens = super().render_for_completion()\n\n        last_n = -1\n        to_process = []\n        while rendered_tokens[last_n] != self.last_tok:\n            to_process.append(rendered_tokens[last_n])\n            last_n -= 1\n        for tok in reversed(to_process):\n            self.parser.process(tok)\n\n        return rendered_tokens\n"
  },
  {
    "path": "python/sglang/srt/entrypoints/engine.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"\nThe entry point of inference server. (SRT = SGLang Runtime)\n\nThis file implements python APIs for the inference engine.\n\"\"\"\n\nimport asyncio\nimport atexit\nimport dataclasses\nimport logging\nimport multiprocessing as mp\nimport os\nimport random\nimport signal\nimport threading\nimport time\nfrom typing import (\n    Any,\n    AsyncIterator,\n    Callable,\n    Dict,\n    Iterator,\n    List,\n    Optional,\n    Tuple,\n    Union,\n)\n\n# Fix a bug of Python threading\nsetattr(threading, \"_register_atexit\", lambda *args, **kwargs: None)\n\nimport torch\nimport uvloop\nimport zmq\n\nfrom sglang.srt.elastic_ep.expert_backup_manager import run_expert_backup_manager\nfrom sglang.srt.entrypoints.EngineBase import EngineBase\nfrom sglang.srt.managers.data_parallel_controller import (\n    run_data_parallel_controller_process,\n)\nfrom sglang.srt.managers.detokenizer_manager import run_detokenizer_process\nfrom sglang.srt.managers.io_struct import (\n    CloseSessionReqInput,\n    DestroyWeightsUpdateGroupReqInput,\n    EmbeddingReqInput,\n    GenerateReqInput,\n    GetWeightsByNameReqInput,\n    InitWeightsUpdateGroupReqInput,\n    LoadLoRAAdapterFromTensorsReqInput,\n    LoadLoRAAdapterReqInput,\n    MultimodalDataInputFormat,\n    OpenSessionReqInput,\n    ReleaseMemoryOccupationReqInput,\n    ResumeMemoryOccupationReqInput,\n    RpcReqInput,\n    RpcReqOutput,\n    UnloadLoRAAdapterReqInput,\n    UpdateWeightFromDiskReqInput,\n    UpdateWeightsFromDistributedReqInput,\n    UpdateWeightsFromIPCReqInput,\n    UpdateWeightsFromTensorReqInput,\n)\nfrom sglang.srt.managers.multi_tokenizer_mixin import MultiTokenizerRouter\nfrom sglang.srt.managers.scheduler import run_scheduler_process\nfrom sglang.srt.managers.template_manager import TemplateManager\nfrom sglang.srt.managers.tokenizer_manager import TokenizerManager\nfrom sglang.srt.managers.tokenizer_manager_multiitem_mixin import ScoreResult\nfrom sglang.srt.model_loader.remote_instance_weight_loader_utils import (\n    parse_remote_instance_transfer_engine_info_from_scheduler_infos,\n)\nfrom sglang.srt.observability.trace import process_tracing_init, trace_set_thread_info\nfrom sglang.srt.server_args import PortArgs, ServerArgs\nfrom sglang.srt.utils import (\n    MultiprocessingSerializer,\n    assert_pkg_version,\n    configure_logger,\n    get_bool_env_var,\n    is_cuda,\n    kill_process_tree,\n    launch_dummy_health_check_server,\n    maybe_reindex_device_id,\n    numa_utils,\n    set_prometheus_multiproc_dir,\n    set_ulimit,\n)\nfrom sglang.srt.utils.network import get_zmq_socket\nfrom sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter\nfrom sglang.version import __version__\n\nlogger = logging.getLogger(__name__)\nasyncio.set_event_loop_policy(uvloop.EventLoopPolicy())\n\n_is_cuda = is_cuda()\n\n\n@dataclasses.dataclass\nclass SchedulerInitResult:\n    \"\"\"Result from launching schedulers.\"\"\"\n\n    scheduler_infos: List[Dict[str, Any]]\n    wait_for_ready: Callable[[], None] = lambda: None\n    wait_for_completion: Callable[[], None] = lambda: None\n\n\ndef init_tokenizer_manager(\n    server_args: ServerArgs,\n    port_args: PortArgs,\n    TokenizerManagerClass: Optional[TokenizerManager] = None,\n) -> Tuple[TokenizerManager, TemplateManager]:\n    # Launch tokenizer process\n    TokenizerManagerClass = TokenizerManagerClass or TokenizerManager\n    tokenizer_manager = TokenizerManagerClass(server_args, port_args)\n\n    # Initialize templates\n    template_manager = TemplateManager()\n    template_manager.initialize_templates(\n        tokenizer_manager=tokenizer_manager,\n        model_path=server_args.model_path,\n        chat_template=server_args.chat_template,\n        completion_template=server_args.completion_template,\n    )\n\n    return tokenizer_manager, template_manager\n\n\nclass Engine(EngineBase):\n    \"\"\"\n    The entry point to the inference engine.\n\n    - The engine consists of three components:\n        1. TokenizerManager: Tokenizes the requests and sends them to the scheduler.\n        2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager.\n        3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.\n\n    Note:\n    1. The HTTP server, Engine, and TokenizerManager all run in the main process.\n    2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library.\n    \"\"\"\n\n    # Some fields to allow people to override the server args\n    # and launch processes for their private forks.\n    server_args_class: ServerArgs = ServerArgs\n    init_tokenizer_manager_func: Callable = staticmethod(init_tokenizer_manager)\n    run_scheduler_process_func: Callable = staticmethod(run_scheduler_process)\n    run_detokenizer_process_func: Callable = staticmethod(run_detokenizer_process)\n\n    def __init__(self, **kwargs):\n        \"\"\"\n        The arguments of this function is the same as `sglang/srt/server_args.py::ServerArgs`.\n        Please refer to `ServerArgs` for the documentation.\n        \"\"\"\n\n        # Parse server_args\n        if \"server_args\" in kwargs:\n            # Directly load server_args\n            server_args = kwargs[\"server_args\"]\n        else:\n            # Construct server_args from kwargs\n            if \"log_level\" not in kwargs:\n                # Do not print logs by default\n                kwargs[\"log_level\"] = \"error\"\n            server_args = self.server_args_class(**kwargs)\n        self.server_args = server_args\n        logger.info(f\"{server_args=}\")\n\n        # Shutdown the subprocesses automatically when the program exits\n        atexit.register(self.shutdown)\n\n        # Launch subprocesses\n        (\n            tokenizer_manager,\n            template_manager,\n            port_args,\n            scheduler_init_result,\n        ) = self._launch_subprocesses(\n            server_args=server_args,\n            init_tokenizer_manager_func=self.init_tokenizer_manager_func,\n            run_scheduler_process_func=self.run_scheduler_process_func,\n            run_detokenizer_process_func=self.run_detokenizer_process_func,\n        )\n        self.tokenizer_manager = tokenizer_manager\n        self.template_manager = template_manager\n        self._scheduler_init_result = scheduler_init_result\n        self.port_args = port_args\n        self.remote_instance_transfer_engine_info = (\n            parse_remote_instance_transfer_engine_info_from_scheduler_infos(\n                scheduler_init_result.scheduler_infos\n            )\n        )\n\n        # Initialize ZMQ sockets\n        context = zmq.Context(2)\n        if self.server_args.node_rank == 0:\n            self.send_to_rpc = get_zmq_socket(\n                context, zmq.DEALER, self.port_args.rpc_ipc_name, True\n            )\n        else:\n            self.send_to_rpc = None\n\n        # Enable tracing\n        if server_args.enable_trace:\n            process_tracing_init(server_args.otlp_traces_endpoint, \"sglang\")\n            thread_label = \"Tokenizer\"\n            if server_args.disaggregation_mode == \"prefill\":\n                thread_label = \"Prefill Tokenizer\"\n            elif server_args.disaggregation_mode == \"decode\":\n                thread_label = \"Decode Tokenizer\"\n            trace_set_thread_info(thread_label)\n\n        try:\n            self.loop = asyncio.get_running_loop()\n        except RuntimeError:\n            self.loop = asyncio.new_event_loop()\n            asyncio.set_event_loop(self.loop)\n\n    def _resolve_routed_dp_rank(\n        self,\n        routed_dp_rank: Optional[int],\n        data_parallel_rank: Optional[int],\n    ) -> Optional[int]:\n        if data_parallel_rank is not None:\n            import warnings\n\n            warnings.warn(\n                \"'data_parallel_rank' is deprecated, use 'routed_dp_rank' instead.\",\n                DeprecationWarning,\n                stacklevel=3,\n            )\n            if routed_dp_rank is None:\n                routed_dp_rank = data_parallel_rank\n\n        if routed_dp_rank is not None:\n            dp_size = self.server_args.dp_size\n            if dp_size <= 1 and routed_dp_rank == 0:\n                logger.warning(\n                    f\"routed_dp_rank={routed_dp_rank} is ignored because dp_size={dp_size}\"\n                )\n                return None\n            if routed_dp_rank < 0 or routed_dp_rank >= dp_size:\n                raise ValueError(\n                    f\"routed_dp_rank={routed_dp_rank} out of range [0, {dp_size})\"\n                )\n\n        logger.debug(f\"routed_dp_rank: {routed_dp_rank}\")\n        return routed_dp_rank\n\n    def generate(\n        self,\n        # The input prompt. It can be a single prompt or a batch of prompts.\n        prompt: Optional[Union[List[str], str]] = None,\n        sampling_params: Optional[Union[List[Dict], Dict]] = None,\n        # The token ids for text; one can either specify text or input_ids.\n        input_ids: Optional[Union[List[List[int]], List[int]]] = None,\n        # The image input. It can be an image instance, file name, URL, or base64 encoded string.\n        # Can be formatted as:\n        # - Single image for a single request\n        # - List of images (one per request in a batch)\n        # - List of lists of images (multiple images per request)\n        # - List of preprocessed outputs from a Huggingface processor, each as a dict containing `format`: 'processor_output' and other data\n        # - List of precomputed image embeddings, each as a dict containing field `format`: 'precomputed_embedding' and `feature`: the precomputed embedding\n        # See also python/sglang/srt/utils.py:load_image for more details.\n        image_data: Optional[MultimodalDataInputFormat] = None,\n        audio_data: Optional[MultimodalDataInputFormat] = None,\n        video_data: Optional[MultimodalDataInputFormat] = None,\n        return_logprob: Optional[Union[List[bool], bool]] = False,\n        logprob_start_len: Optional[Union[List[int], int]] = None,\n        top_logprobs_num: Optional[Union[List[int], int]] = None,\n        token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,\n        lora_path: Optional[List[Optional[str]]] = None,\n        custom_logit_processor: Optional[Union[List[str], str]] = None,\n        return_hidden_states: bool = False,\n        return_routed_experts: bool = False,\n        stream: bool = False,\n        bootstrap_host: Optional[Union[List[str], str]] = None,\n        bootstrap_port: Optional[Union[List[int], int]] = None,\n        bootstrap_room: Optional[Union[List[int], int]] = None,\n        routed_dp_rank: Optional[int] = None,\n        disagg_prefill_dp_rank: Optional[int] = None,\n        # Deprecated: use routed_dp_rank instead\n        data_parallel_rank: Optional[int] = None,\n        external_trace_header: Optional[Dict] = None,\n        rid: Optional[Union[List[str], str]] = None,\n        session_params: Optional[Dict] = None,\n        priority: Optional[int] = None,\n    ) -> Union[Dict, Iterator[Dict]]:\n        \"\"\"\n        The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.\n        Please refer to `GenerateReqInput` for the documentation.\n        \"\"\"\n        routed_dp_rank = self._resolve_routed_dp_rank(\n            routed_dp_rank, data_parallel_rank\n        )\n\n        obj = GenerateReqInput(\n            text=prompt,\n            input_ids=input_ids,\n            sampling_params=sampling_params,\n            image_data=image_data,\n            audio_data=audio_data,\n            video_data=video_data,\n            return_logprob=return_logprob,\n            logprob_start_len=logprob_start_len,\n            top_logprobs_num=top_logprobs_num,\n            token_ids_logprob=token_ids_logprob,\n            lora_path=lora_path,\n            custom_logit_processor=custom_logit_processor,\n            return_hidden_states=return_hidden_states,\n            return_routed_experts=return_routed_experts,\n            stream=stream,\n            bootstrap_host=bootstrap_host,\n            bootstrap_port=bootstrap_port,\n            bootstrap_room=bootstrap_room,\n            routed_dp_rank=routed_dp_rank,\n            disagg_prefill_dp_rank=disagg_prefill_dp_rank,\n            external_trace_header=external_trace_header,\n            rid=rid,\n            session_params=session_params,\n            priority=priority,\n        )\n        generator = self.tokenizer_manager.generate_request(obj, None)\n\n        if stream:\n\n            def generator_wrapper():\n                while True:\n                    try:\n                        chunk = self.loop.run_until_complete(generator.__anext__())\n                        yield chunk\n                    except StopAsyncIteration:\n                        break\n\n            return generator_wrapper()\n        else:\n            ret = self.loop.run_until_complete(generator.__anext__())\n            return ret\n\n    async def async_generate(\n        self,\n        # The input prompt. It can be a single prompt or a batch of prompts.\n        prompt: Optional[Union[List[str], str]] = None,\n        sampling_params: Optional[Union[List[Dict], Dict]] = None,\n        # The token ids for text; one can either specify text or input_ids.\n        input_ids: Optional[Union[List[List[int]], List[int]]] = None,\n        # The image input. It can be an image instance, file name, URL, or base64 encoded string.\n        # Can be formatted as:\n        # - Single image for a single request\n        # - List of images (one per request in a batch)\n        # - List of lists of images (multiple images per request)\n        # - List of preprocessed outputs from a Huggingface processor, each as a dict containing `format`: 'processor_output' and other data\n        # - List of precomputed image embeddings, each as a dict containing field `format`: 'precomputed_embedding' and `feature`: the precomputed embedding\n        # See also python/sglang/srt/utils.py:load_image for more details.\n        image_data: Optional[MultimodalDataInputFormat] = None,\n        audio_data: Optional[MultimodalDataInputFormat] = None,\n        video_data: Optional[MultimodalDataInputFormat] = None,\n        return_logprob: Optional[Union[List[bool], bool]] = False,\n        logprob_start_len: Optional[Union[List[int], int]] = None,\n        top_logprobs_num: Optional[Union[List[int], int]] = None,\n        token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,\n        lora_path: Optional[List[Optional[str]]] = None,\n        custom_logit_processor: Optional[Union[List[str], str]] = None,\n        return_hidden_states: bool = False,\n        return_routed_experts: bool = False,\n        stream: bool = False,\n        bootstrap_host: Optional[Union[List[str], str]] = None,\n        bootstrap_port: Optional[Union[List[int], int]] = None,\n        bootstrap_room: Optional[Union[List[int], int]] = None,\n        routed_dp_rank: Optional[int] = None,\n        disagg_prefill_dp_rank: Optional[int] = None,\n        # Deprecated: use routed_dp_rank instead\n        data_parallel_rank: Optional[int] = None,\n        external_trace_header: Optional[Dict] = None,\n        rid: Optional[Union[List[str], str]] = None,\n        session_params: Optional[Dict] = None,\n        priority: Optional[int] = None,\n    ) -> Union[Dict, AsyncIterator[Dict]]:\n        \"\"\"\n        The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.\n        Please refer to `GenerateReqInput` for the documentation.\n        \"\"\"\n        routed_dp_rank = self._resolve_routed_dp_rank(\n            routed_dp_rank, data_parallel_rank\n        )\n\n        obj = GenerateReqInput(\n            text=prompt,\n            input_ids=input_ids,\n            sampling_params=sampling_params,\n            image_data=image_data,\n            audio_data=audio_data,\n            video_data=video_data,\n            return_logprob=return_logprob,\n            logprob_start_len=logprob_start_len,\n            top_logprobs_num=top_logprobs_num,\n            token_ids_logprob=token_ids_logprob,\n            lora_path=lora_path,\n            return_hidden_states=return_hidden_states,\n            return_routed_experts=return_routed_experts,\n            stream=stream,\n            custom_logit_processor=custom_logit_processor,\n            bootstrap_host=bootstrap_host,\n            bootstrap_port=bootstrap_port,\n            bootstrap_room=bootstrap_room,\n            routed_dp_rank=routed_dp_rank,\n            disagg_prefill_dp_rank=disagg_prefill_dp_rank,\n            external_trace_header=external_trace_header,\n            rid=rid,\n            session_params=session_params,\n            priority=priority,\n        )\n        generator = self.tokenizer_manager.generate_request(obj, None)\n\n        if stream is True:\n            return generator\n        else:\n            return await generator.__anext__()\n\n    def encode(\n        self,\n        prompt: Union[str, List[str], List[Dict], List[List[Dict]]],\n        image_data: Optional[MultimodalDataInputFormat] = None,\n        audio_data: Optional[MultimodalDataInputFormat] = None,\n        video_data: Optional[MultimodalDataInputFormat] = None,\n        dimensions: Optional[int] = None,\n        lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None,\n        external_trace_header: Optional[Dict] = None,\n        rid: Optional[Union[List[str], str]] = None,\n    ) -> Dict:\n        \"\"\"\n        The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.\n        Please refer to `EmbeddingReqInput` for the documentation.\n        \"\"\"\n        obj = EmbeddingReqInput(\n            text=prompt,\n            image_data=image_data,\n            audio_data=audio_data,\n            video_data=video_data,\n            dimensions=dimensions,\n            lora_path=lora_path,\n            external_trace_header=external_trace_header,\n            rid=rid,\n        )\n        generator = self.tokenizer_manager.generate_request(obj, None)\n        ret = self.loop.run_until_complete(generator.__anext__())\n        return ret\n\n    async def async_encode(\n        self,\n        prompt: Union[str, List[str], List[Dict], List[List[Dict]]],\n        image_data: Optional[MultimodalDataInputFormat] = None,\n        audio_data: Optional[MultimodalDataInputFormat] = None,\n        video_data: Optional[MultimodalDataInputFormat] = None,\n        dimensions: Optional[int] = None,\n        lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None,\n        external_trace_header: Optional[Dict] = None,\n        rid: Optional[Union[List[str], str]] = None,\n    ) -> Dict:\n        \"\"\"\n        Asynchronous version of encode method.\n\n        The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.\n        Please refer to `EmbeddingReqInput` for the documentation.\n        \"\"\"\n        obj = EmbeddingReqInput(\n            text=prompt,\n            image_data=image_data,\n            audio_data=audio_data,\n            video_data=video_data,\n            dimensions=dimensions,\n            lora_path=lora_path,\n            external_trace_header=external_trace_header,\n            rid=rid,\n        )\n        generator = self.tokenizer_manager.generate_request(obj, None)\n        return await generator.__anext__()\n\n    def rerank(\n        self,\n        prompt: Union[List[List[str]]],\n    ) -> Dict:\n        \"\"\"\n        The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.\n        Please refer to `EmbeddingReqInput` for the documentation.\n        \"\"\"\n        obj = EmbeddingReqInput(text=prompt, is_cross_encoder_request=True)\n        generator = self.tokenizer_manager.generate_request(obj, None)\n        ret = self.loop.run_until_complete(generator.__anext__())\n        return ret\n\n    @classmethod\n    def _launch_scheduler_processes(\n        cls,\n        server_args: ServerArgs,\n        port_args: PortArgs,\n        run_scheduler_process_func: Callable,\n    ) -> SchedulerInitResult:\n        \"\"\"Launch scheduler processes using multiprocessing.\n        Override in subclasses for different backends (e.g. Ray).\n        \"\"\"\n        scheduler_procs = []\n\n        if server_args.dp_size == 1:\n            # Launch tensor parallel scheduler processes\n            memory_saver_adapter = TorchMemorySaverAdapter.create(\n                enable=server_args.enable_memory_saver\n            )\n            scheduler_pipe_readers = []\n\n            pp_rank_range, tp_rank_range, pp_size_per_node, tp_size_per_node = (\n                _calculate_rank_ranges(\n                    server_args.nnodes,\n                    server_args.pp_size,\n                    server_args.tp_size,\n                    server_args.node_rank,\n                )\n            )\n\n            for pp_rank in pp_rank_range:\n                for tp_rank in tp_rank_range:\n                    reader, writer = mp.Pipe(duplex=False)\n                    gpu_id = (\n                        server_args.base_gpu_id\n                        + ((pp_rank % pp_size_per_node) * tp_size_per_node)\n                        + (tp_rank % tp_size_per_node) * server_args.gpu_id_step\n                    )\n                    attn_cp_rank, moe_dp_rank, moe_ep_rank = _compute_parallelism_ranks(\n                        server_args, tp_rank\n                    )\n\n                    with maybe_reindex_device_id(gpu_id) as gpu_id:\n                        proc = mp.Process(\n                            target=run_scheduler_process_func,\n                            args=(\n                                server_args,\n                                port_args,\n                                gpu_id,\n                                tp_rank,\n                                attn_cp_rank,\n                                moe_dp_rank,\n                                moe_ep_rank,\n                                pp_rank,\n                                None,\n                                writer,\n                            ),\n                        )\n                        with memory_saver_adapter.configure_subprocess(), numa_utils.configure_subprocess(\n                            server_args, gpu_id\n                        ):\n                            proc.start()\n\n                    scheduler_procs.append(proc)\n                    scheduler_pipe_readers.append(reader)\n        else:\n            # Launch the data parallel controller\n            reader, writer = mp.Pipe(duplex=False)\n            scheduler_pipe_readers = [reader]\n            proc = mp.Process(\n                target=run_data_parallel_controller_process,\n                kwargs=dict(\n                    server_args=server_args,\n                    port_args=port_args,\n                    pipe_writer=writer,\n                    run_scheduler_process_func=run_scheduler_process_func,\n                ),\n            )\n            proc.start()\n            scheduler_procs.append(proc)\n\n        scheduler_infos = []\n\n        def wait_for_ready():\n            infos = _wait_for_scheduler_ready(scheduler_pipe_readers, scheduler_procs)\n            scheduler_infos.extend(infos)\n\n        def wait_for_completion():\n            for proc in scheduler_procs:\n                proc.join()\n                logger.error(\n                    f\"Scheduler or DataParallelController {proc.pid} \"\n                    f\"terminated with {proc.exitcode}\"\n                )\n\n        return SchedulerInitResult(\n            scheduler_infos=scheduler_infos,\n            wait_for_ready=wait_for_ready,\n            wait_for_completion=wait_for_completion,\n        )\n\n    @classmethod\n    def _launch_subprocesses(\n        cls,\n        server_args: ServerArgs,\n        init_tokenizer_manager_func: Callable,\n        run_scheduler_process_func: Callable,\n        run_detokenizer_process_func: Callable,\n        port_args: Optional[PortArgs] = None,\n    ) -> Tuple[TokenizerManager, TemplateManager, PortArgs, SchedulerInitResult]:\n        \"\"\"Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.\n\n        Returns:\n            Tuple of (tokenizer_manager, template_manager, port_args, scheduler_init_result).\n        \"\"\"\n        # Configure global environment\n        configure_logger(server_args)\n        _set_envs_and_config(server_args)\n        server_args.check_server_args()\n\n        # Allocate ports for inter-process communications\n        if port_args is None:\n            port_args = PortArgs.init_new(server_args)\n        logger.info(f\"{server_args=}\")\n\n        # Launch scheduler processes\n        scheduler_init_result = cls._launch_scheduler_processes(\n            server_args, port_args, run_scheduler_process_func\n        )\n\n        if (\n            server_args.enable_elastic_expert_backup\n            and server_args.elastic_ep_backend is not None\n        ):\n            run_expert_backup_manager(server_args, port_args)\n\n        if server_args.node_rank >= 1:\n            # In multi-node cases, non-zero rank nodes do not need to run tokenizer or detokenizer,\n            # so they can just wait here.\n            scheduler_init_result.wait_for_ready()\n\n            if os.getenv(\"SGLANG_BLOCK_NONZERO_RANK_CHILDREN\") == \"0\":\n                # When using `Engine` as a Python API, we don't want to block here.\n                return (\n                    None,\n                    None,\n                    port_args,\n                    scheduler_init_result,\n                )\n\n            launch_dummy_health_check_server(\n                server_args.host, server_args.port, server_args.enable_metrics\n            )\n\n            scheduler_init_result.wait_for_completion()\n            return (\n                None,\n                None,\n                port_args,\n                scheduler_init_result,\n            )\n\n        # Launch detokenizer process\n        detoken_proc = mp.Process(\n            target=run_detokenizer_process_func,\n            args=(\n                server_args,\n                port_args,\n            ),\n        )\n        detoken_proc.start()\n\n        # Init tokenizer manager first, as the bootstrap server is initialized here\n        if server_args.tokenizer_worker_num == 1:\n            tokenizer_manager, template_manager = init_tokenizer_manager_func(\n                server_args, port_args\n            )\n        else:\n            # Launch multi-tokenizer router\n            tokenizer_manager = MultiTokenizerRouter(server_args, port_args)\n            template_manager = None\n\n        # Wait for the model to finish loading\n        scheduler_init_result.wait_for_ready()\n\n        # Get back some info from scheduler to tokenizer_manager\n        tokenizer_manager.max_req_input_len = scheduler_init_result.scheduler_infos[0][\n            \"max_req_input_len\"\n        ]\n\n        return (\n            tokenizer_manager,\n            template_manager,\n            port_args,\n            scheduler_init_result,\n        )\n\n    def shutdown(self):\n        \"\"\"Shutdown the engine\"\"\"\n        kill_process_tree(os.getpid(), include_parent=False)\n\n    def __enter__(self):\n        return self\n\n    def __exit__(self, exc_type, exc_value, traceback):\n        self.shutdown()\n        return False\n\n    def flush_cache(self):\n        return self.loop.run_until_complete(self.tokenizer_manager.flush_cache())\n\n    def open_session(\n        self,\n        capacity_of_str_len: int,\n        session_id: Optional[str] = None,\n        streaming: bool = False,\n        timeout: Optional[float] = None,\n    ) -> str:\n        \"\"\"Open a session for multi-turn conversation with shared context.\n\n        Args:\n            capacity_of_str_len: Maximum string length capacity for the session.\n            session_id: Optional session ID. If not provided, a UUID will be generated.\n            streaming: Use low-overhead path for realtime streaming (append-only mode).\n            timeout: If set, the session is automatically closed after being inactive\n                for this many seconds. Inactivity is measured from session open or the\n                most recent request submission.\n\n        Returns:\n            The session ID (either the provided one or a newly generated UUID).\n        \"\"\"\n        obj = OpenSessionReqInput(\n            capacity_of_str_len=capacity_of_str_len,\n            session_id=session_id,\n            streaming=streaming,\n            timeout=timeout,\n        )\n        return self.loop.run_until_complete(\n            self.tokenizer_manager.open_session(obj, None)\n        )\n\n    def close_session(self, session_id: str) -> None:\n        \"\"\"Close a session and release its resources.\n\n        Args:\n            session_id: The session ID to close.\n        \"\"\"\n        obj = CloseSessionReqInput(session_id=session_id)\n        self.loop.run_until_complete(self.tokenizer_manager.close_session(obj, None))\n\n    def start_profile(self, **kwargs):\n        self.loop.run_until_complete(self.tokenizer_manager.start_profile(**kwargs))\n\n    def stop_profile(self):\n        self.loop.run_until_complete(self.tokenizer_manager.stop_profile())\n\n    def start_expert_distribution_record(self):\n        self.loop.run_until_complete(\n            self.tokenizer_manager.start_expert_distribution_record()\n        )\n\n    def stop_expert_distribution_record(self):\n        self.loop.run_until_complete(\n            self.tokenizer_manager.stop_expert_distribution_record()\n        )\n\n    def dump_expert_distribution_record(self):\n        self.loop.run_until_complete(\n            self.tokenizer_manager.dump_expert_distribution_record()\n        )\n\n    def get_server_info(self):\n        internal_states = self.loop.run_until_complete(\n            self.tokenizer_manager.get_internal_state()\n        )\n        return {\n            **dataclasses.asdict(self.tokenizer_manager.server_args),\n            **self._scheduler_init_result.scheduler_infos[0],\n            \"internal_states\": internal_states,\n            \"version\": __version__,\n        }\n\n    def init_weights_update_group(\n        self,\n        master_address: str,\n        master_port: int,\n        rank_offset: int,\n        world_size: int,\n        group_name: str,\n        backend: str = \"nccl\",\n    ):\n        \"\"\"Initialize parameter update group.\"\"\"\n        obj = InitWeightsUpdateGroupReqInput(\n            master_address=master_address,\n            master_port=master_port,\n            rank_offset=rank_offset,\n            world_size=world_size,\n            group_name=group_name,\n            backend=backend,\n        )\n        return self.loop.run_until_complete(\n            self.tokenizer_manager.init_weights_update_group(obj, None)\n        )\n\n    def destroy_weights_update_group(\n        self,\n        group_name: str,\n    ):\n        \"\"\"Destroy parameter update group.\"\"\"\n        obj = DestroyWeightsUpdateGroupReqInput(\n            group_name=group_name,\n        )\n        return self.loop.run_until_complete(\n            self.tokenizer_manager.destroy_weights_update_group(obj, None)\n        )\n\n    def update_weights_from_distributed(\n        self,\n        names: list[str],\n        dtypes: list[str],\n        shapes: list[list[int]],\n        group_name: str = \"weight_update_group\",\n        flush_cache: bool = True,\n        load_format: Optional[str] = None,\n    ):\n        \"\"\"Update weights from distributed source.\"\"\"\n        obj = UpdateWeightsFromDistributedReqInput(\n            names=names,\n            dtypes=dtypes,\n            shapes=shapes,\n            group_name=group_name,\n            flush_cache=flush_cache,\n            load_format=load_format,\n        )\n        return self.loop.run_until_complete(\n            self.tokenizer_manager.update_weights_from_distributed(obj, None)\n        )\n\n    def update_weights_from_tensor(\n        self,\n        named_tensors: List[Tuple[str, torch.Tensor]],\n        load_format: Optional[str] = None,\n        flush_cache: bool = True,\n    ):\n        \"\"\"Update weights from distributed source. If there are going to be more updates, set `flush_cache` to be false\n        to avoid duplicated cache cleaning operation.\"\"\"\n        if load_format == \"flattened_bucket\":\n            serialized_named_tensors = named_tensors\n        else:\n            serialized_named_tensors = [\n                MultiprocessingSerializer.serialize(named_tensors)\n                for _ in range(self.server_args.tp_size)\n            ]\n        obj = UpdateWeightsFromTensorReqInput(\n            serialized_named_tensors=serialized_named_tensors,\n            load_format=load_format,\n            flush_cache=flush_cache,\n        )\n        return self.loop.run_until_complete(\n            self.tokenizer_manager.update_weights_from_tensor(obj, None)\n        )\n\n    def update_weights_from_disk(\n        self,\n        model_path: str,\n        load_format: Optional[str] = None,\n    ):\n        \"\"\"Update the weights from disk inplace without re-launching the engine.\n\n        This method allows updating the model weights from disk without restarting\n        the engine. It can be used to load a different model or update weights with\n        new training.\n        \"\"\"\n        obj = UpdateWeightFromDiskReqInput(\n            model_path=model_path,\n            load_format=load_format,\n        )\n\n        return self.loop.run_until_complete(\n            self.tokenizer_manager.update_weights_from_disk(obj, None)\n        )\n\n    def update_weights_from_ipc(\n        self,\n        zmq_handles: Dict[str, str],\n        flush_cache: bool = True,\n    ):\n        \"\"\"Update weights from IPC for checkpoint-engine integration.\"\"\"\n        obj = UpdateWeightsFromIPCReqInput(\n            zmq_handles=zmq_handles,\n            flush_cache=flush_cache,\n        )\n        return self.loop.run_until_complete(\n            self.tokenizer_manager.update_weights_from_ipc(obj, None)\n        )\n\n    def get_weights_by_name(self, name: str, truncate_size: int = 100):\n        \"\"\"Get weights by parameter name.\"\"\"\n        obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)\n        return self.loop.run_until_complete(\n            self.tokenizer_manager.get_weights_by_name(obj, None)\n        )\n\n    def load_lora_adapter_from_tensors(\n        self,\n        lora_name: str,\n        tensors,\n        config_dict: Dict,\n        load_format: Optional[str] = None,\n    ):\n        if load_format == \"flattened_bucket\":\n            serialized_tensors = tensors\n        else:\n            serialized_tensors = MultiprocessingSerializer.serialize(\n                tensors, output_str=True\n            )\n        lora_req = LoadLoRAAdapterFromTensorsReqInput(\n            lora_name=lora_name,\n            config_dict=config_dict,\n            serialized_tensors=serialized_tensors,\n            load_format=load_format,\n        )\n        return self.loop.run_until_complete(\n            self.tokenizer_manager.load_lora_adapter_from_tensors(lora_req, None)\n        )\n\n    def load_lora_adapter(self, lora_name: str, lora_path: str, pinned: bool = False):\n        \"\"\"Load a new LoRA adapter without re-launching the engine.\"\"\"\n\n        obj = LoadLoRAAdapterReqInput(\n            lora_name=lora_name,\n            lora_path=lora_path,\n            pinned=pinned,\n        )\n\n        return self.loop.run_until_complete(\n            self.tokenizer_manager.load_lora_adapter(obj, None)\n        )\n\n    def unload_lora_adapter(self, lora_name: str):\n        \"\"\"Unload a LoRA adapter without re-launching the engine.\"\"\"\n\n        obj = UnloadLoRAAdapterReqInput(lora_name=lora_name)\n\n        return self.loop.run_until_complete(\n            self.tokenizer_manager.unload_lora_adapter(obj, None)\n        )\n\n    async def async_load_lora_adapter(\n        self, lora_name: str, lora_path: str, pinned: bool = False\n    ):\n        \"\"\"\n        Asynchronous version of load_lora_adapter.\n\n        See load_lora_adapter() for detailed documentation.\n        \"\"\"\n\n        obj = LoadLoRAAdapterReqInput(\n            lora_name=lora_name,\n            lora_path=lora_path,\n            pinned=pinned,\n        )\n\n        return await self.tokenizer_manager.load_lora_adapter(obj, None)\n\n    async def async_unload_lora_adapter(self, lora_name: str):\n        \"\"\"\n        Asynchronous version of unload_lora_adapter.\n\n        See unload_lora_adapter() for detailed documentation.\n        \"\"\"\n\n        obj = UnloadLoRAAdapterReqInput(lora_name=lora_name)\n\n        return await self.tokenizer_manager.unload_lora_adapter(obj, None)\n\n    def release_memory_occupation(self, tags: Optional[List[str]] = None):\n        obj = ReleaseMemoryOccupationReqInput(tags=tags)\n        return self.loop.run_until_complete(\n            self.tokenizer_manager.release_memory_occupation(obj, None)\n        )\n\n    def resume_memory_occupation(self, tags: Optional[List[str]] = None):\n        obj = ResumeMemoryOccupationReqInput(tags=tags)\n        return self.loop.run_until_complete(\n            self.tokenizer_manager.resume_memory_occupation(obj, None)\n        )\n\n    def freeze_gc(self):\n        \"\"\"\n        To maintain a high performance server with low latency, we want to reduce the\n        stalls caused by the garbage collector scanning through a large number of objects.\n\n        It is usually helpful to start the server and warm it up with real requests to\n        initialize many of the long-lived objects that do not need to be garbage collected.\n\n        After sufficient warmup, we can call this function to freeze the garbage collector\n        so that all objects created before this point are considered out of scope for garbage\n        collection.\n        \"\"\"\n\n        self.loop.run_until_complete(self.tokenizer_manager.freeze_gc())\n\n    \"\"\"\n    Execute an RPC call on all scheduler processes.\n    \"\"\"\n\n    def collective_rpc(self, method: str, **kwargs):\n        obj = RpcReqInput(method=method, parameters=kwargs)\n        self.send_to_rpc.send_pyobj(obj)\n        recv_req = self.send_to_rpc.recv_pyobj(zmq.BLOCKY)\n        assert isinstance(recv_req, RpcReqOutput)\n        assert recv_req.success, recv_req.message\n\n    def save_remote_model(self, **kwargs):\n        self.collective_rpc(\"save_remote_model\", **kwargs)\n\n    def save_sharded_model(self, **kwargs):\n        self.collective_rpc(\"save_sharded_model\", **kwargs)\n\n    def score(\n        self,\n        query: Optional[Union[str, List[int]]] = None,\n        items: Optional[Union[str, List[str], List[List[int]]]] = None,\n        label_token_ids: Optional[List[int]] = None,\n        apply_softmax: bool = False,\n        item_first: bool = False,\n    ) -> ScoreResult:\n        \"\"\"\n        Score the probability of specified token IDs appearing after the given (query + item) pair. For example:\n        query = \"<|user|>Is the following city the capital of France? \"\n        items = [\"Paris <|assistant|>\", \"London <|assistant|>\", \"Berlin <|assistant|>\"]\n        label_token_ids = [2332, 1223] # Token IDs for \"Yes\" and \"No\"\n        item_first = False\n\n        This would pass the following prompts to the model:\n        \"<|user|>Is the following city the capital of France? Paris <|assistant|>\"\n        \"<|user|>Is the following city the capital of France? London <|assistant|>\"\n        \"<|user|>Is the following city the capital of France? Berlin <|assistant|>\"\n        The api would then return the probabilities of the model producing \"Yes\" and \"No\" as the next token.\n        The output would look like:\n        [[0.9, 0.1], [0.2, 0.8], [0.1, 0.9]]\n\n\n        Args:\n            query: The query text or pre-tokenized query token IDs. Must be provided.\n            items: The item text(s) or pre-tokenized item token IDs. Must be provided.\n            label_token_ids: List of token IDs to compute probabilities for. If None, no token probabilities will be computed.\n            apply_softmax: Whether to normalize probabilities using softmax.\n            item_first: If True, prepend items to query. Otherwise append items to query.\n\n        Returns:\n            ScoreResult with:\n                scores: List of lists containing probabilities for each item and each label token\n                prompt_tokens: The number of prompt tokens processed.\n\n        Raises:\n            ValueError: If query is not provided, or if items is not provided,\n                      or if token IDs are out of vocabulary, or if logprobs are not available for the specified tokens.\n        \"\"\"\n        return self.loop.run_until_complete(\n            self.tokenizer_manager.score_request(\n                query=query,\n                items=items,\n                label_token_ids=label_token_ids,\n                apply_softmax=apply_softmax,\n                item_first=item_first,\n                request=None,\n            )\n        )\n\n    async def async_score(\n        self,\n        query: Optional[Union[str, List[int]]] = None,\n        items: Optional[Union[str, List[str], List[List[int]]]] = None,\n        label_token_ids: Optional[List[int]] = None,\n        apply_softmax: bool = False,\n        item_first: bool = False,\n    ) -> ScoreResult:\n        \"\"\"\n        Asynchronous version of score method.\n\n        See score() for detailed documentation.\n        \"\"\"\n        return await self.tokenizer_manager.score_request(\n            query=query,\n            items=items,\n            label_token_ids=label_token_ids,\n            apply_softmax=apply_softmax,\n            item_first=item_first,\n            request=None,\n        )\n\n\ndef _set_envs_and_config(server_args: ServerArgs):\n    # Set global environments\n    if \"NCCL_CUMEM_ENABLE\" not in os.environ or server_args.enable_symm_mem:\n        os.environ[\"NCCL_CUMEM_ENABLE\"] = str(int(server_args.enable_symm_mem))\n    if (\n        \"NCCL_NVLS_ENABLE\" not in os.environ\n        or server_args.enable_nccl_nvls\n        or server_args.enable_symm_mem\n    ):\n        os.environ[\"NCCL_NVLS_ENABLE\"] = str(\n            int(server_args.enable_nccl_nvls or server_args.enable_symm_mem)\n        )\n    os.environ[\"CUDA_DEVICE_MAX_CONNECTIONS\"] = \"8\"\n    os.environ[\"CUDA_MODULE_LOADING\"] = \"AUTO\"\n\n    if os.environ.get(\"TRTLLM_ENABLE_PDL\", \"1\") != \"0\":\n        # flashinfer uses this environment variable for various kernels from MoE to quant kernels\n        os.environ[\"TRTLLM_ENABLE_PDL\"] = \"1\"\n\n    if os.environ.get(\"CUTE_DSL_LOG_LEVEL\") is None:\n        # Default to warning level, to avoid too many logs\n        os.environ[\"CUTE_DSL_LOG_LEVEL\"] = \"30\"\n\n    if os.environ.get(\"CUTE_DSL_LOG_TO_CONSOLE\") is None:\n        # Need to set log to console, otherwise the log level won't take effect\n        os.environ[\"CUTE_DSL_LOG_TO_CONSOLE\"] = \"1\"\n\n    # Can also be passed as argument\n    os.environ[\"SGLANG_RUN_ID\"] = (\n        f\"sglang-run-{time.time()}-{random.randint(0, 100000000)}\"\n    )\n\n    # Set prometheus env vars\n    if server_args.enable_metrics:\n        set_prometheus_multiproc_dir()\n\n    # Set ulimit\n    set_ulimit()\n\n    # Check flashinfer version\n    if not get_bool_env_var(\"SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK\"):\n        if server_args.attention_backend == \"flashinfer\":\n            assert_pkg_version(\n                \"flashinfer_python\",\n                \"0.6.6\",\n                \"Please uninstall the old version and \"\n                \"reinstall the latest version by following the instructions \"\n                \"at https://docs.flashinfer.ai/installation.html.\",\n            )\n        if _is_cuda:\n            assert_pkg_version(\n                \"sglang-kernel\",\n                \"0.4.0\",\n                \"Please reinstall the latest version with `pip install sglang-kernel --force-reinstall`\",\n            )\n\n    # Signal handlers can only be registered from the main thread.\n    if threading.current_thread() is threading.main_thread():\n        if server_args.custom_sigquit_handler is None:\n            # Register the signal handler.\n            # The child processes will send SIGQUIT to this process when any error happens\n            # This process then clean up the whole process tree\n            # Note: This sigquit handler is used in the launch phase, and may be replaced by\n            # the running_phase_sigquit_handler in the tokenizer manager after the grpc server is launched.\n            def launch_phase_sigquit_handler(signum, frame):\n                logger.error(\n                    \"Received sigquit from a child process. It usually means the child failed.\"\n                )\n                kill_process_tree(os.getpid())\n\n            signal.signal(signal.SIGQUIT, launch_phase_sigquit_handler)\n        else:\n            # Allow users to register a custom SIGQUIT handler for things like crash dump\n            logger.error(\n                f\"Using custom SIGQUIT handler: {server_args.custom_sigquit_handler}\"\n            )\n            signal.signal(signal.SIGQUIT, server_args.custom_sigquit_handler)\n    else:\n        logger.warning(\n            \"Signal handler is not added because the engine is not in the \"\n            \"main thread. This disables the SIGQUIT handler for cleaning up \"\n            \"the process tree when a child process fails.\"\n        )\n\n    # Set mp start method\n    mp.set_start_method(\"spawn\", force=True)\n\n\ndef _wait_for_scheduler_ready(\n    scheduler_pipe_readers: List,\n    scheduler_procs: List,\n) -> List[Dict]:\n    \"\"\"Wait for the model to finish loading and return scheduler infos.\"\"\"\n    scheduler_infos = []\n    for i in range(len(scheduler_pipe_readers)):\n        try:\n            data = scheduler_pipe_readers[i].recv()\n        except EOFError:\n            logger.error(\n                f\"Rank {i} scheduler is dead. Please check if there are relevant logs.\"\n            )\n            scheduler_procs[i].join()\n            logger.error(f\"Exit code: {scheduler_procs[i].exitcode}\")\n            raise\n\n        if data[\"status\"] != \"ready\":\n            raise RuntimeError(\n                \"Initialization failed. Please see the error messages above.\"\n            )\n        scheduler_infos.append(data)\n    return scheduler_infos\n\n\ndef _calculate_rank_ranges(\n    nnodes: int, pp_size: int, tp_size: int, node_rank: int\n) -> Tuple[range, range, int, int]:\n    \"\"\"Calculate pp_rank_range and tp_rank_range for a given node.\n\n    Args:\n        nnodes: Total number of nodes.\n        pp_size: Pipeline parallel size.\n        tp_size: Tensor parallel size.\n        node_rank: The rank of the node to compute ranges for.\n\n    Returns:\n        A tuple of (pp_rank_range, tp_rank_range, pp_size_per_node, tp_size_per_node):\n        - pp_rank_range: range of pipeline-parallel ranks assigned to this node.\n        - tp_rank_range: range of tensor-parallel ranks assigned to this node.\n        - pp_size_per_node: number of PP ranks per node.\n        - tp_size_per_node: number of TP ranks per node.\n    \"\"\"\n    pp_size_per_node = max(pp_size // nnodes, 1)\n    nnodes_per_pp_rank = max(nnodes // pp_size, 1)\n    pp_rank_range = range(\n        pp_size_per_node * (node_rank // nnodes_per_pp_rank),\n        pp_size_per_node * (node_rank // nnodes_per_pp_rank + 1),\n    )\n\n    nnodes_per_tp_group = nnodes_per_pp_rank\n    tp_size_per_node = tp_size // nnodes_per_tp_group\n    tp_rank_range = range(\n        tp_size_per_node * (node_rank % nnodes_per_tp_group),\n        tp_size_per_node * (node_rank % nnodes_per_tp_group + 1),\n    )\n\n    return pp_rank_range, tp_rank_range, pp_size_per_node, tp_size_per_node\n\n\ndef _compute_parallelism_ranks(\n    server_args: ServerArgs, tp_rank: int\n) -> Tuple[int, int, int]:\n    \"\"\"Compute attention-CP, MoE-DP, and MoE-EP ranks for a TP rank.\"\"\"\n    attn_dp_size = server_args.dp_size if server_args.enable_dp_attention else 1\n\n    # Parallelism hierarchy (outermost to innermost):\n    # - Attention: Global(TP) -> DP -> ATTN_CP -> ATTN_TP (innermost)\n    # - MoE: Global(TP) -> MOE_DP -> EP -> MOE_TP (innermost)\n    attn_tp_size = server_args.tp_size // attn_dp_size // server_args.attn_cp_size\n    attn_cp_rank = (tp_rank // attn_tp_size) % server_args.attn_cp_size\n    moe_dp_rank = tp_rank // (server_args.tp_size // server_args.moe_dp_size)\n    moe_ep_rank = (\n        tp_rank\n        % (server_args.tp_size // server_args.moe_dp_size)\n        // (server_args.tp_size // server_args.moe_dp_size // server_args.ep_size)\n    )\n    return attn_cp_rank, moe_dp_rank, moe_ep_rank\n"
  },
  {
    "path": "python/sglang/srt/entrypoints/grpc_server.py",
    "content": "\"\"\"\nThin gRPC server wrapper — delegates to smg-grpc-servicer package.\n\"\"\"\n\n\nasync def serve_grpc(server_args, model_info=None):\n    \"\"\"Start the standalone gRPC server with integrated scheduler.\"\"\"\n    try:\n        from smg_grpc_servicer.sglang.server import serve_grpc as _serve_grpc\n    except ImportError:\n        raise ImportError(\n            \"gRPC mode requires the smg-grpc-servicer package. \"\n            \"Install it with: pip install smg-grpc-servicer[sglang]\"\n        ) from None\n    await _serve_grpc(server_args, model_info)\n"
  },
  {
    "path": "python/sglang/srt/entrypoints/harmony_utils.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n# SPDX-FileCopyrightText: Copyright contributors to the vLLM project\n# Adapted from vLLM: https://github.com/vllm-project/vllm/blob/1b9902806915040ac9b3029f2ab7522ec505afc3/vllm/entrypoints/harmony_utils.py\n# Slight differences in processing chat messages\nimport datetime\nfrom collections.abc import Iterable\nfrom typing import Literal, Optional, Union\n\nimport orjson\nfrom openai.types.responses import (\n    ResponseOutputItem,\n    ResponseOutputMessage,\n    ResponseOutputText,\n    ResponseReasoningItem,\n)\nfrom openai.types.responses.response_function_tool_call import ResponseFunctionToolCall\nfrom openai.types.responses.response_function_web_search import (\n    ActionFind,\n    ActionOpenPage,\n    ActionSearch,\n    ResponseFunctionWebSearch,\n)\nfrom openai.types.responses.response_reasoning_item import (\n    Content as ResponseReasoningTextContent,\n)\nfrom openai.types.responses.tool import Tool\nfrom openai_harmony import (\n    Author,\n    Conversation,\n    DeveloperContent,\n    HarmonyEncodingName,\n    Message,\n    ReasoningEffort,\n    Role,\n    StreamableParser,\n    SystemContent,\n    TextContent,\n    ToolDescription,\n    load_harmony_encoding,\n)\n\nfrom sglang.srt.entrypoints.openai.protocol import ResponseInputOutputItem\nfrom sglang.srt.utils import random_uuid\n\nREASONING_EFFORT = {\n    \"high\": ReasoningEffort.HIGH,\n    \"medium\": ReasoningEffort.MEDIUM,\n    \"low\": ReasoningEffort.LOW,\n}\n\n_harmony_encoding = None\n\n\ndef get_encoding():\n    global _harmony_encoding\n    if _harmony_encoding is None:\n        _harmony_encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)\n    return _harmony_encoding\n\n\ndef get_system_message(\n    model_identity: Optional[str] = None,\n    reasoning_effort: Optional[Literal[\"high\", \"medium\", \"low\"]] = None,\n    start_date: Optional[str] = None,\n    browser_description: Optional[str] = None,\n    python_description: Optional[str] = None,\n) -> Message:\n    sys_msg_content = SystemContent.new()\n    if model_identity is not None:\n        sys_msg_content = sys_msg_content.with_model_identity(model_identity)\n    if reasoning_effort is not None:\n        sys_msg_content = sys_msg_content.with_reasoning_effort(\n            REASONING_EFFORT[reasoning_effort]\n        )\n    if start_date is None:\n        start_date = datetime.datetime.now().strftime(\"%Y-%m-%d\")\n    sys_msg_content = sys_msg_content.with_conversation_start_date(start_date)\n    if browser_description is not None:\n        sys_msg_content = sys_msg_content.with_tools(browser_description)\n    if python_description is not None:\n        sys_msg_content = sys_msg_content.with_tools(python_description)\n    sys_msg = Message.from_role_and_content(Role.SYSTEM, sys_msg_content)\n    return sys_msg\n\n\ndef get_developer_message(\n    instructions: Optional[str] = None, tools: Optional[list[Tool]] = None\n) -> Message:\n    dev_msg_content = DeveloperContent.new()\n    if instructions is not None:\n        dev_msg_content = dev_msg_content.with_instructions(instructions)\n    if tools is not None:\n        function_tools = []\n        for tool in tools:\n            if tool.type in (\"web_search_preview\", \"code_interpreter\"):\n                # These are built-in tools that are added to the system message.\n                pass\n            elif tool.type == \"function\":\n                function_tools.append(tool)\n            else:\n                raise ValueError(f\"tool type {tool.type} not supported\")\n        if function_tools:\n            function_tool_descriptions = [\n                ToolDescription.new(\n                    name=tool.name,\n                    description=tool.description,\n                    parameters=tool.parameters,\n                )\n                for tool in function_tools\n            ]\n            dev_msg_content = dev_msg_content.with_function_tools(\n                function_tool_descriptions\n            )\n    dev_msg = Message.from_role_and_content(Role.DEVELOPER, dev_msg_content)\n    return dev_msg\n\n\ndef get_user_message(content: str) -> Message:\n    return Message.from_role_and_content(Role.USER, content)\n\n\ndef parse_response_input(\n    response_msg: ResponseInputOutputItem,\n    prev_responses: list[Union[ResponseOutputItem, ResponseReasoningItem]],\n) -> Message:\n    if not isinstance(response_msg, dict):\n        response_msg = response_msg.model_dump()\n    if \"type\" not in response_msg or response_msg[\"type\"] == \"message\":\n        role = response_msg[\"role\"]\n        content = response_msg[\"content\"]\n        if role == \"system\":\n            # User is trying to set a system message. Change it to:\n            # <|start|>developer<|message|># Instructions\n            # {instructions}<|end|>\n            role = \"developer\"\n            text_prefix = \"Instructions:\\n\"\n        else:\n            text_prefix = \"\"\n        if isinstance(content, str):\n            msg = Message.from_role_and_content(role, text_prefix + content)\n        else:\n            contents = [TextContent(text=text_prefix + c[\"text\"]) for c in content]\n            msg = Message.from_role_and_contents(role, contents)\n    elif response_msg[\"type\"] == \"function_call_output\":\n        call_id = response_msg[\"call_id\"]\n        call_response: Optional[ResponseFunctionToolCall] = None\n        for prev_response in reversed(prev_responses):\n            if (\n                isinstance(prev_response, ResponseFunctionToolCall)\n                and prev_response.call_id == call_id\n            ):\n                call_response = prev_response\n                break\n        if call_response is None:\n            raise ValueError(f\"No call message found for {call_id}\")\n        msg = Message.from_author_and_content(\n            Author.new(Role.TOOL, f\"functions.{call_response.name}\"),\n            response_msg[\"output\"],\n        )\n    elif response_msg[\"type\"] == \"reasoning\":\n        content = response_msg[\"content\"]\n        assert len(content) == 1\n        msg = Message.from_role_and_content(Role.ASSISTANT, content[0][\"text\"])\n    elif response_msg[\"type\"] == \"function_call\":\n        msg = Message.from_role_and_content(Role.ASSISTANT, response_msg[\"arguments\"])\n        msg = msg.with_channel(\"commentary\")\n        msg = msg.with_recipient(f\"functions.{response_msg['name']}\")\n        msg = msg.with_content_type(\"json\")\n    else:\n        raise ValueError(f\"Unknown input type: {response_msg['type']}\")\n    return msg\n\n\ndef parse_response_output(output: ResponseOutputItem) -> Message:\n    if isinstance(output, ResponseOutputMessage):\n        role = output.role\n        contents = [TextContent(text=c.text) for c in output.content]\n        msg = Message.from_role_and_contents(role, contents)\n        return msg\n    elif isinstance(output, ResponseFunctionToolCall):\n        msg = Message.from_role_and_content(Role.ASSISTANT, output.arguments)\n        msg = msg.with_channel(\"commentary\")\n        msg = msg.with_recipient(output.name)\n        msg = msg.with_content_type(\"json\")\n        return msg\n    else:\n        raise ValueError(f\"Unknown output type: {type(output)}\")\n\n\ndef parse_chat_input(chat_msg) -> Message:\n    role = chat_msg.role\n    content = chat_msg.content\n    if isinstance(content, str):\n        contents = [TextContent(text=content)]\n    else:\n        # TODO: Support refusal.\n        contents = [TextContent(text=c.text) for c in content]\n    msg = Message.from_role_and_contents(role, contents)\n    return msg\n\n\ndef render_for_completion(messages: list[Message]) -> list[int]:\n    conversation = Conversation.from_messages(messages)\n    token_ids = get_encoding().render_conversation_for_completion(\n        conversation, Role.ASSISTANT\n    )\n    return token_ids\n\n\ndef get_stop_tokens_for_assistant_actions() -> list[int]:\n    return get_encoding().stop_tokens_for_assistant_actions()\n\n\ndef get_streamable_parser_for_assistant() -> StreamableParser:\n    return StreamableParser(get_encoding(), role=Role.ASSISTANT)\n\n\ndef parse_output_message(message: Message):\n    if message.author.role != \"assistant\":\n        # This is a message from a tool to the assistant (e.g., search result).\n        # Don't include it in the final output for now. This aligns with\n        # OpenAI's behavior on models like o4-mini.\n        return []\n\n    output_items = []\n    recipient = message.recipient\n    if recipient is not None and recipient.startswith(\"browser.\"):\n        if len(message.content) != 1:\n            raise ValueError(\"Invalid number of contents in browser message\")\n        content = message.content[0]\n        browser_call = orjson.loads(content.text)\n        # TODO: translate to url properly!\n        if recipient == \"browser.search\":\n            action = ActionSearch(\n                query=f\"cursor:{browser_call.get('query', '')}\", type=\"search\"\n            )\n        elif recipient == \"browser.open\":\n            action = ActionOpenPage(\n                url=f\"cursor:{browser_call.get('url', '')}\", type=\"open_page\"\n            )\n        elif recipient == \"browser.find\":\n            action = ActionFind(\n                pattern=browser_call[\"pattern\"],\n                url=f\"cursor:{browser_call.get('url', '')}\",\n                type=\"find\",\n            )\n        else:\n            raise ValueError(f\"Unknown browser action: {recipient}\")\n        web_search_item = ResponseFunctionWebSearch(\n            id=f\"ws_{random_uuid()}\",\n            action=action,\n            status=\"completed\",\n            type=\"web_search_call\",\n        )\n        output_items.append(web_search_item)\n    elif message.channel == \"analysis\":\n        for content in message.content:\n            reasoning_item = ResponseReasoningItem(\n                id=f\"rs_{random_uuid()}\",\n                type=\"reasoning\",\n                summary=[],\n                content=[\n                    ResponseReasoningTextContent(\n                        text=content.text, type=\"reasoning_text\"\n                    )\n                ],\n                status=None,\n            )\n            output_items.append(reasoning_item)\n    elif message.channel == \"commentary\":\n        if message.recipient.startswith(\"functions.\"):\n            function_name = message.recipient.split(\".\")[-1]\n            for content in message.content:\n                random_id = random_uuid()\n                response_item = ResponseFunctionToolCall(\n                    arguments=content.text,\n                    call_id=f\"call_{random_id}\",\n                    type=\"function_call\",\n                    name=function_name,\n                    id=f\"ft_{random_id}\",\n                )\n                output_items.append(response_item)\n        elif message.recipient.startswith(\"python\") or message.recipient.startswith(\n            \"browser\"\n        ):\n            for content in message.content:\n                reasoning_item = ResponseReasoningItem(\n                    id=f\"rs_{random_uuid()}\",\n                    type=\"reasoning\",\n                    summary=[],\n                    content=[\n                        ResponseReasoningTextContent(\n                            text=content.text, type=\"reasoning_text\"\n                        )\n                    ],\n                    status=None,\n                )\n                output_items.append(reasoning_item)\n        else:\n            raise ValueError(f\"Unknown recipient: {message.recipient}\")\n    elif message.channel == \"final\":\n        contents = []\n        for content in message.content:\n            output_text = ResponseOutputText(\n                text=content.text,\n                annotations=[],  # TODO\n                type=\"output_text\",\n                logprobs=None,  # TODO\n            )\n            contents.append(output_text)\n        text_item = ResponseOutputMessage(\n            id=f\"msg_{random_uuid()}\",\n            content=contents,\n            role=message.author.role,\n            status=\"completed\",\n            type=\"message\",\n        )\n        output_items.append(text_item)\n    else:\n        raise ValueError(f\"Unknown channel: {message.channel}\")\n    return output_items\n\n\ndef parse_remaining_state(parser: StreamableParser):\n    if not parser.current_content:\n        return []\n    if parser.current_role != Role.ASSISTANT:\n        return []\n    current_recipient = parser.current_recipient\n    if current_recipient is not None and current_recipient.startswith(\"browser.\"):\n        return []\n\n    if parser.current_channel == \"analysis\":\n        reasoning_item = ResponseReasoningItem(\n            id=f\"rs_{random_uuid()}\",\n            type=\"reasoning\",\n            summary=[],\n            content=[\n                ResponseReasoningTextContent(\n                    text=parser.current_content, type=\"reasoning_text\"\n                )\n            ],\n            status=None,\n        )\n        return [reasoning_item]\n    elif parser.current_channel == \"final\":\n        output_text = ResponseOutputText(\n            text=parser.current_content,\n            annotations=[],  # TODO\n            type=\"output_text\",\n            logprobs=None,  # TODO\n        )\n        text_item = ResponseOutputMessage(\n            id=f\"msg_{random_uuid()}\",\n            content=[output_text],\n            role=\"assistant\",\n            status=\"completed\",\n            type=\"message\",\n        )\n        return [text_item]\n    return []\n\n\ndef parse_output_into_messages(token_ids: Iterable[int]):\n    parser = get_streamable_parser_for_assistant()\n    for token_id in token_ids:\n        parser.process(token_id)\n    return parser\n"
  },
  {
    "path": "python/sglang/srt/entrypoints/http_server.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"\nThe entry point of inference server. (SRT = SGLang Runtime)\n\nThis file implements HTTP APIs for the inference engine via fastapi.\n\"\"\"\n\nimport asyncio\nimport dataclasses\nimport logging\nimport os\nimport tempfile\nimport threading\nimport time\nfrom contextlib import asynccontextmanager\nfrom http import HTTPStatus\nfrom typing import (\n    Any,\n    AsyncGenerator,\n    AsyncIterator,\n    Callable,\n    Dict,\n    List,\n    Optional,\n    Union,\n)\n\n# Fix a bug of Python threading\nsetattr(threading, \"_register_atexit\", lambda *args, **kwargs: None)\n\n\nimport numpy as np\nimport requests\nimport uvicorn\nimport uvloop\nfrom fastapi import (\n    Depends,\n    FastAPI,\n    File,\n    Form,\n    HTTPException,\n    Query,\n    Request,\n    UploadFile,\n)\nfrom fastapi.exceptions import RequestValidationError\nfrom fastapi.middleware.cors import CORSMiddleware\nfrom fastapi.responses import ORJSONResponse, Response, StreamingResponse\n\nfrom sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST, DisaggregationMode\nfrom sglang.srt.entrypoints.anthropic.protocol import (\n    AnthropicCountTokensRequest,\n    AnthropicMessagesRequest,\n)\nfrom sglang.srt.entrypoints.anthropic.serving import AnthropicServing\nfrom sglang.srt.entrypoints.engine import (\n    Engine,\n    init_tokenizer_manager,\n    run_detokenizer_process,\n    run_scheduler_process,\n)\nfrom sglang.srt.entrypoints.ollama.protocol import (\n    OllamaChatRequest,\n    OllamaGenerateRequest,\n    OllamaShowRequest,\n)\nfrom sglang.srt.entrypoints.ollama.serving import OllamaServing\nfrom sglang.srt.entrypoints.openai.protocol import (\n    ChatCompletionRequest,\n    ClassifyRequest,\n    CompletionRequest,\n    DetokenizeRequest,\n    EmbeddingRequest,\n    ErrorResponse,\n    ModelCard,\n    ModelList,\n    ResponsesRequest,\n    ScoringRequest,\n    TokenizeRequest,\n    V1RerankReqInput,\n)\nfrom sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat\nfrom sglang.srt.entrypoints.openai.serving_classify import OpenAIServingClassify\nfrom sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompletion\nfrom sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding\nfrom sglang.srt.entrypoints.openai.serving_rerank import OpenAIServingRerank\nfrom sglang.srt.entrypoints.openai.serving_score import OpenAIServingScore\nfrom sglang.srt.entrypoints.openai.serving_tokenize import (\n    OpenAIServingDetokenize,\n    OpenAIServingTokenize,\n)\nfrom sglang.srt.entrypoints.openai.serving_transcription import (\n    OpenAIServingTranscription,\n)\nfrom sglang.srt.entrypoints.warmup import execute_warmups\nfrom sglang.srt.environ import envs\nfrom sglang.srt.function_call.function_call_parser import FunctionCallParser\nfrom sglang.srt.managers.io_struct import (\n    AbortReq,\n    AttachHiCacheStorageReqInput,\n    CheckWeightsReqInput,\n    CloseSessionReqInput,\n    ConfigureLoggingReq,\n    ContinueGenerationReqInput,\n    DestroyWeightsUpdateGroupReqInput,\n    DumperControlReqInput,\n    EmbeddingReqInput,\n    GenerateReqInput,\n    GetWeightsByNameReqInput,\n    InitWeightsSendGroupForRemoteInstanceReqInput,\n    InitWeightsUpdateGroupReqInput,\n    LoadLoRAAdapterFromTensorsReqInput,\n    LoadLoRAAdapterReqInput,\n    OpenSessionReqInput,\n    ParseFunctionCallReq,\n    PauseGenerationReqInput,\n    PinPrefixReqInput,\n    ProfileReqInput,\n    ReleaseMemoryOccupationReqInput,\n    ResumeMemoryOccupationReqInput,\n    SendWeightsToRemoteInstanceReqInput,\n    SeparateReasoningReqInput,\n    SetInternalStateReq,\n    SlowDownReqInput,\n    UnloadLoRAAdapterReqInput,\n    UpdateWeightFromDiskReqInput,\n    UpdateWeightsFromDistributedReqInput,\n    UpdateWeightsFromIPCReqInput,\n    UpdateWeightsFromTensorReqInput,\n    UpdateWeightVersionReqInput,\n    VertexGenerateReqInput,\n)\nfrom sglang.srt.managers.multi_tokenizer_mixin import (\n    MultiTokenizerRouter,\n    TokenizerWorker,\n    get_main_process_id,\n    monkey_patch_uvicorn_multiprocessing,\n    read_from_shared_memory,\n    write_data_for_multi_tokenizer,\n)\nfrom sglang.srt.managers.template_manager import TemplateManager\nfrom sglang.srt.managers.tokenizer_manager import ServerStatus, TokenizerManager\nfrom sglang.srt.model_loader.remote_instance_weight_loader_utils import (\n    parse_remote_instance_transfer_engine_info_from_scheduler_infos,\n)\nfrom sglang.srt.observability.func_timer import enable_func_timer\nfrom sglang.srt.observability.trace import (\n    process_tracing_init,\n    set_global_trace_level,\n    trace_set_thread_info,\n)\nfrom sglang.srt.parser.reasoning_parser import ReasoningParser\nfrom sglang.srt.server_args import PortArgs, ServerArgs\nfrom sglang.srt.utils import (\n    add_prometheus_middleware,\n    add_prometheus_track_response_middleware,\n    delete_directory,\n    get_bool_env_var,\n    kill_process_tree,\n    set_uvicorn_logging_configs,\n)\nfrom sglang.srt.utils.auth import AuthLevel, app_has_admin_force_endpoints, auth_level\nfrom sglang.srt.utils.json_response import (\n    SGLangORJSONResponse,\n    dumps_json,\n    orjson_response,\n)\nfrom sglang.utils import get_exception_traceback\nfrom sglang.version import __version__\n\nlogger = logging.getLogger(__name__)\nasyncio.set_event_loop_policy(uvloop.EventLoopPolicy())\n\n# Global constants\nHEALTH_CHECK_TIMEOUT = int(os.getenv(\"SGLANG_HEALTH_CHECK_TIMEOUT\", 20))\nWAIT_WEIGHTS_READY_TIMEOUT = int(os.getenv(\"SGLANG_WAIT_WEIGHTS_READY_TIMEOUT\", 120))\n\n\n# Store global states\n@dataclasses.dataclass\nclass _GlobalState:\n    tokenizer_manager: Union[TokenizerManager, MultiTokenizerRouter, TokenizerWorker]\n    template_manager: TemplateManager\n    scheduler_info: Dict\n    # Dict{\n    #   rank: Tuple(\n    #           session_id,\n    #           Dict{\n    #               name: Tuple (d_ptr, numel, element_size)\n    #           }\n    #         )\n    # }\n    remote_instance_transfer_engine_info: Optional[Dict] = None\n\n\n_global_state: Optional[_GlobalState] = None\n\n\ndef set_global_state(global_state: _GlobalState):\n    global _global_state\n    _global_state = global_state\n\n\ndef get_global_state() -> _GlobalState:\n    return _global_state\n\n\nasync def init_multi_tokenizer() -> ServerArgs:\n    \"\"\"\n    Initialization function for multi-process tokenizer mode.\n    It read args information from shm and inits tokenizer manager for current process.\n    \"\"\"\n\n    # Read configuration from shared memory\n    main_pid = get_main_process_id()\n    port_args, server_args, scheduler_info = read_from_shared_memory(\n        f\"multi_tokenizer_args_{main_pid}\"\n    )\n    server_args: ServerArgs\n    port_args: PortArgs\n\n    # API key authentication is not supported in multi-tokenizer mode\n    assert (\n        server_args.api_key is None\n    ), \"API key is not supported in multi-tokenizer mode\"\n\n    # Create a new ipc name for the current process\n    port_args.tokenizer_ipc_name = (\n        f\"ipc://{tempfile.NamedTemporaryFile(delete=False).name}\"\n    )\n    logger.info(\n        f\"Start multi-tokenizer worker process {os.getpid()}, \"\n        f\"ipc_name={port_args.tokenizer_ipc_name}\"\n    )\n\n    # Launch multi-tokenizer manager process\n    tokenizer_manager = TokenizerWorker(server_args, port_args)\n    template_manager = TemplateManager()\n    template_manager.initialize_templates(\n        tokenizer_manager=tokenizer_manager,\n        model_path=server_args.model_path,\n        chat_template=server_args.chat_template,\n        completion_template=server_args.completion_template,\n    )\n\n    tokenizer_manager.max_req_input_len = scheduler_info[\"max_req_input_len\"]\n\n    set_global_state(\n        _GlobalState(\n            tokenizer_manager=tokenizer_manager,\n            template_manager=template_manager,\n            scheduler_info=scheduler_info,\n        )\n    )\n\n    return server_args\n\n\n@asynccontextmanager\nasync def lifespan(fast_api_app: FastAPI):\n    if getattr(fast_api_app, \"is_single_tokenizer_mode\", False):\n        server_args = fast_api_app.server_args\n        warmup_thread_kwargs = fast_api_app.warmup_thread_kwargs\n        thread_label = \"Tokenizer\"\n    else:\n        # Initialize multi-tokenizer support for worker processes\n        server_args = await init_multi_tokenizer()\n        warmup_thread_kwargs = dict(server_args=server_args)\n        thread_label = f\"MultiTokenizer-{_global_state.tokenizer_manager.worker_id}\"\n\n    # Add prometheus middleware\n    if server_args.enable_metrics:\n        add_prometheus_middleware(app)\n        enable_func_timer()\n\n    # Init tracing\n    if server_args.enable_trace:\n        process_tracing_init(server_args.otlp_traces_endpoint, \"sglang\")\n        if server_args.disaggregation_mode == \"prefill\":\n            thread_label = \"Prefill\" + thread_label\n        elif server_args.disaggregation_mode == \"decode\":\n            thread_label = \"Decode\" + thread_label\n        trace_set_thread_info(thread_label)\n\n    # Initialize OpenAI serving handlers\n    fast_api_app.state.openai_serving_completion = OpenAIServingCompletion(\n        _global_state.tokenizer_manager, _global_state.template_manager\n    )\n    fast_api_app.state.openai_serving_chat = OpenAIServingChat(\n        _global_state.tokenizer_manager, _global_state.template_manager\n    )\n    fast_api_app.state.openai_serving_embedding = OpenAIServingEmbedding(\n        _global_state.tokenizer_manager, _global_state.template_manager\n    )\n    fast_api_app.state.openai_serving_classify = OpenAIServingClassify(\n        _global_state.tokenizer_manager, _global_state.template_manager\n    )\n    fast_api_app.state.openai_serving_score = OpenAIServingScore(\n        _global_state.tokenizer_manager\n    )\n    fast_api_app.state.openai_serving_rerank = OpenAIServingRerank(\n        _global_state.tokenizer_manager, _global_state.template_manager\n    )\n    fast_api_app.state.openai_serving_tokenize = OpenAIServingTokenize(\n        _global_state.tokenizer_manager\n    )\n    fast_api_app.state.openai_serving_detokenize = OpenAIServingDetokenize(\n        _global_state.tokenizer_manager\n    )\n    fast_api_app.state.openai_serving_transcription = OpenAIServingTranscription(\n        _global_state.tokenizer_manager\n    )\n\n    # Initialize Ollama-compatible serving handler\n    fast_api_app.state.ollama_serving = OllamaServing(_global_state.tokenizer_manager)\n\n    # Initialize Anthropic-compatible serving handler\n    fast_api_app.state.anthropic_serving = AnthropicServing(\n        fast_api_app.state.openai_serving_chat\n    )\n\n    # Launch tool server\n    tool_server = None\n    if server_args.tool_server == \"demo\":\n        from sglang.srt.entrypoints.openai.tool_server import DemoToolServer\n\n        tool_server = DemoToolServer()\n    elif server_args.tool_server:\n        from sglang.srt.entrypoints.openai.tool_server import MCPToolServer\n\n        tool_server = MCPToolServer()\n        await tool_server.add_tool_server(server_args.tool_server)\n\n    try:\n        from sglang.srt.entrypoints.openai.serving_responses import (\n            OpenAIServingResponses,\n        )\n\n        fast_api_app.state.openai_serving_responses = OpenAIServingResponses(\n            _global_state.tokenizer_manager,\n            _global_state.template_manager,\n            enable_prompt_tokens_details=True,\n            enable_force_include_usage=True,\n            tool_server=tool_server,\n        )\n    except Exception:\n        traceback = get_exception_traceback()\n        logger.warning(f\"Can not initialize OpenAIServingResponses, error: {traceback}\")\n\n    # Execute custom warmups\n    if server_args.warmups is not None:\n        await execute_warmups(\n            server_args.disaggregation_mode,\n            server_args.warmups.split(\",\"),\n            _global_state.tokenizer_manager,\n        )\n        logger.info(\"Warmup ended\")\n\n    # Execute the general warmup\n    warmup_thread = threading.Thread(\n        target=_wait_and_warmup,\n        kwargs=warmup_thread_kwargs,\n    )\n    warmup_thread.start()\n\n    # Start the HTTP server\n    try:\n        yield\n    finally:\n        warmup_thread.join()\n\n\n# Fast API\napp = FastAPI(\n    lifespan=lifespan,\n    openapi_url=None if get_bool_env_var(\"DISABLE_OPENAPI_DOC\") else \"/openapi.json\",\n)\napp.add_middleware(\n    CORSMiddleware,\n    allow_origins=[\"*\"],\n    allow_credentials=True,\n    allow_methods=[\"*\"],\n    allow_headers=[\"*\"],\n)\n\n# Include routers\nfrom sglang.srt.entrypoints.v1_loads import router as v1_loads_router\n\napp.include_router(v1_loads_router)\n\n\n@app.exception_handler(HTTPException)\nasync def validation_exception_handler(request: Request, exc: HTTPException):\n    \"\"\"Enrich HTTP exception with status code and other details.\n\n    For /v1/responses, emit OpenAI-style nested error envelope:\n    {\"error\": {\"message\": \"...\", \"type\": \"...\", \"param\": null, \"code\": <status>}}\n    \"\"\"\n    # adjust fmt for responses api\n    if request.url.path.startswith(\"/v1/responses\"):\n        nested_error = {\n            \"message\": exc.detail,\n            \"type\": HTTPStatus(exc.status_code).phrase,\n            \"param\": None,\n            \"code\": exc.status_code,\n        }\n        return ORJSONResponse(\n            content={\"error\": nested_error}, status_code=exc.status_code\n        )\n\n    error = ErrorResponse(\n        object=\"error\",\n        message=exc.detail,\n        type=str(exc.status_code),\n        code=exc.status_code,\n    )\n    return ORJSONResponse(content=error.model_dump(), status_code=exc.status_code)\n\n\n# Custom exception handlers to change validation error status codes\n@app.exception_handler(RequestValidationError)\nasync def validation_exception_handler(request: Request, exc: RequestValidationError):\n    \"\"\"Override FastAPI's default 422 validation error with 400.\n\n    For /v1/responses, emit OpenAI-style nested error envelope; for other endpoints keep legacy format.\n    \"\"\"\n    exc_str = str(exc)\n    errors_str = str(exc.errors())\n\n    if errors_str and errors_str != exc_str:\n        message = f\"{exc_str} {errors_str}\"\n    else:\n        message = exc_str\n\n    if request.url.path.startswith(\"/v1/responses\"):\n        # adapt specially, for v1/responses API only (notice the error key is different)\n        nested_error = {\n            \"message\": message,\n            \"type\": HTTPStatus.BAD_REQUEST.phrase,\n            \"param\": None,\n            \"code\": HTTPStatus.BAD_REQUEST.value,\n        }\n        return ORJSONResponse(status_code=400, content={\"error\": nested_error})\n\n    err = ErrorResponse(\n        message=message,\n        type=HTTPStatus.BAD_REQUEST.phrase,\n        code=HTTPStatus.BAD_REQUEST.value,\n    )\n\n    return ORJSONResponse(\n        status_code=400,\n        content=err.model_dump(),\n    )\n\n\nasync def validate_json_request(raw_request: Request):\n    \"\"\"Validate that the request content-type is application/json.\"\"\"\n    content_type = raw_request.headers.get(\"content-type\", \"\").lower()\n    media_type = content_type.split(\";\", maxsplit=1)[0]\n    if media_type != \"application/json\":\n        raise RequestValidationError(\n            errors=[\n                {\n                    \"loc\": [\"header\", \"content-type\"],\n                    \"msg\": \"Unsupported Media Type: Only 'application/json' is allowed\",\n                    \"type\": \"value_error\",\n                }\n            ]\n        )\n\n\n##### Native API endpoints #####\n\n\n@app.get(\"/health\")\n@app.get(\"/health_generate\")\nasync def health_generate(request: Request) -> Response:\n    \"\"\"\n    Check the health of the inference server by sending a special request to generate one token.\n\n    If the server is running something, this request will be ignored, so it creates zero overhead.\n    If the server is not running anything, this request will be run, so we know whether the server is healthy.\n    \"\"\"\n\n    if _global_state.tokenizer_manager.gracefully_exit:\n        logger.info(\"Health check request received during shutdown. Returning 503.\")\n        return Response(status_code=503)\n\n    if _global_state.tokenizer_manager.server_status == ServerStatus.Starting:\n        return Response(status_code=503)\n\n    if (\n        not envs.SGLANG_ENABLE_HEALTH_ENDPOINT_GENERATION.get()\n        and request.url.path == \"/health\"\n    ):\n        return Response(status_code=200)\n\n    sampling_params = {\"max_new_tokens\": 1, \"temperature\": 0.0}\n    rid = f\"HEALTH_CHECK_{time.time()}\"\n\n    if _global_state.tokenizer_manager.is_image_gen:\n        gri = _global_state.tokenizer_manager.get_image_gen_health_check_request(\n            rid, sampling_params\n        )\n    elif _global_state.tokenizer_manager.is_generation:\n        gri = GenerateReqInput(\n            rid=rid,\n            input_ids=[0],\n            sampling_params=sampling_params,\n            log_metrics=False,\n        )\n        if (\n            _global_state.tokenizer_manager.server_args.disaggregation_mode\n            != DisaggregationMode.NULL.value\n        ):\n            gri.bootstrap_host = FAKE_BOOTSTRAP_HOST\n            gri.bootstrap_room = 0\n    else:\n        gri = EmbeddingReqInput(\n            rid=rid, input_ids=[0], sampling_params=sampling_params, log_metrics=False\n        )\n\n    async def gen():\n        async for _ in _global_state.tokenizer_manager.generate_request(gri, request):\n            break\n\n    task = asyncio.create_task(gen())\n\n    # As long as we receive any response from the detokenizer/scheduler, we consider the server is healthy.\n    tic = time.time()\n    while time.time() < tic + HEALTH_CHECK_TIMEOUT:\n        await asyncio.sleep(1)\n        if _global_state.tokenizer_manager.last_receive_tstamp > tic:\n            task.cancel()\n            _global_state.tokenizer_manager.rid_to_state.pop(rid, None)\n            _global_state.tokenizer_manager.server_status = ServerStatus.Up\n            return Response(status_code=200)\n\n    task.cancel()\n    tic_time = time.strftime(\"%H:%M:%S\", time.localtime(tic))\n    last_receive_time = time.strftime(\n        \"%H:%M:%S\", time.localtime(_global_state.tokenizer_manager.last_receive_tstamp)\n    )\n    logger.error(\n        f\"Health check failed. Server couldn't get a response from detokenizer for last \"\n        f\"{HEALTH_CHECK_TIMEOUT} seconds. tic start time: {tic_time}. \"\n        f\"last_heartbeat time: {last_receive_time}\"\n    )\n    _global_state.tokenizer_manager.rid_to_state.pop(rid, None)\n    _global_state.tokenizer_manager.server_status = ServerStatus.UnHealthy\n    return Response(status_code=503)\n\n\n@app.get(\"/get_model_info\")\nasync def get_model_info():\n    \"\"\"Get the model information (deprecated - use /model_info instead).\"\"\"\n    logger.warning(\n        \"Endpoint '/get_model_info' is deprecated and will be removed in a future version. \"\n        \"Please use '/model_info' instead.\"\n    )\n    return await model_info()\n\n\n@app.get(\"/model_info\")\nasync def model_info():\n    \"\"\"Get the model information.\"\"\"\n    model_config = _global_state.tokenizer_manager.model_config\n    result = {\n        \"model_path\": _global_state.tokenizer_manager.model_path,\n        \"tokenizer_path\": _global_state.tokenizer_manager.server_args.tokenizer_path,\n        \"is_generation\": _global_state.tokenizer_manager.is_generation,\n        \"preferred_sampling_params\": _global_state.tokenizer_manager.server_args.preferred_sampling_params,\n        \"weight_version\": _global_state.tokenizer_manager.server_args.weight_version,\n        \"has_image_understanding\": model_config.is_image_understandable_model,\n        \"has_audio_understanding\": model_config.is_audio_understandable_model,\n        \"model_type\": getattr(model_config.hf_config, \"model_type\", None),\n        \"architectures\": getattr(model_config.hf_config, \"architectures\", None),\n        \"weight_version\": _global_state.tokenizer_manager.server_args.weight_version,\n        # \"hf_config\": model_config.hf_config.to_dict(),\n    }\n    return result\n\n\n@app.get(\"/get_weight_version\")\n@app.get(\"/weight_version\")\nasync def weight_version():\n    \"\"\"Get the current weight version.\"\"\"\n    raise HTTPException(\n        status_code=404,\n        detail=\"Endpoint '/get_weight_version' or '/weight_version' is deprecated. Please use '/model_info' instead.\",\n    )\n\n\n@app.get(\"/get_server_info\")\nasync def get_server_info():\n    \"\"\"Get the server information (deprecated - use /server_info instead).\"\"\"\n    logger.warning(\n        \"Endpoint '/get_server_info' is deprecated and will be removed in a future version. \"\n        \"Please use '/server_info' instead.\"\n    )\n    return await server_info()\n\n\n@app.get(\"/server_info\")\nasync def server_info():\n    \"\"\"Get the server information.\"\"\"\n    # Returns internal states per DP.\n    internal_states: List[Dict[Any, Any]] = (\n        await _global_state.tokenizer_manager.get_internal_state()\n    )\n\n    # This field is not serializable.\n    if hasattr(_global_state.tokenizer_manager.server_args, \"model_config\"):\n        del _global_state.tokenizer_manager.server_args.model_config\n\n    return {\n        **dataclasses.asdict(_global_state.tokenizer_manager.server_args),\n        **_global_state.scheduler_info,\n        \"internal_states\": internal_states,\n        \"version\": __version__,\n    }\n\n\n@app.get(\"/get_load\")\nasync def get_load():\n    \"\"\"Get load metrics (deprecated - use /v1/loads instead).\"\"\"\n    logger.warning(\n        \"Endpoint '/get_load' is deprecated and will be removed in a future version. \"\n        \"Please use '/v1/loads' instead.\"\n    )\n    return await _global_state.tokenizer_manager.get_load()\n\n\n# example usage:\n# curl -s -X POST http://localhost:30000/set_internal_state -H \"Content-Type: application/json\" -d '{\"server_args\": {\"pp_max_micro_batch_size\": 8}}'\n@app.api_route(\"/set_internal_state\", methods=[\"POST\", \"PUT\"])\n@auth_level(AuthLevel.ADMIN_OPTIONAL)\nasync def set_internal_state(obj: SetInternalStateReq, request: Request):\n    res = await _global_state.tokenizer_manager.set_internal_state(obj)\n    return res\n\n\n# Do not import `dumper.py` to avoid dependency\nif os.environ.get(\"DUMPER_SERVER_PORT\") == \"reuse\":\n\n    @app.api_route(\"/dumper/{method}\", methods=[\"POST\"])\n    @auth_level(AuthLevel.ADMIN_OPTIONAL)\n    async def _dumper_control_handler(method: str, request: Request):\n        body_bytes = await request.body()\n        body = await request.json() if body_bytes else {}\n        obj = DumperControlReqInput(method=method, body=body)\n        results = await _global_state.tokenizer_manager.dumper_control(obj)\n        if any(not r.success for r in results):\n            errors = [r.error for r in results if not r.success]\n            return ORJSONResponse(status_code=400, content={\"error\": errors})\n        return [x for result in results for x in result.response]\n\n\n# fastapi implicitly converts json in the request to obj (dataclass)\n@app.api_route(\n    \"/generate\",\n    methods=[\"POST\", \"PUT\"],\n    response_class=SGLangORJSONResponse,\n)\nasync def generate_request(obj: GenerateReqInput, request: Request):\n    \"\"\"Handle a generate request.\"\"\"\n    if obj.stream:\n\n        async def stream_results() -> AsyncIterator[bytes]:\n            try:\n                async for out in _global_state.tokenizer_manager.generate_request(\n                    obj, request\n                ):\n                    yield b\"data: \" + dumps_json(out) + b\"\\n\\n\"\n            except ValueError as e:\n                out = {\"error\": {\"message\": str(e)}}\n                logger.error(f\"[http_server] Error: {e}\")\n                yield b\"data: \" + dumps_json(out) + b\"\\n\\n\"\n            yield b\"data: [DONE]\\n\\n\"\n\n        return StreamingResponse(\n            stream_results(),\n            media_type=\"text/event-stream\",\n            background=_global_state.tokenizer_manager.create_abort_task(obj),\n        )\n    else:\n        try:\n            ret = await _global_state.tokenizer_manager.generate_request(\n                obj, request\n            ).__anext__()\n            return orjson_response(ret)\n        except ValueError as e:\n            logger.error(f\"[http_server] Error: {e}\")\n            return _create_error_response(e)\n\n\n@app.api_route(\"/encode\", methods=[\"POST\", \"PUT\"])\nasync def encode_request(obj: EmbeddingReqInput, request: Request):\n    \"\"\"Handle an embedding request.\"\"\"\n    try:\n        ret = await _global_state.tokenizer_manager.generate_request(\n            obj, request\n        ).__anext__()\n        return ret\n    except ValueError as e:\n        return _create_error_response(e)\n\n\n@app.api_route(\"/classify\", methods=[\"POST\", \"PUT\"])\nasync def classify_request(obj: EmbeddingReqInput, request: Request):\n    \"\"\"Handle a reward model request. Now the arguments and return values are the same as embedding models.\"\"\"\n    try:\n        ret = await _global_state.tokenizer_manager.generate_request(\n            obj, request\n        ).__anext__()\n        return ret\n    except ValueError as e:\n        return _create_error_response(e)\n\n\n@app.api_route(\"/flush_cache\", methods=[\"GET\", \"POST\"])\n@auth_level(AuthLevel.ADMIN_OPTIONAL)\nasync def flush_cache():\n    \"\"\"Flush the radix cache.\"\"\"\n    ret = await _global_state.tokenizer_manager.flush_cache()\n    return Response(\n        content=\"Cache flushed.\\nPlease check backend logs for more details. \"\n        \"(When there are running or waiting requests, the operation will not be performed.)\\n\",\n        status_code=200 if ret.success else HTTPStatus.BAD_REQUEST,\n    )\n\n\n@app.api_route(\"/clear_hicache_storage_backend\", methods=[\"GET\", \"POST\"])\n@auth_level(AuthLevel.ADMIN_OPTIONAL)\nasync def clear_hicache_storage_backend_deprecated():\n    \"\"\"Deprecated: use POST /hicache/storage-backend/clear.\"\"\"\n    ret = await _global_state.tokenizer_manager.clear_hicache_storage()\n    return Response(\n        content=(\n            \"Deprecated endpoint. Use POST /hicache/storage-backend/clear.\\n\"\n            \"Hierarchical cache storage backend cleared.\\n\"\n        ),\n        status_code=200 if ret.success else HTTPStatus.BAD_REQUEST,\n    )\n\n\n# example usage:\n# curl -s -X POST http://127.0.0.1:30000/clear_hicache_storage_backend\n@app.api_route(\"/hicache/storage-backend/clear\", methods=[\"POST\"])\n@auth_level(AuthLevel.ADMIN_OPTIONAL)\nasync def clear_hicache_storage_backend():\n    \"\"\"Clear the hierarchical cache storage backend.\"\"\"\n    ret = await _global_state.tokenizer_manager.clear_hicache_storage()\n    return Response(\n        content=\"Hierarchical cache storage backend cleared.\\n\",\n        status_code=200 if ret.success else HTTPStatus.BAD_REQUEST,\n    )\n\n\n# example usage:\n# curl -s -X PUT http://127.0.0.1:30000/hicache/storage-backend \\\n#  -H 'Content-Type: application/json' \\\n#   -d '{\n#     \"hicache_storage_backend\": \"file\",\n#     \"hicache_storage_backend_extra_config_json\": \"{}\",\n#     \"hicache_storage_prefetch_policy\": \"timeout\",\n#     \"hicache_write_policy\": \"write_through\"\n#   }'\n@app.api_route(\"/hicache/storage-backend\", methods=[\"PUT\"])\n@auth_level(AuthLevel.ADMIN_OPTIONAL)\nasync def attach_hicache_storage_backend(obj: AttachHiCacheStorageReqInput):\n    \"\"\"Attach (enable) HiCache storage backend at runtime.\n\n    Only allowed when there are NO running / queued requests.\n    \"\"\"\n    if not _global_state.tokenizer_manager.server_args.admin_api_key:\n        return _admin_api_key_missing_response()\n\n    ret = await _global_state.tokenizer_manager.attach_hicache_storage(\n        hicache_storage_backend=obj.hicache_storage_backend,\n        hicache_storage_backend_extra_config_json=obj.hicache_storage_backend_extra_config_json,\n        hicache_storage_prefetch_policy=obj.hicache_storage_prefetch_policy,\n        hicache_write_policy=obj.hicache_write_policy,\n    )\n    msg = getattr(ret, \"message\", \"\")\n    return Response(\n        content=(\n            (\n                \"HiCache storage backend attached.\\n\"\n                if ret.success\n                else \"Failed to attach HiCache storage backend.\\n\"\n            )\n            + (msg + \"\\n\" if msg else \"\")\n        ),\n        status_code=200 if ret.success else HTTPStatus.BAD_REQUEST,\n    )\n\n\n# example usage:\n# curl -s -X DELETE http://127.0.0.1:30000/hicache/storage-backend\n@app.api_route(\"/hicache/storage-backend\", methods=[\"DELETE\"])\n@auth_level(AuthLevel.ADMIN_OPTIONAL)\nasync def detach_hicache_storage_backend():\n    \"\"\"Detach (disable) HiCache storage backend at runtime.\n\n    Only allowed when there are NO running / queued requests.\n    \"\"\"\n    if not _global_state.tokenizer_manager.server_args.admin_api_key:\n        return _admin_api_key_missing_response()\n\n    ret = await _global_state.tokenizer_manager.detach_hicache_storage()\n    msg = getattr(ret, \"message\", \"\")\n    return Response(\n        content=(\n            (\n                \"HiCache storage backend detached.\\n\"\n                if ret.success\n                else \"Failed to detach HiCache storage backend.\\n\"\n            )\n            + (msg + \"\\n\" if msg else \"\")\n        ),\n        status_code=200 if ret.success else HTTPStatus.BAD_REQUEST,\n    )\n\n\n# example usage:\n# curl -s http://127.0.0.1:30000/hicache/storage-backend\n@app.get(\"/hicache/storage-backend\")\n@auth_level(AuthLevel.ADMIN_OPTIONAL)\nasync def hicache_storage_backend_status():\n    \"\"\"Get current HiCache storage backend status (tokenizer-side view).\"\"\"\n    if not _global_state.tokenizer_manager.server_args.admin_api_key:\n        return _admin_api_key_missing_response()\n\n    return {\n        \"hicache_storage_backend\": _global_state.tokenizer_manager.server_args.hicache_storage_backend,\n        \"hicache_storage_backend_extra_config\": _global_state.tokenizer_manager.server_args.hicache_storage_backend_extra_config,\n        \"hicache_storage_prefetch_policy\": _global_state.tokenizer_manager.server_args.hicache_storage_prefetch_policy,\n        \"hicache_write_policy\": _global_state.tokenizer_manager.server_args.hicache_write_policy,\n    }\n\n\n@app.api_route(\"/hicache/pin_prefix\", methods=[\"POST\"])\n@auth_level(AuthLevel.ADMIN_OPTIONAL)\nasync def pin_prefix(obj: PinPrefixReqInput):\n    \"\"\"Pin a prefix by token_ids to resist eviction.\"\"\"\n    if not _global_state.tokenizer_manager.server_args.admin_api_key:\n        return _admin_api_key_missing_response()\n    ret = await _global_state.tokenizer_manager.pin_prefix(\n        obj.token_ids, obj.ttl_seconds\n    )\n    return ORJSONResponse(\n        content={\n            \"status\": \"ok\" if ret.success else \"error\",\n            \"nodes_pinned\": ret.nodes_pinned,\n            \"message\": ret.message,\n        },\n        status_code=200 if ret.success else HTTPStatus.BAD_REQUEST,\n    )\n\n\n@app.api_route(\"/start_profile\", methods=[\"GET\", \"POST\"])\n@auth_level(AuthLevel.ADMIN_OPTIONAL)\nasync def start_profile_async(obj: Optional[ProfileReqInput] = None):\n    \"\"\"Start profiling.\"\"\"\n    if obj is None:\n        obj = ProfileReqInput()\n\n    await _global_state.tokenizer_manager.start_profile(\n        output_dir=obj.output_dir,\n        start_step=obj.start_step,\n        num_steps=obj.num_steps,\n        activities=obj.activities,\n        with_stack=obj.with_stack,\n        record_shapes=obj.record_shapes,\n        profile_by_stage=obj.profile_by_stage,\n        merge_profiles=obj.merge_profiles,\n        profile_prefix=obj.profile_prefix,\n        profile_stages=obj.profile_stages,\n    )\n    return Response(\n        content=\"Start profiling.\\n\",\n        status_code=200,\n    )\n\n\n@app.api_route(\"/stop_profile\", methods=[\"GET\", \"POST\"])\n@auth_level(AuthLevel.ADMIN_OPTIONAL)\nasync def stop_profile_async():\n    \"\"\"Stop profiling.\"\"\"\n    await _global_state.tokenizer_manager.stop_profile()\n    return Response(\n        content=\"Stop profiling. This will take some time.\\n\",\n        status_code=200,\n    )\n\n\n@app.api_route(\"/set_trace_level\", methods=[\"GET\", \"POST\"])\ndef set_trace_level(level: int = Query(..., ge=0)):\n    set_global_trace_level(level)\n\n    return Response(\n        content=\"success\",\n        status_code=200,\n    )\n\n\n@app.api_route(\"/freeze_gc\", methods=[\"GET\", \"POST\"])\n@auth_level(AuthLevel.ADMIN_OPTIONAL)\nasync def freeze_gc_async():\n    \"\"\"\n    See engine.freeze_gc for more details.\n    \"\"\"\n    await _global_state.tokenizer_manager.freeze_gc()\n    return Response(\n        content=\"Garbage collection frozen.\\n\",\n        status_code=200,\n    )\n\n\n@app.api_route(\"/start_expert_distribution_record\", methods=[\"GET\", \"POST\"])\n@auth_level(AuthLevel.ADMIN_OPTIONAL)\nasync def start_expert_distribution_record_async():\n    \"\"\"Start recording the expert distribution. Clear the previous record if any.\"\"\"\n    await _global_state.tokenizer_manager.start_expert_distribution_record()\n    return Response(\n        content=\"Start recording the expert distribution.\\n\",\n        status_code=200,\n    )\n\n\n@app.api_route(\"/stop_expert_distribution_record\", methods=[\"GET\", \"POST\"])\n@auth_level(AuthLevel.ADMIN_OPTIONAL)\nasync def stop_expert_distribution_record_async():\n    \"\"\"Stop recording the expert distribution.\"\"\"\n    await _global_state.tokenizer_manager.stop_expert_distribution_record()\n    return Response(\n        content=\"Stop recording the expert distribution.\\n\",\n        status_code=200,\n    )\n\n\n@app.api_route(\"/dump_expert_distribution_record\", methods=[\"GET\", \"POST\"])\n@auth_level(AuthLevel.ADMIN_OPTIONAL)\nasync def dump_expert_distribution_record_async():\n    \"\"\"Dump expert distribution record.\"\"\"\n    await _global_state.tokenizer_manager.dump_expert_distribution_record()\n    return Response(\n        content=\"Dump expert distribution record.\\n\",\n        status_code=200,\n    )\n\n\n@app.post(\"/update_weights_from_disk\")\n@auth_level(AuthLevel.ADMIN_OPTIONAL)\nasync def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request):\n    \"\"\"Update the weights from disk inplace without re-launching the server.\"\"\"\n    success, message, num_paused_requests = (\n        await _global_state.tokenizer_manager.update_weights_from_disk(obj, request)\n    )\n\n    content = {\n        \"success\": success,\n        \"message\": message,\n        \"num_paused_requests\": num_paused_requests,\n    }\n    if success:\n        return ORJSONResponse(\n            content,\n            status_code=HTTPStatus.OK,\n        )\n    else:\n        return ORJSONResponse(\n            content,\n            status_code=HTTPStatus.BAD_REQUEST,\n        )\n\n\n@app.post(\"/init_weights_send_group_for_remote_instance\")\n@auth_level(AuthLevel.ADMIN_OPTIONAL)\nasync def init_weights_send_group_for_remote_instance(\n    obj: InitWeightsSendGroupForRemoteInstanceReqInput, request: Request\n):\n    success, message = (\n        await _global_state.tokenizer_manager.init_weights_send_group_for_remote_instance(\n            obj, request\n        )\n    )\n    content = {\"success\": success, \"message\": message}\n    if success:\n        return ORJSONResponse(content, status_code=200)\n    else:\n        return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)\n\n\n@app.post(\"/send_weights_to_remote_instance\")\n@auth_level(AuthLevel.ADMIN_OPTIONAL)\nasync def send_weights_to_remote_instance(\n    obj: SendWeightsToRemoteInstanceReqInput, request: Request\n):\n    success, message = (\n        await _global_state.tokenizer_manager.send_weights_to_remote_instance(\n            obj, request\n        )\n    )\n    content = {\"success\": success, \"message\": message}\n    if success:\n        return ORJSONResponse(content, status_code=200)\n    else:\n        return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)\n\n\n@app.get(\"/get_remote_instance_transfer_engine_info\")\n@auth_level(AuthLevel.ADMIN_OPTIONAL)\nasync def get_remote_instance_transfer_engine_info(rank: int = None):\n    if rank is None or rank < 0:\n        return Response(status_code=HTTPStatus.BAD_REQUEST)\n\n    if (\n        _global_state.remote_instance_transfer_engine_info is None\n        or len(_global_state.remote_instance_transfer_engine_info) == 0\n    ):\n        return Response(status_code=HTTPStatus.BAD_REQUEST)\n\n    try:\n        result = {\n            \"rank\": rank,\n            \"remote_instance_transfer_engine_info\": _global_state.remote_instance_transfer_engine_info[\n                rank\n            ],\n        }\n        return result\n    except Exception as e:\n        logger.error(f\"Exception: {e}\")\n        return Response(status_code=HTTPStatus.BAD_REQUEST)\n\n\n@app.post(\"/init_weights_update_group\")\n@auth_level(AuthLevel.ADMIN_OPTIONAL)\nasync def init_weights_update_group(\n    obj: InitWeightsUpdateGroupReqInput, request: Request\n):\n    \"\"\"Initialize the parameter update group.\"\"\"\n    success, message = await _global_state.tokenizer_manager.init_weights_update_group(\n        obj, request\n    )\n    content = {\"success\": success, \"message\": message}\n    if success:\n        return ORJSONResponse(content, status_code=200)\n    else:\n        return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)\n\n\n@app.post(\"/destroy_weights_update_group\")\n@auth_level(AuthLevel.ADMIN_OPTIONAL)\nasync def destroy_weights_update_group(\n    obj: DestroyWeightsUpdateGroupReqInput, request: Request\n):\n    \"\"\"Destroy the parameter update group.\"\"\"\n    success, message = (\n        await _global_state.tokenizer_manager.destroy_weights_update_group(obj, request)\n    )\n    content = {\"success\": success, \"message\": message}\n    return ORJSONResponse(\n        content, status_code=200 if success else HTTPStatus.BAD_REQUEST\n    )\n\n\n@app.post(\"/update_weights_from_tensor\")\n@auth_level(AuthLevel.ADMIN_OPTIONAL)\nasync def update_weights_from_tensor(\n    obj: UpdateWeightsFromTensorReqInput, request: Request\n):\n    \"\"\"Update the weights from tensor inplace without re-launching the server.\n    Notes:\n    1. Ensure that the model is on the correct device (e.g., GPU) before calling this endpoint. If the model is moved to the CPU unexpectedly, it may cause performance issues or runtime errors.\n    2. HTTP will transmit only the metadata of the tensor, while the tensor itself will be directly copied to the model.\n    3. Any binary data in the named tensors should be base64 encoded.\n    \"\"\"\n\n    success, message = await _global_state.tokenizer_manager.update_weights_from_tensor(\n        obj, request\n    )\n\n    content = {\"success\": success, \"message\": message}\n    return ORJSONResponse(\n        content, status_code=200 if success else HTTPStatus.BAD_REQUEST\n    )\n\n\n@app.post(\"/update_weights_from_distributed\")\n@auth_level(AuthLevel.ADMIN_OPTIONAL)\nasync def update_weights_from_distributed(\n    obj: UpdateWeightsFromDistributedReqInput, request: Request\n):\n    \"\"\"Update model parameter from distributed online.\"\"\"\n    success, message = (\n        await _global_state.tokenizer_manager.update_weights_from_distributed(\n            obj, request\n        )\n    )\n\n    content = {\"success\": success, \"message\": message}\n    if success:\n        return ORJSONResponse(content, status_code=200)\n    else:\n        return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)\n\n\n@app.post(\"/update_weights_from_ipc\")\n@auth_level(AuthLevel.ADMIN_OPTIONAL)\nasync def update_weights_from_ipc(obj: UpdateWeightsFromIPCReqInput, request: Request):\n    \"\"\"Update the weights from IPC (Inter-Process Communication) for checkpoint-engine integration.\"\"\"\n    success, message = await _global_state.tokenizer_manager.update_weights_from_ipc(\n        obj, request\n    )\n\n    content = {\"success\": success, \"message\": message}\n    if success:\n        if _global_state.tokenizer_manager.initial_weights_loaded is False:\n            _global_state.tokenizer_manager.initial_weights_loaded = True\n        return ORJSONResponse(content)\n    else:\n        return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)\n\n\n@app.post(\"/update_weight_version\")\n@auth_level(AuthLevel.ADMIN_OPTIONAL)\nasync def update_weight_version(obj: UpdateWeightVersionReqInput, request: Request):\n    \"\"\"Update the weight version. This operation requires no active requests.\"\"\"\n    if obj.abort_all_requests:\n        _global_state.tokenizer_manager.abort_request(abort_all=True)\n\n    # Use a simple approach without the complex lock mechanism for now\n    # since weight_version update is a simple operation that doesn't affect model weights\n    try:\n        # Update the weight version in server args (the single source of truth)\n        _global_state.tokenizer_manager.server_args.weight_version = obj.new_version\n\n        return ORJSONResponse(\n            {\n                \"success\": True,\n                \"message\": f\"Weight version updated to {obj.new_version}\",\n                \"new_version\": obj.new_version,\n            },\n            status_code=HTTPStatus.OK,\n        )\n    except Exception as e:\n        return ORJSONResponse(\n            {\n                \"success\": False,\n                \"message\": f\"Failed to update weight version: {str(e)}\",\n            },\n            status_code=HTTPStatus.BAD_REQUEST,\n        )\n\n\n@app.api_route(\"/get_weights_by_name\", methods=[\"GET\", \"POST\"])\n@auth_level(AuthLevel.ADMIN_OPTIONAL)\nasync def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):\n    \"\"\"Get model parameter by name.\"\"\"\n    try:\n        ret = await _global_state.tokenizer_manager.get_weights_by_name(obj, request)\n        if ret is None:\n            return _create_error_response(\"Get parameter by name failed\")\n        else:\n            return ORJSONResponse(ret, status_code=200)\n    except Exception as e:\n        return _create_error_response(e)\n\n\n@app.api_route(\"/release_memory_occupation\", methods=[\"GET\", \"POST\"])\n@auth_level(AuthLevel.ADMIN_OPTIONAL)\nasync def release_memory_occupation(\n    obj: ReleaseMemoryOccupationReqInput, request: Request\n):\n    \"\"\"Release GPU memory occupation temporarily.\"\"\"\n    try:\n        await _global_state.tokenizer_manager.release_memory_occupation(obj, request)\n    except Exception as e:\n        return _create_error_response(e)\n\n\n@app.api_route(\"/resume_memory_occupation\", methods=[\"GET\", \"POST\"])\n@auth_level(AuthLevel.ADMIN_OPTIONAL)\nasync def resume_memory_occupation(\n    obj: ResumeMemoryOccupationReqInput, request: Request\n):\n    \"\"\"Resume GPU memory occupation.\"\"\"\n    try:\n        await _global_state.tokenizer_manager.resume_memory_occupation(obj, request)\n    except Exception as e:\n        return _create_error_response(e)\n\n\n@app.post(\"/weights_checker\")\n@auth_level(AuthLevel.ADMIN_OPTIONAL)\nasync def check_weights(obj: CheckWeightsReqInput, request: Request):\n    success, message = await _global_state.tokenizer_manager.check_weights(obj, request)\n    return ORJSONResponse(\n        {\"success\": success, \"message\": message},\n        status_code=200 if success else HTTPStatus.BAD_REQUEST,\n    )\n\n\n@app.api_route(\"/slow_down\", methods=[\"GET\", \"POST\"])\n@auth_level(AuthLevel.ADMIN_OPTIONAL)\nasync def slow_down(obj: SlowDownReqInput, request: Request):\n    \"\"\"Slow down the system deliberately. Only for testing. Example scenario:\n    when we want to test performance of D in large-scale PD disaggregation and have no enough nodes for P,\n    we can use this to slow down D to let it have enough running sequences, and then disable slowdown\n    to let it run in full batch size.\n    \"\"\"\n    try:\n        await _global_state.tokenizer_manager.slow_down(obj, request)\n    except Exception as e:\n        return _create_error_response(e)\n\n\n@app.api_route(\"/load_lora_adapter\", methods=[\"POST\"])\n@auth_level(AuthLevel.ADMIN_OPTIONAL)\nasync def load_lora_adapter(obj: LoadLoRAAdapterReqInput, request: Request):\n    \"\"\"Load a new LoRA adapter without re-launching the server.\"\"\"\n    result = await _global_state.tokenizer_manager.load_lora_adapter(obj, request)\n\n    if result.success:\n        return ORJSONResponse(\n            result,\n            status_code=HTTPStatus.OK,\n        )\n    else:\n        return ORJSONResponse(\n            result,\n            status_code=HTTPStatus.BAD_REQUEST,\n        )\n\n\n@app.api_route(\"/load_lora_adapter_from_tensors\", methods=[\"POST\"])\nasync def load_lora_adapter_from_tensors(\n    obj: LoadLoRAAdapterFromTensorsReqInput, request: Request\n):\n    \"\"\"Load a new LoRA adapter from tensors without re-launching the server.\"\"\"\n    result = await _global_state.tokenizer_manager.load_lora_adapter_from_tensors(\n        obj, request\n    )\n\n    if result.success:\n        return ORJSONResponse(result, status_code=HTTPStatus.OK)\n    else:\n        return ORJSONResponse(result, status_code=HTTPStatus.BAD_REQUEST)\n\n\n@app.api_route(\"/unload_lora_adapter\", methods=[\"POST\"])\n@auth_level(AuthLevel.ADMIN_OPTIONAL)\nasync def unload_lora_adapter(obj: UnloadLoRAAdapterReqInput, request: Request):\n    \"\"\"Load a new LoRA adapter without re-launching the server.\"\"\"\n    result = await _global_state.tokenizer_manager.unload_lora_adapter(obj, request)\n\n    if result.success:\n        return ORJSONResponse(\n            result,\n            status_code=HTTPStatus.OK,\n        )\n    else:\n        return ORJSONResponse(\n            result,\n            status_code=HTTPStatus.BAD_REQUEST,\n        )\n\n\n@app.api_route(\"/open_session\", methods=[\"GET\", \"POST\"])\nasync def open_session(obj: OpenSessionReqInput, request: Request):\n    \"\"\"Open a session, and return its unique session id.\"\"\"\n    try:\n        session_id = await _global_state.tokenizer_manager.open_session(obj, request)\n        if session_id is None:\n            raise Exception(\n                \"Failed to open the session. Check if a session with the same id is still open.\"\n            )\n        return session_id\n    except Exception as e:\n        return _create_error_response(e)\n\n\n@app.api_route(\"/close_session\", methods=[\"GET\", \"POST\"])\nasync def close_session(obj: CloseSessionReqInput, request: Request):\n    \"\"\"Close the session.\"\"\"\n    try:\n        await _global_state.tokenizer_manager.close_session(obj, request)\n        return Response(status_code=200)\n    except Exception as e:\n        return _create_error_response(e)\n\n\n@app.api_route(\"/configure_logging\", methods=[\"GET\", \"POST\"])\n@auth_level(AuthLevel.ADMIN_OPTIONAL)\nasync def configure_logging(obj: ConfigureLoggingReq, request: Request):\n    \"\"\"Configure the request logging options.\"\"\"\n    _global_state.tokenizer_manager.configure_logging(obj)\n    return Response(status_code=200)\n\n\n@app.post(\"/abort_request\")\n@auth_level(AuthLevel.ADMIN_OPTIONAL)\nasync def abort_request(obj: AbortReq, request: Request):\n    \"\"\"Abort a request.\"\"\"\n    try:\n        _global_state.tokenizer_manager.abort_request(\n            rid=obj.rid, abort_all=obj.abort_all\n        )\n        return Response(status_code=200)\n    except Exception as e:\n        return _create_error_response(e)\n\n\n@app.post(\"/parse_function_call\")\nasync def parse_function_call_request(obj: ParseFunctionCallReq, request: Request):\n    \"\"\"\n    A native API endpoint to parse function calls from a text.\n    \"\"\"\n    # 1) Initialize the parser based on the request body\n    parser = FunctionCallParser(tools=obj.tools, tool_call_parser=obj.tool_call_parser)\n\n    # 2) Call the non-stream parsing method (non-stream)\n    normal_text, calls = parser.parse_non_stream(obj.text)\n\n    # 3) Organize the response content\n    response_data = {\n        \"normal_text\": normal_text,\n        \"calls\": [\n            call.model_dump() for call in calls\n        ],  # Convert pydantic objects to dictionaries\n    }\n\n    return ORJSONResponse(content=response_data, status_code=200)\n\n\n@app.post(\"/separate_reasoning\")\nasync def separate_reasoning_request(obj: SeparateReasoningReqInput, request: Request):\n    \"\"\"\n    A native API endpoint to separate reasoning from a text.\n    \"\"\"\n    # 1) Initialize the parser based on the request body\n    parser = ReasoningParser(model_type=obj.reasoning_parser, request=request)\n\n    # 2) Call the non-stream parsing method (non-stream)\n    reasoning_text, normal_text = parser.parse_non_stream(obj.text)\n\n    # 3) Organize the response content\n    response_data = {\n        \"reasoning_text\": reasoning_text,\n        \"text\": normal_text,\n    }\n\n    return ORJSONResponse(content=response_data, status_code=200)\n\n\n@app.post(\"/pause_generation\")\n@auth_level(AuthLevel.ADMIN_OPTIONAL)\nasync def pause_generation(obj: PauseGenerationReqInput, request: Request):\n    \"\"\"Pause generation.\"\"\"\n    await _global_state.tokenizer_manager.pause_generation(obj)\n    return ORJSONResponse(\n        content={\"message\": \"Generation paused successfully.\", \"status\": \"ok\"},\n        status_code=200,\n    )\n\n\n@app.post(\"/continue_generation\")\n@auth_level(AuthLevel.ADMIN_OPTIONAL)\nasync def continue_generation(obj: ContinueGenerationReqInput, request: Request):\n    \"\"\"Continue generation.\"\"\"\n    await _global_state.tokenizer_manager.continue_generation(obj)\n    return ORJSONResponse(\n        content={\"message\": \"Generation continued successfully.\", \"status\": \"ok\"},\n        status_code=200,\n    )\n\n\n##### OpenAI-compatible API endpoints #####\n\n\n@app.post(\"/v1/completions\", dependencies=[Depends(validate_json_request)])\nasync def openai_v1_completions(request: CompletionRequest, raw_request: Request):\n    \"\"\"OpenAI-compatible text completion endpoint.\"\"\"\n    return await raw_request.app.state.openai_serving_completion.handle_request(\n        request, raw_request\n    )\n\n\n@app.post(\"/v1/chat/completions\", dependencies=[Depends(validate_json_request)])\nasync def openai_v1_chat_completions(\n    request: ChatCompletionRequest, raw_request: Request\n):\n    \"\"\"OpenAI-compatible chat completion endpoint.\"\"\"\n    return await raw_request.app.state.openai_serving_chat.handle_request(\n        request, raw_request\n    )\n\n\n@app.post(\n    \"/v1/embeddings\",\n    response_class=ORJSONResponse,\n    dependencies=[Depends(validate_json_request)],\n)\nasync def openai_v1_embeddings(request: EmbeddingRequest, raw_request: Request):\n    \"\"\"OpenAI-compatible embeddings endpoint.\"\"\"\n    return await raw_request.app.state.openai_serving_embedding.handle_request(\n        request, raw_request\n    )\n\n\n@app.post(\n    \"/v1/classify\",\n    response_class=ORJSONResponse,\n    dependencies=[Depends(validate_json_request)],\n)\nasync def openai_v1_classify(request: ClassifyRequest, raw_request: Request):\n    \"\"\"OpenAI-compatible classification endpoint.\"\"\"\n    return await raw_request.app.state.openai_serving_classify.handle_request(\n        request, raw_request\n    )\n\n\n@app.post(\n    \"/v1/tokenize\",\n    response_class=ORJSONResponse,\n    dependencies=[Depends(validate_json_request)],\n)\n@app.post(\n    \"/tokenize\",\n    response_class=ORJSONResponse,\n    dependencies=[Depends(validate_json_request)],\n    include_in_schema=False,\n)\nasync def openai_v1_tokenize(request: TokenizeRequest, raw_request: Request):\n    \"\"\"OpenAI-compatible tokenization endpoint.\"\"\"\n    return await raw_request.app.state.openai_serving_tokenize.handle_request(\n        request, raw_request\n    )\n\n\n@app.post(\n    \"/v1/detokenize\",\n    response_class=ORJSONResponse,\n    dependencies=[Depends(validate_json_request)],\n)\n@app.post(\n    \"/detokenize\",\n    response_class=ORJSONResponse,\n    dependencies=[Depends(validate_json_request)],\n    include_in_schema=False,\n)\nasync def openai_v1_detokenize(request: DetokenizeRequest, raw_request: Request):\n    \"\"\"OpenAI-compatible detokenization endpoint.\"\"\"\n    return await raw_request.app.state.openai_serving_detokenize.handle_request(\n        request, raw_request\n    )\n\n\n@app.post(\"/v1/audio/transcriptions\")\nasync def openai_v1_audio_transcriptions(\n    raw_request: Request,\n    file: UploadFile = File(...),\n    model: str = Form(default=\"default\"),\n    language: Optional[str] = Form(default=None),\n    response_format: str = Form(default=\"json\"),\n    temperature: float = Form(default=0.0),\n    stream: bool = Form(default=False),\n):\n    \"\"\"OpenAI-compatible audio transcription endpoint.\"\"\"\n    if response_format not in [\"json\", \"text\"]:\n        return ORJSONResponse(\n            content={\"error\": {\"message\": \"Only 'json' and 'text' formats supported\"}},\n            status_code=400,\n        )\n\n    audio_data = await file.read()\n\n    return (\n        await raw_request.app.state.openai_serving_transcription.create_transcription(\n            audio_data=audio_data,\n            model=model,\n            language=language,\n            response_format=response_format,\n            temperature=temperature,\n            stream=stream,\n            raw_request=raw_request,\n        )\n    )\n\n\n@app.get(\"/v1/models\", response_class=ORJSONResponse)\nasync def available_models():\n    \"\"\"Show available models. OpenAI-compatible endpoint.\"\"\"\n    served_model_names = [_global_state.tokenizer_manager.served_model_name]\n    model_cards = []\n\n    # Add base model\n    for served_model_name in served_model_names:\n        model_cards.append(\n            ModelCard(\n                id=served_model_name,\n                root=served_model_name,\n                max_model_len=_global_state.tokenizer_manager.model_config.context_len,\n            )\n        )\n\n    # Add loaded LoRA adapters\n    if _global_state.tokenizer_manager.server_args.enable_lora:\n        lora_registry = _global_state.tokenizer_manager.lora_registry\n        for _, lora_ref in lora_registry.get_all_adapters().items():\n            model_cards.append(\n                ModelCard(\n                    id=lora_ref.lora_name,\n                    root=lora_ref.lora_path,\n                    parent=served_model_names[0],\n                    max_model_len=None,\n                )\n            )\n\n    return ModelList(data=model_cards)\n\n\n@app.get(\"/v1/models/{model:path}\", response_class=ORJSONResponse)\nasync def retrieve_model(model: str):\n    \"\"\"Retrieves a model instance, providing basic information about the model.\"\"\"\n    served_model_names = [_global_state.tokenizer_manager.served_model_name]\n\n    if model not in served_model_names:\n        return ORJSONResponse(\n            status_code=404,\n            content={\n                \"error\": {\n                    \"message\": f\"The model '{model}' does not exist\",\n                    \"type\": \"invalid_request_error\",\n                    \"param\": \"model\",\n                    \"code\": \"model_not_found\",\n                }\n            },\n        )\n\n    return ModelCard(\n        id=model,\n        root=model,\n        max_model_len=_global_state.tokenizer_manager.model_config.context_len,\n    )\n\n\n@app.post(\"/v1/score\", dependencies=[Depends(validate_json_request)])\nasync def v1_score_request(request: ScoringRequest, raw_request: Request):\n    \"\"\"Endpoint for the decoder-only scoring API. See Engine.score() for detailed documentation.\"\"\"\n    return await raw_request.app.state.openai_serving_score.handle_request(\n        request, raw_request\n    )\n\n\n@app.post(\"/v1/responses\", dependencies=[Depends(validate_json_request)])\nasync def v1_responses_request(request: dict, raw_request: Request):\n    \"\"\"Endpoint for the responses API with reasoning support.\"\"\"\n\n    request_obj = ResponsesRequest(**request)\n    result = await raw_request.app.state.openai_serving_responses.create_responses(\n        request_obj, raw_request\n    )\n\n    # Handle streaming responses\n    if isinstance(result, AsyncGenerator):\n        return StreamingResponse(\n            result,\n            media_type=\"text/event-stream\",\n            headers={\"Cache-Control\": \"no-cache\", \"Connection\": \"keep-alive\"},\n        )\n\n    return result\n\n\n@app.get(\"/v1/responses/{response_id}\")\nasync def v1_retrieve_responses(response_id: str, raw_request: Request):\n    \"\"\"Retrieve a response by ID.\"\"\"\n    return await raw_request.app.state.openai_serving_responses.retrieve_responses(\n        response_id\n    )\n\n\n@app.post(\"/v1/responses/{response_id}/cancel\")\nasync def v1_cancel_responses(response_id: str, raw_request: Request):\n    \"\"\"Cancel a background response.\"\"\"\n    return await raw_request.app.state.openai_serving_responses.cancel_responses(\n        response_id\n    )\n\n\n@app.api_route(\n    \"/v1/rerank\", methods=[\"POST\", \"PUT\"], dependencies=[Depends(validate_json_request)]\n)\nasync def v1_rerank_request(request: V1RerankReqInput, raw_request: Request):\n    \"\"\"Endpoint for reranking documents based on query relevance.\"\"\"\n    return await raw_request.app.state.openai_serving_rerank.handle_request(\n        request, raw_request\n    )\n\n\n##### Ollama-compatible API endpoints #####\n\n_ollama_root_route = os.environ.get(\"SGLANG_OLLAMA_ROOT_ROUTE\")\nif _ollama_root_route is not None:\n\n    @app.get(_ollama_root_route)\n    @app.head(_ollama_root_route)\n    async def ollama_root():\n        \"\"\"Ollama-compatible root endpoint.\"\"\"\n        return \"Ollama is running\"\n\nelse:\n\n    @app.get(\"/\")\n    @app.head(\"/\")\n    async def sglang_root():\n        \"\"\"Default root endpoint.\"\"\"\n        return \"SGLang is running\"\n\n\n@app.post(os.environ.get(\"SGLANG_OLLAMA_CHAT_ROUTE\", \"/api/chat\"))\nasync def ollama_chat(request: OllamaChatRequest, raw_request: Request):\n    \"\"\"Ollama-compatible chat endpoint.\"\"\"\n    return await raw_request.app.state.ollama_serving.handle_chat(request, raw_request)\n\n\n@app.post(os.environ.get(\"SGLANG_OLLAMA_GENERATE_ROUTE\", \"/api/generate\"))\nasync def ollama_generate(request: OllamaGenerateRequest, raw_request: Request):\n    \"\"\"Ollama-compatible generate endpoint.\"\"\"\n    return await raw_request.app.state.ollama_serving.handle_generate(\n        request, raw_request\n    )\n\n\n@app.get(os.environ.get(\"SGLANG_OLLAMA_TAGS_ROUTE\", \"/api/tags\"))\nasync def ollama_tags(raw_request: Request):\n    \"\"\"Ollama-compatible list models endpoint.\"\"\"\n    return raw_request.app.state.ollama_serving.get_tags()\n\n\n@app.post(os.environ.get(\"SGLANG_OLLAMA_SHOW_ROUTE\", \"/api/show\"))\nasync def ollama_show(request: OllamaShowRequest, raw_request: Request):\n    \"\"\"Ollama-compatible show model info endpoint.\"\"\"\n    return raw_request.app.state.ollama_serving.get_show(request.model)\n\n\n##### Anthropic-compatible API endpoints #####\n\n\n@app.post(\"/v1/messages\", dependencies=[Depends(validate_json_request)])\nasync def anthropic_v1_messages(\n    request: AnthropicMessagesRequest, raw_request: Request\n):\n    \"\"\"Anthropic-compatible Messages API endpoint.\"\"\"\n    return await raw_request.app.state.anthropic_serving.handle_messages(\n        request, raw_request\n    )\n\n\n@app.post(\"/v1/messages/count_tokens\", dependencies=[Depends(validate_json_request)])\nasync def anthropic_v1_count_tokens(\n    request: AnthropicCountTokensRequest, raw_request: Request\n):\n    \"\"\"Anthropic-compatible token counting endpoint.\"\"\"\n    return await raw_request.app.state.anthropic_serving.handle_count_tokens(\n        request, raw_request\n    )\n\n\n## SageMaker API\n@app.get(\"/ping\")\nasync def sagemaker_health() -> Response:\n    \"\"\"Check the health of the http server.\"\"\"\n    return Response(status_code=200)\n\n\n@app.post(\"/invocations\")\nasync def sagemaker_chat_completions(\n    request: ChatCompletionRequest, raw_request: Request\n):\n    \"\"\"OpenAI-compatible chat completion endpoint.\"\"\"\n    return await raw_request.app.state.openai_serving_chat.handle_request(\n        request, raw_request\n    )\n\n\n## Vertex AI API\n@app.post(os.environ.get(\"AIP_PREDICT_ROUTE\", \"/vertex_generate\"))\nasync def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Request):\n    if not vertex_req.instances:\n        return []\n    inputs = {}\n    for input_key in (\"text\", \"input_ids\", \"input_embeds\"):\n        if vertex_req.instances[0].get(input_key):\n            inputs[input_key] = [\n                instance.get(input_key) for instance in vertex_req.instances\n            ]\n            break\n    image_data = [\n        instance.get(\"image_data\")\n        for instance in vertex_req.instances\n        if instance.get(\"image_data\") is not None\n    ] or None\n    req = GenerateReqInput(\n        **inputs,\n        image_data=image_data,\n        **(vertex_req.parameters or {}),\n    )\n    ret = await generate_request(req, raw_request)\n    if isinstance(ret, Response):\n        return ret\n    return ORJSONResponse({\"predictions\": ret})\n\n\ndef _create_error_response(e):\n    return ORJSONResponse(\n        {\"error\": {\"message\": str(e)}}, status_code=HTTPStatus.BAD_REQUEST\n    )\n\n\n# FIXME: In theory we should configure ADMIN_FORCE for some entrypoints, but doing so\n# would currently cause all endpoints to go through add_api_key_middleware\n# (even when neither api-key nor admin-api-key is configured).\n#\n# For now, we simulate ADMIN_FORCE by explicitly checking the admin API key parameter.\n# Once the auth wiring is refactored so ADMIN_FORCE only affects the intended\n# admin endpoints, we should switch this logic to use ADMIN_FORCE directly.\ndef _admin_api_key_missing_response(\n    status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,\n) -> ORJSONResponse:\n    return ORJSONResponse(\n        content={\n            \"error\": (\n                \"This endpoint requires admin API key, but this server was started \"\n                \"without one (admin-api-key). Restart with --admin-api-key to enable.\"\n            )\n        },\n        status_code=status_code,\n    )\n\n\n# Minimal 32x32 black PNG (base64, GLM4v requires at least 32x32 sized image)\nMINIMUM_PNG_PICTURE_BASE64 = \"iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAYAAABzenr0AAAACXBIWXMAAA7EAAAOxAGVKw4bAAAAbUlEQVRYhe3VsQ2AMAxE0Y/lIgNQULD/OqyCMgCihCKSG4yRuKuiNH6JLsoEbMACOGBcua9HOR7Y6w6swBwMy0qLTpkeI77qdEBpBFAHBBDAGH8WrwJKI4AAegUCfAKgEgpQDvh3CR3oQCuav58qlAw73kKCSgAAAABJRU5ErkJggg==\"\n\n\ndef _execute_server_warmup(server_args: ServerArgs):\n    headers = {}\n    url = server_args.url()\n    if server_args.api_key:\n        headers[\"Authorization\"] = f\"Bearer {server_args.api_key}\"\n\n    ssl_verify = server_args.ssl_verify()\n\n    # Wait until the server is launched\n    success = False\n    for _ in range(120):\n        time.sleep(1)\n        try:\n            res = requests.get(\n                url + \"/model_info\", timeout=5, headers=headers, verify=ssl_verify\n            )\n            assert res.status_code == 200, f\"{res=}, {res.text=}\"\n            success = True\n            break\n        except (AssertionError, requests.exceptions.RequestException):\n            last_traceback = get_exception_traceback()\n            pass\n\n    if not success:\n        logger.error(f\"Initialization failed. warmup error: {last_traceback}\")\n        kill_process_tree(os.getpid())\n        return success\n\n    model_info = res.json()\n\n    # Construct a warmup request\n    is_vlm = bool(model_info.get(\"has_image_understanding\", False))\n    if model_info[\"is_generation\"]:\n        if is_vlm and not server_args.skip_tokenizer_init:\n            request_name = \"/v1/chat/completions\"\n        else:\n            request_name = \"/generate\"\n    else:\n        request_name = \"/encode\"\n    max_new_tokens = 8 if model_info[\"is_generation\"] else 1\n    json_data = {\n        \"sampling_params\": {\n            \"temperature\": 0,\n            \"max_new_tokens\": max_new_tokens,\n        },\n    }\n    if server_args.skip_tokenizer_init:\n        json_data[\"input_ids\"] = [[10, 11, 12] for _ in range(server_args.dp_size)]\n        # TODO Workaround the bug that embedding errors for list of size 1\n        if server_args.dp_size == 1:\n            json_data[\"input_ids\"] = json_data[\"input_ids\"][0]\n    elif (\n        is_vlm\n        and server_args.disaggregation_mode == \"null\"\n        and model_info[\"is_generation\"]\n    ):\n        # TODO: ChatCompletionRequest does not have bootstrap info required by disaggregation mode, disable image-warmup for now\n        # Only use chat completions format for generation models, not embedding models\n        json_data = {\n            \"model\": _global_state.tokenizer_manager.served_model_name,\n            \"messages\": [\n                {\n                    \"role\": \"user\",\n                    \"content\": [\n                        {\n                            \"type\": \"image_url\",\n                            \"image_url\": {\n                                \"url\": f\"data:image/png;base64,{MINIMUM_PNG_PICTURE_BASE64}\"\n                            },\n                        },\n                        {\n                            \"type\": \"text\",\n                            \"text\": \"Describe the image.\",\n                        },\n                    ],\n                }\n            ],\n            \"max_tokens\": max_new_tokens,\n            \"stream\": False,\n            \"temperature\": 0.0,\n        }\n    else:\n        json_data[\"text\"] = [\"The capital city of France is\"] * server_args.dp_size\n        # TODO Workaround the bug that embedding errors for list of size 1\n        if server_args.dp_size == 1:\n            json_data[\"text\"] = json_data[\"text\"][0]\n\n    # Config debug dumping\n    if server_args.debug_tensor_dump_input_file:\n        json_data.pop(\"text\", None)\n        json_data[\"input_ids\"] = np.load(\n            server_args.debug_tensor_dump_input_file\n        ).tolist()\n        json_data[\"sampling_params\"][\"max_new_tokens\"] = 0\n\n    # Send a warmup request\n    warmup_timeout = envs.SGLANG_WARMUP_TIMEOUT.get()\n    try:\n        if server_args.disaggregation_mode == \"null\":\n            res = requests.post(\n                url + request_name,\n                json=json_data,\n                headers=headers,\n                timeout=warmup_timeout if warmup_timeout > 0 else 600,\n                verify=ssl_verify,\n            )\n            assert res.status_code == 200, f\"{res.text}\"\n            _global_state.tokenizer_manager.server_status = ServerStatus.Up\n\n        else:\n            logger.info(f\"Start of pd disaggregation warmup ...\")\n            json_data = {\n                \"sampling_params\": {\n                    \"temperature\": 0.0,\n                    \"max_new_tokens\": 8,\n                    \"ignore_eos\": True,\n                },\n                \"bootstrap_host\": [FAKE_BOOTSTRAP_HOST] * server_args.dp_size,\n                # This is a hack to ensure fake transfer is enabled during prefill warmup\n                # ensure each dp rank has a unique bootstrap_room during prefill warmup\n                \"bootstrap_room\": [\n                    i * (2**63 // server_args.dp_size) + (i % server_args.tp_size)\n                    for i in range(server_args.dp_size)\n                ],\n                \"input_ids\": [[10, 11, 12, 13]] * server_args.dp_size,\n            }\n            res = requests.post(\n                url + request_name,\n                json=json_data,\n                headers=headers,\n                timeout=(\n                    warmup_timeout if warmup_timeout > 0 else 1800\n                ),  # because of deep gemm precache is very long if not precache.\n                verify=ssl_verify,\n            )\n            if res.status_code == 200:\n                logger.info(\n                    f\"End of prefill disaggregation mode warmup with status {res.status_code}, resp: {res.json()}\"\n                )\n                _global_state.tokenizer_manager.server_status = ServerStatus.Up\n            else:\n                logger.info(\n                    \"Prefill disaggregation mode warm Up Failed, status code: {}\".format(\n                        res.status_code\n                    )\n                )\n                _global_state.tokenizer_manager.server_status = ServerStatus.UnHealthy\n\n    except Exception:\n        last_traceback = get_exception_traceback()\n        logger.error(f\"Initialization failed. warmup error: {last_traceback}\")\n        kill_process_tree(os.getpid())\n        return False\n\n    # Debug print\n    # logger.info(f\"warmup request returns: {res.json()=}\")\n    return success\n\n\ndef _wait_and_warmup(\n    server_args: ServerArgs,\n    launch_callback: Optional[Callable[[], None]] = None,\n    execute_warmup_func: Callable = _execute_server_warmup,\n):\n    if server_args.checkpoint_engine_wait_weights_before_ready:\n        _wait_weights_ready()\n\n    # Send a warmup request\n    if not server_args.skip_server_warmup:\n        if not execute_warmup_func(server_args):\n            return\n    else:\n        _global_state.tokenizer_manager.server_status = ServerStatus.Up\n\n    # The server is ready for requests\n    logger.info(\"The server is fired up and ready to roll!\")\n\n    if server_args.delete_ckpt_after_loading:\n        delete_directory(server_args.model_path)\n\n    if server_args.debug_tensor_dump_input_file:\n        kill_process_tree(os.getpid())\n\n    if launch_callback is not None:\n        launch_callback()\n\n\ndef _wait_weights_ready():\n    \"\"\"Wait for weights to be ready within the specified timeout.\"\"\"\n    timeout = WAIT_WEIGHTS_READY_TIMEOUT\n    start_time = time.time()\n\n    for _ in range(timeout):\n        if _global_state.tokenizer_manager.initial_weights_loaded:\n            logger.info(\n                f\"Weights are ready after {time.time() - start_time:.2f} seconds\"\n            )\n            return\n        time.sleep(1)\n\n    # Timeout reached without weights being ready\n    logger.error(\n        f\"Weights are not ready after waiting {timeout} seconds. \"\n        f\"Consider increasing SGLANG_WAIT_WEIGHTS_READY_TIMEOUT environment variable. \"\n        f\"Current status: initial_weights_loaded={_global_state.tokenizer_manager.initial_weights_loaded}\"\n    )\n\n\ndef _setup_and_run_http_server(\n    server_args: ServerArgs,\n    tokenizer_manager,\n    template_manager,\n    port_args: PortArgs,\n    scheduler_infos: List[Dict],\n    execute_warmup_func: Callable = _execute_server_warmup,\n    launch_callback: Optional[Callable[[], None]] = None,\n):\n    \"\"\"Set up global state, configure middleware, and run uvicorn.\n\n    Called by launch_server after subprocesses have been launched.\n    \"\"\"\n    # Parse info got from the schedulers\n    remote_instance_transfer_engine_info = (\n        parse_remote_instance_transfer_engine_info_from_scheduler_infos(scheduler_infos)\n    )\n\n    # Set global states\n    set_global_state(\n        _GlobalState(\n            tokenizer_manager=tokenizer_manager,\n            template_manager=template_manager,\n            scheduler_info=scheduler_infos[0],\n            remote_instance_transfer_engine_info=remote_instance_transfer_engine_info,\n        )\n    )\n\n    if server_args.enable_metrics:\n        add_prometheus_track_response_middleware(app)\n\n    # Pass additional arguments to the lifespan function.\n    # They will be used for additional initialization setups.\n    if server_args.tokenizer_worker_num == 1:\n        # If it is single tokenizer mode, we can pass the arguments by attributes of the app object.\n        app.is_single_tokenizer_mode = True\n        app.server_args = server_args\n        app.warmup_thread_kwargs = dict(\n            server_args=server_args,\n            launch_callback=launch_callback,\n            execute_warmup_func=execute_warmup_func,\n        )\n\n        # Add api key authorization\n        # This is only supported in single tokenizer mode.\n        #\n        # Backward compatibility:\n        # - api_key only: behavior matches legacy (all endpoints require api_key)\n        # - no keys: legacy had no restriction; ADMIN_FORCE endpoints must still be rejected when\n        #   admin_api_key is not configured.\n        if (\n            server_args.api_key\n            or server_args.admin_api_key\n            or app_has_admin_force_endpoints(app)\n        ):\n            from sglang.srt.utils.auth import add_api_key_middleware\n\n            add_api_key_middleware(\n                app,\n                api_key=server_args.api_key,\n                admin_api_key=server_args.admin_api_key,\n            )\n    else:\n        # If it is multi-tokenizer mode, we need to write the arguments to shared memory\n        # for other worker processes to read.\n        app.is_single_tokenizer_mode = False\n        multi_tokenizer_args_shm = write_data_for_multi_tokenizer(\n            port_args, server_args, scheduler_infos[0]\n        )\n\n    try:\n        # Update logging configs\n        set_uvicorn_logging_configs(server_args)\n\n        if server_args.ssl_certfile:\n            logger.info(\n                f\"SSL enabled: certfile={server_args.ssl_certfile}, \"\n                f\"keyfile={server_args.ssl_keyfile}\"\n            )\n\n        # Listen for HTTP requests\n        if server_args.tokenizer_worker_num == 1:\n            if server_args.enable_ssl_refresh:\n                # Use Config/Server API for access to the SSLContext.\n                config = uvicorn.Config(\n                    app,\n                    host=server_args.host,\n                    port=server_args.port,\n                    root_path=server_args.fastapi_root_path,\n                    log_level=server_args.log_level_http or server_args.log_level,\n                    timeout_keep_alive=envs.SGLANG_TIMEOUT_KEEP_ALIVE.get(),\n                    loop=\"uvloop\",\n                    ssl_keyfile=server_args.ssl_keyfile,\n                    ssl_certfile=server_args.ssl_certfile,\n                    ssl_ca_certs=server_args.ssl_ca_certs,\n                    ssl_keyfile_password=server_args.ssl_keyfile_password,\n                )\n                config.load()  # Creates the SSLContext\n\n                from sglang.srt.entrypoints.ssl_utils import SSLCertRefresher\n\n                server = uvicorn.Server(config)\n\n                async def _run_with_ssl_refresh():\n                    refresher = SSLCertRefresher(\n                        config.ssl,\n                        server_args.ssl_keyfile,\n                        server_args.ssl_certfile,\n                        server_args.ssl_ca_certs,\n                    )\n                    logger.info(\"SSL certificate auto-refresh enabled.\")\n                    try:\n                        await server.serve()\n                    finally:\n                        refresher.stop()\n\n                import asyncio\n\n                asyncio.run(_run_with_ssl_refresh())\n            else:\n                # Default case, one tokenizer process\n                uvicorn.run(\n                    app,\n                    host=server_args.host,\n                    port=server_args.port,\n                    root_path=server_args.fastapi_root_path,\n                    log_level=server_args.log_level_http or server_args.log_level,\n                    timeout_keep_alive=envs.SGLANG_TIMEOUT_KEEP_ALIVE.get(),\n                    loop=\"uvloop\",\n                    ssl_keyfile=server_args.ssl_keyfile,\n                    ssl_certfile=server_args.ssl_certfile,\n                    ssl_ca_certs=server_args.ssl_ca_certs,\n                    ssl_keyfile_password=server_args.ssl_keyfile_password,\n                )\n        else:\n            # Multiple tokenizer and http processes\n            from uvicorn.config import LOGGING_CONFIG\n\n            LOGGING_CONFIG[\"loggers\"][\"sglang.srt.entrypoints.http_server\"] = {\n                \"handlers\": [\"default\"],\n                \"level\": \"INFO\",\n                \"propagate\": False,\n            }\n            monkey_patch_uvicorn_multiprocessing()\n\n            if server_args.enable_ssl_refresh:\n                logger.warning(\n                    \"--enable-ssl-refresh is not supported with multiple \"\n                    \"tokenizer workers (--tokenizer-worker-num > 1). \"\n                    \"SSL refresh will be disabled.\"\n                )\n\n            uvicorn.run(\n                \"sglang.srt.entrypoints.http_server:app\",\n                host=server_args.host,\n                port=server_args.port,\n                root_path=server_args.fastapi_root_path,\n                log_level=server_args.log_level_http or server_args.log_level,\n                timeout_keep_alive=envs.SGLANG_TIMEOUT_KEEP_ALIVE.get(),\n                loop=\"uvloop\",\n                workers=server_args.tokenizer_worker_num,\n                ssl_keyfile=server_args.ssl_keyfile,\n                ssl_certfile=server_args.ssl_certfile,\n                ssl_ca_certs=server_args.ssl_ca_certs,\n                ssl_keyfile_password=server_args.ssl_keyfile_password,\n            )\n    finally:\n        if server_args.tokenizer_worker_num > 1:\n            if multi_tokenizer_args_shm is not None:\n                multi_tokenizer_args_shm.unlink()\n            if _global_state is not None:\n                _global_state.tokenizer_manager.socket_mapping.clear_all_sockets()\n\n\ndef launch_server(\n    server_args: ServerArgs,\n    init_tokenizer_manager_func: Callable = init_tokenizer_manager,\n    run_scheduler_process_func: Callable = run_scheduler_process,\n    run_detokenizer_process_func: Callable = run_detokenizer_process,\n    execute_warmup_func: Callable = _execute_server_warmup,\n    launch_callback: Optional[Callable[[], None]] = None,\n):\n    \"\"\"\n    Launch SRT (SGLang Runtime) Server.\n\n    The SRT server consists of an HTTP server and an SRT engine.\n\n    - HTTP server: A FastAPI server that routes requests to the engine.\n    - The engine consists of three components:\n        1. TokenizerManager: Tokenizes the requests and sends them to the scheduler.\n        2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager.\n        3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.\n\n    Note:\n    1. The HTTP server, Engine, and TokenizerManager all run in the main process.\n    2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library.\n    \"\"\"\n    # Launch subprocesses\n    tokenizer_manager, template_manager, port_args, scheduler_init_result = (\n        Engine._launch_subprocesses(\n            server_args=server_args,\n            init_tokenizer_manager_func=init_tokenizer_manager_func,\n            run_scheduler_process_func=run_scheduler_process_func,\n            run_detokenizer_process_func=run_detokenizer_process_func,\n        )\n    )\n\n    _setup_and_run_http_server(\n        server_args,\n        tokenizer_manager,\n        template_manager,\n        port_args,\n        scheduler_init_result.scheduler_infos,\n        execute_warmup_func=execute_warmup_func,\n        launch_callback=launch_callback,\n    )\n"
  },
  {
    "path": "python/sglang/srt/entrypoints/http_server_engine.py",
    "content": "import multiprocessing\nimport time\nfrom typing import List, Optional, Tuple\n\nimport requests\nimport torch\n\nfrom sglang.srt.entrypoints.EngineBase import EngineBase\nfrom sglang.srt.entrypoints.http_server import launch_server\nfrom sglang.srt.server_args import ServerArgs\nfrom sglang.srt.utils import MultiprocessingSerializer, kill_process_tree\n\n\ndef launch_server_process(server_args: ServerArgs) -> multiprocessing.Process:\n\n    p = multiprocessing.Process(target=launch_server, args=(server_args,))\n    p.start()\n\n    base_url = server_args.url()\n    timeout = 300.0  # Increased timeout to 5 minutes for downloading large models\n    start_time = time.perf_counter()\n\n    ssl_verify = server_args.ssl_verify()\n\n    with requests.Session() as session:\n        while time.perf_counter() - start_time < timeout:\n            try:\n                headers = {\n                    \"Content-Type\": \"application/json; charset=utf-8\",\n                    \"Authorization\": f\"Bearer {server_args.api_key}\",\n                }\n                response = session.get(\n                    f\"{base_url}/health_generate\", headers=headers, verify=ssl_verify\n                )\n                if response.status_code == 200:\n                    return p\n            except requests.RequestException:\n                pass\n\n            if not p.is_alive():\n                raise Exception(\"Server process terminated unexpectedly.\")\n\n            time.sleep(2)\n\n    p.terminate()\n    raise TimeoutError(\"Server failed to start within the timeout period.\")\n\n\nclass HttpServerEngineAdapter(EngineBase):\n    \"\"\"\n    You can use this class to launch a server from a VerlEngine instance.\n    We recommend using this class only you need to use http server.\n    Otherwise, you can use Engine directly.\n    \"\"\"\n\n    def __init__(self, **kwargs):\n        self.server_args = ServerArgs(**kwargs)\n        print(\n            f\"Launch HttpServerEngineAdapter at: {self.server_args.host}:{self.server_args.port}\"\n        )\n        self.process = launch_server_process(self.server_args)\n\n    def _make_request(self, endpoint: str, payload: Optional[dict] = None):\n        \"\"\"Make a POST request to the specified endpoint with the given payload.\n        Args:\n            endpoint: The API endpoint to call\n            payload: The JSON payload to send (default: empty dict)\n        Returns:\n            The JSON response from the server\n        \"\"\"\n        url = f\"{self.server_args.url()}/{endpoint}\"\n        response = requests.post(\n            url, json=payload or {}, verify=self.server_args.ssl_verify()\n        )\n        response.raise_for_status()\n        return response.json()\n\n    def update_weights_from_tensor(\n        self,\n        named_tensors: List[Tuple[str, torch.Tensor]],\n        load_format: Optional[str] = None,\n        flush_cache: bool = False,\n    ):\n        \"\"\"\n        Update model weights from tensor data. The HTTP server will only post meta data, and the real weights will be copied directly from GPUs.\n        Note: The model should be on GPUs rather than CPU for this functionality to work properly.\n        If you encounter issues, ensure your model is loaded on GPU devices rather than CPU.\n        \"\"\"\n\n        return self._make_request(\n            \"update_weights_from_tensor\",\n            {\n                \"serialized_named_tensors\": [\n                    MultiprocessingSerializer.serialize(named_tensors, output_str=True)\n                    for _ in range(self.server_args.tp_size)\n                ],\n                \"load_format\": load_format,\n                \"flush_cache\": flush_cache,\n            },\n        )\n\n    def shutdown(self):\n        kill_process_tree(self.process.pid)\n\n    def generate(\n        self,\n        prompt=None,\n        sampling_params=None,\n        input_ids=None,\n        image_data=None,\n        return_logprob=False,\n        logprob_start_len=None,\n        top_logprobs_num=None,\n        token_ids_logprob=None,\n        lora_path=None,\n        custom_logit_processor=None,\n        priority=None,\n    ):\n        payload = {\n            \"text\": prompt,\n            \"sampling_params\": sampling_params,\n            \"input_ids\": input_ids,\n            \"image_data\": image_data,\n            \"return_logprob\": return_logprob,\n            \"logprob_start_len\": logprob_start_len,\n            \"top_logprobs_num\": top_logprobs_num,\n            \"token_ids_logprob\": token_ids_logprob,\n            \"lora_path\": lora_path,\n            \"custom_logit_processor\": custom_logit_processor,\n            \"priority\": priority,\n        }\n        # Filter out None values\n        payload = {k: v for k, v in payload.items() if v is not None}\n\n        return self._make_request(\"generate\", payload)\n\n    def release_memory_occupation(self):\n        return self._make_request(\"release_memory_occupation\")\n\n    def resume_memory_occupation(self):\n        return self._make_request(\"resume_memory_occupation\")\n\n    def flush_cache(self):\n        return self._make_request(\"flush_cache\")\n"
  },
  {
    "path": "python/sglang/srt/entrypoints/ollama/README.md",
    "content": "# SGLang Ollama Integration\n\nOllama API compatibility for SGLang, plus a Smart Router for intelligent routing between local and remote models.\n\n## Features\n\n1. **Ollama-compatible API** - Use Ollama CLI/library with SGLang backend\n2. **Smart Router** - LLM-based routing between local and remote models\n\n## Ollama API\n\nFor basic Ollama API usage with SGLang (CLI and Python examples), see the [Ollama API documentation](https://sgl-project.github.io/basic_usage/ollama_api.html).\n\n## Smart Router\n\n### Prerequisites\n\n```bash\npip install ollama\n```\n\nIntelligently routes requests between local Ollama and remote SGLang using an LLM judge.\n\n### How It Works\n\n```\nUser Request\n     │\n     ▼\n┌─────────────────────┐\n│     LLM Judge       │  Classifies as SIMPLE or COMPLEX\n│  (local model)      │\n└─────────────────────┘\n     │\n     ▼\n┌─────────────────────┐\n│  SIMPLE → Local     │  Fast response from local Ollama\n│  COMPLEX → Remote   │  Powerful response from SGLang\n└─────────────────────┘\n```\n\nThe LLM judge (running on local Ollama) analyzes each request and decides:\n- **SIMPLE**: Quick responses, greetings, factual questions, definitions, basic Q&A\n- **COMPLEX**: Deep reasoning, multi-step analysis, long explanations, creative writing\n\n### Setup\n\n**Terminal 1: Local Ollama**\n```bash\nollama pull <LOCAL_MODEL>  # e.g., llama3.2, mistral, phi3\nollama serve  # This will block the terminal\n```\n\n**Terminal 2: Remote SGLang (GPU)**\n```bash\nssh user@gpu-server\npython -m sglang.launch_server --model <REMOTE_MODEL> --port 30001 --host 0.0.0.0\n```\n\n**Terminal 3: Smart Router**\n```bash\nssh -L 30001:localhost:30001 user@gpu-server -N &\npython python/sglang/srt/entrypoints/ollama/smart_router.py\n```\n\n### Configuration\n\n```python\nfrom sglang.srt.entrypoints.ollama.smart_router import SmartRouter\n\nrouter = SmartRouter(\n    # Local Ollama\n    local_host=\"http://localhost:11434\",\n    local_model=\"llama3.2\",  # or any Ollama model\n\n    # Remote SGLang\n    remote_host=\"http://localhost:30001\",\n    remote_model=\"Qwen/Qwen2.5-1.5B-Instruct\",  # or any HuggingFace model\n\n    # LLM Judge (optional, defaults to local_model)\n    judge_model=\"llama3.2\",\n)\n```\n\n### Usage\n\n```python\n# Auto-routing via LLM judge\nresponse = router.chat(\"Hello!\", verbose=True)\n# [Router] LLM Judge: SIMPLE\n# [Router] -> Local Ollama | Model: llama3.2\n\nresponse = router.chat(\"Explain quantum computing in detail\", verbose=True)\n# [Router] LLM Judge: COMPLEX\n# [Router] -> Remote SGLang | Model: Qwen/Qwen2.5-1.5B-Instruct\n\n# Force routing (skip LLM judge)\nresponse = router.chat(\"question\", force_local=True)\nresponse = router.chat(\"question\", force_remote=True)\n\n# Streaming\nfor chunk in router.chat_stream(\"Tell me a story\"):\n    print(chunk['message']['content'], end='')\n```\n\n---\n\n## Value\n\n- **Ollama**: Simple CLI/API developers already know\n- **SGLang**: High-performance inference\n- **Smart Router**: Intelligent routing - fast local for simple tasks, powerful remote for complex tasks\n"
  },
  {
    "path": "python/sglang/srt/entrypoints/ollama/__init__.py",
    "content": "# Ollama-compatible API for SGLang\n"
  },
  {
    "path": "python/sglang/srt/entrypoints/ollama/protocol.py",
    "content": "\"\"\"\nOllama-compatible API protocol definitions.\n\nThese models match the Ollama API format:\nhttps://github.com/ollama/ollama/blob/main/docs/api.md\n\"\"\"\n\nfrom typing import Any, Dict, List, Literal, Optional, Union\n\nfrom pydantic import BaseModel, Field\n\n\nclass OllamaMessage(BaseModel):\n    \"\"\"Ollama message format.\"\"\"\n\n    role: str\n    content: str\n    images: Optional[List[str]] = None\n\n\nclass OllamaChatRequest(BaseModel):\n    \"\"\"Ollama /api/chat request format.\"\"\"\n\n    model: str\n    messages: List[OllamaMessage]\n    stream: bool = True\n    format: Optional[Union[Literal[\"json\"], Dict[str, Any]]] = None\n    options: Optional[Dict[str, Any]] = None\n    keep_alive: Optional[Union[float, str]] = None\n    think: Optional[Union[bool, Literal[\"low\", \"medium\", \"high\"]]] = None\n\n\nclass OllamaChatResponse(BaseModel):\n    \"\"\"Ollama /api/chat response format (non-streaming).\"\"\"\n\n    model: str\n    created_at: str\n    message: OllamaMessage\n    done: bool = True\n    done_reason: Optional[str] = \"stop\"\n    total_duration: Optional[int] = None\n    load_duration: Optional[int] = None\n    prompt_eval_count: Optional[int] = None\n    prompt_eval_duration: Optional[int] = None\n    eval_count: Optional[int] = None\n    eval_duration: Optional[int] = None\n\n\nclass OllamaChatStreamResponse(BaseModel):\n    \"\"\"Ollama /api/chat streaming response chunk.\"\"\"\n\n    model: str\n    created_at: str\n    message: OllamaMessage\n    done: bool = False\n    done_reason: Optional[str] = None\n\n\nclass OllamaGenerateRequest(BaseModel):\n    \"\"\"Ollama /api/generate request format.\"\"\"\n\n    model: str\n    prompt: str\n    suffix: Optional[str] = None\n    system: Optional[str] = None\n    template: Optional[str] = None\n    context: Optional[List[int]] = None\n    stream: bool = True\n    raw: bool = False\n    format: Optional[Union[Literal[\"json\"], Dict[str, Any]]] = None\n    options: Optional[Dict[str, Any]] = None\n    keep_alive: Optional[Union[float, str]] = None\n    images: Optional[List[str]] = None\n    think: Optional[bool] = None\n\n\nclass OllamaGenerateResponse(BaseModel):\n    \"\"\"Ollama /api/generate response format (non-streaming).\"\"\"\n\n    model: str\n    created_at: str\n    response: str\n    done: bool = True\n    done_reason: Optional[str] = \"stop\"\n    context: Optional[List[int]] = None\n    total_duration: Optional[int] = None\n    load_duration: Optional[int] = None\n    prompt_eval_count: Optional[int] = None\n    prompt_eval_duration: Optional[int] = None\n    eval_count: Optional[int] = None\n    eval_duration: Optional[int] = None\n\n\nclass OllamaGenerateStreamResponse(BaseModel):\n    \"\"\"Ollama /api/generate streaming response chunk.\"\"\"\n\n    model: str\n    created_at: str\n    response: str\n    done: bool = False\n    done_reason: Optional[str] = None\n\n\nclass OllamaModelInfo(BaseModel):\n    \"\"\"Model information for /api/tags response.\"\"\"\n\n    name: str\n    model: str\n    modified_at: str\n    size: int\n    digest: str\n    details: Optional[Dict[str, Any]] = None\n\n\nclass OllamaTagsResponse(BaseModel):\n    \"\"\"Ollama /api/tags response format.\"\"\"\n\n    models: List[OllamaModelInfo]\n\n\nclass OllamaShowRequest(BaseModel):\n    \"\"\"Ollama /api/show request format.\"\"\"\n\n    model: str\n\n\nclass OllamaShowResponse(BaseModel):\n    \"\"\"Ollama /api/show response format.\"\"\"\n\n    license: str = \"\"\n    modelfile: str = \"\"\n    parameters: str = \"\"\n    template: str = \"\"\n    modified_at: str = \"\"\n    details: Dict[str, Any] = Field(default_factory=dict)\n    model_info: Dict[str, Any] = Field(default_factory=dict)\n    capabilities: List[str] = Field(default_factory=list)\n"
  },
  {
    "path": "python/sglang/srt/entrypoints/ollama/serving.py",
    "content": "\"\"\"\nOllama-compatible API serving handlers.\n\nThis module provides handlers that convert Ollama API requests to SGLang's\ninternal format and return Ollama-compatible responses.\n\"\"\"\n\nimport time\nfrom datetime import datetime, timezone\nfrom typing import AsyncIterator, Union\n\nimport orjson\nfrom fastapi import Request\nfrom fastapi.responses import StreamingResponse\n\nfrom sglang.srt.entrypoints.ollama.protocol import (\n    OllamaChatRequest,\n    OllamaChatResponse,\n    OllamaChatStreamResponse,\n    OllamaGenerateRequest,\n    OllamaGenerateResponse,\n    OllamaGenerateStreamResponse,\n    OllamaMessage,\n    OllamaModelInfo,\n    OllamaShowResponse,\n    OllamaTagsResponse,\n)\nfrom sglang.srt.managers.io_struct import GenerateReqInput\n\n\nclass OllamaServing:\n    \"\"\"Handler for Ollama-compatible API endpoints.\"\"\"\n\n    def __init__(self, tokenizer_manager):\n        self.tokenizer_manager = tokenizer_manager\n\n    def _get_timestamp(self) -> str:\n        \"\"\"Get current timestamp in Ollama format.\"\"\"\n        return datetime.now(timezone.utc).strftime(\"%Y-%m-%dT%H:%M:%S.%f\")[:-3] + \"Z\"\n\n    def _convert_options_to_sampling_params(self, options: dict = None) -> dict:\n        \"\"\"Convert Ollama options to SGLang sampling params.\"\"\"\n        sampling_params = {}\n\n        if options:\n            # Map Ollama options to SGLang params\n            param_mapping = {\n                \"temperature\": \"temperature\",\n                \"top_p\": \"top_p\",\n                \"top_k\": \"top_k\",\n                \"num_predict\": \"max_new_tokens\",\n                \"stop\": \"stop\",\n                \"presence_penalty\": \"presence_penalty\",\n                \"frequency_penalty\": \"frequency_penalty\",\n                \"seed\": \"seed\",\n            }\n            for ollama_param, sglang_param in param_mapping.items():\n                if ollama_param in options:\n                    sampling_params[sglang_param] = options[ollama_param]\n\n        # Set a reasonable default for max_new_tokens if not specified\n        # Ollama users typically expect longer responses than SGLang's default (128)\n        if \"max_new_tokens\" not in sampling_params:\n            sampling_params[\"max_new_tokens\"] = 2048\n\n        return sampling_params\n\n    async def handle_chat(\n        self, request: OllamaChatRequest, raw_request: Request\n    ) -> Union[OllamaChatResponse, StreamingResponse]:\n        \"\"\"Handle /api/chat endpoint.\"\"\"\n        model_name = self.tokenizer_manager.served_model_name\n\n        # Convert messages to SGLang format\n        messages = [\n            {\"role\": msg.role, \"content\": msg.content} for msg in request.messages\n        ]\n\n        # Apply chat template using tokenizer\n        prompt_ids = self.tokenizer_manager.tokenizer.apply_chat_template(\n            messages,\n            tokenize=True,\n            add_generation_prompt=True,\n        )\n\n        # Convert options to sampling params\n        sampling_params = self._convert_options_to_sampling_params(request.options)\n\n        # Create SGLang request with input_ids\n        gen_request = GenerateReqInput(\n            input_ids=prompt_ids,\n            sampling_params=sampling_params,\n            stream=request.stream,\n        )\n\n        if request.stream:\n            return await self._stream_chat_response(\n                gen_request, raw_request, model_name\n            )\n        else:\n            return await self._generate_chat_response(\n                gen_request, raw_request, model_name\n            )\n\n    async def _generate_chat_response(\n        self, gen_request: GenerateReqInput, raw_request: Request, model_name: str\n    ) -> OllamaChatResponse:\n        \"\"\"Generate non-streaming chat response.\"\"\"\n        start_time = time.time_ns()\n\n        # Get response from tokenizer manager\n        response = await self.tokenizer_manager.generate_request(\n            gen_request, raw_request\n        ).__anext__()\n\n        end_time = time.time_ns()\n        total_duration = end_time - start_time\n\n        output_text = response.get(\"text\", \"\")\n\n        return OllamaChatResponse(\n            model=model_name,\n            created_at=self._get_timestamp(),\n            message=OllamaMessage(role=\"assistant\", content=output_text),\n            done=True,\n            done_reason=\"stop\",\n            total_duration=total_duration,\n            prompt_eval_count=response.get(\"meta_info\", {}).get(\"prompt_tokens\", None),\n            eval_count=response.get(\"meta_info\", {}).get(\"completion_tokens\", None),\n        )\n\n    async def _stream_chat_response(\n        self, gen_request: GenerateReqInput, raw_request: Request, model_name: str\n    ) -> StreamingResponse:\n        \"\"\"Generate streaming chat response.\"\"\"\n\n        async def generate_stream() -> AsyncIterator[bytes]:\n            previous_text = \"\"\n            async for chunk in self.tokenizer_manager.generate_request(\n                gen_request, raw_request\n            ):\n                text = chunk.get(\"text\", \"\")\n                is_done = chunk.get(\"meta_info\", {}).get(\"finish_reason\") is not None\n\n                # Calculate delta (new text since last chunk)\n                delta = text[len(previous_text) :]\n                previous_text = text\n\n                if is_done:\n                    # Final chunk\n                    response = OllamaChatStreamResponse(\n                        model=model_name,\n                        created_at=self._get_timestamp(),\n                        message=OllamaMessage(role=\"assistant\", content=\"\"),\n                        done=True,\n                        done_reason=\"stop\",\n                    )\n                else:\n                    response = OllamaChatStreamResponse(\n                        model=model_name,\n                        created_at=self._get_timestamp(),\n                        message=OllamaMessage(role=\"assistant\", content=delta),\n                        done=False,\n                    )\n\n                yield orjson.dumps(response.model_dump()) + b\"\\n\"\n\n        return StreamingResponse(\n            generate_stream(),\n            media_type=\"application/x-ndjson\",\n        )\n\n    async def handle_generate(\n        self, request: OllamaGenerateRequest, raw_request: Request\n    ) -> Union[OllamaGenerateResponse, StreamingResponse]:\n        \"\"\"Handle /api/generate endpoint.\"\"\"\n        model_name = self.tokenizer_manager.served_model_name\n\n        # Build prompt\n        prompt = request.prompt\n        if request.system:\n            prompt = f\"{request.system}\\n\\n{prompt}\"\n\n        # Handle empty prompt - Ollama CLI sends empty requests on initialization\n        if not prompt or not prompt.strip():\n            empty_response = OllamaGenerateResponse(\n                model=model_name,\n                created_at=self._get_timestamp(),\n                response=\"\",\n                done=True,\n                done_reason=\"stop\",\n            )\n            if request.stream:\n                # Return streaming response with done=True\n                async def empty_stream() -> AsyncIterator[bytes]:\n                    yield orjson.dumps(empty_response.model_dump()) + b\"\\n\"\n\n                return StreamingResponse(\n                    empty_stream(),\n                    media_type=\"application/x-ndjson\",\n                )\n            return empty_response\n\n        # Convert options to sampling params\n        sampling_params = self._convert_options_to_sampling_params(request.options)\n\n        # Create SGLang request\n        gen_request = GenerateReqInput(\n            text=prompt,\n            sampling_params=sampling_params,\n            stream=request.stream,\n        )\n\n        if request.stream:\n            return await self._stream_generate_response(\n                gen_request, raw_request, model_name\n            )\n        else:\n            return await self._generate_generate_response(\n                gen_request, raw_request, model_name\n            )\n\n    async def _generate_generate_response(\n        self, gen_request: GenerateReqInput, raw_request: Request, model_name: str\n    ) -> OllamaGenerateResponse:\n        \"\"\"Generate non-streaming generate response.\"\"\"\n        start_time = time.time_ns()\n\n        response = await self.tokenizer_manager.generate_request(\n            gen_request, raw_request\n        ).__anext__()\n\n        end_time = time.time_ns()\n        total_duration = end_time - start_time\n\n        output_text = response.get(\"text\", \"\")\n\n        return OllamaGenerateResponse(\n            model=model_name,\n            created_at=self._get_timestamp(),\n            response=output_text,\n            done=True,\n            done_reason=\"stop\",\n            total_duration=total_duration,\n            prompt_eval_count=response.get(\"meta_info\", {}).get(\"prompt_tokens\", None),\n            eval_count=response.get(\"meta_info\", {}).get(\"completion_tokens\", None),\n        )\n\n    async def _stream_generate_response(\n        self, gen_request: GenerateReqInput, raw_request: Request, model_name: str\n    ) -> StreamingResponse:\n        \"\"\"Generate streaming generate response.\"\"\"\n\n        async def generate_stream() -> AsyncIterator[bytes]:\n            previous_text = \"\"\n            async for chunk in self.tokenizer_manager.generate_request(\n                gen_request, raw_request\n            ):\n                text = chunk.get(\"text\", \"\")\n                is_done = chunk.get(\"meta_info\", {}).get(\"finish_reason\") is not None\n\n                # Calculate delta (new text since last chunk)\n                delta = text[len(previous_text) :]\n                previous_text = text\n\n                if is_done:\n                    response = OllamaGenerateStreamResponse(\n                        model=model_name,\n                        created_at=self._get_timestamp(),\n                        response=\"\",\n                        done=True,\n                        done_reason=\"stop\",\n                    )\n                else:\n                    response = OllamaGenerateStreamResponse(\n                        model=model_name,\n                        created_at=self._get_timestamp(),\n                        response=delta,\n                        done=False,\n                    )\n\n                yield orjson.dumps(response.model_dump()) + b\"\\n\"\n\n        return StreamingResponse(\n            generate_stream(),\n            media_type=\"application/x-ndjson\",\n        )\n\n    def get_tags(self) -> OllamaTagsResponse:\n        \"\"\"Handle /api/tags endpoint - list available models.\"\"\"\n        model_name = self.tokenizer_manager.served_model_name\n\n        model_info = OllamaModelInfo(\n            name=model_name,\n            model=model_name,\n            modified_at=self._get_timestamp(),\n            size=0,  # We don't track model size\n            digest=\"sha256:sglang0000000000000000000000000000000000000000000000000000000000\",\n            details={\n                \"format\": \"sglang\",\n                \"family\": (\n                    model_name.split(\"/\")[-1] if \"/\" in model_name else model_name\n                ),\n                \"parameter_size\": \"unknown\",\n            },\n        )\n\n        return OllamaTagsResponse(models=[model_info])\n\n    def get_show(self, model: str) -> OllamaShowResponse:\n        \"\"\"Handle /api/show endpoint - show model information.\"\"\"\n        model_config = self.tokenizer_manager.model_config\n\n        # Extract model family from model name\n        model_family = model.split(\"/\")[-1] if \"/\" in model else model\n        # Remove common suffixes to get base family\n        for suffix in [\"-Instruct\", \"-Chat\", \"-Base\"]:\n            if model_family.endswith(suffix):\n                model_family = model_family[: -len(suffix)]\n                break\n\n        # Build context length info\n        context_len = model_config.context_len if model_config else 4096\n\n        return OllamaShowResponse(\n            license=\"\",  # License info not available from SGLang\n            modelfile=f\"FROM {model}\\nPARAMETER num_ctx {context_len}\\n\",\n            parameters=f\"num_ctx {context_len}\",\n            template=\"\",  # Template info not easily accessible\n            modified_at=self._get_timestamp(),\n            details={\n                \"parent_model\": \"\",\n                \"format\": \"sglang\",\n                \"family\": model_family,\n                \"families\": [model_family],\n                \"parameter_size\": \"unknown\",\n                \"quantization_level\": \"\",\n            },\n            model_info={\n                \"general.architecture\": model_family,\n                \"general.name\": model,\n                \"general.parameter_count\": 0,\n                f\"{model_family}.context_length\": context_len,\n                f\"{model_family}.block_count\": 0,\n                f\"{model_family}.embedding_length\": 0,\n                f\"{model_family}.attention.head_count\": 0,\n            },\n            capabilities=[\"completion\"],\n        )\n"
  },
  {
    "path": "python/sglang/srt/entrypoints/ollama/smart_router.py",
    "content": "\"\"\"\nSmart Router: Automatically routes requests between local Ollama and remote SGLang.\n\nUses an LLM judge to classify tasks as simple or complex, then routes accordingly:\n- Simple tasks → Local Ollama (fast response)\n- Complex tasks → Remote SGLang (powerful model)\n\nUsage:\n    from sglang.srt.entrypoints.ollama.smart_router import SmartRouter\n\n    router = SmartRouter(\n        local_host=\"http://localhost:11434\",\n        remote_host=\"http://sglang-server:30001\",\n    )\n    response = router.chat(\"Hello!\")\n\"\"\"\n\nfrom typing import Optional\n\nimport ollama\n\n\nclass SmartRouter:\n    \"\"\"Routes requests between local Ollama and remote SGLang using LLM-based classification.\"\"\"\n\n    # Classification prompt for LLM judge\n    CLASSIFICATION_PROMPT = \"\"\"You are a task classifier. Classify the following user request into one of two categories.\n\nCategories:\n- SIMPLE: Quick responses, greetings, factual questions, definitions, translations, basic Q&A\n- COMPLEX: Tasks requiring deep reasoning, multi-step analysis, long explanations, creative writing, detailed research\n\nReply with ONLY one word: either SIMPLE or COMPLEX.\n\nUser request: \"{prompt}\"\n\nCategory:\"\"\"\n\n    def __init__(\n        self,\n        local_host: str = \"http://localhost:11434\",\n        remote_host: str = \"http://localhost:30001\",\n        local_model: str = \"llama3.2\",\n        remote_model: str = \"Qwen/Qwen2.5-1.5B-Instruct\",\n        judge_model: Optional[str] = None,\n        judge_host: Optional[str] = None,\n    ):\n        \"\"\"\n        Initialize the smart router.\n\n        Args:\n            local_host: URL of local Ollama server\n            remote_host: URL of remote SGLang server\n            local_model: Model name for local Ollama\n            remote_model: Model name for remote SGLang\n            judge_model: Model for LLM-based classification (default: same as local_model)\n            judge_host: Host for judge model (default: same as local_host)\n        \"\"\"\n        self.local_client = ollama.Client(host=local_host)\n        self.remote_client = ollama.Client(host=remote_host)\n        self.local_model = local_model\n        self.remote_model = remote_model\n\n        # Judge model configuration\n        self.judge_model = judge_model or local_model\n        self.judge_host = judge_host or local_host\n        self.judge_client = ollama.Client(host=self.judge_host)\n\n    def _classify_with_llm(\n        self, prompt: str, verbose: bool = False\n    ) -> tuple[bool, str]:\n        \"\"\"\n        Use LLM to classify the prompt.\n\n        Returns:\n            Tuple of (use_remote, reason)\n        \"\"\"\n        try:\n            classification_prompt = self.CLASSIFICATION_PROMPT.format(\n                prompt=prompt[:500]  # Limit prompt length for classification\n            )\n\n            response = self.judge_client.chat(\n                model=self.judge_model,\n                messages=[{\"role\": \"user\", \"content\": classification_prompt}],\n                options={\"temperature\": 0, \"num_predict\": 10},\n            )\n\n            result = response[\"message\"][\"content\"].strip().upper()\n\n            if verbose:\n                print(f\"[Router] LLM Judge: {result}\")\n\n            if \"COMPLEX\" in result:\n                return True, \"Complex task\"\n            else:\n                return False, \"Simple task\"\n\n        except Exception as e:\n            if verbose:\n                print(f\"[Router] LLM Judge failed: {e}, defaulting to local\")\n            return False, \"Judge failed, defaulting to local\"\n\n    def should_use_remote(self, prompt: str, verbose: bool = False) -> tuple[bool, str]:\n        \"\"\"\n        Determine if the prompt should be routed to remote SGLang.\n\n        Args:\n            prompt: User's input prompt\n            verbose: Print debug information\n\n        Returns:\n            Tuple of (should_use_remote, reason)\n        \"\"\"\n        return self._classify_with_llm(prompt, verbose)\n\n    def chat(\n        self,\n        prompt: str,\n        messages: Optional[list] = None,\n        verbose: bool = False,\n        force_local: bool = False,\n        force_remote: bool = False,\n    ) -> dict:\n        \"\"\"\n        Route the request and get response.\n\n        Args:\n            prompt: User's input (used if messages is None)\n            messages: Full message history (overrides prompt if provided)\n            verbose: Print routing decision\n            force_local: Force use of local model\n            force_remote: Force use of remote model\n\n        Returns:\n            Response dict with 'content', 'model', 'location', 'reason' keys\n        \"\"\"\n        # Build messages\n        if messages is None:\n            messages = [{\"role\": \"user\", \"content\": prompt}]\n            check_prompt = prompt\n        else:\n            # Use the last user message for routing decision\n            check_prompt = \"\"\n            for msg in reversed(messages):\n                if msg.get(\"role\") == \"user\":\n                    check_prompt = msg.get(\"content\", \"\")\n                    break\n\n        # Determine routing\n        if force_remote:\n            use_remote, reason = True, \"Forced remote\"\n        elif force_local:\n            use_remote, reason = False, \"Forced local\"\n        else:\n            use_remote, reason = self.should_use_remote(check_prompt, verbose)\n\n        if use_remote:\n            client = self.remote_client\n            model = self.remote_model\n            location = \"Remote SGLang\"\n        else:\n            client = self.local_client\n            model = self.local_model\n            location = \"Local Ollama\"\n\n        if verbose:\n            print(f\"[Router] -> {location} | Model: {model}\")\n\n        try:\n            response = client.chat(model=model, messages=messages)\n            return {\n                \"content\": response[\"message\"][\"content\"],\n                \"model\": model,\n                \"location\": location,\n                \"reason\": reason,\n            }\n        except Exception as e:\n            # Fallback to the other option\n            if verbose:\n                print(f\"[Router] {location} failed: {e}, falling back...\")\n\n            fallback_client = (\n                self.remote_client if not use_remote else self.local_client\n            )\n            fallback_model = self.remote_model if not use_remote else self.local_model\n            fallback_location = \"Remote SGLang\" if not use_remote else \"Local Ollama\"\n\n            response = fallback_client.chat(model=fallback_model, messages=messages)\n            return {\n                \"content\": response[\"message\"][\"content\"],\n                \"model\": fallback_model,\n                \"location\": fallback_location,\n                \"reason\": f\"Fallback from {location}\",\n            }\n\n    def chat_stream(\n        self,\n        prompt: str,\n        messages: Optional[list] = None,\n        verbose: bool = False,\n        force_local: bool = False,\n        force_remote: bool = False,\n    ):\n        \"\"\"\n        Route the request and stream response.\n\n        Yields:\n            Response chunks\n        \"\"\"\n        if messages is None:\n            messages = [{\"role\": \"user\", \"content\": prompt}]\n            check_prompt = prompt\n        else:\n            check_prompt = \"\"\n            for msg in reversed(messages):\n                if msg.get(\"role\") == \"user\":\n                    check_prompt = msg.get(\"content\", \"\")\n                    break\n\n        if force_remote:\n            use_remote, reason = True, \"Forced remote\"\n        elif force_local:\n            use_remote, reason = False, \"Forced local\"\n        else:\n            use_remote, reason = self.should_use_remote(check_prompt, verbose)\n\n        if use_remote:\n            client = self.remote_client\n            model = self.remote_model\n            location = \"Remote SGLang\"\n        else:\n            client = self.local_client\n            model = self.local_model\n            location = \"Local Ollama\"\n\n        if verbose:\n            print(f\"[Router] -> {location} | Model: {model}\")\n\n        for chunk in client.chat(model=model, messages=messages, stream=True):\n            yield chunk\n\n\ndef main():\n    \"\"\"Interactive demo of the smart router.\"\"\"\n    print(\"=\" * 60)\n    print(\"Smart Router: Local Ollama <-> Remote SGLang\")\n    print(\"=\" * 60)\n    print(\"\\nRouting strategy:\")\n    print(\"  LLM Judge classifies each request as SIMPLE or COMPLEX\")\n    print(\"  - SIMPLE tasks -> Local Ollama (fast)\")\n    print(\"  - COMPLEX tasks -> Remote SGLang (powerful)\")\n    print(\"\\nType 'quit' to exit\\n\")\n\n    router = SmartRouter(\n        local_host=\"http://localhost:11434\",\n        remote_host=\"http://localhost:30001\",\n        local_model=\"llama3.2\",\n        remote_model=\"Qwen/Qwen2.5-1.5B-Instruct\",\n    )\n\n    messages = []\n    while True:\n        try:\n            user_input = input(\"You: \").strip()\n            if user_input.lower() in [\"quit\", \"exit\", \"q\"]:\n                print(\"Goodbye!\")\n                break\n            if not user_input:\n                continue\n\n            messages.append({\"role\": \"user\", \"content\": user_input})\n\n            # Use streaming for real-time output\n            print(\"\\nAssistant: \", end=\"\", flush=True)\n            full_response = \"\"\n            for chunk in router.chat_stream(\n                prompt=user_input, messages=messages, verbose=True\n            ):\n                content = chunk.get(\"message\", {}).get(\"content\", \"\")\n                if content:\n                    print(content, end=\"\", flush=True)\n                    full_response += content\n            print(\"\\n\")\n\n            messages.append({\"role\": \"assistant\", \"content\": full_response})\n\n        except KeyboardInterrupt:\n            print(\"\\nGoodbye!\")\n            break\n        except Exception as e:\n            print(f\"Error: {e}\\n\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "python/sglang/srt/entrypoints/openai/__init__.py",
    "content": ""
  },
  {
    "path": "python/sglang/srt/entrypoints/openai/encoding_dsv32.py",
    "content": "# Adapted from https://huggingface.co/deepseek-ai/DeepSeek-V3.2/blob/main/encoding/encoding_dsv32.py\nimport copy\nimport json\nimport re\nfrom typing import Any, Dict, List, Optional, Tuple, Union\n\n\nclass DS32EncodingError(Exception):\n    pass\n\n\nTOOLS_SYSTEM_TEMPLATE = \"\"\"## Tools\nYou have access to a set of tools you can use to answer the user's question.\nYou can invoke functions by writing a \"<{dsml_token}function_calls>\" block like the following as part of your reply to the user:\n<{dsml_token}function_calls>\n<{dsml_token}invoke name=\"$FUNCTION_NAME\">\n<{dsml_token}parameter name=\"$PARAMETER_NAME\" string=\"true|false\">$PARAMETER_VALUE</{dsml_token}parameter>\n...\n</{dsml_token}invoke>\n<{dsml_token}invoke name=\"$FUNCTION_NAME2\">\n...\n</{dsml_token}invoke>\n</{dsml_token}function_calls>\nString and scalar parameters should be specified as is without any escaping or quotes, while lists and objects should use JSON format. The \"string\" attribute should be set to \"true\" for string type parameters and \"false\" for other types (numbers, booleans, arrays, objects).\nIf the thinking_mode is enabled, then after function results you should strongly consider outputting a thinking block. Here is an example:\n<{dsml_token}function_calls>\n...\n</{dsml_token}function_calls>\n<function_results>\n...\n</function_results>\n{thinking_start_token}...thinking about results{thinking_end_token}\nHere are the functions available in JSONSchema format:\n<functions>\n{tool_schemas}\n</functions>\n\"\"\"\n\nbos_token: str = \"<｜begin▁of▁sentence｜>\"\neos_token: str = \"<｜end▁of▁sentence｜>\"\nthinking_start_token: str = \"<think>\"\nthinking_end_token: str = \"</think>\"\ndsml_token: str = \"｜DSML｜\"\nsystem_msg_template: str = \"{content}\"\nuser_msg_template: str = \"<｜User｜>{content}<｜Assistant｜>\"\nassistant_msg_template: str = \"{reasoning}{content}{tool_calls}<｜end▁of▁sentence｜>\"\nthinking_template = \"{reasoning_content}\"\n\nresponse_format_template: str = (\n    \"## Response Format:\\n\\nYou MUST strictly adhere to the following schema to reply:\\n{schema}\"\n)\ntool_call_template: str = (\n    '<{dsml_token}invoke name=\"{name}\">\\n{arguments}\\n</{dsml_token}invoke>'\n)\ntool_calls_template = (\n    \"<{dsml_token}function_calls>\\n{tool_calls}\\n</{dsml_token}function_calls>\"\n)\n\ntool_output_template: str = \"\\n<result>{content}</result>\"\n\n\ndef to_json(value: Any) -> str:\n    try:\n        return json.dumps(value, ensure_ascii=False)\n    except:\n        return json.dumps(value, ensure_ascii=True)\n\n\ndef tools_from_openai_format(tools):\n    return [tool[\"function\"] for tool in tools]\n\n\ndef tool_calls_from_openai_format(tool_calls):\n    return [\n        {\n            \"name\": tool_call[\"function\"][\"name\"],\n            \"arguments\": tool_call[\"function\"][\"arguments\"],\n        }\n        for tool_call in tool_calls\n    ]\n\n\ndef tool_calls_to_openai_format(tool_calls):\n    return [\n        {\n            \"type\": \"function\",\n            \"function\": {\n                \"name\": tool_call[\"name\"],\n                \"arguments\": tool_call[\"arguments\"],\n            },\n        }\n        for tool_call in tool_calls\n    ]\n\n\ndef encode_arguments_to_dsml(tool_call: Dict[str, str]) -> str:\n    p_dsml_template = \"\"\"<{dsml_token}parameter name=\"{key}\" string=\"{is_str}\">{value}</{dsml_token}parameter>\"\"\"\n    P_dsml_strs = []\n\n    arguments = json.loads(tool_call[\"arguments\"])\n\n    for k, v in arguments.items():\n        p_dsml_str = p_dsml_template.format(\n            dsml_token=dsml_token,\n            key=k,\n            is_str=\"true\" if isinstance(v, str) else \"false\",\n            value=v if isinstance(v, str) else to_json(v),\n        )\n\n        P_dsml_strs.append(p_dsml_str)\n\n    return \"\\n\".join(P_dsml_strs)\n\n\ndef decode_dsml_to_arguments(\n    tool_name: str, tool_args: Dict[str, Tuple[str, str]]\n) -> Dict[str, str]:\n    def _decode_value(key: str, value: str, string: str):\n        if string == \"true\":\n            value = to_json(value)\n        return f\"{to_json(key)}: {value}\"\n\n    tool_args_json = (\n        \"{\"\n        + \", \".join(\n            [_decode_value(k, v, string=is_str) for k, (v, is_str) in tool_args.items()]\n        )\n        + \"}\"\n    )\n    return dict(name=tool_name, arguments=tool_args_json)\n\n\ndef render_tools(tools: List[Dict[str, Union[str, Dict[str, Any]]]]) -> str:\n    tools_json = [to_json(t) for t in tools]\n\n    return TOOLS_SYSTEM_TEMPLATE.format(\n        tool_schemas=\"\\n\".join(tools_json),\n        dsml_token=dsml_token,\n        thinking_start_token=thinking_start_token,\n        thinking_end_token=thinking_end_token,\n    )\n\n\ndef find_last_user_index(messages: List[Dict[str, Any]]) -> int:\n    last_user_index = -1\n    for idx in range(len(messages) - 1, -1, -1):\n        if messages[idx].get(\"role\") in [\"user\", \"developer\"]:\n            last_user_index = idx\n            break\n    return last_user_index\n\n\ndef render_message(\n    index: int, messages: List[Dict[str, Any]], thinking_mode: str\n) -> str:\n    if not (0 <= index < len(messages)):\n        raise DS32EncodingError(\n            f\"Index {index} out of range for messages list of length {len(messages)}\"\n        )\n    if thinking_mode not in [\"chat\", \"thinking\"]:\n        raise DS32EncodingError(f\"Invalid thinking_mode `{thinking_mode}`\")\n\n    prompt = \"\"\n    msg = messages[index]\n    last_user_idx = find_last_user_index(messages)\n\n    role = msg.get(\"role\")\n    content = msg.get(\"content\")\n    tools = msg.get(\"tools\")\n    response_format = msg.get(\"response_format\")\n    tool_calls = msg.get(\"tool_calls\")\n    reasoning_content = msg.get(\"reasoning_content\")\n\n    if tools:\n        tools = tools_from_openai_format(tools)\n    if tool_calls:\n        tool_calls = tool_calls_from_openai_format(tool_calls)\n\n    if role == \"system\":\n        prompt += system_msg_template.format(content=content or \"\")\n        if tools:\n            prompt += \"\\n\\n\" + render_tools(tools)\n\n        if response_format:\n            prompt += \"\\n\\n\" + response_format_template.format(\n                schema=to_json(response_format)\n            )\n\n    elif role == \"developer\":\n        if not content:\n            raise DS32EncodingError(f\"Invalid message for role `{role}`: {msg}\")\n        content_developer = \"\"\n        if tools:\n            content_developer += \"\\n\\n\" + render_tools(tools)\n\n        if response_format:\n            content_developer += \"\\n\\n\" + response_format_template.format(\n                schema=to_json(response_format)\n            )\n\n        content_developer += \"\\n\\n# The user's message is: {}\".format(content)\n\n        prompt += user_msg_template.format(content=content_developer)\n        if index == last_user_idx and thinking_mode == \"thinking\":\n            prompt += thinking_start_token\n        else:\n            prompt += thinking_end_token\n\n    elif role == \"user\":\n        prompt += user_msg_template.format(content=content)\n\n        if index == last_user_idx and thinking_mode == \"thinking\":\n            prompt += thinking_start_token\n        else:\n            prompt += thinking_end_token\n\n    elif role == \"tool\":\n        prev_assistant_idx = index - 1\n        assistant_msg = messages[prev_assistant_idx]\n        while prev_assistant_idx >= 0 and assistant_msg.get(\"role\") == \"tool\":\n            prev_assistant_idx -= 1\n            assistant_msg = messages[prev_assistant_idx]\n\n        if not (\n            index == 0\n            or (prev_assistant_idx >= 0 and assistant_msg.get(\"role\") == \"assistant\")\n        ):\n            raise DS32EncodingError(f\"Invalid messages at {index}:\\n{assistant_msg}\")\n\n        tool_call_order = index - prev_assistant_idx\n        assistant_tool_calls = assistant_msg.get(\"tool_calls\")\n        if not (assistant_tool_calls and len(assistant_tool_calls) >= tool_call_order):\n            raise DS32EncodingError(\"No tool calls but found tool output\")\n\n        if tool_call_order == 1:\n            prompt += \"\\n\\n<function_results>\"\n\n        prompt += tool_output_template.format(content=content)\n\n        if tool_call_order == len(assistant_tool_calls):\n            prompt += \"\\n</function_results>\"\n\n            if index >= last_user_idx and thinking_mode == \"thinking\":\n                prompt += \"\\n\\n\" + thinking_start_token\n            else:\n                prompt += \"\\n\\n\" + thinking_end_token\n\n    elif role == \"assistant\":\n        prev_assistant_idx = index\n        thinking_part = \"\"\n\n        tool_calls_content = \"\"\n        if tool_calls:\n            tool_calls = [\n                tool_call_template.format(\n                    dsml_token=dsml_token,\n                    name=tool_call.get(\"name\"),\n                    arguments=encode_arguments_to_dsml(tool_call),\n                )\n                for tool_call in tool_calls\n            ]\n            tool_calls_content += \"\\n\\n\" + tool_calls_template.format(\n                dsml_token=dsml_token, tool_calls=\"\\n\".join(tool_calls)\n            )\n\n        summary_content = content or \"\"\n\n        if thinking_mode == \"thinking\" and index > last_user_idx:\n            if not (reasoning_content or tool_calls):\n                raise DS32EncodingError(\n                    f\"ThinkingMode: {thinking_mode}, invalid message without reasoning_content/tool_calls `{msg}` after last user message\"\n                )\n            thinking_part = (\n                thinking_template.format(reasoning_content=reasoning_content or \"\")\n                + thinking_end_token\n            )\n\n        prompt += assistant_msg_template.format(\n            reasoning=thinking_part,\n            content=summary_content,\n            tool_calls=tool_calls_content,\n        )\n    else:\n        raise NotImplementedError(f\"Unknown role: {role}\")\n\n    return prompt\n\n\ndef drop_thinking_messages(\n    messages: List[Dict[str, Any]], last_user_idx: Optional[int] = None\n) -> List[Dict[str, Any]]:\n    messages_wo_thinking: List[Dict[str, Any]] = []\n    last_user_idx = (\n        find_last_user_index(messages) if last_user_idx is None else last_user_idx\n    )\n    for idx, msg in enumerate(messages):\n        role = msg.get(\"role\")\n        if role in [\"user\", \"system\", \"tool\"] or idx >= last_user_idx:\n            messages_wo_thinking.append(msg)\n            continue\n\n        elif role == \"assistant\":\n            msg_wo_thinking = copy.copy(msg)\n            msg_wo_thinking.pop(\"reasoning_content\", None)\n            messages_wo_thinking.append(msg_wo_thinking)\n\n    return messages_wo_thinking\n\n\ndef encode_messages(\n    messages: List[Dict[str, Any]],\n    thinking_mode: str,\n    context: Optional[List[Dict[str, Any]]] = None,\n    drop_thinking: bool = True,\n    add_default_bos_token: bool = True,\n) -> str:\n    context = context if context else []\n    full_messages = context + messages\n\n    prompt = bos_token if add_default_bos_token and len(context) == 0 else \"\"\n\n    if thinking_mode == \"thinking\" and drop_thinking:\n        full_messages = drop_thinking_messages(full_messages)\n\n    for idx in range(len(messages)):\n        prompt += render_message(\n            idx + len(context), full_messages, thinking_mode=thinking_mode\n        )\n\n    return prompt\n\n\ndef _read_until_stop(\n    index: int, text: str, stop: List[str]\n) -> Tuple[int, str, Optional[str]]:\n    min_pos = len(text)\n    matched_stop = None\n\n    for s in stop:\n        pos = text.find(s, index)\n        if pos != -1 and pos < min_pos:\n            min_pos = pos\n            matched_stop = s\n\n    if matched_stop:\n        content = text[index:min_pos]\n        return min_pos + len(matched_stop), content, matched_stop\n    else:\n        content = text[index:]\n        return len(text), content, None\n\n\ndef parse_tool_calls(index: int, text: str):\n    tool_calls: List[Dict[str, Any]] = []\n    stop_token = None\n    tool_calls_end_token = f\"</{dsml_token}function_calls>\"\n\n    while index < len(text):\n        index, _, stop_token = _read_until_stop(\n            index, text, [f\"<{dsml_token}invoke\", tool_calls_end_token]\n        )\n        if _ != \">\\n\":\n            raise DS32EncodingError(\"Tool call format error\")\n\n        if stop_token == tool_calls_end_token:\n            break\n\n        if stop_token is None:\n            raise DS32EncodingError(\"Missing special token\")\n\n        index, tool_name_content, stop_token = _read_until_stop(\n            index, text, [f\"<{dsml_token}parameter\", f\"</{dsml_token}invoke\"]\n        )\n\n        p_tool_name = re.findall(\n            r'^\\s*name=\"(.*?)\">\\n$', tool_name_content, flags=re.DOTALL\n        )\n        if len(p_tool_name) != 1:\n            raise DS32EncodingError(\"Tool name format error\")\n        tool_name = p_tool_name[0]\n\n        tool_args: Dict[str, Tuple[str, str]] = {}\n        while stop_token == f\"<{dsml_token}parameter\":\n            index, param_content, stop_token = _read_until_stop(\n                index, text, [f\"/{dsml_token}parameter\"]\n            )\n\n            param_kv = re.findall(\n                r'^ name=\"(.*?)\" string=\"(true|false)\">(.*?)<$',\n                param_content,\n                flags=re.DOTALL,\n            )\n            if len(param_kv) != 1:\n                raise DS32EncodingError(\"Parameter format error\")\n            param_name, string, param_value = param_kv[0]\n\n            if param_name in tool_args:\n                raise DS32EncodingError(\"Duplicate parameter name\")\n            tool_args[param_name] = (param_value, string)\n\n            index, content, stop_token = _read_until_stop(\n                index, text, [f\"<{dsml_token}parameter\", f\"</{dsml_token}invoke\"]\n            )\n            if content != \">\\n\":\n                raise DS32EncodingError(\"Parameter format error\")\n\n        tool_call = decode_dsml_to_arguments(tool_name=tool_name, tool_args=tool_args)\n        tool_calls.append(tool_call)\n\n    return index, stop_token, tool_calls\n\n\n# NOTE: This function is designed to parse only correctly formatted string and will not attempt to correct malformed output that may be generated by the model.\ndef parse_message_from_completion_text(text: str, thinking_mode: str):\n    summary_content, reasoning_content, tool_calls = \"\", \"\", []\n    index, stop_token = 0, None\n    tool_calls_start_token = f\"\\n\\n<{dsml_token}function_calls\"\n\n    is_thinking, is_tool_calling = thinking_mode == \"thinking\", False\n\n    if is_thinking:\n        index, content_delta, stop_token = _read_until_stop(\n            index, text, [thinking_end_token, tool_calls_start_token]\n        )\n        reasoning_content = content_delta\n        if stop_token != thinking_end_token:\n            raise DS32EncodingError(\"Invalid thinking format\")\n\n    index, content_delta, stop_token = _read_until_stop(\n        index, text, [eos_token, tool_calls_start_token]\n    )\n    summary_content = content_delta\n    if stop_token == tool_calls_start_token:\n        is_tool_calling = True\n    else:\n        if stop_token != eos_token:\n            raise DS32EncodingError(\"Invalid summary format\")\n\n    if is_tool_calling:\n        index, stop_token, tool_calls = parse_tool_calls(index, text)\n\n        index, tool_ends_text, stop_token = _read_until_stop(index, text, [eos_token])\n        if tool_ends_text:\n            raise DS32EncodingError(\"Unexpected content after tool calls\")\n\n    if not (len(text) == index and stop_token in [eos_token, None]):\n        raise DS32EncodingError(\"Unexpected content at end\")\n\n    for sp_token in [\n        bos_token,\n        eos_token,\n        thinking_start_token,\n        thinking_end_token,\n        dsml_token,\n    ]:\n        if sp_token in summary_content or sp_token in reasoning_content:\n            raise DS32EncodingError(\"Unexpected special token in content\")\n\n    return {\n        \"role\": \"assistant\",\n        \"content\": summary_content,\n        \"reasoning_content\": reasoning_content,\n        \"tool_calls\": tool_calls_to_openai_format(tool_calls),\n    }\n"
  },
  {
    "path": "python/sglang/srt/entrypoints/openai/protocol.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Pydantic models for OpenAI API protocol\"\"\"\n\nimport logging\nimport time\nimport uuid\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List, NamedTuple, Optional, Tuple, TypeAlias, Union\n\nfrom openai.types.responses import (\n    ResponseFunctionToolCall,\n    ResponseInputItemParam,\n    ResponseOutputItem,\n    ResponseOutputMessage,\n    ResponseOutputText,\n    ResponseReasoningItem,\n)\nfrom openai.types.responses.response import ToolChoice\nfrom openai.types.responses.tool import Tool\nfrom pydantic import (\n    BaseModel,\n    Field,\n    field_validator,\n    model_serializer,\n    model_validator,\n)\nfrom typing_extensions import Literal\n\ntry:\n    from xgrammar import StructuralTag\nexcept:\n    StructuralTag = Any\n\nfrom sglang.utils import convert_json_schema_to_str\n\nlogger = logging.getLogger(__name__)\n\nDEFAULT_MODEL_NAME = \"default\"\n\n\nclass ModelCard(BaseModel):\n    \"\"\"Model cards.\"\"\"\n\n    id: str\n    object: str = \"model\"\n    created: int = Field(default_factory=lambda: int(time.time()))\n    owned_by: str = \"sglang\"\n    root: Optional[str] = None\n    parent: Optional[str] = None\n    max_model_len: Optional[int] = None\n\n\nclass ModelList(BaseModel):\n    \"\"\"Model list consists of model cards.\"\"\"\n\n    object: str = \"list\"\n    data: List[ModelCard] = Field(default_factory=list)\n\n\nclass ErrorResponse(BaseModel):\n    object: str = \"error\"\n    message: str\n    type: str\n    param: Optional[str] = None\n    code: int\n\n\nclass LogProbs(BaseModel):\n    text_offset: List[int] = Field(default_factory=list)\n    token_logprobs: List[Optional[float]] = Field(default_factory=list)\n    tokens: List[str] = Field(default_factory=list)\n    top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)\n\n\nclass TopLogprob(BaseModel):\n    token: str\n    bytes: List[int]\n    logprob: float\n\n\nclass ChatCompletionTokenLogprob(BaseModel):\n    token: str\n    bytes: List[int]\n    logprob: float\n    top_logprobs: List[TopLogprob]\n\n\nclass ChoiceLogprobs(BaseModel):\n    # build for v1/chat/completions response\n    content: List[ChatCompletionTokenLogprob]\n\n\nclass CachedTokensDetails(BaseModel):\n    \"\"\"Detailed breakdown of cached tokens by cache source.\"\"\"\n\n    device: int = 0  # Tokens from device cache (GPU)\n    host: int = 0  # Tokens from host cache (CPU memory)\n    # L3 storage fields are only present when storage backend is enabled\n    storage: Optional[int] = None  # Tokens from L3 storage backend\n    storage_backend: Optional[str] = None  # Type of storage backend used\n\n    @model_serializer(mode=\"wrap\")\n    def _serialize(self, handler):\n        data = handler(self)\n        # Remove None fields so they don't appear in response when L3 is disabled\n        if self.storage is None:\n            data.pop(\"storage\", None)\n        if self.storage_backend is None:\n            data.pop(\"storage_backend\", None)\n        return data\n\n\nclass PromptTokensDetails(BaseModel):\n    \"\"\"Details about prompt tokens.\"\"\"\n\n    cached_tokens: int = 0\n\n\nclass UsageInfo(BaseModel):\n    prompt_tokens: int = 0\n    total_tokens: int = 0\n    completion_tokens: Optional[int] = 0\n    # Used to return cached tokens info when --enable-cache-report is set\n    prompt_tokens_details: Optional[PromptTokensDetails] = None\n    reasoning_tokens: Optional[int] = 0\n\n\nclass StreamOptions(BaseModel):\n    include_usage: Optional[bool] = False\n    continuous_usage_stats: Optional[bool] = False\n\n\nclass JsonSchemaResponseFormat(BaseModel):\n    name: str\n    description: Optional[str] = None\n    # use alias to workaround pydantic conflict\n    schema_: Optional[Dict[str, object]] = Field(alias=\"schema\", default=None)\n    strict: Optional[bool] = False\n\n\nclass ResponseFormat(BaseModel):\n    type: Literal[\"text\", \"json_object\", \"json_schema\"]\n    json_schema: Optional[JsonSchemaResponseFormat] = None\n\n\nclass StructuresResponseFormat(BaseModel):\n    begin: str\n    schema_: Optional[Dict[str, object]] = Field(alias=\"schema\", default=None)\n    end: str\n\n\n# NOTE(dark): keep this for backward compatibility\nclass LegacyStructuralTagResponseFormat(BaseModel):\n    type: Literal[\"structural_tag\"]\n    structures: List[StructuresResponseFormat]\n    triggers: List[str]\n\n\nStructuralTagResponseFormat: TypeAlias = Union[\n    LegacyStructuralTagResponseFormat, StructuralTag\n]\n\nToolCallConstraint: TypeAlias = Union[\n    Tuple[Literal[\"structural_tag\"], StructuralTagResponseFormat],\n    Tuple[Literal[\"json_schema\"], Any],  # json_schema can be dict/str/None\n]\n\n\nclass FileRequest(BaseModel):\n    # https://platform.openai.com/docs/api-reference/files/create\n    file: bytes  # The File object (not file name) to be uploaded\n    purpose: str = (\n        \"batch\"  # The intended purpose of the uploaded file, default is \"batch\"\n    )\n\n\nclass FileResponse(BaseModel):\n    id: str\n    object: str = \"file\"\n    bytes: int\n    created_at: int\n    filename: str\n    purpose: str\n\n\nclass FileDeleteResponse(BaseModel):\n    id: str\n    object: str = \"file\"\n    deleted: bool\n\n\nclass BatchRequest(BaseModel):\n    input_file_id: (\n        str  # The ID of an uploaded file that contains requests for the new batch\n    )\n    endpoint: str  # The endpoint to be used for all requests in the batch\n    completion_window: str  # The time frame within which the batch should be processed\n    metadata: Optional[dict] = None  # Optional custom metadata for the batch\n\n\nclass BatchResponse(BaseModel):\n    id: str\n    object: str = \"batch\"\n    endpoint: str\n    errors: Optional[dict] = None\n    input_file_id: str\n    completion_window: str\n    status: str = \"validating\"\n    output_file_id: Optional[str] = None\n    error_file_id: Optional[str] = None\n    created_at: int\n    in_progress_at: Optional[int] = None\n    expires_at: Optional[int] = None\n    finalizing_at: Optional[int] = None\n    completed_at: Optional[int] = None\n    failed_at: Optional[int] = None\n    expired_at: Optional[int] = None\n    cancelling_at: Optional[int] = None\n    cancelled_at: Optional[int] = None\n    request_counts: Optional[dict] = None\n    metadata: Optional[dict] = None\n\n\ndef _migrate_deprecated_dp_rank(values: dict) -> dict:\n    if isinstance(values, dict) and values.get(\"data_parallel_rank\") is not None:\n        import warnings\n\n        warnings.warn(\n            \"'data_parallel_rank' is deprecated, use 'routed_dp_rank' instead.\",\n            DeprecationWarning,\n            stacklevel=2,\n        )\n        if values.get(\"routed_dp_rank\") is None:\n            values[\"routed_dp_rank\"] = values[\"data_parallel_rank\"]\n    return values\n\n\nclass CompletionRequest(BaseModel):\n    # Ordered by official OpenAI API documentation\n    # https://platform.openai.com/docs/api-reference/completions/create\n    model: str = Field(\n        default=DEFAULT_MODEL_NAME,\n        description=\"Model name. Supports LoRA adapters via 'base-model:adapter-name' syntax.\",\n    )\n    prompt: Union[List[int], List[List[int]], str, List[str]]\n    best_of: Optional[int] = None\n    echo: bool = False\n    frequency_penalty: float = 0.0\n    logit_bias: Optional[Dict[str, float]] = None\n    logprobs: Optional[int] = None\n    max_tokens: int = 16\n    n: int = 1\n    presence_penalty: float = 0.0\n    seed: Optional[int] = None\n    stop: Optional[Union[str, List[str]]] = None\n    stream: bool = False\n    stream_options: Optional[StreamOptions] = None\n    suffix: Optional[str] = None\n    temperature: float = 1.0\n    top_p: float = 1.0\n    user: Optional[str] = None\n    return_hidden_states: bool = False\n    return_routed_experts: bool = False\n    return_cached_tokens_details: bool = False\n\n    # Extra parameters for SRT backend only and will be ignored by OpenAI models.\n    top_k: int = -1\n    min_p: float = 0.0\n    min_tokens: int = 0\n    json_schema: Optional[str] = None\n    regex: Optional[str] = None\n    ebnf: Optional[str] = None\n    repetition_penalty: float = 1.0\n    stop_token_ids: Optional[List[int]] = None\n    stop_regex: Optional[Union[str, List[str]]] = None\n    no_stop_trim: bool = False\n    ignore_eos: bool = False\n    skip_special_tokens: bool = True\n    lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None\n    session_params: Optional[Dict] = None\n    response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None\n    custom_params: Optional[Dict] = None\n    custom_logit_processor: Optional[str] = None\n\n    # For PD disaggregation\n    bootstrap_host: Optional[Union[List[str], str]] = None\n    bootstrap_port: Optional[Union[List[Optional[int]], int]] = None\n    bootstrap_room: Optional[Union[List[int], int]] = None\n\n    # For DP routing — external router assigns a specific DP worker\n    routed_dp_rank: Optional[int] = None\n    # For PD disagg — hint telling decode which prefill DP worker has the KV cache\n    disagg_prefill_dp_rank: Optional[int] = None\n    # Deprecated: use routed_dp_rank instead\n    data_parallel_rank: Optional[int] = None\n\n    # For request id\n    rid: Optional[Union[List[str], str]] = None\n    # Extra key for classifying the request (e.g. cache_salt)\n    extra_key: Optional[Union[List[str], str]] = None\n    # Cache salt for request caching\n    cache_salt: Optional[Union[List[str], str]] = None\n    # Priority for the request\n    priority: Optional[int] = None\n\n    # For custom metric labels\n    custom_labels: Optional[Dict[str, str]] = None\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def _handle_deprecated_dp_rank(cls, values):\n        return _migrate_deprecated_dp_rank(values)\n\n    @field_validator(\"max_tokens\")\n    @classmethod\n    def validate_max_tokens_positive(cls, v):\n        if v is not None and v <= 0:\n            raise ValueError(\"max_tokens must be positive\")\n        return v\n\n\nclass SglExt(BaseModel):\n    \"\"\"SGLang extension fields for OpenAI-compatible responses.\n\n    Future SGLang-specific extensions to OpenAI-compatible response objects\n    should be added as fields here rather than directly on the choice object.\n    \"\"\"\n\n    routed_experts: Optional[str] = None\n    cached_tokens_details: Optional[CachedTokensDetails] = None\n\n    @model_serializer(mode=\"wrap\")\n    def _serialize(self, handler):\n        data = handler(self)\n        # Remove None fields to keep response clean\n        return {k: v for k, v in data.items() if v is not None}\n\n\nclass CompletionResponseChoice(BaseModel):\n    index: int\n    text: str\n    logprobs: Optional[LogProbs] = None\n    finish_reason: Optional[Literal[\"stop\", \"length\", \"content_filter\", \"abort\"]] = None\n    matched_stop: Union[None, int, str] = None\n    hidden_states: Optional[object] = None\n\n    @model_serializer(mode=\"wrap\")\n    def _serialize(self, handler):\n        data = handler(self)\n        if self.hidden_states is None:\n            data.pop(\"hidden_states\", None)\n        return data\n\n\nclass CompletionResponse(BaseModel):\n    id: str\n    object: str = \"text_completion\"\n    created: int = Field(default_factory=lambda: int(time.time()))\n    model: str\n    choices: List[CompletionResponseChoice]\n    usage: UsageInfo\n    metadata: Optional[Dict[str, Any]] = None\n    sglext: Optional[SglExt] = None\n\n    @model_serializer(mode=\"wrap\")\n    def _serialize(self, handler):\n        data = handler(self)\n        if self.sglext is None:\n            data.pop(\"sglext\", None)\n        return data\n\n\nclass CompletionResponseStreamChoice(BaseModel):\n    index: int\n    text: str\n    logprobs: Optional[LogProbs] = None\n    finish_reason: Optional[Literal[\"stop\", \"length\", \"content_filter\", \"abort\"]] = None\n    matched_stop: Union[None, int, str] = None\n    hidden_states: Optional[object] = None\n\n    @model_serializer(mode=\"wrap\")\n    def _serialize(self, handler):\n        data = handler(self)\n        if self.hidden_states is None:\n            data.pop(\"hidden_states\", None)\n        return data\n\n\nclass CompletionStreamResponse(BaseModel):\n    id: str\n    object: str = \"text_completion\"\n    created: int = Field(default_factory=lambda: int(time.time()))\n    model: str\n    choices: List[CompletionResponseStreamChoice]\n    usage: Optional[UsageInfo] = None\n    sglext: Optional[SglExt] = None\n\n    @model_serializer(mode=\"wrap\")\n    def _serialize(self, handler):\n        data = handler(self)\n        if self.sglext is None:\n            data.pop(\"sglext\", None)\n        return data\n\n\nclass ChatCompletionMessageContentTextPart(BaseModel):\n    type: Literal[\"text\"]\n    text: str\n\n\nclass ChatCompletionMessageContentImageURL(BaseModel):\n    url: str\n    detail: Optional[Literal[\"auto\", \"low\", \"high\"]] = \"auto\"\n    max_dynamic_patch: Optional[int] = None\n    min_dynamic_patch: Optional[int] = None\n\n\nclass ChatCompletionMessageContentVideoURL(BaseModel):\n    url: str\n    max_dynamic_patch: Optional[int] = None\n    min_dynamic_patch: Optional[int] = None\n\n\nclass ChatCompletionMessageContentAudioURL(BaseModel):\n    url: str\n\n\nclass ChatCompletionMessageContentImagePart(BaseModel):\n    type: Literal[\"image_url\"]\n    image_url: ChatCompletionMessageContentImageURL\n    modalities: Optional[Literal[\"image\", \"multi-images\", \"video\"]] = \"image\"\n\n\nclass ChatCompletionMessageContentVideoPart(BaseModel):\n    type: Literal[\"video_url\"]\n    video_url: ChatCompletionMessageContentVideoURL\n\n\nclass ChatCompletionMessageContentAudioPart(BaseModel):\n    type: Literal[\"audio_url\"]\n    audio_url: ChatCompletionMessageContentAudioURL\n\n\nChatCompletionMessageContentPart = Union[\n    ChatCompletionMessageContentTextPart,\n    ChatCompletionMessageContentImagePart,\n    ChatCompletionMessageContentVideoPart,\n    ChatCompletionMessageContentAudioPart,\n]\n\n# Rerank content types for multimodal reranking (e.g., Qwen3-VL-Reranker)\n# Can be a simple string (text-only) or a list of multimodal content parts\nRerankContentPart = Union[\n    ChatCompletionMessageContentTextPart,\n    ChatCompletionMessageContentImagePart,\n    ChatCompletionMessageContentVideoPart,\n]\nRerankContent = Union[str, List[RerankContentPart]]\n\n\nclass FunctionResponse(BaseModel):\n    \"\"\"Function response.\"\"\"\n\n    name: Optional[str] = None\n    arguments: Optional[str | Dict[str, Any]] = None\n\n\nclass ToolCall(BaseModel):\n    \"\"\"Tool call response.\"\"\"\n\n    id: Optional[str] = None\n    index: Optional[int] = None\n    type: Literal[\"function\"] = \"function\"\n    function: FunctionResponse\n\n\nclass ChatCompletionMessageGenericParam(BaseModel):\n    role: Literal[\"system\", \"assistant\", \"tool\", \"function\", \"developer\"]\n    content: Union[str, List[ChatCompletionMessageContentPart], None] = Field(\n        default=None\n    )\n    tool_call_id: Optional[str] = None\n    name: Optional[str] = None\n    reasoning_content: Optional[str] = None\n    tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])\n    tools: Optional[List[Tool]] = Field(default=None, examples=[None])\n\n    @field_validator(\"role\", mode=\"before\")\n    @classmethod\n    def _normalize_role(cls, v):\n        if isinstance(v, str):\n            v_lower = v.lower()\n            if v_lower not in {\"system\", \"assistant\", \"tool\", \"function\", \"developer\"}:\n                raise ValueError(\n                    \"'role' must be one of 'system', 'developer', 'assistant', 'tool', or 'function' (case-insensitive).\"\n                )\n            return v_lower\n        raise ValueError(\"'role' must be a string\")\n\n\nclass ChatCompletionMessageUserParam(BaseModel):\n    role: Literal[\"user\"]\n    content: Union[str, List[ChatCompletionMessageContentPart]]\n\n\nChatCompletionMessageParam = Union[\n    ChatCompletionMessageGenericParam, ChatCompletionMessageUserParam\n]\n\n\nclass Function(BaseModel):\n    \"\"\"Function descriptions.\"\"\"\n\n    description: Optional[str] = Field(default=None, examples=[None])\n    name: str\n    parameters: Optional[object] = None\n    strict: bool = False\n\n\nclass Tool(BaseModel):\n    \"\"\"Function wrapper.\"\"\"\n\n    type: str = Field(default=\"function\", examples=[\"function\"])\n    function: Function\n\n\nclass ToolChoiceFuncName(BaseModel):\n    \"\"\"The name of tool choice function.\"\"\"\n\n    name: Optional[str] = None\n\n\nclass ToolChoice(BaseModel):\n    \"\"\"The tool choice definition.\"\"\"\n\n    function: ToolChoiceFuncName\n    type: Literal[\"function\"] = Field(default=\"function\", examples=[\"function\"])\n\n\nclass ChatCompletionRequest(BaseModel):\n    # Ordered by official OpenAI API documentation\n    # https://platform.openai.com/docs/api-reference/chat/create\n    messages: List[ChatCompletionMessageParam]\n    model: str = Field(\n        default=DEFAULT_MODEL_NAME,\n        description=\"Model name. Supports LoRA adapters via 'base-model:adapter-name' syntax.\",\n    )\n    frequency_penalty: float = 0.0\n    logit_bias: Optional[Dict[str, float]] = None\n    logprobs: bool = False\n    top_logprobs: Optional[int] = None\n    max_tokens: Optional[int] = Field(\n        default=None,\n        deprecated=\"max_tokens is deprecated in favor of the max_completion_tokens field\",\n        description=\"The maximum number of tokens that can be generated in the chat completion. \",\n    )\n    max_completion_tokens: Optional[int] = Field(\n        default=None,\n        description=\"The maximum number of completion tokens for a chat completion request, \"\n        \"including visible output tokens and reasoning tokens. Input tokens are not included. \",\n    )\n    n: int = 1\n    presence_penalty: float = 0.0\n    response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None\n    seed: Optional[int] = None\n    stop: Optional[Union[str, List[str]]] = None\n    stream: bool = False\n    stream_options: Optional[StreamOptions] = None\n    temperature: Optional[float] = None\n    top_p: Optional[float] = None\n    user: Optional[str] = None\n    tools: Optional[List[Tool]] = Field(default=None, examples=[None])\n    tool_choice: Union[ToolChoice, Literal[\"auto\", \"required\", \"none\"]] = Field(\n        default=\"auto\", examples=[\"none\"]\n    )  # noqa\n    return_hidden_states: bool = False\n    return_routed_experts: bool = False\n    return_cached_tokens_details: bool = False\n    reasoning_effort: Optional[Literal[\"none\", \"low\", \"medium\", \"high\"]] = Field(\n        default=None,\n        description=\"Constrains effort on reasoning for reasoning models. \"\n        \"'none' disables reasoning entirely, 'low' is the least effort, 'high' is the most effort. \"\n        \"Reducing reasoning effort can result in faster responses and fewer tokens used on reasoning \"\n        \"in a response. 'none' defaults thinking and enable_thinking to false in \"\n        \"chat_template_kwargs (unless explicitly overridden). Not supported in the harmony path.\",\n    )\n\n    # Extra parameters for SRT backend only and will be ignored by OpenAI models.\n    top_k: Optional[int] = None\n    min_p: Optional[float] = None\n    min_tokens: int = 0\n    regex: Optional[str] = None\n    ebnf: Optional[str] = None\n    repetition_penalty: Optional[float] = None\n    stop_token_ids: Optional[List[int]] = None\n    stop_regex: Optional[Union[str, List[str]]] = None\n    no_stop_trim: bool = False\n    ignore_eos: bool = False\n    continue_final_message: bool = False\n    skip_special_tokens: bool = True\n    lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None\n    session_params: Optional[Dict] = None\n    separate_reasoning: bool = True\n    stream_reasoning: bool = True\n    chat_template_kwargs: Optional[Dict] = None\n\n    # SGLang multimodal tiling controls (extensions)\n    max_dynamic_patch: Optional[int] = None\n    min_dynamic_patch: Optional[int] = None\n\n    # Custom logit processor for advanced sampling control\n    custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None\n    custom_params: Optional[Dict] = None\n\n    # For request id\n    rid: Optional[Union[List[str], str]] = None\n    # Extra key for classifying the request (e.g. cache_salt)\n    extra_key: Optional[Union[List[str], str]] = None\n    # Cache salt for request caching\n    cache_salt: Optional[Union[List[str], str]] = None\n    # Priority for the request\n    priority: Optional[int] = None\n\n    # For PD disaggregation\n    bootstrap_host: Optional[Union[List[str], str]] = None\n    bootstrap_port: Optional[Union[List[Optional[int]], int]] = None\n    bootstrap_room: Optional[Union[List[int], int]] = None\n\n    # For DP routing — external router assigns a specific DP worker\n    routed_dp_rank: Optional[int] = None\n    # For PD disagg — hint telling decode which prefill DP worker has the KV cache\n    disagg_prefill_dp_rank: Optional[int] = None\n    # Deprecated: use routed_dp_rank instead\n    data_parallel_rank: Optional[int] = None\n\n    # OpenAI/SGLang default sampling parameters\n    _DEFAULT_SAMPLING_PARAMS = {\n        \"temperature\": 1.0,\n        \"top_p\": 1.0,\n        \"top_k\": -1,\n        \"min_p\": 0.0,\n        \"repetition_penalty\": 1.0,\n    }\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def _handle_deprecated_dp_rank(cls, values):\n        return _migrate_deprecated_dp_rank(values)\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def set_tool_choice_default(cls, values):\n        if values.get(\"tool_choice\") is None:\n            if values.get(\"tools\") is None:\n                values[\"tool_choice\"] = \"none\"\n            else:\n                values[\"tool_choice\"] = \"auto\"\n        return values\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def normalize_reasoning_inputs(cls, values: Dict):\n        r = values.get(\"reasoning\")\n\n        if r is not None and isinstance(r, dict):\n            effort = r.get(\"effort\") or r.get(\"reasoning_effort\")\n            if effort in {\"none\", \"low\", \"medium\", \"high\"}:\n                values[\"reasoning_effort\"] = effort\n\n            enabled = (\n                r.get(\"enabled\")\n                if r.get(\"enabled\") is not None\n                else r.get(\"enable\", False)\n            )\n            if isinstance(enabled, str):\n                enabled = enabled.strip().lower() in {\"1\", \"true\", \"yes\", \"y\", \"on\"}\n            if enabled:\n                ctk = values.get(\"chat_template_kwargs\")\n                if not isinstance(ctk, dict):\n                    ctk = {}\n                ctk.setdefault(\"thinking\", True)\n                values[\"chat_template_kwargs\"] = ctk\n\n        if values.get(\"reasoning_effort\") == \"none\":\n            ctk = values.get(\"chat_template_kwargs\")\n            if not isinstance(ctk, dict):\n                ctk = {}\n            # different models check different keys:\n            # - \"thinking\" for deepseek-v3, kimi_k2\n            # - \"enable_thinking\" for qwen3, glm45, nemotron_3, interns1\n            ctk.setdefault(\"thinking\", False)\n            ctk.setdefault(\"enable_thinking\", False)\n            values[\"chat_template_kwargs\"] = ctk\n\n        return values\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def set_json_schema(cls, values):\n        response_format = values.get(\"response_format\")\n        if not response_format:\n            return values\n\n        if response_format.get(\"type\") != \"json_schema\":\n            return values\n\n        schema = response_format.pop(\"schema\", None)\n        json_schema = response_format.get(\"json_schema\")\n\n        if json_schema:\n            return values\n\n        if schema:\n            name_ = schema.get(\"title\", \"Schema\")\n            strict_ = False\n            if \"properties\" in schema and \"strict\" in schema[\"properties\"]:\n                item = schema[\"properties\"].pop(\"strict\", None)\n                if item and item.get(\"default\", False):\n                    strict_ = True\n\n            response_format[\"json_schema\"] = {\n                \"name\": name_,\n                \"schema\": schema,\n                \"strict\": strict_,\n            }\n\n        return values\n\n    def to_sampling_params(\n        self,\n        stop: List[str],\n        model_generation_config: Dict[str, Any],\n        tool_call_constraint: Optional[ToolCallConstraint] = None,\n    ) -> Dict[str, Any]:\n        \"\"\"\n        Convert request to sampling parameters.\n        Priority: user value > model generation_config > OpenAI defaults\n        \"\"\"\n\n        def get_param(param_name: str):\n            value = getattr(self, param_name)\n            if value is None:\n                return model_generation_config.get(\n                    param_name, self._DEFAULT_SAMPLING_PARAMS[param_name]\n                )\n            return value\n\n        # add per user request\n        spaces_between_special_tokens = (\n            True\n            if self.chat_template_kwargs is None\n            else self.chat_template_kwargs.get(\"spaces_between_special_tokens\", True)\n        )\n\n        sampling_params = {\n            \"temperature\": get_param(\"temperature\"),\n            \"max_new_tokens\": self.max_completion_tokens or self.max_tokens,\n            \"min_new_tokens\": self.min_tokens,\n            \"stop\": stop,\n            \"stop_token_ids\": self.stop_token_ids,\n            \"stop_regex\": self.stop_regex,\n            \"top_p\": get_param(\"top_p\"),\n            \"top_k\": get_param(\"top_k\"),\n            \"min_p\": get_param(\"min_p\"),\n            \"presence_penalty\": self.presence_penalty,\n            \"frequency_penalty\": self.frequency_penalty,\n            \"repetition_penalty\": get_param(\"repetition_penalty\"),\n            \"regex\": self.regex,\n            \"ebnf\": self.ebnf,\n            \"n\": self.n,\n            \"no_stop_trim\": self.no_stop_trim,\n            \"ignore_eos\": self.ignore_eos,\n            \"skip_special_tokens\": self.skip_special_tokens,\n            \"logit_bias\": self.logit_bias,\n            \"custom_params\": self.custom_params,\n            \"sampling_seed\": self.seed,\n            \"spaces_between_special_tokens\": spaces_between_special_tokens,\n        }\n\n        if self.response_format and self.response_format.type == \"json_schema\":\n            sampling_params[\"json_schema\"] = convert_json_schema_to_str(\n                self.response_format.json_schema.schema_\n            )\n        elif self.response_format and self.response_format.type == \"json_object\":\n            sampling_params[\"json_schema\"] = '{\"type\": \"object\"}'\n        elif self.response_format and self.response_format.type == \"structural_tag\":\n            sampling_params[\"structural_tag\"] = convert_json_schema_to_str(\n                self.response_format.model_dump(by_alias=True)\n            )\n\n        # Check if there are already existing output constraints\n        has_existing_constraints = (\n            sampling_params.get(\"regex\")\n            or sampling_params.get(\"ebnf\")\n            or sampling_params.get(\"structural_tag\")\n            or sampling_params.get(\"json_schema\")\n        )\n\n        if tool_call_constraint and has_existing_constraints:\n            logger.warning(\"Constrained decoding is not compatible with tool calls.\")\n        elif tool_call_constraint:\n            constraint_type, constraint_value = tool_call_constraint\n            if constraint_type == \"structural_tag\":\n                sampling_params[constraint_type] = convert_json_schema_to_str(\n                    constraint_value.model_dump(by_alias=True)\n                )\n            elif constraint_type == \"json_schema\":\n                sampling_params[constraint_type] = convert_json_schema_to_str(\n                    constraint_value  # type: ignore\n                )\n            else:\n                sampling_params[constraint_type] = constraint_value\n\n        return sampling_params\n\n\nclass ChatMessage(BaseModel):\n    role: Optional[str] = None\n    content: Optional[str] = None\n    reasoning_content: Optional[str] = None\n    tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])\n\n\nclass ChatCompletionResponseChoice(BaseModel):\n    index: int\n    message: ChatMessage\n    logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None\n    finish_reason: Optional[\n        Literal[\n            \"stop\", \"length\", \"tool_calls\", \"content_filter\", \"function_call\", \"abort\"\n        ]\n    ] = None\n    matched_stop: Union[None, int, str] = None\n    hidden_states: Optional[object] = None\n\n    @model_serializer(mode=\"wrap\")\n    def _serialize(self, handler):\n        data = handler(self)\n        if self.hidden_states is None:\n            data.pop(\"hidden_states\", None)\n        return data\n\n\nclass ChatCompletionResponse(BaseModel):\n    id: str\n    object: str = \"chat.completion\"\n    created: int = Field(default_factory=lambda: int(time.time()))\n    model: str\n    choices: List[ChatCompletionResponseChoice]\n    usage: UsageInfo\n    metadata: Optional[Dict[str, Any]] = None\n    sglext: Optional[SglExt] = None\n\n    @model_serializer(mode=\"wrap\")\n    def _serialize(self, handler):\n        data = handler(self)\n        if self.sglext is None:\n            data.pop(\"sglext\", None)\n        return data\n\n\nclass DeltaMessage(BaseModel):\n    role: Optional[str] = None\n    content: Optional[str] = None\n    reasoning_content: Optional[str] = None\n    tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])\n    hidden_states: Optional[object] = None\n\n    @model_serializer(mode=\"wrap\")\n    def _serialize(self, handler):\n        data = handler(self)\n        if self.hidden_states is None:\n            data.pop(\"hidden_states\", None)\n        return data\n\n\nclass ChatCompletionResponseStreamChoice(BaseModel):\n    index: int\n    delta: DeltaMessage\n    logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None\n    finish_reason: Optional[\n        Literal[\n            \"stop\", \"length\", \"tool_calls\", \"content_filter\", \"function_call\", \"abort\"\n        ]\n    ] = None\n    matched_stop: Union[None, int, str] = None\n\n\nclass ChatCompletionStreamResponse(BaseModel):\n    id: str\n    object: str = \"chat.completion.chunk\"\n    created: int = Field(default_factory=lambda: int(time.time()))\n    model: str\n    choices: List[ChatCompletionResponseStreamChoice]\n    usage: Optional[UsageInfo] = None\n    sglext: Optional[SglExt] = None\n\n    @model_serializer(mode=\"wrap\")\n    def _serialize(self, handler):\n        data = handler(self)\n        if self.sglext is None:\n            data.pop(\"sglext\", None)\n        return data\n\n\nclass MultimodalEmbeddingInput(BaseModel):\n    text: Optional[str] = None\n    image: Optional[str] = None\n    video: Optional[str] = None\n\n\nEmbeddingInput = Union[\n    List[int], List[List[int]], str, List[str], List[MultimodalEmbeddingInput]\n]\n\n\nclass EmbeddingRequest(BaseModel):\n    # Ordered by official OpenAI API documentation\n    # https://platform.openai.com/docs/api-reference/embeddings/create\n    input: EmbeddingInput\n    model: str = DEFAULT_MODEL_NAME\n    encoding_format: str = \"float\"\n    dimensions: Optional[int] = None\n    user: Optional[str] = None\n\n    # The request id.\n    rid: Optional[Union[List[str], str]] = None\n    # Priority for the request\n    priority: Optional[int] = None\n    # LoRA adapter path(s)\n    lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None\n\n\nclass EmbeddingObject(BaseModel):\n    embedding: List[float]\n    index: int\n    object: str = \"embedding\"\n\n\nClassifyInput = Union[str, List[str], List[int]]\n\n\nclass ClassifyRequest(BaseModel):\n    # OpenAI-compatible classification request\n    model: str = DEFAULT_MODEL_NAME\n    input: ClassifyInput\n    user: Optional[str] = None\n\n    # The request id.\n    rid: Optional[Union[List[str], str]] = None\n    # Priority for the request\n    priority: Optional[int] = None\n\n\nclass ClassifyData(BaseModel):\n    index: int\n    label: str\n    probs: List[float]\n    num_classes: int\n\n\nclass ClassifyResponse(BaseModel):\n    id: str\n    object: str = \"list\"\n    created: int\n    model: str\n    data: List[ClassifyData]\n    usage: UsageInfo\n\n\nclass EmbeddingResponse(BaseModel):\n    data: List[EmbeddingObject]\n    model: str\n    object: str = \"list\"\n    usage: Optional[UsageInfo] = None\n\n\nclass ScoringRequest(BaseModel):\n    query: Optional[Union[str, List[int]]] = (\n        None  # Query text or pre-tokenized token IDs\n    )\n    items: Optional[Union[str, List[str], List[List[int]]]] = (\n        None  # Item text(s) or pre-tokenized token IDs\n    )\n    label_token_ids: Optional[List[int]] = (\n        None  # Token IDs to compute probabilities for\n    )\n    apply_softmax: bool = False\n    item_first: bool = False\n    model: str = DEFAULT_MODEL_NAME\n\n\nclass ScoringResponse(BaseModel):\n    scores: List[\n        List[float]\n    ]  # List of lists of probabilities, each in the order of label_token_ids\n    model: str\n    usage: Optional[UsageInfo] = None\n    object: str = \"scoring\"\n\n\nclass V1RerankReqInput(BaseModel):\n    query: RerankContent = Field(\n        ...,\n        description=\"The query to match against documents. Can be a string (text-only) \"\n        \"or a list of content parts for multimodal queries (text, image_url, video_url).\",\n    )\n    documents: List[RerankContent] = Field(\n        ...,\n        description=\"List of documents to rank. Each document can be a string (text-only) \"\n        \"or a list of content parts for multimodal documents (text, image_url, video_url).\",\n    )\n    instruct: Optional[str] = Field(\n        default=None,\n        description=\"The instruct to the reranker model.\",\n    )\n    top_n: Optional[int] = Field(\n        default=None,\n        description=\"Maximum number of documents to return. Defaults to returning all documents. \"\n        \"If specified value is greater than the total number of documents, all documents will be returned.\",\n    )\n    return_documents: bool = Field(\n        default=True,\n        description=\"Whether to return documents in the response. Only included when set to true.\",\n    )\n\n    @field_validator(\"top_n\")\n    @classmethod\n    def validate_top_n(cls, v):\n        if v is not None and v < 1:\n            raise ValueError(\"Value error, parameter top_n should be larger than 0.\")\n        return v\n\n    def is_multimodal(self) -> bool:\n        \"\"\"Check if the request contains any multimodal content.\"\"\"\n        if isinstance(self.query, list):\n            return True\n        for doc in self.documents:\n            if isinstance(doc, list):\n                return True\n        return False\n\n\nclass RerankResponse(BaseModel):\n    score: float\n    document: Optional[str] = None\n    index: int\n    meta_info: Optional[dict] = None\n\n    @model_serializer(mode=\"wrap\")\n    def _serialize(self, handler):\n        data = handler(self)\n        # Exclude document field if it's None\n        if self.document is None:\n            data.pop(\"document\", None)\n        return data\n\n\nclass TokenizeRequest(BaseModel):\n    \"\"\"Request schema for the /tokenize endpoint.\"\"\"\n\n    model: str = DEFAULT_MODEL_NAME\n    prompt: Union[str, List[str]]\n    add_special_tokens: bool = Field(\n        default=True,\n        description=\"whether to add model-specific special tokens (e.g. BOS/EOS) during encoding.\",\n    )\n\n\nclass TokenizeResponse(BaseModel):\n    \"\"\"Response schema for the /tokenize endpoint.\"\"\"\n\n    tokens: Union[List[int], List[List[int]]]\n    count: Union[int, List[int]]\n    max_model_len: int\n\n\nclass DetokenizeRequest(BaseModel):\n    \"\"\"Request schema for the /detokenize endpoint.\"\"\"\n\n    model: str = DEFAULT_MODEL_NAME\n    tokens: Union[List[int], List[List[int]]]\n    skip_special_tokens: bool = Field(\n        default=True,\n        description=\"whether to exclude special tokens (e.g. padding or EOS) during decoding.\",\n    )\n\n\nclass DetokenizeResponse(BaseModel):\n    \"\"\"Response schema for the /detokenize endpoint.\"\"\"\n\n    text: Union[str, List[str]]\n\n\nOpenAIServingRequest = Union[\n    ChatCompletionRequest,\n    CompletionRequest,\n    EmbeddingRequest,\n    ClassifyRequest,\n    ScoringRequest,\n    V1RerankReqInput,\n    TokenizeRequest,\n    DetokenizeRequest,\n]\n\n\n# Response API protocol definitions\nclass ResponseReasoningParam(BaseModel):\n    \"\"\"Reasoning parameters for responses.\"\"\"\n\n    effort: Optional[Literal[\"low\", \"medium\", \"high\"]] = Field(\n        default=\"medium\",\n        description=\"Constrains effort on reasoning for reasoning models.\",\n    )\n\n\nclass ResponseTool(BaseModel):\n    \"\"\"Tool definition for responses.\"\"\"\n\n    type: Literal[\"web_search_preview\", \"code_interpreter\"] = Field(\n        description=\"Type of tool to enable\"\n    )\n\n\nResponseInputOutputItem: TypeAlias = Union[\n    ResponseInputItemParam,\n    \"ResponseReasoningItem\",\n    ResponseFunctionToolCall,\n]\n\n\nclass ResponsesRequest(BaseModel):\n    \"\"\"Request body for v1/responses endpoint.\"\"\"\n\n    # Core OpenAI API fields (ordered by official documentation)\n    background: Optional[bool] = False\n    include: Optional[\n        List[\n            Literal[\n                \"code_interpreter_call.outputs\",\n                \"computer_call_output.output.image_url\",\n                \"file_search_call.results\",\n                \"message.input_image.image_url\",\n                \"message.output_text.logprobs\",\n                \"reasoning.encrypted_content\",\n            ]\n        ]\n    ] = None\n    input: Union[str, List[ResponseInputOutputItem]]\n    instructions: Optional[str] = None\n    max_output_tokens: Optional[int] = None\n    max_tool_calls: Optional[int] = None\n    metadata: Optional[Dict[str, Any]] = None\n    model: Optional[str] = None  # Made optional to match vLLM\n    parallel_tool_calls: Optional[bool] = True\n    previous_response_id: Optional[str] = None\n    reasoning: Optional[ResponseReasoningParam] = None\n    service_tier: Literal[\"auto\", \"default\", \"flex\", \"scale\", \"priority\"] = \"auto\"\n    store: Optional[bool] = True\n    stream: Optional[bool] = False\n    temperature: Optional[float] = None\n    tool_choice: Literal[\"auto\", \"required\", \"none\"] = \"auto\"\n    tools: List[ResponseTool] = Field(default_factory=list)\n    top_logprobs: Optional[int] = 0\n    top_p: Optional[float] = None\n    truncation: Optional[Literal[\"auto\", \"disabled\"]] = \"disabled\"\n    user: Optional[str] = None\n\n    # Extra SGLang parameters\n    request_id: str = Field(\n        default_factory=lambda: f\"resp_{uuid.uuid4().hex}\",\n        description=\"The request_id related to this request. If the caller does not set it, a random uuid will be generated.\",\n    )\n    priority: int = Field(default=0, description=\"Request priority\")\n    extra_key: Optional[str] = Field(\n        default=None,\n        description=\"Extra key for classifying the request (e.g. cache_salt)\",\n    )\n    cache_salt: Optional[str] = Field(\n        default=None, description=\"Cache salt for request caching\"\n    )\n\n    # SGLang-specific sampling parameters\n    frequency_penalty: float = 0.0\n    presence_penalty: float = 0.0\n    stop: Optional[Union[str, List[str]]] = None\n    top_k: int = -1\n    min_p: float = 0.0\n    repetition_penalty: float = 1.0\n\n    # Default sampling parameters\n    _DEFAULT_SAMPLING_PARAMS = {\n        \"temperature\": 0.7,\n        \"top_p\": 1.0,\n        \"top_k\": -1,\n        \"min_p\": 0.0,\n        \"repetition_penalty\": 1.0,\n    }\n\n    def to_sampling_params(\n        self, default_max_tokens: int, default_params: Optional[Dict] = None\n    ) -> Dict[str, Any]:\n        \"\"\"Convert to sampling parameters for generation.\"\"\"\n        if default_params is None:\n            default_params = {}\n\n        # Use max_output_tokens if available, otherwise use max_tokens for backwards compatibility\n        if self.max_output_tokens is not None:\n            max_tokens = min(self.max_output_tokens, default_max_tokens)\n        else:\n            max_tokens = default_max_tokens\n\n        # Avoid exceed the context length by minus 2 token\n        max_tokens -= 2\n\n        # Get parameters with defaults\n        temperature = self.temperature\n        if temperature is None:\n            temperature = default_params.get(\n                \"temperature\", self._DEFAULT_SAMPLING_PARAMS[\"temperature\"]\n            )\n\n        top_p = self.top_p\n        if top_p is None:\n            top_p = default_params.get(\"top_p\", self._DEFAULT_SAMPLING_PARAMS[\"top_p\"])\n\n        params = {\n            \"max_new_tokens\": max_tokens,\n            \"temperature\": temperature,\n            \"top_p\": top_p,\n            \"frequency_penalty\": self.frequency_penalty,\n            \"presence_penalty\": self.presence_penalty,\n            \"stop\": self.stop,\n            \"top_k\": self.top_k,\n            \"min_p\": self.min_p,\n            \"repetition_penalty\": self.repetition_penalty,\n        }\n\n        # Apply any additional default parameters\n        for key, value in default_params.items():\n            if key not in params or params[key] is None:\n                params[key] = value\n\n        return params\n\n\nclass PromptTokenUsageInfo(BaseModel):\n    \"\"\"Prompt token usage details.\"\"\"\n\n    cached_tokens: int = 0\n\n\nclass ResponsesResponse(BaseModel):\n    \"\"\"Response body for v1/responses endpoint.\"\"\"\n\n    id: str = Field(default_factory=lambda: f\"resp_{time.time()}\")\n    object: Literal[\"response\"] = \"response\"\n    created_at: int = Field(default_factory=lambda: int(time.time()))\n    model: str\n\n    output: List[\n        Union[ResponseOutputItem, ResponseReasoningItem, ResponseFunctionToolCall]\n    ] = Field(default_factory=list)\n    status: Literal[\"queued\", \"in_progress\", \"completed\", \"failed\", \"cancelled\"]\n    usage: Optional[UsageInfo] = None\n    parallel_tool_calls: bool = True\n    tool_choice: str = \"auto\"\n    tools: List[ResponseTool] = Field(default_factory=list)\n\n    # OpenAI compatibility fields. not all are used at the moment.\n    # Recommend checking https://platform.openai.com/docs/api-reference/responses\n    error: Optional[dict] = None\n    incomplete_details: Optional[dict] = None  # TODO(v) support this input\n    instructions: Optional[str] = None\n    max_output_tokens: Optional[int] = None\n    previous_response_id: Optional[str] = None\n    reasoning: Optional[dict] = (\n        # Unused. No model supports this. For GPT-oss, system prompt sets\n        # the field, not server args.\n        None  # {\"effort\": Optional[str], \"summary\": Optional[str]}\n    )\n    store: Optional[bool] = None\n    temperature: Optional[float] = None\n    text: Optional[dict] = None  # e.g. {\"format\": {\"type\": \"text\"}}\n    top_p: Optional[float] = None\n    truncation: Optional[str] = None\n    user: Optional[str] = None\n    metadata: Optional[Dict[str, Any]] = None\n\n    @classmethod\n    def from_request(\n        cls,\n        request: ResponsesRequest,\n        sampling_params: Any,\n        model_name: str,\n        created_time: int,\n        output: List[\n            Union[ResponseOutputItem, ResponseReasoningItem, ResponseFunctionToolCall]\n        ],\n        status: str,\n        usage: Optional[UsageInfo],\n    ) -> \"ResponsesResponse\":\n        \"\"\"Create a response from a request.\"\"\"\n\n        # Determine if the output is plain text only to set text.format\n        def _is_text_only(\n            items: List[\n                Union[\n                    ResponseOutputItem, ResponseReasoningItem, ResponseFunctionToolCall\n                ]\n            ],\n        ) -> bool:\n            if not items:\n                return False\n            for it in items:\n                # tool call -> not pure text.\n                if isinstance(it, ResponseReasoningItem) or isinstance(\n                    it, ResponseFunctionToolCall\n                ):\n                    return False\n                try:\n                    if isinstance(it, ResponseOutputText):\n                        continue\n                    elif isinstance(it, ResponseOutputMessage):\n                        if not it.content:\n                            continue\n                        for c in it.content:\n                            if not isinstance(c, ResponseOutputText):\n                                return False\n                    else:\n                        # Unknown type, not considered text-only\n                        return False\n                except AttributeError:\n                    return False\n            return True\n\n        text_format = {\"format\": {\"type\": \"text\"}} if _is_text_only(output) else None\n\n        return cls(\n            id=request.request_id,\n            created_at=created_time,\n            model=model_name,\n            output=output,\n            status=status,\n            usage=usage,\n            parallel_tool_calls=request.parallel_tool_calls or True,\n            tool_choice=request.tool_choice,\n            tools=request.tools,\n            # fields for parity with v1/responses\n            error=None,\n            incomplete_details=None,\n            instructions=request.instructions,\n            max_output_tokens=request.max_output_tokens,\n            previous_response_id=request.previous_response_id,  # TODO(v): ensure this is propagated if retrieved from store\n            reasoning={\n                \"effort\": request.reasoning.effort if request.reasoning else None,\n                \"summary\": None,  # unused\n            },\n            store=request.store,\n            temperature=request.temperature,\n            text=text_format,  # TODO(v): Expand coverage per https://platform.openai.com/docs/api-reference/responses/list\n            top_p=request.top_p,\n            truncation=request.truncation,\n            user=request.user,\n            metadata=request.metadata or {},\n        )\n\n\nclass RequestResponseMetadata(BaseModel):\n    \"\"\"Metadata for request/response tracking.\"\"\"\n\n    request_id: str\n    final_usage_info: Optional[UsageInfo] = None\n\n\n@dataclass\nclass MessageProcessingResult:\n    \"\"\"Result of processing chat messages and applying templates.\n\n    This dataclass encapsulates all the outputs from message processing including\n    prompt generation, multimodal data extraction, and constraint preparation.\n    Used internally by OpenAIServingChat to pass processed data between methods.\n\n    Args:\n        prompt: The final text prompt after applying chat template\n        prompt_ids: Either the text prompt (str) or tokenized IDs (List[int])\n        image_data: Extracted image data from messages, if any\n        audio_data: Extracted audio data from messages, if any\n        modalities: List of modality types present in the messages\n        stop: Combined stop strings from template and request\n        tool_call_constraint: Optional constraint for structured tool calls\n    \"\"\"\n\n    prompt: str\n    prompt_ids: Union[str, List[int]]\n    image_data: Optional[Any]\n    audio_data: Optional[Any]\n    video_data: Optional[Any]\n    modalities: List[str]\n    stop: List[str]\n    tool_call_constraint: Optional[ToolCallConstraint] = None\n\n\nclass ToolCallProcessingResult(NamedTuple):\n    \"\"\"Result of processing tool calls in a response.\"\"\"\n\n    tool_calls: Optional[\n        List[Any]\n    ]  # List of ToolCall objects or None if parsing failed\n    remaining_text: str  # Text remaining after parsing tool calls\n    finish_reason: Dict[str, Any]  # Updated finish reason dictionary\n\n\nclass ResponseReasoningTextContent(BaseModel):\n    text: str\n    type: Literal[\"reasoning_text\"] = \"reasoning_text\"\n\n\nResponseInputOutputItem: TypeAlias = Union[\n    ResponseInputItemParam, \"ResponseReasoningItem\", ResponseFunctionToolCall\n]\n\n\n# ================== Transcription API Protocol Definitions ==================\n\n\nclass TranscriptionRequest(BaseModel):\n    \"\"\"Request model for audio transcription (OpenAI-compatible).\"\"\"\n\n    model: str = DEFAULT_MODEL_NAME\n    language: Optional[str] = None\n    response_format: str = \"json\"\n    temperature: float = 0.0\n    stream: bool = False\n    # Internal fields (not from API)\n    audio_data: Optional[bytes] = None\n    audio_duration_s: float = 0.0\n\n\nclass TranscriptionUsage(BaseModel):\n    \"\"\"Usage info for transcription response (duration-based).\"\"\"\n\n    type: Literal[\"duration\"] = \"duration\"\n    seconds: int  # Audio duration in seconds (rounded up)\n\n\nclass TranscriptionResponse(BaseModel):\n    \"\"\"Non-streaming transcription response (OpenAI-compatible).\"\"\"\n\n    text: str\n    usage: Optional[TranscriptionUsage] = None\n\n\nclass TranscriptionStreamChoice(BaseModel):\n    \"\"\"Delta content for streaming transcription.\"\"\"\n\n    delta: DeltaMessage\n    finish_reason: Optional[str] = None\n\n\nclass TranscriptionStreamResponse(BaseModel):\n    \"\"\"Streaming transcription chunk (OpenAI-compatible).\"\"\"\n\n    id: str = Field(default_factory=lambda: f\"trsc-{uuid.uuid4().hex}\")\n    object: Literal[\"transcription.chunk\"] = \"transcription.chunk\"\n    created: int = Field(default_factory=lambda: int(time.time()))\n    model: str\n    choices: List[TranscriptionStreamChoice]\n    usage: Optional[UsageInfo] = None\n"
  },
  {
    "path": "python/sglang/srt/entrypoints/openai/serving_base.py",
    "content": "from __future__ import annotations\n\nimport json\nimport logging\nimport uuid\nfrom abc import ABC, abstractmethod\nfrom typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union\n\nimport orjson\nfrom fastapi import HTTPException, Request\nfrom fastapi.responses import ORJSONResponse, StreamingResponse\n\nfrom sglang.srt.entrypoints.openai.encoding_dsv32 import DS32EncodingError\nfrom sglang.srt.entrypoints.openai.protocol import ErrorResponse, OpenAIServingRequest\nfrom sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput\nfrom sglang.srt.observability.req_time_stats import monotonic_time\nfrom sglang.srt.server_args import ServerArgs\n\nif TYPE_CHECKING:\n    from sglang.srt.managers.tokenizer_manager import TokenizerManager\n\nlogger = logging.getLogger(__name__)\n\n\n# Base class for specific endpoint handlers\nclass OpenAIServingBase(ABC):\n    \"\"\"Abstract base class for OpenAI endpoint handlers\"\"\"\n\n    def __init__(self, tokenizer_manager: TokenizerManager):\n        self.tokenizer_manager = tokenizer_manager\n        self.allowed_custom_labels = (\n            set(\n                self.tokenizer_manager.server_args.tokenizer_metrics_allowed_custom_labels\n            )\n            if isinstance(self.tokenizer_manager.server_args, ServerArgs)\n            and self.tokenizer_manager.server_args.tokenizer_metrics_allowed_custom_labels\n            else None\n        )\n\n    def _parse_model_parameter(self, model: str) -> Tuple[str, Optional[str]]:\n        \"\"\"Parse 'base-model:adapter-name' syntax to extract LoRA adapter.\n\n        Returns (base_model, adapter_name) or (model, None) if no colon present.\n        \"\"\"\n        if \":\" not in model:\n            return model, None\n\n        # Split on first colon only to handle model paths with multiple colons\n        parts = model.split(\":\", 1)\n        base_model = parts[0].strip()\n        adapter_name = parts[1].strip() or None\n\n        return base_model, adapter_name\n\n    def _resolve_lora_path(\n        self,\n        request_model: str,\n        explicit_lora_path: Optional[Union[str, List[Optional[str]]]],\n    ) -> Optional[Union[str, List[Optional[str]]]]:\n        \"\"\"Resolve LoRA adapter with priority: model parameter > explicit lora_path.\n\n        Returns adapter name or None. Supports both single values and lists (batches).\n        \"\"\"\n        _, adapter_from_model = self._parse_model_parameter(request_model)\n\n        # Model parameter adapter takes precedence\n        if adapter_from_model is not None:\n            return adapter_from_model\n\n        # Fall back to explicit lora_path\n        return explicit_lora_path\n\n    async def handle_request(\n        self, request: OpenAIServingRequest, raw_request: Request\n    ) -> Union[Any, StreamingResponse, ErrorResponse]:\n        \"\"\"Handle the specific request type with common pattern\n        If you want to override this method, you should be careful to record the validation time.\n        \"\"\"\n        received_time = monotonic_time()\n\n        try:\n            # Validate request\n            error_msg = self._validate_request(request)\n            if error_msg:\n                return self.create_error_response(error_msg)\n\n            # Log the raw OpenAI request payload before conversion to tokenized form.\n            request_logger = self.tokenizer_manager.request_logger\n            if request_logger.log_requests and request_logger.log_requests_level >= 2:\n                request_logger.log_openai_received_request(request, request=raw_request)\n\n            # Convert to internal format\n            adapted_request, processed_request = self._convert_to_internal_request(\n                request, raw_request\n            )\n\n            if isinstance(adapted_request, (GenerateReqInput, EmbeddingReqInput)):\n                # Only set timing fields if adapted_request supports them\n                adapted_request.received_time = received_time\n\n            # Note(Xinyuan): raw_request below is only used for detecting the connection of the client\n            if hasattr(request, \"stream\") and request.stream:\n                return await self._handle_streaming_request(\n                    adapted_request, processed_request, raw_request\n                )\n            else:\n                return await self._handle_non_streaming_request(\n                    adapted_request, processed_request, raw_request\n                )\n        except HTTPException as e:\n            return self.create_error_response(\n                message=e.detail, err_type=str(e.status_code), status_code=e.status_code\n            )\n        except ValueError as e:\n            return self.create_error_response(\n                message=str(e),\n                err_type=\"BadRequest\",\n                status_code=400,\n            )\n        except DS32EncodingError as e:\n            logger.info(f\"DS32EncodingError: {e}\")\n            return self.create_error_response(\n                message=str(e),\n                err_type=\"BadRequest\",\n                status_code=400,\n            )\n        except Exception as e:\n            logger.exception(f\"Error in request: {e}\")\n            return self.create_error_response(\n                message=f\"Internal server error: {str(e)}\",\n                err_type=\"InternalServerError\",\n                status_code=500,\n            )\n\n    @abstractmethod\n    def _request_id_prefix(self) -> str:\n        \"\"\"Generate request ID based on request type\"\"\"\n        pass\n\n    def _generate_request_id_base(self, request: OpenAIServingRequest) -> Optional[str]:\n        \"\"\"Generate request ID based on request type\"\"\"\n        return None\n\n        # TODO(chang): the rid is used in io_strcut check and often violates `The rid should be a list` AssertionError\n        # Temporarily return None in this function until the rid logic is clear.\n        if rid := getattr(request, \"rid\", None):\n            return rid\n\n        return f\"{self._request_id_prefix()}{uuid.uuid4().hex}\"\n\n    def _compute_extra_key(self, request: OpenAIServingRequest) -> Optional[str]:\n        \"\"\"Compute the final extra_key by concatenating cache_salt and extra_key if both are provided.\"\"\"\n        parts = []\n        for key in [\"cache_salt\", \"extra_key\"]:\n            value = getattr(request, key, None)\n            if value:\n                if not isinstance(value, str):\n                    raise TypeError(\n                        f\"Value of {key} must be a string, but got {type(value).__name__}\"\n                    )\n                parts.append(value)\n        return \"\".join(parts) if parts else None\n\n    @abstractmethod\n    def _convert_to_internal_request(\n        self,\n        request: OpenAIServingRequest,\n        raw_request: Request = None,\n    ) -> tuple[GenerateReqInput, OpenAIServingRequest]:\n        \"\"\"Convert OpenAI request to internal format\"\"\"\n        pass\n\n    async def _handle_streaming_request(\n        self,\n        adapted_request: GenerateReqInput,\n        request: OpenAIServingRequest,\n        raw_request: Request,\n    ) -> Union[StreamingResponse, ErrorResponse, ORJSONResponse]:\n        \"\"\"Handle streaming request\n\n        Override this method in child classes that support streaming requests.\n        \"\"\"\n        return self.create_error_response(\n            message=f\"{self.__class__.__name__} does not support streaming requests\",\n            err_type=\"NotImplementedError\",\n            status_code=501,\n        )\n\n    async def _handle_non_streaming_request(\n        self,\n        adapted_request: GenerateReqInput,\n        request: OpenAIServingRequest,\n        raw_request: Request,\n    ) -> Union[Any, ErrorResponse, ORJSONResponse]:\n        \"\"\"Handle non-streaming request\n\n        Override this method in child classes that support non-streaming requests.\n        \"\"\"\n        return self.create_error_response(\n            message=f\"{self.__class__.__name__} does not support non-streaming requests\",\n            err_type=\"NotImplementedError\",\n            status_code=501,\n        )\n\n    def _validate_request(self, _: OpenAIServingRequest) -> Optional[str]:\n        \"\"\"Validate request\"\"\"\n        pass\n\n    def create_error_response(\n        self,\n        message: str,\n        err_type: str = \"BadRequestError\",\n        status_code: int = 400,\n        param: Optional[str] = None,\n    ) -> ORJSONResponse:\n        \"\"\"Create an error response\"\"\"\n        # TODO: remove fastapi dependency in openai and move response handling to the entrypoint\n        error = ErrorResponse(\n            object=\"error\",\n            message=message,\n            type=err_type,\n            param=param,\n            code=status_code,\n        )\n        return ORJSONResponse(content=error.model_dump(), status_code=status_code)\n\n    def create_streaming_error_response(\n        self,\n        message: str,\n        err_type: str = \"BadRequestError\",\n        status_code: int = 400,\n    ) -> str:\n        \"\"\"Create a streaming error response\"\"\"\n        error = ErrorResponse(\n            object=\"error\",\n            message=message,\n            type=err_type,\n            param=None,\n            code=status_code,\n        )\n        return json.dumps({\"error\": error.model_dump()})\n\n    def extract_custom_labels(self, raw_request):\n        if (\n            not self.allowed_custom_labels\n            or not self.tokenizer_manager.server_args.tokenizer_metrics_custom_labels_header\n        ):\n            return None\n\n        custom_labels = None\n        header = (\n            self.tokenizer_manager.server_args.tokenizer_metrics_custom_labels_header\n        )\n        try:\n            raw_labels = (\n                orjson.loads(raw_request.headers.get(header))\n                if raw_request and raw_request.headers.get(header)\n                else None\n            )\n        except json.JSONDecodeError as e:\n            logger.exception(f\"Error in request: {e}\")\n            raw_labels = None\n\n        if isinstance(raw_labels, dict):\n            custom_labels = {\n                label: value\n                for label, value in raw_labels.items()\n                if label in self.allowed_custom_labels\n            }\n        return custom_labels\n\n    def extract_routing_key(self, raw_request):\n        if raw_request is None:\n            return None\n        return raw_request.headers.get(\"x-smg-routing-key\")\n\n    def extract_routed_dp_rank_from_header(\n        self, raw_request: Request, body_routed_dp_rank: Optional[int] = None\n    ) -> Optional[int]:\n        \"\"\"Extract routed_dp_rank from HTTP header, with higher priority than routed_dp_rank in body.\n\n        Header name: X-Data-Parallel-Rank (case-insensitive in HTTP/1.1/2)\n        \"\"\"\n        if raw_request is None:\n            return body_routed_dp_rank\n\n        header_value = raw_request.headers.get(\"x-data-parallel-rank\")\n        if header_value is not None:\n            try:\n                header_dp_rank = int(header_value)\n                if (\n                    body_routed_dp_rank is not None\n                    and header_dp_rank != body_routed_dp_rank\n                ):\n                    logger.debug(\n                        f\"X-Data-Parallel-Rank header ({header_dp_rank}) overrides \"\n                        f\"body routed_dp_rank ({body_routed_dp_rank})\"\n                    )\n                return header_dp_rank\n            except ValueError:\n                raise HTTPException(\n                    status_code=400,\n                    detail=f\"Invalid X-Data-Parallel-Rank header: must be an integer, got '{header_value}'\",\n                )\n\n        return body_routed_dp_rank\n"
  },
  {
    "path": "python/sglang/srt/entrypoints/openai/serving_chat.py",
    "content": "from __future__ import annotations\n\nimport copy\nimport json\nimport logging\nimport time\nimport uuid\nfrom http import HTTPStatus\nfrom typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Optional, Union\n\nimport jinja2\nimport orjson\nfrom fastapi import Request\nfrom fastapi.responses import ORJSONResponse, StreamingResponse\nfrom jsonschema import Draft202012Validator, SchemaError\n\nfrom sglang.srt.entrypoints.openai.encoding_dsv32 import encode_messages\nfrom sglang.srt.entrypoints.openai.protocol import (\n    ChatCompletionRequest,\n    ChatCompletionResponse,\n    ChatCompletionResponseChoice,\n    ChatCompletionResponseStreamChoice,\n    ChatCompletionStreamResponse,\n    ChatCompletionTokenLogprob,\n    ChatMessage,\n    ChoiceLogprobs,\n    DeltaMessage,\n    ErrorResponse,\n    FunctionResponse,\n    LogProbs,\n    MessageProcessingResult,\n    SglExt,\n    ToolCall,\n    ToolCallProcessingResult,\n    ToolChoice,\n    TopLogprob,\n)\nfrom sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase\nfrom sglang.srt.entrypoints.openai.usage_processor import UsageProcessor\nfrom sglang.srt.entrypoints.openai.utils import (\n    process_cached_tokens_details_from_ret,\n    process_hidden_states_from_ret,\n    process_routed_experts_from_ret,\n    to_openai_style_logprobs,\n)\nfrom sglang.srt.function_call.core_types import ToolCallItem\nfrom sglang.srt.function_call.function_call_parser import FunctionCallParser\nfrom sglang.srt.function_call.json_array_parser import JsonArrayParser\nfrom sglang.srt.function_call.utils import get_json_schema_constraint\nfrom sglang.srt.managers.io_struct import GenerateReqInput\nfrom sglang.srt.parser.conversation import generate_chat_conv\nfrom sglang.srt.parser.jinja_template_utils import process_content_for_template_format\nfrom sglang.srt.parser.reasoning_parser import ReasoningParser\n\nif TYPE_CHECKING:\n    from sglang.srt.managers.template_manager import TemplateManager\n    from sglang.srt.managers.tokenizer_manager import TokenizerManager\n\nlogger = logging.getLogger(__name__)\n\n\ndef _extract_max_dynamic_patch(request: ChatCompletionRequest):\n    img_vals = []\n    vid_vals = []\n    for msg in request.messages or []:\n        content = getattr(msg, \"content\", None)\n        if not isinstance(content, list):\n            continue\n        for part in content:\n            # pydantic object or dict type\n            if getattr(part, \"type\", None) == \"image_url\":\n                iu = getattr(part, \"image_url\", None)\n                mdp = getattr(iu, \"max_dynamic_patch\", None) if iu else None\n                if mdp is not None:\n                    img_vals.append(int(mdp))\n            elif getattr(part, \"type\", None) == \"video_url\":\n                vu = getattr(part, \"video_url\", None)\n                mdp = getattr(vu, \"max_dynamic_patch\", None) if vu else None\n                if mdp is not None:\n                    vid_vals.append(int(mdp))\n\n    # TODO(yuan-luo): per-item max_dynamic_patch for both image and video\n    img_max_dynamic_patch = min(img_vals) if img_vals else None\n    vid_max_dynamic_patch = min(vid_vals) if vid_vals else None\n    return img_max_dynamic_patch, vid_max_dynamic_patch\n\n\nclass OpenAIServingChat(OpenAIServingBase):\n    \"\"\"Handler for /v1/chat/completions requests\"\"\"\n\n    _default_sampling_params_logged = False\n\n    def __init__(\n        self,\n        tokenizer_manager: TokenizerManager,\n        template_manager: TemplateManager,\n    ):\n        super().__init__(tokenizer_manager)\n        self.template_manager = template_manager\n        self.tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser\n        self.reasoning_parser = self.tokenizer_manager.server_args.reasoning_parser\n\n        # Get default sampling parameters from model's generation config\n        self.default_sampling_params = (\n            self.tokenizer_manager.model_config.get_default_sampling_params()\n        )\n        if (\n            self.default_sampling_params\n            and not OpenAIServingChat._default_sampling_params_logged\n        ):\n            logger.info(\n                f\"Using default chat sampling params from model generation config: {self.default_sampling_params}\",\n            )\n            OpenAIServingChat._default_sampling_params_logged = True\n\n        # Check if the model is a GPT-OSS model\n        self.is_gpt_oss = (\n            hasattr(self.tokenizer_manager.model_config, \"hf_config\")\n            and hasattr(self.tokenizer_manager.model_config.hf_config, \"model_type\")\n            and self.tokenizer_manager.model_config.hf_config.model_type == \"gpt_oss\"\n        )\n\n        self.use_dpsk_v32_encoding = self._use_dpsk_v32_encoding()\n\n    def _handle_last_assistant_message(\n        self,\n        messages: List[Dict[str, Any]],\n        request: ChatCompletionRequest,\n    ) -> tuple[List[Dict[str, Any]], Optional[str]]:\n        \"\"\"\n        Handle continue_final_message feature: separate final assistant message.\n\n        If continue_final_message is enabled and the last message is from assistant,\n        extract its content and remove it from the message list.\n        If continue_final_message is False and the last message is from assistant,\n        convert it to a user message to ensure the last message is always from user.\n\n        Only processes text-based content (strings), ignoring multimodal content (lists).\n\n        Args:\n            messages: List of message dictionaries\n            request: ChatCompletionRequest with continue_final_message flag\n\n        Returns:\n            Tuple of (processed_messages, assistant_prefix)\n            - processed_messages: Messages with last assistant message handled appropriately\n            - assistant_prefix: Content of the last assistant message (string only), or None\n        \"\"\"\n        assistant_prefix = None\n        if messages and messages[-1].get(\"role\") == \"assistant\":\n            last_content = messages[-1].get(\"content\")\n            # Only process string content, ignore multimodal content (lists)\n            if isinstance(last_content, str):\n                if request.continue_final_message:\n                    # Extract content and remove the assistant message\n                    assistant_prefix = last_content\n                    messages = messages[:-1]\n                else:\n                    # Convert the last assistant message to user message\n                    messages[-1] = {\"role\": \"user\", \"content\": last_content}\n        return messages, assistant_prefix\n\n    def _append_assistant_prefix_to_prompt_ids(\n        self, prompt_ids: List[int], assistant_prefix: str\n    ) -> List[int]:\n        \"\"\"\n        Append assistant prefix to prompt_ids.\n\n        Args:\n            prompt_ids: Current prompt token IDs\n            assistant_prefix: Assistant message content to append\n\n        Returns:\n            Updated prompt_ids with assistant prefix appended\n        \"\"\"\n        encoded = self.tokenizer_manager.tokenizer.encode(assistant_prefix)\n        if encoded and encoded[0] == self.tokenizer_manager.tokenizer.bos_token_id:\n            encoded = encoded[1:]\n        return prompt_ids + encoded\n\n    def _use_dpsk_v32_encoding(self) -> bool:\n        has_chat_template = (\n            self.tokenizer_manager.tokenizer is not None\n            and self.tokenizer_manager.tokenizer.chat_template is not None\n        )\n        architectures = self.tokenizer_manager.model_config.hf_config.architectures\n        is_dpsk_v32 = \"DeepseekV3\" in architectures[0] if architectures else False\n        return not has_chat_template and is_dpsk_v32\n\n    def _request_id_prefix(self) -> str:\n        return \"chatcmpl-\"\n\n    def _validate_request(self, request: ChatCompletionRequest) -> Optional[str]:\n        \"\"\"Validate that the input is valid.\"\"\"\n        if not request.messages:\n            return \"Messages cannot be empty.\"\n\n        if (\n            isinstance(request.tool_choice, str)\n            and request.tool_choice.lower() == \"required\"\n            and not request.tools\n        ):\n            return \"Tools cannot be empty if tool choice is set to required.\"\n\n        if request.tool_choice is not None and not isinstance(request.tool_choice, str):\n            if not request.tools:\n                return \"Tools cannot be empty if tool choice is set to a specific tool.\"\n            tool_name = request.tool_choice.function.name\n            tool_exists = any(tool.function.name == tool_name for tool in request.tools)\n            if not tool_exists:\n                return f\"Tool '{tool_name}' not found in tools list.\"\n\n        # Validate tool definitions\n        for i, tool in enumerate(request.tools or []):\n            if tool.function.parameters is None:\n                continue\n            try:\n                Draft202012Validator.check_schema(tool.function.parameters)\n            except SchemaError as e:\n                return f\"Tool {i} function has invalid 'parameters' schema: {str(e)}\"\n\n        max_output_tokens = request.max_completion_tokens or request.max_tokens\n        server_context_length = self.tokenizer_manager.server_args.context_length\n        if (\n            max_output_tokens\n            and server_context_length\n            and max_output_tokens > server_context_length\n        ) and not self.tokenizer_manager.server_args.allow_auto_truncate:\n            return (\n                f\"max_completion_tokens is too large: {max_output_tokens}.\"\n                f\"This model supports at most {server_context_length} completion tokens.\"\n            )\n\n        if request.response_format and request.response_format.type == \"json_schema\":\n            schema = getattr(request.response_format.json_schema, \"schema_\", None)\n            if schema is None:\n                return \"schema_ is required for json_schema response format request.\"\n\n        return None\n\n    def _convert_to_internal_request(\n        self,\n        request: ChatCompletionRequest,\n        raw_request: Request = None,\n    ) -> tuple[GenerateReqInput, ChatCompletionRequest]:\n        reasoning_effort = (\n            request.chat_template_kwargs.pop(\"reasoning_effort\", None)\n            if request.chat_template_kwargs\n            else None\n        )\n        if self.is_gpt_oss and reasoning_effort == \"none\":\n            raise ValueError(\n                f\"Harmony does not support reasoning effort {reasoning_effort}\"\n            )\n\n        if reasoning_effort is not None:\n            request.reasoning_effort = reasoning_effort\n\n        \"\"\"Convert OpenAI chat completion request to internal format\"\"\"\n        is_multimodal = self.tokenizer_manager.model_config.is_multimodal\n\n        # Process messages and apply chat template\n        processed_messages = self._process_messages(request, is_multimodal)\n\n        # Build sampling parameters\n        sampling_params = request.to_sampling_params(\n            stop=processed_messages.stop,\n            model_generation_config=self.default_sampling_params,\n            tool_call_constraint=processed_messages.tool_call_constraint,\n        )\n\n        # Handle single vs multiple requests\n        if is_multimodal:\n            prompt_kwargs = {\"text\": processed_messages.prompt}\n        else:\n            if isinstance(processed_messages.prompt_ids, str):\n                prompt_kwargs = {\"text\": processed_messages.prompt_ids}\n            else:\n                prompt_kwargs = {\"input_ids\": processed_messages.prompt_ids}\n\n        # Extract custom labels from raw request headers\n        custom_labels = self.extract_custom_labels(raw_request)\n\n        # Extract routed_dp_rank from header (has higher priority than body)\n        effective_routed_dp_rank = self.extract_routed_dp_rank_from_header(\n            raw_request, request.routed_dp_rank\n        )\n\n        # Resolve LoRA adapter from model parameter or explicit lora_path\n        lora_path = self._resolve_lora_path(request.model, request.lora_path)\n        img_max_dynamic_patch, vid_max_dynamic_patch = _extract_max_dynamic_patch(\n            request\n        )\n        adapted_request = GenerateReqInput(\n            **prompt_kwargs,\n            image_data=processed_messages.image_data,\n            video_data=processed_messages.video_data,\n            audio_data=processed_messages.audio_data,\n            sampling_params=sampling_params,\n            return_logprob=request.logprobs,\n            logprob_start_len=-1,\n            top_logprobs_num=request.top_logprobs or 0,\n            stream=request.stream,\n            return_text_in_logprobs=True,\n            modalities=processed_messages.modalities,\n            lora_path=lora_path,\n            bootstrap_host=request.bootstrap_host,\n            bootstrap_port=request.bootstrap_port,\n            bootstrap_room=request.bootstrap_room,\n            routed_dp_rank=effective_routed_dp_rank,\n            disagg_prefill_dp_rank=request.disagg_prefill_dp_rank,\n            return_hidden_states=request.return_hidden_states,\n            return_routed_experts=request.return_routed_experts,\n            rid=request.rid,\n            extra_key=self._compute_extra_key(request),\n            require_reasoning=self._get_reasoning_from_request(request),\n            priority=request.priority,\n            routing_key=self.extract_routing_key(raw_request),\n            custom_labels=custom_labels,\n            custom_logit_processor=request.custom_logit_processor,\n            image_max_dynamic_patch=img_max_dynamic_patch,\n            video_max_dynamic_patch=vid_max_dynamic_patch,\n            max_dynamic_patch=getattr(request, \"max_dynamic_patch\", None),\n        )\n\n        return adapted_request, request\n\n    def _process_messages(\n        self, request: ChatCompletionRequest, is_multimodal: bool\n    ) -> MessageProcessingResult:\n        \"\"\"Process chat messages and apply chat template\"\"\"\n        # GptOss model needs to keep special tokens for harmony parsing\n        if self.is_gpt_oss:\n            request.skip_special_tokens = False\n\n        self._patch_mistral_skip_special_tokens(request)\n\n        tool_call_constraint = None\n\n        # Apply chat template and its stop strings\n        tools = None\n        if request.tools and request.tool_choice != \"none\":\n            request.skip_special_tokens = False\n            if not isinstance(request.tool_choice, str):\n                tools = [\n                    item.model_dump()\n                    for item in request.tools\n                    if item.function.name == request.tool_choice.function.name\n                ]\n            else:\n                tools = [item.model_dump() for item in request.tools]\n            if self.tool_call_parser:\n                parser = FunctionCallParser(request.tools, self.tool_call_parser)\n                tool_call_constraint = parser.get_structure_constraint(\n                    request.tool_choice\n                )\n            # Handle JSON schema constraint directly for required or named tool choice\n            if request.tool_choice == \"required\" or isinstance(\n                request.tool_choice, ToolChoice\n            ):\n                json_schema = get_json_schema_constraint(\n                    request.tools, request.tool_choice\n                )\n                tool_call_constraint = (\"json_schema\", json_schema)\n\n        # Use chat template\n        if self.template_manager.chat_template_name is None:\n            result = self._apply_jinja_template(request, tools, is_multimodal)\n        else:\n            result = self._apply_conversation_template(request, is_multimodal)\n\n        result.tool_call_constraint = tool_call_constraint\n        return result\n\n    def _apply_jinja_template(\n        self,\n        request: ChatCompletionRequest,\n        tools: Optional[List[Dict]],\n        is_multimodal: bool,\n    ) -> MessageProcessingResult:\n        \"\"\"Apply Jinja chat template\"\"\"\n        prompt = \"\"\n        prompt_ids = []\n        openai_compatible_messages = []\n        image_data = []\n        video_data = []\n        audio_data = []\n        modalities = []\n\n        template_content_format = self.template_manager.jinja_template_content_format\n\n        if self.use_dpsk_v32_encoding:\n            thinking_mode = (\n                \"thinking\"\n                if (request.chat_template_kwargs or {}).get(\"thinking\")\n                else \"chat\"\n            )\n            messages = request.messages\n            messages = [msg.model_dump() for msg in messages]\n\n            for msg in messages:\n                if msg.get(\"content\") is None:\n                    msg[\"content\"] = \"\"\n                processed_msg = process_content_for_template_format(\n                    msg,\n                    template_content_format,\n                    image_data,\n                    video_data,\n                    audio_data,\n                    modalities,\n                    use_dpsk_v32_encoding=self.use_dpsk_v32_encoding,\n                )\n                msg.update(processed_msg)\n\n            # Handle continue_final_message: separate final assistant message\n            messages, assistant_prefix = self._handle_last_assistant_message(\n                messages, request\n            )\n\n            if messages[0][\"role\"] != \"system\":\n                # insert an empty system prompt to help render tool system prompt\n                messages.insert(0, {\"role\": \"system\", \"content\": \"\"})\n            if request.tools:\n                messages[0][\"tools\"] = [tool.model_dump() for tool in request.tools]\n            real_input = encode_messages(messages, thinking_mode=thinking_mode)\n            prompt_ids = self.tokenizer_manager.tokenizer.encode(real_input)\n\n            # Append assistant prefix if continue_final_message is enabled\n            if assistant_prefix:\n                prompt_ids = self._append_assistant_prefix_to_prompt_ids(\n                    prompt_ids, assistant_prefix\n                )\n        else:\n            for message in request.messages:\n                if message.content is None:\n                    message.content = \"\"\n                msg_dict = message.model_dump()\n\n                # Process content based on detected template format\n                processed_msg = process_content_for_template_format(\n                    msg_dict,\n                    template_content_format,\n                    image_data,\n                    video_data,\n                    audio_data,\n                    modalities,\n                )\n\n                # per the Transformers docs & maintainers, tool call arguments in\n                # assistant-role messages with tool_calls need to be dicts not JSON str -\n                # this is how tool-use chat templates will expect them moving forwards\n                # so, for messages that have tool_calls, parse the string (which we get\n                # from openAI format) to dict\n                if (\n                    processed_msg[\"role\"] == \"assistant\"\n                    and \"tool_calls\" in processed_msg\n                    and isinstance(processed_msg[\"tool_calls\"], list)\n                ):\n                    for item in processed_msg[\"tool_calls\"]:\n                        if \"arguments\" in item[\"function\"] and isinstance(\n                            item[\"function\"][\"arguments\"], str\n                        ):\n                            item[\"function\"][\"arguments\"] = orjson.loads(\n                                item[\"function\"][\"arguments\"]\n                            )\n\n                openai_compatible_messages.append(processed_msg)\n\n            # Handle continue_final_message: separate final assistant message\n            openai_compatible_messages, assistant_prefix = (\n                self._handle_last_assistant_message(openai_compatible_messages, request)\n            )\n\n            extra_template_kwargs = {}\n            if request.reasoning_effort is not None:\n                extra_template_kwargs[\"reasoning_effort\"] = request.reasoning_effort\n            if request.chat_template_kwargs:\n                extra_template_kwargs.update(request.chat_template_kwargs)\n\n            try:\n                prompt_ids = self.tokenizer_manager.tokenizer.apply_chat_template(\n                    openai_compatible_messages,\n                    tokenize=True,\n                    add_generation_prompt=True,\n                    tools=tools,\n                    return_dict=False,\n                    **extra_template_kwargs,\n                )\n            except Exception as e:\n                # If the first attempt fails, try with flat function-only format.\n                # Some templates (e.g. Mistral) expect tools without the OpenAI wrapper.\n                tools = (\n                    [t[\"function\"] if \"function\" in t else t for t in tools]\n                    if tools\n                    else None\n                )\n                try:\n                    prompt_ids = self.tokenizer_manager.tokenizer.apply_chat_template(\n                        openai_compatible_messages,\n                        tokenize=True,\n                        add_generation_prompt=True,\n                        tools=tools,\n                        return_dict=False,\n                        **extra_template_kwargs,\n                    )\n                except jinja2.TemplateError as template_error:\n                    # Template errors (e.g., from raise_exception in Jinja templates)\n                    # should be treated as client errors (400 BadRequest)\n                    raise ValueError(str(template_error)) from template_error\n\n            # Append assistant prefix if continue_final_message is enabled\n            if assistant_prefix:\n                prompt_ids = self._append_assistant_prefix_to_prompt_ids(\n                    prompt_ids, assistant_prefix\n                )\n\n            if is_multimodal:\n                prompt = self.tokenizer_manager.tokenizer.decode(prompt_ids)\n\n        stop = request.stop\n        image_data = image_data if image_data else None\n        audio_data = audio_data if audio_data else None\n        video_data = video_data if video_data else None\n        modalities = modalities if modalities else []\n        return MessageProcessingResult(\n            prompt=prompt,\n            prompt_ids=prompt_ids,\n            image_data=image_data,\n            video_data=video_data,\n            audio_data=audio_data,\n            modalities=modalities,\n            stop=stop,\n        )\n\n    def _apply_conversation_template(\n        self,\n        request: ChatCompletionRequest,\n        is_multimodal: bool,\n    ) -> MessageProcessingResult:\n        \"\"\"Apply conversation template\"\"\"\n        prompt = \"\"\n        prompt_ids = []\n        conv = generate_chat_conv(request, self.template_manager.chat_template_name)\n\n        # If we should continue the final assistant message, adjust the conversation.\n        if (\n            request.continue_final_message\n            and request.messages\n            and request.messages[-1].role == \"assistant\"\n        ):\n            # Remove the auto-added blank assistant turn, if present.\n            if conv.messages and conv.messages[-1][1] is None:\n                conv.messages.pop()\n            # Rebuild the prompt from the conversation.\n            prompt = conv.get_prompt()\n            # Strip trailing stop tokens or separators that indicate end-of-assistant.\n            if isinstance(conv.stop_str, list):\n                for stop_token in conv.stop_str:\n                    if prompt.endswith(stop_token):\n                        prompt = prompt[: -len(stop_token)]\n            elif isinstance(conv.stop_str, str) and prompt.endswith(conv.stop_str):\n                prompt = prompt[: -len(conv.stop_str)]\n            if conv.sep and prompt.endswith(conv.sep):\n                prompt = prompt[: -len(conv.sep)]\n            if getattr(conv, \"sep2\", None) and prompt.endswith(conv.sep2):\n                prompt = prompt[: -len(conv.sep2)]\n        else:\n            prompt = conv.get_prompt()\n            if self._get_reasoning_from_request(\n                request\n            ) and self.reasoning_parser not in [\"qwen3\", \"qwen3-thinking\", \"glm4\"]:\n                # qwen3 and glm4 think internally without a leading <think> token\n                prompt += \"<think>\"  # Note(Xinyuan): hard code thinking token\n\n        image_data = conv.image_data if conv.image_data else None\n        video_data = conv.video_data if conv.video_data else None\n        audio_data = conv.audio_data if conv.audio_data else None\n        modalities = conv.modalities if conv.modalities else []\n        stop = copy.copy(conv.stop_str or [] if not request.ignore_eos else [])\n\n        if request.stop:\n            if isinstance(request.stop, str):\n                stop.append(request.stop)\n            else:\n                stop.extend(request.stop)\n\n        if not is_multimodal:\n            prompt_ids = self.tokenizer_manager.tokenizer.encode(prompt)\n\n        return MessageProcessingResult(\n            prompt=prompt,\n            prompt_ids=prompt_ids,\n            image_data=image_data,\n            video_data=video_data,\n            audio_data=audio_data,\n            modalities=modalities,\n            stop=stop,\n        )\n\n    async def _handle_streaming_request(\n        self,\n        adapted_request: GenerateReqInput,\n        request: ChatCompletionRequest,\n        raw_request: Request,\n    ) -> StreamingResponse:\n        \"\"\"Handle streaming chat completion request\"\"\"\n        return StreamingResponse(\n            self._generate_chat_stream(adapted_request, request, raw_request),\n            media_type=\"text/event-stream\",\n            background=self.tokenizer_manager.create_abort_task(adapted_request),\n        )\n\n    async def _generate_chat_stream(\n        self,\n        adapted_request: GenerateReqInput,\n        request: ChatCompletionRequest,\n        raw_request: Request,\n    ) -> AsyncGenerator[str, None]:\n        \"\"\"Generate streaming chat completion response\"\"\"\n        # Parsers for tool calls and reasoning\n        parser_dict = {}\n        reasoning_parser_dict = {}\n\n        # State tracking for streaming\n        is_firsts = {}\n        stream_buffers = {}\n        n_prev_tokens = {}\n        has_tool_calls = {}\n        finish_reasons = {}\n\n        # Usage tracking\n        prompt_tokens = {}\n        completion_tokens = {}\n        cached_tokens = {}\n        hidden_states = {}\n        routed_experts = {}\n\n        try:\n            async for content in self.tokenizer_manager.generate_request(\n                adapted_request, raw_request\n            ):\n                index = content.get(\"index\", 0)\n\n                prompt_tokens[index] = content[\"meta_info\"].get(\"prompt_tokens\", 0)\n                completion_tokens[index] = content[\"meta_info\"].get(\n                    \"completion_tokens\", 0\n                )\n                cached_tokens[index] = content[\"meta_info\"].get(\"cached_tokens\", 0)\n                hidden_states[index] = content[\"meta_info\"].get(\"hidden_states\", None)\n                routed_experts[index] = content[\"meta_info\"].get(\"routed_experts\", None)\n\n                # Handle logprobs\n                finish_reason = content[\"meta_info\"].get(\"finish_reason\", None)\n                choice_logprobs = None\n                if request.logprobs:\n                    n_prev_token = n_prev_tokens.get(index, 0)\n                    total_output_logprobs = len(\n                        content[\"meta_info\"][\"output_token_logprobs\"]\n                    )\n                    # When finish_reason is set and all logprobs have been sent,\n                    # any remaining text is just buffered text being flushed by the\n                    # detokenizer (it holds back text at word boundaries). Return None\n                    # for logprobs since no new tokens were generated for this text.\n                    if n_prev_token < total_output_logprobs or finish_reason is None:\n                        choice_logprobs = self._process_streaming_logprobs(\n                            content, n_prev_token\n                        )\n                    n_prev_tokens[index] = total_output_logprobs\n                finish_reason_type = finish_reason[\"type\"] if finish_reason else None\n\n                # Track finish_reason for each index\n                if finish_reason_type:\n                    # If the abort is from scheduler.\n                    if finish_reason_type == \"abort\":\n                        code = finish_reason.get(\n                            \"status_code\", HTTPStatus.INTERNAL_SERVER_ERROR\n                        )\n                        error = self.create_streaming_error_response(\n                            finish_reason.get(\"message\", \"Generation aborted.\"),\n                            code.name,\n                            code.value,\n                        )\n                        yield f\"data: {error}\\n\\n\"\n                        break\n                    else:\n                        finish_reasons[index] = finish_reason\n\n                # First chunk with role\n                if is_firsts.get(index, True):\n                    is_firsts[index] = False\n                    delta = DeltaMessage(role=\"assistant\", content=\"\")\n                    choice_data = ChatCompletionResponseStreamChoice(\n                        index=index,\n                        delta=delta,\n                        finish_reason=None,\n                        logprobs=None,\n                    )\n                    chunk = ChatCompletionStreamResponse(\n                        id=content[\"meta_info\"][\"id\"],\n                        created=int(time.time()),\n                        choices=[choice_data],\n                        model=request.model,\n                    )\n                    yield f\"data: {chunk.model_dump_json()}\\n\\n\"\n\n                stream_buffer = stream_buffers.get(index, \"\")\n                delta = content[\"text\"][len(stream_buffer) :]\n                stream_buffers[index] = stream_buffer + delta\n\n                # Handle reasoning content\n                if self.reasoning_parser and request.separate_reasoning:\n                    reasoning_text, delta = self._process_reasoning_stream(\n                        index, delta, reasoning_parser_dict, content, request\n                    )\n                    if reasoning_text:\n                        choice_data = ChatCompletionResponseStreamChoice(\n                            index=index,\n                            delta=DeltaMessage(reasoning_content=reasoning_text),\n                            finish_reason=None,\n                        )\n                        chunk = ChatCompletionStreamResponse(\n                            id=content[\"meta_info\"][\"id\"],\n                            created=int(time.time()),\n                            choices=[choice_data],\n                            model=request.model,\n                        )\n\n                        # Add usage stats if continuous_usage_stats is enabled\n                        if (\n                            request.stream_options\n                            and request.stream_options.continuous_usage_stats\n                        ):\n                            chunk.usage = UsageProcessor.calculate_token_usage(\n                                prompt_tokens=prompt_tokens.get(index, 0),\n                                completion_tokens=completion_tokens.get(index, 0),\n                            )\n\n                        yield f\"data: {chunk.model_dump_json()}\\n\\n\"\n\n                # Handle tool calls\n                if (\n                    request.tool_choice != \"none\"\n                    and request.tools\n                    and self.tool_call_parser\n                ):\n                    async for chunk in self._process_tool_call_stream(\n                        index,\n                        delta,\n                        parser_dict,\n                        content,\n                        request,\n                        has_tool_calls,\n                    ):\n                        if chunk:\n                            yield chunk\n\n                    # Send any remaining tool call arguments when generation finishes\n                    if finish_reason_type is not None and index in parser_dict:\n                        parser = parser_dict[index]\n                        remaining_chunk = self._check_for_unstreamed_tool_args(\n                            parser, content, request, index\n                        )\n                        if remaining_chunk:\n                            yield remaining_chunk\n\n                else:\n                    # Regular content\n                    if delta:\n                        choice_data = ChatCompletionResponseStreamChoice(\n                            index=index,\n                            delta=DeltaMessage(content=delta),\n                            finish_reason=None,\n                            matched_stop=None,\n                            logprobs=choice_logprobs,\n                        )\n                        chunk = ChatCompletionStreamResponse(\n                            id=content[\"meta_info\"][\"id\"],\n                            created=int(time.time()),\n                            choices=[choice_data],\n                            model=request.model,\n                        )\n\n                        # Add usage stats if continuous_usage_stats is enabled\n                        if (\n                            request.stream_options\n                            and request.stream_options.continuous_usage_stats\n                        ):\n                            chunk.usage = UsageProcessor.calculate_token_usage(\n                                prompt_tokens=prompt_tokens.get(index, 0),\n                                completion_tokens=completion_tokens.get(index, 0),\n                            )\n\n                        yield f\"data: {chunk.model_dump_json()}\\n\\n\"\n\n            # Send finish_reason chunks for each index that completed\n            for idx, finish_reason_data in finish_reasons.items():\n                finish_reason_type = finish_reason_data[\"type\"]\n\n                # Change finish_reason to \"tool_calls\" if we had tool calls and stopped naturally\n                final_finish_reason = finish_reason_type\n                if has_tool_calls.get(idx, False) and finish_reason_type == \"stop\":\n                    final_finish_reason = \"tool_calls\"\n\n                finish_reason_chunk = ChatCompletionStreamResponse(\n                    id=content[\"meta_info\"][\n                        \"id\"\n                    ],  # NOTE: openai uses the same chatcmpl-id for all indices\n                    created=int(time.time()),\n                    choices=[\n                        ChatCompletionResponseStreamChoice(\n                            index=idx,\n                            delta=DeltaMessage(),\n                            finish_reason=final_finish_reason,\n                            matched_stop=(\n                                finish_reason_data[\"matched\"]\n                                if \"matched\" in finish_reason_data\n                                else None\n                            ),\n                        )\n                    ],\n                    model=request.model,\n                    usage=None,\n                )\n                yield f\"data: {finish_reason_chunk.model_dump_json()}\\n\\n\"\n\n            # Send hidden states if requested\n            if request.return_hidden_states and hidden_states:\n                for index, choice_hidden_states in hidden_states.items():\n                    if choice_hidden_states:\n                        last_token_hidden_states = (\n                            choice_hidden_states[-1]\n                            if len(choice_hidden_states) > 1\n                            else []\n                        )\n                        hidden_states_chunk = ChatCompletionStreamResponse(\n                            id=content[\"meta_info\"][\"id\"],\n                            created=int(time.time()),\n                            choices=[\n                                ChatCompletionResponseStreamChoice(\n                                    index=index,\n                                    delta=DeltaMessage(\n                                        hidden_states=last_token_hidden_states\n                                    ),\n                                    finish_reason=None,  # Hidden states don't need finish_reason\n                                )\n                            ],\n                            model=request.model,\n                        )\n                        yield f\"data: {hidden_states_chunk.model_dump_json()}\\n\\n\"\n\n            if request.return_routed_experts and routed_experts:\n                # Get first non-None routed_experts value\n                first_routed_experts = next(\n                    (v for v in routed_experts.values() if v is not None), None\n                )\n                if first_routed_experts is not None:\n                    routed_experts_chunk = ChatCompletionStreamResponse(\n                        id=content[\"meta_info\"][\"id\"],\n                        created=int(time.time()),\n                        choices=[],  # sglext is at response level\n                        model=request.model,\n                        sglext=SglExt(routed_experts=first_routed_experts),\n                    )\n                    yield f\"data: {routed_experts_chunk.model_dump_json()}\\n\\n\"\n\n            # Additional usage chunk\n            if request.stream_options and request.stream_options.include_usage:\n                usage = UsageProcessor.calculate_streaming_usage(\n                    prompt_tokens,\n                    completion_tokens,\n                    cached_tokens,\n                    n_choices=request.n,\n                    enable_cache_report=self.tokenizer_manager.server_args.enable_cache_report,\n                )\n                usage_chunk = ChatCompletionStreamResponse(\n                    id=content[\"meta_info\"][\"id\"],\n                    created=int(time.time()),\n                    choices=[],  # Empty choices array as per OpenAI spec\n                    model=request.model,\n                    usage=usage,\n                )\n                yield f\"data: {usage_chunk.model_dump_json()}\\n\\n\"\n\n        except ValueError as e:\n            error = self.create_streaming_error_response(str(e))\n            yield f\"data: {error}\\n\\n\"\n\n        yield \"data: [DONE]\\n\\n\"\n\n    async def _handle_non_streaming_request(\n        self,\n        adapted_request: GenerateReqInput,\n        request: ChatCompletionRequest,\n        raw_request: Request,\n    ) -> Union[ChatCompletionResponse, ErrorResponse, ORJSONResponse]:\n        \"\"\"Handle non-streaming chat completion request\"\"\"\n        try:\n            ret = await self.tokenizer_manager.generate_request(\n                adapted_request, raw_request\n            ).__anext__()\n        except ValueError as e:\n            return self.create_error_response(str(e))\n\n        if not isinstance(ret, list):\n            ret = [ret]\n\n        response = self._build_chat_response(\n            request,\n            ret,\n            int(time.time()),\n        )\n\n        return response\n\n    def _build_chat_response(\n        self,\n        request: ChatCompletionRequest,\n        ret: List[Dict[str, Any]],\n        created: int,\n    ) -> Union[ChatCompletionResponse, ORJSONResponse]:\n        \"\"\"Build chat completion response from generation results\"\"\"\n        choices = []\n\n        # Build sglext at response level (from first ret_item, as these are per-request)\n        first_ret = ret[0]\n        routed_experts = process_routed_experts_from_ret(first_ret, request)\n        cached_tokens_details = process_cached_tokens_details_from_ret(\n            first_ret, request\n        )\n        response_sglext = None\n        if routed_experts or cached_tokens_details:\n            response_sglext = SglExt(\n                routed_experts=routed_experts,\n                cached_tokens_details=cached_tokens_details,\n            )\n\n        for idx, ret_item in enumerate(ret):\n            # Process logprobs\n            choice_logprobs = None\n            if request.logprobs:\n                choice_logprobs = self._process_response_logprobs(ret_item)\n\n            # Handle hidden states\n            hidden_states = process_hidden_states_from_ret(ret_item, request)\n\n            finish_reason = ret_item[\"meta_info\"][\"finish_reason\"]\n            text = ret_item[\"text\"]\n\n            # Handle reasoning content\n            reasoning_text = None\n            reasoning_parser = self.reasoning_parser\n            if reasoning_parser and request.separate_reasoning:\n                is_force_reasoning = (\n                    self.template_manager.force_reasoning\n                    or self._get_reasoning_from_request(request)\n                )\n                try:\n                    parser = ReasoningParser(\n                        model_type=reasoning_parser,\n                        stream_reasoning=False,\n                        force_reasoning=is_force_reasoning,\n                        request=request,\n                    )\n                    reasoning_text, text = parser.parse_non_stream(text)\n                except Exception as e:\n                    logger.error(f\"Reasoning parsing error: {e}\")\n                    return self.create_error_response(\n                        \"Failed to parse reasoning content\",\n                        err_type=\"InternalServerError\",\n                        status_code=500,\n                    )\n\n            # Handle tool calls\n            tool_calls = None\n            if (\n                request.tool_choice != \"none\"\n                and request.tools\n                and self.tool_call_parser\n            ):\n                history_tool_calls_cnt = self._get_history_tool_calls_cnt(request)\n                tool_calls, text, finish_reason = self._process_tool_calls(\n                    text,\n                    request.tools,\n                    finish_reason,\n                    request.tool_choice,\n                    history_tool_calls_cnt,\n                )\n\n            choice_data = ChatCompletionResponseChoice(\n                index=idx,\n                message=ChatMessage(\n                    role=\"assistant\",\n                    content=text if text else None,\n                    tool_calls=tool_calls,\n                    reasoning_content=reasoning_text if reasoning_text else None,\n                ),\n                logprobs=choice_logprobs,\n                finish_reason=finish_reason[\"type\"] if finish_reason else None,\n                matched_stop=(\n                    finish_reason[\"matched\"]\n                    if finish_reason and \"matched\" in finish_reason\n                    else None\n                ),\n                hidden_states=hidden_states,\n            )\n            choices.append(choice_data)\n\n        # Calculate usage\n        usage = UsageProcessor.calculate_response_usage(\n            ret,\n            n_choices=request.n,\n            enable_cache_report=self.tokenizer_manager.server_args.enable_cache_report,\n        )\n\n        return ChatCompletionResponse(\n            id=ret[0][\"meta_info\"][\"id\"],\n            created=created,\n            model=request.model,\n            choices=choices,\n            usage=usage,\n            metadata={\"weight_version\": ret[0][\"meta_info\"][\"weight_version\"]},\n            sglext=response_sglext,\n        )\n\n    def _process_logprobs_tokens(\n        self, logprobs: LogProbs, use_token_index: bool = False\n    ) -> List[ChatCompletionTokenLogprob]:\n        \"\"\"Common helper to process logprobs tokens for both streaming and non-streaming\n\n        Args:\n            logprobs: LogProbs data from model\n            use_token_index: True for non-streaming (use token_idx), False for streaming (use index 0)\n        \"\"\"\n        token_logprobs = []\n\n        for token_idx, (token, logprob) in enumerate(\n            zip(logprobs.tokens, logprobs.token_logprobs)\n        ):\n            token_bytes = list(token.encode(\"utf-8\"))\n            top_logprobs = []\n            if logprobs.top_logprobs:\n                # - Non-streaming (use_token_index=True): uses token_idx for full data\n                # - Streaming (use_token_index=False): uses index 0 for pre-sliced data\n                top_logprobs_idx = token_idx if use_token_index else 0\n                for top_token, top_logprob in logprobs.top_logprobs[\n                    top_logprobs_idx\n                ].items():\n                    top_token_bytes = list(top_token.encode(\"utf-8\"))\n                    top_logprobs.append(\n                        TopLogprob(\n                            token=top_token,\n                            bytes=top_token_bytes,\n                            logprob=top_logprob,\n                        )\n                    )\n            token_logprobs.append(\n                ChatCompletionTokenLogprob(\n                    token=token,\n                    bytes=token_bytes,\n                    logprob=logprob,\n                    top_logprobs=top_logprobs,\n                )\n            )\n\n        return token_logprobs\n\n    def _process_response_logprobs(self, ret_item: Dict[str, Any]) -> ChoiceLogprobs:\n        \"\"\"Process logprobs for non-streaming response\"\"\"\n        logprobs = to_openai_style_logprobs(\n            output_token_logprobs=ret_item[\"meta_info\"][\"output_token_logprobs\"],\n            output_top_logprobs=ret_item[\"meta_info\"].get(\"output_top_logprobs\", None),\n        )\n\n        token_logprobs = self._process_logprobs_tokens(logprobs, use_token_index=True)\n        return ChoiceLogprobs(content=token_logprobs)\n\n    def _process_tool_call_id(\n        self,\n        call_item: ToolCallItem,\n        history_tool_calls_cnt: int,\n    ) -> str:\n        \"\"\"Process for generating a new and unique `tool_call_id`\"\"\"\n        if self.tool_call_parser != \"kimi_k2\":\n            # A simple uuid is sufficient for all models except for Kimi-K2.\n            tool_call_id = f\"call_{uuid.uuid4().hex[:24]}\"\n            return tool_call_id\n        else:\n            # Align with Kimi-K2 format: functions.{name}:{index}\n            # Kimi-K2 allows multiple tool_calls in one message; SGLang sets call_item.tool_index to the *local* position inside that message.\n            # Therefore, the index must be corrected by using `history_tool_calls_cnt + call_item.tool_index` to ensure globally unique and properly ordered.\n            tool_call_id = f\"functions.{call_item.name}:{history_tool_calls_cnt+call_item.tool_index}\"\n            logger.debug(\n                f\"Process tool call idx, parser: {self.tool_call_parser}, tool_call_id: {tool_call_id}, history_cnt: {history_tool_calls_cnt}\"\n            )\n            return tool_call_id\n\n    def _process_tool_calls(\n        self,\n        text: str,\n        tools: List[Any],\n        finish_reason: Dict[str, Any],\n        tool_choice: Optional[Union[str, ToolChoice]] = None,\n        history_tool_calls_cnt: int = 0,\n    ) -> ToolCallProcessingResult:\n        \"\"\"Process tool calls in the response\"\"\"\n\n        # Handle required or named tool choice\n        if tool_choice == \"required\" or (\n            isinstance(tool_choice, ToolChoice) and tool_choice.type == \"function\"\n        ):\n            # Set finish reason to tool_calls since we're processing tool calls\n            if finish_reason[\"type\"] == \"stop\":\n                finish_reason[\"type\"] = \"tool_calls\"\n                finish_reason[\"matched\"] = None\n            try:\n                # For required tool choice, we expect a JSON array of tool calls\n                tool_call_data = orjson.loads(text)\n                tool_calls = []\n                for i, tool in enumerate(tool_call_data):\n                    # Create a ToolCallItem from the JSON data\n                    call_info = ToolCallItem(\n                        tool_index=i,  # Use the loop index as tool_index\n                        name=tool[\"name\"],\n                        parameters=json.dumps(tool[\"parameters\"], ensure_ascii=False),\n                    )\n                    tool_id = self._process_tool_call_id(\n                        call_info, history_tool_calls_cnt\n                    )\n                    tool_calls.append(\n                        ToolCall(\n                            id=tool_id,\n                            index=i,\n                            function=FunctionResponse(\n                                name=tool[\"name\"],\n                                arguments=json.dumps(\n                                    tool[\"parameters\"], ensure_ascii=False\n                                ),\n                            ),\n                        )\n                    )\n                return ToolCallProcessingResult(tool_calls, \"\", finish_reason)\n            except json.JSONDecodeError as e:\n                logger.error(f\"Tool call parsing error: {e}\")\n                return ToolCallProcessingResult(None, text, finish_reason)\n\n        # Use parser since output is not constrained by JSON schema\n        parser = FunctionCallParser(tools, self.tool_call_parser)\n        if parser.has_tool_call(text):\n            if finish_reason[\"type\"] == \"stop\":\n                finish_reason[\"type\"] = \"tool_calls\"\n                finish_reason[\"matched\"] = None\n            try:\n                text, call_info_list = parser.parse_non_stream(text)\n                tool_calls = []\n                for call_info in call_info_list:\n                    tool_id = self._process_tool_call_id(\n                        call_info, history_tool_calls_cnt\n                    )\n                    tool_calls.append(\n                        ToolCall(\n                            id=tool_id,\n                            index=getattr(call_info, \"tool_index\", None),\n                            function=FunctionResponse(\n                                name=call_info.name, arguments=call_info.parameters\n                            ),\n                        )\n                    )\n                return ToolCallProcessingResult(tool_calls, text, finish_reason)\n            except Exception as e:\n                logger.error(f\"Tool call parsing error: {e}\")\n                # Return error but don't fail the whole request\n                return ToolCallProcessingResult(None, text, finish_reason)\n\n        return ToolCallProcessingResult(None, text, finish_reason)\n\n    def _process_streaming_logprobs(\n        self, content: Dict[str, Any], n_prev_token: int\n    ) -> ChoiceLogprobs:\n        \"\"\"Process logprobs for streaming response\"\"\"\n        logprobs = to_openai_style_logprobs(\n            output_token_logprobs=content[\"meta_info\"][\"output_token_logprobs\"][\n                n_prev_token:\n            ],\n            output_top_logprobs=content[\"meta_info\"].get(\"output_top_logprobs\", [])[\n                n_prev_token:\n            ],\n        )\n\n        token_logprobs = self._process_logprobs_tokens(logprobs, use_token_index=False)\n        return ChoiceLogprobs(content=token_logprobs)\n\n    def _process_reasoning_stream(\n        self,\n        index: int,\n        delta: str,\n        reasoning_parser_dict: Dict[int, ReasoningParser],\n        content: Dict[str, Any],\n        request: ChatCompletionRequest,\n    ) -> tuple[Optional[str], str]:\n        \"\"\"Process reasoning content in streaming response\"\"\"\n        if index not in reasoning_parser_dict:\n            is_force_reasoning = (\n                self.template_manager.force_reasoning\n                or self._get_reasoning_from_request(request)\n            )\n            reasoning_parser_dict[index] = ReasoningParser(\n                self.reasoning_parser,\n                request.stream_reasoning,\n                is_force_reasoning,\n                request,\n            )\n        reasoning_parser = reasoning_parser_dict[index]\n        return reasoning_parser.parse_stream_chunk(delta)\n\n    def _get_history_tool_calls_cnt(self, request: ChatCompletionRequest) -> int:\n        \"\"\"Counts the number of tool calls in the request's message history.\n\n        NOTE: This method is only useful for models that include self-increasing\n        history tool call idx in tool calls id, such as kimi-k2\n\n        Args:\n            request: The chat completion request object.\n\n        Returns:\n            The total number of tool calls in the history, or 0 if not applicable.\n        \"\"\"\n        messages = getattr(request, \"messages\", [])\n        idx = 0\n        for msg in messages:\n            if msg.role == \"assistant\":\n                tool_calls = getattr(msg, \"tool_calls\", None)\n                idx += len(list(tool_calls)) if tool_calls is not None else 0  # noqa\n        return idx\n\n    def _patch_mistral_skip_special_tokens(\n        self, request: ChatCompletionRequest\n    ) -> None:\n        \"\"\"Mistral uses special tokens ([THINK]/[/THINK]) for reasoning markers,\n        which get stripped when skip_special_tokens=True.\"\"\"\n        if (\n            self.reasoning_parser in [\"mistral\"]\n            and request.reasoning_effort is not None\n            and request.reasoning_effort != \"none\"\n        ):\n            request.skip_special_tokens = False\n\n    def _get_reasoning_from_request(self, request: ChatCompletionRequest) -> bool:\n        \"\"\"Judge whether the request needs reasoning for hybrid reasoning models\n        NOTE: This is predefined based on model's chat template\n        \"\"\"\n        if not self.reasoning_parser:\n            return False\n        if self.reasoning_parser in [\"deepseek-v3\"]:\n            # Models that require explicit enable thinking (thinking=True)\n            return (\n                request.chat_template_kwargs is not None\n                and request.chat_template_kwargs.get(\"thinking\") is True\n            )\n        if self.reasoning_parser in [\"kimi_k2\"]:\n            # Models that thinking by default, and can be disabled by setting thinking=False\n            return (\n                not request.chat_template_kwargs\n                or request.chat_template_kwargs.get(\"thinking\") is not False\n            )\n        if self.reasoning_parser in [\"qwen3\", \"glm45\", \"nemotron_3\", \"interns1\"]:\n            # Models that thinking by default, and can be disabled by setting enable_thinking=False\n            return (\n                not request.chat_template_kwargs\n                or request.chat_template_kwargs.get(\"enable_thinking\") is not False\n            )\n        if self.reasoning_parser in [\"mistral\"]:\n            # Mistral models only reason when reasoning_effort is explicitly\n            # set to a value other than None/\"none\" (typically \"high\").\n            return (\n                request.reasoning_effort is not None\n                and request.reasoning_effort != \"none\"\n            )\n        return True  # default\n\n    async def _process_tool_call_stream(\n        self,\n        index: int,\n        delta: str,\n        parser_dict: Dict[int, FunctionCallParser],\n        content: Dict[str, Any],\n        request: ChatCompletionRequest,\n        has_tool_calls: Dict[int, bool],\n    ):\n        \"\"\"Process tool calls in streaming response\"\"\"\n        if index not in parser_dict:\n            # Use JSON detector directly for required or named tool choice\n            if request.tool_choice == \"required\" or isinstance(\n                request.tool_choice, ToolChoice\n            ):\n                parser_dict[index] = JsonArrayParser()\n            else:\n                parser_dict[index] = FunctionCallParser(\n                    tools=request.tools,\n                    tool_call_parser=self.tool_call_parser,\n                )\n\n        parser = parser_dict[index]\n\n        # Handle both FunctionCallParser and JsonArrayParser\n        if isinstance(parser, JsonArrayParser):\n            result = parser.parse_streaming_increment(delta, request.tools)\n            normal_text, calls = result.normal_text, result.calls\n        else:\n            normal_text, calls = parser.parse_stream_chunk(delta)\n\n        # Yield normal text\n        if normal_text:\n            choice_data = ChatCompletionResponseStreamChoice(\n                index=index,\n                delta=DeltaMessage(content=normal_text),\n                finish_reason=None,\n            )\n            chunk = ChatCompletionStreamResponse(\n                id=content[\"meta_info\"][\"id\"],\n                created=int(time.time()),\n                choices=[choice_data],\n                model=request.model,\n            )\n\n            # Add usage stats if continuous_usage_stats is enabled\n            if request.stream_options and request.stream_options.continuous_usage_stats:\n                prompt_tokens = content[\"meta_info\"].get(\"prompt_tokens\", 0)\n                completion_tokens = content[\"meta_info\"].get(\"completion_tokens\", 0)\n                chunk.usage = UsageProcessor.calculate_token_usage(\n                    prompt_tokens=prompt_tokens,\n                    completion_tokens=completion_tokens,\n                )\n\n            yield f\"data: {chunk.model_dump_json()}\\n\\n\"\n\n        # Yield tool calls\n        history_tool_calls_cnt = self._get_history_tool_calls_cnt(request)\n        for call_item in calls:\n            # Mark that this choice has tool calls\n            has_tool_calls[index] = True\n\n            # Tool call ID should be generated only once per tool call\n            if call_item.name:\n                # First chunk: include ID and function name\n                tool_call_id = self._process_tool_call_id(\n                    call_item, history_tool_calls_cnt\n                )\n                function_name = call_item.name\n            else:\n                # Subsequent chunks: null ID and name for argument deltas\n                tool_call_id = None\n                function_name = None\n\n            tool_call = ToolCall(\n                id=tool_call_id,\n                index=call_item.tool_index,\n                function=FunctionResponse(\n                    name=function_name,\n                    arguments=call_item.parameters,\n                ),\n            )\n\n            choice_data = ChatCompletionResponseStreamChoice(\n                index=index,\n                delta=DeltaMessage(tool_calls=[tool_call]),\n                finish_reason=None,\n            )\n            chunk = ChatCompletionStreamResponse(\n                id=content[\"meta_info\"][\"id\"],\n                created=int(time.time()),\n                choices=[choice_data],\n                model=request.model,\n            )\n\n            # Add usage stats if continuous_usage_stats is enabled\n            if request.stream_options and request.stream_options.continuous_usage_stats:\n                prompt_tokens = content[\"meta_info\"].get(\"prompt_tokens\", 0)\n                completion_tokens = content[\"meta_info\"].get(\"completion_tokens\", 0)\n                chunk.usage = UsageProcessor.calculate_token_usage(\n                    prompt_tokens=prompt_tokens,\n                    completion_tokens=completion_tokens,\n                )\n\n            yield f\"data: {chunk.model_dump_json()}\\n\\n\"\n\n    def _check_for_unstreamed_tool_args(\n        self,\n        parser: Union[FunctionCallParser, JsonArrayParser],\n        content: Dict[str, Any],\n        request: ChatCompletionRequest,\n        index: int,\n    ) -> Optional[str]:\n        \"\"\"\n        Check for any remaining tool call arguments that need to be streamed\n        when generation finishes. This ensures tool calls are properly completed\n        even if the model generates the final arguments in the last chunk.\n        \"\"\"\n        # Get the detector - either from FunctionCallParser or directly if json detector\n        detector = parser.detector if hasattr(parser, \"detector\") else parser\n\n        # Only check if we have tool calls and the detector has tracked data\n        if (\n            not hasattr(detector, \"prev_tool_call_arr\")\n            or not detector.prev_tool_call_arr\n        ):\n            return None\n\n        if (\n            not hasattr(detector, \"streamed_args_for_tool\")\n            or not detector.streamed_args_for_tool\n        ):\n            return None\n\n        # Get the last tool call that was being processed\n        tool_index = len(detector.prev_tool_call_arr) - 1\n        if tool_index < 0 or tool_index >= len(detector.streamed_args_for_tool):\n            return None\n\n        # Get expected vs actual arguments\n        expected_args = detector.prev_tool_call_arr[tool_index].get(\"arguments\", {})\n        expected_call = json.dumps(expected_args, ensure_ascii=False)\n        actual_call = detector.streamed_args_for_tool[tool_index]\n\n        # Check if there are remaining arguments to send\n        remaining_call = (\n            expected_call.replace(actual_call, \"\", 1)\n            if actual_call in expected_call\n            else \"\"\n        )\n\n        if remaining_call:\n            # Create tool call chunk with remaining arguments\n            tool_call = ToolCall(\n                id=None,  # No ID for argument deltas\n                index=tool_index,\n                function=FunctionResponse(\n                    name=None,  # No name for argument deltas\n                    arguments=remaining_call,\n                ),\n            )\n\n            choice_data = ChatCompletionResponseStreamChoice(\n                index=index,\n                delta=DeltaMessage(tool_calls=[tool_call]),\n                finish_reason=None,  # Don't send finish_reason with this chunk\n            )\n\n            chunk = ChatCompletionStreamResponse(\n                id=content[\"meta_info\"][\"id\"],\n                created=int(time.time()),\n                choices=[choice_data],\n                model=request.model,\n            )\n\n            return f\"data: {chunk.model_dump_json()}\\n\\n\"\n\n        return None\n"
  },
  {
    "path": "python/sglang/srt/entrypoints/openai/serving_classify.py",
    "content": "from __future__ import annotations\n\nimport logging\nimport time\nimport uuid\nfrom typing import TYPE_CHECKING, Any, Dict, List, Optional, Union\n\nimport torch\nimport torch.nn.functional as F\nfrom fastapi import Request\nfrom fastapi.responses import ORJSONResponse\n\nfrom sglang.srt.entrypoints.openai.protocol import (\n    ClassifyRequest,\n    ClassifyResponse,\n    ErrorResponse,\n)\nfrom sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase\nfrom sglang.srt.managers.io_struct import EmbeddingReqInput\n\nif TYPE_CHECKING:\n    from sglang.srt.managers.template_manager import TemplateManager\n    from sglang.srt.managers.tokenizer_manager import TokenizerManager\n\nlogger = logging.getLogger(__name__)\n\n\nclass OpenAIServingClassify(OpenAIServingBase):\n    \"\"\"Handler for v1/classify requests\"\"\"\n\n    def __init__(\n        self,\n        tokenizer_manager: TokenizerManager,\n        template_manager: TemplateManager,\n    ):\n        super().__init__(tokenizer_manager)\n        self.template_manager = template_manager\n        self.id2label = self._get_id2label_mapping()\n        self.model_name = (\n            self.tokenizer_manager.served_model_name\n            if self.tokenizer_manager.served_model_name\n            else self.tokenizer_manager.server_args.model_path\n        )\n        if not self.id2label:\n            raise ValueError(\"id2label mapping is missing\")\n\n    def _request_id_prefix(self) -> str:\n        return \"classify-\"\n\n    def _convert_to_internal_request(\n        self,\n        request: ClassifyRequest,\n        raw_request: Request = None,\n    ) -> tuple[EmbeddingReqInput, ClassifyRequest]:\n        \"\"\"Convert OpenAI embedding request to internal format\"\"\"\n        prompt = request.input\n\n        if isinstance(prompt, str):\n            # Single string input\n            prompt_kwargs = {\"text\": prompt}\n        elif isinstance(prompt, list):\n            if len(prompt) > 0 and isinstance(prompt[0], str):\n                prompt_kwargs = {\"text\": prompt}\n            else:\n                # List of integers (token IDs) or empty list\n                prompt_kwargs = {\"input_ids\": prompt}\n        else:\n            # Other types (should not happen but handle gracefully)\n            prompt_kwargs = {\"input_ids\": prompt}\n\n        adapted_request = EmbeddingReqInput(\n            **prompt_kwargs,\n            rid=request.rid,\n            priority=request.priority,\n        )\n\n        return adapted_request, request\n\n    def _validate_request(self, request: ClassifyRequest) -> Optional[str]:\n        \"\"\"Validate that the input is not empty or whitespace only.\"\"\"\n        if not (input := request.input):\n            return \"Input cannot be empty\"\n\n        # Handle single string\n        if isinstance(input, str):\n            if not input.strip():\n                return \"Input cannot be empty or whitespace only\"\n            return None\n\n        # Handle list inputs\n        if isinstance(input, list):\n            # Check first element to determine type\n            first_item = input[0]\n\n            if isinstance(first_item, str):\n                # List of strings\n                for i, item in enumerate(input):\n                    if not isinstance(item, str):\n                        return f\"All items in input list must be strings\"\n                    if not item.strip():\n                        return f\"Input at index {i} cannot be empty or whitespace only\"\n            elif isinstance(first_item, int):\n                # List of integers (token IDs)\n                for i, item in enumerate(input):\n                    if not isinstance(item, int):\n                        return f\"All items in input list must be integers\"\n                    if item < 0:\n                        return f\"Token ID at index {i} must be non-negative\"\n        return None\n\n    def _get_id2label_mapping(self) -> Optional[Dict[int, str]]:\n        \"\"\"Get id2label mapping from model config.\"\"\"\n        try:\n            hf_config = self.tokenizer_manager.model_config.hf_config\n            # Check for id2label in hf_config\n            if hf_config.id2label:\n                return hf_config.id2label\n            # Check for num_labels and create default mapping if needed\n            if hasattr(hf_config, \"num_labels\") and hf_config.num_labels:\n                num_labels = hf_config.num_labels\n                # Create default mapping: {0: \"LABEL_0\", 1: \"LABEL_1\", ...}\n                return {i: f\"LABEL_{i}\" for i in range(num_labels)}\n\n        except Exception as e:\n            logger.warning(f\"Failed to get id2label mapping: {e}\")\n\n        return None\n\n    async def _handle_non_streaming_request(\n        self,\n        adapted_request: EmbeddingReqInput,\n        request: ClassifyRequest,\n        raw_request: Request,\n    ) -> Union[ClassifyResponse, ErrorResponse, ORJSONResponse]:\n        \"\"\"Handle non-streaming classification request.\"\"\"\n        # Generate request ID\n\n        try:\n            ret = await self.tokenizer_manager.generate_request(\n                adapted_request, raw_request\n            ).__anext__()\n        except ValueError as e:\n            return self.create_error_response(str(e))\n\n        if not isinstance(ret, list):\n            ret = [ret]\n\n        response = self._build_classify_response(ret)\n        return response\n\n    def _build_classify_response(self, ret: List[Dict[str, Any]]) -> ClassifyResponse:\n        request_id = f\"{self._request_id_prefix()}{uuid.uuid4().hex}\"\n        created_time = int(time.time())\n        classify_objects = []\n        prompt_tokens = 0\n        total_latency = 0.0\n\n        for i, item in enumerate(ret):\n            embedding = item.get(\"embedding\", [])\n            meta_info = item.get(\"meta_info\", {})\n\n            prompt_tokens += meta_info.get(\"prompt_tokens\", 0)\n            total_latency += meta_info.get(\"e2e_latency\", 0.0)\n\n            if embedding:\n                try:\n                    embedding_tensor = torch.tensor(embedding, dtype=torch.float32)\n                    probs = F.softmax(embedding_tensor, dim=0).tolist()\n\n                    predicted_class = torch.argmax(embedding_tensor).item()\n\n                    label = self.id2label[predicted_class]\n\n                except Exception as e:\n                    logger.error(f\"Error processing embedding for item {i}: {e}\")\n                    probs = [1.0]\n                    label = \"Default\"\n            else:\n                probs = [1.0]\n                label = \"Default\"\n\n            classify_obj = {\n                \"index\": i,\n                \"label\": label,\n                \"probs\": probs,\n                \"num_classes\": len(probs),\n            }\n            classify_objects.append(classify_obj)\n\n        response = {\n            \"id\": request_id,\n            \"object\": \"list\",\n            \"created\": created_time,\n            \"model\": self.model_name,\n            \"data\": classify_objects,\n            \"usage\": {\n                \"prompt_tokens\": prompt_tokens,\n                \"total_tokens\": prompt_tokens,\n                \"completion_tokens\": 0,\n                \"prompt_tokens_details\": None,\n            },\n        }\n\n        return ClassifyResponse(**response)\n"
  },
  {
    "path": "python/sglang/srt/entrypoints/openai/serving_completions.py",
    "content": "from __future__ import annotations\n\nimport logging\nimport time\nfrom http import HTTPStatus\nfrom typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Optional, Union\n\nfrom fastapi import Request\nfrom fastapi.responses import ORJSONResponse, StreamingResponse\n\nfrom sglang.srt.entrypoints.openai.protocol import (\n    CompletionRequest,\n    CompletionResponse,\n    CompletionResponseChoice,\n    CompletionResponseStreamChoice,\n    CompletionStreamResponse,\n    ErrorResponse,\n    SglExt,\n)\nfrom sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase\nfrom sglang.srt.entrypoints.openai.usage_processor import UsageProcessor\nfrom sglang.srt.entrypoints.openai.utils import (\n    process_cached_tokens_details_from_ret,\n    process_hidden_states_from_ret,\n    process_routed_experts_from_ret,\n    to_openai_style_logprobs,\n)\nfrom sglang.srt.managers.io_struct import GenerateReqInput\nfrom sglang.srt.parser.code_completion_parser import (\n    generate_completion_prompt_from_request,\n)\nfrom sglang.utils import convert_json_schema_to_str\n\nif TYPE_CHECKING:\n    from sglang.srt.managers.template_manager import TemplateManager\n    from sglang.srt.managers.tokenizer_manager import TokenizerManager\n\nlogger = logging.getLogger(__name__)\n\n\nclass OpenAIServingCompletion(OpenAIServingBase):\n    \"\"\"Handler for /v1/completion requests\"\"\"\n\n    def __init__(\n        self,\n        tokenizer_manager: TokenizerManager,\n        template_manager: TemplateManager,\n    ):\n        super().__init__(tokenizer_manager)\n        self.template_manager = template_manager\n\n    def _request_id_prefix(self) -> str:\n        return \"cmpl-\"\n\n    def _validate_request(self, request: CompletionRequest) -> Optional[str]:\n        \"\"\"Validate that the input is valid.\"\"\"\n        prompt = request.prompt\n        if not prompt or (isinstance(prompt, list) and all(not p for p in prompt)):\n            return \"Prompt cannot be empty\"\n\n        return None\n\n    def _convert_to_internal_request(\n        self,\n        request: CompletionRequest,\n        raw_request: Request = None,\n    ) -> tuple[GenerateReqInput, CompletionRequest]:\n        \"\"\"Convert OpenAI completion request to internal format\"\"\"\n        # NOTE: with openai API, the prompt's logprobs are always not computed\n        if request.echo and request.logprobs:\n            logger.warning(\n                \"Echo is not compatible with logprobs. \"\n                \"To compute logprobs of input prompt, please use the native /generate API.\"\n            )\n        # Process prompt\n        prompt = request.prompt\n        if self.template_manager.completion_template_name is not None:\n            prompt = generate_completion_prompt_from_request(request)\n\n        # Set logprob start length based on echo and logprobs\n        if request.echo and request.logprobs:\n            logprob_start_len = 0\n        else:\n            logprob_start_len = -1\n\n        # Build sampling parameters\n        sampling_params = self._build_sampling_params(request)\n\n        # Determine prompt format\n        if isinstance(prompt, str) or (\n            isinstance(prompt, list) and isinstance(prompt[0], str)\n        ):\n            prompt_kwargs = {\"text\": prompt}\n        else:\n            prompt_kwargs = {\"input_ids\": prompt}\n\n        # Extract custom labels from raw request headers\n        custom_labels = self.extract_custom_labels(raw_request)\n\n        # Extract routed_dp_rank from header (has higher priority than body)\n        effective_routed_dp_rank = self.extract_routed_dp_rank_from_header(\n            raw_request, request.routed_dp_rank\n        )\n\n        # Resolve LoRA adapter from model parameter or explicit lora_path\n        lora_path = self._resolve_lora_path(request.model, request.lora_path)\n\n        adapted_request = GenerateReqInput(\n            **prompt_kwargs,\n            sampling_params=sampling_params,\n            return_logprob=request.logprobs is not None,\n            top_logprobs_num=request.logprobs if request.logprobs is not None else 0,\n            logprob_start_len=logprob_start_len,\n            return_text_in_logprobs=True,\n            stream=request.stream,\n            lora_path=lora_path,\n            bootstrap_host=request.bootstrap_host,\n            bootstrap_port=request.bootstrap_port,\n            bootstrap_room=request.bootstrap_room,\n            routed_dp_rank=effective_routed_dp_rank,\n            disagg_prefill_dp_rank=request.disagg_prefill_dp_rank,\n            return_hidden_states=request.return_hidden_states,\n            return_routed_experts=request.return_routed_experts,\n            rid=request.rid,\n            extra_key=self._compute_extra_key(request),\n            priority=request.priority,\n            routing_key=self.extract_routing_key(raw_request),\n            custom_labels=custom_labels,\n            custom_logit_processor=request.custom_logit_processor,\n        )\n\n        return adapted_request, request\n\n    def _build_sampling_params(self, request: CompletionRequest) -> Dict[str, Any]:\n        \"\"\"Build sampling parameters for the request\"\"\"\n        # Start with common parameters\n        sampling_params = {\n            \"temperature\": request.temperature,\n            \"max_new_tokens\": request.max_tokens,\n            \"min_new_tokens\": request.min_tokens,\n            \"stop\": request.stop,\n            \"stop_token_ids\": request.stop_token_ids,\n            \"stop_regex\": request.stop_regex,\n            \"top_p\": request.top_p,\n            \"top_k\": request.top_k,\n            \"min_p\": request.min_p,\n            \"presence_penalty\": request.presence_penalty,\n            \"frequency_penalty\": request.frequency_penalty,\n            \"repetition_penalty\": request.repetition_penalty,\n            \"regex\": request.regex,\n            \"json_schema\": request.json_schema,\n            \"ebnf\": request.ebnf,\n            \"n\": request.n,\n            \"no_stop_trim\": request.no_stop_trim,\n            \"ignore_eos\": request.ignore_eos,\n            \"skip_special_tokens\": request.skip_special_tokens,\n            \"logit_bias\": request.logit_bias,\n            \"custom_params\": request.custom_params,\n            \"sampling_seed\": request.seed,\n        }\n\n        # Handle response_format constraints\n        if request.response_format and request.response_format.type == \"json_schema\":\n            sampling_params[\"json_schema\"] = convert_json_schema_to_str(\n                request.response_format.json_schema.schema_\n            )\n        elif request.response_format and request.response_format.type == \"json_object\":\n            sampling_params[\"json_schema\"] = '{\"type\": \"object\"}'\n        elif (\n            request.response_format and request.response_format.type == \"structural_tag\"\n        ):\n            sampling_params[\"structural_tag\"] = convert_json_schema_to_str(\n                request.response_format.model_dump(by_alias=True)\n            )\n\n        return sampling_params\n\n    async def _handle_streaming_request(\n        self,\n        adapted_request: GenerateReqInput,\n        request: CompletionRequest,\n        raw_request: Request,\n    ) -> StreamingResponse:\n        \"\"\"Handle streaming completion request\"\"\"\n        return StreamingResponse(\n            self._generate_completion_stream(adapted_request, request, raw_request),\n            media_type=\"text/event-stream\",\n            background=self.tokenizer_manager.create_abort_task(adapted_request),\n        )\n\n    async def _generate_completion_stream(\n        self,\n        adapted_request: GenerateReqInput,\n        request: CompletionRequest,\n        raw_request: Request,\n    ) -> AsyncGenerator[str, None]:\n        \"\"\"Generate streaming completion response\"\"\"\n        created = int(time.time())\n\n        # State tracking for streaming\n        stream_buffers = {}\n        n_prev_tokens = {}\n\n        # Usage tracking\n        prompt_tokens = {}\n        completion_tokens = {}\n        cached_tokens = {}\n        hidden_states = {}\n        routed_experts = {}\n\n        try:\n            async for content in self.tokenizer_manager.generate_request(\n                adapted_request, raw_request\n            ):\n                index = content.get(\"index\", 0)\n\n                text = content[\"text\"]\n                prompt_tokens[index] = content[\"meta_info\"].get(\"prompt_tokens\", 0)\n                completion_tokens[index] = content[\"meta_info\"].get(\n                    \"completion_tokens\", 0\n                )\n                cached_tokens[index] = content[\"meta_info\"].get(\"cached_tokens\", 0)\n                hidden_states[index] = content[\"meta_info\"].get(\"hidden_states\", None)\n                routed_experts[index] = content[\"meta_info\"].get(\"routed_experts\", None)\n\n                stream_buffer = stream_buffers.get(index, \"\")\n                # Handle echo for first chunk\n                if not stream_buffer:  # The first chunk\n                    if request.echo:\n                        echo_text = self._get_echo_text(request, index)\n                        text = echo_text + text\n\n                # Handle logprobs\n                logprobs = None\n                if request.logprobs is not None:\n                    # The first chunk and echo is enabled.\n                    if not stream_buffer and request.echo:\n                        input_token_logprobs = content[\"meta_info\"][\n                            \"input_token_logprobs\"\n                        ]\n                        input_top_logprobs = content[\"meta_info\"][\"input_top_logprobs\"]\n                    else:\n                        input_token_logprobs = None\n                        input_top_logprobs = None\n\n                    n_prev_token = n_prev_tokens.get(index, 0)\n                    total_output_logprobs = len(\n                        content[\"meta_info\"][\"output_token_logprobs\"]\n                    )\n                    output_logprobs_slice = content[\"meta_info\"][\n                        \"output_token_logprobs\"\n                    ][n_prev_token:]\n                    finish_reason_for_logprobs = content[\"meta_info\"][\"finish_reason\"]\n\n                    # When finish_reason is set and all logprobs have been sent,\n                    # any remaining text is just buffered text being flushed by the\n                    # detokenizer (it holds back text at word boundaries). Return None\n                    # for logprobs since no new tokens were generated for this text.\n                    if (\n                        len(output_logprobs_slice) == 0\n                        and finish_reason_for_logprobs is not None\n                        and input_token_logprobs is None\n                    ):\n                        logprobs = None\n                    else:\n                        logprobs = to_openai_style_logprobs(\n                            input_token_logprobs=input_token_logprobs,\n                            input_top_logprobs=input_top_logprobs,\n                            output_token_logprobs=output_logprobs_slice,\n                            output_top_logprobs=content[\"meta_info\"].get(\n                                \"output_top_logprobs\", []\n                            )[n_prev_token:],\n                        )\n                    n_prev_tokens[index] = total_output_logprobs\n\n                # Generate delta\n                delta = text[len(stream_buffer) :]\n                stream_buffers[index] = stream_buffer + delta\n                finish_reason = content[\"meta_info\"].get(\"finish_reason\", None)\n                finish_reason_type = finish_reason[\"type\"] if finish_reason else None\n\n                # If the abort is from scheduler.\n                if finish_reason_type == \"abort\":\n                    code = finish_reason.get(\n                        \"status_code\", HTTPStatus.INTERNAL_SERVER_ERROR\n                    )\n                    error = self.create_streaming_error_response(\n                        finish_reason.get(\"message\", \"Generation aborted.\"),\n                        code.name,\n                        code.value,\n                    )\n                    yield f\"data: {error}\\n\\n\"\n                    break\n\n                choice_data = CompletionResponseStreamChoice(\n                    index=index,\n                    text=delta,\n                    logprobs=logprobs,\n                    finish_reason=finish_reason_type,\n                    matched_stop=(\n                        finish_reason[\"matched\"]\n                        if finish_reason and \"matched\" in finish_reason\n                        else None\n                    ),\n                )\n                chunk = CompletionStreamResponse(\n                    id=content[\"meta_info\"][\"id\"],\n                    created=created,\n                    object=\"text_completion\",\n                    choices=[choice_data],\n                    model=request.model,\n                )\n\n                # Add usage stats if continuous_usage_stats is enabled\n                if (\n                    request.stream_options\n                    and request.stream_options.continuous_usage_stats\n                ):\n                    chunk.usage = UsageProcessor.calculate_token_usage(\n                        prompt_tokens=prompt_tokens.get(index, 0),\n                        completion_tokens=completion_tokens.get(index, 0),\n                    )\n\n                yield f\"data: {chunk.model_dump_json()}\\n\\n\"\n\n            if request.return_hidden_states and hidden_states:\n                for index, choice_hidden_states in hidden_states.items():\n                    if choice_hidden_states:\n                        last_token_hidden_states = (\n                            choice_hidden_states[-1]\n                            if len(choice_hidden_states) > 1\n                            else []\n                        )\n                        hidden_states_chunk = CompletionStreamResponse(\n                            id=content[\"meta_info\"][\"id\"],\n                            created=created,\n                            object=\"text_completion\",\n                            choices=[\n                                CompletionResponseStreamChoice(\n                                    index=index,\n                                    text=\"\",\n                                    hidden_states=last_token_hidden_states,\n                                    finish_reason=None,\n                                )\n                            ],\n                            model=request.model,\n                        )\n                        yield f\"data: {hidden_states_chunk.model_dump_json()}\\n\\n\"\n\n            if request.return_routed_experts and routed_experts:\n                # Get first non-None routed_experts value\n                first_routed_experts = next(\n                    (v for v in routed_experts.values() if v is not None), None\n                )\n                if first_routed_experts is not None:\n                    routed_experts_chunk = CompletionStreamResponse(\n                        id=content[\"meta_info\"][\"id\"],\n                        created=created,\n                        object=\"text_completion\",\n                        choices=[],  # sglext is at response level\n                        model=request.model,\n                        sglext=SglExt(routed_experts=first_routed_experts),\n                    )\n                    yield f\"data: {routed_experts_chunk.model_dump_json()}\\n\\n\"\n\n            # Handle final usage chunk\n            if request.stream_options and request.stream_options.include_usage:\n                usage = UsageProcessor.calculate_streaming_usage(\n                    prompt_tokens,\n                    completion_tokens,\n                    cached_tokens,\n                    n_choices=request.n,\n                    enable_cache_report=self.tokenizer_manager.server_args.enable_cache_report,\n                )\n                final_usage_chunk = CompletionStreamResponse(\n                    id=content[\"meta_info\"][\"id\"],\n                    created=created,\n                    choices=[],\n                    model=request.model,\n                    usage=usage,\n                )\n                final_usage_data = final_usage_chunk.model_dump_json(exclude_none=True)\n                yield f\"data: {final_usage_data}\\n\\n\"\n\n        except Exception as e:\n            error = self.create_streaming_error_response(str(e))\n            yield f\"data: {error}\\n\\n\"\n\n        yield \"data: [DONE]\\n\\n\"\n\n    async def _handle_non_streaming_request(\n        self,\n        adapted_request: GenerateReqInput,\n        request: CompletionRequest,\n        raw_request: Request,\n    ) -> Union[CompletionResponse, ErrorResponse, ORJSONResponse]:\n        \"\"\"Handle non-streaming completion request\"\"\"\n        try:\n            generator = self.tokenizer_manager.generate_request(\n                adapted_request, raw_request\n            )\n            ret = await generator.__anext__()\n        except ValueError as e:\n            return self.create_error_response(str(e))\n\n        if not isinstance(ret, list):\n            ret = [ret]\n\n        response = self._build_completion_response(\n            request,\n            ret,\n            int(time.time()),\n        )\n\n        return response\n\n    def _build_completion_response(\n        self,\n        request: CompletionRequest,\n        ret: List[Dict[str, Any]],\n        created: int,\n    ) -> CompletionResponse:\n        \"\"\"Build completion response from generation results\"\"\"\n        choices = []\n        echo = False\n\n        # Prepare echo prompts if needed\n        echo_prompts = []\n        if request.echo:\n            echo_prompts = self._prepare_echo_prompts(request)\n            echo = True\n\n        # Build sglext at response level (from first ret_item, as these are per-request)\n        first_ret = ret[0]\n        routed_experts = process_routed_experts_from_ret(first_ret, request)\n        cached_tokens_details = process_cached_tokens_details_from_ret(\n            first_ret, request\n        )\n        response_sglext = None\n        if routed_experts or cached_tokens_details:\n            response_sglext = SglExt(\n                routed_experts=routed_experts,\n                cached_tokens_details=cached_tokens_details,\n            )\n\n        for idx, ret_item in enumerate(ret):\n            text = ret_item[\"text\"]\n\n            # Handle echo\n            if echo:\n                prompt_index = idx // request.n\n                text = echo_prompts[prompt_index] + text\n\n            # Handle logprobs\n            logprobs = None\n            if request.logprobs is not None:\n                if echo:\n                    input_token_logprobs = ret_item[\"meta_info\"][\"input_token_logprobs\"]\n                    input_top_logprobs = ret_item[\"meta_info\"][\"input_top_logprobs\"]\n                else:\n                    input_token_logprobs = None\n                    input_top_logprobs = None\n\n                logprobs = to_openai_style_logprobs(\n                    input_token_logprobs=input_token_logprobs,\n                    input_top_logprobs=input_top_logprobs,\n                    output_token_logprobs=ret_item[\"meta_info\"].get(\n                        \"output_token_logprobs\", []\n                    ),\n                    output_top_logprobs=ret_item[\"meta_info\"].get(\n                        \"output_top_logprobs\", []\n                    ),\n                )\n\n            # Handle hidden states\n            hidden_states = process_hidden_states_from_ret(ret_item, request)\n\n            finish_reason = ret_item[\"meta_info\"][\"finish_reason\"]\n\n            choice_data = CompletionResponseChoice(\n                index=idx,\n                text=text,\n                logprobs=logprobs,\n                finish_reason=finish_reason[\"type\"] if finish_reason else None,\n                matched_stop=(\n                    finish_reason[\"matched\"]\n                    if finish_reason and \"matched\" in finish_reason\n                    else None\n                ),\n                hidden_states=hidden_states,\n            )\n            choices.append(choice_data)\n\n        # Calculate usage\n        cache_report = self.tokenizer_manager.server_args.enable_cache_report\n        usage = UsageProcessor.calculate_response_usage(\n            ret, n_choices=request.n, enable_cache_report=cache_report\n        )\n\n        return CompletionResponse(\n            id=ret[0][\"meta_info\"][\"id\"],\n            model=request.model,\n            created=created,\n            choices=choices,\n            usage=usage,\n            metadata={\"weight_version\": ret[0][\"meta_info\"][\"weight_version\"]},\n            sglext=response_sglext,\n        )\n\n    def _get_echo_text(self, request: CompletionRequest, index: int) -> str:\n        \"\"\"Get echo text for streaming response\"\"\"\n        if isinstance(request.prompt, str):\n            # for the case of single str prompts\n            return request.prompt\n        elif isinstance(request.prompt, list):\n            if isinstance(request.prompt[0], str):\n                # for the case of multiple str prompts\n                return request.prompt[index // request.n]\n            elif isinstance(request.prompt[0], int):\n                # for the case of single token ids prompt\n                return self.tokenizer_manager.tokenizer.decode(\n                    request.prompt, skip_special_tokens=True\n                )\n            elif isinstance(request.prompt[0], list) and isinstance(\n                request.prompt[0][0], int\n            ):\n                # for the case of multiple token ids prompts\n                return self.tokenizer_manager.tokenizer.decode(\n                    request.prompt[index // request.n],\n                    skip_special_tokens=True,\n                )\n        return \"\"\n\n    def _prepare_echo_prompts(self, request: CompletionRequest) -> List[str]:\n        \"\"\"Prepare echo prompts for non-streaming response\"\"\"\n        # TODO: handle the case prompt is token ids\n        if isinstance(request.prompt, list) and isinstance(request.prompt[0], str):\n            # for the case of multiple str prompts\n            return request.prompt\n        elif isinstance(request.prompt, list) and isinstance(request.prompt[0], list):\n            # for the case of multiple token ids prompts\n            return [\n                self.tokenizer_manager.tokenizer.decode(\n                    prompt, skip_special_tokens=True\n                )\n                for prompt in request.prompt\n            ]\n        elif isinstance(request.prompt, list) and isinstance(request.prompt[0], int):\n            # for the case of single token ids prompt\n            return [\n                self.tokenizer_manager.tokenizer.decode(\n                    request.prompt, skip_special_tokens=True\n                )\n            ]\n        else:\n            # for the case of single str prompt\n            return [request.prompt]\n"
  },
  {
    "path": "python/sglang/srt/entrypoints/openai/serving_embedding.py",
    "content": "from __future__ import annotations\n\nfrom typing import TYPE_CHECKING, Any, Dict, List, Optional, Union\n\nfrom fastapi import Request\nfrom fastapi.responses import ORJSONResponse\n\nfrom sglang.srt.entrypoints.openai.protocol import (\n    EmbeddingObject,\n    EmbeddingRequest,\n    EmbeddingResponse,\n    ErrorResponse,\n    MultimodalEmbeddingInput,\n    UsageInfo,\n)\nfrom sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase\nfrom sglang.srt.managers.io_struct import EmbeddingReqInput\nfrom sglang.srt.parser.conversation import generate_embedding_convs\n\nif TYPE_CHECKING:\n    from sglang.srt.managers.template_manager import TemplateManager\n    from sglang.srt.managers.tokenizer_manager import TokenizerManager\n\n\nclass OpenAIServingEmbedding(OpenAIServingBase):\n    \"\"\"Handler for v1/embeddings requests\"\"\"\n\n    def __init__(\n        self,\n        tokenizer_manager: TokenizerManager,\n        template_manager: TemplateManager,\n    ):\n        super().__init__(tokenizer_manager)\n        self.template_manager = template_manager\n\n    def _request_id_prefix(self) -> str:\n        return \"embd-\"\n\n    def _validate_request(self, request: EmbeddingRequest) -> Optional[str]:\n        \"\"\"Validate that the input is not empty or whitespace only.\"\"\"\n        if not (input := request.input):\n            return \"Input cannot be empty\"\n\n        # Handle single string\n        if isinstance(input, str):\n            if not input.strip():\n                return \"Input cannot be empty or whitespace only\"\n            return None\n\n        # Handle list inputs\n        if isinstance(input, list):\n            if len(input) == 0:\n                return \"Input cannot be empty\"\n\n            # Check first element to determine type\n            first_item = input[0]\n\n            if isinstance(first_item, str):\n                # List of strings\n                for i, item in enumerate(input):\n                    if not isinstance(item, str):\n                        return f\"All items in input list must be strings\"\n                    if not item.strip():\n                        return f\"Input at index {i} cannot be empty or whitespace only\"\n            elif isinstance(first_item, int):\n                # List of integers (token IDs)\n                for i, item in enumerate(input):\n                    if not isinstance(item, int):\n                        return f\"All items in input list must be integers\"\n                    if item < 0:\n                        return f\"Token ID at index {i} must be non-negative\"\n        return None\n\n    def _convert_to_internal_request(\n        self,\n        request: EmbeddingRequest,\n        raw_request: Request = None,\n    ) -> tuple[EmbeddingReqInput, EmbeddingRequest]:\n        \"\"\"Convert OpenAI embedding request to internal format\"\"\"\n        prompt = request.input\n\n        if isinstance(prompt, str):\n            # Single string input\n            prompt_kwargs = {\"text\": prompt}\n        elif isinstance(prompt, list):\n            if len(prompt) > 0 and isinstance(prompt[0], str):\n                prompt_kwargs = {\"text\": prompt}\n            elif len(prompt) > 0 and isinstance(prompt[0], MultimodalEmbeddingInput):\n                # Handle multimodal embedding inputs\n                texts = []\n                images = []\n                videos = []\n                for item in prompt:\n                    # Use padding for text if None - this could be improved\n                    texts.append(item.text if item.text is not None else \"padding\")\n                    images.append(item.image if item.image is not None else None)\n                    videos.append(item.video if item.video is not None else None)\n\n                generate_prompts = []\n                # Check if we have a chat template for multimodal embeddings\n                if self.template_manager.chat_template_name is not None:\n                    convs = generate_embedding_convs(\n                        texts, images, videos, self.template_manager.chat_template_name\n                    )\n                    for conv in convs:\n                        generate_prompts.append(conv.get_prompt())\n                else:\n                    generate_prompts = texts\n\n                if len(generate_prompts) == 1:\n                    prompt_kwargs = {\n                        \"text\": generate_prompts[0],\n                        \"image_data\": images[0],\n                        \"video_data\": videos[0],\n                    }\n                else:\n                    prompt_kwargs = {\n                        \"text\": generate_prompts,\n                        \"image_data\": images,\n                        \"video_data\": videos,\n                    }\n            else:\n                # List of integers (token IDs) or empty list\n                prompt_kwargs = {\"input_ids\": prompt}\n        else:\n            # Other types (should not happen but handle gracefully)\n            prompt_kwargs = {\"input_ids\": prompt}\n\n        # Resolve LoRA adapter from model parameter or explicit lora_path\n        lora_path = self._resolve_lora_path(request.model, request.lora_path)\n\n        adapted_request = EmbeddingReqInput(\n            **prompt_kwargs,\n            rid=request.rid,\n            priority=request.priority,\n            routing_key=self.extract_routing_key(raw_request),\n            dimensions=request.dimensions,\n            lora_path=lora_path,\n        )\n\n        return adapted_request, request\n\n    async def _handle_non_streaming_request(\n        self,\n        adapted_request: EmbeddingReqInput,\n        request: EmbeddingRequest,\n        raw_request: Request,\n    ) -> Union[EmbeddingResponse, ErrorResponse, ORJSONResponse]:\n        \"\"\"Handle the embedding request\"\"\"\n        try:\n            ret = await self.tokenizer_manager.generate_request(\n                adapted_request, raw_request\n            ).__anext__()\n        except ValueError as e:\n            return self.create_error_response(str(e))\n\n        if not isinstance(ret, list):\n            ret = [ret]\n\n        response = self._build_embedding_response(ret)\n        return response\n\n    def _build_embedding_response(self, ret: List[Dict[str, Any]]) -> EmbeddingResponse:\n        \"\"\"Build the embedding response\"\"\"\n        embedding_objects = []\n        prompt_tokens = 0\n\n        for idx, ret_item in enumerate(ret):\n            embedding_objects.append(\n                EmbeddingObject(\n                    embedding=ret_item[\"embedding\"],\n                    index=idx,\n                )\n            )\n            # Handle missing prompt_tokens gracefully\n            meta_info = ret_item.get(\"meta_info\", {})\n            prompt_tokens += meta_info.get(\"prompt_tokens\", 0)\n\n        return EmbeddingResponse(\n            data=embedding_objects,\n            model=self.tokenizer_manager.model_path,\n            usage=UsageInfo(\n                prompt_tokens=prompt_tokens,\n                total_tokens=prompt_tokens,\n            ),\n        )\n"
  },
  {
    "path": "python/sglang/srt/entrypoints/openai/serving_rerank.py",
    "content": "import logging\nfrom typing import Any, Dict, List, Optional, Union\n\nfrom fastapi import Request\nfrom fastapi.responses import ORJSONResponse\n\nfrom sglang.srt.entrypoints.openai.protocol import (\n    ChatCompletionMessageContentImagePart,\n    ChatCompletionMessageContentTextPart,\n    ChatCompletionMessageContentVideoPart,\n    ErrorResponse,\n    RerankContent,\n    RerankResponse,\n    V1RerankReqInput,\n)\nfrom sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase\nfrom sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput\n\nlogger = logging.getLogger(__name__)\n\n\ndef _get_yes_no_token_ids(tokenizer) -> tuple[int, int]:\n    \"\"\"Get token IDs for 'yes' and 'no' from the tokenizer.\n\n    Different model sizes may have different token IDs, so we look them up dynamically.\n    \"\"\"\n    # Try to encode 'yes' and 'no' to get their token IDs\n    # The tokenizer should return a single token for these common words\n    try:\n        yes_tokens = tokenizer.encode(\"yes\", add_special_tokens=False)\n        no_tokens = tokenizer.encode(\"no\", add_special_tokens=False)\n\n        if len(yes_tokens) == 1 and len(no_tokens) == 1:\n            return yes_tokens[0], no_tokens[0]\n\n        # Fallback: try convert_tokens_to_ids\n        yes_id = tokenizer.convert_tokens_to_ids(\"yes\")\n        no_id = tokenizer.convert_tokens_to_ids(\"no\")\n        if yes_id is not None and no_id is not None:\n            return yes_id, no_id\n\n    except Exception as e:\n        logger.warning(f\"Failed to get yes/no token IDs dynamically: {e}\")\n\n    # Fallback to known Qwen3 token IDs (may not work for all model sizes)\n    logger.warning(\"Using fallback token IDs for yes/no (9693/2152)\")\n    return 9693, 2152\n\n\ndef _is_qwen3_reranker_template(chat_template: str) -> bool:\n    \"\"\"Detect if the chat template is for Qwen3 text-only reranker.\"\"\"\n    if not chat_template:\n        return False\n    t = chat_template.lower()\n    return ('answer can only be \"yes\" or \"no\"' in t) or (\n        \"answer can only be\" in t and '\"yes\"' in t and '\"no\"' in t\n    )\n\n\ndef _is_qwen3_vl_reranker_template(chat_template: str) -> bool:\n    \"\"\"Detect if the chat template is for Qwen3-VL multimodal reranker.\n\n    VL reranker templates use `query` and `document` as jinja variables\n    and include vision token placeholders for image/video support.\n    \"\"\"\n    if not chat_template:\n        return False\n    t = chat_template.lower()\n    # Check for reranker phrase (yes/no judgment)\n    has_reranker_phrase = ('answer can only be \"yes\" or \"no\"' in t) or (\n        \"answer can only be\" in t and '\"yes\"' in t and '\"no\"' in t\n    )\n    # Check for vision token placeholders (unique to VL templates)\n    has_vision_tokens = \"<|vision_start|>\" in t or \"<|image_pad|>\" in t\n    return has_reranker_phrase and has_vision_tokens\n\n\ndef _is_qwen3_vl_model(model_path: str) -> bool:\n    \"\"\"Check if the model is a Qwen3-VL model based on model path.\"\"\"\n    if not model_path:\n        return False\n    model_lower = model_path.lower()\n    return \"qwen3-vl\" in model_lower or \"qwen3vl\" in model_lower\n\n\ndef _detect_rerank_backend(\n    *,\n    request: V1RerankReqInput,\n    chat_template: Optional[str],\n    model_path: str,\n) -> str:\n    \"\"\"\n    Unify rerank routing decisions used by both `_convert_to_internal_request` and\n    `_handle_non_streaming_request`.\n\n    Returns:\n        \"vl_decoder\" | \"text_decoder\" | \"cross_encoder\"\n    \"\"\"\n    is_multimodal = request.is_multimodal()\n    is_vl_model = _is_qwen3_vl_model(model_path)\n    is_vl_template = _is_qwen3_vl_reranker_template(chat_template)\n    is_text_template = _is_qwen3_reranker_template(chat_template)\n\n    # Prefer VL when template/model indicates VL, or request is multimodal with reranker template.\n    if is_vl_template or is_vl_model or (is_multimodal and is_text_template):\n        return \"vl_decoder\"\n    if is_text_template:\n        return \"text_decoder\"\n    return \"cross_encoder\"\n\n\ndef _qwen3_rerank_score(p_yes: float, p_no: float) -> float:\n    denom = p_yes + p_no\n    if denom <= 0.0:\n        return 0.0\n    return p_yes / denom\n\n\ndef _get_jinja_env():\n    try:\n        import jinja2  # Lazy import: server env should provide this dependency.\n    except ModuleNotFoundError as e:\n        raise ValueError(\n            \"Rendering Qwen3 reranker prompts requires `jinja2`. \"\n            \"Please install it in your runtime environment (e.g., `pip install jinja2`).\"\n        ) from e\n\n    return jinja2.Environment(\n        loader=jinja2.BaseLoader(),\n        autoescape=False,\n        undefined=jinja2.Undefined,\n    )\n\n\ndef _render_jinja_chat_template(\n    chat_template: str,\n    *,\n    query: RerankContent,\n    document: RerankContent,\n    instruct: Optional[str],\n) -> str:\n    \"\"\"Render a loaded Jinja chat template for Qwen3 reranker prompts (text-only).\"\"\"\n    env = _get_jinja_env()\n    template = env.from_string(chat_template)\n\n    # For text-only template, extract text content\n    query_text = query if isinstance(query, str) else _extract_text_from_content(query)\n    doc_text = (\n        document if isinstance(document, str) else _extract_text_from_content(document)\n    )\n\n    render_kwargs = {\n        \"messages\": [\n            {\"role\": \"user\", \"content\": query_text},\n            {\"role\": \"user\", \"content\": doc_text},\n        ]\n    }\n    # Only pass instruct when explicitly provided; template uses `default(...)`\n    # which works only when the variable is undefined (not None).\n    if instruct:\n        render_kwargs[\"instruct\"] = instruct\n    return template.render(**render_kwargs)\n\n\ndef _render_vl_jinja_template(\n    chat_template: str,\n    *,\n    query: List[Dict[str, Any]],\n    document: List[Dict[str, Any]],\n    instruct: Optional[str],\n) -> str:\n    \"\"\"Render a loaded Jinja chat template for Qwen3-VL reranker prompts (multimodal).\n\n    The template expects `query` and `document` as lists of content parts,\n    where each part has a `type` field (text, image, video) and corresponding data.\n    \"\"\"\n    env = _get_jinja_env()\n    template = env.from_string(chat_template)\n\n    render_kwargs = {\n        \"query\": query,\n        \"document\": document,\n    }\n    if instruct:\n        render_kwargs[\"instruct\"] = instruct\n    return template.render(**render_kwargs)\n\n\ndef _extract_text_from_content(content: RerankContent) -> str:\n    \"\"\"Extract text from multimodal content.\"\"\"\n    if isinstance(content, str):\n        return content\n    texts = []\n    for part in content:\n        if isinstance(part, ChatCompletionMessageContentTextPart):\n            texts.append(part.text)\n        elif isinstance(part, dict) and part.get(\"type\") == \"text\":\n            texts.append(part.get(\"text\", \"\"))\n    return \" \".join(texts)\n\n\nclass OpenAIServingRerank(OpenAIServingBase):\n    \"\"\"Handler for /v1/rerank requests\"\"\"\n\n    def __init__(self, tokenizer_manager, template_manager=None):\n        super().__init__(tokenizer_manager)\n        # TemplateManager is optional; rerank uses tokenizer.chat_template today.\n        # Keeping this explicit makes the dependency clear and supports future extensions.\n        self.template_manager = template_manager\n\n        # Cache yes/no token IDs for Qwen3 reranker scoring\n        self._yes_token_id, self._no_token_id = _get_yes_no_token_ids(\n            tokenizer_manager.tokenizer\n        )\n\n    # NOTE: /v1/rerank is not an official OpenAI endpoint. This module may be moved\n    # to another module in the future.\n\n    def _request_id_prefix(self) -> str:\n        return \"rerank-\"\n\n    def _validate_request(self, request: V1RerankReqInput) -> Optional[str]:\n        \"\"\"Validate rerank request format and content\"\"\"\n        if not request.query:\n            return \"Query cannot be empty\"\n\n        if isinstance(request.query, str):\n            if not request.query.strip():\n                return \"Query cannot be empty or whitespace only\"\n\n        if not request.documents:\n            return \"Documents cannot be empty\"\n\n        for doc in request.documents:\n            if not doc:\n                return \"Each document must be a non-empty string\"\n            if isinstance(doc, str) and not doc.strip():\n                return \"Each document cannot be empty or whitespace only\"\n\n        return None\n\n    def _convert_to_internal_request(\n        self,\n        request: V1RerankReqInput,\n        raw_request: Request = None,\n    ) -> tuple[Union[EmbeddingReqInput, V1RerankReqInput], V1RerankReqInput]:\n        \"\"\"\n        Convert OpenAI rerank request to internal format.\n\n        - For Qwen3-VL reranker (multimodal decoder-only): keep the request.\n        - For Qwen3 reranker (text-only decoder-only): keep the request and score via\n          `tokenizer_manager.score_prompts(...)` in the handler.\n        - For cross-encoder rerank models: adapt into `EmbeddingReqInput` pairs.\n        \"\"\"\n        chat_template = self.tokenizer_manager.tokenizer.chat_template\n        model_path = getattr(self.tokenizer_manager.model_config, \"model_path\", \"\")\n        backend = _detect_rerank_backend(\n            request=request,\n            chat_template=chat_template if isinstance(chat_template, str) else None,\n            model_path=model_path,\n        )\n        if backend in (\"vl_decoder\", \"text_decoder\"):\n            return request, request\n\n        # Cross-encoder rerank: Create pairs of [query, document] for each document.\n        # Note: Cross-encoder only supports text-only content\n        if request.is_multimodal():\n            # Extract text for cross-encoder (multimodal not supported)\n            query_text = _extract_text_from_content(request.query)\n            doc_texts = [_extract_text_from_content(doc) for doc in request.documents]\n            pairs = [[query_text, doc] for doc in doc_texts]\n        else:\n            pairs = [[request.query, doc] for doc in request.documents]\n\n        adapted_request = EmbeddingReqInput(text=pairs, is_cross_encoder_request=True)\n        return adapted_request, request\n\n    async def _handle_non_streaming_request(\n        self,\n        adapted_request: Union[EmbeddingReqInput, V1RerankReqInput],\n        request: V1RerankReqInput,\n        raw_request: Request,\n    ) -> Union[List[RerankResponse], ErrorResponse, ORJSONResponse]:\n        \"\"\"Handle the rerank request\"\"\"\n        chat_template = getattr(self.tokenizer_manager.tokenizer, \"chat_template\", None)\n        model_path = getattr(self.tokenizer_manager.model_config, \"model_path\", \"\")\n        rerank_ret = await self._handle_rerank_paths(\n            request=request,\n            raw_request=raw_request,\n            chat_template=chat_template,\n            model_path=model_path,\n        )\n        if rerank_ret is not None:\n            return rerank_ret\n\n        # Default cross-encoder rerank path (existing behavior).\n        try:\n            if not isinstance(adapted_request, EmbeddingReqInput):\n                raise ValueError(\n                    \"Invalid rerank request adaptation. \"\n                    \"If you are serving a decoder-only reranker (e.g., Qwen3-Reranker), \"\n                    \"please provide the corresponding --chat-template and launch without --is-embedding.\"\n                )\n            ret = await self.tokenizer_manager.generate_request(\n                adapted_request, raw_request\n            ).__anext__()\n        except ValueError as e:\n            return self.create_error_response(str(e))\n\n        if not isinstance(ret, list):\n            ret = [ret]\n\n        responses = self._build_rerank_response(ret, request)\n        return responses\n\n    async def _handle_rerank_paths(\n        self,\n        *,\n        request: V1RerankReqInput,\n        raw_request: Request,\n        chat_template: Optional[str],\n        model_path: str,\n    ) -> Optional[Union[List[RerankResponse], ErrorResponse, ORJSONResponse]]:\n        \"\"\"\n        Handle decoder-only rerank paths (VL/text) and return a response if matched.\n\n        Returns None if the request should fall back to cross-encoder rerank.\n        \"\"\"\n        backend = _detect_rerank_backend(\n            request=request,\n            chat_template=chat_template,\n            model_path=model_path,\n        )\n\n        # Qwen3-VL reranker path (decoder-only scoring with query/document template format)\n        if backend == \"vl_decoder\":\n            return await self._handle_vl_reranker_request(\n                request, raw_request, chat_template or \"\"\n            )\n\n        # Qwen3 text-only reranker path (decoder-only scoring).\n        if backend == \"text_decoder\":\n            return await self._handle_text_reranker_request(\n                request=request,\n                raw_request=raw_request,\n                chat_template=chat_template or \"\",\n            )\n\n        return None\n\n    async def _handle_text_reranker_request(\n        self,\n        *,\n        request: V1RerankReqInput,\n        raw_request: Request,\n        chat_template: str,\n    ) -> Union[List[RerankResponse], ErrorResponse]:\n        \"\"\"Handle text-only decoder reranker request via score_prompts().\"\"\"\n        # Qwen3 reranker relies on decoder-only logprobs. If the server is launched\n        # with --is-embedding, model_config.is_generation is typically False and\n        # logprob scoring is not supported.\n        if not self.tokenizer_manager.model_config.is_generation:\n            return self.create_error_response(\n                \"Detected Qwen3 reranker chat template, but the server is not in generation mode. \"\n                \"Please relaunch without --is-embedding for Qwen3-Reranker models.\"\n            )\n\n        try:\n            prompts = [\n                _render_jinja_chat_template(\n                    chat_template,\n                    query=request.query,\n                    document=doc,\n                    instruct=getattr(request, \"instruct\", None),\n                )\n                for doc in request.documents\n            ]\n\n            result = await self.tokenizer_manager.score_prompts(\n                prompts,\n                label_token_ids=[self._yes_token_id, self._no_token_id],\n                apply_softmax=False,\n                request=raw_request,\n            )\n            scores = [_qwen3_rerank_score(s[0], s[1]) for s in result.scores]\n        except ValueError as e:\n            return self.create_error_response(str(e))\n        except Exception as e:\n            # Includes template rendering errors from jinja2.\n            return self.create_error_response(str(e))\n\n        responses = self._build_rerank_response(scores, request)\n        return responses\n\n    async def _handle_vl_reranker_request(\n        self,\n        request: V1RerankReqInput,\n        raw_request: Request,\n        _chat_template: str,\n    ) -> Union[List[RerankResponse], ErrorResponse]:\n        \"\"\"Handle multimodal VL reranker request using chat completion with logprobs.\"\"\"\n        if not self.tokenizer_manager.model_config.is_generation:\n            return self.create_error_response(\n                \"Detected Qwen3-VL reranker, but the server is not in generation mode. \"\n                \"Please relaunch without --is-embedding for Qwen3-VL-Reranker models.\"\n            )\n\n        try:\n            scores = []\n            instruct = getattr(request, \"instruct\", None)\n\n            for doc in request.documents:\n                # Build multimodal content lists and render prompt using jinja template\n                query_content, doc_content, image_data, video_data = (\n                    self._build_vl_reranker_content(\n                        query=request.query,\n                        document=doc,\n                    )\n                )\n\n                # Render the chat template directly with query/document variables\n                prompt = _render_vl_jinja_template(\n                    chat_template=_chat_template,\n                    query=query_content,\n                    document=doc_content,\n                    instruct=instruct,\n                )\n\n                # Create generate request with logprobs\n                gen_request = GenerateReqInput(\n                    text=prompt,\n                    image_data=image_data if image_data else None,\n                    video_data=video_data if video_data else None,\n                    sampling_params={\n                        \"max_new_tokens\": 1,\n                        \"temperature\": 0,\n                    },\n                    return_logprob=True,\n                    top_logprobs_num=50,  # Get enough logprobs to find yes/no tokens\n                    logprob_start_len=0,\n                )\n\n                # Execute generation request\n                ret = await self.tokenizer_manager.generate_request(\n                    gen_request, raw_request\n                ).__anext__()\n\n                # Extract yes/no probabilities from logprobs\n                score = self._extract_score_from_logprobs(ret)\n                scores.append(score)\n\n            responses = self._build_rerank_response(scores, request)\n            return responses\n\n        except ValueError as e:\n            return self.create_error_response(str(e))\n        except Exception as e:\n            logger.exception(\"Error handling VL reranker request\")\n            return self.create_error_response(str(e))\n\n    def _build_vl_reranker_content(\n        self,\n        query: RerankContent,\n        document: RerankContent,\n    ) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[str], List[str]]:\n        \"\"\"Build content lists for VL reranker request.\n\n        Returns:\n            Tuple of (query_content, document_content, image_data, video_data)\n            where query_content and document_content are lists suitable for jinja template.\n        \"\"\"\n        image_data = []\n        video_data = []\n\n        # Build query content list\n        query_content = self._content_to_template_list(query, image_data, video_data)\n\n        # Build document content list\n        doc_content = self._content_to_template_list(document, image_data, video_data)\n\n        return query_content, doc_content, image_data, video_data\n\n    def _content_to_template_list(\n        self,\n        content: RerankContent,\n        image_data: List[str],\n        video_data: List[str],\n    ) -> List[Dict[str, Any]]:\n        \"\"\"Convert RerankContent to a list format suitable for jinja template.\"\"\"\n        result = []\n\n        if isinstance(content, str):\n            result.append({\"type\": \"text\", \"text\": content})\n            return result\n\n        for part in content:\n            if isinstance(part, ChatCompletionMessageContentTextPart):\n                result.append({\"type\": \"text\", \"text\": part.text})\n            elif isinstance(part, ChatCompletionMessageContentImagePart):\n                if part.image_url:\n                    image_data.append(part.image_url.url)\n                    result.append({\"type\": \"image\"})\n            elif isinstance(part, ChatCompletionMessageContentVideoPart):\n                if part.video_url:\n                    video_data.append(part.video_url.url)\n                    result.append({\"type\": \"video\"})\n            elif isinstance(part, dict):\n                part_type = part.get(\"type\")\n                if part_type == \"text\":\n                    result.append({\"type\": \"text\", \"text\": part.get(\"text\", \"\")})\n                elif part_type == \"image_url\":\n                    image_url = part.get(\"image_url\", {})\n                    if isinstance(image_url, dict):\n                        url = image_url.get(\"url\")\n                    else:\n                        url = image_url\n                    if url:\n                        image_data.append(url)\n                        result.append({\"type\": \"image\"})\n                elif part_type == \"video_url\":\n                    video_url = part.get(\"video_url\", {})\n                    if isinstance(video_url, dict):\n                        url = video_url.get(\"url\")\n                    else:\n                        url = video_url\n                    if url:\n                        video_data.append(url)\n                        result.append({\"type\": \"video\"})\n\n        return result\n\n    def _extract_score_from_logprobs(self, ret: Dict[str, Any]) -> float:\n        \"\"\"Extract reranking score from generation response with logprobs.\"\"\"\n        import math\n\n        # Get logprobs from the response\n        meta_info = ret.get(\"meta_info\", {})\n        output_top_logprobs = meta_info.get(\"output_top_logprobs\", [])\n\n        # Use output_top_logprobs[0] - the model's prediction for the first generated token\n        top_logprobs = output_top_logprobs[0] if output_top_logprobs else []\n\n        # Find yes and no token probabilities\n        # Format: list of tuples (logprob, token_id, token_text)\n        p_yes = 0.0\n        p_no = 0.0\n\n        for item in top_logprobs:\n            logprob, token_id = item[0], item[1]\n            if token_id == self._yes_token_id:\n                p_yes = math.exp(logprob)\n            elif token_id == self._no_token_id:\n                p_no = math.exp(logprob)\n\n        return _qwen3_rerank_score(p_yes, p_no)\n\n    def _build_rerank_response(\n        self, ret: Union[List[Dict[str, Any]], List[float]], request: V1RerankReqInput\n    ) -> List[RerankResponse]:\n        \"\"\"Build the rerank response from generation results\"\"\"\n        responses = []\n        for idx, item in enumerate(ret):\n            if isinstance(item, dict):\n                score_val = item.get(\"embedding\")\n                # Some rerank/reward models return scalar score as embedding[0].\n                if isinstance(score_val, list):\n                    if len(score_val) == 0 or not isinstance(\n                        score_val[0], (int, float)\n                    ):\n                        raise ValueError(\n                            f\"Invalid embedding score for rerank at index {idx}: {score_val!r}\"\n                        )\n                    score_val = float(score_val[0])\n                responses.append(\n                    RerankResponse(\n                        score=float(score_val),\n                        document=(\n                            request.documents[idx] if request.return_documents else None\n                        ),\n                        index=idx,\n                        meta_info=item.get(\"meta_info\"),\n                    )\n                )\n            else:\n                responses.append(\n                    RerankResponse(\n                        score=float(item),\n                        document=(\n                            request.documents[idx] if request.return_documents else None\n                        ),\n                        index=idx,\n                    )\n                )\n\n        # Sort by score in descending order (highest relevance first)\n        responses.sort(key=lambda x: x.score, reverse=True)\n\n        # Apply top_n limit if specified\n        if request.top_n is not None and request.top_n > 0:\n            responses = responses[: request.top_n]\n\n        return responses\n"
  },
  {
    "path": "python/sglang/srt/entrypoints/openai/serving_responses.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n# Adapted from vLLM's OpenAIServingResponses\n\"\"\"Handler for /v1/responses requests\"\"\"\n\nfrom __future__ import annotations\n\nimport asyncio\nimport copy\nimport json\nimport logging\nimport time\nfrom contextlib import AsyncExitStack\nfrom http import HTTPStatus\nfrom typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Optional, Union\n\nimport jinja2\nimport openai.types.responses as openai_responses_types\nimport orjson\nfrom fastapi import Request\nfrom fastapi.responses import ORJSONResponse\nfrom openai.types.responses import (\n    ResponseOutputMessage,\n    ResponseOutputText,\n    ResponseReasoningItem,\n)\nfrom openai.types.responses.response_function_tool_call import ResponseFunctionToolCall\nfrom openai.types.responses.response_reasoning_item import (\n    Content as ResponseReasoningTextContent,\n)\nfrom openai_harmony import Message as OpenAIMessage\n\nfrom sglang.srt.entrypoints.context import (\n    ConversationContext,\n    HarmonyContext,\n    SimpleContext,\n    StreamingHarmonyContext,\n)\nfrom sglang.srt.entrypoints.harmony_utils import (\n    get_developer_message,\n    get_stop_tokens_for_assistant_actions,\n    get_system_message,\n    get_user_message,\n    parse_output_message,\n    parse_remaining_state,\n    parse_response_input,\n    render_for_completion,\n)\nfrom sglang.srt.entrypoints.openai.protocol import (\n    ChatCompletionMessageParam,\n    ChatCompletionRequest,\n    PromptTokenUsageInfo,\n    RequestResponseMetadata,\n    ResponsesRequest,\n    ResponsesResponse,\n    UsageInfo,\n)\nfrom sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat\nfrom sglang.srt.entrypoints.openai.tool_server import MCPToolServer, ToolServer\nfrom sglang.srt.managers.io_struct import GenerateReqInput\nfrom sglang.srt.parser.reasoning_parser import ReasoningParser\nfrom sglang.srt.utils import random_uuid\n\nif TYPE_CHECKING:\n    from sglang.srt.managers.template_manager import TemplateManager\n    from sglang.srt.managers.tokenizer_manager import TokenizerManager\n\nlogger = logging.getLogger(__name__)\n\n\nclass OpenAIServingResponses(OpenAIServingChat):\n    \"\"\"Handler for /v1/responses requests\"\"\"\n\n    def __init__(\n        self,\n        tokenizer_manager: TokenizerManager,\n        template_manager: TemplateManager,\n        *,\n        enable_prompt_tokens_details: bool = False,\n        enable_force_include_usage: bool = False,\n        tool_server: Optional[ToolServer] = None,\n    ) -> None:\n        super().__init__(tokenizer_manager, template_manager)\n\n        # template_manager is already set by parent class\n        self.reasoning_parser = self.tokenizer_manager.server_args.reasoning_parser\n        self.enable_prompt_tokens_details = enable_prompt_tokens_details\n        self.enable_force_include_usage = enable_force_include_usage\n\n        # Get default sampling params from model config if available\n        self.default_sampling_params = {}\n\n        self.supports_browsing = (\n            tool_server.has_tool(\"browser\") if tool_server else False\n        )\n        self.supports_code_interpreter = (\n            tool_server.has_tool(\"python\") if tool_server else False\n        )\n        self.tool_server = tool_server\n        # Get from model config\n        self.use_harmony = (\n            self.tokenizer_manager.model_config.hf_config.model_type == \"gpt_oss\"\n        )\n\n        if self.use_harmony:\n            # OpenAI models have two EOS-like tokens: <|return|> and <|call|>.\n            # We need to add them to the stop token ids.\n            if \"stop_token_ids\" not in self.default_sampling_params:\n                self.default_sampling_params[\"stop_token_ids\"] = []\n            self.default_sampling_params[\"stop_token_ids\"].extend(\n                get_stop_tokens_for_assistant_actions()\n            )\n\n        # Response storage for background and retrieval operations\n        # Note: In production, this should use a proper storage backend (Redis, database)\n        # with TTL/expiration to prevent memory leaks\n        self.response_store: dict[str, ResponsesResponse] = {}\n        self.response_store_lock = asyncio.Lock()\n\n        # Message storage for conversation continuity\n        # Note: In production, this should use a proper storage backend (Redis, database)\n        # with TTL/expiration to prevent memory leaks\n        self.msg_store: dict[\n            str, Union[list[ChatCompletionMessageParam], list[\"OpenAIMessage\"]]\n        ] = {}\n\n        self.background_tasks: dict[str, asyncio.Task] = {}\n\n    # error helpers dedicated for v1/responses\n    def create_error_response(\n        self,\n        message: str,\n        err_type: str = \"invalid_request_error\",\n        status_code: int = 400,\n        param: Optional[str] = None,\n    ) -> ORJSONResponse:\n        nested_error = {\n            \"message\": message,\n            \"type\": err_type,\n            \"param\": param,\n            \"code\": status_code,\n        }\n        return ORJSONResponse(content={\"error\": nested_error}, status_code=status_code)\n\n    def create_streaming_error_response(\n        self,\n        message: str,\n        err_type: str = \"BadRequestError\",\n        status_code: int = 400,\n    ) -> str:\n        return json.dumps(\n            {\n                \"error\": {\n                    \"message\": message,\n                    \"type\": err_type,\n                    \"param\": None,\n                    \"code\": status_code,\n                }\n            }\n        )\n\n    def _request_id_prefix(self) -> str:\n        return \"resp_\"\n\n    async def create_responses(\n        self,\n        request: ResponsesRequest,\n        raw_request: Optional[Request] = None,\n    ) -> Union[AsyncGenerator[str, None], ResponsesResponse, ORJSONResponse]:\n        # Validate model\n        if not self.tokenizer_manager:\n            return self.create_error_response(\"Model not loaded\")\n\n        # FIXME: If the engine is dead, raise an error\n        # This is required for the streaming case\n\n        # Handle the previous response ID\n        prev_response_id = request.previous_response_id\n        if prev_response_id is not None:\n            if not prev_response_id.startswith(\"resp_\"):\n                return self._make_invalid_id_error(prev_response_id)\n            async with self.response_store_lock:\n                prev_response = self.response_store.get(prev_response_id)\n            if prev_response is None:\n                return self._make_not_found_error(prev_response_id)\n        else:\n            prev_response = None\n\n        try:\n            model_name = request.model\n            tokenizer = self.tokenizer_manager.tokenizer\n\n            if self.use_harmony:\n                messages, request_prompts, engine_prompts = (\n                    self._make_request_with_harmony(request, prev_response)\n                )\n            else:\n                messages, request_prompts, engine_prompts = await self._make_request(\n                    request, prev_response, tokenizer\n                )\n\n        except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e:\n            logger.exception(\"Error in preprocessing prompt inputs\")\n            return self.create_error_response(f\"{e} {e.__cause__}\")\n\n        request_metadata = RequestResponseMetadata(request_id=request.request_id)\n        if raw_request:\n            raw_request.state.request_metadata = request_metadata\n\n        if (\n            self.tool_server is not None\n            and isinstance(self.tool_server, MCPToolServer)\n            and (request.background or request.stream)\n            and request.tools\n            and any(\n                tool.type in [\"web_search_preview\", \"code_interpreter\"]\n                for tool in request.tools\n            )\n        ):\n            return self.create_error_response(\n                \"MCP tool server is not supported in background mode and \"\n                \"streaming mode\"\n            )\n\n        # Schedule the request and get the result generator\n        generators: list[AsyncGenerator[Any, None]] = []\n        tool_list = []\n        if self.use_harmony:\n            if self.supports_browsing:\n                tool_list.append(\"browser\")\n            if self.supports_code_interpreter:\n                tool_list.append(\"python\")\n        async with AsyncExitStack() as exit_stack:\n            try:\n                if self.tool_server is not None:\n                    tool_session_ctxs: dict[str, Any] = {\n                        tool_name: exit_stack.enter_async_context(\n                            self.tool_server.get_tool_session(tool_name)\n                        )\n                        for tool_name in tool_list\n                    }\n                    tool_sessions = {}\n                    for tool_name in tool_list:\n                        tool_sessions[tool_name] = await tool_session_ctxs[tool_name]\n                else:\n                    assert len(tool_list) == 0\n                    tool_sessions = {}\n                for i, engine_prompt in enumerate(engine_prompts):\n                    # Calculate default max tokens from context length minus prompt length\n                    if hasattr(engine_prompt, \"__len__\"):\n                        prompt_length = len(engine_prompt)\n                    elif isinstance(engine_prompt, list):\n                        prompt_length = len(engine_prompt)\n                    else:\n                        prompt_length = 0\n\n                    context_len = (\n                        self.tokenizer_manager.model_config.context_len\n                        if hasattr(self.tokenizer_manager.model_config, \"context_len\")\n                        else 4096\n                    )\n                    # Account for reserved tokens (e.g., EAGLE speculative decoding slots)\n                    # that the tokenizer_manager adds during validation\n                    num_reserved_tokens = self.tokenizer_manager.num_reserved_tokens\n                    default_max_tokens = max(\n                        context_len - prompt_length - num_reserved_tokens, 512\n                    )  # Ensure minimum 512 tokens\n                    sampling_params = request.to_sampling_params(\n                        default_max_tokens, self.default_sampling_params\n                    )\n\n                    context: ConversationContext\n                    if self.use_harmony:\n                        if request.stream:\n                            context = StreamingHarmonyContext(messages, tool_sessions)\n                        else:\n                            context = HarmonyContext(messages, tool_sessions)\n                    else:\n                        context = SimpleContext()\n\n                    # Create GenerateReqInput for SGLang\n                    adapted_request = GenerateReqInput(\n                        input_ids=engine_prompt,\n                        sampling_params=sampling_params,\n                        stream=request.stream,\n                        rid=request.request_id,\n                        extra_key=self._compute_extra_key(request),\n                        background=request.background,\n                    )\n\n                    generator = self._generate_with_builtin_tools(\n                        request.request_id,\n                        request_prompts[i],\n                        adapted_request,\n                        sampling_params,\n                        context,\n                        raw_request=raw_request,\n                        priority=request.priority,\n                    )\n                    generators.append(generator)\n            except ValueError as e:\n                return self.create_error_response(str(e))\n\n            assert len(generators) == 1\n            (result_generator,) = generators\n\n            # Store the input messages\n            if request.store:\n                self.msg_store[request.request_id] = messages\n\n            if request.background:\n                created_time = int(time.time())\n                response = ResponsesResponse.from_request(\n                    request,\n                    sampling_params,\n                    model_name=model_name,\n                    created_time=created_time,\n                    output=[],\n                    status=\"queued\",\n                    usage=None,\n                )\n                async with self.response_store_lock:\n                    self.response_store[response.id] = response\n\n                # Run the request in the background\n                task = asyncio.create_task(\n                    self._run_background_request(\n                        request,\n                        sampling_params,\n                        result_generator,\n                        context,\n                        model_name,\n                        tokenizer,\n                        request_metadata,\n                        created_time,\n                    ),\n                    name=f\"create_{response.id}\",\n                )\n\n                # For cleanup\n                self.background_tasks[response.id] = task\n                task.add_done_callback(\n                    lambda _: self.background_tasks.pop(response.id, None)\n                )\n                return response\n\n            if request.stream:\n                return self.responses_stream_generator(\n                    request,\n                    sampling_params,\n                    result_generator,\n                    context,\n                    model_name,\n                    tokenizer,\n                    request_metadata,\n                )\n            try:\n                result: Union[ORJSONResponse, ResponsesResponse] = (\n                    await self.responses_full_generator(\n                        request,\n                        sampling_params,\n                        result_generator,\n                        context,\n                        model_name,\n                        tokenizer,\n                        request_metadata,\n                    )\n                )\n                return result\n            except Exception as e:\n                return self.create_error_response(str(e))\n        return self.create_error_response(\"Unknown error\")\n\n    async def _make_request(\n        self,\n        request: ResponsesRequest,\n        prev_response: Optional[ResponsesResponse],\n        tokenizer: Any,\n    ):\n        # Construct the input messages\n        messages = self._construct_input_messages(request, prev_response)\n\n        # Follow SGLang's pattern: create a ChatCompletionRequest and process messages\n        try:\n            # Convert ResponsesRequest to ChatCompletionRequest for processing\n            chat_request = ChatCompletionRequest(\n                model=request.model,\n                messages=messages,\n                stream=request.stream,\n            )\n\n            # Follow SGLang's _process_messages pattern\n            is_multimodal = self.tokenizer_manager.model_config.is_multimodal\n            processed_messages = self._process_messages(chat_request, is_multimodal)\n\n            # Extract the results\n            if is_multimodal:\n                request_prompts = [processed_messages.prompt]\n                engine_prompts = [processed_messages.prompt]\n            else:\n                request_prompts = [processed_messages.prompt_ids]\n                engine_prompts = [processed_messages.prompt_ids]\n\n        except Exception as e:\n            logger.warning(f\"Chat processing failed, using fallback: {e}\")\n            # Fallback to simple encoding\n            prompt_text = \"\"\n            for msg in messages:\n                role = msg.get(\"role\", \"user\")\n                content = msg.get(\"content\", \"\")\n                prompt_text += f\"{role}: {content}\\n\"\n            prompt_ids = tokenizer.encode(prompt_text)\n            request_prompts = [prompt_ids]\n            engine_prompts = [prompt_ids]\n\n        return messages, request_prompts, engine_prompts\n\n    def _make_request_with_harmony(\n        self,\n        request: ResponsesRequest,\n        prev_response: Optional[ResponsesResponse],\n    ):\n        if request.tool_choice != \"auto\":\n            raise NotImplementedError(\n                \"Only 'auto' tool_choice is supported in \" \"response API\"\n            )\n        messages = self._construct_input_messages_with_harmony(request, prev_response)\n        prompt_token_ids = render_for_completion(messages)\n        engine_prompt = prompt_token_ids\n        return messages, [prompt_token_ids], [engine_prompt]\n\n    async def responses_full_generator(\n        self,\n        request: ResponsesRequest,\n        sampling_params: Any,\n        result_generator: AsyncIterator[Any],\n        context: ConversationContext,\n        model_name: str,\n        tokenizer: Any,\n        request_metadata: RequestResponseMetadata,\n        created_time: Optional[int] = None,\n    ) -> Union[ResponsesResponse, ORJSONResponse]:\n        if created_time is None:\n            created_time = int(time.time())\n\n        try:\n            async for _ in result_generator:\n                pass\n        except asyncio.CancelledError:\n            return self.create_error_response(\"Client disconnected\")\n        except ValueError as e:\n            return self.create_error_response(str(e))\n\n        if self.use_harmony:\n            assert isinstance(context, HarmonyContext)\n            output = self._make_response_output_items_with_harmony(context)\n            # TODO: these are all 0 for now!\n            num_prompt_tokens = context.num_prompt_tokens\n            num_generated_tokens = context.num_output_tokens\n            num_cached_tokens = context.num_cached_tokens\n            num_reasoning_tokens = context.num_reasoning_tokens\n        else:\n            assert isinstance(context, SimpleContext)\n            final_res = context.last_output\n            assert final_res is not None\n\n            output = self._make_response_output_items(\n                request, final_res[\"text\"], tokenizer\n            )\n\n            # Calculate usage from actual output\n            if hasattr(final_res, \"meta_info\"):\n                num_prompt_tokens = final_res.meta_info.get(\"prompt_tokens\", 0)\n                num_generated_tokens = final_res.meta_info.get(\"completion_tokens\", 0)\n                num_cached_tokens = final_res.meta_info.get(\"cached_tokens\", 0)\n            elif hasattr(final_res, \"prompt_token_ids\") and hasattr(\n                final_res, \"outputs\"\n            ):\n                # Fallback calculation if meta_info not available\n                num_prompt_tokens = (\n                    len(final_res.prompt_token_ids) if final_res.prompt_token_ids else 0\n                )\n                num_generated_tokens = (\n                    len(final_res.outputs[0].token_ids)\n                    if final_res.outputs and final_res.outputs[0].token_ids\n                    else 0\n                )\n                num_cached_tokens = getattr(final_res, \"num_cached_tokens\", 0)\n                num_reasoning_tokens = 0\n            else:\n                # Final fallback\n                num_prompt_tokens = 0\n                num_generated_tokens = 0\n                num_cached_tokens = 0\n                num_reasoning_tokens = 0\n\n        usage = UsageInfo(\n            prompt_tokens=num_prompt_tokens,\n            completion_tokens=num_generated_tokens,\n            total_tokens=num_prompt_tokens + num_generated_tokens,\n            reasoning_tokens=num_reasoning_tokens,\n        )\n        if self.enable_prompt_tokens_details and num_cached_tokens:\n            usage.prompt_tokens_details = PromptTokenUsageInfo(\n                cached_tokens=num_cached_tokens\n            )\n        request_metadata.final_usage_info = usage\n\n        response = ResponsesResponse.from_request(\n            request,\n            sampling_params,\n            model_name=model_name,\n            created_time=created_time,\n            output=output,\n            status=\"completed\",\n            usage=usage,\n        )\n\n        if request.store:\n            async with self.response_store_lock:\n                stored_response = self.response_store.get(response.id)\n                # If the response is already cancelled, don't update it\n                if stored_response is None or stored_response.status != \"cancelled\":\n                    self.response_store[response.id] = response\n\n        return response\n\n    def _make_response_output_items(\n        self,\n        request: ResponsesRequest,\n        final_output: Any,\n        tokenizer: Any,\n    ):\n        # Handle reasoning parsing if enabled\n        if self.reasoning_parser:\n            # Use standard reasoning parser (openai maps to T4Detector internally)\n            reasoning_parser = ReasoningParser(\n                model_type=self.reasoning_parser,\n                stream_reasoning=False,\n                request=request,\n            )\n            reasoning_content, content = reasoning_parser.parse_non_stream(final_output)\n        else:\n            reasoning_content = None\n            content = final_output\n\n        output_items = []\n        if reasoning_content:\n            reasoning_item = ResponseReasoningItem(\n                id=f\"rs_{random_uuid()}\",\n                type=\"reasoning\",\n                summary=[],\n                content=[\n                    ResponseReasoningTextContent(\n                        type=\"reasoning_text\", text=reasoning_content\n                    ),\n                ],\n                status=None,\n            )\n            output_items.append(reasoning_item)\n        if content:\n            output_text = ResponseOutputText(\n                text=content,\n                annotations=[],  # TODO\n                type=\"output_text\",\n                logprobs=None,  # TODO\n            )\n            message = ResponseOutputMessage(\n                id=f\"msg_{random_uuid()}\",\n                content=[output_text],\n                role=\"assistant\",\n                status=\"completed\",\n                type=\"message\",\n            )\n            output_items.append(message)\n        return output_items\n\n    def _make_response_output_items_with_harmony(\n        self,\n        context: HarmonyContext,\n    ):\n        output_items = []\n        num_init_messages = context.num_init_messages\n        for msg in context.messages[num_init_messages:]:\n            output_items.extend(parse_output_message(msg))\n        # Handle the generation stopped in the middle (if any).\n        last_items = parse_remaining_state(context.parser)\n        if last_items:\n            output_items.extend(last_items)\n        return output_items\n\n    def _construct_input_messages(\n        self,\n        request: ResponsesRequest,\n        prev_response: Optional[ResponsesResponse] = None,\n    ) -> list[ChatCompletionMessageParam]:\n        messages: list[ChatCompletionMessageParam] = []\n        if request.instructions:\n            messages.append(\n                {\n                    \"role\": \"system\",\n                    \"content\": request.instructions,\n                }\n            )\n\n        # Prepend the conversation history\n        if prev_response is not None:\n            # Add the previous messages\n            prev_msg = self.msg_store[prev_response.id]\n            messages.extend(prev_msg)\n\n            # Add the previous output\n            for output_item in prev_response.output:\n                # NOTE: We skip the reasoning output of the previous response\n                if isinstance(output_item, ResponseReasoningItem):\n                    continue\n                for content in output_item.content:\n                    messages.append(\n                        {\n                            \"role\": \"system\",\n                            \"content\": request.instructions,\n                        }\n                    )\n\n        # Append the new input\n        # Responses API supports simple text inputs without chat format\n        if isinstance(request.input, str):\n            messages.append({\"role\": \"user\", \"content\": request.input})\n        else:\n            messages.extend(request.input)  # type: ignore\n        return messages\n\n    def _construct_input_messages_with_harmony(\n        self,\n        request: ResponsesRequest,\n        prev_response: Optional[ResponsesResponse],\n    ) -> list[\"OpenAIMessage\"]:\n        messages: list[\"OpenAIMessage\"] = []\n        if prev_response is None:\n            # New conversation.\n            reasoning_effort = request.reasoning.effort if request.reasoning else None\n            tool_types = [tool.type for tool in request.tools]\n            enable_browser = (\n                \"web_search_preview\" in tool_types and self.tool_server is not None\n            )\n            enable_code_interpreter = (\n                \"code_interpreter\" in tool_types and self.tool_server is not None\n            )\n            sys_msg = get_system_message(\n                reasoning_effort=reasoning_effort,\n                browser_description=(\n                    self.tool_server.get_tool_description(\"browser\")\n                    if self.tool_server and enable_browser\n                    else None\n                ),\n                python_description=(\n                    self.tool_server.get_tool_description(\"python\")\n                    if self.tool_server and enable_code_interpreter\n                    else None\n                ),\n            )\n            messages.append(sys_msg)\n            dev_msg = get_developer_message(request.instructions, request.tools)\n            messages.append(dev_msg)\n        else:\n            # Continue the previous conversation.\n            # FIXME: Currently, request params like reasoning and\n            # instructions are ignored.\n            prev_msgs = self.msg_store[prev_response.id]\n            # Remove the previous chain-of-thoughts if there is a new \"final\"\n            # message.\n            if (\n                len(prev_msgs) > 0\n                and hasattr(prev_msgs[-1], \"channel\")\n                and prev_msgs[-1].channel == \"final\"\n            ):  # type: ignore[union-attr]\n                prev_final_msg_idx = -1\n                for i in range(len(prev_msgs) - 2, -1, -1):\n                    if (\n                        hasattr(prev_msgs[i], \"channel\")\n                        and prev_msgs[i].channel == \"final\"\n                    ):  # type: ignore[union-attr]\n                        prev_final_msg_idx = i\n                        break\n                recent_turn_msgs = prev_msgs[prev_final_msg_idx + 1 :]\n                del prev_msgs[prev_final_msg_idx + 1 :]\n                for msg in recent_turn_msgs:\n                    if (\n                        hasattr(msg, \"channel\") and msg.channel != \"analysis\"\n                    ):  # type: ignore[union-attr]\n                        prev_msgs.append(msg)\n            messages.extend(prev_msgs)\n        # Append the new input.\n        # Responses API supports simple text inputs without chat format.\n        if isinstance(request.input, str):\n            messages.append(get_user_message(request.input))\n        else:\n            if prev_response is not None:\n                prev_outputs = copy(prev_response.output)\n            else:\n                prev_outputs = []\n            for response_msg in request.input:\n                messages.append(parse_response_input(response_msg, prev_outputs))\n                if isinstance(response_msg, ResponseFunctionToolCall):\n                    prev_outputs.append(response_msg)\n        return messages\n\n    async def _run_background_request(\n        self,\n        request: ResponsesRequest,\n        sampling_params: Any,\n        result_generator: AsyncIterator[Any],\n        context: ConversationContext,\n        model_name: str,\n        tokenizer: Any,\n        request_metadata: RequestResponseMetadata,\n        created_time: Optional[int] = None,\n        *args,\n        **kwargs,\n    ):\n        try:\n            # Update the status to \"in_progress\"\n            async with self.response_store_lock:\n                stored_response = self.response_store.get(request.request_id)\n                assert stored_response is not None\n                stored_response.status = \"in_progress\"\n\n            response = await self.responses_full_generator(\n                request,\n                sampling_params,\n                result_generator,\n                context,\n                model_name,\n                tokenizer,\n                request_metadata,\n                created_time,\n                *args,\n                **kwargs,\n            )\n        except Exception as e:\n            logger.exception(\"Background request failed for %s\", request.request_id)\n            response = self.create_error_response(str(e))\n\n        if isinstance(response, ORJSONResponse):\n            # If the request has failed, update the status to \"failed\"\n            response_id = request.request_id\n            async with self.response_store_lock:\n                stored_response = self.response_store.get(response_id)\n                assert stored_response is not None\n                if stored_response.status not in (\"completed\", \"cancelled\"):\n                    stored_response.status = \"failed\"\n\n    async def retrieve_responses(\n        self,\n        response_id: str,\n    ) -> Union[ResponsesResponse, ORJSONResponse]:\n        if not response_id.startswith(\"resp_\"):\n            return self._make_invalid_id_error(response_id)\n\n        async with self.response_store_lock:\n            response = self.response_store.get(response_id)\n\n        if response is None:\n            return self._make_not_found_error(response_id)\n        return response\n\n    async def cancel_responses(\n        self,\n        response_id: str,\n    ) -> Union[ResponsesResponse, ORJSONResponse]:\n        if not response_id.startswith(\"resp_\"):\n            return self._make_invalid_id_error(response_id)\n\n        async with self.response_store_lock:\n            response = self.response_store.get(response_id)\n            if response is None:\n                return self._make_not_found_error(response_id)\n\n            prev_status = response.status\n            if prev_status not in (\"queued\", \"in_progress\"):\n                return self.create_error_response(\n                    err_type=\"invalid_request_error\",\n                    message=\"Cannot cancel a synchronous response.\",\n                )\n\n            # Update the status to \"cancelled\"\n            response.status = \"cancelled\"\n\n        # The response_id is the same as the rid used when submitting the request\n        self.tokenizer_manager.abort_request(rid=response_id)\n\n        if task := self.background_tasks.get(response_id):\n            task.cancel()\n            try:\n                await task\n            except asyncio.CancelledError:\n                logger.exception(\"Background task for %s was cancelled\", response_id)\n        return response\n\n    def _make_invalid_id_error(self, response_id: str):\n        return self.create_error_response(\n            message=(\n                f\"Invalid 'response_id': '{response_id}'. \"\n                \"Expected an ID that begins with 'resp'.\"\n            ),\n            err_type=\"invalid_request_error\",\n            param=\"response_id\",\n        )\n\n    def _make_not_found_error(self, response_id: str):\n        return self.create_error_response(\n            message=f\"Response with id '{response_id}' not found.\",\n            err_type=\"invalid_request_error\",\n            status_code=HTTPStatus.NOT_FOUND,\n            param=\"response_id\",\n        )\n\n    async def responses_stream_generator(\n        self,\n        request: ResponsesRequest,\n        sampling_params: Any,\n        result_generator: AsyncIterator[StreamingHarmonyContext],\n        context: StreamingHarmonyContext,\n        model_name: str,\n        tokenizer: Any,\n        request_metadata: RequestResponseMetadata,\n        created_time: Optional[int] = None,\n    ) -> AsyncGenerator[str, None]:\n        # TODO:\n        # 1. Handle disconnect\n\n        created_time = created_time or int(time.time())\n\n        sequence_number = 0\n\n        def _send_event(event):\n            nonlocal sequence_number\n            # Set sequence_number if the event has this attribute\n            if hasattr(event, \"sequence_number\"):\n                event.sequence_number = sequence_number\n            sequence_number += 1\n            # Get event type from the event's type field if it exists\n            event_type = getattr(event, \"type\", \"unknown\")\n            return (\n                f\"event: {event_type}\\n\"\n                f\"data: {event.model_dump_json(indent=None)}\\n\\n\"\n            )\n\n        current_content_index = 0\n        current_output_index = 0\n        current_item_id = f\"item_{random_uuid()}\"\n        sent_output_item_added = False\n\n        initial_response = ResponsesResponse.from_request(\n            request,\n            sampling_params,\n            model_name=model_name,\n            created_time=created_time,\n            output=[],\n            status=\"in_progress\",\n            usage=None,\n        ).model_dump()\n        yield _send_event(\n            openai_responses_types.ResponseCreatedEvent(\n                type=\"response.created\",\n                sequence_number=-1,\n                response=initial_response,\n            )\n        )\n        yield _send_event(\n            openai_responses_types.ResponseInProgressEvent(\n                type=\"response.in_progress\",\n                sequence_number=-1,\n                response=initial_response,\n            )\n        )\n\n        async for ctx in result_generator:\n\n            # Only process context objects that implement the `is_expecting_start()` method,\n            # which indicates they support per-turn streaming (e.g., StreamingHarmonyContext).\n            # Contexts without this method are skipped, as they do not represent a new turn\n            # or are not compatible with per-turn handling in the /v1/responses endpoint.\n            if not hasattr(ctx, \"is_expecting_start\"):\n                continue\n\n            if ctx.is_expecting_start():\n                current_output_index += 1\n                sent_output_item_added = False\n\n                if len(ctx.parser.messages) > 0:\n                    previous_item = ctx.parser.messages[-1]\n                    if previous_item.recipient is not None:\n                        # Deal with tool call here\n                        pass\n                    elif previous_item.channel == \"analysis\":\n                        reasoning_item = ResponseReasoningItem(\n                            id=f\"rs_{random_uuid()}\",\n                            type=\"reasoning\",\n                            summary=[],\n                            content=[\n                                ResponseReasoningTextContent(\n                                    text=previous_item.content[0].text,\n                                    type=\"reasoning_text\",\n                                ),\n                            ],\n                            status=\"completed\",\n                        )\n                        yield _send_event(\n                            openai_responses_types.ResponseReasoningTextDoneEvent(\n                                type=\"response.reasoning_text.done\",\n                                item_id=current_item_id,\n                                sequence_number=-1,\n                                output_index=current_output_index,\n                                content_index=current_content_index,\n                                text=previous_item.content[0].text,\n                            )\n                        )\n                        yield _send_event(\n                            openai_responses_types.ResponseOutputItemDoneEvent(\n                                type=\"response.output_item.done\",\n                                sequence_number=-1,\n                                output_index=current_output_index,\n                                item=reasoning_item,\n                            )\n                        )\n                    elif previous_item.channel == \"final\":\n                        text_content = openai_responses_types.ResponseOutputText(\n                            type=\"output_text\",\n                            text=previous_item.content[0].text,\n                            annotations=[],\n                        )\n                        yield _send_event(\n                            openai_responses_types.ResponseTextDoneEvent(\n                                type=\"response.output_text.done\",\n                                sequence_number=-1,\n                                output_index=current_output_index,\n                                content_index=current_content_index,\n                                text=previous_item.content[0].text,\n                                logprobs=[],\n                                item_id=current_item_id,\n                            )\n                        )\n                        yield _send_event(\n                            openai_responses_types.ResponseContentPartDoneEvent(\n                                type=\"response.content_part.done\",\n                                sequence_number=-1,\n                                item_id=current_item_id,\n                                output_index=current_output_index,\n                                content_index=current_content_index,\n                                part=text_content,\n                            )\n                        )\n                        yield _send_event(\n                            openai_responses_types.ResponseOutputItemDoneEvent(\n                                type=\"response.output_item.done\",\n                                sequence_number=-1,\n                                output_index=current_output_index,\n                                item=openai_responses_types.ResponseOutputMessage(\n                                    id=current_item_id,\n                                    type=\"message\",\n                                    role=\"assistant\",\n                                    content=[text_content],\n                                    status=\"completed\",\n                                ),\n                            )\n                        )\n\n            if ctx.parser.last_content_delta:\n                if (\n                    ctx.parser.current_channel == \"final\"\n                    and ctx.parser.current_recipient is None\n                ):\n                    if not sent_output_item_added:\n                        sent_output_item_added = True\n                        yield _send_event(\n                            openai_responses_types.ResponseOutputItemAddedEvent(\n                                type=\"response.output_item.added\",\n                                sequence_number=-1,\n                                output_index=current_output_index,\n                                item=openai_responses_types.ResponseOutputMessage(\n                                    id=current_item_id,\n                                    type=\"message\",\n                                    role=\"assistant\",\n                                    content=[],\n                                    status=\"in_progress\",\n                                ),\n                            )\n                        )\n                        yield _send_event(\n                            openai_responses_types.ResponseContentPartAddedEvent(\n                                type=\"response.content_part.added\",\n                                sequence_number=-1,\n                                output_index=current_output_index,\n                                item_id=current_item_id,\n                                content_index=current_content_index,\n                                part=openai_responses_types.ResponseOutputText(\n                                    type=\"output_text\",\n                                    text=\"\",\n                                    annotations=[],\n                                    logprobs=None,\n                                ),\n                            )\n                        )\n                    yield _send_event(\n                        openai_responses_types.ResponseTextDeltaEvent(\n                            type=\"response.output_text.delta\",\n                            sequence_number=-1,\n                            content_index=current_content_index,\n                            output_index=current_output_index,\n                            item_id=current_item_id,\n                            delta=ctx.parser.last_content_delta,\n                            # TODO, use logprobs from ctx.last_request_output\n                            logprobs=[],\n                        )\n                    )\n                elif (\n                    ctx.parser.current_channel == \"analysis\"\n                    and ctx.parser.current_recipient is None\n                ):\n                    if not sent_output_item_added:\n                        sent_output_item_added = True\n                        yield _send_event(\n                            openai_responses_types.ResponseOutputItemAddedEvent(\n                                type=\"response.output_item.added\",\n                                sequence_number=-1,\n                                output_index=current_output_index,\n                                item=openai_responses_types.ResponseReasoningItem(\n                                    type=\"reasoning\",\n                                    id=current_item_id,\n                                    summary=[],\n                                    status=\"in_progress\",\n                                ),\n                            )\n                        )\n                        yield _send_event(\n                            openai_responses_types.ResponseContentPartAddedEvent(\n                                type=\"response.content_part.added\",\n                                sequence_number=-1,\n                                output_index=current_output_index,\n                                item_id=current_item_id,\n                                content_index=current_content_index,\n                                # TODO: migrate this to\n                                # ResponseReasoningTextContent for now\n                                part=openai_responses_types.ResponseOutputText(\n                                    type=\"output_text\",\n                                    text=\"\",\n                                    annotations=[],\n                                    logprobs=None,\n                                ),\n                            )\n                        )\n                    # TODO: migrate to OpenAI types once updated.\n                    yield _send_event(\n                        openai_responses_types.ResponseReasoningTextDeltaEvent(\n                            type=\"response.reasoning_text.delta\",\n                            item_id=current_item_id,\n                            output_index=current_output_index,\n                            content_index=current_content_index,\n                            delta=ctx.parser.last_content_delta,\n                            sequence_number=-1,\n                        )\n                    )\n\n            if ctx.is_assistant_action_turn() and len(ctx.parser.messages) > 0:\n                previous_item = ctx.parser.messages[-1]\n                if (\n                    self.supports_browsing\n                    and previous_item.recipient is not None\n                    and previous_item.recipient.startswith(\"browser.\")\n                ):\n                    function_name = previous_item.recipient[len(\"browser.\") :]\n                    action = None\n                    parsed_args = orjson.loads(previous_item.content[0].text)\n                    if function_name == \"search\":\n                        action = openai_responses_types.response_function_web_search.ActionSearch(\n                            type=\"search\",\n                            query=parsed_args[\"query\"],\n                        )\n                    elif function_name == \"open\":\n                        action = openai_responses_types.response_function_web_search.ActionOpenPage(\n                            type=\"open_page\",\n                            # TODO: translate to url\n                            url=f\"cursor:{parsed_args.get('cursor', '')}\",\n                        )\n                    elif function_name == \"find\":\n                        action = openai_responses_types.response_function_web_search.ActionFind(\n                            type=\"find\",\n                            pattern=parsed_args[\"pattern\"],\n                            # TODO: translate to url\n                            url=f\"cursor:{parsed_args.get('cursor', '')}\",\n                        )\n                    else:\n                        raise ValueError(f\"Unknown function name: {function_name}\")\n\n                    yield _send_event(\n                        openai_responses_types.ResponseOutputItemAddedEvent(\n                            type=\"response.output_item.added\",\n                            sequence_number=-1,\n                            output_index=current_output_index,\n                            item=openai_responses_types.response_function_web_search.ResponseFunctionWebSearch(\n                                # TODO: generate a unique id for web search call\n                                type=\"web_search_call\",\n                                id=current_item_id,\n                                action=action,\n                                status=\"in_progress\",\n                            ),\n                        )\n                    )\n                    yield _send_event(\n                        openai_responses_types.ResponseWebSearchCallInProgressEvent(\n                            type=\"response.web_search_call.in_progress\",\n                            sequence_number=-1,\n                            output_index=current_output_index,\n                            item_id=current_item_id,\n                        )\n                    )\n                    yield _send_event(\n                        openai_responses_types.ResponseWebSearchCallSearchingEvent(\n                            type=\"response.web_search_call.searching\",\n                            sequence_number=-1,\n                            output_index=current_output_index,\n                            item_id=current_item_id,\n                        )\n                    )\n\n                    # enqueue\n                    yield _send_event(\n                        openai_responses_types.ResponseWebSearchCallCompletedEvent(\n                            type=\"response.web_search_call.completed\",\n                            sequence_number=-1,\n                            output_index=current_output_index,\n                            item_id=current_item_id,\n                        )\n                    )\n                    yield _send_event(\n                        openai_responses_types.ResponseOutputItemDoneEvent(\n                            type=\"response.output_item.done\",\n                            sequence_number=-1,\n                            output_index=current_output_index,\n                            item=openai_responses_types.ResponseFunctionWebSearch(\n                                type=\"web_search_call\",\n                                id=current_item_id,\n                                action=action,\n                                status=\"completed\",\n                            ),\n                        )\n                    )\n\n                if (\n                    self.supports_code_interpreter\n                    and previous_item.recipient is not None\n                    and previous_item.recipient.startswith(\"python\")\n                ):\n                    yield _send_event(\n                        openai_responses_types.ResponseOutputItemAddedEvent(\n                            type=\"response.output_item.added\",\n                            sequence_number=-1,\n                            output_index=current_output_index,\n                            item=openai_responses_types.ResponseCodeInterpreterToolCallParam(\n                                type=\"code_interpreter_call\",\n                                id=current_item_id,\n                                code=\"\",\n                                container_id=\"auto\",\n                                outputs=[],\n                                status=\"in_progress\",\n                            ),\n                        )\n                    )\n                    yield _send_event(\n                        openai_responses_types.ResponseCodeInterpreterCallInProgressEvent(\n                            type=\"response.code_interpreter_call.in_progress\",\n                            sequence_number=-1,\n                            output_index=current_output_index,\n                            item_id=current_item_id,\n                        )\n                    )\n                    # TODO: do we need to add delta event here?\n                    yield _send_event(\n                        openai_responses_types.ResponseCodeInterpreterCallCodeDoneEvent(\n                            type=\"response.code_interpreter_call_code.done\",\n                            sequence_number=-1,\n                            output_index=current_output_index,\n                            item_id=current_item_id,\n                            code=previous_item.content[0].text,\n                        )\n                    )\n                    yield _send_event(\n                        openai_responses_types.ResponseCodeInterpreterCallInterpretingEvent(\n                            type=\"response.code_interpreter_call.interpreting\",\n                            sequence_number=-1,\n                            output_index=current_output_index,\n                            item_id=current_item_id,\n                        )\n                    )\n                    yield _send_event(\n                        openai_responses_types.ResponseCodeInterpreterCallCompletedEvent(\n                            type=\"response.code_interpreter_call.completed\",\n                            sequence_number=-1,\n                            output_index=current_output_index,\n                            item_id=current_item_id,\n                        )\n                    )\n                    yield _send_event(\n                        openai_responses_types.ResponseOutputItemDoneEvent(\n                            type=\"response.output_item.done\",\n                            sequence_number=-1,\n                            output_index=current_output_index,\n                            item=openai_responses_types.ResponseCodeInterpreterToolCallParam(\n                                type=\"code_interpreter_call\",\n                                id=current_item_id,\n                                code=previous_item.content[0].text,\n                                container_id=\"auto\",\n                                # TODO: add outputs here\n                                outputs=[],\n                                status=\"completed\",\n                            ),\n                        )\n                    )\n\n        async def empty_async_generator():\n            if False:\n                yield\n\n        final_response = await self.responses_full_generator(\n            request,\n            sampling_params,\n            empty_async_generator(),\n            context,\n            model_name,\n            tokenizer,\n            request_metadata,\n            created_time=created_time,\n        )\n        # Convert final_response to the format expected by ResponseCompletedEvent\n        response_dict = final_response.model_dump()\n\n        # Convert UsageInfo to ResponseUsage format\n        if response_dict.get(\"usage\"):\n            usage_info = response_dict[\"usage\"]\n            response_dict[\"usage\"] = {\n                \"input_tokens\": usage_info.get(\"prompt_tokens\", 0),\n                \"input_tokens_details\": {\n                    \"cached_tokens\": usage_info.get(\"cached_tokens\", 0)\n                },\n                \"output_tokens\": usage_info.get(\"completion_tokens\", 0),\n                \"output_tokens_details\": {\n                    \"reasoning_tokens\": usage_info.get(\"reasoning_tokens\", 0)\n                },\n                \"total_tokens\": usage_info.get(\"total_tokens\", 0),\n            }\n\n        yield _send_event(\n            openai_responses_types.ResponseCompletedEvent(\n                type=\"response.completed\",\n                sequence_number=-1,\n                response=response_dict,\n            )\n        )\n\n    async def _generate_with_builtin_tools(\n        self,\n        request_id: str,\n        request_prompt: Any,\n        adapted_request: GenerateReqInput,\n        sampling_params: Any,\n        context: ConversationContext,\n        raw_request: Optional[Request] = None,\n        priority: Optional[int] = None,\n        **kwargs,\n    ) -> AsyncGenerator[Any, None]:\n        \"\"\"Generate with builtin tool support for harmony-based models.\"\"\"\n        orig_priority = priority or 0\n\n        while True:\n            # Generate using SGLang's tokenizer manager\n            generator = self.tokenizer_manager.generate_request(\n                adapted_request, raw_request\n            )\n\n            async for res in generator:\n                context.append_output(res)\n                # NOTE(woosuk): The stop condition is handled by the engine.\n                yield context\n\n            if not context.need_builtin_tool_call():\n                # The model did not ask for a tool call, so we're done.\n                break\n\n            # Call the tool and update the context with the result.\n            tool_output = await context.call_tool()\n            context.append_output(tool_output)\n\n            # Prepare for the next generation turn\n            # Render the updated conversation for the next completion\n            prompt_token_ids = context.render_for_completion()\n\n            # Update the adapted request with new prompt\n            adapted_request = GenerateReqInput(\n                input_ids=prompt_token_ids,\n                sampling_params=sampling_params,\n                stream=adapted_request.stream,\n                rid=request_id,\n                extra_key=adapted_request.extra_key,\n                return_logprob=adapted_request.return_logprob,\n                logprob_start_len=adapted_request.logprob_start_len,\n                top_logprobs_num=adapted_request.top_logprobs_num,\n                return_text_in_logprobs=adapted_request.return_text_in_logprobs,\n                return_hidden_states=adapted_request.return_hidden_states,\n                background=adapted_request.background,\n            )\n\n            # Update sampling params with reduced max_tokens\n            if hasattr(sampling_params, \"max_new_tokens\") or isinstance(\n                sampling_params, dict\n            ):\n                context_len = getattr(\n                    self.tokenizer_manager.model_config, \"context_len\", 4096\n                )\n                num_reserved_tokens = self.tokenizer_manager.num_reserved_tokens\n                remaining_tokens = (\n                    context_len - len(prompt_token_ids) - num_reserved_tokens\n                )\n\n                if isinstance(sampling_params, dict):\n                    sampling_params[\"max_new_tokens\"] = max(remaining_tokens, 1)\n                else:\n                    sampling_params.max_new_tokens = max(remaining_tokens, 1)\n\n            # Slightly reduce priority for subsequent tool calls\n            priority = orig_priority - 1\n"
  },
  {
    "path": "python/sglang/srt/entrypoints/openai/serving_score.py",
    "content": "import logging\nfrom typing import Union\n\nfrom fastapi import Request\n\nfrom sglang.srt.entrypoints.openai.protocol import (\n    ErrorResponse,\n    ScoringRequest,\n    ScoringResponse,\n    UsageInfo,\n)\nfrom sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase\n\nlogger = logging.getLogger(__name__)\n\n\nclass OpenAIServingScore(OpenAIServingBase):\n    \"\"\"Handler for /v1/score requests\"\"\"\n\n    # NOTE: /v1/rerank is not an official OpenAI endpoint. This module may be moved\n    # to another module in the future.\n\n    def _request_id_prefix(self) -> str:\n        return \"score-\"\n\n    def _convert_to_internal_request(\n        self,\n        request: ScoringRequest,\n        raw_request: Request = None,\n    ) -> tuple[ScoringRequest, ScoringRequest]:\n        \"\"\"Convert OpenAI scoring request to internal format\"\"\"\n        # For scoring, we pass the request directly as the tokenizer_manager\n        # has a specialized score_request method that doesn't use GenerateReqInput\n\n        return request, request\n\n    async def _handle_non_streaming_request(\n        self,\n        adapted_request: ScoringRequest,\n        request: ScoringRequest,\n        raw_request: Request,\n    ) -> Union[ScoringResponse, ErrorResponse]:\n        \"\"\"Handle the scoring request\"\"\"\n        try:\n            # Use tokenizer_manager's score_request method directly\n            result = await self.tokenizer_manager.score_request(\n                query=request.query,\n                items=request.items,\n                label_token_ids=request.label_token_ids,\n                apply_softmax=request.apply_softmax,\n                item_first=request.item_first,\n                request=raw_request,\n            )\n\n            response = ScoringResponse(\n                scores=result.scores,\n                model=request.model,\n                usage=UsageInfo(\n                    prompt_tokens=result.prompt_tokens,\n                    total_tokens=result.prompt_tokens,\n                ),\n            )\n            return response\n\n        except ValueError as e:\n            return self.create_error_response(str(e))\n"
  },
  {
    "path": "python/sglang/srt/entrypoints/openai/serving_tokenize.py",
    "content": "import logging\nfrom http import HTTPStatus\nfrom typing import List, Union\n\nfrom fastapi import Request\n\nfrom sglang.srt.entrypoints.openai.protocol import (\n    DetokenizeRequest,\n    DetokenizeResponse,\n    ErrorResponse,\n    TokenizeRequest,\n    TokenizeResponse,\n)\nfrom sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase\n\nlogger = logging.getLogger(__name__)\n\n\nclass OpenAIServingTokenize(OpenAIServingBase):\n    \"\"\"Handler for /v1/tokenize requests\"\"\"\n\n    def _request_id_prefix(self) -> str:\n        return \"tok-\"\n\n    def _convert_to_internal_request(\n        self, request: TokenizeRequest, raw_request: Request\n    ) -> tuple[TokenizeRequest, TokenizeRequest]:\n        return request, request\n\n    async def _handle_non_streaming_request(\n        self,\n        adapted_request: TokenizeRequest,\n        request: TokenizeRequest,\n        raw_request: Request,\n    ) -> Union[TokenizeResponse, ErrorResponse]:\n        try:\n            tokenizer = self.tokenizer_manager.tokenizer\n            max_model_len = getattr(tokenizer, \"model_max_length\", -1)\n\n            if isinstance(request.prompt, str):\n                token_ids = tokenizer.encode(\n                    request.prompt,\n                    add_special_tokens=request.add_special_tokens,\n                )\n                tokens = token_ids\n                count = len(token_ids)\n            elif isinstance(request.prompt, list):\n                token_ids_list = [\n                    tokenizer.encode(\n                        text, add_special_tokens=request.add_special_tokens\n                    )\n                    for text in request.prompt\n                ]\n                tokens = token_ids_list\n                count = [len(ids) for ids in token_ids_list]\n            else:\n                return self.create_error_response(\n                    f\"Invalid prompt type: {type(request.prompt)}. Expected str or List[str].\"\n                )\n\n            return TokenizeResponse(\n                tokens=tokens, count=count, max_model_len=max_model_len\n            )\n        except Exception as e:\n            logger.error(\"Error during tokenization\", exc_info=True)\n            return self.create_error_response(\n                f\"Internal server error during tokenization: {e}\",\n                err_type=\"InternalServerError\",\n                status_code=HTTPStatus.INTERNAL_SERVER_ERROR,\n            )\n\n\nclass OpenAIServingDetokenize(OpenAIServingBase):\n    \"\"\"Handler for /v1/detokenize requests\"\"\"\n\n    def _request_id_prefix(self) -> str:\n        return \"detok-\"\n\n    def _convert_to_internal_request(\n        self, request: DetokenizeRequest, raw_request: Request\n    ) -> tuple[DetokenizeRequest, DetokenizeRequest]:\n        return request, request\n\n    async def _handle_non_streaming_request(\n        self,\n        adapted_request: DetokenizeRequest,\n        request: DetokenizeRequest,\n        raw_request: Request,\n    ) -> Union[DetokenizeResponse, ErrorResponse]:\n        try:\n            tokenizer = self.tokenizer_manager.tokenizer\n\n            if (\n                isinstance(request.tokens, list)\n                and request.tokens\n                and isinstance(request.tokens[0], int)\n            ):\n                if not all(isinstance(t, int) for t in request.tokens):\n                    return self.create_error_response(\n                        \"Invalid input: 'tokens' must be a list of integers.\"\n                    )\n                tokens_to_decode = [int(t) for t in request.tokens]\n                text = tokenizer.decode(\n                    tokens_to_decode, skip_special_tokens=request.skip_special_tokens\n                )\n                text_out: Union[str, List[str]] = text\n            elif (\n                isinstance(request.tokens, list)\n                and request.tokens\n                and isinstance(request.tokens[0], list)\n            ):\n                texts: List[str] = []\n                for token_list in request.tokens:\n                    if not all(isinstance(t, int) for t in token_list):\n                        return self.create_error_response(\n                            f\"Invalid input: Sublist in 'tokens' must contain only integers. Found: {token_list}\"\n                        )\n                    decoded_text = tokenizer.decode(\n                        [int(t) for t in token_list],\n                        skip_special_tokens=request.skip_special_tokens,\n                    )\n                    texts.append(decoded_text)\n                text_out = texts\n            elif isinstance(request.tokens, list) and not request.tokens:\n                text_out = \"\"\n            else:\n                return self.create_error_response(\n                    f\"Invalid tokens type: {type(request.tokens)}. Expected List[int] or List[List[int]].\"\n                )\n\n            return DetokenizeResponse(text=text_out)\n        except Exception as e:\n            logger.error(\"Error during detokenization\", exc_info=True)\n            if \"decode\" in str(e).lower():\n                return self.create_error_response(\n                    f\"Error decoding tokens: {e}. Input tokens might be invalid for the model.\",\n                    err_type=\"DecodeError\",\n                    status_code=HTTPStatus.BAD_REQUEST,\n                )\n            return self.create_error_response(\n                f\"Internal server error during detokenization: {e}\",\n                err_type=\"InternalServerError\",\n                status_code=HTTPStatus.INTERNAL_SERVER_ERROR,\n            )\n"
  },
  {
    "path": "python/sglang/srt/entrypoints/openai/serving_transcription.py",
    "content": "# Copyright 2025 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"\nOpenAI-compatible transcription endpoint handler for Whisper models.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport io\nimport logging\nimport math\nimport time\nimport uuid\nfrom typing import TYPE_CHECKING, AsyncGenerator, Optional, Union\n\nfrom fastapi import Request\nfrom fastapi.responses import ORJSONResponse, Response, StreamingResponse\n\nfrom sglang.srt.entrypoints.openai.protocol import (\n    DeltaMessage,\n    ErrorResponse,\n    TranscriptionRequest,\n    TranscriptionResponse,\n    TranscriptionStreamChoice,\n    TranscriptionStreamResponse,\n    TranscriptionUsage,\n)\nfrom sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase\nfrom sglang.srt.managers.io_struct import GenerateReqInput\n\nif TYPE_CHECKING:\n    from sglang.srt.managers.tokenizer_manager import TokenizerManager\n\nlogger = logging.getLogger(__name__)\n\n\nclass OpenAIServingTranscription(OpenAIServingBase):\n    \"\"\"Handler for /v1/audio/transcriptions requests\"\"\"\n\n    def __init__(self, tokenizer_manager: TokenizerManager):\n        super().__init__(tokenizer_manager)\n\n    def _request_id_prefix(self) -> str:\n        return \"trsc-\"\n\n    def _validate_request(self, request: TranscriptionRequest) -> Optional[str]:\n        \"\"\"Validate transcription request.\"\"\"\n        # Validation is done in the route handler for form data\n        return None\n\n    def _convert_to_internal_request(\n        self,\n        request: TranscriptionRequest,\n        raw_request: Request = None,\n    ) -> tuple[GenerateReqInput, TranscriptionRequest]:\n        \"\"\"Convert transcription request to internal format.\"\"\"\n        # Build sampling params - include language for WhisperProcessor\n        sampling_params = {\n            \"temperature\": request.temperature,\n            \"max_new_tokens\": 448,  # Whisper default max tokens\n            \"language\": request.language,  # Pass to WhisperProcessor for language-specific decoding\n        }\n\n        # For Whisper, we pass audio_data and let the processor handle it\n        adapted_request = GenerateReqInput(\n            text=\"\",  # Empty text - Whisper processor will set proper decoder tokens\n            audio_data=request.audio_data,\n            sampling_params=sampling_params,\n            stream=request.stream,\n            modalities=[\"audio\"],\n            routing_key=self.extract_routing_key(raw_request),\n        )\n\n        return adapted_request, request\n\n    def _get_audio_duration(self, audio_data: bytes) -> float:\n        \"\"\"Calculate audio duration in seconds.\"\"\"\n        try:\n            import soundfile as sf\n\n            audio_array, sr = sf.read(io.BytesIO(audio_data))\n            duration = len(audio_array) / sr\n            return duration\n        except Exception as e:\n            logger.warning(f\"Could not calculate audio duration: {e}\")\n            return 0.0\n\n    async def create_transcription(\n        self,\n        audio_data: bytes,\n        model: str,\n        language: Optional[str],\n        response_format: str,\n        temperature: float,\n        stream: bool,\n        raw_request: Request,\n    ) -> Union[TranscriptionResponse, StreamingResponse, Response, ORJSONResponse]:\n        \"\"\"Main entry point for transcription requests.\"\"\"\n        # Calculate audio duration for usage reporting\n        audio_duration_s = self._get_audio_duration(audio_data)\n\n        # Build request\n        request = TranscriptionRequest(\n            audio_data=audio_data,\n            model=model,\n            language=language,\n            response_format=response_format,\n            temperature=temperature,\n            stream=stream,\n            audio_duration_s=audio_duration_s,\n        )\n\n        # Use the base class handle_request pattern\n        return await self.handle_request(request, raw_request)\n\n    async def _handle_non_streaming_request(\n        self,\n        adapted_request: GenerateReqInput,\n        request: TranscriptionRequest,\n        raw_request: Request,\n    ) -> Union[TranscriptionResponse, ErrorResponse, ORJSONResponse, Response]:\n        \"\"\"Handle non-streaming transcription request.\"\"\"\n        try:\n            ret = await self.tokenizer_manager.generate_request(\n                adapted_request, raw_request\n            ).__anext__()\n        except ValueError as e:\n            return self.create_error_response(str(e))\n\n        text = ret.get(\"text\", \"\")\n\n        # Build response based on format\n        if request.response_format == \"text\":\n            return Response(content=text, media_type=\"text/plain\")\n\n        # JSON format\n        usage = TranscriptionUsage(seconds=int(math.ceil(request.audio_duration_s)))\n\n        return TranscriptionResponse(text=text, usage=usage)\n\n    async def _handle_streaming_request(\n        self,\n        adapted_request: GenerateReqInput,\n        request: TranscriptionRequest,\n        raw_request: Request,\n    ) -> StreamingResponse:\n        \"\"\"Handle streaming transcription request.\"\"\"\n        return StreamingResponse(\n            self._generate_transcription_stream(adapted_request, request, raw_request),\n            media_type=\"text/event-stream\",\n            background=self.tokenizer_manager.create_abort_task(adapted_request),\n        )\n\n    async def _generate_transcription_stream(\n        self,\n        adapted_request: GenerateReqInput,\n        request: TranscriptionRequest,\n        raw_request: Request,\n    ) -> AsyncGenerator[str, None]:\n        \"\"\"Generate streaming transcription response.\"\"\"\n        created_time = int(time.time())\n        request_id = f\"{self._request_id_prefix()}{uuid.uuid4().hex}\"\n        model = request.model\n        stream_buffer = \"\"\n\n        try:\n            async for content in self.tokenizer_manager.generate_request(\n                adapted_request, raw_request\n            ):\n                finish_reason = content[\"meta_info\"][\"finish_reason\"]\n                finish_reason_type = finish_reason[\"type\"] if finish_reason else None\n\n                # Calculate delta (new text since last chunk)\n                current_text = content.get(\"text\", \"\")\n                delta = current_text[len(stream_buffer) :]\n                stream_buffer = current_text\n\n                # Send content delta if there's new text\n                if delta:\n                    choice_data = TranscriptionStreamChoice(\n                        delta=DeltaMessage(content=delta),\n                        finish_reason=None,\n                    )\n                    chunk = TranscriptionStreamResponse(\n                        id=request_id,\n                        created=created_time,\n                        model=model,\n                        choices=[choice_data],\n                    )\n                    yield f\"data: {chunk.model_dump_json()}\\n\\n\"\n\n                # Send finish reason when done\n                if finish_reason_type:\n                    choice_data = TranscriptionStreamChoice(\n                        delta=DeltaMessage(),\n                        finish_reason=finish_reason_type,\n                    )\n                    chunk = TranscriptionStreamResponse(\n                        id=request_id,\n                        created=created_time,\n                        model=model,\n                        choices=[choice_data],\n                    )\n                    yield f\"data: {chunk.model_dump_json()}\\n\\n\"\n\n        except ValueError as e:\n            error = self.create_streaming_error_response(str(e))\n            yield f\"data: {error}\\n\\n\"\n\n        yield \"data: [DONE]\\n\\n\"\n"
  },
  {
    "path": "python/sglang/srt/entrypoints/openai/tool_server.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n# SPDX-FileCopyrightText: Copyright contributors to the vLLM project\nimport logging\nfrom abc import ABC, abstractmethod\nfrom contextlib import AbstractAsyncContextManager, asynccontextmanager\nfrom typing import Any\n\ntry:\n    from mcp import ClientSession\n    from mcp.client.sse import sse_client\n    from mcp.types import ListToolsResult\nexcept ImportError as e:\n    ClientSession = sse_client = ListToolsResult = e\n\nfrom openai_harmony import ToolDescription, ToolNamespaceConfig\n\nlogger = logging.getLogger(__name__)\n\n\nasync def list_server_and_tools(server_url: str):\n\n    async with sse_client(url=server_url) as streams, ClientSession(\n        *streams\n    ) as session:\n        initialize_response = await session.initialize()\n        list_tools_response = await session.list_tools()\n        return initialize_response, list_tools_response\n\n\ndef trim_schema(schema: dict) -> dict:\n    # Turn JSON Schema from MCP generated into Harmony's variant.\n    if \"title\" in schema:\n        del schema[\"title\"]\n    if \"default\" in schema and schema[\"default\"] is None:\n        del schema[\"default\"]\n    if \"anyOf\" in schema:\n        # Turn \"anyOf\": [{\"type\": \"type-1\"}, {\"type\": \"type-2\"}]\n        # into \"type\": [\"type-1\", \"type-2\"]\n        # if there's more than 1 types, also remove \"null\" type as Harmony will\n        # just ignore it\n        types = [\n            type_dict[\"type\"]\n            for type_dict in schema[\"anyOf\"]\n            if type_dict[\"type\"] != \"null\"\n        ]\n        schema[\"type\"] = types\n        del schema[\"anyOf\"]\n    if \"properties\" in schema:\n        schema[\"properties\"] = {\n            k: trim_schema(v) for k, v in schema[\"properties\"].items()\n        }\n    return schema\n\n\ndef post_process_tools_description(\n    list_tools_result: \"ListToolsResult\",\n) -> \"ListToolsResult\":\n    # Adapt the MCP tool result for Harmony\n    for tool in list_tools_result.tools:\n        tool.inputSchema = trim_schema(tool.inputSchema)\n\n    # Some tools schema don't need to be part of the prompt (e.g. simple text\n    # in text out for Python)\n    list_tools_result.tools = [\n        tool\n        for tool in list_tools_result.tools\n        if getattr(tool.annotations, \"include_in_prompt\", True)\n    ]\n\n    return list_tools_result\n\n\nclass ToolServer(ABC):\n\n    @abstractmethod\n    def has_tool(self, tool_name: str):\n        pass\n\n    @abstractmethod\n    def get_tool_description(self, tool_name: str):\n        pass\n\n    @abstractmethod\n    def get_tool_session(self, tool_name: str) -> AbstractAsyncContextManager[Any]: ...\n\n\nclass MCPToolServer(ToolServer):\n\n    def __init__(self):\n        self.harmony_tool_descriptions = {}\n\n    async def add_tool_server(self, server_url: str):\n        tool_urls = server_url.split(\",\")\n        self.harmony_tool_descriptions = {}\n        self.urls: dict[str, str] = {}\n        for url in tool_urls:\n            url = f\"http://{url}/sse\"\n            initialize_response, list_tools_response = await list_server_and_tools(url)\n\n            list_tools_response = post_process_tools_description(list_tools_response)\n\n            tool_from_mcp = ToolNamespaceConfig(\n                name=initialize_response.serverInfo.name,\n                description=initialize_response.instructions,\n                tools=[\n                    ToolDescription.new(\n                        name=tool.name,\n                        description=tool.description,\n                        parameters=tool.inputSchema,\n                    )\n                    for tool in list_tools_response.tools\n                ],\n            )\n            self.harmony_tool_descriptions[tool_from_mcp.name] = tool_from_mcp\n            if tool_from_mcp.name not in self.urls:\n                self.urls[tool_from_mcp.name] = url\n            else:\n                logger.warning(\n                    \"Tool %s already exists. Ignoring duplicate tool server %s\",\n                    tool_from_mcp.name,\n                    url,\n                )\n\n    def has_tool(self, tool_name: str):\n        return tool_name in self.harmony_tool_descriptions\n\n    def get_tool_description(self, tool_name: str):\n        return self.harmony_tool_descriptions.get(tool_name)\n\n    @asynccontextmanager\n    async def get_tool_session(self, tool_name: str):\n        url = self.urls.get(tool_name)\n        if url:\n            async with sse_client(url=url) as streams, ClientSession(\n                *streams\n            ) as session:\n                await session.initialize()\n                yield session\n        else:\n            logger.warning(\"Tool %s not found\", tool_name)\n\n\nclass DemoToolServer(ToolServer):\n\n    def __init__(self):\n        from sglang.srt.entrypoints.tool import (\n            HarmonyBrowserTool,\n            HarmonyPythonTool,\n            Tool,\n        )\n\n        self.tools: dict[str, Tool] = {}\n        browser_tool = HarmonyBrowserTool()\n        if browser_tool.enabled:\n            self.tools[\"browser\"] = browser_tool\n        python_tool = HarmonyPythonTool()\n        if python_tool.enabled:\n            self.tools[\"python\"] = python_tool\n\n    def has_tool(self, tool_name: str):\n        return tool_name in self.tools\n\n    def get_tool_description(self, tool_name: str):\n        if tool_name not in self.tools:\n            return None\n        if tool_name == \"browser\":\n            return ToolNamespaceConfig.browser()\n        elif tool_name == \"python\":\n            return ToolNamespaceConfig.python()\n        else:\n            raise ValueError(f\"Unknown tool {tool_name}\")\n\n    @asynccontextmanager\n    async def get_tool_session(self, tool_name: str):\n        yield self.tools[tool_name]\n"
  },
  {
    "path": "python/sglang/srt/entrypoints/openai/usage_processor.py",
    "content": "from __future__ import annotations\n\nfrom typing import Any, Dict, List, Mapping, Optional, final\n\nfrom sglang.srt.entrypoints.openai.protocol import PromptTokensDetails, UsageInfo\n\n\n@final\nclass UsageProcessor:\n    \"\"\"Stateless helpers that turn raw token counts into a UsageInfo.\"\"\"\n\n    @staticmethod\n    def _details_if_cached(count: int) -> Optional[PromptTokensDetails]:\n        \"\"\"Return PromptTokensDetails only when count > 0 (keeps JSON slim).\"\"\"\n        return PromptTokensDetails(cached_tokens=count) if count > 0 else None\n\n    @staticmethod\n    def calculate_response_usage(\n        responses: List[Dict[str, Any]],\n        n_choices: int = 1,\n        enable_cache_report: bool = False,\n    ) -> UsageInfo:\n        completion_tokens = sum(\n            r[\"meta_info\"].get(\"completion_tokens\", 0) for r in responses\n        )\n\n        prompt_tokens = sum(\n            responses[i][\"meta_info\"].get(\"prompt_tokens\", 0)\n            for i in range(0, len(responses), n_choices)\n        )\n\n        cached_details = None\n        if enable_cache_report:\n            cached_total = sum(\n                responses[i][\"meta_info\"].get(\"cached_tokens\", 0)\n                for i in range(0, len(responses), n_choices)\n            )\n            cached_details = UsageProcessor._details_if_cached(cached_total)\n\n        return UsageProcessor.calculate_token_usage(\n            prompt_tokens=prompt_tokens,\n            completion_tokens=completion_tokens,\n            cached_tokens=cached_details,\n        )\n\n    @staticmethod\n    def calculate_streaming_usage(\n        prompt_tokens: Mapping[int, int],\n        completion_tokens: Mapping[int, int],\n        cached_tokens: Mapping[int, int],\n        n_choices: int,\n        enable_cache_report: bool = False,\n    ) -> UsageInfo:\n        # index % n_choices == 0 marks the first choice of a prompt\n        total_prompt_tokens = sum(\n            tok for idx, tok in prompt_tokens.items() if idx % n_choices == 0\n        )\n        total_completion_tokens = sum(completion_tokens.values())\n\n        cached_details = (\n            UsageProcessor._details_if_cached(\n                sum(tok for idx, tok in cached_tokens.items() if idx % n_choices == 0)\n            )\n            if enable_cache_report\n            else None\n        )\n\n        return UsageProcessor.calculate_token_usage(\n            prompt_tokens=total_prompt_tokens,\n            completion_tokens=total_completion_tokens,\n            cached_tokens=cached_details,\n        )\n\n    @staticmethod\n    def calculate_token_usage(\n        prompt_tokens: int,\n        completion_tokens: int,\n        cached_tokens: Optional[PromptTokensDetails] = None,\n    ) -> UsageInfo:\n        \"\"\"Calculate token usage information\"\"\"\n        return UsageInfo(\n            prompt_tokens=prompt_tokens,\n            completion_tokens=completion_tokens,\n            total_tokens=prompt_tokens + completion_tokens,\n            prompt_tokens_details=cached_tokens,\n        )\n"
  },
  {
    "path": "python/sglang/srt/entrypoints/openai/utils.py",
    "content": "import logging\nfrom typing import Any, Dict, List, Optional, Union\n\nfrom sglang.srt.entrypoints.openai.protocol import (\n    CachedTokensDetails,\n    ChatCompletionRequest,\n    CompletionRequest,\n    LogProbs,\n)\n\nlogger = logging.getLogger(__name__)\n\n\ndef to_openai_style_logprobs(\n    input_token_logprobs=None,\n    output_token_logprobs=None,\n    input_top_logprobs=None,\n    output_top_logprobs=None,\n):\n    ret_logprobs = LogProbs()\n\n    def append_token_logprobs(token_logprobs):\n        for logprob, _, token_text in token_logprobs:\n            ret_logprobs.tokens.append(token_text)\n            ret_logprobs.token_logprobs.append(logprob)\n\n            # Not supported yet\n            ret_logprobs.text_offset.append(-1)\n\n    def append_top_logprobs(top_logprobs):\n        for tokens in top_logprobs:\n            if tokens is not None:\n                ret_logprobs.top_logprobs.append(\n                    {token[2]: token[0] for token in tokens}\n                )\n            else:\n                ret_logprobs.top_logprobs.append(None)\n\n    if input_token_logprobs is not None:\n        append_token_logprobs(input_token_logprobs)\n    if output_token_logprobs is not None:\n        append_token_logprobs(output_token_logprobs)\n    if input_top_logprobs is not None:\n        append_top_logprobs(input_top_logprobs)\n    if output_top_logprobs is not None:\n        append_top_logprobs(output_top_logprobs)\n\n    return ret_logprobs\n\n\ndef process_hidden_states_from_ret(\n    ret_item: Dict[str, Any],\n    request: Union[\n        ChatCompletionRequest,\n        CompletionRequest,\n    ],\n) -> Optional[List]:\n    \"\"\"Process hidden states from a ret item in non-streaming response.\n\n    Args:\n        ret_item: Response item containing meta_info\n        request: The original request object\n\n    Returns:\n        Processed hidden states for the last token, or None\n    \"\"\"\n    if not request.return_hidden_states:\n        return None\n\n    hidden_states = ret_item[\"meta_info\"].get(\"hidden_states\", None)\n    if hidden_states is not None:\n        hidden_states = hidden_states[-1] if len(hidden_states) > 1 else []\n    return hidden_states\n\n\ndef process_routed_experts_from_ret(\n    ret_item: Dict[str, Any],\n    request: Union[\n        ChatCompletionRequest,\n        CompletionRequest,\n    ],\n) -> Optional[str]:\n    \"\"\"Process routed experts from a ret item in non-streaming response.\"\"\"\n    if not getattr(request, \"return_routed_experts\", False):\n        return None\n    return ret_item[\"meta_info\"].get(\"routed_experts\", None)\n\n\ndef process_cached_tokens_details_from_ret(\n    ret_item: Dict[str, Any],\n    request: Union[\n        ChatCompletionRequest,\n        CompletionRequest,\n    ],\n) -> Optional[CachedTokensDetails]:\n    \"\"\"Process cached tokens details from a ret item in non-streaming response.\"\"\"\n    if not getattr(request, \"return_cached_tokens_details\", False):\n        return None\n\n    details = ret_item[\"meta_info\"].get(\"cached_tokens_details\", None)\n    if details is None:\n        return None\n\n    # Check if L3 storage fields are present\n    if \"storage\" in details:\n        return CachedTokensDetails(\n            device=details.get(\"device\", 0),\n            host=details.get(\"host\", 0),\n            storage=details.get(\"storage\", 0),\n            storage_backend=details.get(\"storage_backend\"),\n        )\n    else:\n        return CachedTokensDetails(\n            device=details.get(\"device\", 0),\n            host=details.get(\"host\", 0),\n        )\n"
  },
  {
    "path": "python/sglang/srt/entrypoints/ssl_utils.py",
    "content": "\"\"\"Utilities for SSL certificate hot-reloading.\"\"\"\n\nimport asyncio\nimport logging\nimport ssl\nfrom typing import Optional\n\nfrom watchfiles import awatch\n\nlogger = logging.getLogger(__name__)\n\n\nclass SSLCertRefresher:\n    \"\"\"Monitors SSL certificate files and reloads them when changed.\n\n    Uses ``watchfiles.awatch()`` for efficient inotify/kqueue-based\n    file monitoring.  On change the referenced :class:`ssl.SSLContext`\n    is updated in-place so that new TLS connections automatically pick\n    up the fresh certificates.\n    \"\"\"\n\n    def __init__(\n        self,\n        ssl_context: ssl.SSLContext,\n        key_path: str,\n        cert_path: str,\n        ca_path: Optional[str] = None,\n    ) -> None:\n        self._ssl_context = ssl_context\n        self._key_path = key_path\n        self._cert_path = cert_path\n        self._ca_path = ca_path\n        self._tasks: list[asyncio.Task] = []\n\n        loop = asyncio.get_running_loop()\n        self._tasks.append(\n            loop.create_task(self._watch_cert_key(), name=\"ssl-cert-key-watcher\")\n        )\n        if self._ca_path:\n            self._tasks.append(\n                loop.create_task(self._watch_ca(), name=\"ssl-ca-watcher\")\n            )\n\n    async def _watch_cert_key(self) -> None:\n        \"\"\"Watch cert and key files and reload on change.\"\"\"\n        try:\n            async for _changes in awatch(self._cert_path, self._key_path):\n                logger.info(\n                    \"SSL cert/key file change detected, reloading: \" \"cert=%s key=%s\",\n                    self._cert_path,\n                    self._key_path,\n                )\n                try:\n                    self._ssl_context.load_cert_chain(self._cert_path, self._key_path)\n                    logger.info(\"SSL cert/key reloaded successfully.\")\n                except Exception:\n                    logger.exception(\n                        \"Failed to reload SSL cert/key — continuing with \"\n                        \"previous certificates.\"\n                    )\n        except asyncio.CancelledError:\n            return\n\n    async def _watch_ca(self) -> None:\n        \"\"\"Watch CA file and reload on change.\"\"\"\n        assert self._ca_path is not None\n        try:\n            async for _changes in awatch(self._ca_path):\n                logger.info(\n                    \"SSL CA file change detected, reloading: ca=%s\",\n                    self._ca_path,\n                )\n                try:\n                    self._ssl_context.load_verify_locations(self._ca_path)\n                    logger.info(\"SSL CA certificates reloaded successfully.\")\n                except Exception:\n                    logger.exception(\n                        \"Failed to reload SSL CA certificates — continuing \"\n                        \"with previous CA bundle.\"\n                    )\n        except asyncio.CancelledError:\n            return\n\n    def stop(self) -> None:\n        \"\"\"Cancel all watching tasks.\"\"\"\n        for task in self._tasks:\n            task.cancel()\n        self._tasks.clear()\n"
  },
  {
    "path": "python/sglang/srt/entrypoints/tool.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\nimport logging\nimport os\nfrom abc import ABC, abstractmethod\nfrom typing import TYPE_CHECKING, Any\n\nfrom sglang.srt.utils import print_info_once, print_warning_once\n\nif TYPE_CHECKING:\n    # Avoid circular import.\n    from sglang.srt.entrypoints.context import ConversationContext\n\nlogger = logging.getLogger(__name__)\n\n\nclass Tool(ABC):\n\n    @abstractmethod\n    async def get_result(self, context: \"ConversationContext\") -> Any:\n        pass\n\n\nclass HarmonyBrowserTool(Tool):\n\n    def __init__(self):\n        self.enabled = True\n        exa_api_key = os.getenv(\"EXA_API_KEY\")\n        if not exa_api_key:\n            self.enabled = False\n            print_warning_once(\"EXA_API_KEY is not set, browsing is disabled\")\n            return\n\n        try:\n            from gpt_oss.tools.simple_browser import SimpleBrowserTool\n            from gpt_oss.tools.simple_browser.backend import ExaBackend\n        except ImportError:\n            self.enabled = False\n            print_warning_once(\"gpt_oss is not installed, browsing is disabled\")\n            return\n\n        browser_backend = ExaBackend(source=\"web\", api_key=exa_api_key)\n        self.browser_tool = SimpleBrowserTool(backend=browser_backend)\n        print_info_once(\"Browser tool initialized\")\n\n    async def get_result(self, context: \"ConversationContext\") -> Any:\n        from sglang.srt.entrypoints.context import HarmonyContext\n\n        assert isinstance(context, HarmonyContext)\n        last_msg = context.messages[-1]\n        tool_output_msgs = []\n        async for msg in self.browser_tool.process(last_msg):\n            tool_output_msgs.append(msg)\n        return tool_output_msgs\n\n    @property\n    def tool_config(self) -> Any:\n        return self.browser_tool.tool_config\n\n\nclass HarmonyPythonTool(Tool):\n\n    def __init__(self):\n        self.enabled = True\n\n        try:\n            from gpt_oss.tools.python_docker.docker_tool import PythonTool\n        except ImportError:\n            self.enabled = False\n            print_warning_once(\"gpt_oss is not installed, code interpreter is disabled\")\n            return\n\n        self.python_tool = PythonTool()\n        print_info_once(\"Code interpreter tool initialized\")\n\n    async def get_result(self, context: \"ConversationContext\") -> Any:\n        from sglang.srt.entrypoints.context import HarmonyContext\n\n        assert isinstance(context, HarmonyContext)\n        last_msg = context.messages[-1]\n        tool_output_msgs = []\n        async for msg in self.python_tool.process(last_msg):\n            tool_output_msgs.append(msg)\n        return tool_output_msgs\n\n    @property\n    def tool_config(self) -> Any:\n        return self.python_tool.tool_config\n"
  },
  {
    "path": "python/sglang/srt/entrypoints/v1_loads.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"\n/v1/loads API endpoint for comprehensive load metrics.\n\nThis module provides the /v1/loads endpoint which returns detailed scheduler\nmetrics for load balancing, monitoring, and capacity planning.\n\"\"\"\n\nimport dataclasses\nfrom datetime import datetime, timezone\nfrom typing import Optional\n\nfrom fastapi import APIRouter, Depends, HTTPException\nfrom fastapi.responses import Response\n\nfrom sglang.srt.managers.io_struct import (\n    DisaggregationMetrics,\n    GetLoadsReqOutput,\n    LoRAMetrics,\n    MemoryMetrics,\n    QueueMetrics,\n    SpeculativeMetrics,\n)\nfrom sglang.version import __version__\n\nrouter = APIRouter()\n\n_OPTIONAL_METRIC_SECTIONS = {\n    \"memory\": (\"memory\", MemoryMetrics),\n    \"speculative\": (\"spec\", SpeculativeMetrics),\n    \"lora\": (\"lora\", LoRAMetrics),\n    \"disaggregation\": (\"disagg\", DisaggregationMetrics),\n    \"queues\": (\"queues\", QueueMetrics),\n}\n\n\ndef _get_tokenizer_manager():\n    \"\"\"Dependency to get tokenizer_manager from global state.\"\"\"\n    from sglang.srt.entrypoints.http_server import get_global_state\n\n    return get_global_state().tokenizer_manager\n\n\ndef _loads_dict_factory(items):\n    \"\"\"Factory for dataclasses.asdict() that excludes None values and timestamp.\"\"\"\n    return {k: v for k, v in items if v is not None and k != \"timestamp\"}\n\n\ndef _compute_aggregate(load_dicts: list) -> dict:\n    \"\"\"Compute aggregate metrics from load dicts.\"\"\"\n    if not load_dicts:\n        return {\n            \"total_running_reqs\": 0,\n            \"total_waiting_reqs\": 0,\n            \"total_reqs\": 0,\n            \"avg_token_usage\": 0.0,\n            \"avg_throughput\": 0.0,\n            \"avg_utilization\": 0.0,\n        }\n\n    n = len(load_dicts)\n    return {\n        \"total_running_reqs\": sum(d[\"num_running_reqs\"] for d in load_dicts),\n        \"total_waiting_reqs\": sum(d[\"num_waiting_reqs\"] for d in load_dicts),\n        \"total_reqs\": sum(\n            d[\"num_running_reqs\"] + d[\"num_waiting_reqs\"] for d in load_dicts\n        ),\n        \"avg_token_usage\": round(sum(d[\"token_usage\"] for d in load_dicts) / n, 4),\n        \"avg_throughput\": round(sum(d[\"gen_throughput\"] for d in load_dicts) / n, 2),\n        \"avg_utilization\": round(sum(d[\"utilization\"] for d in load_dicts) / n, 4),\n    }\n\n\ndef _format_loads_prometheus(load_results) -> Response:\n    \"\"\"Format load metrics in Prometheus text exposition format.\n\n    Metrics are derived from dataclass field metadata, providing a single source of truth.\n    \"\"\"\n    lines = []\n\n    for f in dataclasses.fields(GetLoadsReqOutput):\n        if \"metric\" not in f.metadata:\n            continue\n        metric_type, description = f.metadata[\"metric\"]\n        metric_name = f\"sglang_{f.name}\"\n        lines.append(f\"# HELP {metric_name} {description}\")\n        lines.append(f\"# TYPE {metric_name} {metric_type}\")\n        for load in load_results:\n            value = getattr(load, f.name, None)\n            if value is not None:\n                lines.append(f'{metric_name}{{dp_rank=\"{load.dp_rank}\"}} {value}')\n\n    for attr_name, (prefix, dataclass_type) in _OPTIONAL_METRIC_SECTIONS.items():\n        if not any(getattr(load, attr_name, None) for load in load_results):\n            continue\n        for f in dataclasses.fields(dataclass_type):\n            if \"metric\" not in f.metadata:\n                continue\n            metric_type, description = f.metadata[\"metric\"]\n            metric_name = f\"sglang_{prefix}_{f.name}\"\n            lines.append(f\"# HELP {metric_name} {description}\")\n            lines.append(f\"# TYPE {metric_name} {metric_type}\")\n            for load in load_results:\n                section = getattr(load, attr_name, None)\n                if section:\n                    value = getattr(section, f.name, None)\n                    if value is not None:\n                        lines.append(\n                            f'{metric_name}{{dp_rank=\"{load.dp_rank}\"}} {value}'\n                        )\n\n    return Response(\n        content=\"\\n\".join(lines) + \"\\n\",\n        media_type=\"text/plain; version=0.0.4; charset=utf-8\",\n    )\n\n\n@router.get(\"/v1/loads\")\nasync def get_loads(\n    dp_rank: Optional[int] = None,\n    include: Optional[str] = None,\n    format: Optional[str] = None,\n    tokenizer_manager=Depends(_get_tokenizer_manager),\n):\n    \"\"\"\n    Get comprehensive load metrics for all DP ranks.\n\n    Query Parameters:\n        dp_rank: Filter to specific DP rank (optional)\n        include: Comma-separated sections to include (optional)\n                 Options: core, memory, spec, lora, disagg, queues, all\n                 Default: all\n        format: Response format - 'json' (default) or 'prometheus'\n\n    Returns:\n        JSON response with timestamp, version, dp_rank_count, per-DP-rank loads, and aggregates\n    \"\"\"\n    include_list = [s.strip() for s in include.split(\",\")] if include else None\n\n    try:\n        load_results = await tokenizer_manager.get_loads(\n            include=include_list,\n            dp_rank=dp_rank,\n        )\n    except ValueError as e:\n        raise HTTPException(status_code=400, detail=str(e))\n\n    if format == \"prometheus\":\n        return _format_loads_prometheus(load_results)\n\n    loads = []\n    for load in load_results:\n        d = dataclasses.asdict(load, dict_factory=_loads_dict_factory)\n        d[\"num_total_reqs\"] = d[\"num_running_reqs\"] + d[\"num_waiting_reqs\"]\n        loads.append(d)\n\n    return {\n        \"timestamp\": datetime.now(timezone.utc).isoformat(),\n        \"version\": __version__,\n        \"dp_rank_count\": len(loads),\n        \"loads\": loads,\n        \"aggregate\": _compute_aggregate(loads),\n    }\n"
  },
  {
    "path": "python/sglang/srt/entrypoints/warmup.py",
    "content": "from __future__ import annotations\n\nimport logging\nfrom typing import TYPE_CHECKING, List\n\nimport numpy as np\nimport tqdm\n\nfrom sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST\nfrom sglang.srt.managers.io_struct import GenerateReqInput\n\nif TYPE_CHECKING:\n    from sglang.srt.managers.tokenizer_manager import TokenizerManager\n\nlogger = logging.getLogger(__file__)\n\n_warmup_registry = {}\n\n\ndef warmup(name: str):\n    def decorator(fn):\n        _warmup_registry[name] = fn\n        return fn\n\n    return decorator\n\n\nasync def execute_warmups(\n    disaggregation_mode: str,\n    warmup_names: List[str],\n    tokenizer_manager: TokenizerManager,\n):\n    for warmup_name in warmup_names:\n        if warmup_name not in _warmup_registry:\n            logger.warning(f\"Could not find custom warmup {warmup_name}\")\n            continue\n        logger.info(f\"Running warmup {warmup_name}\")\n        await _warmup_registry[warmup_name](disaggregation_mode, tokenizer_manager)\n\n\n@warmup(\"voice_chat\")\nasync def voice_chat(disaggregation_mode: str, tokenizer_manager: TokenizerManager):\n    # this warms up the fused_moe triton kernels and caches them\n    # if we don't do this we break real time inference for voice chat\n    for i in tqdm.trange(1, 512):\n        size = i * 4\n        generate_req_input = GenerateReqInput(\n            input_ids=(np.random.randint(2**16, size=[size])).tolist(),\n            sampling_params={\n                \"max_new_tokens\": 30,\n                \"temperature\": 0.8,\n                \"stop_token_ids\": [1],\n                \"min_p\": 0.0,\n            },\n        )\n        if disaggregation_mode != \"null\":\n            generate_req_input.bootstrap_room = 0\n            generate_req_input.bootstrap_host = FAKE_BOOTSTRAP_HOST\n\n        await tokenizer_manager.generate_request(generate_req_input, None).__anext__()\n"
  },
  {
    "path": "python/sglang/srt/environ.py",
    "content": "import os\nimport subprocess\nimport warnings\nfrom contextlib import ExitStack, contextmanager\nfrom enum import IntEnum\nfrom typing import Any\n\n\n@contextmanager\ndef temp_set_env(*, allow_sglang: bool = False, **env_vars: Any):\n    \"\"\"Temporarily set environment variables, restoring originals on exit.\n\n    By default, SGLANG_*/SGL_* keys are rejected — use ``Envs`` descriptors\n    for those.  Pass ``allow_sglang=True`` only for special env vars that\n    intentionally bypass ``environ.py``.\n    \"\"\"\n    if not allow_sglang:\n        for key in env_vars:\n            if key.startswith(\"SGLANG_\") or key.startswith(\"SGL_\"):\n                raise ValueError(\"temp_set_env should not be used for sglang env vars\")\n\n    backup = {key: os.environ.get(key) for key in env_vars}\n    try:\n        for key, value in env_vars.items():\n            if value is None:\n                os.environ.pop(key, None)\n            else:\n                os.environ[key] = str(value)\n        yield\n    finally:\n        for key, value in backup.items():\n            if value is None:\n                os.environ.pop(key, None)\n            else:\n                os.environ[key] = value\n\n\nclass EnvField:\n    _allow_set_name = True\n\n    def __init__(self, default: Any):\n        self.default = default\n        # NOTE: environ can only accept str values, so we need a flag to indicate\n        # whether the env var is explicitly set to None.\n        self._set_to_none = False\n\n    def __set_name__(self, owner, name):\n        assert EnvField._allow_set_name, \"Usage like `a = envs.A` is not allowed\"\n        self.name = name\n\n    def parse(self, value: str) -> Any:\n        raise NotImplementedError()\n\n    def get(self) -> Any:\n        value = os.getenv(self.name)\n\n        # Explicitly set to None\n        if self._set_to_none:\n            assert value == str(None)\n            return None\n\n        # Not set, return default\n        if value is None:\n            return self.default\n\n        try:\n            return self.parse(value)\n        except ValueError as e:\n            warnings.warn(\n                f'Invalid value for {self.name}: {e}, using default \"{self.default}\"'\n            )\n            return self.default\n\n    def is_set(self):\n        return self.name in os.environ\n\n    def set(self, value: Any):\n        self._set_to_none = value is None\n        os.environ[self.name] = str(value)\n\n    @contextmanager\n    def override(self, value: Any):\n        backup_present = self.name in os.environ\n        backup_value = os.environ.get(self.name)\n        backup_set_to_none = self._set_to_none\n        self.set(value)\n        yield\n        if backup_present:\n            os.environ[self.name] = backup_value\n        else:\n            os.environ.pop(self.name, None)\n        self._set_to_none = backup_set_to_none\n\n    def clear(self):\n        os.environ.pop(self.name, None)\n        self._set_to_none = False\n\n    def __bool__(self):\n        raise RuntimeError(\n            \"Please use `envs.YOUR_FLAG.get()` instead of `envs.YOUR_FLAG`\"\n        )\n\n    def __len__(self):\n        raise RuntimeError(\n            \"Please use `envs.YOUR_FLAG.get()` instead of `envs.YOUR_FLAG`\"\n        )\n\n\nclass EnvTuple(EnvField):\n    def parse(self, value: str) -> tuple[str, ...]:\n        return tuple(s.strip() for s in value.split(\",\") if s.strip())\n\n\nclass EnvStr(EnvField):\n    def parse(self, value: str) -> str:\n        return value\n\n\nclass EnvBool(EnvField):\n    def parse(self, value: str) -> bool:\n        value = value.lower()\n        if value in [\"true\", \"1\", \"yes\", \"y\"]:\n            return True\n        if value in [\"false\", \"0\", \"no\", \"n\"]:\n            return False\n        raise ValueError(f'\"{value}\" is not a valid boolean value')\n\n\nclass EnvInt(EnvField):\n    def parse(self, value: str) -> int:\n        try:\n            return int(value)\n        except ValueError:\n            raise ValueError(f'\"{value}\" is not a valid integer value')\n\n\nclass EnvFloat(EnvField):\n    def parse(self, value: str) -> float:\n        try:\n            return float(value)\n        except ValueError:\n            raise ValueError(f'\"{value}\" is not a valid float value')\n\n\nclass ToolStrictLevel(IntEnum):\n    \"\"\"\n    Defines the strictness levels for tool call parsing and validation.\n\n    OFF: No strict validation\n    FUNCTION: Enables structural tag constraints for all tools\n    PARAMETER: Enforces strict parameter validation for all tools\n    \"\"\"\n\n    OFF = 0\n    FUNCTION = 1\n    PARAMETER = 2\n\n\nclass Envs:\n    # fmt: off\n\n    # Model & File Download\n    SGLANG_USE_MODELSCOPE = EnvBool(False)\n    SGLANG_SORT_WEIGHT_FILES = EnvBool(False)\n    SGLANG_DISABLED_MODEL_ARCHS = EnvTuple(tuple())\n\n    # Logging Options\n    SGLANG_LOG_GC = EnvBool(False)\n    SGLANG_LOG_FORWARD_ITERS = EnvBool(False)\n    SGLANG_LOG_MS = EnvBool(False)\n    SGLANG_DISABLE_REQUEST_LOGGING = EnvBool(False)\n    SGLANG_LOG_REQUEST_EXCEEDED_MS = EnvInt(-1)\n    SGLANG_LOG_REQUEST_HEADERS = EnvTuple(tuple())\n    SGLANG_LOG_SCHEDULER_STATUS_TARGET = EnvStr(\"\")\n    SGLANG_LOG_SCHEDULER_STATUS_INTERVAL = EnvFloat(60.0)\n\n    # SGLang CI\n    SGLANG_IS_IN_CI = EnvBool(False)\n    SGLANG_IS_IN_CI_AMD = EnvBool(False)\n    SGLANG_CUDA_COREDUMP = EnvBool(False)\n    SGLANG_CUDA_COREDUMP_DIR = EnvStr(\"/tmp/sglang_cuda_coredumps\")\n    SGLANG_TEST_MAX_RETRY = EnvInt(None)\n\n    # Constrained Decoding (Grammar)\n    SGLANG_GRAMMAR_POLL_INTERVAL = EnvFloat(0.005)\n    SGLANG_GRAMMAR_MAX_POLL_ITERATIONS = EnvInt(10000)\n    SGLANG_DISABLE_OUTLINES_DISK_CACHE = EnvBool(False)\n\n\n    # Test & Debug\n    SGLANG_DETECT_SLOW_RANK = EnvBool(False)\n    SGLANG_TEST_STUCK_DETOKENIZER = EnvFloat(0)\n    SGLANG_TEST_STUCK_DP_CONTROLLER = EnvFloat(0)\n    SGLANG_TEST_STUCK_SCHEDULER_INIT = EnvFloat(0)\n    SGLANG_TEST_STUCK_TOKENIZER = EnvFloat(0)\n    SGLANG_TEST_CRASH_AFTER_STREAM_OUTPUTS = EnvInt(0)\n    IS_BLACKWELL = EnvBool(False)\n    IS_H200 = EnvBool(False)\n    SGLANG_SET_CPU_AFFINITY = EnvBool(False)\n    SGLANG_PROFILE_WITH_STACK = EnvBool(True)\n    SGLANG_PROFILE_RECORD_SHAPES = EnvBool(True)\n    SGLANG_PROFILE_V2 = EnvBool(False)\n    SGLANG_RECORD_STEP_TIME = EnvBool(False)\n    SGLANG_FORCE_SHUTDOWN = EnvBool(False)\n    SGLANG_DEBUG_MEMORY_POOL = EnvBool(False)\n    SGLANG_TEST_REQUEST_TIME_STATS = EnvBool(False)\n    SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK = EnvBool(False)\n    SGLANG_SIMULATE_ACC_LEN = EnvFloat(-1)\n    SGLANG_SIMULATE_ACC_METHOD = EnvStr(\"multinomial\")\n    SGLANG_TORCH_PROFILER_DIR = EnvStr(\"/tmp\")\n    SGLANG_OTLP_EXPORTER_SCHEDULE_DELAY_MILLIS = EnvInt(500)\n    SGLANG_OTLP_EXPORTER_MAX_EXPORT_BATCH_SIZE = EnvInt(64)\n    SGLANG_NATIVE_MOVE_KV_CACHE = EnvBool(False)\n    SGLANG_ENABLE_TP_MEMORY_INBALANCE_CHECK = EnvBool(True)\n\n    # Scheduler: memory leak test\n    SGLANG_TEST_RETRACT = EnvBool(False)\n    SGLANG_TEST_RETRACT_INTERVAL = EnvInt(3)\n    SGLANG_TEST_RETRACT_NO_PREFILL_BS = EnvInt(2 ** 31)\n    SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_BUSY = EnvInt(0)\n    SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_IDLE = EnvBool(True)\n\n    # Scheduler: new token ratio hyperparameters\n    SGLANG_INIT_NEW_TOKEN_RATIO = EnvFloat(0.7)\n    SGLANG_MIN_NEW_TOKEN_RATIO_FACTOR = EnvFloat(0.14)\n    SGLANG_NEW_TOKEN_RATIO_DECAY_STEPS = EnvInt(600)\n    SGLANG_RETRACT_DECODE_STEPS = EnvInt(20)\n    SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION = EnvInt(4096)\n\n    # Scheduler: recv interval\n    SGLANG_SCHEDULER_RECV_SKIPPER_WEIGHT_DEFAULT = EnvInt(1000)\n    SGLANG_SCHEDULER_RECV_SKIPPER_WEIGHT_DECODE = EnvInt(1)\n    SGLANG_SCHEDULER_RECV_SKIPPER_WEIGHT_TARGET_VERIFY = EnvInt(1)\n    SGLANG_SCHEDULER_RECV_SKIPPER_WEIGHT_NONE = EnvInt(1)\n\n    # PD Disaggregation (runtime)\n    # NOTE: For SGLANG_DISAGGREGATION_THREAD_POOL_SIZE, the effective default is\n    # computed dynamically at runtime based on cpu_count; see disaggregation backends.\n    SGLANG_DISAGGREGATION_THREAD_POOL_SIZE = EnvInt(None)\n    SGLANG_DISAGGREGATION_QUEUE_SIZE = EnvInt(4)\n    SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT = EnvInt(300)\n    SGLANG_DISAGGREGATION_HEARTBEAT_INTERVAL = EnvFloat(5.0)\n    SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE = EnvInt(2)\n    SGLANG_DISAGGREGATION_WAITING_TIMEOUT = EnvInt(300)\n    SGLANG_DISAGGREGATION_NIXL_BACKEND = EnvStr(\"UCX\")\n    SGLANG_DISAGGREGATION_ALL_CP_RANKS_TRANSFER = EnvBool(False)\n\n    # Scheduler: others:\n    SGLANG_EMPTY_CACHE_INTERVAL = EnvFloat(-1)  # in seconds. Set if you observe high memory accumulation over a long serving period.\n    SGLANG_DISABLE_CONSECUTIVE_PREFILL_OVERLAP = EnvBool(False)\n    SGLANG_SCHEDULER_MAX_RECV_PER_POLL = EnvInt(-1)\n    SGLANG_EXPERIMENTAL_CPP_RADIX_TREE = EnvBool(False)\n    SGLANG_DYNAMIC_CHUNKING_SMOOTH_FACTOR = EnvFloat(0.75)\n    SGLANG_SCHEDULER_SKIP_ALL_GATHER = EnvBool(False)\n    SGLANG_SCHEDULER_DECREASE_PREFILL_IDLE = EnvBool(False)\n    SGLANG_PREFILL_DELAYER_MAX_DELAY_PASSES = EnvInt(None)\n    SGLANG_PREFILL_DELAYER_TOKEN_USAGE_LOW_WATERMARK = EnvFloat(None)\n    SGLANG_DATA_PARALLEL_BUDGET_INTERVAL = EnvInt(1)\n    SGLANG_REQ_WAITING_TIMEOUT = EnvFloat(-1)  # in seconds\n    SGLANG_NCCL_ALL_GATHER_IN_OVERLAP_SCHEDULER_SYNC_BATCH = EnvBool(False)\n    SGLANG_REQ_RUNNING_TIMEOUT = EnvFloat(-1)  # in seconds\n    SGLANG_DISAGGREGATION_BOOTSTRAP_ENTRY_CLEANUP_INTERVAL = EnvInt(120)\n\n    # Test: pd-disaggregation\n    SGLANG_TEST_PD_DISAGG_BACKEND = EnvStr(\"mooncake\")\n    SGLANG_TEST_PD_DISAGG_DEVICES = EnvStr(None)\n\n    # Model Parallel\n    SGLANG_USE_MESSAGE_QUEUE_BROADCASTER = EnvBool(True)\n    SGLANG_ONE_VISIBLE_DEVICE_PER_PROCESS = EnvBool(False)\n    # Override the distributed init method used by torch.distributed.init_process_group.\n    # Set to \"env://\" to use an externally-created TCPStore via MASTER_ADDR/MASTER_PORT.\n    SGLANG_DISTRIBUTED_INIT_METHOD_OVERRIDE = EnvStr(None)\n    SGLANG_TCP_STORE_PORT = EnvInt(29600)\n\n    # Tool Calling\n    SGLANG_FORWARD_UNKNOWN_TOOLS = EnvBool(False)\n\n    # Hi-Cache\n    SGLANG_HICACHE_HF3FS_CONFIG_PATH = EnvStr(None)\n    SGLANG_HICACHE_DECODE_OFFLOAD_STRIDE = EnvInt(None)\n    SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR = EnvStr(None)\n    SGLANG_HICACHE_NIXL_BACKEND_STORAGE_DIR = EnvStr(None)\n    # Max fraction of cache (by token count) that can be pinned; 0 = disable pinning.\n    SGLANG_HICACHE_MAX_PINNED_RATIO = EnvFloat(0.0)\n\n    # Mooncake KV Transfer\n    SGLANG_MOONCAKE_CUSTOM_MEM_POOL = EnvStr(None)\n    ENABLE_ASCEND_TRANSFER_WITH_MOONCAKE = EnvBool(False)\n    ASCEND_NPU_PHY_ID = EnvInt(-1)\n    SGLANG_MOONCAKE_SEND_AUX_TCP = EnvBool(False)\n\n    # Mooncake Store\n    SGLANG_HICACHE_MOONCAKE_CONFIG_PATH = EnvStr(None)\n    SGLANG_HICACHE_MOONCAKE_REUSE_TE = EnvBool(True)\n    MOONCAKE_MASTER = EnvStr(None)\n    MOONCAKE_CLIENT = EnvStr(None)\n    MOONCAKE_LOCAL_HOSTNAME = EnvStr(\"localhost\")\n    MOONCAKE_TE_META_DATA_SERVER = EnvStr(\"P2PHANDSHAKE\")\n    MOONCAKE_GLOBAL_SEGMENT_SIZE = EnvStr(\"4gb\")\n    MOONCAKE_PROTOCOL = EnvStr(\"tcp\")\n    MOONCAKE_DEVICE = EnvStr(\"\")\n    MOONCAKE_MASTER_METRICS_PORT = EnvInt(9003)\n    MOONCAKE_CHECK_SERVER = EnvBool(False)\n    MOONCAKE_STANDALONE_STORAGE = EnvBool(False)\n\n    # AMD & ROCm\n    SGLANG_USE_AITER = EnvBool(False)\n    SGLANG_ROCM_FUSED_DECODE_MLA = EnvBool(False)\n    SGLANG_ROCM_DISABLE_LINEARQUANT = EnvBool(False)\n\n    # MPS (Apple Silicon)\n    SGLANG_USE_MLX = EnvBool(False)\n\n    # NPU\n    SGLANG_NPU_DISABLE_ACL_FORMAT_WEIGHT = EnvBool(False)\n    SGLANG_NPU_USE_MULTI_STREAM = EnvBool(False)\n    SGLANG_NPU_USE_MLAPO = EnvBool(False)\n    # Forward native implementation for activation gelu tanh for model Skywork-Reward-Gemma-2-27B-v0.2\n    SGLANG_NPU_FORWARD_NATIVE_GELUTANH = EnvBool(False)\n    # Forward native implementation for gemma rms norm for model Skywork-Reward-Gemma-2-27B-v0.2\n    SGLANG_NPU_FORWARD_NATIVE_GEMMA_RMS_NORM = EnvBool(False)\n    # Delay all-gather after qlora for better performance for Deepseek v3.2\n    SGLANG_USE_AG_AFTER_QLORA = EnvBool(False)\n    SGLANG_NPU_FUSED_MOE_MODE = EnvInt(1)\n\n    # Quantization\n    SGLANG_INT4_WEIGHT = EnvBool(False)\n    SGLANG_CPU_QUANTIZATION = EnvBool(False)\n    SGLANG_USE_DYNAMIC_MXFP4_LINEAR = EnvBool(False)\n    SGLANG_FORCE_FP8_MARLIN = EnvBool(False)\n    SGLANG_MOE_NVFP4_DISPATCH = EnvBool(False)\n    SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN = EnvBool(False)\n    SGLANG_PER_TOKEN_GROUP_QUANT_8BIT_V2 = EnvBool(False)\n    SGLANG_NVFP4_CKPT_FP8_NEXTN_MOE = EnvBool(False)\n    SGLANG_QUANT_ALLOW_DOWNCASTING = EnvBool(False)\n    SGLANG_FP8_IGNORED_LAYERS = EnvStr(\"\")\n\n    # Flashinfer\n    SGLANG_IS_FLASHINFER_AVAILABLE = EnvBool(True)\n    SGLANG_ENABLE_FLASHINFER_FP8_GEMM = EnvBool(False)\n    # Default to the pick from flashinfer\n    SGLANG_FLASHINFER_FP4_GEMM_BACKEND = EnvStr(\"\")\n    SGLANG_FLASHINFER_WORKSPACE_SIZE = EnvInt(384 * 1024 * 1024)\n\n    # Triton\n    SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS = EnvBool(False)\n    SGLANG_USE_CUSTOM_TRITON_KERNEL_CACHE = EnvBool(False)\n\n    # Torch Compile\n    SGLANG_ENABLE_TORCH_COMPILE = EnvBool(False)\n\n    # EPLB\n    SGLANG_EXPERT_LOCATION_UPDATER_LOG_INPUT = EnvBool(False)\n    SGLANG_EXPERT_LOCATION_UPDATER_CANARY = EnvBool(False)\n    SGLANG_EXPERT_LOCATION_UPDATER_LOG_METRICS = EnvBool(False)\n    SGLANG_LOG_EXPERT_LOCATION_METADATA = EnvBool(False)\n    SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR = EnvStr(\"/tmp\")\n    SGLANG_EPLB_HEATMAP_COLLECTION_INTERVAL = EnvInt(0)\n    SGLANG_ENABLE_EPLB_BALANCEDNESS_METRIC = EnvBool(False)\n\n    # TBO\n    SGLANG_TBO_DEBUG = EnvBool(False)\n\n    # DeepGemm\n    SGLANG_ENABLE_JIT_DEEPGEMM = EnvBool(True)\n    SGLANG_JIT_DEEPGEMM_PRECOMPILE = EnvBool(True)\n    SGLANG_JIT_DEEPGEMM_FAST_WARMUP = EnvBool(False)\n    SGLANG_JIT_DEEPGEMM_COMPILE_WORKERS = EnvInt(4)\n    SGLANG_IN_DEEPGEMM_PRECOMPILE_STAGE = EnvBool(False)\n    SGLANG_DG_CACHE_DIR = EnvStr(os.path.expanduser(\"~/.cache/deep_gemm\"))\n    SGLANG_DG_USE_NVRTC = EnvBool(False)\n    SGLANG_USE_DEEPGEMM_BMM = EnvBool(False)\n\n    # DeepSeek MHA Optimization\n    SGLANG_CHUNKED_PREFIX_CACHE_THRESHOLD = EnvInt(8192)\n\n    # DeepEP\n    SGLANG_DEEPEP_BF16_DISPATCH = EnvBool(False)\n    SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK = EnvInt(128)\n    SGLANG_DEEPEP_LL_COMBINE_SEND_NUM_SMS = EnvInt(32)\n    SGLANG_BLACKWELL_OVERLAP_SHARED_EXPERTS_OUTSIDE_SBO = EnvBool(False)\n\n    # NIXL-EP\n    SGLANG_NIXL_EP_BF16_DISPATCH = EnvBool(False)\n    SGLANG_NIXL_EP_NUM_MAX_DISPATCH_TOKENS_PER_RANK = EnvInt(128)\n\n    # NSA Backend\n    SGLANG_NSA_FUSE_TOPK = EnvBool(True)\n    SGLANG_NSA_ENABLE_MTP_PRECOMPUTE_METADATA = EnvBool(True)\n    SGLANG_USE_FUSED_METADATA_COPY = EnvBool(True)\n    SGLANG_NSA_PREFILL_DENSE_ATTN_KV_LEN_THRESHOLD = EnvInt(2048)\n\n    # sgl-kernel\n    SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK = EnvBool(False)\n\n    # vLLM dependencies (TODO: they have been deprecated, we can remove them safely)\n    USE_VLLM_CUTLASS_W8A8_FP8_KERNEL = EnvBool(False)\n\n    USE_TRITON_W8A8_FP8_KERNEL = EnvBool(False)\n    SGLANG_RETURN_ORIGINAL_LOGPROB = EnvBool(False)\n    SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN = EnvBool(False)\n    SGLANG_MOE_PADDING = EnvBool(False)\n    SGLANG_CUTLASS_MOE = EnvBool(False)\n    HF_HUB_DISABLE_XET = EnvBool(False)\n    DISABLE_OPENAPI_DOC = EnvBool(False)\n    SGLANG_ENABLE_TORCH_INFERENCE_MODE = EnvBool(False)\n    SGLANG_IS_FIRST_RANK_ON_NODE = EnvBool(True)\n    SGLANG_SUPPORT_CUTLASS_BLOCK_FP8 = EnvBool(False)\n    SGLANG_SYNC_TOKEN_IDS_ACROSS_TP = EnvBool(False)\n    SGLANG_ENABLE_COLOCATED_BATCH_GEN = EnvBool(False)\n\n    # Deterministic inference\n    SGLANG_ENABLE_DETERMINISTIC_INFERENCE = EnvBool(False)\n    # Use 1-stage all-reduce kernel on AMD (deterministic, fixed accumulation order)\n    # If not set: auto (enabled when --enable-deterministic-inference is on)\n    # Set to 1: force enable (even without --enable-deterministic-inference)\n    # Set to 0: force disable (use default Aiter AR even with --enable-deterministic-inference)\n    SGLANG_USE_1STAGE_ALLREDUCE = EnvBool(False)\n    SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE = EnvInt(4096)\n    SGLANG_FLASHINFER_DECODE_SPLIT_TILE_SIZE = EnvInt(2048)\n    SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE = EnvInt(4096)\n    SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE = EnvInt(256)\n\n    # RoPE cache configuration\n    SGLANG_SPEC_EXPANSION_SAFETY_FACTOR = EnvInt(2)\n    SGLANG_ROPE_CACHE_SAFETY_MARGIN = EnvInt(256)\n    SGLANG_ROPE_CACHE_ALIGN = EnvInt(128)\n\n    # Overlap Spec V2\n    SGLANG_ENABLE_SPEC_V2 = EnvBool(False)\n    SGLANG_ENABLE_OVERLAP_PLAN_STREAM = EnvBool(False)\n\n    # Spec Config\n    SGLANG_SPEC_ENABLE_STRICT_FILTER_CHECK = EnvBool(True)\n    SGLANG_SPEC_NAN_DETECTION = EnvBool(False)\n    SGLANG_SPEC_OOB_DETECTION = EnvBool(False)\n\n    # VLM\n    SGLANG_VLM_CACHE_SIZE_MB = EnvInt(100)\n    SGLANG_IMAGE_MAX_PIXELS = EnvInt(16384 * 28 * 28)\n    SGLANG_RESIZE_RESAMPLE = EnvStr(\"\")\n    SGLANG_MM_BUFFER_SIZE_MB = EnvInt(0)\n    SGLANG_MM_PRECOMPUTE_HASH = EnvBool(False)\n    SGLANG_VIT_ENABLE_CUDA_GRAPH = EnvBool(False)\n    SGLANG_MM_SKIP_COMPUTE_HASH = EnvBool(False)\n\n\n    # VLM Item CUDA IPC Transport\n    SGLANG_USE_CUDA_IPC_TRANSPORT = EnvBool(False)\n    SGLANG_MM_FEATURE_CACHE_MB = EnvInt(4 * 1024)\n    SGLANG_MM_ITEM_MEM_POOL_RECYCLE_INTERVAL_SEC = EnvFloat(0.05)\n\n    # MM splitting behavior control\n    SGLANG_ENABLE_MM_SPLITTING = EnvBool(False)\n\n    # Mamba\n    SGLANG_MAMBA_CONV_DTYPE = EnvStr(\"bfloat16\")\n    SGLANG_MAMBA_SSM_DTYPE = EnvStr(None)\n\n    # Release & Resume Memory\n    SGLANG_MEMORY_SAVER_CUDA_GRAPH = EnvBool(False)\n\n    # Sparse Embeddings\n    SGLANG_EMBEDDINGS_SPARSE_HEAD = EnvStr(None)\n\n    # Logits processor\n    SGLANG_ENABLE_LOGITS_PROCESSER_CHUNK = EnvBool(False)\n    SGLANG_LOGITS_PROCESSER_CHUNK_SIZE = EnvInt(2048)\n\n    # Tool-Call behavior\n    SGLANG_TOOL_STRICT_LEVEL = EnvInt(ToolStrictLevel.OFF)\n\n    # Ngram\n    SGLANG_NGRAM_FORCE_GREEDY_VERIFY = EnvBool(False)\n\n    # Warmup\n    SGLANG_WARMUP_TIMEOUT = EnvFloat(-1) # in seconds. If a warmup forward batch takes longer than this, the server will crash to prevent hanging. Recommend to increase warmup timeout to 1800 to accommodate some kernel JIT precache e.g. deep gemm\n\n    # HTTP Server\n    SGLANG_TIMEOUT_KEEP_ALIVE = EnvInt(5)\n\n    # Health Check\n    SGLANG_ENABLE_HEALTH_ENDPOINT_GENERATION = EnvBool(True)\n\n    # Encoder gRPC\n    SGLANG_ENCODER_GRPC_TIMEOUT_SECS = EnvInt(60)\n    # Encoder receiver selection: http|grpc (used by EPD paths).\n    SGLANG_ENCODER_MM_RECEIVER_MODE = EnvStr(\"http\")\n\n    # External models\n    SGLANG_EXTERNAL_MODEL_PACKAGE = EnvStr(\"\")\n    SGLANG_EXTERNAL_MM_MODEL_ARCH = EnvStr(\"\")\n    SGLANG_EXTERNAL_MM_PROCESSOR_PACKAGE = EnvStr(\"\")\n\n    # Numa\n    SGLANG_NUMA_BIND_V2 = EnvBool(True)\n    SGLANG_AUTO_NUMA_BIND = EnvBool(False)\n\n    # Metrics\n    SGLANG_ENABLE_METRICS_DEVICE_TIMER = EnvBool(False)\n    SGLANG_ENABLE_METRICS_DP_ATTENTION = EnvBool(False)\n\n    # Tokenizer\n    SGLANG_PATCH_TOKENIZER = EnvBool(False)  # TODO enable by default\n\n    # TokenizerManager\n    SGLANG_REQUEST_STATE_WAIT_TIMEOUT = EnvInt(4)\n\n    # Symmetric Memory\n    SGLANG_SYMM_MEM_PREALLOC_GB_SIZE = EnvInt(-1)\n\n    # Aiter\n    SGLANG_USE_AITER_FP8_PER_TOKEN = EnvBool(False)\n    # fmt: on\n\n    # EPD\n    SGLANG_ENCODER_RECV_TIMEOUT = EnvFloat(180.0)\n    SGLANG_ENCODER_SEND_TIMEOUT = EnvFloat(180.0)\n    SGLANG_ENCODER_DISPATCH_MIN_ITEMS = EnvInt(2)\n\n    # Elastic EP Backup Port\n    SGLANG_BACKUP_PORT_BASE = EnvInt(10000)\n\n\nenvs = Envs()\nEnvField._allow_set_name = False\n\n\ndef _print_deprecated_env(new_name: str, old_name: str):\n    if old_name in os.environ:\n        warnings.warn(\n            f\"Environment variable {old_name} will be deprecated, please use {new_name} instead\"\n        )\n        os.environ[new_name] = os.environ[old_name]\n\n\ndef _warn_deprecated_env_to_cli_flag(env_name: str, suggestion: str):\n    \"\"\"Warn when a deprecated environment variable is used.\n\n    This is for env vars that are deprecated in favor of CLI flags.\n    \"\"\"\n    if env_name in os.environ:\n        warnings.warn(f\"Environment variable {env_name} is deprecated. {suggestion}\")\n\n\ndef _convert_SGL_to_SGLANG():\n    _print_deprecated_env(\"SGLANG_LOG_GC\", \"SGLANG_GC_LOG\")\n    _print_deprecated_env(\n        \"SGLANG_ENABLE_FLASHINFER_FP8_GEMM\", \"SGLANG_ENABLE_FLASHINFER_GEMM\"\n    )\n    _print_deprecated_env(\n        \"SGLANG_MOE_NVFP4_DISPATCH\", \"SGLANG_CUTEDSL_MOE_NVFP4_DISPATCH\"\n    )\n    _print_deprecated_env(\n        \"SGLANG_ENABLE_TP_MEMORY_INBALANCE_CHECK\",\n        \"SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK\",\n    )\n    _deprecated_ms_to_s = {\n        \"SGLANG_QUEUED_TIMEOUT_MS\": \"SGLANG_REQ_WAITING_TIMEOUT\",\n        \"SGLANG_FORWARD_TIMEOUT_MS\": \"SGLANG_REQ_RUNNING_TIMEOUT\",\n    }\n    for old_name, new_name in _deprecated_ms_to_s.items():\n        if old_name in os.environ:\n            ms_val = os.environ[old_name]\n            warnings.warn(\n                f\"Environment variable {old_name} (in ms) is deprecated, \"\n                f\"please use {new_name} (in seconds) instead\"\n            )\n            os.environ[new_name] = str(float(ms_val) / 1000.0)\n\n    for key, value in os.environ.items():\n        if key.startswith(\"SGL_\"):\n            new_key = key.replace(\"SGL_\", \"SGLANG_\", 1)\n            warnings.warn(\n                f\"Environment variable {key} is deprecated, please use {new_key}\"\n            )\n            os.environ[new_key] = value\n\n\n_convert_SGL_to_SGLANG()\n\n_warn_deprecated_env_to_cli_flag(\n    \"SGLANG_ENABLE_FLASHINFER_FP8_GEMM\",\n    \"It will be completely removed in 0.5.7. Please use '--fp8-gemm-backend=flashinfer_trtllm' instead.\",\n)\n_warn_deprecated_env_to_cli_flag(\n    \"SGLANG_ENABLE_FLASHINFER_GEMM\",\n    \"It will be completely removed in 0.5.7. Please use '--fp8-gemm-backend=flashinfer_trtllm' instead.\",\n)\n_warn_deprecated_env_to_cli_flag(\n    \"SGLANG_SUPPORT_CUTLASS_BLOCK_FP8\",\n    \"It will be completely removed in 0.5.7. Please use '--fp8-gemm-backend=cutlass' instead.\",\n)\n_warn_deprecated_env_to_cli_flag(\n    \"SGLANG_FLASHINFER_FP4_GEMM_BACKEND\",\n    \"It will be completely removed in 0.5.9. Please use '--fp4-gemm-backend' instead.\",\n)\n_warn_deprecated_env_to_cli_flag(\n    \"SGLANG_SCHEDULER_DECREASE_PREFILL_IDLE\",\n    \"Please use '--enable-prefill-delayer' instead.\",\n)\n_warn_deprecated_env_to_cli_flag(\n    \"SGLANG_PREFILL_DELAYER_MAX_DELAY_PASSES\",\n    \"Please use '--prefill-delayer-max-delay-passes' instead.\",\n)\n_warn_deprecated_env_to_cli_flag(\n    \"SGLANG_PREFILL_DELAYER_TOKEN_USAGE_LOW_WATERMARK\",\n    \"Please use '--prefill-delayer-token-usage-low-watermark' instead.\",\n)\n\n# Import cuda_coredump to trigger auto-injection of CUDA env vars\n# when SGLANG_CUDA_COREDUMP=1. Best-effort; for strict guarantees,\n# set CUDA_* env vars in the shell before launching Python.\nimport sglang.srt.debug_utils.cuda_coredump  # noqa: F401, E402\n\n\ndef example_with_exit_stack():\n    # Use this style of context manager in unit test\n    exit_stack = ExitStack()\n    exit_stack.enter_context(envs.SGLANG_TEST_RETRACT.override(False))\n    assert envs.SGLANG_TEST_RETRACT.get() is False\n    exit_stack.close()\n    assert envs.SGLANG_TEST_RETRACT.get() is None\n\n\ndef example_with_subprocess():\n    command = [\"python\", \"-c\", \"import os; print(os.getenv('SGLANG_TEST_RETRACT'))\"]\n    with envs.SGLANG_TEST_RETRACT.override(True):\n        process = subprocess.Popen(\n            command, stdout=subprocess.PIPE, stderr=subprocess.PIPE\n        )\n        process.wait()\n        output = process.stdout.read().decode(\"utf-8\").strip()\n        assert output == \"True\"\n\n    process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)\n    output = process.stdout.read().decode(\"utf-8\").strip()\n    assert output == \"None\"\n\n\ndef example_with_implicit_bool_avoidance():\n    @contextmanager\n    def assert_throws(message_matcher: str):\n        try:\n            yield\n        except Exception as e:\n            assert message_matcher in str(e), f\"{e=}\"\n            print(f\"assert_throws find expected error: {e}\")\n            return\n        raise AssertionError(f\"assert_throws do not see exceptions\")\n\n    with assert_throws(\"Please use `envs.YOUR_FLAG.get()` instead of `envs.YOUR_FLAG`\"):\n        if envs.SGLANG_TEST_RETRACT:\n            pass\n\n    with assert_throws(\"Please use `envs.YOUR_FLAG.get()` instead of `envs.YOUR_FLAG`\"):\n        if (1 != 1) or envs.SGLANG_TEST_RETRACT:\n            pass\n\n    with assert_throws(\"Please use `envs.YOUR_FLAG.get()` instead of `envs.YOUR_FLAG`\"):\n        if envs.SGLANG_TEST_RETRACT or (1 == 1):\n            pass\n\n\ndef examples():\n    # Example usage for envs\n    envs.SGLANG_TEST_RETRACT.clear()\n    assert envs.SGLANG_TEST_RETRACT.get() is False\n\n    envs.SGLANG_TEST_RETRACT.set(None)\n    assert envs.SGLANG_TEST_RETRACT.is_set() and envs.SGLANG_TEST_RETRACT.get() is None\n\n    envs.SGLANG_TEST_RETRACT.clear()\n    assert not envs.SGLANG_TEST_RETRACT.is_set()\n\n    envs.SGLANG_TEST_RETRACT.set(True)\n    assert envs.SGLANG_TEST_RETRACT.get() is True\n\n    with envs.SGLANG_TEST_RETRACT.override(None):\n        assert (\n            envs.SGLANG_TEST_RETRACT.is_set() and envs.SGLANG_TEST_RETRACT.get() is None\n        )\n\n    assert envs.SGLANG_TEST_RETRACT.get() is True\n\n    envs.SGLANG_TEST_RETRACT.set(None)\n    with envs.SGLANG_TEST_RETRACT.override(True):\n        assert envs.SGLANG_TEST_RETRACT.get() is True\n\n    assert envs.SGLANG_TEST_RETRACT.is_set() and envs.SGLANG_TEST_RETRACT.get() is None\n\n    example_with_exit_stack()\n    example_with_subprocess()\n    example_with_implicit_bool_avoidance()\n\n\nif __name__ == \"__main__\":\n    examples()\n"
  },
  {
    "path": "python/sglang/srt/eplb/__init__.py",
    "content": ""
  },
  {
    "path": "python/sglang/srt/eplb/eplb_algorithms/__init__.py",
    "content": "from enum import Enum, auto\nfrom typing import Optional\n\nimport torch\n\nfrom sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager\nfrom sglang.srt.eplb.eplb_algorithms import deepseek, deepseek_vec, elasticity_aware\n\n\nclass EplbAlgorithm(Enum):\n    deepseek = auto()\n    deepseek_hierarchical = auto()\n    deepseek_vec = auto()\n    deepseek_vec_hierarchical = auto()\n    elasticity_aware = auto()\n    elasticity_aware_hierarchical = auto()\n    # TODO may have more algorithm later\n\n\ndef rebalance_experts(\n    tokens_per_expert: torch.Tensor,\n    num_physical_experts: int,\n    num_local_physical_experts: int,\n    num_groups: Optional[int],\n    num_nodes: int,\n    algorithm: EplbAlgorithm,\n):\n    if algorithm in [EplbAlgorithm.deepseek, EplbAlgorithm.deepseek_hierarchical]:\n        return deepseek.rebalance_experts(\n            weight=tokens_per_expert.sum(dim=0),\n            num_replicas=num_physical_experts,\n            num_groups=num_groups,\n            num_nodes=num_nodes,\n            num_gpus=num_physical_experts // num_local_physical_experts,\n            enable_hierarchical=algorithm == EplbAlgorithm.deepseek_hierarchical,\n        )\n\n    if algorithm in [\n        EplbAlgorithm.deepseek_vec,\n        EplbAlgorithm.deepseek_vec_hierarchical,\n    ]:\n        return deepseek_vec.rebalance_experts(\n            tokens_per_expert=tokens_per_expert,\n            num_physical_experts=num_physical_experts,\n            num_local_physical_experts=num_local_physical_experts,\n            num_groups=num_groups,\n            num_nodes=num_nodes,\n            enable_hierarchical=algorithm == EplbAlgorithm.deepseek_vec_hierarchical,\n        )\n\n    if algorithm in [\n        EplbAlgorithm.elasticity_aware,\n        EplbAlgorithm.elasticity_aware_hierarchical,\n    ]:\n        return elasticity_aware.rebalance_experts(\n            weight=tokens_per_expert.sum(dim=0),\n            num_replicas=num_physical_experts,\n            num_groups=num_groups,\n            num_nodes=num_nodes,\n            num_gpus=num_physical_experts // num_local_physical_experts,\n            enable_hierarchical=(\n                algorithm == EplbAlgorithm.elasticity_aware_hierarchical\n            ),\n            active_ranks=(\n                ElasticEPStateManager.instance().active_ranks\n                if ElasticEPStateManager.instance() is not None\n                else ElasticEPStateManager.healthy_rank_state()\n            ),\n        )\n\n    raise NotImplementedError\n\n\ndef compute_algorithm(\n    raw_algorithm: str,\n    num_groups: Optional[int],\n    num_nodes: int,\n) -> EplbAlgorithm:\n    if raw_algorithm != \"auto\":\n        return EplbAlgorithm[raw_algorithm]\n\n    # TODO test on real scenarios and know which ones perform better\n    if (num_groups is not None) and (num_groups % num_nodes == 0):\n        return EplbAlgorithm.deepseek_hierarchical\n    else:\n        return EplbAlgorithm.deepseek\n"
  },
  {
    "path": "python/sglang/srt/eplb/eplb_algorithms/deepseek.py",
    "content": "# This file is copied from https://github.com/deepseek-ai/EPLB/blob/main/eplb.py since that one is not a pypi package\nfrom typing import Tuple\n\nimport torch\n\n\ndef balanced_packing(\n    weight: torch.Tensor, num_packs: int\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Pack n weighted objects to m packs, such that each bin contains exactly n/m objects and the weights of all packs\n    are as balanced as possible.\n\n    Parameters:\n        weight: [X, n], the weight of each item\n        num_packs: number of packs\n\n    Returns:\n        pack_index: [X, n], the pack index of each item\n        rank_in_pack: [X, n], the rank of the item in the pack\n    \"\"\"\n    num_layers, num_groups = weight.shape\n    assert num_groups % num_packs == 0\n    groups_per_pack = num_groups // num_packs\n\n    if groups_per_pack == 1:\n        pack_index = torch.arange(\n            weight.size(-1), dtype=torch.int64, device=weight.device\n        ).expand(weight.shape)\n        rank_in_pack = torch.zeros_like(weight, dtype=torch.int64)\n        return pack_index, rank_in_pack\n\n    indices = weight.float().sort(-1, descending=True).indices.cpu()\n    pack_index = torch.full_like(weight, fill_value=-1, dtype=torch.int64, device=\"cpu\")\n    rank_in_pack = torch.full_like(pack_index, fill_value=-1)\n    for i in range(num_layers):\n        pack_weights = [0] * num_packs\n        pack_items = [0] * num_packs\n        for group in indices[i]:\n            pack = min(\n                (i for i in range(num_packs) if pack_items[i] < groups_per_pack),\n                key=pack_weights.__getitem__,\n            )\n            assert pack_items[pack] < groups_per_pack\n            pack_index[i, group] = pack\n            rank_in_pack[i, group] = pack_items[pack]\n            pack_weights[pack] += weight[i, group]\n            pack_items[pack] += 1\n    return pack_index, rank_in_pack\n\n\ndef replicate_experts(\n    weight: torch.Tensor, num_phy: int\n) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Replicate `num_log` experts to `num_phy` replicas, such that the maximum load of all replicas is minimized.\n\n    Parameters:\n        weight: [X, num_log]\n        num_phy: total number of experts after replication\n\n    Returns:\n        phy2log: [X, num_phy], logical expert id of each physical expert\n        rank: [X, num_phy], the replica rank\n        logcnt: [X, num_log], number of replicas for each logical expert\n    \"\"\"\n    n, num_log = weight.shape\n    num_redundant = num_phy - num_log\n    assert num_redundant >= 0\n    device = weight.device\n    phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1)\n    rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device)\n    logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device)\n    arangen = torch.arange(n, dtype=torch.int64, device=device)\n    for i in range(num_log, num_phy):\n        redundant_indices = (weight / logcnt).max(dim=-1).indices\n        phy2log[:, i] = redundant_indices\n        rank[:, i] = logcnt[arangen, redundant_indices]\n        logcnt[arangen, redundant_indices] += 1\n    return phy2log, rank, logcnt\n\n\ndef rebalance_experts_hierarchical(\n    weight: torch.Tensor,\n    num_physical_experts: int,\n    num_groups: int,\n    num_nodes: int,\n    num_gpus: int,\n):\n    \"\"\"\n    Parameters:\n        weight: [num_moe_layers, num_logical_experts]\n        num_physical_experts: number of physical experts after replication\n        num_groups: number of expert groups\n        num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster\n        num_gpus: number of GPUs, must be a multiple of `num_nodes`\n\n    Returns:\n        physical_to_logical_map: [num_moe_layers, num_physical_experts]\n        logical_to_physical_map: [num_moe_layers, num_logical_experts, X]\n        logical_count: [num_moe_layers, num_logical_experts]\n    \"\"\"\n    num_layers, num_logical_experts = weight.shape\n    assert num_logical_experts % num_groups == 0\n    group_size = num_logical_experts // num_groups\n    assert num_groups % num_nodes == 0\n    groups_per_node = num_groups // num_nodes\n    assert num_gpus % num_nodes == 0\n    assert num_physical_experts % num_gpus == 0\n    phy_experts_per_gpu = num_physical_experts // num_gpus\n\n    def inverse(perm: torch.Tensor) -> torch.Tensor:\n        inv = torch.empty_like(perm)\n        inv.scatter_(\n            1,\n            perm,\n            torch.arange(perm.size(1), dtype=torch.int64, device=perm.device).expand(\n                perm.shape\n            ),\n        )\n        return inv\n\n    # Step 1: pack groups to nodes\n    tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1)\n    group_pack_index, group_rank_in_pack = balanced_packing(tokens_per_group, num_nodes)\n    log2mlog = (\n        (\n            (group_pack_index * groups_per_node + group_rank_in_pack) * group_size\n        ).unsqueeze(-1)\n        + torch.arange(group_size, dtype=torch.int64, device=group_pack_index.device)\n    ).flatten(-2)\n    mlog2log = inverse(log2mlog)\n\n    # Step 2: construct redundant experts within nodes\n    # [num_layers * num_nodes, num_logical_experts // num_nodes]\n    tokens_per_mlog = weight.gather(-1, mlog2log).view(\n        -1, num_logical_experts // num_nodes\n    )\n    phy2mlog, phyrank, mlogcnt = replicate_experts(\n        tokens_per_mlog, num_physical_experts // num_nodes\n    )\n\n    # Step 3: pack physical_experts to GPUs\n    # [num_layers * num_nodes, num_physical_experts // num_nodes]\n    tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog)\n    pack_index, rank_in_pack = balanced_packing(tokens_per_phy, num_gpus // num_nodes)\n    phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack\n    pphy2phy = inverse(phy2pphy)\n\n    pphy2mlog = phy2mlog.gather(\n        -1, pphy2phy\n    )  # [num_layers * num_nodes, num_log_per_nodes]\n    pphy2mlog = (\n        pphy2mlog.view(num_layers, num_nodes, -1)\n        + torch.arange(\n            0,\n            num_logical_experts,\n            num_logical_experts // num_nodes,\n            device=group_pack_index.device,\n        ).view(1, -1, 1)\n    ).flatten(-2)\n    pphy2log = mlog2log.gather(-1, pphy2mlog)\n    pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1)\n    logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog)\n    return pphy2log, pphyrank, logcnt\n\n\ndef rebalance_experts(\n    weight: torch.Tensor,\n    num_replicas: int,\n    num_groups: int,\n    num_nodes: int,\n    num_gpus: int,\n    enable_hierarchical: bool,\n) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Entry point for expert-parallelism load balancer.\n\n    Parameters:\n        weight: [layers, num_logical_experts], the load statistics for all logical experts\n        num_replicas: number of physical experts, must be a multiple of `num_gpus`\n        num_groups: number of expert groups\n        num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster\n        num_gpus: number of GPUs, must be a multiple of `num_nodes`\n\n    Returns:\n        physical_to_logical_map: [layers, num_replicas], the expert index of each replica\n        logical_to_physical_map: [layers, num_logical_experts, X], the replica indices for each expert\n        expert_count: [layers, num_logical_experts], number of physical replicas for each logical expert\n    \"\"\"\n\n    num_layers, num_logical_experts = weight.shape\n    weight = weight.float().cpu()\n    if enable_hierarchical:\n        # use hierarchical load-balance policy\n        phy2log, phyrank, logcnt = rebalance_experts_hierarchical(\n            weight, num_replicas, num_groups, num_nodes, num_gpus\n        )\n    else:\n        # use global load-balance policy\n        phy2log, phyrank, logcnt = rebalance_experts_hierarchical(\n            weight, num_replicas, 1, 1, num_gpus\n        )\n    maxlogcnt = logcnt.max().item()\n    log2phy: torch.Tensor = torch.full(\n        (num_layers, num_logical_experts, maxlogcnt),\n        -1,\n        dtype=torch.int64,\n        device=logcnt.device,\n    )\n    log2phy.view(num_layers, -1).scatter_(\n        -1,\n        phy2log * maxlogcnt + phyrank,\n        torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand(\n            num_layers, -1\n        ),\n    )\n    return phy2log, log2phy, logcnt\n\n\n__all__ = [\"rebalance_experts\"]\n"
  },
  {
    "path": "python/sglang/srt/eplb/eplb_algorithms/deepseek_vec.py",
    "content": "# This file is copied from https://github.com/deepseek-ai/EPLB/blob/main/eplb.py since that one is not a pypi package\nfrom typing import Optional, Tuple\n\nimport torch\n\n\ndef pack_groups(tokens_per_group: torch.Tensor, num_nodes: int) -> torch.Tensor:\n    num_layers, num_groups = tokens_per_group.shape\n    assert num_groups % num_nodes == 0\n    groups_per_rank = num_groups // num_nodes\n\n    indices = tokens_per_group.float().sort(-1, descending=True).indices.cpu()\n    ret = torch.full_like(\n        tokens_per_group, fill_value=-1, dtype=torch.int64, device=\"cpu\"\n    )\n    for layer in range(num_layers):\n        node_tokens = [0] * num_nodes\n        node_groups = [0] * num_nodes\n        for group in indices[layer]:\n\n            def key_func(rank: int) -> int:\n                if node_groups[rank] >= groups_per_rank:\n                    return 1, 0\n                else:\n                    return 0, node_tokens[rank]\n\n            rank = min(range(num_nodes), key=key_func)\n            assert node_groups[rank] < groups_per_rank\n            ret[layer, group] = rank * groups_per_rank + node_groups[rank]\n            node_tokens[rank] += tokens_per_group[layer, group]\n            node_groups[rank] += 1\n    return ret\n\n\ndef make_redundant_experts_chunkwise(\n    tokens_per_expert: torch.Tensor,\n    num_physical_experts: int,\n    num_local_physical_experts: int,\n    num_physical_experts_per_chunk: int,\n) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n    num_steps, num_moe_layers, num_logical_experts = tokens_per_expert.shape\n    num_redundancy_experts = num_physical_experts - num_logical_experts\n\n    physical_to_logical_map = torch.empty(\n        num_moe_layers,\n        num_physical_experts,\n        dtype=torch.int,\n        device=tokens_per_expert.device,\n    )\n    logical_to_physical_map = torch.full(\n        (num_moe_layers, num_logical_experts, num_redundancy_experts + 1),\n        -1,\n        dtype=torch.int,\n        device=tokens_per_expert.device,\n    )\n    logical_count = torch.ones(\n        num_moe_layers,\n        num_logical_experts,\n        dtype=torch.int,\n        device=tokens_per_expert.device,\n    )\n\n    assert num_physical_experts % num_physical_experts_per_chunk == 0\n    num_chunks = num_physical_experts // num_physical_experts_per_chunk\n    assert num_logical_experts % num_chunks == 0\n    num_logical_experts_per_group = num_logical_experts // num_chunks\n    assert num_redundancy_experts % num_chunks == 0\n    num_redundancy_experts_per_group = num_redundancy_experts // num_chunks\n\n    arange_num_moe_layers_num_groups = torch.arange(\n        num_moe_layers * num_chunks, dtype=torch.int, device=tokens_per_expert.device\n    )\n    arange_num_logical_experts = torch.arange(\n        num_logical_experts, dtype=torch.int, device=tokens_per_expert.device\n    )\n    arange_num_logical_experts_per_group = torch.arange(\n        num_logical_experts_per_group, dtype=torch.int, device=tokens_per_expert.device\n    )\n    arange_num_groups = torch.arange(\n        num_chunks, dtype=torch.int, device=tokens_per_expert.device\n    )\n    physical_to_logical_map.view(\n        num_moe_layers, num_chunks, num_physical_experts_per_chunk\n    )[:, :, :num_logical_experts_per_group] = arange_num_logical_experts.view(\n        num_chunks, num_logical_experts_per_group\n    )\n    logical_to_physical_map[:, :, 0] = (\n        arange_num_logical_experts_per_group.expand(\n            num_chunks, num_logical_experts_per_group\n        )\n        + arange_num_groups[:, None] * num_physical_experts_per_chunk\n    ).view(num_logical_experts)\n\n    tokens_per_expert_all_diff = tokens_per_expert + arange_num_logical_experts * 1e-4\n    for i in range(num_redundancy_experts_per_group):\n        score = (\n            tokens_per_expert_all_diff / logical_count\n        )  # NOTE: Values in score must be different from each other\n        score1 = tokens_per_expert / (logical_count + 1)\n        score = score.view(\n            num_steps, num_moe_layers, num_chunks, num_logical_experts_per_group\n        )\n        score1 = score1.view_as(score)\n        values, indices = score.max(-1, keepdim=True)\n        values = values.expand_as(score).contiguous()\n        score.scatter_(-1, indices, score1.gather(-1, indices))\n        values.scatter_(-1, indices, score.max(-1, keepdim=True).values)\n        redundancy_indices = values.sum(0).argmin(-1)\n        physical_to_logical_map.view(\n            num_moe_layers, num_chunks, num_physical_experts_per_chunk\n        )[:, :, num_logical_experts_per_group + i] = (\n            redundancy_indices + arange_num_groups * num_logical_experts_per_group\n        )\n        redundancy_count = (\n            logical_count.view(\n                num_moe_layers * num_chunks, num_logical_experts_per_group\n            )\n            .gather(-1, redundancy_indices.view(num_moe_layers * num_chunks, 1))\n            .squeeze(1)\n        )\n        physical_redundancy_indices = (\n            (\n                arange_num_groups * num_physical_experts_per_chunk\n                + num_logical_experts_per_group\n                + i\n            )\n            .expand(num_moe_layers, num_chunks)\n            .flatten()\n        )\n        logical_to_physical_map.view(\n            num_moe_layers * num_chunks,\n            num_logical_experts_per_group,\n            num_redundancy_experts + 1,\n        )[\n            arange_num_moe_layers_num_groups,\n            redundancy_indices.view(num_moe_layers * num_chunks),\n            redundancy_count,\n        ] = physical_redundancy_indices\n        logical_count.view(num_moe_layers * num_chunks, num_logical_experts_per_group)[\n            arange_num_moe_layers_num_groups,\n            redundancy_indices.view(num_moe_layers * num_chunks),\n        ] += 1\n\n    if num_local_physical_experts > 1:\n        # Load-balancing between GPUs\n        physical_to_logical_map_int64 = physical_to_logical_map.to(torch.int64)\n        counts = logical_count.gather(-1, physical_to_logical_map_int64)\n        score = tokens_per_expert.sum(0).gather(-1, physical_to_logical_map_int64)\n        score = score / counts\n        score = score.view(num_moe_layers, num_chunks, num_physical_experts_per_chunk)\n        indices = score.argsort(-1, descending=True)\n        indices += torch.arange(\n            0,\n            num_physical_experts,\n            num_physical_experts_per_chunk,\n            dtype=indices.dtype,\n            device=indices.device,\n        )[None, :, None]\n\n        assert num_physical_experts_per_chunk % num_local_physical_experts == 0\n        num_local_groups = num_physical_experts_per_chunk // num_local_physical_experts\n        indices = indices.view(\n            num_moe_layers, num_chunks, num_local_physical_experts, num_local_groups\n        )\n        indices[:, :, 1::2, :] = indices[:, :, 1::2, :].flip(-1)\n        indices = indices.transpose(2, 3)\n        indices = indices.reshape(num_moe_layers, num_physical_experts)\n        physical_to_logical_map = physical_to_logical_map.gather(-1, indices)\n        mask = logical_to_physical_map == -1\n        logical_to_physical_map[mask] = 0\n        logical_to_physical_map = (\n            indices.argsort(-1)\n            .gather(\n                -1, logical_to_physical_map.view(num_moe_layers, -1).to(torch.int64)\n            )\n            .view_as(logical_to_physical_map)\n            .to(torch.int)\n        )\n        logical_to_physical_map[mask] = -1\n\n    return physical_to_logical_map, logical_to_physical_map, logical_count\n\n\ndef decode_rebalance_experts(\n    tokens_per_expert: torch.Tensor,\n    num_physical_experts: int,\n    num_local_physical_experts: int,\n):\n    return make_redundant_experts_chunkwise(\n        tokens_per_expert,\n        num_physical_experts,\n        num_local_physical_experts,\n        num_physical_experts,\n    )\n\n\ndef prefill_rebalance_experts(\n    tokens_per_expert: torch.Tensor,\n    num_physical_experts: int,\n    num_local_physical_experts: int,\n    num_groups: int,\n    num_nodes: int,\n):\n    tokens_per_expert = tokens_per_expert.float().cpu()\n\n    num_steps, _, num_logical_experts = tokens_per_expert.shape\n    assert num_logical_experts % num_groups == 0\n    group_size = num_logical_experts // num_groups\n    assert num_groups % num_nodes == 0, f\"{num_groups=} {num_nodes=}\"\n\n    tokens_per_group = tokens_per_expert.sum(0).unflatten(-1, (num_groups, -1)).sum(-1)\n    group_perm = pack_groups(\n        tokens_per_group, num_nodes\n    )  # [num_moe_layers, num_groups] => [num_moe_layers, num_nodes]\n\n    # log2mlog [layers, #logexp] -> [layers, #logexp]\n    log2mlog = (\n        (group_perm * group_size).unsqueeze(-1)\n        + torch.arange(group_size, dtype=torch.int64, device=group_perm.device)\n    ).flatten(-2)\n\n    # mlog2log [layers, #logexp] -> [layers, #logexp], inverse of log2mlog\n    mlog2log = torch.empty_like(log2mlog)\n    arange = torch.arange(\n        num_logical_experts, dtype=torch.int64, device=mlog2log.device\n    )\n    mlog2log.scatter_(1, log2mlog, arange.expand(log2mlog.size(0), -1))\n\n    # tokens_per_mlog[i][j][k] = tokens_per_expert[i][j][mlog2log[j][k]]\n    tokens_per_mlog = tokens_per_expert.gather(\n        2, mlog2log.unsqueeze(0).expand(num_steps, -1, -1)\n    )\n\n    phy2mlog, mlog2phy, mlog_count = make_redundant_experts_chunkwise(\n        tokens_per_mlog,\n        num_physical_experts,\n        num_local_physical_experts,\n        num_physical_experts // num_nodes,\n    )\n\n    # phy2log[i][j] = mlog2log[i][phy2mlog[i][j]]\n    phy2log = mlog2log.gather(1, phy2mlog.to(torch.int64))\n\n    # mlog2phy: [num_moe_layers, num_logical_experts, ...]\n    # log2phy[i][j][k] = mlog2phy[i][log2mlog[i][j]][k]\n    log2phy = mlog2phy.gather(\n        1, log2mlog.unsqueeze(-1).expand(-1, -1, mlog2phy.size(-1)).to(torch.int64)\n    )\n\n    # log_count[i][j] = mlog_count[i][log2mlog[i][j]]\n    log_count = mlog_count.gather(1, log2mlog)\n    return phy2log, log2phy, log_count\n\n\ndef rebalance_experts(\n    tokens_per_expert: torch.Tensor,\n    num_physical_experts: int,\n    num_local_physical_experts: int,\n    num_groups: Optional[int],\n    num_nodes: int,\n    enable_hierarchical: bool,\n):\n    if enable_hierarchical:\n        return prefill_rebalance_experts(\n            tokens_per_expert=tokens_per_expert,\n            num_physical_experts=num_physical_experts,\n            num_local_physical_experts=num_local_physical_experts,\n            num_groups=num_groups,\n            num_nodes=num_nodes,\n        )\n    else:\n        return decode_rebalance_experts(\n            tokens_per_expert=tokens_per_expert,\n            num_physical_experts=num_physical_experts,\n            num_local_physical_experts=num_local_physical_experts,\n        )\n"
  },
  {
    "path": "python/sglang/srt/eplb/eplb_algorithms/elasticity_aware.py",
    "content": "from typing import Tuple\n\nimport torch\n\nfrom sglang.srt.eplb.eplb_algorithms.deepseek import rebalance_experts_hierarchical\n\n\ndef rebalance_experts(\n    weight: torch.Tensor,\n    num_replicas: int,\n    num_groups: int,\n    num_nodes: int,\n    num_gpus: int,\n    enable_hierarchical: bool,\n    active_ranks: torch.Tensor,\n) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Entry point for expert-parallelism load balancer.\n\n    Parameters:\n        weight: [layers, num_logical_experts], the load statistics for all logical experts\n        num_replicas: number of physical experts, must be a multiple of `num_gpus`\n        num_groups: number of expert groups\n        num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster\n        num_gpus: number of GPUs, must be a multiple of `num_nodes`\n\n    Returns:\n        physical_to_logical_map: [layers, num_replicas], the expert index of each replica\n        logical_to_physical_map: [layers, num_logical_experts, X], the replica indices for each expert\n        expert_count: [layers, num_logical_experts], number of physical replicas for each logical expert\n    \"\"\"\n\n    num_layers, num_logical_experts = weight.shape\n    weight = weight.float().cpu()\n    num_active_ranks = active_ranks.sum().item()\n    num_local_experts = num_replicas // num_gpus\n    if num_active_ranks < num_gpus:\n        # Must fall back to global load-balance policy\n        # and fix some params\n        phy2log, phyrank, logcnt = rebalance_experts_hierarchical(\n            weight,\n            num_local_experts * num_active_ranks,\n            1,\n            1,\n            num_active_ranks,\n        )\n    elif enable_hierarchical:\n        # use hierarchical load-balance policy\n        phy2log, phyrank, logcnt = rebalance_experts_hierarchical(\n            weight, num_replicas, num_groups, num_nodes, num_gpus\n        )\n    else:\n        # use global load-balance policy\n        phy2log, phyrank, logcnt = rebalance_experts_hierarchical(\n            weight, num_replicas, 1, 1, num_gpus\n        )\n    maxlogcnt = logcnt.max().item()\n    log2phy: torch.Tensor = torch.full(\n        (num_layers, num_logical_experts, maxlogcnt),\n        -1,\n        dtype=torch.int64,\n        device=logcnt.device,\n    )\n    log2phy.view(num_layers, -1).scatter_(\n        -1,\n        phy2log * maxlogcnt + phyrank,\n        torch.arange(\n            num_local_experts * num_active_ranks,\n            dtype=torch.int64,\n            device=log2phy.device,\n        ).expand(num_layers, -1),\n    )\n    if num_active_ranks < num_gpus:\n        phy2log_slices = list(\n            phy2log.view(num_layers, num_active_ranks, -1).unbind(dim=1)\n        )\n        active_ranks_list = active_ranks.tolist()\n        for idx, active_rank in enumerate(active_ranks_list):\n            if not active_rank:\n                phy2log_slices.insert(idx, torch.zeros_like(phy2log_slices[0]))\n                log2phy = torch.where(\n                    log2phy >= idx * num_local_experts,\n                    log2phy + num_local_experts,\n                    log2phy,\n                )\n        phy2log = torch.stack(phy2log_slices, dim=1).contiguous().view(num_layers, -1)\n    return phy2log, log2phy, logcnt\n"
  },
  {
    "path": "python/sglang/srt/eplb/eplb_manager.py",
    "content": "import logging\nimport time\nfrom typing import TYPE_CHECKING, List\n\nimport torch.cuda\n\nfrom sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder\nfrom sglang.srt.eplb.expert_location import ExpertLocationMetadata\n\nif TYPE_CHECKING:\n    from sglang.srt.model_executor.model_runner import ModelRunner\n\nlogger = logging.getLogger(__name__)\n\n\nclass EPLBManager:\n    def __init__(self, model_runner: \"ModelRunner\"):\n        super().__init__()\n        self._model_runner = model_runner\n        self._server_args = model_runner.server_args\n        self._rebalance_layers_per_chunk = (\n            self._server_args.eplb_rebalance_layers_per_chunk\n        )\n        self._rebalance_num_iterations = self._server_args.eplb_rebalance_num_iterations\n\n        # Otherwise, the circular buffer will contain stale data. If the case is needed, it can be implemented.\n        assert (\n            self._server_args.eplb_rebalance_num_iterations\n            >= self._server_args.expert_distribution_recorder_buffer_size\n        ), \"eplb_rebalance_num_iterations must be greater than expert_distribution_recorder_buffer_size\"\n\n        if not get_global_expert_distribution_recorder().recording:\n            get_global_expert_distribution_recorder().start_record()\n\n        logger.info(\n            f\"[EPLBManager] system started, will rebalance per {self._rebalance_num_iterations} iterations.\"\n        )\n\n        self._main_generator = self._entrypoint()\n\n    def on_forward_pass_end(self):\n        next(self._main_generator)\n\n    # can be more complex if needed\n    def _entrypoint(self):\n        while True:\n            for _ in range(self._rebalance_num_iterations):\n                yield\n\n            yield from self.rebalance()\n\n    def rebalance(self):\n        logger.info(\"[EPLBManager] rebalance start\")\n\n        enable_timing = self._rebalance_layers_per_chunk is None\n\n        if enable_timing:\n            torch.get_device_module().synchronize()\n            time_start = time.time()\n\n        dump_record_output = get_global_expert_distribution_recorder().dump_record(\n            output_mode=\"object\"\n        )\n        logical_count = dump_record_output[\"logical_count\"]\n        average_utilization_rate_over_window = dump_record_output[\n            \"average_utilization_rate_over_window\"\n        ]\n\n        # Check whether rebalancing is needed\n        if not self._check_rebalance_needed(average_utilization_rate_over_window):\n            return\n\n        expert_location_metadata = ExpertLocationMetadata.init_by_eplb(\n            self._server_args, self._model_runner.model_config, logical_count\n        )\n\n        update_layer_ids_chunks = self._compute_update_layer_ids_chunks()\n        for chunk_index, update_layer_ids in enumerate(update_layer_ids_chunks):\n            if len(update_layer_ids_chunks) > 1:\n                yield\n            self._model_runner.update_expert_location(\n                expert_location_metadata,\n                update_layer_ids=update_layer_ids,\n            )\n\n        msg = f\"[EPLBManager] rebalance end\"\n        if enable_timing:\n            torch.get_device_module().synchronize()\n            time_end = time.time()\n            msg += f\" time={time_end - time_start:.3f}s\"\n        logger.info(msg)\n\n    def _check_rebalance_needed(self, average_utilization_rate_over_window):\n        if average_utilization_rate_over_window is None:\n            return True\n\n        if (\n            average_utilization_rate_over_window\n            > self._server_args.eplb_min_rebalancing_utilization_threshold\n        ):\n            logger.info(\n                f\"[EPLBManager] Skipped ep rebalancing: current GPU utilization {average_utilization_rate_over_window:.2f} > minimum rebalance threshold {self._server_args.eplb_min_rebalancing_utilization_threshold:.2f}\"\n            )\n            return False\n\n        return True\n\n    def _compute_update_layer_ids_chunks(self) -> List[List[int]]:\n        all_layer_ids = sorted(\n            list(self._model_runner.model.routed_experts_weights_of_layer.keys())\n        )\n        chunk_size = self._rebalance_layers_per_chunk or 1000000\n        return list(_chunk_list(all_layer_ids, chunk_size=chunk_size))\n\n\ndef _chunk_list(items: List, chunk_size):\n    for start_index in range(0, len(items), chunk_size):\n        yield items[start_index : start_index + chunk_size]\n"
  },
  {
    "path": "python/sglang/srt/eplb/eplb_simulator/__init__.py",
    "content": "from . import reader\n"
  },
  {
    "path": "python/sglang/srt/eplb/eplb_simulator/reader.py",
    "content": "from collections import defaultdict\nfrom pathlib import Path\n\nimport torch\nfrom tqdm import tqdm\n\nfrom sglang.srt.eplb.expert_distribution import (\n    _convert_global_physical_count_to_logical_count,\n)\n\nconvert_global_physical_count_to_logical_count = (\n    _convert_global_physical_count_to_logical_count\n)\n\n\ndef read_mode_per_pass(dir_data: Path):\n    \"\"\"Read data from ExpertDistributionRecorder when recorded with mode `per_pass`\"\"\"\n\n    # gpc := global_physical_count\n    gpc_of_forward_pass_and_rank = defaultdict(lambda: defaultdict())\n    for path in tqdm(list(dir_data.glob(\"*.pt\"))):\n        data_pack = torch.load(path, weights_only=True)\n        last_physical_to_logical_map = data_pack[\"last_physical_to_logical_map\"]\n        for record in data_pack[\"records\"]:\n            forward_pass_id = record[\"forward_pass_id\"]\n            rank = record[\"rank\"]\n            assert (\n                gpc_of_forward_pass_and_rank[forward_pass_id].get(rank) is None\n            ), f\"Duplicated {forward_pass_id=} {rank=}\"\n            gpc_of_forward_pass_and_rank[forward_pass_id][rank] = record[\n                \"global_physical_count\"\n            ]\n\n    forward_pass_ids = sorted(gpc_of_forward_pass_and_rank.keys())\n    print(f\"Make {forward_pass_ids=} into array\")\n\n    items = []\n    for forward_pass_id, gpc_of_rank in sorted(gpc_of_forward_pass_and_rank.items()):\n        gpc_of_rank_tensor = torch.stack(\n            [gpc for rank, gpc in sorted(gpc_of_rank.items())]\n        ).sum(dim=0)\n        items.append(gpc_of_rank_tensor)\n\n    gpc_of_forward_pass = torch.stack(items)\n    print(f\"{gpc_of_forward_pass.shape=}\")\n\n    return dict(\n        global_physical_count_of_forward_pass=gpc_of_forward_pass,\n        last_physical_to_logical_map=last_physical_to_logical_map,\n        forward_pass_ids=forward_pass_ids,\n    )\n"
  },
  {
    "path": "python/sglang/srt/eplb/expert_distribution.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\nfrom __future__ import annotations\n\nimport logging\nimport math\nimport time\nfrom abc import ABC\nfrom collections import deque\nfrom contextlib import contextmanager\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Type\n\nimport einops\nimport torch\nimport torch.distributed\n\nfrom sglang.srt.environ import envs\nfrom sglang.srt.model_executor.forward_batch_info import ForwardBatch\nfrom sglang.srt.observability.metrics_collector import ExpertDispatchCollector\nfrom sglang.srt.server_args import ServerArgs\nfrom sglang.srt.utils import Withable, get_int_env_var\n\nif TYPE_CHECKING:\n    from sglang.srt.eplb.expert_location import ExpertLocationMetadata\n\nlogger = logging.getLogger(__name__)\n\n# --------------------------------------- Entrypoint -----------------------------------------\n\n_OutputMode = Literal[\"file\", \"object\"]\n\n\n@dataclass\nclass ExpertDistributionMetrics:\n    eplb_balancedness: torch.Tensor\n\n    def copy_to_cpu(self):\n        self.eplb_balancedness = self.eplb_balancedness.to(\"cpu\", non_blocking=True)\n\n\nclass ExpertDistributionRecorder(ABC):\n    \"\"\"Global expert distribution recording\"\"\"\n\n    @staticmethod\n    def init_new(\n        server_args: ServerArgs,\n        expert_location_metadata: ExpertLocationMetadata,\n        rank: int,\n    ):\n        if server_args.expert_distribution_recorder_mode is not None:\n            assert (\n                expert_location_metadata is not None\n            ), \"ExpertLocationMetadata is required for expert distribution recording. One possible\"\n            \"reason is that you are using a model that does not support expert distribution\"\n            \"recording. Try setting `get_model_config_for_expert_location` in your model.\"\n            return _ExpertDistributionRecorderReal(\n                server_args, expert_location_metadata, rank\n            )\n        else:\n            return _ExpertDistributionRecorderNoop()\n\n    @contextmanager\n    def with_current_layer(self, layer_idx):\n        yield\n\n    @contextmanager\n    def with_debug_name(self, debug_name):\n        yield\n\n    @contextmanager\n    def disable_this_region(self):\n        yield\n\n    @contextmanager\n    def with_forward_pass(self, forward_pass_id: int, forward_batch: ForwardBatch):\n        yield {}\n\n    def on_select_experts(self, topk_ids: torch.Tensor):\n        pass\n\n    def on_deepep_dispatch_normal(\n        self,\n        local_physical_count_of_layer: List[int],\n        num_tokens_per_rank,\n        num_tokens_per_rdma_rank,\n        num_tokens_per_expert,\n    ):\n        pass\n\n    def on_deepep_dispatch_low_latency(\n        self, local_physical_count_of_layer: torch.Tensor\n    ):\n        pass\n\n    def start_record(self):\n        self._on_not_implemented()\n\n    def stop_record(self):\n        self._on_not_implemented()\n\n    def dump_record(self, output_mode: _OutputMode = \"file\"):\n        self._on_not_implemented()\n\n    @property\n    def recording(self):\n        return False\n\n    def _on_not_implemented(self):\n        raise Exception(\n            \"Please set ServerArgs.expert_distribution_recorder_mode to use ExpertDistributionRecorder.\"\n        )\n\n\nclass _ExpertDistributionRecorderNoop(ExpertDistributionRecorder):\n    pass\n\n\nclass _ExpertDistributionRecorderReal(ExpertDistributionRecorder):\n    def __init__(\n        self,\n        server_args: ServerArgs,\n        expert_location_metadata: ExpertLocationMetadata,\n        rank: int,\n    ):\n        self._server_args = server_args\n        self._expert_location_metadata = expert_location_metadata\n\n        self._recording = False\n        self._disable_all = False\n        self._current_forward_pass_id = Withable()\n        self._current_layer_idx = Withable()\n        self._current_debug_name = Withable()\n        self._accumulator = _Accumulator.init_new(\n            server_args, expert_location_metadata, rank\n        )\n        self._single_pass_gatherers = {\n            k: _SinglePassGatherer.init_new(server_args, expert_location_metadata, rank)\n            for k in self._accumulator.get_single_pass_gatherer_keys()\n        }\n\n        if server_args.enable_expert_distribution_metrics:\n            logger.info(\n                \"ExpertDistributionRecorder auto start record since enable_expert_distribution_metrics\"\n            )\n            self.start_record()\n\n    def with_current_layer(self, layer_idx):\n        return self._current_layer_idx.with_value(layer_idx)\n\n    def with_debug_name(self, debug_name):\n        return self._current_debug_name.with_value(debug_name)\n\n    @contextmanager\n    def with_forward_pass(self, forward_pass_id: int, forward_batch: ForwardBatch):\n        outputs = {}\n        with self._current_forward_pass_id.with_value(forward_pass_id):\n            self._on_forward_pass_start(forward_batch)\n            try:\n                yield outputs\n            finally:\n                self._on_forward_pass_end(forward_pass_id, outputs)\n\n    @contextmanager\n    def disable_this_region(self):\n        \"\"\"Context manager to temporarily disable recording.\"\"\"\n        previous_disable_all = self._disable_all\n        self._disable_all = True\n        try:\n            yield\n        finally:\n            self._disable_all = previous_disable_all\n\n    def _on_forward_pass_start(self, forward_batch: ForwardBatch):\n        if not self._recording:\n            return\n        for gatherer_key, gatherer in self._single_pass_gatherers.items():\n            gatherer.reset()\n            gatherer.on_forward_pass_start(forward_batch)\n\n    def _on_forward_pass_end(self, forward_pass_id: int, outputs: Dict[str, Any]):\n        if not self._recording:\n            return\n        for gatherer_key, gatherer in self._single_pass_gatherers.items():\n            single_pass_data = gatherer.collect()\n            self._accumulator.append(\n                forward_pass_id, gatherer_key, single_pass_data, outputs\n            )\n\n    def on_select_experts(self, topk_ids: torch.Tensor):\n        self._on_hook(\"on_select_experts\", topk_ids=topk_ids)\n\n    def on_deepep_dispatch_normal(\n        self,\n        local_physical_count_of_layer: List[int],\n        num_tokens_per_rank,\n        num_tokens_per_rdma_rank,\n        num_tokens_per_expert,\n    ):\n        self._on_hook(\n            \"on_deepep_dispatch_normal\",\n            local_physical_count_of_layer=local_physical_count_of_layer,\n            num_tokens_per_rank=num_tokens_per_rank,\n            num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,\n            num_tokens_per_expert=num_tokens_per_expert,\n        )\n\n    def on_deepep_dispatch_low_latency(\n        self, local_physical_count_of_layer: torch.Tensor\n    ):\n        self._on_hook(\n            \"on_deepep_dispatch_low_latency\",\n            local_physical_count_of_layer=local_physical_count_of_layer,\n        )\n\n    def _on_hook(self, hook_name: str, **kwargs):\n        if self._disable_all:\n            return\n        if not (\n            self._recording or torch.get_device_module().is_current_stream_capturing()\n        ):\n            return\n        gatherer = self._single_pass_gatherers[\n            self._accumulator.get_single_pass_gatherer_key(\n                self._current_debug_name.value\n            )\n        ]\n        getattr(gatherer, hook_name)(layer_idx=self._current_layer_idx.value, **kwargs)\n\n    def _reset(self):\n        \"\"\"Reset the expert distribution recorder.\"\"\"\n        logger.info(\"Resetting ExpertDistributionRecorder...\")\n        assert (\n            self._current_layer_idx.value is None\n        ), f\"{self._current_layer_idx.value=}\"\n        for gatherer in self._single_pass_gatherers.values():\n            gatherer.reset()\n        self._accumulator.reset()\n\n    def start_record(self):\n        \"\"\"Start recording the expert distribution.\"\"\"\n        if self._recording:\n            logger.warning(\n                \"SGLang server is already recording expert ids. Did you forget to dump the expert ids recorded so far by sending requests to the `/stop_expert_distribution_record` and `/dump_expert_distribution_record` endpoints?\"\n            )\n        self._reset()\n        self._recording = True\n\n    def stop_record(self):\n        \"\"\"Stop recording the expert distribution.\"\"\"\n        if not self._recording:\n            logger.warning(\n                \"SGLang server has not been recording expert ids. Did you forget to start recording by sending request to the `/start_expert_distribution_record` endpoint?\"\n            )\n        self._recording = False\n\n    def dump_record(self, output_mode: _OutputMode = \"file\"):\n        \"\"\"Dump the expert distribution record and reset the recorder after dumping.\"\"\"\n        output = self._accumulator.dump(output_mode=output_mode)\n        self._reset()\n        return output\n\n    @property\n    def recording(self):\n        return self._recording\n\n\n_global_expert_distribution_recorder: Optional[ExpertDistributionRecorder] = (\n    _ExpertDistributionRecorderNoop()\n)\n\n\ndef get_global_expert_distribution_recorder():\n    return _global_expert_distribution_recorder\n\n\ndef set_global_expert_distribution_recorder(value):\n    global _global_expert_distribution_recorder\n    _global_expert_distribution_recorder = value\n\n\n# --------------------------------------- SinglePassGatherer -----------------------------------------\n\n\nclass _SinglePassGatherer(ABC):\n    @staticmethod\n    def init_new(\n        server_args: ServerArgs,\n        expert_location_metadata: ExpertLocationMetadata,\n        rank: int,\n    ) -> \"_SinglePassGatherer\":\n        if server_args.expert_distribution_recorder_mode == \"per_token\":\n            return _DetailSinglePassGatherer(\n                server_args, expert_location_metadata, rank\n            )\n\n        if server_args.expert_distribution_recorder_mode == \"stat_approx\":\n            if server_args.moe_a2a_backend != \"none\" and (\n                server_args.deepep_mode == \"normal\"\n            ):\n                return _DeepepNormalSinglePassGatherer(expert_location_metadata, rank)\n            else:\n                raise NotImplementedError\n\n        if server_args.moe_a2a_backend != \"none\":\n            if server_args.deepep_mode == \"normal\":\n                return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank)\n            elif server_args.deepep_mode == \"low_latency\":\n                return _DeepepLowLatencySinglePassGatherer(\n                    expert_location_metadata, rank\n                )\n            else:\n                raise NotImplementedError\n\n        return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank)\n\n    def __init__(self, expert_location_metadata: ExpertLocationMetadata, rank: int):\n        self._expert_location_metadata = expert_location_metadata\n        self._rank = rank\n\n    def on_forward_pass_start(self, forward_batch: ForwardBatch):\n        pass\n\n    def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):\n        pass\n\n    def on_deepep_dispatch_normal(\n        self,\n        layer_idx: int,\n        local_physical_count_of_layer: List[int],\n        num_tokens_per_rank,\n        num_tokens_per_rdma_rank,\n        num_tokens_per_expert,\n    ):\n        pass\n\n    def on_deepep_dispatch_low_latency(\n        self, layer_idx: int, local_physical_count_of_layer: torch.Tensor\n    ):\n        pass\n\n    def reset(self):\n        raise NotImplementedError\n\n    def collect(self) -> Dict:\n        raise NotImplementedError\n\n\nclass _DetailSinglePassGatherer(_SinglePassGatherer):\n    # DeepSeek V3 has this value; should generalize later\n    _TOP_K_NUM = 8\n\n    def __init__(\n        self,\n        server_args: ServerArgs,\n        expert_location_metadata: ExpertLocationMetadata,\n        rank: int,\n    ):\n        super().__init__(expert_location_metadata, rank)\n        self._metadata: Optional[Dict[str, Any]] = None\n        self._topk_ids_of_layer = torch.zeros(\n            (\n                expert_location_metadata.num_layers,\n                # TODO determine the max number\n                server_args.chunked_prefill_size * 8,\n                self._TOP_K_NUM,\n            ),\n            dtype=torch.int32,\n            device=server_args.device,\n        )\n        self._misc_objects: List[Dict[str, Any]] = []\n        assert (\n            not server_args.enable_two_batch_overlap\n        ), \"DetailSinglePassGatherer does not support TBO yet\"\n        # TODO assert shared experts fusion is disabled, o/w data is wrong\n\n    def on_forward_pass_start(self, forward_batch: ForwardBatch):\n        assert self._metadata is None\n        self._metadata = dict(\n            # TODO pr-chain\n            # rids=forward_batch.rids,\n            input_ids=forward_batch.input_ids.cpu().tolist(),\n            positions=forward_batch.positions.cpu().tolist(),\n            extend_seq_lens=forward_batch.extend_seq_lens_cpu,\n            forward_mode=forward_batch.forward_mode.value,\n        )\n\n    def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):\n        self._topk_ids_of_layer[layer_idx, : topk_ids.shape[0], : topk_ids.shape[1]] = (\n            topk_ids\n        )\n\n    def on_deepep_dispatch_normal(\n        self,\n        layer_idx: int,\n        local_physical_count_of_layer: List[int],\n        num_tokens_per_rank,\n        num_tokens_per_rdma_rank,\n        num_tokens_per_expert,\n    ):\n        self._misc_objects.append(\n            dict(\n                layer_id=layer_idx,\n                num_tokens_per_rank=num_tokens_per_rank.cpu().tolist(),\n                num_tokens_per_rdma_rank=num_tokens_per_rdma_rank.cpu().tolist(),\n                num_tokens_per_expert=num_tokens_per_expert.cpu().tolist(),\n            )\n        )\n\n    def reset(self):\n        self._topk_ids_of_layer[...] = -1\n        self._misc_objects.clear()\n        self._metadata = None\n\n    def collect(self) -> Dict:\n        num_tokens = len(self._metadata[\"input_ids\"])\n\n        global_physical_count = _convert_per_token_to_global_physical_count(\n            num_tokens,\n            num_layers=self._expert_location_metadata.num_layers,\n            num_physical_experts=self._expert_location_metadata.num_physical_experts,\n            _topk_ids_of_layer=self._topk_ids_of_layer,\n        )\n\n        return dict(\n            **self._metadata,\n            topk_ids_of_layer=self._topk_ids_of_layer[:, :num_tokens, :].clone().cpu(),\n            misc_objects=self._misc_objects,\n            global_physical_count=global_physical_count,\n        )\n\n\nclass _LayerBasedCpuSinglePassGatherer(_SinglePassGatherer):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self._objects_of_layer = {}\n\n    def _on_layer_data(self, layer_idx: int, objects: List[int]):\n        assert 0 <= layer_idx < self._expert_location_metadata.num_layers\n        if layer_idx in self._objects_of_layer:\n            self._objects_of_layer[layer_idx] = _list_sum(\n                self._objects_of_layer[layer_idx], objects\n            )\n        else:\n            self._objects_of_layer[layer_idx] = objects\n\n    def reset(self):\n        self._objects_of_layer.clear()\n\n    def _collect_objects(self, pad_len: int) -> torch.Tensor:\n        data = [\n            self._objects_of_layer.get(layer_index) or ([0] * pad_len)\n            for layer_index in range(self._expert_location_metadata.num_layers)\n        ]\n        return torch.tensor(data)\n\n\ndef _list_sum(a: List, b: List) -> List:\n    return [x + y for x, y in zip(a, b, strict=True)]\n\n\nclass _LayerBasedGpuSinglePassGatherer(_SinglePassGatherer):\n    def __init__(self, *args, enable_global_physical_experts: bool, **kwargs):\n        super().__init__(*args, **kwargs)\n        self._enable_global_physical_experts = enable_global_physical_experts\n        self._data = torch.zeros(\n            (\n                self._expert_location_metadata.num_layers,\n                (\n                    self._expert_location_metadata.num_physical_experts\n                    if enable_global_physical_experts\n                    else self._expert_location_metadata.num_local_physical_experts\n                ),\n            ),\n            dtype=torch.int,\n            device=\"cuda\",\n        )\n\n    def reset(self):\n        self._data[...] = 0\n\n    def collect(self) -> Dict:\n        if self._enable_global_physical_experts:\n            global_physical_count = self._data\n        else:\n            # Can optimize if bottleneck\n            global_physical_count = _convert_local_to_global_physical_count(\n                self._data,\n                rank=self._rank,\n                num_local_physical_experts=self._expert_location_metadata.num_local_physical_experts,\n                num_physical_experts=self._expert_location_metadata.num_physical_experts,\n            )\n\n        return dict(global_physical_count=global_physical_count)\n\n\nclass _SelectExpertsSinglePassGatherer(_LayerBasedGpuSinglePassGatherer):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs, enable_global_physical_experts=True)\n\n    # can optimize (e.g. fuse / compile)\n    def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):\n        topk_ids = topk_ids.flatten()\n        mask = topk_ids != -1\n        self._data[layer_idx, :].scatter_add_(\n            dim=0, index=topk_ids.masked_fill(~mask, 0).long(), src=mask.int()\n        )\n\n\nclass _DeepepNormalSinglePassGatherer(_LayerBasedCpuSinglePassGatherer):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        if torch.distributed.get_rank() == 0:\n            logger.info(\n                \"DeepepNormalSinglePassGatherer gathers approximate statistics. \"\n                \"If used with small batch size, consider using expert_distribution_recorder_mode=stat.\"\n            )\n\n    def on_deepep_dispatch_normal(\n        self,\n        layer_idx: int,\n        local_physical_count_of_layer: List[int],\n        num_tokens_per_rank,\n        num_tokens_per_rdma_rank,\n        num_tokens_per_expert,\n    ):\n        assert isinstance(local_physical_count_of_layer, list)\n        self._on_layer_data(layer_idx, local_physical_count_of_layer)\n\n    def collect(self) -> Dict:\n        local_physical_count = super()._collect_objects(\n            pad_len=self._expert_location_metadata.num_local_physical_experts\n        )\n        global_physical_count = _convert_local_to_global_physical_count(\n            local_physical_count,\n            rank=self._rank,\n            num_local_physical_experts=self._expert_location_metadata.num_local_physical_experts,\n            num_physical_experts=self._expert_location_metadata.num_physical_experts,\n        )\n        return dict(global_physical_count=global_physical_count)\n\n\nclass _DeepepLowLatencySinglePassGatherer(_LayerBasedGpuSinglePassGatherer):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs, enable_global_physical_experts=False)\n\n    def on_deepep_dispatch_low_latency(\n        self, layer_idx: int, local_physical_count_of_layer: torch.Tensor\n    ):\n        # Most naive implementation, can optimize later\n        self._data[layer_idx, :] += local_physical_count_of_layer\n\n\ndef _convert_per_token_to_global_physical_count(\n    num_tokens: int,\n    num_layers: int,\n    num_physical_experts: int,\n    _topk_ids_of_layer: torch.Tensor,\n) -> torch.Tensor:\n    topk_ids_layer_major = _topk_ids_of_layer[:, :num_tokens, :].reshape(num_layers, -1)\n    mask = topk_ids_layer_major != -1\n\n    index = topk_ids_layer_major.masked_fill(~mask, 0).long()\n    src = mask.int()\n\n    ans = torch.zeros(\n        (num_layers, num_physical_experts),\n        dtype=_topk_ids_of_layer.dtype,\n        device=_topk_ids_of_layer.device,\n    )\n    ans.scatter_add_(dim=1, index=index, src=src)\n    return ans\n\n\ndef _convert_local_to_global_physical_count(\n    local_physical_count: torch.Tensor,\n    rank: int,\n    num_local_physical_experts: int,\n    num_physical_experts: int,\n) -> torch.Tensor:\n    dtype = local_physical_count.dtype\n    device = local_physical_count.device\n    num_layers, _ = local_physical_count.shape\n\n    ans = torch.zeros((num_layers, num_physical_experts), dtype=dtype, device=device)\n    ans[\n        :, num_local_physical_experts * rank : num_local_physical_experts * (rank + 1)\n    ] = local_physical_count\n    return ans\n\n\n# --------------------------------------- Accumulator -----------------------------------------\n\n_SINGLE_PASS_GATHERER_KEY_PRIMARY = \"primary\"\n\n\nclass _Accumulator(ABC):\n    @staticmethod\n    def init_new(\n        server_args: ServerArgs,\n        expert_location_metadata: ExpertLocationMetadata,\n        rank: int,\n    ) -> \"_Accumulator\":\n        return _Accumulator.get_class(server_args)(\n            server_args, expert_location_metadata, rank\n        )\n\n    @staticmethod\n    def get_class(server_args: ServerArgs) -> Type[\"_Accumulator\"]:\n        return {\n            \"stat\": _StatAccumulator,\n            \"stat_approx\": _StatAccumulator,\n            \"per_pass\": _DetailAccumulator,\n            \"per_token\": _DetailAccumulator,\n        }[server_args.expert_distribution_recorder_mode]\n\n    def __init__(\n        self,\n        server_args: ServerArgs,\n        expert_location_metadata: ExpertLocationMetadata,\n        rank: int,\n    ):\n        self._server_args = server_args\n        self._expert_location_metadata = expert_location_metadata\n        self._rank = rank\n\n    def get_single_pass_gatherer_keys(self):\n        return [_SINGLE_PASS_GATHERER_KEY_PRIMARY]\n\n    def get_single_pass_gatherer_key(self, debug_name: Optional[str]):\n        return _SINGLE_PASS_GATHERER_KEY_PRIMARY\n\n    def append(\n        self,\n        forward_pass_id: int,\n        gatherer_key: str,\n        single_pass_data: Dict,\n        outputs: Dict[str, Any],\n    ):\n        pass\n\n    def reset(self):\n        pass\n\n    def dump(self, output_mode: _OutputMode):\n        pass\n\n\nclass _UtilizationRateAccumulatorMixin(_Accumulator):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n        self._enable = self._server_args.enable_expert_distribution_metrics\n\n        if self._enable:\n            self.window_sizes = [10, 100, 1000]\n            self._history = _DequeCollection(maxlens=self.window_sizes)\n            self._rank = torch.distributed.get_rank()\n            self._expert_dispatch_collector = ExpertDispatchCollector(\n                self._expert_location_metadata.ep_size\n            )\n            self._metric_heatmap_collection_counter = 0\n\n    def append(\n        self,\n        forward_pass_id: int,\n        gatherer_key: str,\n        single_pass_data: Dict,\n        outputs: Dict[str, Any],\n    ):\n        super().append(forward_pass_id, gatherer_key, single_pass_data, outputs)\n        if self._enable:\n            return self._append_utilization_rate(\n                forward_pass_id, single_pass_data[\"global_physical_count\"], outputs\n            )\n\n    def reset(self):\n        super().reset()\n        if self._enable:\n            self._history.clear()\n\n    def _append_utilization_rate(\n        self,\n        forward_pass_id: int,\n        single_pass_global_physical_count: torch.Tensor,\n        outputs: Dict[str, Any],\n    ):\n        gpu_physical_count = compute_gpu_physical_count(\n            single_pass_global_physical_count,\n            num_gpu=self._expert_location_metadata.ep_size,\n        )\n        gpu_physical_count = gpu_physical_count.to(self._server_args.device)\n        torch.distributed.reduce(\n            gpu_physical_count, dst=0, op=torch.distributed.ReduceOp.SUM\n        )\n\n        if self._rank == 0:\n            self._handle_metric_eplb_heatmap(gpu_physical_count)\n\n            utilization_rate_gpu = torch.mean(\n                compute_utilization_rate(gpu_physical_count)\n            )\n            if envs.SGLANG_ENABLE_EPLB_BALANCEDNESS_METRIC.get():\n                print(f\"hi {self._rank=} {utilization_rate_gpu=}\")\n                outputs[\"metrics\"] = ExpertDistributionMetrics(\n                    eplb_balancedness=utilization_rate_gpu,\n                )\n            else:\n                # TODO maybe refactor this part to also avoid a `.item()` gpu->cpu sync\n                utilization_rate_cpu = utilization_rate_gpu.item()\n                self._history.append(utilization_rate_cpu)\n\n                gpu_physical_count_sum = gpu_physical_count.sum().item()\n\n                logger.info(\n                    f\"[Expert Balancedness] \"\n                    f\"forward_pass_id={forward_pass_id} \"\n                    f\"current_pass_balancedness={utilization_rate_cpu:.03f} \"\n                    f\"{''.join(f'last_{size}_average_balancedness={value:.03f} ' for size, value in self._history.mean().items())} \"\n                    f\"gpu_physical_count_sum={gpu_physical_count_sum}\"\n                    # f\"current_pass_per_layer={[round(x, 2) for x in utilization_rate_tensor.cpu().tolist()]}\"\n                )\n\n    # TODO refactor\n    def _handle_metric_eplb_heatmap(self, gpu_physical_count: torch.Tensor):\n        # sglang:eplb_gpu_physical_count metric is disabled if SGLANG_EPLB_HEATMAP_COLLECTION_INTERVAL <= 0\n        interval = get_int_env_var(\"SGLANG_EPLB_HEATMAP_COLLECTION_INTERVAL\", 0)\n        if interval > 0 and self._metric_heatmap_collection_counter % interval == 0:\n            for layer_idx in range(self._expert_location_metadata.num_layers):\n                count_of_layer = (\n                    self._expert_dispatch_collector.eplb_gpu_physical_count.labels(\n                        layer=str(layer_idx)\n                    )\n                )\n                # Exclude the +Inf bucket.\n                assert (\n                    self._expert_location_metadata.ep_size\n                    == len(count_of_layer._buckets) - 1\n                ), f\"{self._expert_location_metadata.ep_size=}, {len(count_of_layer._buckets)=}\"\n                for gpu_rank in range(self._expert_location_metadata.ep_size):\n                    count = gpu_physical_count[layer_idx, gpu_rank]\n                    if count > 0:\n                        count_of_layer._sum.inc(count * gpu_rank)\n                        count_of_layer._buckets[gpu_rank].inc(count)\n        self._metric_heatmap_collection_counter += 1\n\n\nclass _DequeCollection:\n    def __init__(self, maxlens: List[int]):\n        self._dequeues = [deque(maxlen=maxlen) for maxlen in maxlens]\n\n    def append(self, value):\n        for d in self._dequeues:\n            d.append(value)\n\n    def clear(self):\n        for d in self._dequeues:\n            d.clear()\n\n    def mean(self) -> Dict[int, float]:\n        return {d.maxlen: sum(d) / len(d) for d in self._dequeues}\n\n\nclass _DetailAccumulator(_UtilizationRateAccumulatorMixin):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self._records = []\n\n    def get_single_pass_gatherer_keys(self):\n        if False:  # TODO `server_args.enable_two_batch_overlap`\n            return [_SINGLE_PASS_GATHERER_KEY_PRIMARY, \"child_a\", \"child_b\"]\n        return super().get_single_pass_gatherer_keys()\n\n    def get_single_pass_gatherer_key(self, debug_name: Optional[str]):\n        if False:  # TODO `server_args.enable_two_batch_overlap`\n            return debug_name or _SINGLE_PASS_GATHERER_KEY_PRIMARY\n        return super().get_single_pass_gatherer_key(debug_name)\n\n    def append(\n        self,\n        forward_pass_id: int,\n        gatherer_key: str,\n        single_pass_data: Dict,\n        outputs: Dict[str, Any],\n    ):\n        super().append(forward_pass_id, gatherer_key, single_pass_data, outputs)\n\n        def _process_object(obj):\n            if isinstance(obj, torch.Tensor):\n                return obj.cpu().clone()\n            return obj\n\n        single_pass_data_processed = {\n            k: _process_object(v) for k, v in single_pass_data.items()\n        }\n\n        self._records.append(\n            dict(\n                forward_pass_id=forward_pass_id,\n                rank=self._rank,\n                gatherer_key=gatherer_key,\n                **single_pass_data_processed,\n            )\n        )\n\n    def reset(self):\n        super().reset()\n        self._records.clear()\n\n    def dump(self, output_mode: _OutputMode):\n        assert output_mode == \"file\"\n        output = dict(\n            records=self._records,\n            # NOTE: This may change during recording, so here we say it is the \"last\" one\n            last_physical_to_logical_map=self._expert_location_metadata.physical_to_logical_map,\n        )\n        _dump_to_file(\n            f\"expert_distribution_recorder_{time.time()}_{self._rank}.pt\", output\n        )\n\n\nclass _StatAccumulator(_UtilizationRateAccumulatorMixin):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self._global_physical_count_of_buffered_step = _Buffer.init_new(\n            item_shape=(\n                self._expert_location_metadata.num_layers,\n                # Cannot use local_physical_count to support select_experts\n                self._expert_location_metadata.num_physical_experts,\n            ),\n            buffer_size=self._server_args.expert_distribution_recorder_buffer_size,\n            dtype=torch.int32,\n            device=self._server_args.device,\n        )\n        self._first_dump = True\n\n    def append(\n        self,\n        forward_pass_id: int,\n        gatherer_key: str,\n        single_pass_data: Dict,\n        outputs: Dict[str, Any],\n    ):\n        super().append(forward_pass_id, gatherer_key, single_pass_data, outputs)\n        # Can optimize if overhead here is large\n        self._global_physical_count_of_buffered_step.append(\n            single_pass_data[\"global_physical_count\"]\n        )\n\n    def reset(self):\n        super().reset()\n        self._global_physical_count_of_buffered_step.reset()\n\n    def dump(self, output_mode: _OutputMode):\n        logical_count_of_buffered_step = _convert_global_physical_count_to_logical_count(\n            self._global_physical_count_of_buffered_step.get_all(),\n            num_layers=self._expert_location_metadata.num_layers,\n            num_logical_experts=self._expert_location_metadata.num_logical_experts,\n            physical_to_logical_map=self._expert_location_metadata.physical_to_logical_map,\n        )\n\n        if self._first_dump:\n            self._first_dump = False\n            torch.get_device_module().empty_cache()\n\n        torch.distributed.all_reduce(\n            logical_count_of_buffered_step, op=torch.distributed.ReduceOp.SUM\n        )\n\n        output = dict(\n            rank=self._rank,\n            logical_count=logical_count_of_buffered_step,\n            average_utilization_rate_over_window=self._get_global_average_utilization_rate(),\n        )\n\n        if output_mode == \"file\":\n            if self._rank == 0:\n                _dump_to_file(f\"expert_distribution_recorder_{time.time()}.pt\", output)\n        elif output_mode == \"object\":\n            return output\n        else:\n            raise NotImplementedError\n\n    def _get_global_average_utilization_rate(self):\n        if not self._enable or math.isclose(\n            self._server_args.eplb_min_rebalancing_utilization_threshold, 1.0\n        ):\n            return None\n\n        if self._rank == 0:\n            utilization_mean_rates = self._history.mean()\n            window_index = self.window_sizes[-1]\n            average_utilization_rate_over_window = (\n                utilization_mean_rates[window_index]\n                if window_index in utilization_mean_rates\n                else 0\n            )\n\n            avg_rate_tensor = torch.tensor(\n                [average_utilization_rate_over_window],\n                dtype=torch.float32,\n                device=\"cuda\",\n            )\n        else:\n            avg_rate_tensor = torch.empty(1, dtype=torch.float32, device=\"cuda\")\n        torch.distributed.broadcast(avg_rate_tensor, src=0)\n        return avg_rate_tensor.item()\n\n\ndef _dump_to_file(name, data):\n    save_dir = Path(envs.SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR.get())\n    path_output = save_dir / name\n    logger.info(f\"Write expert distribution to {path_output}\")\n    if not save_dir.exists():\n        save_dir.mkdir(parents=True, exist_ok=True)\n    torch.save(data, str(path_output))\n\n\nclass _Buffer:\n    @staticmethod\n    def init_new(item_shape: Tuple, buffer_size: int, dtype, device):\n        if buffer_size < 0:\n            return _InfiniteBuffer(item_shape, dtype=dtype, device=device)\n        else:\n            return _CircularBuffer(item_shape, buffer_size, dtype=dtype, device=device)\n\n    def append(self, value: torch.Tensor):\n        raise NotImplementedError\n\n    def get_all(self) -> torch.Tensor:\n        raise NotImplementedError\n\n    def reset(self):\n        raise NotImplementedError\n\n\nclass _CircularBuffer(_Buffer):\n    def __init__(self, item_shape: Tuple, buffer_size: int, dtype, device):\n        self._buffer = torch.zeros(\n            (buffer_size, *item_shape), dtype=dtype, device=device\n        )\n        self._curr_index = 0\n\n    def append(self, value: torch.Tensor):\n        self._buffer[self._curr_index] = value\n        self._curr_index = (self._curr_index + 1) % len(self._buffer)\n\n    def get_all(self) -> torch.Tensor:\n        return self._buffer\n\n    def reset(self):\n        self._buffer[...] = 0\n\n\nclass _InfiniteBuffer(_Buffer):\n    def __init__(self, item_shape: Tuple, dtype, device):\n        self._item_shape = item_shape\n        self._buffer = torch.zeros((128, *item_shape), dtype=dtype, device=device)\n        self._size = 0\n\n    def append(self, value: torch.Tensor):\n        curr_buffer_size = len(self._buffer)\n        dtype = self._buffer.dtype\n        device = self._buffer.device\n\n        if self._size == curr_buffer_size:\n            new_buffer = torch.zeros(\n                (2 * curr_buffer_size, *self._item_shape), dtype=dtype, device=device\n            )\n            new_buffer[:curr_buffer_size] = self._buffer\n            self._buffer = new_buffer\n\n        self._buffer[self._size] = value\n        self._size += 1\n\n    def get_all(self) -> torch.Tensor:\n        return self._buffer[: self._size]\n\n    def reset(self):\n        self._buffer[...] = 0\n        self._size = 0\n\n\ndef _convert_global_physical_count_to_logical_count(\n    # (whatever, num_layers, num_physical_experts)\n    global_physical_count: torch.Tensor,\n    num_layers: int,\n    num_logical_experts: int,\n    physical_to_logical_map: torch.Tensor,\n):\n    dim_extra, _, _ = global_physical_count.shape\n    dtype = global_physical_count.dtype\n    device = global_physical_count.device\n    logical_count = torch.zeros(\n        (dim_extra, num_layers, num_logical_experts), dtype=dtype, device=device\n    )\n    logical_count.scatter_add_(\n        dim=2,\n        index=physical_to_logical_map.unsqueeze(0)\n        .expand(dim_extra, -1, -1)\n        .to(torch.int64),\n        src=global_physical_count,\n    )\n    return logical_count\n\n\ndef compute_gpu_physical_count(\n    physical_count_of_whatever: torch.Tensor,  # (..., num_layer, num_physical_expert)\n    num_gpu: int,\n):\n    \"\"\"output: gpu_physical_count_of_batch (..., num_layer, num_gpu)\"\"\"\n    return einops.reduce(\n        physical_count_of_whatever,\n        \"... num_layer (num_gpu num_expert_per_gpu) -> ... num_layer num_gpu\",\n        \"sum\",\n        num_gpu=num_gpu,\n    )\n\n\ndef compute_utilization_rate(\n    gpu_physical_count_of_batch: torch.Tensor,  # (..., num_layer, num_gpu)\n):\n    \"\"\"output: utilization_rate (..., num_layer)\"\"\"\n    gpu_physical_count_of_batch = gpu_physical_count_of_batch.float()\n    max_gpu_physical_count = einops.reduce(\n        gpu_physical_count_of_batch,\n        \"... num_layer num_gpu -> ... num_layer\",\n        \"max\",\n    )\n    avg_gpu_physical_count = einops.reduce(\n        gpu_physical_count_of_batch,\n        \"... num_layer num_gpu -> ... num_layer\",\n        \"mean\",\n    )\n    return (avg_gpu_physical_count + 1e-5) / (max_gpu_physical_count + 1e-5)\n"
  },
  {
    "path": "python/sglang/srt/eplb/expert_location.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\nfrom __future__ import annotations\n\nimport json\nimport logging\nimport random\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING, List, Optional\n\nimport torch\nimport torch.distributed\nimport torch.nn.functional as F\n\nfrom sglang.srt.eplb import eplb_algorithms\nfrom sglang.srt.model_loader import get_model_architecture\n\nif TYPE_CHECKING:\n    from sglang.srt.configs.model_config import ModelConfig\n    from sglang.srt.server_args import ServerArgs\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclass\nclass ExpertLocationMetadata:\n    physical_to_logical_map: torch.Tensor  # (layers, num_physical_experts)\n    physical_to_logical_map_cpu: torch.Tensor\n    logical_to_all_physical_map: torch.Tensor  # (layers, num_logical_experts, X)\n    logical_to_all_physical_map_cpu: torch.Tensor  # CPU copy for performance\n    logical_to_all_physical_map_num_valid: torch.Tensor  # (layers, num_logical_experts)\n    # (layers, num_logical_experts)\n    logical_to_rank_dispatch_physical_map: Optional[torch.Tensor]\n\n    # -------------------------------- properties ------------------------------------\n\n    @property\n    def num_layers(self) -> int:\n        return self.physical_to_logical_map.shape[0]\n\n    @property\n    def num_physical_experts(self) -> int:\n        return self.physical_to_logical_map.shape[1]\n\n    @property\n    def num_local_physical_experts(self) -> int:\n        ans, remainder = divmod(self.num_physical_experts, self.ep_size)\n        assert remainder == 0\n        return ans\n\n    @property\n    def num_logical_experts(self) -> int:\n        return self.logical_to_all_physical_map.shape[1]\n\n    @property\n    def ep_size(self):\n        # TODO change when EP size != world size\n        return torch.distributed.get_world_size()\n\n    def __post_init__(self):\n        num_layers_0, num_physical_experts_0 = self.physical_to_logical_map.shape\n        num_layers_1, num_logical_experts_0, num_physical_experts_1 = (\n            self.logical_to_all_physical_map.shape\n        )\n        num_layers_2, num_logical_experts_1 = (\n            self.logical_to_all_physical_map_num_valid.shape\n        )\n        assert num_layers_0 == num_layers_1 == num_layers_2\n        assert num_logical_experts_0 == num_logical_experts_1\n        assert num_physical_experts_0 == num_physical_experts_1\n\n    # -------------------------------- construction ------------------------------------\n\n    @staticmethod\n    def init_trivial(\n        server_args: ServerArgs, model_config: ModelConfig, moe_ep_rank: int\n    ):\n        \"\"\"Trivial location - logical expert i corresponds to physical expert i\"\"\"\n        common = ExpertLocationMetadata._init_common(server_args, model_config)\n\n        if common is None:\n            return None\n\n        num_physical_experts = common[\"num_physical_experts\"]\n        model_config_for_expert_location = common[\"model_config_for_expert_location\"]\n        num_layers = model_config_for_expert_location.num_layers\n        num_logical_experts = model_config_for_expert_location.num_logical_experts\n\n        physical_to_logical_map = (\n            torch.arange(0, num_physical_experts).repeat(num_layers, 1)\n            % num_logical_experts\n        )\n\n        return ExpertLocationMetadata.init_by_mapping(\n            server_args,\n            model_config,\n            physical_to_logical_map=physical_to_logical_map,\n            moe_ep_rank=moe_ep_rank,\n        )\n\n    @staticmethod\n    def init_by_mapping(\n        server_args: ServerArgs,\n        model_config: ModelConfig,\n        physical_to_logical_map,\n        moe_ep_rank: int = None,\n    ):\n        if not isinstance(physical_to_logical_map, torch.Tensor):\n            physical_to_logical_map = torch.tensor(physical_to_logical_map)\n        physical_to_logical_map = physical_to_logical_map.to(server_args.device)\n\n        common = ExpertLocationMetadata._init_common(server_args, model_config)\n\n        if common is None:\n            return None\n\n        model_config_for_expert_location = common[\"model_config_for_expert_location\"]\n        logical_to_all_physical_map = _compute_logical_to_all_physical_map(\n            server_args=server_args,\n            physical_to_logical_map=physical_to_logical_map,\n            num_logical_experts=model_config_for_expert_location.num_logical_experts,\n            ep_size=common[\"ep_size\"],\n            moe_ep_rank=moe_ep_rank,\n        )\n\n        return ExpertLocationMetadata._init_raw(\n            server_args=server_args,\n            ep_size=common[\"ep_size\"],\n            physical_to_logical_map=physical_to_logical_map,\n            logical_to_all_physical_map=logical_to_all_physical_map,\n        )\n\n    @staticmethod\n    def init_by_eplb(\n        server_args: ServerArgs, model_config: ModelConfig, logical_count: torch.Tensor\n    ):\n        if not isinstance(logical_count, torch.Tensor):\n            logical_count = torch.tensor(logical_count)\n        if len(logical_count.shape) == 2:\n            logical_count = logical_count.unsqueeze(0)\n        logical_count = logical_count.to(server_args.device)\n\n        common = ExpertLocationMetadata._init_common(server_args, model_config)\n\n        if common is None:\n            return None\n\n        model_config_for_expert_location = common[\"model_config_for_expert_location\"]\n        num_physical_experts = common[\"num_physical_experts\"]\n        num_groups = model_config_for_expert_location.num_groups\n        num_nodes = server_args.nnodes\n\n        physical_to_logical_map, logical_to_all_physical_map, expert_count = (\n            eplb_algorithms.rebalance_experts(\n                tokens_per_expert=logical_count,\n                num_physical_experts=num_physical_experts,\n                num_local_physical_experts=num_physical_experts // common[\"ep_size\"],\n                num_groups=num_groups,\n                num_nodes=num_nodes,\n                algorithm=eplb_algorithms.compute_algorithm(\n                    raw_algorithm=server_args.eplb_algorithm,\n                    num_groups=num_groups,\n                    num_nodes=num_nodes,\n                ),\n            )\n        )\n\n        return ExpertLocationMetadata._init_raw(\n            server_args=server_args,\n            ep_size=common[\"ep_size\"],\n            physical_to_logical_map=physical_to_logical_map.to(server_args.device),\n            logical_to_all_physical_map=logical_to_all_physical_map.to(\n                server_args.device\n            ),\n        )\n\n    @staticmethod\n    def _init_common(server_args: ServerArgs, model_config: ModelConfig):\n        model_config_for_expert_location = (\n            ModelConfigForExpertLocation.from_model_config(model_config)\n        )\n\n        if model_config_for_expert_location is None:\n            return None\n\n        num_physical_experts = (\n            model_config_for_expert_location.num_logical_experts\n            + server_args.ep_num_redundant_experts\n        )\n        ep_size = server_args.ep_size\n        assert num_physical_experts % ep_size == 0\n        num_local_physical_experts = num_physical_experts // ep_size\n\n        return dict(\n            model_config_for_expert_location=model_config_for_expert_location,\n            num_physical_experts=num_physical_experts,\n            num_local_physical_experts=num_local_physical_experts,\n            ep_size=ep_size,\n        )\n\n    @staticmethod\n    def _init_raw(\n        server_args: ServerArgs,\n        ep_size: int,\n        physical_to_logical_map: torch.Tensor,\n        logical_to_all_physical_map: torch.Tensor,\n    ):\n        _, num_physical_experts = physical_to_logical_map.shape\n\n        logical_to_all_physical_map_padded = F.pad(\n            logical_to_all_physical_map,\n            (0, num_physical_experts - logical_to_all_physical_map.shape[-1]),\n            value=-1,\n        )\n\n        logical_to_all_physical_map_num_valid = torch.count_nonzero(\n            logical_to_all_physical_map != -1, dim=-1\n        )\n\n        return ExpertLocationMetadata(\n            physical_to_logical_map=physical_to_logical_map,\n            physical_to_logical_map_cpu=physical_to_logical_map.cpu(),\n            logical_to_all_physical_map=logical_to_all_physical_map_padded,\n            logical_to_all_physical_map_cpu=logical_to_all_physical_map_padded.cpu(),\n            logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,\n            logical_to_rank_dispatch_physical_map=(\n                compute_logical_to_rank_dispatch_physical_map(\n                    server_args=server_args,\n                    logical_to_all_physical_map=logical_to_all_physical_map,\n                    ep_size=ep_size,\n                    num_physical_experts=num_physical_experts,\n                    # TODO improve when we have real EP rank\n                    ep_rank=torch.distributed.get_rank() % ep_size,\n                )\n                if server_args.ep_dispatch_algorithm == \"static\"\n                else None\n            ),\n        )\n\n    # -------------------------------- mutation ------------------------------------\n\n    def update(\n        self,\n        other: \"ExpertLocationMetadata\",\n        update_layer_ids: List[int],\n    ):\n        for field in [\n            \"ep_size\",\n        ]:\n            assert getattr(self, field) == getattr(other, field)\n\n        for field in [\n            \"physical_to_logical_map\",\n            \"physical_to_logical_map_cpu\",\n            \"logical_to_all_physical_map\",\n            \"logical_to_all_physical_map_cpu\",\n            \"logical_to_all_physical_map_num_valid\",\n            \"logical_to_rank_dispatch_physical_map\",\n        ]:\n            other_field = getattr(other, field)\n            self_field = getattr(self, field)\n            assert (other_field is not None) == (self_field is not None)\n            if self_field is not None:\n                mask_update = torch.tensor(\n                    [i in update_layer_ids for i in range(self.num_layers)]\n                )\n                mask_update = mask_update.view(*([-1] + [1] * (self_field.dim() - 1)))\n                mask_update = mask_update.to(self_field.device, non_blocking=True)\n                self_field[...] = torch.where(mask_update, other_field, self_field)\n\n    # -------------------------------- usage ------------------------------------\n\n    def logical_to_all_physical(\n        self,\n        layer_id: int,\n        logical_expert_id: int,\n        require_global_experts: bool = False,\n    ) -> List[int]:\n        # Use CPU copy to avoid GPU→CPU sync on every call, which is expensive in update weights scenario\n        if require_global_experts:\n            num_physical_experts = self.logical_to_all_physical_map_cpu[layer_id].shape[\n                -1\n            ]\n            return list(\n                range(logical_expert_id, num_physical_experts, self.num_logical_experts)\n            )\n        return [\n            physical_expert_id\n            for physical_expert_id in self.logical_to_all_physical_map_cpu[\n                layer_id, logical_expert_id\n            ].tolist()\n            if physical_expert_id != -1\n        ]\n\n\n_global_expert_location_metadata: Optional[ExpertLocationMetadata] = None\n\n\ndef get_global_expert_location_metadata():\n    return _global_expert_location_metadata\n\n\ndef set_global_expert_location_metadata(value):\n    global _global_expert_location_metadata\n    assert _global_expert_location_metadata is None\n    _global_expert_location_metadata = value\n\n\ndef _compute_logical_to_all_physical_map(\n    server_args: ServerArgs,\n    physical_to_logical_map: torch.Tensor,\n    num_logical_experts: int,\n    ep_size: int,\n    moe_ep_rank: int,\n):\n    # This is rarely called, so we use for loops for maximum clarity\n\n    num_layers, num_physical_experts = physical_to_logical_map.shape\n\n    logical_to_all_physical_map = [\n        [[] for _ in range(num_logical_experts)] for _ in range(num_layers)\n    ]\n\n    # Find out the candidate physical experts for each logical expert on each layer\n    for layer_id in range(num_layers):\n        for physical_expert_id in range(num_physical_experts):\n            logical_expert_id = physical_to_logical_map[\n                layer_id, physical_expert_id\n            ].item()\n            logical_to_all_physical_map[layer_id][logical_expert_id].append(\n                physical_expert_id\n            )\n\n    # Replace by the physical expert on local GPU or node if possible\n    if moe_ep_rank is not None:\n        num_gpus_per_node = server_args.ep_size // server_args.nnodes\n        num_local_gpu_physical_experts = num_physical_experts // ep_size\n        num_local_node_physical_experts = (\n            num_local_gpu_physical_experts * num_gpus_per_node\n        )\n        for layer_id in range(num_layers):\n            for logical_expert_id in range(num_logical_experts):\n                # Try to find the nearest physical expert\n                nearest_expert = _find_nearest_expert(\n                    candidate_physical_expert_ids=logical_to_all_physical_map[layer_id][\n                        logical_expert_id\n                    ],\n                    num_local_gpu_physical_experts=num_local_gpu_physical_experts,\n                    moe_ep_rank=moe_ep_rank,\n                    num_gpus_per_node=num_gpus_per_node,\n                    num_local_node_physical_experts=num_local_node_physical_experts,\n                )\n\n                # Replace by the nearest physical expert\n                if nearest_expert != -1:\n                    logical_to_all_physical_map[layer_id][logical_expert_id] = [\n                        nearest_expert\n                    ]\n\n    logical_to_all_physical_map = _pad_nested_array(\n        logical_to_all_physical_map, pad_value=-1\n    )\n\n    return torch.tensor(\n        logical_to_all_physical_map, device=physical_to_logical_map.device\n    )\n\n\ndef _pad_nested_array(arr, pad_value):\n    max_len = max(len(inner) for outer in arr for inner in outer)\n    padded = [\n        [inner + [pad_value] * (max_len - len(inner)) for inner in outer]\n        for outer in arr\n    ]\n    return padded\n\n\n# TODO optimize performance (rewrite and/or run in separate process with overlap)\ndef compute_logical_to_rank_dispatch_physical_map(\n    server_args: ServerArgs,\n    logical_to_all_physical_map: torch.Tensor,\n    ep_size: int,\n    num_physical_experts: int,\n    ep_rank: int,\n    seed: int = 42,\n):\n    r = random.Random(seed)\n\n    num_local_gpu_physical_experts = num_physical_experts // ep_size\n    num_gpus_per_node = server_args.ep_size // server_args.nnodes\n    num_local_node_physical_experts = num_local_gpu_physical_experts * num_gpus_per_node\n    num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape\n    dtype = logical_to_all_physical_map.dtype\n\n    logical_to_rank_dispatch_physical_map = torch.full(\n        size=(ep_size, num_layers, num_logical_experts),\n        fill_value=-1,\n        dtype=dtype,\n    )\n\n    for layer_id in range(num_layers):\n        for logical_expert_id in range(num_logical_experts):\n            candidate_physical_expert_ids = _logical_to_all_physical_raw(\n                logical_to_all_physical_map, layer_id, logical_expert_id\n            )\n            output_partial = logical_to_rank_dispatch_physical_map[\n                :, layer_id, logical_expert_id\n            ]\n\n            for moe_ep_rank in range(ep_size):\n                # Fill with the nearest physical expert\n                output_partial[moe_ep_rank] = _find_nearest_expert(\n                    candidate_physical_expert_ids=candidate_physical_expert_ids,\n                    num_local_gpu_physical_experts=num_local_gpu_physical_experts,\n                    moe_ep_rank=moe_ep_rank,\n                    num_gpus_per_node=num_gpus_per_node,\n                    num_local_node_physical_experts=num_local_node_physical_experts,\n                )\n\n            # Fill remaining slots with fair random choices\n            num_remain = torch.sum(output_partial == -1).item()\n            output_partial[output_partial == -1] = torch.tensor(\n                _fair_choices(candidate_physical_expert_ids, k=num_remain, r=r),\n                dtype=dtype,\n            )\n\n    assert torch.all(logical_to_rank_dispatch_physical_map != -1)\n\n    device = logical_to_all_physical_map.device\n    return logical_to_rank_dispatch_physical_map[ep_rank, :, :].to(device)\n\n\ndef _logical_to_all_physical_raw(\n    logical_to_all_physical_map, layer_id: int, logical_expert_id: int\n) -> List[int]:\n    return [\n        physical_expert_id\n        for physical_expert_id in logical_to_all_physical_map[\n            layer_id, logical_expert_id\n        ].tolist()\n        if physical_expert_id != -1\n    ]\n\n\ndef _compute_gpu_id_of_physical_expert(\n    physical_expert_id: int, num_local_gpu_physical_experts: int\n) -> int:\n    return physical_expert_id // num_local_gpu_physical_experts\n\n\ndef _compute_node_id_of_physical_expert(\n    physical_expert_id: int, num_local_host_physical_experts: int\n) -> int:\n    return physical_expert_id // num_local_host_physical_experts\n\n\ndef _find_nearest_expert(\n    candidate_physical_expert_ids: List[int],\n    num_local_gpu_physical_experts: int,\n    moe_ep_rank: int,\n    num_gpus_per_node: int,\n    num_local_node_physical_experts: int,\n) -> int:\n    # 1. If only one candidate, return it directly\n    if len(candidate_physical_expert_ids) == 1:\n        return candidate_physical_expert_ids[0]\n\n    # 2. Prefer same-GPU experts\n    same_gpu_physical_expert_ids = [\n        physical_expert_id\n        for physical_expert_id in candidate_physical_expert_ids\n        if _compute_gpu_id_of_physical_expert(\n            physical_expert_id, num_local_gpu_physical_experts\n        )\n        == moe_ep_rank\n    ]\n    if len(same_gpu_physical_expert_ids) > 0:\n        return same_gpu_physical_expert_ids[0]\n\n    # 3. Otherwise, prefer same-node experts\n    node_rank = moe_ep_rank // num_gpus_per_node\n    same_node_physical_expert_ids = [\n        physical_expert_id\n        for physical_expert_id in candidate_physical_expert_ids\n        if _compute_node_id_of_physical_expert(\n            physical_expert_id, num_local_node_physical_experts\n        )\n        == node_rank\n    ]\n    if len(same_node_physical_expert_ids) > 0:\n        return same_node_physical_expert_ids[0]\n\n    # 4. At last, leave it as -1 to indicate not found.\n    return -1\n\n\ndef _fair_choices(arr: List, k: int, r: random.Random) -> List:\n    quotient, remainder = divmod(k, len(arr))\n    ans = arr * quotient + r.sample(arr, k=remainder)\n    r.shuffle(ans)\n    return ans\n\n\n@dataclass\nclass ModelConfigForExpertLocation:\n    num_layers: int\n    num_logical_experts: int\n    num_groups: Optional[int] = None\n\n    @staticmethod\n    def from_model_config(model_config: ModelConfig):\n        model_class, _ = get_model_architecture(model_config)\n        if hasattr(model_class, \"get_model_config_for_expert_location\"):\n            return model_class.get_model_config_for_expert_location(\n                model_config.hf_config\n            )\n        else:\n            return None\n\n\ndef compute_initial_expert_location_metadata(\n    server_args: ServerArgs,\n    model_config: ModelConfig,\n    moe_ep_rank: int,\n) -> Optional[ExpertLocationMetadata]:\n    data = server_args.init_expert_location\n    if data == \"trivial\":\n        return ExpertLocationMetadata.init_trivial(\n            server_args, model_config, moe_ep_rank\n        )\n\n    # TODO unify with the utils function\n    if data.endswith(\".pt\"):\n        data_dict = torch.load(data, weights_only=True)\n    elif data.endswith(\".json\"):\n        data_dict = json.loads(Path(data).read_text())\n    else:\n        data_dict = json.loads(data)\n\n    if \"physical_to_logical_map\" in data_dict:\n        logger.info(\n            \"init_expert_location from init_by_mapping using ServerArgs.init_expert_location\"\n        )\n        return ExpertLocationMetadata.init_by_mapping(\n            server_args,\n            model_config,\n            **data_dict,\n            moe_ep_rank=moe_ep_rank,\n        )\n    elif \"logical_count\" in data_dict:\n        logger.info(\n            \"init_expert_location from init_by_eplb using ServerArgs.init_expert_location\"\n        )\n        return ExpertLocationMetadata.init_by_eplb(\n            server_args, model_config, logical_count=data_dict[\"logical_count\"]\n        )\n    else:\n        raise NotImplementedError(\n            f\"Unknown init_expert_location format ({list(data_dict.keys())=})\"\n        )\n"
  },
  {
    "path": "python/sglang/srt/eplb/expert_location_dispatch.py",
    "content": "# Copyright 2023-2025 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\nfrom dataclasses import dataclass\nfrom typing import Literal, Optional\n\nimport torch\n\nfrom sglang.srt.eplb.expert_location import get_global_expert_location_metadata\nfrom sglang.srt.server_args import get_global_server_args\n\n\n@dataclass\nclass ExpertLocationDispatchInfo:\n    ep_dispatch_algorithm: Literal[\"static\", \"random\"]\n    # (num_logical_experts,)\n    partial_logical_to_rank_dispatch_physical_map: Optional[torch.Tensor]\n    # (num_logical_experts, X)\n    partial_logical_to_all_physical_map: torch.Tensor\n    # (num_logical_experts,)\n    partial_logical_to_all_physical_map_num_valid: torch.Tensor\n    num_physical_experts: int\n\n    @classmethod\n    def init_new(cls, layer_id: int):\n        ep_dispatch_algorithm = get_global_server_args().ep_dispatch_algorithm\n        expert_location_metadata = get_global_expert_location_metadata()\n        assert expert_location_metadata is not None\n\n        if ep_dispatch_algorithm is None:\n            return None\n\n        return cls(\n            ep_dispatch_algorithm=ep_dispatch_algorithm,\n            partial_logical_to_rank_dispatch_physical_map=(\n                expert_location_metadata.logical_to_rank_dispatch_physical_map[\n                    layer_id, :\n                ]\n                if expert_location_metadata.logical_to_rank_dispatch_physical_map\n                is not None\n                else None\n            ),\n            partial_logical_to_all_physical_map=expert_location_metadata.logical_to_all_physical_map[\n                layer_id, :\n            ],\n            partial_logical_to_all_physical_map_num_valid=expert_location_metadata.logical_to_all_physical_map_num_valid[\n                layer_id, :\n            ],\n            num_physical_experts=expert_location_metadata.num_physical_experts,\n        )\n\n\ndef transform_select_experts_inputs(\n    router_logits: torch.Tensor,\n    correction_bias: Optional[torch.Tensor],\n    info: Optional[ExpertLocationDispatchInfo],\n):\n    if (info is not None) and (info.ep_dispatch_algorithm == \"fake\"):\n        router_logits.uniform_(5, 10)\n        if correction_bias is not None:\n            correction_bias = torch.zeros_like(correction_bias)\n    return router_logits, correction_bias\n\n\ndef topk_ids_logical_to_physical(\n    topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo]\n) -> torch.Tensor:\n    if info is None:\n        return topk_ids\n\n    if info.ep_dispatch_algorithm == \"static\":\n        return _topk_ids_logical_to_physical_static(topk_ids, info)\n    if info.ep_dispatch_algorithm in [\"dynamic\", \"fake\"]:\n        return _topk_ids_logical_to_physical_dynamic(topk_ids, info)\n    raise NotImplementedError(f\"Unknown algorithm {info.ep_dispatch_algorithm}\")\n\n\ndef _topk_ids_logical_to_physical_static(\n    topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo]\n) -> torch.Tensor:\n    return info.partial_logical_to_rank_dispatch_physical_map[topk_ids]\n\n\ndef _topk_ids_logical_to_physical_dynamic(\n    topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo]\n) -> torch.Tensor:\n    topk_ids_original_shape = topk_ids.shape\n    device = topk_ids.device\n    topk_ids = topk_ids.flatten()\n\n    chosen_dispatch_index = (\n        torch.randint(0, 65536, topk_ids.shape, dtype=torch.int32, device=device)\n        % info.partial_logical_to_all_physical_map_num_valid[topk_ids]\n    )\n    topk_ids = info.partial_logical_to_all_physical_map[topk_ids, chosen_dispatch_index]\n\n    topk_ids = topk_ids.view(topk_ids_original_shape)\n    return topk_ids\n"
  },
  {
    "path": "python/sglang/srt/eplb/expert_location_updater.py",
    "content": "# Copyright 2023-2025 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\nimport logging\nfrom collections import defaultdict\nfrom typing import Dict, List, Optional, Tuple\n\nimport einops\nimport torch\nimport torch.distributed\nfrom torch.distributed import P2POp\n\nfrom sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager\nfrom sglang.srt.eplb.expert_location import (\n    ExpertLocationMetadata,\n    get_global_expert_location_metadata,\n)\nfrom sglang.srt.server_args import get_global_server_args\nfrom sglang.srt.utils import get_bool_env_var\n\nlogger = logging.getLogger(__name__)\n\n\n_LOG_INPUT = get_bool_env_var(\"SGLANG_EXPERT_LOCATION_UPDATER_LOG_INPUT\")\n\n\nclass ExpertLocationUpdater:\n    def __init__(self):\n        self._first_execution = True\n\n    def update(\n        self,\n        routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]],\n        new_expert_location_metadata: ExpertLocationMetadata,\n        update_layer_ids: List[int],\n        nnodes: int,\n        rank: int,\n    ):\n        \"\"\"\n        Update experts' physical location after EPLB.\n\n        Returns a map of layer_id to expert_ids that are missing due to rank\n        failures during fault conditions when elastic EP is enabled.\n        \"\"\"\n        if self._first_execution:\n            self._first_execution = False\n            torch.get_device_module().empty_cache()\n\n        old_expert_location_metadata = get_global_expert_location_metadata()\n        assert old_expert_location_metadata is not None\n\n        missing_logical_experts_by_layers = _update_expert_weights(\n            routed_experts_weights_of_layer=routed_experts_weights_of_layer,\n            old_expert_location_metadata=old_expert_location_metadata,\n            new_expert_location_metadata=new_expert_location_metadata,\n            update_layer_ids=update_layer_ids,\n            nnodes=nnodes,\n            rank=rank,\n        )\n        old_expert_location_metadata.update(\n            new_expert_location_metadata,\n            update_layer_ids=update_layer_ids,\n        )\n\n        return missing_logical_experts_by_layers\n\n\ndef _update_expert_weights(**kwargs):\n    if get_bool_env_var(\"SGLANG_EXPERT_LOCATION_UPDATER_CANARY\"):\n        return _update_expert_weights_with_canary(**kwargs)\n    else:\n        return _update_expert_weights_raw(**kwargs)\n\n\n# can add watchdog as well\ndef _update_expert_weights_with_canary(\n    routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]],\n    old_expert_location_metadata: ExpertLocationMetadata,\n    new_expert_location_metadata: ExpertLocationMetadata,\n    update_layer_ids: List[int],\n    nnodes: int,\n    rank: int,\n):\n    num_local_physical_experts = old_expert_location_metadata.num_local_physical_experts\n\n    def _get_canary_value(meta: ExpertLocationMetadata, layer_id: int):\n        return meta.physical_to_logical_map_cpu[\n            layer_id,\n            num_local_physical_experts * rank : num_local_physical_experts * (rank + 1),\n        ]\n\n    routed_experts_weights_of_layer = {\n        k: [x for x in v] for k, v in routed_experts_weights_of_layer.items()\n    }\n    for layer_id in update_layer_ids:\n        canary_tensor = (\n            _get_canary_value(old_expert_location_metadata, layer_id)\n            .clone()\n            .to(device=get_global_server_args().device, non_blocking=True)\n        )\n        routed_experts_weights_of_layer[layer_id].append(canary_tensor)\n\n    missing_logical_experts_by_layers = _update_expert_weights_raw(\n        routed_experts_weights_of_layer=routed_experts_weights_of_layer,\n        old_expert_location_metadata=old_expert_location_metadata,\n        new_expert_location_metadata=new_expert_location_metadata,\n        update_layer_ids=update_layer_ids,\n        nnodes=nnodes,\n        rank=rank,\n    )\n\n    for layer_id in update_layer_ids:\n        # can optimize speed if needed\n        expect_value = _get_canary_value(new_expert_location_metadata, layer_id)\n        actual_value = routed_experts_weights_of_layer[layer_id][-1].cpu()\n        assert torch.all(expect_value == actual_value), (\n            f\"{expect_value=} {actual_value=} {layer_id=} \"\n            f\"{old_expert_location_metadata.physical_to_logical_map_cpu.tolist()=} \"\n            f\"{new_expert_location_metadata.physical_to_logical_map_cpu.tolist()=} \"\n        )\n\n    return missing_logical_experts_by_layers\n\n\ndef _update_expert_weights_raw(\n    routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]],\n    old_expert_location_metadata: ExpertLocationMetadata,\n    new_expert_location_metadata: ExpertLocationMetadata,\n    update_layer_ids: List[int],\n    nnodes: int,\n    rank: int,\n):\n    log_metrics = get_bool_env_var(\"SGLANG_EXPERT_LOCATION_UPDATER_LOG_METRICS\")\n\n    temp_buffers = create_temp_buffers(\n        routed_experts_weights_of_layer[update_layer_ids[0]]\n    )\n\n    world_size = torch.distributed.get_world_size()\n    num_local_physical_experts = old_expert_location_metadata.num_local_physical_experts\n    num_gpu_per_node = world_size // nnodes\n\n    missing_logical_experts_by_layers: Dict[int, List[int]] = {}\n\n    for layer_id in update_layer_ids:\n        missing_logical_experts_info: List[int] = []\n        update_expert_weights_single_layer(\n            routed_experts_weights=routed_experts_weights_of_layer[layer_id],\n            temp_buffers=temp_buffers,\n            old_physical_to_logical_map=old_expert_location_metadata.physical_to_logical_map_cpu[\n                layer_id\n            ].tolist(),\n            new_physical_to_logical_map=new_expert_location_metadata.physical_to_logical_map_cpu[\n                layer_id\n            ].tolist(),\n            num_local_physical_experts=num_local_physical_experts,\n            num_gpu_per_node=num_gpu_per_node,\n            rank=rank,\n            world_size=world_size,\n            missing_logical_experts_info=missing_logical_experts_info,\n            log_metrics=log_metrics,\n        )\n        if len(missing_logical_experts_info) > 0:\n            missing_logical_experts_by_layers[layer_id] = missing_logical_experts_info\n    return missing_logical_experts_by_layers\n\n\ndef create_temp_buffers(sample_tensors):\n    return [torch.empty_like(tensor) for tensor in sample_tensors]\n\n\ndef update_expert_weights_single_layer(\n    routed_experts_weights: List[torch.Tensor],\n    temp_buffers: List[torch.Tensor],\n    old_physical_to_logical_map: List[int],  # (num_physical_Experts,)\n    new_physical_to_logical_map: List[int],  # (num_physical_Experts,)\n    num_local_physical_experts: int,\n    num_gpu_per_node: int,\n    rank: int,\n    world_size: Optional[int] = None,\n    missing_logical_experts_info: Optional[List[int]] = None,\n    debug: bool = False,\n    log_metrics: bool = False,\n):\n    assert all(\n        tensor.shape[0] == num_local_physical_experts\n        for tensor in routed_experts_weights\n    ), f\"{num_local_physical_experts=} {[x.shape for x in routed_experts_weights]=}\"\n    assert isinstance(old_physical_to_logical_map, list)\n    assert isinstance(new_physical_to_logical_map, list)\n\n    if _LOG_INPUT:\n        logger.info(\n            \"update_expert_weights_single_layer \"\n            f\"{[x.shape for x in routed_experts_weights]=} \"\n            f\"{[x.shape for x in temp_buffers]=} \"\n            f\"{old_physical_to_logical_map=} \"\n            f\"{new_physical_to_logical_map=} \"\n            f\"{num_local_physical_experts=} \"\n            f\"{num_gpu_per_node=} \"\n            f\"{rank=} \"\n            f\"{world_size=} \"\n        )\n\n    output_logs = [] if debug else None\n\n    num_physical_experts = len(old_physical_to_logical_map)\n    num_tensors = len(routed_experts_weights)\n\n    self_node_id = rank // num_gpu_per_node\n\n    local_expert_location_range = (\n        rank * num_local_physical_experts,\n        (rank + 1) * num_local_physical_experts,\n    )\n\n    def _entrypoint():\n        # List[Tuple[logical_expert_id, List[P2POp]]]\n        p2p_op_infos: List[Tuple[int, List[P2POp]]] = []\n        # List[Tuple[temp_buffers_expert_location, routed_experts_weights_expert_location]]\n        buffer2weight_copy_infos: List[Tuple[int, int]] = []\n\n        _handle_recv(buffer2weight_copy_infos, p2p_op_infos)\n        _create_isend_ops(p2p_op_infos)\n        _filter_p2p_ops(p2p_op_infos)\n        _execute_p2p_ops(p2p_op_infos)\n        _execute_buffer2weight_copies(buffer2weight_copy_infos)\n\n        if log_metrics:\n            _log_p2p_op_metrics(\n                p2p_op_infos,\n                world_size=world_size,\n                num_gpu_per_node=num_gpu_per_node,\n                self_node_id=self_node_id,\n            )\n\n        if debug:\n            output_logs.append(f\"{p2p_op_infos=}\")\n            output_logs.append(f\"{buffer2weight_copy_infos=}\")\n\n    def _handle_recv(buffer2weight_copy_infos, p2p_op_infos):\n        for dst_expert_location in range(*local_expert_location_range):\n            _handle_recv_of_dst_expert_location(\n                dst_expert_location, buffer2weight_copy_infos, p2p_op_infos\n            )\n\n    def _handle_recv_of_dst_expert_location(\n        dst_expert_location: int, buffer2weight_copy_infos, p2p_op_infos\n    ):\n        logical_expert_id = new_physical_to_logical_map[dst_expert_location]\n\n        # case 1: unchanged\n        if old_physical_to_logical_map[dst_expert_location] == logical_expert_id:\n            if debug:\n                output_logs.append(\n                    f\"handle_recv_of_dst_expert_location {dst_expert_location=} case=unchanged\"\n                )\n            return\n\n        # case 2: same-gpu\n        for src_expert_location in range(*local_expert_location_range):\n            if old_physical_to_logical_map[src_expert_location] == logical_expert_id:\n                for i in range(num_tensors):\n                    _get_tensor(temp_buffers, i, dst_expert_location).copy_(\n                        _get_tensor(routed_experts_weights, i, src_expert_location)\n                    )\n                buffer2weight_copy_infos.append(\n                    (dst_expert_location, dst_expert_location)\n                )\n                if debug:\n                    output_logs.append(\n                        f\"handle_recv_of_dst_expert_location {dst_expert_location=} case=same-gpu {src_expert_location=}\"\n                    )\n                return\n\n        # case 3: free-rider\n        for src_expert_location in range(\n            rank * num_local_physical_experts, dst_expert_location\n        ):\n            if new_physical_to_logical_map[src_expert_location] == logical_expert_id:\n                buffer2weight_copy_infos.append(\n                    (src_expert_location, dst_expert_location)\n                )\n                if debug:\n                    output_logs.append(\n                        f\"handle_recv_of_dst_expert_location {dst_expert_location=} case=free-rider {src_expert_location=}\"\n                    )\n                return\n\n        same_node_mapping, cross_node_mapping, need_comm_self_node_dst_ranks = (\n            _compute_comm_info(logical_expert_id=logical_expert_id)\n        )\n\n        # case 4: same-node\n        if rank in need_comm_self_node_dst_ranks:\n            chosen_src_rank = same_node_mapping.chunk_value_from_element_value(\n                element_value=rank\n            )\n            _create_p2p_recv_and_buffer2weight_copy(\n                buffer2weight_copy_infos,\n                p2p_op_infos,\n                src_rank=chosen_src_rank,\n                logical_expert_id=logical_expert_id,\n                dst_expert_location=dst_expert_location,\n            )\n            if debug:\n                output_logs.append(\n                    f\"handle_recv_of_dst_expert_location {dst_expert_location=} case=same-node {chosen_src_rank=}\"\n                )\n            return\n\n        # case 5: cross-node\n        # Future work: can optimize when there are multiple ranks in the same dst node that uses the same logical expert\n        chosen_src_rank = cross_node_mapping.chunk_value_from_element_value(\n            element_value=rank\n        )\n        _create_p2p_recv_and_buffer2weight_copy(\n            buffer2weight_copy_infos,\n            p2p_op_infos,\n            src_rank=chosen_src_rank,\n            logical_expert_id=logical_expert_id,\n            dst_expert_location=dst_expert_location,\n        )\n        if debug:\n            output_logs.append(\n                f\"handle_recv_of_dst_expert_location {dst_expert_location=} case=cross-node {chosen_src_rank=}\"\n            )\n        return\n\n    def _create_p2p_recv_and_buffer2weight_copy(\n        buffer2weight_copy_infos,\n        p2p_op_infos,\n        *,\n        logical_expert_id: int,\n        src_rank: int,\n        dst_expert_location: int,\n    ):\n        p2p_op_infos.append(\n            (\n                logical_expert_id,\n                [\n                    P2POp(\n                        op=torch.distributed.irecv,\n                        tensor=_get_tensor(temp_buffers, i, dst_expert_location),\n                        peer=src_rank,\n                    )\n                    for i in range(num_tensors)\n                ],\n            )\n        )\n        buffer2weight_copy_infos.append((dst_expert_location, dst_expert_location))\n\n    def _create_isend_ops(p2p_op_infos):\n        handled_logical_expert_ids = set()\n        for src_expert_location in range(*local_expert_location_range):\n            logical_expert_id = old_physical_to_logical_map[src_expert_location]\n\n            if logical_expert_id in handled_logical_expert_ids:\n                continue\n            handled_logical_expert_ids.add(logical_expert_id)\n\n            _create_isend_ops_of_logical_expert_id(\n                logical_expert_id, src_expert_location, p2p_op_infos\n            )\n\n    def _create_isend_ops_of_logical_expert_id(\n        logical_expert_id, src_expert_location, p2p_op_infos\n    ):\n        same_node_mapping, cross_node_mapping, need_comm_self_node_dst_ranks = (\n            _compute_comm_info(logical_expert_id=logical_expert_id)\n        )\n\n        same_node_dst_ranks = same_node_mapping.element_values_from_chunk_value(\n            chunk_value=rank\n        )\n        cross_node_dst_ranks = cross_node_mapping.element_values_from_chunk_value(\n            chunk_value=rank\n        )\n        all_dst_ranks = same_node_dst_ranks + cross_node_dst_ranks\n\n        if debug:\n            output_logs.append(\n                f\"create_isend_ops_of_logical_expert_id {logical_expert_id=} {src_expert_location=} {same_node_dst_ranks=} {cross_node_dst_ranks=}\"\n            )\n\n        p2p_op_infos.append(\n            (\n                logical_expert_id,\n                [\n                    P2POp(\n                        op=torch.distributed.isend,\n                        tensor=_get_tensor(\n                            routed_experts_weights, i, src_expert_location\n                        ),\n                        peer=dst_rank,\n                    )\n                    for dst_rank in all_dst_ranks\n                    for i in range(num_tensors)\n                ],\n            )\n        )\n\n    def _compute_comm_info(logical_expert_id: int):\n        all_src_ranks = _deduplicate_ordered(\n            [\n                x // num_local_physical_experts\n                for x in range(num_physical_experts)\n                if old_physical_to_logical_map[x] == logical_expert_id\n            ]\n        )\n        all_src_nodes = [x // num_gpu_per_node for x in all_src_ranks]\n        self_node_src_ranks = [\n            x for x in all_src_ranks if x // num_gpu_per_node == self_node_id\n        ]\n\n        need_comm_dst_ranks = _deduplicate_ordered(\n            [\n                x // num_local_physical_experts\n                for x in range(num_physical_experts)\n                if new_physical_to_logical_map[x] == logical_expert_id\n                and x // num_local_physical_experts not in all_src_ranks\n            ]\n        )\n        need_comm_self_node_dst_ranks = (\n            [x for x in need_comm_dst_ranks if x // num_gpu_per_node == self_node_id]\n            if len(self_node_src_ranks) > 0\n            else []\n        )\n        need_comm_cross_node_dst_ranks = [\n            x\n            for x in need_comm_dst_ranks\n            if (x // num_gpu_per_node) not in all_src_nodes\n        ]\n\n        same_node_mapping = _ChunkUtils(\n            chunk_values=self_node_src_ranks,\n            element_values=need_comm_self_node_dst_ranks,\n        )\n\n        cross_node_mapping = _ChunkUtils(\n            chunk_values=all_src_ranks,\n            element_values=need_comm_cross_node_dst_ranks,\n        )\n\n        return same_node_mapping, cross_node_mapping, need_comm_self_node_dst_ranks\n\n    def _filter_p2p_ops(p2p_op_infos):\n        elastic_ep_state = ElasticEPStateManager.instance()\n        if elastic_ep_state is not None and missing_logical_experts_info is not None:\n            # Filter out inactive P2P ops and record missing expert IDs in missing_logical_experts_info\n            is_active = elastic_ep_state.active_ranks_cpu\n            for i, (logical_expert_id, ops) in enumerate(p2p_op_infos):\n                has_isend = any(op.op == torch.distributed.isend for op in ops)\n                has_irecv = any(op.op == torch.distributed.irecv for op in ops)\n                assert not (has_isend and has_irecv), (\n                    \"Each p2p_op_infos entry is expected to contain only send \"\n                    \"or only recv ops.\"\n                )\n\n                if has_isend:\n                    p2p_op_infos[i] = (\n                        logical_expert_id,\n                        [op for op in ops if is_active[op.peer]],\n                    )\n                elif has_irecv:\n                    if any(not is_active[op.peer] for op in ops):\n                        missing_logical_experts_info.append(logical_expert_id)\n                        p2p_op_infos[i] = (logical_expert_id, [])\n\n    def _execute_p2p_ops(p2p_op_infos):\n        sorted_infos = sorted(p2p_op_infos, key=lambda info: info[0])\n        p2p_ops = [op for _, ops in sorted_infos for op in ops]\n        if len(p2p_ops) == 0:\n            return\n\n        reqs = torch.distributed.batch_isend_irecv(p2p_ops)\n        for req in reqs:\n            req.wait()\n\n    def _execute_buffer2weight_copies(buffer2weight_copy_infos):\n        for (\n            temp_buffers_expert_location,\n            routed_experts_weights_expert_location,\n        ) in buffer2weight_copy_infos:\n            for i in range(num_tensors):\n                _get_tensor(\n                    routed_experts_weights, i, routed_experts_weights_expert_location\n                ).copy_(_get_tensor(temp_buffers, i, temp_buffers_expert_location))\n\n    def _get_tensor(tensors, tensor_index: int, expert_location: int) -> torch.Tensor:\n        return tensors[tensor_index][_get_local_expert_location(expert_location)]\n\n    def _get_local_expert_location(expert_location: int) -> int:\n        assert (\n            local_expert_location_range[0]\n            <= expert_location\n            < local_expert_location_range[1]\n        )\n        return expert_location % num_local_physical_experts\n\n    _entrypoint()\n\n    return output_logs\n\n\nclass _ChunkUtils:\n    def __init__(self, *, chunk_values: List, element_values: List):\n        self.chunk_values = chunk_values\n        self.element_values = element_values\n\n    def chunk_value_from_element_value(self, element_value):\n        chunk_index = self._chunk_index_from_element_index(\n            num_elements=len(self.element_values),\n            num_chunks=len(self.chunk_values),\n            element_index=self.element_values.index(element_value),\n        )\n        return self.chunk_values[chunk_index]\n\n    def element_values_from_chunk_value(self, chunk_value) -> List:\n        if len(self.element_values) == 0:\n            return []\n        element_slice = self._element_slice_from_chunk_index(\n            num_elements=len(self.element_values),\n            num_chunks=len(self.chunk_values),\n            chunk_index=self.chunk_values.index(chunk_value),\n        )\n        return self.element_values[element_slice]\n\n    @staticmethod\n    def _chunk_index_from_element_index(\n        num_elements: int, num_chunks: int, element_index: int\n    ) -> int:\n        short_chunk_size, num_long_chunks = divmod(num_elements, num_chunks)\n        num_elements_for_long_chunks = num_long_chunks * (short_chunk_size + 1)\n        if element_index < num_elements_for_long_chunks:\n            return element_index // (short_chunk_size + 1)\n        else:\n            return (\n                num_long_chunks\n                + (element_index - num_elements_for_long_chunks) // short_chunk_size\n            )\n\n    @staticmethod\n    def _element_slice_from_chunk_index(\n        num_elements: int, num_chunks: int, chunk_index: int\n    ) -> slice:\n        short_chunk_size, num_long_chunks = divmod(num_elements, num_chunks)\n        start = chunk_index * short_chunk_size + min(chunk_index, num_long_chunks)\n        end = start + short_chunk_size + int(chunk_index < num_long_chunks)\n        return slice(start, end)\n\n\ndef _deduplicate_ordered(arr: List[int]):\n    output = []\n    for item in arr:\n        if len(output) == 0 or item != output[-1]:\n            output.append(item)\n    return output\n\n\ndef _log_p2p_op_metrics(\n    p2p_op_infos: List[Tuple[int, List[P2POp]]],\n    num_gpu_per_node: int,\n    world_size: int,\n    self_node_id: int,\n):\n    text = \"\"\n    all_ops = [op for _, ops in p2p_op_infos for op in ops]\n\n    for direction, ops in _group_by(all_ops, _get_direction_from_op).items():\n        nbytes_of_gpu = [0] * world_size\n        for op in ops:\n            nbytes_of_gpu[op.peer] += op.tensor.nbytes\n        nbytes_of_gpu = torch.tensor(nbytes_of_gpu, dtype=torch.int64)\n\n        nbytes_of_node = einops.reduce(\n            nbytes_of_gpu,\n            \"(num_nodes num_gpu_per_node) -> num_nodes\",\n            num_gpu_per_node=num_gpu_per_node,\n            reduction=\"sum\",\n        )\n\n        nbytes_curr_node = nbytes_of_node[self_node_id]\n        nbytes_cross_node = torch.sum(nbytes_of_node) - nbytes_curr_node\n\n        text += (\n            f\"{direction}_nbytes_of_gpu={nbytes_of_gpu.tolist()} \"\n            f\"{direction}_nbytes_of_node={nbytes_of_node.tolist()} \"\n            f\"{direction}_nbytes_curr_node={nbytes_curr_node.item()} \"\n            f\"{direction}_nbytes_cross_node={nbytes_cross_node.item()} \"\n        )\n\n    logger.info(f\"[ExpertLocationUpdater] {text}\")\n\n\ndef _get_direction_from_op(op: P2POp):\n    if op.op == torch.distributed.isend:\n        return \"isend\"\n    if op.op == torch.distributed.irecv:\n        return \"irecv\"\n    raise NotImplementedError\n\n\ndef _group_by(items, keyfunc):\n    ans = defaultdict(list)\n    for item in items:\n        ans[keyfunc(item)].append(item)\n    return dict(ans)\n"
  },
  {
    "path": "python/sglang/srt/function_call/base_format_detector.py",
    "content": "import json\nimport logging\nfrom abc import ABC, abstractmethod\nfrom typing import Any, Dict, List\n\nimport orjson\nfrom partial_json_parser.core.exceptions import MalformedJSON\nfrom partial_json_parser.core.options import Allow\n\nfrom sglang.srt.entrypoints.openai.protocol import Tool\nfrom sglang.srt.environ import envs\nfrom sglang.srt.function_call.core_types import (\n    StreamingParseResult,\n    ToolCallItem,\n    _GetInfoFunc,\n)\nfrom sglang.srt.function_call.utils import (\n    _find_common_prefix,\n    _is_complete_json,\n    _partial_json_loads,\n)\n\nlogger = logging.getLogger(__name__)\n\n\nclass BaseFormatDetector(ABC):\n    \"\"\"Base class providing two sets of interfaces: one-time and streaming incremental.\"\"\"\n\n    def __init__(self):\n        # Streaming state management\n        # Buffer for accumulating incomplete patterns that arrive across multiple streaming chunks\n        self._buffer = \"\"\n        # Stores complete tool call info (name and arguments) for each tool being parsed.\n        # Used by serving layer for completion handling when streaming ends.\n        # Format: [{\"name\": str, \"arguments\": dict}, ...]\n        self.prev_tool_call_arr: List[Dict] = []\n        # Index of currently streaming tool call. Starts at -1 (no active tool),\n        # increments as each tool completes. Tracks which tool's arguments are streaming.\n        self.current_tool_id: int = -1\n        # Flag for whether current tool's name has been sent to client.\n        # Tool names sent first with empty parameters, then arguments stream incrementally.\n        self.current_tool_name_sent: bool = False\n        # Tracks raw JSON string content streamed to client for each tool's arguments.\n        # Critical for serving layer to calculate remaining content when streaming ends.\n        # Each index corresponds to a tool_id. Example: ['{\"location\": \"San Francisco\"', '{\"temp\": 72']\n        self.streamed_args_for_tool: List[str] = []\n\n        # Token configuration (override in subclasses)\n        self.bot_token = \"\"\n        self.eot_token = \"\"\n        self.tool_call_separator = \", \"\n\n    def _get_tool_indices(self, tools: List[Tool]) -> Dict[str, int]:\n        \"\"\"\n        Get a mapping of tool names to their indices in the tools list.\n\n        This utility method creates a dictionary mapping function names to their\n        indices in the tools list, which is commonly needed for tool validation\n        and ToolCallItem creation.\n\n        Args:\n            tools: List of available tools\n\n        Returns:\n            Dictionary mapping tool names to their indices\n        \"\"\"\n        return {\n            tool.function.name: i for i, tool in enumerate(tools) if tool.function.name\n        }\n\n    def parse_base_json(self, action: Any, tools: List[Tool]) -> List[ToolCallItem]:\n        tool_indices = self._get_tool_indices(tools)\n        if not isinstance(action, list):\n            action = [action]\n\n        results = []\n        for act in action:\n            name = act.get(\"name\")\n            if not (name and name in tool_indices):\n                logger.warning(f\"Model attempted to call undefined function: {name}\")\n                if not envs.SGLANG_FORWARD_UNKNOWN_TOOLS.get():\n                    continue  # Skip unknown tools (default legacy behavior)\n\n            results.append(\n                ToolCallItem(\n                    tool_index=tool_indices.get(name, -1),\n                    name=name,\n                    parameters=json.dumps(\n                        act.get(\"parameters\") or act.get(\"arguments\", {}),\n                        ensure_ascii=False,\n                    ),\n                )\n            )\n\n        return results\n\n    @abstractmethod\n    def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:\n        \"\"\"\n        Parses the text in one go. Returns success=True if the format matches, otherwise False.\n        Note that leftover_text here represents \"content that this parser will not consume further\".\n        \"\"\"\n        action = orjson.loads(text)\n        return StreamingParseResult(calls=self.parse_base_json(action, tools))\n\n    def _ends_with_partial_token(self, buffer: str, bot_token: str) -> int:\n        \"\"\"\n        Check if buffer ends with a partial bot_token.\n        Return the length of the partial bot_token.\n\n        For some format, the bot_token is not a token in model's vocabulary, such as\n        `[TOOL_CALLS] [` in Mistral.\n        \"\"\"\n        for i in range(1, min(len(buffer) + 1, len(bot_token))):\n            if bot_token.startswith(buffer[-i:]):\n                return i\n        return 0\n\n    def parse_streaming_increment(\n        self, new_text: str, tools: List[Tool]\n    ) -> StreamingParseResult:\n        \"\"\"\n        Streaming incremental parsing with tool validation.\n\n        This base implementation works best with formats where:\n        1. bot_token is followed immediately by JSON (e.g., bot_token + JSON_array)\n        2. JSON can be parsed incrementally using partial_json_loads\n        3. Multiple tool calls are separated by \"; \" or \", \"\n\n        Examples of incompatible formats (need custom implementation, may reuse some logic from this class):\n        - Each tool call is wrapped in a separate block: See Qwen25Detector\n        - Multiple separate blocks: [TOOL_CALLS] [...] \\n [TOOL_CALLS] [...]\n        - Tool call is Pythonic style\n\n        For incompatible formats, detectors should override this method with custom logic.\n        \"\"\"\n        # Append new text to buffer\n        self._buffer += new_text\n        current_text = self._buffer\n\n        # The current_text has tool_call if it is the start of a new tool call sequence\n        # or it is the start of a new tool call after a tool call separator, when there is a previous tool call\n        if not (\n            self.has_tool_call(current_text)\n            or (\n                self.current_tool_id > 0\n                and current_text.startswith(self.tool_call_separator)\n            )\n        ):\n            # Only clear buffer if we're sure no tool call is starting\n            if not self._ends_with_partial_token(self._buffer, self.bot_token):\n                normal_text = self._buffer\n                self._buffer = \"\"\n                if self.eot_token in normal_text:\n                    normal_text = normal_text.replace(self.eot_token, \"\")\n                return StreamingParseResult(normal_text=normal_text)\n            else:\n                # Might be partial bot_token, keep buffering\n                return StreamingParseResult()\n\n        # Build tool indices if not already built\n        if not hasattr(self, \"_tool_indices\"):\n            self._tool_indices = self._get_tool_indices(tools)\n\n        flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR\n\n        try:\n            try:\n                # Priority check: if we're processing a subsequent tool (current_tool_id > 0),\n                # first check if text starts with the tool separator. This is critical for\n                # parallel tool calls because the bot_token (e.g., '[') can also\n                # appear inside array parameters of the current tool, and we must not\n                # mistakenly identify that as the start of a new tool.\n                if self.current_tool_id > 0 and current_text.startswith(\n                    self.tool_call_separator\n                ):\n                    start_idx = len(self.tool_call_separator)\n                else:\n                    # Only search for bot_token if not processing subsequent tool\n                    tool_call_pos = current_text.find(self.bot_token)\n                    if tool_call_pos != -1:\n                        start_idx = tool_call_pos + len(self.bot_token)\n                    else:\n                        start_idx = 0\n\n                if start_idx >= len(current_text):\n                    return StreamingParseResult()\n\n                obj, end_idx = _partial_json_loads(current_text[start_idx:], flags)\n\n                is_current_complete = _is_complete_json(\n                    current_text[start_idx : start_idx + end_idx]\n                )\n\n                # Validate tool name if present\n                if \"name\" in obj and obj[\"name\"] not in self._tool_indices:\n                    # Invalid tool name - reset state\n                    self._buffer = \"\"\n                    self.current_tool_id = -1\n                    self.current_tool_name_sent = False\n                    if self.streamed_args_for_tool:\n                        self.streamed_args_for_tool.pop()\n                    return StreamingParseResult()\n\n                # Handle parameters/arguments consistency\n                # NOTE: we assume here that the obj is always partial of a single tool call\n                if \"parameters\" in obj:\n                    assert (\n                        \"arguments\" not in obj\n                    ), \"model generated both parameters and arguments\"\n                    obj[\"arguments\"] = obj[\"parameters\"]\n\n                current_tool_call = obj\n\n            except MalformedJSON:\n                return StreamingParseResult()\n\n            if not current_tool_call:\n                return StreamingParseResult()\n\n            # Case 1: Handle tool name streaming\n            # This happens when we encounter a tool but haven't sent its name yet\n            if not self.current_tool_name_sent:\n                function_name = current_tool_call.get(\"name\")\n\n                if function_name and function_name in self._tool_indices:\n                    # If this is a new tool (current_tool_id was -1), initialize it\n                    if self.current_tool_id == -1:\n                        self.current_tool_id = 0\n                        self.streamed_args_for_tool.append(\"\")\n                    # If this is a subsequent tool, ensure streamed_args_for_tool is large enough\n                    elif self.current_tool_id >= len(self.streamed_args_for_tool):\n                        while len(self.streamed_args_for_tool) <= self.current_tool_id:\n                            self.streamed_args_for_tool.append(\"\")\n\n                    # Send the tool name with empty parameters\n                    res = StreamingParseResult(\n                        calls=[\n                            ToolCallItem(\n                                tool_index=self.current_tool_id,\n                                name=function_name,\n                                parameters=\"\",\n                            )\n                        ],\n                    )\n                    self.current_tool_name_sent = True\n                else:\n                    res = StreamingParseResult()\n\n            # Case 2: Handle streaming arguments\n            # This happens when we've already sent the tool name and now need to stream arguments incrementally\n            else:\n                cur_arguments = current_tool_call.get(\"arguments\")\n                res = StreamingParseResult()\n\n                if cur_arguments:\n                    # Calculate how much of the arguments we've already streamed\n                    sent = len(self.streamed_args_for_tool[self.current_tool_id])\n                    cur_args_json = json.dumps(cur_arguments, ensure_ascii=False)\n                    prev_arguments = None\n                    if self.current_tool_id < len(self.prev_tool_call_arr):\n                        prev_arguments = self.prev_tool_call_arr[\n                            self.current_tool_id\n                        ].get(\"arguments\")\n\n                    argument_diff = None\n\n                    # If the current tool's JSON is complete, send all remaining arguments\n                    if is_current_complete:\n                        argument_diff = cur_args_json[sent:]\n                        completing_tool_id = (\n                            self.current_tool_id\n                        )  # Save the ID of the tool that's completing\n\n                        # Only remove the processed portion, keep unprocessed content\n                        self._buffer = current_text[start_idx + end_idx :]\n\n                    # If the tool is still being parsed, send incremental changes\n                    elif prev_arguments:\n                        prev_args_json = json.dumps(prev_arguments, ensure_ascii=False)\n                        if cur_args_json != prev_args_json:\n                            prefix = _find_common_prefix(prev_args_json, cur_args_json)\n                            argument_diff = prefix[sent:]\n\n                    # Update prev_tool_call_arr with current state\n                    if self.current_tool_id >= 0:\n                        # Ensure prev_tool_call_arr is large enough\n                        while len(self.prev_tool_call_arr) <= self.current_tool_id:\n                            self.prev_tool_call_arr.append({})\n                        self.prev_tool_call_arr[self.current_tool_id] = (\n                            current_tool_call\n                        )\n\n                    # Advance to next tool if complete\n                    if is_current_complete:\n                        self.current_tool_name_sent = False\n                        self.current_tool_id += 1\n\n                    # Send the argument diff if there's something new\n                    if argument_diff is not None:\n                        # Use the correct tool_index: completing_tool_id for completed tools, current_tool_id for ongoing\n                        tool_index_to_use = (\n                            completing_tool_id\n                            if is_current_complete\n                            else self.current_tool_id\n                        )\n                        res = StreamingParseResult(\n                            calls=[\n                                ToolCallItem(\n                                    tool_index=tool_index_to_use,\n                                    parameters=argument_diff,\n                                )\n                            ],\n                        )\n                        self.streamed_args_for_tool[tool_index_to_use] += argument_diff\n\n            return res\n\n        except Exception as e:\n            logger.error(f\"Error in parse_streaming_increment: {e}\")\n            return StreamingParseResult()\n\n    @abstractmethod\n    def has_tool_call(self, text: str) -> bool:\n        \"\"\"\n        Check if the given text contains function call markers specific to this format.\n        \"\"\"\n        raise NotImplementedError()\n\n    def supports_structural_tag(self) -> bool:\n        \"\"\"Return True if this detector supports structural tag format.\"\"\"\n        return True\n\n    @abstractmethod\n    def structure_info(self) -> _GetInfoFunc:\n        \"\"\"\n        Return a function that creates StructureInfo for constrained generation.\n\n        The returned function takes a tool name and returns a StructureInfo object\n        containing the begin/end patterns and trigger tokens needed for constrained\n        generation of function calls in this format.\n\n        Returns:\n            A function that takes a tool name (str) and returns StructureInfo\n        \"\"\"\n        raise NotImplementedError()\n"
  },
  {
    "path": "python/sglang/srt/function_call/core_types.py",
    "content": "from dataclasses import dataclass\nfrom typing import Callable, List, Optional\n\nfrom pydantic import BaseModel\n\n\nclass ToolCallItem(BaseModel):\n    \"\"\"Simple encapsulation of the parsed ToolCall result for easier usage in streaming contexts.\"\"\"\n\n    tool_index: int\n    name: Optional[str] = None\n    parameters: str  # JSON string\n\n\nclass StreamingParseResult(BaseModel):\n    \"\"\"Result of streaming incremental parsing.\"\"\"\n\n    normal_text: str = \"\"\n    calls: List[ToolCallItem] = []\n\n\n@dataclass\nclass StructureInfo:\n    begin: str\n    end: str\n    trigger: str\n\n\n\"\"\"\nHelper alias of function\nUsually it is a function that takes a name string and returns a StructureInfo object,\nwhich can be used to construct a structural_tag object\n\"\"\"\n_GetInfoFunc = Callable[[str], StructureInfo]\n"
  },
  {
    "path": "python/sglang/srt/function_call/deepseekv31_detector.py",
    "content": "import json\nimport logging\nimport re\nfrom typing import List\n\nfrom sglang.srt.entrypoints.openai.protocol import Tool\nfrom sglang.srt.function_call.base_format_detector import BaseFormatDetector\nfrom sglang.srt.function_call.core_types import (\n    StreamingParseResult,\n    StructureInfo,\n    ToolCallItem,\n    _GetInfoFunc,\n)\nfrom sglang.srt.function_call.utils import _is_complete_json\n\nlogger = logging.getLogger(__name__)\n\n\nclass DeepSeekV31Detector(BaseFormatDetector):\n    \"\"\"\n    Detector for DeepSeek V3 model function call format.\n\n    The DeepSeek V3 format uses special Unicode tokens to delimit function calls\n    with JSON code blocks for arguments.\n\n    Format Structure:\n    ```\n    <｜tool▁calls▁begin｜><｜tool▁call▁begin｜>{function_name}<｜tool▁sep｜>{json_arguments}<｜tool▁calls▁end｜><｜end▁of▁sentence｜>\n    ```\n    Examples:\n    ```\n    <｜tool▁calls▁begin｜><｜tool▁call▁begin｜>get_current_weather<｜tool▁sep｜>{\"location\": \"Tokyo\"}<｜tool▁call▁end｜><｜tool▁call▁begin｜>get_current_weather<｜tool▁sep｜>{\"location\": \"Paris\"}<｜tool▁call▁end｜><｜tool▁calls▁end｜><｜end▁of▁sentence｜>\n    ```\n\n    Key Components:\n    - Tool Calls Section: Wrapped between `<｜tool▁calls▁begin｜>` and `<｜tool▁calls▁end｜>`\n    - Individual Tool Call: Wrapped between `<｜tool▁call▁begin｜>` and `<｜tool▁call▁end｜>`\n    - Function Declaration: `<｜tool▁call▁begin｜>{function_name}<｜tool▁sep｜>`\n    - Arguments: JSON code block between `<｜tool▁sep｜>` and `<｜tool▁call▁end｜>`\n    - Supports multiple tool calls\n\n    Reference: https://www.modelscope.cn/models/deepseek-ai/DeepSeek-V3.1\n    \"\"\"\n\n    def __init__(self):\n        super().__init__()\n        self.bot_token = \"<｜tool▁calls▁begin｜>\"\n        self.eot_token = \"<｜tool▁calls▁end｜>\"\n        self.func_call_regex = r\"<｜tool▁call▁begin｜>.*?<｜tool▁call▁end｜>\"\n        self.func_detail_regex = (\n            r\"<｜tool▁call▁begin｜>(.*)<｜tool▁sep｜>(.*)<｜tool▁call▁end｜>\"\n        )\n        self._last_arguments = \"\"\n        self.current_tool_id = -1\n\n    def has_tool_call(self, text: str) -> bool:\n        \"\"\"Check if the text contains a deepseek format tool call.\"\"\"\n        return self.bot_token in text\n\n    def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:\n        \"\"\"\n        One-time parsing: Detects and parses tool calls in the provided text.\n\n        :param text: The complete text to parse.\n        :param tools: List of available tools.\n        :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.\n        \"\"\"\n        idx = text.find(self.bot_token)\n        normal_text = text[:idx].strip() if idx != -1 else text\n        if self.bot_token not in text:\n            return StreamingParseResult(normal_text=normal_text, calls=[])\n        match_result_list = re.findall(self.func_call_regex, text, re.DOTALL)\n        calls = []\n        try:\n            for match_result in match_result_list:\n                # Get function name\n                func_detail = re.search(self.func_detail_regex, match_result, re.DOTALL)\n                func_name = func_detail.group(1)\n                func_args = func_detail.group(2)\n                func_args = json.loads(func_args)\n                # construct match_result for parse_base_json\n                match_result = {\"name\": func_name, \"parameters\": func_args}\n                calls.extend(self.parse_base_json(match_result, tools))\n            return StreamingParseResult(normal_text=normal_text, calls=calls)\n        except Exception as e:\n            logger.error(f\"Error in detect_and_parse: {e}\")\n            # return the normal text if parsing fails\n            return StreamingParseResult(normal_text=text)\n\n    def parse_streaming_increment(\n        self, new_text: str, tools: List[Tool]\n    ) -> StreamingParseResult:\n        \"\"\"\n        Streaming incremental parsing tool calls for DeepSeekV3 format.\n        \"\"\"\n        self._buffer += new_text\n        current_text = self._buffer\n\n        # Check if we have a tool call (either the start token or individual tool call)\n        has_tool_call = (\n            self.bot_token in current_text or \"<｜tool▁call▁begin｜>\" in current_text\n        )\n\n        if not has_tool_call:\n            self._buffer = \"\"\n            for e_token in [self.eot_token, \"<｜tool▁call▁end｜>\"]:\n                if e_token in new_text:\n                    new_text = new_text.replace(e_token, \"\")\n            return StreamingParseResult(normal_text=new_text)\n\n        if not hasattr(self, \"_tool_indices\"):\n            self._tool_indices = self._get_tool_indices(tools)\n\n        calls: list[ToolCallItem] = []\n        try:\n            partial_match = re.search(\n                pattern=r\"<｜tool▁call▁begin｜>(.*)<｜tool▁sep｜>(.*?)(<｜tool▁call▁end｜>|$)\",\n                string=current_text,\n                flags=re.DOTALL,\n            )\n            if partial_match:\n                func_name = partial_match.group(1).strip()\n                func_args_raw = partial_match.group(2).strip()\n                is_tool_end = partial_match.group(3)\n\n                # Initialize state if this is the first tool call\n                if self.current_tool_id == -1:\n                    self.current_tool_id = 0\n                    self.prev_tool_call_arr = []\n                    self.streamed_args_for_tool = [\"\"]\n\n                # Ensure we have enough entries in our tracking arrays\n                while len(self.prev_tool_call_arr) <= self.current_tool_id:\n                    self.prev_tool_call_arr.append({})\n                while len(self.streamed_args_for_tool) <= self.current_tool_id:\n                    self.streamed_args_for_tool.append(\"\")\n\n                if not self.current_tool_name_sent:\n                    calls.append(\n                        ToolCallItem(\n                            tool_index=self.current_tool_id,\n                            name=func_name,\n                            parameters=\"\",\n                        )\n                    )\n                    self.current_tool_name_sent = True\n                    # Store the tool call info for serving layer completions endpoint\n                    self.prev_tool_call_arr[self.current_tool_id] = {\n                        \"name\": func_name,\n                        \"arguments\": {},\n                    }\n                else:\n                    argument_diff = (\n                        func_args_raw[len(self._last_arguments) :]\n                        if func_args_raw.startswith(self._last_arguments)\n                        else func_args_raw\n                    )\n\n                    if argument_diff:\n                        calls.append(\n                            ToolCallItem(\n                                tool_index=self.current_tool_id,\n                                name=None,\n                                parameters=argument_diff,\n                            )\n                        )\n                        self._last_arguments += argument_diff\n                        self.streamed_args_for_tool[\n                            self.current_tool_id\n                        ] += argument_diff\n\n                    if _is_complete_json(func_args_raw):\n                        # Update the stored arguments\n                        try:\n                            parsed_args = json.loads(func_args_raw)\n                            self.prev_tool_call_arr[self.current_tool_id][\n                                \"arguments\"\n                            ] = parsed_args\n                        except json.JSONDecodeError:\n                            pass\n\n                        # Find the end of the current tool call and remove only that part from buffer\n                        if is_tool_end:\n                            # Remove the completed tool call from buffer, keep any remaining content\n                            self._buffer = current_text[partial_match.end(3) :]\n                        else:\n                            self._buffer = \"\"\n\n                        result = StreamingParseResult(normal_text=\"\", calls=calls)\n                        self.current_tool_id += 1\n                        self._last_arguments = \"\"\n                        self.current_tool_name_sent = False\n                        return result\n\n            return StreamingParseResult(normal_text=\"\", calls=calls)\n\n        except Exception as e:\n            logger.error(f\"Error in parse_streaming_increment: {e}\")\n            return StreamingParseResult(normal_text=current_text)\n\n    def structure_info(self) -> _GetInfoFunc:\n        return lambda name: StructureInfo(\n            begin=\"<｜tool▁call▁begin｜>\" + name + \"<｜tool▁sep｜>\",\n            end=\"<｜tool▁call▁end｜>\",\n            trigger=\"<｜tool▁call▁begin｜>\",\n        )\n"
  },
  {
    "path": "python/sglang/srt/function_call/deepseekv32_detector.py",
    "content": "import json\nimport logging\nimport re\n\nfrom partial_json_parser.core.options import Allow\n\nfrom sglang.srt.entrypoints.openai.protocol import Tool\nfrom sglang.srt.function_call.base_format_detector import BaseFormatDetector\nfrom sglang.srt.function_call.core_types import (\n    StreamingParseResult,\n    StructureInfo,\n    ToolCallItem,\n    _GetInfoFunc,\n)\nfrom sglang.srt.function_call.utils import _find_common_prefix, _partial_json_loads\n\nlogger = logging.getLogger(__name__)\n\n\nclass DeepSeekV32Detector(BaseFormatDetector):\n    \"\"\"\n    Detector for DeepSeek V3.2 model function call format.\n\n    The DeepSeek V3.2 format uses XML-like DSML tags to delimit function calls.\n    Supports two parameter formats:\n\n    Format 1 - XML Parameter Tags:\n    ```\n    <｜DSML｜function_calls>\n        <｜DSML｜invoke name=\"function_name\">\n        <｜DSML｜parameter name=\"param_name\" string=\"true\">value</｜DSML｜parameter>\n        ...\n    </｜DSML｜invoke>\n    </｜DSML｜function_calls>\n    ```\n\n    Format 2 - Direct JSON:\n    ```\n    <｜DSML｜function_calls>\n        <｜DSML｜invoke name=\"function_name\">\n        {\n            \"param_name\": \"value\"\n        }\n    </｜DSML｜invoke>\n    </｜DSML｜function_calls>\n    ```\n\n    Examples:\n    ```\n    <｜DSML｜function_calls>\n        <｜DSML｜invoke name=\"get_favorite_tourist_spot\">\n        <｜DSML｜parameter name=\"city\" string=\"true\">San Francisco</｜DSML｜parameter>\n    </｜DSML｜invoke>\n    </｜DSML｜function_calls>\n\n    <｜DSML｜function_calls>\n        <｜DSML｜invoke name=\"get_favorite_tourist_spot\">\n        { \"city\": \"San Francisco\" }\n    </｜DSML｜invoke>\n    </｜DSML｜function_calls>\n    ```\n\n    Key Components:\n    - Tool Calls Section: Wrapped between `<｜DSML｜function_calls>` and `</｜DSML｜function_calls>`\n    - Individual Tool Call: Wrapped between `<｜DSML｜invoke name=\"...\">` and `</｜DSML｜invoke>`\n    - Parameters: Either XML tags or direct JSON format\n    - Supports multiple tool calls\n\n    Reference: DeepSeek V3.2 format specification\n    \"\"\"\n\n    def __init__(self):\n        super().__init__()\n        self.bot_token = \"<｜DSML｜function_calls>\"\n        self.eot_token = \"</｜DSML｜function_calls>\"\n        self.invoke_end_token = \"</｜DSML｜invoke>\"\n        self.parameter_regex = r'<｜DSML｜parameter\\s+name=\"([^\"]+)\"\\s+string=\"([^\"]+)\"\\s*>(.*?)</｜DSML｜parameter>'\n        self.partial_parameter_regex = (\n            r'<｜DSML｜parameter\\s+name=\"([^\"]+)\"\\s+string=\"([^\"]+)\"\\s*>(.*)$'\n        )\n        self.function_calls_regex = (\n            r\"<｜DSML｜function_calls>(.*?)</｜DSML｜function_calls>\"\n        )\n        self.invoke_regex = (\n            r'<｜DSML｜invoke\\s+name=\"([^\"]+)\"\\s*>(.*?)(</｜DSML｜invoke>|$)'\n        )\n        self.prefix_parameter_end_call = [\"</\", \"｜DSML｜\", \"parameter\"]\n        self.prefix_invoke_end_call = [\"</\", \"｜DSML｜\", \"inv\", \"oke\"]\n        self.current_tool_id = -1\n\n    def has_tool_call(self, text: str) -> bool:\n        \"\"\"Check if the text contains a deepseek v32 format tool call.\"\"\"\n        return self.bot_token in text or \"<｜DSML｜invoke\" in text\n\n    def _parse_parameters_from_xml(\n        self, invoke_content: str, allow_partial: bool = False\n    ) -> str:\n        \"\"\"\n        Parse parameters from either XML-like format or JSON format to str.\n\n        Supports two formats:\n        1. XML parameter tags: <｜DSML｜parameter name=\"...\" string=\"...\">value</｜DSML｜parameter>\n        2. Direct JSON: { \"key\": \"value\" }\n        \"\"\"\n        # First, try to parse as direct JSON (new format)\n        invoke_content_stripped = invoke_content.strip()\n        if invoke_content_stripped.startswith(\"{\"):\n            if allow_partial:\n                # Remove incomplete invoke end call prefix in case they are captured by param\n                for token in reversed(self.prefix_invoke_end_call):\n                    invoke_content_stripped = invoke_content_stripped.rstrip(token)\n                return invoke_content_stripped\n            elif invoke_content_stripped.endswith(\"}\"):\n                return invoke_content_stripped\n\n        # Fall back to XML parameter tag parsing (original format)\n        parameters = {}\n        # Find all complete parameter matches\n        param_matches = list(\n            re.finditer(self.parameter_regex, invoke_content, re.DOTALL)\n        )\n\n        last_match_end = 0\n        for match in param_matches:\n            param_name = match.group(1)\n            param_type = match.group(2)\n            param_value = match.group(3)\n            last_match_end = match.end()\n\n            # Convert value based on type\n            if param_type == \"true\":  # string type\n                parameters[param_name] = param_value.strip()\n            else:\n                # Try to parse as JSON for other types\n                try:\n                    parameters[param_name] = json.loads(param_value.strip())\n                except (json.JSONDecodeError, ValueError):\n                    parameters[param_name] = param_value.strip()\n\n        # If allowed, try to parse a partial parameter at the end\n        if allow_partial:\n            remaining_content = invoke_content[last_match_end:]\n\n            # Remove incomplete parameter_end_call prefix in case they are captured by param\n            for token in reversed(self.prefix_parameter_end_call):\n                remaining_content = remaining_content.rstrip(token)\n\n            # Match start of a parameter tag + value (potentially incomplete)\n            # Regex: <tag name=\"...\" string=\"...\">VALUE... (no end tag)\n            partial_match = re.search(\n                self.partial_parameter_regex, remaining_content, re.DOTALL\n            )\n\n            if partial_match and (param_value := partial_match.group(3)):\n                param_name = partial_match.group(1)\n                if partial_match.group(2) == \"true\":\n                    parameters[param_name] = param_value.strip()\n                else:\n                    try:\n                        parameters[param_name] = _partial_json_loads(\n                            param_value, Allow.ALL\n                        )[0]\n                    except json.JSONDecodeError:\n                        parameters[param_name] = param_value.strip()\n\n        return json.dumps(parameters, ensure_ascii=False)\n\n    def detect_and_parse(self, text: str, tools: list[Tool]) -> StreamingParseResult:\n        \"\"\"\n        One-time parsing: Detects and parses tool calls in the provided text.\n\n        :param text: The complete text to parse.\n        :param tools: List of available tools.\n        :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.\n        \"\"\"\n        idx = text.find(self.bot_token)\n        normal_text = text[:idx].strip() if idx != -1 else text\n        if self.bot_token not in text:\n            return StreamingParseResult(normal_text=normal_text, calls=[])\n\n        calls = []\n        try:\n            # Extract content between function_calls tags\n            function_calls_match = re.search(\n                self.function_calls_regex,\n                text,\n                re.DOTALL,\n            )\n            if not function_calls_match:\n                return StreamingParseResult(normal_text=normal_text, calls=[])\n\n            function_calls_content = function_calls_match.group(1)\n\n            # Find all invoke blocks\n            invoke_matches = re.findall(\n                self.invoke_regex, function_calls_content, re.DOTALL\n            )\n\n            for func_name, invoke_content, _ in invoke_matches:\n                # Parse parameters from XML format\n                func_args = self._parse_parameters_from_xml(invoke_content)\n                # construct match_result for parse_base_json\n                match_result = {\"name\": func_name, \"parameters\": json.loads(func_args)}\n                calls.extend(self.parse_base_json(match_result, tools))\n\n            return StreamingParseResult(normal_text=normal_text, calls=calls)\n        except Exception as e:\n            logger.error(f\"Error in detect_and_parse: {e}\")\n            # return the normal text if parsing fails\n            return StreamingParseResult(normal_text=text)\n\n    def parse_streaming_increment(\n        self, new_text: str, tools: list[Tool]\n    ) -> StreamingParseResult:\n        \"\"\"\n        Streaming incremental parsing tool calls for DeepSeekV32 format.\n        Supports multiple consecutive invoke blocks and argument streaming.\n        \"\"\"\n        self._buffer += new_text\n        current_text = self._buffer\n\n        # Check if buffer contains any DSML markers or ends with potential tag prefix\n        # This handles partial/streaming DSML content\n        dsml_markers = [\"｜DSML｜\", \"<｜\", \"</｜\"]\n        potentially_dsml = any(marker in current_text for marker in dsml_markers)\n\n        # Also check if text ends with start of a tag (to handle \"<\" arriving separately)\n        dsml_prefixes = [\"<\", \"<｜\", \"</\", \"</｜\"]\n        ends_with_prefix = any(\n            current_text.rstrip().endswith(prefix) for prefix in dsml_prefixes\n        )\n\n        if (\n            not self.has_tool_call(current_text)\n            and not potentially_dsml\n            and not ends_with_prefix\n        ):\n            self._buffer = \"\"\n            for e_token in [self.eot_token, self.invoke_end_token]:\n                if e_token in current_text:\n                    current_text = current_text.replace(e_token, \"\")\n            return StreamingParseResult(normal_text=current_text)\n\n        all_calls: list[ToolCallItem] = []\n        try:\n            # Loop to handle multiple consecutive invoke blocks\n            while True:\n                # Try to match an invoke block (may be partial)\n                invoke_match = re.search(\n                    pattern=self.invoke_regex,\n                    string=current_text,\n                    flags=re.DOTALL,\n                )\n                if not invoke_match:\n                    break\n\n                func_name = invoke_match.group(1).strip()\n                invoke_content = invoke_match.group(2)\n                # group(3) is either \"</｜DSML｜invoke>\" (complete) or \"\" (incomplete, matched with $)\n                is_tool_end = bool(invoke_match.group(3))\n\n                # Initialize state if this is the first tool call\n                if self.current_tool_id == -1:\n                    self.current_tool_id = 0\n                    self.prev_tool_call_arr = []\n                    self.streamed_args_for_tool = [\"\"]\n\n                # Ensure arrays are large enough for current tool\n                while len(self.prev_tool_call_arr) <= self.current_tool_id:\n                    self.prev_tool_call_arr.append({})\n                while len(self.streamed_args_for_tool) <= self.current_tool_id:\n                    self.streamed_args_for_tool.append(\"\")\n\n                # 1. Send tool name if not sent yet\n                if not self.current_tool_name_sent:\n                    all_calls.append(\n                        ToolCallItem(\n                            tool_index=self.current_tool_id,\n                            name=func_name,\n                            parameters=\"\",\n                        )\n                    )\n                    self.current_tool_name_sent = True\n\n                # 2. Parse current parameters (partial or complete)\n                current_params = self._parse_parameters_from_xml(\n                    invoke_content, allow_partial=not is_tool_end\n                )\n\n                # 3. Calculate and send incremental arguments\n                sent_len = len(self.streamed_args_for_tool[self.current_tool_id])\n                prev_params = self.prev_tool_call_arr[self.current_tool_id].get(\n                    \"arguments\"\n                )\n\n                argument_diff = None\n\n                if is_tool_end:\n                    # If complete, send everything remaining\n                    argument_diff = current_params[sent_len:]\n                elif prev_params is not None:\n                    # If partial, send stable prefix diff\n                    if current_params != prev_params:\n                        prefix = _find_common_prefix(current_params, prev_params)\n                        if len(prefix) > sent_len:\n                            argument_diff = prefix[sent_len:]\n\n                if argument_diff:\n                    all_calls.append(\n                        ToolCallItem(\n                            tool_index=self.current_tool_id,\n                            name=None,\n                            parameters=argument_diff,\n                        )\n                    )\n                    self.streamed_args_for_tool[self.current_tool_id] += argument_diff\n\n                # Update the stored arguments\n                self.prev_tool_call_arr[self.current_tool_id] = {\n                    \"name\": func_name,\n                    \"arguments\": current_params,\n                }\n\n                # Check if tool call is complete (has closing tag)\n                if is_tool_end:\n                    # Remove the completed tool call from buffer\n                    self._buffer = current_text[invoke_match.end() :]\n                    current_text = self._buffer  # Update for next iteration\n\n                    # Move to next tool call\n                    self.current_tool_id += 1\n                    self.current_tool_name_sent = False\n\n                    # Continue loop to check for more invoke blocks\n                    continue\n                else:\n                    # Tool call not complete yet, don't return anything\n                    # Wait for more chunks until we see </｜DSML｜invoke>\n                    break\n\n            # No more invoke blocks found\n            return StreamingParseResult(normal_text=\"\", calls=all_calls)\n\n        except Exception as e:\n            logger.error(f\"Error in parse_streaming_increment: {e}\")\n            return StreamingParseResult(normal_text=current_text)\n\n    def structure_info(self) -> _GetInfoFunc:\n        return lambda name: StructureInfo(\n            begin=f'<｜DSML｜invoke name=\"{name}\">',\n            end=\"</｜DSML｜invoke>\",\n            trigger=\"<｜DSML｜invoke\",\n        )\n"
  },
  {
    "path": "python/sglang/srt/function_call/deepseekv3_detector.py",
    "content": "import json\nimport logging\nimport re\nfrom typing import List\n\nfrom sglang.srt.entrypoints.openai.protocol import Tool\nfrom sglang.srt.function_call.base_format_detector import BaseFormatDetector\nfrom sglang.srt.function_call.core_types import (\n    StreamingParseResult,\n    StructureInfo,\n    ToolCallItem,\n    _GetInfoFunc,\n)\nfrom sglang.srt.function_call.utils import _is_complete_json\n\nlogger = logging.getLogger(__name__)\n\n\nclass DeepSeekV3Detector(BaseFormatDetector):\n    \"\"\"\n    Detector for DeepSeek V3 model function call format.\n\n    The DeepSeek V3 format uses special Unicode tokens to delimit function calls\n    with JSON code blocks for arguments.\n\n    Format Structure:\n    ```\n    <｜tool▁calls▁begin｜><｜tool▁call▁begin｜>function<｜tool▁sep｜>{function_name}\\n```json\\n{json_arguments}\\n```<｜tool▁calls▁end｜><｜end▁of▁sentence｜>\n    ```\n    Examples:\n    ```\n    <｜tool▁calls▁begin｜><｜tool▁call▁begin｜>function<｜tool▁sep｜>get_current_weather\\n```json\\n{\"location\": \"Tokyo\"}\\n```<｜tool▁call▁end｜>\\n<｜tool▁call▁begin｜>function<｜tool▁sep｜>get_current_weather\\n```json\\n{\"location\": \"Paris\"}\\n```<｜tool▁call▁end｜><｜tool▁calls▁end｜><｜end▁of▁sentence｜>\n    ```\n\n    Key Components:\n    - Tool Calls Section: Wrapped between `<｜tool▁calls▁begin｜>` and `<｜tool▁calls▁end｜>`\n    - Individual Tool Call: Wrapped between `<｜tool▁call▁begin｜>` and `<｜tool▁call▁end｜>`\n    - Function Declaration: `function<｜tool▁sep｜>{function_name}`\n    - Arguments: JSON code block between ````json` and ````\n    - Supports multiple tool calls\n\n    Reference: https://huggingface.co/deepseek-ai/DeepSeek-V3-0324?chat_template=default\n    \"\"\"\n\n    def __init__(self):\n        super().__init__()\n        self.bot_token = \"<｜tool▁calls▁begin｜>\"\n        self.eot_token = \"<｜tool▁calls▁end｜>\"\n        self.func_call_regex = r\"<｜tool▁call▁begin｜>.*?<｜tool▁call▁end｜>\"\n        self.func_detail_regex = r\"<｜tool▁call▁begin｜>(.*)<｜tool▁sep｜>(.*)\\n```json\\n(.*)\\n```<｜tool▁call▁end｜>\"\n        self._last_arguments = \"\"\n        self.current_tool_id = -1\n\n    def has_tool_call(self, text: str) -> bool:\n        \"\"\"Check if the text contains a deepseek format tool call.\"\"\"\n        return self.bot_token in text\n\n    def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:\n        \"\"\"\n        One-time parsing: Detects and parses tool calls in the provided text.\n\n        :param text: The complete text to parse.\n        :param tools: List of available tools.\n        :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.\n        \"\"\"\n        idx = text.find(self.bot_token)\n        normal_text = text[:idx].strip() if idx != -1 else text\n        if self.bot_token not in text:\n            return StreamingParseResult(normal_text=normal_text, calls=[])\n        match_result_list = re.findall(self.func_call_regex, text, re.DOTALL)\n        calls = []\n        try:\n            for match_result in match_result_list:\n                # Get function name\n                func_detail = re.search(self.func_detail_regex, match_result, re.DOTALL)\n                func_name = func_detail.group(2)\n                func_args = func_detail.group(3)\n                func_args = json.loads(func_args)\n                # construct match_result for parse_base_json\n                match_result = {\"name\": func_name, \"parameters\": func_args}\n                calls.extend(self.parse_base_json(match_result, tools))\n            return StreamingParseResult(normal_text=normal_text, calls=calls)\n        except Exception as e:\n            logger.error(f\"Error in detect_and_parse: {e}\")\n            # return the normal text if parsing fails\n            return StreamingParseResult(normal_text=text)\n\n    def parse_streaming_increment(\n        self, new_text: str, tools: List[Tool]\n    ) -> StreamingParseResult:\n        \"\"\"\n        Streaming incremental parsing tool calls for DeepSeekV3 format.\n        \"\"\"\n        self._buffer += new_text\n        current_text = self._buffer\n\n        # Check if we have a tool call (either the start token or individual tool call)\n        has_tool_call = (\n            self.bot_token in current_text or \"<｜tool▁call▁begin｜>\" in current_text\n        )\n\n        if not has_tool_call:\n            self._buffer = \"\"\n            for e_token in [self.eot_token, \"```\", \"<｜tool▁call▁end｜>\"]:\n                if e_token in new_text:\n                    new_text = new_text.replace(e_token, \"\")\n            return StreamingParseResult(normal_text=new_text)\n\n        if not hasattr(self, \"_tool_indices\"):\n            self._tool_indices = self._get_tool_indices(tools)\n\n        calls: list[ToolCallItem] = []\n        try:\n            partial_match = re.search(\n                pattern=r\"<｜tool▁call▁begin｜>(.*)<｜tool▁sep｜>(.*)\\n```json\\n(.*)\\n```.*\",\n                string=current_text,\n                flags=re.DOTALL,\n            )\n            if partial_match:\n                func_name = partial_match.group(2).strip()\n                func_args_raw = partial_match.group(3).strip()\n\n                # Initialize state if this is the first tool call\n                if self.current_tool_id == -1:\n                    self.current_tool_id = 0\n                    self.prev_tool_call_arr = []\n                    self.streamed_args_for_tool = [\"\"]\n\n                # Ensure we have enough entries in our tracking arrays\n                while len(self.prev_tool_call_arr) <= self.current_tool_id:\n                    self.prev_tool_call_arr.append({})\n                while len(self.streamed_args_for_tool) <= self.current_tool_id:\n                    self.streamed_args_for_tool.append(\"\")\n\n                if not self.current_tool_name_sent:\n                    calls.append(\n                        ToolCallItem(\n                            tool_index=self.current_tool_id,\n                            name=func_name,\n                            parameters=\"\",\n                        )\n                    )\n                    self.current_tool_name_sent = True\n                    # Store the tool call info for serving layer completions endpoint\n                    self.prev_tool_call_arr[self.current_tool_id] = {\n                        \"name\": func_name,\n                        \"arguments\": {},\n                    }\n                else:\n                    argument_diff = (\n                        func_args_raw[len(self._last_arguments) :]\n                        if func_args_raw.startswith(self._last_arguments)\n                        else func_args_raw\n                    )\n\n                    if argument_diff:\n                        calls.append(\n                            ToolCallItem(\n                                tool_index=self.current_tool_id,\n                                name=None,\n                                parameters=argument_diff,\n                            )\n                        )\n                        self._last_arguments += argument_diff\n                        self.streamed_args_for_tool[\n                            self.current_tool_id\n                        ] += argument_diff\n\n                    if _is_complete_json(func_args_raw):\n                        # Update the stored arguments\n                        try:\n                            parsed_args = json.loads(func_args_raw)\n                            self.prev_tool_call_arr[self.current_tool_id][\n                                \"arguments\"\n                            ] = parsed_args\n                        except json.JSONDecodeError:\n                            pass\n\n                        # Find the end of the current tool call and remove only that part from buffer\n                        tool_call_end_pattern = (\n                            r\"<｜tool▁call▁begin｜>.*?<｜tool▁call▁end｜>\"\n                        )\n                        match = re.search(\n                            tool_call_end_pattern, current_text, re.DOTALL\n                        )\n                        if match:\n                            # Remove the completed tool call from buffer, keep any remaining content\n                            self._buffer = current_text[match.end() :]\n                        else:\n                            self._buffer = \"\"\n\n                        result = StreamingParseResult(normal_text=\"\", calls=calls)\n                        self.current_tool_id += 1\n                        self._last_arguments = \"\"\n                        self.current_tool_name_sent = False\n                        return result\n\n            return StreamingParseResult(normal_text=\"\", calls=calls)\n\n        except Exception as e:\n            logger.error(f\"Error in parse_streaming_increment: {e}\")\n            return StreamingParseResult(normal_text=current_text)\n\n    def structure_info(self) -> _GetInfoFunc:\n        return lambda name: StructureInfo(\n            begin=\">\" + name + \"\\n```json\\n\",\n            end=\"\\n```<\",\n            trigger=\">\" + name + \"\\n```json\\n\",\n        )\n"
  },
  {
    "path": "python/sglang/srt/function_call/function_call_parser.py",
    "content": "import logging\nfrom typing import Dict, List, Literal, Optional, Set, Tuple, Type, Union\n\nfrom sglang.srt.entrypoints.openai.protocol import (\n    LegacyStructuralTagResponseFormat,\n    StructuresResponseFormat,\n    Tool,\n    ToolCallConstraint,\n    ToolChoice,\n)\nfrom sglang.srt.environ import ToolStrictLevel, envs\nfrom sglang.srt.function_call.base_format_detector import BaseFormatDetector\nfrom sglang.srt.function_call.core_types import ToolCallItem\nfrom sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector\nfrom sglang.srt.function_call.deepseekv31_detector import DeepSeekV31Detector\nfrom sglang.srt.function_call.deepseekv32_detector import DeepSeekV32Detector\nfrom sglang.srt.function_call.gigachat3_detector import GigaChat3Detector\nfrom sglang.srt.function_call.glm4_moe_detector import Glm4MoeDetector\nfrom sglang.srt.function_call.glm47_moe_detector import Glm47MoeDetector\nfrom sglang.srt.function_call.gpt_oss_detector import GptOssDetector\nfrom sglang.srt.function_call.hermes_detector import HermesDetector\nfrom sglang.srt.function_call.internlm_detector import InternlmDetector\nfrom sglang.srt.function_call.kimik2_detector import KimiK2Detector\nfrom sglang.srt.function_call.lfm2_detector import Lfm2Detector\nfrom sglang.srt.function_call.llama32_detector import Llama32Detector\nfrom sglang.srt.function_call.mimo_detector import MiMoDetector\nfrom sglang.srt.function_call.minimax_m2 import MinimaxM2Detector\nfrom sglang.srt.function_call.mistral_detector import MistralDetector\nfrom sglang.srt.function_call.pythonic_detector import PythonicDetector\nfrom sglang.srt.function_call.qwen3_coder_detector import Qwen3CoderDetector\nfrom sglang.srt.function_call.qwen25_detector import Qwen25Detector\nfrom sglang.srt.function_call.step3_detector import Step3Detector\nfrom sglang.srt.function_call.trinity_detector import TrinityDetector\nfrom sglang.srt.function_call.utils import get_json_schema_constraint\n\nlogger = logging.getLogger(__name__)\n\n\nclass FunctionCallParser:\n    \"\"\"\n    Parser for function/tool calls in model outputs.\n\n    This class handles both streaming and non-streaming parsing of function calls using a detector.\n    In streaming scenarios, each time new_text is received, it calls detector.parse_streaming_increment\n    and returns the resulting normal_text and calls to the upper layer (or SSE).\n    \"\"\"\n\n    ToolCallParserEnum: Dict[str, Type[BaseFormatDetector]] = {\n        \"deepseekv3\": DeepSeekV3Detector,\n        \"deepseekv31\": DeepSeekV31Detector,\n        \"deepseekv32\": DeepSeekV32Detector,\n        \"glm\": Glm4MoeDetector,\n        \"glm45\": Glm4MoeDetector,\n        \"glm47\": Glm47MoeDetector,\n        \"gpt-oss\": GptOssDetector,\n        \"kimi_k2\": KimiK2Detector,\n        \"lfm2\": Lfm2Detector,\n        \"llama3\": Llama32Detector,\n        \"mimo\": MiMoDetector,\n        \"mistral\": MistralDetector,\n        \"pythonic\": PythonicDetector,\n        \"qwen\": Qwen25Detector,\n        \"qwen25\": Qwen25Detector,\n        \"qwen3_coder\": Qwen3CoderDetector,\n        \"step3\": Step3Detector,\n        \"step3p5\": Qwen3CoderDetector,\n        \"minimax-m2\": MinimaxM2Detector,\n        \"trinity\": TrinityDetector,\n        \"interns1\": InternlmDetector,\n        \"hermes\": HermesDetector,\n        \"gigachat3\": GigaChat3Detector,\n    }\n\n    def __init__(self, tools: List[Tool], tool_call_parser: str):\n        detector_class = self.ToolCallParserEnum.get(tool_call_parser)\n        if detector_class:\n            detector = detector_class()\n        else:\n            raise ValueError(f\"Unsupported tool_call_parser: {tool_call_parser}\")\n\n        self.detector = detector\n        self.tools = tools\n        self.tool_strict_level = envs.SGLANG_TOOL_STRICT_LEVEL.get()\n\n    def has_tool_call(self, text: str) -> bool:\n        \"\"\"\n        Check if the given text contains a tool call in the format supported by this parser.\n        This delegates to the detector's implementation.\n\n        Args:\n            text: The text to check for tool calls\n\n        Returns:\n            True if the text contains a tool call, False otherwise\n        \"\"\"\n        if not self.tools:\n            return False\n        return self.detector.has_tool_call(text)\n\n    def parse_non_stream(self, full_text: str) -> Tuple[str, list[ToolCallItem]]:\n        \"\"\"\n        One-time parsing of the full text to extract tool calls.\n\n        Args:\n            full_text: The complete text to parse\n\n        Returns:\n            A tuple containing:\n            - The remaining text after parsing that was not consumed by the detector (can be treated as normal text)\n            - A list of tool calls parsed from the text\n        \"\"\"\n        if not self.tools:\n            return full_text, []\n        parsed_result = self.detector.detect_and_parse(full_text, self.tools)\n        tool_call_list = parsed_result.calls\n        if tool_call_list:\n            return parsed_result.normal_text, tool_call_list\n        else:\n            return full_text, []\n\n    def parse_stream_chunk(self, chunk_text: str) -> Tuple[str, list[ToolCallItem]]:\n        \"\"\"\n        Streaming incremental parsing of chunks of text as they arrive.\n\n        Args:\n            chunk_text: The new chunk of text to parse\n\n        Returns:\n            A tuple containing:\n            - The normal text that should be displayed to the user\n            - A list of tool calls parsed from the chunk\n        \"\"\"\n        if not self.tools:\n            return chunk_text, []\n        final_normal_text = \"\"\n        final_calls = []\n\n        sp_result = self.detector.parse_streaming_increment(chunk_text, self.tools)\n        if sp_result.normal_text:\n            final_normal_text = sp_result.normal_text\n        if sp_result.calls:\n            final_calls.extend(sp_result.calls)\n            final_normal_text = sp_result.normal_text\n\n        return final_normal_text, final_calls\n\n    def get_structure_tag(self) -> LegacyStructuralTagResponseFormat:\n        \"\"\"\n        Generate a structural tag response format for all available tools.\n\n        This creates the necessary structural tags that guide the model's output format.\n        \"\"\"\n        tool_structures: List[StructuresResponseFormat] = list()\n        tool_trigger_set: Set[str] = set()\n\n        get_structure_info = self.detector.structure_info()\n        for tool in self.tools:\n            function = tool.function\n            name = function.name\n            assert name is not None\n            info = get_structure_info(name)\n\n            # accept all if not strict, otherwise only accept the schema\n            is_strict = (\n                function.strict or self.tool_strict_level >= ToolStrictLevel.PARAMETER\n            )\n            schema = function.parameters if is_strict else {}\n\n            tool_structures.append(\n                StructuresResponseFormat(\n                    begin=info.begin,\n                    schema=schema or {},  # type: ignore\n                    end=info.end,\n                )\n            )\n            tool_trigger_set.add(info.trigger)\n\n        # TODO(dark): move this into new structural tag format\n        # This requires all grammar backend support the new format\n        return LegacyStructuralTagResponseFormat(\n            type=\"structural_tag\",\n            structures=tool_structures,\n            triggers=list(tool_trigger_set),\n        )\n\n    def get_structure_constraint(\n        self, tool_choice: Union[ToolChoice, Literal[\"auto\", \"required\"]]\n    ) -> Optional[ToolCallConstraint]:\n        \"\"\"\n        Returns the appropriate structure constraint for tool calls based on the tool_choice.\n        The constraint is used to guide the model's output format.\n\n        Args:\n            tool_choice: The tool choice setting from the request\n\n        Returns:\n            A tuple of (constraint_type, constraint_value) to be added to sampling parameters,\n            or None if no constraint applies.\n        \"\"\"\n        # NOTE: structural_tag only supports JSON-compatible content between the begin and end.\n        # It cannot parse or validate function call Pythonic or XML-ish syntax.\n        if (\n            self.detector.supports_structural_tag()\n            and tool_choice == \"auto\"\n            and (\n                any(tool.function.strict for tool in self.tools)\n                or self.tool_strict_level >= ToolStrictLevel.FUNCTION\n            )\n        ):\n            tag = self.get_structure_tag()\n            return (\"structural_tag\", tag)\n        elif tool_choice == \"required\" or isinstance(tool_choice, ToolChoice):\n            json_schema = get_json_schema_constraint(self.tools, tool_choice)\n            return (\"json_schema\", json_schema)\n"
  },
  {
    "path": "python/sglang/srt/function_call/gigachat3_detector.py",
    "content": "import json\nimport logging\nimport re\nfrom typing import List\n\nfrom sglang.srt.entrypoints.openai.protocol import Tool\nfrom sglang.srt.function_call.base_format_detector import BaseFormatDetector\nfrom sglang.srt.function_call.core_types import (\n    StreamingParseResult,\n    ToolCallItem,\n    _GetInfoFunc,\n)\n\nlogger = logging.getLogger(__name__)\n\nREGEX_FUNCTION_CALL = re.compile(\n    r\"(?:function call<\\|role_sep\\|>\\n|<\\|function_call\\|>)(.*)\",\n    re.DOTALL,\n)\n\nREGEX_CONTENT_PATTERN = re.compile(\n    r\"^(.*?)(?:<\\|message_sep\\|>|<\\|function_call\\|>)\",\n    re.DOTALL,\n)\n\nNAME_REGEX = re.compile(\n    r'\"name\"\\s*:\\s*\"([^\"]*)\"',\n    re.DOTALL,\n)\n\nARGS_REGEX = re.compile(\n    r'\"arguments\"\\s*:\\s*(.*)',\n    re.DOTALL,\n)\n\n\nclass GigaChat3Detector(BaseFormatDetector):\n    def __init__(self) -> None:\n        super().__init__()\n        self.tool_started: bool = False\n        self.tool_name_sent: bool = False\n        self.end_content: bool = False\n        self._buffer: str = \"\"\n        self.prev_tool_call_arr: list[dict] = []\n\n    def has_tool_call(self, text: str) -> bool:\n        \"\"\"Check if text contains a tool call marker\"\"\"\n        return \"function call<|role_sep|>\\n\" in text or \"<|function_call|>\" in text\n\n    def detect_and_parse(\n        self,\n        text: str,\n        tools: List[Tool],\n    ) -> StreamingParseResult:\n        \"\"\"\n        Non-streaming parsing of complete model output.\n        Extracts tool calls and content from the full text.\n        \"\"\"\n        logger.debug(f\"[GigaChat3] detect_and_parse: {text}\")\n        model_output = text\n        function_call = None\n        content = None\n        if model_output.rstrip().endswith(\"</s>\"):\n            model_output = model_output[: model_output.rfind(\"</s>\")]\n        m_func = REGEX_FUNCTION_CALL.search(model_output)\n        if m_func:\n            try:\n                function_call = json.loads(m_func.group(1), strict=False)\n                if not (\n                    isinstance(function_call, dict)\n                    and \"name\" in function_call\n                    and \"arguments\" in function_call\n                ):\n                    function_call = None\n                elif not isinstance(function_call[\"arguments\"], dict):\n                    function_call = None\n            except json.JSONDecodeError as e:\n                logger.warning(f\"[GigaChat3] JSON decode error: {e}\")\n                return StreamingParseResult(\n                    normal_text=model_output,\n                    calls=[],\n                )\n        m_content = REGEX_CONTENT_PATTERN.search(model_output)\n        if m_content:\n            content = m_content.group(1)\n        else:\n            content = model_output\n        if not function_call:\n            return StreamingParseResult(normal_text=content, calls=[])\n        name = function_call[\"name\"]\n        args = function_call[\"arguments\"]\n        match_result = {\"name\": name, \"arguments\": args}\n        calls = self.parse_base_json(match_result, tools)\n        return StreamingParseResult(normal_text=content, calls=calls)\n\n    def parse_streaming_increment(\n        self,\n        new_text: str,\n        tools: List[Tool],\n    ) -> StreamingParseResult:\n        \"\"\"\n        Streaming parser for incremental text chunks.\n        Maintains state across calls to build complete tool calls.\n        \"\"\"\n        if not new_text:\n            return StreamingParseResult()\n        logger.debug(f\"[GigaChat3] parse_streaming_increment: '{new_text}'\")\n        self._buffer += new_text\n        current_text = self._buffer\n        delta_text = new_text\n        content = None\n        func_name = None\n        cur_args = None\n        m_func = REGEX_FUNCTION_CALL.search(current_text)\n        if not self.tool_started:\n            m_content = REGEX_CONTENT_PATTERN.search(delta_text)\n            if m_content:\n                content = m_content.group(1)\n                self.end_content = True\n            else:\n                if not self.end_content:\n                    content = delta_text\n            if m_func:\n                self.tool_started = True\n                logger.debug(\"[GigaChat3] Tool call started\")\n            if content:\n                return StreamingParseResult(normal_text=content)\n        if not m_func:\n            return StreamingParseResult()\n        json_tail = m_func.group(1).strip()\n        name_match = NAME_REGEX.search(json_tail)\n        if name_match:\n            func_name = name_match.group(1)\n        args_match = ARGS_REGEX.search(json_tail)\n        if args_match:\n            cur_args = args_match.group(1).strip()\n            if cur_args.endswith(\"</s>\"):\n                cur_args = cur_args[: -len(\"</s>\")]\n            if cur_args.endswith(\"}\"):\n                try:\n                    candidate = cur_args[:-1].strip()\n                    json.loads(candidate, strict=False)\n                    cur_args = candidate\n                except json.JSONDecodeError:\n                    pass\n        calls: List[ToolCallItem] = []\n        if not self.prev_tool_call_arr:\n            self.prev_tool_call_arr.append({})\n        if not self.tool_name_sent:\n            if not func_name:\n                return StreamingParseResult()\n            self.tool_name_sent = True\n            self.prev_tool_call_arr[0][\"name\"] = func_name\n            logger.debug(f\"[GigaChat3] Sending tool name: {func_name}\")\n            calls.append(\n                ToolCallItem(\n                    tool_index=0,\n                    name=func_name,\n                    parameters=\"\",\n                )\n            )\n            return StreamingParseResult(calls=calls)\n        if cur_args is None:\n            return StreamingParseResult()\n        prev_args = self.prev_tool_call_arr[0].get(\"arguments_str\", \"\")\n        if not prev_args:\n            delta_args = cur_args\n        elif cur_args.startswith(prev_args):\n            delta_args = cur_args[len(prev_args) :]\n        else:\n            logger.warning(\n                f\"[GigaChat3] Arguments overlap mismatch. \"\n                f\"prev='{prev_args[:50]}...' cur='{cur_args[:50]}...'\"\n            )\n            return StreamingParseResult()\n        if not delta_args:\n            return StreamingParseResult()\n        self.prev_tool_call_arr[0][\"arguments_str\"] = cur_args\n        try:\n            args_dict = json.loads(cur_args, strict=False)\n            self.prev_tool_call_arr[0][\"arguments\"] = args_dict\n        except json.JSONDecodeError:\n            self.prev_tool_call_arr[0][\"arguments\"] = {}\n        logger.debug(f\"[GigaChat3] Sending args delta: '{delta_args[:100]}...'\")\n        calls.append(\n            ToolCallItem(\n                tool_index=0,\n                name=None,\n                parameters=delta_args,\n            )\n        )\n        return StreamingParseResult(calls=calls)\n\n    def supports_structural_tag(self) -> bool:\n        \"\"\"GigaChat3 does not use structural tags\"\"\"\n        return False\n\n    def structure_info(self) -> _GetInfoFunc:\n        \"\"\"Not applicable for GigaChat3\"\"\"\n        raise NotImplementedError(\n            \"GigaChat3Detector does not support structural_tag format.\"\n        )\n"
  },
  {
    "path": "python/sglang/srt/function_call/glm47_moe_detector.py",
    "content": "import ast\nimport json\nimport logging\nimport re\nfrom enum import Enum\nfrom typing import Any, Dict, List, Optional, Tuple\n\nfrom sglang.srt.entrypoints.openai.protocol import Tool\nfrom sglang.srt.function_call.base_format_detector import BaseFormatDetector\nfrom sglang.srt.function_call.core_types import (\n    StreamingParseResult,\n    ToolCallItem,\n    _GetInfoFunc,\n)\nfrom sglang.srt.function_call.utils import infer_type_from_json_schema\n\nlogger = logging.getLogger(__name__)\n\n\nclass StreamState(str, Enum):\n    \"\"\"State machine states for XML to JSON streaming conversion.\"\"\"\n\n    INIT = \"INIT\"\n    BETWEEN = \"BETWEEN\"\n    IN_KEY = \"IN_KEY\"\n    WAITING_VALUE = \"WAITING_VALUE\"\n    IN_VALUE = \"IN_VALUE\"\n\n\ndef get_argument_type(\n    func_name: str, arg_key: str, defined_tools: List[Tool]\n) -> Optional[str]:\n    \"\"\"Get the expected type of a function argument from tool definitions.\n\n    Supports complex JSON Schema definitions including:\n    - Direct type field (including type arrays)\n    - anyOf/oneOf: parameter can be any of multiple types\n    - enum: parameter must be one of enum values\n    - allOf: parameter must satisfy all type definitions\n    - properties: inferred as object type\n    - items: inferred as array type\n\n    Args:\n        func_name: Name of the function/tool\n        arg_key: Name of the argument\n        defined_tools: List of available tools\n\n    Returns:\n        The type string (e.g., 'string', 'number', 'object') or None if not found\n    \"\"\"\n    name2tool = {tool.function.name: tool for tool in defined_tools}\n\n    # Check if function exists\n    tool = name2tool.get(func_name)\n    if not tool:\n        return None\n\n    # Get parameters safely using getattr\n    params = getattr(tool.function, \"parameters\", None)\n    if not isinstance(params, dict):\n        return None\n\n    # Navigate to the type using dict.get() for safe access\n    properties = params.get(\"properties\")\n    if not isinstance(properties, dict):\n        return None\n\n    arg_spec = properties.get(arg_key)\n    if isinstance(arg_spec, dict):\n        # Use the new type inference function for complex JSON Schema support\n        return infer_type_from_json_schema(arg_spec)\n\n    return None\n\n\ndef _convert_to_number(value: str) -> Any:\n    \"\"\"Convert string to appropriate number type (int or float).\n\n    Args:\n        value: String value to convert\n\n    Returns:\n        Converted number or original string if conversion fails\n    \"\"\"\n    try:\n        if \".\" in value or \"e\" in value.lower():\n            return float(value)\n        else:\n            return int(value)\n    except (ValueError, AttributeError):\n        return value\n\n\ndef parse_arguments(\n    json_value: str, arg_type: Optional[str] = None\n) -> Tuple[Any, bool]:\n    \"\"\"Parse argument value with multiple fallback strategies.\n\n    Args:\n        json_value: Raw string value to parse\n        arg_type: Expected type hint ('string', 'number', 'object', etc.)\n\n    Returns:\n        Tuple of (parsed_value, is_valid_json)\n    \"\"\"\n    # Strategy 1: Direct JSON parsing\n    try:\n        parsed_value = json.loads(json_value)\n\n        # Type coercion for number type\n        if arg_type == \"number\" and isinstance(parsed_value, str):\n            parsed_value = _convert_to_number(parsed_value)\n\n        return parsed_value, True\n    except (json.JSONDecodeError, ValueError):\n        pass\n\n    # Strategy 2: Unescape and parse\n    try:\n        wrapped = json.loads('{\"tmp\": \"' + json_value + '\"}')\n        parsed_value = json.loads(wrapped[\"tmp\"])\n\n        if arg_type == \"number\" and isinstance(parsed_value, str):\n            parsed_value = _convert_to_number(parsed_value)\n\n        return parsed_value, True\n    except (json.JSONDecodeError, ValueError, KeyError):\n        pass\n\n    # Strategy 3: ast.literal_eval\n    try:\n        parsed_value = ast.literal_eval(json_value)\n        return parsed_value, True\n    except (ValueError, SyntaxError):\n        pass\n\n    # Strategy 4: Treat as string\n    try:\n        quoted_value = json.dumps(str(json_value))\n        return json.loads(quoted_value), True\n    except (json.JSONDecodeError, ValueError):\n        return json_value, False\n\n\nclass Glm47MoeDetector(BaseFormatDetector):\n    \"\"\"\n    Detector for GLM-4.7 and GLM-5 models.\n    Assumes function call format:\n      <tool_call>get_weather<arg_key>city</arg_key><arg_value>北京</arg_value><arg_key>date</arg_key><arg_value>2024-06-27</arg_value></tool_call><tool_call>get_weather<arg_key>city</arg_key><arg_value>上海</arg_value><arg_key>date</arg_key><arg_value>2024-06-27</arg_value></tool_call>\n    \"\"\"\n\n    def __init__(self):\n        super().__init__()\n        self.bot_token = \"<tool_call>\"\n        self.eot_token = \"</tool_call>\"\n        self.func_call_regex = r\"<tool_call>.*?</tool_call>\"\n        self.func_detail_regex = re.compile(\n            r\"<tool_call>(.*?)(<arg_key>.*?)?</tool_call>\", re.DOTALL\n        )\n        self.func_arg_regex = re.compile(\n            r\"<arg_key>(.*?)</arg_key>(?:\\\\n|\\s)*<arg_value>(.*?)</arg_value>\",\n            re.DOTALL,\n        )\n        self._last_arguments = \"\"\n        self.current_tool_id = -1\n        self.current_tool_name_sent = False\n        self._streamed_raw_length = 0\n        self._tool_call_completed = False  # Track if tool call has been completed\n        self._sent_empty_object = (\n            False  # Track if empty object has been sent for no-arg functions\n        )\n        self._reset_streaming_state()\n\n    def _reset_streaming_state(self) -> None:\n        \"\"\"Reset the streaming state machine for a new tool call.\"\"\"\n        self._stream_state = StreamState.INIT\n        self._current_key = \"\"\n        self._current_value = \"\"\n        self._xml_tag_buffer = \"\"\n        self._is_first_param = True\n        self._value_started = False\n        self._cached_value_type: Optional[str] = (\n            None  # Cache the value type for consistency\n        )\n        self._tool_call_completed = False  # Reset tool call completion status\n        self._sent_empty_object = False  # Reset empty object sent status\n\n    def has_tool_call(self, text: str) -> bool:\n        \"\"\"Check if the text contains a glm-4.5 / glm-4.6 format tool call.\"\"\"\n        return self.bot_token in text\n\n    def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:\n        \"\"\"\n        One-time parsing: Detects and parses tool calls in the provided text.\n\n        :param text: The complete text to parse.\n        :param tools: List of available tools.\n        :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.\n        \"\"\"\n        if self.bot_token not in text:\n            return StreamingParseResult(normal_text=text, calls=[])\n\n        # Extract all normal text (before, between, and after tool calls)\n        normal_text_parts = []\n        last_end = 0\n\n        # Find all tool call matches\n        for match in re.finditer(self.func_call_regex, text, re.DOTALL):\n            # Add text before this tool call\n            if match.start() > last_end:\n                normal_text_parts.append(text[last_end : match.start()])\n            last_end = match.end()\n\n        # Add any remaining text after the last tool call\n        if last_end < len(text):\n            normal_text_parts.append(text[last_end:])\n\n        # Combine all normal text parts\n        normal_text = \"\".join(normal_text_parts).strip()\n\n        # Parse tool calls\n        match_result_list = re.findall(self.func_call_regex, text, re.DOTALL)\n        calls = []\n        try:\n            for match_result in match_result_list:\n                # Get function name\n                func_detail = self.func_detail_regex.search(match_result)\n                if func_detail is None:\n                    continue\n                func_name = func_detail.group(1) if func_detail.group(1) else \"\"\n                func_args = func_detail.group(2) if func_detail.group(2) else \"\"\n                arguments = {}\n                if func_args:\n                    pairs = self.func_arg_regex.findall(func_args)\n                    # Parse arguments using shared method\n                    arguments = self._parse_argument_pairs(pairs, func_name, tools)\n\n                # construct match_result for parse_base_json\n                match_result = {\"name\": func_name, \"parameters\": arguments}\n                calls.extend(self.parse_base_json(match_result, tools))\n            return StreamingParseResult(normal_text=normal_text, calls=calls)\n        except Exception as e:\n            logger.error(f\"Error in detect_and_parse: {e}\", exc_info=True)\n            # return the normal text if parsing fails\n            return StreamingParseResult(normal_text=text)\n\n    def _get_value_type(self, func_name: str, key: str, tools: List[Tool]) -> str:\n        \"\"\"Get parameter type from tool definition, with fallback to auto-detection.\n\n        Args:\n            func_name: Name of the function\n            key: Parameter name\n            tools: List of available tools\n\n        Returns:\n            Type string: 'string', 'number', 'object', 'array', or 'boolean'\n        \"\"\"\n        arg_type = get_argument_type(func_name, key, tools)\n        if arg_type:\n            return arg_type\n\n        # Improved auto-detection type from value (best effort)\n        value_content = self._current_value.strip() if self._current_value else \"\"\n\n        if not value_content:\n            return \"string\"\n\n        # Try to parse as valid JSON first\n        try:\n            parsed = json.loads(value_content)\n            if isinstance(parsed, dict):\n                return \"object\"\n            elif isinstance(parsed, list):\n                return \"array\"\n            elif isinstance(parsed, bool):\n                return \"boolean\"\n            elif isinstance(parsed, (int, float)):\n                return \"number\"\n            # For string values, check if they look like numbers\n            elif isinstance(parsed, str):\n                if parsed.isdigit() or (\n                    parsed.startswith(\"-\") and parsed[1:].isdigit()\n                ):\n                    return \"number\"\n                return \"string\"\n        except json.JSONDecodeError:\n            # Not valid JSON, try heuristic detection\n            first_char = value_content[0] if value_content else \"\"\n\n            if first_char.isdigit() or first_char in [\"-\", \".\"]:\n                return \"number\"\n            elif first_char in [\"{\", \"[\"]:\n                return \"object\"\n            elif first_char in ['\"', \"'\"]:\n                return \"string\"\n\n        # Default to string (safest fallback)\n        return \"string\"\n\n    def _format_value_complete(self, value: str, value_type: str) -> str:\n        \"\"\"Format complete value based on type.\n\n        Args:\n            value: Raw value string\n            value_type: Expected type ('string', 'number', 'object')\n\n        Returns:\n            Properly formatted JSON value string\n        \"\"\"\n        if value_type == \"string\":\n            # Ensure proper JSON string formatting with quotes\n            return json.dumps(value, ensure_ascii=False)\n        elif value_type == \"number\":\n            try:\n                num = _convert_to_number(value.strip() if value else \"\")\n                return str(num)\n            except (ValueError, AttributeError):\n                # Fallback to string if not a valid number\n                logger.warning(\n                    f\"Failed to parse '{value}' as number, treating as string\"\n                )\n                return json.dumps(str(value) if value else \"\", ensure_ascii=False)\n        else:\n            # For object/array types, return as-is (should already be valid JSON)\n            return value\n\n    def _process_xml_to_json_streaming(\n        self, raw_increment: str, func_name: str, tools: List[Tool]\n    ) -> str:\n        \"\"\"Convert XML increment to JSON streaming output using state machine.\n\n        This method processes XML fragments character by character and converts them\n        to JSON format incrementally. It maintains state across calls to handle\n        partial XML tags and values.\n\n        Args:\n            raw_increment: New XML content to process\n            func_name: Name of the function being called\n            tools: List of available tools for type inference\n\n        Returns:\n            JSON string increment to append to the output\n        \"\"\"\n        json_output = \"\"\n\n        for char in raw_increment:\n            self._xml_tag_buffer += char\n\n            if self._stream_state in [StreamState.INIT, StreamState.BETWEEN]:\n                if self._xml_tag_buffer.endswith(\"<arg_key>\"):\n                    self._stream_state = StreamState.IN_KEY\n                    self._current_key = \"\"\n                    self._xml_tag_buffer = \"\"\n                    json_output += \"{\" if self._is_first_param else \", \"\n                    self._is_first_param = False\n\n            elif self._stream_state == StreamState.IN_KEY:\n                if self._xml_tag_buffer.endswith(\"</arg_key>\"):\n                    self._current_key = self._xml_tag_buffer[:-10].strip()\n                    self._xml_tag_buffer = \"\"\n                    self._stream_state = StreamState.WAITING_VALUE\n                    json_output += (\n                        json.dumps(self._current_key, ensure_ascii=False) + \": \"\n                    )\n\n            elif self._stream_state == StreamState.WAITING_VALUE:\n                if self._xml_tag_buffer.endswith(\"<arg_value>\"):\n                    self._stream_state = StreamState.IN_VALUE\n                    self._current_value = \"\"\n                    self._xml_tag_buffer = \"\"\n                    self._value_started = False\n                    # Determine and cache the value type at the start\n                    self._cached_value_type = self._get_value_type(\n                        func_name, self._current_key, tools\n                    )\n\n            elif self._stream_state == StreamState.IN_VALUE:\n                if self._xml_tag_buffer.endswith(\"</arg_value>\"):\n                    final_value = self._xml_tag_buffer[:-12]\n                    self._current_value += final_value\n\n                    # Use cached value type for consistency\n                    value_type = self._cached_value_type or \"string\"\n\n                    if self._value_started:\n                        # Output any remaining content\n                        if final_value:\n                            if value_type == \"string\":\n                                json_output += json.dumps(\n                                    final_value, ensure_ascii=False\n                                )[1:-1]\n                            else:\n                                json_output += final_value\n                        # Always output closing quote for string type when value was started\n                        if value_type == \"string\":\n                            json_output += '\"'\n                    else:\n                        # Value was never started (empty or complete in one chunk)\n                        json_output += self._format_value_complete(\n                            self._current_value, value_type\n                        )\n\n                    self._xml_tag_buffer = \"\"\n                    self._stream_state = StreamState.BETWEEN\n                    self._current_value = \"\"\n                    self._value_started = False\n                    self._cached_value_type = None  # Reset cached type\n                else:\n                    closing_tag = \"</arg_value>\"\n                    is_potential_closing = len(self._xml_tag_buffer) <= len(\n                        closing_tag\n                    ) and closing_tag.startswith(self._xml_tag_buffer)\n\n                    if not is_potential_closing:\n                        content = self._xml_tag_buffer\n                        # Use cached value type for consistency\n                        value_type = self._cached_value_type or \"string\"\n\n                        if value_type == \"string\":\n                            if not self._value_started:\n                                json_output += '\"'\n                                self._value_started = True\n                            if content:\n                                json_output += json.dumps(content, ensure_ascii=False)[\n                                    1:-1\n                                ]\n                                self._current_value += content\n                                self._xml_tag_buffer = \"\"\n                        elif value_type == \"number\":\n                            if content:\n                                if not self._value_started:\n                                    self._value_started = True\n                                json_output += content\n                                self._current_value += content\n                                self._xml_tag_buffer = \"\"\n                        else:\n                            # For object/array types, output as-is\n                            if content:\n                                if not self._value_started:\n                                    self._value_started = True\n                                json_output += content\n                                self._current_value += content\n                                self._xml_tag_buffer = \"\"\n\n        return json_output\n\n    def _extract_match_groups(self, match: re.Match) -> tuple[str, str, str]:\n        \"\"\"Extract function name, arguments and end marker from regex match.\n\n        Args:\n            match: Regex match object\n\n        Returns:\n            (func_name, func_args_raw, is_tool_end)\n        \"\"\"\n        func_name = match.group(1).strip()\n        func_args_raw = match.group(2).strip() if match.group(2) else \"\"\n        is_tool_end = match.group(3) or \"\"\n        return func_name, func_args_raw, is_tool_end\n\n    def _send_tool_name_if_needed(\n        self, func_name: str, has_arg_key: bool, is_tool_end: str\n    ) -> Optional[ToolCallItem]:\n        \"\"\"Send tool name if needed.\n\n        Args:\n            func_name: Function name\n            has_arg_key: Whether current text contains <arg_key\n            is_tool_end: Whether end marker is encountered\n\n        Returns:\n            Tool call item or None\n        \"\"\"\n        if self.current_tool_name_sent:\n            return None\n\n        # Function name completeness check\n        is_func_name_complete = has_arg_key or is_tool_end == self.eot_token\n\n        if not is_func_name_complete:\n            return None\n\n        if not func_name:\n            logger.warning(\"Empty function name detected, skipping tool call\")\n            return None\n\n        # Send tool name\n        self.current_tool_name_sent = True\n        self._streamed_raw_length = 0\n        self._reset_streaming_state()\n\n        # Record tool info\n        self.prev_tool_call_arr[self.current_tool_id] = {\n            \"name\": func_name,\n            \"arguments\": {},\n        }\n\n        return ToolCallItem(\n            tool_index=self.current_tool_id,\n            name=func_name,\n            parameters=\"\",\n        )\n\n    def _process_arguments_streaming(\n        self, func_name: str, func_args_raw: str, tools: List[Tool]\n    ) -> Optional[ToolCallItem]:\n        \"\"\"Process streaming arguments.\n\n        Args:\n            func_name: Function name\n            func_args_raw: Raw argument string\n            tools: List of available tools\n\n        Returns:\n            Tool call item with parameter updates or None\n        \"\"\"\n        current_raw_length = len(func_args_raw)\n\n        if current_raw_length <= self._streamed_raw_length:\n            return None\n\n        # Get new raw XML content\n        raw_increment = func_args_raw[self._streamed_raw_length :]\n\n        # Convert XML to JSON using state machine\n        json_increment = self._process_xml_to_json_streaming(\n            raw_increment, func_name, tools\n        )\n\n        # CRITICAL: Update streamed length BEFORE early return\n        # Even if json_increment is empty, the input has been consumed by the state machine\n        self._streamed_raw_length = current_raw_length\n\n        if not json_increment:\n            return None\n\n        # Update state\n        self._last_arguments += json_increment\n        self.streamed_args_for_tool[self.current_tool_id] += json_increment\n\n        return ToolCallItem(\n            tool_index=self.current_tool_id,\n            name=None,\n            parameters=json_increment,\n        )\n\n    def _finalize_tool_call(\n        self,\n        func_name: str,\n        func_args_raw: str,\n        tools: List[Tool],\n        match_end_pos: int,\n        current_text: str,\n    ) -> List[ToolCallItem]:\n        \"\"\"Complete tool call processing.\n\n        Args:\n            func_name: Function name\n            func_args_raw: Raw argument string\n            tools: List of available tools\n            match_end_pos: Match end position\n            current_text: Current text\n\n        Returns:\n            List of tool call items to add\n        \"\"\"\n        calls = []\n\n        # Handle no-arg function or need to close braces\n        if self._is_first_param and not self._sent_empty_object:\n            # No-arg function\n            calls.append(\n                ToolCallItem(\n                    tool_index=self.current_tool_id,\n                    name=None,\n                    parameters=\"{}\",\n                )\n            )\n            self._last_arguments += \"{}\"\n            self.streamed_args_for_tool[self.current_tool_id] += \"{}\"\n            self._sent_empty_object = True\n        elif not self._last_arguments.endswith(\"}\") and not self._sent_empty_object:\n            # Need to close brace\n            calls.append(\n                ToolCallItem(\n                    tool_index=self.current_tool_id,\n                    name=None,\n                    parameters=\"}\",\n                )\n            )\n            self._last_arguments += \"}\"\n            self.streamed_args_for_tool[self.current_tool_id] += \"}\"\n            self._sent_empty_object = True\n\n        # Parse final arguments\n        if func_args_raw:\n            try:\n                pairs = self.func_arg_regex.findall(func_args_raw)\n                if pairs:\n                    arguments = self._parse_argument_pairs(pairs, func_name, tools)\n                    self.prev_tool_call_arr[self.current_tool_id][\n                        \"arguments\"\n                    ] = arguments\n            except Exception as e:\n                logger.debug(f\"Failed to parse arguments: {e}\", exc_info=True)\n\n        # Clean buffer\n        self._buffer = current_text[match_end_pos:]\n\n        # Reset state for next tool call\n        self._tool_call_completed = True\n        self.current_tool_id += 1\n        self._last_arguments = \"\"\n        self.current_tool_name_sent = False\n        self._streamed_raw_length = 0\n        self._reset_streaming_state()\n\n        return calls\n\n    def parse_streaming_increment(\n        self, new_text: str, tools: List[Tool]\n    ) -> StreamingParseResult:\n        \"\"\"\n        Streaming incremental parsing tool calls for GLM-4.5 and GLM-4.6 format.\n        Uses a state machine to convert XML to JSON incrementally for true character-by-character streaming.\n        Outputs JSON increments immediately as XML data arrives.\n        \"\"\"\n        self._buffer += new_text\n        current_text = self._buffer\n\n        # Check if we have a tool call\n        has_tool_call = self.bot_token in current_text\n\n        if not has_tool_call:\n            # Check if buffer could be the start of a tool call\n            # Keep buffer if it could be a partial match of bot_token\n            is_potential_start = any(\n                self.bot_token.startswith(current_text[-i:])\n                for i in range(1, min(len(current_text), len(self.bot_token)) + 1)\n            )\n\n            if not is_potential_start:\n                # Not a potential tool call, return as normal text\n                # Must return the entire buffer (current_text), not just new_text,\n                # because buffer may contain previously accumulated characters like '<'\n                # that turned out not to be part of a tool call\n                output_text = current_text\n                self._buffer = \"\"\n                if self.eot_token in output_text:\n                    output_text = output_text.replace(self.eot_token, \"\")\n                return StreamingParseResult(normal_text=output_text)\n            else:\n                # Could be start of tool call, keep buffering\n                return StreamingParseResult(normal_text=\"\", calls=[])\n\n        # Extract any text before the first bot_token and return it as normal_text\n        normal_text = \"\"\n        first_bot_token_idx = current_text.find(self.bot_token)\n        if first_bot_token_idx > 0:\n            normal_text = current_text[:first_bot_token_idx]\n            current_text = current_text[first_bot_token_idx:]\n            # Update buffer to only include from the bot token onwards\n            self._buffer = current_text\n\n        if not hasattr(self, \"_tool_indices\"):\n            self._tool_indices = self._get_tool_indices(tools)\n\n        calls: list[ToolCallItem] = []\n        try:\n            # Try to match a partial or complete tool call\n            # Use a single flexible regex pattern that handles all cases\n            partial_match = re.search(\n                r\"<tool_call>(.*?)(?:(<arg_key.*?))?(?:(</tool_call>)|$)\",\n                current_text,\n                re.DOTALL,\n            )\n\n            if not partial_match:\n                return StreamingParseResult(normal_text=normal_text, calls=[])\n\n            # Extract match groups using helper method\n            func_name, func_args_raw, is_tool_end = self._extract_match_groups(\n                partial_match\n            )\n\n            # Initialize tool call state if needed (keeping existing logic)\n            if self.current_tool_id == -1:\n                self.current_tool_id = 0\n                self.prev_tool_call_arr = []\n                self.streamed_args_for_tool = [\"\"]\n                self._streamed_raw_length = 0\n                self.current_tool_name_sent = False  # Reset for new tool call\n                self._reset_streaming_state()\n            # Check if this is a continuation of an existing tool call or a new one\n            elif not self.current_tool_name_sent:\n                # Only increment tool_id if we're truly starting a NEW tool call\n                # Don't increment if this is just the first time we're processing\n                # a tool call that was received in the buffer\n                # The key insight: only increment when we've COMPLETED a previous tool call\n                # and now see another bot_token in new_text\n                pass  # Remove the problematic auto-increment logic\n\n            # Ensure tracking arrays are large enough (keeping existing logic)\n            while len(self.prev_tool_call_arr) <= self.current_tool_id:\n                self.prev_tool_call_arr.append({})\n            while len(self.streamed_args_for_tool) <= self.current_tool_id:\n                self.streamed_args_for_tool.append(\"\")\n\n            # Determine if function name is complete by checking for <arg_key> in the full text\n            # This is important for streaming scenarios where args come in later chunks\n            has_arg_key = \"<arg_key\" in current_text\n\n            # Send tool name if needed\n            tool_name_item = self._send_tool_name_if_needed(\n                func_name, has_arg_key, is_tool_end\n            )\n            if tool_name_item:\n                calls.append(tool_name_item)\n\n            # Process streaming arguments if tool name has been sent\n            if self.current_tool_name_sent:\n                arg_item = self._process_arguments_streaming(\n                    func_name, func_args_raw, tools\n                )\n                if arg_item:\n                    calls.append(arg_item)\n\n                # Finalize tool call if end token is encountered\n                if is_tool_end == self.eot_token and not self._tool_call_completed:\n                    finalize_calls = self._finalize_tool_call(\n                        func_name,\n                        func_args_raw,\n                        tools,\n                        partial_match.end(),\n                        current_text,\n                    )\n                    calls.extend(finalize_calls)\n                    return StreamingParseResult(normal_text=normal_text, calls=calls)\n\n        except Exception as e:\n            logger.error(f\"Error in parse_streaming_increment: {e}\", exc_info=True)\n            return StreamingParseResult(normal_text=current_text)\n\n        return StreamingParseResult(normal_text=normal_text, calls=calls)\n\n    def _parse_argument_pairs(\n        self, pairs: List[Tuple[str, str]], func_name: str, tools: List[Tool]\n    ) -> Dict[str, Any]:\n        \"\"\"Parse argument key-value pairs with type coercion.\n\n        Args:\n            pairs: List of (key, value) tuples from regex matching\n            func_name: Name of the function\n            tools: List of available tools\n\n        Returns:\n            Dictionary of parsed arguments\n        \"\"\"\n        arguments = {}\n        for arg_key, arg_value in pairs:\n            arg_key = arg_key.strip()\n            arg_value = arg_value.strip()\n            arg_type = get_argument_type(func_name, arg_key, tools)\n            parsed_value, is_good_json = parse_arguments(arg_value, arg_type)\n\n            if arg_type == \"string\":\n                # Only convert to string if explicitly defined as string type\n                if isinstance(parsed_value, str):\n                    arguments[arg_key] = parsed_value\n                elif isinstance(parsed_value, (dict, list)):\n                    # If parsed as dict/list but schema says string, convert to JSON string\n                    arguments[arg_key] = json.dumps(parsed_value, ensure_ascii=False)\n                else:\n                    arguments[arg_key] = str(parsed_value)\n            elif arg_type is None:\n                # If type is not defined, keep the parsed value as-is\n                arguments[arg_key] = parsed_value if is_good_json else arg_value\n            else:\n                # For other types (number, object, array, etc.), use parsed value\n                arguments[arg_key] = parsed_value if is_good_json else arg_value\n\n        return arguments\n\n    def supports_structural_tag(self) -> bool:\n        return False\n\n    def structure_info(self) -> _GetInfoFunc:\n        raise NotImplementedError()\n"
  },
  {
    "path": "python/sglang/srt/function_call/glm4_moe_detector.py",
    "content": "import ast\nimport json\nimport logging\nimport re\nfrom enum import Enum\nfrom typing import Any, Dict, List, Optional, Tuple\n\nfrom sglang.srt.entrypoints.openai.protocol import Tool\nfrom sglang.srt.function_call.base_format_detector import BaseFormatDetector\nfrom sglang.srt.function_call.core_types import (\n    StreamingParseResult,\n    ToolCallItem,\n    _GetInfoFunc,\n)\nfrom sglang.srt.function_call.utils import infer_type_from_json_schema\n\nlogger = logging.getLogger(__name__)\n\n\nclass StreamState(str, Enum):\n    \"\"\"State machine states for XML to JSON streaming conversion.\"\"\"\n\n    INIT = \"INIT\"\n    BETWEEN = \"BETWEEN\"\n    IN_KEY = \"IN_KEY\"\n    WAITING_VALUE = \"WAITING_VALUE\"\n    IN_VALUE = \"IN_VALUE\"\n\n\ndef get_argument_type(\n    func_name: str, arg_key: str, defined_tools: List[Tool]\n) -> Optional[str]:\n    \"\"\"Get the expected type of a function argument from tool definitions.\n\n    Supports complex JSON Schema definitions including:\n    - Direct type field (including type arrays)\n    - anyOf/oneOf: parameter can be any of multiple types\n    - enum: parameter must be one of enum values\n    - allOf: parameter must satisfy all type definitions\n    - properties: inferred as object type\n    - items: inferred as array type\n\n    Args:\n        func_name: Name of the function/tool\n        arg_key: Name of the argument\n        defined_tools: List of available tools\n\n    Returns:\n        The type string (e.g., 'string', 'number', 'object') or None if not found\n    \"\"\"\n    name2tool = {tool.function.name: tool for tool in defined_tools}\n    if func_name not in name2tool:\n        return None\n    tool = name2tool[func_name]\n    properties = (tool.function.parameters or {}).get(\"properties\", {})\n    if not isinstance(properties, dict):\n        properties = {}\n    if arg_key not in properties:\n        return None\n\n    # Use new type inference function for complex JSON Schema support\n    return infer_type_from_json_schema(properties[arg_key])\n\n\ndef _convert_to_number(value: str) -> Any:\n    \"\"\"Convert string to appropriate number type (int or float).\n\n    Args:\n        value: String value to convert\n\n    Returns:\n        Converted number or original string if conversion fails\n    \"\"\"\n    try:\n        if \".\" in value or \"e\" in value.lower():\n            return float(value)\n        else:\n            return int(value)\n    except (ValueError, AttributeError):\n        return value\n\n\ndef parse_arguments(\n    json_value: str, arg_type: Optional[str] = None\n) -> Tuple[Any, bool]:\n    \"\"\"Parse argument value with multiple fallback strategies.\n\n    Args:\n        json_value: Raw string value to parse\n        arg_type: Expected type hint ('string', 'number', 'object', etc.)\n\n    Returns:\n        Tuple of (parsed_value, is_valid_json)\n    \"\"\"\n    # Strategy 1: Direct JSON parsing\n    try:\n        parsed_value = json.loads(json_value)\n\n        # Type coercion for number type\n        if arg_type == \"number\" and isinstance(parsed_value, str):\n            parsed_value = _convert_to_number(parsed_value)\n\n        return parsed_value, True\n    except (json.JSONDecodeError, ValueError):\n        pass\n\n    # Strategy 2: Unescape and parse\n    try:\n        wrapped = json.loads('{\"tmp\": \"' + json_value + '\"}')\n        parsed_value = json.loads(wrapped[\"tmp\"])\n\n        if arg_type == \"number\" and isinstance(parsed_value, str):\n            parsed_value = _convert_to_number(parsed_value)\n\n        return parsed_value, True\n    except (json.JSONDecodeError, ValueError, KeyError):\n        pass\n\n    # Strategy 3: ast.literal_eval\n    try:\n        parsed_value = ast.literal_eval(json_value)\n        return parsed_value, True\n    except (ValueError, SyntaxError):\n        pass\n\n    # Strategy 4: Treat as string\n    try:\n        quoted_value = json.dumps(str(json_value))\n        return json.loads(quoted_value), True\n    except (json.JSONDecodeError, ValueError):\n        return json_value, False\n\n\nclass Glm4MoeDetector(BaseFormatDetector):\n    \"\"\"\n    Detector for GLM-4.5 and GLM-4.6 models.\n    Assumes function call format (with actual newlines):\n      <tool_call>get_weather\n      <arg_key>city</arg_key>\n      <arg_value>北京</arg_value>\n      <arg_key>date</arg_key>\n      <arg_value>2024-06-27</arg_value>\n      </tool_call>\n\n    Or with literal \\n characters (escaped as \\\\n in the output):\n      <tool_call>get_weather\\n<arg_key>city</arg_key>\\n<arg_value>北京</arg_value>\\n</tool_call>\n\n    Uses a streaming state machine to convert XML to JSON incrementally for maximum speed.\n    \"\"\"\n\n    def __init__(self):\n        super().__init__()\n        self.bot_token = \"<tool_call>\"\n        self.eot_token = \"</tool_call>\"\n        self.func_call_regex = r\"<tool_call>.*?</tool_call>\"\n        self.func_detail_regex = re.compile(\n            r\"<tool_call>(.*?)(?:\\\\n|\\n)(.*)</tool_call>\", re.DOTALL\n        )\n        self.func_arg_regex = re.compile(\n            r\"<arg_key>(.*?)</arg_key>(?:\\\\n|\\s)*<arg_value>(.*?)</arg_value>\",\n            re.DOTALL,\n        )\n        self._last_arguments = \"\"\n        self.current_tool_id = -1\n        self.current_tool_name_sent = False\n        self._streamed_raw_length = 0\n        self._reset_streaming_state()\n\n    def _reset_streaming_state(self) -> None:\n        \"\"\"Reset the streaming state machine for a new tool call.\"\"\"\n        self._stream_state = StreamState.INIT\n        self._current_key = \"\"\n        self._current_value = \"\"\n        self._xml_tag_buffer = \"\"\n        self._is_first_param = True\n        self._value_started = False\n        self._cached_value_type: Optional[str] = (\n            None  # Cache the value type for consistency\n        )\n\n    def has_tool_call(self, text: str) -> bool:\n        \"\"\"Check if the text contains a glm-4.5 / glm-4.6 format tool call.\"\"\"\n        return self.bot_token in text\n\n    def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:\n        \"\"\"\n        One-time parsing: Detects and parses tool calls in the provided text.\n\n        :param text: The complete text to parse.\n        :param tools: List of available tools.\n        :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.\n        \"\"\"\n        idx = text.find(self.bot_token)\n        normal_text = text[:idx].strip() if idx != -1 else text\n        if self.bot_token not in text:\n            return StreamingParseResult(normal_text=normal_text, calls=[])\n        match_result_list = re.findall(self.func_call_regex, text, re.DOTALL)\n        calls = []\n        try:\n            for match_result in match_result_list:\n                # Get function name\n                func_detail = self.func_detail_regex.search(match_result)\n                if func_detail is None:\n                    continue\n                func_name = func_detail.group(1) if func_detail.group(1) else \"\"\n                func_args = func_detail.group(2) if func_detail.group(2) else \"\"\n                pairs = self.func_arg_regex.findall(func_args)\n\n                # Parse arguments using shared method\n                arguments = self._parse_argument_pairs(pairs, func_name, tools)\n\n                # construct match_result for parse_base_json\n                match_result = {\"name\": func_name, \"parameters\": arguments}\n                calls.extend(self.parse_base_json(match_result, tools))\n            return StreamingParseResult(normal_text=normal_text, calls=calls)\n        except Exception as e:\n            logger.error(f\"Error in detect_and_parse: {e}\", exc_info=True)\n            # return the normal text if parsing fails\n            return StreamingParseResult(normal_text=text)\n\n    def _get_value_type(self, func_name: str, key: str, tools: List[Tool]) -> str:\n        \"\"\"Get parameter type from tool definition, with fallback to auto-detection.\n\n        Args:\n            func_name: Name of the function\n            key: Parameter name\n            tools: List of available tools\n\n        Returns:\n            Type string: 'string', 'number', 'object', 'array', or 'boolean'\n        \"\"\"\n        arg_type = get_argument_type(func_name, key, tools)\n        if arg_type:\n            return arg_type\n\n        # Improved auto-detection type from value (best effort)\n        value_content = self._current_value.strip() if self._current_value else \"\"\n\n        if not value_content:\n            return \"string\"\n\n        # Try to parse as valid JSON first\n        try:\n            parsed = json.loads(value_content)\n            if isinstance(parsed, dict):\n                return \"object\"\n            elif isinstance(parsed, list):\n                return \"array\"\n            elif isinstance(parsed, bool):\n                return \"boolean\"\n            elif isinstance(parsed, (int, float)):\n                return \"number\"\n            # For string values, check if they look like numbers\n            elif isinstance(parsed, str):\n                if parsed.isdigit() or (\n                    parsed.startswith(\"-\") and parsed[1:].isdigit()\n                ):\n                    return \"number\"\n                return \"string\"\n        except json.JSONDecodeError:\n            # Not valid JSON, try heuristic detection\n            first_char = value_content[0] if value_content else \"\"\n\n            if first_char.isdigit() or first_char in [\"-\", \".\"]:\n                return \"number\"\n            elif first_char in [\"{\", \"[\"]:\n                return \"object\"\n            elif first_char in ['\"', \"'\"]:\n                return \"string\"\n\n        # Default to string (safest fallback)\n        return \"string\"\n\n    def _format_value_complete(self, value: str, value_type: str) -> str:\n        \"\"\"Format complete value based on type.\n\n        Args:\n            value: Raw value string\n            value_type: Expected type ('string', 'number', 'object')\n\n        Returns:\n            Properly formatted JSON value string\n        \"\"\"\n        if value_type == \"string\":\n            # Ensure proper JSON string formatting with quotes\n            return json.dumps(value, ensure_ascii=False)\n        elif value_type == \"number\":\n            try:\n                num = _convert_to_number(value.strip())\n                return str(num)\n            except (ValueError, AttributeError):\n                # Fallback to string if not a valid number\n                logger.warning(\n                    f\"Failed to parse '{value}' as number, treating as string\"\n                )\n                return json.dumps(str(value), ensure_ascii=False)\n        else:\n            # For object/array types, return as-is (should already be valid JSON)\n            return value\n\n    def _process_xml_to_json_streaming(\n        self, raw_increment: str, func_name: str, tools: List[Tool]\n    ) -> str:\n        \"\"\"Convert XML increment to JSON streaming output using state machine.\n\n        This method processes XML fragments character by character and converts them\n        to JSON format incrementally. It maintains state across calls to handle\n        partial XML tags and values.\n\n        Args:\n            raw_increment: New XML content to process\n            func_name: Name of the function being called\n            tools: List of available tools for type inference\n\n        Returns:\n            JSON string increment to append to the output\n        \"\"\"\n        json_output = \"\"\n\n        for char in raw_increment:\n            self._xml_tag_buffer += char\n\n            if self._stream_state in [StreamState.INIT, StreamState.BETWEEN]:\n                if self._xml_tag_buffer.endswith(\"<arg_key>\"):\n                    self._stream_state = StreamState.IN_KEY\n                    self._current_key = \"\"\n                    self._xml_tag_buffer = \"\"\n                    json_output += \"{\" if self._is_first_param else \", \"\n                    self._is_first_param = False\n\n            elif self._stream_state == StreamState.IN_KEY:\n                if self._xml_tag_buffer.endswith(\"</arg_key>\"):\n                    self._current_key = self._xml_tag_buffer[:-10].strip()\n                    self._xml_tag_buffer = \"\"\n                    self._stream_state = StreamState.WAITING_VALUE\n                    json_output += (\n                        json.dumps(self._current_key, ensure_ascii=False) + \": \"\n                    )\n\n            elif self._stream_state == StreamState.WAITING_VALUE:\n                if self._xml_tag_buffer.endswith(\"<arg_value>\"):\n                    self._stream_state = StreamState.IN_VALUE\n                    self._current_value = \"\"\n                    self._xml_tag_buffer = \"\"\n                    self._value_started = False\n                    # Determine and cache the value type at the start\n                    self._cached_value_type = self._get_value_type(\n                        func_name, self._current_key, tools\n                    )\n\n            elif self._stream_state == StreamState.IN_VALUE:\n                if self._xml_tag_buffer.endswith(\"</arg_value>\"):\n                    final_value = self._xml_tag_buffer[:-12]\n                    self._current_value += final_value\n\n                    # Use cached value type for consistency\n                    value_type = self._cached_value_type or \"string\"\n\n                    if self._value_started:\n                        # Output any remaining content\n                        if final_value:\n                            if value_type == \"string\":\n                                json_output += json.dumps(\n                                    final_value, ensure_ascii=False\n                                )[1:-1]\n                            else:\n                                json_output += final_value\n                        # Always output closing quote for string type when value was started\n                        if value_type == \"string\":\n                            json_output += '\"'\n                    else:\n                        # Value was never started (empty or complete in one chunk)\n                        json_output += self._format_value_complete(\n                            self._current_value, value_type\n                        )\n\n                    self._xml_tag_buffer = \"\"\n                    self._stream_state = StreamState.BETWEEN\n                    self._current_value = \"\"\n                    self._value_started = False\n                    self._cached_value_type = None  # Reset cached type\n                else:\n                    closing_tag = \"</arg_value>\"\n                    is_potential_closing = len(self._xml_tag_buffer) <= len(\n                        closing_tag\n                    ) and closing_tag.startswith(self._xml_tag_buffer)\n\n                    if not is_potential_closing:\n                        content = self._xml_tag_buffer\n                        # Use cached value type for consistency\n                        value_type = self._cached_value_type or \"string\"\n\n                        if value_type == \"string\":\n                            if not self._value_started:\n                                json_output += '\"'\n                                self._value_started = True\n                            if content:\n                                json_output += json.dumps(content, ensure_ascii=False)[\n                                    1:-1\n                                ]\n                                self._current_value += content\n                                self._xml_tag_buffer = \"\"\n                        elif value_type == \"number\":\n                            if content:\n                                if not self._value_started:\n                                    self._value_started = True\n                                json_output += content\n                                self._current_value += content\n                                self._xml_tag_buffer = \"\"\n                        else:\n                            # For object/array types, output as-is\n                            if content:\n                                if not self._value_started:\n                                    self._value_started = True\n                                json_output += content\n                                self._current_value += content\n                                self._xml_tag_buffer = \"\"\n\n        return json_output\n\n    def parse_streaming_increment(\n        self, new_text: str, tools: List[Tool]\n    ) -> StreamingParseResult:\n        \"\"\"\n        Streaming incremental parsing tool calls for GLM-4.5 and GLM-4.6 format.\n        Uses a state machine to convert XML to JSON incrementally for true character-by-character streaming.\n        Outputs JSON increments immediately as XML data arrives.\n        \"\"\"\n        self._buffer += new_text\n        current_text = self._buffer\n\n        # Check if we have a tool call\n        has_tool_call = self.bot_token in current_text\n\n        if not has_tool_call:\n            # Check if buffer could be the start of a tool call\n            # Keep buffer if it could be a partial match of bot_token\n            is_potential_start = any(\n                self.bot_token.startswith(current_text[-i:])\n                for i in range(1, min(len(current_text), len(self.bot_token)) + 1)\n            )\n\n            if not is_potential_start:\n                # Not a potential tool call, return as normal text\n                # Must return the entire buffer (current_text), not just new_text,\n                # because buffer may contain previously accumulated characters like '<'\n                # that turned out not to be part of a tool call\n                output_text = current_text\n                self._buffer = \"\"\n                if self.eot_token in output_text:\n                    output_text = output_text.replace(self.eot_token, \"\")\n                return StreamingParseResult(normal_text=output_text)\n            else:\n                # Could be start of tool call, keep buffering\n                return StreamingParseResult(normal_text=\"\", calls=[])\n\n        if not hasattr(self, \"_tool_indices\"):\n            self._tool_indices = self._get_tool_indices(tools)\n\n        calls: list[ToolCallItem] = []\n        try:\n            # Try to match a partial or complete tool call\n            partial_match = re.search(\n                pattern=r\"<tool_call>(.*?)(?:\\\\n|\\n)(.*?)(</tool_call>|$)\",\n                string=current_text,\n                flags=re.DOTALL,\n            )\n            if partial_match:\n                func_name_raw = partial_match.group(1)\n                func_args_raw = partial_match.group(2)\n                is_tool_end = partial_match.group(3)\n\n                # Only proceed if we have a non-empty function name\n                if func_name_raw is None or not func_name_raw.strip():\n                    # If we only have the start token without a function name,\n                    # continue buffering until we get more content\n                    return StreamingParseResult(normal_text=\"\", calls=[])\n\n                func_name = func_name_raw.strip()\n                func_args_raw = func_args_raw.strip() if func_args_raw else \"\"\n\n                # Initialize state if this is the first tool call\n                if self.current_tool_id == -1:\n                    self.current_tool_id = 0\n                    self.prev_tool_call_arr = []\n                    self.streamed_args_for_tool = [\"\"]\n                    self._streamed_raw_length = 0\n                    self.current_tool_name_sent = False\n                    self._reset_streaming_state()\n\n                # Ensure we have enough entries in our tracking arrays\n                while len(self.prev_tool_call_arr) <= self.current_tool_id:\n                    self.prev_tool_call_arr.append({})\n                while len(self.streamed_args_for_tool) <= self.current_tool_id:\n                    self.streamed_args_for_tool.append(\"\")\n\n                # Send tool name first if not sent yet\n                if not self.current_tool_name_sent:\n                    calls.append(\n                        ToolCallItem(\n                            tool_index=self.current_tool_id,\n                            name=func_name,\n                            parameters=\"\",\n                        )\n                    )\n                    self.current_tool_name_sent = True\n                    self._streamed_raw_length = 0\n                    self._reset_streaming_state()\n                    # Store the tool call info\n                    self.prev_tool_call_arr[self.current_tool_id] = {\n                        \"name\": func_name,\n                        \"arguments\": {},\n                    }\n                else:\n                    # Process XML to JSON streaming\n                    current_raw_length = len(func_args_raw)\n\n                    if current_raw_length > self._streamed_raw_length:\n                        # Get the new raw XML content\n                        raw_increment = func_args_raw[self._streamed_raw_length :]\n\n                        # Convert XML increment to JSON increment using state machine\n                        json_increment = self._process_xml_to_json_streaming(\n                            raw_increment, func_name, tools\n                        )\n\n                        # CRITICAL: Update streamed length BEFORE checking json_increment\n                        # Even if json_increment is empty, the input has been consumed by the state machine\n                        self._streamed_raw_length = current_raw_length\n\n                        if json_increment:\n                            calls.append(\n                                ToolCallItem(\n                                    tool_index=self.current_tool_id,\n                                    name=None,\n                                    parameters=json_increment,\n                                )\n                            )\n                            self._last_arguments += json_increment\n                            self.streamed_args_for_tool[\n                                self.current_tool_id\n                            ] += json_increment\n\n                    if is_tool_end == self.eot_token:\n                        if self._is_first_param:\n                            empty_object = \"{}\"\n                            calls.append(\n                                ToolCallItem(\n                                    tool_index=self.current_tool_id,\n                                    name=None,\n                                    parameters=empty_object,\n                                )\n                            )\n                            self._last_arguments += empty_object\n                        elif not self._last_arguments.endswith(\"}\"):\n                            closing_brace = \"}\"\n                            calls.append(\n                                ToolCallItem(\n                                    tool_index=self.current_tool_id,\n                                    name=None,\n                                    parameters=closing_brace,\n                                )\n                            )\n                            self._last_arguments += closing_brace\n                            self.streamed_args_for_tool[\n                                self.current_tool_id\n                            ] += closing_brace\n\n                        try:\n                            pairs = self.func_arg_regex.findall(func_args_raw)\n                            if pairs:\n                                arguments = self._parse_argument_pairs(\n                                    pairs, func_name, tools\n                                )\n                                self.prev_tool_call_arr[self.current_tool_id][\n                                    \"arguments\"\n                                ] = arguments\n                        except Exception as e:\n                            logger.debug(\n                                f\"Failed to parse arguments: {e}\", exc_info=True\n                            )\n\n                        # Remove the completed tool call from buffer\n                        self._buffer = current_text[partial_match.end(3) :]\n\n                        result = StreamingParseResult(normal_text=\"\", calls=calls)\n                        self.current_tool_id += 1\n                        self._last_arguments = \"\"\n                        self.current_tool_name_sent = False\n                        self._streamed_raw_length = 0\n                        self._reset_streaming_state()\n                        return result\n\n            return StreamingParseResult(normal_text=\"\", calls=calls)\n\n        except Exception as e:\n            logger.error(f\"Error in parse_streaming_increment: {e}\", exc_info=True)\n            return StreamingParseResult(normal_text=current_text)\n\n    def _parse_argument_pairs(\n        self, pairs: List[Tuple[str, str]], func_name: str, tools: List[Tool]\n    ) -> Dict[str, Any]:\n        \"\"\"Parse argument key-value pairs with type coercion.\n\n        Args:\n            pairs: List of (key, value) tuples from regex matching\n            func_name: Name of the function\n            tools: List of available tools\n\n        Returns:\n            Dictionary of parsed arguments\n        \"\"\"\n        arguments = {}\n        for arg_key, arg_value in pairs:\n            arg_key = arg_key.strip()\n            arg_value = arg_value.strip()\n            arg_type = get_argument_type(func_name, arg_key, tools)\n            parsed_value, is_good_json = parse_arguments(arg_value, arg_type)\n\n            if arg_type == \"string\":\n                # Only convert to string if explicitly defined as string type\n                if isinstance(parsed_value, str):\n                    arguments[arg_key] = parsed_value\n                elif isinstance(parsed_value, (dict, list)):\n                    # If parsed as dict/list but schema says string, convert to JSON string\n                    arguments[arg_key] = json.dumps(parsed_value, ensure_ascii=False)\n                else:\n                    arguments[arg_key] = str(parsed_value)\n            elif arg_type is None:\n                # If type is not defined, keep the parsed value as-is\n                arguments[arg_key] = parsed_value if is_good_json else arg_value\n            else:\n                # For other types (number, object, array, etc.), use parsed value\n                arguments[arg_key] = parsed_value if is_good_json else arg_value\n\n        return arguments\n\n    def supports_structural_tag(self) -> bool:\n        return False\n\n    def structure_info(self) -> _GetInfoFunc:\n        raise NotImplementedError()\n"
  },
  {
    "path": "python/sglang/srt/function_call/gpt_oss_detector.py",
    "content": "import json\nimport logging\nimport re\nfrom typing import List, Optional\n\nfrom sglang.srt.entrypoints.openai.protocol import Tool\nfrom sglang.srt.environ import envs\nfrom sglang.srt.function_call.base_format_detector import BaseFormatDetector\nfrom sglang.srt.function_call.core_types import (\n    StreamingParseResult,\n    ToolCallItem,\n    _GetInfoFunc,\n)\nfrom sglang.srt.parser.harmony_parser import HarmonyParser\n\nlogger = logging.getLogger(__name__)\n\n\nclass GptOssDetector(BaseFormatDetector):\n    \"\"\"\n    Detector for T4-style function calls using HarmonyParser.\n\n    Handles tool calls in the format:\n    <|channel|>commentary to={namespace.function}<|constrain|>json<|message|>{args}<|call|>\n    \"\"\"\n\n    def __init__(self):\n        super().__init__()\n        self.harmony_parser = HarmonyParser()\n        self.bot_token = \"<|start|>assistant<|channel|>commentary\"\n        self.eot_token = \"<|call|>\"\n\n        # Pattern to extract function name and JSON from tool_call event content\n        self.tool_extract_pattern = re.compile(\n            r\"to=([a-zA-Z_][a-zA-Z0-9_.-]*)\\s*<\\|constrain\\|>json<\\|message\\|>(.*?)(?:<\\|call\\|>|$)\",\n            re.DOTALL,\n        )\n\n    def has_tool_call(self, text: str) -> bool:\n        \"\"\"Check if text contains TypeScript-style function call markers.\"\"\"\n        return self.bot_token in text\n\n    def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:\n        \"\"\"Parse TypeScript-style function calls from complete text.\"\"\"\n        if not self.has_tool_call(text):\n            return StreamingParseResult(normal_text=text, calls=[])\n\n        # Parse with HarmonyParser\n        events = self.harmony_parser.parse(text)\n        # Flush buffer for complete parsing\n        events += self.harmony_parser.parse(\"\")\n\n        tool_indices = self._get_tool_indices(tools)\n        calls = []\n        normal_parts = []\n        tool_index = 0\n\n        for event in events:\n            if event.event_type == \"tool_call\":\n                # Extract tool call from event content\n                tool_call = self._extract_tool_call_from_event(\n                    event.raw_text if event.raw_text else event.content,\n                    tool_indices,\n                    tool_index,\n                )\n                if tool_call:\n                    calls.append(tool_call)\n                    tool_index += 1\n            elif event.event_type == \"normal\":\n                normal_parts.append(event.content)\n            # Ignore reasoning events in function call context\n\n        normal_text = \" \".join(normal_parts).strip()\n        return StreamingParseResult(normal_text=normal_text, calls=calls)\n\n    def parse_streaming_increment(\n        self, new_text: str, tools: List[Tool]\n    ) -> StreamingParseResult:\n        \"\"\"Parse incremental streaming text for TypeScript-style function calls.\"\"\"\n        self._buffer += new_text\n\n        # Always use HarmonyParser for parsing to ensure proper filtering\n        events = self.harmony_parser.parse(new_text)\n\n        # If there are no parsed events and the chunk contains no Harmony structural\n        # markers, treat it as plain text and pass it through. This fixes a bug where\n        # normal content was held in the buffer when tools were provided but not used.\n        if not events:\n            has_harmony_markers = any(\n                marker in self._buffer\n                for marker in (\n                    \"<|start|>\",\n                    \"<|channel|>\",\n                    \"<|message|>\",\n                    \"<|constrain|>\",\n                    \"<|end|>\",\n                    \"<|call|>\",\n                    \"<|return|>\",\n                    \"assistantfinal\",\n                )\n            )\n            if not has_harmony_markers:\n                # Plain text with no tool markers — emit as normal content\n                out = self._buffer\n                self._buffer = \"\"\n                return StreamingParseResult(normal_text=out, calls=[])\n\n        # Quick check if we might have tool calls\n        if (\n            \"<|channel|>commentary to=\" not in self._buffer\n            and not self.current_tool_name_sent\n        ):\n            # No tool calls detected, check for final content\n            if (\n                \"<|channel|>final\" in self._buffer\n                or \"assistantfinal\" in self._buffer.lower()\n            ):\n                # Extract normal text from events\n                normal_text = \"\".join(\n                    [e.content for e in events if e.event_type == \"normal\"]\n                )\n                if normal_text:\n                    self._buffer = \"\"\n                    return StreamingParseResult(normal_text=normal_text, calls=[])\n\n            # For other content, extract normal text from events (with filtering applied)\n            normal_text = \"\".join(\n                [e.content for e in events if e.event_type == \"normal\"]\n            )\n            if normal_text or events:\n                self._buffer = \"\"\n                return StreamingParseResult(normal_text=normal_text, calls=[])\n            else:\n                # No events processed, continue buffering\n                return StreamingParseResult(normal_text=\"\", calls=[])\n\n        if not events:\n            # No complete events yet\n            return StreamingParseResult(normal_text=\"\", calls=[])\n\n        # Initialize state if needed\n        if not hasattr(self, \"_tool_indices\"):\n            self._tool_indices = self._get_tool_indices(tools)\n\n        calls = []\n        normal_text = \"\"\n\n        for event in events:\n            if event.event_type == \"tool_call\":\n                # We got a complete tool call from HarmonyParser\n                tool_call_info = self._extract_tool_call_from_event(\n                    event.raw_text if event.raw_text else event.content,\n                    self._tool_indices,\n                    self.current_tool_id if self.current_tool_id >= 0 else 0,\n                )\n\n                if tool_call_info:\n                    # Initialize state if first tool\n                    if self.current_tool_id == -1:\n                        self.current_tool_id = 0\n                        self.prev_tool_call_arr = []\n                        self.streamed_args_for_tool = [\"\"]\n\n                    # Ensure arrays are large enough\n                    while len(self.prev_tool_call_arr) <= self.current_tool_id:\n                        self.prev_tool_call_arr.append({})\n                    while len(self.streamed_args_for_tool) <= self.current_tool_id:\n                        self.streamed_args_for_tool.append(\"\")\n\n                    # Store tool call info\n                    self.prev_tool_call_arr[self.current_tool_id] = {\n                        \"name\": tool_call_info.name,\n                        \"arguments\": json.loads(tool_call_info.parameters),\n                    }\n\n                    # Emit the complete tool call at once\n                    # (Could be modified to emit name first, then args, if needed)\n                    calls.append(tool_call_info)\n\n                    # Mark as streamed\n                    self.streamed_args_for_tool[self.current_tool_id] = (\n                        tool_call_info.parameters\n                    )\n\n                    # Move to next tool\n                    self.current_tool_id += 1\n                    self.current_tool_name_sent = False\n\n            elif event.event_type == \"normal\":\n                normal_text += event.content\n\n        # Clear buffer since HarmonyParser handles buffering\n        self._buffer = \"\"\n\n        return StreamingParseResult(normal_text=normal_text, calls=calls)\n\n    def _extract_tool_call_from_event(\n        self, content: str, tool_indices: dict, tool_index: int\n    ) -> Optional[ToolCallItem]:\n        \"\"\"\n        Extract tool call information from HarmonyParser event content.\n\n        Content format: \"commentary to=functions.get_weather<|constrain|>json<|message|>{...}\"\n        \"\"\"\n        match = self.tool_extract_pattern.search(content)\n\n        if not match:\n            logger.debug(f\"Could not extract tool call from: {content[:100]}\")\n            return None\n\n        full_function_name = match.group(1)\n        json_content = match.group(2)\n\n        # Extract function name (last part after .)\n        function_name = (\n            full_function_name.split(\".\")[-1]\n            if \".\" in full_function_name\n            else full_function_name\n        )\n\n        # Check if tool exists\n        if function_name not in tool_indices:\n            logger.debug(f\"Function {function_name} not in available tools\")\n            if not envs.SGLANG_FORWARD_UNKNOWN_TOOLS.get():\n                return None  # Skip unknown tools (default legacy behavior)\n\n        # Parse JSON arguments\n        try:\n            arguments = json.loads(json_content) if json_content.strip() else {}\n        except json.JSONDecodeError as e:\n            logger.debug(f\"Failed to parse JSON arguments: {e}\")\n            return None\n\n        return ToolCallItem(\n            tool_index=tool_index,\n            name=function_name,\n            parameters=json.dumps(arguments, ensure_ascii=False),\n        )\n\n    def structure_info(self) -> _GetInfoFunc:\n        raise NotImplementedError(\"structure_info not used with HarmonyParser\")\n"
  },
  {
    "path": "python/sglang/srt/function_call/hermes_detector.py",
    "content": "import json\nimport logging\nimport re\nfrom typing import List\n\nfrom sglang.srt.entrypoints.openai.protocol import Tool\nfrom sglang.srt.function_call.base_format_detector import BaseFormatDetector\nfrom sglang.srt.function_call.core_types import (\n    StreamingParseResult,\n    StructureInfo,\n    _GetInfoFunc,\n)\n\nlogger = logging.getLogger(__name__)\n\n\nclass HermesDetector(BaseFormatDetector):\n    \"\"\"\n    Detector for Hermes tool call format.\n\n    Format:\n        <tool_call>{\"name\": \"...\", \"arguments\": {...}}</tool_call>\n    \"\"\"\n\n    def __init__(self):\n        super().__init__()\n        self.bot_token = \"<tool_call>\"\n        self.eot_token = \"</tool_call>\"\n        self.tool_call_regex = re.compile(\n            r\"<tool_call>(.*?)</tool_call>|<tool_call>(.*)\", re.DOTALL\n        )\n        self._normal_text_buffer = \"\"\n\n    def has_tool_call(self, text: str) -> bool:\n        return self.bot_token in text\n\n    def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:\n        \"\"\"\n        One-time parsing: Detects and parses tool calls in the provided text.\n        \"\"\"\n        idx = text.find(self.bot_token)\n        normal_text = text[:idx].strip() if idx != -1 else text\n        if self.bot_token not in text:\n            return StreamingParseResult(normal_text=normal_text, calls=[])\n\n        calls = []\n        try:\n            for match in self.tool_call_regex.findall(text):\n                raw = match[0] or match[1]\n                if not raw:\n                    continue\n                parsed = json.loads(raw.strip())\n                if isinstance(parsed, list):\n                    calls.extend(self.parse_base_json(parsed, tools))\n                else:\n                    calls.extend(self.parse_base_json(parsed, tools))\n            return StreamingParseResult(normal_text=normal_text, calls=calls)\n        except Exception as e:\n            logger.error(f\"Error in detect_and_parse: {e}\")\n            return StreamingParseResult(normal_text=text)\n\n    def _clean_normal_text(self, text: str) -> str:\n        if not text:\n            return text\n\n        self._normal_text_buffer += text\n\n        if self.eot_token in self._normal_text_buffer:\n            cleaned = self._normal_text_buffer.replace(self.eot_token, \"\")\n            self._normal_text_buffer = \"\"\n            return cleaned\n\n        partial_len = self._ends_with_partial_token(\n            self._normal_text_buffer, self.eot_token\n        )\n        if partial_len:\n            safe_text = self._normal_text_buffer[:-partial_len]\n            self._normal_text_buffer = self._normal_text_buffer[-partial_len:]\n            return safe_text\n\n        cleaned = self._normal_text_buffer\n        self._normal_text_buffer = \"\"\n        return cleaned\n\n    def parse_streaming_increment(\n        self, new_text: str, tools: List[Tool]\n    ) -> StreamingParseResult:\n        \"\"\"\n        Streaming parsing: handle normal text, partial tags, and tool calls.\n        \"\"\"\n        self._buffer += new_text\n        current_text = self._buffer\n\n        if self.bot_token not in current_text:\n            partial_len = self._ends_with_partial_token(current_text, self.bot_token)\n            if partial_len:\n                safe_text = current_text[:-partial_len]\n                self._buffer = current_text[-partial_len:]\n            else:\n                safe_text = current_text\n                self._buffer = \"\"\n            return StreamingParseResult(normal_text=self._clean_normal_text(safe_text))\n\n        bot_pos = current_text.find(self.bot_token)\n        if bot_pos > 0:\n            normal_text = current_text[:bot_pos]\n            self._buffer = current_text[bot_pos:]\n            return StreamingParseResult(normal_text=normal_text)\n\n        result = super().parse_streaming_increment(new_text=\"\", tools=tools)\n        if result.normal_text:\n            result.normal_text = self._clean_normal_text(result.normal_text)\n        return result\n\n    def structure_info(self) -> _GetInfoFunc:\n        return lambda name: StructureInfo(\n            begin='<tool_call>{\"name\":\"' + name + '\", \"arguments\":',\n            end=\"}</tool_call>\",\n            trigger=\"<tool_call>\",\n        )\n"
  },
  {
    "path": "python/sglang/srt/function_call/internlm_detector.py",
    "content": "# modified from https://github.com/InternLM/lmdeploy/blob/main/lmdeploy/serve/openai/tool_parser/internlm2_parser.py\n\nimport json\nimport logging\nimport re\nfrom typing import List\n\nfrom sglang.srt.entrypoints.openai.protocol import Tool\nfrom sglang.srt.environ import envs\nfrom sglang.srt.function_call.base_format_detector import BaseFormatDetector\nfrom sglang.srt.function_call.core_types import (\n    StreamingParseResult,\n    StructureInfo,\n    ToolCallItem,\n    _GetInfoFunc,\n)\n\nlogger = logging.getLogger(__name__)\n\n\nclass InternlmDetector(BaseFormatDetector):\n    \"\"\"\n    Detector for InternLM2/Intern-S1 model function call format.\n\n    The InternLM format uses special tokens to delimit function calls\n    with JSON for arguments.\n\n    Format Structure:\n    ```\n    text<|action_start|> <|plugin|>\n    {json}<|action_end|>\n    ```\n\n    Examples:\n    ```\n    What's the weather like?<|action_start|> <|plugin|>\n    {\"name\": \"get_weather\", \"parameters\": {\"location\": \"Tokyo\"}}<|action_end|>\n    ```\n\n    Key Components:\n    - Tool Call Start: `<|action_start|> <|plugin|>`\n    - Tool Call End: `<|action_end|>`\n    - Arguments: JSON object with `name` and `parameters`/`arguments`\n    - Supports multiple sequential tool calls in both streaming and non-streaming modes\n\n    \"\"\"\n\n    def __init__(self):\n        super().__init__()\n        self.bot_token = \"<|action_start|> <|plugin|>\"\n        self.eot_token = \"<|action_end|>\"\n        self.position = 0\n\n    def has_tool_call(self, text: str) -> bool:\n        \"\"\"Check if the text contains an InternLM format tool call.\"\"\"\n        has_call = self.bot_token in text\n        return has_call\n\n    def get_arguments(self, obj):\n        \"\"\"Extract arguments from object, supporting both 'parameters' and 'arguments' keys.\"\"\"\n        if \"parameters\" in obj:\n            return obj.get(\"parameters\")\n        elif \"arguments\" in obj:\n            return obj.get(\"arguments\")\n        return None\n\n    def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:\n        \"\"\"\n        One-time parsing: Detects and parses tool calls in the provided text.\n        Supports multiple tool calls in the format:\n        <|action_start|> <|plugin|>\\n{JSON}<|action_end|>\n\n        :param text: The complete text to parse.\n        :param tools: List of available tools.\n        :return: StreamingParseResult with normal text and parsed tool calls.\n        \"\"\"\n\n        # Find the first occurrence of tool call marker to extract normal text\n        idx = text.find(self.bot_token)\n        normal_text = text[:idx].strip() if idx != -1 else text\n\n        if self.bot_token not in text:\n            logger.warning(\"[InternLM Tool Call] No tool call markers found in text\")\n            return StreamingParseResult(normal_text=normal_text, calls=[])\n\n        # Use regex to find all tool call blocks\n        # Pattern matches: {self.bot_token}{...}{self.eot_token}\n        tool_call_pattern = (\n            rf\"{re.escape(self.bot_token)}\\s*(.*?){re.escape(self.eot_token)}\"\n        )\n        matches = re.findall(tool_call_pattern, text, re.DOTALL)\n\n        if not matches:\n            logger.warning(\"[InternLM Tool Call] No complete tool call blocks found\")\n            return StreamingParseResult(normal_text=text, calls=[])\n\n        logger.info(f\"[InternLM Tool Call] Found {len(matches)} tool call(s)\")\n\n        calls = []\n        tool_indices = self._get_tool_indices(tools)\n\n        try:\n            for idx, action_json in enumerate(matches):\n                action_json = action_json.strip()\n\n                try:\n                    # Parse the JSON\n                    action_dict = json.loads(action_json)\n                    name = action_dict.get(\"name\")\n                    parameters = self.get_arguments(action_dict)\n\n                    if not parameters:\n                        parameters = {}\n\n                    logger.info(\n                        f\"[InternLM Tool Call] Parsed tool call #{idx+1}: name={name}, \"\n                        f\"parameters={json.dumps(parameters, ensure_ascii=False)}\"\n                    )\n\n                    # Validate tool name\n                    if not (name and name in tool_indices):\n                        logger.warning(\n                            f\"[InternLM Tool Call] Model attempted to call undefined function: {name}, \"\n                            f\"available_tools={list(tool_indices.keys())}\"\n                        )\n                        if not envs.SGLANG_FORWARD_UNKNOWN_TOOLS.get():\n                            continue  # Skip this tool call\n\n                    # Create tool call item and add to list\n                    tool_call = ToolCallItem(\n                        tool_index=tool_indices[name],\n                        name=name,\n                        parameters=json.dumps(parameters, ensure_ascii=False),\n                    )\n                    calls.append(tool_call)\n\n                except json.JSONDecodeError as e:\n                    logger.error(\n                        f\"[InternLM Tool Call] Failed to parse JSON for tool call #{idx+1}: {e}\"\n                    )\n                    continue\n\n            logger.info(\n                f\"[InternLM Tool Call] Successfully parsed {len(calls)} tool call(s), \"\n                f\"normal_text_length={len(normal_text)}\"\n            )\n            return StreamingParseResult(normal_text=normal_text, calls=calls)\n\n        except Exception as e:\n            logger.error(\n                f\"[InternLM Tool Call] Error in detect_and_parse: {e}\", exc_info=True\n            )\n            return StreamingParseResult(normal_text=text, calls=[])\n\n    def parse_streaming_increment(\n        self, new_text: str, tools: List[Tool]\n    ) -> StreamingParseResult:\n        \"\"\"\n        Streaming incremental parsing for InternLM format.\n\n        Supports a single tool call in streaming mode.\n        \"\"\"\n        self._buffer += new_text\n        current_text = self._buffer\n\n        # Check if we don't have a tool call start marker\n        start = current_text.find(self.bot_token)\n        if start == -1:\n            # No tool call marker found\n            # If we've already processed tool calls, don't return text again\n            if self.current_tool_id > 0:\n                self._buffer = \"\"\n                return StreamingParseResult(normal_text=\"\")\n\n            # Check if buffer could be partial start of bot_token\n            if not self._ends_with_partial_token(current_text, self.bot_token):\n                # Not a partial match, return as normal text\n                normal_text = current_text\n                self._buffer = \"\"\n                # Clean up any stray end tokens\n                if self.eot_token in normal_text:\n                    normal_text = normal_text.replace(self.eot_token, \"\")\n                return StreamingParseResult(normal_text=normal_text)\n            else:\n                # Might be partial start token, keep buffering\n                return StreamingParseResult()\n\n        # Check if we have a complete tool call (with end marker)\n        end = current_text.find(self.eot_token)\n        if end != -1:\n            # We have a complete tool call\n            # Initialize state if this is the first tool call\n            if self.current_tool_id == -1:\n                self.current_tool_id = 0\n                self.prev_tool_call_arr = []\n                self.streamed_args_for_tool = [\"\"]\n\n            # Ensure we have enough entries in our tracking arrays\n            while len(self.prev_tool_call_arr) <= self.current_tool_id:\n                self.prev_tool_call_arr.append({})\n            while len(self.streamed_args_for_tool) <= self.current_tool_id:\n                self.streamed_args_for_tool.append(\"\")\n\n            # Use detect_and_parse on the complete tool call\n            complete_section = current_text[: end + len(self.eot_token)]\n            result = self.detect_and_parse(complete_section, tools=tools)\n\n            if result.calls:\n                # Update the tool call index\n                result.calls[0].tool_index = self.current_tool_id\n                # Store the parsed tool call for reference\n                self.prev_tool_call_arr[self.current_tool_id] = {\n                    \"name\": result.calls[0].name,\n                    \"arguments\": json.loads(result.calls[0].parameters),\n                }\n                self.streamed_args_for_tool[self.current_tool_id] = result.calls[\n                    0\n                ].parameters\n                # Increment tool ID for next tool call\n                self.current_tool_id += 1\n\n            # Remove the completed tool call from buffer\n            self._buffer = current_text[end + len(self.eot_token) :]\n            return result\n\n        # We have bot_token but no eot_token yet - handle partial tool call streaming\n        # Extract normal text before the tool call\n        normal_text = current_text[:start]\n        # Keep the tool call part in buffer\n        self._buffer = current_text[start:]\n        return StreamingParseResult(normal_text=normal_text)\n\n    def structure_info(self) -> _GetInfoFunc:\n        \"\"\"\n        Return structure information for constrained generation.\n\n        For InternLM format, the structure is:\n        - begin: <|action_start|> <|plugin|>\\n\n        - end: <|action_end|>\n        - trigger: the begin token\n        \"\"\"\n        return lambda name: StructureInfo(\n            begin='<|action_start|> <|plugin|>\\n{\"name\": \"'\n            + name\n            + '\", \"parameters\": ',\n            end=\"}<|action_end|>\",\n            trigger=\"<|action_start|> <|plugin|>\",\n        )\n"
  },
  {
    "path": "python/sglang/srt/function_call/json_array_parser.py",
    "content": "from typing import List\n\nfrom sglang.srt.entrypoints.openai.protocol import Tool\nfrom sglang.srt.function_call.base_format_detector import BaseFormatDetector\nfrom sglang.srt.function_call.core_types import StreamingParseResult\n\n\nclass JsonArrayParser(BaseFormatDetector):\n    \"\"\"\n    Parser for JSON array tool calls when JSON schema constraints are active.\n\n    This parser is used when tool_choice=\"required\" or a specific tool is named,\n    bypassing model-specific parsers in favor of direct JSON array parsing.\n    \"\"\"\n\n    def __init__(self):\n        super().__init__()\n        # Configure for JSON array parsing\n        self.bot_token = \"[\"\n        self.eot_token = \"]\"\n        self.tool_call_separator = \",\"\n\n    def has_tool_call(self, text: str) -> bool:\n        \"\"\"\n        Check if the given text contains a JSON tool call (array or single object).\n        \"\"\"\n        return \"[\" in text or \"{\" in text\n\n    def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:\n        \"\"\"\n        Parse JSON tool calls using the base class implementation.\n        \"\"\"\n        raise NotImplementedError(\n            \"Detect and parse not supported for JSON schema constraints.\"\n        )\n\n    def parse_streaming_increment(\n        self, new_text: str, tools: List[Tool]\n    ) -> StreamingParseResult:\n        \"\"\"\n        Streaming incremental parsing with tool validation.\n        \"\"\"\n        return super().parse_streaming_increment(new_text, tools)\n\n    def structure_info(self) -> callable:\n        \"\"\"\n        Return a function that creates StructureInfo for constrained generation.\n        This is not used for JSON schema constraints as they are handled\n        by the constraint backends directly.\n        \"\"\"\n        raise NotImplementedError(\"structure_info not used for JSON schema constraints\")\n"
  },
  {
    "path": "python/sglang/srt/function_call/kimik2_detector.py",
    "content": "import json\nimport logging\nimport re\nfrom typing import List\n\nfrom sglang.srt.entrypoints.openai.protocol import Tool\nfrom sglang.srt.function_call.base_format_detector import BaseFormatDetector\nfrom sglang.srt.function_call.core_types import (\n    StreamingParseResult,\n    StructureInfo,\n    ToolCallItem,\n    _GetInfoFunc,\n)\nfrom sglang.srt.function_call.utils import _is_complete_json\n\nlogger = logging.getLogger(__name__)\n\n_KIMI_K2_SPECIAL_TOKENS = [\n    \"<|tool_calls_section_begin|>\",\n    \"<|tool_calls_section_end|>\",\n    \"<|tool_call_begin|>\",\n    \"<|tool_call_end|>\",\n    \"<|tool_call_argument_begin|>\",\n]\n\n\ndef _strip_special_tokens(text: str) -> str:\n    \"\"\"Remove all Kimi-K2 tool-call special tokens from text.\"\"\"\n    for token in _KIMI_K2_SPECIAL_TOKENS:\n        text = text.replace(token, \"\")\n    return text\n\n\nclass KimiK2Detector(BaseFormatDetector):\n    \"\"\"\n    Detector for Kimi K2 / K2.5 model function call format.\n\n    Format Structure:\n    ```\n    <|tool_calls_section_begin|>\n    <|tool_call_begin|>functions.{func_name}:{index}<|tool_call_argument_begin|>{json_args}<|tool_call_end|>\n    <|tool_calls_section_end|>\n    ```\n\n    Reference: https://huggingface.co/moonshotai/Kimi-K2-Instruct/blob/main/docs/tool_call_guidance.md\n    \"\"\"\n\n    def __init__(self):\n        super().__init__()\n\n        self.bot_token: str = \"<|tool_calls_section_begin|>\"\n        self.eot_token: str = \"<|tool_calls_section_end|>\"\n\n        self.tool_call_start_token: str = \"<|tool_call_begin|>\"\n        self.tool_call_end_token: str = \"<|tool_call_end|>\"\n        self.tool_call_argument_begin_token: str = \"<|tool_call_argument_begin|>\"\n\n        # Support hyphenated function names (common in MCP tools, e.g. mcp__portal__search-documents)\n        self.tool_call_regex = re.compile(\n            r\"<\\|tool_call_begin\\|>\\s*(?P<tool_call_id>[\\w.\\-]+:\\d+)\\s*<\\|tool_call_argument_begin\\|>\\s*(?P<function_arguments>\\{.*?\\})\\s*<\\|tool_call_end\\|>\",\n            re.DOTALL,\n        )\n\n        self.stream_tool_call_portion_regex = re.compile(\n            r\"<\\|tool_call_begin\\|>\\s*(?P<tool_call_id>[\\w.\\-]+:\\d+)\\s*<\\|tool_call_argument_begin\\|>\\s*(?P<function_arguments>\\{.*)\",\n            re.DOTALL,\n        )\n\n        self._last_arguments = \"\"\n\n        # Robust parser for ids like \"functions.search:0\", \"functions.mcp__search-docs:0\", or fallback \"search:0\"\n        self.tool_call_id_regex = re.compile(\n            r\"^(?:functions\\.)?(?P<name>[\\w.\\-]+):(?P<index>\\d+)$\"\n        )\n\n    def has_tool_call(self, text: str) -> bool:\n        \"\"\"Check if the text contains a KimiK2 format tool call.\"\"\"\n        return self.bot_token in text\n\n    def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:\n        \"\"\"\n        One-time parsing: Detects and parses tool calls in the provided text.\n\n        :param text: The complete text to parse.\n        :param tools: List of available tools.\n        :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.\n        \"\"\"\n        if self.bot_token not in text:\n            return StreamingParseResult(normal_text=text, calls=[])\n        try:\n            # there are two possible captures - between tags, or between a\n            # tag and end-of-string so the result of\n            # findall is an array of tuples where one is a function call and\n            # the other is None\n            function_call_tuples = self.tool_call_regex.findall(text)\n\n            logger.debug(\"function_call_tuples: %s\", function_call_tuples)\n\n            tool_calls = []\n            for match in function_call_tuples:\n                function_id, function_args = match\n                m = self.tool_call_id_regex.match(function_id)\n                if not m:\n                    logger.warning(\"Unexpected tool_call_id format: %s\", function_id)\n                    continue\n                function_name = m.group(\"name\")\n                function_idx = int(m.group(\"index\"))\n\n                logger.debug(f\"function_name {function_name}\")\n\n                tool_calls.append(\n                    ToolCallItem(\n                        tool_index=function_idx,\n                        name=function_name,\n                        parameters=function_args,\n                    )\n                )\n\n            content = text[: text.find(self.bot_token)]\n            return StreamingParseResult(normal_text=content, calls=tool_calls)\n\n        except Exception as e:\n            logger.error(f\"Error in detect_and_parse: {e}\")\n            # return the normal text if parsing fails\n            return StreamingParseResult(normal_text=text)\n\n    def parse_streaming_increment(\n        self, new_text: str, tools: List[Tool]\n    ) -> StreamingParseResult:\n        \"\"\"\n        Streaming incremental parsing tool calls for KimiK2 format.\n        \"\"\"\n        self._buffer += new_text\n        current_text = self._buffer\n\n        # Check if we have a tool call (either the start token or individual tool call)\n        has_tool_call = (\n            self.bot_token in current_text or self.tool_call_start_token in current_text\n        )\n\n        if not has_tool_call:\n            self._buffer = \"\"\n            normal_text = _strip_special_tokens(new_text)\n            return StreamingParseResult(normal_text=normal_text)\n\n        if not hasattr(self, \"_tool_indices\"):\n            self._tool_indices = self._get_tool_indices(tools)\n\n        calls: list[ToolCallItem] = []\n        try:\n            match = self.stream_tool_call_portion_regex.search(current_text)\n            if match:\n                function_id = match.group(\"tool_call_id\")\n                function_args = match.group(\"function_arguments\")\n\n                m = self.tool_call_id_regex.match(function_id)\n                if not m:\n                    logger.warning(\"Unexpected tool_call_id format: %s\", function_id)\n                    return StreamingParseResult(normal_text=\"\", calls=calls)\n                function_name = m.group(\"name\")\n\n                # Initialize state if this is the first tool call\n                if self.current_tool_id == -1:\n                    self.current_tool_id = 0\n                    self.prev_tool_call_arr = []\n                    self.streamed_args_for_tool = [\"\"]\n\n                # Ensure we have enough entries in our tracking arrays\n                while len(self.prev_tool_call_arr) <= self.current_tool_id:\n                    self.prev_tool_call_arr.append({})\n                while len(self.streamed_args_for_tool) <= self.current_tool_id:\n                    self.streamed_args_for_tool.append(\"\")\n\n                if not self.current_tool_name_sent:\n                    calls.append(\n                        ToolCallItem(\n                            tool_index=self.current_tool_id,\n                            name=function_name,\n                            parameters=\"\",\n                        )\n                    )\n                    self.current_tool_name_sent = True\n                    self.prev_tool_call_arr[self.current_tool_id] = {\n                        \"name\": function_name,\n                        \"arguments\": {},\n                    }\n                else:\n                    argument_diff = (\n                        function_args[len(self._last_arguments) :]\n                        if function_args.startswith(self._last_arguments)\n                        else function_args\n                    )\n\n                    parsed_args_diff = argument_diff.split(self.tool_call_end_token, 1)[\n                        0\n                    ]\n\n                    if parsed_args_diff:\n                        calls.append(\n                            ToolCallItem(\n                                tool_index=self.current_tool_id,\n                                name=None,\n                                parameters=parsed_args_diff,\n                            )\n                        )\n                        self._last_arguments += parsed_args_diff\n                        self.streamed_args_for_tool[\n                            self.current_tool_id\n                        ] += parsed_args_diff\n\n                    parsed_args = function_args.split(self.tool_call_end_token, 1)[0]\n                    if _is_complete_json(parsed_args):\n                        try:\n                            parsed_args = json.loads(parsed_args)\n                            self.prev_tool_call_arr[self.current_tool_id][\n                                \"arguments\"\n                            ] = parsed_args\n                        except json.JSONDecodeError:\n                            pass\n\n                        # Find the end of the current tool call and remove only that part from buffer\n                        tool_call_end_pattern = (\n                            r\"<\\|tool_call_begin\\|>.*?<\\|tool_call_end\\|>\"\n                        )\n                        end_match = re.search(\n                            tool_call_end_pattern, current_text, re.DOTALL\n                        )\n                        if end_match:\n                            self._buffer = current_text[end_match.end() :]\n                        else:\n                            self._buffer = \"\"\n\n                        result = StreamingParseResult(normal_text=\"\", calls=calls)\n                        self.current_tool_id += 1\n                        self._last_arguments = \"\"\n                        self.current_tool_name_sent = False\n                        return result\n\n            return StreamingParseResult(normal_text=\"\", calls=calls)\n\n        except Exception as e:\n            logger.error(f\"Error in parse_streaming_increment: {e}\")\n            return StreamingParseResult(normal_text=_strip_special_tokens(current_text))\n\n    def structure_info(self) -> _GetInfoFunc:\n        \"\"\"Return function that creates StructureInfo for guided generation.\"\"\"\n\n        def get_info(name: str) -> StructureInfo:\n            return StructureInfo(\n                begin=f\"<|tool_calls_section_begin|><|tool_call_begin|>functions.{name}:0<|tool_call_argument_begin|>\",\n                end=\"<|tool_call_end|><|tool_calls_section_end|>\",\n                trigger=\"<|tool_calls_section_begin|>\",\n            )\n\n        return get_info\n"
  },
  {
    "path": "python/sglang/srt/function_call/lfm2_detector.py",
    "content": "\"\"\"\nDetector for LFM2 (Liquid Foundation Model 2) function call format.\n\nFormat Structure (Pythonic style):\n```\n<|tool_call_start|>[function_name(arg1=\"value1\", arg2=\"value2\")]<|tool_call_end|>\n```\n\nMultiple tool calls:\n```\n<|tool_call_start|>[func1(arg=\"val\"), func2(arg=\"val\")]<|tool_call_end|>\n```\n\nAlso supports JSON format:\n```\n<|tool_call_start|>[{\"name\": \"func_name\", \"arguments\": {...}}]<|tool_call_end|>\n```\n\"\"\"\n\nimport ast\nimport json\nimport logging\nimport re\nfrom typing import Any, Dict, List, Optional, Tuple\n\nfrom sglang.srt.entrypoints.openai.protocol import Tool\nfrom sglang.srt.environ import envs\nfrom sglang.srt.function_call.base_format_detector import BaseFormatDetector\nfrom sglang.srt.function_call.core_types import (\n    StreamingParseResult,\n    StructureInfo,\n    ToolCallItem,\n    _GetInfoFunc,\n)\n\nlogger = logging.getLogger(__name__)\n\n\nclass Lfm2Detector(BaseFormatDetector):\n    \"\"\"\n    Detector for LFM2 (Liquid Foundation Model 2) function call format.\n\n    Supports both Pythonic and JSON formats:\n\n    Pythonic:\n    ```\n    <|tool_call_start|>[calculator(expression=\"5 * 7\")]<|tool_call_end|>\n    ```\n\n    JSON:\n    ```\n    <|tool_call_start|>[{\"name\": \"calculator\", \"arguments\": {\"expression\": \"5 * 7\"}}]<|tool_call_end|>\n    ```\n    \"\"\"\n\n    def __init__(self):\n        \"\"\"\n        Initializes the detector with necessary state variables.\n        \"\"\"\n        super().__init__()\n        self.bot_token = \"<|tool_call_start|>\"\n        self.eot_token = \"<|tool_call_end|>\"\n        self.tool_call_separator = \"\"\n\n    def has_tool_call(self, text: str) -> bool:\n        \"\"\"Check if the text contains an LFM2 format tool call.\"\"\"\n        return self.bot_token in text\n\n    def _get_parameter_value(self, val: ast.AST) -> Any:\n        \"\"\"\n        Extract Python literal value from AST node.\n\n        Handles constants, dicts, and lists recursively.\n        Reuses pattern from PythonicDetector.\n        \"\"\"\n        if isinstance(val, ast.Constant):\n            return val.value\n        elif isinstance(val, ast.Dict):\n            return {\n                self._get_parameter_value(k): self._get_parameter_value(v)\n                for k, v in zip(val.keys, val.values)\n                if k is not None  # Handle {**kwargs} case where key is None\n            }\n        elif isinstance(val, ast.List):\n            return [self._get_parameter_value(v) for v in val.elts]\n        elif isinstance(val, ast.Tuple):\n            return tuple(self._get_parameter_value(v) for v in val.elts)\n        elif isinstance(val, ast.Name):\n            # Handle True, False, None as names in older Python\n            if val.id == \"True\":\n                return True\n            elif val.id == \"False\":\n                return False\n            elif val.id == \"None\":\n                return None\n            else:\n                raise ValueError(f\"Unsupported name reference: {val.id}\")\n        elif isinstance(val, ast.UnaryOp) and isinstance(val.op, ast.USub):\n            # Handle negative numbers like -5\n            inner = self._get_parameter_value(val.operand)\n            if isinstance(inner, (int, float)):\n                return -inner\n            raise ValueError(f\"Cannot negate non-numeric value: {inner}\")\n        else:\n            raise ValueError(\n                f\"Tool call arguments must be literals, got: {type(val).__name__}\"\n            )\n\n    def _parse_pythonic_call(\n        self, call: ast.Call, call_index: int, tool_indices: Dict[str, int]\n    ) -> Optional[ToolCallItem]:\n        \"\"\"\n        Parse a single AST Call node into a ToolCallItem.\n\n        Args:\n            call: AST Call node representing a function call\n            call_index: Index of this call in the list of calls\n            tool_indices: Mapping of tool names to their indices\n\n        Returns:\n            ToolCallItem if successful, None if the call should be skipped\n        \"\"\"\n        if not isinstance(call.func, ast.Name):\n            logger.warning(\n                f\"Tool call function must be a simple name, got: {type(call.func).__name__}\"\n            )\n            return None\n\n        function_name = call.func.id\n\n        # Validate that the function exists in the tools\n        if function_name not in tool_indices:\n            logger.warning(\n                f\"Model attempted to call undefined function: {function_name}\"\n            )\n            if not envs.SGLANG_FORWARD_UNKNOWN_TOOLS.get():\n                return None  # Skip unknown tools (default legacy behavior)\n\n        # Parse arguments\n        arguments = {}\n        for keyword in call.keywords:\n            if keyword.arg is None:\n                # **kwargs unpacking - skip for now\n                logger.warning(\"Tool call with **kwargs unpacking is not supported\")\n                continue\n            try:\n                arguments[keyword.arg] = self._get_parameter_value(keyword.value)\n            except ValueError as e:\n                logger.warning(f\"Failed to parse argument {keyword.arg}: {e}\")\n                return None\n\n        return ToolCallItem(\n            tool_index=call_index,  # Use the call index in the response, not tool position\n            name=function_name,\n            parameters=json.dumps(arguments, ensure_ascii=False),\n        )\n\n    def _parse_pythonic_content(\n        self, content: str, tools: List[Tool]\n    ) -> Tuple[List[ToolCallItem], str]:\n        \"\"\"\n        Parse Pythonic format tool calls using AST.\n\n        Args:\n            content: The content between tool call tags (without the tags)\n            tools: List of available tools\n\n        Returns:\n            Tuple of (list of parsed calls, error message if any)\n        \"\"\"\n        content = content.strip()\n        tool_indices = self._get_tool_indices(tools)\n\n        try:\n            module = ast.parse(content)\n            parsed = getattr(module.body[0], \"value\", None) if module.body else None\n\n            if parsed is None:\n                return [], \"Empty or invalid Python expression\"\n\n            # Handle both single call and list of calls\n            if isinstance(parsed, ast.List):\n                call_nodes = parsed.elts\n            elif isinstance(parsed, ast.Call):\n                call_nodes = [parsed]\n            else:\n                return (\n                    [],\n                    f\"Expected function call or list, got: {type(parsed).__name__}\",\n                )\n\n            # Validate all elements are calls\n            if not all(isinstance(e, ast.Call) for e in call_nodes):\n                return [], \"Not all elements in list are function calls\"\n\n            calls = []\n            for call_index, call in enumerate(call_nodes):\n                item = self._parse_pythonic_call(call, call_index, tool_indices)\n                if item is not None:\n                    calls.append(item)\n\n            return calls, \"\"\n\n        except SyntaxError as e:\n            return [], f\"Python syntax error: {e}\"\n        except Exception as e:\n            logger.exception(\"Unexpected error in pythonic tool call parsing\")\n            return [], f\"Unexpected error: {e}\"\n\n    def _parse_json_content(\n        self, content: str, tools: List[Tool]\n    ) -> Tuple[List[ToolCallItem], str]:\n        \"\"\"\n        Parse JSON format tool calls.\n\n        Uses parse_base_json from BaseFormatDetector for consistent handling\n        of SGLANG_FORWARD_UNKNOWN_TOOLS and tool validation.\n\n        Args:\n            content: The content between tool call tags (without the tags)\n            tools: List of available tools\n\n        Returns:\n            Tuple of (list of parsed calls, error message if any)\n        \"\"\"\n        content = content.strip()\n\n        try:\n            parsed = json.loads(content)\n            # parse_base_json handles list/dict normalization, tool validation,\n            # and SGLANG_FORWARD_UNKNOWN_TOOLS consistently with other detectors\n            calls = self.parse_base_json(parsed, tools)\n            return calls, \"\"\n\n        except json.JSONDecodeError as e:\n            return [], f\"JSON parse error: {e}\"\n\n    def _parse_tool_calls_content(\n        self, content: str, tools: List[Tool]\n    ) -> List[ToolCallItem]:\n        \"\"\"\n        Parse the content between tool call tags.\n        Handles both JSON and Pythonic formats.\n        \"\"\"\n        content = content.strip()\n\n        # First, try JSON format (faster check)\n        if content.startswith(\"[{\") or content.startswith(\"{\"):\n            calls, error = self._parse_json_content(content, tools)\n            if calls:\n                return calls\n            # If JSON parsing failed but it looked like JSON, log the error\n            if error:\n                logger.debug(f\"JSON parsing failed: {error}, trying Pythonic format\")\n\n        # Try Pythonic format\n        calls, error = self._parse_pythonic_content(content, tools)\n        if calls:\n            return calls\n\n        if error:\n            logger.warning(f\"Failed to parse tool calls: {error}\")\n\n        return []\n\n    def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:\n        \"\"\"\n        One-time parsing: Detects and parses tool calls in the provided text.\n        \"\"\"\n        idx = text.find(self.bot_token)\n        normal_text = text[:idx].strip() if idx != -1 else text\n\n        if self.bot_token not in text:\n            return StreamingParseResult(normal_text=normal_text, calls=[])\n\n        # Find all <|tool_call_start|>...<|tool_call_end|> blocks\n        pattern = rf\"{re.escape(self.bot_token)}(.*?){re.escape(self.eot_token)}\"\n        match_result_list = re.findall(pattern, text, re.DOTALL)\n\n        calls = []\n        for match_result in match_result_list:\n            parsed_calls = self._parse_tool_calls_content(match_result, tools)\n            calls.extend(parsed_calls)\n\n        return StreamingParseResult(normal_text=normal_text, calls=calls)\n\n    def _strip_special_tokens(self, text: str) -> str:\n        \"\"\"Remove special tokens from text.\"\"\"\n        return text.replace(self.bot_token, \"\").replace(self.eot_token, \"\")\n\n    def parse_streaming_increment(\n        self, new_text: str, tools: List[Tool]\n    ) -> StreamingParseResult:\n        \"\"\"\n        Streaming incremental parsing for LFM2 tool calls.\n\n        This implementation properly handles Pythonic format by:\n        1. Buffering until we see complete <|tool_call_start|>[...]<|tool_call_end|>\n        2. Emitting normal text before tool calls immediately\n        3. Parsing complete tool call blocks using detect_and_parse\n\n        Based on PythonicDetector streaming logic.\n        \"\"\"\n        self._buffer += new_text\n\n        # Check for partial bot_token at the end\n        partial_bot = self._ends_with_partial_token(self._buffer, self.bot_token)\n        partial_eot = self._ends_with_partial_token(self._buffer, self.eot_token)\n\n        # Find bot_token position\n        bot_pos = self._buffer.find(self.bot_token)\n\n        if bot_pos == -1:\n            # No tool call start found\n            if partial_bot:\n                # Might be partial bot_token, hold back that part\n                safe_text = self._buffer[:-partial_bot]\n                self._buffer = self._buffer[-partial_bot:]\n                return StreamingParseResult(normal_text=safe_text)\n            else:\n                # No tool call, emit all as normal text\n                normal_text = self._strip_special_tokens(self._buffer)\n                self._buffer = \"\"\n                return StreamingParseResult(normal_text=normal_text)\n\n        # We have bot_token - extract any normal text before it\n        normal_text_before = self._buffer[:bot_pos] if bot_pos > 0 else \"\"\n\n        # Look for the end token\n        eot_pos = self._buffer.find(self.eot_token, bot_pos + len(self.bot_token))\n\n        if eot_pos == -1:\n            # No end token yet - check if we might have a partial one\n            if partial_eot:\n                # Hold back the partial token, but we need to keep buffering\n                # Just emit any normal text before the tool call\n                if normal_text_before:\n                    self._buffer = self._buffer[bot_pos:]\n                    return StreamingParseResult(normal_text=normal_text_before)\n                # Keep buffering\n                return StreamingParseResult(normal_text=\"\")\n\n            # No end token and no partial - keep buffering but emit normal text\n            if normal_text_before:\n                self._buffer = self._buffer[bot_pos:]\n                return StreamingParseResult(normal_text=normal_text_before)\n\n            # Just keep buffering\n            return StreamingParseResult(normal_text=\"\")\n\n        # We have a complete tool call block\n        tool_call_block = self._buffer[bot_pos : eot_pos + len(self.eot_token)]\n        remaining = self._buffer[eot_pos + len(self.eot_token) :]\n\n        # Parse the complete block\n        result = self.detect_and_parse(tool_call_block, tools)\n\n        # Update buffer with remaining text\n        self._buffer = remaining\n\n        # Add any normal text before the tool call\n        if normal_text_before:\n            result.normal_text = normal_text_before + (result.normal_text or \"\")\n\n        return result\n\n    def supports_structural_tag(self) -> bool:\n        \"\"\"\n        Return False because LFM2 uses Pythonic format which is not JSON-compatible.\n\n        structural_tag only supports JSON-compatible content between begin and end,\n        so it cannot parse Pythonic function call syntax like `func(arg=\"val\")`.\n        \"\"\"\n        return False\n\n    def structure_info(self) -> _GetInfoFunc:\n        \"\"\"\n        Return structure info for constrained generation.\n\n        Note: This is provided for completeness but won't be used since\n        supports_structural_tag() returns False.\n        \"\"\"\n        return lambda name: StructureInfo(\n            begin=\"<|tool_call_start|>[\" + name + \"(\",\n            end=\")]<|tool_call_end|>\",\n            trigger=\"<|tool_call_start|>\",\n        )\n"
  },
  {
    "path": "python/sglang/srt/function_call/llama32_detector.py",
    "content": "import ast\nimport json\nimport logging\nimport re\nfrom typing import List\n\nfrom sglang.srt.entrypoints.openai.protocol import Tool\nfrom sglang.srt.function_call.base_format_detector import BaseFormatDetector\nfrom sglang.srt.function_call.core_types import (\n    StreamingParseResult,\n    StructureInfo,\n    _GetInfoFunc,\n)\n\nlogger = logging.getLogger(__name__)\n\n\nclass Llama32Detector(BaseFormatDetector):\n    \"\"\"\n    Detector for Llama 3.2 models with json tool call format.\n\n    Format Structure:\n    ```\n    <python_tag>{\"name\":\"xxx\", \"arguments\":{...}}\n    ```\n    \"\"\"\n\n    def __init__(self):\n        super().__init__()\n        self.bot_token = \"<|python_tag|>\"\n        # NOTE: technically Llama3.2 doesn't support well with parallel tool calls\n        # They need specific prompt engineering to support parallel tool calls\n        # Here we use ';' as the separator, which might have compatibility issues\n        # if users define to use a different separator in their prompt\n        self.tool_call_separator = \";\"\n\n    def _convert_python_dict_to_json(self, text: str) -> str:\n        \"\"\"Convert Python dict strings to JSON format.\"\"\"\n        try:\n            parsed = ast.literal_eval(text.strip())\n            if isinstance(parsed, dict):\n                return json.dumps(parsed, ensure_ascii=False)\n        except:\n            pass\n        return text\n\n    def has_tool_call(self, text: str) -> bool:\n        \"\"\"Check if the text contains a Llama 3.2 format tool call.\"\"\"\n        # depending on the prompt format the Llama model may or may not\n        # prefix the output with the <|python_tag|> token\n        return \"<|python_tag|>\" in text or text.startswith(\"{\")\n\n    def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:\n        \"\"\"Parse function calls from text, handling multiple JSON objects.\"\"\"\n        if \"<|python_tag|>\" not in text and not text.startswith(\"{\"):\n            return StreamingParseResult(normal_text=text, calls=[])\n\n        if \"<|python_tag|>\" in text:\n            normal_text, action_text = text.split(\"<|python_tag|>\", maxsplit=1)\n        else:\n            normal_text, action_text = \"\", text\n\n        decoder = json.JSONDecoder()\n        idx = 0\n        safe_idx = idx  # the index of the last valid JSON object\n        all_actions = []\n        action_text_len = len(action_text)\n        while idx < action_text_len:\n            try:\n                obj, end = decoder.raw_decode(action_text[idx:])\n                all_actions.append(obj)\n                idx += end + len(self.tool_call_separator)\n                safe_idx = idx\n            except json.JSONDecodeError:\n                # Try Python dict conversion as fallback\n                try:\n                    dict_end = idx\n                    brace_count = 0\n                    for i in range(idx, action_text_len):\n                        if action_text[i] == \"{\":\n                            brace_count += 1\n                        elif action_text[i] == \"}\":\n                            brace_count -= 1\n                            if brace_count == 0:\n                                dict_end = i + 1\n                                break\n\n                    if dict_end > idx:\n                        potential_dict = action_text[idx:dict_end]\n                        json_version = self._convert_python_dict_to_json(potential_dict)\n                        if json_version != potential_dict:\n                            obj, _ = decoder.raw_decode(json_version)\n                            all_actions.append(obj)\n                            idx = dict_end + len(self.tool_call_separator)\n                            safe_idx = idx\n                            continue\n                except:\n                    pass\n\n                next_obj_start = action_text.find('{\"name\":', idx + 1)\n                if next_obj_start == -1:\n                    break\n                idx = next_obj_start\n\n        # Only process if we found valid JSON objects\n        calls = self.parse_base_json(all_actions, tools) if all_actions else []\n        # Use safe_idx to avoid idx containing the last part of an invalid JSON object\n        trailing_text = (\n            action_text[safe_idx:].strip() if safe_idx < action_text_len else \"\"\n        )\n        return StreamingParseResult(\n            normal_text=normal_text + trailing_text, calls=calls\n        )\n\n    def parse_streaming_increment(\n        self, new_text: str, tools: List[Tool]\n    ) -> StreamingParseResult:\n        \"\"\"Override to handle Python dict format in streaming.\"\"\"\n        # First try with converted Python dict\n        self._buffer += new_text\n        converted_buffer = self._buffer\n\n        # Convert Python dict syntax to JSON\n        converted_buffer = re.sub(r\"'([^']*)':\", r'\"\\1\":', converted_buffer)\n        converted_buffer = re.sub(r\":\\s*'([^']*)'\", r': \"\\1\"', converted_buffer)\n\n        # Temporarily replace buffer for parsing\n        original_buffer = self._buffer\n        self._buffer = converted_buffer\n\n        try:\n            result = super().parse_streaming_increment(\"\", tools)\n            return result\n        except:\n            # Fall back to original buffer\n            self._buffer = original_buffer\n            return super().parse_streaming_increment(new_text, tools)\n\n    def structure_info(self) -> _GetInfoFunc:\n        return lambda name: StructureInfo(\n            begin='<|python_tag|>{\"name\":\"' + name + '\", \"arguments\":',\n            end=\"}\",\n            trigger=\"<|python_tag|>\",\n        )\n"
  },
  {
    "path": "python/sglang/srt/function_call/mimo_detector.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\nimport ast\nimport html\nimport json\nimport logging\nimport re\nfrom typing import Any, Dict, List\n\nfrom sglang.srt.entrypoints.openai.protocol import Tool\nfrom sglang.srt.environ import envs\nfrom sglang.srt.function_call.base_format_detector import BaseFormatDetector\nfrom sglang.srt.function_call.core_types import StreamingParseResult, _GetInfoFunc\n\nlogger = logging.getLogger(__name__)\n\n\ndef _get_param_type(func_name: str, param_name: str, tools: List[Tool]) -> str:\n    \"\"\"Get parameter type from tool schema.\"\"\"\n    for tool in tools:\n        if tool.function.name == func_name:\n            props = tool.function.parameters.get(\"properties\", {})\n            if param_name in props:\n                return props[param_name].get(\"type\", \"string\")\n    return \"string\"\n\n\ndef _convert_param_value(\n    param_value: str, param_name: str, func_name: str, tools: List[Tool]\n) -> Any:\n    \"\"\"\n    Convert parameter value based on its type in the schema.\n    Adapted from vllm-project/vllm (vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py)\n    \"\"\"\n    param_value = html.unescape(param_value)\n\n    # Handle null value for any type\n    if param_value.lower() == \"null\":\n        return None\n\n    param_type = _get_param_type(func_name, param_name, tools)\n\n    if param_type in [\"string\", \"str\", \"text\", \"varchar\", \"char\", \"enum\"]:\n        return param_value\n    elif (\n        param_type.startswith(\"int\")\n        or param_type.startswith(\"integer\")\n        or param_type.startswith(\"uint\")\n        or param_type.startswith(\"long\")\n        or param_type.startswith(\"short\")\n        or param_type.startswith(\"unsigned\")\n    ):\n        try:\n            return int(param_value)\n        except (ValueError, TypeError):\n            logger.warning(\n                \"Parsed value '%s' of parameter '%s' is not an \"\n                \"integer in tool '%s', degenerating to string.\",\n                param_value,\n                param_name,\n                func_name,\n            )\n            return param_value\n    elif param_type.startswith(\"num\") or param_type.startswith(\"float\"):\n        try:\n            float_param_value = float(param_value)\n            return (\n                float_param_value\n                if float_param_value - int(float_param_value) != 0\n                else int(float_param_value)\n            )\n        except (ValueError, TypeError):\n            logger.warning(\n                \"Parsed value '%s' of parameter '%s' is not a float \"\n                \"in tool '%s', degenerating to string.\",\n                param_value,\n                param_name,\n                func_name,\n            )\n            return param_value\n    elif param_type in [\"boolean\", \"bool\", \"binary\"]:\n        param_value = param_value.lower()\n        if param_value not in [\"true\", \"false\"]:\n            logger.warning(\n                \"Parsed value '%s' of parameter '%s' is not a boolean \"\n                \"(`true` or `false`) in tool '%s', degenerating to \"\n                \"false.\",\n                param_value,\n                param_name,\n                func_name,\n            )\n        return param_value == \"true\"\n    else:\n        if (\n            param_type in [\"object\", \"array\", \"arr\"]\n            or param_type.startswith(\"dict\")\n            or param_type.startswith(\"list\")\n        ):\n            try:\n                param_value = json.loads(param_value)\n                return param_value\n            except (json.JSONDecodeError, TypeError, ValueError):\n                logger.warning(\n                    \"Parsed value '%s' of parameter '%s' cannot be \"\n                    \"parsed with json.loads in tool '%s', will try \"\n                    \"other methods to parse it.\",\n                    param_value,\n                    param_name,\n                    func_name,\n                )\n        try:\n            param_value = ast.literal_eval(param_value)  # safer\n        except (ValueError, SyntaxError, TypeError):\n            logger.warning(\n                \"Parsed value '%s' of parameter '%s' cannot be \"\n                \"converted via Python `ast.literal_eval()` in tool \"\n                \"'%s', degenerating to string.\",\n                param_value,\n                param_name,\n                func_name,\n            )\n        return param_value\n\n\nclass MiMoDetector(BaseFormatDetector):\n    \"\"\"\n    Detector for MiMo function call format.\n\n    Format:\n        <tool_call>\n        <function=execute_bash>\n        <parameter=command>pwd && ls</parameter>\n        </function>\n        </tool_call>\n    \"\"\"\n\n    def __init__(self):\n        super().__init__()\n        self.bot_token = \"<tool_call>\"\n        self.eot_token = \"</tool_call>\"\n        self.tool_call_regex = re.compile(r\"<tool_call>(.*?)</tool_call>\", re.DOTALL)\n        self.func_regex = re.compile(r\"<function=([^>]+)>(.*?)</function>\", re.DOTALL)\n        self.param_regex = re.compile(\n            r\"<parameter=([^>]+)>(.*?)</parameter>\", re.DOTALL\n        )\n\n    def has_tool_call(self, text: str) -> bool:\n        return self.bot_token in text\n\n    def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:\n        \"\"\"Parse complete text for tool calls.\"\"\"\n        idx = text.find(self.bot_token)\n        if idx == -1:\n            return StreamingParseResult(normal_text=text, calls=[])\n\n        normal_text = text[:idx]\n        tool_indices = self._get_tool_indices(tools)\n\n        calls = []\n        last_end = idx\n\n        for match in self.tool_call_regex.finditer(text):\n            tool_call_body = match.group(1)\n\n            parsed = self._parse_tool_call(tool_call_body, tools)\n\n            if parsed:\n                func_name = parsed.get(\"name\")\n                if func_name not in tool_indices:\n                    # Unknown function\n                    logger.warning(f\"Unknown function: {func_name}\")\n                    if not envs.SGLANG_FORWARD_UNKNOWN_TOOLS.get():\n                        # Return tool call block as normal text\n                        normal_text += text[last_end : match.end()]\n                        last_end = match.end()\n                        continue\n                calls.extend(self.parse_base_json(parsed, tools))\n\n            last_end = match.end()\n\n        return StreamingParseResult(normal_text=normal_text, calls=calls)\n\n    def parse_streaming_increment(\n        self, new_text: str, tools: List[Tool]\n    ) -> StreamingParseResult:\n        \"\"\"\n        Streaming parsing: buffer until complete tool call block.\n        \"\"\"\n        self._buffer += new_text\n        current_text = self._buffer\n\n        start = current_text.find(self.bot_token)\n        if start == -1:\n            if self.current_tool_id > 0:\n                # Already processing tool calls, keep buffering\n                # (more tool calls might come, don't discard text yet)\n                return StreamingParseResult(normal_text=\"\")\n            else:\n                # No tool calls seen yet, return as normal text\n                self._buffer = \"\"\n                return StreamingParseResult(normal_text=current_text)\n\n        # Find end token AFTER the start token\n        end = current_text.find(self.eot_token, start)\n        if end == -1:\n            # Incomplete tool call, return text before start and keep buffering\n            normal_text = current_text[:start]\n            self._buffer = current_text[start:]\n            return StreamingParseResult(normal_text=normal_text)\n\n        # Parse the complete tool call block\n        result = self.detect_and_parse(current_text[: end + len(self.eot_token)], tools)\n\n        if result.calls:\n            # Valid tool call - initialize tracking if first one\n            if self.current_tool_id == -1:\n                self.current_tool_id = 0\n                self.prev_tool_call_arr = []\n                self.streamed_args_for_tool = [\"\"]\n\n            while len(self.prev_tool_call_arr) <= self.current_tool_id:\n                self.prev_tool_call_arr.append({})\n            while len(self.streamed_args_for_tool) <= self.current_tool_id:\n                self.streamed_args_for_tool.append(\"\")\n\n            call = result.calls[0]\n            self.prev_tool_call_arr[self.current_tool_id] = {\n                \"name\": call.name,\n                \"arguments\": json.loads(call.parameters) if call.parameters else {},\n            }\n            self.streamed_args_for_tool[self.current_tool_id] = call.parameters\n            call.tool_index = self.current_tool_id\n            self.current_tool_id += 1\n\n        self._buffer = current_text[end + len(self.eot_token) :]\n        return result\n\n    def _parse_tool_call(\n        self, tool_call_body: str, tools: List[Tool]\n    ) -> Dict[str, Any]:\n        \"\"\"\n        Parse content inside <tool_call>...</tool_call>.\n\n        Structure:\n            tool_call_body contains: <function=name>...params...</function>\n        \"\"\"\n        # Match complete <function=name>body</function> block\n        func_match = self.func_regex.search(tool_call_body)\n        if not func_match:\n            return None\n\n        func_name = func_match.group(1).strip()\n        func_body = func_match.group(2)\n\n        params = {}\n        for param_match in self.param_regex.finditer(func_body):\n            param_name = param_match.group(1).strip()\n            param_value = param_match.group(2)\n            params[param_name] = _convert_param_value(\n                param_value, param_name, func_name, tools\n            )\n\n        return {\"name\": func_name, \"parameters\": params}\n\n    def supports_structural_tag(self) -> bool:\n        return False\n\n    def structure_info(self) -> _GetInfoFunc:\n        raise NotImplementedError\n"
  },
  {
    "path": "python/sglang/srt/function_call/minimax_m2.py",
    "content": "import json\nimport logging\nimport re\nfrom typing import Any, Dict, List, Tuple\n\nfrom sglang.srt.entrypoints.openai.protocol import Tool\nfrom sglang.srt.function_call.base_format_detector import BaseFormatDetector\nfrom sglang.srt.function_call.core_types import (\n    StreamingParseResult,\n    ToolCallItem,\n    _GetInfoFunc,\n)\n\nlogger = logging.getLogger(__name__)\n\n\nclass MinimaxM2Detector(BaseFormatDetector):\n    \"\"\"\n    Detector for MiniMax M2 models.\n    Assumes function call format:\n        <minimax:tool_call>\n        <invoke name=\"func1\">\n        <parameter name=\"param1\">value1</parameter>\n        <parameter name=\"param2\">value2</parameter>\n        </invoke>\n        </minimax:tool_call>\n    \"\"\"\n\n    def __init__(self):\n        super().__init__()\n        self.tool_call_start_token: str = \"<minimax:tool_call>\"\n        self.tool_call_end_token: str = \"</minimax:tool_call>\"\n        self.tool_call_prefix: str = '<invoke name=\"'\n        self.tool_call_function_end_token: str = \"</invoke>\"\n        self.tool_call_regex = re.compile(\n            r\"<minimax:tool_call>(.*?)</minimax:tool_call>|<minimax:tool_call>(.*?)$\",\n            re.DOTALL,\n        )\n        self.tool_call_function_regex = re.compile(\n            r\"<invoke name=\\\"(.*?)</invoke>|<invoke name=\\\"(.*)$\", re.DOTALL\n        )\n        self.tool_call_parameter_regex = re.compile(\n            r\"<parameter name=\\\"(.*?)</parameter>|<parameter name=\\\"(.*?)$\", re.DOTALL\n        )\n        self._buf: str = \"\"\n\n        # Streaming state variables\n        self._current_function_name: str = \"\"\n        self._current_parameters: Dict[str, Any] = {}\n        self._streamed_parameters: Dict[str, str] = (\n            {}\n        )  # Track what parameter content we've streamed\n        self._in_tool_call: bool = False\n        self._function_name_sent: bool = False\n\n    def has_tool_call(self, text: str) -> bool:\n        return self.tool_call_start_token in text\n\n    def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:\n        normal, calls = self._extract(text, tools)\n        return StreamingParseResult(normal_text=normal, calls=calls)\n\n    def _convert_param_value(self, value: str, param_type: str) -> Any:\n        \"\"\"Convert parameter value to the correct type (legacy single-type version).\"\"\"\n        return self._convert_param_value_with_types(value, [param_type])\n\n    def _extract_types_from_schema(self, schema: Any) -> list[str]:\n        \"\"\"\n        Extract all possible types from a JSON schema definition.\n        Handles anyOf, oneOf, allOf, type arrays, and enum fields.\n\n        Args:\n            schema: The JSON schema definition for a parameter\n\n        Returns:\n            List of type strings (e.g., [\"string\", \"integer\", \"null\"])\n        \"\"\"\n        if schema is None:\n            return [\"string\"]\n\n        if not isinstance(schema, dict):\n            return [\"string\"]\n\n        types: set[str] = set()\n\n        # Handle direct \"type\" field\n        if \"type\" in schema:\n            type_value = schema[\"type\"]\n            if isinstance(type_value, str):\n                types.add(type_value)\n            elif isinstance(type_value, list):\n                for t in type_value:\n                    if isinstance(t, str):\n                        types.add(t)\n\n        # Handle enum - infer types from enum values\n        if \"enum\" in schema and isinstance(schema[\"enum\"], list) and schema[\"enum\"]:\n            for value in schema[\"enum\"]:\n                if value is None:\n                    types.add(\"null\")\n                elif isinstance(value, bool):\n                    types.add(\"boolean\")\n                elif isinstance(value, int):\n                    types.add(\"integer\")\n                elif isinstance(value, float):\n                    types.add(\"number\")\n                elif isinstance(value, str):\n                    types.add(\"string\")\n                elif isinstance(value, list):\n                    types.add(\"array\")\n                elif isinstance(value, dict):\n                    types.add(\"object\")\n\n        # Handle anyOf, oneOf, allOf - recursively extract types\n        for choice_field in (\"anyOf\", \"oneOf\", \"allOf\"):\n            if choice_field in schema and isinstance(schema[choice_field], list):\n                for choice in schema[choice_field]:\n                    extracted = self._extract_types_from_schema(choice)\n                    types.update(extracted)\n\n        # If no types found, default to string\n        if not types:\n            return [\"string\"]\n\n        return list(types)\n\n    def _convert_param_value_with_types(\n        self, value: str, param_types: list[str]\n    ) -> Any:\n        \"\"\"\n        Convert parameter value to the correct type based on a list of possible types.\n        Tries each type in order until one succeeds.\n\n        Args:\n            value: The string value to convert\n            param_types: List of possible type strings\n\n        Returns:\n            The converted value\n        \"\"\"\n        if value.lower() == \"null\":\n            return None\n\n        # Normalize types\n        normalized_types = [t.lower() for t in param_types]\n\n        # Try null first if it's in the list\n        if \"null\" in normalized_types or value.lower() in (\"null\", \"none\", \"nil\"):\n            return None\n\n        # Try each type in order of preference (most specific first, string as fallback)\n        # Priority: integer > number > boolean > object > array > string\n        type_priority = [\n            \"integer\",\n            \"int\",\n            \"number\",\n            \"float\",\n            \"boolean\",\n            \"bool\",\n            \"object\",\n            \"array\",\n            \"string\",\n            \"str\",\n            \"text\",\n        ]\n\n        for param_type in type_priority:\n            if param_type not in normalized_types:\n                continue\n\n            if param_type in [\"string\", \"str\", \"text\"]:\n                return value\n            elif param_type in [\"integer\", \"int\"]:\n                try:\n                    return int(value)\n                except (ValueError, TypeError):\n                    continue\n            elif param_type in [\"number\", \"float\"]:\n                try:\n                    val = float(value)\n                    return val if val != int(val) else int(val)\n                except (ValueError, TypeError):\n                    continue\n            elif param_type in [\"boolean\", \"bool\"]:\n                lower_val = value.lower().strip()\n                if lower_val in [\"true\", \"1\", \"yes\", \"on\"]:\n                    return True\n                elif lower_val in [\"false\", \"0\", \"no\", \"off\"]:\n                    return False\n                continue\n            elif param_type in [\"object\", \"array\"]:\n                try:\n                    return json.loads(value)\n                except json.JSONDecodeError:\n                    continue\n\n        # Fallback: try JSON parse, then return as string\n        try:\n            return json.loads(value)\n        except json.JSONDecodeError:\n            return value\n\n    def _get_param_types_from_config(\n        self, param_name: str, param_config: dict\n    ) -> list[str]:\n        \"\"\"\n        Get parameter types from parameter configuration.\n        Handles anyOf, oneOf, allOf, and direct type definitions.\n\n        Args:\n            param_name: The name of the parameter\n            param_config: The properties dict from the tool schema\n\n        Returns:\n            List of type strings\n        \"\"\"\n        if param_name not in param_config:\n            return [\"string\"]\n\n        param_schema = param_config[param_name]\n        if not isinstance(param_schema, dict):\n            return [\"string\"]\n\n        return self._extract_types_from_schema(param_schema)\n\n    def parse_streaming_increment(\n        self, new_text: str, tools: List[Tool]\n    ) -> StreamingParseResult:\n        self._buf += new_text\n        normal = \"\"\n        calls: List[ToolCallItem] = []\n\n        # Build tool indices for validation\n        if not hasattr(self, \"_tool_indices\"):\n            self._tool_indices = self._get_tool_indices(tools)\n\n        while True:\n            # If we're not in a tool call and don't see a start token, return normal text\n            if not self._in_tool_call and self.tool_call_start_token not in self._buf:\n                normal += self._buf\n                self._buf = \"\"\n                break\n\n            # Look for tool call start\n            if not self._in_tool_call:\n                s = self._buf.find(self.tool_call_start_token)\n                if s == -1:\n                    normal += self._buf\n                    self._buf = \"\"\n                    break\n\n                normal += self._buf[:s]\n                self._buf = self._buf[s:]\n\n                self._in_tool_call = True\n                self._function_name_sent = False\n                self._current_function_name = \"\"\n                self._current_parameters = {}\n                self._streamed_parameters = {}\n\n                # Remove the start token\n                self._buf = self._buf[len(self.tool_call_start_token) :]\n                continue\n\n            # We're in a tool call, try to parse function name if not sent yet\n            if not self._function_name_sent:\n                # Look for function name pattern: <invoke name=name>\n                function_match = re.search(r\"<invoke name=\\\"([^>]+)\\\">\", self._buf)\n                if function_match:\n                    function_name = function_match.group(1).strip()\n\n                    # Validate function name\n                    if function_name in self._tool_indices:\n                        self._current_function_name = function_name\n                        self._function_name_sent = True\n\n                        # Initialize tool call tracking\n                        if self.current_tool_id == -1:\n                            self.current_tool_id = 0\n\n                        # Ensure tracking arrays are large enough\n                        while len(self.prev_tool_call_arr) <= self.current_tool_id:\n                            self.prev_tool_call_arr.append({})\n                        while len(self.streamed_args_for_tool) <= self.current_tool_id:\n                            self.streamed_args_for_tool.append(\"\")\n\n                        # Store tool call info\n                        self.prev_tool_call_arr[self.current_tool_id] = {\n                            \"name\": function_name,\n                            \"arguments\": {},\n                        }\n\n                        # Send tool name with empty parameters\n                        calls.append(\n                            ToolCallItem(\n                                tool_index=self.current_tool_id,\n                                name=function_name,\n                                parameters=\"\",\n                            )\n                        )\n\n                        # Remove the processed function declaration\n                        self._buf = self._buf[function_match.end() :]\n                        continue\n                    else:\n                        # Invalid function name, reset state\n                        logger.warning(f\"Invalid function name: {function_name}\")\n                        self._reset_streaming_state()\n                        normal += self._buf\n                        self._buf = \"\"\n                        break\n                else:\n                    # Function name not complete yet, wait for more text\n                    break\n\n            # Parse parameters incrementally\n            if self._function_name_sent:\n                # Process parameters and get any calls to emit\n                parameter_calls = self._parse_and_stream_parameters(self._buf, tools)\n                calls.extend(parameter_calls)\n\n                # Check if tool call is complete\n                if self.tool_call_function_end_token in self._buf:\n                    end_pos = self._buf.find(self.tool_call_function_end_token)\n\n                    # Add closing brace to complete the JSON object\n                    current_streamed = self.streamed_args_for_tool[self.current_tool_id]\n                    if current_streamed:\n                        # Count opening and closing braces to check if JSON is complete\n                        open_braces = current_streamed.count(\"{\")\n                        close_braces = current_streamed.count(\"}\")\n                        if open_braces > close_braces:\n                            calls.append(\n                                ToolCallItem(\n                                    tool_index=self.current_tool_id,\n                                    name=None,\n                                    parameters=\"}\",\n                                )\n                            )\n                            self.streamed_args_for_tool[self.current_tool_id] = (\n                                current_streamed + \"}\"\n                            )\n\n                    # Complete the tool call\n                    self._buf = self._buf[\n                        end_pos + len(self.tool_call_function_end_token) :\n                    ]\n                    self._reset_streaming_state(True)\n                    self.current_tool_id += 1\n                    continue\n                else:\n                    # Tool call not complete yet, wait for more text\n                    break\n\n        return StreamingParseResult(normal_text=normal, calls=calls)\n\n    def _parse_and_stream_parameters(\n        self, text_to_parse: str, tools: List[Tool]\n    ) -> List[ToolCallItem]:\n        \"\"\"\n        Parse complete parameter blocks from text and return any tool call items to emit.\n\n        This method:\n        1. Finds all complete <parameter> blocks\n        2. Parses them into a dictionary\n        3. Compares with current parameters and generates diff if needed\n        4. Updates internal state\n\n        Args:\n            text_to_parse: The text to search for parameter blocks\n\n        Returns:\n            List of ToolCallItem objects to emit (may be empty)\n        \"\"\"\n        calls: List[ToolCallItem] = []\n\n        # Find all complete parameter patterns\n        param_matches = list(\n            re.finditer(\n                r\"<parameter name=\\\"([^>]+)\\\">(.*?)</parameter>\",\n                text_to_parse,\n                re.DOTALL,\n            )\n        )\n\n        # Build new parameters dictionary\n        new_params = {}\n        for match in param_matches:\n            param_name = match.group(1).strip()\n            param_value = match.group(2)\n            new_params[param_name] = self._parse_parameter(\n                self._current_function_name, param_name, param_value, tools\n            )\n\n        # Calculate parameter diff to stream with proper incremental JSON building\n        if new_params != self._current_parameters:\n            previous_args_json = self.streamed_args_for_tool[self.current_tool_id]\n\n            # Build incremental JSON properly\n            if not self._current_parameters:\n                # First parameter(s) - start JSON object but don't close it yet\n                items = []\n                for key, value in new_params.items():\n                    items.append(\n                        f\"{json.dumps(key, ensure_ascii=False)}: {json.dumps(value, ensure_ascii=False)}\"\n                    )\n                json_fragment = \"{\" + \", \".join(items)\n\n                calls.append(\n                    ToolCallItem(\n                        tool_index=self.current_tool_id,\n                        name=None,\n                        parameters=json_fragment,\n                    )\n                )\n                self.streamed_args_for_tool[self.current_tool_id] = json_fragment\n\n            else:\n                # Additional parameters - add them incrementally\n                new_keys = set(new_params.keys()) - set(self._current_parameters.keys())\n                if new_keys:\n                    # Build the continuation part (no closing brace yet)\n                    continuation_parts = []\n                    for key in new_keys:\n                        value = new_params[key]\n                        continuation_parts.append(\n                            f\"{json.dumps(key, ensure_ascii=False)}: {json.dumps(value, ensure_ascii=False)}\"\n                        )\n\n                    json_fragment = \", \" + \", \".join(continuation_parts)\n\n                    calls.append(\n                        ToolCallItem(\n                            tool_index=self.current_tool_id,\n                            name=None,\n                            parameters=json_fragment,\n                        )\n                    )\n                    self.streamed_args_for_tool[self.current_tool_id] = (\n                        previous_args_json + json_fragment\n                    )\n\n            # Update current state\n            self._current_parameters = new_params\n            self.prev_tool_call_arr[self.current_tool_id][\"arguments\"] = new_params\n\n        return calls\n\n    def _reset_streaming_state(self, still_in_tool_call: bool = False):\n        \"\"\"Reset streaming state for the next tool call\"\"\"\n        self._in_tool_call = still_in_tool_call\n        self._function_name_sent = False\n        self._current_function_name = \"\"\n        self._current_parameters = {}\n        self._streamed_parameters = {}\n        self.current_tool_name_sent = False\n\n    def _extract(self, text: str, tools: List[Tool]) -> Tuple[str, List[ToolCallItem]]:\n        normal_parts: List[str] = []\n        calls: List[ToolCallItem] = []\n        cursor = 0\n        while True:\n            s = text.find(self.tool_call_start_token, cursor)\n            if s == -1:\n                normal_parts.append(text[cursor:])\n                break\n            normal_parts.append(text[cursor:s])\n            e = text.find(self.tool_call_end_token, s)\n            if e == -1:\n                normal_parts.append(text[s:])\n                break\n            block = text[s : e + len(self.tool_call_end_token)]\n            cursor = e + len(self.tool_call_end_token)\n            calls.extend(self._parse_block(block, tools))\n        return \"\".join(normal_parts), calls\n\n    def _parse_block(self, block: str, tools: List[Tool]) -> List[ToolCallItem]:\n        res: List[ToolCallItem] = []\n        for m in self.tool_call_function_regex.findall(block):\n            txt = m[0] if m[0] else m[1]\n            if '\">' not in txt:\n                continue\n            idx = txt.index('\">')\n            fname = txt[:idx].strip()\n            body = txt[idx + 2 :]\n            params: Dict[str, Any] = {}\n            for pm in self.tool_call_parameter_regex.findall(body):\n                ptxt = pm[0] if pm[0] else pm[1]\n                if '\">' not in ptxt:\n                    continue\n                pidx = ptxt.index('\">')\n                pname = ptxt[:pidx].strip()\n                pval = ptxt[pidx + 2 :].lstrip(\"\\n\").rstrip(\"\\n\")\n                params[pname] = self._parse_parameter(fname, pname, pval, tools)\n            raw = {\"name\": fname, \"arguments\": params}\n            try:\n                # TODO: fix idx in function call, the index for a function\n                # call will always be -1 in parse_base_json\n                res.extend(self.parse_base_json(raw, tools))\n            except Exception:\n                logger.warning(\"invalid tool call for %s dropped\", fname)\n        return res\n\n    def _parse_parameter(\n        self, fname: str, pname: str, pval: str, tools: List[Tool]\n    ) -> Any:\n        param_config = {}\n        for tool in tools:\n            if tool.function.name == fname and tool.function.parameters is not None:\n                parameters = tool.function.parameters\n                if isinstance(parameters, dict) and \"properties\" in parameters:\n                    param_config = parameters[\"properties\"]\n                    break\n\n        param_type = self._get_param_types_from_config(pname, param_config)\n        return self._convert_param_value_with_types(pval, param_type)\n\n    def supports_structural_tag(self) -> bool:\n        return False\n\n    def structure_info(self) -> _GetInfoFunc:\n        raise NotImplementedError\n"
  },
  {
    "path": "python/sglang/srt/function_call/mistral_detector.py",
    "content": "import json\nimport logging\nfrom typing import Any, List, Optional, Tuple\n\nfrom sglang.srt.entrypoints.openai.protocol import Tool\nfrom sglang.srt.function_call.base_format_detector import BaseFormatDetector\nfrom sglang.srt.function_call.core_types import (\n    StreamingParseResult,\n    StructureInfo,\n    ToolCallItem,\n    _GetInfoFunc,\n)\nfrom sglang.srt.function_call.utils import _is_complete_json\n\nlogger = logging.getLogger(__name__)\n\n\nclass MistralDetector(BaseFormatDetector):\n    \"\"\"\n    Detector for Mistral tool/function call formats.\n\n    Supported formats:\n\n    1) JSON-array format:\n       `[TOOL_CALLS] [{\"name\": \"...\", \"arguments\": {...}}, ...]`\n\n    2) Compact format (common in newer templates/models, especially in streaming):\n       `[TOOL_CALLS]tool_name[ARGS]{...}`\n       (also tolerates missing delimiters like `]` after `[TOOL_CALLS` and/or `[ARGS]` while streaming)\n\n    Reference: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3?chat_template=default\n    \"\"\"\n\n    def __init__(self):\n        \"\"\"Initialize tokens and streaming state.\"\"\"\n        super().__init__()\n        # Canonical Mistral prefix for JSON-array tool calls.\n        self.bot_token = \"[TOOL_CALLS] [\"\n        # Common marker shared by both JSON-array and compact formats.\n        self._tool_calls_marker = \"[TOOL_CALLS\"\n        self.eot_token = \"]\"\n        self.tool_call_separator = \", \"\n\n    def has_tool_call(self, text: str) -> bool:\n        \"\"\"Return True if the text contains either supported tool-call marker.\"\"\"\n        return self._tool_calls_marker in text\n\n    def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:\n        \"\"\"\n        One-time parsing: Detects and parses tool calls in the provided text.\n\n        :param text: The complete text to parse.\n        :param tools: List of available tools.\n        :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.\n        \"\"\"\n        marker_idx = text.find(self._tool_calls_marker)\n        if marker_idx == -1:\n            return StreamingParseResult(normal_text=text, calls=[])\n\n        normal_text = text[:marker_idx].strip()\n        tool_part = text[marker_idx:]\n\n        # Canonical: `[TOOL_CALLS] [{...}, ...]`\n        if self.bot_token in tool_part:\n            json_array_str = self._extract_json_array(tool_part)\n            if not json_array_str:\n                return StreamingParseResult(normal_text=normal_text, calls=[])\n\n            calls: list = []\n            try:\n                function_call_arr = json.loads(json_array_str)\n                if not isinstance(function_call_arr, list):\n                    function_call_arr = [function_call_arr]\n                calls = self.parse_base_json(function_call_arr, tools)\n            except json.JSONDecodeError as e:\n                logger.warning(\n                    f\"Failed to parse JSON part: {json_array_str}, JSON parse error: {str(e)}\"\n                )\n            json_pos = tool_part.find(json_array_str) if json_array_str else -1\n            trailing_text = (\n                tool_part[json_pos + len(json_array_str) :].strip()\n                if json_pos != -1\n                else \"\"\n            )\n            combined_normal = (\n                (normal_text + \" \" + trailing_text).strip()\n                if trailing_text\n                else normal_text\n            )\n            return StreamingParseResult(normal_text=combined_normal, calls=calls)\n\n        # Compact: `[TOOL_CALLS]tool_name[ARGS]{...}`\n        # Loop to extract all consecutive compact tool calls.\n        all_calls: list = []\n        remaining = tool_part\n        while remaining:\n            parsed = self._try_parse_compact_args_format(remaining)\n            if not parsed:\n                break\n            func_name, args_obj, consumed = parsed\n            new_calls = self.parse_base_json(\n                {\"name\": func_name, \"arguments\": args_obj}, tools\n            )\n            all_calls.extend(new_calls)\n            remaining = remaining[consumed:].strip()\n\n        if not all_calls:\n            return StreamingParseResult(normal_text=normal_text, calls=[])\n\n        combined_normal = (\n            (normal_text + \" \" + remaining).strip() if remaining else normal_text\n        )\n        return StreamingParseResult(normal_text=combined_normal, calls=all_calls)\n\n    def parse_streaming_increment(\n        self, new_text: str, tools: List[Tool]\n    ) -> StreamingParseResult:\n        \"\"\"\n        Streaming parsing for both JSON-array and compact formats.\n\n        For the compact format, this buffers until the JSON arguments payload is complete,\n        then emits two items: tool name (with empty parameters) and a full arguments JSON\n        chunk (OpenAI streaming semantics).\n        \"\"\"\n        self._buffer += new_text\n        current_text = self._buffer\n\n        # No marker: either flush as normal text or keep buffering a partial marker.\n        if self._tool_calls_marker not in current_text:\n            if not self._ends_with_partial_token(self._buffer, self._tool_calls_marker):\n                normal_text = self._buffer\n                self._buffer = \"\"\n                if self.eot_token in normal_text:\n                    normal_text = normal_text.replace(self.eot_token, \"\")\n                return StreamingParseResult(normal_text=normal_text)\n            return StreamingParseResult()\n\n        # If there's leading normal text before the marker, stream it out first.\n        marker_pos = current_text.find(self._tool_calls_marker)\n        if marker_pos > 0:\n            normal_text = current_text[:marker_pos]\n            self._buffer = current_text[marker_pos:]\n            return StreamingParseResult(normal_text=normal_text)\n\n        # Build tool indices if not already built.\n        if not hasattr(self, \"_tool_indices\"):\n            self._tool_indices = self._get_tool_indices(tools)\n\n        # Try compact first; JSON-array requires `] [` and often arrives later in streaming.\n        compact = self._try_parse_compact_args_format(current_text)\n        if compact:\n            func_name, args_obj, consumed = compact\n            if func_name not in self._tool_indices:\n                # Unknown tool: treat as normal text and reset state.\n                normal_text = self._buffer\n                self._buffer = \"\"\n                return StreamingParseResult(normal_text=normal_text)\n\n            # Initialize state if this is the first tool call.\n            if self.current_tool_id == -1:\n                self.current_tool_id = 0\n                self.prev_tool_call_arr = []\n                self.streamed_args_for_tool = []\n\n            args_json = json.dumps(args_obj, ensure_ascii=False)\n            tool_id = self.current_tool_id\n\n            # Ensure arrays are large enough.\n            while len(self.prev_tool_call_arr) <= tool_id:\n                self.prev_tool_call_arr.append({})\n            while len(self.streamed_args_for_tool) <= tool_id:\n                self.streamed_args_for_tool.append(\"\")\n\n            self.prev_tool_call_arr[tool_id] = {\n                \"name\": func_name,\n                \"arguments\": args_obj,\n            }\n            self.streamed_args_for_tool[tool_id] = args_json\n\n            calls: List[ToolCallItem] = [\n                ToolCallItem(tool_index=tool_id, name=func_name, parameters=\"\"),\n                ToolCallItem(tool_index=tool_id, name=None, parameters=args_json),\n            ]\n\n            # Consume parsed content from buffer.\n            self._buffer = current_text[consumed:]\n            self.current_tool_id += 1\n            self.current_tool_name_sent = False\n            return StreamingParseResult(normal_text=\"\", calls=calls)\n\n        # Canonical format delegates to the BaseFormatDetector JSON streaming logic.\n        if self.bot_token in current_text:\n            return super().parse_streaming_increment(new_text=\"\", tools=tools)\n\n        # Otherwise, keep buffering.\n        return StreamingParseResult()\n\n    def _try_parse_compact_args_format(\n        self, text: str\n    ) -> Optional[Tuple[str, Any, int]]:\n        \"\"\"\n        Parse the compact tool call format:\n            `[TOOL_CALLS]tool_name[ARGS]{...}`\n\n        Tolerates common streaming variants where delimiters are missing:\n            `[TOOL_CALLStool_name[ARGS{...}`\n\n        Returns:\n            (tool_name, arguments_obj, consumed_end_index) if a complete JSON arguments\n            payload is present; otherwise None.\n        \"\"\"\n        start = text.find(self._tool_calls_marker)\n        if start == -1:\n            return None\n\n        i = start + len(self._tool_calls_marker)  # position after \"[TOOL_CALLS\"\n        if i < len(text) and text[i] == \"]\":\n            i += 1\n        while i < len(text) and text[i].isspace():\n            i += 1\n\n        args_marker = \"[ARGS\"\n        args_pos = text.find(args_marker, i)\n        if args_pos == -1:\n            return None\n\n        func_name = text[i:args_pos].strip()\n        if not func_name:\n            return None\n\n        j = args_pos + len(args_marker)\n        if j < len(text) and text[j] == \"]\":\n            j += 1\n        while j < len(text) and text[j].isspace():\n            j += 1\n\n        if j >= len(text) or text[j] not in \"{[\":\n            return None\n\n        json_str, end_idx = self._extract_json_value(text, j)\n        if not json_str:\n            return None\n        if not _is_complete_json(json_str):\n            return None\n\n        try:\n            args_obj = json.loads(json_str)\n        except json.JSONDecodeError:\n            return None\n\n        return func_name, args_obj, end_idx\n\n    def _extract_json_value(\n        self, text: str, json_start: int\n    ) -> Tuple[Optional[str], int]:\n        \"\"\"\n        Extract a JSON value (object or array) starting at json_start using bracket counting,\n        robust to nested braces/brackets inside strings.\n\n        Returns:\n            (json_str_or_None, end_index_exclusive)\n        \"\"\"\n        if json_start >= len(text) or text[json_start] not in \"{[\":\n            return None, json_start\n\n        opening = text[json_start]\n        closing = \"}\" if opening == \"{\" else \"]\"\n        depth = 0\n        in_string = False\n        escape_next = False\n\n        for k in range(json_start, len(text)):\n            ch = text[k]\n            if escape_next:\n                escape_next = False\n                continue\n            if ch == \"\\\\\":\n                escape_next = True\n                continue\n            if ch == '\"' and not escape_next:\n                in_string = not in_string\n                continue\n            if in_string:\n                continue\n            if ch == opening:\n                depth += 1\n            elif ch == closing:\n                depth -= 1\n                if depth == 0:\n                    return text[json_start : k + 1], k + 1\n\n        return None, json_start\n\n    def _extract_json_array(self, text: str) -> str:\n        \"\"\"\n        Extract the JSON array part using bracket counting to handle nested brackets.\n\n        :param text: The complete text containing [TOOL_CALLS] [...]\n        :return: The JSON array string or None if not found\n        \"\"\"\n        start_idx = text.find(self.bot_token)\n        if start_idx == -1:\n            return None\n\n        # Start from the opening bracket after [TOOL_CALLS]\n        json_start = (\n            start_idx + len(self.bot_token) - 1\n        )  # -1 to include the opening bracket\n        bracket_count = 0\n        in_string = False\n        escape_next = False\n\n        for i in range(json_start, len(text)):\n            char = text[i]\n\n            if escape_next:\n                escape_next = False\n                continue\n\n            if char == \"\\\\\":\n                escape_next = True\n                continue\n\n            if char == '\"' and not escape_next:\n                in_string = not in_string\n                continue\n\n            if not in_string:\n                if char == \"[\":\n                    bracket_count += 1\n                elif char == \"]\":\n                    bracket_count -= 1\n                    if bracket_count == 0:\n                        return text[json_start : i + 1]\n\n        return None\n\n    def structure_info(self) -> _GetInfoFunc:\n        return lambda name: StructureInfo(\n            begin='[TOOL_CALLS] [{\"name\":\"' + name + '\", \"arguments\":',\n            end=\"}]\",\n            trigger=\"[TOOL_CALLS]\",\n        )\n"
  },
  {
    "path": "python/sglang/srt/function_call/pythonic_detector.py",
    "content": "import ast\nimport json\nimport logging\nimport re\nfrom typing import List\n\nfrom sglang.srt.entrypoints.openai.protocol import Tool\nfrom sglang.srt.environ import envs\nfrom sglang.srt.function_call.base_format_detector import BaseFormatDetector\nfrom sglang.srt.function_call.core_types import (\n    StreamingParseResult,\n    ToolCallItem,\n    _GetInfoFunc,\n)\n\nlogger = logging.getLogger(__name__)\n\n\nclass PythonicDetector(BaseFormatDetector):\n    \"\"\"\n    Detector for Llama-4 models with Pythonic tool call format.\n\n    The Pythonic format uses Python function call syntax within square brackets,\n    with arguments as Python literals rather than JSON.\n\n    Format Structure:\n    ```\n    [tool1(arg1=val1, arg2=val2), tool2(arg1=val3)]\n    ```\n\n    Reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct?chat_template=default\n    \"\"\"\n\n    def __init__(self):\n        super().__init__()\n        self.tool_call_regex = re.compile(\n            r\"\\[([a-zA-Z]+\\w*\\(([a-zA-Z]+\\w*=.*,\\s*)*([a-zA-Z]+\\w*=.*\\s)?\\),\\s*)*([a-zA-Z]+\\w*\\(([a-zA-Z]+\\w*=.*,\\s*)*([a-zA-Z]+\\w*=.*\\s*)?\\)\\s*)+\\]\",\n            re.DOTALL,\n        )\n\n    @staticmethod\n    def _text_strip(text: str) -> str:\n        # Llama 4 model sometime will output <|python_start|> and <|python_end|> tokens\n        # remove those tokens\n        text = text.replace(\"<|python_start|>\", \"\")\n        text = text.replace(\"<|python_end|>\", \"\")\n        return text\n\n    def has_tool_call(self, text: str) -> bool:\n        return bool(self.tool_call_regex.search(self._text_strip(text.strip())))\n\n    def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:\n        # Try parsing the text as a Python list of function calls\n        text = text.strip()\n\n        # Remove unexpected <|python_start|> and <|python_end|> for llama4\n        text = self._text_strip(text)\n\n        match = self.tool_call_regex.search(text)\n        if match is None:\n            return StreamingParseResult(normal_text=text, calls=[])\n\n        # Extract the tool call part and any text before/after it\n        tool_call_start = match.start()\n        tool_call_end = match.end()\n\n        normal_text_before = text[:tool_call_start] if tool_call_start > 0 else \"\"\n        tool_call_text = text[tool_call_start:tool_call_end]\n        normal_text_after = text[tool_call_end:] if tool_call_end < len(text) else \"\"\n\n        # Combine normal text\n        normal_text = normal_text_before + normal_text_after\n\n        try:\n            module = ast.parse(tool_call_text)\n            parsed = getattr(module.body[0], \"value\", None)\n            if not (\n                isinstance(parsed, ast.List)\n                and all(isinstance(e, ast.Call) for e in parsed.elts)\n            ):\n                return StreamingParseResult(normal_text=normal_text, calls=[])\n\n            calls = []\n            tool_indices = self._get_tool_indices(tools)\n            for call_index, call in enumerate(parsed.elts):\n                if not isinstance(call.func, ast.Name):\n                    continue\n                function_name = call.func.id\n                # Validate that the function exists in the tools\n                if function_name not in tool_indices:\n                    logger.warning(\n                        f\"Model attempted to call undefined function: {function_name}\"\n                    )\n                    if not envs.SGLANG_FORWARD_UNKNOWN_TOOLS.get():\n                        continue  # Skip unknown tools (default legacy behavior)\n\n                arguments = {}\n                for keyword in call.keywords:\n                    arguments[keyword.arg] = self._get_parameter_value(keyword.value)\n                calls.append(\n                    ToolCallItem(\n                        tool_index=call_index,  # Use the call index in the response, not tool position\n                        name=function_name,\n                        parameters=json.dumps(arguments, ensure_ascii=False),\n                    )\n                )\n\n            return StreamingParseResult(normal_text=normal_text, calls=calls)\n        except Exception:\n            logger.exception(\"Error in pythonic tool call parsing.\")\n            return StreamingParseResult(normal_text=normal_text, calls=[])\n\n    def _find_matching_bracket(self, buffer: str, start: int) -> int:\n        \"\"\"\n        Find the matching closing bracket for the opening bracket at start position.\n        Properly handles nested brackets.\n\n        Args:\n            buffer: The text buffer to search in\n            start: Position of the opening bracket '['\n\n        Returns:\n            Position of the matching closing bracket ']', or -1 if not found\n        \"\"\"\n        bracket_count = 0\n        for i in range(start, len(buffer)):\n            if buffer[i] == \"[\":\n                bracket_count += 1\n            elif buffer[i] == \"]\":\n                bracket_count -= 1\n                if bracket_count == 0:\n                    return i\n        return -1  # No matching bracket found\n\n    def _strip_and_split_buffer(self, buffer: str) -> tuple[str, str]:\n        \"\"\"\n        Strip special tokens from buffer and split into safe_text and held_back_text.\n\n        Returns:\n            tuple of (safe_text_to_output, text_to_hold_in_buffer)\n        \"\"\"\n        # Check if original buffer ends with a partial token at the end\n        special_tokens = [\"<|python_start|>\", \"<|python_end|>\"]\n\n        for token in special_tokens:\n            partial_length = self._ends_with_partial_token(buffer, token)\n            if partial_length > 0:\n                # Split buffer: safe part + held back partial token\n                safe_text = buffer[:-partial_length]\n                held_back = buffer[-partial_length:]\n                # Strip complete special tokens from safe part only\n                safe_text = self._text_strip(safe_text)\n                return safe_text, held_back\n\n        # No partial tokens found, strip complete tokens from entire buffer\n        safe_text = self._text_strip(buffer)\n        return safe_text, \"\"\n\n    def parse_streaming_increment(\n        self, new_text: str, tools: List[Tool]\n    ) -> StreamingParseResult:\n        \"\"\"\n        Streaming incremental parsing for pythonic tool calls.\n        Buffers input until a complete pythonic tool call (from [ to ]) is found,\n        then parses and emits any detected calls.\n        \"\"\"\n        self._buffer += new_text\n\n        # Strip special tokens from entire buffer and handle partial tokens\n        stripped_buffer, held_back = self._strip_and_split_buffer(self._buffer)\n\n        start = stripped_buffer.find(\"[\")\n\n        if start == -1:\n            # No tool call bracket found\n            self._buffer = held_back\n            return StreamingParseResult(normal_text=stripped_buffer)\n\n        normal_text = stripped_buffer[:start] if start > 0 else \"\"\n\n        end = self._find_matching_bracket(stripped_buffer, start)\n        if end != -1:\n            # Found complete tool call\n            call_text = stripped_buffer[start : end + 1]\n            result = self.detect_and_parse(call_text, tools)\n\n            # Update buffer with remaining text after tool call plus any held back text\n            remaining_text = stripped_buffer[end + 1 :] + held_back\n            self._buffer = remaining_text\n\n            # If we had normal text before the tool call, add it to the result\n            if normal_text:\n                result.normal_text = normal_text + (result.normal_text or \"\")\n\n            return result\n\n        # We have an opening bracket but no closing bracket yet\n        # Put back everything from the bracket onwards plus held back text\n        self._buffer = stripped_buffer[start:] + held_back\n\n        if normal_text:\n            return StreamingParseResult(normal_text=normal_text)\n\n        # Otherwise, we're still accumulating a potential tool call\n        return StreamingParseResult(normal_text=\"\")\n\n    def _get_parameter_value(self, val):\n        if isinstance(val, ast.Constant):\n            return val.value\n        elif isinstance(val, ast.Dict):\n            return {\n                k.value: self._get_parameter_value(v)\n                for k, v in zip(val.keys, val.values)\n            }\n        elif isinstance(val, ast.List):\n            return [self._get_parameter_value(v) for v in val.elts]\n        else:\n            raise ValueError(\"Tool call arguments must be literals\")\n\n    def supports_structural_tag(self) -> bool:\n        return False\n\n    def structure_info(self) -> _GetInfoFunc:\n        raise NotImplementedError\n"
  },
  {
    "path": "python/sglang/srt/function_call/qwen25_detector.py",
    "content": "import json\nimport logging\nimport re\nfrom typing import List\n\nfrom sglang.srt.entrypoints.openai.protocol import Tool\nfrom sglang.srt.function_call.base_format_detector import BaseFormatDetector\nfrom sglang.srt.function_call.core_types import (\n    StreamingParseResult,\n    StructureInfo,\n    _GetInfoFunc,\n)\n\nlogger = logging.getLogger(__name__)\n\n\nclass Qwen25Detector(BaseFormatDetector):\n    \"\"\"\n    Detector for Qwen 2.5 and Qwen 3 model function call format.\n\n    Format Structure:\n    ```\n    <tool_call>\\n{\"name\":\"func1\", \"arguments\":{...}}\\n</tool_call>\\n<tool_call>\\n{\"name\":\"func2\", \"arguments\":{...}}\\n</tool_call>\n    ```\n\n    Key Components:\n    - Tool Call Tags: `<tool_call>` and `</tool_call>` wrap each individual call\n    - Function Call Object: JSON object with \"name\" and \"arguments\" fields\n\n    Reference: https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct?chat_template=default\n    \"\"\"\n\n    def __init__(self):\n        \"\"\"\n        Initializes the detector with necessary state variables.\n        \"\"\"\n        super().__init__()\n        self.bot_token = \"<tool_call>\\n\"\n        self.eot_token = \"\\n</tool_call>\"\n        self.tool_call_separator = \"\\n\"\n        self._normal_text_buffer = \"\"  # Buffer for handling partial end tokens\n\n    def has_tool_call(self, text: str) -> bool:\n        \"\"\"Check if the text contains a Qwen 2.5 format tool call.\"\"\"\n        return self.bot_token in text\n\n    def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:\n        \"\"\"\n        One-time parsing: Detects and parses tool calls in the provided text.\n\n        :param text: The complete text to parse.\n        :param tools: List of available tools.\n        :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.\n        \"\"\"\n        idx = text.find(self.bot_token)\n        normal_text = text[:idx].strip() if idx != -1 else text\n        if self.bot_token not in text:\n            return StreamingParseResult(normal_text=normal_text, calls=[])\n\n        # Find all <tool_call>\\n...\\n</tool_call> blocks\n        pattern = rf\"{re.escape(self.bot_token)}(.*?){re.escape(self.eot_token)}\"\n        match_result_list = re.findall(pattern, text, re.DOTALL)\n        calls = []\n        for match_result in match_result_list:\n            try:\n                parsed_call = json.loads(match_result.strip())\n                calls.extend(self.parse_base_json(parsed_call, tools))\n            except json.JSONDecodeError as e:\n                logger.warning(\n                    f\"Failed to parse JSON part: {match_result}, JSON parse error: {str(e)}\"\n                )\n                continue\n        return StreamingParseResult(normal_text=normal_text, calls=calls)\n\n    def parse_streaming_increment(\n        self, new_text: str, tools: List[Tool]\n    ) -> StreamingParseResult:\n        \"\"\"\n        Streaming incremental parsing for Qwen 2.5 tool calls.\n        Uses base class implementation with buffering to handle partial end tokens.\n        \"\"\"\n        result = super().parse_streaming_increment(new_text, tools)\n\n        # Handle partial end tokens that are streamed character by character\n        if result.normal_text:\n            self._normal_text_buffer += result.normal_text\n\n            # Check if buffer contains complete end token (without leading newline)\n            end_token_without_newline = self.eot_token[1:]  # \"</tool_call>\"\n            if end_token_without_newline in self._normal_text_buffer:\n                cleaned_text = self._normal_text_buffer.replace(\n                    end_token_without_newline, \"\"\n                )\n                self._normal_text_buffer = \"\"\n                result.normal_text = cleaned_text\n            else:\n                # Check if buffer might contain partial end token at the end\n                partial_match_len = self._ends_with_partial_token(\n                    self._normal_text_buffer, end_token_without_newline\n                )\n\n                if partial_match_len:\n                    # Keep potential partial match in buffer, return the rest\n                    result.normal_text = self._normal_text_buffer[:-partial_match_len]\n                    self._normal_text_buffer = self._normal_text_buffer[\n                        -partial_match_len:\n                    ]\n                else:\n                    # No partial match, return all buffered text\n                    result.normal_text = self._normal_text_buffer\n                    self._normal_text_buffer = \"\"\n\n        return result\n\n    def structure_info(self) -> _GetInfoFunc:\n        return lambda name: StructureInfo(\n            begin='<tool_call>\\n{\"name\":\"' + name + '\", \"arguments\":',\n            end=\"}\\n</tool_call>\",\n            trigger=\"<tool_call>\",\n        )\n"
  },
  {
    "path": "python/sglang/srt/function_call/qwen3_coder_detector.py",
    "content": "import ast\nimport json\nimport logging\nimport re\nfrom typing import Any, List, Optional\n\nfrom sglang.srt.entrypoints.openai.protocol import Tool\nfrom sglang.srt.function_call.base_format_detector import BaseFormatDetector\nfrom sglang.srt.function_call.core_types import (\n    StreamingParseResult,\n    ToolCallItem,\n    _GetInfoFunc,\n)\n\nlogger = logging.getLogger(__name__)\n\n\nclass Qwen3CoderDetector(BaseFormatDetector):\n    def __init__(self):\n        super().__init__()\n\n        # Sentinel tokens\n        self.tool_call_start_token: str = \"<tool_call>\"\n        self.tool_call_end_token: str = \"</tool_call>\"\n        self.tool_call_prefix: str = \"<function=\"\n        self.function_end_token: str = \"</function>\"\n        self.parameter_prefix: str = \"<parameter=\"\n        self.parameter_end_token: str = \"</parameter>\"\n\n        # Regex for non-streaming fallback\n        self.tool_call_regex = re.compile(r\"<tool_call>(.*?)</tool_call>\", re.DOTALL)\n        self.tool_call_function_regex = re.compile(\n            r\"<function=(.*?)</function>|<function=(.*)$\", re.DOTALL\n        )\n        self.tool_call_parameter_regex = re.compile(\n            r\"<parameter=(.*?)(?:</parameter>|(?=<parameter=)|(?=</function>)|$)\",\n            re.DOTALL,\n        )\n\n        # Streaming State\n        # Base class already initializes _buffer, we just use it directly\n        # No need to check with hasattr - we control the lifecycle through inheritance\n\n        # Index pointing to the next character to be processed in buffer\n        self.parsed_pos: int = 0\n        # Parameter count inside the current tool being processed, used to determine whether to add comma\n        self.current_tool_param_count: int = 0\n        # Flag indicating whether current tool has already sent '{'\n        self.json_started: bool = False\n\n        # [FIX] New state flag: mark whether inside tool_call structure block\n        self.is_inside_tool_call: bool = False\n\n        # Initialize attributes that were missing in the original PR\n        self.current_func_name: Optional[str] = None\n\n    def has_tool_call(self, text: str) -> bool:\n        return self.tool_call_start_token in text\n\n    def _get_arguments_config(\n        self, func_name: str, tools: Optional[list[Tool]]\n    ) -> dict:\n        \"\"\"Extract argument configuration for a function.\"\"\"\n        if tools is None:\n            return {}\n        for config in tools:\n            try:\n                config_type = config.type\n                config_function = config.function\n                config_function_name = config_function.name\n            except AttributeError:\n                continue\n\n            if config_type == \"function\" and config_function_name == func_name:\n                try:\n                    params = config_function.parameters\n                except AttributeError:\n                    return {}\n\n                if isinstance(params, dict) and \"properties\" in params:\n                    return params[\"properties\"]\n                elif isinstance(params, dict):\n                    return params\n                else:\n                    return {}\n        logger.warning(f\"Tool '{func_name}' is not defined in the tools list.\")\n        return {}\n\n    def _convert_param_value(\n        self, param_value: str, param_name: str, param_config: dict, func_name: str\n    ) -> Any:\n        \"\"\"Convert parameter value based on its type in the schema.\"\"\"\n        # Handle null value for any type\n        if param_value.lower() == \"null\":\n            return None\n\n        if param_name not in param_config:\n            if param_config != {}:\n                logger.warning(\n                    f\"Parsed parameter '{param_name}' is not defined in the tool \"\n                    f\"parameters for tool '{func_name}', directly returning the string value.\"\n                )\n            return param_value\n\n        if (\n            isinstance(param_config[param_name], dict)\n            and \"type\" in param_config[param_name]\n        ):\n            param_type = str(param_config[param_name][\"type\"]).strip().lower()\n        else:\n            param_type = \"string\"\n        if param_type in [\"string\", \"str\", \"text\", \"varchar\", \"char\", \"enum\"]:\n            return param_value\n        elif (\n            param_type.startswith(\"int\")\n            or param_type.startswith(\"uint\")\n            or param_type.startswith(\"long\")\n            or param_type.startswith(\"short\")\n            or param_type.startswith(\"unsigned\")\n        ):\n            try:\n                param_value = int(param_value)\n            except Exception:\n                logger.warning(\n                    f\"Parsed value '{param_value}' of parameter '{param_name}' is not an integer in tool \"\n                    f\"'{func_name}', degenerating to string.\"\n                )\n            return param_value\n        elif param_type.startswith(\"num\") or param_type.startswith(\"float\"):\n            try:\n                maybe_convert = (\n                    False if \".\" in param_value or \"e\" in param_value.lower() else True\n                )\n                param_value: float = float(param_value)\n                if maybe_convert and param_value.is_integer():\n                    param_value = int(param_value)\n            except Exception:\n                logger.warning(\n                    f\"Parsed value '{param_value}' of parameter '{param_name}' is not a float in tool \"\n                    f\"'{func_name}', degenerating to string.\"\n                )\n            return param_value\n        elif param_type in [\"boolean\", \"bool\", \"binary\"]:\n            param_value = param_value.lower()\n            if param_value not in [\"true\", \"false\"]:\n                logger.warning(\n                    f\"Parsed value '{param_value}' of parameter '{param_name}' is not a boolean (`true` of `false`) in tool '{func_name}', degenerating to false.\"\n                )\n            return param_value == \"true\"\n        else:\n            if (\n                param_type in [\"object\", \"array\", \"arr\"]\n                or param_type.startswith(\"dict\")\n                or param_type.startswith(\"list\")\n            ):\n                try:\n                    param_value = json.loads(param_value)\n                    return param_value\n                except Exception:\n                    logger.warning(\n                        f\"Parsed value '{param_value}' of parameter '{param_name}' cannot be parsed with json.loads in tool \"\n                        f\"'{func_name}', will try other methods to parse it.\"\n                    )\n            try:\n                param_value = ast.literal_eval(param_value)  # safer\n            except Exception:\n                logger.warning(\n                    f\"Parsed value '{param_value}' of parameter '{param_name}' cannot be converted via Python `ast.literal_eval()` in tool '{func_name}', degenerating to string.\"\n                )\n            return param_value\n\n    def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:\n        \"\"\"One-shot parsing for non-streaming scenarios.\"\"\"\n        if self.tool_call_start_token not in text:\n            return StreamingParseResult(normal_text=text)\n\n        calls = []\n        try:\n            # Simple cleanup of the text to find tool calls\n            # Note: This is a simplified regex approach consistent with vLLM\n            raw_tool_calls = self.tool_call_regex.findall(text)\n            if not raw_tool_calls:\n                # Fallback: maybe the whole text is inside the tag or tags are stripped\n                if self.tool_call_prefix in text:\n                    raw_tool_calls = [text]\n\n            tool_idx = 0\n            for tool_content in raw_tool_calls:\n                # Find function calls\n                funcs = self.tool_call_function_regex.findall(tool_content)\n                for func_match in funcs:\n                    func_body = func_match[0] or func_match[1]\n                    if \">\" not in func_body:\n                        continue\n\n                    name_end = func_body.index(\">\")\n                    func_name = func_body[:name_end]\n                    params_str = func_body[name_end + 1 :]\n\n                    param_config = self._get_arguments_config(func_name, tools)\n                    parsed_params = {}\n\n                    for p_match in self.tool_call_parameter_regex.findall(params_str):\n                        if \">\" not in p_match:\n                            continue\n                        p_idx = p_match.index(\">\")\n                        p_name = p_match[:p_idx]\n                        p_val = p_match[p_idx + 1 :]\n                        # Remove prefixing and trailing \\n\n                        if p_val.startswith(\"\\n\"):\n                            p_val = p_val[1:]\n                        if p_val.endswith(\"\\n\"):\n                            p_val = p_val[:-1]\n\n                        parsed_params[p_name] = self._convert_param_value(\n                            p_val, p_name, param_config, func_name\n                        )\n\n                    calls.append(\n                        ToolCallItem(\n                            tool_index=tool_idx,\n                            name=func_name,\n                            parameters=json.dumps(parsed_params, ensure_ascii=False),\n                        )\n                    )\n                    tool_idx += 1\n\n            # Determine normal text (text before the first tool call)\n            start_idx = text.find(self.tool_call_start_token)\n            if start_idx == -1:\n                start_idx = text.find(self.tool_call_prefix)\n            normal_text = text[:start_idx] if start_idx > 0 else \"\"\n\n            return StreamingParseResult(normal_text=normal_text, calls=calls)\n\n        except Exception as e:\n            logger.error(f\"Error in detect_and_parse: {e}\")\n            return StreamingParseResult(normal_text=text)\n\n    def parse_streaming_increment(\n        self, new_text: str, tools: List[Tool]\n    ) -> StreamingParseResult:\n        \"\"\"\n        Robust cursor-based streaming parser.\n        \"\"\"\n        self._buffer += new_text\n\n        # Guard against empty buffer\n        if not self._buffer:\n            return StreamingParseResult()\n\n        calls = []\n        normal_text_chunks = []\n\n        while True:\n            # Working text slice\n            current_slice = self._buffer[self.parsed_pos :]\n\n            # Optimization: If almost empty, wait for more\n            if not current_slice:\n                break\n\n            # -------------------------------------------------------\n            # 1. Priority detection: check if it's the start of Tool Call\n            # -------------------------------------------------------\n            if current_slice.startswith(self.tool_call_start_token):\n                self.parsed_pos += len(self.tool_call_start_token)\n                self.is_inside_tool_call = True\n                continue\n\n            # -------------------------------------------------------\n            # 2. Function Name: <function=name>\n            # -------------------------------------------------------\n            if current_slice.startswith(self.tool_call_prefix):\n                end_angle = current_slice.find(\">\")\n                if end_angle != -1:\n                    func_name = current_slice[len(self.tool_call_prefix) : end_angle]\n\n                    self.current_tool_id += 1\n                    self.current_tool_name_sent = True\n                    self.current_tool_param_count = 0\n                    self.json_started = False\n                    self.current_func_name = func_name\n\n                    calls.append(\n                        ToolCallItem(\n                            tool_index=self.current_tool_id,\n                            name=func_name,\n                            parameters=\"\",\n                        )\n                    )\n\n                    self.parsed_pos += end_angle + 1\n                    continue\n                else:\n                    # Incomplete tag\n                    break\n\n            # -------------------------------------------------------\n            # 3. Parameter: <parameter=name>value...\n            # -------------------------------------------------------\n            if current_slice.startswith(self.parameter_prefix):\n                name_end = current_slice.find(\">\")\n                if name_end != -1:\n                    value_start_idx = name_end + 1\n                    rest_of_slice = current_slice[value_start_idx:]\n\n                    # A parameter can end in multiple ways:\n                    # 1. [Normal] Encounter </parameter>\n                    # 2. [Abnormal] Encounter next <parameter=\n                    # 3. [Abnormal] Encounter </function>\n                    # So we need to find the smallest one as the parameter end position.\n                    cand_end_param = rest_of_slice.find(self.parameter_end_token)\n                    cand_next_param = rest_of_slice.find(self.parameter_prefix)\n                    cand_end_func = rest_of_slice.find(self.function_end_token)\n\n                    candidates = []\n                    if cand_end_param != -1:\n                        candidates.append(\n                            (cand_end_param, len(self.parameter_end_token))\n                        )\n                    if cand_next_param != -1:\n                        candidates.append((cand_next_param, 0))\n                    if cand_end_func != -1:\n                        candidates.append((cand_end_func, 0))\n\n                    if candidates:\n                        best_cand = min(candidates, key=lambda x: x[0])\n                        end_pos = best_cand[0]\n                        end_token_len = best_cand[1]\n\n                        param_name = current_slice[\n                            len(self.parameter_prefix) : name_end\n                        ]\n                        raw_value = rest_of_slice[:end_pos]\n\n                        # Cleanup value\n                        if raw_value.startswith(\"\\n\"):\n                            raw_value = raw_value[1:]\n                        if raw_value.endswith(\"\\n\"):\n                            raw_value = raw_value[:-1]\n\n                        # JSON Construction\n                        if not self.json_started:\n                            calls.append(\n                                ToolCallItem(\n                                    tool_index=self.current_tool_id, parameters=\"{\"\n                                )\n                            )\n                            self.json_started = True\n\n                        param_config = self._get_arguments_config(\n                            self.current_func_name, tools\n                        )\n                        converted_val = self._convert_param_value(\n                            raw_value, param_name, param_config, self.current_func_name\n                        )\n\n                        # Construct JSON fragment: \"key\": value\n                        # Note: We must be careful with json.dumps to ensure valid JSON streaming\n                        json_key_val = f\"{json.dumps(param_name)}: {json.dumps(converted_val, ensure_ascii=False)}\"\n\n                        if self.current_tool_param_count > 0:\n                            fragment = f\", {json_key_val}\"\n                        else:\n                            fragment = json_key_val\n\n                        calls.append(\n                            ToolCallItem(\n                                tool_index=self.current_tool_id, parameters=fragment\n                            )\n                        )\n                        self.current_tool_param_count += 1\n\n                        # Advance cursor\n                        total_len = (name_end + 1) + end_pos + end_token_len\n                        self.parsed_pos += total_len\n                        continue\n\n                # Incomplete parameter tag or value\n                break\n\n            # -------------------------------------------------------\n            # 4. Function End: </function>\n            # -------------------------------------------------------\n            if current_slice.startswith(self.function_end_token):\n                if not self.json_started:\n                    calls.append(\n                        ToolCallItem(tool_index=self.current_tool_id, parameters=\"{\")\n                    )\n                    self.json_started = True\n\n                calls.append(\n                    ToolCallItem(tool_index=self.current_tool_id, parameters=\"}\")\n                )\n                self.parsed_pos += len(self.function_end_token)\n                self.current_func_name = None\n                continue\n\n            # -------------------------------------------------------\n            # 5. Tool Call End: </tool_call>\n            # -------------------------------------------------------\n            if current_slice.startswith(self.tool_call_end_token):\n                self.parsed_pos += len(self.tool_call_end_token)\n                self.is_inside_tool_call = False  # [FIX] Exit tool call region\n                continue\n\n            # -------------------------------------------------------\n            # 6. Handling content / whitespace / normal text\n            # -------------------------------------------------------\n            # If current position is not the start of a tag (i.e., doesn't start with <), it might be plain text,\n            # or a newline between two tags.\n            # But we need to be careful not to output truncated tags like \"<fun\" as text.\n\n            next_open_angle = current_slice.find(\"<\")\n\n            if next_open_angle == -1:\n                # This entire segment is plain text\n                if not self.is_inside_tool_call:\n                    normal_text_chunks.append(current_slice)\n                # [FIX] If inside tool call, discard this text (usually \\n), don't append\n                self.parsed_pos += len(current_slice)\n                continue\n\n            elif next_open_angle == 0:\n                # Looks like a Tag, but doesn't match any known Tag above\n\n                possible_tags = [\n                    self.tool_call_start_token,\n                    self.tool_call_end_token,\n                    self.tool_call_prefix,\n                    self.function_end_token,\n                    self.parameter_prefix,\n                    self.parameter_end_token,\n                ]\n\n                is_potential_tag = False\n                for tag in possible_tags:\n                    if tag.startswith(current_slice):\n                        is_potential_tag = True\n                        break\n\n                if is_potential_tag:\n                    break  # Wait for more\n                else:\n                    # Just a plain '<' symbol\n                    if not self.is_inside_tool_call:\n                        normal_text_chunks.append(\"<\")\n                    self.parsed_pos += 1\n                    continue\n\n            else:\n                # '<' is in the middle\n                text_segment = current_slice[:next_open_angle]\n                if not self.is_inside_tool_call:\n                    normal_text_chunks.append(text_segment)\n                # [FIX] If inside tool call, discard whitespace/text before Tag\n                self.parsed_pos += next_open_angle\n                continue\n\n        # Memory Cleanup: Slice the buffer\n        # Keep unparsed part, discard parsed part\n        if self.parsed_pos > 0:\n            self._buffer = self._buffer[self.parsed_pos :]\n            self.parsed_pos = 0\n\n        normal_text = \"\".join(normal_text_chunks) if normal_text_chunks else \"\"\n        return StreamingParseResult(calls=calls, normal_text=normal_text)\n\n    def supports_structural_tag(self) -> bool:\n        return False\n\n    def structure_info(self) -> _GetInfoFunc:\n        raise NotImplementedError\n"
  },
  {
    "path": "python/sglang/srt/function_call/step3_detector.py",
    "content": "import ast\nimport json\nimport logging\nimport re\nfrom typing import Any, Dict, List\n\nfrom sglang.srt.entrypoints.openai.protocol import Tool\nfrom sglang.srt.function_call.base_format_detector import BaseFormatDetector\nfrom sglang.srt.function_call.core_types import (\n    StreamingParseResult,\n    ToolCallItem,\n    _GetInfoFunc,\n)\n\nlogger = logging.getLogger(__name__)\n\n\ndef get_argument_type(func_name: str, arg_key: str, defined_tools: List[Tool]) -> str:\n    \"\"\"Get the expected type for a function argument from tool schema.\"\"\"\n    name2tool = {tool.function.name: tool for tool in defined_tools}\n    if func_name not in name2tool:\n        return None\n    tool = name2tool[func_name]\n    parameters = tool.function.parameters or {}\n    properties = parameters.get(\"properties\", {})\n    if arg_key not in properties:\n        return None\n    return properties[arg_key].get(\"type\", None)\n\n\ndef parse_arguments(value: str) -> tuple[Any, bool]:\n    \"\"\"Parse a string value to appropriate type. Returns (parsed_value, success).\"\"\"\n    try:\n        try:\n            parsed_value = json.loads(value)\n        except:\n            parsed_value = ast.literal_eval(value)\n        return parsed_value, True\n    except:\n        return value, False\n\n\nclass Step3Detector(BaseFormatDetector):\n    \"\"\"\n    Detector for Step3 model function call format.\n\n    The Step3 format uses special Unicode tokens to delimit function calls\n    with steptml XML format for invocations.\n\n    Format Structure:\n    ```\n    <｜tool_calls_begin｜>\n    <｜tool_call_begin｜>function<｜tool_sep｜><steptml:invoke name=\"function_name\">\n    <steptml:parameter name=\"param1\">value1</steptml:parameter>\n    <steptml:parameter name=\"param2\">value2</steptml:parameter>\n    </steptml:invoke><｜tool_call_end｜>\n    <｜tool_calls_end｜>\n    ```\n    \"\"\"\n\n    def __init__(self):\n        super().__init__()\n        self.bot_token = \"<｜tool_calls_begin｜>\"\n        self.eot_token = \"<｜tool_calls_end｜>\"\n        self.tool_call_begin = \"<｜tool_call_begin｜>\"\n        self.tool_call_end = \"<｜tool_call_end｜>\"\n        self.tool_sep = \"<｜tool_sep｜>\"\n\n        # Regex for parsing steptml invocations\n        self.invoke_regex = re.compile(\n            r'<steptml:invoke name=\"([^\"]+)\">(.+?)</steptml:invoke>', re.DOTALL\n        )\n        self.param_regex = re.compile(\n            r'<steptml:parameter name=\"([^\"]+)\">([^<]*)</steptml:parameter>', re.DOTALL\n        )\n\n        # Streaming state variables\n        self._in_tool_block: bool = False\n        self._tool_block_finished: bool = False\n        self._current_function_name: str = \"\"\n        self._current_parameters: Dict[str, Any] = {}\n        self._in_tool_call: bool = False\n        self._function_name_sent: bool = False\n\n    def has_tool_call(self, text: str) -> bool:\n        \"\"\"Check if the text contains a Step3 format tool call.\"\"\"\n        return self.bot_token in text\n\n    def _parse_steptml_invoke(\n        self, text: str, tools: List[Tool] = None\n    ) -> tuple[str, dict]:\n        \"\"\"Parse steptml invoke format to extract function name and parameters.\"\"\"\n        invoke_match = self.invoke_regex.search(text)\n        if not invoke_match:\n            return None, {}\n\n        func_name = invoke_match.group(1)\n        params_text = invoke_match.group(2)\n\n        params = {}\n        for param_match in self.param_regex.finditer(params_text):\n            param_name = param_match.group(1)\n            param_value = param_match.group(2).strip()\n\n            # If tools provided, use schema-aware parsing\n            if tools:\n                arg_type = get_argument_type(func_name, param_name, tools)\n                if arg_type and arg_type != \"string\":\n                    parsed_value, _ = parse_arguments(param_value)\n                    params[param_name] = parsed_value\n                else:\n                    params[param_name] = param_value\n            else:\n                # Fallback to generic parsing if no tools provided\n                parsed_value, _ = parse_arguments(param_value)\n                params[param_name] = parsed_value\n\n        return func_name, params\n\n    def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:\n        \"\"\"\n        One-time parsing: Detects and parses tool calls in the provided text.\n        \"\"\"\n        if self.bot_token not in text:\n            return StreamingParseResult(normal_text=text, calls=[])\n\n        try:\n            pre_text, rest = text.split(self.bot_token, 1)\n\n            # If no end token, return everything as normal text\n            if self.eot_token not in rest:\n                return StreamingParseResult(normal_text=text, calls=[])\n\n            tool_section, post_text = rest.split(self.eot_token, 1)\n\n            # Find all individual tool calls using regex\n            calls = []\n            tool_call_pattern = (\n                f\"{re.escape(self.tool_call_begin)}(.*?){re.escape(self.tool_call_end)}\"\n            )\n\n            for match in re.finditer(tool_call_pattern, tool_section, re.DOTALL):\n                call_content = match.group(1)\n\n                # Check if it's a function call\n                if self.tool_sep not in call_content:\n                    continue\n\n                type_part, invoke_part = call_content.split(self.tool_sep, 1)\n                if type_part.strip() != \"function\":\n                    continue\n\n                func_name, params = self._parse_steptml_invoke(invoke_part, tools)\n                if func_name:\n                    # Use parse_base_json to create the ToolCallItem\n                    action = {\"name\": func_name, \"arguments\": params}\n                    calls.extend(self.parse_base_json(action, tools))\n\n            # Combine pre and post text\n            normal_text = pre_text + post_text\n\n            return StreamingParseResult(normal_text=normal_text, calls=calls)\n\n        except Exception as e:\n            logger.error(f\"Error in detect_and_parse: {e}\")\n            # Return the original text if parsing fails\n            return StreamingParseResult(normal_text=text)\n\n    def parse_streaming_increment(\n        self, new_text: str, tools: List[Tool]\n    ) -> StreamingParseResult:\n        \"\"\"\n        Streaming incremental parsing for Step3 format.\n        \"\"\"\n        self._buffer += new_text\n\n        # Build tool indices for validation\n        if not hasattr(self, \"_tool_indices\"):\n            self._tool_indices = self._get_tool_indices(tools)\n\n        # If we've finished the tool block, everything is normal text\n        if self._tool_block_finished:\n            normal_text = self._buffer\n            self._buffer = \"\"\n            return StreamingParseResult(normal_text=normal_text)\n\n        # Check if tool block hasn't started yet\n        if not self._in_tool_block:\n            if self.bot_token in self._buffer:\n                idx = self._buffer.find(self.bot_token)\n                normal_text = self._buffer[:idx]\n                self._buffer = self._buffer[idx + len(self.bot_token) :]\n                self._in_tool_block = True\n                return StreamingParseResult(normal_text=normal_text)\n            else:\n                # Check if we might have a partial bot_token\n                partial_len = self._ends_with_partial_token(\n                    self._buffer, self.bot_token\n                )\n                if partial_len:\n                    return StreamingParseResult()  # Wait for more text\n                else:\n                    normal_text = self._buffer\n                    self._buffer = \"\"\n                    return StreamingParseResult(normal_text=normal_text)\n\n        # We're inside the tool block\n        calls: List[ToolCallItem] = []\n\n        # Check if tool block is ending\n        if self.eot_token in self._buffer:\n            idx = self._buffer.find(self.eot_token)\n\n            # If we're in the middle of a tool call, we need to handle it\n            if self._in_tool_call:\n                # The buffer before eot_token might contain the end of the current tool call\n                before_eot = self._buffer[:idx]\n                if self.tool_call_end in before_eot:\n                    # Parse this final tool call\n                    result = self._parse_partial_tool_call(tools)\n                    calls.extend(result.calls)\n                else:\n                    # Incomplete tool call - log warning\n                    logger.warning(\"Tool block ended with incomplete tool call\")\n\n            remaining = self._buffer[idx + len(self.eot_token) :]\n            self._buffer = \"\"\n            self._tool_block_finished = True\n\n            # Reset any partial tool call state\n            self._reset_streaming_state()\n\n            return StreamingParseResult(normal_text=remaining, calls=calls)\n\n        # Check if we're in a tool call or need to start one\n        if not self._in_tool_call:\n            if self.tool_call_begin in self._buffer:\n                idx = self._buffer.find(self.tool_call_begin)\n                # Remove any content before tool call begin (shouldn't happen but be safe)\n                self._buffer = self._buffer[idx + len(self.tool_call_begin) :]\n                self._in_tool_call = True\n                self._function_name_sent = False\n                self._current_function_name = \"\"\n                self._current_parameters = {}\n                # Fall through to parse the partial tool call\n            else:\n                # Wait for tool call to begin\n                return StreamingParseResult()\n\n        # Parse partial tool call\n        if self._in_tool_call:\n            return self._parse_partial_tool_call(tools)\n\n        return StreamingParseResult()\n\n    def _parse_partial_tool_call(self, tools: List[Tool]) -> StreamingParseResult:\n        \"\"\"Parse partial tool call for streaming scenarios.\"\"\"\n        calls = []\n\n        # Check if we have tool_sep (means we're past the type declaration)\n        if self.tool_sep not in self._buffer:\n            return StreamingParseResult(calls=calls)  # Wait for more text\n\n        type_part, invoke_part = self._buffer.split(self.tool_sep, 1)\n        if type_part.strip() != \"function\":\n            # Invalid tool type, skip this tool call\n            self._reset_streaming_state()\n            return StreamingParseResult(calls=calls)\n\n        # Try to extract function name if not sent yet\n        if not self._function_name_sent:\n            name_match = re.search(r'<steptml:invoke name=\"([^\"]+)\">', invoke_part)\n            if name_match:\n                func_name = name_match.group(1)\n\n                # Validate function name\n                if func_name in self._tool_indices:\n                    self._current_function_name = func_name\n                    self._function_name_sent = True\n\n                    # Initialize tool tracking\n                    if self.current_tool_id == -1:\n                        self.current_tool_id = 0\n\n                    # Ensure tracking arrays are large enough\n                    while len(self.prev_tool_call_arr) <= self.current_tool_id:\n                        self.prev_tool_call_arr.append({})\n                    while len(self.streamed_args_for_tool) <= self.current_tool_id:\n                        self.streamed_args_for_tool.append(\"\")\n\n                    # Store tool call info\n                    self.prev_tool_call_arr[self.current_tool_id] = {\n                        \"name\": func_name,\n                        \"arguments\": {},\n                    }\n\n                    # Send tool name with empty parameters\n                    calls.append(\n                        ToolCallItem(\n                            tool_index=self.current_tool_id,\n                            name=func_name,\n                            parameters=\"\",\n                        )\n                    )\n                else:\n                    # Invalid function name\n                    logger.warning(f\"Invalid function name: {func_name}\")\n                    self._reset_streaming_state()\n                    return StreamingParseResult(calls=calls)\n            else:\n                # Function name not complete yet\n                return StreamingParseResult(calls=calls)\n\n        # Parse parameters incrementally\n        if self._function_name_sent:\n            # Extract all complete parameters\n            new_params = {}\n            for param_match in self.param_regex.finditer(invoke_part):\n                param_name = param_match.group(1)\n                param_value = param_match.group(2).strip()\n\n                # Use schema-aware parsing\n                arg_type = get_argument_type(\n                    self._current_function_name, param_name, tools\n                )\n                if arg_type and arg_type != \"string\":\n                    parsed_value, _ = parse_arguments(param_value)\n                    new_params[param_name] = parsed_value\n                else:\n                    new_params[param_name] = param_value\n\n            # Check if we have new parameters to stream\n            if new_params != self._current_parameters:\n                # Build the JSON content without the closing brace for streaming\n                if not self._current_parameters:\n                    # First parameters - send opening brace and content\n                    params_content = json.dumps(new_params, ensure_ascii=False)\n                    if len(params_content) > 2:  # More than just \"{}\"\n                        # Send everything except the closing brace\n                        diff = params_content[:-1]\n                    else:\n                        diff = \"{\"\n                else:\n                    # Subsequent parameters - calculate the incremental diff\n                    old_json = json.dumps(self._current_parameters, ensure_ascii=False)\n                    new_json = json.dumps(new_params, ensure_ascii=False)\n\n                    # Remove closing braces for comparison\n                    old_without_brace = old_json[:-1]\n                    new_without_brace = new_json[:-1]\n\n                    # The new content should extend the old content\n                    if new_without_brace.startswith(old_without_brace):\n                        diff = new_without_brace[len(old_without_brace) :]\n                    else:\n                        # Parameters changed in unexpected way - shouldn't happen in normal streaming\n                        diff = \"\"\n\n                if diff:\n                    calls.append(\n                        ToolCallItem(\n                            tool_index=self.current_tool_id,\n                            parameters=diff,\n                        )\n                    )\n                    self.streamed_args_for_tool[self.current_tool_id] += diff\n\n                # Update current state\n                self._current_parameters = new_params\n                self.prev_tool_call_arr[self.current_tool_id][\"arguments\"] = new_params\n\n            # Check if tool call is complete\n            if self.tool_call_end in self._buffer:\n                # Send closing brace if we've sent any parameters\n                if self.streamed_args_for_tool[self.current_tool_id]:\n                    calls.append(\n                        ToolCallItem(\n                            tool_index=self.current_tool_id,\n                            parameters=\"}\",\n                        )\n                    )\n                    self.streamed_args_for_tool[self.current_tool_id] += \"}\"\n\n                # Find the end position\n                end_idx = self._buffer.find(self.tool_call_end)\n                # Remove the processed tool call from buffer\n                self._buffer = self._buffer[end_idx + len(self.tool_call_end) :]\n\n                # Reset state for next tool call\n                self._reset_streaming_state()\n                self.current_tool_id += 1\n\n        return StreamingParseResult(calls=calls)\n\n    def _reset_streaming_state(self):\n        \"\"\"Reset streaming state for the next tool call\"\"\"\n        self._in_tool_call = False\n        self._function_name_sent = False\n        self._current_function_name = \"\"\n        self._current_parameters = {}\n\n    def supports_structural_tag(self) -> bool:\n        \"\"\"Return True if this detector supports structural tag format.\"\"\"\n        return False\n\n    def structure_info(self) -> _GetInfoFunc:\n        raise NotImplementedError()\n"
  },
  {
    "path": "python/sglang/srt/function_call/trinity_detector.py",
    "content": "import logging\nfrom typing import List\n\nfrom sglang.srt.entrypoints.openai.protocol import Tool\nfrom sglang.srt.function_call.core_types import StreamingParseResult\nfrom sglang.srt.function_call.qwen25_detector import Qwen25Detector\n\nlogger = logging.getLogger(__name__)\n\n\nclass TrinityDetector(Qwen25Detector):\n    \"\"\"\n    Detector for Trinity models using Qwen-style function call format.\n\n    This detector extends Qwen25Detector to handle tool calls that may appear\n    inside <think> sections by stripping the think tags before parsing.\n\n    Reference: https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct?chat_template=default\n    \"\"\"\n\n    def _strip_think_tags(self, text: str) -> str:\n        \"\"\"Remove <think> and </think> tags, keeping the content inside.\"\"\"\n        return text.replace(\"<think>\", \"\").replace(\"</think>\", \"\")\n\n    def has_tool_call(self, text: str) -> bool:\n        \"\"\"Check if the text contains a tool call.\"\"\"\n        return super().has_tool_call(self._strip_think_tags(text))\n\n    def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:\n        \"\"\"\n        One-time parsing: Detects and parses tool calls in the provided text.\n        \"\"\"\n        return super().detect_and_parse(self._strip_think_tags(text), tools)\n\n    def parse_streaming_increment(\n        self, new_text: str, tools: List[Tool]\n    ) -> StreamingParseResult:\n        \"\"\"\n        Streaming incremental parsing for tool calls.\n        \"\"\"\n        return super().parse_streaming_increment(\n            self._strip_think_tags(new_text), tools\n        )\n"
  },
  {
    "path": "python/sglang/srt/function_call/utils.py",
    "content": "from json import JSONDecodeError, JSONDecoder\nfrom json.decoder import WHITESPACE\nfrom typing import Any, Dict, List, Literal, Optional, Tuple, Union\n\nimport orjson\nimport partial_json_parser\nfrom partial_json_parser.core.options import Allow\n\nfrom sglang.srt.entrypoints.openai.protocol import Tool, ToolChoice\n\n\ndef _find_common_prefix(s1: str, s2: str) -> str:\n    prefix = \"\"\n    min_length = min(len(s1), len(s2))\n    for i in range(0, min_length):\n        if s1[i] == s2[i]:\n            prefix += s1[i]\n        else:\n            break\n    return prefix\n\n\ndef _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:\n    \"\"\"\n    Parse incomplete or partial JSON strings commonly encountered during streaming.\n\n    Args:\n        input_str (str): The potentially incomplete JSON string to parse.\n        flags (Allow): Bitwise flags controlling what types of partial data are allowed.\n            Common flags include:\n            - Allow.STR: Allow partial strings (e.g., '\"hello wo' -> 'hello wo')\n            - Allow.OBJ: Allow partial objects (e.g., '{\"key\":' -> {'key': None})\n            - Allow.ARR: Allow partial arrays (e.g., '[1, 2,' -> [1, 2])\n            - Allow.ALL: Allow all types of partial data\n\n    Returns:\n        Tuple[Any, int]: A tuple containing:\n            - parsed_object: The Python object parsed from the JSON\n            - consumed_length: Number of characters consumed from input_str\n    \"\"\"\n    try:\n        return (partial_json_parser.loads(input_str, flags), len(input_str))\n    except (JSONDecodeError, IndexError) as e:\n        msg = getattr(e, \"msg\", str(e))\n        if \"Extra data\" in msg or \"pop from empty list\" in msg:\n            start = WHITESPACE.match(input_str, 0).end()\n            obj, end = JSONDecoder().raw_decode(input_str, start)\n            return obj, end\n        raise\n\n\ndef _is_complete_json(input_str: str) -> bool:\n    try:\n        orjson.loads(input_str)\n        return True\n    except JSONDecodeError:\n        return False\n\n\ndef _get_tool_schema_defs(tools: List[Tool]) -> dict:\n    \"\"\"\n    Get consolidated $defs from all tools, validating for conflicts.\n\n    Args:\n        tools: List of tools to process\n\n    Returns:\n        Dictionary of consolidated $defs from all tools\n\n    Raises:\n        ValueError: If conflicting $defs are found\n    \"\"\"\n    all_defs = {}\n    for tool in tools:\n        if tool.function.parameters is None:\n            continue\n        defs = tool.function.parameters.get(\"$defs\", {})\n        for def_name, def_schema in defs.items():\n            if def_name in all_defs and all_defs[def_name] != def_schema:\n                raise ValueError(\n                    f\"Tool definition '{def_name}' has \"\n                    \"multiple schemas, which is not \"\n                    \"supported.\"\n                )\n            else:\n                all_defs[def_name] = def_schema\n    return all_defs\n\n\ndef _get_tool_schema(tool: Tool) -> dict:\n    return {\n        \"properties\": {\n            \"name\": {\"type\": \"string\", \"enum\": [tool.function.name]},\n            \"parameters\": (\n                tool.function.parameters\n                if tool.function.parameters\n                else {\"type\": \"object\", \"properties\": {}}\n            ),\n        },\n        \"required\": [\"name\", \"parameters\"],\n    }\n\n\ndef infer_type_from_json_schema(schema: Dict[str, Any]) -> Optional[str]:\n    \"\"\"\n    Infer the primary type of a parameter from JSON Schema.\n\n    Supports complex JSON Schema structures including:\n    - Direct type field (including type arrays)\n    - anyOf/oneOf: parameter can be any of multiple types\n    - enum: parameter must be one of enum values\n    - allOf: parameter must satisfy all type definitions\n    - properties: inferred as object type\n    - items: inferred as array type\n\n    Args:\n        schema: JSON Schema definition\n\n    Returns:\n        Inferred type ('string', 'number', 'object', 'array', etc.) or None\n    \"\"\"\n    if not isinstance(schema, dict):\n        return None\n\n    # Priority 1: Direct type field (including type arrays)\n    if \"type\" in schema:\n        type_value = schema[\"type\"]\n        if isinstance(type_value, str):\n            return type_value\n        elif isinstance(type_value, list) and type_value:\n            # Handle type arrays: return first non-null type\n            non_null_types = [t for t in type_value if t != \"null\"]\n            if non_null_types:\n                return non_null_types[0]\n            return \"string\"  # If only null, default to string\n\n    # Priority 2: Handle anyOf/oneOf\n    if \"anyOf\" in schema or \"oneOf\" in schema:\n        schemas = schema.get(\"anyOf\") or schema.get(\"oneOf\")\n        types = []\n\n        if isinstance(schemas, list):\n            for sub_schema in schemas:\n                inferred_type = infer_type_from_json_schema(sub_schema)\n                if inferred_type:\n                    types.append(inferred_type)\n\n            if types:\n                # If all types are the same, return unified type\n                if len(set(types)) == 1:\n                    return types[0]\n                # When types differ, prioritize string (safest)\n                if \"string\" in types:\n                    return \"string\"\n                # Otherwise return first type\n                return types[0]\n\n    # Priority 3: Handle enum (infer type from enum values)\n    if \"enum\" in schema and isinstance(schema[\"enum\"], list):\n        if not schema[\"enum\"]:\n            return \"string\"\n\n        # Infer type from enum values\n        enum_types = set()\n        for value in schema[\"enum\"]:\n            if value is None:\n                enum_types.add(\"null\")\n            elif isinstance(value, bool):\n                enum_types.add(\"boolean\")\n            elif isinstance(value, int):\n                enum_types.add(\"integer\")\n            elif isinstance(value, float):\n                enum_types.add(\"number\")\n            elif isinstance(value, str):\n                enum_types.add(\"string\")\n            elif isinstance(value, list):\n                enum_types.add(\"array\")\n            elif isinstance(value, dict):\n                enum_types.add(\"object\")\n\n        # If type is uniform, return that type\n        if len(enum_types) == 1:\n            return enum_types.pop()\n        # Mixed types, prioritize string\n        return \"string\"\n\n    # Priority 4: Handle allOf (must satisfy all types)\n    if \"allOf\" in schema and isinstance(schema[\"allOf\"], list):\n        schemas = schema[\"allOf\"]\n        for sub_schema in schemas:\n            inferred_type = infer_type_from_json_schema(sub_schema)\n            if inferred_type and inferred_type != \"string\":\n                return inferred_type\n        return \"string\"\n\n    # Priority 5: Infer object type\n    if \"properties\" in schema:\n        return \"object\"\n\n    # Priority 6: Infer array type\n    if \"items\" in schema:\n        return \"array\"\n\n    return None\n\n\ndef get_json_schema_constraint(\n    tools: List[Tool], tool_choice: Union[ToolChoice, Literal[\"required\"]]\n) -> Optional[dict]:\n    \"\"\"\n    Get the JSON schema constraint for the specified tool choice.\n\n    Args:\n        tool_choice: The tool choice specification\n\n    Returns:\n        JSON schema dict, or None if no valid tools found\n    \"\"\"\n\n    if isinstance(tool_choice, ToolChoice):\n        # For specific function choice, return the user's parameters schema directly\n        fn_name = tool_choice.function.name\n        for tool in tools:\n            if tool.function.name == fn_name:\n                return {\n                    \"type\": \"array\",\n                    \"minItems\": 1,\n                    \"maxItems\": 1,\n                    \"items\": _get_tool_schema(tool),\n                }\n        return None\n    elif tool_choice == \"required\":\n        json_schema = {\n            \"type\": \"array\",\n            \"minItems\": 1,\n            \"items\": {\n                \"type\": \"object\",\n                \"anyOf\": [_get_tool_schema(tool) for tool in tools],\n            },\n        }\n        json_schema_defs = _get_tool_schema_defs(tools)\n        if json_schema_defs:\n            json_schema[\"$defs\"] = json_schema_defs\n        return json_schema\n\n    return None\n"
  },
  {
    "path": "python/sglang/srt/grpc/__init__.py",
    "content": "# SGLang gRPC module\n"
  },
  {
    "path": "python/sglang/srt/hardware_backend/npu/allocator_npu.py",
    "content": "from typing import TYPE_CHECKING\n\nimport torch\n\nfrom sglang.srt.mem_cache.allocator import (\n    PagedTokenToKVPoolAllocator,\n    alloc_extend_naive,\n)\nfrom sglang.srt.utils import get_num_new_pages, next_power_of_2\n\nif TYPE_CHECKING:\n    from sglang.srt.mem_cache.memory_pool import KVCache\n\n\nclass NPUPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):\n    def __init__(\n        self,\n        size: int,\n        page_size: int,\n        dtype: torch.dtype,\n        device: str,\n        kvcache: \"KVCache\",\n        need_sort: bool,\n    ):\n        super().__init__(size, page_size, dtype, device, kvcache, need_sort)\n        self.roundup = page_size - 1\n\n    def alloc_extend(\n        self,\n        prefix_lens: torch.Tensor,\n        prefix_lens_cpu: torch.Tensor,\n        seq_lens: torch.Tensor,\n        seq_lens_cpu: torch.Tensor,\n        last_loc: torch.Tensor,\n        extend_num_tokens: int,\n    ):\n        if self.debug_mode:\n            assert torch.all(\n                (last_loc + 1) % self.page_size == prefix_lens % self.page_size\n            )\n\n        num_new_pages = (\n            (seq_lens + self.roundup) // self.page_size\n            - (prefix_lens + self.roundup) // self.page_size\n        ).sum()\n        num_new_pages_item = num_new_pages.item()\n        if self.need_sort and num_new_pages_item > len(self.free_pages):\n            self.merge_and_sort_free()\n\n        if num_new_pages_item > len(self.free_pages):\n            return None\n\n        if num_new_pages_item < 200:\n            from sgl_kernel_npu.mem_cache.allocator import alloc_extend_kernel\n\n            out_indices = torch.empty(\n                (extend_num_tokens,),\n                dtype=torch.int64,\n                device=self.device,\n            )\n            max_num_extend_tokens = next_power_of_2(extend_num_tokens)\n            bs = prefix_lens.shape[0]\n            alloc_extend_kernel[(bs,)](\n                prefix_lens,\n                seq_lens,\n                last_loc,\n                self.free_pages,\n                out_indices,\n                next_power_of_2(bs),\n                self.page_size,\n                max_num_extend_tokens,\n            )\n\n        else:\n            out_indices = torch.empty(\n                (extend_num_tokens,),\n                dtype=torch.int32,\n                device=self.device,\n            )\n            alloc_extend_naive(\n                prefix_lens,\n                seq_lens,\n                last_loc,\n                self.free_pages,\n                out_indices,\n                self.page_size,\n                self.device,\n            )\n\n        if self.debug_mode:\n            assert len(torch.unique(out_indices)) == len(out_indices)\n\n        self.free_pages = self.free_pages[num_new_pages_item:]\n        return out_indices.int()\n\n    def alloc_decode(\n        self,\n        seq_lens: torch.Tensor,\n        seq_lens_cpu: torch.Tensor,\n        last_loc: torch.Tensor,\n    ):\n        if self.debug_mode:\n            assert torch.all(\n                (last_loc + 2) % self.page_size == seq_lens % self.page_size\n            )\n\n        num_new_pages = get_num_new_pages(\n            seq_lens=seq_lens_cpu,\n            page_size=self.page_size,\n            decode=True,\n        )\n\n        if num_new_pages > len(self.free_pages):\n            self.merge_and_sort_free()\n\n        if num_new_pages > len(self.free_pages):\n            return None\n\n        need_new_pages = (seq_lens % self.page_size == 1).int()\n        end_new_pages = torch.cumsum(need_new_pages, 0)\n        start_new_pages = end_new_pages - need_new_pages\n        if num_new_pages == 0:\n            out_indices = last_loc + 1\n        else:\n            out_indices = (last_loc + 1) * (1 - need_new_pages) + self.free_pages[\n                start_new_pages\n            ] * self.page_size * need_new_pages\n\n        if self.debug_mode:\n            assert len(torch.unique(out_indices)) == len(out_indices)\n\n        self.free_pages = self.free_pages[num_new_pages:]\n        return out_indices.int()\n\n    def free(self, free_index: torch.Tensor):\n        if free_index.numel() == 0:\n            return\n\n        if self.is_not_in_free_group:\n            device = free_index.device\n            free_page_indices = torch.unique(free_index.cpu() // self.page_size)\n            free_page_indices = free_page_indices.to(device)\n            if self.need_sort:\n                self.release_pages = torch.cat((free_page_indices, self.release_pages))\n            else:\n                self.free_pages = torch.cat((free_page_indices, self.free_pages))\n        else:\n            self.free_group.append(free_index)\n\n        if self.debug_mode:\n            assert len(torch.unique(self.free_pages)) == len(self.free_pages)\n"
  },
  {
    "path": "python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py",
    "content": "from __future__ import annotations\n\nfrom dataclasses import dataclass\nfrom typing import TYPE_CHECKING, List, Optional\n\nimport torch\nimport torch_npu\nfrom sgl_kernel_npu.attention.sinks_attention import (\n    attention_sinks_prefill_triton,\n    attention_sinks_triton,\n)\n\nfrom sglang.srt.configs.model_config import AttentionArch\nfrom sglang.srt.dllm.config import DllmConfig\nfrom sglang.srt.hardware_backend.npu.attention.ascend_torch_native_backend import (\n    AscendTorchNativeAttnBackend,\n)\nfrom sglang.srt.hardware_backend.npu.attention.mla_preprocess import (\n    is_fia_nz,\n    is_mla_preprocess_enabled,\n)\nfrom sglang.srt.layers.attention.base_attn_backend import AttentionBackend\nfrom sglang.srt.layers.attention.nsa.utils import is_nsa_enable_prefill_cp\nfrom sglang.srt.layers.radix_attention import AttentionType\nfrom sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode\nfrom sglang.srt.speculative.spec_info import SpecInput\nfrom sglang.srt.utils import get_bool_env_var\n\nif TYPE_CHECKING:\n    from sglang.srt.layers.radix_attention import RadixAttention\n    from sglang.srt.model_executor.model_runner import ModelRunner\n\nimport logging\n\nimport numpy as np\n\n\ndef _reshape_kv_for_fia_nz(\n    tensor: torch.Tensor, num_heads: int, head_dim: int, page_size: int\n) -> torch.Tensor:\n    \"\"\"Reshapes a tensor for FIA NZ format.\"\"\"\n    return tensor.view(-1, 1, num_heads * head_dim // 16, page_size, 16)\n\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclass\nclass ForwardMetadata:\n\n    # calculated map for kv positions [bs * maxseqlen]\n    block_tables: Optional[torch.Tensor] = None\n\n    # seq len inputs\n    extend_seq_lens_cpu_int: Optional[torch.Tensor] = None\n    seq_lens_cpu_int: Optional[torch.Tensor] = None\n    seq_lens_cpu_list: Optional[List[int]] = None\n    seq_lens_list_cumsum: Optional[List[int]] = None\n    seq_lens: Optional[torch.Tensor] = None\n    actual_seq_lengths_q: Optional[torch.Tensor] = None\n    actual_seq_lengths_kv: Optional[torch.Tensor] = None\n\n    # prefix cache\n    prefix_lens: Optional[torch.Tensor] = None\n    flatten_prefix_block_tables: Optional[torch.Tensor] = None\n\n\nclass AscendAttnMaskBuilder:\n    def __init__(self, model_runner: ModelRunner, device, use_fia, use_mla):\n        \"\"\"\n        Initialize the AscendAttnMaskBuilder class.\n\n        :param model_runner: ModelRunner instance for model execution.\n        :param device: Device to run the model on (e.g., 'cuda', 'npu').\n        :param use_fia: Boolean flag to indicate if environment variable ASCEND_USE_FIA is set to 1.\n        \"\"\"\n        self.use_fia = use_fia\n        self.model_runner = model_runner\n        self.device = device\n\n        # Initialize mask\n        mask_len = 128\n        self.mask = self.generate_attn_mask(mask_len, \"norm\", model_runner.dtype).to(\n            self.device\n        )\n\n        # Initialize FIA mask\n        fia_mask_len = 2048\n        self.fia_mask = self.generate_mask_flag(fia_mask_len).to(self.device)\n\n        # Initialize MTP mask\n        mtp_mask_len = 2048\n        self.mtp_mask = self.generate_mask_flag(mtp_mask_len).to(self.device)\n\n        # Initialize mixed chunk mask cache\n        mixed_mask_len = 2048\n        self.mixed_chunk_attn_mask = self.get_splitfuse_attn_mask(mixed_mask_len)\n\n        if use_mla:\n            # Initialize RingMla mask\n            ringmla_mask_len = 512\n            self.ringmla_mask = self.generate_attn_mask(\n                ringmla_mask_len, \"norm\", torch.bfloat16\n            ).to(self.device)\n\n    @staticmethod\n    def generate_mask_flag(max_seq_len):\n        \"\"\"\n        Generate a mask flag for attention masks.\n\n        :param max_seq_len: Maximum sequence length for the mask.\n        :return: A boolean tensor representing the mask flag.\n        \"\"\"\n        # Construct lower triangle matrix.\n        mask_flag = torch.ones((max_seq_len, max_seq_len), dtype=torch.bool).tril_()\n        # Create upper triangle matrix used to mark mask positions.\n        mask_flag = ~mask_flag\n        return mask_flag\n\n    @staticmethod\n    def generate_attn_mask(max_seq_len, mode, dtype=torch.float16):\n        \"\"\"\n        Generate an attention mask.\n\n        :param max_seq_len: Maximum sequence length for the mask.\n        :param mode: Mode of the mask ('mix' or 'norm').\n        :param dtype: Data type of the mask tensor.\n        :return: A tensor representing the attention mask.\n        \"\"\"\n        mask_flag = AscendAttnMaskBuilder.generate_mask_flag(max_seq_len)\n        if mode == \"mix\":\n            mask_value = (\n                float(\"-inf\") if dtype in [torch.float16, torch.bfloat16] else 1\n            )\n        else:\n            mask_value = torch.finfo(torch.float32).min if dtype == torch.float16 else 1\n        attn_mask = (\n            torch.zeros(size=(max_seq_len, max_seq_len))\n            .masked_fill_(mask_flag, mask_value)\n            .to(dtype)\n        )\n        return attn_mask\n\n    @staticmethod\n    def get_attention_mask_id(seq_lens, extend_lens):\n        \"\"\"\n        Generate attention mask IDs based on sequence lengths and extended lengths.\n\n        :param seq_lens: Sequence lengths.\n        :param extend_lens: Extended lengths.\n        :return: A tensor containing the attention mask IDs.\n        \"\"\"\n        starts = seq_lens - extend_lens\n        ends = seq_lens\n\n        # Use torch.stack to stack the start and end indices together\n        ranges = torch.stack((starts, ends), dim=-1)\n\n        # Use list comprehension to generate tensors for each range and concatenate them\n        attn_mask_id = torch.cat([torch.arange(start, end) for start, end in ranges])\n        return attn_mask_id\n\n    def update_attn_cache(\n        self,\n        seqlen: int,\n        mask_cache: torch.Tensor,\n        seq_len_cached: int,\n        dtype: torch.dtype,\n        mode,\n    ):\n        \"\"\"\n        Update the attention mask cache.\n\n        :param seqlen: Maximum sequence length.\n        :param mask_cache: Current attention mask cache.\n        :param seq_len_cached: Cached sequence length.\n        :param dtype: Data type of the mask tensor.\n        :param mode: Mode of the mask ('mix' or 'norm').\n        :return: Updated mask cache and sequence length cache.\n        \"\"\"\n        if seqlen > seq_len_cached:\n            seq_len_cached = seqlen\n            mask_cache = self.generate_attn_mask(seqlen, mode, dtype)\n        if mask_cache.dtype != dtype:\n            mask_cache = mask_cache.to(dtype)\n        return mask_cache, seq_len_cached\n\n    def get_splitfuse_attn_mask(\n        self,\n        seq_lens: torch.Tensor = None,\n    ) -> torch.Tensor:\n        \"\"\"\n        Generate a splitfuse attention mask.\n\n        :param seq_lens: Sequence lengths.\n        :return: A tensor representing the splitfuse attention mask.\n        \"\"\"\n        attn_mask = (\n            torch.triu(torch.ones(seq_lens, seq_lens), diagonal=1)\n            .to(torch.int8)\n            .to(self.device)\n        )\n        return attn_mask\n\n\nclass AscendAttnBackend(AttentionBackend):\n\n    def __init__(self, model_runner: ModelRunner):\n        super().__init__()\n        self.forward_metadata = None\n        self.device = model_runner.device\n        self.page_size = model_runner.page_size\n        self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA\n        if self.use_mla:\n            self.kv_lora_rank = model_runner.model_config.kv_lora_rank\n            self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim\n            if (\n                \"MiniCPM3ForCausalLM\"\n                in model_runner.model_config.hf_config.architectures\n            ):\n                self.qk_nope_head_dim = (\n                    model_runner.model_config.hf_config.qk_nope_head_dim\n                )\n            else:\n                self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim\n            self.q_head_dim = self.qk_rope_head_dim + self.qk_nope_head_dim\n        else:\n            self.use_alibi = getattr(model_runner.model_config, \"use_alibi\", False)\n            if (\n                \"Gemma2ForSequenceClassification\"\n                in model_runner.model_config.hf_config.architectures\n            ):\n                self.use_native_sdpa = True\n        self.native_attn = AscendTorchNativeAttnBackend()\n        self.graph_metadata = {}\n        self.max_context_len = model_runner.model_config.context_len\n        self.req_to_token = model_runner.req_to_token_pool.req_to_token\n        self.graph_mode = False\n        self.use_fia = get_bool_env_var(\"ASCEND_USE_FIA\", \"False\")\n        self.enable_torch_compile = model_runner.server_args.enable_torch_compile\n        self.speculative_num_draft_tokens = (\n            model_runner.server_args.speculative_num_draft_tokens\n        )\n        self.ascend_attn_mask_builder = AscendAttnMaskBuilder(\n            model_runner, self.device, self.use_fia, self.use_mla\n        )\n        self.mask, self.fia_mask, self.mtp_mask, self.mix_mask = (\n            self.ascend_attn_mask_builder.mask,\n            self.ascend_attn_mask_builder.fia_mask,\n            self.ascend_attn_mask_builder.mtp_mask,\n            self.ascend_attn_mask_builder.mixed_chunk_attn_mask,\n        )\n        if self.use_mla:\n            self.ringmla_mask = self.ascend_attn_mask_builder.ringmla_mask\n\n        # dllm model config\n        self.dllm_config = DllmConfig.from_server_args(model_runner.server_args)\n        self.is_dllm_model = False\n        if self.dllm_config is not None:\n            self.is_dllm_model = True\n            self.dllm_block_size = self.dllm_config.block_size\n\n    def get_verify_buffers_to_fill_after_draft(self):\n        \"\"\"\n        Return buffers for verify attention kernels that needs to be filled after draft.\n\n        Typically, these are tree mask and position buffers.\n        \"\"\"\n        return [None, None]\n\n    def update_verify_buffers_to_fill_after_draft(\n        self, spec_info: SpecInput, cuda_graph_bs: Optional[int]\n    ):\n        pass\n\n    def init_forward_metadata(self, forward_batch: ForwardBatch):\n        \"\"\"Init the metadata for a forward pass.\"\"\"\n        self.forward_metadata = ForwardMetadata()\n        seq_lens_max = forward_batch.seq_lens.max()\n        if forward_batch.forward_mode.is_target_verify():\n            seq_lens_max += self.speculative_num_draft_tokens\n        self.forward_metadata.block_tables = (\n            forward_batch.req_to_token_pool.req_to_token[\n                forward_batch.req_pool_indices, :seq_lens_max\n            ][:, :: self.page_size]\n            // self.page_size\n        )\n        if forward_batch.extend_seq_lens is not None:\n            self.forward_metadata.extend_seq_lens = forward_batch.extend_seq_lens\n            self.forward_metadata.extend_seq_lens_cpu_int = (\n                forward_batch.extend_seq_lens.cpu().int()\n            )\n        if forward_batch.seq_lens is not None:\n            self.forward_metadata.seq_lens = forward_batch.seq_lens.int()\n        else:\n            self.forward_metadata.seq_lens = forward_batch.seq_lens_cpu.to(\n                self.device\n            ).int()\n\n        self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()\n        if (\n            not forward_batch.forward_mode.is_draft_extend_v2()\n            and not forward_batch.forward_mode.is_draft_extend()\n            and not forward_batch.forward_mode.is_target_verify()\n        ):\n            seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu)\n            self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum\n\n        if forward_batch.forward_mode.is_target_verify():\n            self.forward_metadata.seq_lens_cpu_int += self.speculative_num_draft_tokens\n\n        if (\n            self.use_mla\n            and forward_batch.forward_mode.is_extend()\n            and not forward_batch.forward_mode.is_draft_extend(include_v2=True)\n            and not forward_batch.forward_mode.is_target_verify()\n            and sum(forward_batch.extend_prefix_lens_cpu) > 0\n        ):\n            self.forward_metadata.prefix_lens = forward_batch.extend_prefix_lens.to(\n                \"cpu\"\n            )\n            seq_prefix_lens = self.forward_metadata.prefix_lens.tolist()\n            self.forward_metadata.flatten_prefix_block_tables = torch.empty(\n                0, dtype=torch.int32\n            ).to(self.device)\n            for req_idx, seq_len in zip(\n                forward_batch.req_pool_indices.tolist(), seq_prefix_lens\n            ):\n                req_indices = forward_batch.req_to_token_pool.req_to_token[req_idx]\n                req_prefix_block_tables = (\n                    req_indices[:seq_len][:: self.page_size] // self.page_size\n                )\n                self.forward_metadata.flatten_prefix_block_tables = torch.cat(\n                    (\n                        self.forward_metadata.flatten_prefix_block_tables,\n                        torch.flatten(req_prefix_block_tables),\n                    )\n                )\n\n        self.graph_mode = False\n\n    def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):\n        self.graph_metadata = {\n            \"block_tables\": torch.empty(\n                (max_bs, (self.max_context_len + self.page_size - 1) // self.page_size),\n                dtype=torch.int32,\n                device=self.device,\n            ),\n        }\n\n    def init_forward_metadata_capture_cuda_graph(\n        self,\n        bs: int,\n        num_tokens: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        encoder_lens: Optional[torch.Tensor],\n        forward_mode: ForwardMode,\n        spec_info: Optional[SpecInput],\n    ):\n        metadata = ForwardMetadata()\n\n        metadata.block_tables = self.graph_metadata[\"block_tables\"][:bs, :]\n        if self.is_dllm_model:\n            max_len = int(seq_lens[:bs].max().item())\n            max_seq_pages = (max_len + self.page_size - 1) // self.page_size\n            metadata.block_tables[:bs, :max_seq_pages].copy_(\n                (\n                    self.req_to_token[req_pool_indices[:bs], :max_len][\n                        :, :: self.page_size\n                    ]\n                    // self.page_size\n                ).to(torch.int32)\n            )\n            metadata.block_tables[:bs, max_seq_pages:].fill_(0)\n            metadata.block_tables[bs:, :].fill_(0)\n\n        metadata.seq_lens_cpu_list = seq_lens.cpu().int().tolist()\n        metadata.seq_lens = seq_lens\n        if (\n            forward_mode.is_target_verify()\n            or forward_mode.is_draft_extend_v2()\n            or forward_mode.is_draft_extend()\n        ):\n            metadata.actual_seq_lengths_q = torch.arange(\n                self.speculative_num_draft_tokens,\n                self.speculative_num_draft_tokens\n                + bs * self.speculative_num_draft_tokens,\n                self.speculative_num_draft_tokens,\n                dtype=torch.int32,\n                device=seq_lens.device,\n            )\n        else:\n            metadata.actual_seq_lengths_q = torch.tensor(\n                [1 + i * 1 for i in range(bs)],\n                dtype=torch.int32,\n                device=seq_lens.device,\n            )\n        if forward_mode.is_dllm_extend():\n            extend_seq_lens_cpu_int = torch.tensor(\n                [self.dllm_block_size for i in range(bs)],\n                dtype=torch.int32,\n                device=seq_lens.device,\n            )\n            metadata.seq_lens_list_cumsum = (\n                torch.cumsum(extend_seq_lens_cpu_int, dim=0).int().tolist()\n            )\n\n        self.graph_metadata[bs] = metadata\n        self.forward_metadata = metadata\n\n        self.graph_mode = True\n\n    def init_forward_metadata_replay_cuda_graph(\n        self,\n        bs: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        seq_lens_sum: int,\n        encoder_lens: Optional[torch.Tensor],\n        forward_mode: ForwardMode,\n        spec_info: Optional[SpecInput],\n        seq_lens_cpu: Optional[torch.Tensor],\n    ):\n        metadata = self.graph_metadata[bs]\n        max_len = seq_lens_cpu[:bs].max().item()\n        if forward_mode.is_target_verify():\n            max_len += self.speculative_num_draft_tokens\n        max_seq_pages = (max_len + self.page_size - 1) // self.page_size\n\n        metadata.block_tables[:bs, :max_seq_pages].copy_(\n            self.req_to_token[req_pool_indices[:bs], :max_len][:, :: self.page_size]\n            // self.page_size\n        )\n\n        metadata.block_tables[:bs, max_seq_pages:].fill_(0)\n        metadata.block_tables[bs:, :].fill_(0)\n\n        if forward_mode.is_target_verify():\n            seq_lens = seq_lens + self.speculative_num_draft_tokens\n        metadata.seq_lens[:bs].copy_(seq_lens[:bs])\n\n        self.forward_metadata = metadata\n\n        self.graph_mode = True\n\n    def get_cuda_graph_seq_len_fill_value(self):\n        return 0\n\n    def _generate_alibi_bias(\n        self,\n        seq_len: int,\n        slopes: torch.Tensor,\n        num_heads: int,\n        device: torch.device,\n        dtype: torch.dtype = torch.bfloat16,\n    ) -> torch.Tensor:\n        position_point = (\n            torch.arange(seq_len).view(1, 1, -1).expand(num_heads, -1, -1).to(device)\n        )\n        alibi = slopes.view(-1, 1, 1) * position_point\n        alibi_bias = alibi.view(num_heads, 1, seq_len).to(device).to(dtype)\n        return alibi_bias\n\n    def generate_alibi_bias(\n        self,\n        q_seq_len: int,\n        kv_seq_len: int,\n        slopes: torch.Tensor,\n        num_heads: int,\n        device: torch.device,\n        is_extend: bool = True,\n        dtype: torch.dtype = torch.bfloat16,\n    ) -> torch.Tensor:\n        MAX_LEN_ALB = 5000\n        max_seq_len = max(kv_seq_len, q_seq_len, MAX_LEN_ALB)\n        if getattr(self, \"alibi_bias\", None) is None:\n            self.alibi_bias = self._generate_alibi_bias(\n                max_seq_len, slopes, num_heads, device, dtype\n            )\n\n        if getattr(self, \"super_mask\", None) is None:\n            super_mask = torch.ones(size=(1, max_seq_len, max_seq_len), dtype=dtype)\n            super_mask = super_mask.float().fill_(float(\"-inf\")).type_as(super_mask)\n            super_mask = torch.triu(super_mask, 1).to(device)\n            self.super_mask = super_mask\n        if is_extend:\n            return (\n                self.alibi_bias[:, :q_seq_len, :kv_seq_len]\n                + self.super_mask[:, :q_seq_len, :kv_seq_len]\n            )\n        else:\n            return self.alibi_bias[:, :q_seq_len, :kv_seq_len]\n\n    def attn_alibi(\n        self,\n        q,\n        k_cache,\n        v_cache,\n        block_tables,\n        seq_lens,\n        query_lens,\n        scale_value,\n        num_heads,\n        slopes,\n        is_extend,\n    ):\n        curr = 0\n        num_prompts = query_lens.shape[0]\n        head_size = k_cache.shape[3]\n        head_size_v = v_cache.shape[3]\n        block_size = k_cache.shape[1]\n        attn_output = []\n        for i in range(num_prompts):\n            seq_len = seq_lens[i].item()\n            block_table = block_tables[i]\n\n            j = torch.arange(seq_len, device=block_table.device)\n\n            block_number = block_table[j // block_size]\n            block_offset = j % block_size\n\n            k = k_cache[block_number, block_offset]\n            v = v_cache[block_number, block_offset]\n            k = k.view(seq_len, num_heads, head_size)\n            v = v.view(seq_len, num_heads, head_size_v)\n\n            if is_extend:\n                q_len = query_lens[i].item()\n                query = q[curr : curr + q_len]\n            else:\n                q_len = 1\n                query = q[curr : curr + 1]\n\n            query = query.to(torch.float32)\n            query = query * scale_value\n            query = query.permute(1, 0, 2)\n            k = k.permute(1, 2, 0)\n\n            score = torch.bmm(query, k)\n            score = score.to(torch.float32)\n            if slopes is not None:\n                alibi_bias = self.generate_alibi_bias(\n                    q_seq_len=q_len,\n                    kv_seq_len=seq_len,\n                    slopes=slopes,\n                    num_heads=num_heads,\n                    device=q.device,\n                    is_extend=is_extend,\n                    dtype=query.dtype,\n                )\n                score = score + alibi_bias\n            score = torch.max(score, torch.tensor(torch.finfo(score.dtype).min))\n            p = torch.nn.functional.softmax(score, dim=-1)\n            v = v.permute(1, 0, 2)\n            out = torch.bmm(p, v)\n            out = out.permute(1, 0, 2)\n            out = out.reshape(-1, num_heads * head_size_v)\n            attn_output.append(out)\n            curr += q_len\n        attn_output = torch.cat(attn_output, dim=0).to(q.dtype).to(q.device)\n        attn_output = attn_output.view(-1, num_heads * head_size)\n        return attn_output\n\n    def do_cp_balance_attn(\n        self,\n        q_nope,\n        k_nope,\n        q_pe,\n        k_pe,\n        topk_indices,\n        layer,\n        actual_seq_qlen,\n        actual_seq_lengths_kv,\n    ):\n        seq_len = q_nope.shape[0]\n        split_len = (seq_len + 1) // 2\n        q_nope_prev, q_nope_next = torch.split(q_nope, split_len, dim=0)\n        q_rope_prev, q_rope_next = torch.split(q_pe, split_len, dim=0)\n        q_nope_prev = q_nope_prev.contiguous()\n        q_nope_next = q_nope_next.contiguous()\n        q_rope_prev = q_rope_prev.contiguous()\n        q_rope_next = q_rope_next.contiguous()\n        topk_indices_prev, topk_indices_next = topk_indices\n\n        actual_seq_qlen_prev, actual_seq_qlen_next = actual_seq_qlen\n        actual_seq_lengths_kv_prev, actual_seq_lengths_kv_next = actual_seq_lengths_kv\n\n        attn_out_prev, _, _ = torch_npu.npu_sparse_flash_attention(\n            query=q_nope_prev,\n            key=k_nope,\n            value=k_nope,\n            query_rope=q_rope_prev,\n            key_rope=k_pe,\n            sparse_indices=topk_indices_prev,\n            scale_value=layer.scaling,\n            actual_seq_lengths_query=actual_seq_qlen_prev.to(\n                device=q_nope.device, dtype=torch.int32\n            ),\n            actual_seq_lengths_kv=actual_seq_lengths_kv_prev.to(\n                device=q_nope.device, dtype=torch.int32\n            ),\n            block_table=self.forward_metadata.block_tables,\n            sparse_block_size=1,\n            layout_query=\"TND\",\n            layout_kv=\"PA_BSND\",\n            sparse_mode=3,\n            attention_mode=2,\n            return_softmax_lse=False,\n        )\n        attn_out_next, _, _ = torch_npu.npu_sparse_flash_attention(\n            query=q_nope_next,\n            key=k_nope,\n            value=k_nope,\n            query_rope=q_rope_next,\n            key_rope=k_pe,\n            sparse_indices=topk_indices_next,\n            scale_value=layer.scaling,\n            actual_seq_lengths_query=actual_seq_qlen_next.to(\n                device=q_nope.device, dtype=torch.int32\n            ),\n            actual_seq_lengths_kv=actual_seq_lengths_kv_next.to(\n                device=q_nope.device, dtype=torch.int32\n            ),\n            block_table=self.forward_metadata.block_tables,\n            sparse_block_size=1,\n            layout_query=\"TND\",\n            layout_kv=\"PA_BSND\",\n            sparse_mode=3,\n            attention_mode=2,\n            return_softmax_lse=False,\n        )\n        return torch.cat([attn_out_prev, attn_out_next], dim=0)\n\n    def forward_sparse(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache: bool = True,\n        # For multi_head latent attention\n        q_rope: Optional[torch.Tensor] = None,\n        k_rope: Optional[torch.Tensor] = None,\n        topk_indices: torch.Tensor = None,\n    ):\n\n        is_prefill = (\n            forward_batch.forward_mode.is_extend()\n            and not forward_batch.forward_mode.is_draft_extend_v2()\n            and not forward_batch.forward_mode.is_draft_extend()\n            and not forward_batch.forward_mode.is_target_verify()\n        )\n\n        if save_kv_cache:\n            k = k.view(-1, layer.tp_k_head_num, self.kv_lora_rank)\n            k_rope = k_rope.view(-1, layer.tp_k_head_num, self.qk_rope_head_dim)\n            forward_batch.token_to_kv_pool.set_kv_buffer(\n                layer, forward_batch.out_cache_loc, k, k_rope\n            )\n        q_nope, q_pe = q, q_rope\n        k_nope, k_pe = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)\n\n        if is_prefill:\n            if self.forward_metadata.actual_seq_lengths_q is not None:\n                actual_seq_qlen = self.forward_metadata.actual_seq_lengths_q\n            else:\n                actual_seq_qlen = torch.cumsum(forward_batch.extend_seq_lens, dim=0)\n        else:\n            if self.forward_metadata.actual_seq_lengths_q is None:\n                if (\n                    forward_batch.forward_mode.is_draft_extend_v2()\n                    or forward_batch.forward_mode.is_target_verify()\n                ):\n                    actual_seq_qlen = (\n                        torch.arange(\n                            self.speculative_num_draft_tokens,\n                            self.speculative_num_draft_tokens + q.shape[0],\n                            self.speculative_num_draft_tokens,\n                            dtype=torch.int32,\n                        )\n                        .to(q.device)\n                        .to(torch.int32)\n                    )\n                elif forward_batch.forward_mode.is_draft_extend():\n                    actual_seq_qlen = (\n                        forward_batch.extend_seq_lens.cumsum()\n                        .to(q.device)\n                        .to(torch.int32)\n                    )\n                else:\n                    actual_seq_qlen = (\n                        torch.arange(1, q.shape[0] + 1).to(q.device).to(torch.int32)\n                    )\n            else:\n                actual_seq_qlen = self.forward_metadata.actual_seq_lengths_q\n\n        if self.forward_metadata.actual_seq_lengths_kv is not None:\n            actual_seq_lengths_kv = self.forward_metadata.actual_seq_lengths_kv\n        elif self.forward_metadata.seq_lens_cpu_int is not None:\n            actual_seq_lengths_kv = self.forward_metadata.seq_lens_cpu_int\n        else:\n            actual_seq_lengths_kv = self.forward_metadata.seq_lens\n\n        if (\n            is_prefill\n            and is_nsa_enable_prefill_cp()\n            and forward_batch.nsa_cp_metadata is not None\n        ):\n            attn_out = self.do_cp_balance_attn(\n                q_nope,\n                k_nope,\n                q_pe,\n                k_pe,\n                topk_indices,\n                layer,\n                actual_seq_qlen,\n                actual_seq_lengths_kv,\n            )\n        else:\n            attn_out, _, _ = torch_npu.npu_sparse_flash_attention(\n                query=q_nope,\n                key=k_nope,\n                value=k_nope,\n                query_rope=q_pe,\n                key_rope=k_pe,\n                sparse_indices=topk_indices,\n                scale_value=layer.scaling,\n                actual_seq_lengths_query=actual_seq_qlen.to(\n                    device=q_nope.device, dtype=torch.int32\n                ),\n                actual_seq_lengths_kv=actual_seq_lengths_kv.to(\n                    device=q_nope.device, dtype=torch.int32\n                ),\n                block_table=self.forward_metadata.block_tables,\n                sparse_block_size=1,\n                layout_query=\"TND\",\n                layout_kv=\"PA_BSND\",\n                sparse_mode=3,\n                attention_mode=2,\n                return_softmax_lse=False,\n            )\n\n        return attn_out\n\n    def forward_extend(\n        self,\n        q,\n        k,\n        v,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache: bool = True,\n        # For multi_head latent attention\n        q_rope: Optional[torch.Tensor] = None,\n        k_rope: Optional[torch.Tensor] = None,\n        topk_indices: Optional[torch.Tensor] = None,\n        sinks: Optional[torch.Tensor] = None,\n        slopes: Optional[torch.Tensor] = None,\n    ):\n        if is_mla_preprocess_enabled():\n            # MLAPO and MLAPROLOG do save kv_cache\n            save_kv_cache = False\n        if self.is_dllm_model:\n            return self.forward_dllm(\n                q,\n                k,\n                v,\n                layer,\n                forward_batch,\n                save_kv_cache,\n                q_rope=q_rope,\n                k_rope=k_rope,\n            )\n        if topk_indices is not None:\n            return self.forward_sparse(\n                q,\n                k,\n                v,\n                layer,\n                forward_batch,\n                save_kv_cache,\n                q_rope,\n                k_rope,\n                topk_indices,\n            )\n        if (\n            forward_batch.forward_mode.is_target_verify()\n            or forward_batch.forward_mode.is_draft_extend()\n            or forward_batch.forward_mode.is_draft_extend_v2()\n        ):\n            return self.forward_mtp(\n                q,\n                k,\n                v,\n                layer,\n                forward_batch,\n                save_kv_cache,\n                q_rope=q_rope,\n                k_rope=k_rope,\n            )\n\n        if not self.use_mla:\n            # In cross attention layer, when there is no vision input,the values of k and v is None\n            if save_kv_cache and k is not None and v is not None:\n                # support cross attention\n                cache_loc = (\n                    forward_batch.out_cache_loc\n                    if not layer.is_cross_attention\n                    else forward_batch.encoder_out_cache_loc\n                )\n                forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)\n\n            k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)\n            v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)\n\n            if sinks is not None:\n                attn_out = attention_sinks_prefill_triton(\n                    q,\n                    k_cache,\n                    v_cache,\n                    sinks,\n                    self.forward_metadata.extend_seq_lens,\n                    self.forward_metadata.block_tables,\n                    self.forward_metadata.seq_lens,\n                    layer.scaling,\n                    layer.sliding_window_size,\n                    layer.tp_q_head_num,\n                    layer.tp_k_head_num,\n                )\n                return attn_out\n\n            if self.use_fia:\n                \"\"\"FIA will support multi-bs in the later version of CANN\"\"\"\n                q = q.reshape(-1, layer.tp_q_head_num, layer.qk_head_dim)\n                attn_output = torch.empty(\n                    (q.size(0), layer.tp_q_head_num, layer.v_head_dim),\n                    device=q.device,\n                    dtype=q.dtype,\n                )\n                q_len_offset = 0\n                for q_len in forward_batch.extend_seq_lens_cpu:\n                    attn_output[q_len_offset : q_len_offset + q_len] = (\n                        torch.ops.npu.npu_fused_infer_attention_score(\n                            q[None, q_len_offset : q_len_offset + q_len],\n                            k[None, q_len_offset : q_len_offset + q_len],\n                            v[None, q_len_offset : q_len_offset + q_len],\n                            num_heads=layer.tp_q_head_num,\n                            num_key_value_heads=layer.tp_k_head_num,\n                            input_layout=\"BSND\",  # todo, TND not supports q_heads!=k_heads\n                            atten_mask=self.fia_mask.unsqueeze(0),\n                            sparse_mode=3 if q_len != 1 else 0,\n                            scale=layer.scaling,\n                            next_tokens=0,\n                        )[0]\n                    )\n                    q_len_offset += q_len\n                attn_output = attn_output.view(\n                    -1, layer.tp_q_head_num * layer.v_head_dim\n                )\n\n            else:\n                causal = True\n                if (\n                    layer.is_cross_attention\n                    or layer.attn_type == AttentionType.ENCODER_ONLY\n                ):\n                    causal = False\n                # there are some accuracy issues in cross attention scene to use torch_npu._npu_flash_attention_qlens\n                # forward_batch.encoder_lens is not None in cross attention scend, we add native attn to solve accuracy issues\n                # Model skywork-reward-gemma2-2-27B also suffers from precision anomalies, thus the torch native backend becomes beneficial approach.\n                if (\n                    layer.qk_head_dim <= 128\n                    and causal\n                    and forward_batch.encoder_lens is None\n                    and layer.logit_cap == 0\n                    and not getattr(self, \"use_native_sdpa\", False)\n                ):\n                    if not self.use_alibi:\n                        query = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)\n                        attn_output = torch.empty(\n                            (query.shape[0], layer.tp_q_head_num * layer.v_head_dim),\n                            dtype=query.dtype,\n                            device=query.device,\n                        )\n                        torch_npu._npu_flash_attention_qlens(\n                            query=query,\n                            key_cache=k_cache,\n                            value_cache=v_cache,\n                            mask=self.mask,\n                            block_table=self.forward_metadata.block_tables,\n                            seq_len=self.forward_metadata.extend_seq_lens_cpu_int,\n                            context_lens=self.forward_metadata.seq_lens_cpu_int,\n                            scale_value=layer.scaling,\n                            num_heads=layer.tp_q_head_num,\n                            num_kv_heads=layer.tp_k_head_num,\n                            out=attn_output,\n                        )\n                    else:\n                        attn_output = self.attn_alibi(\n                            q=q.reshape(-1, layer.tp_q_head_num, layer.qk_head_dim),\n                            k_cache=k_cache,\n                            v_cache=v_cache,\n                            block_tables=self.forward_metadata.block_tables,\n                            seq_lens=self.forward_metadata.seq_lens_cpu_int,\n                            query_lens=self.forward_metadata.extend_seq_lens_cpu_int,\n                            scale_value=layer.scaling,\n                            num_heads=layer.tp_q_head_num,\n                            slopes=slopes,\n                            is_extend=True,\n                        )\n                else:\n                    if layer.qk_head_dim != layer.v_head_dim:\n                        attn_output = q.new_empty(\n                            (q.shape[0], layer.tp_q_head_num * layer.v_head_dim)\n                        )\n                    else:\n                        attn_output = torch.empty_like(q)\n\n                    use_gqa = layer.tp_q_head_num != layer.tp_k_head_num\n\n                    q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)\n                    o_ = attn_output.view(-1, layer.tp_q_head_num, layer.v_head_dim)\n\n                    # add forward_batch.encoder_lens and is_cross_attention arguments for cross attention scene\n                    attn_output = self.native_attn.run_sdpa_forward_extend(\n                        q_,\n                        o_,\n                        k_cache.view(-1, layer.tp_k_head_num, layer.qk_head_dim),\n                        v_cache.view(-1, layer.tp_v_head_num, layer.v_head_dim),\n                        forward_batch.req_to_token_pool.req_to_token,\n                        forward_batch.req_pool_indices,\n                        forward_batch.seq_lens,\n                        forward_batch.extend_prefix_lens,\n                        forward_batch.extend_seq_lens,\n                        forward_batch.encoder_lens,\n                        is_cross_attention=layer.is_cross_attention,\n                        scaling=layer.scaling,\n                        enable_gqa=use_gqa,\n                        causal=causal,\n                        logit_cap=layer.logit_cap,\n                        logit_capping_method=layer.logit_capping_method,\n                    )\n                    attn_output = attn_output.view(\n                        -1, layer.tp_q_head_num * layer.v_head_dim\n                    )\n        elif sum(forward_batch.extend_prefix_lens_cpu) > 0:\n            num_token_padding = q.shape[0]\n            q, k, v = [\n                data[: forward_batch.num_token_non_padded_cpu] for data in [q, k, v]\n            ]\n            q_nope, q_rope = q.split([layer.v_head_dim, self.qk_rope_head_dim], dim=-1)\n            k_nope, k_rope = k.split([layer.v_head_dim, self.qk_rope_head_dim], dim=-1)\n\n            # 1st, compute extend tokens to get attn_output and attn_lse\n            num_tokens = q_nope.size(0)\n            attn_output = torch.zeros(\n                num_tokens,\n                layer.tp_q_head_num,\n                layer.v_head_dim,\n                dtype=q_nope.dtype,\n                device=q_nope.device,\n            )\n            attn_lse = torch.zeros(\n                layer.tp_q_head_num,\n                num_tokens,\n                dtype=torch.float32,\n                device=q_nope.device,\n            )\n            torch_npu.atb.npu_ring_mla(\n                q_nope=q_nope,\n                q_rope=q_rope,\n                k_nope=k_nope,\n                k_rope=k_rope,\n                value=v,\n                mask=self.ringmla_mask,\n                seqlen=self.forward_metadata.extend_seq_lens_cpu_int,\n                head_num=layer.tp_q_head_num,\n                kv_head_num=layer.tp_k_head_num,\n                pre_out=None,\n                prev_lse=None,\n                qk_scale=layer.scaling,\n                kernel_type=\"kernel_type_high_precision\",\n                mask_type=\"mask_type_triu\",\n                calc_type=\"calc_type_first_ring\",\n                output=attn_output,\n                softmax_lse=attn_lse,\n            )\n\n            # 2nd, load history kvcache(kv_a and k_pe) and calculate k_nope\n            k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)\n            v_buffer = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)\n            kv_cached = torch.index_select(\n                k_buffer, 0, self.forward_metadata.flatten_prefix_block_tables\n            )\n            k_rope_cached = torch.index_select(\n                v_buffer, 0, self.forward_metadata.flatten_prefix_block_tables\n            ).flatten(0, 1)\n\n            assert layer.kv_b_proj is not None\n            kv = layer.kv_b_proj(kv_cached)[0].view(\n                -1, layer.tp_k_head_num, self.qk_nope_head_dim + layer.v_head_dim\n            )\n            k_nope, v = kv.split([self.qk_nope_head_dim, layer.v_head_dim], dim=-1)\n\n            # 3rd, compute history kv to attn_out\n            k_rope = k_rope_cached.expand(-1, layer.tp_k_head_num, -1)\n            seq_len = torch.stack(\n                [\n                    self.forward_metadata.extend_seq_lens_cpu_int,\n                    self.forward_metadata.prefix_lens,\n                ]\n            )\n            torch_npu.atb.npu_ring_mla(\n                q_nope=q_nope,\n                q_rope=q_rope,\n                k_nope=k_nope,\n                k_rope=k_rope,\n                value=v,\n                mask=self.ringmla_mask,\n                seqlen=seq_len,\n                head_num=layer.tp_q_head_num,\n                kv_head_num=layer.tp_k_head_num,\n                pre_out=attn_output,\n                prev_lse=attn_lse,\n                qk_scale=layer.scaling,\n                kernel_type=\"kernel_type_high_precision\",\n                mask_type=\"no_mask\",\n                calc_type=\"calc_type_default\",\n                output=attn_output,\n                softmax_lse=attn_lse,\n            )\n            attn_output = attn_output.reshape(\n                [-1, layer.tp_q_head_num, layer.v_head_dim]\n            )\n            if num_token_padding != forward_batch.num_token_non_padded_cpu:\n                attn_output = torch.cat(\n                    [\n                        attn_output,\n                        attn_output.new_zeros(\n                            num_token_padding - attn_output.shape[0],\n                            *attn_output.shape[1:],\n                        ),\n                    ],\n                    dim=0,\n                )\n        else:\n            assert (\n                layer.qk_head_dim != layer.v_head_dim\n            ), \"FIA only supports qk_head_dim != v_head_dim\"\n            if layer.v_head_dim in [256]:\n                \"\"\"Currently, in NO_QUANT situation, qk_nope_head_dim == v_head_dim, and rope exists, v_head_dim only support 512 and 128\"\"\"\n                kv_lora_rank = k.shape[-1] - self.qk_rope_head_dim\n                kv_c, k_rope = k.split([kv_lora_rank, self.qk_rope_head_dim], dim=-1)\n                if save_kv_cache:\n                    forward_batch.token_to_kv_pool.set_kv_buffer(\n                        layer, forward_batch.out_cache_loc, kv_c, k_rope\n                    )\n                attn_output = q.new_empty(\n                    (q.shape[0], layer.tp_q_head_num, kv_lora_rank)\n                )\n                use_gqa = layer.tp_q_head_num != layer.tp_k_head_num\n\n                k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)\n                v_cache = forward_batch.token_to_kv_pool.get_value_buffer(\n                    layer.layer_id\n                )\n                kv_cache = torch.cat([k_cache, v_cache], dim=-1)\n                attn_output = self.native_attn.run_sdpa_forward_extend(\n                    q,\n                    attn_output,\n                    kv_cache.view(-1, layer.tp_k_head_num, layer.qk_head_dim),\n                    k_cache.view(-1, layer.tp_v_head_num, layer.v_head_dim),\n                    forward_batch.req_to_token_pool.req_to_token,\n                    forward_batch.req_pool_indices,\n                    forward_batch.seq_lens,\n                    forward_batch.extend_prefix_lens,\n                    forward_batch.extend_seq_lens,\n                    scaling=layer.scaling,\n                    enable_gqa=use_gqa,\n                    causal=True,\n                )\n            else:\n                num_token_padding = q.shape[0]\n                q, k, v = [\n                    data[: forward_batch.num_token_non_padded_cpu] for data in [q, k, v]\n                ]\n\n                q_nope, q_rope = q.split(\n                    [layer.v_head_dim, self.qk_rope_head_dim], dim=-1\n                )\n                k_nope, k_rope = k.split(\n                    [layer.v_head_dim, self.qk_rope_head_dim], dim=-1\n                )\n\n                attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(\n                    q_nope,\n                    k_nope,\n                    v,\n                    query_rope=q_rope,\n                    key_rope=k_rope,\n                    num_heads=layer.tp_q_head_num,\n                    input_layout=\"TND\",\n                    atten_mask=self.fia_mask,\n                    sparse_mode=3,\n                    actual_seq_lengths=self.forward_metadata.seq_lens_list_cumsum,\n                    actual_seq_lengths_kv=self.forward_metadata.seq_lens_list_cumsum,\n                    scale=layer.scaling,\n                    next_tokens=0,\n                )\n\n                attn_output = attn_output.reshape(\n                    -1, layer.tp_q_head_num, layer.v_head_dim\n                )\n                if num_token_padding != forward_batch.num_token_non_padded_cpu:\n                    attn_output = torch.cat(\n                        [\n                            attn_output,\n                            attn_output.new_zeros(\n                                num_token_padding - attn_output.shape[0],\n                                *attn_output.shape[1:],\n                            ),\n                        ],\n                        dim=0,\n                    )\n\n        return attn_output\n\n    def forward_dllm(\n        self,\n        q,\n        k,\n        v,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache: bool = True,\n        # For multi_head latent attention\n        q_rope: Optional[torch.Tensor] = None,\n        k_rope: Optional[torch.Tensor] = None,\n        topk_indices: Optional[torch.Tensor] = None,\n    ):\n        if save_kv_cache:\n            forward_batch.token_to_kv_pool.set_kv_buffer(\n                layer, forward_batch.out_cache_loc, k, v\n            )\n\n        k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)\n        v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)\n        query = q.reshape(-1, layer.tp_q_head_num, layer.qk_head_dim)\n\n        if self.forward_metadata.seq_lens_cpu_int is None:\n            # capture\n            actual_seq_lengths_kv = self.forward_metadata.seq_lens_cpu_list\n        else:\n            # eagle\n            actual_seq_lengths_kv = (\n                self.forward_metadata.seq_lens_cpu_int.cpu().int().tolist()\n            )\n\n        if self.forward_metadata.extend_seq_lens_cpu_int is None:\n            # capture & replay\n            actual_seq_lengths = self.forward_metadata.seq_lens_list_cumsum\n        else:\n            actual_seq_lengths = (\n                torch.cumsum(self.forward_metadata.extend_seq_lens_cpu_int, dim=0)\n                .int()\n                .tolist()\n            )\n\n        attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(\n            query,\n            k_cache.view(-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim),\n            v_cache.view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim),\n            block_table=self.forward_metadata.block_tables,\n            block_size=self.page_size,\n            num_heads=layer.tp_q_head_num,\n            num_key_value_heads=layer.tp_k_head_num,\n            input_layout=\"TND\",\n            atten_mask=None,\n            scale=layer.scaling,\n            actual_seq_lengths=actual_seq_lengths,\n            actual_seq_lengths_kv=actual_seq_lengths_kv,\n        )\n        attn_output = attn_output.view(-1, layer.tp_q_head_num * layer.v_head_dim)\n\n        return attn_output\n\n    def forward_mtp(\n        self,\n        q,\n        k,\n        v,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache: bool,\n        q_rope: Optional[torch.Tensor] = None,\n        k_rope: Optional[torch.Tensor] = None,\n    ):\n        if save_kv_cache:\n            if self.use_mla:\n                k = k.view(-1, layer.tp_k_head_num, self.kv_lora_rank)\n                k_rope = k_rope.view(-1, layer.tp_k_head_num, self.qk_rope_head_dim)\n                forward_batch.token_to_kv_pool.set_kv_buffer(\n                    layer, forward_batch.out_cache_loc, k, k_rope\n                )\n            else:\n                forward_batch.token_to_kv_pool.set_kv_buffer(\n                    layer, forward_batch.out_cache_loc, k, v\n                )\n\n        if not self.use_mla:\n            k_cache = forward_batch.token_to_kv_pool.get_key_buffer(\n                layer.layer_id\n            ).view(-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim)\n            v_cache = forward_batch.token_to_kv_pool.get_value_buffer(\n                layer.layer_id\n            ).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim)\n            query = q.reshape(-1, layer.tp_q_head_num, layer.qk_head_dim).contiguous()\n            if not self.graph_mode:\n                num_token_padding = query.shape[0]\n                query = query[: forward_batch.num_token_non_padded_cpu]\n            if self.forward_metadata.seq_lens_cpu_int is None:\n                actual_seq_lengths_kv = self.forward_metadata.seq_lens_cpu_list\n            else:\n                actual_seq_lengths_kv = (\n                    self.forward_metadata.seq_lens_cpu_int.cpu().int().tolist()\n                )\n            if forward_batch.forward_mode.is_draft_extend():\n                actual_seq_lengths = (\n                    np.array(forward_batch.extend_seq_lens_cpu).cumsum().tolist()\n                )\n            else:\n                actual_seq_lengths = np.arange(\n                    self.speculative_num_draft_tokens,\n                    self.speculative_num_draft_tokens + query.shape[0],\n                    self.speculative_num_draft_tokens,\n                )\n\n            attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(\n                query,\n                k_cache,\n                v_cache,\n                block_table=self.forward_metadata.block_tables,\n                block_size=self.page_size,\n                num_heads=layer.tp_q_head_num,\n                num_key_value_heads=layer.tp_k_head_num,\n                input_layout=\"TND\",\n                atten_mask=self.mtp_mask,\n                scale=layer.scaling,\n                actual_seq_lengths=actual_seq_lengths,\n                actual_seq_lengths_kv=actual_seq_lengths_kv,\n                sparse_mode=3,\n            )\n            attn_output = attn_output.view(-1, layer.tp_q_head_num * layer.v_head_dim)\n            if (\n                not self.graph_mode\n                and forward_batch.num_token_non_padded_cpu != num_token_padding\n            ):\n                attn_output = torch.cat(\n                    [\n                        attn_output,\n                        attn_output.new_zeros(\n                            num_token_padding - forward_batch.num_token_non_padded_cpu,\n                            *attn_output.shape[1:],\n                        ),\n                    ],\n                    dim=0,\n                )\n            return attn_output\n        else:\n            c_kv, k_rope = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)\n            if is_fia_nz():\n                k_rope_cache = _reshape_kv_for_fia_nz(\n                    k_rope, layer.tp_k_head_num, self.qk_rope_head_dim, self.page_size\n                )\n                c_kv_cache = _reshape_kv_for_fia_nz(\n                    c_kv, layer.tp_v_head_num, self.kv_lora_rank, self.page_size\n                )\n            else:\n                k_rope_cache = k_rope.view(\n                    -1, layer.tp_k_head_num, self.page_size, self.qk_rope_head_dim\n                )\n                c_kv_cache = c_kv.view(\n                    -1, layer.tp_v_head_num, self.page_size, self.kv_lora_rank\n                )\n\n            q_nope = q.view(-1, layer.tp_q_head_num, self.kv_lora_rank).contiguous()\n            q_rope = q_rope.view(-1, layer.tp_q_head_num, self.qk_rope_head_dim)\n            if not self.graph_mode:\n                num_token_padding = q.shape[0]\n                q_nope = q_nope[: forward_batch.num_token_non_padded_cpu]\n                q_rope = q_rope[: forward_batch.num_token_non_padded_cpu]\n            if self.forward_metadata.seq_lens_cpu_int is None:\n                actual_seq_lengths_kv = self.forward_metadata.seq_lens_cpu_list\n            else:\n                actual_seq_lengths_kv = (\n                    self.forward_metadata.seq_lens_cpu_int.cpu().int().tolist()\n                )\n            if forward_batch.forward_mode.is_draft_extend():\n                actual_seq_lengths = (\n                    np.array(forward_batch.extend_seq_lens_cpu).cumsum().tolist()\n                )\n            else:\n                actual_seq_lengths = np.arange(\n                    self.speculative_num_draft_tokens,\n                    self.speculative_num_draft_tokens + q_nope.shape[0],\n                    self.speculative_num_draft_tokens,\n                )\n\n            workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(\n                q_nope,\n                c_kv_cache,\n                c_kv_cache,\n                query_rope=q_rope,\n                key_rope=k_rope_cache,\n                num_heads=layer.tp_q_head_num,\n                num_key_value_heads=layer.tp_k_head_num,\n                input_layout=\"TND\",\n                scale=layer.scaling,\n                antiquant_mode=0,\n                antiquant_scale=None,\n                block_table=self.forward_metadata.block_tables,\n                block_size=self.page_size,\n                sparse_mode=3,\n                atten_mask=self.mtp_mask,\n                actual_seq_lengths=actual_seq_lengths,\n                actual_seq_lengths_kv=actual_seq_lengths_kv,\n            )\n            attn_output = torch.empty_like(q_nope, dtype=q.dtype, device=q.device)\n            softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)\n            torch_npu.npu_fused_infer_attention_score.out(\n                q_nope,\n                c_kv_cache,\n                c_kv_cache,\n                query_rope=q_rope,\n                key_rope=k_rope_cache,\n                num_heads=layer.tp_q_head_num,\n                num_key_value_heads=layer.tp_k_head_num,\n                input_layout=\"TND\",\n                scale=layer.scaling,\n                antiquant_mode=0,\n                antiquant_scale=None,\n                block_table=self.forward_metadata.block_tables,\n                block_size=self.page_size,\n                sparse_mode=3,\n                atten_mask=self.mtp_mask,\n                actual_seq_lengths=actual_seq_lengths,\n                actual_seq_lengths_kv=actual_seq_lengths_kv,\n                workspace=workspace,\n                out=[attn_output, softmax_lse],\n            )\n            attn_output = attn_output.view(-1, layer.tp_q_head_num * layer.v_head_dim)\n            if (\n                not self.graph_mode\n                and forward_batch.num_token_non_padded_cpu != num_token_padding\n            ):\n                attn_output = torch.cat(\n                    [\n                        attn_output,\n                        attn_output.new_zeros(\n                            num_token_padding - attn_output.shape[0],\n                            *attn_output.shape[1:],\n                        ),\n                    ],\n                    dim=0,\n                )\n            return attn_output\n\n    def forward_decode_graph(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache: bool = True,\n        q_rope: Optional[torch.Tensor] = None,\n        k_rope: Optional[torch.Tensor] = None,\n        sinks: Optional[torch.Tensor] = None,\n    ):\n        if save_kv_cache:\n            if self.use_mla:\n                k = k.view(-1, layer.tp_k_head_num, self.kv_lora_rank)\n                k_rope = k_rope.view(-1, layer.tp_k_head_num, self.qk_rope_head_dim)\n                forward_batch.token_to_kv_pool.set_kv_buffer(\n                    layer, forward_batch.out_cache_loc, k, k_rope\n                )\n            else:\n                forward_batch.token_to_kv_pool.set_kv_buffer(\n                    layer, forward_batch.out_cache_loc, k, v\n                )\n\n        if sinks is not None:\n            k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)\n            v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)\n\n            attn_out = attention_sinks_triton(\n                q,\n                k_cache,\n                v_cache,\n                sinks,\n                self.forward_metadata.block_tables,\n                self.forward_metadata.seq_lens,\n                layer.scaling,\n                layer.sliding_window_size,\n                layer.tp_q_head_num,\n                layer.tp_k_head_num,\n            )\n            return attn_out\n\n        if not self.use_mla:\n            k_cache = forward_batch.token_to_kv_pool.get_key_buffer(\n                layer.layer_id\n            ).view(-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim)\n            v_cache = forward_batch.token_to_kv_pool.get_value_buffer(\n                layer.layer_id\n            ).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim)\n            query = q.reshape(-1, 1, layer.tp_q_head_num * layer.qk_head_dim)\n            if self.forward_metadata.seq_lens_cpu_int is None:\n                actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list\n            else:\n                actual_seq_len_kv = (\n                    self.forward_metadata.seq_lens_cpu_int.cpu().int().tolist()\n                )\n            num_tokens = query.shape[0]\n            workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(\n                query,\n                k_cache,\n                v_cache,\n                block_table=self.forward_metadata.block_tables,\n                block_size=self.page_size,\n                num_heads=layer.tp_q_head_num,\n                num_key_value_heads=layer.tp_k_head_num,\n                input_layout=\"BSH\",\n                scale=layer.scaling,\n                actual_seq_lengths_kv=actual_seq_len_kv,\n            )\n            output = torch.empty(\n                (num_tokens, 1, layer.tp_q_head_num * layer.v_head_dim),\n                dtype=q.dtype,\n                device=q.device,\n            )\n            softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)\n            torch_npu.npu_fused_infer_attention_score.out(\n                query,\n                k_cache,\n                v_cache,\n                block_table=self.forward_metadata.block_tables,\n                block_size=self.page_size,\n                num_heads=layer.tp_q_head_num,\n                num_key_value_heads=layer.tp_k_head_num,\n                input_layout=\"BSH\",\n                scale=layer.scaling,\n                actual_seq_lengths_kv=actual_seq_len_kv,\n                workspace=workspace,\n                out=[output, softmax_lse],\n            )\n            return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)\n        else:\n            c_kv, k_rope = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)\n            if is_fia_nz():\n                k_rope_cache = _reshape_kv_for_fia_nz(\n                    k_rope, layer.tp_k_head_num, self.qk_rope_head_dim, self.page_size\n                )\n                c_kv_cache = _reshape_kv_for_fia_nz(\n                    c_kv, layer.tp_v_head_num, self.kv_lora_rank, self.page_size\n                )\n            else:\n                k_rope_cache = k_rope.view(\n                    -1, self.page_size, layer.tp_k_head_num * self.qk_rope_head_dim\n                )\n                c_kv_cache = c_kv.view(\n                    -1, self.page_size, layer.tp_k_head_num * self.kv_lora_rank\n                )\n\n            q_nope = q.view(-1, 1, layer.tp_q_head_num, self.kv_lora_rank).contiguous()\n            q_rope = q_rope.view(-1, 1, layer.tp_q_head_num, self.qk_rope_head_dim)\n\n            if self.forward_metadata.seq_lens_cpu_int is None:\n                actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list\n            else:\n                actual_seq_len_kv = (\n                    self.forward_metadata.seq_lens_cpu_int.cpu().int().tolist()\n                )\n\n            workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(\n                q_nope,\n                c_kv_cache,\n                c_kv_cache,\n                query_rope=q_rope,\n                key_rope=k_rope_cache,\n                num_heads=layer.tp_q_head_num,\n                num_key_value_heads=layer.tp_k_head_num,\n                block_table=self.forward_metadata.block_tables,\n                block_size=self.page_size,\n                input_layout=\"BSND\",\n                scale=layer.scaling,\n                actual_seq_lengths_kv=actual_seq_len_kv,\n                antiquant_mode=0,\n                antiquant_scale=None,\n                sparse_mode=0,\n            )\n            output = torch.empty_like(q_nope, dtype=q.dtype, device=q.device)\n            softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)\n\n            torch_npu.npu_fused_infer_attention_score.out(\n                q_nope,\n                c_kv_cache,\n                c_kv_cache,\n                query_rope=q_rope,\n                key_rope=k_rope_cache,\n                num_heads=layer.tp_q_head_num,\n                num_key_value_heads=layer.tp_k_head_num,\n                block_table=self.forward_metadata.block_tables,\n                block_size=self.page_size,\n                input_layout=\"BSND\",\n                scale=layer.scaling,\n                actual_seq_lengths_kv=actual_seq_len_kv,\n                antiquant_mode=0,\n                antiquant_scale=None,\n                sparse_mode=0,\n                workspace=workspace,\n                out=[output, softmax_lse],\n            )\n            return output.view(-1, layer.tp_q_head_num * self.kv_lora_rank)\n\n    def forward_decode(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache: bool = True,\n        # For multi-head latent attention\n        q_rope: Optional[torch.Tensor] = None,\n        k_rope: Optional[torch.Tensor] = None,\n        topk_indices: Optional[torch.Tensor] = None,\n        sinks: Optional[torch.Tensor] = None,\n        slopes: Optional[torch.Tensor] = None,\n    ):\n        if is_mla_preprocess_enabled():\n            # MLAPO does saving kv_cache\n            save_kv_cache = False\n        if topk_indices is not None:\n            return self.forward_sparse(\n                q,\n                k,\n                v,\n                layer,\n                forward_batch,\n                save_kv_cache,\n                q_rope,\n                k_rope,\n                topk_indices,\n            )\n\n        if self.graph_mode and (not self.enable_torch_compile):\n            return self.forward_decode_graph(\n                q,\n                k,\n                v,\n                layer,\n                forward_batch,\n                save_kv_cache,\n                q_rope=q_rope,\n                k_rope=k_rope,\n                sinks=sinks,\n            )\n\n        if not self.use_mla:\n            # In cross attention layer, when there is no vision input,the values of k and v is None\n            if save_kv_cache and k is not None and v is not None:\n                # support cross attention\n                cache_loc = (\n                    forward_batch.out_cache_loc\n                    if not layer.is_cross_attention\n                    else forward_batch.encoder_out_cache_loc\n                )\n                forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)\n            num_tokens = q.shape[0]\n            k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)\n            v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)\n\n            if sinks is not None:\n                attn_out = attention_sinks_triton(\n                    q,\n                    k_cache,\n                    v_cache,\n                    sinks,\n                    self.forward_metadata.block_tables,\n                    self.forward_metadata.seq_lens,\n                    layer.scaling,\n                    layer.sliding_window_size,\n                    layer.tp_q_head_num,\n                    layer.tp_k_head_num,\n                )\n                return attn_out\n\n            if self.use_fia:\n                if self.forward_metadata.seq_lens_cpu_int is None:\n                    actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list\n                else:\n                    actual_seq_len_kv = (\n                        self.forward_metadata.seq_lens_cpu_int.cpu().int().tolist()\n                    )\n                attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(\n                    q.view(\n                        forward_batch.batch_size,\n                        -1,\n                        layer.tp_q_head_num,\n                        layer.qk_head_dim,\n                    ),\n                    k_cache.view(\n                        -1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim\n                    ),\n                    v_cache.view(\n                        -1, self.page_size, layer.tp_v_head_num * layer.qk_head_dim\n                    ),\n                    num_heads=layer.tp_q_head_num,\n                    num_key_value_heads=layer.tp_k_head_num,\n                    input_layout=\"BSND\",\n                    atten_mask=None,\n                    block_size=self.page_size,\n                    block_table=self.forward_metadata.block_tables,\n                    actual_seq_lengths_kv=actual_seq_len_kv,\n                    scale=layer.scaling,\n                )\n            # there are some accuracy issues in cross attention scene to use torch_npu._npu_flash_attention_qlens\n            # forward_batch.encoder_lens is not None in cross attention scend, we add native attn to solve accuracy issues\n            elif forward_batch.encoder_lens is None and layer.logit_cap == 0:\n                query = q.reshape(-1, layer.tp_q_head_num, layer.qk_head_dim)\n                num_tokens = query.shape[0]\n                if not self.use_alibi:\n                    attn_output = torch.empty(\n                        (num_tokens, layer.tp_q_head_num, layer.v_head_dim),\n                        dtype=query.dtype,\n                        device=query.device,\n                    )\n\n                    torch_npu._npu_paged_attention(\n                        query=query,\n                        key_cache=k_cache,\n                        value_cache=v_cache,\n                        num_heads=layer.tp_q_head_num,\n                        num_kv_heads=layer.tp_k_head_num,\n                        scale_value=layer.scaling,\n                        block_table=self.forward_metadata.block_tables,\n                        context_lens=self.forward_metadata.seq_lens_cpu_int,\n                        out=attn_output,\n                    )\n                else:\n                    attn_output = self.attn_alibi(\n                        q=query,\n                        k_cache=k_cache,\n                        v_cache=v_cache,\n                        block_tables=self.forward_metadata.block_tables,\n                        seq_lens=self.forward_metadata.seq_lens_cpu_int,\n                        query_lens=torch.ones(num_tokens, dtype=torch.int32),\n                        scale_value=layer.scaling,\n                        num_heads=layer.tp_q_head_num,\n                        slopes=slopes,\n                        is_extend=False,\n                    )\n            else:\n                if layer.qk_head_dim != layer.v_head_dim:\n                    attn_output = q.new_empty(\n                        (q.shape[0], layer.tp_q_head_num * layer.v_head_dim)\n                    )\n                else:\n                    attn_output = torch.empty_like(q)\n\n                use_gqa = layer.tp_q_head_num != layer.tp_k_head_num\n\n                q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)\n                o_ = attn_output.view(-1, layer.tp_q_head_num, layer.v_head_dim)\n\n                attn_output = self.native_attn.run_sdpa_forward_decode(\n                    q_,\n                    o_,\n                    k_cache.view(-1, layer.tp_k_head_num, layer.qk_head_dim),\n                    v_cache.view(-1, layer.tp_v_head_num, layer.v_head_dim),\n                    forward_batch.req_to_token_pool.req_to_token,\n                    forward_batch.req_pool_indices,\n                    forward_batch.seq_lens,\n                    forward_batch.encoder_lens,\n                    is_cross_attention=layer.is_cross_attention,\n                    scaling=layer.scaling,\n                    enable_gqa=use_gqa,\n                    causal=False,\n                    logit_cap=layer.logit_cap,\n                    logit_capping_method=layer.logit_capping_method,\n                )\n            return attn_output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)\n        else:\n            if save_kv_cache:\n                forward_batch.token_to_kv_pool.set_kv_buffer(\n                    layer, forward_batch.out_cache_loc, k, k_rope\n                )\n            num_tokens = q.shape[0]\n            kv_c = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)\n            k_pe = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)\n\n            if self.use_fia and (layer.tp_q_head_num // layer.tp_k_head_num) >= 8:\n                \"\"\"layer.tp_q_head_num // layer.tp_k_head_num < 8 will support in the later version of CANN\"\"\"\n                if is_fia_nz():\n                    kv_c = _reshape_kv_for_fia_nz(\n                        kv_c, layer.tp_k_head_num, self.kv_lora_rank, self.page_size\n                    )\n                    k_pe = _reshape_kv_for_fia_nz(\n                        k_pe, layer.tp_k_head_num, self.qk_rope_head_dim, self.page_size\n                    )\n                else:\n                    kv_c = kv_c.view(\n                        -1, self.page_size, layer.tp_k_head_num * self.kv_lora_rank\n                    )\n                    k_pe = k_pe.view(\n                        -1, self.page_size, layer.tp_k_head_num * self.qk_rope_head_dim\n                    )\n                q = q.view(\n                    forward_batch.batch_size, -1, layer.tp_q_head_num, self.kv_lora_rank\n                )\n                q_rope = q_rope.view(\n                    forward_batch.batch_size,\n                    -1,\n                    layer.tp_q_head_num,\n                    self.qk_rope_head_dim,\n                )\n                attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(\n                    q,\n                    kv_c,\n                    kv_c,\n                    query_rope=q_rope,\n                    key_rope=k_pe,\n                    num_heads=layer.tp_q_head_num,\n                    num_key_value_heads=layer.tp_k_head_num,\n                    input_layout=\"BSND\",\n                    atten_mask=None,\n                    sparse_mode=0,\n                    scale=layer.scaling,\n                    antiquant_mode=0,\n                    antiquant_scale=None,\n                    block_table=self.forward_metadata.block_tables,\n                    block_size=self.page_size,\n                    actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_int,\n                )\n            else:\n                assert (\n                    self.graph_mode == False\n                )  # _npu_paged_attention_mla not support graph mode\n                if q_rope is not None:\n                    q = torch.cat([q, q_rope], dim=-1)\n                query = q.view(-1, layer.tp_q_head_num, layer.head_dim)\n                kv_c_and_k_pe_cache = torch.cat([kv_c, k_pe], dim=-1)\n                kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(\n                    -1,\n                    self.page_size,\n                    layer.tp_k_head_num,\n                    self.kv_lora_rank + self.qk_rope_head_dim,\n                )\n                attn_output = torch.empty(\n                    [num_tokens, layer.tp_q_head_num, self.kv_lora_rank],\n                    dtype=q.dtype,\n                    device=q.device,\n                )\n                torch_npu._npu_paged_attention_mla(\n                    query=query,\n                    key_cache=kv_c_and_k_pe_cache,\n                    num_kv_heads=layer.tp_k_head_num,\n                    num_heads=layer.tp_q_head_num,\n                    scale_value=layer.scaling,\n                    block_table=self.forward_metadata.block_tables,\n                    context_lens=self.forward_metadata.seq_lens_cpu_int,\n                    mla_vheadsize=self.kv_lora_rank,\n                    out=attn_output,\n                )\n            return attn_output.view(num_tokens, layer.tp_q_head_num * self.kv_lora_rank)\n\n    def forward_mixed(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache: bool = True,\n        q_rope: Optional[torch.Tensor] = None,\n        k_rope: Optional[torch.Tensor] = None,\n        topk_indices: Optional[torch.Tensor] = None,\n    ):\n        if (\n            topk_indices is not None\n            or self.use_mla\n            or (not self.use_fia and layer.qk_head_dim > 128)\n        ):\n            raise NotImplementedError(\n                \"The 'enable-mixed-chunk' feature is currently unsupported in the following scenarios: \"\n                \"1. When using the MLA backend on Ascend NPU devices, \"\n                \"2. When using the deepseekv3.2 model on Ascend NPU devices, \"\n                \"3. When the environment variable ASCEND_USE_FIA is set to 0 and qk_head_dim exceeds 128 on Ascend NPU devices.\"\n            )\n        if save_kv_cache:\n            forward_batch.token_to_kv_pool.set_kv_buffer(\n                layer, forward_batch.out_cache_loc, k, v\n            )\n        k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)\n        v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)\n        num_block, block_size, _, _ = k_cache.shape\n        key = k_cache.view(num_block, block_size, -1)\n        value = v_cache.view(num_block, block_size, -1)\n\n        query = q.reshape(-1, layer.tp_q_head_num, layer.qk_head_dim)\n\n        attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(\n            query,\n            key,\n            value,\n            num_heads=layer.tp_q_head_num,\n            num_key_value_heads=layer.tp_k_head_num,\n            input_layout=\"TND\",\n            block_size=block_size,\n            block_table=self.forward_metadata.block_tables,\n            atten_mask=self.mix_mask,\n            sparse_mode=3,\n            actual_seq_lengths=self.forward_metadata.seq_lens_list_cumsum,\n            actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_int,\n            scale=layer.scaling,\n        )\n\n        return attn_output.view(\n            attn_output.shape[0], layer.tp_q_head_num * layer.v_head_dim\n        )\n\n\nclass AscendAttnMultiStepDraftBackend:\n    \"\"\"\n    Wrap multiple Ascend attention backends as one for multiple consecutive\n    draft decoding steps\n    \"\"\"\n\n    def __init__(\n        self,\n        model_runner: ModelRunner,\n        topk: int,\n        speculative_num_steps: int,\n    ):\n        self.topk = topk\n        self.speculative_num_steps = speculative_num_steps\n\n        self.attn_backends = []\n        for _ in range(self.speculative_num_steps):\n            self.attn_backends.append(AscendAttnBackend(model_runner))\n\n    def common_template(self, forward_batch: ForwardBatch, call_fn: int):\n        assert forward_batch.spec_info is not None\n\n        for i in range(self.speculative_num_steps - 1):\n            call_fn(i, forward_batch)\n\n    def init_forward_metadata(self, forward_batch: ForwardBatch):\n        def call_fn(i, forward_batch):\n            assert forward_batch.spec_info is not None\n            self.attn_backends[i].init_forward_metadata(forward_batch)\n\n        self.common_template(forward_batch, call_fn)\n\n    def init_cuda_graph_state(self, max_bs, max_num_tokens):\n        for i in range(self.speculative_num_steps):\n            self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)\n\n    def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):\n        def call_fn(i, forward_batch):\n            self.attn_backends[i].init_forward_metadata_capture_cuda_graph(\n                forward_batch.batch_size,\n                forward_batch.batch_size * self.topk,\n                forward_batch.req_pool_indices,\n                forward_batch.seq_lens,\n                encoder_lens=None,\n                forward_mode=ForwardMode.DECODE,\n                spec_info=forward_batch.spec_info,\n            )\n\n        self.common_template(forward_batch, call_fn)\n\n    def init_forward_metadata_replay_cuda_graph(\n        self, forward_batch: ForwardBatch, bs: int\n    ):\n        def call_fn(i, forward_batch):\n            self.attn_backends[i].init_forward_metadata_replay_cuda_graph(\n                bs,\n                forward_batch.req_pool_indices,\n                forward_batch.seq_lens,\n                seq_lens_sum=-1,\n                encoder_lens=None,\n                forward_mode=ForwardMode.DECODE,\n                spec_info=forward_batch.spec_info,\n                seq_lens_cpu=forward_batch.seq_lens_cpu,\n            )\n\n        self.common_template(forward_batch, call_fn)\n"
  },
  {
    "path": "python/sglang/srt/hardware_backend/npu/attention/ascend_torch_native_backend.py",
    "content": "from __future__ import annotations\n\nimport math\n\nimport torch\nfrom torch.nn.functional import scaled_dot_product_attention\n\n\nclass AscendTorchNativeAttnBackend:\n    def __init__(self):\n        pass\n\n    def scaled_dot_product_attention_with_softcapping(\n        self,\n        query,\n        key,\n        value,\n        attn_mask=None,\n        is_causal=False,\n        scale=None,\n        enable_gqa=False,\n        logit_cap=0.0,\n        logit_capping_method=\"tanh\",\n    ) -> torch.Tensor:\n        L, S = query.size(-2), key.size(-2)\n        scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale\n        attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)\n        if is_causal:\n            assert attn_mask is None\n            temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(\n                diagonal=0\n            )\n            attn_bias.masked_fill_(temp_mask.logical_not(), float(\"-inf\"))\n            attn_bias.to(query.dtype)\n\n        if attn_mask is not None:\n            if attn_mask.dtype == torch.bool:\n                attn_bias.masked_fill_(attn_mask.logical_not(), float(\"-inf\"))\n            else:\n                attn_bias = attn_mask + attn_bias\n\n        if enable_gqa:\n            key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)\n            value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)\n\n        attn_weight = query @ key.transpose(-2, -1) * scale_factor\n\n        if logit_cap > 0:\n            if logit_capping_method == \"tanh\":\n                attn_weight = logit_cap * torch.tanh(attn_weight / logit_cap)\n\n        attn_weight += attn_bias\n        attn_weight = torch.softmax(attn_weight, dim=-1)\n        return attn_weight @ value\n\n    def run_sdpa_forward_extend(\n        self,\n        query: torch.Tensor,\n        output: torch.Tensor,\n        k_cache: torch.Tensor,\n        v_cache: torch.Tensor,\n        req_to_token: torch.Tensor,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        extend_prefix_lens: torch.Tensor,\n        extend_seq_lens: torch.Tensor,\n        encoder_lens: torch.Tensor = None,\n        is_cross_attention: bool = False,\n        scaling=None,\n        enable_gqa=False,\n        causal=False,\n        logit_cap: float = 0.0,\n        logit_capping_method: str = \"tanh\",\n    ):\n        \"\"\"Run the extend forward by using torch native sdpa op.\n\n        Args:\n            query: [num_tokens, num_heads, head_size]\n            output: [num_tokens, num_heads, head_size]\n            k_cache: [max_total_num_tokens, num_heads, head_size]\n            v_cache: [max_total_num_tokens, num_heads, head_size]\n            req_to_token: [max_num_reqs, max_context_len]\n            req_pool_indices: [num_seqs]\n            seq_lens: [num_seqs]\n            extend_prefix_lens: [num_seqs]\n            extend_seq_lens: [num_seqs]\n            encoder_lens: [num_seqs]\n            is_cross_attention: [bool]\n            scaling: float or None\n            enable_gqa: bool\n            causal: bool\n\n        Returns:\n            output: [num_tokens, num_heads, head_size]\n        \"\"\"\n\n        assert seq_lens.shape[0] == extend_prefix_lens.shape[0]\n        assert seq_lens.shape[0] == extend_seq_lens.shape[0]\n\n        # [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]\n        query = query.movedim(0, query.dim() - 2)\n\n        start_q, start_kv = 0, 0\n        for seq_idx in range(seq_lens.shape[0]):\n            # Need optimize the performance later.\n\n            extend_seq_len_q = extend_seq_lens[seq_idx]\n            prefill_seq_len_q = extend_prefix_lens[seq_idx]\n\n            seq_len_kv = seq_lens[seq_idx]\n            end_q = start_q + extend_seq_len_q\n            end_kv = start_kv + seq_len_kv\n            atten_start_kv = 0\n            atten_end_kv = seq_lens[seq_idx]\n            # support cross attention\n            if encoder_lens is not None:\n                if is_cross_attention:\n                    atten_end_kv = encoder_lens[seq_idx]\n                else:\n                    atten_start_kv = encoder_lens[seq_idx]\n                    atten_end_kv = encoder_lens[seq_idx] + extend_seq_len_q\n\n            per_req_query = query[:, start_q:end_q, :]\n            per_req_query_redudant = torch.empty(\n                (per_req_query.shape[0], seq_len_kv, per_req_query.shape[2]),\n                dtype=per_req_query.dtype,\n                device=per_req_query.device,\n            )\n\n            per_req_query_redudant[:, prefill_seq_len_q:, :] = per_req_query\n\n            # get key and value from cache. per_req_tokens contains the kv cache\n            # index for each token in the sequence.\n            req_pool_idx = req_pool_indices[seq_idx]\n            per_req_tokens = req_to_token[req_pool_idx, atten_start_kv:atten_end_kv]\n            per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)\n            per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)\n\n            if not (per_req_query.dtype == per_req_key.dtype == per_req_value.dtype):\n                # scaled_dot_product_attention() expects query, key, and value to have the same dtype\n                per_req_key = per_req_key.to(per_req_query.dtype)\n                per_req_value = per_req_value.to(per_req_query.dtype)\n\n            if logit_cap > 0:\n                per_req_out_redudant = (\n                    self.scaled_dot_product_attention_with_softcapping(\n                        per_req_query_redudant.unsqueeze(0),\n                        per_req_key.unsqueeze(0),\n                        per_req_value.unsqueeze(0),\n                        enable_gqa=enable_gqa,\n                        scale=scaling,\n                        is_causal=causal,\n                        logit_cap=logit_cap,\n                        logit_capping_method=logit_capping_method,\n                    )\n                    .squeeze(0)\n                    .movedim(query.dim() - 2, 0)\n                )\n            else:\n                per_req_out_redudant = (\n                    scaled_dot_product_attention(\n                        per_req_query_redudant.unsqueeze(0),\n                        per_req_key.unsqueeze(0),\n                        per_req_value.unsqueeze(0),\n                        enable_gqa=enable_gqa,\n                        scale=scaling,\n                        is_causal=causal,\n                    )\n                    .squeeze(0)\n                    .movedim(query.dim() - 2, 0)\n                )\n            output[start_q:end_q, :, :] = per_req_out_redudant[prefill_seq_len_q:, :, :]\n            start_q, start_kv = end_q, end_kv\n        return output\n\n    def run_sdpa_forward_decode(\n        self,\n        query: torch.Tensor,\n        output: torch.Tensor,\n        k_cache: torch.Tensor,\n        v_cache: torch.Tensor,\n        req_to_token: torch.Tensor,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        encoder_lens: torch.Tensor = None,\n        is_cross_attention: bool = False,\n        scaling=None,\n        enable_gqa=False,\n        causal=False,\n        logit_cap: float = 0.0,\n        logit_capping_method: str = \"tanh\",\n    ):\n        \"\"\"Run the decode forward by using torch native sdpa op.\n\n        Args:\n            query: [num_tokens, num_heads, head_size]\n            output: [num_tokens, num_heads, head_size]\n            k_cache: [max_total_num_tokens, num_heads, head_size]\n            v_cache: [max_total_num_tokens, num_heads, head_size]\n            req_to_token: [max_num_reqs, max_context_len]\n            req_pool_indices: [num_seqs]\n            seq_lens: [num_seqs]\n            encoder_lens: [num_seqs]\n            is_cross_attention: [bool]\n            scaling: float or None\n            enable_gqa: bool\n            causal: bool\n\n        Returns:\n            output: [num_tokens, num_heads, head_size]\n        \"\"\"\n\n        # [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]\n        query = query.movedim(0, query.dim() - 2)\n\n        start_q, start_kv = 0, 0\n        for seq_idx in range(seq_lens.shape[0]):\n            # Need optimize the performance later.\n\n            seq_len_q = 1\n            seq_len_kv = seq_lens[seq_idx]\n            end_q = start_q + seq_len_q\n            end_kv = start_kv + seq_len_kv\n            atten_start_kv = 0\n            atten_end_kv = seq_lens[seq_idx]\n            # support cross attention\n            if encoder_lens is not None:\n                if is_cross_attention:\n                    atten_end_kv = encoder_lens[seq_idx]\n                else:\n                    atten_start_kv = encoder_lens[seq_idx]\n                    atten_end_kv = encoder_lens[seq_idx] + seq_len_kv\n\n            per_req_query = query[:, start_q:end_q, :]\n\n            # get key and value from cache. per_req_tokens contains the kv cache\n            # index for each token in the sequence.\n            req_pool_idx = req_pool_indices[seq_idx]\n            per_req_tokens = req_to_token[req_pool_idx, atten_start_kv:atten_end_kv]\n            per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)\n            per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)\n\n            if not (per_req_query.dtype == per_req_key.dtype == per_req_value.dtype):\n                # scaled_dot_product_attention() expects query, key, and value to have the same dtype\n                per_req_key = per_req_key.to(per_req_query.dtype)\n                per_req_value = per_req_value.to(per_req_query.dtype)\n\n            if logit_cap > 0:\n                per_req_out = (\n                    self.scaled_dot_product_attention_with_softcapping(\n                        per_req_query.unsqueeze(0),\n                        per_req_key.unsqueeze(0),\n                        per_req_value.unsqueeze(0),\n                        enable_gqa=enable_gqa,\n                        scale=scaling,\n                        is_causal=causal,\n                        logit_cap=logit_cap,\n                        logit_capping_method=logit_capping_method,\n                    )\n                    .squeeze(0)\n                    .movedim(query.dim() - 2, 0)\n                )\n            else:\n                per_req_out = (\n                    scaled_dot_product_attention(\n                        per_req_query.unsqueeze(0),\n                        per_req_key.unsqueeze(0),\n                        per_req_value.unsqueeze(0),\n                        enable_gqa=enable_gqa,\n                        scale=scaling,\n                        is_causal=causal,\n                    )\n                    .squeeze(0)\n                    .movedim(query.dim() - 2, 0)\n                )\n            output[start_q:end_q, :, :] = per_req_out\n            start_q, start_kv = end_q, end_kv\n\n        return output\n\n    def support_triton(self):\n        return False\n"
  },
  {
    "path": "python/sglang/srt/hardware_backend/npu/attention/mla_preprocess.py",
    "content": "import re\nfrom functools import lru_cache\nfrom typing import TYPE_CHECKING, Optional\n\nimport torch\nimport torch.nn.functional as F\n\nfrom sglang.srt.hardware_backend.npu.utils import npu_format_cast\nfrom sglang.srt.utils import get_bool_env_var\n\nif TYPE_CHECKING:\n    from sglang.srt.layers.quantization.base_config import QuantizationConfig\n\n\n@lru_cache(maxsize=1)\ndef is_mla_preprocess_enabled() -> bool:\n    return get_bool_env_var(\"SGLANG_NPU_USE_MLAPO\")\n\n\n@lru_cache(maxsize=1)\ndef is_fia_nz() -> bool:\n    is_fia_nz_ = get_bool_env_var(\"SGLANG_USE_FIA_NZ\")\n    if is_fia_nz_:\n        assert (\n            is_mla_preprocess_enabled()\n        ), \"SGLANG_USE_FIA_NZ must be enable with SGLANG_NPU_USE_MLAPO\"\n    return is_fia_nz_\n\n\ndef round_up(val: int, align: int) -> int:\n    if align == 0:\n        return 0\n    return -(val // -align) * align\n\n\ndef transdata(nd_mat, block_size: tuple = (16, 16)):\n    r = round_up(nd_mat.shape[0], block_size[0])\n    c = round_up(nd_mat.shape[1], block_size[1])\n    r_pad = r - nd_mat.shape[0]\n    c_pad = c - nd_mat.shape[1]\n    nd_mat = F.pad(nd_mat, ((0, r_pad, 0, c_pad)))\n    nz_mat = torch.permute(\n        torch.reshape(\n            nd_mat,\n            (r // block_size[0], block_size[0], c // block_size[1], block_size[1]),\n        ),\n        [2, 0, 1, 3],\n    )\n    nz_mat = torch.reshape(\n        nz_mat, (nz_mat.shape[0], nz_mat.shape[1] * nz_mat.shape[2], nz_mat.shape[3])\n    )\n    return nz_mat\n\n\ndef trans_rope_weight(weight, rope_dim):\n    weight_1 = weight[..., -rope_dim::2, :].contiguous()\n    weight_2 = weight[..., -rope_dim + 1 :: 2, :].contiguous()\n    weight[..., -rope_dim:, :] = torch.cat([weight_1, weight_2], dim=-2)\n\n    return weight.contiguous()\n\n\nclass NPUFusedMLAPreprocess(torch.nn.Module):\n    def __init__(\n        self,\n        fused_qkv_a_proj_with_mqa,\n        q_a_layernorm,\n        kv_a_layernorm,\n        q_b_proj,\n        w_kc,\n        rotary_emb,\n        layer_id,\n        num_local_heads,\n        qk_nope_head_dim,\n        qk_rope_head_dim,\n        v_head_dim,\n        quant_config: Optional[\"QuantizationConfig\"] = None,\n    ):\n        super().__init__()\n        self.qkv_a_proj = fused_qkv_a_proj_with_mqa\n        self.q_a_layernorm = q_a_layernorm\n        self.kv_a_layernorm = kv_a_layernorm\n        self.q_b_proj = q_b_proj\n        self.w_kc = w_kc.contiguous()\n        self.rotary_emb = rotary_emb\n        self.layer_id = layer_id\n        self.quant_config = quant_config\n        self.has_preprocess_weights = False\n        self.dtype = None\n\n        self.q_lora_rank = self.q_b_proj.input_size  # 1536\n        self.kv_lora_rank = self.kv_a_layernorm.hidden_size  # 512\n        self.num_local_heads = num_local_heads  # tp\n        self.qk_nope_head_dim = qk_nope_head_dim  # 128\n        self.qk_rope_head_dim = qk_rope_head_dim  # 64\n        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim\n        self.v_head_dim = v_head_dim\n        self.q_b_proj_weight_scale = self.q_b_proj.weight_scale.view(1, -1).to(\n            torch.float\n        )\n\n    def preprocess_weights(self, hidden_states):\n        self.dummy = torch.zeros(\n            (hidden_states.shape[-1]),\n            dtype=hidden_states.dtype,\n            device=hidden_states.device,\n        )\n        self.qkv_a_proj_input_offset = self.qkv_a_proj.input_offset.to(dtype=torch.int8)\n        self.q_b_proj_input_offset = self.q_b_proj.input_offset.to(dtype=torch.int8)\n\n        # matmul_0 weight [7168, 2112]\n        fused_qkv_a_proj_with_mqa_weight_q = self.qkv_a_proj.weight.data[\n            :, : self.q_lora_rank\n        ].clone()  # [7168, 1536]\n        fused_qkv_a_proj_with_mqa_weight_kv = self.qkv_a_proj.weight.data[\n            :, self.q_lora_rank :\n        ].clone()  # [7168, 576]\n        # rope fit\n        fused_qkv_a_proj_with_mqa_weight_kv_t = (\n            fused_qkv_a_proj_with_mqa_weight_kv.t().contiguous()\n        )\n        fused_qkv_a_proj_with_mqa_weight_kv_t = trans_rope_weight(\n            fused_qkv_a_proj_with_mqa_weight_kv_t, self.qk_rope_head_dim\n        )\n        fused_qkv_a_proj_with_mqa_weight_kv = (\n            fused_qkv_a_proj_with_mqa_weight_kv_t.t().contiguous()\n        )\n        # cat nz\n        fused_qkv_a_proj_with_mqa_weight_new = torch.cat(\n            (fused_qkv_a_proj_with_mqa_weight_kv, fused_qkv_a_proj_with_mqa_weight_q),\n            dim=-1,\n        )\n        fused_qkv_a_proj_with_mqa_weight = (\n            fused_qkv_a_proj_with_mqa_weight_new.t().contiguous()\n        )\n        fused_qkv_a_proj_with_mqa_weight_nz = (\n            transdata(fused_qkv_a_proj_with_mqa_weight, block_size=(16, 32))\n            .unsqueeze(0)\n            .contiguous()\n        )\n        self.qkv_a_proj_weight_nz = npu_format_cast(fused_qkv_a_proj_with_mqa_weight_nz)\n\n        # matmul_0 deq_scale [2112]\n        fused_qkv_a_proj_with_mqa_deq_scale_q = self.qkv_a_proj.deq_scale.data[\n            : self.q_lora_rank\n        ].clone()  # [7168, 1536]\n        fused_qkv_a_proj_with_mqa_deq_scale_kv = self.qkv_a_proj.deq_scale.data[\n            self.q_lora_rank :\n        ].clone()  # [7168, 576]\n        # rope fit\n        fused_qkv_a_proj_with_mqa_deq_scale_kv = (\n            fused_qkv_a_proj_with_mqa_deq_scale_kv.reshape(\n                self.kv_lora_rank + self.qk_rope_head_dim, -1\n            ).contiguous()\n        )\n        fused_qkv_a_proj_with_mqa_deq_scale_kv = trans_rope_weight(\n            fused_qkv_a_proj_with_mqa_deq_scale_kv, self.qk_rope_head_dim\n        )\n        fused_qkv_a_proj_with_mqa_deq_scale_kv = (\n            fused_qkv_a_proj_with_mqa_deq_scale_kv.view(\n                self.kv_lora_rank + self.qk_rope_head_dim\n            ).contiguous()\n        )\n        self.qkv_a_proj_deq_scale_kvq = torch.cat(\n            (\n                fused_qkv_a_proj_with_mqa_deq_scale_kv,\n                fused_qkv_a_proj_with_mqa_deq_scale_q,\n            ),\n            dim=-1,\n        )\n\n        # matmul_0 quant_bias [2112]\n        fused_qkv_a_proj_with_mqa_quant_bias_q = self.qkv_a_proj.quant_bias.data[\n            : self.q_lora_rank\n        ].clone()  # [7168, 1536]\n        fused_qkv_a_proj_with_mqa_quant_bias_kv = self.qkv_a_proj.quant_bias.data[\n            self.q_lora_rank :\n        ].clone()  # [7168, 576]\n        # rope fit\n        fused_qkv_a_proj_with_mqa_quant_bias_kv = (\n            fused_qkv_a_proj_with_mqa_quant_bias_kv.reshape(\n                self.kv_lora_rank + self.qk_rope_head_dim, -1\n            ).contiguous()\n        )\n        fused_qkv_a_proj_with_mqa_quant_bias_kv = trans_rope_weight(\n            fused_qkv_a_proj_with_mqa_quant_bias_kv, self.qk_rope_head_dim\n        )\n        fused_qkv_a_proj_with_mqa_quant_bias_kv = (\n            fused_qkv_a_proj_with_mqa_quant_bias_kv.view(\n                self.kv_lora_rank + self.qk_rope_head_dim\n            ).contiguous()\n        )\n        self.qkv_a_proj_quant_bias_kvq = torch.cat(\n            (\n                fused_qkv_a_proj_with_mqa_quant_bias_kv,\n                fused_qkv_a_proj_with_mqa_quant_bias_q,\n            ),\n            dim=-1,\n        )\n\n        # matmul_1 weight [1536, num_head * 192]\n        q_b_proj_weight = self.q_b_proj.weight.data.clone()\n        q_b_proj_weight = q_b_proj_weight.t().reshape(\n            self.num_local_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1\n        )\n        q_b_proj_weight = trans_rope_weight(q_b_proj_weight, self.qk_rope_head_dim)\n        q_b_proj_weight = q_b_proj_weight.reshape(\n            self.num_local_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim), -1\n        )\n        q_b_proj_weight_nz = (\n            transdata(q_b_proj_weight, block_size=(16, 32)).unsqueeze(0).contiguous()\n        )\n        self.q_b_proj_weight_nz = npu_format_cast(q_b_proj_weight_nz)\n\n        # matmul_1 deq_scale [num_head * 192]\n        q_b_proj_deq_scale = self.q_b_proj.deq_scale.data.clone()\n        q_b_proj_deq_scale = q_b_proj_deq_scale.reshape(\n            self.num_local_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1\n        )\n        q_b_proj_deq_scale = trans_rope_weight(\n            q_b_proj_deq_scale, self.qk_rope_head_dim\n        )\n        self.q_b_proj_deq_scale = q_b_proj_deq_scale.reshape(\n            self.num_local_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim)\n        )\n\n        # matmul_1 quant_bias [num_head * 192]\n        q_b_proj_quant_bias = self.q_b_proj.quant_bias.data.clone()\n        q_b_proj_quant_bias = q_b_proj_quant_bias.reshape(\n            self.num_local_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1\n        )\n        q_b_proj_quant_bias = trans_rope_weight(\n            q_b_proj_quant_bias, self.qk_rope_head_dim\n        )\n        self.q_b_proj_quant_bias = q_b_proj_quant_bias.reshape(\n            self.num_local_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim)\n        )\n\n    def mlaprolog_preprocess_weight(self):\n        self.qkv_a_proj.weight.data = self.qkv_a_proj.weight.data.transpose(0, 1)\n        qkv_a_proj_weight_q = self.qkv_a_proj.weight.data[:, : self.q_lora_rank].clone()\n        qkv_a_proj_weight_kv = self.qkv_a_proj.weight.data[\n            :, self.q_lora_rank :\n        ].clone()\n        self.q_a_proj_weight = npu_format_cast(qkv_a_proj_weight_q)\n        self.kv_a_proj_weight = npu_format_cast(qkv_a_proj_weight_kv)\n\n    def get_sin_cos(self, positions):\n        cos_sin = self.rotary_emb.cos_sin_cache[positions]\n        cos, sin = cos_sin.chunk(2, dim=-1)\n        cos = cos.repeat(1, 2)\n        sin = sin.repeat(1, 2)\n        return cos, sin\n\n    def get_kv_cache_and_cache_idx(self, forward_batch):\n        k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(self.layer_id)\n        slot_mapping = forward_batch.out_cache_loc.to(dtype=torch.int32)\n        return k_cache, v_cache, slot_mapping\n\n    def forward_absorb_prepare_npu_rms_norm_cache(\n        self,\n        positions: torch.Tensor,\n        hidden_states: torch.Tensor,\n        forward_batch,\n        zero_allocator,\n    ):\n        bsz, _ = hidden_states.view(-1, hidden_states.shape[-1]).shape\n        self.dtype = hidden_states.dtype\n        if self.layer_id == 0:\n            self.cos, self.sin = self.get_sin_cos(positions)\n            self.rotary_emb.cos_cached, self.rotary_emb.sin_cache = self.cos, self.sin\n        else:\n            self.cos, self.sin = self.rotary_emb.cos_cached, self.rotary_emb.sin_cache\n\n        self.kvCache, self.kvCacheRope, self.slotmapping = (\n            self.get_kv_cache_and_cache_idx(forward_batch)\n        )\n\n        if not self.has_preprocess_weights:\n            self.has_preprocess_weights = True\n\n        cos, sin = self.cos, self.sin\n\n        if self.q_lora_rank is not None:\n            fused_qkv_a_proj_out = self.qkv_a_proj(hidden_states)[0]\n            q_lowrank, latent_cache = fused_qkv_a_proj_out.split(\n                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1\n            )\n            q = self.q_a_layernorm(q_lowrank)\n            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)\n        else:\n            q = self.q_proj(hidden_states)[0].view(\n                -1, self.num_local_heads, self.qk_head_dim\n            )\n            latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]\n\n        q_nope, q_pe = torch.split(\n            q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1\n        )  # b*s,n,d\n\n        q_nope = q_nope.view(-1, self.num_local_heads, self.qk_nope_head_dim)\n        q_nope = torch.matmul(q_nope.transpose(0, 1), self.w_kc).transpose(0, 1)\n\n        q_pe = q_pe.view(-1, self.num_local_heads, 1, self.qk_rope_head_dim)\n        cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)\n        sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)\n        q_pe = torch.ops.npu.npu_interleave_rope(q_pe, cos, sin)  # (B,N,S,D)\n        q_pe = q_pe.view(cos.shape[0], self.num_local_heads, self.qk_rope_head_dim)\n\n        latent_cache = latent_cache.view(\n            -1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim\n        )  # (B*S,N,1,D)\n\n        cache_mode = \"PA_NZ\" if is_fia_nz() else \"PA_BNSD\"\n        self.kvCache = self.kvCache.view(\n            -1,\n            forward_batch.attn_backend.page_size,\n            1,\n            forward_batch.attn_backend.kv_lora_rank,\n        )\n        self.kvCacheRope = self.kvCacheRope.view(\n            -1,\n            forward_batch.attn_backend.page_size,\n            1,\n            forward_batch.attn_backend.qk_rope_head_dim,\n        )\n        k_rope, k_nope, _, _ = torch.ops.npu.npu_kv_rmsnorm_rope_cache(\n            latent_cache,\n            self.kv_a_layernorm.weight,\n            cos,\n            sin,\n            self.slotmapping.to(torch.int64),\n            self.kvCacheRope,\n            self.kvCache,\n            epsilon=self.kv_a_layernorm.variance_epsilon,\n            cache_mode=cache_mode,\n        )\n\n        return (q_pe, k_rope, q_nope, k_nope, forward_batch, zero_allocator, positions)\n\n    def forward_mlapo(self, positions, hidden_states, forward_batch, zero_allocator):\n        input_dtype = hidden_states.dtype\n        if not self.has_preprocess_weights:\n            self.preprocess_weights(hidden_states)\n            self.has_preprocess_weights = True\n            self.dtype = hidden_states.dtype\n\n        if self.layer_id == 0:\n            cos, sin = self.get_sin_cos(positions)\n            self.rotary_emb.cos_cached, self.rotary_emb.sin_cache = cos, sin\n        else:\n            cos, sin = self.rotary_emb.cos_cached, self.rotary_emb.sin_cache\n\n        k_cache, v_cache, slot_mapping = self.get_kv_cache_and_cache_idx(forward_batch)\n\n        q_nope_out = torch.empty(\n            (hidden_states.shape[0], self.w_kc.shape[0], k_cache.shape[-1]),\n            dtype=input_dtype,\n            device=hidden_states.device,\n        )\n        q_rope_out = torch.empty(\n            (hidden_states.shape[0], self.w_kc.shape[0], v_cache.shape[-1]),\n            dtype=input_dtype,\n            device=hidden_states.device,\n        )\n        if is_fia_nz():\n            kv_shape, kv_rope_shape = k_cache.shape, v_cache.shape\n            num_blocks, block_size, num_heads, _ = kv_shape\n            k_cache = k_cache.view(\n                num_blocks, num_heads * self.kv_lora_rank // 16, block_size, 16\n            )\n            v_cache = v_cache.view(\n                num_blocks, num_heads * self.qk_rope_head_dim // 16, block_size, 16\n            )\n        # TODO: dummy inputs to be removed\n        # https://github.com/sgl-project/sgl-kernel-npu/issues/78\n        if hasattr(self.q_a_layernorm, \"bias\"):\n            q_a_layernorm_bias = self.q_a_layernorm.bias\n        else:\n            q_a_layernorm_bias = self.dummy\n\n        torch.ops.npu.mla_preprocess(\n            hidden_states,\n            self.dummy,\n            self.dummy,\n            self.qkv_a_proj_weight_nz,\n            self.qkv_a_proj_deq_scale_kvq,\n            self.q_a_layernorm.weight,\n            q_a_layernorm_bias,\n            self.q_b_proj_weight_nz,\n            self.q_b_proj_deq_scale,\n            self.kv_a_layernorm.weight,\n            cos,\n            sin,\n            self.w_kc,\n            k_cache,\n            v_cache,\n            slot_mapping,\n            quant_scale0=self.qkv_a_proj.input_scale,\n            quant_offset0=self.qkv_a_proj_input_offset,\n            bias0=self.qkv_a_proj_quant_bias_kvq,\n            quant_scale1=self.q_b_proj.input_scale,\n            quant_offset1=self.q_b_proj_input_offset,\n            bias1=self.q_b_proj_quant_bias,\n            cache_mode=\"nzcache\" if is_fia_nz() else \"krope_ctkv\",\n            quant_mode=\"per_tensor_quant_asymm\",\n            q_out0=q_nope_out,\n            kv_cache_out0=k_cache,\n            q_out1=q_rope_out,\n            kv_cache_out1=v_cache,\n        )\n\n        if is_fia_nz():\n            k_cache = k_cache.view(kv_shape)\n            v_cache = v_cache.view(kv_rope_shape)\n\n        return (\n            q_rope_out,\n            v_cache,\n            q_nope_out,\n            k_cache,\n            forward_batch,\n            zero_allocator,\n            positions,\n        )\n\n    def forward_mlaprolog(self, positions, hidden_states, forward_batch):\n        if not self.has_preprocess_weights:\n            self.mlaprolog_preprocess_weight()\n            self.has_preprocess_weights = True\n        self.cos, self.sin = self.get_sin_cos(positions)\n        k_cache, v_cache, slot_mapping = self.get_kv_cache_and_cache_idx(forward_batch)\n        mla_prolog_input_args = {\n            \"token_x\": hidden_states,\n            \"weight_dq\": self.q_a_proj_weight,\n            \"weight_uq_qr\": self.q_b_proj.weight,\n            \"weight_uk\": self.w_kc,\n            \"weight_dkv_kr\": self.kv_a_proj_weight,\n            \"rmsnorm_gamma_cq\": self.q_a_layernorm.weight,\n            \"rmsnorm_gamma_ckv\": self.kv_a_layernorm.weight,\n            \"rope_sin\": self.sin,\n            \"rope_cos\": self.cos,\n            \"kv_cache\": k_cache,\n            \"kr_cache\": v_cache,\n            \"cache_index\": slot_mapping.to(dtype=torch.int64),\n            \"dequant_scale_w_uq_qr\": self.q_b_proj_weight_scale,\n            \"rmsnorm_epsilon_cq\": self.q_a_layernorm.variance_epsilon,\n            \"rmsnorm_epsilon_ckv\": self.kv_a_layernorm.variance_epsilon,\n            \"cache_mode\": \"PA_BSND\",\n            \"query_norm_flag\": True,\n            \"weight_quant_mode\": 1,  # 0:no quant; 1:uq_qr: quant; 2: weight_dq,weight_uq_qr,weight_dkv_kr: quant\n        }\n        q_nope, q_pe, dequant_scale_q_nope, qr, dequant_q_norm = (\n            torch.ops.custom.npu_mla_prolog_v3(**mla_prolog_input_args)\n        )\n        dequant_q_norm = dequant_q_norm.view(hidden_states.shape[0])\n        return (\n            q_pe,\n            v_cache,\n            q_nope,\n            k_cache,\n            qr,\n            forward_batch,\n            positions,\n            dequant_q_norm,\n        )\n\n    def forward(self, positions, hidden_states, forward_batch, zero_allocator):\n        # assert self.quant_config and self.quant_config.get_name() == \"modelslim\"\n        # route by `qkv_a_proj` quant type as MTP layers can be unquantized\n        _is_w8a8 = (\n            hasattr(self.qkv_a_proj.quant_method, \"quantization_config\")\n            and self.qkv_a_proj.quant_method.quantization_config.get_name()\n            == \"modelslim\"\n        )\n        # with the mlaprolog enabled, the kv_b_proj layers are unquantized\n        _is_mlaprolog = hasattr(self.quant_config, \"ignore\") and any(\n            re.fullmatch(r\".*kv_b_proj\", l) for l in self.quant_config.ignore\n        )\n        if _is_w8a8:\n            return self.forward_mlapo(\n                positions, hidden_states, forward_batch, zero_allocator\n            )\n        elif _is_mlaprolog:\n            return self.forward_mlaprolog(positions, hidden_states, forward_batch)\n        else:\n            return self.forward_absorb_prepare_npu_rms_norm_cache(\n                positions, hidden_states, forward_batch, zero_allocator\n            )\n"
  },
  {
    "path": "python/sglang/srt/hardware_backend/npu/cmo.py",
    "content": "import torch\n\ncmo_stream = None\n\n\ndef get_cmo_stream():\n    \"\"\"\n    Cache Management Operation(CMO).\n    Launch a new stream to prefetch the weight of matmul when running other\n    AIV or communication kernels, aiming to overlap the memory access time.\n    \"\"\"\n    global cmo_stream\n    return cmo_stream\n\n\ndef set_cmo_stream(stream):\n    global cmo_stream\n    cmo_stream = stream\n\n\ndef prepare_weight_cache(handle, cache, PREFETCH_MAX_SIZE=1000000000):\n    \"\"\"\n    PREFETCH_MAX_SIZE: maximum size (bytes) for each prefetch operation.\n    This affects the time spent in prefetch:\n        time ≈ PREFETCH_MAX_SIZE / system_bandwidth\n    \"\"\"\n    import torch_npu\n\n    stream = get_cmo_stream()\n    if stream is None:\n        stream = torch.npu.Stream()\n        set_cmo_stream(stream)\n    stream.wait_stream(torch.npu.current_stream())\n    with torch.npu.stream(stream):\n        if isinstance(cache, list):\n            for weight in cache:\n                torch_npu.npu_prefetch(\n                    weight,\n                    handle,\n                    PREFETCH_MAX_SIZE,\n                )\n        else:\n            torch_npu.npu_prefetch(\n                cache,\n                handle,\n                PREFETCH_MAX_SIZE,\n            )\n\n\ndef wait_cmo_stream():\n    stream = get_cmo_stream()\n    if stream is not None:\n        cur_stream = torch.npu.current_stream()\n        cur_stream.wait_stream(stream)\n"
  },
  {
    "path": "python/sglang/srt/hardware_backend/npu/graph_runner/eagle_draft_extend_npu_graph_runner.py",
    "content": "# Copyright 2024-2025 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Run the model with npu graph and torch.compile.\"\"\"\n\nfrom __future__ import annotations\n\nimport threading\nfrom typing import TYPE_CHECKING\n\nimport torch\n\nfrom sglang.srt.configs.model_config import is_deepseek_nsa\nfrom sglang.srt.model_executor.forward_batch_info import ForwardBatch\nfrom sglang.srt.speculative.eagle_draft_extend_cuda_graph_runner import (\n    EAGLEDraftExtendCudaGraphRunner,\n)\n\nif TYPE_CHECKING:\n    from sglang.srt.speculative.eagle_worker import EAGLEWorker\n\n\nclass EAGLEDraftExtendNpuGraphRunner(EAGLEDraftExtendCudaGraphRunner):\n    def __init__(self, eagle_worker: EAGLEWorker):\n        super().__init__(eagle_worker)\n\n    def _create_graph(self):\n        return torch.npu.NPUGraph()\n\n    def _cache_loc_dtype(self):\n        return torch.int32\n\n    def _capture_init(self, run_once_fn):\n        for _ in range(2):\n            torch.npu.synchronize()\n            self.model_runner.tp_group.barrier()\n            run_once_fn()\n\n    def _capture_graph(self, graph, pool, stream, run_once_fn):\n        with torch.npu.graph(\n            graph, pool=pool, stream=stream, auto_dispatch_capture=True\n        ):\n            out = run_once_fn()\n        return out\n\n    def _replay_update(self, seq_lens):\n        self.graphs[self.bs].update(\n            cpu_update_input=[{\"actual_seq_lengths_kv\": seq_lens}]\n        )\n\n    def _replay(self, forward_batch: ForwardBatch):\n        if not is_deepseek_nsa(self.model_runner.model_config.hf_config):\n            seq_lens = forward_batch.seq_lens_cpu.tolist() + [0] * (\n                self.bs - self.raw_bs\n            )\n            thread = threading.Thread(target=self._replay_update, args=(seq_lens,))\n            thread.start()\n            self.graphs[self.bs].replay()\n            thread.join()\n        else:\n            self.graphs[self.bs].replay()\n"
  },
  {
    "path": "python/sglang/srt/hardware_backend/npu/graph_runner/eagle_draft_npu_graph_runner.py",
    "content": "# Copyright 2025 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Run the model with npu graph and torch.compile\"\"\"\n\nfrom __future__ import annotations\n\nimport logging\nimport threading\nfrom typing import TYPE_CHECKING, Dict, Union\n\nimport numpy as np\nimport torch\n\nfrom sglang.srt.configs.model_config import AttentionArch, is_deepseek_nsa\nfrom sglang.srt.model_executor.forward_batch_info import ForwardBatch\nfrom sglang.srt.speculative.eagle_draft_cuda_graph_runner import (\n    EAGLEDraftCudaGraphRunner,\n)\n\nif TYPE_CHECKING:\n    from sglang.srt.speculative.eagle_worker import EAGLEWorker\n\nfrom sglang.srt.utils import is_npu\n\nlogger = logging.getLogger(__name__)\n\nif is_npu():\n    torch.cuda.CUDAGraph = torch.npu.NPUGraph\n    torch.cuda.synchronize = torch.npu.synchronize\n    torch.cuda.graph = torch.npu.graph\n    torch.cuda.stream = torch.npu.stream\n    torch.cuda.Stream = torch.npu.Stream\n    torch.cuda.current_stream = torch.npu.current_stream\n\n\nclass EAGLEDraftNpuGraphRunner(EAGLEDraftCudaGraphRunner):\n    def __init__(self, eagle_worker: EAGLEWorker):\n        super().__init__(eagle_worker)\n        self.update_attr_name = None\n        self.update_attr_type = None\n        self._init_arch_map()\n\n    def _init_arch_map(self):\n        self.attr_name: Dict[str, str] = {\n            AttentionArch.MLA: \"actual_seq_lengths_kv\",\n            AttentionArch.MHA: \"context_lens\",\n        }\n        self.attr_type: Dict[str, Union[list, torch.Tensor]] = {\n            AttentionArch.MLA: [],\n            AttentionArch.MHA: torch.Tensor(),\n        }\n\n    def _create_graph(self):\n        return torch.npu.NPUGraph()\n\n    def _capture_init(self, run_once_fn):\n        for _ in range(2):\n            torch.npu.synchronize()\n            self.model_runner.tp_group.barrier()\n            run_once_fn()\n\n    def _capture_graph(self, graph, pool, stream, run_once_fn):\n        with torch.npu.graph(\n            graph, pool=pool, stream=stream, auto_dispatch_capture=True\n        ):\n            out = run_once_fn()\n        return out\n\n    def _get_update_attr_name(self):\n        return self.attr_name[AttentionArch.MLA]\n\n    def _get_update_attr_type(self):\n        return self.attr_type[AttentionArch.MLA]\n\n    def _replay_update(self, seq_lens):\n        if isinstance(self.update_attr_type, torch.Tensor):\n            seq_lens = torch.from_numpy(np.array(seq_lens).astype(np.int32))\n\n        self.graphs[self.bs].update(\n            cpu_update_input=[{self.update_attr_name: seq_lens}]\n        )\n\n    def _replay(self, forward_batch: ForwardBatch):\n        self.update_attr_name = self._get_update_attr_name()\n        self.update_attr_type = self._get_update_attr_type()\n        if not is_deepseek_nsa(self.model_runner.model_config.hf_config):\n            seq_lens = forward_batch.seq_lens_cpu.tolist() + [0] * (\n                self.bs - self.raw_bs\n            )\n            thread = threading.Thread(target=self._replay_update, args=(seq_lens,))\n            thread.start()\n            self.graphs[self.bs].replay()\n            thread.join()\n        else:\n            self.graphs[self.bs].replay()\n\n    def _cache_loc_dtype(self):\n        return torch.int32\n"
  },
  {
    "path": "python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Run the model with npu graph and torch.compile.\"\"\"\n\nfrom __future__ import annotations\n\nimport logging\nimport os\nimport threading\nfrom contextlib import contextmanager\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING, Dict, Optional, Union\n\nimport numpy as np\nimport torch\n\nimport sglang\nfrom sglang.srt.configs.model_config import AttentionArch, is_deepseek_nsa\nfrom sglang.srt.distributed.parallel_state import GroupCoordinator\nfrom sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner\nfrom sglang.srt.utils import (\n    empty_context,\n    get_bool_env_var,\n    get_compiler_backend,\n    is_npu,\n)\n\nis_npu = is_npu()\n\nif is_npu:\n    import torch_npu\n    from torch_npu.profiler import ProfilerActivity, profile\n\nlogger = logging.getLogger(__name__)\n\nif TYPE_CHECKING:\n    from sglang.srt.model_executor.model_runner import ModelRunner\n\nfrom sglang.srt.layers.logits_processor import LogitsProcessorOutput\nfrom sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors\n\n\n@contextmanager\ndef patch_model_npu(\n    model: torch.nn.Module,\n    enable_compile: bool,\n    num_tokens: int,\n    tp_group: GroupCoordinator,\n):\n    if enable_compile:\n        backend = get_compiler_backend(\"npugraph_ex\")\n        yield torch.compile(\n            torch.no_grad()(model.forward),\n            fullgraph=True,\n            dynamic=False,\n            backend=backend,\n        )\n    else:\n        yield model.forward\n\n\nclass NPUGraphRunner(CudaGraphRunner):\n    \"\"\"A NPUGraphRunner runs the forward pass of a model with npu graph and torch.compile.\"\"\"\n\n    def __init__(self, model_runner: ModelRunner):\n        sglang.srt.model_executor.cuda_graph_runner.patch_model = patch_model_npu\n        super().__init__(model_runner)\n        self.update_attr_name = None\n        self.update_attr_type = None\n        self.model_runner = model_runner\n        self._init_arch_map()\n        self.use_fia = get_bool_env_var(\"ASCEND_USE_FIA\", \"False\")\n\n    def _init_arch_map(self):\n        if self.is_dllm:\n            self.attr_name: Dict[str, str] = {\n                AttentionArch.MLA: \"actual_seq_lengths_kv\",\n                AttentionArch.MHA: \"actual_seq_lengths_kv\",\n            }\n        else:\n            self.attr_name: Dict[str, str] = {\n                AttentionArch.MLA: \"actual_seq_lengths_kv\",\n                AttentionArch.MHA: \"context_lens\",\n            }\n        self.attr_type: Dict[str, Union[list, torch.Tensor]] = {\n            AttentionArch.MLA: [],\n            AttentionArch.MHA: torch.Tensor(),\n        }\n\n    def _create_device_graph(self):\n        return torch.npu.NPUGraph()\n\n    def _capture_graph(self, graph, pool, stream, run_once_fn):\n        if self.enable_torch_compile:\n            skip_guard_context = torch.compiler.set_stance(skip_guard_eval_unsafe=True)\n        else:\n            skip_guard_context = empty_context()\n\n        with skip_guard_context, torch.npu.graph(\n            graph,\n            pool=pool,\n            stream=stream,\n            auto_dispatch_capture=True,\n        ):\n            out = run_once_fn()\n        return out\n\n    def _get_update_attr_name(self):\n        return self.attr_name[AttentionArch.MLA]\n\n    def _get_update_attr_type(self):\n        return self.attr_type[AttentionArch.MLA]\n\n    def _update_inputs(self, seq_lens):\n        if isinstance(self.update_attr_type, torch.Tensor):\n            seq_lens = torch.from_numpy(np.array(seq_lens).astype(np.int32))\n\n        self.graphs[self.bs].update(\n            cpu_update_input=[{self.update_attr_name: seq_lens}]\n        )\n\n    def _cache_loc_dtype(self):\n        return torch.int32\n\n    def _init_profile_context_and_memory_record(self):\n        output_dir = os.path.join(\n            os.getenv(\"SGLANG_TORCH_PROFILER_DIR\", \"/tmp\"), \"graph_capture_profile\"\n        )\n        if not Path(output_dir).exists():\n            Path(output_dir).mkdir(parents=True, exist_ok=True)\n        logger.info(\n            f\"Profiling starts for graph capture for NPU. Traces will be saved to: {output_dir}\"\n        )\n        experimental_config = torch_npu.profiler._ExperimentalConfig(\n            export_type=[torch_npu.profiler.ExportType.Text],\n            profiler_level=torch_npu.profiler.ProfilerLevel.Level1,\n        )\n        profile_context = profile(\n            activities=[ProfilerActivity.CPU, ProfilerActivity.NPU],\n            record_shapes=True,\n            profile_memory=True,\n            on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(\n                output_dir, async_mode=True\n            ),\n            experimental_config=experimental_config,\n        )\n        return profile_context\n\n    def _post_process_after_profile(self, prof_context):\n        # for NPU, profile data will be saved to disk for further analysis.\n        pass\n\n    def replay(\n        self,\n        forward_batch: ForwardBatch,\n        skip_attn_backend_init: bool = False,\n        pp_proxy_tensors: Optional[PPProxyTensors] = None,\n    ) -> Union[LogitsProcessorOutput, PPProxyTensors]:\n        if not skip_attn_backend_init:\n            self.replay_prepare(forward_batch, pp_proxy_tensors)\n        else:\n            # In speculative decoding, these two fields are still needed.\n            self.buffers.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids)\n            self.buffers.positions[: self.raw_num_token].copy_(forward_batch.positions)\n\n        self.update_attr_name = self._get_update_attr_name()\n        self.update_attr_type = self._get_update_attr_type()\n        # Replay\n        if not is_deepseek_nsa(self.model_runner.model_config.hf_config):\n            if forward_batch.forward_mode.is_target_verify():\n                seq_lens_cpu = forward_batch.seq_lens.cpu() + self.num_tokens_per_bs\n                seq_lens = seq_lens_cpu.tolist() + [0] * (self.bs - self.raw_bs)\n            else:\n                seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * (\n                    self.bs - self.raw_bs\n                )\n            thread = threading.Thread(target=self._update_inputs, args=(seq_lens,))\n            thread.start()\n            self.graphs[self.bs].replay()\n            thread.join()\n        else:\n            self.graphs[self.bs].replay()\n\n        output = self.output_buffers[self.bs]\n        if isinstance(output, LogitsProcessorOutput):\n            if self.is_dllm:\n                next_token_logits = None\n                full_logits = output.full_logits[: self.raw_num_token]\n            else:\n                full_logits = None\n                next_token_logits = output.next_token_logits[: self.raw_num_token]\n            return LogitsProcessorOutput(\n                next_token_logits=next_token_logits,\n                full_logits=full_logits,\n                hidden_states=(\n                    output.hidden_states[: self.raw_num_token]\n                    if output.hidden_states is not None\n                    else None\n                ),\n            )\n        else:\n            assert isinstance(output, PPProxyTensors)\n            return PPProxyTensors({k: v[: self.bs] for k, v in output.tensors.items()})\n"
  },
  {
    "path": "python/sglang/srt/hardware_backend/npu/graph_runner/vit_npu_graph_runner.py",
    "content": "# Copyright 2023-2025 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\n\"\"\"ViT NPU Graph Runner class.\"\"\"\n\nfrom __future__ import annotations\n\nfrom typing import Dict, Hashable, List, Optional, Tuple\n\nimport torch\nimport torch.nn as nn\nimport torch_npu\n\nfrom sglang.srt.distributed.device_communicators.pynccl_allocator import (\n    set_graph_pool_id,\n)\nfrom sglang.srt.layers.attention.vision import VisionAttention\nfrom sglang.srt.multimodal.vit_cuda_graph_runner import ViTCudaGraphRunner\nfrom sglang.srt.server_args import get_global_server_args\n\n\nclass ViTNpuGraphRunner(ViTCudaGraphRunner):\n    \"\"\"Generic ViT NPU Graph Runner.\n\n    This runner captures the \"blocks + merger + deepstack merger (optional)\" part\n    of a vision transformer into a NPU graph and replays it for identical shapes.\n\n    Optional for Qwen3 deepstack:\n      - vit.deepstack_vision_indexes: Sequence[int]\n      - vit.deepstack_merger_list: nn.ModuleList (same length as deepstack_vision_indexes)\n    \"\"\"\n\n    _graph_memory_pool = None\n\n    def __init__(\n        self,\n        vit: nn.Module,\n    ) -> None:\n        super().__init__(vit)\n        self.device_module = torch.get_device_module(self.device)\n        self.cu_seq_lens: Dict[Hashable, torch.Tensor] = {}\n\n        # rotary position buffers shared across graphs\n        self.sin_cos_ws: Dict[Hashable, Tuple[torch.Tensor, torch.Tensor]] = {}\n\n    @property\n    def device(self) -> torch.device:\n        return self.vit.device\n\n    @property\n    def dtype(self) -> torch.dtype:\n        return self.vit.dtype\n\n    def _create_graph(\n        self,\n        graph_key: int,\n    ):\n\n        graph = torch_npu.npu.NPUGraph()\n        vit = self.vit\n\n        override_backend = get_global_server_args().mm_attention_backend\n        with torch_npu.npu.graph(graph, pool=ViTNpuGraphRunner._graph_memory_pool):\n            y = None\n            deepstack_outs: List[torch.Tensor] = []\n            deepstack_capture_idx = 0\n\n            for layer_num, blk in enumerate(vit.blocks):\n                if override_backend == \"ascend_attn\":\n                    cu_seq_lens = self.cu_seq_lens[graph_key]\n                else:\n                    raise RuntimeError(\"Not supported ViT attention backend\")\n\n                if layer_num == 0:\n                    y = blk(\n                        self.block_input[graph_key],\n                        cu_seqlens=cu_seq_lens,\n                        rotary_pos_emb_cos=self.sin_cos_ws[graph_key][0],\n                        rotary_pos_emb_sin=self.sin_cos_ws[graph_key][1],\n                        output_ws=self.block_ws[graph_key],\n                    )\n                else:\n                    y = blk(\n                        y,\n                        cu_seqlens=cu_seq_lens,\n                        rotary_pos_emb_cos=self.sin_cos_ws[graph_key][0],\n                        rotary_pos_emb_sin=self.sin_cos_ws[graph_key][1],\n                        output_ws=self.block_ws[graph_key],\n                    )\n\n                # Optional deepstack support (Qwen3-VL)\n                if (\n                    self._deepstack_visual_indexes\n                    and layer_num in self._deepstack_visual_indexes\n                ):\n                    if self._deepstack_merger_list is None:\n                        raise RuntimeError(\n                            \"deepstack_visual_indexes exists but deepstack_merger_list is missing.\"\n                        )\n                    deepstack_out = self._deepstack_merger_list[deepstack_capture_idx](\n                        y\n                    )\n                    deepstack_outs.append(deepstack_out)\n                    deepstack_capture_idx += 1\n\n            main_out = vit.merger(y)\n\n            if deepstack_outs:\n                self.block_output[graph_key] = torch.cat(\n                    [main_out] + deepstack_outs, dim=1\n                )\n            else:\n                self.block_output[graph_key] = main_out\n\n        self.block_graphs[graph_key] = graph\n\n    def create_graph(\n        self,\n        x_3d: torch.Tensor,  # [S, 1, H]\n        cu_seqlens: torch.Tensor,\n        rotary_pos_emb_cos: Optional[torch.Tensor] = None,\n        rotary_pos_emb_sin: Optional[torch.Tensor] = None,\n    ) -> int:\n        vit = self.vit\n        graph_key = self._get_graph_key(x_3d)\n\n        if graph_key in self.block_graphs:\n            return graph_key\n\n        if ViTNpuGraphRunner._graph_memory_pool is None:\n            ViTNpuGraphRunner._graph_memory_pool = (\n                self.device_module.graph_pool_handle()\n            )\n        # Set graph pool id globally to be able to use symmetric memory\n        set_graph_pool_id(ViTNpuGraphRunner._graph_memory_pool)\n\n        # pre-allocate workspace\n        attn_module: VisionAttention = vit.blocks[0].attn\n        num_heads = attn_module.num_attention_heads_per_partition\n        attn_head_dim = attn_module.head_size\n\n        if graph_key not in self.block_output:\n            self.block_output[graph_key] = x_3d\n            self.block_input[graph_key] = x_3d\n            self.block_ws[graph_key] = torch.empty(\n                graph_key,\n                num_heads,\n                attn_head_dim,\n                device=self.device,\n                dtype=self.dtype,\n            )\n            if rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None:\n                self.sin_cos_ws[graph_key] = (rotary_pos_emb_cos, rotary_pos_emb_sin)\n\n        if graph_key not in self.cu_seq_lens:\n            seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]\n            self.cu_seq_lens[graph_key] = seq_lens.to(\"cpu\").to(torch.int32)\n\n        if rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None:\n            self._create_graph(\n                graph_key=graph_key,\n            )\n\n        return graph_key\n\n    def replay(\n        self,\n        graph_key: int,\n        x_3d: torch.Tensor,\n        rotary_pos_emb_cos: Optional[torch.Tensor] = None,\n        rotary_pos_emb_sin: Optional[torch.Tensor] = None,\n        output_indices: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        if rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None:\n            # update rotary workspace content\n            self.sin_cos_ws[graph_key][0].copy_(rotary_pos_emb_cos)\n            self.sin_cos_ws[graph_key][1].copy_(rotary_pos_emb_sin)\n\n        # copy input\n        self.block_input[graph_key].copy_(x_3d)\n\n        # replay\n        self.block_graphs[graph_key].replay()\n\n        out = self.block_output[graph_key]\n\n        # Optional output reordering (Qwen2.5-VL window permutation inverse)\n        if output_indices is not None:\n            out = out.index_select(0, output_indices)\n\n        return out\n\n    def run(\n        self,\n        x: torch.Tensor,\n        cu_seqlens: torch.Tensor,\n        rotary_pos_emb_cos: Optional[torch.Tensor] = None,\n        rotary_pos_emb_sin: Optional[torch.Tensor] = None,\n        output_indices: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        # x: [seq_len, hidden] -> [S, B=1, H]\n        x_3d = x.unsqueeze(1)\n        graph_key = self._get_graph_key(x_3d)\n        if graph_key not in self.block_graphs:\n            self.create_graph(\n                x_3d=x_3d,\n                cu_seqlens=cu_seqlens,\n                rotary_pos_emb_cos=rotary_pos_emb_cos,\n                rotary_pos_emb_sin=rotary_pos_emb_sin,\n            )\n\n        return self.replay(\n            graph_key=graph_key,\n            x_3d=x_3d,\n            rotary_pos_emb_cos=rotary_pos_emb_cos,\n            rotary_pos_emb_sin=rotary_pos_emb_sin,\n            output_indices=output_indices,\n        )\n"
  },
  {
    "path": "python/sglang/srt/hardware_backend/npu/memory_pool_npu.py",
    "content": "from typing import TYPE_CHECKING, Optional\n\nimport torch\nimport torch_npu\n\nfrom sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE\nfrom sglang.srt.mem_cache.memory_pool import (\n    MHATokenToKVPool,\n    MLATokenToKVPool,\n    get_tensor_size_bytes,\n)\nfrom sglang.srt.utils import get_bool_env_var\n\nif TYPE_CHECKING:\n    from sglang.srt.layers.radix_attention import RadixAttention\n\n\nclass NPUMHATokenToKVPool(MHATokenToKVPool):\n\n    def __init__(\n        self,\n        size: int,\n        page_size: int,\n        dtype: torch.dtype,\n        head_num: int,\n        head_dim: int,\n        layer_num: int,\n        device: str,\n        enable_memory_saver: bool,\n        start_layer: Optional[int] = None,\n        end_layer: Optional[int] = None,\n        enable_alt_stream: bool = True,\n        enable_kv_cache_copy: bool = False,\n    ):\n        self.use_fia = get_bool_env_var(\"ASCEND_USE_FIA\", \"False\")\n        super().__init__(\n            size=size,\n            page_size=page_size,\n            dtype=dtype,\n            head_num=head_num,\n            head_dim=head_dim,\n            layer_num=layer_num,\n            device=device,\n            enable_memory_saver=enable_memory_saver,\n            start_layer=start_layer,\n            end_layer=end_layer,\n            enable_alt_stream=enable_alt_stream,\n            enable_kv_cache_copy=enable_kv_cache_copy,\n        )\n\n    def _create_buffers(self):\n        with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):\n            # [size, head_num, head_dim] for each layer\n            # The padded slot 0 is used for writing dummy outputs from padded tokens.\n            # Continuous memory improves the efficiency of Ascend`s transmission backend,\n            # while other backends remain unchanged.\n            self.kv_buffer = torch.zeros(\n                (\n                    2,\n                    self.layer_num,\n                    self.size // self.page_size + 1,\n                    self.page_size,\n                    self.head_num,\n                    self.head_dim,\n                ),\n                dtype=self.store_dtype,\n                device=self.device,\n            )\n            self.k_buffer = self.kv_buffer[0]\n            self.v_buffer = self.kv_buffer[1]\n\n            if self.use_fia:\n                self.k_buffer = []\n                self.v_buffer = []\n                for i in range(self.layer_num):\n                    k_buffer_layer = self.kv_buffer[0][i].view(\n                        -1, 1, self.head_num, self.head_dim\n                    )\n                    v_buffer_layer = self.kv_buffer[1][i].view(\n                        -1, 1, self.head_num, self.head_dim\n                    )\n                    self.k_buffer.append(k_buffer_layer)\n                    self.v_buffer.append(v_buffer_layer)\n\n    # for disagg\n    def get_contiguous_buf_infos(self):\n        # layer_num x [seq_len, head_num, head_dim]\n        # layer_num x [page_num, page_size, head_num, head_dim]\n        kv_data_ptrs = [\n            self.get_key_buffer(i).data_ptr()\n            for i in range(self.start_layer, self.start_layer + self.layer_num)\n        ] + [\n            self.get_value_buffer(i).data_ptr()\n            for i in range(self.start_layer, self.start_layer + self.layer_num)\n        ]\n        kv_data_lens = [\n            self.get_key_buffer(i).nbytes\n            for i in range(self.start_layer, self.start_layer + self.layer_num)\n        ] + [\n            self.get_value_buffer(i).nbytes\n            for i in range(self.start_layer, self.start_layer + self.layer_num)\n        ]\n        kv_item_lens = [\n            self.get_key_buffer(i)[0].nbytes\n            for i in range(self.start_layer, self.start_layer + self.layer_num)\n        ] + [\n            self.get_value_buffer(i)[0].nbytes\n            for i in range(self.start_layer, self.start_layer + self.layer_num)\n        ]\n        return kv_data_ptrs, kv_data_lens, kv_item_lens\n\n    def set_kv_buffer(\n        self,\n        layer: \"RadixAttention\",\n        loc: torch.Tensor,\n        cache_k: torch.Tensor,\n        cache_v: torch.Tensor,\n        k_scale: Optional[float] = None,\n        v_scale: Optional[float] = None,\n        layer_id_override: Optional[int] = None,\n    ):\n        if layer_id_override is not None:\n            layer_id = layer_id_override\n        else:\n            layer_id = layer.layer_id\n        if cache_k.dtype != self.dtype:\n            if k_scale is not None:\n                cache_k.div_(k_scale)\n            if v_scale is not None:\n                cache_v.div_(v_scale)\n            cache_k = cache_k.to(self.dtype)\n            cache_v = cache_v.to(self.dtype)\n\n        if self.store_dtype != self.dtype:\n            cache_k = cache_k.view(self.store_dtype)\n            cache_v = cache_v.view(self.store_dtype)\n\n        if self.use_fia:\n            k_buffer_layer = self.k_buffer[layer_id - self.start_layer]\n            v_buffer_layer = self.v_buffer[layer_id - self.start_layer]\n\n            torch_npu.npu_scatter_nd_update_(\n                k_buffer_layer,\n                loc.view(-1, 1),\n                cache_k.view(-1, 1, self.head_num, self.head_dim),\n            )\n            torch_npu.npu_scatter_nd_update_(\n                v_buffer_layer,\n                loc.view(-1, 1),\n                cache_v.view(-1, 1, self.head_num, self.head_dim),\n            )\n        else:\n            loc = loc.to(torch.int32)\n            torch_npu._npu_reshape_and_cache(\n                key=cache_k,\n                value=cache_v,\n                key_cache=self.k_buffer[layer_id - self.start_layer].view(\n                    -1, self.page_size, self.head_num, self.head_dim\n                ),\n                value_cache=self.v_buffer[layer_id - self.start_layer].view(\n                    -1, self.page_size, self.head_num, self.head_dim\n                ),\n                slot_indices=loc,\n            )\n\n\nclass NPUMLATokenToKVPool(MLATokenToKVPool):\n\n    def __init__(\n        self,\n        size: int,\n        page_size: int,\n        dtype: torch.dtype,\n        kv_lora_rank: int,\n        qk_rope_head_dim: int,\n        index_head_dim: Optional[int],\n        layer_num: int,\n        device: str,\n        enable_memory_saver: bool,\n        start_layer: Optional[int] = None,\n        end_layer: Optional[int] = None,\n    ):\n        super(MLATokenToKVPool, self).__init__(\n            size=size,\n            page_size=page_size,\n            dtype=dtype,\n            layer_num=layer_num,\n            device=device,\n            enable_memory_saver=enable_memory_saver,\n            start_layer=start_layer,\n            end_layer=end_layer,\n        )\n\n        self.kv_lora_rank = kv_lora_rank\n        self.qk_rope_head_dim = qk_rope_head_dim\n        self.index_head_dim = index_head_dim\n\n        self.custom_mem_pool = None\n\n        with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):\n            # The padded slot 0 is used for writing dummy outputs from padded tokens.\n            self.k_buffer = torch.zeros(\n                (\n                    layer_num,\n                    self.size // self.page_size + 1,\n                    self.page_size,\n                    1,\n                    self.kv_lora_rank,\n                ),\n                dtype=self.store_dtype,\n                device=self.device,\n            )\n            self.v_buffer = torch.zeros(\n                (\n                    layer_num,\n                    self.size // self.page_size + 1,\n                    self.page_size,\n                    1,\n                    self.qk_rope_head_dim,\n                ),\n                dtype=self.store_dtype,\n                device=self.device,\n            )\n            self.index_k_buffer = None\n            if self.index_head_dim is not None:\n                self.index_k_buffer = torch.zeros(\n                    (\n                        layer_num,\n                        self.size // self.page_size + 1,\n                        self.page_size,\n                        1,\n                        self.index_head_dim,\n                    ),\n                    dtype=self.store_dtype,\n                    device=self.device,\n                )\n\n        self._finalize_allocation_log(size)\n\n    def get_kv_size_bytes(self):\n        assert hasattr(self, \"k_buffer\")\n        assert hasattr(self, \"v_buffer\")\n        kv_size_bytes = 0\n        for k_cache in self.k_buffer:\n            kv_size_bytes += get_tensor_size_bytes(k_cache)\n        for v_cache in self.v_buffer:\n            kv_size_bytes += get_tensor_size_bytes(v_cache)\n        if self.index_head_dim is not None:\n            assert hasattr(self, \"index_k_buffer\")\n            for index_k_cache in self.index_k_buffer:\n                kv_size_bytes += get_tensor_size_bytes(index_k_cache)\n        return kv_size_bytes\n\n    def get_kv_buffer(self, layer_id: int):\n        if self.layer_transfer_counter is not None:\n            self.layer_transfer_counter.wait_until(layer_id - self.start_layer)\n        return (\n            self.k_buffer[layer_id - self.start_layer],\n            self.v_buffer[layer_id - self.start_layer],\n        )\n\n    def get_key_buffer(self, layer_id: int):\n        if self.layer_transfer_counter is not None:\n            self.layer_transfer_counter.wait_until(layer_id - self.start_layer)\n\n        if self.store_dtype != self.dtype:\n            return self.k_buffer[layer_id - self.start_layer].view(self.dtype)\n        return self.k_buffer[layer_id - self.start_layer]\n\n    def get_value_buffer(self, layer_id: int):\n        if self.layer_transfer_counter is not None:\n            self.layer_transfer_counter.wait_until(layer_id - self.start_layer)\n\n        if self.store_dtype != self.dtype:\n            return self.v_buffer[layer_id - self.start_layer].view(self.dtype)\n        return self.v_buffer[layer_id - self.start_layer]\n\n    def get_index_k_buffer(self, layer_id: int):\n        if self.layer_transfer_counter is not None:\n            self.layer_transfer_counter.wait_until(layer_id - self.start_layer)\n\n        if self.store_dtype != self.dtype:\n            return self.index_k_buffer[layer_id - self.start_layer].view(self.dtype)\n        return self.index_k_buffer[layer_id - self.start_layer]\n\n    # for disagg\n    def get_contiguous_buf_infos(self):\n        # MLA has only one kv_buffer, so only the information of this buffer needs to be returned.\n        kv_data_ptrs = [self.k_buffer[i].data_ptr() for i in range(self.layer_num)] + [\n            self.v_buffer[i].data_ptr() for i in range(self.layer_num)\n        ]\n        kv_data_lens = [self.k_buffer[i].nbytes for i in range(self.layer_num)] + [\n            self.v_buffer[i].nbytes for i in range(self.layer_num)\n        ]\n        kv_item_lens = [self.k_buffer[i][0].nbytes for i in range(self.layer_num)] + [\n            self.v_buffer[i][0].nbytes for i in range(self.layer_num)\n        ]\n        if self.index_head_dim is not None:\n            kv_data_ptrs += [\n                self.index_k_buffer[i].data_ptr() for i in range(self.layer_num)\n            ]\n            kv_data_lens += [\n                self.index_k_buffer[i].nbytes for i in range(self.layer_num)\n            ]\n            kv_item_lens += [\n                self.index_k_buffer[i][0].nbytes for i in range(self.layer_num)\n            ]\n        return kv_data_ptrs, kv_data_lens, kv_item_lens\n\n    def set_kv_buffer(\n        self,\n        layer: \"RadixAttention\",\n        loc: torch.Tensor,\n        cache_k: torch.Tensor,\n        cache_v: torch.Tensor,\n    ):\n        layer_id = layer.layer_id\n        if cache_k.dtype != self.dtype:\n            cache_k = cache_k.to(self.dtype)\n            cache_v = cache_v.to(self.dtype)\n\n        if self.store_dtype != self.dtype:\n            cache_k = cache_k.view(self.store_dtype)\n            cache_v = cache_v.view(self.store_dtype)\n\n        if cache_v is None:\n            cache_k, cache_v = cache_k.split(\n                [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1\n            )\n\n        torch_npu.npu_scatter_nd_update_(\n            self.k_buffer[layer_id - self.start_layer].view(-1, 1, self.kv_lora_rank),\n            loc.view(-1, 1),\n            cache_k.view(-1, 1, self.kv_lora_rank),\n        )\n        torch_npu.npu_scatter_nd_update_(\n            self.v_buffer[layer_id - self.start_layer].view(\n                -1, 1, self.qk_rope_head_dim\n            ),\n            loc.view(-1, 1),\n            cache_v.view(-1, 1, self.qk_rope_head_dim),\n        )\n\n    def set_index_k_buffer(\n        self,\n        layer_id: int,\n        loc: torch.Tensor,\n        index_k: torch.Tensor,\n    ):\n        if index_k.dtype != self.dtype:\n            index_k = index_k.to(self.dtype)\n\n        if self.store_dtype != self.dtype:\n            index_k = index_k.view(self.store_dtype)\n\n        torch_npu.npu_scatter_nd_update_(\n            self.index_k_buffer[layer_id - self.start_layer].view(\n                -1, 1, self.index_head_dim\n            ),\n            loc.view(-1, 1),\n            index_k.view(-1, 1, self.index_head_dim),\n        )\n"
  },
  {
    "path": "python/sglang/srt/hardware_backend/npu/modules/deepseek_v2_attention_mla_npu.py",
    "content": "import re\nfrom typing import TYPE_CHECKING\n\nimport torch\nimport torch_npu\n\nfrom sglang.srt.environ import envs\nfrom sglang.srt.hardware_backend.npu.attention.mla_preprocess import (\n    NPUFusedMLAPreprocess,\n    is_fia_nz,\n    is_mla_preprocess_enabled,\n)\nfrom sglang.srt.layers.attention.nsa.nsa_indexer import scattered_to_tp_attn_full\nfrom sglang.srt.layers.attention.nsa.utils import (\n    nsa_use_prefill_cp,\n)\nfrom sglang.srt.layers.communicator import ScatterMode, get_attn_tp_context\n\nif TYPE_CHECKING:\n    from sglang.srt.model_executor.forward_batch_info import ForwardBatch\n    from sglang.srt.models.deepseek_v2 import DeepseekV2AttentionMLA\n    from sglang.srt.utils import BumpAllocator\n_use_ag_after_qlora = envs.SGLANG_USE_AG_AFTER_QLORA.get()\n\n\n# region MHA\ndef forward_mha_prepare_npu(\n    m: \"DeepseekV2AttentionMLA\",\n    positions: torch.Tensor,\n    hidden_states: torch.Tensor,\n    forward_batch: \"ForwardBatch\",\n    zero_allocator: \"BumpAllocator\",\n    layer_scatter_modes,\n):\n    if m.q_lora_rank is not None:\n        q, latent_cache = (\n            get_attn_tp_context()\n            .fetch_qkv_latent()\n            .split(\n                [m.q_lora_rank, m.kv_lora_rank + m.qk_rope_head_dim],\n                dim=-1,\n            )\n        )\n\n        # NSA Indexer: cache quantized keys, auto-skip topk for sequences <= nsa_index_topk\n\n        if m.use_nsa:\n            q_lora = m.q_a_layernorm(q)\n            q = m.q_b_proj(q_lora)[0].view(-1, m.num_local_heads, m.qk_head_dim)\n            _ = m.indexer(\n                x=hidden_states,\n                q_lora=q_lora,\n                positions=positions,\n                forward_batch=forward_batch,\n                layer_id=m.layer_id,\n                return_indices=False,\n            )\n\n        else:\n            q = m.q_a_layernorm(q)\n            if (\n                _use_ag_after_qlora\n                and layer_scatter_modes.layer_input_mode == ScatterMode.SCATTERED\n                and layer_scatter_modes.attn_mode == ScatterMode.TP_ATTN_FULL\n            ):\n                q = scattered_to_tp_attn_full(q, forward_batch)\n                latent_cache = scattered_to_tp_attn_full(latent_cache, forward_batch)\n            q = m.q_b_proj(q)[0].view(-1, m.num_local_heads, m.qk_head_dim)\n\n    else:\n        q = m.q_proj(hidden_states)[0].view(-1, m.num_local_heads, m.qk_head_dim)\n        latent_cache = m.kv_a_proj_with_mqa(hidden_states)[0]\n\n    _, q_pe = q.split([m.qk_nope_head_dim, m.qk_rope_head_dim], dim=-1)\n    kv_a, _ = latent_cache.split([m.kv_lora_rank, m.qk_rope_head_dim], dim=-1)\n    latent_cache = latent_cache.unsqueeze(1)\n\n    if m.use_deepseek_yarn_rope:\n        B, S = q.shape[0], 1\n        cos, sin = m.rotary_emb.get_cos_sin_cache(\n            positions, hidden_states.dtype, offsets=None\n        )\n        q_pe = torch_npu.npu_interleave_rope(\n            q_pe.reshape(B, -1, S, m.qk_rope_head_dim),\n            cos,\n            sin,\n        )\n        q_pe = q_pe.reshape(B, -1, m.qk_rope_head_dim)\n\n        ckv_cache, k_rope_cache = forward_batch.token_to_kv_pool.get_kv_buffer(\n            m.layer_id\n        )\n        _, _, k_pe, kv_a = torch_npu.npu_kv_rmsnorm_rope_cache(\n            latent_cache.view(-1, 1, 1, m.kv_lora_rank + m.qk_rope_head_dim),  # bnsd\n            m.kv_a_layernorm.weight,\n            cos,\n            sin,\n            forward_batch.out_cache_loc.to(torch.int64),\n            k_rope_cache,\n            ckv_cache,\n            k_rope_scale=None,\n            c_kv_scale=None,\n            k_rope_offset=None,\n            c_kv_offset=None,\n            epsilon=m.kv_a_layernorm.variance_epsilon,\n            cache_mode=\"PA_NZ\" if is_fia_nz() else \"PA_BNSD\",\n            is_output_kv=True,\n        )  # adapter NZ\n\n        k_pe = k_pe.reshape(B, -1, m.qk_rope_head_dim)\n    else:\n        kv_a = m.kv_a_layernorm(kv_a)\n        k_pe = latent_cache[:, :, m.kv_lora_rank :]\n        if m.rotary_emb is not None:\n            q_pe, k_pe = m.rotary_emb(positions, q_pe, k_pe)\n        # this is for model kimi-vl-a3B-instruct\n        forward_batch.token_to_kv_pool.set_kv_buffer(\n            m, forward_batch.out_cache_loc, kv_a.unsqueeze(1), k_pe\n        )\n\n    q[..., m.qk_nope_head_dim :] = q_pe\n\n    kv = m.kv_b_proj(kv_a)[0]\n    kv = kv.view(-1, m.num_local_heads, m.qk_nope_head_dim + m.v_head_dim)\n    k_nope = kv[..., : m.qk_nope_head_dim]\n    v = kv[..., m.qk_nope_head_dim :]\n\n    k = m._concat_and_cast_mha_k(k_nope, k_pe, forward_batch)\n    return q, k, v, forward_batch\n\n\ndef forward_mha_core_npu(\n    m: \"DeepseekV2AttentionMLA\",\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    forward_batch: \"ForwardBatch\",\n) -> torch.Tensor:\n    attn_output = m.attn_mha(q, k, v, forward_batch, save_kv_cache=False)\n    attn_output = attn_output.reshape(-1, m.num_local_heads * m.v_head_dim)\n    output, _ = m.o_proj(attn_output)\n    return output\n\n\n# endregion\n\n\n# region MLA\ndef forward_mla_prepare_npu(\n    m: \"DeepseekV2AttentionMLA\",\n    positions: torch.Tensor,\n    hidden_states: torch.Tensor,\n    forward_batch: \"ForwardBatch\",\n    zero_allocator: \"BumpAllocator\",\n    layer_scatter_modes,\n):\n    if is_mla_preprocess_enabled():\n        if not hasattr(m, \"mla_preprocess\"):\n            m.mla_preprocess = NPUFusedMLAPreprocess(\n                m.fused_qkv_a_proj_with_mqa,\n                m.q_a_layernorm,\n                m.kv_a_layernorm,\n                m.q_b_proj,\n                m.w_kc,\n                m.rotary_emb,\n                m.layer_id,\n                m.num_local_heads,\n                m.qk_nope_head_dim,\n                m.qk_rope_head_dim,\n                m.quant_config,\n            )\n        (\n            q_pe,\n            k_pe,\n            q_nope_out,\n            k_nope,\n            forward_batch,\n            zero_allocator,\n            positions,\n        ) = m.mla_preprocess.forward(\n            positions, hidden_states, forward_batch, zero_allocator\n        )\n        topk_indices = None\n    else:\n        q_lora = None\n        if m.q_lora_rank is not None:\n            q, latent_cache = (\n                get_attn_tp_context()\n                .fetch_qkv_latent()\n                .split(\n                    [m.q_lora_rank, m.kv_lora_rank + m.qk_rope_head_dim],\n                    dim=-1,\n                )\n            )\n            k_nope = latent_cache[..., : m.kv_lora_rank]\n\n            q = m.q_a_layernorm(q)\n            if (\n                _use_ag_after_qlora\n                and layer_scatter_modes.layer_input_mode == ScatterMode.SCATTERED\n                and layer_scatter_modes.attn_mode == ScatterMode.TP_ATTN_FULL\n            ):\n                q = scattered_to_tp_attn_full(q, forward_batch)\n                latent_cache = scattered_to_tp_attn_full(latent_cache, forward_batch)\n            k_nope = m.kv_a_layernorm(k_nope)\n\n            # q_lora needed by indexer\n            if m.use_nsa:\n                q_lora = q\n\n            k_nope = k_nope.unsqueeze(1)\n            q = m.q_b_proj(q)[0].view(-1, m.num_local_heads, m.qk_head_dim)\n        else:\n            q = m.q_proj(hidden_states)[0].view(-1, m.num_local_heads, m.qk_head_dim)\n            latent_cache = m.kv_a_proj_with_mqa(hidden_states)[0]\n            k_nope = latent_cache[..., : m.kv_lora_rank]\n            k_nope = m.kv_a_layernorm(k_nope).unsqueeze(1)\n\n        q_nope, q_pe = q.split([m.qk_nope_head_dim, m.qk_rope_head_dim], dim=-1)\n        k_pe = latent_cache[..., m.kv_lora_rank :].unsqueeze(1)\n\n        q_nope_out = torch.bmm(q_nope.transpose(0, 1), m.w_kc)\n\n        q_nope_out = q_nope_out.transpose(0, 1)\n\n        q_pe, k_pe = m.rotary_emb(positions, q_pe, k_pe)\n\n        if nsa_use_prefill_cp(forward_batch):\n            # support allgather+rerrange\n            k_nope, k_pe = m.rebuild_cp_kv_cache(\n                latent_cache, forward_batch, k_nope, k_pe\n            )\n        topk_indices = None\n        if q_lora is not None:\n            topk_indices = m.indexer(\n                x=hidden_states,\n                q_lora=q_lora,\n                positions=positions,\n                forward_batch=forward_batch,\n                layer_id=m.layer_id,\n            )\n\n    return (\n        q_pe,\n        k_pe,\n        q_nope_out,\n        k_nope,\n        forward_batch,\n        zero_allocator,\n        positions,\n        topk_indices,\n    )\n\n\ndef forward_mla_core_npu(\n    m: \"DeepseekV2AttentionMLA\",\n    q_pe: torch.Tensor,\n    k_pe: torch.Tensor,\n    q_nope_out: torch.Tensor,\n    k_nope: torch.Tensor,\n    forward_batch: \"ForwardBatch\",\n    zero_allocator: \"BumpAllocator\",\n    positions: torch.Tensor,\n    topk_indices: torch.Tensor,\n) -> torch.Tensor:\n    attn_output = m.attn_mqa(\n        q_nope_out,\n        k_nope,\n        k_nope,\n        forward_batch,\n        q_rope=q_pe,\n        k_rope=k_pe,\n        **(dict(topk_indices=topk_indices) if topk_indices is not None else {}),\n    )\n\n    attn_output = attn_output.view(-1, m.num_local_heads, m.kv_lora_rank)\n\n    attn_bmm_output = torch.empty(\n        (attn_output.shape[0], m.num_local_heads, m.v_head_dim),\n        dtype=attn_output.dtype,\n        device=attn_output.device,\n    )\n\n    attn_output = attn_output.contiguous()\n    torch.ops.npu.batch_matmul_transpose(attn_output, m.w_vc, attn_bmm_output)\n\n    attn_bmm_output = attn_bmm_output.reshape(-1, m.num_local_heads * m.v_head_dim)\n    output, _ = m.o_proj(attn_bmm_output)\n\n    return output\n\n\n# endregion\n\n\n# region DSA\ndef forward_dsa_prepare_npu(\n    m: \"DeepseekV2AttentionMLA\",\n    positions: torch.Tensor,\n    hidden_states: torch.Tensor,\n    forward_batch: \"ForwardBatch\",\n    zero_allocator: \"BumpAllocator\",\n    layer_scatter_modes,\n):\n    dynamic_scale = None\n    if is_mla_preprocess_enabled() and forward_batch.forward_mode.is_decode():\n        (\n            q_pe,\n            k_pe,\n            q_nope_out,\n            k_nope,\n            q_lora,\n            forward_batch,\n            zero_allocator,\n            positions,\n            dynamic_scale,\n        ) = npu_mla_preprocess(\n            m,\n            hidden_states,\n            positions,\n            forward_batch,\n            zero_allocator,\n        )\n    else:\n        fused_qkv_a_proj_out = m.fused_qkv_a_proj_with_mqa(hidden_states)[0]\n        q, latent_cache = fused_qkv_a_proj_out.split(\n            [m.q_lora_rank, m.kv_lora_rank + m.qk_rope_head_dim], dim=-1\n        )\n\n        # overlap qk norm\n        q = m.q_a_layernorm(q)\n        if (\n            _use_ag_after_qlora\n            and layer_scatter_modes.layer_input_mode == ScatterMode.SCATTERED\n            and layer_scatter_modes.attn_mode == ScatterMode.TP_ATTN_FULL\n        ):\n            q = scattered_to_tp_attn_full(q, forward_batch)\n            latent_cache = scattered_to_tp_attn_full(latent_cache, forward_batch)\n        q_lora = q.clone()  # required for topk_indices\n\n        q_event = None\n        if m.alt_stream is not None:\n            m.alt_stream.wait_stream(torch.npu.current_stream())\n            with torch.npu.stream(m.alt_stream):\n                q = m.q_b_proj(q_lora)[0].view(-1, m.num_local_heads, m.qk_head_dim)\n                # record q to ensure memory space will not be released\n                q.record_stream(m.alt_stream)\n                q_event = m.alt_stream.record_event()\n        else:\n            q = m.q_b_proj(q_lora)[0].view(-1, m.num_local_heads, m.qk_head_dim)\n\n        k_nope, k_pe = latent_cache.unsqueeze(1).split(\n            [m.kv_lora_rank, m.qk_rope_head_dim], dim=-1\n        )\n        k_nope = m.kv_a_layernorm(k_nope)\n        # main stream waits for the completion of the event on the alt stream to ensure data dependency is complete\n        if q_event is not None:\n            torch.npu.current_stream().wait_event(q_event)\n\n        q_nope, q_pe = q.split([m.qk_nope_head_dim, m.qk_rope_head_dim], dim=-1)\n\n        q_nope_out = torch.bmm(q_nope.transpose(0, 1), m.w_kc)\n\n        q_nope_out = q_nope_out.transpose(0, 1)\n\n        q_pe, k_pe = m.rotary_emb(positions, q_pe, k_pe)\n\n        if nsa_use_prefill_cp(forward_batch):\n            # support allgather+rerrange\n            k_nope, k_pe = m.rebuild_cp_kv_cache(\n                latent_cache, forward_batch, k_nope, k_pe\n            )\n\n    topk_indices = m.indexer(\n        hidden_states,\n        q_lora,\n        positions,\n        forward_batch,\n        m.layer_id,\n        layer_scatter_modes,\n        dynamic_scale,\n    )\n\n    return (\n        q_pe,\n        k_pe,\n        q_nope_out,\n        k_nope,\n        topk_indices,\n        forward_batch,\n        zero_allocator,\n        positions,\n    )\n\n\ndef forward_dsa_core_npu(\n    m: \"DeepseekV2AttentionMLA\",\n    q_pe: torch.Tensor,\n    k_pe: torch.Tensor,\n    q_nope_out: torch.Tensor,\n    k_nope: torch.Tensor,\n    topk_indices: torch.Tensor,\n    forward_batch: \"ForwardBatch\",\n    zero_allocator: \"BumpAllocator\",\n    positions: torch.Tensor,\n) -> torch.Tensor:\n    attn_output = m.attn_mqa(\n        q_nope_out.contiguous(),\n        k_nope.contiguous(),\n        k_nope.contiguous(),\n        forward_batch,\n        save_kv_cache=True,  # False if forward_batch.forward_mode.is_extend() else True,\n        q_rope=q_pe.contiguous(),\n        k_rope=k_pe.contiguous(),\n        topk_indices=topk_indices,\n    )\n    attn_output = attn_output.view(-1, m.num_local_heads, m.kv_lora_rank)\n\n    attn_bmm_output = torch.empty(\n        (attn_output.shape[0], m.num_local_heads, m.v_head_dim),\n        dtype=attn_output.dtype,\n        device=attn_output.device,\n    )\n\n    if (\n        forward_batch.forward_mode.is_extend()\n        and not forward_batch.forward_mode.is_draft_extend(include_v2=True)\n        and not forward_batch.forward_mode.is_target_verify()\n    ):\n        attn_output = attn_output.transpose(0, 1)\n        torch.bmm(\n            attn_output,\n            m.w_vc,\n            out=attn_bmm_output.view(-1, m.num_local_heads, m.v_head_dim).transpose(\n                0, 1\n            ),\n        )\n    else:\n        attn_output = attn_output.contiguous()\n        torch.ops.npu.batch_matmul_transpose(attn_output, m.w_vc, attn_bmm_output)\n\n    attn_bmm_output = attn_bmm_output.reshape(-1, m.num_local_heads * m.v_head_dim)\n\n    output, _ = m.o_proj(attn_bmm_output)\n    return output\n\n\ndef npu_mla_preprocess(\n    m: \"DeepseekV2AttentionMLA\",\n    hidden_states: torch.Tensor,\n    positions: torch.Tensor,\n    forward_batch: \"ForwardBatch\",\n    zero_allocator: \"BumpAllocator\",\n):\n    dynamic_scale = None\n    if not hasattr(m, \"mla_preprocess\"):\n        m.mla_preprocess = NPUFusedMLAPreprocess(\n            m.fused_qkv_a_proj_with_mqa,\n            m.q_a_layernorm,\n            m.kv_a_layernorm,\n            m.q_b_proj,\n            m.w_kc,\n            m.rotary_emb,\n            m.layer_id,\n            m.num_local_heads,\n            m.qk_nope_head_dim,\n            m.qk_rope_head_dim,\n            m.v_head_dim,\n            m.quant_config,\n        )\n    # mlaprolog does not require additional calculation of q_lora\n    _is_mlaprolog = hasattr(m.quant_config, \"ignore\") and any(\n        re.fullmatch(r\".*kv_b_proj\", l) for l in m.quant_config.ignore\n    )\n    if _is_mlaprolog:\n        (\n            q_pe,\n            k_pe,\n            q_nope_out,\n            k_nope,\n            q_lora,\n            forward_batch,\n            positions,\n            dynamic_scale,\n        ) = m.mla_preprocess.forward(\n            positions, hidden_states, forward_batch, zero_allocator\n        )\n    else:\n        if m.alt_stream is not None:\n            mla_event = torch.npu.Event()\n            mla_event.record()\n            with torch.npu.stream(m.alt_stream):\n                # alt stream waits for the completion of the event on the main stream to ensure data dependency is complete\n                torch.npu.current_stream().wait_event(mla_event)\n                (\n                    q_pe,\n                    k_pe,\n                    q_nope_out,\n                    k_nope,\n                    forward_batch,\n                    zero_allocator,\n                    positions,\n                ) = m.mla_preprocess.forward(\n                    positions, hidden_states, forward_batch, zero_allocator\n                )\n\n            fused_qkv_a_proj_out = m.fused_qkv_a_proj_with_mqa(hidden_states)[0]\n            q, _ = fused_qkv_a_proj_out.split(\n                [m.q_lora_rank, m.kv_lora_rank + m.qk_rope_head_dim], dim=-1\n            )\n            q_lora = m.q_a_layernorm(q)\n            torch.npu.current_stream().wait_event(m.alt_stream)\n        else:\n            (\n                q_pe,\n                k_pe,\n                q_nope_out,\n                k_nope,\n                forward_batch,\n                zero_allocator,\n                positions,\n            ) = m.mla_preprocess.forward(\n                positions, hidden_states, forward_batch, zero_allocator\n            )\n            fused_qkv_a_proj_out = m.fused_qkv_a_proj_with_mqa(hidden_states)[0]\n            q, _ = fused_qkv_a_proj_out.split(\n                [m.q_lora_rank, m.kv_lora_rank + m.qk_rope_head_dim], dim=-1\n            )\n            q_lora = m.q_a_layernorm(q)\n\n    return (\n        q_pe,\n        k_pe,\n        q_nope_out,\n        k_nope,\n        q_lora,\n        forward_batch,\n        zero_allocator,\n        positions,\n        dynamic_scale,\n    )\n\n\n# endregion\n"
  },
  {
    "path": "python/sglang/srt/hardware_backend/npu/modules/qwen_vl_processor.py",
    "content": "from typing import Optional\n\nimport torch\nimport torchvision.transforms.v2.functional as tvF\nfrom transformers.image_processing_utils import BatchFeature\nfrom transformers.image_processing_utils_fast import (\n    group_images_by_shape,\n    reorder_images,\n)\nfrom transformers.image_utils import SizeDict\nfrom transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize\nfrom transformers.utils import TensorType\n\nfrom sglang.srt.utils import apply_module_patch\n\n\n# Func refers to transformers.models.qwen2_vl.image_processing_qwen2_vl_fast.py\n# Qwen2VLImageProcessorFast._preprocess\ndef npu_wrapper_preprocess(func):\n\n    def _preprocess(\n        self,\n        images: list[\"torch.Tensor\"],\n        do_resize: bool,\n        size: SizeDict,\n        interpolation: Optional[\"tvF.InterpolationMode\"],\n        do_rescale: bool,\n        rescale_factor: float,\n        do_normalize: bool,\n        image_mean: float | list[float] | None,\n        image_std: float | list[float] | None,\n        patch_size: int,\n        temporal_patch_size: int,\n        merge_size: int,\n        disable_grouping: bool | None,\n        return_tensors: str | TensorType | None,\n        **kwargs,\n    ):\n        # Group images by size for batched resizing\n        grouped_images, grouped_images_index = group_images_by_shape(\n            images, disable_grouping=disable_grouping\n        )\n        resized_images_grouped = {}\n        for shape, stacked_images in grouped_images.items():\n            height, width = stacked_images.shape[-2:]\n            if do_resize:\n                resized_height, resized_width = smart_resize(\n                    height,\n                    width,\n                    factor=patch_size * merge_size,\n                    min_pixels=size[\"shortest_edge\"],\n                    max_pixels=size[\"longest_edge\"],\n                )\n                stacked_images = self.resize(\n                    image=stacked_images,\n                    size=SizeDict(height=resized_height, width=resized_width),\n                    interpolation=interpolation,\n                )\n            resized_images_grouped[shape] = stacked_images\n        resized_images = reorder_images(resized_images_grouped, grouped_images_index)\n\n        # Group images by size for further processing\n        # Needed in case do_resize is False, or resize returns images with different sizes\n        grouped_images, grouped_images_index = group_images_by_shape(\n            resized_images, disable_grouping=disable_grouping\n        )\n        processed_images_grouped = {}\n        processed_grids = {}\n        for shape, stacked_images in grouped_images.items():\n            resized_height, resized_width = stacked_images.shape[-2:]\n            # Fused rescale and normalize\n            patches = self.rescale_and_normalize(\n                stacked_images,\n                do_rescale,\n                rescale_factor,\n                do_normalize,\n                image_mean,\n                image_std,\n            )\n            if patches.ndim == 4:\n                # add a temporal dimension if we have images\n                patches = patches.unsqueeze(1)\n            if patches.shape[1] % temporal_patch_size != 0:\n                repeats = patches[:, -1:].repeat(1, temporal_patch_size - 1, 1, 1, 1)\n                patches = torch.cat([patches, repeats], dim=1)\n            batch_size, grid_t, channel = patches.shape[:3]\n            grid_t = grid_t // temporal_patch_size\n            grid_h, grid_w = resized_height // patch_size, resized_width // patch_size\n\n            ######################################\n            # Start of modifications for sglang  #\n            ######################################\n            patches = patches.view(\n                batch_size * grid_t,\n                temporal_patch_size * channel,\n                grid_h // merge_size,\n                merge_size,\n                patch_size,\n                grid_w // merge_size,\n                merge_size,\n                patch_size,\n            )\n            patches = patches.permute(0, 1, 2, 5, 3, 6, 4, 7)\n            patches = patches.reshape(\n                batch_size,\n                grid_t,\n                temporal_patch_size,\n                channel,\n                grid_h * grid_w,\n                patch_size,\n                patch_size,\n            )\n            patches = patches.permute(0, 1, 4, 3, 2, 5, 6)\n            flatten_patches = patches.reshape(\n                batch_size,\n                grid_t * grid_h * grid_w,\n                -1,\n            )\n            ######################################\n            #  End of modifications for sglang   #\n            ######################################\n\n            processed_images_grouped[shape] = flatten_patches\n            processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size\n\n        processed_images = reorder_images(\n            processed_images_grouped, grouped_images_index\n        )\n        processed_grids = reorder_images(processed_grids, grouped_images_index)\n        pixel_values = torch.cat(processed_images, dim=0)\n        image_grid_thw = torch.tensor(processed_grids)\n\n        return BatchFeature(\n            data={\"pixel_values\": pixel_values, \"image_grid_thw\": image_grid_thw},\n            tensor_type=return_tensors,\n        )\n\n    return _preprocess\n\n\n_npu_preprocess_patched = False\n\n\ndef npu_apply_qwen_image_preprocess_patch():\n    global _npu_preprocess_patched\n    if _npu_preprocess_patched:\n        return\n    apply_module_patch(\n        \"transformers.models.qwen2_vl.image_processing_qwen2_vl_fast.Qwen2VLImageProcessorFast\",\n        \"_preprocess\",\n        [npu_wrapper_preprocess],\n    )\n    _npu_preprocess_patched = True\n"
  },
  {
    "path": "python/sglang/srt/hardware_backend/npu/moe/topk.py",
    "content": "from typing import TYPE_CHECKING, Optional\n\nimport torch\nfrom sgl_kernel_npu.norm.l1_norm import l1_norm\n\nfrom sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder\nfrom sglang.srt.eplb.expert_location_dispatch import topk_ids_logical_to_physical\nfrom sglang.srt.layers.moe.routed_experts_capturer import get_global_experts_capturer\nfrom sglang.srt.layers.moe.topk import StandardTopKOutput, select_experts\n\nif TYPE_CHECKING:\n    from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo\n    from sglang.srt.layers.moe.topk import TopKConfig, TopKOutput\n\n\ndef fused_topk_npu(\n    hidden_states: torch.Tensor,\n    router_logits: torch.Tensor,\n    topk_config: \"TopKConfig\",\n    num_token_non_padded: Optional[torch.Tensor] = None,\n    expert_location_dispatch_info: Optional[\"ExpertLocationDispatchInfo\"] = None,\n    layer_id: Optional[int] = None,\n) -> \"TopKOutput\":\n\n    use_grouped_topk = topk_config.use_grouped_topk\n    renormalize = topk_config.renormalize\n    correction_bias = topk_config.correction_bias\n\n    if not use_grouped_topk:\n        topk_weights, topk_ids, _ = torch.ops.npu.npu_moe_gating_top_k_softmax(\n            router_logits,\n            k=topk_config.top_k,\n        )\n\n        if renormalize:\n            topk_weights = l1_norm(\n                topk_weights\n                if topk_config.num_fused_shared_experts == 0\n                else topk_weights[:, :-1]\n            )\n        topk_weights = topk_weights.to(torch.float32)\n\n    elif use_grouped_topk and correction_bias is not None:\n        # Force set routed_scaling_factor = 1 to optimize renormalize\n        topk_weights, topk_ids, _ = torch.ops.npu.npu_moe_gating_top_k(\n            router_logits.to(torch.float32),\n            k=topk_config.top_k,\n            bias=correction_bias.to(torch.float32),\n            k_group=topk_config.topk_group,\n            group_count=topk_config.num_expert_group,\n            group_select_mode=1,\n            renorm=0,\n            norm_type=1,\n            routed_scaling_factor=(\n                1 if renormalize else topk_config.routed_scaling_factor\n            ),\n            eps=float(1e-20),\n        )\n\n    else:\n        topk_config.torch_native = True\n        return select_experts(\n            hidden_states=hidden_states,\n            layer_id=layer_id,\n            router_logits=router_logits,\n            topk_config=topk_config,\n            num_token_non_padded=num_token_non_padded,\n            expert_location_dispatch_info=expert_location_dispatch_info,\n        )\n\n    if expert_location_dispatch_info is not None:\n        topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)\n    get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids)\n    get_global_experts_capturer().capture(\n        layer_id=layer_id,\n        topk_ids=topk_ids,\n    )\n\n    return StandardTopKOutput(topk_weights, topk_ids, router_logits)\n"
  },
  {
    "path": "python/sglang/srt/hardware_backend/npu/quantization/fused_moe_method_npu.py",
    "content": "from typing import TYPE_CHECKING, Optional\n\nimport numpy as np\nimport torch\n\nfrom sglang.srt.hardware_backend.npu.utils import npu_format_cast\nfrom sglang.srt.layers.quantization.base_config import FusedMoEMethodBase\n\nif TYPE_CHECKING:\n    from sglang.srt.layers.moe.token_dispatcher import (\n        CombineInput,\n        StandardDispatchOutput,\n    )\n    from sglang.srt.layers.quantization.base_config import QuantizationConfig\n\n\ndef npu_fused_experts_w4a4(\n    hidden_states: torch.Tensor,\n    w13: torch.Tensor,\n    w13_scale: torch.Tensor,\n    w2: torch.Tensor,\n    w2_scale: torch.Tensor,\n    topk_weights: torch.Tensor,\n    topk_ids: torch.Tensor,\n    top_k: int,\n):\n    original_shape = hidden_states.shape\n    original_dtype = hidden_states.dtype\n    scale_dtype = original_dtype if original_dtype == torch.bfloat16 else torch.float32\n    if len(original_shape) == 3:\n        hidden_states = hidden_states.view(-1, hidden_states.shape[-1])\n    num_tokens = hidden_states.shape[0]\n    num_experts = w13.shape[0]\n\n    hidden_states, expanded_row_idx, expert_tokens, _ = (\n        torch.ops.npu.npu_moe_init_routing_v2(\n            hidden_states,\n            topk_ids,\n            active_num=num_tokens * top_k,\n            expert_num=num_experts,\n            expert_tokens_num_type=1,\n            expert_tokens_num_flag=True,\n            active_expert_range=[0, num_experts],\n            quant_mode=-1,\n        )\n    )\n    expert_tokens = expert_tokens.to(torch.int64)\n\n    # gmm1: gate_up_proj\n    hidden_states, pertoken_scale = torch.ops.npu.npu_dynamic_quant(\n        hidden_states, dst_type=torch.quint4x2\n    )\n    scale_args13 = {\n        \"scale\": [w13_scale],\n        \"per_token_scale\": [pertoken_scale],\n    }\n\n    hidden_states = torch.ops.npu.npu_grouped_matmul(\n        x=[hidden_states],\n        weight=[w13],\n        **scale_args13,\n        split_item=2,\n        group_list_type=1,\n        group_type=0,\n        group_list=expert_tokens,\n        output_dtype=original_dtype,\n    )[0]\n    # act_fn: swiglu\n    hidden_states = torch.ops.npu.npu_swiglu(hidden_states)\n    hidden_states, pertoken_scale = torch.ops.npu.npu_dynamic_quant(hidden_states)\n\n    scale_args2 = {\n        \"scale\": [w2_scale.to(scale_dtype)],\n        \"per_token_scale\": [pertoken_scale],\n    }\n    # gmm2: down_proj\n    hidden_states = torch.ops.npu.npu_grouped_matmul(\n        x=[hidden_states],\n        weight=[w2],\n        **scale_args2,\n        split_item=2,\n        group_list_type=1,\n        group_type=0,\n        group_list=expert_tokens,\n        output_dtype=original_dtype,\n    )[0]\n\n    final_hidden_states = torch.ops.npu.npu_moe_finalize_routing(\n        hidden_states,\n        skip1=None,\n        skip2=None,\n        bias=None,\n        scales=topk_weights,\n        expanded_src_to_dst_row=expanded_row_idx,\n        export_for_source_row=topk_ids,\n        drop_pad_mode=2,\n    )\n    if len(original_shape) == 3:\n        final_hidden_states = final_hidden_states.view(original_shape)\n    return final_hidden_states\n\n\ndef npu_fused_experts(\n    hidden_states: torch.Tensor,\n    w13: torch.Tensor,\n    w13_scale: torch.Tensor,\n    w2: torch.Tensor,\n    w2_scale: torch.Tensor,\n    topk_weights: torch.Tensor,\n    topk_ids: torch.Tensor,\n    top_k: int,\n    **kwargs,\n):\n    w13_offset = kwargs.get(\"w13_offset\", None)\n    w2_offset = kwargs.get(\"w2_offset\", None)\n    use_wna16 = kwargs.get(\"use_wna16\", False)\n\n    original_shape = hidden_states.shape\n    original_dtype = hidden_states.dtype\n    scale_dtype = original_dtype if original_dtype == torch.bfloat16 else torch.float32\n    if len(original_shape) == 3:\n        hidden_states = hidden_states.view(-1, hidden_states.shape[-1])\n    num_tokens = hidden_states.shape[0]\n    num_experts = w13.shape[0]\n    row_idx_len = num_tokens * top_k\n    row_idx = (\n        torch.arange(0, row_idx_len, dtype=torch.int32, device=topk_weights.device)\n        .view(top_k, -1)\n        .permute(1, 0)\n        .contiguous()\n    )\n    hidden_states, expanded_row_idx, expanded_expert_idx = (\n        torch.ops.npu.npu_moe_init_routing(\n            hidden_states, row_idx=row_idx, expert_idx=topk_ids, active_num=num_tokens\n        )\n    )\n    expert_tokens = torch.ops.npu.npu_moe_compute_expert_tokens(\n        expanded_expert_idx, num_experts\n    )\n    expert_tokens = expert_tokens.to(torch.int64)\n    # gmm1: gate_up_proj\n    if not use_wna16:\n        hidden_states, pertoken_scale = torch.ops.npu.npu_dynamic_quant(hidden_states)\n        scale_args13 = {\n            \"scale\": [w13_scale.to(scale_dtype)],\n            \"per_token_scale\": [pertoken_scale],\n        }\n    else:\n        scale_args13 = {\n            \"antiquant_scale\": [w13_scale],\n            \"antiquant_offset\": [w13_offset],\n        }\n\n    hidden_states = torch.ops.npu.npu_grouped_matmul(\n        x=[hidden_states],\n        weight=[w13],\n        **scale_args13,\n        split_item=2,\n        group_list_type=0,\n        group_type=0,\n        group_list=expert_tokens,\n        output_dtype=original_dtype,\n    )[0]\n    # act_fn: swiglu\n    if not use_wna16:\n        hidden_states, pertoken_scale = torch.ops.npu.npu_dequant_swiglu_quant(\n            hidden_states,\n            activate_left=True,\n            quant_mode=1,\n        )\n\n        scale_args2 = {\n            \"scale\": [w2_scale.to(scale_dtype)],\n            \"per_token_scale\": [pertoken_scale],\n        }\n    else:\n        hidden_states = torch.ops.npu.npu_swiglu(hidden_states)\n        scale_args2 = {\"antiquant_scale\": [w2_scale], \"antiquant_offset\": [w2_offset]}\n    # gmm2: down_proj\n    hidden_states = torch.ops.npu.npu_grouped_matmul(\n        x=[hidden_states],\n        weight=[w2],\n        **scale_args2,\n        split_item=2,\n        group_list_type=0,\n        group_type=0,\n        group_list=expert_tokens,\n        output_dtype=original_dtype,\n    )[0]\n\n    final_hidden_states = torch.ops.npu.npu_moe_finalize_routing(\n        hidden_states,\n        skip1=None,\n        skip2=None,\n        bias=None,\n        scales=topk_weights,\n        expanded_src_to_dst_row=expanded_row_idx,\n        export_for_source_row=topk_ids,\n    )\n    if len(original_shape) == 3:\n        final_hidden_states = final_hidden_states.view(original_shape)\n    return final_hidden_states\n\n\ndef npu_fused_experts_w8a8_decode(\n    hidden_states: torch.Tensor,\n    w13: torch.Tensor,\n    w13_scale: torch.Tensor,\n    w2: torch.Tensor,\n    w2_scale: torch.Tensor,\n    topk_weights: torch.Tensor,\n    topk_ids: torch.Tensor,\n    top_k: int,\n    **kwargs,\n):\n    num_tokens = hidden_states.shape[:-1].numel()\n    first_expert_idx = 0\n    last_expert_idx = w13.shape[0]\n    global_num_experts = w13.shape[0]\n    original_shape = hidden_states.shape\n    group_list_type = 1\n\n    sorted_hidden_states, expanded_row_idx, expert_tokens, pertoken_scale = (\n        torch.ops.npu.npu_moe_init_routing_v2(\n            hidden_states,\n            topk_ids,\n            active_num=num_tokens * top_k,\n            expert_num=global_num_experts,\n            expert_tokens_num_type=group_list_type,\n            expert_tokens_num_flag=True,\n            active_expert_range=[first_expert_idx, last_expert_idx],\n            quant_mode=1,\n        )\n    )\n\n    hidden_states = torch.ops.npu.npu_grouped_matmul(\n        x=[sorted_hidden_states],\n        weight=[w13],\n        scale=[w13_scale],\n        per_token_scale=[pertoken_scale],\n        group_list=expert_tokens,\n        split_item=2,\n        group_type=0,\n        group_list_type=group_list_type,\n        output_dtype=torch.bfloat16,\n    )[0]\n\n    # act_fn: swiglu\n    hidden_states, swiglu_out_scale = torch.ops.npu.npu_dequant_swiglu_quant(\n        hidden_states, quant_mode=1, activate_left=True\n    )\n\n    output = torch.ops.npu.npu_grouped_matmul(\n        x=[hidden_states],\n        weight=[w2],\n        scale=[w2_scale],\n        per_token_scale=[swiglu_out_scale],\n        group_list=expert_tokens,\n        split_item=2,\n        group_type=0,\n        group_list_type=group_list_type,\n        output_dtype=torch.bfloat16,\n    )[0]\n\n    assert original_shape is not None\n    final_hidden_states = torch.ops.npu.npu_moe_token_unpermute(\n        permuted_tokens=output,\n        sorted_indices=torch.abs(expanded_row_idx),\n        probs=topk_weights,\n    )\n    if len(original_shape) == 3:\n        final_hidden_states = final_hidden_states.view(original_shape)\n\n    return final_hidden_states\n\n\ndef npu_fused_moe_without_routing_weights_bf16(\n    layer, hidden_states, group_list_type, group_list, output_dtype\n):\n    from sgl_kernel_npu.activation.swiglu_quant import swiglu_quant\n\n    # gmm1: gate_up_proj\n    hidden_states = torch.ops.npu.npu_grouped_matmul(\n        x=[hidden_states],\n        weight=[layer.w13_weight],\n        split_item=2,\n        group_list_type=group_list_type,\n        group_type=0,\n        group_list=group_list,\n        output_dtype=output_dtype,\n    )[0]\n    hidden_states, _ = swiglu_quant(\n        hidden_states, group_list, group_list_type, need_quant=False\n    )\n    # gmm2: down_proj\n    hidden_states = torch.ops.npu.npu_grouped_matmul(\n        x=[hidden_states],\n        weight=[layer.w2_weight],\n        split_item=2,\n        group_list_type=group_list_type,\n        group_type=0,\n        group_list=group_list,\n        output_dtype=output_dtype,\n    )[0]\n    return hidden_states\n\n\ndef fused_moe_npu(\n    x,\n    w1,\n    w2,\n    topk_output,\n    moe_runner_config,\n):\n    # TODO: reuse the codes of UnquantizedFusedMoEMethod-forward_npu\n    topk_weights, topk_ids, _ = topk_output\n    original_dtype = x.dtype\n    num_tokens = x.shape[0]\n    topk_weights = topk_weights.to(x.dtype)\n    topk_ids = topk_ids.to(torch.int32)\n    num_experts = w1.shape[0]\n    top_k = topk_weights.shape[-1]\n    row_idx_len = num_tokens * top_k\n    row_idx = (\n        torch.arange(0, row_idx_len, dtype=torch.int32, device=topk_weights.device)\n        .view(top_k, -1)\n        .permute(1, 0)\n        .contiguous()\n    )\n\n    hidden_states, expanded_row_idx, expanded_expert_idx = (\n        torch.ops.npu.npu_moe_init_routing(\n            x, row_idx=row_idx, expert_idx=topk_ids, active_num=num_tokens\n        )\n    )\n\n    expert_tokens = torch.ops.npu.npu_moe_compute_expert_tokens(\n        expanded_expert_idx, num_experts\n    )\n\n    expert_tokens = expert_tokens.to(torch.int64)\n\n    # gmm1: gate_up_proj\n    hidden_states = torch.ops.npu.npu_grouped_matmul(\n        x=[hidden_states],\n        weight=[w1.permute(0, 2, 1)],\n        bias=None,\n        split_item=2,\n        group_list_type=0,\n        group_type=0,\n        group_list=expert_tokens,\n        output_dtype=original_dtype,\n    )[0]\n\n    # act_fn:\n    if moe_runner_config.activation == \"silu\":\n        hidden_states = torch.ops.npu.npu_swiglu(hidden_states)\n    else:\n        from sglang.srt.layers.activation import GeluAndMul\n\n        hidden_states = GeluAndMul()(hidden_states)\n\n    # gmm2: down_proj\n    hidden_states = torch.ops.npu.npu_grouped_matmul(\n        x=[hidden_states],\n        weight=[w2.permute(0, 2, 1)],\n        bias=None,\n        split_item=2,\n        group_list_type=0,\n        group_type=0,\n        group_list=expert_tokens,\n        output_dtype=original_dtype,\n    )[0]\n\n    final_hidden_states = torch.ops.npu.npu_moe_finalize_routing(\n        hidden_states,\n        skip1=None,\n        skip2=None,\n        bias=None,\n        scales=topk_weights,\n        expanded_src_to_dst_row=expanded_row_idx,\n        export_for_source_row=topk_ids,\n    )\n    return final_hidden_states\n\n\nclass _NPUFusedMoEMethodBase(FusedMoEMethodBase):\n\n    def __init__(\n        self,\n        quant_config: Optional[\"QuantizationConfig\"] = None,\n    ):\n        self.quant_config = quant_config\n\n\nclass NPUW4A4Int4DynamicMoEMethod(_NPUFusedMoEMethodBase):\n\n    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:\n        layer.w13_weight.data = npu_format_cast(layer.w13_weight.data.transpose(1, 2))\n        layer.w13_weight.data = self._pack_to_int32(\n            layer.w13_weight.data.to(torch.int32)\n        )\n\n        layer.w2_weight.data = npu_format_cast(layer.w2_weight.data.transpose(1, 2))\n\n        scale_np = layer.w13_weight_scale.data.cpu().numpy()\n        scale_np.dtype = np.uint32\n        scale_uint64_tensor = torch.from_numpy(scale_np.astype(np.int64)).npu()\n\n        layer.w13_weight_scale = torch.nn.Parameter(\n            scale_uint64_tensor.squeeze(-1), requires_grad=False\n        )\n        layer.w2_weight_scale = torch.nn.Parameter(\n            layer.w2_weight_scale.data.squeeze(-1), requires_grad=False\n        )\n\n        # Compressed-tensors format doesn't have this field\n        if hasattr(layer, \"w13_weight_offset\"):\n            layer.w13_weight_offset = torch.nn.Parameter(\n                layer.w13_weight_offset.data.squeeze(-1),\n                requires_grad=False,\n            )\n        if hasattr(layer, \"w2_weight_offset\"):\n            layer.w2_weight_offset = torch.nn.Parameter(\n                layer.w2_weight_offset.data.squeeze(-1),\n                requires_grad=False,\n            )\n\n    def _pack_to_int32(self, weight: torch.Tensor):\n        # pack 8 int4 to int32, we use a int32 to represent a int4\n        assert (\n            weight.shape[-1] % 8 == 0\n        ), \"the last dim of weight needs to be divided by 8\"\n        new_weight = torch.ops.npu.npu_convert_weight_to_int4pack(weight.flatten(0, 1))\n        new_weight = new_weight.view(weight.shape[0], weight.shape[1], -1)\n        return new_weight\n\n    def apply(\n        self,\n        layer,\n        dispatch_output: \"StandardDispatchOutput\",\n    ) -> \"CombineInput\":\n        from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput\n\n        x = dispatch_output.hidden_states\n        topk_output = dispatch_output.topk_output\n\n        topk_weights, topk_ids, _ = topk_output\n        topk_ids = topk_ids.to(torch.int32)\n        topk_weights = topk_weights.to(x.dtype)\n        output = npu_fused_experts_w4a4(\n            hidden_states=x,\n            w13=layer.w13_weight,\n            w13_scale=layer.w13_weight_scale,\n            w2=layer.w2_weight,\n            w2_scale=layer.w2_weight_scale,\n            topk_weights=topk_weights,\n            topk_ids=topk_ids,\n            top_k=topk_ids.shape[1],\n        )\n        return StandardCombineInput(hidden_states=output)\n\n\nclass NPUW8A8Int8DynamicMoEMethod(_NPUFusedMoEMethodBase):\n\n    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:\n        layer.w13_weight.data = npu_format_cast(layer.w13_weight.data.transpose(1, 2))\n        layer.w2_weight.data = npu_format_cast(layer.w2_weight.data.transpose(1, 2))\n        layer.w13_weight_scale = torch.nn.Parameter(\n            layer.w13_weight_scale.data.squeeze(-1), requires_grad=False\n        )\n        layer.w2_weight_scale = torch.nn.Parameter(\n            layer.w2_weight_scale.data.squeeze(-1), requires_grad=False\n        )\n        layer.w13_weight_scale_bf16 = torch.nn.Parameter(\n            layer.w13_weight_scale.data.to(dtype=torch.bfloat16), requires_grad=False\n        )\n        layer.w2_weight_scale_bf16 = torch.nn.Parameter(\n            layer.w2_weight_scale.data.to(dtype=torch.bfloat16), requires_grad=False\n        )\n        # Compressed-tensors format doesn't have this field\n        if hasattr(layer, \"w13_weight_offset\"):\n            layer.w13_weight_offset = torch.nn.Parameter(\n                layer.w13_weight_offset.data.squeeze(-1),\n                requires_grad=False,\n            )\n        if hasattr(layer, \"w2_weight_offset\"):\n            layer.w2_weight_offset = torch.nn.Parameter(\n                layer.w2_weight_offset.data.squeeze(-1),\n                requires_grad=False,\n            )\n\n    def apply(\n        self,\n        layer,\n        dispatch_output: \"StandardDispatchOutput\",\n    ) -> \"CombineInput\":\n        from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput\n\n        # release fp32 scale to save memory\n        layer.w13_weight_scale = None\n        layer.w2_weight_scale = None\n\n        hidden_states = dispatch_output.hidden_states\n        topk_output = dispatch_output.topk_output\n\n        topk_weights, topk_ids, _ = topk_output\n        topk_ids = topk_ids.to(torch.int32)\n        topk_weights = topk_weights.to(hidden_states.dtype)\n\n        # prefill\n        if not torch.npu.is_current_stream_capturing():\n            output = npu_fused_experts(\n                hidden_states=hidden_states,\n                w13=layer.w13_weight,\n                w13_scale=layer.w13_weight_scale_bf16,\n                w2=layer.w2_weight,\n                w2_scale=layer.w2_weight_scale_bf16,\n                topk_weights=topk_weights,\n                topk_ids=topk_ids,\n                top_k=topk_ids.shape[1],\n            )\n        # decode\n        else:\n            output = npu_fused_experts_w8a8_decode(\n                hidden_states=hidden_states,\n                w13=layer.w13_weight,\n                w13_scale=layer.w13_weight_scale_bf16,\n                w2=layer.w2_weight,\n                w2_scale=layer.w2_weight_scale_bf16,\n                topk_weights=topk_weights,\n                topk_ids=topk_ids,\n                top_k=topk_ids.shape[1],\n            )\n\n        return StandardCombineInput(hidden_states=output)\n\n    def apply_without_routing_weights(\n        self,\n        layer,\n        hidden_states,\n        hidden_states_scale,\n        group_list_type,\n        group_list,\n        output_dtype,\n    ):\n        # gmm1: gate_up_proj\n        hidden_states = torch.ops.npu.npu_grouped_matmul(\n            x=[hidden_states],\n            weight=[layer.w13_weight],\n            split_item=2,\n            group_list_type=group_list_type,\n            group_type=0,\n            group_list=group_list,\n            output_dtype=torch.int32,\n        )[0]\n\n        # act_fn: swiglu\n        hidden_states, swiglu_out_scale = torch.ops.npu.npu_dequant_swiglu_quant(\n            x=hidden_states,\n            weight_scale=layer.w13_weight_scale,\n            activation_scale=hidden_states_scale,\n            bias=None,\n            quant_scale=None,\n            quant_offset=None,\n            group_index=group_list,\n            activate_left=True,\n            quant_mode=1,\n        )\n\n        # gmm2: down_proj\n        hidden_states = torch.ops.npu.npu_grouped_matmul(\n            x=[hidden_states],\n            weight=[layer.w2_weight],\n            scale=[layer.w2_weight_scale.to(output_dtype)],\n            per_token_scale=[swiglu_out_scale],\n            split_item=2,\n            group_list_type=group_list_type,\n            group_type=0,\n            group_list=group_list,\n            output_dtype=output_dtype,\n        )[0]\n        return hidden_states\n\n\nclass NPUW4A8Int8DynamicMoEMethod(_NPUFusedMoEMethodBase):\n\n    def _process_scale(\n        self, weight: torch.Tensor, scale, per_group_scale, is_per_channel_weight\n    ):\n        scale = scale.transpose(1, 2).contiguous()\n\n        if is_per_channel_weight:\n            scale_np = scale.cpu().numpy()\n            scale_np.dtype = np.uint32\n            scale_uint64_tensor = torch.from_numpy(scale_np.astype(np.int64)).npu()\n            return scale_uint64_tensor, None\n\n        per_group_scale = per_group_scale.transpose(1, 2).contiguous()\n        group_num, k, n = weight.shape\n        # the weight of the new version is reduced by half by pack n, so it needs to be restored\n        n = n * 2\n        per_group_scale = per_group_scale.reshape(group_num, -1, n)\n        group_num, quantgroup_num, n = per_group_scale.shape\n        bias = None\n\n        scale_fp32 = (scale * per_group_scale).to(torch.float16).to(torch.float32)\n        scale_fp32_np = scale_fp32.cpu().numpy()\n        scale_fp32_np.dtype = np.uint32\n        sscale_uint64 = np.zeros((group_num, quantgroup_num, n * 2), dtype=np.uint32)\n\n        sscale_uint64[..., ::2] = scale_fp32_np\n\n        sscale_uint64_buffer = np.frombuffer(\n            sscale_uint64.tobytes(), dtype=np.int64\n        ).copy()\n        sscale_uint64_tensor = torch.from_numpy(sscale_uint64_buffer).reshape(\n            group_num, quantgroup_num, n\n        )\n        sscale_uint64_tensor = sscale_uint64_tensor.npu()\n        return sscale_uint64_tensor, bias\n\n    def _update_bias(self, layer, w13_bias, w2_bias):\n        layer.w13_scale_bias.data = (\n            layer.w13_scale_bias.data.transpose(1, 2).contiguous().sum(axis=1)\n        )\n        layer.w2_scale_bias.data = (\n            layer.w2_scale_bias.data.transpose(1, 2).contiguous().sum(axis=1)\n        )\n\n    def _pack_to_int32(self, weight: torch.Tensor):\n        # pack 4 int8(int4*2) to int32, because in pytorch, we need to use int32 to represent int4\n        assert (\n            weight.shape[-1] % 4 == 0\n        ), \"the last dim of weight needs to be divided by 4\"\n        return weight.view(torch.int32).contiguous()\n\n    def process_weights_after_loading(\n        self, layer: torch.nn.Module, is_per_channel_weight, activation_use_clip\n    ) -> None:\n        if not activation_use_clip:\n            self._process_weights_without_clip(layer, is_per_channel_weight)\n        else:\n            self._process_weights_with_clip(layer)\n\n        layer.w13_weight = torch.nn.Parameter(\n            layer.w13_weight.data.transpose(1, 2).contiguous(), requires_grad=False\n        )\n        layer.w2_weight = torch.nn.Parameter(\n            layer.w2_weight.data.transpose(1, 2).contiguous(), requires_grad=False\n        )\n\n        layer.w13_weight.data = npu_format_cast(layer.w13_weight.data)\n        layer.w2_weight.data = npu_format_cast(layer.w2_weight.data)\n\n        layer.w13_weight.data = self._pack_to_int32(layer.w13_weight.data)\n        layer.w2_weight.data = self._pack_to_int32(layer.w2_weight.data)\n\n    def _process_weights_without_clip(\n        self, layer: torch.nn.Module, is_per_channel_weight\n    ) -> None:\n        w13_weight_scale_second = (\n            layer.w13_weight_scale_second.data\n            if hasattr(layer, \"w13_weight_scale_second\")\n            else None\n        )\n        w2_weight_scale_second = (\n            layer.w2_weight_scale_second.data\n            if hasattr(layer, \"w2_weight_scale_second\")\n            else None\n        )\n        layer.w13_weight_scale.data, w13_bias = self._process_scale(\n            layer.w13_weight,\n            layer.w13_weight_scale.data,\n            w13_weight_scale_second,\n            is_per_channel_weight,\n        )\n        layer.w2_weight_scale.data, w2_bias = self._process_scale(\n            layer.w2_weight,\n            layer.w2_weight_scale.data,\n            w2_weight_scale_second,\n            is_per_channel_weight,\n        )\n        if hasattr(layer, \"w13_weight_scale_second\"):\n            # scale_second is no longer used, release this part of the memory\n            del layer.w13_weight_scale_second\n            del layer.w2_weight_scale_second\n            del layer.w13_weight_offset_second\n            del layer.w2_weight_offset_second\n\n        self._update_bias(layer, w13_bias, w2_bias)\n\n    def _process_weights_with_clip(self, layer: torch.nn.Module) -> None:\n        w13_weight_scale = (\n            layer.w13_weight_scale.data.squeeze(-1).contiguous().unsqueeze(1)\n        )\n        w2_weight_scale = (\n            layer.w2_weight_scale.data.squeeze(-1).contiguous().unsqueeze(1)\n        )\n        layer.w13_weight_scale = torch.nn.Parameter(\n            w13_weight_scale, requires_grad=False\n        )\n        layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, requires_grad=False)\n        layer.w13_scale_bias = layer.w13_bias\n        layer.w2_scale_bias = layer.w2_bias\n\n    def apply(\n        self,\n        layer,\n        dispatch_output: \"StandardDispatchOutput\",\n    ) -> \"CombineInput\":\n        from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput\n\n        hidden_states = dispatch_output.hidden_states\n        topk_output = dispatch_output.topk_output\n\n        topk_weights, topk_ids, _ = topk_output\n        top_k = topk_ids.shape[1]\n        group_list_type = 1\n        original_shape = hidden_states.shape\n        topk_weights = topk_weights\n\n        num_tokens = hidden_states.shape[:-1].numel()\n\n        first_expert_idx = 0\n        last_expert_idx = layer.num_experts\n        global_num_experts = layer.num_experts\n\n        sorted_hidden_states, expanded_row_idx, expert_tokens, pertoken_scale = (\n            torch.ops.npu.npu_moe_init_routing_v2(\n                hidden_states,\n                topk_ids,\n                active_num=num_tokens * top_k,\n                expert_num=global_num_experts,\n                expert_tokens_num_type=1,\n                expert_tokens_num_flag=True,\n                active_expert_range=[first_expert_idx, last_expert_idx],\n                quant_mode=1,\n            )\n        )\n\n        expert_tokens = expert_tokens.to(torch.int64)\n\n        bias1 = [layer.w13_scale_bias]\n        bias2 = [layer.w2_scale_bias]\n        w1_scale = [layer.w13_weight_scale]\n        w2_scale = [layer.w2_weight_scale]\n        _output_dtype = torch.bfloat16\n\n        hidden_states = torch.ops.npu.npu_grouped_matmul(\n            x=[sorted_hidden_states],\n            weight=[layer.w13_weight],\n            scale=w1_scale,\n            bias=bias1,\n            per_token_scale=[pertoken_scale],\n            group_list=expert_tokens,\n            split_item=2,\n            group_type=0,\n            group_list_type=group_list_type,\n            output_dtype=_output_dtype,\n        )[0]\n\n        # act_fn: swiglu\n        hidden_states = torch.ops.npu.npu_swiglu(hidden_states)\n        hidden_states, swiglu_out_scale = torch.ops.npu.npu_dynamic_quant(hidden_states)\n\n        output = torch.ops.npu.npu_grouped_matmul(\n            x=[hidden_states],\n            weight=[layer.w2_weight],\n            scale=w2_scale,\n            bias=bias2,\n            per_token_scale=[swiglu_out_scale],\n            group_list=expert_tokens,\n            split_item=2,\n            group_type=0,\n            group_list_type=group_list_type,\n            output_dtype=_output_dtype,\n        )[0]\n\n        assert original_shape is not None\n        final_hidden_states = torch.ops.npu.npu_moe_token_unpermute(\n            permuted_tokens=output,\n            sorted_indices=torch.abs(expanded_row_idx),\n            probs=topk_weights,\n        )\n        if len(original_shape) == 3:\n            final_hidden_states = final_hidden_states.view(original_shape)\n\n        return StandardCombineInput(hidden_states=final_hidden_states)\n\n    def apply_without_routing_weights(\n        self,\n        layer,\n        hidden_states,\n        hidden_states_scale,\n        group_list_type,\n        group_list,\n        output_dtype,\n    ):\n        from sgl_kernel_npu.activation.swiglu_quant import swiglu_quant\n\n        hidden_states = torch.ops.npu.npu_grouped_matmul(\n            x=[hidden_states],\n            weight=[layer.w13_weight],\n            scale=[layer.w13_weight_scale],\n            bias=[layer.w13_scale_bias],\n            per_token_scale=[hidden_states_scale],\n            group_list=group_list,\n            split_item=2,\n            group_type=0,\n            group_list_type=group_list_type,\n            output_dtype=output_dtype,\n        )[0]\n\n        hidden_states, swiglu_out_scale = swiglu_quant(\n            hidden_states, group_list, group_list_type\n        )\n\n        hidden_states = torch.ops.npu.npu_grouped_matmul(\n            x=[hidden_states],\n            weight=[layer.w2_weight],\n            scale=[layer.w2_weight_scale],\n            bias=[layer.w2_scale_bias],\n            per_token_scale=[swiglu_out_scale],\n            group_list=group_list,\n            split_item=2,\n            group_type=0,\n            group_list_type=group_list_type,\n            output_dtype=output_dtype,\n        )[0]\n\n        return hidden_states\n\n\nclass NPUW4A16Int4DynamicMoEMethod(_NPUFusedMoEMethodBase):\n\n    def _pack_to_int32(self, weight: torch.Tensor):\n        assert weight.dim() == 3\n        if weight.dtype == torch.int32:\n            # pack 8 int4 to int32, we use a int32 to represent a int4\n            assert (\n                weight.shape[-1] % 8 == 0\n            ), \"the last dim of weight needs to be divided by 8\"\n            new_weight = torch.ops.npu.npu_convert_weight_to_int4pack(\n                weight.flatten(0, 1)\n            )\n            new_weight = new_weight.view(weight.shape[0], weight.shape[1], -1)\n        elif weight.dtype == torch.int8:\n            # pack 4 int8(int4*2) to int32, because in pytorch, we need to use int32 to represent int4\n            assert (\n                weight.shape[-1] % 4 == 0\n            ), \"the last dim of weight needs to be divided by 4\"\n            new_weight = weight.view(torch.int32).contiguous()\n        else:\n            raise ValueError(f\"{weight.dtype=} is not supported !\")\n        return new_weight\n\n    def _unpack_from_int32(\n        self,\n        value: torch.Tensor,\n        num_bits: int,\n        shape: torch.Size = None,\n        packed_dim=1,\n    ) -> torch.Tensor:\n        \"\"\"\n        Unpacks a tensor of packed int32 weights into individual int8s, maintaining the\n        original bit range.\n\n        Return tensors in int8\n\n        :param value: tensor to unpack\n        :param num_bits: number of bits to unpack each data point into\n        :param shape: shape to unpack into, used to remove padding\n        :returns: unpacked int8 tensor\n        \"\"\"\n        if value.dtype is not torch.int32:\n            raise ValueError(\n                f\"Expected {torch.int32} but got {value.dtype}, Aborting unpack.\"\n            )\n\n        if num_bits > 8:\n            raise ValueError(\"Unpacking is only supported for less than 8 bits\")\n\n        pack_factor = 32 // num_bits\n\n        # unpack\n        mask = (1 << num_bits) - 1\n\n        if packed_dim == 1:\n            unpacked = torch.zeros(\n                (value.shape[0], value.shape[1] * pack_factor),\n                device=value.device,\n                dtype=torch.int32,\n            )\n            for i in range(pack_factor):\n                unpacked[:, i::pack_factor] = (value >> (num_bits * i)) & mask\n\n            # remove padding\n            if shape is not None:\n                original_row_size = int(shape[1])\n                unpacked = unpacked[:, :original_row_size]\n        else:\n            unpacked = torch.zeros(\n                (value.shape[0] * pack_factor, value.shape[1]),\n                device=value.device,\n                dtype=torch.int32,\n            )\n            for i in range(pack_factor):\n                unpacked[i::pack_factor, :] = (value >> (num_bits * i)) & mask\n\n            # remove padding\n            original_row_size = int(shape[0])\n            unpacked = unpacked[:original_row_size, :]\n\n        # bits are packed in unsigned format, reformat to signed\n        # update the value range from unsigned to signed\n        offset = pow(2, num_bits) // 2\n        unpacked = (unpacked - offset).to(torch.int8)\n\n        return unpacked\n\n    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:\n        w13_weight_scale = layer.w13_weight_scale.data.transpose(-1, -2).contiguous()\n        w2_weight_scale = layer.w2_weight_scale.data.transpose(-1, -2).contiguous()\n        layer.w13_weight_scale = torch.nn.Parameter(\n            w13_weight_scale, requires_grad=False\n        )\n        layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, requires_grad=False)\n\n        layer.w13_weight_offset = torch.nn.Parameter(\n            layer.w13_weight_offset.data.transpose(-1, -2).contiguous(),\n            requires_grad=False,\n        )\n        layer.w2_weight_offset = torch.nn.Parameter(\n            layer.w2_weight_offset.data.transpose(-1, -2).contiguous(),\n            requires_grad=False,\n        )\n\n        # w = [n, k // 8]  --> [k, n // 8]\n        # w13_weight = layer.w13_weight.data.transpose(1, 2).contiguous()\n        # w2_weight = layer.w2_weight.data.transpose(1, 2).contiguous()\n        unpacked_w13_weight = (\n            self._unpack_from_int32(layer.w13_weight.data.flatten(0, 1), 4)\n            .view(layer.w13_weight.data.shape[0], layer.w13_weight.data.shape[1], -1)\n            .transpose(1, 2)\n            .contiguous()\n            .int()\n        )\n        unpacked_w2_weight = (\n            self._unpack_from_int32(layer.w2_weight.data.flatten(0, 1), 4)\n            .view(layer.w2_weight.data.shape[0], layer.w2_weight.data.shape[1], -1)\n            .transpose(1, 2)\n            .contiguous()\n            .int()\n        )\n\n        w13_weight = self._pack_to_int32(unpacked_w13_weight)\n        w2_weight = self._pack_to_int32(unpacked_w2_weight)\n\n        layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)\n        layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)\n\n    def apply(\n        self,\n        layer,\n        dispatch_output: \"StandardDispatchOutput\",\n    ) -> \"CombineInput\":\n        from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput\n\n        x = dispatch_output.hidden_states\n        topk_output = dispatch_output.topk_output\n\n        topk_weights, topk_ids, _ = topk_output\n        topk_ids = topk_ids.to(torch.int32)\n        topk_weights = topk_weights.to(x.dtype)\n        output = npu_fused_experts(\n            hidden_states=x,\n            w13=layer.w13_weight,\n            w13_scale=layer.w13_weight_scale,\n            w13_offset=layer.w13_weight_offset,\n            w2=layer.w2_weight,\n            w2_scale=layer.w2_weight_scale,\n            w2_offset=layer.w2_weight_offset,\n            topk_weights=topk_weights,\n            topk_ids=topk_ids,\n            top_k=topk_ids.shape[1],\n            use_wna16=True,\n        )\n        return StandardCombineInput(hidden_states=output)\n\n    def apply_without_routing_weights(\n        self,\n        layer,\n        hidden_states,\n        hidden_states_scale,\n        group_list_type,\n        group_list,\n        output_dtype,\n    ):\n        if hidden_states_scale is None:\n            # gmm1: gate_up_proj\n            hidden_states = torch.ops.npu.npu_grouped_matmul(\n                x=[hidden_states],\n                weight=[layer.w13_weight],\n                antiquant_scale=[layer.w13_weight_scale],\n                antiquant_offset=[layer.w13_weight_offset],\n                split_item=2,\n                group_list_type=group_list_type,\n                group_type=0,\n                group_list=group_list,\n                output_dtype=output_dtype,\n            )[0]\n\n            # act_fn: swiglu\n            hidden_states = torch.ops.npu.npu_swiglu(hidden_states)\n\n            # gmm2: down_proj\n            out_hidden = torch.ops.npu.npu_grouped_matmul(\n                x=[hidden_states],\n                weight=[layer.w2_weight],\n                antiquant_scale=[layer.w2_weight_scale],\n                antiquant_offset=[layer.w2_weight_offset],\n                split_item=2,\n                group_list_type=group_list_type,\n                group_type=0,\n                group_list=group_list,\n                output_dtype=output_dtype,\n            )[0]\n        else:\n            raise ValueError(\n                \"when weight is int4, hidden_states only supports non-quant dtype!\"\n            )\n\n        return out_hidden\n"
  },
  {
    "path": "python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py",
    "content": "from typing import TYPE_CHECKING, Optional\n\nimport torch\n\nfrom sglang.srt.hardware_backend.npu.utils import npu_format_cast\nfrom sglang.srt.layers.quantization.base_config import LinearMethodBase\n\nif TYPE_CHECKING:\n    from sglang.srt.layers.quantization.base_config import QuantizationConfig\n\n\nclass _NPULinearMethodBase(LinearMethodBase):\n\n    def __init__(\n        self,\n        quant_config: Optional[\"QuantizationConfig\"] = None,\n    ):\n        self.quant_config = quant_config\n\n\nclass NPUW8A8Int8LinearMethod(_NPULinearMethodBase):\n\n    def process_weights_after_loading(self, layer: torch.nn.Module):\n        layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()\n        layer.weight.data = npu_format_cast(layer.weight.data)\n\n        layer.weight_scale.data = layer.weight_scale.data.flatten()\n        # Compressed-tensors format doesn't have this field\n        if hasattr(layer, \"weight_offset\"):\n            layer.weight_offset.data = layer.weight_offset.data.flatten()\n\n        expanding_factor = layer.weight.data.shape[0]\n        layer.aclnn_input_scale = torch.nn.Parameter(\n            layer.input_scale.data.repeat(expanding_factor).to(device=\"npu\"),\n            requires_grad=False,\n        )\n        layer.aclnn_input_scale_reciprocal = 1 / torch.nn.Parameter(\n            layer.input_scale.data.repeat(expanding_factor).to(device=\"npu\"),\n            requires_grad=False,\n        )\n        layer.aclnn_input_offset = torch.nn.Parameter(\n            layer.input_offset.data.repeat(expanding_factor).to(device=\"npu\"),\n            requires_grad=False,\n        )\n\n    def apply(\n        self,\n        layer: torch.nn.Module,\n        x: torch.Tensor,\n        bias: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        from sglang.srt.layers.linear import RowParallelLinear\n\n        original_dtype = x.dtype\n        if original_dtype != torch.int8:\n            x = torch.ops.npu.npu_quantize(\n                x,\n                layer.aclnn_input_scale_reciprocal,\n                layer.aclnn_input_offset,\n                torch.qint8,\n                -1,\n                False,\n            )\n        # Only fuse bias add into GEMM for rank 0 (this ensures that\n        # bias will not get added more than once in Attention TP>1 case)\n        if isinstance(layer, RowParallelLinear) and layer.tp_rank > 0:\n            quant_bias = None\n        else:\n            quant_bias = layer.quant_bias\n        return torch.ops.npu.npu_quant_matmul(\n            x,\n            layer.weight,\n            layer.deq_scale,\n            bias=quant_bias,\n            output_dtype=original_dtype,\n        )\n\n\nclass NPUW8A8Int8DynamicLinearMethod(_NPULinearMethodBase):\n\n    def process_weights_after_loading(self, layer: torch.nn.Module):\n        layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()\n        layer.weight.data = npu_format_cast(layer.weight.data)\n\n        layer.weight_scale.data = layer.weight_scale.data.flatten()\n        # Compressed-tensors format doesn't have this field\n        if hasattr(layer, \"weight_offset\"):\n            layer.weight_offset.data = layer.weight_offset.data.flatten()\n\n    def apply(\n        self,\n        layer: torch.nn.Module,\n        x: torch.Tensor,\n        bias: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n\n        if isinstance(x, tuple):\n            \"\"\"dynamic_scale is calculated in malprolog kernel\"\"\"\n            original_dtype = torch.bfloat16\n            quant_out, dynamic_scale = x\n        else:\n            original_dtype = x.dtype\n            quant_out, dynamic_scale = torch.ops.npu.npu_dynamic_quant(x)\n        return torch.ops.npu.npu_quant_matmul(\n            quant_out,\n            layer.weight,\n            layer.weight_scale,\n            pertoken_scale=dynamic_scale.flatten(),\n            bias=bias,\n            output_dtype=original_dtype,\n        )\n\n\nclass NPU_W4A4DynamicLinearMethod(_NPULinearMethodBase):\n\n    def process_weights_after_loading(self, layer):\n        layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()\n        layer.weight_scale.data = layer.weight_scale.data.flatten()\n        layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)\n        layer.weight_offset.data = layer.weight_offset.data.flatten()\n        layer.weight.data = torch.ops.npu.npu_convert_weight_to_int4pack(\n            layer.weight.data.to(torch.int32)\n        )\n\n    def apply(\n        self,\n        layer: torch.nn.Module,\n        x: torch.Tensor,\n        bias: Optional[torch.Tensor] = None,\n        tp_rank: Optional[int] = 0,\n    ) -> torch.Tensor:\n        original_dtype = x.dtype\n        quant_out, dynamic_scale = torch.ops.npu.npu_dynamic_quant(\n            x, dst_type=torch.quint4x2\n        )\n        return torch.ops.npu.npu_quant_matmul(\n            quant_out,\n            layer.weight,\n            layer.weight_scale,\n            pertoken_scale=dynamic_scale.flatten(),\n            bias=bias,\n            output_dtype=original_dtype,\n        )\n"
  },
  {
    "path": "python/sglang/srt/hardware_backend/npu/utils.py",
    "content": "import functools\nimport logging\nfrom enum import IntEnum\nfrom typing import TYPE_CHECKING, Callable\n\nimport torch\n\nfrom sglang.srt.environ import envs\nfrom sglang.srt.utils import get_npu_memory_capacity, is_npu\n\nif TYPE_CHECKING:\n    from sglang.srt.server_args import ServerArgs\n\nlogger = logging.getLogger(__name__)\n_is_npu = is_npu()\nindexer_weight_stream = None\n\n\nclass NPUACLFormat(IntEnum):\n    ACL_FORMAT_UNDEFINED = -1\n    ACL_FORMAT_ND = 2\n    ACL_FORMAT_FRACTAL_NZ = 29\n\n\nclass FusedMoEMode(IntEnum):\n    FUSED_DEEP_MOE = 1\n    DISPATCH_FFN_COMBINE = 2\n\n\ndef _call_once(fn: Callable):\n\n    @functools.wraps(fn)\n    def wrapper(*args, **kwargs):\n        if getattr(fn, \"_has_been_called\", False):\n            logger.debug(\"Function {} has already been called.\", fn.__name__)\n            return\n\n        fn._has_been_called = True\n        return fn(*args, **kwargs)\n\n    return wrapper\n\n\ndef set_default_server_args(args: \"ServerArgs\"):\n    \"\"\"\n    Set default server arguments for NPU backend.\n    \"\"\"\n\n    # NPU only works with \"ascend\" attention backend for now\n    args.attention_backend = \"ascend\"\n    args.prefill_attention_backend = \"ascend\"\n    args.decode_attention_backend = \"ascend\"\n    if args.page_size is None:\n        args.page_size = 128\n\n    # NPU memory settings\n    npu_mem = get_npu_memory_capacity()\n    if npu_mem <= 32 * 1024:\n        # Ascend 910B4,910B4_1\n        # (chunked_prefill_size 4k, cuda_graph_max_bs 16 if tp < 4 else 64)\n        if args.chunked_prefill_size is None:\n            args.chunked_prefill_size = 4 * 1024\n        if args.cuda_graph_max_bs is None:\n            if args.tp_size < 4:\n                args.cuda_graph_max_bs = 16\n            else:\n                args.cuda_graph_max_bs = 64\n    elif npu_mem <= 64 * 1024:\n        # Ascend 910B1,910B2,910B2C,910B3,910_9391,910_9392,910_9381,910_9382,910_9372,910_9362\n        # (chunked_prefill_size 8k, cuda_graph_max_bs 64 if tp < 4 else 256)\n        if args.chunked_prefill_size is None:\n            args.chunked_prefill_size = 8 * 1024\n        if args.cuda_graph_max_bs is None:\n            if args.tp_size < 4:\n                args.cuda_graph_max_bs = 64\n            else:\n                args.cuda_graph_max_bs = 256\n\n    # NPU does not support CustomAllReduce\n    args.disable_custom_all_reduce = True\n\n    # handles hierarchical cache configs\n    if args.enable_hierarchical_cache:\n        args.hicache_io_backend = \"kernel_ascend\"\n        if args.use_mla_backend():\n            args.hicache_mem_layout = \"page_first_kv_split\"\n        else:\n            args.hicache_mem_layout = \"page_first_direct\"\n\n\n@_call_once\ndef init_npu_backend():\n    \"\"\"\n    Initialize NPU backend. This function should be called only once.\n    \"\"\"\n\n    assert _is_npu, \"NPU backend initialization called on non-NPU device.\"\n\n    import sgl_kernel_npu  # noqa: F401\n    import torch_npu\n    from torch_npu.contrib import transfer_to_npu  # noqa: F401\n\n    # Re-mock torch.cuda.is_available cuz transfer_to_npu mocks it True\n    torch.cuda.is_available = lambda: False\n\n    torch_npu.npu.config.allow_internal_format = True\n    torch_npu.npu.set_compile_mode(jit_compile=False)\n\n\ndef npu_format_cast(\n    tensor: torch.Tensor,\n    acl_format: NPUACLFormat = NPUACLFormat.ACL_FORMAT_FRACTAL_NZ,\n) -> torch.Tensor:\n    \"\"\"\n    Cast a tensor to a specific NPU ACL format.\n\n    Args:\n        tensor (torch.Tensor): The input tensor.\n        acl_format (NPUACLFormat): The target NPU ACL format.\n\n    Returns:\n        torch.Tensor: The tensor cast to the specified NPU ACL format.\n    \"\"\"\n\n    if not _is_npu:\n        return tensor\n\n    if envs.SGLANG_NPU_DISABLE_ACL_FORMAT_WEIGHT.get():\n        return tensor\n\n    if tensor.device == torch.device(\"cpu\"):\n        logger.warning_once(\n            \"Warning: The conversion from 'ND' to 'NZ' does not work on the CPU. \"\n            \"Please disable offloading, otherwise the performance will be \"\n            \"significantly reduced.\"\n        )\n        return tensor\n    else:\n        return torch.ops.npu.npu_format_cast(tensor, acl_format.value)\n\n\ndef get_indexer_weight_stream():\n    global indexer_weight_stream\n    if indexer_weight_stream is None:\n        indexer_weight_stream = torch.npu.Stream()\n    return indexer_weight_stream\n"
  },
  {
    "path": "python/sglang/srt/layers/activation.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Fused operators for activation layers.\"\"\"\n\nimport logging\nimport math\nfrom typing import Optional\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom transformers import PretrainedConfig\n\nfrom sglang.srt.distributed import (\n    divide,\n    get_tensor_model_parallel_rank,\n    get_tensor_model_parallel_world_size,\n)\nfrom sglang.srt.environ import envs\nfrom sglang.srt.layers.quantization.base_config import QuantizationConfig\nfrom sglang.srt.layers.utils import MultiPlatformOp\nfrom sglang.srt.server_args import get_global_server_args\nfrom sglang.srt.utils import (\n    cpu_has_amx_support,\n    is_cpu,\n    is_cuda,\n    is_hip,\n    is_npu,\n    is_xpu,\n    set_weight_attrs,\n)\nfrom sglang.utils import resolve_obj_by_qualname\n\n_is_cuda = is_cuda()\n_is_npu = is_npu()\n_is_cpu_amx_available = cpu_has_amx_support()\n_is_cpu = is_cpu()\n_is_hip = is_hip()\n_is_xpu = is_xpu()\n\nif _is_cuda or _is_xpu:\n    from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul\nelif _is_hip:\n    from sgl_kernel import gelu_and_mul, gelu_quick, gelu_tanh_and_mul, silu_and_mul\n\nif is_npu():\n    import torch_npu\n\nlogger = logging.getLogger(__name__)\n\n\nclass SiluAndMul(MultiPlatformOp):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        if get_global_server_args().rl_on_policy_target is not None:\n            self._forward_method = self.forward_native\n\n    def forward_native(self, x: torch.Tensor) -> torch.Tensor:\n        d = x.shape[-1] // 2\n        return F.silu(x[..., :d]) * x[..., d:]\n\n    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:\n        d = x.shape[-1] // 2\n        output_shape = x.shape[:-1] + (d,)\n        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)\n        silu_and_mul(x, out)\n        return out\n\n    def forward_cpu(self, x: torch.Tensor) -> torch.Tensor:\n        if _is_cpu_amx_available:\n            out = torch.ops.sgl_kernel.silu_and_mul_cpu(x)\n            return out\n        else:\n            return self.forward_native(x)\n\n    def forward_npu(self, x: torch.Tensor) -> torch.Tensor:\n        out = torch_npu.npu_swiglu(x)\n        return out\n\n    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:\n        d = x.shape[-1] // 2\n        output_shape = x.shape[:-1] + (d,)\n        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)\n        silu_and_mul(x, out)\n        return out\n\n\nclass GeluAndMul(MultiPlatformOp):\n    def __init__(self, approximate=\"tanh\"):\n        super().__init__()\n        self.approximate = approximate\n\n    def _forward_impl(self, x: torch.Tensor) -> torch.Tensor:\n        d = x.shape[-1] // 2\n        output_shape = x.shape[:-1] + (d,)\n        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)\n        if self.approximate == \"tanh\":\n            gelu_tanh_and_mul(x, out)\n        elif self.approximate == \"none\":\n            gelu_and_mul(x, out)\n        else:\n            raise RuntimeError(\"GeluAndMul only support tanh or none\")\n        return out\n\n    def forward_native(self, x: torch.Tensor) -> torch.Tensor:\n        d = x.shape[-1] // 2\n        return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]\n\n    def forward_cpu(self, x: torch.Tensor) -> torch.Tensor:\n        if _is_cpu_amx_available and self.approximate == \"tanh\":\n            return torch.ops.sgl_kernel.gelu_tanh_and_mul_cpu(x)\n        elif _is_cpu_amx_available and self.approximate == \"none\":\n            return torch.ops.sgl_kernel.gelu_and_mul_cpu(x)\n        else:\n            return self.forward_native(x)\n\n    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:\n        return self._forward_impl(x)\n\n    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:\n        return self._forward_impl(x)\n\n    def forward_npu(self, x: torch.Tensor) -> torch.Tensor:\n        if envs.SGLANG_NPU_FORWARD_NATIVE_GELUTANH.get():\n            return self.forward_native(x)\n        y_npu, gelu_npu = torch_npu.npu_geglu(\n            x,\n            dim=-1,\n            approximate=1 if self.approximate == \"tanh\" else 0,\n            activate_left=True,\n        )\n        return y_npu\n\n\nclass NewGELU(MultiPlatformOp):\n    def forward_native(self, x: torch.Tensor) -> torch.Tensor:\n        c = math.sqrt(2.0 / math.pi)\n        return 0.5 * x * (1.0 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3.0))))\n\n    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:\n        # TODO: Implement the CUDA kernel for NewGELU in sgl-kernel\n        return self.forward_native(x)\n\n\nclass ReLU2(nn.Module):\n    \"\"\"\n    Applies the squared Rectified Linear Unit function.\n    y = max(0, x)^2\n    \"\"\"\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = F.relu(x)\n        return x * x\n\n\nclass QuickGELU(MultiPlatformOp):\n    def forward_native(self, x: torch.Tensor) -> torch.Tensor:\n        return x * torch.sigmoid(1.702 * x)\n\n    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:\n        return self.forward_native(x)\n\n    def forward_hip(self, x: torch.Tensor) -> torch.Tensor:\n        out = torch.empty(x.shape, dtype=x.dtype, device=x.device)\n        gelu_quick(x, out)\n        return out\n\n    def forward_npu(self, x: torch.Tensor) -> torch.Tensor:\n        return torch_npu.npu_fast_gelu(x)\n\n\nclass XIELU(MultiPlatformOp):\n    \"\"\"\n    Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010\n    If the user has installed the nickjbrowning/XIELU, we import xIELU CUDA\n    Otherwise, we emit a single warning and use xIELU Python\n    \"\"\"\n\n    def __init__(\n        self,\n        alpha_p_init: float = 0.8,\n        alpha_n_init: float = 0.8,\n        beta: float = 0.5,\n        eps: float = -1e-6,\n        dtype: torch.dtype = torch.bfloat16,\n        with_vector_loads: bool = False,\n    ):\n        super().__init__()\n        self.alpha_p = nn.Parameter(\n            torch.log(torch.exp(torch.tensor(alpha_p_init, dtype=dtype)) - 1).unsqueeze(\n                0\n            )\n        )\n        self.alpha_n = nn.Parameter(\n            torch.log(\n                torch.exp(torch.tensor(alpha_n_init - beta, dtype=dtype)) - 1\n            ).unsqueeze(0)\n        )\n        self.register_buffer(\"beta\", torch.tensor(beta, dtype=dtype))\n        self.register_buffer(\"eps\", torch.tensor(eps, dtype=dtype))\n        self.with_vector_loads = with_vector_loads\n        # Temporary until xIELU CUDA fully implemented\n        self._beta_scalar = float(self.beta.detach().cpu().float().item())\n        self._eps_scalar = float(self.eps.detach().cpu().float().item())\n\n        self._xielu_cuda_obj = None\n        try:\n            import xielu.ops  # noqa: F401\n\n            self._xielu_cuda_obj = torch.classes.xielu.XIELU()\n            msg = \"Using experimental xIELU CUDA.\"\n            try:\n                from torch._dynamo import allow_in_graph\n\n                self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda)\n                msg += \" Enabled torch._dynamo for xIELU CUDA.\"\n            except Exception as err:\n                msg += (\n                    f\" Could not enable torch._dynamo for xIELU ({err}) - \"\n                    \"this may result in slower performance.\"\n                )\n                self._xielu_cuda_fn = self._xielu_cuda\n            logger.warning_once(msg)\n        except Exception as err:\n            pass\n            # logger.warning_once(\n            #     \"CUDA-fused xIELU not available (%s) –\"\n            #     \" falling back to a Python version.\\n\"\n            #     \"For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`\",\n            #     str(err),\n            # )\n\n    def _xielu_python(self, x: torch.Tensor) -> torch.Tensor:\n        alpha_p = nn.functional.softplus(self.alpha_p)\n        alpha_n = self.beta + nn.functional.softplus(self.alpha_n)\n        return torch.where(\n            x > 0,\n            alpha_p * x * x + self.beta * x,\n            (torch.expm1(torch.min(x, self.eps)) - x) * alpha_n + self.beta * x,\n        )\n\n    def _xielu_cuda(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"Firewall function to prevent torch.compile from seeing .item()\"\"\"\n        assert self._xielu_cuda_obj is not None, \"XIELU CUDA object must not be None\"\n        original_shape = x.shape\n        # CUDA kernel expects 3D tensors, reshape if needed\n        while x.dim() < 3:\n            x = x.unsqueeze(0)\n        if x.dim() > 3:\n            x = x.view(-1, 1, x.size(-1))\n        if original_shape != x.shape:\n            logger.warning_once(\n                \"Warning: xIELU input tensor expects 3 dimensions\"\n                \" but got (shape: %s). Reshaping to (shape: %s).\\n\"\n                \"Note: For SGLang this may be expected if sending\"\n                \"[B*S,D] instead of [B,S,D].\",\n                original_shape,\n                x.shape,\n            )\n        result = self._xielu_cuda_obj.forward(\n            x,\n            self.alpha_p,\n            self.alpha_n,\n            # Temporary until xIELU CUDA fully implemented -> self.{beta,eps}.item()\n            self._beta_scalar,\n            self._eps_scalar,\n            self.with_vector_loads,\n        )\n        return result.view(original_shape)\n\n    def forward(self, input: torch.Tensor) -> torch.Tensor:\n        if self._xielu_cuda_obj is not None and input.is_cuda:\n            if not torch._dynamo.is_compiling():\n                return self._xielu_cuda_fn(input)\n            else:\n                logger.warning_once(\n                    \"torch._dynamo is compiling, using Python version of xIELU.\"\n                )\n        return self._xielu_python(input)\n\n\nclass ScaledActivation(nn.Module):\n    \"\"\"An activation function with post-scale parameters.\n\n    This is used for some quantization methods like AWQ.\n    \"\"\"\n\n    def __init__(\n        self,\n        act_module: nn.Module,\n        intermediate_size: int,\n        input_is_parallel: bool = True,\n        params_dtype: Optional[torch.dtype] = None,\n    ):\n        super().__init__()\n        self.act = act_module\n        self.input_is_parallel = input_is_parallel\n        if input_is_parallel:\n            tp_size = get_tensor_model_parallel_world_size()\n            intermediate_size_per_partition = divide(intermediate_size, tp_size)\n        else:\n            intermediate_size_per_partition = intermediate_size\n        if params_dtype is None:\n            params_dtype = torch.get_default_dtype()\n        self.scales = nn.Parameter(\n            torch.empty(intermediate_size_per_partition, dtype=params_dtype)\n        )\n        set_weight_attrs(self.scales, {\"weight_loader\": self.weight_loader})\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        return self.act(x) / self.scales\n\n    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):\n        param_data = param.data\n        if self.input_is_parallel:\n            tp_rank = get_tensor_model_parallel_rank()\n            shard_size = param_data.shape[0]\n            start_idx = tp_rank * shard_size\n            loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)\n        assert param_data.shape == loaded_weight.shape\n        param_data.copy_(loaded_weight)\n\n\n_ACTIVATION_REGISTRY = {\n    \"gelu\": nn.GELU(),\n    \"gelu_pytorch_tanh\": nn.GELU(approximate=\"tanh\"),\n    \"gelu_new\": NewGELU(),\n    \"relu2\": ReLU2(),\n    \"xielu\": XIELU(),\n}\n\n\ndef get_act_fn(\n    act_fn_name: str,\n    quant_config: Optional[QuantizationConfig] = None,\n    intermediate_size: Optional[int] = None,\n    input_is_parallel: bool = True,\n    params_dtype: Optional[torch.dtype] = None,\n) -> nn.Module:\n    \"\"\"Get an activation function by name.\"\"\"\n    act_fn_name = act_fn_name.lower()\n    if act_fn_name not in _ACTIVATION_REGISTRY:\n        raise ValueError(f\"Activation function {act_fn_name!r} is not supported.\")\n\n    act_fn = _ACTIVATION_REGISTRY[act_fn_name]\n    if quant_config is not None and act_fn_name in quant_config.get_scaled_act_names():\n        if intermediate_size is None:\n            raise ValueError(\n                \"intermediate_size must be specified for scaled \"\n                \"activation functions.\"\n            )\n        return ScaledActivation(\n            act_fn, intermediate_size, input_is_parallel, params_dtype\n        )\n    return act_fn\n\n\ndef get_cross_encoder_activation_function(config: PretrainedConfig):\n    if (\n        hasattr(config, \"sbert_ce_default_activation_function\")\n        and config.sbert_ce_default_activation_function is not None\n    ):\n\n        function_name = config.sbert_ce_default_activation_function\n        assert function_name.startswith(\"torch.nn.modules.\"), (\n            \"Loading of activation functions is restricted to \"\n            \"torch.nn.modules for security reasons\"\n        )\n        return resolve_obj_by_qualname(function_name)()\n    else:\n        # adapt bge-reranker\n        return nn.Identity()\n"
  },
  {
    "path": "python/sglang/srt/layers/amx_utils.py",
    "content": "import logging\n\nimport torch\n\nfrom sglang.srt.utils import cpu_has_amx_support\n\nlogger = logging.getLogger(__name__)\n\nfrom enum import IntEnum\n\n\nclass CPUQuantMethod(IntEnum):\n    UNQUANT = 0\n    INT8_W8A8 = 1\n    FP8_W8A16 = 2\n    INT4_W4A8 = 3\n\n\ndef amx_process_weight_after_loading(weight, is_conv=False):\n    if weight.device != torch.device(\"cpu\"):\n        return weight\n    if not cpu_has_amx_support():\n        return weight\n    if is_conv:\n        return torch.ops.sgl_kernel.causal_conv1d_weight_pack(\n            weight.view(-1, weight.size(-1))\n        )\n    else:\n        return torch.ops.sgl_kernel.convert_weight_packed(weight)\n\n\n# TODO: currently gemm kernel has the below requirements:\n# OC: OC % TILE_N == 0 or OC < TILE_N, where TILE_N = 16\n# IC: IC % TILE_K == 0, where TILE_K = 32\ndef dim_is_supported(weight):\n    TILE_N = 16\n    TILE_K = 32\n    ndim = weight.ndim\n    OC = weight.size(1) if ndim == 3 else weight.size(0)\n    IC = weight.size(2) if ndim == 3 else weight.size(1)\n    is_oc_support = OC < TILE_N or OC % TILE_N == 0\n    is_ic_support = IC % TILE_K == 0\n    return is_oc_support and is_ic_support\n\n\ndef dtype_is_supported(weight):\n    return weight.dtype in [\n        torch.float16,\n        torch.bfloat16,\n        torch.int8,\n        torch.float8_e4m3fn,\n    ]\n\n\ndef is_dim_conv_weight(weight):\n    return weight.dim() == 3 and weight.size(1) == 1\n\n\ndef _init_amx_conv_state(conv_state):\n    # CPU AMX layout for conv_state kernel optimization\n    conv_state_cpu = []\n    for conv_shape_t in conv_state:\n        conv_shape_new = conv_shape_t.as_strided_(\n            conv_shape_t.size(),\n            (\n                conv_shape_t.stride(0),\n                conv_shape_t.stride(1),\n                1,\n                conv_shape_t.size(2),\n            ),\n        )\n        conv_state_cpu.append(conv_shape_new)\n    return conv_state_cpu\n\n\ndef _amx_process_weight_after_loading(\n    module, weight_names, transpose_dims=None\n) -> None:\n    # Pack weight for get better performance on CPU\n    devices = {getattr(module, weight_name).device for weight_name in weight_names}\n    assert len(devices) == 1, f\"Expects all weights to be on the same device\"\n    device = devices.pop()\n\n    if transpose_dims:\n        assert len(weight_names) == len(\n            transpose_dims\n        ), \"len(weight_names) should be equal to len(transpose_dims)\"\n\n    for i, weight_name in enumerate(weight_names):\n        weight_tensor = getattr(module, weight_name)\n\n        if transpose_dims and transpose_dims[i]:\n            weight_tensor = weight_tensor.transpose(*transpose_dims[i])\n        is_conv_weight = is_dim_conv_weight(weight_tensor)\n        # We don't pack weight or use intel amx backend if any weight of this module has unsupported dim.\n        if (\n            (not dim_is_supported(weight_tensor))\n            or not dtype_is_supported(weight_tensor)\n        ) and (not is_conv_weight):\n            logger.warning(\n                f\"Unsupported dimension or dtype for prepacking for weight '{weight_name}' with shape {weight_tensor.shape} and dtype {weight_tensor.dtype} in {module}. \"\n                f\"The derived (OC, IC) dimensions must be divisible by (16, 32). \"\n            )\n            module.use_intel_amx_backend = False\n            return\n\n        packed_weight = torch.nn.Parameter(\n            amx_process_weight_after_loading(weight_tensor, is_conv_weight),\n            requires_grad=False,\n        )\n        packed_weight.__dict__ = weight_tensor.__dict__\n        setattr(module, weight_name, packed_weight)\n        if is_conv_weight:\n            # need to use inplace copy for conv weight amx packing,\n            # as its usage in radix_linear_attention will use the original conv weight.\n            weight_tensor = weight_tensor.view(-1, weight_tensor.size(-1))\n            weight_tensor.copy_(packed_weight)\n\n    module.use_intel_amx_backend = (\n        device == torch.device(\"cpu\") and cpu_has_amx_support()\n    )\n\n    if (\n        module.use_intel_amx_backend\n        and hasattr(module, \"bias\")\n        and module.bias is not None\n    ):\n        module.bias = torch.nn.Parameter(module.bias.data.float(), requires_grad=False)\n\n\nclass PackWeightMethod:\n    def __init__(self, weight_names, transpose_dims=None):\n        self.weight_names = weight_names\n        self.transpose_dims = transpose_dims\n\n    def process_weights_after_loading(self, module) -> None:\n        _amx_process_weight_after_loading(\n            module, self.weight_names, self.transpose_dims\n        )\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/aiter_backend.py",
    "content": "from __future__ import annotations\n\n\"\"\"\nend to end attention solution with aiter kernels\n\"\"\"\n\nimport logging\nfrom dataclasses import dataclass\nfrom enum import Enum, auto\nfrom typing import TYPE_CHECKING, Optional\n\nimport torch\nimport triton\n\nfrom sglang.srt.layers.attention.base_attn_backend import AttentionBackend\nfrom sglang.srt.layers.attention.utils import (\n    create_flashinfer_kv_indices_triton,\n    create_flashmla_kv_indices_triton,\n)\nfrom sglang.srt.layers.dp_attention import (\n    get_attention_tp_size,\n    is_dp_attention_enabled,\n)\nfrom sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode\nfrom sglang.srt.utils import is_gfx95_supported\n\nif TYPE_CHECKING:\n    from sglang.srt.layers.radix_attention import RadixAttention\n    from sglang.srt.model_executor.model_runner import ModelRunner\n    from sglang.srt.speculative.spec_info import SpecInput\n\ntry:\n    from aiter import (\n        flash_attn_varlen_func,\n        get_mla_metadata_info_v1,\n        get_mla_metadata_v1,\n        get_ps_metadata_info_v1,\n        get_ps_metadata_v1,\n        mha_batch_prefill_func,\n        mla_prefill_ps_asm_fwd,\n        mla_reduce_v1,\n        paged_attention_ragged,\n    )\n    from aiter.mla import mla_decode_fwd, mla_prefill_fwd\n    from aiter.ops.triton.attention.unified_attention import unified_attention\nexcept ImportError:\n    print(\n        \"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device.\"\n    )\n\nfrom sglang.srt.configs.model_config import AttentionArch\nfrom sglang.srt.layers.attention.utils import (\n    launch_reshape_and_cache_flash,\n    pad_sequence_with_mask,\n)\nfrom sglang.srt.layers.quantization.fp8_kernel import fp8_dtype\nfrom sglang.srt.mem_cache.swa_memory_pool import SWAKVPool\nfrom sglang.srt.utils import get_bool_env_var\n\nlogger = logging.getLogger(__name__)\n\n# Use aiter mla persist design for fp8-kv cache\n_use_mla_ps_kernel = get_bool_env_var(\"SGLANG_AITER_MLA_PERSIST\", \"True\")\n\n# Use fp8 prefill only on gfx95\n_use_fp8_prefill_attn = (\n    get_bool_env_var(\"SGLANG_AITER_FP8_PREFILL_ATTN\", \"True\") and is_gfx95_supported()\n)\n\n# Persist\n# fast_mode=True if _use_mla_ps_kernel else False\n# intra_batch_mode=False if _use_mla_ps_kernel else True\n\n# fake non-ps, intra_batch_mode needs to be True for non-ps-mode\nfast_mode = False\nintra_batch_mode = True if _use_mla_ps_kernel else False\n\n\nclass WrapperDispatch(Enum):\n    SLIDING_WINDOW = auto()\n    CROSS_ATTENTION = auto()\n\n\n@dataclass\nclass ForwardMetadata:\n    kv_indptr: torch.Tensor\n    kv_indices: torch.Tensor\n    qo_indptr: torch.Tensor\n    kv_last_page_len: torch.Tensor\n    max_q_len: int\n    max_kv_len: Optional[int]\n    work_metadata: Optional[torch.Tensor] = None\n    work_info_set: Optional[torch.Tensor] = None\n    work_indptr: Optional[torch.Tensor] = None\n    reduce_indptr: Optional[torch.Tensor] = None\n    reduce_final_map: Optional[torch.Tensor] = None\n    reduce_partial_map: Optional[torch.Tensor] = None\n    num_kv_splits: Optional[int] = None\n    run_graph: Optional[bool] = True\n    custom_mask: Optional[torch.Tensor] = None\n    mask_indptr: Optional[torch.Tensor] = None\n    max_extend_len: Optional[int] = None\n    fp8_prefill_kv_indices: Optional[torch.Tensor] = None\n    swa_page_table: Optional[torch.Tensor] = None\n\n\nglobal_workspace_buffer = None\n\n_AITER_PARTITION_SIZE_ROCM = 256\n\n\nclass AiterAttnBackend(AttentionBackend):\n    def __init__(\n        self,\n        model_runner: ModelRunner,\n        skip_prefill: bool = False,\n        kv_indptr_buf: Optional[torch.Tensor] = None,\n        topk: int = 1,\n    ):\n        super().__init__()\n        # Lazy import to avoid the initialization of cuda context\n        from sglang.srt.layers.attention.triton_ops.extend_attention import (\n            extend_attention_fwd,\n        )\n\n        self.input_dtype = model_runner.model_config.dtype\n\n        self.page_size = model_runner.server_args.page_size\n\n        self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd)\n\n        self.device = model_runner.device\n        self.is_multimodal = model_runner.model_config.is_multimodal\n        self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens\n        self.speculative_num_steps = model_runner.server_args.speculative_num_steps\n        self.topk = topk\n        self.num_head = (\n            model_runner.model_config.num_attention_heads // get_attention_tp_size()\n        )\n        self.head_dim = model_runner.model_config.head_dim\n        self.num_kv_head = model_runner.model_config.get_num_kv_heads(\n            get_attention_tp_size()\n        )\n        self.kv_cache_dtype = model_runner.kv_cache_dtype\n\n        self.req_to_token = model_runner.req_to_token_pool.req_to_token\n\n        self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA\n\n        # Get v_head_dim based on model type\n        if self.use_mla:\n            # For MLA models, get v_head_dim from model config\n            self.v_head_dim = model_runner.model_config.v_head_dim\n        elif hasattr(model_runner.token_to_kv_pool, \"get_v_head_dim\"):\n            # For hybrid models (Mamba+attention, GDN, Kimi linear),\n            # layer_id=0 may not be a full attention layer\n            self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim()\n        else:\n            self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[\n                -1\n            ]\n\n        # Parse constants\n        self.max_context_len = model_runner.model_config.context_len\n        self.skip_prefill = skip_prefill\n\n        max_bs = model_runner.req_to_token_pool.size\n\n        if kv_indptr_buf is None:\n            self.kv_indptr = torch.zeros(\n                (max_bs + 1,), dtype=torch.int32, device=model_runner.device\n            )\n        else:\n            self.kv_indptr = kv_indptr_buf\n\n        self.kv_last_page_len = torch.ones(\n            (max_bs,), dtype=torch.int32, device=model_runner.device\n        )\n        self.qo_indptr = torch.zeros(\n            (max_bs + 1,), dtype=torch.int32, device=model_runner.device\n        )\n        self.mask_indptr = torch.zeros(\n            (max_bs + 1,), dtype=torch.int64, device=model_runner.device\n        )\n        self._kv_indices_scratch: Optional[torch.Tensor] = None\n\n        # Create prefill indices updater\n        if not skip_prefill:\n            self.indices_updater_prefill = AiterIndicesUpdaterPrefill(\n                model_runner, self\n            )\n            if self.use_mla:\n                self.mla_indices_updater_prefill = AiterMlaIndicesUpdaterPrefill(\n                    model_runner, self\n                )\n\n        # sliding window attention\n        self.use_sliding_window_kv_pool = (\n            isinstance(model_runner.token_to_kv_pool, SWAKVPool)\n            and model_runner.token_to_kv_pool.swa_layer_nums > 0\n        )\n\n        if self.use_sliding_window_kv_pool:\n            self.token_to_kv_pool = model_runner.token_to_kv_pool\n            self.use_triton_unified_attention = True\n        else:\n            self.use_triton_unified_attention = False\n\n        # aiter kernel related initialization\n        self.max_num_partitions = (\n            self.max_context_len + _AITER_PARTITION_SIZE_ROCM - 1\n        ) // _AITER_PARTITION_SIZE_ROCM\n\n        nbyes_per_qo_elem = torch.finfo(torch.float32).bits // 8\n\n        if not (self.use_mla or self.use_triton_unified_attention):\n            self.workspace_buffer = torch.empty(\n                (max_bs * self.num_head * self.max_num_partitions * self.head_dim)\n                * nbyes_per_qo_elem\n                + 2 * (max_bs * self.num_head * self.max_num_partitions) * 4,\n                dtype=torch.uint8,\n                device=self.device,\n            )\n\n        self.scale = float(1.0 / (self.head_dim**0.5))\n        self.k_scale = self.v_scale = torch.tensor([1.0], dtype=torch.float32).to(\n            self.device\n        )\n\n        self.logits_soft_cap = 0.0\n\n        self.forward_metadata: ForwardMetadata = None\n\n        if self.use_mla:\n            self.enable_dp_attention = is_dp_attention_enabled()\n            self.qo_indptr_ = torch.zeros(\n                (max_bs + 1,), dtype=torch.int32, device=model_runner.device\n            )\n            global _use_mla_ps_kernel, fast_mode, intra_batch_mode\n\n            # current mla_decode_fwd onln support fake-nps in self.num_head == 16\n            # so all num_head size does not use qh16 kernel to simulate\n            # it should not use fake-nps (fast_mode = False, intra_batch_mode = True)\n            # it will cause gpu-fault or accuracy issue\n            if self.num_head == 32 or self.num_head == 128:\n                fast_mode = True\n                intra_batch_mode = False\n\n            # current persist a16w16 mla_decode kernel does not support head_num = 128\n            # need to fall back to non-persist\n            # only use mla_ps_kernel when fp8 kv_cache\n            # for non-fp8 kv_cache on tp8, use non-persist kernel to avoid performance degradation\n            # head_num=16 (tp8 perf issue), head_num=128 (unsupported, like tp1 or --enable-dp-attention with tp8-dp8)\n            if (\n                self.num_head == 16 or self.num_head == 128\n            ) and self.kv_cache_dtype is not fp8_dtype:\n                _use_mla_ps_kernel = False\n                fast_mode = False\n                intra_batch_mode = False\n\n            self.max_split_per_batch = 32 if _use_mla_ps_kernel else None\n\n            if self.num_draft_tokens is None and _use_mla_ps_kernel:\n                self.max_split_per_batch = 64\n\n            self.fix_max_split_per_batch = self.max_split_per_batch\n\n    def make_mla_decode_meta_data_buffer(self, max_seqlen_qo, batch_size):\n        nhead = self.num_head\n        dtype = self.kv_cache_dtype\n\n        if self.enable_dp_attention:\n            gpu = torch.cuda.current_device()\n            device_properties = torch.cuda.get_device_properties(gpu)\n            cu_num = device_properties.multi_processor_count\n            self.max_split_per_batch = min(\n                (cu_num + batch_size - 1) // batch_size, self.fix_max_split_per_batch\n            )\n\n        (\n            (work_meta_data_size, work_meta_data_type),\n            (work_indptr_size, work_indptr_type),\n            (work_info_set_size, work_info_set_type),\n            (reduce_indptr_size, reduce_indptr_type),\n            (reduce_final_map_size, reduce_final_map_type),\n            (reduce_partial_map_size, reduce_partial_map_type),\n        ) = get_mla_metadata_info_v1(\n            batch_size,\n            max_seqlen_qo,\n            nhead,\n            dtype,\n            dtype,\n            is_sparse=False,\n            fast_mode=fast_mode,\n            num_kv_splits=self.max_split_per_batch,\n            intra_batch_mode=intra_batch_mode,\n        )\n\n        # aiter implementation\n        # the tensor's meaning please refer aiter/ops/attention.py\n        work_metadata = torch.empty(\n            work_meta_data_size, dtype=work_meta_data_type, device=\"cuda\"\n        )\n        work_indptr = torch.empty(\n            work_indptr_size, dtype=work_indptr_type, device=\"cuda\"\n        )\n        work_info_set = torch.empty(\n            work_info_set_size,\n            dtype=work_info_set_type,\n            device=\"cuda\",\n        )\n        reduce_indptr = torch.empty(\n            reduce_indptr_size, dtype=reduce_indptr_type, device=\"cuda\"\n        )\n        reduce_final_map = torch.empty(\n            reduce_final_map_size, dtype=reduce_final_map_type, device=\"cuda\"\n        )\n        reduce_partial_map = torch.empty(\n            reduce_partial_map_size, dtype=reduce_partial_map_type, device=\"cuda\"\n        )\n\n        return (\n            work_metadata,\n            work_indptr,\n            work_info_set,\n            reduce_indptr,\n            reduce_final_map,\n            reduce_partial_map,\n        )\n\n    def make_mla_meta_data(\n        self,\n        qo_indptr,\n        kv_indptr,\n        kv_last_page_len,\n        work_metadata,\n        work_info_set,\n        work_indptr,\n        reduce_indptr,\n        reduce_final_map,\n        reduce_partial_map,\n        max_q_len,\n        fast_mode,\n        max_split_per_batch,\n        intra_batch_mode,\n    ):\n\n        nhead_kv = 1\n        page_size = self.page_size\n        dtype = self.kv_cache_dtype\n\n        meta = get_mla_metadata_v1(\n            qo_indptr,\n            kv_indptr,\n            kv_last_page_len,\n            self.num_head // nhead_kv,\n            nhead_kv,\n            False,\n            work_metadata,\n            work_info_set,\n            work_indptr,\n            reduce_indptr,\n            reduce_final_map,\n            reduce_partial_map,\n            kv_granularity=max(page_size, 16),\n            max_seqlen_qo=max_q_len,\n            uni_seqlen_qo=max_q_len,\n            fast_mode=fast_mode,\n            max_split_per_batch=max_split_per_batch,\n            intra_batch_mode=intra_batch_mode,\n            dtype_q=dtype,\n            dtype_kv=dtype,\n        )\n\n    def make_mla_prefill_ps_meta_data_buffer(\n        self, batch_size: int, max_qlen: int, qlen_granularity: int\n    ):\n        (\n            (work_meta_data_size, work_meta_data_type),\n            (work_indptr_size, work_indptr_type),\n            (work_info_size, work_info_type),\n            (reduce_indptr_size, reduce_indptr_type),\n            (reduce_final_map_size, reduce_final_map_type),\n            (reduce_partial_map_size, reduce_partial_map_type),\n        ) = get_ps_metadata_info_v1(\n            batch_size=batch_size,\n            num_head_k=self.num_kv_head,\n            max_qlen=max_qlen,\n            qlen_granularity=qlen_granularity,\n        )\n\n        device = self.device\n        work_metadata_ptrs = torch.empty(\n            work_meta_data_size, dtype=work_meta_data_type, device=device\n        )\n        work_indptr = torch.empty(\n            work_indptr_size, dtype=work_indptr_type, device=device\n        )\n        work_info = torch.empty(work_info_size, dtype=work_info_type, device=device)\n        reduce_indptr = torch.empty(\n            reduce_indptr_size, dtype=reduce_indptr_type, device=device\n        )\n        reduce_final_map = torch.empty(\n            reduce_final_map_size, dtype=reduce_final_map_type, device=device\n        )\n        reduce_partial_map = torch.empty(\n            reduce_partial_map_size, dtype=reduce_partial_map_type, device=device\n        )\n\n        return (\n            work_metadata_ptrs,\n            work_indptr,\n            work_info,\n            reduce_indptr,\n            reduce_final_map,\n            reduce_partial_map,\n        )\n\n    def make_mla_prefill_ps_meta_data(\n        self,\n        qo_indptr: torch.Tensor,\n        kv_indptr: torch.Tensor,\n        seq_lens: torch.Tensor,\n        work_metadata: torch.Tensor,\n        work_indptr: torch.Tensor,\n        work_info: torch.Tensor,\n        reduce_indptr: torch.Tensor,\n        reduce_final_map: torch.Tensor,\n        reduce_partial_map: torch.Tensor,\n        is_causal: bool = True,\n    ):\n        gqa_ratio = self.num_head // self.num_kv_head\n        num_heads_k = self.num_kv_head\n        tile_q = 256\n        qhead_granularity = gqa_ratio\n        qlen_granularity = tile_q // qhead_granularity\n        kvlen_granularity = max(128, self.page_size)\n        block_size = self.page_size\n\n        qo_indptr_cpu = qo_indptr.to(\"cpu\", dtype=torch.int32)\n        kv_indptr_cpu = kv_indptr.to(\"cpu\", dtype=torch.int32)\n        seq_lens_cpu = seq_lens.to(\"cpu\", dtype=torch.int32)\n\n        get_ps_metadata_v1(\n            qo_indptr_cpu,\n            kv_indptr_cpu,\n            seq_lens_cpu,\n            gqa_ratio,\n            num_heads_k,\n            work_metadata,\n            work_indptr,\n            work_info,\n            reduce_indptr,\n            reduce_final_map,\n            reduce_partial_map,\n            qhead_granularity=qhead_granularity,\n            qlen_granularity=qlen_granularity,\n            kvlen_granularity=kvlen_granularity,\n            block_size=block_size,\n            is_causal=is_causal,\n        )\n\n    # for page size > 1 useful conversion function\n    def _transform_table_1_to_real(self, page_table: torch.Tensor) -> torch.Tensor:\n        page_size = self.page_size\n        if page_size == 1:\n            return page_table\n        max_seqlen_k = page_table.shape[1]\n        strided_indices = torch.arange(\n            0, max_seqlen_k, page_size, device=page_table.device, dtype=torch.int32\n        )\n        return page_table[:, strided_indices] // page_size\n\n    def _resolve_v2_num_draft_tokens(\n        self,\n        extend_seq_lens: Optional[torch.Tensor] = None,\n        extend_seq_lens_cpu: Optional[list[int]] = None,\n    ) -> int:\n        \"\"\"Resolve fixed per-request extend length for DRAFT_EXTEND_V2.\"\"\"\n        num_draft_tokens = self.num_draft_tokens\n        if num_draft_tokens is None:\n            if extend_seq_lens is not None and extend_seq_lens.numel() > 0:\n                # Avoid list scans in hot path when tensor lengths are already available.\n                num_draft_tokens = int(extend_seq_lens[0].item())\n            elif extend_seq_lens_cpu:\n                num_draft_tokens = max(extend_seq_lens_cpu)\n            else:\n                raise ValueError(\n                    \"DRAFT_EXTEND_V2 requires speculative_num_draft_tokens or \"\n                    \"non-empty extend_seq_lens/extend_seq_lens_cpu.\"\n                )\n\n        num_draft_tokens = int(num_draft_tokens)\n        if extend_seq_lens is not None and extend_seq_lens.numel() > 0:\n            if not torch.all(extend_seq_lens == num_draft_tokens):\n                raise ValueError(\n                    \"DRAFT_EXTEND_V2 expects fixed extend length per request; got \"\n                    f\"extend_seq_lens={extend_seq_lens}, expected all == {num_draft_tokens}.\"\n                )\n        if extend_seq_lens_cpu and any(\n            x != num_draft_tokens for x in extend_seq_lens_cpu\n        ):\n            raise ValueError(\n                \"DRAFT_EXTEND_V2 expects fixed extend length per request; got \"\n                f\"{extend_seq_lens_cpu}, expected all == {num_draft_tokens}.\"\n            )\n        return num_draft_tokens\n\n    def _get_kv_indices_scratch(\n        self, required_tokens: int, device: torch.device\n    ) -> torch.Tensor:\n        if (\n            self._kv_indices_scratch is None\n            or self._kv_indices_scratch.device != device\n            or self._kv_indices_scratch.numel() < required_tokens\n        ):\n            self._kv_indices_scratch = torch.empty(\n                required_tokens, dtype=torch.int32, device=device\n            )\n        return self._kv_indices_scratch[:required_tokens]\n\n    def _set_uniform_qo_indptr(\n        self, bs: int, tokens_per_req: int, device: torch.device\n    ) -> torch.Tensor:\n        qo_indptr = self.qo_indptr[: bs + 1]\n        qo_indptr[: bs + 1] = torch.arange(\n            0,\n            bs * tokens_per_req + 1,\n            step=tokens_per_req,\n            dtype=torch.int32,\n            device=device,\n        )\n        return qo_indptr\n\n    def _ensure_spec_v2_topk_supported(self):\n        if self.topk > 1:\n            raise NotImplementedError(\n                \"AiterAttnBackend SPEC_V2 path currently supports topk <= 1 only. \"\n                f\"Got topk={self.topk}.\"\n            )\n\n    def mla_fp8_prefill_attn(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        layer: RadixAttention,\n    ):\n        total_q = q.shape[0]\n        nhead = layer.tp_q_head_num\n        v_head_dim = layer.v_head_dim\n\n        if q.dtype != fp8_dtype:\n            q = q.to(fp8_dtype)\n        if k.dtype != fp8_dtype:\n            k = k.to(fp8_dtype)\n        if v.dtype != fp8_dtype:\n            v = v.to(fp8_dtype)\n        one_scale = torch.ones((), dtype=torch.float32, device=q.device)\n\n        tile_q = 256\n        reduce_indptr = self.forward_metadata.reduce_indptr\n        reduce_final_map = self.forward_metadata.reduce_final_map\n        reduce_partial_map = self.forward_metadata.reduce_partial_map\n\n        logits = torch.empty(\n            (reduce_partial_map.size(0) * tile_q, nhead, v_head_dim),\n            dtype=torch.float32,\n            device=q.device,\n        )\n        attn_lse = torch.empty(\n            (reduce_partial_map.size(0) * tile_q, nhead),\n            dtype=torch.float32,\n            device=q.device,\n        )\n        final_lse = torch.empty(\n            (total_q, nhead),\n            dtype=torch.float32,\n            device=q.device,\n        )\n        output = q.new_empty(\n            (total_q, nhead, v_head_dim),\n            dtype=self.input_dtype,\n        )\n\n        mla_prefill_ps_asm_fwd(\n            q,\n            k,\n            v,\n            self.forward_metadata.qo_indptr,\n            self.forward_metadata.kv_indptr,\n            self.forward_metadata.fp8_prefill_kv_indices,\n            self.forward_metadata.work_indptr,\n            self.forward_metadata.work_info_set,\n            self.forward_metadata.max_q_len,\n            layer.scaling,\n            True,\n            logits,\n            attn_lse,\n            output,\n            one_scale,\n            one_scale,\n            one_scale,\n        )\n        mla_reduce_v1(\n            logits,\n            attn_lse,\n            reduce_indptr,\n            reduce_final_map,\n            reduce_partial_map,\n            tile_q,\n            output,\n            final_lse,\n        )\n        return output\n\n    def init_forward_metadata(self, forward_batch: ForwardBatch):\n        \"\"\"Init auxiliary variables for aiter attention backend.\"\"\"\n\n        bs = forward_batch.batch_size\n        kv_indptr = self.kv_indptr\n        spec_info = forward_batch.spec_info\n        qo_indptr = None\n        kv_last_page_len = None\n        max_q_len = None\n        max_kv_len = None\n\n        work_metadata = None\n        work_indptr = None\n        work_info_set = None\n        reduce_indptr = None\n        reduce_final_map = None\n        reduce_partial_map = None\n\n        num_kv_splits = None\n        swa_page_table = None\n\n        if forward_batch.forward_mode.is_decode_or_idle():\n            if spec_info is None or forward_batch.forward_mode.is_idle():\n                kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)\n                kv_indptr = kv_indptr[: bs + 1]\n\n                if not self.use_triton_unified_attention:\n                    kv_indices = self._get_kv_indices_scratch(\n                        forward_batch.seq_lens_sum, forward_batch.seq_lens.device\n                    )\n                    create_flashinfer_kv_indices_triton[(bs,)](\n                        self.req_to_token,\n                        forward_batch.req_pool_indices,\n                        forward_batch.seq_lens,\n                        kv_indptr,\n                        None,\n                        kv_indices,\n                        self.req_to_token.stride(0),\n                    )\n                else:\n                    max_q_len = 1\n                    page_size = self.page_size\n                    max_kv_len = torch.max(forward_batch.seq_lens).item()\n                    max_num_blocks_per_seq = (max_kv_len + page_size - 1) // page_size\n                    kv_indices = torch.zeros(\n                        bs, max_kv_len, dtype=torch.int32, device=self.device\n                    )\n\n                    create_flashmla_kv_indices_triton[(bs,)](\n                        self.req_to_token,\n                        forward_batch.req_pool_indices,\n                        forward_batch.seq_lens,\n                        None,\n                        kv_indices,\n                        self.req_to_token.stride(0),\n                        max_kv_len,\n                        1,\n                    )\n\n                    if self.use_sliding_window_kv_pool:\n                        swa_page_table = (\n                            self.token_to_kv_pool.translate_loc_from_full_to_swa(\n                                kv_indices\n                            )\n                        )\n\n                        kv_indices = self._transform_table_1_to_real(kv_indices)\n                        swa_page_table = self._transform_table_1_to_real(swa_page_table)\n\n                    qo_indptr = self.qo_indptr[: bs + 1]\n                    qo_indptr[1 : bs + 1] = torch.cumsum(\n                        self.kv_last_page_len[:bs], dim=0\n                    )\n\n            else:\n                kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices\n                bs = kv_indptr.shape[0] - 1\n\n            if self.use_mla:\n                qo_indptr = self.qo_indptr_[: bs + 1]\n                qo_indptr[1 : bs + 1] = torch.cumsum(self.kv_last_page_len[:bs], dim=0)\n                kv_last_page_len = self.kv_last_page_len[:bs]\n                max_q_len = 1\n\n                if _use_mla_ps_kernel:\n                    (\n                        work_metadata,\n                        work_indptr,\n                        work_info_set,\n                        reduce_indptr,\n                        reduce_final_map,\n                        reduce_partial_map,\n                    ) = self.make_mla_decode_meta_data_buffer(max_q_len, bs)\n\n                    num_kv_splits = self.max_split_per_batch\n\n                    self.make_mla_meta_data(\n                        qo_indptr,\n                        kv_indptr,\n                        kv_last_page_len,\n                        work_metadata,\n                        work_info_set,\n                        work_indptr,\n                        reduce_indptr,\n                        reduce_final_map,\n                        reduce_partial_map,\n                        max_q_len,\n                        fast_mode=fast_mode,\n                        max_split_per_batch=num_kv_splits,\n                        intra_batch_mode=intra_batch_mode,\n                    )\n\n            self.forward_metadata = ForwardMetadata(\n                kv_indptr,\n                kv_indices,\n                qo_indptr,\n                kv_last_page_len,\n                max_q_len,\n                max_kv_len,\n                work_metadata=work_metadata,\n                work_info_set=work_info_set,\n                work_indptr=work_indptr,\n                reduce_indptr=reduce_indptr,\n                reduce_final_map=reduce_final_map,\n                reduce_partial_map=reduce_partial_map,\n                num_kv_splits=num_kv_splits,\n                run_graph=False,\n                swa_page_table=swa_page_table,\n            )\n\n        elif forward_batch.forward_mode.is_draft_extend_v2():\n            # EAGLE V2: DRAFT_EXTEND_V2 mode - extend draft KV cache with all predicted tokens\n            self._ensure_spec_v2_topk_supported()\n            if self.use_mla:\n                device = forward_batch.seq_lens.device\n                num_draft_tokens = self._resolve_v2_num_draft_tokens(\n                    extend_seq_lens=forward_batch.extend_seq_lens\n                )\n                qo_indptr = self._set_uniform_qo_indptr(bs, num_draft_tokens, device)\n\n                kv_indptr = self.kv_indptr[: bs + 1]\n                kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)\n\n                kv_indices = self._get_kv_indices_scratch(\n                    forward_batch.seq_lens_sum, device\n                )\n\n                create_flashinfer_kv_indices_triton[(bs,)](\n                    self.req_to_token,\n                    forward_batch.req_pool_indices,\n                    forward_batch.seq_lens,\n                    kv_indptr,\n                    None,\n                    kv_indices,\n                    self.req_to_token.stride(0),\n                )\n\n                if _use_mla_ps_kernel:\n                    max_seqlen_qo = num_draft_tokens\n                    (\n                        work_metadata,\n                        work_indptr,\n                        work_info_set,\n                        reduce_indptr,\n                        reduce_final_map,\n                        reduce_partial_map,\n                    ) = self.make_mla_decode_meta_data_buffer(max_seqlen_qo, bs)\n\n                    num_kv_splits = self.max_split_per_batch\n\n                    self.make_mla_meta_data(\n                        qo_indptr,\n                        kv_indptr,\n                        self.kv_last_page_len[:bs],\n                        work_metadata,\n                        work_info_set,\n                        work_indptr,\n                        reduce_indptr,\n                        reduce_final_map,\n                        reduce_partial_map,\n                        max_seqlen_qo,\n                        fast_mode=fast_mode,\n                        max_split_per_batch=num_kv_splits,\n                        intra_batch_mode=intra_batch_mode,\n                    )\n\n                self.forward_metadata = ForwardMetadata(\n                    kv_indptr,\n                    kv_indices,\n                    qo_indptr,\n                    self.kv_last_page_len[:bs],\n                    num_draft_tokens,\n                    forward_batch.seq_lens_cpu.max().item(),\n                    work_metadata=work_metadata,\n                    work_info_set=work_info_set,\n                    work_indptr=work_indptr,\n                    reduce_indptr=reduce_indptr,\n                    reduce_final_map=reduce_final_map,\n                    reduce_partial_map=reduce_partial_map,\n                    num_kv_splits=num_kv_splits,\n                    run_graph=False,\n                )\n            else:\n                self.indices_updater_prefill.update(\n                    forward_batch.req_pool_indices,\n                    forward_batch.seq_lens,\n                    forward_batch.seq_lens_sum,\n                    prefix_lens=None,\n                    encoder_lens=forward_batch.encoder_lens,\n                    spec_info=forward_batch.spec_info,\n                )\n                self.forward_metadata = ForwardMetadata(\n                    self.indices_updater_prefill.kv_indptr,\n                    self.indices_updater_prefill.kv_indices,\n                    None,\n                    None,\n                    self.indices_updater_prefill.max_q_len,\n                    self.indices_updater_prefill.max_kv_len,\n                )\n        elif forward_batch.forward_mode.is_draft_extend():\n            # EAGLE V1: DRAFT_EXTEND mode - uses spec_info.accept_length\n            if self.use_mla:\n                kv_indices, kv_indptr, qo_indptr, custom_mask = (\n                    spec_info.generate_attn_arg_prefill(\n                        forward_batch.req_pool_indices,\n                        forward_batch.seq_lens,\n                        forward_batch.seq_lens_sum,\n                        self.req_to_token,\n                    )\n                )\n\n                if _use_mla_ps_kernel:\n                    max_seqlen_qo = max(forward_batch.extend_seq_lens_cpu)\n                    (\n                        work_metadata,\n                        work_indptr,\n                        work_info_set,\n                        reduce_indptr,\n                        reduce_final_map,\n                        reduce_partial_map,\n                    ) = self.make_mla_decode_meta_data_buffer(max_seqlen_qo, bs)\n\n                    num_kv_splits = self.max_split_per_batch\n\n                    self.make_mla_meta_data(\n                        qo_indptr,\n                        kv_indptr,\n                        self.kv_last_page_len[:bs],\n                        work_metadata,\n                        work_info_set,\n                        work_indptr,\n                        reduce_indptr,\n                        reduce_final_map,\n                        reduce_partial_map,\n                        max_seqlen_qo,\n                        fast_mode=fast_mode,\n                        max_split_per_batch=num_kv_splits,\n                        intra_batch_mode=intra_batch_mode,\n                    )\n\n                self.forward_metadata = ForwardMetadata(\n                    kv_indptr,\n                    kv_indices,\n                    qo_indptr,\n                    # self.mla_indices_updater_prefill.kv_last_page_len,\n                    self.kv_last_page_len[:bs],\n                    max(forward_batch.extend_seq_lens_cpu),\n                    forward_batch.seq_lens_cpu.max().item(),\n                    work_metadata=work_metadata,\n                    work_info_set=work_info_set,\n                    work_indptr=work_indptr,\n                    reduce_indptr=reduce_indptr,\n                    reduce_final_map=reduce_final_map,\n                    reduce_partial_map=reduce_partial_map,\n                    num_kv_splits=num_kv_splits,\n                    run_graph=False,\n                )\n            else:\n                # Non-MLA draft_extend: use triton extend kernel with causal masking\n                kv_indices, kv_indptr, qo_indptr, custom_mask = (\n                    spec_info.generate_attn_arg_prefill(\n                        forward_batch.req_pool_indices,\n                        forward_batch.seq_lens,\n                        forward_batch.seq_lens_sum,\n                        self.req_to_token,\n                    )\n                )\n                kv_indices = kv_indices.to(torch.int64)\n                draft_max_extend_len = torch.max(spec_info.accept_length).item()\n\n                self.forward_metadata = ForwardMetadata(\n                    kv_indptr,\n                    kv_indices,\n                    qo_indptr,\n                    None,\n                    draft_max_extend_len,\n                    None,\n                    custom_mask=custom_mask,\n                    mask_indptr=None,\n                    max_extend_len=draft_max_extend_len,\n                )\n        elif forward_batch.forward_mode.is_target_verify():\n            if self.use_mla:\n                draft_num = spec_info.draft_token_num\n                kv_lens = forward_batch.seq_lens + draft_num\n                kv_lens_sum = forward_batch.seq_lens_sum + draft_num * bs\n                device = forward_batch.seq_lens.device\n\n                qo_indptr = self.qo_indptr[: bs + 1]\n                qo_indptr[: bs + 1] = torch.arange(\n                    0,\n                    (1 + bs) * draft_num,\n                    step=draft_num,\n                    dtype=torch.int32,\n                    device=device,\n                )\n                kv_indptr = self.kv_indptr[: bs + 1]\n                kv_indptr[1 : bs + 1] = torch.cumsum(kv_lens, dim=0)\n                kv_indices = self._get_kv_indices_scratch(\n                    kv_lens_sum,\n                    device,\n                )\n                create_flashinfer_kv_indices_triton[(bs,)](\n                    self.req_to_token,\n                    forward_batch.req_pool_indices,\n                    kv_lens,\n                    kv_indptr,\n                    None,\n                    kv_indices,\n                    self.req_to_token.stride(0),\n                )\n\n                # if self.kv_cache_dtype == fp8_dtype:\n                if _use_mla_ps_kernel:\n                    max_seqlen_qo = draft_num\n                    (\n                        work_metadata,\n                        work_indptr,\n                        work_info_set,\n                        reduce_indptr,\n                        reduce_final_map,\n                        reduce_partial_map,\n                    ) = self.make_mla_decode_meta_data_buffer(max_seqlen_qo, bs)\n\n                    num_kv_splits = self.max_split_per_batch\n\n                    self.make_mla_meta_data(\n                        qo_indptr,\n                        kv_indptr,\n                        self.kv_last_page_len[:bs],\n                        work_metadata,\n                        work_info_set,\n                        work_indptr,\n                        reduce_indptr,\n                        reduce_final_map,\n                        reduce_partial_map,\n                        max_seqlen_qo,\n                        fast_mode=fast_mode,\n                        max_split_per_batch=num_kv_splits,\n                        intra_batch_mode=intra_batch_mode,\n                    )\n\n                self.forward_metadata = ForwardMetadata(\n                    kv_indptr,\n                    kv_indices,\n                    qo_indptr,\n                    # self.mla_indices_updater_prefill.kv_last_page_len,\n                    self.kv_last_page_len[:bs],\n                    draft_num,\n                    None,\n                    work_metadata=work_metadata,\n                    work_info_set=work_info_set,\n                    work_indptr=work_indptr,\n                    reduce_indptr=reduce_indptr,\n                    reduce_final_map=reduce_final_map,\n                    reduce_partial_map=reduce_partial_map,\n                    num_kv_splits=num_kv_splits,\n                    run_graph=False,\n                )\n            else:\n                # Non-MLA target_verify: use triton extend kernel with custom mask\n                bs = len(forward_batch.req_pool_indices)\n                draft_num = spec_info.draft_token_num\n\n                qo_indptr = torch.arange(\n                    0,\n                    (1 + bs) * draft_num,\n                    step=draft_num,\n                    dtype=torch.int32,\n                    device=self.device,\n                )\n\n                kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)\n                kv_indptr = kv_indptr[: bs + 1]\n\n                kv_indices = torch.empty(\n                    kv_indptr[-1], dtype=torch.int64, device=self.device\n                )\n                create_flashinfer_kv_indices_triton[(bs,)](\n                    self.req_to_token,\n                    forward_batch.req_pool_indices,\n                    forward_batch.seq_lens,\n                    kv_indptr,\n                    None,\n                    kv_indices,\n                    self.req_to_token.stride(0),\n                )\n\n                custom_mask = spec_info.custom_mask\n                seq_mask_len = draft_num * (forward_batch.seq_lens + draft_num)\n                mask_indptr = self.mask_indptr\n                mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len[:bs], dim=0)\n                mask_indptr = mask_indptr[: bs + 1]\n\n                self.forward_metadata = ForwardMetadata(\n                    kv_indptr,\n                    kv_indices,\n                    qo_indptr,\n                    None,\n                    draft_num,\n                    None,\n                    custom_mask=custom_mask,\n                    mask_indptr=mask_indptr,\n                    max_extend_len=draft_num,\n                )\n        else:\n            prefix_lens = forward_batch.extend_prefix_lens\n\n            if self.is_multimodal:\n                extend_no_prefix = False\n            else:\n                extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)\n            if self.use_mla:\n                self.mla_indices_updater_prefill.update(\n                    forward_batch.req_pool_indices,\n                    forward_batch.seq_lens,\n                    forward_batch.seq_lens_sum,\n                    forward_batch.extend_seq_lens,\n                    forward_batch.extend_seq_lens.max().item(),\n                    forward_batch.seq_lens.max().item(),\n                    spec_info=None,\n                )\n\n                max_q_len = self.mla_indices_updater_prefill.max_q_len\n                qo_indptr = self.mla_indices_updater_prefill.qo_indptr\n                kv_indptr = self.mla_indices_updater_prefill.kv_indptr\n\n                work_metadata = None\n                work_indptr = None\n                work_info_set = None\n                reduce_indptr = None\n                reduce_final_map = None\n                reduce_partial_map = None\n                fp8_prefill_kv_indices = None\n\n                if _use_fp8_prefill_attn:\n                    tile_q = 256\n                    qlen_granularity = tile_q // (self.num_head // self.num_kv_head)\n                    (\n                        work_metadata,\n                        work_indptr,\n                        work_info_set,\n                        reduce_indptr,\n                        reduce_final_map,\n                        reduce_partial_map,\n                    ) = self.make_mla_prefill_ps_meta_data_buffer(\n                        bs, max_q_len, qlen_granularity\n                    )\n\n                    self.make_mla_prefill_ps_meta_data(\n                        qo_indptr,\n                        kv_indptr,\n                        forward_batch.seq_lens,\n                        work_metadata,\n                        work_indptr,\n                        work_info_set,\n                        reduce_indptr,\n                        reduce_final_map,\n                        reduce_partial_map,\n                        is_causal=True,\n                    )\n\n                    total_s = forward_batch.seq_lens_sum\n                    fp8_prefill_kv_indices = torch.arange(\n                        total_s, device=self.device, dtype=torch.int32\n                    )\n\n                self.forward_metadata = ForwardMetadata(\n                    self.mla_indices_updater_prefill.kv_indptr,\n                    self.mla_indices_updater_prefill.kv_indices,\n                    qo_indptr,\n                    self.kv_last_page_len[:bs],\n                    max_q_len,\n                    self.mla_indices_updater_prefill.max_kv_len,\n                    work_metadata=work_metadata,\n                    work_info_set=work_info_set,\n                    work_indptr=work_indptr,\n                    reduce_indptr=reduce_indptr,\n                    reduce_final_map=reduce_final_map,\n                    reduce_partial_map=reduce_partial_map,\n                    fp8_prefill_kv_indices=fp8_prefill_kv_indices,\n                )\n            else:\n                self.indices_updater_prefill.update(\n                    forward_batch.req_pool_indices,\n                    forward_batch.seq_lens,\n                    forward_batch.seq_lens_sum,\n                    prefix_lens,\n                    encoder_lens=forward_batch.encoder_lens,\n                    spec_info=None,\n                )\n\n                if self.use_sliding_window_kv_pool:\n                    swa_page_table = (\n                        self.token_to_kv_pool.translate_loc_from_full_to_swa(\n                            self.indices_updater_prefill.kv_indices\n                        )\n                    )\n\n                self.forward_metadata = ForwardMetadata(\n                    self.indices_updater_prefill.kv_indptr,\n                    self.indices_updater_prefill.kv_indices,\n                    None,\n                    None,\n                    self.indices_updater_prefill.max_q_len,\n                    self.indices_updater_prefill.max_kv_len,\n                    swa_page_table=swa_page_table,\n                )\n\n    def init_cuda_graph_state(\n        self,\n        max_bs: int,\n        max_num_tokens: int,\n        kv_indices_buf: Optional[torch.Tensor] = None,\n    ):\n        self.cuda_graph_kv_last_page_len = torch.ones(max_bs, dtype=torch.int)\n        if kv_indices_buf is None:\n            max_num_blocks_per_seq = (\n                self.max_context_len + self.page_size - 1\n            ) // self.page_size\n            self.cuda_graph_kv_indices = torch.zeros(\n                (max_bs * max_num_blocks_per_seq),\n                dtype=torch.int32,\n                device=self.device,\n            )\n        else:\n            self.cuda_graph_kv_indices = kv_indices_buf\n\n        if not self.skip_prefill:\n            self.cuda_graph_custom_mask = torch.zeros(\n                (max_num_tokens * self.max_context_len),\n                dtype=torch.uint8,\n                device=self.device,\n            )\n\n        # if self.use_mla and (_use_mla_ps_kernel or self.kv_cache_dtype == fp8_dtype):\n        if self.use_mla and _use_mla_ps_kernel:\n            # for persistent mla_decode_fwd\n            max_seqlen_qo = (\n                1 if self.num_draft_tokens is None else self.num_draft_tokens\n            )\n\n            (\n                self.work_metadata,\n                self.work_indptr,\n                self.work_info_set,\n                self.reduce_indptr,\n                self.reduce_final_map,\n                self.reduce_partial_map,\n            ) = self.make_mla_decode_meta_data_buffer(max_seqlen_qo, max_bs)\n\n        else:\n            self.work_metadata = None\n            self.work_indptr = None\n            self.work_info_set = None\n\n            self.reduce_indptr = None\n            self.reduce_final_map = None\n            self.reduce_partial_map = None\n\n        if self.use_sliding_window_kv_pool:\n            max_num_blocks_per_seq = (\n                self.max_context_len + self.page_size - 1\n            ) // self.page_size\n            self.cuda_graph_swa_page_table = torch.zeros(\n                (max_bs, max_num_blocks_per_seq),\n                dtype=torch.int32,\n                device=self.device,\n            )\n\n    def init_forward_metadata_capture_cuda_graph(\n        self,\n        bs: int,\n        num_tokens: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        encoder_lens: Optional[torch.Tensor],\n        forward_mode: ForwardMode,\n        spec_info: Optional[SpecInput],\n    ):\n\n        num_kv_splits = None\n        # num_kv_splits_indptr = None\n\n        work_metadata = None\n        work_info_set = None\n        work_indptr = None\n\n        reduce_indptr = None\n        reduce_final_map = None\n        reduce_partial_map = None\n\n        swa_page_table = None\n\n        max_kv_len = torch.max(seq_lens).item()\n\n        if forward_mode.is_decode_or_idle():\n            qo_indptr = None\n            kv_last_page_len = None\n            max_q_len = None\n\n            if spec_info is None:\n\n                if not self.use_triton_unified_attention:\n                    kv_indptr = self.kv_indptr\n                    kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)\n                    kv_indptr = kv_indptr[: bs + 1]\n                    kv_indices = self.cuda_graph_kv_indices\n                    create_flashinfer_kv_indices_triton[(bs,)](\n                        self.req_to_token,\n                        req_pool_indices,\n                        seq_lens,\n                        kv_indptr,\n                        None,\n                        kv_indices,\n                        self.req_to_token.stride(0),\n                    )\n                else:\n                    max_q_len = 1\n                    max_num_blocks_per_seq = (\n                        self.max_context_len + self.page_size - 1\n                    ) // self.page_size\n                    kv_indices = self.cuda_graph_kv_indices.view(\n                        -1, max_num_blocks_per_seq\n                    )\n\n                    page_indices = self.req_to_token[req_pool_indices[:bs], :max_kv_len]\n\n                    if self.use_sliding_window_kv_pool:\n                        swa_page_indices = (\n                            self.token_to_kv_pool.translate_loc_from_full_to_swa(\n                                page_indices\n                            )\n                        )\n\n                        page_indices = self._transform_table_1_to_real(page_indices)\n                        swa_page_indices = self._transform_table_1_to_real(\n                            swa_page_indices\n                        )\n\n                        new_rows = swa_page_indices.shape[0]\n                        new_cols = swa_page_indices.shape[1]\n\n                        kv_indices[:new_rows, :new_cols].copy_(page_indices)\n                        swa_page_table = self.cuda_graph_swa_page_table\n                        swa_page_table[:new_rows, :new_cols].copy_(swa_page_indices)\n\n                    qo_indptr = self.qo_indptr[: bs + 1]\n                    qo_indptr[1 : bs + 1] = torch.cumsum(\n                        self.cuda_graph_kv_last_page_len[:bs], dim=0\n                    )\n\n                    kv_indptr = None\n            else:\n                kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices\n\n            if self.use_mla:\n                qo_indptr = self.qo_indptr_[: bs + 1]\n                qo_indptr[1 : bs + 1] = torch.cumsum(\n                    self.cuda_graph_kv_last_page_len[:bs], dim=0\n                )\n                kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]\n                max_q_len = 1\n\n                if _use_mla_ps_kernel:\n                    num_kv_splits = self.max_split_per_batch\n\n                    self.make_mla_meta_data(\n                        qo_indptr,\n                        kv_indptr,\n                        kv_last_page_len,\n                        self.work_metadata,\n                        self.work_info_set,\n                        self.work_indptr,\n                        self.reduce_indptr,\n                        self.reduce_final_map,\n                        self.reduce_partial_map,\n                        max_q_len,\n                        fast_mode=fast_mode,\n                        max_split_per_batch=num_kv_splits,\n                        intra_batch_mode=intra_batch_mode,\n                    )\n\n                    work_metadata = self.work_metadata\n                    work_info_set = self.work_info_set\n                    work_indptr = self.work_indptr\n\n                    reduce_indptr = self.reduce_indptr\n                    reduce_final_map = self.reduce_final_map\n                    reduce_partial_map = self.reduce_partial_map\n\n            self.forward_metadata = ForwardMetadata(\n                kv_indptr,\n                kv_indices,\n                qo_indptr,\n                kv_last_page_len,\n                max_q_len,\n                max_kv_len,\n                work_metadata=work_metadata,\n                work_info_set=work_info_set,\n                work_indptr=work_indptr,\n                reduce_indptr=reduce_indptr,\n                reduce_final_map=reduce_final_map,\n                reduce_partial_map=reduce_partial_map,\n                num_kv_splits=num_kv_splits,\n                swa_page_table=swa_page_table,\n            )\n\n        elif forward_mode.is_target_verify():\n            qo_indptr = self.qo_indptr[: bs + 1]\n            qo_indptr[: bs + 1] = torch.arange(\n                0,\n                (1 + bs) * self.num_draft_tokens,\n                step=self.num_draft_tokens,\n                dtype=torch.int32,\n                device=self.device,\n            )\n            if self.use_mla:\n                kv_lens = seq_lens + self.num_draft_tokens\n            else:\n                kv_lens = seq_lens\n            kv_indptr = self.kv_indptr[: bs + 1]\n            kv_indptr[1 : bs + 1] = torch.cumsum(kv_lens, dim=0)\n            kv_indices = self.cuda_graph_kv_indices\n            create_flashinfer_kv_indices_triton[(bs,)](\n                self.req_to_token,\n                req_pool_indices,\n                kv_lens,\n                kv_indptr,\n                None,\n                kv_indices,\n                self.req_to_token.stride(0),\n            )\n            kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]\n            max_q_len = self.num_draft_tokens\n\n            if self.use_mla:\n                if _use_mla_ps_kernel:\n\n                    num_kv_splits = self.max_split_per_batch\n\n                    self.make_mla_meta_data(\n                        qo_indptr,\n                        kv_indptr,\n                        kv_last_page_len,\n                        self.work_metadata,\n                        self.work_info_set,\n                        self.work_indptr,\n                        self.reduce_indptr,\n                        self.reduce_final_map,\n                        self.reduce_partial_map,\n                        max_q_len,\n                        fast_mode=fast_mode,\n                        max_split_per_batch=num_kv_splits,\n                        intra_batch_mode=intra_batch_mode,\n                    )\n\n                    work_metadata = self.work_metadata\n                    work_info_set = self.work_info_set\n                    work_indptr = self.work_indptr\n\n                    reduce_indptr = self.reduce_indptr\n                    reduce_final_map = self.reduce_final_map\n                    reduce_partial_map = self.reduce_partial_map\n\n                self.forward_metadata = ForwardMetadata(\n                    kv_indptr,\n                    kv_indices,\n                    qo_indptr,\n                    kv_last_page_len,\n                    max_q_len,\n                    kv_indptr[-1].item(),\n                    work_metadata=work_metadata,\n                    work_info_set=work_info_set,\n                    work_indptr=work_indptr,\n                    reduce_indptr=reduce_indptr,\n                    reduce_final_map=reduce_final_map,\n                    reduce_partial_map=reduce_partial_map,\n                    num_kv_splits=num_kv_splits,\n                )\n            else:\n                custom_mask = self.cuda_graph_custom_mask\n                custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask\n                seq_mask_len = max_q_len * (seq_lens + max_q_len)\n                mask_indptr = self.mask_indptr\n                mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len[:bs], dim=0)\n                mask_indptr = mask_indptr[: bs + 1]\n\n                self.forward_metadata = ForwardMetadata(\n                    kv_indptr,\n                    kv_indices,\n                    qo_indptr,\n                    kv_last_page_len,\n                    max_q_len,\n                    kv_indptr[-1].item(),\n                    custom_mask=custom_mask,\n                    mask_indptr=mask_indptr,\n                    max_extend_len=max_q_len,\n                )\n        elif forward_mode.is_draft_extend_v2():\n            # EAGLE V2: Uses fixed num_draft_tokens per batch\n            self._ensure_spec_v2_topk_supported()\n            num_tokens_per_bs = self._resolve_v2_num_draft_tokens()\n            qo_indptr = self._set_uniform_qo_indptr(bs, num_tokens_per_bs, self.device)\n            kv_indptr = self.kv_indptr[: bs + 1]\n            kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)\n            kv_indices = self.cuda_graph_kv_indices\n            create_flashinfer_kv_indices_triton[(bs,)](\n                self.req_to_token,\n                req_pool_indices,\n                seq_lens,\n                kv_indptr,\n                None,\n                kv_indices,\n                self.req_to_token.stride(0),\n            )\n            kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]\n            max_q_len = num_tokens_per_bs\n\n            if self.use_mla and _use_mla_ps_kernel:\n                num_kv_splits = self.max_split_per_batch\n\n                self.make_mla_meta_data(\n                    qo_indptr,\n                    kv_indptr,\n                    kv_last_page_len,\n                    self.work_metadata,\n                    self.work_info_set,\n                    self.work_indptr,\n                    self.reduce_indptr,\n                    self.reduce_final_map,\n                    self.reduce_partial_map,\n                    max_q_len,\n                    fast_mode=fast_mode,\n                    max_split_per_batch=num_kv_splits,\n                    intra_batch_mode=intra_batch_mode,\n                )\n\n                work_metadata = self.work_metadata\n                work_info_set = self.work_info_set\n                work_indptr = self.work_indptr\n\n                reduce_indptr = self.reduce_indptr\n                reduce_final_map = self.reduce_final_map\n                reduce_partial_map = self.reduce_partial_map\n\n            self.forward_metadata = ForwardMetadata(\n                kv_indptr,\n                kv_indices,\n                qo_indptr,\n                kv_last_page_len,\n                max_q_len,\n                kv_indptr[-1].item(),\n                work_metadata=work_metadata,\n                work_info_set=work_info_set,\n                work_indptr=work_indptr,\n                reduce_indptr=reduce_indptr,\n                reduce_final_map=reduce_final_map,\n                reduce_partial_map=reduce_partial_map,\n                num_kv_splits=num_kv_splits,\n            )\n        elif forward_mode.is_draft_extend():\n            # EAGLE V1: Uses speculative_num_steps + 1\n            num_tokens_per_bs = self.speculative_num_steps + 1\n            qo_indptr = self.qo_indptr[: bs + 1]\n            qo_indptr[: bs + 1] = torch.arange(\n                0,\n                bs * num_tokens_per_bs + 1,\n                step=num_tokens_per_bs,\n                dtype=torch.int32,\n                device=self.device,\n            )\n            kv_indptr = self.kv_indptr[: bs + 1]\n            kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)\n            kv_indices = self.cuda_graph_kv_indices\n            create_flashinfer_kv_indices_triton[(bs,)](\n                self.req_to_token,\n                req_pool_indices,\n                seq_lens,\n                kv_indptr,\n                None,\n                kv_indices,\n                self.req_to_token.stride(0),\n            )\n\n            if self.use_mla:\n                kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]\n                max_q_len = num_tokens_per_bs\n\n                if _use_mla_ps_kernel:\n\n                    num_kv_splits = self.max_split_per_batch\n\n                    self.make_mla_meta_data(\n                        qo_indptr,\n                        kv_indptr,\n                        kv_last_page_len,\n                        self.work_metadata,\n                        self.work_info_set,\n                        self.work_indptr,\n                        self.reduce_indptr,\n                        self.reduce_final_map,\n                        self.reduce_partial_map,\n                        max_q_len,\n                        fast_mode=fast_mode,\n                        max_split_per_batch=num_kv_splits,\n                        intra_batch_mode=intra_batch_mode,\n                    )\n\n                    work_metadata = self.work_metadata\n                    work_info_set = self.work_info_set\n                    work_indptr = self.work_indptr\n\n                    reduce_indptr = self.reduce_indptr\n                    reduce_final_map = self.reduce_final_map\n                    reduce_partial_map = self.reduce_partial_map\n\n                self.forward_metadata = ForwardMetadata(\n                    kv_indptr,\n                    kv_indices,\n                    qo_indptr,\n                    kv_last_page_len,\n                    max_q_len,\n                    kv_indptr[-1].item(),\n                    work_metadata=work_metadata,\n                    work_info_set=work_info_set,\n                    work_indptr=work_indptr,\n                    reduce_indptr=reduce_indptr,\n                    reduce_final_map=reduce_final_map,\n                    reduce_partial_map=reduce_partial_map,\n                    num_kv_splits=num_kv_splits,\n                )\n            else:\n                # Non-MLA draft_extend cuda graph: use triton extend kernel\n                self.forward_metadata = ForwardMetadata(\n                    kv_indptr,\n                    kv_indices,\n                    qo_indptr,\n                    None,\n                    num_tokens_per_bs,\n                    None,\n                    custom_mask=None,\n                    mask_indptr=None,\n                    max_extend_len=num_tokens_per_bs,\n                )\n        else:\n            raise ValueError(f\"Invalid mode: {forward_mode=}\")\n\n    def init_forward_metadata_replay_cuda_graph(\n        self,\n        bs: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        seq_lens_sum: int,\n        encoder_lens: Optional[torch.Tensor],\n        forward_mode: ForwardMode,\n        spec_info: Optional[SpecInput],\n        seq_lens_cpu: Optional[torch.Tensor],\n    ):\n\n        num_kv_splits = None\n        # num_kv_splits_indptr = None\n\n        work_metadata = None\n        work_info_set = None\n        work_indptr = None\n\n        reduce_indptr = None\n        reduce_final_map = None\n        reduce_partial_map = None\n\n        swa_page_table = None\n        max_kv_len = torch.max(seq_lens).item()\n\n        if forward_mode.is_decode_or_idle():\n            qo_indptr = None\n            kv_last_page_len = None\n            max_q_len = None\n\n            if spec_info is None:\n                if not self.use_triton_unified_attention:\n                    kv_indptr = self.kv_indptr\n                    kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)\n                    kv_indptr = kv_indptr[: bs + 1]\n                    kv_indices = self.cuda_graph_kv_indices\n                    create_flashinfer_kv_indices_triton[(bs,)](\n                        self.req_to_token,\n                        req_pool_indices,\n                        seq_lens,\n                        kv_indptr,\n                        None,\n                        kv_indices,\n                        self.req_to_token.stride(0),\n                    )\n                else:\n                    max_q_len = 1\n                    max_num_blocks_per_seq = (\n                        self.max_context_len + self.page_size - 1\n                    ) // self.page_size\n                    kv_indices = self.cuda_graph_kv_indices.view(\n                        -1, max_num_blocks_per_seq\n                    )\n\n                    page_indices = self.req_to_token[req_pool_indices[:bs], :max_kv_len]\n\n                    if self.use_sliding_window_kv_pool:\n                        swa_page_indices = (\n                            self.token_to_kv_pool.translate_loc_from_full_to_swa(\n                                page_indices\n                            )\n                        )\n\n                        page_indices = self._transform_table_1_to_real(page_indices)\n                        swa_page_indices = self._transform_table_1_to_real(\n                            swa_page_indices\n                        )\n\n                        new_rows = swa_page_indices.shape[0]\n                        new_cols = swa_page_indices.shape[1]\n\n                        kv_indices[:new_rows, :new_cols].copy_(page_indices)\n                        swa_page_table = self.cuda_graph_swa_page_table\n                        swa_page_table[:new_rows, :new_cols].copy_(swa_page_indices)\n\n                    qo_indptr = self.qo_indptr[: bs + 1]\n                    qo_indptr[1 : bs + 1] = torch.cumsum(\n                        self.cuda_graph_kv_last_page_len[:bs], dim=0\n                    )\n\n                    kv_indptr = None\n            else:\n                kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices\n\n            if self.use_mla:\n                qo_indptr = self.qo_indptr_[: bs + 1]\n                qo_indptr[1 : bs + 1] = torch.cumsum(\n                    self.cuda_graph_kv_last_page_len[:bs], dim=0\n                )\n                kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]\n                max_q_len = 1\n\n                if _use_mla_ps_kernel:\n                    num_kv_splits = self.max_split_per_batch\n\n                    self.make_mla_meta_data(\n                        qo_indptr,\n                        kv_indptr,\n                        kv_last_page_len,\n                        self.work_metadata,\n                        self.work_info_set,\n                        self.work_indptr,\n                        self.reduce_indptr,\n                        self.reduce_final_map,\n                        self.reduce_partial_map,\n                        max_q_len,\n                        fast_mode=fast_mode,\n                        max_split_per_batch=num_kv_splits,\n                        intra_batch_mode=intra_batch_mode,\n                    )\n\n                    work_metadata = self.work_metadata\n                    work_info_set = self.work_info_set\n                    work_indptr = self.work_indptr\n\n                    reduce_indptr = self.reduce_indptr\n                    reduce_final_map = self.reduce_final_map\n                    reduce_partial_map = self.reduce_partial_map\n\n            self.forward_metadata = ForwardMetadata(\n                kv_indptr,\n                kv_indices,\n                qo_indptr,\n                kv_last_page_len,\n                max_q_len,\n                max_kv_len,\n                work_metadata=work_metadata,\n                work_info_set=work_info_set,\n                work_indptr=work_indptr,\n                reduce_indptr=reduce_indptr,\n                reduce_final_map=reduce_final_map,\n                reduce_partial_map=reduce_partial_map,\n                num_kv_splits=num_kv_splits,\n                swa_page_table=swa_page_table,\n                # num_kv_splits_indptr=num_kv_splits_indptr,\n            )\n\n        elif forward_mode.is_target_verify():\n            bs = len(req_pool_indices)\n            qo_indptr = self.qo_indptr[: bs + 1]\n            qo_indptr[: bs + 1] = torch.arange(\n                0,\n                (1 + bs) * self.num_draft_tokens,\n                step=self.num_draft_tokens,\n                dtype=torch.int32,\n                device=self.device,\n            )\n            if self.use_mla:\n                kv_lens = seq_lens + self.num_draft_tokens\n            else:\n                kv_lens = seq_lens\n            kv_indptr = self.kv_indptr[: bs + 1]\n            kv_indptr[1 : bs + 1] = torch.cumsum(kv_lens, dim=0)\n            kv_indices = self.cuda_graph_kv_indices\n            create_flashinfer_kv_indices_triton[(bs,)](\n                self.req_to_token,\n                req_pool_indices,\n                kv_lens,\n                kv_indptr,\n                None,\n                kv_indices,\n                self.req_to_token.stride(0),\n            )\n            kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]\n            max_q_len = self.num_draft_tokens\n\n            if self.use_mla:\n                if _use_mla_ps_kernel:\n\n                    num_kv_splits = self.max_split_per_batch\n\n                    self.make_mla_meta_data(\n                        qo_indptr,\n                        kv_indptr,\n                        kv_last_page_len,\n                        self.work_metadata,\n                        self.work_info_set,\n                        self.work_indptr,\n                        self.reduce_indptr,\n                        self.reduce_final_map,\n                        self.reduce_partial_map,\n                        max_q_len,\n                        fast_mode=fast_mode,\n                        max_split_per_batch=num_kv_splits,\n                        intra_batch_mode=intra_batch_mode,\n                    )\n\n                    work_metadata = self.work_metadata\n                    work_info_set = self.work_info_set\n                    work_indptr = self.work_indptr\n\n                    reduce_indptr = self.reduce_indptr\n                    reduce_final_map = self.reduce_final_map\n                    reduce_partial_map = self.reduce_partial_map\n\n                self.forward_metadata = ForwardMetadata(\n                    kv_indptr,\n                    kv_indices,\n                    qo_indptr,\n                    kv_last_page_len,\n                    max_q_len,\n                    kv_indptr[-1].item(),\n                    work_metadata=work_metadata,\n                    work_info_set=work_info_set,\n                    work_indptr=work_indptr,\n                    reduce_indptr=reduce_indptr,\n                    reduce_final_map=reduce_final_map,\n                    reduce_partial_map=reduce_partial_map,\n                    num_kv_splits=num_kv_splits,\n                )\n            else:\n                custom_mask = self.cuda_graph_custom_mask\n                custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask\n                seq_mask_len = max_q_len * (seq_lens + max_q_len)\n                mask_indptr = self.mask_indptr[: bs + 1]\n                mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)\n\n                self.forward_metadata = ForwardMetadata(\n                    kv_indptr,\n                    kv_indices,\n                    qo_indptr,\n                    kv_last_page_len,\n                    max_q_len,\n                    kv_indptr[-1].item(),\n                    custom_mask=custom_mask,\n                    mask_indptr=mask_indptr,\n                    max_extend_len=max_q_len,\n                )\n        elif forward_mode.is_draft_extend_v2():\n            # EAGLE V2: Fixed num_draft_tokens per batch\n            self._ensure_spec_v2_topk_supported()\n            seq_lens = seq_lens[:bs]\n            num_tokens_per_bs = self._resolve_v2_num_draft_tokens()\n            extend_lens = torch.full(\n                (bs,), num_tokens_per_bs, dtype=torch.int32, device=seq_lens.device\n            )\n\n            qo_indptr = self.qo_indptr[: bs + 1]\n            qo_indptr[1 : bs + 1] = torch.cumsum(extend_lens, dim=0)\n            kv_indptr = self.kv_indptr[: bs + 1]\n            kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)\n            kv_indices = self.cuda_graph_kv_indices\n            create_flashinfer_kv_indices_triton[(bs,)](\n                self.req_to_token,\n                req_pool_indices,\n                seq_lens,\n                kv_indptr,\n                None,\n                kv_indices,\n                self.req_to_token.stride(0),\n            )\n\n            kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]\n            max_q_len = num_tokens_per_bs\n\n            if self.use_mla and _use_mla_ps_kernel:\n\n                num_kv_splits = self.max_split_per_batch\n\n                self.make_mla_meta_data(\n                    qo_indptr,\n                    kv_indptr,\n                    kv_last_page_len,\n                    self.work_metadata,\n                    self.work_info_set,\n                    self.work_indptr,\n                    self.reduce_indptr,\n                    self.reduce_final_map,\n                    self.reduce_partial_map,\n                    max_q_len,\n                    fast_mode=fast_mode,\n                    max_split_per_batch=num_kv_splits,\n                    intra_batch_mode=intra_batch_mode,\n                )\n\n                work_metadata = self.work_metadata\n                work_info_set = self.work_info_set\n                work_indptr = self.work_indptr\n\n                reduce_indptr = self.reduce_indptr\n                reduce_final_map = self.reduce_final_map\n                reduce_partial_map = self.reduce_partial_map\n\n            self.forward_metadata = ForwardMetadata(\n                kv_indptr,\n                kv_indices,\n                qo_indptr,\n                kv_last_page_len,\n                max_q_len,\n                kv_indptr[-1].item(),\n                work_metadata=work_metadata,\n                work_info_set=work_info_set,\n                work_indptr=work_indptr,\n                reduce_indptr=reduce_indptr,\n                reduce_final_map=reduce_final_map,\n                reduce_partial_map=reduce_partial_map,\n                num_kv_splits=num_kv_splits,\n            )\n        elif forward_mode.is_draft_extend():\n            # EAGLE V1: Uses spec_info.accept_length\n            num_tokens_per_bs = self.speculative_num_steps + 1\n            seq_lens = seq_lens[:bs]\n            accept_lens = spec_info.accept_length[:bs]\n            qo_indptr = self.qo_indptr[: bs + 1]\n            qo_indptr[1 : bs + 1] = torch.cumsum(accept_lens, dim=0)\n            kv_indptr = self.kv_indptr[: bs + 1]\n            kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)\n            kv_indices = self.cuda_graph_kv_indices\n            create_flashinfer_kv_indices_triton[(bs,)](\n                self.req_to_token,\n                req_pool_indices,\n                seq_lens,\n                kv_indptr,\n                None,\n                kv_indices,\n                self.req_to_token.stride(0),\n            )\n\n            kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]\n            max_q_len = num_tokens_per_bs\n\n            if self.use_mla and _use_mla_ps_kernel:\n\n                num_kv_splits = self.max_split_per_batch\n\n                self.make_mla_meta_data(\n                    qo_indptr,\n                    kv_indptr,\n                    kv_last_page_len,\n                    self.work_metadata,\n                    self.work_info_set,\n                    self.work_indptr,\n                    self.reduce_indptr,\n                    self.reduce_final_map,\n                    self.reduce_partial_map,\n                    max_q_len,\n                    fast_mode=fast_mode,\n                    max_split_per_batch=num_kv_splits,\n                    intra_batch_mode=intra_batch_mode,\n                )\n\n                work_metadata = self.work_metadata\n                work_info_set = self.work_info_set\n                work_indptr = self.work_indptr\n\n                reduce_indptr = self.reduce_indptr\n                reduce_final_map = self.reduce_final_map\n                reduce_partial_map = self.reduce_partial_map\n\n            self.forward_metadata = ForwardMetadata(\n                kv_indptr,\n                kv_indices,\n                qo_indptr,\n                kv_last_page_len,\n                max_q_len,\n                kv_indptr[-1].item(),\n                work_metadata=work_metadata,\n                work_info_set=work_info_set,\n                work_indptr=work_indptr,\n                reduce_indptr=reduce_indptr,\n                reduce_final_map=reduce_final_map,\n                reduce_partial_map=reduce_partial_map,\n                num_kv_splits=num_kv_splits,\n            )\n\n        else:\n            raise ValueError(\"Invalid forward mode\")\n\n    def get_cuda_graph_seq_len_fill_value(self):\n        return 1 if self.num_draft_tokens is None else self.num_draft_tokens\n\n    def update_verify_buffers_to_fill_after_draft(\n        self, spec_info: SpecInput, cuda_graph_bs: Optional[int]\n    ):\n        # AITER verify path does not require post-draft buffer patching currently.\n        # This override prevents overlap-plan stream mode from failing with the\n        # base class NotImplementedError.\n        pass\n\n    def forward_extend(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache=True,\n        sinks=None,\n    ):\n        self.logits_soft_cap = layer.logit_cap\n\n        cache_loc = (\n            forward_batch.out_cache_loc\n            if not layer.is_cross_attention\n            else forward_batch.encoder_out_cache_loc\n        )\n\n        if k is not None:\n            assert v is not None\n            if save_kv_cache:\n                if self.use_triton_unified_attention:\n                    token_to_kv_pool = forward_batch.token_to_kv_pool\n                    k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(\n                        layer.layer_id\n                    )\n                    slot_mapping_swa = token_to_kv_pool.full_to_swa_index_mapping\n\n                    launch_reshape_and_cache_flash(\n                        k.view(-1, layer.tp_k_head_num, layer.qk_head_dim),\n                        v.view(-1, layer.tp_v_head_num, layer.v_head_dim),\n                        k_cache.view(\n                            -1, self.page_size, layer.tp_k_head_num, layer.qk_head_dim\n                        ),\n                        v_cache.view(\n                            -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim\n                        ),\n                        cache_loc,\n                        (\n                            slot_mapping_swa.long()\n                            if layer.sliding_window_size > 0\n                            else None\n                        ),\n                    )\n\n                elif self.use_mla:\n                    forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)\n                else:\n                    forward_batch.token_to_kv_pool.set_kv_buffer(\n                        layer, cache_loc, k, v, layer.k_scale, layer.v_scale\n                    )\n\n        if self.use_mla:\n            max_q_len = self.forward_metadata.max_q_len\n            max_kv_len = self.forward_metadata.max_kv_len\n            kv_indptr = self.forward_metadata.kv_indptr\n            kv_indices = self.forward_metadata.kv_indices\n            qo_indptr = self.forward_metadata.qo_indptr\n            K_Buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)\n            V_Buffer = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)\n            kv_lora_rank = V_Buffer.shape[-1]\n            qk_rope_head_dim = K_Buffer.shape[-1] - kv_lora_rank\n            qk_nope_head_dim = k.shape[-1] - qk_rope_head_dim\n            assert len(q.shape) == 3\n            assert len(k.shape) == 3\n            assert len(v.shape) == 3\n\n            if (\n                forward_batch.forward_mode.is_extend()\n                and not forward_batch.forward_mode.is_target_verify()\n                and not forward_batch.forward_mode.is_draft_extend()\n                and not forward_batch.forward_mode.is_draft_extend_v2()\n            ):\n                extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)\n                if kv_indices.shape[0] == 0 or extend_no_prefix:\n                    if _use_fp8_prefill_attn:\n                        output = self.mla_fp8_prefill_attn(\n                            q,\n                            k,\n                            v,\n                            layer,\n                        )\n                    else:\n                        output = flash_attn_varlen_func(\n                            q,\n                            k,\n                            v,\n                            qo_indptr,\n                            qo_indptr,\n                            max_q_len,\n                            max_q_len,\n                            softmax_scale=layer.scaling,\n                            causal=True,\n                        )\n                    return output\n                elif layer.qk_head_dim != (kv_lora_rank + qk_rope_head_dim):\n                    K_Buffer = torch.index_select(K_Buffer, 0, kv_indices)\n                    kvc, k_pe = torch.split(\n                        K_Buffer, [kv_lora_rank, qk_rope_head_dim], dim=-1\n                    )\n\n                    if self.kv_cache_dtype == fp8_dtype:\n                        dtype = q.dtype\n\n                        kvc = kvc.to(dtype)\n                        k_pe = k_pe.to(dtype)\n\n                    if (\n                        _use_fp8_prefill_attn\n                        and layer.kv_b_proj.weight.dtype == torch.uint8\n                    ):\n                        # MXFP4 weights + FP8 prefill: fuse GEMM, nope/v split, and k_pe cat\n                        # into a single kernel (fused_gemm_afp4wfp4_split_cat) that writes k and v\n                        # directly in FP8, avoiding a separate elementwise cast\n                        k, v = layer.kv_b_proj(\n                            (\n                                kvc.squeeze(1),\n                                k_pe.expand(-1, layer.tp_k_head_num, -1),\n                                qk_nope_head_dim,\n                                layer.v_head_dim,\n                                fp8_dtype,\n                            )\n                        )[0]\n                    else:\n                        kv = layer.kv_b_proj(kvc.contiguous())[0]\n\n                        kv = kv.view(\n                            -1, layer.tp_k_head_num, qk_nope_head_dim + layer.v_head_dim\n                        )\n                        k, v = torch.split(\n                            kv, [qk_nope_head_dim, layer.v_head_dim], dim=-1\n                        )\n                        k = torch.cat(\n                            [\n                                k,\n                                torch.broadcast_to(\n                                    k_pe,\n                                    (k_pe.shape[0], layer.tp_k_head_num, k_pe.shape[2]),\n                                ),\n                            ],\n                            dim=-1,\n                        )\n\n                    assert (\n                        forward_batch.extend_prefix_lens.shape\n                        == forward_batch.extend_seq_lens.shape\n                    )\n\n                    if _use_fp8_prefill_attn:\n                        return self.mla_fp8_prefill_attn(q, k, v, layer)\n                    else:\n                        return flash_attn_varlen_func(\n                            q,\n                            k,\n                            v,\n                            qo_indptr,\n                            kv_indptr,\n                            max_q_len,\n                            max_kv_len,\n                            softmax_scale=layer.scaling,\n                            causal=True,\n                        )\n\n                else:\n                    if layer.qk_head_dim != layer.v_head_dim:\n                        o = q.new_empty(\n                            (q.shape[0], layer.tp_q_head_num * layer.v_head_dim)\n                        )\n                    else:\n                        o = torch.empty_like(q)\n\n                    mla_prefill_fwd(\n                        q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),\n                        K_Buffer.view(-1, 1, 1, layer.qk_head_dim),\n                        o.view(-1, layer.tp_q_head_num, layer.v_head_dim),\n                        qo_indptr,\n                        kv_indptr,\n                        kv_indices,\n                        self.forward_metadata.kv_last_page_len,\n                        self.forward_metadata.max_q_len,\n                        layer.scaling,\n                        layer.logit_cap,\n                    )\n                    K_Buffer = K_Buffer.view(-1, layer.tp_k_head_num, layer.qk_head_dim)\n                    return o\n            elif forward_batch.forward_mode.is_target_verify():\n                o = q.new_empty(\n                    (q.shape[0], layer.tp_q_head_num, layer.v_head_dim),\n                    dtype=self.input_dtype,\n                )\n\n                work_metadata = self.forward_metadata.work_metadata\n                work_indptr = self.forward_metadata.work_indptr\n                work_info_set = self.forward_metadata.work_info_set\n\n                reduce_indptr = self.forward_metadata.reduce_indptr\n                reduce_final_map = self.forward_metadata.reduce_final_map\n                reduce_partial_map = self.forward_metadata.reduce_partial_map\n\n                num_kv_splits = self.forward_metadata.num_kv_splits\n\n                mla_decode_fwd(\n                    q,\n                    K_Buffer.view(-1, 1, 1, layer.qk_head_dim),\n                    o,\n                    self.forward_metadata.qo_indptr,\n                    self.forward_metadata.kv_indptr,\n                    self.forward_metadata.kv_indices,\n                    self.forward_metadata.kv_last_page_len,\n                    self.forward_metadata.max_q_len,\n                    sm_scale=layer.scaling,\n                    logit_cap=layer.logit_cap,\n                    work_meta_data=work_metadata,\n                    work_indptr=work_indptr,\n                    work_info_set=work_info_set,\n                    reduce_indptr=reduce_indptr,\n                    reduce_final_map=reduce_final_map,\n                    reduce_partial_map=reduce_partial_map,\n                    q_scale=(\n                        layer.k_scale if layer.k_scale is not None else self.k_scale\n                    ),\n                    kv_scale=(\n                        layer.k_scale if layer.k_scale is not None else self.k_scale\n                    ),\n                    intra_batch_mode=intra_batch_mode,\n                    num_kv_splits=num_kv_splits,\n                )\n                return o\n            elif (\n                forward_batch.forward_mode.is_draft_extend()\n                or forward_batch.forward_mode.is_draft_extend_v2()\n            ):\n\n                work_metadata = self.forward_metadata.work_metadata\n                work_indptr = self.forward_metadata.work_indptr\n                work_info_set = self.forward_metadata.work_info_set\n\n                reduce_indptr = self.forward_metadata.reduce_indptr\n                reduce_final_map = self.forward_metadata.reduce_final_map\n                reduce_partial_map = self.forward_metadata.reduce_partial_map\n\n                num_kv_splits = self.forward_metadata.num_kv_splits\n\n                if self.forward_metadata.run_graph is not True:\n\n                    bs, q_pad, q_mask = pad_sequence_with_mask(\n                        q.view(q.shape[0], -1),\n                        qo_indptr[:-1],\n                        forward_batch.extend_seq_lens,\n                        self.forward_metadata.max_q_len,\n                    )\n                    o = q.new_empty(\n                        (\n                            bs * self.forward_metadata.max_q_len,\n                            layer.tp_q_head_num,\n                            layer.v_head_dim,\n                        ),\n                        dtype=self.input_dtype,\n                    )\n                    mla_decode_fwd(\n                        q_pad.view(-1, layer.tp_q_head_num, layer.qk_head_dim),\n                        K_Buffer.view(-1, 1, 1, layer.qk_head_dim),\n                        o,\n                        self.forward_metadata.qo_indptr,\n                        self.forward_metadata.kv_indptr,\n                        self.forward_metadata.kv_indices,\n                        self.forward_metadata.kv_last_page_len,\n                        self.forward_metadata.max_q_len,\n                        sm_scale=layer.scaling,\n                        logit_cap=layer.logit_cap,\n                        work_meta_data=work_metadata,\n                        work_indptr=work_indptr,\n                        work_info_set=work_info_set,\n                        reduce_indptr=reduce_indptr,\n                        reduce_final_map=reduce_final_map,\n                        reduce_partial_map=reduce_partial_map,\n                        q_scale=(\n                            layer.k_scale if layer.k_scale is not None else self.k_scale\n                        ),\n                        kv_scale=(\n                            layer.k_scale if layer.k_scale is not None else self.k_scale\n                        ),\n                        intra_batch_mode=intra_batch_mode,\n                        num_kv_splits=num_kv_splits,\n                    )\n\n                    total_valid_q = int(qo_indptr[-1].item())\n                    return o[:total_valid_q]\n                else:\n                    o = q.new_empty(\n                        (q.shape[0], layer.tp_q_head_num, layer.v_head_dim),\n                        dtype=self.input_dtype,\n                    )\n\n                    mla_decode_fwd(\n                        q,\n                        K_Buffer.view(-1, 1, 1, layer.qk_head_dim),\n                        o,\n                        self.forward_metadata.qo_indptr,\n                        self.forward_metadata.kv_indptr,\n                        self.forward_metadata.kv_indices,\n                        self.forward_metadata.kv_last_page_len,\n                        self.forward_metadata.max_q_len,\n                        sm_scale=layer.scaling,\n                        logit_cap=layer.logit_cap,\n                        work_meta_data=work_metadata,\n                        work_indptr=work_indptr,\n                        work_info_set=work_info_set,\n                        reduce_indptr=reduce_indptr,\n                        reduce_final_map=reduce_final_map,\n                        reduce_partial_map=reduce_partial_map,\n                        q_scale=(\n                            layer.k_scale if layer.k_scale is not None else self.k_scale\n                        ),\n                        kv_scale=(\n                            layer.k_scale if layer.k_scale is not None else self.k_scale\n                        ),\n                        intra_batch_mode=intra_batch_mode,\n                        num_kv_splits=num_kv_splits,\n                    )\n                    return o\n            else:\n                raise ValueError(\n                    f\"Invalid forward mode for MLA prefill: {forward_batch.forward_mode=}\"\n                )\n        else:\n            if (\n                forward_batch.forward_mode.is_target_verify()\n                or forward_batch.forward_mode.is_draft_extend()\n            ):\n                # Use triton extend kernel which supports custom masks and causal masking\n                if layer.qk_head_dim != layer.v_head_dim:\n                    o = q.new_empty(\n                        (q.shape[0], layer.tp_q_head_num * layer.v_head_dim)\n                    )\n                else:\n                    o = torch.empty_like(q)\n\n                self.extend_attention_fwd(\n                    q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),\n                    k.contiguous(),\n                    v.contiguous(),\n                    o.view(-1, layer.tp_q_head_num, layer.v_head_dim),\n                    forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),\n                    forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),\n                    self.forward_metadata.qo_indptr,\n                    self.forward_metadata.kv_indptr,\n                    self.forward_metadata.kv_indices,\n                    self.forward_metadata.custom_mask,\n                    True,  # causal\n                    self.forward_metadata.mask_indptr,\n                    self.forward_metadata.max_extend_len,\n                    1.0,  # k_scale\n                    1.0,  # v_scale\n                    layer.scaling,\n                    logit_cap=layer.logit_cap,\n                )\n                return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)\n\n            k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(\n                layer.layer_id\n            )\n\n            bs0 = forward_batch.batch_size + 1\n\n            # TODO kkhuang-amd need to remove it when mha_batch_prefill_func support fp8-kv\n            if self.kv_cache_dtype == fp8_dtype:\n                dtype = q.dtype\n                k_cache = k_cache.to(dtype)\n                v_cache = v_cache.to(dtype)\n\n            window_size = (-1, -1)\n            page_table = self.forward_metadata.kv_indices\n\n            if layer.sliding_window_size is not None and layer.sliding_window_size > -1:\n                window_size = (layer.sliding_window_size, -1)\n                # page_table = self.token_to_kv_pool.translate_loc_from_full_to_swa(\n                #    page_table\n                # )\n                page_table = self.forward_metadata.swa_page_table\n\n            o = mha_batch_prefill_func(\n                q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),\n                k_cache,\n                v_cache,\n                self.qo_indptr[:bs0],\n                self.forward_metadata.kv_indptr[:bs0],\n                page_table,\n                self.forward_metadata.max_q_len,\n                self.forward_metadata.max_kv_len,\n                causal=True,\n                logits_soft_cap=self.logits_soft_cap,\n                alibi_slopes=None,\n                return_lse=False,\n                return_attn_probs=False,\n                window_size=window_size,\n                sink_ptr=sinks,\n            )\n\n            return o.view(-1, layer.tp_q_head_num * layer.head_dim)\n\n    def forward_decode(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache=True,\n        sinks=None,\n    ):\n\n        q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)\n\n        if layer.qk_head_dim != layer.v_head_dim:\n            o = q.new_empty(\n                (q.shape[0], layer.tp_q_head_num * layer.v_head_dim),\n                dtype=self.input_dtype,\n            )\n        else:\n            o = torch.empty_like(q, dtype=self.input_dtype)\n\n        if save_kv_cache:\n            if self.use_triton_unified_attention:\n                token_to_kv_pool = forward_batch.token_to_kv_pool\n                k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(\n                    layer.layer_id\n                )\n                slot_mapping_swa = token_to_kv_pool.full_to_swa_index_mapping\n\n                launch_reshape_and_cache_flash(\n                    k.view(-1, layer.tp_k_head_num, layer.qk_head_dim),\n                    v.view(-1, layer.tp_v_head_num, layer.v_head_dim),\n                    k_cache.view(\n                        -1, self.page_size, layer.tp_k_head_num, layer.qk_head_dim\n                    ),\n                    v_cache.view(\n                        -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim\n                    ),\n                    forward_batch.out_cache_loc,\n                    slot_mapping_swa.long() if layer.sliding_window_size > 0 else None,\n                )\n            else:\n                forward_batch.token_to_kv_pool.set_kv_buffer(\n                    layer, forward_batch.out_cache_loc, k, v\n                )\n\n        if self.use_mla:\n            k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)\n\n            work_metadata = self.forward_metadata.work_metadata\n            work_indptr = self.forward_metadata.work_indptr\n            work_info_set = self.forward_metadata.work_info_set\n\n            reduce_indptr = self.forward_metadata.reduce_indptr\n            reduce_final_map = self.forward_metadata.reduce_final_map\n            reduce_partial_map = self.forward_metadata.reduce_partial_map\n\n            num_kv_splits = self.forward_metadata.num_kv_splits\n\n            mla_decode_fwd(\n                q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),\n                k_buffer.view(-1, 1, 1, layer.qk_head_dim),\n                o.view(-1, layer.tp_q_head_num, layer.v_head_dim),\n                self.forward_metadata.qo_indptr,\n                self.forward_metadata.kv_indptr,\n                self.forward_metadata.kv_indices,\n                self.forward_metadata.kv_last_page_len,\n                self.forward_metadata.max_q_len,\n                sm_scale=layer.scaling,\n                logit_cap=layer.logit_cap,\n                work_meta_data=work_metadata,\n                work_indptr=work_indptr,\n                work_info_set=work_info_set,\n                reduce_indptr=reduce_indptr,\n                reduce_final_map=reduce_final_map,\n                reduce_partial_map=reduce_partial_map,\n                q_scale=layer.k_scale if layer.k_scale is not None else self.k_scale,\n                kv_scale=layer.k_scale if layer.k_scale is not None else self.k_scale,\n                intra_batch_mode=intra_batch_mode,\n                num_kv_splits=num_kv_splits,\n            )\n        else:\n            self.logits_soft_cap = layer.logit_cap\n\n            k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(\n                layer.layer_id\n            )\n\n            # TODO kkhuang-amd need to remove it when paged_attention_ragged support fp8-kv\n            if self.kv_cache_dtype == fp8_dtype:\n                dtype = q.dtype\n\n                k_cache = k_cache.to(dtype)\n                v_cache = v_cache.to(dtype)\n\n            if self.use_triton_unified_attention:\n\n                bs = forward_batch.batch_size\n                window_size = (-1, -1)\n                page_table = self.forward_metadata.kv_indices\n\n                if (\n                    layer.sliding_window_size is not None\n                    and layer.sliding_window_size > -1\n                ):\n                    window_size = (layer.sliding_window_size - 1, 0)\n                    page_table = self.forward_metadata.swa_page_table\n\n                o = torch.empty_like(q)\n\n                max_kv_len = page_table.shape[1]\n\n                unified_attention(\n                    q=q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),\n                    k=k_cache.view(\n                        -1, self.page_size, layer.tp_k_head_num, layer.qk_head_dim\n                    ),\n                    v=v_cache.view(\n                        -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim\n                    ),\n                    out=o.view(-1, layer.tp_q_head_num, layer.qk_head_dim),\n                    cu_seqlens_q=self.forward_metadata.qo_indptr,\n                    seqused_k=forward_batch.seq_lens,\n                    max_seqlen_q=self.forward_metadata.max_q_len,\n                    max_seqlen_k=max_kv_len,\n                    softmax_scale=self.scale,\n                    causal=True,\n                    window_size=window_size,\n                    block_table=page_table,\n                    softcap=0,\n                    q_descale=None,\n                    k_descale=None,\n                    v_descale=None,\n                    sinks=sinks,\n                )\n            else:\n                paged_attention_ragged(\n                    o.view(-1, layer.tp_q_head_num, layer.qk_head_dim),\n                    self.workspace_buffer,\n                    q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),\n                    k_cache.view(-1, 1, layer.tp_k_head_num, layer.qk_head_dim),\n                    v_cache.view(-1, 1, layer.tp_v_head_num, layer.v_head_dim),\n                    self.scale,\n                    self.forward_metadata.kv_indptr,\n                    self.forward_metadata.kv_indices,\n                    self.kv_last_page_len,\n                    1,\n                    self.max_num_partitions,\n                    None,\n                    \"auto\",\n                    \"NHD\",\n                    self.logits_soft_cap,\n                    self.k_scale,\n                    self.v_scale,\n                    None,\n                    _AITER_PARTITION_SIZE_ROCM,\n                )\n\n        return o\n\n\nclass AiterIndicesUpdaterPrefill:\n    def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):\n        # Parse Constants\n        self.num_qo_heads = (\n            model_runner.model_config.num_attention_heads // get_attention_tp_size()\n        )\n        self.num_kv_heads = model_runner.model_config.get_num_kv_heads(\n            get_attention_tp_size()\n        )\n        self.head_dim = model_runner.model_config.head_dim\n        self.data_type = model_runner.kv_cache_dtype\n        self.q_data_type = model_runner.dtype\n        self.sliding_window_size = model_runner.sliding_window_size\n        self.attn_backend = attn_backend\n\n        # Buffers and wrappers\n        self.kv_indptr = attn_backend.kv_indptr\n        self.kv_last_page_len = attn_backend.kv_last_page_len\n        self.qo_indptr = attn_backend.qo_indptr\n        self.req_to_token = model_runner.req_to_token_pool.req_to_token\n        self.update = self.update_single_wrapper\n\n        self.kv_indices = None\n        self.max_q_len = 0\n        self.max_kv_len = 0\n\n    def update(\n        self,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        seq_lens_sum: int,\n        prefix_lens: torch.Tensor,\n        encoder_lens: Optional[torch.Tensor],\n        spec_info: Optional[SpecInput],\n    ):\n        # Keep the signature for type checking. It will be assigned during runtime.\n        raise NotImplementedError()\n\n    def update_single_wrapper(\n        self,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        seq_lens_sum: int,\n        prefix_lens: torch.Tensor,\n        encoder_lens: Optional[torch.Tensor],\n        spec_info: Optional[SpecInput],\n    ):\n\n        kv_start_idx = None\n        kv_indptr = self.kv_indptr\n        qo_indptr = self.qo_indptr\n        paged_kernel_lens = seq_lens\n        paged_kernel_lens_sum = seq_lens_sum\n\n        bs = len(req_pool_indices)\n        if spec_info is None:\n            # Normal extend\n            kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)\n            kv_indptr = kv_indptr[: bs + 1]\n\n            # (TODO: Kk) WA - CI test_moe_eval_accuracy_large.py\n            # mha_batch_prefill reads 128 data to do computatoin\n            # if real data is not long enough then original padding value 0 is used\n            # but the 0 location will be made nan (noqa) in cuda graph capture mode\n            # this will cause the output tensor value becomes nan\n            # WA is to assure that last index of pool not changed\n            kv_indices = torch.empty(\n                paged_kernel_lens_sum + 256,\n                dtype=torch.int32,\n                device=req_pool_indices.device,\n            )\n            create_flashinfer_kv_indices_triton[(bs,)](\n                self.req_to_token,\n                req_pool_indices,\n                paged_kernel_lens,\n                kv_indptr,\n                kv_start_idx,\n                kv_indices,\n                self.req_to_token.shape[1],\n            )\n\n            token_num = kv_indptr[-1]\n            kv_indices[token_num:] = kv_indices[0]\n\n            self.max_kv_len = torch.max(paged_kernel_lens).item()\n\n            extend_lens = seq_lens - prefix_lens\n            self.max_q_len = torch.max(extend_lens).item()\n\n            qo_indptr[1 : bs + 1] = torch.cumsum(extend_lens, dim=0)\n            qo_indptr = qo_indptr[: bs + 1]\n            custom_mask = None\n        else:\n            kv_indices, kv_indptr, qo_indptr, custom_mask = (\n                spec_info.generate_attn_arg_prefill(\n                    req_pool_indices,\n                    paged_kernel_lens,\n                    paged_kernel_lens_sum,\n                    self.req_to_token,\n                )\n            )\n\n        self.kv_indices = kv_indices\n\n\nclass AiterMlaIndicesUpdaterPrefill:\n    def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):\n        # Parse Constants\n        self.attn_backend = attn_backend\n\n        # Buffers and wrappers\n        self.req_to_token = model_runner.req_to_token_pool.req_to_token\n        self.update = self.update_single_wrapper\n\n        self.kv_indptr = None\n        self.kv_indices = None\n        self.qo_indptr = None\n        self.kv_last_page_len = None\n        self.max_q_len = 0\n        self.max_kv_len = 0\n\n    def update(\n        self,\n        req_pool_indices: torch.Tensor,\n        kv_lens: torch.Tensor,\n        kv_lens_sum: int,\n        extend_lens: torch.Tensor,\n        max_q_len: int,\n        max_kv_len: int,\n        spec_info: Optional[SpecInput],\n    ):\n        # Keep the signature for type checking. It will be assigned during runtime.\n        raise NotImplementedError()\n\n    def update_single_wrapper(\n        self,\n        req_pool_indices: torch.Tensor,\n        kv_lens: torch.Tensor,\n        kv_lens_sum: int,\n        extend_lens: torch.Tensor,\n        max_q_len: int,\n        max_kv_len: int,\n        spec_info: Optional[SpecInput],\n    ):\n        bs = len(req_pool_indices)\n\n        kv_indptr = self.attn_backend.kv_indptr\n\n        if spec_info is None:\n            # Normal extend\n            kv_indptr[1 : bs + 1] = torch.cumsum(kv_lens, dim=0)\n            kv_indptr = kv_indptr[: bs + 1]\n            kv_indices = torch.empty(\n                kv_lens_sum,\n                dtype=torch.int32,\n                device=req_pool_indices.device,\n            )\n            create_flashinfer_kv_indices_triton[(bs,)](\n                self.req_to_token,\n                req_pool_indices,\n                kv_lens,\n                kv_indptr,\n                None,\n                kv_indices,\n                self.req_to_token.stride(0),\n            )\n\n            qo_indptr = self.attn_backend.qo_indptr\n            qo_indptr[1 : bs + 1] = torch.cumsum(extend_lens, dim=0)\n            qo_indptr = qo_indptr[: bs + 1]\n        else:\n            kv_indices, kv_indptr, qo_indptr, custom_mask = (\n                spec_info.generate_attn_arg_prefill(\n                    req_pool_indices,\n                    kv_lens,\n                    kv_lens_sum,\n                    self.req_to_token,\n                )\n            )\n\n        self.kv_indptr = kv_indptr\n        self.kv_indices = kv_indices\n        self.qo_indptr = qo_indptr\n        self.max_q_len = max_q_len\n        self.max_kv_len = max_kv_len\n\n\nclass AiterMultiStepDraftBackend:\n    \"\"\"\n    Wrap multiple triton attention backends as one for multiple consecutive\n    draft decoding steps.\n    \"\"\"\n\n    def __init__(\n        self,\n        model_runner: ModelRunner,\n        topk: int,\n        speculative_num_steps: int,\n    ):\n        from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices\n\n        self.topk = topk\n        self.speculative_num_steps = speculative_num_steps\n        self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices\n        max_bs = model_runner.req_to_token_pool.size * self.topk\n        self.kv_indptr = torch.zeros(\n            (\n                self.speculative_num_steps,\n                max_bs + 1,\n            ),\n            dtype=torch.int32,\n            device=model_runner.device,\n        )\n        self.attn_backends = []\n        for i in range(self.speculative_num_steps - 1):\n            self.attn_backends.append(\n                AiterAttnBackend(\n                    model_runner,\n                    skip_prefill=True,\n                    kv_indptr_buf=self.kv_indptr[i],\n                    topk=topk,\n                )\n            )\n        self.max_context_len = self.attn_backends[0].max_context_len\n        self.num_head = (\n            model_runner.model_config.num_attention_heads // get_attention_tp_size()\n        )\n        self.device = model_runner.device\n        # Cached variables for generate_draft_decode_kv_indices\n        self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]\n        self.page_size = model_runner.server_args.page_size\n\n    def common_template(\n        self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int\n    ):\n        num_seqs = forward_batch.batch_size\n        bs = self.topk * num_seqs\n        seq_lens_sum = forward_batch.seq_lens_sum\n\n        self.generate_draft_decode_kv_indices[\n            (self.speculative_num_steps, num_seqs, self.topk)\n        ](\n            forward_batch.req_pool_indices,\n            forward_batch.req_to_token_pool.req_to_token,\n            forward_batch.seq_lens,\n            kv_indices_buffer,\n            self.kv_indptr,\n            forward_batch.positions,\n            self.pool_len,\n            kv_indices_buffer.shape[1],\n            self.kv_indptr.shape[1],\n            triton.next_power_of_2(num_seqs),\n            triton.next_power_of_2(self.speculative_num_steps),\n            triton.next_power_of_2(bs),\n            self.page_size,\n        )\n\n        for i in range(self.speculative_num_steps - 1):\n            forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]\n            forward_batch.spec_info.kv_indices = kv_indices_buffer[i][\n                : seq_lens_sum * self.topk + bs * (i + 1)\n            ]\n            call_fn(i, forward_batch)\n\n    def init_forward_metadata(self, forward_batch: ForwardBatch):\n        kv_indices = torch.empty(\n            (\n                self.speculative_num_steps,\n                forward_batch.batch_size * self.topk * self.max_context_len,\n            ),\n            dtype=torch.int32,\n            device=self.device,\n        )\n\n        def call_fn(i, forward_batch):\n            forward_batch.spec_info.kv_indptr = (\n                forward_batch.spec_info.kv_indptr.clone()\n            )\n            forward_batch.spec_info.kv_indices = (\n                forward_batch.spec_info.kv_indices.clone()\n            )\n            self.attn_backends[i].init_forward_metadata(forward_batch)\n\n        self.common_template(forward_batch, kv_indices, call_fn)\n\n    def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):\n        self.cuda_graph_kv_indices = torch.zeros(\n            (self.speculative_num_steps, max_num_tokens * self.max_context_len),\n            dtype=torch.int32,\n            device=self.device,\n        )\n        for i in range(self.speculative_num_steps - 1):\n            self.attn_backends[i].init_cuda_graph_state(\n                max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]\n            )\n\n    def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):\n        def call_fn(i, forward_batch):\n            self.attn_backends[i].init_forward_metadata_capture_cuda_graph(\n                forward_batch.batch_size,\n                forward_batch.batch_size * self.topk,\n                forward_batch.req_pool_indices,\n                forward_batch.seq_lens,\n                encoder_lens=None,\n                forward_mode=ForwardMode.DECODE,\n                spec_info=forward_batch.spec_info,\n            )\n\n        self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)\n\n    def init_forward_metadata_replay_cuda_graph(\n        self, forward_batch: ForwardBatch, bs: int\n    ):\n        def call_fn(i, forward_batch):\n            self.attn_backends[i].init_forward_metadata_replay_cuda_graph(\n                bs,\n                forward_batch.req_pool_indices,\n                forward_batch.seq_lens,\n                seq_lens_sum=-1,\n                encoder_lens=None,\n                forward_mode=ForwardMode.DECODE,\n                spec_info=forward_batch.spec_info,\n                seq_lens_cpu=None,\n            )\n\n        self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/attention_registry.py",
    "content": "import logging\nfrom typing import TYPE_CHECKING\n\nlogger = logging.getLogger(__name__)\n\n\nif TYPE_CHECKING:\n    # evade circular imports\n    from sglang.srt.layers.attention.base_attn_backend import AttentionBackend\n    from sglang.srt.model_executor.model_runner import ModelRunner\n\nATTENTION_BACKENDS = {}\n\n\ndef register_attention_backend(name):\n    def decorator(fn):\n        ATTENTION_BACKENDS[name] = fn\n        return fn\n\n    return decorator\n\n\n@register_attention_backend(\"flashinfer\")\ndef create_flashinfer_backend(runner):\n    import torch\n\n    if not runner.use_mla_backend:\n        from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend\n\n        # Init streams\n        if runner.server_args.speculative_algorithm == \"EAGLE\":\n            if (\n                not hasattr(runner, \"plan_stream_for_flashinfer\")\n                or not runner.plan_stream_for_flashinfer\n            ):\n                runner.plan_stream_for_flashinfer = torch.cuda.Stream()\n        return FlashInferAttnBackend(\n            runner, init_new_workspace=runner.init_new_workspace\n        )\n    else:\n        from sglang.srt.layers.attention.flashinfer_mla_backend import (\n            FlashInferMLAAttnBackend,\n        )\n\n        return FlashInferMLAAttnBackend(runner)\n\n\n@register_attention_backend(\"trtllm_mla\")\ndef create_trtllm_mla_backend(runner):\n    if not runner.use_mla_backend:\n        raise ValueError(\"trtllm_mla backend can only be used with MLA models.\")\n    from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend\n\n    return TRTLLMMLABackend(runner)\n\n\n@register_attention_backend(\"aiter\")\ndef create_aiter_backend(runner):\n    from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend\n\n    return AiterAttnBackend(runner)\n\n\n@register_attention_backend(\"wave\")\ndef create_wave_backend(runner):\n    from sglang.srt.layers.attention.wave_backend import WaveAttnBackend\n\n    return WaveAttnBackend(runner)\n\n\n@register_attention_backend(\"ascend\")\ndef create_ascend_backend(runner):\n    from sglang.srt.hardware_backend.npu.attention.ascend_backend import (\n        AscendAttnBackend,\n    )\n\n    return AscendAttnBackend(runner)\n\n\n@register_attention_backend(\"nsa\")\ndef create_nsa_backend(runner):\n    from sglang.srt.layers.attention.nsa_backend import NativeSparseAttnBackend\n\n    return NativeSparseAttnBackend(runner)\n\n\n@register_attention_backend(\"triton\")\ndef create_triton_backend(runner):\n    assert not runner.model_config.is_encoder_decoder, (\n        \"Cross attention is not supported in the triton attention backend. \"\n        \"Please use `--attention-backend flashinfer`.\"\n    )\n    if runner.server_args.enable_double_sparsity:\n        from sglang.srt.layers.attention.double_sparsity_backend import (\n            DoubleSparseAttnBackend,\n        )\n\n        return DoubleSparseAttnBackend(runner)\n    else:\n        from sglang.srt.layers.attention.triton_backend import TritonAttnBackend\n\n        return TritonAttnBackend(runner)\n\n\n@register_attention_backend(\"torch_native\")\ndef create_torch_native_backend(runner):\n    from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend\n\n    return TorchNativeAttnBackend(runner)\n\n\n@register_attention_backend(\"flex_attention\")\ndef create_flex_attention_backend(runner):\n    from sglang.srt.layers.attention.torch_flex_backend import TorchFlexAttnBackend\n\n    return TorchFlexAttnBackend(runner)\n\n\n@register_attention_backend(\"flashmla\")\ndef create_flashmla_backend(runner):\n    from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend\n\n    return FlashMLABackend(runner)\n\n\n@register_attention_backend(\"fa3\")\ndef create_flashattention_v3_backend(runner):\n    import torch\n\n    assert (\n        torch.cuda.get_device_capability()[0] == 8 and not runner.use_mla_backend\n    ) or torch.cuda.get_device_capability()[0] == 9, (\n        \"FlashAttention v3 Backend requires SM>=80 and SM<=90. \"\n        \"Please use `--attention-backend flashinfer`.\"\n    )\n    from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend\n\n    return FlashAttentionBackend(runner)\n\n\n@register_attention_backend(\"fa4\")\ndef create_flashattention_v4_backend(runner):\n    from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend\n\n    return FlashAttentionBackend(runner, fa_impl_ver=4)\n\n\n@register_attention_backend(\"cutlass_mla\")\ndef create_cutlass_mla_backend(runner):\n    from sglang.srt.layers.attention.cutlass_mla_backend import CutlassMLABackend\n\n    return CutlassMLABackend(runner)\n\n\n@register_attention_backend(\"trtllm_mha\")\ndef create_trtllm_mha_backend(runner):\n    if runner.use_mla_backend:\n        raise ValueError(\"trtllm_mha backend can only be used with non-MLA models.\")\n    from sglang.srt.layers.attention.trtllm_mha_backend import TRTLLMHAAttnBackend\n\n    return TRTLLMHAAttnBackend(runner)\n\n\n@register_attention_backend(\"intel_amx\")\ndef create_intel_amx_backend(runner):\n    from sglang.srt.layers.attention.intel_amx_backend import IntelAMXAttnBackend\n\n    return IntelAMXAttnBackend(runner)\n\n\n@register_attention_backend(\"dual_chunk_flash_attn\")\ndef create_dual_chunk_flash_attn_backend(runner):\n    from sglang.srt.layers.attention.dual_chunk_flashattention_backend import (\n        DualChunkFlashAttentionBackend,\n    )\n\n    return DualChunkFlashAttentionBackend(runner)\n\n\ndef attn_backend_wrapper(runner: \"ModelRunner\", full_attn_backend: \"AttentionBackend\"):\n    \"\"\"\n    Wrapper for special models like hybrid GDN, so we don't\n    need to change the code of the original attention backend.\n    \"\"\"\n    assert not (\n        runner.hybrid_gdn_config is not None and runner.use_mla_backend\n    ), \"hybrid_gdn can only be used with non-MLA models.\"\n\n    if cfg := runner.mambaish_config:\n        from sglang.srt.layers.attention.fla.utils import check_environments\n        from sglang.srt.layers.attention.hybrid_linear_attn_backend import (\n            HybridLinearAttnBackend,\n            Mamba2AttnBackend,\n        )\n        from sglang.srt.layers.attention.linear.gdn_backend import GDNAttnBackend\n        from sglang.srt.layers.attention.linear.kda_backend import KDAAttnBackend\n        from sglang.srt.layers.attention.linear.lightning_backend import (\n            LightningAttentionBackend,\n        )\n        from sglang.srt.layers.attention.linear.utils import (\n            initialize_linear_attn_config,\n        )\n        from sglang.srt.utils import is_blackwell, is_npu\n\n        check_environments()\n        initialize_linear_attn_config(runner.server_args)\n        if runner.hybrid_gdn_config is not None:\n            if is_blackwell():\n                assert (\n                    runner.server_args.attention_backend == \"triton\"\n                    or runner.server_args.attention_backend == \"trtllm_mha\"\n                    or runner.server_args.attention_backend == \"fa4\"\n                    or runner.server_args.attention_backend == \"flashinfer\"\n                ), \"triton, trtllm_mha, fa4, or flashinfer backend are the only supported backends on Blackwell GPUs for hybrid GDN models, use --attention-backend to specify the backend.\"\n            if is_npu():\n                assert (\n                    runner.server_args.attention_backend == \"ascend\"\n                ), \"ascend backend is the only supported backend on NPU for hybrid GDN models, use --attention-backend ascend to specify the backend.\"\n            logger.info(f\"Using hybrid linear attention backend for hybrid GDN models.\")\n            linear_attn_backend = GDNAttnBackend(runner)\n        elif runner.mamba2_config is not None:\n            linear_attn_backend = Mamba2AttnBackend(runner)\n        elif runner.kimi_linear_config is not None:\n            linear_attn_backend = KDAAttnBackend(runner)\n        elif runner.hybrid_lightning_config is not None:\n            linear_attn_backend = LightningAttentionBackend(runner)\n        else:\n            raise ValueError(\n                \"Expected hybrid GDN or NemotronH models, but got unknown model.\"\n            )\n        full_attn_layers = cfg.full_attention_layer_ids\n        return HybridLinearAttnBackend(\n            full_attn_backend, linear_attn_backend, full_attn_layers\n        )\n\n    return full_attn_backend\n\n\n@register_attention_backend(\"intel_xpu\")\ndef create_intel_xpu_backend(runner):\n    from sglang.srt.layers.attention.xpu_backend import XPUAttentionBackend\n\n    return XPUAttentionBackend(runner)\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/base_attn_backend.py",
    "content": "from __future__ import annotations\n\nfrom abc import ABC, abstractmethod\nfrom typing import TYPE_CHECKING, Optional\n\nimport torch\n\nfrom sglang.srt.utils.common import is_npu\n\nif TYPE_CHECKING:\n    from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata\n    from sglang.srt.layers.radix_attention import RadixAttention\n    from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode\n    from sglang.srt.speculative.spec_info import SpecInput\n\n\nclass AttentionBackend(ABC):\n    \"\"\"The base class of attention backends\"\"\"\n\n    @abstractmethod\n    def init_forward_metadata(self, forward_batch: ForwardBatch):\n        \"\"\"Init the metadata for a forward pass.\"\"\"\n        raise NotImplementedError()\n\n    def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):\n        \"\"\"Init the global shared states for cuda graph.\"\"\"\n        raise NotImplementedError()\n\n    def init_forward_metadata_capture_cuda_graph(\n        self,\n        bs: int,\n        num_tokens: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        encoder_lens: Optional[torch.Tensor],\n        forward_mode: ForwardMode,\n        spec_info: Optional[SpecInput],\n    ):\n        \"\"\"Init the metadata for a forward pass for capturing a cuda graph.\"\"\"\n        raise NotImplementedError()\n\n    def init_forward_metadata_replay_cuda_graph(\n        self,\n        bs: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        seq_lens_sum: int,\n        encoder_lens: Optional[torch.Tensor],\n        forward_mode: ForwardMode,\n        spec_info: Optional[SpecInput],\n        seq_lens_cpu: Optional[torch.Tensor],\n    ):\n        \"\"\"Init the metadata for a forward pass for replaying a cuda graph.\"\"\"\n        raise NotImplementedError()\n\n    def get_cuda_graph_seq_len_fill_value(self):\n        \"\"\"Get the fill value for padded seq lens. Typically, it is 0 or 1.\"\"\"\n        raise NotImplementedError()\n\n    def get_verify_buffers_to_fill_after_draft(self):\n        \"\"\"\n        Return buffers of verify attention kernels that needs to be filled after draft.\n\n        Typically, these are tree mask and position buffers.\n        \"\"\"\n        return [None, None]\n\n    def update_verify_buffers_to_fill_after_draft(\n        self, spec_info: SpecInput, cuda_graph_bs: Optional[int]\n    ):\n        \"\"\"\n        Update the buffers returned by get_verify_fill_after_draft_buffers if needed.\n\n        Here, we need to redo the computation of all metadata of the attention backend\n        that depends on tree mask and position buffers.\n        \"\"\"\n        raise NotImplementedError()\n\n    def forward(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache: bool = True,\n        **kwargs,\n    ):\n        \"\"\"Run forward on an attention layer.\"\"\"\n        if forward_batch.forward_mode.is_idle():\n            return q.new_empty(q.shape[0], layer.tp_q_head_num * layer.v_head_dim)\n        elif forward_batch.forward_mode.is_decode():\n            return self.forward_decode(\n                q,\n                k,\n                v,\n                layer,\n                forward_batch,\n                save_kv_cache=save_kv_cache,\n                **kwargs,\n            )\n        elif forward_batch.forward_mode.is_mixed() and is_npu():\n            return self.forward_mixed(\n                q,\n                k,\n                v,\n                layer,\n                forward_batch,\n                save_kv_cache=save_kv_cache,\n                **kwargs,\n            )\n        else:\n            return self.forward_extend(\n                q,\n                k,\n                v,\n                layer,\n                forward_batch,\n                save_kv_cache=save_kv_cache,\n                **kwargs,\n            )\n\n    def forward_decode(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache: bool = True,\n    ):\n        \"\"\"Run a forward for decode.\"\"\"\n        raise NotImplementedError()\n\n    def forward_extend(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache: bool = True,\n    ):\n        \"\"\"Run a forward for extend.\"\"\"\n        raise NotImplementedError()\n\n    def forward_mixed(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache: bool = True,\n    ):\n        \"\"\"Run a forward for mix.\"\"\"\n        raise NotImplementedError()\n\n    def support_triton(self):\n        \"\"\"Check if the current backend supports triton.\"\"\"\n        return True\n\n    def get_indexer_metadata(\n        self,\n        layer_id: int,\n        forward_batch: ForwardBatch,\n    ) -> Optional[BaseIndexerMetadata]:\n        \"\"\"Get the indexer metadata. None means don't support indexer.\"\"\"\n        return None\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/cutlass_mla_backend.py",
    "content": "from __future__ import annotations\n\n\"\"\"\nSupport attention backend for Cutlass MLA.\n\n\"\"\"\n\nfrom dataclasses import dataclass\nfrom typing import TYPE_CHECKING, Optional, Union\n\nimport torch\nimport triton\n\nfrom sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend\nfrom sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton\nfrom sglang.srt.layers.dp_attention import get_attention_tp_size\nfrom sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode\nfrom sglang.srt.utils import is_cuda\n\nif TYPE_CHECKING:\n    from sglang.srt.layers.radix_attention import RadixAttention\n    from sglang.srt.model_executor.model_runner import ModelRunner\n    from sglang.srt.speculative.spec_info import SpecInput\n\n_is_cuda = is_cuda()\nif _is_cuda:\n    from sgl_kernel import cutlass_mla_decode, cutlass_mla_get_workspace_size\n\n\n# Cutlass MLA only supports pagesize=128\nPAGE_SIZE = 128\n\n\n@dataclass\nclass CutlassMLADecodeMetadata:\n    workspace: Optional[torch.Tensor] = None\n    block_kv_indices: Optional[torch.Tensor] = None\n\n    def __init__(\n        self,\n        workspace: Optional[torch.Tensor] = None,\n        block_kv_indices: Optional[torch.Tensor] = None,\n    ):\n        self.workspace = workspace\n        self.block_kv_indices = block_kv_indices\n\n\nclass CutlassMLABackend(FlashInferMLAAttnBackend):\n    \"\"\"Cutlass attention kernels.\"\"\"\n\n    def __init__(\n        self,\n        model_runner: ModelRunner,\n        skip_prefill: bool = False,\n        kv_indptr_buf: Optional[torch.Tensor] = None,\n        kv_last_page_len_buf: Optional[torch.Tensor] = None,\n    ):\n        super().__init__(\n            model_runner, skip_prefill, kv_indptr_buf, kv_last_page_len_buf\n        )\n\n        self.num_q_heads = (\n            model_runner.model_config.num_attention_heads // get_attention_tp_size()\n        )\n        self.num_kv_heads = model_runner.model_config.get_num_kv_heads(\n            get_attention_tp_size()\n        )\n        self.req_to_token = model_runner.req_to_token_pool.req_to_token\n        self.num_local_heads = (\n            model_runner.model_config.num_attention_heads // get_attention_tp_size()\n        )\n        self.forward_metadata: Union[CutlassMLADecodeMetadata] = None\n        self.kv_lora_rank = model_runner.model_config.kv_lora_rank\n        self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim\n        self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim\n        self.v_head_dim = model_runner.model_config.v_head_dim\n        self.scaling = model_runner.model_config.scaling\n        self.data_type = model_runner.kv_cache_dtype\n        self.q_data_type = model_runner.dtype\n        self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim\n\n    def init_forward_metadata(self, forward_batch: ForwardBatch):\n\n        bs = forward_batch.batch_size\n        spec_info = forward_batch.spec_info\n        if forward_batch.forward_mode.is_decode_or_idle():\n            if spec_info is None:\n                max_seqlen_pad = triton.cdiv(\n                    forward_batch.seq_lens_cpu.max().item(), PAGE_SIZE\n                )\n                block_kv_indices = torch.full(\n                    (bs, max_seqlen_pad),\n                    -1,\n                    dtype=torch.int32,\n                    device=forward_batch.seq_lens.device,\n                )\n                create_flashmla_kv_indices_triton[(bs,)](\n                    self.req_to_token,\n                    forward_batch.req_pool_indices,\n                    forward_batch.seq_lens,\n                    None,\n                    block_kv_indices,\n                    self.req_to_token.stride(0),\n                    max_seqlen_pad,\n                    PAGED_SIZE=PAGE_SIZE,\n                )\n                workspace_size = cutlass_mla_get_workspace_size(\n                    max_seqlen_pad * PAGE_SIZE, bs, num_kv_splits=1\n                )\n                workspace = torch.empty(\n                    workspace_size, device=\"cuda\", dtype=torch.uint8\n                )\n                self.forward_metadata = CutlassMLADecodeMetadata(\n                    workspace,\n                    block_kv_indices,\n                )\n            else:\n                super().init_forward_metadata(forward_batch)\n        else:\n            super().init_forward_metadata(forward_batch)\n\n    def init_cuda_graph_state(\n        self,\n        max_bs: int,\n        max_num_tokens: int,\n        block_kv_indices: Optional[torch.Tensor] = None,\n    ):\n        if block_kv_indices is None:\n            cuda_graph_kv_indices = torch.full(\n                (max_bs, (self.max_context_len + PAGE_SIZE) // PAGE_SIZE),\n                1,\n                dtype=torch.int32,\n                device=\"cuda\",\n            )\n        else:\n            cuda_graph_kv_indices = block_kv_indices\n\n        workspace_size = cutlass_mla_get_workspace_size(\n            cuda_graph_kv_indices.shape[1] * PAGE_SIZE, max_bs, num_kv_splits=1\n        )\n        self.cuda_graph_mla_workspace = torch.empty(\n            workspace_size, device=\"cuda\", dtype=torch.uint8\n        )\n        self.cuda_graph_kv_indices = cuda_graph_kv_indices\n\n    def init_forward_metadata_capture_cuda_graph(\n        self,\n        bs: int,\n        num_tokens: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        encoder_lens: Optional[torch.Tensor],\n        forward_mode: ForwardMode,\n        spec_info: Optional[SpecInput],\n    ):\n        if forward_mode.is_decode_or_idle():\n            if spec_info is None:\n                max_seqlen_pad = self.cuda_graph_kv_indices.shape[1]\n\n                create_flashmla_kv_indices_triton[(bs,)](\n                    self.req_to_token,\n                    req_pool_indices,\n                    seq_lens,\n                    None,\n                    self.cuda_graph_kv_indices,\n                    self.req_to_token.stride(0),\n                    self.cuda_graph_kv_indices.stride(0),\n                    PAGED_SIZE=PAGE_SIZE,\n                )\n                self.forward_metadata = CutlassMLADecodeMetadata(\n                    self.cuda_graph_mla_workspace,\n                    self.cuda_graph_kv_indices[:bs, :max_seqlen_pad],\n                )\n        else:\n            super().init_forward_metadata_capture_cuda_graph(\n                bs,\n                num_tokens,\n                req_pool_indices,\n                seq_lens,\n                encoder_lens,\n                forward_mode,\n                spec_info,\n            )\n\n    def init_forward_metadata_replay_cuda_graph(\n        self,\n        bs: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        seq_lens_sum: int,\n        encoder_lens: Optional[torch.Tensor],\n        forward_mode: ForwardMode,\n        spec_info: Optional[SpecInput],\n        seq_lens_cpu: Optional[torch.Tensor],\n    ):\n\n        if forward_mode.is_decode_or_idle():\n            assert seq_lens_cpu is not None\n            seq_lens = seq_lens[:bs]\n\n            create_flashmla_kv_indices_triton[(bs,)](\n                self.req_to_token,\n                req_pool_indices[:bs],\n                seq_lens,\n                None,\n                self.cuda_graph_kv_indices,\n                self.req_to_token.stride(0),\n                self.cuda_graph_kv_indices.stride(0),\n                PAGED_SIZE=PAGE_SIZE,\n            )\n        else:\n            super().init_forward_metadata_replay_cuda_graph(\n                bs,\n                req_pool_indices,\n                seq_lens,\n                seq_lens_sum,\n                encoder_lens,\n                forward_mode,\n                spec_info,\n                seq_lens_cpu,\n            )\n\n    def get_cuda_graph_seq_len_fill_value(self):\n        return 1\n\n    def forward_decode(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache: bool = True,\n        # For multi-head latent attention\n        q_rope: Optional[torch.Tensor] = None,\n        k_rope: Optional[torch.Tensor] = None,\n    ):\n        cache_loc = forward_batch.out_cache_loc\n\n        if k is not None:\n            assert v is not None\n            if save_kv_cache:\n                if k_rope is not None:\n                    forward_batch.token_to_kv_pool.set_mla_kv_buffer(\n                        layer,\n                        cache_loc,\n                        k,\n                        k_rope,\n                    )\n                else:\n                    forward_batch.token_to_kv_pool.set_kv_buffer(\n                        layer,\n                        cache_loc,\n                        k,\n                        v,\n                    )\n\n        # Reshape inputs\n        if q_rope is not None:\n            q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)\n            q_rope = q_rope.view(\n                -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim\n            )\n        else:\n            reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)\n            q_nope = reshaped_q[:, :, : layer.v_head_dim]\n            q_rope = reshaped_q[:, :, layer.v_head_dim :]\n\n        q_nope = q_nope.to(self.q_data_type)\n        q_rope = q_rope.to(self.q_data_type)\n\n        k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)\n\n        o = cutlass_mla_decode(\n            q_nope=q_nope,\n            q_pe=q_rope,\n            kv_c_and_k_pe_cache=k_cache.view(-1, PAGE_SIZE, self.kv_cache_dim),\n            seq_lens=forward_batch.seq_lens.to(torch.int32),\n            page_table=self.forward_metadata.block_kv_indices,\n            workspace=self.forward_metadata.workspace,\n            sm_scale=layer.scaling,\n            num_kv_splits=1,\n        )\n\n        return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/double_sparsity_backend.py",
    "content": "from __future__ import annotations\n\nfrom typing import TYPE_CHECKING\n\nimport torch\n\nfrom sglang.srt.layers.attention.base_attn_backend import AttentionBackend\nfrom sglang.srt.model_executor.forward_batch_info import ForwardBatch\nfrom sglang.srt.server_args import get_global_server_args\n\nif TYPE_CHECKING:\n    from sglang.srt.layers.radix_attention import RadixAttention\n    from sglang.srt.model_executor.model_runner import ModelRunner\n\n\nclass DoubleSparseAttnBackend(AttentionBackend):\n    def __init__(self, model_runner: ModelRunner):\n        # Lazy import to avoid the initialization of cuda context\n        from sglang.srt.layers.attention.triton_ops.double_sparsity_attention import (\n            extend_attention_fwd,\n            flash_decode_attention_fwd,\n            flash_decode_sparse_attention_fwd,\n        )\n\n        super().__init__()\n\n        self.decode_attention_fwd = flash_decode_attention_fwd\n        self.decode_sparse_attention_fwd = flash_decode_sparse_attention_fwd\n        self.extend_attention_fwd = extend_attention_fwd\n        self.num_head = model_runner.model_config.num_attention_heads\n        self.head_dim = model_runner.model_config.hidden_size // self.num_head\n        self.heavy_token_num = model_runner.server_args.ds_heavy_token_num\n\n        self.sorted_channels = model_runner.sorted_channels\n        self.sparse_decode_threshold = (\n            model_runner.server_args.ds_sparse_decode_threshold\n        )\n        self.att_out_approx: torch.Tensor = None\n        self.mid_out: torch.Tensor = None\n        self.mid_o_logexpsum: torch.Tensor = None\n\n        # TODO: Change the hard-coded block_seq_num\n        self.BLOCK_SEQ = 128\n\n        if get_global_server_args().triton_attention_reduce_in_fp32:\n            self.reduce_dtype = torch.float32\n        else:\n            self.reduce_dtype = torch.float16\n\n        self.forward_metadata = None\n\n    def init_forward_metadata(self, forward_batch: ForwardBatch):\n        \"\"\"Init auxiliary variables for triton attention backend.\"\"\"\n\n        if forward_batch.forward_mode.is_decode():\n            start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32)\n            start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0)\n\n            total_num_tokens = torch.sum(forward_batch.seq_lens).item()\n            attn_logits = torch.empty(\n                (self.num_head, total_num_tokens),\n                dtype=self.reduce_dtype,\n                device=\"cuda\",\n            )\n\n            max_seq_len = torch.max(forward_batch.seq_lens).item()\n            min_seq_len = torch.min(forward_batch.seq_lens).item()\n            max_extend_len = None\n            # NOTE: Align sequence order with req_to_token order\n            ds_req_to_token = forward_batch.req_to_token_pool.req_to_token[\n                forward_batch.req_pool_indices\n            ]\n\n            bsz = forward_batch.seq_lens.shape[0]\n\n            att_out_approx = torch.empty(\n                [self.num_head, bsz, max_seq_len],\n                dtype=self.reduce_dtype,\n                device=\"cuda\",\n            )\n\n            block_seq_num = (\n                self.heavy_token_num + self.BLOCK_SEQ - 1\n            ) // self.BLOCK_SEQ\n\n            mid_out = torch.empty(\n                [bsz, self.num_head, block_seq_num, self.head_dim],\n                dtype=torch.float32,\n                device=\"cuda\",\n            )\n            mid_o_logexpsum = torch.empty(\n                [bsz, self.num_head, block_seq_num], dtype=torch.float32, device=\"cuda\"\n            )\n            self.att_out_approx = att_out_approx\n            self.mid_out = mid_out\n            self.mid_o_logexpsum = mid_o_logexpsum\n\n        else:\n            start_loc = attn_logits = max_seq_len = min_seq_len = None\n            prefix_lens = forward_batch.extend_prefix_lens\n            max_extend_len = torch.max(forward_batch.seq_lens - prefix_lens).item()\n            ds_req_to_token = None\n\n        self.forward_metadata = (\n            start_loc,\n            attn_logits,\n            max_seq_len,\n            min_seq_len,\n            max_extend_len,\n            ds_req_to_token,\n        )\n\n    def forward_extend(\n        self,\n        q,\n        k,\n        v,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache=True,\n    ):\n        # TODO: reuse the buffer across layers\n        if layer.qk_head_dim != layer.v_head_dim:\n            o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))\n        else:\n            o = torch.empty_like(q)\n\n        k_label = torch.gather(\n            k,\n            2,\n            self.sorted_channels[layer.layer_id]\n            .unsqueeze(0)\n            .expand(k.shape[0], -1, -1),\n        )\n\n        if save_kv_cache:\n            forward_batch.token_to_kv_pool.set_kv_buffer(\n                layer, forward_batch.out_cache_loc, k, v, k_label\n            )\n\n        (\n            start_loc,\n            attn_logits,\n            max_seq_len,\n            min_seq_len,\n            max_extend_len,\n            ds_req_to_token,\n        ) = self.forward_metadata\n        self.extend_attention_fwd(\n            q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),\n            k.contiguous(),\n            v.contiguous(),\n            o.view(-1, layer.tp_q_head_num, layer.v_head_dim),\n            forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),\n            forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),\n            forward_batch.req_to_token_pool.req_to_token,\n            forward_batch.req_pool_indices,\n            forward_batch.seq_lens,\n            forward_batch.extend_seq_lens,\n            forward_batch.extend_start_loc,\n            max_extend_len,\n            layer.scaling,\n            layer.logit_cap,\n        )\n        return o\n\n    def forward_decode(\n        self,\n        q,\n        k,\n        v,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache=True,\n    ):\n        # During torch.compile, there is a bug in rotary_emb that causes the\n        # output value to have a 3D tensor shape. This reshapes the output correctly.\n        q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)\n\n        # TODO: reuse the buffer across layers\n        if layer.qk_head_dim != layer.v_head_dim:\n            o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))\n        else:\n            o = torch.empty_like(q)\n\n        # TODO: Add min seqlen\n        (\n            start_loc,\n            attn_logits,\n            max_seq_len,\n            min_seq_len,\n            max_extend_len,\n            ds_req_to_token,\n        ) = self.forward_metadata\n\n        k_label = torch.gather(\n            k,\n            2,\n            self.sorted_channels[layer.layer_id]\n            .unsqueeze(0)\n            .expand(k.shape[0], -1, -1),\n        )\n\n        if save_kv_cache:\n            forward_batch.token_to_kv_pool.set_kv_buffer(\n                layer, forward_batch.out_cache_loc, k, v, k_label\n            )\n\n        # NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num\n        #            and set a minimum value for sparse_decode\n        if (\n            min_seq_len < self.heavy_token_num\n            or max_seq_len < self.sparse_decode_threshold\n        ):\n            self.decode_attention_fwd(\n                q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),\n                forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),\n                forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),\n                o.view(-1, layer.tp_q_head_num, layer.v_head_dim),\n                forward_batch.req_to_token_pool.req_to_token,\n                forward_batch.req_pool_indices,\n                start_loc,\n                forward_batch.seq_lens,\n                attn_logits,\n                max_seq_len,\n                layer.scaling,\n                layer.logit_cap,\n            )\n        else:\n            # TODO(Andy): indexing with torch.gather or torch.index_select or customized kernel\n            q_label = torch.gather(\n                q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),\n                2,\n                self.sorted_channels[layer.layer_id]\n                .unsqueeze(0)\n                .expand(q.shape[0], -1, -1),\n            )\n            self.decode_sparse_attention_fwd(\n                q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),\n                forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),\n                forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),\n                o.view(-1, layer.tp_q_head_num, layer.qk_head_dim),\n                q_label,\n                forward_batch.token_to_kv_pool.get_label_buffer(layer.layer_id),\n                ds_req_to_token,\n                forward_batch.seq_lens,\n                max_seq_len,\n                layer.scaling,\n                layer.logit_cap,\n                self.heavy_token_num,\n                self.att_out_approx,\n                self.mid_out,\n                self.mid_o_logexpsum,\n                self.BLOCK_SEQ,\n            )\n\n        return o\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\"\"\"Attention layer with Dual chunk flash attention and sparse attention.\"\"\"\n\nimport functools\nimport logging\nimport math\nfrom dataclasses import dataclass\nfrom typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple\n\nimport torch\nimport torch.nn.functional as F\nfrom sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache\nfrom sgl_kernel.sparse_flash_attn import (\n    convert_vertical_slash_indexes,\n    convert_vertical_slash_indexes_mergehead,\n    sparse_attn_func,\n)\n\nfrom sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank\nfrom sglang.srt.layers.attention.base_attn_backend import AttentionBackend\nfrom sglang.srt.layers.attention.flashattention_backend import FlashAttentionMetadata\nfrom sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode\n\nif TYPE_CHECKING:\n    from sglang.srt.layers.radix_attention import RadixAttention\n    from sglang.srt.model_executor.model_runner import ModelRunner\n\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclass\nclass DualChunkFlashAttentionMetadata:\n    \"\"\"Metadata for FlashAttentionBackend.\n\n    NOTE: Any python object stored here is not updated when it is\n    cuda-graph replayed. If you have values that need to be changed\n    dynamically, it should be stored in tensor. The tensor has to be\n    updated from `CUDAGraphRunner.forward` API.\n    \"\"\"\n\n    # (batch_size,). The sequence length per sequence. Sequence length means\n    # the computed tokens + new tokens None if it is a decoding.\n    seq_lens: Optional[List[int]] = None\n    # seq_lens stored as a tensor.\n    seq_lens_tensor: Optional[torch.Tensor] = None\n    # Maximum sequence length among prefill batch. 0 if there are decoding\n    # requests only.\n    max_seq_len: int = None\n\n    # (batch_size,). The orig sequence length per sequence.\n    orig_seq_lens: Optional[List[int]] = None\n\n    # orig_seq_lens stored as a tensor.\n    orig_seq_lens_tensor: Optional[torch.Tensor] = None\n\n    # Block addresses per sequence. (Seq id -> list of physical block)\n    # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks\n    # in the kv cache. Each block can contain up to block_size tokens.\n    # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph\n    # captured.\n    block_tables: Optional[torch.Tensor] = None\n\n    # (batch_size + 1,). The cumulative subquery lengths of the sequences in\n    # the batch, used to index into subquery. E.g., if the subquery length\n    # is [4, 6], it is [0, 4, 10].\n    query_start_loc: Optional[torch.Tensor] = None\n    # (batch_size + 1,). The cumulative sequence lengths of the sequences in\n    # the batch, used to index into sequence. E.g., if the sequence length is\n    # [4, 6], it is [0, 4, 10].\n    seq_start_loc: Optional[torch.Tensor] = None\n\n    # Length scaling factor\n    scaling_factor: Optional[torch.Tensor] = None\n\n    # (batch_size,). Sequence lengths for intra attention.\n    seq_lens_intra: Optional[torch.Tensor] = None\n\n    # Max sequence length for intra attention.\n    max_seq_len_intra: Optional[int] = None\n\n    # (batch_size, num_blocks). Block table for intra attention.\n    block_tables_intra: Optional[torch.Tensor] = None\n\n    # (batch_size,). Sequence lengths for succ attention.\n    seq_lens_succ: Optional[torch.Tensor] = None\n\n    # Max sequence length for succ attention.\n    max_seq_len_succ: Optional[int] = None\n\n    # (batch_size, num_blocks). Block table for succ attention.\n    block_tables_succ: Optional[torch.Tensor] = None\n\n    # (batch_size,). Sequence lengths for inter attention.\n    seq_lens_inter: Optional[torch.Tensor] = None\n\n    # Max sequence length for inter attention.\n    max_seq_len_inter: Optional[int] = None\n\n\nclass DualChunkFlashAttentionBackend(AttentionBackend):\n    def __init__(\n        self,\n        model_runner: \"ModelRunner\",\n    ) -> None:\n        self.forward_metadata: FlashAttentionMetadata = None\n        self.device = model_runner.device\n        self.max_context_len = model_runner.model_config.context_len\n        self.num_heads = model_runner.model_config.get_num_attention_heads(\n            model_runner.server_args.tp_size\n        )\n        self.num_kv_heads = model_runner.model_config.get_num_kv_heads(\n            model_runner.server_args.tp_size\n        )\n        self.head_size = model_runner.model_config.head_dim\n\n        self.req_to_token = model_runner.req_to_token_pool.req_to_token\n        self.kv_cache_dtype = model_runner.kv_cache_dtype\n        self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype\n        self.page_size = model_runner.page_size\n\n        assert self.num_heads % self.num_kv_heads == 0\n        self.num_queries_per_kv = self.num_heads // self.num_kv_heads\n\n        dual_chunk_attention_config = getattr(\n            model_runner.model_config.hf_config, \"dual_chunk_attention_config\", None\n        )\n        assert dual_chunk_attention_config is not None\n        self.chunk_size = dual_chunk_attention_config.get(\"chunk_size\", 8192)\n        self.local_size = dual_chunk_attention_config.get(\"local_size\", 1024)\n        self.original_max_position_embeddings = dual_chunk_attention_config.get(\n            \"original_max_position_embeddings\", 0\n        )\n        self.sparse_attention_config = dual_chunk_attention_config.get(\n            \"sparse_attention_config\", None\n        )\n        if not self.sparse_attention_config:\n            logger.warning_once(\n                \"Sparse attention will not be enabled as \"\n                \"sparse attention config is not provided.\"\n            )\n        self.sparse_attention_enabled = dual_chunk_attention_config.get(\n            \"sparse_attention_enabled\", self.sparse_attention_config is not None\n        )\n        self.sparse_attention_threshold = dual_chunk_attention_config.get(\n            \"sparse_attention_threshold\", 32768\n        )\n        self.sparse_attention_last_q = dual_chunk_attention_config.get(\n            \"sparse_attention_last_q\", 64\n        )\n        self.dual_chunk_attention_config = dual_chunk_attention_config\n\n        if self.sparse_attention_enabled:\n            self.arange = torch.arange(self.sparse_attention_last_q, device=\"cuda\")\n            self.last_q_mask = (\n                self.arange[None, None, :, None] >= self.arange[None, None, None, :]\n            )\n\n    @functools.lru_cache()\n    def get_sparse_attention_config(self, layer_idx) -> List[Dict[str, Any]]:\n        layer_sparse_attention_config = {\n            int(i): j for i, j in self.sparse_attention_config[layer_idx].items()\n        }\n        start_head = self.num_heads * get_tensor_model_parallel_rank()\n        end_head = start_head + self.num_heads\n        return [layer_sparse_attention_config[i] for i in range(start_head, end_head)]\n\n    def init_forward_metadata(self, forward_batch: ForwardBatch):\n        \"\"\"Initialize forward metadata hence all layers in the forward pass can reuse it.\"\"\"\n\n        forward_mode: ForwardMode = forward_batch.forward_mode\n        assert forward_mode.is_prefill() or forward_mode.is_decode()\n        batch_size = forward_batch.batch_size\n\n        metadata = DualChunkFlashAttentionMetadata()\n        metadata.seq_lens_tensor = forward_batch.seq_lens.to(torch.int32)\n        metadata.seq_lens = forward_batch.seq_lens.tolist()\n        metadata.max_seq_len = forward_batch.seq_lens.max().item()\n\n        metadata.orig_seq_lens_tensor = forward_batch.orig_seq_lens\n        metadata.orig_seq_lens = forward_batch.orig_seq_lens.tolist()\n\n        metadata.block_tables = forward_batch.req_to_token_pool.req_to_token[\n            forward_batch.req_pool_indices, : metadata.max_seq_len\n        ]\n        # Convert the block table to a strided format.\n        if self.page_size > 1:\n            strided_indices = torch.arange(\n                0, metadata.block_tables.shape[1], self.page_size, device=self.device\n            )\n            metadata.block_tables = (\n                metadata.block_tables[:, strided_indices] // self.page_size\n            )\n\n        metadata.query_start_loc = torch.zeros(\n            batch_size + 1, dtype=torch.int32, device=metadata.seq_lens_tensor.device\n        )\n        if forward_mode.is_prefill():\n            metadata.query_start_loc[1:] = torch.cumsum(\n                forward_batch.extend_seq_lens.to(torch.int32), dim=0, dtype=torch.int32\n            )\n        else:\n            metadata.query_start_loc[1:] = torch.cumsum(\n                torch.arange(\n                    batch_size,\n                    dtype=metadata.query_start_loc.dtype,\n                    device=metadata.query_start_loc.device,\n                ),\n                dim=0,\n                dtype=torch.int32,\n            )\n        metadata.seq_start_loc = torch.zeros(\n            batch_size + 1, dtype=torch.int32, device=metadata.seq_lens_tensor.device\n        )\n        metadata.seq_start_loc[1:] = torch.cumsum(\n            metadata.seq_lens_tensor, dim=0, dtype=torch.int32\n        )\n\n        if self.original_max_position_embeddings > 0:\n            if forward_mode.is_prefill():\n                metadata.scaling_factor = (\n                    0.1\n                    * torch.log(\n                        metadata.orig_seq_lens_tensor\n                        / self.original_max_position_embeddings\n                    )\n                    + 1.0\n                ).clip(min=1)\n            else:\n                metadata.scaling_factor = (\n                    0.1\n                    * torch.log(\n                        metadata.orig_seq_lens_tensor\n                        / self.original_max_position_embeddings\n                    )\n                    + 1.0\n                ).clip(min=1)\n\n        if forward_mode.is_decode():\n            cache_seq_lens = metadata.orig_seq_lens_tensor\n\n            chunk_len = self.chunk_size - self.local_size\n            chunk_num_curr = (cache_seq_lens - 1) // chunk_len\n\n            seq_lens_intra = cache_seq_lens - chunk_num_curr * chunk_len\n            max_seq_len_intra = seq_lens_intra.max().item()\n            metadata.seq_lens_intra = seq_lens_intra\n            metadata.max_seq_len_intra = max_seq_len_intra\n\n            block_tables_intra = torch.zeros(\n                batch_size,\n                (max_seq_len_intra - 1) // self.page_size + 1,\n                dtype=metadata.block_tables.dtype,\n                device=metadata.block_tables.device,\n            )\n            for i in range(batch_size):\n                st = chunk_num_curr[i] * chunk_len // self.page_size\n                ed = min(\n                    st + (max_seq_len_intra - 1) // self.page_size + 1,\n                    (cache_seq_lens[i] - 1) // self.page_size + 1,\n                )\n                block_tables_intra[i, : ed - st] = metadata.block_tables[i, st:ed]\n            metadata.block_tables_intra = block_tables_intra\n\n            metadata.seq_lens_succ = (\n                chunk_num_curr - (chunk_num_curr - 1).clip(min=0)\n            ) * chunk_len\n            metadata.max_seq_len_succ = metadata.seq_lens_succ.max().item()\n            if metadata.max_seq_len_succ:\n                block_tables_succ = torch.zeros(\n                    batch_size,\n                    (metadata.max_seq_len_succ - 1) // self.page_size + 1,\n                    dtype=metadata.block_tables.dtype,\n                    device=metadata.block_tables.device,\n                )\n                for i in range(batch_size):\n                    start = (\n                        (chunk_num_curr[i] - 1).clip(min=0)\n                        * chunk_len\n                        // self.page_size\n                    )\n                    end = min(\n                        start + (metadata.max_seq_len_succ - 1) // self.page_size + 1,\n                        (cache_seq_lens[i] - 1) // self.page_size + 1,\n                    )\n                    block_tables_succ[i, : end - start] = metadata.block_tables[\n                        i, start:end\n                    ]\n                metadata.block_tables_succ = block_tables_succ\n\n            metadata.seq_lens_inter = (chunk_num_curr - 1).clip(min=0) * chunk_len\n            metadata.max_seq_len_inter = metadata.seq_lens_inter.max().item()\n\n        self.forward_metadata = metadata\n\n    def forward_extend(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        layer: \"RadixAttention\",\n        forward_batch: ForwardBatch,\n        save_kv_cache=True,\n    ):\n        # Use precomputed metadata across all layers\n        metadata = self.forward_metadata\n\n        (\n            query,\n            query_succ,\n            query_inter,\n            query_succ_critical,\n            query_inter_critical,\n        ) = torch.split(q, q.shape[-1] // 5, dim=-1)\n\n        # Reshape the query, key, and value tensors.\n        query = query.view(-1, self.num_heads, self.head_size)\n        query_succ = query_succ.view(-1, self.num_heads, self.head_size)\n        query_inter = query_inter.view(-1, self.num_heads, self.head_size)\n        query_succ_critical = query_succ_critical.view(\n            -1, self.num_heads, self.head_size\n        )\n        query_inter_critical = query_inter_critical.view(\n            -1, self.num_heads, self.head_size\n        )\n        key = k.view(-1, self.num_kv_heads, self.head_size)\n        value = v.view(-1, self.num_kv_heads, self.head_size)\n\n        # apply DCA scaling\n        if self.original_max_position_embeddings > 0:\n            assert metadata.scaling_factor is not None\n            assert metadata.query_start_loc is not None\n            assert metadata.orig_seq_lens is not None\n            current_start = 0\n            query_start_loc_cpu = metadata.query_start_loc.cpu()\n            for i in range(len(metadata.orig_seq_lens)):\n                current_end = (\n                    current_start\n                    + (query_start_loc_cpu[i + 1] - query_start_loc_cpu[i]).item()\n                )\n                key[current_start:current_end].mul_(metadata.scaling_factor[i])\n                current_start = current_end\n            assert current_end <= self.max_context_len\n\n        # Do multi-head attention\n        key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(\n            layer.layer_id\n        )\n        key_cache = key_cache.view(\n            -1, self.page_size, layer.tp_k_head_num, layer.head_dim\n        )\n        value_cache = value_cache.view(\n            -1, self.page_size, layer.tp_v_head_num, layer.head_dim\n        )\n\n        if key is not None and value is not None:\n            if save_kv_cache:\n                forward_batch.token_to_kv_pool.set_kv_buffer(\n                    layer,\n                    forward_batch.out_cache_loc,\n                    key,\n                    value,\n                    layer.k_scale,\n                    layer.v_scale,\n                )\n\n        if not save_kv_cache:\n            # profile run\n            o = flash_attn_varlen_func(\n                q=query,\n                k=key,\n                v=value,\n                cu_seqlens_q=metadata.seq_start_loc,\n                cu_seqlens_k=metadata.seq_start_loc,\n                max_seqlen_q=metadata.max_seq_len,\n                max_seqlen_k=metadata.max_seq_len,\n                softmax_scale=layer.scaling,\n                causal=True,\n            )\n        else:\n            # prefill/chunked-prefill\n            # get per layer sparse attention config\n            if self.sparse_attention_enabled:\n                self.layer_sparse_attention_config = self.get_sparse_attention_config(\n                    layer.layer_id\n                )\n            assert metadata.orig_seq_lens is not None\n            o = self._dual_chunk_flash_attn_prefill(\n                q=query,\n                q_succ=query_succ,\n                q_inter=query_inter,\n                q_succ_critical=query_succ_critical,\n                q_inter_critical=query_inter_critical,\n                k=key_cache,\n                v=value_cache,\n                cu_seqlens_q=metadata.query_start_loc,\n                cu_seqlens_k=metadata.seq_start_loc,\n                orig_seq_lens=metadata.orig_seq_lens,\n                scaling_factor=metadata.scaling_factor,\n                softmax_scale=layer.scaling,\n                causal=True,\n                window_size=(-1, -1),\n                block_table=metadata.block_tables,\n                chunk_size=self.chunk_size,\n                local_size=self.local_size,\n            )\n        return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)\n\n    def forward_decode(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        layer: \"RadixAttention\",\n        forward_batch: ForwardBatch,\n        save_kv_cache=True,\n    ) -> torch.Tensor:\n        # Use precomputed metadata across all layers\n        metadata = self.forward_metadata\n\n        (\n            query,\n            query_succ,\n            query_inter,\n            query_succ_critical,\n            query_inter_critical,\n        ) = torch.split(q, q.shape[-1] // 5, dim=-1)\n\n        # Reshape the query, key, and value tensors.\n        query = query.view(-1, self.num_heads, self.head_size)\n        query_succ = query_succ.view(-1, self.num_heads, self.head_size)\n        query_inter = query_inter.view(-1, self.num_heads, self.head_size)\n        query_succ_critical = query_succ_critical.view(\n            -1, self.num_heads, self.head_size\n        )\n        query_inter_critical = query_inter_critical.view(\n            -1, self.num_heads, self.head_size\n        )\n        key = k.view(-1, self.num_kv_heads, self.head_size)\n        value = v.view(-1, self.num_kv_heads, self.head_size)\n\n        key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(\n            layer.layer_id\n        )\n        key_cache = key_cache.view(\n            -1, self.page_size, layer.tp_k_head_num, layer.head_dim\n        )\n        value_cache = value_cache.view(\n            -1, self.page_size, layer.tp_v_head_num, layer.head_dim\n        )\n\n        if key is not None and value is not None:\n            if save_kv_cache:\n                forward_batch.token_to_kv_pool.set_kv_buffer(\n                    layer,\n                    forward_batch.out_cache_loc,\n                    key,\n                    value,\n                    layer.k_scale,\n                    layer.v_scale,\n                )\n\n        # apply DCA scaling\n        if self.original_max_position_embeddings > 0:\n            assert metadata.scaling_factor is not None\n            scaling_factor = metadata.scaling_factor\n            key.mul_(scaling_factor.unsqueeze(-1).unsqueeze(-1))\n\n        o = self._dual_chunk_flash_attn_decoding(\n            query.unsqueeze(1),\n            query_succ.unsqueeze(1),\n            query_inter.unsqueeze(1),\n            key_cache,\n            value_cache,\n            block_table=metadata.block_tables,\n            cache_seqlens=metadata.seq_lens_tensor,\n            softmax_scale=layer.scaling,\n            causal=True,\n            chunk_size=self.chunk_size,\n            local_size=self.local_size,\n            original_max_position_embeddings=self.original_max_position_embeddings,\n            decode_meta=metadata,\n        ).squeeze(1)\n        return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)\n\n    def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):\n        \"\"\"Initialize CUDA graph state for the attention backend.\n\n        Args:\n            max_bs (int): Maximum batch size to support in CUDA graphs\n\n        This creates fixed-size tensors that will be reused during CUDA graph replay\n        to avoid memory allocations.\n        \"\"\"\n        self.decode_metadata = {\n            \"seq_lens_tensor\": torch.zeros(\n                max_bs, dtype=torch.int32, device=self.device\n            ),\n            \"orig_seq_lens_tensor\": torch.zeros(\n                max_bs, dtype=torch.int32, device=self.device\n            ),\n            \"scaling_factor\": torch.zeros(\n                max_bs, dtype=torch.float32, device=self.device\n            ),\n            \"block_tables\": torch.zeros(\n                max_bs,\n                (self.max_context_len - 1) // self.page_size + 1,\n                dtype=torch.int32,\n                device=self.device,\n            ),\n            \"block_tables_intra\": torch.zeros(\n                max_bs,\n                (self.max_context_len - 1) // self.page_size + 1,\n                dtype=torch.int32,\n                device=self.device,\n            ),\n            \"seq_lens_intra\": torch.zeros(\n                max_bs, dtype=torch.int32, device=self.device\n            ),\n            \"block_tables_succ\": torch.zeros(\n                max_bs,\n                (self.max_context_len - 1) // self.page_size + 1,\n                dtype=torch.int32,\n                device=self.device,\n            ),\n            \"seq_lens_succ\": torch.zeros(max_bs, dtype=torch.int32, device=self.device),\n            \"seq_lens_inter\": torch.zeros(\n                max_bs, dtype=torch.int32, device=self.device\n            ),\n        }\n\n    def init_forward_metadata_capture_cuda_graph(\n        self,\n        bs: int,\n        num_tokens: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        encoder_lens: Optional[torch.Tensor],\n        forward_mode: ForwardMode,\n        spec_info: Optional[None],\n    ):\n        metadata = DualChunkFlashAttentionMetadata()\n\n        if forward_mode.is_decode_or_idle():\n            if self.original_max_position_embeddings > 0:\n                metadata.scaling_factor = self.decode_metadata[\"scaling_factor\"][:bs]\n\n            metadata.seq_lens_tensor = self.decode_metadata[\"seq_lens_tensor\"][:bs]\n            metadata.orig_seq_lens_tensor = self.decode_metadata[\n                \"orig_seq_lens_tensor\"\n            ][:bs]\n            metadata.max_seq_len = self.max_context_len\n            metadata.block_tables = self.decode_metadata[\"block_tables\"][\n                req_pool_indices, :\n            ]\n\n            # intra\n            metadata.max_seq_len_intra = self.max_context_len\n            metadata.seq_lens_intra = self.decode_metadata[\"seq_lens_intra\"][:bs]\n\n            metadata.block_tables_intra = self.decode_metadata[\"block_tables_intra\"][\n                :bs, :\n            ]\n\n            # succ\n            metadata.seq_lens_succ = self.decode_metadata[\"seq_lens_succ\"][:bs]\n            metadata.max_seq_len_succ = self.max_context_len\n\n            metadata.block_tables_succ = self.decode_metadata[\"block_tables_succ\"][\n                :bs, :\n            ]\n\n            metadata.seq_lens_inter = self.decode_metadata[\"seq_lens_inter\"][:bs]\n            metadata.max_seq_len_inter = self.max_context_len\n\n            self.decode_metadata[bs] = metadata\n\n        self.forward_metadata = metadata\n\n    def init_forward_metadata_replay_cuda_graph(\n        self,\n        bs: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        seq_lens_sum: int,\n        encoder_lens: Optional[torch.Tensor],\n        forward_mode: ForwardMode,\n        spec_info: Optional[None],\n        seq_lens_cpu: Optional[torch.Tensor],\n        out_cache_loc: torch.Tensor = None,\n    ):\n        \"\"\"Initialize forward metadata for replaying CUDA graph.\"\"\"\n        assert forward_mode.is_decode()\n        seq_lens = seq_lens[:bs]\n        req_pool_indices = req_pool_indices[:bs]\n        metadata = self.decode_metadata[bs]\n\n        metadata.seq_lens_tensor.copy_(seq_lens.to(torch.int32))\n        metadata.seq_lens = seq_lens.tolist()\n        metadata.max_seq_len = seq_lens.max().item()\n\n        metadata.orig_seq_lens_tensor.copy_(seq_lens)\n        metadata.orig_seq_lens = seq_lens.tolist()\n\n        block_tables = self.req_to_token[req_pool_indices, : metadata.max_seq_len]\n        # Convert the block table to a strided format.\n        if self.page_size > 1:\n            strided_indices = torch.arange(\n                0, block_tables.shape[1], self.page_size, device=self.device\n            )\n            block_tables = block_tables[:, strided_indices] // self.page_size\n        metadata.block_tables.fill_(0)\n        metadata.block_tables[: block_tables.shape[0], : block_tables.shape[1]].copy_(\n            block_tables\n        )\n\n        if self.original_max_position_embeddings > 0:\n            scaling_factor = (\n                0.1\n                * torch.log(\n                    metadata.orig_seq_lens_tensor\n                    / self.original_max_position_embeddings\n                )\n                + 1.0\n            ).clip(min=1)\n            metadata.scaling_factor.copy_(scaling_factor)\n\n        cache_seq_lens = metadata.orig_seq_lens_tensor\n\n        chunk_len = self.chunk_size - self.local_size\n        chunk_num_curr = (cache_seq_lens - 1) // chunk_len\n\n        seq_lens_intra = cache_seq_lens - chunk_num_curr * chunk_len\n        max_seq_len_intra = seq_lens_intra.max().item()\n        metadata.seq_lens_intra.copy_(seq_lens_intra)\n        metadata.max_seq_len_intra = max_seq_len_intra\n\n        metadata.block_tables_intra.fill_(0)\n        for i in range(bs):\n            st = chunk_num_curr[i] * chunk_len // self.page_size\n            ed = min(\n                st + (max_seq_len_intra - 1) // self.page_size + 1,\n                (cache_seq_lens[i] - 1) // self.page_size + 1,\n            )\n            metadata.block_tables_intra[i, : ed - st] = metadata.block_tables[i, st:ed]\n\n        seq_lens_succ = (chunk_num_curr - (chunk_num_curr - 1).clip(min=0)) * chunk_len\n        metadata.seq_lens_succ.copy_(seq_lens_succ)\n        metadata.max_seq_len_succ = metadata.seq_lens_succ.max().item()\n        if metadata.max_seq_len_succ:\n            metadata.block_tables_succ.fill_(0)\n            for i in range(bs):\n                start = (\n                    (chunk_num_curr[i] - 1).clip(min=0) * chunk_len // self.page_size\n                )\n                end = min(\n                    start + (metadata.max_seq_len_succ - 1) // self.page_size + 1,\n                    (cache_seq_lens[i] - 1) // self.page_size + 1,\n                )\n                metadata.block_tables_succ[i, : end - start] = metadata.block_tables[\n                    i, start:end\n                ]\n\n        seq_lens_inter = (chunk_num_curr - 1).clip(min=0) * chunk_len\n        metadata.seq_lens_inter.copy_(seq_lens_inter)\n        metadata.max_seq_len_inter = metadata.seq_lens_inter.max().item()\n\n        self.forward_metadata = metadata\n\n    def get_cuda_graph_seq_len_fill_value(self):\n        \"\"\"Get the fill value for sequence length in CUDA graph.\"\"\"\n        return 1\n\n    def _dual_chunk_flash_attn_prefill(\n        self,\n        q,\n        q_succ,\n        q_inter,\n        q_succ_critical,\n        q_inter_critical,\n        k,\n        v,\n        cu_seqlens_q,\n        cu_seqlens_k,\n        orig_seq_lens: List[int],\n        scaling_factor: torch.Tensor,\n        softmax_scale: float,\n        causal: Optional[bool] = True,\n        window_size: Tuple[int, int] = (-1, -1),\n        block_table: Optional[torch.Tensor] = None,\n        chunk_size: int = 8192,\n        local_size: int = 1024,\n    ):\n        if not causal:\n            raise ValueError(\"Dual Chunk Attention does not support causal=False\")\n        if window_size != (-1, -1):\n            raise ValueError(\"Dual Chunk Attention does not support window_size\")\n\n        cu_seqlens_q_cpu = cu_seqlens_q.cpu().tolist()\n        cu_seqlens_k_cpu = cu_seqlens_k.cpu().tolist()\n        all_outputs = []\n\n        for i in range(0, len(cu_seqlens_q_cpu) - 1):\n            qs = cu_seqlens_q_cpu[i]\n            qe = cu_seqlens_q_cpu[i : i + 2][-1]\n            ks = cu_seqlens_k_cpu[i]\n            ke = cu_seqlens_k_cpu[i : i + 2][-1]\n\n            current_q = q[qs:qe]\n            current_q_succ = q_succ[qs:qe]\n            current_q_inter = q_inter[qs:qe]\n            current_q_succ_critical = q_succ_critical[qs:qe]\n            current_q_inter_critical = q_inter_critical[qs:qe]\n\n            if block_table is None:\n                current_k = k[ks:ke]\n                current_v = v[ks:ke]\n                current_block_table = None\n                current_orig_seq_len = orig_seq_lens[i]\n            else:\n                current_block_table = block_table[i]\n                current_orig_seq_len = orig_seq_lens[i]\n                current_k = k\n                current_v = v\n            sparse_attn_enabled = (\n                self.sparse_attention_enabled\n                and current_orig_seq_len > self.sparse_attention_threshold\n            )\n\n            if current_q.shape[0] == 0:\n                continue\n\n            if current_k.shape[0] == 0:\n                all_outputs.append(\n                    torch.zeros(\n                        (current_q.shape[0], current_q.shape[1], v.shape[2]),\n                        device=q.device,\n                        dtype=q.dtype,\n                    )\n                )\n                continue\n\n            current_output = torch.empty_like(current_q)\n            group_size = int(current_q.size(-2) / current_k.size(-2))\n\n            if sparse_attn_enabled:\n                num_device_q_heads = current_q.size(-2)\n                heads_vertical_size = torch.empty(\n                    size=(num_device_q_heads,), dtype=torch.int32\n                )\n                heads_slash_size = torch.empty(\n                    size=(num_device_q_heads,), dtype=torch.int32\n                )\n                for head_id in range(current_q.size(-2)):\n                    (\n                        ty,\n                        vertical_size,\n                        slash_size,\n                        _,\n                    ) = self.layer_sparse_attention_config[head_id]\n                    assert ty == \"vertical_and_slash\", \"only support slash mode\"\n\n                    if vertical_size == 30:\n                        vertical_size += 100\n                    heads_vertical_size[head_id] = vertical_size\n                    heads_slash_size[head_id] = slash_size\n\n                current_output = self._dual_chunk_flash_attn_prefill_func(\n                    current_q,  # allheads\n                    current_q_succ,\n                    current_q_inter,\n                    current_q_succ_critical,\n                    current_q_inter_critical,\n                    current_k,\n                    current_v,\n                    current_block_table,\n                    softmax_scale,\n                    chunk_size,\n                    local_size,\n                    scaling_factor[i].item(),\n                    ke - ks,\n                    sparse_attn_enabled=sparse_attn_enabled,\n                    heads_vertical_size=heads_vertical_size,\n                    heads_slash_size=heads_slash_size,\n                    group_size=group_size,\n                )\n            else:\n                for head_id in range(current_q.size(-2)):\n                    # (seq_len, num_heads, head_size)\n                    current_q_head = current_q[:, head_id, :].unsqueeze(1)\n                    current_q_succ_head = current_q_succ[:, head_id, :].unsqueeze(1)\n                    current_q_inter_head = current_q_inter[:, head_id, :].unsqueeze(1)\n                    current_q_succ_head_critical = current_q_succ_critical[\n                        :, head_id, :\n                    ].unsqueeze(1)\n                    current_q_inter_head_critical = current_q_inter_critical[\n                        :, head_id, :\n                    ].unsqueeze(1)\n                    if block_table is not None:\n                        current_k_head = current_k[\n                            ..., head_id // group_size, :\n                        ].unsqueeze(2)\n                        current_v_head = current_v[\n                            ..., head_id // group_size, :\n                        ].unsqueeze(2)\n\n                    else:\n                        current_k_head = current_k[:, head_id, :].unsqueeze(1)\n                        current_v_head = current_v[:, head_id, :].unsqueeze(1)\n\n                    current_out = self._dual_chunk_flash_attn_prefill_func(\n                        current_q_head,\n                        current_q_succ_head,\n                        current_q_inter_head,\n                        current_q_succ_head_critical,\n                        current_q_inter_head_critical,\n                        current_k_head,\n                        current_v_head,\n                        current_block_table,\n                        softmax_scale,\n                        chunk_size,\n                        local_size,\n                        scaling_factor[i].item(),\n                        ke - ks,\n                        sparse_attn_enabled=sparse_attn_enabled,\n                    )\n                    current_output[:, head_id : head_id + 1, :] = current_out\n            all_outputs.append(current_output)\n        return torch.cat(all_outputs, dim=0)\n\n    def _dual_chunk_flash_attn_prefill_func(\n        self,\n        q,\n        q_succ,\n        q_inter,\n        q_succ_critical,\n        q_inter_critical,\n        k,\n        v,\n        block_table,\n        softmax_scale: float,\n        chunk_size: int,\n        local_size: int,\n        scaling_factor: float,\n        k_length: int,\n        sparse_attn_enabled: Optional[bool] = True,\n        heads_vertical_size=None,\n        heads_slash_size=None,\n        group_size=None,\n    ):\n        flash_results = []\n        chunk_len = chunk_size - local_size\n\n        if block_table is not None:\n            block_size = v.shape[1]\n            if chunk_len % block_size != 0:\n                raise ValueError(\"chunk_len must be divisible by block_size.\")\n        else:\n            block_size = 1\n\n        if self.original_max_position_embeddings > 0:\n            softmax_scale = softmax_scale * scaling_factor\n\n        begin = k_length - q.shape[0]\n        while begin < k_length:\n            flash_per_chunk = []\n\n            prev_chunk_end_pos = (begin // chunk_len) * chunk_len\n            next_chunk_end_pos = prev_chunk_end_pos + chunk_len\n            end = min(next_chunk_end_pos, k_length)\n            qbegin = begin - (k_length - q.shape[0])\n            qend = end - (k_length - q.shape[0])\n\n            qk_chunks = []\n            q_states_intra = q[qbegin:qend]\n            # choose critical token\n            if block_table is not None:\n                block_tables_intra = _get_block(\n                    block_table, block_size, prev_chunk_end_pos, end\n                )\n                k_states_intra = k[block_tables_intra].view(-1, *k.shape[-2:])[\n                    : (end - prev_chunk_end_pos)\n                ]\n                v_states_intra = v[block_tables_intra].view(-1, *v.shape[-2:])[\n                    : (end - prev_chunk_end_pos)\n                ]\n            else:\n                block_tables_intra = None\n                k_states_intra = k[prev_chunk_end_pos:end]\n                v_states_intra = v[prev_chunk_end_pos:end]\n\n            if sparse_attn_enabled:\n                last_q_size = min(qend - qbegin, self.sparse_attention_last_q)\n                _, num_device_k_heads, head_dim = k_states_intra.shape\n                k_states_intra = (\n                    k_states_intra.unsqueeze(2)\n                    .repeat(1, 1, group_size, 1)\n                    .reshape(-1, num_device_k_heads * group_size, head_dim)\n                )\n                v_states_intra = (\n                    v_states_intra.unsqueeze(2)\n                    .repeat(1, 1, group_size, 1)\n                    .reshape(-1, num_device_k_heads * group_size, head_dim)\n                )\n                qk_chunks.append(\n                    (q_states_intra.transpose(0, 1)[:, -last_q_size:] * softmax_scale)\n                    @ k_states_intra.permute(1, 2, 0)\n                )\n\n            if prev_chunk_end_pos - chunk_len >= 0:\n                q_states_succ = q_succ[qbegin:qend]\n                q_states_succ_critical = q_succ_critical[qbegin:qend]\n                if block_table is not None:\n                    block_tables_succ = _get_block(\n                        block_table,\n                        block_size,\n                        prev_chunk_end_pos - chunk_len,\n                        prev_chunk_end_pos,\n                    )\n                    k_states_succ = k[block_tables_succ].view(-1, *k.shape[-2:])[\n                        :chunk_len\n                    ]\n                    v_states_succ = v[block_tables_succ].view(-1, *v.shape[-2:])[\n                        :chunk_len\n                    ]\n                else:\n                    k_states_succ = k[\n                        prev_chunk_end_pos - chunk_len : prev_chunk_end_pos\n                    ]\n                    v_states_succ = v[\n                        prev_chunk_end_pos - chunk_len : prev_chunk_end_pos\n                    ]\n\n                if sparse_attn_enabled:\n                    k_states_succ = (\n                        k_states_succ.unsqueeze(2)\n                        .repeat(1, 1, group_size, 1)\n                        .reshape(-1, num_device_k_heads * group_size, head_dim)\n                    )\n                    v_states_succ = (\n                        v_states_succ.unsqueeze(2)\n                        .repeat(1, 1, group_size, 1)\n                        .reshape(-1, num_device_k_heads * group_size, head_dim)\n                    )\n                    qk_chunks.append(\n                        (\n                            q_states_succ_critical.transpose(0, 1)[:, -last_q_size:]\n                            * softmax_scale\n                        )\n                        @ k_states_succ.permute(1, 2, 0)\n                    )\n\n            if prev_chunk_end_pos - chunk_len * 2 >= 0:\n                q_states_inter = q_inter[qbegin:qend]\n                q_states_inter_critical = q_inter_critical[qbegin:qend]\n                if block_table is not None:\n                    block_tables_inter = _get_block(\n                        block_table, block_size, 0, prev_chunk_end_pos - chunk_len\n                    )\n                    k_states_inter = k[block_tables_inter].view(-1, *k.shape[-2:])[\n                        : (prev_chunk_end_pos - chunk_len)\n                    ]\n                    v_states_inter = v[block_tables_inter].view(-1, *v.shape[-2:])[\n                        : (prev_chunk_end_pos - chunk_len)\n                    ]\n                else:\n                    k_states_inter = k[: prev_chunk_end_pos - chunk_len]\n                    v_states_inter = v[: prev_chunk_end_pos - chunk_len]\n\n                if sparse_attn_enabled:\n                    k_states_inter = (\n                        k_states_inter.unsqueeze(2)\n                        .repeat(1, 1, group_size, 1)\n                        .reshape(-1, num_device_k_heads * group_size, head_dim)\n                    )\n                    v_states_inter = (\n                        v_states_inter.unsqueeze(2)\n                        .repeat(1, 1, group_size, 1)\n                        .reshape(-1, num_device_k_heads * group_size, head_dim)\n                    )\n                    qk_chunks.append(\n                        (\n                            q_states_inter_critical.transpose(0, 1)[:, -last_q_size:]\n                            * softmax_scale\n                        )\n                        @ k_states_inter.permute(1, 2, 0)\n                    )\n\n            if sparse_attn_enabled:\n                reversed_qk = qk_chunks[::-1]\n                qk = torch.cat(reversed_qk, dim=-1)\n\n                qk[:, :, -last_q_size:] = torch.where(\n                    self.last_q_mask[..., -last_q_size:, -last_q_size:].to(qk.device),\n                    qk[:, :, -last_q_size:],\n                    -torch.inf,\n                )\n                qk = F.softmax(qk, dim=-1, dtype=torch.float32)\n\n                vertical = qk.sum(-2, keepdim=True)\n                vertical[..., :30] = torch.inf\n\n                # Avoid sorting by using the min/max ints to fill the indexer\n                # buffers.\n                int32_max = torch.iinfo(torch.int32).max\n                int32_min = torch.iinfo(torch.int32).min\n                n_heads = qk.size()[0]\n                max_slash_topk = torch.max(heads_slash_size).item()\n                max_vertical_topk = torch.max(heads_vertical_size).item()\n                # store each head's slash topk, vertical topk\n                vertical = vertical.reshape((n_heads, -1))\n                # prevent out of range when prompt size < max_vertical_topk\n                max_vertical_topk = min(vertical.shape[-1], max_vertical_topk)\n                vertical_topk_buffer = torch.topk(\n                    vertical, max_vertical_topk, -1\n                ).indices\n                slash_topk_buffer = torch.empty(\n                    size=(n_heads, max_slash_topk), dtype=torch.int64, device=qk.device\n                )\n                for head_i in range(n_heads):\n                    #  (nqheads=1, lastq, k_len)\n                    head_score = qk[head_i : head_i + 1, :, :]\n                    slash_scores = _sum_all_diagonal_matrix(head_score)\n                    if head_score.size(1) != 1:\n                        # drop right up corner\n                        slash_scores = slash_scores[..., : -last_q_size + 1]\n                    slash_scores[..., -100:] = torch.inf\n\n                    head_slash_size = heads_slash_size[head_i]\n                    head_slash_size = min(head_slash_size, vertical.size(-1))\n                    slash_topk = torch.topk(slash_scores, head_slash_size, -1).indices\n                    # （nheads, max_topk）\n                    slash_topk_buffer[head_i, :head_slash_size] = slash_topk\n\n                    # reset heads topk\n                    heads_slash_size[head_i] = head_slash_size\n                    heads_vertical_size[head_i] = min(\n                        heads_vertical_size[head_i], max_vertical_topk\n                    )\n\n                # store\n                vertical_buffer = torch.full(\n                    (n_heads, max_vertical_topk),\n                    int32_max,\n                    dtype=torch.int64,\n                    device=q.device,\n                )\n                slash_buffer = torch.full(\n                    (n_heads, max_slash_topk),\n                    int32_min,\n                    dtype=torch.int64,\n                    device=q.device,\n                )\n                succ_vertical_buffer = torch.full(\n                    (n_heads, max_vertical_topk),\n                    int32_max,\n                    dtype=torch.int64,\n                    device=q.device,\n                )\n                succ_slash_buffer = torch.full(\n                    (n_heads, max_slash_topk),\n                    int32_min,\n                    dtype=torch.int64,\n                    device=q.device,\n                )\n                inter_vertical_buffer = torch.full(\n                    (n_heads, max_vertical_topk),\n                    int32_max,\n                    dtype=torch.int64,\n                    device=q.device,\n                )\n                inter_slash_buffer = torch.full(\n                    (n_heads, max_slash_topk),\n                    int32_min,\n                    dtype=torch.int64,\n                    device=q.device,\n                )\n\n                vertical_size_buffer = torch.empty(\n                    size=(n_heads,), dtype=torch.int32, device=q.device\n                )\n                slash_sizes_buffer = torch.empty(\n                    size=(n_heads,), dtype=torch.int32, device=q.device\n                )\n                succ_vertical_size_buffer = torch.empty(\n                    size=(n_heads,), dtype=torch.int32, device=q.device\n                )\n                succ_slash_sizes_buffer = torch.empty(\n                    size=(n_heads,), dtype=torch.int32, device=q.device\n                )\n                inter_vertical_size_buffer = torch.empty(\n                    size=(n_heads,), dtype=torch.int32, device=q.device\n                )\n                inter_slash_sizes_buffer = torch.empty(\n                    size=(n_heads,), dtype=torch.int32, device=q.device\n                )\n\n                for head_i in range(n_heads):\n                    vertical_topk = vertical_topk_buffer[\n                        head_i, : heads_vertical_size[head_i]\n                    ]\n                    # intra\n                    intra_vertical_indices = (\n                        vertical_topk[vertical_topk >= prev_chunk_end_pos]\n                        - prev_chunk_end_pos\n                    )\n                    if intra_vertical_indices.nelement() == 0:\n                        intra_vertical_indices = torch.cat(\n                            [\n                                intra_vertical_indices,\n                                torch.arange(\n                                    0,\n                                    k_states_intra.size(0),\n                                    max(1, k_states_intra.size(0) / 5),\n                                    dtype=torch.int32,\n                                    device=intra_vertical_indices.device,\n                                ),\n                            ]\n                        )\n                    slash_topk = slash_topk_buffer[head_i, : heads_slash_size[head_i]]\n                    intra_slash_indices = (qk.size(-1) - 1) - slash_topk[\n                        slash_topk >= prev_chunk_end_pos\n                    ]\n                    # fill buffer\n                    v_count = intra_vertical_indices.nelement()\n                    s_count = intra_slash_indices.nelement()\n                    vertical_size_buffer[head_i] = v_count\n                    slash_sizes_buffer[head_i] = s_count\n                    vertical_buffer[head_i, :v_count].copy_(intra_vertical_indices)\n                    slash_buffer[head_i, :s_count].copy_(intra_slash_indices)\n                    # succ\n                    if prev_chunk_end_pos - chunk_len >= 0:\n                        succ_vertical_indices = vertical_topk[\n                            (vertical_topk < prev_chunk_end_pos)\n                            & (vertical_topk >= prev_chunk_end_pos - chunk_len)\n                        ] - (prev_chunk_end_pos - chunk_len)\n                        # TODO: support no vertical\n                        if succ_vertical_indices.nelement() == 0:\n                            succ_vertical_indices = torch.cat(\n                                [\n                                    succ_vertical_indices,\n                                    torch.arange(\n                                        0,\n                                        k_states_succ.size(0),\n                                        max(1, k_states_succ.size(0) / 5),\n                                        dtype=torch.int32,\n                                        device=intra_vertical_indices.device,\n                                    ),\n                                ]\n                            )\n                        succ_slash_indices = (\n                            prev_chunk_end_pos + (qend - qbegin) - 1\n                        ) - slash_topk[\n                            (\n                                (slash_topk >= (prev_chunk_end_pos - chunk_len))\n                                & (slash_topk < (prev_chunk_end_pos + (qend - qbegin)))\n                            )\n                        ]\n                        if succ_slash_indices.nelement() == 0:\n                            succ_slash_indices = torch.cat(\n                                [\n                                    succ_slash_indices,\n                                    torch.arange(\n                                        0,\n                                        k_states_succ.size(0),\n                                        max(1, k_states_succ.size(0) / 5),\n                                        dtype=torch.int32,\n                                        device=intra_vertical_indices.device,\n                                    ),\n                                ]\n                            )\n                        # fill buffer\n                        v_count = succ_vertical_indices.nelement()\n                        s_count = succ_slash_indices.nelement()\n                        succ_vertical_size_buffer[head_i] = v_count\n                        succ_slash_sizes_buffer[head_i] = s_count\n                        succ_vertical_buffer[head_i, :v_count].copy_(\n                            succ_vertical_indices\n                        )\n                        succ_slash_buffer[head_i, :s_count].copy_(succ_slash_indices)\n\n                    if prev_chunk_end_pos - 2 * chunk_len >= 0:\n                        inter_vertical_indices = vertical_topk[\n                            vertical_topk < prev_chunk_end_pos - chunk_len\n                        ]\n\n                        if inter_vertical_indices.nelement() == 0:\n                            inter_vertical_indices = torch.cat(\n                                [\n                                    inter_vertical_indices,\n                                    torch.arange(\n                                        0,\n                                        k_states_inter.size(0),\n                                        max(1, k_states_inter.size(0) / 5),\n                                        dtype=torch.int32,\n                                        device=intra_vertical_indices.device,\n                                    ),\n                                ]\n                            )\n                        inter_slash_indices = (\n                            prev_chunk_end_pos - chunk_len + (qend - qbegin) - 1\n                        ) - slash_topk[\n                            slash_topk\n                            < (prev_chunk_end_pos - chunk_len + (qend - qbegin))\n                        ]\n                        if inter_slash_indices.nelement() == 0:\n                            inter_slash_indices = torch.cat(\n                                [\n                                    inter_slash_indices,\n                                    torch.arange(\n                                        0,\n                                        k_states_inter.size(0),\n                                        max(1, k_states_inter.size(0) / 5),\n                                        dtype=torch.int32,\n                                        device=intra_vertical_indices.device,\n                                    ),\n                                ]\n                            )\n                        # fill buffer\n                        v_count = inter_vertical_indices.nelement()\n                        s_count = inter_slash_indices.nelement()\n                        inter_vertical_size_buffer[head_i] = v_count\n                        inter_slash_sizes_buffer[head_i] = s_count\n                        inter_vertical_buffer[head_i, :v_count].copy_(\n                            inter_vertical_indices\n                        )\n                        inter_slash_buffer[head_i, :s_count].copy_(inter_slash_indices)\n            else:\n                intra_vertical_indices, intra_slash_indices = None, None\n                succ_vertical_indices, succ_slash_indices = None, None\n                inter_vertical_indices, inter_slash_indices = None, None\n\n            if sparse_attn_enabled:\n                flash_result = self._do_flash_attn(\n                    q_states_intra,\n                    k_states_intra,\n                    v_states_intra,\n                    softmax_scale=softmax_scale,\n                    causal=True,\n                    stage=\"intra\",\n                    vertical_indices=vertical_buffer,\n                    slash_indices=slash_buffer,\n                    vertical_indices_count=vertical_size_buffer,\n                    slash_indices_count=slash_sizes_buffer,\n                    mergehead_softmax_scale=softmax_scale,\n                    sparse_attn_enabled=sparse_attn_enabled,\n                )\n            else:\n                flash_result = self._do_flash_attn(\n                    q_states_intra,\n                    k_states_intra,\n                    v_states_intra,\n                    softmax_scale=softmax_scale,\n                    causal=True,\n                    stage=\"intra\",\n                    vertical_indices=intra_vertical_indices,\n                    slash_indices=intra_slash_indices,\n                    sparse_attn_enabled=sparse_attn_enabled,\n                )\n            flash_per_chunk.append(flash_result)\n\n            if prev_chunk_end_pos - chunk_len >= 0:\n                if sparse_attn_enabled:\n                    flash_result = self._do_flash_attn(\n                        q_states_succ,\n                        k_states_succ,\n                        v_states_succ,\n                        softmax_scale=softmax_scale,\n                        causal=False,\n                        stage=\"succ\",\n                        vertical_indices=succ_vertical_buffer,\n                        slash_indices=succ_slash_buffer,\n                        vertical_indices_count=succ_vertical_size_buffer,\n                        slash_indices_count=succ_slash_sizes_buffer,\n                        mergehead_softmax_scale=softmax_scale,\n                        sparse_attn_enabled=sparse_attn_enabled,\n                    )\n                else:\n                    flash_result = self._do_flash_attn(\n                        q_states_succ,\n                        k_states_succ,\n                        v_states_succ,\n                        softmax_scale=softmax_scale,\n                        causal=False,\n                        stage=\"succ\",\n                        vertical_indices=succ_vertical_indices,\n                        slash_indices=succ_slash_indices,\n                        sparse_attn_enabled=sparse_attn_enabled,\n                    )\n                flash_per_chunk.append(flash_result)\n\n            if prev_chunk_end_pos - chunk_len * 2 >= 0:\n                if sparse_attn_enabled:\n                    flash_result = self._do_flash_attn(\n                        q_states_inter,\n                        k_states_inter,\n                        v_states_inter,\n                        softmax_scale=softmax_scale,\n                        causal=False,\n                        stage=\"inter\",\n                        vertical_indices=inter_vertical_buffer,\n                        slash_indices=inter_slash_buffer,\n                        vertical_indices_count=inter_vertical_size_buffer,\n                        slash_indices_count=inter_slash_sizes_buffer,\n                        mergehead_softmax_scale=softmax_scale,\n                        sparse_attn_enabled=sparse_attn_enabled,\n                    )\n                else:\n                    flash_result = self._do_flash_attn(\n                        q_states_inter,\n                        k_states_inter,\n                        v_states_inter,\n                        softmax_scale=softmax_scale,\n                        causal=False,\n                        stage=\"inter\",\n                        vertical_indices=inter_vertical_indices,\n                        slash_indices=inter_slash_indices,\n                        sparse_attn_enabled=sparse_attn_enabled,\n                    )\n                flash_per_chunk.append(flash_result)\n\n            flash_results.append(flash_per_chunk)\n            begin = end\n\n        attn_output = self._merge_attn_outputs(flash_results)\n        del flash_results\n        return attn_output\n\n    def _do_flash_attn(\n        self,\n        query_states: torch.Tensor,\n        key_states: torch.Tensor,\n        value_states: torch.Tensor,\n        softmax_scale: float,\n        causal: bool = True,\n        max_seqlen_k: Optional[int] = None,\n        stage: str = \"intra\",\n        vertical_indices: Optional[torch.Tensor] = None,\n        slash_indices: Optional[torch.Tensor] = None,\n        vertical_indices_count: Optional[torch.Tensor] = None,\n        slash_indices_count: Optional[torch.Tensor] = None,\n        mergehead_softmax_scale: Optional[float] = None,\n        sparse_attn_enabled: Optional[bool] = False,\n    ):\n        if max_seqlen_k is None:\n            max_seqlen_k = key_states.shape[0]\n\n        q_len = query_states.shape[0]\n        q_heads = query_states.shape[1]\n        h_dim = query_states.shape[-1]\n\n        if sparse_attn_enabled:\n            assert slash_indices is not None\n            if stage == \"intra\":\n                assert causal\n            else:\n                assert not causal\n\n            query_states = query_states.unsqueeze(0).transpose(1, 2)\n            key_states = key_states.unsqueeze(0).transpose(1, 2)\n            value_states = value_states.unsqueeze(0).transpose(1, 2)\n\n            q = query_states\n            k = key_states\n            v = value_states\n\n            if vertical_indices_count is not None and slash_indices_count is not None:\n                assert mergehead_softmax_scale is not None\n\n                res, s_lse = _vertical_slash_sparse_attention(\n                    q,\n                    k,\n                    v,\n                    vertical_indices,\n                    slash_indices,\n                    mergehead_softmax_scale,\n                    causal=causal,\n                    stage=stage,\n                    vertical_indices_count=vertical_indices_count,\n                    slash_indices_count=slash_indices_count,\n                )\n                res = res.view(q_heads, q_len, h_dim).transpose(\n                    0, 1\n                )  # (qlen,nhead,h_dim)\n                s_lse = (\n                    s_lse.view(q_heads, q_len, 1).squeeze(-1).unsqueeze(0).float()\n                )  # (1, nhead,qlen)\n            else:\n                res, s_lse = _vertical_slash_sparse_attention(\n                    q,\n                    k,\n                    v,\n                    vertical_indices,\n                    slash_indices,\n                    softmax_scale,\n                    causal=causal,\n                    stage=stage,\n                )\n                res = res.view(q_len, q_heads, h_dim)\n                s_lse = s_lse.view(q_len, q_heads, 1).transpose(0, 2).float()\n            return res, s_lse\n\n        output, softmax_lse, *rest = flash_attn_varlen_func(\n            q=query_states,\n            k=key_states,\n            v=value_states,\n            softmax_scale=softmax_scale,\n            cu_seqlens_q=torch.tensor(\n                [0, query_states.shape[0]],\n                dtype=torch.int32,\n                device=query_states.device,\n            ),\n            max_seqlen_q=query_states.shape[0],\n            cu_seqlens_k=torch.tensor(\n                [0, max_seqlen_k], dtype=torch.int32, device=query_states.device\n            ),\n            max_seqlen_k=max_seqlen_k,\n            causal=causal,\n            return_softmax_lse=True,\n        )\n        softmax_lse = softmax_lse.view(q_len, q_heads, 1).transpose(0, 2).float()\n        return output, softmax_lse\n\n    def _merge_attn_outputs(\n        self,\n        flash_results: List[List[Tuple[torch.Tensor, torch.Tensor]]],\n        return_lse: Optional[bool] = False,\n    ) -> torch.Tensor:\n        attn_outputs_all = []\n        logits_all = []\n\n        for flash_per_chunk in flash_results:\n            if len(flash_per_chunk) == 1:\n                attn_outputs_all.append(flash_per_chunk[0][0])\n                if return_lse:\n                    logits_all.append(flash_per_chunk[0][1])\n                continue\n\n            attn_outputs = torch.stack(\n                [flash_attn_output[0] for flash_attn_output in flash_per_chunk]\n            )\n            logits = torch.stack(\n                [flash_attn_output[1] for flash_attn_output in flash_per_chunk]\n            )\n            logits = logits.to(torch.float32)\n\n            if return_lse:\n                max_val = torch.max(logits, dim=0).values\n                diff = torch.abs(logits[0] - logits[1])\n                log_sum_exp = max_val + torch.log1p(torch.exp(-diff))\n                logits_all.append(log_sum_exp)\n\n            max_logits = torch.max(logits, dim=0).values\n            stable_logits = logits - max_logits.unsqueeze(0)\n            lse_s = torch.exp(stable_logits).detach()\n            lse_sum = torch.sum(lse_s, dim=0)\n            lse_s /= lse_sum\n            attn_outputs *= lse_s.unsqueeze(-1).transpose(2, 3).squeeze(1)\n            attn_outputs_all.append(attn_outputs.sum(dim=0))\n\n        if return_lse:\n            return (torch.cat(attn_outputs_all, dim=0), torch.cat(logits_all, dim=-1))\n        else:\n            return torch.cat(attn_outputs_all, dim=0)\n\n    def _dual_chunk_flash_attn_decoding(\n        self,\n        query: torch.Tensor,\n        query_succ: torch.Tensor,\n        query_inter: torch.Tensor,\n        key_cache: torch.Tensor,\n        value_cache: torch.Tensor,\n        block_table: torch.Tensor,\n        cache_seqlens: torch.Tensor,\n        softmax_scale: float,\n        causal: bool,\n        chunk_size: int,\n        local_size: int,\n        original_max_position_embeddings: int,\n        decode_meta: DualChunkFlashAttentionMetadata,\n    ):\n        if not causal:\n            raise ValueError(\"Dual Chunk Attention does not support causal=False\")\n\n        block_size = value_cache.shape[1]\n        chunk_len = chunk_size - local_size\n        if chunk_len % block_size != 0:\n            raise ValueError(\"chunk_len must be divisible by block_size.\")\n        if original_max_position_embeddings > 0:\n            assert decode_meta.scaling_factor is not None\n            scaling_factor = decode_meta.scaling_factor\n            query = (query * scaling_factor.view(-1, 1, 1, 1)).to(\n                query.dtype\n            )  # possible for numerical issue, need to fused in the kernel\n            query_succ = (query_succ * scaling_factor.view(-1, 1, 1, 1)).to(query.dtype)\n            query_inter = (query_inter * scaling_factor.view(-1, 1, 1, 1)).to(\n                query.dtype\n            )\n        outputs_list = []\n        softmax_lses_list = []\n\n        # intra-attention\n        intra_output, intra_softmax_lse = (\n            self._dual_chunk_flash_attn_decoding_with_exp_sums(\n                query,\n                key_cache,\n                value_cache,\n                decode_meta.block_tables_intra,\n                decode_meta.seq_lens_intra,\n                softmax_scale,\n                causal=False,\n            )\n        )\n        outputs_list.append(intra_output)\n        softmax_lses_list.append(intra_softmax_lse)\n\n        # succ-attention\n        if decode_meta.max_seq_len_succ:\n            succ_output, succ_softmax_lse = (\n                self._dual_chunk_flash_attn_decoding_with_exp_sums(\n                    query_succ,\n                    key_cache,\n                    value_cache,\n                    decode_meta.block_tables_succ,\n                    decode_meta.seq_lens_succ,\n                    softmax_scale,\n                    causal=False,\n                )\n            )\n            outputs_list.append(succ_output)\n            softmax_lses_list.append(succ_softmax_lse)\n\n        # inter-attention\n        if decode_meta.max_seq_len_inter:\n            inter_output, inter_softmax_lse = (\n                self._dual_chunk_flash_attn_decoding_with_exp_sums(\n                    query_inter,\n                    key_cache,\n                    value_cache,\n                    block_table,\n                    decode_meta.seq_lens_inter,\n                    softmax_scale,\n                    causal=False,\n                )\n            )\n            outputs_list.append(inter_output)\n            softmax_lses_list.append(inter_softmax_lse)\n        outputs = torch.stack(outputs_list, dim=0)\n        del outputs_list\n        softmax_lses = torch.stack(softmax_lses_list, dim=0).to(torch.float32)\n        del softmax_lses_list\n        max_logits = torch.max(softmax_lses, dim=0).values\n        stable_logits = softmax_lses - max_logits.unsqueeze(0)\n        lse_s = torch.exp(stable_logits).detach()\n        lse_sum = torch.sum(lse_s, dim=0)\n        lse_s /= lse_sum\n        outputs *= lse_s.unsqueeze(-1).transpose(2, 3)\n        return outputs.sum(0)\n\n    def _dual_chunk_flash_attn_decoding_with_exp_sums(\n        self,\n        query: torch.Tensor,\n        key_cache: torch.Tensor,\n        value_cache: torch.Tensor,\n        block_table: torch.Tensor,\n        cache_seqlens: torch.Tensor,\n        softmax_scale: float,\n        causal: bool,\n    ):\n        out, softmax_lse, *rest_expand = flash_attn_with_kvcache(\n            q=query,\n            k_cache=key_cache,\n            v_cache=value_cache,\n            page_table=block_table,\n            cache_seqlens=cache_seqlens,\n            softmax_scale=softmax_scale,\n            causal=causal,\n            return_softmax_lse=True,\n        )\n        mask = cache_seqlens == 0\n        out[mask] = 0\n        softmax_lse[mask] = -float(\"inf\")\n        return out, softmax_lse\n\n\ndef _vertical_slash_sparse_attention(\n    query: torch.Tensor,  # [BATCH, N_HEADS, N_CTX, D_HEAD]\n    key: torch.Tensor,  # [BATCH, N_HEADS, N_KV_CTX, D_HEAD]\n    value: torch.Tensor,  # [BATCH, N_HEADS, N_KV_CTX, D_HEAD]\n    v_idx: torch.Tensor,  # [BATCH, N_HEADS, NNZ_V]\n    s_idx: torch.Tensor,  # [BATCH, N_HEADS, NNZ_S]\n    softmax_scale: float,\n    causal: bool = True,\n    stage: str = \"intra\",\n    block_size_M: int = 64,\n    block_size_N: int = 64,\n    vertical_indices_count: torch.Tensor = None,  # [N_HEADS,]\n    slash_indices_count: torch.Tensor = None,\n):\n    if stage == \"intra\":\n        assert causal\n    else:\n        assert not causal\n\n    batch_size, num_heads, context_size, head_dim = query.shape\n    _, _, kv_seq_len, _ = key.shape\n\n    if head_dim not in [16, 32, 64, 128, 256, 512]:\n        target_dim = 2 ** math.ceil(math.log2(head_dim)) - head_dim\n        query = F.pad(query, [0, target_dim, 0, 0, 0, 0, 0, 0])\n        key = F.pad(key, [0, target_dim, 0, 0, 0, 0, 0, 0])\n        value = F.pad(value, [0, target_dim, 0, 0, 0, 0, 0, 0])\n\n    v_idx = (\n        v_idx.to(torch.int32)\n        .reshape((batch_size, num_heads, -1))\n        .sort(dim=-1, descending=False)[0]\n    )\n    s_idx = (\n        s_idx.to(torch.int32)\n        .reshape((batch_size, num_heads, -1))\n        .sort(dim=-1, descending=True)[0]\n    )\n    q_seqlens = torch.tensor([context_size], dtype=torch.int32, device=query.device)\n    kv_seqlens = torch.tensor([kv_seq_len], dtype=torch.int32, device=query.device)\n\n    if vertical_indices_count is not None and slash_indices_count is not None:\n        (\n            block_count,\n            block_offset,\n            column_count,\n            column_index,\n        ) = convert_vertical_slash_indexes_mergehead(\n            q_seqlens,\n            kv_seqlens,\n            v_idx,\n            s_idx,\n            vertical_indices_count,\n            slash_indices_count,\n            context_size,\n            block_size_M,\n            block_size_N,\n            causal,\n        )\n    else:\n        (\n            block_count,\n            block_offset,\n            column_count,\n            column_index,\n        ) = convert_vertical_slash_indexes(\n            q_seqlens,\n            kv_seqlens,\n            v_idx,\n            s_idx,\n            context_size,\n            block_size_M,\n            block_size_N,\n            causal,\n        )\n\n    q = query.transpose(1, 2).contiguous()\n    k = key.transpose(1, 2).contiguous()\n    v = value.transpose(1, 2).contiguous()\n    out, lse = sparse_attn_func(\n        q,\n        k,\n        v,\n        block_count,\n        block_offset,\n        column_count,\n        column_index,\n        causal=causal,\n        softmax_scale=softmax_scale,\n        return_softmax_lse=True,\n    )\n    out = out.transpose(1, 2).contiguous()\n    softmax_lse = lse.reshape(*lse.shape, 1)\n    return (out[..., :context_size, :head_dim], softmax_lse[..., :context_size, :])\n\n\ndef _sum_all_diagonal_matrix(mat: torch.tensor):\n    h, n, m = mat.shape\n    # Zero matrix used for padding\n    zero_mat = torch.zeros((h, n, n), device=mat.device)\n    # pads the matrix on left and right\n    mat_padded = torch.cat((zero_mat, mat, zero_mat), -1)\n    # Change the strides\n    mat_strided = mat_padded.as_strided(\n        (1, n, n + m), (n * (2 * n + m), 2 * n + m + 1, 1)\n    )\n    # Sums the resulting matrix's columns\n    sum_diags = torch.sum(mat_strided, 1)\n    return sum_diags[:, 1:]  # drop left bottom corner\n\n\ndef _get_block(block_table: torch.Tensor, block_size: int, begin: int, end: int):\n    begin_block = begin // block_size\n    end_block = (end - 1) // block_size + 1\n    return block_table[begin_block:end_block]\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/fla/chunk.py",
    "content": "# Adapted from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gated_delta_rule/chunk.py\n# -*- coding: utf-8 -*-\n# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang\n\nfrom typing import Optional\n\nimport torch\nfrom einops import rearrange\n\nfrom sglang.srt.layers.attention.fla.chunk_delta_h import chunk_gated_delta_rule_fwd_h\nfrom sglang.srt.layers.attention.fla.chunk_o import chunk_fwd_o\nfrom sglang.srt.layers.attention.fla.chunk_scaled_dot_kkt import (\n    chunk_scaled_dot_kkt_fwd,\n)\nfrom sglang.srt.layers.attention.fla.cumsum import chunk_local_cumsum\nfrom sglang.srt.layers.attention.fla.l2norm import l2norm_fwd\nfrom sglang.srt.layers.attention.fla.solve_tril import solve_tril\nfrom sglang.srt.layers.attention.fla.utils import (\n    SUPPRESS_LEVEL,\n    autocast_custom_fwd,\n    input_guard,\n)\nfrom sglang.srt.layers.attention.fla.wy_fast import recompute_w_u_fwd\n\n\ndef chunk_gated_delta_rule_fwd(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    g: torch.Tensor,\n    beta: torch.Tensor,\n    scale: float,\n    initial_state: torch.Tensor,\n    initial_state_indices: torch.Tensor,\n    cu_seqlens: Optional[torch.LongTensor] = None,\n):\n    g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)\n    # obtain WY representation. u is actually the new v.\n    A = chunk_scaled_dot_kkt_fwd(\n        k=k, beta=beta, g_cumsum=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32\n    )\n    A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)\n    w, u = recompute_w_u_fwd(\n        k=k,\n        v=v,\n        beta=beta,\n        A=A,\n        g_cumsum=g,\n        cu_seqlens=cu_seqlens,\n    )\n    h, v_new = chunk_gated_delta_rule_fwd_h(\n        k=k,\n        w=w,\n        u=u,\n        g=g,\n        initial_state=initial_state,\n        initial_state_indices=initial_state_indices,\n        cu_seqlens=cu_seqlens,\n    )\n    o = chunk_fwd_o(\n        q=q,\n        k=k,\n        v=v_new,\n        h=h,\n        g=g,\n        scale=scale,\n        cu_seqlens=cu_seqlens,\n    )\n    if SUPPRESS_LEVEL < 3:\n        return g, o, A, None, h, None\n    elif SUPPRESS_LEVEL >= 3:\n        return g, o, A, w, h, v_new\n\n\nclass ChunkGatedDeltaRuleFunction(torch.autograd.Function):\n\n    @staticmethod\n    @input_guard\n    @autocast_custom_fwd\n    def forward(\n        ctx,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        g: torch.Tensor,\n        beta: torch.Tensor,\n        scale: float,\n        initial_state: torch.Tensor,\n        initial_state_indices: torch.Tensor,\n        cu_seqlens: Optional[torch.LongTensor] = None,\n        use_qk_l2norm_in_kernel: bool = False,\n    ):\n        q_orig = q\n        k_orig = k\n\n        if use_qk_l2norm_in_kernel:\n            q = l2norm_fwd(q)\n            k = l2norm_fwd(k)\n\n        g, o, A, w, h, v_new = chunk_gated_delta_rule_fwd(\n            q=q,\n            k=k,\n            v=v,\n            g=g,\n            beta=beta,\n            scale=scale,\n            initial_state=initial_state,\n            initial_state_indices=initial_state_indices,\n            cu_seqlens=cu_seqlens,\n        )\n        return o.to(q.dtype), h\n\n\n@torch.compiler.disable\ndef chunk_gated_delta_rule(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    g: torch.Tensor,\n    beta: torch.Tensor,\n    scale: float = None,\n    initial_state: torch.Tensor = None,\n    initial_state_indices: torch.Tensor = None,\n    cu_seqlens: Optional[torch.LongTensor] = None,\n    head_first: bool = False,\n    use_qk_l2norm_in_kernel: bool = False,\n):\n    r\"\"\"\n    Args:\n        q (torch.Tensor):\n            queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.\n        k (torch.Tensor):\n            keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.\n        v (torch.Tensor):\n            values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.\n        g (torch.Tensor):\n            (forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.\n        beta (torch.Tensor):\n            betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.\n        scale (Optional[int]):\n            Scale factor for the RetNet attention scores.\n            If not provided, it will default to `1 / sqrt(K)`. Default: `None`.\n        initial_state (Optional[torch.Tensor]):\n            Initial state of shape `[N, H, V, K]` for `N` input sequences.\n            For equal-length input sequences, `N` equals the batch size `B`.\n            Default: `None`.\n        output_final_state (Optional[bool]):\n            Whether to output the final state of shape `[N, H, V, K]`. Default: `False`.\n        cu_seqlens (torch.LongTensor):\n            Cumulative sequence lengths of shape `[N+1]` used for variable-length training,\n            consistent with the FlashAttention API.\n        head_first (Optional[bool]):\n            Whether the inputs are in the head-first format, which is not supported for variable-length inputs.\n            Default: `False`.\n\n    Returns:\n        o (torch.Tensor):\n            Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.\n        final_state (torch.Tensor):\n            Final state of shape `[N, H, V, K]` if `output_final_state=True` else `None`.\n\n    Examples::\n        >>> import torch\n        >>> import torch.nn.functional as F\n        >>> from einops import rearrange\n        >>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule\n        # inputs with equal lengths\n        >>> B, T, H, K, V = 4, 2048, 4, 512, 512\n        >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')\n        >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1)\n        >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')\n        >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()\n        >>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))\n        >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')\n        >>> o, ht = chunk_gated_delta_rule(\n            q, k, v, g, beta,\n            initial_state=h0,\n            output_final_state=True\n        )\n        # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required\n        >>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))\n        # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected\n        >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)\n        >>> o_var, ht_var = chunk_gated_delta_rule(\n            q, k, v, g, beta,\n            initial_state=h0,\n            output_final_state=True,\n            cu_seqlens=cu_seqlens\n        )\n    \"\"\"\n    assert q.dtype == k.dtype == v.dtype\n    assert (\n        q.dtype != torch.float32\n    ), \"ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16.\"\n    assert (\n        len(beta.shape) == 3\n    ), \"beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise.\"\n\n    if head_first:\n        raise DeprecationWarning(\n            \"head_first is deprecated and will be removed in a future version. \"\n            \"Please use head_first=False for now instead.\"\n        )\n        q, k, v, beta, g = map(\n            lambda x: rearrange(x, \"b h t ... -> b t h ...\"), (q, k, v, beta, g)\n        )\n    # if not head_first and q.shape[1] < q.shape[2]:\n    #     warnings.warn(\n    #         f\"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). \"\n    #         \"This may indicate the inputs were passed in head-first format [B, H, T, ...] \"\n    #         \"when head_first=False was specified. \"\n    #         \"Please verify your input tensor format matches the expected shape [B, T, H, ...].\"\n    #     )\n    if cu_seqlens is not None:\n        if q.shape[0] != 1:\n            raise ValueError(\n                f\"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`.\"\n                f\"Please flatten variable-length inputs before processing.\"\n            )\n        if (\n            initial_state_indices is not None\n            and initial_state_indices.shape[0] != len(cu_seqlens) - 1\n        ):\n            raise ValueError(\n                f\"The number of initial states is expected to be equal to the number of input sequences, \"\n                f\"i.e., {len(cu_seqlens) - 1} rather than {initial_state_indices.shape[0]}.\"\n            )\n    if scale is None:\n        scale = k.shape[-1] ** -0.5\n    o, h = ChunkGatedDeltaRuleFunction.apply(\n        q,\n        k,\n        v,\n        g,\n        beta,\n        scale,\n        initial_state,\n        initial_state_indices,\n        cu_seqlens,\n        use_qk_l2norm_in_kernel,\n    )\n    if head_first:\n        o = rearrange(o, \"b t h ... -> b h t ...\")\n    return o, None, h\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/fla/chunk_delta_h.py",
    "content": "# Adapted from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/common/chunk_delta_h.py\n# -*- coding: utf-8 -*-\n# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang\n\nfrom typing import Optional, Tuple\n\nimport torch\nimport triton\nimport triton.language as tl\n\nfrom sglang.srt.layers.attention.fla.index import (\n    prepare_chunk_indices,\n    prepare_chunk_offsets,\n)\nfrom sglang.srt.layers.attention.fla.op import exp, safe_exp\nfrom sglang.srt.layers.attention.fla.utils import is_nvidia_hopper\n\nNUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16]\nCHUNK_SIZE = 64\n\n\n# @triton.autotune(\n#     configs=[\n#         triton.Config({\"BV\": BV}, num_warps=num_warps, num_stages=num_stages)\n#         for num_warps in [2, 4]\n#         for num_stages in [2, 3, 4]\n#         for BV in [32, 64]\n#     ],\n#     key=[\"H\", \"K\", \"V\", \"BT\", \"USE_G\"],\n#     use_cuda_graph=use_cuda_graph,\n# )\n@triton.jit(do_not_specialize=[\"T\"])\ndef chunk_gated_delta_rule_fwd_kernel_h_blockdim64(\n    k,\n    v,\n    w,\n    v_new,\n    g,\n    gk,\n    h,\n    initial_state,\n    initial_state_indices,\n    cu_seqlens,\n    chunk_offsets,\n    T,\n    H: tl.constexpr,\n    Hg: tl.constexpr,\n    K: tl.constexpr,\n    V: tl.constexpr,\n    BT: tl.constexpr,\n    BV: tl.constexpr,\n    USE_G: tl.constexpr,\n    USE_GK: tl.constexpr,\n    USE_INITIAL_STATE: tl.constexpr,\n    INPLACE_UPDATE: tl.constexpr,\n    SAVE_NEW_VALUE: tl.constexpr,\n    IS_VARLEN: tl.constexpr,\n):\n    i_v, i_nh = tl.program_id(0), tl.program_id(1)\n    i_n, i_h = i_nh // H, i_nh % H\n    if IS_VARLEN:\n        bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(\n            cu_seqlens + i_n + 1\n        ).to(tl.int32)\n        T = eos - bos\n        NT = tl.cdiv(T, BT)\n        boh = tl.load(chunk_offsets + i_n).to(tl.int32)\n    else:\n        bos, eos = i_n * T, i_n * T + T\n        NT = tl.cdiv(T, BT)\n        boh = i_n * NT\n\n    # [BV, BK]\n    b_h1 = tl.zeros([BV, 64], dtype=tl.float32)\n    if K > 64:\n        b_h2 = tl.zeros([BV, 64], dtype=tl.float32)\n    if K > 128:\n        b_h3 = tl.zeros([BV, 64], dtype=tl.float32)\n    if K > 192:\n        b_h4 = tl.zeros([BV, 64], dtype=tl.float32)\n\n    # calculate offset\n    h += ((boh * H + i_h) * V * K).to(tl.int64)\n    v += ((bos * H + i_h) * V).to(tl.int64)\n    k += ((bos * Hg + i_h // (H // Hg)) * K).to(tl.int64)\n    w += ((bos * H + i_h) * K).to(tl.int64)\n    if SAVE_NEW_VALUE:\n        v_new += ((bos * H + i_h) * V).to(tl.int64)\n    stride_v = H * V\n    stride_h = H * V * K\n    stride_k = Hg * K\n    stride_w = H * K\n\n    index = tl.load(initial_state_indices + i_n).to(tl.int32)\n    h0 = initial_state + index * stride_h\n    ht = initial_state + index * stride_h\n    if USE_INITIAL_STATE:\n        h0 = h0 + i_h * V * K\n    if INPLACE_UPDATE:\n        ht = ht + i_h * V * K\n\n    # load initial state\n    if USE_INITIAL_STATE:\n        p_h0_1 = tl.make_block_ptr(h0, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0))\n        b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32)\n        if K > 64:\n            p_h0_2 = tl.make_block_ptr(\n                h0, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)\n            )\n            b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32)\n        if K > 128:\n            p_h0_3 = tl.make_block_ptr(\n                h0, (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0)\n            )\n            b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32)\n        if K > 192:\n            p_h0_4 = tl.make_block_ptr(\n                h0, (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0)\n            )\n            b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32)\n\n    # main recurrence\n    for i_t in range(NT):\n        p_h1 = tl.make_block_ptr(\n            h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0)\n        )\n        tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1))\n        if K > 64:\n            p_h2 = tl.make_block_ptr(\n                h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)\n            )\n            tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1))\n        if K > 128:\n            p_h3 = tl.make_block_ptr(\n                h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0)\n            )\n            tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1))\n        if K > 192:\n            p_h4 = tl.make_block_ptr(\n                h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0)\n            )\n            tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1))\n\n        p_w = tl.make_block_ptr(\n            w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0)\n        )\n        b_w = tl.load(p_w, boundary_check=(0, 1))\n        b_v = tl.dot(b_w, tl.trans(b_h1).to(b_w.dtype))\n        if K > 64:\n            p_w = tl.make_block_ptr(\n                w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0)\n            )\n            b_w = tl.load(p_w, boundary_check=(0, 1))\n            b_v += tl.dot(b_w, tl.trans(b_h2).to(b_w.dtype))\n        if K > 128:\n            p_w = tl.make_block_ptr(\n                w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0)\n            )\n            b_w = tl.load(p_w, boundary_check=(0, 1))\n            b_v += tl.dot(b_w, tl.trans(b_h3).to(b_w.dtype))\n        if K > 192:\n            p_w = tl.make_block_ptr(\n                w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0)\n            )\n            b_w = tl.load(p_w, boundary_check=(0, 1))\n            b_v += tl.dot(b_w, tl.trans(b_h4).to(b_w.dtype))\n        p_v = tl.make_block_ptr(\n            v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)\n        )\n        b_v = tl.load(p_v, boundary_check=(0, 1)) - b_v\n\n        if SAVE_NEW_VALUE:\n            p_v = tl.make_block_ptr(\n                v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)\n            )\n            tl.store(p_v, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1))\n\n        last_idx = min((i_t + 1) * BT, T) - 1\n        if USE_G:\n            b_g_last = tl.load(g + bos * H + last_idx * H + i_h)\n            p_g = tl.make_block_ptr(\n                g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)\n            )\n            b_g = tl.load(p_g, boundary_check=(0,))\n            b_v = b_v * safe_exp(b_g_last - b_g)[:, None]\n            b_g_last = exp(b_g_last)\n            b_h1 = b_h1 * b_g_last\n            if K > 64:\n                b_h2 = b_h2 * b_g_last\n            if K > 128:\n                b_h3 = b_h3 * b_g_last\n            if K > 192:\n                b_h4 = b_h4 * b_g_last\n\n        if USE_GK:\n            o_k1 = tl.arange(0, 64)\n            b_gk_last1 = tl.load(\n                gk + (bos + last_idx) * H * K + i_h * K + o_k1,\n                mask=(o_k1 < K),\n                other=0.0,\n            )\n            b_h1 *= exp(b_gk_last1)[None, :]\n            if K > 64:\n                o_k2 = 64 + o_k1\n                b_gk_last2 = tl.load(\n                    gk + (bos + last_idx) * H * K + i_h * K + o_k2,\n                    mask=(o_k2 < K),\n                    other=0.0,\n                )\n                b_h2 *= exp(b_gk_last2)[None, :]\n            if K > 128:\n                o_k3 = 128 + o_k1\n                b_gk_last3 = tl.load(\n                    gk + (bos + last_idx) * H * K + i_h * K + o_k3,\n                    mask=(o_k3 < K),\n                    other=0.0,\n                )\n                b_h3 *= exp(b_gk_last3)[None, :]\n            if K > 192:\n                o_k4 = 192 + o_k1\n                b_gk_last4 = tl.load(\n                    gk + (bos + last_idx) * H * K + i_h * K + o_k4,\n                    mask=(o_k4 < K),\n                    other=0.0,\n                )\n                b_h4 *= exp(b_gk_last4)[None, :]\n        b_v = b_v.to(k.dtype.element_ty)\n\n        p_k = tl.make_block_ptr(\n            k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)\n        )\n        b_k = tl.load(p_k, boundary_check=(0, 1))\n        b_h1 += tl.trans(tl.dot(b_k, b_v))\n        if K > 64:\n            p_k = tl.make_block_ptr(\n                k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)\n            )\n            b_k = tl.load(p_k, boundary_check=(0, 1))\n            b_h2 += tl.trans(tl.dot(b_k, b_v))\n        if K > 128:\n            p_k = tl.make_block_ptr(\n                k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)\n            )\n            b_k = tl.load(p_k, boundary_check=(0, 1))\n            b_h3 += tl.trans(tl.dot(b_k, b_v))\n        if K > 192:\n            p_k = tl.make_block_ptr(\n                k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)\n            )\n            b_k = tl.load(p_k, boundary_check=(0, 1))\n            b_h4 += tl.trans(tl.dot(b_k, b_v))\n\n    # epilogue\n    if INPLACE_UPDATE:\n        p_ht = tl.make_block_ptr(ht, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0))\n        tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1))\n        if K > 64:\n            p_ht = tl.make_block_ptr(\n                ht, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)\n            )\n            tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1))\n        if K > 128:\n            p_ht = tl.make_block_ptr(\n                ht, (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0)\n            )\n            tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1))\n        if K > 192:\n            p_ht = tl.make_block_ptr(\n                ht, (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0)\n            )\n            tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1))\n\n\ndef chunk_gated_delta_rule_fwd_h(\n    k: torch.Tensor,\n    w: torch.Tensor,\n    u: torch.Tensor,\n    g: Optional[torch.Tensor] = None,\n    gk: Optional[torch.Tensor] = None,\n    initial_state: Optional[torch.Tensor] = None,\n    initial_state_indices: Optional[torch.Tensor] = None,\n    save_new_value: bool = True,\n    cu_seqlens: Optional[torch.LongTensor] = None,\n) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n    B, T, Hg, K, V = *k.shape, u.shape[-1]\n    H = u.shape[-2]\n    BT = CHUNK_SIZE\n\n    chunk_indices = (\n        prepare_chunk_indices(cu_seqlens, CHUNK_SIZE)\n        if cu_seqlens is not None\n        else None\n    )\n    # N: the actual number of sequences in the batch with either equal or variable lengths\n    if cu_seqlens is None:\n        N, NT, chunk_offsets = B, triton.cdiv(T, BT), None\n    else:\n        N, NT, chunk_offsets = (\n            len(cu_seqlens) - 1,\n            len(chunk_indices),\n            prepare_chunk_offsets(cu_seqlens, BT),\n        )\n    assert K <= 256, \"current kernel does not support head dimension larger than 256.\"\n\n    h = k.new_empty(B, NT, H, V, K)\n\n    v_new = torch.empty_like(u) if save_new_value else None\n\n    def grid(meta):\n        return (triton.cdiv(V, meta[\"BV\"]), N * H)\n\n    chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid](\n        k=k,\n        v=u,\n        w=w,\n        v_new=v_new,\n        g=g,\n        gk=gk,\n        h=h,\n        initial_state=initial_state,\n        initial_state_indices=initial_state_indices,\n        cu_seqlens=cu_seqlens,\n        chunk_offsets=chunk_offsets,\n        T=T,\n        H=H,\n        Hg=Hg,\n        K=K,\n        V=V,\n        BT=BT,\n        BV=32,\n        USE_G=g is not None,\n        USE_GK=gk is not None,\n        USE_INITIAL_STATE=initial_state is not None,\n        INPLACE_UPDATE=True,\n        SAVE_NEW_VALUE=v_new is not None,\n        IS_VARLEN=cu_seqlens is not None,\n        num_warps=4,\n        num_stages=2,\n    )\n    return h, v_new\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/fla/chunk_o.py",
    "content": "# Adapted from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/common/chunk_o.py\n# -*- coding: utf-8 -*-\n# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang\n\nfrom typing import Optional\n\nimport torch\nimport triton\nimport triton.language as tl\n\nfrom sglang.srt.layers.attention.fla.index import prepare_chunk_indices\nfrom sglang.srt.layers.attention.fla.op import exp, safe_exp\nfrom sglang.srt.layers.attention.fla.utils import check_shared_mem, is_nvidia_hopper\n\nBKV_LIST = [64, 128] if check_shared_mem() else [32, 64]\nNUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]\n\n\n# @triton.autotune(\n#     configs=[\n#         triton.Config({\"BK\": BK, \"BV\": BV}, num_warps=num_warps, num_stages=num_stages)\n#         for BK in BKV_LIST\n#         for BV in BKV_LIST\n#         for num_warps in NUM_WARPS\n#         for num_stages in [2, 3, 4]\n#     ],\n#     key=[\"H\", \"K\", \"V\", \"BT\"],\n# )\n@triton.jit(do_not_specialize=[\"T\"])\ndef chunk_fwd_kernel_o(\n    q,\n    k,\n    v,\n    h,\n    g,\n    o,\n    cu_seqlens,\n    chunk_indices,\n    scale,\n    T,\n    H: tl.constexpr,\n    Hg: tl.constexpr,\n    K: tl.constexpr,\n    V: tl.constexpr,\n    BT: tl.constexpr,\n    BK: tl.constexpr,\n    BV: tl.constexpr,\n    USE_G: tl.constexpr,\n    IS_VARLEN: tl.constexpr,\n):\n    i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n    i_b, i_h = i_bh // H, i_bh % H\n\n    if IS_VARLEN:\n        i_tg = i_t\n        i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(\n            chunk_indices + i_t * 2 + 1\n        ).to(tl.int32)\n        bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(\n            cu_seqlens + i_n + 1\n        ).to(tl.int32)\n        T = eos - bos\n        NT = tl.cdiv(T, BT)\n    else:\n        NT = tl.cdiv(T, BT)\n        i_tg = i_b * NT + i_t\n        bos, eos = i_b * T, i_b * T + T\n\n    # offset calculation\n    q += (bos * Hg + i_h // (H // Hg)) * K\n    k += (bos * Hg + i_h // (H // Hg)) * K\n    v += (bos * H + i_h) * V\n    o += (bos * H + i_h) * V\n    h += (i_tg * H + i_h).to(tl.int64) * V * K\n\n    b_o = tl.zeros([BT, BV], dtype=tl.float32)\n    b_A = tl.zeros([BT, BT], dtype=tl.float32)\n\n    for i_k in range(tl.cdiv(K, BK)):\n        p_q = tl.make_block_ptr(\n            q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)\n        )\n        p_k = tl.make_block_ptr(\n            k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)\n        )\n        p_h = tl.make_block_ptr(\n            h, (V, K), (K, 1), (i_v * BV, i_k * BK), (BV, BK), (1, 0)\n        )\n        # [BT, BK]\n        b_q = tl.load(p_q, boundary_check=(0, 1))\n        # [BK, BT]\n        b_k = tl.load(p_k, boundary_check=(0, 1))\n        # [BV, BK]\n        b_h = tl.load(p_h, boundary_check=(0, 1))\n\n        # [BT, BK] @ [BK, BV] -> [BT, BV]\n        b_o += tl.dot(b_q, tl.trans(b_h))\n        # [BT, BK] @ [BK, BT] -> [BT, BT]\n        b_A += tl.dot(b_q, b_k)\n\n    if USE_G:\n        g += bos * H + i_h\n        p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,))\n        b_g = tl.load(p_g, boundary_check=(0,))\n        b_o = b_o * exp(b_g)[:, None]\n        b_A = b_A * safe_exp(b_g[:, None] - b_g[None, :])\n\n    o_i = tl.arange(0, BT)\n    m_A = o_i[:, None] >= o_i[None, :]\n    b_A = tl.where(m_A, b_A, 0)\n\n    p_v = tl.make_block_ptr(\n        v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)\n    )\n    p_o = tl.make_block_ptr(\n        o, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)\n    )\n    b_v = tl.load(p_v, boundary_check=(0, 1))\n\n    # to fix mma -> mma layout conversion\n    # already solved by triton v3.2 or higher\n    b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale\n    tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n\ndef chunk_fwd_o(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    h: torch.Tensor,\n    g: Optional[torch.Tensor] = None,  # cumsum of log decay\n    scale: Optional[float] = None,\n    cu_seqlens: Optional[torch.LongTensor] = None,\n    chunk_size: int = 64,\n) -> torch.Tensor:\n    B, T, Hg, K, V = *q.shape, v.shape[-1]\n    H = v.shape[-2]\n    BT = min(chunk_size, max(16, triton.next_power_of_2(T)))\n    chunk_indices = (\n        prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None\n    )\n    NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)\n    if scale is None:\n        scale = k.shape[-1] ** -0.5\n\n    o = torch.zeros_like(v)\n\n    def grid(meta):\n        return (triton.cdiv(V, meta[\"BV\"]), NT, B * H)\n\n    chunk_fwd_kernel_o[grid](\n        q,\n        k,\n        v,\n        h,\n        g,\n        o,\n        cu_seqlens,\n        chunk_indices,\n        scale,\n        T=T,\n        H=H,\n        Hg=Hg,\n        K=K,\n        V=V,\n        BT=BT,\n        BK=128,\n        BV=64,\n        USE_G=g is not None,\n        IS_VARLEN=cu_seqlens is not None,\n        num_warps=4,\n        num_stages=2,\n    )\n    return o\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py",
    "content": "# Adapted from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/common/chunk_scaled_dot_kkt.py\n# -*- coding: utf-8 -*-\n# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang\n\nfrom typing import Optional\n\nimport torch\nimport triton\nimport triton.language as tl\n\nfrom sglang.srt.layers.attention.fla.index import prepare_chunk_indices\nfrom sglang.srt.layers.attention.fla.op import safe_exp\n\n\n# @triton.autotune(\n#     configs=[\n#         triton.Config({\"BK\": BK}, num_warps=num_warps, num_stages=num_stages)\n#         for BK in [32, 64, 128]\n#         for num_warps in [2, 4, 8]\n#         for num_stages in [2, 3, 4]\n#     ],\n#     key=[\"H\", \"K\", \"BT\", \"IS_VARLEN\"],\n# )\n@triton.jit(do_not_specialize=[\"T\"])\ndef chunk_scaled_dot_kkt_fwd_kernel(\n    k,\n    beta,\n    g_cumsum,\n    A,\n    cu_seqlens,\n    chunk_indices,\n    T,\n    H: tl.constexpr,\n    Hg: tl.constexpr,\n    K: tl.constexpr,\n    BT: tl.constexpr,\n    BK: tl.constexpr,\n    IS_VARLEN: tl.constexpr,\n    USE_G: tl.constexpr,\n):\n    i_t, i_bh = tl.program_id(0), tl.program_id(1)\n    i_b, i_h = i_bh // H, i_bh % H\n    if IS_VARLEN:\n        i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(\n            chunk_indices + i_t * 2 + 1\n        ).to(tl.int32)\n        bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(\n            cu_seqlens + i_n + 1\n        ).to(tl.int32)\n        T = eos - bos\n    else:\n        bos, eos = i_b * T, i_b * T + T\n    o_t = tl.arange(0, BT)\n\n    p_beta = tl.make_block_ptr(\n        beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)\n    )\n    b_beta = tl.load(p_beta, boundary_check=(0,))\n\n    b_A = tl.zeros([BT, BT], dtype=tl.float32)\n    for i_k in range(tl.cdiv(K, BK)):\n        p_k = tl.make_block_ptr(\n            k + (bos * Hg + i_h // (H // Hg)) * K,\n            (T, K),\n            (Hg * K, 1),\n            (i_t * BT, i_k * BK),\n            (BT, BK),\n            (1, 0),\n        )\n        b_k = tl.load(p_k, boundary_check=(0, 1))\n        b_A += tl.dot(b_k, tl.trans(b_k))\n\n    if USE_G:\n        p_g = tl.make_block_ptr(\n            g_cumsum + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)\n        )\n        b_g = tl.load(p_g, boundary_check=(0,))\n        b_g_diff = b_g[:, None] - b_g[None, :]\n        b_A = b_A * safe_exp(b_g_diff)\n\n    b_A *= b_beta[:, None]\n    b_A = tl.where(o_t[:, None] > o_t[None, :], b_A, 0)\n    p_A = tl.make_block_ptr(\n        A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1), (i_t * BT, 0), (BT, BT), (1, 0)\n    )\n    tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))\n\n\ndef chunk_scaled_dot_kkt_fwd(\n    k: torch.Tensor,\n    beta: torch.Tensor,\n    g_cumsum: Optional[torch.Tensor] = None,\n    cu_seqlens: Optional[torch.LongTensor] = None,\n    chunk_size: int = 64,\n    output_dtype: torch.dtype = torch.float32,\n) -> torch.Tensor:\n    r\"\"\"\n    Compute beta * K * K^T.\n\n    Args:\n        k (torch.Tensor):\n            The key tensor of shape `[B, T, H, K]`.\n        beta (torch.Tensor):\n            The beta tensor of shape `[B, T, H]`.\n        g_cumsum (torch.Tensor):\n            The cumulative sum of the gate tensor of shape `[B, T, H]`.\n            Default: None\n        cu_seqlens (torch.LongTensor):\n            The cumulative sequence lengths of the input tensor.\n            Default: None\n        chunk_size (int):\n            The chunk size. Default: 64.\n        output_dtype (torch.dtype):\n            The dtype of the output tensor. Default: `torch.float32`\n\n    Returns:\n        beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size.\n    \"\"\"\n\n    B, T, Hg, K = k.shape\n\n    H = beta.shape[-1]\n    BT = chunk_size\n    chunk_indices = (\n        prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None\n    )\n    NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)\n    A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype)\n    chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)](\n        k=k,\n        beta=beta,\n        g_cumsum=g_cumsum,\n        A=A,\n        cu_seqlens=cu_seqlens,\n        chunk_indices=chunk_indices,\n        T=T,\n        H=H,\n        Hg=Hg,\n        K=K,\n        BT=BT,\n        BK=64,\n        IS_VARLEN=cu_seqlens is not None,\n        USE_G=g_cumsum is not None,\n        num_warps=8,\n        num_stages=3,\n    )\n    return A\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/fla/cumsum.py",
    "content": "# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/utils/cumsum.py\n# -*- coding: utf-8 -*-\n# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang\n\nfrom typing import Optional\n\nimport torch\nimport triton\nimport triton.language as tl\n\nfrom sglang.srt.layers.attention.fla.index import prepare_chunk_indices\nfrom sglang.srt.layers.attention.fla.utils import check_shared_mem, input_guard\n\nBS_LIST = [32, 64] if check_shared_mem() else [16, 32]\n\n\n# @triton.autotune(\n#     configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]],\n#     key=[\"B\", \"H\", \"BT\", \"IS_VARLEN\", \"REVERSE\"],\n# )\n@triton.jit(do_not_specialize=[\"T\"])\ndef chunk_local_cumsum_scalar_kernel(\n    s,\n    o,\n    scale,\n    cu_seqlens,\n    chunk_indices,\n    T,\n    B: tl.constexpr,\n    H: tl.constexpr,\n    BT: tl.constexpr,\n    REVERSE: tl.constexpr,\n    HAS_SCALE: tl.constexpr,\n    IS_VARLEN: tl.constexpr,\n    HEAD_FIRST: tl.constexpr,\n):\n    i_t, i_bh = tl.program_id(0), tl.program_id(1)\n    i_b, i_h = i_bh // H, i_bh % H\n    if IS_VARLEN:\n        i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(\n            chunk_indices + i_t * 2 + 1\n        ).to(tl.int32)\n        bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(\n            cu_seqlens + i_n + 1\n        ).to(tl.int32)\n        T = eos - bos\n    else:\n        bos, eos = i_b * T, i_b * T + T\n\n    if HEAD_FIRST:\n        p_s = tl.make_block_ptr(\n            s + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,)\n        )\n        p_o = tl.make_block_ptr(\n            o + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,)\n        )\n    else:\n        p_s = tl.make_block_ptr(s + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))\n        p_o = tl.make_block_ptr(o + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))\n    # [BT]\n    b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32)\n    b_o = tl.cumsum(b_s, axis=0)\n    if REVERSE:\n        b_z = tl.sum(b_s, axis=0)\n        b_o = -b_o + b_z[None] + b_s\n    if HAS_SCALE:\n        b_o *= scale\n    tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,))\n\n\n@triton.autotune(\n    configs=[\n        triton.Config({\"BS\": BS}, num_warps=num_warps)\n        for BS in BS_LIST\n        for num_warps in [2, 4, 8]\n    ],\n    key=[\"B\", \"H\", \"S\", \"BT\", \"IS_VARLEN\", \"REVERSE\", \"HAS_SCALE\"],\n)\n@triton.jit(do_not_specialize=[\"T\"])\ndef chunk_local_cumsum_vector_kernel(\n    s,\n    o,\n    scale,\n    cu_seqlens,\n    chunk_indices,\n    T,\n    B: tl.constexpr,\n    H: tl.constexpr,\n    S: tl.constexpr,\n    BT: tl.constexpr,\n    BS: tl.constexpr,\n    REVERSE: tl.constexpr,\n    HAS_SCALE: tl.constexpr,\n    IS_VARLEN: tl.constexpr,\n    HEAD_FIRST: tl.constexpr,\n):\n    i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n    i_b, i_h = i_bh // H, i_bh % H\n    if IS_VARLEN:\n        i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(\n            chunk_indices + i_t * 2 + 1\n        ).to(tl.int32)\n        bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(\n            cu_seqlens + i_n + 1\n        ).to(tl.int32)\n        T = eos - bos\n    else:\n        bos, eos = i_b * T, i_b * T + T\n\n    o_i = tl.arange(0, BT)\n    if REVERSE:\n        m_s = tl.where(o_i[:, None] <= o_i[None, :], 1.0, 0.0)\n    else:\n        m_s = tl.where(o_i[:, None] >= o_i[None, :], 1.0, 0.0)\n\n    if HEAD_FIRST:\n        p_s = tl.make_block_ptr(\n            s + (bos * H + i_h * T) * S,\n            (T, S),\n            (S, 1),\n            (i_t * BT, i_s * BS),\n            (BT, BS),\n            (1, 0),\n        )\n        p_o = tl.make_block_ptr(\n            o + (bos * H + i_h * T) * S,\n            (T, S),\n            (S, 1),\n            (i_t * BT, i_s * BS),\n            (BT, BS),\n            (1, 0),\n        )\n    else:\n        p_s = tl.make_block_ptr(\n            s + (bos * H + i_h) * S,\n            (T, S),\n            (H * S, 1),\n            (i_t * BT, i_s * BS),\n            (BT, BS),\n            (1, 0),\n        )\n        p_o = tl.make_block_ptr(\n            o + (bos * H + i_h) * S,\n            (T, S),\n            (H * S, 1),\n            (i_t * BT, i_s * BS),\n            (BT, BS),\n            (1, 0),\n        )\n    # [BT, BS]\n    b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)\n    b_o = tl.dot(m_s, b_s, allow_tf32=False)\n    if HAS_SCALE:\n        b_o *= scale\n    tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n\ndef chunk_local_cumsum_scalar(\n    g: torch.Tensor,\n    chunk_size: int,\n    reverse: bool = False,\n    scale: float = None,\n    cu_seqlens: Optional[torch.Tensor] = None,\n    head_first: bool = False,\n    output_dtype: Optional[torch.dtype] = torch.float,\n) -> torch.Tensor:\n    if head_first:\n        B, H, T = g.shape\n    else:\n        B, T, H = g.shape\n    assert chunk_size == 2 ** (\n        chunk_size.bit_length() - 1\n    ), \"chunk_size must be a power of 2\"\n    BT = chunk_size\n    chunk_indices = (\n        prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None\n    )\n    NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)\n    g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)\n    grid = (NT, B * H)\n    chunk_local_cumsum_scalar_kernel[grid](\n        s=g_org,\n        o=g,\n        scale=scale,\n        cu_seqlens=cu_seqlens,\n        chunk_indices=chunk_indices,\n        T=T,\n        B=B,\n        H=H,\n        BT=BT,\n        HEAD_FIRST=head_first,\n        REVERSE=reverse,\n        HAS_SCALE=scale is not None,\n        IS_VARLEN=cu_seqlens is not None,\n        num_warps=8,\n        num_stages=3,\n    )\n    return g\n\n\ndef chunk_local_cumsum_vector(\n    g: torch.Tensor,\n    chunk_size: int,\n    reverse: bool = False,\n    scale: float = None,\n    cu_seqlens: Optional[torch.Tensor] = None,\n    head_first: bool = False,\n    output_dtype: Optional[torch.dtype] = torch.float,\n) -> torch.Tensor:\n    if head_first:\n        B, H, T, S = g.shape\n    else:\n        B, T, H, S = g.shape\n    BT = chunk_size\n    chunk_indices = (\n        prepare_chunk_indices(cu_seqlens, chunk_size)\n        if cu_seqlens is not None\n        else None\n    )\n    NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)\n    assert chunk_size == 2 ** (\n        chunk_size.bit_length() - 1\n    ), \"chunk_size must be a power of 2\"\n\n    g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)\n\n    def grid(meta):\n        return (triton.cdiv(meta[\"S\"], meta[\"BS\"]), NT, B * H)\n\n    # keep cumulative normalizer in fp32\n    # this kernel is equivalent to\n    # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)\n    chunk_local_cumsum_vector_kernel[grid](\n        s=g_org,\n        o=g,\n        scale=scale,\n        cu_seqlens=cu_seqlens,\n        chunk_indices=chunk_indices,\n        T=T,\n        B=B,\n        H=H,\n        S=S,\n        BT=BT,\n        HEAD_FIRST=head_first,\n        REVERSE=reverse,\n        HAS_SCALE=scale is not None,\n        IS_VARLEN=cu_seqlens is not None,\n    )\n    return g\n\n\n@input_guard\ndef chunk_local_cumsum(\n    g: torch.Tensor,\n    chunk_size: int,\n    reverse: bool = False,\n    scale: float = None,\n    cu_seqlens: Optional[torch.Tensor] = None,\n    head_first: bool = False,\n    output_dtype: Optional[torch.dtype] = torch.float,\n    **kwargs,\n) -> torch.Tensor:\n    if cu_seqlens is not None:\n        assert (\n            g.shape[0] == 1\n        ), \"Only batch size 1 is supported when cu_seqlens are provided\"\n    if len(g.shape) == 3:\n        return chunk_local_cumsum_scalar(\n            g=g,\n            chunk_size=chunk_size,\n            reverse=reverse,\n            scale=scale,\n            cu_seqlens=cu_seqlens,\n            head_first=head_first,\n            output_dtype=output_dtype,\n        )\n    elif len(g.shape) == 4:\n        return chunk_local_cumsum_vector(\n            g=g,\n            chunk_size=chunk_size,\n            reverse=reverse,\n            scale=scale,\n            cu_seqlens=cu_seqlens,\n            head_first=head_first,\n            output_dtype=output_dtype,\n        )\n    else:\n        raise ValueError(\n            f\"Unsupported input shape {g.shape}, \"\n            f\"which should be (B, T, H, D) if `head_first=False` \"\n            f\"or (B, H, T, D) otherwise\"\n        )\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/fla/fused_gdn_gating.py",
    "content": "from typing import Tuple\n\nimport torch\nimport triton\nimport triton.language as tl\n\n\n# g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)\n# beta_output = b.sigmoid()\n@triton.jit\ndef fused_gdn_gating_kernel(\n    g,\n    beta_output,\n    A_log,\n    a,\n    b,\n    dt_bias,\n    seq_len,\n    NUM_HEADS: tl.constexpr,\n    beta: tl.constexpr,\n    threshold: tl.constexpr,\n    BLK_HEADS: tl.constexpr,\n):\n    i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n    head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS)\n    off = i_b * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off\n    mask = head_off < NUM_HEADS\n    blk_A_log = tl.load(A_log + head_off, mask=mask)\n    blk_a = tl.load(a + off, mask=mask)\n    blk_b = tl.load(b + off, mask=mask)\n    blk_bias = tl.load(dt_bias + head_off, mask=mask)\n    x = blk_a.to(tl.float32) + blk_bias.to(tl.float32)\n    softplus_x = tl.where(\n        beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x\n    )\n    blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x\n    tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask)\n    blk_beta_output = tl.sigmoid(blk_b.to(tl.float32))\n    tl.store(beta_output + off, blk_beta_output.to(b.dtype.element_ty), mask=mask)\n\n\ndef fused_gdn_gating(\n    A_log: torch.Tensor,\n    a: torch.Tensor,\n    b: torch.Tensor,\n    dt_bias: torch.Tensor,\n    beta: float = 1.0,\n    threshold: float = 20.0,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    batch, num_heads = a.shape\n    seq_len = 1\n    grid = (batch, seq_len, triton.cdiv(num_heads, 8))\n    g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device)\n    beta_output = torch.empty(1, batch, num_heads, dtype=torch.float32, device=b.device)\n    fused_gdn_gating_kernel[grid](\n        g,\n        beta_output,\n        A_log,\n        a,\n        b,\n        dt_bias,\n        seq_len,\n        num_heads,\n        beta,\n        threshold,\n        8,\n        num_warps=1,\n    )\n    return g, beta_output\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/fla/fused_norm_gate.py",
    "content": "# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/modules/fused_norm_gate.py\n# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang\n\n\nimport torch\nimport torch.nn as nn\nimport triton\nimport triton.language as tl\n\nfrom sglang.srt.utils import (\n    cdiv,\n    cpu_has_amx_support,\n    is_cpu,\n    is_npu,\n    next_power_of_2,\n)\n\n_is_npu = is_npu()\n_use_cpu = is_cpu() and cpu_has_amx_support()\n\n# Maximum rows per Triton block for layernorm gated kernel\nMAX_ROWS_PER_BLOCK = 4\n\n\n@triton.jit\ndef layer_norm_gated_fwd_kernel(\n    x,  # pointer to the input\n    g,  # pointer to the gate\n    y,  # pointer to the output\n    w,  # pointer to the weights\n    b,  # pointer to the biases\n    residual,  # pointer to the residual\n    residual_out,  # pointer to the residual\n    mean,  # pointer to the mean\n    rstd,  # pointer to the 1/std\n    eps,  # epsilon to avoid division by zero\n    T,  # number of rows in x\n    D: tl.constexpr,  # number of columns in x\n    BT: tl.constexpr,\n    BD: tl.constexpr,\n    ACTIVATION: tl.constexpr,\n    IS_RMS_NORM: tl.constexpr,\n    STORE_RESIDUAL_OUT: tl.constexpr,\n    HAS_RESIDUAL: tl.constexpr,\n    HAS_WEIGHT: tl.constexpr,\n    HAS_BIAS: tl.constexpr,\n):\n    i_t = tl.program_id(0)\n\n    o_d = tl.arange(0, BD)\n    m_d = o_d < D\n\n    p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))\n    b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32)\n    if HAS_RESIDUAL:\n        p_res = tl.make_block_ptr(\n            residual, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)\n        )\n        b_x += tl.load(p_res, boundary_check=(0, 1)).to(tl.float32)\n    if STORE_RESIDUAL_OUT:\n        p_res_out = tl.make_block_ptr(\n            residual_out, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)\n        )\n        tl.store(p_res_out, b_x.to(p_res_out.dtype.element_ty), boundary_check=(0, 1))\n    if not IS_RMS_NORM:\n        b_mean = tl.sum(b_x, axis=1) / D\n        p_mean = tl.make_block_ptr(mean, (T,), (1,), (i_t * BT,), (BT,), (0,))\n        tl.store(p_mean, b_mean.to(p_mean.dtype.element_ty), boundary_check=(0,))\n        b_xbar = tl.where(m_d[None, :], b_x - b_mean[:, None], 0.0)\n        b_var = tl.sum(b_xbar * b_xbar, axis=1) / D\n    else:\n        b_xbar = tl.where(m_d[None, :], b_x, 0.0)\n        b_var = tl.sum(b_xbar * b_xbar, axis=1) / D\n    b_rstd = 1 / tl.sqrt(b_var + eps)\n\n    p_rstd = tl.make_block_ptr(rstd, (T,), (1,), (i_t * BT,), (BT,), (0,))\n    tl.store(p_rstd, b_rstd.to(p_rstd.dtype.element_ty), boundary_check=(0,))\n\n    if HAS_WEIGHT:\n        b_w = tl.load(w + o_d, mask=m_d).to(tl.float32)\n    if HAS_BIAS:\n        b_b = tl.load(b + o_d, mask=m_d).to(tl.float32)\n    b_x_hat = (\n        (b_x - b_mean[:, None]) * b_rstd[:, None]\n        if not IS_RMS_NORM\n        else b_x * b_rstd[:, None]\n    )\n    b_y = b_x_hat * b_w[None, :] if HAS_WEIGHT else b_x_hat\n    if HAS_BIAS:\n        b_y = b_y + b_b[None, :]\n\n    # swish/sigmoid output gate\n    p_g = tl.make_block_ptr(g, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))\n    b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)\n    if ACTIVATION == \"swish\" or ACTIVATION == \"silu\":\n        b_y = b_y * b_g * tl.sigmoid(b_g)\n    elif ACTIVATION == \"sigmoid\":\n        b_y = b_y * tl.sigmoid(b_g)\n\n    # Write output\n    p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))\n    tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef layer_norm_gated_fwd_kernel1(\n    x,  # pointer to the input\n    g,  # pointer to the gate\n    y,  # pointer to the output\n    w,  # pointer to the weights\n    b,  # pointer to the biases\n    residual,  # pointer to the residual\n    residual_out,  # pointer to the residual\n    mean,  # pointer to the mean\n    rstd,  # pointer to the 1/std\n    eps,  # epsilon to avoid division by zero\n    D: tl.constexpr,  # number of columns in x\n    BD: tl.constexpr,\n    ACTIVATION: tl.constexpr,\n    IS_RMS_NORM: tl.constexpr,\n    STORE_RESIDUAL_OUT: tl.constexpr,\n    HAS_RESIDUAL: tl.constexpr,\n    HAS_WEIGHT: tl.constexpr,\n    HAS_BIAS: tl.constexpr,\n):\n    i_t = tl.program_id(0)\n    x += i_t * D\n    y += i_t * D\n    g += i_t * D\n    if HAS_RESIDUAL:\n        residual += i_t * D\n    if STORE_RESIDUAL_OUT:\n        residual_out += i_t * D\n\n    o_d = tl.arange(0, BD)\n    m_d = o_d < D\n    b_x = tl.load(x + o_d, mask=m_d, other=0.0).to(tl.float32)\n    if HAS_RESIDUAL:\n        b_x += tl.load(residual + o_d, mask=m_d, other=0.0).to(tl.float32)\n    if STORE_RESIDUAL_OUT:\n        tl.store(residual_out + o_d, b_x, mask=m_d)\n    if not IS_RMS_NORM:\n        b_mean = tl.sum(b_x, axis=0) / D\n        tl.store(mean + i_t, b_mean)\n        b_xbar = tl.where(m_d, b_x - b_mean, 0.0)\n        b_var = tl.sum(b_xbar * b_xbar, axis=0) / D\n    else:\n        b_xbar = tl.where(m_d, b_x, 0.0)\n        b_var = tl.sum(b_xbar * b_xbar, axis=0) / D\n    b_rstd = 1 / tl.sqrt(b_var + eps)\n    tl.store(rstd + i_t, b_rstd)\n\n    if HAS_WEIGHT:\n        b_w = tl.load(w + o_d, mask=m_d).to(tl.float32)\n    if HAS_BIAS:\n        b_b = tl.load(b + o_d, mask=m_d).to(tl.float32)\n    b_x_hat = (b_x - b_mean) * b_rstd if not IS_RMS_NORM else b_x * b_rstd\n    b_y = b_x_hat * b_w if HAS_WEIGHT else b_x_hat\n    if HAS_BIAS:\n        b_y = b_y + b_b\n\n    # swish/sigmoid output gate\n    b_g = tl.load(g + o_d, mask=m_d, other=0.0).to(tl.float32)\n    if ACTIVATION == \"swish\" or ACTIVATION == \"silu\":\n        b_y = b_y * b_g * tl.sigmoid(b_g)\n    elif ACTIVATION == \"sigmoid\":\n        b_y = b_y * tl.sigmoid(b_g)\n\n    # Write output\n    tl.store(y + o_d, b_y, mask=m_d)\n\n\ndef layer_norm_gated_fwd(\n    x: torch.Tensor,\n    g: torch.Tensor,\n    weight: torch.Tensor,\n    bias: torch.Tensor,\n    activation: str = \"swish\",\n    eps: float = 1e-5,\n    residual: torch.Tensor = None,\n    out_dtype: torch.dtype = None,\n    residual_dtype: torch.dtype = None,\n    is_rms_norm: bool = False,\n):\n    if residual is not None:\n        residual_dtype = residual.dtype\n    T, D = x.shape\n    if residual is not None:\n        assert residual.shape == (T, D)\n    if weight is not None:\n        assert weight.shape == (D,)\n    if bias is not None:\n        assert bias.shape == (D,)\n    # allocate output\n    y = x if out_dtype is None else torch.empty_like(x, dtype=out_dtype)\n    if residual is not None or (\n        residual_dtype is not None and residual_dtype != x.dtype\n    ):\n        residual_out = torch.empty(T, D, device=x.device, dtype=residual_dtype)\n    else:\n        residual_out = None\n    mean = (\n        torch.empty((T,), dtype=torch.float, device=x.device)\n        if not is_rms_norm\n        else None\n    )\n    rstd = torch.empty((T,), dtype=torch.float, device=x.device)\n    # Less than 64KB per feature: enqueue fused kernel\n    MAX_FUSED_SIZE = 65536 // x.element_size()\n    BD = min(MAX_FUSED_SIZE, next_power_of_2(D))\n    if D > BD:\n        raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n    # heuristics for number of warps\n\n    if D <= 512:\n        BT = 32\n        layer_norm_gated_fwd_kernel[(cdiv(T, BT),)](\n            x=x,\n            g=g,\n            y=y,\n            w=weight,\n            b=bias,\n            residual=residual,\n            residual_out=residual_out,\n            mean=mean,\n            rstd=rstd,\n            eps=eps,\n            T=T,\n            D=D,\n            BD=BD,\n            BT=BT,\n            ACTIVATION=activation,\n            IS_RMS_NORM=is_rms_norm,\n            STORE_RESIDUAL_OUT=residual_out is not None,\n            HAS_RESIDUAL=residual is not None,\n            HAS_WEIGHT=weight is not None,\n            HAS_BIAS=bias is not None,\n            num_warps=4,\n        )\n    else:\n        layer_norm_gated_fwd_kernel1[(T,)](\n            x=x,\n            g=g,\n            y=y,\n            w=weight,\n            b=bias,\n            residual=residual,\n            residual_out=residual_out,\n            mean=mean,\n            rstd=rstd,\n            eps=eps,\n            D=D,\n            BD=BD,\n            ACTIVATION=activation,\n            IS_RMS_NORM=is_rms_norm,\n            STORE_RESIDUAL_OUT=residual_out is not None,\n            HAS_RESIDUAL=residual is not None,\n            HAS_WEIGHT=weight is not None,\n            HAS_BIAS=bias is not None,\n            num_warps=4,\n        )\n    # residual_out is None if residual is None and residual_dtype == input_dtype\n    return y, mean, rstd, residual_out if residual_out is not None else x\n\n\nclass LayerNormGatedFunction(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        x: torch.Tensor,\n        g: torch.Tensor,\n        weight: torch.Tensor,\n        bias: torch.Tensor,\n        activation: str,\n        residual: torch.Tensor | None = None,\n        eps: float = 1e-6,\n        prenorm: bool = False,\n        residual_in_fp32: bool = False,\n        is_rms_norm: bool = False,\n    ):\n        x_shape_og = x.shape\n        g_shape_og = g.shape\n        # reshape input data into 2D tensor\n        x = x.reshape(-1, x.shape[-1])\n        g = g.reshape(-1, g.shape[-1])\n        if residual is not None:\n            assert residual.shape == x_shape_og\n            residual = residual.reshape(-1, residual.shape[-1])\n        residual_dtype = (\n            residual.dtype\n            if residual is not None\n            else (torch.float if residual_in_fp32 else None)\n        )\n        y, mean, rstd, residual_out = layer_norm_gated_fwd(\n            x=x,\n            g=g,\n            weight=weight,\n            bias=bias,\n            activation=activation,\n            eps=eps,\n            residual=residual,\n            residual_dtype=residual_dtype,\n            is_rms_norm=is_rms_norm,\n        )\n        ctx.save_for_backward(residual_out, g, weight, bias, mean, rstd)\n        ctx.x_shape_og = x_shape_og\n        ctx.g_shape_og = g_shape_og\n        ctx.activation = activation\n        ctx.eps = eps\n        ctx.is_rms_norm = is_rms_norm\n        ctx.has_residual = residual is not None\n        ctx.prenorm = prenorm\n        ctx.x_dtype = x.dtype\n        y = y.reshape(x_shape_og)\n        return y if not prenorm else (y, residual_out.reshape(x_shape_og))\n\n\ndef rms_norm_gated(\n    x: torch.Tensor,\n    g: torch.Tensor,\n    weight: torch.Tensor,\n    bias: torch.Tensor,\n    activation: str = \"swish\",\n    residual: torch.Tensor | None = None,\n    prenorm: bool = False,\n    residual_in_fp32: bool = False,\n    eps: float = 1e-6,\n):\n    return LayerNormGatedFunction.apply(\n        x,\n        g,\n        weight,\n        bias,\n        activation,\n        residual,\n        eps,\n        prenorm,\n        residual_in_fp32,\n        True,\n    )\n\n\nclass FusedRMSNormGated(nn.Module):\n    def __init__(\n        self,\n        hidden_size: int,\n        elementwise_affine: bool = True,\n        eps: float = 1e-5,\n        activation: str = \"swish\",\n        device: torch.device | None = None,\n        dtype: torch.dtype | None = None,\n    ) -> None:\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n\n        self.hidden_size = hidden_size\n        self.elementwise_affine = elementwise_affine\n        self.eps = eps\n        self.activation = activation\n\n        if self.activation not in [\"swish\", \"silu\", \"sigmoid\"]:\n            raise ValueError(f\"Unsupported activation: {self.activation}\")\n\n        if elementwise_affine:\n            self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))\n        else:\n            self.register_parameter(\"weight\", None)\n        self.register_parameter(\"bias\", None)\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        g: torch.Tensor,\n        residual: torch.Tensor | None = None,\n        prenorm: bool = False,\n        residual_in_fp32: bool = False,\n    ) -> torch.Tensor:\n        return rms_norm_gated(\n            x,\n            g,\n            self.weight,\n            self.bias,\n            self.activation,\n            residual=residual,\n            eps=self.eps,\n            prenorm=prenorm,\n            residual_in_fp32=residual_in_fp32,\n        )\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/fla/fused_recurrent.py",
    "content": "# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gated_delta_rule/fused_recurrent.py\n# -*- coding: utf-8 -*-\n# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang\n\nfrom typing import Optional, Tuple\n\nimport torch\nimport triton\nimport triton.language as tl\n\nfrom sglang.srt.layers.attention.fla.op import exp\nfrom sglang.srt.layers.attention.fla.utils import input_guard\n\n\n@triton.jit(do_not_specialize=[\"T\"])\ndef fused_recurrent_gated_delta_rule_fwd_kernel(\n    q,\n    k,\n    v,\n    g,\n    beta,\n    o,\n    h0,\n    ht,\n    cu_seqlens,\n    scale,\n    T,\n    B: tl.constexpr,\n    H: tl.constexpr,\n    HV: tl.constexpr,\n    K: tl.constexpr,\n    V: tl.constexpr,\n    BK: tl.constexpr,\n    BV: tl.constexpr,\n    USE_INITIAL_STATE: tl.constexpr,  # whether to use initial state\n    STORE_FINAL_STATE: tl.constexpr,  # whether to store final state\n    IS_BETA_HEADWISE: tl.constexpr,  # whether beta is headwise vector or scalar,\n    USE_QK_L2NORM_IN_KERNEL: tl.constexpr,\n    IS_VARLEN: tl.constexpr,\n    IS_KDA: tl.constexpr,\n):\n    i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n    i_n, i_hv = i_nh // HV, i_nh % HV\n    i_h = i_hv // (HV // H)\n    if IS_VARLEN:\n        bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(\n            cu_seqlens + i_n + 1\n        ).to(tl.int64)\n        all = T\n        T = eos - bos\n    else:\n        bos, eos = i_n * T, i_n * T + T\n        all = B * T\n    o_k = i_k * BK + tl.arange(0, BK)\n    o_v = i_v * BV + tl.arange(0, BV)\n\n    p_q = q + (bos * H + i_h) * K + o_k\n    p_k = k + (bos * H + i_h) * K + o_k\n    p_v = v + (bos * HV + i_hv) * V + o_v\n    if IS_BETA_HEADWISE:\n        p_beta = beta + (bos * HV + i_hv) * V + o_v\n    else:\n        p_beta = beta + bos * HV + i_hv\n    if not IS_KDA:\n        p_g = g + bos * HV + i_hv\n    else:\n        p_gk = g + (bos * H + i_h) * K + o_k\n\n    p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v\n\n    mask_k = o_k < K\n    mask_v = o_v < V\n    mask_h = mask_v[:, None] & mask_k[None, :]\n\n    b_h = tl.zeros([BV, BK], dtype=tl.float32)\n    if USE_INITIAL_STATE:\n        p_h0 = h0 + i_nh * V * K + o_v[:, None] * K + o_k[None, :]\n        b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)\n\n    for _ in range(0, T):\n        b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)\n        b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)\n        b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)\n\n        if USE_QK_L2NORM_IN_KERNEL:\n            b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q) + 1e-6))\n            b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k) + 1e-6))\n        b_q = b_q * scale\n        # [BV, BK]\n        if not IS_KDA:\n            b_g = tl.load(p_g).to(tl.float32)\n            b_h *= exp(b_g)\n        else:\n            b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32)\n            b_h *= exp(b_gk[None, :])\n        # [BV]\n        b_v -= tl.sum(b_h * b_k[None, :], 1)\n        if IS_BETA_HEADWISE:\n            b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)\n        else:\n            b_beta = tl.load(p_beta).to(tl.float32)\n        b_v *= b_beta\n        # [BV, BK]\n        b_h += b_v[:, None] * b_k[None, :]\n        # [BV]\n        b_o = tl.sum(b_h * b_q[None, :], 1)\n        tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)\n\n        p_q += H * K\n        p_k += H * K\n        p_o += HV * V\n        p_v += HV * V\n        if not IS_KDA:\n            p_g += HV\n        else:\n            p_gk += H * K\n        p_beta += HV * (V if IS_BETA_HEADWISE else 1)\n\n    if STORE_FINAL_STATE:\n        p_ht = ht + i_nh * V * K + o_v[:, None] * K + o_k[None, :]\n        tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)\n\n\ndef fused_recurrent_gated_delta_rule_fwd(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    g: torch.Tensor,\n    beta: torch.Tensor,\n    scale: float,\n    initial_state: torch.Tensor,\n    output_final_state: bool,\n    use_qk_l2norm_in_kernel: bool = False,\n    cu_seqlens: Optional[torch.LongTensor] = None,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    B, T, H, K, V = *k.shape, v.shape[-1]\n    HV = v.shape[2]\n    N = B if cu_seqlens is None else len(cu_seqlens) - 1\n    BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32)\n    NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)\n    assert NK == 1, \"NK > 1 is not supported yet\"\n    num_stages = 3\n    num_warps = 1\n\n    o = q.new_empty(NK, *v.shape)\n    if output_final_state:\n        final_state = q.new_empty(N, HV, V, K, dtype=torch.float32)\n    else:\n        final_state = None\n\n    grid = (NK, NV, N * HV)\n    fused_recurrent_gated_delta_rule_fwd_kernel[grid](\n        q=q,\n        k=k,\n        v=v,\n        g=g,\n        beta=beta,\n        o=o,\n        h0=initial_state,\n        ht=final_state,\n        cu_seqlens=cu_seqlens,\n        scale=scale,\n        T=T,\n        B=B,\n        H=H,\n        HV=HV,\n        K=K,\n        V=V,\n        BK=BK,\n        BV=BV,\n        USE_INITIAL_STATE=initial_state is not None,\n        STORE_FINAL_STATE=final_state is not None,\n        IS_BETA_HEADWISE=beta.ndim == v.ndim,\n        USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,\n        IS_VARLEN=cu_seqlens is not None,\n        IS_KDA=False,\n        num_warps=num_warps,\n        num_stages=num_stages,\n    )\n    o = o.squeeze(0)\n    return o, final_state\n\n\n# Adapted from vllm project.\n@triton.jit\ndef fused_recurrent_gated_delta_rule_packed_decode_kernel(\n    mixed_qkv,\n    a,\n    b,\n    A_log,\n    dt_bias,\n    o,\n    h0,\n    ht,\n    ssm_state_indices,\n    scale,\n    stride_mixed_qkv_tok: tl.constexpr,\n    stride_a_tok: tl.constexpr,\n    stride_b_tok: tl.constexpr,\n    stride_init_state_token: tl.constexpr,\n    stride_final_state_token: tl.constexpr,\n    stride_indices_seq: tl.constexpr,\n    H: tl.constexpr,\n    HV: tl.constexpr,\n    K: tl.constexpr,\n    V: tl.constexpr,\n    BK: tl.constexpr,\n    BV: tl.constexpr,\n    SOFTPLUS_THRESHOLD: tl.constexpr,\n    USE_QK_L2NORM_IN_KERNEL: tl.constexpr,\n):\n    i_v, i_nh = tl.program_id(0), tl.program_id(1)\n    i_n, i_hv = i_nh // HV, i_nh % HV\n    i_h = i_hv // (HV // H)\n\n    o_k = tl.arange(0, BK)\n    o_v = i_v * BV + tl.arange(0, BV)\n    mask_k = o_k < K\n    mask_v = o_v < V\n    mask_h = mask_v[:, None] & mask_k[None, :]\n\n    state_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq).to(tl.int64)\n    p_o = o + (i_n * HV + i_hv) * V + o_v\n\n    if state_idx < 0:\n        zero = tl.zeros([BV], dtype=tl.float32).to(p_o.dtype.element_ty)\n        tl.store(p_o, zero, mask=mask_v)\n        return\n\n    p_h0 = h0 + state_idx * stride_init_state_token\n    p_h0 = p_h0 + i_hv * V * K + o_v[:, None] * K + o_k[None, :]\n    b_h = tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)\n\n    p_mixed = mixed_qkv + i_n * stride_mixed_qkv_tok\n    q_off = i_h * K + o_k\n    k_off = (H * K) + i_h * K + o_k\n    v_off = (2 * H * K) + i_hv * V + o_v\n    b_q = tl.load(p_mixed + q_off, mask=mask_k, other=0).to(tl.float32)\n    b_k = tl.load(p_mixed + k_off, mask=mask_k, other=0).to(tl.float32)\n    b_v = tl.load(p_mixed + v_off, mask=mask_v, other=0).to(tl.float32)\n\n    if USE_QK_L2NORM_IN_KERNEL:\n        b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)\n        b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)\n    b_q = b_q * scale\n\n    a_val = tl.load(a + i_n * stride_a_tok + i_hv).to(tl.float32)\n    b_val = tl.load(b + i_n * stride_b_tok + i_hv).to(tl.float32)\n    A_log_val = tl.load(A_log + i_hv).to(tl.float32)\n    dt_bias_val = tl.load(dt_bias + i_hv).to(tl.float32)\n    x = a_val + dt_bias_val\n    softplus_x = tl.where(x <= SOFTPLUS_THRESHOLD, tl.log(1.0 + tl.exp(x)), x)\n    g_val = -tl.exp(A_log_val) * softplus_x\n    beta_val = tl.sigmoid(b_val).to(b.dtype.element_ty).to(tl.float32)\n\n    b_h *= exp(g_val)\n    b_v -= tl.sum(b_h * b_k[None, :], 1)\n    b_v *= beta_val\n    b_h += b_v[:, None] * b_k[None, :]\n    b_o = tl.sum(b_h * b_q[None, :], 1)\n    tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)\n\n    p_ht = ht + state_idx * stride_final_state_token\n    p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]\n    tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)\n\n\ndef fused_recurrent_gated_delta_rule_packed_decode(\n    mixed_qkv: torch.Tensor,\n    a: torch.Tensor,\n    b: torch.Tensor,\n    A_log: torch.Tensor,\n    dt_bias: torch.Tensor,\n    scale: float,\n    initial_state: torch.Tensor,\n    out: torch.Tensor,\n    ssm_state_indices: torch.Tensor,\n    use_qk_l2norm_in_kernel: bool = False,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    if mixed_qkv.ndim != 2:\n        raise ValueError(\n            f\"`mixed_qkv` must be a 2D tensor (got ndim={mixed_qkv.ndim}).\"\n        )\n    if mixed_qkv.stride(-1) != 1:\n        raise ValueError(\"`mixed_qkv` must be contiguous in the last dim.\")\n    if a.ndim != 2 or b.ndim != 2:\n        raise ValueError(\n            f\"`a` and `b` must be 2D tensors (got a.ndim={a.ndim}, b.ndim={b.ndim}).\"\n        )\n    if a.stride(-1) != 1 or b.stride(-1) != 1:\n        raise ValueError(\"`a`/`b` must be contiguous in the last dim.\")\n    if A_log.ndim != 1 or dt_bias.ndim != 1:\n        raise ValueError(\"`A_log`/`dt_bias` must be 1D tensors.\")\n    if A_log.stride(0) != 1 or dt_bias.stride(0) != 1:\n        raise ValueError(\"`A_log`/`dt_bias` must be contiguous.\")\n    if ssm_state_indices.ndim != 1:\n        raise ValueError(\n            f\"`ssm_state_indices` must be 1D for packed decode (got ndim={ssm_state_indices.ndim}).\"\n        )\n    if not out.is_contiguous():\n        raise ValueError(\"`out` must be contiguous.\")\n\n    dev = mixed_qkv.device\n    if any(\n        t.device != dev\n        for t in (a, b, A_log, dt_bias, initial_state, out, ssm_state_indices)\n    ):\n        raise ValueError(\"All inputs must be on the same device.\")\n\n    B = mixed_qkv.shape[0]\n    if a.shape[0] != B or b.shape[0] != B:\n        raise ValueError(\n            \"Mismatched batch sizes: \"\n            f\"mixed_qkv.shape[0]={B}, a.shape[0]={a.shape[0]}, b.shape[0]={b.shape[0]}.\"\n        )\n    if ssm_state_indices.shape[0] != B:\n        raise ValueError(\n            f\"`ssm_state_indices` must have shape [B] (got {tuple(ssm_state_indices.shape)}; expected ({B},)).\"\n        )\n\n    if initial_state.ndim != 4:\n        raise ValueError(\n            f\"`initial_state` must be a 4D tensor (got ndim={initial_state.ndim}).\"\n        )\n    if initial_state.stride(-1) != 1:\n        raise ValueError(\"`initial_state` must be contiguous in the last dim.\")\n    HV, V, K = initial_state.shape[-3:]\n    if a.shape[1] != HV or b.shape[1] != HV:\n        raise ValueError(\n            f\"`a`/`b` must have shape [B, HV] with HV={HV} (got a.shape={tuple(a.shape)}, b.shape={tuple(b.shape)}).\"\n        )\n    if A_log.numel() != HV or dt_bias.numel() != HV:\n        raise ValueError(\n            f\"`A_log` and `dt_bias` must have {HV} elements (got A_log.numel()={A_log.numel()}, dt_bias.numel()={dt_bias.numel()}).\"\n        )\n    if out.shape != (B, 1, HV, V):\n        raise ValueError(\n            f\"`out` must have shape {(B, 1, HV, V)} (got out.shape={tuple(out.shape)}).\"\n        )\n\n    qkv_dim = mixed_qkv.shape[1]\n    qk_dim = qkv_dim - HV * V\n    if qk_dim <= 0 or qk_dim % 2 != 0:\n        raise ValueError(\n            f\"Invalid packed `mixed_qkv` last dim={qkv_dim} for HV={HV}, V={V}.\"\n        )\n    q_dim = qk_dim // 2\n    if q_dim % K != 0:\n        raise ValueError(f\"Invalid packed Q size {q_dim}: must be divisible by K={K}.\")\n    H = q_dim // K\n    if H <= 0 or HV % H != 0:\n        raise ValueError(\n            f\"Invalid head config inferred from mixed_qkv: H={H}, HV={HV}.\"\n        )\n\n    BK = triton.next_power_of_2(K)\n    if triton.cdiv(K, BK) != 1:\n        raise ValueError(\n            f\"Packed decode kernel only supports NK=1 (got K={K}, BK={BK}).\"\n        )\n    BV = min(triton.next_power_of_2(V), 32)\n    num_stages = 3\n    num_warps = 1\n\n    stride_mixed_qkv_tok = mixed_qkv.stride(0)\n    stride_a_tok = a.stride(0)\n    stride_b_tok = b.stride(0)\n    stride_init_state_token = initial_state.stride(0)\n    stride_final_state_token = initial_state.stride(0)\n    stride_indices_seq = ssm_state_indices.stride(0)\n\n    NV = triton.cdiv(V, BV)\n    grid = (NV, B * HV)\n    fused_recurrent_gated_delta_rule_packed_decode_kernel[grid](\n        mixed_qkv=mixed_qkv,\n        a=a,\n        b=b,\n        A_log=A_log,\n        dt_bias=dt_bias,\n        o=out,\n        h0=initial_state,\n        ht=initial_state,\n        ssm_state_indices=ssm_state_indices,\n        scale=scale,\n        stride_mixed_qkv_tok=stride_mixed_qkv_tok,\n        stride_a_tok=stride_a_tok,\n        stride_b_tok=stride_b_tok,\n        stride_init_state_token=stride_init_state_token,\n        stride_final_state_token=stride_final_state_token,\n        stride_indices_seq=stride_indices_seq,\n        H=H,\n        HV=HV,\n        K=K,\n        V=V,\n        BK=BK,\n        BV=BV,\n        SOFTPLUS_THRESHOLD=20.0,\n        USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,\n        num_warps=num_warps,\n        num_stages=num_stages,\n    )\n    return out, initial_state\n\n\nclass FusedRecurrentFunction(torch.autograd.Function):\n\n    @staticmethod\n    @input_guard\n    def forward(\n        ctx,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        g: torch.Tensor,\n        beta: torch.Tensor,\n        scale: float,\n        initial_state: torch.Tensor,\n        output_final_state: bool,\n        cu_seqlens: Optional[torch.LongTensor] = None,\n        use_qk_l2norm_in_kernel: bool = False,\n    ):\n        o, final_state = fused_recurrent_gated_delta_rule_fwd(\n            q=q,\n            k=k,\n            v=v,\n            g=g,\n            beta=beta,\n            scale=scale,\n            initial_state=initial_state,\n            output_final_state=output_final_state,\n            use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,\n            cu_seqlens=cu_seqlens,\n        )\n\n        return o, final_state\n\n    @staticmethod\n    @input_guard\n    def backward(ctx, do, dht):\n        raise NotImplementedError(\n            \"Backward pass is not implemented yet and we do not have plans to implement it \"\n            \"because we haven't figured out how to compute dg without materializing the full \"\n            \"hidden states for all time steps.\"\n        )\n\n\ndef fused_recurrent_gated_delta_rule(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    g: torch.Tensor,\n    beta: torch.Tensor = None,\n    scale: float = None,\n    initial_state: torch.Tensor = None,\n    output_final_state: bool = False,\n    cu_seqlens: Optional[torch.LongTensor] = None,\n    use_qk_l2norm_in_kernel: bool = False,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    r\"\"\"\n    Args:\n        q (torch.Tensor):\n            queries of shape `[B, T, H, K]`.\n        k (torch.Tensor):\n            keys of shape `[B, T, H, K]`.\n        v (torch.Tensor):\n            values of shape `[B, T, HV, V]`.\n            GVA is applied if `HV > H`.\n        g (torch.Tensor):\n            g (decays) of shape `[B, T, HV]`.\n        beta (torch.Tensor):\n            betas of shape `[B, T, HV]`.\n        scale (Optional[int]):\n            Scale factor for the RetNet attention scores.\n            If not provided, it will default to `1 / sqrt(K)`. Default: `None`.\n        initial_state (Optional[torch.Tensor]):\n            Initial state of shape `[N, HV, V, K]` for `N` input sequences.\n            For equal-length input sequences, `N` equals the batch size `B`.\n            Default: `None`.\n        output_final_state (Optional[bool]):\n            Whether to output the final state of shape `[N, HV, V, K]`. Default: `False`.\n        cu_seqlens (torch.LongTensor):\n            Cumulative sequence lengths of shape `[N+1]` used for variable-length training,\n            consistent with the FlashAttention API.\n    Returns:\n        o (torch.Tensor):\n            Outputs of shape `[B, T, HV, V]`.\n        final_state (torch.Tensor):\n            Final state of shape `[N, HV, V, K]` if `output_final_state=True` else `None`.\n    Examples::\n        >>> import torch\n        >>> import torch.nn.functional as F\n        >>> from einops import rearrange\n        >>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule\n        # inputs with equal lengths\n        >>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512\n        >>> q = torch.randn(B, T, H, K, device='cuda')\n        >>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1)\n        >>> v = torch.randn(B, T, HV, V, device='cuda')\n        >>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda'))\n        >>> beta = torch.rand(B, T, HV, device='cuda').sigmoid()\n        >>> h0 = torch.randn(B, HV, V, K, device='cuda')\n        >>> o, ht = fused_gated_recurrent_delta_rule(\n            q, k, v, g, beta,\n            initial_state=h0,\n            output_final_state=True\n        )\n        # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required\n        >>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta))\n        # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected\n        >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)\n        >>> o_var, ht_var = fused_gated_recurrent_delta_rule(\n            q, k, v, g, beta,\n            initial_state=h0,\n            output_final_state=True,\n            cu_seqlens=cu_seqlens\n        )\n    \"\"\"\n    if cu_seqlens is not None:\n        if q.shape[0] != 1:\n            raise ValueError(\n                f\"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`.\"\n                f\"Please flatten variable-length inputs before processing.\"\n            )\n        if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:\n            raise ValueError(\n                f\"The number of initial states is expected to be equal to the number of input sequences, \"\n                f\"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.\"\n            )\n    if scale is None:\n        scale = k.shape[-1] ** -0.5\n    else:\n        assert scale > 0, \"scale must be positive\"\n    if beta is None:\n        beta = torch.ones_like(q[..., 0])\n    o, final_state = FusedRecurrentFunction.apply(\n        q,\n        k,\n        v,\n        g,\n        beta,\n        scale,\n        initial_state,\n        output_final_state,\n        cu_seqlens,\n        use_qk_l2norm_in_kernel,\n    )\n    return o, final_state\n\n\n# HAS_EAGLE_TREE_CUSTOM_ATTN_MASK is added to support eagle tree attention mask\n# retrieve_parent_token_ptr: [N, NP2_T], retrieve_next_sibling_ptr: [N, NP2_T]\n# e.g. for a sequence of length 4, the eagle tree attention structure is:\n# retrieve_next_token=[1, 3, -1, -1] -> retrieve_next_token[i]: the 1st child token of token i\n# retrieve_next_sibling=[-1, 2, -1, -1] -> retrieve_next_sibling[i]: the 1st tree sibling token of token i\n# retrieve_parent_token=[n/a, 0, 0, 1] -> retrieve_parent_token[i]: the parent token of token i\n# Tree:\n#    0\n#   / \\\n#  1   2\n# /\n# 3\n# When calculating token 3's attention, it should attend to token 1 (parent) and token 0 (grand-parent)\n# When calculating token 2's attention, it should attend to token 0 (parent)\n@triton.jit(do_not_specialize=[\"T\"])\ndef fused_recurrent_gated_delta_rule_update_fwd_kernel(\n    q,\n    k,\n    v,\n    g,\n    beta,\n    o,\n    h0_source,\n    h0_indices,\n    cu_seqlens,\n    scale,\n    intermediate_states_buffer,\n    intermediate_state_indices,\n    cache_steps,\n    retrieve_parent_token_ptr,\n    stride_retrieve_parent_token_seq: tl.constexpr,\n    stride_retrieve_parent_token_token: tl.constexpr,\n    T,\n    NP2_T: tl.constexpr,\n    B: tl.constexpr,\n    H: tl.constexpr,\n    HV: tl.constexpr,\n    K: tl.constexpr,\n    V: tl.constexpr,\n    BK: tl.constexpr,\n    BV: tl.constexpr,\n    USE_INITIAL_STATE: tl.constexpr,  # whether to use initial state\n    IS_BETA_HEADWISE: tl.constexpr,  # whether beta is headwise vector or scalar,\n    USE_QK_L2NORM_IN_KERNEL: tl.constexpr,\n    IS_VARLEN: tl.constexpr,\n    DISABLE_STATE_UPDATE: tl.constexpr,  # whether to disable final state update\n    DISABLE_OUTPUT_CALCULATION: tl.constexpr,  # whether to disable output calculation\n    CACHE_INTERMEDIATE_STATES: tl.constexpr,\n    HAS_EAGLE_TREE_CUSTOM_ATTN_MASK: tl.constexpr,\n):\n    i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n    i_n, i_hv = i_nh // HV, i_nh % HV\n    i_h = i_hv // (HV // H)\n    if IS_VARLEN:\n        bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(\n            cu_seqlens + i_n + 1\n        ).to(tl.int64)\n        all = T\n        T = eos - bos\n    else:\n        bos, eos = i_n * T, i_n * T + T\n        all = B * T\n    o_k = i_k * BK + tl.arange(0, BK)\n    o_v = i_v * BV + tl.arange(0, BV)\n\n    p_q = q + (bos * H + i_h) * K + o_k\n    p_k = k + (bos * H + i_h) * K + o_k\n    p_v = v + (bos * HV + i_hv) * V + o_v\n    if IS_BETA_HEADWISE:\n        p_beta = beta + (bos * HV + i_hv) * V + o_v\n    else:\n        p_beta = beta + bos * HV + i_hv\n    p_g = g + bos * HV + i_hv\n    p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v\n\n    if HAS_EAGLE_TREE_CUSTOM_ATTN_MASK:\n        token_indices = tl.arange(0, NP2_T)\n        mask_retrieve = token_indices < T\n        retrieve_parent_token_base = (\n            retrieve_parent_token_ptr\n            + (i_n * stride_retrieve_parent_token_seq)\n            + token_indices * stride_retrieve_parent_token_token\n        )\n        parent_idx_tokens = tl.load(retrieve_parent_token_base, mask_retrieve)\n\n    mask_k = o_k < K\n    mask_v = o_v < V\n    mask_h = mask_v[:, None] & mask_k[None, :]\n\n    b_h = tl.zeros([BV, BK], dtype=tl.float32)\n    if USE_INITIAL_STATE:\n        idx = tl.load(h0_indices + i_n)\n        # Add bounds checking for idx\n        if idx >= 0:  # Assuming negative indices are invalid\n            p_h0 = (\n                h0_source\n                + idx * HV * K * V\n                + i_hv * K * V\n                + o_v[:, None] * K\n                + o_k[None, :]\n            )\n            b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)\n\n    # Prepare intermediate state cache variables if enabled\n    cache_idx = -1\n    if CACHE_INTERMEDIATE_STATES:\n        cache_idx = tl.load(intermediate_state_indices + i_n)\n\n    step_idx = 0\n    for _ in range(0, T):\n        if HAS_EAGLE_TREE_CUSTOM_ATTN_MASK:\n            # step_idx = 0 should use the b_h from USE_INITIAL_STATE\n            if step_idx != 0 and cache_idx >= 0:\n                # when calculating current step's attention, load the state from the parent token\n                parent_step_idx = tl.sum(\n                    tl.where(token_indices == step_idx, parent_idx_tokens, 0)\n                )\n                step_offset = parent_step_idx * HV * K * V\n                cache_ptr = (\n                    intermediate_states_buffer\n                    + cache_idx * cache_steps * HV * K * V\n                    + step_offset\n                    + i_hv * K * V\n                    + o_v[:, None] * K\n                    + o_k[None, :]\n                )\n                b_h = tl.load(cache_ptr, mask=mask_h, other=0).to(tl.float32)\n\n        b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)\n        b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)\n        b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)\n        b_g = tl.load(p_g).to(tl.float32)\n\n        if USE_QK_L2NORM_IN_KERNEL:\n            b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q) + 1e-6))\n            b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k) + 1e-6))\n        b_q = b_q * scale\n        # [BK, BV]\n        b_h *= exp(b_g)\n        # [BV]\n        b_v -= tl.sum(b_h * b_k[None, :], 1)\n        if IS_BETA_HEADWISE:\n            b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)\n        else:\n            b_beta = tl.load(p_beta).to(tl.float32)\n        b_v *= b_beta\n        # [BV, BK]\n        b_h += b_v[:, None] * b_k[None, :]\n        # [BV]\n        if not DISABLE_OUTPUT_CALCULATION:\n            b_o = tl.sum(b_h * b_q[None, :], 1)\n            # core attn output\n            tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)\n\n        # store intermediate states if enabled\n        if CACHE_INTERMEDIATE_STATES:\n            if cache_idx >= 0:\n                # Compute cache pointer for this step\n                step_offset = step_idx * HV * K * V\n                cache_ptr = (\n                    intermediate_states_buffer\n                    + cache_idx * cache_steps * HV * K * V\n                    + step_offset\n                    + i_hv * K * V\n                    + o_v[:, None] * K\n                    + o_k[None, :]\n                )\n                tl.store(cache_ptr, b_h.to(cache_ptr.dtype.element_ty), mask=mask_h)\n\n        step_idx += 1\n\n        p_q += H * K\n        p_k += H * K\n        p_o += HV * V\n        p_v += HV * V\n        p_g += HV\n        p_beta += HV * (V if IS_BETA_HEADWISE else 1)\n\n    # Store final state back to h0_source with bounds checking\n    # ssm states\n    if not DISABLE_STATE_UPDATE:\n        idx = tl.load(h0_indices + i_n)\n        if idx >= 0:  # Add bounds checking\n            p_h0 = (\n                h0_source\n                + idx * HV * K * V\n                + i_hv * K * V\n                + o_v[:, None] * K\n                + o_k[None, :]\n            )\n            tl.store(p_h0, b_h.to(p_h0.dtype.element_ty), mask=mask_h)\n\n\ndef fused_recurrent_gated_delta_rule_update_fwd(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    g: torch.Tensor,\n    beta: torch.Tensor,\n    scale: float,\n    initial_state_source: torch.Tensor,\n    initial_state_indices: torch.Tensor,\n    use_qk_l2norm_in_kernel: bool = False,\n    cu_seqlens: Optional[torch.LongTensor] = None,\n    disable_state_update: bool = False,\n    disable_output_calculation: bool = False,\n    intermediate_states_buffer: Optional[torch.Tensor] = None,\n    intermediate_state_indices: Optional[torch.Tensor] = None,\n    cache_steps: Optional[int] = None,\n    retrieve_parent_token: Optional[torch.Tensor] = None,\n) -> torch.Tensor:\n    B, T, H, K, V = *k.shape, v.shape[-1]\n    HV = v.shape[2]\n    N = B if cu_seqlens is None else len(cu_seqlens) - 1\n    BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8)\n    NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)\n    assert NK == 1, \"NK > 1 is not supported yet\"\n    num_stages = 3\n    num_warps = 1\n\n    if disable_output_calculation:\n        # When output calculation is disabled, allocate minimal tensor\n        o = q.new_empty(NK, 1, 1, 1, 1)  # minimal allocation\n    else:\n        o = q.new_empty(NK, *v.shape)\n\n    grid = (NK, NV, N * HV)\n\n    # prepare retrieve next token buffer strides if provided\n    if retrieve_parent_token is not None:\n        stride_retrieve_parent_token_seq, stride_retrieve_parent_token_token = (\n            retrieve_parent_token.stride(0),\n            retrieve_parent_token.stride(1),\n        )\n    else:\n        stride_retrieve_parent_token_seq = stride_retrieve_parent_token_token = 0\n\n    NP2_T = triton.next_power_of_2(T)\n    fused_recurrent_gated_delta_rule_update_fwd_kernel[grid](\n        q=q,\n        k=k,\n        v=v,\n        g=g,\n        beta=beta,\n        o=o,\n        h0_source=initial_state_source,\n        h0_indices=initial_state_indices,\n        cu_seqlens=cu_seqlens,\n        scale=scale,\n        intermediate_states_buffer=intermediate_states_buffer,\n        intermediate_state_indices=intermediate_state_indices,\n        cache_steps=0 if cache_steps is None else cache_steps,\n        retrieve_parent_token_ptr=retrieve_parent_token,\n        stride_retrieve_parent_token_seq=stride_retrieve_parent_token_seq,\n        stride_retrieve_parent_token_token=stride_retrieve_parent_token_token,\n        T=T,\n        NP2_T=NP2_T,\n        B=B,\n        H=H,\n        HV=HV,\n        K=K,\n        V=V,\n        BK=BK,\n        BV=BV,\n        USE_INITIAL_STATE=initial_state_source is not None,\n        IS_VARLEN=cu_seqlens is not None,\n        CACHE_INTERMEDIATE_STATES=intermediate_states_buffer is not None,\n        HAS_EAGLE_TREE_CUSTOM_ATTN_MASK=retrieve_parent_token is not None,\n        IS_BETA_HEADWISE=beta.ndim == v.ndim,\n        USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,\n        DISABLE_STATE_UPDATE=disable_state_update,\n        DISABLE_OUTPUT_CALCULATION=disable_output_calculation,\n        num_warps=num_warps,\n        num_stages=num_stages,\n    )\n    o = o.squeeze(0)\n    return o\n\n\nclass FusedRecurrentUpdateFunction(torch.autograd.Function):\n\n    @staticmethod\n    @input_guard\n    def forward(\n        ctx,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        g: torch.Tensor,\n        beta: torch.Tensor,\n        scale: float,\n        initial_state_source: torch.Tensor,\n        initial_state_indices: torch.Tensor,\n        cu_seqlens: Optional[torch.LongTensor] = None,\n        use_qk_l2norm_in_kernel: bool = False,\n        disable_state_update: bool = False,\n        disable_output_calculation: bool = False,\n        intermediate_states_buffer: Optional[torch.Tensor] = None,\n        intermediate_state_indices: Optional[torch.Tensor] = None,\n        cache_steps: Optional[int] = None,\n        retrieve_parent_token: Optional[torch.Tensor] = None,\n    ):\n        o = fused_recurrent_gated_delta_rule_update_fwd(\n            q=q,\n            k=k,\n            v=v,\n            g=g,\n            beta=beta,\n            scale=scale,\n            initial_state_source=initial_state_source,\n            initial_state_indices=initial_state_indices,\n            use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,\n            cu_seqlens=cu_seqlens,\n            disable_state_update=disable_state_update,\n            disable_output_calculation=disable_output_calculation,\n            intermediate_states_buffer=intermediate_states_buffer,\n            intermediate_state_indices=intermediate_state_indices,\n            cache_steps=cache_steps,\n            retrieve_parent_token=retrieve_parent_token,\n        )\n\n        return o\n\n    @staticmethod\n    @input_guard\n    def backward(ctx, do, dht):\n        raise NotImplementedError(\n            \"Backward pass is not implemented yet and we do not have plans to implement it \"\n            \"because we haven't figured out how to compute dg without materializing the full \"\n            \"hidden states for all time steps.\"\n        )\n\n\ndef fused_recurrent_gated_delta_rule_update(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    g: torch.Tensor,\n    beta: torch.Tensor = None,\n    scale: float = None,\n    initial_state_source: torch.Tensor = None,\n    initial_state_indices: torch.Tensor = None,\n    cu_seqlens: Optional[torch.LongTensor] = None,\n    use_qk_l2norm_in_kernel: bool = False,\n    disable_state_update: bool = False,\n    disable_output_calculation: bool = False,\n    intermediate_states_buffer: Optional[torch.Tensor] = None,\n    intermediate_state_indices: Optional[torch.Tensor] = None,\n    cache_steps: Optional[int] = None,\n    retrieve_parent_token: Optional[torch.Tensor] = None,\n) -> torch.Tensor:\n    if cu_seqlens is not None:\n        if q.shape[0] != 1:\n            raise ValueError(\n                f\"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`.\"\n                f\"Please flatten variable-length inputs before processing.\"\n            )\n        if initial_state_source is not None:\n            if initial_state_indices.shape[0] != len(cu_seqlens) - 1:\n                raise ValueError(\n                    f\"The number of initial states is expected to be equal to the number of input sequences, \"\n                    f\"i.e., {len(cu_seqlens) - 1} rather than {initial_state_indices.shape[0]}.\"\n                )\n            if initial_state_indices.shape[0] != intermediate_state_indices.shape[0]:\n                raise ValueError(\n                    f\"The number of intermediate state indices is expected to be equal to the number of input sequences, \"\n                    f\"i.e., {initial_state_indices.shape[0]} != {intermediate_state_indices.shape[0]}.\"\n                )\n    if scale is None:\n        scale = k.shape[-1] ** -0.5\n    else:\n        assert scale > 0, \"scale must be positive\"\n    if beta is None:\n        beta = torch.ones_like(q[..., 0])\n    o = FusedRecurrentUpdateFunction.apply(\n        q,\n        k,\n        v,\n        g,\n        beta,\n        scale,\n        initial_state_source,\n        initial_state_indices,\n        cu_seqlens,\n        use_qk_l2norm_in_kernel,\n        disable_state_update,\n        disable_output_calculation,\n        intermediate_states_buffer,\n        intermediate_state_indices,\n        cache_steps,\n        retrieve_parent_token,\n    )\n    return o\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py",
    "content": "from typing import Optional\n\nimport torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit(do_not_specialize=[\"T\"])\ndef fused_sigmoid_gating_delta_rule_update_kernel(\n    A_log,\n    a,\n    dt_bias,\n    softplus_beta,\n    softplus_threshold,\n    q,\n    k,\n    v,\n    b,\n    o,\n    h0_source,\n    h0_indices,\n    cu_seqlens,\n    # Parameters for target_verify support (unused for decode)\n    intermediate_states_buffer,\n    intermediate_state_indices,\n    cache_steps,\n    retrieve_parent_token_ptr,\n    stride_retrieve_parent_token_seq: tl.constexpr,\n    stride_retrieve_parent_token_token: tl.constexpr,\n    # ================================================\n    scale,\n    T,\n    stride_q,\n    stride_k,\n    stride_v,\n    stride_b,\n    NP2_T: tl.constexpr,\n    B: tl.constexpr,\n    H: tl.constexpr,\n    HV: tl.constexpr,\n    K: tl.constexpr,\n    V: tl.constexpr,\n    BK: tl.constexpr,\n    BV: tl.constexpr,\n    USE_INITIAL_STATE: tl.constexpr,\n    USE_QK_L2NORM_IN_KERNEL: tl.constexpr,\n    IS_VARLEN: tl.constexpr,\n    IS_KDA: tl.constexpr,\n    # Optional flags for target_verify support (default False for decode)\n    DISABLE_STATE_UPDATE: tl.constexpr = False,\n    CACHE_INTERMEDIATE_STATES: tl.constexpr = False,\n    HAS_EAGLE_TREE_CUSTOM_ATTN_MASK: tl.constexpr = False,\n):\n    \"\"\"\n    Fused kernel that combines sigmoid gating computation with recurrent delta rule update.\n    \"\"\"\n    i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n    i_n, i_hv = i_nh // HV, i_nh % HV\n    i_h = i_hv // (HV // H)\n\n    if IS_VARLEN:\n        bos, eos = (\n            tl.load(cu_seqlens + i_n).to(tl.int64),\n            tl.load(cu_seqlens + i_n + 1).to(tl.int64),\n        )\n        all = T\n        T = eos - bos\n    else:\n        bos, eos = i_n * T, i_n * T + T\n        all = B * T\n\n    o_k = i_k * BK + tl.arange(0, BK)\n    o_v = i_v * BV + tl.arange(0, BV)\n\n    p_q = q + bos * stride_q + i_h * K + o_k\n    p_k = k + bos * stride_k + i_h * K + o_k\n    p_v = v + bos * stride_v + i_hv * V + o_v\n    p_b = b + bos * stride_b + i_hv\n    p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v\n\n    # Gating computation pointers\n    p_A_log = A_log + i_hv\n    if IS_KDA:\n        p_a = a + (bos * HV + i_hv) * K + o_k\n        p_dt_bias = dt_bias + i_hv * K + o_k\n    else:\n        p_a = a + bos * HV + i_hv\n        p_dt_bias = dt_bias + i_hv\n\n    mask_k = o_k < K\n    mask_v = o_v < V\n    mask_h = mask_k[:, None] & mask_v[None, :]\n\n    b_h = tl.zeros([BK, BV], dtype=tl.float32)\n    if USE_INITIAL_STATE:\n        idx = tl.load(h0_indices + i_n)\n        if idx >= 0:\n            p_h0 = (\n                h0_source\n                + idx * HV * K * V\n                + i_hv * K * V\n                + o_v[None, :] * K\n                + o_k[:, None]\n            )\n            b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)\n\n    # Preload tree attention data if needed\n    if HAS_EAGLE_TREE_CUSTOM_ATTN_MASK:\n        token_indices = tl.arange(0, NP2_T)\n        mask_retrieve = token_indices < T\n        retrieve_parent_token_base = (\n            retrieve_parent_token_ptr\n            + (i_n * stride_retrieve_parent_token_seq)\n            + token_indices * stride_retrieve_parent_token_token\n        )\n        parent_idx_tokens = tl.load(\n            retrieve_parent_token_base, mask=mask_retrieve, other=0\n        )\n\n    # Prepare intermediate state cache index if enabled\n    cache_idx = -1\n    if CACHE_INTERMEDIATE_STATES:\n        cache_idx = tl.load(intermediate_state_indices + i_n)\n\n    step_idx = 0\n    for _ in range(0, T):\n        # Tree attention: load parent's cached state\n        if HAS_EAGLE_TREE_CUSTOM_ATTN_MASK:\n            # step_idx == 0 uses b_h from USE_INITIAL_STATE\n            if step_idx != 0 and cache_idx >= 0:\n                parent_step_idx = tl.sum(\n                    tl.where(token_indices == step_idx, parent_idx_tokens, 0)\n                )\n                step_offset = parent_step_idx * HV * K * V\n                cache_ptr = (\n                    intermediate_states_buffer\n                    + cache_idx * cache_steps * HV * K * V\n                    + step_offset\n                    + i_hv * K * V\n                    + o_v[None, :] * K\n                    + o_k[:, None]\n                )\n                b_h = tl.load(cache_ptr, mask=mask_h, other=0).to(tl.float32)\n\n        # Load inputs\n        b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)\n        b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)\n        b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)\n        b_b = tl.load(p_b).to(tl.float32)\n\n        # Compute sigmoid gating\n        # Load gating parameters\n        b_A_log = tl.load(p_A_log).to(tl.float32)\n        if IS_KDA:\n            b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32)\n            b_dt_bias = tl.load(p_dt_bias, mask=mask_k, other=0).to(tl.float32)\n        else:\n            b_a = tl.load(p_a).to(tl.float32)\n            b_dt_bias = tl.load(p_dt_bias).to(tl.float32)\n\n        # Compute g = -exp(A_log) * softplus(a + dt_bias)\n        x = b_a + b_dt_bias\n        beta_x = softplus_beta * x\n        # Apply softplus with numerical stability\n        softplus_x = tl.where(\n            beta_x <= softplus_threshold,\n            (1.0 / softplus_beta) * tl.log(1.0 + tl.exp(beta_x)),\n            x,\n        )\n        b_g = -tl.exp(b_A_log) * softplus_x\n\n        # Compute beta = sigmoid(b)\n        b_beta = 1.0 / (1.0 + tl.exp(-b_b))\n\n        # Apply L2 normalization if enabled\n        if USE_QK_L2NORM_IN_KERNEL:\n            b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q) + 1e-6))\n            b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k) + 1e-6))\n\n        b_q = b_q * scale\n\n        # Apply gating to hidden state: h *= exp(g)\n        if IS_KDA:\n            b_h *= tl.exp(b_g[:, None])\n        else:\n            b_h *= tl.exp(b_g)\n\n        # Delta rule: v -= sum(h * k, dim=0)\n        b_v -= tl.sum(b_h * b_k[:, None], 0)\n\n        # Apply beta gating: v *= beta\n        b_v *= b_beta\n\n        # Update hidden state: h += k[:, None] * v[None, :]\n        b_h += b_k[:, None] * b_v[None, :]\n\n        # Compute output: o = sum(h * q, dim=0)\n        b_o = tl.sum(b_h * b_q[:, None], 0)\n        tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)\n\n        # Cache intermediate states if enabled\n        if CACHE_INTERMEDIATE_STATES:\n            if cache_idx >= 0:\n                step_offset = step_idx * HV * K * V\n                cache_ptr = (\n                    intermediate_states_buffer\n                    + cache_idx * cache_steps * HV * K * V\n                    + step_offset\n                    + i_hv * K * V\n                    + o_v[None, :] * K\n                    + o_k[:, None]\n                )\n                tl.store(cache_ptr, b_h.to(cache_ptr.dtype.element_ty), mask=mask_h)\n\n        step_idx += 1\n\n        # Update pointers for next timestep\n        p_q += stride_q\n        p_k += stride_k\n        p_v += stride_v\n        p_b += stride_b\n        p_o += HV * V\n        if IS_KDA:\n            p_a += HV * K\n        else:\n            p_a += HV\n\n    # Store final state back to h0_source with bounds checking\n    if not DISABLE_STATE_UPDATE:\n        if USE_INITIAL_STATE:\n            idx = tl.load(h0_indices + i_n)\n            if idx >= 0:\n                p_h0 = (\n                    h0_source\n                    + idx * HV * K * V\n                    + i_hv * K * V\n                    + o_v[None, :] * K\n                    + o_k[:, None]\n                )\n                tl.store(p_h0, b_h.to(p_h0.dtype.element_ty), mask=mask_h)\n\n\ndef fused_sigmoid_gating_delta_rule_update(\n    A_log: torch.Tensor,\n    a: torch.Tensor,\n    dt_bias: torch.Tensor,\n    softplus_beta: float,\n    softplus_threshold: float,\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    b: torch.Tensor,\n    initial_state_source: torch.Tensor,\n    initial_state_indices: torch.Tensor,\n    scale: Optional[float] = None,\n    use_qk_l2norm_in_kernel: bool = False,\n    cu_seqlens: Optional[torch.Tensor] = None,\n    is_kda: bool = False,\n    # Optional parameters for target_verify support\n    disable_state_update: bool = False,\n    intermediate_states_buffer: Optional[torch.Tensor] = None,\n    intermediate_state_indices: Optional[torch.Tensor] = None,\n    cache_steps: Optional[int] = None,\n    retrieve_parent_token: Optional[torch.Tensor] = None,\n):\n    \"\"\"\n    Fused triton implementation of sigmoid gating delta rule update.\n    This function uses a single fused kernel that combines both sigmoid gating computation\n    and the recurrent delta rule update for better performance.\n\n    Supports both decode and target_verify modes:\n    - decode: standard single-step update with state write-back\n    - target_verify: multi-step with intermediate state caching, optional tree attention,\n                     and optional state update disable\n    \"\"\"\n    B, T, H, K, V = *k.shape, v.shape[-1]\n    stride_q = q.stride()[1]\n    stride_k = k.stride()[1]\n    stride_v = v.stride()[1]\n    stride_b = b.stride()[-2]\n    HV = v.shape[2]\n    N = B if cu_seqlens is None else len(cu_seqlens) - 1\n    BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32)\n    NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)\n    assert NK == 1, \"NK > 1 is not supported yet\"\n    num_stages = 3\n    num_warps = 1\n\n    if scale is None:\n        scale = k.shape[-1] ** -0.5\n    else:\n        assert scale > 0, \"scale must be positive\"\n\n    o = q.new_empty(NK, *v.shape)\n\n    # Prepare retrieve_parent_token strides\n    if retrieve_parent_token is not None:\n        stride_retrieve_parent_token_seq = retrieve_parent_token.stride(0)\n        stride_retrieve_parent_token_token = retrieve_parent_token.stride(1)\n    else:\n        stride_retrieve_parent_token_seq = 0\n        stride_retrieve_parent_token_token = 0\n\n    NP2_T = triton.next_power_of_2(T)\n\n    grid = (NK, NV, N * HV)\n\n    fused_sigmoid_gating_delta_rule_update_kernel[grid](\n        A_log=A_log,\n        a=a,\n        dt_bias=dt_bias,\n        softplus_beta=softplus_beta,\n        softplus_threshold=softplus_threshold,\n        q=q,\n        k=k,\n        v=v,\n        b=b,\n        o=o,\n        h0_source=initial_state_source,\n        h0_indices=initial_state_indices,\n        cu_seqlens=cu_seqlens,\n        intermediate_states_buffer=intermediate_states_buffer,\n        intermediate_state_indices=intermediate_state_indices,\n        cache_steps=0 if cache_steps is None else cache_steps,\n        retrieve_parent_token_ptr=retrieve_parent_token,\n        stride_retrieve_parent_token_seq=stride_retrieve_parent_token_seq,\n        stride_retrieve_parent_token_token=stride_retrieve_parent_token_token,\n        scale=scale,\n        T=T,\n        stride_q=stride_q,\n        stride_k=stride_k,\n        stride_v=stride_v,\n        stride_b=stride_b,\n        NP2_T=NP2_T,\n        B=B,\n        H=H,\n        HV=HV,\n        K=K,\n        V=V,\n        BK=BK,\n        BV=BV,\n        USE_INITIAL_STATE=initial_state_source is not None,\n        USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,\n        IS_VARLEN=cu_seqlens is not None,\n        IS_KDA=is_kda,\n        DISABLE_STATE_UPDATE=disable_state_update,\n        CACHE_INTERMEDIATE_STATES=intermediate_states_buffer is not None,\n        HAS_EAGLE_TREE_CUSTOM_ATTN_MASK=retrieve_parent_token is not None,\n        num_warps=num_warps,\n        num_stages=num_stages,\n    )\n    o = o.squeeze(0)\n    return o\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/fla/index.py",
    "content": "# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/utils/index.py\n# -*- coding: utf-8 -*-\n# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang\n\nimport torch\nimport triton\n\nfrom sglang.srt.layers.attention.fla.utils import tensor_cache\n\n\n@tensor_cache\ndef prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:\n    return cu_seqlens[1:] - cu_seqlens[:-1]\n\n\n@tensor_cache\ndef prepare_chunk_indices(\n    cu_seqlens: torch.LongTensor, chunk_size: int\n) -> torch.LongTensor:\n    indices = torch.cat(\n        [\n            torch.arange(n)\n            for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()\n        ]\n    )\n    return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens)\n\n\n@tensor_cache\ndef prepare_chunk_offsets(\n    cu_seqlens: torch.LongTensor, chunk_size: int\n) -> torch.LongTensor:\n    return torch.cat(\n        [cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)]\n    ).cumsum(-1)\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/fla/kda.py",
    "content": "# Adapted from https://github.com/vllm-project/vllm/blob/0384aa7150c4c9778efca041ffd1beb3ad2bd694/vllm/model_executor/layers/fla/ops/kda.py\n# This file contains code copied from the flash-linear-attention project.\n# The original source code was licensed under the MIT license and included\n# the following copyright notice:\n# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang\n\nimport torch\nimport triton\nimport triton.language as tl\n\nfrom sglang.srt.layers.attention.fla.chunk_delta_h import chunk_gated_delta_rule_fwd_h\nfrom sglang.srt.layers.attention.fla.cumsum import chunk_local_cumsum\nfrom sglang.srt.layers.attention.fla.fused_norm_gate import layer_norm_gated_fwd\nfrom sglang.srt.layers.attention.fla.fused_recurrent import (\n    fused_recurrent_gated_delta_rule_fwd_kernel,\n)\nfrom sglang.srt.layers.attention.fla.index import prepare_chunk_indices\nfrom sglang.srt.layers.attention.fla.l2norm import l2norm_fwd\nfrom sglang.srt.layers.attention.fla.op import exp, log\nfrom sglang.srt.layers.attention.fla.solve_tril import solve_tril\nfrom sglang.srt.layers.attention.fla.utils import is_amd\n\nBT_LIST_AUTOTUNE = [32, 64, 128]\nNUM_WARPS_AUTOTUNE = [2, 4, 8, 16] if is_amd else [4, 8, 16, 32]\n\n\ndef cdiv(a: int, b: int) -> int:\n    \"\"\"Ceiling division.\"\"\"\n    return -(a // -b)\n\n\ndef next_power_of_2(n: int) -> int:\n    \"\"\"The next power of 2 (inclusive)\"\"\"\n    if n < 1:\n        return 1\n    return 1 << (n - 1).bit_length()\n\n\ndef fused_recurrent_kda_fwd(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    g: torch.Tensor,\n    beta: torch.Tensor,\n    scale: float,\n    initial_state: torch.Tensor,\n    inplace_final_state: bool = True,\n    cu_seqlens: torch.LongTensor | None = None,\n    # ssm_state_indices: torch.Tensor | None = None,\n    num_accepted_tokens: torch.Tensor | None = None,\n    use_qk_l2norm_in_kernel: bool = False,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    B, T, H, K, V = *k.shape, v.shape[-1]\n    HV = v.shape[2]\n    N = B if cu_seqlens is None else len(cu_seqlens) - 1\n    BK, BV = next_power_of_2(K), min(next_power_of_2(V), 8)\n    NK, NV = cdiv(K, BK), cdiv(V, BV)\n    assert NK == 1, \"NK > 1 is not supported yet\"\n    num_stages = 3\n    num_warps = 1\n\n    o = q.new_empty(NK, *v.shape)\n    if inplace_final_state:\n        final_state = initial_state\n    else:\n        final_state = q.new_empty(N, HV, V, K, dtype=initial_state.dtype)\n\n    stride_init_state_token = initial_state.stride(0)\n    stride_final_state_token = final_state.stride(0)\n\n    # if ssm_state_indices is None:\n    #     stride_indices_seq, stride_indices_tok = 1, 1\n    # elif ssm_state_indices.ndim == 1:\n    #     stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1\n    # else:\n    #     stride_indices_seq, stride_indices_tok = ssm_state_indices.stride()\n\n    grid = (NK, NV, N * HV)\n    fused_recurrent_gated_delta_rule_fwd_kernel[grid](\n        q=q,\n        k=k,\n        v=v,\n        g=g,\n        beta=beta,\n        o=o,\n        h0=initial_state,\n        ht=final_state,\n        cu_seqlens=cu_seqlens,\n        # ssm_state_indices=ssm_state_indices,\n        # num_accepted_tokens=num_accepted_tokens,\n        scale=scale,\n        # N=N,\n        T=T,\n        B=B,\n        H=H,\n        HV=HV,\n        K=K,\n        V=V,\n        BK=BK,\n        BV=BV,\n        # stride_init_state_token=stride_init_state_token,\n        # stride_final_state_token=stride_final_state_token,\n        # stride_indices_seq=stride_indices_seq,\n        # stride_indices_tok=stride_indices_tok,\n        USE_INITIAL_STATE=initial_state is not None,\n        STORE_FINAL_STATE=final_state is not None,\n        IS_BETA_HEADWISE=beta.ndim == v.ndim,\n        USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,\n        IS_VARLEN=cu_seqlens is not None,\n        # INPLACE_FINAL_STATE=inplace_final_state,\n        IS_KDA=True,\n        num_warps=num_warps,\n        num_stages=num_stages,\n    )\n\n    o = o.squeeze(0)\n    return o, final_state\n\n\ndef fused_recurrent_kda(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    g: torch.Tensor,\n    beta: torch.Tensor = None,\n    scale: float = None,\n    initial_state: torch.Tensor = None,\n    inplace_final_state: bool = True,\n    use_qk_l2norm_in_kernel: bool = True,\n    cu_seqlens: torch.LongTensor | None = None,\n    # ssm_state_indices: torch.LongTensor | None = None,\n    **kwargs,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    if cu_seqlens is not None and q.shape[0] != 1:\n        raise ValueError(\n            f\"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`.\"\n            f\"Please flatten variable-length inputs before processing.\"\n        )\n    if scale is None:\n        scale = k.shape[-1] ** -0.5\n\n    o, final_state = fused_recurrent_kda_fwd(\n        q=q.contiguous(),\n        k=k.contiguous(),\n        v=v.contiguous(),\n        g=g.contiguous(),\n        beta=beta.contiguous(),\n        scale=scale,\n        initial_state=initial_state,\n        inplace_final_state=inplace_final_state,\n        cu_seqlens=cu_seqlens,\n        # ssm_state_indices=ssm_state_indices,\n        num_accepted_tokens=None,\n        use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,\n    )\n    return o, final_state\n\n\ndef rms_norm_gated(\n    x: torch.Tensor,\n    g: torch.Tensor,\n    weight: torch.Tensor,\n    bias: torch.Tensor,\n    activation: str = \"swish\",\n    residual: torch.Tensor | None = None,\n    prenorm: bool = False,\n    residual_in_fp32: bool = False,\n    eps: float = 1e-6,\n):\n    x_shape_og = x.shape\n    # reshape input data into 2D tensor\n    x = x.contiguous().reshape(-1, x.shape[-1])\n    g = g.contiguous().reshape(-1, g.shape[-1])\n    if residual is not None:\n        assert residual.shape == x_shape_og\n        residual = residual.contiguous().reshape(-1, residual.shape[-1])\n    residual_dtype = (\n        residual.dtype\n        if residual is not None\n        else (torch.float if residual_in_fp32 else None)\n    )\n    y, _, _, residual_out = layer_norm_gated_fwd(\n        x=x,\n        g=g,\n        weight=weight,\n        bias=bias,\n        activation=activation,\n        eps=eps,\n        residual=residual,\n        residual_dtype=residual_dtype,\n        is_rms_norm=True,\n    )\n    y = y.reshape(x_shape_og)\n    return y if not prenorm else (y, residual_out.reshape(x_shape_og))\n\n\n@triton.autotune(\n    configs=[\n        triton.Config({\"BK\": BK}, num_warps=num_warps, num_stages=num_stages)\n        for BK in [32, 64]\n        for num_warps in [1, 2, 4, 8]\n        for num_stages in [2, 3, 4]\n    ],\n    key=[\"BC\", \"IS_VARLEN\"],\n)\n@triton.jit(do_not_specialize=[\"T\"])\ndef chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_inter(\n    q,\n    k,\n    g,\n    beta,\n    A,\n    Aqk,\n    scale,\n    cu_seqlens,\n    chunk_indices,\n    T,\n    H: tl.constexpr,\n    K: tl.constexpr,\n    BT: tl.constexpr,\n    BC: tl.constexpr,\n    BK: tl.constexpr,\n    NC: tl.constexpr,\n    IS_VARLEN: tl.constexpr,\n):\n    i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n    i_b, i_h = i_bh // H, i_bh % H\n    i_i, i_j = i_c // NC, i_c % NC\n    if IS_VARLEN:\n        i_n, i_t = (\n            tl.load(chunk_indices + i_t * 2).to(tl.int32),\n            tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),\n        )\n        bos, eos = (\n            tl.load(cu_seqlens + i_n).to(tl.int32),\n            tl.load(cu_seqlens + i_n + 1).to(tl.int32),\n        )\n        T = eos - bos\n    else:\n        bos, eos = i_b * T, i_b * T + T\n\n    if i_t * BT + i_i * BC >= T:\n        return\n    if i_i <= i_j:\n        return\n\n    q += (bos * H + i_h) * K\n    k += (bos * H + i_h) * K\n    g += (bos * H + i_h) * K\n    A += (bos * H + i_h) * BT\n    Aqk += (bos * H + i_h) * BT\n\n    p_b = tl.make_block_ptr(\n        beta + bos * H + i_h, (T,), (H,), (i_t * BT + i_i * BC,), (BC,), (0,)\n    )\n    b_b = tl.load(p_b, boundary_check=(0,))\n\n    b_A = tl.zeros([BC, BC], dtype=tl.float32)\n    b_Aqk = tl.zeros([BC, BC], dtype=tl.float32)\n    for i_k in range(tl.cdiv(K, BK)):\n        p_q = tl.make_block_ptr(\n            q, (T, K), (H * K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)\n        )\n        p_k = tl.make_block_ptr(\n            k, (T, K), (H * K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)\n        )\n        p_g = tl.make_block_ptr(\n            g, (T, K), (H * K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)\n        )\n        b_kt = tl.make_block_ptr(\n            k, (K, T), (1, H * K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)\n        )\n        p_gk = tl.make_block_ptr(\n            g, (K, T), (1, H * K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)\n        )\n\n        o_k = i_k * BK + tl.arange(0, BK)\n        m_k = o_k < K\n        # [BK,]\n        b_gn = tl.load(g + (i_t * BT + i_i * BC) * H * K + o_k, mask=m_k, other=0)\n        # [BC, BK]\n        b_g = tl.load(p_g, boundary_check=(0, 1))\n        b_k = tl.load(p_k, boundary_check=(0, 1)) * exp(b_g - b_gn[None, :])\n        # [BK, BC]\n        b_gk = tl.load(p_gk, boundary_check=(0, 1))\n        b_kt = tl.load(b_kt, boundary_check=(0, 1))\n        # [BC, BC]\n        b_ktg = b_kt * exp(b_gn[:, None] - b_gk)\n        b_A += tl.dot(b_k, b_ktg)\n\n        b_q = tl.load(p_q, boundary_check=(0, 1))\n        b_qg = b_q * exp(b_g - b_gn[None, :]) * scale\n        b_Aqk += tl.dot(b_qg, b_ktg)\n\n    b_A *= b_b[:, None]\n\n    p_A = tl.make_block_ptr(\n        A, (T, BT), (H * BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)\n    )\n    tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1))\n    p_Aqk = tl.make_block_ptr(\n        Aqk, (T, BT), (H * BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)\n    )\n    tl.store(p_Aqk, b_Aqk.to(Aqk.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.autotune(\n    configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]],\n    key=[\"BK\", \"BT\", \"IS_VARLEN\"],\n)\n@triton.jit(do_not_specialize=[\"T\"])\ndef chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_intra(\n    q,\n    k,\n    g,\n    beta,\n    A,\n    Aqk,\n    scale,\n    cu_seqlens,\n    chunk_indices,\n    T,\n    H: tl.constexpr,\n    K: tl.constexpr,\n    BT: tl.constexpr,\n    BC: tl.constexpr,\n    BK: tl.constexpr,\n    IS_VARLEN: tl.constexpr,\n):\n    i_t, i_i, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n    i_b, i_h = i_bh // H, i_bh % H\n    if IS_VARLEN:\n        i_n, i_t = (\n            tl.load(chunk_indices + i_t * 2).to(tl.int32),\n            tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),\n        )\n        bos, eos = (\n            tl.load(cu_seqlens + i_n).to(tl.int32),\n            tl.load(cu_seqlens + i_n + 1).to(tl.int32),\n        )\n        T = eos - bos\n    else:\n        bos, eos = i_b * T, i_b * T + T\n\n    if i_t * BT + i_i * BC >= T:\n        return\n\n    o_i = tl.arange(0, BC)\n    o_k = tl.arange(0, BK)\n    m_k = o_k < K\n    m_A = (i_t * BT + i_i * BC + o_i) < T\n    o_A = (bos + i_t * BT + i_i * BC + o_i) * H * BT + i_h * BT + i_i * BC\n\n    p_q = tl.make_block_ptr(\n        q + (bos * H + i_h) * K,\n        (T, K),\n        (H * K, 1),\n        (i_t * BT + i_i * BC, 0),\n        (BC, BK),\n        (1, 0),\n    )\n    p_k = tl.make_block_ptr(\n        k + (bos * H + i_h) * K,\n        (T, K),\n        (H * K, 1),\n        (i_t * BT + i_i * BC, 0),\n        (BC, BK),\n        (1, 0),\n    )\n    p_g = tl.make_block_ptr(\n        g + (bos * H + i_h) * K,\n        (T, K),\n        (H * K, 1),\n        (i_t * BT + i_i * BC, 0),\n        (BC, BK),\n        (1, 0),\n    )\n    b_q = tl.load(p_q, boundary_check=(0, 1))\n    b_k = tl.load(p_k, boundary_check=(0, 1))\n    b_g = tl.load(p_g, boundary_check=(0, 1))\n\n    p_b = beta + (bos + i_t * BT + i_i * BC + o_i) * H + i_h\n    b_k = b_k * tl.load(p_b, mask=m_A, other=0)[:, None]\n\n    p_kt = k + (bos + i_t * BT + i_i * BC) * H * K + i_h * K + o_k\n    p_gk = g + (bos + i_t * BT + i_i * BC) * H * K + i_h * K + o_k\n\n    for j in range(0, min(BC, T - i_t * BT - i_i * BC)):\n        b_kt = tl.load(p_kt, mask=m_k, other=0).to(tl.float32)\n        b_gk = tl.load(p_gk, mask=m_k, other=0).to(tl.float32)\n        b_ktg = b_kt[None, :] * exp(b_g - b_gk[None, :])\n        b_A = tl.sum(b_k * b_ktg, 1)\n        b_A = tl.where(o_i > j, b_A, 0.0)\n        b_Aqk = tl.sum(b_q * b_ktg, 1)\n        b_Aqk = tl.where(o_i >= j, b_Aqk * scale, 0.0)\n        tl.store(A + o_A + j, b_A, mask=m_A)\n        tl.store(Aqk + o_A + j, b_Aqk, mask=m_A)\n        p_kt += H * K\n        p_gk += H * K\n\n\ndef chunk_kda_scaled_dot_kkt_fwd(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    gk: torch.Tensor | None = None,\n    beta: torch.Tensor | None = None,\n    scale: float | None = None,\n    cu_seqlens: torch.LongTensor | None = None,\n    chunk_size: int = 64,\n    output_dtype: torch.dtype = torch.float32,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    r\"\"\"\n    Compute beta * K * K^T.\n\n    Args:\n        k (torch.Tensor):\n            The key tensor of shape `[B, T, H, K]`.\n        beta (torch.Tensor):\n            The beta tensor of shape `[B, T, H]`.\n        gk (torch.Tensor):\n            The cumulative sum of the gate tensor of shape `[B, T, H, K]` applied to the key tensor. Default: `None`.\n        cu_seqlens (torch.LongTensor):\n            The cumulative sequence lengths of the input tensor.\n            Default: None\n        chunk_size (int):\n            The chunk size. Default: 64.\n        output_dtype (torch.dtype):\n            The dtype of the output tensor. Default: `torch.float32`\n\n    Returns:\n        beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size.\n    \"\"\"\n    B, T, H, K = k.shape\n    assert K <= 256\n    BT = chunk_size\n    chunk_indices = (\n        prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None\n    )\n    NT = cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)\n\n    BC = min(16, BT)\n    NC = cdiv(BT, BC)\n    BK = max(next_power_of_2(K), 16)\n    A = torch.zeros(B, T, H, BT, device=k.device, dtype=output_dtype)\n    Aqk = torch.zeros(B, T, H, BT, device=k.device, dtype=output_dtype)\n    grid = (NT, NC * NC, B * H)\n    chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_inter[grid](\n        q=q,\n        k=k,\n        g=gk,\n        beta=beta,\n        A=A,\n        Aqk=Aqk,\n        scale=scale,\n        cu_seqlens=cu_seqlens,\n        chunk_indices=chunk_indices,\n        T=T,\n        H=H,\n        K=K,\n        BT=BT,\n        BC=BC,\n        NC=NC,\n        IS_VARLEN=cu_seqlens is not None,\n    )\n\n    grid = (NT, NC, B * H)\n    chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_intra[grid](\n        q=q,\n        k=k,\n        g=gk,\n        beta=beta,\n        A=A,\n        Aqk=Aqk,\n        scale=scale,\n        cu_seqlens=cu_seqlens,\n        chunk_indices=chunk_indices,\n        T=T,\n        H=H,\n        K=K,\n        BT=BT,\n        BC=BC,\n        BK=BK,\n        IS_VARLEN=cu_seqlens is not None,\n    )\n    return A, Aqk\n\n\n@triton.autotune(\n    configs=[\n        triton.Config({}, num_warps=num_warps, num_stages=num_stages)\n        for num_warps in [2, 4, 8]\n        for num_stages in [2, 3, 4]\n    ],\n    key=[\"H\", \"K\", \"V\", \"BT\", \"BK\", \"BV\", \"IS_VARLEN\"],\n)\n@triton.jit(do_not_specialize=[\"T\"])\ndef recompute_w_u_fwd_kernel(\n    q,\n    k,\n    qg,\n    kg,\n    v,\n    beta,\n    w,\n    u,\n    A,\n    gk,\n    cu_seqlens,\n    chunk_indices,\n    T,\n    H: tl.constexpr,\n    K: tl.constexpr,\n    V: tl.constexpr,\n    BT: tl.constexpr,\n    BK: tl.constexpr,\n    BV: tl.constexpr,\n    STORE_QG: tl.constexpr,\n    STORE_KG: tl.constexpr,\n    IS_VARLEN: tl.constexpr,\n    DOT_PRECISION: tl.constexpr,\n):\n    i_t, i_bh = tl.program_id(0), tl.program_id(1)\n    i_b, i_h = i_bh // H, i_bh % H\n    if IS_VARLEN:\n        i_n, i_t = (\n            tl.load(chunk_indices + i_t * 2).to(tl.int32),\n            tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),\n        )\n        bos, eos = (\n            tl.load(cu_seqlens + i_n).to(tl.int32),\n            tl.load(cu_seqlens + i_n + 1).to(tl.int32),\n        )\n        T = eos - bos\n    else:\n        bos, eos = i_b * T, i_b * T + T\n    p_b = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))\n    b_b = tl.load(p_b, boundary_check=(0,))\n\n    p_A = tl.make_block_ptr(\n        A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)\n    )\n    b_A = tl.load(p_A, boundary_check=(0, 1))\n\n    for i_v in range(tl.cdiv(V, BV)):\n        p_v = tl.make_block_ptr(\n            v + (bos * H + i_h) * V,\n            (T, V),\n            (H * V, 1),\n            (i_t * BT, i_v * BV),\n            (BT, BV),\n            (1, 0),\n        )\n        p_u = tl.make_block_ptr(\n            u + (bos * H + i_h) * V,\n            (T, V),\n            (H * V, 1),\n            (i_t * BT, i_v * BV),\n            (BT, BV),\n            (1, 0),\n        )\n        b_v = tl.load(p_v, boundary_check=(0, 1))\n        b_vb = (b_v * b_b[:, None]).to(b_v.dtype)\n        b_u = tl.dot(b_A, b_vb, input_precision=DOT_PRECISION)\n        tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))\n\n    for i_k in range(tl.cdiv(K, BK)):\n        p_w = tl.make_block_ptr(\n            w + (bos * H + i_h) * K,\n            (T, K),\n            (H * K, 1),\n            (i_t * BT, i_k * BK),\n            (BT, BK),\n            (1, 0),\n        )\n        p_k = tl.make_block_ptr(\n            k + (bos * H + i_h) * K,\n            (T, K),\n            (H * K, 1),\n            (i_t * BT, i_k * BK),\n            (BT, BK),\n            (1, 0),\n        )\n        b_k = tl.load(p_k, boundary_check=(0, 1))\n        b_kb = b_k * b_b[:, None]\n\n        p_gk = tl.make_block_ptr(\n            gk + (bos * H + i_h) * K,\n            (T, K),\n            (H * K, 1),\n            (i_t * BT, i_k * BK),\n            (BT, BK),\n            (1, 0),\n        )\n        b_gk = tl.load(p_gk, boundary_check=(0, 1))\n        b_kb *= exp(b_gk)\n        if STORE_QG:\n            p_q = tl.make_block_ptr(\n                q + (bos * H + i_h) * K,\n                (T, K),\n                (H * K, 1),\n                (i_t * BT, i_k * BK),\n                (BT, BK),\n                (1, 0),\n            )\n            p_qg = tl.make_block_ptr(\n                qg + (bos * H + i_h) * K,\n                (T, K),\n                (H * K, 1),\n                (i_t * BT, i_k * BK),\n                (BT, BK),\n                (1, 0),\n            )\n            b_q = tl.load(p_q, boundary_check=(0, 1))\n            b_qg = b_q * exp(b_gk)\n            tl.store(p_qg, b_qg.to(p_qg.dtype.element_ty), boundary_check=(0, 1))\n        if STORE_KG:\n            last_idx = min(i_t * BT + BT, T) - 1\n\n            o_k = i_k * BK + tl.arange(0, BK)\n            m_k = o_k < K\n            b_gn = tl.load(\n                gk + ((bos + last_idx) * H + i_h) * K + o_k, mask=m_k, other=0.0\n            )\n            b_kg = b_k * exp(b_gn - b_gk)\n\n            p_kg = tl.make_block_ptr(\n                kg + (bos * H + i_h) * K,\n                (T, K),\n                (H * K, 1),\n                (i_t * BT, i_k * BK),\n                (BT, BK),\n                (1, 0),\n            )\n            tl.store(p_kg, b_kg.to(p_kg.dtype.element_ty), boundary_check=(0, 1))\n\n        b_w = tl.dot(b_A, b_kb.to(b_k.dtype))\n        tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))\n\n\ndef recompute_w_u_fwd(\n    k: torch.Tensor,\n    v: torch.Tensor,\n    beta: torch.Tensor,\n    A: torch.Tensor,\n    q: torch.Tensor | None = None,\n    gk: torch.Tensor | None = None,\n    cu_seqlens: torch.LongTensor | None = None,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    B, T, H, K, V = *k.shape, v.shape[-1]\n    BT = A.shape[-1]\n    BK = 64\n    BV = 64\n\n    chunk_indices = (\n        prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None\n    )\n    NT = cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)\n\n    w = torch.empty_like(k)\n    u = torch.empty_like(v)\n    kg = torch.empty_like(k) if gk is not None else None\n    recompute_w_u_fwd_kernel[(NT, B * H)](\n        q=q,\n        k=k,\n        qg=None,\n        kg=kg,\n        v=v,\n        beta=beta,\n        w=w,\n        u=u,\n        A=A,\n        gk=gk,\n        cu_seqlens=cu_seqlens,\n        chunk_indices=chunk_indices,\n        T=T,\n        H=H,\n        K=K,\n        V=V,\n        BT=BT,\n        BK=BK,\n        BV=BV,\n        STORE_QG=False,\n        STORE_KG=kg is not None,\n        IS_VARLEN=cu_seqlens is not None,\n        DOT_PRECISION=\"ieee\",\n    )\n    return w, u, None, kg\n\n\n@triton.autotune(\n    configs=[\n        triton.Config({\"BK\": BK, \"BV\": BV}, num_warps=num_warps, num_stages=num_stages)\n        for BK in [32, 64]\n        for BV in [64, 128]\n        for num_warps in [2, 4, 8]\n        for num_stages in [2, 3, 4]\n    ],\n    key=[\"BT\", \"IS_VARLEN\"],\n)\n@triton.jit(do_not_specialize=[\"T\"])\ndef chunk_gla_fwd_kernel_o(\n    q,\n    v,\n    g,\n    h,\n    o,\n    A,\n    cu_seqlens,\n    chunk_indices,\n    scale,\n    T,\n    H: tl.constexpr,\n    K: tl.constexpr,\n    V: tl.constexpr,\n    BT: tl.constexpr,\n    BK: tl.constexpr,\n    BV: tl.constexpr,\n    IS_VARLEN: tl.constexpr,\n):\n    i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n    i_b, i_h = i_bh // H, i_bh % H\n    if IS_VARLEN:\n        i_tg = i_t\n        i_n, i_t = (\n            tl.load(chunk_indices + i_t * 2).to(tl.int32),\n            tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),\n        )\n        bos, eos = (\n            tl.load(cu_seqlens + i_n).to(tl.int32),\n            tl.load(cu_seqlens + i_n + 1).to(tl.int32),\n        )\n        T = eos - bos\n        NT = tl.cdiv(T, BT)\n    else:\n        NT = tl.cdiv(T, BT)\n        i_tg = i_b * NT + i_t\n        bos, eos = i_b * T, i_b * T + T\n\n    m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :]\n\n    b_o = tl.zeros([BT, BV], dtype=tl.float32)\n    for i_k in range(tl.cdiv(K, BK)):\n        p_q = tl.make_block_ptr(\n            q + (bos * H + i_h) * K,\n            (T, K),\n            (H * K, 1),\n            (i_t * BT, i_k * BK),\n            (BT, BK),\n            (1, 0),\n        )\n        p_g = tl.make_block_ptr(\n            g + (bos * H + i_h) * K,\n            (T, K),\n            (H * K, 1),\n            (i_t * BT, i_k * BK),\n            (BT, BK),\n            (1, 0),\n        )\n        p_h = tl.make_block_ptr(\n            h + (i_tg * H + i_h) * V * K,\n            (V, K),\n            (K, 1),\n            (i_v * BV, i_k * BK),\n            (BV, BK),\n            (1, 0),\n        )\n\n        # [BT, BK]\n        b_q = tl.load(p_q, boundary_check=(0, 1))\n        b_q = (b_q * scale).to(b_q.dtype)\n        # [BT, BK]\n        b_g = tl.load(p_g, boundary_check=(0, 1))\n        # [BT, BK]\n        b_qg = (b_q * exp(b_g)).to(b_q.dtype)\n        # [BK, BV]\n        b_h = tl.load(p_h, boundary_check=(0, 1))\n        # works but dkw, owing to divine benevolence\n        # [BT, BV]\n        if i_k >= 0:\n            b_o += tl.dot(b_qg, tl.trans(b_h).to(b_qg.dtype))\n    p_v = tl.make_block_ptr(\n        v + (bos * H + i_h) * V,\n        (T, V),\n        (H * V, 1),\n        (i_t * BT, i_v * BV),\n        (BT, BV),\n        (1, 0),\n    )\n    p_o = tl.make_block_ptr(\n        o + (bos * H + i_h) * V,\n        (T, V),\n        (H * V, 1),\n        (i_t * BT, i_v * BV),\n        (BT, BV),\n        (1, 0),\n    )\n    p_A = tl.make_block_ptr(\n        A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)\n    )\n    # [BT, BV]\n    b_v = tl.load(p_v, boundary_check=(0, 1))\n    # [BT, BT]\n    b_A = tl.load(p_A, boundary_check=(0, 1))\n    b_A = tl.where(m_s, b_A, 0.0).to(b_v.dtype)\n    b_o += tl.dot(b_A, b_v, allow_tf32=False)\n    tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n\ndef chunk_gla_fwd_o_gk(\n    q: torch.Tensor,\n    v: torch.Tensor,\n    g: torch.Tensor,\n    A: torch.Tensor,\n    h: torch.Tensor,\n    o: torch.Tensor,\n    scale: float,\n    cu_seqlens: torch.LongTensor | None = None,\n    chunk_size: int = 64,\n):\n    B, T, H, K, V = *q.shape, v.shape[-1]\n    BT = chunk_size\n\n    chunk_indices = (\n        prepare_chunk_indices(cu_seqlens, chunk_size)\n        if cu_seqlens is not None\n        else None\n    )\n    NT = cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)\n\n    def grid(meta):\n        return (cdiv(V, meta[\"BV\"]), NT, B * H)\n\n    chunk_gla_fwd_kernel_o[grid](\n        q=q,\n        v=v,\n        g=g,\n        h=h,\n        o=o,\n        A=A,\n        cu_seqlens=cu_seqlens,\n        chunk_indices=chunk_indices,\n        scale=scale,\n        T=T,\n        H=H,\n        K=K,\n        V=V,\n        BT=BT,\n        IS_VARLEN=cu_seqlens is not None,\n    )\n    return o\n\n\ndef chunk_kda_fwd(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    g: torch.Tensor,\n    beta: torch.Tensor,\n    scale: float,\n    initial_state: torch.Tensor,\n    initial_state_indices: torch.Tensor,\n    cu_seqlens: torch.LongTensor | None = None,\n):\n    chunk_size = 64\n    g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens)\n    # the intra Aqk is kept in fp32\n    # the computation has very marginal effect on the entire throughput\n    A, Aqk = chunk_kda_scaled_dot_kkt_fwd(\n        q=q,\n        k=k,\n        gk=g,\n        beta=beta,\n        scale=scale,\n        cu_seqlens=cu_seqlens,\n        output_dtype=torch.float32,\n    )\n    A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)\n    w, u, _, kg = recompute_w_u_fwd(\n        k=k,\n        v=v,\n        beta=beta,\n        A=A,\n        gk=g,\n        cu_seqlens=cu_seqlens,\n    )\n    del A\n    h, v_new = chunk_gated_delta_rule_fwd_h(\n        k=kg,\n        w=w,\n        u=u,\n        gk=g,\n        initial_state=initial_state,\n        initial_state_indices=initial_state_indices,\n        cu_seqlens=cu_seqlens,\n    )\n    del w, u, kg\n    o = chunk_gla_fwd_o_gk(\n        q=q,\n        v=v_new,\n        g=g,\n        A=Aqk,\n        h=h,\n        o=v,\n        scale=scale,\n        cu_seqlens=cu_seqlens,\n        chunk_size=chunk_size,\n    )\n    del Aqk, v_new, h\n    return o\n\n\ndef chunk_kda(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    g: torch.Tensor,\n    beta: torch.Tensor,\n    scale: float = None,\n    initial_state: torch.Tensor = None,\n    initial_state_indices: torch.Tensor = None,\n    use_qk_l2norm_in_kernel: bool = False,\n    cu_seqlens: torch.LongTensor | None = None,\n    **kwargs,\n):\n    if scale is None:\n        scale = k.shape[-1] ** -0.5\n\n    if use_qk_l2norm_in_kernel:\n        q = l2norm_fwd(q.contiguous())\n        k = l2norm_fwd(k.contiguous())\n\n    o = chunk_kda_fwd(\n        q=q,\n        k=k,\n        v=v.contiguous(),\n        g=g.contiguous(),\n        beta=beta.contiguous(),\n        scale=scale,\n        initial_state=initial_state,\n        initial_state_indices=initial_state_indices,\n        cu_seqlens=cu_seqlens,\n    )\n    return o\n\n\n@triton.autotune(\n    configs=[\n        triton.Config({\"BT\": bt}, num_warps=nw, num_stages=ns)\n        for bt in BT_LIST_AUTOTUNE\n        for nw in NUM_WARPS_AUTOTUNE\n        for ns in [2, 3]\n    ],\n    key=[\"H\", \"D\"],\n)\n@triton.jit\ndef kda_gate_fwd_kernel(\n    g,\n    A,\n    y,\n    g_bias,\n    beta: tl.constexpr,\n    threshold: tl.constexpr,\n    T,\n    H,\n    D: tl.constexpr,\n    BT: tl.constexpr,\n    BD: tl.constexpr,\n    HAS_BIAS: tl.constexpr,\n):\n    i_t, i_h = tl.program_id(0), tl.program_id(1)\n    n_t = i_t * BT\n\n    b_a = tl.load(A + i_h).to(tl.float32)\n    b_a = -tl.exp(b_a)\n\n    stride_row = H * D\n    stride_col = 1\n\n    g_ptr = tl.make_block_ptr(\n        base=g + i_h * D,\n        shape=(T, D),\n        strides=(stride_row, stride_col),\n        offsets=(n_t, 0),\n        block_shape=(BT, BD),\n        order=(1, 0),\n    )\n\n    y_ptr = tl.make_block_ptr(\n        base=y + i_h * D,\n        shape=(T, D),\n        strides=(stride_row, stride_col),\n        offsets=(n_t, 0),\n        block_shape=(BT, BD),\n        order=(1, 0),\n    )\n\n    b_g = tl.load(g_ptr, boundary_check=(0, 1)).to(tl.float32)\n\n    if HAS_BIAS:\n        n_d = tl.arange(0, BD)\n        bias_mask = n_d < D\n        b_bias = tl.load(g_bias + i_h * D + n_d, mask=bias_mask, other=0.0).to(\n            tl.float32\n        )\n        b_g = b_g + b_bias[None, :]\n\n    # softplus(x, beta) = (1/beta) * log(1 + exp(beta * x))\n    # When beta * x > threshold, use linear approximation x\n    # Use threshold to switch to linear when beta*x > threshold\n    g_scaled = b_g * beta\n    use_linear = g_scaled > threshold\n    sp = tl.where(use_linear, b_g, (1.0 / beta) * log(1.0 + tl.exp(g_scaled)))\n    b_y = b_a * sp\n\n    tl.store(y_ptr, b_y.to(y.dtype.element_ty), boundary_check=(0, 1))\n\n\ndef fused_kda_gate(\n    g: torch.Tensor,\n    A: torch.Tensor,\n    head_k_dim: int,\n    g_bias: torch.Tensor | None = None,\n    beta: float = 1.0,\n    threshold: float = 20.0,\n) -> torch.Tensor:\n    \"\"\"\n    Forward pass for KDA gate:\n      input g: [..., H*D]\n      param A: [H] or [1, 1, H, 1]\n      beta: softplus beta parameter\n      threshold: softplus threshold parameter\n      return  : [..., H, D]\n    \"\"\"\n    orig_shape = g.shape[:-1]\n\n    g = g.view(-1, g.shape[-1])\n    T = g.shape[0]\n    HD = g.shape[1]\n    H = A.numel()\n    assert H * head_k_dim == HD\n\n    y = torch.empty_like(g, dtype=torch.float32)\n\n    def grid(meta):\n        return (cdiv(T, meta[\"BT\"]), H)\n\n    kda_gate_fwd_kernel[grid](\n        g,\n        A,\n        y,\n        g_bias,\n        beta,\n        threshold,\n        T,\n        H,\n        head_k_dim,\n        BD=next_power_of_2(head_k_dim),\n        HAS_BIAS=g_bias is not None,\n    )\n\n    y = y.view(*orig_shape, H, head_k_dim)\n    return y\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/fla/l2norm.py",
    "content": "# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/modules/l2norm.py\n# -*- coding: utf-8 -*-\n# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang\n\nfrom typing import Optional\n\nimport torch\nimport torch.nn as nn\nimport triton\nimport triton.language as tl\n\nfrom sglang.srt.layers.attention.fla.utils import input_guard\n\nBT_LIST = [8, 16, 32, 64, 128]\n\n\n# @triton.autotune(\n#     configs=[\n#         triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16, 32]\n#     ],\n#     key=[\"D\"],\n# )\n@triton.jit\ndef l2norm_fwd_kernel1(\n    x,\n    y,\n    D,\n    BD: tl.constexpr,\n    eps,\n):\n    i_t = tl.program_id(0)\n    x += i_t * D\n    y += i_t * D\n    # Compute mean and variance\n    cols = tl.arange(0, BD)\n    mask = cols < D\n    b_x = tl.load(x + cols, mask=mask, other=0.0).to(tl.float32)\n    b_var = tl.sum(b_x * b_x, axis=0)\n    b_rstd = 1 / tl.sqrt(b_var + eps)\n    # tl.store(Rstd + i_t, rstd)\n    # Normalize and apply linear transformation\n    b_y = b_x * b_rstd\n    tl.store(y + cols, b_y, mask=mask)\n\n\n# @triton.autotune(\n#     configs=[\n#         triton.Config({\"BT\": BT}, num_warps=num_warps)\n#         for num_warps in [1, 2, 4, 8, 16]\n#         for BT in BT_LIST\n#     ],\n#     key=[\"D\", \"NB\"],\n# )\n@triton.jit\ndef l2norm_fwd_kernel(\n    x,\n    y,\n    eps,\n    NB: tl.constexpr,\n    T: tl.constexpr,\n    D: tl.constexpr,\n    BT: tl.constexpr,\n    BD: tl.constexpr,\n):\n    i_t = tl.program_id(0)\n    p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))\n    b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32)\n    b_var = tl.sum(b_x * b_x, axis=1)\n    b_y = b_x / tl.sqrt(b_var + eps)[:, None]\n    p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))\n    tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1))\n\n\ndef l2norm_fwd(\n    x: torch.Tensor, eps: float = 1e-6, output_dtype: Optional[torch.dtype] = None\n):\n    x_shape_og = x.shape\n    x = x.view(-1, x.shape[-1])\n    # allocate output\n    if output_dtype is None:\n        y = torch.empty_like(x)\n    else:\n        y = torch.empty_like(x, dtype=output_dtype)\n    assert y.stride(-1) == 1\n    T, D = x.shape[0], x.shape[-1]\n    # rstd = torch.empty((T,), dtype=torch.float32, device=x.device)\n    # Less than 64KB per feature: enqueue fused kernel\n    MAX_FUSED_SIZE = 65536 // x.element_size()\n    BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D))\n    if D > BD:\n        raise RuntimeError(\"This layer doesn't support feature dim >= 64KB.\")\n\n    if D <= 512:\n        NB = triton.cdiv(T, 2048)\n\n        def grid(meta):\n            return (triton.cdiv(T, meta[\"BT\"]),)\n\n        l2norm_fwd_kernel[grid](\n            x,\n            y,\n            eps,\n            NB=NB,\n            T=T,\n            D=D,\n            BD=BD,\n            BT=16,\n            num_warps=8,\n            num_stages=3,\n        )\n    else:\n        l2norm_fwd_kernel1[(T,)](\n            x,\n            y,\n            eps=eps,\n            D=D,\n            BD=BD,\n            num_warps=8,\n            num_stages=3,\n        )\n\n    return y.view(x_shape_og)\n\n\nclass L2NormFunction(torch.autograd.Function):\n\n    @staticmethod\n    @input_guard\n    def forward(ctx, x, eps=1e-6, output_dtype=None):\n        return l2norm_fwd(x, eps, output_dtype)\n\n\ndef l2norm(\n    x: torch.Tensor, eps: float = 1e-6, output_dtype: Optional[torch.dtype] = None\n) -> torch.Tensor:\n    return L2NormFunction.apply(x, eps, output_dtype)\n\n\nl2_norm = l2norm\n\n\nclass L2Norm(nn.Module):\n\n    def __init__(self, eps: float = 1e-6, output_dtype: Optional[torch.dtype] = None):\n        super().__init__()\n        self.eps = eps\n        self.output_dtype = output_dtype\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        return l2norm(x, self.eps, self.output_dtype)\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/fla/layernorm_gated.py",
    "content": "# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/modules/layernorm_gated.py\n# Copyright (c) 2024, Tri Dao.\n# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html\n# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.\n# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.\n# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.\n\n\nfrom functools import lru_cache\n\nimport torch\nimport torch.nn.functional as F\nimport triton\nimport triton.language as tl\nfrom einops import rearrange\n\nfrom sglang.srt.server_args import get_global_server_args\nfrom sglang.srt.utils import (\n    cdiv,\n    cpu_has_amx_support,\n    device_context,\n    is_cpu,\n    is_npu,\n    next_power_of_2,\n)\n\n_is_npu = is_npu()\n_use_cpu = is_cpu() and cpu_has_amx_support()\n\n# Maximum rows per Triton block for layernorm gated kernel\nMAX_ROWS_PER_BLOCK = 4\n\n\ndef rms_norm_ref(\n    x,\n    weight,\n    bias,\n    z=None,\n    eps=1e-6,\n    group_size=None,\n    norm_before_gate=True,\n    upcast=True,\n):\n    dtype = x.dtype\n    N = x.shape[-1]\n    weight = weight.float()\n    bias = bias.float() if bias is not None else None\n    if upcast:\n        x = x.float()\n        z = z.float() if z is not None else z\n    if z is not None and not norm_before_gate:\n        x = x * F.silu(z)\n    if group_size is None:\n        rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)\n        out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)\n    else:\n        x_group = rearrange(x, \"... (g d) -> ... g d\", d=group_size)\n        rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps)\n        out = rearrange(x_group * rstd, \"... g d -> ... (g d)\") * weight\n        if bias is not None:\n            out = out + bias\n    if z is not None and norm_before_gate:\n        out *= F.silu(z)\n    return out.to(dtype)\n\n\n@triton.jit\ndef _layer_norm_fwd_1pass_kernel(\n    X,  # pointer to the input\n    Y,  # pointer to the output\n    W,  # pointer to the weights\n    B,  # pointer to the biases\n    Z,  # pointer to the other branch\n    Mean,  # pointer to the mean\n    Rstd,  # pointer to the 1/std\n    stride_x_row,  # how much to increase the pointer when moving by 1 row\n    stride_y_row,\n    stride_z_row,\n    M,  # number of rows in X\n    N: tl.constexpr,  # number of columns in X\n    eps,  # epsilon to avoid division by zero\n    BLOCK_N: tl.constexpr,\n    ROWS_PER_BLOCK: tl.constexpr,\n    HAS_BIAS: tl.constexpr,\n    HAS_Z: tl.constexpr,\n    NORM_BEFORE_GATE: tl.constexpr,\n    IS_RMS_NORM: tl.constexpr,\n    ACTIVATION: tl.constexpr,\n):\n    # Map the program id to the starting row of X and Y it should compute.\n    row_start = tl.program_id(0) * ROWS_PER_BLOCK\n    group = tl.program_id(1)\n\n    # Create 2D tile: [ROWS_PER_BLOCK, BLOCK_N]\n    rows = row_start + tl.arange(0, ROWS_PER_BLOCK)\n    cols = tl.arange(0, BLOCK_N)\n\n    # Compute offsets for 2D tile\n    row_offsets = rows[:, None] * stride_x_row\n    col_offsets = cols[None, :] + group * N\n\n    # Base pointers\n    X_base = X + row_offsets + col_offsets\n    Y_base = Y + rows[:, None] * stride_y_row + col_offsets\n\n    # Create mask for valid rows and columns\n    row_mask = rows[:, None] < M\n    col_mask = cols[None, :] < N\n    mask = row_mask & col_mask\n\n    # Load input data with 2D tile\n    x = tl.load(X_base, mask=mask, other=0.0).to(tl.float32)\n\n    if HAS_Z and not NORM_BEFORE_GATE:\n        Z_base = Z + rows[:, None] * stride_z_row + col_offsets\n        z = tl.load(Z_base, mask=mask, other=0.0).to(tl.float32)\n        if ACTIVATION == \"swish\" or ACTIVATION == \"silu\":\n            x *= z * tl.sigmoid(z)\n        elif ACTIVATION == \"sigmoid\":\n            x *= tl.sigmoid(z)\n\n    # Compute mean and variance per row (reduce along axis 1)\n    if not IS_RMS_NORM:\n        mean = tl.sum(x, axis=1) / N  # Shape: [ROWS_PER_BLOCK]\n        # Store mean for each row\n        mean_offsets = group * M + rows\n        mean_mask = rows < M\n        tl.store(Mean + mean_offsets, mean, mask=mean_mask)\n        # Broadcast mean back to 2D for subtraction\n        xbar = tl.where(mask, x - mean[:, None], 0.0)\n        var = tl.sum(xbar * xbar, axis=1) / N  # Shape: [ROWS_PER_BLOCK]\n    else:\n        xbar = tl.where(mask, x, 0.0)\n        var = tl.sum(xbar * xbar, axis=1) / N  # Shape: [ROWS_PER_BLOCK]\n        mean = 0.0  # Placeholder for RMS norm\n\n    rstd = tl.rsqrt(var + eps)  # Shape: [ROWS_PER_BLOCK]\n\n    # Store rstd for each row\n    rstd_offsets = group * M + rows\n    rstd_mask = rows < M\n    tl.store(Rstd + rstd_offsets, rstd, mask=rstd_mask)\n\n    # Load weights and biases (broadcast across rows)\n    w_offsets = cols + group * N\n    w_mask = cols < N\n    w = tl.load(W + w_offsets, mask=w_mask, other=0.0).to(tl.float32)\n\n    if HAS_BIAS:\n        b = tl.load(B + w_offsets, mask=w_mask, other=0.0).to(tl.float32)\n\n    # Normalize and apply linear transformation\n    if not IS_RMS_NORM:\n        x_hat = (x - mean[:, None]) * rstd[:, None]\n    else:\n        x_hat = x * rstd[:, None]\n\n    y = x_hat * w[None, :] + b[None, :] if HAS_BIAS else x_hat * w[None, :]\n\n    if HAS_Z and NORM_BEFORE_GATE:\n        Z_base = Z + rows[:, None] * stride_z_row + col_offsets\n        z = tl.load(Z_base, mask=mask, other=0.0).to(tl.float32)\n        if ACTIVATION == \"swish\" or ACTIVATION == \"silu\":\n            y *= z * tl.sigmoid(z)\n        elif ACTIVATION == \"sigmoid\":\n            y *= tl.sigmoid(z)\n\n    # Write output\n    tl.store(Y_base, y, mask=mask)\n\n\n@lru_cache\ndef _get_sm_count(device: torch.device) -> int:\n    \"\"\"Get and cache the SM count for a given device.\"\"\"\n    props = torch.cuda.get_device_properties(device)\n    return props.multi_processor_count\n\n\ndef calc_rows_per_block(M: int, device: torch.device) -> int:\n    # When piecewise cuda graph is enabled, use a constant value to avoid\n    # torch.compile creating guards on the dynamic batch dimension.\n    try:\n        if not get_global_server_args().disable_piecewise_cuda_graph:\n            return MAX_ROWS_PER_BLOCK\n    except ValueError:\n        # Global server args not initialized (e.g., in unit tests)\n        pass\n    sm_count = _get_sm_count(device)\n    rows_per_block = next_power_of_2(cdiv(M, 2 * sm_count))\n    rows_per_block = min(rows_per_block, MAX_ROWS_PER_BLOCK)\n    return rows_per_block\n\n\ndef _layer_norm_fwd(\n    x,\n    weight,\n    bias,\n    eps,\n    z=None,\n    out=None,\n    group_size=None,\n    norm_before_gate=True,\n    is_rms_norm=False,\n    activation: str = \"swish\",\n):\n    M, N = x.shape\n    if group_size is None:\n        group_size = N\n    assert N % group_size == 0\n    ngroups = N // group_size\n    assert x.stride(-1) == 1\n    if z is not None:\n        assert z.stride(-1) == 1\n        assert z.shape == (M, N)\n    assert weight.shape == (N,)\n    assert weight.stride(-1) == 1\n    if bias is not None:\n        assert bias.stride(-1) == 1\n        assert bias.shape == (N,)\n    # allocate output\n    if out is not None:\n        assert out.shape == x.shape\n    else:\n        out = torch.empty_like(x)\n    assert out.stride(-1) == 1\n    mean = (\n        torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)\n        if not is_rms_norm\n        else None\n    )\n    rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)\n    # Less than 64KB per feature: enqueue fused kernel\n    MAX_FUSED_SIZE = 65536 // x.element_size()\n    BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))\n    if group_size > BLOCK_N:\n        raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n    # heuristics for number of warps\n    num_warps = min(max(BLOCK_N // 256, 1), 8)\n    # Calculate rows per block based on SM count\n    rows_per_block = calc_rows_per_block(M, x.device)\n    # Update grid to use rows_per_block\n    grid = (cdiv(M, rows_per_block), ngroups)\n    with device_context(x.device):\n        _layer_norm_fwd_1pass_kernel[grid](\n            x,\n            out,\n            weight,\n            bias,\n            z,\n            mean,\n            rstd,\n            x.stride(0),\n            out.stride(0),\n            z.stride(0) if z is not None else 0,\n            M,\n            group_size,\n            eps,\n            BLOCK_N=BLOCK_N,\n            ROWS_PER_BLOCK=rows_per_block,\n            HAS_BIAS=bias is not None,\n            HAS_Z=z is not None,\n            NORM_BEFORE_GATE=norm_before_gate,\n            IS_RMS_NORM=is_rms_norm,\n            num_warps=num_warps,\n            ACTIVATION=activation,\n        )\n    return out, mean, rstd\n\n\nif _is_npu:\n    from sgl_kernel_npu.fla.layernorm_gated import layer_norm_fwd_npu as _layer_norm_fwd\n\n\ndef rms_norm_gated(\n    *,\n    x,\n    weight,\n    bias,\n    z=None,\n    eps=1e-6,\n    group_size=None,\n    norm_before_gate=True,\n    is_rms_norm=False,\n    activation: str = \"swish\",\n):\n    \"\"\"If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))\"\"\"\n\n    x_shape_og = x.shape\n    # reshape input data into 2D tensor\n    x = x.reshape(-1, x.shape[-1])\n    if x.stride(-1) != 1:\n        x = x.contiguous()\n    if z is not None:\n        assert z.shape == x_shape_og\n        z = z.reshape(-1, z.shape[-1])\n        if z.stride(-1) != 1:\n            z = z.contiguous()\n    weight = weight.contiguous()\n    if bias is not None:\n        bias = bias.contiguous()\n    if _is_npu:\n        assert activation == \"swish\", \"NPU only supports swish activation\"\n    y, mean, rstd = _layer_norm_fwd(\n        x,\n        weight,\n        bias,\n        eps,\n        z=z,\n        group_size=group_size,\n        norm_before_gate=norm_before_gate,\n        is_rms_norm=is_rms_norm,\n        activation=activation,\n    )\n    return y.reshape(x_shape_og)\n\n\nclass LayerNormFn(torch.autograd.Function):\n\n    @staticmethod\n    def forward(\n        ctx,\n        x,\n        weight,\n        bias,\n        z=None,\n        eps=1e-6,\n        group_size=None,\n        norm_before_gate=True,\n        is_rms_norm=False,\n        activation: str = \"swish\",\n    ):\n        return rms_norm_gated(\n            x=x,\n            weight=weight,\n            bias=bias,\n            eps=eps,\n            z=z,\n            group_size=group_size,\n            norm_before_gate=norm_before_gate,\n            is_rms_norm=is_rms_norm,\n            activation=activation,\n        )\n\n\ndef layernorm_fn(\n    x,\n    weight,\n    bias,\n    z=None,\n    eps=1e-6,\n    group_size=None,\n    norm_before_gate=True,\n    is_rms_norm=False,\n    activation: str = \"swish\",\n):\n    return LayerNormFn.apply(\n        x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm, activation\n    )\n\n\nclass LayerNorm(torch.nn.Module):\n\n    def __init__(\n        self,\n        hidden_size,\n        eps=1e-5,\n        group_size=None,\n        norm_before_gate=True,\n        device=None,\n        dtype=None,\n    ):\n        \"\"\"If group_size is not None, we do GroupNorm with each group having group_size elements.\n        group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).\n        \"\"\"\n\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        self.eps = eps\n        self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))\n        self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))\n        self.group_size = group_size\n        self.norm_before_gate = norm_before_gate\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        torch.nn.init.ones_(self.weight)\n        torch.nn.init.zeros_(self.bias)\n\n    def forward(self, x, z=None):\n        \"\"\"If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))\"\"\"\n        return layernorm_fn(\n            x,\n            self.weight,\n            self.bias,\n            z=z,\n            group_size=self.group_size,\n            eps=self.eps,\n            norm_before_gate=self.norm_before_gate,\n            is_rms_norm=False,\n        )\n\n\nclass RMSNorm(torch.nn.Module):\n\n    def __init__(\n        self,\n        hidden_size,\n        eps=1e-5,\n        group_size=None,\n        norm_before_gate=True,\n        device=None,\n        dtype=None,\n        activation: str = \"swish\",\n    ):\n        \"\"\"If group_size is not None, we do GroupNorm with each group having group_size elements.\n        group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).\n        \"\"\"\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        self.eps = eps\n        self.activation = activation\n        self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))\n        self.register_parameter(\"bias\", None)\n        self.group_size = group_size\n        self.norm_before_gate = norm_before_gate\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        torch.nn.init.ones_(self.weight)\n\n    def forward(self, x, z=None):\n        \"\"\"If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))\"\"\"\n        if _use_cpu:\n            assert (\n                self.norm_before_gate\n                and self.group_size is None\n                and self.activation == \"swish\"\n            ), \"CPU rmsnorm_gated currently only supports norm before gate without group size or activation other than swish\"\n            return torch.ops.sgl_kernel.fused_rmsnorm_gated_cpu(\n                x, self.weight, z, self.eps\n            )\n        else:\n            return layernorm_fn(\n                x,\n                self.weight,\n                self.bias,\n                z=z,\n                eps=self.eps,\n                group_size=self.group_size,\n                norm_before_gate=self.norm_before_gate,\n                is_rms_norm=True,\n                activation=self.activation,\n            )\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/fla/op.py",
    "content": "# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/utils/op.py\n# -*- coding: utf-8 -*-\n# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang\n\nimport os\n\nimport triton\nimport triton.language as tl\nimport triton.language.extra.libdevice as tldevice\n\nfrom sglang.srt.layers.attention.fla.utils import is_gather_supported\n\nif os.environ.get(\"FLA_USE_FAST_OPS\", \"0\") == \"1\":\n    exp = tldevice.fast_expf\n    exp2 = tldevice.exp2\n    log = tldevice.fast_logf\n    log2 = tldevice.fast_log2f\nelse:\n    exp = tl.exp\n    exp2 = tl.math.exp2\n    log = tl.log\n    log2 = tl.log2\n\n\n@triton.jit\ndef safe_exp(x):\n    return exp(tl.where(x <= 0, x, float(\"-inf\")))\n\n\nif not is_gather_supported:\n\n    @triton.jit\n    def gather(src, index, axis, _builder=None):\n        \"\"\"\n        Gather operation that works when tl.gather is not supported.\n        This is a fallback implementation that returns None.\n        Just to make triton compiler happy.\n        \"\"\"\n        return None\n\nelse:\n    gather = tl.gather\n\n\nif hasattr(triton.language, \"_experimental_make_tensor_descriptor\"):\n    # For Triton 3.3.x\n    make_tensor_descriptor = triton.language._experimental_make_tensor_descriptor\nelif hasattr(triton.language, \"make_tensor_descriptor\"):\n    # For Triton 3.4.x and later\n    make_tensor_descriptor = triton.language.make_tensor_descriptor\nelse:\n    \"\"\"\n    Fallback implementation when TMA is not supported.\n    Returns None to indicate TMA descriptors are unavailable.\n    Just make triton compiler happy.\n    \"\"\"\n\n    @triton.jit\n    def make_tensor_descriptor(\n        base,\n        shape,\n        strides,\n        block_shape,\n        _builder=None,\n    ):\n        return None\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/fla/solve_tril.py",
    "content": "# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/utils/solve_tril.py\n# -*- coding: utf-8 -*-\n# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang\n\nfrom typing import Optional\n\nimport torch\nimport triton\nimport triton.language as tl\n\nfrom sglang.srt.layers.attention.fla.index import prepare_chunk_indices\nfrom sglang.srt.layers.attention.fla.utils import input_guard\n\n\n# @triton.autotune(\n#     configs=[\n#         triton.Config({}, num_warps=num_warps, num_stages=num_stages)\n#         for num_warps in [1, 2, 4, 8]\n#         for num_stages in [2, 3, 4, 5]\n#     ],\n#     key=[\"BT\"],\n# )\n@triton.jit(do_not_specialize=[\"T\"])\ndef solve_tril_16x16_kernel(\n    A,\n    Ad,\n    cu_seqlens,\n    chunk_indices,\n    T,\n    H: tl.constexpr,\n    BT: tl.constexpr,\n    IS_VARLEN: tl.constexpr,\n):\n    i_t, i_bh = tl.program_id(0), tl.program_id(1)\n    i_b, i_h = i_bh // H, i_bh % H\n    if IS_VARLEN:\n        i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(\n            chunk_indices + i_t * 2 + 1\n        ).to(tl.int32)\n        bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(\n            cu_seqlens + i_n + 1\n        ).to(tl.int32)\n        T = eos - bos\n    else:\n        bos, eos = i_b * T, i_b * T + T\n\n    A = A + (bos * H + i_h) * BT\n    Ad = Ad + (bos * H + i_h) * 16\n\n    offset = (i_t * 16) % BT\n    p_A = tl.make_block_ptr(\n        A, (T, BT), (H * BT, 1), (i_t * 16, offset), (16, 16), (1, 0)\n    )\n    p_Ai = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 16, 0), (16, 16), (1, 0))\n    b_A = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32)\n    b_A = -tl.where(tl.arange(0, 16)[:, None] > tl.arange(0, 16)[None, :], b_A, 0)\n\n    o_i = tl.arange(0, 16)\n    for i in range(1, min(16, T - i_t * 16)):\n        b_a = -tl.load(A + (i_t * 16 + i) * H * BT + o_i + offset)\n        b_a = b_a + tl.sum(b_a[:, None] * b_A, 0)\n        mask = o_i == i\n        b_A = tl.where(mask[:, None], b_a, b_A)\n    b_A += o_i[:, None] == o_i[None, :]\n    tl.store(\n        p_Ai,\n        b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding=\"rtne\"),\n        boundary_check=(0, 1),\n    )\n\n\n# @triton.autotune(\n#     configs=[\n#         triton.Config({}, num_warps=num_warps, num_stages=num_stages)\n#         for num_warps in [1, 2, 4, 8]\n#         for num_stages in [2, 3, 4, 5]\n#     ],\n#     key=[\"H\", \"BT\", \"IS_VARLEN\"],\n# )\n@triton.jit(do_not_specialize=[\"T\"])\ndef merge_16x16_to_32x32_inverse_kernel(\n    A,\n    Ad,\n    Ai,\n    cu_seqlens,\n    chunk_indices,\n    T,\n    H: tl.constexpr,\n    BT: tl.constexpr,\n    IS_VARLEN: tl.constexpr,\n):\n    i_t, i_bh = tl.program_id(0), tl.program_id(1)\n    i_b, i_h = i_bh // H, i_bh % H\n    if IS_VARLEN:\n        i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(\n            chunk_indices + i_t * 2 + 1\n        ).to(tl.int32)\n        bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(\n            cu_seqlens + i_n + 1\n        ).to(tl.int32)\n        T = eos - bos\n    else:\n        bos, eos = i_b * T, i_b * T + T\n\n    A += (bos * H + i_h) * 32\n    Ad += (bos * H + i_h) * 16\n    Ai += (bos * H + i_h) * 32\n\n    p_A_21 = tl.make_block_ptr(\n        A, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0)\n    )\n    p_Ad_11 = tl.make_block_ptr(\n        Ad, (T, 16), (H * 16, 1), (i_t * 32, 0), (16, 16), (1, 0)\n    )\n    p_Ad_22 = tl.make_block_ptr(\n        Ad, (T, 16), (H * 16, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0)\n    )\n    p_Ai_11 = tl.make_block_ptr(\n        Ai, (T, 32), (H * 32, 1), (i_t * 32, 0), (16, 16), (1, 0)\n    )\n    p_Ai_22 = tl.make_block_ptr(\n        Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 16), (16, 16), (1, 0)\n    )\n    p_Ai_21 = tl.make_block_ptr(\n        Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0)\n    )\n\n    A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)\n    Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32)\n    Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32)\n    Ai_21 = -tl.dot(\n        tl.dot(Ai_22, A_21, input_precision=\"ieee\"), Ai_11, input_precision=\"ieee\"\n    )\n    tl.store(\n        p_Ai_11,\n        Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding=\"rtne\"),\n        boundary_check=(0, 1),\n    )\n    tl.store(\n        p_Ai_22,\n        Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding=\"rtne\"),\n        boundary_check=(0, 1),\n    )\n    tl.store(\n        p_Ai_21,\n        Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding=\"rtne\"),\n        boundary_check=(0, 1),\n    )\n\n\n# @triton.autotune(\n#     configs=[\n#         triton.Config({}, num_warps=num_warps, num_stages=num_stages)\n#         for num_warps in [2, 4, 8]\n#         for num_stages in [2, 3, 4, 5]\n#     ],\n#     key=[\"H\", \"BT\", \"IS_VARLEN\"],\n# )\n@triton.jit(do_not_specialize=[\"T\"])\ndef merge_16x16_to_64x64_inverse_kernel(\n    A,\n    Ad,\n    Ai,\n    cu_seqlens,\n    chunk_indices,\n    T,\n    H: tl.constexpr,\n    BT: tl.constexpr,\n    IS_VARLEN: tl.constexpr,\n):\n    i_t, i_bh = tl.program_id(0), tl.program_id(1)\n    i_b, i_h = i_bh // H, i_bh % H\n    if IS_VARLEN:\n        i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(\n            chunk_indices + i_t * 2 + 1\n        ).to(tl.int32)\n        bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(\n            cu_seqlens + i_n + 1\n        ).to(tl.int32)\n        T = eos - bos\n    else:\n        bos, eos = i_b * T, i_b * T + T\n\n    A += (bos * H + i_h) * 64\n    Ad += (bos * H + i_h) * 16\n    Ai += (bos * H + i_h) * 64\n\n    p_A_21 = tl.make_block_ptr(\n        A, (T, 64), (H * 64, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0)\n    )\n    p_A_32 = tl.make_block_ptr(\n        A, (T, 64), (H * 64, 1), (i_t * 64 + 32, 16), (16, 16), (1, 0)\n    )\n    p_A_31 = tl.make_block_ptr(\n        A, (T, 64), (H * 64, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0)\n    )\n    p_A_43 = tl.make_block_ptr(\n        A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 32), (16, 16), (1, 0)\n    )\n    p_A_42 = tl.make_block_ptr(\n        A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 16), (16, 16), (1, 0)\n    )\n    p_A_41 = tl.make_block_ptr(\n        A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0)\n    )\n    p_Ad_11 = tl.make_block_ptr(\n        Ad, (T, 16), (H * 16, 1), (i_t * 64, 0), (16, 16), (1, 0)\n    )\n    p_Ad_22 = tl.make_block_ptr(\n        Ad, (T, 16), (H * 16, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0)\n    )\n    p_Ad_33 = tl.make_block_ptr(\n        Ad, (T, 16), (H * 16, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0)\n    )\n    p_Ad_44 = tl.make_block_ptr(\n        Ad, (T, 16), (H * 16, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0)\n    )\n\n    A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)\n    A_32 = tl.load(p_A_32, boundary_check=(0, 1)).to(tl.float32)\n    A_31 = tl.load(p_A_31, boundary_check=(0, 1)).to(tl.float32)\n    A_43 = tl.load(p_A_43, boundary_check=(0, 1)).to(tl.float32)\n    A_42 = tl.load(p_A_42, boundary_check=(0, 1)).to(tl.float32)\n    A_41 = tl.load(p_A_41, boundary_check=(0, 1)).to(tl.float32)\n\n    Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32)\n    Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32)\n    Ai_33 = tl.load(p_Ad_33, boundary_check=(0, 1)).to(tl.float32)\n    Ai_44 = tl.load(p_Ad_44, boundary_check=(0, 1)).to(tl.float32)\n\n    Ai_21 = -tl.dot(\n        tl.dot(Ai_22, A_21, input_precision=\"ieee\"), Ai_11, input_precision=\"ieee\"\n    )\n    Ai_32 = -tl.dot(\n        tl.dot(Ai_33, A_32, input_precision=\"ieee\"), Ai_22, input_precision=\"ieee\"\n    )\n    Ai_43 = -tl.dot(\n        tl.dot(Ai_44, A_43, input_precision=\"ieee\"), Ai_33, input_precision=\"ieee\"\n    )\n\n    Ai_31 = -tl.dot(\n        Ai_33,\n        tl.dot(A_31, Ai_11, input_precision=\"ieee\")\n        + tl.dot(A_32, Ai_21, input_precision=\"ieee\"),\n        input_precision=\"ieee\",\n    )\n    Ai_42 = -tl.dot(\n        Ai_44,\n        tl.dot(A_42, Ai_22, input_precision=\"ieee\")\n        + tl.dot(A_43, Ai_32, input_precision=\"ieee\"),\n        input_precision=\"ieee\",\n    )\n    Ai_41 = -tl.dot(\n        Ai_44,\n        tl.dot(A_41, Ai_11, input_precision=\"ieee\")\n        + tl.dot(A_42, Ai_21, input_precision=\"ieee\")\n        + tl.dot(A_43, Ai_31, input_precision=\"ieee\"),\n        input_precision=\"ieee\",\n    )\n\n    p_Ai_11 = tl.make_block_ptr(\n        Ai, (T, 64), (H * 64, 1), (i_t * 64, 0), (16, 16), (1, 0)\n    )\n    p_Ai_22 = tl.make_block_ptr(\n        Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 16), (16, 16), (1, 0)\n    )\n    p_Ai_33 = tl.make_block_ptr(\n        Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 32), (16, 16), (1, 0)\n    )\n    p_Ai_44 = tl.make_block_ptr(\n        Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 48), (16, 16), (1, 0)\n    )\n    p_Ai_21 = tl.make_block_ptr(\n        Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0)\n    )\n    p_Ai_31 = tl.make_block_ptr(\n        Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0)\n    )\n    p_Ai_32 = tl.make_block_ptr(\n        Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 16), (16, 16), (1, 0)\n    )\n    p_Ai_41 = tl.make_block_ptr(\n        Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0)\n    )\n    p_Ai_42 = tl.make_block_ptr(\n        Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 16), (16, 16), (1, 0)\n    )\n    p_Ai_43 = tl.make_block_ptr(\n        Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 32), (16, 16), (1, 0)\n    )\n    tl.store(\n        p_Ai_11,\n        Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding=\"rtne\"),\n        boundary_check=(0, 1),\n    )\n    tl.store(\n        p_Ai_22,\n        Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding=\"rtne\"),\n        boundary_check=(0, 1),\n    )\n    tl.store(\n        p_Ai_33,\n        Ai_33.to(p_Ai_33.dtype.element_ty, fp_downcast_rounding=\"rtne\"),\n        boundary_check=(0, 1),\n    )\n    tl.store(\n        p_Ai_44,\n        Ai_44.to(p_Ai_44.dtype.element_ty, fp_downcast_rounding=\"rtne\"),\n        boundary_check=(0, 1),\n    )\n    tl.store(\n        p_Ai_21,\n        Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding=\"rtne\"),\n        boundary_check=(0, 1),\n    )\n    tl.store(\n        p_Ai_31,\n        Ai_31.to(p_Ai_31.dtype.element_ty, fp_downcast_rounding=\"rtne\"),\n        boundary_check=(0, 1),\n    )\n    tl.store(\n        p_Ai_32,\n        Ai_32.to(p_Ai_32.dtype.element_ty, fp_downcast_rounding=\"rtne\"),\n        boundary_check=(0, 1),\n    )\n    tl.store(\n        p_Ai_41,\n        Ai_41.to(p_Ai_41.dtype.element_ty, fp_downcast_rounding=\"rtne\"),\n        boundary_check=(0, 1),\n    )\n    tl.store(\n        p_Ai_42,\n        Ai_42.to(p_Ai_42.dtype.element_ty, fp_downcast_rounding=\"rtne\"),\n        boundary_check=(0, 1),\n    )\n    tl.store(\n        p_Ai_43,\n        Ai_43.to(p_Ai_43.dtype.element_ty, fp_downcast_rounding=\"rtne\"),\n        boundary_check=(0, 1),\n    )\n\n    fill_zeros = tl.zeros((16, 16), dtype=tl.float32)\n    p_Ai_12 = tl.make_block_ptr(\n        Ai, (T, 64), (H * 64, 1), (i_t * 64, 16), (16, 16), (1, 0)\n    )\n    p_Ai_13 = tl.make_block_ptr(\n        Ai, (T, 64), (H * 64, 1), (i_t * 64, 32), (16, 16), (1, 0)\n    )\n    p_Ai_14 = tl.make_block_ptr(\n        Ai, (T, 64), (H * 64, 1), (i_t * 64, 48), (16, 16), (1, 0)\n    )\n    p_Ai_23 = tl.make_block_ptr(\n        Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 32), (16, 16), (1, 0)\n    )\n    p_Ai_24 = tl.make_block_ptr(\n        Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 48), (16, 16), (1, 0)\n    )\n    p_Ai_34 = tl.make_block_ptr(\n        Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 48), (16, 16), (1, 0)\n    )\n    tl.store(\n        p_Ai_12,\n        fill_zeros.to(p_Ai_12.dtype.element_ty, fp_downcast_rounding=\"rtne\"),\n        boundary_check=(0, 1),\n    )\n    tl.store(\n        p_Ai_13,\n        fill_zeros.to(p_Ai_13.dtype.element_ty, fp_downcast_rounding=\"rtne\"),\n        boundary_check=(0, 1),\n    )\n    tl.store(\n        p_Ai_14,\n        fill_zeros.to(p_Ai_14.dtype.element_ty, fp_downcast_rounding=\"rtne\"),\n        boundary_check=(0, 1),\n    )\n    tl.store(\n        p_Ai_23,\n        fill_zeros.to(p_Ai_23.dtype.element_ty, fp_downcast_rounding=\"rtne\"),\n        boundary_check=(0, 1),\n    )\n    tl.store(\n        p_Ai_24,\n        fill_zeros.to(p_Ai_24.dtype.element_ty, fp_downcast_rounding=\"rtne\"),\n        boundary_check=(0, 1),\n    )\n    tl.store(\n        p_Ai_34,\n        fill_zeros.to(p_Ai_34.dtype.element_ty, fp_downcast_rounding=\"rtne\"),\n        boundary_check=(0, 1),\n    )\n\n\n@input_guard\ndef solve_tril(\n    A: torch.Tensor,\n    cu_seqlens: Optional[torch.Tensor] = None,\n    output_dtype: torch.dtype = torch.float,\n) -> torch.Tensor:\n    \"\"\"\n    Compute the inverse of the lower triangular matrix\n    A should be strictly lower triangular, i.e., A.triu() == 0.\n\n    Args:\n        A (torch.Tensor):\n            [B, T, H, K]\n        cu_seqlens (torch.Tensor):\n            The cumulative sequence lengths of the input tensor.\n            Default: None.\n        output_dtype (torch.dtype):\n            The dtype of the output tensor. Default: `torch.float`\n\n    Returns:\n        (I + A)^-1 with the same shape as A\n    \"\"\"\n    assert A.shape[-1] in [16, 32, 64]\n\n    B, T, H, BT = A.shape\n    Ad = torch.empty(\n        B, T, H, 16, device=A.device, dtype=torch.float if BT != 16 else output_dtype\n    )\n\n    chunk_indices = (\n        prepare_chunk_indices(cu_seqlens, 16) if cu_seqlens is not None else None\n    )\n    NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, 16)\n    solve_tril_16x16_kernel[NT, B * H](\n        A=A,\n        Ad=Ad,\n        cu_seqlens=cu_seqlens,\n        chunk_indices=chunk_indices,\n        T=T,\n        H=H,\n        BT=BT,\n        IS_VARLEN=cu_seqlens is not None,\n        num_warps=1,\n        num_stages=4,\n    )\n    if BT == 16:\n        return Ad\n\n    Ai = torch.empty(B, T, H, BT, device=A.device, dtype=output_dtype)\n    merge_fn = (\n        merge_16x16_to_32x32_inverse_kernel\n        if BT == 32\n        else merge_16x16_to_64x64_inverse_kernel\n    )\n    chunk_indices = (\n        prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None\n    )\n    NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)\n    merge_fn[NT, B * H](\n        A=A,\n        Ad=Ad,\n        Ai=Ai,\n        cu_seqlens=cu_seqlens,\n        chunk_indices=chunk_indices,\n        T=T,\n        H=H,\n        BT=BT,\n        IS_VARLEN=cu_seqlens is not None,\n        num_warps=4,\n        num_stages=3,\n    )\n    return Ai\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/fla/utils.py",
    "content": "# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/utils.py\n# -*- coding: utf-8 -*-\n\nimport contextlib\nimport functools\nimport logging\nimport os\nimport sys\nfrom enum import Enum\nfrom functools import lru_cache\nfrom typing import Any, Callable, Dict, Literal, Optional, Tuple\n\nimport torch\nimport triton\nfrom packaging import version\n\nfrom sglang.srt.utils.common import torch_release\n\nlogger = logging.getLogger(__name__)\n\nCOMPILER_MODE = os.getenv(\"FLA_COMPILER_MODE\") == \"1\"\nFLA_CI_ENV = os.getenv(\"FLA_CI_ENV\") == \"1\"\n\n\n@lru_cache(maxsize=1)\ndef check_environments():\n    \"\"\"\n    Checks the current operating system, Triton version, and Python version,\n    issuing warnings if they don't meet recommendations.\n    This function's body only runs once due to lru_cache.\n    \"\"\"\n    # Check Operating System\n    if sys.platform == \"win32\":\n        logger.warning(\n            \"Detected Windows operating system. Triton does not have an official Windows release, \"\n            \"thus FLA will not be adapted for Windows, and any potential errors will not be fixed. \"\n            \"Please consider using a Linux environment for compatibility.\"\n        )\n\n    triton_version = version.parse(triton.__version__)\n    required_triton_version = version.parse(\"3.2.0\")\n\n    if triton_version < required_triton_version:\n        logger.warning(\n            f\"Current Triton version {triton_version} is below the recommended 3.2.0 version. \"\n            \"Errors may occur and these issues will not be fixed. \"\n            \"Please consider upgrading Triton.\"\n        )\n\n    # Check Python version\n    py_version = version.parse(f\"{sys.version_info.major}.{sys.version_info.minor}\")\n    required_py_version = version.parse(\"3.11\")\n\n    if py_version < required_py_version:\n        logger.warning(\n            f\"Current Python version {py_version} is below the recommended 3.11 version. \"\n            \"It is recommended to upgrade to Python 3.11 or higher for the best experience.\"\n        )\n\n    return None\n\n\ndef get_abs_err(x, y):\n    return (x.detach() - y.detach()).flatten().abs().max().item()\n\n\ndef get_err_ratio(x, y):\n    err = (x.detach() - y.detach()).flatten().square().mean().sqrt().item()\n    base = (x.detach()).flatten().square().mean().sqrt().item()\n    return err / (base + 1e-8)\n\n\ndef assert_close(prefix, ref, tri, ratio, warning=False, err_atol=1e-6):\n    abs_atol = get_abs_err(ref, tri)\n    msg = f\"{prefix} diff: {abs_atol:.6f} ratio: {get_err_ratio(ref, tri):.6f}\"\n    logger.info(msg)\n    error_rate = get_err_ratio(ref, tri)\n    if abs_atol <= err_atol:\n        return\n    if warning or (FLA_CI_ENV and (error_rate < 0.01 or abs_atol <= 0.3)):\n        if error_rate > ratio:\n            import warnings\n\n            warnings.warn(msg)\n    else:\n        assert error_rate < ratio, msg\n\n\nSUPPRESS_LEVEL = int(os.getenv(\"GDN_RECOMPUTE_SUPPRESS_LEVEL\", \"0\"))\n\n\ndef tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:\n    \"\"\"\n    A decorator that caches the most recent results of a function with tensor inputs.\n    This decorator will store the output of the decorated function for the most recent set of input tensors.\n    The cache is limited to a fixed size (default is 4). When the cache is full, the oldest entry will be removed.\n    Args:\n        fn (Callable[..., torch.Tensor]):\n            The function to be decorated. It should take tensor inputs and return tensor outputs.\n    Returns:\n        Callable[..., torch.Tensor]:\n            A wrapped version of the input function with single-entry caching.\n    \"\"\"\n\n    cache_entries: Tuple[Optional[Tuple], Optional[Dict], Any] = []\n    cache_size = 4\n\n    @functools.wraps(fn)\n    def wrapper(*args: Any, **kwargs: Any) -> Any:\n        nonlocal cache_entries, cache_size\n        for i, entry in enumerate(cache_entries):\n            last_args, last_kwargs, last_result = entry\n            if len(args) == len(last_args) and len(kwargs) == len(last_kwargs):\n                if all(a is b for a, b in zip(args, last_args)) and all(\n                    k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()\n                ):\n                    cache_entries = (\n                        cache_entries[:i]\n                        + cache_entries[i + 1 :]\n                        + [(args, kwargs, last_result)]\n                    )\n                    return last_result\n\n        result = fn(*args, **kwargs)\n\n        if len(cache_entries) >= cache_size:\n            cache_entries = cache_entries[1:]\n        cache_entries.append((args, kwargs, result))\n        return result\n\n    return wrapper\n\n\ndef input_guard(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:\n    \"\"\"\n    A decorator to make sure all input tensors are contiguous and set the device based on input tensors.\n    \"\"\"\n\n    @functools.wraps(fn)\n    def wrapper(*args, **kwargs):\n        contiguous_args = (\n            i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args\n        )\n        contiguous_kwargs = {\n            k: (v if not isinstance(v, torch.Tensor) else v.contiguous())\n            for k, v in kwargs.items()\n        }\n\n        tensor = None\n        for arg in args:\n            if isinstance(arg, torch.Tensor):\n                tensor = arg\n                break\n        if tensor is None:\n            for value in kwargs.values():\n                if isinstance(value, torch.Tensor):\n                    tensor = value\n                    break\n\n        if tensor is not None:\n            ctx = custom_device_ctx(tensor.device.index)\n        else:\n            ctx = contextlib.nullcontext()\n\n        with ctx:\n            return fn(*contiguous_args, **contiguous_kwargs)\n\n    return wrapper\n\n\ncontiguous = input_guard\n\n\ndef require_version(version, hint):\n    \"\"\"\n    Perform a runtime check of the dependency versions, using the exact same syntax used by pip.\n    \"\"\"\n\n    def decorator(fn):\n        @functools.wraps(fn)\n        def wrapper(ctx, *args, **kwargs):\n            from transformers.utils.versions import require_version\n\n            require_version(version, hint)\n            return fn(\n                ctx,\n                *(\n                    i if not isinstance(i, torch.Tensor) else i.contiguous()\n                    for i in args\n                ),\n                **{\n                    k: (v if not isinstance(v, torch.Tensor) else v.contiguous())\n                    for k, v in kwargs.items()\n                },\n            )\n\n        return wrapper\n\n    return decorator\n\n\ndef checkpoint(fn):\n    def wrapper(*args, **kwargs):\n        return torch.utils.checkpoint.checkpoint(fn, *args, **kwargs)\n\n    return wrapper\n\n\ndef _cpu_device_warning():\n    import warnings\n\n    warnings.warn(\n        (\"Triton is not supported on current platform, roll back to CPU.\"), stacklevel=1\n    )\n\n\n@lru_cache(maxsize=None)\ndef get_multiprocessor_count(tensor_idx: int = 0) -> int:\n    try:\n        return triton.runtime.driver.active.utils.get_device_properties(tensor_idx)[\n            \"multiprocessor_count\"\n        ]\n    except BaseException:\n        _cpu_device_warning()\n        return -1\n\n\n@lru_cache(maxsize=None)\ndef get_available_device() -> str:\n    try:\n        return triton.runtime.driver.active.get_current_target().backend\n    except BaseException:\n        _cpu_device_warning()\n        return \"cpu\"\n\n\n@lru_cache(maxsize=None)\ndef _check_platform() -> Literal[\"nvidia\", \"amd\", \"intel\", \"musa\"]:\n    device = get_available_device()\n    if device == \"cuda\":\n        return \"nvidia\"\n    elif device == \"hip\":\n        return \"amd\"\n    elif device == \"xpu\":\n        return \"intel\"\n    else:\n        return device\n\n\n# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'.\n# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs.\n# Therefore, we need to check the triton backend to determine the actual GPU vendor.\ndevice = get_available_device() if get_available_device() != \"hip\" else \"cuda\"\ndevice_torch_lib = getattr(torch, device)\ndevice_platform = _check_platform()\n\nis_amd = device_platform == \"amd\"\nis_intel = device_platform == \"intel\"\nis_nvidia = device_platform == \"nvidia\"\nis_intel_alchemist = is_intel and \"Intel(R) Arc(TM) A\" in torch.xpu.get_device_name(0)\nis_nvidia_hopper = is_nvidia and (\n    \"NVIDIA H\" in torch.cuda.get_device_name(0)\n    or torch.cuda.get_device_capability()[0] >= 9\n)\nuse_cuda_graph = is_nvidia and os.environ.get(\"FLA_USE_CUDA_GRAPH\", \"0\") == \"1\"\n\n# Nvidia Ampere or newer, haven't check AMD and intel yet.\nis_tf32_supported = is_nvidia and torch.cuda.get_device_capability(0)[0] >= 8\nis_gather_supported = hasattr(triton.language, \"gather\")\n\n\ndef get_all_max_shared_mem():\n    try:\n        return [\n            triton.runtime.driver.active.utils.get_device_properties(i)[\n                \"max_shared_mem\"\n            ]\n            for i in range(device_torch_lib.device_count())\n        ]\n    except BaseException:\n        _cpu_device_warning()\n        return [-1]\n\n\nclass Backend(Enum):\n    ADA = 101376  # RTX 4090\n    AMPERE = 166912  # A100\n    HOPPER = 232448  # H100\n    DEFAULT = 102400  # Default\n\n    @classmethod\n    def get_shared_memory(cls, arch: str) -> int:\n        try:\n            return cls[arch.upper()].value\n        except KeyError:\n            return cls.DEFAULT.value\n\n\n@lru_cache(maxsize=None)\ndef check_shared_mem(arch: str = \"none\", tensor_idx: int = 0) -> bool:\n    try:\n        device_shared_mem_list = get_all_max_shared_mem()\n        max_shared_memory = device_shared_mem_list[tensor_idx]\n        return max_shared_memory >= Backend.get_shared_memory(arch)\n    except Exception:\n        return False\n\n\nif torch_release >= (2, 4):\n    device = \"cuda\" if device == \"cpu\" else device\n    autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=device)\n    autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=device)\n\n    def custom_device_ctx(index: int):\n        return device_torch_lib.device(index)\n\nelse:\n    assert (\n        device == \"cuda\"\n    ), \"Only cuda device is supported for PyTorch version < 2.4.0.\"\n    autocast_custom_fwd = device_torch_lib.amp.custom_fwd\n    autocast_custom_bwd = device_torch_lib.amp.custom_bwd\n\n    def custom_device_ctx(index: int):\n        return torch.cuda.device(index)\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/fla/wy_fast.py",
    "content": "# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gated_delta_rule/wy_fast.py\n# -*- coding: utf-8 -*-\n# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang\n\nfrom typing import Optional, Tuple\n\nimport torch\nimport triton\nimport triton.language as tl\n\nfrom sglang.srt.layers.attention.fla.index import prepare_chunk_indices\n\n\n# @triton.autotune(\n#     configs=[\n#         triton.Config({}, num_warps=num_warps, num_stages=num_stages)\n#         for num_warps in [2, 4, 8]\n#         for num_stages in [2, 3, 4]\n#     ],\n#     key=[\"H\", \"K\", \"V\", \"BT\", \"BK\", \"BV\", \"IS_VARLEN\"],\n# )\n@triton.jit(do_not_specialize=[\"T\"])\ndef recompute_w_u_fwd_kernel(\n    k,\n    v,\n    beta,\n    w,\n    u,\n    A,\n    g,\n    cu_seqlens,\n    chunk_indices,\n    T,\n    H: tl.constexpr,\n    Hg: tl.constexpr,\n    K: tl.constexpr,\n    V: tl.constexpr,\n    BT: tl.constexpr,\n    BK: tl.constexpr,\n    BV: tl.constexpr,\n    IS_VARLEN: tl.constexpr,\n):\n    i_t, i_bh = tl.program_id(0), tl.program_id(1)\n    i_b, i_h = i_bh // H, i_bh % H\n    if IS_VARLEN:\n        i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(\n            chunk_indices + i_t * 2 + 1\n        ).to(tl.int32)\n        bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(\n            cu_seqlens + i_n + 1\n        ).to(tl.int32)\n        T = eos - bos\n    else:\n        bos, eos = i_b * T, i_b * T + T\n    p_beta = tl.make_block_ptr(\n        beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)\n    )\n    p_g = tl.make_block_ptr(g + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,))\n    p_A = tl.make_block_ptr(\n        A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)\n    )\n    b_beta = tl.load(p_beta, boundary_check=(0,))\n    b_A = tl.load(p_A, boundary_check=(0, 1))\n    b_g = tl.exp(tl.load(p_g, boundary_check=(0,)))\n\n    for i_v in range(tl.cdiv(V, BV)):\n        p_v = tl.make_block_ptr(\n            v + (bos * H + i_h) * V,\n            (T, V),\n            (H * V, 1),\n            (i_t * BT, i_v * BV),\n            (BT, BV),\n            (1, 0),\n        )\n        p_u = tl.make_block_ptr(\n            u + (bos * H + i_h) * V,\n            (T, V),\n            (H * V, 1),\n            (i_t * BT, i_v * BV),\n            (BT, BV),\n            (1, 0),\n        )\n        b_v = tl.load(p_v, boundary_check=(0, 1))\n        b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)\n        b_u = tl.dot(b_A, b_vb, allow_tf32=False)\n        tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))\n\n    for i_k in range(tl.cdiv(K, BK)):\n        p_k = tl.make_block_ptr(\n            k + (bos * Hg + i_h // (H // Hg)) * K,\n            (T, K),\n            (Hg * K, 1),\n            (i_t * BT, i_k * BK),\n            (BT, BK),\n            (1, 0),\n        )\n        p_w = tl.make_block_ptr(\n            w + (bos * H + i_h) * K,\n            (T, K),\n            (H * K, 1),\n            (i_t * BT, i_k * BK),\n            (BT, BK),\n            (1, 0),\n        )\n        b_k = tl.load(p_k, boundary_check=(0, 1))\n        b_kb = (b_k * b_beta[:, None] * b_g[:, None]).to(b_k.dtype)\n        b_w = tl.dot(b_A, b_kb)\n        tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))\n\n\ndef recompute_w_u_fwd(\n    k: torch.Tensor,\n    v: torch.Tensor,\n    beta: torch.Tensor,\n    g_cumsum: torch.Tensor,\n    A: torch.Tensor,\n    cu_seqlens: Optional[torch.LongTensor],\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    B, T, Hg, K, V = *k.shape, v.shape[-1]\n    H = v.shape[-2]\n    BT = A.shape[-1]\n\n    chunk_indices = (\n        prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None\n    )\n    NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)\n    BK = 64\n    BV = 64\n    u = torch.empty_like(v)\n    w = k.new_empty(B, T, H, K)\n    recompute_w_u_fwd_kernel[(NT, B * H)](\n        k=k,\n        v=v,\n        beta=beta,\n        w=w,\n        u=u,\n        A=A,\n        g=g_cumsum,\n        cu_seqlens=cu_seqlens,\n        chunk_indices=chunk_indices,\n        T=T,\n        H=H,\n        Hg=Hg,\n        K=K,\n        V=V,\n        BT=BT,\n        BK=BK,\n        BV=BV,\n        IS_VARLEN=cu_seqlens is not None,\n        num_warps=4,\n        num_stages=3,\n    )\n    return w, u\n\n\nfwd_recompute_w_u = recompute_w_u_fwd\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/flashattention_backend.py",
    "content": "from __future__ import annotations\n\nfrom dataclasses import dataclass\nfrom typing import TYPE_CHECKING, Optional\n\nimport numpy as np\nimport torch\nimport triton\nimport triton.language as tl\n\nfrom sglang.srt.configs.model_config import AttentionArch\nfrom sglang.srt.layers.attention.base_attn_backend import AttentionBackend\nfrom sglang.srt.layers.radix_attention import AttentionType\nfrom sglang.srt.mem_cache.swa_memory_pool import SWAKVPool\nfrom sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode\nfrom sglang.srt.server_args import get_global_server_args\nfrom sglang.srt.speculative.spec_info import SpecInput\nfrom sglang.srt.utils import get_compiler_backend\n\nif TYPE_CHECKING:\n    from sglang.srt.layers.radix_attention import RadixAttention\n    from sglang.srt.model_executor.model_runner import ModelRunner\n\nfrom sgl_kernel import merge_state_v2\nfrom sgl_kernel.flash_attn import flash_attn_varlen_func as flash_attn_varlen_func_fa3\nfrom sgl_kernel.flash_attn import flash_attn_with_kvcache as flash_attn_with_kvcache_fa3\n\nflash_attn_varlen_func = flash_attn_varlen_func_fa3\nflash_attn_with_kvcache = flash_attn_with_kvcache_fa3\n\nfrom sglang.jit_kernel.flash_attention_v4 import (\n    flash_attn_varlen_func as flash_attn_varlen_func_fa4,\n)\nfrom sglang.jit_kernel.flash_attention_v4 import (\n    flash_attn_with_kvcache as flash_attn_with_kvcache_fa4,\n)\n\n\n@dataclass\nclass FlashAttentionMetadata:\n    \"\"\"Metadata to be init once in the model forward pass,\n    each layer's forward pass can reuse the metadata.\n\n    For each init metadata function, we will try set up them in below order\n    \"\"\"\n\n    # Sequence lengths for the forward batch\n    cache_seqlens_int32: torch.Tensor = None\n    # Maximum sequence length for query\n    max_seq_len_q: int = 1\n    # Maximum sequence length for key\n    max_seq_len_k: int = 0\n    # Cumulative sequence lengths for query\n    cu_seqlens_q: torch.Tensor = None\n    # Cumulative sequence lengths for key\n    cu_seqlens_k: torch.Tensor = None\n    # Window size (typically used by Gemma)\n    window_size: tuple = (-1, -1)\n    # Page table, the index of KV Cache Tables/Blocks\n    page_table: torch.Tensor = None\n    # Page table for Sliding Window Attention\n    swa_page_table: torch.Tensor = None\n\n    # Encoder metadata\n    # Cumulative sequence lengths for encoder key\n    encoder_cu_seqlens_k: torch.Tensor = None\n    # Maximum sequence length for encoder key\n    encoder_max_seq_len_k: int = 0\n    # Sequence lengths for the forward batch\n    encoder_lens_int32: torch.Tensor = None\n    # Page table for the encoder\n    encoder_page_table: torch.Tensor = None\n\n    @dataclass\n    class LocalAttentionMetadata:\n        local_query_start_loc: torch.Tensor = None  # cu_seqlens_q for local attention\n        local_seqused_k: torch.Tensor = None  # sequence lengths for local attention\n        local_block_table: torch.Tensor = None  # block table for local attention\n        local_max_query_len: int = 0  # max query length for local attention\n        local_max_seq_len: int = 0  # max sequence length for local attention\n\n    local_attn_metadata: Optional[LocalAttentionMetadata] = None\n\n    # For sliding window attention topk>1 spec decoding\n    swa_spec_metadata: Optional[FlashAttentionMetadata] = None\n\n\n# Copied from:\n# https://github.com/houseroad/vllm/blob/4e45bfcaf928bdb9bd952b4ac922a3c205589ae8/vllm/v1/attention/backends/flash_attn.py\n#\n# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into\n# local attention blocks, where each block is passed to the attention kernel\n# as an independent local (\"virtual\") batch item.\n#\n# For example, if are performing a chunked prefill a batch of 3 sequences:\n#   q_seqlens  = [4, 10, 5]\n#   kv_seqlens = [6, 17, 9]\n# Then normally for regular attention we would compute with an attention mask\n#  for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like:\n#   batch idx: 0 (q_seqlens = 4, kv_seqlens = 6)\n#        k_toks >   0 1 2 3 4 5\n#        q_toks v  _____________\n#               0 | 1 1 1\n#               1 | 1 1 1 1\n#               2 | 1 1 1 1 1\n#               3 | 1 1 1 1 1 1\n#\n# for local attention (with attn_chunk_size = 4) we would compute with an\n#  attention mask like:\n#   batch idx: 0  (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4)\n#        k_toks >   0 1 2 3 4 5\n#        q_toks v  _____________\n#               0 | 1 1 1\n#               1 | 1 1 1 1\n#               2 |         1\n#               3 |         1 1\n#\n# We can simulate this mask using standard flash-attention by breaking the\n#  sequences into local (\"virtual\") batches, where each local batch item is a\n#  local attention block, so in this case batch idx 0 would be broken up into:\n#\n#   local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4)  (batch 0)\n#        k_toks >   0 1 2 3\n#        q_toks v  _____________\n#               0 | 1 1 1\n#               1 | 1 1 1 1\n#   local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0)\n#        k_toks >   4 5\n#        q_toks v  _____________\n#               2 | 1\n#               3 | 1 1\n#\n# e.g. if we have:\n#   attn_chunk_size = 4\n#   query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5])\n# Then this function would return:\n#                           __b0__  ______b1______  __b2__ < orig batch indices\n#   q_seqlens_local    = [   2,  2,  1,  4,  4,  1,  4,  1]\n#   cu_seqlens_q_local = [0, 4,  6, 10, 14, 18, 19, 23, 24]\n#   seqlens_k_local    = [   4,  2,  4,  4,  4,  1,  4,  1]\n#   block_table_local  : shape[local_virtual_batches, pages_per_local_batch]\ndef make_local_attention_virtual_batches(\n    attn_chunk_size: int,\n    query_start_loc_np: np.ndarray,\n    seq_lens_np: np.ndarray,\n    block_table: torch.Tensor,\n    page_size: int = 0,\n) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]:\n    \"\"\"\n    Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into\n    local attention blocks, where each block is passed to the attention kernel\n    as an independent local (\"virtual\") batch item.\n\n    Args:\n        attn_chunk_size: Size of local attention chunks\n        query_start_loc_np: Cumulative sum of query lengths (numpy array)\n        seq_lens_np: Sequence lengths (numpy array)\n        block_table: Block table for KV cache\n        page_size: Size of each page in the KV cache\n\n    Returns:\n        seqlens_q_local: Query sequence lengths for local attention\n        cu_seqlens_q_local: Cumulative sum of query sequence lengths for local attention\n        seqlens_k_local: Key sequence lengths for local attention\n        block_table_local: Block table for local attention\n    \"\"\"\n    # Adjust attention_chunk_size based on the actual sequence length\n    # to avoid index out of bounds errors\n    max_seq_len = seq_lens_np.max()\n    effective_chunk_size = min(attn_chunk_size, max_seq_len)\n    # Make sure effective_chunk_size is divisible by page_size\n    effective_chunk_size = (effective_chunk_size // page_size) * page_size\n    if effective_chunk_size < page_size:\n        effective_chunk_size = page_size\n    attn_chunk_size = effective_chunk_size\n\n    q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]\n    actual_batch_size = seq_lens_np.shape[0]\n\n    # Handle if we are starting in the middle of a local attention block,\n    #  we assume q_seqlens > 0 (for all elements), for each batch idx we compute\n    #  the number of tokens that are not in the first local attention block and\n    #  then we can simply use a cdiv for the rest.\n    # For example if we have:\n    #   attn_chunk_size = 4\n    #   q_seqlens = [4, 10, 5]\n    #   k_seqlens = [6, 17, 9]\n    # Then we would get:\n    #   new_tokens_in_first_block = [2, 1, 4]\n    #   local_blocks = [2, 4, 2]\n    q_tokens_in_first_block = np.minimum(\n        attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), q_seqlens\n    ).astype(np.int32)\n    tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)\n    local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, attn_chunk_size)\n\n    # Once we know the number of local blocks we can compute the request spans\n    #  for each batch idx, we can figure out the number of \"virtual\" requests we\n    #  have to make,\n    # For the above example we would get:\n    #   seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]\n    #\n    # First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])\n    #   (TODO: max a utility to share this code with _prepare_inputs)\n    # arange step 1. [2, 4, 2] -> [2, 6, 8]\n    cu_num_blocks = np.cumsum(local_blocks)\n    virtual_batches = cu_num_blocks[-1]\n    # arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]\n    block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)\n    # arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]\n    arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets\n    # also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])\n    rarange = np.repeat(local_blocks, local_blocks) - arange - 1\n    # Then we can compute the seqlens_q_local, handling the fact that the\n    #  first and last blocks could be partial\n    seqlens_q_local = np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)\n    # set the first block since this may be a partial block\n    seqlens_q_local[arange == 0] = q_tokens_in_first_block\n    # set the remaining blocks\n    seqlens_q_local[arange > 0] = np.minimum(\n        seqlens_q_local - attn_chunk_size * (arange - 1), attn_chunk_size\n    )[arange > 0]\n\n    # convert from q_seqlens to cu_seqlens_q\n    cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0)).astype(np.int32)\n\n    # compute the seqlens_k_local,\n    #  basically a full local attention block for all but the last block in each\n    #  batch\n    # For our example this will be:\n    #   seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]\n    seqlens_k_local = np.full(cu_num_blocks[-1], attn_chunk_size, dtype=np.int32)\n    seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block\n\n    k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - (\n        rarange * attn_chunk_size + np.repeat(tokens_in_last_block, local_blocks)\n    )\n    # For the example the local attention blocks start at:\n    #                           _b0_  _____b1_____  _b2_\n    #   k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]\n    block_starts = k_seqstarts_absolute // page_size\n\n    assert attn_chunk_size % page_size == 0, (\n        f\"attn_chunk_size {attn_chunk_size} is not \"\n        f\"divisible by page_size {page_size}\"\n    )\n    pages_per_local_batch = attn_chunk_size // page_size\n\n    # Create a block_table for the local attention blocks\n    # For out example if we have a block-table like (assuming page_size=2):\n    #   block_table = [\n    #     [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9],  < batch 0\n    #     [10, 11, 12, 13, 14, 15, 16, 17, 18, 19],  < batch 1\n    #     [20, 21, 22, 23, 24, 25, 26, 27, 28, 29],  < batch 2\n    #   ]\n    # Then for the local batches we would want a block-table like\n    #   block_table_local = [\n    #     [  0,  1 ], < local-batch 0, (batch 0, starting from k[0])\n    #     [  2,  3 ], < local-batch 1, (batch 0, starting from k[4])\n    #     [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])\n    #     [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])\n    #     [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])\n    #     [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])\n    #     [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])\n    #     [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])\n    #   ]\n    block_indices = np.broadcast_to(\n        np.arange(pages_per_local_batch, dtype=np.int32),\n        (virtual_batches, pages_per_local_batch),\n    ) + np.expand_dims(block_starts, axis=1)\n    # Ensure block_indices doesn't exceed block_table dimensions\n    # This is a critical safety check that prevents index out of bounds errors\n    # when dealing with large sequences (>8192 tokens) or when the block_table\n    # dimensions are smaller than what would be needed for the full attention chunk size.\n    block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1)\n    batch_indices = np.repeat(\n        np.arange(actual_batch_size, dtype=np.int32),\n        local_blocks * pages_per_local_batch,\n    )\n\n    # NOTE: https://github.com/pytorch/pytorch/pull/160256 causes performance\n    # regression when using numpy arrays (batch and block indices) to index into\n    # torch tensor (block_table). As a workaround, convert numpy arrays to torch\n    # tensor first, which recovers perf.\n    batch_indices_torch = torch.from_numpy(batch_indices)\n    block_indices_torch = torch.from_numpy(block_indices)\n    block_table_local = block_table[batch_indices_torch, block_indices_torch].view(\n        virtual_batches, -1\n    )\n\n    return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, block_table_local\n\n\ndef cdiv(a: int, b: int) -> int:\n    \"\"\"Ceiling division.\"\"\"\n    return -(a // -b)\n\n\n# TODO(hebiao064): remove this once we have a better way to handle the merge_state_v2 torch.compile issue\n@torch._dynamo.disable()\ndef merge_state_v2_wrapper(o, s_a, o_exp, s_b):\n    return merge_state_v2(o, s_a, o_exp, s_b)\n\n\nclass FlashAttentionBackend(AttentionBackend):\n    \"\"\"FlashAttention backend implementation.\n\n    Note about the init:\n    - If no spec decoding\n        - FlashAttentionBackend will be init once when the server starts.\n    - If spec decoding\n        - FlashAttentionBackend will be init once for the target worker\n        - FlashAttentionMultiStepBackend will be once for the draft worker\n            - It will spawn num_steps FlashAttentionBackend for the draft worker\n\n    Note about CUDA Graph:\n    - We only support CUDA Graph for Decode (Normal Decode and Draft Decode) and Target Verify.\n    - We don't support CUDA Graph for Extend and Draft Extend.\n    - When server init, init_cuda_graph_state will be called first and then init_cuda_graph_capture will be called.\n    - For each forward batch, init_replay_cuda_graph will be called first and then replay the graph.\n    \"\"\"\n\n    def __init__(\n        self,\n        model_runner: ModelRunner,\n        skip_prefill: bool = False,\n        speculative_step_id=0,\n        topk=0,\n        speculative_num_steps=0,\n        fa_impl_ver=3,\n    ):\n        super().__init__()\n\n        assert not (\n            model_runner.sliding_window_size is not None\n            and model_runner.model_config.is_encoder_decoder\n        ), \"Sliding window and cross attention are not supported together\"\n\n        self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder\n        self.forward_metadata: FlashAttentionMetadata = None\n        # extra metadata for handling speculative decoding topk > 1, extended draft decode and verify\n        self.forward_metadata_spec_decode_expand: FlashAttentionMetadata = None\n        self.max_context_len = model_runner.model_config.context_len\n        self.device = model_runner.device\n        self.decode_cuda_graph_metadata = {}\n        self.target_verify_metadata = {}\n        self.req_to_token = model_runner.req_to_token_pool.req_to_token\n        self.kv_cache_dtype = model_runner.kv_cache_dtype\n        self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype\n        self.page_size = model_runner.page_size\n        self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA\n        self.skip_prefill = skip_prefill\n\n        self.use_sliding_window_kv_pool = (\n            isinstance(model_runner.token_to_kv_pool, SWAKVPool)\n            and model_runner.token_to_kv_pool.swa_layer_nums > 0\n        )\n\n        if self.use_sliding_window_kv_pool:\n            self.token_to_kv_pool = model_runner.token_to_kv_pool\n\n        self.topk = model_runner.server_args.speculative_eagle_topk or 0\n        self.speculative_num_steps = speculative_num_steps\n        self.speculative_num_draft_tokens = (\n            model_runner.server_args.speculative_num_draft_tokens\n        )\n        self.speculative_step_id = speculative_step_id\n\n        self.fa_impl_ver = fa_impl_ver\n\n        # Local attention settings\n        self.has_local_attention = model_runner.model_config.is_local_attention_model\n        if self.has_local_attention:\n            assert (\n                model_runner.attention_chunk_size is not None\n            ), \"Attention chunk size is required for local attention\"\n            self.attention_chunk_size = model_runner.attention_chunk_size\n\n        # For each layer, the sliding_window_size can be different. This is only used for preparing SWA metadata.\n        # We use `layer.sliding_window_size` to decide whether to use SWA for each layer.\n        self.sliding_window_size = model_runner.sliding_window_size\n        self.has_swa = (\n            self.sliding_window_size is not None and self.sliding_window_size > -1\n        )\n\n        # If num_splits == 0, we use a heuristic to automatically determine the number of splits.\n        # We set nums splits to 1 if deterministic inference is enabled.\n        # See https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/ for more details.\n        # Furthermore, FA4 does not support num_splits=0 with CUDA Graph, so we set num_splits to 1 if CUDA Graph is enabled.\n        self.num_splits = (\n            1\n            if model_runner.server_args.enable_deterministic_inference\n            or (\n                self.fa_impl_ver == 4\n                and not model_runner.server_args.disable_cuda_graph\n            )\n            else 0\n        )\n\n    def init_forward_metadata(self, forward_batch: ForwardBatch):\n        \"\"\"Initialize forward metadata hence all layers in the forward pass can reuse it.\"\"\"\n        metadata = FlashAttentionMetadata()\n        seqlens_in_batch = forward_batch.seq_lens\n        batch_size = forward_batch.batch_size\n        device = seqlens_in_batch.device\n\n        if forward_batch.forward_mode.is_decode_or_idle():\n            # Draft Decode\n            if forward_batch.spec_info is not None:\n                if self.topk <= 1:\n                    metadata.cache_seqlens_int32 = (\n                        seqlens_in_batch + (self.speculative_step_id + 1)\n                    ).to(torch.int32)\n                    metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (\n                        self.speculative_step_id + 1\n                    )\n                    metadata.cu_seqlens_q = torch.arange(\n                        0, batch_size + 1, dtype=torch.int32, device=device\n                    )\n                    metadata.cu_seqlens_k = torch.nn.functional.pad(\n                        torch.cumsum(\n                            metadata.cache_seqlens_int32, dim=0, dtype=torch.int32\n                        ),\n                        (1, 0),\n                    )\n                    metadata.page_table = forward_batch.req_to_token_pool.req_to_token[\n                        forward_batch.req_pool_indices, : metadata.max_seq_len_k\n                    ]\n                else:\n                    metadata.cache_seqlens_int32 = (seqlens_in_batch).to(torch.int32)\n                    metadata.max_seq_len_q = self.topk\n                    metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()\n                    metadata.cu_seqlens_q = torch.arange(\n                        0,\n                        batch_size * self.topk + 1,\n                        step=self.topk,\n                        dtype=torch.int32,\n                        device=device,\n                    )\n                    metadata.cu_seqlens_k = torch.nn.functional.pad(\n                        torch.cumsum(\n                            metadata.cache_seqlens_int32, dim=0, dtype=torch.int32\n                        ),\n                        (1, 0),\n                    )\n                    metadata.page_table = forward_batch.req_to_token_pool.req_to_token[\n                        forward_batch.req_pool_indices, : metadata.max_seq_len_k\n                    ]\n                    metadata_expand = FlashAttentionMetadata()\n                    decode_length = self.speculative_step_id + 1\n                    metadata_expand.cache_seqlens_int32 = torch.full(\n                        (seqlens_in_batch.numel() * self.topk,),\n                        decode_length,\n                        device=device,\n                        dtype=torch.int32,\n                    )\n                    metadata_expand.max_seq_len_q = 1\n                    metadata_expand.cu_seqlens_q = torch.arange(\n                        0,\n                        metadata_expand.cache_seqlens_int32.numel() + 1,\n                        dtype=torch.int32,\n                        device=device,\n                    )\n                    metadata_expand.cu_seqlens_k = torch.arange(\n                        0,\n                        metadata_expand.cache_seqlens_int32.numel() * decode_length + 1,\n                        step=decode_length,\n                        dtype=torch.int32,\n                        device=device,\n                    )\n                    # shape: [bs, num_steps, topk] -> [bs x topk, num_steps]\n                    cache_loc = forward_batch.out_cache_loc.view(\n                        -1, self.speculative_num_steps\n                    )\n                    metadata_expand.page_table = (\n                        cache_loc[:, :decode_length].contiguous().to(torch.int32)\n                    )\n                    self.forward_metadata_spec_decode_expand = metadata_expand\n            else:\n                # Normal Decode\n                metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)\n                metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()\n                metadata.cu_seqlens_q = torch.arange(\n                    0, batch_size + 1, dtype=torch.int32, device=device\n                )\n                metadata.cu_seqlens_k = torch.nn.functional.pad(\n                    torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)\n                )\n                metadata.page_table = forward_batch.req_to_token_pool.req_to_token[\n                    forward_batch.req_pool_indices, : metadata.max_seq_len_k\n                ]\n            # TODO: we need to test this part for llama 4 eagle case\n            self._maybe_init_local_attn_metadata(forward_batch, metadata, device)\n        elif forward_batch.forward_mode.is_target_verify():\n            if self.topk <= 1:\n                metadata.cache_seqlens_int32 = (\n                    forward_batch.seq_lens + self.speculative_num_draft_tokens\n                ).to(torch.int32)\n                metadata.max_seq_len_q = self.speculative_num_draft_tokens\n                metadata.max_seq_len_k = (\n                    forward_batch.seq_lens_cpu.max().item()\n                    + self.speculative_num_draft_tokens\n                )\n                metadata.cu_seqlens_q = torch.arange(\n                    0,\n                    batch_size * self.speculative_num_draft_tokens + 1,\n                    self.speculative_num_draft_tokens,\n                    dtype=torch.int32,\n                    device=device,\n                )\n                metadata.cu_seqlens_k = torch.nn.functional.pad(\n                    torch.cumsum(\n                        metadata.cache_seqlens_int32, dim=0, dtype=torch.int32\n                    ),\n                    (1, 0),\n                )\n                metadata.page_table = forward_batch.req_to_token_pool.req_to_token[\n                    forward_batch.req_pool_indices, : metadata.max_seq_len_k\n                ]\n\n                self._maybe_init_local_attn_metadata(forward_batch, metadata, device)\n            else:\n                metadata.cache_seqlens_int32 = forward_batch.seq_lens.to(torch.int32)\n                metadata.max_seq_len_q = self.speculative_num_draft_tokens\n                metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()\n                metadata.cu_seqlens_q = torch.arange(\n                    0,\n                    batch_size * self.speculative_num_draft_tokens + 1,\n                    step=self.speculative_num_draft_tokens,\n                    dtype=torch.int32,\n                    device=device,\n                )\n                metadata.cu_seqlens_k = torch.nn.functional.pad(\n                    torch.cumsum(\n                        metadata.cache_seqlens_int32, dim=0, dtype=torch.int32\n                    ),\n                    (1, 0),\n                )\n                metadata.page_table = forward_batch.req_to_token_pool.req_to_token[\n                    forward_batch.req_pool_indices, : metadata.max_seq_len_k\n                ]\n\n                metadata_expand = FlashAttentionMetadata()\n\n                metadata_expand.max_seq_len_q = 1\n                metadata_expand.cu_seqlens_q = torch.arange(\n                    0,\n                    forward_batch.seq_lens.numel() * self.speculative_num_draft_tokens\n                    + 1,\n                    dtype=torch.int32,\n                    device=device,\n                )\n\n                # create expand page table\n                offsets = torch.arange(\n                    self.speculative_num_draft_tokens, device=device\n                ).unsqueeze(\n                    0\n                )  # shape: (1, self.speculative_num_draft_tokens)\n                cols = offsets.expand(\n                    forward_batch.seq_lens.numel(), -1\n                ) + forward_batch.seq_lens.unsqueeze(1)\n                cum_len = torch.nn.functional.pad(\n                    torch.cumsum(\n                        (\n                            forward_batch.seq_lens + self.speculative_num_draft_tokens\n                        ).repeat_interleave(self.speculative_num_draft_tokens),\n                        dim=0,\n                    ),\n                    (1, 0),\n                )[:-1]\n                mask_extraction_indices = (\n                    cols.repeat_interleave(self.speculative_num_draft_tokens, dim=0)\n                    + cum_len[:, None]\n                ).view(1, -1)\n                mask = forward_batch.spec_info.custom_mask[\n                    mask_extraction_indices\n                ].view(\n                    -1, self.speculative_num_draft_tokens\n                )  # (bsz * draft_num, draft_num)\n\n                # shift table indices to avoid padding\n                # non_masked_page_table [[8, 9, 10],   mask (display with int format) [[1, 0, 0],\n                #                        [8, 9, 10],                                   [1, 1, 0],\n                #                        [8, 9, 10]]                                   [1, 0, 1]]\n                # if masked with padding [[8, 0, 0],   our mask without padding       [[8, 9, 10],\n                #                        [8, 9, 0],                                    [8, 9, 10],\n                #                        [8, 0, 10]]                                   [8, 10, 9]]\n                # note here cache_seqlens_int32 is [1, 2, 2] so extra page indices will be ignored in each row\n                col_indices = offsets.expand(\n                    mask.shape[0], self.speculative_num_draft_tokens\n                )\n                # Build keys: if an entry is valid (mask==True), keep its original index;\n                # if not, add self.speculative_num_draft_tokens so that it sorts after all valid entries.\n                keys = torch.where(\n                    mask, col_indices, col_indices + self.speculative_num_draft_tokens\n                )\n                _, sort_order = torch.sort(keys, dim=1)\n                non_masked_page_table = (\n                    forward_batch.req_to_token_pool.req_to_token[\n                        forward_batch.req_pool_indices, :\n                    ]\n                    .gather(1, cols)\n                    .repeat_interleave(self.speculative_num_draft_tokens, dim=0)\n                )  # (bsz, draft_num)\n                metadata_expand.page_table = non_masked_page_table.gather(1, sort_order)\n                metadata_expand.cache_seqlens_int32 = mask.sum(dim=1).to(torch.int32)\n                metadata_expand.cu_seqlens_k = torch.nn.functional.pad(\n                    torch.cumsum(\n                        metadata_expand.cache_seqlens_int32, dim=0, dtype=torch.int32\n                    ),\n                    (1, 0),\n                )\n                self.forward_metadata_spec_decode_expand = metadata_expand\n\n                if self.has_swa:\n                    self._init_sliding_window_attn_spec_metadata(\n                        metadata, metadata_expand\n                    )\n\n        elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed(\n            include_draft_extend_v2=True\n        ):\n            metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)\n            metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()\n            metadata.cu_seqlens_k = torch.nn.functional.pad(\n                torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)\n            )\n            metadata.page_table = forward_batch.req_to_token_pool.req_to_token[\n                forward_batch.req_pool_indices, : metadata.max_seq_len_k\n            ]\n\n            if any(\n                forward_batch.extend_prefix_lens_cpu\n            ) or forward_batch.forward_mode.is_draft_extend(include_v2=True):\n                extend_seq_lens = forward_batch.extend_seq_lens\n                metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)\n                metadata.cu_seqlens_q = torch.nn.functional.pad(\n                    torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)\n                )\n            else:\n                metadata.max_seq_len_q = metadata.max_seq_len_k\n                metadata.cu_seqlens_q = metadata.cu_seqlens_k\n\n            # Setup local attention if enabled\n            if forward_batch.forward_mode == ForwardMode.EXTEND:\n                self._maybe_init_local_attn_metadata(forward_batch, metadata, device)\n\n        # Encoder metadata for cross attention\n        if forward_batch.encoder_lens is not None:\n            assert (\n                forward_batch.encoder_lens.numel() == 1\n            ), \"Only encoder size 1 is supported for now\"\n\n            metadata.encoder_lens_int32 = forward_batch.encoder_lens.to(torch.int32)\n            metadata.encoder_cu_seqlens_k = torch.nn.functional.pad(\n                torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32),\n                (1, 0),\n            )\n            metadata.encoder_max_seq_len_k = metadata.encoder_lens_int32.max().item()\n            metadata.encoder_page_table = forward_batch.req_to_token_pool.req_to_token[\n                forward_batch.req_pool_indices, : metadata.encoder_max_seq_len_k\n            ]\n\n            # Currently only support forward_batch.encoder_lens.numel() == 1\n            metadata.page_table = forward_batch.req_to_token_pool.req_to_token[\n                forward_batch.req_pool_indices,\n                metadata.encoder_max_seq_len_k : (\n                    metadata.encoder_max_seq_len_k + metadata.max_seq_len_k\n                ),\n            ]\n\n        if self.use_sliding_window_kv_pool:\n            metadata.swa_page_table = (\n                self.token_to_kv_pool.translate_loc_from_full_to_swa(\n                    metadata.page_table\n                )\n            )\n\n        # Convert the page table to a strided format which is needed by FA3 API\n        if self.page_size > 1:\n            self.strided_indices = torch.arange(\n                0, metadata.page_table.shape[1], self.page_size, device=self.device\n            )\n\n            if self.use_sliding_window_kv_pool:\n                metadata.swa_page_table = (\n                    metadata.swa_page_table[:, self.strided_indices] // self.page_size\n                )\n\n            metadata.page_table = (\n                metadata.page_table[:, self.strided_indices] // self.page_size\n            )\n\n            if (\n                self.topk > 1\n                and forward_batch.forward_mode.is_decode_or_idle()\n                and forward_batch.spec_info is not None\n            ):\n                # Modifies cache_seqlens_int32 and page_table(B, speculative_num_steps).\n                last_page_lens = forward_batch.seq_lens % self.page_size\n                # First attention handles prefix - last_page_len part.\n                metadata.cache_seqlens_int32 -= last_page_lens  # Both (B, )\n\n                # Second attention handles last_page_len + decode part.\n                expanded_last_page_lens = last_page_lens.repeat_interleave(self.topk)\n                self.forward_metadata_spec_decode_expand.cache_seqlens_int32 += (\n                    expanded_last_page_lens\n                )\n                # NOTE: the max decode length is speculative_num_steps - 1 (one token always generated by draft extend)\n                # and we leave one extra for last_page_len, which -> speculative_num_steps for the page table\n                expand_page_table = torch.zeros(\n                    forward_batch.batch_size * self.topk,\n                    self.speculative_num_steps,\n                    dtype=torch.int32,\n                    device=self.device,\n                )\n                # shape: [bs, num_steps, topk] -> [bs x topk, num_steps]\n                cache_loc = forward_batch.out_cache_loc.view(\n                    -1, self.speculative_num_steps\n                )\n                draft_decode_set_expand_metadata(\n                    cache_seqlens_int32=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,\n                    page_table=expand_page_table,\n                    last_page_lens=last_page_lens,\n                    decode_length=decode_length,\n                    cache_loc=cache_loc,\n                    topk=self.topk,\n                    page_size=self.page_size,\n                )\n                self.forward_metadata_spec_decode_expand.page_table = expand_page_table\n\n        self.forward_metadata = metadata\n\n    def forward_extend(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache=True,\n        # For multi-head latent attention\n        q_rope: Optional[torch.Tensor] = None,\n        k_rope: Optional[torch.Tensor] = None,\n        sinks: Optional[torch.Tensor] = None,\n    ):\n        if k is not None:\n            assert v is not None\n            if save_kv_cache:\n                cache_loc = (\n                    forward_batch.out_cache_loc\n                    if not layer.is_cross_attention\n                    else forward_batch.encoder_out_cache_loc\n                )\n                if not self.use_mla:\n                    forward_batch.token_to_kv_pool.set_kv_buffer(\n                        layer, cache_loc, k, v, layer.k_scale, layer.v_scale\n                    )\n                else:\n                    forward_batch.token_to_kv_pool.set_mla_kv_buffer(\n                        layer,\n                        cache_loc,\n                        k,\n                        k_rope,\n                    )\n\n        # Use precomputed metadata across all layers\n        metadata = self.forward_metadata\n\n        # Calculate window size (can be moved to metadata if layer properties don't change)\n        # we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1\n        # here is two side inclusive\n        is_swa_layer = (\n            layer.sliding_window_size is not None and layer.sliding_window_size > -1\n        )\n        window_size = (layer.sliding_window_size, 0) if is_swa_layer else (-1, -1)\n        k_descale, v_descale = None, None\n        # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention\n        # has corresponding quantization method so that layer.k_scale is not None,\n        # 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case,\n        # 4) fa_impl_ver != 4 since fa4 does not currently support fp8 queries and keys.\n        if (\n            self.kv_cache_dtype_str != \"auto\"\n            and layer.head_dim <= 256\n            and self.fa_impl_ver != 4\n        ):\n            if layer.k_scale is not None:\n                descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)\n                k_descale = layer.k_scale.expand(descale_shape)\n                v_descale = layer.v_scale.expand(descale_shape)\n            q = q.to(self.kv_cache_dtype)\n            q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None\n            k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None\n        causal = True\n        if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY:\n            causal = False\n\n        # Check if we should use local attention\n        use_local_attn = (\n            self.has_local_attention\n            and self.attention_chunk_size is not None\n            and metadata.local_attn_metadata is not None\n            and (hasattr(layer, \"use_irope\") and layer.use_irope)\n        )\n\n        # We do cascade attention for Target Verify with topk > 1\n        # We don't use cascade attention for Sliding Window Attention:\n        # - Different window sizes should be passed in for each q in the first stage of cascade attention, but FA3 interface doesn't support pass in a list of window sizes.\n        # - The overhead of duplicated computation of the common prefix part is small for sliding window layers (seq_len <= window_size), so we can just expand it.\n        use_cascade_attn = (\n            forward_batch.forward_mode.is_target_verify()\n            and self.topk > 1\n            and not is_swa_layer\n        )\n\n        flash_attn_varlen_func_base = flash_attn_varlen_func_fa3\n        flash_attn_with_kvcache_base = flash_attn_with_kvcache_fa3\n\n        flash_attn_varlen_func = (\n            flash_attn_varlen_func_fa4\n            if self.fa_impl_ver == 4\n            else flash_attn_varlen_func_base\n        )\n        flash_attn_with_kvcache = (\n            flash_attn_with_kvcache_fa4\n            if self.fa_impl_ver == 4\n            else flash_attn_with_kvcache_base\n        )\n\n        kwargs = {}\n        if sinks is not None:\n            kwargs[\"sinks\"] = sinks\n\n        # Get the appropriate page table based on whether we're using local attention\n        if use_local_attn:\n            local_metadata = metadata.local_attn_metadata\n            page_table = local_metadata.local_block_table\n            cu_seqlens_q = local_metadata.local_query_start_loc\n            cache_seqlens = local_metadata.local_seqused_k\n            max_seqlen_q = local_metadata.local_max_query_len\n        elif is_swa_layer and metadata.swa_spec_metadata is not None:\n            swa_spec_metadata = metadata.swa_spec_metadata\n            page_table = swa_spec_metadata.page_table\n            cu_seqlens_q = swa_spec_metadata.cu_seqlens_q\n            cache_seqlens = swa_spec_metadata.cache_seqlens_int32\n            max_seqlen_q = swa_spec_metadata.max_seq_len_q\n            cu_seqlens_k = swa_spec_metadata.cu_seqlens_k\n        else:\n            page_table = metadata.page_table\n            if is_swa_layer and self.use_sliding_window_kv_pool:\n                if metadata.swa_page_table is not None:\n                    page_table = metadata.swa_page_table\n                else:\n                    page_table = self.token_to_kv_pool.translate_loc_from_full_to_swa(\n                        metadata.page_table\n                    )\n            cu_seqlens_q = metadata.cu_seqlens_q\n            cache_seqlens = metadata.cache_seqlens_int32\n            max_seqlen_q = metadata.max_seq_len_q\n            cu_seqlens_k = metadata.cu_seqlens_k\n\n        # Use Flash Attention for prefill\n        if not self.use_mla:\n            # Do multi-head attention\n            key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(\n                layer.layer_id\n            )\n            key_cache = key_cache.view(\n                -1, self.page_size, layer.tp_k_head_num, layer.head_dim\n            )\n            value_cache = value_cache.view(\n                -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim\n            )\n            if layer.is_cross_attention:\n                page_table = metadata.encoder_page_table\n                cache_seqlens = metadata.encoder_lens_int32\n                cu_seqlens_k = metadata.encoder_cu_seqlens_k\n                window_size = (-1, -1)\n\n            result = flash_attn_with_kvcache(\n                q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),\n                k_cache=key_cache,\n                v_cache=value_cache,\n                page_table=page_table,\n                cache_seqlens=cache_seqlens,\n                cu_seqlens_q=cu_seqlens_q,\n                cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,\n                max_seqlen_q=max_seqlen_q,\n                softmax_scale=layer.scaling,\n                causal=False if use_cascade_attn else causal,\n                window_size=window_size,\n                softcap=layer.logit_cap,\n                k_descale=k_descale,\n                v_descale=v_descale,\n                return_softmax_lse=use_cascade_attn,\n                num_splits=self.num_splits,\n                **kwargs,\n            )\n\n            if use_cascade_attn:\n                o, softmax_lse, *rest = result\n                o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache(\n                    q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),\n                    # Here metadata_expand.page_table is not divided with page_size.\n                    # This is because we loose the fine control of  what token to attend,\n                    # but has to attend to some block completely.\n                    k_cache=key_cache.view(-1, 1, layer.tp_k_head_num, layer.head_dim),\n                    v_cache=value_cache.view(\n                        -1, 1, layer.tp_v_head_num, layer.head_dim\n                    ),\n                    page_table=self.forward_metadata_spec_decode_expand.page_table,\n                    cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,\n                    cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,\n                    cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,\n                    max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,\n                    softmax_scale=layer.scaling,\n                    causal=False,\n                    window_size=window_size,\n                    softcap=layer.logit_cap,\n                    k_descale=k_descale,\n                    v_descale=v_descale,\n                    return_softmax_lse=True,\n                    num_splits=self.num_splits,\n                    **kwargs,\n                )\n                o, _ = merge_state_v2_wrapper(\n                    o,\n                    softmax_lse.T.contiguous(),\n                    o_expand,\n                    softmax_lse_expand.T.contiguous(),\n                )\n            else:\n                o = result\n        else:\n            if (\n                forward_batch.attn_attend_prefix_cache is not None\n                and not forward_batch.forward_mode.is_target_verify()\n                and not forward_batch.forward_mode.is_draft_extend(include_v2=True)\n            ):\n                # Do multi-head attention with chunked prefix cache\n                if forward_batch.attn_attend_prefix_cache:\n                    assert not get_global_server_args().disable_chunked_prefix_cache\n                    # MHA for chunked prefix kv cache when running model with MLA\n                    assert forward_batch.prefix_chunk_idx is not None\n                    assert forward_batch.prefix_chunk_cu_seq_lens is not None\n                    assert forward_batch.prefix_chunk_max_seq_lens is not None\n\n                    chunk_idx = forward_batch.prefix_chunk_idx\n                    assert chunk_idx >= 0\n\n                    assert forward_batch.mha_return_lse\n                    output = flash_attn_varlen_func(\n                        q=q.view(-1, layer.tp_q_head_num, layer.head_dim),\n                        k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),\n                        v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),\n                        cu_seqlens_q=metadata.cu_seqlens_q,\n                        cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],\n                        max_seqlen_q=metadata.max_seq_len_q,\n                        max_seqlen_k=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],\n                        softmax_scale=layer.scaling,\n                        causal=False,\n                        return_softmax_lse=True,\n                        **kwargs,\n                    )\n                else:\n                    # MHA for extend part of sequence without attending prefix kv cache\n                    cu_seqlens_k = (\n                        metadata.cu_seqlens_q\n                        if not forward_batch.mha_one_shot\n                        else metadata.cu_seqlens_k\n                    )\n                    max_seqlen_k = (\n                        metadata.max_seq_len_q\n                        if not forward_batch.mha_one_shot\n                        else metadata.max_seq_len_k\n                    )\n                    output = flash_attn_varlen_func(\n                        q=q.view(-1, layer.tp_q_head_num, layer.head_dim),\n                        k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),\n                        v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),\n                        cu_seqlens_q=metadata.cu_seqlens_q,\n                        cu_seqlens_k=cu_seqlens_k,\n                        max_seqlen_q=metadata.max_seq_len_q,\n                        max_seqlen_k=max_seqlen_k,\n                        softmax_scale=layer.scaling,\n                        causal=True,\n                        return_softmax_lse=forward_batch.mha_return_lse,\n                        **kwargs,\n                    )\n                if forward_batch.mha_return_lse:\n                    output, lse, *rest = output\n                    lse = torch.transpose(lse, 0, 1).contiguous()\n                    return output, lse\n                return output\n            else:\n                assert self.fa_impl_ver in [3], \"Only FA3 support here\"\n                # Do absorbed multi-latent attention\n                kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(\n                    layer.layer_id\n                ).to(q.dtype)\n                k_rope = kv_cache[:, :, layer.v_head_dim :]\n                c_kv = kv_cache[:, :, : layer.v_head_dim]\n                k_rope_cache = k_rope.view(\n                    -1,\n                    self.page_size,\n                    layer.tp_k_head_num,\n                    layer.head_dim - layer.v_head_dim,\n                )\n                c_kv_cache = c_kv.view(\n                    -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim\n                )\n                if q_rope is not None:\n                    q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)\n                    q_rope = q_rope.view(\n                        -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim\n                    )\n                else:\n                    q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)\n                    q_nope = q_all[:, :, : layer.v_head_dim]\n                    q_rope = q_all[:, :, layer.v_head_dim :]\n\n                result = flash_attn_with_kvcache(\n                    q=q_rope,\n                    k_cache=k_rope_cache,\n                    v_cache=c_kv_cache,\n                    qv=q_nope,\n                    page_table=page_table,\n                    cache_seqlens=cache_seqlens,\n                    cu_seqlens_q=cu_seqlens_q,\n                    cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,\n                    max_seqlen_q=max_seqlen_q,\n                    softmax_scale=layer.scaling,\n                    causal=False if use_cascade_attn else causal,\n                    softcap=layer.logit_cap,\n                    k_descale=k_descale,\n                    v_descale=v_descale,\n                    return_softmax_lse=use_cascade_attn,\n                    num_splits=self.num_splits,\n                )\n                if use_cascade_attn:\n                    o, softmax_lse, *rest = result\n                    o_expand, softmax_lse_expand, *rest_expand = (\n                        flash_attn_with_kvcache(\n                            q=q_rope,\n                            k_cache=k_rope_cache,\n                            v_cache=c_kv_cache,\n                            qv=q_nope,\n                            page_table=self.forward_metadata_spec_decode_expand.page_table,\n                            cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,\n                            cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,\n                            cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,\n                            max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,\n                            softmax_scale=layer.scaling,\n                            causal=False,\n                            window_size=window_size,\n                            softcap=layer.logit_cap,\n                            k_descale=k_descale,\n                            v_descale=v_descale,\n                            return_softmax_lse=True,\n                            num_splits=self.num_splits,\n                        )\n                    )\n                    o, _ = merge_state_v2_wrapper(\n                        o,\n                        softmax_lse.T.contiguous(),\n                        o_expand,\n                        softmax_lse_expand.T.contiguous(),\n                    )\n                else:\n                    o = result\n\n        return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)\n\n    def forward_decode(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache=True,\n        # For multi-head latent attention\n        q_rope: Optional[torch.Tensor] = None,\n        k_rope: Optional[torch.Tensor] = None,\n        sinks: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        if k is not None:\n            assert v is not None\n            if save_kv_cache:\n                cache_loc = (\n                    forward_batch.out_cache_loc\n                    if not layer.is_cross_attention\n                    else forward_batch.encoder_out_cache_loc\n                )\n                if not self.use_mla:\n                    forward_batch.token_to_kv_pool.set_kv_buffer(\n                        layer, cache_loc, k, v, layer.k_scale, layer.v_scale\n                    )\n                else:\n                    forward_batch.token_to_kv_pool.set_mla_kv_buffer(\n                        layer,\n                        cache_loc,\n                        k,\n                        k_rope,\n                    )\n\n        # Use precomputed metadata across all layers\n        metadata = self.forward_metadata\n        local_attn_metadata = getattr(metadata, \"local_attn_metadata\", None)\n        use_local_attn = (\n            self.has_local_attention\n            and self.attention_chunk_size is not None\n            and local_attn_metadata is not None\n            and (hasattr(layer, \"use_irope\") and layer.use_irope)\n        )\n\n        # When Spec Decode enabled, forward_decode would be called with two mode:\n        # 1. DRAFT_DECODE: we enable cascade attention when top_k > 1\n        # 2. IDLE: we don’t need cascade attention, spec_info will be none in this case\n        use_cascade_attn = forward_batch.spec_info is not None and self.topk > 1\n\n        # Calculate window size (can be moved to metadata if layer properties don't change)\n        # we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1\n        # here is two side inclusive\n        is_swa_layer = (\n            layer.sliding_window_size is not None and layer.sliding_window_size > -1\n        )\n        window_size = (layer.sliding_window_size, 0) if is_swa_layer else (-1, -1)\n\n        causal = True\n        if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY:\n            causal = False\n\n        kwargs = {}\n        if sinks is not None:\n            kwargs[\"sinks\"] = sinks\n\n        flash_attn_with_kvcache_base = flash_attn_with_kvcache_fa3\n\n        flash_attn_with_kvcache = (\n            flash_attn_with_kvcache_fa4\n            if self.fa_impl_ver == 4\n            else flash_attn_with_kvcache_base\n        )\n\n        k_descale, v_descale = None, None\n        # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention\n        # has corresponding quantization method so that layer.k_scale is not None,\n        # 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.\n        if self.kv_cache_dtype_str != \"auto\" and layer.head_dim <= 256:\n            if layer.k_scale is not None:\n                descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)\n                k_descale = layer.k_scale.expand(descale_shape)\n                v_descale = layer.v_scale.expand(descale_shape)\n            q = q.to(self.kv_cache_dtype)\n            q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None\n            k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None\n        if not self.use_mla:\n            # Do multi-head attention\n\n            key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(\n                layer.layer_id\n            )\n            key_cache = key_cache.view(\n                -1, self.page_size, layer.tp_k_head_num, layer.head_dim\n            )\n            value_cache = value_cache.view(\n                -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim\n            )\n\n            if layer.is_cross_attention:\n                # Always use non-chunked logic for cross-attention\n                o = flash_attn_with_kvcache(\n                    q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),\n                    k_cache=key_cache,\n                    v_cache=value_cache,\n                    page_table=metadata.encoder_page_table,\n                    cache_seqlens=metadata.encoder_lens_int32,\n                    cu_seqlens_q=metadata.cu_seqlens_q,\n                    cu_seqlens_k_new=metadata.encoder_cu_seqlens_k,\n                    max_seqlen_q=1,\n                    softmax_scale=layer.scaling,\n                    causal=False,\n                    window_size=(-1, -1),\n                    softcap=layer.logit_cap,\n                    k_descale=k_descale,\n                    v_descale=v_descale,\n                    num_splits=self.num_splits,\n                    **kwargs,\n                )\n            elif use_local_attn:\n                # Use chunked (local) attention batching for self-attention\n                o = flash_attn_with_kvcache(\n                    q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),\n                    k_cache=key_cache,\n                    v_cache=value_cache,\n                    page_table=local_attn_metadata.local_block_table,\n                    cache_seqlens=local_attn_metadata.local_seqused_k,\n                    cu_seqlens_q=local_attn_metadata.local_query_start_loc,\n                    cu_seqlens_k_new=None,\n                    max_seqlen_q=local_attn_metadata.local_max_query_len,\n                    softmax_scale=layer.scaling,\n                    causal=True,\n                    window_size=(-1, -1),\n                    softcap=layer.logit_cap,\n                    k_descale=k_descale,\n                    v_descale=v_descale,\n                    num_splits=self.num_splits,\n                    **kwargs,\n                )\n            else:\n                page_table = metadata.page_table\n                if is_swa_layer and self.use_sliding_window_kv_pool:\n                    if metadata.swa_page_table is not None:\n                        page_table = metadata.swa_page_table\n                    else:\n                        page_table = (\n                            self.token_to_kv_pool.translate_loc_from_full_to_swa(\n                                metadata.page_table\n                            )\n                        )\n                cache_seqlens = metadata.cache_seqlens_int32\n                cu_seqlens_k = metadata.cu_seqlens_k\n                max_seqlen_q = metadata.max_seq_len_q\n                q_reshaped = q.contiguous().view(\n                    -1, layer.tp_q_head_num, layer.head_dim\n                )\n\n                # Default: single-token self-attention\n                result = flash_attn_with_kvcache(\n                    q=q_reshaped,\n                    k_cache=key_cache,\n                    v_cache=value_cache,\n                    page_table=page_table,\n                    cache_seqlens=cache_seqlens,\n                    cu_seqlens_q=metadata.cu_seqlens_q,\n                    max_seqlen_q=max_seqlen_q,\n                    softmax_scale=layer.scaling,\n                    causal=False if use_cascade_attn else causal,\n                    window_size=window_size,\n                    softcap=layer.logit_cap,\n                    k_descale=k_descale,\n                    v_descale=v_descale,\n                    return_softmax_lse=use_cascade_attn,\n                    num_splits=self.num_splits,\n                    **kwargs,\n                )\n                if use_cascade_attn:\n                    o, softmax_lse, *rest = result\n                    o_expand, softmax_lse_expand, *rest_expand = (\n                        flash_attn_with_kvcache(\n                            q=q_reshaped,\n                            k_cache=key_cache,\n                            v_cache=value_cache,\n                            page_table=self.forward_metadata_spec_decode_expand.page_table,\n                            cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,\n                            cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,\n                            cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,\n                            max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,\n                            softmax_scale=layer.scaling,\n                            causal=False,\n                            window_size=window_size,\n                            softcap=layer.logit_cap,\n                            k_descale=k_descale,\n                            v_descale=v_descale,\n                            return_softmax_lse=True,\n                            num_splits=self.num_splits,\n                            **kwargs,\n                        )\n                    )\n                    o, _ = merge_state_v2(\n                        o,\n                        softmax_lse.T.contiguous(),\n                        o_expand,\n                        softmax_lse_expand.T.contiguous(),\n                    )\n                else:\n                    o = result\n        else:\n            # Do absorbed multi-latent attention\n            kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(\n                q.dtype\n            )\n            k_rope = kv_cache[:, :, layer.v_head_dim :]\n            c_kv = kv_cache[:, :, : layer.v_head_dim]\n            k_rope_cache = k_rope.view(\n                -1,\n                self.page_size,\n                layer.tp_k_head_num,\n                layer.head_dim - layer.v_head_dim,\n            )\n            c_kv_cache = c_kv.view(\n                -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim\n            )\n\n            if q_rope is not None:\n                q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)\n                q_rope = q_rope.view(\n                    -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim\n                )\n            else:\n                q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)\n                q_nope = q_all[:, :, : layer.v_head_dim]\n                q_rope = q_all[:, :, layer.v_head_dim :]\n            max_seqlen_q = metadata.max_seq_len_q\n\n            result = flash_attn_with_kvcache(\n                q=q_rope,\n                k_cache=k_rope_cache,\n                v_cache=c_kv_cache,\n                qv=q_nope,\n                page_table=metadata.page_table,\n                cache_seqlens=metadata.cache_seqlens_int32,\n                cu_seqlens_q=metadata.cu_seqlens_q,\n                cu_seqlens_k_new=metadata.cu_seqlens_k,\n                max_seqlen_q=max_seqlen_q,\n                softmax_scale=layer.scaling,\n                causal=False if use_cascade_attn else causal,\n                softcap=layer.logit_cap,\n                k_descale=k_descale,\n                v_descale=v_descale,\n                return_softmax_lse=use_cascade_attn,  # softmax_lse is needed for merge states\n                num_splits=self.num_splits,\n            )\n            if use_cascade_attn:\n                o, softmax_lse, *rest = result\n                o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache(\n                    q=q_rope,\n                    k_cache=k_rope_cache,\n                    v_cache=c_kv_cache,\n                    qv=q_nope,\n                    page_table=self.forward_metadata_spec_decode_expand.page_table,\n                    cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,\n                    cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,\n                    cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,\n                    max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,\n                    softmax_scale=layer.scaling,\n                    causal=False,\n                    window_size=window_size,\n                    softcap=layer.logit_cap,\n                    k_descale=k_descale,\n                    v_descale=v_descale,\n                    return_softmax_lse=True,\n                    num_splits=self.num_splits,\n                )\n                o, _ = merge_state_v2(\n                    o,\n                    softmax_lse.T.contiguous(),\n                    o_expand,\n                    softmax_lse_expand.T.contiguous(),\n                )\n            else:\n                o = result\n\n        return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)\n\n    def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):\n        \"\"\"Initialize CUDA graph state for the attention backend.\n\n        Args:\n            max_bs (int): Maximum batch size to support in CUDA graphs\n\n        This creates fixed-size tensors that will be reused during CUDA graph replay\n        to avoid memory allocations.\n        \"\"\"\n        max_num_pages = (self.max_context_len + self.page_size - 1) // self.page_size\n\n        # This is being used by normal decode and draft decode when topk == 1\n        self.decode_cuda_graph_metadata = {\n            \"cache_seqlens\": torch.zeros(max_bs, dtype=torch.int32, device=self.device),\n            \"cu_seqlens_q\": torch.arange(\n                0, max_bs + 1, dtype=torch.int32, device=self.device\n            ),\n            \"cu_seqlens_k\": torch.zeros(\n                max_bs + 1, dtype=torch.int32, device=self.device\n            ),\n            \"page_table\": torch.zeros(\n                max_bs,\n                max_num_pages,\n                dtype=torch.int32,\n                device=self.device,\n            ),\n            \"strided_indices\": torch.arange(\n                0, self.max_context_len, self.page_size, device=self.device\n            ),\n        }\n        # Only allocate local attention buffers if local attention is enabled\n        # This prevents OOM errors when local attention is not being used\n        if self.has_local_attention:\n            # Estimate maximum sizes for local attention metadata\n            max_seq_len = self.max_context_len\n            page_size = self.page_size or 1\n            attn_chunk_size = self.attention_chunk_size\n            max_virtual_batches = max_bs * (\n                (max_seq_len + attn_chunk_size - 1) // attn_chunk_size\n            )\n            max_pages_per_block = (attn_chunk_size + page_size - 1) // page_size\n\n            self.decode_cuda_graph_local_attn_metadata = {\n                \"local_query_start_loc\": torch.zeros(\n                    max_virtual_batches + 1, dtype=torch.int32, device=self.device\n                ),\n                \"local_seqused_k\": torch.zeros(\n                    max_virtual_batches, dtype=torch.int32, device=self.device\n                ),\n                \"local_block_table\": torch.zeros(\n                    max_virtual_batches,\n                    max_pages_per_block,\n                    dtype=torch.int32,\n                    device=self.device,\n                ),\n            }\n\n        if self.use_sliding_window_kv_pool:\n            self.decode_cuda_graph_metadata[\"swa_page_table\"] = torch.zeros(\n                max_bs,\n                max_num_pages,\n                dtype=torch.int32,\n                device=self.device,\n            )\n\n        # This is used by draft decode's first half of metadata when topk > 1\n        if self.topk > 1:\n            self.draft_decode_metadata_topk_normal = {\n                \"cache_seqlens\": torch.zeros(\n                    max_bs, dtype=torch.int32, device=self.device\n                ),\n                \"cu_seqlens_q\": torch.arange(\n                    0,\n                    max_bs * self.topk + 1,\n                    step=self.topk,\n                    dtype=torch.int32,\n                    device=self.device,\n                ),\n                \"cu_seqlens_k\": torch.zeros(\n                    max_bs + 1, dtype=torch.int32, device=self.device\n                ),\n                \"page_table\": torch.zeros(\n                    max_bs,\n                    self.max_context_len,\n                    dtype=torch.int32,\n                    device=self.device,\n                ),\n            }\n\n            # This is used by draft decode's second half of metadata when topk > 1\n            decode_length = self.speculative_step_id + 1\n            self.draft_decode_metadata_topk_expand = {\n                \"cache_seqlens\": torch.full(\n                    (max_bs * self.topk,),\n                    decode_length,\n                    device=self.device,\n                    dtype=torch.int32,\n                ),\n                \"cu_seqlens_q\": torch.arange(\n                    0,\n                    max_bs * self.topk + 1,\n                    dtype=torch.int32,\n                    device=self.device,\n                ),\n                \"cu_seqlens_k\": torch.arange(\n                    0,\n                    max_bs * self.topk * decode_length + 1,\n                    step=decode_length,\n                    dtype=torch.int32,\n                    device=self.device,\n                ),\n                \"page_table\": torch.zeros(\n                    max_bs * self.topk,\n                    decode_length + 1,  # Additional page for last partial page\n                    dtype=torch.int32,\n                    device=self.device,\n                ),\n            }\n\n        if (\n            self.speculative_num_draft_tokens is not None\n            and self.speculative_num_draft_tokens > 0\n        ):\n            # \"page_table_draft_decode\" will be set only when spec decoding enabled to save memory\n            self.decode_cuda_graph_metadata[\"page_table_draft_decode\"] = torch.zeros(\n                max_bs,\n                max_num_pages,\n                dtype=torch.int32,\n                device=self.device,\n            )\n\n            self.target_verify_metadata = {\n                \"cache_seqlens\": torch.zeros(\n                    max_bs, dtype=torch.int32, device=self.device\n                ),\n                \"cu_seqlens_q\": torch.arange(\n                    0,\n                    max_bs * self.speculative_num_draft_tokens + 1,\n                    step=self.speculative_num_draft_tokens,\n                    dtype=torch.int32,\n                    device=self.device,\n                ),\n                \"cu_seqlens_k\": torch.zeros(\n                    max_bs + 1, dtype=torch.int32, device=self.device\n                ),\n                \"page_table\": torch.zeros(\n                    max_bs,\n                    max_num_pages,\n                    dtype=torch.int32,\n                    device=self.device,\n                ),\n                \"strided_indices\": torch.arange(\n                    0, self.max_context_len, self.page_size, device=self.device\n                ),\n            }\n\n            self.draft_extend_metadata = {\n                \"cache_seqlens\": torch.zeros(\n                    max_bs, dtype=torch.int32, device=self.device\n                ),\n                \"cu_seqlens_q\": torch.zeros(\n                    max_bs + 1,\n                    dtype=torch.int32,\n                    device=self.device,\n                ),\n                \"cu_seqlens_k\": torch.zeros(\n                    max_bs + 1, dtype=torch.int32, device=self.device\n                ),\n                \"page_table\": torch.zeros(\n                    max_bs,\n                    max_num_pages,\n                    dtype=torch.int32,\n                    device=self.device,\n                ),\n                \"strided_indices\": torch.arange(\n                    0, self.max_context_len, self.page_size, device=self.device\n                ),\n            }\n\n            if self.use_sliding_window_kv_pool:\n                self.target_verify_metadata[\"swa_page_table\"] = torch.zeros(\n                    max_bs,\n                    max_num_pages,\n                    dtype=torch.int32,\n                    device=self.device,\n                )\n                self.draft_extend_metadata[\"swa_page_table\"] = torch.zeros(\n                    max_bs,\n                    max_num_pages,\n                    dtype=torch.int32,\n                    device=self.device,\n                )\n\n        if self.topk > 1:\n            self.target_verify_metadata_topk_normal = {\n                \"cache_seqlens\": torch.zeros(\n                    max_bs, dtype=torch.int32, device=self.device\n                ),\n                \"cu_seqlens_q\": torch.arange(\n                    0,\n                    max_bs * self.speculative_num_draft_tokens + 1,\n                    step=self.speculative_num_draft_tokens,\n                    dtype=torch.int32,\n                    device=self.device,\n                ),\n                \"cu_seqlens_k\": torch.zeros(\n                    max_bs + 1, dtype=torch.int32, device=self.device\n                ),\n                \"page_table\": torch.zeros(\n                    max_bs,\n                    self.max_context_len,\n                    dtype=torch.int32,\n                    device=self.device,\n                ),\n            }\n\n            self.target_verify_metadata_topk_expand = {\n                \"cache_seqlens\": torch.zeros(\n                    max_bs * self.speculative_num_draft_tokens,\n                    dtype=torch.int32,\n                    device=self.device,\n                ),\n                \"cu_seqlens_k\": torch.zeros(\n                    max_bs * self.speculative_num_draft_tokens + 1,\n                    dtype=torch.int32,\n                    device=self.device,\n                ),\n                \"cu_seqlens_q\": torch.arange(\n                    0,\n                    max_bs * self.speculative_num_draft_tokens + 1,\n                    dtype=torch.int32,\n                    device=self.device,\n                ),\n                \"page_table\": torch.zeros(\n                    max_bs * self.speculative_num_draft_tokens,\n                    self.speculative_num_draft_tokens,\n                    dtype=torch.int32,\n                    device=self.device,\n                ),\n            }\n\n            if self.has_swa:\n                self.target_verify_metadata_topk_swa = {\n                    \"cache_seqlens\": torch.zeros(\n                        max_bs * self.speculative_num_draft_tokens,\n                        dtype=torch.int32,\n                        device=self.device,\n                    ),\n                    \"cu_seqlens_k\": torch.zeros(\n                        max_bs * self.speculative_num_draft_tokens + 1,\n                        dtype=torch.int32,\n                        device=self.device,\n                    ),\n                    \"cu_seqlens_q\": torch.arange(\n                        0,\n                        max_bs * self.speculative_num_draft_tokens + 1,\n                        dtype=torch.int32,\n                        device=self.device,\n                    ),\n                    \"page_table\": torch.zeros(\n                        max_bs * self.speculative_num_draft_tokens,\n                        self.max_context_len,\n                        dtype=torch.int32,\n                        device=self.device,\n                    ),\n                }\n\n        # Only allocate encoder metadata for encoder-decoder models\n        if self.is_encoder_decoder:\n            self.encoder_metadata = {\n                \"encoder_page_table\": torch.zeros(\n                    max_bs,\n                    self.max_context_len,\n                    dtype=torch.int32,\n                    device=self.device,\n                ),\n                \"encoder_lens_int32\": torch.zeros(\n                    max_bs, dtype=torch.int32, device=self.device\n                ),\n                \"encoder_cu_seqlens_k\": torch.zeros(\n                    max_bs + 1, dtype=torch.int32, device=self.device\n                ),\n            }\n        else:\n            # For decoder-only models, skip encoder_metadata allocation\n            self.encoder_metadata = {}\n\n    def init_forward_metadata_capture_cuda_graph(\n        self,\n        bs: int,\n        num_tokens: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        encoder_lens: Optional[torch.Tensor],\n        forward_mode: ForwardMode,\n        spec_info: Optional[SpecInput],\n    ):\n        \"\"\"Initialize forward metadata for capturing CUDA graph.\"\"\"\n        metadata = FlashAttentionMetadata()\n\n        # metadata_expand is needed for Spec Decoding when top k > 1\n        metadata_expand = FlashAttentionMetadata()\n\n        device = seq_lens.device\n        if forward_mode.is_decode_or_idle():\n            if spec_info is not None:\n                # Draft Decode\n                if self.topk <= 1:\n                    # When topk = 1, we use the normal decode metadata\n                    metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[\n                        \"cache_seqlens\"\n                    ][:bs]\n                    metadata.max_seq_len_k = seq_lens.max().item() + (\n                        self.speculative_step_id + 1\n                    )\n                    metadata.cu_seqlens_q = self.decode_cuda_graph_metadata[\n                        \"cu_seqlens_q\"\n                    ][: bs + 1]\n                    metadata.cu_seqlens_k = torch.nn.functional.pad(\n                        torch.cumsum(\n                            metadata.cache_seqlens_int32, dim=0, dtype=torch.int32\n                        ),\n                        (1, 0),\n                    )\n                    metadata.page_table = self.decode_cuda_graph_metadata[\n                        \"page_table_draft_decode\"\n                    ][:bs, :]\n                    if self.use_sliding_window_kv_pool:\n                        metadata.swa_page_table = self.decode_cuda_graph_metadata[\n                            \"swa_page_table\"\n                        ][:bs, :]\n                    self.decode_cuda_graph_metadata[bs] = metadata\n                else:\n                    # When top k > 1, we need two specific draft decode metadata, and then merge states\n                    # 1. The first half of metadata for prefix tokens\n                    metadata.cache_seqlens_int32 = (\n                        self.draft_decode_metadata_topk_normal[\"cache_seqlens\"][:bs]\n                    )\n                    metadata.max_seq_len_q = self.topk\n                    metadata.max_seq_len_k = seq_lens.max().item()\n                    metadata.cu_seqlens_q = self.draft_decode_metadata_topk_normal[\n                        \"cu_seqlens_q\"\n                    ][: bs + 1]\n                    metadata.cu_seqlens_k = self.draft_decode_metadata_topk_normal[\n                        \"cu_seqlens_k\"\n                    ][: bs + 1]\n                    metadata.page_table = self.draft_decode_metadata_topk_normal[\n                        \"page_table\"\n                    ][:bs, :]\n\n                    # 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)\n                    metadata_expand.cache_seqlens_int32 = (\n                        self.draft_decode_metadata_topk_expand[\"cache_seqlens\"][\n                            : bs * self.topk\n                        ]\n                    )\n                    metadata_expand.max_seq_len_q = 1\n                    metadata_expand.cu_seqlens_q = (\n                        self.draft_decode_metadata_topk_expand[\"cu_seqlens_q\"][\n                            : bs * self.topk + 1\n                        ]\n                    )\n                    metadata_expand.cu_seqlens_k = (\n                        self.draft_decode_metadata_topk_expand[\"cu_seqlens_k\"][\n                            : bs * self.topk + 1\n                        ]\n                    )\n                    metadata_expand.page_table = self.draft_decode_metadata_topk_expand[\n                        \"page_table\"\n                    ][: bs * self.topk]\n                    self.draft_decode_metadata_topk_normal[bs] = metadata\n                    self.draft_decode_metadata_topk_expand[bs] = metadata_expand\n            else:\n                # Normal Decode\n                # Get sequence information\n                metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)\n                batch_size = len(seq_lens)\n                device = seq_lens.device\n                metadata.cu_seqlens_k = torch.nn.functional.pad(\n                    torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)\n                )\n                # Precompute maximum sequence length\n                metadata.max_seq_len_k = seq_lens.max().item()\n                # Precompute page table\n                metadata.page_table = self.decode_cuda_graph_metadata[\"page_table\"][\n                    :bs, :\n                ]\n                if self.use_sliding_window_kv_pool:\n                    metadata.swa_page_table = self.decode_cuda_graph_metadata[\n                        \"swa_page_table\"\n                    ][:bs, :]\n                # Precompute cumulative sequence lengths\n                metadata.cu_seqlens_q = torch.arange(\n                    0, batch_size + 1, dtype=torch.int32, device=device\n                )\n                self.decode_cuda_graph_metadata[bs] = metadata\n\n                self._maybe_update_local_attn_metadata_for_capture(metadata, batch_size)\n\n        elif forward_mode.is_target_verify():\n            if self.topk <= 1:\n                metadata.cache_seqlens_int32 = self.target_verify_metadata[\n                    \"cache_seqlens\"\n                ][:bs]\n                metadata.cache_seqlens_int32.copy_(\n                    (seq_lens + self.speculative_num_draft_tokens)\n                )\n\n                metadata.max_seq_len_q = self.speculative_num_draft_tokens\n                metadata.max_seq_len_k = (\n                    seq_lens.max().item() + self.speculative_num_draft_tokens\n                )\n\n                metadata.cu_seqlens_q = torch.arange(\n                    0,\n                    bs * self.speculative_num_draft_tokens + 1,\n                    self.speculative_num_draft_tokens,\n                    dtype=torch.int32,\n                    device=device,\n                )\n\n                metadata.cu_seqlens_k = self.target_verify_metadata[\"cu_seqlens_k\"][\n                    : (bs + 1)\n                ]\n\n                metadata.page_table = self.target_verify_metadata[\"page_table\"][:bs, :]\n\n                if self.use_sliding_window_kv_pool:\n                    metadata.swa_page_table = self.target_verify_metadata[\n                        \"swa_page_table\"\n                    ][:bs, :]\n\n                self.target_verify_metadata[bs] = metadata\n            else:\n                # When topk > 1, we need two specific target verify metadata, and then merge states\n                # 1. The first half of metadata for prefix tokens\n                metadata.cache_seqlens_int32 = self.target_verify_metadata_topk_normal[\n                    \"cache_seqlens\"\n                ][:bs]\n                metadata.max_seq_len_q = self.speculative_num_draft_tokens\n                # metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item(), do this in replay\n                metadata.cu_seqlens_q = self.target_verify_metadata_topk_normal[\n                    \"cu_seqlens_q\"\n                ][: bs + 1]\n                metadata.cu_seqlens_k = self.target_verify_metadata_topk_normal[\n                    \"cu_seqlens_k\"\n                ][: bs + 1]\n                metadata.page_table = self.target_verify_metadata_topk_normal[\n                    \"page_table\"\n                ][:bs, :]\n\n                # 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)\n                metadata_expand.cache_seqlens_int32 = (\n                    self.target_verify_metadata_topk_expand[\"cache_seqlens\"][\n                        : bs * self.speculative_num_draft_tokens\n                    ]\n                )\n                metadata_expand.max_seq_len_q = 1\n                metadata_expand.cu_seqlens_q = self.target_verify_metadata_topk_expand[\n                    \"cu_seqlens_q\"\n                ][: bs * self.speculative_num_draft_tokens + 1]\n                metadata_expand.cu_seqlens_k = self.target_verify_metadata_topk_expand[\n                    \"cu_seqlens_k\"\n                ][: bs * self.speculative_num_draft_tokens + 1]\n\n                metadata_expand.page_table = self.target_verify_metadata_topk_expand[\n                    \"page_table\"\n                ][: bs * self.speculative_num_draft_tokens]\n\n                self.target_verify_metadata_topk_normal[bs] = metadata\n                self.target_verify_metadata_topk_expand[bs] = metadata_expand\n\n                if self.has_swa:\n                    metadata_swa = FlashAttentionMetadata()\n                    metadata_swa.cache_seqlens_int32 = (\n                        self.target_verify_metadata_topk_swa[\"cache_seqlens\"][\n                            : bs * self.speculative_num_draft_tokens\n                        ]\n                    )\n                    metadata_swa.max_seq_len_q = 1\n                    metadata_swa.cu_seqlens_q = self.target_verify_metadata_topk_swa[\n                        \"cu_seqlens_q\"\n                    ][: bs * self.speculative_num_draft_tokens + 1]\n                    metadata_swa.cu_seqlens_k = self.target_verify_metadata_topk_swa[\n                        \"cu_seqlens_k\"\n                    ][: bs * self.speculative_num_draft_tokens + 1]\n\n                    metadata_swa.page_table = self.target_verify_metadata_topk_swa[\n                        \"page_table\"\n                    ][: bs * self.speculative_num_draft_tokens]\n                    self.target_verify_metadata_topk_swa[bs] = metadata_swa\n                    metadata.swa_spec_metadata = metadata_swa\n\n        elif forward_mode.is_draft_extend(include_v2=True):\n            metadata.cache_seqlens_int32 = self.draft_extend_metadata[\"cache_seqlens\"][\n                :bs\n            ]\n            metadata.cache_seqlens_int32.copy_(seq_lens)\n\n            num_tokens_per_bs = num_tokens // bs\n            metadata.max_seq_len_q = num_tokens_per_bs\n            metadata.max_seq_len_k = seq_lens.max().item()\n\n            metadata.cu_seqlens_q = torch.arange(\n                0,\n                bs * num_tokens_per_bs + 1,\n                num_tokens_per_bs,\n                dtype=torch.int32,\n                device=device,\n            )\n\n            metadata.cu_seqlens_k = self.draft_extend_metadata[\"cu_seqlens_k\"][\n                : (bs + 1)\n            ]\n            metadata.page_table = self.draft_extend_metadata[\"page_table\"][:bs, :]\n\n            if self.use_sliding_window_kv_pool:\n                metadata.swa_page_table = self.draft_extend_metadata[\"swa_page_table\"][\n                    :bs, :\n                ]\n\n            self.draft_extend_metadata[bs] = metadata\n\n        if encoder_lens is not None:\n            encoder_bs = encoder_lens.numel()\n            metadata.encoder_lens_int32 = self.encoder_metadata[\"encoder_lens_int32\"][\n                :encoder_bs\n            ]\n            metadata.encoder_cu_seqlens_k = self.encoder_metadata[\n                \"encoder_cu_seqlens_k\"\n            ][: (encoder_bs + 1)]\n\n            metadata.encoder_page_table = self.encoder_metadata[\"encoder_page_table\"][\n                :bs, :\n            ]\n\n        self.forward_metadata = metadata\n        self.forward_metadata_spec_decode_expand = metadata_expand\n\n    def init_forward_metadata_replay_cuda_graph(\n        self,\n        bs: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        seq_lens_sum: int,\n        encoder_lens: Optional[torch.Tensor],\n        forward_mode: ForwardMode,\n        spec_info: Optional[SpecInput],\n        seq_lens_cpu: Optional[torch.Tensor],\n        out_cache_loc: Optional[torch.Tensor] = None,\n    ):\n        \"\"\"Initialize forward metadata for replaying CUDA graph.\"\"\"\n        seq_lens = seq_lens[:bs]\n        seq_lens_cpu = seq_lens_cpu[:bs]\n        req_pool_indices = req_pool_indices[:bs]\n        device = seq_lens.device\n        metadata = None\n        metadata_expand = None\n\n        if forward_mode.is_decode_or_idle():\n\n            if spec_info is not None:\n                # Draft Decode\n                if self.topk <= 1:\n                    # When topk = 1, we use the normal decode metadata\n                    metadata = self.decode_cuda_graph_metadata[bs]\n                    max_len = seq_lens_cpu.max().item()\n                    metadata.max_seq_len_k = max_len + self.speculative_step_id + 1\n                    max_seq_pages = (\n                        metadata.max_seq_len_k + self.page_size - 1\n                    ) // self.page_size\n\n                    normal_decode_set_metadata(\n                        metadata.cache_seqlens_int32,\n                        metadata.cu_seqlens_k,\n                        metadata.page_table,\n                        self.req_to_token,\n                        req_pool_indices,\n                        self.decode_cuda_graph_metadata[\"strided_indices\"],\n                        max_seq_pages,\n                        seq_lens,\n                        self.speculative_step_id + 1,\n                        self.page_size,\n                        metadata.swa_page_table,\n                        (\n                            self.token_to_kv_pool\n                            if self.use_sliding_window_kv_pool\n                            else None\n                        ),\n                    )\n\n                else:\n                    # When top k > 1, we need two specific draft decode metadata, and then merge states\n                    # 1. The first half of metadata for prefix tokens\n                    metadata = self.draft_decode_metadata_topk_normal[bs]\n                    if self.page_size > 1:\n                        # First attention handles seq_lens - last_page_lens if page size > 1.\n                        last_page_lens = seq_lens % self.page_size\n                        seq_lens = seq_lens - last_page_lens\n                    metadata.cache_seqlens_int32.copy_(seq_lens)\n                    # metadata.max_seq_len_q = self.topk, already set in capture\n                    # metadata.cu_seqlens_q already set in capture\n                    # metadata.cu_seqlens_k is not needed\n\n                    metadata.max_seq_len_k = seq_lens_cpu.max().item()\n                    max_seq_pages = (\n                        metadata.max_seq_len_k + self.page_size - 1\n                    ) // self.page_size\n                    strided_indices = self.decode_cuda_graph_metadata[\"strided_indices\"]\n                    strided_indices = strided_indices[:max_seq_pages]\n                    page_table = (\n                        self.req_to_token[\n                            req_pool_indices[:, None],  # shape [bs, 1]\n                            strided_indices[None, :],  # shape [1, max_seq_pages]\n                        ]\n                        // self.page_size\n                    )\n                    metadata.page_table[:, :max_seq_pages].copy_(page_table)\n                    # 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)\n                    metadata_expand = self.draft_decode_metadata_topk_expand[bs]\n                    decode_length = self.speculative_step_id + 1\n                    # shape: [bs, num_steps, topk] -> [bs x topk, num_steps]\n                    cache_loc = out_cache_loc.view(-1, self.speculative_num_steps)\n                    if self.page_size > 1:\n                        draft_decode_set_expand_metadata(\n                            cache_seqlens_int32=metadata_expand.cache_seqlens_int32,\n                            page_table=metadata_expand.page_table,\n                            last_page_lens=last_page_lens,\n                            decode_length=decode_length,\n                            cache_loc=cache_loc,\n                            topk=self.topk,\n                            page_size=self.page_size,\n                        )\n                    else:\n                        num_seqs = cache_loc.shape[0]\n                        metadata_expand.page_table[:num_seqs, :decode_length].copy_(\n                            cache_loc[:, :decode_length]\n                        )\n                # TODO: Handle local attention metadata for draft decode when llama4 eagle is supported\n            else:\n                # Normal Decode\n                metadata = self.decode_cuda_graph_metadata[bs]\n                max_len = seq_lens_cpu.max().item()\n                max_seq_pages = (max_len + self.page_size - 1) // self.page_size\n                metadata.max_seq_len_k = max_len\n\n                normal_decode_set_metadata(\n                    metadata.cache_seqlens_int32,\n                    metadata.cu_seqlens_k,\n                    metadata.page_table,\n                    self.req_to_token,\n                    req_pool_indices,\n                    self.decode_cuda_graph_metadata[\"strided_indices\"],\n                    max_seq_pages,\n                    seq_lens,\n                    0,\n                    self.page_size,\n                    metadata.swa_page_table,\n                    self.token_to_kv_pool if self.use_sliding_window_kv_pool else None,\n                )\n\n                self._maybe_update_local_attn_metadata_for_replay(\n                    metadata,\n                    bs,\n                )\n        elif forward_mode.is_target_verify():\n            if self.topk <= 1:\n                metadata = self.target_verify_metadata[bs]\n                metadata.cache_seqlens_int32.copy_(\n                    (seq_lens + self.speculative_num_draft_tokens)\n                )\n\n                metadata.max_seq_len_k = (\n                    seq_lens_cpu.max().item() + self.speculative_num_draft_tokens\n                )\n                metadata.cu_seqlens_k[1:].copy_(\n                    torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)\n                )\n                max_seq_pages = (\n                    metadata.max_seq_len_k + self.page_size - 1\n                ) // self.page_size\n                page_indices = self.req_to_token[\n                    req_pool_indices[:, None],\n                    self.decode_cuda_graph_metadata[\"strided_indices\"][:max_seq_pages],\n                ]\n                if (\n                    self.use_sliding_window_kv_pool\n                    and metadata.swa_page_table is not None\n                ):\n                    swa_page_indices = (\n                        self.token_to_kv_pool.translate_loc_from_full_to_swa(\n                            page_indices\n                        )\n                    )\n                    metadata.swa_page_table[:, :max_seq_pages].copy_(\n                        swa_page_indices // self.page_size\n                    )\n                page_indices //= self.page_size\n                metadata.page_table[:, :max_seq_pages].copy_(page_indices)\n            else:\n                # When topk > 1, we need two specific target verify metadata, and then merge states\n                # 1. The first half of metadata for prefix tokens\n                metadata = self.target_verify_metadata_topk_normal[bs]\n                metadata.cache_seqlens_int32.copy_(seq_lens)\n                # metadata.max_seq_len_q = self.speculative_num_draft_tokens, already set in capture\n                metadata.max_seq_len_k = seq_lens_cpu.max().item()\n                # metadata.cu_seqlens_q already set in capture\n                metadata.cu_seqlens_k[1:].copy_(\n                    torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)\n                )\n                max_seq_pages = (\n                    metadata.max_seq_len_k + self.page_size - 1\n                ) // self.page_size\n                page_indices = self.req_to_token[\n                    req_pool_indices[:, None],\n                    self.decode_cuda_graph_metadata[\"strided_indices\"][:max_seq_pages],\n                ]\n                page_indices //= self.page_size\n                metadata.page_table[:, :max_seq_pages].copy_(page_indices)\n\n                # 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)\n                metadata_expand = self.target_verify_metadata_topk_expand[bs]\n\n                # metadata_expand.max_seq_len_q = 1, already set in capture\n                # metadata_expand.cu_seqlens_q already set in capture\n                offsets = torch.arange(\n                    self.speculative_num_draft_tokens, device=device\n                ).unsqueeze(\n                    0\n                )  # shape: (1, self.speculative_num_draft_tokens)\n\n                cols = offsets.expand(seq_lens.numel(), -1) + seq_lens.unsqueeze(1)\n                cum_len = torch.nn.functional.pad(\n                    torch.cumsum(\n                        (\n                            seq_lens + self.speculative_num_draft_tokens\n                        ).repeat_interleave(self.speculative_num_draft_tokens),\n                        dim=0,\n                    ),\n                    (1, 0),\n                )[:-1]\n                mask_extraction_indices = (\n                    cols.repeat_interleave(self.speculative_num_draft_tokens, dim=0)\n                    + cum_len[:, None]\n                ).view(1, -1)\n                # avoid extracting padded seq indices which will be out of boundary\n                mask_extraction_indices[\n                    :,\n                    spec_info.positions.numel() * self.speculative_num_draft_tokens :,\n                ].fill_(0)\n                mask = spec_info.custom_mask[mask_extraction_indices].view(\n                    -1, self.speculative_num_draft_tokens\n                )  # (bsz * draft_num, draft_num)\n\n                col_indices = offsets.expand(\n                    mask.shape[0], self.speculative_num_draft_tokens\n                )\n                keys = torch.where(\n                    mask,\n                    col_indices,\n                    col_indices + self.speculative_num_draft_tokens,\n                )\n                _, sort_order = torch.sort(keys, dim=1)\n\n                non_masked_page_table = (\n                    self.req_to_token[req_pool_indices, :]\n                    .gather(1, cols)\n                    .repeat_interleave(self.speculative_num_draft_tokens, dim=0)\n                )  # (bsz, draft_num)\n\n                metadata_expand.page_table.copy_(\n                    non_masked_page_table.gather(1, sort_order)\n                )\n                metadata_expand.cache_seqlens_int32.copy_(mask.sum(dim=1))\n                metadata_expand.cu_seqlens_k[1:].copy_(\n                    torch.cumsum(\n                        metadata_expand.cache_seqlens_int32,\n                        dim=0,\n                        dtype=torch.int32,\n                    )\n                )\n                if self.has_swa:\n                    metadata_swa = self.target_verify_metadata_topk_swa[bs]\n                    self._init_sliding_window_attn_spec_metadata(\n                        metadata, metadata_expand, metadata_swa\n                    )\n\n        elif forward_mode.is_draft_extend():\n            metadata = self.draft_extend_metadata[bs]\n            metadata.cache_seqlens_int32.copy_(seq_lens)\n\n            metadata.max_seq_len_k = seq_lens_cpu.max().item()\n            metadata.cu_seqlens_k[1:].copy_(\n                torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)\n            )\n            accept_length = spec_info.accept_length[:bs]\n            if spec_info.accept_length_cpu:\n                metadata.max_seq_len_q = max(spec_info.accept_length_cpu) + 1\n            else:\n                metadata.max_seq_len_q = 1\n\n            metadata.cu_seqlens_q[1:].copy_(\n                torch.cumsum(accept_length, dim=0, dtype=torch.int32)\n            )\n\n            max_seq_pages = (\n                metadata.max_seq_len_k + self.page_size - 1\n            ) // self.page_size\n            page_indices = self.req_to_token[\n                req_pool_indices[:, None],\n                self.draft_extend_metadata[\"strided_indices\"][:max_seq_pages],\n            ]\n            if self.use_sliding_window_kv_pool and metadata.swa_page_table is not None:\n                swa_page_indices = self.token_to_kv_pool.translate_loc_from_full_to_swa(\n                    page_indices\n                )\n                metadata.swa_page_table[:, :max_seq_pages].copy_(\n                    swa_page_indices // self.page_size\n                )\n            metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size)\n\n        elif forward_mode.is_draft_extend_v2():\n            metadata = self.draft_extend_metadata[bs]\n            metadata.cache_seqlens_int32.copy_(seq_lens)\n\n            metadata.max_seq_len_k = seq_lens_cpu.max().item()\n            metadata.cu_seqlens_k[1:].copy_(\n                torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)\n            )\n\n            extend_seq_lens_tensor = getattr(spec_info, \"extend_seq_lens_tensor\", None)\n            extend_seq_lens_cpu = getattr(spec_info, \"extend_seq_lens_cpu\", None)\n            if extend_seq_lens_tensor is not None:\n                extend_seq_lens = extend_seq_lens_tensor.to(torch.int32)\n            elif extend_seq_lens_cpu is not None:\n                extend_seq_lens = torch.as_tensor(\n                    extend_seq_lens_cpu,\n                    dtype=torch.int32,\n                    device=device,\n                )\n            else:\n                default_extend = getattr(\n                    spec_info, \"num_tokens_per_req\", self.speculative_num_steps + 1\n                )\n                extend_seq_lens = torch.full(\n                    (bs,), default_extend, dtype=torch.int32, device=device\n                )\n                extend_seq_lens_cpu = [default_extend] * bs\n\n            if extend_seq_lens_cpu:\n                metadata.max_seq_len_q = int(max(extend_seq_lens_cpu))\n            else:\n                metadata.max_seq_len_q = getattr(\n                    spec_info, \"num_tokens_per_req\", self.speculative_num_steps + 1\n                )\n\n            metadata.cu_seqlens_q[1:].copy_(\n                torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32)\n            )\n\n            max_seq_pages = (\n                metadata.max_seq_len_k + self.page_size - 1\n            ) // self.page_size\n            page_indices = self.req_to_token[\n                req_pool_indices[:, None],\n                self.draft_extend_metadata[\"strided_indices\"][:max_seq_pages],\n            ]\n            if self.use_sliding_window_kv_pool and metadata.swa_page_table is not None:\n                swa_page_indices = self.token_to_kv_pool.translate_loc_from_full_to_swa(\n                    page_indices\n                )\n                metadata.swa_page_table[:, :max_seq_pages].copy_(\n                    swa_page_indices // self.page_size\n                )\n            metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size)\n\n        if encoder_lens is not None:\n            # Only support encoder size 1 for now\n            metadata.encoder_max_seq_len_k = encoder_lens[0]\n            metadata.encoder_lens_int32.copy_(encoder_lens[:1])\n            metadata.encoder_cu_seqlens_k[1:].copy_(\n                torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32)\n            )\n\n            metadata.encoder_page_table[:, : metadata.encoder_max_seq_len_k].copy_(\n                self.req_to_token[req_pool_indices, : metadata.encoder_max_seq_len_k]\n            )\n\n            # Update the regular page table\n            page_table = self.req_to_token[\n                req_pool_indices,\n                metadata.encoder_max_seq_len_k : (\n                    metadata.encoder_max_seq_len_k + metadata.max_seq_len_k\n                ),\n            ]\n            metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)\n\n        self.forward_metadata = metadata\n        self.forward_metadata_spec_decode_expand = metadata_expand\n\n    def get_cuda_graph_seq_len_fill_value(self):\n        \"\"\"Get the fill value for sequence length in CUDA graph.\"\"\"\n        return 1\n\n    def _maybe_init_local_attn_metadata(\n        self, forwardbatch: ForwardBatch, metadata: FlashAttentionMetadata, device\n    ):\n        \"\"\"Centralized utility to initialize local_attn_metadata if chunked attention is enabled.\"\"\"\n        if not self.has_local_attention:\n            metadata.local_attn_metadata = None\n            return\n\n        cu_seqlens_q = metadata.cu_seqlens_q\n        cache_seqlens_int32 = metadata.cache_seqlens_int32\n        if self.use_sliding_window_kv_pool:\n            page_table = self.token_to_kv_pool.translate_loc_from_full_to_swa(\n                metadata.page_table\n            )\n        else:\n            page_table = metadata.page_table\n        if cu_seqlens_q is None or cache_seqlens_int32 is None or page_table is None:\n            metadata.local_attn_metadata = None\n            return\n\n        cu_seqlens_q_np = cu_seqlens_q.cpu().numpy()\n        seq_lens_np = cache_seqlens_int32.cpu().numpy()\n        (\n            seqlens_q_local_np,\n            cu_seqlens_q_local_np,\n            seqlens_k_local_np,\n            block_table_local,\n        ) = make_local_attention_virtual_batches(\n            self.attention_chunk_size,\n            cu_seqlens_q_np,\n            seq_lens_np,\n            page_table,\n            self.page_size,\n        )\n\n        local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(\n            local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(device),\n            local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),\n            local_block_table=block_table_local.to(device),\n            local_max_query_len=int(seqlens_q_local_np.max()),\n            local_max_seq_len=int(seqlens_k_local_np.max()),\n        )\n        metadata.local_attn_metadata = local_metadata\n\n    def _maybe_update_local_attn_metadata_for_capture(\n        self, metadata: FlashAttentionMetadata, bs: int\n    ):\n        \"\"\"Update local attention metadata during CUDA graph capture phase.\n\n        This method calculates the exact buffer sizes needed for local attention metadata\n        during the CUDA graph capture phase, optimizing memory usage by creating views of\n        pre-allocated buffers with exactly the sizes needed.\n        \"\"\"\n        if not self.has_local_attention:\n            return\n\n        seq_lens_capture = metadata.cache_seqlens_int32\n        max_seq_len = int(seq_lens_capture.max().item())\n        page_table_capture = metadata.page_table\n\n        cu_seqlens_q_np = metadata.cu_seqlens_q.cpu().numpy()\n        seqlens_np = seq_lens_capture.cpu().numpy()\n        (\n            seqlens_q_local_np,\n            cu_seqlens_q_local_np,\n            seqlens_k_local_np,\n            block_table_local_np,\n        ) = make_local_attention_virtual_batches(\n            self.attention_chunk_size,\n            cu_seqlens_q_np,\n            seqlens_np,\n            page_table_capture,\n            self.page_size,\n        )\n\n        # Get exact dimensions from the calculation\n        q_len = len(cu_seqlens_q_local_np)\n        k_len = len(seqlens_k_local_np)\n        b0 = block_table_local_np.shape[0] if block_table_local_np.shape[0] > 0 else bs\n        b1 = block_table_local_np.shape[1] if block_table_local_np.shape[1] > 0 else 1\n\n        # Create views of the pre-allocated buffers with exactly these sizes\n        # This is the key optimization - we only use the memory we actually need\n        local_query_start_loc = self.decode_cuda_graph_local_attn_metadata[\n            \"local_query_start_loc\"\n        ][:q_len]\n\n        local_seqused_k = self.decode_cuda_graph_local_attn_metadata[\"local_seqused_k\"][\n            :k_len\n        ]\n\n        local_block_table = self.decode_cuda_graph_local_attn_metadata[\n            \"local_block_table\"\n        ][:b0, :b1]\n\n        metadata.local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(\n            local_query_start_loc=local_query_start_loc,\n            local_seqused_k=local_seqused_k,\n            local_block_table=local_block_table,\n            local_max_query_len=1,\n            local_max_seq_len=max_seq_len,\n        )\n\n    def _maybe_update_local_attn_metadata_for_replay(\n        self,\n        metadata: FlashAttentionMetadata,\n        bs: int,\n    ):\n        \"\"\"Update preallocated local attention metadata in-place before CUDA graph replay.\"\"\"\n        if not self.has_local_attention:\n            return\n\n        # Access preallocated buffers\n        local_q_buf = self.decode_cuda_graph_local_attn_metadata[\n            \"local_query_start_loc\"\n        ]\n        local_k_buf = self.decode_cuda_graph_local_attn_metadata[\"local_seqused_k\"]\n        local_block_buf = self.decode_cuda_graph_local_attn_metadata[\n            \"local_block_table\"\n        ]\n        cu_seqlens_q = self.decode_cuda_graph_metadata[\"cu_seqlens_q\"]\n\n        # Create a modified version for local attention that only processes the last token\n        # This mimics the normal decode pattern\n        cu_seqlens_q = torch.arange(\n            bs + 1, device=cu_seqlens_q.device, dtype=cu_seqlens_q.dtype\n        )\n        seqlens = metadata.cache_seqlens_int32[:bs]\n        # Slice the page_table to match the batch size and actual sequence length\n        # This serves three important purposes:\n        # 1. Ensures we only process the actual batch size (bs) and not the maximum batch size\n        # 2. Limits the sequence length to prevent processing padding tokens or garbage values\n        # 3. Prevents zeros in the block table which can cause garbage output during replay\n        #\n        # Without this slicing, the pre-allocated page_table may contain zeros or invalid indices\n        # beyond the actual sequence length, leading to incorrect attention calculations\n        max_seq_len = int(seqlens.max().item())\n        if self.use_sliding_window_kv_pool:\n            sliced_page_table = self.token_to_kv_pool.translate_loc_from_full_to_swa(\n                metadata.page_table[:bs, :max_seq_len]\n            )\n        else:\n            sliced_page_table = metadata.page_table[:bs, :max_seq_len]\n\n        cu_seqlens_q_np = cu_seqlens_q.cpu().numpy()\n        seqlens_np = seqlens.cpu().numpy()\n        (\n            seqlens_q_local_np,\n            cu_seqlens_q_local_np,\n            seqlens_k_local_np,\n            block_table_local,\n        ) = make_local_attention_virtual_batches(\n            self.attention_chunk_size,\n            cu_seqlens_q_np,\n            seqlens_np,\n            sliced_page_table,\n            self.page_size,\n        )\n\n        # Convert back to tensors\n        device = local_q_buf.device\n        cu_seqlens_q_local = torch.from_numpy(cu_seqlens_q_local_np).to(device)\n        seqlens_k_local = torch.from_numpy(seqlens_k_local_np).to(device)\n        block_table_local = block_table_local.to(device)\n        # Get sizes\n        q_len = cu_seqlens_q_local.shape[0]\n        k_len = seqlens_k_local.shape[0]\n        b0, b1 = block_table_local.shape\n\n        # In-place updates into preallocated tensors and zero out the unused space\n        local_q_buf[:q_len].copy_(cu_seqlens_q_local)\n        local_q_buf[q_len:].fill_(0)\n        local_k_buf[:k_len].copy_(seqlens_k_local)\n        local_k_buf[k_len:].fill_(0)\n        local_block_buf[:b0, :b1].copy_(block_table_local)\n        local_block_buf[b0:, :].fill_(0)\n        local_block_buf[:b0, b1:].fill_(0)\n\n        if metadata.local_attn_metadata is not None:\n            lam = metadata.local_attn_metadata\n            lam.local_max_query_len = int(seqlens_q_local_np.max())\n            lam.local_max_seq_len = int(seqlens_k_local_np.max())\n\n    def _init_sliding_window_attn_spec_metadata(\n        self,\n        metadata: FlashAttentionMetadata,\n        metadata_expand: FlashAttentionMetadata,\n        metadata_swa: Optional[FlashAttentionMetadata] = None,\n    ):\n        # TODO: support page_size > 1 for swa spec\n        assert (\n            self.page_size == 1\n        ), \"FlashAttention backend doesn't support topk > 1 speculative decoding with page size > 1 sliding window attention\"\n\n        cache_seqlens_int32 = (\n            metadata.cache_seqlens_int32.repeat_interleave(\n                self.speculative_num_draft_tokens\n            )\n            + metadata_expand.cache_seqlens_int32\n        )\n        cu_seqlens_k = torch.nn.functional.pad(\n            torch.cumsum(cache_seqlens_int32, dim=0, dtype=torch.int32), (1, 0)\n        )\n        bs = cache_seqlens_int32.shape[0]\n        page_table = (\n            metadata.page_table.new_zeros(\n                (bs, metadata.max_seq_len_k + metadata_expand.page_table.shape[1])\n            )\n            if metadata_swa is None\n            else metadata_swa.page_table\n        )\n\n        prepare_swa_spec_page_table_triton(\n            page_table,\n            metadata.page_table,\n            metadata_expand.page_table,\n            metadata.cache_seqlens_int32,\n            metadata_expand.cache_seqlens_int32,\n            self.speculative_num_draft_tokens,\n        )\n\n        if metadata_swa is None:\n            metadata_swa = FlashAttentionMetadata()\n            metadata_swa.max_seq_len_q = 1\n            metadata_swa.cu_seqlens_q = metadata_expand.cu_seqlens_q\n            metadata_swa.cache_seqlens_int32 = cache_seqlens_int32\n            metadata_swa.cu_seqlens_k = cu_seqlens_k\n            metadata_swa.page_table = page_table\n        else:\n            metadata_swa.cache_seqlens_int32.copy_(cache_seqlens_int32)\n            metadata_swa.cu_seqlens_k.copy_(cu_seqlens_k)\n\n        metadata.swa_spec_metadata = metadata_swa\n\n\n@triton.jit\ndef _prepare_swa_spec_page_table_kernel(\n    dst_ptr,\n    src_a_ptr,\n    src_b_ptr,\n    seq_len_a_ptr,\n    seq_len_b_ptr,\n    dst_stride_m,\n    dst_stride_n,\n    a_stride_m,\n    a_stride_n,\n    b_stride_m,\n    b_stride_n,\n    LEN_A: tl.constexpr,\n    LEN_B: tl.constexpr,\n    REPEAT_STEP: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n):\n    pid_m = tl.program_id(0)\n    pid_n = tl.program_id(1)\n\n    idx_a = pid_m // REPEAT_STEP\n    idx_b = pid_m\n    seq_len_a = tl.load(seq_len_a_ptr + idx_a)\n    seq_len_b = tl.load(seq_len_b_ptr + idx_b)\n\n    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n    total_len = seq_len_a + seq_len_b\n\n    if pid_n * BLOCK_N >= total_len:\n        return\n\n    mask = offs_n < total_len\n    dst = dst_ptr + pid_m * dst_stride_m + offs_n * dst_stride_n\n\n    if (pid_n + 1) * BLOCK_N < seq_len_a:\n        a_ptr = src_a_ptr + idx_a * a_stride_m + offs_n * a_stride_n\n        a_mask = mask & (offs_n < LEN_A)\n        val = tl.load(a_ptr, mask=a_mask, other=0)\n        tl.store(dst, val, mask=mask)\n    elif pid_n * BLOCK_N >= seq_len_a:\n        offs_b = offs_n - seq_len_a\n        b_ptr = src_b_ptr + idx_b * b_stride_m + offs_b * b_stride_n\n        b_mask = mask & (offs_b < LEN_B)\n        val = tl.load(b_ptr, mask=b_mask, other=0)\n        tl.store(dst, val, mask=mask)\n    else:\n        # mixed part\n        a_offs = offs_n\n        a_mask = (a_offs < seq_len_a) & (a_offs < LEN_A)\n        a_ptr = src_a_ptr + idx_a * a_stride_m + a_offs * a_stride_n\n        a_val = tl.load(a_ptr, mask=a_mask, other=0)\n\n        b_offs = offs_n - seq_len_a\n        b_mask = (b_offs >= 0) & (b_offs < seq_len_b) & (b_offs < LEN_B)\n        b_ptr = src_b_ptr + idx_b * b_stride_m + b_offs * b_stride_n\n        b_val = tl.load(b_ptr, mask=b_mask, other=0)\n\n        result = tl.where(offs_n < seq_len_a, a_val, b_val)\n        tl.store(dst, result, mask=mask)\n\n\ndef prepare_swa_spec_page_table_triton(\n    page_table_dst: torch.Tensor,\n    page_table_a: torch.Tensor,\n    page_table_b: torch.Tensor,  # expand page table\n    seq_len_a: torch.Tensor,\n    seq_len_b: torch.Tensor,  # expand seq lens\n    speculative_num_draft_tokens: int,\n):\n    # concat page_table and expand page_table by kv seq length\n    bs = seq_len_a.numel()\n    bs_expand = seq_len_b.numel()\n    assert bs_expand == bs * speculative_num_draft_tokens\n\n    LEN_A = page_table_a.shape[1]\n    LEN_B = page_table_b.shape[1]\n    LEN_OUT = LEN_A + LEN_B\n    REPEAT_STEP = speculative_num_draft_tokens\n    BLOCK_N = 256\n\n    grid = (bs_expand, triton.cdiv(LEN_OUT, BLOCK_N))\n    _prepare_swa_spec_page_table_kernel[grid](\n        page_table_dst,\n        page_table_a,\n        page_table_b,\n        seq_len_a,\n        seq_len_b,\n        page_table_dst.stride(0),\n        page_table_dst.stride(1),\n        page_table_a.stride(0),\n        page_table_a.stride(1),\n        page_table_b.stride(0),\n        page_table_b.stride(1),\n        LEN_A=LEN_A,\n        LEN_B=LEN_B,\n        REPEAT_STEP=REPEAT_STEP,\n        BLOCK_N=BLOCK_N,\n        num_warps=4,\n    )\n\n\nclass FlashAttentionMultiStepBackend:\n\n    def __init__(\n        self, model_runner: ModelRunner, topk: int, speculative_num_steps: int\n    ):\n        self.model_runner = model_runner\n        self.topk = topk\n        self.speculative_num_steps = speculative_num_steps\n        self.attn_backends = []\n        for i in range(self.speculative_num_steps - 1):\n            self.attn_backends.append(\n                FlashAttentionBackend(\n                    model_runner,\n                    speculative_step_id=i,\n                    topk=self.topk,\n                    speculative_num_steps=self.speculative_num_steps,\n                )\n            )\n\n    def init_forward_metadata(self, forward_batch: ForwardBatch):\n        for i in range(self.speculative_num_steps - 1):\n            self.attn_backends[i].init_forward_metadata(forward_batch)\n\n    def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):\n        for i in range(self.speculative_num_steps - 1):\n            self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)\n\n    def init_forward_metadata_capture_cuda_graph(\n        self,\n        forward_batch: ForwardBatch,\n    ):\n        assert forward_batch.spec_info is not None\n        assert forward_batch.spec_info.is_draft_input()\n\n        for i in range(self.speculative_num_steps - 1):\n            self.attn_backends[i].init_forward_metadata_capture_cuda_graph(\n                forward_batch.batch_size,\n                forward_batch.batch_size * self.topk,\n                forward_batch.req_pool_indices,\n                forward_batch.seq_lens,\n                encoder_lens=forward_batch.encoder_lens,\n                forward_mode=ForwardMode.DECODE,\n                spec_info=forward_batch.spec_info,\n            )\n\n    def init_forward_metadata_replay_cuda_graph(\n        self, forward_batch: ForwardBatch, bs: int\n    ):\n        assert forward_batch.spec_info is not None\n        assert forward_batch.spec_info.is_draft_input()\n\n        for i in range(self.speculative_num_steps - 1):\n            # TODO: incrementally update the metadata for the later steps,\n            # so that they do not need to recompute everything from scratch.\n            self.attn_backends[i].init_forward_metadata_replay_cuda_graph(\n                bs,\n                forward_batch.req_pool_indices,\n                forward_batch.seq_lens,\n                forward_batch.seq_lens_sum,\n                encoder_lens=forward_batch.encoder_lens,\n                forward_mode=ForwardMode.DECODE,\n                spec_info=forward_batch.spec_info,\n                seq_lens_cpu=forward_batch.seq_lens_cpu,\n                out_cache_loc=forward_batch.out_cache_loc,\n            )\n\n\n# @torch.compile(dynamic=True, backend=get_compiler_backend())\n# TODO: fuse these kernels\n# NOTE: torch.compile makes it slower in speculative decoding\ndef normal_decode_set_metadata(\n    cache_seqlens_int32: torch.Tensor,\n    cu_seqlens_k: torch.Tensor,\n    page_table: torch.Tensor,\n    req_to_token: torch.Tensor,\n    req_pool_indices: torch.Tensor,\n    strided_indices: torch.Tensor,\n    max_seq_pages: torch.Tensor,\n    seq_lens: torch.Tensor,\n    seq_len_delta: int,\n    page_size: int,\n    swa_page_table: Optional[torch.Tensor] = None,\n    token_to_kv_pool: Optional[SWAKVPool] = None,\n):\n    cache_seqlens_int32.copy_(seq_lens + seq_len_delta)\n    cu_seqlens_k[1:].copy_(torch.cumsum(cache_seqlens_int32, dim=0, dtype=torch.int32))\n    page_indices = req_to_token[\n        req_pool_indices[:, None],\n        strided_indices[:max_seq_pages][None, :],\n    ]\n    page_table[:, :max_seq_pages].copy_(page_indices // page_size)\n\n    if swa_page_table is not None and token_to_kv_pool is not None:\n        assert isinstance(token_to_kv_pool, SWAKVPool)\n        swa_page_indices = token_to_kv_pool.translate_loc_from_full_to_swa(page_indices)\n        swa_page_table[:, :max_seq_pages].copy_(swa_page_indices // page_size)\n\n\n@torch.compile(dynamic=True, backend=get_compiler_backend())\ndef draft_decode_set_expand_metadata(\n    cache_seqlens_int32: torch.Tensor,  # Modifies\n    page_table: torch.Tensor,  # Modifies\n    last_page_lens: torch.Tensor,\n    decode_length: int,\n    cache_loc: torch.Tensor,\n    topk: int,\n    page_size: int,\n):\n    expanded_last_page_lens = last_page_lens.repeat_interleave(topk)\n    cache_seqlens_int32.copy_(decode_length + expanded_last_page_lens)\n    cache_loc = (cache_loc // page_size).to(torch.int32)\n    if cache_loc.dim() == 1:\n        cache_loc = cache_loc.unsqueeze(0)\n    # Vectorized torch.unique_consecutive: track value change points then scatter\n    mask = torch.ones_like(cache_loc, dtype=torch.bool)\n    mask[:, 1:] = cache_loc[:, 1:] != cache_loc[:, :-1]\n    positions = mask.cumsum(dim=1) - 1\n    num_seqs = cache_loc.shape[0]\n    page_table[:num_seqs, :].scatter_(1, positions, cache_loc)\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/flashinfer_backend.py",
    "content": "from __future__ import annotations\n\n\"\"\"\nSupport different attention backends.\nNow there are two backends: FlashInfer and Triton.\nFlashInfer is faster and Triton is easier to customize.\nEach backend supports two operators: extend (i.e. prefill with cached prefix) and decode.\n\"\"\"\n\nimport logging\nimport os\nfrom dataclasses import dataclass\nfrom enum import Enum, auto\nfrom functools import partial\nfrom typing import TYPE_CHECKING, Callable, List, Optional, Union\n\nimport torch\n\nfrom sglang.srt.compilation.piecewise_context_manager import is_in_piecewise_cuda_graph\nfrom sglang.srt.dllm.config import DllmConfig\nfrom sglang.srt.environ import envs\nfrom sglang.srt.layers.attention.base_attn_backend import AttentionBackend\nfrom sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton\nfrom sglang.srt.layers.dp_attention import get_attention_tp_size\nfrom sglang.srt.layers.radix_attention import AttentionType\nfrom sglang.srt.mem_cache.swa_memory_pool import SWATokenToKVPoolAllocator\nfrom sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode\nfrom sglang.srt.speculative.spec_info import SpecInput\nfrom sglang.srt.utils import (\n    get_int_env_var,\n    is_flashinfer_available,\n    is_sm100_supported,\n    next_power_of_2,\n)\n\nif TYPE_CHECKING:\n    from sglang.srt.layers.radix_attention import RadixAttention\n    from sglang.srt.model_executor.model_runner import ModelRunner\n\nlogger = logging.getLogger(__name__)\n\nif envs.SGLANG_ENABLE_TORCH_COMPILE.get():\n    torch._logging.set_logs(dynamo=logging.ERROR)\n    torch._dynamo.config.suppress_errors = True\n\n\nif is_flashinfer_available():\n    from flashinfer import (\n        BatchDecodeWithPagedKVCacheWrapper,\n        BatchPrefillWithPagedKVCacheWrapper,\n        BatchPrefillWithRaggedKVCacheWrapper,\n        fast_decode_plan,\n    )\n    from flashinfer.cascade import merge_state\n\n\nclass WrapperDispatch(Enum):\n    SLIDING_WINDOW = auto()\n    CROSS_ATTENTION = auto()\n\n\n@dataclass\nclass MultiItemScoringParams:\n    \"\"\"Parameters for multi-item scoring in attention computation.\n\n    Used when processing sequences with multiple items separated by delimiters,\n    where each item needs specific attention patterns that respect item boundaries.\n\n    Attributes:\n        prefix_len_ptr: A uint32 1D tensor indicating the prefix length of each prompt.\n                       The tensor size is equal to the batch size.\n        token_pos_in_items_ptr: A uint16 1D tensor indicating the token position of each item\n                               starting from 0 (delimiter) for each item. For batch size > 1,\n                               sequences are concatenated with zero padding to ensure same length.\n        token_pos_in_items_len: Zero padding length for token_pos_in_items_ptr to handle\n                               batch_size > 1 case. Defines the padded length for each sequence.\n        max_item_len_ptr: A uint16 tensor containing the max token length of all items\n                         for each prompt in the batch.\n\n    \"\"\"\n\n    prefix_len_ptr: Optional[torch.Tensor] = None\n    token_pos_in_items_ptr: Optional[torch.Tensor] = None\n    token_pos_in_items_len: int = 0\n    max_item_len_ptr: Optional[torch.Tensor] = None\n\n    def is_enabled(self) -> bool:\n        \"\"\"Check if multi-item scoring is enabled.\"\"\"\n        return self.prefix_len_ptr is not None\n\n\n@dataclass\nclass DecodeMetadata:\n    decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper]\n\n\n@dataclass\nclass PrefillMetadata:\n    prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper]\n    use_ragged: bool\n    extend_no_prefix: bool\n    multi_item_params: Optional[MultiItemScoringParams] = None\n\n\n# Reuse this workspace buffer across all flashinfer wrappers\nglobal_workspace_buffer = None\n\n# Use as a fast path to override the indptr in flashinfer's plan function\n# This is used to remove some host-to-device copy overhead.\nglobal_override_indptr_cpu = None\n\n\nclass FlashInferAttnBackend(AttentionBackend):\n    \"\"\"Flashinfer attention kernels.\"\"\"\n\n    def __init__(\n        self,\n        model_runner: ModelRunner,\n        skip_prefill: bool = False,\n        kv_indptr_buf: Optional[torch.Tensor] = None,\n        kv_last_page_len_buf: Optional[torch.Tensor] = None,\n        init_new_workspace: bool = False,\n    ):\n        super().__init__()\n        self.prefill_backend = \"fa2\"\n        self.decode_backend = \"fa2\"\n\n        # Store multi-item scoring delimiter for efficient access\n        self.multi_item_scoring_delimiter = (\n            model_runner.server_args.multi_item_scoring_delimiter\n        )\n\n        # FIXME: remove dllm workarounds from flashinfer\n        self.dllm_config = DllmConfig.from_server_args(model_runner.server_args)\n        self.is_dllm_model = self.dllm_config is not None\n\n        # Parse constants\n        self.decode_use_tensor_cores = should_use_tensor_core(\n            kv_cache_dtype=model_runner.kv_cache_dtype,\n            num_attention_heads=model_runner.model_config.num_attention_heads\n            // get_attention_tp_size(),\n            num_kv_heads=model_runner.model_config.get_num_kv_heads(\n                get_attention_tp_size()\n            ),\n        )\n        self.max_context_len = model_runner.model_config.context_len\n        self.skip_prefill = skip_prefill\n        self.is_multimodal = model_runner.model_config.is_multimodal\n\n        assert not (\n            model_runner.sliding_window_size is not None\n            and model_runner.model_config.is_encoder_decoder\n        ), \"Sliding window and cross attention are not supported together\"\n\n        if model_runner.sliding_window_size is not None:\n            self.num_wrappers = 2\n            self.dispatch_reason = WrapperDispatch.SLIDING_WINDOW\n        elif model_runner.model_config.is_encoder_decoder:\n            self.num_wrappers = 2\n            self.dispatch_reason = WrapperDispatch.CROSS_ATTENTION\n        else:\n            self.num_wrappers = 1\n            self.dispatch_reason = None\n\n        # Qwen2/Qwen3 models require higher flashinfer workspace size\n        if (\n            \"Qwen2ForCausalLM\" in model_runner.model_config.hf_config.architectures\n            or \"Qwen3ForCausalLM\" in model_runner.model_config.hf_config.architectures\n            or \"MiMoForCausalLM\" in model_runner.model_config.hf_config.architectures\n            or \"Qwen3VLForConditionalGeneration\"\n            in model_runner.model_config.hf_config.architectures\n            or \"Qwen3VLMoeForConditionalGeneration\"\n            in model_runner.model_config.hf_config.architectures\n        ):\n            envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.set(512 * 1024 * 1024)\n\n        # When deterministic inference is enabled, tensor cores should be used for decode\n        # Also set split tile sizes for prefill and decode from environment variables, and disable kv split for cuda graph\n        # More information can be found here: https://github.com/flashinfer-ai/flashinfer/pull/1675\n        self.enable_deterministic = (\n            model_runner.server_args.enable_deterministic_inference\n        )\n        self.prefill_split_tile_size = None\n        self.decode_split_tile_size = None\n        self.disable_cuda_graph_kv_split = False\n        if self.enable_deterministic:\n            self.decode_use_tensor_cores = True\n            self.prefill_split_tile_size = get_int_env_var(\n                \"SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE\", 4096\n            )\n            self.decode_split_tile_size = get_int_env_var(\n                \"SGLANG_FLASHINFER_DECODE_SPLIT_TILE_SIZE\", 2048\n            )\n            self.disable_cuda_graph_kv_split = True\n            envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.set(2048 * 1024 * 1024)\n\n        # Allocate buffers\n        global global_workspace_buffer\n        if global_workspace_buffer is None:\n            # different from flashinfer zero_init_global_workspace_buffer\n            global_workspace_size = envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.get()\n            global_workspace_buffer = torch.empty(\n                global_workspace_size,\n                dtype=torch.uint8,\n                device=model_runner.device,\n            )\n        if init_new_workspace:\n            self.workspace_buffer = torch.empty(\n                envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.get(),\n                dtype=torch.uint8,\n                device=model_runner.device,\n            )\n        else:\n            self.workspace_buffer = global_workspace_buffer\n        max_bs = model_runner.req_to_token_pool.size\n        if kv_indptr_buf is None:\n            self.kv_indptr = [\n                torch.zeros(\n                    (max_bs + 1,), dtype=torch.int32, device=model_runner.device\n                )\n                for _ in range(self.num_wrappers)\n            ]\n        else:\n            assert self.num_wrappers == 1\n            self.kv_indptr = [kv_indptr_buf]\n\n        if kv_last_page_len_buf is None:\n            self.kv_last_page_len = torch.ones(\n                (max_bs,), dtype=torch.int32, device=model_runner.device\n            )\n        else:\n            assert self.num_wrappers == 1\n            self.kv_last_page_len = kv_last_page_len_buf\n\n        if not self.skip_prefill:\n            self.qo_indptr = [\n                torch.zeros(\n                    (max_bs + 1,), dtype=torch.int32, device=model_runner.device\n                )\n                for _ in range(self.num_wrappers)\n            ]\n\n        fmha_backend = \"auto\"\n        if is_sm100_supported():\n            # Disable CUTLASS backend when piecewise cuda graph is enabled\n            # due to TMA descriptor initialization issues on B200\n            if not model_runner.server_args.disable_piecewise_cuda_graph:\n                logger.warning(\n                    \"CUTLASS backend is disabled when piecewise cuda graph is enabled \"\n                    \"due to TMA descriptor initialization issues on B200. \"\n                    \"Using auto backend instead for stability.\"\n                )\n            else:\n                fmha_backend = \"cutlass\"\n        self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(\n            self.workspace_buffer, \"NHD\", backend=fmha_backend\n        )\n\n        # Two wrappers: one for sliding window attention and one for full attention.\n        # Using two wrappers is unnecessary in the current PR, but are prepared for future PRs\n        self.prefill_wrappers_paged = []\n        self.prefill_wrappers_verify = []\n        self.decode_wrappers = []\n        for _ in range(self.num_wrappers):\n            if not skip_prefill:\n                self.prefill_wrappers_paged.append(\n                    BatchPrefillWithPagedKVCacheWrapper(\n                        self.workspace_buffer,\n                        \"NHD\",\n                        backend=self.prefill_backend,\n                    )\n                )\n                self.prefill_wrappers_verify.append(\n                    BatchPrefillWithPagedKVCacheWrapper(\n                        self.workspace_buffer,\n                        \"NHD\",\n                        backend=self.prefill_backend,\n                    )\n                )\n            self.decode_wrappers.append(\n                BatchDecodeWithPagedKVCacheWrapper(\n                    self.workspace_buffer,\n                    \"NHD\",\n                    backend=self.decode_backend,\n                    use_tensor_cores=self.decode_use_tensor_cores,\n                )\n            )\n\n        # Create indices updater\n        if not skip_prefill:\n            self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill(\n                model_runner, self\n            )  # for verify\n        self.indices_updater_decode = FlashInferIndicesUpdaterDecode(model_runner, self)\n\n        # Other metadata\n        self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None\n\n        self.decode_cuda_graph_metadata = {}\n        self.prefill_cuda_graph_metadata = {}  # For verify\n        self.draft_extend_cuda_graph_metadata = {}  # For draft extend\n\n    def _process_multi_item_scoring(\n        self, forward_batch: ForwardBatch\n    ) -> MultiItemScoringParams:\n        \"\"\"Process multi-item scoring tensors for FlashInfer attention.\n\n        This method handles sequences containing multiple \"items\" separated by delimiter tokens,\n        where each item needs specific attention patterns that respect item boundaries.\n\n        The method produces four key tensors for FlashInfer:\n        - prefix_len_ptr: uint32 tensor with prefix length for each prompt in batch\n        - token_pos_in_items_ptr: uint16 tensor with token positions starting from 0 at delimiters\n        - token_pos_in_items_len: padding length for batch processing\n        - max_item_len_ptr: uint16 tensor with max item length for each prompt\n\n        Args:\n            forward_batch: The forward batch containing input sequences and delimiter info\n\n        Returns:\n            MultiItemScoringParams: The processed multi-item scoring parameters\n\n        Examples:\n            Following FlashInfer definition: for 3 items of length 3, 2, 4 respectively:\n            token_pos_in_items_ptr = [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0]\n\n            Case 1: Single sequence\n            Text: \"What is the capital of France? <delim> London <delim> Paris <delim> Berlin <delim>\"\n            Tokens: [What, is, the, capital, of, France, ?, <delim>, London, <delim>, Paris, <delim>, Berlin, <delim>]\n            Indices: [ 0,   1,  2,   3,      4,  5,     6,   7,     8,      9,     10,    11,    12,     13]\n            - prefix_len_ptr: [7] (query length before first delimiter)\n            - token_pos_in_items_ptr: [0, 1, 0, 1, 0, 1, 0] (delim=0, London=1, delim=0, Paris=1, delim=0, Berlin=1, delim=0)\n            - token_pos_in_items_len: 7 (actual length)\n            - max_item_len_ptr: [1] (max item length is 1 token - all options are single tokens)\n\n            Case 2: Batch processing (batch_size=2)\n            Sequence 1: 2 items of length 2, 1 → [0, 1, 2, 0, 1, 0] (6 elements)\n            Sequence 2: 3 items of length 1, 3, 2 → [0, 1, 0, 1, 2, 3, 0, 1, 2, 0] (10 elements)\n            After padding both to length 10:\n            - token_pos_in_items_ptr: [0, 1, 2, 0, 1, 0, 0, 0, 0, 0,    0, 1, 0, 1, 2, 3, 0, 1, 2, 0]\n            - token_pos_in_items_len: 10 (padded length for batch processing)\n            - max_item_len_ptr: [2, 3] (max lengths per sequence)\n        \"\"\"\n\n        delimiter = self.multi_item_scoring_delimiter\n        if delimiter is None or forward_batch.forward_mode == ForwardMode.DECODE:\n            return MultiItemScoringParams()\n\n        delimiter_mask = forward_batch.input_ids == delimiter\n        prefix_cache_lens = getattr(forward_batch, \"extend_prefix_lens\", None)\n        extend_seq_lens = getattr(forward_batch, \"extend_seq_lens\", None)\n        prefix_len_ptr, token_pos_in_items_ptr = [], []\n        token_pos_in_items_len = 0\n\n        # If no extend_seq_lens, treat whole batch as one sequence\n        if extend_seq_lens is None or len(extend_seq_lens) <= 1:\n            extend_seq_lens = [forward_batch.input_ids.size(0)]\n\n        seq_start = 0\n        for i, seq_len in enumerate(extend_seq_lens):\n            seq_end = seq_start + seq_len\n            mask = delimiter_mask[seq_start:seq_end]\n            pos = forward_batch.positions[seq_start:seq_end]\n            delimiter_indices = torch.nonzero(mask, as_tuple=True)[0]\n\n            if len(delimiter_indices) > 0:\n                first_delim = delimiter_indices[0]\n                # Prefix length: store as scalar\n                prefix_len = first_delim + (\n                    prefix_cache_lens[i] if prefix_cache_lens is not None else 0\n                )\n                prefix_len_ptr.append(\n                    prefix_len.item() if torch.is_tensor(prefix_len) else prefix_len\n                )\n\n                # Compute relative positions within items after delimiters\n                diff = pos[first_delim:] - torch.cummax(mask[first_delim:], 0)[1]\n                token_pos = (diff - pos[first_delim]).to(torch.uint16)\n                token_pos_in_items_ptr.append(token_pos)\n\n                # Update forward_batch positions in-place\n                pos[first_delim:] = diff - 1\n                forward_batch.positions[seq_start:seq_end] = pos\n\n            seq_start = seq_end\n\n        # Pad token_pos_in_items_ptr for batch processing\n        if token_pos_in_items_ptr:\n            token_pos_in_items_len = max(t.numel() for t in token_pos_in_items_ptr)\n            device = forward_batch.input_ids.device\n            token_pos_in_items_ptr = [\n                torch.cat(\n                    [\n                        t,\n                        torch.zeros(\n                            token_pos_in_items_len - t.numel(),\n                            dtype=torch.uint16,\n                            device=device,\n                        ),\n                    ]\n                )\n                for t in token_pos_in_items_ptr\n            ]\n\n        if not prefix_len_ptr or not token_pos_in_items_ptr:\n            return MultiItemScoringParams()\n\n        # Build final params\n        device = forward_batch.input_ids.device\n        return MultiItemScoringParams(\n            prefix_len_ptr=torch.tensor(\n                prefix_len_ptr, dtype=torch.uint32, device=device\n            ),\n            token_pos_in_items_ptr=torch.cat(token_pos_in_items_ptr, dim=0),\n            token_pos_in_items_len=token_pos_in_items_len & 0xFFFFFFFF,\n            max_item_len_ptr=torch.stack(\n                [\n                    t.to(torch.int32).max().to(torch.uint16)\n                    for t in token_pos_in_items_ptr\n                ],\n                dim=0,\n            ),\n        )\n\n    def init_forward_metadata(self, forward_batch: ForwardBatch):\n        if forward_batch.forward_mode.is_decode_or_idle():\n            self.indices_updater_decode.update(\n                forward_batch.req_pool_indices,\n                forward_batch.seq_lens,\n                forward_batch.seq_lens_cpu,\n                forward_batch.seq_lens_sum,\n                decode_wrappers=self.decode_wrappers,\n                encoder_lens=forward_batch.encoder_lens,\n                spec_info=forward_batch.spec_info,\n                fixed_split_size=self.decode_split_tile_size,\n                disable_split_kv=False,\n            )\n            self.forward_metadata = DecodeMetadata(self.decode_wrappers)\n        elif forward_batch.forward_mode.is_draft_extend():\n            self.indices_updater_prefill.update(\n                forward_batch.req_pool_indices,\n                forward_batch.seq_lens,\n                forward_batch.seq_lens_cpu,\n                forward_batch.seq_lens_sum,\n                prefix_lens=None,\n                prefill_wrappers=self.prefill_wrappers_paged,\n                use_ragged=False,\n                encoder_lens=forward_batch.encoder_lens,\n                spec_info=forward_batch.spec_info,\n            )\n            self.forward_metadata = PrefillMetadata(\n                self.prefill_wrappers_paged, False, False\n            )\n        elif forward_batch.forward_mode.is_target_verify():\n            self.indices_updater_prefill.update(\n                forward_batch.req_pool_indices,\n                forward_batch.seq_lens,\n                forward_batch.seq_lens_cpu,\n                forward_batch.seq_lens_sum,\n                prefix_lens=None,\n                prefill_wrappers=self.prefill_wrappers_verify,\n                use_ragged=False,\n                encoder_lens=forward_batch.encoder_lens,\n                spec_info=forward_batch.spec_info,\n            )\n            self.forward_metadata = PrefillMetadata(\n                self.prefill_wrappers_verify, False, False\n            )\n        else:\n            prefix_lens = forward_batch.extend_prefix_lens\n\n            # Disable ragged wrapper and ensure prefix handling for multimodal and multi-item scoring\n            if self.is_multimodal or self.multi_item_scoring_delimiter is not None:\n                # use_ragged = False: Multi-item scoring requires the paged wrapper because:\n                # 1. Ragged wrapper doesn't support the specialized multi-item parameters\n                #    (prefix_len_ptr, token_pos_in_items_ptr, etc.)\n                # 2. Paged wrapper provides better control over attention masking needed\n                #    for respecting item boundaries in multi-item sequences\n                # 3. Custom masking logic conflicts with ragged wrapper's assumptions\n                use_ragged = False\n                extend_no_prefix = False\n            else:\n                use_ragged = (\n                    not self.enable_deterministic and not is_in_piecewise_cuda_graph()\n                )\n                extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)\n\n            # Process multi-item scoring in attention backend instead of ForwardBatch\n            multi_item_params = MultiItemScoringParams()\n            if self.multi_item_scoring_delimiter is not None:\n                # Use new backend-specific implementation\n                multi_item_params = self._process_multi_item_scoring(forward_batch)\n\n            self.indices_updater_prefill.update(\n                forward_batch.req_pool_indices,\n                forward_batch.seq_lens,\n                forward_batch.seq_lens_cpu,\n                forward_batch.seq_lens_sum,\n                prefix_lens,\n                prefill_wrappers=self.prefill_wrappers_paged,\n                use_ragged=use_ragged,\n                encoder_lens=forward_batch.encoder_lens,\n                spec_info=None,\n                fixed_split_size=self.prefill_split_tile_size,\n                multi_item_params=multi_item_params,\n            )\n            self.forward_metadata = PrefillMetadata(\n                self.prefill_wrappers_paged,\n                use_ragged,\n                extend_no_prefix,\n                multi_item_params,\n            )\n\n    def init_cuda_graph_state(\n        self,\n        max_bs: int,\n        max_num_tokens: int,\n        kv_indices_buf: Optional[torch.Tensor] = None,\n    ):\n        if kv_indices_buf is None:\n            cuda_graph_kv_indices = torch.zeros(\n                (max_num_tokens * self.max_context_len,),\n                dtype=torch.int32,\n                device=\"cuda\",\n            )\n        else:\n            cuda_graph_kv_indices = kv_indices_buf\n\n        self.cuda_graph_kv_indices = [cuda_graph_kv_indices] + [\n            cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)\n        ]\n\n        # Ensure tensors are properly allocated\n        for i in range(self.num_wrappers):\n            # Force allocation by performing a small operation\n            if len(self.cuda_graph_kv_indices[i]) > 0:\n                self.cuda_graph_kv_indices[i][0] = 0\n\n        if not self.skip_prefill:\n            self.cuda_graph_custom_mask = torch.zeros(\n                (max_num_tokens * self.max_context_len),\n                dtype=torch.uint8,\n                device=\"cuda\",\n            )\n            self.cuda_graph_qk_indptr = [x.clone() for x in self.kv_indptr]\n            self.cuda_graph_qo_indptr = [x.clone() for x in self.kv_indptr]\n\n    def init_forward_metadata_capture_cuda_graph(\n        self,\n        bs: int,\n        num_tokens: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        encoder_lens: Optional[torch.Tensor],\n        forward_mode: ForwardMode,\n        spec_info: Optional[SpecInput],\n    ):\n        if forward_mode.is_decode_or_idle():\n            decode_wrappers = []\n            for i in range(self.num_wrappers):\n                decode_wrappers.append(\n                    BatchDecodeWithPagedKVCacheWrapper(\n                        self.workspace_buffer,\n                        \"NHD\",\n                        backend=self.decode_backend,\n                        use_cuda_graph=True,\n                        use_tensor_cores=self.decode_use_tensor_cores,\n                        paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1],\n                        paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],\n                        paged_kv_last_page_len_buffer=self.kv_last_page_len[\n                            :num_tokens\n                        ],\n                    )\n                )\n            seq_lens_sum = seq_lens.sum().item()\n            self.indices_updater_decode.update(\n                req_pool_indices,\n                seq_lens,\n                seq_lens.cpu(),  # may add a little overhead in capture stage\n                seq_lens_sum,\n                decode_wrappers=decode_wrappers,\n                encoder_lens=encoder_lens,\n                spec_info=spec_info,\n                fixed_split_size=None,\n                disable_split_kv=self.disable_cuda_graph_kv_split,\n            )\n            self.decode_cuda_graph_metadata[bs] = decode_wrappers\n            self.forward_metadata = DecodeMetadata(decode_wrappers)\n            for i in range(self.num_wrappers):\n                decode_wrappers[i].begin_forward = partial(\n                    fast_decode_plan, decode_wrappers[i]\n                )\n        elif forward_mode.is_target_verify():\n            prefill_wrappers = []\n            for i in range(self.num_wrappers):\n                prefill_wrappers.append(\n                    BatchPrefillWithPagedKVCacheWrapper(\n                        self.workspace_buffer,\n                        \"NHD\",\n                        use_cuda_graph=True,\n                        backend=self.prefill_backend,\n                        qo_indptr_buf=self.cuda_graph_qo_indptr[i][: bs + 1],\n                        paged_kv_indptr_buf=self.kv_indptr[i][: bs + 1],\n                        paged_kv_indices_buf=self.cuda_graph_kv_indices[i],\n                        paged_kv_last_page_len_buf=self.kv_last_page_len[:bs],\n                        custom_mask_buf=self.cuda_graph_custom_mask,\n                        mask_indptr_buf=self.cuda_graph_qk_indptr[i][: bs + 1],\n                    )\n                )\n            seq_lens_sum = seq_lens.sum().item()\n            self.indices_updater_prefill.update(\n                req_pool_indices,\n                seq_lens,\n                seq_lens.cpu(),  # may add a little overhead in capture stage\n                seq_lens_sum,\n                prefix_lens=None,\n                prefill_wrappers=prefill_wrappers,\n                use_ragged=False,\n                encoder_lens=encoder_lens,\n                spec_info=spec_info,\n            )\n            self.prefill_cuda_graph_metadata[bs] = prefill_wrappers\n            self.forward_metadata = PrefillMetadata(prefill_wrappers, False, False)\n        elif forward_mode.is_draft_extend():\n            prefill_wrappers = []\n            for i in range(self.num_wrappers):\n                prefill_wrappers.append(\n                    BatchPrefillWithPagedKVCacheWrapper(\n                        self.workspace_buffer,\n                        \"NHD\",\n                        backend=self.prefill_backend,\n                        use_cuda_graph=True,\n                        qo_indptr_buf=self.cuda_graph_qo_indptr[i][: bs + 1],\n                        paged_kv_indptr_buf=self.kv_indptr[i][: bs + 1],\n                        paged_kv_indices_buf=self.cuda_graph_kv_indices[i],\n                        paged_kv_last_page_len_buf=self.kv_last_page_len[:bs],\n                    )\n                )\n\n            seq_lens_sum = seq_lens.sum().item()\n            self.indices_updater_prefill.update(\n                req_pool_indices,\n                seq_lens,\n                seq_lens.cpu(),  # may add a little overhead in capture stage\n                seq_lens_sum,\n                prefix_lens=None,\n                prefill_wrappers=prefill_wrappers,\n                use_ragged=False,\n                encoder_lens=encoder_lens,\n                spec_info=spec_info,\n            )\n            self.prefill_cuda_graph_metadata[bs] = prefill_wrappers\n            self.forward_metadata = PrefillMetadata(prefill_wrappers, False, False)\n        elif forward_mode.is_dllm_extend():\n            prefill_wrappers = []\n            for i in range(self.num_wrappers):\n                prefill_wrappers.append(\n                    BatchPrefillWithPagedKVCacheWrapper(\n                        self.workspace_buffer,\n                        \"NHD\",\n                        backend=self.prefill_backend,\n                        use_cuda_graph=True,\n                        qo_indptr_buf=self.cuda_graph_qo_indptr[i][: bs + 1],\n                        paged_kv_indptr_buf=self.kv_indptr[i][: bs + 1],\n                        paged_kv_indices_buf=self.cuda_graph_kv_indices[i],\n                        paged_kv_last_page_len_buf=self.kv_last_page_len[:bs],\n                    )\n                )\n            seq_lens_sum = seq_lens.sum().item()\n            self.indices_updater_prefill.update(\n                req_pool_indices,\n                seq_lens,\n                seq_lens.cpu(),  # may add a little overhead in capture stage\n                seq_lens_sum,\n                prefix_lens=seq_lens - self.dllm_config.block_size,\n                prefill_wrappers=prefill_wrappers,\n                use_ragged=True,\n                encoder_lens=encoder_lens,\n                spec_info=None,\n            )\n            self.prefill_cuda_graph_metadata[bs] = prefill_wrappers\n            self.forward_metadata = PrefillMetadata(prefill_wrappers, True, False)\n        else:\n            raise ValueError(f\"Invalid mode: {forward_mode=}\")\n\n    def init_forward_metadata_replay_cuda_graph(\n        self,\n        bs: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        seq_lens_sum: int,\n        encoder_lens: Optional[torch.Tensor],\n        forward_mode: ForwardMode,\n        spec_info: Optional[SpecInput],\n        seq_lens_cpu: Optional[torch.Tensor],\n    ):\n        if forward_mode.is_decode_or_idle():\n            self.indices_updater_decode.update(\n                req_pool_indices[:bs],\n                seq_lens[:bs],\n                seq_lens_cpu[:bs] if seq_lens_cpu is not None else None,\n                seq_lens_sum,\n                decode_wrappers=self.decode_cuda_graph_metadata[bs],\n                encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,\n                spec_info=spec_info,\n                fixed_split_size=None,\n                disable_split_kv=self.disable_cuda_graph_kv_split,\n            )\n        elif forward_mode.is_target_verify():\n            self.indices_updater_prefill.update(\n                req_pool_indices[:bs],\n                seq_lens[:bs],\n                seq_lens_cpu[:bs] if seq_lens_cpu is not None else None,\n                seq_lens_sum,\n                prefix_lens=None,\n                prefill_wrappers=self.prefill_cuda_graph_metadata[bs],\n                use_ragged=False,\n                encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,\n                spec_info=spec_info,\n            )\n        elif forward_mode.is_draft_extend():\n            self.indices_updater_prefill.update(\n                req_pool_indices[:bs],\n                seq_lens[:bs],\n                seq_lens_cpu[:bs] if seq_lens_cpu is not None else None,\n                seq_lens_sum,\n                prefix_lens=None,\n                prefill_wrappers=self.prefill_cuda_graph_metadata[bs],\n                use_ragged=False,\n                encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,\n                spec_info=spec_info,\n            )\n        elif forward_mode.is_dllm_extend():\n            self.indices_updater_prefill.update(\n                req_pool_indices[:bs],\n                seq_lens[:bs],\n                seq_lens_cpu[:bs] if seq_lens_cpu is not None else None,\n                seq_lens_sum,\n                prefix_lens=seq_lens - self.dllm_config.block_size,\n                prefill_wrappers=self.prefill_cuda_graph_metadata[bs],\n                use_ragged=True,\n                encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,\n                spec_info=None,\n            )\n        else:\n            raise ValueError(\"Invalid forward mode\")\n\n    def get_cuda_graph_seq_len_fill_value(self):\n        return 1\n\n    def forward_extend(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache=True,\n    ):\n        prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[\n            self._get_wrapper_idx(layer)\n        ]\n        cache_loc = (\n            forward_batch.out_cache_loc\n            if not layer.is_cross_attention\n            else forward_batch.encoder_out_cache_loc\n        )\n\n        logits_soft_cap = layer.logit_cap\n\n        q = q.contiguous()\n        if not self.forward_metadata.use_ragged:\n            if k is not None:\n                assert v is not None\n                if save_kv_cache:\n                    forward_batch.token_to_kv_pool.set_kv_buffer(\n                        layer, cache_loc, k, v, layer.k_scale, layer.v_scale\n                    )\n\n            o = prefill_wrapper_paged.forward(\n                q.view(-1, layer.tp_q_head_num, layer.head_dim),\n                forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),\n                causal=not layer.is_cross_attention,\n                sm_scale=layer.scaling,\n                # Disable sliding window attention for multi-item scoring:\n                # - Sliding window could cut across item boundaries, breaking semantic coherence\n                # - Multi-item sequences need full attention to properly handle delimiter tokens\n                # - Specialized multi-item parameters (prefix_len_ptr, token_pos_in_items_ptr)\n                #   provide more precise attention control than simple sliding windows\n                # - Item-aware masking takes precedence over window-based masking\n                window_left=(\n                    layer.sliding_window_size\n                    if not (\n                        self.forward_metadata.multi_item_params\n                        and self.forward_metadata.multi_item_params.is_enabled()\n                    )\n                    else -1\n                ),\n                logits_soft_cap=logits_soft_cap,\n                # Must use _float to avoid device-to-host copy that breaks cuda graph capture.\n                k_scale=layer.k_scale_float,\n                v_scale=layer.v_scale_float,\n            )\n        else:\n            # If `k`/`v` are not explicitly provided, fall back to the KV cache stored in\n            # `forward_batch.token_to_kv_pool` for this layer. This enables attention over\n            # previously cached context without re-materializing KV tensors (e.g., the\n            # IQuestLoopCoder path uses token_to_kv_pool as the KV source).\n            if k is None and v is None:\n                k = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)[0]\n                v = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)[1]\n            causal = True\n            if (\n                layer.is_cross_attention\n                or layer.attn_type == AttentionType.ENCODER_ONLY\n            ):\n                causal = False\n            if not self.is_dllm_model and layer.attn_type == AttentionType.ENCODER_ONLY:\n                save_kv_cache = False\n\n            if self.forward_metadata.extend_no_prefix:\n                # NOTE: FlashInfer currently has limitations with head_dim = 32 or other dimensions\n                # The FlashInfer head_dim limitation itself is tracked here:\n                # https://github.com/flashinfer-ai/flashinfer/issues/1048\n                o = self.prefill_wrapper_ragged.forward(\n                    q.view(-1, layer.tp_q_head_num, layer.head_dim),\n                    k.view(-1, layer.tp_k_head_num, layer.head_dim),\n                    v.view(-1, layer.tp_v_head_num, layer.head_dim),\n                    causal=causal,\n                    sm_scale=layer.scaling,\n                    logits_soft_cap=logits_soft_cap,\n                )\n\n            else:\n                if not self.is_dllm_model:\n                    # TODO: design a better interface\n                    # For other models, use causal attention for the ragged part as previously\n                    causal = True\n\n                o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(\n                    q.view(-1, layer.tp_q_head_num, layer.head_dim),\n                    k.view(-1, layer.tp_k_head_num, layer.head_dim),\n                    v.view(-1, layer.tp_v_head_num, layer.head_dim),\n                    causal=causal,\n                    sm_scale=layer.scaling,\n                    logits_soft_cap=logits_soft_cap,\n                )\n                o2, s2 = prefill_wrapper_paged.forward_return_lse(\n                    q.view(-1, layer.tp_q_head_num, layer.head_dim),\n                    forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),\n                    causal=False,\n                    sm_scale=layer.scaling,\n                    logits_soft_cap=logits_soft_cap,\n                )\n\n                o, _ = merge_state(o1, s1, o2, s2)\n\n            if save_kv_cache:\n                forward_batch.token_to_kv_pool.set_kv_buffer(\n                    layer, cache_loc, k, v, layer.k_scale, layer.v_scale\n                )\n\n        return o.view(-1, layer.tp_q_head_num * layer.head_dim)\n\n    def forward_decode(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache=True,\n    ):\n        decode_wrapper = self.forward_metadata.decode_wrappers[\n            self._get_wrapper_idx(layer)\n        ]\n        cache_loc = (\n            forward_batch.out_cache_loc\n            if not layer.is_cross_attention\n            else forward_batch.encoder_out_cache_loc\n        )\n\n        if k is not None:\n            assert v is not None\n            if save_kv_cache:\n                forward_batch.token_to_kv_pool.set_kv_buffer(\n                    layer, cache_loc, k, v, layer.k_scale, layer.v_scale\n                )\n\n        # Call the wrapped function\n        o = decode_wrapper.forward(\n            q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),\n            forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),\n            sm_scale=layer.scaling,\n            logits_soft_cap=layer.logit_cap,\n            # Must use _float to avoid device-to-host copy that breaks cuda graph capture.\n            k_scale=layer.k_scale_float,\n            v_scale=layer.v_scale_float,\n        )\n\n        return o.view(-1, layer.tp_q_head_num * layer.head_dim)\n\n    def _get_wrapper_idx(self, layer: RadixAttention):\n        if self.num_wrappers == 1:\n            return 0\n\n        if self.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:\n            return layer.sliding_window_size == -1\n        if self.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:\n            return layer.is_cross_attention\n\n        raise ValueError(f\"Unknown dispatch reason: {self.dispatch_reason}\")\n\n\nclass FlashInferIndicesUpdaterDecode:\n    def __init__(self, model_runner: ModelRunner, attn_backend: FlashInferAttnBackend):\n        # Parse Constants\n        self.num_qo_heads = (\n            model_runner.model_config.num_attention_heads // get_attention_tp_size()\n        )\n        self.num_kv_heads = model_runner.model_config.get_num_kv_heads(\n            get_attention_tp_size()\n        )\n        self.head_dim = model_runner.model_config.head_dim\n        self.data_type = model_runner.kv_cache_dtype\n        self.q_data_type = model_runner.dtype\n        self.sliding_window_size = model_runner.sliding_window_size\n        self.attn_backend = attn_backend\n\n        # Buffers and wrappers\n        self.kv_indptr = attn_backend.kv_indptr\n        self.kv_last_page_len = attn_backend.kv_last_page_len\n        self.req_to_token = model_runner.req_to_token_pool.req_to_token\n        self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator\n\n        # Dispatch the update function\n        if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:\n            self.update = self.update_sliding_window\n        elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:\n            self.update = self.update_cross_attention\n        else:\n            assert self.attn_backend.num_wrappers == 1\n            self.update = self.update_single_wrapper\n\n    def update(\n        self,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        seq_lens_cpu: Optional[torch.Tensor],\n        seq_lens_sum: int,\n        decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],\n        encoder_lens: Optional[torch.Tensor],\n        spec_info: Optional[SpecInput],\n        fixed_split_size: Optional[int] = None,\n        disable_split_kv: Optional[bool] = None,\n    ):\n        # Keep the signature for type checking. It will be assigned during runtime.\n        raise NotImplementedError()\n\n    def update_single_wrapper(\n        self,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        seq_lens_cpu: Optional[torch.Tensor],\n        seq_lens_sum: int,\n        decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],\n        encoder_lens: Optional[torch.Tensor],\n        spec_info: Optional[SpecInput],\n        fixed_split_size: Optional[int] = None,\n        disable_split_kv: Optional[bool] = None,\n    ):\n        decode_wrappers = decode_wrappers or self.decode_wrappers\n        self.call_begin_forward(\n            decode_wrappers[0],\n            req_pool_indices,\n            seq_lens,\n            seq_lens_sum,\n            self.kv_indptr[0],\n            None,\n            spec_info,\n            seq_lens_cpu,\n            fixed_split_size=fixed_split_size,\n            disable_split_kv=disable_split_kv,\n        )\n\n    def update_sliding_window(\n        self,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        seq_lens_cpu: Optional[torch.Tensor],\n        seq_lens_sum: int,\n        decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],\n        encoder_lens: Optional[torch.Tensor],\n        spec_info: Optional[SpecInput],\n        fixed_split_size: Optional[int] = None,\n        disable_split_kv: Optional[bool] = None,\n    ):\n        assert self.sliding_window_size is not None\n        for wrapper_id in range(2):\n            if wrapper_id == 0:\n                # Sliding window attention\n                paged_kernel_lens_tmp = torch.clamp(\n                    seq_lens, max=self.sliding_window_size + 1\n                )\n                if seq_lens_cpu is not None:\n                    seq_lens_cpu_tmp = torch.clamp(\n                        seq_lens_cpu, max=self.sliding_window_size + 1\n                    )\n                    paged_kernel_lens_sum_tmp = seq_lens_cpu_tmp.sum().item()\n                else:\n                    paged_kernel_lens_sum_tmp = paged_kernel_lens_tmp.sum().item()\n                kv_start_idx_tmp = seq_lens - paged_kernel_lens_tmp\n            else:\n                # Full attention\n                paged_kernel_lens_tmp = seq_lens\n                paged_kernel_lens_sum_tmp = seq_lens_sum\n                seq_lens_cpu_tmp = seq_lens_cpu\n                kv_start_idx_tmp = None\n\n            use_sliding_window_kv_pool = wrapper_id == 0 and isinstance(\n                self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator\n            )\n\n            self.call_begin_forward(\n                decode_wrappers[wrapper_id],\n                req_pool_indices,\n                paged_kernel_lens_tmp,\n                paged_kernel_lens_sum_tmp,\n                self.kv_indptr[wrapper_id],\n                kv_start_idx_tmp,\n                spec_info,\n                seq_lens_cpu=seq_lens_cpu_tmp,\n                use_sliding_window_kv_pool=use_sliding_window_kv_pool,\n            )\n\n    def update_cross_attention(\n        self,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        seq_lens_cpu: Optional[torch.Tensor],\n        seq_lens_sum: int,\n        decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],\n        encoder_lens: Optional[torch.Tensor],\n        spec_info: Optional[SpecInput],\n        fixed_split_size: Optional[int] = None,\n        disable_split_kv: Optional[bool] = None,\n    ):\n        for wrapper_id in range(2):\n            if wrapper_id == 0:\n                # Normal attention\n                paged_kernel_lens = seq_lens\n                kv_start_idx = encoder_lens\n            else:\n                # Cross attention\n                paged_kernel_lens = encoder_lens\n                kv_start_idx = torch.zeros_like(encoder_lens)\n                seq_lens_sum = encoder_lens.sum().item()\n\n            self.call_begin_forward(\n                decode_wrappers[wrapper_id],\n                req_pool_indices,\n                paged_kernel_lens,\n                seq_lens_sum,\n                self.kv_indptr[wrapper_id],\n                kv_start_idx,\n                spec_info,\n                seq_lens_cpu=seq_lens_cpu,\n            )\n\n    def call_begin_forward(\n        self,\n        wrapper: BatchDecodeWithPagedKVCacheWrapper,\n        req_pool_indices: torch.Tensor,\n        paged_kernel_lens: torch.Tensor,\n        paged_kernel_lens_sum: int,\n        kv_indptr: torch.Tensor,\n        kv_start_idx: torch.Tensor,\n        spec_info: Optional[SpecInput],\n        seq_lens_cpu: Optional[torch.Tensor],\n        use_sliding_window_kv_pool: bool = False,\n        fixed_split_size: Optional[int] = None,\n        disable_split_kv: Optional[bool] = None,\n    ):\n        if spec_info is None:\n            bs = len(req_pool_indices)\n            kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)\n            kv_indptr = kv_indptr[: bs + 1]\n\n            if wrapper.is_cuda_graph_enabled:\n                # Directly write to the cuda graph input buffer\n                kv_indices = wrapper._paged_kv_indices_buf\n            else:\n                kv_indices = torch.empty(\n                    paged_kernel_lens_sum, dtype=torch.int32, device=\"cuda\"\n                )\n\n            create_flashinfer_kv_indices_triton[(bs,)](\n                self.req_to_token,\n                req_pool_indices,\n                paged_kernel_lens,\n                kv_indptr,\n                kv_start_idx,\n                kv_indices,\n                self.req_to_token.shape[1],\n            )\n        else:\n            kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices\n            bs = kv_indptr.shape[0] - 1\n\n        if use_sliding_window_kv_pool:\n            kv_last_index = kv_indptr[-1]\n            kv_indices[:kv_last_index] = (\n                self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(\n                    kv_indices[:kv_last_index]\n                )\n            )\n\n        global global_override_indptr_cpu\n        locally_override = False\n        if seq_lens_cpu is not None and global_override_indptr_cpu is None:\n            locally_override = True\n            global_override_indptr_cpu = torch.empty_like(kv_indptr, device=\"cpu\")\n            global_override_indptr_cpu[0] = 0\n            global_override_indptr_cpu[1 : bs + 1] = torch.cumsum(seq_lens_cpu, dim=0)\n\n        # Check if this specific wrapper's begin_forward has been replaced with fast_decode_plan\n        # by checking if it's a partial function with fast_decode_plan as the func\n        wrapper_uses_fast_decode_plan = (\n            hasattr(wrapper.begin_forward, \"func\")\n            and wrapper.begin_forward.func == fast_decode_plan\n        )\n\n        if wrapper_uses_fast_decode_plan:\n            # When begin_forward is replaced with fast_decode_plan, pass global_override_indptr_cpu\n            wrapper.begin_forward(\n                kv_indptr,\n                kv_indices,\n                self.kv_last_page_len[:bs],\n                self.num_qo_heads,\n                self.num_kv_heads,\n                self.head_dim,\n                1,\n                data_type=self.data_type,\n                q_data_type=self.q_data_type,\n                non_blocking=True,\n                fixed_split_size=fixed_split_size,\n                disable_split_kv=(\n                    disable_split_kv if disable_split_kv is not None else False\n                ),\n                global_override_indptr_cpu=global_override_indptr_cpu,\n            )\n        else:\n            # When using original begin_forward, don't pass global_override_indptr_cpu\n            wrapper.begin_forward(\n                kv_indptr,\n                kv_indices,\n                self.kv_last_page_len[:bs],\n                self.num_qo_heads,\n                self.num_kv_heads,\n                self.head_dim,\n                1,\n                data_type=self.data_type,\n                q_data_type=self.q_data_type,\n                non_blocking=True,\n                fixed_split_size=fixed_split_size,\n                disable_split_kv=(\n                    disable_split_kv if disable_split_kv is not None else False\n                ),\n            )\n\n        if locally_override:\n            global_override_indptr_cpu = None\n\n\nclass FlashInferIndicesUpdaterPrefill:\n    def __init__(self, model_runner: ModelRunner, attn_backend: FlashInferAttnBackend):\n        # Parse Constants\n        self.num_qo_heads = (\n            model_runner.model_config.num_attention_heads // get_attention_tp_size()\n        )\n        self.num_kv_heads = model_runner.model_config.get_num_kv_heads(\n            get_attention_tp_size()\n        )\n        self.head_dim = model_runner.model_config.head_dim\n        self.data_type = model_runner.kv_cache_dtype\n        self.q_data_type = model_runner.dtype\n        self.sliding_window_size = model_runner.sliding_window_size\n        self.attn_backend = attn_backend\n\n        # Buffers and wrappers\n        self.kv_indptr = attn_backend.kv_indptr\n        self.kv_last_page_len = attn_backend.kv_last_page_len\n        self.qo_indptr = attn_backend.qo_indptr\n        self.req_to_token = model_runner.req_to_token_pool.req_to_token\n        self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator\n        self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged\n\n        # Dispatch the update function\n        if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:\n            self.update = self.update_sliding_window\n        elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:\n            self.update = self.update_cross_attention\n        else:\n            assert self.attn_backend.num_wrappers == 1\n            self.update = self.update_single_wrapper\n\n    def update(\n        self,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        seq_lens_cpu: Optional[torch.Tensor],\n        seq_lens_sum: int,\n        prefix_lens: torch.Tensor,\n        prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],\n        use_ragged: bool,\n        encoder_lens: Optional[torch.Tensor],\n        spec_info: Optional[SpecInput],\n        fixed_split_size: Optional[int] = None,\n    ):\n        # Keep the signature for type checking. It will be assigned during runtime.\n        raise NotImplementedError()\n\n    def update_single_wrapper(\n        self,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        seq_lens_cpu: Optional[torch.Tensor],\n        seq_lens_sum: int,\n        prefix_lens: torch.Tensor,\n        prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],\n        use_ragged: bool,\n        encoder_lens: Optional[torch.Tensor],\n        spec_info: Optional[SpecInput],\n        fixed_split_size: Optional[int] = None,\n        multi_item_params: Optional[MultiItemScoringParams] = None,\n    ):\n        if use_ragged:\n            # TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu\n            # and forward_batch.extend_seq_lens_cpu\n            paged_kernel_lens = prefix_lens\n            paged_kernel_lens_sum = paged_kernel_lens.sum().item()\n        else:\n            paged_kernel_lens = seq_lens\n            paged_kernel_lens_sum = seq_lens_sum\n\n        self.call_begin_forward(\n            self.prefill_wrapper_ragged,\n            prefill_wrappers[0],\n            req_pool_indices,\n            paged_kernel_lens,\n            paged_kernel_lens_sum,\n            seq_lens,\n            prefix_lens,\n            None,\n            self.kv_indptr[0],\n            self.qo_indptr[0],\n            use_ragged,\n            spec_info,\n            fixed_split_size=fixed_split_size,\n            multi_item_params=multi_item_params,\n        )\n\n    def update_sliding_window(\n        self,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        seq_lens_cpu: Optional[torch.Tensor],\n        seq_lens_sum: int,\n        prefix_lens: torch.Tensor,\n        prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],\n        use_ragged: bool,\n        encoder_lens: Optional[torch.Tensor],\n        spec_info: Optional[SpecInput],\n        fixed_split_size: Optional[int] = None,\n        multi_item_params: Optional[MultiItemScoringParams] = None,\n    ):\n        for wrapper_id in range(2):\n            if wrapper_id == 0:\n                # window attention use paged only\n                paged_kernel_lens = torch.minimum(\n                    seq_lens,\n                    torch.tensor(self.sliding_window_size) + seq_lens - prefix_lens,\n                )\n                paged_kernel_lens_sum = paged_kernel_lens.sum().item()\n            else:\n                # full attention\n                paged_kernel_lens = seq_lens\n                paged_kernel_lens_sum = seq_lens_sum\n\n            kv_start_idx = seq_lens - paged_kernel_lens\n            use_sliding_window_kv_pool = wrapper_id == 0 and isinstance(\n                self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator\n            )\n\n            self.call_begin_forward(\n                self.prefill_wrapper_ragged,\n                prefill_wrappers[wrapper_id],\n                req_pool_indices,\n                paged_kernel_lens,\n                paged_kernel_lens_sum,\n                seq_lens,\n                prefix_lens,\n                kv_start_idx,\n                self.kv_indptr[wrapper_id],\n                self.qo_indptr[wrapper_id],\n                use_ragged,\n                spec_info,\n                use_sliding_window_kv_pool=use_sliding_window_kv_pool,\n                multi_item_params=multi_item_params,\n            )\n\n    def update_cross_attention(\n        self,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        seq_lens_cpu: Optional[torch.Tensor],\n        seq_lens_sum: int,\n        prefix_lens: torch.Tensor,\n        prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],\n        use_ragged: bool,\n        encoder_lens: Optional[torch.Tensor],\n        spec_info: Optional[SpecInput],\n        fixed_split_size: Optional[int] = None,\n        multi_item_params: Optional[MultiItemScoringParams] = None,\n    ):\n        for wrapper_id in range(2):\n            if wrapper_id == 0:\n                # normal attention\n                paged_kernel_lens = seq_lens\n                kv_start_idx = encoder_lens\n                paged_kernel_lens_sum = seq_lens_sum\n            else:\n                # cross attention\n                paged_kernel_lens = encoder_lens\n                kv_start_idx = torch.zeros_like(encoder_lens)\n                paged_kernel_lens_sum = paged_kernel_lens.sum().item()\n\n            self.call_begin_forward(\n                self.prefill_wrapper_ragged,\n                prefill_wrappers[wrapper_id],\n                req_pool_indices,\n                paged_kernel_lens,\n                paged_kernel_lens_sum,\n                seq_lens,\n                prefix_lens,\n                kv_start_idx,\n                self.kv_indptr[wrapper_id],\n                self.qo_indptr[wrapper_id],\n                use_ragged,\n                spec_info,\n                multi_item_params=multi_item_params,\n            )\n\n    def call_begin_forward(\n        self,\n        wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper,\n        wrapper_paged: BatchPrefillWithPagedKVCacheWrapper,\n        req_pool_indices: torch.Tensor,\n        paged_kernel_lens: torch.Tensor,\n        paged_kernel_lens_sum: int,\n        seq_lens: torch.Tensor,\n        prefix_lens: torch.Tensor,\n        kv_start_idx: torch.Tensor,\n        kv_indptr: torch.Tensor,\n        qo_indptr: torch.Tensor,\n        use_ragged: bool,\n        spec_info: Optional[SpecInput],\n        use_sliding_window_kv_pool: bool = False,\n        fixed_split_size: Optional[int] = None,\n        multi_item_params: Optional[MultiItemScoringParams] = None,\n    ):\n        bs = len(seq_lens)\n        if spec_info is None:\n            assert len(seq_lens) == len(req_pool_indices)\n            # Normal extend\n            kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)\n            kv_indptr = kv_indptr[: bs + 1]\n            kv_indices = torch.empty(\n                paged_kernel_lens_sum + 256,\n                dtype=torch.int32,\n                device=req_pool_indices.device,\n            )\n            create_flashinfer_kv_indices_triton[(bs,)](\n                self.req_to_token,\n                req_pool_indices,\n                paged_kernel_lens,\n                kv_indptr,\n                kv_start_idx,\n                kv_indices,\n                self.req_to_token.shape[1],\n            )\n            qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)\n            qo_indptr = qo_indptr[: bs + 1]\n            custom_mask = None\n        else:\n            assert isinstance(spec_info, SpecInput)\n            kv_indices, kv_indptr, qo_indptr, custom_mask = (\n                spec_info.generate_attn_arg_prefill(\n                    req_pool_indices,\n                    paged_kernel_lens,\n                    paged_kernel_lens_sum,\n                    self.req_to_token,\n                )\n            )\n\n        # extend part\n        if use_ragged:\n            wrapper_ragged.begin_forward(\n                qo_indptr,\n                qo_indptr,\n                self.num_qo_heads,\n                self.num_kv_heads,\n                self.head_dim,\n                q_data_type=self.q_data_type,\n            )\n\n        if use_sliding_window_kv_pool:\n            kv_last_index = kv_indptr[-1]\n            kv_indices[:kv_last_index] = (\n                self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(\n                    kv_indices[:kv_last_index]\n                )\n            )\n\n        # cached part\n        # Conditionally set multi-item parameters\n        if multi_item_params is not None and multi_item_params.is_enabled():\n            # Multi-item scoring is active - use specialized parameters and disable generic custom_mask\n            use_custom_mask = None\n            prefix_len_ptr = multi_item_params.prefix_len_ptr\n            token_pos_in_items_ptr = multi_item_params.token_pos_in_items_ptr\n            token_pos_in_items_len = multi_item_params.token_pos_in_items_len\n            max_item_len_ptr = multi_item_params.max_item_len_ptr\n        else:\n            # No multi-item scoring - use standard parameters\n            use_custom_mask = custom_mask\n            prefix_len_ptr = None\n            token_pos_in_items_ptr = None\n            token_pos_in_items_len = 0\n            max_item_len_ptr = None\n\n        wrapper_paged.begin_forward(\n            qo_indptr,\n            kv_indptr,\n            kv_indices,\n            self.kv_last_page_len[:bs],\n            self.num_qo_heads,\n            self.num_kv_heads,\n            self.head_dim,\n            1,\n            q_data_type=self.q_data_type,\n            kv_data_type=self.data_type,\n            custom_mask=use_custom_mask,\n            non_blocking=True,\n            fixed_split_size=fixed_split_size,\n            prefix_len_ptr=prefix_len_ptr,\n            token_pos_in_items_ptr=token_pos_in_items_ptr,\n            token_pos_in_items_len=token_pos_in_items_len,\n            max_item_len_ptr=max_item_len_ptr,\n        )\n\n\nclass FlashInferMultiStepDraftBackend:\n    \"\"\"\n    Wrap multiple flashinfer attention backends as one for multiple consecutive\n    draft decoding steps.\n    \"\"\"\n\n    def __init__(\n        self,\n        model_runner: ModelRunner,\n        topk: int,\n        speculative_num_steps: int,\n    ):\n        from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices\n\n        self.topk = topk\n        self.speculative_num_steps = speculative_num_steps\n        self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices\n        self.page_size = model_runner.page_size\n\n        max_bs = model_runner.req_to_token_pool.size * self.topk\n        self.kv_indptr = torch.zeros(\n            (\n                self.speculative_num_steps,\n                max_bs + 1,\n            ),\n            dtype=torch.int32,\n            device=model_runner.device,\n        )\n        self.kv_last_page_len = torch.ones(\n            (max_bs,), dtype=torch.int32, device=model_runner.device\n        )\n        self.attn_backends: List[FlashInferAttnBackend] = []\n        for i in range(self.speculative_num_steps - 1):\n            self.attn_backends.append(\n                FlashInferAttnBackend(\n                    model_runner,\n                    skip_prefill=True,\n                    kv_indptr_buf=self.kv_indptr[i],\n                    kv_last_page_len_buf=self.kv_last_page_len,\n                )\n            )\n\n        self.max_context_len = self.attn_backends[0].max_context_len\n\n        # Cached variables for generate_draft_decode_kv_indices\n        self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]\n\n    def common_template(\n        self,\n        forward_batch: ForwardBatch,\n        kv_indices_buffer: torch.Tensor,\n        call_fn: Callable,\n    ):\n        num_seqs = forward_batch.batch_size\n        bs = self.topk * num_seqs\n        seq_lens_sum = forward_batch.seq_lens_sum\n\n        self.generate_draft_decode_kv_indices[\n            (self.speculative_num_steps, num_seqs, self.topk)\n        ](\n            forward_batch.req_pool_indices,\n            forward_batch.req_to_token_pool.req_to_token,\n            forward_batch.seq_lens,\n            kv_indices_buffer,\n            self.kv_indptr,\n            forward_batch.positions,\n            self.pool_len,\n            kv_indices_buffer.shape[1],\n            self.kv_indptr.shape[1],\n            next_power_of_2(num_seqs),\n            next_power_of_2(self.speculative_num_steps),\n            next_power_of_2(bs),\n            self.page_size,\n        )\n\n        assert forward_batch.spec_info is not None\n        assert forward_batch.spec_info.is_draft_input()\n\n        # Copy the kv_indptr once to avoid multiple device-to-host copies in flashinfer's plan.\n        indptr_cpu_whole = self.kv_indptr[:, : bs + 1].cpu()\n        global global_override_indptr_cpu\n\n        for i in range(self.speculative_num_steps - 1):\n            forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]\n            forward_batch.spec_info.kv_indices = kv_indices_buffer[i][\n                : seq_lens_sum * self.topk + bs * (i + 1)\n            ]\n            global_override_indptr_cpu = indptr_cpu_whole[i]\n            call_fn(i, forward_batch)\n\n        global_override_indptr_cpu = None\n\n    def init_forward_metadata(self, forward_batch: ForwardBatch):\n        kv_indices = torch.empty(\n            (\n                self.speculative_num_steps,\n                forward_batch.batch_size * self.topk * self.max_context_len,\n            ),\n            dtype=torch.int32,\n            device=\"cuda\",\n        )\n\n        def call_fn(i, forward_batch):\n            forward_batch.spec_info.kv_indptr = (\n                forward_batch.spec_info.kv_indptr.clone()\n            )\n            forward_batch.spec_info.kv_indices = (\n                forward_batch.spec_info.kv_indices.clone()\n            )\n            self.attn_backends[i].init_forward_metadata(forward_batch)\n\n        self.common_template(forward_batch, kv_indices, call_fn)\n\n    def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):\n        self.cuda_graph_kv_indices = torch.zeros(\n            (self.speculative_num_steps, max_bs * self.max_context_len),\n            dtype=torch.int32,\n            device=\"cuda\",\n        )\n\n        for i in range(self.speculative_num_steps - 1):\n            self.attn_backends[i].init_cuda_graph_state(\n                max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]\n            )\n\n    def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):\n        def call_fn(i, forward_batch):\n            self.attn_backends[i].init_forward_metadata_capture_cuda_graph(\n                forward_batch.batch_size,\n                forward_batch.batch_size * self.topk,\n                forward_batch.req_pool_indices,\n                forward_batch.seq_lens,\n                encoder_lens=None,\n                forward_mode=ForwardMode.DECODE,\n                spec_info=forward_batch.spec_info,\n            )\n\n        self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)\n\n    def init_forward_metadata_replay_cuda_graph(\n        self, forward_batch: ForwardBatch, bs: int\n    ):\n        def call_fn(i, forward_batch):\n            self.attn_backends[i].init_forward_metadata_replay_cuda_graph(\n                bs,\n                forward_batch.req_pool_indices,\n                forward_batch.seq_lens,\n                seq_lens_sum=-1,\n                encoder_lens=None,\n                forward_mode=ForwardMode.DECODE,\n                spec_info=forward_batch.spec_info,\n                seq_lens_cpu=forward_batch.seq_lens_cpu,\n            )\n\n        self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)\n\n\ndef should_use_tensor_core(\n    kv_cache_dtype: torch.dtype,\n    num_attention_heads: int,\n    num_kv_heads: int,\n) -> bool:\n    \"\"\"\n    Determine whether to use tensor cores for attention computation.\n\n    Args:\n        kv_cache_dtype: Data type of the KV cache\n        num_attention_heads: Number of attention heads\n        num_kv_heads: Number of key/value heads\n\n    Returns:\n        bool: Whether to use tensor cores\n    \"\"\"\n    # Try to use environment variable first\n    env_override = os.environ.get(\"SGLANG_FLASHINFER_USE_TENSOR_CORE\")\n    if env_override is not None:\n        return env_override.lower() == \"true\"\n\n    # Try to use _grouped_size_compiled_for_decode_kernels if available\n    # This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug\n    try:\n        from flashinfer.decode import _grouped_size_compiled_for_decode_kernels\n\n        if not _grouped_size_compiled_for_decode_kernels(\n            num_attention_heads,\n            num_kv_heads,\n        ):\n            return True\n        else:\n            return False\n    except (ImportError, AttributeError):\n        pass\n\n    # Calculate GQA group size\n    gqa_group_size = num_attention_heads // num_kv_heads\n\n    # For Flashinfer, a GQA group size of at least 4 is needed to efficiently\n    # use Tensor Cores, as it fuses the head group with the token dimension in MMA.\n    if kv_cache_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):\n        return True\n    elif kv_cache_dtype in (torch.float16, torch.half, torch.bfloat16):\n        return gqa_group_size >= 4\n    else:\n        return False\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/flashinfer_mla_backend.py",
    "content": "from __future__ import annotations\n\n\"\"\"\nSupport attention backend for flashinfer MLA.\nThe flashinfer_mla_disable_ragged flag controls whether to use ragged prefill wrapper and defaults to be false.\nWhen it's set to false, all wrappers are BatchMLAPaged wrapper.\nWhen it's set to true, the backend uses BatchRagged and BatchMLAPaged wrapper for prefilling,\nand uses BatchMLAPaged wrapper for decoding.\nMore details can be found in https://docs.flashinfer.ai/api/mla.html\n\"\"\"\n\nfrom dataclasses import dataclass\nfrom functools import partial\nfrom typing import TYPE_CHECKING, Callable, Optional, Union\n\nimport torch\n\nfrom sglang.srt.compilation.piecewise_context_manager import is_in_piecewise_cuda_graph\nfrom sglang.srt.environ import envs\nfrom sglang.srt.layers.attention.base_attn_backend import AttentionBackend\nfrom sglang.srt.layers.attention.flashinfer_backend import (\n    create_flashinfer_kv_indices_triton,\n)\nfrom sglang.srt.layers.dp_attention import get_attention_tp_size\nfrom sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode\nfrom sglang.srt.server_args import get_global_server_args\nfrom sglang.srt.speculative.spec_info import SpecInput\nfrom sglang.srt.utils import (\n    is_flashinfer_available,\n    is_sm100_supported,\n    next_power_of_2,\n)\n\nif TYPE_CHECKING:\n    from sglang.srt.layers.attention.flashinfer_mla_backend import (\n        FlashInferMlaAttnBackend,\n    )\n    from sglang.srt.layers.radix_attention import RadixAttention\n    from sglang.srt.model_executor.model_runner import ModelRunner\n    from sglang.srt.speculative.spec_info import SpecInput\n\nif envs.SGLANG_ENABLE_TORCH_COMPILE.get():\n    import logging\n\n    torch._logging.set_logs(dynamo=logging.ERROR)\n    torch._dynamo.config.suppress_errors = True\n\nif is_flashinfer_available():\n    from flashinfer import (\n        BatchMLAPagedAttentionWrapper,\n        BatchPrefillWithRaggedKVCacheWrapper,\n    )\n\n\n@dataclass\nclass DecodeMetadata:\n    decode_wrapper: BatchMLAPagedAttentionWrapper\n\n\n@dataclass\nclass PrefillMetadata:\n    prefill_wrapper: BatchMLAPagedAttentionWrapper\n    use_ragged: bool\n\n\n# Reuse this workspace buffer across all flashinfer wrappers\nglobal_workspace_buffer = None\n\n\nclass FlashInferMhaChunkKVRunner:\n    def __init__(\n        self, model_runner: ModelRunner, attn_backend: FlashInferMlaAttnBackend\n    ):\n        # Parse Constants\n        self.num_local_heads = (\n            model_runner.model_config.num_attention_heads // get_attention_tp_size()\n        )\n        self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim\n        self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim\n        self.v_head_dim = model_runner.model_config.v_head_dim\n        self.data_type = model_runner.dtype\n        self.q_data_type = model_runner.dtype\n\n        # Buffers and wrappers\n        self.qo_indptr = attn_backend.qo_indptr\n        self.kv_indptr = attn_backend.kv_indptr\n        self.workspace_buffer = attn_backend.workspace_buffer\n        self.fmha_backend = attn_backend.fmha_backend\n\n        self.chunk_ragged_wrappers = []\n        self.ragged_wrapper = attn_backend.prefill_wrapper_ragged\n\n    def update_prefix_chunks(self, num_prefix_chunks: int):\n        while num_prefix_chunks > len(self.chunk_ragged_wrappers):\n            ragged_wrapper = BatchPrefillWithRaggedKVCacheWrapper(\n                self.workspace_buffer, \"NHD\", backend=self.fmha_backend\n            )\n            self.chunk_ragged_wrappers.append(ragged_wrapper)\n\n    def update_wrapper(\n        self,\n        forward_batch: ForwardBatch,\n        disable_flashinfer_ragged: bool = False,\n    ):\n        assert forward_batch.num_prefix_chunks is not None\n        num_prefix_chunks = forward_batch.num_prefix_chunks\n        self.update_prefix_chunks(num_prefix_chunks)\n\n        prefix_lens = forward_batch.extend_prefix_lens\n        seq_lens = forward_batch.seq_lens\n\n        bs = len(seq_lens)\n        qo_indptr = self.qo_indptr\n        qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)\n        qo_indptr = qo_indptr[: bs + 1]\n\n        for chunk_idx in range(forward_batch.num_prefix_chunks):\n            # MHA for chunked prefix kv cache when running model with MLA\n            assert forward_batch.prefix_chunk_idx is not None\n            assert forward_batch.prefix_chunk_cu_seq_lens is not None\n            assert forward_batch.prefix_chunk_max_seq_lens is not None\n\n            kv_indptr = forward_batch.prefix_chunk_cu_seq_lens[chunk_idx]\n            wrapper = self.chunk_ragged_wrappers[chunk_idx]\n            wrapper.begin_forward(\n                qo_indptr=qo_indptr,\n                kv_indptr=kv_indptr,\n                num_qo_heads=self.num_local_heads,\n                num_kv_heads=self.num_local_heads,\n                head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,\n                head_dim_vo=self.v_head_dim,\n                q_data_type=self.q_data_type,\n                causal=False,\n            )\n        # ragged prefill\n        if not disable_flashinfer_ragged:\n            kv_indptr = (\n                qo_indptr\n                if not forward_batch.mha_one_shot\n                else self.kv_indptr[: bs + 1]\n            )\n            self.ragged_wrapper.begin_forward(\n                qo_indptr=qo_indptr,\n                kv_indptr=kv_indptr,\n                num_qo_heads=self.num_local_heads,\n                num_kv_heads=self.num_local_heads,\n                head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,\n                head_dim_vo=self.v_head_dim,\n                q_data_type=self.q_data_type,\n                causal=True,\n            )\n\n    def forward(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n    ):\n        logits_soft_cap = layer.logit_cap\n        if forward_batch.attn_attend_prefix_cache:\n            chunk_idx = forward_batch.prefix_chunk_idx\n            assert chunk_idx >= 0\n            wrapper = self.chunk_ragged_wrappers[chunk_idx]\n            o = wrapper.forward_return_lse(\n                q.view(-1, layer.tp_q_head_num, layer.head_dim),\n                k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),\n                v.view(-1, layer.tp_v_head_num, layer.v_head_dim).to(q.dtype),\n                causal=False,\n                sm_scale=layer.scaling,\n                logits_soft_cap=logits_soft_cap,\n            )\n        else:\n            forward = (\n                self.ragged_wrapper.forward_return_lse\n                if forward_batch.mha_return_lse\n                else self.ragged_wrapper.forward\n            )\n            o = forward(\n                q.view(-1, layer.tp_q_head_num, layer.head_dim),\n                k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),\n                v.view(-1, layer.tp_v_head_num, layer.v_head_dim).to(q.dtype),\n                causal=True,\n                sm_scale=layer.scaling,\n                logits_soft_cap=logits_soft_cap,\n            )\n        return o\n\n\nclass FlashInferMLAAttnBackend(AttentionBackend):\n    \"\"\"Flashinfer attention kernels.\"\"\"\n\n    def __init__(\n        self,\n        model_runner: ModelRunner,\n        skip_prefill: bool = False,\n        kv_indptr_buf: Optional[torch.Tensor] = None,\n        q_indptr_decode_buf: Optional[torch.Tensor] = None,\n    ):\n        super().__init__()\n\n        # Parse constants\n        self.max_context_len = model_runner.model_config.context_len\n        self.device = model_runner.device\n        self.skip_prefill = skip_prefill\n        self.enable_chunk_kv = (\n            not skip_prefill\n            and get_global_server_args().disaggregation_mode != \"decode\"\n            and not get_global_server_args().disable_chunked_prefix_cache\n            and not get_global_server_args().flashinfer_mla_disable_ragged\n        )\n        self.page_size = model_runner.page_size\n\n        # Allocate buffers\n        global global_workspace_buffer\n        if global_workspace_buffer is None:\n            # different from flashinfer zero_init_global_workspace_buffer\n            global_workspace_buffer = torch.empty(\n                envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.get(),\n                dtype=torch.uint8,\n                device=model_runner.device,\n            )\n        self.workspace_buffer = global_workspace_buffer\n\n        max_bs = model_runner.req_to_token_pool.size\n        if kv_indptr_buf is None:\n            self.kv_indptr = torch.zeros(\n                (max_bs + 1,), dtype=torch.int32, device=model_runner.device\n            )\n        else:\n            self.kv_indptr = kv_indptr_buf\n\n        if not self.skip_prefill:\n            self.qo_indptr = torch.zeros(\n                (max_bs + 1,), dtype=torch.int32, device=model_runner.device\n            )\n\n        if q_indptr_decode_buf is None:\n            self.q_indptr_decode = torch.arange(\n                0, max_bs + 1, dtype=torch.int32, device=model_runner.device\n            )\n        else:\n            self.q_indptr_decode = q_indptr_decode_buf\n\n        if is_sm100_supported():\n            self.fmha_backend = \"cutlass\"\n        else:\n            self.fmha_backend = \"auto\"\n\n        self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(\n            self.workspace_buffer, \"NHD\", backend=self.fmha_backend\n        )\n\n        if not self.skip_prefill:\n            self.prefill_wrapper_paged = BatchMLAPagedAttentionWrapper(\n                self.workspace_buffer,\n                backend=\"auto\",\n            )\n\n            # FlashinferMLA backend uses mla wrapper for target verify\n            self.prefill_wrapper_verify = BatchMLAPagedAttentionWrapper(\n                self.workspace_buffer,\n                backend=\"auto\",\n            )\n\n        self.decode_wrapper = BatchMLAPagedAttentionWrapper(\n            self.workspace_buffer, backend=\"auto\"\n        )\n\n        # Create indices updater\n        if not skip_prefill:\n            self.indices_updater_prefill = FlashInferMLAIndicesUpdaterPrefill(\n                model_runner, self\n            )\n            if self.enable_chunk_kv:\n                self.mha_chunk_kv_cache = FlashInferMhaChunkKVRunner(model_runner, self)\n\n        self.indices_updater_decode = FlashInferMLAIndicesUpdaterDecode(\n            model_runner, self\n        )\n\n        # Other metadata\n        self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None\n        self.decode_cuda_graph_metadata = {}\n        self.prefill_cuda_graph_metadata = {}  # For verify\n\n    def init_forward_metadata(self, forward_batch: ForwardBatch):\n        if forward_batch.forward_mode.is_decode_or_idle():\n            self.indices_updater_decode.update(\n                forward_batch.req_pool_indices,\n                forward_batch.seq_lens,\n                forward_batch.seq_lens_sum,\n                decode_wrapper=self.decode_wrapper,\n                init_metadata_replay=False,\n            )\n            self.forward_metadata = DecodeMetadata(self.decode_wrapper)\n        elif forward_batch.forward_mode.is_draft_extend():\n            self.indices_updater_prefill.update(\n                forward_batch.req_pool_indices,\n                forward_batch.seq_lens,\n                forward_batch.seq_lens_sum,\n                prefix_lens=None,\n                prefill_wrapper_paged=self.prefill_wrapper_paged,\n                use_ragged=False,\n                spec_info=forward_batch.spec_info,\n            )\n            self.forward_metadata = PrefillMetadata(self.prefill_wrapper_paged, False)\n        elif forward_batch.forward_mode.is_target_verify():\n            self.indices_updater_prefill.update(\n                forward_batch.req_pool_indices,\n                forward_batch.seq_lens,\n                forward_batch.seq_lens_sum,\n                prefix_lens=None,\n                prefill_wrapper_paged=self.prefill_wrapper_verify,\n                use_ragged=False,\n                spec_info=forward_batch.spec_info,\n            )\n            self.forward_metadata = PrefillMetadata(self.prefill_wrapper_verify, False)\n        else:\n            prefix_lens = forward_batch.extend_prefix_lens\n            extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)\n            use_ragged = (\n                not get_global_server_args().flashinfer_mla_disable_ragged\n                and extend_no_prefix\n                # Piecewise cuda graph should use paged prefill to be compatible with prefix cache\n                and not is_in_piecewise_cuda_graph()\n            )\n\n            self.indices_updater_prefill.update(\n                forward_batch.req_pool_indices,\n                forward_batch.seq_lens,\n                forward_batch.seq_lens_sum,\n                prefix_lens,\n                prefill_wrapper_paged=self.prefill_wrapper_paged,\n                use_ragged=use_ragged,\n            )\n            self.forward_metadata = PrefillMetadata(\n                self.prefill_wrapper_paged, use_ragged\n            )\n\n    def init_cuda_graph_state(\n        self,\n        max_bs: int,\n        max_num_tokens: int,\n        kv_indices_buf: Optional[torch.Tensor] = None,\n    ):\n        if kv_indices_buf is None:\n            cuda_graph_kv_indices = torch.zeros(\n                (max_bs * self.max_context_len,),\n                dtype=torch.int32,\n                device=\"cuda\",\n            )\n        else:\n            cuda_graph_kv_indices = kv_indices_buf\n\n        self.cuda_graph_kv_indices = cuda_graph_kv_indices\n        self.cuda_graph_qo_indptr = self.q_indptr_decode.clone()\n        self.cuda_graph_kv_indptr = self.kv_indptr.clone()\n        self.cuda_graph_kv_lens = torch.ones(\n            (max_bs,), dtype=torch.int32, device=self.device\n        )\n\n        # For fast decode plan in graph replaying\n        self.cuda_graph_qo_indptr_cpu = self.cuda_graph_qo_indptr.to(\"cpu\")\n        self.cuda_graph_kv_indptr_cpu = self.cuda_graph_kv_indptr.to(\"cpu\")\n        self.fast_decode_kwargs = {\n            \"qo_indptr_cpu\": self.cuda_graph_qo_indptr_cpu,\n            \"kv_indptr_cpu\": self.cuda_graph_kv_indptr_cpu,\n            \"kv_indices\": self.cuda_graph_kv_indices,\n        }\n\n    def init_forward_metadata_capture_cuda_graph(\n        self,\n        bs: int,\n        num_tokens: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        encoder_lens: Optional[torch.Tensor],\n        forward_mode: ForwardMode,\n        spec_info: Optional[SpecInput],\n    ):\n        if forward_mode.is_decode_or_idle():\n            decode_wrapper = BatchMLAPagedAttentionWrapper(\n                self.workspace_buffer,\n                use_cuda_graph=True,\n                qo_indptr=self.cuda_graph_qo_indptr[: num_tokens + 1],\n                kv_indptr=self.cuda_graph_kv_indptr[: num_tokens + 1],\n                kv_indices=self.cuda_graph_kv_indices,\n                kv_len_arr=self.cuda_graph_kv_lens[:num_tokens],\n                backend=\"auto\",\n            )\n\n            seq_lens_sum = seq_lens.sum().item()\n            self.indices_updater_decode.update(\n                req_pool_indices,\n                seq_lens,\n                seq_lens_sum,\n                decode_wrapper=decode_wrapper,\n                init_metadata_replay=False,\n                spec_info=spec_info,\n            )\n            self.decode_cuda_graph_metadata[bs] = decode_wrapper\n            self.forward_metadata = DecodeMetadata(decode_wrapper)\n            decode_wrapper.plan = partial(fast_mla_decode_plan, decode_wrapper)\n        elif forward_mode.is_target_verify():\n            verify_wrapper = BatchMLAPagedAttentionWrapper(\n                self.workspace_buffer,\n                use_cuda_graph=True,\n                qo_indptr=self.cuda_graph_qo_indptr[: bs + 1],\n                kv_indptr=self.cuda_graph_kv_indptr[: bs + 1],\n                kv_indices=self.cuda_graph_kv_indices,\n                kv_len_arr=self.cuda_graph_kv_lens[:bs],\n                backend=\"auto\",\n            )\n            seq_lens_sum = seq_lens.sum().item()\n            self.indices_updater_prefill.update(\n                req_pool_indices,\n                seq_lens,\n                seq_lens_sum,\n                prefix_lens=None,\n                prefill_wrapper_paged=verify_wrapper,\n                use_ragged=False,\n                spec_info=spec_info,\n            )\n            self.prefill_cuda_graph_metadata[bs] = verify_wrapper\n            self.forward_metadata = PrefillMetadata(verify_wrapper, False)\n        elif forward_mode.is_draft_extend():\n            draft_extend_wrapper = BatchMLAPagedAttentionWrapper(\n                self.workspace_buffer,\n                use_cuda_graph=True,\n                qo_indptr=self.cuda_graph_qo_indptr[: bs + 1],\n                kv_indptr=self.cuda_graph_kv_indptr[: bs + 1],\n                kv_indices=self.cuda_graph_kv_indices,\n                kv_len_arr=self.cuda_graph_kv_lens[:bs],\n                backend=\"auto\",\n            )\n            seq_lens_sum = seq_lens.sum().item()\n            self.indices_updater_prefill.update(\n                req_pool_indices,\n                seq_lens,\n                seq_lens_sum,\n                prefix_lens=None,\n                prefill_wrapper_paged=draft_extend_wrapper,\n                use_ragged=False,\n                spec_info=spec_info,\n            )\n            self.prefill_cuda_graph_metadata[bs] = draft_extend_wrapper\n            self.forward_metadata = PrefillMetadata(draft_extend_wrapper, False)\n        else:\n            raise ValueError(f\"Invalid mode: {forward_mode=}\")\n\n    def init_forward_metadata_replay_cuda_graph(\n        self,\n        bs: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        seq_lens_sum: int,\n        encoder_lens: Optional[torch.Tensor],\n        forward_mode: ForwardMode,\n        spec_info: Optional[SpecInput],\n        seq_lens_cpu: Optional[torch.Tensor],\n    ):\n        if forward_mode.is_decode_or_idle():\n            assert seq_lens_cpu is not None\n            kv_len_arr_cpu = seq_lens_cpu[:bs]\n            self.cuda_graph_kv_indptr_cpu[1 : bs + 1] = torch.cumsum(\n                kv_len_arr_cpu, dim=0\n            )\n            self.fast_decode_kwargs.update(\n                {\n                    \"qo_indptr_cpu\": self.cuda_graph_qo_indptr_cpu[: bs + 1],\n                    \"kv_indptr_cpu\": self.cuda_graph_kv_indptr_cpu[: bs + 1],\n                    \"kv_len_arr_cpu\": kv_len_arr_cpu,\n                }\n            )\n\n            self.indices_updater_decode.update(\n                req_pool_indices[:bs],\n                seq_lens[:bs],\n                seq_lens_sum,\n                decode_wrapper=self.decode_cuda_graph_metadata[bs],\n                init_metadata_replay=True,\n                spec_info=spec_info,\n                **self.fast_decode_kwargs,\n            )\n        elif forward_mode.is_target_verify():\n            self.indices_updater_prefill.update(\n                req_pool_indices[:bs],\n                seq_lens[:bs],\n                seq_lens_sum,\n                prefix_lens=None,\n                prefill_wrapper_paged=self.prefill_cuda_graph_metadata[bs],\n                use_ragged=False,\n                spec_info=spec_info,\n            )\n        elif forward_mode.is_draft_extend():\n            self.indices_updater_prefill.update(\n                req_pool_indices[:bs],\n                seq_lens[:bs],\n                seq_lens_sum,\n                prefix_lens=None,\n                prefill_wrapper_paged=self.prefill_cuda_graph_metadata[bs],\n                use_ragged=False,\n                spec_info=spec_info,\n            )\n        else:\n            raise ValueError(f\"Invalid forward mode: {forward_mode=}\")\n\n    def get_cuda_graph_seq_len_fill_value(self):\n        return 1\n\n    def init_mha_chunk_metadata(\n        self, forward_batch: ForwardBatch, disable_flashinfer_ragged: bool = False\n    ):\n        \"\"\"Init the metadata for a forward pass.\"\"\"\n        self.mha_chunk_kv_cache.update_wrapper(forward_batch, disable_flashinfer_ragged)\n\n    def forward_extend(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache: bool = True,\n        q_rope: Optional[torch.Tensor] = None,\n        k_rope: Optional[torch.Tensor] = None,\n    ):\n        if forward_batch.attn_attend_prefix_cache is not None and any(\n            forward_batch.extend_prefix_lens_cpu\n        ):  # MHA Chunk\n            assert self.enable_chunk_kv\n            assert q_rope is None\n            assert k_rope is None\n            return self.mha_chunk_kv_cache.forward(q, k, v, layer, forward_batch)\n\n        cache_loc = forward_batch.out_cache_loc\n        logits_soft_cap = layer.logit_cap\n        prefill_wrapper_paged = self.forward_metadata.prefill_wrapper\n\n        # Save kv cache\n        if save_kv_cache and k is not None:\n            assert v is not None\n            if save_kv_cache:\n                if k_rope is not None:\n                    forward_batch.token_to_kv_pool.set_mla_kv_buffer(\n                        layer, cache_loc, k, k_rope\n                    )\n                else:\n                    forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)\n        if q_rope is not None:\n            q = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)\n            q_rope = q_rope.view(\n                -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim\n            )\n\n        if self.forward_metadata.use_ragged:\n            # ragged prefill\n            if q_rope is not None:\n                q = torch.cat([q, q_rope], dim=-1)\n            qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)\n            if k_rope is not None:\n                k = torch.cat([k, k_rope], dim=-1)\n            o = self.prefill_wrapper_ragged.forward(\n                qall,\n                k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),\n                v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),\n                causal=True,\n                sm_scale=layer.scaling,\n                logits_soft_cap=logits_soft_cap,\n            )\n        else:\n            # mla paged prefill\n            k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(\n                q.dtype\n            )\n            if q_rope is None:\n                qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)\n                q, q_rope = (\n                    qall[:, :, : layer.v_head_dim],\n                    qall[:, :, layer.v_head_dim :],\n                )\n            o = q.new_empty(q.shape)\n            o = prefill_wrapper_paged.run(\n                q,\n                q_rope,\n                k_buf[:, :, : layer.v_head_dim],\n                k_buf[:, :, layer.v_head_dim :],\n                out=o,\n            )\n\n        return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)\n\n    def forward_decode(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache: bool = True,\n        # For multi-head latent attention\n        q_rope: Optional[torch.Tensor] = None,\n        k_rope: Optional[torch.Tensor] = None,\n    ):\n        decode_wrapper = self.forward_metadata.decode_wrapper\n        cache_loc = forward_batch.out_cache_loc\n\n        if k is not None:\n            assert v is not None\n            if save_kv_cache:\n                if k_rope is not None:\n                    forward_batch.token_to_kv_pool.set_mla_kv_buffer(\n                        layer,\n                        cache_loc,\n                        k,\n                        k_rope,\n                    )\n                else:\n                    forward_batch.token_to_kv_pool.set_kv_buffer(\n                        layer,\n                        cache_loc,\n                        k,\n                        v,\n                    )\n\n        # Reshape inputs\n        if q_rope is not None:\n            q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)\n            q_rope = q_rope.view(\n                -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim\n            )\n        else:\n            reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)\n            q_nope = reshaped_q[:, :, : layer.v_head_dim]\n            q_rope = reshaped_q[:, :, layer.v_head_dim :]\n\n        k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(\n            q.dtype\n        )\n\n        o = q_nope.new_empty(q_nope.shape)\n        # Direct call to run without the wrapper\n        o = decode_wrapper.run(\n            q_nope,\n            q_rope,\n            k_buffer[:, :, : layer.v_head_dim],\n            k_buffer[:, :, layer.v_head_dim :],\n            out=o,\n        )\n\n        return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)\n\n\nclass FlashInferMLAIndicesUpdaterDecode:\n    def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):\n        # Parse Constants\n        self.num_local_heads = (\n            model_runner.model_config.num_attention_heads // get_attention_tp_size()\n        )\n        self.kv_lora_rank = model_runner.model_config.kv_lora_rank\n        self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim\n        self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim\n        self.scaling = model_runner.model_config.scaling\n        self.data_type = model_runner.dtype\n        self.attn_backend = attn_backend\n\n        # Buffers and wrappers\n        self.kv_indptr = attn_backend.kv_indptr\n        self.req_to_token = model_runner.req_to_token_pool.req_to_token\n        self.q_indptr = attn_backend.q_indptr_decode\n\n    def update(\n        self,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        seq_lens_sum: int,\n        decode_wrapper: BatchMLAPagedAttentionWrapper,\n        init_metadata_replay: bool = False,\n        spec_info: Optional[SpecInput] = None,\n        **fast_decode_kwargs,\n    ):\n        decode_wrapper = decode_wrapper or self.decode_wrapper\n        self.call_begin_forward(\n            decode_wrapper,\n            req_pool_indices,\n            seq_lens,\n            seq_lens_sum,\n            self.q_indptr,\n            self.kv_indptr,\n            init_metadata_replay,\n            spec_info,\n            **fast_decode_kwargs,\n        )\n\n    def call_begin_forward(\n        self,\n        wrapper: BatchMLAPagedAttentionWrapper,\n        req_pool_indices: torch.Tensor,\n        paged_kernel_lens: torch.Tensor,\n        paged_kernel_lens_sum: int,\n        q_indptr: torch.Tensor,\n        kv_indptr: torch.Tensor,\n        init_metadata_replay: bool = False,\n        spec_info: Optional[SpecInput] = None,\n        **fast_decode_kwargs,\n    ):\n        bs = len(req_pool_indices)\n        q_indptr = q_indptr[: bs + 1]\n        kv_lens = paged_kernel_lens.to(torch.int32)\n        sm_scale = self.scaling\n        if spec_info is None:\n            kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)\n            kv_indptr = kv_indptr[: bs + 1]\n            kv_indices = (\n                torch.empty(paged_kernel_lens_sum, dtype=torch.int32, device=\"cuda\")\n                if not init_metadata_replay\n                else fast_decode_kwargs[\"kv_indices\"]\n            )\n            create_flashinfer_kv_indices_triton[(bs,)](\n                self.req_to_token,\n                req_pool_indices,\n                paged_kernel_lens,\n                kv_indptr,\n                None,\n                kv_indices,\n                self.req_to_token.shape[1],\n            )\n        else:\n            kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices\n\n        if not init_metadata_replay:\n            wrapper.plan(\n                q_indptr,\n                kv_indptr,\n                kv_indices,\n                kv_lens,\n                self.num_local_heads,\n                self.kv_lora_rank,\n                self.qk_rope_head_dim,\n                1,\n                False,\n                sm_scale,\n                self.data_type,\n                self.data_type,\n            )\n        else:\n            wrapper.plan(\n                fast_decode_kwargs[\"qo_indptr_cpu\"],\n                fast_decode_kwargs[\"kv_indptr_cpu\"],\n                kv_indices,\n                fast_decode_kwargs[\"kv_len_arr_cpu\"],\n                self.num_local_heads,\n                self.kv_lora_rank,\n                self.qk_rope_head_dim,\n                1,\n                False,\n                sm_scale,\n                self.data_type,\n                self.data_type,\n            )\n\n\nclass FlashInferMLAIndicesUpdaterPrefill:\n    def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):\n        # Parse Constants\n        self.num_local_heads = (\n            model_runner.model_config.num_attention_heads // get_attention_tp_size()\n        )\n        self.kv_lora_rank = model_runner.model_config.kv_lora_rank\n        self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim\n        self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim\n        self.v_head_dim = model_runner.model_config.v_head_dim\n        self.scaling = model_runner.model_config.scaling\n        self.data_type = model_runner.dtype\n        self.q_data_type = model_runner.dtype\n        self.attn_backend = attn_backend\n\n        # Buffers and wrappers\n        self.kv_indptr = attn_backend.kv_indptr\n        self.qo_indptr = attn_backend.qo_indptr\n        self.req_to_token = model_runner.req_to_token_pool.req_to_token\n        self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged\n\n    def update(\n        self,\n        req_pool_indices: torch.Tnesor,\n        seq_lens: torch.Tensor,\n        seq_lens_sum: int,\n        prefix_lens: torch.Tensor,\n        prefill_wrapper_paged: BatchMLAPagedAttentionWrapper,\n        use_ragged: bool,\n        spec_info: Optional[SpecInput] = None,\n    ):\n        if use_ragged:\n            paged_kernel_lens = prefix_lens\n            paged_kernel_lens_sum = paged_kernel_lens.sum().item()\n        else:\n            paged_kernel_lens = seq_lens\n            paged_kernel_lens_sum = seq_lens_sum\n\n        self.call_begin_forward(\n            self.prefill_wrapper_ragged,\n            prefill_wrapper_paged,\n            req_pool_indices,\n            paged_kernel_lens,\n            paged_kernel_lens_sum,\n            seq_lens,\n            prefix_lens,\n            self.kv_indptr,\n            self.qo_indptr,\n            use_ragged,\n            spec_info,\n        )\n\n    def call_begin_forward(\n        self,\n        wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper,\n        wrapper_paged: BatchMLAPagedAttentionWrapper,\n        req_pool_indices: torch.Tensor,\n        paged_kernel_lens: torch.Tensor,\n        paged_kernel_lens_sum: int,\n        seq_lens: torch.Tensor,\n        prefix_lens: torch.Tensor,\n        kv_indptr: torch.Tensor,\n        qo_indptr: torch.Tensor,\n        use_ragged: bool,\n        spec_info: Optional[SpecInput] = None,\n    ):\n        bs = len(seq_lens)\n        sm_scale = self.scaling\n\n        if spec_info is None:\n            assert len(seq_lens) == len(req_pool_indices)\n            kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)\n            kv_indptr = kv_indptr[: bs + 1]\n            kv_indices = torch.empty(\n                paged_kernel_lens_sum,\n                dtype=torch.int32,\n                device=req_pool_indices.device,\n            )\n            create_flashinfer_kv_indices_triton[(bs,)](\n                self.req_to_token,\n                req_pool_indices,\n                paged_kernel_lens,\n                kv_indptr,\n                None,\n                kv_indices,\n                self.req_to_token.shape[1],\n            )\n            qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)\n            qo_indptr = qo_indptr[: bs + 1]\n            custom_mask = None\n        else:\n            assert isinstance(spec_info, SpecInput)\n            # TODO: Support topk > 1 with custom mask\n            kv_indices, kv_indptr, qo_indptr, custom_mask = (\n                spec_info.generate_attn_arg_prefill(\n                    req_pool_indices,\n                    paged_kernel_lens,\n                    paged_kernel_lens_sum,\n                    self.req_to_token,\n                )\n            )\n\n        if use_ragged:\n            # ragged prefill\n            wrapper_ragged.begin_forward(\n                qo_indptr=qo_indptr,\n                kv_indptr=qo_indptr,\n                num_qo_heads=self.num_local_heads,\n                num_kv_heads=self.num_local_heads,\n                head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,\n                head_dim_vo=self.v_head_dim,\n                q_data_type=self.q_data_type,\n                causal=True,\n            )\n        else:\n            # mla paged prefill\n            kv_len_arr = kv_indptr[1:] - kv_indptr[:-1]\n            wrapper_paged.plan(\n                qo_indptr,\n                kv_indptr,\n                kv_indices,\n                kv_len_arr,\n                self.num_local_heads,\n                self.kv_lora_rank,\n                self.qk_rope_head_dim,\n                1,\n                True,\n                sm_scale,\n                self.q_data_type,\n                self.data_type,\n            )\n\n\nclass FlashInferMLAMultiStepDraftBackend:\n    \"\"\"\n    Wrap multiple flashinfer mla attention backends as one for multiple consecutive\n    draft decoding steps.\n    \"\"\"\n\n    def __init__(\n        self,\n        model_runner: ModelRunner,\n        topk: int,\n        speculative_num_steps: int,\n    ):\n        from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices\n\n        if topk > 1:\n            raise ValueError(\n                \"Currently Flashinfer MLA only supports topk=1 for speculative decoding\"\n            )\n        self.topk = topk\n        self.speculative_num_steps = speculative_num_steps\n        self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices\n\n        max_bs = model_runner.req_to_token_pool.size * self.topk\n        self.kv_indptr = torch.zeros(\n            (\n                self.speculative_num_steps,\n                max_bs + 1,\n            ),\n            dtype=torch.int32,\n            device=model_runner.device,\n        )\n        self.q_indptr_decode = torch.arange(\n            0, max_bs + 1, dtype=torch.int32, device=model_runner.device\n        )\n\n        self.attn_backends = []\n        for i in range(self.speculative_num_steps - 1):\n            self.attn_backends.append(\n                FlashInferMLAAttnBackend(\n                    model_runner,\n                    skip_prefill=True,\n                    kv_indptr_buf=self.kv_indptr[i],\n                    q_indptr_decode_buf=self.q_indptr_decode,\n                )\n            )\n\n        self.max_context_len = self.attn_backends[0].max_context_len\n\n        # Cached variables for generate_draft_decode_kv_indices\n        self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]\n        self.page_size = model_runner.server_args.page_size\n\n    def common_template(\n        self,\n        forward_batch: ForwardBatch,\n        kv_indices_buffer: torch.Tensor,\n        call_fn: Callable,\n    ):\n        num_seqs = forward_batch.batch_size\n        bs = self.topk * num_seqs\n        seq_lens_sum = forward_batch.seq_lens_sum\n\n        self.generate_draft_decode_kv_indices[\n            (self.speculative_num_steps, num_seqs, self.topk)\n        ](\n            forward_batch.req_pool_indices,\n            forward_batch.req_to_token_pool.req_to_token,\n            forward_batch.seq_lens,\n            kv_indices_buffer,\n            self.kv_indptr,\n            forward_batch.positions,\n            self.pool_len,\n            kv_indices_buffer.shape[1],\n            self.kv_indptr.shape[1],\n            next_power_of_2(num_seqs),\n            next_power_of_2(self.speculative_num_steps),\n            next_power_of_2(bs),\n            self.page_size,\n        )\n\n        assert forward_batch.spec_info is not None\n        assert forward_batch.spec_info.is_draft_input()\n\n        for i in range(self.speculative_num_steps - 1):\n            forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]\n            forward_batch.spec_info.kv_indices = kv_indices_buffer[i][\n                : seq_lens_sum * self.topk + bs * (i + 1)\n            ]\n            call_fn(i, forward_batch)\n\n    def init_forward_metadata(self, forward_batch: ForwardBatch):\n        kv_indices = torch.zeros(\n            (\n                self.speculative_num_steps,\n                forward_batch.batch_size * self.topk * self.max_context_len,\n            ),\n            dtype=torch.int32,\n            device=\"cuda\",\n        )\n\n        def call_fn(i, forward_batch):\n            forward_batch.spec_info.kv_indptr = (\n                forward_batch.spec_info.kv_indptr.clone()\n            )\n            forward_batch.spec_info.kv_indices = (\n                forward_batch.spec_info.kv_indices.clone()\n            )\n            self.attn_backends[i].init_forward_metadata(forward_batch)\n\n        self.common_template(forward_batch, kv_indices, call_fn)\n\n    def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):\n        self.cuda_graph_kv_indices = torch.zeros(\n            (self.speculative_num_steps, max_bs * self.max_context_len),\n            dtype=torch.int32,\n            device=\"cuda\",\n        )\n\n        for i in range(self.speculative_num_steps - 1):\n            self.attn_backends[i].init_cuda_graph_state(\n                max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]\n            )\n\n    def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):\n        def call_fn(i, forward_batch):\n            self.attn_backends[i].init_forward_metadata_capture_cuda_graph(\n                forward_batch.batch_size,\n                forward_batch.batch_size * self.topk,\n                forward_batch.req_pool_indices,\n                forward_batch.seq_lens,\n                encoder_lens=None,\n                forward_mode=ForwardMode.DECODE,\n                spec_info=forward_batch.spec_info,\n            )\n\n        self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)\n\n    def init_forward_metadata_replay_cuda_graph(\n        self, forward_batch: ForwardBatch, bs: int\n    ):\n        def call_fn(i, forward_batch):\n            self.attn_backends[i].init_forward_metadata_replay_cuda_graph(\n                bs,\n                forward_batch.req_pool_indices,\n                forward_batch.seq_lens,\n                seq_lens_sum=-1,\n                encoder_lens=None,\n                forward_mode=ForwardMode.DECODE,\n                spec_info=forward_batch.spec_info,\n                seq_lens_cpu=forward_batch.seq_lens_cpu,\n            )\n\n        self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)\n\n\ndef fast_mla_decode_plan(\n    self,\n    qo_indptr_cpu: torch.Tensor,\n    kv_indptr_cpu: torch.Tensor,\n    kv_indices: torch.Tensor,\n    kv_len_arr_cpu: torch.Tensor,\n    num_heads: int,\n    head_dim_ckv: int,\n    head_dim_kpe: int,\n    page_size: int,\n    causal: bool,\n    sm_scale: float,\n    q_data_type: torch.dtype,\n    kv_data_type: torch.dtype,\n) -> None:\n    \"\"\"A faster version of BatchMLAPagedAttentionWrapper::plan,\n    for skipping the stream synchronization in original plan function during\n    cuda graph replaying.\n    \"\"\"\n    self._causal = causal\n    self._page_size = page_size\n    self._sm_scale = sm_scale\n\n    try:\n        # Standard version with just the required arguments (no use_profiler)\n        self._cached_module.plan(\n            self._float_workspace_buffer,\n            self._int_workspace_buffer,\n            self._pin_memory_int_workspace_buffer,\n            qo_indptr_cpu,\n            kv_indptr_cpu,\n            kv_len_arr_cpu,\n            num_heads,\n            head_dim_ckv,\n            causal,\n        )\n    except Exception as e:\n        raise RuntimeError(f\"Error in alternate MLA plan: {e}\")\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/flashmla_backend.py",
    "content": "\"\"\"\nSupport attention backend for FlashMLA.\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom dataclasses import dataclass\nfrom typing import TYPE_CHECKING, Callable, Optional, Tuple, Union\n\nimport torch\nimport triton\nfrom sgl_kernel.flash_mla import flash_mla_with_kvcache, get_mla_metadata\n\nfrom sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend\nfrom sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton\nfrom sglang.srt.layers.dp_attention import get_attention_tp_size\nfrom sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant\nfrom sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode\n\nif TYPE_CHECKING:\n    from sglang.srt.layers.radix_attention import RadixAttention\n    from sglang.srt.model_executor.model_runner import ModelRunner\n    from sglang.srt.speculative.spec_info import SpecInput\n\n\nPAGE_SIZE = 64\n\n\n@dataclass\nclass FlashMLADecodeMetadata:\n    flashmla_metadata: Optional[Tuple[torch.Tensor, torch.Tensor]] = None\n    num_splits: Optional[torch.Tensor] = None\n    block_kv_indices: Optional[torch.Tensor] = None\n\n    def __init__(\n        self,\n        flashmla_metadata: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n        num_splits: Optional[torch.Tensor] = None,\n        block_kv_indices: Optional[torch.Tensor] = None,\n    ):\n        self.flashmla_metadata = flashmla_metadata\n        self.num_splits = num_splits\n        self.block_kv_indices = block_kv_indices\n\n\nclass FlashMLABackend(FlashInferMLAAttnBackend):\n    def __init__(\n        self,\n        model_runner: ModelRunner,\n        skip_prefill: bool = False,\n        kv_indptr_buf: Optional[torch.Tensor] = None,\n        kv_last_page_len_buf: Optional[torch.Tensor] = None,\n    ):\n        super().__init__(\n            model_runner, skip_prefill, kv_indptr_buf, kv_last_page_len_buf\n        )\n\n        self.num_q_heads = (\n            model_runner.model_config.num_attention_heads // get_attention_tp_size()\n        )\n        self.req_to_token = model_runner.req_to_token_pool.req_to_token\n        self.num_local_heads = (\n            model_runner.model_config.num_attention_heads // get_attention_tp_size()\n        )\n        self.forward_metadata: Union[FlashMLADecodeMetadata] = None\n        self.kv_lora_rank = model_runner.model_config.kv_lora_rank\n        self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim\n        self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim\n        self.v_head_dim = model_runner.model_config.v_head_dim\n        self.scaling = model_runner.model_config.scaling\n        self.data_type = model_runner.kv_cache_dtype\n        self.q_data_type = model_runner.dtype\n        self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim\n        self.is_fp8_kvcache = self.data_type in {\n            torch.float8_e4m3fn,\n            torch.float8_e5m2,\n        }\n\n        self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens\n\n        self.cuda_graph_kv_indices = None\n        self.cuda_graph_mla_metadata = None\n        self.cuda_graph_num_splits = None\n        self.cuda_graph_mla_metadata_view = None\n        self.cuda_graph_num_splits_view = None\n\n    def init_forward_metadata(self, forward_batch: ForwardBatch):\n        bs = forward_batch.batch_size\n        if forward_batch.forward_mode.is_decode_or_idle():\n            max_seqlen_pad = triton.cdiv(\n                forward_batch.seq_lens_cpu.max().item(), PAGE_SIZE\n            )\n            block_kv_indices = torch.full(\n                (bs, max_seqlen_pad),\n                -1,\n                dtype=torch.int32,\n                device=forward_batch.seq_lens.device,\n            )\n            create_flashmla_kv_indices_triton[(bs,)](\n                self.req_to_token,\n                forward_batch.req_pool_indices,\n                forward_batch.seq_lens,\n                None,\n                block_kv_indices,\n                self.req_to_token.stride(0),\n                max_seqlen_pad,\n            )\n            mla_metadata, num_splits = get_mla_metadata(\n                forward_batch.seq_lens.to(torch.int32),\n                self.num_q_heads,\n                1,\n                is_fp8_kvcache=self.is_fp8_kvcache,\n            )\n            self.forward_metadata = FlashMLADecodeMetadata(\n                mla_metadata,\n                num_splits,\n                block_kv_indices,\n            )\n        elif forward_batch.forward_mode.is_target_verify():\n            seq_lens_cpu = forward_batch.seq_lens_cpu + self.num_draft_tokens\n            seq_lens = forward_batch.seq_lens + self.num_draft_tokens\n\n            max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE)\n            block_kv_indices = torch.full(\n                (bs, max_seqlen_pad),\n                -1,\n                dtype=torch.int32,\n                device=seq_lens.device,\n            )\n            create_flashmla_kv_indices_triton[(bs,)](\n                self.req_to_token,\n                forward_batch.req_pool_indices,\n                seq_lens,\n                None,\n                block_kv_indices,\n                self.req_to_token.stride(0),\n                max_seqlen_pad,\n            )\n            mla_metadata, num_splits = get_mla_metadata(\n                seq_lens.to(torch.int32),\n                self.num_draft_tokens * self.num_q_heads,\n                1,\n                is_fp8_kvcache=self.is_fp8_kvcache,\n            )\n            self.forward_metadata = FlashMLADecodeMetadata(\n                mla_metadata,\n                num_splits,\n                block_kv_indices,\n            )\n        else:\n            super().init_forward_metadata(forward_batch)\n\n    def init_cuda_graph_state(\n        self,\n        max_bs: int,\n        max_num_tokens: int,\n        block_kv_indices: Optional[torch.Tensor] = None,\n    ):\n        if block_kv_indices is None:\n            self.cuda_graph_kv_indices = torch.full(\n                (max_bs, (self.max_context_len + PAGE_SIZE) // PAGE_SIZE),\n                1,\n                dtype=torch.int32,\n                device=\"cuda\",\n            )\n        else:\n            self.cuda_graph_kv_indices = block_kv_indices\n\n        device_props = torch.cuda.get_device_properties(self.req_to_token.device)\n        max_num_sm_parts = device_props.multi_processor_count\n\n        self.cuda_graph_mla_metadata = torch.empty(\n            (max_num_sm_parts, 8),\n            dtype=torch.int32,\n            device=\"cuda\",\n        )\n        self.cuda_graph_num_splits = torch.empty(\n            max_bs + 1,\n            dtype=torch.int32,\n            device=\"cuda\",\n        )\n\n        self.cuda_graph_mla_metadata_view = None\n        self.cuda_graph_num_splits_view = None\n\n    def init_forward_metadata_capture_cuda_graph(\n        self,\n        bs: int,\n        num_tokens: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        encoder_lens: Optional[torch.Tensor],\n        forward_mode: ForwardMode,\n        spec_info: Optional[SpecInput],\n    ):\n        if forward_mode.is_decode_or_idle():\n            max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)\n\n            create_flashmla_kv_indices_triton[(bs,)](\n                self.req_to_token,\n                req_pool_indices,\n                seq_lens,\n                None,\n                self.cuda_graph_kv_indices,\n                self.req_to_token.stride(0),\n                self.cuda_graph_kv_indices.stride(0),\n            )\n            num_q_heads = self.num_q_heads\n\n            mla_metadata, num_splits = get_mla_metadata(\n                seq_lens.to(torch.int32),\n                num_q_heads,\n                1,\n                is_fp8_kvcache=self.is_fp8_kvcache,\n            )\n\n            actual_num_sm_parts = mla_metadata.shape[0]\n            assert actual_num_sm_parts <= self.cuda_graph_mla_metadata.shape[0], (\n                f\"num_sm_parts {actual_num_sm_parts} exceeds preallocated max \"\n                f\"{self.cuda_graph_mla_metadata.shape[0]}\"\n            )\n\n            self.cuda_graph_mla_metadata[:actual_num_sm_parts].copy_(mla_metadata)\n            self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)\n\n            self.cuda_graph_mla_metadata_view = self.cuda_graph_mla_metadata[\n                :actual_num_sm_parts\n            ]\n            self.cuda_graph_num_splits_view = self.cuda_graph_num_splits[: bs + 1]\n\n            self.forward_metadata = FlashMLADecodeMetadata(\n                self.cuda_graph_mla_metadata_view,\n                self.cuda_graph_num_splits_view,\n                self.cuda_graph_kv_indices[:bs, :max_seqlen_pad],\n            )\n\n        elif forward_mode.is_target_verify():\n            seq_lens = seq_lens + self.num_draft_tokens\n            max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)\n\n            create_flashmla_kv_indices_triton[(bs,)](\n                self.req_to_token,\n                req_pool_indices,\n                seq_lens,\n                None,\n                self.cuda_graph_kv_indices,\n                self.req_to_token.stride(0),\n                self.cuda_graph_kv_indices.stride(0),\n            )\n\n            mla_metadata, num_splits = get_mla_metadata(\n                seq_lens.to(torch.int32),\n                self.num_draft_tokens * self.num_q_heads,\n                1,\n                is_fp8_kvcache=self.is_fp8_kvcache,\n            )\n\n            actual_num_sm_parts = mla_metadata.shape[0]\n            assert actual_num_sm_parts <= self.cuda_graph_mla_metadata.shape[0]\n\n            self.cuda_graph_mla_metadata[:actual_num_sm_parts].copy_(mla_metadata)\n            self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)\n\n            self.cuda_graph_mla_metadata_view = self.cuda_graph_mla_metadata[\n                :actual_num_sm_parts\n            ]\n            self.cuda_graph_num_splits_view = self.cuda_graph_num_splits[: bs + 1]\n\n            self.forward_metadata = FlashMLADecodeMetadata(\n                self.cuda_graph_mla_metadata_view,\n                self.cuda_graph_num_splits_view,\n                self.cuda_graph_kv_indices[:bs, :max_seqlen_pad],\n            )\n        else:\n            super().init_forward_metadata_capture_cuda_graph(\n                bs,\n                num_tokens,\n                req_pool_indices,\n                seq_lens,\n                encoder_lens,\n                forward_mode,\n                spec_info,\n            )\n\n    def init_forward_metadata_replay_cuda_graph(\n        self,\n        bs: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        seq_lens_sum: int,\n        encoder_lens: Optional[torch.Tensor],\n        forward_mode: ForwardMode,\n        spec_info: Optional[SpecInput],\n        seq_lens_cpu: Optional[torch.Tensor],\n    ):\n        if forward_mode.is_decode_or_idle():\n            assert seq_lens_cpu is not None\n            seq_lens = seq_lens[:bs]\n            seq_lens_cpu = seq_lens_cpu[:bs]\n            max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE)\n\n            create_flashmla_kv_indices_triton[(bs,)](\n                self.req_to_token,\n                req_pool_indices[:bs],\n                seq_lens,\n                None,\n                self.cuda_graph_kv_indices,\n                self.req_to_token.stride(0),\n                self.cuda_graph_kv_indices.stride(0),\n            )\n            num_q_heads = self.num_q_heads\n\n            mla_metadata, num_splits = get_mla_metadata(\n                seq_lens.to(torch.int32),\n                num_q_heads,\n                1,\n                is_fp8_kvcache=self.is_fp8_kvcache,\n            )\n\n            actual_num_sm_parts = mla_metadata.shape[0]\n\n            if actual_num_sm_parts != self.cuda_graph_mla_metadata_view.shape[0]:\n                import logging\n\n                logger = logging.getLogger(__name__)\n                logger.warning(\n                    f\"num_sm_parts mismatch in CUDA Graph replay: \"\n                    f\"capture={self.cuda_graph_mla_metadata_view.shape[0]}, \"\n                    f\"replay={actual_num_sm_parts}. \"\n                    f\"This may indicate batch size changed between capture and replay.\"\n                )\n                self.cuda_graph_mla_metadata_view = self.cuda_graph_mla_metadata[\n                    :actual_num_sm_parts\n                ]\n                self.cuda_graph_num_splits_view = self.cuda_graph_num_splits[: bs + 1]\n\n            self.cuda_graph_mla_metadata[:actual_num_sm_parts].copy_(mla_metadata)\n            self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)\n\n            self.forward_metadata.mla_metadata = self.cuda_graph_mla_metadata_view\n            self.forward_metadata.num_splits = self.cuda_graph_num_splits_view\n            self.forward_metadata.block_kv_indices = self.cuda_graph_kv_indices[\n                :bs, :max_seqlen_pad\n            ]\n\n        elif forward_mode.is_target_verify():\n            seq_lens = seq_lens[:bs] + self.num_draft_tokens\n            seq_lens_cpu = seq_lens_cpu[:bs] + self.num_draft_tokens\n            max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE)\n\n            create_flashmla_kv_indices_triton[(bs,)](\n                self.req_to_token,\n                req_pool_indices[:bs],\n                seq_lens,\n                None,\n                self.cuda_graph_kv_indices,\n                self.req_to_token.stride(0),\n                self.cuda_graph_kv_indices.stride(0),\n            )\n\n            mla_metadata, num_splits = get_mla_metadata(\n                seq_lens.to(torch.int32),\n                self.num_draft_tokens * self.num_q_heads,\n                1,\n                is_fp8_kvcache=self.is_fp8_kvcache,\n            )\n\n            actual_num_sm_parts = mla_metadata.shape[0]\n\n            if actual_num_sm_parts != self.cuda_graph_mla_metadata_view.shape[0]:\n                self.cuda_graph_mla_metadata_view = self.cuda_graph_mla_metadata[\n                    :actual_num_sm_parts\n                ]\n                self.cuda_graph_num_splits_view = self.cuda_graph_num_splits[: bs + 1]\n\n            self.cuda_graph_mla_metadata[:actual_num_sm_parts].copy_(mla_metadata)\n            self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)\n\n            self.forward_metadata.mla_metadata = self.cuda_graph_mla_metadata_view\n            self.forward_metadata.num_splits = self.cuda_graph_num_splits_view\n            self.forward_metadata.block_kv_indices = self.cuda_graph_kv_indices[\n                :bs, :max_seqlen_pad\n            ]\n        else:\n            super().init_forward_metadata_replay_cuda_graph(\n                bs,\n                req_pool_indices,\n                seq_lens,\n                seq_lens_sum,\n                encoder_lens,\n                forward_mode,\n                spec_info,\n                seq_lens_cpu,\n            )\n\n    def get_cuda_graph_seq_len_fill_value(self):\n        return 1\n\n    def forward_decode(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache: bool = True,\n    ):\n        cache_loc = forward_batch.out_cache_loc\n\n        if k is not None:\n            assert v is not None\n            if save_kv_cache:\n                forward_batch.token_to_kv_pool.set_kv_buffer(\n                    layer,\n                    cache_loc,\n                    k,\n                    v,\n                )\n        bs = forward_batch.batch_size\n        k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)\n\n        reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)\n        if self.is_fp8_kvcache:\n            if layer.k_scale is not None:\n                q_scale = layer.k_scale\n                descale_q = layer.k_scale.reshape(1)\n                descale_k = layer.k_scale.reshape(1)\n            else:\n                q_scale = torch.ones((1,), dtype=torch.float32, device=reshape_q.device)\n                descale_q = torch.ones(\n                    (1,), dtype=torch.float32, device=reshape_q.device\n                )\n                descale_k = torch.ones(\n                    (1,), dtype=torch.float32, device=reshape_q.device\n                )\n\n            q_shape = reshape_q.shape\n            reshape_q_2d = reshape_q.reshape(-1, q_shape[-1])\n            reshape_q_fp8_2d, _ = scaled_fp8_quant(reshape_q_2d, q_scale)\n            reshape_q_fp8 = reshape_q_fp8_2d.reshape(q_shape)\n            o, _ = flash_mla_with_kvcache(\n                q=reshape_q_fp8,\n                k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim),\n                block_table=self.forward_metadata.block_kv_indices[:bs],\n                cache_seqlens=forward_batch.seq_lens.to(torch.int32),\n                head_dim_v=self.kv_lora_rank,\n                tile_scheduler_metadata=self.forward_metadata.flashmla_metadata,\n                num_splits=self.forward_metadata.num_splits,\n                softmax_scale=layer.scaling,\n                causal=True,\n                descale_q=descale_q,\n                descale_k=descale_k,\n            )\n\n            return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)\n        else:\n            o, _ = flash_mla_with_kvcache(\n                q=reshape_q,\n                k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim),\n                block_table=self.forward_metadata.block_kv_indices[:bs],\n                cache_seqlens=forward_batch.seq_lens.to(torch.int32),\n                head_dim_v=self.kv_lora_rank,\n                tile_scheduler_metadata=self.forward_metadata.flashmla_metadata,\n                num_splits=self.forward_metadata.num_splits,\n                softmax_scale=layer.scaling,\n                causal=True,\n            )\n\n            return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)\n\n    def forward_extend(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache: bool = True,\n    ):\n        if (\n            forward_batch.forward_mode == ForwardMode.EXTEND\n            or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND\n        ):\n            return super().forward_extend(q, k, v, layer, forward_batch, save_kv_cache)\n        else:\n            cache_loc = forward_batch.out_cache_loc\n\n            if k is not None:\n                assert v is not None\n                if save_kv_cache:\n                    forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)\n\n            bs = forward_batch.batch_size\n            k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)\n\n            reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)\n            if self.is_fp8_kvcache:\n                if layer.k_scale is not None:\n                    q_scale = layer.k_scale\n                    descale_q = layer.k_scale.reshape(1)\n                    descale_k = layer.k_scale.reshape(1)\n                else:\n                    q_scale = torch.ones(\n                        (1,), dtype=torch.float32, device=reshape_q.device\n                    )\n                    descale_q = torch.ones(\n                        (1,), dtype=torch.float32, device=reshape_q.device\n                    )\n                    descale_k = torch.ones(\n                        (1,), dtype=torch.float32, device=reshape_q.device\n                    )\n\n                q_shape = reshape_q.shape\n                reshape_q_2d = reshape_q.reshape(-1, q_shape[-1])\n                reshape_q_fp8_2d, _ = scaled_fp8_quant(reshape_q_2d, q_scale)\n                reshape_q_fp8 = reshape_q_fp8_2d.reshape(q_shape)\n                o, _ = flash_mla_with_kvcache(\n                    q=reshape_q_fp8,\n                    k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim),\n                    block_table=self.forward_metadata.block_kv_indices[:bs],\n                    cache_seqlens=forward_batch.seq_lens.to(torch.int32)\n                    + self.num_draft_tokens,\n                    head_dim_v=self.kv_lora_rank,\n                    tile_scheduler_metadata=self.forward_metadata.flashmla_metadata,\n                    num_splits=self.forward_metadata.num_splits,\n                    softmax_scale=layer.scaling,\n                    causal=True,\n                    descale_q=descale_q,\n                    descale_k=descale_k,\n                )\n            else:\n                o, _ = flash_mla_with_kvcache(\n                    q=reshape_q,\n                    k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim),\n                    block_table=self.forward_metadata.block_kv_indices[:bs],\n                    cache_seqlens=forward_batch.seq_lens.to(torch.int32)\n                    + self.num_draft_tokens,\n                    head_dim_v=self.kv_lora_rank,\n                    tile_scheduler_metadata=self.forward_metadata.flashmla_metadata,\n                    num_splits=self.forward_metadata.num_splits,\n                    softmax_scale=layer.scaling,\n                    causal=True,\n                )\n            return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)\n\n\nclass FlashMLAMultiStepDraftBackend:\n    def __init__(\n        self,\n        model_runner: ModelRunner,\n        topk: int,\n        speculative_num_steps: int,\n    ):\n        if topk > 1:\n            raise ValueError(\n                \"Currently FlashMLA only supports topk=1 for speculative decoding\"\n            )\n        self.topk = topk\n        self.speculative_num_steps = speculative_num_steps\n        max_bs = model_runner.req_to_token_pool.size * self.topk\n        self.kv_indptr = torch.zeros(\n            (\n                self.speculative_num_steps,\n                max_bs + 1,\n            ),\n            dtype=torch.int32,\n            device=model_runner.device,\n        )\n\n        self.attn_backends = []\n        for i in range(self.speculative_num_steps - 1):\n            self.attn_backends.append(\n                FlashMLABackend(\n                    model_runner,\n                    skip_prefill=True,\n                    kv_indptr_buf=self.kv_indptr[i],\n                    kv_last_page_len_buf=None,\n                )\n            )\n\n    def common_template(\n        self,\n        forward_batch: ForwardBatch,\n        call_fn: Callable,\n    ):\n        assert forward_batch.spec_info is not None\n\n        for i in range(self.speculative_num_steps - 1):\n            call_fn(i, forward_batch)\n\n    def init_forward_metadata(self, forward_batch: ForwardBatch):\n        def call_fn(i, forward_batch):\n            assert forward_batch.spec_info is not None\n            self.attn_backends[i].init_forward_metadata(forward_batch)\n\n        self.common_template(forward_batch, call_fn)\n\n    def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):\n        for i in range(self.speculative_num_steps - 1):\n            self.attn_backends[i].init_cuda_graph_state(\n                max_bs, max_num_tokens, block_kv_indices=None\n            )\n\n    def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):\n        def call_fn(i, forward_batch):\n            # EAGLE draft worker uses DECODE mode for draft steps\n            from sglang.srt.model_executor.forward_batch_info import ForwardMode\n\n            # Create a dummy forward_mode for draft step\n            self.attn_backends[i].init_forward_metadata_capture_cuda_graph(\n                forward_batch.batch_size,\n                forward_batch.batch_size * self.topk,\n                forward_batch.req_pool_indices,\n                forward_batch.seq_lens,\n                encoder_lens=None,\n                forward_mode=ForwardMode.DECODE,\n                spec_info=forward_batch.spec_info,\n            )\n\n        self.common_template(forward_batch, call_fn)\n\n    def init_forward_metadata_replay_cuda_graph(\n        self, forward_batch: ForwardBatch, bs: int\n    ):\n        def call_fn(i, forward_batch):\n            from sglang.srt.model_executor.forward_batch_info import ForwardMode\n\n            self.attn_backends[i].init_forward_metadata_replay_cuda_graph(\n                bs,\n                forward_batch.req_pool_indices,\n                forward_batch.seq_lens,\n                seq_lens_sum=-1,\n                encoder_lens=None,\n                forward_mode=ForwardMode.DECODE,\n                spec_info=forward_batch.spec_info,\n                seq_lens_cpu=forward_batch.seq_lens_cpu,\n            )\n\n        self.common_template(forward_batch, call_fn)\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/hybrid_attn_backend.py",
    "content": "from typing import Optional\n\nimport torch\n\nfrom sglang.srt.layers.attention.base_attn_backend import AttentionBackend\nfrom sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata\nfrom sglang.srt.layers.radix_attention import RadixAttention\nfrom sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode\nfrom sglang.srt.model_executor.model_runner import ModelRunner\nfrom sglang.srt.speculative.spec_info import SpecInput\n\n\nclass HybridAttnBackend(AttentionBackend):\n    \"\"\"Support different backends for prefill and decode.\"\"\"\n\n    def __init__(\n        self,\n        model_runner: ModelRunner,\n        prefill_backend: AttentionBackend,\n        decode_backend: AttentionBackend,\n    ):\n        self.model_runner = model_runner\n        self.prefill_backend = prefill_backend\n        self.decode_backend = decode_backend\n        self.data_type = model_runner.kv_cache_dtype\n\n    def _select_backend(self, forward_mode: ForwardMode) -> AttentionBackend:\n        \"\"\"\n        Select the appropriate attention backend based on the forward mode.\n\n        Args:\n            forward_mode: The current forward mode indicating the operation type\n\n        Returns:\n            The selected attention backend (prefill or decode)\n\n        Note:\n            - decode_or_idle: Always uses decode backend\n            - target_verify or draft_extend: Uses decode backend if speculative_attention_mode is \"decode\", otherwise prefill backend\n            - prefill: Always uses prefill backend\n        \"\"\"\n        if forward_mode.is_decode_or_idle():\n            return self.decode_backend\n        elif forward_mode.is_target_verify() or forward_mode.is_draft_extend():\n            return (\n                self.decode_backend\n                if self.model_runner.server_args.speculative_attention_mode == \"decode\"\n                else self.prefill_backend\n            )\n        else:\n            return self.prefill_backend\n\n    def init_forward_metadata(self, forward_batch: ForwardBatch):\n        backend = self._select_backend(forward_batch.forward_mode)\n        backend.init_forward_metadata(forward_batch)\n\n    def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):\n        self.decode_backend.init_cuda_graph_state(max_bs, max_num_tokens)\n        if (\n            self.model_runner.server_args.speculative_algorithm is not None\n            and self.model_runner.server_args.speculative_attention_mode == \"prefill\"\n        ):\n            # When speculative decoding is enabled, we need to initialize the backend\n            # that will be used for target_verify.\n            self.prefill_backend.init_cuda_graph_state(max_bs, max_num_tokens)\n\n    def init_forward_metadata_capture_cuda_graph(\n        self,\n        bs: int,\n        num_tokens: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        encoder_lens: Optional[torch.Tensor],\n        forward_mode: ForwardMode,\n        spec_info: Optional[SpecInput],\n    ):\n        backend = self._select_backend(forward_mode)\n        backend.init_forward_metadata_capture_cuda_graph(\n            bs,\n            num_tokens,\n            req_pool_indices,\n            seq_lens,\n            encoder_lens,\n            forward_mode,\n            spec_info,\n        )\n\n    def init_forward_metadata_replay_cuda_graph(\n        self,\n        bs: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        seq_lens_sum: int,\n        encoder_lens: Optional[torch.Tensor],\n        forward_mode: ForwardMode,\n        spec_info: Optional[SpecInput],\n        seq_lens_cpu: Optional[torch.Tensor],\n    ):\n        backend = self._select_backend(forward_mode)\n        backend.init_forward_metadata_replay_cuda_graph(\n            bs,\n            req_pool_indices,\n            seq_lens,\n            seq_lens_sum,\n            encoder_lens,\n            forward_mode,\n            spec_info,\n            seq_lens_cpu,\n        )\n\n    def get_cuda_graph_seq_len_fill_value(self):\n        return self.decode_backend.get_cuda_graph_seq_len_fill_value()\n\n    def forward(\n        self,\n        q: Optional[torch.Tensor] = None,  # For full attention\n        k: Optional[torch.Tensor] = None,  # For full attention\n        v: Optional[torch.Tensor] = None,  # For full attention\n        layer: Optional[RadixAttention] = None,\n        forward_batch: Optional[ForwardBatch] = None,\n        save_kv_cache: bool = True,\n        *,\n        mixed_qkv: Optional[torch.Tensor] = None,  # For linear attention\n        a: Optional[torch.Tensor] = None,  # For linear attention\n        b: Optional[torch.Tensor] = None,  # For linear attention\n        **kwargs,\n    ):\n        \"\"\"Forward method that supports both regular attention (q, k, v) and linear attention (mixed_qkv, a, b).\"\"\"\n        backend = self._select_backend(forward_batch.forward_mode)\n        if mixed_qkv is not None:\n            return backend.forward(\n                layer=layer,\n                forward_batch=forward_batch,\n                save_kv_cache=save_kv_cache,\n                mixed_qkv=mixed_qkv,\n                a=a,\n                b=b,\n                **kwargs,\n            )\n        return backend.forward(q, k, v, layer, forward_batch, save_kv_cache, **kwargs)\n\n    def forward_decode(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache: bool = True,\n        **kwargs,\n    ):\n        return self.decode_backend.forward_decode(\n            q, k, v, layer, forward_batch, save_kv_cache, **kwargs\n        )\n\n    def forward_extend(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache: bool = True,\n        **kwargs,\n    ):\n        backend = self._select_backend(forward_batch.forward_mode)\n        return backend.forward_extend(\n            q, k, v, layer, forward_batch, save_kv_cache, **kwargs\n        )\n\n    def get_indexer_metadata(\n        self, layer_id: int, forward_batch: ForwardBatch\n    ) -> Optional[BaseIndexerMetadata]:\n        backend = self._select_backend(forward_batch.forward_mode)\n        return backend.get_indexer_metadata(layer_id, forward_batch)\n\n    def forward(\n        self,\n        q: torch.Tensor = None,\n        k: torch.Tensor = None,\n        v: torch.Tensor = None,\n        layer: RadixAttention = None,\n        forward_batch: ForwardBatch = None,\n        save_kv_cache: bool = True,\n        **kwargs,\n    ):\n        \"\"\"Delegate forward to the appropriate backend based on forward mode.\"\"\"\n        backend = self._select_backend(forward_batch.forward_mode)\n        return backend.forward(\n            q=q,\n            k=k,\n            v=v,\n            layer=layer,\n            forward_batch=forward_batch,\n            save_kv_cache=save_kv_cache,\n            **kwargs,\n        )\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py",
    "content": "import logging\nfrom typing import Optional, Union\n\nimport torch\nimport triton\nimport triton.language as tl\n\nfrom sglang.srt.layers.attention.base_attn_backend import AttentionBackend\nfrom sglang.srt.layers.attention.mamba.causal_conv1d_triton import PAD_SLOT_ID\nfrom sglang.srt.layers.attention.mamba.mamba import MambaMixer2\nfrom sglang.srt.layers.attention.mamba.mamba2_metadata import (\n    ForwardMetadata,\n    Mamba2Metadata,\n)\nfrom sglang.srt.layers.attention.mamba.mamba_state_scatter_triton import (\n    fused_mamba_state_scatter_with_mask,\n)\nfrom sglang.srt.layers.radix_attention import RadixAttention\nfrom sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool\nfrom sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode\nfrom sglang.srt.model_executor.model_runner import ModelRunner\nfrom sglang.srt.server_args import get_global_server_args\nfrom sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput\nfrom sglang.srt.speculative.spec_info import SpecInput\nfrom sglang.srt.utils import is_cpu\n\nif not is_cpu():\n    from sglang.srt.layers.attention.fla.chunk_delta_h import (\n        CHUNK_SIZE as FLA_CHUNK_SIZE,\n    )\n\nlogger = logging.getLogger(__name__)\n\n\n# Kernel to track mamba states if needed based on track mask\n@triton.jit\ndef track_mamba_state_if_needed_kernel(\n    conv_states_ptr,\n    ssm_states_ptr,\n    cache_indices_ptr,\n    mamba_track_mask_ptr,\n    mamba_track_indices_ptr,\n    conv_state_stride_0,  # stride for first dimension (batch/pool index)\n    ssm_state_stride_0,  # stride for first dimension (batch/pool index)\n    conv_state_numel_per_row: tl.constexpr,  # total elements per row\n    ssm_state_numel_per_row: tl.constexpr,  # total elements per row\n    BLOCK_SIZE: tl.constexpr,\n):\n    \"\"\"\n    Track conv_states and ssm_states rows based on track mask.\n\n    This kernel replaces a Python loop that copies state tensors for mamba attention.\n    For each batch element, if the track mask is True, it copies the entire row from\n    the source index (cache_indices[i]) to the destination index (mamba_track_indices[i]).\n\n    Grid: (batch_size,)\n    Each block handles one batch element, using multiple threads to copy data in parallel.\n    \"\"\"\n    batch_idx = tl.program_id(0)\n\n    # Load the copy mask for this batch element\n    track_mask = tl.load(mamba_track_mask_ptr + batch_idx)\n\n    # Early exit if we don't need to track\n    if not track_mask:\n        return\n\n    # Load source and destination indices\n    src_idx = tl.load(cache_indices_ptr + batch_idx)\n    dst_idx = tl.load(mamba_track_indices_ptr + batch_idx)\n\n    # Copy conv_states\n    # Each thread handles BLOCK_SIZE elements\n    for offset in range(0, conv_state_numel_per_row, BLOCK_SIZE):\n        element_indices = offset + tl.arange(0, BLOCK_SIZE)\n        mask = element_indices < conv_state_numel_per_row\n\n        src_ptr = conv_states_ptr + src_idx * conv_state_stride_0 + element_indices\n        dst_ptr = conv_states_ptr + dst_idx * conv_state_stride_0 + element_indices\n\n        data = tl.load(src_ptr, mask=mask, other=0.0)\n        tl.store(dst_ptr, data, mask=mask)\n\n    # Copy ssm_states\n    for offset in range(0, ssm_state_numel_per_row, BLOCK_SIZE):\n        element_indices = offset + tl.arange(0, BLOCK_SIZE)\n        mask = element_indices < ssm_state_numel_per_row\n\n        src_ptr = ssm_states_ptr + src_idx * ssm_state_stride_0 + element_indices\n        dst_ptr = ssm_states_ptr + dst_idx * ssm_state_stride_0 + element_indices\n\n        data = tl.load(src_ptr, mask=mask, other=0.0)\n        tl.store(dst_ptr, data, mask=mask)\n\n\ndef track_mamba_states_if_needed(\n    conv_states: torch.Tensor,\n    ssm_states: torch.Tensor,\n    cache_indices: torch.Tensor,\n    mamba_track_mask: torch.Tensor,\n    mamba_track_indices: torch.Tensor,\n    batch_size: int,\n):\n    \"\"\"\n    Track mamba states using Triton kernel for better performance.\n\n    Args:\n        conv_states: Convolution states tensor [pool_size, ...]\n        ssm_states: SSM states tensor [pool_size, ...]\n        cache_indices: Source indices for each batch element [batch_size]\n        mamba_track_mask: Boolean mask indicating which elements to track [batch_size]\n        mamba_track_indices: Indices to track for each batch element [batch_size]\n        batch_size: Number of batch elements\n    \"\"\"\n    conv_state_numel_per_row = conv_states[0].numel()\n    ssm_state_numel_per_row = ssm_states[0].numel()\n\n    # Choose BLOCK_SIZE based on the size of the data\n    BLOCK_SIZE = 1024\n\n    # Launch kernel with batch_size blocks\n    grid = (batch_size,)\n    track_mamba_state_if_needed_kernel[grid](\n        conv_states,\n        ssm_states,\n        cache_indices,\n        mamba_track_mask,\n        mamba_track_indices,\n        conv_states.stride(0),\n        ssm_states.stride(0),\n        conv_state_numel_per_row,\n        ssm_state_numel_per_row,\n        BLOCK_SIZE,\n    )\n\n\nclass MambaAttnBackendBase(AttentionBackend):\n    def __init__(self, model_runner: ModelRunner):\n        super().__init__()\n        self.pad_slot_id = PAD_SLOT_ID\n        self.device = model_runner.device\n        self.req_to_token_pool: HybridReqToTokenPool = model_runner.req_to_token_pool\n        self.forward_metadata: ForwardMetadata = None\n        self.state_indices_list = []\n        self.query_start_loc_list = []\n        self.retrieve_next_token_list = []\n        self.retrieve_next_sibling_list = []\n        self.retrieve_parent_token_list = []\n        self.cached_cuda_graph_decode_query_start_loc: torch.Tensor = None\n        self.cached_cuda_graph_verify_query_start_loc: torch.Tensor = None\n        self.conv_states_shape: tuple[int, int] = None\n\n    def _forward_metadata(self, forward_batch: ForwardBatch):\n        bs = forward_batch.batch_size\n\n        retrieve_next_token = None\n        retrieve_next_sibling = None\n        retrieve_parent_token = None\n        track_conv_indices = None\n        track_ssm_h_src = None\n        track_ssm_h_dst = None\n        track_ssm_final_src = None\n        track_ssm_final_dst = None\n\n        mamba_cache_indices = self.req_to_token_pool.get_mamba_indices(\n            forward_batch.req_pool_indices\n        )\n\n        if forward_batch.forward_mode.is_decode_or_idle():\n            query_start_loc = torch.arange(\n                0, bs + 1, dtype=torch.int32, device=self.device\n            )\n        elif forward_batch.forward_mode.is_extend(include_draft_extend_v2=True):\n            if forward_batch.forward_mode.is_draft_extend_v2():\n                # HybridLinearAttnBackend.init_forward_metadata calls all sub-backends\n                # unconditionally, but DRAFT_EXTEND_V2 only runs full-attn layers in\n                # the draft model, so mamba metadata can be skipped.\n                query_start_loc = None\n            elif forward_batch.forward_mode.is_target_verify():\n                query_start_loc = torch.arange(\n                    0,\n                    forward_batch.input_ids.shape[0] + 1,\n                    step=forward_batch.spec_info.draft_token_num,\n                    dtype=torch.int32,\n                    device=forward_batch.input_ids.device,\n                )\n\n                if forward_batch.spec_info.topk > 1:\n                    retrieve_next_token = forward_batch.spec_info.retrive_next_token\n                    retrieve_next_sibling = forward_batch.spec_info.retrive_next_sibling\n                    # retrieve_next_token is None during dummy run so skip tensor creation\n                    if retrieve_next_token is not None:\n                        retrieve_parent_token = torch.empty_like(retrieve_next_token)\n            else:\n                query_start_loc = torch.empty(\n                    (bs + 1,), dtype=torch.int32, device=self.device\n                )\n                query_start_loc[:bs] = forward_batch.extend_start_loc\n                query_start_loc[bs] = (\n                    forward_batch.extend_start_loc[-1]\n                    + forward_batch.extend_seq_lens[-1]\n                )\n                if (\n                    forward_batch.mamba_track_mask is not None\n                    and forward_batch.mamba_track_mask.any()\n                ):\n                    track_conv_indices = self._init_track_conv_indices(\n                        query_start_loc, forward_batch\n                    )\n\n                    (\n                        track_ssm_h_src,\n                        track_ssm_h_dst,\n                        track_ssm_final_src,\n                        track_ssm_final_dst,\n                    ) = self._init_track_ssm_indices(mamba_cache_indices, forward_batch)\n        else:\n            raise ValueError(f\"Invalid forward mode: {forward_batch.forward_mode=}\")\n\n        return ForwardMetadata(\n            query_start_loc=query_start_loc,\n            mamba_cache_indices=mamba_cache_indices,\n            retrieve_next_token=retrieve_next_token,\n            retrieve_next_sibling=retrieve_next_sibling,\n            retrieve_parent_token=retrieve_parent_token,\n            track_conv_indices=track_conv_indices,\n            track_ssm_h_src=track_ssm_h_src,\n            track_ssm_h_dst=track_ssm_h_dst,\n            track_ssm_final_src=track_ssm_final_src,\n            track_ssm_final_dst=track_ssm_final_dst,\n        )\n\n    def init_forward_metadata(self, forward_batch: ForwardBatch):\n        self.forward_metadata = self._forward_metadata(forward_batch)\n\n    def _init_track_conv_indices(\n        self, query_start_loc: torch.Tensor, forward_batch: ForwardBatch\n    ):\n        \"\"\"\n        Compute indices for extracting conv states from the input sequence during extend.\n\n        In Mamba models, the conv layer maintains a sliding window of recent inputs.\n        After processing a prefill chunk, we need to save the last `conv_state_len` tokens\n        of the processed region for prefix caching.\n\n        The key insight is that FLA (Flash Linear Attention) processes sequences in chunks\n        of FLA_CHUNK_SIZE. We only track the conv state up to the last complete chunk boundary\n        (aligned_len).\n\n        start_indices is the starting token index of the conv state to track in this extend batch.\n        indices include all pos to track in this extend batch, conv_state_len for each req that\n        needs to be tracked (i.e. mamba_track_mask is True)\n\n        Returns:\n            indices: Tensor of shape [num_tracked_requests, conv_state_len] containing\n                     flattened positions into the packed input tensor.\n        \"\"\"\n        conv_state_len = self.conv_states_shape[-1]\n\n        # Calculate the end position of the last aligned chunk\n        lens_to_track = (\n            forward_batch.mamba_track_seqlens - forward_batch.extend_prefix_lens\n        )\n        mamba_cache_chunk_size = get_global_server_args().mamba_cache_chunk_size\n        aligned_len = (lens_to_track // mamba_cache_chunk_size) * mamba_cache_chunk_size\n        start_indices = query_start_loc[:-1] + aligned_len - conv_state_len\n        start_indices = start_indices[forward_batch.mamba_track_mask]\n\n        # Create indices: [batch_size, conv_state_len]\n        indices = start_indices.unsqueeze(-1) + torch.arange(\n            conv_state_len,\n            device=self.device,\n            dtype=start_indices.dtype,\n        )\n\n        return indices.clamp(0, query_start_loc[-1] - 1)\n\n    def _init_track_ssm_indices(\n        self, mamba_cache_indices: torch.Tensor, forward_batch: ForwardBatch\n    ):\n        \"\"\"\n        Compute source and destination indices for tracking SSM states for prefix caching.\n\n        After processing a prefill, we need to save the SSM recurrent state for prefix caching.\n        The FLA kernel outputs intermediate hidden states `h` at each chunk boundary,\n        plus a `last_recurrent_state` at the end of the chunked prefill size.\n\n        The challenge is that sequences may or may not end on a chunk boundary:\n          - Aligned case (len % FLA_CHUNK_SIZE == 0): In this case, FLA will store the to-cache\n            state in the last_recurrent_state.\n          - Unaligned case (len % FLA_CHUNK_SIZE != 0): The last_recurrent_state includes the\n            unaligned position, but we only want state up to the last chunk boundary.\n            We must extract from the intermediate `h` tensor at the appropriate chunk index.\n\n        We compute the src and dst indices for all requests that need to be cached\n        (i.e. mamba_track_mask is True) based on the rule above.\n\n        For example:\n        1. If chunked prefill length is < 64, then only final state has value. In this case we\n           cache `final` state.\n        2. if chunked prefill length == 64, then only final state has value. In this case we\n           cache pos 64, from `final` state\n        3. if chunked prefill length >64 and < 128, then both h and final state have value.\n           We cache pos 64 from `h` state\n        4. if chunked prefill length ==128, then both h and final state have value. We cache\n           pos 128 from `final` state. Note `h` doesn't include the pos 128.\n\n        Returns:\n            track_ssm_h_src: Source indices into the packed `h` tensor (for unaligned seqs)\n            track_ssm_h_dst: Destination cache slot indices (for unaligned seqs)\n            track_ssm_final_src: Source indices into last_recurrent_state buffer (for aligned seqs)\n            track_ssm_final_dst: Destination cache slot indices (for aligned seqs)\n        \"\"\"\n        # Move to CPU to avoid kernel launches for masking operations\n        mamba_track_mask = forward_batch.mamba_track_mask.cpu()\n        extend_seq_lens = forward_batch.extend_seq_lens.cpu()\n        mamba_track_indices = forward_batch.mamba_track_indices.cpu()\n        mamba_cache_indices = mamba_cache_indices.cpu()\n        mamba_track_seqlens = forward_batch.mamba_track_seqlens.cpu()\n        prefix_lens = forward_batch.extend_prefix_lens.cpu()\n\n        # Calculate the number of hidden states per request\n        num_h_states = (extend_seq_lens - 1) // FLA_CHUNK_SIZE + 1\n\n        # Calculate the starting offset for each sequence in the packed batch\n        track_ssm_src_offset = torch.zeros_like(num_h_states)\n        track_ssm_src_offset[1:] = torch.cumsum(num_h_states[:-1], dim=0)\n\n        # Filter variables by track mask\n        lens_to_track = mamba_track_seqlens - prefix_lens\n        lens_masked = lens_to_track[mamba_track_mask]\n        offset_masked = track_ssm_src_offset[mamba_track_mask]\n        dst_masked = mamba_track_indices[mamba_track_mask]\n\n        # Determine if the sequence ends at a chunk boundary\n        is_aligned = (lens_masked % FLA_CHUNK_SIZE) == 0\n\n        # Case 1: Aligned. Use last_recurrent_state from ssm_states.\n        track_ssm_final_src = mamba_cache_indices[mamba_track_mask][is_aligned]\n        track_ssm_final_dst = dst_masked[is_aligned]\n\n        # Case 2: Unaligned. Use intermediate state from h.\n        # TODO: if support FLA_CHUNK_SIZE % page size != 0, then need to modify this\n        not_aligned = ~is_aligned\n        track_ssm_h_src = offset_masked[not_aligned] + (\n            lens_masked[not_aligned] // FLA_CHUNK_SIZE\n        )\n        track_ssm_h_dst = dst_masked[not_aligned]\n\n        # Move back to GPU\n        return (\n            track_ssm_h_src.to(self.device, non_blocking=True),\n            track_ssm_h_dst.to(self.device, non_blocking=True),\n            track_ssm_final_src.to(self.device, non_blocking=True),\n            track_ssm_final_dst.to(self.device, non_blocking=True),\n        )\n\n    def init_forward_metadata_capture_cuda_graph(\n        self,\n        bs: int,\n        num_tokens: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        encoder_lens: Optional[torch.Tensor],\n        forward_mode: ForwardMode,\n        spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],\n    ):\n        self.forward_metadata = self._capture_metadata(\n            bs, req_pool_indices, forward_mode, spec_info\n        )\n\n    def init_forward_metadata_replay_cuda_graph(\n        self,\n        bs: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        seq_lens_sum: int,\n        encoder_lens: Optional[torch.Tensor],\n        forward_mode: ForwardMode,\n        spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],\n        seq_lens_cpu: Optional[torch.Tensor],\n    ):\n        self.forward_metadata = self._replay_metadata(\n            bs, req_pool_indices, forward_mode, spec_info, seq_lens_cpu\n        )\n\n    def init_forward_metadata_capture_cpu_graph(\n        self,\n        bs: int,\n        num_tokens: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        encoder_lens: Optional[torch.Tensor],\n        forward_mode: ForwardMode,\n        spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],\n    ):\n        self.forward_metadata = self._capture_metadata(\n            bs, req_pool_indices, forward_mode, spec_info\n        )\n\n    def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):\n        assert (\n            max_num_tokens % max_bs == 0\n        ), f\"max_num_tokens={max_num_tokens} must be divisible by max_bs={max_bs}\"\n        draft_token_num = max_num_tokens // max_bs\n        for i in range(max_bs):\n            self.state_indices_list.append(\n                torch.full(\n                    (i + 1,), self.pad_slot_id, dtype=torch.int32, device=self.device\n                )\n            )\n            self.query_start_loc_list.append(\n                torch.zeros((i + 2,), dtype=torch.int32, device=self.device)\n            )\n            self.retrieve_next_token_list.append(\n                torch.zeros(\n                    (i + 1, draft_token_num), dtype=torch.int32, device=self.device\n                )\n            )\n            self.retrieve_next_sibling_list.append(\n                torch.zeros(\n                    (i + 1, draft_token_num), dtype=torch.int32, device=self.device\n                )\n            )\n            self.retrieve_parent_token_list.append(\n                torch.zeros(\n                    (i + 1, draft_token_num), dtype=torch.int32, device=self.device\n                )\n            )\n        self.cached_cuda_graph_decode_query_start_loc = torch.arange(\n            0, max_bs + 1, dtype=torch.int32, device=self.device\n        )\n        self.cached_cuda_graph_verify_query_start_loc = torch.arange(\n            0,\n            max_bs * draft_token_num + 1,\n            step=draft_token_num,\n            dtype=torch.int32,\n            device=self.device,\n        )\n\n    def init_cpu_graph_state(self, max_bs: int, max_num_tokens: int):\n        assert (\n            max_num_tokens % max_bs == 0\n        ), f\"max_num_tokens={max_num_tokens} must be divisible by max_bs={max_bs}\"\n        for i in range(max_bs):\n            self.state_indices_list.append(\n                torch.full(\n                    (i + 1,), self.pad_slot_id, dtype=torch.int32, device=self.device\n                )\n            )\n            self.query_start_loc_list.append(\n                torch.empty((i + 2,), dtype=torch.int32, device=self.device)\n            )\n        self.cached_cuda_graph_decode_query_start_loc = torch.arange(\n            0, max_bs + 1, dtype=torch.int32, device=self.device\n        )\n\n    def _capture_metadata(\n        self,\n        bs: int,\n        req_pool_indices: torch.Tensor,\n        forward_mode: ForwardMode,\n        spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],\n    ):\n        if forward_mode.is_decode_or_idle():\n            self.query_start_loc_list[bs - 1].copy_(\n                self.cached_cuda_graph_decode_query_start_loc[: bs + 1]\n            )\n        elif forward_mode.is_target_verify():\n            self.query_start_loc_list[bs - 1].copy_(\n                self.cached_cuda_graph_verify_query_start_loc[: bs + 1]\n            )\n        else:\n            raise ValueError(f\"Invalid forward mode: {forward_mode=}\")\n        mamba_indices = self.req_to_token_pool.get_mamba_indices(req_pool_indices)\n        self.state_indices_list[bs - 1][: len(mamba_indices)].copy_(mamba_indices)\n\n        # If topk > 1, we need to use retrieve_next_token and retrieve_next_sibling to handle the eagle tree custom attention mask\n        if forward_mode.is_target_verify() and spec_info.topk > 1:\n            # They are None during cuda graph capture so skip the copy_...\n            # self.retrieve_next_token_list[bs - 1].copy_(spec_info.retrive_next_token)\n            # self.retrieve_next_sibling_list[bs - 1].copy_(spec_info.retrive_next_sibling)\n            return ForwardMetadata(\n                query_start_loc=self.query_start_loc_list[bs - 1],\n                mamba_cache_indices=self.state_indices_list[bs - 1],\n                retrieve_next_token=self.retrieve_next_token_list[bs - 1],\n                retrieve_next_sibling=self.retrieve_next_sibling_list[bs - 1],\n                retrieve_parent_token=self.retrieve_parent_token_list[bs - 1],\n            )\n        else:\n            return ForwardMetadata(\n                query_start_loc=self.query_start_loc_list[bs - 1],\n                mamba_cache_indices=self.state_indices_list[bs - 1],\n            )\n\n    def _replay_metadata(\n        self,\n        bs: int,\n        req_pool_indices: torch.Tensor,\n        forward_mode: ForwardMode,\n        spec_info: Optional[SpecInput],\n        seq_lens_cpu: Optional[torch.Tensor],\n    ):\n        num_padding = torch.count_nonzero(\n            seq_lens_cpu == self.get_cuda_graph_seq_len_fill_value()\n        )\n        # Make sure forward metadata is correctly handled for padding reqs\n        req_pool_indices[bs - num_padding :] = 0\n        mamba_indices = self.req_to_token_pool.get_mamba_indices(req_pool_indices)\n        mamba_indices[bs - num_padding :] = -1\n        self.state_indices_list[bs - 1][: len(mamba_indices)].copy_(mamba_indices)\n        if forward_mode.is_decode_or_idle():\n            if num_padding == 0:\n                self.query_start_loc_list[bs - 1].copy_(\n                    self.cached_cuda_graph_decode_query_start_loc[: bs + 1]\n                )\n            else:\n                self.query_start_loc_list[bs - 1][: bs - num_padding].copy_(\n                    self.cached_cuda_graph_decode_query_start_loc[: bs - num_padding]\n                )\n                self.query_start_loc_list[bs - 1][bs - num_padding :].fill_(\n                    bs - num_padding\n                )\n        elif forward_mode.is_target_verify():\n            if num_padding == 0:\n                self.query_start_loc_list[bs - 1].copy_(\n                    self.cached_cuda_graph_verify_query_start_loc[: bs + 1]\n                )\n            else:\n                self.query_start_loc_list[bs - 1][: bs - num_padding].copy_(\n                    self.cached_cuda_graph_verify_query_start_loc[: bs - num_padding]\n                )\n                self.query_start_loc_list[bs - 1][bs - num_padding :].fill_(\n                    (bs - num_padding) * spec_info.draft_token_num\n                )\n        else:\n            raise ValueError(f\"Invalid forward mode: {forward_mode=}\")\n\n        # If topk > 1, we need to use retrieve_next_token and retrieve_next_sibling to handle the eagle tree custom attention mask\n        if forward_mode.is_target_verify() and spec_info.topk > 1:\n            bs_without_pad = spec_info.retrive_next_token.shape[0]\n            self.retrieve_next_token_list[bs - 1][:bs_without_pad].copy_(\n                spec_info.retrive_next_token\n            )\n            self.retrieve_next_sibling_list[bs - 1][:bs_without_pad].copy_(\n                spec_info.retrive_next_sibling\n            )\n            return ForwardMetadata(\n                query_start_loc=self.query_start_loc_list[bs - 1],\n                mamba_cache_indices=self.state_indices_list[bs - 1],\n                retrieve_next_token=self.retrieve_next_token_list[bs - 1],\n                retrieve_next_sibling=self.retrieve_next_sibling_list[bs - 1],\n                retrieve_parent_token=self.retrieve_parent_token_list[bs - 1],\n            )\n        else:\n            return ForwardMetadata(\n                query_start_loc=self.query_start_loc_list[bs - 1],\n                mamba_cache_indices=self.state_indices_list[bs - 1],\n            )\n\n    def get_cuda_graph_seq_len_fill_value(self):\n        return 1  # Mamba attn does not use seq lens to index kv cache\n\n    def get_cpu_graph_seq_len_fill_value(self):\n        return 1\n\n    def _track_mamba_state_decode(\n        self,\n        forward_batch: ForwardBatch,\n        conv_states: torch.Tensor,\n        ssm_states: torch.Tensor,\n        cache_indices: torch.Tensor,\n    ):\n        \"\"\"\n        Track and copy Mamba conv/SSM states during decode for prefix caching.\n\n        During decode, each token update modifies conv_states and ssm_states in-place\n        at positions indexed by cache_indices (the working slots). For prefix caching,\n        we need to copy these updated states to persistent cache slots (mamba_track_indices)\n        so they can be prefix cached.\n\n        This delegates to `track_mamba_states_if_needed`, which performs:\n            conv_states[mamba_track_indices[i]] = conv_states[cache_indices[i]]\n            ssm_states[mamba_track_indices[i]] = ssm_states[cache_indices[i]]\n        for all requests where mamba_track_mask[i] is True.\n        \"\"\"\n        if forward_batch.mamba_track_mask is not None:\n            track_mamba_states_if_needed(\n                conv_states,\n                ssm_states,\n                cache_indices,\n                forward_batch.mamba_track_mask,\n                forward_batch.mamba_track_indices,\n                forward_batch.batch_size,\n            )\n\n    def _track_mamba_state_extend(\n        self,\n        forward_batch: ForwardBatch,\n        h: torch.Tensor,\n        ssm_states: torch.Tensor,\n        forward_metadata: ForwardMetadata,\n    ):\n        \"\"\"\n        Track and copy SSM states during extend for prefix caching.\n\n        After the FLA chunked prefill kernel runs, we need to save the SSM recurrent\n        state at the last chunk boundary so it can be reused for prefix caching.\n        The source of the state depends on whether the sequence length is aligned\n        to FLA_CHUNK_SIZE. See `_init_track_ssm_indices` for more details on how\n        the source and destination indices are computed.\n\n        Note: Conv state tracking for extend is handled separately via gather operations\n        using indices computed by `_init_track_conv_indices`.\n        \"\"\"\n        if (\n            forward_batch.mamba_track_mask is not None\n            and forward_batch.mamba_track_mask.any()\n        ):\n            h = h.squeeze(0)\n\n            if forward_metadata.track_ssm_h_src.numel() > 0:\n                ssm_states[forward_metadata.track_ssm_h_dst] = h[\n                    forward_metadata.track_ssm_h_src\n                ].to(ssm_states.dtype, copy=False)\n            if forward_metadata.track_ssm_final_src.numel() > 0:\n                ssm_states[forward_metadata.track_ssm_final_dst] = ssm_states[\n                    forward_metadata.track_ssm_final_src\n                ]\n\n\nclass Mamba2AttnBackend(MambaAttnBackendBase):\n    \"\"\"Attention backend wrapper for Mamba2Mixer kernels.\"\"\"\n\n    def __init__(self, model_runner: ModelRunner):\n        super().__init__(model_runner)\n        config = model_runner.mamba2_config\n        assert config is not None\n        self.mamba_chunk_size = config.mamba_chunk_size\n\n    def init_forward_metadata(self, forward_batch: ForwardBatch):\n        metadata = self._forward_metadata(forward_batch)\n        self.forward_metadata = Mamba2Metadata.prepare_mixed(\n            metadata,\n            self.mamba_chunk_size,\n            forward_batch,\n        )\n\n    def init_forward_metadata_capture_cuda_graph(\n        self,\n        bs: int,\n        num_tokens: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        encoder_lens: Optional[torch.Tensor],\n        forward_mode: ForwardMode,\n        spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],\n    ):\n        metadata = self._capture_metadata(bs, req_pool_indices, forward_mode, spec_info)\n        draft_token_num = spec_info.draft_token_num if spec_info is not None else 1\n        self.forward_metadata = Mamba2Metadata.prepare_decode(\n            metadata,\n            seq_lens,\n            is_target_verify=forward_mode.is_target_verify(),\n            draft_token_num=draft_token_num,\n        )\n\n    def init_forward_metadata_replay_cuda_graph(\n        self,\n        bs: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        seq_lens_sum: int,\n        encoder_lens: Optional[torch.Tensor],\n        forward_mode: ForwardMode,\n        spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],\n        seq_lens_cpu: Optional[torch.Tensor],\n    ):\n        metadata = self._replay_metadata(\n            bs, req_pool_indices, forward_mode, spec_info, seq_lens_cpu\n        )\n        draft_token_num = spec_info.draft_token_num if spec_info is not None else 1\n        self.forward_metadata = Mamba2Metadata.prepare_decode(\n            metadata,\n            seq_lens,\n            is_target_verify=forward_mode.is_target_verify(),\n            draft_token_num=draft_token_num,\n        )\n\n    def forward(\n        self,\n        mixer: MambaMixer2,\n        hidden_states: torch.Tensor,\n        output: torch.Tensor,\n        layer_id: int,\n        mup_vector: Optional[torch.Tensor] = None,\n        use_triton_causal_conv: bool = False,\n    ):\n        assert isinstance(self.forward_metadata, Mamba2Metadata)\n        layer_cache = self.req_to_token_pool.mamba2_layer_cache(layer_id)\n        return mixer.forward(\n            hidden_states=hidden_states,\n            output=output,\n            layer_cache=layer_cache,\n            metadata=self.forward_metadata,\n            mup_vector=mup_vector,\n            use_triton_causal_conv=use_triton_causal_conv,\n        )\n\n    def forward_decode(self, *args, **kwargs):\n        raise NotImplementedError(\n            \"Mamba2AttnBackend's forward is called directly instead of through HybridLinearAttnBackend, as it supports mixed prefill and decode\"\n        )\n\n    def forward_extend(self, *args, **kwargs):\n        raise NotImplementedError(\n            \"Mamba2AttnBackend's forward is called directly instead of through HybridLinearAttnBackend, as it supports mixed prefill and decode\"\n        )\n\n\nclass HybridLinearAttnBackend(AttentionBackend):\n    \"\"\"Manages a full and linear attention backend\"\"\"\n\n    def __init__(\n        self,\n        full_attn_backend: AttentionBackend,\n        linear_attn_backend: MambaAttnBackendBase,\n        full_attn_layers: list[int],\n    ):\n        self.full_attn_layers = full_attn_layers\n        self.full_attn_backend = full_attn_backend\n        self.linear_attn_backend = linear_attn_backend\n        self.attn_backend_list = [full_attn_backend, linear_attn_backend]\n\n    def init_forward_metadata(self, forward_batch: ForwardBatch):\n        for attn_backend in self.attn_backend_list:\n            attn_backend.init_forward_metadata(forward_batch)\n\n    def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):\n        for attn_backend in self.attn_backend_list:\n            attn_backend.init_cuda_graph_state(max_bs, max_num_tokens)\n\n    def init_cpu_graph_state(self, max_bs: int, max_num_tokens: int):\n        for attn_backend in self.attn_backend_list:\n            attn_backend.init_cpu_graph_state(max_bs, max_num_tokens)\n\n    def init_forward_metadata_capture_cuda_graph(\n        self,\n        bs: int,\n        num_tokens: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        encoder_lens: Optional[torch.Tensor],\n        forward_mode: ForwardMode,\n        spec_info: Optional[SpecInput],\n    ):\n        for attn_backend in self.attn_backend_list:\n            attn_backend.init_forward_metadata_capture_cuda_graph(\n                bs,\n                num_tokens,\n                req_pool_indices,\n                seq_lens,\n                encoder_lens,\n                forward_mode,\n                spec_info,\n            )\n\n    def init_forward_metadata_capture_cpu_graph(\n        self,\n        bs: int,\n        num_tokens: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        encoder_lens: Optional[torch.Tensor],\n        forward_mode: ForwardMode,\n        spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],\n    ):\n        for attn_backend in self.attn_backend_list:\n            attn_backend.init_forward_metadata_capture_cpu_graph(\n                bs,\n                num_tokens,\n                req_pool_indices,\n                seq_lens,\n                encoder_lens,\n                forward_mode,\n                spec_info,\n            )\n\n    def init_forward_metadata_replay_cuda_graph(\n        self,\n        bs: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        seq_lens_sum: int,\n        encoder_lens: Optional[torch.Tensor],\n        forward_mode: ForwardMode,\n        spec_info: Optional[SpecInput],\n        seq_lens_cpu: Optional[torch.Tensor],\n    ):\n        for attn_backend in self.attn_backend_list:\n            attn_backend.init_forward_metadata_replay_cuda_graph(\n                bs,\n                req_pool_indices,\n                seq_lens,\n                seq_lens_sum,\n                encoder_lens,\n                forward_mode,\n                spec_info,\n                seq_lens_cpu,\n            )\n\n    def get_cuda_graph_seq_len_fill_value(self):\n        return self.full_attn_backend.get_cuda_graph_seq_len_fill_value()\n\n    def get_cpu_graph_seq_len_fill_value(self):\n        return self.full_attn_backend.get_cpu_graph_seq_len_fill_value()\n\n    def forward_decode(\n        self,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache: bool = True,\n        q: Optional[torch.Tensor] = None,  # For full attention\n        k: Optional[torch.Tensor] = None,  # For full attention\n        v: Optional[torch.Tensor] = None,  # For full attention\n        mixed_qkv: Optional[torch.Tensor] = None,  # For linear attention\n        a: Optional[torch.Tensor] = None,  # For GDN linear attention\n        b: Optional[torch.Tensor] = None,  # For GDN linear attention\n        **kwargs,\n    ):\n        layer_id = layer.layer_id if layer else kwargs[\"layer_id\"]\n        if layer_id in self.full_attn_layers:\n            return self.full_attn_backend.forward_decode(\n                q, k, v, layer, forward_batch, save_kv_cache, **kwargs\n            )\n        # Linear attention backend\n        return self.linear_attn_backend.forward_decode(\n            q=q,\n            k=k,\n            v=v,\n            layer=layer,\n            forward_batch=forward_batch,\n            save_kv_cache=save_kv_cache,\n            mixed_qkv=mixed_qkv,\n            a=a,\n            b=b,\n            **kwargs,\n        )\n\n    def forward_extend(\n        self,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache: bool = True,\n        q: Optional[torch.Tensor] = None,  # For full attention\n        k: Optional[torch.Tensor] = None,  # For full attention\n        v: Optional[torch.Tensor] = None,  # For full attention\n        mixed_qkv: Optional[torch.Tensor] = None,  # For linear attention\n        a: Optional[torch.Tensor] = None,  # For GDN linear attention\n        b: Optional[torch.Tensor] = None,  # For GDN linear attention\n        **kwargs,\n    ):\n        layer_id = layer.layer_id if layer else kwargs[\"layer_id\"]\n        if layer_id in self.full_attn_layers:\n            return self.full_attn_backend.forward_extend(\n                q, k, v, layer, forward_batch, save_kv_cache, **kwargs\n            )\n        # Linear attention backend\n        return self.linear_attn_backend.forward_extend(\n            q=q,\n            k=k,\n            v=v,\n            layer=layer,\n            forward_batch=forward_batch,\n            save_kv_cache=save_kv_cache,\n            mixed_qkv=mixed_qkv,\n            a=a,\n            b=b,\n            **kwargs,\n        )\n\n    def forward(\n        self,\n        q: Optional[torch.Tensor] = None,  # For full attention\n        k: Optional[torch.Tensor] = None,  # For full attention\n        v: Optional[torch.Tensor] = None,  # For full attention\n        layer: RadixAttention = None,\n        forward_batch: ForwardBatch = None,\n        save_kv_cache: bool = True,\n        mixed_qkv: Optional[torch.Tensor] = None,  # For linear attention\n        a: Optional[torch.Tensor] = None,  # For linear attention\n        b: Optional[torch.Tensor] = None,  # For linear attention\n        **kwargs,\n    ):\n        layer_id = layer.layer_id if layer else kwargs[\"layer_id\"]\n        is_linear_attn = layer_id not in self.full_attn_layers\n\n        if forward_batch.forward_mode.is_idle():\n            if is_linear_attn:\n                return mixed_qkv.new_empty(\n                    mixed_qkv.shape[0], layer.num_v_heads, layer.head_v_dim\n                )\n            return q.new_empty(q.shape[0], layer.tp_q_head_num * layer.v_head_dim)\n        elif forward_batch.forward_mode.is_decode():\n            return self.forward_decode(\n                layer,\n                forward_batch,\n                save_kv_cache,\n                q,\n                k,\n                v,\n                mixed_qkv,\n                a,\n                b,\n                **kwargs,\n            )\n        else:\n            return self.forward_extend(\n                layer,\n                forward_batch,\n                save_kv_cache,\n                q,\n                k,\n                v,\n                mixed_qkv,\n                a,\n                b,\n                **kwargs,\n            )\n\n    def update_mamba_state_after_mtp_verify(\n        self,\n        accepted_steps: torch.Tensor,\n        mamba_track_indices: Optional[torch.Tensor],\n        mamba_steps_to_track: Optional[torch.Tensor],\n        model,\n    ):\n        \"\"\"\n        Update mamba states after MTP verify using fully fused Triton kernel.\n\n        This replaces the original advanced indexing operations with a single fused\n        gather-scatter kernel that also handles masking internally, avoiding:\n        - index_elementwise_kernel from tensor[bool_mask]\n        - index_select kernel launches\n        - nonzero kernel launches\n        \"\"\"\n        request_number = accepted_steps.shape[0]\n\n        state_indices_tensor = (\n            self.linear_attn_backend.forward_metadata.mamba_cache_indices[\n                :request_number\n            ]\n        )\n\n        mamba_caches = (\n            self.linear_attn_backend.req_to_token_pool.get_speculative_mamba2_params_all_layers()\n        )\n\n        conv_states = mamba_caches.conv[0]\n        ssm_states = mamba_caches.temporal\n        intermediate_state_cache = mamba_caches.intermediate_ssm\n        intermediate_conv_window_cache = mamba_caches.intermediate_conv_window[0]\n\n        # Use fully fused kernel that handles masking internally\n        # This avoids separate nonzero() and index_select() calls\n        fused_mamba_state_scatter_with_mask(\n            ssm_states,\n            intermediate_state_cache,\n            state_indices_tensor,\n            accepted_steps,\n        )\n        fused_mamba_state_scatter_with_mask(\n            conv_states,\n            intermediate_conv_window_cache,\n            state_indices_tensor,\n            accepted_steps,\n        )\n\n        # Track indices used for tracking mamba states for prefix cache\n        if mamba_track_indices is not None:\n            assert mamba_steps_to_track is not None\n            # Use fully fused kernel for track scatter operations\n            fused_mamba_state_scatter_with_mask(\n                ssm_states,\n                intermediate_state_cache,\n                mamba_track_indices,\n                mamba_steps_to_track,\n            )\n            fused_mamba_state_scatter_with_mask(\n                conv_states,\n                intermediate_conv_window_cache,\n                mamba_track_indices,\n                mamba_steps_to_track,\n            )\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/intel_amx_backend.py",
    "content": "from __future__ import annotations\n\nfrom typing import TYPE_CHECKING\n\nimport torch\n\nfrom sglang.srt.layers.attention.base_attn_backend import AttentionBackend\nfrom sglang.srt.model_executor.forward_batch_info import ForwardBatch\n\nif TYPE_CHECKING:\n    from sglang.srt.layers.radix_attention import RadixAttention\n    from sglang.srt.model_executor.model_runner import ModelRunner\n\n\nclass IntelAMXAttnBackend(AttentionBackend):\n    def __init__(self, model_runner: ModelRunner):\n        import sgl_kernel  # noqa: F401\n\n        super().__init__()\n        self.forward_metadata = None\n        self.device = model_runner.device\n\n        self.num_head = (\n            model_runner.model_config.num_attention_heads // model_runner.tp_size\n        )\n\n        # [NB]: `layer_id` set to 0 for qwen3-next models, as not all attn layers require kv pool\n        # using \"full_attention_layer_id_mapping\" to map which layer needs kv pool\n        layer_id = 0\n        if hasattr(model_runner.token_to_kv_pool, \"full_attention_layer_id_mapping\"):\n            layer_id = [*model_runner.token_to_kv_pool.full_attention_layer_id_mapping][\n                0\n            ]\n        self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(\n            layer_id\n        ).shape[-1]\n        self.decode_attention_fwd = torch.ops.sgl_kernel.decode_attention_cpu\n        self.extend_attention_fwd = torch.ops.sgl_kernel.extend_attention_cpu\n\n    def init_forward_metadata(self, forward_batch: ForwardBatch):\n        \"\"\"Init the metadata for a forward pass.\"\"\"\n\n        bs = forward_batch.batch_size\n        attn_logits = torch.zeros(\n            (\n                bs,\n                self.num_head,\n                8,  # self.num_kv_splits,\n                self.v_head_dim + 1,\n            ),\n            dtype=torch.float32,\n            device=self.device,\n        )\n        if forward_batch.forward_mode.is_decode_or_idle():\n            max_extend_len = None\n        else:\n            max_extend_len = torch.max(forward_batch.extend_seq_lens).item()\n        self.forward_metadata = (attn_logits, max_extend_len)\n\n    def get_cpu_graph_seq_len_fill_value(self):\n        return 1\n\n    def init_forward_metadata_capture_cpu_graph(\n        self,\n        bs: int,\n        num_tokens: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        encoder_lens,\n        forward_mode,\n        spec_info,\n    ):\n        attn_logits = torch.zeros(\n            (\n                bs,\n                self.num_head,\n                8,  # self.num_kv_splits,\n                self.v_head_dim + 1,\n            ),\n            dtype=torch.float32,\n            device=self.device,\n        )\n        max_extend_len = None\n        self.forward_metadata = (attn_logits, max_extend_len)\n\n    def init_cpu_graph_state(self, max_bs: int, max_num_tokens: int):\n        pass\n\n    def forward_extend(\n        self,\n        q,\n        k,\n        v,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache=True,\n    ):\n        if layer.qk_head_dim != layer.v_head_dim:\n            o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))\n        else:\n            o = torch.empty_like(q)\n\n        if save_kv_cache:\n            forward_batch.token_to_kv_pool.set_kv_buffer(\n                layer, forward_batch.out_cache_loc, k, v\n            )\n\n        _, max_extend_len = self.forward_metadata\n\n        self.extend_attention_fwd(\n            q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),\n            k,\n            v,\n            o.view(-1, layer.tp_q_head_num, layer.v_head_dim),\n            forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),\n            forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),\n            forward_batch.req_to_token_pool.req_to_token,\n            forward_batch.req_pool_indices,\n            forward_batch.seq_lens,\n            forward_batch.extend_seq_lens,\n            forward_batch.extend_start_loc,\n            max_extend_len,\n            layer.scaling,\n            layer.logit_cap,\n        )\n        return o\n\n    def forward_decode(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache=True,\n    ):\n        attn_logits, _ = self.forward_metadata\n\n        q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)\n\n        if layer.qk_head_dim != layer.v_head_dim:\n            o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))\n        else:\n            o = torch.empty_like(q)\n\n        self.decode_attention_fwd(\n            q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),\n            forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),\n            forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),\n            o.view(-1, layer.tp_q_head_num, layer.v_head_dim),\n            k,\n            v,\n            forward_batch.out_cache_loc,\n            attn_logits,\n            forward_batch.req_to_token_pool.req_to_token,\n            forward_batch.req_pool_indices,\n            forward_batch.seq_lens,\n            layer.scaling,\n            layer.logit_cap,\n        )\n\n        return o\n\n    def support_triton(self):\n        return False\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/linear/__init__.py",
    "content": ""
  },
  {
    "path": "python/sglang/srt/layers/attention/linear/gdn_backend.py",
    "content": "from typing import Optional, Tuple, Union\n\nimport torch\n\nfrom sglang.srt.layers.attention.fla.fused_gdn_gating import fused_gdn_gating\nfrom sglang.srt.layers.attention.hybrid_linear_attn_backend import MambaAttnBackendBase\nfrom sglang.srt.layers.attention.linear.kernels.gdn_triton import TritonGDNKernel\nfrom sglang.srt.layers.attention.linear.utils import (\n    LinearAttnKernelBackend,\n    get_linear_attn_decode_backend,\n    get_linear_attn_prefill_backend,\n)\nfrom sglang.srt.layers.attention.mamba.causal_conv1d_triton import (\n    causal_conv1d_fn,\n    causal_conv1d_update,\n)\nfrom sglang.srt.layers.radix_linear_attention import RadixLinearAttention\nfrom sglang.srt.mem_cache.memory_pool import MambaPool\nfrom sglang.srt.model_executor.forward_batch_info import ForwardBatch\nfrom sglang.srt.model_executor.model_runner import ModelRunner\nfrom sglang.srt.utils import is_cpu, is_cuda, is_npu\nfrom sglang.srt.utils.common import rank0_log\n\nif not is_cpu():\n    from sglang.srt.layers.attention.fla.chunk_delta_h import (\n        CHUNK_SIZE as FLA_CHUNK_SIZE,\n    )\n\nif is_cuda():\n    from sglang.srt.layers.attention.mamba.causal_conv1d import (\n        causal_conv1d_fn as causal_conv1d_fn_cuda,\n    )\n\n    causal_conv1d_fn = causal_conv1d_fn_cuda\nelif is_npu():\n    from sgl_kernel_npu.fla.fused_gdn_gating import fused_gdn_gating_npu\n    from sgl_kernel_npu.mamba.causal_conv1d import (\n        causal_conv1d_fn_npu,\n        causal_conv1d_update_npu,\n    )\n\n    fused_gdn_gating = fused_gdn_gating_npu\n    causal_conv1d_fn = causal_conv1d_fn_npu\n    causal_conv1d_update = causal_conv1d_update_npu\nelif is_cpu():\n    from sgl_kernel.mamba import causal_conv1d_fn_cpu, causal_conv1d_update_cpu\n\n    causal_conv1d_fn = causal_conv1d_fn_cpu\n    causal_conv1d_update = causal_conv1d_update_cpu\n    fused_gdn_gating = torch.ops.sgl_kernel.fused_gdn_gating_cpu\n\n\nclass GDNKernelDispatcher:\n    \"\"\"Dispatches GDN kernel calls to the appropriate backend per mode.\"\"\"\n\n    def __init__(\n        self,\n        decode_backend: LinearAttnKernelBackend,\n        prefill_backend: LinearAttnKernelBackend,\n    ):\n        triton_kernel = TritonGDNKernel()\n\n        if decode_backend.is_triton():\n            self.decode_kernel = triton_kernel\n        elif decode_backend.is_cutedsl():\n            if not is_cuda():\n                raise ValueError(\"CuTe DSL backend requires CUDA\")\n            from sglang.srt.layers.attention.linear.kernels.gdn_cutedsl import (\n                CuteDSLGDNKernel,\n            )\n\n            self.decode_kernel = CuteDSLGDNKernel()\n        elif decode_backend.is_flashinfer():\n            if not is_cuda():\n                raise ValueError(\"FlashInfer GDN backend requires CUDA\")\n            from sglang.srt.layers.attention.linear.kernels.gdn_flashinfer import (\n                FlashInferGDNKernel,\n            )\n\n            flashinfer_kernel = FlashInferGDNKernel()\n            self.decode_kernel = flashinfer_kernel\n        else:\n            raise ValueError(f\"Unsupported GDN decode backend: {decode_backend}\")\n\n        if prefill_backend.is_triton():\n            self.extend_kernel = triton_kernel\n        elif prefill_backend.is_cutedsl():\n            raise ValueError(\n                \"CuTe DSL backend only supports decode, not prefill. \"\n                \"Use --linear-attn-prefill-backend triton instead.\"\n            )\n        elif prefill_backend.is_flashinfer():\n            if not is_cuda():\n                raise ValueError(\"FlashInfer GDN backend requires CUDA\")\n            # Reuse the FlashInfer kernel if already created for decode\n            if decode_backend.is_flashinfer():\n                self.extend_kernel = flashinfer_kernel\n            else:\n                from sglang.srt.layers.attention.linear.kernels.gdn_flashinfer import (\n                    FlashInferGDNKernel,\n                )\n\n                flashinfer_kernel = FlashInferGDNKernel()\n                self.extend_kernel = flashinfer_kernel\n        else:\n            raise ValueError(f\"Unsupported GDN prefill backend: {prefill_backend}\")\n\n        # Verify kernel: use FlashInfer if either decode or prefill selected it\n        if decode_backend.is_flashinfer() or prefill_backend.is_flashinfer():\n            self.verify_kernel = flashinfer_kernel\n        else:\n            self.verify_kernel = triton_kernel\n\n        self.supports_packed_decode = getattr(\n            self.decode_kernel, \"supports_packed_decode\", False\n        )\n\n        rank0_log(\n            f\"GDN kernel dispatcher: decode={self.decode_kernel.__class__.__name__}, \"\n            f\"extend={self.extend_kernel.__class__.__name__}, \"\n            f\"verify={self.verify_kernel.__class__.__name__} \"\n            f\"packed_decode={self.supports_packed_decode}\"\n        )\n\n    def packed_decode(\n        self,\n        mixed_qkv: torch.Tensor,\n        a: torch.Tensor,\n        b: torch.Tensor,\n        *,\n        A_log: torch.Tensor,\n        dt_bias: torch.Tensor,\n        scale: float,\n        ssm_states: torch.Tensor,\n        cache_indices: torch.Tensor,\n        num_v_heads: int,\n        head_v_dim: int,\n        **kwargs,\n    ) -> Optional[torch.Tensor]:\n        \"\"\"Attempt packed decode. Returns output tensor or None if\n        the decode kernel does not support packed decode.\"\"\"\n        if not self.supports_packed_decode:\n            return None\n        return self.decode_kernel.packed_decode(\n            mixed_qkv,\n            a,\n            b,\n            A_log=A_log,\n            dt_bias=dt_bias,\n            scale=scale,\n            ssm_states=ssm_states,\n            cache_indices=cache_indices,\n            num_v_heads=num_v_heads,\n            head_v_dim=head_v_dim,\n            **kwargs,\n        )\n\n    def decode(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        a: torch.Tensor,\n        b: torch.Tensor,\n        *,\n        A_log: torch.Tensor,\n        dt_bias: torch.Tensor,\n        ssm_states: torch.Tensor,\n        cache_indices: torch.Tensor,\n        query_start_loc: torch.Tensor,\n        **kwargs,\n    ) -> torch.Tensor:\n        return self.decode_kernel.decode(\n            q,\n            k,\n            v,\n            a,\n            b,\n            A_log=A_log,\n            dt_bias=dt_bias,\n            ssm_states=ssm_states,\n            cache_indices=cache_indices,\n            query_start_loc=query_start_loc,\n            **kwargs,\n        )\n\n    def extend(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        g: torch.Tensor,\n        beta: torch.Tensor,\n        *,\n        ssm_states: torch.Tensor,\n        cache_indices: torch.Tensor,\n        query_start_loc: torch.Tensor,\n        **kwargs,\n    ) -> tuple:\n        return self.extend_kernel.extend(\n            q,\n            k,\n            v,\n            g,\n            beta,\n            ssm_states=ssm_states,\n            cache_indices=cache_indices,\n            query_start_loc=query_start_loc,\n            **kwargs,\n        )\n\n    def target_verify(\n        self,\n        A_log: torch.Tensor,\n        dt_bias: torch.Tensor,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        a: torch.Tensor,\n        b: torch.Tensor,\n        *,\n        ssm_states: torch.Tensor,\n        cache_indices: torch.Tensor,\n        query_start_loc: torch.Tensor,\n        **kwargs,\n    ) -> torch.Tensor:\n        return self.verify_kernel.target_verify(\n            A_log=A_log,\n            dt_bias=dt_bias,\n            q=q,\n            k=k,\n            v=v,\n            a=a,\n            b=b,\n            ssm_states=ssm_states,\n            cache_indices=cache_indices,\n            query_start_loc=query_start_loc,\n            **kwargs,\n        )\n\n\nclass GDNAttnBackend(MambaAttnBackendBase):\n    \"\"\"Attention backend for GDN (Gated Delta Network) linear attention.\"\"\"\n\n    def __init__(self, model_runner: ModelRunner):\n        super().__init__(model_runner)\n        self.conv_states_shape = (\n            model_runner.req_to_token_pool.mamba_pool.mamba_cache.conv[0].shape\n        )\n        if not is_cpu() and not is_npu():\n            assert (\n                self.conv_states_shape[-1] < FLA_CHUNK_SIZE\n            ), f\"{self.conv_states_shape[-1]=} should be less than {FLA_CHUNK_SIZE}\"\n\n        decode_backend = get_linear_attn_decode_backend()\n        prefill_backend = get_linear_attn_prefill_backend()\n        self.kernel_dispatcher = GDNKernelDispatcher(decode_backend, prefill_backend)\n\n    def forward_decode(\n        self,\n        layer: RadixLinearAttention,\n        forward_batch: ForwardBatch,\n        mixed_qkv: Union[torch.Tensor, Tuple[torch.Tensor, ...]],\n        a: torch.Tensor,\n        b: torch.Tensor,\n        **kwargs,\n    ):\n        layer_cache = self.req_to_token_pool.mamba2_layer_cache(layer.layer_id)\n        conv_states = layer_cache.conv[0]\n        ssm_states = layer_cache.temporal\n        query_start_loc = self.forward_metadata.query_start_loc\n        cache_indices = self.forward_metadata.mamba_cache_indices\n\n        assert isinstance(mixed_qkv, torch.Tensor)\n        mixed_qkv = causal_conv1d_update(\n            mixed_qkv,\n            conv_states,\n            layer.conv_weights,\n            layer.bias,\n            layer.activation,\n            conv_state_indices=cache_indices,\n        )\n\n        # Skip split + reshape + separate gating kernel by consuming\n        # the packed mixed_qkv directly in a single fused Triton kernel.\n        if self.kernel_dispatcher.supports_packed_decode:\n            core_attn_out = self.kernel_dispatcher.packed_decode(\n                mixed_qkv=mixed_qkv,\n                a=a,\n                b=b,\n                A_log=layer.A_log,\n                dt_bias=layer.dt_bias,\n                scale=layer.head_k_dim**-0.5,\n                ssm_states=ssm_states,\n                cache_indices=cache_indices,\n                num_v_heads=layer.num_v_heads,\n                head_v_dim=layer.head_v_dim,\n            )\n            self._track_mamba_state_decode(\n                forward_batch, conv_states, ssm_states, cache_indices\n            )\n            return core_attn_out\n\n        query, key, value = torch.split(\n            mixed_qkv,\n            [layer.q_dim, layer.k_dim, layer.v_dim],\n            dim=-1,\n        )\n        # Reshape from [bs, h*d] to [1, bs, h, d]\n        bs = forward_batch.batch_size\n        query = query.view(1, bs, layer.num_q_heads, layer.head_q_dim)\n        key = key.view(1, bs, layer.num_k_heads, layer.head_k_dim)\n        value = value.view(1, bs, layer.num_v_heads, layer.head_v_dim)\n\n        core_attn_out = self.kernel_dispatcher.decode(\n            q=query,\n            k=key,\n            v=value,\n            a=a,\n            b=b,\n            A_log=layer.A_log,\n            dt_bias=layer.dt_bias,\n            ssm_states=ssm_states,\n            cache_indices=cache_indices,\n            query_start_loc=query_start_loc,\n        )\n\n        self._track_mamba_state_decode(\n            forward_batch, conv_states, ssm_states, cache_indices\n        )\n\n        return core_attn_out\n\n    def forward_extend(\n        self,\n        layer: RadixLinearAttention,\n        forward_batch: ForwardBatch,\n        mixed_qkv: Union[torch.Tensor, Tuple[torch.Tensor, ...]],\n        a: torch.Tensor,\n        b: torch.Tensor,\n        **kwargs,\n    ):\n        assert isinstance(mixed_qkv, torch.Tensor)\n        seq_len = mixed_qkv.shape[0]\n\n        is_target_verify = forward_batch.forward_mode.is_target_verify()\n        forward_metadata = self.forward_metadata\n\n        query_start_loc = forward_metadata.query_start_loc\n        cache_indices = forward_metadata.mamba_cache_indices\n        retrieve_next_token = forward_metadata.retrieve_next_token\n        retrieve_next_sibling = forward_metadata.retrieve_next_sibling\n        retrieve_parent_token = forward_metadata.retrieve_parent_token\n\n        mamba_cache_params = self.req_to_token_pool.mamba2_layer_cache(layer.layer_id)\n        conv_states = mamba_cache_params.conv[0]\n        ssm_states = mamba_cache_params.temporal\n        if is_target_verify:\n            assert isinstance(mamba_cache_params, MambaPool.SpeculativeState)\n            intermediate_state_cache = mamba_cache_params.intermediate_ssm\n            intermediate_conv_window_cache = (\n                mamba_cache_params.intermediate_conv_window[0]\n            )\n            has_initial_states = torch.ones(\n                seq_len // forward_batch.spec_info.draft_token_num,\n                dtype=torch.bool,\n                device=forward_batch.input_ids.device,\n            )\n            intermediate_state_indices = torch.arange(\n                cache_indices.shape[0], dtype=torch.int32, device=cache_indices.device\n            )\n        else:\n            has_initial_states = forward_batch.extend_prefix_lens > 0\n\n        if is_target_verify:\n            batch_size = seq_len // forward_batch.spec_info.draft_token_num\n            draft_token_num = forward_batch.spec_info.draft_token_num\n            mixed_qkv_reshaped = mixed_qkv.view(\n                batch_size, draft_token_num, -1\n            ).transpose(1, 2)\n            mixed_qkv_processed = causal_conv1d_update(\n                mixed_qkv_reshaped,\n                conv_states,\n                layer.conv_weights,\n                layer.bias,\n                layer.activation,\n                conv_state_indices=cache_indices[:batch_size],\n                intermediate_conv_window=intermediate_conv_window_cache,\n                intermediate_state_indices=intermediate_state_indices[:batch_size],\n                retrieve_next_token=retrieve_next_token,\n                retrieve_next_sibling=retrieve_next_sibling,\n                retrieve_parent_token=retrieve_parent_token,\n            )\n            mixed_qkv = mixed_qkv_processed.transpose(1, 2).view(seq_len, -1)\n        else:\n            mixed_qkv = mixed_qkv.transpose(0, 1)\n            if (\n                forward_batch.mamba_track_mask is not None\n                and forward_batch.mamba_track_mask.any()\n            ):\n                conv_dst = forward_batch.mamba_track_indices\n                mixed_qkv_to_track = mixed_qkv[\n                    :, forward_metadata.track_conv_indices\n                ].transpose(0, 1)\n                mask_indices = forward_batch.mamba_track_mask.nonzero(as_tuple=True)[0]\n                conv_states[conv_dst[mask_indices]] = mixed_qkv_to_track\n\n            mixed_qkv = causal_conv1d_fn(\n                mixed_qkv,\n                layer.conv_weights,\n                layer.bias,\n                activation=layer.activation,\n                conv_states=conv_states,\n                has_initial_state=has_initial_states,\n                cache_indices=cache_indices,\n                query_start_loc=query_start_loc,\n                seq_lens_cpu=forward_batch.extend_seq_lens_cpu,\n            ).transpose(0, 1)[:seq_len]\n\n        query, key, value = torch.split(\n            mixed_qkv,\n            [layer.q_dim, layer.k_dim, layer.v_dim],\n            dim=-1,\n        )\n\n        actual_seq_len = query.shape[0]\n        query = query.view(1, actual_seq_len, layer.num_q_heads, layer.head_q_dim)\n        key = key.view(1, actual_seq_len, layer.num_k_heads, layer.head_k_dim)\n        value = value.view(1, actual_seq_len, layer.num_v_heads, layer.head_v_dim)\n\n        if is_target_verify:\n            core_attn_out = self.kernel_dispatcher.target_verify(\n                A_log=layer.A_log,\n                dt_bias=layer.dt_bias,\n                q=query,\n                k=key,\n                v=value,\n                a=a,\n                b=b,\n                ssm_states=ssm_states,\n                cache_indices=cache_indices,\n                query_start_loc=query_start_loc,\n                intermediate_states_buffer=intermediate_state_cache,\n                intermediate_state_indices=intermediate_state_indices,\n                cache_steps=forward_batch.spec_info.draft_token_num,\n                retrieve_parent_token=retrieve_parent_token,\n            )\n        else:\n            g, beta = fused_gdn_gating(layer.A_log, a, b, layer.dt_bias)\n            core_attn_out, last_recurrent_state, h = self.kernel_dispatcher.extend(\n                q=query,\n                k=key,\n                v=value,\n                g=g,\n                beta=beta,\n                ssm_states=ssm_states,\n                cache_indices=cache_indices,\n                query_start_loc=query_start_loc,\n            )\n\n            if (is_npu() or is_cpu()) and last_recurrent_state is not None:\n                last_recurrent_state = last_recurrent_state.to(\n                    ssm_states.dtype, copy=False\n                )\n                ssm_states[cache_indices] = last_recurrent_state\n\n            if h is not None:\n                self._track_mamba_state_extend(\n                    forward_batch, h, ssm_states, forward_metadata\n                )\n\n        return core_attn_out\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/linear/kda_backend.py",
    "content": "from typing import Tuple, Union\n\nimport torch\n\nfrom sglang.srt.layers.attention.hybrid_linear_attn_backend import MambaAttnBackendBase\nfrom sglang.srt.layers.attention.linear.kernels.kda_triton import TritonKDAKernel\nfrom sglang.srt.layers.attention.linear.utils import (\n    LinearAttnKernelBackend,\n    get_linear_attn_decode_backend,\n    get_linear_attn_prefill_backend,\n)\nfrom sglang.srt.layers.attention.mamba.causal_conv1d_triton import (\n    causal_conv1d_fn,\n    causal_conv1d_update,\n)\nfrom sglang.srt.layers.radix_linear_attention import RadixLinearAttention\nfrom sglang.srt.utils import is_cpu, is_npu\nfrom sglang.srt.utils.common import rank0_log\n\n# KDA always uses the triton causal_conv1d_fn (no CUDA override).\n# Only causal_conv1d_update needs platform-specific overrides for decode.\nif is_npu():\n    from sgl_kernel_npu.mamba.causal_conv1d import causal_conv1d_update_npu\n\n    causal_conv1d_update = causal_conv1d_update_npu\nelif is_cpu():\n    from sgl_kernel.mamba import causal_conv1d_update_cpu\n\n    causal_conv1d_update = causal_conv1d_update_cpu\n\nfrom sglang.srt.model_executor.forward_batch_info import ForwardBatch\nfrom sglang.srt.model_executor.model_runner import ModelRunner\n\n\nclass KDAKernelDispatcher:\n    \"\"\"Dispatches KDA kernel calls to the appropriate backend per mode.\"\"\"\n\n    def __init__(\n        self,\n        decode_backend: LinearAttnKernelBackend,\n        prefill_backend: LinearAttnKernelBackend,\n    ):\n        triton_kernel = TritonKDAKernel()\n\n        if decode_backend.is_triton():\n            self.decode_kernel = triton_kernel\n        else:\n            raise ValueError(\n                f\"Unsupported KDA decode backend: {decode_backend}. \"\n                \"KDA currently only supports 'triton'.\"\n            )\n\n        if prefill_backend.is_triton():\n            self.extend_kernel = triton_kernel\n        else:\n            raise ValueError(\n                f\"Unsupported KDA prefill backend: {prefill_backend}. \"\n                \"KDA currently only supports 'triton'.\"\n            )\n\n        rank0_log(\n            f\"KDA kernel dispatcher: decode={self.decode_kernel.__class__.__name__}, \"\n            f\"extend={self.extend_kernel.__class__.__name__}\"\n        )\n\n    def decode(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        a: torch.Tensor,\n        b: torch.Tensor,\n        *,\n        A_log: torch.Tensor,\n        dt_bias: torch.Tensor,\n        ssm_states: torch.Tensor,\n        cache_indices: torch.Tensor,\n        query_start_loc: torch.Tensor,\n        **kwargs,\n    ) -> torch.Tensor:\n        return self.decode_kernel.decode(\n            q,\n            k,\n            v,\n            a,\n            b,\n            A_log=A_log,\n            dt_bias=dt_bias,\n            ssm_states=ssm_states,\n            cache_indices=cache_indices,\n            query_start_loc=query_start_loc,\n            **kwargs,\n        )\n\n    def extend(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        g: torch.Tensor,\n        beta: torch.Tensor,\n        *,\n        ssm_states: torch.Tensor,\n        cache_indices: torch.Tensor,\n        query_start_loc: torch.Tensor,\n        **kwargs,\n    ) -> torch.Tensor:\n        return self.extend_kernel.extend(\n            q,\n            k,\n            v,\n            g,\n            beta,\n            ssm_states=ssm_states,\n            cache_indices=cache_indices,\n            query_start_loc=query_start_loc,\n            **kwargs,\n        )\n\n\nclass KDAAttnBackend(MambaAttnBackendBase):\n    \"\"\"Attention backend for KDA (Kimi Delta Attention) linear attention.\"\"\"\n\n    def __init__(self, model_runner: ModelRunner):\n        super().__init__(model_runner)\n        decode_backend = get_linear_attn_decode_backend()\n        prefill_backend = get_linear_attn_prefill_backend()\n        self.kernel_dispatcher = KDAKernelDispatcher(decode_backend, prefill_backend)\n\n    def forward_decode(\n        self,\n        layer: RadixLinearAttention,\n        mixed_qkv: Union[torch.Tensor, Tuple[torch.Tensor, ...]],\n        a: torch.Tensor,\n        b: torch.Tensor,\n        **kwargs,\n    ):\n        layer_cache = self.req_to_token_pool.mamba2_layer_cache(layer.layer_id)\n        conv_states = layer_cache.conv[0]\n        ssm_states = layer_cache.temporal\n        query_start_loc = self.forward_metadata.query_start_loc\n        cache_indices = self.forward_metadata.mamba_cache_indices\n\n        qkv = causal_conv1d_update(\n            mixed_qkv,\n            conv_states.transpose(-1, -2),\n            layer.conv_weights,\n            layer.bias,\n            activation=\"silu\",\n            conv_state_indices=cache_indices,\n        )\n        q, k, v = qkv.split([layer.q_dim, layer.k_dim, layer.v_dim], dim=-1)\n        q = q.unflatten(-1, (-1, layer.head_q_dim)).unsqueeze(0)  # n (h d) -> 1 n h d\n        k = k.unflatten(-1, (-1, layer.head_k_dim)).unsqueeze(0)  # n (h d) -> 1 n h d\n        v = v.unflatten(-1, (-1, layer.head_v_dim)).unsqueeze(0)  # n (h d) -> 1 n h d\n\n        return self.kernel_dispatcher.decode(\n            q=q,\n            k=k,\n            v=v,\n            a=a,\n            b=b,\n            A_log=layer.A_log,\n            dt_bias=layer.dt_bias,\n            ssm_states=ssm_states,\n            cache_indices=cache_indices,\n            query_start_loc=query_start_loc,\n        )\n\n    def forward_extend(\n        self,\n        layer: RadixLinearAttention,\n        forward_batch: ForwardBatch,\n        mixed_qkv: Union[torch.Tensor, Tuple[torch.Tensor, ...]],\n        a: torch.Tensor,\n        b: torch.Tensor,\n        **kwargs,\n    ):\n        query_start_loc = self.forward_metadata.query_start_loc\n        cache_indices = self.forward_metadata.mamba_cache_indices\n\n        mamba_cache_params = self.req_to_token_pool.mamba2_layer_cache(layer.layer_id)\n        conv_states = mamba_cache_params.conv[0].transpose(-1, -2)\n\n        ssm_states = mamba_cache_params.temporal\n\n        has_initial_state = forward_batch.extend_prefix_lens > 0\n\n        splits = [layer.q_dim, layer.k_dim, layer.v_dim]\n        q, k, v = mixed_qkv.transpose(0, 1).split(splits, dim=0)\n        q_conv_weight, k_conv_weight, v_conv_weight = layer.conv_weights.split(\n            splits, dim=0\n        )\n        q_conv_state, k_conv_state, v_conv_state = conv_states.split(splits, dim=-2)\n        if layer.bias is not None:\n            q_bias, k_bias, v_bias = layer.bias.split(splits, dim=0)\n        else:\n            q_bias, k_bias, v_bias = None, None, None\n\n        q = causal_conv1d_fn(\n            q,\n            q_conv_weight,\n            q_bias,\n            activation=\"silu\",\n            conv_states=q_conv_state,\n            has_initial_state=has_initial_state,\n            cache_indices=cache_indices,\n            query_start_loc=query_start_loc,\n            seq_lens_cpu=forward_batch.extend_seq_lens_cpu,\n        ).transpose(0, 1)\n        k = causal_conv1d_fn(\n            k,\n            k_conv_weight,\n            k_bias,\n            activation=\"silu\",\n            conv_states=k_conv_state,\n            has_initial_state=has_initial_state,\n            cache_indices=cache_indices,\n            query_start_loc=query_start_loc,\n            seq_lens_cpu=forward_batch.extend_seq_lens_cpu,\n        ).transpose(0, 1)\n        v = causal_conv1d_fn(\n            v,\n            v_conv_weight,\n            v_bias,\n            activation=\"silu\",\n            conv_states=v_conv_state,\n            has_initial_state=has_initial_state,\n            cache_indices=cache_indices,\n            query_start_loc=query_start_loc,\n            seq_lens_cpu=forward_batch.extend_seq_lens_cpu,\n        ).transpose(0, 1)\n\n        q = q.unflatten(-1, (-1, layer.head_q_dim)).unsqueeze(0)  # n (h d) -> 1 n h d\n        k = k.unflatten(-1, (-1, layer.head_k_dim)).unsqueeze(0)  # n (h d) -> 1 n h d\n        v = v.unflatten(-1, (-1, layer.head_v_dim)).unsqueeze(0)  # n (h d) -> 1 n h d\n\n        core_attn_out = self.kernel_dispatcher.extend(\n            q=q,\n            k=k,\n            v=v,\n            g=a,\n            beta=b,\n            ssm_states=ssm_states,\n            cache_indices=cache_indices,\n            query_start_loc=query_start_loc,\n        )\n\n        return core_attn_out\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/linear/kernels/__init__.py",
    "content": ""
  },
  {
    "path": "python/sglang/srt/layers/attention/linear/kernels/gdn_cutedsl.py",
    "content": "import torch\n\nfrom sglang.jit_kernel.cutedsl_gdn import cutedsl_fused_sigmoid_gating_delta_rule_update\nfrom sglang.srt.layers.attention.linear.kernels.kernel_backend import (\n    LinearAttnKernelBase,\n)\n\n\nclass CuteDSLGDNKernel(LinearAttnKernelBase):\n    \"\"\"CuTe DSL kernel for GDN decode (CUDA only).\"\"\"\n\n    def decode(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        a: torch.Tensor,\n        b: torch.Tensor,\n        *,\n        A_log: torch.Tensor,\n        dt_bias: torch.Tensor,\n        ssm_states: torch.Tensor,\n        cache_indices: torch.Tensor,\n        query_start_loc: torch.Tensor,\n        **kwargs,\n    ) -> torch.Tensor:\n        return cutedsl_fused_sigmoid_gating_delta_rule_update(\n            A_log=A_log,\n            dt_bias=dt_bias,\n            q=q,\n            k=k,\n            v=v,\n            a=a,\n            b=b,\n            initial_state_source=ssm_states,\n            initial_state_indices=cache_indices,\n            cu_seqlens=query_start_loc,\n            use_qk_l2norm_in_kernel=True,\n            softplus_beta=1.0,\n            softplus_threshold=20.0,\n        )\n\n    def extend(self, *args, **kwargs):\n        raise NotImplementedError(\"CuteDSLGDNKernel only supports decode\")\n\n    def target_verify(self, *args, **kwargs):\n        raise NotImplementedError(\"CuteDSLGDNKernel only supports decode\")\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/linear/kernels/gdn_flashinfer.py",
    "content": "\"\"\"FlashInfer-based kernels for GDN (Gated Delta Network) linear attention.\n\nBoth SM90 and SM100+ use the same pool layout: [pool, HV, V, K] (K-last).\n\nSM90 (Hopper): full support — decode, prefill, MTP.  State dtype: fp32.\nSM100+ (Blackwell+): decode-only with bf16 state.  More support on the way.\n\nRequires flashinfer >= 0.6.4 (SM90) or >= 0.6.5 (SM100+).\n\"\"\"\n\nimport logging\nimport os\nfrom typing import Optional\n\nimport torch\n\nfrom sglang.srt.layers.attention.linear.kernels.kernel_backend import (\n    LinearAttnKernelBase,\n)\n\nlogger = logging.getLogger(__name__)\n\n# ---------------------------------------------------------------------------\n# Lazy import for FlashInfer GDN kernels\n# ---------------------------------------------------------------------------\n_flashinfer_gdn_available: Optional[bool] = None\n_flashinfer_chunk_gated_delta_rule = None\n_flashinfer_gated_delta_rule_mtp = None\n_flashinfer_gated_delta_rule_decode = None\n\n\ndef _get_flashinfer_gdn_kernels():\n    \"\"\"Lazy import for FlashInfer GDN prefill, decode and verify (MTP) kernels.\n\n    Returns (available, prefill_fn, mtp_fn, decode_fn).\n    \"\"\"\n    global _flashinfer_gdn_available, _flashinfer_chunk_gated_delta_rule, _flashinfer_gated_delta_rule_mtp, _flashinfer_gated_delta_rule_decode\n    if _flashinfer_gdn_available is None:\n        try:\n            os.environ.setdefault(\"FLASHINFER_DISABLE_VERSION_CHECK\", \"1\")\n\n            from flashinfer.gdn_decode import (\n                gated_delta_rule_decode_pretranspose,\n                gated_delta_rule_mtp,\n            )\n            from flashinfer.gdn_prefill import chunk_gated_delta_rule\n\n            _flashinfer_chunk_gated_delta_rule = chunk_gated_delta_rule\n            _flashinfer_gated_delta_rule_mtp = gated_delta_rule_mtp\n            _flashinfer_gated_delta_rule_decode = gated_delta_rule_decode_pretranspose\n            _flashinfer_gdn_available = (\n                torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 9\n            )\n            if _flashinfer_gdn_available:\n                logger.info(\"FlashInfer GDN kernels loaded successfully\")\n        except (ImportError, RuntimeError) as e:\n            logger.warning(f\"FlashInfer GDN kernels not available: {e}\")\n            _flashinfer_gdn_available = False\n            _flashinfer_gated_delta_rule_decode = None\n    return (\n        _flashinfer_gdn_available,\n        _flashinfer_chunk_gated_delta_rule,\n        _flashinfer_gated_delta_rule_mtp,\n        _flashinfer_gated_delta_rule_decode,\n    )\n\n\n# ---------------------------------------------------------------------------\n# Kernel implementation\n# ---------------------------------------------------------------------------\n\n\nclass FlashInferGDNKernel(LinearAttnKernelBase):\n    \"\"\"FlashInfer kernel for GDN with K-last SSM state layout.\n\n    SM90 (Hopper): decode uses gather/scatter; prefill and MTP verify supported.\n    SM100+ (Blackwell+): decode uses pool API (initial_state_indices); prefill\n    and MTP verify are not supported (use Triton backend for those).\n\n    Requires flashinfer >= 0.6.4 (SM90) or >= 0.6.5 (SM100+).\n    \"\"\"\n\n    def __init__(self):\n        (\n            available,\n            self._prefill_fn,\n            self._mtp_fn,\n            self._decode_fn,\n        ) = _get_flashinfer_gdn_kernels()\n\n        if not available:\n            raise RuntimeError(\n                \"FlashInfer GDN kernels are not available. \"\n                \"Requires SM90+ and FlashInfer with GDN kernel support.\"\n            )\n        if self._decode_fn is None:\n            raise RuntimeError(\"FlashInfer GDN decode kernel is unavailable.\")\n\n        sm_major = torch.cuda.get_device_capability()[0]\n        self.use_state_pool = sm_major != 9\n\n        if sm_major == 9:\n            if self._prefill_fn is None:\n                raise RuntimeError(\"FlashInfer GDN prefill kernel is unavailable.\")\n            if self._mtp_fn is None:\n                raise RuntimeError(\"FlashInfer GDN MTP (verify) kernel is unavailable.\")\n\n        logger.info(\"Using FlashInfer GDN kernels\")\n\n    # ---- decode ----\n\n    def decode(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        a: torch.Tensor,\n        b: torch.Tensor,\n        *,\n        A_log: torch.Tensor,\n        dt_bias: torch.Tensor,\n        ssm_states: torch.Tensor,\n        cache_indices: torch.Tensor,\n        query_start_loc: torch.Tensor,\n        **kwargs,\n    ) -> torch.Tensor:\n        batch_size = cache_indices.shape[0]\n        num_heads = q.shape[2]\n        head_k_dim = q.shape[3]\n        num_v_heads = v.shape[2]\n        head_v_dim = v.shape[3]\n\n        query_fi = q.view(batch_size, 1, num_heads, head_k_dim)\n        key_fi = k.view(batch_size, 1, num_heads, head_k_dim)\n        value_fi = v.view(batch_size, 1, num_v_heads, head_v_dim)\n        a_fi = a.view(batch_size, 1, num_v_heads)\n        b_fi = b.view(batch_size, 1, num_v_heads)\n\n        if self.use_state_pool:\n            output_fi, _ = self._decode_fn(\n                q=query_fi,\n                k=key_fi,\n                v=value_fi,\n                state=None,\n                A_log=A_log.detach().float(),\n                a=a_fi,\n                dt_bias=dt_bias.detach(),\n                b=b_fi,\n                use_qk_l2norm=True,\n                initial_state=ssm_states,\n                initial_state_indices=cache_indices,\n            )\n        else:\n            # TODO: Once FlashInfer PR#2521 is merged for SM90, gather/scatter\n            # will no longer be needed here.\n            state_batch = ssm_states[cache_indices]\n            output_fi, new_state = self._decode_fn(\n                q=query_fi,\n                k=key_fi,\n                v=value_fi,\n                state=state_batch,\n                A_log=A_log.detach(),\n                a=a_fi,\n                dt_bias=dt_bias.detach(),\n                b=b_fi,\n                scale=None,\n                output=None,\n                use_qk_l2norm=True,\n            )\n            ssm_states[cache_indices] = new_state\n\n        return output_fi.view(1, batch_size, num_v_heads, head_v_dim)\n\n    # ---- extend (prefill) ----\n\n    def extend(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        g: torch.Tensor,\n        beta: torch.Tensor,\n        *,\n        ssm_states: torch.Tensor,\n        cache_indices: torch.Tensor,\n        query_start_loc: torch.Tensor,\n        **kwargs,\n    ) -> tuple:\n        if self.use_state_pool:\n            raise NotImplementedError(\n                \"FlashInfer GDN prefill is not supported on SM100+. \"\n                \"Use --linear-attn-prefill-backend triton.\"\n            )\n\n        # SM90: chunked prefill using FlashInfer GDN prefill kernel.\n        from sglang.srt.layers.attention.fla.l2norm import l2norm_fwd\n\n        total_seq_len = q.shape[1]\n        num_v_heads = v.shape[2]\n        head_v_dim = v.shape[3]\n\n        q_fi = l2norm_fwd(q[0].contiguous())\n        k_fi = l2norm_fwd(k[0].contiguous())\n        v_fi = v[0].contiguous()\n\n        # g (alpha) and beta: [1, seq, HV] -> [seq, HV], float32 for FlashInfer\n        alpha_fi = torch.exp(g[0].to(torch.float32))\n        beta_fi = beta[0].to(torch.float32)\n\n        cu_seqlens_fi = query_start_loc.to(torch.int64)\n\n        # Remap negative padding indices to sentinel slot\n        ssm_cache_indices = torch.where(\n            cache_indices >= 0,\n            cache_indices,\n            ssm_states.shape[0] - 1,\n        ).to(torch.int64)\n\n        # FlashInfer requires float32 initial state, K-last layout [B, HV, V, K]\n        initial_state_fi = ssm_states[ssm_cache_indices].to(torch.float32)\n\n        output_fi, output_state_fi = self._prefill_fn(\n            q=q_fi,\n            k=k_fi,\n            v=v_fi,\n            g=alpha_fi,\n            beta=beta_fi,\n            scale=None,\n            initial_state=initial_state_fi,\n            output_final_state=True,\n            cu_seqlens=cu_seqlens_fi,\n            use_qk_l2norm_in_kernel=False,\n        )\n\n        # Write back state to pool\n        ssm_states.index_copy_(\n            0,\n            ssm_cache_indices,\n            output_state_fi.to(ssm_states.dtype),\n        )\n\n        # Output: [seq, HV, V] -> [1, seq, HV, V]\n        core_attn_out = output_fi.view(1, total_seq_len, num_v_heads, head_v_dim)\n\n        # Return (output, last_recurrent_state, h) to match Triton kernel interface.\n        # h=None since FlashInfer doesn't provide intermediate states.\n        return core_attn_out, None, None\n\n    # ---- target_verify (MTP) ----\n\n    def target_verify(\n        self,\n        A_log: torch.Tensor,\n        dt_bias: torch.Tensor,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        a: torch.Tensor,\n        b: torch.Tensor,\n        *,\n        ssm_states: torch.Tensor,\n        cache_indices: torch.Tensor,\n        query_start_loc: torch.Tensor,\n        intermediate_states_buffer: torch.Tensor,\n        intermediate_state_indices: torch.Tensor,\n        cache_steps: int,\n        retrieve_parent_token: torch.Tensor,\n        **kwargs,\n    ) -> torch.Tensor:\n        if self.use_state_pool:\n            raise NotImplementedError(\n                \"FlashInfer GDN MTP verify is not yet supported on SM100+.\"\n            )\n\n        # SM90: MTP verify using FlashInfer gated_delta_rule_mtp kernel.\n        if retrieve_parent_token is not None:\n            raise RuntimeError(\n                \"FlashInfer GDN verify kernel only supports topk=1 \"\n                \"(retrieve_parent_token must be None).\"\n            )\n\n        seq_len = q.shape[1]\n        batch_size = query_start_loc.shape[0] - 1\n        draft_token_num = seq_len // batch_size\n\n        num_heads = q.shape[2]\n        head_k_dim = q.shape[3]\n        num_v_heads = v.shape[2]\n        head_v_dim = v.shape[3]\n\n        query_mtp = q.view(batch_size, draft_token_num, num_heads, head_k_dim)\n        key_mtp = k.view(batch_size, draft_token_num, num_heads, head_k_dim)\n        value_mtp = v.view(batch_size, draft_token_num, num_v_heads, head_v_dim)\n\n        if a is None or b is None or A_log is None or dt_bias is None:\n            raise RuntimeError(\n                \"FlashInfer GDN MTP kernel requires a, b, A_log, dt_bias.\"\n            )\n\n        a_mtp = a.view(batch_size, draft_token_num, num_v_heads)\n        b_mtp = b.view(batch_size, draft_token_num, num_v_heads)\n\n        output_fi, _ = self._mtp_fn(\n            q=query_mtp,\n            k=key_mtp,\n            v=value_mtp,\n            initial_state=ssm_states,\n            initial_state_indices=cache_indices,\n            A_log=A_log.detach(),\n            a=a_mtp,\n            dt_bias=dt_bias.detach(),\n            b=b_mtp,\n            scale=None,\n            output=None,\n            intermediate_states_buffer=intermediate_states_buffer,\n            disable_state_update=True,\n            use_qk_l2norm=True,\n        )\n\n        return output_fi.view(1, seq_len, num_v_heads, head_v_dim)\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/linear/kernels/gdn_triton.py",
    "content": "import torch\n\nfrom sglang.srt.layers.attention.linear.kernels.kernel_backend import (\n    LinearAttnKernelBase,\n)\nfrom sglang.srt.utils import is_cpu, is_npu\n\nif not is_cpu():\n    from sglang.srt.layers.attention.fla.chunk import chunk_gated_delta_rule\n    from sglang.srt.layers.attention.fla.fused_recurrent import (\n        fused_recurrent_gated_delta_rule_packed_decode,\n    )\n    from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import (\n        fused_sigmoid_gating_delta_rule_update,\n    )\n\nif is_npu():\n    from sgl_kernel_npu.fla.chunk import chunk_gated_delta_rule_npu\n    from sgl_kernel_npu.fla.fused_sigmoid_gating_recurrent import (\n        fused_sigmoid_gating_delta_rule_update_npu,\n    )\n\n    chunk_gated_delta_rule = chunk_gated_delta_rule_npu\n    fused_sigmoid_gating_delta_rule_update = fused_sigmoid_gating_delta_rule_update_npu\nelif is_cpu():\n    from sgl_kernel.mamba import chunk_gated_delta_rule_cpu\n\n    chunk_gated_delta_rule = chunk_gated_delta_rule_cpu\n    fused_sigmoid_gating_delta_rule_update = (\n        torch.ops.sgl_kernel.fused_sigmoid_gating_delta_rule_update_cpu\n    )\n\n\nclass TritonGDNKernel(LinearAttnKernelBase):\n    \"\"\"Triton-based kernel for GDN (Gated Delta Network) linear attention.\"\"\"\n\n    supports_packed_decode: bool = not is_cpu() and not is_npu()\n\n    def packed_decode(\n        self,\n        mixed_qkv: torch.Tensor,\n        a: torch.Tensor,\n        b: torch.Tensor,\n        *,\n        A_log: torch.Tensor,\n        dt_bias: torch.Tensor,\n        scale: float,\n        ssm_states: torch.Tensor,\n        cache_indices: torch.Tensor,\n        num_v_heads: int,\n        head_v_dim: int,\n        **kwargs,\n    ) -> torch.Tensor:\n        \"\"\"Packed decode fast path: fuse QKV extraction + gating + recurrent\n        update into a single Triton kernel, eliminating intermediate tensors\n        and extra kernel launches.\n\n        Args:\n            mixed_qkv: [B, qkv_dim] packed projection output after conv1d.\n            a, b: [B, HV] gating inputs.\n            A_log: [HV] log-space decay parameter.\n            dt_bias: [HV] time-step bias.\n            scale: attention scale factor (typically head_k_dim ** -0.5).\n            ssm_states: [num_slots, HV, V, K] full state pool.\n            cache_indices: [B] per-request state slot indices.\n            num_v_heads: number of value heads (after TP sharding).\n            head_v_dim: dimension per value head.\n\n        Returns:\n            output tensor of shape [1, B, HV, V] matching the existing\n            decode kernel output layout.\n        \"\"\"\n        B = mixed_qkv.shape[0]\n        # Packed kernel expects output shape [B, 1, HV, V]\n        out = mixed_qkv.new_empty(B, 1, num_v_heads, head_v_dim)\n\n        fused_recurrent_gated_delta_rule_packed_decode(\n            mixed_qkv=mixed_qkv,\n            a=a,\n            b=b,\n            A_log=A_log,\n            dt_bias=dt_bias,\n            scale=scale,\n            initial_state=ssm_states,\n            out=out,\n            ssm_state_indices=cache_indices,\n            use_qk_l2norm_in_kernel=True,\n        )\n\n        # Convert [B, 1, HV, V] → [1, B, HV, V] to match existing output\n        # layout. transpose() returns a view — zero cost.\n        return out.transpose(0, 1)\n\n    def decode(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        a: torch.Tensor,\n        b: torch.Tensor,\n        *,\n        A_log: torch.Tensor,\n        dt_bias: torch.Tensor,\n        ssm_states: torch.Tensor,\n        cache_indices: torch.Tensor,\n        query_start_loc: torch.Tensor,\n        **kwargs,\n    ) -> torch.Tensor:\n        return fused_sigmoid_gating_delta_rule_update(\n            A_log=A_log,\n            dt_bias=dt_bias,\n            q=q,\n            k=k,\n            v=v,\n            a=a,\n            b=b,\n            initial_state_source=ssm_states,\n            initial_state_indices=cache_indices,\n            cu_seqlens=query_start_loc,\n            use_qk_l2norm_in_kernel=True,\n            softplus_beta=1.0,\n            softplus_threshold=20.0,\n        )\n\n    def extend(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        g: torch.Tensor,\n        beta: torch.Tensor,\n        *,\n        ssm_states: torch.Tensor,\n        cache_indices: torch.Tensor,\n        query_start_loc: torch.Tensor,\n        **kwargs,\n    ) -> tuple:\n        recurrent_state = ssm_states\n        recurrent_state_indices_args = {\"initial_state_indices\": cache_indices}\n        if is_npu() or is_cpu():\n            recurrent_state = ssm_states[cache_indices]\n            recurrent_state_indices_args = {}\n        return chunk_gated_delta_rule(\n            q=q,\n            k=k,\n            v=v,\n            g=g,\n            beta=beta,\n            initial_state=recurrent_state,\n            cu_seqlens=query_start_loc,\n            head_first=False,\n            use_qk_l2norm_in_kernel=True,\n            **recurrent_state_indices_args,\n        )\n\n    def target_verify(\n        self,\n        A_log: torch.Tensor,\n        dt_bias: torch.Tensor,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        a: torch.Tensor,\n        b: torch.Tensor,\n        *,\n        ssm_states: torch.Tensor,\n        cache_indices: torch.Tensor,\n        query_start_loc: torch.Tensor,\n        intermediate_states_buffer: torch.Tensor,\n        intermediate_state_indices: torch.Tensor,\n        cache_steps: int,\n        retrieve_parent_token: torch.Tensor,\n        **kwargs,\n    ) -> torch.Tensor:\n        return fused_sigmoid_gating_delta_rule_update(\n            A_log=A_log,\n            dt_bias=dt_bias,\n            q=q,\n            k=k,\n            v=v,\n            a=a,\n            b=b,\n            initial_state_source=ssm_states,\n            initial_state_indices=cache_indices,\n            cu_seqlens=query_start_loc,\n            use_qk_l2norm_in_kernel=True,\n            softplus_beta=1.0,\n            softplus_threshold=20.0,\n            is_kda=False,\n            # target_verify specific parameters\n            disable_state_update=True,\n            intermediate_states_buffer=intermediate_states_buffer,\n            intermediate_state_indices=intermediate_state_indices,\n            cache_steps=cache_steps,\n            retrieve_parent_token=retrieve_parent_token,\n        )\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/linear/kernels/kda_triton.py",
    "content": "import torch\n\nfrom sglang.srt.layers.attention.linear.kernels.kernel_backend import (\n    LinearAttnKernelBase,\n)\nfrom sglang.srt.utils import is_cpu\n\nif not is_cpu():\n    from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import (\n        fused_sigmoid_gating_delta_rule_update,\n    )\n    from sglang.srt.layers.attention.fla.kda import chunk_kda\n\n\nclass TritonKDAKernel(LinearAttnKernelBase):\n    \"\"\"Triton-based kernel for KDA (Kimi Delta Attention) linear attention.\"\"\"\n\n    def decode(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        a: torch.Tensor,\n        b: torch.Tensor,\n        *,\n        A_log: torch.Tensor,\n        dt_bias: torch.Tensor,\n        ssm_states: torch.Tensor,\n        cache_indices: torch.Tensor,\n        query_start_loc: torch.Tensor,\n        **kwargs,\n    ) -> torch.Tensor:\n        return fused_sigmoid_gating_delta_rule_update(\n            A_log=A_log,\n            dt_bias=dt_bias,\n            q=q,\n            k=k,\n            v=v,\n            a=a,\n            b=b,\n            initial_state_source=ssm_states,\n            initial_state_indices=cache_indices,\n            cu_seqlens=query_start_loc,\n            use_qk_l2norm_in_kernel=True,\n            softplus_beta=1.0,\n            softplus_threshold=20.0,\n            is_kda=True,\n        )\n\n    def extend(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        g: torch.Tensor,\n        beta: torch.Tensor,\n        *,\n        ssm_states: torch.Tensor,\n        cache_indices: torch.Tensor,\n        query_start_loc: torch.Tensor,\n        **kwargs,\n    ) -> torch.Tensor:\n        return chunk_kda(\n            q=q,\n            k=k,\n            v=v,\n            g=g,\n            beta=beta,\n            initial_state=ssm_states,\n            initial_state_indices=cache_indices,\n            use_qk_l2norm_in_kernel=True,\n            cu_seqlens=query_start_loc,\n        )\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/linear/kernels/kernel_backend.py",
    "content": "from abc import ABC, abstractmethod\n\nimport torch\n\n\nclass LinearAttnKernelBase(ABC):\n    \"\"\"Abstract base class for linear attention kernel implementations.\n\n    Each concrete implementation wraps a specific kernel (Triton, CuTe DSL, etc.)\n    and provides decode/extend/target_verify methods with a unified interface.\n    \"\"\"\n\n    @abstractmethod\n    def decode(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        a: torch.Tensor,\n        b: torch.Tensor,\n        *,\n        A_log: torch.Tensor,\n        dt_bias: torch.Tensor,\n        ssm_states: torch.Tensor,\n        cache_indices: torch.Tensor,\n        query_start_loc: torch.Tensor,\n        **kwargs,\n    ) -> torch.Tensor: ...\n\n    @abstractmethod\n    def extend(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        g: torch.Tensor,\n        beta: torch.Tensor,\n        *,\n        ssm_states: torch.Tensor,\n        cache_indices: torch.Tensor,\n        query_start_loc: torch.Tensor,\n        **kwargs,\n    ) -> tuple: ...\n\n    def target_verify(\n        self,\n        A_log: torch.Tensor,\n        dt_bias: torch.Tensor,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        a: torch.Tensor,\n        b: torch.Tensor,\n        *,\n        ssm_states: torch.Tensor,\n        cache_indices: torch.Tensor,\n        query_start_loc: torch.Tensor,\n        **kwargs,\n    ) -> torch.Tensor:\n        raise NotImplementedError(\n            f\"{self.__class__.__name__} does not support target_verify\"\n        )\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/linear/lightning_attn.py",
    "content": "# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/linear_attn.py\n\n# SPDX-License-Identifier: Apache-2.0\n# SPDX-FileCopyrightText: Copyright contributors to the vLLM project\nimport torch\nimport triton\nimport triton.language as tl\nfrom einops import rearrange\n\n\n@triton.jit\ndef _fwd_diag_kernel(\n    Q,\n    K,\n    V,\n    Out,\n    S,\n    b: tl.constexpr,\n    h: tl.constexpr,\n    n,\n    d: tl.constexpr,\n    e: tl.constexpr,\n    BLOCK: tl.constexpr,\n    NUM_BLOCK,\n    CBLOCK: tl.constexpr,\n):\n    # This kernel computes the diagonal blocks of the attention matrix\n    # Each diagonal block represents attention\n    # where queries attend to keys in the same block\n    off = tl.program_id(0)\n    off_bh = off // NUM_BLOCK  # batch-head index\n    off_block = off % NUM_BLOCK  # block index within the sequence\n    off_cblock = tl.program_id(1)  # sub-block index within a block\n\n    off_h = off_bh % h  # head index\n\n    # Calculate base offsets for the current batch and head\n    qk_offset = off_bh * n * d\n    v_offset = off_bh * n * e\n    o_offset = off_bh * n * e\n\n    # Calculate offsets for the current block\n    block_offset = off_block * BLOCK\n    qk_block_offset = block_offset * d\n    v_block_offset = block_offset * e\n    o_block_offset = block_offset * e\n\n    # Calculate offsets for the current sub-block\n    cblock_offset = off_cblock * CBLOCK\n    q_cblock_offset = cblock_offset * d\n    o_cblock_offset = cblock_offset * e\n\n    # Calculate pointers to the query, key, value, and output tensors\n    Q_block_ptr = (\n        Q\n        + qk_offset\n        + qk_block_offset\n        + q_cblock_offset\n        + tl.arange(0, CBLOCK)[:, None] * d\n        + tl.arange(0, d)[None, :]\n    )\n    K_trans_block_ptr = (\n        K\n        + qk_offset\n        + qk_block_offset\n        + tl.arange(0, CBLOCK)[None, :] * d\n        + tl.arange(0, d)[:, None]\n    )\n    V_block_ptr = (\n        V\n        + v_offset\n        + v_block_offset\n        + tl.arange(0, CBLOCK)[:, None] * e\n        + tl.arange(0, e)[None, :]\n    )\n    O_block_ptr = (\n        Out\n        + o_offset\n        + o_block_offset\n        + o_cblock_offset\n        + tl.arange(0, CBLOCK)[:, None] * e\n        + tl.arange(0, e)[None, :]\n    )\n\n    # Load the decay rate for the current head\n    S_block_ptr = S + off_h\n    s = tl.load(S_block_ptr)\n\n    i = off_cblock\n    q_index = tl.arange(0, CBLOCK) + i * CBLOCK\n\n    # Load query values\n    q = tl.load(Q_block_ptr, mask=block_offset + q_index[:, None] < n, other=0.0).to(\n        tl.float32\n    )\n\n    # Initialize output accumulator\n    qkv = tl.zeros([CBLOCK, e], dtype=tl.float32)\n\n    # Process all sub-blocks up to and\n    # including the current one (causal attention)\n    for j in range(i + 1):\n        kv_index = tl.arange(0, CBLOCK) + j * CBLOCK\n        diff = q_index[:, None] - kv_index[None, :]\n        s_index = s * diff\n        # Apply causal mask: only attend to positions before the current one\n        s_index = tl.where(diff >= 0, -s_index, float(\"-inf\"))\n        decay = tl.exp(s_index)\n\n        # Load key and value\n        k_trans = tl.load(\n            K_trans_block_ptr,\n            mask=block_offset + kv_index[None, :] < n,\n            other=0.0,\n        ).to(tl.float32)\n        v = tl.load(\n            V_block_ptr,\n            mask=block_offset + kv_index[:, None] < n,\n            other=0.0,\n        ).to(tl.float32)\n\n        # Compute attention scores and apply decay\n        qk = tl.dot(q, k_trans) * decay\n\n        # Compute weighted values and accumulate\n        qkv += tl.dot(qk, v)\n\n        # Move to the next sub-block\n        K_trans_block_ptr += CBLOCK * d\n        V_block_ptr += CBLOCK * e\n\n    # Store the result\n    tl.store(\n        O_block_ptr,\n        qkv.to(O_block_ptr.dtype.element_ty),\n        mask=block_offset + q_index[:, None] < n,\n    )\n\n\n@triton.jit\ndef _fwd_kv_parallel(\n    K,\n    V,\n    K_decay,\n    KV,\n    b: tl.constexpr,\n    h: tl.constexpr,\n    n,\n    d: tl.constexpr,\n    e: tl.constexpr,\n    BLOCK: tl.constexpr,\n    NUM_BLOCK,\n    D_FBLOCK: tl.constexpr,\n    E_FBLOCK: tl.constexpr,\n    NUM_FBLOCK: tl.constexpr,\n    CBLOCK: tl.constexpr,\n    NUM_CBLOCK: tl.constexpr,\n):\n    # This kernel computes the key-value outer\n    # products for each block in parallel\n    off_bh = tl.program_id(0)  # batch-head index\n    off_block = tl.program_id(1)  # block index\n\n    off_h = off_bh % h  # head index\n\n    block_offset = off_block * BLOCK\n\n    # Calculate offsets for the current block\n    k_block_offset = block_offset * d\n    v_block_offset = block_offset * e\n    kv_block_offset = off_block * d * e\n\n    # Calculate base offsets for the current batch and head\n    k_offset = off_bh * n * d\n    v_offset = off_bh * n * e\n    kv_offset = off_bh * NUM_BLOCK * d * e\n\n    # Calculate pointers to the key, value, and key-value tensors\n    K_trans_block_ptr = (\n        K\n        + k_offset\n        + k_block_offset\n        + tl.arange(0, CBLOCK)[None, :] * d\n        + tl.arange(0, D_FBLOCK)[:, None]\n    )\n    V_block_ptr = (\n        V\n        + v_offset\n        + v_block_offset\n        + tl.arange(0, CBLOCK)[:, None] * e\n        + tl.arange(0, E_FBLOCK)[None, :]\n    )\n    KV_block_ptr = (\n        KV\n        + kv_offset\n        + kv_block_offset\n        + tl.arange(0, D_FBLOCK)[:, None] * e\n        + tl.arange(0, E_FBLOCK)[None, :]\n    )\n\n    # Load the decay factors for the current head and block\n    k_decay_ptr = K_decay + off_h * BLOCK + tl.arange(0, CBLOCK)[None, :]\n\n    kv_index = tl.arange(0, CBLOCK)\n\n    # Initialize the key-value outer product accumulator\n    kv = tl.zeros([D_FBLOCK, E_FBLOCK], dtype=tl.float32)\n\n    # Handle the last block which might be smaller than BLOCK\n    if off_block == NUM_BLOCK - 1:\n        split_n = n - (NUM_BLOCK - 1) * BLOCK\n    else:\n        split_n = BLOCK\n    left_shift = tl.cdiv(split_n, CBLOCK) * CBLOCK - split_n\n    num_blocks = min(tl.cdiv(split_n, CBLOCK), NUM_CBLOCK)\n    k_decay_ptr += (NUM_CBLOCK - num_blocks) * CBLOCK\n\n    # Process all sub-blocks in the current block\n    for j in range(num_blocks):\n        left_bound = (1 - j) * left_shift\n        # Load key and value, handling boundary conditions\n        k_trans = tl.load(\n            K_trans_block_ptr - left_shift * d,\n            mask=kv_index[None, :] >= left_bound,\n            other=0.0,\n        )\n        v = tl.load(\n            V_block_ptr - left_shift * e,\n            mask=kv_index[:, None] >= left_bound,\n            other=0.0,\n        )\n\n        # Load decay factor and compute weighted key-value outer product\n        k_decay = tl.load(k_decay_ptr)\n        kv += tl.dot(k_trans * k_decay, v)\n\n        # Move to the next sub-block\n        K_trans_block_ptr += CBLOCK * d\n        V_block_ptr += CBLOCK * e\n        k_decay_ptr += CBLOCK\n\n    # Store the result\n    tl.store(KV_block_ptr, kv.to(KV_block_ptr.dtype.element_ty))\n\n\n@triton.jit\ndef _fwd_kv_reduce(\n    S,\n    KV,\n    KV_HISTORY,\n    b: tl.constexpr,\n    h: tl.constexpr,\n    n,\n    d: tl.constexpr,\n    e: tl.constexpr,\n    BLOCK: tl.constexpr,\n    NUM_BLOCK,\n    D_FBLOCK: tl.constexpr,\n    E_FBLOCK: tl.constexpr,\n):\n    # This kernel reduces the key-value outer products\n    # across blocks and updates the KV history\n    off_bh = tl.program_id(0)  # batch-head index\n    off_h = off_bh % h  # head index\n\n    kv_offset = off_bh * NUM_BLOCK * d * e\n\n    # Calculate pointer to the key-value tensor\n    KV_block_ptr = (\n        KV\n        + kv_offset\n        + tl.arange(0, D_FBLOCK)[:, None] * e\n        + tl.arange(0, E_FBLOCK)[None, :]\n    )\n\n    # Load the decay rate for the current head\n    s_ptrs = S + off_h\n    s = tl.load(s_ptrs)\n\n    # Calculate pointer to the key-value history tensor\n    kv_history_offset = off_bh * d * e\n    KV_HISTORY_block_ptr = (\n        KV_HISTORY\n        + kv_history_offset\n        + tl.arange(0, D_FBLOCK)[:, None] * e\n        + tl.arange(0, E_FBLOCK)[None, :]\n    )\n\n    # Load the previous key-value history\n    kv_pre = tl.load(KV_HISTORY_block_ptr).to(tl.float32)\n\n    # Process all blocks in reverse order to compute the prefix sum\n    for i in range(NUM_BLOCK):\n        block_size = min(n - i * BLOCK, BLOCK)\n        # Compute decay factor for the current block\n        block_decay = tl.exp(-s.to(tl.float32) * block_size)\n\n        # Load the current key-value outer product\n        kv_cur = tl.load(KV_block_ptr).to(tl.float32)\n        # Store the previous key-value history to the current block\n        tl.store(KV_block_ptr, kv_pre.to(KV_block_ptr.dtype.element_ty))\n\n        # Update the key-value history with the current block\n        kv_pre = block_decay * kv_pre + kv_cur\n        KV_block_ptr += d * e\n\n    # Store the updated key-value history\n    tl.store(KV_HISTORY_block_ptr, kv_pre)\n\n\n@triton.jit\ndef _fwd_none_diag_kernel(\n    Q,\n    Out,\n    S,\n    KV,\n    b: tl.constexpr,\n    h: tl.constexpr,\n    n,\n    d: tl.constexpr,\n    e: tl.constexpr,\n    BLOCK: tl.constexpr,\n    NUM_BLOCK,\n    E_FBLOCK: tl.constexpr,\n    CBLOCK: tl.constexpr,\n    NUM_CBLOCK: tl.constexpr,\n):\n    # This kernel computes the non-diagonal blocks of the attention matrix\n    # Each non-diagonal block represents attention\n    # where queries attend to keys in different blocks\n    off_bh = tl.program_id(0)  # batch-head index\n    off_h = off_bh % h  # head index\n\n    off_nc = tl.program_id(1)\n    off_n = off_nc // NUM_CBLOCK  # block index\n    off_c = off_nc % NUM_CBLOCK  # sub-block index\n    off_e = tl.program_id(2)  # output feature block index\n\n    n_offset = off_n * BLOCK\n    c_offset = off_c * CBLOCK\n    e_offset = off_e * E_FBLOCK\n    block_offset = n_offset + c_offset\n\n    # Calculate offsets for the current batch, head, and block\n    q_offset = off_bh * n * d + (n_offset + c_offset) * d\n    o_offset = off_bh * n * e + (n_offset + c_offset) * e + e_offset\n    kv_offset = off_bh * NUM_BLOCK * d * e + off_n * d * e + e_offset\n\n    # Calculate pointers to the query, output, and key-value tensors\n    Q_block_ptr = (\n        Q + q_offset + tl.arange(0, CBLOCK)[:, None] * d + tl.arange(0, d)[None, :]\n    )\n    O_block_ptr = (\n        Out\n        + o_offset\n        + tl.arange(0, CBLOCK)[:, None] * e\n        + tl.arange(0, E_FBLOCK)[None, :]\n    )\n    KV_block_ptr = (\n        KV + kv_offset + tl.arange(0, d)[:, None] * e + tl.arange(0, E_FBLOCK)[None, :]\n    )\n\n    # Load the decay rate for the current head\n    S_block_ptr = S + off_h\n    s = tl.load(S_block_ptr)\n\n    c_array = tl.arange(0, CBLOCK)\n\n    # Load the key-value outer product for the current block\n    kv = tl.load(KV_block_ptr).to(tl.float32)\n    q_index = block_offset + tl.arange(0, CBLOCK)\n\n    # Load query values\n    q = tl.load(Q_block_ptr, mask=q_index[:, None] < n, other=0.0).to(tl.float32)\n\n    # Compute decay factors for the current sub-block\n    q_decay = tl.exp(-s.to(tl.float32) * (off_c * CBLOCK + c_array[:, None]))\n\n    # Compute non-diagonal attention output\n    qkv_none_diag = tl.dot(q, kv) * q_decay\n\n    # Load diagonal attention output (computed by _fwd_diag_kernel)\n    qkv_diag = tl.load(O_block_ptr, mask=q_index[:, None] < n, other=0.0).to(tl.float32)\n\n    # Combine diagonal and non-diagonal attention outputs\n    qkv = qkv_diag + qkv_none_diag\n\n    # Store the result\n    tl.store(\n        O_block_ptr, qkv.to(O_block_ptr.dtype.element_ty), mask=q_index[:, None] < n\n    )\n\n\nclass _attention(torch.autograd.Function):\n\n    @staticmethod\n    def forward(ctx, q, k, v, s, kv_history):\n        # Forward pass of the lightning attention algorithm\n        q = q.contiguous()\n        k = k.contiguous()\n        v = v.contiguous()\n        s = s.contiguous()\n\n        # Check CUDA compute capability\n        capability = torch.cuda.get_device_capability()\n        if capability[0] < 8:\n            raise RuntimeError(\n                \"Flash attention currently only supported\",\n                \"for compute capability >= 80\",\n            )\n\n        # Get input dimensions\n        b, h, n, d = q.shape\n        e = v.shape[-1]\n\n        # Initialize output tensor\n        o = torch.empty((b, h, n, e), dtype=q.dtype, device=q.device)\n\n        # Set block sizes\n        BLOCK = 256\n        NUM_BLOCK = triton.cdiv(n, BLOCK)\n\n        CBLOCK = 32\n        NUM_CBLOCK = BLOCK // CBLOCK\n        assert BLOCK % CBLOCK == 0, \"BLOCK must be a multiple of CBLOCK\"\n\n        # Compute decay factors for keys\n        array = torch.arange(0, BLOCK, device=q.device) + 1\n        k_decay = torch.exp(-s * (BLOCK - array.reshape(1, -1)))\n\n        # Step 1: Compute diagonal blocks of attention\n        grid = (b * h * NUM_BLOCK, NUM_CBLOCK)\n        _fwd_diag_kernel[grid](\n            q,\n            k,\n            v,\n            o,\n            s,\n            b,\n            h,\n            n,\n            d,\n            e,\n            BLOCK=BLOCK,\n            NUM_BLOCK=NUM_BLOCK,\n            CBLOCK=CBLOCK,\n        )\n\n        # Set feature block sizes\n        NUM_FBLOCK = 1\n        D_FBLOCK = d // NUM_FBLOCK\n        assert d % NUM_FBLOCK == 0\n        E_FBLOCK = e // NUM_FBLOCK\n        assert e % NUM_FBLOCK == 0\n\n        CBLOCK = 64\n        NUM_CBLOCK = BLOCK // CBLOCK\n        assert BLOCK % CBLOCK == 0, \"BLOCK must be a multiple of CBLOCK\"\n\n        # Step 2: Compute key-value outer products for each block in parallel\n        kv = torch.empty((b, h, NUM_BLOCK, d, e), dtype=torch.float32, device=q.device)\n        grid = (b * h, NUM_BLOCK)\n        _fwd_kv_parallel[grid](\n            k,\n            v,\n            k_decay,\n            kv,\n            b,\n            h,\n            n,\n            d,\n            e,\n            BLOCK=BLOCK,\n            NUM_BLOCK=NUM_BLOCK,\n            D_FBLOCK=D_FBLOCK,\n            E_FBLOCK=E_FBLOCK,\n            NUM_FBLOCK=NUM_FBLOCK,\n            CBLOCK=CBLOCK,\n            NUM_CBLOCK=NUM_CBLOCK,\n        )\n\n        # Step 3: Reduce key-value outer products\n        # across blocks and update KV history\n        grid = (b * h, NUM_FBLOCK)\n        _fwd_kv_reduce[grid](\n            s,\n            kv,\n            kv_history,\n            b,\n            h,\n            n,\n            d,\n            e,\n            BLOCK=BLOCK,\n            NUM_BLOCK=NUM_BLOCK,\n            D_FBLOCK=D_FBLOCK,\n            E_FBLOCK=E_FBLOCK,\n        )\n\n        # Step 4: Compute non-diagonal blocks of attention\n        grid = (b * h, NUM_BLOCK * NUM_CBLOCK)\n        _fwd_none_diag_kernel[grid](\n            q,\n            o,\n            s,\n            kv,\n            b,\n            h,\n            n,\n            d,\n            e,\n            BLOCK=BLOCK,\n            NUM_BLOCK=NUM_BLOCK,\n            E_FBLOCK=E_FBLOCK,\n            CBLOCK=CBLOCK,\n            NUM_CBLOCK=NUM_CBLOCK,\n        )\n\n        # Save tensors for backward pass\n        ctx.save_for_backward(q, k, v, s, kv)\n        ctx.BLOCK = BLOCK\n\n        return o, torch.cat([kv, kv_history.unsqueeze(2)], dim=2)\n\n\n# Apply the lightning attention function\nlightning_attention_ = _attention.apply\n\n\ndef lightning_attention(q, k, v, ed, block_size=256, kv_history=None):\n    \"\"\"\n    Apply lightning attention algorithm\n    to compute attention efficiently.\n\n    Args:\n        q: Query tensor of shape [batch, heads, seq_len, dim]\n        k: Key tensor of shape [batch, heads, seq_len, dim]\n        v: Value tensor of shape [batch, heads, seq_len, dim_v]\n        ed: Decay rate tensor of shape [heads]\n        block_size: Size of blocks for block-sparse attention\n        kv_history: Optional key-value history from previous computations\n\n    Returns:\n        output: Attention output\n        kv: Updated key-value history\n    \"\"\"\n    d = q.shape[-1]\n    e = v.shape[-1]\n\n    if ed.dim() == 1:\n        ed = ed.view(1, -1, 1, 1)\n\n    # Split the computation into chunks for better parallelism\n    m = 128 if d >= 128 else 64\n    assert d % m == 0, f\"Dimension d ({d}) must be divisible by m ({m})\"\n    arr = [m * i for i in range(d // m + 1)]\n    if arr[-1] != d:\n        arr.append(d)\n    n = len(arr)\n    output = 0\n\n    # Initialize or clone key-value history\n    if kv_history is None:\n        kv_history = torch.zeros(\n            (q.shape[0], q.shape[1], d, e), dtype=torch.float32, device=q.device\n        )\n    else:\n        kv_history = kv_history.clone().contiguous()\n\n    # Process each chunk and accumulate results\n    for i in range(n - 1):\n        s = arr[i]\n        e = arr[i + 1]\n        q1 = q[..., s:e]\n        k1 = k[..., s:e]\n        o, kv = lightning_attention_(q1, k1, v, ed, kv_history)\n        output = output + o\n    return output, kv\n\n\n@triton.jit\ndef _linear_attn_decode_kernel(\n    q_ptr,\n    k_ptr,\n    v_ptr,\n    kv_cache_ptr,\n    slope_rate,\n    slot_idx,\n    output_ptr,\n    D: tl.constexpr,\n    qkv_b_stride,\n    qkv_h_stride,\n    cache_b_stride,\n    cache_h_stride,\n    cache_d0_stride,\n    cache_d1_stride,\n    BLOCK_SIZE: tl.constexpr,\n):\n    \"\"\"\n    Kernel for linear attention decoding with KV cache.\n\n    This kernel computes attention for a single token using the KV cache.\n    \"\"\"\n    pid_b = tl.program_id(0)  # batch index\n    pid_h = tl.program_id(1)  # head index\n    pid_d = tl.program_id(2)  # dimension block index\n\n    # Load slot index for the current batch\n    slot_id = tl.load(slot_idx + pid_b)\n\n    # Skip if slot_id is -1 (padding)\n    if slot_id == -1:\n        return\n\n    batch_id = pid_b\n    head_id = pid_h\n\n    # Load decay rate for the current head\n    ratio = tl.load(slope_rate + pid_h)\n\n    # Calculate offsets for dimensions\n    qk_d_offsets = tl.arange(0, D)\n    v_d_offsets = tl.arange(0, BLOCK_SIZE) + pid_d * BLOCK_SIZE\n    cache_d_offsets = (\n        qk_d_offsets[:, None] * cache_d0_stride + v_d_offsets[None, :] * cache_d1_stride\n    )\n\n    # Calculate offsets for the current batch and head\n    q_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride\n    k_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride\n    v_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride\n\n    cache_offset = slot_id * cache_b_stride + head_id * cache_h_stride\n\n    # Create masks for loading tensors\n    qk_mask = qk_d_offsets < D\n    v_mask = v_d_offsets < D\n\n    # Load query, key, and value tensors\n    q = tl.load(q_ptr + q_offset + qk_d_offsets, mask=qk_mask, other=0.0)\n    k = tl.load(k_ptr + k_offset + qk_d_offsets, mask=qk_mask, other=0.0)\n    v = tl.load(v_ptr + v_offset + v_d_offsets, mask=v_mask, other=0.0)\n\n    # Compute key-value outer product\n    kv_outer = k[:, None] * v[None, :]\n    kv_mask = qk_mask[:, None] & v_mask[None, :]\n\n    # Apply decay to previous KV cache\n    ratio = tl.exp(-ratio)\n    kv_ptr = kv_cache_ptr + cache_offset + cache_d_offsets\n    kv_cache_old = tl.load(kv_ptr, mask=kv_mask, other=0.0)\n    kv_outer = kv_outer + ratio * kv_cache_old\n\n    # Compute attention output\n    output = q[:, None].to(tl.float32) * kv_outer\n    output = tl.sum(output, axis=0)\n\n    # Update KV cache and store output\n    tl.store(kv_ptr, kv_outer, mask=kv_mask)\n    tl.store(output_ptr + q_offset + v_d_offsets, output, mask=v_mask)\n\n\ndef linear_decode_forward_triton(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    kv_caches: torch.Tensor,\n    slope_rate: torch.Tensor,\n    slot_idx: torch.Tensor,\n    BLOCK_SIZE: int = 32,\n) -> torch.Tensor:\n    \"\"\"\n    Perform linear attention decoding using Triton kernels.\n\n    Args:\n        q: Query tensor of shape [B, H, 1, D]\n        k: Key tensor of shape [B, H, 1, D]\n        v: Value tensor of shape [B, H, 1, D]\n        kv_caches: Key-value cache tensor\n        slope_rate: Decay rate tensor\n        slot_idx: Slot indices for batches\n        BLOCK_SIZE: Size of blocks for processing\n\n    Returns:\n        output: Attention output tensor\n    \"\"\"\n    B, H, _, D = q.shape\n    assert k.shape == (B, H, 1, D)\n    assert v.shape == (B, H, 1, D)\n\n    # Initialize output tensor\n    output = torch.empty_like(q)\n\n    # Set grid dimensions for the kernel\n    grid = (B, H, D // BLOCK_SIZE)\n\n    # Calculate strides for tensors\n    qkv_b_stride = q.stride(0)\n    qkv_h_stride = q.stride(1)\n\n    cache_b_stride = kv_caches.stride(0)\n    cache_h_stride = kv_caches.stride(1)\n    cache_d0_stride = kv_caches.stride(2)\n    cache_d1_stride = kv_caches.stride(3)\n\n    # Launch the kernel\n    _linear_attn_decode_kernel[grid](\n        q,\n        k,\n        v,\n        kv_caches,\n        slope_rate,\n        slot_idx,\n        output,\n        D,\n        qkv_b_stride,\n        qkv_h_stride,\n        cache_b_stride,\n        cache_h_stride,\n        cache_d0_stride,\n        cache_d1_stride,\n        BLOCK_SIZE=BLOCK_SIZE,\n    )\n\n    # Reshape output and return\n    output = rearrange(output, \"b h n d -> b n (h d)\")\n    return output.squeeze(1).contiguous()\n\n\nclass BailingLinearKernel:\n    \"\"\"\n    Linear attention kernel implementation for Bailing models.\n\n    This class is adapted from MiniMaxText01LinearKernel in vllm:\n    https://github.com/vllm-project/vllm/blob/a9138e85b14047e06300685b48e3485b995425fb/vllm/model_executor/models/minimax_text_01.py#L289\n\n    The implementation maintains the same functionality while being renamed to\n    match our Bailing model naming convention.\n    \"\"\"\n\n    @staticmethod\n    def jit_linear_forward_prefix(\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        kv_caches: torch.Tensor,\n        slope_rate: torch.Tensor,\n        block_size: int,\n        layer_idx: int = None,\n        **kwargs,\n    ) -> torch.Tensor:\n\n        slope_rate = slope_rate.to(torch.float32)\n        should_pad_dim = q.dim() == 3\n        if should_pad_dim:\n            q = q.unsqueeze(0)\n            k = k.unsqueeze(0)\n            v = v.unsqueeze(0)\n        b, h, n, d = q.shape\n        e = d\n        kv_history = kv_caches.reshape(1, h, d, e).contiguous()\n        output, kv_history = lightning_attention(\n            q, k, v, slope_rate, block_size=block_size, kv_history=kv_history\n        )\n        kv_caches.copy_(kv_history[:, :, -1, :, :].reshape(h, d, e))\n        assert output.shape[0] == 1, \"batch size must be 1\"\n        return output.squeeze(0).transpose(0, 1).reshape([n, h * d]).contiguous()\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/linear/lightning_backend.py",
    "content": "import logging\nimport math\nfrom typing import Optional, Union\n\nimport torch\n\nfrom sglang.srt.layers.attention.hybrid_linear_attn_backend import MambaAttnBackendBase\nfrom sglang.srt.layers.attention.linear.lightning_attn import (\n    BailingLinearKernel,\n    linear_decode_forward_triton,\n)\nfrom sglang.srt.layers.attention.linear.linear_metadata import BailingLinearMetadata\nfrom sglang.srt.layers.attention.linear.seg_la import SegLaMeta, seg_la_fwd\nfrom sglang.srt.layers.radix_attention import RadixAttention\nfrom sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode\nfrom sglang.srt.model_executor.model_runner import ModelRunner\nfrom sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput\n\nlogger = logging.getLogger(__name__)\n\n\nclass LightningAttentionBackend(MambaAttnBackendBase):\n    \"\"\"\n    Note about the init:\n    - If no spec decoding\n        - FlashAttentionBackend will be init once when the server starts.\n    - If spec decoding\n        - FlashAttentionBackend will be init once for the target worker\n        - FlashAttentionMultiStepBackend will be once for the draft worker\n            - It will spawn num_steps FlashAttentionBackend for the draft worker\n\n    Note about CUDA Graph:\n    - We only support CUDA Graph for Decode (Normal Decode and Draft Decode) and Target Verify.\n    - We don't support CUDA Graph for Extend and Draft Extend.\n    - When server init, init_cuda_graph_state will be called first and then init_cuda_graph_capture will be called.\n    - For each forward batch, init_replay_cuda_graph will be called first and then replay the graph.\n    \"\"\"\n\n    def __init__(self, model_runner: ModelRunner):\n        super().__init__(model_runner)\n\n        assert not (\n            model_runner.sliding_window_size is not None\n            and model_runner.model_config.is_encoder_decoder\n        ), \"Sliding window and cross attention are not supported together\"\n\n        # extra metadata for handling speculative decoding topk > 1, extended draft decode and verify\n        self.max_context_len = model_runner.model_config.context_len\n        self.device = model_runner.device\n        self.decode_cuda_graph_metadata = {}\n        self.kv_cache_dtype = model_runner.kv_cache_dtype\n        self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype\n        self.BLOCK = (\n            model_runner.model_config.block\n            if hasattr(model_runner.model_config, \"block\")\n            else 256\n        )\n        total_num_heads = model_runner.model_config.hf_config.num_attention_heads\n        num_hidden_layers = model_runner.model_config.hf_config.num_hidden_layers\n        self.tp_slope = LightningAttentionBackend._build_slope_tensor(\n            total_num_heads, num_hidden_layers, self.device\n        )\n        self.linear_backend = getattr(\n            model_runner.model_config.hf_config, \"linear_backend\", \"seg_la\"\n        )\n        logger.info(\n            f\"linear_backend for linear attention in hybrid_linear_backend: {self.linear_backend}\"\n        )\n\n    def init_forward_metadata(self, forward_batch: ForwardBatch):\n        metadata = self._forward_metadata(forward_batch)\n        self.forward_metadata = BailingLinearMetadata.prepare_mixed(\n            metadata.query_start_loc,\n            metadata.mamba_cache_indices,\n            forward_batch,\n        )\n\n    def init_forward_metadata_capture_cuda_graph(\n        self,\n        bs: int,\n        num_tokens: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        encoder_lens: Optional[torch.Tensor],\n        forward_mode: ForwardMode,\n        spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],\n    ):\n        metadata = self._capture_metadata(bs, req_pool_indices, forward_mode, spec_info)\n        self.forward_metadata = BailingLinearMetadata.prepare_decode(\n            metadata.query_start_loc, metadata.mamba_cache_indices, bs, seq_lens\n        )\n\n    def init_forward_metadata_replay_cuda_graph(\n        self,\n        bs: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        seq_lens_sum: int,\n        encoder_lens: Optional[torch.Tensor],\n        forward_mode: ForwardMode,\n        spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],\n        seq_lens_cpu: Optional[torch.Tensor],\n    ):\n        metadata = self._replay_metadata(\n            bs, req_pool_indices, forward_mode, spec_info, seq_lens_cpu\n        )\n        self.forward_metadata = BailingLinearMetadata.prepare_decode(\n            metadata.query_start_loc, metadata.mamba_cache_indices, bs, seq_lens\n        )\n\n    @staticmethod\n    def _build_slope_tensor(\n        n_attention_heads: int, num_hidden_layers: int, device=\"cuda\"\n    ):\n        def get_slopes(n):\n            def get_slopes_power_of_2(n):\n                start = 2 ** (-(2 ** -(math.log2(n) - 3)))\n                ratio = start\n                return [start * ratio**i for i in range(n)]\n\n            if math.log2(n).is_integer():\n                return get_slopes_power_of_2(n)\n            else:\n                closest_power_of_2 = 2 ** math.floor(math.log2(n))\n                return (\n                    get_slopes_power_of_2(closest_power_of_2)\n                    + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]\n                )\n\n        slopes = torch.tensor(\n            get_slopes(n_attention_heads), dtype=torch.float32\n        ).reshape(n_attention_heads, 1, 1)\n        from sglang.srt.layers.dp_attention import (\n            get_attention_tp_rank,\n            get_attention_tp_size,\n        )\n\n        tp_heads = n_attention_heads // get_attention_tp_size()\n        tp_rank = get_attention_tp_rank()\n        if num_hidden_layers <= 1:\n            slope_rate_list = [slopes * (1 + 1e-5)]\n        else:\n            slope_rate_list = [\n                slopes * (1 - layer_id / (num_hidden_layers - 1) + 1e-5)\n                for layer_id in range(num_hidden_layers)\n            ]\n\n        tp_slope = [\n            slope_rate_list[layer_id][tp_rank * tp_heads : (tp_rank + 1) * tp_heads]\n            .contiguous()\n            .to(device)\n            for layer_id in range(num_hidden_layers)\n        ]\n\n        return tp_slope\n\n    def _prefill_and_mix_infer(\n        self,\n        q,\n        k,\n        v,\n        kv_cache,\n        state_indices_tensor,\n        forward_batch,\n        layer,\n        metadata,\n    ):\n        hidden = []\n        for _prefill_idx in range(metadata.num_prefills):\n            if _prefill_idx >= forward_batch.extend_start_loc.shape[0]:\n                break\n            if _prefill_idx >= state_indices_tensor.shape[0]:\n                break\n\n            _start = forward_batch.extend_start_loc[_prefill_idx]\n\n            if _prefill_idx + 1 < forward_batch.extend_start_loc.shape[0]:\n                _end = forward_batch.extend_start_loc[_prefill_idx + 1]\n            else:\n                if (\n                    forward_batch.extend_seq_lens is not None\n                    and _prefill_idx < forward_batch.extend_seq_lens.shape[0]\n                    and metadata.num_decodes > 0\n                ):\n                    seq_len = forward_batch.extend_seq_lens[_prefill_idx]\n                    _end = _start + seq_len\n                else:\n                    _end = q.shape[0]\n\n            slot_id = state_indices_tensor[_prefill_idx]\n            qs = q[_start:_end].transpose(0, 1).contiguous()\n            ks = k[_start:_end].transpose(0, 1).contiguous()\n            vs = v[_start:_end].transpose(0, 1).contiguous()\n            slice_layer_cache = kv_cache[slot_id, ...]\n            out_slice = BailingLinearKernel.jit_linear_forward_prefix(\n                qs,\n                ks,\n                vs,\n                slice_layer_cache,\n                self.tp_slope[layer.layer_id],\n                self.BLOCK,\n                layer_idx=layer.layer_id,\n            )\n            hidden.append(out_slice.contiguous())\n        if metadata.num_decodes > 0:\n            hidden.append(\n                self._decode_infer(\n                    q, k, v, kv_cache, state_indices_tensor, metadata, layer\n                )\n            )\n\n        if not hidden:\n            return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)\n\n        hidden = torch.concat(hidden, dim=0).contiguous()\n        return hidden\n\n    def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, metadata, layer):\n        num_prefill_tokens = metadata.num_prefill_tokens\n        num_prefills = metadata.num_prefills\n        q = q[num_prefill_tokens:].unsqueeze(2).contiguous()\n        k = k[num_prefill_tokens:].unsqueeze(2).contiguous()\n        v = v[num_prefill_tokens:].unsqueeze(2).contiguous()\n        slot_id = state_indices_tensor[num_prefills:]\n\n        assert slot_id.shape[0] == q.shape[0], (\n            f\"slot_id length {slot_id.shape[0]} does not match decode batch size {q.shape[0]}. \"\n            \"This indicates a bug in the upstream logic that should be investigated.\"\n        )\n        hidden = linear_decode_forward_triton(\n            q, k, v, kv_cache, self.tp_slope[layer.layer_id], slot_id, 32\n        )\n        return hidden\n\n    def _linear_attention_entry(\n        self,\n        q,\n        k,\n        v,\n        kv_cache,\n        state_indices_tensor,\n        metadata,\n        layer,\n        mask=None,\n        temp_cache=None,\n        intermediate_state_indices=None,\n    ):\n        q_offsets = metadata.query_start_loc\n\n        seg_meta = SegLaMeta(\n            batch_size=metadata.batch_size,\n            q_offsets=metadata.query_start_loc,\n            s_offsets=state_indices_tensor,\n            q_lengths=q_offsets.diff(),\n            s_scales=metadata.has_initial_states,\n            max_q_length=None,\n            mask=mask,\n        )\n        hidden = seg_la_fwd(\n            q=q,\n            k=k,\n            v=v,\n            s=kv_cache,\n            decay_scales=self.tp_slope[layer.layer_id],\n            meta=seg_meta,\n            caches=temp_cache,\n            cache_indices=intermediate_state_indices,\n            decouple=True,\n        )\n        return hidden\n\n    def forward_extend(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache=True,\n        **kwargs,\n    ):\n        q_rope = kwargs[\"q_rope\"] if \"q_rope\" in kwargs else None\n        k_rope = kwargs[\"k_rope\"] if \"k_rope\" in kwargs else None\n        layer_id = layer.layer_id if layer else kwargs[\"layer_id\"]\n\n        metadata = self.forward_metadata\n\n        if self.kv_cache_dtype_str != \"auto\" and layer.k_scale is not None:\n            q = q.to(self.kv_cache_dtype)\n\n        query_start_loc = self.forward_metadata.query_start_loc\n        cache_indices = self.forward_metadata.mamba_cache_indices\n        mamba_cache_params = self.req_to_token_pool.mamba2_layer_cache(layer_id)\n        ssm_states = mamba_cache_params.temporal\n        if self.linear_backend == \"minimax\":\n            o = self._prefill_and_mix_infer(\n                q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),\n                k,\n                v,\n                ssm_states,\n                cache_indices,\n                forward_batch,\n                layer,\n                metadata,\n            )\n        elif self.linear_backend == \"seg_la\":\n            intermediate_state_indices = (\n                torch.arange(\n                    cache_indices.shape[0],\n                    dtype=torch.int32,\n                    device=cache_indices.device,\n                )\n                if forward_batch.forward_mode.is_target_verify()\n                else None\n            )\n            o = self._linear_attention_entry(\n                q,\n                k,\n                v,\n                ssm_states,\n                cache_indices,\n                metadata,\n                layer,\n                temp_cache=(\n                    mamba_cache_params.intermediate_ssm\n                    if forward_batch.forward_mode.is_target_verify()\n                    else None\n                ),\n                intermediate_state_indices=intermediate_state_indices,\n            )\n        else:\n            raise ValueError(\n                f\"linear backend: {self.linear_backend} is not support for now\"\n            )\n        return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)\n\n    def forward_decode(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache=True,\n        **kwargs,\n    ) -> torch.Tensor:\n        q_rope = kwargs[\"q_rope\"] if \"q_rope\" in kwargs else None\n        k_rope = kwargs[\"k_rope\"] if \"k_rope\" in kwargs else None\n        layer_id = layer.layer_id if layer else kwargs[\"layer_id\"]\n\n        # Use precomputed metadata across all layers\n        metadata = self.forward_metadata\n\n        if self.kv_cache_dtype_str != \"auto\":\n            q = q.to(self.kv_cache_dtype)\n\n        # Do linear attention\n        query_start_loc = self.forward_metadata.query_start_loc\n        cache_indices = self.forward_metadata.mamba_cache_indices\n        mamba_cache_params = self.req_to_token_pool.mamba2_layer_cache(layer_id)\n        ssm_states = mamba_cache_params.temporal\n        if self.linear_backend == \"minimax\":\n            o = self._decode_infer(q, k, v, ssm_states, cache_indices, metadata, layer)\n        elif self.linear_backend == \"seg_la\":\n            o = self._linear_attention_entry(\n                q, k, v, ssm_states, cache_indices, metadata, layer\n            )\n        else:\n            raise ValueError(\n                f\"linear backend: {self.linear_backend} is not support for now\"\n            )\n        return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/linear/linear_metadata.py",
    "content": "from dataclasses import dataclass\n\nimport torch\n\nfrom sglang.srt.layers.attention.mamba.mamba2_metadata import ForwardMetadata\nfrom sglang.srt.model_executor.forward_batch_info import ForwardBatch\n\n\n@dataclass(kw_only=True)\nclass BailingLinearMetadata(ForwardMetadata):\n    num_prefills: int\n    num_prefill_tokens: int\n    num_decodes: int\n    batch_size: int\n    has_initial_states: torch.Tensor\n    q_lengths: torch.Tensor\n\n    @staticmethod\n    def prepare_decode(\n        query_start_loc: torch.Tensor,\n        mamba_cache_indices: torch.Tensor,\n        bs: int,\n        seq_lens: torch.Tensor,\n    ) -> \"BailingLinearMetadata\":\n        \"\"\"This path is run during CUDA graph capture, i.e. decode only, so `num_prefills` is 0\"\"\"\n        return BailingLinearMetadata(\n            batch_size=bs,\n            query_start_loc=query_start_loc,\n            mamba_cache_indices=mamba_cache_indices,\n            num_decodes=seq_lens.shape[0],\n            num_prefills=0,\n            num_prefill_tokens=0,\n            has_initial_states=torch.ones_like(seq_lens),\n            q_lengths=query_start_loc.diff(),\n        )\n\n    @classmethod\n    def prepare_mixed(\n        cls,\n        query_start_loc: torch.Tensor,\n        mamba_cache_indices: torch.Tensor,\n        forward_batch: ForwardBatch,\n    ) -> \"BailingLinearMetadata\":\n        \"\"\"This path cannot run with CUDA graph, as it contains extend requests.\"\"\"\n        if forward_batch.extend_num_tokens is None:\n            return cls.prepare_decode(\n                query_start_loc=query_start_loc,\n                mamba_cache_indices=mamba_cache_indices,\n                bs=forward_batch.batch_size,\n                seq_lens=forward_batch.seq_lens,\n            )\n        num_prefills = len(forward_batch.extend_seq_lens)\n        num_prefill_tokens = forward_batch.extend_num_tokens\n        num_decodes = len(forward_batch.seq_lens) - num_prefills\n        context_lens_tensor = forward_batch.extend_prefix_lens\n        assert context_lens_tensor is not None\n        has_initial_states = context_lens_tensor > 0\n\n        query_start_loc = query_start_loc[: num_prefills + 1]\n\n        return BailingLinearMetadata(\n            batch_size=forward_batch.batch_size,\n            query_start_loc=query_start_loc,\n            mamba_cache_indices=mamba_cache_indices,\n            num_prefills=num_prefills,\n            num_prefill_tokens=num_prefill_tokens,\n            num_decodes=num_decodes,\n            has_initial_states=has_initial_states,\n            q_lengths=query_start_loc.diff(),\n        )\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/linear/seg_la.py",
    "content": "# -*- coding: utf-8 -*-\n\"\"\"\nCopyright (c) Ant Financial Service Group and its affiliates.\n\"\"\"\n\n# Copied from https://code.alipay.com/pia/PainlessInferenceAcceleration/blob/v0.0.6/flood/flood/ops/seg_la.py\n\nfrom dataclasses import dataclass\nfrom typing import Optional\n\nimport torch\nimport triton\nimport triton.language as tl\n\n\n# arg `meta` of `seg_la_fwd` is SegLaMeta\n@dataclass\nclass SegLaMeta:\n    batch_size: int  # batch size, num of requests\n    max_q_length: int  # max(seq_lens)\n    q_offsets: torch.Tensor  # [bs+1], query_start_locations,\n    s_offsets: torch.Tensor  # [bs], slot_ids\n    q_lengths: torch.Tensor  # [bs], query length\n    s_scales: torch.Tensor  # [bs], prefill = 0, decode = 1\n    s_offsets_stride: int = 0\n    q_offsets_stride: int = 0\n    s_scales_stride: int = 0\n    decay_scales_stride: int = 0\n    mask: Optional[torch.Tensor] = None  # Currently not supported\n\n\n# fused\n@triton.jit\ndef seg_la_kernel(\n    Q,\n    K,\n    V,\n    S,\n    Out,\n    softmax_scale,\n    stride_q,\n    stride_k,\n    stride_v,\n    stride_s,\n    stride_o,\n    s_offsets,\n    q_offsets,\n    q_lengths,\n    s_scales,\n    decay_scales,\n    HEAD_DIM: tl.constexpr,\n    SPLIT_DIM: tl.constexpr,\n    BLOCK: tl.constexpr,\n    EVEN: tl.constexpr,\n    DECOUPLE: tl.constexpr,\n):\n    bid = tl.program_id(0)\n    hid = tl.program_id(1)\n    sid = tl.program_id(2)\n\n    # s_scale is 0 (prefill) or 1 (decode)\n    s_scale = tl.load(s_scales + bid)\n    q_length = tl.load(q_lengths + bid)\n    q_offset = tl.load(q_offsets + bid)\n    s_offset = tl.load(s_offsets + bid)\n    decay_scale = -tl.load(decay_scales + hid)\n\n    offs_b = tl.arange(0, BLOCK)\n    offs_d = tl.arange(0, HEAD_DIM)\n    offs_s = tl.arange(0, SPLIT_DIM)\n\n    if s_offset == -1:\n        return\n\n    q_ptrs = (\n        Q\n        + q_offset * stride_q\n        + hid * HEAD_DIM\n        + (offs_b[:, None] * stride_q + offs_d[None, :])\n    )\n    k_ptrs = (\n        K\n        + q_offset * stride_k\n        + hid * HEAD_DIM\n        + (offs_b[:, None] * stride_k + offs_d[None, :])\n    )\n    v_ptrs = (\n        V\n        + q_offset * stride_v\n        + hid * HEAD_DIM\n        + sid * SPLIT_DIM\n        + (offs_b[:, None] * stride_v + offs_s[None, :])\n    )\n    out_ptrs = (\n        Out\n        + q_offset * stride_o\n        + hid * HEAD_DIM\n        + sid * SPLIT_DIM\n        + (offs_b[:, None] * stride_o + offs_s[None, :])\n    )\n    s_ptrs = (\n        S\n        + s_offset * stride_s\n        + hid * HEAD_DIM * HEAD_DIM\n        + sid * SPLIT_DIM\n        + (offs_d[:, None] * HEAD_DIM + offs_s[None, :])\n    )\n    state = tl.load(s_ptrs, mask=s_scale > 0).to(tl.float32)\n\n    if BLOCK > 1:\n        for n in range(0, q_length, BLOCK):\n            n = tl.multiple_of(n, BLOCK)\n\n            if EVEN:\n                q = tl.load(q_ptrs + n * stride_q).to(tl.float32)\n                k = tl.trans(tl.load(k_ptrs + n * stride_k)).to(tl.float32)\n                v = tl.load(v_ptrs + n * stride_k).to(tl.float32)\n            else:\n                q = tl.load(\n                    q_ptrs + n * stride_q,\n                    mask=(n + offs_b)[:, None] < q_length,\n                    other=0.0,\n                ).to(tl.float32)\n                k = tl.trans(\n                    tl.load(\n                        k_ptrs + n * stride_k,\n                        mask=(n + offs_b)[:, None] < q_length,\n                        other=0.0,\n                    )\n                ).to(tl.float32)\n                v = tl.load(\n                    v_ptrs + n * stride_k,\n                    mask=(n + offs_b)[:, None] < q_length,\n                    other=0.0,\n                ).to(tl.float32)\n\n            if DECOUPLE:\n                # only work with small scales\n                if EVEN:\n                    b = BLOCK\n                else:\n                    b = min(BLOCK, q_length - n)\n                b_offs = b - 1 - offs_b\n\n                edb = tl.exp(decay_scale * b_offs)\n                decays = tl.where(b_offs >= 0, edb, 0)\n                inv_decays = tl.where(b_offs >= 0, 1 / edb, 0)\n\n                q = q * inv_decays[:, None]\n                k = k * decays[None, :]\n                qk = tl.dot(q, k) * softmax_scale\n                qk = tl.where(offs_b[None, :] <= offs_b[:, None], qk, 0.0)\n                o = tl.dot(qk, v)\n\n                block_decay = tl.exp(decay_scale * b)\n                block_decay_plus = block_decay * softmax_scale\n                o = tl.dot(q, state) * block_decay_plus + o\n\n                state = state * block_decay + tl.dot(k, v)\n            else:\n\n                qk = tl.dot(q, k) * softmax_scale\n                decays = tl.exp(decay_scale * (offs_b[:, None] - offs_b[None, :]))\n                decays = tl.where(offs_b[None, :] <= offs_b[:, None], decays, 0.0)\n                qk *= decays\n                o = tl.dot(qk, v)\n\n                decay_arr = tl.exp(decay_scale * (offs_b[:, None] + 1)) * softmax_scale\n                o = tl.dot(q * decay_arr, state, acc=o)\n\n                if EVEN:\n                    b = BLOCK\n                else:\n                    b = min(BLOCK, q_length - n)\n                b_offs = b - 1 - offs_b\n                b_offs = tl.where(b_offs >= 0, b_offs, 10000)\n                decays = tl.exp(decay_scale * b_offs)\n                block_decay = tl.exp(decay_scale * b)\n                state = state * block_decay + tl.dot(k * decays[None, :], v)\n\n            if EVEN:\n                tl.store(out_ptrs + n * stride_o, o.to(Out.dtype.element_ty))\n            else:\n                tl.store(\n                    out_ptrs + n * stride_o,\n                    o.to(Out.dtype.element_ty),\n                    mask=(n + offs_b)[:, None] < q_length,\n                )\n\n        tl.store(s_ptrs, state.to(S.dtype.element_ty))\n\n    else:\n        q = tl.trans(tl.load(q_ptrs)).to(tl.float32) * softmax_scale\n        k = tl.trans(tl.load(k_ptrs)).to(tl.float32)\n        v = tl.load(v_ptrs).to(tl.float32)\n        state = state * tl.exp(decay_scale) + k * v\n\n        o = tl.sum(q * state, axis=0, keep_dims=True)\n\n        tl.store(out_ptrs, o.to(Out.dtype.element_ty))\n\n        tl.store(s_ptrs, state.to(S.dtype.element_ty))\n\n\n# used for prefilling\n@triton.jit\ndef seg_la_p_kernel(\n    Q,\n    K,\n    V,\n    S,\n    Out,\n    softmax_scale,\n    stride_q,\n    stride_k,\n    stride_v,\n    stride_s,\n    stride_o,\n    s_offsets,\n    q_offsets,\n    q_lengths,\n    s_scales,\n    decay_scales,\n    HEAD_DIM: tl.constexpr,\n    K_SPLIT_DIM: tl.constexpr,\n    V_SPLIT_DIM: tl.constexpr,\n    BLOCK: tl.constexpr,\n    EVEN: tl.constexpr,\n):\n    bid = tl.program_id(0)\n    hid = tl.program_id(1)\n    kvid = tl.program_id(2)\n    N = HEAD_DIM // V_SPLIT_DIM\n    kid = kvid // N\n    vid = kvid % N\n    H = tl.num_programs(1)\n\n    # s_scale is 0 (first prefill chunk) or 1 (next prefill chunk)\n    s_scale = tl.load(s_scales + bid)\n    q_length = tl.load(q_lengths + bid)\n    q_offset = tl.load(q_offsets + bid)\n    s_offset = tl.load(s_offsets + bid)\n    decay_scale = -tl.load(decay_scales + hid)\n\n    offs_b = tl.arange(0, BLOCK)\n    offs_k = tl.arange(0, K_SPLIT_DIM)\n    offs_v = tl.arange(0, V_SPLIT_DIM)\n\n    if s_offset == -1:\n        return\n\n    q_ptrs = (\n        Q\n        + q_offset * stride_q\n        + hid * HEAD_DIM\n        + kid * K_SPLIT_DIM\n        + (offs_b[:, None] * stride_q + offs_k[None, :])\n    )\n    k_ptrs = (\n        K\n        + q_offset * stride_k\n        + hid * HEAD_DIM\n        + kid * K_SPLIT_DIM\n        + (offs_b[:, None] * stride_k + offs_k[None, :])\n    )\n    v_ptrs = (\n        V\n        + q_offset * stride_v\n        + hid * HEAD_DIM\n        + vid * V_SPLIT_DIM\n        + (offs_b[:, None] * stride_v + offs_v[None, :])\n    )\n    # (num_dim_block, length, qo_heads, d)\n    out_ptrs = (\n        Out\n        + kid * stride_o\n        + q_offset * HEAD_DIM * H\n        + hid * HEAD_DIM\n        + vid * V_SPLIT_DIM\n        + (offs_b[:, None] * H * HEAD_DIM + offs_v[None, :])\n    )\n    s_ptrs = (\n        S\n        + s_offset * stride_s\n        + hid * HEAD_DIM * HEAD_DIM\n        + kid * HEAD_DIM * K_SPLIT_DIM\n        + vid * V_SPLIT_DIM\n        + (offs_k[:, None] * HEAD_DIM + offs_v[None, :])\n    )\n    state = tl.load(s_ptrs, mask=s_scale > 0).to(tl.float32)\n\n    for n in range(0, q_length, BLOCK):\n        n = tl.multiple_of(n, BLOCK)\n\n        if EVEN:\n            q = tl.load(q_ptrs + n * stride_q).to(tl.float32)\n            k = tl.trans(tl.load(k_ptrs + n * stride_k)).to(tl.float32)\n            v = tl.load(v_ptrs + n * stride_v).to(tl.float32)\n            b = BLOCK\n            b_offs = b - 1 - offs_b\n            decays = tl.exp(decay_scale * b_offs)\n            inv_decays = 1 / decays\n        else:\n            q = tl.load(\n                q_ptrs + n * stride_q, mask=(n + offs_b)[:, None] < q_length, other=0.0\n            ).to(tl.float32)\n            k = tl.trans(\n                tl.load(\n                    k_ptrs + n * stride_k,\n                    mask=(n + offs_b)[:, None] < q_length,\n                    other=0.0,\n                )\n            ).to(tl.float32)\n            v = tl.load(\n                v_ptrs + n * stride_v, mask=(n + offs_b)[:, None] < q_length, other=0.0\n            ).to(tl.float32)\n            b = min(BLOCK, q_length - n)\n            b_offs = b - 1 - offs_b\n            block_decays = tl.exp(decay_scale * b_offs)\n            decays = tl.where(b_offs >= 0, block_decays, 0)\n            inv_decays = tl.where(b_offs >= 0, 1 / block_decays, 0)\n\n        q = q * inv_decays[:, None]\n        k = k * decays[None, :]\n        qk = tl.dot(q, k) * softmax_scale\n        qk = tl.where(offs_b[None, :] <= offs_b[:, None], qk, 0.0)\n        o = tl.dot(qk, v)\n\n        block_decay = tl.exp(decay_scale * b)\n        o = tl.dot(q, state) * block_decay * softmax_scale + o\n\n        state = state * block_decay + tl.dot(k, v)\n\n        if EVEN:\n            tl.store(out_ptrs + n * H * HEAD_DIM, o.to(Out.dtype.element_ty))\n        else:\n            tl.store(\n                out_ptrs + n * H * HEAD_DIM,\n                o.to(Out.dtype.element_ty),\n                mask=(n + offs_b)[:, None] < q_length,\n            )\n\n    tl.store(s_ptrs, state.to(S.dtype.element_ty))\n\n\n# used for speculative\n@triton.jit\ndef seg_la_s_kernel(\n    Q,\n    K,\n    V,\n    S,\n    Out,\n    Mask,\n    softmax_scale,\n    stride_q,\n    stride_k,\n    stride_v,\n    stride_s,\n    stride_o,\n    s_offsets,\n    q_offsets,\n    q_lengths,\n    s_scales,\n    decay_scales,\n    HEAD_DIM: tl.constexpr,\n    K_SPLIT_DIM: tl.constexpr,\n    V_SPLIT_DIM: tl.constexpr,\n    BLOCK: tl.constexpr,\n    EVEN: tl.constexpr,\n):\n    bid = tl.program_id(0)\n    hid = tl.program_id(1)\n    kvid = tl.program_id(2)\n    N = HEAD_DIM // V_SPLIT_DIM\n    kid = kvid // N\n    vid = kvid % N\n    H = tl.num_programs(1)\n\n    # s_scale is 0 (first prefill chunk) or 1 (next prefill chunk)\n    s_scale = tl.load(s_scales + bid)\n    q_length = tl.load(q_lengths + bid)\n    q_offset = tl.load(q_offsets + bid)\n    s_offset = tl.load(s_offsets + bid)\n    decay_scale = -tl.load(decay_scales + hid)\n\n    offs_b = tl.arange(0, BLOCK)\n    offs_k = tl.arange(0, K_SPLIT_DIM)\n    offs_v = tl.arange(0, V_SPLIT_DIM)\n\n    if s_offset == -1:\n        return\n\n    q_ptrs = (\n        Q\n        + q_offset * stride_q\n        + hid * HEAD_DIM\n        + kid * K_SPLIT_DIM\n        + (offs_b[:, None] * stride_q + offs_k[None, :])\n    )\n    k_ptrs = (\n        K\n        + q_offset * stride_k\n        + hid * HEAD_DIM\n        + kid * K_SPLIT_DIM\n        + (offs_b[:, None] * stride_k + offs_k[None, :])\n    )\n    v_ptrs = (\n        V\n        + q_offset * stride_v\n        + hid * HEAD_DIM\n        + vid * V_SPLIT_DIM\n        + (offs_b[:, None] * stride_v + offs_v[None, :])\n    )\n    # (num_dim_block, length, qo_heads, d)\n    out_ptrs = (\n        Out\n        + kid * stride_o\n        + q_offset * HEAD_DIM * H\n        + hid * HEAD_DIM\n        + vid * V_SPLIT_DIM\n        + (offs_b[:, None] * H * HEAD_DIM + offs_v[None, :])\n    )\n    s_ptrs = (\n        S\n        + s_offset * stride_s\n        + hid * HEAD_DIM * HEAD_DIM\n        + kid * HEAD_DIM * K_SPLIT_DIM\n        + vid * V_SPLIT_DIM\n        + (offs_k[:, None] * HEAD_DIM + offs_v[None, :])\n    )\n    state = tl.load(s_ptrs, mask=s_scale > 0).to(tl.float32)\n\n    if EVEN:\n        q = tl.load(q_ptrs).to(tl.float32)\n        k = tl.trans(tl.load(k_ptrs)).to(tl.float32)\n        v = tl.load(v_ptrs).to(tl.float32)\n        mask = tl.load(\n            Mask\n            + bid * BLOCK * BLOCK\n            + tl.arange(0, BLOCK)[:, None] * BLOCK\n            + tl.arange(0, BLOCK)[None, :]\n        ).to(tl.int32)\n        positions = tl.sum(mask, 1) - 1\n        max_pos = tl.max(positions)\n        b_offs = max_pos - positions\n    else:\n        q = tl.load(q_ptrs, mask=offs_b[:, None] < q_length).to(tl.float32)\n        k = tl.trans(tl.load(k_ptrs, mask=offs_b[:, None] < q_length)).to(tl.float32)\n        v = tl.load(v_ptrs, mask=offs_b[:, None] < q_length).to(tl.float32)\n        mask = tl.load(\n            Mask\n            + bid * q_length * q_length\n            + tl.arange(0, BLOCK)[:, None] * q_length\n            + tl.arange(0, BLOCK)[None, :],\n            mask=(tl.arange(0, BLOCK)[:, None] < q_length)\n            & (tl.arange(0, BLOCK)[None, :] < q_length),\n        ).to(tl.int32)\n        positions = tl.sum(mask, 1) - 1\n        max_pos = tl.max(positions)\n        b_offs = max_pos - positions\n\n    decays = tl.exp(decay_scale * b_offs)\n    inv_decays = 1 / decays\n\n    q = q * inv_decays[:, None]\n    k = k * decays[None, :]\n    qk = tl.dot(q, k) * softmax_scale\n    qk = qk * mask.to(tl.float32)\n    o = tl.dot(qk, v)\n\n    block_decay = tl.exp(decay_scale * (max_pos + 1))\n    o = tl.dot(q, state) * block_decay * softmax_scale + o\n\n    if EVEN:\n        tl.store(out_ptrs, o.to(Out.dtype.element_ty))\n    else:\n        tl.store(out_ptrs, o.to(Out.dtype.element_ty), mask=offs_b[:, None] < q_length)\n\n\n# used for decode\n@triton.jit\ndef seg_la_d_kernel(\n    Q,\n    K,\n    V,\n    S,\n    Out,\n    softmax_scale,\n    stride_q,\n    stride_k,\n    stride_v,\n    stride_s,\n    stride_o,\n    s_offsets,\n    decay_scales,\n    HEAD_DIM: tl.constexpr,\n    K_SPLIT_DIM: tl.constexpr,\n    V_SPLIT_DIM: tl.constexpr,\n):\n    bid = tl.program_id(0)\n    hid = tl.program_id(1)\n    kvid = tl.program_id(2)\n    N = HEAD_DIM // V_SPLIT_DIM\n    kid = kvid // N\n    vid = kvid % N\n    H = tl.num_programs(1)\n\n    # s_scale is 0 (first prefill chunk) or 1 (next prefill chunk)\n    s_offset = tl.load(s_offsets + bid)\n    if s_offset == -1:\n        return\n\n    decay_scale = -tl.load(decay_scales + hid)\n\n    offs_k = tl.arange(0, K_SPLIT_DIM)\n    offs_v = tl.arange(0, V_SPLIT_DIM)\n\n    q_ptrs = Q + bid * stride_q + hid * HEAD_DIM + kid * K_SPLIT_DIM + (offs_k)\n    k_ptrs = K + bid * stride_k + hid * HEAD_DIM + kid * K_SPLIT_DIM + (offs_k)\n    v_ptrs = V + bid * stride_v + hid * HEAD_DIM + vid * V_SPLIT_DIM + (offs_v)\n    # (num_dim_block, length, qo_heads, d)\n    out_ptrs = (\n        Out\n        + kid * stride_o\n        + bid * H * HEAD_DIM\n        + hid * HEAD_DIM\n        + vid * V_SPLIT_DIM\n        + (offs_v)\n    )\n    s_ptrs = (\n        S\n        + s_offset * stride_s\n        + hid * HEAD_DIM * HEAD_DIM\n        + kid * HEAD_DIM * K_SPLIT_DIM\n        + vid * V_SPLIT_DIM\n        + (offs_k[:, None] * HEAD_DIM + offs_v[None, :])\n    )\n    state = tl.load(s_ptrs).to(tl.float32)\n\n    k = tl.load(k_ptrs).to(tl.float32)\n    v = tl.load(v_ptrs).to(tl.float32)\n    q = tl.load(q_ptrs).to(tl.float32) * softmax_scale\n\n    state = state * tl.exp(decay_scale) + k[:, None] * v\n    o = tl.sum(q[:, None] * state, axis=0)\n\n    tl.store(out_ptrs, o.to(Out.dtype.element_ty))\n    tl.store(s_ptrs, state.to(S.dtype.element_ty))\n\n\n# used for MTP with only spec-topk=1.\n@triton.jit\ndef seg_la_mtp_kernel(\n    Q,\n    K,\n    V,\n    S,\n    CACHES,\n    Out,\n    softmax_scale,\n    stride_q,\n    stride_k,\n    stride_v,\n    stride_s,\n    stride_c,\n    stride_o,\n    s_offsets,\n    cache_indices,\n    decay_scales,\n    step,\n    HEAD_DIM: tl.constexpr,\n    K_SPLIT_DIM: tl.constexpr,\n    V_SPLIT_DIM: tl.constexpr,\n):\n    bid = tl.program_id(0)\n    hid = tl.program_id(1)\n    kvid = tl.program_id(2)\n    N = HEAD_DIM // V_SPLIT_DIM\n    kid = kvid // N\n    vid = kvid % N\n    H = tl.num_programs(1)\n\n    s_offset = tl.load(s_offsets + bid)\n    if s_offset == -1:\n        return\n\n    decay_scale = tl.exp(-tl.load(decay_scales + hid))\n\n    offs_k = tl.arange(0, K_SPLIT_DIM)\n    offs_v = tl.arange(0, V_SPLIT_DIM)\n\n    # (length, qo_heads, d)\n    q_ptrs = Q + bid * step * stride_q + hid * HEAD_DIM + kid * K_SPLIT_DIM + (offs_k)\n    k_ptrs = K + bid * step * stride_k + hid * HEAD_DIM + kid * K_SPLIT_DIM + (offs_k)\n    v_ptrs = V + bid * step * stride_v + hid * HEAD_DIM + vid * V_SPLIT_DIM + (offs_v)\n    # (num_dim_block, length, qo_heads, d)\n    out_ptrs = (\n        Out\n        + kid * stride_o\n        + bid * step * H * HEAD_DIM\n        + hid * HEAD_DIM\n        + vid * V_SPLIT_DIM\n        + (offs_v)\n    )\n    # (bs, qo_heads, d, d)\n    s_ptrs = (\n        S\n        + s_offset * stride_s\n        + hid * HEAD_DIM * HEAD_DIM\n        + kid * HEAD_DIM * K_SPLIT_DIM\n        + vid * V_SPLIT_DIM\n        + (offs_k[:, None] * HEAD_DIM + offs_v[None, :])\n    )\n    state = tl.load(s_ptrs).to(tl.float32)\n    # (bs, step, kv_heads, d, d)\n    cache_indices = tl.load(cache_indices + bid)\n    c_ptrs = (\n        CACHES\n        + cache_indices * stride_c\n        + hid * HEAD_DIM * HEAD_DIM\n        + kid * HEAD_DIM * K_SPLIT_DIM\n        + vid * V_SPLIT_DIM\n        + (offs_k[:, None] * HEAD_DIM + offs_v[None, :])\n    )\n\n    for i in range(step):\n        q = tl.load(q_ptrs).to(tl.float32) * softmax_scale\n        k = tl.load(k_ptrs).to(tl.float32)\n        v = tl.load(v_ptrs).to(tl.float32)\n\n        state = state * decay_scale + k[:, None] * v\n        o = tl.sum(q[:, None] * state, axis=0)\n\n        tl.store(out_ptrs, o.to(Out.dtype.element_ty))\n        tl.store(c_ptrs, state.to(CACHES.dtype.element_ty))\n        q_ptrs += stride_q\n        k_ptrs += stride_k\n        v_ptrs += stride_v\n        out_ptrs += H * HEAD_DIM\n        c_ptrs += H * HEAD_DIM * HEAD_DIM\n\n\n# (k_dim_block, length, qo_heads, d)\n@triton.jit\ndef seg_la_sum_kernel(T, O, DIM: tl.constexpr, NUM_BLOCK: tl.constexpr):\n    pid = tl.program_id(0)\n    length = tl.num_programs(0)\n    x = tl.zeros((DIM,), dtype=tl.float32)\n    for i in range(NUM_BLOCK):\n        x += tl.load(T + i * length * DIM + pid * DIM + tl.arange(0, DIM)).to(\n            tl.float32\n        )\n    tl.store(O + pid * DIM + tl.arange(0, DIM), x)\n\n\ndef seg_la_fwd(\n    q,\n    k,\n    v,\n    s,\n    decay_scales,\n    meta,\n    caches=None,\n    cache_indices=None,\n    softmax_scale=None,\n    decouple=False,\n):\n    length, qo_heads, HEAD_DIM = q.shape\n    _, kv_heads, _ = k.shape\n    bs = meta.batch_size\n    if softmax_scale is None:\n        softmax_scale = HEAD_DIM ** (-0.5)\n\n    # MAX_LENGTH = meta.max_q_length\n    MAX_LENGTH = triton.cdiv(length, bs)\n\n    assert qo_heads == kv_heads, \"seg_la does NOT support GQA currently\"\n\n    if MAX_LENGTH > 1:\n        # prefill with partitioning q/k/v\n        # BLOCK should <= 64 with decouple\n        K_SPLIT_DIM = 32\n        V_SPLIT_DIM = 32 if bs <= 2 else 64\n\n        num_warps = 2  # 2\n        num_stages = 3  # 3\n\n        k_dim_block = HEAD_DIM // K_SPLIT_DIM\n        v_dim_block = HEAD_DIM // V_SPLIT_DIM\n        tmp = torch.empty(\n            (k_dim_block, length, qo_heads, HEAD_DIM), device=q.device, dtype=q.dtype\n        )\n        grid = (bs, kv_heads, k_dim_block * v_dim_block)\n\n        if caches is not None:\n            # mtp\n            EVEN = False\n            BLOCK = 32\n            step = length // bs\n\n            seg_la_mtp_kernel[grid](\n                q,\n                k,\n                v,\n                s,\n                caches,\n                tmp,\n                softmax_scale,\n                q.stride(0),\n                k.stride(0),\n                v.stride(0),\n                s.stride(0),\n                caches.stride(0),\n                tmp.stride(0),\n                meta.s_offsets,\n                cache_indices,\n                decay_scales,\n                step,\n                HEAD_DIM=HEAD_DIM,\n                K_SPLIT_DIM=K_SPLIT_DIM,\n                V_SPLIT_DIM=V_SPLIT_DIM,\n                num_warps=num_warps,\n                num_stages=num_stages,\n            )\n\n        elif meta.mask is not None:\n            # spec\n            ms = meta.mask.size(-1)\n            BLOCK = (ms + 15) // 16 * 16\n            EVEN = BLOCK == ms\n\n            seg_la_s_kernel[grid](\n                q,\n                k,\n                v,\n                s,\n                tmp,\n                meta.mask,\n                softmax_scale,\n                q.stride(0),\n                k.stride(0),\n                v.stride(0),\n                s.stride(0),\n                tmp.stride(0),\n                meta.s_offsets,\n                meta.q_offsets,\n                meta.q_lengths,\n                meta.s_scales,\n                decay_scales,\n                HEAD_DIM=HEAD_DIM,\n                K_SPLIT_DIM=K_SPLIT_DIM,\n                V_SPLIT_DIM=V_SPLIT_DIM,\n                BLOCK=BLOCK,\n                EVEN=EVEN,\n                num_warps=num_warps,\n                num_stages=num_stages,\n            )\n\n        else:\n            # prefill\n            BLOCK = 32\n            EVEN = MAX_LENGTH % BLOCK == 0 if bs == 1 else False\n\n            seg_la_p_kernel[grid](\n                q,\n                k,\n                v,\n                s,\n                tmp,\n                softmax_scale,\n                q.stride(0),\n                k.stride(0),\n                v.stride(0),\n                s.stride(0),\n                tmp.stride(0),\n                meta.s_offsets,\n                meta.q_offsets,\n                meta.q_lengths,\n                meta.s_scales,\n                decay_scales,\n                HEAD_DIM=HEAD_DIM,\n                K_SPLIT_DIM=K_SPLIT_DIM,\n                V_SPLIT_DIM=V_SPLIT_DIM,\n                BLOCK=BLOCK,\n                EVEN=EVEN,\n                num_warps=num_warps,\n                num_stages=num_stages,\n            )\n\n        if k_dim_block > 1:\n            if length < 2048:\n                o = tmp.sum(0)\n            else:\n                o = torch.empty(\n                    (length, qo_heads, HEAD_DIM), device=q.device, dtype=q.dtype\n                )\n                seg_la_sum_kernel[(length,)](\n                    tmp,\n                    o,\n                    DIM=qo_heads * HEAD_DIM,\n                    NUM_BLOCK=k_dim_block,\n                    num_warps=2,\n                    num_stages=3,\n                )\n        else:\n            o = tmp[0]\n\n    else:\n        # decode with partitioning q/k/v\n        if bs <= 128:\n            K_SPLIT_DIM = 128  # 128\n            V_SPLIT_DIM = 32  # 32\n            num_warps = 2  # 2\n            num_stages = 2  # 3\n        else:\n            K_SPLIT_DIM = 128  # 128\n            V_SPLIT_DIM = 64  # 32\n            num_warps = 2  # 2\n            num_stages = 3  # 3\n        k_dim_block = HEAD_DIM // K_SPLIT_DIM\n        v_dim_block = HEAD_DIM // V_SPLIT_DIM\n        tmp = torch.empty(\n            (k_dim_block, length, qo_heads, HEAD_DIM), device=q.device, dtype=q.dtype\n        )\n        grid = (bs, kv_heads, k_dim_block * v_dim_block)\n\n        seg_la_d_kernel[grid](\n            q,\n            k,\n            v,\n            s,\n            tmp,\n            softmax_scale,\n            q.stride(0),\n            k.stride(0),\n            v.stride(0),\n            s.stride(0),\n            tmp.stride(0),\n            meta.s_offsets,\n            decay_scales,\n            HEAD_DIM=HEAD_DIM,\n            K_SPLIT_DIM=K_SPLIT_DIM,\n            V_SPLIT_DIM=V_SPLIT_DIM,\n            num_warps=num_warps,\n            num_stages=num_stages,\n        )\n        if k_dim_block > 1:\n            o = tmp.sum(0)\n        else:\n            o = tmp[0]\n\n    # if fallback:\n    #     # prefill/decode with partitioning v only\n    #     o = torch.empty(q.shape, device=q.device, dtype=q.dtype)\n    #     if MAX_LENGTH == 1:\n    #         # decode\n    #         BLOCK = 1\n    #         EVEN = False\n    #         SPLIT_DIM = 32\n    #         num_warps = 8\n    #         num_stages = 2\n    #         num_dim_block = HEAD_DIM // SPLIT_DIM\n    #         grid = (batch, kv_heads, num_dim_block)\n    #     else:\n    #         # prefill\n    #         if decouple:\n    #             BLOCK = 64\n    #             SPLIT_DIM = 16\n    #         else:\n    #             BLOCK = HEAD_DIM\n    #             SPLIT_DIM = 32\n    #         # EVEN = all([x % BLOCK == 0 for x in meta.qls])\n    #         EVEN = False\n    #         num_warps = 8\n    #         num_stages = 2\n    #         # prop = torch.cuda.get_device_properties(q.device.index)\n    #         # arch = prop.major * 10 + prop.minor\n    #         # if arch not in (80, 90):\n    #         #     num_stages = 1\n\n    #         num_dim_block = HEAD_DIM // SPLIT_DIM\n    #         grid = (batch, kv_heads, num_dim_block)\n\n    #     seg_la_kernel[grid](\n    #         q,\n    #         k,\n    #         v,\n    #         s,\n    #         o,\n    #         softmax_scale,\n    #         q.stride(0),\n    #         k.stride(0),\n    #         v.stride(0),\n    #         s.stride(0),\n    #         o.stride(0),\n    #         meta.s_offsets,\n    #         meta.q_offsets,\n    #         meta.q_lengths,\n    #         meta.s_scales,\n    #         decay_scales,\n    #         HEAD_DIM=HEAD_DIM,\n    #         SPLIT_DIM=SPLIT_DIM,\n    #         BLOCK=BLOCK,\n    #         EVEN=EVEN,\n    #         DECOUPLE=decouple,\n    #         num_warps=num_warps,\n    #         num_stages=num_stages\n    #     )\n    return o\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/linear/utils.py",
    "content": "from __future__ import annotations\n\nimport logging\nfrom enum import Enum\nfrom typing import TYPE_CHECKING, Optional\n\nfrom sglang.srt.utils.common import rank0_log\n\nif TYPE_CHECKING:\n    from sglang.srt.server_args import ServerArgs\n\nlogger = logging.getLogger(__name__)\n\n\nclass LinearAttnKernelBackend(Enum):\n    TRITON = \"triton\"\n    CUTEDSL = \"cutedsl\"\n    FLASHINFER = \"flashinfer\"\n\n    def is_triton(self):\n        return self == LinearAttnKernelBackend.TRITON\n\n    def is_cutedsl(self):\n        return self == LinearAttnKernelBackend.CUTEDSL\n\n    def is_flashinfer(self):\n        return self == LinearAttnKernelBackend.FLASHINFER\n\n\nLINEAR_ATTN_DECODE_BACKEND: Optional[LinearAttnKernelBackend] = None\nLINEAR_ATTN_PREFILL_BACKEND: Optional[LinearAttnKernelBackend] = None\n\n\ndef initialize_linear_attn_config(server_args: ServerArgs):\n    global LINEAR_ATTN_DECODE_BACKEND\n    global LINEAR_ATTN_PREFILL_BACKEND\n\n    base = server_args.linear_attn_backend\n    decode = server_args.linear_attn_decode_backend or base\n    prefill = server_args.linear_attn_prefill_backend or base\n\n    LINEAR_ATTN_DECODE_BACKEND = LinearAttnKernelBackend(decode)\n    LINEAR_ATTN_PREFILL_BACKEND = LinearAttnKernelBackend(prefill)\n    rank0_log(\n        f\"Linear attention kernel backend: \"\n        f\"decode={LINEAR_ATTN_DECODE_BACKEND.value}, \"\n        f\"prefill={LINEAR_ATTN_PREFILL_BACKEND.value}\"\n    )\n\n\ndef get_linear_attn_decode_backend() -> LinearAttnKernelBackend:\n    global LINEAR_ATTN_DECODE_BACKEND\n    if LINEAR_ATTN_DECODE_BACKEND is None:\n        logger.warning(\n            \"LINEAR_ATTN_DECODE_BACKEND is not initialized, using triton backend\"\n        )\n        LINEAR_ATTN_DECODE_BACKEND = LinearAttnKernelBackend.TRITON\n    return LINEAR_ATTN_DECODE_BACKEND\n\n\ndef get_linear_attn_prefill_backend() -> LinearAttnKernelBackend:\n    global LINEAR_ATTN_PREFILL_BACKEND\n    if LINEAR_ATTN_PREFILL_BACKEND is None:\n        logger.warning(\n            \"LINEAR_ATTN_PREFILL_BACKEND is not initialized, using triton backend\"\n        )\n        LINEAR_ATTN_PREFILL_BACKEND = LinearAttnKernelBackend.TRITON\n    return LINEAR_ATTN_PREFILL_BACKEND\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/mamba/causal_conv1d.py",
    "content": "# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py\n# SPDX-License-Identifier: Apache-2.0\n\n# Copyright (c) 2024, Tri Dao.\n# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py\n\nfrom typing import Optional\n\nimport torch\n\nfrom .causal_conv1d_triton import PAD_SLOT_ID\nfrom .causal_conv1d_triton import causal_conv1d_fn as _causal_conv1d_fn_triton\nfrom .causal_conv1d_triton import causal_conv1d_update as _causal_conv1d_update_triton\n\ntry:\n    from sgl_kernel import causal_conv1d_fwd\n    from sgl_kernel import causal_conv1d_update as causal_conv1d_update_kernel\n\n    torch.ops.sgl_kernel.causal_conv1d_update\n    _HAS_SGL_KERNEL = True\nexcept (ImportError, AttributeError):\n    _HAS_SGL_KERNEL = False\n\n\ndef _get_seq_lens_cpu(query_start_loc, x):\n    if query_start_loc is not None:\n        return (query_start_loc[1:] - query_start_loc[:-1]).cpu().tolist()\n    return [x.shape[-1]]\n\n\ndef causal_conv1d_fn(\n    x: torch.Tensor,\n    weight: torch.Tensor,\n    bias: Optional[torch.Tensor] = None,\n    query_start_loc: Optional[torch.Tensor] = None,\n    cache_indices: Optional[torch.Tensor] = None,\n    has_initial_state: Optional[torch.Tensor] = None,\n    conv_states: Optional[torch.Tensor] = None,\n    activation: Optional[str] = \"silu\",\n    pad_slot_id: int = PAD_SLOT_ID,\n    **kwargs,\n):\n    \"\"\"\n    x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen\n        sequences are concatenated from left to right for varlen\n    weight: (dim, width)\n    bias: (dim,)\n    query_start_loc: (batch + 1) int32\n        The cumulative sequence lengths of the sequences in\n        the batch, used to index into sequence. prepended by 0.\n        for example: query_start_loc = torch.Tensor([0,10,16,17]),\n        x.shape=(dim,17)\n    cache_indices: (batch)  int32\n        indicates the corresponding state index,\n        like so: conv_state = conv_states[cache_indices[batch_id]]\n    has_initial_state: (batch) bool\n        indicates whether should the kernel take the current state as initial\n        state for the calculations\n    conv_states: (...,dim,width - 1) itype\n        updated inplace if provided\n    activation: either None or \"silu\" or \"swish\"\n    pad_slot_id: int\n            if cache_indices is passed, lets the kernel identify padded\n            entries that will not be processed,\n            for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]\n            in this case, the kernel will not process entries at\n            indices 0 and 3\n\n\n    out: (batch, dim, seqlen)\n    \"\"\"\n    # Use Triton when: (1) sgl_kernel not available, or (2) input is\n    # non-contiguous and seq_lens_cpu is already pre-computed by caller.\n    # The Triton kernel accepts arbitrary strides, avoiding a .contiguous()\n    # copy that can cost >0.6 ms/layer on large prefill batches.\n    use_triton = not _HAS_SGL_KERNEL or (x.stride(-1) != 1 and \"seq_lens_cpu\" in kwargs)\n    if use_triton:\n        if \"seq_lens_cpu\" not in kwargs:\n            kwargs[\"seq_lens_cpu\"] = _get_seq_lens_cpu(query_start_loc, x)\n        return _causal_conv1d_fn_triton(\n            x,\n            weight,\n            bias,\n            conv_states=conv_states,\n            query_start_loc=query_start_loc,\n            cache_indices=cache_indices,\n            has_initial_state=has_initial_state,\n            activation=activation,\n            pad_slot_id=pad_slot_id,\n            **kwargs,\n        )\n    if activation not in [None, \"silu\", \"swish\"]:\n        raise NotImplementedError(\"activation must be None, silu, or swish\")\n    if x.stride(-1) != 1:\n        x = x.contiguous()\n    bias = bias.contiguous() if bias is not None else None\n\n    causal_conv1d_fwd(\n        x,\n        weight,\n        bias,\n        conv_states,\n        query_start_loc,\n        cache_indices,\n        has_initial_state,\n        activation in [\"silu\", \"swish\"],\n        pad_slot_id,\n    )\n    return x\n\n\ndef causal_conv1d_update(\n    x: torch.Tensor,\n    conv_state: torch.Tensor,\n    weight: torch.Tensor,\n    bias: Optional[torch.Tensor] = None,\n    activation: Optional[str] = None,\n    cache_seqlens: Optional[torch.Tensor] = None,\n    conv_state_indices: Optional[torch.Tensor] = None,\n    pad_slot_id: int = PAD_SLOT_ID,\n):\n    \"\"\"\n    x: (batch, dim) or (batch, dim, seqlen)\n    conv_state: (batch, dim, state_len), where state_len >= width - 1\n    weight: (dim, width)\n    bias: (dim,)\n    cache_seqlens: (batch,), dtype int32.\n        If not None, the conv_state is treated as a circular buffer.\n        The conv_state will be updated by copying x to the conv_state\n        starting at the index\n        @cache_seqlens % state_len.\n    conv_state_indices: (batch,), dtype int32\n        If not None, the conv_state is a larger tensor along the batch dim,\n        and we are selecting the batch coords specified by conv_state_indices.\n        Useful for a continuous batching scenario.\n    pad_slot_id: int\n            if cache_indices is passed, lets the kernel identify padded\n            entries that will not be processed,\n            for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]\n            in this case, the kernel will not process entries at\n            indices 0 and 3\n    out: (batch, dim) or (batch, dim, seqlen)\n    \"\"\"\n    use_triton = not _HAS_SGL_KERNEL\n    if use_triton:\n        return _causal_conv1d_update_triton(\n            x,\n            conv_state,\n            weight,\n            bias=bias,\n            activation=activation,\n            cache_seqlens=cache_seqlens,\n            conv_state_indices=conv_state_indices,\n            pad_slot_id=pad_slot_id,\n        )\n    if activation not in [None, \"silu\", \"swish\"]:\n        raise NotImplementedError(\n            f\"activation must be None, silu, or swish, actual: {activation}\"\n        )\n    activation_val = activation in [\"silu\", \"swish\"]\n    unsqueeze = x.dim() == 2\n    if unsqueeze:\n        x = x.unsqueeze(-1)\n    causal_conv1d_update_kernel(\n        x,\n        conv_state,\n        weight,\n        bias,\n        activation_val,\n        cache_seqlens,\n        conv_state_indices,\n        pad_slot_id,\n    )\n    if unsqueeze:\n        x = x.squeeze(-1)\n    return x\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/mamba/causal_conv1d_triton.py",
    "content": "# Copyright (c) 2024, Tri Dao.\n# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py\n# and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py\n\nfrom typing import List, Optional, Union\n\nimport torch\nimport triton\nimport triton.language as tl\n\nPAD_SLOT_ID = -1\n\n\n@triton.jit()\ndef _causal_conv1d_fwd_kernel(  # continuous batching\n    # Pointers to matrices\n    x_ptr,  # (dim, cu_seqlen) holding `batch` of actual sequences + padded sequences\n    w_ptr,  # (dim, width)\n    bias_ptr,\n    initial_states_ptr,  # conv_states_ptr\n    cache_indices_ptr,  # conv_state_indices_ptr\n    has_initial_states_ptr,\n    query_start_loc_ptr,\n    o_ptr,  # (dim, seqlen) - actually pointing to x_ptr\n    # Matrix dimensions\n    dim: tl.constexpr,\n    seqlen: tl.int32,  # cu_seqlen\n    num_cache_lines: tl.constexpr,  # added to support vLLM larger cache lines\n    # Strides\n    stride_x_seq: tl.constexpr,  # stride to get to next sequence,\n    stride_x_dim: tl.constexpr,  # stride to get to next feature-value,\n    stride_x_token: tl.constexpr,  # stride to get to next token (same feature-index, same sequence-index)\n    stride_w_dim: tl.constexpr,  # stride to get to next dim-axis value\n    stride_w_width: tl.constexpr,  # stride to get to next width-axis value\n    stride_istate_seq: tl.constexpr,\n    stride_istate_dim: tl.constexpr,\n    stride_istate_token: tl.constexpr,\n    stride_o_seq: tl.constexpr,\n    stride_o_dim: tl.constexpr,\n    stride_o_token: tl.constexpr,\n    # others\n    pad_slot_id: tl.constexpr,\n    # Meta-parameters\n    HAS_BIAS: tl.constexpr,\n    KERNEL_WIDTH: tl.constexpr,\n    SILU_ACTIVATION: tl.constexpr,\n    HAS_INITIAL_STATES: tl.constexpr,\n    HAS_CACHE: tl.constexpr,\n    IS_CONTINUOUS_BATCHING: tl.constexpr,\n    USE_PAD_SLOT: tl.constexpr,\n    NP2_STATELEN: tl.constexpr,\n    BLOCK_M: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n):\n    conv_states_ptr = initial_states_ptr\n    conv_state_indices_ptr = cache_indices_ptr\n    stride_conv_state_seq = stride_istate_seq\n    stride_conv_state_dim = stride_istate_dim\n    stride_conv_state_tok = stride_istate_token\n    state_len = (\n        KERNEL_WIDTH - 1\n    )  # can be passed via argument if it's not the same as this value\n\n    # one program handles one chunk in a single sequence\n    # rather than mixing sequences - to make updating initial_states across sequences efficiently\n\n    # single-sequence id\n    idx_seq = tl.program_id(0)\n    chunk_offset = tl.program_id(1)\n\n    # BLOCK_N elements along the feature-dimension (channel)\n    idx_feats = tl.program_id(2) * BLOCK_N + tl.arange(0, BLOCK_N)\n\n    if idx_seq == pad_slot_id:\n        return\n\n    sequence_start_index = tl.load(query_start_loc_ptr + idx_seq)\n    sequence_end_index = tl.load(query_start_loc_ptr + idx_seq + 1)\n    # find the actual sequence length\n    seqlen = sequence_end_index - sequence_start_index\n\n    token_offset = BLOCK_M * chunk_offset\n    segment_len = min(BLOCK_M, seqlen - token_offset)\n\n    if segment_len <= 0:\n        return\n\n    # base of the sequence\n    x_base = (\n        x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim\n    )  # [BLOCK_N,]\n\n    if IS_CONTINUOUS_BATCHING:\n        # cache_idx\n        conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq).to(tl.int64)\n    else:\n        # cache_idx\n        conv_state_batch_coord = idx_seq\n    if USE_PAD_SLOT:  # noqa\n        if conv_state_batch_coord == pad_slot_id:\n            # not processing as this is not the actual sequence\n            return\n    conv_states_base = (\n        conv_states_ptr\n        + (conv_state_batch_coord * stride_conv_state_seq)\n        + (idx_feats * stride_conv_state_dim)\n    )  # [BLOCK_N,]\n\n    w_base = w_ptr + (idx_feats * stride_w_dim)  # [BLOCK_N,]\n\n    # Does 2 things:\n    # 1. READ prior-block init-state data - [done by every Triton programs]\n    # 2. update conv_state with new data [only by the Triton program handles chunk_offset=0]\n    if chunk_offset == 0:\n        # read from conv_states\n        load_init_state = False\n        if HAS_INITIAL_STATES:  # the new HAS_INITIAL_STATES\n            load_init_state = tl.load(has_initial_states_ptr + idx_seq).to(tl.int1)\n        if load_init_state:\n            # load from conv_states\n            prior_tokens = conv_states_base + (state_len - 1) * stride_conv_state_tok\n            mask_w = idx_feats < dim\n            if KERNEL_WIDTH == 2:\n                conv_states_ptrs = prior_tokens  # [BLOCK_N]\n                col0 = tl.load(conv_states_ptrs, mask_w, 0.0)\n            if KERNEL_WIDTH == 3:\n                conv_states_ptrs = prior_tokens  # [BLOCK_N]\n                col1 = tl.load(conv_states_ptrs, mask_w, 0.0)\n                conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok  # [BLOCK_N]\n                col0 = tl.load(conv_states_ptrs, mask_w, 0.0)\n            if KERNEL_WIDTH == 4:\n                conv_states_ptrs = prior_tokens  # [BLOCK_N]\n                col2 = tl.load(conv_states_ptrs, mask_w, 0.0)\n                conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok  # [BLOCK_N]\n                col1 = tl.load(conv_states_ptrs, mask_w, 0.0)\n                conv_states_ptrs = prior_tokens - 2 * stride_conv_state_tok  # [BLOCK_N]\n                col0 = tl.load(conv_states_ptrs, mask_w, 0.0)\n            if KERNEL_WIDTH == 5:\n                conv_states_ptrs = prior_tokens  # [BLOCK_N]\n                col3 = tl.load(conv_states_ptrs, mask_w, 0.0)\n                conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok  # [BLOCK_N]\n                col2 = tl.load(conv_states_ptrs, mask_w, 0.0)\n                conv_states_ptrs = prior_tokens - 2 * stride_conv_state_tok  # [BLOCK_N]\n                col1 = tl.load(conv_states_ptrs, mask_w, 0.0)\n                conv_states_ptrs = prior_tokens - 3 * stride_conv_state_tok  # [BLOCK_N]\n                col0 = tl.load(conv_states_ptrs, mask_w, 0.0)\n        else:\n            # prior-tokens are zeros\n            if KERNEL_WIDTH >= 2:  # STRATEGY1\n                # first chunk and does not have prior-token, so just set to 0\n                col0 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty)\n            if KERNEL_WIDTH >= 3:  # STRATEGY1\n                col1 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty)\n            if KERNEL_WIDTH >= 4:  # STRATEGY1\n                col2 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty)\n            if KERNEL_WIDTH >= 5:  # STRATEGY1\n                col3 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty)\n\n        # STEP 2:\n        # here prepare data for updating conv_state\n        if (\n            state_len <= seqlen\n        ):  # SMALL_CACHE=True (only move part of 'x' into conv_state cache)\n            # just read from 'x'\n            # copy 'x' data to conv_state\n            # load only 'x' data (and set 0 before 'x' if seqlen < state_len)\n            idx_tokens_last = (seqlen - state_len) + tl.arange(\n                0, NP2_STATELEN\n            )  # [BLOCK_M]\n            x_ptrs = (\n                x_ptr\n                + ((sequence_start_index + idx_tokens_last) * stride_x_token)[:, None]\n                + (idx_feats * stride_x_dim)[None, :]\n            )  # [BLOCK_M,BLOCK_N,]\n            mask_x = (\n                (idx_tokens_last >= 0)[:, None]\n                & (idx_tokens_last < seqlen)[:, None]\n                & (idx_feats < dim)[None, :]\n            )  # token-index  # token-index  # feature-index\n            loaded_x = tl.load(x_ptrs, mask_x, 0.0)\n            new_conv_state = tl.load(x_ptrs, mask_x, 0.0)\n            idx_tokens_conv = tl.arange(0, NP2_STATELEN)  # [BLOCK_M]\n            conv_states_ptrs_target = (\n                conv_states_base[None, :]\n                + (idx_tokens_conv * stride_conv_state_tok)[:, None]\n            )\n\n            mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[None, :]\n            tl.debug_barrier()  #  NOTE: use this due to bug in Triton compiler\n            tl.store(conv_states_ptrs_target, new_conv_state, mask)\n\n        else:\n            if load_init_state:\n                # update conv_state by shifting left, i.e. take last few cols from conv_state + cols from 'x'\n                idx_tokens_conv = tl.arange(0, NP2_STATELEN)  # [BLOCK_M]\n\n                conv_states_ptrs_source = (\n                    conv_states_ptr\n                    + (conv_state_batch_coord * stride_conv_state_seq)\n                    + (idx_feats * stride_conv_state_dim)[None, :]\n                    + ((idx_tokens_conv + seqlen) * stride_conv_state_tok)[:, None]\n                )  # [BLOCK_M, BLOCK_N]\n                mask = (\n                    (conv_state_batch_coord < num_cache_lines)\n                    & ((idx_tokens_conv + seqlen) < state_len)[:, None]\n                    & (idx_feats < dim)[None, :]\n                )\n                conv_state = tl.load(conv_states_ptrs_source, mask, other=0.0)\n\n                VAL = state_len - seqlen\n\n                x_ptrs = (\n                    x_base[None, :]\n                    + ((idx_tokens_conv - VAL) * stride_x_token)[:, None]\n                )  # [BLOCK_M, BLOCK_N]\n\n                mask_x = (\n                    (idx_tokens_conv - VAL >= 0)[:, None]\n                    & (idx_tokens_conv - VAL < seqlen)[:, None]\n                    & (idx_feats < dim)[None, :]\n                )  # token-index  # token-index  # feature-index\n                loaded_x = tl.load(x_ptrs, mask_x, 0.0)\n\n                tl.debug_barrier()  # need this due to the bug in tl.where not enforcing this when data is the result of another tl.load\n                new_conv_state = tl.where(\n                    mask, conv_state, loaded_x\n                )  # BUG in 'tl.where'  which requires a barrier before this\n                conv_states_ptrs_target = (\n                    conv_states_base\n                    + (idx_tokens_conv * stride_conv_state_tok)[:, None]\n                )  # [BLOCK_M, BLOCK_N]\n                mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[\n                    None, :\n                ]\n                tl.store(conv_states_ptrs_target, new_conv_state, mask)\n            else:  # load_init_state == False\n                # update conv_state by shifting left, BUT\n                # set cols prior to 'x' as zeros + cols from 'x'\n                idx_tokens_conv = tl.arange(0, NP2_STATELEN)  # [BLOCK_M]\n\n                VAL = state_len - seqlen\n\n                x_ptrs = (\n                    x_base[None, :]\n                    + ((idx_tokens_conv - VAL) * stride_x_token)[:, None]\n                )  # [BLOCK_M, BLOCK_N]\n\n                mask_x = (\n                    (idx_tokens_conv - VAL >= 0)[:, None]\n                    & (idx_tokens_conv - VAL < seqlen)[:, None]\n                    & (idx_feats < dim)[None, :]\n                )  # token-index  # token-index  # feature-index\n                new_conv_state = tl.load(x_ptrs, mask_x, 0.0)\n\n                conv_states_ptrs_target = (\n                    conv_states_base\n                    + (idx_tokens_conv * stride_conv_state_tok)[:, None]\n                )  # [BLOCK_M, BLOCK_N]\n                mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[\n                    None, :\n                ]\n                tl.store(conv_states_ptrs_target, new_conv_state, mask)\n\n    else:  # chunk_offset > 0\n        # read prior-token data from `x`\n        load_init_state = True\n        prior_tokens = x_base + (token_offset - 1) * stride_x_token\n        mask_w = idx_feats < dim\n        if KERNEL_WIDTH == 2:\n            conv_states_ptrs = prior_tokens  # [BLOCK_N]\n            col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=\".ca\")\n        if KERNEL_WIDTH == 3:\n            conv_states_ptrs = prior_tokens  # [BLOCK_N]\n            col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=\".ca\")\n            conv_states_ptrs = prior_tokens - 1 * stride_x_token  # [BLOCK_N]\n            col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=\".ca\")\n        if KERNEL_WIDTH == 4:\n            conv_states_ptrs = prior_tokens  # [BLOCK_N]\n            col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=\".ca\")\n            conv_states_ptrs = prior_tokens - 1 * stride_x_token  # [BLOCK_N]\n            col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=\".ca\")\n            conv_states_ptrs = prior_tokens - 2 * stride_x_token  # [BLOCK_N]\n            col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=\".ca\")\n        if KERNEL_WIDTH == 5:\n            # ruff: noqa: F841\n            conv_states_ptrs = prior_tokens  # [BLOCK_N]\n            col3 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=\".ca\")\n            conv_states_ptrs = prior_tokens - 1 * stride_x_token  # [BLOCK_N]\n            col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=\".ca\")\n            conv_states_ptrs = prior_tokens - 2 * stride_x_token  # [BLOCK_N]\n            col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=\".ca\")\n            conv_states_ptrs = prior_tokens - 3 * stride_x_token  # [BLOCK_N]\n            col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=\".ca\")\n\n    if HAS_BIAS:\n        bias = bias_ptr + idx_feats\n        mask_bias = idx_feats < dim\n        acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to(\n            tl.float32\n        )  # [BLOCK_N]\n    else:\n        acc_preload = tl.zeros((BLOCK_N,), dtype=tl.float32)\n\n    x_base_1d = x_base + token_offset * stride_x_token  # starting of chunk\n\n    # PRE-LOAD WEIGHTS\n    mask_w = idx_feats < dim\n    if KERNEL_WIDTH >= 2:\n        w_ptrs = w_base + (0 * stride_w_width)  # [BLOCK_N] tensor\n        w_col0 = tl.load(w_ptrs, mask_w, other=0.0)\n        w_ptrs = w_base + (1 * stride_w_width)  # [BLOCK_N] tensor\n        w_col1 = tl.load(w_ptrs, mask_w, other=0.0)\n    if KERNEL_WIDTH >= 3:\n        w_ptrs = w_base + (2 * stride_w_width)  # [BLOCK_N] tensor\n        w_col2 = tl.load(w_ptrs, mask_w, other=0.0)\n    if KERNEL_WIDTH >= 4:\n        w_ptrs = w_base + (3 * stride_w_width)  # [BLOCK_N] tensor\n        w_col3 = tl.load(w_ptrs, mask_w, other=0.0)\n    mask_x_1d = idx_feats < dim\n    for idx_token in range(segment_len):\n        acc = acc_preload\n\n        matrix_w = w_col0\n        matrix_x = col0\n        for j in tl.static_range(KERNEL_WIDTH):\n\n            if KERNEL_WIDTH == 2:\n                if j == 1:  # KERNEL_WIDTH-1:\n                    matrix_w = w_col1\n                    x_ptrs_1d = x_base_1d + idx_token * stride_x_token  # [BLOCK_N]\n                    matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)\n            elif KERNEL_WIDTH == 3:\n                if j == 1:\n                    matrix_w = w_col1\n                    matrix_x = col1\n                elif j == 2:\n                    matrix_w = w_col2\n                    x_ptrs_1d = x_base_1d + idx_token * stride_x_token  # [BLOCK_N]\n                    matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)\n            elif KERNEL_WIDTH == 4:\n                if j == 1:\n                    matrix_w = w_col1\n                    matrix_x = col1\n                elif j == 2:\n                    matrix_w = w_col2\n                    matrix_x = col2\n                elif j == 3:\n                    matrix_w = w_col3\n                    x_ptrs_1d = x_base_1d + idx_token * stride_x_token  # [BLOCK_N]\n                    matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)\n\n            acc += matrix_x * matrix_w  # [BLOCK_N]\n\n        if KERNEL_WIDTH == 2:\n            col0 = matrix_x\n        elif KERNEL_WIDTH == 3:\n            col0 = col1\n            col1 = matrix_x\n        elif KERNEL_WIDTH == 4:\n            col0 = col1\n            col1 = col2\n            col2 = matrix_x\n\n        if SILU_ACTIVATION:\n            acc = acc / (1 + tl.exp(-acc))\n        mask_1d = (idx_token < segment_len) & (\n            idx_feats < dim\n        )  # token-index  # feature-index\n        o_ptrs = (\n            o_ptr\n            + (sequence_start_index + token_offset + idx_token) * stride_o_token\n            + (idx_feats * stride_o_dim)\n        )\n\n        tl.store(o_ptrs, acc, mask=mask_1d)\n\n\ndef causal_conv1d_fn(\n    x: torch.Tensor,\n    weight: torch.Tensor,\n    bias: Union[torch.Tensor, None],\n    conv_states: torch.Tensor,\n    query_start_loc: torch.Tensor,\n    seq_lens_cpu: List[int],\n    cache_indices: Optional[torch.Tensor] = None,\n    has_initial_state: Optional[torch.Tensor] = None,\n    activation: Optional[str] = \"silu\",\n    pad_slot_id: int = PAD_SLOT_ID,\n    validate_data=False,\n    **kwargs,\n):\n    \"\"\"support varlen + continuous batching when x is 2D tensor\n\n    x: (dim,cu_seq_len)\n        cu_seq_len = total tokens of all seqs in that batch\n        sequences are concatenated from left to right for varlen\n    weight: (dim, width)\n    conv_states: (...,dim,width - 1) itype\n        updated inplace if provided\n        [it use `cache_indices` to get the index to the cache of conv_state for that sequence\n\n        conv_state[cache_indices[i]] for seq-i - to be used as initial_state when has_initial_state[i] = True\n             and after that conv_state[cache_indices[i]] need to be shift-left and updated with values from 'x'\n        ]\n    query_start_loc: (batch + 1) int32\n        The cumulative sequence lengths of the sequences in\n        the batch, used to index into sequence. prepended by 0.\n        if\n        x = [5, 1, 1, 1] <- continuous batching (batch=4)\n        then\n        query_start_loc = [0, 5, 6, 7, 8] <- the starting index of the next sequence; while the last value is\n           the ending index of the last sequence\n        [length(query_start_loc)-1 == batch]\n        for example: query_start_loc = torch.Tensor([0,10,16,17]),\n        x.shape=(dim,17)\n    seq_lens_cpu: (batch) int32\n        The sequence lengths of the sequences in the batch\n    cache_indices: (batch)  int32\n        indicates the corresponding state index,\n        like so: conv_state = conv_states[cache_indices[batch_id]]\n    has_initial_state: (batch) bool\n        indicates whether should the kernel take the current state as initial\n        state for the calculations\n        [single boolean for each sequence in the batch: True or False]\n    bias: (dim,)\n    activation: either None or \"silu\" or \"swish\" or True\n    pad_slot_id: int\n        if cache_indices is passed, lets the kernel identify padded\n        entries that will not be processed,\n        for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]\n        in this case, the kernel will not process entries at\n        indices 0 and 3\n\n    out: same shape as `x`\n    \"\"\"\n    if isinstance(activation, bool) and activation:\n        activation = \"silu\"\n\n    out = torch.empty_like(x)\n\n    is_channel_last = (x.stride(0) == 1) & (x.stride(1) > 1)\n    dim, cu_seqlen = x.shape\n    _, width = weight.shape\n    state_len = width - 1\n    np2_statelen = triton.next_power_of_2(state_len)\n\n    stride_x_seq = 0\n    stride_x_dim = x.stride(0)\n    stride_x_token = x.stride(1)\n    stride_w_dim = weight.stride(0)\n    stride_w_width = weight.stride(1)\n    stride_istate_seq = 0\n    stride_istate_dim = 0\n    stride_istate_token = 0\n    num_cache_lines = 0\n    if conv_states is not None:\n        # extensions to support vLLM:\n        # 1. conv_states is used to replaced initial_states\n        # 2. conv_states serve as a cache with num cache lines can be larger than batch size\n        # 3. mapping from sequence x[idx] to a cache line at index as specified via cache_indices[idx]\n        # 4. computation can be skipped if cache_indices[idx] == pad_slot_id\n        num_cache_lines = conv_states.size(0)\n        assert (\n            num_cache_lines == conv_states.shape[0]\n            and dim == conv_states.shape[1]\n            and width - 1 <= conv_states.shape[2]\n        )\n        stride_istate_seq = conv_states.stride(0)\n        stride_istate_dim = conv_states.stride(1)\n        stride_istate_token = conv_states.stride(2)\n        # assert stride_istate_dim == 1\n    if out.dim() == 2:\n        stride_o_seq = 0\n        stride_o_dim = out.stride(0)\n        stride_o_token = out.stride(1)\n    else:\n        stride_o_seq = out.stride(0)\n        stride_o_dim = out.stride(1)\n        stride_o_token = out.stride(2)\n\n    if validate_data:\n        assert x.dim() == 2\n        assert query_start_loc is not None\n        assert query_start_loc.dim() == 1\n        assert x.stride(0) == 1 or x.stride(1) == 1\n        padded_batch = query_start_loc.size(0) - 1\n        if bias is not None:\n            assert bias.dim() == 1\n            assert dim == bias.size(0)\n        if cache_indices is not None:\n            assert cache_indices.dim() == 1\n            assert padded_batch == cache_indices.size(0)\n        if has_initial_state is not None:\n            assert has_initial_state.size() == (padded_batch,)\n            assert (\n                conv_states is not None\n            ), \"ERROR: `has_initial_state` is used, which needs also `conv_states`\"\n        assert weight.stride(1) == 1\n        assert (dim, width) == weight.shape\n        assert is_channel_last, \"Need to run in channel-last layout\"\n\n    def grid(META):\n        max_seq_len = max(seq_lens_cpu)\n        return (\n            len(seq_lens_cpu),  # batch_size\n            (max_seq_len + META[\"BLOCK_M\"] - 1) // META[\"BLOCK_M\"],\n            triton.cdiv(dim, META[\"BLOCK_N\"]),\n        )\n\n    _causal_conv1d_fwd_kernel[grid](\n        # Pointers to matrices\n        x,\n        weight,\n        bias,\n        conv_states,\n        cache_indices,\n        has_initial_state,\n        query_start_loc,\n        out,\n        # Matrix dimensions\n        dim,\n        cu_seqlen,\n        num_cache_lines,\n        # stride\n        stride_x_seq,\n        stride_x_dim,\n        stride_x_token,\n        stride_w_dim,\n        stride_w_width,\n        stride_istate_seq,\n        stride_istate_dim,\n        stride_istate_token,\n        stride_o_seq,\n        stride_o_dim,\n        stride_o_token,\n        # others\n        pad_slot_id,\n        # META\n        HAS_BIAS=bias is not None,\n        KERNEL_WIDTH=width,\n        SILU_ACTIVATION=activation in [\"silu\", \"swish\"],\n        HAS_INITIAL_STATES=has_initial_state is not None,\n        HAS_CACHE=conv_states is not None,\n        IS_CONTINUOUS_BATCHING=cache_indices is not None,\n        USE_PAD_SLOT=pad_slot_id is not None,\n        NP2_STATELEN=np2_statelen,\n        # launch_cooperative_grid=True\n        BLOCK_M=8,\n        BLOCK_N=256,\n        num_stages=2,\n    )\n    return out\n\n\n# HAS_EAGLE_TREE_CUSTOM_ATTN_MASK is added to support eagle tree attention mask\n# retrieve_next_token_ptr: [N, NP2_T], retrieve_next_sibling_ptr: [N, NP2_T]\n# e.g. for a sequence of length 4, the eagle tree attention structure is:\n# retrieve_next_token=[1, 3, -1, -1] -> retrieve_next_token[i]: the 1st child token of token i\n# retrieve_next_sibling=[-1, 2, -1, -1] -> retrieve_next_sibling[i]: the 1st tree sibling token of token i\n# retrieve_parent_token=[n/a, 0, 0, 1] -> retrieve_parent_token[i]: the parent token of token i\n# Tree:\n#    0\n#   / \\\n#  1   2\n# /\n# 3\n# When calculating token 3's convolution, it should conv to token 1 (parent) and token 0 (grand-parent)\n# When calculating token 2's convolution, it should conv to token 0 (parent)\n# This kernel is a fused kernel which will also produce retrieve_parent_token based on retrieve_next_token & retrieve_next_sibling\n@triton.jit()\ndef _causal_conv1d_update_kernel(\n    # Pointers to matrices\n    x_ptr,  # (batch, dim, seqlen)\n    w_ptr,  # (dim, width)\n    bias_ptr,\n    conv_state_ptr,\n    cache_seqlens_ptr,  # circular buffer\n    conv_state_indices_ptr,\n    num_accepted_tokens_ptr,\n    intermediate_conv_window_ptr,\n    intermediate_state_indices_ptr,\n    retrieve_next_token_ptr,\n    retrieve_next_sibling_ptr,\n    retrieve_parent_token_ptr,\n    o_ptr,  # (batch, dim, seqlen)\n    # Matrix dimensions\n    batch: int,\n    dim: tl.constexpr,\n    seqlen: tl.constexpr,\n    state_len: tl.constexpr,\n    num_cache_lines: tl.constexpr,  # added to support vLLM larger cache lines\n    # Strides\n    stride_x_seq: tl.constexpr,\n    stride_x_dim: tl.constexpr,\n    stride_x_token: tl.constexpr,\n    stride_w_dim: tl.constexpr,\n    stride_w_width: tl.constexpr,\n    stride_conv_state_seq: tl.constexpr,\n    stride_conv_state_dim: tl.constexpr,\n    stride_conv_state_tok: tl.constexpr,\n    stride_state_indices: tl.constexpr,\n    stride_inter_seq: tl.constexpr,\n    stride_inter_step: tl.constexpr,\n    stride_inter_dim: tl.constexpr,\n    stride_inter_win: tl.constexpr,\n    stride_intermediate_state_indices: tl.constexpr,\n    stride_retrieve_next_token_seq: tl.constexpr,\n    stride_retrieve_next_token_token: tl.constexpr,\n    stride_retrieve_next_sibling_seq: tl.constexpr,\n    stride_retrieve_next_sibling_token: tl.constexpr,\n    stride_retrieve_parent_token_seq: tl.constexpr,\n    stride_retrieve_parent_token_token: tl.constexpr,\n    stride_o_seq: tl.constexpr,\n    stride_o_dim: tl.constexpr,\n    stride_o_token: tl.constexpr,\n    # others\n    pad_slot_id: tl.constexpr,\n    # Meta-parameters\n    HAS_BIAS: tl.constexpr,\n    KERNEL_WIDTH: tl.constexpr,\n    SILU_ACTIVATION: tl.constexpr,\n    IS_CONTINUOUS_BATCHING: tl.constexpr,\n    IS_SPEC_DECODING: tl.constexpr,\n    NP2_STATELEN: tl.constexpr,\n    NP2_SEQLEN: tl.constexpr,\n    USE_PAD_SLOT: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    SAVE_INTERMEDIATE: tl.constexpr,\n    HAS_EAGLE_TREE_CUSTOM_ATTN_MASK: tl.constexpr,\n):\n    # ruff: noqa: E501\n    idx_seq = tl.program_id(0)\n    if idx_seq >= batch:\n        return\n\n    # [BLOCK_N,] elements along the feature-dimension (channel)\n    idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)\n\n    if IS_CONTINUOUS_BATCHING:\n        # mask = idx_seq < batch\n        conv_state_batch_coord = tl.load(\n            conv_state_indices_ptr + idx_seq * stride_state_indices\n        ).to(tl.int64)\n        if SAVE_INTERMEDIATE:\n            intermediate_state_batch_coord = tl.load(\n                intermediate_state_indices_ptr\n                + idx_seq * stride_intermediate_state_indices\n            ).to(tl.int64)\n    else:\n        conv_state_batch_coord = idx_seq\n    if USE_PAD_SLOT:  # noqa\n        if conv_state_batch_coord == pad_slot_id:\n            # not processing as this is not the actual sequence\n            return\n\n    if IS_SPEC_DECODING:\n        # The rolling of conv state:\n        #\n        # Before forward, the conv_state is:\n        # [history1, history2, ..., historyM].\n        #\n        # After forward, the conv_state becomes:\n        # [history2, ..., historyM, draft1, draft2, ..., draftN].\n        #\n        # After acceptance, it becomes:\n        #\n        # - accept 1 tokens: [history2, ..., historyM, draft1]\n        # - accept 2 tokens: [history3, ..., historyM, draft1, draft2]\n        # - and so on.\n        conv_state_token_offset = tl.load(num_accepted_tokens_ptr + idx_seq) - 1\n    else:\n        conv_state_token_offset = 0\n\n    # STEP 1: READ init_state data\n    conv_states_base = (\n        conv_state_ptr\n        + (conv_state_batch_coord * stride_conv_state_seq)\n        + (idx_feats * stride_conv_state_dim)\n    )\n    mask_w = idx_feats < dim\n\n    prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok\n    if KERNEL_WIDTH >= 2:\n        conv_states_ptrs = prior_tokens  # [BLOCK_N]\n        col0 = tl.load(conv_states_ptrs, mask_w, 0.0)\n    if KERNEL_WIDTH >= 3:\n        conv_states_ptrs = prior_tokens + 1 * stride_conv_state_tok  # [BLOCK_N]\n        col1 = tl.load(conv_states_ptrs, mask_w, 0.0)\n    if KERNEL_WIDTH >= 4:\n        conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok  # [BLOCK_N]\n        col2 = tl.load(conv_states_ptrs, mask_w, 0.0)\n    if KERNEL_WIDTH == 5:\n        conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok  # [BLOCK_N]\n        col3 = tl.load(conv_states_ptrs, mask_w, 0.0)\n\n    # STEP 2: assume state_len > seqlen\n    idx_tokens = tl.arange(0, NP2_STATELEN)  # [BLOCK_M]\n\n    # The conv_state updates works in a sliding window manner,\n    # at each forward pass, the tokens are shift by 1, so we\n    # load since idx_tokens + 1.\n    conv_state_ptrs_source = (\n        conv_state_ptr\n        + (conv_state_batch_coord * stride_conv_state_seq)\n        + conv_state_token_offset * stride_conv_state_tok\n        + (idx_feats * stride_conv_state_dim)[None, :]\n        + ((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) * stride_conv_state_tok)[\n            :, None\n        ]\n    )  # [BLOCK_M, BLOCK_N]\n    mask = (\n        (conv_state_batch_coord < num_cache_lines)\n        & ((idx_tokens + seqlen) < state_len)[:, None]\n        & (idx_feats < dim)[None, :]\n    )\n    conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0)\n\n    VAL = state_len - seqlen\n    x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim)  # [BLOCK_N]\n\n    x_ptrs = (\n        x_base[None, :] + ((idx_tokens - VAL) * stride_x_token)[:, None]\n    )  # [BLOCK_M, BLOCK_N]\n\n    mask_x = (\n        (idx_tokens - VAL >= 0)[:, None]\n        & (idx_tokens - VAL < seqlen)[:, None]\n        & (idx_feats < dim)[None, :]\n    )  # token-index  # token-index  # feature-index\n    loaded_x = tl.load(x_ptrs, mask_x, 0.0)\n    tl.debug_barrier()\n\n    new_conv_state = tl.where(mask, conv_state, loaded_x)\n\n    conv_state_base = (\n        conv_state_ptr\n        + (conv_state_batch_coord * stride_conv_state_seq)\n        + (idx_feats * stride_conv_state_dim)\n    )  # [BLOCK_N,]\n    conv_state_ptrs_target = (\n        conv_state_base + (idx_tokens * stride_conv_state_tok)[:, None]\n    )  # [BLOCK_M, BLOCK_N]\n    mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :]\n    tl.store(conv_state_ptrs_target, new_conv_state, mask)\n\n    # STEP 3: init accumulator\n    if HAS_BIAS:\n        bias = bias_ptr + idx_feats\n        mask_bias = idx_feats < dim\n        acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to(\n            tl.float32\n        )  # [BLOCK_N]\n    else:\n        acc_preload = tl.zeros((BLOCK_N,), dtype=tl.float32)\n\n    # STEP 4:\n    # PRE-LOAD WEIGHTS\n    # first kernel column, configured for weights to handle BLOCK_N features in range\n    if HAS_EAGLE_TREE_CUSTOM_ATTN_MASK:\n        idx_tokens = tl.arange(0, NP2_SEQLEN)  # [BLOCK_M]\n        # Update parent mapping for all tokens at once using vectorized operations\n        mask_retrieve = idx_tokens < seqlen\n        retrieve_next_token_base = (\n            retrieve_next_token_ptr\n            + (idx_seq * stride_retrieve_next_token_seq)\n            + idx_tokens * stride_retrieve_next_token_token\n        )\n        retrieve_next_tokens = tl.load(retrieve_next_token_base, mask_retrieve)\n        retrieve_next_sibling_base = (\n            retrieve_next_sibling_ptr\n            + (idx_seq * stride_retrieve_next_sibling_seq)\n            + idx_tokens * stride_retrieve_next_sibling_token\n        )\n        retrieve_next_siblings = tl.load(retrieve_next_sibling_base, mask_retrieve)\n        parent_idx_tokens = tl.zeros((NP2_SEQLEN,), dtype=tl.int32)\n\n    w_base = w_ptr + (idx_feats * stride_w_dim)  # [BLOCK_N,]\n    mask_w = idx_feats < dim\n    if KERNEL_WIDTH >= 2:\n        w_ptrs = w_base + (0 * stride_w_width)  # [BLOCK_N] tensor\n        w_col0 = tl.load(w_ptrs, mask_w, other=0.0)\n        w_ptrs = w_base + (1 * stride_w_width)  # [BLOCK_N] tensor\n        w_col1 = tl.load(w_ptrs, mask_w, other=0.0)\n    if KERNEL_WIDTH >= 3:\n        w_ptrs = w_base + (2 * stride_w_width)  # [BLOCK_N] tensor\n        w_col2 = tl.load(w_ptrs, mask_w, other=0.0)\n    if KERNEL_WIDTH >= 4:\n        w_ptrs = w_base + (3 * stride_w_width)  # [BLOCK_N] tensor\n        w_col3 = tl.load(w_ptrs, mask_w, other=0.0)\n\n    x_base_1d = x_base  # starting of chunk [BLOCK_N]\n    mask_x_1d = idx_feats < dim\n\n    # STEP 5: compute each token\n    for idx_token in tl.static_range(seqlen):\n        acc = acc_preload\n\n        if HAS_EAGLE_TREE_CUSTOM_ATTN_MASK:\n            # set the parent index of the next token in the eagle tree\n            # next token's parent is the current token\n            retrieve_next_token_idx = tl.sum(\n                tl.where(idx_tokens == idx_token, retrieve_next_tokens, 0)\n            )\n            if retrieve_next_token_idx != -1:  # pad slot id\n                parent_idx_tokens = tl.where(\n                    idx_tokens == retrieve_next_token_idx,\n                    idx_token,\n                    parent_idx_tokens,\n                )\n            # next token's parent is the parent of the current token\n            retrieve_sibling_token_idx = tl.sum(\n                tl.where(idx_tokens == idx_token, retrieve_next_siblings, 0)\n            )\n            if retrieve_sibling_token_idx != -1:  # pad slot id\n                parent_idx_token = tl.sum(\n                    tl.where(idx_tokens == idx_token, parent_idx_tokens, 0)\n                )\n                parent_idx_tokens = tl.where(\n                    idx_tokens == retrieve_sibling_token_idx,\n                    parent_idx_token,\n                    parent_idx_tokens,\n                )\n            # tl.device_print(\"am\", parent_idx_tokens)\n\n            _idx_token = idx_token\n            x_ptrs_1d = x_base_1d + _idx_token * stride_x_token  # [BLOCK_N]\n            matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)\n            # convolution operation: itself * wcol[-1] + parent * wcol[-2] + grand-parent * wcol[-3] + ...\n            for j in tl.static_range(KERNEL_WIDTH):\n                if KERNEL_WIDTH == 2:\n                    if j == 0:\n                        matrix_w = w_col1\n                    else:\n                        matrix_w = w_col0\n                elif KERNEL_WIDTH == 3:\n                    if j == 0:\n                        matrix_w = w_col2\n                    elif j == 1:\n                        matrix_w = w_col1\n                    else:\n                        matrix_w = w_col0\n                elif KERNEL_WIDTH == 4:\n                    if j == 0:\n                        matrix_w = w_col3\n                    elif j == 1:\n                        matrix_w = w_col2\n                    elif j == 2:\n                        matrix_w = w_col1\n                    else:\n                        matrix_w = w_col0\n\n                if SAVE_INTERMEDIATE:\n                    # Save the window state after consuming this token\n                    # Layout: [seq(cache line), step, dim, win(K-1)]\n                    base_ptr = (\n                        intermediate_conv_window_ptr\n                        + intermediate_state_batch_coord * stride_inter_seq\n                        + idx_token * stride_inter_step\n                        + idx_feats * stride_inter_dim\n                    )\n\n                    # store itself in KERNEL_WIDTH-2 slot, parent in KERNEL_WIDTH-3 slot, grand-parent in KERNEL_WIDTH-4 slot, ...\n                    if KERNEL_WIDTH - j - 2 >= 0:\n                        tl.store(\n                            base_ptr + (KERNEL_WIDTH - j - 2) * stride_inter_win,\n                            matrix_x,\n                            mask=mask_w,\n                        )\n\n                acc += matrix_x * matrix_w\n\n                # move to parent for next iteration\n                if _idx_token > 0:\n                    _idx_token = tl.sum(\n                        tl.where(idx_tokens == _idx_token, parent_idx_tokens, 0)\n                    )\n                    x_ptrs_1d = x_base_1d + _idx_token * stride_x_token  # [BLOCK_N]\n                    matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)\n                else:\n                    # no parent within the current chunk, load from prev conv state: col[-1] (idx 0's parent), col[-2] (idx 0's grand parent), ...\n                    if KERNEL_WIDTH == 2:\n                        if _idx_token == 0:\n                            matrix_x = col0\n                    elif KERNEL_WIDTH == 3:\n                        if _idx_token == 0:\n                            matrix_x = col1\n                        else:\n                            matrix_x = col0\n                    elif KERNEL_WIDTH == 4:\n                        if _idx_token == 0:\n                            matrix_x = col2\n                        elif _idx_token == -1:\n                            matrix_x = col1\n                        else:\n                            matrix_x = col0\n                    _idx_token = _idx_token - 1\n        else:\n            matrix_w = w_col0\n            matrix_x = col0\n\n            for j in tl.static_range(KERNEL_WIDTH):\n                if KERNEL_WIDTH == 2:\n                    if j == 1:  # KERNEL_WIDTH-1:\n                        matrix_w = w_col1\n                        x_ptrs_1d = x_base_1d + idx_token * stride_x_token  # [BLOCK_N]\n                        matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)\n                elif KERNEL_WIDTH == 3:\n                    if j == 1:\n                        matrix_w = w_col1\n                        matrix_x = col1\n                    elif j == 2:\n                        matrix_w = w_col2\n                        x_ptrs_1d = x_base_1d + idx_token * stride_x_token  # [BLOCK_N]\n                        matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)\n                elif KERNEL_WIDTH == 4:\n                    if j == 1:\n                        matrix_w = w_col1\n                        matrix_x = col1\n                    elif j == 2:\n                        matrix_w = w_col2\n                        matrix_x = col2\n                    elif j == 3:\n                        matrix_w = w_col3\n                        x_ptrs_1d = x_base_1d + idx_token * stride_x_token  # [BLOCK_N]\n                        matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)\n\n                acc += matrix_x * matrix_w  # [BLOCK_N]\n\n            if KERNEL_WIDTH == 2:\n                col0 = matrix_x\n            elif KERNEL_WIDTH == 3:\n                col0 = col1\n                col1 = matrix_x\n            elif KERNEL_WIDTH == 4:\n                col0 = col1\n                col1 = col2\n                col2 = matrix_x\n\n            if SAVE_INTERMEDIATE:\n                # Save the window state after consuming this token\n                # Layout: [seq(cache line), step, dim, win(K-1)]\n                base_ptr = (\n                    intermediate_conv_window_ptr\n                    + intermediate_state_batch_coord * stride_inter_seq\n                    + idx_token * stride_inter_step\n                    + idx_feats * stride_inter_dim\n                )\n                if KERNEL_WIDTH >= 2:\n                    tl.store(base_ptr + 0 * stride_inter_win, col0, mask=mask_w)\n                if KERNEL_WIDTH >= 3:\n                    tl.store(base_ptr + 1 * stride_inter_win, col1, mask=mask_w)\n                if KERNEL_WIDTH >= 4:\n                    tl.store(base_ptr + 2 * stride_inter_win, col2, mask=mask_w)\n\n        if SILU_ACTIVATION:\n            acc = acc / (1 + tl.exp(-acc))\n        mask_1d = (idx_token < seqlen) & (\n            idx_feats < dim\n        )  # token-index  # feature-index\n        o_ptrs = (\n            o_ptr\n            + (idx_seq) * stride_o_seq\n            + idx_token * stride_o_token\n            + (idx_feats * stride_o_dim)\n        )\n\n        tl.store(o_ptrs, acc, mask=mask_1d)\n\n        # fuse: store calculated retrieve_parent_token to tensor\n        if HAS_EAGLE_TREE_CUSTOM_ATTN_MASK:\n            tl.store(\n                retrieve_parent_token_ptr\n                + idx_seq * stride_retrieve_parent_token_seq\n                + idx_tokens * stride_retrieve_parent_token_token,\n                parent_idx_tokens,\n                mask=mask_retrieve,\n            )\n\n\ndef causal_conv1d_update(\n    x: torch.Tensor,\n    conv_state: torch.Tensor,\n    weight: torch.Tensor,\n    bias: Optional[torch.Tensor] = None,\n    activation: Union[bool, str, None] = None,\n    cache_seqlens: Optional[torch.Tensor] = None,\n    conv_state_indices: Optional[torch.Tensor] = None,\n    num_accepted_tokens: Optional[torch.Tensor] = None,\n    intermediate_conv_window: Optional[torch.Tensor] = None,\n    intermediate_state_indices: Optional[torch.Tensor] = None,\n    retrieve_next_token: Optional[torch.Tensor] = None,\n    retrieve_next_sibling: Optional[torch.Tensor] = None,\n    retrieve_parent_token: Optional[torch.Tensor] = None,\n    pad_slot_id: int = PAD_SLOT_ID,\n    metadata=None,\n    validate_data=False,\n):\n    \"\"\"\n    x: (batch, dim) or (batch, dim, seqlen)\n        [shape=2: single token prediction]\n        [shape=3: single or multiple tokens prediction]\n    conv_state: (..., dim, state_len), where state_len >= width - 1\n    weight: (dim, width)\n    bias: (dim,)\n    cache_seqlens: (batch,), dtype int32.\n        If not None, the conv_state is treated as a circular buffer.\n        The conv_state will be updated by copying x to the conv_state\n        starting at the index\n        @cache_seqlens % state_len.\n    conv_state_indices: (batch,), dtype int32\n        If not None, the conv_state is a larger tensor along the batch dim,\n        and we are selecting the batch coords specified by conv_state_indices.\n        Useful for a continuous batching scenario.\n    pad_slot_id: int\n            if cache_indices is passed, lets the kernel identify padded\n            entries that will not be processed,\n            for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]\n            in this case, the kernel will not process entries at\n            indices 0 and 3\n    out: (batch, dim) or (batch, dim, seqlen)\n    \"\"\"\n    if validate_data:\n        assert cache_seqlens is None  # not implemented yet - ok for vLLM\n        assert pad_slot_id is not None\n        assert x.stride(1) == 1\n    if isinstance(activation, bool):\n        activation = \"silu\" if activation is True else None\n    elif activation is not None:\n        assert activation in [\"silu\", \"swish\"]\n    unsqueeze = x.dim() == 2\n    if unsqueeze:\n        # make it (batch, dim, seqlen) with seqlen == 1\n        x = x.unsqueeze(-1)\n    batch, dim, seqlen = x.shape\n    _, width = weight.shape\n    # conv_state: (..., dim, state_len), where state_len >= width - 1\n    num_cache_lines, _, state_len = conv_state.size()\n\n    if validate_data:\n        assert dim == weight.size(0)\n        assert (\n            conv_state.stride(-2) == 1\n        ), f\"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})\"\n        assert state_len >= width - 1\n        # when above happens, we don't shift-left to keep any records in conv_state\n        assert dim == conv_state.size(1)\n        if conv_state_indices is None:\n            assert conv_state.size(0) >= batch\n        else:\n            assert (batch,) == conv_state_indices.shape\n            assert intermediate_state_indices is not None\n            assert (batch,) == intermediate_state_indices.shape\n\n        assert num_cache_lines >= batch\n        assert weight.stride(1) == 1  # Need this\n        assert cache_seqlens is None  # not needed for vLLM - circular buffer\n\n    # adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o'\n    out = torch.empty_like(x)\n    stride_w_dim, stride_w_width = weight.stride()\n\n    stride_x_seq, stride_x_dim, stride_x_token = x.stride()  # X (batch, dim, seqlen)\n\n    stride_o_seq, stride_o_dim, stride_o_token = out.stride()\n    stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride()\n    stride_state_indices = (\n        conv_state_indices.stride(0) if conv_state_indices is not None else 0\n    )\n    stride_intermediate_state_indices = (\n        intermediate_state_indices.stride(0)\n        if intermediate_state_indices is not None\n        else 0\n    )\n    if num_accepted_tokens is not None:\n        state_len = width - 1 + (seqlen - 1)  # effective state_len needed\n    else:\n        state_len = width - 1\n    np2_statelen = triton.next_power_of_2(state_len)\n    np2_seqlen = triton.next_power_of_2(seqlen)\n\n    def grid(META):\n        return (\n            batch,\n            triton.cdiv(dim, META[\"BLOCK_N\"]),\n        )\n\n    # prepare intermediate buffer strides if provided\n    if intermediate_conv_window is not None:\n        stride_inter_seq, stride_inter_step, stride_inter_dim, stride_inter_win = (\n            intermediate_conv_window.stride(0),\n            intermediate_conv_window.stride(1),\n            intermediate_conv_window.stride(2),\n            intermediate_conv_window.stride(3),\n        )\n    else:\n        stride_inter_seq = stride_inter_step = stride_inter_dim = stride_inter_win = 0\n\n    # prepare retrieve next token buffer strides if provided\n    if retrieve_next_token is not None:\n        stride_retrieve_next_token_seq, stride_retrieve_next_token_token = (\n            retrieve_next_token.stride(0),\n            retrieve_next_token.stride(1),\n        )\n    else:\n        stride_retrieve_next_token_seq = stride_retrieve_next_token_token = 0\n\n    # prepare retrieve next sibling buffer strides if provided\n    if retrieve_next_sibling is not None:\n        stride_retrieve_next_sibling_seq, stride_retrieve_next_sibling_token = (\n            retrieve_next_sibling.stride(0),\n            retrieve_next_sibling.stride(1),\n        )\n    else:\n        stride_retrieve_next_sibling_seq = stride_retrieve_next_sibling_token = 0\n\n    # prepare retrieve parent token buffer strides if provided\n    if retrieve_parent_token is not None:\n        stride_retrieve_parent_token_seq, stride_retrieve_parent_token_token = (\n            retrieve_parent_token.stride(0),\n            retrieve_parent_token.stride(1),\n        )\n    else:\n        stride_retrieve_parent_token_seq = stride_retrieve_parent_token_token = 0\n\n    _causal_conv1d_update_kernel[grid](\n        # Pointers to matrices\n        x,\n        weight,\n        bias,\n        conv_state,\n        cache_seqlens,\n        conv_state_indices,\n        num_accepted_tokens,\n        intermediate_conv_window if intermediate_conv_window is not None else x,\n        intermediate_state_indices,\n        retrieve_next_token,\n        retrieve_next_sibling,\n        retrieve_parent_token,\n        out,\n        # Matrix dimensions\n        batch,\n        dim,\n        seqlen,\n        state_len,\n        num_cache_lines,\n        # stride\n        stride_x_seq,\n        stride_x_dim,\n        stride_x_token,\n        stride_w_dim,\n        stride_w_width,\n        stride_istate_seq,\n        stride_istate_dim,\n        stride_istate_token,\n        stride_state_indices,\n        stride_inter_seq,\n        stride_inter_step,\n        stride_inter_dim,\n        stride_inter_win,\n        stride_intermediate_state_indices,\n        stride_retrieve_next_token_seq,\n        stride_retrieve_next_token_token,\n        stride_retrieve_next_sibling_seq,\n        stride_retrieve_next_sibling_token,\n        stride_retrieve_parent_token_seq,\n        stride_retrieve_parent_token_token,\n        stride_o_seq,\n        stride_o_dim,\n        stride_o_token,\n        # others\n        pad_slot_id,\n        # META\n        HAS_BIAS=bias is not None,\n        KERNEL_WIDTH=width,\n        SILU_ACTIVATION=activation in [\"silu\", \"swish\"],\n        IS_CONTINUOUS_BATCHING=conv_state_indices is not None,\n        IS_SPEC_DECODING=num_accepted_tokens is not None,\n        NP2_STATELEN=np2_statelen,\n        NP2_SEQLEN=np2_seqlen,\n        USE_PAD_SLOT=pad_slot_id is not None,\n        BLOCK_N=256,\n        SAVE_INTERMEDIATE=intermediate_conv_window is not None,\n        HAS_EAGLE_TREE_CUSTOM_ATTN_MASK=retrieve_next_token is not None,\n    )\n    if unsqueeze:\n        out = out.squeeze(-1)\n    return out\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/mamba/mamba.py",
    "content": "from typing import Callable, List, Optional, Tuple\n\nimport torch\nimport torch.nn as nn\n\nfrom sglang.srt.configs.mamba_utils import (\n    Mamba2CacheParams,\n    extra_groups_for_head_shards,\n)\nfrom sglang.srt.distributed import (\n    divide,\n    get_tensor_model_parallel_rank,\n    get_tensor_model_parallel_world_size,\n)\nfrom sglang.srt.layers.attention.mamba.mamba2_metadata import Mamba2Metadata\nfrom sglang.srt.layers.attention.mamba.mixer2_rms_norm_gated import Mixer2RMSNormGated\nfrom sglang.srt.layers.attention.mamba.ops import (\n    mamba_chunk_scan_combined,\n    selective_state_update,\n)\nfrom sglang.srt.layers.linear import (\n    ColumnParallelLinear,\n    MergedColumnParallelLinear,\n    RowParallelLinear,\n)\nfrom sglang.srt.layers.quantization.base_config import QuantizationConfig\nfrom sglang.srt.mem_cache.memory_pool import MambaPool\nfrom sglang.srt.model_loader.weight_utils import (\n    composed_weight_loader,\n    sharded_weight_loader,\n)\nfrom sglang.srt.utils import is_cpu, is_cuda, is_npu, set_weight_attrs\n\nif is_cuda():\n    from sglang.srt.layers.attention.mamba.causal_conv1d import (\n        causal_conv1d_fn,\n        causal_conv1d_update,\n    )\n    from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (\n        causal_conv1d_fn as causal_conv1d_fn_triton,\n    )\n    from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (\n        causal_conv1d_update as causal_conv1d_update_triton,\n    )\nelif is_npu():\n    from sgl_kernel_npu.mamba.causal_conv1d import (\n        causal_conv1d_fn_npu as causal_conv1d_fn,\n    )\n    from sgl_kernel_npu.mamba.causal_conv1d import (\n        causal_conv1d_update_npu as causal_conv1d_update,\n    )\n\nLoaderFunction = Callable[[torch.Tensor, torch.Tensor], None]\n\n\ndef mamba_v2_sharded_weight_loader(\n    shard_spec: List[Tuple[int, int, float]],\n    tp_size: int,\n    tp_rank: int,\n) -> LoaderFunction:\n    \"\"\"Create a weight loader for mamba v2. This ensures that the projections\n    are correctly sharded so that they can be split into x, B, C. It also\n    ensures the the all the groups corresponding to a head shard is placed\n    together with it.\n    \"\"\"\n\n    def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:\n\n        # - track boundary of (sharded) param, and loaded_weight, respectively\n        boundary, loaded_boundary = 0, 0\n\n        # Calculate padding size for CPU when TP odd size\n        if is_cpu():\n            full_dim_sum = 0\n            full_dim_list = []\n            weight_full_dim_list = []\n            for full_dim, _, _ in shard_spec:\n                full_dim_sum = full_dim_sum + full_dim\n                full_dim_list.append(full_dim)\n            for full_dim in full_dim_list:\n                weight_full_dim_list.append(\n                    int(full_dim / full_dim_sum * loaded_weight.size(0))\n                )\n\n        # - iterate over the shard specs\n        for full_dim, extra, duplicate_groups in shard_spec:\n            # - full dim is the model dim (before TP).\n            # - extra > 0, means there is expected overall increase\n            #   of dimensions. This is so because of replication.\n            # - ratio is used map the tp_rank to the actual shard\n            #   rank. This is useful when there is replication of\n            #   groups to accompany head shards.\n\n            # - size of the loaded shard\n            shard_size = full_dim // tp_size\n\n            # - compute the rank into the loaded shard.\n            # - if there is replication, different TP shards will\n            #   take from the same rank.\n            # NOTE: currently we only support duplication\n            # in the case where num_groups == 1\n            rank = 0 if duplicate_groups else tp_rank\n\n            # - leftmost boundary index into loaded weight.\n            loaded_skip = rank * shard_size\n            loaded_start_idx = loaded_boundary + loaded_skip\n\n            # - take these many dims from the loaded weight.\n            take = min(shard_size, full_dim - extra - loaded_skip)\n\n            # CPU logic of padding size for qwen3-next\n            # TODO : make this common for all mamba.\n            if is_cpu() and loaded_weight.size(0) % tp_size != 0:\n                import copy\n\n                loaded_weight_ = copy.deepcopy(loaded_weight)\n                q, k, v = torch.split(\n                    loaded_weight_,\n                    weight_full_dim_list,\n                    dim=0,\n                )\n                pad_qk = torch.zeros(\n                    full_dim_list[0] - weight_full_dim_list[0],\n                    loaded_weight.size(1),\n                    loaded_weight.size(2),\n                ).to(loaded_weight.dtype)\n                pad_v = torch.zeros(\n                    full_dim_list[2] - weight_full_dim_list[2],\n                    loaded_weight.size(1),\n                    loaded_weight.size(2),\n                ).to(loaded_weight.dtype)\n                q = torch.cat((q, pad_qk), dim=0)\n                k = torch.cat((k, pad_qk), dim=0)\n                v = torch.cat((v, pad_v), dim=0)\n                loaded_weight_qk = torch.cat((q, k), dim=0)\n                loaded_weight = torch.cat((loaded_weight_qk, v), dim=0)\n\n            # - always shard on dim 0\n            # - the ignore is for a mundane mypy error as it does not\n            #   seem to handle slices well.\n            # https://github.com/python/mypy/issues/2410\n            param.data[\n                boundary : (boundary + take), ...  # type: ignore[misc]\n            ] = loaded_weight[\n                loaded_start_idx : (loaded_start_idx + take)  # type: ignore[misc]\n            ]  # type: ignore[misc]\n\n            # move indexing boundaries\n            boundary += shard_size\n            loaded_boundary += full_dim - extra\n\n    return loader\n\n\nclass MambaMixer2(torch.nn.Module):\n    \"\"\"\n    Compute ∆, A, B, C, and D the state space parameters and compute\n    the `contextualized_states`. A, D are input independent\n    (see Mamba paper [1] Section 3.5.2 \"Interpretation of A\"\n    for why A isn't selective) ∆, B, C are input-dependent\n    (this is a key difference between Mamba and the linear time\n    invariant S4, and is why Mamba is called\n    **selective** state spaces)\n    \"\"\"\n\n    def __init__(\n        self,\n        cache_params: Mamba2CacheParams,\n        hidden_size: int,\n        use_conv_bias: bool,\n        use_bias: bool,\n        n_groups: int = 1,\n        rms_norm_eps: float = 1e-5,\n        activation: str = \"silu\",\n        use_rms_norm: bool = True,\n        quant_config: Optional[QuantizationConfig] = None,\n        prefix: str = \"\",\n    ):\n        super().__init__()\n\n        # For TP, the sharding plan is as follows:\n        # - for the conv modules, since\n        #   conv_dim = intermediate_size * 2 * n_groups * ssm_state_size,\n        #   we shard intermediate_size and n_groups\n        # - since intermediate_size = n_heads * head_dim, sharding on\n        #   intermediate_size is achieved by sharding on n_heads.\n        # - IF, world_size divides groups, then sharding\n        #   (n_groups / world_size, n_heads / world_size)\n        #   also maintains the invariant n_heads % n_groups == 0\n        # - HOWEVER IF, world_size DOES NOT divide groups, then we need\n        #   to allocate extra space in the shard, such that groups\n        #   may be replicated to follow the head shard.\n        # - NOTE: currently for the world size DOES NOT divide groups\n        #   case, we only support the case when n_groups == 1\n        self.tp_size = get_tensor_model_parallel_world_size()\n        self.tp_rank = get_tensor_model_parallel_rank()\n\n        self.num_heads = num_heads = cache_params.shape.num_heads\n        self.head_dim = cache_params.shape.head_dim\n\n        assert (\n            num_heads % self.tp_size == 0\n        ), \"Tensor parallel world size must divide num heads.\"\n\n        assert (n_groups % self.tp_size) == 0 or n_groups == 1, (\n            \"If tensor parallel world size does not divide num_groups, \"\n            \"then num_groups must equal 1.\"\n        )\n\n        assert (\n            (n_groups % self.tp_size == 0) or self.tp_size == 1 or quant_config is None\n        ), (\n            \"Tensor parallel currently supported for quantized models only \"\n            \"if tensor parallel world size divides num groups.\"\n        )\n\n        self.ssm_state_size = cache_params.shape.ssm_state_size\n        self.activation = activation\n\n        conv_kernel_size = cache_params.shape.conv_kernel\n        self.intermediate_size = intermediate_size = (\n            cache_params.shape.intermediate_size\n        )\n        self.n_groups = n_groups\n        if n_groups % self.tp_size != 0:\n            # - for TP we shard conv_dim by sharding on n_groups,\n            # - but if n_groups cannot divide tp_size, we need to\n            #   extend some extra groups\n            groups = extra_groups_for_head_shards(n_groups, self.tp_size)\n            self.n_groups = n_groups + groups\n        self.groups_ssm_state_size = self.n_groups * self.ssm_state_size\n        self.conv_dim = cache_params.shape.conv_dim\n\n        if n_groups % self.tp_size == 0:\n            self.conv1d = MergedColumnParallelLinear(\n                input_size=conv_kernel_size,\n                output_sizes=[\n                    intermediate_size,\n                    self.groups_ssm_state_size,\n                    self.groups_ssm_state_size,\n                ],\n                bias=use_conv_bias,\n                quant_config=None,\n                prefix=f\"{prefix}.conv1d\",\n            )\n\n            self.in_proj = MergedColumnParallelLinear(\n                input_size=hidden_size,\n                output_sizes=[\n                    intermediate_size,\n                    intermediate_size,\n                    self.groups_ssm_state_size,\n                    self.groups_ssm_state_size,\n                    self.num_heads,\n                ],\n                bias=use_bias,\n                quant_config=quant_config,\n                prefix=f\"{prefix}.in_proj\",\n            )\n        else:\n            # This is the n_groups == 1 case,\n            # where we need to duplicate groups if TP>1.\n\n            self.conv1d = ColumnParallelLinear(\n                input_size=conv_kernel_size,\n                output_size=self.conv_dim,\n                bias=use_conv_bias,\n                quant_config=None,\n                prefix=f\"{prefix}.conv1d\",\n            )\n\n            self.in_proj = ColumnParallelLinear(\n                input_size=hidden_size,\n                output_size=intermediate_size + self.conv_dim + self.num_heads,\n                bias=use_bias,\n                quant_config=quant_config,\n                prefix=f\"{prefix}.in_proj\",\n            )\n\n            # - because in_proj is a concatenation of 3 weights, we\n            #   need to interleave them before sharding\n            # - use the custom weight loader mamba_v2_sharded_weight_loader\n            #   for conv1d.bias, covn1d.weight and in_proj.weight\n            # - need to set these settings, to assign the groups\n            #   to the head shards\n            group_shard_settings = (\n                self.groups_ssm_state_size,  # expected model size\n                (self.n_groups - n_groups) * self.ssm_state_size,  # extra dims assigned\n                n_groups == 1,  # if there was only one group\n            )\n            intermediate_settings = (intermediate_size, 0, False)\n            head_settings = (self.num_heads, 0, False)\n\n            # - the weight already has a \"weight_loader\" attribute\n            #   which set_weight_attrs will raise if we do not\n            #   delete before trying to override it\n            # - ditto for the other two weights below\n            delattr(self.conv1d.bias, \"weight_loader\")\n            set_weight_attrs(\n                self.conv1d.bias,\n                {\n                    \"weight_loader\": mamba_v2_sharded_weight_loader(\n                        [\n                            intermediate_settings,\n                            group_shard_settings,\n                            group_shard_settings,\n                        ],\n                        self.tp_size,\n                        self.tp_rank,\n                    )\n                },\n            )\n\n            delattr(self.conv1d.weight, \"weight_loader\")\n            set_weight_attrs(\n                self.conv1d.weight,\n                {\n                    \"weight_loader\": mamba_v2_sharded_weight_loader(\n                        [\n                            intermediate_settings,\n                            group_shard_settings,\n                            group_shard_settings,\n                        ],\n                        self.tp_size,\n                        self.tp_rank,\n                    )\n                },\n            )\n\n            if quant_config is None:\n                # - quant layers do not have a weight loader\n                delattr(self.in_proj.weight, \"weight_loader\")\n                set_weight_attrs(\n                    self.in_proj.weight,\n                    {\n                        \"weight_loader\": mamba_v2_sharded_weight_loader(\n                            [\n                                intermediate_settings,  # for gate\n                                intermediate_settings,\n                                group_shard_settings,\n                                group_shard_settings,\n                                head_settings,  # for dt\n                            ],\n                            self.tp_size,\n                            self.tp_rank,\n                        )\n                    },\n                )\n\n        # unsqueeze to fit conv1d weights shape into the linear weights shape.\n        # Can't do this in `weight_loader` since it already exists in\n        # `ColumnParallelLinear` and `MergedColumnParallelLinear`,\n        # and `set_weight_attrs` doesn't allow to override it\n        self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)\n\n        # - these are TPed by heads to reduce the size of the\n        #   temporal shape\n        self.A = nn.Parameter(\n            torch.empty(\n                divide(num_heads, self.tp_size),\n                dtype=torch.float32,\n            )\n        )\n        self.D = nn.Parameter(torch.ones(num_heads // self.tp_size))\n        self.dt_bias = nn.Parameter(torch.ones(num_heads // self.tp_size))\n        self.use_rms_norm = use_rms_norm\n\n        set_weight_attrs(self.D, {\"weight_loader\": sharded_weight_loader(0)})\n        a_weight_loader = composed_weight_loader(\n            sharded_weight_loader(0), lambda x: -torch.exp(x.float())\n        )\n        set_weight_attrs(self.A, {\"weight_loader\": a_weight_loader})\n        set_weight_attrs(self.dt_bias, {\"weight_loader\": sharded_weight_loader(0)})\n\n        self.out_proj = RowParallelLinear(\n            intermediate_size,\n            hidden_size,\n            bias=use_bias,\n            input_is_parallel=True,\n            quant_config=quant_config,\n            prefix=f\"{prefix}.out_proj\",\n        )\n\n        self.norm = Mixer2RMSNormGated(\n            intermediate_size, n_groups, self.use_rms_norm, eps=rms_norm_eps\n        )\n\n        self.prefix = prefix\n\n    def forward(\n        self,\n        *,\n        hidden_states: torch.Tensor,\n        output: torch.Tensor,\n        layer_cache: MambaPool.State,\n        metadata: Mamba2Metadata,\n        mup_vector: Optional[torch.Tensor] = None,\n        use_triton_causal_conv: bool = False,\n    ):\n        # metadata contains metadata necessary for the mamba2 triton\n        # kernels to operate in continuous batching and in chunked prefill\n        # modes; they are computed at top-level model forward since they\n        # stay the same and reused for all mamba layers in the same iteration\n        state_indices_tensor = metadata.mamba_cache_indices\n        conv_state = layer_cache.conv[0]\n        ssm_state = layer_cache.temporal\n\n        query_start_loc = metadata.query_start_loc\n\n        # 1. Gated MLP's linear projection\n        projected_states, _ = self.in_proj(hidden_states)\n\n        if mup_vector is not None:\n            projected_states = projected_states * mup_vector\n\n        gate, hidden_states_B_C, dt = torch.split(\n            projected_states,\n            [\n                self.intermediate_size // self.tp_size,\n                self.conv_dim // self.tp_size,\n                self.num_heads // self.tp_size,\n            ],\n            dim=-1,\n        )\n        conv_weights = self.conv1d.weight.view(\n            self.conv1d.weight.size(0), self.conv1d.weight.size(2)\n        )\n\n        # - get hidden_states, B and C after depthwise convolution.\n        split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(\n            hidden_states_B_C,\n            [\n                self.intermediate_size // self.tp_size,\n                self.groups_ssm_state_size // self.tp_size,\n                self.groups_ssm_state_size // self.tp_size,\n            ],\n            dim=-1,\n        )\n\n        num_prefills = metadata.num_prefills  # request count\n        num_decodes = metadata.num_decodes  # token count (=request)\n        num_decode_tokens = (\n            num_decodes * metadata.draft_token_num\n            if metadata.is_target_verify\n            else num_decodes\n        )\n        num_prefill_tokens = metadata.num_prefill_tokens  # token count\n        has_prefill = num_prefills > 0\n        has_decode = num_decodes > 0\n        num_actual_tokens = num_prefill_tokens + num_decode_tokens\n        assert num_actual_tokens == projected_states.shape[0]\n\n        # NOTE: V0 put prefill before decode\n        # Separate prefill and decode by splitting varlen input\n        # Split along token dimension\n        hidden_states_B_C_p, hidden_states_B_C_d = torch.split(\n            hidden_states_B_C,\n            [num_prefill_tokens, num_decode_tokens],\n            dim=0,\n        )\n        dt_p, dt_d = torch.split(\n            dt,\n            [num_prefill_tokens, num_decode_tokens],\n            dim=0,\n        )\n        # Split along batch dimension\n        state_indices_tensor_p, state_indices_tensor_d = torch.split(\n            state_indices_tensor,\n            [num_prefills, num_decodes],\n            dim=0,\n        )\n        query_start_loc_p = query_start_loc[: num_prefills + 1] if has_prefill else None\n\n        # Preallocate output tensor to avoid memcpy cost for merging prefill\n        # and decode outputs\n\n        preallocated_ssm_out = torch.empty(\n            [\n                projected_states.shape[0],\n                (self.num_heads * self.head_dim) // self.tp_size,\n            ],\n            dtype=hidden_states.dtype,\n            device=hidden_states.device,\n        )\n        preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split(\n            preallocated_ssm_out,\n            [num_prefill_tokens, num_decode_tokens],\n            dim=0,\n        )\n\n        # Process prefill requests\n        if has_prefill:\n            mixed_metadata = metadata.mixed_metadata\n            assert mixed_metadata is not None\n            # 2. Convolution sequence transformation\n            # - \"cache_indices\" updates the conv_state cache in positions\n            #   pointed to by \"state_indices_tensor\"\n            has_initial_states_p = mixed_metadata.has_initial_states\n            prep_initial_states = mixed_metadata.prep_initial_states\n            cache_indices = state_indices_tensor_p\n            x = hidden_states_B_C_p.transpose(\n                0, 1\n            )  # this is the form that causal-conv see\n            ccfn = (\n                causal_conv1d_fn\n                if not use_triton_causal_conv\n                else causal_conv1d_fn_triton\n            )\n            hidden_states_B_C_p = ccfn(\n                x,\n                conv_weights,\n                self.conv1d.bias,\n                activation=self.activation,\n                conv_states=conv_state,\n                has_initial_state=has_initial_states_p,\n                cache_indices=cache_indices,\n                query_start_loc=query_start_loc_p,\n                seq_lens_cpu=mixed_metadata.extend_seq_lens_cpu,\n            ).transpose(0, 1)[:num_prefill_tokens]\n\n            hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn(hidden_states_B_C_p)\n\n            # 3. State Space Model sequence transformation\n            initial_states = None\n            if has_initial_states_p is not None and prep_initial_states:\n                initial_states = torch.where(\n                    has_initial_states_p[:, None, None, None],\n                    ssm_state[state_indices_tensor_p],\n                    0,\n                )\n\n            # NOTE: final output is an in-place update of out tensor\n            varlen_state = mamba_chunk_scan_combined(\n                hidden_states_p.view(\n                    1, num_prefill_tokens, self.num_heads // self.tp_size, self.head_dim\n                ),\n                dt_p.unsqueeze(0),\n                self.A,\n                B_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1),\n                C_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1),\n                chunk_size=mixed_metadata.chunk_size,\n                D=self.D,\n                z=None,\n                dt_bias=self.dt_bias,\n                seq_idx=mixed_metadata.seq_idx,\n                chunk_indices=mixed_metadata.chunk_indices,\n                chunk_offsets=mixed_metadata.chunk_offsets,\n                cu_seqlens=query_start_loc_p,\n                initial_states=initial_states,\n                return_varlen_states=True,\n                return_final_states=False,\n                dt_softplus=True,\n                dt_limit=(0.0, float(\"inf\")),\n                out=preallocated_ssm_out_p.view(\n                    1, num_prefill_tokens, -1, self.head_dim\n                ),\n                state_dtype=ssm_state.dtype,\n            )\n\n            # update ssm states\n            # - varlen state is a (num_prefills, nheads, headdim, dstate) tensor\n            ssm_state[state_indices_tensor_p] = varlen_state\n\n        # Process decode requests\n        if has_decode:\n            is_target_verify = metadata.is_target_verify\n\n            # 2. Convolution sequence transformation\n            if is_target_verify:\n                assert (\n                    use_triton_causal_conv\n                ), \"Speculative decoding requires use_triton_causal_conv=True for intermediate state support\"\n                assert isinstance(\n                    layer_cache, MambaPool.SpeculativeState\n                ), \"layer_cache must be SpeculativeState for speculative decoding\"\n                draft_token_num = metadata.draft_token_num\n                self.intermediate_state_indices = torch.arange(\n                    num_decodes, dtype=torch.int32, device=state_indices_tensor_d.device\n                )\n\n                # Reshape for batch processing\n                hidden_states_B_C_d_reshaped = hidden_states_B_C_d.view(\n                    num_decodes, draft_token_num, -1\n                ).transpose(1, 2)\n\n                hidden_states_B_C_d_processed = causal_conv1d_update_triton(\n                    hidden_states_B_C_d_reshaped,\n                    conv_state,\n                    conv_weights,\n                    self.conv1d.bias,\n                    self.activation,\n                    conv_state_indices=state_indices_tensor_d[:num_decodes],\n                    intermediate_conv_window=layer_cache.intermediate_conv_window[0],\n                    intermediate_state_indices=self.intermediate_state_indices,\n                    retrieve_next_token=metadata.retrieve_next_token,\n                    retrieve_next_sibling=metadata.retrieve_next_sibling,\n                    retrieve_parent_token=metadata.retrieve_parent_token,\n                )\n                hidden_states_B_C_d = hidden_states_B_C_d_processed.transpose(\n                    1, 2\n                ).view(num_decode_tokens, -1)\n            else:\n                ccu = (\n                    causal_conv1d_update\n                    if not use_triton_causal_conv\n                    else causal_conv1d_update_triton\n                )\n                hidden_states_B_C_d = ccu(\n                    hidden_states_B_C_d,\n                    conv_state,\n                    conv_weights,\n                    self.conv1d.bias,\n                    self.activation,\n                    conv_state_indices=state_indices_tensor_d,\n                )\n\n            hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(hidden_states_B_C_d)\n\n            # 3. State Space Model sequence transformation\n            n_groups = self.n_groups // self.tp_size\n            A_d = (\n                self.A[:, None, ...][:, :, None]\n                .expand(-1, self.head_dim, self.ssm_state_size)\n                .to(dtype=torch.float32)\n            )\n            dt_d = dt_d[:, :, None].expand(-1, -1, self.head_dim)\n            dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)\n            D_d = self.D[:, None, ...].expand(-1, self.head_dim)\n            B_d = B_d.view(-1, n_groups, B_d.shape[1] // n_groups)\n            C_d = C_d.view(-1, n_groups, C_d.shape[1] // n_groups)\n            hidden_states_d = hidden_states_d.view(\n                -1, self.num_heads // self.tp_size, self.head_dim\n            )\n\n            if is_target_verify:\n                selective_state_update(\n                    ssm_state,\n                    hidden_states_d.view(\n                        num_decodes,\n                        draft_token_num,\n                        self.num_heads // self.tp_size,\n                        self.head_dim,\n                    ),\n                    dt_d.view(\n                        num_decodes,\n                        draft_token_num,\n                        self.num_heads // self.tp_size,\n                        self.head_dim,\n                    ),\n                    A_d,\n                    B_d.view(num_decodes, draft_token_num, n_groups, -1),\n                    C_d.view(num_decodes, draft_token_num, n_groups, -1),\n                    D_d,\n                    z=None,\n                    dt_bias=dt_bias,\n                    dt_softplus=True,\n                    state_batch_indices=state_indices_tensor_d[:num_decodes],\n                    out=preallocated_ssm_out_d.view(\n                        num_decodes,\n                        draft_token_num,\n                        self.num_heads // self.tp_size,\n                        self.head_dim,\n                    ),\n                    disable_state_update=True,\n                    intermediate_states_buffer=layer_cache.intermediate_ssm,\n                    cache_steps=draft_token_num,\n                    retrieve_parent_token=metadata.retrieve_parent_token,\n                    intermediate_state_indices=self.intermediate_state_indices,\n                )\n            else:\n                selective_state_update(\n                    ssm_state,\n                    hidden_states_d,\n                    dt_d,\n                    A_d,\n                    B_d,\n                    C_d,\n                    D_d,\n                    z=None,\n                    dt_bias=dt_bias,\n                    dt_softplus=True,\n                    state_batch_indices=state_indices_tensor_d,\n                    out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim),\n                )\n\n        # 4. gated MLP\n        # GatedRMSNorm internally applying SiLU to the gate\n        # SiLU is applied internally before normalization, unlike standard\n        # norm usage\n        hidden_states = self.norm(preallocated_ssm_out, gate[:num_actual_tokens])\n\n        # 5. Final linear projection\n        output[:num_actual_tokens], _ = self.out_proj(hidden_states)\n\n    @property\n    def mamba_type(self) -> str:\n        return \"mamba2\"\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/mamba/mamba2_metadata.py",
    "content": "# Copyright 2025 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n# Adapted from https://github.com/vllm-project/vllm/blob/2c58742dff8613a3bd7496f2008ce927e18d38d1/vllm/model_executor/layers/mamba/mamba2_metadata.py\n\n\nimport math\nfrom dataclasses import dataclass\nfrom typing import Optional\n\nimport torch\n\nfrom sglang.srt.model_executor.forward_batch_info import ForwardBatch\n\n\n@dataclass(kw_only=True)\nclass ForwardMetadata:\n    query_start_loc: torch.Tensor\n    mamba_cache_indices: torch.Tensor\n    # For topk > 1 eagle\n    retrieve_next_token: Optional[torch.Tensor] = None\n    retrieve_next_sibling: Optional[torch.Tensor] = None\n    retrieve_parent_token: Optional[torch.Tensor] = None\n    # For prefill radix cache\n    track_conv_indices: Optional[torch.Tensor] = None\n    track_ssm_h_src: Optional[torch.Tensor] = None\n    track_ssm_h_dst: Optional[torch.Tensor] = None\n    track_ssm_final_src: Optional[torch.Tensor] = None\n    track_ssm_final_dst: Optional[torch.Tensor] = None\n\n    is_target_verify: bool = False\n    draft_token_num: int = 1\n\n\n@dataclass(kw_only=True)\nclass Mamba2Metadata(ForwardMetadata):\n    \"\"\"stable metadata across all mamba2 layers in the forward pass\"\"\"\n\n    num_prefills: int\n    num_prefill_tokens: int\n    num_decodes: int\n\n    @dataclass(kw_only=True, frozen=True)\n    class MixedMetadata:\n        has_initial_states: torch.Tensor\n        prep_initial_states: bool\n\n        chunk_size: int\n        seq_idx: torch.Tensor\n        chunk_indices: torch.Tensor\n        chunk_offsets: torch.Tensor\n\n        extend_seq_lens_cpu: list[int]\n\n    mixed_metadata: MixedMetadata | None = None\n    \"\"\"`mixed_metadata` is used for extend/mixed requests\"\"\"\n\n    @staticmethod\n    def _query_start_loc_to_chunk_indices_offsets(\n        query_start_loc: torch.Tensor, chunk_size: int, total_seqlens: int\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Args:\n            query_start_loc (torch.Tensor): 1D tensor of cumulative sequence\n                lengths, shape (num_seqs + 1,).\n                The first element should be 0. Each entry represents the starting\n                index of a sequence in the flattened token array.\n            chunk_size (int): The size of each physical mamba chunk\n                (number of tokens per chunk).\n            total_seqlens (int): The total number of tokens in the batch.\n\n        Returns:\n            Tuple[torch.Tensor, torch.Tensor]: A tuple containing:\n                - chunk_indices (torch.Tensor): 1D tensor of indices\n                    indicating the physical chunk for each logical chunk.\n                - chunk_offsets (torch.Tensor): 1D tensor of offsets\n                    indicating the starting index of each logical chunk within\n                    its physical chunk.\n\n        This function computes the chunk indices and offsets for the given\n        query_start_loc and chunk_size. Both are tensors of integers with length N,\n        where N is the number of logical (pseudo) chunks.\n        A logical chunk is a sequence of tokens that are all part of the same\n        sequence and are all in the same physical mamba chunk.\n        In other words, a logical chunk changes every time we cross a sequence\n        boundary or a physical mamba chunk boundary.\n        Logical chunks are needed to handle batched requests with initial states\n        (see _state_passing_fwd and _chunk_scan_fwd).\n        The chunk_indices tensor contains the index of the physical chunk for each\n        logical chunk.\n        The chunk_offsets tensor contains the offset (AKA starting index) of the\n        logical chunk in the physical chunk.\n\n        Example:\n        query_start_loc = [0, 5, 10]\n        chunk_size = 8\n        total_seqlens = 10\n        -> chunk_indices = [0, 0, 1]\n        -> chunk_offsets = [0, 5, 0]\n\n        In this example, we have 2 sequences, each with 5 tokens. The physical\n        chunk size is 8 tokens.\n        We have three logical chunks:\n        - the first logical chunk starts at token 0 in the first physical chunk\n            and contains all 5 tokens from the first sequence\n        - the second logical chunk starts at token 5 in the first physical chunk\n            and contains first 3 tokens from the second sequence\n        - the third logical chunk starts at token 0 in the second physical chunk\n            and contains the remaining 2 tokens from the second sequence\n        \"\"\"\n\n        cu_seqlens = query_start_loc[1:]  # remove prepended 0\n\n        # outputs will have length expansion of chunks that do not divide\n        # chunk_size\n        N = (\n            math.ceil(total_seqlens / chunk_size)\n            + (cu_seqlens[:-1] % chunk_size > 0).sum()\n        )\n        chunk_indices = torch.arange(N, dtype=torch.int, device=query_start_loc.device)\n        chunk_offsets = torch.zeros(\n            (N,), dtype=torch.int, device=query_start_loc.device\n        )\n\n        p = 0  # num of insertions\n        for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):\n\n            # if does not divide chunk_size, then there is one chunk insertion\n            p += s % chunk_size > 0\n\n            # get the dimensions\n            # - the + 1 for _e is to shift the boundary by one chunk\n            # - this shifting is not needed if chunk_size divides e\n            _s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size > 0)\n\n            # adjust indices and offsets\n            chunk_indices[_s:_e] -= p\n            chunk_offsets[_s] = s % chunk_size\n\n        return chunk_indices, chunk_offsets\n\n    @staticmethod\n    def prepare_decode(\n        forward_metadata: ForwardMetadata,\n        seq_lens: torch.Tensor,\n        *,\n        is_target_verify: bool,\n        draft_token_num: int,\n    ) -> \"Mamba2Metadata\":\n        \"\"\"This path is run during CUDA graph capture, i.e. decode only, so `num_prefills` is 0\"\"\"\n        return Mamba2Metadata(\n            query_start_loc=forward_metadata.query_start_loc,\n            mamba_cache_indices=forward_metadata.mamba_cache_indices,\n            retrieve_next_token=forward_metadata.retrieve_next_token,\n            retrieve_next_sibling=forward_metadata.retrieve_next_sibling,\n            retrieve_parent_token=forward_metadata.retrieve_parent_token,\n            num_decodes=len(seq_lens),\n            num_prefills=0,\n            num_prefill_tokens=0,\n            is_target_verify=is_target_verify,\n            draft_token_num=draft_token_num,\n        )\n\n    @classmethod\n    def prepare_mixed(\n        cls,\n        forward_metadata: ForwardMetadata,\n        chunk_size: int,\n        forward_batch: ForwardBatch,\n    ) -> \"Mamba2Metadata\":\n        \"\"\"This path cannot run with CUDA graph, as it contains extend requests.\"\"\"\n        if forward_batch.extend_num_tokens is None:\n            draft_token_num = (\n                forward_batch.spec_info.draft_token_num\n                if forward_batch.spec_info is not None\n                else 1\n            )\n            return cls.prepare_decode(\n                forward_metadata,\n                forward_batch.seq_lens,\n                is_target_verify=forward_batch.forward_mode.is_target_verify(),\n                draft_token_num=draft_token_num,\n            )\n        num_prefills = len(forward_batch.extend_seq_lens)\n        num_prefill_tokens = forward_batch.extend_num_tokens\n        num_decodes = len(forward_batch.seq_lens) - num_prefills\n        context_lens_tensor = forward_batch.extend_prefix_lens\n        assert context_lens_tensor is not None\n        # precompute flag to avoid device syncs later\n        has_initial_states = context_lens_tensor > 0\n        prep_initial_states = torch.any(has_initial_states[:num_prefills]).item()\n\n        query_start_loc = forward_metadata.query_start_loc[: num_prefills + 1]\n        seq_idx = torch.repeat_interleave(\n            torch.arange(\n                num_prefills, dtype=torch.int32, device=query_start_loc.device\n            ),\n            query_start_loc.diff(),\n            output_size=num_prefill_tokens,\n        )\n        seq_idx.unsqueeze_(0)\n\n        # We compute metadata for chunked prefill once at the top level model\n        # forward and reuse them in mamba layers. If not needed, they will be\n        # ignored inside mamba kernels.\n        chunk_offsets, chunk_indices = None, None\n        if prep_initial_states:\n            chunk_indices, chunk_offsets = (\n                cls._query_start_loc_to_chunk_indices_offsets(\n                    query_start_loc, chunk_size, num_prefill_tokens\n                )\n            )\n\n        draft_token_num = (\n            getattr(forward_batch.spec_info, \"draft_token_num\", 1)\n            if forward_batch.spec_info is not None\n            else 1\n        )\n        return Mamba2Metadata(\n            query_start_loc=query_start_loc,\n            mamba_cache_indices=forward_metadata.mamba_cache_indices,\n            retrieve_next_token=forward_metadata.retrieve_next_token,\n            retrieve_next_sibling=forward_metadata.retrieve_next_sibling,\n            retrieve_parent_token=forward_metadata.retrieve_parent_token,\n            num_prefills=num_prefills,\n            num_prefill_tokens=num_prefill_tokens,\n            num_decodes=num_decodes,\n            is_target_verify=forward_batch.forward_mode.is_target_verify(),\n            draft_token_num=draft_token_num,\n            mixed_metadata=cls.MixedMetadata(\n                has_initial_states=has_initial_states,\n                prep_initial_states=prep_initial_states,\n                chunk_size=chunk_size,\n                seq_idx=seq_idx,\n                chunk_indices=chunk_indices,\n                chunk_offsets=chunk_offsets,\n                extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,\n            ),\n        )\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/mamba/mamba_state_scatter_triton.py",
    "content": "\"\"\"\nFused Triton kernel for Mamba state scatter operations.\n\nThis kernel replaces the expensive advanced indexing operations in\n`update_mamba_state_after_mtp_verify` with a single fused gather-scatter kernel,\navoiding multiple `index_elementwise_kernel` launches.\n\"\"\"\n\nimport torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _fused_mamba_state_scatter_with_mask_kernel(\n    src_ptr,\n    dst_ptr,\n    # Raw index arrays (before index_select)\n    dst_indices_raw_ptr,  # [total_requests] - state_indices_tensor\n    step_indices_raw_ptr,  # [total_requests] - accepted_steps or mamba_steps_to_track\n    # Total number of requests\n    total_requests,\n    elem_per_entry: tl.constexpr,\n    src_layer_stride,\n    src_req_stride,\n    src_step_stride,\n    dst_layer_stride,\n    dst_req_stride,\n    src_req_size,\n    src_step_size,\n    dst_req_size,\n    BLOCK_SIZE: tl.constexpr,\n):\n    \"\"\"\n    Fused gather-scatter kernel with built-in masking.\n\n    This kernel fuses the index_select operations by:\n    1. Iterating over all requests (pid_req from 0 to total_requests-1)\n    2. Checking if step_indices_raw[pid_req] >= 0 (valid mask)\n    3. If valid, performing the scatter:\n       dst[l, dst_indices_raw[pid_req], :] = src[l, pid_req, step_indices_raw[pid_req], :]\n\n    Grid: (total_requests, num_layers, ceil(elem_per_entry / BLOCK_SIZE))\n    \"\"\"\n    pid_req = tl.program_id(0)\n    pid_layer = tl.program_id(1).to(tl.int64)\n    pid_block = tl.program_id(2).to(tl.int64)\n\n    # Load step index to check validity (step >= 0 means valid)\n    step_idx = tl.load(step_indices_raw_ptr + pid_req).to(tl.int64)\n\n    # Early exit if this request is not valid (step < 0)\n    if step_idx < 0:\n        return\n\n    # Load destination index\n    dst_idx = tl.load(dst_indices_raw_ptr + pid_req).to(tl.int64)\n\n    # Source index is just the request index itself\n    src_idx = pid_req\n\n    # Bounds check to avoid illegal memory access\n    if not (\n        (dst_idx >= 0)\n        & (dst_idx < dst_req_size)\n        & (src_idx >= 0)\n        & (src_idx < src_req_size)\n        & (step_idx < src_step_size)\n    ):\n        return\n\n    # Compute base offsets\n    src_offset = (\n        pid_layer * src_layer_stride\n        + src_idx * src_req_stride\n        + step_idx * src_step_stride\n    )\n    dst_offset = pid_layer * dst_layer_stride + dst_idx * dst_req_stride\n\n    # Compute element range for this block\n    start = pid_block * BLOCK_SIZE\n    offsets = start + tl.arange(0, BLOCK_SIZE)\n    mask = offsets < elem_per_entry\n\n    # Load from source and store to destination\n    data = tl.load(src_ptr + src_offset + offsets, mask=mask)\n    tl.store(dst_ptr + dst_offset + offsets, data, mask=mask)\n\n\ndef fused_mamba_state_scatter_with_mask(\n    dst: torch.Tensor,  # [num_layers, cache_size, *state_shape]\n    src: torch.Tensor,  # [num_layers, spec_size, draft_tokens, *state_shape]\n    dst_indices_raw: torch.Tensor,  # [total_requests] - raw indices (e.g., state_indices_tensor)\n    step_indices_raw: torch.Tensor,  # [total_requests] - raw step indices (step >= 0 means valid)\n):\n    \"\"\"\n    Fully fused gather-scatter with built-in masking for mamba state updates.\n\n    This function fuses the following operations into a single kernel:\n    1. valid_mask = step_indices_raw >= 0\n    2. valid_indices = valid_mask.nonzero()\n    3. dst_indices = dst_indices_raw[valid_indices]  (index_select)\n    4. step_indices = step_indices_raw[valid_indices]  (index_select)\n    5. for each valid i: dst[:, dst_indices[i], :] = src[:, i, step_indices[i], :]\n\n    Args:\n        dst: Destination tensor [num_layers, cache_size, *state_shape]\n        src: Source tensor [num_layers, spec_size, draft_tokens, *state_shape]\n        dst_indices_raw: Raw destination indices for all requests [total_requests]\n        step_indices_raw: Raw step indices; entry >= 0 means valid [total_requests]\n    \"\"\"\n    total_requests = step_indices_raw.shape[0]\n    if total_requests == 0:\n        return\n\n    if dst.device != src.device:\n        raise ValueError(\n            f\"dst and src must be on the same device. {dst.device=} {src.device=}\"\n        )\n    if not dst.is_cuda or not src.is_cuda:\n        raise ValueError(\n            \"fused_mamba_state_scatter_with_mask only supports CUDA tensors.\"\n        )\n    if dst.ndim < 2 or src.ndim < 3:\n        raise ValueError(f\"Unexpected tensor ranks: {dst.ndim=} {src.ndim=}\")\n    if dst.shape[0] != src.shape[0]:\n        raise ValueError(\n            f\"Layer dimension mismatch: {dst.shape[0]=} vs {src.shape[0]=}\"\n        )\n    if dst.shape[2:] != src.shape[3:]:\n        raise ValueError(\n            f\"Trailing dims mismatch: {dst.shape[2:]=} vs {src.shape[3:]=}\"\n        )\n    if dst_indices_raw.ndim != 1 or step_indices_raw.ndim != 1:\n        raise ValueError(\n            f\"indices must be 1D: {dst_indices_raw.shape=} {step_indices_raw.shape=}\"\n        )\n    if dst_indices_raw.shape[0] != step_indices_raw.shape[0]:\n        raise ValueError(\n            f\"indices length mismatch: {dst_indices_raw.shape[0]=} vs {step_indices_raw.shape[0]=}\"\n        )\n\n    num_layers = dst.shape[0]\n    src_req_size = src.shape[1]\n    src_step_size = src.shape[2]\n    dst_req_size = dst.shape[1]\n\n    # Flatten trailing dimensions: number of elements per (layer, cache_line) entry.\n    elem_per_entry = dst.numel() // (dst.shape[0] * dst.shape[1])\n\n    # Get strides (in elements, not bytes)\n    src_layer_stride = src.stride(0)\n    src_req_stride = src.stride(1)\n    src_step_stride = src.stride(2)\n    dst_layer_stride = dst.stride(0)\n    dst_req_stride = dst.stride(1)\n\n    # Ensure indices are int32 and contiguous\n    dst_indices_raw = dst_indices_raw.to(torch.int32).contiguous()\n    step_indices_raw = step_indices_raw.to(torch.int32).contiguous()\n\n    # Ensure tensors are contiguous\n    if not dst.is_contiguous():\n        raise ValueError(\"dst tensor must be contiguous\")\n    if not src.is_contiguous():\n        raise ValueError(\"src tensor must be contiguous\")\n\n    # Block size for copying elements\n    BLOCK_SIZE = 1024\n\n    # Grid over all requests - invalid ones will early-exit in the kernel\n    grid = (total_requests, num_layers, triton.cdiv(elem_per_entry, BLOCK_SIZE))\n\n    _fused_mamba_state_scatter_with_mask_kernel[grid](\n        src,\n        dst,\n        dst_indices_raw,\n        step_indices_raw,\n        total_requests,\n        elem_per_entry,\n        src_layer_stride,\n        src_req_stride,\n        src_step_stride,\n        dst_layer_stride,\n        dst_req_stride,\n        src_req_size,\n        src_step_size,\n        dst_req_size,\n        BLOCK_SIZE=BLOCK_SIZE,\n    )\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py",
    "content": "from typing import Union\n\nimport torch\n\nfrom sglang.srt.distributed.communication_op import (\n    tensor_model_parallel_all_gather,\n    tensor_model_parallel_all_reduce,\n)\nfrom sglang.srt.distributed.parallel_state import (\n    get_tensor_model_parallel_rank,\n    get_tensor_model_parallel_world_size,\n)\nfrom sglang.srt.layers.attention.fla.layernorm_gated import rms_norm_gated\nfrom sglang.srt.layers.utils import MultiPlatformOp\nfrom sglang.srt.model_loader.weight_utils import sharded_weight_loader\nfrom sglang.srt.utils.common import set_weight_attrs\n\n\nclass Mixer2RMSNormGated(MultiPlatformOp):\n    def __init__(\n        self,\n        full_hidden_size: int,\n        full_n_groups: int,\n        use_rms_norm: bool = True,\n        eps: float = 1e-6,\n    ):\n        super().__init__()\n        self.tp_size = get_tensor_model_parallel_world_size()\n        self.tp_rank = get_tensor_model_parallel_rank()\n        self.full_hidden_size = full_hidden_size\n        self.group_size = full_hidden_size // full_n_groups\n        self.per_rank_hidden_size = full_hidden_size // self.tp_size\n        self.n_groups = full_hidden_size // self.group_size\n\n        self.variance_epsilon = eps\n        self.use_rms_norm = use_rms_norm\n        if self.use_rms_norm:\n            # Register norm weight only if we're actually applying RMSNorm\n            self.weight = torch.nn.Parameter(torch.ones(self.per_rank_hidden_size))\n            set_weight_attrs(self.weight, {\"weight_loader\": sharded_weight_loader(0)})\n        else:\n            # Avoid checkpoint mismatch by skipping unused parameter\n            self.register_parameter(\"weight\", None)\n        assert (\n            self.full_hidden_size % self.tp_size == 0\n        ), \"Tensor parallel world size must divide hidden size.\"\n\n    def forward_native(\n        self,\n        x: torch.Tensor,\n        gate: torch.Tensor,\n    ):\n        # Three tensor-parallel cases:\n        #   1. n_groups is 1\n        #      In this case we parallelize along the reduction dim.\n        #      Each rank computes a local sum of squares followed by AllReduce\n        #   2. tp_size divides n_groups\n        #      Each rank only reduces within its local group(s).\n        #      No collective ops necessary.\n        #   3. The general case can be pretty complicated so we AllGather\n        #      the input and then redundantly compute the RMSNorm.\n        input_dtype = x.dtype\n        x = x * torch.nn.functional.silu(gate.to(torch.float32))\n        if not self.use_rms_norm:\n            return x.to(input_dtype)\n\n        if self.n_groups == 1:\n            if self.tp_size > 1:\n                # Compute local sum and then reduce to obtain global sum\n                local_sums = x.pow(2).sum(dim=-1, keepdim=True)\n                global_sums = tensor_model_parallel_all_reduce(local_sums)\n                # Calculate the variance\n                count = self.tp_size * x.shape[-1]\n                variance = global_sums / count\n\n            else:\n                variance = x.pow(2).mean(-1, keepdim=True)\n            x = x * torch.rsqrt(variance + self.variance_epsilon)\n        else:\n            redundant_tp: bool = self.n_groups % self.tp_size != 0\n            if redundant_tp:\n                # To handle the general case, redundantly apply the variance\n                x = tensor_model_parallel_all_gather(x, -1)\n\n            *prefix_dims, hidden_dim = x.shape\n            group_count = hidden_dim // self.group_size\n            x_grouped = x.view(*prefix_dims, group_count, self.group_size)\n            variance = x_grouped.pow(2).mean(-1, keepdim=True)\n            x_grouped = x_grouped * torch.rsqrt(variance + self.variance_epsilon)\n            x = x_grouped.view(*prefix_dims, hidden_dim)\n\n            if redundant_tp:\n                start = self.per_rank_hidden_size * self.tp_rank\n                end = start + self.per_rank_hidden_size\n                x = x[..., start:end]\n\n        return self.weight * x.to(input_dtype)\n\n    def forward_cuda(\n        self,\n        x: torch.Tensor,\n        gate: torch.Tensor,\n    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:\n        input_dtype = x.dtype\n        if not self.use_rms_norm:\n            # Keep gate in float32 for numerical stability during silu\n            return x * torch.nn.functional.silu(gate.to(torch.float32)).to(input_dtype)\n\n        if ((self.n_groups % self.tp_size) != 0) or self.n_groups != 1:\n            return self.forward_native(x, gate)\n\n        return rms_norm_gated(\n            x=x,\n            weight=self.weight.data,\n            bias=None,\n            z=gate,\n            eps=self.variance_epsilon,\n            norm_before_gate=False,\n            is_rms_norm=True,\n        )\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/mamba/ops/__init__.py",
    "content": "from .mamba_ssm import PAD_SLOT_ID\nfrom .ssd_combined import mamba_chunk_scan_combined\nfrom .ssu_dispatch import (\n    initialize_mamba_selective_state_update_backend,\n    selective_state_update,\n)\n\n__all__ = [\n    \"PAD_SLOT_ID\",\n    \"selective_state_update\",\n    \"mamba_chunk_scan_combined\",\n    \"initialize_mamba_selective_state_update_backend\",\n]\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/mamba/ops/layernorm_gated.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n# SPDX-FileCopyrightText: Copyright contributors to the vLLM project\n# Copyright (c) 2024, Tri Dao.\n# Adapted from https://github.com/state-spaces/mamba/blob/60dadf2e0ee730ac337035d5533de10bc26e4847/mamba_ssm/ops/triton/layernorm_gated.py\n\nimport torch\nimport triton\nimport triton.language as tl\n\n\n@triton.heuristics({\"HAS_BIAS\": lambda args: args[\"B\"] is not None})\n@triton.heuristics({\"HAS_Z\": lambda args: args[\"Z\"] is not None})\n@triton.jit\ndef _layer_norm_fwd_1pass_kernel(\n    X,  # pointer to the input\n    Y,  # pointer to the output\n    W,  # pointer to the weights\n    B,  # pointer to the biases\n    Z,  # pointer to the other branch\n    Mean,  # pointer to the mean\n    Rstd,  # pointer to the 1/std\n    stride_x_row: tl.int64,\n    stride_y_row: tl.int64,\n    stride_z_row: tl.int64,\n    M: tl.int64,  # number of rows in X\n    N: tl.int64,  # number of columns in X\n    eps,  # epsilon to avoid division by zero\n    BLOCK_N: tl.constexpr,\n    HAS_BIAS: tl.constexpr,\n    HAS_Z: tl.constexpr,\n    NORM_BEFORE_GATE: tl.constexpr,\n    IS_RMS_NORM: tl.constexpr,\n):\n    # Map the program id to the row of X and Y it should compute.\n    row = tl.program_id(0)\n    group = tl.program_id(1)\n    X += row * stride_x_row + group * N\n    Y += row * stride_y_row + group * N\n    if HAS_Z:\n        Z += row * stride_z_row + group * N\n    if not IS_RMS_NORM:\n        Mean += group * M\n    Rstd += group * M\n    W += group * N\n    if HAS_BIAS:\n        B += group * N\n    # Compute mean and variance\n    cols = tl.arange(0, BLOCK_N)\n    x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n    if HAS_Z and not NORM_BEFORE_GATE:\n        z = tl.load(Z + cols, mask=cols < N).to(tl.float32)\n        x *= z * tl.sigmoid(z)\n    if not IS_RMS_NORM:\n        mean = tl.sum(x, axis=0) / N\n        tl.store(Mean + row, mean)\n        xbar = tl.where(cols < N, x - mean, 0.0)\n        var = tl.sum(xbar * xbar, axis=0) / N\n    else:\n        xbar = tl.where(cols < N, x, 0.0)\n        var = tl.sum(xbar * xbar, axis=0) / N\n    rstd = 1 / tl.sqrt(var + eps)\n    tl.store(Rstd + row, rstd)\n    # Normalize and apply linear transformation\n    mask = cols < N\n    w = tl.load(W + cols, mask=mask).to(tl.float32)\n    if HAS_BIAS:\n        b = tl.load(B + cols, mask=mask).to(tl.float32)\n    x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n    y = x_hat * w + b if HAS_BIAS else x_hat * w\n    if HAS_Z and NORM_BEFORE_GATE:\n        z = tl.load(Z + cols, mask=mask).to(tl.float32)\n        y *= z * tl.sigmoid(z)\n    # Write output\n    tl.store(Y + cols, y, mask=mask)\n\n\ndef _layer_norm_fwd(\n    x,\n    weight,\n    bias,\n    eps,\n    z=None,\n    out=None,\n    group_size=None,\n    norm_before_gate=True,\n    is_rms_norm=False,\n):\n    M, N = x.shape\n    if group_size is None:\n        group_size = N\n    assert N % group_size == 0\n    ngroups = N // group_size\n    assert x.stride(-1) == 1\n    if z is not None:\n        assert z.stride(-1) == 1\n        assert z.shape == (M, N)\n    assert weight.shape == (N,)\n    assert weight.stride(-1) == 1\n    if bias is not None:\n        assert bias.stride(-1) == 1\n        assert bias.shape == (N,)\n    # allocate output\n    if out is not None:\n        assert out.shape == x.shape\n    else:\n        out = torch.empty_like(x)\n    assert out.stride(-1) == 1\n    mean = (\n        torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)\n        if not is_rms_norm\n        else None\n    )\n    rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)\n    # Less than 64KB per feature: enqueue fused kernel\n    MAX_FUSED_SIZE = 65536 // x.element_size()\n    BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))\n    if group_size > BLOCK_N:\n        raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n    # heuristics for number of warps\n    num_warps = min(max(BLOCK_N // 256, 1), 8)\n    grid = (M, ngroups)\n    with torch.cuda.device(x.device.index):\n        _layer_norm_fwd_1pass_kernel[grid](\n            x,\n            out,\n            weight,\n            bias,\n            z,\n            mean,\n            rstd,\n            x.stride(0),\n            out.stride(0),\n            z.stride(0) if z is not None else 0,\n            M,\n            group_size,\n            eps,\n            BLOCK_N=BLOCK_N,\n            NORM_BEFORE_GATE=norm_before_gate,\n            IS_RMS_NORM=is_rms_norm,\n            num_warps=num_warps,\n        )\n    return out, mean, rstd\n\n\ndef rms_norm_gated(\n    x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True\n):\n    x_shape_og = x.shape\n    # reshape input data into 2D tensor\n    x = x.reshape(-1, x.shape[-1])\n    if x.stride(-1) != 1:\n        x = x.contiguous()\n    if z is not None:\n        assert z.shape == x_shape_og\n        z = z.reshape(-1, z.shape[-1])\n        if z.stride(-1) != 1:\n            z = z.contiguous()\n    weight = weight.contiguous()\n    if bias is not None:\n        bias = bias.contiguous()\n    y, _, _ = _layer_norm_fwd(\n        x,\n        weight,\n        bias,\n        eps,\n        z=z,\n        group_size=group_size,\n        norm_before_gate=norm_before_gate,\n        is_rms_norm=True,\n    )\n\n    return y.reshape(x_shape_og)\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/mamba/ops/mamba_ssm.py",
    "content": "# Adapted from: https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/mamba/ops/mamba_ssm.py\n\n# SPDX-License-Identifier: Apache-2.0\n# SPDX-FileCopyrightText: Copyright contributors to the vLLM project\n\n# Copyright (c) 2024, Tri Dao, Albert Gu.\n# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py\n\nimport torch\nimport triton\nimport triton.language as tl\nfrom packaging import version\n\nPAD_SLOT_ID = -1\n\nTRITON3 = version.parse(triton.__version__) >= version.parse(\"3.0.0\")\n\nif TRITON3:\n\n    @triton.jit\n    def softplus(dt):\n        dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt)\n        return dt\n\nelse:\n\n    @triton.jit\n    def softplus(dt):\n        dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)\n        return dt\n\n\n@triton.heuristics({\"HAS_DT_BIAS\": lambda args: args[\"dt_bias_ptr\"] is not None})\n@triton.heuristics({\"HAS_D\": lambda args: args[\"D_ptr\"] is not None})\n@triton.heuristics({\"HAS_Z\": lambda args: args[\"z_ptr\"] is not None})\n@triton.heuristics(\n    {\n        \"HAS_STATE_BATCH_INDICES\": lambda args: args[\"state_batch_indices_ptr\"]\n        is not None\n    }\n)\n@triton.heuristics(\n    {\"BLOCK_SIZE_DSTATE\": lambda args: triton.next_power_of_2(args[\"dstate\"])}\n)\n@triton.heuristics(\n    {\n        \"CACHE_INTERMEDIATE_STATES\": lambda args: args[\"intermediate_states_buffer\"]\n        is not None\n    }\n)\n@triton.heuristics(\n    {\n        \"HAS_EAGLE_TREE_CUSTOM_ATTN_MASK\": lambda args: args[\n            \"retrieve_parent_token_ptr\"\n        ]\n        is not None\n    }\n)\n@triton.heuristics(\n    {\n        \"HAS_INTERMEDIATE_STATE_INDICES\": lambda args: args[\n            \"intermediate_state_indices_ptr\"\n        ]\n        is not None\n    }\n)\n@triton.jit(do_not_specialize=[\"T\"])\ndef _selective_scan_update_kernel(\n    # Pointers to matrices\n    state_ptr,\n    x_ptr,\n    dt_ptr,\n    dt_bias_ptr,\n    A_ptr,\n    B_ptr,\n    C_ptr,\n    D_ptr,\n    z_ptr,\n    out_ptr,\n    state_batch_indices_ptr,\n    pad_slot_id,\n    intermediate_states_buffer,\n    cache_steps,\n    retrieve_parent_token_ptr,\n    intermediate_state_indices_ptr,\n    # Matrix dimensions\n    batch,\n    T,\n    nheads,\n    dim,\n    dstate,\n    nheads_ngroups_ratio,\n    # Strides\n    stride_state_batch,\n    stride_state_head,\n    stride_state_dim,\n    stride_state_dstate,\n    stride_x_batch,\n    stride_x_T,\n    stride_x_head,\n    stride_x_dim,\n    stride_dt_batch,\n    stride_dt_T,\n    stride_dt_head,\n    stride_dt_dim,\n    stride_dt_bias_head,\n    stride_dt_bias_dim,\n    stride_A_head,\n    stride_A_dim,\n    stride_A_dstate,\n    stride_B_batch,\n    stride_B_T,\n    stride_B_group,\n    stride_B_dstate,\n    stride_C_batch,\n    stride_C_T,\n    stride_C_group,\n    stride_C_dstate,\n    stride_D_head,\n    stride_D_dim,\n    stride_z_batch,\n    stride_z_T,\n    stride_z_head,\n    stride_z_dim,\n    stride_out_batch,\n    stride_out_T,\n    stride_out_head,\n    stride_out_dim,\n    stride_retrieve_parent_token_batch,\n    stride_retrieve_parent_token_T,\n    # Meta-parameters\n    DT_SOFTPLUS: tl.constexpr,\n    TIE_HDIM: tl.constexpr,\n    BLOCK_SIZE_M: tl.constexpr,\n    HAS_DT_BIAS: tl.constexpr,\n    HAS_D: tl.constexpr,\n    HAS_Z: tl.constexpr,\n    HAS_STATE_BATCH_INDICES: tl.constexpr,\n    DISABLE_STATE_UPDATE: tl.constexpr,\n    CACHE_INTERMEDIATE_STATES: tl.constexpr,\n    HAS_EAGLE_TREE_CUSTOM_ATTN_MASK: tl.constexpr,\n    HAS_INTERMEDIATE_STATE_INDICES: tl.constexpr,\n    BLOCK_SIZE_DSTATE: tl.constexpr,\n):\n    pid_m = tl.program_id(axis=0)\n    pid_b = tl.program_id(axis=1)\n    pid_h = tl.program_id(axis=2)\n\n    # If HAS_STATE_BATCH_INDICES is true, then the ssm state's batch coordinate\n    # is taken from the state_batch_indices_ptr Otherwise, the state coordinate\n    # is the same as the batch id.\n    if HAS_STATE_BATCH_INDICES:\n        state_batch_indices_ptr += pid_b\n        state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64)\n        state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head\n    else:\n        state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head\n\n    x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head\n    dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head\n    if HAS_DT_BIAS:\n        dt_bias_ptr += pid_h * stride_dt_bias_head\n    A_ptr += pid_h * stride_A_head\n    B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group\n    C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group\n    if HAS_Z:\n        z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head\n    out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head\n\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)\n    state_ptrs = state_ptr + (\n        offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate\n    )\n\n    mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)\n    if HAS_STATE_BATCH_INDICES:\n        mask &= state_batch_idx != pad_slot_id\n    state = tl.load(state_ptrs, mask=mask, other=0.0).to(tl.float32)\n\n    if HAS_DT_BIAS:\n        dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim\n    if HAS_D:\n        D_ptr += pid_h * stride_D_head\n        D_ptrs = D_ptr + offs_m * stride_D_dim\n    A_ptrs = A_ptr + offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate\n\n    cache_idx = -1\n    if CACHE_INTERMEDIATE_STATES:\n        if HAS_INTERMEDIATE_STATE_INDICES:\n            intermediate_state_idx = tl.load(intermediate_state_indices_ptr + pid_b).to(\n                tl.int64\n            )\n            cache_idx = intermediate_state_idx\n        elif HAS_STATE_BATCH_INDICES:\n            cache_idx = state_batch_idx\n        else:\n            cache_idx = pid_b\n\n    current_step_idx = 0\n    for _ in range(T):\n        if HAS_EAGLE_TREE_CUSTOM_ATTN_MASK:\n            if current_step_idx != 0 and cache_idx >= 0:\n                parent_ptr = (\n                    retrieve_parent_token_ptr\n                    + pid_b * stride_retrieve_parent_token_batch\n                    + current_step_idx * stride_retrieve_parent_token_T\n                )\n                parent_step_idx = tl.load(parent_ptr).to(tl.int32)\n\n                if parent_step_idx >= 0 and parent_step_idx < T:\n                    step_offset = parent_step_idx * nheads * dim * dstate\n                    cache_ptr = (\n                        intermediate_states_buffer\n                        + cache_idx * cache_steps * nheads * dim * dstate\n                        + step_offset\n                        + pid_h * dim * dstate\n                        + offs_m[:, None] * dstate\n                        + offs_n[None, :]\n                    )\n                    state = tl.load(cache_ptr, mask=mask, other=0.0).to(tl.float32)\n\n        x_ptrs = x_ptr + offs_m * stride_x_dim\n        dt_ptrs = dt_ptr + offs_m * stride_dt_dim\n        B_ptrs = B_ptr + offs_n * stride_B_dstate\n        C_ptrs = C_ptr + offs_n * stride_C_dstate\n        if HAS_Z:\n            z_ptrs = z_ptr + offs_m * stride_z_dim\n        out_ptrs = out_ptr + offs_m * stride_out_dim\n\n        x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n        if not TIE_HDIM:\n            dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n            if HAS_DT_BIAS:\n                dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n            if DT_SOFTPLUS:\n                dt = softplus(dt)\n            A = tl.load(\n                A_ptrs,\n                mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate),\n                other=0.0,\n            ).to(tl.float32)\n            dA = tl.exp(A * dt[:, None])\n        else:\n            dt = tl.load(dt_ptr).to(tl.float32)\n            if HAS_DT_BIAS:\n                dt += tl.load(dt_bias_ptr).to(tl.float32)\n            if DT_SOFTPLUS:\n                dt = softplus(dt)\n            A = tl.load(A_ptr).to(tl.float32)\n            dA = tl.exp(A * dt)  # scalar, not a matrix\n\n        B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)\n        C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)\n        if HAS_D:\n            D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n        if HAS_Z:\n            z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n\n        dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt\n        state = state * dA + dB * x[:, None]\n\n        if CACHE_INTERMEDIATE_STATES:\n            if HAS_STATE_BATCH_INDICES:\n                if state_batch_idx != pad_slot_id:\n                    cache_ptr_base = (\n                        intermediate_states_buffer\n                        + cache_idx * cache_steps * nheads * dim * dstate\n                        + current_step_idx * nheads * dim * dstate\n                        + pid_h * dim * dstate\n                    )\n                    cache_ptrs = cache_ptr_base + (\n                        offs_m[:, None] * dstate + offs_n[None, :]\n                    )\n                    tl.store(\n                        cache_ptrs, state.to(cache_ptrs.dtype.element_ty), mask=mask\n                    )\n\n        out = tl.sum(state * C[None, :], axis=1)\n        if HAS_D:\n            out += x * D\n        if HAS_Z:\n            out *= z * tl.sigmoid(z)\n        tl.store(out_ptrs, out, mask=offs_m < dim)\n\n        current_step_idx += 1\n\n        x_ptr += stride_x_T\n        dt_ptr += stride_dt_T\n        B_ptr += stride_B_T\n        C_ptr += stride_C_T\n        out_ptr += stride_out_T\n        if HAS_Z:\n            z_ptr += stride_z_T\n\n    if not DISABLE_STATE_UPDATE:\n        tl.store(state_ptrs, state.to(state_ptrs.dtype.element_ty), mask=mask)\n\n\ndef selective_state_update(\n    state,\n    x,\n    dt,\n    A,\n    B,\n    C,\n    D=None,\n    z=None,\n    dt_bias=None,\n    dt_softplus=False,\n    state_batch_indices=None,\n    pad_slot_id=PAD_SLOT_ID,\n    out=None,\n    disable_state_update=False,\n    intermediate_states_buffer=None,\n    cache_steps=None,\n    retrieve_parent_token=None,\n    intermediate_state_indices=None,\n):\n    \"\"\"\n    Argument:\n        state: (batch, dim, dstate) or (batch, nheads, dim, dstate)\n        x: (batch, dim) or (batch, nheads, dim) for single-token or (batch, T, nheads, dim) for multi-token\n        dt: (batch, dim) or (batch, nheads, dim)\n        A: (dim, dstate) or (nheads, dim, dstate)\n        B: (batch, dstate) or (batch, ngroups, dstate) for single-token or (batch, T, ngroups, dstate) for multi-token\n        C: (batch, dstate) or (batch, ngroups, dstate)\n        D: (dim,) or (nheads, dim)\n        z: (batch, dim) or (batch, nheads, dim)\n        dt_bias: (dim,) or (nheads, dim)\n        pad_slot_id: int\n            if cache_indices is passed, lets the kernel identify padded\n            entries that will not be processed,\n            for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]\n            in this case, the kernel will not process entries at\n            indices 0 and 3\n        out: Preallocated ssm output tensor. Assume same shape as x.\n             In-place updated.\n        disable_state_update: If True, don't write back to state (for speculative verify)\n        intermediate_states_buffer: Buffer to cache intermediate states\n        cache_steps: Total number of steps in the buffer\n        retrieve_parent_token: (batch, T) tensor of parent token indices for EAGLE tree attention\n        intermediate_state_indices: (batch,) tensor of indices for intermediate_states_buffer operations.\n            If provided, uses these indices instead of state_batch_indices for the buffer.\n    \"\"\"\n    if state.dim() == 3:\n        state = state.unsqueeze(1)\n    if x.dim() == 2:\n        x = x.unsqueeze(1)\n    if x.dim() == 3:\n        x = x.unsqueeze(1)\n    if dt.dim() == 2:\n        dt = dt.unsqueeze(1)\n    if dt.dim() == 3:\n        dt = dt.unsqueeze(1)\n    if A.dim() == 2:\n        A = A.unsqueeze(0)\n    if B.dim() == 2:\n        B = B.unsqueeze(1)\n    if B.dim() == 3:\n        B = B.unsqueeze(1)\n    if C.dim() == 2:\n        C = C.unsqueeze(1)\n    if C.dim() == 3:\n        C = C.unsqueeze(1)\n    if D is not None and D.dim() == 1:\n        D = D.unsqueeze(0)\n    if z is not None:\n        if z.dim() == 2:\n            z = z.unsqueeze(1)\n        if z.dim() == 3:\n            z = z.unsqueeze(1)\n    if dt_bias is not None and dt_bias.dim() == 1:\n        dt_bias = dt_bias.unsqueeze(0)\n    if out.dim() == 2:\n        out = out.unsqueeze(1)\n    if out.dim() == 3:\n        out = out.unsqueeze(1)\n\n    _, nheads, dim, dstate = state.shape\n    batch, T, _, _ = x.shape\n\n    assert x.shape == (batch, T, nheads, dim)\n    assert dt.shape == x.shape\n    assert A.shape == (nheads, dim, dstate)\n    ngroups = B.shape[2]\n    assert nheads % ngroups == 0, \"nheads must be divisible by ngroups\"\n    assert B.shape == (batch, T, ngroups, dstate)\n    assert C.shape == B.shape\n    if D is not None:\n        assert D.shape == (nheads, dim)\n    if z is not None:\n        assert z.shape == x.shape\n    if dt_bias is not None:\n        assert dt_bias.shape == (nheads, dim)\n    if state_batch_indices is not None:\n        assert state_batch_indices.shape == (batch,)\n    assert out.shape == x.shape\n\n    grid = lambda META: (triton.cdiv(dim, META[\"BLOCK_SIZE_M\"]), batch, nheads)\n    z_strides = (\n        (z.stride(0), z.stride(1), z.stride(2), z.stride(3))\n        if z is not None\n        else (0, 0, 0, 0)\n    )\n    # We don't want autotune since it will overwrite the state\n    # We instead tune by hand.\n    BLOCK_SIZE_M, num_warps = (\n        (32, 4)\n        if dstate <= 16\n        else (\n            (16, 4)\n            if dstate <= 32\n            else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8))))\n        )\n    )\n    tie_hdim = (\n        A.stride(-1) == 0\n        and A.stride(-2) == 0\n        and dt.stride(-1) == 0\n        and dt_bias.stride(-1) == 0\n    )\n\n    retrieve_parent_token_strides = (\n        (retrieve_parent_token.stride(0), retrieve_parent_token.stride(1))\n        if retrieve_parent_token is not None\n        else (0, 0)\n    )\n\n    with torch.cuda.device(x.device.index):\n        _selective_scan_update_kernel[grid](\n            state,\n            x,\n            dt,\n            dt_bias,\n            A,\n            B,\n            C,\n            D,\n            z,\n            out,\n            state_batch_indices,\n            pad_slot_id,\n            intermediate_states_buffer,\n            cache_steps if cache_steps is not None else 0,\n            retrieve_parent_token,\n            intermediate_state_indices,\n            batch,\n            T,\n            nheads,\n            dim,\n            dstate,\n            nheads // ngroups,\n            state.stride(0),\n            state.stride(1),\n            state.stride(2),\n            state.stride(3),\n            x.stride(0),\n            x.stride(1),\n            x.stride(2),\n            x.stride(3),\n            dt.stride(0),\n            dt.stride(1),\n            dt.stride(2),\n            dt.stride(3),\n            *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,\n            A.stride(0),\n            A.stride(1),\n            A.stride(2),\n            B.stride(0),\n            B.stride(1),\n            B.stride(2),\n            B.stride(3),\n            C.stride(0),\n            C.stride(1),\n            C.stride(2),\n            C.stride(3),\n            *(D.stride(0), D.stride(1)) if D is not None else 0,\n            z_strides[0],\n            z_strides[1],\n            z_strides[2],\n            z_strides[3],\n            out.stride(0),\n            out.stride(1),\n            out.stride(2),\n            out.stride(3),\n            retrieve_parent_token_strides[0],\n            retrieve_parent_token_strides[1],\n            dt_softplus,\n            tie_hdim,\n            BLOCK_SIZE_M,\n            DISABLE_STATE_UPDATE=disable_state_update,\n            num_warps=num_warps,\n        )\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/mamba/ops/ssd_bmm.py",
    "content": "# Adapted from: https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/mamba/ops/ssd_bmm.py\n\n# SPDX-License-Identifier: Apache-2.0\n# SPDX-FileCopyrightText: Copyright contributors to the vLLM project\n\n# Copyright (c) 2024, Tri Dao, Albert Gu.\n# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_bmm.py\n\n# ruff: noqa: E501,SIM102\n\nimport math\n\nimport torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _bmm_chunk_fwd_kernel(\n    # Pointers to matrices\n    a_ptr,\n    b_ptr,\n    out_ptr,\n    seq_idx_ptr,\n    # Matrix dimensions\n    seqlen,\n    chunk_size,\n    K,\n    ngroups,\n    stride_a_batch,\n    stride_a_seqlen,\n    stride_a_head,\n    stride_ak,\n    stride_b_batch,\n    stride_b_seqlen,\n    stride_b_head,\n    stride_bk,\n    stride_out_batch,\n    stride_out_chunk,\n    stride_out_head,\n    stride_outm,\n    stride_outn,\n    stride_seq_idx_batch,\n    stride_seq_idx_seqlen,\n    # Meta-parameters\n    IS_CAUSAL: tl.constexpr,\n    dot_dtype: tl.constexpr,\n    HAS_SEQ_IDX: tl.constexpr,\n    BLOCK_SIZE_M: tl.constexpr = 16,\n    BLOCK_SIZE_N: tl.constexpr = 16,\n    BLOCK_SIZE_K: tl.constexpr = 16,\n):\n    pid_b = tl.program_id(axis=1)\n    pid_ch = tl.program_id(axis=2).to(tl.int64)\n    pid_c = pid_ch // ngroups\n    pid_h = pid_ch - pid_c * ngroups\n    num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)\n    pid_m = tl.program_id(axis=0) // num_pid_n\n    pid_n = tl.program_id(axis=0) % num_pid_n\n    if IS_CAUSAL:\n        if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:\n            return\n    a_ptr += (\n        pid_b * stride_a_batch\n        + pid_c * chunk_size * stride_a_seqlen\n        + pid_h * stride_a_head\n    )\n    b_ptr += (\n        pid_b * stride_b_batch\n        + pid_c * chunk_size * stride_b_seqlen\n        + pid_h * stride_b_head\n    )\n    if HAS_SEQ_IDX:\n        seq_idx_ptr += (\n            pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen\n        )\n\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    offs_k = tl.arange(0, BLOCK_SIZE_K)\n    a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak)\n    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen)\n    chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n\n    acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n        a = tl.load(\n            a_ptrs,\n            mask=(offs_m[:, None] < chunk_size_limit)\n            & (offs_k[None, :] < K - k * BLOCK_SIZE_K),\n            other=0.0,\n        ).to(dot_dtype)\n        b = tl.load(\n            b_ptrs,\n            mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K)\n            & (offs_n[None, :] < chunk_size_limit),\n            other=0.0,\n        ).to(dot_dtype)\n        acc += tl.dot(a, b)\n        a_ptrs += BLOCK_SIZE_K * stride_ak\n        b_ptrs += BLOCK_SIZE_K * stride_bk\n\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    if HAS_SEQ_IDX:\n        chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n        seq_idx_m = tl.load(\n            seq_idx_ptr + offs_m * stride_seq_idx_seqlen,\n            mask=offs_m < chunk_size_limit,\n            other=-1,\n        )\n        seq_idx_n = tl.load(\n            seq_idx_ptr + offs_n * stride_seq_idx_seqlen,\n            mask=offs_n < chunk_size_limit,\n            other=-2,\n        )\n        acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)\n    out = acc.to(out_ptr.dtype.element_ty)\n\n    out_ptr += (\n        pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head\n    )\n    out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn)\n    tl.store(\n        out_ptrs,\n        out,\n        mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size),\n    )\n\n\ndef _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=None):\n    \"\"\"\n    Argument:\n        a: (batch, seqlen, k) or (batch, seqlen, ngroups, k)\n        b: (batch, seqlen, k) or (batch, seqlen, ngroups, k)\n        seq_idx: (batch, seqlen) or None. out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out.\n        causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are\n            guaranteed to be correct.\n    Return:\n        out: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size)\n    \"\"\"\n    # Check constraints.\n    has_groups = a.dim() == 4\n    if not has_groups:\n        batch, seqlen, k = a.shape\n    else:\n        batch, seqlen, ngroups, k = a.shape\n    assert b.shape == a.shape\n    if seq_idx is not None:\n        assert seq_idx.shape == (batch, seqlen)\n    if a.stride(-1) != 1 and a.stride(1) != 1:\n        a = a.contiguous()\n    if b.stride(-1) != 1 and b.stride(1) != 1:\n        b = b.contiguous()\n    nchunks = math.ceil(seqlen / chunk_size)\n    # Allocates output.\n    out_dtype = a.dtype if output_dtype is None else output_dtype\n    out = torch.empty(\n        (\n            (batch, nchunks, chunk_size, chunk_size)\n            if not has_groups\n            else (batch, nchunks, ngroups, chunk_size, chunk_size)\n        ),\n        device=a.device,\n        dtype=out_dtype,\n    )\n    dot_dtype = (\n        tl.bfloat16\n        if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16\n        else (\n            tl.float16\n            if a.dtype == torch.float16 or b.dtype == torch.float16\n            else tl.float32\n        )\n    )\n    grid = lambda META: (\n        triton.cdiv(chunk_size, META[\"BLOCK_SIZE_M\"])\n        * triton.cdiv(chunk_size, META[\"BLOCK_SIZE_N\"]),\n        batch,\n        nchunks if not has_groups else nchunks * ngroups,\n    )\n    with torch.cuda.device(a.device.index):\n        _bmm_chunk_fwd_kernel[grid](\n            a,\n            b,\n            out,\n            seq_idx,\n            seqlen,\n            chunk_size,\n            k,\n            ngroups if has_groups else 1,\n            a.stride(0),\n            a.stride(1),\n            0 if not has_groups else a.stride(2),\n            a.stride(-1),\n            b.stride(0),\n            b.stride(1),\n            0 if not has_groups else b.stride(2),\n            b.stride(-1),\n            out.stride(0),\n            out.stride(1),\n            0 if not has_groups else out.stride(2),\n            out.stride(-2),\n            out.stride(-1),\n            *(\n                (seq_idx.stride(0), seq_idx.stride(1))\n                if seq_idx is not None\n                else (0, 0)\n            ),\n            causal,\n            dot_dtype,\n            HAS_SEQ_IDX=seq_idx is not None,\n        )\n    return out\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py",
    "content": "# Adapted from: https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py\n\n# SPDX-License-Identifier: Apache-2.0\n# SPDX-FileCopyrightText: Copyright contributors to the vLLM project\n\n# Copyright (c) 2024, Tri Dao, Albert Gu.\n# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_scan.py\n\n# ruff: noqa: E501,SIM102\n\nimport torch\nimport triton\nimport triton.language as tl\nfrom packaging import version\n\nTRITON_22 = version.parse(triton.__version__) >= version.parse(\"2.2.0\")\n\n\n@triton.jit\ndef _chunk_scan_fwd_kernel(\n    # Pointers to matrices\n    cb_ptr,\n    x_ptr,\n    z_ptr,\n    out_ptr,\n    out_x_ptr,\n    dt_ptr,\n    dA_cumsum_ptr,\n    seq_idx_ptr,\n    C_ptr,\n    states_ptr,\n    D_ptr,\n    initstates_ptr,\n    chunk_indices_ptr,\n    chunk_offsets_ptr,\n    chunk_meta_num,\n    # Matrix dimensions\n    chunk_size,\n    hdim,\n    dstate,\n    batch,\n    seqlen,\n    nheads_ngroups_ratio,\n    # Strides\n    stride_cb_batch,\n    stride_cb_chunk,\n    stride_cb_head,\n    stride_cb_csize_m,\n    stride_cb_csize_k,\n    stride_x_batch,\n    stride_x_seqlen,\n    stride_x_head,\n    stride_x_hdim,\n    stride_z_batch,\n    stride_z_seqlen,\n    stride_z_head,\n    stride_z_hdim,\n    stride_out_batch,\n    stride_out_seqlen,\n    stride_out_head,\n    stride_out_hdim,\n    stride_dt_batch,\n    stride_dt_chunk,\n    stride_dt_head,\n    stride_dt_csize,\n    stride_dA_cs_batch,\n    stride_dA_cs_chunk,\n    stride_dA_cs_head,\n    stride_dA_cs_csize,\n    stride_seq_idx_batch,\n    stride_seq_idx_seqlen,\n    stride_C_batch,\n    stride_C_seqlen,\n    stride_C_head,\n    stride_C_dstate,\n    stride_states_batch,\n    stride_states_chunk,\n    stride_states_head,\n    stride_states_hdim,\n    stride_states_dstate,\n    stride_init_states_batch,\n    stride_init_states_head,\n    stride_init_states_hdim,\n    stride_init_states_dstate,\n    stride_D_head,\n    # Meta-parameters\n    IS_CAUSAL: tl.constexpr,\n    HAS_D: tl.constexpr,\n    D_HAS_HDIM: tl.constexpr,\n    HAS_Z: tl.constexpr,\n    HAS_SEQ_IDX: tl.constexpr,\n    BLOCK_SIZE_DSTATE: tl.constexpr,\n    IS_TRITON_22: tl.constexpr,\n    HAS_INITSTATES: tl.constexpr,\n    BLOCK_SIZE_M: tl.constexpr = 16,\n    BLOCK_SIZE_N: tl.constexpr = 16,\n    BLOCK_SIZE_K: tl.constexpr = 16,\n):\n    pid_bc = tl.program_id(axis=1).to(tl.int64)\n    pid_c = pid_bc // batch\n    pid_b = pid_bc - pid_c * batch\n    if not HAS_INITSTATES:\n        c_idx = pid_c\n        c_off = 0\n    else:\n        c_idx = tl.load(chunk_indices_ptr + pid_c, mask=pid_c > -1, other=0)\n        c_off = tl.load(chunk_offsets_ptr + pid_c, mask=pid_c > -1, other=0)\n\n    pid_h = tl.program_id(axis=2)\n    num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)\n    pid_m = tl.program_id(axis=0) // num_pid_n\n    pid_n = tl.program_id(axis=0) % num_pid_n\n    cb_ptr += (\n        pid_b * stride_cb_batch\n        + c_idx * stride_cb_chunk\n        + (pid_h // nheads_ngroups_ratio) * stride_cb_head\n    )\n    x_ptr += (\n        pid_b * stride_x_batch\n        + c_idx * chunk_size * stride_x_seqlen\n        + pid_h * stride_x_head\n    )\n    dt_ptr += pid_b * stride_dt_batch + c_idx * stride_dt_chunk + pid_h * stride_dt_head\n    dA_cumsum_ptr += (\n        pid_b * stride_dA_cs_batch\n        + c_idx * stride_dA_cs_chunk\n        + pid_h * stride_dA_cs_head\n    )\n    C_ptr += (\n        pid_b * stride_C_batch\n        + c_idx * chunk_size * stride_C_seqlen\n        + (pid_h // nheads_ngroups_ratio) * stride_C_head\n    )\n\n    # M-block offsets and prev states\n    #  - logic in next block may override these if there is an active offset\n    offs_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M)\n    prev_states_ptr = (\n        states_ptr\n        + pid_b * stride_states_batch\n        + c_idx * stride_states_chunk\n        + pid_h * stride_states_head\n    )\n    prev_states_hdim = stride_states_hdim\n    prev_states_dstate = stride_states_dstate\n\n    chunk_size_limit = min(chunk_size, seqlen - c_idx * chunk_size)\n    if HAS_SEQ_IDX:\n        seq_idx_ptr += (\n            pid_b * stride_seq_idx_batch + c_idx * chunk_size * stride_seq_idx_seqlen\n        )\n\n        # - we only need seq_idx_prev to be aligned to chunk boundary\n        seq_idx_prev = tl.load(\n            seq_idx_ptr - stride_seq_idx_seqlen, mask=c_idx >= 1, other=0\n        )\n\n        if HAS_INITSTATES:\n            # if there are init states, we only need seq_idx_m to point\n            # what is the current seq_idx\n\n            # get current seq idx\n            if (pid_m * BLOCK_SIZE_M + c_off) < chunk_size_limit:\n                seq_idx_m = tl.load(\n                    seq_idx_ptr\n                    + (pid_m * BLOCK_SIZE_M + c_off) * stride_seq_idx_seqlen,\n                )\n\n                # - recall that in ssd_state_passing, for the case c_off == 0\n                # i.e., the very first sequence, we made states_ptr hold its initial state\n                # so this edge case is taken care of\n                if (\n                    (c_off == 0)\n                    and (\n                        seq_idx_prev != seq_idx_m\n                    )  # if a seq is changed exactly on boundary\n                    or (c_off > 0)  # implies a new example (pseudo chunk)\n                ):\n\n                    # - replace prev_states_ptr with init_states\n                    prev_states_ptr = (\n                        initstates_ptr\n                        + seq_idx_m * stride_init_states_batch\n                        + pid_h * stride_init_states_head\n                    )\n                    prev_states_hdim = stride_init_states_hdim  # override strides\n                    prev_states_dstate = stride_init_states_dstate\n\n    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    dA_cs_m = tl.load(\n        dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0\n    ).to(tl.float32)\n\n    # - handle chunk state limit\n    if HAS_INITSTATES:\n\n        # have to split this if otherwise compilation will have problems\n        dA_cs_m_boundary = 0.0\n\n        # get the c_idx for the next (logica) chunk\n        c_idx_n = tl.load(\n            chunk_indices_ptr + (pid_c + 1),\n            mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num,\n            other=-1,  # to trigger different chunk\n        )\n\n        # - there are things to consider\n        # A. if c_off > 0 then we need to move the dA_cs boundary to ensure correct\n        #    contribution of past states\n        # B. if c_off_n < chunk_size_limit, then we need to adjust this so as not to\n        #    encroach into the next sequence, where c_off_n is the offset of the next\n        #    (logical) chunk.\n        # An equivalent check for B is c_idx == c_idx_n, where there is repetition in\n        # (logical) chunk indices.\n\n        if (c_idx == c_idx_n) or c_off > 0:\n\n            # get the next offset\n            c_off_n = tl.load(\n                chunk_offsets_ptr + (pid_c + 1),\n                mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num,\n                other=chunk_size,\n            )\n\n            # in this case, adjust down the chunk_size_limit\n            if c_idx == c_idx_n:\n                chunk_size_limit = min(c_off_n, chunk_size_limit)\n\n            # get the cs at the offset boundary\n            # - c_off == 0 is a passthrough\n            # - We need dA_cs at the boundary, defined by c_off - no need\n            #   to increase pointer by pid_m (it is a constant offset,\n            #   i.e. the same for all blocks)\n            dA_cs_m_boundary = tl.load(\n                dA_cumsum_ptr + (c_off - 1) * stride_dA_cs_csize,\n                mask=(((c_off - 1) > -1) and ((c_off) < chunk_size)),\n                other=0.0,\n            ).to(tl.float32)\n\n    if HAS_SEQ_IDX:\n        # - handle seq idx when HAS_INITSTATES==False\n        if not HAS_INITSTATES:\n            seq_idx_m = tl.load(\n                seq_idx_ptr + offs_m * stride_seq_idx_seqlen,\n                mask=offs_m < chunk_size_limit,\n                other=-1,\n            )\n\n    acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n    # Without the if (pid_c > -1), with Triton 2.1.0, I get\n    # Assertion `!(srcMmaLayout && dstMmaLayout) && \"Unexpected mma -> mm a layout conversion\"' failed.\n    # With Triton 2.2.0, this works\n    if IS_TRITON_22 or c_idx > -1:\n        # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128\n        offs_k_dstate = tl.arange(\n            0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K\n        )\n        C_ptrs = C_ptr + (\n            offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate\n        )\n\n        prev_states_ptrs = prev_states_ptr + (\n            offs_n[None, :] * prev_states_hdim\n            + offs_k_dstate[:, None] * prev_states_dstate\n        )\n        if HAS_SEQ_IDX:\n\n            if not HAS_INITSTATES:\n                # - this is for continuous batching where there is no init states\n                scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0)\n            else:\n                # - if there is initstates, we will rely on prev_states, no zeroing\n                #   required.\n                scale_m = tl.exp(dA_cs_m - dA_cs_m_boundary)\n        else:\n            scale_m = tl.exp(dA_cs_m)\n        if BLOCK_SIZE_DSTATE <= 128:\n            C = tl.load(\n                C_ptrs,\n                mask=(offs_m[:, None] < chunk_size_limit)\n                & (offs_k_dstate[None, :] < dstate),\n                other=0.0,\n            )\n\n            prev_states = tl.load(\n                prev_states_ptrs,\n                mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim),\n                other=0.0,\n            )\n            prev_states = prev_states.to(C_ptr.dtype.element_ty)\n            acc = tl.dot(C, prev_states) * scale_m[:, None]\n        else:\n            for k in range(0, dstate, BLOCK_SIZE_K):\n                C = tl.load(\n                    C_ptrs,\n                    mask=(offs_m[:, None] < chunk_size_limit)\n                    & (offs_k_dstate[None, :] < dstate - k),\n                    other=0.0,\n                )\n                # C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty)\n                prev_states = tl.load(\n                    prev_states_ptrs,\n                    mask=(offs_k_dstate[:, None] < dstate - k)\n                    & (offs_n[None, :] < hdim),\n                    other=0.0,\n                )\n                prev_states = prev_states.to(C_ptr.dtype.element_ty)\n                acc += tl.dot(C, prev_states)\n                C_ptrs += BLOCK_SIZE_K\n                prev_states_ptrs += BLOCK_SIZE_K\n            acc *= scale_m[:, None]\n\n    offs_k = tl.arange(0, BLOCK_SIZE_K) + c_off\n    cb_ptrs = cb_ptr + (\n        offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k\n    )\n    x_ptrs = x_ptr + (\n        offs_k[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim\n    )\n    dt_ptrs = dt_ptr + offs_k * stride_dt_csize\n    dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize\n    K_MAX = (\n        chunk_size_limit\n        if not IS_CAUSAL\n        else min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit)\n    )\n    for k in range(0, K_MAX, BLOCK_SIZE_K):\n        cb = tl.load(\n            cb_ptrs,\n            mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < chunk_size - k),\n            other=0.0,\n        ).to(tl.float32)\n        dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(\n            tl.float32\n        )\n        # If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j].\n        # So we don't need masking wrt seq_idx here.\n        cb *= tl.exp(dA_cs_m[:, None] - dA_cs_k[None, :])\n        dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32)\n        cb *= dt_k\n        if IS_CAUSAL:\n            mask = offs_m[:, None] >= k + offs_k[None, :]\n            cb = tl.where(mask, cb, 0.0)\n        cb = cb.to(x_ptr.dtype.element_ty)\n        x = tl.load(\n            x_ptrs,\n            mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < hdim),\n            other=0.0,\n        )\n        acc += tl.dot(cb, x)\n        cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k\n        x_ptrs += BLOCK_SIZE_K * stride_x_seqlen\n        dt_ptrs += BLOCK_SIZE_K * stride_dt_csize\n        dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize\n\n    offs_out_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M)\n    offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n\n    if HAS_D:\n        if D_HAS_HDIM:\n            D = tl.load(\n                D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0\n            ).to(tl.float32)\n        else:\n            D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)\n        x_residual = tl.load(\n            x_ptr\n            + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim),\n            mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),\n            other=0.0,\n        ).to(tl.float32)\n        acc += x_residual * D\n\n    if HAS_Z:\n        out_x_ptr += (\n            pid_b * stride_out_batch\n            + c_idx * chunk_size * stride_out_seqlen\n            + pid_h * stride_out_head\n        )\n        out_x_ptrs = out_x_ptr + (\n            stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :]\n        )\n        tl.store(\n            out_x_ptrs,\n            acc,\n            mask=(offs_out_m[:, None] < chunk_size_limit)\n            & (offs_out_n[None, :] < hdim),\n        )\n\n        z_ptr += (\n            pid_b * stride_z_batch\n            + c_idx * chunk_size * stride_z_seqlen\n            + pid_h * stride_z_head\n        )\n        z_ptrs = z_ptr + (\n            stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :]\n        )\n        z = tl.load(\n            z_ptrs,\n            mask=(offs_out_m[:, None] < chunk_size_limit)\n            & (offs_out_n[None, :] < hdim),\n            other=0.0,\n        ).to(tl.float32)\n        acc *= z * tl.sigmoid(z)\n\n    out_ptr += (\n        pid_b * stride_out_batch\n        + c_idx * chunk_size * stride_out_seqlen\n        + pid_h * stride_out_head\n    )\n    out_ptrs = out_ptr + (\n        stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim\n    )\n    tl.store(\n        out_ptrs,\n        acc,\n        mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim),\n    )\n\n\ndef _chunk_scan_fwd(\n    cb,\n    x,\n    dt,\n    dA_cumsum,\n    C,\n    states,\n    D=None,\n    z=None,\n    seq_idx=None,\n    chunk_indices=None,\n    chunk_offsets=None,\n    initial_states=None,\n    out=None,\n):\n    batch, seqlen, nheads, headdim = x.shape\n    _, _, nchunks, chunk_size = dt.shape\n    _, _, ngroups, dstate = C.shape\n    assert nheads % ngroups == 0\n    assert C.shape == (batch, seqlen, ngroups, dstate)\n    assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)\n    if z is not None:\n        assert z.shape == x.shape\n    if D is not None:\n        assert D.shape == (nheads, headdim) or D.shape == (nheads,)\n    assert dt.shape == (batch, nheads, nchunks, chunk_size)\n    assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)\n    assert states.shape == (batch, nchunks, nheads, headdim, dstate)\n\n    if seq_idx is not None:\n        assert seq_idx.shape == (batch, seqlen)\n\n        if initial_states is not None:\n            # with initial states, we need to take care of how\n            # seq_idx crosses the boundaries\n            assert batch == 1, \"chunk scan only supports initial states with batch 1\"\n            assert (\n                chunk_indices is not None and chunk_offsets is not None\n            ), \"chunk_indices and chunk_offsets should have been set\"\n        else:\n            chunk_indices, chunk_offsets = None, None\n    else:\n        chunk_indices, chunk_offsets = None, None\n\n    assert out.shape == x.shape\n\n    if z is not None:\n        out_x = torch.empty_like(x)\n        assert out_x.stride() == out.stride()\n    else:\n        out_x = None\n\n    grid = lambda META: (\n        triton.cdiv(chunk_size, META[\"BLOCK_SIZE_M\"])\n        * triton.cdiv(headdim, META[\"BLOCK_SIZE_N\"]),\n        batch * nchunks if chunk_offsets is None else len(chunk_offsets),\n        nheads,\n    )\n    z_strides = (\n        (z.stride(0), z.stride(1), z.stride(2), z.stride(3))\n        if z is not None\n        else (0, 0, 0, 0)\n    )\n    _chunk_scan_fwd_kernel[grid](\n        cb,\n        x,\n        z,\n        out,\n        out_x,\n        dt,\n        dA_cumsum,\n        seq_idx,\n        C,\n        states,\n        D,\n        initial_states,\n        chunk_indices,\n        chunk_offsets,\n        len(chunk_indices) if chunk_indices is not None else 0,\n        chunk_size,\n        headdim,\n        dstate,\n        batch,\n        seqlen,\n        nheads // ngroups,\n        cb.stride(0),\n        cb.stride(1),\n        cb.stride(2),\n        cb.stride(3),\n        cb.stride(4),\n        x.stride(0),\n        x.stride(1),\n        x.stride(2),\n        x.stride(3),\n        z_strides[0],\n        z_strides[1],\n        z_strides[2],\n        z_strides[3],\n        out.stride(0),\n        out.stride(1),\n        out.stride(2),\n        out.stride(3),\n        dt.stride(0),\n        dt.stride(2),\n        dt.stride(1),\n        dt.stride(3),\n        dA_cumsum.stride(0),\n        dA_cumsum.stride(2),\n        dA_cumsum.stride(1),\n        dA_cumsum.stride(3),\n        *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n        C.stride(0),\n        C.stride(1),\n        C.stride(2),\n        C.stride(3),\n        states.stride(0),\n        states.stride(1),\n        states.stride(2),\n        states.stride(3),\n        states.stride(4),\n        *(\n            (\n                initial_states.stride(0),\n                initial_states.stride(1),\n                initial_states.stride(2),\n                initial_states.stride(3),\n            )\n            if initial_states is not None\n            else (0, 0, 0, 0)\n        ),\n        D.stride(0) if D is not None else 0,\n        True,\n        D is not None,\n        D.dim() == 2 if D is not None else True,\n        BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),\n        HAS_Z=z is not None,\n        HAS_SEQ_IDX=seq_idx is not None,\n        IS_TRITON_22=TRITON_22,\n        HAS_INITSTATES=initial_states is not None,\n    )\n    return out_x\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py",
    "content": "# Adapted from: https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py\n\n# SPDX-License-Identifier: Apache-2.0\n# SPDX-FileCopyrightText: Copyright contributors to the vLLM project\n\n# Copyright (c) 2024, Tri Dao, Albert Gu.\n# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_state.py\n\n# ruff: noqa: E501\n\nimport math\n\nimport torch\nimport triton\nimport triton.language as tl\n\nfrom .mamba_ssm import softplus\n\n\n@triton.jit\ndef _chunk_cumsum_fwd_kernel(\n    # Pointers to matrices\n    dt_ptr,\n    A_ptr,\n    dt_bias_ptr,\n    dt_out_ptr,\n    dA_cumsum_ptr,\n    # Matrix dimension\n    batch,\n    seqlen,\n    nheads,\n    chunk_size,\n    dt_min,\n    dt_max,\n    # Strides\n    stride_dt_batch,\n    stride_dt_seqlen,\n    stride_dt_head,\n    stride_A_head,\n    stride_dt_bias_head,\n    stride_dt_out_batch,\n    stride_dt_out_chunk,\n    stride_dt_out_head,\n    stride_dt_out_csize,\n    stride_dA_cs_batch,\n    stride_dA_cs_chunk,\n    stride_dA_cs_head,\n    stride_dA_cs_csize,\n    # Meta-parameters\n    DT_SOFTPLUS: tl.constexpr,\n    HAS_DT_BIAS: tl.constexpr,\n    BLOCK_SIZE_CHUNK: tl.constexpr,\n    BLOCK_SIZE_H: tl.constexpr = 16,\n):\n    pid_b = tl.program_id(axis=0)\n\n    # if dt is long, may cause problems, so use 64 bit\n    # https://github.com/triton-lang/triton/issues/1058\n    pid_c = tl.program_id(axis=1).to(tl.int64)\n    pid_h = tl.program_id(axis=2)\n    dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen\n    dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk\n    dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk\n\n    offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)\n    offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)\n    dt_ptrs = dt_ptr + (\n        offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen\n    )\n    A_ptrs = A_ptr + offs_h * stride_A_head\n    dt_out_ptrs = dt_out_ptr + (\n        offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize\n    )\n    dA_cs_ptrs = dA_cumsum_ptr + (\n        offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize\n    )\n    chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n\n    dt = tl.load(\n        dt_ptrs,\n        mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),\n        other=0.0,\n    ).to(tl.float32)\n    if HAS_DT_BIAS:\n        dt_bias = tl.load(\n            dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0\n        ).to(tl.float32)\n        dt += dt_bias[:, None]\n    if DT_SOFTPLUS:\n        dt = tl.where(dt <= 20.0, softplus(dt), dt)\n    # As of Triton 2.2.0, tl.clamp is not available yet\n    # dt = tl.clamp(dt, dt_min, dt_max)\n    dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)\n    dt = tl.where(\n        (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0\n    )\n    tl.store(\n        dt_out_ptrs,\n        dt,\n        mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size),\n    )\n    A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)\n    dA = dt * A[:, None]\n    dA_cs = tl.cumsum(dA, axis=1)\n    tl.store(\n        dA_cs_ptrs,\n        dA_cs,\n        mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size),\n    )\n\n\n@triton.jit\ndef _chunk_state_fwd_kernel(\n    # Pointers to matrices\n    x_ptr,\n    b_ptr,\n    states_ptr,\n    dt_ptr,\n    dA_cumsum_ptr,\n    seq_idx_ptr,\n    # Matrix dimensions\n    hdim,\n    dstate,\n    chunk_size,\n    batch,\n    seqlen,\n    nheads_ngroups_ratio,\n    # Strides\n    stride_x_batch,\n    stride_x_seqlen,\n    stride_x_head,\n    stride_x_hdim,\n    stride_b_batch,\n    stride_b_seqlen,\n    stride_b_head,\n    stride_b_dstate,\n    stride_states_batch,\n    stride_states_chunk,\n    stride_states_head,\n    stride_states_hdim,\n    stride_states_dstate,\n    stride_dt_batch,\n    stride_dt_chunk,\n    stride_dt_head,\n    stride_dt_csize,\n    stride_dA_cs_batch,\n    stride_dA_cs_chunk,\n    stride_dA_cs_head,\n    stride_dA_cs_csize,\n    stride_seq_idx_batch,\n    stride_seq_idx_seqlen,\n    # Meta-parameters\n    HAS_SEQ_IDX: tl.constexpr,\n    BLOCK_SIZE_M: tl.constexpr = 16,\n    BLOCK_SIZE_N: tl.constexpr = 16,\n    BLOCK_SIZE_K: tl.constexpr = 16,\n):\n    pid_bc = tl.program_id(axis=1).to(tl.int64)\n    pid_c = pid_bc // batch\n    pid_b = pid_bc - pid_c * batch\n    pid_h = tl.program_id(axis=2)\n    num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)\n    pid_m = tl.program_id(axis=0) // num_pid_n\n    pid_n = tl.program_id(axis=0) % num_pid_n\n    b_ptr += (\n        pid_b * stride_b_batch\n        + pid_c * chunk_size * stride_b_seqlen\n        + (pid_h // nheads_ngroups_ratio) * stride_b_head\n    )\n    x_ptr += (\n        pid_b * stride_x_batch\n        + pid_c * chunk_size * stride_x_seqlen\n        + pid_h * stride_x_head\n    )\n    dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head\n    dA_cumsum_ptr += (\n        pid_b * stride_dA_cs_batch\n        + pid_c * stride_dA_cs_chunk\n        + pid_h * stride_dA_cs_head\n    )\n    if HAS_SEQ_IDX:\n        seq_idx_ptr += (\n            pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen\n        )\n\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    offs_k = tl.arange(0, BLOCK_SIZE_K)\n    x_ptrs = x_ptr + (\n        offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen\n    )\n    b_ptrs = b_ptr + (\n        offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen\n    )\n    dt_ptrs = dt_ptr + offs_k * stride_dt_csize\n    dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(\n        tl.float32\n    )\n    dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize\n    if HAS_SEQ_IDX:\n        seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen\n\n    chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n    if HAS_SEQ_IDX:\n        seq_idx_last = tl.load(\n            seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen\n        )\n\n    acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n    for k in range(0, chunk_size_limit, BLOCK_SIZE_K):\n        x = tl.load(\n            x_ptrs,\n            mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k),\n            other=0.0,\n        )\n        b = tl.load(\n            b_ptrs,\n            mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate),\n            other=0.0,\n        ).to(tl.float32)\n        dA_cs_k = tl.load(\n            dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0\n        ).to(tl.float32)\n        if HAS_SEQ_IDX:\n            seq_idx_k = tl.load(\n                seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1\n            )\n        dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(\n            tl.float32\n        )\n        if not HAS_SEQ_IDX:\n            scale = tl.exp(dA_cs_last - dA_cs_k) * dt_k\n        else:\n            scale = tl.where(\n                seq_idx_k == seq_idx_last, tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0\n            )\n        b *= scale[:, None]\n        b = b.to(x_ptr.dtype.element_ty)\n        acc += tl.dot(x, b)\n        x_ptrs += BLOCK_SIZE_K * stride_x_seqlen\n        b_ptrs += BLOCK_SIZE_K * stride_b_seqlen\n        dt_ptrs += BLOCK_SIZE_K * stride_dt_csize\n        dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize\n        if HAS_SEQ_IDX:\n            seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen\n    states = acc.to(states_ptr.dtype.element_ty)\n\n    states_ptr += (\n        pid_b * stride_states_batch\n        + pid_c * stride_states_chunk\n        + pid_h * stride_states_head\n    )\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    states_ptrs = states_ptr + (\n        offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate\n    )\n    c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)\n    tl.store(states_ptrs, states, mask=c_mask)\n\n\n@triton.jit\ndef _chunk_state_varlen_kernel(\n    # Pointers to matrices\n    x_ptr,\n    b_ptr,\n    dt_ptr,\n    dA_cumsum_ptr,\n    chunk_states_ptr,\n    cu_seqlens_ptr,\n    states_ptr,\n    initstates_ptr,\n    # Matrix dimensions\n    hdim,\n    dstate,\n    chunk_size,\n    seqlen,\n    nheads_ngroups_ratio,\n    # Strides\n    stride_x_seqlen,\n    stride_x_head,\n    stride_x_hdim,\n    stride_b_seqlen,\n    stride_b_head,\n    stride_b_dstate,\n    stride_dt_chunk,\n    stride_dt_head,\n    stride_dt_csize,\n    stride_dA_cs_chunk,\n    stride_dA_cs_head,\n    stride_dA_cs_csize,\n    stride_chunk_states_chunk,\n    stride_chunk_states_head,\n    stride_chunk_states_hdim,\n    stride_chunk_states_dstate,\n    stride_states_batch,\n    stride_states_head,\n    stride_states_hdim,\n    stride_states_dstate,\n    stride_init_states_batch,\n    stride_init_states_head,\n    stride_init_states_hdim,\n    stride_init_states_dstate,\n    # Meta-parameters\n    HAS_INITSTATES: tl.constexpr,\n    BLOCK_SIZE_M: tl.constexpr = 16,\n    BLOCK_SIZE_N: tl.constexpr = 16,\n    BLOCK_SIZE_K: tl.constexpr = 16,\n):\n    pid_b = tl.program_id(axis=1)\n    pid_h = tl.program_id(axis=2)\n    num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)\n    pid_m = tl.program_id(axis=0) // num_pid_n\n    pid_n = tl.program_id(axis=0) % num_pid_n\n    end_idx = tl.load(cu_seqlens_ptr + pid_b + 1)\n    pid_c = (end_idx - 1) // chunk_size\n    b_ptr += (\n        pid_c * chunk_size * stride_b_seqlen\n        + (pid_h // nheads_ngroups_ratio) * stride_b_head\n    )\n    x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head\n    dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head\n    dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head\n    chunk_states_ptr += (\n        pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head\n    )\n\n    if HAS_INITSTATES:\n        # if there are init states provided, we differentiate between states (which\n        # are boundary conditions at a chunk boundary) and initstates (which are boundary\n        # conditions when a new example in a cont batch starts)\n        initstates_ptr += pid_h * stride_init_states_head\n\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    offs_k = tl.arange(0, BLOCK_SIZE_K)\n    x_ptrs = x_ptr + (\n        offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen\n    )\n    b_ptrs = b_ptr + (\n        offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen\n    )\n    dt_ptrs = dt_ptr + offs_k * stride_dt_csize\n    dA_cs_last = tl.load(\n        dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize\n    ).to(tl.float32)\n    dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize\n\n    chunk_size_limit = end_idx - pid_c * chunk_size\n    start_idx = tl.load(cu_seqlens_ptr + pid_b)\n    start_idx_cur = tl.maximum(start_idx - pid_c * chunk_size, 0)\n\n    acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n    for k in range(0, chunk_size_limit, BLOCK_SIZE_K):\n        x = tl.load(\n            x_ptrs,\n            mask=(offs_m[:, None] < hdim)\n            & (offs_k[None, :] < chunk_size_limit - k)\n            & (offs_k[None, :] >= start_idx_cur - k),\n            other=0.0,\n        )\n        b = tl.load(\n            b_ptrs,\n            mask=(offs_k[:, None] < chunk_size_limit - k)\n            & (offs_n[None, :] < dstate)\n            & (offs_k[:, None] >= start_idx_cur - k),\n            other=0.0,\n        ).to(tl.float32)\n        dA_cs_k = tl.load(\n            dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0\n        ).to(tl.float32)\n        dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(\n            tl.float32\n        )\n        scale = tl.where(\n            (offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),\n            tl.exp(dA_cs_last - dA_cs_k) * dt_k,\n            0.0,\n        )\n        b *= scale[:, None]\n        b = b.to(x_ptr.dtype.element_ty)\n        acc += tl.dot(x, b)\n        x_ptrs += BLOCK_SIZE_K * stride_x_seqlen\n        b_ptrs += BLOCK_SIZE_K * stride_b_seqlen\n        dt_ptrs += BLOCK_SIZE_K * stride_dt_csize\n        dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize\n\n    # If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk\n    # If HAS_INITSTATES==True need to consider two possibilities\n    # - if start_idx < pid_c * chunk_size, then we need to take the past_states_ptrs\n    # - if state_idx >= pid * chunk_size, then we need to insert initstates\n    if (start_idx < pid_c * chunk_size) or (HAS_INITSTATES):  # first chunk\n\n        dA_cs_boundary = 0.0  # default\n\n        if not HAS_INITSTATES:\n            past_states_ptrs = chunk_states_ptr + (\n                offs_m[:, None] * stride_chunk_states_hdim\n                + offs_n[None, :] * stride_chunk_states_dstate\n            )\n        else:\n\n            # - this seems repetitive, buts its to help the compiler\n            if start_idx < pid_c * chunk_size:\n                past_states_ptrs = chunk_states_ptr + (\n                    offs_m[:, None] * stride_chunk_states_hdim\n                    + offs_n[None, :] * stride_chunk_states_dstate\n                )\n            else:\n                past_states_ptrs = initstates_ptr + (\n                    pid_b * stride_init_states_batch\n                    + offs_m[:, None] * stride_init_states_hdim\n                    + offs_n[None, :] * stride_init_states_dstate\n                )\n\n                # need to adjust the boundary\n                if start_idx > pid_c * chunk_size:\n                    dA_cs_boundary = tl.load(\n                        dA_cumsum_ptr\n                        + (start_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize\n                    ).to(tl.float32)\n\n        past_states = tl.load(\n            past_states_ptrs,\n            mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate),\n            other=0.0,\n        ).to(tl.float32)\n\n        scale = tl.exp(dA_cs_last - dA_cs_boundary)\n        acc += past_states * scale\n\n    states = acc.to(states_ptr.dtype.element_ty)\n\n    states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    states_ptrs = states_ptr + (\n        offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate\n    )\n    c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)\n    tl.store(states_ptrs, states, mask=c_mask)\n\n\ndef _chunk_cumsum_fwd(\n    dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float(\"inf\"))\n):\n    batch, seqlen, nheads = dt.shape\n    assert A.shape == (nheads,)\n    if dt_bias is not None:\n        assert dt_bias.shape == (nheads,)\n    nchunks = math.ceil(seqlen / chunk_size)\n    dt_out = torch.empty(\n        batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32\n    )\n    dA_cumsum = torch.empty(\n        batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32\n    )\n    grid_chunk_cs = lambda META: (\n        batch,\n        nchunks,\n        triton.cdiv(nheads, META[\"BLOCK_SIZE_H\"]),\n    )\n    with torch.cuda.device(dt.device.index):\n        _chunk_cumsum_fwd_kernel[grid_chunk_cs](\n            dt,\n            A,\n            dt_bias,\n            dt_out,\n            dA_cumsum,\n            batch,\n            seqlen,\n            nheads,\n            chunk_size,\n            dt_limit[0],\n            dt_limit[1],\n            dt.stride(0),\n            dt.stride(1),\n            dt.stride(2),\n            A.stride(0),\n            dt_bias.stride(0) if dt_bias is not None else 0,\n            dt_out.stride(0),\n            dt_out.stride(2),\n            dt_out.stride(1),\n            dt_out.stride(3),\n            dA_cumsum.stride(0),\n            dA_cumsum.stride(2),\n            dA_cumsum.stride(1),\n            dA_cumsum.stride(3),\n            dt_softplus,\n            HAS_DT_BIAS=dt_bias is not None,\n            BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),\n        )\n    return dA_cumsum, dt_out\n\n\ndef _chunk_state_fwd(\n    B, x, dt, dA_cumsum, seq_idx=None, states=None, states_in_fp32=True\n):\n    batch, seqlen, nheads, headdim = x.shape\n    _, _, nchunks, chunk_size = dt.shape\n    _, _, ngroups, dstate = B.shape\n    assert nheads % ngroups == 0\n    assert B.shape == (batch, seqlen, ngroups, dstate)\n    assert dt.shape == (batch, nheads, nchunks, chunk_size)\n    assert dA_cumsum.shape == dt.shape\n    if seq_idx is not None:\n        assert seq_idx.shape == (batch, seqlen)\n    if states is not None:\n        assert states.shape == (batch, nchunks, nheads, headdim, dstate)\n    else:\n        states_dtype = torch.float32 if states_in_fp32 else B.dtype\n        states = torch.empty(\n            (batch, nchunks, nheads, headdim, dstate),\n            device=x.device,\n            dtype=states_dtype,\n        )\n    grid = lambda META: (\n        triton.cdiv(headdim, META[\"BLOCK_SIZE_M\"])\n        * triton.cdiv(dstate, META[\"BLOCK_SIZE_N\"]),\n        batch * nchunks,\n        nheads,\n    )\n    with torch.cuda.device(x.device.index):\n        _chunk_state_fwd_kernel[grid](\n            x,\n            B,\n            states,\n            dt,\n            dA_cumsum,\n            seq_idx,\n            headdim,\n            dstate,\n            chunk_size,\n            batch,\n            seqlen,\n            nheads // ngroups,\n            x.stride(0),\n            x.stride(1),\n            x.stride(2),\n            x.stride(3),\n            B.stride(0),\n            B.stride(1),\n            B.stride(2),\n            B.stride(-1),\n            states.stride(0),\n            states.stride(1),\n            states.stride(2),\n            states.stride(3),\n            states.stride(4),\n            dt.stride(0),\n            dt.stride(2),\n            dt.stride(1),\n            dt.stride(3),\n            dA_cumsum.stride(0),\n            dA_cumsum.stride(2),\n            dA_cumsum.stride(1),\n            dA_cumsum.stride(3),\n            *(\n                (seq_idx.stride(0), seq_idx.stride(1))\n                if seq_idx is not None\n                else (0, 0)\n            ),\n            HAS_SEQ_IDX=seq_idx is not None,\n        )\n    return states\n\n\ndef chunk_state_varlen(\n    B, x, dt, dA_cumsum, cu_seqlens, chunk_states, initial_states=None\n):\n    total_seqlen, nheads, headdim = x.shape\n    _, nchunks, chunk_size = dt.shape\n    _, ngroups, dstate = B.shape\n    batch = cu_seqlens.shape[0] - 1\n    cu_seqlens = cu_seqlens.contiguous()\n    assert nheads % ngroups == 0\n    assert B.shape == (total_seqlen, ngroups, dstate)\n    assert dt.shape == (nheads, nchunks, chunk_size)\n    assert dA_cumsum.shape == dt.shape\n    assert chunk_states.shape == (nchunks, nheads, headdim, dstate)\n\n    if initial_states is not None:\n        assert initial_states.shape == (batch, nheads, headdim, dstate)\n\n    states = torch.empty(\n        batch,\n        nheads,\n        headdim,\n        dstate,\n        dtype=chunk_states.dtype,\n        device=chunk_states.device,\n    )\n    grid = lambda META: (\n        triton.cdiv(headdim, META[\"BLOCK_SIZE_M\"])\n        * triton.cdiv(dstate, META[\"BLOCK_SIZE_N\"]),\n        batch,\n        nheads,\n    )\n    with torch.cuda.device(x.device.index):\n        _chunk_state_varlen_kernel[grid](\n            x,\n            B,\n            dt,\n            dA_cumsum,\n            chunk_states,\n            cu_seqlens,\n            states,\n            initial_states,\n            headdim,\n            dstate,\n            chunk_size,\n            total_seqlen,\n            nheads // ngroups,\n            x.stride(0),\n            x.stride(1),\n            x.stride(2),\n            B.stride(0),\n            B.stride(1),\n            B.stride(2),\n            dt.stride(1),\n            dt.stride(0),\n            dt.stride(2),\n            dA_cumsum.stride(1),\n            dA_cumsum.stride(0),\n            dA_cumsum.stride(2),\n            chunk_states.stride(0),\n            chunk_states.stride(1),\n            chunk_states.stride(2),\n            chunk_states.stride(3),\n            states.stride(0),\n            states.stride(1),\n            states.stride(2),\n            states.stride(3),\n            *(\n                (\n                    initial_states.stride(0),\n                    initial_states.stride(1),\n                    initial_states.stride(2),\n                    initial_states.stride(3),\n                )\n                if initial_states is not None\n                else (0, 0, 0, 0)\n            ),\n            HAS_INITSTATES=initial_states is not None,\n        )\n    return states\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/mamba/ops/ssd_combined.py",
    "content": "# Adapted from: https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/mamba/ops/ssd_combined.py\n\n# SPDX-License-Identifier: Apache-2.0\n# SPDX-FileCopyrightText: Copyright contributors to the vLLM project\n\n# Copyright (c) 2024, Tri Dao, Albert Gu.\n# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_combined.py\n\n# ruff: noqa: E501\n\nimport torch\nimport triton\nfrom einops import rearrange\nfrom packaging import version\n\nfrom .ssd_bmm import _bmm_chunk_fwd\nfrom .ssd_chunk_scan import _chunk_scan_fwd\nfrom .ssd_chunk_state import _chunk_cumsum_fwd, _chunk_state_fwd, chunk_state_varlen\nfrom .ssd_state_passing import _state_passing_fwd\n\nTRITON_22 = version.parse(triton.__version__) >= version.parse(\"2.2.0\")\n\n\ndef is_int_pow_2(n):\n    return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0\n\n\ndef _mamba_chunk_scan_combined_fwd(\n    x,\n    dt,\n    A,\n    B,\n    C,\n    chunk_size,\n    D=None,\n    z=None,\n    dt_bias=None,\n    initial_states=None,\n    seq_idx=None,\n    chunk_indices=None,\n    chunk_offsets=None,\n    cu_seqlens=None,\n    dt_softplus=False,\n    dt_limit=(0.0, float(\"inf\")),\n    state_dtype=None,\n    out=None,\n):\n    assert is_int_pow_2(chunk_size), \"chunk_size must be integer power of 2\"\n    batch, seqlen, nheads, headdim = x.shape\n    _, _, ngroups, dstate = B.shape\n    assert nheads % ngroups == 0\n    assert B.shape == (batch, seqlen, ngroups, dstate)\n    assert dt.shape == (batch, seqlen, nheads)\n    assert A.shape == (nheads,)\n    assert C.shape == B.shape\n    if z is not None:\n        assert z.shape == x.shape\n    if D is not None:\n        assert D.shape == (nheads, headdim) or D.shape == (nheads,)\n    if seq_idx is not None:\n        assert seq_idx.shape == (batch, seqlen)\n    if B.stride(-1) != 1:\n        B = B.contiguous()\n    if C.stride(-1) != 1:\n        C = C.contiguous()\n    if (\n        x.stride(-1) != 1 and x.stride(1) != 1\n    ):  # Either M or K dimension should be contiguous\n        x = x.contiguous()\n    if (\n        z is not None and z.stride(-1) != 1 and z.stride(1) != 1\n    ):  # Either M or K dimension should be contiguous\n        z = z.contiguous()\n    if D is not None and D.stride(-1) != 1:\n        D = D.contiguous()\n    if initial_states is not None:\n        if cu_seqlens is None:\n            assert initial_states.shape == (batch, nheads, headdim, dstate)\n        else:\n            assert initial_states.shape == (\n                len(cu_seqlens) - 1,\n                nheads,\n                headdim,\n                dstate,\n            )\n\n    # This function executes 5 sub-functions for computing mamba\n    # - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/\n    #   which has a minimal implementation to understand the below operations\n    # - as explained by the blog, mamba is a special case of causal attention\n    # - the idea is to chunk the attention matrix and compute each\n    #   submatrix separately using different optimizations.\n    # - see the blog and paper for a visualization of the submatrices\n    #   which we refer to in the comments below\n\n    # 1. Compute chunked cumsum of A * dt\n    # - here dt may go through a softplus activation\n    dA_cumsum, dt = _chunk_cumsum_fwd(\n        dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit\n    )\n\n    # 2. Compute the state for each intra-chunk\n    # (right term of low-rank factorization of off-diagonal blocks; B terms)\n    states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)\n\n    # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries\n    # (middle term of factorization of off-diag blocks; A terms)\n    # - for handling chunked prefill, this requires i) initial_states\n    #   ii) seq_idx iii) is_cont_batched and (iv) chunk_offsets to be all specified.\n    # - When a new seq_idx is detected, we will stop passing the prev_state\n    #   and switch accordingly to the init_state corresponding to the new seq_idx.\n    # - We will also make sure that the dA_cumsum is taken only from the start of the\n    #   sequence (hence we need the full dA_cumsum tensor and not just the values at chunk boundaries)\n    # - this will ensure that states will be updated with the rightmost flushed seq_idx\n    #   of the previous chunk. This implies that the first chunk of states is either 0\n    #   or equal to init_states of the first example.\n    states, final_states = _state_passing_fwd(\n        rearrange(states, \"... p n -> ... (p n)\"),\n        dA_cumsum,\n        initial_states=(\n            rearrange(initial_states, \"... p n -> ... (p n)\")\n            if initial_states is not None\n            else None\n        ),\n        seq_idx=seq_idx,\n        chunk_size=chunk_size,\n        out_dtype=state_dtype if state_dtype is not None else C.dtype,\n        is_cont_batched=cu_seqlens is not None,\n        chunk_offsets=chunk_offsets,\n    )\n    states, final_states = (\n        rearrange(t, \"... (p n) -> ... p n\", n=dstate) for t in [states, final_states]\n    )\n\n    # 4. Compute batched matrix multiply for C_j^T B_i terms\n    CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32)\n\n    # 5. Scan and compute the diagonal blocks, taking into\n    #    account past causal states.\n    # - if initial states are provided, then states information will be\n    #   augmented with initial_states.\n    # - to do this properly, we need to account for example changes in\n    #   the continuous batch, therefore we introduce pseudo chunks, which is\n    #   a chunk that is split up each time an example changes.\n    # - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had\n    #   a seq_idx change, in which case we take states information from\n    #   init_states.\n    out_x = _chunk_scan_fwd(\n        CB,\n        x,\n        dt,\n        dA_cumsum,\n        C,\n        states,\n        D=D,\n        z=z,\n        seq_idx=seq_idx,\n        chunk_indices=chunk_indices,\n        chunk_offsets=chunk_offsets,\n        initial_states=initial_states,\n        out=out,\n    )\n    if cu_seqlens is None:\n        return out_x, dt, dA_cumsum, states, final_states\n    else:\n        assert (\n            batch == 1\n        ), \"passing cu_seqlens to get the varlen states is only supported if batch dimension is 1\"\n        varlen_states = chunk_state_varlen(\n            B.squeeze(0),\n            x.squeeze(0),\n            dt.squeeze(0),\n            dA_cumsum.squeeze(0),\n            cu_seqlens,\n            states.squeeze(0),\n            initial_states=initial_states,\n        )\n        return out_x, dt, dA_cumsum, states, final_states, varlen_states\n\n\ndef mamba_chunk_scan_combined(\n    x,\n    dt,\n    A,\n    B,\n    C,\n    chunk_size,\n    D=None,\n    z=None,\n    dt_bias=None,\n    initial_states=None,\n    seq_idx=None,\n    chunk_indices=None,\n    chunk_offsets=None,\n    cu_seqlens=None,\n    dt_softplus=False,\n    dt_limit=(0.0, float(\"inf\")),\n    out=None,\n    return_final_states=False,\n    return_varlen_states=False,\n    return_intermediate_states=False,\n    state_dtype=None,\n):\n    \"\"\"\n    Argument:\n        x: (batch, seqlen, nheads, headdim)\n        dt: (batch, seqlen, nheads)\n        A: (nheads)\n        B: (batch, seqlen, ngroups, dstate)\n        C: (batch, seqlen, ngroups, dstate)\n        chunk_size: int\n        D: (nheads, headdim) or (nheads,)\n        z: (batch, seqlen, nheads, headdim)\n        dt_bias: (nheads,)\n        initial_states: (batch, nheads, headdim, dstate)\n        seq_idx: (batch, seqlen)\n        cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True\n        dt_softplus: Whether to apply softplus to dt\n        out: Preallocated output tensor\n        state_dtype: The data type of the ssm state\n    \"\"\"\n\n    if not return_varlen_states:\n        cu_seqlens = None\n    else:\n        assert (\n            cu_seqlens is not None\n        ), \"cu_seqlens must be provided if return_varlen_states is True\"\n    out_x, dt_out, dA_cumsum, states, final_states, *rest = (\n        _mamba_chunk_scan_combined_fwd(\n            x,\n            dt,\n            A,\n            B,\n            C,\n            chunk_size,\n            D=D,\n            z=z,\n            dt_bias=dt_bias,\n            initial_states=initial_states,\n            seq_idx=seq_idx,\n            chunk_indices=chunk_indices,\n            chunk_offsets=chunk_offsets,\n            cu_seqlens=cu_seqlens,\n            dt_softplus=dt_softplus,\n            dt_limit=dt_limit,\n            out=out,\n            state_dtype=state_dtype,\n        )\n    )\n    if return_intermediate_states:\n        if return_varlen_states:\n            varlen_states = rest[0]\n            if return_final_states:\n                return states, final_states, varlen_states\n            else:\n                return states, varlen_states\n        else:\n            if return_final_states:\n                return states, final_states\n            else:\n                return states\n\n    if not return_varlen_states:\n        if not return_final_states:\n            return\n        else:\n            return final_states\n    else:\n        varlen_states = rest[0]\n        return (\n            (varlen_states)\n            if not return_final_states\n            else (final_states, varlen_states)\n        )\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py",
    "content": "# Adapted from: https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py\n\n# SPDX-License-Identifier: Apache-2.0\n# SPDX-FileCopyrightText: Copyright contributors to the vLLM project\n\n# Copyright (c) 2024, Tri Dao, Albert Gu.\n# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_state_passing.py\n\n# ruff: noqa: E501\n\nimport torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _state_passing_fwd_kernel(\n    # Pointers to matrices\n    states_ptr,\n    out_ptr,\n    final_states_ptr,\n    dA_cs_ptr,\n    initstates_ptr,\n    seq_idx_ptr,\n    chunk_offsets_ptr,\n    chunk_meta_num,\n    # Matrix dimensions\n    dim,\n    nchunks,\n    seqlen,\n    chunk_size,\n    # Strides\n    stride_states_batch,\n    stride_states_chunk,\n    stride_states_head,\n    stride_states_dim,\n    stride_out_batch,\n    stride_out_chunk,\n    stride_out_head,\n    stride_out_dim,\n    stride_final_states_batch,\n    stride_final_states_head,\n    stride_final_states_dim,\n    stride_dA_cs_batch,\n    stride_dA_cs_chunk,\n    stride_dA_cs_head,\n    stride_dA_cs_csize,\n    stride_initstates_batch,\n    stride_initstates_head,\n    stride_initstates_dim,\n    stride_seq_idx_batch,\n    stride_seq_idx_seqlen,\n    # Meta-parameters\n    HAS_INITSTATES: tl.constexpr,\n    HAS_SEQ_IDX: tl.constexpr,\n    IS_CONT_BATCHED: tl.constexpr,\n    BLOCK_SIZE: tl.constexpr = 16,\n):\n    pid_b = tl.program_id(axis=1)\n    pid_h = tl.program_id(axis=2)\n    pid_m = tl.program_id(axis=0)\n    states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head\n    dA_cs_ptr += (\n        pid_b * stride_dA_cs_batch\n        + pid_h * stride_dA_cs_head\n        + (chunk_size - 1) * stride_dA_cs_csize\n    )\n    out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head\n    final_states_ptr += (\n        pid_b * stride_final_states_batch + pid_h * stride_final_states_head\n    )\n    if HAS_INITSTATES:\n        initstates_ptr += pid_h * stride_initstates_head\n        if not IS_CONT_BATCHED:\n            initstates_ptr += pid_b * stride_initstates_batch\n\n    if HAS_SEQ_IDX:\n        seq_idx_ptr += pid_b * stride_seq_idx_batch\n\n    offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n    states_ptrs = states_ptr + offs_m * stride_states_dim\n    out_ptrs = out_ptr + offs_m * stride_out_dim\n    final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim\n\n    # - states will be the past state of the sequence that continues on the current check\n    if not HAS_INITSTATES:\n        states = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)\n    else:\n        initstates_ptr += offs_m * stride_initstates_dim\n        initstates_ptrs = initstates_ptr\n        # - for cont batches, for the first chunk mean it will be the first batch's\n        #   init state\n        states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n\n    tl.store(out_ptrs, states, mask=offs_m < dim)\n    out_ptrs += stride_out_chunk\n    prev_seq_idx_chunk_end = 0\n    logical_chunk_idx = 0\n    for c in range(nchunks):\n        new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n        dA_cs = tl.load(dA_cs_ptr).to(tl.float32)\n        scale_mask = True\n        if HAS_SEQ_IDX:\n            # - the seq to pass forward is the one that is flushed to the right\n            #   boundary.\n            # - that is given by seq_idx_chunk_end below: the sequence index at the end of the chunk.\n            seq_idx_chunk_end = tl.load(\n                seq_idx_ptr\n                + (min((c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen\n            )\n            if HAS_INITSTATES:\n                if IS_CONT_BATCHED and prev_seq_idx_chunk_end != seq_idx_chunk_end:\n                    # this means in the current chunk the rightmost flushed seq\n                    # has changed.\n                    # - so we do not propagate the state from previous chunk\n                    # - but rather we load that sequence's init state\n                    initstates_ptrs = (\n                        initstates_ptr + seq_idx_chunk_end * stride_initstates_batch\n                    )\n\n                    # - update state with seq_idx_new's init state\n                    states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(\n                        tl.float32\n                    )\n\n                    # - we need to consider the cumsum only of the last sequence in the chunk\n                    # - find its starting position (given by c_off of the logical chunk index)\n                    # - and subtract the cumsum just before that position from the total cumsum\n                    # - first, update the logical chunk index (add the number of sequences in the current physical chunk):\n                    # sequence index at the start of the current chunk\n                    seq_idx_chunk_start = tl.load(\n                        seq_idx_ptr\n                        + min(c * chunk_size, seqlen) * stride_seq_idx_seqlen\n                    )\n                    logical_chunk_idx += seq_idx_chunk_end - seq_idx_chunk_start\n                    # - load the chunk offset:\n                    c_off = tl.load(\n                        chunk_offsets_ptr + logical_chunk_idx,\n                        mask=logical_chunk_idx < chunk_meta_num,\n                        other=0,\n                    )\n                    # - if offset is 0, then the sequence starts at the beginning of the chunk, and we don't need to subtract anything\n                    if c_off > 0:\n                        # - dA_cs_ptr currently points to the cumsum at the end of the chunk - subtract the chunk size and add the offset\n                        dA_cs_boundary = tl.load(\n                            dA_cs_ptr\n                            - (chunk_size - 1) * stride_dA_cs_csize\n                            + (c_off - 1) * stride_dA_cs_csize,\n                            mask=(c_off - 1) > -1 and c_off < chunk_size,\n                            other=0.0,\n                        )\n                        dA_cs -= dA_cs_boundary\n\n                # - increment logical chunk index for every physical chunk\n                logical_chunk_idx += 1\n            else:\n                scale_mask = seq_idx_chunk_end == prev_seq_idx_chunk_end\n            prev_seq_idx_chunk_end = seq_idx_chunk_end\n\n        scale = tl.where(scale_mask, tl.exp(dA_cs), 0.0)\n        states = scale * states + new_states\n        if c < nchunks - 1:\n            tl.store(out_ptrs, states, mask=offs_m < dim)\n        else:\n            tl.store(final_states_ptrs, states, mask=offs_m < dim)\n        states_ptrs += stride_states_chunk\n        dA_cs_ptr += stride_dA_cs_chunk\n        out_ptrs += stride_out_chunk\n\n\ndef _state_passing_fwd(\n    states,\n    dA_cumsum,\n    initial_states=None,\n    seq_idx=None,\n    chunk_size=None,\n    out_dtype=None,\n    is_cont_batched=False,\n    chunk_offsets=None,\n):\n    batch, nchunks, nheads, dim = states.shape\n    if chunk_size is None:\n        chunk_size = dA_cumsum.shape[-1]\n    else:\n        assert chunk_size == dA_cumsum.shape[-1]\n    assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)\n    if initial_states is not None:\n        if is_cont_batched:\n            # - if cu_seqlens is provided, then the initial states\n            #   are used for continuous batching. In which case we\n            #   require seq_idx to be provided\n            assert (\n                seq_idx is not None\n            ), \"seq_idx must be provided for continuous batching\"\n            # - we also need chunk_offsets to be provided, to account\n            #   for computation of dA_cumsum from the start of the\n            #   sequence\n            assert (\n                chunk_offsets is not None\n            ), \"chunk_offsets must be provided for continuous batching\"\n        else:\n            # - this is the regular batching case, where initial\n            #   states are used are for each example of the batch.\n            assert initial_states.shape == (batch, nheads, dim)\n\n    if seq_idx is not None:\n        seqlen = seq_idx.shape[-1]\n        assert seq_idx.shape == (batch, seqlen)\n    out_dtype = states.dtype if out_dtype is None else out_dtype\n    out = torch.empty(\n        (batch, nchunks, nheads, dim), device=states.device, dtype=out_dtype\n    )\n    final_states = torch.empty(\n        (batch, nheads, dim), device=states.device, dtype=torch.float32\n    )\n    grid = lambda META: (triton.cdiv(dim, META[\"BLOCK_SIZE\"]), batch, nheads)\n    with torch.cuda.device(states.device.index):\n        _state_passing_fwd_kernel[grid](\n            states,\n            out,\n            final_states,\n            dA_cumsum,\n            initial_states,\n            seq_idx,\n            chunk_offsets,\n            len(chunk_offsets) if chunk_offsets is not None else 0,\n            dim,\n            nchunks,\n            seqlen if seq_idx is not None else 0,\n            chunk_size,\n            states.stride(0),\n            states.stride(1),\n            states.stride(2),\n            states.stride(3),\n            out.stride(0),\n            out.stride(1),\n            out.stride(2),\n            out.stride(3),\n            final_states.stride(0),\n            final_states.stride(1),\n            final_states.stride(2),\n            dA_cumsum.stride(0),\n            dA_cumsum.stride(2),\n            dA_cumsum.stride(1),\n            dA_cumsum.stride(3),\n            *(\n                (\n                    initial_states.stride(0),\n                    initial_states.stride(1),\n                    initial_states.stride(2),\n                )\n                if initial_states is not None\n                else (0, 0, 0)\n            ),\n            *(\n                (seq_idx.stride(0), seq_idx.stride(1))\n                if seq_idx is not None\n                else (0, 0)\n            ),\n            HAS_INITSTATES=initial_states is not None,\n            HAS_SEQ_IDX=seq_idx is not None,\n            IS_CONT_BATCHED=is_cont_batched,\n        )\n    return out, final_states\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/mamba/ops/ssu_dispatch.py",
    "content": "from __future__ import annotations\n\nimport logging\nfrom abc import ABC, abstractmethod\nfrom typing import TYPE_CHECKING\n\nimport torch\n\nif TYPE_CHECKING:\n    from sglang.srt.server_args import ServerArgs\n\nlogger = logging.getLogger(__name__)\n\n\nclass MambaSSUBackend(ABC):\n    @property\n    @abstractmethod\n    def name(self) -> str:\n        \"\"\"Human-readable name used for logging.\"\"\"\n\n    @abstractmethod\n    def __call__(\n        self,\n        state: torch.Tensor,\n        x: torch.Tensor,\n        dt: torch.Tensor,\n        A: torch.Tensor,\n        B: torch.Tensor,\n        C: torch.Tensor,\n        D: torch.Tensor | None = None,\n        z: torch.Tensor | None = None,\n        dt_bias: torch.Tensor | None = None,\n        dt_softplus: bool = False,\n        state_batch_indices: torch.Tensor | None = None,\n        pad_slot_id: int = -1,\n        out: torch.Tensor | None = None,\n        disable_state_update: bool = False,\n        intermediate_states_buffer: torch.Tensor | None = None,\n        cache_steps: int | None = None,\n        retrieve_parent_token: torch.Tensor | None = None,\n        intermediate_state_indices: torch.Tensor | None = None,\n    ) -> None: ...\n\n\nclass TritonSSUBackend(MambaSSUBackend):\n    \"\"\"Triton-based selective-state-update backend.\"\"\"\n\n    def __init__(self) -> None:\n        from sglang.srt.layers.attention.mamba.ops.mamba_ssm import (\n            selective_state_update,\n        )\n\n        self._kernel = selective_state_update\n\n    @property\n    def name(self) -> str:\n        return \"triton\"\n\n    def __call__(\n        self,\n        state: torch.Tensor,\n        x: torch.Tensor,\n        dt: torch.Tensor,\n        A: torch.Tensor,\n        B: torch.Tensor,\n        C: torch.Tensor,\n        D: torch.Tensor | None = None,\n        z: torch.Tensor | None = None,\n        dt_bias: torch.Tensor | None = None,\n        dt_softplus: bool = False,\n        state_batch_indices: torch.Tensor | None = None,\n        pad_slot_id: int = -1,\n        out: torch.Tensor | None = None,\n        disable_state_update: bool = False,\n        intermediate_states_buffer: torch.Tensor | None = None,\n        cache_steps: int | None = None,\n        retrieve_parent_token: torch.Tensor | None = None,\n        intermediate_state_indices: torch.Tensor | None = None,\n    ) -> None:\n        self._kernel(\n            state,\n            x,\n            dt,\n            A,\n            B,\n            C,\n            D=D,\n            z=z,\n            dt_bias=dt_bias,\n            dt_softplus=dt_softplus,\n            state_batch_indices=state_batch_indices,\n            pad_slot_id=pad_slot_id,\n            out=out,\n            disable_state_update=disable_state_update,\n            intermediate_states_buffer=intermediate_states_buffer,\n            cache_steps=cache_steps,\n            retrieve_parent_token=retrieve_parent_token,\n            intermediate_state_indices=intermediate_state_indices,\n        )\n\n\nclass FlashInferSSUBackend(MambaSSUBackend):\n    \"\"\"FlashInfer-based selective-state-update backend.\"\"\"\n\n    def __init__(self) -> None:\n        from flashinfer.mamba import selective_state_update\n\n        self._kernel = selective_state_update\n\n    @property\n    def name(self) -> str:\n        return \"flashinfer\"\n\n    def __call__(\n        self,\n        state: torch.Tensor,\n        x: torch.Tensor,\n        dt: torch.Tensor,\n        A: torch.Tensor,\n        B: torch.Tensor,\n        C: torch.Tensor,\n        D: torch.Tensor | None = None,\n        z: torch.Tensor | None = None,\n        dt_bias: torch.Tensor | None = None,\n        dt_softplus: bool = False,\n        state_batch_indices: torch.Tensor | None = None,\n        pad_slot_id: int = -1,\n        out: torch.Tensor | None = None,\n        disable_state_update: bool = False,\n        intermediate_states_buffer: torch.Tensor | None = None,\n        cache_steps: int | None = None,\n        retrieve_parent_token: torch.Tensor | None = None,\n        intermediate_state_indices: torch.Tensor | None = None,\n    ) -> None:\n        if retrieve_parent_token is not None:\n            raise ValueError(\n                \"FlashInfer backend does not support retrieve_parent_token. \"\n                \"Use --mamba-backend triton for EAGLE tree attention.\"\n            )\n        # FlashInfer expects cache_steps as an int (0 when unused).\n        self._kernel(\n            state,\n            x,\n            dt,\n            A,\n            B,\n            C,\n            D=D,\n            z=z,\n            dt_bias=dt_bias,\n            dt_softplus=dt_softplus,\n            state_batch_indices=state_batch_indices,\n            pad_slot_id=pad_slot_id,\n            out=out,\n            disable_state_update=disable_state_update,\n            intermediate_states_buffer=intermediate_states_buffer,\n            cache_steps=0 if cache_steps is None else cache_steps,\n            intermediate_state_indices=intermediate_state_indices,\n        )\n\n\n_BACKEND_REGISTRY: dict[str, type[MambaSSUBackend]] = {\n    \"triton\": TritonSSUBackend,\n    \"flashinfer\": FlashInferSSUBackend,\n}\n\n_mamba_ssu_backend: MambaSSUBackend | None = None\n\n\ndef initialize_mamba_selective_state_update_backend(server_args: ServerArgs) -> None:\n    \"\"\"Instantiate the selective-state-update backend from server config.\n\n    This should be called once during scheduler initialization.\n\n    Args:\n        server_args: Server arguments containing ``mamba_backend`` setting.\n\n    Raises:\n        ValueError: If the requested backend is unavailable or cannot be imported.\n    \"\"\"\n    global _mamba_ssu_backend\n\n    requested = server_args.mamba_backend or \"triton\"\n\n    backend_cls = _BACKEND_REGISTRY.get(requested)\n    if backend_cls is None:\n        raise ValueError(\n            f\"Unknown mamba backend '{requested}'. \"\n            f\"Available backends: {list(_BACKEND_REGISTRY.keys())}\"\n        )\n\n    try:\n        _mamba_ssu_backend = backend_cls()\n    except ImportError:\n        raise ValueError(\n            f\"Mamba backend '{requested}' requested but its dependencies are not \"\n            f\"available. Install the required package or use a different \"\n            f\"--mamba-backend value.\"\n        )\n\n    logger.debug(\n        \"Mamba selective_state_update backend initialized: %s\",\n        _mamba_ssu_backend.name,\n    )\n\n\ndef selective_state_update(\n    state: torch.Tensor,\n    x: torch.Tensor,\n    dt: torch.Tensor,\n    A: torch.Tensor,\n    B: torch.Tensor,\n    C: torch.Tensor,\n    D: torch.Tensor | None = None,\n    z: torch.Tensor | None = None,\n    dt_bias: torch.Tensor | None = None,\n    dt_softplus: bool = False,\n    state_batch_indices: torch.Tensor | None = None,\n    pad_slot_id: int = -1,\n    out: torch.Tensor | None = None,\n    disable_state_update: bool = False,\n    intermediate_states_buffer: torch.Tensor | None = None,\n    cache_steps: int | None = None,\n    retrieve_parent_token: torch.Tensor | None = None,\n    intermediate_state_indices: torch.Tensor | None = None,\n) -> None:\n    \"\"\"Dispatch selective-state-update to the configured backend.\n\n    This function provides a unified interface regardless of the underlying\n    backend. Backend-specific argument adaptation is handled inside each\n    :class:`MambaSSUBackend` subclass.\n\n    Args:\n        state: SSM state tensor (batch, nheads, dim, dstate)\n        x: Input tensor\n        dt: Delta time tensor\n        A: A matrix\n        B: B matrix\n        C: C matrix\n        D: Optional D vector\n        z: Optional z tensor for gating\n        dt_bias: Optional dt bias\n        dt_softplus: Whether to apply softplus to dt\n        state_batch_indices: Optional batch indices for state\n        out: Preallocated output tensor (in-place updated)\n        disable_state_update: If True, don't write back to state (for speculative verify)\n        intermediate_states_buffer: Buffer to cache intermediate states\n        cache_steps: Total number of steps in the buffer\n        retrieve_parent_token: (batch, T) tensor of parent token indices for EAGLE tree attention\n        intermediate_state_indices: (batch,) tensor of indices for intermediate_states_buffer operations.\n            If provided, uses these indices instead of state_batch_indices for the buffer.\n    \"\"\"\n    assert _mamba_ssu_backend is not None, (\n        \"Mamba selective_state_update backend not initialized. \"\n        \"Call initialize_mamba_selective_state_update_backend() first.\"\n    )\n\n    _mamba_ssu_backend(\n        state,\n        x,\n        dt,\n        A,\n        B,\n        C,\n        D=D,\n        z=z,\n        dt_bias=dt_bias,\n        dt_softplus=dt_softplus,\n        state_batch_indices=state_batch_indices,\n        pad_slot_id=pad_slot_id,\n        out=out,\n        disable_state_update=disable_state_update,\n        intermediate_states_buffer=intermediate_states_buffer,\n        cache_steps=cache_steps,\n        retrieve_parent_token=retrieve_parent_token,\n        intermediate_state_indices=intermediate_state_indices,\n    )\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/merge_state.py",
    "content": "from typing import Optional, Tuple\n\nimport torch\nfrom sgl_kernel import merge_state_v2\n\nfrom sglang.srt.layers.attention.triton_ops.merge_state import merge_state_triton\nfrom sglang.srt.utils import is_cuda\n\n_is_cuda = is_cuda()\n\n\n# Automatically fallback to the Triton kernel in some cases\n# (e.g., for AMD GPUs, when the head dimension is not a multiple\n# of 4 or 8, and in FP8 precision)\ndef _supported_dtypes(o: torch.Tensor) -> bool:\n    return o.dtype in [torch.float32, torch.half, torch.bfloat16]\n\n\ndef _supported_headdim(o: torch.Tensor) -> bool:\n    headdim = o.shape[2]  # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]\n    if o.dtype == torch.float32:\n        return headdim % 4 == 0\n    return headdim % 8 == 0\n\n\ndef merge_state(\n    prefix_output: torch.Tensor,\n    prefix_lse: torch.Tensor,\n    suffix_output: torch.Tensor,\n    suffix_lse: torch.Tensor,\n    output: Optional[torch.Tensor] = None,\n    output_lse: Optional[torch.Tensor] = None,\n) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n    if (\n        _is_cuda\n        and _supported_dtypes(prefix_output)\n        and _supported_headdim(prefix_output)\n    ):\n        return merge_state_v2(\n            prefix_output, prefix_lse, suffix_output, suffix_lse, output, output_lse\n        )\n    else:\n        # Fallback to Triton kernel\n        return merge_state_triton(\n            prefix_output, prefix_lse, suffix_output, suffix_lse, output, output_lse\n        )\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/nsa/dequant_k_cache.py",
    "content": "import torch\nimport triton\nimport triton.language as tl\n\n\ndef dequantize_k_cache(quant_k_cache):\n    return _dequantize_k_cache_fast_wrapped(quant_k_cache)\n\n\ndef _dequantize_k_cache_ref(\n    quant_k_cache: torch.Tensor,  # (num_blocks, block_size, 1, bytes_per_token)\n    dv: int = 512,\n    tile_size: int = 128,\n    d: int = 576,\n) -> torch.Tensor:\n    \"\"\"\n    De-quantize the k-cache\n    \"\"\"\n    assert dv % tile_size == 0\n    original_ndim = quant_k_cache.ndim\n    if original_ndim == 3:\n        # set block_size = 1\n        quant_k_cache = quant_k_cache.unsqueeze(1)\n    num_tiles = dv // tile_size\n    num_blocks, block_size, h_k, _ = quant_k_cache.shape\n    assert h_k == 1\n    result = torch.empty(\n        (num_blocks, block_size, d), dtype=torch.bfloat16, device=quant_k_cache.device\n    )\n\n    quant_k_cache = quant_k_cache.view(num_blocks, block_size, -1)\n\n    input_nope = quant_k_cache[..., :dv]\n    input_scale = quant_k_cache[..., dv : dv + num_tiles * 4].view(torch.float32)\n    input_rope = quant_k_cache[..., dv + num_tiles * 4 :].view(torch.bfloat16)\n    result[..., dv:] = input_rope\n\n    for tile_idx in range(0, num_tiles):\n        cur_nope = input_nope[\n            ..., tile_idx * tile_size : (tile_idx + 1) * tile_size\n        ].to(torch.float32)\n        cur_scales = input_scale[..., tile_idx].unsqueeze(-1)\n        result[..., tile_idx * tile_size : (tile_idx + 1) * tile_size] = (\n            cur_nope * cur_scales\n        )\n\n    if original_ndim == 3:\n        return result.view(num_blocks, 1, -1)\n    else:\n        return result.view(num_blocks, block_size, 1, -1)\n\n\ndef _dequantize_k_cache_fast_wrapped(\n    quant_k_cache: torch.Tensor,\n    dv: int = 512,\n    tile_size: int = 128,\n) -> torch.Tensor:\n    original_ndim = quant_k_cache.ndim\n    if original_ndim == 3:\n        # set block_size = 1\n        quant_k_cache = quant_k_cache.unsqueeze(1)\n    num_blocks, block_size, _, dim_quant = quant_k_cache.shape\n    assert dv == 512\n    assert dim_quant == 656\n    assert tile_size == 128\n    quant_k_cache = quant_k_cache.view((-1, dim_quant))\n\n    output = _dequantize_k_cache_fast(quant_k_cache)\n\n    if original_ndim == 3:\n        return output.view(num_blocks, 1, -1)\n    else:\n        return output.view(num_blocks, block_size, 1, -1)\n\n\ndef _dequantize_k_cache_fast(quant_k_cache, group_size: int = 128):\n    num_tokens, dim_quant = quant_k_cache.shape\n\n    assert quant_k_cache.dtype == torch.float8_e4m3fn\n    dim_nope = 512\n    dim_rope = 64\n    num_tiles = dim_nope // group_size\n    assert dim_quant == 656\n\n    output = torch.empty(\n        (num_tokens, dim_nope + dim_rope),\n        dtype=torch.bfloat16,\n        device=quant_k_cache.device,\n    )\n\n    num_blocks_per_token = triton.cdiv(dim_nope + dim_rope, group_size)\n    assert num_blocks_per_token == 5\n\n    assert dim_nope % group_size == 0\n\n    input_nope_q = quant_k_cache[:, :dim_nope]\n    input_nope_s = quant_k_cache[:, dim_nope : dim_nope + num_tiles * 4].view(\n        torch.float32\n    )\n    input_rope = quant_k_cache[:, dim_nope + num_tiles * 4 :].view(torch.bfloat16)\n\n    _dequantize_k_cache_fast_kernel[(num_tokens, num_blocks_per_token)](\n        output,\n        input_nope_q,\n        input_nope_s,\n        input_rope,\n        output.stride(0),\n        input_nope_q.stride(0),\n        input_nope_s.stride(0),\n        input_rope.stride(0),\n        NUM_NOPE_BLOCKS=num_tiles,\n        GROUP_SIZE=group_size,\n        DIM_NOPE=dim_nope,\n        DIM_ROPE=dim_rope,\n    )\n\n    return output\n\n\n@triton.jit\ndef _dequantize_k_cache_fast_kernel(\n    output_ptr,\n    input_nope_q_ptr,\n    input_nope_s_ptr,\n    input_rope_ptr,\n    output_stride_0: int,\n    input_nope_q_stride_0: int,\n    input_nope_s_stride_0: int,\n    input_rope_stride_0: int,\n    NUM_NOPE_BLOCKS: tl.constexpr,\n    GROUP_SIZE: tl.constexpr,\n    DIM_NOPE: tl.constexpr,\n    DIM_ROPE: tl.constexpr,\n):\n    token_id = tl.program_id(0)\n    raw_block_id = tl.program_id(1)\n\n    if raw_block_id < NUM_NOPE_BLOCKS:\n        # a. dequant nope\n        effective_block_id = raw_block_id\n\n        offs_q = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)\n        mask = offs_q < DIM_NOPE\n        ptr_q = input_nope_q_ptr + token_id * input_nope_q_stride_0 + offs_q\n        ptr_s = input_nope_s_ptr + token_id * input_nope_s_stride_0 + effective_block_id\n\n        y_q = tl.load(ptr_q, mask=mask, other=0.0).to(tl.float32)\n        y_s = tl.load(ptr_s)\n\n        y = (y_q * y_s).to(output_ptr.dtype.element_ty)\n\n        dst_ptr = output_ptr + token_id * output_stride_0 + offs_q\n        tl.store(dst_ptr, y, mask=mask)\n    else:\n        # b. copy rope\n        effective_block_id = raw_block_id - NUM_NOPE_BLOCKS\n\n        offs = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)\n        mask = offs < DIM_ROPE\n\n        src_ptr = input_rope_ptr + token_id * input_rope_stride_0 + offs\n        dst_ptr = output_ptr + token_id * output_stride_0 + DIM_NOPE + offs\n\n        data = tl.load(src_ptr, mask=mask).to(tl.bfloat16)\n        tl.store(dst_ptr, data, mask=mask)\n\n\ndef dequantize_k_cache_paged(\n    quant_k_cache: torch.Tensor,\n    page_table_1_flattened: torch.Tensor,\n    group_size: int = 128,\n) -> torch.Tensor:\n    \"\"\"\n    De-quantize the k-cache with paged layout\n    Args:\n        quant_k_cache: [total_num_tokens, 1, dim_quant] or [num_blocks, block_size, 1, dim_quant], the quantized k-cache in paged layout\n        page_table_1_flattened: [num_tokens], the flattened page_table_1 with the page indices in each requests concatenated together\n    Returns:\n        output: [num_tokens, 1, dim_nope + dim_rope], the de-quantized k-cache\n    \"\"\"\n    dim_quant = quant_k_cache.shape[-1]\n    assert (\n        dim_quant == 656\n    ), f\"dim_quant: {dim_quant} != 656 detected in dequantize_k_cache_paged\"\n    quant_k_cache = quant_k_cache.view((-1, dim_quant))\n\n    # num_tokens can exceed kv_cache_size due to prefix sharing (multiple seqs share same KV slots)\n    # Index bounds validated in nsa_backend.init_forward_metadata\n    num_tokens = page_table_1_flattened.shape[0]\n    assert quant_k_cache.dtype == torch.float8_e4m3fn\n    dim_nope = 512\n    dim_rope = 64\n    num_tiles = dim_nope // group_size  # 512 // 128 = 4\n\n    output = torch.empty(\n        (num_tokens, 1, dim_nope + dim_rope),\n        dtype=torch.bfloat16,\n        device=quant_k_cache.device,\n    )\n\n    # cdiv(512 + 64, 128) = 5\n    num_blocks_per_token = triton.cdiv(dim_nope + dim_rope, group_size)\n    assert num_blocks_per_token == 5\n\n    assert dim_nope % group_size == 0\n\n    input_nope_q = quant_k_cache[:, :dim_nope]\n    # [:, 512:512+4*4] = [:, 512:528]\n    input_nope_s = quant_k_cache[:, dim_nope : dim_nope + num_tiles * 4].view(\n        torch.float32\n    )\n    # [:, 528:]\n    input_rope = quant_k_cache[:, dim_nope + num_tiles * 4 :].view(torch.bfloat16)\n\n    _dequantize_k_cache_paged_kernel[(num_tokens, num_blocks_per_token)](\n        output,\n        input_nope_q,\n        input_nope_s,\n        input_rope,\n        page_table_1_flattened,\n        output.stride(0),\n        input_nope_q.stride(0),\n        input_nope_s.stride(0),\n        input_rope.stride(0),\n        NUM_NOPE_BLOCKS=num_tiles,\n        GROUP_SIZE=group_size,\n        DIM_NOPE=dim_nope,\n        DIM_ROPE=dim_rope,\n    )\n\n    return output\n\n\n@triton.jit\ndef _dequantize_k_cache_paged_kernel(\n    output_ptr,\n    input_nope_q_ptr,\n    input_nope_s_ptr,\n    input_rope_ptr,\n    page_table_1_ptr,\n    output_stride_0: int,\n    input_nope_q_stride_0: int,\n    input_nope_s_stride_0: int,\n    input_rope_stride_0: int,\n    NUM_NOPE_BLOCKS: tl.constexpr,\n    GROUP_SIZE: tl.constexpr,\n    DIM_NOPE: tl.constexpr,\n    DIM_ROPE: tl.constexpr,\n):\n    token_id = tl.program_id(0)\n    token_id_paged = tl.load(page_table_1_ptr + token_id).to(tl.int32)\n    raw_block_id = tl.program_id(1)\n\n    if raw_block_id < NUM_NOPE_BLOCKS:\n        # a. dequant nope\n        effective_block_id = raw_block_id\n\n        offs_q = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)\n        mask = offs_q < DIM_NOPE\n        ptr_q = input_nope_q_ptr + token_id_paged * input_nope_q_stride_0 + offs_q\n        ptr_s = (\n            input_nope_s_ptr\n            + token_id_paged * input_nope_s_stride_0\n            + effective_block_id\n        )\n\n        y_q = tl.load(ptr_q, mask=mask, other=0.0).to(tl.float32)\n        y_s = tl.load(ptr_s)\n\n        y = (y_q * y_s).to(output_ptr.dtype.element_ty)\n\n        dst_ptr = output_ptr + token_id * output_stride_0 + offs_q\n        tl.store(dst_ptr, y, mask=mask)\n    else:\n        # b. copy rope\n        effective_block_id = raw_block_id - NUM_NOPE_BLOCKS\n\n        offs = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)\n        mask = offs < DIM_ROPE\n\n        src_ptr = input_rope_ptr + token_id_paged * input_rope_stride_0 + offs\n        dst_ptr = output_ptr + token_id * output_stride_0 + DIM_NOPE + offs\n\n        data = tl.load(src_ptr, mask=mask).to(tl.bfloat16)\n        tl.store(dst_ptr, data, mask=mask)\n\n\nif __name__ == \"__main__\":\n    raise Exception(\"UT is in quant_k_cache.py\")\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/nsa/index_buf_accessor.py",
    "content": "from typing import TYPE_CHECKING\n\nimport torch\nimport triton\nimport triton.language as tl\n\nfrom sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz\nfrom sglang.srt.utils import is_hip\n\n_is_hip = is_hip()\n_is_fp8_fnuz = is_fp8_fnuz()\n\nif TYPE_CHECKING:\n    from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool\n\n\"\"\"\nk: data, 128 item per token, fp8\ns: scale, 1 item per token, fp32\n\"\"\"\n\n\nclass GetK:\n    @classmethod\n    def execute(cls, *args, **kwargs):\n        return cls.triton(*args, **kwargs)\n\n    @classmethod\n    def slow(\n        cls, pool: \"NSATokenToKVPool\", buf, seq_len: int, page_indices: torch.Tensor\n    ):\n        num_pages = (seq_len + pool.page_size - 1) // pool.page_size\n        seq_len_ = num_pages * pool.page_size\n        index_k_fp8 = torch.empty(\n            (seq_len_, pool.index_head_dim),\n            dtype=torch.uint8,\n            device=pool.device,\n        )\n        for i in range(num_pages):\n            page_index = page_indices[i]\n            index_k_fp8[i * pool.page_size : (i + 1) * pool.page_size] = buf[\n                page_index\n            ][: pool.page_size * pool.index_head_dim].view(-1, pool.index_head_dim)\n\n        return index_k_fp8[:seq_len]\n\n    @classmethod\n    def torch_fast(\n        cls, pool: \"NSATokenToKVPool\", buf, seq_len: int, page_indices: torch.Tensor\n    ):\n        \"\"\"\n        :param page_indices: (num_pages,), int32\n        :return: (seq_len, index_head_dim), uint8\n        \"\"\"\n\n        # can handle per 128B instead of per element\n\n        # page_indices: (num_pages,), element := a page index\n        buf_numel_per_page = buf.shape[1]\n\n        num_k_bytes_per_page = pool.page_size * pool.index_head_dim\n        num_k_bytes_per_token = pool.index_head_dim\n\n        # buf: (num_pages, page_size 64 * head_dim 128 + page_size 64 * fp32_nbytes 4), uint8\n        # flat_buf: (whatever,), uint8\n        flat_buf = buf.flatten()\n\n        # flat_indices: (num_pages, num_k_bytes_per_page), int32, element := an index into flat_buf that we want to access\n        flat_indices = (page_indices * buf_numel_per_page)[:, None] + torch.arange(\n            num_k_bytes_per_page, dtype=torch.int32, device=\"cuda\"\n        )[None, :]\n        flat_indices = flat_indices.flatten()[: seq_len * num_k_bytes_per_token]\n\n        out = flat_buf[flat_indices]\n        return out.view(-1, 128)\n\n    @classmethod\n    def triton(\n        cls, pool: \"NSATokenToKVPool\", buf, seq_len: int, page_indices: torch.Tensor\n    ):\n        \"\"\"\n        Triton implementation for gathering K data from paged buffer.\n        :param page_indices: (num_pages,), int32/int64\n        :return: (seq_len, index_head_dim), uint8\n        \"\"\"\n        return _get_k_triton(\n            buf=buf,\n            page_indices=page_indices,\n            seq_len=seq_len,\n            page_size=pool.page_size,\n            index_head_dim=pool.index_head_dim,\n        )\n\n\nclass GetS:\n    @classmethod\n    def execute(cls, *args, **kwargs):\n        return cls.triton(*args, **kwargs)\n\n    @classmethod\n    def slow(\n        cls, pool: \"NSATokenToKVPool\", buf, seq_len: int, page_indices: torch.Tensor\n    ):\n        num_pages = (seq_len + pool.page_size - 1) // pool.page_size\n        seq_len_ = num_pages * pool.page_size\n        assert pool.index_head_dim // pool.quant_block_size == 1\n        index_k_scale_fp8 = torch.empty(\n            (seq_len_, 4),\n            dtype=torch.uint8,\n            device=pool.device,\n        )\n        for i in range(num_pages):\n            page_index = page_indices[i]\n            index_k_scale_fp8[i * pool.page_size : (i + 1) * pool.page_size] = buf[\n                page_index\n            ][pool.page_size * pool.index_head_dim :].view(-1, 4)\n        return index_k_scale_fp8[:seq_len]\n\n    @classmethod\n    def torch_fast(\n        cls, pool: \"NSATokenToKVPool\", buf, seq_len: int, page_indices: torch.Tensor\n    ):\n        \"\"\"\n        :param page_indices: (num_pages,), int32\n        :return: (seq_len, index_head_dim // quant_block_size), uint8\n        \"\"\"\n        buf_numel_per_page = buf.shape[1]\n\n        num_s_bytes_per_page = buf.shape[1] - pool.page_size * pool.index_head_dim\n        num_s_bytes_per_token = pool.index_head_dim // pool.quant_block_size * 4\n        s_offset_in_page = pool.page_size * pool.index_head_dim\n\n        flat_buf = buf.flatten()\n        flat_indices = (\n            (page_indices * buf_numel_per_page)[:, None]\n            + torch.arange(num_s_bytes_per_page, dtype=torch.int32, device=\"cuda\")[\n                None, :\n            ]\n            + s_offset_in_page\n        )\n        flat_indices = flat_indices.flatten()[: seq_len * num_s_bytes_per_token]\n\n        out = flat_buf[flat_indices]\n        return out.view(-1, 4)\n\n    @classmethod\n    def triton(\n        cls, pool: \"NSATokenToKVPool\", buf, seq_len: int, page_indices: torch.Tensor\n    ):\n        \"\"\"\n        Triton implementation for gathering S (scale) data from paged buffer.\n        :param page_indices: (num_pages,), int32/int64\n        :return: (seq_len, 4), uint8\n        \"\"\"\n        return _get_s_triton(\n            buf=buf,\n            page_indices=page_indices,\n            seq_len=seq_len,\n            page_size=pool.page_size,\n            index_head_dim=pool.index_head_dim,\n        )\n\n\nclass GetKAndS:\n    @classmethod\n    def execute(cls, *args, **kwargs):\n        return cls.triton(*args, **kwargs)\n\n    @classmethod\n    def triton(\n        cls,\n        pool: \"NSATokenToKVPool\",\n        buf: torch.Tensor,\n        page_indices: torch.Tensor,\n        seq_len_tensor: torch.Tensor,\n        seq_len_sum: int,\n        max_seq_len: int,\n    ):\n        \"\"\"\n        Triton implementation for gathering both K and S data from paged buffer in a single call.\n        :param page_indices: (num_pages,), int32/int64\n        :param seq_len_tensor: (num_pages,), int32/int64\n        :param seq_len_sum: sum of all sequence len, int32\n        :param max_seq_len: max of all sequence len, int32\n        :return: tuple of (k_fp8, k_scale) where\n                 k_fp8: (seq_len, index_head_dim), uint8\n                 k_scale: (seq_len, 4), uint8\n        \"\"\"\n        return _get_k_and_s_triton(\n            buf=buf,\n            page_indices=page_indices,\n            seq_lens=seq_len_tensor,\n            seq_len_sum=seq_len_sum,\n            max_seq_len=max_seq_len,\n            page_size=pool.page_size,\n            index_head_dim=pool.index_head_dim,\n        )\n\n\nclass SetK:\n    @classmethod\n    def execute(cls, *args, buf, **kwargs):\n        return cls.torch_fast(*args, **kwargs, buf=buf)\n\n    @classmethod\n    def slow(\n        cls,\n        pool: \"NSATokenToKVPool\",\n        buf: torch.Tensor,\n        loc: torch.Tensor,\n        index_k: torch.Tensor,\n    ):\n        for i in range(len(loc)):\n            page_index = loc[i] // pool.page_size\n            offset = loc[i] % pool.page_size\n            buf[\n                page_index,\n                offset * pool.index_head_dim : (offset + 1) * pool.index_head_dim,\n            ] = index_k[i].view(torch.uint8)\n\n    @classmethod\n    def torch_fast(\n        cls,\n        pool: \"NSATokenToKVPool\",\n        buf: torch.Tensor,\n        loc: torch.Tensor,\n        index_k: torch.Tensor,\n    ):\n        (num_tokens_to_write,) = loc.shape\n        buf_numel_per_page = buf.shape[1]\n        num_k_bytes_per_token = pool.index_head_dim\n\n        # loc: (num_tokens_to_write,), int32, element := the token index to write to\n        loc_page_index = loc // pool.page_size\n        loc_token_offset_in_page = loc % pool.page_size\n\n        flat_buf = buf.flatten()\n        flat_indices = (\n            (loc_page_index * buf_numel_per_page)[:, None]\n            + (loc_token_offset_in_page * num_k_bytes_per_token)[:, None]\n            + torch.arange(num_k_bytes_per_token, dtype=torch.int32, device=\"cuda\")[\n                None, :\n            ]\n        )\n        num_k_bytes_total = num_tokens_to_write * num_k_bytes_per_token\n        flat_indices = flat_indices.flatten()[:num_k_bytes_total]\n        flat_buf[flat_indices] = index_k.view(torch.uint8).flatten()\n\n\nclass SetS:\n    @classmethod\n    def execute(cls, *args, buf, **kwargs):\n        return cls.torch_fast(*args, **kwargs, buf=buf)\n\n    @classmethod\n    def slow(\n        cls,\n        pool: \"NSATokenToKVPool\",\n        buf: torch.Tensor,\n        loc: torch.Tensor,\n        index_k_scale: torch.Tensor,\n    ):\n        for i in range(len(loc)):\n            page_index = loc[i] // pool.page_size\n            offset = loc[i] % pool.page_size\n            start = pool.page_size * pool.index_head_dim\n            buf[page_index, start + offset * 4 : start + (offset + 1) * 4] = (\n                index_k_scale[i].view(torch.uint8)\n            )\n\n    @classmethod\n    def torch_fast(\n        cls,\n        pool: \"NSATokenToKVPool\",\n        buf: torch.Tensor,\n        loc: torch.Tensor,\n        index_k_scale: torch.Tensor,\n    ):\n        (num_tokens_to_write,) = loc.shape\n        buf_numel_per_page = buf.shape[1]\n        num_s_bytes_per_token = 4\n        s_offset_in_page = pool.page_size * pool.index_head_dim\n\n        # loc: (num_tokens_to_write,), int32, element := the token index to write to\n        loc_page_index = loc // pool.page_size\n        loc_token_offset_in_page = loc % pool.page_size\n\n        flat_buf = buf.flatten()\n        flat_indices = (\n            (loc_page_index * buf_numel_per_page)[:, None]\n            + s_offset_in_page\n            + (loc_token_offset_in_page * num_s_bytes_per_token)[:, None]\n            + torch.arange(num_s_bytes_per_token, dtype=torch.int32, device=\"cuda\")[\n                None, :\n            ]\n        )\n        number_s_bytes_total = num_tokens_to_write * num_s_bytes_per_token\n        flat_indices = flat_indices.flatten()[:number_s_bytes_total]\n        flat_buf[flat_indices] = index_k_scale.view(torch.uint8).flatten()\n\n\nclass SetKAndS:\n    @classmethod\n    def execute(cls, *args, buf, **kwargs):\n        if 0:\n            # print(\"SetK, SetS comparison test\")\n            buf_cloned = buf.clone()\n            cls.vanilla(*args, **kwargs, buf=buf)\n            cls.triton(*args, **kwargs, buf=buf_cloned)\n\n            def _clear_token_0(target):\n                target[0, :128] = target[0, 64 * 128 : 64 * 128 + 4] = 0\n\n            _clear_token_0(buf)\n            _clear_token_0(buf_cloned)\n\n            assert torch.all(\n                buf == buf_cloned\n            ), f\"{buf=} {buf_cloned=} {kwargs['loc'].to_list()=}\"\n            return\n\n        cls.triton(*args, **kwargs, buf=buf)\n\n    @classmethod\n    def vanilla(cls, pool, buf, loc, index_k, index_k_scale):\n        SetK.execute(pool=pool, buf=buf, loc=loc, index_k=index_k)\n        SetS.execute(pool=pool, buf=buf, loc=loc, index_k_scale=index_k_scale)\n\n    @classmethod\n    def triton(cls, pool, buf, loc, index_k, index_k_scale):\n        _set_k_and_s_triton(\n            buf=buf,\n            loc=loc,\n            index_k=index_k,\n            index_k_scale=index_k_scale,\n            page_size=pool.page_size,\n        )\n\n\ndef _set_k_and_s_triton(\n    buf: torch.Tensor,\n    loc: torch.Tensor,\n    index_k: torch.Tensor,\n    index_k_scale: torch.Tensor,\n    page_size: int,\n):\n    \"\"\"\n    :param buf: (num_pages, page_size 64 * (128B data + 4B scale)), uint8\n    :param loc: (num_tokens_to_write,), int, element := the token index to write to\n    :param index_k: (num_tokens_to_write, 128 elem), fp8\n    :param index_k_scale: (num_tokens_to_write, 1 elem), fp32\n    :return:\n    \"\"\"\n    num_pages, buf_numel_per_page = buf.shape\n    (num_tokens_to_write,) = loc.shape\n    num_tokens_to_write_, index_head_dim = index_k.shape\n\n    # Handle both 1D (num_tokens,) and 2D (num_tokens, 1) shapes for index_k_scale\n    if index_k_scale.ndim == 1:\n        num_tokens_to_write__ = index_k_scale.shape[0]\n        scale_dim = 1\n    elif index_k_scale.ndim == 2:\n        num_tokens_to_write__, scale_dim = index_k_scale.shape\n    else:\n        raise ValueError(\n            f\"index_k_scale must be 1D or 2D, got shape {index_k_scale.shape}\"\n        )\n    if _is_hip:\n        assert buf_numel_per_page == 1 * (128 + 4)\n    else:\n        assert buf_numel_per_page == 64 * (128 + 4)\n    assert num_tokens_to_write == num_tokens_to_write_ == num_tokens_to_write__\n    assert index_head_dim == 128\n    assert scale_dim == 1\n    if _is_hip:\n        assert page_size == 1\n    else:\n        assert page_size == 64\n\n    assert buf.dtype == torch.uint8\n    assert loc.dtype == torch.int64, f\"{loc.dtype=}\"  # can be int32\n    if _is_fp8_fnuz:\n        assert index_k.dtype == torch.float8_e4m3fnuz\n    else:\n        assert index_k.dtype == torch.float8_e4m3fn\n    assert index_k_scale.dtype == torch.float32\n\n    assert buf.is_contiguous()\n    assert loc.is_contiguous()\n    assert index_k.is_contiguous()\n    assert index_k_scale.is_contiguous()\n\n    if _is_fp8_fnuz:\n        buf_fp8 = buf.view(torch.float8_e4m3fnuz)\n    else:\n        buf_fp8 = buf.view(torch.float8_e4m3fn)\n    buf_fp32 = buf.view(torch.float32)\n\n    _set_k_and_s_triton_kernel[(num_tokens_to_write,)](\n        buf_fp8,\n        buf_fp32,\n        loc,\n        index_k,\n        index_k_scale,\n        index_k.stride(0),\n        PAGE_SIZE=page_size,\n        BUF_NUMEL_PER_PAGE=buf_numel_per_page,\n        NUM_K_ELEMS_PER_TOKEN=index_head_dim,\n        S_OFFSET_NBYTES_IN_PAGE=page_size * index_head_dim,\n    )\n\n\n@triton.jit\ndef _set_k_and_s_triton_kernel(\n    buf_fp8_ptr,\n    buf_fp32_ptr,\n    loc_ptr,\n    index_k_ptr,\n    index_k_scale_ptr,\n    index_k_ptr_stride_0,\n    PAGE_SIZE: tl.constexpr,\n    BUF_NUMEL_PER_PAGE: tl.constexpr,\n    NUM_K_ELEMS_PER_TOKEN: tl.constexpr,\n    S_OFFSET_NBYTES_IN_PAGE: tl.constexpr,\n):\n    token_id = tl.program_id(0)\n\n    loc = tl.load(loc_ptr + token_id)\n\n    in_k_offsets = token_id * index_k_ptr_stride_0 + tl.arange(0, NUM_K_ELEMS_PER_TOKEN)\n\n    # no need for `mask`, since we read 128B for k and 4B for scale, both pow of 2\n    k = tl.load(index_k_ptr + in_k_offsets)\n    k_scale = tl.load(index_k_scale_ptr + token_id)\n\n    loc_page_index = loc // PAGE_SIZE\n    loc_token_offset_in_page = loc % PAGE_SIZE\n\n    out_k_offsets = (\n        loc_page_index * BUF_NUMEL_PER_PAGE\n        + loc_token_offset_in_page * NUM_K_ELEMS_PER_TOKEN\n        + tl.arange(0, NUM_K_ELEMS_PER_TOKEN)\n    )\n\n    # \"//4\" b/c it is fp32 instead of uint8\n    out_s_offset = (\n        loc_page_index * BUF_NUMEL_PER_PAGE // 4\n        + S_OFFSET_NBYTES_IN_PAGE // 4\n        + loc_token_offset_in_page\n    )\n\n    tl.store(buf_fp8_ptr + out_k_offsets, k)\n    tl.store(buf_fp32_ptr + out_s_offset, k_scale)\n\n\ndef _get_k_triton(\n    buf: torch.Tensor,\n    page_indices: torch.Tensor,\n    seq_len: int,\n    page_size: int,\n    index_head_dim: int,\n):\n    \"\"\"\n    Gather K (key) data from paged buffer using Triton.\n\n    :param buf: (num_pages, page_size * 128 + page_size * 4), uint8\n    :param page_indices: (num_pages,), int32/int64\n    :param seq_len: int, number of tokens to gather\n    :param page_size: int, typically 64\n    :param index_head_dim: int, typically 128\n    :return: (seq_len, index_head_dim), uint8\n    \"\"\"\n    num_pages, buf_numel_per_page = buf.shape\n\n    # Allocate output\n    out = torch.empty((seq_len, index_head_dim), dtype=torch.uint8, device=buf.device)\n\n    # Launch kernel with one thread per token\n    grid = (seq_len,)\n    _get_k_triton_kernel[grid](\n        buf,\n        page_indices,\n        out,\n        seq_len,\n        page_size,\n        buf_numel_per_page,\n        index_head_dim,\n        BLOCK_SIZE=128,\n    )\n\n    return out\n\n\n@triton.jit\ndef _get_k_triton_kernel(\n    buf_ptr,\n    page_indices_ptr,\n    out_ptr,\n    seq_len: tl.constexpr,\n    page_size: tl.constexpr,\n    buf_numel_per_page: tl.constexpr,\n    index_head_dim: tl.constexpr,\n    BLOCK_SIZE: tl.constexpr,\n):\n    \"\"\"\n    Each program handles one token (seq_len tokens total).\n    Loads 128 bytes from the appropriate page.\n    \"\"\"\n    token_id = tl.program_id(0)\n\n    # Calculate which page and offset within page\n    page_idx = token_id // page_size\n    token_offset_in_page = token_id % page_size\n\n    # Load the page index from page_indices\n    page_index = tl.load(page_indices_ptr + page_idx)\n\n    # Calculate source offset in buf\n    # buf[page_index, token_offset_in_page * index_head_dim : ...]\n    src_base_offset = (\n        page_index * buf_numel_per_page + token_offset_in_page * index_head_dim\n    )\n\n    # Load 128 bytes (index_head_dim elements)\n    offsets = tl.arange(0, BLOCK_SIZE)\n    mask = offsets < index_head_dim\n    data = tl.load(buf_ptr + src_base_offset + offsets, mask=mask)\n\n    # Store to output\n    dst_offset = token_id * index_head_dim\n    tl.store(out_ptr + dst_offset + offsets, data, mask=mask)\n\n\ndef _get_s_triton(\n    buf: torch.Tensor,\n    page_indices: torch.Tensor,\n    seq_len: int,\n    page_size: int,\n    index_head_dim: int,\n):\n    \"\"\"\n    Gather S (scale) data from paged buffer using Triton.\n\n    :param buf: (num_pages, page_size * 128 + page_size * 4), uint8\n    :param page_indices: (num_pages,), int32/int64\n    :param seq_len: int, number of tokens to gather\n    :param page_size: int, typically 64\n    :param index_head_dim: int, typically 128\n    :return: (seq_len, 4), uint8 (representing fp32 scale)\n    \"\"\"\n    num_pages, buf_numel_per_page = buf.shape\n    s_offset_in_page = page_size * index_head_dim  # Scales start after K data\n\n    # Allocate output\n    out = torch.empty((seq_len, 4), dtype=torch.uint8, device=buf.device)\n\n    # Launch kernel with one thread per token\n    grid = (seq_len,)\n    _get_s_triton_kernel[grid](\n        buf,\n        page_indices,\n        out,\n        seq_len,\n        page_size,\n        buf_numel_per_page,\n        s_offset_in_page,\n    )\n\n    return out\n\n\n@triton.jit\ndef _get_s_triton_kernel(\n    buf_ptr,\n    page_indices_ptr,\n    out_ptr,\n    seq_len: tl.constexpr,\n    page_size: tl.constexpr,\n    buf_numel_per_page: tl.constexpr,\n    s_offset_in_page: tl.constexpr,\n):\n    \"\"\"\n    Each program handles one token (seq_len tokens total).\n    Loads 4 bytes (fp32 scale) from the appropriate page.\n    \"\"\"\n    token_id = tl.program_id(0)\n\n    # Calculate which page and offset within page\n    page_idx = token_id // page_size\n    token_offset_in_page = token_id % page_size\n\n    # Load the page index from page_indices\n    page_index = tl.load(page_indices_ptr + page_idx)\n\n    # Calculate source offset in buf\n    # Scales are stored after K data: page_size * index_head_dim offset\n    # buf[page_index, s_offset_in_page + token_offset_in_page * 4 : ...]\n    src_base_offset = (\n        page_index * buf_numel_per_page + s_offset_in_page + token_offset_in_page * 4\n    )\n\n    # Load 4 bytes (fp32 scale)\n    offsets = tl.arange(0, 4)\n    data = tl.load(buf_ptr + src_base_offset + offsets)\n\n    # Store to output\n    dst_offset = token_id * 4\n    tl.store(out_ptr + dst_offset + offsets, data)\n\n\ndef _get_k_and_s_triton(\n    buf: torch.Tensor,\n    page_indices: torch.Tensor,\n    seq_lens: torch.Tensor,\n    seq_len_sum: int,\n    max_seq_len: int,\n    page_size: int,\n    index_head_dim: int,\n):\n    \"\"\"\n    Fused gather of both K (key) and S (scale) data from paged buffer using Triton.\n    This is more efficient than calling GetK and GetS separately.\n\n    :param buf: (num_pages, page_size * 128 + page_size * 4), uint8\n    :param page_indices: (num_pages,), int32/int64\n    :param seq_lens: tensor of sequence lens, int64\n    :param seq_len_sum: sum of all sequence len, int32\n    :param max_seq_len: max of sequence len, int32\n    :param page_size: int, typically 64\n    :param index_head_dim: int, typically 128\n    :return: tuple of (k_out, s_out) where\n             k_out: (seq_len, index_head_dim), uint8\n             s_out: (seq_len, 4), uint8\n    \"\"\"\n    # Allocate outputs\n    k_out = torch.empty(\n        (seq_len_sum, index_head_dim), dtype=torch.uint8, device=buf.device\n    )\n    s_out = torch.empty((seq_len_sum, 4), dtype=torch.uint8, device=buf.device)\n\n    _, buf_numel_per_page = buf.shape\n    _, page_indice_batch_offset = page_indices.shape\n    s_offset_in_page = page_size * index_head_dim\n\n    # Launch kernel with one thread per token\n    BLOCK_SIZE = 256\n    BLOCK_SIZE_K = 128\n\n    num_token_blocks = (max_seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE\n    num_k_threads = (index_head_dim + BLOCK_SIZE_K - 1) // BLOCK_SIZE_K\n\n    seq_num = seq_lens.shape[0]\n    grid = (seq_num, num_token_blocks, num_k_threads)\n    seq_num_pow2 = 1\n    while seq_num_pow2 < seq_num:\n        seq_num_pow2 *= 2\n\n    _get_k_and_s_triton_kernel[grid](\n        buf_ptr=buf,\n        page_indices_ptr=page_indices,\n        k_out_ptr=k_out,\n        s_out_ptr=s_out,\n        seq_len_ptr=seq_lens,\n        seq_len_num_pow=seq_num_pow2,\n        page_size=page_size,\n        buf_numel_per_page=buf_numel_per_page,\n        index_head_dim=index_head_dim,\n        s_offset_in_page=s_offset_in_page,\n        page_indice_batch_offset=page_indice_batch_offset,\n        BLOCK_SIZE=BLOCK_SIZE,\n        BLOCK_SIZE_K=BLOCK_SIZE_K,\n    )\n\n    return k_out, s_out\n\n\n@triton.jit\ndef _get_k_and_s_triton_kernel(\n    buf_ptr,\n    page_indices_ptr,\n    k_out_ptr,\n    s_out_ptr,\n    seq_len_ptr,\n    seq_len_num_pow: tl.constexpr,\n    page_size: tl.constexpr,\n    buf_numel_per_page: tl.constexpr,\n    index_head_dim: tl.constexpr,\n    s_offset_in_page: tl.constexpr,\n    page_indice_batch_offset: tl.constexpr,\n    BLOCK_SIZE: tl.constexpr,\n    BLOCK_SIZE_K: tl.constexpr,\n):\n    \"\"\"\n    Fused kernel that gathers both K and S data in a single pass.\n    Each program handles one token (seq_len tokens total).\n    Loads 128 bytes (K) + 4 bytes (S) from the appropriate page.\n    \"\"\"\n    batch_id = tl.program_id(0)\n    block_token_start = tl.program_id(1) * BLOCK_SIZE\n    thread_idx = tl.program_id(2)\n\n    # Define the token range within the block and the K dimension range handled by the thread.\n    token_ids_in_block = tl.arange(0, BLOCK_SIZE)\n    token_ids = block_token_start + token_ids_in_block\n    k_offsets = thread_idx * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)\n\n    seq_len = tl.load(seq_len_ptr + batch_id)\n    token_valid_mask = token_ids < seq_len\n\n    pre_batch_idx = tl.arange(0, seq_len_num_pow)\n    mask_pre_batch_idx = pre_batch_idx < batch_id\n    prev_seq_lens = tl.load(seq_len_ptr + pre_batch_idx, mask=mask_pre_batch_idx)\n    batch_token_offset = tl.sum(prev_seq_lens)\n\n    # Batch calculate the page index and in-page offset of each token.\n    page_idx = token_ids // page_size\n    token_offset_in_page = token_ids % page_size\n    page_indices_base = batch_id * page_indice_batch_offset\n    page_idx_valid_mask = page_idx < page_indice_batch_offset\n    page_index = tl.load(\n        page_indices_ptr + page_idx + page_indices_base,\n        mask=token_valid_mask & page_idx_valid_mask,\n    )\n\n    # ===== Load K data =====\n    # The address calculation logic for K: page_index * total number of elements in a single page + K offset of the token within the page.\n    k_src_token_offset = token_offset_in_page * index_head_dim\n    k_src_base_offset = page_index * buf_numel_per_page + k_src_token_offset\n\n    k_load_addr = buf_ptr + k_src_base_offset[:, None] + k_offsets[None, :]\n    k_dim_mask = k_offsets[None, :] < index_head_dim\n    k_mask = token_valid_mask[:, None] & k_dim_mask\n\n    k_data = tl.load(k_load_addr, mask=k_mask, other=0)\n\n    # Store K to output\n    k_dst_token_offset = batch_token_offset + token_ids\n    k_dst_base_offset = k_dst_token_offset * index_head_dim\n    k_store_addr = k_out_ptr + k_dst_base_offset[:, None] + k_offsets[None, :]\n    tl.store(k_store_addr, k_data, mask=k_mask)\n\n    # ===== Load S data =====\n    # The address calculation logic for S: page_index * total number of elements in a single page + starting offset of S within the page + offset of token within S in the page\n    s_src_token_offset = s_offset_in_page + token_offset_in_page * 4\n    s_src_base_offset = page_index * buf_numel_per_page + s_src_token_offset\n\n    s_offsets = tl.arange(0, 4)\n    s_load_addr = buf_ptr + s_src_base_offset[:, None] + s_offsets[None, :]\n    s_mask = token_valid_mask[:, None] & (s_offsets[None, :] < 4)\n    s_data = tl.load(s_load_addr, mask=s_mask, other=0)\n\n    # Store S to output\n    s_dst_token_offset = batch_token_offset + token_ids\n    s_dst_base_offset = s_dst_token_offset * 4\n    s_store_addr = s_out_ptr + s_dst_base_offset[:, None] + s_offsets[None, :]\n    tl.store(s_store_addr, s_data, mask=s_mask)\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/nsa/nsa_backend_mtp_precompute.py",
    "content": "\"\"\"Multi-step precompute utilities for Native Sparse Attention backend.\n\nThis module provides optimization utilities for multi-step speculative decoding\nby precomputing shared metadata once and copying it to multiple backend instances.\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom dataclasses import dataclass\nfrom typing import TYPE_CHECKING, Optional\n\nimport torch\n\nfrom sglang.srt.layers.attention.nsa.utils import compute_nsa_seqlens\n\nif TYPE_CHECKING:\n    from sglang.srt.model_executor.forward_batch_info import ForwardMode\n    from sglang.srt.speculative.spec_info import SpecInput\n\n\n@dataclass\nclass PrecomputedMetadata:\n    \"\"\"Precomputed metadata shared across multiple backend instances.\n\n    Used for multi-step speculative decoding where multiple backends\n    need identical metadata. Precomputing once and copying N times\n    is much faster than computing N times.\n\n    \"\"\"\n\n    # Basic seqlens\n    cache_seqlens: torch.Tensor  # int32, [bs]\n    cu_seqlens_k: torch.Tensor  # int32, [bs+1]\n\n    # Page table\n    page_indices: torch.Tensor  # int32, [bs, max_len] or [expanded_bs, max_len]\n    real_page_table: Optional[torch.Tensor]  # int32, transformed version\n\n    # NSA seqlens\n    seqlens_expanded: torch.Tensor  # int32, [expanded_size]\n    nsa_cache_seqlens: torch.Tensor  # int32, [expanded_size]\n    nsa_cu_seqlens_k: torch.Tensor  # int32, [expanded_size+1]\n    seqlens_expanded_size: int\n\n    # Dimensions\n    max_len: int  # for decode/draft_extend\n    max_seqlen_k: int  # for target_verify\n\n    # FlashMLA (optional)\n    flashmla_metadata: Optional[torch.Tensor] = None\n\n\ndef compute_cu_seqlens(seqlens: torch.Tensor) -> torch.Tensor:\n    \"\"\"Compute cumulative sequence lengths with padding.\"\"\"\n    assert seqlens.dtype == torch.int32\n    return torch.nn.functional.pad(\n        torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)\n    )\n\n\nclass NativeSparseAttnBackendMTPPrecomputeMixin:\n    \"\"\"Mixin class providing metadata precomputation for multi-step speculative decoding.\n\n    This mixin provides the _precompute_replay_metadata method and its helpers,\n    which are used to optimize CUDA graph replay in multi-step scenarios.\n    \"\"\"\n\n    def _precompute_replay_metadata(\n        self,\n        bs: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        seq_lens_cpu: torch.Tensor,\n        forward_mode: \"ForwardMode\",\n        spec_info: Optional[\"SpecInput\"],\n    ) -> PrecomputedMetadata:\n        \"\"\"Precompute all shared metadata for multi-step backends.\n\n        This function extracts and computes all operations that are\n        identical across different backend instances in multi-step\n        speculative decoding.\n\n        Args:\n            bs: Batch size\n            req_pool_indices: Request pool indices [bs]\n            seq_lens: Sequence lengths [bs]\n            seq_lens_cpu: Sequence lengths on CPU [bs]\n            forward_mode: Forward mode (decode/target_verify/draft_extend)\n            spec_info: Speculative decoding info (for draft_extend mode)\n\n        Returns:\n            PrecomputedMetadata containing all shared intermediate results\n        \"\"\"\n        # Slice inputs to batch size\n        seq_lens = seq_lens[:bs]\n        seq_lens_cpu = seq_lens_cpu[:bs]\n        req_pool_indices = req_pool_indices[:bs]\n\n        # Dispatch to mode-specific precomputation\n        if forward_mode.is_decode_or_idle():\n            return self._precompute_decode_mode(\n                bs, req_pool_indices, seq_lens, seq_lens_cpu\n            )\n        elif forward_mode.is_target_verify():\n            return self._precompute_target_verify_mode(\n                bs, req_pool_indices, seq_lens, seq_lens_cpu\n            )\n        elif forward_mode.is_draft_extend():\n            return self._precompute_draft_extend_mode(\n                bs, req_pool_indices, seq_lens, seq_lens_cpu, spec_info\n            )\n        else:\n            raise ValueError(f\"Unsupported forward mode: {forward_mode}\")\n\n    def _precompute_decode_mode(\n        self,\n        bs: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        seq_lens_cpu: torch.Tensor,\n    ) -> PrecomputedMetadata:\n        \"\"\"Precompute metadata for normal decode mode.\"\"\"\n        max_len = int(seq_lens_cpu.max().item())\n\n        # Convert to int32 and compute cumsum\n        cache_seqlens = seq_lens.to(torch.int32)\n        cu_seqlens_k = compute_cu_seqlens(cache_seqlens)\n\n        # Get page indices from cache\n        page_indices = self.req_to_token[req_pool_indices, :max_len].contiguous()\n\n        # Compute NSA seqlens\n        nsa_cache_seqlens = compute_nsa_seqlens(\n            cache_seqlens, nsa_index_topk=self.nsa_index_topk\n        )\n        seqlens_expanded = cache_seqlens\n        seqlens_expanded_size = seqlens_expanded.shape[0]\n\n        # Compute NSA cumsum\n        nsa_cu_seqlens_k = compute_cu_seqlens(nsa_cache_seqlens)\n\n        # Transform page table if needed\n        if self.real_page_size > 1:\n            real_page_table = self._transform_table_1_to_real(page_indices)\n        else:\n            real_page_table = None  # Will use page_indices directly\n\n        # Compute FlashMLA metadata if needed\n        flashmla_metadata = None\n        if self.nsa_decode_impl == \"flashmla_kv\":\n            flashmla_metadata = self._compute_flashmla_metadata(\n                cache_seqlens=nsa_cache_seqlens,\n                seq_len_q=1,\n            )\n\n        return PrecomputedMetadata(\n            cache_seqlens=cache_seqlens,\n            cu_seqlens_k=cu_seqlens_k,\n            page_indices=page_indices,\n            real_page_table=real_page_table,\n            seqlens_expanded=seqlens_expanded,\n            nsa_cache_seqlens=nsa_cache_seqlens,\n            nsa_cu_seqlens_k=nsa_cu_seqlens_k,\n            seqlens_expanded_size=seqlens_expanded_size,\n            max_len=max_len,\n            max_seqlen_k=max_len,\n            flashmla_metadata=flashmla_metadata,\n        )\n\n    def _precompute_target_verify_mode(\n        self,\n        bs: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        seq_lens_cpu: torch.Tensor,\n    ) -> PrecomputedMetadata:\n        \"\"\"Precompute metadata for target verify mode.\"\"\"\n        max_seqlen_k = int(\n            seq_lens_cpu.max().item() + self.speculative_num_draft_tokens\n        )\n\n        # Cache seqlens with draft tokens\n        cache_seqlens = (seq_lens + self.speculative_num_draft_tokens).to(torch.int32)\n        cu_seqlens_k = compute_cu_seqlens(cache_seqlens)\n\n        # Page indices (repeated for each draft token)\n        page_indices = self.req_to_token[req_pool_indices, :max_seqlen_k]\n        page_indices = torch.repeat_interleave(\n            page_indices, repeats=self.speculative_num_draft_tokens, dim=0\n        ).contiguous()\n\n        # Generate expanded seqlens\n        extend_seq_lens_cpu = [self.speculative_num_draft_tokens] * bs\n        seqlens_int32_cpu = [\n            self.speculative_num_draft_tokens + kv_len\n            for kv_len in seq_lens_cpu.tolist()\n        ]\n        seqlens_expanded = torch.cat(\n            [\n                torch.arange(\n                    kv_len - qo_len + 1,\n                    kv_len + 1,\n                    dtype=torch.int32,\n                    device=self.device,\n                )\n                for qo_len, kv_len in zip(\n                    extend_seq_lens_cpu,\n                    seqlens_int32_cpu,\n                    strict=True,\n                )\n            ]\n        )\n\n        # Compute NSA seqlens\n        nsa_cache_seqlens = compute_nsa_seqlens(seqlens_expanded, self.nsa_index_topk)\n        seqlens_expanded_size = seqlens_expanded.shape[0]\n\n        # NSA cumsum\n        nsa_cu_seqlens_k = compute_cu_seqlens(nsa_cache_seqlens)\n\n        # Transform page table\n        if self.real_page_size > 1:\n            real_page_table = self._transform_table_1_to_real(page_indices)\n        else:\n            real_page_table = None\n\n        # FlashMLA metadata\n        flashmla_metadata = None\n        if self.nsa_decode_impl == \"flashmla_kv\":\n            flashmla_metadata = self._compute_flashmla_metadata(\n                cache_seqlens=nsa_cache_seqlens,\n                seq_len_q=1,\n            )\n\n        return PrecomputedMetadata(\n            cache_seqlens=cache_seqlens,\n            cu_seqlens_k=cu_seqlens_k,\n            page_indices=page_indices,\n            real_page_table=real_page_table,\n            seqlens_expanded=seqlens_expanded,\n            nsa_cache_seqlens=nsa_cache_seqlens,\n            nsa_cu_seqlens_k=nsa_cu_seqlens_k,\n            seqlens_expanded_size=seqlens_expanded_size,\n            max_len=-1,  # Not used in this mode\n            max_seqlen_k=max_seqlen_k,\n            flashmla_metadata=flashmla_metadata,\n        )\n\n    def _precompute_draft_extend_mode(\n        self,\n        bs: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        seq_lens_cpu: torch.Tensor,\n        spec_info: \"SpecInput\",\n    ) -> PrecomputedMetadata:\n        \"\"\"Precompute metadata for draft extend mode.\"\"\"\n        max_seqlen_k = int(seq_lens_cpu.max().item())\n\n        # Cache seqlens\n        cache_seqlens = seq_lens.to(torch.int32)\n        cu_seqlens_k = compute_cu_seqlens(cache_seqlens)\n\n        # Extend seqlens from spec_info\n        extend_seq_lens = spec_info.accept_length[:bs]\n        extend_seq_lens_cpu = extend_seq_lens.tolist()\n\n        # Page indices (repeated per accept length)\n        page_indices = self.req_to_token[req_pool_indices, :max_seqlen_k]\n        page_indices = torch.repeat_interleave(\n            page_indices, repeats=extend_seq_lens, dim=0\n        ).contiguous()\n\n        # Generate expanded seqlens\n        seqlens_expanded = torch.cat(\n            [\n                torch.arange(\n                    kv_len - qo_len + 1,\n                    kv_len + 1,\n                    dtype=torch.int32,\n                    device=self.device,\n                )\n                for qo_len, kv_len in zip(\n                    extend_seq_lens_cpu,\n                    seq_lens_cpu.tolist(),\n                    strict=True,\n                )\n            ]\n        )\n\n        # Compute NSA seqlens\n        nsa_cache_seqlens = compute_nsa_seqlens(seqlens_expanded, self.nsa_index_topk)\n        seqlens_expanded_size = seqlens_expanded.shape[0]\n\n        # NSA cumsum\n        nsa_cu_seqlens_k = compute_cu_seqlens(nsa_cache_seqlens)\n\n        # Transform page table\n        if self.real_page_size > 1:\n            real_page_table = self._transform_table_1_to_real(page_indices)\n        else:\n            real_page_table = None\n\n        # FlashMLA metadata\n        flashmla_metadata = None\n        if self.nsa_decode_impl == \"flashmla_kv\":\n            flashmla_metadata = self._compute_flashmla_metadata(\n                cache_seqlens=nsa_cache_seqlens,\n                seq_len_q=1,\n            )\n\n        return PrecomputedMetadata(\n            cache_seqlens=cache_seqlens,\n            cu_seqlens_k=cu_seqlens_k,\n            page_indices=page_indices,\n            real_page_table=real_page_table,\n            seqlens_expanded=seqlens_expanded,\n            nsa_cache_seqlens=nsa_cache_seqlens,\n            nsa_cu_seqlens_k=nsa_cu_seqlens_k,\n            seqlens_expanded_size=seqlens_expanded_size,\n            max_len=max_seqlen_k,\n            max_seqlen_k=max_seqlen_k,\n            flashmla_metadata=flashmla_metadata,\n        )\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/nsa/nsa_indexer.py",
    "content": "from __future__ import annotations\n\nimport contextlib\nfrom abc import ABC, abstractmethod\nfrom typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple\n\nimport torch\nfrom einops import rearrange\n\nfrom sglang.jit_kernel.fused_store_index_cache import (\n    can_use_nsa_fused_store,\n    fused_store_index_k_cache,\n)\nfrom sglang.srt.environ import envs\nfrom sglang.srt.layers.dp_attention import attn_tp_all_gather_into_tensor\nfrom sglang.srt.layers.layernorm import LayerNorm\nfrom sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz\nfrom sglang.srt.layers.utils import MultiPlatformOp\nfrom sglang.srt.utils import add_prefix, ceil_align, is_cuda, is_hip, is_npu\n\nglobal _use_multi_stream\n_is_cuda = is_cuda()\n_is_hip = is_hip()\n_is_npu = is_npu()\n_is_fp8_fnuz = is_fp8_fnuz()\nif _is_cuda:\n    try:\n        import deep_gemm\n    except ImportError as e:\n        deep_gemm = e\n\nif is_npu():\n    import torch_npu\n    from sglang.srt.hardware_backend.npu.utils import get_indexer_weight_stream\n\nfrom sglang.srt.distributed import (\n    get_attn_context_model_parallel_rank,\n    get_attn_context_model_parallel_world_size,\n)\nfrom sglang.srt.distributed.parallel_state import get_pp_group\nfrom sglang.srt.layers import deep_gemm_wrapper\nfrom sglang.srt.layers.attention.nsa.utils import (\n    cp_all_gather_rerange_output,\n    is_nsa_enable_prefill_cp,\n    is_nsa_prefill_cp_in_seq_split,\n)\nfrom sglang.srt.layers.communicator import ScatterMode\nfrom sglang.srt.layers.linear import ReplicatedLinear\nfrom sglang.srt.layers.quantization.base_config import QuantizationConfig\nfrom sglang.srt.layers.rotary_embedding import get_rope_wrapper\nfrom sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode\nfrom sglang.srt.model_executor.forward_batch_info import ForwardBatch\nfrom sglang.srt.server_args import get_global_server_args\n\n_use_ag_after_qlora = envs.SGLANG_USE_AG_AFTER_QLORA.get()\nif TYPE_CHECKING:\n    from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool\n\n\nDUAL_STREAM_TOKEN_THRESHOLD = 1024 if _is_cuda else 0\n\n\nclass BaseIndexerMetadata(ABC):\n    @abstractmethod\n    def get_seqlens_int32(self) -> torch.Tensor:\n        \"\"\"\n        Return: (batch_size,) int32 tensor\n        \"\"\"\n\n    @abstractmethod\n    def get_page_table_64(self) -> torch.Tensor:\n        \"\"\"\n        Return: (batch_size, num_blocks) int32, page table.\n                The page size of the table is 64.\n        \"\"\"\n\n    @abstractmethod\n    def get_page_table_1(self) -> torch.Tensor:\n        \"\"\"\n        Return: (batch_size, num_blocks) int32, page table.\n                The page size of the table is 1.\n        \"\"\"\n\n    @abstractmethod\n    def get_seqlens_expanded(self) -> torch.Tensor:\n        \"\"\"\n        Return: (sum_extend_seq_len,) int32 tensor\n        \"\"\"\n\n    def get_indexer_kvcache_range(self) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Return: (tokens, ), (tokens, ) int32, k_start and k_end in kv cache(token,xxx) for each token.\n        \"\"\"\n\n    def get_indexer_seq_len_cpu(self) -> torch.Tensor:\n        \"\"\"\n        Return: seq lens for each batch.\n        \"\"\"\n\n    def get_indexer_seq_len(self) -> torch.Tensor:\n        \"\"\"\n        Return: seq lens for each batch.\n        \"\"\"\n\n    def get_nsa_extend_len_cpu(self) -> List[int]:\n        \"\"\"\n        Return: extend seq lens for each batch.\n        \"\"\"\n\n    def get_token_to_batch_idx(self) -> torch.Tensor:\n        \"\"\"\n        Return: batch idx for each token.\n        \"\"\"\n\n    @abstractmethod\n    def topk_transform(\n        self,\n        logits: torch.Tensor,\n        topk: int,\n    ) -> torch.Tensor:\n        \"\"\"\n        Perform topk selection on the logits and possibly transform the result.\n\n        NOTE that attention backend may override this function to do some\n        transformation, which means the result of this topk_transform may not\n        be the topk indices of the input logits.\n\n        Return: Anything, since it will be passed to the attention backend\n                for further processing on sparse attention computation.\n                Don't assume it is the topk indices of the input logits.\n        \"\"\"\n\n\ndef rotate_activation(x: torch.Tensor) -> torch.Tensor:\n    assert x.dtype == torch.bfloat16\n    # from sgl_kernel import hadamard_transform\n    if _is_hip:\n        from fast_hadamard_transform import hadamard_transform\n    else:\n        from sglang.jit_kernel.hadamard import hadamard_transform\n\n    hidden_size = x.size(-1)\n    assert (\n        hidden_size & (hidden_size - 1)\n    ) == 0, \"Hidden size must be a power of 2 for Hadamard transform.\"\n    return hadamard_transform(x, scale=hidden_size**-0.5)\n\n\nclass Indexer(MultiPlatformOp):\n    def __init__(\n        self,\n        hidden_size: int,\n        index_n_heads: int,\n        index_head_dim: int,\n        rope_head_dim: int,\n        index_topk: int,\n        q_lora_rank: int,\n        max_position_embeddings: int,\n        rope_theta: float,\n        layer_id: int,\n        scale_fmt: Optional[str],\n        block_size: int = 128,\n        rope_scaling: Optional[Dict[str, Any]] = None,\n        is_neox_style: bool = True,\n        prefix: str = \"\",\n        quant_config: Optional[QuantizationConfig] = None,\n        alt_stream: Optional[torch.cuda.Stream] = None,\n    ):\n        super().__init__()\n        self.hidden_size = hidden_size\n        self.n_heads = index_n_heads\n        self.head_dim = index_head_dim\n        self.rope_head_dim = rope_head_dim\n        self.index_topk = index_topk\n        self.q_lora_rank = q_lora_rank\n        self.layer_id = layer_id\n        self.alt_stream = alt_stream\n        self.nsa_enable_prefill_cp = is_nsa_enable_prefill_cp()\n        if self.nsa_enable_prefill_cp:\n            self.cp_size = get_attn_context_model_parallel_world_size()\n            self.cp_rank = get_attn_context_model_parallel_rank()\n        else:\n            self.cp_size = None\n            self.cp_rank = None\n        if _is_cuda:\n            self.sm_count = deep_gemm.get_num_sms()\n            self.half_device_sm_count = ceil_align(self.sm_count // 2, 8)\n            pp_size = get_global_server_args().pp_size\n            self.logits_with_pp_recv = pp_size > 1 and not get_pp_group().is_last_rank\n        else:\n            self.logits_with_pp_recv = False\n\n        self.wq_b = ReplicatedLinear(\n            self.q_lora_rank,\n            self.n_heads * self.head_dim,\n            bias=False,\n            quant_config=quant_config,\n            prefix=add_prefix(\"wq_b\", prefix),\n        )\n\n        self.wk = ReplicatedLinear(\n            self.hidden_size,\n            self.head_dim,\n            bias=False,\n            quant_config=quant_config,\n            prefix=add_prefix(\"wk\", prefix),\n        )\n        self.weights_proj = ReplicatedLinear(\n            self.hidden_size,\n            self.n_heads,\n            bias=False,\n            params_dtype=torch.bfloat16 if _is_cuda else torch.float32,\n            prefix=add_prefix(\"weights_proj\", prefix),\n        )\n        self.k_norm = LayerNorm(self.head_dim, dtype=torch.float32)\n        self.rotary_emb = get_rope_wrapper(\n            rope_head_dim,\n            rotary_dim=rope_head_dim,\n            max_position=max_position_embeddings,\n            base=rope_theta,  # type: ignore\n            rope_scaling=rope_scaling,\n            is_neox_style=is_neox_style,\n            device=get_global_server_args().device,\n        )\n        self.block_size = block_size\n        self.scale_fmt = scale_fmt\n        self.softmax_scale = self.head_dim**-0.5\n\n    @contextlib.contextmanager\n    def _with_real_sm_count(self):\n        # When pipeline parallelism is enabled, each PP rank initiates a recv operation after the _pp_launch_batch\n        # request to receive the PP proxy tensor or output from the previous stage, occupying one SM resource.\n        # Model execution runs in parallel with the recv operation, so the SMs available to the indexer must be reduced\n        # by 1. Currently, the last rank starts the send result + recv request only after waiting for execution results.\n        if self.logits_with_pp_recv:\n            pp_recv_sm_count = 1\n            with deep_gemm_wrapper.configure_deep_gemm_num_sms(\n                self.sm_count - pp_recv_sm_count\n            ):\n                yield\n        else:\n            yield\n\n    def _weights_proj_bf16_in_fp32_out(self, x: torch.Tensor) -> torch.Tensor:\n        if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:\n            weight = self.weights_proj.weight\n            out = torch.empty(\n                (x.shape[0], weight.shape[0]),\n                dtype=torch.float32,\n                device=x.device,\n            )\n            deep_gemm_wrapper.gemm_nt_bf16bf16f32(x, weight, out)\n            return out\n\n        if _is_hip:\n            x = x.to(self.weights_proj.weight.dtype)\n        weights, _ = self.weights_proj(x)\n        return weights.float()\n\n    @torch.compile(dynamic=True) if not _is_hip else lambda f: f\n    def _project_and_scale_head_gates(self, x: torch.Tensor):\n        weights = self._weights_proj_bf16_in_fp32_out(x)\n        weights = weights * self.n_heads**-0.5\n        return weights\n\n    @torch.compile(dynamic=True) if not _is_hip else lambda f: f\n    def _get_logits_head_gate(self, x: torch.Tensor, q_scale: torch.Tensor):\n        weights = self._weights_proj_bf16_in_fp32_out(x)\n        weights = weights * self.n_heads**-0.5\n        weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale\n        return weights\n\n    def _get_q_k_bf16(\n        self,\n        q_lora: torch.Tensor,\n        x: torch.Tensor,\n        positions: torch.Tensor,\n        enable_dual_stream: bool,\n        forward_batch: ForwardBatch,\n    ):\n        if enable_dual_stream:\n            current_stream = torch.cuda.current_stream()\n            self.alt_stream.wait_stream(current_stream)\n\n            with deep_gemm_wrapper.configure_deep_gemm_num_sms(\n                self.half_device_sm_count\n            ):\n                query, _ = self.wq_b(q_lora)\n                query = rearrange(query, \"l (h d) -> l h d\", d=self.head_dim)\n                q_rope, _ = torch.split(\n                    query,\n                    [self.rope_head_dim, self.head_dim - self.rope_head_dim],\n                    dim=-1,\n                )\n            with torch.cuda.stream(self.alt_stream):\n                # TODO we should also put DeepGEMM half SM here?\n                key, _ = self.wk(x)\n                key = self.k_norm(key)\n\n                k_rope, _ = torch.split(\n                    key,\n                    [self.rope_head_dim, self.head_dim - self.rope_head_dim],\n                    dim=-1,\n                )\n\n            current_stream.wait_stream(self.alt_stream)\n        else:\n            query, _ = self.wq_b(q_lora)\n            query = rearrange(query, \"l (h d) -> l h d\", d=self.head_dim)\n            q_rope, _ = torch.split(\n                query, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1\n            )\n            key, _ = self.wk(x)\n            key = self.k_norm(key)\n            k_rope, _ = torch.split(\n                key, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1\n            )\n\n        q_rope, k_rope = self.rotary_emb(positions, q_rope, k_rope)\n\n        query[..., : self.rope_head_dim] = q_rope.clone()\n        key[..., : self.rope_head_dim] = k_rope.clone()\n\n        if enable_dual_stream:\n            current_stream = torch.cuda.current_stream()\n            self.alt_stream.wait_stream(current_stream)\n            query = rotate_activation(query)\n\n            with torch.cuda.stream(self.alt_stream):\n                key = rotate_activation(key)\n            current_stream.wait_stream(self.alt_stream)\n        else:\n            query = rotate_activation(query)\n            key = rotate_activation(key)\n\n        # allgather+rerrange\n        if forward_batch.nsa_cp_metadata is not None and self.nsa_enable_prefill_cp:\n            key = cp_all_gather_rerange_output(\n                key.contiguous(),\n                self.cp_size,\n                forward_batch,\n                torch.cuda.current_stream(),\n            )\n        return query, key\n\n    def _get_k_bf16(\n        self,\n        x: torch.Tensor,\n        positions: torch.Tensor,\n        enable_dual_stream: bool,\n    ):\n        # Compute only key, skip query\n        key, _ = self.wk(x)\n        key = self.k_norm(key)\n        k_rope, _ = torch.split(\n            key, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1\n        )\n\n        _, k_rope = self.rotary_emb(positions, k_rope, k_rope)\n        key[..., : self.rope_head_dim] = k_rope.clone()\n        key = rotate_activation(key)\n\n        return key\n\n    def _get_topk_paged(\n        self,\n        forward_batch: ForwardBatch,\n        layer_id: int,\n        q_fp8: torch.Tensor,\n        weights: torch.Tensor,\n        metadata: BaseIndexerMetadata,\n    ) -> torch.Tensor:\n        if TYPE_CHECKING:\n            assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool)\n\n        page_size = forward_batch.token_to_kv_pool.page_size\n        # NOTE(dark): blocksize = 64 is hardcoded in deep_gemm\n        if _is_hip:\n            assert page_size == 1, \"only support page size 1\"\n            block_tables = metadata.get_page_table_1()\n        else:\n            assert page_size == 64, \"only support page size 64\"\n            # NOTE(dark): this support extend/decode/decode+graph\n            block_tables = metadata.get_page_table_64()\n\n        max_seq_len = block_tables.shape[1] * page_size\n        kv_cache_fp8 = forward_batch.token_to_kv_pool.get_index_k_with_scale_buffer(\n            layer_id=layer_id\n        )\n\n        blocksize = page_size\n        if (\n            forward_batch.forward_mode.is_target_verify()\n            or forward_batch.forward_mode.is_draft_extend(include_v2=True)\n        ):\n            seqlens_32 = metadata.get_seqlens_expanded()\n        else:\n            seqlens_32 = metadata.get_seqlens_int32()\n        # Reuse pre-computed schedule metadata if available (from init_forward_metadata),\n        # otherwise fall back to computing it here.\n        schedule_metadata = getattr(metadata, \"paged_mqa_schedule_metadata\", None)\n        if _is_cuda:\n            if schedule_metadata is None:\n                schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata(\n                    seqlens_32, blocksize, self.sm_count\n                )\n\n        assert len(q_fp8.shape) == 3\n        q_fp8 = q_fp8.unsqueeze(1)  # the next_n dim is 1 now\n        assert len(kv_cache_fp8.shape) == 2\n        block_kv = 1 if _is_hip else 64\n        num_heads_kv = 1\n        head_dim_with_sf = 132\n        if _is_hip:\n            kv_cache_fp8 = kv_cache_fp8.view(\n                -1, block_kv, num_heads_kv, head_dim_with_sf\n            )\n        else:\n            kv_cache_fp8 = kv_cache_fp8.view(\n                kv_cache_fp8.shape[0], block_kv, num_heads_kv, head_dim_with_sf\n            )\n        assert len(weights.shape) == 3\n        weights = weights.squeeze(2)\n\n        # When attn_tp_size > 1 or in the MAX_LEN padding mode, padding may exist in the hidden states,\n        # and it is necessary to extract the actual q length.\n        q_offset = sum(metadata.get_nsa_extend_len_cpu())\n        if _is_hip:\n            from aiter.ops.triton.pa_mqa_logits import deepgemm_fp8_paged_mqa_logits\n\n            batch_size, next_n, heads, _ = q_fp8.shape\n            logits = torch.full(\n                (batch_size * next_n, max_seq_len),\n                float(\"-inf\"),\n                device=q_fp8.device,\n                dtype=torch.float32,\n            )\n            deepgemm_fp8_paged_mqa_logits(\n                q_fp8,\n                kv_cache_fp8,\n                weights,\n                logits,\n                seqlens_32,\n                block_tables,\n                max_seq_len,\n                Preshuffle=False,\n                KVBlockSize=block_kv,\n                ChunkK=128,\n                TotalCuCount=256,\n                WavePerEU=5,\n            )\n        else:\n            logits = deep_gemm.fp8_paged_mqa_logits(\n                q_fp8[:q_offset],\n                kv_cache_fp8,\n                weights[:q_offset],\n                seqlens_32,\n                block_tables,\n                schedule_metadata,\n                max_seq_len,\n                clean_logits=False,\n            )\n\n        # NOTE(dark): logits should be cleaned in topk_transform\n        topk_result = metadata.topk_transform(logits, self.index_topk)\n        # Restore possible padding exist in the hidden states.\n        if not _is_hip and q_offset < q_fp8.shape[0]:\n            pad_len = q_fp8.shape[0] - q_offset\n            padding = torch.full(\n                (pad_len, topk_result.shape[1]),\n                -1,\n                dtype=topk_result.dtype,\n                device=topk_result.device,\n            )\n            topk_result = torch.cat([topk_result, padding], dim=0)\n        return topk_result\n\n    def _should_chunk_mqa_logits(\n        self, num_q: int, num_k: int, device: torch.device\n    ) -> Tuple[bool, int]:\n        \"\"\"\n        Detect whether we need to chunk the MQA logits computation to avoid OOM\n        Return: (need_chunk, free_mem)\n        \"\"\"\n        # Quick static check for normal batches\n        if num_q * num_k < 8_000_000:  # 8M elements ≈ 32MB logits\n            return False, 0\n\n        free_mem, total_mem = torch.cuda.mem_get_info(device)\n        bytes_per_elem = 4  # float32\n        logits_bytes = num_q * num_k * bytes_per_elem\n\n        # Logits should not exceed 50% of free memory or 30% of total memory\n        need_chunk = (logits_bytes * 2 > free_mem) or (logits_bytes > total_mem * 0.3)\n        return need_chunk, free_mem\n\n    def _get_topk_ragged(\n        self,\n        enable_dual_stream: bool,\n        forward_batch: ForwardBatch,\n        layer_id: int,\n        q_fp8: torch.Tensor,\n        weights: torch.Tensor,\n        metadata: BaseIndexerMetadata,\n    ) -> torch.Tensor:\n        if TYPE_CHECKING:\n            assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool)\n\n        assert forward_batch.forward_mode.is_extend_without_speculative()\n\n        page_size = forward_batch.token_to_kv_pool.page_size\n        if _is_hip:\n            assert page_size == 1, \"only support page size 1\"\n        else:\n            assert page_size == 64, \"only support page size 64\"\n\n        assert len(weights.shape) == 3\n        assert (\n            forward_batch.seq_lens_cpu is not None\n            and forward_batch.extend_seq_lens_cpu is not None\n        )\n        weights = weights.squeeze(-1)\n\n        if _is_hip:\n            block_tables = metadata.get_page_table_1()\n        else:\n            block_tables = metadata.get_page_table_64()\n\n        assert (\n            forward_batch.seq_lens_cpu is not None\n            and forward_batch.extend_seq_lens_cpu is not None\n        )\n\n        batch_size = len(block_tables)\n        token_nums, _, _ = q_fp8.shape\n        device = q_fp8.device\n\n        topk_result = torch.full(\n            (token_nums, self.index_topk), -1, device=device, dtype=torch.int32\n        )\n        if batch_size == 0:\n            return topk_result\n\n        ks, ke = metadata.get_indexer_kvcache_range()\n\n        indexer_seq_lens_cpu = metadata.get_indexer_seq_len_cpu()\n        seq_len_sum = torch.sum(indexer_seq_lens_cpu).item()\n        max_seq_len = torch.max(indexer_seq_lens_cpu).item()\n        k_fp8, k_scale = forward_batch.token_to_kv_pool.get_index_k_scale_buffer(\n            layer_id,\n            metadata.get_indexer_seq_len(),\n            block_tables,\n            seq_len_sum,\n            max_seq_len,\n        )\n        if _is_fp8_fnuz:\n            k_fp8 = k_fp8.view(torch.float8_e4m3fnuz)\n        else:\n            k_fp8 = k_fp8.view(torch.float8_e4m3fn)\n\n        k_scale = k_scale.view(torch.float32).squeeze(-1)\n        kv_fp8 = (k_fp8, k_scale)\n\n        # Check if we need to chunk to avoid OOM\n        seq_lens_expanded = metadata.get_seqlens_expanded()\n        token_to_batch_idx = metadata.get_token_to_batch_idx()\n        q_offset = ks.shape[0]\n        k_offset = k_fp8.shape[0]\n        need_chunk, free_mem = self._should_chunk_mqa_logits(q_offset, k_offset, device)\n\n        if not need_chunk:\n            assert q_fp8[:q_offset].shape[0] != 0\n            with self._with_real_sm_count():\n                if _is_hip:\n                    from aiter.ops.triton.fp8_mqa_logits import fp8_mqa_logits\n\n                    kv, scale = kv_fp8\n                    logits = fp8_mqa_logits(\n                        q_fp8[:q_offset], kv, scale, weights[:q_offset], ks, ke\n                    )\n                else:\n                    logits = deep_gemm.fp8_mqa_logits(\n                        q_fp8[:q_offset],\n                        kv_fp8,\n                        weights[:q_offset],\n                        ks,\n                        ke,\n                        clean_logits=False,\n                    )\n            assert logits.shape[0] == len(seq_lens_expanded)\n            assert logits.shape[1] == k_offset\n\n            raw_topk_result = metadata.topk_transform(logits, self.index_topk, ks=ks)\n            topk_result[:q_offset] = raw_topk_result\n            return topk_result\n\n        # Chunk path\n        bytes_per_elem = 4  # float32\n        bytes_per_row = k_offset * bytes_per_elem\n        # Reserve 50% of free memory for logits\n        max_rows = max(1, int((free_mem * 0.5) // max(bytes_per_row, 1)))\n        max_rows = min(max_rows, q_offset)\n\n        global_topk_offset = metadata.attn_metadata.topk_indices_offset\n\n        assert (\n            seq_lens_expanded.shape[0] == q_offset\n        ), f\"seq_lens_expanded length mismatch: {seq_lens_expanded.shape[0]} != {q_offset}\"\n        if global_topk_offset is not None:\n            assert (\n                global_topk_offset.shape[0] >= q_offset\n            ), f\"topk_indices_offset too short: {global_topk_offset.shape[0]} < {q_offset}\"\n\n        start = 0\n        while start < q_offset:\n            end = min(start + max_rows, q_offset)\n\n            with self._with_real_sm_count():\n                if _is_hip:\n                    from aiter.ops.triton.fp8_mqa_logits import fp8_mqa_logits\n\n                    kv, scale = kv_fp8\n                    logits_chunk = fp8_mqa_logits(\n                        q_fp8[start:end],\n                        kv,\n                        scale,\n                        weights[start:end],\n                        ks[start:end],\n                        ke[start:end],\n                    )\n                else:\n                    logits_chunk = deep_gemm.fp8_mqa_logits(\n                        q_fp8[start:end],\n                        kv_fp8,\n                        weights[start:end],\n                        ks[start:end],\n                        ke[start:end],\n                        clean_logits=False,\n                    )\n\n            lengths_chunk = seq_lens_expanded[start:end]\n\n            # RAGGED: use global offset; PAGED: construct local cu_seqlens_q per chunk\n            if global_topk_offset is not None:\n                # RAGGED path\n                topk_offset_chunk = global_topk_offset[start:end]\n                cu_seqlens_q_chunk = None\n                batch_idx_chunk = None\n            else:\n                # PAGED path: treat each token as a length-1 sequence\n                topk_offset_chunk = None\n                B_chunk = logits_chunk.shape[0]\n                cu_seqlens_q_chunk = torch.ones(\n                    B_chunk, dtype=torch.int32, device=device\n                )\n                batch_idx_chunk = token_to_batch_idx[start:end]\n\n            raw_topk_chunk = metadata.topk_transform(\n                logits_chunk,\n                self.index_topk,\n                ks=ks[start:end],\n                cu_seqlens_q=cu_seqlens_q_chunk,\n                ke_offset=lengths_chunk,\n                batch_idx_list=batch_idx_chunk,\n                topk_indices_offset_override=topk_offset_chunk,\n            )\n            topk_result[start:end] = raw_topk_chunk\n            start = end\n\n        return topk_result\n\n    def _forward_cuda_k_only(\n        self,\n        x: torch.Tensor,\n        positions: torch.Tensor,\n        forward_batch: ForwardBatch,\n        layer_id: int,\n        act_quant,\n        enable_dual_stream: bool,\n        metadata: BaseIndexerMetadata,\n        return_indices: bool = True,\n    ) -> Optional[torch.Tensor]:\n        assert forward_batch.forward_mode.is_extend_without_speculative()\n        x_meta = x[0] if isinstance(x, tuple) else x\n\n        # Fast path: only compute and store k cache, skip all q and weights ops\n        key = self._get_k_bf16(x, positions, enable_dual_stream)\n\n        if not forward_batch.out_cache_loc.is_contiguous():\n            forward_batch.out_cache_loc = forward_batch.out_cache_loc.contiguous()\n\n        self._store_index_k_cache(\n            forward_batch=forward_batch,\n            layer_id=layer_id,\n            key=key,\n            act_quant=act_quant,\n        )\n\n        # MHA doesn't need topk_indices\n        if not return_indices:\n            return None\n\n        # MLA: use dummy logits with topk kernel's fast path to generate indices\n        # When length <= 2048, naive_topk_cuda directly generates [0,1,...,length-1,-1,...]\n        seq_lens_expanded = metadata.get_seqlens_expanded()\n        dummy_logits = torch.zeros(\n            seq_lens_expanded.shape[0],\n            self.index_topk,\n            dtype=torch.float32,\n            device=x_meta.device,\n        )\n        return metadata.topk_transform(dummy_logits, self.index_topk)\n\n    def _get_topk_ragged_with_cp(\n        self,\n        forward_batch: ForwardBatch,\n        layer_id: int,\n        q_fp8: torch.Tensor,\n        weights: torch.Tensor,\n        metadata: BaseIndexerMetadata,\n        kv_len: int,\n        actual_seq_q: int,\n        cp_index: List[Tuple[int, int, int]] = None,\n    ) -> torch.Tensor:\n        if TYPE_CHECKING:\n            assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool)\n\n        page_size = forward_batch.token_to_kv_pool.page_size\n        assert page_size == 64, \"only support page size 64\"\n        assert len(weights.shape) == 3\n        weights = weights.squeeze(-1)\n        k_fp8_list = []\n        k_scale_list = []\n        ks_list = []\n        ke_offset_list = []\n        offset = 0\n        actual_seq_q_list = []\n        batch_idx_list = []\n\n        block_tables = metadata.get_page_table_64()\n\n        assert (\n            forward_batch.seq_lens_cpu is not None\n            and forward_batch.extend_seq_lens_cpu is not None\n        )\n        if cp_index is not None:\n            # TODO Multi-batch support has accuracy issues\n            for batch_idx, start_seq_position, end_seq_position in cp_index:\n                pre_chunk_offset = (\n                    forward_batch.seq_lens_cpu[batch_idx].item()\n                    - forward_batch.extend_seq_lens_cpu[batch_idx]\n                )\n                start_seq_position += pre_chunk_offset\n                end_seq_position += pre_chunk_offset\n                if offset == 0 and batch_idx != 0:\n                    offset += forward_batch.extend_seq_lens_cpu[batch_idx - 1]\n                k_fp8 = forward_batch.token_to_kv_pool.get_index_k_continuous(\n                    layer_id,\n                    end_seq_position,\n                    block_tables[batch_idx],\n                )\n                k_scale = forward_batch.token_to_kv_pool.get_index_k_scale_continuous(\n                    layer_id,\n                    end_seq_position,\n                    block_tables[batch_idx],\n                )\n\n                extend_seq_len = end_seq_position - start_seq_position\n                ks = torch.full(\n                    (extend_seq_len,), offset, dtype=torch.int32, device=\"cuda\"\n                )\n                k_fp8_list.append(k_fp8)\n                k_scale_list.append(k_scale)\n                ks_list.append(ks)\n                ke_offset = torch.arange(\n                    start_seq_position + 1,\n                    end_seq_position + 1,\n                    dtype=torch.int32,\n                    device=\"cuda\",\n                )\n                ke_offset_list.append(ke_offset)\n                actual_seq_q = torch.tensor(\n                    [extend_seq_len], dtype=torch.int32, device=\"cuda\"\n                )\n                actual_seq_q_list.append(actual_seq_q)\n                batch_idx_list.append(batch_idx)\n\n            k_fp8 = torch.cat(k_fp8_list, dim=0).view(torch.float8_e4m3fn)\n            k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).squeeze(-1)\n            kv_fp8 = (k_fp8, k_scale)\n            ks = torch.cat(ks_list, dim=0)\n            ke_offset = torch.cat(ke_offset_list, dim=0)\n            ke = ks + ke_offset\n            actual_seq_q = torch.cat(actual_seq_q_list, dim=0)\n            with self._with_real_sm_count():\n                logits = deep_gemm.fp8_mqa_logits(\n                    q_fp8,\n                    kv_fp8,\n                    weights,\n                    ks,\n                    ke,\n                    clean_logits=False,\n                )\n            topk_result = metadata.topk_transform(\n                logits,\n                self.index_topk,\n                ks=ks,\n                cu_seqlens_q=actual_seq_q,\n                ke_offset=ke_offset,\n                batch_idx_list=batch_idx_list,\n            )\n        else:\n            kv_len = (\n                forward_batch.seq_lens_cpu[0].item()\n                - forward_batch.extend_seq_lens_cpu[0]\n                + kv_len\n            )\n            k_fp8 = forward_batch.token_to_kv_pool.get_index_k_continuous(\n                layer_id,\n                kv_len,\n                block_tables[0],\n            )\n            k_scale = forward_batch.token_to_kv_pool.get_index_k_scale_continuous(\n                layer_id,\n                kv_len,\n                block_tables[0],\n            )\n\n            k_fp8 = k_fp8.view(torch.float8_e4m3fn)\n            k_scale = k_scale.view(torch.float32).squeeze(-1)\n            kv_fp8 = (k_fp8, k_scale)\n            ks = torch.full((actual_seq_q,), offset, dtype=torch.int32, device=\"cuda\")\n            ke_offset = torch.arange(\n                (kv_len - actual_seq_q) + 1,\n                kv_len + 1,\n                dtype=torch.int32,\n                device=\"cuda\",\n            )\n            ke = ks + ke_offset\n\n            with self._with_real_sm_count():\n                logits = deep_gemm.fp8_mqa_logits(\n                    q_fp8,\n                    kv_fp8,\n                    weights,\n                    ks,\n                    ke,\n                    clean_logits=False,\n                )\n            actual_seq_q = torch.tensor([actual_seq_q], dtype=torch.int32).to(\n                device=\"cuda\", non_blocking=True\n            )\n            topk_result = metadata.topk_transform(\n                logits,\n                self.index_topk,\n                ks=ks,\n                cu_seqlens_q=actual_seq_q,\n                ke_offset=ke_offset,\n            )\n\n        return topk_result\n\n    def forward_indexer(\n        self,\n        q_fp8: torch.Tensor,\n        weights: torch.Tensor,\n        forward_batch: ForwardBatch,\n        topk: int,\n        layer_id: int,\n    ) -> Optional[torch.Tensor]:\n        if not _is_npu:\n            from sglang.srt.layers.attention.nsa.tilelang_kernel import fp8_index\n\n        page_size = forward_batch.token_to_kv_pool.page_size\n        assert page_size == 64, \"only support page size 64\"\n\n        assert len(weights.shape) == 3\n        weights = weights.squeeze(-1)\n\n        # logits = deep_gemm.fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke)\n        k_fp8_list = []\n        k_scale_list = []\n\n        topk_indices_list = []\n\n        block_tables = forward_batch.req_to_token_pool.req_to_token[\n            forward_batch.req_pool_indices, :\n        ]\n        strided_indices = torch.arange(\n            0, block_tables.shape[-1], page_size, device=\"cuda\"\n        )\n        block_tables = block_tables[:, strided_indices] // page_size\n\n        q_len_start = 0\n\n        for i in range(forward_batch.batch_size):\n            seq_len = forward_batch.seq_lens[i].item()\n            q_len = (\n                forward_batch.extend_seq_lens_cpu[i]\n                if forward_batch.forward_mode.is_extend()\n                else 1\n            )\n            q_len_end = q_len_start + q_len\n\n            q_fp8_partial = q_fp8[q_len_start:q_len_end]\n            q_fp8_partial = q_fp8_partial.unsqueeze(0).contiguous()\n\n            weights_partial = weights[q_len_start:q_len_end]\n            weights_partial = weights_partial.squeeze(-1).unsqueeze(0).contiguous()\n\n            k_fp8 = forward_batch.token_to_kv_pool.get_index_k_continuous(\n                layer_id,\n                seq_len,\n                block_tables[i],\n            )\n            k_scale = forward_batch.token_to_kv_pool.get_index_k_scale_continuous(\n                layer_id,\n                seq_len,\n                block_tables[i],\n            )\n\n            k_fp8 = k_fp8.view(torch.float8_e4m3fn).unsqueeze(0).contiguous()\n            k_scale = k_scale.view(torch.float32).squeeze(-1).unsqueeze(0).contiguous()\n\n            index_score = fp8_index(\n                q_fp8_partial,\n                weights_partial,\n                k_fp8,\n                k_scale,\n            )\n            end_pos = seq_len\n            topk_indices = index_score.topk(min(topk, end_pos), dim=-1)[1].squeeze(0)\n\n            pad_len = ceil_align(topk_indices.shape[-1], 2048) - topk_indices.shape[-1]\n            topk_indices = torch.nn.functional.pad(\n                topk_indices, (0, pad_len), \"constant\", -1\n            )\n\n            topk_indices_list.append(topk_indices)\n\n            q_len_start = q_len_end\n\n        topk_indices = torch.cat(topk_indices_list, dim=0)\n        return topk_indices\n\n    def _store_index_k_cache(\n        self,\n        forward_batch: ForwardBatch,\n        layer_id: int,\n        key: torch.Tensor,\n        *,\n        act_quant=None,  # fallback only\n    ) -> None:\n        \"\"\"\n        Store NSA indexer K cache for current step.\n\n        Preferred: fused_store_index_k_cache(key, cache, out_cache_loc, page_size)\n        Fallback : act_quant(key) + token_to_kv_pool.set_index_k_scale_buffer(...)\n        \"\"\"\n\n        # Fast path: JIT fused store (CUDA, page_size=64, non-fnuz)\n        if (\n            _is_cuda\n            and (not _is_fp8_fnuz)\n            and can_use_nsa_fused_store(\n                key.dtype,\n                forward_batch.out_cache_loc.dtype,\n                forward_batch.token_to_kv_pool.page_size,\n            )\n        ):\n            # NOTE: wrapper already normalizes shape/contiguity and asserts dtypes.\n            buf = forward_batch.token_to_kv_pool.get_index_k_with_scale_buffer(\n                layer_id=layer_id\n            )\n            fused_store_index_k_cache(\n                key,\n                buf,\n                forward_batch.out_cache_loc,\n                forward_batch.token_to_kv_pool.page_size,\n            )\n            return\n\n        # Fallback: original path\n        assert act_quant is not None\n        k_fp8, k_scale = act_quant(key, self.block_size, self.scale_fmt)\n\n        out_loc = forward_batch.out_cache_loc\n        if not out_loc.is_contiguous():\n            out_loc = out_loc.contiguous()\n\n        forward_batch.token_to_kv_pool.set_index_k_scale_buffer(\n            layer_id=layer_id,\n            loc=out_loc,\n            index_k=k_fp8,\n            index_k_scale=k_scale,\n        )\n\n    def forward_cuda(\n        self,\n        x: torch.Tensor,\n        q_lora: torch.Tensor,\n        positions: torch.Tensor,\n        forward_batch: ForwardBatch,\n        layer_id: int,\n        return_indices: bool = True,\n    ) -> Optional[torch.Tensor]:\n        if _is_hip:\n            from sglang.srt.layers.attention.nsa.tilelang_kernel import act_quant\n        elif not _is_npu:\n            from sglang.srt.layers.attention.nsa.triton_kernel import act_quant\n\n        if TYPE_CHECKING:\n            assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool)\n\n        # When upstream uses fused FP8 RMSNorm+quant, activations may be passed as\n        # a tuple like (x_fp8, x_scale[, y]). Use `x_meta` for shape/device queries.\n        x_meta = x[0] if isinstance(x, tuple) else x\n\n        metadata = forward_batch.attn_backend.get_indexer_metadata(\n            layer_id, forward_batch\n        )\n\n        enable_dual_stream = (\n            self.alt_stream is not None\n            and get_is_capture_mode()\n            and q_lora.shape[0] > 0\n            and q_lora.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD\n        )\n\n        # skip NSA if attention backend choose to skip this batch\n        if metadata is None:\n            return None\n\n        # Determine if should skip topk based on sequence length\n        # We can only skip the logits computation if cuda graph is not involved\n        skip_logits_computation = False\n        if forward_batch.forward_mode.is_extend_without_speculative():\n            if forward_batch.seq_lens_cpu is not None:\n                max_kv_len = forward_batch.seq_lens_cpu.max().item()\n                skip_logits_computation = max_kv_len <= self.index_topk\n\n        # Optimization: fast path when skipping topk computation\n        if skip_logits_computation and (not self.nsa_enable_prefill_cp):\n            return self._forward_cuda_k_only(\n                x,\n                positions,\n                forward_batch,\n                layer_id,\n                act_quant,\n                enable_dual_stream,\n                metadata,\n                return_indices,\n            )\n\n        if enable_dual_stream and forward_batch.forward_mode.is_decode_or_idle():\n            current_stream = torch.cuda.current_stream()\n            self.alt_stream.wait_stream(current_stream)\n            weights = self._project_and_scale_head_gates(x)\n            query, key = self._get_q_k_bf16(\n                q_lora, x, positions, enable_dual_stream, forward_batch=forward_batch\n            )\n            q_fp8, q_scale = act_quant(query, self.block_size, self.scale_fmt)\n            with torch.cuda.stream(self.alt_stream):\n                self._store_index_k_cache(\n                    forward_batch=forward_batch,\n                    layer_id=layer_id,\n                    key=key,\n                    act_quant=act_quant,\n                )\n            current_stream.wait_stream(self.alt_stream)\n            weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale\n        else:\n            query, key = self._get_q_k_bf16(\n                q_lora, x, positions, enable_dual_stream, forward_batch=forward_batch\n            )\n\n            if enable_dual_stream:\n                current_stream = torch.cuda.current_stream()\n                self.alt_stream.wait_stream(current_stream)\n\n                q_fp8, q_scale = act_quant(query, self.block_size, self.scale_fmt)\n                with torch.cuda.stream(self.alt_stream):\n                    self._store_index_k_cache(\n                        forward_batch=forward_batch,\n                        layer_id=layer_id,\n                        key=key,\n                        act_quant=act_quant,\n                    )\n                current_stream.wait_stream(self.alt_stream)\n            else:\n                q_fp8, q_scale = act_quant(query, self.block_size, self.scale_fmt)\n                self._store_index_k_cache(\n                    forward_batch=forward_batch,\n                    layer_id=layer_id,\n                    key=key,\n                    act_quant=act_quant,\n                )\n\n            # `_get_logits_head_gate` expects a Tensor. For tuple activations, dequantize\n            # to a float tensor here (callsite), keeping `_get_logits_head_gate` backend-agnostic.\n            if isinstance(x, tuple):\n                assert len(x) in (\n                    2,\n                    3,\n                ), \"For tuple input, only (x, x_s) or (x, x_s, y) formats are accepted\"\n                x_q, x_s = x[0], x[1]\n                if (\n                    x_s is not None\n                    and x_q.dim() == 2\n                    and x_s.dim() == 2\n                    and x_q.shape[0] == x_s.shape[0]\n                ):\n                    m, n = x_q.shape\n                    ng = x_s.shape[1]\n                    if ng > 0 and n % ng == 0:\n                        group = n // ng\n                        x_for_gate = (\n                            x_q.to(torch.float32)\n                            .view(m, ng, group)\n                            .mul_(x_s.to(torch.float32).unsqueeze(-1))\n                            .view(m, n)\n                            .to(torch.bfloat16)\n                        )\n                    else:\n                        x_for_gate = x_q.to(torch.bfloat16)\n                else:\n                    x_for_gate = x_q.to(torch.bfloat16)\n            else:\n                x_for_gate = x\n\n            weights = self._get_logits_head_gate(x_for_gate, q_scale)\n\n        if _is_cuda or _is_hip:\n            assert forward_batch.seq_lens_cpu is not None\n            if len(forward_batch.seq_lens_cpu) == 0:\n                # this seems b/c max-pad, no worries?\n                # if x.shape[0] != 0:\n                #     print(\n                #         \"HACK: seq_lens empty but x not empty, hackily return all-invalid topk_result\"\n                #     )\n                return torch.full(\n                    (x_meta.shape[0], self.index_topk),\n                    -1,\n                    dtype=torch.int,\n                    device=x_meta.device,\n                )\n\n            if (\n                forward_batch.forward_mode.is_decode_or_idle()\n                or forward_batch.forward_mode.is_target_verify()\n                or forward_batch.forward_mode.is_draft_extend(include_v2=True)\n            ):\n                topk_result = self._get_topk_paged(\n                    forward_batch, layer_id, q_fp8, weights, metadata\n                )\n            else:\n                if (\n                    forward_batch.nsa_cp_metadata is not None\n                    and is_nsa_prefill_cp_in_seq_split()\n                ):\n                    kv_len_prev = forward_batch.nsa_cp_metadata.kv_len_prev\n                    kv_len_next = forward_batch.nsa_cp_metadata.kv_len_next\n                    actual_seq_q_prev = forward_batch.nsa_cp_metadata.actual_seq_q_prev\n                    actual_seq_q_next = forward_batch.nsa_cp_metadata.actual_seq_q_next\n\n                    # TODO support mutil-batch\n                    # cp_batch_seq_index_prev = forward_batch.nsa_cp_metadata[\"cp_batch_seq_index_prev\"]\n                    # cp_batch_seq_index_next = forward_batch.nsa_cp_metadata[\"cp_batch_seq_index_next\"]\n                    # TODO prev, next, combined into a single call\n                    q_fp8_prev, q_fp8_next = torch.split(\n                        q_fp8, (q_fp8.shape[0] + 1) // 2, dim=0\n                    )\n                    weights_prev, weights_next = torch.split(\n                        weights, (weights.shape[0] + 1) // 2, dim=0\n                    )\n                    topk_result_prev = self._get_topk_ragged_with_cp(\n                        forward_batch,\n                        layer_id,\n                        q_fp8_prev,\n                        weights_prev,\n                        metadata,\n                        kv_len_prev,\n                        actual_seq_q_prev,\n                    )\n\n                    topk_result_next = self._get_topk_ragged_with_cp(\n                        forward_batch,\n                        layer_id,\n                        q_fp8_next,\n                        weights_next,\n                        metadata,\n                        kv_len_next,\n                        actual_seq_q_next,\n                    )\n                    return torch.cat([topk_result_prev, topk_result_next], dim=0)\n                else:\n                    topk_result = self._get_topk_ragged(\n                        enable_dual_stream,\n                        forward_batch,\n                        layer_id,\n                        q_fp8,\n                        weights,\n                        metadata,\n                    )\n        else:\n            topk_result = self.forward_indexer(\n                q_fp8.contiguous(),\n                weights,\n                forward_batch,\n                topk=self.index_topk,\n                layer_id=layer_id,\n            )\n        return topk_result\n\n    def forward_npu(\n        self,\n        x: torch.Tensor,\n        q_lora: torch.Tensor,\n        positions: torch.Tensor,\n        forward_batch: ForwardBatch,\n        layer_id: int,\n        layer_scatter_modes=None,\n        dynamic_scale: torch.Tensor = None,\n    ) -> torch.Tensor:\n        if forward_batch.attn_backend.forward_metadata.seq_lens_cpu_int is None:\n            actual_seq_lengths_kv = forward_batch.attn_backend.forward_metadata.seq_lens\n        else:\n            actual_seq_lengths_kv = (\n                forward_batch.attn_backend.forward_metadata.seq_lens_cpu_int\n            )\n        is_prefill = (\n            forward_batch.forward_mode.is_extend()\n            and not forward_batch.forward_mode.is_draft_extend_v2()\n            and not forward_batch.forward_mode.is_target_verify()\n            and not forward_batch.forward_mode.is_draft_extend()\n        )\n\n        cos_sin = self.rotary_emb.cos_sin_cache[positions]\n        cos, sin = cos_sin.chunk(2, dim=-1)\n        cos = cos.repeat(1, 2).view(-1, 1, 1, self.rope_head_dim)\n        sin = sin.repeat(1, 2).view(-1, 1, 1, self.rope_head_dim)\n\n        bs = q_lora.shape[0]\n        if self.alt_stream is not None:\n            self.alt_stream.wait_stream(torch.npu.current_stream())\n            with torch.npu.stream(self.alt_stream):\n                q_lora = (\n                    (q_lora, dynamic_scale) if dynamic_scale is not None else q_lora\n                )\n                q = self.wq_b(q_lora)[\n                    0\n                ]  # [bs, 1536] @ [1536, 64 * 128] = [bs, 64 * 128]\n                wq_b_event = self.alt_stream.record_event()\n                q = q.view(bs, self.n_heads, self.head_dim)  # [bs, 64, 128]\n                q_pe, q_nope = torch.split(\n                    q,\n                    [self.rope_head_dim, self.head_dim - self.rope_head_dim],\n                    dim=-1,\n                )  # [bs, 64, 64 + 64]\n                q_pe = q_pe.view(bs, self.n_heads, 1, self.rope_head_dim)\n                q_pe = torch_npu.npu_rotary_mul(q_pe, cos, sin).view(\n                    bs, self.n_heads, self.rope_head_dim\n                )  # [bs, n, d]\n                q = torch.cat([q_pe, q_nope], dim=-1)\n                q.record_stream(self.alt_stream)\n                q_rope_event = self.alt_stream.record_event()\n        else:\n            q_lora = (q_lora, dynamic_scale) if dynamic_scale is not None else q_lora\n            q = self.wq_b(q_lora)[0]  # [bs, 1536] @ [1536, 64 * 128] = [bs, 64 * 128]\n            q = q.view(bs, self.n_heads, self.head_dim)  # [bs, 64, 128]\n            q_pe, q_nope = torch.split(\n                q,\n                [self.rope_head_dim, self.head_dim - self.rope_head_dim],\n                dim=-1,\n            )  # [bs, 64, 64 + 64]\n            q_pe = q_pe.view(bs, self.n_heads, 1, self.rope_head_dim)\n            q_pe = torch_npu.npu_rotary_mul(q_pe, cos, sin).view(\n                bs, self.n_heads, self.rope_head_dim\n            )  # [bs, n, d]\n            q = torch.cat([q_pe, q_nope], dim=-1)\n\n        if envs.SGLANG_NPU_USE_MULTI_STREAM.get():\n            indexer_weight_stream = get_indexer_weight_stream()\n            indexer_weight_stream.wait_stream(torch.npu.current_stream())\n            with torch.npu.stream(indexer_weight_stream):\n                x = x.view(-1, self.hidden_size)\n                weights = self.weights_proj(x.float())[0].to(torch.bfloat16)\n                weights.record_stream(indexer_weight_stream)\n                weights_event = indexer_weight_stream.record_event()\n        else:\n            x = x.view(-1, self.hidden_size)\n            weights = self.weights_proj(x.float())[0].to(torch.bfloat16)\n\n        k_proj = self.wk(x)[0]  # [b, s, 7168] @ [7168, 128] = [b, s, 128]\n        k = self.k_norm(k_proj)\n        if (\n            _use_ag_after_qlora\n            and layer_scatter_modes.layer_input_mode == ScatterMode.SCATTERED\n            and layer_scatter_modes.attn_mode == ScatterMode.TP_ATTN_FULL\n        ):\n            k = scattered_to_tp_attn_full(k, forward_batch)\n        k_pe, k_nope = torch.split(\n            k,\n            [self.rope_head_dim, self.head_dim - self.rope_head_dim],\n            dim=-1,\n        )  # [bs, 64 + 64]\n\n        k_pe = k_pe.view(-1, 1, 1, self.rope_head_dim)\n        k_pe = torch.ops.npu.npu_rotary_mul(k_pe, cos, sin).view(\n            bs, 1, self.rope_head_dim\n        )  # [bs, 1, d]\n        k = torch.cat([k_pe, k_nope.unsqueeze(1)], dim=-1)  # [bs, 1, 128]\n\n        if (\n            is_prefill\n            and self.nsa_enable_prefill_cp\n            and forward_batch.nsa_cp_metadata is not None\n        ):\n            k = cp_all_gather_rerange_output(\n                k.contiguous().view(-1, self.head_dim),\n                self.cp_size,\n                forward_batch,\n                torch.npu.current_stream(),\n            )\n\n        forward_batch.token_to_kv_pool.set_index_k_buffer(\n            layer_id, forward_batch.out_cache_loc, k\n        )\n        if is_prefill:\n            if self.nsa_enable_prefill_cp and forward_batch.nsa_cp_metadata is not None:\n                forward_batch.attn_backend.forward_metadata.actual_seq_lengths_q = (\n                    forward_batch.nsa_cp_metadata.actual_seq_q_prev_tensor,\n                    forward_batch.nsa_cp_metadata.actual_seq_q_next_tensor,\n                )\n                forward_batch.attn_backend.forward_metadata.actual_seq_lengths_kv = (\n                    forward_batch.nsa_cp_metadata.kv_len_prev_tensor,\n                    forward_batch.nsa_cp_metadata.kv_len_next_tensor,\n                )\n                actual_seq_lengths_q = (\n                    forward_batch.attn_backend.forward_metadata.actual_seq_lengths_q\n                )\n                actual_seq_lengths_kv = (\n                    forward_batch.attn_backend.forward_metadata.actual_seq_lengths_kv\n                )\n            else:\n                actual_seq_lengths_kv = forward_batch.seq_lens\n                actual_seq_lengths_q = forward_batch.extend_seq_lens.cumsum(dim=0)\n        else:\n            if forward_batch.attn_backend.forward_metadata.actual_seq_lengths_q is None:\n                if (\n                    forward_batch.forward_mode.is_draft_extend_v2()\n                    or forward_batch.forward_mode.is_target_verify()\n                    or forward_batch.forward_mode.is_draft_extend()\n                ):\n                    num_draft_tokens = (\n                        forward_batch.attn_backend.speculative_num_draft_tokens\n                    )\n                    actual_seq_lengths_q = torch.arange(\n                        num_draft_tokens,\n                        num_draft_tokens + bs,\n                        num_draft_tokens,\n                        dtype=torch.int32,\n                        device=k.device,\n                    )\n                else:\n                    actual_seq_lengths_q = torch.tensor(\n                        [1 + i * 1 for i in range(bs)],\n                        dtype=torch.int32,\n                        device=k.device,\n                    )\n            else:\n                actual_seq_lengths_q = (\n                    forward_batch.attn_backend.forward_metadata.actual_seq_lengths_q\n                )\n\n        past_key_states = forward_batch.token_to_kv_pool.get_index_k_buffer(layer_id)\n\n        if self.alt_stream is not None:\n            torch.npu.current_stream().wait_event(q_rope_event)\n        if envs.SGLANG_NPU_USE_MULTI_STREAM.get():\n            torch.npu.current_stream().wait_event(weights_event)\n        if (\n            _use_ag_after_qlora\n            and layer_scatter_modes.layer_input_mode == ScatterMode.SCATTERED\n            and layer_scatter_modes.attn_mode == ScatterMode.TP_ATTN_FULL\n        ):\n            weights = scattered_to_tp_attn_full(weights, forward_batch)\n        block_table = forward_batch.attn_backend.forward_metadata.block_tables\n        if (\n            is_prefill\n            and self.nsa_enable_prefill_cp\n            and forward_batch.nsa_cp_metadata is not None\n        ):\n            block_table = block_table[: actual_seq_lengths_q[0].numel()]\n            topk_indices = self.do_npu_cp_balance_indexer(\n                q.view(-1, self.n_heads, self.head_dim),\n                past_key_states,\n                weights,\n                actual_seq_lengths_q,\n                actual_seq_lengths_kv,\n                block_table,\n            )\n            return topk_indices\n        else:\n            block_table = (\n                block_table[: actual_seq_lengths_q.size()[0]]\n                if is_prefill\n                else block_table\n            )\n\n            topk_indices = torch_npu.npu_lightning_indexer(\n                query=q.view(-1, self.n_heads, self.head_dim),\n                key=past_key_states,\n                weights=weights,\n                actual_seq_lengths_query=actual_seq_lengths_q.to(torch.int32),\n                actual_seq_lengths_key=actual_seq_lengths_kv.to(k.device).to(\n                    torch.int32\n                ),\n                block_table=block_table,\n                layout_query=\"TND\",\n                layout_key=\"PA_BSND\",\n                sparse_count=self.index_topk,\n                sparse_mode=3,\n            )\n            return topk_indices[0]\n\n    def do_npu_cp_balance_indexer(\n        self,\n        q,\n        past_key_states,\n        indexer_weights,\n        actual_seq_lengths_q,\n        actual_seq_lengths_kv,\n        block_table,\n    ):\n        q_prev, q_next = torch.split(q, (q.size(0) + 1) // 2, dim=0)\n        weights_prev, weights_next = None, None\n        if indexer_weights is not None:\n            weights_prev, weights_next = torch.split(\n                indexer_weights, (indexer_weights.size(0) + 1) // 2, dim=0\n            )\n            weights_prev = weights_prev.contiguous().view(-1, weights_prev.shape[-1])\n            weights_next = weights_next.contiguous().view(-1, weights_next.shape[-1])\n\n        actual_seq_lengths_q_prev, actual_seq_lengths_q_next = actual_seq_lengths_q\n        actual_seq_lengths_kv_prev, actual_seq_lengths_kv_next = actual_seq_lengths_kv\n\n        topk_indices_prev = torch_npu.npu_lightning_indexer(\n            query=q_prev,\n            key=past_key_states,\n            weights=weights_prev,\n            actual_seq_lengths_query=actual_seq_lengths_q_prev.to(\n                device=q.device, dtype=torch.int32\n            ),\n            actual_seq_lengths_key=actual_seq_lengths_kv_prev.to(\n                device=q.device, dtype=torch.int32\n            ),\n            block_table=block_table,\n            layout_query=\"TND\",\n            layout_key=\"PA_BSND\",\n            sparse_count=self.index_topk,\n            sparse_mode=3,\n        )\n        topk_indices_next = torch_npu.npu_lightning_indexer(\n            query=q_next,\n            key=past_key_states,\n            weights=weights_next,\n            actual_seq_lengths_query=actual_seq_lengths_q_next.to(\n                device=q.device, dtype=torch.int32\n            ),\n            actual_seq_lengths_key=actual_seq_lengths_kv_next.to(\n                device=q.device, dtype=torch.int32\n            ),\n            block_table=block_table,\n            layout_query=\"TND\",\n            layout_key=\"PA_BSND\",\n            sparse_count=self.index_topk,\n            sparse_mode=3,\n        )\n        return topk_indices_prev[0], topk_indices_next[0]\n\n\ndef scattered_to_tp_attn_full(\n    hidden_states: torch.Tensor,\n    forward_batch,\n) -> torch.Tensor:\n    hidden_states, local_hidden_states = (\n        torch.empty(\n            (forward_batch.input_ids.shape[0], hidden_states.shape[1]),\n            dtype=hidden_states.dtype,\n            device=hidden_states.device,\n        ),\n        hidden_states,\n    )\n    attn_tp_all_gather_into_tensor(hidden_states, local_hidden_states.contiguous())\n    return hidden_states\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/nsa/nsa_mtp_verification.py",
    "content": "\"\"\"\nVerification utilities for NSA backend fused metadata copy operations.\n\nThis module contains verification code to ensure that fused metadata copy kernels\nproduce the same results as individual copy operations.\n\"\"\"\n\nimport torch\n\n\ndef verify_single_backend_fused_metadata_copy(\n    metadata,\n    precomputed,\n    forward_mode,\n    bs,\n    flashmla_num_splits_src=None,\n    flashmla_metadata_src=None,\n    flashmla_num_splits_dst=None,\n    flashmla_metadata_dst=None,\n):\n    \"\"\"\n    Verify that the fused metadata copy kernel produces the same results as individual copies.\n\n    Args:\n        metadata: The NSA metadata object containing destination tensors\n        precomputed: The precomputed metadata containing source tensors\n        forward_mode: The forward mode (decode, target_verify, or draft_extend)\n        bs: Batch size\n        flashmla_num_splits_src: Source FlashMLA num_splits tensor (optional)\n        flashmla_metadata_src: Source FlashMLA metadata tensor (optional)\n        flashmla_num_splits_dst: Destination FlashMLA num_splits tensor (optional)\n        flashmla_metadata_dst: Destination FlashMLA metadata tensor (optional)\n\n    Raises:\n        RuntimeError: If verification fails (tensors don't match)\n    \"\"\"\n    # Clone destination tensors to preserve fused kernel results\n    fused_cache_seqlens = metadata.cache_seqlens_int32.clone()\n    fused_cu_seqlens_k = metadata.cu_seqlens_k.clone()\n    fused_page_table_1 = metadata.page_table_1.clone()\n    fused_nsa_cache_seqlens = metadata.nsa_cache_seqlens_int32.clone()\n    fused_nsa_seqlens_expanded = metadata.nsa_seqlens_expanded.clone()\n    fused_nsa_cu_seqlens_k = metadata.nsa_cu_seqlens_k.clone()\n    fused_real_page_table = (\n        metadata.real_page_table.clone()\n        if precomputed.real_page_table is not None\n        else None\n    )\n    fused_flashmla_num_splits = None\n    fused_flashmla_metadata = None\n    if precomputed.flashmla_metadata is not None:\n        fused_flashmla_num_splits = flashmla_num_splits_dst.clone()\n        fused_flashmla_metadata = flashmla_metadata_dst.clone()\n\n    # Create reference tensors (zeroed out)\n    ref_cache_seqlens = torch.zeros_like(metadata.cache_seqlens_int32)\n    ref_cu_seqlens_k = torch.zeros_like(metadata.cu_seqlens_k)\n    ref_page_table_1 = torch.zeros_like(metadata.page_table_1)\n    ref_nsa_cache_seqlens = torch.zeros_like(metadata.nsa_cache_seqlens_int32)\n    ref_nsa_seqlens_expanded = torch.zeros_like(metadata.nsa_seqlens_expanded)\n    ref_nsa_cu_seqlens_k = torch.zeros_like(metadata.nsa_cu_seqlens_k)\n    ref_real_page_table = (\n        torch.zeros_like(metadata.real_page_table)\n        if precomputed.real_page_table is not None\n        else None\n    )\n    ref_flashmla_num_splits = None\n    ref_flashmla_metadata = None\n    if precomputed.flashmla_metadata is not None:\n        ref_flashmla_num_splits = torch.zeros_like(flashmla_num_splits_dst)\n        ref_flashmla_metadata = torch.zeros_like(flashmla_metadata_dst)\n\n    # Run individual copy operations (reference implementation)\n    ref_cache_seqlens.copy_(precomputed.cache_seqlens)\n    ref_cu_seqlens_k[1:].copy_(precomputed.cu_seqlens_k[1:])\n\n    if forward_mode.is_decode_or_idle():\n        # Decode mode\n        ref_page_table_1[:, : precomputed.max_len].copy_(precomputed.page_indices)\n        ref_nsa_cache_seqlens.copy_(precomputed.nsa_cache_seqlens)\n    elif forward_mode.is_target_verify():\n        # Target verify mode\n        ref_page_table_1[:, : precomputed.max_seqlen_k].copy_(precomputed.page_indices)\n        ref_nsa_seqlens_expanded.copy_(precomputed.seqlens_expanded)\n        ref_nsa_cache_seqlens.copy_(precomputed.nsa_cache_seqlens)\n    elif forward_mode.is_draft_extend():\n        # Draft extend mode\n        rows = precomputed.page_indices.shape[0]\n        cols = precomputed.max_seqlen_k\n        ref_page_table_1[:rows, :cols].copy_(precomputed.page_indices)\n        size = precomputed.seqlens_expanded_size\n        ref_nsa_seqlens_expanded[:size].copy_(precomputed.seqlens_expanded)\n        ref_nsa_cache_seqlens[:size].copy_(precomputed.nsa_cache_seqlens)\n\n    # Copy NSA cu_seqlens\n    size = precomputed.seqlens_expanded_size\n    ref_nsa_cu_seqlens_k[1 : 1 + size].copy_(precomputed.nsa_cu_seqlens_k[1 : 1 + size])\n\n    # Copy real page table\n    if precomputed.real_page_table is not None:\n        rows, cols = precomputed.real_page_table.shape\n        ref_real_page_table[:rows, :cols].copy_(precomputed.real_page_table)\n\n    # Copy FlashMLA metadata\n    if precomputed.flashmla_metadata is not None:\n        size = precomputed.seqlens_expanded_size\n        ref_flashmla_num_splits[: size + 1].copy_(flashmla_num_splits_src[: size + 1])\n        ref_flashmla_metadata.copy_(flashmla_metadata_src)\n\n    # Compare results and crash if inconsistent\n    def check_tensor_equal(name, fused, ref):\n        if not torch.equal(fused, ref):\n            max_diff = (fused.float() - ref.float()).abs().max().item()\n            mismatched_elements = (fused != ref).sum().item()\n            total_elements = fused.numel()\n            raise RuntimeError(\n                f\"FUSED METADATA COPY VERIFICATION FAILED!\\n\"\n                f\"Tensor: {name}\\n\"\n                f\"Max difference: {max_diff}\\n\"\n                f\"Mismatched elements: {mismatched_elements}/{total_elements}\\n\"\n                f\"Fused shape: {fused.shape}, Ref shape: {ref.shape}\\n\"\n                f\"Forward mode: {forward_mode}, bs={bs}\\n\"\n                f\"The fused kernel produces different results than individual copies.\\n\"\n                f\"This indicates a bug in the fused metadata copy kernel.\"\n            )\n\n    # Verify all tensors (only compare the slices that were actually updated)\n    check_tensor_equal(\"cache_seqlens\", fused_cache_seqlens, ref_cache_seqlens)\n    check_tensor_equal(\"cu_seqlens_k\", fused_cu_seqlens_k, ref_cu_seqlens_k)\n\n    # Compare page_table_1 only for the region that was updated\n    if forward_mode.is_decode_or_idle():\n        check_tensor_equal(\n            \"page_table_1\",\n            fused_page_table_1[:, : precomputed.max_len],\n            ref_page_table_1[:, : precomputed.max_len],\n        )\n    elif forward_mode.is_target_verify():\n        check_tensor_equal(\n            \"page_table_1\",\n            fused_page_table_1[:, : precomputed.max_seqlen_k],\n            ref_page_table_1[:, : precomputed.max_seqlen_k],\n        )\n    elif forward_mode.is_draft_extend():\n        rows = precomputed.page_indices.shape[0]\n        cols = precomputed.max_seqlen_k\n        check_tensor_equal(\n            \"page_table_1\",\n            fused_page_table_1[:rows, :cols],\n            ref_page_table_1[:rows, :cols],\n        )\n\n    # Compare nsa_cache_seqlens only for the region that was updated\n    if forward_mode.is_decode_or_idle():\n        check_tensor_equal(\n            \"nsa_cache_seqlens\",\n            fused_nsa_cache_seqlens,\n            ref_nsa_cache_seqlens,\n        )\n    else:  # TARGET_VERIFY or DRAFT_EXTEND\n        size = precomputed.seqlens_expanded_size\n        check_tensor_equal(\n            \"nsa_cache_seqlens\",\n            fused_nsa_cache_seqlens[:size],\n            ref_nsa_cache_seqlens[:size],\n        )\n\n    # Compare nsa_seqlens_expanded only for TARGET_VERIFY and DRAFT_EXTEND\n    if forward_mode.is_target_verify() or forward_mode.is_draft_extend():\n        size = precomputed.seqlens_expanded_size\n        check_tensor_equal(\n            \"nsa_seqlens_expanded\",\n            fused_nsa_seqlens_expanded[:size],\n            ref_nsa_seqlens_expanded[:size],\n        )\n\n    # Compare nsa_cu_seqlens_k only for the region that was updated\n    size = precomputed.seqlens_expanded_size\n    check_tensor_equal(\n        \"nsa_cu_seqlens_k\",\n        fused_nsa_cu_seqlens_k[: 1 + size],\n        ref_nsa_cu_seqlens_k[: 1 + size],\n    )\n\n    if precomputed.real_page_table is not None:\n        rows, cols = precomputed.real_page_table.shape\n        check_tensor_equal(\n            \"real_page_table\",\n            fused_real_page_table[:rows, :cols],\n            ref_real_page_table[:rows, :cols],\n        )\n\n    if precomputed.flashmla_metadata is not None:\n        size = precomputed.seqlens_expanded_size\n        check_tensor_equal(\n            \"flashmla_num_splits\",\n            fused_flashmla_num_splits[: size + 1],\n            ref_flashmla_num_splits[: size + 1],\n        )\n        check_tensor_equal(\n            \"flashmla_metadata\",\n            fused_flashmla_metadata,\n            ref_flashmla_metadata,\n        )\n\n\ndef verify_multi_backend_fused_metadata_copy(\n    metadata0,\n    metadata1,\n    metadata2,\n    precomputed,\n    bs,\n    flashmla_num_splits_src=None,\n    flashmla_metadata_src=None,\n):\n    \"\"\"\n    Verify that the multi-backend fused metadata copy kernel produces the same results\n    as individual copies for all three backends.\n\n    Args:\n        metadata0: The NSA metadata object for backend 0\n        metadata1: The NSA metadata object for backend 1\n        metadata2: The NSA metadata object for backend 2\n        precomputed: The precomputed metadata containing source tensors\n        bs: Batch size\n        flashmla_num_splits_src: Source FlashMLA num_splits tensor (optional)\n        flashmla_metadata_src: Source FlashMLA metadata tensor (optional)\n\n    Raises:\n        RuntimeError: If verification fails (tensors don't match)\n    \"\"\"\n    # Clone destination tensors to preserve fused kernel results\n    fused_results = []\n    for idx, metadata in enumerate([metadata0, metadata1, metadata2]):\n        fused_cache_seqlens = metadata.cache_seqlens_int32.clone()\n        fused_cu_seqlens_k = metadata.cu_seqlens_k.clone()\n        fused_page_table_1 = metadata.page_table_1.clone()\n        fused_nsa_cache_seqlens = metadata.nsa_cache_seqlens_int32.clone()\n        fused_nsa_cu_seqlens_k = metadata.nsa_cu_seqlens_k.clone()\n        fused_real_page_table = (\n            metadata.real_page_table.clone()\n            if precomputed.real_page_table is not None\n            else None\n        )\n        fused_flashmla_num_splits = None\n        fused_flashmla_metadata = None\n        if precomputed.flashmla_metadata is not None:\n            fused_flashmla_num_splits = metadata.flashmla_metadata.num_splits.clone()\n            fused_flashmla_metadata = (\n                metadata.flashmla_metadata.flashmla_metadata.clone()\n            )\n\n        fused_results.append(\n            {\n                \"cache_seqlens\": fused_cache_seqlens,\n                \"cu_seqlens_k\": fused_cu_seqlens_k,\n                \"page_table_1\": fused_page_table_1,\n                \"nsa_cache_seqlens\": fused_nsa_cache_seqlens,\n                \"nsa_cu_seqlens_k\": fused_nsa_cu_seqlens_k,\n                \"real_page_table\": fused_real_page_table,\n                \"flashmla_num_splits\": fused_flashmla_num_splits,\n                \"flashmla_metadata\": fused_flashmla_metadata,\n            }\n        )\n\n    # Run individual copy operations for each backend (reference implementation)\n    ref_results = []\n    for idx in range(3):\n        metadata = [metadata0, metadata1, metadata2][idx]\n\n        # Create reference tensors (zeroed out)\n        ref_cache_seqlens = torch.zeros_like(metadata.cache_seqlens_int32)\n        ref_cu_seqlens_k = torch.zeros_like(metadata.cu_seqlens_k)\n        ref_page_table_1 = torch.zeros_like(metadata.page_table_1)\n        ref_nsa_cache_seqlens = torch.zeros_like(metadata.nsa_cache_seqlens_int32)\n        ref_nsa_cu_seqlens_k = torch.zeros_like(metadata.nsa_cu_seqlens_k)\n        ref_real_page_table = (\n            torch.zeros_like(metadata.real_page_table)\n            if precomputed.real_page_table is not None\n            else None\n        )\n        ref_flashmla_num_splits = None\n        ref_flashmla_metadata = None\n        if precomputed.flashmla_metadata is not None:\n            ref_flashmla_num_splits = torch.zeros_like(\n                metadata.flashmla_metadata.num_splits\n            )\n            ref_flashmla_metadata = torch.zeros_like(\n                metadata.flashmla_metadata.flashmla_metadata\n            )\n\n        # Copy operations (decode mode)\n        ref_cache_seqlens.copy_(precomputed.cache_seqlens)\n        ref_cu_seqlens_k[1:].copy_(precomputed.cu_seqlens_k[1:])\n        ref_page_table_1[:, : precomputed.max_len].copy_(precomputed.page_indices)\n        ref_nsa_cache_seqlens.copy_(precomputed.nsa_cache_seqlens)\n\n        # Copy NSA cu_seqlens\n        size = precomputed.seqlens_expanded_size\n        ref_nsa_cu_seqlens_k[1 : 1 + size].copy_(\n            precomputed.nsa_cu_seqlens_k[1 : 1 + size]\n        )\n\n        # Copy real page table\n        if precomputed.real_page_table is not None:\n            rows, cols = precomputed.real_page_table.shape\n            ref_real_page_table[:rows, :cols].copy_(precomputed.real_page_table)\n\n        # Copy FlashMLA metadata\n        if precomputed.flashmla_metadata is not None:\n            ref_flashmla_num_splits[: size + 1].copy_(\n                flashmla_num_splits_src[: size + 1]\n            )\n            ref_flashmla_metadata.copy_(flashmla_metadata_src)\n\n        ref_results.append(\n            {\n                \"cache_seqlens\": ref_cache_seqlens,\n                \"cu_seqlens_k\": ref_cu_seqlens_k,\n                \"page_table_1\": ref_page_table_1,\n                \"nsa_cache_seqlens\": ref_nsa_cache_seqlens,\n                \"nsa_cu_seqlens_k\": ref_nsa_cu_seqlens_k,\n                \"real_page_table\": ref_real_page_table,\n                \"flashmla_num_splits\": ref_flashmla_num_splits,\n                \"flashmla_metadata\": ref_flashmla_metadata,\n            }\n        )\n\n    # Compare results for all 3 backends\n    def check_tensor_equal(backend_idx, name, fused, ref):\n        if not torch.equal(fused, ref):\n            max_diff = (fused.float() - ref.float()).abs().max().item()\n            mismatched_elements = (fused != ref).sum().item()\n            total_elements = fused.numel()\n            raise RuntimeError(\n                f\"MULTI-BACKEND FUSED METADATA COPY VERIFICATION FAILED!\\n\"\n                f\"Backend: {backend_idx}\\n\"\n                f\"Tensor: {name}\\n\"\n                f\"Max difference: {max_diff}\\n\"\n                f\"Mismatched elements: {mismatched_elements}/{total_elements}\\n\"\n                f\"Fused shape: {fused.shape}, Ref shape: {ref.shape}\\n\"\n                f\"Batch size: {bs}\\n\"\n                f\"The multi-backend fused kernel produces different results than individual copies.\\n\"\n                f\"This indicates a bug in the fused metadata copy kernel.\"\n            )\n\n    # Verify all tensors for all 3 backends (multi-backend is DECODE mode only)\n    for idx in range(3):\n        fused = fused_results[idx]\n        ref = ref_results[idx]\n\n        check_tensor_equal(\n            idx,\n            \"cache_seqlens\",\n            fused[\"cache_seqlens\"],\n            ref[\"cache_seqlens\"],\n        )\n        check_tensor_equal(\n            idx,\n            \"cu_seqlens_k\",\n            fused[\"cu_seqlens_k\"],\n            ref[\"cu_seqlens_k\"],\n        )\n        # Multi-backend is DECODE mode only, so compare only [:, :max_len]\n        check_tensor_equal(\n            idx,\n            \"page_table_1\",\n            fused[\"page_table_1\"][:, : precomputed.max_len],\n            ref[\"page_table_1\"][:, : precomputed.max_len],\n        )\n        check_tensor_equal(\n            idx,\n            \"nsa_cache_seqlens\",\n            fused[\"nsa_cache_seqlens\"],\n            ref[\"nsa_cache_seqlens\"],\n        )\n        # DECODE mode uses bs for nsa_cu_seqlens_k size\n        check_tensor_equal(\n            idx,\n            \"nsa_cu_seqlens_k\",\n            fused[\"nsa_cu_seqlens_k\"][: bs + 1],\n            ref[\"nsa_cu_seqlens_k\"][: bs + 1],\n        )\n\n        if precomputed.real_page_table is not None:\n            rows, cols = precomputed.real_page_table.shape\n            check_tensor_equal(\n                idx,\n                \"real_page_table\",\n                fused[\"real_page_table\"][:rows, :cols],\n                ref[\"real_page_table\"][:rows, :cols],\n            )\n\n        if precomputed.flashmla_metadata is not None:\n            # DECODE mode uses bs + 1 for flashmla_num_splits\n            check_tensor_equal(\n                idx,\n                \"flashmla_num_splits\",\n                fused[\"flashmla_num_splits\"][: bs + 1],\n                ref[\"flashmla_num_splits\"][: bs + 1],\n            )\n            check_tensor_equal(\n                idx,\n                \"flashmla_metadata\",\n                fused[\"flashmla_metadata\"],\n                ref[\"flashmla_metadata\"],\n            )\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/nsa/quant_k_cache.py",
    "content": "import torch\nimport triton\nimport triton.language as tl\n\n\ndef quantize_k_cache(cache_k):\n    return _quantize_k_cache_fast_wrapped(cache_k)\n\n\ndef quantize_k_cache_separate(\n    k_nope: torch.Tensor,\n    k_rope: torch.Tensor,\n    tile_size: int = 128,\n):\n    \"\"\"\n    Quantize k_nope and k_rope separately without concat, returns two tensors.\n\n    This avoids the concat operation and enables direct reuse of set_mla_kv_buffer_triton\n    by returning two separate byte tensors for the nope and rope parts.\n\n    Args:\n        k_nope: (num_tokens, dim_nope) or (num_tokens, 1, dim_nope)\n                Must have dim_nope=512 for FP8 MLA quantization\n        k_rope: (num_tokens, dim_rope) or (num_tokens, 1, dim_rope)\n                Must have dim_rope=64 for FP8 MLA quantization\n        tile_size: quantization tile size (default 128)\n\n    Returns:\n        Tuple of (nope_part, rope_part) where:\n        - nope_part: (num_tokens, 1, 528) as uint8 view, contains [nope_fp8(512) | scales(16)]\n        - rope_part: (num_tokens, 1, 128) as uint8 view, contains [rope_bf16_bytes(128)]\n\n        These two tensors can be directly passed to set_mla_kv_buffer_triton(kv_buffer, loc, nope_part, rope_part)\n    \"\"\"\n    # Squeeze middle dimension if present\n    k_nope_2d = k_nope.squeeze(1) if k_nope.ndim == 3 else k_nope\n    k_rope_2d = k_rope.squeeze(1) if k_rope.ndim == 3 else k_rope\n\n    num_tokens = k_nope_2d.shape[0]\n    dim_nope = k_nope_2d.shape[1]\n    dim_rope = k_rope_2d.shape[1]\n\n    # Validate dimensions for FP8 MLA\n    if dim_nope != 512:\n        raise ValueError(f\"Expected dim_nope=512 for FP8 MLA, got {dim_nope}\")\n    if dim_rope != 64:\n        raise ValueError(f\"Expected dim_rope=64 for FP8 MLA, got {dim_rope}\")\n    if k_rope_2d.shape[0] != num_tokens:\n        raise ValueError(\n            f\"k_nope and k_rope must have same num_tokens, got {num_tokens} vs {k_rope_2d.shape[0]}\"\n        )\n\n    return _quantize_k_cache_fast_separate(\n        k_nope=k_nope_2d, k_rope=k_rope_2d, group_size=tile_size\n    )\n\n\n# Copied from original\ndef _quantize_k_cache_ref(\n    input_k_cache: torch.Tensor,  # (num_blocks, block_size, h_k, d)\n    dv: int = 512,\n    tile_size: int = 128,\n) -> torch.Tensor:\n    \"\"\"\n    Quantize the k-cache\n    Return a tensor with shape (num_blocks, block_size, h_k, dv + 4(dv/tile_size) + t(d-dv)) of dtype uint8_t, where t = input_k_cache.element_size()\n    For more detail about the layout of K/V, please refer to comments in flash_mla_interface.py or README.md\n    \"\"\"\n    assert dv % tile_size == 0\n    num_tiles = dv // tile_size\n    num_blocks, block_size, h_k, d = input_k_cache.shape\n    assert h_k == 1\n    input_k_cache = input_k_cache.squeeze(2)  # [num_blocks, block_size, d]\n    input_elem_size = input_k_cache.element_size()\n\n    result = torch.empty(\n        (num_blocks, block_size, dv + num_tiles * 4 + input_elem_size * (d - dv)),\n        dtype=torch.float8_e4m3fn,\n        device=input_k_cache.device,\n    )\n    result_k_nope_part = result[..., :dv]\n    result_k_scale_factor = result[..., dv : dv + num_tiles * 4].view(torch.float32)\n    result_k_rope_part = result[..., dv + num_tiles * 4 :].view(input_k_cache.dtype)\n    result_k_rope_part[:] = input_k_cache[..., dv:]\n\n    for tile_idx in range(0, num_tiles):\n        cur_scale_factors_inv = (\n            torch.abs(\n                input_k_cache[..., tile_idx * tile_size : (tile_idx + 1) * tile_size]\n            )\n            .max(dim=-1)\n            .values\n            / 448.0\n        )  # [num_blocks, block_size]\n        result_k_scale_factor[:, :, tile_idx] = cur_scale_factors_inv\n\n        cur_scale_factors_inv.unsqueeze_(-1)  # [num_blocks, block_size, 1]\n        cur_quantized_nope = (\n            input_k_cache[\n                ..., tile_idx * tile_size : (tile_idx + 1) * tile_size\n            ].float()\n            / cur_scale_factors_inv.float()\n        ).to(torch.float8_e4m3fn)\n        result_k_nope_part[..., tile_idx * tile_size : (tile_idx + 1) * tile_size] = (\n            cur_quantized_nope\n        )\n\n    result = result.view(num_blocks, block_size, 1, -1)\n    return result\n\n\ndef _quantize_k_cache_fast_wrapped(\n    input_k_cache: torch.Tensor,\n    dv: int = 512,\n    tile_size: int = 128,\n) -> torch.Tensor:\n    # TODO the final API may be 2D instead of 4D, thus we convert them here\n    num_blocks, block_size, _, dim_nope_and_rope = input_k_cache.shape\n    assert dv == 512\n    assert dim_nope_and_rope == 512 + 64\n    assert tile_size == 128\n    input_k_cache = input_k_cache.view((-1, dim_nope_and_rope))\n\n    # TODO deliberately split into two tensors, then upstream can provide the two tensors instead of concat into one\n    k_nope = input_k_cache[:, :dv]\n    k_rope = input_k_cache[:, dv:]\n\n    output = _quantize_k_cache_fast(k_nope=k_nope, k_rope=k_rope)\n\n    return output.view(num_blocks, block_size, 1, -1)\n\n\ndef _quantize_k_cache_fast(k_nope, k_rope, group_size: int = 128):\n    \"\"\"\n    :param k_nope: (num_tokens, dim_nope 512)\n    :param k_rope: (num_tokens, dim_rope 64)\n    \"\"\"\n\n    assert k_nope.dtype == torch.bfloat16\n    assert k_rope.dtype == torch.bfloat16\n\n    num_tokens, dim_nope = k_nope.shape\n    num_tokens_, dim_rope = k_rope.shape\n    assert num_tokens == num_tokens_\n    assert dim_nope == 512\n    assert dim_rope == 64\n    assert k_nope.dtype == k_rope.dtype\n    num_tiles = dim_nope // group_size\n\n    assert k_nope.stride(1) == 1\n    assert k_rope.stride(1) == 1\n\n    output = torch.empty(\n        (num_tokens, dim_nope + num_tiles * 4 + k_rope.element_size() * dim_rope),\n        dtype=torch.float8_e4m3fn,\n        device=k_nope.device,\n    )\n    output_nope_q = output[..., :dim_nope]\n    output_nope_s = output[..., dim_nope : dim_nope + num_tiles * 4].view(torch.float32)\n    output_rope = output[..., dim_nope + num_tiles * 4 :].view(torch.bfloat16)\n\n    num_blocks_per_token = triton.cdiv(dim_nope + dim_rope, group_size)\n    assert num_blocks_per_token == 5\n\n    assert dim_nope % group_size == 0\n    NUM_NOPE_BLOCKS = dim_nope // group_size\n\n    _quantize_k_cache_fast_kernel[(num_tokens, num_blocks_per_token)](\n        output_nope_q,\n        output_nope_s,\n        output_rope,\n        k_nope,\n        k_rope,\n        output_nope_q.stride(0),\n        output_nope_s.stride(0),\n        output_rope.stride(0),\n        k_nope.stride(0),\n        k_rope.stride(0),\n        NUM_NOPE_BLOCKS=NUM_NOPE_BLOCKS,\n        GROUP_SIZE=group_size,\n        DIM_NOPE=dim_nope,\n        DIM_ROPE=dim_rope,\n        FP8_MIN=torch.finfo(torch.float8_e4m3fn).min,\n        FP8_MAX=torch.finfo(torch.float8_e4m3fn).max,\n    )\n\n    return output\n\n\ndef _quantize_k_cache_fast_separate(k_nope, k_rope, group_size: int = 128):\n    \"\"\"\n    Quantize k_nope and k_rope in a single Triton kernel, directly outputting two separate tensors.\n\n    This avoids packing/unpacking and enables direct use with set_mla_kv_buffer_triton.\n\n    :param k_nope: (num_tokens, dim_nope 512) bfloat16\n    :param k_rope: (num_tokens, dim_rope 64) bfloat16\n    :param group_size: quantization tile size (default 128, kernel is tuned for this value)\n    :return: Tuple of (nope_part_u8, rope_part_u8)\n        - nope_part_u8: (num_tokens, 1, nope_part_bytes) uint8, layout [nope_fp8(dim_nope) | scales(num_tiles*4)]\n        - rope_part_u8: (num_tokens, 1, rope_part_bytes) uint8, layout [rope_bf16_bytes(dim_rope*2)]\n    \"\"\"\n    num_tokens, dim_nope = k_nope.shape\n    num_tokens_, dim_rope = k_rope.shape\n\n    assert num_tokens == num_tokens_, f\"k_nope and k_rope must have same num_tokens\"\n\n    # Ensure contiguous tensors for kernel\n    k_nope = k_nope.contiguous()\n    k_rope = k_rope.contiguous()\n\n    num_tiles = dim_nope // group_size\n\n    # Calculate byte sizes based on validated dimensions\n    # nope_part: [FP8 quantized data (dim_nope bytes)] + [FP32 scales (num_tiles * 4 bytes)]\n    # rope_part: [BF16 raw data (dim_rope * 2 bytes)]\n    nope_part_bytes = (\n        dim_nope + num_tiles * 4\n    )  # e.g., 512 + 4*4 = 528 for dim_nope=512, group_size=128\n    rope_part_bytes = (\n        dim_rope * k_rope.element_size()\n    )  # e.g., 64 * 2 = 128 for dim_rope=64, BF16\n\n    # Allocate two separate output buffers (as uint8 for direct byte-level access)\n    nope_part_u8 = torch.empty(\n        (num_tokens, nope_part_bytes), dtype=torch.uint8, device=k_nope.device\n    )\n    rope_part_u8 = torch.empty(\n        (num_tokens, rope_part_bytes), dtype=torch.uint8, device=k_rope.device\n    )\n\n    # Create typed views for the kernel to write into\n    # Fixed byte layout for nope_part: [nope_fp8 (dim_nope bytes) | scales_fp32 (num_tiles*4 bytes)]\n    # Fixed byte layout for rope_part: [rope_bf16 (dim_rope*2 bytes)]\n    nope_q_view = nope_part_u8[:, :dim_nope].view(torch.float8_e4m3fn)\n    nope_s_view = nope_part_u8[:, dim_nope:].view(torch.float32)\n    rope_view = rope_part_u8.view(torch.bfloat16)\n\n    # Kernel launch parameters\n    num_blocks_per_token = triton.cdiv(dim_nope + dim_rope, group_size)\n    NUM_NOPE_BLOCKS = dim_nope // group_size\n\n    # Use the same kernel as _quantize_k_cache_fast (reuse existing implementation)\n    _quantize_k_cache_fast_kernel[(num_tokens, num_blocks_per_token)](\n        nope_q_view,\n        nope_s_view,\n        rope_view,\n        k_nope,\n        k_rope,\n        nope_q_view.stride(0),\n        nope_s_view.stride(0),\n        rope_view.stride(0),\n        k_nope.stride(0),\n        k_rope.stride(0),\n        NUM_NOPE_BLOCKS=NUM_NOPE_BLOCKS,\n        GROUP_SIZE=group_size,\n        DIM_NOPE=dim_nope,\n        DIM_ROPE=dim_rope,\n        FP8_MIN=torch.finfo(torch.float8_e4m3fn).min,\n        FP8_MAX=torch.finfo(torch.float8_e4m3fn).max,\n    )\n\n    # Add middle dimension for compatibility with set_mla_kv_buffer_triton\n    return nope_part_u8.unsqueeze(1), rope_part_u8.unsqueeze(1)\n\n\n@triton.jit\ndef _quantize_k_cache_fast_kernel(\n    output_nope_q_ptr,\n    output_nope_s_ptr,\n    output_rope_ptr,\n    k_nope_ptr,\n    k_rope_ptr,\n    output_nope_q_stride_0: int,\n    output_nope_s_stride_0: int,\n    output_rope_stride_0: int,\n    k_nope_stride_0: int,\n    k_rope_stride_0: int,\n    NUM_NOPE_BLOCKS: tl.constexpr,\n    GROUP_SIZE: tl.constexpr,\n    DIM_NOPE: tl.constexpr,\n    DIM_ROPE: tl.constexpr,\n    FP8_MIN: tl.constexpr,\n    FP8_MAX: tl.constexpr,\n):\n    token_id = tl.program_id(0)\n    raw_block_id = tl.program_id(1)\n\n    if raw_block_id < NUM_NOPE_BLOCKS:\n        # a. quant nope\n        effective_block_id = raw_block_id\n\n        offs = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)\n        mask = offs < DIM_NOPE\n        ptr = k_nope_ptr + token_id * k_nope_stride_0 + offs\n\n        y = tl.load(ptr, mask=mask, other=0.0).to(tl.float32)\n\n        # the ref impl do not have a `tl.maximum(... eps)`, so we remove it here\n        y_s = tl.max(tl.abs(y)) / FP8_MAX\n        y_s_inv = 1.0 / y_s\n        y_q = tl.clamp(y * y_s_inv, FP8_MIN, FP8_MAX).to(\n            output_nope_q_ptr.dtype.element_ty\n        )\n\n        dst_q_ptr = output_nope_q_ptr + token_id * output_nope_q_stride_0 + offs\n        dst_s_ptr = (\n            output_nope_s_ptr + token_id * output_nope_s_stride_0 + effective_block_id\n        )\n\n        tl.store(dst_q_ptr, y_q, mask=mask)\n        tl.store(dst_s_ptr, y_s)\n    else:\n        # b. copy rope\n        effective_block_id = raw_block_id - NUM_NOPE_BLOCKS\n\n        offs = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)\n        mask = offs < DIM_ROPE\n\n        src_ptr = k_rope_ptr + token_id * k_rope_stride_0 + offs\n        dst_ptr = output_rope_ptr + token_id * output_rope_stride_0 + offs\n\n        data = tl.load(src_ptr, mask=mask)\n        tl.store(dst_ptr, data, mask=mask)\n\n\nif __name__ == \"__main__\":\n    import dequant_k_cache\n\n    for num_blocks, block_size in [\n        (1, 1),\n        (10, 64),\n    ]:\n        dim_nope_and_rope = 512 + 64\n\n        input_k_cache = torch.randn(\n            (num_blocks, block_size, 1, dim_nope_and_rope),\n            dtype=torch.bfloat16,\n            device=\"cuda\",\n        )\n\n        ref_quant = _quantize_k_cache_ref(input_k_cache)\n        actual_quant = _quantize_k_cache_fast_wrapped(input_k_cache)\n\n        ref_ref_dequant = dequant_k_cache._dequantize_k_cache_slow(ref_quant)\n        ref_actual_dequant = dequant_k_cache._dequantize_k_cache_fast_wrapped(ref_quant)\n        actual_actual_dequant = dequant_k_cache._dequantize_k_cache_fast_wrapped(\n            actual_quant\n        )\n\n        print(f\"{ref_ref_dequant=}\")\n        print(f\"{actual_actual_dequant=}\")\n        print(f\"{actual_actual_dequant - ref_ref_dequant=}\")\n        print(f\"{torch.mean(ref_ref_dequant - actual_actual_dequant)=}\")\n\n        # TODO too different?\n        torch.testing.assert_close(\n            ref_ref_dequant, ref_actual_dequant, atol=0.2, rtol=0.2\n        )\n        torch.testing.assert_close(\n            ref_ref_dequant, actual_actual_dequant, atol=0.2, rtol=0.2\n        )\n\n        # test dequant_k_cache_paged\n        page_table_1 = torch.arange(\n            num_blocks * block_size, dtype=torch.int32, device=\"cuda\"\n        )\n        actual_dequant_paged = dequant_k_cache.dequantize_k_cache_paged(\n            actual_quant, page_table_1\n        ).reshape(actual_actual_dequant.shape)\n        print(f\"{torch.mean(actual_actual_dequant - actual_dequant_paged)=}\")\n        torch.testing.assert_close(\n            ref_ref_dequant, actual_dequant_paged, atol=0.2, rtol=0.2\n        )\n\n    print(\"Passed\")\n\n    # Test quantize_k_cache_separate: verify output matches concat path\n    print(\"\\nTesting quantize_k_cache_separate...\")\n    for num_tokens in [64, 100]:\n        dim_nope = 512\n        dim_rope = 64\n\n        k_nope = torch.randn(\n            num_tokens, 1, dim_nope, dtype=torch.bfloat16, device=\"cuda\"\n        )\n        k_rope = torch.randn(\n            num_tokens, 1, dim_rope, dtype=torch.bfloat16, device=\"cuda\"\n        )\n\n        # Old path: concat then quantize\n        k_concat = torch.cat([k_nope, k_rope], dim=-1).squeeze(1)  # (num_tokens, 576)\n        old_output = quantize_k_cache(k_concat.unsqueeze(1).unsqueeze(1))  # 4D input\n        old_output = old_output.squeeze(1).squeeze(1)  # Back to (num_tokens, 656)\n\n        # New path: quantize separately\n        nope_part, rope_part = quantize_k_cache_separate(k_nope, k_rope)\n        new_bytes = torch.cat([nope_part.squeeze(1), rope_part.squeeze(1)], dim=-1)\n\n        # Compare byte-level equality\n        old_bytes = old_output.view(torch.uint8)\n\n        if old_bytes.shape != new_bytes.shape:\n            raise RuntimeError(\n                f\"Shape mismatch: {old_bytes.shape} vs {new_bytes.shape}\"\n            )\n\n        diff_bytes = (old_bytes != new_bytes).sum().item()\n        if diff_bytes > 0:\n            max_diff = (old_bytes.float() - new_bytes.float()).abs().max().item()\n            raise RuntimeError(\n                f\"quantize_k_cache_separate output doesn't match concat path: \"\n                f\"{diff_bytes} differing bytes, max_diff={max_diff}\"\n            )\n\n        print(f\"  num_tokens={num_tokens}: PASSED (outputs match byte-wise)\")\n\n    print(\"quantize_k_cache_separate tests passed!\")\n\n    print(\"\\nDo benchmark...\")\n\n    for num_blocks, block_size in [\n        (1, 64),\n        (64, 64),\n        (128, 64),\n        (256, 64),\n        (512, 64),\n        (1024, 64),\n        (2048, 64),\n    ]:\n        dim_nope_and_rope = 512 + 64\n\n        input_k_cache = torch.randn(\n            (num_blocks, block_size, 1, dim_nope_and_rope),\n            dtype=torch.bfloat16,\n            device=\"cuda\",\n        )\n\n        actual_quant = _quantize_k_cache_fast_wrapped(input_k_cache)\n\n        page_table_1 = torch.arange(\n            num_blocks * block_size, dtype=torch.int32, device=\"cuda\"\n        )\n\n        def run_ans():\n            return dequant_k_cache.dequantize_k_cache_paged(actual_quant, page_table_1)\n\n        ans_time: float = triton.testing.do_bench(run_ans, warmup=10, rep=20) / 1000  # type: ignore\n        print(f\"seq_kv: {num_blocks * block_size}, time: {ans_time * 1e6: 4.0f} us\")\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/nsa/tilelang_kernel.py",
    "content": "from typing import Optional, Tuple\n\nimport tilelang\nimport tilelang.language as T\nimport torch\n\nfrom sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz\nfrom sglang.srt.utils import is_gfx95_supported, is_hip\n\ntilelang.set_log_level(\"WARNING\")\n\npass_configs = {\n    tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,\n    tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,\n}\n# TL_DISABLE_FAST_MATH has deprecated in v0.1.7.post1 tilelang\nif hasattr(tilelang.PassConfigKey, \"TL_DISABLE_FAST_MATH\"):\n    pass_configs[tilelang.PassConfigKey.TL_DISABLE_FAST_MATH] = True\nelif hasattr(tilelang.PassConfigKey, \"TL_ENABLE_FAST_MATH\"):\n    pass_configs[tilelang.PassConfigKey.TL_ENABLE_FAST_MATH] = False\n\n_is_hip = is_hip()\n_is_gfx95_supported = is_gfx95_supported()\n_is_fp8_fnuz = is_fp8_fnuz()\n\nBF16 = \"bfloat16\"\nFP8 = \"float8_e4m3fnuz\" if _is_fp8_fnuz else \"float8_e4m3\"\nFP32 = \"float32\"\n\n\ndef fast_log2_ceil(x):\n    bits_x = T.reinterpret(\"uint32\", x)\n    exp_x = (bits_x >> 23) & 0xFF\n    man_bits = bits_x & ((1 << 23) - 1)\n    return T.Cast(\"int32\", exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0))\n\n\ndef fast_pow2(x):\n    bits_x = (x + 127) << 23\n    return T.reinterpret(\"float32\", bits_x)\n\n\ndef fast_round_scale(amax, fp8_max_inv):\n    return fast_pow2(fast_log2_ceil(amax * fp8_max_inv))\n\n\n@tilelang.jit(pass_configs=pass_configs)\ndef act_quant_kernel(\n    N, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32, round_scale=False\n):\n    M = T.symbolic(\"M\")\n    fp8_min = -224.0 if _is_fp8_fnuz else -448.0\n    fp8_max = 224.0 if _is_fp8_fnuz else 448.0\n    fp8_max_inv = 1 / fp8_max\n    num_stages = 0 if round_scale else 2\n    blk_m = 32\n    group_size = 128\n\n    @T.prim_func\n    def act_quant_kernel_(\n        X: T.Tensor[(M, N), in_dtype],\n        Y: T.Tensor[(M, N), out_dtype],\n        S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype],\n    ):\n        with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (\n            pid_m,\n            pid_n,\n        ):\n            x_shared = T.alloc_shared((blk_m, group_size), in_dtype)\n            x_local = T.alloc_fragment((blk_m, group_size), in_dtype)\n            amax_local = T.alloc_fragment((blk_m,), scale_dtype)\n            s_local = T.alloc_fragment((blk_m,), scale_dtype)\n            y_local = T.alloc_fragment((blk_m, group_size), out_dtype)\n            y_shared = T.alloc_shared((blk_m, group_size), out_dtype)\n\n            for _ in T.Pipelined(1, num_stages=num_stages):\n                T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared)\n                T.copy(x_shared, x_local)\n                T.reduce_absmax(x_local, amax_local, dim=1)\n                for i in T.Parallel(blk_m):\n                    amax_local[i] = T.max(amax_local[i], 1e-4)\n                    if round_scale:\n                        s_local[i] = fast_round_scale(amax_local[i], fp8_max_inv)\n                    else:\n                        s_local[i] = amax_local[i] * fp8_max_inv\n                for i, j in T.Parallel(blk_m, group_size):\n                    y_local[i, j] = T.clamp(\n                        x_local[i, j] / s_local[i], fp8_min, fp8_max\n                    )\n                for i in T.Parallel(blk_m):\n                    S[pid_m * blk_m + i, pid_n] = s_local[i]\n                T.copy(y_local, y_shared)\n                T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size])\n\n    return act_quant_kernel_\n\n\ndef act_quant(\n    x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Quantizes the input tensor `x` using block-wise quantization.\n\n    Args:\n        x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.\n        block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.\n        scale_fmt (Optional[str], optional): The format of the scale. Default is None.\n    Returns:\n        Tuple[torch.Tensor, torch.Tensor]: A tuple containing:\n            - The quantized tensor with dtype `torch.float8_e4m3fn`.\n            - A tensor of scaling factors with dtype `torch.float32`.\n    \"\"\"\n    assert x.is_contiguous(), \"Input tensor must be contiguous\"\n    assert (\n        x.size(-1) % block_size == 0\n    ), f\"Last dimension size must be divisible by block_size (block_size={block_size})\"\n    N = x.size(-1)\n    if _is_fp8_fnuz:\n        y = torch.empty_like(x, dtype=torch.float8_e4m3fnuz)\n    else:\n        y = torch.empty_like(x, dtype=torch.float8_e4m3fn)\n    s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32)\n    kernel = act_quant_kernel(N, round_scale=scale_fmt is not None)\n    kernel(x.view(-1, N), y.view(-1, N), s.view(-1, N // block_size))\n    return y, s\n\n\n@tilelang.jit(out_idx=[4], pass_configs=pass_configs)\ndef fp8_index_kernel(h: int, d: int, clear_accum=True):\n    b = T.symbolic(\"b\")\n    m = T.symbolic(\"m\")\n    n = T.symbolic(\"n\")\n\n    blk_n1 = 512\n    blk_n2 = 128\n\n    @T.prim_func\n    def fp8_index_kernel_(\n        q: T.Tensor[(b, m, h, d), FP8],\n        q_s: T.Tensor[(b, m, h), FP32],\n        k: T.Tensor[(b, n, d), FP8],\n        k_s: T.Tensor[(b, n), FP32],\n        o: T.Tensor[(b, m, n), FP32],\n    ) -> None:\n        with T.Kernel(b, m, T.ceildiv(n, blk_n1)) as (i_b, i_m, i1_n):\n            q_smem = T.alloc_shared((h, d), FP8)\n            T.copy(q[i_b, i_m, 0, 0], q_smem)\n\n            q_s_frag = T.alloc_fragment(h, FP32)\n            T.copy(q_s[i_b, i_m, 0], q_s_frag)\n\n            for i2_n in T.Pipelined(blk_n1 // blk_n2, num_stages=2):\n                k_smem = T.alloc_shared((blk_n2, d), FP8)\n                T.copy(k[i_b, i1_n * blk_n1 + i2_n * blk_n2, 0], k_smem)\n\n                k_s_frag = T.alloc_fragment(blk_n2, FP32)\n                T.copy(k_s[i_b, i1_n * blk_n1 + i2_n * blk_n2], k_s_frag)\n\n                logits = T.alloc_fragment((blk_n2, h), FP32)\n                if not clear_accum:\n                    T.fill(logits, 0)\n                T.gemm(\n                    k_smem,\n                    q_smem,\n                    logits,\n                    transpose_A=False,\n                    transpose_B=True,\n                    clear_accum=clear_accum,\n                )\n\n                for i_h, i3_n in T.Parallel(h, blk_n2):\n                    logits[i3_n, i_h] = T.max(logits[i3_n, i_h], 0) * q_s_frag[i_h]\n\n                logits_sum = T.alloc_fragment(blk_n2, FP32)\n                T.reduce_sum(logits, logits_sum, dim=1)\n\n                for i3_n in T.Parallel(blk_n2):\n                    logits_sum[i3_n] *= k_s_frag[i3_n]\n\n                T.copy(logits_sum, o[i_b, i_m, i1_n * blk_n1 + i2_n * blk_n2])\n\n    return fp8_index_kernel_\n\n\ndef fp8_index(\n    q: torch.Tensor,\n    q_s: torch.Tensor,\n    k: torch.Tensor,\n    k_s: torch.Tensor,\n) -> torch.Tensor:\n    \"\"\"\n    Perform index score using FP8 precision.\n\n    Args:\n        q (torch.Tensor): The Q tensor, must be contiguous.\n        q_s (torch.Tensor): The scaling factor for Q (float), must be contiguous.\n        k (torch.Tensor): The K tensor, must be contiguous.\n        k_s (torch.Tensor): The scaling factor for K (e8m0 here), must be contiguous.\n\n        fp8 q @ fp8 k -> fp32 logits\n        relu(fp32 logits) * q_s (weights) -> fp32 logits\n        fp32 logits -> fp32 logits_sum\n        fp32 logits_sum * k_s (e8m0) -> fp32 index_score\n    \"\"\"\n    if _is_hip:\n        return fp8_index_kernel(q.shape[2], q.shape[3], False)(q, q_s, k, k_s)\n    else:\n        return fp8_index_kernel(q.shape[2], q.shape[3])(q, q_s, k, k_s)\n\n\n@tilelang.jit(\n    out_idx=[-1],\n    pass_configs={\n        tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,\n        tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,\n    },\n)\ndef sparse_attention_fwd_kernel_v1(\n    num_heads,\n    dim,\n    tail_dim,\n    topk,\n    *,\n    kv_group=1,\n    sm_scale=None,\n    is_causal=True,\n    block_I=64,\n    num_stages=2,\n    threads=256,\n):\n    assert dim == tilelang.math.next_power_of_2(\n        dim\n    ), f\"haven't check padding correctness yet, dim={dim}\"\n    assert tail_dim == tilelang.math.next_power_of_2(\n        tail_dim\n    ), f\"haven't check padding correctness yet, dim={tail_dim}\"\n    assert is_causal == True, \"non-casual is not supported\"\n    assert (\n        topk % block_I == 0\n    ), \"otherwise will load some index=0 thus causing wrong kv to be loaded\"\n    if sm_scale is None:\n        sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 * 1.44269504  # log2(e)\n    else:\n        sm_scale = sm_scale * 1.44269504  # log2(e)\n\n    batch = T.symbolic(\"batch\")\n    seq_len = T.symbolic(\"seq_len\")\n    seq_len_kv = T.symbolic(\"seq_len_kv\")\n\n    head_kv = num_heads // kv_group\n    q_shape = [batch, seq_len, num_heads, dim + tail_dim]\n    kv_shape = [batch, seq_len_kv, kv_group, dim + tail_dim]\n    o_shape = [batch, seq_len, num_heads, dim]\n    indices_shape = [batch, seq_len, kv_group, topk]\n    indices_dtype = \"int32\"\n    dtype = \"bfloat16\"\n    accum_dtype = \"float\"\n\n    H = head_kv\n    padded_H = max(tilelang.math.next_power_of_2(head_kv), 16)\n    if padded_H != H:\n        assert kv_group == 1\n    BI = block_I\n    NI = tilelang.cdiv(topk, block_I)\n    D = dim\n    D_tail = tail_dim\n\n    if head_kv > 64:\n        assert head_kv % 64 == 0, \"head_kv should be a multiple of 64\"\n        REPLICATE_H = head_kv // 64\n    else:\n        REPLICATE_H = 1\n\n    H_per_block = padded_H if REPLICATE_H == 1 else 64\n\n    @T.prim_func\n    def main(\n        Q: T.Tensor(q_shape, dtype),  # type: ignore\n        KV: T.Tensor(kv_shape, dtype),  # type: ignore\n        Indices: T.Tensor(indices_shape, indices_dtype),  # type: ignore\n        Output: T.Tensor(o_shape, dtype),  # type: ignore\n    ):\n        with T.Kernel(seq_len * REPLICATE_H, batch, kv_group, threads=threads) as (\n            bx,\n            by,\n            bz,\n        ):\n            Q_shared = T.alloc_shared([H_per_block, D], dtype)\n            Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype)\n            KV_shared = T.alloc_shared([BI, D], dtype)\n            K_tail_shared = T.alloc_shared([BI, D_tail], dtype)\n            O_shared = T.alloc_shared([H_per_block, D], dtype)\n            mask = T.alloc_fragment([BI], \"bool\")\n\n            acc_o = T.alloc_fragment([H_per_block, D], accum_dtype)\n            acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype)\n            S_shared = T.alloc_shared([H_per_block, BI], dtype)\n            sumexp = T.alloc_fragment([H_per_block], accum_dtype)\n            sumexp_i = T.alloc_fragment([H_per_block], accum_dtype)\n            alpha = T.alloc_fragment([H_per_block], accum_dtype)\n            m_i = T.alloc_fragment([H_per_block], accum_dtype)\n            m_i_prev = T.alloc_fragment([H_per_block], accum_dtype)\n\n            T.fill(acc_o, 0)\n            T.fill(sumexp, 0)\n            T.fill(m_i, -(2**30))  # avoid -inf - inf to cause nan\n\n            b_i, g_i = by, bz\n            s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H)\n            q_i = s_i\n            max_kv_i = q_i\n\n            H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64)\n            H1 = H0 + H_per_block\n\n            T.copy(Q[b_i, s_i, H0:H1, :D], Q_shared)\n            T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared)\n\n            for i_i in T.Pipelined(NI, num_stages=num_stages):\n\n                for bi_i in T.Parallel(BI):\n                    mask[bi_i] = Indices[b_i, s_i, g_i, i_i * BI + bi_i] >= 0\n\n                for bi_i, d_i in T.Parallel(BI, D):\n                    KV_shared[bi_i, d_i] = KV[\n                        b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, d_i\n                    ]\n                for bi_i, d_i in T.Parallel(BI, D_tail):\n                    K_tail_shared[bi_i, d_i] = KV[\n                        b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, D + d_i\n                    ]\n\n                for h_i, bi_i in T.Parallel(H_per_block, BI):\n                    acc_s[h_i, bi_i] = T.if_then_else(\n                        mask[bi_i], 0, -T.infinity(acc_s.dtype)\n                    )\n                T.gemm(\n                    Q_shared,\n                    KV_shared,\n                    acc_s,\n                    transpose_B=True,\n                    policy=T.GemmWarpPolicy.FullCol,\n                )\n                T.gemm(\n                    Q_tail_shared,\n                    K_tail_shared,\n                    acc_s,\n                    transpose_B=True,\n                    policy=T.GemmWarpPolicy.FullCol,\n                )\n                T.copy(m_i, m_i_prev)\n                T.reduce_max(acc_s, m_i, dim=1, clear=False)\n                for h_i in T.Parallel(H_per_block):\n                    alpha[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)\n                for h_i, bi_i in T.Parallel(H_per_block, BI):\n                    acc_s[h_i, bi_i] = T.exp2(\n                        acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale\n                    )\n                T.reduce_sum(acc_s, sumexp_i, dim=1)  # is this a accumulate operator?\n                for h_i in T.Parallel(H_per_block):\n                    sumexp[h_i] = sumexp[h_i] * alpha[h_i] + sumexp_i[h_i]\n                for h_i, d_i in T.Parallel(H_per_block, D):\n                    acc_o[h_i, d_i] = acc_o[h_i, d_i] * alpha[h_i]\n\n                T.copy(acc_s, S_shared)\n                T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)\n\n            # Rescale\n            for h_i, d_i in T.Parallel(H_per_block, D):\n                acc_o[h_i, d_i] /= sumexp[h_i]\n            for h_i in T.Parallel(H_per_block):\n                sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale\n\n            T.copy(acc_o, O_shared)\n            T.copy(acc_o, Output[b_i, s_i, H0:H1, :])\n\n    return main\n\n\n@tilelang.jit(\n    out_idx=[-1],\n    compile_flags=[\n        \"-O3\",\n        \"-Wno-deprecated-declarations\",\n        \"-U__CUDA_NO_HALF_OPERATORS__\",\n        \"-U__CUDA_NO_HALF_CONVERSIONS__\",\n        \"-U__CUDA_NO_HALF2_OPERATORS__\",\n        \"-U__CUDA_NO_BFLOAT16_CONVERSIONS__\",\n        \"--expt-relaxed-constexpr\",\n        \"--expt-extended-lambda\",\n        \"--ptxas-options=-v,--register-usage-level=10\",\n        \"-DNDEBUG\",\n    ],\n)  # type: ignore\ndef sparse_attention_fwd_kernel_v2(\n    num_heads: int,\n    dim: int,\n    tail_dim: int,\n    topk: int,\n    *,\n    kv_group: int = 1,\n    sm_scale: Optional[float] = None,\n    block_I: int = 64,\n):\n    assert dim == tilelang.math.next_power_of_2(\n        dim\n    ), f\"haven't check padding correctness yet, dim={dim}\"\n    assert tail_dim == tilelang.math.next_power_of_2(\n        tail_dim\n    ), f\"haven't check padding correctness yet, dim={tail_dim}\"\n    assert (\n        topk % block_I == 0\n    ), \"otherwise will load some index=0 thus causing wrong kv to be loaded\"\n    if sm_scale is None:\n        sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 * 1.44269504  # log2(e)\n    else:\n        sm_scale = sm_scale * 1.44269504  # log2(e)\n    threads = 384\n\n    batch = T.symbolic(\"batch\")\n    qo_len = T.symbolic(\"seq_len\")\n    num_pages = T.symbolic(\"num_pages\")\n\n    q_shape = [batch, qo_len, num_heads, dim + tail_dim]\n    kv_shape = [batch, num_pages, kv_group, dim + tail_dim]\n    o_shape = [batch, qo_len, num_heads, dim]\n    indices_shape = [batch, qo_len, kv_group, topk]\n\n    indices_dtype = \"int32\"\n    dtype = \"bfloat16\"\n    accum_dtype = \"float\"\n\n    H = num_heads\n    padded_H = max(tilelang.math.next_power_of_2(num_heads), 16)\n    if padded_H != H:\n        assert kv_group == 1\n    BI = block_I\n    NI = tilelang.cdiv(topk, block_I)\n    assert NI % 2 == 0, \"NI should be a multiple of 2\"\n    D = dim\n    D_tail = tail_dim\n    if num_heads > 64:\n        assert num_heads % 64 == 0, \"head_kv should be a multiple of 64\"\n        REPLICATE_H = num_heads // 64\n    else:\n        REPLICATE_H = 1\n\n    H_per_block = padded_H if REPLICATE_H == 1 else 64\n\n    @T.prim_func\n    def main(\n        Q: T.Tensor(q_shape, dtype),  # type: ignore\n        KV: T.Tensor(kv_shape, dtype),  # type: ignore\n        Indices: T.Tensor(indices_shape, indices_dtype),  # type: ignore\n        Output: T.Tensor(o_shape, dtype),  # type: ignore\n    ):\n        \"\"\"\n        Q: [b, qo_len, H, D + D_tail] (bfloat16)\n        KV: [b, num_pages, kv_group, D + D_tail] (bfloat16)\n        Indices: [b, qo_len, kv_group, topk] (int32)\n        \"\"\"\n\n        with T.Kernel(qo_len * REPLICATE_H, batch, 1, threads=threads) as (bx, by, bz):  # type: ignore\n            Q_shared_l = T.alloc_shared([H_per_block, D // 2], dtype)\n            Q_shared_r = T.alloc_shared([H_per_block, D // 2], dtype)\n            Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype)\n            KV_shared_0_l = T.alloc_shared([BI, D // 2], dtype)\n            KV_shared_0_r = T.alloc_shared([BI, D // 2], dtype)\n            KV_shared_1_l = T.alloc_shared([BI, D // 2], dtype)\n            KV_shared_1_r = T.alloc_shared([BI, D // 2], dtype)\n            K_tail_shared_0 = T.alloc_shared([BI, D_tail], dtype)\n            K_tail_shared_1 = T.alloc_shared([BI, D_tail], dtype)\n            O_shared_l = Q_shared_l\n            O_shared_r = Q_shared_r\n            is_kv_valid_0 = T.alloc_shared([BI], \"bool\", scope=\"shared\")\n            is_kv_valid_1 = T.alloc_shared([BI], \"bool\", scope=\"shared\")\n\n            acc_o_l = T.alloc_fragment([H_per_block, D // 2], accum_dtype)\n            acc_o_r = T.alloc_fragment([H_per_block, D // 2], accum_dtype)\n            acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype)\n            S_shared = T.alloc_shared([H_per_block, BI], dtype)\n            sumexp = T.alloc_fragment([H_per_block], accum_dtype)\n            sum_exp_shared = T.alloc_shared([H_per_block], accum_dtype)\n            sumexp_i = T.alloc_fragment([H_per_block], accum_dtype)\n            alpha_shared = T.alloc_shared([H_per_block], accum_dtype, scope=\"shared\")\n            alpha_local = T.alloc_fragment([H_per_block], accum_dtype)\n            m_i = T.alloc_fragment([H_per_block], accum_dtype)\n            m_i_prev = T.alloc_fragment([H_per_block], accum_dtype)\n            indices_local = T.alloc_local([1], indices_dtype)\n            indices_tmp = T.alloc_local([1], indices_dtype)\n\n            bar_q = T.alloc_barrier(arrive_count=384)\n            bar_k_0_ready = T.alloc_barrier(arrive_count=128)\n            bar_k_1_ready = T.alloc_barrier(arrive_count=128)\n            bar_k_0_free = T.alloc_barrier(arrive_count=256)\n            bar_k_1_free = T.alloc_barrier(arrive_count=256)\n            bar_sScale_and_sS_ready = T.alloc_barrier(arrive_count=256)\n            bar_sScale_and_sS_free = T.alloc_barrier(arrive_count=256)\n\n            bar_0_128 = T.alloc_barrier(arrive_count=128)\n            bar_1_128 = T.alloc_barrier(arrive_count=128)\n            bar_2_128 = T.alloc_barrier(arrive_count=128)\n            bar_final = T.alloc_barrier(arrive_count=128)\n\n            b_i, g_i = by, bz\n            s_i = bx if REPLICATE_H == 1 else bx // REPLICATE_H\n\n            H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64)\n            H1 = H0 + H_per_block\n\n            tx = T.get_thread_binding()\n\n            T.copy(Q[b_i, s_i, H0:H1, 0 : D // 2], Q_shared_l)\n            T.copy(Q[b_i, s_i, H0:H1, D // 2 : D], Q_shared_r)\n            T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared)\n            T.barrier_arrive(bar_q)\n\n            if tx < 128:\n                T.set_max_nreg(240, 1)\n                T.fill(sumexp, 0)\n                T.fill(m_i, -(2**30))  # avoid -inf - inf to cause nan\n                T.fill(acc_o_l, 0)\n                T.barrier_wait(bar_q, 0)\n\n                for i_i in T.serial(T.ceildiv(NI, 2)):\n                    # Buffer 0\n                    # with sync_at(bar_0_128, 0):\n                    T.barrier_wait(bar_k_0_ready[0], (i_i & 1))\n                    T.barrier_arrive(bar_0_128)\n                    T.barrier_wait(bar_0_128, 0)\n\n                    for h_i, bi_i in T.Parallel(H_per_block, BI):\n                        acc_s[h_i, bi_i] = T.if_then_else(\n                            is_kv_valid_0[bi_i], 0, -T.infinity(acc_s.dtype)\n                        )\n                    T.gemm(\n                        Q_shared_l, KV_shared_0_l, acc_s, transpose_B=True, wg_wait=-1\n                    )\n                    T.gemm(\n                        Q_shared_r, KV_shared_0_r, acc_s, transpose_B=True, wg_wait=-1\n                    )\n                    T.gemm(\n                        Q_tail_shared,\n                        K_tail_shared_0,\n                        acc_s,\n                        transpose_B=True,\n                        wg_wait=-1,\n                    )\n\n                    T.wait_wgmma(0)\n\n                    if i_i != 0:\n                        T.barrier_arrive(bar_sScale_and_sS_free)\n                        T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2) & 1) ^ 1)\n\n                    T.copy(m_i, m_i_prev)\n                    T.reduce_max(acc_s, m_i, dim=1, clear=False)\n                    for h_i in T.Parallel(H_per_block):\n                        alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)\n                    for h_i, bi_i in T.Parallel(H_per_block, BI):\n                        acc_s[h_i, bi_i] = T.exp2(\n                            acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale\n                        )\n                    T.reduce_sum(\n                        acc_s, sumexp_i, dim=1\n                    )  # is this a accumulate operator?\n                    for h_i in T.Parallel(H_per_block):\n                        sumexp[h_i] = sumexp[h_i] * alpha_local[h_i] + sumexp_i[h_i]\n                    for h_i, d_i in T.Parallel(H_per_block, D // 2):\n                        acc_o_l[h_i, d_i] *= alpha_local[h_i]\n                    T.copy(alpha_local, alpha_shared)\n\n                    T.copy(acc_s, S_shared)\n                    T.gemm(S_shared, KV_shared_0_l, acc_o_l)\n\n                    T.barrier_arrive(bar_sScale_and_sS_ready)\n                    T.barrier_arrive(bar_k_0_free[0])\n\n                    # Buffer 1\n                    T.barrier_wait(bar_k_1_ready[0], (i_i & 1))\n                    T.barrier_arrive(bar_0_128)\n                    T.barrier_wait(bar_0_128, 1)\n\n                    for h_i, bi_i in T.Parallel(H_per_block, BI):\n                        acc_s[h_i, bi_i] = T.if_then_else(\n                            is_kv_valid_1[bi_i], 0, -T.infinity(acc_s.dtype)\n                        )\n                    T.gemm(\n                        Q_shared_l, KV_shared_1_l, acc_s, transpose_B=True, wg_wait=-1\n                    )\n                    T.gemm(\n                        Q_shared_r, KV_shared_1_r, acc_s, transpose_B=True, wg_wait=-1\n                    )\n                    T.gemm(\n                        Q_tail_shared,\n                        K_tail_shared_1,\n                        acc_s,\n                        transpose_B=True,\n                        wg_wait=-1,\n                    )\n\n                    T.wait_wgmma(0)\n\n                    T.barrier_arrive(bar_sScale_and_sS_free)\n                    T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2 + 1) & 1) ^ 1)\n\n                    T.copy(m_i, m_i_prev)\n                    T.reduce_max(acc_s, m_i, dim=1, clear=False)\n                    for h_i in T.Parallel(H_per_block):\n                        alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)\n                    for h_i, bi_i in T.Parallel(H_per_block, BI):\n                        acc_s[h_i, bi_i] = T.exp2(\n                            acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale\n                        )\n                    T.reduce_sum(\n                        acc_s, sumexp_i, dim=1\n                    )  # is this a accumulate operator?\n                    for h_i in T.Parallel(H_per_block):\n                        sumexp[h_i] = sumexp[h_i] * alpha_local[h_i] + sumexp_i[h_i]\n                    for h_i, d_i in T.Parallel(H_per_block, D // 2):\n                        acc_o_l[h_i, d_i] *= alpha_local[h_i]\n                    T.copy(alpha_local, alpha_shared)\n\n                    T.copy(acc_s, S_shared)\n                    T.gemm(S_shared, KV_shared_1_l, acc_o_l)\n\n                    T.barrier_arrive(bar_sScale_and_sS_ready)\n                    T.barrier_arrive(bar_k_1_free[0])\n\n                # Rescale\n                for h_i in T.Parallel(H_per_block):\n                    sum_exp_shared[h_i] = sumexp[h_i]\n                T.barrier_arrive(bar_final)\n                for h_i, d_i in T.Parallel(H_per_block, D // 2):\n                    acc_o_l[h_i, d_i] /= sumexp[h_i]\n                for h_i in T.Parallel(H_per_block):\n                    sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale\n                T.copy(acc_o_l, O_shared_l)\n                T.copy(O_shared_l, Output[b_i, s_i, H0:H1, 0 : D // 2])\n            elif tx >= 128 and tx < 256:\n                # T.set_max_nreg(168, 1)\n                T.fill(acc_o_r, 0)\n                for i_i in T.serial(T.ceildiv(NI, 2)):\n                    # Buffer 0\n                    T.barrier_arrive(bar_sScale_and_sS_ready)\n                    T.barrier_wait(bar_sScale_and_sS_ready, ((i_i * 2) & 1))\n                    T.barrier_arrive(bar_1_128)\n                    T.barrier_wait(bar_1_128, 0)\n                    for h_i, d_i in T.Parallel(H_per_block, D // 2):\n                        acc_o_r[h_i, d_i] *= alpha_shared[h_i]\n                    T.gemm(S_shared, KV_shared_0_r, acc_o_r)\n                    T.barrier_arrive(bar_k_0_free[0])\n                    T.barrier_arrive(bar_sScale_and_sS_free)\n\n                    # Buffer 1\n                    T.barrier_arrive(bar_sScale_and_sS_ready)\n                    T.barrier_wait(bar_sScale_and_sS_ready, ((i_i * 2 + 1) & 1))\n                    T.barrier_arrive(bar_1_128)\n                    T.barrier_wait(bar_1_128, 1)\n                    for h_i, d_i in T.Parallel(H_per_block, D // 2):\n                        acc_o_r[h_i, d_i] *= alpha_shared[h_i]\n                    T.gemm(S_shared, KV_shared_1_r, acc_o_r)\n                    T.barrier_arrive(bar_k_1_free[0])\n                    if i_i != T.ceildiv(NI, 2) - 1:\n                        T.barrier_arrive(bar_sScale_and_sS_free)\n\n                # Rescale\n                T.barrier_wait(bar_final, 0)\n                for h_i, d_i in T.Parallel(H_per_block, D // 2):\n                    acc_o_r[h_i, d_i] /= sum_exp_shared[h_i]\n\n                T.copy(acc_o_r, O_shared_r)\n                T.copy(O_shared_r, Output[b_i, s_i, H0:H1, D // 2 : D])\n            elif tx >= 256:\n                # producer\n                T.set_max_nreg(80, 0)\n                indices_local[0] = 0\n                for i_i in T.serial(T.ceildiv(NI, 2)):\n                    # Buffer 0\n                    T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1))\n                    T.barrier_arrive(bar_2_128)\n                    T.barrier_wait(bar_2_128, 0)\n\n                    for r in T.serial(4):\n                        indices_tmp[0] = Indices[\n                            b_i, s_i, g_i, (i_i * 2) * BI + r * 16 + (tx - 256) // 8\n                        ]\n                        is_kv_valid_0[r * 16 + (tx - 256) // 8] = indices_tmp[0] >= 0\n                        if is_kv_valid_0[r * 16 + (tx - 256) // 8]:\n                            indices_local[0] = indices_tmp[0]\n\n                        with T.attr(\"default\", \"async_scope\", 1):  # type: ignore\n                            for u in T.serial(4):\n                                for v in T.vectorized(8):\n                                    KV_shared_0_l[\n                                        r * 16 + (tx - 256) // 8,\n                                        64 * u + (tx - 256) % 8 * 8 + v,\n                                    ] = KV[\n                                        b_i,\n                                        indices_local[0],\n                                        g_i,\n                                        64 * u + (tx - 256) % 8 * 8 + v,\n                                    ]\n                                    KV_shared_0_r[\n                                        r * 16 + (tx - 256) // 8,\n                                        64 * u + (tx - 256) % 8 * 8 + v,\n                                    ] = KV[\n                                        b_i,\n                                        indices_local[0],\n                                        g_i,\n                                        D // 2 + 64 * u + (tx - 256) % 8 * 8 + v,\n                                    ]\n                        with T.attr(\"default\", \"async_scope\", 1):  # type: ignore\n                            for v in T.vectorized(8):\n                                K_tail_shared_0[\n                                    r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v\n                                ] = KV[\n                                    b_i,\n                                    indices_local[0],\n                                    g_i,\n                                    D + (tx - 256) % 8 * 8 + v,\n                                ]\n\n                    T.cp_async_barrier_noinc(bar_k_0_ready[0])\n\n                    # Buffer 1\n                    T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1))\n                    T.barrier_arrive(bar_2_128)\n                    T.barrier_wait(bar_2_128, 1)\n\n                    for r in T.serial(4):\n                        indices_tmp[0] = Indices[\n                            b_i, s_i, g_i, (i_i * 2 + 1) * BI + r * 16 + (tx - 256) // 8\n                        ]\n                        is_kv_valid_1[r * 16 + (tx - 256) // 8] = indices_tmp[0] >= 0\n                        if is_kv_valid_1[r * 16 + (tx - 256) // 8]:\n                            indices_local[0] = indices_tmp[0]\n\n                        with T.attr(\"default\", \"async_scope\", 1):  # type: ignore\n                            for u in T.serial(4):\n                                for v in T.vectorized(8):\n                                    KV_shared_1_l[\n                                        r * 16 + (tx - 256) // 8,\n                                        64 * u + (tx - 256) % 8 * 8 + v,\n                                    ] = KV[\n                                        b_i,\n                                        indices_local[0],\n                                        g_i,\n                                        64 * u + (tx - 256) % 8 * 8 + v,\n                                    ]\n                                    KV_shared_1_r[\n                                        r * 16 + (tx - 256) // 8,\n                                        64 * u + (tx - 256) % 8 * 8 + v,\n                                    ] = KV[\n                                        b_i,\n                                        indices_local[0],\n                                        g_i,\n                                        D // 2 + 64 * u + (tx - 256) % 8 * 8 + v,\n                                    ]\n                        with T.attr(\"default\", \"async_scope\", 1):  # type: ignore\n                            for v in T.vectorized(8):\n                                K_tail_shared_1[\n                                    r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v\n                                ] = KV[\n                                    b_i,\n                                    indices_local[0],\n                                    g_i,\n                                    D + (tx - 256) % 8 * 8 + v,\n                                ]\n\n                    T.cp_async_barrier_noinc(bar_k_1_ready[0])\n\n    return main\n\n\n@tilelang.jit(\n    out_idx=[-2, -1],\n    pass_configs={\n        tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,\n        tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,\n    },\n)\ndef sparse_mla_fwd_decode_partial(\n    heads,\n    dim,\n    tail_dim,\n    topk,\n    *,\n    kv_group=1,\n    sm_scale=None,\n    is_causal=True,\n    block_I=64,\n    threads=256,\n):\n    \"\"\"\n    grid: (seq_len * REPLICATE_H, top_k_blocks).\n    Each block does one topk block, writes partial_o, partial_lse.\n    \"\"\"\n\n    assert is_causal == True, \"non-causal is not supported\"\n    assert kv_group == 1\n    assert topk % block_I == 0\n\n    # log2(e) = 1.44269504\n    if sm_scale is None:\n        sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 * 1.44269504\n    else:\n        sm_scale = sm_scale * 1.44269504\n\n    batch = 1\n    seq_len = T.dynamic(\"seq_len\")\n    seq_len_kv = T.dynamic(\"seq_len_kv\")\n\n    head_kv = heads // kv_group\n    padded_H = max(tilelang.math.next_power_of_2(head_kv), 16)\n    REPLICATE_H = (head_kv // 64) if head_kv > 64 else 1\n    H_per_block = padded_H if REPLICATE_H == 1 else 64\n    BI = block_I\n    NI = topk // block_I\n    D = dim\n    D_tail = tail_dim\n\n    q_shape = [batch, seq_len, heads, dim + tail_dim]\n    kv_shape = [batch, seq_len_kv, kv_group, dim + tail_dim]\n    indices_shape = [batch, seq_len, kv_group, topk]\n    partial_o_shape = [batch, seq_len, NI, heads, dim]\n    partial_lse_shape = [batch, seq_len, NI, heads]\n    indices_dtype = T.int32\n    dtype = T.bfloat16\n    accum_dtype = T.float32\n\n    @T.prim_func\n    def main(\n        Q: T.Tensor(q_shape, dtype),\n        KV: T.Tensor(kv_shape, dtype),\n        Indices: T.Tensor(indices_shape, indices_dtype),\n        Partial_O: T.Tensor(partial_o_shape, dtype),\n        Partial_Lse: T.Tensor(partial_lse_shape, accum_dtype),\n    ):\n        with T.Kernel(seq_len * REPLICATE_H, NI, threads=threads) as (bx, by):\n            Q_shared = T.alloc_shared([H_per_block, D], dtype)\n            Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype)\n            KV_shared = T.alloc_shared([BI, D], dtype)\n            K_tail_shared = T.alloc_shared([BI, D_tail], dtype)\n            mask = T.alloc_fragment([BI], T.bool)\n\n            acc_o = T.alloc_fragment([H_per_block, D], accum_dtype)\n            acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype)\n            S_shared = T.alloc_shared([H_per_block, BI], dtype)\n            sumexp_i = T.alloc_fragment([H_per_block], accum_dtype)\n            m_i = T.alloc_fragment([H_per_block], accum_dtype)\n\n            T.fill(acc_o, 0)\n\n            b_i, g_i = 0, 0\n            s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H)\n            topk_block_i = by\n            q_i = s_i\n\n            H0 = 0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64\n            H1 = H0 + H_per_block\n\n            T.copy(Q[b_i, s_i, H0:H1, :D], Q_shared)\n            T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared)\n\n            for bi_i in T.Parallel(BI):\n                mask[bi_i] = Indices[b_i, s_i, g_i, topk_block_i * BI + bi_i] >= 0\n            for bi_i, d_i in T.Parallel(BI, D):\n                KV_shared[bi_i, d_i] = KV[\n                    b_i, Indices[b_i, s_i, g_i, topk_block_i * BI + bi_i], g_i, d_i\n                ]\n            for bi_i, d_i in T.Parallel(BI, D_tail):\n                K_tail_shared[bi_i, d_i] = KV[\n                    b_i, Indices[b_i, s_i, g_i, topk_block_i * BI + bi_i], g_i, D + d_i\n                ]\n            for h_i, bi_i in T.Parallel(H_per_block, BI):\n                acc_s[h_i, bi_i] = T.if_then_else(\n                    mask[bi_i], 0, -T.infinity(acc_s.dtype)\n                )\n            T.gemm(\n                Q_shared,\n                KV_shared,\n                acc_s,\n                transpose_B=True,\n                policy=T.GemmWarpPolicy.FullCol,\n            )\n            T.gemm(\n                Q_tail_shared,\n                K_tail_shared,\n                acc_s,\n                transpose_B=True,\n                policy=T.GemmWarpPolicy.FullCol,\n            )\n\n            T.reduce_max(acc_s, m_i, dim=1, clear=True)\n            for h_i in T.Parallel(H_per_block):\n                m_i[h_i] = T.max(m_i[h_i], -(2**30))\n            for h_i, bi_i in T.Parallel(H_per_block, BI):\n                acc_s[h_i, bi_i] = T.exp2(\n                    acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale\n                )\n\n            T.reduce_sum(acc_s, sumexp_i, dim=1)\n            T.copy(acc_s, S_shared)\n            T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)\n\n            # sumexp_i==0 (all masked), divide by 1 to get 0 and avoid nan\n            for h_i, d_i in T.Parallel(H_per_block, D):\n                acc_o[h_i, d_i] = acc_o[h_i, d_i] / T.if_then_else(\n                    sumexp_i[h_i] == 0.0, 1.0, sumexp_i[h_i]\n                )\n            # sumexp_i==0 (all masked), use large negative so combine ignores this split\n            for h_i in T.Parallel(H_per_block):\n                sumexp_i[h_i] = T.if_then_else(\n                    sumexp_i[h_i] == 0.0,\n                    -(2**30),\n                    T.log2(sumexp_i[h_i]) + m_i[h_i] * sm_scale,\n                )\n\n            T.copy(acc_o, Partial_O[b_i, s_i, topk_block_i, H0:H1, :])\n            T.copy(sumexp_i, Partial_Lse[b_i, s_i, topk_block_i, H0:H1])\n\n    return main\n\n\n@tilelang.jit(\n    out_idx=[-1],\n    pass_configs={\n        tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,\n        tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,\n    },\n)\ndef sparse_mla_fwd_decode_combine(\n    heads,\n    dim,\n    topk,\n    head_per_block,\n    *,\n    block_I=64,\n    threads=256,\n):\n    \"\"\"\n    grid: (seq_len * REPLICATE_H). batch=1, kv_group=1.\n    Each block does one tile of heads (e.g. 4 or 8 for decode).\n    \"\"\"\n\n    assert heads % head_per_block == 0, f\"head_per_block must divide heads\"\n\n    batch = 1\n    seq_len = T.dynamic(\"seq_len\")\n\n    NI = topk // block_I\n    H_per_block = head_per_block\n    REPLICATE_H = heads // H_per_block\n\n    partial_o_shape = [batch, seq_len, NI, heads, dim]\n    partial_lse_shape = [batch, seq_len, NI, heads]\n    o_shape = [batch, seq_len, heads, dim]\n    dtype = T.bfloat16\n    accum_dtype = T.float32\n\n    @T.prim_func\n    def main(\n        Partial_O: T.Tensor(partial_o_shape, dtype),\n        Partial_Lse: T.Tensor(partial_lse_shape, accum_dtype),\n        Output: T.Tensor(o_shape, dtype),\n    ):\n        with T.Kernel(seq_len * REPLICATE_H, threads=threads) as (bx,):\n            shared_lse = T.alloc_shared([NI, H_per_block], accum_dtype)\n\n            lse_max = T.alloc_fragment([H_per_block], accum_dtype)\n            lse_sum = T.alloc_fragment([H_per_block], accum_dtype)\n            scale = T.alloc_fragment([H_per_block, NI], accum_dtype)\n            acc_o = T.alloc_fragment([H_per_block, dim], accum_dtype)\n\n            b_i = 0\n            s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H)\n            H0 = 0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * H_per_block\n            H1 = H0 + H_per_block\n\n            for k in T.serial(NI):\n                T.copy(Partial_Lse[b_i, s_i, k, H0:H1], shared_lse[k, :])\n\n            T.fill(lse_max, -(2**30))\n            for k in T.serial(NI):\n                for h_i in T.Parallel(H_per_block):\n                    lse_max[h_i] = T.max(lse_max[h_i], shared_lse[k, h_i])\n            T.fill(lse_sum, 0)\n            for k in T.serial(NI):\n                for h_i in T.Parallel(H_per_block):\n                    lse_sum[h_i] = lse_sum[h_i] + T.exp2(\n                        shared_lse[k, h_i] - lse_max[h_i]\n                    )\n            for k in T.serial(NI):\n                for h_i in T.Parallel(H_per_block):\n                    scale[h_i, k] = T.exp2(\n                        shared_lse[k, h_i] - lse_max[h_i] - T.log2(lse_sum[h_i])\n                    )\n\n            T.fill(acc_o, 0)\n            for k in T.serial(NI):\n                for h_i, d_i in T.Parallel(H_per_block, dim):\n                    acc_o[h_i, d_i] = acc_o[h_i, d_i] + scale[h_i, k] * Partial_O[\n                        b_i, s_i, k, H0 + h_i, d_i\n                    ].astype(accum_dtype)\n\n            T.copy(acc_o, Output[b_i, s_i, H0:H1, :])\n\n    return main\n\n\ndef tilelang_sparse_fwd(\n    q: torch.Tensor,\n    kv: torch.Tensor,\n    indices: torch.Tensor,\n    sm_scale: float,\n    d_v: int = 512,\n) -> torch.Tensor:\n    assert q.dim() == 3 and kv.dim() == 3 and indices.dim() == 3\n    num_heads = q.shape[1]\n    dim = q.shape[2]\n    tail_dim = dim - d_v\n    topk = indices.shape[-1]\n    assert topk == 2048\n    if _is_hip:\n        if _is_gfx95_supported:\n            # decode kernel\n            if q.shape[0] <= 64:\n                kernel_partial = sparse_mla_fwd_decode_partial(\n                    num_heads,\n                    d_v,\n                    tail_dim,\n                    topk,\n                    sm_scale=sm_scale,\n                    block_I=64,\n                    threads=256,\n                )\n                kernel_combine = sparse_mla_fwd_decode_combine(\n                    num_heads, d_v, topk, head_per_block=4, block_I=64, threads=256\n                )\n                partial_o, partial_lse = kernel_partial(\n                    q.unsqueeze(0), kv.unsqueeze(0), indices.unsqueeze(0)\n                )\n                out = kernel_combine(partial_o, partial_lse)\n                return out\n\n            # prefill kernel\n            kernel = sparse_attention_fwd_kernel_v1(\n                num_heads, d_v, tail_dim, topk, sm_scale=sm_scale, num_stages=1\n            )\n        else:  # reduce LDS usage on gfx942 target\n            kernel = sparse_attention_fwd_kernel_v1(\n                num_heads,\n                d_v,\n                tail_dim,\n                topk,\n                sm_scale=sm_scale,\n                block_I=32,\n                num_stages=1,\n                threads=128,\n            )\n    else:\n        kernel = sparse_attention_fwd_kernel_v2(\n            num_heads, d_v, tail_dim, topk, sm_scale=sm_scale\n        )\n    return kernel(q.unsqueeze(0), kv.unsqueeze(0), indices.unsqueeze(0))  # type: ignore\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/nsa/transform_index.py",
    "content": "from typing import List, Optional\n\nimport torch\nimport triton\nimport triton.language as tl\n\n\ndef transform_index_page_table_prefill(**kwargs):\n    return transform_index_page_table_prefill_ref(**kwargs)\n\n\ndef transform_index_page_table_decode(**kwargs):\n    return transform_index_page_table_decode_ref(**kwargs)\n\n\n@triton.jit\ndef transform_index_page_table_decode_kernel(\n    page_table_ptr: torch.Tensor,\n    topk_indices_ptr: torch.Tensor,\n    result_ptr: torch.Tensor,\n    page_size: tl.constexpr,\n    max_seqlen_k: tl.constexpr,\n):\n    TOPK: tl.constexpr = 2048\n    req_id = tl.program_id(0)\n    page_table_ptr = page_table_ptr + req_id * max_seqlen_k\n    topk_indices_ptr = topk_indices_ptr + req_id * TOPK\n    result_ptr = result_ptr + req_id * TOPK\n\n    offset = tl.arange(0, TOPK)  # topk should be 2048\n    loaded_topk_indices = tl.load(topk_indices_ptr + offset)\n    mask = loaded_topk_indices >= 0\n    loaded_kv_indices = tl.load(page_table_ptr + loaded_topk_indices, mask=mask)\n    tl.store(result_ptr + offset, loaded_kv_indices, mask=mask)\n    tl.store(result_ptr + offset, -1, mask=~mask)\n\n\ndef transform_index_page_table_decode_fast(\n    page_table: torch.Tensor,\n    topk_indices: torch.Tensor,\n    result: Optional[torch.Tensor] = None,\n    page_size: int = 1,\n) -> torch.Tensor:\n    \"\"\"\n    Transform the page table according to topk indices for sparse topk attention.\n    Args:\n        page_table: [qo_len, max_seqlen_k], the original page table\n        topk_indices: [qo_len, topk], the topk indices for each query position\n    Returns:\n        transformed_page_table: [qo_len, topk], the transformed page table\n        For out-of-bound indices in topk_indices, this should be filled with -1.\n    \"\"\"\n    assert page_size == 1\n    assert page_table.shape[0] == topk_indices.shape[0]\n    assert topk_indices.shape[1] == 2048\n    qo_len = topk_indices.shape[0]\n    max_seqlen_k = page_table.shape[1]\n    if result is None:\n        result = torch.empty_like(topk_indices, dtype=torch.int32)\n    # Launch triton kernel\n    grid = (qo_len,)\n    transform_index_page_table_decode_kernel[grid](\n        page_table,\n        topk_indices,\n        result,\n        page_size,\n        max_seqlen_k=max_seqlen_k,\n    )\n    return result\n\n\ndef transform_index_page_table_prefill_fast(\n    page_table: torch.Tensor,\n    topk_indices: torch.Tensor,\n    extend_lens_cpu: List[int],\n    page_size: int = 1,\n) -> torch.Tensor:\n    # TODO(baizhou): can be implemented with another triton kernel\n    assert page_size == 1\n    result = torch.empty_like(topk_indices, dtype=torch.int32)\n    assert len(extend_lens_cpu) == page_table.shape[0]\n    offset = 0\n    for i, l in enumerate(extend_lens_cpu):\n        transform_index_page_table_decode_fast(\n            page_table[i].unsqueeze(0).expand(l, -1),\n            topk_indices[offset : offset + l],\n            result=result[offset : offset + l],\n        )\n        offset += l\n    assert offset == topk_indices.shape[0]\n    return result\n\n\ndef transform_index_page_table_decode_ref(\n    page_table: torch.Tensor,\n    topk_indices: torch.Tensor,\n    result: Optional[torch.Tensor] = None,\n    page_size: int = 1,\n) -> torch.Tensor:\n    assert page_size == 1\n    assert page_table.shape[0] == topk_indices.shape[0]\n    if result is None:\n        result = torch.empty_like(topk_indices, dtype=torch.int32)\n    assert result.shape == topk_indices.shape\n    torch.gather(\n        page_table.to(result.dtype),\n        dim=1,\n        index=topk_indices.clamp(min=0),\n        out=result,\n    )\n    result[topk_indices < 0] = -1\n    return result\n\n\ndef transform_index_page_table_prefill_ref(\n    page_table: torch.Tensor,\n    topk_indices: torch.Tensor,\n    extend_lens_cpu: List[int],\n    page_size: int = 1,\n) -> torch.Tensor:\n    assert page_size == 1\n    result = torch.empty_like(topk_indices, dtype=torch.int32)\n    assert len(extend_lens_cpu) == page_table.shape[0]\n    offset = 0\n    for i, l in enumerate(extend_lens_cpu):\n        transform_index_page_table_decode_ref(\n            page_table[i].unsqueeze(0).expand(l, -1),\n            topk_indices[offset : offset + l],\n            result=result[offset : offset + l],\n        )\n        offset += l\n    assert offset == topk_indices.shape[0]\n    return result\n\n\nif __name__ == \"__main__\":\n    bs, topk, max_seqlen = 10, 2048, 3000\n    page_table = torch.randint(0, 100, (bs, max_seqlen), device=\"cuda\")\n    topk_indices = torch.full((bs, topk), -1, device=\"cuda\")\n    topk_indices[:, :1600] = torch.arange(1600).unsqueeze(0).repeat(bs, 1)\n    ref_result = transform_index_page_table_decode_ref(page_table, topk_indices)\n    result = transform_index_page_table_decode_fast(page_table, topk_indices)\n    assert torch.all(result == ref_result)\n    print(\"Passed\")\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/nsa/triton_kernel.py",
    "content": "from typing import Optional, Tuple\n\nimport torch\nimport triton\nimport triton.language as tl\n\n\n# Triton implementation\n@triton.jit\ndef _act_quant_kernel(\n    X_ptr,\n    Y_ptr,\n    S_ptr,\n    M,\n    N,\n    group_size: tl.constexpr,\n    round_scale: tl.constexpr,\n    BLOCK_M: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n):\n    \"\"\"\n    Triton kernel for activation quantization.\n\n    Each block processes BLOCK_M rows and group_size columns.\n    \"\"\"\n    # Get block IDs\n    pid_m = tl.program_id(0)\n    pid_n = tl.program_id(1)\n\n    # FP8 constants\n    fp8_min = -448.0\n    fp8_max = 448.0\n    fp8_max_inv = 1.0 / fp8_max\n\n    # Calculate row and column offsets\n    row_start = pid_m * BLOCK_M\n    col_start = pid_n * group_size\n\n    # Create offset arrays\n    rows = row_start + tl.arange(0, BLOCK_M)\n    cols = col_start + tl.arange(0, BLOCK_N)\n\n    # Mask for valid rows and columns\n    row_mask = rows < M\n    col_mask = cols < N\n    mask = row_mask[:, None] & col_mask[None, :]\n\n    # Load input data\n    x_ptrs = X_ptr + rows[:, None] * N + cols[None, :]\n    x = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32)\n\n    # Compute absolute max along columns (group_size dimension) for each row\n    x_abs = tl.abs(x)\n    amax = tl.max(x_abs, axis=1)  # Shape: (BLOCK_M,)\n\n    # Clamp amax to avoid division by zero\n    amax = tl.maximum(amax, 1e-4)\n\n    # Compute scale\n    if round_scale:\n        # Fast round scale using bit manipulation approximation\n        # This is a simplified version - the exact bit manipulation is harder in Triton\n        # Using log2 + ceil + pow2 as approximation\n        log_val = tl.log2(amax * fp8_max_inv)\n        log_ceil = tl.ceil(log_val)\n        scale = tl.exp2(log_ceil)\n    else:\n        scale = amax * fp8_max_inv\n\n    # Quantize: y = clamp(x / scale, fp8_min, fp8_max)\n    scale_broadcast = scale[:, None]\n    y = x / scale_broadcast\n    y = tl.minimum(tl.maximum(y, fp8_min), fp8_max)\n\n    # Store quantized output\n    y_ptrs = Y_ptr + rows[:, None] * N + cols[None, :]\n    tl.store(y_ptrs, y, mask=mask)\n\n    # Store scales\n    s_cols = pid_n\n    s_ptrs = S_ptr + rows * (N // group_size) + s_cols\n    s_mask = row_mask\n    tl.store(s_ptrs, scale, mask=s_mask)\n\n\ndef act_quant(\n    x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Quantizes the input tensor `x` using block-wise quantization with Triton.\n\n    Args:\n        x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.\n        block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.\n        scale_fmt (Optional[str], optional): The format of the scale. Default is None.\n    Returns:\n        Tuple[torch.Tensor, torch.Tensor]: A tuple containing:\n            - The quantized tensor with dtype `torch.float8_e4m3fn`.\n            - A tensor of scaling factors with dtype `torch.float32`.\n    \"\"\"\n    assert x.is_contiguous(), \"Input tensor must be contiguous\"\n    assert (\n        x.size(-1) % block_size == 0\n    ), f\"Last dimension size must be divisible by block_size (block_size={block_size})\"\n\n    # Flatten all dims except last\n    N = x.size(-1)\n    x_flat = x.view(-1, N)\n    M = x_flat.size(0)\n\n    # Allocate output tensors\n    y = torch.empty_like(x, dtype=torch.float8_e4m3fn)\n    y_flat = y.view(-1, N)\n    s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32)\n    s_flat = s.view(-1, N // block_size)\n\n    # Launch kernel\n    BLOCK_M = 32\n    BLOCK_N = block_size\n    grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, block_size))\n    round_scale = scale_fmt is not None\n\n    _act_quant_kernel[grid](\n        x_flat,\n        y_flat,\n        s_flat,\n        M,\n        N,\n        group_size=block_size,\n        round_scale=round_scale,\n        BLOCK_M=BLOCK_M,\n        BLOCK_N=BLOCK_N,\n        num_stages=0 if round_scale else 2,\n    )\n\n    return y, s\n\n\n@triton.jit\ndef _get_valid_kv_indices_kernel(\n    page_table_ptr,  # [bs, topk]\n    kv_indptr_ptr,  # [bs + 1]\n    kv_indices_ptr,  # [bs * topk] output buffer\n    bs: tl.constexpr,\n    topk: tl.constexpr,\n):\n    \"\"\"\n    Extract valid indices (non -1) from page_table into kv_indices.\n    Each program handles one batch.\n    \"\"\"\n    batch_id = tl.program_id(0)\n\n    # Get the start position for this batch in kv_indices\n    dst_start = tl.load(kv_indptr_ptr + batch_id)\n\n    # Load all topk indices for this batch\n    src_offset = batch_id * topk\n    offsets = tl.arange(0, topk)\n    indices = tl.load(page_table_ptr + src_offset + offsets)\n\n    # Count valid indices and compact them\n    mask = indices != -1\n\n    # Use prefix sum to compute destination positions for valid elements\n    # For each position, count how many valid elements are before it\n    prefix_sum = tl.cumsum(mask.to(tl.int32), axis=0) - 1\n\n    # Store valid indices to their compacted positions\n    dst_positions = dst_start + prefix_sum\n    tl.store(kv_indices_ptr + dst_positions, indices, mask=mask)\n\n\ndef get_valid_kv_indices(\n    page_table_1: torch.Tensor,\n    kv_indptr: torch.Tensor,\n    kv_indices: torch.Tensor,\n    bs: int,\n):\n    \"\"\"\n    Extract valid indices from page_table_1 into kv_indices buffer.\n\n    Args:\n        page_table_1: [bs, topk] page table with -1 as invalid\n        kv_indptr: [bs + 1] cumulative count of valid indices per batch\n        kv_indices: [bs * topk] pre-allocated output buffer\n        bs: batch size\n    \"\"\"\n    topk = page_table_1.shape[1]\n    grid = (bs,)\n    _get_valid_kv_indices_kernel[grid](\n        page_table_1,\n        kv_indptr,\n        kv_indices,\n        bs,\n        topk,\n    )\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/nsa/utils.py",
    "content": "# temp NSA debugging environ\nfrom dataclasses import dataclass\nfrom itertools import accumulate\nfrom typing import TYPE_CHECKING, List, Tuple, Union\n\nimport torch\nimport torch.nn.functional as F\nimport triton\nimport triton.language as tl\n\nfrom sglang.srt.distributed.device_communicators.pynccl_allocator import (\n    use_symmetric_memory,\n)\nfrom sglang.srt.layers.dp_attention import (\n    DpPaddingMode,\n    attn_cp_all_gather_into_tensor,\n    get_attention_cp_group,\n    get_attention_cp_rank,\n    get_attention_cp_size,\n    get_attention_dp_rank,\n    is_allocation_symmetric,\n)\nfrom sglang.srt.server_args import get_global_server_args\nfrom sglang.srt.utils.common import ceil_align, ceil_div\n\nif TYPE_CHECKING:\n    from sglang.srt.model_executor.forward_batch_info import ForwardBatch\n\n\ndef compute_nsa_seqlens(original_seq_lens, nsa_index_topk: int):\n    return original_seq_lens.clamp(max=nsa_index_topk)\n\n\ndef is_nsa_enable_prefill_cp():\n    return get_global_server_args().enable_nsa_prefill_context_parallel\n\n\ndef is_nsa_prefill_cp_in_seq_split():\n    return (\n        is_nsa_enable_prefill_cp()\n        and get_global_server_args().nsa_prefill_cp_mode == \"in-seq-split\"\n    )\n\n\ndef is_nsa_prefill_cp_round_robin_split():\n    return (\n        is_nsa_enable_prefill_cp()\n        and get_global_server_args().nsa_prefill_cp_mode == \"round-robin-split\"\n    )\n\n\ndef can_nsa_prefill_cp_round_robin_split(forward_batch: \"ForwardBatch\"):\n    if not forward_batch.forward_mode.is_context_parallel_extend():\n        return False\n    cp_size = get_attention_cp_size()\n    seq_len = sum(forward_batch.extend_seq_lens_cpu)\n    return (\n        is_nsa_prefill_cp_round_robin_split()\n        and seq_len > 0\n        and seq_len >= cp_size\n        and cp_size > 1\n    )\n\n\ndef nsa_cp_round_robin_split_data(input_: Union[torch.Tensor, List]):\n    \"\"\"\n    # for round-robin-split, split the tokens evenly according to the rule of token_idx % cp_size.\n    |   +-----------before split------------+|\n    | token0, token1, token2, token3, token4, token5, token6, token7, ...\n    |\n    |   +--------------result-------------------+\n    | dp_atten_tp0: token0, token4, token8, token12, token16, ... |\n    | dp_atten_tp1: token1, token5, token9, token13, token17, ... |\n    | dp_atten_tp2: token2, token6, token10, token14, token18, ... |\n    | dp_atten_tp3: token3, token7, token11, token15, token19, ... |\n    |   +-------------------------+\n    \"\"\"\n    cp_size = get_attention_cp_size()\n    cp_rank = get_attention_cp_rank()\n    if isinstance(input_, (tuple, list)):\n        indices = range(cp_rank, len(input_), cp_size)\n        return input_[indices]\n\n    tokens = len(input_)\n    if tokens % cp_size != 0:\n        cur_len = tokens // cp_size + (tokens % cp_size > cp_rank)\n        if cur_len == 0:\n            return input_.new_empty(0, *input_.shape[1:])\n        indices = torch.arange(cp_rank, tokens, cp_size, device=input_.device)\n        return input_[indices]\n\n    # for torch device tensor\n    return input_.view(-1, cp_size, *input_.shape[1:])[:, cp_rank].contiguous()\n\n\ndef cal_padded_tokens(forward_batch: \"ForwardBatch\"):\n    # Consistent with the padding calculation logic in ForwardBatch.prepare_mlp_sync_batch,\n    # calculate the actual token length after padding when attn_tp_size > 1 or in the MAX_LEN padding mode.\n    global_num_tokens = forward_batch.global_num_tokens_cpu.copy()\n    sync_group_size = len(global_num_tokens)\n    attn_cp_size = get_attention_cp_size()\n    for i in range(sync_group_size):\n        global_num_tokens[i] = ceil_align(global_num_tokens[i], attn_cp_size)\n    dp_padding_mode = DpPaddingMode.get_dp_padding_mode(\n        forward_batch.is_extend_in_batch, global_num_tokens\n    )\n    if dp_padding_mode.is_max_len():\n        tokens = max(global_num_tokens)\n    elif len(global_num_tokens) > 1:\n        tokens = global_num_tokens[get_attention_dp_rank()]\n    else:\n        tokens = global_num_tokens[0]\n    if can_nsa_prefill_cp_round_robin_split(forward_batch):\n        tokens = ceil_div(tokens, attn_cp_size)\n    return tokens\n\n\ndef pad_nsa_cache_seqlens(forward_batch: \"ForwardBatch\", nsa_cache_seqlens):\n    attn_cp_size = get_attention_cp_size()\n    needs_cp_pad = attn_cp_size > 1 and can_nsa_prefill_cp_round_robin_split(\n        forward_batch\n    )\n    needs_dp_pad = forward_batch.global_num_tokens_cpu is not None\n    if not needs_cp_pad and not needs_dp_pad:\n        return nsa_cache_seqlens\n    tokens = cal_padded_tokens(forward_batch)\n    pad_len = tokens - nsa_cache_seqlens.shape[0]\n    if pad_len > 0:\n        nsa_cache_seqlens = torch.cat(\n            [\n                nsa_cache_seqlens,\n                nsa_cache_seqlens.new_zeros(pad_len, *nsa_cache_seqlens.shape[1:]),\n            ]\n        )\n    return nsa_cache_seqlens\n\n\n@dataclass\nclass NSAContextParallelMetadata:\n\n    split_list: List[int] = None\n    max_rank_len: List[int] = None\n    zigzag_index: List[int] = None\n    per_rank_actual_token: List[int] = None\n    reverse_split_len: List[int] = None\n    cp_reverse_index: List[int] = None\n    kv_len_prev: int = -1\n    kv_len_next: int = -1\n    actual_seq_q_prev: int = -1\n    actual_seq_q_next: int = -1\n    kv_len_prev_tensor: torch.Tensor = None\n    kv_len_next_tensor: torch.Tensor = None\n    actual_seq_q_prev_tensor: torch.Tensor = None\n    actual_seq_q_next_tensor: torch.Tensor = None\n    total_seq_lens: torch.Tensor = None\n\n\ndef can_cp_split(seq_len: int, cp_size: int, use_nsa: bool, forward_batch):\n    if is_nsa_prefill_cp_round_robin_split():\n        cur_cp_seq_len = seq_len // cp_size\n        assert (\n            seq_len % cp_size == 0\n        ), f\"seq_len {seq_len} is not divisible by cp_size {cp_size} when nsa_prefill_cp_mode is round-robin-split\"\n    else:\n        # TODO current just support prefill batch=1 and len(input_ids) > self.cp_size * 2\n        # Note: (self.cp_size * 2) To achieve load balancing for seq computation,\n        # the seq data needs to be divided and recombined at twice the size of cp_size.\n        cur_cp_seq_len = seq_len // (cp_size * 2)\n    if (\n        cur_cp_seq_len != 0\n        and cp_size > 1\n        and use_nsa\n        and forward_batch.forward_mode.is_context_parallel_extend()\n        and is_nsa_enable_prefill_cp()\n        and sum(forward_batch.extend_seq_lens_cpu) >= cp_size\n    ):\n        return True\n    else:\n        return False\n\n\ndef cp_split_and_rebuild_data(forward_batch, input_: torch.Tensor):\n    if is_nsa_prefill_cp_round_robin_split():\n        cp_size = get_attention_cp_size()\n        assert (\n            input_.shape[0] % cp_size == 0\n        ), f\"Expect input shape 0 can divided by cp size, but got input shape {input_.shape}, cp size {cp_size}\"\n        return nsa_cp_round_robin_split_data(input_)\n\n    input_list = list(\n        torch.split(input_, forward_batch.nsa_cp_metadata.split_list, dim=0)\n    )\n    result = torch.cat(\n        [input_list[i] for i in forward_batch.nsa_cp_metadata.zigzag_index], dim=0\n    ).view(-1, input_.shape[-1])\n    return result\n\n\ndef cp_split_and_rebuild_position(forward_batch, positions: torch.Tensor):\n    if is_nsa_prefill_cp_round_robin_split():\n        cp_size = get_attention_cp_size()\n        assert positions.shape[0] % cp_size == 0, (\n            f\"Expect positions shape 0 can divided by cp size, but got positions shape {positions.shape}, \"\n            f\"cp size {cp_size}\"\n        )\n        return nsa_cp_round_robin_split_data(positions)\n\n    position_id_list = list(\n        torch.split(positions, forward_batch.nsa_cp_metadata.split_list, dim=-1)\n    )\n    positions = torch.cat(\n        [position_id_list[i] for i in forward_batch.nsa_cp_metadata.zigzag_index],\n        dim=-1,\n    )\n    return positions\n\n\n@triton.jit\ndef nsa_cp_round_robin_split_q_seqs_kernel(\n    in_seqs_ptr,\n    out_seqs_ptr,\n    bs_idx_ptr,\n    tokens: tl.constexpr,\n    cp_size: tl.constexpr,\n    cp_rank: tl.constexpr,\n):\n    extra_seq = 0\n    bs_idx = 0\n    for bs in range(tokens):\n        cur_len = tl.load(in_seqs_ptr + bs)\n        cur_len += extra_seq\n        cur_seq = cur_len // cp_size + (cur_len % cp_size > cp_rank)\n        if cur_seq > 0:\n            tl.store(bs_idx_ptr + bs_idx, bs)\n            tl.store(out_seqs_ptr + bs_idx, cur_seq)\n            bs_idx += 1\n        extra_seq = cur_len - cur_seq * cp_size\n\n\ndef nsa_cp_round_robin_split_q_seqs_cpu(extend_seqs):\n    cp_size = get_attention_cp_size()\n    cp_rank = get_attention_cp_rank()\n    extra_seq = 0\n    q_seqs = []\n    for bs, cur_len in enumerate(extend_seqs):\n        cur_len += extra_seq\n        cur_seq = cur_len // cp_size + int(cur_len % cp_size > cp_rank)\n        q_seqs.append(cur_seq)\n        extra_seq = cur_len - cur_seq * cp_size\n    bs_idx = list([i for i, x in enumerate(q_seqs) if x > 0])\n    q_seqs = [q_len for q_len in q_seqs if q_len > 0]\n    return q_seqs, bs_idx\n\n\ndef nsa_cp_round_robin_split_q_seqs(\n    extend_seqs_cpu, extend_seqs\n) -> Tuple[List, torch.Tensor, List, torch.Tensor]:\n    \"\"\"\n    round-robin-split distributes tokens across ranks based on token_idx % cp_size.\n\n    Return:\n    ret_q_lens_cpu(List) and ret_q_lens(torch.Tensor): the partitioned length (excluding zeros) on the current cp rank\n        for each sequence after distribution across cp ranks.\n    bs_idx_cpu(List) and bs_idx(torch.Tensor): marks which sequences are ultimately selected,\n        i.e., those with a partitioned length greater than zero.\n    \"\"\"\n    cp_size = get_attention_cp_size()\n    cp_rank = get_attention_cp_rank()\n    # len(ret_q_lens_cpu) == len(bs_idx_cpu)\n    ret_q_lens_cpu, bs_idx_cpu = nsa_cp_round_robin_split_q_seqs_cpu(extend_seqs_cpu)\n    ret_q_lens = torch.empty(\n        (len(bs_idx_cpu),), device=extend_seqs.device, dtype=extend_seqs.dtype\n    )\n    bs_idx = torch.empty(\n        (len(bs_idx_cpu),), device=extend_seqs.device, dtype=torch.int32\n    )\n    grid = (1,)\n    nsa_cp_round_robin_split_q_seqs_kernel[grid](\n        extend_seqs, ret_q_lens, bs_idx, len(extend_seqs), cp_size, cp_rank\n    )\n    return ret_q_lens_cpu, ret_q_lens, bs_idx_cpu, bs_idx\n\n\ndef nsa_use_prefill_cp(forward_batch, nsa_enable_prefill_cp=None):\n    if nsa_enable_prefill_cp is None:\n        nsa_enable_prefill_cp = is_nsa_enable_prefill_cp()\n    if (\n        forward_batch.nsa_cp_metadata is not None\n        and nsa_enable_prefill_cp\n        and forward_batch.forward_mode.is_context_parallel_extend()\n    ):\n        return True\n    else:\n        return False\n\n\ndef cp_attn_tp_all_gather_reorganazied_into_tensor(\n    input_: torch.Tensor, total_len, attn_tp_size, forward_batch, stream_op\n):\n    \"\"\"\n    Allgather communication for context_parallel(kv_cache, index_k, hidden_states).\n    This implementation mainly consists of three parts:\n    Step 1, padding the input shape to unify the shape for allgather communication (the shape must be the same).\n    Step 2, allgather communication(async).\n    Step 3, removing the padding and reassembling the data according to the actual tokens.\n    \"\"\"\n    # step1\n    max_len = (total_len + attn_tp_size - 1) // attn_tp_size\n    pad_size = max_len - input_.shape[0]\n    if pad_size > 0:\n        input_ = F.pad(input_, (0, 0, 0, pad_size), mode=\"constant\", value=0)\n    with use_symmetric_memory(\n        get_attention_cp_group(), disabled=not is_allocation_symmetric()\n    ):\n        input_tensor_all = torch.empty(\n            max_len * attn_tp_size,\n            input_.shape[1],\n            device=input_.device,\n            dtype=input_.dtype,\n        )\n    # step2\n    get_attention_cp_group().cp_all_gather_into_tensor_async(\n        input_tensor_all, input_, stream_op\n    )\n    # step3\n    outputs_list_max = list(\n        torch.split(input_tensor_all, forward_batch.nsa_cp_metadata.max_rank_len, dim=0)\n    )\n    outputs = torch.cat(\n        [\n            outputs_list_max[index][:per_rank_len]\n            for index, per_rank_len in enumerate(\n                forward_batch.nsa_cp_metadata.per_rank_actual_token\n            )\n        ],\n        dim=0,\n    )\n    return outputs\n\n\ndef cp_all_gather_rerange_output(input_tensor, cp_size, forward_batch, stream):\n    \"\"\"\n    # for in-seq-split\n    |   +-----------before allgather------------+|\n    |   | dp_atten_tp0: block0, block7 |\n    |   | dp_atten_tp1: block1, block6 |\n    |   | dp_atten_tp2: block2, block5 |\n    |   | dp_atten_tp3: block3, block4 |\n    |\n    |   +----------before rerange---------------+|\n    | block0 | block7 | block1 | block6 | block2 | block5 | block3 | block4 |\n    |\n    |   +--------------result-------------------+\n    | block0 | block1 | block2 | block3 | block4 | block5 | block6 | block7 |\n    |   +-------------------------+\n\n    # for round-robin-split\n    |   +-----------before allgather------------+|\n    | dp_atten_tp0: token0, token4, token8, token12, token16, ... |\n    | dp_atten_tp1: token1, token5, token9, token13, token17, ... |\n    | dp_atten_tp2: token2, token6, token10, token14, token18, ... |\n    | dp_atten_tp3: token3, token7, token11, token15, token19, ... |\n    |\n    |   +--------------result-------------------+\n    | token0, token1, token2, token3, token4, token5, token6, token7, ...\n    |   +-------------------------+\n    \"\"\"\n    if is_nsa_prefill_cp_round_robin_split():\n        with use_symmetric_memory(\n            get_attention_cp_group(), disabled=not is_allocation_symmetric()\n        ):\n            output_tensor = input_tensor.new_empty(\n                (input_tensor.shape[0] * cp_size, *input_tensor.shape[1:]),\n            )\n        attn_cp_all_gather_into_tensor(\n            output_tensor,\n            input_tensor,\n        )\n        out_shape = output_tensor.shape\n        output_tensor = (\n            output_tensor.view(cp_size, -1, *out_shape[1:])\n            .transpose(0, 1)\n            .reshape(out_shape)\n        )\n        return output_tensor\n\n    bs_seq_len, hidden_size = input_tensor.shape\n    output_tensor = cp_attn_tp_all_gather_reorganazied_into_tensor(\n        input_tensor,\n        forward_batch.nsa_cp_metadata.total_seq_lens,\n        cp_size,\n        forward_batch,\n        stream,\n    )\n    outputs_list = list(\n        torch.split(\n            output_tensor, forward_batch.nsa_cp_metadata.reverse_split_len, dim=0\n        )\n    )\n    output_tensor = torch.cat(\n        [outputs_list[i] for i in forward_batch.nsa_cp_metadata.cp_reverse_index], dim=0\n    )\n    output_tensor = output_tensor.view(-1, hidden_size)\n    return output_tensor\n\n\ndef calculate_cp_seq_idx(cp_chunks_len, seqs_len):\n    \"\"\"Used to obtain the index of the seq corresponding\n    to each cp block in the forwardbatch, and the starting\n    and ending positions of the corresponding seq in the cp block\"\"\"\n    j = 0\n    tuple_len = []  # Only keep this result list\n    cumulative = {}  # Used to track cumulative values for each index\n\n    for i in range(len(cp_chunks_len)):\n        current_dict = {}\n        current_tuples = []\n        c_val = cp_chunks_len[i]\n\n        while j < len(seqs_len):\n            s_val = seqs_len[j]\n            if s_val == c_val:\n                idx = j\n                current_dict[idx] = s_val\n                # Update cumulative value for this index\n                cumulative[idx] = cumulative.get(idx, 0) + s_val\n                j += 1\n                break\n            elif s_val > c_val:\n                idx = j\n                current_dict[idx] = c_val\n                # Update cumulative value for this index\n                cumulative[idx] = cumulative.get(idx, 0) + c_val\n                seqs_len[j] = s_val - c_val\n                break\n            else:  # s_val < c_val\n                idx = j\n                current_dict[idx] = s_val\n                # Update cumulative value for this index\n                cumulative[idx] = cumulative.get(idx, 0) + s_val\n                c_val -= s_val\n                j += 1\n\n        # Build tuple: (index, historical cumulative, historical+current)\n        for idx, val in current_dict.items():\n            # Subtract current value to get historical cumulative\n            prev_cum = cumulative.get(idx, 0) - val\n            current_cum = prev_cum + val\n            current_tuples.append((idx, prev_cum, current_cum))\n\n        tuple_len.append(current_tuples)\n    return tuple_len\n\n\ndef prepare_input_dp_with_cp_dsa(\n    kv_len,\n    cp_rank,\n    cp_size,\n    seqs_len,\n):\n    if is_nsa_prefill_cp_round_robin_split():\n        return True\n    \"\"\"prepare_input_dp_with_cp_dsa-zigzag index\n    Example (DP_ATTENT_TP == CP_SIZE == 4):\n    Description:\n    1. Start with a full-length request.\n    2. Split the request into multiple blocks (block0 to block7).\n    3. Rearrange these blocks to balance computational\n        load across different DP ranks.\n    4. Assign the rearranged blocks to different DP attention\n        time points (dp_atten_tp0 to dp_atten_tp3).\n    +---------------------------------+\n    |        cp_split_tokens         |\n    +---------------------------------+\n    |                                 |\n    |   request_with_full_length     |\n    |             | split (cp_size * 2) |\n    |   +-------------------------+  |\n    |   | block0 | block1 | block2 | block3 | block4 | block5 | block6 | block7 |\n    |   +-------------------------+  |\n    |             | rerange          |\n    |   +---------------------------------+\n    |   | block0 | block7 | block1 | block6 | block2 | block5 | block3 | block4 |\n    |   +---------------------------------+\n    |             |\n    |   +-------------------------+\n    |   | dp_atten_tp0: block0, block7 |\n    |   | dp_atten_tp1: block1, block6 |\n    |   | dp_atten_tp2: block2, block5 |\n    |   | dp_atten_tp3: block3, block4 |\n    |   +-------------------------+\n\n    Why zigzag rearrange?\n    - Attention calculations must follow causal attention principles.\n    - Simply slicing by rank order can lead to computational load imbalance:\n        * First rank may focus on fewer historical key-value tokens (less computation)\n        * Last rank may focus on more tokens (more computation)\n    - To mitigate uneven load, the input hissenstate needs to be sliced by cp_size*2 and rearranged.\n    \"\"\"\n    # just support batch = 1\n    kv_len = torch.tensor(kv_len)\n    bs_per_cp_group = 1\n    kv_len_origin = kv_len\n    # get zigzag index\n    cp_segment_num = cp_size * 2\n    seq_per_batch = kv_len // cp_segment_num  # seq_len for each batch and segment\n    split_list = seq_per_batch.repeat_interleave(cp_segment_num).int().tolist()\n    remainder = kv_len % (cp_segment_num)\n    if remainder > 0:\n        split_list[:remainder] = [x + 1 for x in split_list[:remainder]]\n\n    seq_max_rank_len = (kv_len + cp_size - 1) // cp_size\n    max_rank_len = seq_max_rank_len.repeat_interleave(cp_size).int().tolist()\n    zigzag_index = list(\n        range(cp_rank, cp_rank + bs_per_cp_group * cp_segment_num, cp_segment_num)\n    ) + list(\n        range(\n            cp_segment_num - cp_rank - 1,\n            bs_per_cp_group * cp_segment_num,\n            cp_segment_num,\n        )\n    )\n\n    per_rank_actual_token = list(\n        split_list[i] + split_list[cp_size * 2 - i - 1] for i in range(cp_size)\n    )\n    reverse_split_len = [\n        element\n        for i in range(cp_size)\n        for element in (split_list[i], split_list[cp_size * 2 - i - 1])\n    ]\n    # get zigzag reverse index\n    cp_reverse_index = []\n    for batch_id in range(bs_per_cp_group):\n        cp_reverse_index.extend(\n            list(range(batch_id, cp_segment_num * bs_per_cp_group, 2 * bs_per_cp_group))\n            + list(\n                range(\n                    (cp_segment_num - 1) * bs_per_cp_group + batch_id,\n                    0,\n                    -2 * bs_per_cp_group,\n                )\n            )\n        )\n    prefix_sum_list = list(accumulate(split_list))\n\n    # TODO Support multi-batch-cp-split, multi-batch-cp support has accuracy issues\n    # cp_seq_index = calculate_cp_seq_idx(split_list[:], seqs_len[:])\n    kv_len_prev = prefix_sum_list[cp_rank]\n    kv_len_next = prefix_sum_list[cp_size * 2 - cp_rank - 1]\n    actual_seq_q_prev = split_list[cp_rank]\n    actual_seq_q_next = split_list[cp_size * 2 - cp_rank - 1]\n    kv_len_prev_tensor = torch.tensor(kv_len_prev).to(device=\"cuda\", dtype=torch.int32)\n    kv_len_next_tensor = torch.tensor(kv_len_next).to(device=\"cuda\", dtype=torch.int32)\n    actual_seq_q_prev_tensor = torch.tensor(actual_seq_q_prev).to(\n        device=\"cuda\", dtype=torch.int32\n    )\n    actual_seq_q_next_tensor = torch.tensor(actual_seq_q_next).to(\n        device=\"cuda\", dtype=torch.int32\n    )\n\n    nsa_cp_metadata = NSAContextParallelMetadata(\n        split_list=split_list,\n        max_rank_len=max_rank_len,\n        zigzag_index=zigzag_index,\n        per_rank_actual_token=per_rank_actual_token,\n        reverse_split_len=reverse_split_len,\n        cp_reverse_index=cp_reverse_index,\n        kv_len_prev=kv_len_prev,\n        kv_len_next=kv_len_next,\n        actual_seq_q_prev=actual_seq_q_prev,\n        actual_seq_q_next=actual_seq_q_next,\n        kv_len_prev_tensor=kv_len_prev_tensor,\n        kv_len_next_tensor=kv_len_next_tensor,\n        actual_seq_q_prev_tensor=actual_seq_q_prev_tensor,\n        actual_seq_q_next_tensor=actual_seq_q_next_tensor,\n        total_seq_lens=kv_len_origin,\n    )\n    return nsa_cp_metadata\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/nsa_backend.py",
    "content": "from __future__ import annotations\n\nfrom dataclasses import dataclass\nfrom enum import IntEnum, auto\nfrom typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, TypeAlias\n\nimport torch\n\nfrom sglang.srt.configs.model_config import get_nsa_index_topk, is_deepseek_nsa\nfrom sglang.srt.environ import envs\nfrom sglang.srt.layers.attention.base_attn_backend import AttentionBackend\nfrom sglang.srt.layers.attention.nsa.dequant_k_cache import dequantize_k_cache_paged\nfrom sglang.srt.layers.attention.nsa.nsa_backend_mtp_precompute import (\n    NativeSparseAttnBackendMTPPrecomputeMixin,\n    PrecomputedMetadata,\n    compute_cu_seqlens,\n)\nfrom sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata\nfrom sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache\nfrom sglang.srt.layers.attention.nsa.transform_index import (\n    transform_index_page_table_decode,\n    transform_index_page_table_prefill,\n)\nfrom sglang.srt.layers.attention.nsa.utils import (\n    can_nsa_prefill_cp_round_robin_split,\n    compute_nsa_seqlens,\n    is_nsa_enable_prefill_cp,\n    nsa_cp_round_robin_split_data,\n    nsa_cp_round_robin_split_q_seqs,\n    pad_nsa_cache_seqlens,\n)\nfrom sglang.srt.layers.attention.utils import (\n    concat_mla_absorb_q_general,\n    mla_quantize_and_rope_for_fp8,\n    seqlens_expand_triton,\n)\nfrom sglang.srt.layers.dp_attention import get_attention_tp_size\nfrom sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode\nfrom sglang.srt.utils import is_cuda, is_hip\n\nif TYPE_CHECKING:\n    from sglang.srt.layers.radix_attention import RadixAttention\n    from sglang.srt.model_executor.model_runner import ModelRunner\n    from sglang.srt.speculative.spec_info import SpecInput\n\n\n_is_hip = is_hip()\n\nif _is_hip:\n    from sglang.srt.layers.attention.nsa.triton_kernel import get_valid_kv_indices\n\n    try:\n        from aiter import (  # noqa: F401\n            flash_attn_varlen_func,\n            mha_batch_prefill_func,\n            paged_attention_ragged,\n        )\n        from aiter.mla import mla_decode_fwd, mla_prefill_fwd  # noqa: F401\n    except ImportError:\n        print(\n            \"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device.\"\n        )\nelse:\n    from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache\n\n\n# Reuse this workspace buffer across all NSA backend instances\nglobal_workspace_buffer = None\n\n# Control whether to use fused metadata copy kernel for cuda graph replay (default: enabled)\n# Set SGLANG_USE_FUSED_METADATA_COPY=0 or false to disable\n_USE_FUSED_METADATA_COPY = envs.SGLANG_USE_FUSED_METADATA_COPY.get() and not _is_hip\n\n\n@dataclass(frozen=True)\nclass NSAFlashMLAMetadata:\n    \"\"\"Metadata only needed by FlashMLA\"\"\"\n\n    flashmla_metadata: torch.Tensor\n    num_splits: torch.Tensor\n\n    def slice(self, sli):\n        return NSAFlashMLAMetadata(\n            flashmla_metadata=self.flashmla_metadata,\n            num_splits=self.num_splits[sli],\n        )\n\n    def copy_(self, other: \"NSAFlashMLAMetadata\"):\n        self.flashmla_metadata.copy_(other.flashmla_metadata)\n        self.num_splits.copy_(other.num_splits)\n\n\n@dataclass(frozen=True)\nclass NSAMetadata:\n    page_size: int\n\n    # Sequence lengths for the forward batch\n    cache_seqlens_int32: torch.Tensor\n    # Maximum sequence length for query\n    max_seq_len_q: int\n    # Maximum sequence length for key\n    max_seq_len_k: int\n    # Cumulative sequence lengths for query\n    cu_seqlens_q: torch.Tensor\n    # Cumulative sequence lengths for key\n    cu_seqlens_k: torch.Tensor\n    # Page table, the index of KV Cache Tables/Blocks\n    # this table is always with page_size = 1\n    page_table_1: torch.Tensor\n\n    # NOTE(dark): This will property be used in:\n    # 1. dense decode/prefill, we use paged flash attention, need real_page_table\n    # 2. sparse decode/prefill, indexer need real_page_table to compute the score\n    real_page_table: torch.Tensor\n\n    # NSA metadata (nsa prefill are expanded)\n    nsa_cache_seqlens_int32: torch.Tensor  # this seqlens is clipped to `topk`\n    nsa_cu_seqlens_q: torch.Tensor  # must be arange(0, len(nsa_cu_seqlens_k))\n    nsa_cu_seqlens_k: torch.Tensor  # cumsum of `nsa_cache_seqlens_int32`\n    nsa_extend_seq_lens_list: List[int]\n    nsa_seqlens_expanded: torch.Tensor  # expanded, unclipped `seqlens`\n    nsa_max_seqlen_q: Literal[1] = 1  # always 1 for decode, variable for extend\n\n    flashmla_metadata: Optional[NSAFlashMLAMetadata] = None\n    # DeepGEMM schedule metadata for paged MQA logits (decode/target_verify/draft_extend only).\n    # Precomputed once per forward batch and reused across layers.\n    paged_mqa_schedule_metadata: Optional[torch.Tensor] = None\n    # The sum of sequence lengths for key, prefill only\n    seq_lens_sum: Optional[int] = None\n    # The flattened 1D page table with shape (seq_lens_sum,), prefill only\n    # this table is always with page_size = 1\n    page_table_1_flattened: Optional[torch.Tensor] = None\n    # The offset of topk indices in ragged kv, prefill only\n    # shape: (seq_lens_sum,)\n    topk_indices_offset: Optional[torch.Tensor] = None\n\n    # k_start and k_end in kv cache for each token.\n    indexer_k_start_end: Optional[Tuple[torch.Tensor, torch.Tensor]] = None\n    # seq lens for each batch.\n    indexer_seq_lens_cpu: Optional[torch.Tensor] = None\n    # seq lens for each batch.\n    indexer_seq_lens: Optional[torch.Tensor] = None\n    # batch index for each token.\n    token_to_batch_idx: Optional[torch.Tensor] = None\n\n\nclass TopkTransformMethod(IntEnum):\n    # Transform topk indices to indices to the page table (page_size = 1)\n    PAGED = auto()\n    # Transform topk indices to indices to ragged kv (non-paged)\n    RAGGED = auto()\n\n\n@torch.compile\ndef _compiled_cat(tensors: list[torch.Tensor], dim: int = -1) -> torch.Tensor:\n    return torch.cat(tensors, dim=dim)\n\n\ndef _cat(tensors: list[torch.Tensor], dim: int = -1) -> torch.Tensor:\n    \"\"\"\n    Concatenate two tensors along the last dimension.\n    Use this function to concatenate q_nope and q_rope or k_nope and k_rope.\n    \"\"\"\n    assert len(tensors) == 2\n\n    qk_nope, qk_rope = tensors\n    assert qk_nope.ndim == 3 and qk_rope.ndim == 3\n\n    torch._dynamo.mark_dynamic(qk_nope, 0)\n    torch._dynamo.mark_dynamic(qk_rope, 0)\n\n    return _compiled_cat([qk_nope, qk_rope], dim=dim)\n\n\n@dataclass(frozen=True)\nclass NSAIndexerMetadata(BaseIndexerMetadata):\n    attn_metadata: NSAMetadata\n    topk_transform_method: TopkTransformMethod\n    paged_mqa_schedule_metadata: Optional[torch.Tensor] = None\n\n    def get_seqlens_int32(self) -> torch.Tensor:\n        return self.attn_metadata.cache_seqlens_int32\n\n    def get_page_table_64(self) -> torch.Tensor:\n        return self.attn_metadata.real_page_table\n\n    def get_page_table_1(self) -> torch.Tensor:\n        return self.attn_metadata.page_table_1\n\n    def get_seqlens_expanded(self) -> torch.Tensor:\n        return self.attn_metadata.nsa_seqlens_expanded\n\n    def get_cu_seqlens_k(self) -> torch.Tensor:\n        return self.attn_metadata.cu_seqlens_k\n\n    def get_indexer_kvcache_range(self) -> Tuple[torch.Tensor, torch.Tensor]:\n        return self.attn_metadata.indexer_k_start_end\n\n    def get_indexer_seq_len(self) -> torch.Tensor:\n        return self.attn_metadata.indexer_seq_lens\n\n    def get_indexer_seq_len_cpu(self) -> torch.Tensor:\n        return self.attn_metadata.indexer_seq_lens_cpu\n\n    def get_nsa_extend_len_cpu(self) -> List[int]:\n        return self.attn_metadata.nsa_extend_seq_lens_list\n\n    def get_token_to_batch_idx(self) -> torch.Tensor:\n        return self.attn_metadata.token_to_batch_idx\n\n    def topk_transform(\n        self,\n        logits: torch.Tensor,\n        topk: int,\n        ks: Optional[torch.Tensor] = None,\n        cu_seqlens_q: torch.Tensor = None,\n        ke_offset: torch.Tensor = None,\n        batch_idx_list: List[int] = None,\n        topk_indices_offset_override: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        from sgl_kernel import (\n            fast_topk_transform_fused,\n            fast_topk_transform_ragged_fused,\n            fast_topk_v2,\n        )\n\n        if topk_indices_offset_override is not None:\n            cu_topk_indices_offset = topk_indices_offset_override\n            cu_seqlens_q_topk = None\n        elif cu_seqlens_q is not None:\n            cu_seqlens_q = cu_seqlens_q.to(torch.int32)\n            cu_seqlens_q_topk = compute_cu_seqlens(cu_seqlens_q)\n            cu_topk_indices_offset = torch.repeat_interleave(\n                cu_seqlens_q_topk[:-1],\n                cu_seqlens_q,\n            )\n        else:\n            cu_seqlens_q_topk = self.attn_metadata.cu_seqlens_q\n            cu_topk_indices_offset = self.attn_metadata.topk_indices_offset\n        if ke_offset is not None:\n            seq_lens_topk = ke_offset\n        else:\n            seq_lens_topk = self.get_seqlens_expanded()\n        if batch_idx_list is not None:\n            page_table_size_1 = self.attn_metadata.page_table_1[batch_idx_list]\n        else:\n            page_table_size_1 = self.attn_metadata.page_table_1\n\n        if not envs.SGLANG_NSA_FUSE_TOPK.get():\n            return fast_topk_v2(logits, seq_lens_topk, topk, row_starts=ks)\n        elif self.topk_transform_method == TopkTransformMethod.PAGED:\n            # NOTE(dark): if fused, we return a transformed page table directly\n            return fast_topk_transform_fused(\n                score=logits,\n                lengths=seq_lens_topk,\n                page_table_size_1=page_table_size_1,\n                cu_seqlens_q=cu_seqlens_q_topk,\n                topk=topk,\n                row_starts=ks,\n            )\n        elif self.topk_transform_method == TopkTransformMethod.RAGGED:\n            return fast_topk_transform_ragged_fused(\n                score=logits,\n                lengths=seq_lens_topk,\n                topk_indices_offset=cu_topk_indices_offset,\n                topk=topk,\n                row_starts=ks,\n            )\n        else:\n            assert False, f\"Unsupported {self.topk_transform_method = }\"\n\n\n_NSA_IMPL_T: TypeAlias = Literal[\n    \"flashmla_sparse\", \"flashmla_kv\", \"fa3\", \"tilelang\", \"trtllm\"\n]\n\n\nclass NativeSparseAttnBackend(\n    NativeSparseAttnBackendMTPPrecomputeMixin, AttentionBackend\n):\n    def __init__(\n        self,\n        model_runner: ModelRunner,\n        skip_prefill: bool = False,\n        speculative_step_id=0,\n        topk=0,\n        speculative_num_steps=0,\n    ):\n        super().__init__()\n        self.forward_metadata: NSAMetadata\n        self.device = model_runner.device\n        assert isinstance(model_runner.page_size, int)\n        self.real_page_size = model_runner.page_size\n        self.num_splits = (\n            1 if model_runner.server_args.enable_deterministic_inference else 0\n        )\n        self.use_nsa = is_deepseek_nsa(model_runner.model_config.hf_config)\n        assert self.use_nsa, \"NSA backend only supports DeepSeek NSA\"\n        self.nsa_kv_cache_store_fp8 = (\n            model_runner.token_to_kv_pool.nsa_kv_cache_store_fp8\n        )\n        self.nsa_index_topk = get_nsa_index_topk(model_runner.model_config.hf_config)\n        self.max_context_len = model_runner.model_config.context_len\n        self.num_q_heads = (\n            model_runner.model_config.num_attention_heads // get_attention_tp_size()\n        )\n        self.kv_cache_dim = model_runner.token_to_kv_pool.kv_cache_dim\n        self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim\n        self.kv_lora_rank = model_runner.model_config.kv_lora_rank\n        self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim\n\n        assert model_runner.req_to_token_pool is not None\n        self.req_to_token = model_runner.req_to_token_pool.req_to_token\n\n        self.use_mha: bool = False\n        self.nsa_prefill_impl: _NSA_IMPL_T = (\n            model_runner.server_args.nsa_prefill_backend\n        )\n        self.nsa_decode_impl: _NSA_IMPL_T = model_runner.server_args.nsa_decode_backend\n        self.enable_auto_select_prefill_impl = self.nsa_prefill_impl == \"flashmla_auto\"\n\n        self._arange_buf = torch.arange(16384, device=self.device, dtype=torch.int32)\n\n        if _is_hip:\n            max_bs = model_runner.req_to_token_pool.size\n\n            self.kv_indptr = torch.zeros(\n                (max_bs + 1,), dtype=torch.int32, device=model_runner.device\n            )\n\n            self.kv_indices = torch.zeros(\n                max_bs * self.nsa_index_topk,\n                dtype=torch.int32,\n                device=self.device,\n            )\n\n        # Speculative decoding\n        self.topk = model_runner.server_args.speculative_eagle_topk or 0\n        self.speculative_num_steps = speculative_num_steps\n        self.speculative_num_draft_tokens = (\n            model_runner.server_args.speculative_num_draft_tokens\n        )\n        self.speculative_step_id = speculative_step_id\n\n        self.device_capability = torch.cuda.get_device_capability()\n        self.device_sm_major = self.device_capability[0]\n        self.kv_cache_dtype = model_runner.kv_cache_dtype\n\n        # Allocate global workspace buffer for TRT-LLM kernels (ragged attention on SM100/B200, or trtllm decode)\n        if self.device_sm_major >= 10 or self.nsa_decode_impl == \"trtllm\":\n            global global_workspace_buffer\n            if global_workspace_buffer is None:\n                global_workspace_buffer = torch.empty(\n                    envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.get(),\n                    dtype=torch.uint8,\n                    device=model_runner.device,\n                )\n            self.workspace_buffer = global_workspace_buffer\n        else:\n            self.workspace_buffer = None\n\n    def get_device_int32_arange(self, l: int) -> torch.Tensor:\n        if l > len(self._arange_buf):\n            next_pow_of_2 = 1 << (l - 1).bit_length()\n            self._arange_buf = torch.arange(\n                next_pow_of_2, device=self.device, dtype=torch.int32\n            )\n        return self._arange_buf[:l]\n\n    def _transform_table_1_to_real(self, page_table: torch.Tensor) -> torch.Tensor:\n        page_size = self.real_page_size\n        if page_size == 1:\n            return page_table\n        max_seqlen_k = page_table.shape[1]\n        strided_indices = torch.arange(\n            0, max_seqlen_k, page_size, device=page_table.device, dtype=torch.int32\n        )\n        return page_table[:, strided_indices] // page_size\n\n    def init_forward_metadata(self, forward_batch: ForwardBatch):\n        \"\"\"Init the metadata for a forward pass.\"\"\"\n        batch_size = forward_batch.batch_size\n        device = forward_batch.seq_lens.device\n\n        if forward_batch.forward_mode.is_target_verify():\n            draft_token_num = self.speculative_num_draft_tokens\n        else:\n            draft_token_num = 0\n\n        cache_seqlens_int32 = (forward_batch.seq_lens + draft_token_num).to(torch.int32)\n        cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)\n        assert forward_batch.seq_lens_cpu is not None\n        max_seqlen_k = int(forward_batch.seq_lens_cpu.max().item() + draft_token_num)\n        # [b, max_seqlen_k]\n        page_table = forward_batch.req_to_token_pool.req_to_token[\n            forward_batch.req_pool_indices, :max_seqlen_k\n        ]\n\n        page_table_1_flattened = None\n        topk_indices_offset = None\n\n        # Centralized dispatch: decide all strategies for this batch\n        self.set_nsa_prefill_impl(forward_batch)\n        topk_transform_method = self.get_topk_transform_method()\n        # Batch indices selected when cp enabled: After splitting multiple sequences,\n        # a certain cp rank may not have some of these sequences.\n        # We use bs_idx_cpu to mark which sequences are finally selected by the current cp rank,\n        # a default value of None indicates that all sequences are selected.\n        bs_idx_cpu = None\n        # seq_len_cpu of selected sequences\n        indexer_seq_lens_cpu = forward_batch.seq_lens_cpu\n        indexer_seq_lens = forward_batch.seq_lens\n\n        if forward_batch.forward_mode.is_decode_or_idle():\n            extend_seq_lens_cpu = [1] * batch_size\n            max_seqlen_q = 1\n            cu_seqlens_q = self.get_device_int32_arange(batch_size + 1)\n            seqlens_expanded = cache_seqlens_int32\n        elif forward_batch.forward_mode.is_target_verify():\n            max_seqlen_q = 1\n            cu_seqlens_q = torch.arange(\n                0,\n                batch_size * self.speculative_num_draft_tokens + 1,\n                1,\n                dtype=torch.int32,\n                device=device,\n            )\n            extend_seq_lens_cpu = [self.speculative_num_draft_tokens] * batch_size\n            forward_batch.extend_seq_lens_cpu = extend_seq_lens_cpu\n\n            seqlens_expanded = seqlens_expand_triton(\n                torch.tensor(extend_seq_lens_cpu, dtype=torch.int32, device=device),\n                cache_seqlens_int32,\n                self.speculative_num_draft_tokens * batch_size,\n                self.speculative_num_draft_tokens,\n            )\n            page_table = torch.repeat_interleave(\n                page_table, repeats=self.speculative_num_draft_tokens, dim=0\n            )\n        elif forward_batch.forward_mode.is_draft_extend(include_v2=True):\n            assert (\n                forward_batch.extend_seq_lens_cpu is not None\n                and forward_batch.extend_seq_lens is not None\n                and forward_batch.extend_prefix_lens_cpu is not None\n            ), \"All of them must not be None\"\n\n            extend_seq_lens_cpu = forward_batch.extend_seq_lens_cpu\n            assert forward_batch.extend_seq_lens is not None\n\n            max_seqlen_q = 1\n            cu_seqlens_q = torch.arange(\n                0,\n                forward_batch.extend_num_tokens + 1,\n                1,\n                dtype=torch.int32,\n                device=device,\n            )\n\n            seqlens_expanded = seqlens_expand_triton(\n                forward_batch.extend_seq_lens,\n                cache_seqlens_int32,\n                sum(extend_seq_lens_cpu),\n                self.speculative_num_draft_tokens,\n            )\n            if forward_batch.forward_mode.is_draft_extend_v2():\n                # DRAFT_EXTEND_V2: V2 worker pre-fills draft KV cache with ALL speculated\n                # tokens upfront. All requests extend by the same fixed\n                # (speculative_num_draft_tokens). Use scalar to avoid GPU sync.\n                page_table = torch.repeat_interleave(\n                    page_table, repeats=self.speculative_num_draft_tokens, dim=0\n                )\n            else:\n                # DRAFT_EXTEND (v1): V1 worker extends by (accept_length + 1) per request\n                # after verification. Lengths vary per request based on how many tokens\n                # were accepted.\n                page_table = torch.repeat_interleave(\n                    page_table, repeats=forward_batch.extend_seq_lens, dim=0\n                )\n        elif forward_batch.forward_mode.is_extend():\n            assert (\n                forward_batch.extend_seq_lens_cpu is not None\n                and forward_batch.extend_seq_lens is not None\n                and forward_batch.extend_prefix_lens_cpu is not None\n            ), \"All of them must not be None\"\n            extend_seq_lens_cpu = forward_batch.extend_seq_lens_cpu\n            assert forward_batch.extend_seq_lens is not None\n            extend_seq_lens = forward_batch.extend_seq_lens\n\n            seqlens_expanded = torch.cat(\n                [\n                    torch.arange(\n                        kv_len - qo_len + 1,\n                        kv_len + 1,\n                        dtype=torch.int32,\n                        device=device,\n                    )\n                    for qo_len, kv_len in zip(\n                        forward_batch.extend_seq_lens_cpu,\n                        forward_batch.seq_lens_cpu.tolist(),\n                        strict=True,\n                    )\n                ]\n            )\n\n            if can_nsa_prefill_cp_round_robin_split(forward_batch):\n                seqlens_expanded = nsa_cp_round_robin_split_data(seqlens_expanded)\n                extend_seq_lens_cpu, extend_seq_lens, bs_idx_cpu, bs_idx = (\n                    nsa_cp_round_robin_split_q_seqs(\n                        extend_seq_lens_cpu, extend_seq_lens\n                    )\n                )\n                indexer_seq_lens_cpu = indexer_seq_lens_cpu[bs_idx_cpu]\n                indexer_seq_lens = indexer_seq_lens[bs_idx]\n                cache_seqlens_int32 = cache_seqlens_int32[bs_idx]\n                cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)\n                max_seqlen_k = (\n                    int(indexer_seq_lens_cpu.max().item() + draft_token_num)\n                    if len(indexer_seq_lens_cpu) != 0\n                    else 0\n                )\n                page_table = page_table[bs_idx, :max_seqlen_k]\n\n            if (\n                any(forward_batch.extend_prefix_lens_cpu)\n                or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND\n                or bs_idx_cpu is not None\n            ):\n                max_seqlen_q = (\n                    max(extend_seq_lens_cpu) if len(extend_seq_lens_cpu) != 0 else 1\n                )\n                cu_seqlens_q = compute_cu_seqlens(extend_seq_lens.to(torch.int32))\n            else:\n                max_seqlen_q = max_seqlen_k\n                cu_seqlens_q = cu_seqlens_k\n\n            # Check if MHA FP8 dequantization is needed\n            mha_dequantize_needed = (\n                self.use_mha\n                and forward_batch.token_to_kv_pool.dtype == torch.float8_e4m3fn\n            )\n            forward_batch.using_mha_one_shot_fp8_dequant = mha_dequantize_needed\n\n            # page_table_1_flattened is only used when prefix sharing is enabled:\n            has_prefix_sharing = any(forward_batch.extend_prefix_lens_cpu)\n            if has_prefix_sharing and (\n                topk_transform_method == TopkTransformMethod.RAGGED\n                or mha_dequantize_needed\n            ):\n                page_table_1_flattened = torch.cat(\n                    [\n                        page_table[i, :kv_len]\n                        for i, kv_len in enumerate(\n                            indexer_seq_lens_cpu.tolist(),\n                        )\n                    ]\n                )\n                assert page_table_1_flattened.shape[0] == sum(\n                    indexer_seq_lens_cpu\n                ), f\"{page_table_1_flattened.shape[0] = } must be the same as {sum(indexer_seq_lens_cpu) = }\"\n\n                # Validate indices when logical tokens exceed physical capacity\n                # This is likely to be triggered by PP with high kv reuse & parallelism\n                kv_cache_capacity = (\n                    forward_batch.token_to_kv_pool.size\n                    + forward_batch.token_to_kv_pool.page_size\n                )\n                if forward_batch.seq_lens_sum > kv_cache_capacity:\n                    max_idx = page_table_1_flattened.max().item()\n                    assert max_idx < kv_cache_capacity, (\n                        f\"Invalid page table index: max={max_idx}, \"\n                        f\"kv_cache_capacity={kv_cache_capacity}\"\n                    )\n\n            if topk_transform_method == TopkTransformMethod.RAGGED:\n                topk_indices_offset = torch.repeat_interleave(\n                    cu_seqlens_k[:-1],\n                    extend_seq_lens,\n                )\n        else:\n            assert False, f\"Unsupported {forward_batch.forward_mode = }\"\n\n        indexer_k_start_end, token_to_batch_idx = self._cal_indexer_k_start_end(\n            forward_batch, bs_idx_cpu\n        )\n        # 1D, expanded seqlens (1D means cheap to compute, so always compute it)\n        nsa_cache_seqlens_int32 = compute_nsa_seqlens(\n            original_seq_lens=seqlens_expanded,\n            nsa_index_topk=self.nsa_index_topk,\n        )\n        nsa_cache_seqlens_int32 = pad_nsa_cache_seqlens(\n            forward_batch, nsa_cache_seqlens_int32\n        )\n        nsa_cu_seqlens_k = compute_cu_seqlens(nsa_cache_seqlens_int32)\n        nsa_cu_seqlens_q = self.get_device_int32_arange(len(nsa_cu_seqlens_k))\n\n        paged_mqa_schedule_metadata = None\n        # DeepGEMM paged MQA logits path needs a schedule metadata tensor.\n        # Compute it once per forward batch and reuse it across layers.\n        if is_cuda() and (\n            forward_batch.forward_mode.is_decode_or_idle()\n            or forward_batch.forward_mode.is_target_verify()\n            or forward_batch.forward_mode.is_draft_extend()\n        ):\n            try:\n                import deep_gemm\n\n                # NOTE: DeepGEMM paged path uses block_size=64.\n                seqlens_32 = (\n                    seqlens_expanded\n                    if (\n                        forward_batch.forward_mode.is_target_verify()\n                        or forward_batch.forward_mode.is_draft_extend()\n                    )\n                    else cache_seqlens_int32\n                )\n                paged_mqa_schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata(\n                    seqlens_32, 64, deep_gemm.get_num_sms()\n                )\n            except (ImportError, ModuleNotFoundError):\n                paged_mqa_schedule_metadata = None\n\n        metadata = NSAMetadata(\n            page_size=self.real_page_size,\n            cache_seqlens_int32=cache_seqlens_int32,\n            max_seq_len_q=max_seqlen_q,\n            max_seq_len_k=max_seqlen_k,\n            cu_seqlens_q=cu_seqlens_q,\n            cu_seqlens_k=cu_seqlens_k,\n            seq_lens_sum=forward_batch.seq_lens_sum,\n            page_table_1=page_table,\n            page_table_1_flattened=page_table_1_flattened,\n            flashmla_metadata=(\n                self._compute_flashmla_metadata(\n                    cache_seqlens=nsa_cache_seqlens_int32,\n                    seq_len_q=1,\n                )\n                if self.nsa_decode_impl == \"flashmla_kv\"\n                else None\n            ),\n            paged_mqa_schedule_metadata=paged_mqa_schedule_metadata,\n            nsa_cache_seqlens_int32=nsa_cache_seqlens_int32,\n            nsa_cu_seqlens_q=nsa_cu_seqlens_q,\n            nsa_cu_seqlens_k=nsa_cu_seqlens_k,\n            nsa_seqlens_expanded=seqlens_expanded,\n            nsa_extend_seq_lens_list=extend_seq_lens_cpu,\n            real_page_table=self._transform_table_1_to_real(page_table),\n            nsa_max_seqlen_q=1,\n            topk_indices_offset=topk_indices_offset,\n            indexer_k_start_end=indexer_k_start_end,\n            indexer_seq_lens_cpu=indexer_seq_lens_cpu,\n            indexer_seq_lens=indexer_seq_lens,\n            token_to_batch_idx=token_to_batch_idx,\n        )\n        self.forward_metadata = metadata\n\n    def _cal_indexer_k_start_end(\n        self,\n        forward_batch: ForwardBatch,\n        bs_idx: Optional[List[int]] = None,\n    ):\n        if not forward_batch.forward_mode.is_extend_without_speculative():\n            return None, None\n        if forward_batch.batch_size == 0 or (bs_idx is not None and len(bs_idx) == 0):\n            empty_t = torch.empty(0, dtype=torch.int32, device=self.device)\n            return (empty_t, empty_t), empty_t\n\n        # Suppose there are two requests, with extend_seq_len = [3, 2]\n        # and seq_lens = [10, 4]\n        # The logits matrix looks like this, with * representing the valid logits\n        # and - representing the invalid logits:\n        #\n        #  ********--|----\n        #  *********-|----\n        #  **********|----\n        #  ----------|***-\n        #  ----------|****\n        #\n        # ks = [0, 0, 0, 10, 10]\n        # ke = [8, 9, 10, 13, 14]\n        ks_list = []\n        ke_list = []\n        token_to_batch_idx = []\n\n        q_offset = 0\n        k_offset = 0\n\n        assert (\n            forward_batch.seq_lens_cpu is not None\n            and forward_batch.extend_seq_lens_cpu is not None\n        )\n        for i in range(forward_batch.batch_size):\n            seq_len = forward_batch.seq_lens_cpu[i].item()\n            assert isinstance(seq_len, int)\n            extend_seq_len = forward_batch.extend_seq_lens_cpu[i]\n            ks = torch.full(\n                (extend_seq_len,), k_offset, dtype=torch.int32, device=self.device\n            )\n            kv_len = seq_len\n            if forward_batch.forward_mode.is_target_verify():\n                kv_len += self.speculative_num_draft_tokens\n            seq_lens_expanded = torch.arange(\n                kv_len - extend_seq_len + 1,\n                kv_len + 1,\n                dtype=torch.int32,\n                device=self.device,\n            )\n            ke = ks + seq_lens_expanded\n            ks_list.append(ks)\n            ke_list.append(ke)\n\n            # bi: The index within the selected batch bs_idx. Entries that were not selected are ignored.\n            bi = bs_idx.index(i) if (bs_idx is not None and i in bs_idx) else i\n            tb = torch.full(\n                (extend_seq_len,), bi, dtype=torch.int32, device=self.device\n            )\n            token_to_batch_idx.append(tb)\n\n            if bs_idx is None or i in bs_idx:  # skip batch not included in bs_idx\n                q_offset += extend_seq_len\n                k_offset += seq_len\n\n        ks = torch.cat(ks_list, dim=0)\n        ke = torch.cat(ke_list, dim=0)\n        token_to_batch_idx = torch.cat(token_to_batch_idx, dim=0)\n        if bs_idx is not None:\n            assert can_nsa_prefill_cp_round_robin_split(forward_batch)\n            ks = nsa_cp_round_robin_split_data(ks)\n            ke = nsa_cp_round_robin_split_data(ke)\n            token_to_batch_idx = nsa_cp_round_robin_split_data(token_to_batch_idx)\n        return (ks, ke), token_to_batch_idx\n\n    def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):\n        \"\"\"Initialize CUDA graph state for the attention backend.\n\n        Args:\n            max_bs (int): Maximum batch size to support in CUDA graphs\n\n        This creates fixed-size tensors that will be reused during CUDA graph replay\n        to avoid memory allocations.\n        \"\"\"\n        self.decode_cuda_graph_metadata: Dict = {\n            \"cache_seqlens\": torch.ones(\n                max_num_tokens, dtype=torch.int32, device=self.device\n            ),\n            \"cu_seqlens_q\": torch.arange(\n                0, max_bs + 1, dtype=torch.int32, device=self.device\n            ),\n            \"cu_seqlens_k\": torch.zeros(\n                max_bs + 1, dtype=torch.int32, device=self.device\n            ),\n            # fake page_table for sparse_prefill\n            # Add extra columns for speculative draft tokens to avoid\n            # overflow during target_verify when max_seqlen_k = seq_len + num_draft_tokens\n            \"page_table\": torch.zeros(\n                max_num_tokens,\n                self.max_context_len + (self.speculative_num_draft_tokens or 0),\n                dtype=torch.int32,\n                device=self.device,\n            ),\n            \"flashmla_metadata\": (\n                self._compute_flashmla_metadata(\n                    cache_seqlens=torch.ones(\n                        max_num_tokens, dtype=torch.int32, device=self.device\n                    ),\n                    seq_len_q=1,\n                )\n                if self.nsa_decode_impl == \"flashmla_kv\"\n                else None\n            ),\n        }\n\n    def init_forward_metadata_capture_cuda_graph(\n        self,\n        bs: int,\n        num_tokens: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        encoder_lens: Optional[torch.Tensor],\n        forward_mode: ForwardMode,\n        spec_info: Optional[SpecInput],\n    ):\n        self.set_nsa_prefill_impl(forward_batch=None)\n\n        \"\"\"Initialize forward metadata for capturing CUDA graph.\"\"\"\n        if forward_mode.is_decode_or_idle():\n            # Normal Decode\n            # Get sequence information\n            cache_seqlens_int32 = seq_lens.to(torch.int32)\n            cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)\n\n            # Use max context length for seq_len_k\n            page_table_1 = self.decode_cuda_graph_metadata[\"page_table\"][:bs, :]\n            max_seqlen_q = 1\n            max_seqlen_k = page_table_1.shape[1]\n\n            # Precompute page table\n            # Precompute cumulative sequence lengths\n\n            # NOTE(dark): this is always arange, since we are decoding\n            cu_seqlens_q = self.decode_cuda_graph_metadata[\"cu_seqlens_q\"][: bs + 1]\n            nsa_cache_seqlens_int32 = compute_nsa_seqlens(\n                cache_seqlens_int32, nsa_index_topk=self.nsa_index_topk\n            )\n\n            seqlens_expanded = cache_seqlens_int32\n            nsa_extend_seq_lens_list = [1] * num_tokens\n            if self.nsa_decode_impl == \"flashmla_kv\":\n                flashmla_metadata = self.decode_cuda_graph_metadata[\n                    \"flashmla_metadata\"\n                ].slice(slice(0, num_tokens + 1))\n                flashmla_metadata.copy_(\n                    self._compute_flashmla_metadata(\n                        cache_seqlens=nsa_cache_seqlens_int32,\n                        seq_len_q=1,\n                    )\n                )\n            else:\n                flashmla_metadata = None\n        elif forward_mode.is_target_verify() or forward_mode.is_draft_extend(\n            include_v2=True\n        ):\n            cache_seqlens_int32 = (seq_lens + self.speculative_num_draft_tokens).to(\n                torch.int32\n            )\n            cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)\n            max_seqlen_q = 1\n            page_table_1 = self.decode_cuda_graph_metadata[\"page_table\"][\n                : bs * self.speculative_num_draft_tokens, :\n            ]\n            max_seqlen_k = page_table_1.shape[1]\n\n            cu_seqlens_q = torch.arange(\n                0,\n                bs * self.speculative_num_draft_tokens + 1,\n                1,\n                dtype=torch.int32,\n                device=self.device,\n            )\n\n            extend_seq_lens_cpu = [self.speculative_num_draft_tokens] * bs\n\n            seqlens_int32_cpu = [\n                self.speculative_num_draft_tokens + kv_len\n                for kv_len in seq_lens.tolist()\n            ]\n            seqlens_expanded = torch.cat(\n                [\n                    torch.arange(\n                        kv_len - qo_len + 1,\n                        kv_len + 1,\n                        dtype=torch.int32,\n                        device=self.device,\n                    )\n                    for qo_len, kv_len in zip(\n                        extend_seq_lens_cpu,\n                        seqlens_int32_cpu,\n                        strict=True,\n                    )\n                ]\n            )\n            nsa_cache_seqlens_int32 = compute_nsa_seqlens(\n                seqlens_expanded, nsa_index_topk=self.nsa_index_topk\n            )\n            nsa_extend_seq_lens_list = [1] * bs * self.speculative_num_draft_tokens\n\n            if self.nsa_decode_impl == \"flashmla_kv\":\n                flashmla_metadata = self.decode_cuda_graph_metadata[\n                    \"flashmla_metadata\"\n                ].slice(slice(0, bs * self.speculative_num_draft_tokens + 1))\n\n                flashmla_metadata.copy_(\n                    self._compute_flashmla_metadata(\n                        cache_seqlens=nsa_cache_seqlens_int32,\n                        seq_len_q=1,\n                    )\n                )\n            else:\n                flashmla_metadata = None\n\n        nsa_cu_seqlens_k = compute_cu_seqlens(nsa_cache_seqlens_int32)\n        nsa_cu_seqlens_q = self.get_device_int32_arange(len(nsa_cu_seqlens_k))\n        real_page_table = self._transform_table_1_to_real(page_table_1)\n\n        paged_mqa_schedule_metadata = None\n        if is_cuda() and (\n            forward_mode.is_decode_or_idle()\n            or forward_mode.is_target_verify()\n            or forward_mode.is_draft_extend()\n        ):\n            try:\n                import deep_gemm\n\n                seqlens_32 = (\n                    seqlens_expanded\n                    if (\n                        forward_mode.is_target_verify()\n                        or forward_mode.is_draft_extend()\n                    )\n                    else cache_seqlens_int32\n                )\n                paged_mqa_schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata(\n                    seqlens_32, 64, deep_gemm.get_num_sms()\n                )\n            except (ImportError, ModuleNotFoundError):\n                paged_mqa_schedule_metadata = None\n\n        metadata = NSAMetadata(\n            page_size=self.real_page_size,\n            cache_seqlens_int32=cache_seqlens_int32,\n            max_seq_len_q=max_seqlen_q,\n            max_seq_len_k=max_seqlen_k,\n            cu_seqlens_q=cu_seqlens_q,\n            cu_seqlens_k=cu_seqlens_k,\n            page_table_1=page_table_1,\n            flashmla_metadata=flashmla_metadata,\n            paged_mqa_schedule_metadata=paged_mqa_schedule_metadata,\n            nsa_cache_seqlens_int32=nsa_cache_seqlens_int32,\n            nsa_cu_seqlens_q=nsa_cu_seqlens_q,\n            nsa_cu_seqlens_k=nsa_cu_seqlens_k,\n            nsa_seqlens_expanded=seqlens_expanded,\n            real_page_table=real_page_table,\n            nsa_extend_seq_lens_list=nsa_extend_seq_lens_list,\n        )\n        self.decode_cuda_graph_metadata[bs] = metadata\n        self.forward_metadata = metadata\n\n    def init_forward_metadata_replay_cuda_graph(\n        self,\n        bs: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        seq_lens_sum: int,\n        encoder_lens: Optional[torch.Tensor],\n        forward_mode: ForwardMode,\n        spec_info: Optional[SpecInput],\n        seq_lens_cpu: Optional[torch.Tensor],\n        out_cache_loc: Optional[torch.Tensor] = None,\n    ):\n        \"\"\"Initialize forward metadata for replaying CUDA graph.\"\"\"\n        assert seq_lens_cpu is not None\n\n        self.set_nsa_prefill_impl(forward_batch=None)\n\n        seq_lens = seq_lens[:bs]\n        seq_lens_cpu = seq_lens_cpu[:bs]\n        req_pool_indices = req_pool_indices[:bs]\n\n        # Normal Decode\n        metadata: NSAMetadata = self.decode_cuda_graph_metadata[bs]\n        if forward_mode.is_decode_or_idle():\n            # Normal Decode\n            max_len = int(seq_lens_cpu.max().item())\n\n            cache_seqlens = seq_lens.to(torch.int32)\n            metadata.cache_seqlens_int32.copy_(cache_seqlens)\n            metadata.cu_seqlens_k[1:].copy_(\n                torch.cumsum(cache_seqlens, dim=0, dtype=torch.int32)\n            )\n            page_indices = self.req_to_token[req_pool_indices, :max_len]\n            metadata.page_table_1[:, :max_len].copy_(page_indices)\n            nsa_cache_seqlens = compute_nsa_seqlens(\n                cache_seqlens, nsa_index_topk=self.nsa_index_topk\n            )\n            metadata.nsa_cache_seqlens_int32.copy_(nsa_cache_seqlens)\n            seqlens_expanded = cache_seqlens\n        elif forward_mode.is_target_verify():\n            max_seqlen_k = int(\n                seq_lens_cpu.max().item() + self.speculative_num_draft_tokens\n            )\n\n            cache_seqlens = (seq_lens + self.speculative_num_draft_tokens).to(\n                torch.int32\n            )\n            metadata.cache_seqlens_int32.copy_(cache_seqlens)\n            metadata.cu_seqlens_k[1:].copy_(\n                torch.cumsum(cache_seqlens, dim=0, dtype=torch.int32)\n            )\n            page_indices = self.req_to_token[req_pool_indices, :max_seqlen_k]\n            page_indices = torch.repeat_interleave(\n                page_indices, repeats=self.speculative_num_draft_tokens, dim=0\n            )\n            metadata.page_table_1[:, :max_seqlen_k].copy_(page_indices)\n            extend_seq_lens_cpu = [self.speculative_num_draft_tokens] * bs\n\n            seqlens_expanded = seqlens_expand_triton(\n                torch.tensor(\n                    extend_seq_lens_cpu, dtype=torch.int32, device=self.device\n                ),\n                cache_seqlens,\n                self.speculative_num_draft_tokens * bs,\n                self.speculative_num_draft_tokens,\n            )\n            metadata.nsa_seqlens_expanded.copy_(seqlens_expanded)\n            nsa_cache_seqlens = compute_nsa_seqlens(\n                seqlens_expanded, self.nsa_index_topk\n            )\n            metadata.nsa_cache_seqlens_int32.copy_(nsa_cache_seqlens)\n        elif forward_mode.is_draft_extend(include_v2=True):\n            max_seqlen_k = int(seq_lens_cpu.max().item())\n            cache_seqlens = seq_lens.to(torch.int32)\n            metadata.cache_seqlens_int32.copy_(cache_seqlens)\n            metadata.cu_seqlens_k[1:].copy_(\n                torch.cumsum(cache_seqlens, dim=0, dtype=torch.int32)\n            )\n\n            extend_seq_lens = spec_info.accept_length[:bs]\n            extend_seq_lens_cpu = extend_seq_lens.tolist()\n\n            page_indices = self.req_to_token[req_pool_indices, :max_seqlen_k]\n            page_indices = torch.repeat_interleave(\n                page_indices, repeats=extend_seq_lens, dim=0\n            )\n            metadata.page_table_1[: page_indices.shape[0], :max_seqlen_k].copy_(\n                page_indices\n            )\n\n            seqlens_expanded = seqlens_expand_triton(\n                extend_seq_lens,\n                cache_seqlens,\n                sum(extend_seq_lens_cpu),\n                self.speculative_num_draft_tokens,\n            )\n            metadata.nsa_seqlens_expanded[: seqlens_expanded.shape[0]].copy_(\n                seqlens_expanded\n            )\n            nsa_cache_seqlens = compute_nsa_seqlens(\n                seqlens_expanded, self.nsa_index_topk\n            )\n            metadata.nsa_cache_seqlens_int32[: seqlens_expanded.shape[0]].copy_(\n                nsa_cache_seqlens\n            )\n\n        # Update DeepGEMM paged MQA schedule metadata outside the captured graph.\n        if is_cuda() and (\n            forward_mode.is_decode_or_idle()\n            or forward_mode.is_target_verify()\n            or forward_mode.is_draft_extend()\n        ):\n            try:\n                import deep_gemm\n\n                seqlens_32 = (\n                    seqlens_expanded\n                    if (\n                        forward_mode.is_target_verify()\n                        or forward_mode.is_draft_extend()\n                    )\n                    else metadata.cache_seqlens_int32\n                )\n                new_schedule = deep_gemm.get_paged_mqa_logits_metadata(\n                    seqlens_32, 64, deep_gemm.get_num_sms()\n                )\n                if metadata.paged_mqa_schedule_metadata is None:\n                    metadata.paged_mqa_schedule_metadata = new_schedule\n                else:\n                    metadata.paged_mqa_schedule_metadata.copy_(new_schedule)\n            except (ImportError, ModuleNotFoundError):\n                metadata.paged_mqa_schedule_metadata = None\n        seqlens_expanded_size = seqlens_expanded.shape[0]\n        assert (\n            metadata.nsa_cache_seqlens_int32 is not None\n            and metadata.nsa_cu_seqlens_k is not None\n            and self.nsa_index_topk is not None\n        )\n\n        metadata.nsa_cu_seqlens_k[1 : 1 + seqlens_expanded_size].copy_(\n            torch.cumsum(nsa_cache_seqlens, dim=0, dtype=torch.int32)\n        )\n        # NOTE(dark): (nsa-) cu_seqlens_q is always arange, no need to copy\n\n        assert self.real_page_size == metadata.page_size\n        if self.real_page_size > 1:\n            real_table = self._transform_table_1_to_real(page_indices)\n            new_rows = real_table.shape[0]\n            new_cols = real_table.shape[1]\n            metadata.real_page_table[:new_rows, :new_cols].copy_(real_table)\n        else:\n            assert metadata.real_page_table is metadata.page_table_1\n\n        if self.nsa_decode_impl == \"flashmla_kv\":\n            flashmla_metadata = metadata.flashmla_metadata.slice(\n                slice(0, seqlens_expanded_size + 1)\n            )\n            flashmla_metadata.copy_(\n                self._compute_flashmla_metadata(\n                    cache_seqlens=nsa_cache_seqlens,\n                    seq_len_q=1,\n                )\n            )\n\n        self.forward_metadata = metadata\n\n    def init_forward_metadata_replay_cuda_graph_from_precomputed(\n        self,\n        bs: int,\n        precomputed: PrecomputedMetadata,\n        forward_mode: ForwardMode,\n    ):\n        \"\"\"Fast path: copy precomputed metadata to this backend's metadata.\n\n        This function only performs copy operations, no computation.\n\n        Args:\n            bs: Batch size\n            precomputed: Precomputed metadata to copy from\n            forward_mode: Forward mode\n        \"\"\"\n        self.set_nsa_prefill_impl(forward_batch=None)\n\n        metadata = self.decode_cuda_graph_metadata[bs]\n\n        # Track whether fused kernel succeeded\n        fused_kernel_succeeded = False\n\n        # Use fused CUDA kernel for all copy operations\n        if _USE_FUSED_METADATA_COPY:\n            try:\n                from sglang.jit_kernel.fused_metadata_copy import (\n                    fused_metadata_copy_cuda,\n                )\n\n                # Map forward_mode to integer enum\n                if forward_mode.is_decode_or_idle():\n                    mode_int = 0  # DECODE\n                elif forward_mode.is_target_verify():\n                    mode_int = 1  # TARGET_VERIFY\n                elif forward_mode.is_draft_extend():\n                    mode_int = 2  # DRAFT_EXTEND\n                else:\n                    raise ValueError(f\"Unsupported forward_mode: {forward_mode}\")\n\n                # Prepare FlashMLA tensors if needed\n                flashmla_num_splits_src = None\n                flashmla_num_splits_dst = None\n                flashmla_metadata_src = None\n                flashmla_metadata_dst = None\n                if precomputed.flashmla_metadata is not None:\n                    flashmla_num_splits_src = precomputed.flashmla_metadata.num_splits\n                    flashmla_num_splits_dst = metadata.flashmla_metadata.num_splits\n                    flashmla_metadata_src = (\n                        precomputed.flashmla_metadata.flashmla_metadata\n                    )\n                    flashmla_metadata_dst = metadata.flashmla_metadata.flashmla_metadata\n\n                # Call fused kernel\n                fused_metadata_copy_cuda(\n                    # Source tensors\n                    precomputed.cache_seqlens,\n                    precomputed.cu_seqlens_k,\n                    precomputed.page_indices,\n                    precomputed.nsa_cache_seqlens,\n                    precomputed.seqlens_expanded,\n                    precomputed.nsa_cu_seqlens_k,\n                    precomputed.real_page_table,\n                    flashmla_num_splits_src,\n                    flashmla_metadata_src,\n                    # Destination tensors\n                    metadata.cache_seqlens_int32,\n                    metadata.cu_seqlens_k,\n                    metadata.page_table_1,\n                    metadata.nsa_cache_seqlens_int32,\n                    metadata.nsa_seqlens_expanded,\n                    metadata.nsa_cu_seqlens_k,\n                    (\n                        metadata.real_page_table\n                        if precomputed.real_page_table is not None\n                        else None\n                    ),\n                    flashmla_num_splits_dst,\n                    flashmla_metadata_dst,\n                    # Parameters\n                    mode_int,\n                    bs,\n                    precomputed.max_len,\n                    precomputed.max_seqlen_k,\n                    precomputed.seqlens_expanded_size,\n                )\n\n                # Successfully used fused kernel\n                fused_kernel_succeeded = True\n\n            except ImportError:\n                print(\n                    \"Warning: Fused metadata copy kernel not available, falling back to individual copies.\"\n                )\n            except Exception as e:\n                print(\n                    f\"Warning: Fused metadata copy kernel failed with error: {e}, falling back to individual copies.\"\n                )\n\n        # Fallback to individual copy operations if fused kernel disabled or failed\n        if not fused_kernel_succeeded:\n            # Copy basic seqlens\n            metadata.cache_seqlens_int32.copy_(precomputed.cache_seqlens)\n            metadata.cu_seqlens_k[1:].copy_(precomputed.cu_seqlens_k[1:])\n\n            # Mode-specific copy logic\n            if forward_mode.is_decode_or_idle():\n                # Decode mode\n                metadata.page_table_1[:, : precomputed.max_len].copy_(\n                    precomputed.page_indices\n                )\n                metadata.nsa_cache_seqlens_int32.copy_(precomputed.nsa_cache_seqlens)\n                # seqlens_expanded is same as cache_seqlens (already copied)\n\n            elif forward_mode.is_target_verify():\n                # Target verify mode\n                metadata.page_table_1[:, : precomputed.max_seqlen_k].copy_(\n                    precomputed.page_indices\n                )\n                metadata.nsa_seqlens_expanded.copy_(precomputed.seqlens_expanded)\n                metadata.nsa_cache_seqlens_int32.copy_(precomputed.nsa_cache_seqlens)\n\n            elif forward_mode.is_draft_extend():\n                # Draft extend mode\n                rows = precomputed.page_indices.shape[0]\n                cols = precomputed.max_seqlen_k\n                metadata.page_table_1[:rows, :cols].copy_(precomputed.page_indices)\n\n                size = precomputed.seqlens_expanded_size\n                metadata.nsa_seqlens_expanded[:size].copy_(precomputed.seqlens_expanded)\n                metadata.nsa_cache_seqlens_int32[:size].copy_(\n                    precomputed.nsa_cache_seqlens\n                )\n\n            # Copy NSA cu_seqlens\n            size = precomputed.seqlens_expanded_size\n            metadata.nsa_cu_seqlens_k[1 : 1 + size].copy_(\n                precomputed.nsa_cu_seqlens_k[1 : 1 + size]\n            )\n\n            # Copy real page table\n            if precomputed.real_page_table is not None:\n                rows, cols = precomputed.real_page_table.shape\n                metadata.real_page_table[:rows, :cols].copy_(\n                    precomputed.real_page_table\n                )\n\n            # Copy FlashMLA metadata in fallback path\n            if precomputed.flashmla_metadata is not None:\n                size = precomputed.seqlens_expanded_size\n                flashmla_metadata = metadata.flashmla_metadata.slice(slice(0, size + 1))\n                flashmla_metadata.copy_(precomputed.flashmla_metadata)\n\n        self.forward_metadata = metadata\n\n    def forward_extend(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache=True,\n        # For multi-head latent attention\n        q_rope: Optional[torch.Tensor] = None,\n        k_rope: Optional[torch.Tensor] = None,\n        topk_indices: Optional[torch.Tensor] = None,\n        cos_sin_cache: Optional[torch.Tensor] = None,\n        is_neox: Optional[bool] = False,\n        llama_4_scaling: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n\n        causal = not layer.is_cross_attention\n        metadata = self.forward_metadata\n        assert causal, \"NSA is causal only\"\n\n        nsa_impl = (\n            self.nsa_decode_impl\n            if (\n                forward_batch.forward_mode.is_target_verify()\n                or forward_batch.forward_mode.is_draft_extend(include_v2=True)\n            )\n            else self.nsa_prefill_impl\n        )\n\n        if nsa_impl == \"trtllm\" and not self.use_mha:\n            return self._forward_trtllm(\n                q,\n                k,\n                v,\n                layer,\n                forward_batch,\n                metadata.nsa_cache_seqlens_int32,\n                save_kv_cache,\n                q_rope,\n                k_rope,\n                topk_indices,\n                cos_sin_cache,\n                is_neox,\n                llama_4_scaling,\n            )\n\n        if k is not None:\n            assert v is not None\n            if save_kv_cache:\n                cache_loc = (\n                    forward_batch.out_cache_loc\n                    if not layer.is_cross_attention\n                    else forward_batch.encoder_out_cache_loc\n                )\n                forward_batch.token_to_kv_pool.set_mla_kv_buffer(  # type: ignore\n                    layer,\n                    cache_loc,\n                    k,\n                    k_rope,\n                )\n\n        # Use MHA kernel if in MHA_ONE_SHOT mode\n        if self.use_mha:\n            assert k is not None and v is not None\n            assert q_rope is None, \"MHA_ONE_SHOT path should not pass q_rope\"\n            assert (\n                layer.tp_k_head_num == layer.tp_q_head_num > 1\n            ), \"MHA_ONE_SHOT requires dense multi-head config\"\n            return self._forward_standard_mha(\n                q=q,\n                k=k,\n                v=v,\n                layer=layer,\n                forward_batch=forward_batch,\n                metadata=metadata,\n            )\n\n        # Do absorbed multi-latent attention (MLA path)\n        assert q_rope is not None\n        kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)\n\n        if q_rope is not None:\n            q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)\n            q_rope = q_rope.view(\n                -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim\n            )\n        else:\n            q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)\n            q_nope = q_all[:, :, : layer.v_head_dim]\n            q_rope = q_all[:, :, layer.v_head_dim :]\n\n        # Align topk_indices with q dimensions\n        # This handles cases where q is padded (TP + partial DP attention)\n        if topk_indices is not None:\n            topk_indices = self._pad_topk_indices(topk_indices, q_nope.shape[0])\n\n        # NOTE(dark): here, we use page size = 1\n        topk_transform_method = self.get_topk_transform_method()\n        if envs.SGLANG_NSA_FUSE_TOPK.get():\n            page_table_1 = topk_indices\n        else:\n            if topk_transform_method == TopkTransformMethod.RAGGED:\n                topk_indices_offset = metadata.topk_indices_offset\n                assert topk_indices_offset is not None\n                mask = topk_indices != -1\n                topk_indices_offset = (\n                    topk_indices_offset.unsqueeze(1)\n                    if topk_indices_offset.ndim == 1\n                    else topk_indices_offset\n                )\n                topk_indices = torch.where(\n                    mask, topk_indices + topk_indices_offset, topk_indices\n                )\n            elif topk_transform_method == TopkTransformMethod.PAGED:\n                assert metadata.nsa_extend_seq_lens_list is not None\n                page_table_1 = transform_index_page_table_prefill(\n                    page_table=metadata.page_table_1,\n                    topk_indices=topk_indices,\n                    extend_lens_cpu=metadata.nsa_extend_seq_lens_list,\n                    page_size=1,\n                )\n\n        if nsa_impl == \"tilelang\":\n            if q_rope is not None:\n                q_all = concat_mla_absorb_q_general(q_nope, q_rope)\n            return self._forward_tilelang(\n                q_all=q_all,\n                kv_cache=kv_cache,\n                page_table_1=page_table_1,\n                sm_scale=layer.scaling,\n                v_head_dim=layer.v_head_dim,\n            )\n        elif nsa_impl == \"flashmla_sparse\":\n            if q_rope is not None:\n                q_all = concat_mla_absorb_q_general(q_nope, q_rope)\n\n            if topk_transform_method == TopkTransformMethod.RAGGED:\n                if any(forward_batch.extend_prefix_lens_cpu):\n                    page_table_1_flattened = (\n                        self.forward_metadata.page_table_1_flattened\n                    )\n                    assert page_table_1_flattened is not None\n                    kv_cache = dequantize_k_cache_paged(\n                        kv_cache, page_table_1_flattened\n                    )\n                else:\n                    kv_cache = _cat([k, k_rope], dim=-1)\n                page_table_1 = topk_indices\n\n            return self._forward_flashmla_sparse(\n                q_all=q_all,\n                kv_cache=kv_cache,\n                page_table_1=page_table_1,\n                sm_scale=layer.scaling,\n                v_head_dim=layer.v_head_dim,\n            )\n        elif nsa_impl == \"flashmla_kv\":\n            if q_rope is not None:\n                q_all = concat_mla_absorb_q_general(q_nope, q_rope)\n            return self._forward_flashmla_kv(\n                q_all=q_all,\n                kv_cache=kv_cache,\n                sm_scale=layer.scaling,\n                v_head_dim=layer.v_head_dim,\n                # TODO optimize args\n                layer=layer,\n                metadata=metadata,\n                page_table_1=page_table_1,\n            )\n        elif nsa_impl == \"fa3\":\n            return self._forward_fa3(\n                q_rope=q_rope,\n                kv_cache=kv_cache,\n                v_head_dim=layer.v_head_dim,\n                q_nope=q_nope,\n                page_table=page_table_1,\n                cache_seqlens=metadata.nsa_cache_seqlens_int32,\n                cu_seqlens_q=metadata.nsa_cu_seqlens_q,\n                cu_seqlens_k=metadata.nsa_cu_seqlens_k,\n                max_seqlen_q=metadata.nsa_max_seqlen_q,\n                sm_scale=layer.scaling,\n                logit_cap=layer.logit_cap,\n                page_size=1,\n            )\n        elif nsa_impl == \"aiter\":\n            if q_rope is not None:\n                q_all = torch.cat([q_nope, q_rope], dim=-1)\n            return self._forward_aiter_extend(\n                q_all=q_all,\n                kv_cache=kv_cache,\n                page_table_1=page_table_1,\n                layer=layer,\n            )\n        else:\n            raise ValueError(\n                f\"Unsupported {nsa_impl = } for forward_extend. Consider using an other attention backend.\"\n            )\n\n    def forward_decode(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache=True,\n        # For multi-head latent attention\n        q_rope: Optional[torch.Tensor] = None,\n        k_rope: Optional[torch.Tensor] = None,\n        topk_indices: Optional[torch.Tensor] = None,\n        cos_sin_cache: Optional[torch.Tensor] = None,\n        is_neox: Optional[bool] = False,\n        llama_4_scaling: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n\n        causal = not layer.is_cross_attention\n        metadata = self.forward_metadata\n        assert causal, \"NSA is causal only\"\n\n        if self.nsa_decode_impl == \"trtllm\":\n            return self._forward_trtllm(\n                q,\n                k,\n                v,\n                layer,\n                forward_batch,\n                metadata.cache_seqlens_int32,\n                save_kv_cache,\n                q_rope,\n                k_rope,\n                topk_indices,\n                cos_sin_cache,\n                is_neox,\n                llama_4_scaling,\n            )\n\n        if k is not None:\n            assert v is not None\n            if save_kv_cache:\n                cache_loc = (\n                    forward_batch.out_cache_loc\n                    if not layer.is_cross_attention\n                    else forward_batch.encoder_out_cache_loc\n                )\n                forward_batch.token_to_kv_pool.set_mla_kv_buffer(  # type: ignore\n                    layer,\n                    cache_loc,\n                    k,\n                    k_rope,\n                )\n\n        # Do absorbed multi-latent attention\n        kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)\n        if q_rope is not None:\n            q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)\n            q_rope = q_rope.view(\n                -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim\n            )\n        else:\n            q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)\n            q_nope = q_all[:, :, : layer.v_head_dim]\n            q_rope = q_all[:, :, layer.v_head_dim :]\n\n        # Align topk_indices with q dimensions\n        if topk_indices is not None:\n            topk_indices = self._pad_topk_indices(topk_indices, q_nope.shape[0])\n\n        if envs.SGLANG_NSA_FUSE_TOPK.get():\n            page_table_1 = topk_indices\n        else:\n            page_table_1 = transform_index_page_table_decode(\n                page_table=metadata.page_table_1,\n                topk_indices=topk_indices,\n                page_size=1,\n            )\n\n        if self.nsa_decode_impl == \"flashmla_sparse\":\n            if q_rope is not None:\n                q_all = concat_mla_absorb_q_general(q_nope, q_rope)\n            return self._forward_flashmla_sparse(\n                q_all=q_all,\n                kv_cache=kv_cache,\n                page_table_1=page_table_1,\n                sm_scale=layer.scaling,\n                v_head_dim=layer.v_head_dim,\n            )\n        elif self.nsa_decode_impl == \"flashmla_kv\":\n            if q_rope is not None:\n                q_all = concat_mla_absorb_q_general(q_nope, q_rope)\n            return self._forward_flashmla_kv(\n                q_all=q_all,\n                kv_cache=kv_cache,\n                sm_scale=layer.scaling,\n                v_head_dim=layer.v_head_dim,\n                # TODO optimize args\n                layer=layer,\n                metadata=metadata,\n                page_table_1=page_table_1,\n            )\n        elif self.nsa_decode_impl == \"tilelang\":\n            if q_rope is not None:\n                q_all = concat_mla_absorb_q_general(q_nope, q_rope)\n            return self._forward_tilelang(\n                q_all=q_all,\n                kv_cache=kv_cache,\n                page_table_1=page_table_1,\n                sm_scale=layer.scaling,\n                v_head_dim=layer.v_head_dim,\n            )\n        elif self.nsa_decode_impl == \"fa3\":\n            return self._forward_fa3(\n                q_rope=q_rope,\n                kv_cache=kv_cache,\n                v_head_dim=layer.v_head_dim,\n                q_nope=q_nope,\n                page_table=page_table_1,\n                cache_seqlens=metadata.nsa_cache_seqlens_int32,\n                cu_seqlens_q=metadata.nsa_cu_seqlens_q,\n                cu_seqlens_k=metadata.nsa_cu_seqlens_k,\n                max_seqlen_q=metadata.nsa_max_seqlen_q,\n                sm_scale=layer.scaling,\n                logit_cap=layer.logit_cap,\n                page_size=1,\n            )\n        elif self.nsa_decode_impl == \"aiter\":\n            if q_rope is not None:\n                q_all = torch.cat([q_nope, q_rope], dim=-1)\n            return self._forward_aiter(\n                q_all=q_all,\n                kv_cache=kv_cache,\n                page_table_1=page_table_1,\n                layer=layer,\n                metadata=metadata,\n                bs=forward_batch.batch_size,\n            )\n\n        else:\n            assert False, f\"Unsupported {self.nsa_decode_impl = }\"\n\n    def _forward_fa3(\n        self,\n        q_rope: torch.Tensor,\n        kv_cache: torch.Tensor,\n        v_head_dim: int,\n        q_nope: torch.Tensor,\n        page_table: torch.Tensor,\n        cache_seqlens: torch.Tensor,\n        cu_seqlens_q: torch.Tensor,\n        cu_seqlens_k: torch.Tensor,\n        max_seqlen_q: int,\n        sm_scale: float,\n        logit_cap: float,\n        page_size: int,\n    ) -> torch.Tensor:\n        k_rope_cache = kv_cache[:, :, v_head_dim:]\n        c_kv_cache = kv_cache[:, :, :v_head_dim]\n        qk_rope_dim = k_rope_cache.shape[-1]\n        k_rope_cache = k_rope_cache.view(-1, page_size, 1, qk_rope_dim)\n        c_kv_cache = c_kv_cache.view(-1, page_size, 1, v_head_dim)\n        o = flash_attn_with_kvcache(\n            q=q_rope,\n            k_cache=k_rope_cache,\n            v_cache=c_kv_cache,\n            qv=q_nope,\n            page_table=page_table,\n            cache_seqlens=cache_seqlens,\n            cu_seqlens_q=cu_seqlens_q,\n            cu_seqlens_k_new=cu_seqlens_k,\n            max_seqlen_q=max_seqlen_q,\n            softmax_scale=sm_scale,\n            causal=True,\n            softcap=logit_cap,\n            return_softmax_lse=False,\n            num_splits=self.num_splits,\n        )\n        return o  # type: ignore\n\n    def _forward_flashmla_sparse(\n        self,\n        q_all: torch.Tensor,\n        kv_cache: torch.Tensor,\n        v_head_dim: int,\n        page_table_1: torch.Tensor,\n        sm_scale: float,\n    ) -> torch.Tensor:\n        from sgl_kernel.flash_mla import flash_mla_sparse_fwd\n\n        # FlashMLA sparse kernel requires num_heads to be a multiple of 64 (Hopper) or 128 (Blackwell)\n        # When using TP, num_heads might be smaller (e.g., 256//8=32)\n        num_tokens, num_heads, head_dim = q_all.shape\n\n        # Determine required padding based on GPU architecture (use cached value)\n        required_padding = 128 if self.device_sm_major >= 10 else 64\n\n        need_padding = num_heads % required_padding != 0\n\n        if need_padding:\n            assert required_padding % num_heads == 0, (\n                f\"num_heads {num_heads} cannot be padded to {required_padding}. \"\n                f\"TP size may be too large for this model.\"\n            )\n\n            # Pad q to required size\n            q_padded = q_all.new_zeros((num_tokens, required_padding, head_dim))\n            q_padded[:, :num_heads, :] = q_all\n            q_input = q_padded\n        else:\n            q_input = q_all\n\n        # indices shape must be (s_q, h_kv=1, topk), keep h_kv=1 unchanged\n        indices_input = page_table_1.unsqueeze(1)\n\n        o, _, _ = flash_mla_sparse_fwd(\n            q=q_input,\n            kv=kv_cache,\n            indices=indices_input,\n            sm_scale=sm_scale,\n            d_v=v_head_dim,\n        )\n\n        # Trim output back to original num_heads if we padded\n        if need_padding:\n            o = o[:, :num_heads, :]\n\n        return o\n\n    def _forward_flashmla_kv(\n        self,\n        q_all: torch.Tensor,\n        kv_cache: torch.Tensor,\n        v_head_dim: int,\n        sm_scale: float,\n        layer,\n        metadata: NSAMetadata,\n        page_table_1,\n    ) -> torch.Tensor:\n        from sgl_kernel.flash_mla import flash_mla_with_kvcache\n\n        cache_seqlens = metadata.nsa_cache_seqlens_int32\n\n        # TODO the 2nd dim is seq_len_q, need to be >1 when MTP\n        q_all = q_all.view(-1, 1, layer.tp_q_head_num, layer.head_dim)\n        kv_cache = kv_cache.view(-1, self.real_page_size, 1, self.kv_cache_dim)\n        assert self.real_page_size == 64, \"only page size 64 is supported\"\n\n        if not self.nsa_kv_cache_store_fp8:\n            # inefficiently quantize the whole cache\n            kv_cache = quantize_k_cache(kv_cache)\n\n        indices = page_table_1.unsqueeze(1)\n        assert (\n            indices.shape[-1] == self.nsa_index_topk\n        )  # requirement of FlashMLA decode kernel\n\n        o, _ = flash_mla_with_kvcache(\n            q=q_all,\n            k_cache=kv_cache,\n            cache_seqlens=cache_seqlens,\n            head_dim_v=v_head_dim,\n            tile_scheduler_metadata=metadata.flashmla_metadata.flashmla_metadata,\n            num_splits=metadata.flashmla_metadata.num_splits,\n            softmax_scale=sm_scale,\n            indices=indices,\n            # doc says it is not used, but if pass in None then error\n            block_table=torch.empty(\n                (q_all.shape[0], 0), dtype=torch.int32, device=q_all.device\n            ),\n            is_fp8_kvcache=True,\n        )\n        return o\n\n    def _forward_standard_mha(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        metadata: NSAMetadata,\n    ) -> torch.Tensor:\n        \"\"\"Standard MHA using FlashAttention varlen for MHA_ONE_SHOT mode.\"\"\"\n        q = q.view(-1, layer.tp_q_head_num, layer.head_dim)\n        k = k.view(-1, layer.tp_k_head_num, layer.head_dim)\n        v = v.view(-1, layer.tp_v_head_num, layer.v_head_dim)\n\n        # MHA_ONE_SHOT: k/v include all tokens (prefix + current)\n        cu_seqlens_q = metadata.cu_seqlens_q\n        cu_seqlens_k = metadata.cu_seqlens_k\n        max_seqlen_k = metadata.max_seq_len_k\n        causal = True\n\n        # Verify batch sizes match (length of cu_seqlens should be batch_size + 1)\n        assert len(cu_seqlens_q) == len(cu_seqlens_k), (\n            f\"batch_size mismatch: cu_seqlens_q has {len(cu_seqlens_q)-1} requests, \"\n            f\"cu_seqlens_k has {len(cu_seqlens_k)-1} requests\"\n        )\n\n        # Use TRTLLm ragged attention for SM100 (Blackwell/B200) to avoid FA4 accuracy issues\n        if self.device_sm_major >= 10:\n            import flashinfer\n\n            seq_lens = metadata.cache_seqlens_int32\n            return flashinfer.prefill.trtllm_ragged_attention_deepseek(\n                query=q,\n                key=k,\n                value=v,\n                workspace_buffer=self.workspace_buffer,\n                seq_lens=seq_lens,\n                max_q_len=metadata.max_seq_len_q,\n                max_kv_len=max_seqlen_k,\n                bmm1_scale=layer.scaling,\n                bmm2_scale=1.0,\n                o_sf_scale=1.0,\n                batch_size=forward_batch.batch_size,\n                window_left=-1,\n                cum_seq_lens_q=cu_seqlens_q,\n                cum_seq_lens_kv=cu_seqlens_k,\n                enable_pdl=False,\n                is_causal=causal,\n                return_lse=False,\n            )\n\n        # Use FA3 for SM90 (Hopper/H200)\n        return flash_attn_varlen_func(\n            q=q,\n            k=k,\n            v=v,\n            cu_seqlens_q=cu_seqlens_q,\n            cu_seqlens_k=cu_seqlens_k,\n            max_seqlen_q=metadata.max_seq_len_q,\n            max_seqlen_k=max_seqlen_k,\n            softmax_scale=layer.scaling,\n            causal=causal,\n        )\n\n    def _forward_tilelang(\n        self,\n        q_all: torch.Tensor,\n        kv_cache: torch.Tensor,\n        v_head_dim: int,\n        page_table_1: torch.Tensor,\n        sm_scale: float,\n    ) -> torch.Tensor:\n        from sglang.srt.layers.attention.nsa.tilelang_kernel import tilelang_sparse_fwd\n\n        return tilelang_sparse_fwd(\n            q=q_all,\n            kv=kv_cache,\n            indices=page_table_1.unsqueeze(1),\n            sm_scale=sm_scale,\n            d_v=v_head_dim,\n        )\n\n    def _forward_aiter(\n        self,\n        q_all: torch.Tensor,\n        kv_cache: torch.Tensor,\n        page_table_1: torch.Tensor,\n        layer: RadixAttention,\n        metadata: NSAMetadata,\n        bs: int,\n    ) -> torch.Tensor:\n        q = q_all.reshape(-1, layer.tp_q_head_num * layer.head_dim)\n\n        if layer.head_dim != layer.v_head_dim:\n            o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))\n        else:\n            o = torch.empty_like(q)\n\n        kv_indptr = self.kv_indptr\n\n        non_minus1_mask = page_table_1 != -1\n        non_minus1_counts = non_minus1_mask.sum(dim=1)\n        kv_indptr[1 : bs + 1] = torch.cumsum(non_minus1_counts, dim=0)\n\n        kv_indices = self.kv_indices\n        get_valid_kv_indices(page_table_1, kv_indptr, kv_indices, bs)\n\n        mla_decode_fwd(\n            q.view(-1, layer.tp_q_head_num, layer.head_dim),\n            kv_cache.view(-1, 1, 1, layer.head_dim),\n            o.view(-1, layer.tp_q_head_num, layer.v_head_dim),\n            metadata.cu_seqlens_q,\n            kv_indptr,\n            kv_indices,\n            metadata.cu_seqlens_q,\n            metadata.max_seq_len_q,\n            sm_scale=layer.scaling,\n            logit_cap=layer.logit_cap,\n        )\n        # kv_cache = kv_cache.view(-1, 1, layer.head_dim)\n        return o\n\n    def _forward_aiter_extend(\n        self,\n        q_all: torch.Tensor,\n        kv_cache: torch.Tensor,\n        page_table_1: torch.Tensor,\n        layer: RadixAttention,\n    ) -> torch.Tensor:\n        num_tokens = q_all.shape[0]\n        q = q_all.reshape(-1, layer.tp_q_head_num * layer.head_dim)\n\n        if layer.head_dim != layer.v_head_dim:\n            o = q.new_empty((num_tokens, layer.tp_q_head_num * layer.v_head_dim))\n        else:\n            o = torch.empty_like(q)\n\n        non_minus1_mask = page_table_1 != -1\n        non_minus1_counts = non_minus1_mask.sum(dim=1)\n\n        kv_indptr = torch.zeros(num_tokens + 1, dtype=torch.int32, device=self.device)\n        kv_indptr[1:] = torch.cumsum(non_minus1_counts, dim=0)\n\n        # Allocate kv_indices with upper-bound size (num_tokens * topk)\n        topk = page_table_1.shape[1]\n        kv_indices = torch.zeros(\n            num_tokens * topk, dtype=torch.int32, device=self.device\n        )\n\n        # Use get_valid_kv_indices kernel to extract valid indices\n        get_valid_kv_indices(page_table_1, kv_indptr, kv_indices, num_tokens)\n\n        # Build cu_seqlens_q for extend: each token is treated as seq_len_q=1\n        cu_seqlens_q = torch.arange(\n            0, num_tokens + 1, dtype=torch.int32, device=self.device\n        )\n        # TODO support more forward_mode\n        mla_decode_fwd(\n            q.view(-1, layer.tp_q_head_num, layer.head_dim),\n            kv_cache.view(-1, 1, 1, layer.head_dim),\n            o.view(-1, layer.tp_q_head_num, layer.v_head_dim),\n            cu_seqlens_q,\n            kv_indptr,\n            kv_indices,\n            cu_seqlens_q,\n            1,  # max_seq_len_q = 1 for per-token attention\n            sm_scale=layer.scaling,\n            logit_cap=layer.logit_cap,\n        )\n        return o\n\n    def _forward_trtllm(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        seq_lens: torch.Tensor,\n        save_kv_cache=True,\n        # For multi-head latent attention\n        q_rope: Optional[torch.Tensor] = None,\n        k_rope: Optional[torch.Tensor] = None,\n        topk_indices: Optional[torch.Tensor] = None,\n        cos_sin_cache: Optional[torch.Tensor] = None,\n        is_neox: Optional[bool] = False,\n        llama_4_scaling: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        \"\"\"Forward using TRT-LLM sparse MLA kernel.\"\"\"\n        import flashinfer.decode\n\n        metadata = self.forward_metadata\n\n        merge_query = q_rope is not None\n        if self.kv_cache_dtype == torch.float8_e4m3fn:\n            # For FP8 path, we quantize the query and rope parts and merge them into a single tensor\n            # Note: rope application in deepseek_v2.py:forward_absorb_prepare is skipped for FP8 decode path of this trtllm_mla backend\n            assert q_rope is not None, \"For FP8 path q_rope should not be None.\"\n            assert k_rope is not None, \"For FP8 path k_rope should not be None.\"\n            assert (\n                cos_sin_cache is not None\n            ), \"For FP8 path cos_sin_cache should not be None.\"\n\n            q, k, k_rope = mla_quantize_and_rope_for_fp8(\n                q,\n                q_rope,\n                k.squeeze(1),\n                k_rope.squeeze(1),\n                forward_batch.positions,\n                cos_sin_cache,\n                is_neox,\n                self.kv_lora_rank,\n                self.qk_rope_head_dim,\n            )\n            merge_query = False\n\n            # Save KV cache if requested\n        if save_kv_cache:\n            assert (\n                k is not None and k_rope is not None\n            ), \"For populating trtllm_mla kv cache, both k_nope and k_rope should be not None.\"\n            cache_loc = (\n                forward_batch.out_cache_loc\n                if not layer.is_cross_attention\n                else forward_batch.encoder_out_cache_loc\n            )\n            forward_batch.token_to_kv_pool.set_mla_kv_buffer(\n                layer, cache_loc, k, k_rope\n            )\n\n        k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)\n        kv_cache = k_cache.view(-1, self.real_page_size, self.kv_cache_dim).unsqueeze(1)\n\n        if merge_query:\n            q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)\n            q_rope_reshaped = q_rope.view(\n                -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim\n            )\n            q_all = concat_mla_absorb_q_general(q_nope, q_rope_reshaped)\n        else:\n            q_all = q.view(-1, layer.tp_q_head_num, layer.head_dim)\n\n        # Align topk_indices with q dimensions\n        if topk_indices is not None:\n            topk_indices = self._pad_topk_indices(topk_indices, q.shape[0])\n\n        if envs.SGLANG_NSA_FUSE_TOPK.get():\n            page_table_1 = topk_indices\n        else:\n            page_table_1 = transform_index_page_table_decode(\n                page_table=metadata.page_table_1,\n                topk_indices=topk_indices,\n                page_size=1,\n            )\n\n        q_scale = 1.0\n        k_scale = (\n            layer.k_scale_float\n            if getattr(layer, \"k_scale_float\", None) is not None\n            else 1.0\n        )\n        bmm1_scale = q_scale * k_scale * layer.scaling\n\n        batch_size = page_table_1.shape[0]\n        _, num_heads, head_dim = q_all.shape\n\n        q = q_all.view(batch_size, 1, num_heads, head_dim)\n        kv = kv_cache.view(-1, 1, self.real_page_size, self.kv_cache_dim)\n        block_tables = page_table_1.unsqueeze(1)\n        seq_lens = metadata.cache_seqlens_int32 if seq_lens is None else seq_lens\n\n        out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(\n            query=q,\n            kv_cache=kv,\n            workspace_buffer=self.workspace_buffer,\n            qk_nope_head_dim=self.qk_nope_head_dim,\n            kv_lora_rank=self.kv_lora_rank,\n            qk_rope_head_dim=self.qk_rope_head_dim,\n            block_tables=block_tables,\n            seq_lens=seq_lens,\n            max_seq_len=metadata.max_seq_len_k,\n            sparse_mla_top_k=self.nsa_index_topk,\n            bmm1_scale=bmm1_scale,\n            backend=\"trtllm-gen\",\n        )\n        # Output: [batch, q_len=1, heads, v_dim] -> [batch, heads, v_dim]\n        return out.squeeze(1)\n\n    def _pad_topk_indices(\n        self, topk_indices: torch.Tensor, num_tokens: int\n    ) -> torch.Tensor:\n        current_tokens = topk_indices.shape[0]\n        if current_tokens == num_tokens:\n            return topk_indices\n\n        assert current_tokens <= num_tokens, (\n            f\"topk_indices rows ({current_tokens}) > num_tokens ({num_tokens}); \"\n            \"this indicates a mismatch between indexer output and q layout.\"\n        )\n\n        pad_size = num_tokens - current_tokens\n        padding = torch.full(\n            (pad_size, topk_indices.shape[1]),\n            -1,\n            dtype=topk_indices.dtype,\n            device=topk_indices.device,\n        )\n        return torch.cat([topk_indices, padding], dim=0)\n\n    def get_cuda_graph_seq_len_fill_value(self):\n        \"\"\"Get the fill value for sequence length in CUDA graph.\"\"\"\n        return 1\n\n    def set_nsa_prefill_impl(self, forward_batch: Optional[ForwardBatch] = None):\n        \"\"\"\n        Decide all attention prefill dispatch strategies for this batch.\n        \"\"\"\n        from sglang.srt.utils import get_device_sm, is_blackwell\n\n        # Decide MHA vs MLA\n        if forward_batch and forward_batch.forward_mode.is_extend_without_speculative():\n            # Check if sequence meets criteria for MHA_ONE_SHOT\n            assert forward_batch.seq_lens_cpu is not None\n            max_kv_len = forward_batch.seq_lens_cpu.max().item()\n            sum_seq_lens = sum(forward_batch.seq_lens_cpu)\n            device_sm = get_device_sm()\n\n            # Requirements: H200/B200, short sequences, supported dtype, fits in chunk\n            self.use_mha = (\n                (\n                    device_sm == 90 or (device_sm >= 100 and device_sm < 110)\n                )  # SM90/SM100 only\n                and max_kv_len\n                <= envs.SGLANG_NSA_PREFILL_DENSE_ATTN_KV_LEN_THRESHOLD.get()  # Short enough for MHA\n                and forward_batch.token_to_kv_pool.dtype\n                in [torch.bfloat16, torch.float8_e4m3fn]\n                and sum_seq_lens\n                <= forward_batch.get_max_chunk_capacity()  # Fits in chunk\n                and (not is_nsa_enable_prefill_cp())  # CP not enabled\n            )\n        else:\n            self.use_mha = False  # Decode/verify always use MLA\n\n        # Set MLA implementation only if not using MHA\n        if not self.use_mha and self.enable_auto_select_prefill_impl:\n            if self.nsa_kv_cache_store_fp8:\n                if (\n                    is_blackwell()\n                    and forward_batch is not None\n                    and forward_batch.forward_mode == ForwardMode.EXTEND\n                ):\n                    total_kv_tokens = forward_batch.seq_lens_sum\n                    total_q_tokens = forward_batch.extend_num_tokens\n                    # Heuristic based on benchmarking flashmla_kv vs flashmla_sparse + dequantize_k_cache_paged\n                    if total_kv_tokens < total_q_tokens * 512:\n                        self.nsa_prefill_impl = \"flashmla_sparse\"\n                        return\n                self.nsa_prefill_impl = \"flashmla_kv\"\n            else:\n                # bf16 kv cache\n                self.nsa_prefill_impl = \"flashmla_sparse\"\n\n    def get_topk_transform_method(self) -> TopkTransformMethod:\n        \"\"\"\n        SGLANG_NSA_FUSE_TOPK controls whether to fuse the topk transform into the topk kernel.\n        This method is used to select the topk transform method which can be fused or unfused.\n        \"\"\"\n        if (\n            # disable for MTP\n            self.nsa_kv_cache_store_fp8\n            and self.nsa_prefill_impl == \"flashmla_sparse\"\n        ):\n            topk_transform_method = TopkTransformMethod.RAGGED\n        else:\n            topk_transform_method = TopkTransformMethod.PAGED\n        return topk_transform_method\n\n    def get_indexer_metadata(\n        self, layer_id: int, forward_batch: ForwardBatch\n    ) -> NSAIndexerMetadata:\n        return NSAIndexerMetadata(\n            attn_metadata=self.forward_metadata,\n            topk_transform_method=self.get_topk_transform_method(),\n            paged_mqa_schedule_metadata=self.forward_metadata.paged_mqa_schedule_metadata,\n        )\n\n    def _compute_flashmla_metadata(self, cache_seqlens: torch.Tensor, seq_len_q: int):\n        from sgl_kernel.flash_mla import get_mla_metadata\n\n        flashmla_metadata, num_splits = get_mla_metadata(\n            cache_seqlens=cache_seqlens,\n            # TODO doc says `num_q_tokens_per_q_seq * num_heads_q // num_heads_k`\n            #      but the name looks like need seq_len_q?\n            num_q_tokens_per_head_k=seq_len_q * self.num_q_heads // 1,\n            num_heads_k=1,\n            num_heads_q=self.num_q_heads,\n            is_fp8_kvcache=True,\n            topk=self.nsa_index_topk,\n        )\n\n        return NSAFlashMLAMetadata(\n            flashmla_metadata=flashmla_metadata,\n            num_splits=num_splits,\n        )\n\n\nclass NativeSparseAttnMultiStepBackend:\n\n    def __init__(\n        self, model_runner: ModelRunner, topk: int, speculative_num_steps: int\n    ):\n        self.model_runner = model_runner\n        self.topk = topk\n        self.speculative_num_steps = speculative_num_steps\n        self.attn_backends = []\n        for i in range(self.speculative_num_steps - 1):\n            self.attn_backends.append(\n                NativeSparseAttnBackend(\n                    model_runner,\n                    speculative_step_id=i,\n                    topk=self.topk,\n                    speculative_num_steps=self.speculative_num_steps,\n                )\n            )\n\n    def init_forward_metadata(self, forward_batch: ForwardBatch):\n        for i in range(self.speculative_num_steps - 1):\n            self.attn_backends[i].init_forward_metadata(forward_batch)\n\n    def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):\n        for i in range(self.speculative_num_steps - 1):\n            self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)\n\n    def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):\n        for i in range(self.speculative_num_steps - 1):\n            self.attn_backends[i].init_forward_metadata_capture_cuda_graph(\n                forward_batch.batch_size,\n                forward_batch.batch_size * self.topk,\n                forward_batch.req_pool_indices,\n                forward_batch.seq_lens,\n                encoder_lens=None,\n                forward_mode=ForwardMode.DECODE,\n                spec_info=forward_batch.spec_info,\n            )\n\n    def init_forward_metadata_replay_cuda_graph(\n        self, forward_batch: ForwardBatch, bs: int\n    ):\n        if envs.SGLANG_NSA_ENABLE_MTP_PRECOMPUTE_METADATA.get():\n            # Precompute metadata once (shared across all backends)\n            precomputed = self.attn_backends[0]._precompute_replay_metadata(\n                bs=bs,\n                req_pool_indices=forward_batch.req_pool_indices,\n                seq_lens=forward_batch.seq_lens,\n                seq_lens_cpu=forward_batch.seq_lens_cpu,\n                forward_mode=ForwardMode.DECODE,\n                spec_info=forward_batch.spec_info,\n            )\n\n            # Use multi-backend fused copy when we have 3 or more backends\n            # This is 3x faster than calling the single-backend copy 3 times\n            if self.speculative_num_steps > 3:\n                try:\n                    from sglang.jit_kernel.fused_metadata_copy import (\n                        fused_metadata_copy_multi_cuda,\n                    )\n\n                    metadata0 = self.attn_backends[0].decode_cuda_graph_metadata[bs]\n                    metadata1 = self.attn_backends[1].decode_cuda_graph_metadata[bs]\n                    metadata2 = self.attn_backends[2].decode_cuda_graph_metadata[bs]\n\n                    # Set nsa_prefill_impl for first 3 backends (required by the method)\n                    for i in range(3):\n                        self.attn_backends[i].set_nsa_prefill_impl(forward_batch=None)\n\n                    # Prepare FlashMLA tensors if needed\n                    flashmla_num_splits_src = None\n                    flashmla_metadata_src = None\n                    flashmla_num_splits_dst0 = None\n                    flashmla_num_splits_dst1 = None\n                    flashmla_num_splits_dst2 = None\n                    flashmla_metadata_dst0 = None\n                    flashmla_metadata_dst1 = None\n                    flashmla_metadata_dst2 = None\n\n                    if precomputed.flashmla_metadata is not None:\n                        flashmla_num_splits_src = (\n                            precomputed.flashmla_metadata.num_splits\n                        )\n                        flashmla_metadata_src = (\n                            precomputed.flashmla_metadata.flashmla_metadata\n                        )\n                        flashmla_num_splits_dst0 = (\n                            metadata0.flashmla_metadata.num_splits\n                        )\n                        flashmla_num_splits_dst1 = (\n                            metadata1.flashmla_metadata.num_splits\n                        )\n                        flashmla_num_splits_dst2 = (\n                            metadata2.flashmla_metadata.num_splits\n                        )\n                        flashmla_metadata_dst0 = (\n                            metadata0.flashmla_metadata.flashmla_metadata\n                        )\n                        flashmla_metadata_dst1 = (\n                            metadata1.flashmla_metadata.flashmla_metadata\n                        )\n                        flashmla_metadata_dst2 = (\n                            metadata2.flashmla_metadata.flashmla_metadata\n                        )\n\n                    # Call the multi-backend fused kernel for first 3 backends\n                    fused_metadata_copy_multi_cuda(\n                        # Source tensors\n                        precomputed.cache_seqlens,\n                        precomputed.cu_seqlens_k,\n                        precomputed.page_indices,\n                        precomputed.nsa_cache_seqlens,\n                        precomputed.nsa_cu_seqlens_k,\n                        precomputed.real_page_table,\n                        flashmla_num_splits_src,\n                        flashmla_metadata_src,\n                        # Destination tensors for backend 0\n                        metadata0.cache_seqlens_int32,\n                        metadata0.cu_seqlens_k,\n                        metadata0.page_table_1,\n                        metadata0.nsa_cache_seqlens_int32,\n                        metadata0.nsa_cu_seqlens_k,\n                        (\n                            metadata0.real_page_table\n                            if precomputed.real_page_table is not None\n                            else None\n                        ),\n                        flashmla_num_splits_dst0,\n                        flashmla_metadata_dst0,\n                        # Destination tensors for backend 1\n                        metadata1.cache_seqlens_int32,\n                        metadata1.cu_seqlens_k,\n                        metadata1.page_table_1,\n                        metadata1.nsa_cache_seqlens_int32,\n                        metadata1.nsa_cu_seqlens_k,\n                        (\n                            metadata1.real_page_table\n                            if precomputed.real_page_table is not None\n                            else None\n                        ),\n                        flashmla_num_splits_dst1,\n                        flashmla_metadata_dst1,\n                        # Destination tensors for backend 2\n                        metadata2.cache_seqlens_int32,\n                        metadata2.cu_seqlens_k,\n                        metadata2.page_table_1,\n                        metadata2.nsa_cache_seqlens_int32,\n                        metadata2.nsa_cu_seqlens_k,\n                        (\n                            metadata2.real_page_table\n                            if precomputed.real_page_table is not None\n                            else None\n                        ),\n                        flashmla_num_splits_dst2,\n                        flashmla_metadata_dst2,\n                        # Parameters\n                        bs,\n                        precomputed.max_len,\n                        precomputed.seqlens_expanded_size,\n                    )\n\n                    # Copy remaining backends one by one (if > 3 backends)\n                    for i in range(3, self.speculative_num_steps - 1):\n                        self.attn_backends[\n                            i\n                        ].init_forward_metadata_replay_cuda_graph_from_precomputed(\n                            bs=bs,\n                            precomputed=precomputed,\n                            forward_mode=ForwardMode.DECODE,\n                        )\n                except (ImportError, Exception) as e:\n                    # Fallback to loop if multi-backend kernel not available or fails\n                    if isinstance(e, ImportError):\n                        print(\n                            \"Warning: Multi-backend fused metadata copy kernel not available, falling back to loop.\"\n                        )\n                    else:\n                        print(\n                            f\"Warning: Multi-backend fused metadata copy kernel failed with error: {e}, falling back to loop.\"\n                        )\n                    for i in range(self.speculative_num_steps - 1):\n                        self.attn_backends[\n                            i\n                        ].init_forward_metadata_replay_cuda_graph_from_precomputed(\n                            bs=bs,\n                            precomputed=precomputed,\n                            forward_mode=ForwardMode.DECODE,\n                        )\n            else:\n                # Less than 3 backends: copy to each backend individually\n                for i in range(self.speculative_num_steps - 1):\n                    self.attn_backends[\n                        i\n                    ].init_forward_metadata_replay_cuda_graph_from_precomputed(\n                        bs=bs,\n                        precomputed=precomputed,\n                        forward_mode=ForwardMode.DECODE,\n                    )\n        else:\n            # Fallback: compute metadata separately for each backend\n            for i in range(self.speculative_num_steps - 1):\n                self.attn_backends[i].init_forward_metadata_replay_cuda_graph(\n                    bs=bs,\n                    req_pool_indices=forward_batch.req_pool_indices,\n                    seq_lens=forward_batch.seq_lens,\n                    seq_lens_sum=forward_batch.seq_lens_sum,\n                    encoder_lens=None,\n                    forward_mode=ForwardMode.DECODE,\n                    spec_info=forward_batch.spec_info,\n                    seq_lens_cpu=forward_batch.seq_lens_cpu,\n                    out_cache_loc=None,\n                )\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/tbo_backend.py",
    "content": "from typing import TYPE_CHECKING, Callable, List, Optional\n\nimport torch\n\nfrom sglang.srt.batch_overlap import two_batch_overlap\nfrom sglang.srt.layers.attention.base_attn_backend import AttentionBackend\nfrom sglang.srt.speculative.spec_info import SpecInput\n\nif TYPE_CHECKING:\n    from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode\n\n\nclass TboAttnBackend(AttentionBackend):\n    def __init__(self, primary: AttentionBackend, children: List[AttentionBackend]):\n        super().__init__()\n        self.primary = primary\n        self.children = children\n\n    @classmethod\n    def init_new(cls, creator: Callable[[], AttentionBackend]):\n        return cls(\n            primary=creator(),\n            children=[creator() for _ in range(2)],\n        )\n\n    def init_forward_metadata(self, forward_batch: \"ForwardBatch\"):\n        self.primary.init_forward_metadata(forward_batch=forward_batch)\n        if forward_batch.tbo_children is not None:\n            for child, forward_batch_child in zip(\n                self.children, forward_batch.tbo_children, strict=True\n            ):\n                if forward_batch_child.batch_size > 0:\n                    child.init_forward_metadata(forward_batch=forward_batch_child)\n\n    def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):\n        self.primary.init_cuda_graph_state(max_bs=max_bs, max_num_tokens=max_num_tokens)\n        for item in self.children:\n            # TODO for children, maybe can provide *smaller* max_bs to optimize\n            item.init_cuda_graph_state(max_bs=max_bs, max_num_tokens=max_num_tokens)\n\n    def init_forward_metadata_capture_cuda_graph(\n        self,\n        bs: int,\n        num_tokens: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        encoder_lens: Optional[torch.Tensor],\n        forward_mode: \"ForwardMode\",\n        spec_info: Optional[SpecInput],\n    ):\n        self.primary.init_forward_metadata_capture_cuda_graph(\n            bs=bs,\n            num_tokens=num_tokens,\n            req_pool_indices=req_pool_indices,\n            seq_lens=seq_lens,\n            encoder_lens=encoder_lens,\n            forward_mode=forward_mode,\n            spec_info=spec_info,\n        )\n\n        self._init_forward_metadata_cuda_graph_children(\n            fn_name=\"init_forward_metadata_capture_cuda_graph\",\n            bs=bs,\n            req_pool_indices=req_pool_indices,\n            seq_lens=seq_lens,\n            encoder_lens=encoder_lens,\n            forward_mode=forward_mode,\n            spec_info=spec_info,\n            capture_num_tokens=num_tokens,\n        )\n\n    def init_forward_metadata_replay_cuda_graph(\n        self,\n        bs: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        seq_lens_sum: int,\n        encoder_lens: Optional[torch.Tensor],\n        forward_mode: \"ForwardMode\",\n        spec_info: Optional[SpecInput],\n        seq_lens_cpu: Optional[torch.Tensor],\n    ):\n        self.primary.init_forward_metadata_replay_cuda_graph(\n            bs=bs,\n            req_pool_indices=req_pool_indices,\n            seq_lens=seq_lens,\n            seq_lens_sum=seq_lens_sum,\n            encoder_lens=encoder_lens,\n            forward_mode=forward_mode,\n            spec_info=spec_info,\n            seq_lens_cpu=seq_lens_cpu,\n        )\n\n        self._init_forward_metadata_cuda_graph_children(\n            fn_name=\"init_forward_metadata_replay_cuda_graph\",\n            bs=bs,\n            req_pool_indices=req_pool_indices,\n            seq_lens=seq_lens,\n            encoder_lens=encoder_lens,\n            forward_mode=forward_mode,\n            spec_info=spec_info,\n            replay_seq_lens_sum=seq_lens_sum,\n            replay_seq_lens_cpu=seq_lens_cpu,\n        )\n\n    def _init_forward_metadata_cuda_graph_children(\n        self,\n        fn_name: str,\n        # common args\n        bs: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        encoder_lens: Optional[torch.Tensor],\n        forward_mode: \"ForwardMode\",\n        spec_info: Optional[SpecInput],\n        # capture args\n        capture_num_tokens: int = None,\n        # replay args\n        replay_seq_lens_sum: int = None,\n        replay_seq_lens_cpu: Optional[torch.Tensor] = None,\n    ):\n        token_num_per_seq = two_batch_overlap.get_token_num_per_seq(\n            forward_mode=forward_mode, spec_info=spec_info\n        )\n        if fn_name == \"init_forward_metadata_capture_cuda_graph\":\n            assert (\n                capture_num_tokens == bs * token_num_per_seq\n            ), \"For target-verify or decode mode, num_tokens should be equal to token_num_per_seq * bs\"\n        num_tokens = bs * token_num_per_seq\n\n        tbo_split_seq_index, tbo_split_token_index = (\n            two_batch_overlap.compute_split_indices_for_cuda_graph_replay(\n                forward_mode=forward_mode,\n                cuda_graph_num_tokens=num_tokens,\n                spec_info=spec_info,\n            )\n        )\n\n        num_tokens_child_left = tbo_split_token_index\n        num_tokens_child_right = num_tokens - tbo_split_token_index\n        bs_child_left = tbo_split_seq_index\n        bs_child_right = bs - bs_child_left\n\n        assert (\n            num_tokens_child_left > 0 and num_tokens_child_right > 0\n        ), f\"{num_tokens_child_left=} {num_tokens_child_right=} {forward_mode=} {num_tokens=}\"\n\n        common_pre_split_args = dict(\n            fn_name=fn_name,\n            bs=bs,\n            req_pool_indices=req_pool_indices,\n            seq_lens=seq_lens,\n            encoder_lens=encoder_lens,\n            forward_mode=forward_mode,\n            spec_info=spec_info,\n            capture_num_tokens=capture_num_tokens,\n            replay_seq_lens_sum=replay_seq_lens_sum,\n            replay_seq_lens_cpu=replay_seq_lens_cpu,\n        )\n\n        args_left = _init_forward_metadata_cuda_graph_split(\n            output_bs=bs_child_left,\n            seq_slice=slice(None, tbo_split_seq_index),\n            **common_pre_split_args,\n        )\n        args_right = _init_forward_metadata_cuda_graph_split(\n            output_bs=bs_child_right,\n            seq_slice=slice(tbo_split_seq_index, None),\n            **common_pre_split_args,\n        )\n\n        child_left, child_right = self.children\n        getattr(child_left, fn_name)(**args_left)\n        getattr(child_right, fn_name)(**args_right)\n\n    def get_cuda_graph_seq_len_fill_value(self):\n        ans = self.primary.get_cuda_graph_seq_len_fill_value()\n        for child in self.children:\n            assert ans == child.get_cuda_graph_seq_len_fill_value()\n        return ans\n\n    def forward_extend(self, *args, **kwargs):\n        return self.primary.forward_extend(*args, **kwargs)\n\n    def forward_decode(self, *args, **kwargs):\n        return self.primary.forward_decode(*args, **kwargs)\n\n    def get_indexer_metadata(self, layer_id: int, forward_batch: \"ForwardBatch\"):\n        return self.primary.get_indexer_metadata(layer_id, forward_batch)\n\n\ndef _init_forward_metadata_cuda_graph_split(\n    fn_name: str,\n    seq_slice: slice,\n    output_bs: int,\n    # common args\n    bs: int,\n    req_pool_indices: torch.Tensor,\n    seq_lens: torch.Tensor,\n    encoder_lens: Optional[torch.Tensor],\n    forward_mode: \"ForwardMode\",\n    spec_info: Optional[SpecInput],\n    # capture args\n    capture_num_tokens: int = None,\n    # replay args\n    replay_seq_lens_sum: int = None,\n    replay_seq_lens_cpu: Optional[torch.Tensor] = None,\n):\n    token_num_per_seq = two_batch_overlap.get_token_num_per_seq(\n        forward_mode=forward_mode, spec_info=spec_info\n    )\n    assert encoder_lens is None, \"encoder_lens is not supported yet\"\n    if spec_info is not None:\n        output_spec_info = two_batch_overlap.split_spec_info(\n            spec_info=spec_info,\n            start_seq_index=seq_slice.start if seq_slice.start is not None else 0,\n            end_seq_index=seq_slice.stop if seq_slice.stop is not None else bs,\n            start_token_index=(\n                seq_slice.start * token_num_per_seq\n                if seq_slice.start is not None\n                else 0\n            ),\n            end_token_index=(\n                seq_slice.stop * token_num_per_seq\n                if seq_slice.stop is not None\n                else bs * token_num_per_seq\n            ),\n        )\n\n    else:\n        output_spec_info = None\n    ans = dict(\n        bs=output_bs,\n        req_pool_indices=req_pool_indices[seq_slice],\n        seq_lens=seq_lens[seq_slice],\n        # directly forward\n        forward_mode=forward_mode,\n        # ignore\n        encoder_lens=None,\n        spec_info=output_spec_info,\n    )\n\n    if fn_name == \"init_forward_metadata_capture_cuda_graph\":\n        assert (\n            capture_num_tokens == bs * token_num_per_seq\n        ), \"Only support num_tokens==bs * token_num_per_seq for target-verify or decode mode\"\n        ans.update(\n            dict(\n                num_tokens=output_bs * token_num_per_seq,\n            )\n        )\n    elif fn_name == \"init_forward_metadata_replay_cuda_graph\":\n        output_seq_lens_cpu = replay_seq_lens_cpu[seq_slice]\n        ans.update(\n            dict(\n                seq_lens_sum=output_seq_lens_cpu.sum().item(),\n                seq_lens_cpu=output_seq_lens_cpu,\n            )\n        )\n    else:\n        raise NotImplementedError\n\n    return ans\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/torch_flex_backend.py",
    "content": "from __future__ import annotations\n\nfrom typing import TYPE_CHECKING\n\nimport torch\nfrom torch.nn.attention.flex_attention import create_block_mask, flex_attention\n\nfrom sglang.srt.layers.attention.base_attn_backend import AttentionBackend\nfrom sglang.srt.layers.radix_attention import AttentionType\nfrom sglang.srt.model_executor.forward_batch_info import ForwardBatch\n\nif TYPE_CHECKING:\n    from sglang.srt.layers.radix_attention import RadixAttention\n    from sglang.srt.model_executor.model_runner import ModelRunner\n\n\nclass TorchFlexAttnBackend(AttentionBackend):\n    def __init__(self, model_runner: ModelRunner):\n        super().__init__()\n        self.forward_metadata = None\n        self.device = model_runner.device\n        self.flex_attention = torch.compile(flex_attention, dynamic=True)\n        torch._dynamo.config.cache_size_limit = 1024\n        torch._dynamo.config.accumulated_cache_size_limit = 1024\n\n    def init_forward_metadata(self, forward_batch: ForwardBatch):\n        \"\"\"Init the metadata for a forward pass.\"\"\"\n        # TODO: find a more elegant way to save memory\n        # Currently maintain the same memory as torch_native_backend\n        torch.cuda.empty_cache()\n\n        # Provide two block_mask Lists per seq_idx for lower latency, later will support per layer level mask generation\n        self.extend_block_masks = []\n        self.decode_block_masks = []\n\n        if forward_batch.forward_mode.is_extend():\n            for seq_idx in range(forward_batch.seq_lens.shape[0]):\n                seq_len_kv = forward_batch.seq_lens[seq_idx]\n                seq_len_q = seq_len_kv\n                self.extend_block_masks.append(\n                    create_block_mask(\n                        self._causal_mask,\n                        None,\n                        None,\n                        seq_len_q,\n                        seq_len_kv,\n                        device=self.device,\n                        _compile=False,\n                    )\n                )\n\n        elif forward_batch.forward_mode.is_decode():\n            for seq_idx in range(forward_batch.seq_lens.shape[0]):\n                seq_len_q = 1\n                seq_len_kv = forward_batch.seq_lens[seq_idx]\n\n                self.decode_block_masks.append(\n                    create_block_mask(\n                        self._decode_mask,\n                        None,\n                        None,\n                        seq_len_q,\n                        seq_len_kv,\n                        device=self.device,\n                        _compile=False,\n                    )\n                )\n\n    def _causal_mask(self, b, h, q_idx, kv_idx):\n        return q_idx >= kv_idx\n\n    def _decode_mask(self, b, h, q_idx, kv_idx):\n        return q_idx <= kv_idx\n\n    def _run_flex_forward_extend(\n        self,\n        query: torch.Tensor,\n        output: torch.Tensor,\n        k_cache: torch.Tensor,\n        v_cache: torch.Tensor,\n        req_to_token: torch.Tensor,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        extend_prefix_lens: torch.Tensor,\n        extend_seq_lens: torch.Tensor,\n        scaling=None,\n        enable_gqa=False,\n        causal=False,\n    ):\n        \"\"\"Run the extend forward by using torch flex attention op.\n\n        Args:\n            query: [num_tokens, num_heads, head_size]\n            output: [num_tokens, num_heads, head_size]\n            k_cache: [max_total_num_tokens, num_heads, head_size]\n            v_cache: [max_total_num_tokens, num_heads, head_size]\n            req_to_token: [max_num_reqs, max_context_len]\n            req_pool_indices: [num_seqs]\n            seq_lens: [num_seqs]\n            extend_prefix_lens: [num_seqs]\n            extend_seq_lens: [num_seqs]\n            scaling: float or None\n            enable_gqa: bool\n            causal: bool\n\n        Returns:\n            output: [num_tokens, num_heads, head_size]\n        \"\"\"\n\n        assert seq_lens.shape[0] == extend_prefix_lens.shape[0]\n        assert seq_lens.shape[0] == extend_seq_lens.shape[0]\n\n        # [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]\n        query = query.movedim(0, query.dim() - 2)\n\n        start_q, start_kv = 0, 0\n\n        for seq_idx in range(seq_lens.shape[0]):\n            # TODO: this loop process a sequence per iter, this is inefficient.\n            # Need optimize the performance later.\n            extend_seq_len_q = extend_seq_lens[seq_idx]\n            prefill_seq_len_q = extend_prefix_lens[seq_idx]\n\n            seq_len_kv = seq_lens[seq_idx]\n            end_q = start_q + extend_seq_len_q\n            end_kv = start_kv + seq_len_kv\n\n            per_req_query = query[:, start_q:end_q, :]\n            per_req_query_redundant = torch.empty(\n                (per_req_query.shape[0], seq_len_kv, per_req_query.shape[2]),\n                dtype=per_req_query.dtype,\n                device=per_req_query.device,\n            )\n\n            per_req_query_redundant[:, prefill_seq_len_q:, :] = per_req_query\n\n            # get key and value from cache. per_req_tokens contains the kv cache\n            # index for each token in the sequence.\n            req_pool_idx = req_pool_indices[seq_idx]\n            per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]\n            per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)\n            per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)\n\n            if not causal:\n                raise NotImplementedError(\"Non-causal mode is not yet implemented.\")\n\n            per_req_out_redundant = (\n                self.flex_attention(\n                    per_req_query_redundant.unsqueeze(0),\n                    per_req_key.unsqueeze(0),\n                    per_req_value.unsqueeze(0),\n                    block_mask=self.extend_block_masks[seq_idx],\n                    scale=scaling,\n                    enable_gqa=enable_gqa,\n                )\n                .squeeze(0)\n                .movedim(query.dim() - 2, 0)\n            )\n            output[start_q:end_q, :, :] = per_req_out_redundant[\n                prefill_seq_len_q:, :, :\n            ]\n            start_q, start_kv = end_q, end_kv\n        return output\n\n    def _run_flex_forward_decode(\n        self,\n        query: torch.Tensor,\n        output: torch.Tensor,\n        k_cache: torch.Tensor,\n        v_cache: torch.Tensor,\n        req_to_token: torch.Tensor,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        scaling=None,\n        enable_gqa=False,\n        causal=False,\n    ):\n        \"\"\"Run the decode forward by using torch flex attention op.\n\n        Args:\n            query: [num_tokens, num_heads, head_size]\n            output: [num_tokens, num_heads, head_size]\n            k_cache: [max_total_num_tokens, num_heads, head_size]\n            v_cache: [max_total_num_tokens, num_heads, head_size]\n            req_to_token: [max_num_reqs, max_context_len]\n            req_pool_indices: [num_seqs]\n            seq_lens: [num_seqs]\n            scaling: float or None\n            enable_gqa: bool\n            causal: bool\n\n        Returns:\n            output: [num_tokens, num_heads, head_size]\n        \"\"\"\n\n        # [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]\n        query = query.movedim(0, query.dim() - 2)\n\n        start_q, start_kv = 0, 0\n        for seq_idx in range(seq_lens.shape[0]):\n            # TODO: this loop process a sequence per iter, this is inefficient.\n            # Need optimize the performance later.\n\n            seq_len_q = 1\n            seq_len_kv = seq_lens[seq_idx]\n            end_q = start_q + seq_len_q\n            end_kv = start_kv + seq_len_kv\n\n            per_req_query = query[:, start_q:end_q, :]\n\n            # get key and value from cache. per_req_tokens contains the kv cache\n            # index for each token in the sequence.\n            req_pool_idx = req_pool_indices[seq_idx]\n            per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]\n            per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)\n            per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)\n\n            per_req_out = (\n                self.flex_attention(\n                    per_req_query.unsqueeze(0),\n                    per_req_key.unsqueeze(0),\n                    per_req_value.unsqueeze(0),\n                    block_mask=self.decode_block_masks[seq_idx],\n                    scale=scaling,\n                    enable_gqa=enable_gqa,\n                )\n                .squeeze(0)\n                .movedim(query.dim() - 2, 0)\n            )\n\n            output[start_q:end_q, :, :] = per_req_out\n            start_q, start_kv = end_q, end_kv\n\n        return output\n\n    def forward_extend(\n        self,\n        q,\n        k,\n        v,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache=True,\n    ):\n        if layer.qk_head_dim != layer.v_head_dim:\n            o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))\n        else:\n            o = torch.empty_like(q)\n\n        if save_kv_cache:\n            forward_batch.token_to_kv_pool.set_kv_buffer(\n                layer, forward_batch.out_cache_loc, k, v\n            )\n\n        use_gqa = layer.tp_q_head_num != layer.tp_k_head_num\n\n        q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)\n        o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)\n\n        causal = True\n        if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY:\n            raise NotImplementedError(\n                \"TorchFlexAttnBackend does not support non-causal attention for now.\"\n            )\n\n        self._run_flex_forward_extend(\n            q_,\n            o_,\n            forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),\n            forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),\n            forward_batch.req_to_token_pool.req_to_token,\n            forward_batch.req_pool_indices,\n            forward_batch.seq_lens,\n            forward_batch.extend_prefix_lens,\n            forward_batch.extend_seq_lens,\n            scaling=layer.scaling,\n            enable_gqa=use_gqa,\n            causal=causal,\n        )\n        return o\n\n    def forward_decode(\n        self,\n        q,\n        k,\n        v,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache=True,\n    ):\n        # During torch.compile, there is a bug in rotary_emb that causes the\n        # output value to have a 3D tensor shape. This reshapes the output correctly.\n        q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)\n\n        if layer.qk_head_dim != layer.v_head_dim:\n            o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))\n        else:\n            o = torch.empty_like(q)\n\n        if save_kv_cache:\n            forward_batch.token_to_kv_pool.set_kv_buffer(\n                layer, forward_batch.out_cache_loc, k, v\n            )\n\n        use_gqa = layer.tp_q_head_num != layer.tp_k_head_num\n        q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)\n        o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)\n\n        self._run_flex_forward_decode(\n            q_,\n            o_,\n            forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),\n            forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),\n            forward_batch.req_to_token_pool.req_to_token,\n            forward_batch.req_pool_indices,\n            forward_batch.seq_lens,\n            scaling=layer.scaling,\n            enable_gqa=use_gqa,\n            causal=False,\n        )\n\n        return o\n\n    def support_triton(self):\n        return False\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/torch_native_backend.py",
    "content": "from __future__ import annotations\n\nfrom typing import TYPE_CHECKING\n\nimport torch\nfrom torch.nn.functional import scaled_dot_product_attention\n\nfrom sglang.srt.layers.attention.base_attn_backend import AttentionBackend\nfrom sglang.srt.layers.radix_attention import AttentionType\nfrom sglang.srt.model_executor.forward_batch_info import ForwardBatch\n\nif TYPE_CHECKING:\n    from sglang.srt.layers.radix_attention import RadixAttention\n    from sglang.srt.model_executor.model_runner import ModelRunner\n\n\nclass TorchNativeAttnBackend(AttentionBackend):\n    def __init__(self, model_runner: ModelRunner):\n        super().__init__()\n        self.forward_metadata = None\n        self.device = model_runner.device\n\n    def init_forward_metadata(self, forward_batch: ForwardBatch):\n        \"\"\"Init the metadata for a forward pass.\"\"\"\n        pass\n\n    def _run_sdpa_forward_extend(\n        self,\n        query: torch.Tensor,\n        output: torch.Tensor,\n        k_cache: torch.Tensor,\n        v_cache: torch.Tensor,\n        req_to_token: torch.Tensor,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        extend_prefix_lens: torch.Tensor,\n        extend_seq_lens: torch.Tensor,\n        scaling=None,\n        enable_gqa=False,\n        causal=False,\n    ):\n        \"\"\"Run the extend forward by using torch native sdpa op.\n\n        Args:\n            query: [num_tokens, num_heads, head_size]\n            output: [num_tokens, num_heads, head_size]\n            k_cache: [max_total_num_tokens, num_heads, head_size]\n            v_cache: [max_total_num_tokens, num_heads, head_size]\n            req_to_token: [max_num_reqs, max_context_len]\n            req_pool_indices: [num_seqs]\n            seq_lens: [num_seqs]\n            extend_prefix_lens: [num_seqs]\n            extend_seq_lens: [num_seqs]\n            scaling: float or None\n            enable_gqa: bool\n            causal: bool\n\n        Returns:\n            output: [num_tokens, num_heads, head_size]\n        \"\"\"\n\n        assert seq_lens.shape[0] == extend_prefix_lens.shape[0]\n        assert seq_lens.shape[0] == extend_seq_lens.shape[0]\n\n        # [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]\n        query = query.movedim(0, query.dim() - 2)\n\n        start_q, start_kv = 0, 0\n        for seq_idx in range(seq_lens.shape[0]):\n            # TODO: this loop process a sequence per iter, this is inefficient.\n            # Need optimize the performance later.\n\n            extend_seq_len_q = extend_seq_lens[seq_idx]\n            prefill_seq_len_q = extend_prefix_lens[seq_idx]\n\n            seq_len_kv = seq_lens[seq_idx]\n            end_q = start_q + extend_seq_len_q\n            end_kv = start_kv + seq_len_kv\n\n            per_req_query = query[:, start_q:end_q, :]\n            per_req_query_redudant = torch.empty(\n                (per_req_query.shape[0], seq_len_kv, per_req_query.shape[2]),\n                dtype=per_req_query.dtype,\n                device=per_req_query.device,\n            )\n\n            per_req_query_redudant[:, prefill_seq_len_q:, :] = per_req_query\n\n            # get key and value from cache. per_req_tokens contains the kv cache\n            # index for each token in the sequence.\n            req_pool_idx = req_pool_indices[seq_idx]\n            per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]\n            per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)\n            per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)\n\n            if not (per_req_query.dtype == per_req_key.dtype == per_req_value.dtype):\n                # scaled_dot_product_attention() expects query, key, and value to have the same dtype\n                per_req_key = per_req_key.to(per_req_query.dtype)\n                per_req_value = per_req_value.to(per_req_query.dtype)\n\n            per_req_out_redudant = (\n                scaled_dot_product_attention(\n                    per_req_query_redudant.unsqueeze(0),\n                    per_req_key.unsqueeze(0),\n                    per_req_value.unsqueeze(0),\n                    enable_gqa=enable_gqa,\n                    scale=scaling,\n                    is_causal=causal,\n                )\n                .squeeze(0)\n                .movedim(query.dim() - 2, 0)\n            )\n            output[start_q:end_q, :, :] = per_req_out_redudant[prefill_seq_len_q:, :, :]\n            start_q, start_kv = end_q, end_kv\n        return output\n\n    def _run_sdpa_forward_decode(\n        self,\n        query: torch.Tensor,\n        output: torch.Tensor,\n        k_cache: torch.Tensor,\n        v_cache: torch.Tensor,\n        req_to_token: torch.Tensor,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        scaling=None,\n        enable_gqa=False,\n        causal=False,\n    ):\n        \"\"\"Run the decode forward by using torch native sdpa op.\n\n        Args:\n            query: [num_tokens, num_heads, head_size]\n            output: [num_tokens, num_heads, head_size]\n            k_cache: [max_total_num_tokens, num_heads, head_size]\n            v_cache: [max_total_num_tokens, num_heads, head_size]\n            req_to_token: [max_num_reqs, max_context_len]\n            req_pool_indices: [num_seqs]\n            seq_lens: [num_seqs]\n            scaling: float or None\n            enable_gqa: bool\n            causal: bool\n\n        Returns:\n            output: [num_tokens, num_heads, head_size]\n        \"\"\"\n\n        # [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]\n        query = query.movedim(0, query.dim() - 2)\n\n        start_q, start_kv = 0, 0\n        for seq_idx in range(seq_lens.shape[0]):\n            # TODO: this loop process a sequence per iter, this is inefficient.\n            # Need optimize the performance later.\n\n            seq_len_q = 1\n            seq_len_kv = seq_lens[seq_idx]\n            end_q = start_q + seq_len_q\n            end_kv = start_kv + seq_len_kv\n\n            per_req_query = query[:, start_q:end_q, :]\n\n            # get key and value from cache. per_req_tokens contains the kv cache\n            # index for each token in the sequence.\n            req_pool_idx = req_pool_indices[seq_idx]\n            per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]\n            per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)\n            per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)\n\n            if not (per_req_query.dtype == per_req_key.dtype == per_req_value.dtype):\n                # scaled_dot_product_attention() expects query, key, and value to have the same dtype\n                per_req_key = per_req_key.to(per_req_query.dtype)\n                per_req_value = per_req_value.to(per_req_query.dtype)\n\n            per_req_out = (\n                scaled_dot_product_attention(\n                    per_req_query.unsqueeze(0),\n                    per_req_key.unsqueeze(0),\n                    per_req_value.unsqueeze(0),\n                    enable_gqa=enable_gqa,\n                    scale=scaling,\n                    is_causal=causal,\n                )\n                .squeeze(0)\n                .movedim(query.dim() - 2, 0)\n            )\n            output[start_q:end_q, :, :] = per_req_out\n            start_q, start_kv = end_q, end_kv\n\n        return output\n\n    def forward_extend(\n        self,\n        q,\n        k,\n        v,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache=True,\n    ):\n        if layer.qk_head_dim != layer.v_head_dim:\n            o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))\n        else:\n            o = torch.empty_like(q)\n\n        if layer.is_cross_attention:\n            cache_loc = forward_batch.encoder_out_cache_loc\n        else:\n            cache_loc = forward_batch.out_cache_loc\n\n        if save_kv_cache:\n            forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)\n\n        use_gqa = layer.tp_q_head_num != layer.tp_k_head_num\n\n        q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)\n        o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)\n\n        causal = True\n        if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY:\n            causal = False\n\n        self._run_sdpa_forward_extend(\n            q_,\n            o_,\n            forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),\n            forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),\n            forward_batch.req_to_token_pool.req_to_token,\n            forward_batch.req_pool_indices,\n            forward_batch.seq_lens,\n            forward_batch.extend_prefix_lens,\n            forward_batch.extend_seq_lens,\n            scaling=layer.scaling,\n            enable_gqa=use_gqa,\n            causal=causal,\n        )\n        return o\n\n    def forward_decode(\n        self,\n        q,\n        k,\n        v,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache=True,\n    ):\n        # During torch.compile, there is a bug in rotary_emb that causes the\n        # output value to have a 3D tensor shape. This reshapes the output correctly.\n        q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)\n\n        if layer.qk_head_dim != layer.v_head_dim:\n            o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))\n        else:\n            o = torch.empty_like(q)\n\n        if layer.is_cross_attention:\n            cache_loc = forward_batch.encoder_out_cache_loc\n        else:\n            cache_loc = forward_batch.out_cache_loc\n\n        if save_kv_cache:\n            forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)\n\n        use_gqa = layer.tp_q_head_num != layer.tp_k_head_num\n\n        q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)\n        o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)\n\n        self._run_sdpa_forward_decode(\n            q_,\n            o_,\n            forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),\n            forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),\n            forward_batch.req_to_token_pool.req_to_token,\n            forward_batch.req_pool_indices,\n            forward_batch.seq_lens,\n            scaling=layer.scaling,\n            enable_gqa=use_gqa,\n            causal=False,\n        )\n\n        return o\n\n    def support_triton(self):\n        return False\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/triton_backend.py",
    "content": "from __future__ import annotations\n\nfrom dataclasses import dataclass\nfrom typing import TYPE_CHECKING, List, Optional\n\nimport torch\nimport triton\nimport triton.language as tl\n\nfrom sglang.srt.configs.model_config import AttentionArch\nfrom sglang.srt.layers.attention.base_attn_backend import AttentionBackend\nfrom sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton\nfrom sglang.srt.layers.dp_attention import get_attention_tp_size\nfrom sglang.srt.layers.radix_attention import AttentionType\nfrom sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode\nfrom sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices\nfrom sglang.srt.utils import (\n    get_bool_env_var,\n    get_device_core_count,\n    get_int_env_var,\n    next_power_of_2,\n)\n\nif TYPE_CHECKING:\n    from sglang.srt.layers.radix_attention import RadixAttention\n    from sglang.srt.model_executor.model_runner import ModelRunner\n    from sglang.srt.speculative.spec_info import SpecInput\n\n\ndef logit_capping_mod(logit_capping_method, logit_cap):\n    # positive logit_cap -> tanh cap\n    if logit_capping_method == \"tanh\":\n        return logit_cap\n    else:\n        raise ValueError()\n\n\n@dataclass\nclass ForwardMetadata:\n    attn_logits: torch.Tensor\n    attn_lse: torch.Tensor\n    max_extend_len: int\n    num_kv_splits: torch.Tensor\n    kv_indptr: torch.Tensor\n    kv_indices: torch.Tensor\n    qo_indptr: torch.Tensor\n    custom_mask: torch.Tensor\n    mask_indptr: torch.Tensor\n    # Sliding window\n    window_kv_indptr: torch.Tensor\n    window_kv_indices: torch.Tensor\n    window_num_kv_splits: torch.Tensor\n    window_kv_offsets: torch.Tensor\n\n\nclass TritonAttnBackend(AttentionBackend):\n    def __init__(\n        self,\n        model_runner: ModelRunner,\n        skip_prefill: bool = False,\n        kv_indptr_buf: Optional[torch.Tensor] = None,\n    ):\n        # Lazy import to avoid the initialization of cuda context\n        from sglang.srt.layers.attention.triton_ops.decode_attention import (\n            decode_attention_fwd,\n        )\n        from sglang.srt.layers.attention.triton_ops.extend_attention import (\n            build_unified_kv_indices,\n            extend_attention_fwd,\n            extend_attention_fwd_unified,\n        )\n\n        super().__init__()\n\n        self.decode_attention_fwd = torch.compiler.disable(decode_attention_fwd)\n        self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd)\n        self.extend_attention_fwd_unified = torch.compiler.disable(\n            extend_attention_fwd_unified\n        )\n        self.build_unified_kv_indices = torch.compiler.disable(build_unified_kv_indices)\n\n        # Parse args\n        self.skip_prefill = skip_prefill\n        max_bs = model_runner.req_to_token_pool.size\n        self.sliding_window_size = model_runner.sliding_window_size\n        self.req_to_token = model_runner.req_to_token_pool.req_to_token\n        self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator\n        self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens\n        self.speculative_num_steps = model_runner.server_args.speculative_num_steps\n        self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA\n        self.num_head = (\n            model_runner.model_config.num_attention_heads // get_attention_tp_size()\n        )\n        self.num_kv_head = model_runner.model_config.get_num_kv_heads(\n            get_attention_tp_size()\n        )\n        if (\n            model_runner.hybrid_gdn_config is not None\n            or model_runner.kimi_linear_config is not None\n        ):\n            # For hybrid linear models, layer_id = 0 may not be full attention\n            self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim()\n        else:\n            self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[\n                -1\n            ]\n        self.max_context_len = model_runner.model_config.context_len\n        self.device = model_runner.device\n        self.device_core_count = get_device_core_count(model_runner.gpu_id)\n        self.static_kv_splits = get_bool_env_var(\n            \"SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS\", \"false\"\n        )\n        self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits\n\n        self.allow_bidirectional_attention_in_extend = (\n            model_runner.server_args.disable_cuda_graph\n            and (model_runner.server_args.chunked_prefill_size == -1)\n        )\n\n        # Decide whether enable deterministic inference with batch-invariant operations\n        self.enable_deterministic = (\n            model_runner.server_args.enable_deterministic_inference\n        )\n\n        # Configure deterministic inference settings\n        if self.enable_deterministic:\n            # Use fixed split tile size for batch invariance\n            self.split_tile_size = get_int_env_var(\n                \"SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE\", 256\n            )\n            # Set static_kv_splits to False to use deterministic logic instead\n            self.static_kv_splits = False\n        else:\n            self.split_tile_size = (\n                model_runner.server_args.triton_attention_split_tile_size\n            )\n\n        if self.split_tile_size is not None:\n            self.max_kv_splits = (\n                self.max_context_len + self.split_tile_size - 1\n            ) // self.split_tile_size\n\n        # Check arguments\n        assert not (\n            model_runner.sliding_window_size is not None\n            and model_runner.model_config.is_encoder_decoder\n        ), \"Sliding window and cross attention are not supported together\"\n\n        # Initialize buffers\n        # TODO(Jianan Ji): Make sure it behaves as expected when kv_indptr_buf is provided and sliding window is enabled\n        if kv_indptr_buf is None:\n            self.kv_indptr = torch.zeros(\n                (max_bs + 1,), dtype=torch.int32, device=model_runner.device\n            )\n        else:\n            self.kv_indptr = kv_indptr_buf\n\n        # If sliding window is enabled, we might need two sets of buffers\n        # because of interleaved attention types (e.g. for Gemma3)\n        self.window_kv_indptr = None\n        if self.sliding_window_size is not None and self.sliding_window_size > 0:\n            if kv_indptr_buf is None:\n                self.window_kv_indptr = torch.zeros(\n                    (max_bs + 1,), dtype=torch.int32, device=model_runner.device\n                )\n            else:\n                # When provided a buffer, create a clone for the second buffer\n                self.window_kv_indptr = torch.zeros_like(kv_indptr_buf)\n\n        if not self.skip_prefill:\n            self.qo_indptr = torch.zeros(\n                (max_bs + 1,), dtype=torch.int64, device=model_runner.device\n            )\n\n            self.mask_indptr = torch.zeros(\n                (max_bs + 1,), dtype=torch.int64, device=model_runner.device\n            )\n\n        # Initialize forward metadata\n        self.forward_metadata: ForwardMetadata = None\n\n        self.cuda_graph_custom_mask = None\n\n    def get_num_kv_splits(\n        self,\n        num_kv_splits: torch.Tensor,\n        seq_lens: torch.Tensor,\n    ):\n        num_token, num_seq = num_kv_splits.shape[0], seq_lens.shape[0]\n        # NOTE(alcanderian): Considering speculative_decodeing,\n        # num_kv_splits.shape[0] will be topk * real_num_token.\n        # And the real_num_token is num_seq in decoding phase.\n        num_group = num_token // num_seq\n\n        assert (\n            num_group * num_seq == num_token\n        ), f\"num_seq({num_seq}), num_token({num_token}), something goes wrong!\"\n\n        # Legacy dynamic splitting logic (non-deterministic)\n        if (\n            self.static_kv_splits or self.device_core_count <= 0\n        ) and not self.enable_deterministic:\n            num_kv_splits.fill_(self.max_kv_splits)\n            return\n\n        # deterministic\n        if self.split_tile_size is not None and self.enable_deterministic:\n            # expand seq_lens to match num_token\n            if num_group > 1:\n                expanded_seq_lens = seq_lens.repeat_interleave(num_group)\n            else:\n                expanded_seq_lens = seq_lens\n\n            num_kv_splits[:] = (\n                expanded_seq_lens + self.split_tile_size - 1\n            ) // self.split_tile_size\n            return\n\n        if num_seq < 256:\n            SCHEDULE_SEQ = 256\n        else:\n            SCHEDULE_SEQ = triton.next_power_of_2(num_seq)\n\n        get_num_kv_splits_triton[(1,)](\n            num_kv_splits,\n            seq_lens,\n            num_seq,\n            num_group,\n            self.num_head,\n            self.num_kv_head,\n            self.max_kv_splits,\n            self.device_core_count,\n            MAX_NUM_SEQ=SCHEDULE_SEQ,\n        )\n\n    def init_forward_metadata(self, forward_batch: ForwardBatch):\n        \"\"\"Init auxiliary variables for triton attention backend.\"\"\"\n\n        bs = forward_batch.batch_size\n        kv_indptr = self.kv_indptr\n        window_kv_indptr = self.window_kv_indptr\n        window_kv_indices = None\n        window_num_kv_splits = None\n        window_kv_offsets = None\n        spec_info = forward_batch.spec_info\n\n        if forward_batch.forward_mode.is_decode_or_idle():\n            if spec_info is None:\n                kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)\n                kv_indptr = kv_indptr[: bs + 1]\n                kv_indices = torch.empty(\n                    forward_batch.seq_lens_sum, dtype=torch.int64, device=self.device\n                )\n                create_flashinfer_kv_indices_triton[(bs,)](\n                    self.req_to_token,\n                    forward_batch.req_pool_indices,\n                    forward_batch.seq_lens,\n                    kv_indptr,\n                    None,\n                    kv_indices,\n                    self.req_to_token.stride(0),\n                )\n                # Sliding window\n                if (\n                    self.sliding_window_size is not None\n                    and self.sliding_window_size > 0\n                ):\n                    window_kv_indptr, window_kv_indices, window_kv_lens, _ = (\n                        update_sliding_window_buffer(\n                            self.window_kv_indptr,\n                            self.req_to_token,\n                            self.sliding_window_size,\n                            forward_batch.seq_lens,\n                            forward_batch.req_pool_indices,\n                            bs,\n                            self.device,\n                            self.token_to_kv_pool_allocator,\n                        )\n                    )\n                    window_num_kv_splits = torch.empty(\n                        (bs,), dtype=torch.int32, device=self.device\n                    )\n                    self.get_num_kv_splits(window_num_kv_splits, window_kv_lens)\n            else:\n                kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices\n                bs = kv_indptr.shape[0] - 1\n\n            attn_logits = torch.empty(\n                (bs, self.num_head, self.max_kv_splits, self.v_head_dim),\n                dtype=torch.float32,\n                device=self.device,\n            )\n            attn_lse = torch.empty(\n                (bs, self.num_head, self.max_kv_splits),\n                dtype=torch.float32,\n                device=self.device,\n            )\n            num_kv_splits = torch.empty((bs,), dtype=torch.int32, device=self.device)\n            self.get_num_kv_splits(num_kv_splits, forward_batch.seq_lens)\n\n            qo_indptr = None\n            custom_mask = None\n            mask_indptr = None\n            max_extend_len = None\n        elif forward_batch.forward_mode.is_target_verify():\n            bs = len(forward_batch.req_pool_indices)\n            qo_indptr = torch.arange(\n                0,\n                (1 + bs) * self.num_draft_tokens,\n                step=self.num_draft_tokens,\n                dtype=torch.int32,\n                device=self.device,\n            )\n            # Different with flashinfer kv_indptr and kv_indices construction\n            kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)\n            kv_indptr = kv_indptr[: bs + 1]\n            kv_indices = torch.empty(\n                kv_indptr[-1], dtype=torch.int64, device=self.device\n            )\n            create_flashinfer_kv_indices_triton[(bs,)](\n                self.req_to_token,\n                forward_batch.req_pool_indices,\n                forward_batch.seq_lens,\n                kv_indptr,\n                None,\n                kv_indices,\n                self.req_to_token.stride(0),\n            )\n\n            if self.sliding_window_size is not None and self.sliding_window_size > 0:\n                # window_kv_offsets is used to calculate the start position in custom mask\n                (\n                    window_kv_indptr,\n                    window_kv_indices,\n                    window_kv_lens,\n                    window_kv_offsets,\n                ) = update_sliding_window_buffer(\n                    self.window_kv_indptr,\n                    self.req_to_token,\n                    self.sliding_window_size,\n                    forward_batch.seq_lens,\n                    forward_batch.req_pool_indices,\n                    bs,\n                    self.device,\n                    self.token_to_kv_pool_allocator,\n                )\n\n            custom_mask = spec_info.custom_mask\n            seq_mask_len = self.num_draft_tokens * (\n                forward_batch.seq_lens + self.num_draft_tokens\n            )\n            mask_indptr = self.mask_indptr\n            mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len[:bs], dim=0)\n            mask_indptr = mask_indptr[: bs + 1]\n            max_extend_len = self.num_draft_tokens\n            num_kv_splits = None\n            attn_logits = None\n            attn_lse = None\n\n        elif forward_batch.forward_mode.is_draft_extend():\n            kv_indices, kv_indptr, qo_indptr, custom_mask = (\n                spec_info.generate_attn_arg_prefill(\n                    forward_batch.req_pool_indices,\n                    forward_batch.seq_lens,\n                    None,\n                    self.req_to_token,\n                )\n            )\n            kv_indices = kv_indices.to(torch.int64)\n            mask_indptr = None\n            # TODO(FIXME): This will trigger an invalid Eagle tree when using\n            # `max(spec_info.accept_length_cpu)`.\n            # It might have been forgotten to update somewhere.\n            max_extend_len = torch.max(spec_info.accept_length).item()\n            num_kv_splits = None\n            attn_logits = None\n            attn_lse = None\n        else:\n            kv_indptr[1 : bs + 1] = torch.cumsum(\n                forward_batch.extend_prefix_lens, dim=0\n            )\n            kv_indptr = kv_indptr[: bs + 1]\n            kv_indices = torch.empty(\n                sum(forward_batch.extend_prefix_lens_cpu),\n                dtype=torch.int64,\n                device=self.device,\n            )\n            create_flashinfer_kv_indices_triton[(bs,)](\n                self.req_to_token,\n                forward_batch.req_pool_indices,\n                forward_batch.extend_prefix_lens,\n                kv_indptr,\n                None,\n                kv_indices,\n                self.req_to_token.stride(0),\n            )\n            # Sliding window\n            if self.sliding_window_size is not None and self.sliding_window_size > 0:\n                (\n                    window_kv_indptr,\n                    window_kv_indices,\n                    window_kv_lens,\n                    window_kv_offsets,\n                ) = update_sliding_window_buffer(\n                    self.window_kv_indptr,\n                    self.req_to_token,\n                    self.sliding_window_size,\n                    forward_batch.extend_prefix_lens,\n                    forward_batch.req_pool_indices,\n                    bs,\n                    self.device,\n                    self.token_to_kv_pool_allocator,\n                )\n\n            qo_indptr = self.qo_indptr\n            qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0)\n            qo_indptr = qo_indptr[: bs + 1]\n            custom_mask = None\n            mask_indptr = None\n            attn_logits = None\n            attn_lse = None\n            max_extend_len = max(forward_batch.extend_seq_lens_cpu)\n            num_kv_splits = None\n\n        self.forward_metadata = ForwardMetadata(\n            attn_logits,\n            attn_lse,\n            max_extend_len,\n            num_kv_splits,\n            kv_indptr,\n            kv_indices,\n            qo_indptr,\n            custom_mask,\n            mask_indptr,\n            window_kv_indptr,\n            window_kv_indices,\n            window_num_kv_splits,\n            window_kv_offsets,\n        )\n\n    def init_cuda_graph_state(\n        self,\n        max_bs: int,\n        max_num_tokens: int,\n        kv_indices_buf: Optional[torch.Tensor] = None,\n        cuda_graph_num_kv_splits_buf: Optional[torch.Tensor] = None,\n    ):\n        self.cuda_graph_attn_logits = torch.zeros(\n            (max_num_tokens, self.num_head, self.max_kv_splits, self.v_head_dim),\n            dtype=torch.float32,\n            device=self.device,\n        )\n        self.cuda_graph_attn_lse = torch.zeros(\n            (max_num_tokens, self.num_head, self.max_kv_splits),\n            dtype=torch.float32,\n            device=self.device,\n        )\n\n        if cuda_graph_num_kv_splits_buf is None:\n            self.cuda_graph_num_kv_splits = torch.full(\n                (max_num_tokens,),\n                self.max_kv_splits,\n                dtype=torch.int32,\n                device=self.device,\n            )\n        else:\n            self.cuda_graph_num_kv_splits = cuda_graph_num_kv_splits_buf\n\n        if kv_indices_buf is None:\n            self.cuda_graph_kv_indices = torch.zeros(\n                (max_num_tokens * self.max_context_len),\n                dtype=torch.int64,\n                device=self.device,\n            )\n        else:\n            self.cuda_graph_kv_indices = kv_indices_buf\n\n        if not self.skip_prefill:\n            self.cuda_graph_custom_mask = torch.zeros(\n                (max_num_tokens * self.max_context_len),\n                dtype=torch.uint8,\n                device=self.device,\n            )\n\n        if self.sliding_window_size is not None and self.sliding_window_size > 0:\n            if kv_indices_buf is None:\n                self.cuda_graph_window_kv_indices = torch.zeros(\n                    (max_num_tokens * self.sliding_window_size),\n                    dtype=torch.int64,\n                    device=self.device,\n                )\n            else:\n                self.cuda_graph_window_kv_indices = torch.zeros_like(kv_indices_buf)\n\n            self.cuda_graph_window_num_kv_splits = torch.full(\n                (max_num_tokens,),\n                self.max_kv_splits,\n                dtype=torch.int32,\n                device=self.device,\n            )\n\n            self.cuda_graph_window_kv_offsets = torch.zeros(\n                (max_bs,),\n                dtype=torch.int32,\n                device=self.device,\n            )\n\n    def init_forward_metadata_capture_cuda_graph(\n        self,\n        bs: int,\n        num_tokens: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        encoder_lens: Optional[torch.Tensor],\n        forward_mode: ForwardMode,\n        spec_info: Optional[SpecInput],\n    ):\n        assert encoder_lens is None, \"Not supported\"\n        window_kv_indptr = self.window_kv_indptr\n        window_kv_indices = None\n        window_num_kv_splits = None\n        window_kv_offsets = None\n\n        if forward_mode.is_decode_or_idle():\n            if spec_info is None:\n                kv_indptr = self.kv_indptr\n                kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)\n                kv_indptr = kv_indptr[: bs + 1]\n                kv_indices = self.cuda_graph_kv_indices\n                create_flashinfer_kv_indices_triton[(bs,)](\n                    self.req_to_token,\n                    req_pool_indices,\n                    seq_lens,\n                    kv_indptr,\n                    None,\n                    kv_indices,\n                    self.req_to_token.stride(0),\n                )\n                if (\n                    self.sliding_window_size is not None\n                    and self.sliding_window_size > 0\n                ):\n                    window_kv_indices = self.cuda_graph_window_kv_indices\n                    window_num_kv_splits = self.cuda_graph_window_num_kv_splits\n                    window_kv_indptr, window_kv_indices, _, _ = (\n                        update_sliding_window_buffer_cuda_graph(\n                            self.window_kv_indptr,\n                            window_kv_indices,\n                            self.req_to_token,\n                            self.sliding_window_size,\n                            seq_lens[:bs],\n                            req_pool_indices,\n                            bs,\n                            self.token_to_kv_pool_allocator,\n                        )\n                    )\n            else:\n                kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices\n\n            attn_logits = self.cuda_graph_attn_logits\n            attn_lse = self.cuda_graph_attn_lse\n            max_extend_len = None\n            num_kv_splits = self.cuda_graph_num_kv_splits\n            qo_indptr = None\n            custom_mask = None\n            mask_indptr = None\n        elif forward_mode.is_target_verify():\n            qo_indptr = self.qo_indptr[: bs + 1]\n            qo_indptr[: bs + 1] = torch.arange(\n                0,\n                (1 + bs) * self.num_draft_tokens,\n                step=self.num_draft_tokens,\n                dtype=torch.int32,\n                device=self.device,\n            )\n            kv_indptr = self.kv_indptr[: bs + 1]\n            kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)\n            kv_indices = self.cuda_graph_kv_indices\n            create_flashinfer_kv_indices_triton[(bs,)](\n                self.req_to_token,\n                req_pool_indices,\n                seq_lens,\n                kv_indptr,\n                None,\n                kv_indices,\n                self.req_to_token.stride(0),\n            )\n\n            if self.sliding_window_size is not None and self.sliding_window_size > 0:\n                window_kv_indices = self.cuda_graph_window_kv_indices\n                window_num_kv_splits = self.cuda_graph_window_num_kv_splits\n                window_kv_offsets = self.cuda_graph_window_kv_offsets\n                window_kv_indptr, window_kv_indices, _, window_kv_offsets[:bs] = (\n                    update_sliding_window_buffer_cuda_graph(\n                        self.window_kv_indptr,\n                        window_kv_indices,\n                        self.req_to_token,\n                        self.sliding_window_size,\n                        seq_lens[:bs],\n                        req_pool_indices,\n                        bs,\n                        self.token_to_kv_pool_allocator,\n                    )\n                )\n\n            custom_mask = self.cuda_graph_custom_mask\n            custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask\n            seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)\n            mask_indptr = self.mask_indptr[: bs + 1]\n            mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)\n            max_extend_len = self.num_draft_tokens\n            num_kv_splits = None\n            attn_logits = None\n            attn_lse = None\n        elif forward_mode.is_draft_extend(include_v2=True):\n            num_tokens_per_bs = self.speculative_num_steps + 1\n            qo_indptr = self.qo_indptr[: bs + 1]\n            qo_indptr[: bs + 1] = torch.arange(\n                0,\n                bs * num_tokens_per_bs + 1,\n                step=num_tokens_per_bs,\n                dtype=torch.int32,\n                device=self.device,\n            )\n            kv_indptr = self.kv_indptr[: bs + 1]\n            kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)\n            kv_indices = self.cuda_graph_kv_indices\n            create_flashinfer_kv_indices_triton[(bs,)](\n                self.req_to_token,\n                req_pool_indices,\n                seq_lens,\n                kv_indptr,\n                None,\n                kv_indices,\n                self.req_to_token.stride(0),\n            )\n            custom_mask = None\n            mask_indptr = None\n            max_extend_len = num_tokens_per_bs\n            num_kv_splits = None\n            attn_logits = None\n            attn_lse = None\n        else:\n            raise ValueError(\n                f\"Invalid forward mode: {forward_mode=} for CUDA Graph capture.\"\n            )\n\n        self.forward_metadata = ForwardMetadata(\n            attn_logits,\n            attn_lse,\n            max_extend_len,\n            num_kv_splits,\n            kv_indptr,\n            kv_indices,\n            qo_indptr,\n            custom_mask,\n            mask_indptr,\n            window_kv_indptr,\n            window_kv_indices,\n            window_num_kv_splits,\n            window_kv_offsets,\n        )\n\n    def init_forward_metadata_replay_cuda_graph(\n        self,\n        bs: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        seq_lens_sum: int,\n        encoder_lens: Optional[torch.Tensor],\n        forward_mode: ForwardMode,\n        spec_info: Optional[SpecInput],\n        seq_lens_cpu: Optional[torch.Tensor],\n    ):\n        # NOTE: encoder_lens expected to be zeros or None\n        if forward_mode.is_decode_or_idle():\n            # Update kv_indptr, kv_indices\n            kv_indptr = self.kv_indptr\n            kv_indices = self.cuda_graph_kv_indices\n            num_kv_splits = self.cuda_graph_num_kv_splits\n            if spec_info is None:\n                kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0)\n                kv_indptr = kv_indptr[: bs + 1]\n                create_flashinfer_kv_indices_triton[(bs,)](\n                    self.req_to_token,\n                    req_pool_indices[:bs],\n                    seq_lens[:bs],\n                    kv_indptr,\n                    None,\n                    kv_indices,\n                    self.req_to_token.stride(0),\n                )\n                num_token = bs\n                if (\n                    self.sliding_window_size is not None\n                    and self.sliding_window_size > 0\n                ):\n                    window_num_kv_splits = self.cuda_graph_window_num_kv_splits\n                    window_kv_indices = self.cuda_graph_window_kv_indices\n                    _, _, window_kv_lens, _ = update_sliding_window_buffer_cuda_graph(\n                        self.window_kv_indptr,\n                        window_kv_indices,\n                        self.req_to_token,\n                        self.sliding_window_size,\n                        seq_lens[:bs],\n                        req_pool_indices[:bs],\n                        bs,\n                        self.token_to_kv_pool_allocator,\n                    )\n                    self.get_num_kv_splits(\n                        window_num_kv_splits[:num_token], window_kv_lens[:bs]\n                    )\n\n            else:\n                assert False, \"Multi-step cuda graph init is not done here.\"\n            self.get_num_kv_splits(num_kv_splits[:num_token], seq_lens[:bs])\n\n        elif forward_mode.is_target_verify():\n            # Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr\n            bs = len(req_pool_indices)\n            qo_indptr = self.qo_indptr[: bs + 1]\n            qo_indptr[: bs + 1] = torch.arange(\n                0,\n                (1 + bs) * self.num_draft_tokens,\n                step=self.num_draft_tokens,\n                dtype=torch.int32,\n                device=self.device,\n            )\n            kv_indptr = self.kv_indptr[: bs + 1]\n            kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)\n            kv_indices = self.cuda_graph_kv_indices\n            create_flashinfer_kv_indices_triton[(bs,)](\n                self.req_to_token,\n                req_pool_indices,\n                seq_lens,\n                kv_indptr,\n                None,\n                kv_indices,\n                self.req_to_token.stride(0),\n            )\n            if self.sliding_window_size is not None and self.sliding_window_size > 0:\n                window_num_kv_splits = self.cuda_graph_window_num_kv_splits\n                window_kv_indices = self.cuda_graph_window_kv_indices\n                window_kv_offsets = self.cuda_graph_window_kv_offsets\n                _, _, window_kv_lens, window_kv_offsets[:bs] = (\n                    update_sliding_window_buffer_cuda_graph(\n                        self.window_kv_indptr,\n                        window_kv_indices,\n                        self.req_to_token,\n                        self.sliding_window_size,\n                        seq_lens[:bs],\n                        req_pool_indices,\n                        bs,\n                        self.token_to_kv_pool_allocator,\n                    )\n                )\n            custom_mask = self.cuda_graph_custom_mask\n            custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask\n            seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)\n            mask_indptr = self.mask_indptr[: bs + 1]\n            mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)\n        elif forward_mode.is_draft_extend(include_v2=True):\n            seq_lens = seq_lens[:bs]\n            num_tokens_per_bs = self.speculative_num_steps + 1\n            qo_indptr = self.qo_indptr[: bs + 1]\n            qo_indptr[: bs + 1] = torch.arange(\n                0,\n                bs * num_tokens_per_bs + 1,\n                step=num_tokens_per_bs,\n                dtype=torch.int32,\n                device=self.device,\n            )\n            kv_indptr = self.kv_indptr[: bs + 1]\n            kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)\n            kv_indices = self.cuda_graph_kv_indices\n            create_flashinfer_kv_indices_triton[(bs,)](\n                self.req_to_token,\n                req_pool_indices,\n                seq_lens,\n                kv_indptr,\n                None,\n                kv_indices,\n                self.req_to_token.stride(0),\n            )\n        else:\n            raise ValueError(\n                f\"Invalid forward mode: {forward_mode=} for CUDA Graph replay.\"\n            )\n\n    def get_cuda_graph_seq_len_fill_value(self):\n        return 1\n\n    def get_verify_buffers_to_fill_after_draft(self):\n        \"\"\"\n        Return buffers for verify attention kernels that needs to be filled after draft.\n\n        Typically, these are tree mask and position buffers.\n        \"\"\"\n        return [self.cuda_graph_custom_mask, None]\n\n    def update_verify_buffers_to_fill_after_draft(\n        self, spec_info: SpecInput, cuda_graph_bs: Optional[int]\n    ):\n        pass\n\n    def forward_extend(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache=True,\n        sinks=None,\n    ):\n        # TODO: reuse the buffer across layers\n        if layer.qk_head_dim != layer.v_head_dim:\n            o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))\n        else:\n            o = torch.empty_like(q)\n\n        # Save KV cache first (must do this before unified kernel)\n        if save_kv_cache:\n            if (\n                self.use_mla or layer.k_scale is None\n            ):  # Triton MLA currently doesn't support quantized kv cache\n                forward_batch.token_to_kv_pool.set_kv_buffer(\n                    layer,\n                    forward_batch.out_cache_loc,\n                    k,\n                    v,\n                )\n            else:\n                forward_batch.token_to_kv_pool.set_kv_buffer(\n                    layer,\n                    forward_batch.out_cache_loc,\n                    k.clone(),  # cloned to protect k,v from in-place mutation in set_kv_buffer\n                    v.clone(),\n                    layer.k_scale,\n                    layer.v_scale,\n                )\n\n        logits_soft_cap = logit_capping_mod(layer.logit_capping_method, layer.logit_cap)\n\n        causal = True\n        if (\n            layer.is_cross_attention\n            or layer.attn_type == AttentionType.ENCODER_ONLY\n            or (\n                layer.attn_type == AttentionType.DECODER_BIDIRECTIONAL\n                and self.allow_bidirectional_attention_in_extend\n            )\n        ):\n            causal = False\n\n        # Deterministic mode: use unified 1-stage kernel\n        if self.enable_deterministic:\n            return self._forward_extend_unified(\n                q, o, layer, forward_batch, causal, logits_soft_cap, sinks\n            )\n\n        # Normal mode: use original 2-stage kernel\n        if layer.sliding_window_size is not None and layer.sliding_window_size > -1:\n            sliding_window_size = (\n                layer.sliding_window_size\n            )  # Needed for sliding window mask\n            kv_indptr = self.forward_metadata.window_kv_indptr\n            kv_indices = self.forward_metadata.window_kv_indices\n            window_kv_offsets = self.forward_metadata.window_kv_offsets\n        else:\n            sliding_window_size = -1\n            kv_indptr = self.forward_metadata.kv_indptr\n            kv_indices = self.forward_metadata.kv_indices\n            window_kv_offsets = None\n\n        if layer.k_scale is not None and layer.v_scale is not None:\n            k_descale = layer.k_scale_float\n            v_descale = layer.v_scale_float\n        else:\n            k_descale = 1.0\n            v_descale = 1.0\n\n        self.extend_attention_fwd(\n            q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),\n            k.contiguous(),\n            v.contiguous(),\n            o.view(-1, layer.tp_q_head_num, layer.v_head_dim),\n            forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),\n            forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),\n            self.forward_metadata.qo_indptr,\n            kv_indptr,\n            kv_indices,\n            self.forward_metadata.custom_mask,\n            causal,\n            self.forward_metadata.mask_indptr,\n            self.forward_metadata.max_extend_len,\n            k_descale,\n            v_descale,\n            layer.scaling,\n            logit_cap=logits_soft_cap,\n            sliding_window_size=sliding_window_size,\n            sinks=sinks,\n            window_kv_offsets=window_kv_offsets,\n            xai_temperature_len=layer.xai_temperature_len,\n        )\n        return o\n\n    def _forward_extend_unified(\n        self,\n        q: torch.Tensor,\n        o: torch.Tensor,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        causal: bool,\n        logits_soft_cap: float,\n        sinks: Optional[torch.Tensor],\n    ):\n        \"\"\"\n        Unified 1-stage extend attention for deterministic inference.\n        Both prefix and extend KV are accessed through unified kv_indices.\n        \"\"\"\n        bs = forward_batch.batch_size\n\n        # Determine sliding window settings\n        if layer.sliding_window_size is not None and layer.sliding_window_size > -1:\n            sliding_window_size = layer.sliding_window_size\n            # Note: for unified kernel, we use full kv_indptr (not window)\n            prefix_kv_indptr = self.forward_metadata.window_kv_indptr\n            prefix_kv_indices = self.forward_metadata.window_kv_indices\n            # Compute window start positions (absolute position of first key in window)\n            # window_start_pos = seq_len - window_len\n            window_kv_lens = prefix_kv_indptr[1 : bs + 1] - prefix_kv_indptr[:bs]\n            # Handle TARGET_VERIFY mode where extend_prefix_lens might not be set\n            if forward_batch.extend_prefix_lens is not None:\n                window_start_pos = (\n                    forward_batch.extend_prefix_lens[:bs] - window_kv_lens\n                )\n            else:\n                # Infer from spec_info: prefix_len = seq_len - draft_token_num\n                if forward_batch.spec_info is not None and hasattr(\n                    forward_batch.spec_info, \"draft_token_num\"\n                ):\n                    extend_prefix_lens = (\n                        forward_batch.seq_lens[:bs]\n                        - forward_batch.spec_info.draft_token_num\n                    )\n                    window_start_pos = extend_prefix_lens - window_kv_lens\n                else:\n                    window_start_pos = None\n        else:\n            sliding_window_size = -1\n            prefix_kv_indptr = self.forward_metadata.kv_indptr\n            prefix_kv_indices = self.forward_metadata.kv_indices\n            window_start_pos = None\n\n        # Build unified kv_indices using fused Triton kernel\n        extend_kv_indices = forward_batch.out_cache_loc\n\n        # Handle cases where extend_seq_lens or extend_start_loc might not be set\n        # In speculative decoding, we can infer these from spec_info or compute them\n        if forward_batch.extend_seq_lens is None:\n            # TARGET_VERIFY mode: infer extend_seq_lens from spec_info\n            if forward_batch.spec_info is not None and hasattr(\n                forward_batch.spec_info, \"draft_token_num\"\n            ):\n                draft_token_num = forward_batch.spec_info.draft_token_num\n                extend_seq_lens = torch.full(\n                    (bs,), draft_token_num, dtype=torch.int32, device=self.device\n                )\n            else:\n                raise RuntimeError(\n                    \"extend_seq_lens is None but cannot infer from spec_info. \"\n                    \"This should not happen in TARGET_VERIFY mode.\"\n                )\n        else:\n            extend_seq_lens = forward_batch.extend_seq_lens\n\n        # Check extend_start_loc separately - it might be None even when extend_seq_lens is set\n        if forward_batch.extend_start_loc is None:\n            # Compute extend_start_loc from extend_seq_lens\n            # extend_start_loc[i] = sum(extend_seq_lens[0:i])\n            extend_start_loc = torch.cat(\n                [\n                    torch.zeros(1, dtype=torch.int32, device=self.device),\n                    torch.cumsum(extend_seq_lens[:-1], dim=0),\n                ]\n            )\n        else:\n            extend_start_loc = forward_batch.extend_start_loc\n\n        unified_kv_indptr, unified_kv_indices, prefix_lens = (\n            self.build_unified_kv_indices(\n                prefix_kv_indptr,\n                prefix_kv_indices,\n                extend_start_loc,\n                extend_seq_lens,\n                extend_kv_indices,\n                bs,\n            )\n        )\n\n        # Convert prefix_lens to int32 for the kernel\n        prefix_lens = prefix_lens.to(torch.int32)\n\n        if layer.k_scale is not None and layer.v_scale is not None:\n            k_descale = layer.k_scale_float\n            v_descale = layer.v_scale_float\n        else:\n            k_descale = 1.0\n            v_descale = 1.0\n\n        # Call unified kernel\n        self.extend_attention_fwd_unified(\n            q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),\n            o.view(-1, layer.tp_q_head_num, layer.v_head_dim),\n            forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),\n            forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),\n            k_descale,\n            v_descale,\n            self.forward_metadata.qo_indptr,\n            unified_kv_indptr,\n            unified_kv_indices,\n            prefix_lens,\n            self.forward_metadata.max_extend_len,\n            custom_mask=self.forward_metadata.custom_mask,\n            mask_indptr=self.forward_metadata.mask_indptr,\n            sm_scale=layer.scaling,\n            logit_cap=logits_soft_cap,\n            is_causal=causal,\n            sliding_window_size=sliding_window_size,\n            sinks=sinks,\n            window_start_pos=window_start_pos,\n            xai_temperature_len=layer.xai_temperature_len,\n        )\n\n        return o\n\n    def forward_decode(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache=True,\n        sinks=None,\n    ):\n        # During torch.compile, there is a bug in rotary_emb that causes the\n        # output value to have a 3D tensor shape. This reshapes the output correctly.\n        q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)\n\n        # TODO: reuse the buffer across layers\n        if layer.qk_head_dim != layer.v_head_dim:\n            o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))\n        else:\n            o = torch.empty_like(q)\n\n        logits_soft_cap = logit_capping_mod(layer.logit_capping_method, layer.logit_cap)\n\n        if save_kv_cache:\n            if self.use_mla:  # Triton MLA currently doesn't support quantized kv cache\n                forward_batch.token_to_kv_pool.set_kv_buffer(\n                    layer,\n                    forward_batch.out_cache_loc,\n                    k,\n                    v,\n                )\n            else:\n                forward_batch.token_to_kv_pool.set_kv_buffer(\n                    layer,\n                    forward_batch.out_cache_loc,\n                    k,\n                    v,\n                    layer.k_scale,\n                    layer.v_scale,\n                )\n\n        if layer.sliding_window_size is not None and layer.sliding_window_size > -1:\n            kv_indptr = self.forward_metadata.window_kv_indptr\n            kv_indices = self.forward_metadata.window_kv_indices\n        else:\n            kv_indptr = self.forward_metadata.kv_indptr\n            kv_indices = self.forward_metadata.kv_indices\n\n        if layer.k_scale is not None and layer.v_scale is not None:\n            k_descale = layer.k_scale_float\n            v_descale = layer.v_scale_float\n        else:\n            k_descale = 1.0\n            v_descale = 1.0\n\n        self.decode_attention_fwd(\n            q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),\n            forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),\n            forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),\n            o.view(-1, layer.tp_q_head_num, layer.v_head_dim),\n            kv_indptr,\n            kv_indices,\n            self.forward_metadata.attn_logits,\n            self.forward_metadata.attn_lse,\n            self.forward_metadata.num_kv_splits,\n            self.max_kv_splits,\n            layer.scaling,\n            k_descale,\n            v_descale,\n            logit_cap=logits_soft_cap,\n            sinks=sinks,\n            xai_temperature_len=layer.xai_temperature_len,\n        )\n        return o\n\n\nclass TritonMultiStepDraftBackend:\n    \"\"\"\n    Wrap multiple triton attention backends as one for multiple consecutive\n    draft decoding steps.\n    \"\"\"\n\n    def __init__(\n        self,\n        model_runner: ModelRunner,\n        topk: int,\n        speculative_num_steps: int,\n    ):\n        self.topk = topk\n        self.speculative_num_steps = speculative_num_steps\n        max_bs = model_runner.req_to_token_pool.size * self.topk\n        self.kv_indptr = torch.zeros(\n            (\n                self.speculative_num_steps,\n                max_bs + 1,\n            ),\n            dtype=torch.int32,\n            device=model_runner.device,\n        )\n        self.attn_backends: List[TritonAttnBackend] = []\n        for i in range(self.speculative_num_steps - 1):\n            self.attn_backends.append(\n                TritonAttnBackend(\n                    model_runner,\n                    skip_prefill=True,\n                    kv_indptr_buf=self.kv_indptr[i],\n                )\n            )\n        self.max_context_len = self.attn_backends[0].max_context_len\n        self.num_head = (\n            model_runner.model_config.num_attention_heads // get_attention_tp_size()\n        )\n        self.device = model_runner.device\n        # Cached variables for generate_draft_decode_kv_indices\n        self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]\n        self.page_size = model_runner.server_args.page_size\n\n    def common_template(\n        self,\n        forward_batch: ForwardBatch,\n        kv_indices_buffer: Optional[torch.Tensor],\n        call_fn: int,\n    ):\n        if kv_indices_buffer is None:\n            kv_indices_buffer = self.cuda_graph_kv_indices\n\n        num_seqs = forward_batch.batch_size\n        bs = self.topk * num_seqs\n        seq_lens_sum = forward_batch.seq_lens_sum\n\n        generate_draft_decode_kv_indices[\n            (self.speculative_num_steps, num_seqs, self.topk)\n        ](\n            forward_batch.req_pool_indices,\n            forward_batch.req_to_token_pool.req_to_token,\n            forward_batch.seq_lens,\n            kv_indices_buffer,\n            self.kv_indptr,\n            forward_batch.positions,\n            self.pool_len,\n            kv_indices_buffer.shape[1],\n            self.kv_indptr.shape[1],\n            next_power_of_2(num_seqs),\n            next_power_of_2(self.speculative_num_steps),\n            next_power_of_2(bs),\n            self.page_size,\n        )\n\n        if call_fn is None:\n            return\n\n        for i in range(self.speculative_num_steps - 1):\n            forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]\n            forward_batch.spec_info.kv_indices = kv_indices_buffer[i][\n                : seq_lens_sum * self.topk + bs * (i + 1)\n            ]\n            call_fn(i, forward_batch)\n\n    def init_forward_metadata(self, forward_batch: ForwardBatch):\n        kv_indices = torch.empty(\n            (\n                self.speculative_num_steps,\n                forward_batch.batch_size * self.topk * self.max_context_len,\n            ),\n            dtype=torch.int64,\n            device=self.device,\n        )\n\n        def call_fn(i, forward_batch):\n            forward_batch.spec_info.kv_indptr = (\n                forward_batch.spec_info.kv_indptr.clone()\n            )\n            forward_batch.spec_info.kv_indices = (\n                forward_batch.spec_info.kv_indices.clone()\n            )\n            self.attn_backends[i].init_forward_metadata(forward_batch)\n\n        self.common_template(forward_batch, kv_indices, call_fn)\n\n    def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):\n        self.cuda_graph_kv_indices = torch.zeros(\n            (self.speculative_num_steps, max_num_tokens * self.max_context_len),\n            dtype=torch.int64,\n            device=self.device,\n        )\n        self.cuda_graph_num_kv_splits = torch.full(\n            (max_num_tokens,),\n            self.attn_backends[0].max_kv_splits,\n            dtype=torch.int32,\n            device=self.device,\n        )\n\n        for i in range(self.speculative_num_steps - 1):\n            self.attn_backends[i].init_cuda_graph_state(\n                max_bs,\n                max_num_tokens,\n                kv_indices_buf=self.cuda_graph_kv_indices[i],\n                cuda_graph_num_kv_splits_buf=self.cuda_graph_num_kv_splits,\n            )\n\n    def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):\n        def call_fn(i, forward_batch):\n            self.attn_backends[i].init_forward_metadata_capture_cuda_graph(\n                forward_batch.batch_size,\n                forward_batch.batch_size * self.topk,\n                forward_batch.req_pool_indices,\n                forward_batch.seq_lens,\n                encoder_lens=None,\n                forward_mode=ForwardMode.DECODE,\n                spec_info=forward_batch.spec_info,\n            )\n\n        self.common_template(forward_batch, None, call_fn)\n\n    def init_forward_metadata_replay_cuda_graph(\n        self, forward_batch: ForwardBatch, bs: int\n    ):\n        self.common_template(forward_batch, None, None)\n\n        # NOTE: Multi-step's attention backends use the slice of\n        # - kv_indptr buffer (cuda graph and non-cuda graph)\n        # - kv_indices buffer (cuda graph only)\n        # So we don't need to assign the KV indices inside the attention backend.\n\n        # Compute num_kv_splits only once\n        num_token = forward_batch.batch_size * self.topk\n        self.attn_backends[-1].get_num_kv_splits(\n            self.attn_backends[-1].cuda_graph_num_kv_splits[:num_token],\n            forward_batch.seq_lens[:bs],\n        )\n\n\n@triton.jit\ndef get_num_kv_splits_triton(\n    num_kv_splits_ptr,\n    seq_lens_ptr,\n    num_seq,\n    num_group,\n    num_head,\n    num_kv_head,\n    max_kv_splits,\n    device_core_count,\n    MAX_NUM_SEQ: tl.constexpr,\n):\n    # TODO: this method is tunable, we need more online serving data to tune it\n    offs_seq = tl.arange(0, MAX_NUM_SEQ)\n    mask_seq = offs_seq < num_seq\n\n    seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=0)\n    max_seq_len = tl.max(seq_lens)\n    seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=max_seq_len)\n    min_seq_len = tl.min(seq_lens)\n    if max_seq_len * 8 < min_seq_len * 10:\n        min_seq_len = max_seq_len\n    max_kv_splits_1 = tl.minimum(tl.cdiv(max_seq_len, min_seq_len), max_kv_splits)\n    kv_chunk_size_1 = tl.cdiv(max_seq_len, max_kv_splits_1)\n\n    # NOTE: this is a hack to let num_kv_split grows up with seqlen gradually\n    ext_seq_len = tl.cast(max_seq_len, tl.float32) / 64.0\n    ext_device_core_count = tl.cast(\n        device_core_count * tl.maximum(tl.log2(ext_seq_len), 1.0), tl.int32\n    )\n    block_h, num_kv_group = 16, num_head // num_kv_head\n    if num_kv_group == 1:\n        token_grid = num_seq * num_group * num_head\n    else:\n        # from triton_ops/decode_attention.py:_decode_grouped_att_m_fwd\n        block_h = tl.minimum(block_h, num_kv_group)\n        token_grid = num_seq * num_group * tl.cdiv(num_head, block_h)\n    max_kv_splits_2 = tl.minimum(\n        tl.cdiv(ext_device_core_count, token_grid), max_kv_splits\n    )\n    kv_chunk_size_2 = tl.cdiv(max_seq_len, max_kv_splits_2)\n\n    num_kv_splits = tl.maximum(\n        tl.cdiv(seq_lens, kv_chunk_size_1), tl.cdiv(seq_lens, kv_chunk_size_2)\n    )\n\n    offs_token = offs_seq * num_group\n    mask_token = offs_token < num_seq * num_group\n    for i in range(0, num_group):\n        tl.store(num_kv_splits_ptr + i + offs_token, num_kv_splits, mask=mask_token)\n\n\ndef update_sliding_window_buffer(\n    window_kv_indptr,\n    req_to_token,\n    sliding_window_size,\n    seq_lens,\n    req_pool_indices,\n    bs,\n    device,\n    token_to_kv_pool_allocator=None,\n):\n    window_kv_lens = torch.minimum(\n        seq_lens,\n        torch.tensor(sliding_window_size),\n    )\n    window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)\n    window_kv_indptr = window_kv_indptr[: bs + 1]\n    window_kv_indices = torch.empty(\n        window_kv_indptr[-1], dtype=torch.int64, device=device\n    )\n    window_kv_start_idx = seq_lens - window_kv_lens\n    create_flashinfer_kv_indices_triton[(bs,)](\n        req_to_token,\n        req_pool_indices,\n        window_kv_lens,\n        window_kv_indptr,\n        window_kv_start_idx,\n        window_kv_indices,\n        req_to_token.stride(0),\n    )\n    # full to swa index mapping\n    if hasattr(token_to_kv_pool_allocator, \"translate_loc_from_full_to_swa\"):\n        kv_last_index = window_kv_indptr[-1]\n        window_kv_indices[:kv_last_index] = (\n            token_to_kv_pool_allocator.translate_loc_from_full_to_swa(\n                window_kv_indices[:kv_last_index]\n            )\n        )\n    return window_kv_indptr, window_kv_indices, window_kv_lens, window_kv_start_idx\n\n\ndef update_sliding_window_buffer_cuda_graph(\n    window_kv_indptr,\n    window_kv_indices,\n    req_to_token,\n    sliding_window_size,\n    seq_lens,\n    req_pool_indices,\n    bs,\n    token_to_kv_pool_allocator=None,\n):\n    window_kv_lens = torch.minimum(\n        seq_lens,\n        torch.tensor(sliding_window_size),\n    )\n    window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)\n    window_kv_indptr = window_kv_indptr[: bs + 1]\n    window_kv_start_idx = seq_lens - window_kv_lens\n    create_flashinfer_kv_indices_triton[(bs,)](\n        req_to_token,\n        req_pool_indices,\n        window_kv_lens,\n        window_kv_indptr,\n        window_kv_start_idx,\n        window_kv_indices,\n        req_to_token.stride(0),\n    )\n    # full to swa index mapping\n    if hasattr(token_to_kv_pool_allocator, \"translate_loc_from_full_to_swa\"):\n        kv_last_index = window_kv_indptr[-1]\n        window_kv_indices[:kv_last_index] = (\n            token_to_kv_pool_allocator.translate_loc_from_full_to_swa(\n                window_kv_indices[:kv_last_index]\n            )\n        )\n    return window_kv_indptr, window_kv_indices, window_kv_lens, window_kv_start_idx\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/triton_ops/decode_attention.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"\nMemory-efficient attention for decoding.\nIt supports page size = 1.\n\"\"\"\n\n# Adapted from\n# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py\n# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py\n\nimport logging\n\nimport triton\nimport triton.language as tl\n\nfrom sglang.srt.utils import is_hip\n\n_is_hip = is_hip()\n\nlogger = logging.getLogger(__name__)\n\n\n_MIN_BLOCK_KV = 32\n\n\n@triton.jit\ndef tanh(x):\n    # Tanh is just a scaled sigmoid\n    return 2 * tl.sigmoid(2 * x) - 1\n\n\n@triton.jit\ndef _fwd_kernel_stage1(\n    Q,\n    K_Buffer,\n    V_Buffer,\n    sm_scale_withk,\n    kv_indptr,\n    kv_indices,\n    Att_Out,\n    Att_Lse,\n    num_kv_splits,\n    stride_qbs,\n    stride_qh,\n    stride_buf_kbs,\n    stride_buf_kh,\n    stride_buf_vbs,\n    stride_buf_vh,\n    stride_mid_ob,\n    stride_mid_oh,\n    stride_mid_os,\n    kv_group_num: tl.constexpr,\n    BLOCK_DMODEL: tl.constexpr,\n    BLOCK_DV: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    MIN_BLOCK_KV: tl.constexpr,\n    logit_cap: tl.constexpr,\n    Lk: tl.constexpr,\n    Lv: tl.constexpr,\n    xai_temperature_len: tl.constexpr,\n):\n    cur_batch = tl.program_id(0)\n    cur_head = tl.program_id(1)\n    split_kv_id = tl.program_id(2)\n\n    cur_kv_head = cur_head // kv_group_num\n\n    offs_d = tl.arange(0, BLOCK_DMODEL)\n    offs_dv = tl.arange(0, BLOCK_DV)\n    mask_d = offs_d < Lk\n    mask_dv = offs_dv < Lv\n\n    cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch)\n    cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx\n    kv_splits = tl.load(num_kv_splits + cur_batch)\n\n    if xai_temperature_len > 0:\n        offs_qidx = cur_batch_seq_len - 1\n        xai_temperature_scale = 1.0 / tl.log2(float(xai_temperature_len))\n        _qtemp = tl.log2(offs_qidx.to(tl.float32)) * xai_temperature_scale\n        xai_temperature_reg = tl.where(offs_qidx > xai_temperature_len, _qtemp, 1.0)\n\n    off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d\n\n    kv_len_per_split = (\n        tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV\n    )\n    split_kv_start = kv_len_per_split * split_kv_id\n    split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)\n\n    e_max = -float(\"inf\")\n    e_sum = 0.0\n    acc = tl.zeros([BLOCK_DV], dtype=tl.float32)\n\n    if split_kv_end > split_kv_start:\n        q = tl.load(Q + off_q, mask=mask_d, other=0.0)\n        for start_n in range(split_kv_start, split_kv_end, BLOCK_N):\n            offs_n = start_n + tl.arange(0, BLOCK_N)\n            kv_loc = tl.load(\n                kv_indices + cur_batch_kv_start_idx + offs_n,\n                mask=offs_n < split_kv_end,\n                other=0,\n            )\n            offs_buf_k = (\n                kv_loc[:, None] * stride_buf_kbs\n                + cur_kv_head * stride_buf_kh\n                + offs_d[None, :]\n            )\n            k = tl.load(\n                K_Buffer + offs_buf_k,\n                mask=(offs_n[:, None] < split_kv_end) & (mask_d[None, :]),\n                other=0.0,\n            )\n            qk = tl.sum(q[None, :] * k, 1)\n            qk *= sm_scale_withk\n\n            if logit_cap > 0:\n                qk = logit_cap * tanh(qk / logit_cap)\n\n            if xai_temperature_len > 0:\n                qk *= xai_temperature_reg\n\n            qk = tl.where(offs_n < split_kv_end, qk, float(\"-inf\"))\n\n            offs_buf_v = (\n                kv_loc[:, None] * stride_buf_vbs\n                + cur_kv_head * stride_buf_vh\n                + offs_dv[None, :]\n            )\n            v = tl.load(\n                V_Buffer + offs_buf_v,\n                mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),\n                other=0.0,\n            )\n\n            n_e_max = tl.maximum(tl.max(qk, 0), e_max)\n            re_scale = tl.exp(e_max - n_e_max)\n            p = tl.exp(qk - n_e_max)\n            acc *= re_scale\n            acc += tl.sum(p[:, None] * v, 0)\n\n            e_sum = e_sum * re_scale + tl.sum(p, 0)\n            e_max = n_e_max\n\n        offs_mid_o = (\n            cur_batch * stride_mid_ob\n            + cur_head * stride_mid_oh\n            + split_kv_id * stride_mid_os\n            + offs_dv\n        )\n\n        tl.store(\n            Att_Out + offs_mid_o,\n            acc / e_sum,\n            mask=(mask_dv),\n        )\n\n        offs_mid_o_1 = (\n            cur_batch * stride_mid_ob\n            + cur_head * stride_mid_oh\n            + split_kv_id * stride_mid_os\n        ) // Lv\n\n        tl.store(\n            Att_Lse + offs_mid_o_1,\n            e_max + tl.log(e_sum),\n        )\n\n\ndef _decode_att_m_fwd(\n    q,\n    k_buffer,\n    v_buffer,\n    att_out,\n    att_lse,\n    kv_indptr,\n    kv_indices,\n    num_kv_splits,\n    max_kv_splits,\n    sm_scale_withk,\n    logit_cap,\n    xai_temperature_len=-1,\n):\n    BLOCK = 64\n    # [TODO] work around SGPR limit on MI3xx\n    if _is_hip:\n        BLOCK = 8\n    MAX_KV_SPLITS = max_kv_splits\n    Lk = k_buffer.shape[-1]\n    Lv = v_buffer.shape[-1]\n\n    batch, head_num = q.shape[0], q.shape[1]\n\n    grid = (batch, head_num, MAX_KV_SPLITS)\n    kv_group_num = q.shape[1] // k_buffer.shape[1]\n\n    if kv_group_num == 1:\n        num_warps = 4\n    else:\n        num_warps = 2\n        if _is_hip:\n            num_warps = 1\n\n    BLOCK_DMODEL = triton.next_power_of_2(Lk)\n    BLOCK_DV = triton.next_power_of_2(Lv)\n\n    _fwd_kernel_stage1[grid](\n        q,\n        k_buffer,\n        v_buffer,\n        sm_scale_withk,\n        kv_indptr,\n        kv_indices,\n        att_out,\n        att_lse,\n        num_kv_splits,\n        q.stride(0),\n        q.stride(1),\n        k_buffer.stride(0),\n        k_buffer.stride(1),\n        v_buffer.stride(0),\n        v_buffer.stride(1),\n        att_out.stride(0),\n        att_out.stride(1),\n        att_out.stride(2),\n        kv_group_num=kv_group_num,\n        BLOCK_DMODEL=BLOCK_DMODEL,\n        BLOCK_DV=BLOCK_DV,\n        BLOCK_N=BLOCK,\n        MIN_BLOCK_KV=_MIN_BLOCK_KV,\n        logit_cap=logit_cap,\n        xai_temperature_len=xai_temperature_len,\n        num_warps=num_warps,\n        num_stages=2,\n        Lk=Lk,\n        Lv=Lv,\n    )\n\n\n@triton.jit\ndef _fwd_grouped_kernel_stage1(\n    Q,\n    K_Buffer,\n    V_Buffer,\n    sm_scale_withk,\n    kv_indptr,\n    kv_indices,\n    Att_Out,\n    Att_Lse,\n    num_kv_splits,\n    stride_qbs,\n    stride_qh,\n    stride_buf_kbs,\n    stride_buf_kh,\n    stride_buf_vbs,\n    stride_buf_vh,\n    stride_mid_ob,\n    stride_mid_oh,\n    stride_mid_os,\n    kv_group_num: tl.constexpr,\n    q_head_num: tl.constexpr,\n    BLOCK_DMODEL: tl.constexpr,\n    BLOCK_DPE: tl.constexpr,\n    BLOCK_DV: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    BLOCK_H: tl.constexpr,\n    MIN_BLOCK_KV: tl.constexpr,\n    logit_cap: tl.constexpr,\n    xai_temperature_len: tl.constexpr,\n    Lk: tl.constexpr,\n    Lv: tl.constexpr,\n):\n    cur_batch = tl.program_id(0)\n    cur_head_id = tl.program_id(1)\n    cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)\n    split_kv_id = tl.program_id(2)\n\n    if BLOCK_H < kv_group_num:\n        VALID_BLOCK_H: tl.constexpr = BLOCK_H\n    else:\n        VALID_BLOCK_H: tl.constexpr = kv_group_num\n    cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)\n    mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H\n    mask_h = mask_h & (cur_head < q_head_num)\n\n    offs_d = tl.arange(0, BLOCK_DMODEL)\n    offs_dv = tl.arange(0, BLOCK_DV)\n    mask_d = offs_d < Lk\n    mask_dv = offs_dv < Lv\n\n    cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch)\n    cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx\n    kv_splits = tl.load(num_kv_splits + cur_batch)\n\n    if xai_temperature_len > 0:\n        offs_qidx = cur_batch_seq_len - 1\n        xai_temperature_scale = 1.0 / tl.log2(float(xai_temperature_len))\n        _qtemp = tl.log2(offs_qidx.to(tl.float32)) * xai_temperature_scale\n        xai_temperature_reg = tl.where(offs_qidx > xai_temperature_len, _qtemp, 1.0)\n\n    offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]\n\n    if BLOCK_DPE > 0:\n        offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)\n        mask_dpe = offs_dpe < Lk\n        off_qpe = (\n            cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]\n        )\n\n    kv_len_per_split = (\n        tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV\n    )\n    split_kv_start = kv_len_per_split * split_kv_id\n    split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)\n\n    e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float(\"inf\")\n    e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)\n    acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)\n\n    if split_kv_end > split_kv_start:\n        q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)\n        if BLOCK_DPE > 0:\n            qpe = tl.load(\n                Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0\n            )\n        for start_n in range(split_kv_start, split_kv_end, BLOCK_N):\n            offs_n = start_n + tl.arange(0, BLOCK_N)\n            kv_loc = tl.load(\n                kv_indices + cur_batch_kv_start_idx + offs_n,\n                mask=offs_n < split_kv_end,\n                other=0,\n            )\n            offs_buf_k = (\n                kv_loc[None, :] * stride_buf_kbs\n                + cur_kv_head * stride_buf_kh\n                + offs_d[:, None]\n            )\n            k = tl.load(\n                K_Buffer + offs_buf_k,\n                mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]),\n                other=0.0,\n            )\n            qk = tl.dot(q, k.to(q.dtype))\n            if BLOCK_DPE > 0:\n                offs_buf_kpe = (\n                    kv_loc[None, :] * stride_buf_kbs\n                    + cur_kv_head * stride_buf_kh\n                    + offs_dpe[:, None]\n                )\n                kpe = tl.load(\n                    K_Buffer + offs_buf_kpe,\n                    mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]),\n                    other=0.0,\n                )\n                qk += tl.dot(qpe, kpe.to(qpe.dtype))\n            qk *= sm_scale_withk\n\n            if logit_cap > 0:\n                qk = logit_cap * tanh(qk / logit_cap)\n\n            if xai_temperature_len > 0:\n                qk *= xai_temperature_reg[:, None]\n\n            qk = tl.where(\n                mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float(\"-inf\")\n            )\n\n            offs_buf_v = (\n                kv_loc[:, None] * stride_buf_vbs\n                + cur_kv_head * stride_buf_vh\n                + offs_dv[None, :]\n            )\n            v = tl.load(\n                V_Buffer + offs_buf_v,\n                mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),\n                other=0.0,\n            )\n\n            n_e_max = tl.maximum(tl.max(qk, 1), e_max)\n            re_scale = tl.exp(e_max - n_e_max)\n            p = tl.exp(qk - n_e_max[:, None])\n            acc *= re_scale[:, None]\n            acc += tl.dot(p.to(v.dtype), v)\n\n            e_sum = e_sum * re_scale + tl.sum(p, 1)\n            e_max = n_e_max\n\n        offs_mid_o = (\n            cur_batch * stride_mid_ob\n            + cur_head[:, None] * stride_mid_oh\n            + split_kv_id * stride_mid_os\n            + offs_dv[None, :]\n        )\n\n        tl.store(\n            Att_Out + offs_mid_o,\n            acc / e_sum[:, None],\n            mask=(mask_h[:, None]) & (mask_dv[None, :]),\n        )\n\n        offs_mid_o_1 = (\n            cur_batch * stride_mid_ob\n            + cur_head * stride_mid_oh\n            + split_kv_id * stride_mid_os\n        ) // Lv\n\n        tl.store(\n            Att_Lse + offs_mid_o_1,\n            e_max + tl.log(e_sum),\n            mask=mask_h,\n        )\n\n\ndef _decode_grouped_att_m_fwd(\n    q,\n    k_buffer,\n    v_buffer,\n    att_out,\n    att_lse,\n    kv_indptr,\n    kv_indices,\n    num_kv_splits,\n    max_kv_splits,\n    sm_scale_withk,\n    logit_cap,\n    xai_temperature_len=-1,\n):\n    BLOCK = 32\n    Lk = k_buffer.shape[-1]\n    Lv = v_buffer.shape[-1]\n\n    # [TODO] work around shmem limit on MI3xx\n    if _is_hip and Lk >= 576:\n        BLOCK = 16\n\n    if Lk == 576:\n        BLOCK_DMODEL = 512\n        BLOCK_DPE = 64\n    elif Lk == 288:\n        BLOCK_DMODEL = 256\n        BLOCK_DPE = 32\n    else:\n        BLOCK_DMODEL = triton.next_power_of_2(Lk)\n        BLOCK_DPE = 0\n    BLOCK_DV = triton.next_power_of_2(Lv)\n\n    batch, head_num = q.shape[0], q.shape[1]\n    kv_group_num = q.shape[1] // k_buffer.shape[1]\n\n    BLOCK_H = 16\n    MAX_KV_SPLITS = max_kv_splits\n    grid = (\n        batch,\n        triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),\n        MAX_KV_SPLITS,\n    )\n\n    extra_kargs = {}\n    num_stages = 2\n    if _is_hip:\n        # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html\n        # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py\n        extra_kargs = {\"waves_per_eu\": 1, \"matrix_instr_nonkdim\": 16, \"kpack\": 2}\n        num_stages = 1\n\n    _fwd_grouped_kernel_stage1[grid](\n        q,\n        k_buffer,\n        v_buffer,\n        sm_scale_withk,\n        kv_indptr,\n        kv_indices,\n        att_out,\n        att_lse,\n        num_kv_splits,\n        q.stride(0),\n        q.stride(1),\n        k_buffer.stride(0),\n        k_buffer.stride(1),\n        v_buffer.stride(0),\n        v_buffer.stride(1),\n        att_out.stride(0),\n        att_out.stride(1),\n        att_out.stride(2),\n        kv_group_num=kv_group_num,\n        q_head_num=head_num,\n        BLOCK_DMODEL=BLOCK_DMODEL,\n        BLOCK_DPE=BLOCK_DPE,\n        BLOCK_DV=BLOCK_DV,\n        BLOCK_N=BLOCK,\n        BLOCK_H=BLOCK_H,\n        MIN_BLOCK_KV=_MIN_BLOCK_KV,\n        logit_cap=logit_cap,\n        xai_temperature_len=xai_temperature_len,\n        num_warps=4,\n        num_stages=num_stages,\n        Lk=Lk,\n        Lv=Lv,\n        **extra_kargs,\n    )\n\n\n@triton.jit\ndef _fwd_kernel_stage2(\n    Mid_O,\n    Mid_O_1,\n    O,\n    v_scale,\n    kv_indptr,\n    num_kv_splits,\n    sink_ptr,\n    stride_mid_ob,\n    stride_mid_oh,\n    stride_mid_os,\n    stride_obs,\n    stride_oh,\n    MAX_KV_SPLITS: tl.constexpr,\n    MIN_BLOCK_KV: tl.constexpr,\n    BLOCK_DV: tl.constexpr,\n    Lv: tl.constexpr,\n    HAS_SINK: tl.constexpr,\n):\n    cur_batch = tl.program_id(0)\n    cur_head = tl.program_id(1)\n\n    cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - tl.load(\n        kv_indptr + cur_batch\n    )\n    kv_splits = tl.load(num_kv_splits + cur_batch)\n\n    offs_d = tl.arange(0, BLOCK_DV)\n    mask_d = offs_d < Lv\n\n    e_sum = 0.0\n    e_max = -float(\"inf\")\n    acc = tl.zeros([BLOCK_DV], dtype=tl.float32)\n\n    offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d\n    offs_logic = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh) // Lv\n    kv_len_per_split = (\n        tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV\n    )\n\n    for split_kv_id in range(0, MAX_KV_SPLITS):\n        split_kv_start = kv_len_per_split * split_kv_id\n        split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)\n\n        if split_kv_end > split_kv_start:\n            tv = tl.load(\n                Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0\n            )\n            tlogic = tl.load(Mid_O_1 + offs_logic + split_kv_id * stride_mid_os // Lv)\n            n_e_max = tl.maximum(tlogic, e_max)\n\n            old_scale = tl.exp(e_max - n_e_max)\n            acc *= old_scale\n            exp_logic = tl.exp(tlogic - n_e_max)\n            acc += exp_logic * tv\n\n            e_sum = e_sum * old_scale + exp_logic\n            e_max = n_e_max\n\n    if HAS_SINK:\n        cur_sink = tl.load(sink_ptr + cur_head)\n        e_sum += tl.exp(cur_sink - e_max)\n\n    tl.store(\n        O + cur_batch * stride_obs + cur_head * stride_oh + offs_d,\n        acc / e_sum * v_scale,\n        mask=mask_d,\n    )\n\n\ndef _decode_softmax_reducev_fwd(\n    logits,\n    lse,\n    q,\n    o,\n    v_scale,\n    v_buffer,\n    kv_indptr,\n    num_kv_splits,\n    max_kv_splits,\n    sinks=None,\n):\n    batch, head_num = q.shape[0], q.shape[1]\n    Lv = v_buffer.shape[-1]\n    BLOCK_DV = triton.next_power_of_2(Lv)\n\n    MAX_KV_SPLITS = max_kv_splits\n    HAS_SINK = sinks is not None\n\n    extra_kargs = {}\n    if _is_hip:\n        # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html\n        # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py\n        extra_kargs = {\"waves_per_eu\": 4, \"matrix_instr_nonkdim\": 16, \"kpack\": 2}\n\n    grid = (batch, head_num)\n    _fwd_kernel_stage2[grid](\n        logits,\n        lse,\n        o,\n        v_scale,\n        kv_indptr,\n        num_kv_splits,\n        sinks,\n        logits.stride(0),\n        logits.stride(1),\n        logits.stride(2),\n        o.stride(0),\n        o.stride(1),\n        MAX_KV_SPLITS=MAX_KV_SPLITS,\n        MIN_BLOCK_KV=_MIN_BLOCK_KV,\n        BLOCK_DV=BLOCK_DV,\n        Lv=Lv,\n        HAS_SINK=HAS_SINK,\n        num_warps=4,\n        num_stages=2,\n        **extra_kargs,\n    )\n\n\ndef decode_attention_fwd_normal(\n    q,\n    k_buffer,\n    v_buffer,\n    o,\n    kv_indptr,\n    kv_indices,\n    attn_logits,\n    attn_lse,\n    num_kv_splits,\n    max_kv_splits,\n    sm_scale_withk,\n    v_scale,\n    logit_cap=0.0,\n    sinks=None,\n    xai_temperature_len=-1,\n):\n    _decode_att_m_fwd(\n        q,\n        k_buffer,\n        v_buffer,\n        attn_logits,\n        attn_lse,\n        kv_indptr,\n        kv_indices,\n        num_kv_splits,\n        max_kv_splits,\n        sm_scale_withk,\n        logit_cap,\n        xai_temperature_len,\n    )\n    _decode_softmax_reducev_fwd(\n        attn_logits,\n        attn_lse,\n        q,\n        o,\n        v_scale,\n        v_buffer,\n        kv_indptr,\n        num_kv_splits,\n        max_kv_splits,\n        sinks,\n    )\n\n\ndef decode_attention_fwd_grouped(\n    q,\n    k_buffer,\n    v_buffer,\n    o,\n    kv_indptr,\n    kv_indices,\n    attn_logits,\n    attn_lse,\n    num_kv_splits,\n    max_kv_splits,\n    sm_scale_withk,\n    v_scale,\n    logit_cap=0.0,\n    sinks=None,\n    xai_temperature_len=-1,\n):\n    _decode_grouped_att_m_fwd(\n        q,\n        k_buffer,\n        v_buffer,\n        attn_logits,\n        attn_lse,\n        kv_indptr,\n        kv_indices,\n        num_kv_splits,\n        max_kv_splits,\n        sm_scale_withk,\n        logit_cap,\n        xai_temperature_len,\n    )\n    _decode_softmax_reducev_fwd(\n        attn_logits,\n        attn_lse,\n        q,\n        o,\n        v_scale,\n        v_buffer,\n        kv_indptr,\n        num_kv_splits,\n        max_kv_splits,\n        sinks,\n    )\n\n\ndef decode_attention_fwd(\n    q,\n    k_buffer,\n    v_buffer,\n    o,\n    kv_indptr,\n    kv_indices,\n    attn_logits,\n    attn_lse,\n    num_kv_splits,\n    max_kv_splits,\n    sm_scale,\n    k_scale,\n    v_scale,\n    logit_cap=0.0,\n    sinks=None,\n    xai_temperature_len=-1,\n):\n    assert max_kv_splits == attn_logits.shape[2]\n    assert q.shape[0] <= kv_indptr.shape[0] - 1\n    assert q.shape[0] <= attn_logits.shape[0]\n\n    kv_group_num = q.shape[1] // v_buffer.shape[1]\n\n    if kv_group_num == 1:\n        # MHA\n        decode_attention_fwd_normal(\n            q,\n            k_buffer,\n            v_buffer,\n            o,\n            kv_indptr,\n            kv_indices,\n            attn_logits,\n            attn_lse,\n            num_kv_splits,\n            max_kv_splits,\n            sm_scale * k_scale,\n            v_scale,\n            logit_cap=logit_cap,\n            sinks=sinks,\n            xai_temperature_len=xai_temperature_len,\n        )\n    else:\n        # GQA/MQA/MLA\n        decode_attention_fwd_grouped(\n            q,\n            k_buffer,\n            v_buffer,\n            o,\n            kv_indptr,\n            kv_indices,\n            attn_logits,\n            attn_lse,\n            num_kv_splits,\n            max_kv_splits,\n            sm_scale * k_scale,\n            v_scale,\n            logit_cap=logit_cap,\n            sinks=sinks,\n            xai_temperature_len=xai_temperature_len,\n        )\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py",
    "content": "import torch\nimport triton\nimport triton.language as tl\n\nfrom sglang.srt.server_args import get_global_server_args\nfrom sglang.srt.utils import is_cuda, is_hip\n\n_is_cuda = is_cuda()\nif _is_cuda:\n    CUDA_CAPABILITY = torch.cuda.get_device_capability()\n\n_is_hip = is_hip()\n\nif get_global_server_args().triton_attention_reduce_in_fp32:\n    REDUCE_TRITON_TYPE = tl.float32\n    REDUCE_TORCH_TYPE = torch.float32\nelse:\n    REDUCE_TRITON_TYPE = tl.float16\n    REDUCE_TORCH_TYPE = torch.float16\n\n\n@triton.jit\ndef tanh(x):\n    # Tanh is just a scaled sigmoid\n    return 2 * tl.sigmoid(2 * x) - 1\n\n\n@triton.jit\ndef _fwd_kernel_flash_decode_stage1(\n    Q,\n    K,\n    V,\n    sm_scale,\n    Req_to_tokens,\n    B_req_idx,\n    B_Seqlen,\n    Mid_O,  # [batch, head, seq_block_num, head_dim]\n    Mid_O_LogExpSum,  # [batch, head, seq_block_num]\n    stride_req_to_tokens_b,\n    stride_req_to_tokens_s,\n    stride_qbs,\n    stride_qh,\n    stride_qd,\n    stride_kbs,\n    stride_kh,\n    stride_kd,\n    stride_vbs,\n    stride_vh,\n    stride_vd,\n    stride_mid_ob,\n    stride_mid_oh,\n    stride_mid_os,\n    stride_mid_od,\n    stride_mid_o_eb,\n    stride_mid_o_eh,\n    stride_mid_o_es,\n    gqa_group_size,\n    BLOCK_SEQ: tl.constexpr,\n    BLOCK_DMODEL: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n):\n    cur_batch = tl.program_id(0)\n    cur_head = tl.program_id(1)\n    seq_start_block = tl.program_id(2)\n    cur_kv_head = cur_head // gqa_group_size\n\n    offs_d = tl.arange(0, BLOCK_DMODEL)\n    cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n    cur_batch_req_idx = tl.load(B_req_idx + cur_batch)\n    cur_batch_start_index = seq_start_block * BLOCK_SEQ\n    cur_batch_end_index = tl.minimum(\n        cur_batch_seq_len, cur_batch_start_index + BLOCK_SEQ\n    )\n\n    off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d\n\n    block_n_size = (\n        tl.where(\n            cur_batch_end_index - cur_batch_start_index <= 0,\n            0,\n            cur_batch_end_index - cur_batch_start_index + BLOCK_N - 1,\n        )\n        // BLOCK_N\n    )\n\n    offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N)\n\n    q = tl.load(Q + off_q)\n\n    sum_exp = 0.0\n    max_logic = -float(\"inf\")\n    acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)\n\n    for start_n in range(0, block_n_size, 1):\n        offs_n_new = start_n * BLOCK_N + offs_n\n        k_loc = tl.load(\n            Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new,\n            mask=offs_n_new < cur_batch_end_index,\n            other=0,\n        )\n        off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :]\n        k = tl.load(\n            K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0\n        )\n        att_value = tl.sum(q[None, :] * k, 1)\n        att_value *= sm_scale\n        att_value = tl.where(offs_n_new < cur_batch_end_index, att_value, float(\"-inf\"))\n        v = tl.load(\n            V + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0\n        )\n\n        cur_max_logic = tl.max(att_value, axis=0)\n        new_max_logic = tl.maximum(cur_max_logic, max_logic)\n\n        exp_logic = tl.exp(att_value - new_max_logic)\n        logic_scale = tl.exp(max_logic - new_max_logic)\n        acc *= logic_scale\n        acc += tl.sum(exp_logic[:, None] * v, axis=0)\n\n        sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=0)\n        max_logic = new_max_logic\n\n    need_store = tl.where(block_n_size == 0, 0, 1)\n    for _ in range(0, need_store, 1):\n        off_mid_o = (\n            cur_batch * stride_mid_ob\n            + cur_head * stride_mid_oh\n            + seq_start_block * stride_mid_os\n            + offs_d\n        )\n        off_mid_o_logexpsum = (\n            cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh + seq_start_block\n        )\n        tl.store(Mid_O + off_mid_o, acc / sum_exp)\n        tl.store(Mid_O_LogExpSum + off_mid_o_logexpsum, max_logic + tl.log(sum_exp))\n    return\n\n\n@triton.jit\ndef _fwd_kernel_flash_decode_stage2(\n    B_Seqlen,\n    Mid_O,  # [batch, head, seq_block_num, head_dim]\n    Mid_O_LogExpSum,  # [batch, head, seq_block_num]\n    O,  # [batch, head, head_dim]\n    stride_mid_ob,\n    stride_mid_oh,\n    stride_mid_os,\n    stride_mid_od,\n    stride_mid_o_eb,\n    stride_mid_o_eh,\n    stride_mid_o_es,\n    stride_obs,\n    stride_oh,\n    stride_od,\n    BLOCK_SEQ: tl.constexpr,\n    BLOCK_DMODEL: tl.constexpr,\n):\n    cur_batch = tl.program_id(0)\n    cur_head = tl.program_id(1)\n\n    offs_d = tl.arange(0, BLOCK_DMODEL)\n    cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n\n    block_n_size = (\n        tl.where(cur_batch_seq_len <= 0, 0, cur_batch_seq_len + BLOCK_SEQ - 1)\n        // BLOCK_SEQ\n    )\n\n    sum_exp = 0.0\n    max_logic = -float(\"inf\")\n    acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)\n\n    offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d\n    offs_logic = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh\n    for block_seq_n in range(0, block_n_size, 1):\n        tv = tl.load(Mid_O + offs_v + block_seq_n * stride_mid_os)\n        tlogic = tl.load(Mid_O_LogExpSum + offs_logic + block_seq_n)\n        new_max_logic = tl.maximum(tlogic, max_logic)\n\n        old_scale = tl.exp(max_logic - new_max_logic)\n        acc *= old_scale\n        exp_logic = tl.exp(tlogic - new_max_logic)\n        acc += exp_logic * tv\n        sum_exp = sum_exp * old_scale + exp_logic\n        max_logic = new_max_logic\n\n    tl.store(O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, acc / sum_exp)\n    return\n\n\n@torch.no_grad()\ndef flash_decode_stage1(\n    q,\n    k,\n    v,\n    Req_to_tokens,\n    B_req_idx,\n    B_Seqlen,\n    max_len_in_batch,\n    mid_out,\n    mid_out_logsumexp,\n    block_seq,\n):\n    BLOCK_SEQ = block_seq\n    BLOCK_N = 16\n    assert BLOCK_SEQ % BLOCK_N == 0\n    # shape constraints\n    Lq, Lk = q.shape[-1], k.shape[-1]\n    assert Lq == Lk\n    assert Lk in {16, 32, 64, 128}\n    sm_scale = 1.0 / (Lk**0.5)\n    batch, head_num = B_req_idx.shape[0], q.shape[1]\n    grid = (batch, head_num, triton.cdiv(max_len_in_batch, BLOCK_SEQ))\n    gqa_group_size = q.shape[1] // k.shape[1]\n\n    _fwd_kernel_flash_decode_stage1[grid](\n        q,\n        k,\n        v,\n        sm_scale,\n        Req_to_tokens,\n        B_req_idx,\n        B_Seqlen,\n        mid_out,\n        mid_out_logsumexp,\n        Req_to_tokens.stride(0),\n        Req_to_tokens.stride(1),\n        q.stride(0),\n        q.stride(1),\n        q.stride(2),\n        k.stride(0),\n        k.stride(1),\n        k.stride(2),\n        v.stride(0),\n        v.stride(1),\n        v.stride(2),\n        mid_out.stride(0),\n        mid_out.stride(1),\n        mid_out.stride(2),\n        mid_out.stride(3),\n        mid_out_logsumexp.stride(0),\n        mid_out_logsumexp.stride(1),\n        mid_out_logsumexp.stride(2),\n        gqa_group_size,\n        BLOCK_SEQ=BLOCK_SEQ,\n        BLOCK_DMODEL=Lk,\n        BLOCK_N=BLOCK_N,\n        num_warps=1,\n        num_stages=2,\n    )\n    return\n\n\n@torch.no_grad()\ndef flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, O, block_seq):\n    Lk = mid_out.shape[-1]\n    assert Lk in {16, 32, 64, 128}\n    batch, head_num = mid_out.shape[0], mid_out.shape[1]\n    grid = (batch, head_num)\n\n    _fwd_kernel_flash_decode_stage2[grid](\n        B_Seqlen,\n        mid_out,\n        mid_out_logexpsum,\n        O,\n        mid_out.stride(0),\n        mid_out.stride(1),\n        mid_out.stride(2),\n        mid_out.stride(3),\n        mid_out_logexpsum.stride(0),\n        mid_out_logexpsum.stride(1),\n        mid_out_logexpsum.stride(2),\n        O.stride(0),\n        O.stride(1),\n        O.stride(2),\n        BLOCK_SEQ=block_seq,\n        BLOCK_DMODEL=Lk,\n        num_warps=4,\n        num_stages=2,\n    )\n    return\n\n\ndef flash_decode_attention_fwd(\n    q,\n    k_buffer,\n    v_buffer,\n    o,\n    req_to_token,\n    b_req_idx,\n    b_start_loc,\n    b_seq_len,\n    attn_logits,\n    max_len_in_batch,\n    sm_scale,\n    logit_cap=0.0,\n):\n    BLOCK_SEQ = 256\n    kv_group_num = q.shape[1] // v_buffer.shape[1]\n    # batch_size = q.shape[0]\n\n    block_seq_num = (max_len_in_batch + BLOCK_SEQ - 1) // BLOCK_SEQ\n\n    mid_o = torch.empty(\n        [q.shape[0], q.shape[1], block_seq_num, q.shape[-1]],\n        dtype=torch.float32,\n        device=\"cuda\",\n    )\n    mid_o_logexpsum = torch.empty(\n        [q.shape[0], q.shape[1], block_seq_num], dtype=torch.float32, device=\"cuda\"\n    )\n\n    flash_decode_stage1(\n        q,\n        k_buffer,\n        v_buffer,\n        req_to_token,\n        b_req_idx,\n        b_seq_len,\n        max_len_in_batch,\n        mid_o,\n        mid_o_logexpsum,\n        BLOCK_SEQ,\n    )\n    flash_decode_stage2(mid_o, mid_o_logexpsum, b_seq_len, o, BLOCK_SEQ)\n\n\n@triton.jit\ndef _sparse_fwd_kernel_flash_decode_stage1(  # Double Sparsity's approximate attention\n    Q_Label,\n    K_Label_Buffer,\n    sm_scale,\n    Req_to_tokens,  # shape: [B, S]\n    B_Seqlen,\n    Att_Out,  # shape: [H, B, S] easier for topk\n    stride_req_to_tokens_b,\n    stride_qbs,\n    stride_qh,\n    stride_buf_kbs,\n    stride_buf_kh,\n    att_stride_h,\n    att_stride_b,\n    kv_group_num: tl.constexpr,\n    BLOCK_DMODEL: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    logit_cap: tl.constexpr,\n):\n    cur_batch = tl.program_id(0)\n    cur_head = tl.program_id(1)\n    start_n = tl.program_id(2)\n\n    cur_kv_head = cur_head // kv_group_num\n\n    offs_d = tl.arange(0, BLOCK_DMODEL)\n    cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n\n    cur_batch_start_index = 0\n    cur_batch_end_index = cur_batch_seq_len\n\n    min_val = -float(\"inf\")\n    att_value = tl.full([BLOCK_N], min_val, dtype=tl.float32)\n\n    off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d\n\n    offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)\n\n    block_index = start_n * BLOCK_N\n    block_mask = tl.where(block_index < cur_batch_seq_len, 1, 0)\n\n    for start_mark in range(0, block_mask, 1):\n        q = tl.load(Q_Label + off_q + start_mark).to(REDUCE_TRITON_TYPE)\n        offs_n_new = cur_batch_start_index + offs_n\n        k_loc = tl.load(\n            Req_to_tokens + stride_req_to_tokens_b * cur_batch + offs_n_new,\n            mask=offs_n_new < cur_batch_end_index,\n            other=0,\n        )\n        offs_buf_k = (\n            k_loc[:, None] * stride_buf_kbs\n            + cur_kv_head * stride_buf_kh\n            + offs_d[None, :]\n        )\n        k = tl.load(\n            K_Label_Buffer + offs_buf_k,\n            mask=offs_n_new[:, None] < cur_batch_end_index,\n            other=0.0,\n        ).to(REDUCE_TRITON_TYPE)\n\n        att_value = tl.sum(q[None, :] * k, 1)\n        att_value *= sm_scale\n\n        if logit_cap > 0:\n            att_value = logit_cap * tanh(att_value / logit_cap)\n\n    att_value = tl.where(offs_n < cur_batch_end_index, att_value, min_val)\n    off_o = cur_head * att_stride_h + (cur_batch * att_stride_b + offs_n)\n    tl.store(Att_Out + off_o, att_value)\n\n\n@triton.jit\ndef _sparse_fwd_kernel_flash_decode_stage2(\n    Q,\n    K,\n    V,\n    sm_scale,\n    Req_to_tokens,  # shape: [B, S]\n    Topk_token_indices,  # shape: [H, B, k]\n    Mid_O,  # [batch, head, seq_block_num, head_dim]\n    Mid_O_LogExpSum,  # [batch, head, seq_block_num]\n    Heavy_token_num,  # NOTE: This can be used as constexpr but we may support dynamic heavy token number in the future\n    stride_req_to_tokens_b,\n    stride_topk_token_indices_h,\n    stride_topk_token_indices_b,\n    stride_qbs,\n    stride_qh,\n    stride_kbs,\n    stride_kh,\n    stride_vbs,\n    stride_vh,\n    stride_mid_ob,\n    stride_mid_oh,\n    stride_mid_os,\n    stride_mid_o_eb,\n    stride_mid_o_eh,\n    gqa_group_size,\n    BLOCK_SEQ: tl.constexpr,\n    BLOCK_DMODEL: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n):\n    cur_batch = tl.program_id(0)\n    cur_head = tl.program_id(1)\n    seq_start_block = tl.program_id(2)\n    cur_kv_head = cur_head // gqa_group_size\n\n    offs_d = tl.arange(0, BLOCK_DMODEL)\n    cur_batch_start_index = seq_start_block * BLOCK_SEQ\n    cur_batch_end_index = tl.minimum(Heavy_token_num, cur_batch_start_index + BLOCK_SEQ)\n\n    off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d\n\n    block_n_size = (\n        tl.where(\n            cur_batch_end_index - cur_batch_start_index <= 0,\n            0,\n            cur_batch_end_index - cur_batch_start_index + BLOCK_N - 1,\n        )\n        // BLOCK_N\n    )\n\n    # offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N)\n    offs_n = tl.arange(0, BLOCK_N)\n\n    q = tl.load(Q + off_q)\n\n    sum_exp = 0.0\n    max_logic = -float(\"inf\")\n    acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)\n\n    for start_n in range(cur_batch_start_index, cur_batch_end_index, BLOCK_N):\n        # for start_n in range(0, block_n_size, 1):\n        # offs_n_new = start_n * BLOCK_N + offs_n\n        offs_n_new = start_n + offs_n\n        # offs_n_new = cur_batch_start_index + start_n * BLOCK_N + offs_n\n        topk_token_indices = tl.load(\n            Topk_token_indices\n            + stride_topk_token_indices_h * cur_head\n            + stride_topk_token_indices_b * cur_batch\n            + offs_n_new,\n            mask=offs_n_new < cur_batch_end_index,\n            other=0,\n        )\n        k_loc = tl.load(\n            Req_to_tokens + stride_req_to_tokens_b * cur_batch + topk_token_indices,\n            mask=offs_n_new < cur_batch_end_index,\n            other=0,\n        )\n        off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :]\n        k = tl.load(\n            K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0\n        )\n        att_value = tl.sum(q[None, :] * k, 1)\n        att_value *= sm_scale\n        att_value = tl.where(offs_n_new < cur_batch_end_index, att_value, float(\"-inf\"))\n        v = tl.load(\n            V + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0\n        )\n\n        cur_max_logic = tl.max(att_value, axis=0)\n        new_max_logic = tl.maximum(cur_max_logic, max_logic)\n\n        exp_logic = tl.exp(att_value - new_max_logic)\n        logic_scale = tl.exp(max_logic - new_max_logic)\n        acc *= logic_scale\n        acc += tl.sum(exp_logic[:, None] * v, axis=0)\n\n        sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=0)\n        max_logic = new_max_logic\n\n    # need_store = tl.where(block_n_size == 0, 0, 1)\n    need_store = 1\n    for _ in range(0, need_store, 1):\n        off_mid_o = (\n            cur_batch * stride_mid_ob\n            + cur_head * stride_mid_oh\n            + seq_start_block * stride_mid_os\n            + offs_d\n        )\n        off_mid_o_logexpsum = (\n            cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh + seq_start_block\n        )\n        tl.store(Mid_O + off_mid_o, acc / sum_exp)\n        tl.store(Mid_O_LogExpSum + off_mid_o_logexpsum, max_logic + tl.log(sum_exp))\n    return\n\n\n@triton.jit\ndef _sparse_fwd_kernel_flash_decode_stage3(\n    Mid_O,  # [batch, head, seq_block_num, head_dim]\n    Mid_O_LogExpSum,  # [batch, head, seq_block_num]\n    O,  # [batch, head, head_dim]\n    seq_len,  # NOTE: This can be used as constexpr but we may support dynamic heavy token number in the future\n    stride_mid_ob,\n    stride_mid_oh,\n    stride_mid_os,\n    stride_mid_o_eb,\n    stride_mid_o_eh,\n    stride_obs,\n    stride_oh,\n    BLOCK_SEQ: tl.constexpr,\n    BLOCK_DMODEL: tl.constexpr,\n):\n    cur_batch = tl.program_id(0)\n    cur_head = tl.program_id(1)\n\n    offs_d = tl.arange(0, BLOCK_DMODEL)\n\n    block_n_size = tl.where(seq_len <= 0, 0, seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ\n\n    sum_exp = 0.0\n    max_logic = -float(\"inf\")\n    acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)\n\n    offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d\n    offs_logic = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh\n    for block_seq_n in range(0, block_n_size, 1):\n        tv = tl.load(Mid_O + offs_v + block_seq_n * stride_mid_os)\n        tlogic = tl.load(Mid_O_LogExpSum + offs_logic + block_seq_n)\n        new_max_logic = tl.maximum(tlogic, max_logic)\n\n        old_scale = tl.exp(max_logic - new_max_logic)\n        acc *= old_scale\n        exp_logic = tl.exp(tlogic - new_max_logic)\n        acc += exp_logic * tv\n        sum_exp = sum_exp * old_scale + exp_logic\n        max_logic = new_max_logic\n\n    tl.store(O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, acc / sum_exp)\n    return\n\n\ndef sparse_flash_decode_stage1(\n    q_label,\n    k_label_buffer,\n    att_out,\n    Req_to_tokens,\n    B_Seqlen,\n    max_len_in_batch,\n    sm_scale,\n    logit_cap,\n):\n    BLOCK = 32\n    # shape constraints\n    Lq, Lk = q_label.shape[-1], k_label_buffer.shape[-1]\n    assert Lq == Lk\n    assert Lk in {16, 32, 64, 128, 256, 576}\n\n    BLOCK_DMODEL = Lk\n\n    batch, head_num = q_label.shape[0], q_label.shape[1]\n\n    grid = (batch, head_num, triton.cdiv(max_len_in_batch, BLOCK))\n    kv_group_num = q_label.shape[1] // k_label_buffer.shape[1]\n\n    if kv_group_num == 1:\n        num_warps = 4\n    else:\n        num_warps = 2\n\n    _sparse_fwd_kernel_flash_decode_stage1[grid](\n        q_label,\n        k_label_buffer,\n        sm_scale,\n        Req_to_tokens,\n        B_Seqlen,\n        att_out,\n        Req_to_tokens.stride(0),\n        q_label.stride(0),\n        q_label.stride(1),\n        k_label_buffer.stride(0),\n        k_label_buffer.stride(1),\n        att_out.stride(0),\n        att_out.stride(1),\n        kv_group_num,\n        BLOCK_DMODEL,\n        BLOCK,\n        logit_cap,\n        num_warps=num_warps,\n        num_stages=1,\n    )\n\n\n@torch.no_grad()\ndef sparse_flash_decode_stage2(\n    q,\n    k,\n    v,\n    Req_to_tokens,\n    Topk_token_indices,\n    heavy_token_num,\n    mid_out,\n    mid_out_logsumexp,\n    block_seq,\n    sm_scale,\n):\n    BLOCK_SEQ = block_seq\n    BLOCK_N = 16\n    assert BLOCK_SEQ % BLOCK_N == 0\n    # shape constraints\n    Lq, Lk = q.shape[-1], k.shape[-1]\n    assert Lq == Lk\n    assert Lk in {16, 32, 64, 128}\n    assert heavy_token_num == Topk_token_indices.shape[-1]\n    # sm_scale = 1.0 / (Lk ** 0.5)\n    batch, head_num = q.shape[0], q.shape[1]\n    grid = (batch, head_num, triton.cdiv(heavy_token_num, BLOCK_SEQ))\n\n    gqa_group_size = q.shape[1] // k.shape[1]\n\n    _sparse_fwd_kernel_flash_decode_stage2[grid](\n        q,\n        k,\n        v,\n        sm_scale,\n        Req_to_tokens,\n        Topk_token_indices,\n        mid_out,\n        mid_out_logsumexp,\n        heavy_token_num,\n        Req_to_tokens.stride(0),\n        Topk_token_indices.stride(0),\n        Topk_token_indices.stride(1),\n        q.stride(0),\n        q.stride(1),\n        k.stride(0),\n        k.stride(1),\n        v.stride(0),\n        v.stride(1),\n        mid_out.stride(0),\n        mid_out.stride(1),\n        mid_out.stride(2),\n        mid_out_logsumexp.stride(0),\n        mid_out_logsumexp.stride(1),\n        gqa_group_size,\n        BLOCK_SEQ=BLOCK_SEQ,\n        BLOCK_DMODEL=Lk,\n        BLOCK_N=BLOCK_N,\n        num_warps=1,\n        num_stages=2,\n    )\n    return\n\n\n@torch.no_grad()\ndef sparse_flash_decode_stage3(Seqlen, mid_out, mid_out_logexpsum, O, block_seq):\n    Lk = mid_out.shape[-1]\n    assert Lk in {16, 32, 64, 128}\n    batch, head_num = mid_out.shape[0], mid_out.shape[1]\n    grid = (batch, head_num)\n\n    _sparse_fwd_kernel_flash_decode_stage3[grid](\n        mid_out,\n        mid_out_logexpsum,\n        O,\n        Seqlen,\n        mid_out.stride(0),\n        mid_out.stride(1),\n        mid_out.stride(2),\n        mid_out_logexpsum.stride(0),\n        mid_out_logexpsum.stride(1),\n        O.stride(0),\n        O.stride(1),\n        BLOCK_SEQ=block_seq,\n        BLOCK_DMODEL=Lk,\n        num_warps=4,\n        num_stages=2,\n    )\n    return\n\n\ndef flash_decode_sparse_attention_fwd(\n    q,\n    k_buffer,\n    v_buffer,\n    o,\n    q_label,\n    k_label_buffer,\n    req_to_token,\n    b_seq_len,\n    max_len_in_batch,\n    sm_scale,\n    logit_cap,\n    heavy_token_num=32,\n    att_out_approx=None,\n    mid_out=None,\n    mid_o_logexpsum=None,\n    BLOCK_SEQ=256,\n):\n    # TODO(Andy): Tune BLOCK_SEQ & BLOCK_D\n    kv_group_num = q.shape[1] // v_buffer.shape[1]\n    # batch_size = q.shape[0]\n\n    # Step 1: BGEMV approximate attention (page implementation)\n\n    if att_out_approx is None:\n        att_out_approx = torch.empty(\n            [q.shape[1], q.shape[0], max_len_in_batch],\n            dtype=REDUCE_TORCH_TYPE,\n            device=q.device,\n        )\n\n    if mid_out is None:\n        block_seq_num = (heavy_token_num + BLOCK_SEQ - 1) // BLOCK_SEQ\n\n        mid_out = torch.empty(\n            [q.shape[0], q.shape[1], block_seq_num, q.shape[-1]],\n            dtype=torch.float32,\n            device=q.device,\n        )\n        mid_o_logexpsum = torch.empty(\n            [q.shape[0], q.shape[1], block_seq_num],\n            dtype=torch.float32,\n            device=q.device,\n        )\n\n    sparse_flash_decode_stage1(\n        q_label,\n        k_label_buffer,\n        att_out_approx,\n        req_to_token,\n        b_seq_len,\n        max_len_in_batch,\n        sm_scale,\n        logit_cap,\n    )\n\n    # Step 2: TopK token selection\n    # NOTE(Andy): Apply sparse decoding when min > heavy_token_num and max > sparse decoding threshold\n    # TODO(Andy): Change a faster topk implementation\n    topk_token_indices = torch.topk(att_out_approx, heavy_token_num, dim=-1).indices\n    # topk_token_indices: [H, B, k], Req_to_tokens: [B, S]\n    # topk_token_indices = torch.arange(0, heavy_token_num, device=q.device).unsqueeze(0).unsqueeze(0).expand(q.shape[1], q.shape[0], -1)\n\n    sparse_flash_decode_stage2(\n        q,\n        k_buffer,\n        v_buffer,\n        req_to_token,\n        topk_token_indices,\n        heavy_token_num,\n        mid_out,\n        mid_o_logexpsum,\n        BLOCK_SEQ,\n        sm_scale,\n    )\n\n    sparse_flash_decode_stage3(heavy_token_num, mid_out, mid_o_logexpsum, o, BLOCK_SEQ)\n\n\n# Extend attention kernel for Double Sparsity\n# Moved from https://github.com/sgl-project/sglang/blob/v0.4.2.post1/python/sglang/srt/layers/attention/triton_ops/extend_attention.py\n@triton.jit\ndef _fwd_kernel(\n    Q_Extend,\n    K_Extend,\n    V_Extend,\n    O_Extend,\n    K_Buffer,\n    V_Buffer,\n    Req_to_tokens,\n    B_req_idx,\n    B_Seq_Len,\n    B_Start_Loc_Extend,\n    B_Seq_Len_Extend,\n    sm_scale,\n    kv_group_num,\n    stride_qbs,\n    stride_qh,\n    stride_kbs,\n    stride_kh,\n    stride_vbs,\n    stride_vh,\n    stride_obs,\n    stride_oh,\n    stride_buf_kbs,\n    stride_buf_kh,\n    stride_buf_vbs,\n    stride_buf_vh,\n    stride_req_to_tokens_b,\n    logit_cap: tl.constexpr,\n    Lq: tl.constexpr,\n    Lv: tl.constexpr,\n    BLOCK_DMODEL: tl.constexpr,\n    BLOCK_DPE: tl.constexpr,\n    BLOCK_DV: tl.constexpr,\n    BLOCK_M: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n):\n    cur_seq = tl.program_id(0)\n    cur_head = tl.program_id(1)\n    cur_block_m = tl.program_id(2)\n    cur_kv_head = cur_head // kv_group_num\n\n    cur_seq_len = tl.load(B_Seq_Len + cur_seq)\n    cur_seq_len_extend = tl.load(B_Seq_Len_Extend + cur_seq)\n    cur_seq_len_prefix = cur_seq_len - cur_seq_len_extend\n\n    cur_seq_prefix_start_in_loc = 0\n    cur_seq_extend_start_contiguous = tl.load(B_Start_Loc_Extend + cur_seq)\n    cur_batch_req_idx = tl.load(B_req_idx + cur_seq)\n\n    offs_d = tl.arange(0, BLOCK_DMODEL)\n    offs_dv = tl.arange(0, BLOCK_DV)\n    offs_m = tl.arange(0, BLOCK_M)\n    mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend\n\n    mask_d = offs_d < Lq\n    mask_dv = offs_dv < Lv\n\n    offs_q = (\n        (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])\n        * stride_qbs\n        + cur_head * stride_qh\n        + offs_d[None, :]\n    )\n    q = tl.load(\n        Q_Extend + offs_q, mask=(mask_m[:, None]) & (mask_d[None, :]), other=0.0\n    )\n\n    if BLOCK_DPE > 0:\n        offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)\n        offs_qpe = (\n            (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])\n            * stride_qbs\n            + cur_head * stride_qh\n            + offs_dpe[None, :]\n        )\n        qpe = tl.load(Q_Extend + offs_qpe, mask=mask_m[:, None], other=0.0)\n\n    # stage 1: compute scores with prefix\n    offs_n = tl.arange(0, BLOCK_N)\n\n    acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32)\n    deno = tl.zeros([BLOCK_M], dtype=tl.float32)\n    e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n\n    for start_n in range(0, cur_seq_len_prefix, BLOCK_N):\n        start_n = tl.multiple_of(start_n, BLOCK_N)\n        mask_n = (start_n + offs_n) < cur_seq_len_prefix\n        offs_b_loc_prefix = cur_batch_req_idx * stride_req_to_tokens_b + (\n            cur_seq_prefix_start_in_loc + start_n + offs_n\n        )\n        offs_kv_loc = tl.load(Req_to_tokens + offs_b_loc_prefix, mask=mask_n, other=0)\n\n        # load k in transposed way\n        offs_buf_k = (\n            offs_kv_loc[None, :] * stride_buf_kbs\n            + cur_kv_head * stride_buf_kh\n            + offs_d[:, None]\n        )\n        k = tl.load(\n            K_Buffer + offs_buf_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0\n        )\n\n        qk = tl.dot(q.to(k.dtype), k)\n        if BLOCK_DPE > 0:\n            offs_kpe = (\n                offs_kv_loc[None, :] * stride_buf_kbs\n                + cur_kv_head * stride_buf_kh\n                + offs_dpe[:, None]\n            )\n            kpe = tl.load(\n                K_Buffer + offs_kpe,\n                mask=mask_n[None, :],\n                other=0.0,\n            )\n            qk += tl.dot(qpe.to(kpe.dtype), kpe)\n        qk *= sm_scale\n\n        if logit_cap > 0:\n            qk = logit_cap * tanh(qk / logit_cap)\n\n        qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float(\"-inf\"))\n\n        n_e_max = tl.maximum(tl.max(qk, 1), e_max)\n        re_scale = tl.exp(e_max - n_e_max)\n        p = tl.exp(qk - n_e_max[:, None])\n        deno = deno * re_scale + tl.sum(p, 1)\n\n        offs_buf_v = (\n            offs_kv_loc[:, None] * stride_buf_vbs\n            + cur_kv_head * stride_buf_vh\n            + offs_dv[None, :]\n        )\n        v = tl.load(\n            V_Buffer + offs_buf_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0\n        )\n        p = p.to(v.dtype)\n        acc = acc * re_scale[:, None] + tl.dot(p, v)\n\n        e_max = n_e_max\n\n    # stage 2: compute the triangle part\n\n    cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)\n    for start_n in range(0, cur_block_m_end, BLOCK_N):\n        start_n = tl.multiple_of(start_n, BLOCK_N)\n        mask_n = (start_n + offs_n) < cur_block_m_end\n\n        # load k in transposed way\n        offs_k = (\n            (cur_seq_extend_start_contiguous + start_n + offs_n[None, :]) * stride_kbs\n            + cur_kv_head * stride_kh\n            + offs_d[:, None]\n        )\n        k = tl.load(\n            K_Extend + offs_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0\n        )\n\n        qk = tl.dot(q, k, out_dtype=tl.float32)\n        if BLOCK_DPE > 0:\n            offs_kpe = (\n                (cur_seq_extend_start_contiguous + start_n + offs_n[None, :])\n                * stride_kbs\n                + cur_kv_head * stride_kh\n                + offs_dpe[:, None]\n            )\n            kpe = tl.load(\n                K_Extend + offs_kpe,\n                mask=mask_n[None, :],\n                other=0.0,\n            )\n            qk += tl.dot(qpe, kpe)\n\n        qk *= sm_scale\n\n        if logit_cap > 0:\n            qk = logit_cap * tanh(qk / logit_cap)\n\n        mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (\n            start_n + offs_n[None, :]\n        )\n        mask_causual &= mask_m[:, None] & mask_n[None, :]\n        qk = tl.where(mask_causual, qk, float(\"-inf\"))\n\n        n_e_max = tl.maximum(tl.max(qk, 1), e_max)\n        re_scale = tl.exp(e_max - n_e_max)\n        p = tl.exp(qk - n_e_max[:, None])\n        deno = deno * re_scale + tl.sum(p, 1)\n\n        offs_v = (\n            (cur_seq_extend_start_contiguous + start_n + offs_n[:, None]) * stride_vbs\n            + cur_kv_head * stride_vh\n            + offs_dv[None, :]\n        )\n        v = tl.load(\n            V_Extend + offs_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0\n        )\n        p = p.to(v.dtype)\n        acc = acc * re_scale[:, None] + tl.dot(p, v)\n\n        e_max = n_e_max\n\n    offs_o = (\n        (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])\n        * stride_obs\n        + cur_head * stride_oh\n        + offs_dv[None, :]\n    )\n    tl.store(\n        O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None] & mask_dv[None, :]\n    )\n\n\ndef extend_attention_fwd(\n    q_extend,\n    k_extend,\n    v_extend,\n    o_extend,\n    k_buffer,\n    v_buffer,\n    req_to_tokens,\n    b_req_idx,\n    b_seq_len,\n    b_seq_len_extend,\n    b_start_loc_extend,\n    max_len_extend,\n    sm_scale=None,\n    logit_cap=0.0,\n):\n    \"\"\"\n    q_extend, k_extend, v_extend, o_extend: contiguous tensors\n\n    k_buffer, v_buffer: (prefix + extend) tensors in mem_manager\n    \"\"\"\n    Lq, Lk, Lv = (\n        q_extend.shape[-1],\n        k_extend.shape[-1],\n        v_extend.shape[-1],\n    )\n\n    if Lq == 576:\n        BLOCK_DMODEL = 512\n        BLOCK_DPE = 64\n    elif Lq == 288:\n        BLOCK_DMODEL = 256\n        BLOCK_DPE = 32\n    elif Lq == 192:\n        BLOCK_DMODEL = 128\n        BLOCK_DPE = 64\n    else:\n        BLOCK_DMODEL = triton.next_power_of_2(Lq)\n        BLOCK_DPE = 0\n    BLOCK_DV = triton.next_power_of_2(Lv)\n\n    if _is_hip:\n        BLOCK_M, BLOCK_N = (64, 64)\n        num_warps = 4\n\n    else:\n        if _is_cuda and CUDA_CAPABILITY[0] >= 9:\n            if Lq <= 256:\n                BLOCK_M, BLOCK_N = (128, 64)\n            else:\n                BLOCK_M, BLOCK_N = (32, 64)\n        elif _is_cuda and CUDA_CAPABILITY[0] >= 8:\n            if Lq <= 128:\n                BLOCK_M, BLOCK_N = (128, 128)\n            elif Lq <= 256:\n                BLOCK_M, BLOCK_N = (64, 64)\n            else:\n                BLOCK_M, BLOCK_N = (32, 64)\n        else:\n            BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)\n\n        num_warps = 4 if Lk <= 64 else 8\n\n    sm_scale = sm_scale or 1.0 / (Lq**0.5)\n    batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1]\n    kv_group_num = q_extend.shape[1] // k_extend.shape[1]\n\n    grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))\n    num_stages = 1\n\n    extra_kargs = {}\n    if _is_hip:\n        extra_kargs = {\"waves_per_eu\": 4, \"matrix_instr_nonkdim\": 16, \"kpack\": 2}\n\n    _fwd_kernel[grid](\n        q_extend,\n        k_extend,\n        v_extend,\n        o_extend,\n        k_buffer,\n        v_buffer,\n        req_to_tokens,\n        b_req_idx,\n        b_seq_len,\n        b_start_loc_extend,\n        b_seq_len_extend,\n        sm_scale,\n        kv_group_num,\n        q_extend.stride(0),\n        q_extend.stride(1),\n        k_extend.stride(0),\n        k_extend.stride(1),\n        v_extend.stride(0),\n        v_extend.stride(1),\n        o_extend.stride(0),\n        o_extend.stride(1),\n        k_buffer.stride(0),\n        k_buffer.stride(1),\n        v_buffer.stride(0),\n        v_buffer.stride(1),\n        req_to_tokens.stride(0),\n        logit_cap=logit_cap,\n        BLOCK_DMODEL=BLOCK_DMODEL,\n        BLOCK_DPE=BLOCK_DPE,\n        BLOCK_DV=BLOCK_DV,\n        BLOCK_M=BLOCK_M,\n        BLOCK_N=BLOCK_N,\n        Lq=Lq,\n        Lv=Lv,\n        num_warps=num_warps,\n        num_stages=num_stages,\n        **extra_kargs,\n    )\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/triton_ops/extend_attention.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"\nMemory-efficient attention for prefill.\nIt supports page size = 1 and prefill with KV cache (i.e. extend).\n\"\"\"\n\nimport torch\nimport triton\nimport triton.language as tl\n\nfrom sglang.srt.layers.attention.triton_ops.prefill_attention import (\n    context_attention_fwd,\n)\nfrom sglang.srt.utils import is_cuda, is_hip\n\n_is_cuda = is_cuda()\nif _is_cuda:\n    CUDA_CAPABILITY = torch.cuda.get_device_capability()\n\n_is_hip = is_hip()\n\n\ndef _get_block_sizes_for_extend_attention(Lq: int, Lv: int):\n    \"\"\"\n    Get block sizes and configuration for extend attention kernels.\n\n    Args:\n        Lq: Query head dimension\n        Lv: Value head dimension\n\n    Returns:\n        tuple: (BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps)\n    \"\"\"\n    # Determine BLOCK_DMODEL and BLOCK_DPE based on head dimension\n    if Lq == 576:\n        BLOCK_DMODEL = 512\n        BLOCK_DPE = 64\n    elif Lq == 288:\n        BLOCK_DMODEL = 256\n        BLOCK_DPE = 32\n    elif Lq == 192:\n        BLOCK_DMODEL = 128\n        BLOCK_DPE = 64\n    else:\n        BLOCK_DMODEL = triton.next_power_of_2(Lq)\n        BLOCK_DPE = 0\n\n    BLOCK_DV = triton.next_power_of_2(Lv)\n\n    # Determine BLOCK_M, BLOCK_N, and num_warps based on hardware\n    if _is_hip:\n        BLOCK_M, BLOCK_N = (64, 64)\n        num_warps = 4\n    else:\n        if _is_cuda and CUDA_CAPABILITY[0] == 12:\n            # sm120 workstation Blackwell architecture (RTX Pro 6000) has a much smaller shared memory size (100K)\n            if Lq <= 128:\n                BLOCK_M, BLOCK_N = (64, 128)\n            elif Lq <= 256:\n                BLOCK_M, BLOCK_N = (64, 64)\n            else:\n                BLOCK_M, BLOCK_N = (32, 32)\n        elif _is_cuda and CUDA_CAPABILITY[0] >= 9:\n            # Hopper architecture (H100, etc.)\n            if Lq <= 256:\n                BLOCK_M, BLOCK_N = (128, 64)\n            else:\n                BLOCK_M, BLOCK_N = (32, 64)\n        elif _is_cuda and CUDA_CAPABILITY[0] >= 8:\n            # Ampere architecture (A100, etc.)\n            # sm86/sm89 has a much smaller shared memory size (100K) than sm80 (160K)\n            if CUDA_CAPABILITY[1] == 9 or CUDA_CAPABILITY[1] == 6:\n                if Lq <= 128:\n                    BLOCK_M, BLOCK_N = (64, 128)\n                elif Lq <= 256:\n                    BLOCK_M, BLOCK_N = (64, 64)\n                else:\n                    BLOCK_M, BLOCK_N = (32, 32)\n            else:\n                if Lq <= 128:\n                    BLOCK_M, BLOCK_N = (128, 128)\n                elif Lq <= 256:\n                    BLOCK_M, BLOCK_N = (64, 64)\n                else:\n                    BLOCK_M, BLOCK_N = (32, 64)\n        else:\n            # Older architectures\n            BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)\n\n        num_warps = 4 if Lq <= 64 else 8\n\n    return BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps\n\n\n@triton.jit\ndef tanh(x):\n    # Tanh is just a scaled sigmoid\n    return 2 * tl.sigmoid(2 * x) - 1\n\n\n@triton.jit\ndef _copy_unified_indices_kernel(\n    # Input buffers\n    prefix_kv_indptr,\n    prefix_kv_indices,\n    extend_start_loc,\n    extend_seq_lens,\n    extend_kv_indices,\n    unified_kv_indptr,\n    # Output buffer\n    unified_kv_indices,\n    # Size\n    bs,\n):\n    \"\"\"\n    Triton kernel to copy indices to unified buffer (parallel per sequence).\n    Each thread block processes one sequence with vectorized loads/stores.\n    \"\"\"\n    pid = tl.program_id(0)\n\n    if pid >= bs:\n        return\n\n    # Load sequence info\n    prefix_start = tl.load(prefix_kv_indptr + pid)\n    prefix_end = tl.load(prefix_kv_indptr + pid + 1)\n    extend_start = tl.load(extend_start_loc + pid)\n    extend_len = tl.load(extend_seq_lens + pid)\n\n    prefix_len = prefix_end - prefix_start\n    unified_start = tl.load(unified_kv_indptr + pid)\n\n    # Copy indices in vectorized chunks\n    BLOCK_SIZE: tl.constexpr = 128\n\n    # Process prefix indices\n    for block_start in range(0, prefix_len, BLOCK_SIZE):\n        offs = block_start + tl.arange(0, BLOCK_SIZE)\n        mask = offs < prefix_len\n\n        src_idx = prefix_start + offs\n        dst_idx = unified_start + offs\n\n        vals = tl.load(prefix_kv_indices + src_idx, mask=mask, other=0)\n        tl.store(unified_kv_indices + dst_idx, vals, mask=mask)\n\n    # Process extend indices\n    for block_start in range(0, extend_len, BLOCK_SIZE):\n        offs = block_start + tl.arange(0, BLOCK_SIZE)\n        mask = offs < extend_len\n\n        src_idx = extend_start + offs\n        dst_idx = unified_start + prefix_len + offs\n\n        vals = tl.load(extend_kv_indices + src_idx, mask=mask, other=0)\n        tl.store(unified_kv_indices + dst_idx, vals, mask=mask)\n\n\ndef build_unified_kv_indices(\n    prefix_kv_indptr: torch.Tensor,\n    prefix_kv_indices: torch.Tensor,\n    extend_start_loc: torch.Tensor,\n    extend_seq_lens: torch.Tensor,\n    extend_kv_indices: torch.Tensor,\n    bs: int,\n) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Build unified KV indices efficiently:\n    - Use PyTorch's optimized cumsum (NVIDIA CUB) for indptr\n    - Use Triton kernel for parallel index copying\n\n    Returns:\n        (unified_kv_indptr, unified_kv_indices, prefix_lens)\n    \"\"\"\n    device = prefix_kv_indptr.device\n\n    prefix_lens = prefix_kv_indptr[1 : bs + 1] - prefix_kv_indptr[:bs]\n\n    # Create unified_kv_indptr avoiding direct assignment (for CUDA graph compatibility)\n    unified_lens = prefix_lens + extend_seq_lens[:bs]\n    unified_kv_indptr = torch.cat(\n        [\n            torch.zeros(1, dtype=torch.int32, device=device),\n            torch.cumsum(unified_lens, dim=0),\n        ]\n    )\n\n    max_unified_len = len(prefix_kv_indices) + len(extend_kv_indices)\n\n    unified_kv_indices = torch.empty(max_unified_len, dtype=torch.int64, device=device)\n\n    # Launch Triton kernel for parallel index copying\n    _copy_unified_indices_kernel[(bs,)](\n        prefix_kv_indptr,\n        prefix_kv_indices,\n        extend_start_loc,\n        extend_seq_lens,\n        extend_kv_indices,\n        unified_kv_indptr,\n        unified_kv_indices,\n        bs,\n    )\n\n    return unified_kv_indptr, unified_kv_indices, prefix_lens\n\n\n@triton.jit\ndef _fwd_kernel(\n    Q_Extend,\n    K_Extend,\n    V_Extend,\n    O_Extend,\n    K_Buffer,\n    V_Buffer,\n    qo_indptr,\n    kv_indptr,\n    kv_indices,\n    mask_ptr,\n    mask_indptr,\n    sink_ptr,\n    window_kv_offset_ptr,\n    sm_scale,\n    k_scale,\n    v_scale,\n    kv_group_num,\n    stride_qbs,\n    stride_qh,\n    stride_kbs,\n    stride_kh,\n    stride_vbs,\n    stride_vh,\n    stride_obs,\n    stride_oh,\n    stride_buf_kbs,\n    stride_buf_kh,\n    stride_buf_vbs,\n    stride_buf_vh,\n    SLIDING_WINDOW_SIZE: tl.constexpr,\n    logit_cap: tl.constexpr,\n    xai_temperature_len: tl.constexpr,\n    Lq: tl.constexpr,\n    Lv: tl.constexpr,\n    BLOCK_DMODEL: tl.constexpr,\n    BLOCK_DPE: tl.constexpr,\n    BLOCK_DV: tl.constexpr,\n    BLOCK_M: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    USE_CUSTOM_MASK: tl.constexpr,\n    IS_CAUSAL: tl.constexpr,\n    SKIP_PREFIX_CUSTOM_MASK: tl.constexpr,\n    STORE_TRANSPOSE: tl.constexpr,\n    HAS_SINK: tl.constexpr,\n):\n    cur_seq = tl.program_id(0)\n    cur_head = tl.program_id(1)\n    cur_block_m = tl.program_id(2)\n    cur_kv_head = cur_head // kv_group_num\n\n    cur_seq_extend_start_idx = tl.load(qo_indptr + cur_seq)\n    cur_seq_len_extend = tl.load(qo_indptr + cur_seq + 1) - cur_seq_extend_start_idx\n    cur_seq_kv_start_idx = tl.load(kv_indptr + cur_seq)\n    cur_seq_len_prefix = tl.load(kv_indptr + cur_seq + 1) - cur_seq_kv_start_idx\n    cur_seq_len = cur_seq_len_prefix + cur_seq_len_extend\n\n    if USE_CUSTOM_MASK:\n        cur_seq_mask_start_idx = tl.load(mask_indptr + cur_seq)\n\n    # For SWA, we should only load the mask in the sliding window\n    window_kv_offset = 0\n    if USE_CUSTOM_MASK and SLIDING_WINDOW_SIZE > 0:\n        window_kv_offset = tl.load(window_kv_offset_ptr + cur_seq)\n\n    offs_d = tl.arange(0, BLOCK_DMODEL)\n    offs_dv = tl.arange(0, BLOCK_DV)\n    offs_m = tl.arange(0, BLOCK_M)\n    mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend\n\n    mask_d = offs_d < Lq\n    mask_dv = offs_dv < Lv\n\n    if xai_temperature_len > 0:\n        offs_qidx = cur_seq_len_prefix + cur_block_m * BLOCK_M + offs_m\n        xai_temperature_scale = 1.0 / tl.log2(float(xai_temperature_len))\n        xai_temperature_reg = tl.where(\n            offs_qidx > xai_temperature_len,\n            tl.log2(offs_qidx.to(tl.float32)) * xai_temperature_scale,\n            1.0,\n        )\n\n    offs_q = (\n        (cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])\n        * stride_qbs\n        + cur_head * stride_qh\n        + offs_d[None, :]\n    )\n    q = tl.load(\n        Q_Extend + offs_q, mask=(mask_m[:, None]) & (mask_d[None, :]), other=0.0\n    )\n\n    if BLOCK_DPE > 0:\n        offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)\n        offs_qpe = (\n            (cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])\n            * stride_qbs\n            + cur_head * stride_qh\n            + offs_dpe[None, :]\n        )\n        qpe = tl.load(Q_Extend + offs_qpe, mask=mask_m[:, None], other=0.0)\n\n    # stage 1: compute scores with prefix\n    offs_n = tl.arange(0, BLOCK_N)\n\n    acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32)\n    deno = tl.zeros([BLOCK_M], dtype=tl.float32)\n    e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n\n    for start_n in range(0, cur_seq_len_prefix, BLOCK_N):\n        start_n = tl.multiple_of(start_n, BLOCK_N)\n        mask_n = (start_n + offs_n) < cur_seq_len_prefix\n\n        final_mask = mask_m[:, None] & mask_n[None, :]\n        if USE_CUSTOM_MASK and not SKIP_PREFIX_CUSTOM_MASK:\n            custom_mask = tl.load(\n                mask_ptr\n                + cur_seq_mask_start_idx\n                + (cur_block_m * BLOCK_M + offs_m[:, None])\n                * (cur_seq_len + window_kv_offset)\n                + window_kv_offset\n                + start_n\n                + offs_n[None, :],\n                mask=(mask_m[:, None] & mask_n[None, :]),\n                other=0,\n            )\n            final_mask &= custom_mask\n        if SLIDING_WINDOW_SIZE > 0:\n            # Add mask where q_id <= kv_id + sliding_window_size\n            # q_id = prefix_len + cur_m, kv_id = cur_n\n            window_mask = (\n                cur_seq_len_prefix + cur_block_m * BLOCK_M + offs_m[:, None]\n            ) <= (start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE)\n            final_mask &= window_mask\n\n        SKIP_TILE = False\n        if (USE_CUSTOM_MASK and not SKIP_PREFIX_CUSTOM_MASK) or SLIDING_WINDOW_SIZE > 0:\n            SKIP_TILE = tl.max(tl.max(final_mask.to(tl.int32), axis=1), axis=0) == 0\n\n        if not SKIP_TILE:\n            offs_kv_loc = tl.load(\n                kv_indices + cur_seq_kv_start_idx + start_n + offs_n,\n                mask=mask_n,\n                other=0,\n            )\n\n            # load k in transposed way\n            offs_buf_k = (\n                offs_kv_loc[None, :] * stride_buf_kbs\n                + cur_kv_head * stride_buf_kh\n                + offs_d[:, None]\n            )\n            k = tl.load(\n                K_Buffer + offs_buf_k,\n                mask=(mask_n[None, :]) & (mask_d[:, None]),\n                other=0.0,\n            )\n\n            qk = tl.dot(q.to(k.dtype), k)\n            if BLOCK_DPE > 0:\n                offs_kpe = (\n                    offs_kv_loc[None, :] * stride_buf_kbs\n                    + cur_kv_head * stride_buf_kh\n                    + offs_dpe[:, None]\n                )\n                kpe = tl.load(\n                    K_Buffer + offs_kpe,\n                    mask=mask_n[None, :],\n                    other=0.0,\n                )\n                qk += tl.dot(qpe.to(kpe.dtype), kpe)\n            qk *= sm_scale * k_scale\n\n            if logit_cap > 0:\n                qk = logit_cap * tanh(qk / logit_cap)\n\n            if xai_temperature_len > 0:\n                qk *= xai_temperature_reg[:, None]\n\n            qk = tl.where(final_mask, qk, float(\"-inf\"))\n\n            row_max = tl.max(qk, 1)\n            row_max_fixed = tl.where(row_max == float(\"-inf\"), -1e20, row_max)\n            n_e_max = tl.maximum(row_max_fixed, e_max)\n\n            re_scale = tl.exp(e_max - n_e_max)\n            p = tl.exp(qk - n_e_max[:, None])\n            deno = deno * re_scale + tl.sum(p, 1)\n\n            offs_buf_v = (\n                offs_kv_loc[:, None] * stride_buf_vbs\n                + cur_kv_head * stride_buf_vh\n                + offs_dv[None, :]\n            )\n            v = tl.load(\n                V_Buffer + offs_buf_v,\n                mask=mask_n[:, None] & mask_dv[None, :],\n                other=0.0,\n            )\n            p = p.to(v.dtype)\n            acc = acc * re_scale[:, None] + tl.dot(p, v) * v_scale\n\n            e_max = n_e_max\n\n    # stage 2: compute the triangle part\n\n    cur_block_m_end = (\n        cur_seq_len_extend\n        if not IS_CAUSAL\n        else tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)\n    )\n    for start_n in range(0, cur_block_m_end, BLOCK_N):\n        start_n = tl.multiple_of(start_n, BLOCK_N)\n        mask_n = (start_n + offs_n) < cur_block_m_end\n\n        final_mask = mask_m[:, None] & mask_n[None, :]\n        if USE_CUSTOM_MASK:\n            custom_mask = tl.load(\n                mask_ptr\n                + cur_seq_mask_start_idx\n                + (cur_block_m * BLOCK_M + offs_m[:, None])\n                * (cur_seq_len + window_kv_offset)\n                + window_kv_offset\n                + cur_seq_len_prefix\n                + start_n\n                + offs_n[None, :],\n                mask=(mask_m[:, None] & mask_n[None, :]),\n                other=0,\n            )\n            custom_mask &= mask_m[:, None] & mask_n[None, :]\n            final_mask &= custom_mask\n        elif IS_CAUSAL:\n            mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (\n                start_n + offs_n[None, :]\n            )\n            mask_causual &= mask_m[:, None] & mask_n[None, :]\n            final_mask &= mask_causual\n        else:\n            mask_non_causal = mask_m[:, None] & mask_n[None, :]\n            final_mask &= mask_non_causal\n\n        if SLIDING_WINDOW_SIZE > 0:\n            # Add mask where q_id <= kv_id + sliding_window_size\n            window_mask = (cur_block_m * BLOCK_M + offs_m[:, None]) <= (\n                start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE\n            )\n            final_mask &= window_mask\n\n        SKIP_TILE = False\n        if USE_CUSTOM_MASK or SLIDING_WINDOW_SIZE > 0:\n            SKIP_TILE = tl.max(tl.max(final_mask.to(tl.int32), axis=1), axis=0) == 0\n\n        if not SKIP_TILE:\n            # load k in transposed way\n            offs_k = (\n                (cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs\n                + cur_kv_head * stride_kh\n                + offs_d[:, None]\n            )\n            k = tl.load(\n                K_Extend + offs_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0\n            )\n\n            qk = tl.dot(q, k, out_dtype=tl.float32)\n            if BLOCK_DPE > 0:\n                offs_kpe = (\n                    (cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs\n                    + cur_kv_head * stride_kh\n                    + offs_dpe[:, None]\n                )\n                kpe = tl.load(\n                    K_Extend + offs_kpe,\n                    mask=mask_n[None, :],\n                    other=0.0,\n                )\n                qk += tl.dot(qpe, kpe)\n\n            qk *= sm_scale\n\n            if logit_cap > 0:\n                qk = logit_cap * tanh(qk / logit_cap)\n\n            if xai_temperature_len > 0:\n                qk *= xai_temperature_reg[:, None]\n\n            qk = tl.where(final_mask, qk, float(\"-inf\"))\n\n            row_max = tl.max(qk, 1)\n            row_max_fixed = tl.where(row_max == float(\"-inf\"), -1e20, row_max)\n            n_e_max = tl.maximum(row_max_fixed, e_max)\n\n            re_scale = tl.exp(e_max - n_e_max)\n            p = tl.exp(qk - n_e_max[:, None])\n            deno = deno * re_scale + tl.sum(p, 1)\n\n            offs_v = (\n                (cur_seq_extend_start_idx + start_n + offs_n[:, None]) * stride_vbs\n                + cur_kv_head * stride_vh\n                + offs_dv[None, :]\n            )\n            v = tl.load(\n                V_Extend + offs_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0\n            )\n            p = p.to(v.dtype)\n            acc = acc * re_scale[:, None] + tl.dot(p, v)\n\n            e_max = n_e_max\n\n    if HAS_SINK:\n        cur_sink = tl.load(sink_ptr + cur_head)\n        deno += tl.exp(cur_sink - e_max)\n\n    offs_o = (\n        (cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])\n        * stride_obs\n        + cur_head * stride_oh\n        + offs_dv[None, :]\n    )\n    if STORE_TRANSPOSE:\n        tl.store(\n            O_Extend + offs_o.T,\n            (acc / deno[:, None]).T,\n            mask=(mask_m[:, None] & mask_dv[None, :]).T,\n        )\n    else:\n        tl.store(\n            O_Extend + offs_o,\n            acc / deno[:, None],\n            mask=mask_m[:, None] & mask_dv[None, :],\n        )\n\n\ndef extend_attention_fwd(\n    q_extend,\n    k_extend,\n    v_extend,\n    o_extend,\n    k_buffer,\n    v_buffer,\n    qo_indptr,\n    kv_indptr,\n    kv_indices,\n    custom_mask,\n    is_causal,\n    mask_indptr,\n    max_len_extend,\n    k_scale,\n    v_scale,\n    sm_scale=None,\n    logit_cap=0.0,\n    skip_prefix_custom_mask=True,\n    sliding_window_size=-1,\n    sinks=None,\n    window_kv_offsets=None,\n    xai_temperature_len=-1,\n):\n    \"\"\"\n    q_extend, k_extend, v_extend, o_extend: contiguous tensors\n\n    k_buffer, v_buffer: (prefix + extend) tensors in mem_manager\n    \"\"\"\n    Lq, Lk, Lv = (\n        q_extend.shape[-1],\n        k_extend.shape[-1],\n        v_extend.shape[-1],\n    )\n\n    # Get block sizes and configuration\n    BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps = (\n        _get_block_sizes_for_extend_attention(Lq, Lv)\n    )\n\n    sm_scale = sm_scale or 1.0 / (Lq**0.5)\n    batch_size, head_num = qo_indptr.shape[0] - 1, q_extend.shape[1]\n    kv_group_num = q_extend.shape[1] // k_extend.shape[1]\n\n    USE_CUSTOM_MASK = custom_mask is not None\n    # Skip custom mask for prefix part\n    SKIP_PREFIX_CUSTOM_MASK = skip_prefix_custom_mask\n\n    HAS_SINK = sinks is not None\n\n    grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))\n    num_stages = 1\n\n    extra_kargs = {}\n    if _is_hip:\n        extra_kargs = {\"waves_per_eu\": 1, \"matrix_instr_nonkdim\": 16, \"kpack\": 2}\n\n    _fwd_kernel[grid](\n        q_extend,\n        k_extend,\n        v_extend,\n        o_extend,\n        k_buffer,\n        v_buffer,\n        qo_indptr,\n        kv_indptr,\n        kv_indices,\n        custom_mask,\n        mask_indptr,\n        sinks,\n        window_kv_offsets,\n        sm_scale,\n        k_scale,\n        v_scale,\n        kv_group_num,\n        q_extend.stride(0),\n        q_extend.stride(1),\n        k_extend.stride(0),\n        k_extend.stride(1),\n        v_extend.stride(0),\n        v_extend.stride(1),\n        o_extend.stride(0),\n        o_extend.stride(1),\n        k_buffer.stride(0),\n        k_buffer.stride(1),\n        v_buffer.stride(0),\n        v_buffer.stride(1),\n        SLIDING_WINDOW_SIZE=sliding_window_size,\n        logit_cap=logit_cap,\n        xai_temperature_len=xai_temperature_len,\n        BLOCK_DMODEL=BLOCK_DMODEL,\n        BLOCK_DPE=BLOCK_DPE,\n        BLOCK_DV=BLOCK_DV,\n        BLOCK_M=BLOCK_M,\n        BLOCK_N=BLOCK_N,\n        Lq=Lq,\n        Lv=Lv,\n        USE_CUSTOM_MASK=USE_CUSTOM_MASK,\n        IS_CAUSAL=is_causal,\n        SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK,\n        HAS_SINK=HAS_SINK,\n        STORE_TRANSPOSE=_is_hip,\n        num_warps=num_warps,\n        num_stages=num_stages,\n        **extra_kargs,\n    )\n\n\ndef redundant_attention(\n    q_extend,\n    o_extend,\n    k_buffer,\n    v_buffer,\n    b_req_idx,\n    b_start_loc,\n    b_seq_len,\n    b_seq_len_prefix,\n    max_len_in_batch,\n):\n    total_token_num = k_buffer.shape[0]\n    B, H_Q, D = b_req_idx.shape[0], q_extend.shape[-2], q_extend.shape[-1]\n    q_buffer = torch.empty(\n        (total_token_num, H_Q, D), dtype=q_extend.dtype, device=q_extend.device\n    )\n\n    pt = 0\n    for i in range(B):\n        cur_seq_len_extend = b_seq_len[i] - b_seq_len_prefix[i]\n        pl, pr = b_start_loc[i] + b_seq_len_prefix[i], b_start_loc[i] + b_seq_len[i]\n        q_buffer[pl:pr] = q_extend[pt : pt + cur_seq_len_extend]\n        pt += cur_seq_len_extend\n\n    o_buffer = torch.empty_like(q_buffer)\n    context_attention_fwd(\n        q_buffer, k_buffer, v_buffer, o_buffer, b_start_loc, b_seq_len, max_len_in_batch\n    )\n\n    pt = 0\n    for i in range(B):\n        cur_seq_len_extend = b_seq_len[i] - b_seq_len_prefix[i]\n        pl, pr = b_start_loc[i] + b_seq_len_prefix[i], b_start_loc[i] + b_seq_len[i]\n        o_extend[pt : pt + cur_seq_len_extend] = o_buffer[pl:pr]\n        pt += cur_seq_len_extend\n\n\n@triton.jit\ndef _fwd_kernel_unified(\n    Q,\n    O,\n    K_Buffer,\n    V_Buffer,\n    qo_indptr,\n    kv_indptr,\n    kv_indices,\n    prefix_lens,\n    mask_ptr,\n    mask_indptr,\n    sink_ptr,\n    window_start_pos,\n    sm_scale_withk,\n    v_scale,\n    kv_group_num,\n    stride_qbs,\n    stride_qh,\n    stride_obs,\n    stride_oh,\n    stride_buf_kbs,\n    stride_buf_kh,\n    stride_buf_vbs,\n    stride_buf_vh,\n    SLIDING_WINDOW_SIZE: tl.constexpr,\n    logit_cap: tl.constexpr,\n    xai_temperature_len: tl.constexpr,\n    Lq: tl.constexpr,\n    Lv: tl.constexpr,\n    BLOCK_DMODEL: tl.constexpr,\n    BLOCK_DPE: tl.constexpr,\n    BLOCK_DV: tl.constexpr,\n    BLOCK_M: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    IS_CAUSAL: tl.constexpr,\n    USE_CUSTOM_MASK: tl.constexpr,\n    HAS_SINK: tl.constexpr,\n):\n    \"\"\"\n    Unified 1-stage kernel for deterministic extend attention.\n    Both prefix and extend KV are accessed through the unified kv_indices.\n    \"\"\"\n    cur_seq = tl.program_id(0)\n    cur_head = tl.program_id(1)\n    cur_block_m = tl.program_id(2)\n    cur_kv_head = cur_head // kv_group_num\n\n    # Load sequence information\n    cur_seq_q_start_idx = tl.load(qo_indptr + cur_seq)\n    cur_seq_q_len = tl.load(qo_indptr + cur_seq + 1) - cur_seq_q_start_idx\n    cur_seq_kv_start_idx = tl.load(kv_indptr + cur_seq)\n    cur_seq_kv_len = tl.load(kv_indptr + cur_seq + 1) - cur_seq_kv_start_idx\n    cur_seq_prefix_len = tl.load(prefix_lens + cur_seq)\n\n    # Load window start position for sliding window attention\n    # This is the absolute position of the first key in the window (0 if no sliding window)\n    cur_window_start = 0\n    if SLIDING_WINDOW_SIZE > 0:\n        cur_window_start = tl.load(window_start_pos + cur_seq)\n\n    # Load custom mask start index if using custom mask (for speculative decoding)\n    if USE_CUSTOM_MASK:\n        cur_seq_mask_start_idx = tl.load(mask_indptr + cur_seq)\n\n    offs_d = tl.arange(0, BLOCK_DMODEL)\n    offs_dv = tl.arange(0, BLOCK_DV)\n    offs_m = tl.arange(0, BLOCK_M)\n    mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_q_len\n    mask_d = offs_d < Lq\n    mask_dv = offs_dv < Lv\n\n    # XAI temperature handling\n    if xai_temperature_len > 0:\n        offs_qidx = cur_seq_prefix_len + cur_block_m * BLOCK_M + offs_m\n        xai_temperature_reg = tl.where(\n            offs_qidx < xai_temperature_len,\n            1.0,\n            xai_temperature_len / (offs_qidx + 1.0),\n        )\n\n    # Load Q\n    offs_q = (\n        (cur_seq_q_start_idx + cur_block_m * BLOCK_M + offs_m[:, None]) * stride_qbs\n        + cur_head * stride_qh\n        + offs_d[None, :]\n    )\n    q = tl.load(Q + offs_q, mask=(mask_m[:, None]) & (mask_d[None, :]), other=0.0)\n\n    if BLOCK_DPE > 0:\n        offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)\n        offs_qpe = (\n            (cur_seq_q_start_idx + cur_block_m * BLOCK_M + offs_m[:, None]) * stride_qbs\n            + cur_head * stride_qh\n            + offs_dpe[None, :]\n        )\n        qpe = tl.load(Q + offs_qpe, mask=mask_m[:, None], other=0.0)\n\n    # Initialize accumulators\n    offs_n = tl.arange(0, BLOCK_N)\n    acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32)\n    deno = tl.zeros([BLOCK_M], dtype=tl.float32)\n    e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n\n    # Unified loop: process all KV tokens (prefix + extend)\n    for start_n in range(0, cur_seq_kv_len, BLOCK_N):\n        start_n = tl.multiple_of(start_n, BLOCK_N)\n        mask_n = (start_n + offs_n) < cur_seq_kv_len\n\n        # Compute mask\n        final_mask = mask_m[:, None] & mask_n[None, :]\n\n        # Apply custom mask if provided\n        if USE_CUSTOM_MASK:\n            custom_mask = tl.load(\n                mask_ptr\n                + cur_seq_mask_start_idx\n                + (cur_block_m * BLOCK_M + offs_m[:, None]) * cur_seq_kv_len\n                + start_n\n                + offs_n[None, :],\n                mask=(mask_m[:, None] & mask_n[None, :]),\n                other=0,\n            )\n            final_mask &= custom_mask\n\n        # Apply causal mask for extend part\n        if IS_CAUSAL and not USE_CUSTOM_MASK:\n            # Determine if current KV block is in extend region\n            # Only apply causal mask when both Q and K are in extend region\n            q_idx = cur_block_m * BLOCK_M + offs_m[:, None]\n            k_idx_in_total = start_n + offs_n[None, :]\n\n            # Causal mask: q_idx >= (k_idx - prefix_len) when k_idx >= prefix_len\n            # For prefix region (k_idx < prefix_len), no causal mask\n            k_is_extend = k_idx_in_total >= cur_seq_prefix_len\n            k_idx_in_extend = k_idx_in_total - cur_seq_prefix_len\n            causal_mask = tl.where(\n                k_is_extend,\n                q_idx >= k_idx_in_extend,\n                True,  # No causal mask for prefix\n            )\n            final_mask &= causal_mask\n\n        if SLIDING_WINDOW_SIZE > 0:\n            # Sliding window mask with correct absolute positions\n            # Q absolute position: window_start + prefix_len + q_position_in_extend\n            q_abs_pos = (\n                cur_window_start\n                + cur_seq_prefix_len\n                + cur_block_m * BLOCK_M\n                + offs_m[:, None]\n            )\n\n            # K absolute position: window_start + k_index_in_unified_array\n            k_abs_pos = cur_window_start + start_n + offs_n[None, :]\n\n            # Sliding window: query can attend to keys within window_size\n            window_mask = q_abs_pos <= (k_abs_pos + SLIDING_WINDOW_SIZE)\n            final_mask &= window_mask\n\n        # Check if we can skip this tile\n        SKIP_TILE = False\n        if USE_CUSTOM_MASK or SLIDING_WINDOW_SIZE > 0:\n            SKIP_TILE = tl.max(tl.max(final_mask.to(tl.int32), axis=1), axis=0) == 0\n\n        if not SKIP_TILE:\n            # Load KV indices\n            offs_kv_loc = tl.load(\n                kv_indices + cur_seq_kv_start_idx + start_n + offs_n,\n                mask=mask_n,\n                other=0,\n            )\n\n            # Load K\n            offs_buf_k = (\n                offs_kv_loc[None, :] * stride_buf_kbs\n                + cur_kv_head * stride_buf_kh\n                + offs_d[:, None]\n            )\n            k = tl.load(\n                K_Buffer + offs_buf_k,\n                mask=(mask_n[None, :]) & (mask_d[:, None]),\n                other=0.0,\n            )\n\n            # Compute QK\n            qk = tl.dot(q.to(k.dtype), k)\n            if BLOCK_DPE > 0:\n                offs_kpe = (\n                    offs_kv_loc[None, :] * stride_buf_kbs\n                    + cur_kv_head * stride_buf_kh\n                    + offs_dpe[:, None]\n                )\n                kpe = tl.load(\n                    K_Buffer + offs_kpe,\n                    mask=mask_n[None, :],\n                    other=0.0,\n                )\n                qk += tl.dot(qpe.to(kpe.dtype), kpe)\n\n            qk *= sm_scale_withk\n\n            if logit_cap > 0:\n                qk = logit_cap * tanh(qk / logit_cap)\n\n            if xai_temperature_len > 0:\n                qk *= xai_temperature_reg[:, None]\n\n            qk = tl.where(final_mask, qk, float(\"-inf\"))\n\n            # Online softmax\n            row_max = tl.max(qk, 1)\n            row_max_fixed = tl.where(row_max == float(\"-inf\"), -1e20, row_max)\n            n_e_max = tl.maximum(row_max_fixed, e_max)\n\n            re_scale = tl.exp(e_max - n_e_max)\n            p = tl.exp(qk - n_e_max[:, None])\n            deno = deno * re_scale + tl.sum(p, 1)\n\n            # Load V\n            offs_buf_v = (\n                offs_kv_loc[:, None] * stride_buf_vbs\n                + cur_kv_head * stride_buf_vh\n                + offs_dv[None, :]\n            )\n            v = tl.load(\n                V_Buffer + offs_buf_v,\n                mask=mask_n[:, None] & mask_dv[None, :],\n                other=0.0,\n            )\n            p = p.to(v.dtype)\n            acc = acc * re_scale[:, None] + tl.dot(p, v)\n\n            e_max = n_e_max\n\n    # Handle sink tokens\n    if HAS_SINK:\n        cur_sink = tl.load(sink_ptr + cur_head)\n        deno += tl.exp(cur_sink - e_max)\n\n    # Store output\n    offs_o = (\n        (cur_seq_q_start_idx + cur_block_m * BLOCK_M + offs_m[:, None]) * stride_obs\n        + cur_head * stride_oh\n        + offs_dv[None, :]\n    )\n    tl.store(\n        O + offs_o,\n        acc / deno[:, None] * v_scale,\n        mask=mask_m[:, None] & mask_dv[None, :],\n    )\n\n\ndef extend_attention_fwd_unified(\n    q,\n    o,\n    k_buffer,\n    v_buffer,\n    k_scale,\n    v_scale,\n    qo_indptr,\n    kv_indptr,\n    kv_indices,\n    prefix_lens,\n    max_len_extend,\n    custom_mask=None,\n    mask_indptr=None,\n    sm_scale=None,\n    logit_cap=0.0,\n    is_causal=True,\n    sliding_window_size=-1,\n    sinks=None,\n    window_start_pos=None,\n    xai_temperature_len=-1,\n):\n    \"\"\"\n    Unified 1-stage extend attention for deterministic inference.\n\n    Args:\n        q: Query tensor [num_tokens, num_heads, head_dim]\n        o: Output tensor [num_tokens, num_heads, head_dim]\n        k_buffer: Key cache buffer\n        v_buffer: Value cache buffer\n        qo_indptr: Query offsets [batch_size + 1]\n        kv_indptr: KV offsets [batch_size + 1] (includes both prefix and extend)\n        kv_indices: Unified KV indices (both prefix and extend)\n        prefix_lens: Prefix length for each sequence [batch_size]\n        max_len_extend: Maximum extend length\n        custom_mask: Custom attention mask (for speculative decoding tree attention)\n        mask_indptr: Mask offsets [batch_size + 1]\n        sm_scale: Softmax scale\n        logit_cap: Logit capping value\n        is_causal: Whether to apply causal mask\n        sliding_window_size: Sliding window size (-1 for no sliding window)\n        sinks: Sink tokens\n        window_start_pos: Absolute position of first key in sliding window [batch_size]\n                         (None if sliding window not used)\n        xai_temperature_len: XAI temperature length\n    \"\"\"\n    Lq, Lv = q.shape[-1], v_buffer.shape[-1]\n\n    # Get block sizes and configuration\n    BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps = (\n        _get_block_sizes_for_extend_attention(Lq, Lv)\n    )\n\n    sm_scale = sm_scale or 1.0 / (Lq**0.5)\n    batch_size, head_num = qo_indptr.shape[0] - 1, q.shape[1]\n    kv_group_num = q.shape[1] // k_buffer.shape[1]\n\n    USE_CUSTOM_MASK = custom_mask is not None\n    HAS_SINK = sinks is not None\n\n    # For sliding window attention, window_start_pos tracks the absolute position\n    # of the first key in each sequence's window\n    if sliding_window_size > 0 and window_start_pos is None:\n        # If not provided, assume window starts at position 0\n        window_start_pos = torch.zeros(batch_size, dtype=torch.int32, device=q.device)\n\n    grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))\n    num_stages = 1\n\n    extra_kargs = {}\n    if _is_hip:\n        extra_kargs = {\"waves_per_eu\": 1, \"matrix_instr_nonkdim\": 16, \"kpack\": 2}\n\n    _fwd_kernel_unified[grid](\n        q,\n        o,\n        k_buffer,\n        v_buffer,\n        qo_indptr,\n        kv_indptr,\n        kv_indices,\n        prefix_lens,\n        custom_mask,\n        mask_indptr,\n        sinks,\n        window_start_pos,\n        sm_scale * k_scale,\n        v_scale,\n        kv_group_num,\n        q.stride(0),\n        q.stride(1),\n        o.stride(0),\n        o.stride(1),\n        k_buffer.stride(0),\n        k_buffer.stride(1),\n        v_buffer.stride(0),\n        v_buffer.stride(1),\n        SLIDING_WINDOW_SIZE=sliding_window_size,\n        logit_cap=logit_cap,\n        xai_temperature_len=xai_temperature_len,\n        BLOCK_DMODEL=BLOCK_DMODEL,\n        BLOCK_DPE=BLOCK_DPE,\n        BLOCK_DV=BLOCK_DV,\n        BLOCK_M=BLOCK_M,\n        BLOCK_N=BLOCK_N,\n        Lq=Lq,\n        Lv=Lv,\n        IS_CAUSAL=is_causal,\n        USE_CUSTOM_MASK=USE_CUSTOM_MASK,\n        HAS_SINK=HAS_SINK,\n        num_warps=num_warps,\n        num_stages=num_stages,\n        **extra_kargs,\n    )\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/triton_ops/merge_state.py",
    "content": "from typing import Optional, Tuple\n\nimport torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef merge_state_kernel(\n    output,  # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_merged\n    output_lse,  # [NUM_TOKENS, NUM_HEADS] s_merged\n    prefix_output,  # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_a\n    prefix_lse,  # [NUM_TOKENS, NUM_HEADS] s_a\n    suffix_output,  # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_b\n    suffix_lse,  # [NUM_TOKENS, NUM_HEADS] s_b\n    HEAD_SIZE: tl.constexpr,\n    PADDED_HEAD_SIZE: tl.constexpr,\n    OUTPUT_LSE: tl.constexpr,\n):\n    token_idx = tl.program_id(0)\n    num_tokens = tl.num_programs(0)\n    head_idx = tl.program_id(1)\n    num_heads = tl.num_programs(1)\n\n    p_lse = tl.load(prefix_lse + token_idx * num_heads + head_idx)\n    s_lse = tl.load(suffix_lse + token_idx * num_heads + head_idx)\n    p_lse = float(\"-inf\") if p_lse == float(\"inf\") else p_lse\n    s_lse = float(\"-inf\") if s_lse == float(\"inf\") else s_lse\n\n    max_lse = tl.maximum(p_lse, s_lse)\n    p_lse = p_lse - max_lse\n    s_lse = s_lse - max_lse\n    out_se = tl.exp(p_lse) + tl.exp(s_lse)\n\n    if OUTPUT_LSE:\n        out_lse = tl.log(out_se) + max_lse\n        tl.store(output_lse + token_idx * num_heads + head_idx, out_lse)\n\n    head_arange = tl.arange(0, PADDED_HEAD_SIZE)\n    head_mask = head_arange < HEAD_SIZE\n    p_out = tl.load(\n        prefix_output\n        + token_idx * num_heads * HEAD_SIZE\n        + head_idx * HEAD_SIZE\n        + head_arange,\n        mask=head_mask,\n    )\n    s_out = tl.load(\n        suffix_output\n        + token_idx * num_heads * HEAD_SIZE\n        + head_idx * HEAD_SIZE\n        + head_arange,\n        mask=head_mask,\n    )\n\n    p_scale = tl.exp(p_lse) / out_se\n    s_scale = tl.exp(s_lse) / out_se\n    out = p_out * p_scale + s_out * s_scale\n    tl.store(\n        output + token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_arange,\n        out,\n        mask=head_mask,\n    )\n\n\ndef merge_state_triton(\n    prefix_output: torch.Tensor,\n    prefix_lse: torch.Tensor,\n    suffix_output: torch.Tensor,\n    suffix_lse: torch.Tensor,\n    output: Optional[torch.Tensor] = None,\n    output_lse: Optional[torch.Tensor] = None,\n) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n    # Avoid creating new tensors if they are already provided\n    if output is None:\n        output = torch.empty_like(prefix_output)\n    if output_lse is None:\n        output_lse = torch.empty_like(prefix_lse)\n\n    num_tokens = output.shape[0]\n    num_query_heads = output.shape[1]\n    head_size = output.shape[2]\n    padded_head_size = triton.next_power_of_2(head_size)\n\n    merge_state_kernel[(num_tokens, num_query_heads)](\n        output,\n        output_lse,\n        prefix_output,\n        prefix_lse,\n        suffix_output,\n        suffix_lse,\n        head_size,\n        padded_head_size,\n        output_lse is not None,\n    )\n    return output, output_lse\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/triton_ops/prefill_attention.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"\nMemory-efficient attention for prefill.\nIt supporst page size = 1.\n\"\"\"\n\n# Adapted from\n# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1\nimport torch\nimport triton\nimport triton.language as tl\n\nfrom sglang.srt.utils import is_cuda, is_hip\n\n_is_cuda = is_cuda()\n_is_hip = is_hip()\n\nif _is_cuda or _is_hip:\n    CUDA_CAPABILITY = torch.cuda.get_device_capability()\n\n\n@triton.jit\ndef _fwd_kernel(\n    Q,\n    K,\n    V,\n    sm_scale,\n    B_Start_Loc,\n    B_Seqlen,\n    Out,\n    stride_qbs,\n    stride_qh,\n    stride_kbs,\n    stride_kh,\n    stride_vbs,\n    stride_vh,\n    stride_obs,\n    stride_oh,\n    kv_group_num: tl.constexpr,\n    BLOCK_M: tl.constexpr,\n    BLOCK_DMODEL: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    IS_CAUSAL: tl.constexpr,\n    Lk: tl.constexpr,\n):\n    cur_batch = tl.program_id(0)\n    cur_head = tl.program_id(1)\n    start_m = tl.program_id(2)\n\n    cur_kv_head = cur_head // kv_group_num\n\n    cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n    cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n\n    block_start_loc = BLOCK_M * start_m\n\n    # initialize offsets\n    offs_n = tl.arange(0, BLOCK_N)\n    offs_d = tl.arange(0, BLOCK_DMODEL)\n    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n    off_q = (\n        (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs\n        + cur_head * stride_qh\n        + offs_d[None, :]\n    )\n    off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None]\n    off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :]\n\n    mask_d = offs_d < Lk\n\n    q = tl.load(\n        Q + off_q,\n        mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :]),\n        other=0.0,\n    )\n\n    k_ptrs = K + off_k\n    v_ptrs = V + off_v\n\n    # initialize pointer to m and l\n    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n    l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n    acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n\n    block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)\n\n    end_n = (\n        cur_batch_seq_len\n        if not IS_CAUSAL\n        else tl.minimum((start_m + 1) * BLOCK_M, cur_batch_seq_len)\n    )\n    for start_n in range(0, block_mask * end_n, BLOCK_N):\n        start_n = tl.multiple_of(start_n, BLOCK_N)\n        # -- compute qk ----\n        k = tl.load(\n            k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,\n            mask=((start_n + offs_n[None, :]) < cur_batch_seq_len) & (mask_d[:, None]),\n            other=0.0,\n        )\n        # mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0)\n\n        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n        qk += tl.dot(q, k)\n        qk *= sm_scale\n\n        if IS_CAUSAL:\n            qk += tl.where(\n                (start_n + offs_n[None, :] < cur_batch_seq_len)\n                & (offs_m[:, None] >= (start_n + offs_n[None, :])),\n                0,\n                float(\"-inf\"),\n            )\n        else:\n            qk += tl.where(\n                (start_n + offs_n[None, :]) < cur_batch_seq_len, 0, float(\"-inf\")\n            )\n\n        # -- compute m_ij, p, l_ij\n        m_ij = tl.max(qk, 1)\n        p = tl.exp(qk - m_ij[:, None])\n        l_ij = tl.sum(p, 1)\n        # -- update m_i and l_i\n        m_i_new = tl.maximum(m_i, m_ij)\n        alpha = tl.exp(m_i - m_i_new)\n        beta = tl.exp(m_ij - m_i_new)\n        l_i_new = alpha * l_i + beta * l_ij\n        # -- update output accumulator --\n        # scale p\n        p_scale = beta / l_i_new\n        p = p * p_scale[:, None]\n        # scale acc\n        acc_scale = l_i / l_i_new * alpha\n        acc = acc * acc_scale[:, None]\n        # update acc\n        v = tl.load(\n            v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,\n            mask=((start_n + offs_n[:, None]) < cur_batch_seq_len) & (mask_d[None, :]),\n            other=0.0,\n        )\n\n        p = p.to(v.dtype)\n        acc += tl.dot(p, v)\n        # update m_i and l_i\n        l_i = l_i_new\n        m_i = m_i_new\n    # initialize pointers to output\n    off_o = (\n        (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs\n        + cur_head * stride_oh\n        + offs_d[None, :]\n    )\n    out_ptrs = Out + off_o\n    tl.store(\n        out_ptrs, acc, mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :])\n    )\n\n\ndef context_attention_fwd(\n    q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True\n):\n    \"\"\"\n    q, k, v: [b * s, head, head_dim]\n    b_start_loc: [b]\n    b_seq_len: [b]\n    out: [b * s, head, head_dim]\n    \"\"\"\n    if (_is_cuda or _is_hip) and CUDA_CAPABILITY[0] > 8:\n        BLOCK = 128\n    else:\n        BLOCK = 64\n\n    Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n\n    sm_scale = 1.0 / (Lq**0.5)\n    batch, head = b_seq_len.shape[0], q.shape[1]\n    kv_group_num = q.shape[1] // k.shape[1]\n\n    grid = (batch, head, triton.cdiv(max_input_len, BLOCK))\n    num_warps = 4 if Lk <= 64 else 8\n\n    _fwd_kernel[grid](\n        q,\n        k,\n        v,\n        sm_scale,\n        b_start_loc,\n        b_seq_len,\n        o,\n        q.stride(0),\n        q.stride(1),\n        k.stride(0),\n        k.stride(1),\n        v.stride(0),\n        v.stride(1),\n        o.stride(0),\n        o.stride(1),\n        kv_group_num=kv_group_num,\n        BLOCK_M=BLOCK,\n        BLOCK_DMODEL=triton.next_power_of_2(Lk),\n        BLOCK_N=BLOCK,\n        IS_CAUSAL=is_causal,\n        num_warps=num_warps,\n        num_stages=1,\n        Lk=Lk,\n    )\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"\nMemory-efficient attention for decoding.\nIt supports page size = 1.\n\"\"\"\n\n# Adapted from\n# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py\n# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py\n\nimport triton\nimport triton.language as tl\n\nfrom sglang.srt.layers.attention.triton_ops.decode_attention import (\n    _decode_softmax_reducev_fwd,\n)\n\n\ndef is_hip():\n    return triton.runtime.driver.active.get_current_target().backend == \"hip\"\n\n\n_is_hip = is_hip()\n\n\n@triton.jit\ndef tanh(x):\n    # Tanh is just a scaled sigmoid\n    return 2 * tl.sigmoid(2 * x) - 1\n\n\n@triton.jit\ndef _fwd_grouped_kernel_stage1_rope(\n    Q,  # Holds [Q_NOPE; Q_PE], b x h x (d+r)\n    K_Buffer,  # Holds [KV; K_PE], b*s x (c+r)\n    V_buffer,  # Holds [KV], b*s x (c)\n    cos_sin_cache,  # max_seq_len x (rotary_dim * 2)\n    positions,  # sequence positions\n    sm_scale,\n    kv_indptr,\n    kv_indices,\n    Att_Out,  # b x h x NUM_KV_SPLITS x (kv_lora_rank + 1)\n    k_pe_t_out,\n    stride_qb,\n    stride_qh,\n    stride_buf_kbs,\n    stride_buf_vbs,\n    stride_mid_ob,\n    stride_mid_oh,\n    stride_mid_os,\n    stride_kpe_tokens_out_b,\n    stride_cos_sin_cache_s,\n    stride_positions_b,\n    rotary_dim: tl.constexpr,\n    kv_lora_rank: tl.constexpr,\n    qk_rope_head_dim: tl.constexpr,\n    kv_group_num: tl.constexpr,\n    q_head_num: tl.constexpr,\n    BLOCK_C: tl.constexpr,\n    BLOCK_R: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    BLOCK_H: tl.constexpr,\n    NUM_KV_SPLITS: tl.constexpr,\n    logit_cap: tl.constexpr,\n    USE_ROPE: tl.constexpr,\n    IS_NEOX_STYLE: tl.constexpr,\n):\n\n    cur_batch = tl.program_id(0)\n    cur_head_id = tl.program_id(1)\n    split_kv_id = tl.program_id(2)\n\n    if BLOCK_H < kv_group_num:\n        VALID_BLOCK_H: tl.constexpr = BLOCK_H\n    else:\n        VALID_BLOCK_H: tl.constexpr = kv_group_num\n    cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)\n    mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H\n    mask_h = mask_h & (cur_head < q_head_num)\n\n    offs_c = tl.arange(0, BLOCK_C)\n    offs_qk_r = tl.arange(kv_lora_rank, kv_lora_rank + BLOCK_R)  # to get the k_pe\n\n    off_q_pe = (\n        cur_batch * stride_qb + cur_head[:, None] * stride_qh + offs_qk_r[None, :]\n    )\n    offs_q = cur_batch * stride_qb + cur_head[:, None] * stride_qh + offs_c[None, :]\n\n    mask_c = offs_c < kv_lora_rank\n    mask_qk_r = offs_qk_r < (kv_lora_rank + qk_rope_head_dim)\n\n    cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch)\n    cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx\n\n    q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_c[None, :]), other=0.0)\n    q_pe = tl.load(\n        Q + off_q_pe, mask=(mask_h[:, None]) & (mask_qk_r[None, :]), other=0.0\n    )\n\n    kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)\n    split_kv_start = kv_len_per_split * split_kv_id\n    split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)\n\n    # apply rotary embedding for q_pe, and k_pe (last token per batch of K_PE)\n    LAST_SPLIT = split_kv_end == cur_batch_seq_len\n    k_pe_last_token = tl.zeros([BLOCK_R], dtype=q.dtype)\n\n    if USE_ROPE:\n        if IS_NEOX_STYLE:\n            # [BLOCK_ROTARY // 2, BLOCK_ROTARY // 2 + 1, BLOCK_ROTARY // 2 + 2, ..., 0, 1, 2, ..., BLOCK_ROTARY // 2 - 1, pass:]\n            offs_qk_rot_r = kv_lora_rank + (\n                (tl.arange(0, BLOCK_R) + (rotary_dim // 2)) % rotary_dim\n            )\n            # Which elements to flip\n            mask_rotate = tl.arange(0, BLOCK_R) < (rotary_dim // 2)\n            # [0 , 1, 2, ..., rotary_dim // 2 - 1, 0 , 1, 2, ..., rotary_dim // 2 - 1]\n            offs_rotary = tl.arange(0, BLOCK_R) % (rotary_dim // 2)\n        else:\n            # [1, 0, 3, 2, 5, 4, ..., BLOCK_R, BLOCK_R - 1]\n            offs_qk_rot_r = (\n                kv_lora_rank\n                + (((tl.arange(0, BLOCK_R) + 1) % 2) * 2)\n                - 1\n                + tl.arange(0, BLOCK_R)\n            )\n            mask_rotate = tl.arange(0, BLOCK_R) % 2 < 1\n            # [0, 0, 1, 1, ..., rotary_dim // 2 - 1, rotary_dim // 2 - 1]\n            offs_rotary = tl.arange(0, BLOCK_R) // 2\n\n        if qk_rope_head_dim > rotary_dim:\n            offs_qk_rot_r = tl.where(\n                tl.arange(0, BLOCK_R) < rotary_dim, offs_qk_rot_r, tl.arange(0, BLOCK_R)\n            )\n            offs_rotary = tl.where(\n                tl.arange(0, BLOCK_R) < rotary_dim, offs_rotary, tl.arange(0, BLOCK_R)\n            )\n\n        mask_rotary = tl.arange(0, BLOCK_R) < rotary_dim\n\n        pos = tl.load(positions + cur_batch * stride_positions_b)\n        cos = tl.load(\n            cos_sin_cache + pos * stride_cos_sin_cache_s + offs_rotary,\n            mask=mask_rotary,\n            other=1.0,\n        )\n        sin = tl.load(\n            cos_sin_cache\n            + pos * stride_cos_sin_cache_s\n            + offs_rotary\n            + rotary_dim // 2,\n            mask_rotary,\n            other=0.0,\n        )\n\n        off_q_pe_rot = (\n            cur_batch * stride_qb\n            + cur_head[:, None] * stride_qh\n            + offs_qk_rot_r[None, :]\n        )\n        mask_qk_rot_r = offs_qk_rot_r < (kv_lora_rank + qk_rope_head_dim)\n\n        # 0, 2, 4,.... 1, 3, 5...\n        q_pe_rot = tl.load(\n            Q + off_q_pe_rot,\n            mask=(mask_h[:, None]) & (mask_qk_rot_r[None, :]),\n            other=0.0,\n        )\n        q_pe_rot = tl.where(mask_rotate[None, :], -q_pe_rot, q_pe_rot)\n\n        q_pe = q_pe * cos + q_pe_rot * sin\n\n        # we only apply to the last token in the K_PE\n        if LAST_SPLIT:\n            # debug assert\n            if (cur_batch == 0 and cur_head == 0) and split_kv_id < NUM_KV_SPLITS - 1:\n                tl.device_assert(False, \"Only last split should compute k_pe\")\n\n            kv_loc = tl.load(\n                kv_indices + cur_batch_kv_start_idx + cur_batch_seq_len - 1\n            )\n            offs_buf_k_pe_last_token = kv_loc * stride_buf_kbs + offs_qk_r\n            offs_buf_k_pe_rot_last_token = kv_loc * stride_buf_kbs + offs_qk_rot_r\n            k_pe_last_token = tl.load(K_Buffer + offs_buf_k_pe_last_token)\n\n            k_pe_rot_last_token = tl.load(K_Buffer + offs_buf_k_pe_rot_last_token)\n            k_pe_rot_last_token = tl.where(\n                mask_rotate, -k_pe_rot_last_token, k_pe_rot_last_token\n            )\n\n            k_pe_last_token = k_pe_last_token * cos + k_pe_rot_last_token * sin\n\n    e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float(\"inf\")\n    e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)\n    acc = tl.zeros([BLOCK_H, BLOCK_C], dtype=tl.float32)\n\n    if split_kv_end > split_kv_start:\n        for start_n in range(split_kv_start, split_kv_end, BLOCK_N):\n            offs_n = start_n + tl.arange(0, BLOCK_N)\n            kv_loc = tl.load(\n                kv_indices + cur_batch_kv_start_idx + offs_n,\n                mask=offs_n < split_kv_end,\n                other=0,\n            )\n\n            offs_buf_kv = kv_loc[None, :] * stride_buf_kbs + offs_c[:, None]\n            offs_buf_k_pe = kv_loc[None, :] * stride_buf_kbs + offs_qk_r[:, None]\n\n            k_pe = tl.load(\n                K_Buffer + offs_buf_k_pe,\n                mask=(offs_n[None, :] < split_kv_end) & (mask_qk_r[:, None]),\n                other=0.0,\n            )  # positional embedding part of keys\n\n            if (USE_ROPE and LAST_SPLIT) and start_n >= cur_batch_seq_len - BLOCK_N:\n                k_pe = tl.where(\n                    offs_n[None, :] != (split_kv_end - 1),\n                    k_pe,\n                    k_pe_last_token[:, None],\n                )\n\n            # (16, 64) x (64, 32)\n            # dot product of rope parts\n            qk = tl.dot(q_pe, k_pe.to(q_pe.dtype))\n\n            kv = tl.load(\n                K_Buffer + offs_buf_kv,\n                mask=(offs_n[None, :] < split_kv_end) & (mask_c[:, None]),\n                other=0.0,\n            )  # the shared latent tensor for keys and values\n\n            # (16, 512) x (512, 32)\n            # dot product of nope parts\n            qk += tl.dot(q, kv)\n\n            qk *= sm_scale\n\n            if logit_cap > 0:\n                qk = logit_cap * tanh(qk / logit_cap)\n\n            qk = tl.where(\n                mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float(\"-inf\")\n            )\n\n            offs_buf_v = kv_loc[:, None] * stride_buf_vbs + offs_c[None, :]\n            v = tl.load(\n                V_buffer + offs_buf_v,\n                mask=(offs_n[:, None] < split_kv_end) & (mask_c[None, :]),\n                other=0.0,\n            )\n\n            n_e_max = tl.maximum(tl.max(qk, 1), e_max)\n            re_scale = tl.exp(e_max - n_e_max)\n            p = tl.exp(qk - n_e_max[:, None])\n            acc *= re_scale[:, None]\n            # (16, 32) x (32, 512)\n            acc += tl.dot(p.to(v.dtype), v)\n\n            e_sum = e_sum * re_scale + tl.sum(p, 1)\n            e_max = n_e_max\n\n        offs_mid_o = (\n            cur_batch * stride_mid_ob\n            + cur_head[:, None] * stride_mid_oh\n            + split_kv_id * stride_mid_os\n            + offs_c[None, :]\n        )\n\n        if USE_ROPE:\n            if LAST_SPLIT:\n                k_pe_last_token_ptrs = (\n                    k_pe_t_out\n                    + cur_batch * stride_kpe_tokens_out_b\n                    + tl.arange(0, BLOCK_R)\n                )\n                tl.store(k_pe_last_token_ptrs, k_pe_last_token, mask=mask_qk_r)\n\n        tl.store(\n            Att_Out + offs_mid_o,\n            acc / e_sum[:, None],\n            mask=(mask_h[:, None]) & (mask_c[None, :]),\n        )\n\n        offs_mid_o_1 = (\n            cur_batch * stride_mid_ob\n            + cur_head * stride_mid_oh\n            + split_kv_id * stride_mid_os\n            + kv_lora_rank\n        )\n\n        tl.store(\n            Att_Out + offs_mid_o_1,\n            e_max + tl.log(e_sum),\n            mask=mask_h,\n        )\n\n\n# TODO rope offset\ndef _decode_grouped_att_m_fwd_rope(\n    q,\n    k_buffer,\n    v_buffer,\n    att_out,\n    k_pe_tokens_out,\n    kv_lora_rank,  # c\n    cos_sin_cache,\n    positions,\n    rotary_dim,\n    kv_indptr,\n    kv_indices,\n    num_kv_splits,\n    sm_scale,\n    logit_cap,\n    use_rope,\n    is_neox_style=True,\n):\n    if use_rope:\n        assert (\n            k_pe_tokens_out is not None\n        ), \"We must output the k_pe tokens with rope applied if rope fusion enabled.\"\n\n    BLOCK = 32\n\n    # # [TODO] work around shmem limit on MI3xx\n    # if _is_hip and kv_lora_rank >= 576:\n    #     BLOCK = 16\n\n    qk_rope_head_dim = k_buffer.shape[-1] - kv_lora_rank\n    batch, head_num = kv_indptr.shape[0] - 1, q.shape[1]\n    kv_group_num = q.shape[1] // k_buffer.shape[1]\n\n    BLOCK_C = triton.next_power_of_2(kv_lora_rank)\n    BLOCK_R = triton.next_power_of_2(qk_rope_head_dim)\n\n    BLOCK_H = 16\n    NUM_KV_SPLITS = num_kv_splits\n    grid = (\n        batch,\n        triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),\n        NUM_KV_SPLITS,\n    )\n\n    extra_kargs = {}\n    num_stages = 2\n    if _is_hip:\n        # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html\n        # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py\n        extra_kargs = {\"waves_per_eu\": 1, \"matrix_instr_nonkdim\": 16, \"kpack\": 2}\n        num_stages = 1\n\n    _fwd_grouped_kernel_stage1_rope[grid](\n        q,\n        k_buffer,\n        v_buffer,\n        cos_sin_cache,\n        positions,\n        sm_scale,\n        kv_indptr,\n        kv_indices,\n        att_out,\n        k_pe_tokens_out,\n        q.stride(0),\n        q.stride(1),\n        k_buffer.stride(0),\n        v_buffer.stride(0),\n        att_out.stride(0),\n        att_out.stride(1),\n        att_out.stride(2),\n        k_pe_tokens_out.stride(0) if use_rope else 0,\n        cos_sin_cache.stride(0) if use_rope else 0,\n        positions.stride(0) if use_rope else 0,\n        rotary_dim,\n        kv_lora_rank,\n        qk_rope_head_dim,\n        kv_group_num=kv_group_num,\n        q_head_num=head_num,\n        BLOCK_C=BLOCK_C,\n        BLOCK_R=BLOCK_R,\n        BLOCK_N=BLOCK,\n        BLOCK_H=BLOCK_H,\n        NUM_KV_SPLITS=NUM_KV_SPLITS,\n        logit_cap=logit_cap,\n        USE_ROPE=use_rope,\n        IS_NEOX_STYLE=is_neox_style,\n        num_warps=4,\n        num_stages=num_stages,\n        **extra_kargs\n    )\n\n\ndef decode_attention_fwd_grouped_rope(\n    q,\n    k_buffer,\n    v_buffer,\n    o,\n    kv_indptr,\n    kv_indices,\n    k_pe_tokens,\n    kv_lora_rank,\n    rotary_dim,\n    cos_sin_cache,\n    positions,\n    attn_logits,\n    num_kv_splits,\n    sm_scale,\n    logit_cap=0.0,\n    use_rope=False,\n    is_neox_style=False,\n):\n    _decode_grouped_att_m_fwd_rope(\n        q,\n        k_buffer,\n        v_buffer,\n        attn_logits,\n        k_pe_tokens,\n        kv_lora_rank,\n        cos_sin_cache,\n        positions,\n        rotary_dim,\n        kv_indptr,\n        kv_indices,\n        num_kv_splits,\n        sm_scale,\n        logit_cap,\n        use_rope,\n        is_neox_style,\n    )\n    _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, kv_indptr, num_kv_splits)\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/triton_ops/trtllm_fp8_kv_kernel.py",
    "content": "\"\"\"\nFused FP8 quantization + paged KV cache write kernel for TRTLLM MHA backend.\n\nThis kernel fuses the following operations:\n1. FP8 quantization of K and V tensors (from BF16/FP16 to FP8)\n2. Per-token or per-page scale computation\n3. Writing quantized K/V to paged KV cache layout\n\nPerformance benefits:\n- Eliminates intermediate FP8 tensors in memory\n- Reduces kernel launch overhead\n- Better memory bandwidth utilization\n\"\"\"\n\nimport logging\nfrom typing import Optional\n\nimport torch\nimport triton\nimport triton.language as tl\n\nlogger = logging.getLogger(__name__)\n\n\n@triton.jit\ndef _process_kv_tensor(\n    token_id,\n    head_block_id,\n    page_id,\n    page_offset,\n    input_ptr,\n    cache_ptr,\n    inv_scale,\n    use_provided_scale: tl.constexpr,\n    num_kv_heads: tl.constexpr,\n    head_dim: tl.constexpr,\n    input_stride_token: tl.constexpr,\n    input_stride_head: tl.constexpr,\n    input_stride_dim: tl.constexpr,\n    cache_stride_page: tl.constexpr,\n    cache_stride_offset: tl.constexpr,\n    cache_stride_head: tl.constexpr,\n    cache_stride_dim: tl.constexpr,\n    BLOCK_HEAD: tl.constexpr,\n    BLOCK_DIM: tl.constexpr,\n):\n    \"\"\"Process a block of heads for a single K or V tensor.\"\"\"\n    head_idx = head_block_id * BLOCK_HEAD\n    num_heads_in_block = min(BLOCK_HEAD, num_kv_heads - head_idx)\n\n    for dim_idx in range(0, head_dim, BLOCK_DIM):\n        num_dims_in_block = min(BLOCK_DIM, head_dim - dim_idx)\n\n        head_offsets = head_idx + tl.arange(0, BLOCK_HEAD)\n        dim_offsets = dim_idx + tl.arange(0, BLOCK_DIM)\n\n        head_mask = head_offsets < (head_idx + num_heads_in_block)\n        dim_mask = dim_offsets < (dim_idx + num_dims_in_block)\n\n        # Load from input using 3D strides\n        input_offsets = (\n            token_id * input_stride_token\n            + head_offsets[:, None] * input_stride_head\n            + dim_offsets[None, :] * input_stride_dim\n        )\n        mask = head_mask[:, None] & dim_mask[None, :]\n\n        block = tl.load(input_ptr + input_offsets, mask=mask, other=0.0)\n\n        # Quantize to FP8\n        if use_provided_scale:\n            block_fp8 = (block * inv_scale).to(tl.float8e4nv)\n        else:\n            block_fp8 = block.to(tl.float8e4nv)\n\n        # Write to cache at [page_id, page_offset, head, dim]\n        cache_offsets = (\n            page_id * cache_stride_page\n            + page_offset * cache_stride_offset\n            + head_offsets[:, None] * cache_stride_head\n            + dim_offsets[None, :] * cache_stride_dim\n        )\n\n        tl.store(cache_ptr + cache_offsets, block_fp8, mask=mask)\n\n\n@triton.jit\ndef _fused_fp8_set_kv_buffer_kernel(\n    # Input tensors (post-RoPE K and V in FP16/BF16)\n    k_ptr,  # [num_tokens, num_kv_heads, head_dim]\n    v_ptr,  # [num_tokens, num_kv_heads, head_dim]\n    # Output KV cache buffers (FP8 paged layout)\n    k_cache_ptr,  # [total_slots, num_kv_heads, head_dim]\n    v_cache_ptr,  # [total_slots, num_kv_heads, head_dim]\n    # Cache location indices\n    cache_loc_ptr,  # [num_tokens] -> token to cache location mapping\n    # Pointers to scalar inverse scales (computed on GPU in wrapper)\n    inv_k_scale_ptr,  # pointer to 0-D tensor on GPU\n    inv_v_scale_ptr,  # pointer to 0-D tensor on GPU\n    use_provided_scale: tl.constexpr,  # whether to use provided scale\n    # Tensor dimensions\n    num_kv_heads: tl.constexpr,\n    head_dim: tl.constexpr,\n    page_size: tl.constexpr,\n    # Strides for K input [num_tokens, num_kv_heads, head_dim]\n    k_stride_token: tl.constexpr,\n    k_stride_head: tl.constexpr,\n    k_stride_dim: tl.constexpr,\n    # Strides for K cache [total_slots, num_kv_heads, head_dim] (logically paged)\n    k_cache_stride_page: tl.constexpr,\n    k_cache_stride_offset: tl.constexpr,\n    k_cache_stride_head: tl.constexpr,\n    k_cache_stride_dim: tl.constexpr,\n    # Strides for V input [num_tokens, num_kv_heads, head_dim]\n    v_stride_token: tl.constexpr,\n    v_stride_head: tl.constexpr,\n    v_stride_dim: tl.constexpr,\n    # Strides for V cache [total_slots, num_kv_heads, head_dim] (logically paged)\n    v_cache_stride_page: tl.constexpr,\n    v_cache_stride_offset: tl.constexpr,\n    v_cache_stride_head: tl.constexpr,\n    v_cache_stride_dim: tl.constexpr,\n    # Block sizes\n    BLOCK_HEAD: tl.constexpr,  # Number of heads per block\n    BLOCK_DIM: tl.constexpr,  # Head dimension block size\n):\n    \"\"\"\n    Fused FP8 quantization + paged KV cache write kernel.\n\n    Each program processes one token-head_block-kv combination, quantizing and writing\n    to the appropriate page in the KV cache.\n\n    Grid: (num_tokens, num_head_blocks, 2) where dim2: 0=K, 1=V\n    \"\"\"\n    # Get program IDs\n    token_id = tl.program_id(0)\n    head_block_id = tl.program_id(1)\n    kv_idx = tl.program_id(2)  # 0 for K, 1 for V\n\n    # Get cache location for this token\n    cache_loc = tl.load(cache_loc_ptr + token_id)\n\n    # Compute page_id and offset within page\n    page_id = cache_loc // page_size\n    page_offset = cache_loc % page_size\n\n    # Select K or V based on kv_idx\n    if kv_idx == 0:\n        # Process K tensor\n        if use_provided_scale:\n            inv_scale = tl.load(inv_k_scale_ptr)\n        else:\n            inv_scale = 1.0\n        _process_kv_tensor(\n            token_id,\n            head_block_id,\n            page_id,\n            page_offset,\n            k_ptr,\n            k_cache_ptr,\n            inv_scale,\n            use_provided_scale,\n            num_kv_heads,\n            head_dim,\n            k_stride_token,\n            k_stride_head,\n            k_stride_dim,\n            k_cache_stride_page,\n            k_cache_stride_offset,\n            k_cache_stride_head,\n            k_cache_stride_dim,\n            BLOCK_HEAD,\n            BLOCK_DIM,\n        )\n    else:\n        # Process V tensor\n        if use_provided_scale:\n            inv_scale = tl.load(inv_v_scale_ptr)\n        else:\n            inv_scale = 1.0\n        _process_kv_tensor(\n            token_id,\n            head_block_id,\n            page_id,\n            page_offset,\n            v_ptr,\n            v_cache_ptr,\n            inv_scale,\n            use_provided_scale,\n            num_kv_heads,\n            head_dim,\n            v_stride_token,\n            v_stride_head,\n            v_stride_dim,\n            v_cache_stride_page,\n            v_cache_stride_offset,\n            v_cache_stride_head,\n            v_cache_stride_dim,\n            BLOCK_HEAD,\n            BLOCK_DIM,\n        )\n\n\ndef fused_fp8_set_kv_buffer(\n    k: torch.Tensor,  # [num_tokens, num_kv_heads, head_dim] or [num_tokens, num_kv_heads * head_dim]\n    v: torch.Tensor,  # [num_tokens, num_kv_heads, head_dim] or [num_tokens, num_kv_heads * head_dim]\n    k_cache: torch.Tensor,  # [total_slots, num_kv_heads, head_dim] or [num_pages, page_size, num_kv_heads, head_dim]\n    v_cache: torch.Tensor,  # [total_slots, num_kv_heads, head_dim] or [num_pages, page_size, num_kv_heads, head_dim]\n    cache_loc: torch.Tensor,  # [num_tokens], dtype=int32\n    k_scale: Optional[\n        float\n    ] = None,  # Scalar scale (matching original set_kv_buffer signature)\n    v_scale: Optional[float] = None,\n    page_size: int = 16,\n    use_triton: bool = True,  # Whether to use Triton kernel (set to False to force naive fallback)\n) -> None:\n    \"\"\"\n    Python wrapper for the fused FP8 quantization + paged KV cache write kernel.\n\n    This function replicates the exact behavior of the original set_kv_buffer but with\n    a fused kernel that combines FP8 quantization and cache write.\n\n    Args:\n        k: Key tensor after RoPE, can be 2D or 3D\n        v: Value tensor, can be 2D or 3D\n        k_cache: Paged K cache buffer in FP8\n        v_cache: Paged V cache buffer in FP8\n        cache_loc: Cache location for each token, shape [num_tokens]\n        k_scale: Optional scalar scale for K (matching original set_kv_buffer)\n        v_scale: Optional scalar scale for V (matching original set_kv_buffer)\n        page_size: Number of tokens per page\n        use_triton: Whether to use optimized Triton kernel\n    \"\"\"\n    num_tokens = k.shape[0]\n\n    # Step 1: Infer num_kv_heads and head_dim from cache shape\n    if k_cache.ndim == 3:\n        # 3D cache layout: [total_slots, num_kv_heads, head_dim]\n        total_slots, num_kv_heads, head_dim = k_cache.shape\n        assert (\n            total_slots % page_size == 0\n        ), f\"total_slots ({total_slots}) must be divisible by page_size ({page_size})\"\n        num_pages = total_slots // page_size\n    elif k_cache.ndim == 4:\n        # 4D cache layout: [num_pages, page_size, num_kv_heads, head_dim]\n        num_pages, ps, num_kv_heads, head_dim = k_cache.shape\n        assert (\n            ps == page_size\n        ), f\"page_size mismatch: cache has {ps}, expected {page_size}\"\n        total_slots = num_pages * page_size\n    else:\n        raise ValueError(f\"Unsupported k_cache.ndim={k_cache.ndim}, expected 3 or 4\")\n\n    # Step 2: Validate k, v shapes and normalize\n    # Store original 3D shape for Triton path\n    k_3d = None\n    v_3d = None\n\n    if k.ndim == 3:\n        # Input is [num_tokens, num_kv_heads, head_dim]\n        assert (\n            k.shape[1] == num_kv_heads\n        ), f\"num_kv_heads mismatch: k.shape[1]={k.shape[1]} vs cache={num_kv_heads}\"\n        assert (\n            k.shape[2] == head_dim\n        ), f\"head_dim mismatch: k.shape[2]={k.shape[2]} vs cache={head_dim}\"\n        assert v.shape[1] == num_kv_heads and v.shape[2] == head_dim, \"v shape mismatch\"\n\n        # Keep 3D for Triton kernel\n        k_3d = k\n        v_3d = v\n        # Create 2D view for naive fallback (will be used only if use_triton=False)\n        k_2d = k.reshape(num_tokens, num_kv_heads * head_dim)\n        v_2d = v.reshape(num_tokens, num_kv_heads * head_dim)\n    elif k.ndim == 2:\n        # Input is already [num_tokens, num_kv_heads * head_dim]\n        assert (\n            k.shape[1] == num_kv_heads * head_dim\n        ), f\"k.shape[1]={k.shape[1]} != {num_kv_heads * head_dim}\"\n        assert (\n            v.shape[1] == num_kv_heads * head_dim\n        ), f\"v.shape[1]={v.shape[1]} != {num_kv_heads * head_dim}\"\n\n        # Create 3D view for Triton kernel\n        k_3d = k.view(num_tokens, num_kv_heads, head_dim)\n        v_3d = v.view(num_tokens, num_kv_heads, head_dim)\n        # Keep 2D for naive\n        k_2d = k\n        v_2d = v\n    else:\n        raise ValueError(f\"Unsupported k.ndim={k.ndim}, expected 2 or 3\")\n\n    # Step 3: Compute cache strides based on layout\n    if k_cache.ndim == 3:\n        # 3D cache: [total_slots, num_kv_heads, head_dim]\n        stride_slot = k_cache.stride(0)\n        stride_head = k_cache.stride(1)\n        stride_dim = k_cache.stride(2)\n\n        k_cache_stride_page = stride_slot * page_size\n        k_cache_stride_offset = stride_slot\n        k_cache_stride_head = stride_head\n        k_cache_stride_dim = stride_dim\n\n        v_stride_slot = v_cache.stride(0)\n        v_stride_head = v_cache.stride(1)\n        v_stride_dim = v_cache.stride(2)\n\n        v_cache_stride_page = v_stride_slot * page_size\n        v_cache_stride_offset = v_stride_slot\n        v_cache_stride_head = v_stride_head\n        v_cache_stride_dim = v_stride_dim\n    else:\n        # 4D cache: [num_pages, page_size, num_kv_heads, head_dim]\n        k_cache_stride_page = k_cache.stride(0)\n        k_cache_stride_offset = k_cache.stride(1)\n        k_cache_stride_head = k_cache.stride(2)\n        k_cache_stride_dim = k_cache.stride(3)\n\n        v_cache_stride_page = v_cache.stride(0)\n        v_cache_stride_offset = v_cache.stride(1)\n        v_cache_stride_head = v_cache.stride(2)\n        v_cache_stride_dim = v_cache.stride(3)\n\n    # Decide whether to use provided scale\n    use_provided_scale = k_scale is not None and v_scale is not None\n\n    if use_triton and num_tokens > 0:\n        # Use optimized Triton kernel\n        # Compute input strides for 3D k, v: [num_tokens, num_kv_heads, head_dim]\n        k_stride_token = k_3d.stride(0)\n        k_stride_head = k_3d.stride(1)\n        k_stride_dim = k_3d.stride(2)\n\n        v_stride_token = v_3d.stride(0)\n        v_stride_head = v_3d.stride(1)\n        v_stride_dim = v_3d.stride(2)\n\n        # Block sizes for tiling (tunable)\n        BLOCK_HEAD = min(num_kv_heads, 8)  # Process up to 8 heads at once\n        BLOCK_DIM = min(head_dim, 128)  # Process up to 128 dims at once\n\n        # Compute number of head blocks\n        num_head_blocks = (num_kv_heads + BLOCK_HEAD - 1) // BLOCK_HEAD\n\n        # Grid: (num_tokens, num_head_blocks, 2)\n        # - dim 0: tokens\n        # - dim 1: head blocks\n        # - dim 2: K/V (0=K, 1=V)\n        grid = (num_tokens, num_head_blocks, 2)\n\n        device = k_3d.device\n\n        def _to_tensor_scale(scale):\n            \"\"\"Convert scale to 0-D CUDA tensor (accepts Python float or Tensor).\"\"\"\n            if isinstance(scale, torch.Tensor):\n                return scale.to(device=device, dtype=torch.float32)\n            else:\n                # Python float / np scalar\n                return torch.tensor(float(scale), device=device, dtype=torch.float32)\n\n        # Compute inverse scales on GPU to avoid GPU→CPU sync in CUDA graph capture.\n        # Previously we used float(k_scale) which triggers synchronization and fails\n        # during CUDA graph capture with cudaErrorStreamCaptureUnsupported.\n        if use_provided_scale:\n            k_scale_tensor = _to_tensor_scale(k_scale)\n            v_scale_tensor = _to_tensor_scale(v_scale)\n\n            # Pure GPU scalar operation, safe for CUDA graph\n            inv_k_scale = (1.0 / k_scale_tensor).to(device=device, dtype=torch.float32)\n            inv_v_scale = (1.0 / v_scale_tensor).to(device=device, dtype=torch.float32)\n\n            inv_k_scale_ptr = inv_k_scale\n            inv_v_scale_ptr = inv_v_scale\n        else:\n            # When use_provided_scale=False, kernel uses constant 1.0 for inv_scale.\n            # Triton will optimize away the tl.load() calls via constant folding.\n            # We pass dummy pointers (k_3d) which won't be accessed in the kernel.\n            # This avoids creating new GPU tensors during CUDA graph capture.\n            inv_k_scale_ptr = k_3d\n            inv_v_scale_ptr = k_3d\n\n        # Launch Triton kernel\n        _fused_fp8_set_kv_buffer_kernel[grid](\n            k_3d,\n            v_3d,\n            k_cache,\n            v_cache,\n            cache_loc,\n            inv_k_scale_ptr,\n            inv_v_scale_ptr,\n            use_provided_scale,\n            num_kv_heads,\n            head_dim,\n            page_size,\n            k_stride_token,\n            k_stride_head,\n            k_stride_dim,\n            k_cache_stride_page,\n            k_cache_stride_offset,\n            k_cache_stride_head,\n            k_cache_stride_dim,\n            v_stride_token,\n            v_stride_head,\n            v_stride_dim,\n            v_cache_stride_page,\n            v_cache_stride_offset,\n            v_cache_stride_head,\n            v_cache_stride_dim,\n            BLOCK_HEAD=BLOCK_HEAD,\n            BLOCK_DIM=BLOCK_DIM,\n        )\n    else:\n        # Fallback to naive implementation\n        _naive_fp8_set_kv_buffer(\n            k_2d, v_2d, k_cache, v_cache, cache_loc, k_scale, v_scale, page_size\n        )\n\n\ndef _naive_fp8_set_kv_buffer(\n    k: torch.Tensor,\n    v: torch.Tensor,\n    k_cache: torch.Tensor,\n    v_cache: torch.Tensor,\n    cache_loc: torch.Tensor,\n    k_scale: Optional[float],\n    v_scale: Optional[float],\n    page_size: int,\n) -> None:\n    \"\"\"\n    Naive fallback implementation that mimics the original set_kv_buffer logic.\n\n    This directly replicates the behavior of MHATokenToKVPool.set_kv_buffer:\n    1. Apply scale (if k.dtype != cache.dtype and scale is provided)\n    2. Convert to FP8\n    3. Write to cache at cache_loc\n\n    Args:\n        k: [num_tokens, num_kv_heads * head_dim], already reshaped to 2D\n        v: [num_tokens, num_kv_heads * head_dim], already reshaped to 2D\n        k_cache: [total_slots, num_kv_heads, head_dim] or [num_pages, page_size, num_kv_heads, head_dim]\n        v_cache: Same shape as k_cache\n        cache_loc: [num_tokens]\n        k_scale: Optional scale for K\n        v_scale: Optional scale for V\n        page_size: Tokens per page\n    \"\"\"\n    num_tokens = k.shape[0]\n\n    # Infer dimensions from cache\n    if k_cache.ndim == 3:\n        num_kv_heads = k_cache.shape[1]\n        head_dim = k_cache.shape[2]\n    elif k_cache.ndim == 4:\n        num_kv_heads = k_cache.shape[2]\n        head_dim = k_cache.shape[3]\n    else:\n        raise ValueError(f\"Unsupported k_cache.ndim={k_cache.ndim}\")\n\n    # Determine target dtype and storage dtype\n    # See: python/sglang/srt/mem_cache/memory_pool.py:445-449\n    store_dtype = k_cache.dtype\n    if store_dtype == torch.uint8:\n        # Cache is stored as uint8 for FP8 (due to index_put limitation)\n        dtype = torch.float8_e4m3fn  # Logical dtype\n    else:\n        dtype = store_dtype  # Cache dtype is the logical dtype\n\n    # Replicate the original set_kv_buffer behavior\n    # See: python/sglang/srt/mem_cache/memory_pool.py:777-799\n    if k.dtype != dtype:\n        # Need quantization - clone first to avoid modifying input\n        k = k.clone()\n        v = v.clone()\n\n        if k_scale is not None:\n            k.div_(k_scale)  # In-place division\n        if v_scale is not None:\n            v.div_(v_scale)  # In-place division\n\n        k = k.to(dtype)\n        v = v.to(dtype)\n\n    # View FP8 as uint8 if needed (for index_put compatibility)\n    if store_dtype == torch.uint8 and dtype in (torch.float8_e5m2, torch.float8_e4m3fn):\n        k = k.view(torch.uint8)\n        v = v.view(torch.uint8)\n\n    # Reshape from [T, H*D] to [T, H, D]\n    k = k.view(num_tokens, num_kv_heads, head_dim)\n    v = v.view(num_tokens, num_kv_heads, head_dim)\n\n    # Write to cache using advanced indexing (same as original)\n    if k_cache.ndim == 3:\n        # 3D cache: [total_slots, H, D]\n        k_cache[cache_loc] = k\n        v_cache[cache_loc] = v\n    else:\n        # 4D cache: [num_pages, page_size, H, D]\n        # Decompose loc into page_id and page_offset (vectorized)\n        page_ids = cache_loc // page_size\n        page_offsets = cache_loc % page_size\n        k_cache[page_ids, page_offsets] = k\n        v_cache[page_ids, page_offsets] = v\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/trtllm_mha_backend.py",
    "content": "from __future__ import annotations\n\n\"\"\"\nSupport attention backend for TRTLLM MHA kernels from flashinfer.\nThe kernel supports sm100 only, with sliding window and attention sink features.\n\"\"\"\n\nimport logging\nfrom dataclasses import dataclass\nfrom typing import TYPE_CHECKING, Optional\n\nimport torch\n\nfrom sglang.srt.environ import envs\nfrom sglang.srt.layers.attention.flashinfer_backend import (\n    FlashInferAttnBackend,\n    FlashInferMultiStepDraftBackend,\n)\nfrom sglang.srt.layers.attention.triton_ops.trtllm_fp8_kv_kernel import (\n    fused_fp8_set_kv_buffer,\n)\nfrom sglang.srt.layers.attention.utils import canonicalize_stride\nfrom sglang.srt.mem_cache.swa_memory_pool import SWAKVPool, SWATokenToKVPoolAllocator\nfrom sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode\nfrom sglang.srt.utils import is_flashinfer_available\nfrom sglang.srt.utils.common import is_sm90_supported, is_sm120_supported\n\nlogger = logging.getLogger(__name__)\n\nif is_flashinfer_available():\n    import flashinfer\n\nif TYPE_CHECKING:\n    from sglang.srt.layers.radix_attention import RadixAttention\n    from sglang.srt.model_executor.model_runner import ModelRunner\n    from sglang.srt.speculative.spec_info import SpecInput\n\n# Constants\n# Default workspace size in MB for TRTLLM MHA\n# Can be configured via SGLANG_FLASHINFER_WORKSPACE_SIZE environment variable\nDEFAULT_WORKSPACE_SIZE_MB = 512\n\n# Reuse this workspace buffer across all TRTLLM MHA wrappers\nglobal_zero_init_workspace_buffer = None\n\n\n@dataclass\nclass TRTLLMMHAMetadata:\n    # Sequence lengths for the forward batch\n    cache_seqlens_int32: torch.Tensor = None\n    # Maximum sequence length for query\n    max_seq_len_q: int = 1\n    # Maximum sequence length for key\n    max_seq_len_k: int = 0\n    # Cumulative sequence lengths for `query\n    cu_seqlens_q: torch.Tensor = None\n    # Cumulative sequence lengths for key\n    cu_seqlens_k: torch.Tensor = None\n    # Page table, the index of KV Cache Tables/Blocks\n    page_table: torch.Tensor = None\n    # Page table for SWA layers (translated from full pool indices to SWA pool indices)\n    swa_page_table: torch.Tensor = None\n\n\nclass TRTLLMHAAttnBackend(FlashInferAttnBackend):\n    \"\"\"TRTLLM MHA attention kernel from flashinfer.\"\"\"\n\n    def __init__(\n        self,\n        model_runner: ModelRunner,\n        skip_prefill: bool = False,\n        kv_indptr_buf: Optional[torch.Tensor] = None,\n        kv_last_page_len_buf: Optional[torch.Tensor] = None,\n        speculative_step_id: int = 0,\n    ):\n        # Capture workspace size before super().__init__() to preserve user's\n        # SGLANG_FLASHINFER_WORKSPACE_SIZE setting (may be overridden by parent)\n        env_var = envs.SGLANG_FLASHINFER_WORKSPACE_SIZE\n        workspace_size_bytes = (\n            env_var.get()\n            if env_var.is_set()\n            else DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024\n        )\n\n        super().__init__(\n            model_runner, skip_prefill, kv_indptr_buf, kv_last_page_len_buf\n        )\n\n        config = model_runner.model_config\n\n        # MHA-specific dimensions\n        self.max_context_len = model_runner.model_config.context_len\n        self.hidden_size = config.hidden_size\n\n        # Runtime parameters\n        self.data_type = model_runner.kv_cache_dtype\n        self.q_data_type = model_runner.dtype\n        self.page_size = model_runner.page_size\n        self.req_to_token = model_runner.req_to_token_pool.req_to_token\n        self.device = model_runner.device\n\n        # Workspace allocation\n        self.workspace_size = workspace_size_bytes\n        # Allocate buffers\n        global global_zero_init_workspace_buffer\n        if global_zero_init_workspace_buffer is None:\n            global_zero_init_workspace_buffer = torch.zeros(\n                self.workspace_size,\n                dtype=torch.uint8,\n                device=model_runner.device,\n            )\n        self.workspace_buffer = global_zero_init_workspace_buffer\n\n        # CUDA graph state\n        self.decode_cuda_graph_metadata = {}\n\n        # Speculative decoding\n        # Only support topk <= 1 for now.\n        self.topk = model_runner.server_args.speculative_eagle_topk or 0\n        self.speculative_step_id = speculative_step_id\n        self.target_verify_metadata = {}\n\n        self.speculative_num_draft_tokens = (\n            model_runner.server_args.speculative_num_draft_tokens\n        )\n\n        # Sliding Window Attention(SWA) hybrid model support.\n        # For hybrid SWA models, the KV cache is split into two pools (full and SWA)\n        # with separate index spaces. We maintain a translated page_table for SWA\n        # layers so the trtllm kernel reads from the correct pool.\n        allocator = model_runner.token_to_kv_pool_allocator\n        self.use_sliding_window_kv_pool = isinstance(\n            allocator, SWATokenToKVPoolAllocator\n        )\n        self._swa_kv_pool: Optional[SWAKVPool] = (\n            allocator.get_kvcache() if self.use_sliding_window_kv_pool else None\n        )\n\n        # Forward metadata\n        self.forward_metadata: Optional[TRTLLMMHAMetadata] = None\n\n        # Init backend (XQA or TRTLLM-GEN)\n        # We need to specify q_type and out_type for different backend\n        # XQA: (q_type must be bf16)\n        #   KV bf16: q_type = bf16, out_type=model_runner.dtype\n        #   KV fp8: q_type = bf16, out_type=model_runner.dtype\n        # TRTLLM-GEN:\n        #   KV bf16: q_type = bf16, out_type=model_runner.dtype\n        #   KV fp8: q_type = fp8, out_type=model_runner.dtype\n        self.is_xqa_impl = is_sm90_supported() or is_sm120_supported()\n\n    def _maybe_translate_swa(\n        self, token_indices: torch.Tensor\n    ) -> Optional[torch.Tensor]:\n        \"\"\"Translate full-pool token indices to SWA-pool indices, or return None.\"\"\"\n        if not self.use_sliding_window_kv_pool:\n            return None\n        shape = token_indices.shape\n        return self._swa_kv_pool.translate_loc_from_full_to_swa(\n            token_indices.reshape(-1)\n        ).reshape(shape)\n\n    def _alloc_swa_page_table(\n        self, max_bs: int, max_num_pages: int\n    ) -> Optional[torch.Tensor]:\n        \"\"\"Allocate a SWA page_table buffer, or return None for non-SWA models.\"\"\"\n        if not self.use_sliding_window_kv_pool:\n            return None\n        return torch.zeros(max_bs, max_num_pages, dtype=torch.int32, device=self.device)\n\n    def _copy_swa_page_table(\n        self,\n        metadata: TRTLLMMHAMetadata,\n        page_indices: torch.Tensor,\n        num_pages: int,\n    ):\n        \"\"\"Translate and copy SWA page indices into metadata. No-op for non-SWA.\"\"\"\n        if metadata.swa_page_table is None:\n            return\n        swa_indices = self._maybe_translate_swa(page_indices)\n        metadata.swa_page_table[:, :num_pages].copy_(swa_indices // self.page_size)\n\n    def _get_layer_cache_loc(\n        self,\n        layer: RadixAttention,\n        cache_loc: torch.Tensor,\n    ) -> torch.Tensor:\n        \"\"\"Return cache locations in the correct index space for the given layer.\"\"\"\n        if self.use_sliding_window_kv_pool:\n            _, is_swa = self._swa_kv_pool.layers_mapping[layer.layer_id]\n            if is_swa:\n                return self._swa_kv_pool.translate_loc_from_full_to_swa(cache_loc)\n        return cache_loc\n\n    def _bind_swa_page_table(\n        self, metadata: TRTLLMMHAMetadata, source: dict, key: str, bs: int\n    ):\n        \"\"\"Bind a pre-allocated SWA page_table slice to metadata for CUDA graph.\"\"\"\n        buf = source.get(key)\n        if buf is not None:\n            metadata.swa_page_table = buf[:bs, :]\n\n    def _get_layer_page_table(\n        self, layer: RadixAttention, forward_batch: ForwardBatch\n    ) -> torch.Tensor:\n        \"\"\"Return the correct page_table for the given layer (SWA or full).\"\"\"\n        swa_pt = self.forward_metadata.swa_page_table\n        if swa_pt is not None:\n            _, is_swa = self._swa_kv_pool.layers_mapping[layer.layer_id]\n            if is_swa:\n                return swa_pt\n        return self.forward_metadata.page_table\n\n    def init_cuda_graph_state(\n        self,\n        max_bs: int,\n        max_num_tokens: int,\n        kv_indices_buf: Optional[torch.Tensor] = None,\n    ):\n        \"\"\"Initialize CUDA graph state for TRTLLM MHA.\"\"\"\n        max_num_pages = (self.max_context_len + self.page_size - 1) // self.page_size\n        self.decode_cuda_graph_metadata = {\n            \"cache_seqlens\": torch.zeros(max_bs, dtype=torch.int32, device=self.device),\n            \"page_table\": torch.zeros(\n                max_bs,\n                max_num_pages,\n                dtype=torch.int32,\n                device=self.device,\n            ),\n            \"swa_page_table\": self._alloc_swa_page_table(max_bs, max_num_pages),\n            \"strided_indices\": torch.arange(\n                0, self.max_context_len, self.page_size, device=self.device\n            ),\n        }\n\n        if (\n            self.speculative_num_draft_tokens is not None\n            and self.speculative_num_draft_tokens > 0\n        ):\n            self.decode_cuda_graph_metadata[\"cu_seqlens_q\"] = torch.arange(\n                0, max_bs + 1, dtype=torch.int32, device=self.device\n            )\n            self.decode_cuda_graph_metadata[\"cu_seqlens_k\"] = torch.zeros(\n                max_bs + 1, dtype=torch.int32, device=self.device\n            )\n            self.decode_cuda_graph_metadata[\"page_table_draft_decode\"] = torch.zeros(\n                max_bs,\n                max_num_pages,\n                dtype=torch.int32,\n                device=self.device,\n            )\n            self.decode_cuda_graph_metadata[\"swa_page_table_draft_decode\"] = (\n                self._alloc_swa_page_table(max_bs, max_num_pages)\n            )\n\n            self.target_verify_metadata = {\n                \"cache_seqlens\": torch.zeros(\n                    max_bs, dtype=torch.int32, device=self.device\n                ),\n                \"cu_seqlens_q\": torch.arange(\n                    0,\n                    max_bs * self.speculative_num_draft_tokens + 1,\n                    step=self.speculative_num_draft_tokens,\n                    dtype=torch.int32,\n                    device=self.device,\n                ),\n                \"cu_seqlens_k\": torch.zeros(\n                    max_bs + 1, dtype=torch.int32, device=self.device\n                ),\n                \"page_table\": torch.zeros(\n                    max_bs,\n                    max_num_pages,\n                    dtype=torch.int32,\n                    device=self.device,\n                ),\n                \"swa_page_table\": self._alloc_swa_page_table(max_bs, max_num_pages),\n                \"strided_indices\": torch.arange(\n                    0, self.max_context_len, self.page_size, device=self.device\n                ),\n            }\n\n            self.draft_extend_metadata = {\n                \"cache_seqlens\": torch.zeros(\n                    max_bs, dtype=torch.int32, device=self.device\n                ),\n                \"cu_seqlens_q\": torch.zeros(\n                    max_bs + 1,\n                    dtype=torch.int32,\n                    device=self.device,\n                ),\n                \"cu_seqlens_k\": torch.zeros(\n                    max_bs + 1, dtype=torch.int32, device=self.device\n                ),\n                \"page_table\": torch.zeros(\n                    max_bs,\n                    max_num_pages,\n                    dtype=torch.int32,\n                    device=self.device,\n                ),\n                \"swa_page_table\": self._alloc_swa_page_table(max_bs, max_num_pages),\n                \"strided_indices\": torch.arange(\n                    0, self.max_context_len, self.page_size, device=self.device\n                ),\n            }\n\n    def init_forward_metadata_capture_cuda_graph(\n        self,\n        bs: int,\n        num_tokens: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        encoder_lens: Optional[torch.Tensor],\n        forward_mode: ForwardMode,\n        spec_info: Optional[SpecInput],\n    ):\n        \"\"\"Initialize metadata for CUDA graph capture.\"\"\"\n        metadata = TRTLLMMHAMetadata()\n        device = seq_lens.device\n\n        if forward_mode.is_decode_or_idle():\n            if spec_info is not None:\n                # Draft Decode\n                # Here we only support topk = 1 for now.\n                metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[\n                    \"cache_seqlens\"\n                ][:bs]\n                metadata.max_seq_len_k = seq_lens.max().item() + (\n                    self.speculative_step_id + 1\n                )\n                metadata.cu_seqlens_q = self.decode_cuda_graph_metadata[\"cu_seqlens_q\"][\n                    : bs + 1\n                ]\n                metadata.cu_seqlens_k = torch.nn.functional.pad(\n                    torch.cumsum(\n                        metadata.cache_seqlens_int32, dim=0, dtype=torch.int32\n                    ),\n                    (1, 0),\n                )\n                metadata.page_table = self.decode_cuda_graph_metadata[\n                    \"page_table_draft_decode\"\n                ][:bs, :]\n                self._bind_swa_page_table(\n                    metadata,\n                    self.decode_cuda_graph_metadata,\n                    \"swa_page_table_draft_decode\",\n                    bs,\n                )\n                self.decode_cuda_graph_metadata[bs] = metadata\n            else:\n                # Normal Decode\n                # Get sequence information\n                metadata.cache_seqlens_int32 = seq_lens[:bs].to(torch.int32)\n                batch_size = len(seq_lens)\n                metadata.cu_seqlens_k = torch.nn.functional.pad(\n                    torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)\n                )\n\n                # Precompute maximum sequence length\n                metadata.max_seq_len_k = seq_lens.max().item()\n                # Precompute cumulative sequence lengths\n                metadata.cu_seqlens_q = torch.arange(\n                    0, batch_size + 1, dtype=torch.int32, device=device\n                )\n                # Precompute page table\n                metadata.page_table = self.decode_cuda_graph_metadata[\"page_table\"][\n                    :bs, :\n                ]\n                self._bind_swa_page_table(\n                    metadata,\n                    self.decode_cuda_graph_metadata,\n                    \"swa_page_table\",\n                    bs,\n                )\n                self.decode_cuda_graph_metadata[bs] = metadata\n        elif forward_mode.is_target_verify():\n            # Target Verify\n            # Here we only support topk = 1 for now.\n            metadata.cache_seqlens_int32 = self.target_verify_metadata[\"cache_seqlens\"][\n                :bs\n            ]\n            metadata.cache_seqlens_int32.copy_(\n                (seq_lens + self.speculative_num_draft_tokens)\n            )\n\n            metadata.cu_seqlens_q = torch.arange(\n                0,\n                bs * self.speculative_num_draft_tokens + 1,\n                self.speculative_num_draft_tokens,\n                dtype=torch.int32,\n                device=device,\n            )\n\n            metadata.cu_seqlens_k = self.target_verify_metadata[\"cu_seqlens_k\"][\n                : (bs + 1)\n            ]\n\n            metadata.max_seq_len_q = self.speculative_num_draft_tokens\n            metadata.max_seq_len_k = (\n                seq_lens.max().item() + self.speculative_num_draft_tokens\n            )\n\n            metadata.page_table = self.target_verify_metadata[\"page_table\"][:bs, :]\n            self._bind_swa_page_table(\n                metadata,\n                self.target_verify_metadata,\n                \"swa_page_table\",\n                bs,\n            )\n\n            self.target_verify_metadata[bs] = metadata\n        elif forward_mode.is_draft_extend():\n            metadata.cache_seqlens_int32 = self.draft_extend_metadata[\"cache_seqlens\"][\n                :bs\n            ]\n            metadata.cache_seqlens_int32.copy_(seq_lens)\n            num_tokens_per_bs = num_tokens // bs\n            metadata.cu_seqlens_q = torch.arange(\n                0,\n                bs * num_tokens_per_bs + 1,\n                num_tokens_per_bs,\n                dtype=torch.int32,\n                device=device,\n            )\n\n            metadata.cu_seqlens_k = self.draft_extend_metadata[\"cu_seqlens_k\"][\n                : (bs + 1)\n            ]\n            num_tokens_per_bs = num_tokens // bs\n            metadata.max_seq_len_q = num_tokens_per_bs\n            metadata.max_seq_len_k = seq_lens.max().item()\n\n            metadata.page_table = self.draft_extend_metadata[\"page_table\"][:bs, :]\n            self._bind_swa_page_table(\n                metadata,\n                self.draft_extend_metadata,\n                \"swa_page_table\",\n                bs,\n            )\n\n            self.draft_extend_metadata[bs] = metadata\n        self.forward_metadata = metadata\n\n    def init_forward_metadata_replay_cuda_graph(\n        self,\n        bs: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        seq_lens_sum: int,\n        encoder_lens: Optional[torch.Tensor],\n        forward_mode: ForwardMode,\n        spec_info: Optional[SpecInput],\n        seq_lens_cpu: Optional[torch.Tensor],\n    ):\n        \"\"\"Replay CUDA graph with new inputs.\"\"\"\n        seq_lens = seq_lens[:bs]\n        seq_lens_cpu = seq_lens_cpu[:bs]\n        req_pool_indices = req_pool_indices[:bs]\n        metadata = None\n        if forward_mode.is_decode_or_idle():\n            if spec_info is not None:\n                # Draft Decode\n                # Here we only support topk = 1 for now.\n                metadata = self.decode_cuda_graph_metadata[bs]\n                max_len = seq_lens_cpu.max().item()\n                metadata.max_seq_len_k = max_len + self.speculative_step_id + 1\n\n                max_seq_pages = (\n                    metadata.max_seq_len_k + self.page_size - 1\n                ) // self.page_size\n\n                metadata.cache_seqlens_int32.copy_(\n                    seq_lens + self.speculative_step_id + 1\n                )\n            else:\n                # Normal Decode\n                metadata = self.decode_cuda_graph_metadata[bs]\n                max_len = seq_lens_cpu.max().item()\n                max_seq_pages = (max_len + self.page_size - 1) // self.page_size\n                metadata.max_seq_len_k = max_len\n\n                metadata.cache_seqlens_int32.copy_(seq_lens)\n\n            metadata.cu_seqlens_k[1:].copy_(\n                torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)\n            )\n            page_indices = self.req_to_token[\n                req_pool_indices[:, None],\n                self.decode_cuda_graph_metadata[\"strided_indices\"][:max_seq_pages][\n                    None, :\n                ],\n            ]\n            metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size)\n            self._copy_swa_page_table(metadata, page_indices, max_seq_pages)\n        elif forward_mode.is_target_verify():\n            # Here we only support topk = 1 for now.\n            metadata = self.target_verify_metadata[bs]\n            metadata.cache_seqlens_int32.copy_(\n                (seq_lens + self.speculative_num_draft_tokens)\n            )\n\n            metadata.max_seq_len_k = (\n                seq_lens_cpu.max().item() + self.speculative_num_draft_tokens\n            )\n            max_len = seq_lens_cpu.max().item()\n            metadata.cu_seqlens_k[1:].copy_(\n                torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)\n            )\n            max_seq_pages = (\n                metadata.max_seq_len_k + self.page_size - 1\n            ) // self.page_size\n            page_indices = self.req_to_token[\n                req_pool_indices[:, None],\n                self.decode_cuda_graph_metadata[\"strided_indices\"][:max_seq_pages],\n            ]\n            metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size)\n            self._copy_swa_page_table(metadata, page_indices, max_seq_pages)\n            metadata.max_seq_len_q = self.speculative_num_draft_tokens\n        elif forward_mode.is_draft_extend():\n            metadata = self.draft_extend_metadata[bs]\n            metadata.cache_seqlens_int32.copy_(seq_lens)\n\n            metadata.max_seq_len_k = seq_lens_cpu.max().item()\n            max_len = seq_lens_cpu.max().item()\n            metadata.cu_seqlens_k[1:].copy_(\n                torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)\n            )\n            accept_length = spec_info.accept_length[:bs]\n            if spec_info.accept_length_cpu:\n                metadata.max_seq_len_q = max(spec_info.accept_length_cpu) + 1\n            else:\n                metadata.max_seq_len_q = 1\n\n            metadata.cu_seqlens_q[1:].copy_(\n                torch.cumsum(accept_length, dim=0, dtype=torch.int32)\n            )\n\n            max_seq_pages = (\n                metadata.max_seq_len_k + self.page_size - 1\n            ) // self.page_size\n            page_indices = self.req_to_token[\n                req_pool_indices[:, None],\n                self.draft_extend_metadata[\"strided_indices\"][:max_seq_pages],\n            ]\n            metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size)\n            self._copy_swa_page_table(metadata, page_indices, max_seq_pages)\n        self.forward_metadata = metadata\n\n    def get_cuda_graph_seq_len_fill_value(self) -> int:\n        \"\"\"Get the fill value for sequence lengths in CUDA graph.\"\"\"\n        return 1\n\n    def _should_use_fused_fp8_path(self, save_kv_cache: bool, k: torch.Tensor) -> bool:\n        \"\"\"Check if we should use the fused FP8 KV cache write path.\"\"\"\n        return save_kv_cache and k is not None and self.data_type == torch.float8_e4m3fn\n\n    def _fused_fp8_set_kv_buffer(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        **kwargs,\n    ):\n        \"\"\"Fused FP8 quantization and KV cache write.\"\"\"\n        cache_loc = self._get_layer_cache_loc(layer, forward_batch.out_cache_loc)\n\n        # Get K/V cache buffers from token_to_kv_pool\n        k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)\n\n        fused_fp8_set_kv_buffer(\n            k=k,\n            v=v,\n            k_cache=k_cache,\n            v_cache=v_cache,\n            cache_loc=cache_loc,\n            k_scale=layer.k_scale,  # May be None\n            v_scale=layer.v_scale,  # May be None\n            page_size=self.page_size,\n        )\n\n    def init_forward_metadata(self, forward_batch: ForwardBatch):\n        \"\"\"Initialize the metadata for a forward pass.\"\"\"\n\n        metadata = TRTLLMMHAMetadata()\n        seqlens_in_batch = forward_batch.seq_lens\n        batch_size = forward_batch.batch_size\n        device = seqlens_in_batch.device\n\n        if forward_batch.forward_mode.is_decode_or_idle():\n            if forward_batch.spec_info is not None:\n                # Draft Decode\n                # Here we only support topk = 1 for now.\n                metadata.cache_seqlens_int32 = (\n                    seqlens_in_batch + (self.speculative_step_id + 1)\n                ).to(torch.int32)\n                metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (\n                    self.speculative_step_id + 1\n                )\n                metadata.cu_seqlens_q = torch.arange(\n                    0, batch_size + 1, dtype=torch.int32, device=device\n                )\n                metadata.cu_seqlens_k = torch.nn.functional.pad(\n                    torch.cumsum(\n                        metadata.cache_seqlens_int32, dim=0, dtype=torch.int32\n                    ),\n                    (1, 0),\n                )\n                metadata.page_table = forward_batch.req_to_token_pool.req_to_token[\n                    forward_batch.req_pool_indices, : metadata.max_seq_len_k\n                ]\n            else:\n                # Normal Decode\n                metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)\n                metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()\n                metadata.cu_seqlens_q = torch.arange(\n                    0, batch_size + 1, dtype=torch.int32, device=device\n                )\n                metadata.cu_seqlens_k = torch.nn.functional.pad(\n                    torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)\n                )\n                metadata.page_table = forward_batch.req_to_token_pool.req_to_token[\n                    forward_batch.req_pool_indices, : metadata.max_seq_len_k\n                ]\n        elif forward_batch.forward_mode.is_target_verify():\n            # Only support topk = 1 for now.\n            metadata.cache_seqlens_int32 = (\n                forward_batch.seq_lens + self.speculative_num_draft_tokens\n            ).to(torch.int32)\n            metadata.max_seq_len_q = self.speculative_num_draft_tokens\n            metadata.max_seq_len_k = (\n                forward_batch.seq_lens_cpu.max().item()\n                + self.speculative_num_draft_tokens\n            )\n            metadata.cu_seqlens_q = torch.arange(\n                0,\n                batch_size * self.speculative_num_draft_tokens + 1,\n                self.speculative_num_draft_tokens,\n                dtype=torch.int32,\n                device=device,\n            )\n            metadata.cu_seqlens_k = torch.nn.functional.pad(\n                torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32),\n                (1, 0),\n            )\n            metadata.page_table = forward_batch.req_to_token_pool.req_to_token[\n                forward_batch.req_pool_indices, : metadata.max_seq_len_k\n            ]\n\n        else:\n            metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)\n            metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()\n            metadata.cu_seqlens_k = torch.nn.functional.pad(\n                torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)\n            )\n            metadata.page_table = forward_batch.req_to_token_pool.req_to_token[\n                forward_batch.req_pool_indices, : metadata.max_seq_len_k\n            ]\n\n            if any(\n                forward_batch.extend_prefix_lens_cpu\n            ) or forward_batch.forward_mode.is_draft_extend(include_v2=True):\n                extend_seq_lens = forward_batch.extend_seq_lens\n                # NOTE: in piecewise CUDA graph warmup, extend_seq_lens_cpu is a torch.Tensor;\n                # Python's max() returns a 0-d tensor, but flashinfer expects an int.\n                max_q = max(forward_batch.extend_seq_lens_cpu)\n                metadata.max_seq_len_q = (\n                    int(max_q.item()) if isinstance(max_q, torch.Tensor) else int(max_q)\n                )\n                metadata.cu_seqlens_q = torch.nn.functional.pad(\n                    torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)\n                )\n            else:\n                metadata.max_seq_len_q = metadata.max_seq_len_k\n                metadata.cu_seqlens_q = metadata.cu_seqlens_k\n\n        # Compute SWA page table (None for non-SWA models)\n        metadata.swa_page_table = self._maybe_translate_swa(metadata.page_table)\n\n        # Convert the page tables to a strided format\n        if self.page_size > 1:\n            self.strided_indices = torch.arange(\n                0, metadata.page_table.shape[1], self.page_size, device=self.device\n            )\n            metadata.page_table = (\n                metadata.page_table[:, self.strided_indices] // self.page_size\n            )\n            if metadata.swa_page_table is not None:\n                metadata.swa_page_table = (\n                    metadata.swa_page_table[:, self.strided_indices] // self.page_size\n                )\n\n        self.forward_metadata = metadata\n\n    def forward_decode(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache: bool = True,\n        **kwargs,\n    ) -> torch.Tensor:\n        \"\"\"Run forward for decode using TRTLLM MHA kernel.\"\"\"\n        cache_loc = forward_batch.out_cache_loc\n\n        use_fused_fp8_path = self._should_use_fused_fp8_path(save_kv_cache, k)\n\n        if use_fused_fp8_path:\n            # Use fused FP8 quantization + KV cache write path\n            self._fused_fp8_set_kv_buffer(\n                q=q,\n                k=k,\n                v=v,\n                layer=layer,\n                forward_batch=forward_batch,\n            )\n            k = None\n            v = None\n        else:\n            # Use original set_kv_buffer path\n            if save_kv_cache and k is not None:\n                forward_batch.token_to_kv_pool.set_kv_buffer(\n                    layer, cache_loc, k, v, layer.k_scale, layer.v_scale\n                )\n\n        # For XQA, q_dtype should be bf16\n        if self.data_type == torch.float8_e4m3fn and (not self.is_xqa_impl):\n            q = q.to(torch.float8_e4m3fn)\n        q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)\n        k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)\n        # shape conversion:\n        # [num_pages, page_size, num_kv_heads, head_dim] -> [num_pages, num_kv_heads, page_size, head_dim]\n        k_cache = k_cache.view(\n            -1, self.page_size, layer.tp_k_head_num, layer.head_dim\n        ).permute(0, 2, 1, 3)\n        v_cache = v_cache.view(\n            -1, self.page_size, layer.tp_v_head_num, layer.head_dim\n        ).permute(0, 2, 1, 3)\n\n        if layer.tp_k_head_num == 1:\n            k_cache = canonicalize_stride(k_cache)\n        if layer.tp_v_head_num == 1:\n            v_cache = canonicalize_stride(v_cache)\n\n        kv_cache = (k_cache, v_cache)\n\n        # TODO: add support for quantization\n        q_scale = 1.0\n        k_scale = (\n            layer.k_scale_float\n            if getattr(layer, \"k_scale_float\", None) is not None\n            else 1.0\n        )\n        bmm1_scale = q_scale * k_scale * layer.scaling\n        bmm2_scale = 1.0\n        # sink: additional value per head in the denominator of the softmax.\n        attention_sink = kwargs.get(\"sinks\", None)\n\n        page_table = self._get_layer_page_table(layer, forward_batch)\n\n        # Call TRT-LLM kernel\n        # raw_out: like q, [bs, acc_q_len, num_q_heads, head_dim] but with output dtype\n        o = flashinfer.decode.trtllm_batch_decode_with_kv_cache(\n            query=q,\n            kv_cache=kv_cache,\n            workspace_buffer=self.workspace_buffer,\n            block_tables=page_table,\n            seq_lens=self.forward_metadata.cache_seqlens_int32,\n            max_seq_len=self.max_context_len,\n            bmm1_scale=bmm1_scale,\n            bmm2_scale=bmm2_scale,\n            window_left=layer.sliding_window_size,\n            sinks=attention_sink,\n            out_dtype=self.q_data_type,  # model_runner.dtype\n        )\n\n        return o.view(-1, layer.tp_q_head_num * layer.head_dim)\n\n    def forward_extend(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache=True,\n        **kwargs,\n    ):\n        cache_loc = forward_batch.out_cache_loc\n\n        use_fused_fp8_path = self._should_use_fused_fp8_path(save_kv_cache, k)\n\n        if use_fused_fp8_path:\n            # Use fused FP8 quantization + KV cache write path\n            self._fused_fp8_set_kv_buffer(\n                q=q,\n                k=k,\n                v=v,\n                layer=layer,\n                forward_batch=forward_batch,\n            )\n            k = None\n            v = None\n        else:\n            # Use original set_kv_buffer path\n            if save_kv_cache and k is not None:\n                forward_batch.token_to_kv_pool.set_kv_buffer(\n                    layer, cache_loc, k, v, layer.k_scale, layer.v_scale\n                )\n\n        if self.data_type == torch.float8_e4m3fn:\n            q = q.to(torch.float8_e4m3fn)\n        q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)\n        # [num_pages, page_size, num_kv_heads, head_dim] -> [num_pages, num_kv_heads, page_size, head_dim]\n        k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)\n        k_cache = k_cache.view(\n            -1, self.page_size, layer.tp_k_head_num, layer.head_dim\n        ).permute(0, 2, 1, 3)\n        v_cache = v_cache.view(\n            -1, self.page_size, layer.tp_v_head_num, layer.head_dim\n        ).permute(0, 2, 1, 3)\n\n        if layer.tp_k_head_num == 1:\n            k_cache = canonicalize_stride(k_cache)\n        if layer.tp_v_head_num == 1:\n            v_cache = canonicalize_stride(v_cache)\n\n        kv_cache = (k_cache, v_cache)\n\n        # sink: additional value per head in the denominator of the softmax.\n        attention_sink = kwargs.get(\"sinks\", None)\n        # TODO: add support for quantization\n        q_scale = 1.0\n        k_scale = (\n            layer.k_scale_float\n            if getattr(layer, \"k_scale_float\", None) is not None\n            else 1.0\n        )\n        bmm1_scale = q_scale * k_scale * layer.scaling\n        bmm2_scale = 1.0\n\n        page_table = self._get_layer_page_table(layer, forward_batch)\n\n        if forward_batch.forward_mode.is_target_verify():\n            o = flashinfer.decode.trtllm_batch_decode_with_kv_cache(\n                query=q,\n                kv_cache=kv_cache,\n                workspace_buffer=self.workspace_buffer,\n                block_tables=page_table,\n                seq_lens=self.forward_metadata.cache_seqlens_int32,\n                max_seq_len=self.max_context_len,\n                bmm1_scale=bmm1_scale,\n                bmm2_scale=bmm2_scale,\n                window_left=layer.sliding_window_size,\n                sinks=attention_sink,\n                out_dtype=self.q_data_type,  # model_runner.dtype\n                q_len_per_req=self.forward_metadata.max_seq_len_q,\n            )\n        else:\n            o = flashinfer.prefill.trtllm_batch_context_with_kv_cache(\n                query=q,\n                kv_cache=kv_cache,\n                workspace_buffer=self.workspace_buffer,\n                block_tables=page_table,\n                seq_lens=self.forward_metadata.cache_seqlens_int32,\n                max_q_len=self.forward_metadata.max_seq_len_q,\n                max_kv_len=self.max_context_len,\n                bmm1_scale=bmm1_scale,\n                bmm2_scale=bmm2_scale,\n                batch_size=forward_batch.batch_size,\n                cum_seq_lens_q=self.forward_metadata.cu_seqlens_q,\n                cum_seq_lens_kv=self.forward_metadata.cu_seqlens_k,\n                window_left=layer.sliding_window_size,\n                sinks=attention_sink,\n                out_dtype=self.q_data_type,  # model_runner.dtype\n            )\n\n        return o.view(-1, layer.tp_q_head_num * layer.head_dim)\n\n\nclass TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):\n    \"\"\"Multi-step TRTLLM MHA attention kernel used by EAGLE.\"\"\"\n\n    def __init__(\n        self, model_runner: ModelRunner, topk: int, speculative_num_steps: int\n    ):\n        super().__init__(model_runner, topk, speculative_num_steps)\n        for i in range(self.speculative_num_steps - 1):\n            self.attn_backends[i] = TRTLLMHAAttnBackend(\n                model_runner,\n                skip_prefill=True,\n                kv_indptr_buf=self.kv_indptr[i],\n                kv_last_page_len_buf=self.kv_last_page_len,\n                speculative_step_id=i,\n            )\n\n    def init_forward_metadata(self, forward_batch: ForwardBatch):\n        for i in range(self.speculative_num_steps - 1):\n            self.attn_backends[i].init_forward_metadata(forward_batch)\n\n    def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):\n        for i in range(self.speculative_num_steps - 1):\n            self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)\n\n    def init_forward_metadata_capture_cuda_graph(\n        self,\n        forward_batch: ForwardBatch,\n    ):\n        assert forward_batch.spec_info is not None\n        assert forward_batch.spec_info.is_draft_input()\n\n        for i in range(self.speculative_num_steps - 1):\n            self.attn_backends[i].init_forward_metadata_capture_cuda_graph(\n                forward_batch.batch_size,\n                forward_batch.batch_size * self.topk,\n                forward_batch.req_pool_indices,\n                forward_batch.seq_lens,\n                encoder_lens=forward_batch.encoder_lens,\n                forward_mode=ForwardMode.DECODE,\n                spec_info=forward_batch.spec_info,\n            )\n\n    def init_forward_metadata_replay_cuda_graph(\n        self, forward_batch: ForwardBatch, bs: int\n    ):\n        assert forward_batch.spec_info is not None\n        assert forward_batch.spec_info.is_draft_input()\n\n        for i in range(self.speculative_num_steps - 1):\n\n            self.attn_backends[i].init_forward_metadata_replay_cuda_graph(\n                bs,\n                forward_batch.req_pool_indices,\n                forward_batch.seq_lens,\n                forward_batch.seq_lens_sum,\n                encoder_lens=forward_batch.encoder_lens,\n                forward_mode=ForwardMode.DECODE,\n                spec_info=forward_batch.spec_info,\n                seq_lens_cpu=forward_batch.seq_lens_cpu,\n            )\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/trtllm_mla_backend.py",
    "content": "from __future__ import annotations\n\n\"\"\"\nSupport attention backend for TRTLLM MLA kernels from flashinfer.\n\"\"\"\n\nimport logging\nimport math\nfrom dataclasses import dataclass\nfrom typing import TYPE_CHECKING, Optional, Union\n\nimport torch\nimport triton\nimport triton.language as tl\n\nfrom sglang.srt.compilation.piecewise_context_manager import is_in_piecewise_cuda_graph\nfrom sglang.srt.layers.attention.flashinfer_mla_backend import (\n    FlashInferMLAAttnBackend,\n    FlashInferMLAMultiStepDraftBackend,\n)\nfrom sglang.srt.layers.attention.utils import (\n    concat_mla_absorb_q_general,\n    create_flashmla_kv_indices_triton,\n    get_num_page_per_block_flashmla,\n    mla_quantize_and_rope_for_fp8,\n)\nfrom sglang.srt.layers.dp_attention import get_attention_tp_size\nfrom sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant\nfrom sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode\nfrom sglang.srt.server_args import get_global_server_args\nfrom sglang.srt.utils import is_flashinfer_available, is_float4_e2m1fn_x2\n\nif is_flashinfer_available():\n    import flashinfer\n\nif TYPE_CHECKING:\n    from sglang.srt.layers.radix_attention import RadixAttention\n    from sglang.srt.model_executor.model_runner import ModelRunner\n    from sglang.srt.speculative.spec_info import SpecInput\n\nlogger = logging.getLogger(__name__)\n\n# Constants\nDEFAULT_WORKSPACE_SIZE_MB = 150  # Memory workspace size in MB\n\n# Block constraint from flashinfer requirements\n# From flashinfer.decode._check_trtllm_gen_mla_shape:\n#   block_num % (128 / block_size) == 0\n# This imposes that the total number of blocks must be divisible by\n# (128 / block_size). We capture the 128 constant here so we can\n# compute the LCM with other padding constraints.\nTRTLLM_BLOCK_CONSTRAINT = 128\n\n\n@triton.jit\ndef pad_draft_extend_query_kernel(\n    q_ptr,  # Input query tensor [total_seq_len, num_heads, head_dim]\n    padded_q_ptr,  # Output padded query tensor [batch_size, max_seq_len, num_heads, head_dim]\n    seq_lens_q_ptr,  # Sequence lengths for each sequence [batch_size]\n    cumsum_ptr,  # Cumulative sum of accept lengths [batch_size + 1]\n    batch_size,\n    max_seq_len,\n    num_heads,\n    head_dim,\n    BLOCK_SIZE: tl.constexpr,\n):\n    \"\"\"Triton kernel for padding draft extended query tensor with parallelized head and dim processing.\"\"\"\n    # Use 3D program IDs: (batch_seq, head_block, dim_block)\n    batch_seq_pid = tl.program_id(0)\n    head_pid = tl.program_id(1)\n    dim_pid = tl.program_id(2)\n\n    batch_id = batch_seq_pid // max_seq_len\n    seq_pos = batch_seq_pid % max_seq_len\n\n    if batch_id >= batch_size:\n        return\n\n    # Load accept length for this batch\n    seq_len = tl.load(seq_lens_q_ptr + batch_id)\n\n    if seq_pos >= seq_len:\n        return\n\n    # Load cumulative sum to get start position in input tensor\n    input_start = tl.load(cumsum_ptr + batch_id)\n    input_pos = input_start + seq_pos\n\n    # Calculate head and dim block ranges\n    head_start = head_pid * BLOCK_SIZE\n    head_end = tl.minimum(head_start + BLOCK_SIZE, num_heads)\n    head_mask = tl.arange(0, BLOCK_SIZE) < (head_end - head_start)\n\n    dim_start = dim_pid * BLOCK_SIZE\n    dim_end = tl.minimum(dim_start + BLOCK_SIZE, head_dim)\n    dim_mask = tl.arange(0, BLOCK_SIZE) < (dim_end - dim_start)\n\n    # Calculate input offset\n    input_offset = (\n        input_pos * num_heads * head_dim\n        + (head_start + tl.arange(0, BLOCK_SIZE))[:, None] * head_dim\n        + (dim_start + tl.arange(0, BLOCK_SIZE))[None, :]\n    )\n\n    # Load data\n    data = tl.load(\n        q_ptr + input_offset,\n        mask=head_mask[:, None] & dim_mask[None, :],\n        other=0.0,\n    )\n\n    # Calculate output offset\n    output_offset = (\n        batch_id * max_seq_len * num_heads * head_dim\n        + seq_pos * num_heads * head_dim\n        + (head_start + tl.arange(0, BLOCK_SIZE))[:, None] * head_dim\n        + (dim_start + tl.arange(0, BLOCK_SIZE))[None, :]\n    )\n\n    # Store data\n    tl.store(\n        padded_q_ptr + output_offset,\n        data,\n        mask=head_mask[:, None] & dim_mask[None, :],\n    )\n\n\n@triton.jit\ndef unpad_draft_extend_output_kernel(\n    raw_out_ptr,  # Input raw output tensor (batch_size, token_per_batch, tp_q_head_num, v_head_dim)\n    output_ptr,  # Output tensor (-1, tp_q_head_num, v_head_dim)\n    accept_length_ptr,  # Accept lengths for each sequence [batch_size]\n    cumsum_ptr,  # Cumulative sum of accept lengths [batch_size + 1]\n    batch_size,\n    token_per_batch,\n    tp_q_head_num,\n    v_head_dim,\n    BLOCK_SIZE: tl.constexpr,\n):\n    \"\"\"Triton kernel for unpadding draft extended output tensor with parallelized head and dim processing.\"\"\"\n    batch_seq_pid = tl.program_id(0)\n    head_pid = tl.program_id(1)\n    dim_pid = tl.program_id(2)\n\n    batch_id = batch_seq_pid // token_per_batch\n    seq_pos = batch_seq_pid % token_per_batch\n\n    if batch_id >= batch_size:\n        return\n\n    # Load accept length for this batch\n    accept_len = tl.load(accept_length_ptr + batch_id)\n\n    if seq_pos >= accept_len:\n        return\n\n    # Load cumulative sum to get start position in output tensor\n    output_start = tl.load(cumsum_ptr + batch_id)\n    output_pos = output_start + seq_pos\n\n    # Calculate head and dim block ranges\n    head_start = head_pid * BLOCK_SIZE\n    head_end = tl.minimum(head_start + BLOCK_SIZE, tp_q_head_num)\n    head_mask = tl.arange(0, BLOCK_SIZE) < (head_end - head_start)\n\n    dim_start = dim_pid * BLOCK_SIZE\n    dim_end = tl.minimum(dim_start + BLOCK_SIZE, v_head_dim)\n    dim_mask = tl.arange(0, BLOCK_SIZE) < (dim_end - dim_start)\n\n    # Calculate input offset: (batch_id, seq_pos, head_id, dim_id)\n    input_offset = (\n        batch_id * token_per_batch * tp_q_head_num * v_head_dim\n        + seq_pos * tp_q_head_num * v_head_dim\n        + (head_start + tl.arange(0, BLOCK_SIZE))[:, None] * v_head_dim\n        + (dim_start + tl.arange(0, BLOCK_SIZE))[None, :]\n    )\n\n    # Load data\n    data = tl.load(\n        raw_out_ptr + input_offset,\n        mask=head_mask[:, None] & dim_mask[None, :],\n        other=0.0,\n    )\n\n    output_offset = (\n        output_pos * tp_q_head_num * v_head_dim\n        + (head_start + tl.arange(0, BLOCK_SIZE))[:, None] * v_head_dim\n        + (dim_start + tl.arange(0, BLOCK_SIZE))[None, :]\n    )\n\n    # Store data\n    tl.store(\n        output_ptr + output_offset,\n        data,\n        mask=head_mask[:, None] & dim_mask[None, :],\n    )\n\n\ndef _quantize_fp8_qkv(q, k, v, layer):\n    q = q.to(torch.float8_e4m3fn)\n\n    k_scale = getattr(layer, \"k_scale_float\", None)\n    if k_scale is None:\n        k_scale = 1.0\n    if k_scale != 1.0:\n        assert hasattr(layer, \"k_scale\"), \"k_scale is not set\"\n        k_2d, _ = scaled_fp8_quant(\n            k.reshape(-1, k.shape[-1]).contiguous(), layer.k_scale\n        )\n        k = k_2d.reshape(k.shape)\n    else:\n        k = k.to(torch.float8_e4m3fn)\n\n    v_scale = getattr(layer, \"v_scale_float\", None)\n    if v_scale is None:\n        v_scale = 1.0\n    if v_scale != 1.0:\n        assert hasattr(layer, \"v_scale\"), \"v_scale is not set\"\n        v_2d, _ = scaled_fp8_quant(\n            v.reshape(-1, v.shape[-1]).contiguous(), layer.v_scale\n        )\n        v = v_2d.reshape(v.shape)\n    else:\n        v = v.to(torch.float8_e4m3fn)\n\n    return q, k, v, k_scale, v_scale\n\n\nglobal_zero_init_workspace_buffer = None\n\n\n@dataclass\nclass TRTLLMMLAPrefillMetadata:\n    \"\"\"Metadata for TRTLLM MLA prefill operations.\"\"\"\n\n    max_seq_len: int\n    cum_seq_lens: torch.Tensor\n    seq_lens: torch.Tensor\n    fallback_to_flashinfer_impl: bool = False\n\n\n@dataclass\nclass TRTLLMMLADecodeMetadata:\n    \"\"\"Metadata for TRTLLM MLA decode operations.\"\"\"\n\n    block_kv_indices: Optional[torch.Tensor] = None\n    max_seq_len_k: Optional[int] = None\n    max_seq_len_q: Optional[int] = None\n    sum_seq_lens_q: Optional[int] = None\n    cu_seqlens_q: Optional[torch.Tensor] = None\n    seq_lens_q: Optional[torch.Tensor] = None\n    seq_lens_k: Optional[torch.Tensor] = None\n\n\nclass TRTLLMMLABackend(FlashInferMLAAttnBackend):\n    \"\"\"TRTLLM MLA attention kernel from flashinfer.\"\"\"\n\n    def __init__(\n        self,\n        model_runner: ModelRunner,\n        skip_prefill: bool = False,\n        kv_indptr_buf: Optional[torch.Tensor] = None,\n        q_indptr_decode_buf: Optional[torch.Tensor] = None,\n    ):\n        super().__init__(\n            model_runner,\n            skip_prefill,\n            kv_indptr_buf,\n            q_indptr_decode_buf,\n        )\n\n        config = model_runner.model_config\n\n        # Model parameters\n        self.num_q_heads = config.num_attention_heads // get_attention_tp_size()\n        self.num_kv_heads = config.get_num_kv_heads(get_attention_tp_size())\n        self.num_local_heads = config.num_attention_heads // get_attention_tp_size()\n\n        # MLA-specific dimensions\n        self.kv_lora_rank = config.kv_lora_rank\n        self.qk_nope_head_dim = config.qk_nope_head_dim\n        self.qk_rope_head_dim = config.qk_rope_head_dim\n        self.v_head_dim = config.v_head_dim\n        self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim\n\n        # Runtime parameters\n        self.scaling = config.scaling\n        self.data_type = model_runner.kv_cache_dtype\n        self.q_data_type = model_runner.dtype\n        self.page_size = model_runner.page_size\n        self.req_to_token = model_runner.req_to_token_pool.req_to_token\n\n        # Workspace allocation\n        self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024\n        global global_zero_init_workspace_buffer\n        if global_zero_init_workspace_buffer is None:\n            global_zero_init_workspace_buffer = torch.zeros(\n                self.workspace_size,\n                dtype=torch.uint8,\n                device=model_runner.device,\n            )\n        self.workspace_buffer = global_zero_init_workspace_buffer\n\n        # CUDA graph state\n        self.decode_cuda_graph_metadata = {}\n        self.decode_cuda_graph_kv_indices = None\n        self.padded_q_buffer = None\n        self.unpad_output_buffer = None\n        self.forward_prefill_metadata: Optional[TRTLLMMLAPrefillMetadata] = None\n        self.forward_decode_metadata: Union[TRTLLMMLADecodeMetadata, None] = None\n\n        self.disable_chunked_prefix_cache = (\n            get_global_server_args().disable_chunked_prefix_cache\n        )\n\n        self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens\n\n    def _calc_padded_blocks(self, max_seq_len: int) -> int:\n        \"\"\"\n        Calculate padded block count that satisfies both TRT-LLM and Triton constraints.\n\n        Args:\n            max_seq_len: Maximum sequence length in tokens\n\n        Returns:\n            Number of blocks padded to satisfy all constraints\n        \"\"\"\n        blocks = triton.cdiv(max_seq_len, self.page_size)\n\n        # Apply dual constraints (take LCM to satisfy both):\n        # 1. TRT-LLM: block_num % (128 / page_size) == 0\n        # 2. Triton: number of pages per block\n        trtllm_constraint = TRTLLM_BLOCK_CONSTRAINT // self.page_size\n        triton_constraint = get_num_page_per_block_flashmla(self.page_size)\n        constraint_lcm = math.lcm(trtllm_constraint, triton_constraint)\n\n        if blocks % constraint_lcm != 0:\n            blocks = triton.cdiv(blocks, constraint_lcm) * constraint_lcm\n        return blocks\n\n    def _create_block_kv_indices(\n        self,\n        batch_size: int,\n        max_blocks: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        device: torch.device,\n    ) -> torch.Tensor:\n        \"\"\"\n        Create block KV indices tensor using Triton kernel.\n\n        Args:\n            batch_size: Batch size\n            max_blocks: Maximum number of blocks per sequence\n            req_pool_indices: Request pool indices\n            seq_lens: Sequence lengths\n            device: Target device\n\n        Returns:\n            Block KV indices tensor\n        \"\"\"\n        block_kv_indices = torch.full(\n            (batch_size, max_blocks), -1, dtype=torch.int32, device=device\n        )\n\n        create_flashmla_kv_indices_triton[(batch_size,)](\n            self.req_to_token,\n            req_pool_indices,\n            seq_lens,\n            None,\n            block_kv_indices,\n            self.req_to_token.stride(0),\n            max_blocks,\n            PAGED_SIZE=self.page_size,\n        )\n\n        return block_kv_indices\n\n    def init_cuda_graph_state(\n        self,\n        max_bs: int,\n        max_num_tokens: int,\n        kv_indices_buf: Optional[torch.Tensor] = None,\n    ):\n        \"\"\"Initialize CUDA graph state for TRTLLM MLA.\"\"\"\n\n        max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)\n\n        self.decode_cuda_graph_kv_indices = torch.full(\n            (max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device\n        )\n        num_tokens_per_bs = max_num_tokens // max_bs\n\n        if is_float4_e2m1fn_x2(self.data_type):\n            # Buffer for padded query: (max_bs, max_draft_tokens, num_q_heads, v_head_dim)\n            self.store_dtype = torch.uint8\n            self.padded_q_buffer = torch.zeros(\n                (max_bs, num_tokens_per_bs // 2, self.num_q_heads, self.kv_cache_dim),\n                dtype=self.store_dtype,\n                device=self.device,\n            )\n\n            # Buffer for unpadded output: (max_num_tokens, num_q_heads, v_head_dim)\n            self.unpad_output_buffer = torch.zeros(\n                (max_num_tokens // 2, self.num_q_heads, 512),\n                dtype=self.store_dtype,\n                device=self.device,\n            )\n        else:\n            # Buffer for padded query: (max_bs, max_draft_tokens, num_q_heads, v_head_dim)\n            self.padded_q_buffer = torch.zeros(\n                (max_bs, num_tokens_per_bs, self.num_q_heads, self.kv_cache_dim),\n                dtype=self.data_type,\n                device=self.device,\n            )\n\n            # Buffer for unpadded output: (max_num_tokens, num_q_heads, v_head_dim)\n            self.unpad_output_buffer = torch.zeros(\n                (max_num_tokens, self.num_q_heads, 512),\n                dtype=self.data_type,\n                device=self.device,\n            )\n\n        super().init_cuda_graph_state(max_bs, max_num_tokens, kv_indices_buf)\n\n    def init_forward_metadata_capture_cuda_graph(\n        self,\n        bs: int,\n        num_tokens: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        encoder_lens: Optional[torch.Tensor],\n        forward_mode: ForwardMode,\n        spec_info: Optional[SpecInput],\n    ):\n        \"\"\"Initialize metadata for CUDA graph capture.\"\"\"\n\n        # Delegate to parent for non-decode modes.\n        if (\n            not forward_mode.is_decode_or_idle()\n            and not forward_mode.is_target_verify()\n            and not forward_mode.is_draft_extend(include_v2=True)\n        ):\n            return super().init_forward_metadata_capture_cuda_graph(\n                bs,\n                num_tokens,\n                req_pool_indices,\n                seq_lens,\n                encoder_lens,\n                forward_mode,\n                spec_info,\n            )\n\n        metadata = TRTLLMMLADecodeMetadata()\n\n        if forward_mode.is_target_verify():\n            seq_lens = seq_lens + self.num_draft_tokens\n            metadata.seq_lens_k = torch.zeros(\n                (bs,), dtype=torch.int32, device=seq_lens.device\n            )\n            metadata.seq_lens_k.copy_(seq_lens.to(dtype=torch.int32))\n        elif forward_mode.is_draft_extend(include_v2=True):\n            num_tokens_per_bs = num_tokens // bs\n            metadata.max_seq_len_q = num_tokens_per_bs\n            metadata.sum_seq_lens_q = num_tokens_per_bs * bs\n            metadata.cu_seqlens_q = torch.arange(\n                0,\n                bs * num_tokens_per_bs + 1,\n                num_tokens_per_bs,\n                dtype=torch.int32,\n                device=seq_lens.device,\n            )\n            metadata.seq_lens_q = torch.full(\n                (bs,), num_tokens_per_bs, dtype=torch.int32, device=seq_lens.device\n            )\n            # NOTE(draft_extend seq_len handling):\n            # forward_batch.seq_lens is the seq_lens of the prev_context + verified tokens.\n            # To account for pad_draft_extend_query, we need seq_lens = prev_context + max_draft_tokens.\n            # This will ensure queries align with kvs correctly when calling\n            # flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla.\n            seq_lens = seq_lens - metadata.seq_lens_q + metadata.max_seq_len_q\n            metadata.seq_lens_k = torch.zeros(\n                (bs,), dtype=torch.int32, device=seq_lens.device\n            )\n            metadata.seq_lens_k.copy_(seq_lens.to(dtype=torch.int32))\n\n        # Custom fast-path for decode/idle.\n        # Capture with full width so future longer sequences are safe during replay\n        max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)\n        block_kv_indices = self.decode_cuda_graph_kv_indices[:bs, :max_blocks_per_seq]\n\n        create_flashmla_kv_indices_triton[(bs,)](\n            self.req_to_token,\n            req_pool_indices,\n            seq_lens,\n            None,\n            block_kv_indices,\n            self.req_to_token.stride(0),\n            max_blocks_per_seq,\n            PAGED_SIZE=self.page_size,\n        )\n\n        metadata.block_kv_indices = block_kv_indices\n        metadata.max_seq_len_k = self.max_context_len\n\n        self.decode_cuda_graph_metadata[bs] = metadata\n        self.forward_decode_metadata = metadata\n\n    def init_forward_metadata_replay_cuda_graph(\n        self,\n        bs: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        seq_lens_sum: int,\n        encoder_lens: Optional[torch.Tensor],\n        forward_mode: ForwardMode,\n        spec_info: Optional[SpecInput],\n        seq_lens_cpu: Optional[torch.Tensor],\n    ):\n        \"\"\"Replay CUDA graph with new inputs.\"\"\"\n        # Delegate to parent for non-decode modes.\n        if (\n            not forward_mode.is_decode_or_idle()\n            and not forward_mode.is_target_verify()\n            and not forward_mode.is_draft_extend(include_v2=True)\n        ):\n            return super().init_forward_metadata_replay_cuda_graph(\n                bs,\n                req_pool_indices,\n                seq_lens,\n                seq_lens_sum,\n                encoder_lens,\n                forward_mode,\n                spec_info,\n                seq_lens_cpu,\n            )\n\n        metadata = self.decode_cuda_graph_metadata[bs]\n\n        if forward_mode.is_target_verify():\n            seq_lens = seq_lens[:bs] + self.num_draft_tokens\n            metadata.seq_lens_k.copy_(seq_lens.to(dtype=torch.int32))\n            del seq_lens_sum  # not handle \"num_draft_tokens\" but we do not need it\n        elif forward_mode.is_draft_extend(include_v2=True):\n            accept_length = spec_info.accept_length[:bs]\n            if spec_info.accept_length_cpu:\n                metadata.max_seq_len_q = max(spec_info.accept_length_cpu[:bs]) + 1\n                metadata.sum_seq_lens_q = sum(spec_info.accept_length_cpu[:bs]) + bs\n            else:\n                metadata.max_seq_len_q = 1\n                metadata.sum_seq_lens_q = bs\n            # draft_extend uses (accept_length + 1) query tokens per sequence\n            extend_seq_lens = accept_length + 1\n            metadata.cu_seqlens_q[1:].copy_(\n                torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32)\n            )\n            metadata.seq_lens_q.copy_(extend_seq_lens)\n            # see NOTE(draft_extend seq_len handling)\n            seq_lens = seq_lens[:bs] - metadata.seq_lens_q + metadata.max_seq_len_q\n            metadata.seq_lens_k.copy_(seq_lens.to(torch.int32))\n\n        # Update block indices for new sequences.\n        create_flashmla_kv_indices_triton[(bs,)](\n            self.req_to_token,\n            req_pool_indices[:bs],\n            seq_lens,\n            None,\n            metadata.block_kv_indices,\n            self.req_to_token.stride(0),\n            metadata.block_kv_indices.shape[1],\n            PAGED_SIZE=self.page_size,\n        )\n\n    def get_cuda_graph_seq_len_fill_value(self) -> int:\n        \"\"\"Get the fill value for sequence lengths in CUDA graph.\"\"\"\n        return 1\n\n    def init_forward_metadata(self, forward_batch: ForwardBatch):\n        \"\"\"Initialize the metadata for a forward pass.\"\"\"\n        # Delegate to parent for non-decode modes.\n        if (\n            forward_batch.forward_mode.is_extend()\n            and not forward_batch.forward_mode.is_target_verify()\n            and not forward_batch.forward_mode.is_draft_extend(include_v2=True)\n        ):\n            # For extend batch with prefix length > 0, fallback to ragged kernel implemented in flashinfer MLA backend\n            # when chunked prefix cache is disabled.\n            # Also fallback to flashinfer MLA backend when in piecewise cuda graph, since it only supports MLA forward mode.\n            has_prefix = any(forward_batch.extend_prefix_lens_cpu)\n            fallback_to_flashinfer_impl = (\n                self.disable_chunked_prefix_cache and has_prefix\n            ) or is_in_piecewise_cuda_graph()\n            if fallback_to_flashinfer_impl:\n                super().init_forward_metadata(forward_batch)\n\n            seq_lens = forward_batch.seq_lens - forward_batch.extend_prefix_lens\n            cum_seq_lens_q = torch.cat(\n                (\n                    torch.zeros(\n                        1, dtype=torch.int32, device=forward_batch.seq_lens.device\n                    ),\n                    torch.cumsum(seq_lens, dim=0),\n                )\n            ).int()\n            max_seq_len = max(forward_batch.extend_seq_lens_cpu)\n            self.forward_prefill_metadata = TRTLLMMLAPrefillMetadata(\n                max_seq_len,\n                cum_seq_lens_q,\n                seq_lens,\n                fallback_to_flashinfer_impl,\n            )\n        elif (\n            forward_batch.forward_mode.is_decode_or_idle()\n            or forward_batch.forward_mode.is_target_verify()\n            or forward_batch.forward_mode.is_draft_extend(include_v2=True)\n        ):\n            bs = forward_batch.batch_size\n            self.forward_decode_metadata = TRTLLMMLADecodeMetadata()\n            # This is necessary because the backend instance persists across forward passes,\n            # and forward_prefill_metadata from a previous regular extend call could still be set.\n            if (\n                forward_batch.forward_mode.is_target_verify()\n                or forward_batch.forward_mode.is_draft_extend(include_v2=True)\n            ):\n                self.forward_prefill_metadata = None\n            # Get maximum sequence length.\n            if getattr(forward_batch, \"seq_lens_cpu\", None) is not None:\n                max_seq = forward_batch.seq_lens_cpu.max().item()\n            else:\n                max_seq = forward_batch.seq_lens.max().item()\n\n            seq_lens = forward_batch.seq_lens\n\n            if forward_batch.forward_mode.is_target_verify():\n                max_seq = max_seq + self.num_draft_tokens\n                seq_lens = seq_lens + self.num_draft_tokens\n                self.forward_decode_metadata.seq_lens_k = seq_lens.to(torch.int32)\n            elif forward_batch.forward_mode.is_draft_extend(include_v2=True):\n                sum_seq_lens_q = sum(forward_batch.extend_seq_lens_cpu)\n                max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)\n                cu_seqlens_q = torch.nn.functional.pad(\n                    torch.cumsum(\n                        forward_batch.extend_seq_lens, dim=0, dtype=torch.int32\n                    ),\n                    (1, 0),\n                )\n                # see NOTE(draft_extend seq_len handling)\n                seq_lens = seq_lens - forward_batch.extend_seq_lens + max_seq_len_q\n\n                self.forward_decode_metadata.max_seq_len_q = max_seq_len_q\n                self.forward_decode_metadata.sum_seq_lens_q = sum_seq_lens_q\n                self.forward_decode_metadata.cu_seqlens_q = cu_seqlens_q\n                self.forward_decode_metadata.seq_lens_q = forward_batch.extend_seq_lens\n                self.forward_decode_metadata.seq_lens_k = seq_lens.to(torch.int32)\n\n            max_seqlen_pad = self._calc_padded_blocks(max_seq)\n            block_kv_indices = self._create_block_kv_indices(\n                bs,\n                max_seqlen_pad,\n                forward_batch.req_pool_indices,\n                seq_lens,\n                seq_lens.device,\n            )\n\n            self.forward_decode_metadata.block_kv_indices = block_kv_indices\n            self.forward_decode_metadata.max_seq_len_k = int(max_seq)\n            self.forward_decode_metadata.batch_size = bs\n\n            forward_batch.decode_trtllm_mla_metadata = self.forward_decode_metadata\n        else:\n            return super().init_forward_metadata(forward_batch)\n\n    def init_mha_chunk_metadata(self, forward_batch: ForwardBatch):\n        super().init_mha_chunk_metadata(forward_batch, disable_flashinfer_ragged=True)\n\n    def pad_draft_extend_query(\n        self,\n        q: torch.Tensor,\n        padded_q: torch.Tensor,\n        seq_lens_q: torch.Tensor,\n        cu_seqlens_q: torch.Tensor,\n    ) -> torch.Tensor:\n        \"\"\"Pad draft extended query using Triton kernel.\"\"\"\n        batch_size = cu_seqlens_q.shape[0] - 1\n        max_seq_len_q = padded_q.shape[1]\n        num_heads = padded_q.shape[2]\n        head_dim = padded_q.shape[3]\n\n        # Launch Triton kernel with 3D grid for parallelized head and dim processing\n        BLOCK_SIZE = 64\n        num_head_blocks = triton.cdiv(num_heads, BLOCK_SIZE)\n        num_dim_blocks = triton.cdiv(head_dim, BLOCK_SIZE)\n        grid = (batch_size * max_seq_len_q, num_head_blocks, num_dim_blocks)\n\n        pad_draft_extend_query_kernel[grid](\n            q_ptr=q,\n            padded_q_ptr=padded_q,\n            seq_lens_q_ptr=seq_lens_q,\n            cumsum_ptr=cu_seqlens_q,\n            batch_size=batch_size,\n            max_seq_len=max_seq_len_q,\n            num_heads=num_heads,\n            head_dim=head_dim,\n            BLOCK_SIZE=BLOCK_SIZE,\n        )\n        return padded_q\n\n    def unpad_draft_extend_output(\n        self,\n        raw_out: torch.Tensor,\n        cu_seqlens_q: torch.Tensor,\n        seq_lens_q: torch.Tensor,\n        sum_seq_lens_q: int,\n    ) -> torch.Tensor:\n        \"\"\"Unpad draft extended output using Triton kernel.\"\"\"\n        # raw_out: (batch_size, token_per_batch, layer.tp_q_head_num, layer.v_head_dim)\n        batch_size = seq_lens_q.shape[0]\n        token_per_batch = raw_out.shape[1]  # max_seq_len\n        tp_q_head_num = raw_out.shape[2]  # num_heads\n        v_head_dim = raw_out.shape[3]  # head_dim\n        total_tokens = sum_seq_lens_q\n\n        # Check if we're in CUDA graph mode (buffers are pre-allocated)\n        if self.unpad_output_buffer is not None:\n            # Use pre-allocated buffer for CUDA graph compatibility\n            output = self.unpad_output_buffer[:total_tokens, :, :].to(\n                dtype=raw_out.dtype\n            )\n        else:\n            # Dynamic allocation for non-CUDA graph mode\n            output = torch.empty(\n                (total_tokens, tp_q_head_num, v_head_dim),\n                dtype=raw_out.dtype,\n                device=raw_out.device,\n            )\n\n        # Launch Triton kernel with 3D grid for parallelized head and dim processing\n        BLOCK_SIZE = 64\n        num_head_blocks = triton.cdiv(tp_q_head_num, BLOCK_SIZE)\n        num_dim_blocks = triton.cdiv(v_head_dim, BLOCK_SIZE)\n        grid = (batch_size * token_per_batch, num_head_blocks, num_dim_blocks)\n\n        unpad_draft_extend_output_kernel[grid](\n            raw_out_ptr=raw_out,\n            output_ptr=output,\n            accept_length_ptr=seq_lens_q,\n            cumsum_ptr=cu_seqlens_q,\n            batch_size=batch_size,\n            token_per_batch=token_per_batch,\n            tp_q_head_num=tp_q_head_num,\n            v_head_dim=v_head_dim,\n            BLOCK_SIZE=BLOCK_SIZE,\n        )\n        return output[:total_tokens, :, :]\n\n    def forward_decode(\n        self,\n        q: torch.Tensor,  # q_nope\n        k: torch.Tensor,  # k_nope\n        v: torch.Tensor,  # not used in this backend\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache: bool = True,\n        q_rope: Optional[torch.Tensor] = None,\n        k_rope: Optional[torch.Tensor] = None,\n        cos_sin_cache: Optional[torch.Tensor] = None,\n        is_neox: Optional[bool] = False,\n        llama_4_scaling: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        \"\"\"Run forward for decode using TRTLLM MLA kernel.\"\"\"\n        merge_query = q_rope is not None\n        if self.data_type == torch.float8_e4m3fn:\n            # For FP8 path, we quantize the query and rope parts and merge them into a single tensor\n            # Note: rope application in deepseek_v2.py:forward_absorb_prepare is skipped for FP8 decode path of this trtllm_mla backend\n            assert all(\n                x is not None for x in [q_rope, k_rope, cos_sin_cache]\n            ), \"For FP8 path and using flashinfer.rope.mla_rope_quantize we need all of q_rope, k_rope and cos_sin_cache to be not None.\"\n            q, k, k_rope = mla_quantize_and_rope_for_fp8(\n                q,\n                q_rope,\n                k.squeeze(1),\n                k_rope.squeeze(1),\n                forward_batch.positions,\n                cos_sin_cache,\n                is_neox,\n                self.kv_lora_rank,\n                self.qk_rope_head_dim,\n            )\n            merge_query = False\n\n        # Save KV cache if requested\n        if save_kv_cache:\n            assert (\n                k is not None and k_rope is not None\n            ), \"For populating trtllm_mla kv cache, both k_nope and k_rope should be not None.\"\n            forward_batch.token_to_kv_pool.set_mla_kv_buffer(\n                layer, forward_batch.out_cache_loc, k, k_rope\n            )\n\n        # Prepare query tensor inline\n        if merge_query:\n            # For FP16 path, we merge the query and rope parts into a single tensor\n            q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)\n            q_rope_reshaped = q_rope.view(\n                -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim\n            )\n            query = concat_mla_absorb_q_general(q_nope, q_rope_reshaped)\n        else:\n            # For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function\n            query = q.view(-1, layer.tp_q_head_num, layer.head_dim)\n\n        # Apply llama 4 scaling if provided\n        if llama_4_scaling is not None:\n            query = query.to(self.q_data_type) * llama_4_scaling\n            query = query.to(self.data_type)\n\n        # Ensure query has shape [bs, acc_q_len, num_q_heads, head_dim] when seq_len 1\n        if query.dim() == 3:\n            query = query.unsqueeze(1)\n\n        # Prepare KV cache inline\n        k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)\n        kv_cache = k_cache.view(-1, self.page_size, self.kv_cache_dim).unsqueeze(1)\n\n        # Get metadata\n        metadata = (\n            getattr(forward_batch, \"decode_trtllm_mla_metadata\", None)\n            or self.forward_decode_metadata\n        )\n\n        # Ensure batch_size is sufficient, the batch size increase due to the padding from the forward batch\n        # FIXME(@rainj-me), refactor the skip_attn_backend_init, init_forward_metadata for attn backends\n        # and padding logic in prepare_mlp_sync_batch to avoid this\n        batch_size = getattr(metadata, \"batch_size\", None)\n        if batch_size is not None and batch_size < forward_batch.batch_size:\n            self.init_forward_metadata(forward_batch)\n            metadata = forward_batch.decode_trtllm_mla_metadata\n\n        # Scale computation for TRTLLM MLA kernel BMM1 operation:\n        # The final BMM1 scale is computed as: q_scale * k_scale * softmax_scale\n        # Scale components:\n        # - q_scale: Query scaling factor (set to 1.0 for both FP16/FP8 paths)\n        # - k_scale: Key scaling factor from model checkpoint. Only applied when KV cache\n        #   stores FP8-quantized values, to compensate for the quantization scaling.\n        #   For BF16/FP16 KV cache, k_scale must be 1.0 since values are unscaled.\n        # - softmax_scale: Attention softmax scaling = 1/sqrt(head_dim), pre-computed as layer.scaling\n        q_scale = 1.0\n        if self.data_type == torch.float8_e4m3fn:\n            k_scale = (\n                layer.k_scale_float\n                if getattr(layer, \"k_scale_float\", None) is not None\n                else 1.0\n            )\n        else:\n            if getattr(layer, \"k_scale_float\", None) is not None:\n                logger.warning_once(\n                    \"Checkpoint has k_scale but KV cache dtype is not FP8. \"\n                    \"Ignoring k_scale for BMM1 (k_scale=%.4f, kv_dtype=%s).\",\n                    layer.k_scale_float,\n                    self.data_type,\n                )\n            k_scale = 1.0\n\n        bmm1_scale = q_scale * k_scale * layer.scaling\n\n        # Call TRT-LLM kernel\n        raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(\n            query=query,\n            kv_cache=kv_cache,\n            workspace_buffer=self.workspace_buffer,\n            qk_nope_head_dim=self.qk_nope_head_dim,\n            kv_lora_rank=self.kv_lora_rank,\n            qk_rope_head_dim=self.qk_rope_head_dim,\n            block_tables=metadata.block_kv_indices,\n            seq_lens=forward_batch.seq_lens.to(torch.int32),\n            max_seq_len=metadata.max_seq_len_k,\n            bmm1_scale=bmm1_scale,\n        )\n\n        # Reshape output directly without slicing\n        output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim)\n        return output\n\n    def forward_extend(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache: bool = True,\n        q_rope: Optional[torch.Tensor] = None,\n        k_rope: Optional[torch.Tensor] = None,\n        cos_sin_cache: Optional[torch.Tensor] = None,\n        is_neox: Optional[bool] = False,\n        llama_4_scaling: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n\n        if (\n            self.forward_prefill_metadata is not None\n            and self.forward_prefill_metadata.fallback_to_flashinfer_impl\n        ):\n            return super().forward_extend(\n                q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope\n            )\n\n        # TODO refactor to avoid code duplication\n        merge_query = q_rope is not None\n        if (\n            self.data_type == torch.float8_e4m3fn\n        ) and forward_batch.forward_mode.is_target_verify():\n            # For FP8 path, we quantize the query and rope parts and merge them into a single tensor\n            # Note: rope application in deepseek_v2.py:forward_absorb_prepare is skipped for FP8 decode path of this trtllm_mla backend\n            assert all(\n                x is not None for x in [q_rope, k_rope, cos_sin_cache]\n            ), \"For FP8 path and using flashinfer.rope.mla_rope_quantize we need all of q_rope, k_rope and cos_sin_cache to be not None.\"\n            q, k, k_rope = mla_quantize_and_rope_for_fp8(\n                q,\n                q_rope,\n                k.squeeze(1),\n                k_rope.squeeze(1),\n                forward_batch.positions,\n                cos_sin_cache,\n                is_neox,\n                self.kv_lora_rank,\n                self.qk_rope_head_dim,\n            )\n            merge_query = False\n\n        # Save KV cache if requested\n        if save_kv_cache:\n            assert (\n                k is not None and k_rope is not None\n            ), \"For populating trtllm_mla kv cache, both k_nope and k_rope should be not None.\"\n            forward_batch.token_to_kv_pool.set_mla_kv_buffer(\n                layer, forward_batch.out_cache_loc, k, k_rope\n            )\n\n        # TODO refactor to avoid code duplication\n        # Prepare query tensor inline\n        if merge_query:\n            # For FP16 path, we merge the query and rope parts into a single tensor\n            q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)\n            q_rope_reshaped = q_rope.view(\n                -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim\n            )\n            q = concat_mla_absorb_q_general(q_nope, q_rope_reshaped)\n\n        q = q.view(-1, layer.tp_q_head_num, layer.head_dim)\n\n        # Apply llama 4 scaling if provided\n        if llama_4_scaling is not None:\n            q = q.to(self.q_data_type) * llama_4_scaling\n            q = q.to(self.data_type)\n\n        if (\n            forward_batch.forward_mode.is_target_verify()\n            or forward_batch.forward_mode.is_draft_extend(include_v2=True)\n        ):\n            metadata = (\n                getattr(forward_batch, \"decode_trtllm_mla_metadata\", None)\n                or self.forward_decode_metadata\n            )\n\n            # Ensure batch_size is sufficient, the batch size increase due to the padding from the forward batch\n            # FIXME(@rainj-me), refactor the skip_attn_backend_init, init_forward_metadata for attn backends\n            # and padding logic in prepare_mlp_sync_batch to avoid this\n            batch_size = getattr(metadata, \"batch_size\", None)\n            if batch_size is not None and batch_size < forward_batch.batch_size:\n                self.init_forward_metadata(forward_batch)\n                metadata = forward_batch.decode_trtllm_mla_metadata\n\n            # Ensure query has shape [bs, num_draft_tokens, num_q_heads, head_dim]\n            bs = forward_batch.batch_size\n\n            k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)\n            kv_cache = k_cache.view(-1, self.page_size, self.kv_cache_dim).unsqueeze(1)\n\n            q_scale = 1.0\n            if self.data_type == torch.float8_e4m3fn:\n                k_scale = (\n                    layer.k_scale_float\n                    if getattr(layer, \"k_scale_float\", None) is not None\n                    else 1.0\n                )\n            else:\n                if getattr(layer, \"k_scale_float\", None) is not None:\n                    logger.warning_once(\n                        \"Checkpoint has k_scale but KV cache dtype is not FP8. \"\n                        \"Ignoring k_scale for BMM1 (k_scale=%.4f, kv_dtype=%s).\",\n                        layer.k_scale_float,\n                        self.data_type,\n                    )\n                k_scale = 1.0\n            q = q.to(self.data_type)\n\n            bmm1_scale = q_scale * k_scale * layer.scaling\n            if forward_batch.forward_mode.is_target_verify():\n                max_seq_len = (\n                    metadata.max_seq_len_k + forward_batch.spec_info.draft_token_num\n                )\n                # For target_verify, all sequences have the same number of draft tokens\n                q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)\n                needs_unpad = False\n            else:\n                # draft_extend: handle varying accept_lengths. If total_tokens % bs == 0,\n                # we can directly reshape q; otherwise, pad to max_seq_len_q.\n                total_tokens = q.shape[0]\n                tokens_per_seq = total_tokens // bs if bs > 0 else 0\n                can_direct_view = bs > 0 and (total_tokens % bs == 0)\n\n                if can_direct_view:\n                    max_seq_len = metadata.max_seq_len_k + tokens_per_seq\n                    q = q.view(bs, tokens_per_seq, layer.tp_q_head_num, layer.head_dim)\n                    needs_unpad = False\n                else:\n                    # Varying lengths: pad q to (bs, max_seq_len_q, ...)\n                    actual_seq_lens_q = forward_batch.extend_seq_lens\n                    actual_max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)\n                    max_seq_len = metadata.max_seq_len_k + actual_max_seq_len_q\n\n                    actual_cu_seqlens_q = torch.nn.functional.pad(\n                        torch.cumsum(actual_seq_lens_q, dim=0, dtype=torch.int32),\n                        (1, 0),\n                    )\n\n                    if self.padded_q_buffer is not None:\n                        padded_q = self.padded_q_buffer[\n                            :bs, :actual_max_seq_len_q, :, :\n                        ].to(dtype=q.dtype)\n                        padded_q.zero_()\n                    else:\n                        padded_q = torch.zeros(\n                            (\n                                bs,\n                                actual_max_seq_len_q,\n                                layer.tp_q_head_num,\n                                layer.head_dim,\n                            ),\n                            dtype=q.dtype,\n                            device=q.device,\n                        )\n\n                    q = self.pad_draft_extend_query(\n                        q, padded_q, actual_seq_lens_q, actual_cu_seqlens_q\n                    )\n                    needs_unpad = True\n                    unpad_seq_lens_q = actual_seq_lens_q\n                    unpad_cu_seqlens_q = actual_cu_seqlens_q\n                    unpad_sum_seq_lens_q = total_tokens\n\n            assert kv_cache.dtype == self.data_type\n\n            raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(\n                query=q,\n                kv_cache=kv_cache,\n                workspace_buffer=self.workspace_buffer,\n                qk_nope_head_dim=self.qk_nope_head_dim,\n                kv_lora_rank=self.kv_lora_rank,\n                qk_rope_head_dim=self.qk_rope_head_dim,\n                block_tables=metadata.block_kv_indices,\n                seq_lens=metadata.seq_lens_k,\n                max_seq_len=max_seq_len,\n                bmm1_scale=bmm1_scale,\n            )\n\n            if needs_unpad:\n                # Unpad the output for draft_extend mode with varying lengths\n                # Use the actual values computed during padding, not from metadata\n                output = self.unpad_draft_extend_output(\n                    raw_out,\n                    unpad_cu_seqlens_q,\n                    unpad_seq_lens_q,\n                    unpad_sum_seq_lens_q,\n                )\n                output = output.view(-1, layer.tp_q_head_num * layer.v_head_dim)\n            else:\n                output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim)\n            return output\n\n        if k_rope is not None:\n            k = torch.cat([k, k_rope], dim=-1)\n        k = k.view(-1, layer.tp_k_head_num, layer.head_dim)\n        v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim)\n\n        q_scale = k_scale = v_scale = 1.0\n        if self.data_type == torch.float8_e4m3fn:\n            q, k, v, k_scale, v_scale = _quantize_fp8_qkv(q, k, v, layer)\n\n        common_trtllm_args = {\n            \"query\": q,\n            \"key\": k,\n            \"value\": v,\n            \"workspace_buffer\": self.workspace_buffer,\n            \"batch_size\": forward_batch.batch_size,\n            \"window_left\": -1,\n            \"enable_pdl\": False,\n            \"max_q_len\": self.forward_prefill_metadata.max_seq_len,\n            \"bmm1_scale\": q_scale * k_scale * layer.scaling,\n            \"bmm2_scale\": v_scale,\n            \"cum_seq_lens_q\": self.forward_prefill_metadata.cum_seq_lens,\n        }\n\n        # When chunked prefix cache is enabled, dispatch to different path for ragged attention.\n        if forward_batch.attn_attend_prefix_cache:\n            # MHA for chunked prefix kv cache when running model with MLA\n            assert forward_batch.prefix_chunk_idx is not None\n            assert forward_batch.prefix_chunk_cu_seq_lens is not None\n            assert q_rope is None\n            assert k_rope is None\n            chunk_idx = forward_batch.prefix_chunk_idx\n\n            out = torch.zeros(\n                q.shape[0],\n                layer.tp_q_head_num,\n                layer.v_head_dim,\n                dtype=self.q_data_type,\n                device=q.device,\n            )\n            return flashinfer.prefill.trtllm_ragged_attention_deepseek(\n                **common_trtllm_args,\n                seq_lens=forward_batch.prefix_chunk_seq_lens[chunk_idx],\n                max_kv_len=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],\n                o_sf_scale=-1.0,\n                cum_seq_lens_kv=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],\n                is_causal=False,\n                return_lse=True,\n                out=out,\n            )\n        else:\n            out = torch.zeros(\n                q.shape[0],\n                q.shape[1],\n                v.shape[2],\n                device=q.device,\n                dtype=self.q_data_type,\n            )\n            return flashinfer.prefill.trtllm_ragged_attention_deepseek(\n                **common_trtllm_args,\n                seq_lens=self.forward_prefill_metadata.seq_lens,\n                max_kv_len=self.forward_prefill_metadata.max_seq_len,\n                o_sf_scale=1.0,\n                cum_seq_lens_kv=self.forward_prefill_metadata.cum_seq_lens,\n                is_causal=True,\n                return_lse=forward_batch.mha_return_lse,\n                out=out,\n            )\n\n\nclass TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):\n    \"\"\"Multi-step draft backend for TRT-LLM MLA used by EAGLE.\"\"\"\n\n    def __init__(\n        self, model_runner: \"ModelRunner\", topk: int, speculative_num_steps: int\n    ):\n        super().__init__(model_runner, topk, speculative_num_steps)\n\n        for i in range(self.speculative_num_steps - 1):\n            self.attn_backends[i] = TRTLLMMLABackend(\n                model_runner,\n                skip_prefill=True,\n                kv_indptr_buf=self.kv_indptr[i],\n                q_indptr_decode_buf=self.q_indptr_decode,\n            )\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/utils.py",
    "content": "import torch\nimport triton\nimport triton.language as tl\n\nfrom sglang.srt.utils import is_cuda\n\n_FLASHMLA_CREATE_KV_BLOCK_SIZE = 4096\nFLASHMLA_CREATE_KV_BLOCK_SIZE_TRITON = tl.constexpr(_FLASHMLA_CREATE_KV_BLOCK_SIZE)\n\n_is_cuda = is_cuda()\n\nif _is_cuda:\n    from sgl_kernel import concat_mla_absorb_q\n\n\n@triton.jit\ndef create_flashinfer_kv_indices_triton(\n    req_to_token_ptr,  # [max_batch, max_context_len]\n    req_pool_indices_ptr,\n    page_kernel_lens_ptr,\n    kv_indptr,\n    kv_start_idx,\n    kv_indices_ptr,\n    req_to_token_ptr_stride: tl.constexpr,\n):\n    BLOCK_SIZE: tl.constexpr = 512\n    pid = tl.program_id(axis=0)\n\n    # find the req pool idx, this is for batch to token\n    req_pool_index = tl.load(req_pool_indices_ptr + pid)\n    kv_indices_offset = tl.load(kv_indptr + pid)\n\n    kv_start = 0\n    kv_end = 0\n    if kv_start_idx:\n        kv_start = tl.load(kv_start_idx + pid).to(tl.int32)\n        kv_end = kv_start\n    kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)\n\n    num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)\n    for i in range(num_loop):\n        # index into req_to_token_ptr needs to be int64\n        offset = tl.arange(0, BLOCK_SIZE).to(tl.int64) + i * BLOCK_SIZE\n        mask = offset < kv_end - kv_start\n        data = tl.load(\n            req_to_token_ptr\n            + req_pool_index * req_to_token_ptr_stride\n            + kv_start\n            + offset,\n            mask=mask,\n        )\n        tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)\n\n\ndef get_num_page_per_block_flashmla(page_size: int = 64) -> int:\n    num_page_per_block = _FLASHMLA_CREATE_KV_BLOCK_SIZE // page_size\n    return num_page_per_block\n\n\n@triton.jit\ndef create_flashmla_kv_indices_triton(\n    req_to_token_ptr,  # [max_batch, max_context_len]\n    req_pool_indices_ptr,\n    page_kernel_lens_ptr,\n    kv_start_idx,\n    kv_indices_ptr,\n    req_to_token_ptr_stride: tl.constexpr,\n    kv_indices_ptr_stride: tl.constexpr,\n    PAGED_SIZE: tl.constexpr = 64,\n):\n    NUM_PAGE_PER_BLOCK: tl.constexpr = (\n        FLASHMLA_CREATE_KV_BLOCK_SIZE_TRITON // PAGED_SIZE\n    )\n    pid = tl.program_id(axis=0)\n\n    # find the req pool idx, this is for batch to token\n    req_pool_index = tl.load(req_pool_indices_ptr + pid)\n\n    kv_start = 0\n    kv_end = 0\n    if kv_start_idx:\n        kv_start = tl.load(kv_start_idx + pid).to(tl.int32)\n        kv_end = kv_start\n\n    kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)\n\n    num_paged = tl.cdiv(kv_end - kv_start, PAGED_SIZE)\n    num_pages_loop = tl.cdiv(kv_end - kv_start, FLASHMLA_CREATE_KV_BLOCK_SIZE_TRITON)\n\n    for i in range(num_pages_loop):\n        # index into req_to_token_ptr needs to be int64\n        paged_offset = (\n            tl.arange(0, NUM_PAGE_PER_BLOCK).to(tl.int64) + i * NUM_PAGE_PER_BLOCK\n        ) * PAGED_SIZE\n        paged_offset_out = tl.arange(0, NUM_PAGE_PER_BLOCK) + i * NUM_PAGE_PER_BLOCK\n\n        mask = paged_offset < num_paged * PAGED_SIZE\n        mask_out = paged_offset_out < num_paged\n\n        data = tl.load(\n            req_to_token_ptr\n            + req_pool_index * req_to_token_ptr_stride\n            + kv_start\n            + paged_offset,\n            mask=mask,\n        )\n        tl.store(\n            kv_indices_ptr + pid * kv_indices_ptr_stride + paged_offset_out,\n            data // PAGED_SIZE,\n            mask=mask_out,\n        )\n\n\n@triton.jit\ndef concat_and_cast_mha_k_kernel(\n    k_ptr,\n    k_nope_ptr,\n    k_rope_ptr,\n    head_cnt: tl.constexpr,\n    k_stride0: tl.constexpr,\n    k_stride1: tl.constexpr,\n    nope_stride0: tl.constexpr,\n    nope_stride1: tl.constexpr,\n    rope_stride0: tl.constexpr,\n    nope_dim: tl.constexpr,\n    rope_dim: tl.constexpr,\n):\n    pid_loc = tl.program_id(0)\n    head_range = tl.arange(0, head_cnt)\n\n    k_head_ptr = k_ptr + pid_loc * k_stride0 + head_range[:, None] * k_stride1\n\n    nope_offs = tl.arange(0, nope_dim)\n\n    src_nope_ptr = (\n        k_nope_ptr\n        + pid_loc * nope_stride0\n        + head_range[:, None] * nope_stride1\n        + nope_offs[None, :]\n    )\n    dst_nope_ptr = k_head_ptr + nope_offs[None, :]\n\n    src_nope = tl.load(src_nope_ptr)\n    tl.store(dst_nope_ptr, src_nope)\n\n    rope_offs = tl.arange(0, rope_dim)\n    src_rope_ptr = k_rope_ptr + pid_loc * rope_stride0 + rope_offs[None, :]\n    dst_rope_ptr = k_head_ptr + nope_dim + rope_offs[None, :]\n    src_rope = tl.load(src_rope_ptr)\n    tl.store(dst_rope_ptr, src_rope)\n\n\ndef concat_and_cast_mha_k_triton(\n    k: torch.Tensor,\n    k_nope: torch.Tensor,\n    k_rope: torch.Tensor,\n):\n    # The source data type will be implicitly converted to the target data type.\n    assert (\n        len(k.shape) == 3 and len(k_nope.shape) == 3 and len(k_rope.shape) == 3\n    ), f\"shape should be 3d, but got {k.shape=}, {k_nope.shape=}, {k_rope.shape=}\"\n    assert (\n        k.shape[0] == k_nope.shape[0] and k.shape[0] == k_rope.shape[0]\n    ), f\"invalid shape, got {k.shape=}, {k_nope.shape=}, {k_rope.shape=}\"\n    assert (\n        k.shape[1] == k_nope.shape[1] and 1 == k_rope.shape[1]\n    ), f\"invalid shape, got {k.shape=}, {k_nope.shape=}, {k_rope.shape=}\"\n    assert (\n        k.shape[-1] == k_nope.shape[-1] + k_rope.shape[-1]\n    ), f\"invalid shape, got {k.shape=}, {k_nope.shape=}, {k_rope.shape=}\"\n\n    nope_dim = k_nope.shape[-1]\n    rope_dim = k_rope.shape[-1]\n    grid = (k.shape[0],)\n\n    concat_and_cast_mha_k_kernel[grid](\n        k,\n        k_nope,\n        k_rope,\n        k.shape[1],\n        k.stride(0),\n        k.stride(1),\n        k_nope.stride(0),\n        k_nope.stride(1),\n        k_rope.stride(0),\n        nope_dim,\n        rope_dim,\n    )\n\n\n@triton.jit\ndef pad_sequence_with_mask_kernel(\n    input_ptr,  # (total_tokens, hidden)\n    offsets_ptr,  # (B,)\n    lengths_ptr,  # (B,)\n    output_ptr,  # (B, max_len, hidden)\n    mask_ptr,  # (B, max_len)\n    max_len,\n    hidden_dim,\n    BLOCK_M: tl.constexpr,  # seq block\n    BLOCK_D: tl.constexpr,  # hidden block\n):\n    b = tl.program_id(0)  # batch index\n    m = tl.program_id(1)  # seq block index\n\n    offset = tl.load(offsets_ptr + b)\n    length = tl.load(lengths_ptr + b)\n\n    seq_ids = m * BLOCK_M + tl.arange(0, BLOCK_M)\n    hid_ids = tl.arange(0, BLOCK_D)\n\n    seq_mask = seq_ids < max_len\n    valid_token = seq_ids < length\n\n    # input index\n    in_token = offset + seq_ids\n    in_ptr = input_ptr + in_token[:, None] * hidden_dim + hid_ids[None, :]\n\n    # output index\n    out_ptr = (\n        output_ptr\n        + b * max_len * hidden_dim\n        + seq_ids[:, None] * hidden_dim\n        + hid_ids[None, :]\n    )\n\n    values = tl.load(\n        in_ptr,\n        mask=valid_token[:, None] & (hid_ids[None, :] < hidden_dim),\n        other=0.0,\n    )\n\n    tl.store(\n        out_ptr,\n        values,\n        mask=seq_mask[:, None] & (hid_ids[None, :] < hidden_dim),\n    )\n\n    # attention mask\n    if tl.program_id(2) == 0:\n        mask_out_ptr = mask_ptr + b * max_len + seq_ids\n        tl.store(mask_out_ptr, valid_token, mask=seq_mask)\n\n\ndef pad_sequence_with_mask(\n    input_emb,  # (total_tokens, hidden)\n    offsets,  # (B,)\n    lengths,  # (B,)\n    max_len,\n):\n    B = offsets.shape[0]\n    hidden_dim = input_emb.shape[1]\n\n    output = torch.zeros(\n        (B, max_len, hidden_dim),\n        device=input_emb.device,\n        dtype=input_emb.dtype,\n    )\n    attn_mask = torch.empty(\n        (B * max_len),\n        device=input_emb.device,\n        dtype=torch.bool,\n    )\n\n    BLOCK_D = triton.next_power_of_2(hidden_dim)\n    BLOCK_M = triton.next_power_of_2(max_len)\n\n    grid = (\n        B,\n        triton.cdiv(max_len, BLOCK_M),\n        1,\n    )\n\n    pad_sequence_with_mask_kernel[grid](\n        input_emb,\n        offsets,\n        lengths,\n        output,\n        attn_mask,\n        max_len,\n        hidden_dim,\n        BLOCK_M=BLOCK_M,\n        BLOCK_D=BLOCK_D,\n    )\n\n    return B, output, attn_mask\n\n\n@triton.jit\ndef seqlens_expand_kernel(\n    extend_seq_lens_ptr,  # [N]\n    seq_lens_ptr,  # [N]\n    offsets_ptr,  # [N+1]\n    output_ptr,  # [sum(extend_seq_lens)]\n    N,\n    BLOCK: tl.constexpr,\n):\n    pid = tl.program_id(0)\n\n    if pid >= N:\n        return\n\n    qo_len = tl.load(extend_seq_lens_ptr + pid)\n    kv_len = tl.load(seq_lens_ptr + pid)\n\n    start = kv_len - qo_len + 1\n    out_offset = tl.load(offsets_ptr + pid)\n\n    offs = tl.arange(0, BLOCK)\n    mask = offs < qo_len\n\n    values = start + offs\n    tl.store(output_ptr + out_offset + offs, values, mask=mask)\n\n\ndef seqlens_expand_triton(\n    extend_seq_lens: torch.Tensor,\n    seq_lens: torch.Tensor,\n    total_len: int,\n    max_q_len: int,\n):\n    \"\"\"\n    extend_seq_lens: [N], int32, CUDA\n    seq_lens:        [N], int32, CUDA\n    \"\"\"\n    assert extend_seq_lens.is_cuda\n    assert seq_lens.is_cuda\n\n    N = extend_seq_lens.numel()\n\n    offsets = torch.zeros(N + 1, device=extend_seq_lens.device, dtype=torch.int32)\n    offsets[1:] = torch.cumsum(extend_seq_lens, dim=0)\n    output = torch.empty(total_len, device=extend_seq_lens.device, dtype=torch.int32)\n\n    BLOCK = triton.next_power_of_2(max_q_len)\n    grid = (N,)\n\n    seqlens_expand_kernel[grid](\n        extend_seq_lens,\n        seq_lens,\n        offsets,\n        output,\n        N,\n        BLOCK=BLOCK,\n    )\n\n    return output\n\n\n# When num_kv_heads=1, we have tensors with degenerate strides,\n# For example, as below, where we have stride[-3] == stride[-2]:\n# - shape: [num_pages, 1, 64, 128]\n# - stride: [8192, 128, 128, 1]\n# This will cause TMA desc validation fail in flashinfer (trtllm-mha backend).\n#\n# See: https://github.com/flashinfer-ai/flashinfer/issues/2232\ndef canonicalize_stride(tensor: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Adjust degenerate strides for a tensor, make it canonical.\n    \"\"\"\n    sizes = tensor.size()\n    strides = tensor.stride()\n    ndim = tensor.dim()\n\n    need_fix = any(\n        sizes[i] == 1 and strides[i] == strides[i + 1] for i in range(ndim - 1)\n    )\n\n    if not need_fix:\n        return tensor\n\n    # canonicalize the stride\n    # Example:\n    # - shape: [num_pages, 1, 64, 128]\n    # - stride: [8192, 128, 128, 1] (wrong!)\n    # Gives new stride: [8192, 8192, 128 ,1] (correct!)\n    new_strides = [0] * ndim\n    new_strides[-1] = 1\n    for i in range(ndim - 2, -1, -1):\n        new_strides[i] = new_strides[i + 1] * sizes[i + 1]\n\n    return tensor.as_strided(sizes, new_strides)\n\n\ndef mla_quantize_and_rope_for_fp8(\n    q_nope: torch.Tensor,\n    q_rope: torch.Tensor,\n    k_nope: torch.Tensor,\n    k_rope: torch.Tensor,\n    pos_ids: torch.Tensor,\n    cos_sin_cache: torch.Tensor,\n    is_neox: bool,\n    kv_lora_rank: int,\n    qk_rope_head_dim: int,\n) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n    import flashinfer.rope\n\n    \"\"\"Quantize and apply RoPE for FP8 attention path.\n\n        This function handles the FP8 quantization and RoPE application for MLA attention.\n        It takes separate query/key nope and rope components, applies RoPE to the rope parts,\n        quantizes all components to FP8, and merges the query components into a single tensor.\n\n        Args:\n            q_nope: Query no-position-encoding component [seq_len, num_heads, kv_lora_rank]\n                - expected dtype: torch.bfloat16\n            q_rope: Query RoPE component [seq_len, num_heads, qk_rope_head_dim]\n                - expected dtype: torch.bfloat16\n            k_nope: Key no-position-encoding component [seq_len, num_heads, kv_lora_rank]\n                - expected dtype: torch.bfloat16\n            k_rope: Key RoPE component [seq_len, num_heads, qk_rope_head_dim]\n                - expected dtype: torch.bfloat16\n            pos_ids: Position indices for each token\n                - expected dtype: torch.int64 or torch.int32\n            cos_sin_cache: Precomputed cosine/sine cache for RoPE\n                - expected dtype: matches q_/k_ input dtype (torch.bfloat16)\n            is_neox: Whether to use NeoX-style RoPE (interleaved) or GPT-style (half rotation)\n            kv_lora_rank: Dimension of the no-position-encoding component\n            qk_rope_head_dim: Dimension of the RoPE component\n\n        Returns:\n            tuple: (merged_q_out, k_nope_out, k_rope_out) quantized to FP8\n                - merged_q_out: [seq_len, num_heads, kv_lora_rank + qk_rope_head_dim], dtype=torch.float8_e4m3fn\n                - k_nope_out:   [seq_len, num_heads, kv_lora_rank], dtype=torch.float8_e4m3fn\n                - k_rope_out:   [seq_len, num_heads, qk_rope_head_dim], dtype=torch.float8_e4m3fn\n        \"\"\"\n    attn_dtype = torch.float8_e4m3fn\n    q_len, num_heads = q_rope.shape[0], q_rope.shape[1]\n\n    # Allocate output tensors with FP8 dtype\n    # Query output will contain merged nope + rope components\n    q_out = q_rope.new_empty(\n        q_len,\n        num_heads,\n        kv_lora_rank + qk_rope_head_dim,\n        dtype=attn_dtype,\n    )\n\n    # Key outputs maintain original shapes but with FP8 dtype\n    k_rope_out = k_rope.new_empty(k_rope.shape, dtype=attn_dtype)\n    k_nope_out = k_nope.new_empty(k_nope.shape, dtype=attn_dtype)\n\n    # Apply RoPE and quantize all components in a single fused kernel call\n    # This kernel handles:\n    # 1. RoPE application to q_rope and k_rope using cos_sin_cache and positions\n    # 2. Quantization of all components to FP8 format\n    # 3. Output placement into pre-allocated tensors\n    flashinfer.rope.mla_rope_quantize_fp8(\n        q_rope=q_rope,\n        k_rope=k_rope,\n        q_nope=q_nope,\n        k_nope=k_nope,\n        cos_sin_cache=cos_sin_cache,\n        pos_ids=pos_ids,\n        is_neox=is_neox,\n        quantize_dtype=attn_dtype,\n        # Output tensor slicing: q_out contains [nope_part, rope_part]\n        q_rope_out=q_out[..., kv_lora_rank:],  # RoPE part goes to end\n        k_rope_out=k_rope_out,\n        q_nope_out=q_out[..., :kv_lora_rank],  # Nope part goes to beginning\n        k_nope_out=k_nope_out,\n        # Quantization scales (set to 1.0 for no additional scaling)\n        quant_scale_q=1.0,\n        quant_scale_kv=1.0,\n    )\n\n    return q_out, k_nope_out, k_rope_out\n\n\ndef concat_mla_absorb_q_general(q_nope, q_rope):\n    if _is_cuda and q_nope.shape[-1] == 512 and q_rope.shape[-1] == 64:\n        return concat_mla_absorb_q(q_nope, q_rope)\n    else:\n        return torch.cat([q_nope, q_rope], dim=-1)\n\n\n@triton.jit\ndef reshape_and_cache_flash(\n    key_ptr,\n    value_ptr,\n    key_cache_ptr,\n    value_cache_ptr,\n    slot_mapping_ptr,\n    swa_slot_mapping_ptr,\n    k_scale_ptr,\n    v_scale_ptr,\n    block_stride,\n    key_stride,\n    value_stride,\n    num_heads,\n    head_size,\n    block_size,\n    HEAD_BLOCK: tl.constexpr,\n    BLOCK_D: tl.constexpr,\n    HAS_SWA: tl.constexpr,\n    USE_SCALE: tl.constexpr,\n):\n    \"\"\"\n    Triton kernel for reshaping per-token K/V tensors into paged KV cache layout.\n\n    Source layout:\n        key/value: [num_tokens, num_heads, head_size]\n\n    Target cache layout:\n        cache: [num_blocks, block_size, num_heads, head_size]\n\n    Each Triton program instance handles:\n        - one token (program_id(0))\n        - one block of heads (program_id(1))\n\n    Features:\n        - optional SWA slot remapping\n        - optional FP8 scale dequantization before cache write\n\n    Args:\n        key_ptr: Pointer to source key tensor.\n        value_ptr: Pointer to source value tensor.\n        key_cache_ptr: Pointer to destination key cache tensor.\n        value_cache_ptr: Pointer to destination value cache tensor.\n        slot_mapping_ptr: Maps token -> cache slot.\n        swa_slot_mapping_ptr: Optional second-stage slot remap for SWA mode.\n        k_scale_ptr: Optional key scaling factor pointer.\n        v_scale_ptr: Optional value scaling factor pointer.\n        block_stride: Stride between cache blocks.\n        key_stride: Stride between source key tokens.\n        value_stride: Stride between source value tokens.\n        num_heads: Number of attention heads.\n        head_size: Hidden dimension per head.\n        block_size: Number of slots per cache block.\n        HEAD_BLOCK: Number of heads processed per program.\n        BLOCK_D: Vectorized dimension size (power-of-2 padded).\n        HAS_SWA: Enable SWA remapping.\n        USE_SCALE: Enable scale division before storing.\n    \"\"\"\n\n    # ----------------------------------\n    # program ids\n    # pid0 = token\n    # pid1 = head block\n    # ----------------------------------\n    token_idx = tl.program_id(0)\n    head_block_idx = tl.program_id(1)\n\n    # ----------------------------------\n    # slot mapping\n    # ----------------------------------\n    slot_idx = tl.load(slot_mapping_ptr + token_idx)\n\n    if HAS_SWA:\n        slot_idx = tl.load(swa_slot_mapping_ptr + slot_idx)\n\n    if slot_idx < 0:\n        return\n\n    block_idx = slot_idx // block_size\n    block_offset = slot_idx % block_size\n\n    # ----------------------------------\n    # head range\n    # ----------------------------------\n    head_idx = head_block_idx * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK)\n\n    head_mask = head_idx < num_heads\n\n    dim_idx = tl.arange(0, BLOCK_D)\n\n    # shape = [HEAD_BLOCK, BLOCK_D]\n    offs = head_idx[:, None] * head_size + dim_idx[None, :]\n\n    mask = head_mask[:, None] & (dim_idx[None, :] < head_size)\n\n    # ----------------------------------\n    # source load\n    # ----------------------------------\n    src_key = token_idx * key_stride + offs\n    src_value = token_idx * value_stride + offs\n\n    k = tl.load(key_ptr + src_key, mask=mask)\n    v = tl.load(value_ptr + src_value, mask=mask)\n\n    # ----------------------------------\n    # optional scale\n    # ----------------------------------\n    if USE_SCALE:\n        k_scale = tl.load(k_scale_ptr)\n        v_scale = tl.load(v_scale_ptr)\n\n        k = k / k_scale\n        v = v / v_scale\n\n    # ----------------------------------\n    # target layout\n    # [block_idx, block_offset, head, dim]\n    # ----------------------------------\n    tgt = block_idx * block_stride + block_offset * num_heads * head_size + offs\n\n    tl.store(key_cache_ptr + tgt, k, mask=mask)\n    tl.store(value_cache_ptr + tgt, v, mask=mask)\n\n\ndef launch_reshape_and_cache_flash(\n    key,\n    value,\n    key_cache,\n    value_cache,\n    slot_mapping,\n    swa_slot_mapping=None,\n    k_scale=None,\n    v_scale=None,\n):\n    \"\"\"\n    Launch wrapper for reshape_and_cache_flash Triton kernel.\n\n    This wrapper prepares launch configuration and dispatches the Triton kernel\n    that writes token-major K/V tensors into paged KV cache layout.\n\n    Args:\n        key: Source key tensor [num_tokens, num_heads, head_size]\n        value: Source value tensor [num_tokens, num_heads, head_size]\n        key_cache: Destination key cache [num_blocks, block_size, num_heads, head_size]\n        value_cache: Destination value cache [num_blocks, block_size, num_heads, head_size]\n        slot_mapping: Token-to-cache slot mapping\n        swa_slot_mapping: Optional SWA remapping table\n        k_scale: Optional key scaling factor\n        v_scale: Optional value scaling factor\n    \"\"\"\n\n    num_tokens = key.shape[0]\n    num_heads = key.shape[1]\n    head_size = key.shape[2]\n\n    HEAD_BLOCK = 4\n\n    BLOCK_D = triton.next_power_of_2(head_size)\n\n    grid = (\n        num_tokens,\n        triton.cdiv(num_heads, HEAD_BLOCK),\n    )\n\n    reshape_and_cache_flash[grid](\n        key,\n        value,\n        key_cache,\n        value_cache,\n        slot_mapping,\n        swa_slot_mapping if swa_slot_mapping is not None else key,\n        k_scale if k_scale is not None else key,\n        v_scale if v_scale is not None else key,\n        key_cache.stride(0),\n        key.stride(0),\n        value.stride(0),\n        num_heads,\n        head_size,\n        key_cache.shape[1],\n        HEAD_BLOCK=HEAD_BLOCK,\n        BLOCK_D=BLOCK_D,\n        HAS_SWA=(swa_slot_mapping is not None),\n        USE_SCALE=(k_scale is not None),\n    )\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/vision.py",
    "content": "from __future__ import annotations\n\nimport dataclasses\nimport functools\nimport math\nfrom functools import lru_cache, partial\nfrom typing import Any, Callable, Optional, Tuple\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom einops import rearrange\n\nfrom sglang.jit_kernel.norm import can_use_fused_inplace_qknorm as can_use_jit_qk_norm\nfrom sglang.srt.environ import envs\nfrom sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size\nfrom sglang.srt.models.utils import apply_qk_norm\nfrom sglang.srt.utils import (\n    get_bool_env_var,\n    get_device_capability,\n    is_blackwell_supported,\n    is_cuda,\n    is_hip,\n    is_npu,\n    is_xpu,\n    print_info_once,\n)\nfrom sglang.srt.utils.multi_stream_utils import (\n    maybe_execute_in_parallel,\n    with_multi_stream,\n)\n\n_is_cuda = is_cuda()\n_is_npu = is_npu()\n_is_hip = is_hip()\n_is_xpu = is_xpu()\n\nif _is_cuda:\n    from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache\n\n    try:\n        from sgl_kernel.flash_attn import flash_attn_varlen_func\n\n        from sglang.jit_kernel.flash_attention_v4 import (\n            flash_attn_varlen_func as flash_attn_varlen_func_fa4,\n        )\n\n        def flash_attn_func(*args, ver: int = 3, **kwargs):\n            if ver == 4:\n                return flash_attn_varlen_func_fa4(*args, **kwargs)\n            return flash_attn_varlen_func(*args, **kwargs)\n\n    except ImportError as e:\n        raise e\n\n\nif _is_npu:\n    import torch_npu\n\n_use_aiter = get_bool_env_var(\"SGLANG_USE_AITER\") and _is_hip\n\nfrom sglang.srt.distributed import (\n    split_tensor_along_last_dim,\n    tensor_model_parallel_all_gather,\n)\nfrom sglang.srt.distributed import utils as dist_utils\nfrom sglang.srt.layers.attention.triton_ops.prefill_attention import (\n    context_attention_fwd,\n)\nfrom sglang.srt.layers.layernorm import RMSNorm\nfrom sglang.srt.layers.linear import (\n    ColumnParallelLinear,\n    QKVParallelLinear,\n    RowParallelLinear,\n)\nfrom sglang.srt.layers.quantization import QuantizationConfig\nfrom sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb\nfrom sglang.srt.server_args import get_global_server_args\nfrom sglang.srt.utils import add_prefix, get_bool_env_var\n\nROTARY_EMBED_CLASSES = {\n    \"normal\": apply_rotary_pos_emb,\n}\n\n# === Vision Encoder === #\nFLASHINFER_WORKSPACE_SIZE_BYTES = 128 * 1024 * 1024\n\n# Batch buckets for cuDNN graph caching - graphs are cached per bucket size\n# This avoids creating a new graph for each unique batch size at runtime\nBATCH_BUCKETS = [8, 16, 32, 64]\n\n# Bucketized max seqlens to reduce cuDNN recompilation frequency while\n# preserving a tighter upper bound than a single fixed max seqlen.\nFLASHINFER_MAX_SEQLEN_BUCKETS = [\n    4 * 1024,\n    8 * 1024,\n    16 * 1024,\n    32 * 1024,\n    64 * 1024,\n    128 * 1024,\n]\n\n\n@dataclasses.dataclass\nclass SingletonCache:\n    data: Any = None\n\n    def set_data(self, value: Any) -> None:\n        self.data = value\n\n    def get_data(self) -> Optional[Any]:\n        return self.data\n\n    def empty(self) -> bool:\n        return self.get_data() is None\n\n\n# TODO: requires real seqlens from images\n@functools.lru_cache(maxsize=128)\ndef _get_cu_seqlens_for_shape(batch_size: int, seqlen: int, device) -> torch.Tensor:\n    \"\"\"\n    Generates cumulative sequence lengths (cu_seqlens) for a given batch_size, seqlen, and device.\n    Caches the result based on these parameters.\n    \"\"\"\n    cu_seqlens = torch.arange(\n        0,\n        (batch_size + 1) * seqlen,\n        step=seqlen,\n        dtype=torch.int32,\n        device=device,\n    )\n    return cu_seqlens\n\n\ndef resolve_seqlens(\n    cu_seqlens: torch.Tensor | SingletonCache | None,\n    bsz: int,\n    seq_len: int,\n    *,\n    device: torch.device,\n) -> torch.Tensor:\n    if cu_seqlens is None:\n        resolved_seqlens = _get_cu_seqlens_for_shape(bsz, seq_len, device=device)\n    elif isinstance(cu_seqlens, SingletonCache):\n        if cu_seqlens.empty():\n            cu_seqlens.set_data(_get_cu_seqlens_for_shape(bsz, seq_len, device=device))\n        resolved_seqlens = cu_seqlens.get_data()\n    else:\n        resolved_seqlens = cu_seqlens\n    assert isinstance(\n        resolved_seqlens, torch.Tensor\n    ), \"cu_seqlens must be a torch.Tensor\"\n    return resolved_seqlens\n\n\nclass VisionSdpaAttention(nn.Module):\n    r\"\"\"\n    Scaled Dot Product Attention inner product\n\n    \"\"\"\n\n    def __init__(\n        self,\n        head_dim: int,\n        num_heads: int,\n        num_kv_heads: int,\n        dropout: float = 0.0,\n        flatten_batch: bool = False,\n        softmax_in_single_precision: bool = False,\n        **kwargs,\n    ):\n        super().__init__()\n        self.head_size = head_dim\n        self.num_heads = num_heads\n        self.num_kv_heads = num_kv_heads\n        self.flatten_batch = flatten_batch\n        self.softmax_in_single_precision = softmax_in_single_precision\n        self.dropout = dropout\n        self.scale = 1.0 / math.sqrt(self.head_size)\n\n    @staticmethod\n    @lru_cache(maxsize=128)\n    def _generate_mask_cache(\n        s: int, flatten_batch: bool, cu_seqlens: tuple\n    ) -> torch.BoolTensor:\n        \"\"\"\n        Generate a boolean attention mask with caching mechanism.\n        Args:\n            s: sequence length\n            flatten_batch: whether to flatten batch dimension\n            cu_seqlens: tuple of cumulative sequence lengths\n        Returns:\n            attention mask tensor of shape [b, 1, s, s] or [1, s, s]\n        \"\"\"\n        if flatten_batch:\n            mask = torch.zeros([1, s, s], dtype=torch.bool)\n            for i in range(1, len(cu_seqlens)):\n                start = cu_seqlens[i - 1]\n                end = cu_seqlens[i]\n                mask[..., start:end, start:end] = True\n        else:\n            # [1, 1, 1, s]\n            row_indices = torch.arange(s).view(1, 1, 1, s)\n            # [1, 1, s, 1]\n            col_indices = torch.arange(s).view(1, 1, s, 1)\n            # [b, 1, 1, 1]\n            seq_lens = torch.tensor(\n                [end - start for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:])],\n            ).view(-1, 1, 1, 1)\n\n            mask = (row_indices < seq_lens) & (col_indices < seq_lens)\n\n        return mask\n\n    def generate_patch_attention_mask(\n        self,\n        s: int,\n        cu_seqlens: Optional[torch.Tensor],\n        flatten_batch: bool = False,\n    ) -> Optional[torch.Tensor]:\n        r\"\"\"\n        Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`.\n        Args:\n            s: sequence length\n            cu_seqlens: cumulative sequence lengths tensor. If not, returns an empty mask\n            flatten_batch: whether to flatten batch dimension\n        Returns:\n            attention mask tensor or None\n        \"\"\"\n        if cu_seqlens is None:\n            return None\n\n        cu_seqlens_tuple = tuple(cu_seqlens.cpu().tolist())\n\n        return self._generate_mask_cache(s, flatten_batch, cu_seqlens_tuple)\n\n    def forward(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        bsz: int,\n        cu_seqlens: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        **kwargs,\n    ) -> torch.Tensor:\n        r\"\"\"\n        Args:\n            cu_seqlens: [b]\n        Returns:\n             [b * s, h, head_size]\n        \"\"\"\n        if self.flatten_batch:\n            assert bsz == 1, \"flatten_batch is True, bsz must be 1\"\n\n        assert q.dim() == 3, q.shape\n\n        s = q.shape[0] // bsz\n\n        # [b, 1, s, s]\n        if attention_mask is None:\n            attention_mask = self.generate_patch_attention_mask(\n                s, cu_seqlens, flatten_batch=self.flatten_batch\n            )\n\n        if attention_mask is None:\n            if self.softmax_in_single_precision:\n                raise RuntimeError(\"Empty attention mask\")\n        else:\n            attention_mask = attention_mask.to(device=q.device)\n\n        q, k, v = [rearrange(x, \"(b s) h d -> b h s d\", b=bsz) for x in [q, k, v]]\n\n        if self.softmax_in_single_precision:\n            k = rearrange(k, \"b h s d -> b h d s\")\n            attn_weights = torch.matmul(q, k) * self.scale\n            del k\n            # masking\n            attention_mask = (~attention_mask) * torch.finfo(q.dtype).min\n            attn_weights = attn_weights + attention_mask\n            del attention_mask\n            # full-precision\n            attn_weights = nn.functional.softmax(\n                attn_weights, dim=-1, dtype=torch.float32\n            ).to(q.dtype)\n            attn_weights = nn.functional.dropout(\n                attn_weights, p=self.dropout, training=False\n            )\n            output = torch.matmul(attn_weights, v)\n            del attn_weights, v\n        else:\n            # SDPA\n            # [b, h, s, head_size]\n            output = F.scaled_dot_product_attention(\n                q,\n                k,\n                v,\n                attn_mask=attention_mask,\n                dropout_p=self.dropout,\n                is_causal=False,\n            )\n\n        # [b, h, s, head_size] --> [b * s, h, head_size]\n        output = rearrange(output, \"b h s d -> (b s) h d\")\n\n        return output\n\n\nclass VisionTritonAttention(nn.Module):\n    \"\"\"\n    Triton-implemented attention without a causal mask\n    \"\"\"\n\n    def __init__(\n        self,\n        **kwargs,\n    ):\n        super().__init__()\n        use_data_parallel = (\n            kwargs[\"use_data_parallel\"] if \"use_data_parallel\" in kwargs else False\n        )\n        self.tp_size = 1 if use_data_parallel else get_attention_tp_size()\n\n    def forward(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        cu_seqlens: torch.Tensor | SingletonCache | None,\n        bsz: int,\n        seq_len: int,\n        **kwargs,\n    ) -> torch.Tensor:\n        r\"\"\"\n        Args:\n            cu_seqlens: [b]\n        Returns:\n             [b * s, h, head_size]\n        \"\"\"\n        if envs.SGLANG_VIT_ENABLE_CUDA_GRAPH.get():\n            if \"output_ws\" not in kwargs:\n                raise RuntimeError(\"output_ws should be prepared for cuda-graph mode\")\n\n            if not isinstance(cu_seqlens, list):\n                raise RuntimeError(\"cuda-graph mode cu_seqlens should be a list\")\n\n            output = kwargs[\"output_ws\"]\n            context_attention_fwd(\n                q,\n                k,\n                v,\n                output,\n                cu_seqlens[0],\n                cu_seqlens[1],\n                cu_seqlens[2],\n                is_causal=False,\n            )\n        else:\n            cu_seqlens = resolve_seqlens(cu_seqlens, bsz, seq_len, device=q.device)\n\n            # [b * s, head, head_size]\n            output = torch.empty_like(q)\n\n            seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]\n            max_seqlen = seq_lens.max().item()\n            context_attention_fwd(\n                q,\n                k,\n                v,\n                output,\n                cu_seqlens.to(q.device),\n                seq_lens.to(q.device),\n                max_seqlen,\n                is_causal=False,\n            )\n\n        return output\n\n\nclass VisionFlash3Attention(nn.Module):\n    def __init__(\n        self,\n        **kwargs,\n    ):\n        if not _is_cuda:\n            raise Exception(\"VisionFlash3Attention is only available for cuda\")\n        super().__init__()\n        use_data_parallel = (\n            kwargs[\"use_data_parallel\"] if \"use_data_parallel\" in kwargs else False\n        )\n        self.tp_size = 1 if use_data_parallel else get_attention_tp_size()\n\n    def forward(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        cu_seqlens: torch.Tensor | SingletonCache | None,\n        bsz: int,\n        seq_len: int,\n        **kwargs,\n    ) -> torch.Tensor:\n        r\"\"\"\n        Args:\n            cu_seqlens: [b]\n        Returns:\n             [b * s, h, head_size]\n        \"\"\"\n        if envs.SGLANG_VIT_ENABLE_CUDA_GRAPH.get():\n            max_seqlen = cu_seqlens[1]\n            output = flash_attn_func(\n                q,\n                k,\n                v,\n                cu_seqlens_q=cu_seqlens[0],\n                cu_seqlens_k=cu_seqlens[0],\n                max_seqlen_q=max_seqlen,\n                max_seqlen_k=max_seqlen,\n            )\n        else:\n            cu_seqlens = resolve_seqlens(cu_seqlens, bsz, seq_len, device=q.device)\n            cu_seqlens = cu_seqlens.to(dtype=torch.int32).to(q.device)\n            seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]\n            max_seqlen = seq_lens.max().item()\n\n            output = flash_attn_func(\n                q,\n                k,\n                v,\n                cu_seqlens_q=cu_seqlens,\n                cu_seqlens_k=cu_seqlens,\n                max_seqlen_q=max_seqlen,\n                max_seqlen_k=max_seqlen,\n            )\n\n        return output\n\n\nclass VisionFlash4Attention(nn.Module):\n    def __init__(\n        self,\n        **kwargs,\n    ):\n        if not _is_cuda:\n            raise Exception(\"VisionFlash4Attention is only available for cuda\")\n        super().__init__()\n\n    def forward(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        cu_seqlens: torch.Tensor | SingletonCache | None,\n        bsz: int,\n        seq_len: int,\n        **kwargs,\n    ) -> torch.Tensor:\n        r\"\"\"\n        Args:\n            cu_seqlens: [b]\n        Returns:\n             [b * s, h, head_size]\n        \"\"\"\n        if cu_seqlens is None:\n            cu_seqlens = _get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)\n        elif isinstance(cu_seqlens, SingletonCache):\n            if cu_seqlens.empty():\n                cu_seqlens.set_data(\n                    _get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)\n                )\n            cu_seqlens = cu_seqlens.get_data()\n\n        cu_seqlens = cu_seqlens.to(dtype=torch.int32).to(q.device)\n        seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]\n        max_seqlen = seq_lens.max().item()\n\n        output = flash_attn_func(\n            q,\n            k,\n            v,\n            cu_seqlens_q=cu_seqlens,\n            cu_seqlens_k=cu_seqlens,\n            max_seqlen_q=max_seqlen,\n            max_seqlen_k=max_seqlen,\n            ver=4,\n        )\n\n        return output\n\n\nclass VisionFlashInferAttention(nn.Module):\n    def __init__(\n        self,\n        **kwargs,\n    ):\n        if not _is_cuda:\n            raise Exception(\"VisionFlashInferAttention is only available for cuda\")\n        super().__init__()\n        self.workspace_buffer = (\n            kwargs[\"workspace_buffer\"] if \"workspace_buffer\" in kwargs else None\n        )\n\n    def forward(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        cu_seqlens: torch.Tensor | SingletonCache | None,\n        bsz: int,\n        seq_len: int,\n        **kwargs,\n    ) -> torch.Tensor:\n        r\"\"\"\n        Args:\n            cu_seqlens: [b]\n        Returns:\n             [b * s, h, head_size]\n        \"\"\"\n        if \"sequence_lengths\" not in kwargs:\n            raise RuntimeError(\n                \"sequence_lengths should be prepared for vision flashinfer_cudnn attention backend\"\n            )\n        if \"max_seqlen\" not in kwargs:\n            raise RuntimeError(\n                \"max_seqlen should be prepared for vision flashinfer_cudnn attention backend\"\n            )\n\n        sequence_lengths = kwargs[\"sequence_lengths\"]  # (B_padded,) or (B_padded,1,1,1)\n        max_seqlen = kwargs[\"max_seqlen\"]\n\n        # max_seqlen must be python int\n        if isinstance(max_seqlen, torch.Tensor):\n            if max_seqlen.is_cuda:\n                max_seqlen = int(max_seqlen.detach().cpu().item())\n            else:\n                max_seqlen = int(max_seqlen.item())\n        else:\n            max_seqlen = int(max_seqlen)\n\n        # flatten if caller gives (b, s, h, d)\n        is_reshaped = q.dim() == 4\n        if is_reshaped:\n            reshape_batch_size = q.shape[0]\n            q, k, v = (rearrange(x, \"b s ... -> (b s) ...\") for x in [q, k, v])\n\n        if not isinstance(cu_seqlens, torch.Tensor):\n            raise RuntimeError(\n                \"flashinfer_cudnn expects packed indptrs as a torch.Tensor\"\n            )\n\n        # sequence_lengths -> (B,)\n        if not isinstance(sequence_lengths, torch.Tensor):\n            raise RuntimeError(\"sequence_lengths must be a torch.Tensor\")\n        seq_lens_1d = sequence_lengths.view(-1).to(device=q.device, dtype=torch.int32)\n        B = int(seq_lens_1d.numel())\n\n        # cu_seqlens contains packed *element indptrs*:\n        # [qk_indptr(B+1), v_indptr(B+1), o_indptr(B+1)] => total 3*(B+1)\n        cu_seqlens_1d = cu_seqlens.view(-1).to(device=q.device, dtype=torch.int32)\n        expected = 3 * (B + 1)\n        if int(cu_seqlens_1d.numel()) != expected:\n            raise RuntimeError(\n                f\"packed indptr numel mismatch: got {cu_seqlens_1d.numel()}, expected {expected} (= 3*(B+1))\"\n            )\n\n        split = B + 1\n        indptr_qk = cu_seqlens_1d[:split].view(split, 1, 1, 1)\n        indptr_v = cu_seqlens_1d[split : 2 * split].view(split, 1, 1, 1)\n        indptr_o = cu_seqlens_1d[2 * split :].view(split, 1, 1, 1)\n\n        # cuDNN style: (B,1,1,1)\n        seq_lens_4d = seq_lens_1d.view(B, 1, 1, 1)\n\n        # indptr are in ELEMENT offsets (not token offsets)\n        token_width_q = int(q.shape[1] * q.shape[2])  # heads * head_dim on this rank\n        total_elems_q = int(q.numel())\n\n        # check each real sequence fits\n        # (skip padded tail where seq_len==0)\n        start_elems = indptr_qk.view(-1)[:-1]  # (B,)\n        end_elems = start_elems + seq_lens_1d * token_width_q\n        if (end_elems > total_elems_q).any():\n            raise RuntimeError(\"offset + len out of bounds; packed indptr is wrong\")\n\n        _, _, head_size = q.shape\n        scale = head_size**-0.5\n\n        output, _ = cudnn_batch_prefill_with_kv_cache(\n            q,\n            k,\n            v,\n            scale,\n            self.workspace_buffer,\n            max_token_per_sequence=max_seqlen,\n            max_sequence_kv=max_seqlen,\n            actual_seq_lens_q=seq_lens_4d,\n            actual_seq_lens_kv=seq_lens_4d,\n            causal=False,\n            return_lse=True,\n            batch_offsets_q=indptr_qk,\n            batch_offsets_k=indptr_qk,\n            batch_offsets_v=indptr_v,\n            batch_offsets_o=indptr_o,\n            is_cuda_graph_compatible=True,\n        )\n\n        if is_reshaped:\n            output = rearrange(output, \"(b s) h d -> b s h d\", b=reshape_batch_size)\n\n        return output\n\n\nclass VisionAiterAttention(nn.Module):\n    def __init__(\n        self,\n        **kwargs,\n    ):\n        if not _is_hip:\n            raise Exception(\"aiter_attn is only available for AMD\")\n        try:\n            from aiter import flash_attn_varlen_func as aiter_flash_attn_varlen_func\n        except ImportError as e:\n            raise ImportError(\n                \"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device.\"\n            ) from e\n\n        self.flash_attn_varlen_func = aiter_flash_attn_varlen_func\n        super().__init__()\n\n    def forward(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        cu_seqlens: torch.Tensor | SingletonCache | None,\n        bsz: int,\n        seq_len: int,\n        **kwargs,\n    ) -> torch.Tensor:\n        cu_seqlens = resolve_seqlens(cu_seqlens, bsz, seq_len, device=q.device)\n\n        cu_seqlens = cu_seqlens.to(dtype=torch.int32).to(q.device)\n        seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]\n        max_seqlen = seq_lens.max().item()\n\n        return self.flash_attn_varlen_func(\n            q=q,\n            k=k,\n            v=v,\n            cu_seqlens_q=cu_seqlens,\n            cu_seqlens_k=cu_seqlens,\n            max_seqlen_q=max_seqlen,\n            max_seqlen_k=max_seqlen,\n        )\n\n\nclass VisionAscendAttention(nn.Module):\n\n    def __init__(\n        self,\n        **kwargs,\n    ):\n        if not _is_npu:\n            raise Exception(\"VisionAscendAttention is only available for ascend npu\")\n        super().__init__()\n\n    def forward(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        cu_seqlens: torch.Tensor | SingletonCache | None,\n        bsz: int,\n        seq_len: int,\n        **kwargs,\n    ) -> torch.Tensor:\n        r\"\"\"\n        Args:\n            cu_seqlens: [b]\n        Returns:\n             [b * s, h, head_size]\n        \"\"\"\n        if envs.SGLANG_VIT_ENABLE_CUDA_GRAPH.get():\n            if \"output_ws\" not in kwargs:\n                raise RuntimeError(\"output_ws should be prepared for npu-graph mode\")\n            output = kwargs[\"output_ws\"]\n            # graph mode: runner already passes seq_lens (int32 on CPU)\n            seq_len_arg = cu_seqlens\n        else:\n            cu_seqlens = resolve_seqlens(cu_seqlens, bsz, seq_len, device=\"cpu\")\n            seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]\n            if seq_lens.is_npu:\n                seq_lens = seq_lens.to(\"cpu\")\n            output = torch.empty_like(q)\n            seq_len_arg = seq_lens.to(torch.int32)\n\n        _, num_heads, head_size = q.shape\n        num_kv_heads = k.shape[1]\n\n        torch_npu._npu_flash_attention_unpad(\n            query=q,\n            key=k,\n            value=v,\n            seq_len=seq_len_arg,\n            scale_value=head_size**-0.5,\n            num_heads=num_heads,\n            num_kv_heads=num_kv_heads,\n            out=output,\n        )\n        return output\n\n\nQKV_BACKEND_IMPL = {\n    \"triton_attn\": VisionTritonAttention,\n    \"sdpa\": VisionSdpaAttention,\n    \"fa3\": VisionFlash3Attention,\n    \"fa4\": VisionFlash4Attention,\n    \"flashinfer_cudnn\": VisionFlashInferAttention,\n    \"ascend_attn\": VisionAscendAttention,\n    \"aiter_attn\": VisionAiterAttention,\n}\n\n\nclass VisionAttention(nn.Module):\n    r\"\"\"\n        Multi-headed attention without any cache, mostly used for multimodal transformers.\n\n\n    Args:\n        use_qkv_parallel (bool, optional): If True, use QKV-parallel attention.\n        softmax_in_single_precision (bool, default to False):\n            if ``True``, the softmax will be performed in single-precision\n            Otherwise, it will be performed in half-precision\n\n    \"\"\"\n\n    def __init__(\n        self,\n        embed_dim: int,\n        num_heads: int,\n        projection_size: int,\n        use_qkv_parallel: bool,\n        qkv_backend: Optional[str] = None,\n        quant_config: Optional[QuantizationConfig] = None,\n        dropout: float = 0.0,\n        softmax_in_single_precision: bool = False,\n        flatten_batch: bool = False,\n        prefix: str = \"\",\n        proj_bias: bool = True,\n        num_dummy_heads: int = 0,\n        qkv_bias: bool = True,\n        qk_normalization: bool = False,\n        qk_normalization_by_head_size: bool = False,\n        layer_norm_eps: float = 1e-06,\n        customized_position_embedding_applier: Callable[\n            [torch.Tensor, torch.Tensor, Any, Any], Tuple[torch.Tensor, torch.Tensor]\n        ] = None,\n        use_data_parallel: bool = False,\n        use_dp_attention_reduce: bool = False,\n        aux_stream: Optional[torch.cuda.Stream] = None,\n        workspace_buffer: Optional[torch.Tensor] = None,\n        **kwargs,\n    ):\n        super().__init__()\n        self.tp_size = 1 if use_data_parallel else get_attention_tp_size()\n        self.tp_rank = 0 if use_data_parallel else get_attention_tp_rank()\n        self.dropout = dropout\n        self.head_size = embed_dim // num_heads\n        self.hidden_size_per_attention_head = dist_utils.divide(\n            projection_size, num_heads\n        )\n        self.num_attention_heads_per_partition = dist_utils.divide(\n            num_dummy_heads + num_heads, self.tp_size\n        )\n        self.num_attention_kv_heads_per_partition = dist_utils.divide(\n            num_dummy_heads + num_heads, self.tp_size\n        )\n\n        self.q_size = self.num_attention_heads_per_partition * self.head_size\n        self.kv_size = self.num_attention_kv_heads_per_partition * self.head_size\n\n        self.qk_normalization = qk_normalization\n        self.qk_normalization_by_head_size = qk_normalization_by_head_size\n\n        # Additional dummy heads are used to enable TP for common GPU counts.\n        self.dummy_dim = (num_dummy_heads + num_heads) * self.head_size\n\n        if self.qk_normalization:\n            self.q_norm, self.k_norm = self._init_qk_norm(\n                self.dummy_dim, layer_norm_eps, embed_dim\n            )\n\n        elif self.qk_normalization_by_head_size:\n            self.q_norm, self.k_norm = self._init_qk_norm(\n                self.head_size, layer_norm_eps\n            )\n\n        # Select attention backend via a unified method\n        _passed_backend = qkv_backend\n        qkv_backend = self._determine_attention_backend(_passed_backend)\n        if (\n            get_global_server_args().mm_attention_backend is None\n            and _passed_backend is None\n        ):\n            print_info_once(f\"Multimodal attention backend not set. Use {qkv_backend}.\")\n        print_info_once(f\"Using {qkv_backend} as multimodal attention backend.\")\n\n        self.customized_position_embedding_applier = (\n            customized_position_embedding_applier\n        )\n        self.qkv_backend = QKV_BACKEND_IMPL[qkv_backend](\n            head_dim=self.head_size,\n            num_heads=self.num_attention_heads_per_partition,\n            num_kv_heads=self.num_attention_kv_heads_per_partition,\n            dropout=dropout,\n            flatten_batch=flatten_batch,\n            softmax_in_single_precision=softmax_in_single_precision,\n            use_data_parallel=use_data_parallel,\n            workspace_buffer=workspace_buffer,\n        )\n\n        self.use_qkv_parallel = use_qkv_parallel\n        if use_qkv_parallel:\n            self.qkv_proj = QKVParallelLinear(\n                hidden_size=embed_dim,\n                head_size=self.head_size,\n                total_num_heads=num_dummy_heads + num_heads,\n                total_num_kv_heads=num_dummy_heads + num_heads,\n                bias=qkv_bias,\n                quant_config=quant_config,\n                tp_rank=self.tp_rank,\n                tp_size=self.tp_size,\n                prefix=add_prefix(\"qkv_proj\", prefix),\n            )\n        else:\n            self.qkv_proj = ColumnParallelLinear(\n                input_size=embed_dim,\n                output_size=3 * self.dummy_dim,\n                bias=qkv_bias,\n                quant_config=quant_config,\n                tp_rank=self.tp_rank,\n                tp_size=self.tp_size,\n                prefix=add_prefix(\"qkv_proj\", prefix),\n            )\n        self.proj = RowParallelLinear(\n            input_size=self.dummy_dim,\n            output_size=embed_dim,\n            bias=proj_bias,\n            quant_config=quant_config,\n            tp_rank=self.tp_rank,\n            tp_size=self.tp_size,\n            prefix=add_prefix(\"proj\", prefix),\n            use_dp_attention_reduce=use_dp_attention_reduce,\n        )\n\n        self.workspace_buffer = workspace_buffer\n        self.aux_stream = aux_stream\n        self.ln_events = [torch.cuda.Event(), torch.cuda.Event()] if aux_stream else []\n\n    def _init_qk_norm(\n        self, norm_dim: int, eps: float, var_hidden_size: Optional[int] = None\n    ):\n        norm_kwargs = (\n            dict(\n                weight_dtype=torch.float32,\n                cast_x_before_out_mul=True,\n            )\n            if get_global_server_args().rl_on_policy_target is not None\n            else {}\n        )\n        q_norm = RMSNorm(\n            norm_dim,\n            eps=eps,\n            var_hidden_size=var_hidden_size,\n            **norm_kwargs,\n        )\n        k_norm = RMSNorm(\n            norm_dim,\n            eps=eps,\n            var_hidden_size=var_hidden_size,\n            **norm_kwargs,\n        )\n        return q_norm, k_norm\n\n    def _determine_attention_backend(self, passed_backend: Optional[str]) -> str:\n        \"\"\"Decide the multimodal attention backend string.\n\n        Priority: server args override > constructor arg > platform default.\n\n        Platform defaults:\n        - CUDA: \"triton_attn\"\n        - Non-CUDA: \"sdpa\"\n        \"\"\"\n        override_backend = get_global_server_args().mm_attention_backend\n        if override_backend is not None:\n            backend = override_backend\n        elif passed_backend is not None:\n            backend = passed_backend\n        elif is_cuda():\n            major, minor = get_device_capability()\n            if major == 9:\n                backend = \"fa3\"\n            else:\n                backend = \"triton_attn\"\n        elif _is_hip:\n            if get_device_capability() >= (9, 4) and _use_aiter:\n                backend = \"aiter_attn\"\n            else:\n                backend = \"triton_attn\"\n        elif _is_xpu:\n            backend = \"triton_attn\"\n        else:\n            backend = \"sdpa\"\n        if backend == \"fa3\" and is_blackwell_supported():\n            raise ValueError(\"The 'fa3' backend is not supported on Blackwell GPUs\")\n\n        return backend\n\n    def _apply_qk_norm_head_size(self, q: torch.Tensor, k: torch.Tensor):\n        \"\"\"apply qk norm for GLM-OCR vit attn\"\"\"\n        q_by_head = q.reshape(-1, self.head_size)\n        q_by_head = self.q_norm(q_by_head)\n        k_by_head = k.reshape(-1, self.head_size)\n        k_by_head = self.k_norm(k_by_head)\n        q = q_by_head.view(q.shape)\n        k = k_by_head.view(k.shape)\n        return q, k\n\n    def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):\n        \"\"\"apply qk norm for internvl vit attn\"\"\"\n\n        def q_l2norm():\n            q_ = q.flatten(1, 2)\n            if self.tp_size > 1:\n                q_ = tensor_model_parallel_all_gather(q_.contiguous())\n            q_ = self.q_norm(q_)\n            if self.tp_size > 1:\n                splitter = partial(\n                    split_tensor_along_last_dim, num_partitions=self.tp_size\n                )\n                q_ = splitter(q_)[self.tp_rank]\n            q_ = q_.unflatten(-1, (-1, self.head_size))\n            return q_\n\n        def k_l2norm():\n            k_ = k.flatten(1, 2)\n            if self.tp_size > 1:\n                k_ = tensor_model_parallel_all_gather(k_.contiguous())\n            k_ = self.k_norm(k_)\n            if self.tp_size > 1:\n                splitter = partial(\n                    split_tensor_along_last_dim, num_partitions=self.tp_size\n                )\n                k_ = splitter(k_)[self.tp_rank]\n            k_ = k_.unflatten(-1, (-1, self.head_size))\n            return k_\n\n        with with_multi_stream(True):\n            q, k = maybe_execute_in_parallel(\n                q_l2norm,\n                k_l2norm,\n                self.ln_events,\n                self.aux_stream,\n            )\n        return q, k\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        cu_seqlens: Optional[torch.Tensor] = None,\n        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n        rotary_pos_emb_cos: Optional[torch.Tensor] = None,\n        rotary_pos_emb_sin: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        **kwargs,\n    ) -> torch.Tensor:\n        r\"\"\"\n        Args:\n            x: [b, s, embed_dim]\n            cu_seqlens: [b]\n        Returns:\n             [s, b, head * head_size]\n        \"\"\"\n        if x.dim() == 2:\n            x = x.unsqueeze(0)\n        assert x.dim() == 3, x.shape\n        if (\n            get_global_server_args().rl_on_policy_target is not None\n            and position_embeddings is not None\n        ):\n            assert isinstance(position_embeddings, tuple), (\n                \"expected position_embeddings to be a tuple of two tensors,\\n\"\n                f\"but got {type(position_embeddings)}, change if needed\"\n            )\n            position_embeddings = tuple(p.to(x.dtype) for p in position_embeddings)\n        x_shape = x.shape\n        bsz, s, _ = x_shape\n        head = self.num_attention_heads_per_partition\n        kv_head = self.num_attention_kv_heads_per_partition\n\n        attn_output_ws = kwargs[\"output_ws\"] if \"output_ws\" in kwargs else None\n        max_seqlen = kwargs[\"max_seqlen\"] if \"max_seqlen\" in kwargs else None\n        sequence_lengths = (\n            kwargs[\"sequence_lengths\"] if \"sequence_lengths\" in kwargs else None\n        )\n        if self.use_qkv_parallel:\n            # [b, s, embed_dim] --> [b, s, embed_dim]\n            qkv, _ = self.qkv_proj(x)\n            q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)\n\n            # [b, s, embed_dim] --> [b * s, head, head_size]\n            q = q.reshape(bsz * s, head, -1).contiguous()\n            k = k.reshape(bsz * s, kv_head, -1).contiguous()\n            v = v.reshape(bsz * s, kv_head, -1).contiguous()\n            if self.qk_normalization_by_head_size:\n                q, k = self._apply_qk_norm_head_size(q, k)\n        else:\n            # [b, s, embed_dim] --> [s, b, embed_dim]\n            x = rearrange(x, \"b s ... -> s b ...\")\n            # [s, b, embed_dim] --> [s, b, head * 3 * head_size]\n            qkv, _ = self.qkv_proj(x)\n\n            # [s, b, head, head_dim_sum]\n            new_x_shape = qkv.size()[:-1] + (\n                head,\n                self.q_size + 2 * self.kv_size,\n            )\n            qkv = qkv.view(*new_x_shape)\n\n            # [s, b, head, 3 * head_size] --> 3 [s, b, head, head_size]\n            q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)\n\n            # [s, b, head, head_size] --> [b, s, head, head_size]\n            q, k, v = [\n                rearrange(x, \"s b ... -> b s ...\").contiguous() for x in (q, k, v)\n            ]\n\n            if self.qk_normalization_by_head_size:\n                q, k = self._apply_qk_norm_head_size(q, k)\n\n        cos = None\n        sin = None\n\n        if position_embeddings is not None:\n            if self.customized_position_embedding_applier is not None:\n                q, k = self.customized_position_embedding_applier(\n                    q, k, position_embeddings, x_shape\n                )\n            else:\n                cos, sin = position_embeddings\n        elif rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None:\n            cos = rotary_pos_emb_cos\n            sin = rotary_pos_emb_sin\n\n        if cos is not None and sin is not None:\n            original_shape = q.shape\n\n            # [total_tokens, head, head_size]\n            q = q.view(-1, head, self.head_size)\n            k = k.view(-1, head, self.head_size)\n\n            if cos.size(-1) * 2 == self.head_size:\n                cos = torch.cat([cos, cos], dim=-1)\n                sin = torch.cat([sin, sin], dim=-1)\n\n            q, k = apply_rotary_pos_emb(q, k, cos, sin)\n            q = q.view(original_shape)\n            k = k.view(original_shape)\n\n        if q.dim() == 4:\n            # [b, s, head, head_size] --> [b * s, head, head_size]\n            q = rearrange(q, \"b s ... -> (b s) ...\")\n        if k.dim() == 4:\n            # [b, s, head, head_size] --> [b * s, head, head_size]\n            k = rearrange(k, \"b s ... -> (b s) ...\")\n        if v.dim() == 4:\n            # [b, s, head, head_size] --> [b * s, head, head_size]\n            v = rearrange(v, \"b s ... -> (b s) ...\")\n\n        assert q.dim() == 3, q.dim()\n        assert k.dim() == 3, k.dim()\n        assert v.dim() == 3, v.dim()\n\n        # internvl\n        if self.qk_normalization and not self.qk_normalization_by_head_size:\n            # jit kernel\n            if can_use_jit_qk_norm(self.head_size, q.dtype):\n\n                # q: [tokens, head, head_size]  ->  [tokens, embed_dim]\n                head_dim_for_norm = head * self.head_size\n\n                q, k = apply_qk_norm(\n                    q=q,\n                    k=k,\n                    q_norm=self.q_norm,\n                    k_norm=self.k_norm,\n                    head_dim=head_dim_for_norm,\n                    alt_stream=self.aux_stream,\n                )\n\n            else:\n                q, k = self._apply_qk_norm(q, k)\n\n        output = self.qkv_backend.forward(\n            q=q,\n            k=k,\n            v=v,\n            bsz=bsz,\n            seq_len=s,\n            cu_seqlens=cu_seqlens,\n            attention_mask=attention_mask,\n            sequence_lengths=sequence_lengths,\n            max_seqlen=max_seqlen,\n            output_ws=attn_output_ws,\n        )\n\n        assert output.dim() == 3, output.shape\n\n        if self.use_qkv_parallel:\n            # [b * s, h, head_size] --> [b, s, h * head_size]\n            output = rearrange(output, \"(b s) ... h d -> b s ... (h d)\", b=bsz)\n\n            # [b, s, h * head_size] --> [b, s, h * head_size]\n            output, _ = self.proj(output)\n        else:\n            # [b * s, h, head_size] --> [s, b, h * head_size]\n            context_layer = rearrange(\n                output, \"(b s) h d -> s b (h d)\", b=bsz, s=s\n            ).contiguous()\n\n            # [s, b, h * head_size] --> [s, b, h * head_size]\n            output, _ = self.proj(context_layer)\n\n            # [s, b, h * head_size] --> [b, s, h * head_size]\n            output = output.view(bsz, s, -1)\n\n        return output\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/vision_utils.py",
    "content": "\"\"\"Utility functions for vision attention layers.\"\"\"\n\nimport torch\n\nfrom sglang.srt.layers.dp_attention import get_attention_tp_size\n\n\ndef update_vit_attn_dummy_heads_config(config):\n    \"\"\"Update HF config to ensure vision attention num_attention_heads is divisible by tp_size\"\"\"\n    tp_size = get_attention_tp_size()\n    num_heads = getattr(\n        config.vision_config,\n        \"num_heads\",\n        getattr(config.vision_config, \"num_attention_heads\", None),\n    )\n    head_dim = config.vision_config.hidden_size // num_heads\n    num_dummy_heads = 0\n\n    if num_heads % tp_size != 0:\n        num_dummy_heads = ((num_heads + tp_size - 1) // tp_size) * tp_size - num_heads\n\n    setattr(config.vision_config, \"head_dim\", head_dim)\n    setattr(config.vision_config, \"num_dummy_heads\", num_dummy_heads)\n\n\ndef pad_vit_attn_dummy_heads(config, name: str, loaded_weight: torch.Tensor):\n    \"\"\"Pad attention qkv weights for dummy heads\"\"\"\n    num_dummy_heads = config.vision_config.num_dummy_heads\n    if num_dummy_heads == 0:\n        return loaded_weight\n    head_dim = config.vision_config.head_dim\n\n    if \"attn.qkv_proj\" in name:\n        wq, wk, wv = loaded_weight.chunk(3, dim=0)\n        if name.endswith(\".weight\"):\n            dummy_shape = [num_dummy_heads, head_dim, wq.shape[-1]]\n        elif name.endswith(\".bias\"):\n            dummy_shape = [num_dummy_heads, head_dim]\n        else:\n            raise RuntimeError(f\"Unsupported weight with name={name}\")\n        pad_func = lambda x: torch.cat(\n            [x.unflatten(0, (-1, head_dim)), x.new_zeros(dummy_shape)], dim=0\n        ).flatten(0, 1)\n        wq, wk, wv = pad_func(wq), pad_func(wk), pad_func(wv)\n        loaded_weight = torch.cat([wq, wk, wv], dim=0)\n    elif any([_ in name for _ in [\"attn.q_proj\", \"attn.k_proj\", \"attn.v_proj\"]]):\n        if name.endswith(\".weight\"):\n            dummy_shape = [num_dummy_heads, head_dim, loaded_weight.shape[-1]]\n        elif name.endswith(\".bias\"):\n            dummy_shape = [num_dummy_heads, head_dim]\n        else:\n            raise RuntimeError(f\"Unsupported weight with name={name}\")\n        padded_weight = loaded_weight.new_zeros(dummy_shape)\n        loaded_weight = torch.cat(\n            [loaded_weight.unflatten(0, (-1, head_dim)), padded_weight], dim=0\n        ).flatten(0, 1)\n    elif \"attn.proj.weight\" in name:\n        padded_weight = loaded_weight.new_zeros(\n            loaded_weight.shape[0], head_dim * num_dummy_heads\n        )\n        loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1)\n    elif \"attn.q_norm.weight\" in name or \"attn.k_norm.weight\" in name:\n        padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads)\n        loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0)\n    return loaded_weight\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/wave_backend.py",
    "content": "from __future__ import annotations\n\nimport logging\nfrom dataclasses import dataclass\nfrom typing import TYPE_CHECKING, Optional\n\nimport torch\nimport triton\nimport triton.language as tl\n\nfrom sglang.srt.layers.attention.base_attn_backend import AttentionBackend\nfrom sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton\nfrom sglang.srt.layers.dp_attention import get_attention_tp_size\nfrom sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode\nfrom sglang.srt.utils import get_bool_env_var, get_device_core_count\n\nif TYPE_CHECKING:\n    from sglang.srt.layers.radix_attention import RadixAttention\n    from sglang.srt.model_executor.model_runner import ModelRunner\n    from sglang.srt.speculative.spec_info import SpecInput\n\nlogger = logging.getLogger(__name__)\n\n\n@triton.jit\ndef get_num_kv_splits_triton(\n    num_kv_splits_ptr,\n    seq_lens_ptr,\n    num_seq,\n    num_group,\n    num_head,\n    num_kv_head,\n    max_kv_splits,\n    device_core_count,\n    MAX_NUM_SEQ: tl.constexpr,\n):\n    # TODO: this method is tunable, we need more online serving data to tune it\n    offs_seq = tl.arange(0, MAX_NUM_SEQ)\n    mask_seq = offs_seq < num_seq\n\n    seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=0)\n    max_seq_len = tl.max(seq_lens)\n    seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=max_seq_len)\n    min_seq_len = tl.min(seq_lens)\n    if max_seq_len * 8 < min_seq_len * 10:\n        min_seq_len = max_seq_len\n    max_kv_splits_1 = tl.minimum(tl.cdiv(max_seq_len, min_seq_len), max_kv_splits)\n    kv_chunk_size_1 = tl.cdiv(max_seq_len, max_kv_splits_1)\n\n    # NOTE: this is a hack to let num_kv_split grows up with seqlen gradually\n    ext_seq_len = tl.cast(max_seq_len, tl.float32) / 64.0\n    ext_device_core_count = tl.cast(\n        device_core_count * tl.maximum(tl.log2(ext_seq_len), 1.0), tl.int32\n    )\n    block_h, num_kv_group = 16, num_head // num_kv_head\n    if num_kv_group == 1:\n        token_grid = num_seq * num_group * num_head\n    else:\n        # from triton_ops/decode_attention.py:_decode_grouped_att_m_fwd\n        block_h = tl.minimum(block_h, num_kv_group)\n        token_grid = num_seq * num_group * tl.cdiv(num_head, block_h)\n    max_kv_splits_2 = tl.minimum(\n        tl.cdiv(ext_device_core_count, token_grid), max_kv_splits\n    )\n    kv_chunk_size_2 = tl.cdiv(max_seq_len, max_kv_splits_2)\n\n    num_kv_splits = tl.maximum(\n        tl.cdiv(seq_lens, kv_chunk_size_1), tl.cdiv(seq_lens, kv_chunk_size_2)\n    )\n\n    offs_token = offs_seq * num_group\n    mask_token = offs_token < num_seq * num_group\n    for i in range(0, num_group):\n        tl.store(num_kv_splits_ptr + i + offs_token, num_kv_splits, mask=mask_token)\n\n\n@dataclass\nclass ForwardMetadata:\n    attn_logits: torch.Tensor\n    attn_lse: torch.Tensor\n    max_extend_len: int\n    num_kv_splits: torch.Tensor\n    kv_indptr: torch.Tensor\n    kv_indices: torch.Tensor\n    qo_indptr: torch.Tensor\n    custom_mask: torch.Tensor\n    mask_indptr: torch.Tensor\n\n\nclass WaveAttnBackend(AttentionBackend):\n    def __init__(\n        self,\n        model_runner: ModelRunner,\n        skip_prefill: bool = False,\n        kv_indptr_buf: Optional[torch.Tensor] = None,\n    ):\n        # Lazy import to avoid the initialization of cuda context\n        from sglang.srt.layers.attention.wave_ops.decode_attention import (\n            decode_attention_fwd,\n        )\n        from sglang.srt.layers.attention.wave_ops.extend_attention import (\n            extend_attention_wave,\n        )\n\n        super().__init__()\n\n        # Set unique cache dir for each process to avoid cache write races\n        import wave_lang.kernel.wave.cache as cache\n\n        base_cache_dir = cache.CACHE_BASE_DIR\n        new_dir = base_cache_dir / f\"worker_{model_runner.tp_rank}\"\n        logger.info(f\"Setting Wave cache dir: {new_dir}\")\n        cache.CACHE_BASE_DIR = new_dir\n\n        self.decode_attention_fwd = decode_attention_fwd\n        self.extend_attention_fwd = extend_attention_wave\n\n        self.skip_prefill = skip_prefill\n\n        max_bs = model_runner.req_to_token_pool.size\n\n        if kv_indptr_buf is None:\n            self.kv_indptr = torch.zeros(\n                (max_bs + 1,), dtype=torch.int32, device=model_runner.device\n            )\n        else:\n            self.kv_indptr = kv_indptr_buf\n\n        self.req_to_token = model_runner.req_to_token_pool.req_to_token\n\n        if not self.skip_prefill:\n            self.qo_indptr = torch.zeros(\n                (max_bs + 1,), dtype=torch.int32, device=model_runner.device\n            )\n\n            self.mask_indptr = torch.zeros(\n                (max_bs + 1,), dtype=torch.int64, device=model_runner.device\n            )\n\n        self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens\n\n        self.num_head = (\n            model_runner.model_config.num_attention_heads // get_attention_tp_size()\n        )\n        self.num_kv_head = model_runner.model_config.get_num_kv_heads(\n            get_attention_tp_size()\n        )\n\n        self.static_kv_splits = get_bool_env_var(\n            \"SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS\", \"false\"\n        )\n        self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits\n        self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]\n\n        self.forward_metadata: ForwardMetadata = None\n\n        self.max_context_len = model_runner.model_config.context_len\n\n        self.device = model_runner.device\n        self.device_core_count = get_device_core_count(model_runner.gpu_id)\n\n    def get_num_kv_splits(\n        self,\n        num_kv_splits: torch.Tensor,\n        seq_lens: torch.Tensor,\n    ):\n        num_token, num_seq = num_kv_splits.shape[0], seq_lens.shape[0]\n        num_group = num_token // num_seq\n\n        assert (\n            num_group * num_seq == num_token\n        ), f\"num_seq({num_seq}), num_token({num_token}), something goes wrong!\"\n\n        if self.static_kv_splits or self.device_core_count <= 0:\n            num_kv_splits.fill_(self.max_kv_splits)\n            return\n\n        if num_seq < 256:\n            SCHEDULE_SEQ = 256\n        else:\n            SCHEDULE_SEQ = triton.next_power_of_2(num_seq)\n\n        get_num_kv_splits_triton[(1,)](\n            num_kv_splits,\n            seq_lens,\n            num_seq,\n            num_group,\n            self.num_head,\n            self.num_kv_head,\n            self.max_kv_splits,\n            self.device_core_count,\n            MAX_NUM_SEQ=SCHEDULE_SEQ,\n        )\n\n    def init_forward_metadata(self, forward_batch: ForwardBatch):\n        \"\"\"Init auxiliary variables for wave attention backend.\"\"\"\n\n        bs = forward_batch.batch_size\n        kv_indptr = self.kv_indptr\n        spec_info = forward_batch.spec_info\n\n        if forward_batch.forward_mode.is_decode_or_idle():\n            if spec_info is None:\n                kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)\n                kv_indptr = kv_indptr[: bs + 1]\n                kv_indices = torch.empty(\n                    forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device\n                )\n                create_flashinfer_kv_indices_triton[(bs,)](\n                    self.req_to_token,\n                    forward_batch.req_pool_indices,\n                    forward_batch.seq_lens,\n                    kv_indptr,\n                    None,\n                    kv_indices,\n                    self.req_to_token.stride(0),\n                )\n            else:\n                kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices\n                bs = kv_indptr.shape[0] - 1\n\n            from sglang.srt.layers.attention.wave_ops.decode_attention import (\n                decode_attention_intermediate_arrays_shapes,\n            )\n\n            attn_logits_shape, attn_logits_max_shape = (\n                decode_attention_intermediate_arrays_shapes(\n                    bs, self.v_head_dim, self.num_head, self.max_kv_splits\n                )\n            )\n            attn_logits = torch.empty(\n                attn_logits_shape,\n                dtype=torch.float32,\n                device=self.device,\n            )\n            attn_lse = torch.empty(\n                attn_logits_max_shape,\n                dtype=torch.float32,\n                device=self.device,\n            )\n            num_kv_splits = torch.empty((bs,), dtype=torch.int32, device=self.device)\n\n            self.get_num_kv_splits(num_kv_splits, forward_batch.seq_lens)\n\n            qo_indptr = None\n            custom_mask = None\n            mask_indptr = None\n            max_extend_len = None\n        elif forward_batch.forward_mode.is_target_verify():\n            bs = len(forward_batch.req_pool_indices)\n            qo_indptr = torch.arange(\n                0,\n                (1 + bs) * self.num_draft_tokens,\n                step=self.num_draft_tokens,\n                dtype=torch.int32,\n                device=self.device,\n            )\n            # Different with flashinfer kv_indptr and kv_indices construction\n            kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)\n            kv_indptr = kv_indptr[: bs + 1]\n            kv_indices = torch.empty(\n                kv_indptr[-1], dtype=torch.int32, device=self.device\n            )\n            create_flashinfer_kv_indices_triton[(bs,)](\n                self.req_to_token,\n                forward_batch.req_pool_indices,\n                forward_batch.seq_lens,\n                kv_indptr,\n                None,\n                kv_indices,\n                self.req_to_token.stride(0),\n            )\n\n            custom_mask = spec_info.custom_mask\n            seq_mask_len = self.num_draft_tokens * (\n                forward_batch.seq_lens + self.num_draft_tokens\n            )\n            mask_indptr = self.mask_indptr\n            mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len[:bs], dim=0)\n            mask_indptr = mask_indptr[: bs + 1]\n            max_extend_len = self.num_draft_tokens\n            num_kv_splits = None\n            attn_logits = None\n            attn_lse = None\n        elif forward_batch.forward_mode.is_draft_extend():\n            kv_indices, kv_indptr, qo_indptr, custom_mask = (\n                spec_info.generate_attn_arg_prefill(\n                    forward_batch.req_pool_indices,\n                    forward_batch.seq_lens,\n                    None,\n                    self.req_to_token,\n                )\n            )\n            mask_indptr = None\n            # TODO(FIXME): This will trigger an invalid Eagle tree when using\n            # `max(spec_info.accept_length_cpu)`.\n            # It might have been forgotten to update somewhere.\n            max_extend_len = torch.max(spec_info.accept_length).item()\n            num_kv_splits = None\n            attn_logits = None\n            attn_lse = None\n        else:\n            kv_indptr[1 : bs + 1] = torch.cumsum(\n                forward_batch.extend_prefix_lens, dim=0\n            )\n            kv_indptr = kv_indptr[: bs + 1]\n            kv_indices = torch.empty(\n                forward_batch.extend_prefix_lens.sum().item(),\n                dtype=torch.int32,\n                device=self.device,\n            )\n            create_flashinfer_kv_indices_triton[(bs,)](\n                self.req_to_token,\n                forward_batch.req_pool_indices,\n                forward_batch.extend_prefix_lens,\n                kv_indptr,\n                None,\n                kv_indices,\n                self.req_to_token.stride(0),\n            )\n\n            qo_indptr = self.qo_indptr\n            qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0)\n            qo_indptr = qo_indptr[: bs + 1]\n            custom_mask = None\n            mask_indptr = None\n            attn_logits = None\n            attn_lse = None\n            max_extend_len = torch.max(forward_batch.extend_seq_lens).item()\n            num_kv_splits = None\n\n        self.forward_metadata = ForwardMetadata(\n            attn_logits,\n            attn_lse,\n            max_extend_len,\n            num_kv_splits,\n            kv_indptr,\n            kv_indices,\n            qo_indptr,\n            custom_mask,\n            mask_indptr,\n        )\n\n    def init_cuda_graph_state(\n        self,\n        max_bs: int,\n        max_num_tokens: int,\n        kv_indices_buf: Optional[torch.Tensor] = None,\n    ):\n        from sglang.srt.layers.attention.wave_ops.decode_attention import (\n            decode_attention_intermediate_arrays_shapes,\n        )\n\n        attn_logits_shape, attn_logits_max_shape = (\n            decode_attention_intermediate_arrays_shapes(\n                max_bs, self.v_head_dim, self.num_head, self.max_kv_splits\n            )\n        )\n        self.cuda_graph_attn_logits = torch.zeros(\n            attn_logits_shape,\n            dtype=torch.float32,\n            device=self.device,\n        )\n        self.cuda_graph_attn_lse = torch.zeros(\n            attn_logits_max_shape,\n            dtype=torch.float32,\n            device=self.device,\n        )\n        self.cuda_graph_num_kv_splits = torch.full(\n            (max_bs,), self.max_kv_splits, dtype=torch.int32, device=self.device\n        )\n        if kv_indices_buf is None:\n            self.cuda_graph_kv_indices = torch.zeros(\n                (max_bs * self.max_context_len),\n                dtype=torch.int32,\n                device=self.device,\n            )\n        else:\n            self.cuda_graph_kv_indices = kv_indices_buf\n\n        if not self.skip_prefill:\n            self.cuda_graph_custom_mask = torch.zeros(\n                (max_bs * self.max_context_len),\n                dtype=torch.uint8,\n                device=self.device,\n            )\n\n    def init_forward_metadata_capture_cuda_graph(\n        self,\n        bs: int,\n        num_tokens: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        encoder_lens: Optional[torch.Tensor],\n        forward_mode: ForwardMode,\n        spec_info: Optional[SpecInput],\n    ):\n        assert encoder_lens is None, \"Not supported\"\n\n        if forward_mode.is_decode_or_idle():\n            if spec_info is None:\n                kv_indptr = self.kv_indptr\n                kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)\n                kv_indptr = kv_indptr[: bs + 1]\n                kv_indices = self.cuda_graph_kv_indices\n                create_flashinfer_kv_indices_triton[(bs,)](\n                    self.req_to_token,\n                    req_pool_indices,\n                    seq_lens,\n                    kv_indptr,\n                    None,\n                    kv_indices,\n                    self.req_to_token.stride(0),\n                )\n            else:\n                kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices\n\n            attn_logits = self.cuda_graph_attn_logits\n            attn_lse = self.cuda_graph_attn_lse\n            max_extend_len = None\n            num_kv_splits = self.cuda_graph_num_kv_splits\n            qo_indptr = None\n            custom_mask = None\n            mask_indptr = None\n        elif forward_mode.is_target_verify():\n            qo_indptr = self.qo_indptr[: bs + 1]\n            qo_indptr[: bs + 1] = torch.arange(\n                0,\n                (1 + bs) * self.num_draft_tokens,\n                step=self.num_draft_tokens,\n                dtype=torch.int32,\n                device=self.device,\n            )\n            kv_indptr = self.kv_indptr[: bs + 1]\n            kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)\n            kv_indices = self.cuda_graph_kv_indices\n            create_flashinfer_kv_indices_triton[(bs,)](\n                self.req_to_token,\n                req_pool_indices,\n                seq_lens,\n                kv_indptr,\n                None,\n                kv_indices,\n                self.req_to_token.stride(0),\n            )\n\n            custom_mask = self.cuda_graph_custom_mask\n            seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)\n            mask_indptr = self.mask_indptr[: bs + 1]\n            mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)\n            max_extend_len = self.num_draft_tokens\n            num_kv_splits = None\n            attn_logits = None\n            attn_lse = None\n        else:\n            raise ValueError(\n                f\"Invalid forward mode: {forward_mode=} for CUDA Graph capture.\"\n            )\n\n        self.forward_metadata = ForwardMetadata(\n            attn_logits,\n            attn_lse,\n            max_extend_len,\n            num_kv_splits,\n            kv_indptr,\n            kv_indices,\n            qo_indptr,\n            custom_mask,\n            mask_indptr,\n        )\n\n    def init_forward_metadata_replay_cuda_graph(\n        self,\n        bs: int,\n        req_pool_indices: torch.Tensor,\n        seq_lens: torch.Tensor,\n        seq_lens_sum: int,\n        encoder_lens: Optional[torch.Tensor],\n        forward_mode: ForwardMode,\n        spec_info: Optional[SpecInput],\n        seq_lens_cpu: Optional[torch.Tensor],\n    ):\n        # NOTE: encoder_lens expected to be zeros or None\n        if forward_mode.is_decode_or_idle():\n            # Update kv_indptr, kv_indices\n            kv_indptr = self.kv_indptr\n            kv_indices = self.cuda_graph_kv_indices\n            num_kv_splits = self.cuda_graph_num_kv_splits\n            if spec_info is None:\n                kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0)\n                kv_indptr = kv_indptr[: bs + 1]\n                create_flashinfer_kv_indices_triton[(bs,)](\n                    self.req_to_token,\n                    req_pool_indices[:bs],\n                    seq_lens[:bs],\n                    kv_indptr,\n                    None,\n                    kv_indices,\n                    self.req_to_token.stride(0),\n                )\n                num_token = bs\n            else:\n                kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr\n                kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices\n                num_token = spec_info.kv_indptr.shape[0] - 1\n            self.get_num_kv_splits(num_kv_splits[:num_token], seq_lens[:bs])\n        elif forward_mode.is_target_verify():\n            # Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr\n            bs = len(req_pool_indices)\n            qo_indptr = self.qo_indptr[: bs + 1]\n            qo_indptr[: bs + 1] = torch.arange(\n                0,\n                (1 + bs) * self.num_draft_tokens,\n                step=self.num_draft_tokens,\n                dtype=torch.int32,\n                device=self.device,\n            )\n            kv_indptr = self.kv_indptr[: bs + 1]\n            kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)\n            kv_indices = self.cuda_graph_kv_indices\n            create_flashinfer_kv_indices_triton[(bs,)](\n                self.req_to_token,\n                req_pool_indices,\n                seq_lens,\n                kv_indptr,\n                None,\n                kv_indices,\n                self.req_to_token.stride(0),\n            )\n            custom_mask = self.cuda_graph_custom_mask\n            custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask\n            seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)\n            mask_indptr = self.mask_indptr[: bs + 1]\n            mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)\n        else:\n            raise ValueError(\n                f\"Invalid forward mode: {forward_mode=} for CUDA Graph replay.\"\n            )\n\n    def get_cuda_graph_seq_len_fill_value(self):\n        return 1\n\n    def forward_extend(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache=True,\n    ):\n        # TODO: reuse the buffer across layers\n        if layer.qk_head_dim != layer.v_head_dim:\n            o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))\n        else:\n            o = torch.empty_like(q)\n\n        if save_kv_cache:\n            forward_batch.token_to_kv_pool.set_kv_buffer(\n                layer, forward_batch.out_cache_loc, k, v\n            )\n\n        max_extend_len = self.forward_metadata.max_extend_len\n        computed_max_ext_seq_len = torch.max(forward_batch.extend_seq_lens)\n        if computed_max_ext_seq_len != max_extend_len:\n            assert len(forward_batch.extend_seq_lens) == 1\n            forward_batch.extend_seq_lens[0] = max_extend_len\n            forward_batch.seq_lens = max_extend_len\n\n        self.extend_attention_fwd(\n            q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),\n            k.contiguous(),\n            v.contiguous(),\n            forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),\n            forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),\n            self.forward_metadata.qo_indptr,\n            self.forward_metadata.kv_indptr,\n            self.forward_metadata.kv_indices,\n            self.forward_metadata.custom_mask,\n            self.forward_metadata.mask_indptr,\n            self.forward_metadata.max_extend_len,\n            o.view(-1, layer.tp_q_head_num, layer.v_head_dim),\n            is_causal=True,\n            layer_scaling=layer.scaling,\n            logit_cap=layer.logit_cap,\n        )\n        return o\n\n    def forward_decode(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache=True,\n    ):\n        # During torch.compile, there is a bug in rotary_emb that causes the\n        # output value to have a 3D tensor shape. This reshapes the output correctly.\n        q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)\n\n        # TODO: reuse the buffer across layers\n        if layer.qk_head_dim != layer.v_head_dim:\n            o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))\n        else:\n            o = torch.empty_like(q)\n\n        if save_kv_cache:\n            forward_batch.token_to_kv_pool.set_kv_buffer(\n                layer, forward_batch.out_cache_loc, k, v\n            )\n\n        self.decode_attention_fwd(\n            q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),\n            forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),\n            forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),\n            o.view(-1, layer.tp_q_head_num, layer.v_head_dim),\n            self.forward_metadata.kv_indptr,\n            self.forward_metadata.kv_indices,\n            self.forward_metadata.attn_logits,\n            self.forward_metadata.attn_lse,\n            self.forward_metadata.num_kv_splits,\n            self.max_kv_splits,\n            layer.scaling,\n            layer.logit_cap,\n        )\n        return o\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/wave_ops/decode_attention.py",
    "content": "\"\"\"\nMemory-efficient attention for decoding.\nIt supports page size = 1.\n\"\"\"\n\nimport functools\nimport logging\n\nfrom wave_lang.kernel.lang.global_symbols import *\nfrom wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile\nfrom wave_lang.kernel.wave.constraints import GenericDot, MMAOperand, MMAType\nfrom wave_lang.kernel.wave.templates.paged_decode_attention import (\n    get_paged_decode_attention_kernels,\n    get_paged_decode_intermediate_arrays_shapes,\n    paged_decode_attention_shape,\n)\nfrom wave_lang.kernel.wave.utils.general_utils import get_default_scheduling_params\nfrom wave_lang.kernel.wave.utils.run_utils import set_default_run_config\n\nlogger = logging.getLogger(__name__)\nimport os\n\ndump_generated_mlir = int(os.environ.get(\"WAVE_DUMP_MLIR\", 0))\n\n\n@functools.lru_cache(maxsize=4096)\ndef get_wave_kernel(\n    shape: paged_decode_attention_shape,\n    max_kv_splits,\n    input_dtype,\n    output_dtype,\n    logit_cap,\n):\n    mha = (shape.num_query_heads // shape.num_kv_heads) == 1\n\n    # Get the kernels (either compile or load from cache).\n    if mha:\n        mfma_variant = (\n            GenericDot(along_dim=MMAOperand.M, k_vec_size=4, k_mult=1),\n            GenericDot(along_dim=MMAOperand.M, k_vec_size=1, k_mult=64),\n        )\n    else:\n        mfma_variant = (MMAType.F32_16x16x16_F16, MMAType.F32_16x16x16_F16)\n\n    (\n        phase_0,\n        phase_1,\n        hyperparams_0,\n        hyperparams_1,\n        dynamic_symbols_0,\n        dynamic_symbols_1,\n    ) = get_paged_decode_attention_kernels(\n        shape,\n        mfma_variant,\n        max_kv_splits,\n        input_dtype=input_dtype,\n        output_dtype=output_dtype,\n        logit_cap=logit_cap,\n    )\n    hyperparams_0.update(get_default_scheduling_params())\n    hyperparams_1.update(get_default_scheduling_params())\n\n    options = WaveCompileOptions(\n        subs=hyperparams_0,\n        canonicalize=True,\n        run_bench=False,\n        use_buffer_ops=True,\n        waves_per_eu=2,\n        dynamic_symbols=dynamic_symbols_0,\n        wave_runtime=True,\n    )\n    options = set_default_run_config(options)\n    phase_0 = wave_compile(options, phase_0)\n\n    options = WaveCompileOptions(\n        subs=hyperparams_1,\n        canonicalize=True,\n        run_bench=False,\n        use_buffer_ops=False,\n        waves_per_eu=4,\n        dynamic_symbols=dynamic_symbols_1,\n        wave_runtime=True,\n    )\n    options = set_default_run_config(options)\n    phase_1 = wave_compile(options, phase_1)\n\n    return phase_0, phase_1\n\n\ndef decode_attention_intermediate_arrays_shapes(\n    num_seqs, head_size_kv, num_query_heads, max_kv_splits\n):\n    # Not all fields are used, but we need to pass them to the function\n    shape = paged_decode_attention_shape(\n        num_query_heads=num_query_heads,\n        num_kv_heads=0,\n        head_size=0,\n        head_size_kv=head_size_kv,\n        block_size=0,\n        num_seqs=num_seqs,\n    )\n    return get_paged_decode_intermediate_arrays_shapes(shape, max_kv_splits)\n\n\ndef decode_attention_wave(\n    q,\n    k_buffer,\n    v_buffer,\n    o,\n    b_req_idx,\n    req_to_token,\n    attn_logits,\n    attn_logits_max,\n    num_kv_splits,\n    max_kv_splits,\n    sm_scale,\n    logit_cap,\n):\n    num_seqs, num_query_heads, head_size = q.shape\n    _, num_kv_heads, _ = k_buffer.shape\n    _, _, head_size_kv = v_buffer.shape\n    block_size = 32\n    shape = paged_decode_attention_shape(\n        num_query_heads,\n        num_kv_heads,\n        head_size,\n        head_size_kv,\n        block_size,\n        num_seqs,\n    )\n\n    phase_0, phase_1 = get_wave_kernel(\n        shape, max_kv_splits, q.dtype, o.dtype, logit_cap\n    )\n\n    mb_qk = phase_0(\n        q,\n        k_buffer,\n        v_buffer,\n        b_req_idx,\n        req_to_token,\n        attn_logits,\n        attn_logits_max,\n    )\n    if dump_generated_mlir:\n        filename = f\"wave_decode_attention_phase0_{'x'.join(map(str, shape))}.mlir\"\n        with open(filename, \"w\") as f:\n            f.write(mb_qk.module_op.get_asm())\n\n    mb_sv = phase_1(attn_logits, attn_logits_max, b_req_idx, o)\n    if dump_generated_mlir:\n        filename = f\"wave_decode_attention_phase1_{'x'.join(map(str, shape))}.mlir\"\n        with open(filename, \"w\") as f:\n            f.write(mb_sv.module_op.get_asm())\n\n\ndef decode_attention_fwd(\n    q,\n    k_buffer,\n    v_buffer,\n    o,\n    b_req_idx,\n    req_to_token,\n    attn_logits,\n    attn_logits_max,\n    num_kv_splits,\n    max_kv_splits,\n    sm_scale,\n    logit_cap=0.0,\n):\n    decode_attention_wave(\n        q,\n        k_buffer,\n        v_buffer,\n        o,\n        b_req_idx,\n        req_to_token,\n        attn_logits,\n        attn_logits_max,\n        num_kv_splits,\n        max_kv_splits,\n        sm_scale,\n        logit_cap,\n    )\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/wave_ops/extend_attention.py",
    "content": "\"\"\"\nMemory-efficient attention for prefill.\nIt support page size = 1.\n\"\"\"\n\nimport functools\nimport os\n\nimport torch\nfrom wave_lang.kernel.lang.global_symbols import *\nfrom wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile\nfrom wave_lang.kernel.wave.constraints import MMAType\nfrom wave_lang.kernel.wave.scheduling.schedule import SchedulingType\nfrom wave_lang.kernel.wave.templates.attention_common import AttentionShape\nfrom wave_lang.kernel.wave.templates.extend_attention import get_extend_attention_kernel\nfrom wave_lang.kernel.wave.utils.general_utils import get_default_scheduling_params\nfrom wave_lang.kernel.wave.utils.run_utils import set_default_run_config\n\ndump_generated_mlir = int(os.environ.get(\"WAVE_DUMP_MLIR\", 0))\n\n\n@functools.lru_cache\ndef get_wave_kernel(\n    shape: AttentionShape,\n    q_shape: tuple[int],\n    k_shape: tuple[int],\n    v_shape: tuple[int],\n    k_cache_shape: tuple[int],\n    v_cache_shape: tuple[int],\n    o_shape: tuple[int],\n    input_dtype: torch.dtype,\n    output_dtype: torch.dtype,\n    size_dtype: torch.dtype,\n    is_causal: bool,\n    logit_cap: float,\n    layer_scaling: float,\n):\n    assert shape.num_query_heads % shape.num_kv_heads == 0\n\n    mfma_variant = (MMAType.F32_16x16x32_K8_F16, MMAType.F32_16x16x16_F16)\n    (\n        extend_attention,\n        hyperparams,\n        dynamic_symbols,\n    ) = get_extend_attention_kernel(\n        shape,\n        mfma_variant,\n        q_shape,\n        k_shape,\n        v_shape,\n        k_cache_shape,\n        v_cache_shape,\n        o_shape,\n        input_dtype=input_dtype,\n        output_dtype=output_dtype,\n        size_dtype=size_dtype,\n        is_causal=is_causal,\n        layer_scaling=layer_scaling,\n        logit_cap=logit_cap,\n    )\n\n    hyperparams.update(get_default_scheduling_params())\n    options = WaveCompileOptions(\n        subs=hyperparams,\n        canonicalize=True,\n        run_bench=False,\n        schedule=SchedulingType.NONE,\n        use_scheduling_barriers=False,\n        dynamic_symbols=dynamic_symbols,\n        use_buffer_ops=True,\n        waves_per_eu=2,\n        denorm_fp_math_f32=\"preserve-sign\",\n        wave_runtime=True,\n    )\n    options = set_default_run_config(options)\n    extend_attention = wave_compile(options, extend_attention)\n\n    return extend_attention\n\n\ndef extend_attention_wave(\n    q_extend,\n    k_extend,\n    v_extend,\n    k_buffer,\n    v_buffer,\n    qo_indptr,\n    kv_indptr,\n    kv_indices,\n    custom_mask,\n    mask_indptr,\n    max_seq_len,\n    output,\n    is_causal=True,\n    layer_scaling=None,\n    logit_cap=0,\n):\n    shape = AttentionShape(\n        num_query_heads=q_extend.shape[1],\n        num_kv_heads=k_extend.shape[1],\n        head_size=q_extend.shape[2],\n        head_size_kv=k_extend.shape[2],\n        num_seqs=kv_indptr.shape[0] - 1,\n        max_seq_len=max_seq_len,\n    )\n\n    # Run the wave kernel.\n    extend_attention = get_wave_kernel(\n        shape,\n        q_extend.shape,\n        k_extend.shape,\n        v_extend.shape,\n        k_buffer.shape,\n        v_buffer.shape,\n        output.shape,\n        input_dtype=q_extend.dtype,\n        output_dtype=output.dtype,\n        size_dtype=qo_indptr.dtype,\n        is_causal=is_causal,\n        layer_scaling=layer_scaling,\n        logit_cap=logit_cap,\n    )\n\n    mb = extend_attention(\n        q_extend,\n        k_extend,\n        v_extend,\n        k_buffer,\n        v_buffer,\n        qo_indptr,\n        kv_indptr,\n        kv_indices,\n        max_seq_len,\n        output,\n    )\n\n    if dump_generated_mlir:\n        shape_list = [\n            q_extend.shape[0],\n            q_extend.shape[1],\n            k_extend.shape[1],\n            q_extend.shape[2],\n            k_extend.shape[2],\n        ]\n        filename = f\"wave_prefill_attention_{'x'.join(map(str, shape_list))}.mlir\"\n        with open(filename, \"w\") as f:\n            f.write(mb.module_op.get_asm())\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/wave_ops/prefill_attention.py",
    "content": "\"\"\"\nMemory-efficient attention for prefill.\nIt support page size = 1.\n\"\"\"\n\nimport math\nimport os\n\nfrom wave_lang.kernel.lang.global_symbols import *\nfrom wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile\nfrom wave_lang.kernel.wave.constraints import MMAType\nfrom wave_lang.kernel.wave.templates.attention_common import AttentionShape\nfrom wave_lang.kernel.wave.templates.prefill_attention import (\n    get_prefill_attention_kernel,\n)\nfrom wave_lang.kernel.wave.utils.general_utils import get_default_scheduling_params\nfrom wave_lang.kernel.wave.utils.run_utils import set_default_run_config\n\ndump_generated_mlir = int(os.environ.get(\"WAVE_DUMP_MLIR\", 0))\n\n\ndef prefill_attention_wave(\n    q, k, v, o, b_start_loc, b_seq_len, max_seq_len, is_causal=True\n):\n\n    shape = AttentionShape(\n        num_query_heads=q.shape[1],\n        num_kv_heads=k.shape[1],\n        head_size=q.shape[2],\n        head_size_kv=k.shape[2],\n        num_seqs=b_seq_len.shape[0],\n        max_seq_len=max_seq_len,\n        total_seq_len=q.shape[0],\n    )\n\n    assert shape.num_query_heads % shape.num_kv_heads == 0\n\n    output_shape = (shape.total_seq_len, shape.num_query_heads, shape.head_size_kv)\n    # Run the wave kernel.\n    mfma_variant = (MMAType.F32_16x16x16_F16, MMAType.F32_16x16x16_F16)\n    prefill, hyperparams = get_prefill_attention_kernel(\n        shape,\n        mfma_variant,\n        q.shape,\n        k.shape,\n        v.shape,\n        output_shape,\n        input_dtype=q.dtype,\n        output_dtype=o.dtype,\n        size_dtype=b_seq_len.dtype,\n    )\n\n    hyperparams.update(get_default_scheduling_params())\n\n    log2e = 1.44269504089\n    dk_sqrt = math.sqrt(1.0 / shape.head_size)\n\n    options = WaveCompileOptions(\n        subs=hyperparams,\n        canonicalize=True,\n        run_bench=False,\n        use_scheduling_barriers=False,\n    )\n    options = set_default_run_config(options)\n    prefill = wave_compile(options, prefill)\n\n    mb = prefill(\n        q * dk_sqrt * log2e,\n        k,\n        v,\n        b_start_loc,\n        b_seq_len,\n        o,\n    )\n    if dump_generated_mlir:\n        shape_list = [q.shape[0], q.shape[1], k.shape[1], q.shape[2], k.shape[2]]\n        filename = f\"wave_prefill_attention_{'x'.join(map(str, shape_list))}.mlir\"\n        with open(filename, \"w\") as f:\n            f.write(mb.module_op.get_asm())\n"
  },
  {
    "path": "python/sglang/srt/layers/attention/xpu_backend.py",
    "content": "from __future__ import annotations\n\nfrom typing import TYPE_CHECKING, Optional\n\nimport torch\n\nfrom sglang.srt.configs.model_config import AttentionArch\nfrom sglang.srt.layers.attention.base_attn_backend import AttentionBackend\nfrom sglang.srt.layers.attention.flashattention_backend import (\n    FlashAttentionMetadata,\n    make_local_attention_virtual_batches,\n    merge_state_v2_wrapper,\n    prepare_swa_spec_page_table_triton,\n)\nfrom sglang.srt.managers.schedule_batch import get_global_server_args\nfrom sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode\n\nif TYPE_CHECKING:\n    from sglang.srt.layers.radix_attention import RadixAttention\n    from sglang.srt.model_executor.model_runner import ModelRunner\n\nfrom sgl_kernel import merge_state_v2\nfrom sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache\n\n\nclass XPUAttentionBackend(AttentionBackend):\n    \"\"\"XPU FlashAttention backend, currently based on FlashAttentionBackend, will be refactored later.\n\n    TODO:\n    - Prefill and Decode disaggregation, currently only chunked prefill is supported\n    - Speculative Decoding support\n    - XPU Graph support, see https://github.com/pytorch/pytorch/issues/162143\n    - MLA support\n    \"\"\"\n\n    def __init__(\n        self,\n        model_runner: ModelRunner,\n        skip_prefill: bool = False,\n        speculative_step_id=0,\n        topk=0,\n        speculative_num_steps=0,\n    ):\n        super().__init__()\n\n        assert not (\n            model_runner.sliding_window_size is not None\n            and model_runner.model_config.is_encoder_decoder\n        ), \"Sliding window and cross attention are not supported together\"\n\n        self.forward_metadata: FlashAttentionMetadata = None\n        # extra metadata for handling speculative decoding topk > 1, extended draft decode and verify\n        self.forward_metadata_spec_decode_expand: FlashAttentionMetadata = None\n        self.max_context_len = model_runner.model_config.context_len\n        self.device = model_runner.device\n        self.decode_cuda_graph_metadata = {}\n        self.target_verify_metadata = {}\n        self.req_to_token = model_runner.req_to_token_pool.req_to_token\n        self.kv_cache_dtype = model_runner.kv_cache_dtype\n        self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype\n        self.page_size = model_runner.page_size\n        self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA\n        assert (\n            self.use_mla is False\n        ), \"XPUAttentionBackend doesn't support MLA yet, please use --attention-backend triton instead.\"\n        self.skip_prefill = skip_prefill\n        self.is_hybrid_swa = model_runner.is_hybrid_swa\n        if self.is_hybrid_swa:\n            self.full_to_swa_index_mapping = (\n                model_runner.token_to_kv_pool.full_to_swa_index_mapping\n            )\n        self.topk = model_runner.server_args.speculative_eagle_topk or 0\n        self.speculative_num_steps = speculative_num_steps\n        self.speculative_num_draft_tokens = (\n            model_runner.server_args.speculative_num_draft_tokens\n        )\n        self.speculative_step_id = speculative_step_id\n\n        # Local attention settings\n        self.attention_chunk_size = (\n            model_runner.attention_chunk_size\n            if hasattr(model_runner, \"attention_chunk_size\")\n            else None\n        )\n\n        # For each layer, the sliding_window_size can be different. This is only used for preparing SWA metadata.\n        # We use `layer.sliding_window_size` to decide whether to use SWA for each layer.\n        self.sliding_window_size = model_runner.sliding_window_size\n        self.has_swa = (\n            self.sliding_window_size is not None and self.sliding_window_size > -1\n        )\n\n    def init_forward_metadata(self, forward_batch: ForwardBatch):\n        \"\"\"Initialize forward metadata hence all layers in the forward pass can reuse it.\"\"\"\n        metadata = FlashAttentionMetadata()\n        seqlens_in_batch = forward_batch.seq_lens\n        batch_size = forward_batch.batch_size\n        device = seqlens_in_batch.device\n\n        if forward_batch.forward_mode.is_decode_or_idle():\n            # Draft Decode\n            if forward_batch.spec_info is not None:\n                assert (\n                    False\n                ), \"XPUAttentionBackend doesn't support speculative decoding yet, please use --attention-backend triton instead.\"\n                if self.topk <= 1:\n                    metadata.cache_seqlens_int32 = (\n                        seqlens_in_batch + (self.speculative_step_id + 1)\n                    ).to(torch.int32)\n                    metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (\n                        self.speculative_step_id + 1\n                    )\n                    metadata.cu_seqlens_q = torch.arange(\n                        0, batch_size + 1, dtype=torch.int32, device=device\n                    )\n                    metadata.cu_seqlens_k = torch.nn.functional.pad(\n                        torch.cumsum(\n                            metadata.cache_seqlens_int32, dim=0, dtype=torch.int32\n                        ),\n                        (1, 0),\n                    )\n                    metadata.page_table = forward_batch.req_to_token_pool.req_to_token[\n                        forward_batch.req_pool_indices, : metadata.max_seq_len_k\n                    ]\n                else:\n                    metadata.cache_seqlens_int32 = (seqlens_in_batch).to(torch.int32)\n                    metadata.max_seq_len_q = self.topk\n                    metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()\n                    metadata.cu_seqlens_q = torch.arange(\n                        0,\n                        batch_size * self.topk + 1,\n                        step=self.topk,\n                        dtype=torch.int32,\n                        device=device,\n                    )\n                    metadata.cu_seqlens_k = torch.nn.functional.pad(\n                        torch.cumsum(\n                            metadata.cache_seqlens_int32, dim=0, dtype=torch.int32\n                        ),\n                        (1, 0),\n                    )\n                    metadata.page_table = forward_batch.req_to_token_pool.req_to_token[\n                        forward_batch.req_pool_indices, : metadata.max_seq_len_k\n                    ]\n\n                    metadata_expand = FlashAttentionMetadata()\n                    decode_length = self.speculative_step_id + 1\n                    metadata_expand.cache_seqlens_int32 = torch.full(\n                        (seqlens_in_batch.numel() * self.topk,),\n                        decode_length,\n                        device=device,\n                        dtype=torch.int32,\n                    )\n                    metadata_expand.max_seq_len_q = 1\n                    metadata_expand.cu_seqlens_q = torch.arange(\n                        0,\n                        metadata_expand.cache_seqlens_int32.numel() + 1,\n                        dtype=torch.int32,\n                        device=device,\n                    )\n                    metadata_expand.cu_seqlens_k = torch.arange(\n                        0,\n                        metadata_expand.cache_seqlens_int32.numel() * decode_length + 1,\n                        step=decode_length,\n                        dtype=torch.int32,\n                        device=device,\n                    )\n                    # shape: [bs, num_steps, topk] -> [bs x topk, num_steps]\n                    cache_loc = forward_batch.out_cache_loc.view(\n                        -1, self.speculative_num_steps\n                    )\n                    metadata_expand.page_table = (\n                        cache_loc[:, :decode_length].contiguous().to(torch.int32)\n                    )\n                    self.forward_metadata_spec_decode_expand = metadata_expand\n            else:\n                # Normal Decode\n                metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)\n                metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()\n                metadata.cu_seqlens_q = torch.arange(\n                    0, batch_size + 1, dtype=torch.int32, device=device\n                )\n                metadata.cu_seqlens_k = torch.nn.functional.pad(\n                    torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)\n                )\n                metadata.page_table = forward_batch.req_to_token_pool.req_to_token[\n                    forward_batch.req_pool_indices, : metadata.max_seq_len_k\n                ]\n            # TODO: we need to test this part for llama 4 eagle case\n            self._init_local_attn_metadata(forward_batch, metadata, device)\n        elif forward_batch.forward_mode.is_target_verify():\n            if self.topk <= 1:\n                metadata.cache_seqlens_int32 = (\n                    forward_batch.seq_lens + self.speculative_num_draft_tokens\n                ).to(torch.int32)\n                metadata.max_seq_len_q = self.speculative_num_draft_tokens\n                metadata.max_seq_len_k = (\n                    forward_batch.seq_lens_cpu.max().item()\n                    + self.speculative_num_draft_tokens\n                )\n                metadata.cu_seqlens_q = torch.arange(\n                    0,\n                    batch_size * self.speculative_num_draft_tokens + 1,\n                    self.speculative_num_draft_tokens,\n                    dtype=torch.int32,\n                    device=device,\n                )\n                metadata.cu_seqlens_k = torch.nn.functional.pad(\n                    torch.cumsum(\n                        metadata.cache_seqlens_int32, dim=0, dtype=torch.int32\n                    ),\n                    (1, 0),\n                )\n                metadata.page_table = forward_batch.req_to_token_pool.req_to_token[\n                    forward_batch.req_pool_indices, : metadata.max_seq_len_k\n                ]\n\n                self._init_local_attn_metadata(forward_batch, metadata, device)\n            else:\n                metadata.cache_seqlens_int32 = forward_batch.seq_lens.to(torch.int32)\n                metadata.max_seq_len_q = self.speculative_num_draft_tokens\n                metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()\n                metadata.cu_seqlens_q = torch.arange(\n                    0,\n                    batch_size * self.speculative_num_draft_tokens + 1,\n                    step=self.speculative_num_draft_tokens,\n                    dtype=torch.int32,\n                    device=device,\n                )\n                metadata.cu_seqlens_k = torch.nn.functional.pad(\n                    torch.cumsum(\n                        metadata.cache_seqlens_int32, dim=0, dtype=torch.int32\n                    ),\n                    (1, 0),\n                )\n                metadata.page_table = forward_batch.req_to_token_pool.req_to_token[\n                    forward_batch.req_pool_indices, : metadata.max_seq_len_k\n                ]\n\n                metadata_expand = FlashAttentionMetadata()\n\n                metadata_expand.max_seq_len_q = 1\n                metadata_expand.cu_seqlens_q = torch.arange(\n                    0,\n                    forward_batch.seq_lens.numel() * self.speculative_num_draft_tokens\n                    + 1,\n                    dtype=torch.int32,\n                    device=device,\n                )\n\n                # create expand page table\n                offsets = torch.arange(\n                    self.speculative_num_draft_tokens, device=device\n                ).unsqueeze(\n                    0\n                )  # shape: (1, self.speculative_num_draft_tokens)\n                cols = offsets.expand(\n                    forward_batch.seq_lens.numel(), -1\n                ) + forward_batch.seq_lens.unsqueeze(1)\n                cum_len = torch.nn.functional.pad(\n                    torch.cumsum(\n                        (\n                            forward_batch.seq_lens + self.speculative_num_draft_tokens\n                        ).repeat_interleave(self.speculative_num_draft_tokens),\n                        dim=0,\n                    ),\n                    (1, 0),\n                )[:-1]\n                mask_extraction_indices = (\n                    cols.repeat_interleave(self.speculative_num_draft_tokens, dim=0)\n                    + cum_len[:, None]\n                ).view(1, -1)\n                mask = forward_batch.spec_info.custom_mask[\n                    mask_extraction_indices\n                ].view(\n                    -1, self.speculative_num_draft_tokens\n                )  # (bsz * draft_num, draft_num)\n\n                # shift table indices to avoid padding\n                # non_masked_page_table [[8, 9, 10],   mask (display with int format) [[1, 0, 0],\n                #                        [8, 9, 10],                                   [1, 1, 0],\n                #                        [8, 9, 10]]                                   [1, 0, 1]]\n                # if masked with padding [[8, 0, 0],   our mask without padding       [[8, 9, 10],\n                #                        [8, 9, 0],                                    [8, 9, 10],\n                #                        [8, 0, 10]]                                   [8, 10, 9]]\n                # note here cache_seqlens_int32 is [1, 2, 2] so extra page indices will be ignored in each row\n                col_indices = offsets.expand(\n                    mask.shape[0], self.speculative_num_draft_tokens\n                )\n                # Build keys: if an entry is valid (mask==True), keep its original index;\n                # if not, add self.speculative_num_draft_tokens so that it sorts after all valid entries.\n                keys = torch.where(\n                    mask, col_indices, col_indices + self.speculative_num_draft_tokens\n                )\n                _, sort_order = torch.sort(keys, dim=1)\n                non_masked_page_table = (\n                    forward_batch.req_to_token_pool.req_to_token[\n                        forward_batch.req_pool_indices, :\n                    ]\n                    .gather(1, cols)\n                    .repeat_interleave(self.speculative_num_draft_tokens, dim=0)\n                )  # (bsz, draft_num)\n                metadata_expand.page_table = non_masked_page_table.gather(1, sort_order)\n                metadata_expand.cache_seqlens_int32 = mask.sum(dim=1).to(torch.int32)\n                metadata_expand.cu_seqlens_k = torch.nn.functional.pad(\n                    torch.cumsum(\n                        metadata_expand.cache_seqlens_int32, dim=0, dtype=torch.int32\n                    ),\n                    (1, 0),\n                )\n                self.forward_metadata_spec_decode_expand = metadata_expand\n\n                if self.has_swa:\n                    self._init_sliding_window_attn_spec_metadata(\n                        metadata, metadata_expand\n                    )\n\n        elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():\n            metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)\n            metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()\n            metadata.cu_seqlens_k = torch.nn.functional.pad(\n                torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)\n            )\n            metadata.page_table = forward_batch.req_to_token_pool.req_to_token[\n                forward_batch.req_pool_indices, : metadata.max_seq_len_k\n            ]\n\n            if (\n                any(forward_batch.extend_prefix_lens_cpu)\n                or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND\n            ):\n                extend_seq_lens = forward_batch.extend_seq_lens\n                metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)\n                metadata.cu_seqlens_q = torch.nn.functional.pad(\n                    torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)\n                )\n            else:\n                metadata.max_seq_len_q = metadata.max_seq_len_k\n                metadata.cu_seqlens_q = metadata.cu_seqlens_k\n\n            # Setup local attention if enabled\n            if forward_batch.forward_mode == ForwardMode.EXTEND:\n                self._init_local_attn_metadata(forward_batch, metadata, device)\n\n        # Encoder metadata for cross attention\n        if forward_batch.encoder_lens is not None:\n            assert (\n                forward_batch.encoder_lens.numel() == 1\n            ), \"Only encoder size 1 is supported for now\"\n\n            metadata.encoder_lens_int32 = forward_batch.encoder_lens.to(torch.int32)\n            metadata.encoder_cu_seqlens_k = torch.nn.functional.pad(\n                torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32),\n                (1, 0),\n            )\n            metadata.encoder_max_seq_len_k = metadata.encoder_lens_int32.max().item()\n            metadata.encoder_page_table = forward_batch.req_to_token_pool.req_to_token[\n                forward_batch.req_pool_indices, : metadata.encoder_max_seq_len_k\n            ]\n\n            # Currently only support forward_batch.encoder_lens.numel() == 1\n            metadata.page_table = forward_batch.req_to_token_pool.req_to_token[\n                forward_batch.req_pool_indices,\n                metadata.encoder_max_seq_len_k : (\n                    metadata.encoder_max_seq_len_k + metadata.max_seq_len_k\n                ),\n            ]\n\n        # Convert the page table to a strided format which is needed by FA3 API\n        if self.page_size > 1:\n            self.strided_indices = torch.arange(\n                0, metadata.page_table.shape[1], self.page_size, device=self.device\n            )\n            metadata.page_table = (\n                metadata.page_table[:, self.strided_indices] // self.page_size\n            )\n\n        self.forward_metadata = metadata\n\n    def forward_extend(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache=True,\n        # For multi-head latent attention\n        q_rope: Optional[torch.Tensor] = None,\n        k_rope: Optional[torch.Tensor] = None,\n        sinks: Optional[torch.Tensor] = None,\n    ):\n        if k is not None:\n            assert v is not None\n            if save_kv_cache:\n                cache_loc = (\n                    forward_batch.out_cache_loc\n                    if not layer.is_cross_attention\n                    else forward_batch.encoder_out_cache_loc\n                )\n                if not self.use_mla:\n                    forward_batch.token_to_kv_pool.set_kv_buffer(\n                        layer, cache_loc, k, v, layer.k_scale, layer.v_scale\n                    )\n                else:\n                    forward_batch.token_to_kv_pool.set_mla_kv_buffer(\n                        layer,\n                        cache_loc,\n                        k,\n                        k_rope,\n                    )\n\n        # Use precomputed metadata across all layers\n        metadata = self.forward_metadata\n\n        # Calculate window size (can be moved to metadata if layer properties don't change)\n        # we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1\n        # here is two side inclusive\n        is_hybrid_swa = (\n            layer.sliding_window_size is not None and layer.sliding_window_size > -1\n        )\n        window_size = (layer.sliding_window_size, 0) if is_hybrid_swa else (-1, -1)\n\n        # currently no FP8 KV cache supported\n        k_descale, v_descale = None, None\n        # # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention\n        # # has corresponding quantization method so that layer.k_scale is not None,\n        # # 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.\n        # if self.kv_cache_dtype_str != \"auto\" and layer.head_dim <= 256:\n        #     if layer.k_scale is not None:\n        #         descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)\n        #         k_descale = layer.k_scale.expand(descale_shape)\n        #         v_descale = layer.v_scale.expand(descale_shape)\n        #     q = q.to(self.kv_cache_dtype)\n        #     q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None\n        #     k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None\n        causal = not layer.is_cross_attention\n\n        # Check if we should use local attention\n        use_local_attn = (\n            self.attention_chunk_size is not None\n            and metadata.local_attn_metadata is not None\n            and (hasattr(layer, \"use_irope\") and layer.use_irope)\n        )\n\n        # We do cascade attention for Target Verify with topk > 1\n        # We don't use cascade attention for Sliding Window Attention:\n        # - Different window sizes should be passed in for each q in the first stage of cascade attention, but FA3 interface doesn't support pass in a list of window sizes.\n        # - The overhead of duplicated computation of the common prefix part is small for sliding window layers (seq_len <= window_size), so we can just expand it.\n        use_cascade_attn = (\n            forward_batch.forward_mode.is_target_verify()\n            and self.topk > 1\n            and not is_hybrid_swa\n        )\n\n        # For fa3 interface version compatibility, we put new fields into conditional keyword args\n        kwargs = {}\n        if sinks is not None:\n            kwargs[\"sinks\"] = sinks\n\n        # Get the appropriate page table based on whether we're using local attention\n        if use_local_attn:\n            local_metadata = metadata.local_attn_metadata\n            page_table = local_metadata.local_block_table\n            cu_seqlens_q = local_metadata.local_query_start_loc\n            cache_seqlens = local_metadata.local_seqused_k\n            max_seqlen_q = local_metadata.local_max_query_len\n        elif is_hybrid_swa and metadata.swa_spec_metadata is not None:\n            swa_spec_metadata = metadata.swa_spec_metadata\n            page_table = swa_spec_metadata.page_table\n            cu_seqlens_q = swa_spec_metadata.cu_seqlens_q\n            cache_seqlens = swa_spec_metadata.cache_seqlens_int32\n            max_seqlen_q = swa_spec_metadata.max_seq_len_q\n            cu_seqlens_k = swa_spec_metadata.cu_seqlens_k\n        else:\n            page_table = metadata.page_table\n            cu_seqlens_q = metadata.cu_seqlens_q\n            cache_seqlens = metadata.cache_seqlens_int32\n            max_seqlen_q = metadata.max_seq_len_q\n            cu_seqlens_k = metadata.cu_seqlens_k\n\n        # Use Flash Attention for prefill\n        if not self.use_mla:\n            # Do multi-head attention\n            key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(\n                layer.layer_id\n            )\n            key_cache = key_cache.view(\n                -1, self.page_size, layer.tp_k_head_num, layer.head_dim\n            )\n            value_cache = value_cache.view(\n                -1, self.page_size, layer.tp_v_head_num, layer.head_dim\n            )\n            if layer.is_cross_attention:\n                page_table = metadata.encoder_page_table\n                cache_seqlens = metadata.encoder_lens_int32\n                cu_seqlens_k = metadata.encoder_cu_seqlens_k\n                window_size = (-1, -1)\n\n            result = flash_attn_with_kvcache(\n                q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),\n                k_cache=key_cache,\n                v_cache=value_cache,\n                page_table=page_table,\n                cache_seqlens=cache_seqlens,\n                cu_seqlens_q=cu_seqlens_q,\n                cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,\n                max_seqlen_q=max_seqlen_q,\n                softmax_scale=layer.scaling,\n                causal=False if use_cascade_attn else causal,\n                window_size=window_size,\n                softcap=layer.logit_cap,\n                k_descale=k_descale,\n                v_descale=v_descale,\n                return_softmax_lse=use_cascade_attn,\n                **kwargs,\n            )\n\n            if use_cascade_attn:\n                o, softmax_lse, *rest = result\n                o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache(\n                    q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),\n                    k_cache=key_cache,\n                    v_cache=value_cache,\n                    page_table=self.forward_metadata_spec_decode_expand.page_table,\n                    cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,\n                    cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,\n                    cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,\n                    max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,\n                    softmax_scale=layer.scaling,\n                    causal=False,\n                    window_size=window_size,\n                    softcap=layer.logit_cap,\n                    k_descale=k_descale,\n                    v_descale=v_descale,\n                    return_softmax_lse=True,\n                    **kwargs,\n                )\n                o, _ = merge_state_v2_wrapper(\n                    o,\n                    softmax_lse.T.contiguous(),\n                    o_expand,\n                    softmax_lse_expand.T.contiguous(),\n                )\n            else:\n                o = result\n        else:\n            if (\n                forward_batch.attn_attend_prefix_cache is not None\n                and not forward_batch.forward_mode.is_target_verify()\n                and not forward_batch.forward_mode.is_draft_extend()\n            ):\n                # Do multi-head attention with chunked prefix cache\n                if forward_batch.attn_attend_prefix_cache:\n                    assert not get_global_server_args().disable_chunked_prefix_cache\n                    # MHA for chunked prefix kv cache when running model with MLA\n                    assert forward_batch.prefix_chunk_idx is not None\n                    assert forward_batch.prefix_chunk_cu_seq_lens is not None\n                    assert forward_batch.prefix_chunk_max_seq_lens is not None\n\n                    chunk_idx = forward_batch.prefix_chunk_idx\n                    assert chunk_idx >= 0\n\n                    assert forward_batch.mha_return_lse\n                    output = flash_attn_varlen_func(\n                        q=q.view(-1, layer.tp_q_head_num, layer.head_dim),\n                        k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),\n                        v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),\n                        cu_seqlens_q=metadata.cu_seqlens_q,\n                        cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],\n                        max_seqlen_q=metadata.max_seq_len_q,\n                        max_seqlen_k=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],\n                        softmax_scale=layer.scaling,\n                        causal=False,\n                        return_softmax_lse=True,\n                    )\n                else:\n                    # MHA for extend part of sequence without attending prefix kv cache\n                    output = flash_attn_varlen_func(\n                        q=q.view(-1, layer.tp_q_head_num, layer.head_dim),\n                        k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),\n                        v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),\n                        cu_seqlens_q=metadata.cu_seqlens_q,\n                        cu_seqlens_k=metadata.cu_seqlens_q,\n                        max_seqlen_q=metadata.max_seq_len_q,\n                        max_seqlen_k=metadata.max_seq_len_q,\n                        softmax_scale=layer.scaling,\n                        causal=True,\n                        return_softmax_lse=forward_batch.mha_return_lse,\n                    )\n                if forward_batch.mha_return_lse:\n                    output, lse, *rest = output\n                    lse = torch.transpose(lse, 0, 1).contiguous()\n                    return output, lse\n                return output\n            else:\n                # Do absorbed multi-latent attention\n                kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(\n                    layer.layer_id\n                ).to(q.dtype)\n                k_rope = kv_cache[:, :, layer.v_head_dim :]\n                c_kv = kv_cache[:, :, : layer.v_head_dim]\n                k_rope_cache = k_rope.view(\n                    -1,\n                    self.page_size,\n                    layer.tp_k_head_num,\n                    layer.head_dim - layer.v_head_dim,\n                )\n                c_kv_cache = c_kv.view(\n                    -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim\n                )\n                if q_rope is not None:\n                    q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)\n                    q_rope = q_rope.view(\n                        -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim\n                    )\n                else:\n                    q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)\n                    q_nope = q_all[:, :, : layer.v_head_dim]\n                    q_rope = q_all[:, :, layer.v_head_dim :]\n\n                result = flash_attn_with_kvcache(\n                    q=q_rope,\n                    k_cache=k_rope_cache,\n                    v_cache=c_kv_cache,\n                    qv=q_nope,\n                    page_table=page_table,\n                    cache_seqlens=cache_seqlens,\n                    cu_seqlens_q=cu_seqlens_q,\n                    cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,\n                    max_seqlen_q=max_seqlen_q,\n                    softmax_scale=layer.scaling,\n                    causal=False if use_cascade_attn else causal,\n                    softcap=layer.logit_cap,\n                    k_descale=k_descale,\n                    v_descale=v_descale,\n                    return_softmax_lse=use_cascade_attn,\n                )\n                if use_cascade_attn:\n                    o, softmax_lse, *rest = result\n                    o_expand, softmax_lse_expand, *rest_expand = (\n                        flash_attn_with_kvcache(\n                            q=q_rope,\n                            k_cache=k_rope_cache,\n                            v_cache=c_kv_cache,\n                            qv=q_nope,\n                            page_table=self.forward_metadata_spec_decode_expand.page_table,\n                            cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,\n                            cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,\n                            cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,\n                            max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,\n                            softmax_scale=layer.scaling,\n                            causal=False,\n                            window_size=window_size,\n                            softcap=layer.logit_cap,\n                            k_descale=k_descale,\n                            v_descale=v_descale,\n                            return_softmax_lse=True,\n                        )\n                    )\n                    o, _ = merge_state_v2_wrapper(\n                        o,\n                        softmax_lse.T.contiguous(),\n                        o_expand,\n                        softmax_lse_expand.T.contiguous(),\n                    )\n                else:\n                    o = result\n\n        return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)\n\n    def forward_decode(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        layer: RadixAttention,\n        forward_batch: ForwardBatch,\n        save_kv_cache=True,\n        # For multi-head latent attention\n        q_rope: Optional[torch.Tensor] = None,\n        k_rope: Optional[torch.Tensor] = None,\n        sinks: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        if k is not None:\n            assert v is not None\n            if save_kv_cache:\n                cache_loc = (\n                    forward_batch.out_cache_loc\n                    if not layer.is_cross_attention\n                    else forward_batch.encoder_out_cache_loc\n                )\n                if not self.use_mla:\n                    forward_batch.token_to_kv_pool.set_kv_buffer(\n                        layer, cache_loc, k, v, layer.k_scale, layer.v_scale\n                    )\n                else:\n                    forward_batch.token_to_kv_pool.set_mla_kv_buffer(\n                        layer,\n                        cache_loc,\n                        k,\n                        k_rope,\n                    )\n\n        # Use precomputed metadata across all layers\n        metadata = self.forward_metadata\n        local_attn_metadata = getattr(metadata, \"local_attn_metadata\", None)\n        use_local_attn = (\n            self.attention_chunk_size is not None\n            and local_attn_metadata is not None\n            and (hasattr(layer, \"use_irope\") and layer.use_irope)\n        )\n\n        # When Spec Decode enabled, forward_decode would be called with two mode:\n        # 1. DRAFT_DECODE: we enable cascade attention when top_k > 1\n        # 2. IDLE: we don’t need cascade attention, spec_info will be none in this case\n        use_cascade_attn = forward_batch.spec_info is not None and self.topk > 1\n\n        # Calculate window size (can be moved to metadata if layer properties don't change)\n        # we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1\n        # here is two side inclusive\n        window_size = (\n            (layer.sliding_window_size, 0)\n            if layer.sliding_window_size is not None and layer.sliding_window_size > -1\n            else (-1, -1)\n        )\n        causal = not layer.is_cross_attention\n\n        # For fa3 interface version compatibility, we put new fields into conditional keyword args\n        kwargs = {}\n        if sinks is not None:\n            kwargs[\"sinks\"] = sinks\n\n        k_descale, v_descale = None, None\n        # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention\n        # has corresponding quantization method so that layer.k_scale is not None,\n        # 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.\n        if self.kv_cache_dtype_str != \"auto\" and layer.head_dim <= 256:\n            if layer.k_scale is not None:\n                descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)\n                k_descale = layer.k_scale.expand(descale_shape)\n                v_descale = layer.v_scale.expand(descale_shape)\n            q = q.to(self.kv_cache_dtype)\n            q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None\n            k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None\n        if not self.use_mla:\n            # Do multi-head attention\n\n            key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(\n                layer.layer_id\n            )\n            key_cache = key_cache.view(\n                -1, self.page_size, layer.tp_k_head_num, layer.head_dim\n            )\n            value_cache = value_cache.view(\n                -1, self.page_size, layer.tp_v_head_num, layer.head_dim\n            )\n\n            if layer.is_cross_attention:\n                # Always use non-chunked logic for cross-attention\n                o = flash_attn_with_kvcache(\n                    q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),\n                    k_cache=key_cache,\n                    v_cache=value_cache,\n                    page_table=metadata.encoder_page_table,\n                    cache_seqlens=metadata.encoder_lens_int32,\n                    cu_seqlens_q=metadata.cu_seqlens_q,\n                    cu_seqlens_k_new=metadata.encoder_cu_seqlens_k,\n                    max_seqlen_q=1,\n                    softmax_scale=layer.scaling,\n                    causal=False,\n                    window_size=(-1, -1),\n                    softcap=layer.logit_cap,\n                    k_descale=k_descale,\n                    v_descale=v_descale,\n                    **kwargs,\n                )\n            elif use_local_attn:\n                # Use chunked (local) attention batching for self-attention\n                o = flash_attn_with_kvcache(\n                    q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),\n                    k_cache=key_cache,\n                    v_cache=value_cache,\n                    page_table=local_attn_metadata.local_block_table,\n                    cache_seqlens=local_attn_metadata.local_seqused_k,\n                    cu_seqlens_q=local_attn_metadata.local_query_start_loc,\n                    cu_seqlens_k_new=None,\n                    max_seqlen_q=local_attn_metadata.local_max_query_len,\n                    softmax_scale=layer.scaling,\n                    causal=True,\n                    window_size=(-1, -1),\n                    softcap=layer.logit_cap,\n                    k_descale=k_descale,\n                    v_descale=v_descale,\n                    **kwargs,\n                )\n            else:\n                page_table = metadata.page_table\n                cache_seqlens = metadata.cache_seqlens_int32\n                cu_seqlens_k = metadata.cu_seqlens_k\n                max_seqlen_q = metadata.max_seq_len_q\n                q_reshaped = q.contiguous().view(\n                    -1, layer.tp_q_head_num, layer.head_dim\n                )\n\n                # Default: single-token self-attention\n                result = flash_attn_with_kvcache(\n                    q=q_reshaped,\n                    k_cache=key_cache,\n                    v_cache=value_cache,\n                    page_table=page_table,\n                    cache_seqlens=cache_seqlens,\n                    cu_seqlens_q=metadata.cu_seqlens_q,\n                    cu_seqlens_k_new=cu_seqlens_k,\n                    max_seqlen_q=max_seqlen_q,\n                    softmax_scale=layer.scaling,\n                    causal=False if use_cascade_attn else causal,\n                    window_size=window_size,\n                    softcap=layer.logit_cap,\n                    k_descale=k_descale,\n                    v_descale=v_descale,\n                    return_softmax_lse=use_cascade_attn,\n                    **kwargs,\n                )\n                if use_cascade_attn:\n                    o, softmax_lse, *rest = result\n                    o_expand, softmax_lse_expand, *rest_expand = (\n                        flash_attn_with_kvcache(\n                            q=q_reshaped,\n                            k_cache=key_cache,\n                            v_cache=value_cache,\n                            page_table=self.forward_metadata_spec_decode_expand.page_table,\n                            cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,\n                            cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,\n                            cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,\n                            max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,\n                            softmax_scale=layer.scaling,\n                            causal=False,\n                            window_size=window_size,\n                            softcap=layer.logit_cap,\n                            k_descale=k_descale,\n                            v_descale=v_descale,\n                            return_softmax_lse=True,\n                            **kwargs,\n                        )\n                    )\n                    o, _ = merge_state_v2(\n                        o,\n                        softmax_lse.T.contiguous(),\n                        o_expand,\n                        softmax_lse_expand.T.contiguous(),\n                    )\n                else:\n                    o = result\n        else:\n            # Do absorbed multi-latent attention\n            kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(\n                q.dtype\n            )\n            k_rope = kv_cache[:, :, layer.v_head_dim :]\n            c_kv = kv_cache[:, :, : layer.v_head_dim]\n            k_rope_cache = k_rope.view(\n                -1,\n                self.page_size,\n                layer.tp_k_head_num,\n                layer.head_dim - layer.v_head_dim,\n            )\n            c_kv_cache = c_kv.view(\n                -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim\n            )\n\n            if q_rope is not None:\n                q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)\n                q_rope = q_rope.view(\n                    -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim\n                )\n            else:\n                q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)\n                q_nope = q_all[:, :, : layer.v_head_dim]\n                q_rope = q_all[:, :, layer.v_head_dim :]\n            max_seqlen_q = metadata.max_seq_len_q\n\n            result = flash_attn_with_kvcache(\n                q=q_rope,\n                k_cache=k_rope_cache,\n                v_cache=c_kv_cache,\n                qv=q_nope,\n                page_table=metadata.page_table,\n                cache_seqlens=metadata.cache_seqlens_int32,\n                cu_seqlens_q=metadata.cu_seqlens_q,\n                cu_seqlens_k_new=metadata.cu_seqlens_k,\n                max_seqlen_q=max_seqlen_q,\n                softmax_scale=layer.scaling,\n                causal=False if use_cascade_attn else causal,\n                softcap=layer.logit_cap,\n                k_descale=k_descale,\n                v_descale=v_descale,\n                return_softmax_lse=use_cascade_attn,  # softmax_lse is needed for merge states\n            )\n            if use_cascade_attn:\n                o, softmax_lse, *rest = result\n                o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache(\n                    q=q_rope,\n                    k_cache=k_rope_cache,\n                    v_cache=c_kv_cache,\n                    qv=q_nope,\n                    page_table=self.forward_metadata_spec_decode_expand.page_table,\n                    cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,\n                    cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,\n                    cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,\n                    max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,\n                    softmax_scale=layer.scaling,\n                    causal=False,\n                    window_size=window_size,\n                    softcap=layer.logit_cap,\n                    k_descale=k_descale,\n                    v_descale=v_descale,\n                    return_softmax_lse=True,\n                )\n                o, _ = merge_state_v2(\n                    o,\n                    softmax_lse.T.contiguous(),\n                    o_expand,\n                    softmax_lse_expand.T.contiguous(),\n                )\n            else:\n                o = result\n\n        return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)\n\n    def get_cuda_graph_seq_len_fill_value(self):\n        \"\"\"Get the fill value for sequence length in CUDA graph.\"\"\"\n        return 1\n\n    def _init_local_attn_metadata(\n        self, forwardbatch: ForwardBatch, metadata: FlashAttentionMetadata, device\n    ):\n        \"\"\"Centralized utility to initialize local_attn_metadata if chunked attention is enabled.\"\"\"\n        if self.attention_chunk_size is None:\n            metadata.local_attn_metadata = None\n            return\n\n        cu_seqlens_q = metadata.cu_seqlens_q\n        cache_seqlens_int32 = metadata.cache_seqlens_int32\n        if self.is_hybrid_swa:\n            page_table = self.full_to_swa_index_mapping[metadata.page_table].to(\n                torch.int32\n            )\n        else:\n            page_table = metadata.page_table\n        if cu_seqlens_q is None or cache_seqlens_int32 is None or page_table is None:\n            metadata.local_attn_metadata = None\n            return\n\n        cu_seqlens_q_np = cu_seqlens_q.cpu().numpy()\n        seq_lens_np = cache_seqlens_int32.cpu().numpy()\n        (\n            seqlens_q_local_np,\n            cu_seqlens_q_local_np,\n            seqlens_k_local_np,\n            block_table_local,\n        ) = make_local_attention_virtual_batches(\n            self.attention_chunk_size,\n            cu_seqlens_q_np,\n            seq_lens_np,\n            page_table,\n            self.page_size,\n        )\n\n        local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(\n            local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(device),\n            local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),\n            local_block_table=block_table_local.to(device),\n            local_max_query_len=int(seqlens_q_local_np.max()),\n            local_max_seq_len=int(seqlens_k_local_np.max()),\n        )\n        metadata.local_attn_metadata = local_metadata\n\n    def _init_sliding_window_attn_spec_metadata(\n        self,\n        metadata: FlashAttentionMetadata,\n        metadata_expand: FlashAttentionMetadata,\n        metadata_swa: Optional[FlashAttentionMetadata] = None,\n    ):\n        # TODO: support page_size > 1 for swa spec\n        assert (\n            self.page_size == 1\n        ), \"FlashAttention backend doesn't support topk > 1 speculative decoding with page size > 1 sliding window attention\"\n\n        cache_seqlens_int32 = (\n            metadata.cache_seqlens_int32.repeat_interleave(\n                self.speculative_num_draft_tokens\n            )\n            + metadata_expand.cache_seqlens_int32\n        )\n        cu_seqlens_k = torch.nn.functional.pad(\n            torch.cumsum(cache_seqlens_int32, dim=0, dtype=torch.int32), (1, 0)\n        )\n        bs = cache_seqlens_int32.shape[0]\n        page_table = (\n            metadata.page_table.new_zeros(\n                (bs, metadata.max_seq_len_k + metadata_expand.page_table.shape[1])\n            )\n            if metadata_swa is None\n            else metadata_swa.page_table\n        )\n\n        prepare_swa_spec_page_table_triton(\n            page_table,\n            metadata.page_table,\n            metadata_expand.page_table,\n            metadata.cache_seqlens_int32,\n            metadata_expand.cache_seqlens_int32,\n            self.speculative_num_draft_tokens,\n        )\n\n        if metadata_swa is None:\n            metadata_swa = FlashAttentionMetadata()\n            metadata_swa.max_seq_len_q = 1\n            metadata_swa.cu_seqlens_q = metadata_expand.cu_seqlens_q\n            metadata_swa.cache_seqlens_int32 = cache_seqlens_int32\n            metadata_swa.cu_seqlens_k = cu_seqlens_k\n            metadata_swa.page_table = page_table\n        else:\n            metadata_swa.cache_seqlens_int32.copy_(cache_seqlens_int32)\n            metadata_swa.cu_seqlens_k.copy_(cu_seqlens_k)\n\n        metadata.swa_spec_metadata = metadata_swa\n"
  },
  {
    "path": "python/sglang/srt/layers/communicator.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\nimport logging\nfrom contextlib import contextmanager\nfrom dataclasses import dataclass\nfrom enum import Enum, auto\nfrom functools import partial\nfrom typing import Callable, Dict, List, Optional, Tuple, Union\n\nimport torch\n\nfrom sglang.srt.distributed import (\n    get_tensor_model_parallel_rank,\n    get_tensor_model_parallel_world_size,\n    get_tp_group,\n    tensor_model_parallel_all_reduce,\n)\nfrom sglang.srt.distributed.device_communicators.pynccl_allocator import (\n    use_symmetric_memory,\n)\nfrom sglang.srt.environ import envs\nfrom sglang.srt.layers.attention.nsa.utils import (\n    is_nsa_enable_prefill_cp,\n    nsa_use_prefill_cp,\n)\nfrom sglang.srt.layers.dp_attention import (\n    attn_tp_all_gather_into_tensor,\n    attn_tp_reduce_scatter_tensor,\n    dp_gather_partial,\n    dp_reduce_scatter_tensor,\n    dp_scatter,\n    get_attention_cp_rank,\n    get_attention_cp_size,\n    get_attention_dp_size,\n    get_attention_tp_rank,\n    get_attention_tp_size,\n    get_global_dp_buffer,\n    get_local_dp_buffer,\n    is_allocation_symmetric,\n    is_dp_attention_enabled,\n)\nfrom sglang.srt.layers.flashinfer_comm_fusion import is_flashinfer_allreduce_unavailable\nfrom sglang.srt.layers.moe import (\n    get_moe_a2a_backend,\n    should_use_flashinfer_cutlass_moe_fp4_allgather,\n)\nfrom sglang.srt.model_executor.forward_batch_info import ForwardBatch\nfrom sglang.srt.server_args import get_global_server_args\nfrom sglang.srt.speculative.spec_info import SpeculativeAlgorithm\nfrom sglang.srt.utils import (\n    get_bool_env_var,\n    is_cuda,\n    is_flashinfer_available,\n    is_gfx95_supported,\n    is_hip,\n    is_npu,\n    is_sm90_supported,\n    is_sm100_supported,\n)\n\n_is_cuda = is_cuda()\n_is_flashinfer_available = is_flashinfer_available()\n_is_sm90_supported = _is_cuda and is_sm90_supported()\n_is_sm100_supported = _is_cuda and is_sm100_supported()\n_use_aiter = get_bool_env_var(\"SGLANG_USE_AITER\") and is_hip()\n_is_gfx95_supported = is_gfx95_supported()\n_is_npu = is_npu()\n_use_ag_after_qlora = envs.SGLANG_USE_AG_AFTER_QLORA.get()\n\nif _use_aiter and _is_gfx95_supported:\n    from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant\n\n    from sglang.srt.layers.quantization.rocm_mxfp4_utils import fused_rms_mxfp4_quant\nelif _is_npu:\n    from sglang.srt.hardware_backend.npu.cmo import prepare_weight_cache\n\n\n# TODO: According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465\n# We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True).\nFUSE_ALLREDUCE_MAX_BATCH_SIZE = 2048\n\n\ndef apply_flashinfer_allreduce_fusion(batch_size: int):\n    return (\n        # NOTE: flashinfer 0.6.1 caused performance regression on sm100 for allreduce fusion\n        # Ref: https://github.com/sgl-project/sglang/issues/17237\n        (_is_sm90_supported or _is_sm100_supported)\n        and _is_flashinfer_available\n        and batch_size > 0\n        and batch_size <= FUSE_ALLREDUCE_MAX_BATCH_SIZE\n        and not is_dp_attention_enabled()\n        and get_global_server_args().enable_flashinfer_allreduce_fusion\n        and not is_flashinfer_allreduce_unavailable()\n    )\n\n\ndef apply_aiter_all_reduce_fusion(input_tensor: torch.Tensor):\n    n = input_tensor.shape[-1]\n    total_bytes = input_tensor.numel() * input_tensor.element_size()\n    return (\n        _use_aiter\n        and total_bytes > 0\n        and n <= 16384\n        and total_bytes < 8 * 1024 * 8192\n        and get_tensor_model_parallel_world_size() != 6\n        and not is_dp_attention_enabled()\n        and get_global_server_args().enable_aiter_allreduce_fusion\n    )\n\n\nclass ScatterMode(Enum):\n    \"\"\"\n    Suppose we have TP=4, DP=2, enable-dp-attention, and the system handles seq a,b,c,d\n    Model input/output: [ab, ab, cd, cd] for four ranks respectively\n    SCATTERED: [a, b, c, d]\n    TP_ATTN_FULL: [ab, ab, cd, cd], i.e. all ranks inside a TP attn group have full data of the group\n    FULL: [abcd, abcd, abcd, abcd]\n    \"\"\"\n\n    SCATTERED = auto()\n    TP_ATTN_FULL = auto()\n    FULL = auto()\n\n    @staticmethod\n    def model_input_output():\n        \"\"\"The scatter mode for model forward pass input and output data\"\"\"\n        if is_nsa_enable_prefill_cp():\n            return ScatterMode.SCATTERED\n        return ScatterMode.TP_ATTN_FULL\n\n\nclass AttentionInputs:\n\n    def __init__(\n        self,\n        hidden_states: torch.Tensor,\n        forward_batch: ForwardBatch,\n        qkv_latent_func: Callable,\n    ):\n        self.hidden_states_local = hidden_states\n        self.forward_batch = forward_batch\n        self.qkv_latent_func = qkv_latent_func\n        self.hidden_states_ = None\n        self.qkv_latent_ = None\n\n    def tp_all_gather_hidden_states(self, hidden_states, forward_batch):\n        total_tokens = forward_batch.input_ids.shape[0]\n        output = hidden_states.new_empty((total_tokens, hidden_states.shape[-1]))\n        get_tp_group().all_gather_into_tensor(output, hidden_states)\n        return output\n\n    def fetch_qkv_latent(self):\n        if self.qkv_latent_ is not None:\n            return self.qkv_latent_\n        assert self.qkv_latent_func is not None\n        self.qkv_latent_ = self.qkv_latent_func(\n            self.hidden_states_local, self.forward_batch\n        )\n        if get_attn_tp_context().input_scattered:\n            self.qkv_latent_ = self.tp_all_gather_hidden_states(\n                self.qkv_latent_, self.forward_batch\n            )\n        return self.qkv_latent_\n\n    def fetch_hidden_states(self):\n        if self.hidden_states_ is not None:\n            return self.hidden_states_\n        self.hidden_states_ = self.hidden_states_local\n        if get_attn_tp_context().input_scattered:\n            self.hidden_states_ = self.tp_all_gather_hidden_states(\n                self.hidden_states_, self.forward_batch\n            )\n        return self.hidden_states_\n\n\nclass AttnTpContext:\n    def __init__(self):\n        self.allow_input_scattered = False\n        self.input_scattered_ = False\n        self.attn_inputs_: Optional[AttentionInputs] = None\n\n    def init_context(self, q_lora_rank, is_nsa):\n        self.allow_input_scattered = (\n            get_global_server_args().enable_attn_tp_input_scattered\n            and (_is_cuda or _is_npu)\n            and q_lora_rank is not None\n            and not is_nsa\n            and get_tensor_model_parallel_world_size() > 1\n            and not is_dp_attention_enabled()\n            and get_moe_a2a_backend().is_none()\n            and not enable_moe_dense_fully_dp()\n            and get_global_server_args().disable_piecewise_cuda_graph\n            and get_global_server_args().speculative_algorithm != \"EAGLE3\"\n        )\n        if get_global_server_args().enable_attn_tp_input_scattered:\n            if not self.allow_input_scattered:\n                logging.info(\n                    \"attn_tp_input_scattered is not enabled while other conditions are not met\"\n                )\n            else:\n                logging.info(\"attn_tp_input_scattered is enabled\")\n\n    def use_input_scattered(self, forward_batch: ForwardBatch):\n        return (\n            self.allow_input_scattered\n            and forward_batch.forward_mode.is_extend()\n            and not forward_batch.forward_mode.is_target_verify()\n            and not forward_batch.forward_mode.is_draft_extend()\n            and forward_batch.input_ids is not None\n            and not forward_batch.can_run_tbo\n        )\n\n    @property\n    def input_scattered(self):\n        return self.input_scattered_\n\n    def set_attn_inputs(self, attn_inputs: AttentionInputs):\n        self.attn_inputs_ = attn_inputs\n\n    def fetch_qkv_latent(self):\n        assert self.attn_inputs_ is not None\n        return self.attn_inputs_.fetch_qkv_latent()\n\n    def fetch_hidden_states(self):\n        assert self.attn_inputs_ is not None\n        return self.attn_inputs_.fetch_hidden_states()\n\n    @contextmanager\n    def maybe_input_scattered(self, forward_batch: ForwardBatch):\n        flag = self.use_input_scattered(forward_batch)\n        old_flag = self.input_scattered\n        self.input_scattered_ = flag\n        yield\n        self.input_scattered_ = old_flag\n        self.attn_inputs_ = None\n\n\nATTN_TP_CONTEXT = AttnTpContext()\n\n\ndef get_attn_tp_context():\n    return ATTN_TP_CONTEXT\n\n\n@dataclass\nclass _LayerModeComputationContext:\n    num_layers: int\n    layer_id: int\n    is_layer_sparse: bool\n    is_previous_layer_sparse: Optional[bool]\n    is_next_layer_sparse: Optional[bool]\n\n    def previous_layer(self):\n        assert self.is_previous_layer_sparse is not None\n        return _LayerModeComputationContext(\n            num_layers=self.num_layers,\n            layer_id=self.layer_id - 1,\n            is_layer_sparse=self.is_previous_layer_sparse,\n            is_previous_layer_sparse=None,\n            is_next_layer_sparse=self.is_layer_sparse,\n        )\n\n\n@dataclass\nclass LayerScatterModes:\n    layer_input_mode: ScatterMode\n    attn_mode: ScatterMode\n    # Can be further split into e.g. mlp_input_mode and mlp_output_mode if needed\n    mlp_mode: ScatterMode\n    middle_residual_mode: ScatterMode\n    layer_output_mode: ScatterMode\n\n    @classmethod\n    def init_new(cls, **kwargs):\n        context = _LayerModeComputationContext(**kwargs)\n        return cls(\n            layer_input_mode=cls._compute_layer_input_mode(context),\n            attn_mode=ScatterMode.TP_ATTN_FULL,\n            mlp_mode=cls._compute_mlp_mode(context),\n            middle_residual_mode=cls._compute_middle_residual_mode(context),\n            layer_output_mode=cls._compute_layer_output_mode(context),\n        )\n\n    @classmethod\n    def _compute_layer_input_mode(cls, context: _LayerModeComputationContext):\n        if context.layer_id == 0:\n            return ScatterMode.model_input_output()\n        return cls._compute_layer_output_mode(context.previous_layer())\n\n    @classmethod\n    def _compute_mlp_mode(cls, context: _LayerModeComputationContext):\n        if context.is_layer_sparse:\n            return (\n                ScatterMode.SCATTERED\n                if (\n                    # Token dispatch/combine will be handled outside of LayerCommunicator for these modes.\n                    not get_moe_a2a_backend().is_none()\n                    or should_use_flashinfer_cutlass_moe_fp4_allgather()\n                )\n                else ScatterMode.FULL\n            )\n        else:\n            return (\n                ScatterMode.SCATTERED\n                if enable_moe_dense_fully_dp()\n                else ScatterMode.FULL\n            )\n\n    @classmethod\n    def _should_gather_for_tbo(cls, context: _LayerModeComputationContext):\n        return (\n            not context.is_layer_sparse\n            and context.is_next_layer_sparse\n            and enable_moe_dense_fully_dp()\n            and get_global_server_args().enable_two_batch_overlap\n        )\n\n    @classmethod\n    def _compute_middle_residual_mode(cls, context: _LayerModeComputationContext):\n        mlp_mode = cls._compute_mlp_mode(context)\n        if mlp_mode == ScatterMode.SCATTERED:\n            return ScatterMode.SCATTERED\n        if mlp_mode == ScatterMode.FULL:\n            return ScatterMode.TP_ATTN_FULL\n        raise NotImplementedError\n\n    @classmethod\n    def _compute_layer_output_mode(cls, context: _LayerModeComputationContext):\n        mlp_mode = cls._compute_mlp_mode(context)\n        if context.layer_id == context.num_layers - 1:\n            return ScatterMode.model_input_output()\n        if mlp_mode == ScatterMode.SCATTERED:\n            if cls._should_gather_for_tbo(context):\n                return ScatterMode.TP_ATTN_FULL\n            return ScatterMode.SCATTERED\n        if mlp_mode == ScatterMode.FULL:\n            return ScatterMode.TP_ATTN_FULL\n        raise NotImplementedError\n\n\ndef enable_moe_dense_fully_dp():\n    return get_global_server_args().moe_dense_tp_size == 1\n\n\nclass LayerCommunicator:\n    def __init__(\n        self,\n        layer_scatter_modes: LayerScatterModes,\n        input_layernorm: torch.nn.Module,\n        post_attention_layernorm: torch.nn.Module,\n        # Reduce scatter requires skipping all-reduce in model code after MoE/MLP, so only enable for models which have that implemented. Remove flag once done for all models that use LayerCommunicator.\n        allow_reduce_scatter: bool = False,\n        is_last_layer: bool = False,\n        qkv_latent_func: Optional[Callable] = None,\n    ):\n        self.layer_scatter_modes = layer_scatter_modes\n        self.input_layernorm = input_layernorm\n        self.post_attention_layernorm = post_attention_layernorm\n        self.allow_reduce_scatter = allow_reduce_scatter\n        self.is_last_layer = is_last_layer\n        self.qkv_latent_func = qkv_latent_func\n\n        self._context = CommunicateContext.init_new()\n        self._post_init_communicate()\n        self._speculative_algo = SpeculativeAlgorithm.from_string(\n            get_global_server_args().speculative_algorithm\n        )\n\n    def _post_init_communicate(self):\n        self._communicate_simple_fn = CommunicateSimpleFn.get_fn(\n            input_mode=self.layer_scatter_modes.layer_input_mode,\n            output_mode=self.layer_scatter_modes.attn_mode,\n            context=self._context,\n        )\n        self._communicate_with_all_reduce_and_layer_norm_fn = (\n            CommunicateWithAllReduceAndLayerNormFn.get_fn(\n                hidden_states_input_mode=self.layer_scatter_modes.attn_mode,\n                residual_input_mode=self.layer_scatter_modes.layer_input_mode,\n                hidden_states_output_mode=self.layer_scatter_modes.mlp_mode,\n                residual_output_mode=self.layer_scatter_modes.middle_residual_mode,\n                context=self._context,\n            )\n        )\n        self._communicate_summable_tensor_pair_fn = (\n            CommunicateSummableTensorPairFn.get_fn(\n                hidden_states_input_mode=self.layer_scatter_modes.mlp_mode,\n                residual_input_mode=self.layer_scatter_modes.middle_residual_mode,\n                output_mode=self.layer_scatter_modes.layer_output_mode,\n                context=self._context,\n            )\n        )\n\n    def prepare_attn_and_capture_last_layer_outputs(\n        self,\n        hidden_states: torch.Tensor,\n        residual: torch.Tensor,\n        forward_batch: ForwardBatch,\n        captured_last_layer_outputs: Optional[List[torch.Tensor]] = None,\n        post_residual_addition: Optional[torch.Tensor] = None,\n    ):\n        hidden_states, residual = self.prepare_attn(\n            hidden_states,\n            residual,\n            forward_batch,\n            post_residual_addition=post_residual_addition,\n        )\n        if captured_last_layer_outputs is not None:\n            gathered_last_layer_output = self._communicate_simple_fn(\n                hidden_states=residual,\n                forward_batch=forward_batch,\n                context=self._context,\n            )\n            if gathered_last_layer_output is residual:\n                # Clone to avoid modifying the original residual by Custom RMSNorm inplace operation\n                gathered_last_layer_output = residual.clone()\n            captured_last_layer_outputs.append(gathered_last_layer_output)\n        return hidden_states, residual\n\n    def prepare_attn(\n        self,\n        hidden_states: torch.Tensor,\n        residual: torch.Tensor,\n        forward_batch: ForwardBatch,\n        quant_format: str = \"\",\n        post_residual_addition: Optional[torch.Tensor] = None,\n    ):\n        if get_attn_tp_context().input_scattered:\n            hidden_states, residual = self._tp_reduce_scatter(\n                hidden_states,\n                residual,\n            )\n        if hidden_states.shape[0] == 0:\n            residual = hidden_states\n        else:\n            if (\n                residual is not None\n                and hasattr(hidden_states, \"_sglang_needs_allreduce_fusion\")\n                and hidden_states._sglang_needs_allreduce_fusion\n            ):\n                if (\n                    apply_aiter_all_reduce_fusion(hidden_states)\n                    or apply_flashinfer_allreduce_fusion(hidden_states.shape[0])\n                ) and hasattr(self.input_layernorm, \"forward_with_allreduce_fusion\"):\n                    hidden_states, residual = (\n                        self.input_layernorm.forward_with_allreduce_fusion(\n                            hidden_states, residual\n                        )\n                    )\n                else:\n                    hidden_states = tensor_model_parallel_all_reduce(hidden_states)\n                    hidden_states, residual = self.input_layernorm(\n                        hidden_states, residual\n                    )\n            else:\n                if residual is None:\n                    residual = hidden_states\n\n                    if _use_aiter and _is_gfx95_supported and (\"mxfp4\" in quant_format):\n                        hidden_states, *_, _ = fused_rms_mxfp4_quant(\n                            hidden_states,\n                            self.input_layernorm.weight,\n                            self.input_layernorm.variance_epsilon,\n                            None,\n                            None,\n                            None,\n                            None,\n                        )\n                    elif _use_aiter and _is_gfx95_supported and (\"fp8\" in quant_format):\n\n                        hidden_states, _, _, _res = fused_rms_fp8_group_quant(\n                            hidden_states,\n                            self.input_layernorm.weight,\n                            self.input_layernorm.variance_epsilon,\n                            inp2=None,\n                            inp2_weight=None,\n                            inp2_epsilon=None,\n                            group_size=128,\n                            dtype_quant=torch.float8_e4m3fn,\n                            res1=None,\n                            output_unquantized_inp1=False,\n                        )\n\n                    else:\n                        hidden_states = self.input_layernorm(hidden_states)\n                else:\n\n                    if _use_aiter and _is_gfx95_supported and (\"mxfp4\" in quant_format):\n                        hidden_states, *_, residual = fused_rms_mxfp4_quant(\n                            hidden_states,\n                            self.input_layernorm.weight,\n                            self.input_layernorm.variance_epsilon,\n                            None,\n                            None,\n                            None,\n                            residual,\n                        )\n                    elif _use_aiter and _is_gfx95_supported and (\"fp8\" in quant_format):\n                        # RMSNorm + FP8 per-group quant\n                        # return hidden_states：\n                        #   out_fp8  : FP8 activation →  a8w8 GEMM\n                        #   out_bs   : block-scale →  gemm_a8w8_blockscale.x_scale\n                        hidden_states, _, _, residual = fused_rms_fp8_group_quant(\n                            hidden_states,\n                            self.input_layernorm.weight,\n                            self.input_layernorm.variance_epsilon,\n                            inp2=None,\n                            inp2_weight=None,\n                            inp2_epsilon=None,\n                            group_size=128,\n                            dtype_quant=torch.float8_e4m3fn,\n                            res1=residual,\n                            output_unquantized_inp1=False,\n                        )\n                    else:\n                        hidden_states, residual = self.input_layernorm(\n                            hidden_states,\n                            residual,\n                            post_residual_addition,\n                        )\n\n        hidden_states = self._communicate_simple_fn(\n            hidden_states=hidden_states,\n            forward_batch=forward_batch,\n            context=self._context,\n        )\n        if self.qkv_latent_func is not None:\n            attn_inputs = AttentionInputs(\n                hidden_states, forward_batch, self.qkv_latent_func\n            )\n            get_attn_tp_context().set_attn_inputs(attn_inputs)\n        return hidden_states, residual\n\n    def _tp_reduce_scatter(\n        self,\n        hidden_states: torch.Tensor,\n        residual: torch.Tensor,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        if hidden_states.shape[0] == 0:\n            return hidden_states, hidden_states\n        assert (\n            hidden_states.shape[0] % self._context.tp_size == 0\n        ), f\"Expected total tokens {hidden_states.shape[0]} % tp_size {self._context.tp_size} to be 0\"\n        local_tokens = hidden_states.shape[0] // self._context.tp_size\n        output = hidden_states.new_empty(local_tokens, *hidden_states.shape[1:])\n        get_tp_group().reduce_scatter_tensor(output, hidden_states)\n        if residual is not None:\n            residual = residual.tensor_split(self._context.tp_size)[\n                self._context.tp_rank\n            ]\n        return output, residual\n\n    def prepare_mlp(\n        self,\n        hidden_states: torch.Tensor,\n        residual: torch.Tensor,\n        forward_batch: ForwardBatch,\n        cache=None,\n    ):\n        if cache is not None:\n            self._context.cache = cache\n\n        return self._communicate_with_all_reduce_and_layer_norm_fn(\n            hidden_states=hidden_states,\n            residual=residual,\n            forward_batch=forward_batch,\n            layernorm=self.post_attention_layernorm,\n            context=self._context,\n        )\n\n    def postprocess_layer(\n        self,\n        hidden_states: torch.Tensor,\n        residual: torch.Tensor,\n        forward_batch: ForwardBatch,\n    ):\n        return self._communicate_summable_tensor_pair_fn(\n            hidden_states=hidden_states,\n            residual=residual,\n            forward_batch=forward_batch,\n            context=self._context,\n            allow_reduce_scatter=self.allow_reduce_scatter,\n        )\n\n    def should_use_reduce_scatter(self, forward_batch: ForwardBatch):\n        if not self.allow_reduce_scatter:\n            return False\n        if (\n            self._communicate_summable_tensor_pair_fn\n            is CommunicateSummableTensorPairFn._scatter_hidden_states\n            and forward_batch.dp_padding_mode.is_max_len()\n        ):\n            return True\n        if nsa_use_prefill_cp(forward_batch):\n            return True\n        if get_attn_tp_context().input_scattered and not self.is_last_layer:\n            return True\n        return False\n\n    # NOTE: This function will cause torch recompilation\n    def should_fuse_mlp_allreduce_with_next_layer(\n        self, forward_batch: ForwardBatch\n    ) -> bool:\n        if (\n            is_dp_attention_enabled()\n            and self._speculative_algo is not None\n            and self._speculative_algo.is_eagle()\n        ):\n            return False\n\n        if get_attn_tp_context().input_scattered:\n            return False\n\n        batch_size = (\n            forward_batch.input_ids.shape[0]\n            if hasattr(forward_batch, \"input_ids\")\n            else 0\n        )\n\n        return (\n            (\n                apply_flashinfer_allreduce_fusion(batch_size)\n                or (\n                    _use_aiter\n                    and batch_size > 0\n                    and get_tensor_model_parallel_world_size() != 6\n                    and get_global_server_args().enable_aiter_allreduce_fusion\n                )\n            )\n            and (not self.is_last_layer)\n            and (self._context.tp_size > 1)\n        )\n\n\n@dataclass\nclass CommunicateContext:\n    process_group_sizes: Dict[ScatterMode, int]\n    attn_tp_rank: int\n    attn_tp_size: int\n    attn_dp_size: int\n    attn_cp_rank: int\n    attn_cp_size: int\n    tp_size: int\n    cache = None\n    tp_rank: int\n\n    def is_same_group_size(self, a: ScatterMode, b: ScatterMode):\n        return self.process_group_sizes[a] == self.process_group_sizes[b]\n\n    @classmethod\n    def init_new(cls):\n        attn_tp_rank = get_attention_tp_rank()\n        attn_tp_size = get_attention_tp_size()\n        attn_dp_size = get_attention_dp_size()\n        attn_cp_size = get_attention_cp_size()\n        attn_cp_rank = get_attention_cp_rank()\n        tp_size = get_tensor_model_parallel_world_size()\n        tp_rank = get_tensor_model_parallel_rank()\n        process_group_sizes = {\n            ScatterMode.SCATTERED: 1,\n            ScatterMode.TP_ATTN_FULL: attn_tp_size,\n            # TODO: support --moe-dense-tp-size > 1\n            ScatterMode.FULL: tp_size,\n        }\n        return cls(\n            process_group_sizes=process_group_sizes,\n            attn_tp_rank=attn_tp_rank,\n            attn_tp_size=attn_tp_size,\n            attn_dp_size=attn_dp_size,\n            attn_cp_rank=attn_cp_rank,\n            attn_cp_size=attn_cp_size,\n            tp_size=tp_size,\n            tp_rank=tp_rank,\n        )\n\n\nclass CommunicateSimpleFn:\n    @staticmethod\n    def get_fn(\n        input_mode: ScatterMode,\n        output_mode: ScatterMode,\n        context: CommunicateContext,\n    ):\n        if context.is_same_group_size(input_mode, output_mode):\n            return CommunicateSimpleFn._trivial\n\n        if (input_mode == ScatterMode.SCATTERED) and (\n            output_mode == ScatterMode.TP_ATTN_FULL\n        ):\n            if _use_ag_after_qlora:\n                return CommunicateSimpleFn._trivial\n            return CommunicateSimpleFn._scattered_to_tp_attn_full\n\n        raise NotImplementedError(f\"{input_mode=} {output_mode=}\")\n\n    @staticmethod\n    def _trivial(\n        hidden_states: torch.Tensor,\n        forward_batch: ForwardBatch,\n        context: CommunicateContext,\n    ) -> torch.Tensor:\n        return hidden_states\n\n    @staticmethod\n    def _scattered_to_tp_attn_full(\n        hidden_states: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],\n        forward_batch: ForwardBatch,\n        context: CommunicateContext,\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        if isinstance(hidden_states, tuple):\n            gathered_hidden_states = []\n            for local_hidden_states in hidden_states:\n                with use_symmetric_memory(\n                    get_tp_group(),\n                    disabled=not is_allocation_symmetric(),\n                ):\n                    output = torch.empty(\n                        (\n                            local_hidden_states.shape[0] * context.attn_tp_size,\n                            *local_hidden_states.shape[1:],\n                        ),\n                        dtype=local_hidden_states.dtype,\n                        device=local_hidden_states.device,\n                    )\n                attn_tp_all_gather_into_tensor(\n                    output,\n                    local_hidden_states,\n                )\n                gathered_hidden_states.append(output)\n            return tuple(gathered_hidden_states)\n\n        hidden_states, local_hidden_states = (\n            get_local_dp_buffer(),\n            hidden_states,\n        )\n        attn_tp_all_gather_into_tensor(\n            hidden_states,\n            local_hidden_states,\n        )\n        return hidden_states\n\n\nclass CommunicateWithAllReduceAndLayerNormFn:\n    \"\"\"Besides communication, needs to\n    1. All reduce in tp_attn_group on hidden_states\n    2. Apply layer norm\n    \"\"\"\n\n    @staticmethod\n    def get_fn(\n        hidden_states_input_mode: ScatterMode,\n        residual_input_mode: ScatterMode,\n        hidden_states_output_mode: ScatterMode,\n        residual_output_mode: ScatterMode,\n        context: CommunicateContext,\n    ):\n\n        if (\n            context.is_same_group_size(\n                hidden_states_input_mode, hidden_states_output_mode\n            )\n            and context.is_same_group_size(residual_input_mode, residual_output_mode)\n            and context.attn_tp_size == 1\n        ):\n            return CommunicateWithAllReduceAndLayerNormFn._simple\n\n        if (\n            (hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)\n            and (\n                residual_input_mode in [ScatterMode.SCATTERED, ScatterMode.TP_ATTN_FULL]\n            )\n            and (hidden_states_output_mode == ScatterMode.FULL)\n            and (residual_output_mode == ScatterMode.TP_ATTN_FULL)\n        ):\n            return partial(\n                CommunicateWithAllReduceAndLayerNormFn._gather_hidden_states_and_residual,\n                residual_input_mode=residual_input_mode,\n            )\n\n        if (\n            (hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)\n            and (\n                residual_input_mode in [ScatterMode.SCATTERED, ScatterMode.TP_ATTN_FULL]\n            )\n            and (hidden_states_output_mode == ScatterMode.SCATTERED)\n            and (residual_output_mode == ScatterMode.SCATTERED)\n        ):\n            return partial(\n                CommunicateWithAllReduceAndLayerNormFn._scatter_hidden_states_and_residual,\n                residual_input_mode=residual_input_mode,\n            )\n\n        raise NotImplementedError(\n            f\"{hidden_states_input_mode=} {residual_input_mode=} {hidden_states_output_mode=} {residual_output_mode=}\"\n        )\n\n    @staticmethod\n    def _simple(\n        hidden_states: torch.Tensor,\n        residual: torch.Tensor,\n        forward_batch: ForwardBatch,\n        layernorm: torch.nn.Module,\n        context: CommunicateContext,\n    ):\n        # TODO move these `if shape != 0` into LayerNorm itself\n        if hidden_states.shape[0] != 0:\n            hidden_states, residual = layernorm(hidden_states, residual)\n        return hidden_states, residual\n\n    @staticmethod\n    def _gather_hidden_states_and_residual(\n        hidden_states: torch.Tensor,\n        residual: torch.Tensor,\n        forward_batch: ForwardBatch,\n        layernorm: torch.nn.Module,\n        context: CommunicateContext,\n        *,\n        residual_input_mode,\n    ):\n        if get_attn_tp_context().input_scattered:\n            return CommunicateWithAllReduceAndLayerNormFn._tp_all_reduce_with_scattered_residual(\n                hidden_states,\n                residual,\n                layernorm,\n                context,\n            )\n\n        if residual_input_mode == ScatterMode.SCATTERED and context.attn_tp_size > 1:\n            residual, local_residual = (\n                get_local_dp_buffer(),\n                residual,\n            )\n            attn_tp_all_gather_into_tensor(residual, local_residual)\n        if context.attn_dp_size != 1:\n            # Perform layernorm on smaller data before comm. Only valid when attn_tp_size is 1 (tp_size == dp_size)\n            use_layer_norm_before_gather = context.attn_tp_size == 1\n            if use_layer_norm_before_gather and hidden_states.shape[0] != 0:\n                with use_symmetric_memory(\n                    get_tp_group(),\n                    disabled=not is_allocation_symmetric(),\n                ):\n                    hidden_states, residual = layernorm(hidden_states, residual)\n            elif context.attn_tp_rank == 0:\n                hidden_states += residual\n\n            hidden_states, local_hidden_states = (\n                get_global_dp_buffer(),\n                hidden_states,\n            )\n            dp_gather_partial(hidden_states, local_hidden_states, forward_batch)\n\n            if not use_layer_norm_before_gather:\n                dp_scatter(residual, hidden_states, forward_batch)\n                if hidden_states.shape[0] != 0:\n                    hidden_states = layernorm(hidden_states)\n        else:\n            handled = False\n            if (\n                apply_aiter_all_reduce_fusion(hidden_states)\n                or apply_flashinfer_allreduce_fusion(hidden_states.shape[0])\n            ) and hasattr(layernorm, \"forward_with_allreduce_fusion\"):\n                hidden_states, residual = layernorm.forward_with_allreduce_fusion(\n                    hidden_states, residual\n                )\n                handled = True\n\n            if not handled:\n                hidden_states = tensor_model_parallel_all_reduce(hidden_states)\n                if _is_npu and context.cache is not None:\n                    _ = prepare_weight_cache(hidden_states, context.cache)\n                hidden_states, residual = layernorm(hidden_states, residual)\n        return hidden_states, residual\n\n    @staticmethod\n    def _scatter_hidden_states_and_residual(\n        hidden_states: torch.Tensor,\n        residual: torch.Tensor,\n        forward_batch: ForwardBatch,\n        layernorm: torch.nn.Module,\n        context: CommunicateContext,\n        *,\n        residual_input_mode,\n    ):\n        input_hidden_states = hidden_states\n        hidden_states = hidden_states.tensor_split(context.attn_tp_size)[\n            context.attn_tp_rank\n        ]\n        attn_tp_reduce_scatter_tensor(hidden_states, input_hidden_states)\n        if residual_input_mode == ScatterMode.TP_ATTN_FULL:\n            residual = residual.tensor_split(context.attn_tp_size)[context.attn_tp_rank]\n        if hidden_states.shape[0] != 0:\n            hidden_states, residual = layernorm(hidden_states, residual)\n        return hidden_states, residual\n\n    @staticmethod\n    def _tp_all_reduce_with_scattered_residual(\n        hidden_states: torch.Tensor,\n        residual: torch.Tensor,\n        layernorm: torch.nn.Module,\n        context: CommunicateContext,\n    ):\n        if hidden_states.shape[0] == 0:\n            return hidden_states, hidden_states\n\n        scattered_states = hidden_states.tensor_split(context.tp_size)[context.tp_rank]\n        scattered_states += residual\n        residual = tensor_model_parallel_all_reduce(hidden_states)\n        hidden_states = layernorm(residual)\n        return hidden_states, residual\n\n\nclass CommunicateSummableTensorPairFn:\n    \"\"\"It is allowed to make (hidden_states, residual) := (hidden_states + residual, None) if needed.\"\"\"\n\n    @classmethod\n    def execute(\n        cls,\n        hidden_states_input_mode,\n        residual_input_mode,\n        output_mode,\n        context,\n        **kwargs,\n    ):\n        return cls.get_fn(\n            hidden_states_input_mode=hidden_states_input_mode,\n            residual_input_mode=residual_input_mode,\n            output_mode=output_mode,\n            context=context,\n        )(context=context, **kwargs)\n\n    @staticmethod\n    def get_fn(\n        hidden_states_input_mode: ScatterMode,\n        residual_input_mode: ScatterMode,\n        output_mode: ScatterMode,\n        context: CommunicateContext,\n    ):\n        if context.is_same_group_size(\n            hidden_states_input_mode, output_mode\n        ) and context.is_same_group_size(residual_input_mode, output_mode):\n            return CommunicateSummableTensorPairFn._trivial\n\n        if (\n            (hidden_states_input_mode == ScatterMode.FULL)\n            and (residual_input_mode == ScatterMode.TP_ATTN_FULL)\n            and (output_mode == ScatterMode.TP_ATTN_FULL)\n        ):\n            return CommunicateSummableTensorPairFn._scatter_hidden_states\n\n        if (\n            (hidden_states_input_mode == ScatterMode.SCATTERED)\n            and (residual_input_mode == ScatterMode.SCATTERED)\n            and (output_mode == ScatterMode.TP_ATTN_FULL)\n        ):\n            return CommunicateSummableTensorPairFn._gather\n\n        if (\n            (hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)\n            and (residual_input_mode == ScatterMode.TP_ATTN_FULL)\n            and (output_mode == ScatterMode.SCATTERED)\n        ):\n            return CommunicateSummableTensorPairFn._scatter\n\n        raise NotImplementedError(\n            f\"{hidden_states_input_mode=} {residual_input_mode=} {output_mode=}\"\n        )\n\n    @staticmethod\n    def _trivial(\n        hidden_states: torch.Tensor,\n        residual: torch.Tensor,\n        forward_batch: ForwardBatch,\n        context: CommunicateContext,\n        **kwargs,\n    ):\n        return hidden_states, residual\n\n    @staticmethod\n    def _scatter_hidden_states(\n        hidden_states: torch.Tensor,\n        residual: torch.Tensor,\n        forward_batch: ForwardBatch,\n        context: CommunicateContext,\n        allow_reduce_scatter: bool = False,\n    ):\n        hidden_states, global_hidden_states = (\n            get_local_dp_buffer(),\n            hidden_states,\n        )\n        if allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len():\n            # When using padding, all_reduce is skipped after MLP and MOE and reduce scatter is used here instead.\n            dp_reduce_scatter_tensor(hidden_states, global_hidden_states)\n        else:\n            dp_scatter(hidden_states, global_hidden_states, forward_batch)\n        return hidden_states, residual\n\n    @staticmethod\n    def _gather(\n        hidden_states: torch.Tensor,\n        residual: torch.Tensor,\n        forward_batch: ForwardBatch,\n        context: CommunicateContext,\n        **kwargs,\n    ):\n        hidden_states += residual\n        residual = None\n        hidden_states, local_hidden_states = (\n            get_local_dp_buffer(),\n            hidden_states,\n        )\n        attn_tp_all_gather_into_tensor(\n            hidden_states,\n            local_hidden_states,\n        )\n        return hidden_states, residual\n\n    @staticmethod\n    def _scatter(\n        hidden_states: torch.Tensor,\n        residual: torch.Tensor,\n        forward_batch: ForwardBatch,\n        context: CommunicateContext,\n    ):\n        assert residual is None, \"not yet handled residual!=None\"\n        tensor_list = list(hidden_states.tensor_split(context.attn_tp_size))\n        hidden_states = tensor_list[context.attn_tp_rank]\n        return hidden_states, residual\n"
  },
  {
    "path": "python/sglang/srt/layers/communicator_nsa_cp.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\n\nfrom functools import partial\nfrom typing import Callable, Optional\n\nimport torch\n\nfrom sglang.srt.layers.attention.nsa.utils import (\n    is_nsa_enable_prefill_cp,\n    nsa_use_prefill_cp,\n)\nfrom sglang.srt.layers.communicator import (\n    CommunicateContext,\n    CommunicateSimpleFn,\n    CommunicateSummableTensorPairFn,\n    CommunicateWithAllReduceAndLayerNormFn,\n    LayerCommunicator,\n    LayerScatterModes,\n    ScatterMode,\n)\nfrom sglang.srt.layers.dp_attention import (\n    attn_cp_all_gather_into_tensor,\n    attn_cp_reduce_scatter_tensor,\n    get_local_dp_buffer,\n)\nfrom sglang.srt.model_executor.forward_batch_info import ForwardBatch\n\n\ndef nsa_enable_prefill_cp():\n    # After using cp, the communication mode of this part changes.\n    # The three parts of prepare_attn, prepare_mlp, and postprocess_layer\n    # no longer require additional communication for reduce, scatter, etc.\n    return is_nsa_enable_prefill_cp()\n\n\nclass NSACPLayerCommunicator(LayerCommunicator):\n    def __init__(\n        self,\n        layer_scatter_modes: LayerScatterModes,\n        input_layernorm: torch.nn.Module,\n        post_attention_layernorm: torch.nn.Module,\n        # Reduce scatter requires skipping all-reduce in model code after MoE/MLP, so only enable for models which have that implemented. Remove flag once done for all models that use LayerCommunicator.\n        allow_reduce_scatter: bool = False,\n        is_last_layer: bool = False,\n        qkv_latent_func: Optional[Callable] = None,\n    ):\n        super().__init__(\n            layer_scatter_modes,\n            input_layernorm,\n            post_attention_layernorm,\n            allow_reduce_scatter,\n            is_last_layer,\n            qkv_latent_func,\n        )\n\n    def _post_init_communicate(self):\n        # SCATTERED in attn tp is different from SCATTERED in global tp when dp_size > 1\n        if self.layer_scatter_modes.mlp_mode != ScatterMode.SCATTERED:\n            assert (\n                self._context.attn_dp_size == 1\n            ), f\"dp_size should be 1 when moe_runner_backend is none\"\n        self._communicate_simple_fn = NSACPCommunicateSimpleFn.get_fn(\n            input_mode=ScatterMode.SCATTERED,\n            output_mode=ScatterMode.SCATTERED,\n            context=self._context,\n        )\n        self._communicate_with_all_reduce_and_layer_norm_fn = NSACPCommunicateWithAllReduceAndLayerNormFn.get_fn(\n            hidden_states_input_mode=ScatterMode.SCATTERED,\n            residual_input_mode=ScatterMode.SCATTERED,\n            hidden_states_output_mode=self.layer_scatter_modes.mlp_mode,  # SCATTERED, FULL\n            residual_output_mode=ScatterMode.SCATTERED,\n            context=self._context,\n        )\n        self._communicate_summable_tensor_pair_fn = NSACPCommunicateSummableTensorPairFn.get_fn(\n            hidden_states_input_mode=self.layer_scatter_modes.mlp_mode,  # SCATTERED, FULL\n            residual_input_mode=ScatterMode.SCATTERED,\n            output_mode=ScatterMode.SCATTERED,\n            context=self._context,\n        )\n\n\nclass NSACPCommunicateSimpleFn(CommunicateSimpleFn):\n    @staticmethod\n    def get_fn(\n        input_mode: ScatterMode,\n        output_mode: ScatterMode,\n        context: CommunicateContext,\n    ):\n        if context.is_same_group_size(input_mode, output_mode):\n            return NSACPCommunicateSimpleFn._trivial\n\n        raise NotImplementedError(f\"{input_mode=} {output_mode=}\")\n\n\nclass NSACPCommunicateWithAllReduceAndLayerNormFn(\n    CommunicateWithAllReduceAndLayerNormFn\n):\n    \"\"\"Besides communication, needs to\n    1. All reduce in tp_attn_group on hidden_states\n    2. Apply layer norm\n    \"\"\"\n\n    @staticmethod\n    def get_fn(\n        hidden_states_input_mode: ScatterMode,\n        residual_input_mode: ScatterMode,\n        hidden_states_output_mode: ScatterMode,\n        residual_output_mode: ScatterMode,\n        context: CommunicateContext,\n    ):\n        assert hidden_states_input_mode == ScatterMode.SCATTERED\n        assert residual_input_mode == ScatterMode.SCATTERED\n        assert residual_output_mode == ScatterMode.SCATTERED\n        if hidden_states_output_mode == ScatterMode.SCATTERED:\n            return NSACPCommunicateWithAllReduceAndLayerNormFn._simple\n\n        if hidden_states_output_mode == ScatterMode.FULL:\n            return partial(\n                NSACPCommunicateWithAllReduceAndLayerNormFn._gather_hidden_states_and_residual,\n                residual_input_mode=residual_input_mode,\n            )\n\n        raise NotImplementedError(\n            f\"{hidden_states_input_mode=} {residual_input_mode=} {hidden_states_output_mode=} {residual_output_mode=}\"\n        )\n\n    @staticmethod\n    def _gather_hidden_states_and_residual(\n        hidden_states: torch.Tensor,\n        residual: torch.Tensor,\n        forward_batch: ForwardBatch,\n        layernorm: torch.nn.Module,\n        context: CommunicateContext,\n        *,\n        residual_input_mode,\n    ):\n        if hidden_states.shape[0] != 0:\n            hidden_states, residual = layernorm(hidden_states, residual)\n        # for prefill: attn tp scattered -> full\n        # for decode: attn tp full -> full\n        if nsa_use_prefill_cp(forward_batch):\n            assert context.attn_dp_size == 1\n            hidden_states, local_hidden_states = (\n                get_local_dp_buffer(),\n                hidden_states,\n            )\n            attn_cp_all_gather_into_tensor(\n                hidden_states,\n                local_hidden_states,\n            )\n        return hidden_states, residual\n\n\nclass NSACPCommunicateSummableTensorPairFn(CommunicateSummableTensorPairFn):\n    \"\"\"It is allowed to make (hidden_states, residual) := (hidden_states + residual, None) if needed.\"\"\"\n\n    @staticmethod\n    def get_fn(\n        hidden_states_input_mode: ScatterMode,\n        residual_input_mode: ScatterMode,\n        output_mode: ScatterMode,\n        context: CommunicateContext,\n    ):\n        if context.is_same_group_size(\n            hidden_states_input_mode, output_mode\n        ) and context.is_same_group_size(residual_input_mode, output_mode):\n            return NSACPCommunicateSummableTensorPairFn._trivial\n\n        if (\n            (hidden_states_input_mode == ScatterMode.FULL)\n            and (residual_input_mode == ScatterMode.SCATTERED)\n            and (output_mode == ScatterMode.SCATTERED)\n        ):\n            return NSACPCommunicateSummableTensorPairFn._scatter_hidden_states\n\n        raise NotImplementedError(\n            f\"{hidden_states_input_mode=} {residual_input_mode=} {output_mode=}\"\n        )\n\n    @staticmethod\n    def _scatter_hidden_states(\n        hidden_states: torch.Tensor,\n        residual: torch.Tensor,\n        forward_batch: ForwardBatch,\n        context: CommunicateContext,\n        allow_reduce_scatter: bool = False,\n    ):\n        # for prefill: full -> attn tp scattered\n        # for decode: full -> attn tp full\n        if nsa_use_prefill_cp(forward_batch):\n            assert context.attn_dp_size == 1\n            input_hidden_states = hidden_states\n            hidden_states = hidden_states.tensor_split(context.attn_cp_size)[\n                context.attn_cp_rank\n            ]\n            attn_cp_reduce_scatter_tensor(hidden_states, input_hidden_states)\n        return hidden_states, residual\n"
  },
  {
    "path": "python/sglang/srt/layers/conv.py",
    "content": "\"\"\"\nConv2d/Conv3d layers with unfold+linear optimization for patch embeddings.\n\nWhen kernel_size == stride, padding == 0, dilation == 1, groups == 1, the conv\nis equivalent to unfold + F.linear, which is significantly faster on CUDA and\nalso avoids the PyTorch 2.9.1 + CuDNN < 9.15 Conv3d bug\n(https://github.com/pytorch/pytorch/issues/168167).\n\"\"\"\n\nimport math\nfrom typing import Tuple, Union\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom sglang.srt.layers.utils.multi_platform import MultiPlatformOp\n\n_VALID_PADDING_STRINGS = {\"same\", \"valid\"}\n_VALID_PADDING_MODES = {\"zeros\", \"reflect\", \"replicate\", \"circular\"}\n\n\ndef _tuplify(val, n: int) -> tuple:\n    if isinstance(val, (list, tuple)):\n        assert len(val) == n\n        return tuple(val)\n    return (val,) * n\n\n\ndef _check_enable_linear(\n    kernel_size: tuple,\n    stride: tuple,\n    padding: tuple,\n    dilation: tuple,\n    groups: int,\n) -> bool:\n    \"\"\"Check if conv can be replaced with unfold + F.linear.\"\"\"\n    return (\n        kernel_size == stride\n        and all(p == 0 for p in padding)\n        and all(d == 1 for d in dilation)\n        and groups == 1\n    )\n\n\ndef _reverse_repeat_tuple(t: tuple) -> tuple:\n    \"\"\"(1, 2, 3) -> (3, 3, 2, 2, 1, 1). Used for F.pad with non-zeros padding_mode.\"\"\"\n    return tuple(x for x in reversed(t) for _ in range(2))\n\n\ndef _compute_same_padding_for_pad(kernel_size: tuple, dilation: tuple) -> tuple:\n    \"\"\"Compute _reversed_padding_repeated_twice for padding='same'.\n\n    This mirrors PyTorch's nn.Conv*d behavior: pre-compute the exact pad\n    amounts so that F.pad can be called before F.conv*d(padding=0).\n    \"\"\"\n    pad = []\n    for k, d in zip(reversed(kernel_size), reversed(dilation)):\n        total = d * (k - 1)\n        pad.append(total // 2)\n        pad.append(total - total // 2)\n    return tuple(pad)\n\n\ndef _validate_conv_args(\n    in_channels: int,\n    out_channels: int,\n    groups: int,\n    padding,\n    padding_mode: str,\n    stride: tuple,\n) -> None:\n    if in_channels % groups != 0:\n        raise ValueError(\n            f\"in_channels ({in_channels}) must be divisible by groups ({groups})\"\n        )\n    if out_channels % groups != 0:\n        raise ValueError(\n            f\"out_channels ({out_channels}) must be divisible by groups ({groups})\"\n        )\n    if padding_mode not in _VALID_PADDING_MODES:\n        raise ValueError(\n            f\"padding_mode must be one of {_VALID_PADDING_MODES}, got '{padding_mode}'\"\n        )\n    if isinstance(padding, str):\n        if padding not in _VALID_PADDING_STRINGS:\n            raise ValueError(\n                f\"padding must be one of {_VALID_PADDING_STRINGS}, got '{padding}'\"\n            )\n        if padding == \"same\" and any(s != 1 for s in stride):\n            raise ValueError(\"padding='same' is not supported for strided convolutions\")\n\n\nclass Conv2dLayer(MultiPlatformOp):\n    \"\"\"Drop-in replacement for nn.Conv2d. Linear optimization disabled by default.\"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: Union[int, Tuple[int, int]],\n        stride: Union[int, Tuple[int, int]] = 1,\n        padding: Union[int, Tuple[int, int], str] = 0,\n        dilation: Union[int, Tuple[int, int]] = 1,\n        groups: int = 1,\n        bias: bool = True,\n        padding_mode: str = \"zeros\",\n        disable_linear: bool = True,\n    ):\n        super().__init__()\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.kernel_size = _tuplify(kernel_size, 2)\n        self.stride = _tuplify(stride, 2)\n        self.dilation = _tuplify(dilation, 2)\n        self.groups = groups\n        self.padding_mode = padding_mode\n\n        _validate_conv_args(\n            in_channels, out_channels, groups, padding, padding_mode, self.stride\n        )\n\n        if isinstance(padding, str):\n            self.padding = (0, 0) if padding == \"valid\" else padding\n        else:\n            self.padding = _tuplify(padding, 2)\n\n        # Pre-compute pad tuple for padding_mode != \"zeros\" (mirrors nn.Conv2d).\n        # When padding=\"same\", we need numeric values for F.pad;\n        # when padding is already numeric, _reverse_repeat_tuple handles it.\n        if isinstance(self.padding, str):\n            self._reversed_padding_repeated_twice = _compute_same_padding_for_pad(\n                self.kernel_size, self.dilation\n            )\n        else:\n            self._reversed_padding_repeated_twice = _reverse_repeat_tuple(self.padding)\n\n        padding_tuple = self.padding if isinstance(self.padding, tuple) else (1, 1)\n        self.enable_linear = not disable_linear and _check_enable_linear(\n            self.kernel_size, self.stride, padding_tuple, self.dilation, groups\n        )\n\n        self.weight = nn.Parameter(\n            torch.empty(out_channels, in_channels // groups, *self.kernel_size)\n        )\n        if bias:\n            self.bias = nn.Parameter(torch.empty(out_channels))\n        else:\n            self.register_parameter(\"bias\", None)\n\n        self._reset_parameters()\n\n    def _reset_parameters(self):\n        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))\n        if self.bias is not None:\n            fan_in = nn.init._calculate_correct_fan(self.weight, \"fan_in\")\n            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0\n            nn.init.uniform_(self.bias, -bound, bound)\n\n    def _forward_mulmat(self, x: torch.Tensor) -> torch.Tensor:\n        K1, K2 = self.kernel_size\n        x = x.unfold(2, K1, K1).unfold(3, K2, K2)\n        N, _, Hp, Wp = x.shape[:4]\n        x = x.permute(0, 2, 3, 1, 4, 5).reshape(N, Hp, Wp, -1)\n        x = F.linear(x, self.weight.reshape(self.out_channels, -1), self.bias)\n        return x.permute(0, 3, 1, 2)\n\n    def _forward_conv(self, x: torch.Tensor) -> torch.Tensor:\n        if self.padding_mode != \"zeros\":\n            return F.conv2d(\n                F.pad(x, self._reversed_padding_repeated_twice, mode=self.padding_mode),\n                self.weight,\n                self.bias,\n                self.stride,\n                (0, 0),\n                self.dilation,\n                self.groups,\n            )\n        return F.conv2d(\n            x,\n            self.weight,\n            self.bias,\n            self.stride,\n            self.padding,\n            self.dilation,\n            self.groups,\n        )\n\n    def forward_native(self, x: torch.Tensor) -> torch.Tensor:\n        if self.enable_linear:\n            return self._forward_mulmat(x)\n        return self._forward_conv(x)\n\n    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:\n        if self.enable_linear:\n            return self._forward_mulmat(x)\n        return self._forward_conv(x)\n\n\nclass Conv3dLayer(MultiPlatformOp):\n    \"\"\"Drop-in replacement for nn.Conv3d with automatic linear optimization.\"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: Union[int, Tuple[int, int, int]],\n        stride: Union[int, Tuple[int, int, int]] = 1,\n        padding: Union[int, Tuple[int, int, int], str] = 0,\n        dilation: Union[int, Tuple[int, int, int]] = 1,\n        groups: int = 1,\n        bias: bool = True,\n        padding_mode: str = \"zeros\",\n        disable_linear: bool = False,\n    ):\n        super().__init__()\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.kernel_size = _tuplify(kernel_size, 3)\n        self.stride = _tuplify(stride, 3)\n        self.dilation = _tuplify(dilation, 3)\n        self.groups = groups\n        self.padding_mode = padding_mode\n\n        _validate_conv_args(\n            in_channels, out_channels, groups, padding, padding_mode, self.stride\n        )\n\n        if isinstance(padding, str):\n            self.padding = (0, 0, 0) if padding == \"valid\" else padding\n        else:\n            self.padding = _tuplify(padding, 3)\n\n        if isinstance(self.padding, str):\n            self._reversed_padding_repeated_twice = _compute_same_padding_for_pad(\n                self.kernel_size, self.dilation\n            )\n        else:\n            self._reversed_padding_repeated_twice = _reverse_repeat_tuple(self.padding)\n\n        padding_tuple = self.padding if isinstance(self.padding, tuple) else (1, 1, 1)\n        self.enable_linear = not disable_linear and _check_enable_linear(\n            self.kernel_size, self.stride, padding_tuple, self.dilation, groups\n        )\n\n        self.weight = nn.Parameter(\n            torch.empty(out_channels, in_channels // groups, *self.kernel_size)\n        )\n        if bias:\n            self.bias = nn.Parameter(torch.empty(out_channels))\n        else:\n            self.register_parameter(\"bias\", None)\n\n        self._reset_parameters()\n\n    def _reset_parameters(self):\n        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))\n        if self.bias is not None:\n            fan_in = nn.init._calculate_correct_fan(self.weight, \"fan_in\")\n            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0\n            nn.init.uniform_(self.bias, -bound, bound)\n\n    def _forward_mulmat(self, x: torch.Tensor) -> torch.Tensor:\n        K1, K2, K3 = self.kernel_size\n        x = x.unfold(2, K1, K1).unfold(3, K2, K2).unfold(4, K3, K3)\n        N, Dp, Hp, Wp = x.shape[0], x.shape[2], x.shape[3], x.shape[4]\n        x = x.permute(0, 2, 3, 4, 1, 5, 6, 7).reshape(N, Dp, Hp, Wp, -1)\n        x = F.linear(x, self.weight.reshape(self.out_channels, -1), self.bias)\n        return x.permute(0, 4, 1, 2, 3)\n\n    def _forward_conv(self, x: torch.Tensor) -> torch.Tensor:\n        if self.padding_mode != \"zeros\":\n            return F.conv3d(\n                F.pad(x, self._reversed_padding_repeated_twice, mode=self.padding_mode),\n                self.weight,\n                self.bias,\n                self.stride,\n                (0, 0, 0),\n                self.dilation,\n                self.groups,\n            )\n        return F.conv3d(\n            x,\n            self.weight,\n            self.bias,\n            self.stride,\n            self.padding,\n            self.dilation,\n            self.groups,\n        )\n\n    def forward_native(self, x: torch.Tensor) -> torch.Tensor:\n        if self.enable_linear:\n            return self._forward_mulmat(x)\n        return self._forward_conv(x)\n\n    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:\n        if self.enable_linear:\n            return self._forward_mulmat(x)\n        return self._forward_conv(x)\n"
  },
  {
    "path": "python/sglang/srt/layers/deep_gemm_wrapper/__init__.py",
    "content": "from .entrypoint import *\n"
  },
  {
    "path": "python/sglang/srt/layers/deep_gemm_wrapper/compile_utils.py",
    "content": "import logging\nimport os\nfrom contextlib import contextmanager\nfrom enum import IntEnum, auto\nfrom typing import Dict, List, Tuple\n\nimport torch\nfrom tqdm import tqdm\n\nfrom sglang.srt.distributed.device_communicators.pynccl_allocator import (\n    disable_symmetric_memory_context,\n    restore_symmetric_memory_context,\n)\nfrom sglang.srt.environ import envs\nfrom sglang.srt.layers.deep_gemm_wrapper.configurer import ENABLE_JIT_DEEPGEMM\nfrom sglang.srt.server_args import ServerArgs\nfrom sglang.srt.utils import ceil_div, get_available_gpu_memory\n\nlogger = logging.getLogger(__name__)\n\nif ENABLE_JIT_DEEPGEMM:\n    import deep_gemm\n\n\n_BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))\n_ENABLE_JIT_DEEPGEMM_PRECOMPILE = envs.SGLANG_JIT_DEEPGEMM_PRECOMPILE.get()\n_DO_COMPILE_ALL = True\n_IS_FIRST_RANK_ON_NODE = envs.SGLANG_IS_FIRST_RANK_ON_NODE.get()\n_IN_PRECOMPILE_STAGE = envs.SGLANG_IN_DEEPGEMM_PRECOMPILE_STAGE.get()\n_FAST_WARMUP = envs.SGLANG_JIT_DEEPGEMM_FAST_WARMUP.get()\n\n# Force redirect deep_gemm cache_dir\nos.environ[\"DG_JIT_CACHE_DIR\"] = os.getenv(\n    \"SGLANG_DG_CACHE_DIR\", os.path.join(os.path.expanduser(\"~\"), \".cache\", \"deep_gemm\")\n)\n\n# Refer to https://github.com/deepseek-ai/DeepGEMM/commit/d75b218b7b8f4a5dd5406ac87905039ead3ae42f\n# NVRTC may have performance loss with some cases.\n# And NVCC JIT speed is also 9x faster in the ref commit\nos.environ[\"DG_JIT_USE_NVRTC\"] = os.getenv(\"SGL_DG_USE_NVRTC\", \"0\")\n\n\ndef update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):\n    global _BUILTIN_M_LIST\n    global _DO_COMPILE_ALL\n    global _IS_FIRST_RANK_ON_NODE\n\n    _BUILTIN_M_LIST = []\n\n    if _FAST_WARMUP:\n        # In fast warmup mode, only compile a small set of typical Ms\n\n        # First cover all the small bs to ensure decode performance\n        _BUILTIN_M_LIST += list(range(1, 1025))\n\n        # Then cover larger batch sizes with gradually increasing steps\n        # For example, when chunekd prefill size is 16384\n        # The sampled Ms would be:\n        #   1024, 1026, ... 2046 (step 2)\n        #   2048, 2052, ... 4092 (step 4)\n        #   4096, 5004, ... 8184 (step 8)\n        #   8192, 9008, ... 16384 (step 16)\n        # Totally 1024 + 1024 / 2 + 2048 / 4 + 4096 / 8 + 8192 / 16 = 3072 kernels\n        next_m, sample_step = 1024, 2\n        max_prefill_bs = (\n            min(server_args.chunked_prefill_size, 32 * 1024)\n            if server_args.chunked_prefill_size >= 1\n            else 16 * 1024\n        )\n        while next_m < max_prefill_bs:\n            _BUILTIN_M_LIST += list(range(next_m, 2 * next_m, sample_step))\n            next_m = next_m * 2\n            sample_step = sample_step * 2\n        _BUILTIN_M_LIST.append(max_prefill_bs)\n        _BUILTIN_M_LIST = sorted(list(set(_BUILTIN_M_LIST)))\n    else:\n        # When fast warmup isn't enabled, generate m_max and compile all the covered Ms.\n        m_max = 1024 * 16\n        if server_args.chunked_prefill_size < 1:\n            m_max = 1024 * 64\n        elif server_args.chunked_prefill_size > 8192:\n            m_max = server_args.chunked_prefill_size * 2\n        m_max = min(1024 * 128, m_max)\n        _BUILTIN_M_LIST += list(range(1, m_max + 1))\n\n    _IS_FIRST_RANK_ON_NODE = server_args.base_gpu_id == gpu_id\n\n    # Check if is the first rank on node.\n    # Default each rank will try compile all Ms to\n    # load all symbols at the launch stages.\n    # Avoid loading symbols at the serving stages.\n    _DO_COMPILE_ALL = _IS_FIRST_RANK_ON_NODE\n\n\nclass DeepGemmKernelType(IntEnum):\n    GROUPED_GEMM_NT_F8F8BF16_MASKED = auto()\n    GROUPED_GEMM_NT_F8F8BF16_CONTIG = auto()\n    GEMM_NT_F8F8BF16 = auto()\n    GEMM_NT_BF16BF16F32 = auto()\n\n\n_INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dict()\n\n\n# TODO improve code\ndef _maybe_compile_deep_gemm_one_type_all(\n    kernel_type: DeepGemmKernelType,\n    n: int,\n    k: int,\n    num_groups: int,\n) -> None:\n    global _INITIALIZATION_DICT\n    global _BUILTIN_M_LIST\n\n    query_key = (kernel_type, n, k, num_groups)\n    if (\n        _ENABLE_JIT_DEEPGEMM_PRECOMPILE\n        and _DO_COMPILE_ALL\n        and _INITIALIZATION_DICT.get(query_key) is None\n    ):\n        _INITIALIZATION_DICT[query_key] = True\n\n        # TODO maybe improve logs\n        if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:\n            logger.warning(\n                \"Entering DeepGEMM JIT Pre-Compile session. \"\n                \"It may take a long time (typically 10-20 mins) \"\n                \"if you have not run `sglang.compile_deep_gemm`. \"\n                \"It is recommended to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`\"\n                \" for pre-compilation to reduce the overhead if you have not run it before. \"\n                \"For example: \"\n                \"`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`\"\n            )\n\n        logger.info(\n            f\"Try DeepGEMM JIT Compiling for \"\n            f\"<{kernel_type.name}> N={n}, K={k}, num_groups={num_groups} with all Ms.\"\n            f\"{' It only takes a little time (typically 1 sec) if you have run `python3 -m sglang.compile_deep_gemm`. ' if not _IN_PRECOMPILE_STAGE else ''}\"\n        )\n\n        _compile_deep_gemm_one_type_all(\n            kernel_type=kernel_type,\n            n=n,\n            k=k,\n            num_groups=num_groups,\n            m_list=_BUILTIN_M_LIST,\n        )\n\n\n# NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced\ndef _compile_deep_gemm_one_type_all(\n    kernel_type: DeepGemmKernelType,\n    n: int,\n    k: int,\n    num_groups: int,\n    m_list: List[int],\n) -> None:\n    # Symmetric memory allocation performs a collective operation across all the GPUs.\n    # Temporary disable symmetric memory during compilation since it only runs on the first rank.\n    saved_context = disable_symmetric_memory_context()\n    try:\n        if kernel_type == DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG:\n            m_alignment = deep_gemm.get_mk_alignment_for_contiguous_layout()\n            m_list = sorted(list(set(m for m in m_list if m % m_alignment == 0)))\n\n        # Here the precompilation is only run on the first rank, so gpu_id should be 0\n        memory_budget = get_available_gpu_memory(device=\"cuda\", gpu_id=0)\n\n        # If the memory budget is less memory requirement, we need to reduce max_m to avoid out of memory, which might further cause hanging during warmup\n        max_m = max(m_list)\n        required_memory = _BaseWarmupExecutor.get_memory_requirement(\n            kernel_type, max_m=max_m, n=n, k=k, num_groups=num_groups\n        )\n        logger.info(\n            f\"Required memory for warmup: {required_memory}GB, Available memory: {memory_budget}GB\"\n        )\n        if memory_budget < required_memory:\n            # TODO: Maybe compute the max_m based on the memory budget\n            while (\n                _BaseWarmupExecutor.get_memory_requirement(\n                    kernel_type, max_m=max_m, n=n, k=k, num_groups=num_groups\n                )\n                > memory_budget\n                and max_m > 4096\n            ):\n                max_m = max_m // 2\n            logger.warning(\n                f\"Available memory {memory_budget}GB is less than required memory {required_memory}GB for warmup, reducing max_m to {max_m} to avoid out of memory\"\n            )\n            m_list = [m for m in m_list if m <= max_m]\n\n        # Need some methods to estimate needed memory for warmup\n        executor = _BaseWarmupExecutor.create(\n            kernel_type, max_m=max_m, n=n, k=k, num_groups=num_groups\n        )\n\n        old_compile_mode = deep_gemm.get_compile_mode()\n        deep_gemm.set_compile_mode(1)\n        # TODO can use multi thread\n        for m in tqdm(m_list, desc=f\"DeepGEMM warmup\"):\n            executor.execute(m=m)\n        deep_gemm.set_compile_mode(old_compile_mode)\n\n        # clean up input buffers\n        torch.cuda.current_stream().synchronize()\n        del executor\n        torch.cuda.empty_cache()\n    finally:\n        # Restore symmetric memory context\n        restore_symmetric_memory_context(saved_context)\n\n\nclass _BaseWarmupExecutor:\n    @staticmethod\n    def create(kernel_type: DeepGemmKernelType, **kwargs):\n        return {\n            DeepGemmKernelType.GEMM_NT_F8F8BF16: _NormalWarmupExecutor,\n            DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG: _GroupedContWarmupExecutor,\n            DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED: _GroupedMaskedWarmupExecutor,\n            DeepGemmKernelType.GEMM_NT_BF16BF16F32: _BF16F32WarmupExecutor,\n        }[kernel_type](**kwargs)\n\n    @staticmethod\n    def get_memory_requirement(\n        kernel_type: DeepGemmKernelType, max_m: int, n: int, k: int, num_groups: int\n    ) -> int:\n        # Return the required memory space in GB for warmup executor\n        _GB = 1 << 30\n        if kernel_type == DeepGemmKernelType.GEMM_NT_F8F8BF16:\n            return (max_m * k + n * k + max_m * n * 2) / _GB\n        elif kernel_type == DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG:\n            return (max_m * k + num_groups * n * k + max_m * 4 + max_m * n * 2) / _GB\n        elif kernel_type == DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED:\n            return (\n                num_groups * max_m * k\n                + num_groups * n * k\n                + num_groups * 4\n                + num_groups * max_m * n * 2\n            ) / _GB\n        elif kernel_type == DeepGemmKernelType.GEMM_NT_BF16BF16F32:\n            # bf16 lhs + bf16 rhs + fp32 out\n            return (max_m * k * 2 + n * k * 2 + max_m * n * 4) / _GB\n        else:\n            raise ValueError(f\"Invalid kernel type: {kernel_type}\")\n\n    def execute(self, m):\n        raise NotImplementedError\n\n\ndef _empty_token_fp8(size):\n    *dims, k = size\n    return (\n        torch.empty(size, device=\"cuda\", dtype=torch.float8_e4m3fn),\n        torch.empty(\n            (*dims, ceil_div(k, _BLOCK_SIZE)), device=\"cuda\", dtype=torch.float32\n        ),\n    )\n\n\ndef _empty_block_fp8(size):\n    *dims, n, k = size\n    return (\n        torch.empty(size, device=\"cuda\", dtype=torch.float8_e4m3fn),\n        torch.empty(\n            (*dims, ceil_div(n, _BLOCK_SIZE), ceil_div(k, _BLOCK_SIZE)),\n            device=\"cuda\",\n            dtype=torch.float32,\n        ),\n    )\n\n\n_BLOCK_SIZE = 128\n\n\nclass _NormalWarmupExecutor(_BaseWarmupExecutor):\n    def __init__(self, max_m: int, n: int, k: int, num_groups: int):\n        self.lhs_q, self.lhs_s = _empty_token_fp8((max_m, k))\n        self.rhs_q, self.rhs_s = _empty_block_fp8((n, k))\n        self.out = torch.empty((max_m, n), device=\"cuda\", dtype=torch.bfloat16)\n\n    def execute(self, m):\n        deep_gemm.fp8_gemm_nt(\n            (self.lhs_q[:m], self.lhs_s[:m]),\n            (self.rhs_q, self.rhs_s),\n            self.out[:m],\n        )\n\n\nclass _GroupedContWarmupExecutor(_BaseWarmupExecutor):\n    def __init__(self, max_m: int, n: int, k: int, num_groups: int):\n        self.lhs_q, self.lhs_s = _empty_token_fp8((max_m, k))\n        self.rhs_q, self.rhs_s = _empty_block_fp8((num_groups, n, k))\n        self.m_indices = torch.zeros((max_m,), device=\"cuda\", dtype=torch.int32)\n        self.out = torch.empty((max_m, n), device=\"cuda\", dtype=torch.bfloat16)\n\n    def execute(self, m):\n        deep_gemm.m_grouped_fp8_gemm_nt_contiguous(\n            (self.lhs_q[:m], self.lhs_s[:m]),\n            (self.rhs_q, self.rhs_s),\n            self.out[:m],\n            m_indices=self.m_indices[:m],\n        )\n\n\nclass _GroupedMaskedWarmupExecutor(_BaseWarmupExecutor):\n    def __init__(self, max_m: int, n: int, k: int, num_groups: int):\n        self.lhs_q, self.lhs_s = _empty_token_fp8((num_groups, max_m, k))\n        self.rhs_q, self.rhs_s = _empty_block_fp8((num_groups, n, k))\n        self.masked_m = torch.zeros((num_groups,), device=\"cuda\", dtype=torch.int32)\n        self.out = torch.empty(\n            (num_groups, max_m, n), device=\"cuda\", dtype=torch.bfloat16\n        )\n\n    def execute(self, m):\n        deep_gemm.fp8_m_grouped_gemm_nt_masked(\n            (self.lhs_q, self.lhs_s),\n            (self.rhs_q, self.rhs_s),\n            self.out,\n            masked_m=self.masked_m,\n            # DeepGEMM uses `expect_m` instead of input shape for `get_best_config`\n            expected_m=m,\n        )\n\n\nclass _BF16F32WarmupExecutor(_BaseWarmupExecutor):\n    def __init__(self, max_m: int, n: int, k: int, num_groups: int):\n        self.lhs = torch.empty((max_m, k), device=\"cuda\", dtype=torch.bfloat16)\n        self.rhs = torch.empty((n, k), device=\"cuda\", dtype=torch.bfloat16)\n        self.out = torch.empty((max_m, n), device=\"cuda\", dtype=torch.float32)\n\n    def execute(self, m):\n        deep_gemm.bf16_gemm_nt(self.lhs[:m], self.rhs, self.out[:m])\n\n\n@contextmanager\ndef deep_gemm_execution_hook(\n    m: int, n: int, k: int, num_groups: int, kernel_type: DeepGemmKernelType\n):\n    if m > 0:\n        _maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)\n    yield\n"
  },
  {
    "path": "python/sglang/srt/layers/deep_gemm_wrapper/configurer.py",
    "content": "import logging\n\nfrom sglang.srt.environ import envs\nfrom sglang.srt.utils import get_device_sm, is_blackwell_supported\n\nlogger = logging.getLogger(__name__)\n\n\ndef _compute_enable_deep_gemm():\n    sm_version = get_device_sm()\n    if sm_version < 90:\n        return False\n\n    try:\n        import deep_gemm  # noqa: F401\n    except ImportError:\n        return False\n\n    return envs.SGLANG_ENABLE_JIT_DEEPGEMM.get()\n\n\nENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm()\n\nDEEPGEMM_BLACKWELL = ENABLE_JIT_DEEPGEMM and is_blackwell_supported()\nDEEPGEMM_SCALE_UE8M0 = DEEPGEMM_BLACKWELL\n"
  },
  {
    "path": "python/sglang/srt/layers/deep_gemm_wrapper/entrypoint.py",
    "content": "import logging\nfrom contextlib import contextmanager\nfrom typing import Any, Optional, Tuple\n\nimport torch\n\nfrom sglang.srt.layers.deep_gemm_wrapper import compile_utils\nfrom sglang.srt.layers.deep_gemm_wrapper.configurer import (  # noqa: F401\n    DEEPGEMM_BLACKWELL,\n    DEEPGEMM_SCALE_UE8M0,\n    ENABLE_JIT_DEEPGEMM,\n)\nfrom sglang.srt.server_args import ServerArgs\nfrom sglang.srt.utils import get_bool_env_var\n\nlogger = logging.getLogger(__name__)\n\nif ENABLE_JIT_DEEPGEMM:\n    import deep_gemm\n    from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor  # noqa: F401\n\n_SANITY_CHECK = get_bool_env_var(\"SGLANG_DEEPGEMM_SANITY_CHECK\")\n\n\n# TODO maybe rename these functions\ndef grouped_gemm_nt_f8f8bf16_masked(\n    lhs: Tuple[torch.Tensor, torch.Tensor],\n    rhs: Tuple[torch.Tensor, torch.Tensor],\n    out: torch.Tensor,\n    masked_m: torch.Tensor,\n    expected_m: int,\n    overlap_args: Optional[Any] = None,\n    max_block_n: int = 256,\n):\n    num_groups, _, k = lhs[0].shape\n    _, n, _ = rhs[0].shape\n    kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED\n\n    _sanity_check_input(lhs)\n    _sanity_check_input(rhs)\n\n    with compile_utils.deep_gemm_execution_hook(\n        expected_m, n, k, num_groups, kernel_type\n    ):\n        with configure_deep_gemm_num_sms(\n            overlap_args.num_sms if overlap_args is not None else None\n        ):\n\n            return deep_gemm.fp8_m_grouped_gemm_nt_masked(\n                lhs,\n                rhs,\n                out,\n                masked_m,\n                expected_m,\n                **(\n                    dict(\n                        enable_overlap=True,\n                        max_block_n=max_block_n,\n                        signal=overlap_args.signal,\n                    )\n                    if overlap_args is not None\n                    else {}\n                ),\n            )\n\n\ndef grouped_gemm_nt_f8f8bf16_contig(\n    lhs: Tuple[torch.Tensor, torch.Tensor],\n    rhs: Tuple[torch.Tensor, torch.Tensor],\n    out: torch.Tensor,\n    m_indices: torch.Tensor,\n):\n    m, k = lhs[0].shape\n    num_groups, n, _ = rhs[0].shape\n    kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG\n\n    _sanity_check_input(lhs)\n    _sanity_check_input(rhs)\n\n    with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):\n        deep_gemm.m_grouped_fp8_gemm_nt_contiguous(lhs, rhs, out, m_indices)\n\n\ndef gemm_nt_f8f8bf16(\n    lhs: Tuple[torch.Tensor, torch.Tensor],\n    rhs: Tuple[torch.Tensor, torch.Tensor],\n    out: torch.Tensor,\n):\n    m, k = lhs[0].shape\n    n, _ = rhs[0].shape\n    num_groups = 1\n    kernel_type = compile_utils.DeepGemmKernelType.GEMM_NT_F8F8BF16\n\n    _sanity_check_input(lhs)\n    _sanity_check_input(rhs)\n\n    with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):\n        deep_gemm.fp8_gemm_nt(\n            lhs,\n            rhs,\n            out,\n        )\n\n\ndef gemm_nt_bf16bf16f32(\n    lhs: torch.Tensor,\n    rhs: torch.Tensor,\n    out: torch.Tensor,\n):\n    m, k = lhs.shape\n    n, _ = rhs.shape\n    num_groups = 1\n    kernel_type = compile_utils.DeepGemmKernelType.GEMM_NT_BF16BF16F32\n\n    with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):\n        deep_gemm.bf16_gemm_nt(lhs, rhs, out)\n\n\ndef update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):\n    compile_utils.update_deep_gemm_config(gpu_id, server_args)\n\n\n@contextmanager\ndef configure_deep_gemm_num_sms(num_sms):\n    if num_sms is None or not ENABLE_JIT_DEEPGEMM:\n        yield\n    else:\n        original_num_sms = deep_gemm.get_num_sms()\n        deep_gemm.set_num_sms(num_sms)\n        try:\n            yield\n        finally:\n            deep_gemm.set_num_sms(original_num_sms)\n\n\ndef _sanity_check_input(x_fp8: Tuple[torch.Tensor, torch.Tensor]):\n    if not _SANITY_CHECK:\n        return\n\n    x, x_scale = x_fp8\n\n    if x_scale.dtype == torch.int:\n        return\n\n    from sglang.srt.layers.quantization.fp8_utils import ceil_to_ue8m0\n\n    x_scale_ceil = ceil_to_ue8m0(x_scale)\n    assert torch.all(x_scale == x_scale_ceil), f\"{x_scale=} {x_scale_ceil=}\"\n"
  },
  {
    "path": "python/sglang/srt/layers/dp_attention.py",
    "content": "from __future__ import annotations\n\nimport functools\nimport logging\nfrom contextlib import contextmanager\nfrom enum import IntEnum, auto\nfrom typing import TYPE_CHECKING, List, Optional, Tuple\n\nimport torch\nimport triton\nimport triton.language as tl\n\nfrom sglang.srt.distributed import (\n    GroupCoordinator,\n    get_attn_context_model_parallel_rank,\n    get_attn_context_model_parallel_world_size,\n    get_attn_cp_group,\n    get_attn_tensor_model_parallel_rank,\n    get_attn_tensor_model_parallel_world_size,\n    get_attn_tp_group,\n    get_tensor_model_parallel_rank,\n    get_tensor_model_parallel_world_size,\n    get_tp_group,\n    tensor_model_parallel_all_reduce,\n)\nfrom sglang.srt.distributed.device_communicators.pynccl_allocator import (\n    use_symmetric_memory,\n)\nfrom sglang.srt.utils import get_bool_env_var, is_hip\n\nif TYPE_CHECKING:\n    from sglang.srt.configs.model_config import ModelConfig\n    from sglang.srt.server_args import ServerArgs\n\nlogger = logging.getLogger(__name__)\n\nif TYPE_CHECKING:\n    from sglang.srt.model_executor.forward_batch_info import ForwardBatch\n\n_ATTN_DP_RANK: Optional[int] = None\n_ATTN_DP_SIZE: Optional[int] = None\n_LOCAL_ATTN_DP_SIZE: Optional[int] = None\n_LOCAL_ATTN_DP_RANK: Optional[int] = None\n_ENABLE_DP_ATTENTION_FLAG: bool = False\n\n_is_hip = is_hip()\n_USE_ROCM700A_WA = _is_hip and get_bool_env_var(\"SGLANG_USE_ROCM700A\")\n\n\nclass DpPaddingMode(IntEnum):\n\n    # Padding tokens to max length and then gather tokens using `all_gather_into_tensor`\n    MAX_LEN = auto()\n    # Padding tokens to sum length and then gather tokens using `all_reduce`\n    SUM_LEN = auto()\n\n    def is_max_len(self):\n        return self == DpPaddingMode.MAX_LEN\n\n    def is_sum_len(self):\n        return self == DpPaddingMode.SUM_LEN\n\n    @classmethod\n    def get_dp_padding_mode(\n        cls, is_extend_in_batch, global_num_tokens: List[int]\n    ) -> DpPaddingMode:\n        dp_size = get_attention_dp_size()\n\n        # When is_extend_in_batch and dp_size > 1, use SUM_LEN to avoid padding\n        # overhead from uneven token distribution.\n        # For dp_size=1, max_len equals sum_len, so prefer MAX_LEN mode\n        # to enable symmetric memory optimization (needed for NSA CP, etc.).\n        if is_extend_in_batch and dp_size > 1:\n            return DpPaddingMode.SUM_LEN\n\n        # we choose the mode that minimizes the communication cost\n        # prefer MAX_LEN when communication cost is equal to enable symmetric memory\n        max_len = max(global_num_tokens)\n        sum_len = sum(global_num_tokens)\n        if sum_len * 2 >= max_len * dp_size:\n            return cls.MAX_LEN\n        else:\n            return cls.SUM_LEN\n\n    @classmethod\n    def get_default_mode_in_cuda_graph(cls) -> DpPaddingMode:\n        # TODO(kkhuang-amd): noqa, temporary work-around for rocm 7.0.0 alpha\n        # it can be safely removed later, once RCCL fixed\n        if _USE_ROCM700A_WA:\n            return cls.SUM_LEN\n        else:\n            return cls.MAX_LEN\n\n\nclass _DpGatheredBufferWrapper:\n\n    _hidden_size: int\n    _dtype: torch.dtype\n    _device: torch.device\n    _global_dp_buffer_len: int\n    _local_dp_buffer_len: int\n    _dp_max_padding: bool\n    _global_num_tokens: Optional[List[int]]\n    _is_extend_in_batch: bool\n\n    @classmethod\n    def set_metadata(cls, hidden_size: int, dtype: torch.dtype, device: torch.device):\n        cls._hidden_size = hidden_size\n        cls._dtype = dtype\n        cls._device = device\n\n    @classmethod\n    def set_dp_buffer_len(\n        cls,\n        global_dp_buffer_len: int,\n        local_dp_buffer_len: int,\n        dp_max_padding: bool,\n        global_num_tokens: Optional[List[int]] = None,\n    ):\n        cls._global_dp_buffer_len = global_dp_buffer_len\n        cls._local_dp_buffer_len = local_dp_buffer_len\n        cls._dp_max_padding = dp_max_padding\n        cls._global_num_tokens = global_num_tokens\n\n    @classmethod\n    def get_global_dp_buffer(cls) -> torch.Tensor:\n        with use_symmetric_memory(get_tp_group(), disabled=not cls._dp_max_padding):\n            buffer = torch.empty(\n                (cls._global_dp_buffer_len, cls._hidden_size),\n                dtype=cls._dtype,\n                device=cls._device,\n            )\n        return buffer\n\n    @classmethod\n    def get_local_dp_buffer(cls) -> torch.Tensor:\n        with use_symmetric_memory(get_tp_group(), disabled=not cls._dp_max_padding):\n            buffer = torch.empty(\n                (cls._local_dp_buffer_len, cls._hidden_size),\n                dtype=cls._dtype,\n                device=cls._device,\n            )\n        return buffer\n\n    @classmethod\n    def get_global_dp_buffer_len(cls) -> int:\n        return cls._global_dp_buffer_len\n\n    @classmethod\n    def get_local_dp_buffer_len(cls) -> int:\n        return cls._local_dp_buffer_len\n\n    @classmethod\n    def get_dp_global_num_tokens(cls) -> List[int]:\n        return cls._global_num_tokens\n\n    @classmethod\n    def get_dp_hidden_size(cls) -> int:\n        return cls._hidden_size\n\n    @classmethod\n    def get_dp_dtype(cls) -> torch.dtype:\n        return cls._dtype\n\n    @classmethod\n    def get_dp_device(cls) -> torch.device:\n        return cls._device\n\n    @classmethod\n    def set_is_extend_in_batch(cls, is_extend_in_batch: bool):\n        cls._is_extend_in_batch = is_extend_in_batch\n\n    @classmethod\n    def get_is_extend_in_batch(cls) -> bool:\n        return cls._is_extend_in_batch\n\n    @classmethod\n    def is_dp_max_padding(cls) -> bool:\n        return cls._dp_max_padding\n\n\ndef set_dp_buffer_len(\n    global_dp_buffer_len: int,\n    local_dp_buffer_len: int,\n    dp_max_padding: bool,\n    global_num_tokens: Optional[List[int]] = None,\n):\n    _DpGatheredBufferWrapper.set_dp_buffer_len(\n        global_dp_buffer_len, local_dp_buffer_len, dp_max_padding, global_num_tokens\n    )\n\n\ndef get_global_dp_buffer() -> torch.Tensor:\n    return _DpGatheredBufferWrapper.get_global_dp_buffer()\n\n\ndef get_local_dp_buffer() -> torch.Tensor:\n    return _DpGatheredBufferWrapper.get_local_dp_buffer()\n\n\ndef get_global_dp_buffer_len() -> int:\n    return _DpGatheredBufferWrapper.get_global_dp_buffer_len()\n\n\ndef get_local_dp_buffer_len() -> int:\n    return _DpGatheredBufferWrapper.get_local_dp_buffer_len()\n\n\ndef get_dp_global_num_tokens() -> List[int]:\n    return _DpGatheredBufferWrapper.get_dp_global_num_tokens()\n\n\ndef get_dp_hidden_size() -> int:\n    return _DpGatheredBufferWrapper.get_dp_hidden_size()\n\n\ndef get_dp_dtype() -> torch.dtype:\n    return _DpGatheredBufferWrapper.get_dp_dtype()\n\n\ndef get_dp_device() -> torch.device:\n    return _DpGatheredBufferWrapper.get_dp_device()\n\n\ndef set_is_extend_in_batch(is_extend_in_batch: bool):\n    _DpGatheredBufferWrapper.set_is_extend_in_batch(is_extend_in_batch)\n\n\ndef get_is_extend_in_batch() -> bool:\n    return _DpGatheredBufferWrapper.get_is_extend_in_batch()\n\n\ndef is_dp_max_padding() -> bool:\n    return _DpGatheredBufferWrapper.is_dp_max_padding()\n\n\ndef compute_dp_attention_world_info(\n    enable_dp_attention, tp_rank, tp_size, dp_size, attn_cp_size: int = 1\n):\n    attn_dp_size = dp_size if enable_dp_attention else 1\n    attn_tp_size = tp_size // attn_dp_size // attn_cp_size\n    attn_tp_rank = tp_rank % attn_tp_size\n\n    if not enable_dp_attention:\n        attn_dp_rank = 0\n    else:\n        # Rank layout is (dp, cp, tp) where tp is the fastest-changing dim:\n        # tp_rank = (attn_dp_rank * attn_cp_size + attn_cp_rank) * attn_tp_size + attn_tp_rank\n        attn_dp_rank = tp_rank // (attn_tp_size * attn_cp_size)\n\n    return attn_tp_rank, attn_tp_size, attn_dp_rank\n\n\ndef compute_dp_attention_local_info(\n    enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size\n):\n    if not enable_dp_attention:\n        return tp_rank, tp_size, 0\n\n    local_tp_size = moe_dense_tp_size if moe_dense_tp_size else tp_size\n    local_tp_rank = tp_rank % local_tp_size\n    local_dp_size = max(1, dp_size // (tp_size // local_tp_size))\n\n    local_attn_tp_size = local_tp_size // local_dp_size\n    local_attn_dp_rank = local_tp_rank // local_attn_tp_size\n    local_attn_tp_rank = local_tp_rank % local_attn_tp_size\n\n    return local_attn_tp_rank, local_attn_tp_size, local_attn_dp_rank\n\n\ndef initialize_dp_attention(\n    server_args: ServerArgs,\n    model_config: ModelConfig,\n):\n    global _ATTN_DP_RANK, _ATTN_DP_SIZE\n    global _LOCAL_ATTN_DP_SIZE, _LOCAL_ATTN_DP_RANK, _ENABLE_DP_ATTENTION_FLAG\n    enable_dp_attention = server_args.enable_dp_attention\n    dp_size = server_args.dp_size\n    moe_dense_tp_size = server_args.moe_dense_tp_size\n    attn_cp_size = server_args.attn_cp_size\n\n    _ENABLE_DP_ATTENTION_FLAG = enable_dp_attention\n\n    tp_rank = get_tensor_model_parallel_rank()\n    tp_size = get_tensor_model_parallel_world_size()\n\n    _, _, _ATTN_DP_RANK = compute_dp_attention_world_info(\n        enable_dp_attention, tp_rank, tp_size, dp_size, attn_cp_size\n    )\n    _, _, _LOCAL_ATTN_DP_RANK = compute_dp_attention_local_info(\n        enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size\n    )\n\n    if enable_dp_attention:\n        _ATTN_DP_SIZE = dp_size\n        if moe_dense_tp_size is None:\n            _LOCAL_ATTN_DP_SIZE = _ATTN_DP_SIZE\n        else:\n            _LOCAL_ATTN_DP_SIZE = max(1, dp_size // (tp_size // moe_dense_tp_size))\n    else:\n        _ATTN_DP_SIZE = 1\n        _LOCAL_ATTN_DP_SIZE = 1\n\n    _DpGatheredBufferWrapper.set_metadata(\n        hidden_size=model_config.hidden_size,\n        dtype=model_config.dtype,\n        device=torch.device(server_args.device),\n    )\n\n\ndef is_dp_attention_enabled() -> bool:\n    return _ENABLE_DP_ATTENTION_FLAG\n\n\ndef is_allocation_symmetric() -> bool:\n    return not is_dp_attention_enabled() or is_dp_max_padding()\n\n\ndef get_attention_tp_group() -> GroupCoordinator:\n    return get_attn_tp_group()\n\n\ndef get_attention_tp_rank() -> int:\n    return get_attn_tensor_model_parallel_rank()\n\n\ndef get_attention_tp_size() -> int:\n    return get_attn_tensor_model_parallel_world_size()\n\n\ndef get_attention_cp_group() -> GroupCoordinator:\n    return get_attn_cp_group()\n\n\ndef get_attention_cp_rank() -> int:\n    return get_attn_context_model_parallel_rank()\n\n\ndef get_attention_cp_size() -> int:\n    return get_attn_context_model_parallel_world_size()\n\n\ndef get_attention_dp_rank() -> int:\n    assert _ATTN_DP_RANK is not None, \"dp attention not initialized!\"\n    return _ATTN_DP_RANK\n\n\ndef get_attention_dp_size() -> int:\n    assert _ATTN_DP_SIZE is not None, \"dp attention not initialized!\"\n    return _ATTN_DP_SIZE\n\n\ndef get_local_attention_dp_rank() -> int:\n    assert _LOCAL_ATTN_DP_RANK is not None, \"dp attention not initialized!\"\n    return _LOCAL_ATTN_DP_RANK\n\n\ndef get_local_attention_dp_size() -> int:\n    assert _LOCAL_ATTN_DP_SIZE is not None, \"dp attention not initialized!\"\n    return _LOCAL_ATTN_DP_SIZE\n\n\n@contextmanager\ndef disable_dp_size():\n    \"\"\"Patch the tp group temporarily until this function ends.\n\n    This method is for draft workers of speculative decoding to run draft model\n    with different tp degree from that of target model workers.\n\n    Args:\n        tp_group (GroupCoordinator): the tp group coordinator\n    \"\"\"\n    global _ATTN_DP_SIZE\n    assert _ATTN_DP_SIZE is not None, \"dp attention not initialized!\"\n\n    old_dp_size = _ATTN_DP_SIZE\n    _ATTN_DP_SIZE = 1\n    try:\n        yield\n    finally:\n        _ATTN_DP_SIZE = old_dp_size\n\n\ndef get_dp_local_info(forward_batch: ForwardBatch) -> Tuple[torch.Tensor, torch.Tensor]:\n    # `get_dp_local_info` is only called in global DP gather and scatter. We use global DP rank here.\n    dp_rank = get_attention_dp_rank()\n\n    if forward_batch.dp_local_start_pos is None:\n        cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0)\n        if dp_rank == 0:\n            local_start_pos = torch.zeros_like(cumtokens[0])\n        else:\n            local_start_pos = cumtokens[dp_rank - 1]\n        local_num_tokens = forward_batch.global_num_tokens_gpu[dp_rank]\n\n        forward_batch.dp_local_start_pos = local_start_pos\n        forward_batch.dp_local_num_tokens = local_num_tokens\n\n    return forward_batch.dp_local_start_pos, forward_batch.dp_local_num_tokens\n\n\n@triton.jit\ndef memcpy_triton_kernel(\n    dst_ptr,\n    src_ptr,\n    offset_ptr,\n    sz_ptr,\n    offset_src: tl.constexpr,\n    chunk_size,  # multiplied for offset and sz\n    BLOCK_SIZE: tl.constexpr,\n):\n    pid = tl.program_id(axis=0).to(tl.int64)\n    offset = tl.load(offset_ptr).to(tl.int64) * chunk_size\n    sz = tl.load(sz_ptr).to(tl.int64) * chunk_size\n\n    start_index = pid * BLOCK_SIZE\n    offs = tl.arange(0, BLOCK_SIZE)\n    mask = start_index + offs < sz\n\n    if offset_src:\n        data = tl.load(src_ptr + offset + start_index + offs, mask=mask)\n        tl.store(dst_ptr + start_index + offs, data, mask=mask)\n    else:\n        data = tl.load(src_ptr + start_index + offs, mask=mask)\n        tl.store(dst_ptr + offset + start_index + offs, data, mask=mask)\n\n\ndef prod(x):\n    return functools.reduce(lambda a, b: a * b, x, 1)\n\n\ndef memcpy_triton(dst, src, dim, offset, sz, offset_src):\n    max_size = min(src.numel(), dst.numel())\n    assert dim == 0, \"dim != 0 unsupported\"\n    assert src.shape[1:] == dst.shape[1:], \"src and dst must have same shape\"\n    chunk_size = prod(src.shape[1:])\n    BLOCK_SIZE = 8192\n    grid = (triton.cdiv(max_size, BLOCK_SIZE),)\n\n    memcpy_triton_kernel[grid](dst, src, offset, sz, offset_src, chunk_size, BLOCK_SIZE)\n\n\ndef _dp_gather_via_all_reduce(\n    global_tokens: torch.Tensor,\n    local_tokens: torch.Tensor,\n    forward_batch: ForwardBatch,\n    is_partial: bool,\n):\n    local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)\n\n    global_tokens.fill_(0)\n    assert local_tokens.is_contiguous()\n    assert global_tokens.is_contiguous()\n\n    if local_tokens.shape[0] > 0 and (is_partial or get_attention_tp_rank() == 0):\n        assert (\n            local_tokens.untyped_storage() is not global_tokens.untyped_storage()\n        ), \"aliasing between global_tokens and local_tokens not allowed\"\n\n        memcpy_triton(\n            global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False\n        )\n\n    # Input IDs are in int 32. We should use inplace_all_reduce for local case because of custom all reduce.\n    NUM_GPUS_PER_NODE = 8\n    if (\n        not local_tokens.dtype.is_floating_point\n        and get_tensor_model_parallel_world_size() <= NUM_GPUS_PER_NODE\n    ):\n        from sglang.srt.distributed.parallel_state import inplace_all_reduce\n\n        inplace_all_reduce(global_tokens, group_name=get_tp_group().unique_name)\n\n    else:\n        global_tokens[:] = tensor_model_parallel_all_reduce(global_tokens)\n\n\ndef _dp_gather_via_all_gather(\n    global_tokens: torch.Tensor,\n    local_tokens: torch.Tensor,\n    forward_batch: ForwardBatch,\n    is_partial: bool,\n):\n    if get_attention_tp_size() == 1:\n        get_tp_group().all_gather_into_tensor(global_tokens, local_tokens)\n        return\n\n    if not is_partial:\n        if get_attention_tp_rank() != 0:\n            local_tokens.fill_(0)\n    scattered_local_tokens = local_tokens.tensor_split(get_attention_tp_size())[\n        get_attention_tp_rank()\n    ]\n    get_attention_tp_group().reduce_scatter_tensor(scattered_local_tokens, local_tokens)\n    get_tp_group().all_gather_into_tensor(global_tokens, scattered_local_tokens)\n\n\ndef _dp_gather(\n    global_tokens: torch.Tensor,\n    local_tokens: torch.Tensor,\n    forward_batch: ForwardBatch,\n    is_partial: bool,\n):\n    if forward_batch.dp_padding_mode.is_max_len():\n        _dp_gather_via_all_gather(\n            global_tokens, local_tokens, forward_batch, is_partial\n        )\n    else:\n        _dp_gather_via_all_reduce(\n            global_tokens, local_tokens, forward_batch, is_partial\n        )\n\n\ndef dp_gather_partial(\n    global_tokens: torch.Tensor,\n    local_tokens: torch.Tensor,\n    forward_batch: ForwardBatch,\n):\n    _dp_gather(global_tokens, local_tokens, forward_batch, is_partial=True)\n\n\ndef dp_gather_replicate(\n    global_tokens: torch.Tensor,\n    local_tokens: torch.Tensor,\n    forward_batch: ForwardBatch,\n):\n    _dp_gather(global_tokens, local_tokens, forward_batch, is_partial=False)\n\n\ndef dp_scatter(\n    local_tokens: torch.Tensor,  # output\n    global_tokens: torch.Tensor,  # input\n    forward_batch: ForwardBatch,\n):\n    # local_num_tokens is not necessarily the same as local_tokens.shape[0],\n    # since local_tokens may be padded for cuda graph\n    local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)\n\n    local_tokens.fill_(0)\n    assert local_tokens.is_contiguous()\n    assert global_tokens.is_contiguous()\n    if local_tokens.shape[0] > 0:\n        assert (\n            local_tokens.untyped_storage() is not global_tokens.untyped_storage()\n        ), \"aliasing between local_tokens and global_tokens not allowed\"\n\n        memcpy_triton(\n            local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True\n        )\n\n\ndef dp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):\n    if get_tensor_model_parallel_world_size() == get_attention_dp_size():\n        get_tp_group().reduce_scatter_tensor(output, input)\n    else:\n        scattered_local_tokens = input.tensor_split(\n            get_tensor_model_parallel_world_size()\n        )[get_tensor_model_parallel_rank()]\n        get_tp_group().reduce_scatter_tensor(scattered_local_tokens, input)\n        get_attention_tp_group().all_gather_into_tensor(output, scattered_local_tokens)\n\n\ndef attn_tp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):\n    return get_attention_tp_group().reduce_scatter_tensor(output, input)\n\n\ndef attn_cp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):\n    return get_attention_cp_group().reduce_scatter_tensor(output, input)\n\n\ndef attn_tp_all_reduce(input: torch.Tensor):\n    return get_attention_tp_group().all_reduce(input)\n\n\ndef attn_tp_all_gather_into_tensor(output: torch.Tensor, input: torch.Tensor):\n    return get_attention_tp_group().all_gather_into_tensor(output, input)\n\n\ndef attn_cp_all_gather_into_tensor(output: torch.Tensor, input: torch.Tensor):\n    return get_attention_cp_group().all_gather_into_tensor(output, input)\n\n\ndef attn_tp_all_gather(output_list: List[torch.Tensor], input: torch.Tensor):\n    return get_attention_tp_group().all_gather(input, output_tensor_list=output_list)\n"
  },
  {
    "path": "python/sglang/srt/layers/elementwise.py",
    "content": "from typing import Optional, Tuple\n\nimport torch\nimport triton\nimport triton.language as tl\n\nfrom sglang.srt.utils import is_hip\nfrom sglang.srt.utils.custom_op import register_custom_op\n\n_is_hip = is_hip()\n\n\nfused_softcap_autotune = triton.autotune(\n    configs=[\n        triton.Config(kwargs={\"BLOCK_SIZE\": 128}, num_warps=4),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 128}, num_warps=8),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 128}, num_warps=16),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 256}, num_warps=4),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 256}, num_warps=8),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 512}, num_warps=4),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 512}, num_warps=8),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 512}, num_warps=16),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 1024}, num_warps=4),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 1024}, num_warps=8),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 1024}, num_warps=16),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 1024}, num_warps=32),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 2048}, num_warps=32),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 4096}, num_warps=32),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 8192}, num_warps=32),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 16384}, num_warps=32),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 32768}, num_warps=32),\n    ],\n    key=[\"n_ele\"],\n)\n\n\n@triton.jit\ndef fused_softcap_kernel(\n    output_ptr,\n    input_ptr,\n    n_ele,\n    softcap_const: tl.constexpr,\n    BLOCK_SIZE: tl.constexpr,\n):\n    pid = tl.program_id(axis=0)\n    block_start = pid * BLOCK_SIZE\n    offsets = block_start + tl.arange(0, BLOCK_SIZE)\n    mask = offsets < n_ele\n    x = tl.load(input_ptr + offsets, mask=mask)\n    fx = x.to(tl.float32)\n    fxs = fx / softcap_const\n    exped = tl.exp(2 * fxs)\n    top = exped - 1\n    bottom = exped + 1\n    output = top / bottom * softcap_const\n    tl.store(output_ptr + offsets, output, mask=mask)\n\n\nfused_softcap_kernel_autotuned = fused_softcap_autotune(fused_softcap_kernel)\n\n\ndef fused_softcap(x, softcap_const, autotune=False):\n    output = torch.empty_like(x, dtype=torch.float32)\n    n_elements = output.numel()\n    if autotune:\n        grid = lambda meta: (triton.cdiv(n_elements, meta[\"BLOCK_SIZE\"]),)\n        fused_softcap_kernel_autotuned[grid](output, x, n_elements, softcap_const)\n    else:\n        fused_softcap_kernel[(triton.cdiv(n_elements, 128),)](\n            output, x, n_elements, softcap_const, BLOCK_SIZE=128, num_warps=8\n        )\n    return output\n\n\n# cast to float + softcap\nclass Softcap:\n    def __init__(self, softcap_const: float):\n        self.softcap_const = softcap_const\n\n    def __call__(self, *args, **kwargs):\n        return self.forward(*args, **kwargs)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        if x.is_cuda:\n            return self.forward_cuda(x)\n        else:\n            return self.forward_native(x)\n\n    def forward_native(self, x: torch.Tensor) -> torch.Tensor:\n        return torch.tanh(x.float() / self.softcap_const) * self.softcap_const\n\n    def forward_cuda(self, x: torch.Tensor, autotune=False) -> torch.Tensor:\n        return fused_softcap(x, self.softcap_const, autotune=autotune)\n\n\nrmsnorm_autotune = triton.autotune(\n    configs=[\n        triton.Config(kwargs={\"BLOCK_SIZE\": 1024}, num_warps=4, num_stages=1),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 1024}, num_warps=8, num_stages=1),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 1024}, num_warps=16, num_stages=1),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 1024}, num_warps=4),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 1024}, num_warps=8),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 1024}, num_warps=16),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 1024}, num_warps=4, num_stages=4),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 1024}, num_warps=8, num_stages=4),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 1024}, num_warps=16, num_stages=4),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 1024}, num_warps=8, num_stages=8),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 1024}, num_warps=16, num_stages=8),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 2048}, num_warps=8),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 2048}, num_warps=16),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 2048}, num_warps=8, num_stages=4),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 2048}, num_warps=16, num_stages=4),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 4096}, num_warps=8),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 4096}, num_warps=16),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 8192}, num_warps=8),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 8192}, num_warps=16),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 8192}, num_warps=32),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 8192}, num_warps=8, num_stages=1),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 8192}, num_warps=16, num_stages=1),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 8192}, num_warps=32, num_stages=1),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 8192}, num_warps=8, num_stages=4),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 8192}, num_warps=16, num_stages=4),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 8192}, num_warps=32, num_stages=4),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 16384}, num_warps=8),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 16384}, num_warps=16),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 16384}, num_warps=32),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 16384}, num_warps=8, num_stages=1),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 16384}, num_warps=16, num_stages=1),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 16384}, num_warps=32, num_stages=1),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 16384}, num_warps=8, num_stages=4),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 16384}, num_warps=16, num_stages=4),\n        triton.Config(kwargs={\"BLOCK_SIZE\": 16384}, num_warps=32, num_stages=4),\n    ],\n    key=[\"hidden_dim\"],\n)\n\n\n@triton.jit\ndef fused_dual_residual_rmsnorm_kernel(\n    output_ptr,\n    mid_ptr,\n    activ_ptr,\n    residual_ptr,\n    weight1_ptr,\n    weight2_ptr,\n    eps: tl.constexpr,\n    hidden_dim: tl.constexpr,\n    BLOCK_SIZE: tl.constexpr,\n):\n    pid = tl.program_id(axis=0)\n    input_start = pid * hidden_dim\n\n    offsets = tl.arange(0, BLOCK_SIZE)\n    mask = offsets < hidden_dim\n\n    a_ = tl.load(activ_ptr + input_start + offsets, mask=mask, other=0.0)\n    a = a_.to(tl.float32)\n    rms = tl.sqrt(tl.sum(a * a, axis=0) / hidden_dim + eps)\n\n    r = tl.load(residual_ptr + input_start + offsets, mask=mask, other=0.0)\n    w1_ = tl.load(weight1_ptr + offsets, mask=mask, other=0.0)\n    w1 = w1_.to(tl.float32)\n\n    a2r = r + (a / rms * w1).to(r.dtype)\n    tl.store(\n        mid_ptr + input_start + offsets,\n        a2r,\n        mask=mask,\n    )\n\n    a2r = a2r.to(tl.float32)\n    rms2 = tl.sqrt(tl.sum(a2r * a2r, axis=0) / hidden_dim + eps)\n\n    w2_ = tl.load(weight2_ptr + offsets, mask=mask, other=0.0)\n    w2 = w2_.to(tl.float32)\n\n    tl.store(\n        output_ptr + input_start + offsets,\n        a2r / rms2 * w2,  # implicitly casts to output dtype here\n        mask=mask,\n    )\n\n\nfused_dual_residual_rmsnorm_kernel_autotune = rmsnorm_autotune(\n    fused_dual_residual_rmsnorm_kernel\n)\n\n\ndef fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=False):\n    assert len(x.shape) == 2\n    assert (\n        x.shape == residual.shape and x.dtype == residual.dtype\n    ), f\"{x.shape=} {residual.shape=} {x.dtype=} {residual.dtype=}\"\n    output, mid = torch.empty_like(x), torch.empty_like(x)\n    bs, hidden_dim = x.shape\n    if autotune:\n        fused_dual_residual_rmsnorm_kernel_autotune[(bs,)](\n            output, mid, x, residual, weight1, weight2, eps=eps, hidden_dim=hidden_dim\n        )\n    else:\n        max_warps = 16 if _is_hip else 32\n        config = {\n            \"BLOCK_SIZE\": triton.next_power_of_2(hidden_dim),\n            \"num_warps\": max(\n                min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), max_warps), 4\n            ),\n        }\n\n        fused_dual_residual_rmsnorm_kernel[(bs,)](\n            output,\n            mid,\n            x,\n            residual,\n            weight1,\n            weight2,\n            eps=eps,\n            hidden_dim=hidden_dim,\n            **config,\n        )\n\n    return output, mid\n\n\n@triton.jit\ndef fused_rmsnorm_kernel(\n    output_ptr,\n    activ_ptr,\n    weight_ptr,\n    eps: tl.constexpr,\n    hidden_dim: tl.constexpr,\n    BLOCK_SIZE: tl.constexpr,\n):\n    pid = tl.program_id(axis=0).to(tl.int64)\n    input_start = pid * hidden_dim\n\n    offsets = tl.arange(0, BLOCK_SIZE)\n    mask = offsets < hidden_dim\n\n    a_ = tl.load(activ_ptr + input_start + offsets, mask=mask, other=0.0)\n    a = a_.to(tl.float32)\n    rms = tl.sqrt(tl.sum(a * a, axis=0) / hidden_dim + eps)\n\n    w1_ = tl.load(weight_ptr + offsets, mask=mask, other=0.0)\n    w1 = w1_.to(tl.float32)\n\n    a_rms = a / rms * w1\n\n    tl.store(\n        output_ptr + input_start + offsets,\n        a_rms,  # implicitly casts to output dtype here\n        mask=mask,\n    )\n\n\ndef fused_rmsnorm(x, weight, eps, autotune=False, inplace=False):\n    assert len(x.shape) == 2\n    if inplace:\n        output = x\n    else:\n        output = torch.empty_like(x)\n    bs, hidden_dim = x.shape\n    max_warps = 16 if _is_hip else 32\n    config = {\n        \"BLOCK_SIZE\": triton.next_power_of_2(hidden_dim),\n        \"num_warps\": max(\n            min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), max_warps), 4\n        ),\n    }\n\n    fused_rmsnorm_kernel[(bs,)](\n        output, x, weight, eps=eps, hidden_dim=hidden_dim, **config\n    )\n    return output\n\n\nclass FusedDualResidualRMSNorm:\n    \"\"\"\n    Fused implementation of\n    y = RMSNorm2(RMSNorm1(x) + residual))\n    \"\"\"\n\n    def __init__(self, rmsnorm1, rmsnorm2) -> None:  # the one after rmsnorm1\n        self.rmsnorm1 = rmsnorm1\n        self.rmsnorm2 = rmsnorm2\n        self.variance_epsilon = self.rmsnorm1.variance_epsilon\n        assert self.rmsnorm1.variance_epsilon == self.rmsnorm2.variance_epsilon\n        assert self.rmsnorm1.weight.shape == self.rmsnorm2.weight.shape\n\n    def __call__(self, *args, **kwargs):\n        return self.forward(*args, **kwargs)\n\n    def forward(\n        self, x: torch.Tensor, residual: torch.Tensor\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        if x.is_cuda:\n            return self.forward_cuda(x, residual)\n        else:\n            return self.forward_flashinfer(x, residual)\n\n    def forward_cuda(\n        self, x: torch.Tensor, residual: torch.Tensor, autotune=False\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        return fused_dual_residual_rmsnorm(\n            x,\n            residual,\n            self.rmsnorm1.weight,\n            self.rmsnorm2.weight,\n            self.variance_epsilon,\n            autotune=autotune,\n        )\n\n    def forward_flashinfer(\n        self,\n        x: torch.Tensor,\n        residual: torch.Tensor,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        normed1 = self.rmsnorm1(x)\n        residual = normed1 + residual\n        return self.rmsnorm2(residual), residual\n\n    def forward_native(\n        self,\n        x: torch.Tensor,\n        residual: torch.Tensor,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        normed1 = self.rmsnorm1.forward_native(x)\n        residual = normed1 + residual\n        return self.rmsnorm2.forward_native(residual), residual\n\n\n@triton.jit\ndef experts_combine_kernel(\n    out_hidden_states,\n    moe_hidden_states,\n    mlp_hidden_states,\n    combine_k: tl.constexpr,\n    hidden_dim: tl.constexpr,\n    BLOCK_SIZE: tl.constexpr,\n):\n    pid = tl.program_id(0)\n    start_index_mlp = pid * hidden_dim\n    start_index_rmoe = pid * hidden_dim * combine_k\n    offsets = tl.arange(0, BLOCK_SIZE)\n    mask = offsets < hidden_dim\n    combine_k_offsets = tl.arange(0, combine_k)\n\n    moe_x = tl.load(\n        moe_hidden_states\n        + start_index_rmoe\n        + combine_k_offsets[:, None] * hidden_dim\n        + offsets[None, :],\n        mask=mask[None, :],\n        other=0.0,\n    )\n    moe_x = tl.sum(moe_x, axis=0)\n    mlp_x = tl.load(mlp_hidden_states + start_index_mlp + offsets, mask=mask, other=0.0)\n    combined_x = (moe_x + mlp_x) / 1.4142135623730951\n\n    tl.store(out_hidden_states + start_index_mlp + offsets, combined_x, mask=mask)\n\n\n@register_custom_op(out_shape=\"mlp_hidden_states\")\ndef experts_combine_triton(\n    moe_hidden_states: torch.Tensor,\n    mlp_hidden_states: torch.Tensor,\n    output_buffer: Optional[torch.Tensor] = None,\n) -> torch.Tensor:\n    assert moe_hidden_states.is_contiguous()\n    assert mlp_hidden_states.is_contiguous()\n\n    if len(moe_hidden_states.shape) == 2:\n        combine_k = 1  # pre-combined\n    else:\n        combine_k = moe_hidden_states.shape[1]\n\n    if output_buffer is None:\n        out_hidden_states = torch.empty_like(mlp_hidden_states)\n    else:\n        flat_output_buffer = output_buffer.view(mlp_hidden_states.dtype).reshape(-1)\n        assert flat_output_buffer.numel() >= mlp_hidden_states.numel()\n        out_hidden_states = flat_output_buffer[: mlp_hidden_states.numel()].reshape(\n            mlp_hidden_states.shape\n        )\n\n    bs, hidden_dim = mlp_hidden_states.shape\n\n    config = {\n        \"BLOCK_SIZE\": triton.next_power_of_2(hidden_dim),\n        \"num_warps\": max(\n            min(triton.next_power_of_2(triton.cdiv(hidden_dim, 1024)), 8), 4\n        ),\n    }\n\n    experts_combine_kernel[(bs,)](\n        out_hidden_states,\n        moe_hidden_states,\n        mlp_hidden_states,\n        combine_k,\n        hidden_dim,\n        **config,\n    )\n\n    return out_hidden_states\n\n\n# gelu on first half of vector\n@triton.jit\ndef gelu_and_mul_kernel(\n    out_hidden_states_ptr,  # (bs, hidden_dim)\n    out_scales_ptr,  # (bs,)\n    hidden_states_ptr,  # (bs, hidden_dim * 2)\n    quant_max: tl.constexpr,\n    static_scale: tl.constexpr,\n    hidden_dim: tl.constexpr,  # the output hidden_dim\n    BLOCK_SIZE: tl.constexpr,\n):\n    pid = tl.program_id(axis=0)\n\n    input_start = pid * hidden_dim * 2\n    output_start = pid * hidden_dim\n\n    input1_offs = tl.arange(0, BLOCK_SIZE)\n    mask = tl.arange(0, BLOCK_SIZE) < hidden_dim  # shared for input1, input3, output\n    input3_offs = hidden_dim + tl.arange(0, BLOCK_SIZE)\n    output_offs = tl.arange(0, BLOCK_SIZE)\n\n    x1 = tl.load(\n        hidden_states_ptr + input_start + input1_offs, mask=mask, other=0.0\n    ).to(tl.float32)\n    x3 = tl.load(\n        hidden_states_ptr + input_start + input3_offs, mask=mask, other=0.0\n    ).to(tl.float32)\n\n    # gelu\n    # cast down before mul to better match training?\n    gelu_x1 = 0.5 * (1.0 + tl.erf(x1 * 0.7071067811865475)) * x1\n    out = x3 * gelu_x1.to(hidden_states_ptr.dtype.element_ty)\n\n    if quant_max is not None:\n        raise NotImplementedError()\n\n    tl.store(out_hidden_states_ptr + output_start + output_offs, out, mask=mask)\n\n\ndef gelu_and_mul_triton(\n    hidden_states,\n    scales=None,\n    quantize=None,  # dtype to quantize to\n    out=None,\n):\n    bs, in_hidden_dim = hidden_states.shape\n    hidden_dim = in_hidden_dim // 2\n\n    if out is None:\n        out_hidden_states = torch.empty(\n            (bs, hidden_dim),\n            dtype=quantize or hidden_states.dtype,\n            device=hidden_states.device,\n        )\n    else:\n        assert out.shape == (bs, hidden_dim)\n        assert out.dtype == (quantize or hidden_states.dtype)\n        out_hidden_states = out\n    out_scales = None\n    static_scale = False\n    if quantize is not None:\n        if scales is None:\n            out_scales = torch.empty(\n                (bs,), dtype=torch.float32, device=hidden_states.device\n            )\n        else:\n            out_scales = scales\n            static_scale = True\n\n    max_warps = 16 if _is_hip else 32\n    config = {\n        # 8 ele per thread (not tuned)\n        \"num_warps\": max(\n            min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), max_warps), 4\n        ),\n    }\n\n    gelu_and_mul_kernel[(bs,)](\n        out_hidden_states,\n        out_scales,\n        hidden_states,\n        quant_max=torch.finfo(quantize).max if quantize is not None else None,\n        static_scale=static_scale,\n        hidden_dim=hidden_dim,\n        BLOCK_SIZE=triton.next_power_of_2(hidden_dim),\n        **config,\n    )\n\n    if quantize is not None:\n        return out_hidden_states, out_scales\n    else:\n        return out_hidden_states, None\n\n\n# silu on first half of vector\n@triton.jit\ndef silu_and_mul_kernel(\n    out_hidden_states_ptr,  # (bs, hidden_dim)\n    out_scales_ptr,  # (bs,)\n    hidden_states_ptr,  # (bs, hidden_dim * 2)\n    quant_max: tl.constexpr,\n    static_scale: tl.constexpr,\n    hidden_dim: tl.constexpr,  # the output hidden_dim\n    BLOCK_SIZE: tl.constexpr,\n):\n    pid = tl.program_id(axis=0)\n\n    input_start = pid * hidden_dim * 2\n    output_start = pid * hidden_dim\n\n    input1_offs = tl.arange(0, BLOCK_SIZE)\n    mask = tl.arange(0, BLOCK_SIZE) < hidden_dim  # shared for input1, input3, output\n    input3_offs = hidden_dim + tl.arange(0, BLOCK_SIZE)\n    output_offs = tl.arange(0, BLOCK_SIZE)\n\n    x1 = tl.load(\n        hidden_states_ptr + input_start + input1_offs, mask=mask, other=0.0\n    ).to(tl.float32)\n    x3 = tl.load(\n        hidden_states_ptr + input_start + input3_offs, mask=mask, other=0.0\n    ).to(tl.float32)\n\n    # silu\n    # cast down before mul to better match training?\n    silu_x1 = x1 * tl.sigmoid(x1)\n    out = x3 * silu_x1.to(hidden_states_ptr.dtype.element_ty)\n\n    if quant_max is not None:\n        raise NotImplementedError()\n\n    tl.store(out_hidden_states_ptr + output_start + output_offs, out, mask=mask)\n\n\ndef silu_and_mul_triton(\n    hidden_states,\n    scales=None,\n    quantize=None,  # dtype to quantize to\n    out=None,\n):\n    bs, in_hidden_dim = hidden_states.shape\n    hidden_dim = in_hidden_dim // 2\n\n    if out is None:\n        out_hidden_states = torch.empty(\n            (bs, hidden_dim),\n            dtype=quantize or hidden_states.dtype,\n            device=hidden_states.device,\n        )\n    else:\n        assert out.shape == (bs, hidden_dim)\n        assert out.dtype == (quantize or hidden_states.dtype)\n        out_hidden_states = out\n    out_scales = None\n    static_scale = False\n    if quantize is not None:\n        if scales is None:\n            out_scales = torch.empty(\n                (bs,), dtype=torch.float32, device=hidden_states.device\n            )\n        else:\n            out_scales = scales\n            static_scale = True\n\n    max_warps = 16 if _is_hip else 32\n    config = {\n        # 8 ele per thread (not tuned)\n        \"num_warps\": max(\n            min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), max_warps), 4\n        ),\n    }\n\n    silu_and_mul_kernel[(bs,)](\n        out_hidden_states,\n        out_scales,\n        hidden_states,\n        quant_max=torch.finfo(quantize).max if quantize is not None else None,\n        static_scale=static_scale,\n        hidden_dim=hidden_dim,\n        BLOCK_SIZE=triton.next_power_of_2(hidden_dim),\n        **config,\n    )\n\n    if quantize is not None:\n        return out_hidden_states, out_scales\n    else:\n        return out_hidden_states, None\n"
  },
  {
    "path": "python/sglang/srt/layers/flashinfer_comm_fusion.py",
    "content": "import logging\nfrom typing import Optional, Tuple\n\nimport torch\n\nfrom sglang.srt.distributed import (\n    get_tensor_model_parallel_rank,\n    get_tensor_model_parallel_world_size,\n)\nfrom sglang.srt.utils import is_flashinfer_available\nfrom sglang.srt.utils.custom_op import register_custom_op\n\nlogger = logging.getLogger(__name__)\n\n_flashinfer_comm = None\n_workspace_manager = None\n_flashinfer_allreduce_unavailable = False\n\nif is_flashinfer_available():\n    try:\n        import flashinfer.comm as comm\n\n        if hasattr(comm, \"allreduce_fusion\") and hasattr(\n            comm, \"create_allreduce_fusion_workspace\"\n        ):\n            _flashinfer_comm = comm\n        else:\n            _flashinfer_allreduce_unavailable = True\n            logger.warning(\n                \"flashinfer.comm unified allreduce_fusion API is not available, \"\n                \"falling back to standard implementation\"\n            )\n    except ImportError:\n        _flashinfer_allreduce_unavailable = True\n        logger.warning(\n            \"flashinfer.comm is not available, falling back to standard \"\n            \"implementation\"\n        )\n\n\ndef is_flashinfer_allreduce_unavailable() -> bool:\n    return _flashinfer_allreduce_unavailable\n\n\nclass FlashInferWorkspaceManager:\n    def __init__(self):\n        self.workspace = None\n        self.world_size = None\n        self.rank = None\n        self.max_token_num = None\n        self.hidden_dim = None\n        self.dtype = None\n        self.initialized = False\n\n    def initialize(\n        self,\n        world_size: int,\n        rank: int,\n        max_token_num: int,\n        hidden_dim: int,\n        dtype: torch.dtype,\n        use_oneshot: Optional[bool] = None,\n    ):\n        \"\"\"Initialize workspace\"\"\"\n        if _flashinfer_comm is None:\n            logger.warning(\n                \"FlashInfer comm not available, skipping workspace initialization\"\n            )\n            return\n\n        self.cleanup()\n        try:\n            self.workspace = _flashinfer_comm.create_allreduce_fusion_workspace(\n                backend=\"trtllm\",\n                world_size=world_size,\n                rank=rank,\n                max_token_num=max_token_num,\n                hidden_dim=hidden_dim,\n                dtype=dtype,\n                force_oneshot_support=bool(use_oneshot),\n            )\n        except Exception as e:\n            global _flashinfer_allreduce_unavailable\n            _flashinfer_allreduce_unavailable = True\n            logger.warning(\n                f\"Failed to initialize FlashInfer workspace: {e}. \"\n                \"Disabling flashinfer allreduce fusion permanently.\"\n            )\n            self.workspace = None\n            self.initialized = False\n            return\n\n        self.world_size = world_size\n        self.rank = rank\n        self.max_token_num = max_token_num\n        self.hidden_dim = hidden_dim\n        self.dtype = dtype\n        self.initialized = True\n\n        backend = getattr(self.workspace, \"backend\", \"unknown\")\n        logger.info(\n            f\"FlashInfer workspace initialized for rank {rank}, \"\n            f\"world_size {world_size}, backend {backend}\"\n        )\n\n    def is_buffer_size_sufficient(\n        self,\n        token_num: int,\n        hidden_dim: int,\n        dtype: torch.dtype,\n        use_oneshot: Optional[bool] = None,\n    ) -> bool:\n        if not self.initialized or self.workspace is None:\n            return False\n        try:\n            return self.workspace.is_buffer_size_sufficient(\n                tp_size=self.world_size,\n                num_tokens=token_num,\n                hidden_dim=hidden_dim,\n                dtype=dtype,\n                use_oneshot=use_oneshot,\n            )\n        except Exception as e:\n            logger.debug(f\"FlashInfer workspace size check failed: {e}\")\n            return False\n\n    def cleanup(self):\n        \"\"\"Clean up workspace\"\"\"\n        if self.workspace is not None:\n            try:\n                self.workspace.destroy()\n            except Exception as e:\n                logger.warning(f\"Failed to cleanup FlashInfer workspace: {e}\")\n            finally:\n                self.workspace = None\n                self.initialized = False\n                self.world_size = None\n                self.rank = None\n                self.max_token_num = None\n                self.hidden_dim = None\n                self.dtype = None\n\n\n_workspace_manager = FlashInferWorkspaceManager()\n\n\ndef ensure_workspace_initialized(\n    max_token_num: int = 2048,\n    hidden_dim: int = 4096,\n    dtype: torch.dtype = torch.float16,\n    token_num: Optional[int] = None,\n    use_oneshot: Optional[bool] = None,\n):\n    \"\"\"Ensure workspace is initialized\"\"\"\n    if _flashinfer_allreduce_unavailable:\n        return False\n\n    if not is_flashinfer_available() or _flashinfer_comm is None:\n        return False\n\n    world_size = get_tensor_model_parallel_world_size()\n    if world_size <= 1:\n        return False\n\n    rank = get_tensor_model_parallel_rank()\n    token_num = token_num or max_token_num\n\n    if (\n        not _workspace_manager.initialized\n        or _workspace_manager.world_size != world_size\n        or _workspace_manager.rank != rank\n        or not _workspace_manager.is_buffer_size_sufficient(\n            token_num=token_num,\n            hidden_dim=hidden_dim,\n            dtype=dtype,\n            use_oneshot=use_oneshot,\n        )\n    ):\n        _workspace_manager.initialize(\n            world_size=world_size,\n            rank=rank,\n            max_token_num=max_token_num,\n            hidden_dim=hidden_dim,\n            dtype=dtype,\n            use_oneshot=use_oneshot,\n        )\n\n    return _workspace_manager.initialized\n\n\ndef fake_flashinfer_allreduce_residual_rmsnorm(\n    input_tensor: torch.Tensor,\n    residual: torch.Tensor,\n    weight: torch.Tensor,\n    eps: float = 1e-6,\n    max_token_num: int = 16384,\n    use_oneshot: Optional[bool] = None,\n    trigger_completion_at_end: bool = False,\n    fp32_acc: bool = False,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    residual_out = torch.empty_like(residual)\n    norm_out = torch.empty_like(input_tensor)\n    return norm_out, residual_out\n\n\n@register_custom_op(\n    mutates_args=[\"input_tensor\", \"residual\", \"weight\"],\n    fake_impl=fake_flashinfer_allreduce_residual_rmsnorm,\n)\ndef flashinfer_allreduce_residual_rmsnorm(\n    input_tensor: torch.Tensor,\n    residual: torch.Tensor,\n    weight: torch.Tensor,\n    eps: float = 1e-6,\n    max_token_num: int = 2048,\n    use_oneshot: Optional[bool] = None,\n    trigger_completion_at_end: bool = False,\n    fp32_acc: bool = False,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Use FlashInfer's fused allreduce + residual + RMS norm operation\n\n    Args:\n        input_tensor: Input tensor that needs allreduce\n        residual: Residual tensor\n        weight: RMS norm weight\n        eps: RMS norm epsilon\n        max_token_num: Maximum token number\n        use_oneshot: Whether to use oneshot mode\n        trigger_completion_at_end: Whether to trigger completion at end\n        fp32_acc: Whether to use fp32 precision\n\n    Returns:\n        Tuple[torch.Tensor, torch.Tensor]: (norm_output, residual_output)\n    \"\"\"\n    if not is_flashinfer_available() or _flashinfer_comm is None:\n        logger.debug(\n            \"FlashInfer not available, falling back to standard implementation\"\n        )\n        return None, None\n\n    world_size = get_tensor_model_parallel_world_size()\n    if world_size <= 1:\n        logger.debug(\"Single GPU, no need for allreduce fusion\")\n        return None, None\n\n    assert input_tensor.shape[0] <= max_token_num\n    if (\n        not input_tensor.is_contiguous()\n        or not residual.is_contiguous()\n        or not weight.is_contiguous()\n    ):\n        logger.debug(\"Non-contiguous tensors, skipping FlashInfer allreduce fusion\")\n        return None, None\n\n    if not ensure_workspace_initialized(\n        max_token_num=max_token_num,\n        hidden_dim=input_tensor.shape[-1],\n        dtype=input_tensor.dtype,\n        token_num=input_tensor.shape[0],\n        use_oneshot=use_oneshot,\n    ):\n        logger.debug(\"FlashInfer workspace not available\")\n        return None, None\n\n    residual_out = torch.empty_like(residual)\n    norm_out = torch.empty_like(input_tensor)\n\n    _flashinfer_comm.allreduce_fusion(\n        input=input_tensor,\n        workspace=_workspace_manager.workspace,\n        pattern=_flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,\n        launch_with_pdl=True,\n        residual_out=residual_out,\n        norm_out=norm_out,\n        residual_in=residual,\n        rms_gamma=weight,\n        rms_eps=eps,\n        use_oneshot=use_oneshot,\n        fp32_acc=fp32_acc,\n    )\n\n    return norm_out, residual_out\n\n\ndef cleanup_flashinfer_workspace():\n    global _workspace_manager\n    if _workspace_manager is not None:\n        _workspace_manager.cleanup()\n"
  },
  {
    "path": "python/sglang/srt/layers/int4fp8_utils.py",
    "content": "\"\"\"\nCommon utilities for quark.\n\"\"\"\n\nimport logging\nfrom typing import Tuple\n\nimport torch\n\nlogger = logging.getLogger(__name__)\n\n\ndef quantize_fp8_scale_tensorwise(w: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n    FP8_MAX = 448.0\n    scale = w.abs().amax().float() / FP8_MAX\n    scaled = (w / scale).clamp(-FP8_MAX, FP8_MAX).to(torch.float8_e4m3fn)\n    return scaled, scale\n\n\ndef quantize_int4_scale_columnwise(\n    w: torch.Tensor,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    S4_MAX = 7\n    w_flat = w.reshape(-1, w.shape[-1]).float()\n    scale = w_flat.abs().amax(axis=-1) / S4_MAX\n    scaled = torch.round(w_flat / scale[:, None]).to(torch.int8).clamp(-S4_MAX, S4_MAX)\n    return scaled.reshape(w.shape), scale.reshape(w.shape[:-1])\n\n\ndef pack_int4_to_int32(to_pack: torch.Tensor, reorder: bool = True) -> torch.Tensor:\n    if to_pack.ndim > 2:\n        raise ValueError(\n            \"Pack: Only supports tensors with dimensions not greater than 2.\"\n        )\n\n    if reorder:\n        order_map = [0, 2, 4, 6, 1, 3, 5, 7]\n    else:\n        order_map = [0, 1, 2, 3, 4, 5, 6, 7]\n    pack_num = 8\n    if to_pack.ndim == 2:\n        packed = torch.zeros(\n            to_pack.shape[0],\n            to_pack.shape[1] // pack_num,\n            dtype=torch.int32,\n            device=to_pack.device,\n        )\n        new_c = to_pack.shape[1] // pack_num\n        for c in range(new_c):\n            for i in range(pack_num):\n                # Use -3 as an example, high_position is 11111111,cause bit_or generate errors, so we can't use int4 directly\n                packed_col = to_pack[:, c * pack_num + order_map[i]].to(torch.int32)\n                packed_col = packed_col & 0x0F\n                packed[:, c] = torch.bitwise_or(\n                    packed[:, c], torch.bitwise_left_shift(packed_col, i * 4)\n                )\n    elif to_pack.ndim == 0:\n        packed = to_pack.to(torch.int32)\n    else:\n        packed = torch.zeros(\n            to_pack.shape[0] // pack_num, dtype=torch.int32, device=to_pack.device\n        )\n        new_c = to_pack.shape[0] // pack_num\n        for c in range(new_c):\n            for i in range(pack_num):\n                # Use -3 as an example, high_position is 11111111,cause bit_or generate errors, so we can't use int4 directly\n                packed_col = to_pack[c * pack_num + order_map[i]]\n                packed_col = packed_col & 0x0F\n                packed[c] = torch.bitwise_or(\n                    packed[c], torch.bitwise_left_shift(packed_col, i * 4)\n                )\n\n    return packed.view(torch.uint32)\n"
  },
  {
    "path": "python/sglang/srt/layers/layernorm.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Fused operators for normalization layers.\"\"\"\n\nimport logging\nfrom typing import Optional, Tuple, Union\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom sglang.srt.batch_invariant_ops import (\n    is_batch_invariant_mode_enabled,\n    rms_norm_batch_invariant,\n)\nfrom sglang.srt.environ import envs\nfrom sglang.srt.layers.utils import MultiPlatformOp\nfrom sglang.srt.server_args import get_global_server_args\nfrom sglang.srt.utils import (\n    cpu_has_amx_support,\n    get_bool_env_var,\n    is_cpu,\n    is_cuda,\n    is_flashinfer_available,\n    is_hip,\n    is_npu,\n    is_xpu,\n)\n\n_is_cuda = is_cuda()\n_is_flashinfer_available = is_flashinfer_available()\n_is_hip = is_hip()\n_is_npu = is_npu()\n_use_aiter = get_bool_env_var(\"SGLANG_USE_AITER\") and _is_hip\n_is_cpu_amx_available = cpu_has_amx_support()\n_is_cpu = is_cpu()\n_is_xpu = is_xpu()\n_flashinfer_layernorm_available = False\n\nif _is_cuda or _is_xpu:\n    if _is_flashinfer_available:\n        try:\n            from flashinfer.norm import layernorm\n\n            _flashinfer_layernorm_available = True\n        except (ImportError, AttributeError):\n            _flashinfer_layernorm_available = False\n    else:\n        _flashinfer_layernorm_available = False\n\n    from sgl_kernel import (\n        fused_add_rmsnorm,\n        gemma_fused_add_rmsnorm,\n        gemma_rmsnorm,\n        rmsnorm,\n    )\n_has_vllm_rms_norm = False\nif _use_aiter:\n    from aiter import rmsnorm2d_fwd as rms_norm\n    from aiter import rmsnorm2d_fwd_with_add as fused_add_rms_norm\n\n    _has_vllm_rms_norm = True  # aiter provides the rms_norm functions\nelif _is_hip:\n    try:\n        from vllm._custom_ops import fused_add_rms_norm, rms_norm\n\n        _has_vllm_rms_norm = True\n    except ImportError:\n        # Fallback: vllm not available, will use forward_native\n        _has_vllm_rms_norm = False\n\nlogger = logging.getLogger(__name__)\n\nif _is_npu:\n    import torch_npu\n\n\ndef _forward_with_allreduce_fusion(\n    norm_module,\n    x: torch.Tensor,\n    residual: Optional[torch.Tensor],\n    post_residual_addition: Optional[torch.Tensor],\n    weight: torch.Tensor,\n) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n    \"\"\"Shared allreduce-fused RMSNorm logic usable by any norm.\"\"\"\n    if residual is not None:\n        from sglang.srt.distributed import (\n            get_tensor_model_parallel_world_size,\n            tensor_model_parallel_all_reduce,\n            tensor_model_parallel_fused_allreduce_rmsnorm,\n        )\n        from sglang.srt.layers.flashinfer_comm_fusion import (\n            flashinfer_allreduce_residual_rmsnorm,\n        )\n\n        if get_tensor_model_parallel_world_size() > 1:\n            if post_residual_addition is not None:\n                residual = residual + post_residual_addition\n\n            # Prefer AITER fused AR+RMSNorm when enabled on AMD.\n            if _use_aiter:\n                fused_result = tensor_model_parallel_fused_allreduce_rmsnorm(\n                    x, residual, weight, norm_module.variance_epsilon\n                )\n                if fused_result is not None:\n                    return fused_result\n            else:\n                fused_result = flashinfer_allreduce_residual_rmsnorm(\n                    input_tensor=x,\n                    residual=residual,\n                    weight=weight,\n                    eps=norm_module.variance_epsilon,\n                )\n                if fused_result[0] is not None:\n                    return fused_result\n\n            # For AITER route, preserve correctness when fused path is unavailable.\n            if _use_aiter and get_global_server_args().enable_aiter_allreduce_fusion:\n                x = tensor_model_parallel_all_reduce(x)\n                return norm_module.forward(x, residual, None)\n\n    return norm_module.forward(x, residual, post_residual_addition)\n\n\nclass RMSNorm(MultiPlatformOp):\n    def __init__(\n        self,\n        hidden_size: int,\n        eps: float = 1e-6,\n        var_hidden_size: Optional[int] = None,\n        cast_x_before_out_mul: bool = False,\n        fp32_residual: bool = False,\n        has_weight: bool = True,\n        weight_dtype: Optional = None,\n        override_orig_dtype: Optional = None,\n    ) -> None:\n        super().__init__()\n        self.has_weight = has_weight\n        self.cast_x_before_out_mul = cast_x_before_out_mul\n        self.fp32_residual = fp32_residual\n        self.override_orig_dtype = override_orig_dtype\n        if self.has_weight:\n            self.weight = nn.Parameter(torch.ones(hidden_size, dtype=weight_dtype))\n        else:\n            self.weight = torch.ones(hidden_size, dtype=weight_dtype)\n        self.variance_epsilon = eps\n        self.hidden_size = hidden_size\n        self.variance_size_override = (\n            None if var_hidden_size == hidden_size else var_hidden_size\n        )\n        if _use_aiter:\n            self._forward_method = self.forward_aiter\n\n    def forward_cuda(\n        self,\n        x: torch.Tensor,\n        residual: Optional[torch.Tensor] = None,\n        post_residual_addition: Optional[torch.Tensor] = None,\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        if x.numel() == 0:\n            return x\n        if self.variance_size_override is not None:\n            return self.forward_native(x, residual, post_residual_addition)\n        if is_batch_invariant_mode_enabled():\n            if (\n                residual is not None\n                or get_global_server_args().rl_on_policy_target == \"fsdp\"\n            ):\n                return self.forward_native(x, residual, post_residual_addition)\n            return rms_norm_batch_invariant(\n                x,\n                self.weight.data,\n                self.variance_epsilon,\n            )\n        if residual is not None:\n            # TODO: Ideally we want to have (hidden_states+residual)+post_residual_addition.\n            # but right now we can only have hidden_states+(residual+post_residual_addition).\n            # (hidden_states+residual)+post_residual_addition != hidden_states+(residual+post_residual_addition),\n            # we probably need to add another parameter to fused_add_rmsnorm\n            if post_residual_addition is not None:\n                residual = residual + post_residual_addition\n            fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)\n            return x, residual\n        out = rmsnorm(x, self.weight.data, self.variance_epsilon)\n        return out\n\n    def forward_npu(\n        self,\n        x: torch.Tensor,\n        residual: Optional[torch.Tensor] = None,\n        post_residual_addition: Optional[torch.Tensor] = None,\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        if residual is not None:\n            if post_residual_addition is not None:\n                residual = residual + post_residual_addition\n            out, _, residual_out = torch_npu.npu_add_rms_norm(\n                residual, x, self.weight.data, self.variance_epsilon\n            )\n            return out, residual_out\n        return torch_npu.npu_rms_norm(x, self.weight.data, self.variance_epsilon)[0]\n\n    def forward_aiter(\n        self,\n        x: torch.Tensor,\n        residual: Optional[torch.Tensor] = None,\n        post_residual_addition: Optional[torch.Tensor] = None,\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        if residual is not None:\n            residual_out = torch.empty_like(x)\n            output = torch.empty_like(x)\n            if post_residual_addition is not None:\n                residual = residual + post_residual_addition\n            fused_add_rms_norm(\n                output,\n                x,\n                residual,\n                residual_out,\n                self.weight.data,\n                self.variance_epsilon,\n            )\n            return output, residual_out\n        return rms_norm(x, self.weight.data, self.variance_epsilon)\n\n    def forward_hip(\n        self,\n        x: torch.Tensor,\n        residual: Optional[torch.Tensor] = None,\n        post_residual_addition: Optional[torch.Tensor] = None,\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        # Fallback to native implementation if vllm is not available\n        if not _has_vllm_rms_norm:\n            return self.forward_native(x, residual, post_residual_addition)\n\n        if not x.is_contiguous():\n            # NOTE: Remove this if aiter kernel supports discontinuous input\n            x = x.contiguous()\n        if residual is not None:\n            out = torch.empty_like(x)\n            residual_out = torch.empty_like(x)\n            if post_residual_addition is not None:\n                residual = residual + post_residual_addition\n            fused_add_rms_norm(\n                out, x, residual_out, residual, self.weight.data, self.variance_epsilon\n            )\n            return out, residual_out\n        out = torch.empty_like(x)\n        rms_norm(out, x, self.weight.data, self.variance_epsilon)\n        return out\n\n    def forward_native(\n        self,\n        x: torch.Tensor,\n        residual: Optional[torch.Tensor] = None,\n        post_residual_addition: Optional[torch.Tensor] = None,\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        if not x.is_contiguous():\n            x = x.contiguous()\n        orig_dtype = self.override_orig_dtype or x.dtype\n        x = x.to(torch.float32)\n        if residual is not None:\n            x = x + residual.to(torch.float32)\n            if post_residual_addition is not None:\n                x = x + post_residual_addition.to(torch.float32)\n            if self.fp32_residual:\n                residual = x.clone()\n            else:\n                residual = x.to(orig_dtype)\n\n        hidden_size = x.shape[-1]\n        if hidden_size != self.hidden_size:\n            raise ValueError(\n                \"Expected hidden_size to be \"\n                f\"{self.hidden_size}, but found: {hidden_size}\"\n            )\n\n        if self.variance_size_override is None:\n            x_var = x\n        else:\n            if hidden_size < self.variance_size_override:\n                raise ValueError(\n                    \"Expected hidden_size to be at least \"\n                    f\"{self.variance_size_override}, but found: {hidden_size}\"\n                )\n\n            x_var = x[..., : self.variance_size_override]\n\n        variance = x_var.pow(2).mean(dim=-1, keepdim=True)\n        x = x * torch.rsqrt(variance + self.variance_epsilon)\n\n        if self.cast_x_before_out_mul:\n            x = self.weight * x.to(orig_dtype)\n        else:\n            x = (x * self.weight).to(orig_dtype)\n\n        if residual is None:\n            return x\n        else:\n            return x, residual\n\n    def forward_cpu(\n        self,\n        x: torch.Tensor,\n        residual: Optional[torch.Tensor] = None,\n        post_residual_addition: Optional[torch.Tensor] = None,\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        if _is_cpu_amx_available:\n            if residual is not None:\n                if post_residual_addition is not None:\n                    residual = residual + post_residual_addition\n                torch.ops.sgl_kernel.fused_add_rmsnorm_cpu(\n                    x, residual, self.weight.data, self.variance_epsilon\n                )\n                return x, residual\n            return torch.ops.sgl_kernel.rmsnorm_cpu(\n                x, self.weight.data, self.variance_epsilon\n            )\n        else:\n            return self.forward_native(x, residual, post_residual_addition)\n\n    def forward_xpu(\n        self,\n        x: torch.Tensor,\n        residual: Optional[torch.Tensor] = None,\n        post_residual_addition: Optional[torch.Tensor] = None,\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        if self.variance_size_override is not None:\n            return self.forward_native(x, residual, post_residual_addition)\n        if residual is not None:\n            if post_residual_addition is not None:\n                residual = residual + post_residual_addition\n            fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)\n            return x, residual\n        out = rmsnorm(x, self.weight.data, self.variance_epsilon)\n        return out\n\n    def forward_with_allreduce_fusion(\n        self,\n        x: torch.Tensor,\n        residual: Optional[torch.Tensor] = None,\n        post_residual_addition: Optional[torch.Tensor] = None,\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        \"\"\"Forward with allreduce fusion, prioritizing flashinfer fused operations.\"\"\"\n        return _forward_with_allreduce_fusion(\n            self, x, residual, post_residual_addition, self.weight\n        )\n\n\nclass LayerNorm(MultiPlatformOp):\n    def __init__(\n        self,\n        hidden_size: int,\n        eps: float = 1e-6,\n        elementwise_affine: bool = True,\n        bias: bool = True,\n        dtype: torch.dtype = torch.float32,\n    ) -> None:\n        super().__init__()\n        self.hidden_size = hidden_size\n        self.variance_epsilon = eps\n        self.elementwise_affine = elementwise_affine\n        self.use_bias = bias\n        self.dtype = dtype\n\n        self.bias = nn.Parameter(torch.zeros(hidden_size, dtype=self.dtype))\n        self.weight = nn.Parameter(torch.ones(hidden_size, dtype=self.dtype))\n\n    def forward_cuda(\n        self,\n        x: torch.Tensor,\n    ) -> torch.Tensor:\n        if (\n            _flashinfer_layernorm_available\n            and x.dtype == torch.bfloat16\n            and self.dtype == torch.float32\n        ):\n            return layernorm(x, self.weight, self.bias, self.variance_epsilon)\n        else:\n            return self.forward_native(x)\n\n    def forward_native(\n        self,\n        x: torch.Tensor,\n    ) -> torch.Tensor:\n        weight = self.weight if self.elementwise_affine else None\n        bias = self.bias if self.use_bias else None\n        orig_dtype = x.dtype\n        x = x.to(self.dtype)\n        return F.layer_norm(\n            x,\n            (self.hidden_size,),\n            weight=weight,\n            bias=bias,\n            eps=self.variance_epsilon,\n        ).to(orig_dtype)\n\n    def forward_hip(\n        self,\n        x: torch.Tensor,\n    ) -> torch.Tensor:\n        return self.forward_native(x)\n\n    def forward_npu(\n        self,\n        x: torch.Tensor,\n    ) -> torch.Tensor:\n        return self.forward_native(x)\n\n    def forward_cpu(\n        self,\n        x: torch.Tensor,\n    ) -> torch.Tensor:\n        if _is_cpu_amx_available:\n            bias_data = self.bias.data if self.use_bias else None\n            return torch.ops.sgl_kernel.layernorm_cpu(\n                x, self.weight.data, bias_data, self.variance_epsilon\n            )\n        else:\n            return self.forward_native(x)\n\n\nclass GemmaRMSNorm(MultiPlatformOp):\n    def __init__(\n        self,\n        hidden_size: int,\n        eps: float = 1e-6,\n    ) -> None:\n        super().__init__()\n        self.weight = nn.Parameter(torch.zeros(hidden_size))\n        self.variance_epsilon = eps\n        # Re-dispatch\n        if _is_hip:\n            self._forward_method = self.forward_native\n\n    def _forward_impl(\n        self,\n        x: torch.Tensor,\n        residual: Optional[torch.Tensor] = None,\n        post_residual_addition: Optional[torch.Tensor] = None,\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        if residual is not None:\n            if post_residual_addition is not None:\n                residual = residual + post_residual_addition\n            gemma_fused_add_rmsnorm(\n                x, residual, self.weight.data, self.variance_epsilon\n            )\n            return x, residual\n        out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)\n        return out\n\n    def forward_native(\n        self,\n        x: torch.Tensor,\n        residual: Optional[torch.Tensor] = None,\n        post_residual_addition: Optional[torch.Tensor] = None,\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        orig_dtype = x.dtype\n        if residual is not None:\n            if post_residual_addition is not None:\n                residual = residual + post_residual_addition\n            x = x + residual\n            residual = x\n\n        x = x.float()\n        variance = x.pow(2).mean(dim=-1, keepdim=True)\n        x = x * torch.rsqrt(variance + self.variance_epsilon)\n        x = x * (1.0 + self.weight.float())\n        x = x.to(orig_dtype)\n        return x if residual is None else (x, residual)\n\n    def forward_cuda(\n        self,\n        x: torch.Tensor,\n        residual: Optional[torch.Tensor] = None,\n        post_residual_addition: Optional[torch.Tensor] = None,\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        return self._forward_impl(x, residual, post_residual_addition)\n\n    def forward_cpu(\n        self,\n        x: torch.Tensor,\n        residual: Optional[torch.Tensor] = None,\n        post_residual_addition: Optional[torch.Tensor] = None,\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        if _is_cpu_amx_available:\n            if residual is not None:\n                if post_residual_addition is not None:\n                    residual = residual + post_residual_addition\n                torch.ops.sgl_kernel.gemma_fused_add_rmsnorm_cpu(\n                    x, residual, self.weight.data, self.variance_epsilon\n                )\n                return x, residual\n            return torch.ops.sgl_kernel.gemma_rmsnorm_cpu(\n                x, self.weight.data, self.variance_epsilon\n            )\n        return self.forward_native(x, residual, post_residual_addition)\n\n    def forward_npu(\n        self,\n        x: torch.Tensor,\n        residual: Optional[torch.Tensor] = None,\n        post_residual_addition: Optional[torch.Tensor] = None,\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        if envs.SGLANG_NPU_FORWARD_NATIVE_GEMMA_RMS_NORM.get():\n            return self.forward_native(x, residual)\n        if residual is not None:\n            if post_residual_addition is not None:\n                residual = residual + post_residual_addition\n            x = x + residual\n            residual = x\n\n        x, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.variance_epsilon)\n        return x if residual is None else (x, residual)\n\n    def forward_xpu(\n        self,\n        x: torch.Tensor,\n        residual: Optional[torch.Tensor] = None,\n        post_residual_addition: Optional[torch.Tensor] = None,\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        return self._forward_impl(x, residual, post_residual_addition)\n\n    def forward_with_allreduce_fusion(\n        self,\n        x: torch.Tensor,\n        residual: Optional[torch.Tensor] = None,\n        post_residual_addition: Optional[torch.Tensor] = None,\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        \"\"\"Forward with allreduce fusion; uses 1 + weight for fused kernels.\"\"\"\n        # TODO(brayden): we can see if TRTLLM allreduce fusion can provide gemma-style norm\n        return _forward_with_allreduce_fusion(\n            self, x, residual, post_residual_addition, self.weight + 1.0\n        )\n\n\nclass Gemma3RMSNorm(MultiPlatformOp):\n    def __init__(self, dim: int, eps: float = 1e-6):\n        super().__init__()\n        self.eps = eps\n        self.weight = nn.Parameter(torch.zeros(dim))\n        # Re-dispatch\n\n    def _norm(self, x):\n        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)\n\n    def forward_native(self, x):\n        output = self._norm(x.float())\n        # Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16)\n        # See https://github.com/huggingface/transformers/pull/29402\n        output = output * (1.0 + self.weight.float())\n        return output.type_as(x)\n\n    def forward_cpu(self, x):\n        if _is_cpu_amx_available and x.stride(-1) == 1:\n            return torch.ops.sgl_kernel.gemma3_rmsnorm_cpu(x, self.weight, self.eps)\n        return self.forward_native(x)\n\n    def forward_cuda(self, x):\n        return self.forward_native(x)\n\n    def forward_npu(self, x):\n        output, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.eps)\n        return output\n\n    def extra_repr(self):\n        return f\"{tuple(self.weight.shape)}, eps={self.eps}\"\n"
  },
  {
    "path": "python/sglang/srt/layers/linear.py",
    "content": "\"\"\"Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/linear.py\"\"\"\n\nfrom __future__ import annotations\n\nimport itertools\nimport logging\nfrom typing import TYPE_CHECKING, Dict, List, Optional, Tuple\n\nimport torch\nfrom torch import nn\nfrom torch.nn.parameter import Parameter, UninitializedParameter\n\nfrom sglang.srt.distributed import (\n    divide,\n    get_tensor_model_parallel_rank,\n    get_tensor_model_parallel_world_size,\n    get_tp_group,\n    split_tensor_along_last_dim,\n    tensor_model_parallel_all_gather,\n    tensor_model_parallel_all_reduce,\n)\nfrom sglang.srt.distributed.device_communicators.pynccl_allocator import (\n    use_symmetric_memory,\n)\nfrom sglang.srt.layers.dp_attention import (\n    get_attention_tp_group,\n    is_allocation_symmetric,\n)\nfrom sglang.srt.layers.parameter import (\n    BasevLLMParameter,\n    BlockQuantScaleParameter,\n    PackedColumnParameter,\n    PackedvLLMParameter,\n    PerTensorScaleParameter,\n    RowvLLMParameter,\n    _ColumnvLLMParameter,\n)\nfrom sglang.srt.layers.utils import pad_or_narrow_weight\nfrom sglang.srt.utils import get_bool_env_var, is_cpu, is_hip, is_npu, set_weight_attrs\n\nif TYPE_CHECKING:\n    from sglang.srt.layers.quantization.base_config import (\n        QuantizationConfig,\n        QuantizeMethodBase,\n    )\n\n_is_hip = is_hip()\n_disable_hip_linear_quant = _is_hip and get_bool_env_var(\n    \"SGLANG_ROCM_DISABLE_LINEARQUANT\"\n)\n\nlogger = logging.getLogger(__name__)\n\nWEIGHT_LOADER_V2_SUPPORTED = [\n    \"CompressedTensorsLinearMethod\",\n    \"AWQMarlinLinearMethod\",\n    \"AWQLinearMethod\",\n    \"AWQLinearAscendMethod\",\n    \"GPTQMarlinLinearMethod\",\n    \"Fp8LinearMethod\",\n    \"BlockInt8LinearMethod\",\n    \"MarlinLinearMethod\",\n    \"QQQLinearMethod\",\n    \"GPTQMarlin24LinearMethod\",\n    \"TPUInt8LinearMethod\",\n    \"GPTQLinearMethod\",\n    \"FBGEMMFp8LinearMethod\",\n    \"GPTQLinearAscendMethod\",\n    \"GPTQMoEAscendMethod\",\n    \"ModelOptFp8LinearMethod\",\n    \"ModelOptFp4LinearMethod\",\n    \"IPEXAWQLinearMethod\",\n    \"PetitNvFp4LinearMethod\",\n    \"QuarkInt4Fp8LinearMethod\",\n]\n\n_is_cpu = is_cpu()\n_is_npu = is_npu()\n\n\ndef adjust_marlin_shard(param, shard_size, shard_offset):\n    marlin_tile_size = getattr(param, \"marlin_tile_size\", None)\n    if marlin_tile_size is None:\n        return shard_size, shard_offset\n\n    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size\n\n\ndef adjust_bitsandbytes_4bit_shard(\n    param: Parameter, shard_offsets: Dict[str, Tuple[int, int]], loaded_shard_id: str\n) -> Tuple[int, int]:\n    \"\"\"Adjust the quantization offsets and sizes for BitsAndBytes sharding.\"\"\"\n\n    total, _ = shard_offsets[\"total\"]\n    orig_offset, orig_size = shard_offsets[loaded_shard_id]\n\n    quantized_total = param.data.shape[0]\n    quantized_offset = orig_offset * quantized_total // total\n    quantized_size = orig_size * quantized_total // total\n\n    return quantized_size, quantized_offset\n\n\ndef adjust_scalar_to_fused_array(param, loaded_weight, shard_id):\n    \"\"\"For fused modules (QKV and MLP) we have an array of length\n    N that holds 1 scale for each \"logical\" matrix. So the param\n    is an array of length N. The loaded_weight corresponds to\n    one of the shards on disk. Here, we slice the param based on\n    the shard_id for loading.\n    \"\"\"\n    qkv_idxs = {\"q\": 0, \"k\": 1, \"v\": 2}\n\n    if isinstance(shard_id, str):\n        shard_id = qkv_idxs[shard_id]\n    elif not isinstance(shard_id, int):\n        raise ValueError(f\"Unknown Shard Id {shard_id}\")\n\n    # AutoFP8 scales do not have a shape\n    # compressed-tensors scales do have a shape\n    if len(loaded_weight.shape) != 0:\n        assert loaded_weight.shape[0] == 1\n        loaded_weight = loaded_weight[0]\n\n    return param[shard_id], loaded_weight\n\n\ndef adjust_shard_offsets(shard_offsets, loaded_weight, dim):\n    actual_weight_size = loaded_weight.size(dim)\n    target_weight_size = shard_offsets[-1][-1] + shard_offsets[-1][-2]\n    if actual_weight_size != target_weight_size:\n        new_shard_offsets = []\n        new_offset = 0\n        for shard_id, shard_offset, shard_size in shard_offsets:\n            actual_shard_size = actual_weight_size * shard_size // target_weight_size\n            new_shard_offsets.append((shard_id, new_offset, actual_shard_size))\n            new_offset += actual_shard_size\n        return new_shard_offsets\n    return shard_offsets\n\n\nclass LinearBase(torch.nn.Module):\n    \"\"\"Base linear layer.\n\n    Args:\n        input_size: input dimension of the linear layer.\n        output_size: output dimension of the linear layer.\n        bias: If true, add bias.\n        skip_bias_add: If true, skip adding bias but instead return it.\n        params_dtype: Data type for the parameters.\n        quant_config: Quantization configure.\n    \"\"\"\n\n    def __init__(\n        self,\n        input_size: int,\n        output_size: int,\n        skip_bias_add: bool = False,\n        params_dtype: Optional[torch.dtype] = None,\n        quant_config: Optional[QuantizationConfig] = None,\n        prefix: str = \"\",\n    ):\n        super().__init__()\n\n        # Keep input parameters\n        self.input_size = input_size\n        self.output_size = output_size\n        self.skip_bias_add = skip_bias_add\n        if params_dtype is None:\n            params_dtype = torch.get_default_dtype()\n        self.params_dtype = params_dtype\n        self.quant_config = quant_config\n        if quant_config is None:\n            from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod\n\n            self.quant_method: Optional[QuantizeMethodBase] = UnquantizedLinearMethod()\n        else:\n            self.quant_method = quant_config.get_quant_method(self, prefix=prefix)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        raise NotImplementedError\n\n\nclass ReplicatedLinear(LinearBase):\n    \"\"\"Replicated linear layer.\n\n    Args:\n        input_size: input dimension of the linear layer.\n        output_size: output dimension of the linear layer.\n        bias: If true, add bias.\n        skip_bias_add: If true, skip adding bias but instead return it.\n        params_dtype: Data type for the parameters.\n        quant_config: Quantization configure.\n        prefix: The name of the layer in the state dict, including all parents\n                        (e.g. model.layers.0.qkv_proj)\n    \"\"\"\n\n    def __init__(\n        self,\n        input_size: int,\n        output_size: int,\n        bias: bool = True,\n        skip_bias_add: bool = False,\n        params_dtype: Optional[torch.dtype] = None,\n        quant_config: Optional[QuantizationConfig] = None,\n        prefix: str = \"\",\n    ):\n        super().__init__(\n            input_size,\n            output_size,\n            skip_bias_add,\n            params_dtype,\n            quant_config,\n            prefix=prefix,\n        )\n\n        # All the linear layer supports quant method.\n        assert self.quant_method is not None\n        self.quant_method.create_weights(\n            self,\n            self.input_size,\n            [self.output_size],\n            self.input_size,\n            self.output_size,\n            self.params_dtype,\n            weight_loader=self.weight_loader,\n        )\n\n        if bias:\n            self.bias = Parameter(\n                torch.empty(self.output_size, dtype=self.params_dtype)\n            )\n            set_weight_attrs(\n                self.bias,\n                {\n                    \"output_dim\": 0,\n                    \"weight_loader\": self.weight_loader,\n                },\n            )\n        else:\n            self.register_parameter(\"bias\", None)\n\n    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):\n        # If the weight on disk does not have a shape, give it one\n        # (such scales for AutoFp8).\n        if len(loaded_weight.shape) == 0:\n            loaded_weight = loaded_weight.reshape(1)\n\n        # The per-tensor quant-scale must be 1 dimension\n        if _is_npu:\n            if param.size() != loaded_weight.size() and param.size(0) == 1:\n                if torch.allclose(loaded_weight, loaded_weight[0]):\n                    loaded_weight = loaded_weight[:1]\n                else:\n                    raise ValueError(f\"{loaded_weight} are not all equal\")\n\n            if param.dtype == torch.int8 or loaded_weight.dtype == torch.int8:\n                assert (\n                    param.dtype == loaded_weight.dtype\n                ), \"init para dtype and loaded weight dtype should be the same\"\n\n        assert param.size() == loaded_weight.size()\n        param.data.copy_(loaded_weight)\n\n    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n        bias = self.bias if not self.skip_bias_add else None\n        assert self.quant_method is not None\n        output = self.quant_method.apply(self, x, bias)\n        output_bias = self.bias if self.skip_bias_add else None\n        return output, output_bias\n\n    def extra_repr(self) -> str:\n        s = f\"in_features={self.input_size}\"\n        s += f\", output_features={self.output_size}\"\n        s += f\", bias={self.bias is not None}\"\n        return s\n\n\nclass ColumnParallelLinear(LinearBase):\n    \"\"\"Linear layer with column parallelism.\n\n    The linear layer is defined as Y = XA + b. A is parallelized along\n    its second dimension as A = [A_1, ..., A_p].\n\n    Args:\n        input_size: first dimension of matrix A.\n        output_size: second dimension of matrix A.\n        bias: If true, add bias.\n        gather_output: If true, call all-gather on output and make Y available\n                       to all GPUs, otherwise, every GPU will have its output\n                       which is Y_i = XA_i\n        skip_bias_add: This was added to enable performance optimizations where\n                       bias can be fused with other element-wise operations. we\n                       skip adding bias but instead return it.\n        params_dtype: Data type for the parameters.\n        quant_config: Quantization configure.\n        output_sizes: list of output sizes packed into one output, like for QKV\n                       the list would be size 3.\n        prefix: The name of the layer in the state dict, including all parents\n                        (e.g. model.layers.0.qkv_proj)\n    \"\"\"\n\n    def __init__(\n        self,\n        input_size: int,\n        output_size: int,\n        bias: bool = True,\n        gather_output: bool = False,\n        skip_bias_add: bool = False,\n        params_dtype: Optional[torch.dtype] = None,\n        quant_config: Optional[QuantizationConfig] = None,\n        output_sizes: Optional[List[int]] = None,\n        prefix: str = \"\",\n        tp_rank: Optional[int] = None,\n        tp_size: Optional[int] = None,\n        use_presharded_weights: bool = False,\n        skip_block_quant_check: bool = False,\n    ):\n        super().__init__(\n            input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix\n        )\n\n        self.gather_output = gather_output\n        self.use_presharded_weights = use_presharded_weights\n\n        # Divide the weight matrix along the last dimension.\n        if tp_rank is None:\n            tp_rank = get_tensor_model_parallel_rank()\n        if tp_size is None:\n            tp_size = get_tensor_model_parallel_world_size()\n        self.tp_rank, self.tp_size = tp_rank, tp_size\n        assert self.quant_method is not None\n        self.output_size_per_partition = divide(self.output_size, tp_size)\n        self.output_partition_sizes = [self.output_size_per_partition]\n        # If QKV or MergedColumn, use output size of each partition.\n        if hasattr(self, \"output_sizes\"):\n            self.output_partition_sizes = [\n                divide(output_size, tp_size) for output_size in self.output_sizes\n            ]\n\n        if output_sizes is None:\n            output_sizes = [output_size]\n\n        self.quant_method.create_weights(\n            layer=self,\n            input_size_per_partition=self.input_size,\n            output_partition_sizes=self.output_partition_sizes,\n            input_size=self.input_size,\n            output_size=self.output_size,\n            params_dtype=self.params_dtype,\n            skip_block_quant_check=skip_block_quant_check,\n            weight_loader=(\n                self.weight_loader_v2\n                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED\n                else self.weight_loader\n            ),\n        )\n        if bias:\n            self.bias = Parameter(\n                torch.zeros(self.output_size_per_partition, dtype=params_dtype)\n            )\n            set_weight_attrs(\n                self.bias,\n                {\n                    \"output_dim\": 0,\n                    \"weight_loader\": self.weight_loader,\n                },\n            )\n        else:\n            self.register_parameter(\"bias\", None)\n\n    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):\n        output_dim = getattr(param, \"output_dim\", None)\n        param_data = param.data\n\n        # Special case for GGUF\n        is_gguf_weight = getattr(param, \"is_gguf_weight\", False)\n        is_gguf_weight_type = getattr(param, \"is_gguf_weight_type\", False)\n        if is_gguf_weight_type:\n            param.weight_type = loaded_weight.item()\n\n        # Materialize GGUF UninitializedParameter\n        if is_gguf_weight and isinstance(param, UninitializedParameter):\n            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)\n\n        # bitsandbytes loads the weights of the specific portion\n        # no need to narrow here\n        use_bitsandbytes_4bit = getattr(param, \"use_bitsandbytes_4bit\", False)\n        if output_dim is not None and not use_bitsandbytes_4bit:\n            shard_size = param_data.shape[output_dim]\n            start_idx = self.tp_rank * shard_size\n\n            if _is_cpu:\n                from sglang.srt.model_loader.weight_utils import (\n                    narrow_padded_param_and_loaded_weight,\n                )\n\n                param_data, loaded_weight = narrow_padded_param_and_loaded_weight(\n                    param_data,\n                    loaded_weight,\n                    0,  # param_data_start\n                    start_idx,\n                    output_dim,\n                    shard_size,\n                    not self.use_presharded_weights,\n                )\n            else:\n                if not self.use_presharded_weights:\n                    loaded_weight = loaded_weight.narrow(\n                        output_dim, start_idx, shard_size\n                    )\n\n        # Special case for loading scales off disk, which often do not\n        # have a shape (such as in the case of AutoFP8).\n        if len(loaded_weight.shape) == 0:\n            loaded_weight = loaded_weight.reshape(1)\n\n        assert param_data.shape == loaded_weight.shape\n        param_data.copy_(loaded_weight)\n\n    def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):\n        # Special case for loading scales off disk, which often do not\n        # have a shape (such as in the case of AutoFP8).\n        if len(loaded_weight.shape) == 0:\n            assert loaded_weight.numel() == 1\n            loaded_weight = loaded_weight.reshape(1)\n\n        if isinstance(param, _ColumnvLLMParameter):\n            param.load_column_parallel_weight(\n                loaded_weight,\n                tp_rank=self.tp_rank,\n                use_presharded_weights=self.use_presharded_weights,\n            )\n        else:\n            # FIXME: This branch is needed to load deepseek v3 awq.\n            # However, we should fix this and avoid the branching here.\n            # After QuantizedRL reload, params might still need tp_rank\n            try:\n                param.load_column_parallel_weight(\n                    loaded_weight,\n                    tp_rank=self.tp_rank,\n                    use_presharded_weights=self.use_presharded_weights,\n                )\n            except TypeError:\n                # Fallback for parameters that don't accept additional args\n                param.load_column_parallel_weight(loaded_weight)\n\n    def forward(self, input_):\n        bias = self.bias if not self.skip_bias_add else None\n\n        # Matrix multiply.\n        assert self.quant_method is not None\n        output_parallel = self.quant_method.apply(self, input_, bias)\n        if self.gather_output:\n            # All-gather across the partitions.\n            output = tensor_model_parallel_all_gather(output_parallel)\n        else:\n            output = output_parallel\n        output_bias = self.bias if self.skip_bias_add else None\n        return output, output_bias\n\n    def extra_repr(self) -> str:\n        s = f\"in_features={self.input_size}\"\n        s += f\", output_features={self.output_size_per_partition}\"\n        s += f\", bias={self.bias is not None}\"\n        s += f\", tp_size={self.tp_size}\"\n        s += f\", gather_output={self.gather_output}\"\n        return s\n\n\nclass MergedColumnParallelLinear(ColumnParallelLinear):\n    \"\"\"Packed linear layers with column parallelism.\n\n    Similar to ColumnParallelLinear, but the weight matrix is concatenated\n    along the output dimension. When the weight matrix is loaded, the\n    different partitions are sharded separately.\n\n    Args:\n        input_size: input dimension of the linear layer.\n        output_sizes: list of output dimensions of the linear layer.\n        bias: If true, add bias.\n        gather_output: If true, call all-gather on output and make the output\n                       available to all GPUs, otherwise, every GPU will have\n                       its own output.\n        skip_bias_add: This was added to enable performance optimizations where\n                       bias can be fused with other element-wise operations. we\n                       skip adding bias but instead return it.\n        params_dtype: Data type for the parameters.\n        quant_config: Quantization configure.\n        prefix: The name of the layer in the state dict, including all parents\n                        (e.g. model.layers.0.qkv_proj)\n    \"\"\"\n\n    def __init__(\n        self,\n        input_size: int,\n        output_sizes: List[int],\n        bias: bool = True,\n        gather_output: bool = False,\n        skip_bias_add: bool = False,\n        params_dtype: Optional[torch.dtype] = None,\n        quant_config: Optional[QuantizationConfig] = None,\n        prefix: str = \"\",\n        tp_rank: Optional[int] = None,\n        tp_size: Optional[int] = None,\n        use_presharded_weights: bool = False,\n    ):\n        self.output_sizes = output_sizes\n        if tp_rank is None:\n            tp_rank = get_tensor_model_parallel_rank()\n        if tp_size is None:\n            tp_size = get_tensor_model_parallel_world_size()\n        self.tp_rank, self.tp_size = tp_rank, tp_size\n        assert all(output_size % tp_size == 0 for output_size in output_sizes)\n        self.use_presharded_weights = use_presharded_weights\n        super().__init__(\n            input_size=input_size,\n            output_size=sum(output_sizes),\n            bias=bias,\n            gather_output=gather_output,\n            skip_bias_add=skip_bias_add,\n            params_dtype=params_dtype,\n            quant_config=quant_config,\n            prefix=prefix,\n            tp_rank=tp_rank,\n            tp_size=tp_size,\n            use_presharded_weights=use_presharded_weights,\n        )\n        self.prefix = prefix\n\n    def weight_loader(\n        self,\n        param: Parameter,\n        loaded_weight: torch.Tensor,\n        loaded_shard_id: tuple[int, ...] | int | None = None,\n    ):\n        if isinstance(loaded_shard_id, tuple):\n            if hasattr(param, \"load_merged_column_weight\"):\n                return self.weight_loader_v2(param, loaded_weight, loaded_shard_id)\n            raise NotImplementedError(\n                \"Shard id with multiple indices is not supported in weight_loader, \"\n                \"please use weight_loader_v2 instead.\"\n            )\n\n        # Special case for GGUF\n        # initialize GGUF param after we know the quantize type\n        is_gguf_weight = getattr(param, \"is_gguf_weight\", False)\n        is_gguf_weight_type = getattr(param, \"is_gguf_weight_type\", False)\n        if is_gguf_weight_type:\n            param.data[loaded_shard_id].copy_(loaded_weight)\n            param.shard_weight_type[loaded_shard_id] = loaded_weight.item()\n            return\n\n        if is_gguf_weight:\n            output_dim = getattr(param, \"output_dim\", None)\n            shard_size = loaded_weight.size(output_dim) // self.tp_size\n            start_idx = self.tp_rank * shard_size\n\n            loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)\n\n            param.shard_id.append(loaded_shard_id)\n            param.shard_id_map[loaded_shard_id] = len(param.data_container)\n            param.data_container.append(loaded_weight)\n            return\n\n        param_data = param.data\n        output_dim = getattr(param, \"output_dim\", None)\n        # Special case for AQLM codebooks.\n        is_metadata = getattr(param, \"is_metadata\", False)\n        # Special case for per-tensor scale to load scalar into fused array.\n        needs_scalar_to_array = getattr(param, \"needs_scalar_to_array\", False)\n\n        if loaded_shard_id is None:\n            # Loaded weight is already fused on disk (qkv/mlp).\n            if output_dim is None:\n                if needs_scalar_to_array:\n                    param_data, loaded_weight = adjust_scalar_to_fused_array(\n                        param_data, loaded_weight, 0\n                    )\n\n                assert param_data.shape == loaded_weight.shape\n                param_data.copy_(loaded_weight)\n                return\n            current_shard_offset = 0\n            shard_offsets: List[Tuple[int, int, int]] = []\n            for i, output_size in enumerate(self.output_sizes):\n                effective_size = (\n                    output_size // self.tp_size\n                    if self.use_presharded_weights\n                    else output_size\n                )\n                shard_offsets.append((i, current_shard_offset, effective_size))\n                current_shard_offset += effective_size\n            packed_dim = getattr(param, \"packed_dim\", None)\n\n            use_bitsandbytes_4bit = getattr(param, \"use_bitsandbytes_4bit\", False)\n            if _is_cpu:\n                shard_offsets = adjust_shard_offsets(\n                    shard_offsets, loaded_weight, output_dim\n                )\n\n            for shard_id, shard_offset, shard_size in shard_offsets:\n                # Special case for Quantization.\n                # If quantized, we need to adjust the offset and size to account\n                # for the packing.\n                if packed_dim == output_dim:\n                    shard_size = shard_size // param.pack_factor\n                    shard_offset = shard_offset // param.pack_factor\n                    # Special case for Marlin.\n                    shard_size, shard_offset = adjust_marlin_shard(\n                        param, shard_size, shard_offset\n                    )\n\n                if use_bitsandbytes_4bit:\n                    index = list(itertools.accumulate([0] + self.output_sizes))\n                    orig_offsets = {\n                        str(i): (index[i], size)\n                        for i, size in enumerate(self.output_sizes)\n                    }\n                    orig_offsets[\"total\"] = (self.output_size, 0)\n                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(\n                        param, orig_offsets, str(shard_id)\n                    )\n\n                loaded_weight_shard = loaded_weight.narrow(\n                    output_dim, shard_offset, shard_size\n                )\n                self.weight_loader(param, loaded_weight_shard, shard_id)\n            return\n\n        assert loaded_shard_id < len(self.output_sizes)\n        if output_dim is not None:\n            shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size\n            shard_size = self.output_sizes[loaded_shard_id] // self.tp_size\n            # Special case for quantization.\n            # If quantized, we need to adjust the offset and size to account\n            # for the packing.\n            packed_dim = getattr(param, \"packed_dim\", None)\n            if packed_dim == output_dim:\n                shard_size = shard_size // param.pack_factor\n                shard_offset = shard_offset // param.pack_factor\n                # Special case for Marlin.\n                shard_size, shard_offset = adjust_marlin_shard(\n                    param, shard_size, shard_offset\n                )\n\n            use_bitsandbytes_4bit = getattr(param, \"use_bitsandbytes_4bit\", False)\n            if use_bitsandbytes_4bit:\n                shard_size = loaded_weight.shape[output_dim]\n                shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id\n\n            param_data = param_data.narrow(output_dim, shard_offset, shard_size)\n            start_idx = self.tp_rank * shard_size\n\n            if _is_cpu:\n                from sglang.srt.model_loader.weight_utils import (\n                    narrow_padded_param_and_loaded_weight,\n                )\n\n                param_data, loaded_weight = narrow_padded_param_and_loaded_weight(\n                    param_data,\n                    loaded_weight,\n                    0,  # param_data_start\n                    start_idx,\n                    output_dim,\n                    shard_size,\n                    not use_bitsandbytes_4bit and not self.use_presharded_weights,\n                )\n            else:\n                # bitsandbytes loads the weights of the specific portion\n                # no need to narrow here\n                if not use_bitsandbytes_4bit and not self.use_presharded_weights:\n                    # Padding for special case like qwen2_5_VL's mlp which is not 8-aligned\n                    end_idx = start_idx + shard_size\n                    if end_idx > loaded_weight.shape[output_dim]:\n                        loaded_weight = pad_or_narrow_weight(\n                            loaded_weight, output_dim, start_idx, shard_size\n                        )\n                    else:\n                        loaded_weight = loaded_weight.narrow(\n                            output_dim, start_idx, shard_size\n                        )\n\n        # Special case for AQLM codebooks.\n        elif is_metadata:\n            # metadata indicates fixed size concatenated along dim 0\n            shard_size = loaded_weight.shape[0]\n            shard_offset = loaded_shard_id * shard_size\n            param_data = param_data.narrow(0, shard_offset, shard_size)\n\n        # Special case for per-tensor scales in fused case.\n        elif needs_scalar_to_array:\n            param_data, loaded_weight = adjust_scalar_to_fused_array(\n                param_data, loaded_weight, loaded_shard_id\n            )\n\n        else:\n            ignore_warning = getattr(param, \"ignore_warning\", False)\n            if not ignore_warning:\n                logger.warning(\n                    \"Loading a weight without `output_dim` attribute in \"\n                    \"MergedColumnParallelLinear, assume the weight is \"\n                    \"the same for all partitions.\"\n                )\n\n        assert param_data.shape == loaded_weight.shape\n        param_data.copy_(loaded_weight)\n\n    def _load_fused_module_from_checkpoint(\n        self,\n        param: BasevLLMParameter,\n        loaded_weight: torch.Tensor,\n        output_sizes: list[int] | None = None,\n    ):\n        \"\"\"\n        Handle special case for models where MLP layers are already\n        fused on disk. In this case, we have no shard id. This function\n        determmines the shard id by splitting these layers and then calls\n        the weight loader using the shard id.\n\n        An example of a model with these fused layers:\n        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct\n        \"\"\"\n\n        current_shard_offset = 0\n        shard_offsets: List[Tuple[int, int, int]] = []\n        output_sizes = output_sizes or self.output_sizes\n        for i, output_size in enumerate(output_sizes):\n            shard_offsets.append((i, current_shard_offset, output_size))\n            current_shard_offset += output_size\n\n        for shard_id, shard_offset, shard_size in shard_offsets:\n            # Special case for Quantization.\n            # If quantized, we need to adjust the offset and size to account\n            # for the packing.\n            if (\n                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))\n                and param.packed_dim == param.output_dim\n            ):\n                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(\n                    shard_size=shard_size, shard_offset=shard_offset\n                )\n\n            loaded_weight_shard = loaded_weight.narrow(\n                param.output_dim, shard_offset, shard_size\n            )\n            self.weight_loader_v2(param, loaded_weight_shard, shard_id)\n\n    def _load_merged_block_scale(\n        self, param: BasevLLMParameter, loaded_weight: torch.Tensor\n    ):\n        \"\"\"\n        Handle block-wise scale loading for MergedColumnParallelLinear.\n        Similar to QKVParallelLinear._load_qkv_block_scale, but for merged column layers.\n        \"\"\"\n        weight_block_size = self.quant_method.quant_config.weight_block_size\n        block_n, _ = weight_block_size[0], weight_block_size[1]\n        block_n = 1 if getattr(param, \"format_ue8m0\", False) else block_n\n\n        # Calculate block sizes for each shard\n        shard_block_sizes = []\n        shard_block_offsets = []\n        current_block_offset = 0\n        for output_size in self.output_sizes:\n            shard_block_size = (output_size + block_n - 1) // block_n\n            shard_block_sizes.append(shard_block_size)\n            shard_block_offsets.append(current_block_offset)\n            current_block_offset += shard_block_size\n\n        # Load each shard\n        for shard_id, (shard_block_offset, shard_block_size) in enumerate(\n            zip(shard_block_offsets, shard_block_sizes)\n        ):\n            # Extract the shard from loaded_weight\n            loaded_weight_shard = loaded_weight.narrow(\n                param.output_dim, shard_block_offset, shard_block_size\n            )\n\n            # Calculate per-rank offset and size (considering TP)\n            rank_shard_offset = shard_block_offset // self.tp_size\n            rank_shard_size = shard_block_size // self.tp_size\n\n            # Load into the parameter\n            param.load_merged_column_weight(\n                loaded_weight=loaded_weight_shard,\n                shard_id=shard_id,\n                shard_offset=rank_shard_offset,\n                shard_size=rank_shard_size,\n                tp_rank=self.tp_rank,\n                tp_size=self.tp_size,\n                use_presharded_weights=self.use_presharded_weights,\n            )\n\n    def weight_loader_v2(\n        self,\n        param: BasevLLMParameter,\n        loaded_weight: torch.Tensor,\n        loaded_shard_id: tuple[int, ...] | int | None = None,\n    ):\n        if loaded_shard_id is None or isinstance(loaded_shard_id, tuple):\n            if isinstance(param, PerTensorScaleParameter):\n                param.load_merged_column_weight(\n                    loaded_weight=loaded_weight,\n                    shard_id=0,\n                    tp_rank=self.tp_rank,\n                    tp_size=self.tp_size,\n                )\n                return\n            elif isinstance(param, BlockQuantScaleParameter):\n                self._load_merged_block_scale(param, loaded_weight)\n                return\n            elif type(param) in (RowvLLMParameter, BasevLLMParameter):\n                param.load_merged_column_weight(\n                    loaded_weight=loaded_weight,\n                    tp_rank=self.tp_rank,\n                    tp_size=self.tp_size,\n                )\n                return\n            output_sizes = (\n                [self.output_sizes[idx] for idx in loaded_shard_id]\n                if loaded_shard_id\n                else None\n            )\n            # TODO: @dsikka - move to parameter.py\n            self._load_fused_module_from_checkpoint(\n                param, loaded_weight, output_sizes=output_sizes\n            )\n            return\n\n        assert loaded_shard_id < len(self.output_sizes)\n\n        if isinstance(param, BlockQuantScaleParameter):\n            weight_block_size = self.quant_method.quant_config.weight_block_size\n            raw_block_n, _ = weight_block_size[0], weight_block_size[1]\n            block_n = 1 if getattr(param, \"format_ue8m0\", False) else raw_block_n\n            shard_offset = (\n                (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // block_n\n            ) // self.tp_size\n            shard_size = (\n                (self.output_sizes[loaded_shard_id] + block_n - 1)\n                // block_n\n                // self.tp_size\n            )\n        else:\n            shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size\n            shard_size = self.output_sizes[loaded_shard_id] // self.tp_size\n\n        param.load_merged_column_weight(\n            loaded_weight=loaded_weight,\n            shard_id=loaded_shard_id,\n            shard_offset=shard_offset,\n            shard_size=shard_size,\n            use_presharded_weights=self.use_presharded_weights,\n            tp_rank=self.tp_rank,\n            tp_size=self.tp_size,\n        )\n\n\nclass QKVParallelLinear(ColumnParallelLinear):\n    \"\"\"Linear layers for the attention's QKV transformation.\n\n    Linear layers for the linear transformation of the query, key, and value\n    vectors in the attention layer. The weight matrix is concatenated along\n    the output dimension. The layer is parallelized along the head dimension.\n    When the number of key/value heads is smaller than the number of query\n    heads (e.g., multi-query/grouped-query attention), the key/value head may\n    be replicated while the query heads are partitioned.\n\n    Args:\n        hidden_size: input hidden state size of the transformer.\n        head_size: size of each attention head.\n        total_num_heads: total number of attention query heads.\n        total_num_kv_heads: total number of attention key/value heads. If\n                            None, assume total_num_kv_heads = total_num_heads.\n        bias: If true, add bias.\n        skip_bias_add: This was added to enable performance optimizations where\n                       bias can be fused with other element-wise operations. we\n                       skip adding bias but instead return it.\n        params_dtype: Data type for the parameters.\n        quant_config: Quantization configure.\n        prefix: The name of the layer in the state dict, including all parents\n                        (e.g. model.layers.0.qkv_proj)\n    \"\"\"\n\n    def __init__(\n        self,\n        hidden_size: int,\n        head_size: int,\n        total_num_heads: int,\n        total_num_kv_heads: Optional[int] = None,\n        bias: bool = True,\n        skip_bias_add: bool = False,\n        params_dtype: Optional[torch.dtype] = None,\n        quant_config: Optional[QuantizationConfig] = None,\n        prefix: str = \"\",\n        tp_rank: Optional[int] = None,\n        tp_size: Optional[int] = None,\n        load_presharded_attn: bool = False,\n        v_head_size: Optional[int] = None,\n        skip_block_quant_check: bool = False,\n    ):\n        self.hidden_size = hidden_size\n        self.head_size = head_size\n        self.v_head_size = v_head_size if v_head_size is not None else head_size\n        self.total_num_heads = total_num_heads\n        if total_num_kv_heads is None:\n            total_num_kv_heads = total_num_heads\n        self.total_num_kv_heads = total_num_kv_heads\n        # Divide the weight matrix along the last dimension.\n        if tp_rank is None:\n            tp_rank = get_tensor_model_parallel_rank()\n        if tp_size is None:\n            tp_size = get_tensor_model_parallel_world_size()\n        self.tp_rank, self.tp_size = tp_rank, tp_size\n        self.num_heads = divide(self.total_num_heads, tp_size)\n        if tp_size >= self.total_num_kv_heads:\n            self.num_kv_heads = 1\n            self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads)\n        else:\n            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)\n            self.num_kv_head_replicas = 1\n        self.q_proj_shard_size = self.num_heads * self.head_size\n        self.kv_proj_shard_size = self.num_kv_heads * self.head_size\n        self.v_proj_shard_size = self.num_kv_heads * self.v_head_size\n        input_size = self.hidden_size\n        output_size = (\n            self.num_heads * self.head_size\n            + self.num_kv_heads * self.head_size\n            + self.num_kv_heads * self.v_head_size\n        ) * tp_size\n        self.output_sizes = [\n            self.num_heads * self.head_size * tp_size,  # q_proj\n            self.num_kv_heads * self.head_size * tp_size,  # k_proj\n            self.num_kv_heads * self.v_head_size * tp_size,  # v_proj\n        ]\n        self.use_presharded_weights = load_presharded_attn\n        quant_config = None if _disable_hip_linear_quant else quant_config\n\n        super().__init__(\n            input_size=input_size,\n            output_size=output_size,\n            bias=bias,\n            gather_output=False,\n            skip_bias_add=skip_bias_add,\n            params_dtype=params_dtype,\n            quant_config=quant_config,\n            prefix=prefix,\n            tp_rank=tp_rank,\n            tp_size=tp_size,\n            use_presharded_weights=self.use_presharded_weights,\n            skip_block_quant_check=skip_block_quant_check,\n        )\n\n    def _get_shard_offset_mapping(self, loaded_shard_id: str):\n        shard_offset_mapping = {\n            \"q\": 0,\n            \"k\": self.num_heads * self.head_size,\n            \"v\": (self.num_heads + self.num_kv_heads) * self.head_size,\n            \"total\": (self.num_heads + self.num_kv_heads) * self.head_size\n            + self.num_kv_heads * self.v_head_size,\n        }\n        return shard_offset_mapping.get(loaded_shard_id)\n\n    def _get_shard_size_mapping(self, loaded_shard_id: str):\n        shard_size_mapping = {\n            \"q\": self.num_heads * self.head_size,\n            \"k\": self.num_kv_heads * self.head_size,\n            \"v\": self.num_kv_heads * self.v_head_size,\n        }\n        return shard_size_mapping.get(loaded_shard_id)\n\n    def _load_fused_module_from_checkpoint(\n        self, param: BasevLLMParameter, loaded_weight: torch.Tensor\n    ):\n        \"\"\"\n        Handle special case for models where QKV layers are already\n        fused on disk. In this case, we have no shard id. This function\n        determmines the shard id by splitting these layers and then calls\n        the weight loader using the shard id.\n\n        An example of a model with these fused layers:\n        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct\n        \"\"\"\n        shard_offsets = [\n            # (shard_id, shard_offset, shard_size)\n            (\"q\", 0, self.total_num_heads * self.head_size),\n            (\n                \"k\",\n                self.total_num_heads * self.head_size,\n                self.total_num_kv_heads * self.head_size,\n            ),\n            (\n                \"v\",\n                (self.total_num_heads + self.total_num_kv_heads) * self.head_size,\n                self.total_num_kv_heads * self.v_head_size,\n            ),\n        ]\n\n        for shard_id, shard_offset, shard_size in shard_offsets:\n            # Special case for Quantization.\n            # If quantized, we need to adjust the offset and size to account\n            # for the packing.\n            if (\n                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))\n                and param.packed_dim == param.output_dim\n            ):\n                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(\n                    shard_size=shard_size, shard_offset=shard_offset\n                )\n\n            if not self.use_presharded_weights:\n                loaded_weight_shard = loaded_weight.narrow(\n                    param.output_dim, shard_offset, shard_size\n                )\n            self.weight_loader_v2(param, loaded_weight_shard, shard_id)\n\n    def _load_qkv_block_scale(\n        self, param: BasevLLMParameter, loaded_weight: torch.Tensor\n    ):\n        block_n, _ = self.quant_method.quant_config.weight_block_size\n        q_size = self.total_num_heads * self.head_size // block_n\n        k_size = self.total_num_kv_heads * self.head_size // block_n\n        v_size = self.total_num_kv_heads * self.head_size // block_n\n        shard_offsets = [\n            # (shard_id, shard_offset, shard_size)\n            (\"q\", 0, q_size),\n            (\"k\", q_size, k_size),\n            (\"v\", q_size + k_size, v_size),\n        ]\n        for shard_id, shard_offset, shard_size in shard_offsets:\n            loaded_weight_shard = loaded_weight.narrow(\n                param.output_dim, shard_offset, shard_size\n            )\n            rank_shard_offset = self._get_shard_offset_mapping(shard_id) // block_n\n            rank_shard_size = self._get_shard_size_mapping(shard_id) // block_n\n            param.load_qkv_weight(\n                loaded_weight=loaded_weight_shard,\n                num_heads=self.num_kv_head_replicas,\n                shard_id=shard_id,\n                shard_offset=rank_shard_offset,\n                shard_size=rank_shard_size,\n                tp_rank=self.tp_rank,\n                use_presharded_weights=self.use_presharded_weights,\n            )\n\n    def weight_loader_v2(\n        self,\n        param: BasevLLMParameter,\n        loaded_weight: torch.Tensor,\n        loaded_shard_id: Optional[str] = None,\n    ):\n        if loaded_shard_id is None:  # special case for certain models\n            if isinstance(param, PerTensorScaleParameter):\n                param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0)\n                return\n            elif type(param) in (RowvLLMParameter, BasevLLMParameter):\n                param.load_qkv_weight(loaded_weight=loaded_weight)\n                return\n            elif isinstance(param, BlockQuantScaleParameter):\n                self._load_qkv_block_scale(param, loaded_weight)\n                return\n            # TODO: @dsikka - move to parameter.py\n            self._load_fused_module_from_checkpoint(param, loaded_weight)\n            return\n\n        assert loaded_shard_id in [\"q\", \"k\", \"v\"]\n\n        shard_offset = self._get_shard_offset_mapping(loaded_shard_id)\n        shard_size = self._get_shard_size_mapping(loaded_shard_id)\n\n        if isinstance(param, BlockQuantScaleParameter):\n            weight_block_size = self.quant_method.quant_config.weight_block_size\n            raw_block_n, _ = weight_block_size[0], weight_block_size[1]\n            block_n = 1 if getattr(param, \"format_ue8m0\", False) else raw_block_n\n            shard_offset = (shard_offset + block_n - 1) // block_n\n            shard_size = (shard_size + block_n - 1) // block_n\n\n        param.load_qkv_weight(\n            loaded_weight=loaded_weight,\n            num_heads=self.num_kv_head_replicas,\n            shard_id=loaded_shard_id,\n            shard_offset=shard_offset,\n            shard_size=shard_size,\n            tp_rank=self.tp_rank,\n            use_presharded_weights=self.use_presharded_weights,\n        )\n\n    def weight_loader(\n        self,\n        param: Parameter,\n        loaded_weight: torch.Tensor,\n        loaded_shard_id: Optional[str] = None,\n    ):\n\n        # Special case for GGUF\n        # initialize GGUF param after we know the quantize type\n        is_gguf_weight = getattr(param, \"is_gguf_weight\", False)\n        is_gguf_weight_type = getattr(param, \"is_gguf_weight_type\", False)\n        if is_gguf_weight_type and loaded_shard_id is not None:\n            idx_map = {\"q\": 0, \"k\": 1, \"v\": 2}\n            param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)\n            param.shard_weight_type[loaded_shard_id] = loaded_weight.item()\n            return\n\n        if is_gguf_weight:\n            output_dim = getattr(param, \"output_dim\", None)\n            shard_size = loaded_weight.size(output_dim) // self.tp_size\n            start_idx = self.tp_rank * shard_size\n\n            loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)\n\n            param.shard_id.append(loaded_shard_id)\n            param.shard_id_map[loaded_shard_id] = len(param.data_container)\n            param.data_container.append(loaded_weight)\n            return\n\n        param_data = param.data\n        output_dim = getattr(param, \"output_dim\", None)\n        # Special case for AQLM codebooks.\n        is_metadata = getattr(param, \"is_metadata\", False)\n\n        # Special case for per-tensor scales in fused case.\n        needs_scalar_to_array = getattr(param, \"needs_scalar_to_array\", False)\n\n        if loaded_shard_id is None:\n            # Loaded weight is already fused on disk (qkv/mlp).\n            if output_dim is None:\n                if needs_scalar_to_array:\n                    param_data, loaded_weight = adjust_scalar_to_fused_array(\n                        param_data, loaded_weight, 0\n                    )\n\n                assert param_data.shape == loaded_weight.shape\n                param_data.copy_(loaded_weight)\n                return\n            shard_offsets = [\n                # (shard_id, shard_offset, shard_size)\n                (\"q\", 0, self.total_num_heads * self.head_size),\n                (\n                    \"k\",\n                    self.total_num_heads * self.head_size,\n                    self.total_num_kv_heads * self.head_size,\n                ),\n                (\n                    \"v\",\n                    (self.total_num_heads + self.total_num_kv_heads) * self.head_size,\n                    self.total_num_kv_heads * self.v_head_size,\n                ),\n            ]\n            use_bitsandbytes_4bit = getattr(param, \"use_bitsandbytes_4bit\", False)\n\n            packed_dim = getattr(param, \"packed_dim\", None)\n            if _is_cpu:\n                shard_offsets = adjust_shard_offsets(\n                    shard_offsets, loaded_weight, output_dim\n                )\n\n            for shard_id, shard_offset, shard_size in shard_offsets:\n                # Special case for Quantized Weights.\n                # If quantized, we need to adjust the offset and size to account\n                # for the packing.\n                if packed_dim == output_dim:\n                    shard_size = shard_size // param.pack_factor\n                    shard_offset = shard_offset // param.pack_factor\n\n                    # Special case for Marlin.\n                    shard_size, shard_offset = adjust_marlin_shard(\n                        param, shard_size, shard_offset\n                    )\n\n                if use_bitsandbytes_4bit:\n                    orig_qkv_offsets = {\n                        \"q\": (0, self.total_num_heads * self.head_size),\n                        \"k\": (\n                            self.total_num_heads * self.head_size,\n                            self.total_num_kv_heads * self.head_size,\n                        ),\n                        \"v\": (\n                            (self.total_num_heads + self.total_num_kv_heads)\n                            * self.head_size,\n                            self.total_num_kv_heads * self.v_head_size,\n                        ),\n                        \"total\": (\n                            (self.total_num_heads + self.total_num_kv_heads)\n                            * self.head_size\n                            + self.total_num_kv_heads * self.v_head_size,\n                            0,\n                        ),\n                    }\n\n                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(\n                        param, orig_qkv_offsets, shard_id\n                    )\n\n                if not self.use_presharded_weights:\n                    loaded_weight_shard = loaded_weight.narrow(\n                        output_dim, shard_offset, shard_size\n                    )\n                self.weight_loader(param, loaded_weight_shard, shard_id)\n            return\n\n        assert loaded_shard_id in [\"q\", \"k\", \"v\"]\n\n        # If output dim is defined, use the default loading process.\n        if output_dim is not None:\n            if loaded_shard_id == \"q\":\n                shard_offset = 0\n                shard_size = self.num_heads * self.head_size\n            elif loaded_shard_id == \"k\":\n                shard_offset = self.num_heads * self.head_size\n                shard_size = self.num_kv_heads * self.head_size\n            elif loaded_shard_id == \"v\":\n                shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size\n                shard_size = self.num_kv_heads * self.v_head_size\n            # Special case for Quantized Weights.\n            # If quantized, we need to adjust the offset and size to account\n            # for the packing.\n            packed_dim = getattr(param, \"packed_dim\", None)\n            if packed_dim == output_dim:\n                shard_size = shard_size // param.pack_factor\n                shard_offset = shard_offset // param.pack_factor\n\n                # Special case for Marlin.\n                shard_size, shard_offset = adjust_marlin_shard(\n                    param, shard_size, shard_offset\n                )\n\n            use_bitsandbytes_4bit = getattr(param, \"use_bitsandbytes_4bit\", False)\n            if use_bitsandbytes_4bit:\n                orig_qkv_offsets = {\n                    \"q\": (0, self.num_heads * self.head_size),\n                    \"k\": (\n                        self.num_heads * self.head_size,\n                        self.num_kv_heads * self.head_size,\n                    ),\n                    \"v\": (\n                        (self.num_heads + self.num_kv_heads) * self.head_size,\n                        self.num_kv_heads * self.v_head_size,\n                    ),\n                    \"total\": (\n                        (self.num_heads + self.num_kv_heads) * self.head_size\n                        + self.num_kv_heads * self.v_head_size,\n                        0,\n                    ),\n                }\n                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(\n                    param, orig_qkv_offsets, loaded_shard_id\n                )\n\n            param_data = param_data.narrow(output_dim, shard_offset, shard_size)\n            if loaded_shard_id == \"q\":\n                shard_id = self.tp_rank\n            else:\n                shard_id = self.tp_rank // self.num_kv_head_replicas\n            start_idx = shard_id * shard_size\n\n            if _is_cpu:\n                from sglang.srt.model_loader.weight_utils import (\n                    narrow_padded_param_and_loaded_weight,\n                )\n\n                param_data, loaded_weight = narrow_padded_param_and_loaded_weight(\n                    param_data,\n                    loaded_weight,\n                    0,  # param_data_start\n                    start_idx,\n                    output_dim,\n                    shard_size,\n                    not use_bitsandbytes_4bit and not self.use_presharded_weights,\n                )\n            else:\n                # bitsandbytes loads the weights of the specific portion\n                # no need to narrow here\n                if not use_bitsandbytes_4bit and not self.use_presharded_weights:\n                    loaded_weight = loaded_weight.narrow(\n                        output_dim, start_idx, shard_size\n                    )\n\n        # Special case for for AQLM codebooks.\n        elif is_metadata:\n            # metadata indicates fixed size concatenated along dim 0\n            shard_size = loaded_weight.shape[0]\n            shard_index = [\"q\", \"k\", \"v\"].index(loaded_shard_id)\n            param_data = param_data.narrow(0, shard_index * shard_size, shard_size)\n        # Special case for per-tensor scales in fused case.\n        elif needs_scalar_to_array:\n            param_data, loaded_weight = adjust_scalar_to_fused_array(\n                param_data, loaded_weight, loaded_shard_id\n            )\n        else:\n            ignore_warning = getattr(param, \"ignore_warning\", False)\n            if not ignore_warning:\n                logger.warning(\n                    \"Loading a weight without `output_dim` attribute in \"\n                    \"QKVParallelLinear, assume the weight is the same \"\n                    \"for all partitions.\"\n                )\n\n        assert (\n            param_data.shape == loaded_weight.shape\n        ), f\"{param_data.shape=} {loaded_weight.shape=}\"\n        param_data.copy_(loaded_weight)\n\n\nclass RowParallelLinear(LinearBase):\n    \"\"\"Linear layer with row parallelism.\n\n    The linear layer is defined as Y = XA + b. A is parallelized along\n    its first dimension and X along its second dimension as:\n               -   -\n              | A_1 |\n              | .   |\n          A = | .   |        X = [X_1, ..., X_p]\n              | .   |\n              | A_p |\n               -   -\n    Arguments:\n        input_size: first dimension of matrix A.\n        output_size: second dimension of matrix A.\n        bias: If true, add bias. Note that bias is not parallelized.\n        input_is_parallel: If true, we assume that the input is already\n                           split across the GPUs and we do not split\n                           again.\n        skip_bias_add: This was added to enable performance optimization where\n                       bias can be fused with other element-wise operations.\n                       We skip adding bias but instead return it.\n        params_dtype: Data type for the parameters.\n        quant_config: Quantization configure.\n    \"\"\"\n\n    def __init__(\n        self,\n        input_size: int,\n        output_size: int,\n        bias: bool = True,\n        input_is_parallel: bool = True,\n        skip_bias_add: bool = False,\n        params_dtype: Optional[torch.dtype] = None,\n        reduce_results: bool = True,\n        quant_config: Optional[QuantizationConfig] = None,\n        prefix: str = \"\",\n        tp_rank: Optional[int] = None,\n        tp_size: Optional[int] = None,\n        use_presharded_weights: bool = False,\n        use_dp_attention_reduce: bool = False,\n    ):\n        quant_config = None if _disable_hip_linear_quant else quant_config\n        super().__init__(\n            input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix\n        )\n\n        self.input_is_parallel = input_is_parallel\n        self.reduce_results = reduce_results\n        self.use_dp_attention_reduce = use_dp_attention_reduce\n\n        # Divide the weight matrix along the last dimension.\n        if tp_rank is None:\n            tp_rank = get_tensor_model_parallel_rank()\n        if tp_size is None:\n            tp_size = get_tensor_model_parallel_world_size()\n        self.tp_rank, self.tp_size = tp_rank, tp_size\n        self.input_size_per_partition = divide(input_size, self.tp_size)\n        assert self.quant_method is not None\n        self.use_presharded_weights = use_presharded_weights\n\n        self.quant_method.create_weights(\n            layer=self,\n            input_size_per_partition=self.input_size_per_partition,\n            output_partition_sizes=[self.output_size],\n            input_size=self.input_size,\n            output_size=self.output_size,\n            params_dtype=self.params_dtype,\n            weight_loader=(\n                self.weight_loader_v2\n                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED\n                else self.weight_loader\n            ),\n        )\n\n        if bias:\n            self.bias = Parameter(torch.zeros(self.output_size, dtype=params_dtype))\n            set_weight_attrs(\n                self.bias,\n                {\n                    \"output_dim\": 0,\n                    \"weight_loader\": self.weight_loader,\n                },\n            )\n        else:\n            self.register_parameter(\"bias\", None)\n\n    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):\n        input_dim = getattr(param, \"input_dim\", None)\n        use_bitsandbytes_4bit = getattr(param, \"use_bitsandbytes_4bit\", False)\n\n        # Special case for GGUF\n        is_gguf_weight = getattr(param, \"is_gguf_weight\", False)\n        is_gguf_weight_type = getattr(param, \"is_gguf_weight_type\", False)\n        if is_gguf_weight_type:\n            param.weight_type = loaded_weight.item()\n\n        # Materialize GGUF UninitializedParameter\n        if is_gguf_weight and isinstance(param, UninitializedParameter):\n            weight_shape = list(loaded_weight.shape)\n            if input_dim:\n                weight_shape[input_dim] = weight_shape[input_dim] // self.tp_size\n            param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)\n\n        param_data = param.data\n        # bitsandbytes loads the weights of the specific portion\n        # no need to narrow here\n        if (\n            input_dim is not None\n            and not use_bitsandbytes_4bit\n            and not self.use_presharded_weights\n        ):\n            shard_size = param_data.shape[input_dim]\n            start_idx = self.tp_rank * shard_size\n\n            if _is_cpu:\n                from sglang.srt.model_loader.weight_utils import (\n                    narrow_padded_param_and_loaded_weight,\n                )\n\n                param_data, loaded_weight = narrow_padded_param_and_loaded_weight(\n                    param_data,\n                    loaded_weight,\n                    0,  # param_data_start\n                    start_idx,\n                    input_dim,\n                    shard_size,\n                )\n            else:\n                # Padding for special case like qwen2_5_VL's mlp which is not 8-aligned\n                end_idx = start_idx + shard_size\n                if end_idx > loaded_weight.shape[input_dim]:\n                    loaded_weight = pad_or_narrow_weight(\n                        loaded_weight, input_dim, start_idx, shard_size\n                    )\n                else:\n                    loaded_weight = loaded_weight.narrow(\n                        input_dim, start_idx, shard_size\n                    )\n\n        # Special case for loading scales off disk, which often do not\n        # have a shape (such as in the case of AutoFP8).\n        if len(loaded_weight.shape) == 0:\n            loaded_weight = loaded_weight.reshape(1)\n\n        assert (\n            param_data.shape == loaded_weight.shape\n        ), f\"{param_data.shape=} {loaded_weight.shape=}\"\n        param_data.copy_(loaded_weight)\n\n    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):\n\n        # Special case for loading scales off disk, which often do not\n        # have a shape (such as in the case of AutoFP8).\n        if len(loaded_weight.shape) == 0:\n            assert loaded_weight.numel() == 1\n            loaded_weight = loaded_weight.reshape(1)\n\n        if isinstance(param, RowvLLMParameter):\n            # This `BasevLLMParameter` is defined in sglang/srt/layers/parameter.py,\n            # It supports additional parameters like tp_rank and use_presharded_weights.\n            param.load_row_parallel_weight(\n                loaded_weight,\n                tp_rank=self.tp_rank,\n                use_presharded_weights=self.use_presharded_weights,\n            )\n        else:\n            # `params` is defined in `vllm/model_executor/parameter.py`,\n            # It does not support additional parameters.\n            # However, after QuantizedRL reload, params might still need tp_rank\n            try:\n                param.load_row_parallel_weight(\n                    loaded_weight,\n                    tp_rank=self.tp_rank,\n                    use_presharded_weights=self.use_presharded_weights,\n                )\n            except TypeError:\n                # Fallback for parameters that don't accept additional args\n                param.load_row_parallel_weight(loaded_weight)\n\n    def forward(self, input_, skip_all_reduce=False):\n        if self.input_is_parallel:\n            input_parallel = input_\n        else:\n            splitted_input = split_tensor_along_last_dim(\n                input_, num_partitions=self.tp_size\n            )\n            input_parallel = splitted_input[self.tp_rank].contiguous()\n\n        # Matrix multiply.\n        assert self.quant_method is not None\n        # Only fuse bias add into GEMM for rank 0 (this ensures that\n        # bias will not get added more than once in TP>1 case)\n        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias\n        with use_symmetric_memory(\n            get_tp_group(), disabled=not is_allocation_symmetric()\n        ):\n            output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)\n\n        if self.reduce_results and self.tp_size > 1 and not skip_all_reduce:\n            if self.use_dp_attention_reduce:\n                output = get_attention_tp_group().all_reduce(output_parallel)\n            else:\n                output = tensor_model_parallel_all_reduce(output_parallel)\n        else:\n            output = output_parallel\n\n        output_bias = self.bias if self.skip_bias_add else None\n\n        return output, output_bias\n\n    def extra_repr(self) -> str:\n        s = f\"input_features={self.input_size_per_partition}\"\n        s += f\", output_features={self.output_size}\"\n        s += f\", bias={self.bias is not None}\"\n        s += f\", tp_size={self.tp_size}\"\n        s += f\", reduce_results={self.reduce_results}\"\n        return s\n\n\nclass MergedColumnParallelRepeatedLinear(LinearBase):\n    \"\"\"Merged column parallel linear and repeated linear layer.\n\n    TODO: quantization is not supported yet.\n    Args:\n        input_size: input dimension of the linear layer.\n        column_output_sizes: output dimension of the column linear layers.\n        repeated_output_sizes: output dimension of the repeated linear layers.\n        skip_bias_add: If true, skip adding bias but instead return it.\n        params_dtype: Data type for the parameters.\n        quant_config: Quantization configure.\n    \"\"\"\n\n    def __init__(\n        self,\n        input_size: int,\n        column_output_sizes: List[int],\n        repeated_output_sizes: List[int],\n        skip_bias_add: bool = False,\n        params_dtype: Optional[torch.dtype] = None,\n        quant_config: Optional[QuantizationConfig] = None,\n        prefix: str = \"\",\n    ):\n        output_size = sum(column_output_sizes) + sum(repeated_output_sizes)\n        super().__init__(\n            input_size=input_size,\n            output_size=output_size,\n            skip_bias_add=skip_bias_add,\n            params_dtype=params_dtype,\n            quant_config=quant_config,\n            prefix=prefix,\n        )\n        self.num_column_parallel = len(column_output_sizes)\n        self.tp_rank = get_tensor_model_parallel_rank()\n        self.tp_size = get_tensor_model_parallel_world_size()\n\n        self.output_partition_sizes = [\n            divide(x, self.tp_size) for x in column_output_sizes\n        ] + repeated_output_sizes\n        self.quant_method.create_weights(\n            layer=self,\n            input_size_per_partition=self.input_size,\n            output_partition_sizes=self.output_partition_sizes,\n            input_size=self.input_size,\n            output_size=self.output_size,\n            params_dtype=self.params_dtype,\n            skip_block_quant_check=True,\n            weight_loader=self.weight_loader,\n        )\n\n        self.prefix = prefix\n\n    def forward(self, input_: torch.Tensor) -> torch.Tensor:\n        return self.quant_method.apply(self, input_)\n\n    def weight_loader(\n        self, param: Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int\n    ) -> torch.Tensor:\n        output_dim = param.output_dim\n        shard_offset = sum(self.output_partition_sizes[:loaded_shard_id])\n        shard_size = self.output_partition_sizes[loaded_shard_id]\n        param_data = param.data.narrow(output_dim, shard_offset, shard_size)\n\n        if loaded_shard_id < self.num_column_parallel:\n            start_idx = self.tp_rank * shard_size\n            loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)\n\n        param_data.copy_(loaded_weight)\n\n\nclass ColumnParallelBatchedLinear(nn.Module):\n    \"\"\"Column parallel batched linear layer.\n\n    TODO: quantization is not supported yet.\n    Args:\n        batch: batch dimension of the linear layer.\n        input_size: input dimension of the linear layer.\n        output_size: output dimension of the linear layer.\n        dtype: Data type for the parameters.\n    \"\"\"\n\n    def __init__(\n        self, batch: int, input_size: int, output_size: int, dtype: torch.dtype\n    ):\n        super().__init__()\n        self.tp_rank = get_tensor_model_parallel_rank()\n        self.tp_size = get_tensor_model_parallel_world_size()\n        self.weight = nn.Parameter(\n            torch.empty(batch, output_size // self.tp_size, input_size, dtype=dtype),\n            requires_grad=False,\n        )\n        setattr(self.weight, \"weight_loader\", self.weight_loader)\n\n    def forward(self, input: torch.Tensor) -> torch.Tensor:\n        return torch.bmm(input, self.weight.transpose(-1, -2))\n\n    def weight_loader(\n        self, param: Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int\n    ) -> torch.Tensor:\n        shard_size = self.weight.shape[-2]\n        start_idx = self.tp_rank * shard_size\n        loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)\n        param.data[loaded_shard_id].copy_(loaded_weight)\n"
  },
  {
    "path": "python/sglang/srt/layers/model_parallel.py",
    "content": "\"\"\"\nCommon utilities for torch model parallelism.\n\"\"\"\n\nfrom typing import Optional, Sequence\n\nimport torch\nimport torch.nn as nn\nfrom torch.distributed.device_mesh import DeviceMesh\n\ntry:\n    import torch.distributed.tensor as dt\nexcept ImportError:\n    # torch 2.4 or older\n    import torch.distributed._tensor as dt\n\nfrom torch.distributed.tensor.parallel import (\n    ColwiseParallel,\n    RowwiseParallel,\n    parallelize_module,\n)\n\n\ndef _shard_tensor(\n    full_tensor: torch.Tensor,\n    device_mesh: DeviceMesh,\n    placements: Sequence[dt.Shard],\n) -> \"dt.DTensor\":\n    \"\"\"\n    Locally shards a full tensor based on indicated sharding arrangement, and\n    returns a DTensor containing the local shard.\n\n    .. warning:: This is a private API that is subject to change. It skips the\n        communication otherwise required by `distribute_tensor`. It is only\n        applicable to cases where all ranks have the same `full_tensor`. For\n        example, in distributed inference all ranks load from the same\n        checkpoint. This API will not check for data equality between ranks, it\n        is thus user's responsibility to ensure the `full_tensor` is the same\n        across ranks.\n\n    Args:\n        full_tensor (torch.Tensor): the full tensor to be sharded.\n        device_mesh (:class:`DeviceMesh`): DeviceMesh to place the\n            DTensor.  Must have same dimension as the number of placements.\n        placements (Sequence[:class:`Shard`]): the placements that\n            describes how to place the local tensor on DeviceMesh.\n\n    Returns:\n        A :class:`DTensor` object with the shard as its local tensor.\n\n    Examples:\n        >>> # xdoctest: +SKIP(\"need world_size and rank\")\n        >>> device_mesh = dist.init_device_mesh(\"cuda\", (world_size,))\n        >>> full_tensor = torch.arange(world_size, device=f\"cuda:{rank}\")\n        >>> dtensor = _shard_tensor(full_tensor, device_mesh, [Shard(1)])\n    \"\"\"\n    shape, offset = dt._utils.compute_local_shape_and_global_offset(\n        full_tensor.shape, device_mesh, placements\n    )\n    slices = [\n        slice(cur_offset, cur_offset + cur_shape)\n        for cur_shape, cur_offset in zip(shape, offset)\n    ]\n    local_tensor = full_tensor[slices]\n    return dt.DTensor.from_local(local_tensor, device_mesh, placements)\n\n\nclass ColwiseParallelSharded(ColwiseParallel):\n    \"\"\"\n    A version of ColwiseParallel where the local weight has been already\n    sharded.  This is used for the fused wqkv case, where during loading, we\n    already sharded wq, wk, wv before fusing them.\n    \"\"\"\n\n    # Override the _partition_linear_fn in ColwiseParallel\n    def _partition_linear_fn(self, name, module, device_mesh):\n        # colwise shard weight/bias to Shard(0), weight be Shard(0)\n        # means Colwise as Linear is input * weight^T + bias, where\n        # weight would become Shard(1)\n        for name, param in module.named_parameters():\n            dtensor = dt.DTensor.from_local(param, device_mesh, [dt.Shard(0)])\n            dist_param = torch.nn.Parameter(dtensor, requires_grad=False)\n            module.register_parameter(name, dist_param)\n\n\nclass RowwiseParallelMaybeWait(RowwiseParallel):\n    \"\"\"\n    A version of RowwiseParallel that waits for the output (establish dependency\n    between comm stream and compute stream in CUDA sense) before going into the\n    next op. This is needed to workaround the current interaction between\n    AsyncCollectiveTensor and multi-platform ops, such as `RMSNorm`.\n    \"\"\"\n\n    def _partition_linear_fn(self, name, module, device_mesh):\n        # Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)\n        # means Rowwise as nn.Linear is input * weight^T + bias, where\n        # weight would become Shard(0)\n        module.register_parameter(\n            \"weight\",\n            nn.Parameter(_shard_tensor(module.weight, device_mesh, [dt.Shard(1)])),\n        )\n        if getattr(module, \"bias\", None) is not None:\n            # The Linear module has bias\n            module.register_parameter(\n                \"bias\",\n                nn.Parameter(\n                    dt.distribute_tensor(module.bias, device_mesh, [dt.Replicate()])\n                ),\n            )\n\n    @staticmethod\n    def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):\n        outputs = super(\n            RowwiseParallelMaybeWait, RowwiseParallelMaybeWait\n        )._prepare_output_fn(\n            output_layouts, use_local_output, mod, outputs, device_mesh\n        )\n        return torch.distributed._functional_collectives.wait_tensor(outputs)\n\n\ndef tensor_parallel(\n    module: torch.nn.Module,\n    device_mesh: Optional[DeviceMesh] = None,\n):\n    \"\"\"\n    Tensor parallelize the model across the given device mesh.\n    Args:\n        module (`torch.nn.Module`):\n            The module to tensor parallelize.\n        device_mesh (`torch.distributed.DeviceMesh`):\n            The device mesh to use for tensor parallelism.\n    \"\"\"\n\n    # Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module.\n    # No op if `_tp_plan` attribute does not exist under the module.\n    # This is a helper function to be used with `model.apply` to recursively\n    # parallelize a model.\n    def tplize(mod: torch.nn.Module) -> None:\n        tp_plan = getattr(mod, \"_tp_plan\", None)\n        if tp_plan is None:\n            return\n        for child_name, tp_style in tp_plan.items():\n            submod = mod.get_submodule(child_name)\n            if tp_style == \"Colwise\":\n                parallelize_module(submod, device_mesh, ColwiseParallel())\n            elif tp_style == \"Rowwise\":\n                parallelize_module(submod, device_mesh, RowwiseParallelMaybeWait())\n            elif tp_style == \"Colwise_Sharded\":\n                parallelize_module(submod, device_mesh, ColwiseParallelSharded())\n            else:\n                raise ValueError(f\"Unknown TP style {tp_style}\")\n\n    # `apply` is a native method of `nn.Module` that recursively applies a\n    # function to every submodule.\n    module.apply(tplize)\n"
  },
  {
    "path": "python/sglang/srt/layers/moe/cutlass_moe.py",
    "content": "\"\"\"CUTLASS based Fused MoE kernels.\"\"\"\n\nfrom typing import Optional, Tuple\n\nimport torch\n\nfrom sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams\nfrom sglang.srt.utils import is_cuda, is_sm90_supported, is_sm100_supported\n\n_is_cuda = is_cuda()\nif _is_cuda:\n    from sgl_kernel import (\n        apply_shuffle_mul_sum,\n        es_fp8_blockwise_scaled_grouped_mm,\n        es_sm100_mxfp8_blockscaled_grouped_mm,\n        es_sm100_mxfp8_blockscaled_grouped_quant,\n        fp8_blockwise_scaled_grouped_mm,\n        prepare_moe_input,\n        shuffle_rows,\n        silu_and_mul,\n    )\n\n    from sglang.jit_kernel.nvfp4 import (\n        cutlass_fp4_group_mm,\n        scaled_fp4_experts_quant,\n    )\n\n\ndef cutlass_fused_experts_fp8(\n    a: torch.Tensor,\n    w1_q: torch.Tensor,\n    w2_q: torch.Tensor,\n    w1_scale: torch.Tensor,\n    w2_scale: torch.Tensor,\n    topk_weights: torch.Tensor,\n    topk_ids: torch.Tensor,\n    a1_strides: torch.Tensor,\n    c1_strides: torch.Tensor,\n    a2_strides: torch.Tensor,\n    c2_strides: torch.Tensor,\n    workspace: torch.Tensor,\n    a_ptrs: torch.Tensor,\n    b_ptrs: torch.Tensor,\n    out_ptrs: torch.Tensor,\n    a_scales_ptrs: torch.Tensor,\n    b_scales_ptrs: torch.Tensor,\n    expert_offsets: torch.Tensor,\n    problem_sizes1: torch.Tensor,\n    problem_sizes2: torch.Tensor,\n    use_fp8_blockscale: bool = True,\n    use_mxfp8: bool = False,\n    output: Optional[torch.Tensor] = None,\n    enable_es: Tuple[bool, bool] = (False, False),\n) -> torch.Tensor:\n    \"\"\"Performs Fused MoE computation using CUTLASS-like kernels with FP8 weights and activations.\n\n    This function implements a Mixture of Experts (MoE) layer with a SwiGLU/SiLU\n    activation, leveraging custom kernels likely derived from CUTLASS principles\n    for grouped matrix multiplication (`fp8_blockwise_scaled_grouped_mm`) and\n    data preparation (`prepare_moe_input`, `silu_and_mul`).\n\n    It handles per-token routing, quantizes input activations to FP8 with\n    per-token scales, performs the expert computations using FP8 GEMMs with\n    pre-quantized FP8 weights (per-block scales), applies the SiLU activation,\n    and combines the results weighted by the router scores.\n\n    Args:\n        a (torch.Tensor): Input activations. Shape: `(m, k)`, where `m` is the total\n            number of tokens and `k` is the hidden size. Expected dtype: `torch.half`\n            or `torch.bfloat16`.\n        w1_q (torch.Tensor): Pre-quantized FP8 weight tensor for the first GEMM\n            (up-projection part of SwiGLU). Expected shape: `(E, k, n*2)`, where\n            `E` is the number of experts, `k` is the hidden size, and `n*2` is the\n            intermediate size (`I`). Expected dtype: `torch.float8_e4m3fn`.\n            Note: This shape implies weights are stored as (num_experts, hidden_size, intermediate_size).\n        w2_q (torch.Tensor): Pre-quantized FP8 weight tensor for the second GEMM\n            (down-projection). Expected shape: `(E, n, k)`, where `n` is half the\n            intermediate size (`I // 2`). Expected dtype: `torch.float8_e4m3fn`.\n            Note: This shape implies weights are stored as (num_experts, intermediate_size // 2, hidden_size).\n        w1_scale (torch.Tensor): Scales corresponding to `w1_q` (per-block scales).\n            Shape: `(E, num_blocks_n, num_blocks_k)`. Dtype: `torch.float32`.\n        w2_scale (torch.Tensor): Scales corresponding to `w2_q` (per-block scales).\n             Shape: `(E, num_blocks_k, num_blocks_n)`. Dtype: `torch.float32`.\n        topk_weights (torch.Tensor): Router weights for the selected top-k experts\n            for each token. Shape: `(m, topk)`. Dtype should ideally match `a`.\n        topk_ids (torch.Tensor): Indices of the selected top-k experts for each token.\n            Shape: `(m, topk)`. Dtype: `torch.int32`.\n        a1_strides (torch.Tensor): Stride information for the first GEMM's 'a' input.\n            Passed directly to the underlying kernel. Expected shape `(E,)`, dtype `torch.int64`.\n            Note: Its exact usage within `fp8_blockwise_scaled_grouped_mm` needs clarification\n            as it's passed as both a_stride and b_stride in the first call.\n        c1_strides (torch.Tensor): Stride information for the first GEMM's 'c' output.\n            Passed directly to the underlying kernel. Expected shape `(E,)`, dtype `torch.int64`.\n        a2_strides (torch.Tensor): Stride information for the second GEMM's 'a' input.\n            Passed directly to the underlying kernel. Expected shape `(E,)`, dtype `torch.int64`.\n            Note: Its exact usage within `fp8_blockwise_scaled_grouped_mm` needs clarification\n            as it's passed as both a_stride and b_stride in the second call.\n        c2_strides (torch.Tensor): Stride information for the second GEMM's 'c' output.\n            Passed directly to the underlying kernel. Expected shape `(E,)`, dtype `torch.int64`.\n        workspace (torch.Tensor): Reusable workspace for the underlying kernel.\n        a_ptrs (torch.Tensor): Pointers container for calculating offsets of the input activations for each expert.\n        b_ptrs (torch.Tensor): Pointers container for calculating offsets of the input weights for each expert.\n        out_ptrs (torch.Tensor): Pointers container for calculating offsets of the output activations for each expert.\n        a_scales_ptrs (torch.Tensor): Pointers container for calculating offsets of the input scales for each expert.\n        b_scales_ptrs (torch.Tensor): Pointers container for calculating offsets of the input scales for each expert.\n        use_fp8_blockscale (bool, optional): Flag indicating usage of FP8 with\n            block scaling. Currently, only `True` is supported. Defaults to `True`.\n        use_mxfp8 (bool, optional): Flag indicating usage of MXFP8 (UE8M0 scales)\n            with SM100 expert-specialization kernels. Defaults to `False`.\n        output (torch.Tensor, optional): Output tensor. If not provided, a new tensor will be created.\n        enable_es (tuple(bool, bool)): Flag indicating usage of expert specialization kernel for (up-projection, down-projection)\n    Returns:\n        torch.Tensor: The computed MoE layer output. Shape: `(m, k)`, dtype matches `a`.\n\n    Raises:\n        AssertionError: If input shapes, dtypes, or flags are inconsistent or unsupported.\n        NotImplementedError: If CUDA is not available or `sgl_kernel` is not properly installed.\n    \"\"\"\n    assert use_fp8_blockscale, \"Only support fp8 blockscale for now\"\n    assert topk_weights.shape == topk_ids.shape, \"topk shape mismatch\"\n    assert w1_q.dtype == torch.float8_e4m3fn\n    assert w2_q.dtype == torch.float8_e4m3fn\n    assert a.shape[1] == w1_q.shape[1], \"Hidden size mismatch w1\"\n    assert w1_q.shape[2] == w2_q.shape[1] * 2, \"Hidden size mismatch w2\"\n    assert w1_q.shape[0] == w2_q.shape[0], \"Expert number mismatch\"\n    assert w1_q.shape[0] == w2_q.shape[0], \"Weights expert number mismatch\"\n    assert w1_q.shape[0] == w1_scale.shape[0], \"w1 scales expert number mismatch\"\n    assert w1_q.shape[0] == w2_scale.shape[0], \"w2 scales expert number mismatch\"\n    assert a.dtype in [torch.half, torch.bfloat16], \"Invalid output dtype\"\n\n    if is_cuda:\n        from sglang.srt.layers.quantization.fp8_kernel import (\n            sglang_per_token_group_quant_fp8,\n        )\n    es_up, es_down = enable_es\n    out_dtype = a.dtype\n    num_experts = w1_q.size(0)\n    m = a.size(0)\n    k = w1_q.size(1)\n    n = w2_q.size(1)\n\n    topk = topk_ids.size(1)\n    device = a.device\n\n    a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)\n    c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)\n\n    if use_mxfp8:\n        assert es_up and es_down, \"MXFP8 requires expert-specialization for both GEMMs\"\n        assert is_sm100_supported(), \"MXFP8 requires SM100\"\n        assert k % 32 == 0, \"MXFP8 requires hidden size to be divisible by 32\"\n        assert n % 32 == 0, \"MXFP8 requires intermediate size to be divisible by 32\"\n        assert w1_scale.dtype == torch.uint8, \"MXFP8 w1_scale must be uint8\"\n        assert w2_scale.dtype == torch.uint8, \"MXFP8 w2_scale must be uint8\"\n        expected_w1_scale_shape = (\n            num_experts,\n            w1_q.shape[1] // 32,\n            w1_q.shape[2],\n        )\n        expected_w2_scale_shape = (\n            num_experts,\n            w2_q.shape[1] // 32,\n            w2_q.shape[2],\n        )\n        assert (\n            w1_scale.shape == expected_w1_scale_shape\n        ), f\"MXFP8 w1_scale must be {expected_w1_scale_shape}, got {w1_scale.shape}\"\n        assert (\n            w2_scale.shape == expected_w2_scale_shape\n        ), f\"MXFP8 w2_scale must be {expected_w2_scale_shape}, got {w2_scale.shape}\"\n\n        mxfp8_blockscale_align = 128\n        total_tokens = m * topk\n        nonzero_experts = min(num_experts, total_tokens)\n        max_total = total_tokens + (mxfp8_blockscale_align - 1) * nonzero_experts\n        max_blockscale = (\n            (max_total + mxfp8_blockscale_align - 1) // mxfp8_blockscale_align\n        ) * mxfp8_blockscale_align\n\n    blockscale_offsets = None\n    if use_mxfp8 and (es_up or es_down):\n        blockscale_offsets = torch.empty(\n            (num_experts + 1,), dtype=torch.int32, device=device\n        )\n\n    prepare_moe_input(\n        topk_ids,\n        expert_offsets,\n        problem_sizes1,\n        problem_sizes2,\n        a_map,\n        c_map,\n        num_experts,\n        n,\n        k,\n        blockscale_offsets,\n    )\n\n    if use_mxfp8 and es_up:\n        rep_a = shuffle_rows(a, a_map, (m * topk, k))\n        rep_a_q = torch.empty_like(rep_a, dtype=torch.float8_e4m3fn)\n        rep_a1_scales = torch.empty(\n            (max_blockscale, k // 32), dtype=torch.uint8, device=device\n        )\n        es_sm100_mxfp8_blockscaled_grouped_quant(\n            rep_a,\n            problem_sizes1,\n            expert_offsets[:-1],\n            blockscale_offsets[:-1],\n            rep_a_q,\n            rep_a1_scales,\n        )\n    else:\n        a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128)\n        rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k))\n        rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128)))\n\n    c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)\n    c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype)\n\n    a_sf_layout = torch.empty((num_experts, 5), device=device, dtype=torch.int)\n    w_sf_layout = torch.empty((num_experts, 5), device=device, dtype=torch.int)\n\n    if is_sm90_supported() and es_up:\n        es_fp8_blockwise_scaled_grouped_mm(\n            c1,\n            rep_a_q,\n            w1_q,\n            rep_a1_scales,\n            w1_scale,\n            a1_strides,\n            a1_strides,\n            c1_strides,\n            problem_sizes1,\n            expert_offsets[:-1],\n            workspace,\n        )\n    elif use_mxfp8 and es_up:\n        es_sm100_mxfp8_blockscaled_grouped_mm(\n            c1,\n            rep_a_q,\n            w1_q,\n            rep_a1_scales,\n            w1_scale,\n            problem_sizes1,\n            expert_offsets[:-1],\n            blockscale_offsets[:-1],\n        )\n    else:\n        fp8_blockwise_scaled_grouped_mm(\n            c1,\n            a_ptrs,\n            b_ptrs,\n            out_ptrs,\n            a_scales_ptrs,\n            b_scales_ptrs,\n            rep_a_q,\n            w1_q,\n            rep_a1_scales,\n            w1_scale,\n            a1_strides,\n            a1_strides,\n            c1_strides,\n            a_sf_layout,\n            w_sf_layout,\n            problem_sizes1,\n            expert_offsets[:-1],\n            workspace,\n        )\n\n    intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype)\n    silu_and_mul(c1, intermediate)\n\n    if use_mxfp8 and es_down:\n        intemediate_q = torch.empty_like(intermediate, dtype=torch.float8_e4m3fn)\n        a2_scale = torch.empty(\n            (max_blockscale, n // 32), dtype=torch.uint8, device=device\n        )\n        es_sm100_mxfp8_blockscaled_grouped_quant(\n            intermediate,\n            problem_sizes2,\n            expert_offsets[:-1],\n            blockscale_offsets[:-1],\n            intemediate_q,\n            a2_scale,\n        )\n    else:\n        intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128)\n\n    if is_sm90_supported() and es_down:\n        es_fp8_blockwise_scaled_grouped_mm(\n            c2,\n            intemediate_q,\n            w2_q,\n            a2_scale,\n            w2_scale,\n            a2_strides,\n            a2_strides,\n            c2_strides,\n            problem_sizes2,\n            expert_offsets[:-1],\n            workspace,\n        )\n    elif use_mxfp8 and es_down:\n        es_sm100_mxfp8_blockscaled_grouped_mm(\n            c2,\n            intemediate_q,\n            w2_q,\n            a2_scale,\n            w2_scale,\n            problem_sizes2,\n            expert_offsets[:-1],\n            blockscale_offsets[:-1],\n        )\n    else:\n        fp8_blockwise_scaled_grouped_mm(\n            c2,\n            a_ptrs,\n            b_ptrs,\n            out_ptrs,\n            a_scales_ptrs,\n            b_scales_ptrs,\n            intemediate_q,\n            w2_q,\n            a2_scale,\n            w2_scale,\n            a2_strides,\n            a2_strides,\n            c2_strides,\n            a_sf_layout,\n            w_sf_layout,\n            problem_sizes2,\n            expert_offsets[:-1],\n            workspace,\n        )\n\n    if output is None:\n        output = torch.empty((m, k), device=device, dtype=out_dtype)\n\n    apply_shuffle_mul_sum(c2, output, c_map, topk_weights.to(out_dtype))\n    return output\n\n\nFLOAT4_E2M1_MAX = 6.0\nFLOAT8_E4M3_MAX = 448.0\n\n\ndef cutlass_moe_fp4(\n    a: torch.Tensor,\n    a1_gscale: torch.Tensor,\n    w1_fp4: torch.Tensor,\n    w1_blockscale: torch.Tensor,\n    w1_alphas: torch.Tensor,\n    a2_gscale: torch.Tensor,\n    w2_fp4: torch.Tensor,\n    w2_blockscale: torch.Tensor,\n    w2_alphas: torch.Tensor,\n    topk_weights: torch.Tensor,\n    topk_ids: torch.Tensor,\n    params: CutlassMoEParams,\n    apply_router_weight_on_input: bool = False,\n):\n    \"\"\"\n    MoE implementation for FP4 Inputs\n\n    # Gemm 1\n    a: Input tensor: [m, k] (half/bfloat16)\n    a1_gscale: Activation scale per expert: [e]  (float32)\n    w1(gate up) (not an argument to cutlass_moe_fp4): [e, 2 * n, k]\n    w1_fp4: [e, 2 * n, k // 2], dtype: torch.uint8 (stacked fp4: E2M1)\n    (Note: `n` is the up projection output dim, `k` is the input dim in\n     full precision)\n    w1_blockscale: [e, 2 * n, k // block_size] (float8_e4m3)\n                   (Block size = 16 for NVFP4)\n\n    # Gemm 2\n    a2_gscale: Activation scale per expert: [e]\n    w2(down projection) (not an argument to cutlass_moe_fp4): [e, k, n]\n    w2_fp4: [e, k, n // 2], dtype: torch.uint8 (stacked E2M1)\n    w2_blockscale: [e, k, n // block_size], dtype: float8_e4m3\n\n    Strides for activations, weights and output in logical number of elements.\n    The activations & output stride is the number of elements to the next row.\n    The weights stride is the number of elements to the next row per expert.\n    For example, if the weight is [e, n, k], then the b_stride is a tensor of\n    shape [e] with each element being k. Similarly for activations, if the\n    shape is [m, k], then the a_stride has shape [e] with each value k.\n    Similarly for output, if the output is [m, n], then the c_stride is a\n    tensor of shape [e] with each element being k.\n\n    Note: cutlass_fp4_group_mm is designed to accept the strides of\n    activations and weights to be the same, so it is passed in as a single\n    tensor.\n    ab_strides_13: [e] dtype: int64 [Gemm 1: Activation / Weight strides]\n    ab_strides_2: [e] dtype: int64 [Gemm 2: Activation / Weight strides]\n    c_strides_13: [e] dtype: int64 [Gemm 1: Output Strides]\n    c_strides_2: [e] dtype: int64 [Gemm 1: Output Strides]\n\n    topk_weights: [m, topk] dtype: float8\n    topk_ids: [m, topk] dtype: float8\n\n    m, n, k: Unquantized weight shapes, dtype: int\n    e: number of experts for the current rank, dtype: int\n    assumes that topk < k < n to satisfy - up/down projection expectations.\n    \"\"\"\n    assert topk_weights.shape == topk_ids.shape, \"topk shape mismatch\"\n    assert w1_fp4.dtype == torch.uint8, \"weight 1 must be uint8\"\n    assert w2_fp4.dtype == torch.uint8, \"weight 2 must be uint8\"\n    assert (\n        w1_fp4.ndim == 3\n        and w2_fp4.ndim == 3\n        and w1_blockscale.ndim == 3\n        and w2_blockscale.ndim == 3\n    ), \"All Weights must be of rank 3 for cutlass_moe_fp4\"\n    m_a, k_a = a.shape\n    e_w1, nx2_w1, half_k_w1 = w1_fp4.shape\n    e_w2, k_w2, half_n_w2 = w2_fp4.shape\n\n    assert e_w1 == e_w2 and e_w1 == params.num_experts, (\n        \"Number of experts must match\",\n        \" between weights.\",\n    )\n    assert (\n        k_a // 2 == half_k_w1 and params.hidden_size == k_w2\n    ), \"Hidden size mismatch between a, w1 and w2\"\n    assert (\n        nx2_w1 == params.intermediate_size_per_partition * 2\n        and half_n_w2 == params.intermediate_size_per_partition // 2\n    ), (\"mismatch in \" \"expected `n`\")\n    assert 2 * half_k_w1 == k_w2, \"Hidden size mismatch w2 and w1\"\n    assert a.dtype in [torch.half, torch.bfloat16], \"Invalid input dtype\"\n\n    out_dtype = a.dtype\n    num_topk = topk_ids.shape[1]\n    device = a.device\n    a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)\n    c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)\n    prepare_moe_input(\n        topk_ids,\n        params.expert_offsets,\n        params.problem_sizes1,\n        params.problem_sizes2,\n        a_map,\n        c_map,\n        params.num_experts,\n        params.intermediate_size_per_partition,\n        params.hidden_size,\n        params.blockscale_offsets,\n    )\n\n    rep_a_fp4, rep_a_blockscale = scaled_fp4_experts_quant(\n        a,\n        a1_gscale,\n        params.expert_offsets,\n        params.blockscale_offsets,\n        num_topk,\n        expert_map=a_map,\n    )\n    c1 = cutlass_fp4_group_mm(\n        rep_a_fp4,\n        w1_fp4,\n        rep_a_blockscale,\n        w1_blockscale,\n        w1_alphas,\n        out_dtype,\n        params.to_gemm1_args(),\n    )\n    del rep_a_fp4, rep_a_blockscale\n\n    # hidden size dimension is split to one halfpytho sized tensor.\n    intermediate = torch.empty(\n        (m_a * num_topk, w1_fp4.shape[1] // 2), device=device, dtype=out_dtype\n    )\n    silu_and_mul(c1, intermediate)\n\n    int_fp4, int_blockscale = scaled_fp4_experts_quant(\n        intermediate,\n        a2_gscale,\n        params.expert_offsets,\n        params.blockscale_offsets,\n        num_topk,\n    )\n    c2 = cutlass_fp4_group_mm(\n        int_fp4,\n        w2_fp4,\n        int_blockscale,\n        w2_blockscale,\n        w2_alphas,\n        out_dtype,\n        params.to_gemm2_args(),\n    )\n    del int_fp4, int_blockscale\n    c2 = shuffle_rows(c2, c_map, (m_a * num_topk, params.hidden_size))\n    c2 = c2.view(m_a, num_topk, params.hidden_size)\n    if not apply_router_weight_on_input:\n        c2 = c2 * topk_weights.view(m_a, num_topk, 1).to(out_dtype)\n    return c2.sum(dim=1).to(out_dtype)\n"
  },
  {
    "path": "python/sglang/srt/layers/moe/cutlass_moe_params.py",
    "content": "from dataclasses import dataclass\nfrom enum import Enum, auto\nfrom typing import Optional\n\nimport torch\n\n\nclass CutlassMoEType(Enum):\n    \"\"\"\n    Enum for the different types of cutlass moe operations\n    that are currently supported in SGLang.\n    \"\"\"\n\n    BlockscaledFP8 = auto()\n    BlockscaledFP4 = auto()\n\n\n@dataclass\nclass CutlassMoEParams:\n    \"\"\"\n    Parameters for the cutlass moe operation.\n    \"\"\"\n\n    #  Type as defined above\n    cutlass_moe_type: CutlassMoEType\n\n    # Strides for activations, weights and output in logical number of elements.\n    # The activations & output stride is the number of elements to the next row.\n    # The weights stride is the number of elements to the next row per expert.\n    # For example, if the weight is [e, n, k], then the b_stride is a tensor of\n    # shape [e] with each element being k. Similarly for activations, if the\n    # shape is [m, k], then the a_stride has shape [e] with each value k.\n    # Similarly for output, if the output is [m, n], then the c_stride is a\n    # tensor of shape [e] with each element being k.\n\n    # Note: cutlass_fp4_group_mm is designed to accept the strides of\n    # activations and weights to be the same, so it is passed in as a single\n    # tensor.\n    # ab_strides_13: [e] dtype: int64 [Gemm 1: Activation / Weight strides]\n    # ab_strides_2: [e] dtype: int64 [Gemm 2: Activation / Weight strides]\n    # c_strides_13: [e] dtype: int64 [Gemm 1: Output Strides]\n    # c_strides_2: [e] dtype: int64 [Gemm 2: Output Strides]\n    ab_strides_13: torch.Tensor\n    ab_strides_2: torch.Tensor\n    c_strides_13: torch.Tensor\n    c_strides_2: torch.Tensor\n\n    # m: Total number of tokens\n    # n: intermediate size per partition\n    # k: hidden size per expert\n    # e: Number of experts\n    # device: Device to run computation on and store tensors\n    m: int\n    intermediate_size_per_partition: int\n    hidden_size: int\n    num_experts: int\n    device: torch.device\n\n    # Pointers container for calculating offsets of the input activations for each expert\n    # a_ptrs: [e] dtype: int64\n    a_ptrs: torch.Tensor\n\n    # Pointers container for calculating offsets of the input weights for each expert\n    # b_ptrs: [e] dtype: int64\n    b_ptrs: torch.Tensor\n\n    # Pointers container for calculating offsets of the output activations for each expert\n    # out_ptrs: [e] dtype: int64\n    out_ptrs: torch.Tensor\n    # Pointers container for calculating offsets of the input scales for each expert\n    # a_scales_ptrs: [e] dtype: int64\n    # b_scales_ptrs: [e] dtype: int64\n    a_scales_ptrs: torch.Tensor\n    b_scales_ptrs: torch.Tensor\n    # Pointers for per-expert alpha values\n    alpha_ptrs: torch.Tensor\n    # CUTLASS blockscale layouts for A and B operands\n    layout_sfa: torch.Tensor\n    layout_sfb: torch.Tensor\n\n    # Offsets that mark at which token index each expert begins its computation\n    # The number of tokens computed with expert E is expert_offsets[E + 1] - expert_offsets[E]\n    # expert_offsets: [e+1] dtype: int32\n    expert_offsets: torch.Tensor\n\n    # Problem size: (num_experts, (m,2n,k)) for first GEMM\n    # problem_sizes1: [e, 3] dtype: int32\n    # Problem size: (num_experts, (m,n,k)) for second GEMM\n    # problem_sizes2: [e, 3] dtype: int32\n    problem_sizes1: torch.Tensor\n    problem_sizes2: torch.Tensor\n    # Similar to expert_offsets, but for blockscales for FP4 blockscaled Group GEMM\n    blockscale_offsets: Optional[torch.Tensor] = None\n\n    def __init__(\n        self,\n        cutlass_moe_type: CutlassMoEType,\n        device: torch.device,\n        num_experts: int,\n        intermediate_size_per_partition: int,\n        hidden_size: int,\n    ):\n        self.cutlass_moe_type = cutlass_moe_type\n        self.device = device\n        self.num_experts = num_experts\n        self.intermediate_size_per_partition = intermediate_size_per_partition\n        self.hidden_size = hidden_size\n        self.n = self.intermediate_size_per_partition\n        self.k = self.hidden_size\n        self.e = self.num_experts\n        self.ab_strides_13 = torch.full(\n            (self.e,), self.k, dtype=torch.int64, device=self.device\n        )\n        self.ab_strides_2 = torch.full(\n            (self.e,), self.n, dtype=torch.int64, device=self.device\n        )\n        self.c_strides_13 = torch.full(\n            (self.e,), 2 * self.n, dtype=torch.int64, device=self.device\n        )\n        self.c_strides_2 = torch.full(\n            (self.e,), self.k, dtype=torch.int64, device=self.device\n        )\n        self.expert_offsets = torch.empty(\n            (self.e + 1,), dtype=torch.int32, device=self.device\n        )\n        self.problem_sizes1 = torch.empty(\n            (self.e, 3), dtype=torch.int32, device=self.device\n        )\n        self.problem_sizes2 = torch.empty(\n            (self.e, 3), dtype=torch.int32, device=self.device\n        )\n        if self.cutlass_moe_type == CutlassMoEType.BlockscaledFP4:\n            self.blockscale_offsets = torch.empty(\n                (self.e + 1,), dtype=torch.int32, device=self.device\n            )\n        else:\n            self.blockscale_offsets = None\n        self.a_ptrs = torch.empty((self.e,), dtype=torch.int64, device=self.device)\n        self.b_ptrs = torch.empty((self.e,), dtype=torch.int64, device=self.device)\n        self.out_ptrs = torch.empty((self.e,), dtype=torch.int64, device=self.device)\n        self.a_scales_ptrs = torch.empty(\n            (self.e,), dtype=torch.int64, device=self.device\n        )\n        self.b_scales_ptrs = torch.empty(\n            (self.e,), dtype=torch.int64, device=self.device\n        )\n        self.alpha_ptrs = torch.empty((self.e,), dtype=torch.int64, device=self.device)\n        self.layout_sfa = torch.empty(\n            (self.e, 5), dtype=torch.int64, device=self.device\n        )\n        self.layout_sfb = torch.empty(\n            (self.e, 5), dtype=torch.int64, device=self.device\n        )\n\n    def to_gemm1_args(self) -> dict:\n        return {\n            \"ab_strides\": self.ab_strides_13,\n            \"c_strides\": self.c_strides_13,\n            \"problem_sizes\": self.problem_sizes1,\n            \"expert_offsets\": self.expert_offsets[:-1],\n            \"blockscale_offsets\": self.blockscale_offsets[:-1],\n            \"a_ptrs\": self.a_ptrs,\n            \"b_ptrs\": self.b_ptrs,\n            \"out_ptrs\": self.out_ptrs,\n            \"a_scales_ptrs\": self.a_scales_ptrs,\n            \"b_scales_ptrs\": self.b_scales_ptrs,\n            \"alpha_ptrs\": self.alpha_ptrs,\n            \"layout_sfa\": self.layout_sfa,\n            \"layout_sfb\": self.layout_sfb,\n        }\n\n    def to_gemm2_args(self) -> dict:\n        return {\n            \"ab_strides\": self.ab_strides_2,\n            \"c_strides\": self.c_strides_2,\n            \"problem_sizes\": self.problem_sizes2,\n            \"expert_offsets\": self.expert_offsets[:-1],\n            \"blockscale_offsets\": self.blockscale_offsets[:-1],\n            \"a_ptrs\": self.a_ptrs,\n            \"b_ptrs\": self.b_ptrs,\n            \"out_ptrs\": self.out_ptrs,\n            \"a_scales_ptrs\": self.a_scales_ptrs,\n            \"b_scales_ptrs\": self.b_scales_ptrs,\n            \"alpha_ptrs\": self.alpha_ptrs,\n            \"layout_sfa\": self.layout_sfa,\n            \"layout_sfb\": self.layout_sfb,\n        }\n"
  },
  {
    "path": "python/sglang/srt/layers/moe/cutlass_w4a8_moe.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n\"\"\"Cutlass W4A8 MoE kernel.\"\"\"\n\nfrom typing import Optional\n\nimport torch\n\nfrom sglang.srt.utils import is_cuda_alike\n\n_is_cuda_alike = is_cuda_alike()\n\nif _is_cuda_alike:\n    from sgl_kernel import (\n        cutlass_w4a8_moe_mm,\n        get_cutlass_w4a8_moe_mm_data,\n    )\n\nfrom sgl_kernel import silu_and_mul\n\nfrom sglang.jit_kernel.per_tensor_quant_fp8 import per_tensor_quant_fp8\nfrom sglang.srt.distributed import get_moe_expert_parallel_world_size\nfrom sglang.srt.layers.moe.ep_moe.kernels import (\n    cutlass_w4_run_moe_ep_preproess,\n    deepep_ll_get_cutlass_w4a8_moe_mm_data,\n    deepep_permute_triton_kernel,\n    deepep_post_reorder_triton_kernel,\n    deepep_run_moe_deep_preprocess,\n    post_reorder_for_cutlass_moe,\n    pre_reorder_for_cutlass_moe,\n    silu_and_mul_masked_post_per_tensor_quant_fwd,\n    silu_mul_static_tensorwise_quant_for_cutlass_moe,\n)\n\n\ndef cutlass_w4a8_moe(\n    a: torch.Tensor,\n    w1_q: torch.Tensor,\n    w2_q: torch.Tensor,\n    w1_scale: torch.Tensor,\n    w2_scale: torch.Tensor,\n    topk_weights: torch.Tensor,\n    topk_ids: torch.Tensor,\n    a_strides1: torch.Tensor,\n    b_strides1: torch.Tensor,\n    c_strides1: torch.Tensor,\n    a_strides2: torch.Tensor,\n    b_strides2: torch.Tensor,\n    c_strides2: torch.Tensor,\n    s_strides13: torch.Tensor,\n    s_strides2: torch.Tensor,\n    expert_offsets: torch.Tensor,\n    problem_sizes1: torch.Tensor,\n    problem_sizes2: torch.Tensor,\n    a1_scale: Optional[torch.Tensor] = None,\n    a2_scale: Optional[torch.Tensor] = None,\n    apply_router_weight_on_input: bool = False,\n    routed_scaling_factor: float = 1.0,\n) -> torch.Tensor:\n    \"\"\"\n    This function computes a w4a8-quantized Mixture of Experts (MoE) layer\n    using two sets of quantized weights, w1_q and w2_q, and top-k gating\n    mechanism. The matrix multiplications are implemented with CUTLASS\n    grouped gemm.\n\n    Parameters:\n    - a (torch.Tensor): The input tensor to the MoE layer.\n        Shape: [M, K]\n    - w1_q (torch.Tensor): The first set of int4-quantized expert weights.\n        Shape: [num_experts, N * 2,  K // 2]\n        (the weights are passed transposed and int4-packed)\n    - w2_q (torch.Tensor): The second set of int4-quantized expert weights.\n        Shape: [num_experts, K, N // 2]\n        (the weights are passed transposed and int4-packed)\n    - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.\n        Shape: [num_experts, K // 512, N * 8]\n    - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.\n        Shape: [num_experts, N // 512, K * 4]\n    - topk_weights (torch.Tensor): The weights of each token->expert mapping.\n    - topk_ids (torch.Tensor): The ids of each token->expert mapping.\n    - a_strides1 (torch.Tensor): The input strides of the first grouped gemm.\n    - b_strides1 (torch.Tensor): The weights strides of the first grouped gemm.\n    - c_strides1 (torch.Tensor): The output strides of the first grouped gemm.\n    - a_strides2 (torch.Tensor): The input strides of the second grouped gemm.\n    - b_strides2 (torch.Tensor): The weights strides of the second grouped gemm.\n    - c_strides2 (torch.Tensor): The output strides of the second grouped gemm.\n    - s_strides13 (torch.Tensor): The input and scale strides of the first grouped gemm.\n    - s_strides2 (torch.Tensor): The scale strides of the second grouped gemm.\n    - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.\n        Shape: scalar or [1, K]\n    - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to\n        quantize the intermediate result between the gemms.\n        Shape: scalar or [1, N]\n    - apply_router_weight_on_input (bool): When true, the topk weights are\n        applied directly on the inputs. This is only applicable when topk is 1.\n\n    Returns:\n    - torch.Tensor: The fp8 output tensor after applying the MoE layer.\n    \"\"\"\n    assert topk_weights.shape == topk_ids.shape, \"topk shape mismatch\"\n    assert w1_q.dtype == torch.int8\n    assert w2_q.dtype == torch.int8\n    assert a.shape[1] // 2 == w1_q.shape[2], \"Hidden size mismatch w1\"\n    assert w1_q.shape[2] * 2 == w2_q.shape[1], \"Hidden size mismatch w2\"\n    assert w1_q.shape[0] == w2_q.shape[0], \"Expert number mismatch\"\n    assert w1_q.shape[0] == w1_scale.shape[0], \"w1 scales expert number mismatch\"\n    assert w1_q.shape[0] == w2_scale.shape[0], \"w2 scales expert number mismatch\"\n\n    assert a_strides1.shape[0] == w1_q.shape[0], \"A Strides 1 expert number mismatch\"\n    assert b_strides1.shape[0] == w1_q.shape[0], \"B Strides 1 expert number mismatch\"\n    assert a_strides2.shape[0] == w2_q.shape[0], \"A Strides 2 expert number mismatch\"\n    assert b_strides2.shape[0] == w2_q.shape[0], \"B Strides 2 expert number mismatch\"\n    num_local_experts = w1_q.size(0)\n    m = a.size(0)\n    k = w1_q.size(2) * 2  # w1_q is transposed and packed\n    n = w2_q.size(2) * 2  # w2_q is transposed and packed\n    topk = topk_ids.size(1)\n\n    if apply_router_weight_on_input:\n        assert topk == 1, \"apply_router_weight_on_input is only implemented for topk=1\"\n\n    device = a.device\n    if get_moe_expert_parallel_world_size() > 1:\n        topk_ids = torch.where(topk_ids == -1, num_local_experts, topk_ids)\n\n    src2dst = cutlass_w4_run_moe_ep_preproess(\n        topk_ids,\n    )\n\n    gateup_input = torch.empty(\n        (m * topk, k),\n        device=device,\n        dtype=torch.float8_e4m3fn,\n    )\n\n    pre_reorder_for_cutlass_moe(\n        a,\n        gateup_input,\n        src2dst,\n        topk_ids,\n        a1_scale,\n        num_local_experts,\n        topk,\n        m,\n        k,\n    )\n\n    # NOTE: a_map and c_map are not used in the get_cutlass_w4a8_moe_mm_data kernel,\n    # they are kept to allow for a quick switch of the permutation logic\n    # from the current triton kernel implementation to the cutlass-based one if needed.\n    a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)\n    c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)\n    get_cutlass_w4a8_moe_mm_data(\n        topk_ids,\n        expert_offsets,\n        problem_sizes1,\n        problem_sizes2,\n        a_map,\n        c_map,\n        num_local_experts,\n        n,\n        k,\n    )\n\n    c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.bfloat16)\n    c2 = torch.empty((m * topk, k), device=device, dtype=torch.bfloat16)\n\n    cutlass_w4a8_moe_mm(\n        c1,\n        gateup_input,\n        w1_q,\n        a1_scale.float(),\n        w1_scale,\n        expert_offsets[:-1],\n        problem_sizes1,\n        a_strides1,\n        b_strides1,\n        c_strides1,\n        s_strides13,\n        128,\n        topk,\n    )\n\n    intermediate_q = torch.empty(\n        (m * topk, n), dtype=torch.float8_e4m3fn, device=device\n    )\n    silu_mul_static_tensorwise_quant_for_cutlass_moe(\n        c1, intermediate_q, a2_scale.float(), expert_offsets[-1:], m * topk, n\n    )\n\n    cutlass_w4a8_moe_mm(\n        c2,\n        intermediate_q,\n        w2_q,\n        a2_scale.float(),\n        w2_scale,\n        expert_offsets[:-1],\n        problem_sizes2,\n        a_strides2,\n        b_strides2,\n        c_strides2,\n        s_strides2,\n        128,\n        topk,\n    )\n\n    output = torch.empty_like(a)\n\n    post_reorder_for_cutlass_moe(\n        c2,\n        output,\n        src2dst,\n        topk_ids,\n        topk_weights,\n        num_local_experts,\n        topk,\n        m,\n        k,\n        routed_scaling_factor,\n    )\n    return output\n\n\ndef cutlass_w4a8_moe_deepep_normal(\n    a: torch.Tensor,\n    w1_q: torch.Tensor,\n    w2_q: torch.Tensor,\n    w1_scale: torch.Tensor,\n    w2_scale: torch.Tensor,\n    topk_weights: torch.Tensor,\n    topk_ids_: torch.Tensor,\n    a_strides1: torch.Tensor,\n    b_strides1: torch.Tensor,\n    c_strides1: torch.Tensor,\n    a_strides2: torch.Tensor,\n    b_strides2: torch.Tensor,\n    c_strides2: torch.Tensor,\n    s_strides13: torch.Tensor,\n    s_strides2: torch.Tensor,\n    expert_offsets: torch.Tensor,\n    problem_sizes1: torch.Tensor,\n    problem_sizes2: torch.Tensor,\n    a1_scale: Optional[torch.Tensor] = None,\n    a2_scale: Optional[torch.Tensor] = None,\n) -> torch.Tensor:\n    \"\"\"\n    This function computes a w4a8-quantized Mixture of Experts (MoE) layer\n    using two sets of quantized weights, w1_q and w2_q, and top-k gating\n    mechanism. The matrix multiplications are implemented with CUTLASS\n    grouped gemm.\n\n    Parameters:\n    - a (torch.Tensor): The input tensor to the MoE layer.\n        Shape: [M, K]\n    - w1_q (torch.Tensor): The first set of int4-quantized expert weights.\n        Shape: [num_experts, N * 2,  K // 2]\n        (the weights are passed transposed and int4-packed)\n    - w2_q (torch.Tensor): The second set of int4-quantized expert weights.\n        Shape: [num_experts, K, N // 2]\n        (the weights are passed transposed and int4-packed)\n    - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.\n        Shape: [num_experts, K // 512, N * 8]\n    - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.\n        Shape: [num_experts, N // 512, K * 4]\n    - topk_weights (torch.Tensor): The weights of each token->expert mapping.\n    - a_strides1 (torch.Tensor): The input strides of the first grouped gemm.\n    - b_strides1 (torch.Tensor): The weights strides of the first grouped gemm.\n    - c_strides1 (torch.Tensor): The output strides of the first grouped gemm.\n    - a_strides2 (torch.Tensor): The input strides of the second grouped gemm.\n    - b_strides2 (torch.Tensor): The weights strides of the second grouped gemm.\n    - c_strides2 (torch.Tensor): The output strides of the second grouped gemm.\n    - s_strides13 (torch.Tensor): The input and scale strides of the first grouped gemm.\n    - s_strides2 (torch.Tensor): The scale strides of the second grouped gemm.\n    - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.\n        Shape: scalar or [1, K]\n    - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to\n        quantize the intermediate result between the gemms.\n        Shape: scalar or [1, N]\n    - apply_router_weight_on_input (bool): When true, the topk weights are\n        applied directly on the inputs. This is only applicable when topk is 1.\n\n    Returns:\n    - torch.Tensor: The fp8 output tensor after applying the MoE layer.\n    \"\"\"\n    assert topk_weights.shape == topk_ids_.shape, \"topk shape mismatch\"\n    assert w1_q.dtype == torch.int8\n    assert w2_q.dtype == torch.int8\n    assert a.shape[1] // 2 == w1_q.shape[2], \"Hidden size mismatch w1\"\n    assert w1_q.shape[2] * 2 == w2_q.shape[1], \"Hidden size mismatch w2\"\n    assert w1_q.shape[0] == w2_q.shape[0], \"Expert number mismatch\"\n    assert w1_q.shape[0] == w1_scale.shape[0], \"w1 scales expert number mismatch\"\n    assert w1_q.shape[0] == w2_scale.shape[0], \"w2 scales expert number mismatch\"\n\n    assert a_strides1.shape[0] == w1_q.shape[0], \"A Strides 1 expert number mismatch\"\n    assert b_strides1.shape[0] == w1_q.shape[0], \"B Strides 1 expert number mismatch\"\n    assert a_strides2.shape[0] == w2_q.shape[0], \"A Strides 2 expert number mismatch\"\n    assert b_strides2.shape[0] == w2_q.shape[0], \"B Strides 2 expert number mismatch\"\n    num_experts = w1_q.size(0)\n    m = a.size(0)\n    k = w1_q.size(2) * 2  # w1_q is transposed and packed\n    n = w2_q.size(2) * 2  # w2_q is transposed and packed\n    topk = topk_ids_.size(1)\n\n    num_experts = w1_q.size(0)\n    m = a.size(0)\n    k = w1_q.size(2) * 2\n    n = w2_q.size(2) * 2\n    topk = topk_ids_.size(1)\n    device = a.device\n\n    reorder_topk_ids, src2dst, _ = deepep_run_moe_deep_preprocess(\n        topk_ids_, num_experts\n    )\n    num_total_tokens = reorder_topk_ids.numel()\n    gateup_input_pre_reorder = torch.empty(\n        (int(num_total_tokens), a.shape[1]),\n        device=device,\n        dtype=a.dtype,\n    )\n    deepep_permute_triton_kernel[(a.shape[0],)](\n        a,\n        gateup_input_pre_reorder,\n        src2dst,\n        topk_ids_.to(torch.int64),\n        None,\n        topk,\n        a.shape[1],\n        BLOCK_SIZE=512,\n    )\n    gateup_input = torch.empty(\n        gateup_input_pre_reorder.shape, dtype=torch.float8_e4m3fn, device=device\n    )\n    per_tensor_quant_fp8(gateup_input_pre_reorder, gateup_input, a1_scale.float(), True)\n    del gateup_input_pre_reorder\n    local_topk_ids = topk_ids_\n    local_topk_ids = (\n        torch.where(local_topk_ids == -1, num_experts, topk_ids_).to(torch.int32)\n    ).contiguous()\n\n    a_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)\n    c_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)\n    get_cutlass_w4a8_moe_mm_data(\n        local_topk_ids,\n        expert_offsets,\n        problem_sizes1,\n        problem_sizes2,\n        a_map,\n        c_map,\n        num_experts,\n        n,\n        k,\n    )\n    c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.bfloat16)\n    c2 = torch.zeros((m * topk, k), device=device, dtype=torch.bfloat16)\n\n    cutlass_w4a8_moe_mm(\n        c1,\n        gateup_input,\n        w1_q,\n        a1_scale.float(),\n        w1_scale,\n        expert_offsets[:-1],\n        problem_sizes1,\n        a_strides1,\n        b_strides1,\n        c_strides1,\n        s_strides13,\n        128,\n        topk,\n    )\n    intermediate = torch.empty((m * topk, n), device=device, dtype=torch.bfloat16)\n    silu_and_mul(c1, intermediate)\n\n    intermediate_q = torch.empty(\n        intermediate.shape, dtype=torch.float8_e4m3fn, device=device\n    )\n    per_tensor_quant_fp8(intermediate, intermediate_q, a2_scale.float(), True)\n\n    cutlass_w4a8_moe_mm(\n        c2,\n        intermediate_q,\n        w2_q,\n        a2_scale.float(),\n        w2_scale,\n        expert_offsets[:-1],\n        problem_sizes2,\n        a_strides2,\n        b_strides2,\n        c_strides2,\n        s_strides2,\n        128,\n        topk,\n    )\n    num_tokens = src2dst.shape[0] // topk\n    output = torch.empty(\n        (num_tokens, c2.shape[1]),\n        device=c2.device,\n        dtype=torch.bfloat16,\n    )\n    deepep_post_reorder_triton_kernel[(num_tokens,)](\n        c2,\n        output,\n        src2dst,\n        topk_ids_,\n        topk_weights,\n        topk,\n        c2.shape[1],\n        BLOCK_SIZE=512,\n    )\n\n    return output\n\n\ndef cutlass_w4a8_moe_deepep_ll(\n    a: torch.Tensor,\n    w1_q: torch.Tensor,\n    w2_q: torch.Tensor,\n    w1_scale: torch.Tensor,\n    w2_scale: torch.Tensor,\n    topk_ids_: torch.Tensor,\n    masked_m: torch.Tensor,\n    a_strides1: torch.Tensor,\n    b_strides1: torch.Tensor,\n    c_strides1: torch.Tensor,\n    a_strides2: torch.Tensor,\n    b_strides2: torch.Tensor,\n    c_strides2: torch.Tensor,\n    s_strides13: torch.Tensor,\n    s_strides2: torch.Tensor,\n    expert_offsets: torch.Tensor,\n    problem_sizes1: torch.Tensor,\n    problem_sizes2: torch.Tensor,\n    a1_scale: Optional[torch.Tensor] = None,\n    a2_scale: Optional[torch.Tensor] = None,\n) -> torch.Tensor:\n    \"\"\"\n    This function computes a w4a8-quantized Mixture of Experts (MoE) layer\n    using two sets of quantized weights, w1_q and w2_q, and top-k gating\n    mechanism. The matrix multiplications are implemented with CUTLASS\n    grouped gemm.\n\n    Parameters:\n    - a (torch.Tensor): The input tensor to the MoE layer.\n        Shape: [num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, K]\n    - w1_q (torch.Tensor): The first set of int4-quantized expert weights.\n        Shape: [num_experts, N * 2,  K // 2]\n        (the weights are passed transposed and int4-packed)\n    - w2_q (torch.Tensor): The second set of int4-quantized expert weights.\n        Shape: [num_experts, K, N // 2]\n        (the weights are passed transposed and int4-packed)\n    - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.\n        Shape: [num_experts, K // 512, N * 8]\n    - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.\n        Shape: [num_experts, N // 512, K * 4]\n    - topk_weights (torch.Tensor): The weights of each token->expert mapping.\n    - a_strides1 (torch.Tensor): The input strides of the first grouped gemm.\n    - b_strides1 (torch.Tensor): The weights strides of the first grouped gemm.\n    - c_strides1 (torch.Tensor): The output strides of the first grouped gemm.\n    - a_strides2 (torch.Tensor): The input strides of the second grouped gemm.\n    - b_strides2 (torch.Tensor): The weights strides of the second grouped gemm.\n    - c_strides2 (torch.Tensor): The output strides of the second grouped gemm.\n    - s_strides13 (torch.Tensor): The input and scale strides of the first grouped gemm.\n    - s_strides2 (torch.Tensor): The scale strides of the second grouped gemm.\n    - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.\n        Shape: scalar or [1, K]\n    - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to\n        quantize the intermediate result between the gemms.\n        Shape: scalar or [1, N]\n    - apply_router_weight_on_input (bool): When true, the topk weights are\n        applied directly on the inputs. This is only applicable when topk is 1.\n\n    Returns:\n    - torch.Tensor: The fp8 output tensor after applying the MoE layer.\n    \"\"\"\n    assert w1_q.dtype == torch.int8\n    assert w2_q.dtype == torch.int8\n    assert a.shape[2] // 2 == w1_q.shape[2], \"Hidden size mismatch w1\"\n    assert w1_q.shape[2] * 2 == w2_q.shape[1], \"Hidden size mismatch w2\"\n    assert w1_q.shape[0] == w2_q.shape[0], \"Expert number mismatch\"\n    assert w1_q.shape[0] == w1_scale.shape[0], \"w1 scales expert number mismatch\"\n    assert w1_q.shape[0] == w2_scale.shape[0], \"w2 scales expert number mismatch\"\n\n    assert a_strides1.shape[0] == w1_q.shape[0], \"A Strides 1 expert number mismatch\"\n    assert b_strides1.shape[0] == w1_q.shape[0], \"B Strides 1 expert number mismatch\"\n    assert a_strides2.shape[0] == w2_q.shape[0], \"A Strides 2 expert number mismatch\"\n    assert b_strides2.shape[0] == w2_q.shape[0], \"B Strides 2 expert number mismatch\"\n    num_experts = w1_q.size(0)\n    m = a.size(1)\n    k = w1_q.size(2) * 2  # w1_q is transposed and packed\n    n = w2_q.size(2) * 2  # w2_q is transposed and packed\n    topk = topk_ids_.size(1)\n\n    device = a.device\n\n    problem_sizes1, problem_sizes2 = deepep_ll_get_cutlass_w4a8_moe_mm_data(\n        masked_m,\n        problem_sizes1,\n        problem_sizes2,\n        num_experts,\n        n,\n        k,\n    )\n\n    gateup_input = torch.empty(a.shape, dtype=torch.float8_e4m3fn, device=device)\n    per_tensor_quant_fp8(a, gateup_input, a1_scale.float(), True)\n    c1 = torch.empty((num_experts, m, n * 2), device=device, dtype=torch.bfloat16)\n    c2 = torch.empty((num_experts, m, k), device=device, dtype=torch.bfloat16)\n\n    cutlass_w4a8_moe_mm(\n        c1,\n        gateup_input,\n        w1_q,\n        a1_scale.float(),\n        w1_scale,\n        expert_offsets[:-1],\n        problem_sizes1,\n        a_strides1,\n        b_strides1,\n        c_strides1,\n        s_strides13,\n        128,\n        topk,\n    )\n\n    intermediate_q = torch.empty(\n        (num_experts, m, n), device=a.device, dtype=torch.float8_e4m3fn\n    )\n    silu_and_mul_masked_post_per_tensor_quant_fwd(\n        c1, intermediate_q, masked_m, a2_scale\n    )\n    cutlass_w4a8_moe_mm(\n        c2,\n        intermediate_q,\n        w2_q,\n        a2_scale.float(),\n        w2_scale,\n        expert_offsets[:-1],\n        problem_sizes2,\n        a_strides2,\n        b_strides2,\n        c_strides2,\n        s_strides2,\n        128,\n        topk,\n    )\n\n    return c2\n"
  },
  {
    "path": "python/sglang/srt/layers/moe/ep_moe/__init__.py",
    "content": ""
  },
  {
    "path": "python/sglang/srt/layers/moe/ep_moe/kernels.py",
    "content": "import logging\n\nimport torch\nimport triton\n\nfrom sglang.srt.utils import ceil_div, is_cuda\n\nlogger = logging.getLogger(__name__)\n\n_is_cuda = is_cuda()\nif _is_cuda:\n    from sglang.srt.layers.quantization.fp8_kernel import (\n        sglang_per_token_group_quant_fp8 as per_token_group_quant_fp8,\n    )\n\nimport triton.language as tl\n\n\ndef _get_launch_config_1d(device, numel):\n    MAX_THREADS_PER_BLOCK = 1024\n    MIN_THREADS_PER_BLOCK = 512\n    MAX_WAVES = 8  # empirical numbers\n\n    props = torch.cuda.get_device_properties(device)\n    sm_count = props.multi_processor_count\n    max_threads_per_sm = props.max_threads_per_multi_processor\n    max_num_blocks = sm_count * max_threads_per_sm // MAX_THREADS_PER_BLOCK\n\n    block_dim = MAX_THREADS_PER_BLOCK\n\n    def get_num_blocks(block_dim):\n        return triton.cdiv(numel, block_dim)\n\n    while (\n        block_dim > MIN_THREADS_PER_BLOCK\n        and get_num_blocks(block_dim // 2) <= max_num_blocks\n    ):\n        block_dim = block_dim // 2\n\n    num_blocks = get_num_blocks(block_dim)\n    grid_dim = min(num_blocks, max_num_blocks * MAX_WAVES)\n\n    return (grid_dim,), block_dim\n\n\ndef _get_launch_config_2d(device, m, n):\n    MAX_THREADS_PER_BLOCK = 1024\n    MIN_THREADS_PER_BLOCK = 512\n    MAX_WAVES = 8  # empirical numbers\n\n    props = torch.cuda.get_device_properties(device)\n    sm_count = props.multi_processor_count\n    max_threads_per_sm = props.max_threads_per_multi_processor\n    max_num_blocks = sm_count * max_threads_per_sm // MAX_THREADS_PER_BLOCK\n\n    block_dim = MAX_THREADS_PER_BLOCK\n\n    def get_num_blocks(block_dim):\n        return m * triton.cdiv(n, block_dim)\n\n    while (\n        block_dim > MIN_THREADS_PER_BLOCK\n        and get_num_blocks(block_dim // 2) <= max_num_blocks\n    ):\n        block_dim = block_dim // 2\n\n    grid_dim_x = triton.cdiv(n, block_dim)\n    grid_dim_y = max(min(m, max_num_blocks * MAX_WAVES // grid_dim_x), 1)\n\n    return (grid_dim_y, grid_dim_x), block_dim\n\n\n@triton.jit\ndef deepep_permute_triton_kernel(\n    input_ptr,\n    gateup_input_ptr,\n    src2dst_ptr,\n    topk_ids_ptr,\n    a1_scales_ptr,\n    topk,\n    hidden_size,\n    BLOCK_SIZE: tl.constexpr,\n):\n    OutDtype = gateup_input_ptr.dtype.element_ty\n\n    src_idx = tl.program_id(0)\n    src2dst_ptr = src2dst_ptr + src_idx * topk\n    topk_ids_ptr = topk_ids_ptr + src_idx * topk\n\n    src_ptr = input_ptr + src_idx * hidden_size\n\n    for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):\n        offset = start_offset + tl.arange(0, BLOCK_SIZE)\n        mask = offset < hidden_size\n        in_data = tl.load(src_ptr + offset, mask=mask).to(OutDtype)\n\n        for idx in range(topk):\n            dst_idx = tl.load(src2dst_ptr + idx)\n            if dst_idx >= 0:\n                dst_ptr = gateup_input_ptr + dst_idx * hidden_size\n                tl.store(dst_ptr + offset, in_data, mask=mask)\n\n\n@triton.jit\ndef deepep_post_reorder_triton_kernel(\n    down_output_ptr,\n    output_ptr,\n    src2dst_ptr,\n    topk_ids_ptr,\n    topk_weights_ptr,\n    topk,\n    hidden_size,\n    BLOCK_SIZE: tl.constexpr,\n):\n    InDtype = down_output_ptr.dtype.element_ty\n\n    src_idx = tl.program_id(0)\n    src2dst_ptr = src2dst_ptr + src_idx * topk\n    topk_ids_ptr = topk_ids_ptr + src_idx * topk\n    topk_weights_ptr = topk_weights_ptr + src_idx * topk\n\n    store_ptr = output_ptr + src_idx * hidden_size\n    for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):\n        offset = start_offset + tl.arange(0, BLOCK_SIZE)\n        mask = offset < hidden_size\n        sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)\n        for idx in range(topk):\n            dst_idx = tl.load(src2dst_ptr + idx)\n            if dst_idx >= 0:\n                weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)\n                load_ptr = down_output_ptr + dst_idx * hidden_size\n                in_data = tl.load(load_ptr + offset, mask=mask)\n                sum_vec += in_data * weigh_scale\n        tl.store(store_ptr + offset, sum_vec, mask=mask)\n\n\n@triton.jit\ndef compute_src2dst_triton_kernel(\n    reorder_ids, src2dst, num_toks, BLOCK_SIZE: tl.constexpr\n):\n    pid = tl.program_id(axis=0)\n    dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n    mask = dst_id < num_toks\n    src_id = tl.load(reorder_ids + dst_id, mask=mask)\n    tl.store(src2dst + src_id, dst_id, mask=mask)\n\n\n@triton.jit\ndef deepep_compute_src2dst_triton_kernel(\n    reorder_ids, src2dst, num_toks, num_minus_one, BLOCK_SIZE: tl.constexpr\n):\n    pid = tl.program_id(axis=0)\n    dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n    mask = dst_id < num_toks\n    src_id = tl.load(reorder_ids + dst_id, mask=mask)\n    num_invalid = tl.load(num_minus_one)\n    tl.store(src2dst + src_id, dst_id - num_invalid, mask=mask)\n\n\ndef deepep_run_moe_deep_preprocess(topk_ids: torch.Tensor, num_experts: int):\n    reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)\n    seg_indptr = torch.empty(num_experts + 1, device=topk_ids.device, dtype=torch.int64)\n    src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int64)\n\n    # Find offset\n    expert_ids = torch.arange(\n        num_experts + 1, device=topk_ids.device, dtype=reorder_topk_ids.dtype\n    )\n    torch.searchsorted(reorder_topk_ids, expert_ids, out=seg_indptr)\n    num_minus_one = seg_indptr[0]\n    seg_indptr = seg_indptr - num_minus_one\n\n    BLOCK_SIZE = 512\n    grid = (triton.cdiv(topk_ids.numel(), BLOCK_SIZE),)\n    deepep_compute_src2dst_triton_kernel[grid](\n        reorder_ids, src2dst, topk_ids.numel(), num_minus_one, BLOCK_SIZE\n    )\n    reorder_topk_ids = reorder_topk_ids[num_minus_one:]\n    return reorder_topk_ids, src2dst, seg_indptr\n\n\n@triton.jit\ndef compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks):\n    expert_id_minus_1 = tl.program_id(0) - 1\n    low = 0\n    high = num_toks - 1\n    target_location = -1\n    while low <= high:\n        mid = (low + high) // 2\n\n        if tl.load(reorder_topk_ids + mid) > expert_id_minus_1:\n            high = mid - 1\n        else:\n            low = mid + 1\n            target_location = mid\n    tl.store(seg_indptr + expert_id_minus_1 + 1, target_location + 1)\n\n\ndef cutlass_w4_run_moe_ep_preproess(topk_ids: torch.Tensor):\n    _, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)\n\n    BLOCK_SIZE = 512\n    grid = (triton.cdiv(topk_ids.numel(), BLOCK_SIZE),)\n    src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)\n    compute_src2dst_triton_kernel[grid](\n        reorder_ids, src2dst, topk_ids.numel(), BLOCK_SIZE\n    )\n\n    return src2dst\n\n\n@triton.jit\ndef pre_reorder_triton_kernel_for_cutlass_moe(\n    input_ptr,\n    gateup_input_ptr,\n    src2dst_ptr,\n    topk_ids_ptr,\n    a1_scales_ptr,\n    num_local_experts,\n    topk,\n    num_tokens,\n    hidden_size,\n    BLOCK_SIZE: tl.constexpr,\n    NUM_STAGES: tl.constexpr,\n):\n    OutDtype = gateup_input_ptr.dtype.element_ty\n\n    if a1_scales_ptr is not None:\n        a1_scale = 1.0 / tl.load(a1_scales_ptr)\n    else:\n        a1_scale = 1.0\n\n    offset = BLOCK_SIZE * tl.program_id(1) + tl.arange(0, BLOCK_SIZE)\n    mask = offset < hidden_size\n\n    start_src_idx = tl.program_id(0)\n    step = tl.num_programs(0)\n\n    for src_idx_int32 in tl.range(\n        start_src_idx, num_tokens, step, num_stages=NUM_STAGES\n    ):\n        src_idx = src_idx_int32.to(tl.int64)\n        token_src2dst_ptr = src2dst_ptr + src_idx * topk\n        token_topk_ids_ptr = topk_ids_ptr + src_idx * topk\n\n        src_ptr_offs = input_ptr + src_idx * hidden_size + offset\n        dst_ptr_offs = gateup_input_ptr + offset\n        in_data = tl.load(src_ptr_offs, mask=mask).to(tl.float32)\n        out_data = (in_data * a1_scale).to(OutDtype)\n        for idx in range(topk):\n            expert_id = tl.load(token_topk_ids_ptr + idx)\n            if expert_id != num_local_experts:\n                dst_idx = tl.load(token_src2dst_ptr + idx)\n                tl.store(dst_ptr_offs + dst_idx * hidden_size, out_data, mask=mask)\n\n\ndef pre_reorder_for_cutlass_moe(\n    input,\n    gateup_input,\n    src2dst,\n    topk_ids,\n    a1_scales,\n    num_local_experts,\n    topk,\n    num_tokens,\n    hidden_size,\n):\n    grid, block_dim = _get_launch_config_2d(input.device, num_tokens, hidden_size)\n\n    pre_reorder_triton_kernel_for_cutlass_moe[grid](\n        input_ptr=input,\n        gateup_input_ptr=gateup_input,\n        src2dst_ptr=src2dst,\n        topk_ids_ptr=topk_ids,\n        a1_scales_ptr=a1_scales,\n        num_local_experts=num_local_experts,\n        topk=topk,\n        num_tokens=num_tokens,\n        hidden_size=hidden_size,\n        BLOCK_SIZE=block_dim,\n        NUM_STAGES=3,\n    )\n\n\n# copy from https://github.com/ModelTC/lightllm/blob/a000ab69098654df4731f5b12587dd4e7f0a4f41/lightllm/common/fused_moe/moe_silu_and_mul_mix_quant_ep.py\n@triton.jit\ndef _silu_and_mul_post_quant_kernel(\n    input_ptr,\n    stride_input_0,\n    stride_input_1,\n    stride_input_2,\n    output_ptr,\n    stride_output_0,\n    stride_output_1,\n    stride_output_2,\n    output_scale_ptr,\n    stride_output_scale_0,\n    stride_output_scale_1,\n    stride_output_scale_2,\n    masked_m_ptr,\n    size_n,\n    fp8_max,\n    fp8_min,\n    BLOCK_N: tl.constexpr,\n    NUM_STAGE: tl.constexpr,\n    SCALE_UE8M0: tl.constexpr,\n):\n    expert_id = tl.program_id(2)\n    token_id = tl.program_id(1)\n    hidden_dim_block_index = tl.program_id(0)\n\n    block_num_per_expert = tl.num_programs(1)\n\n    token_num_cur_expert = tl.load(masked_m_ptr + expert_id)\n\n    stride_input_0 = tl.cast(stride_input_0, dtype=tl.int64)\n    stride_output_0 = tl.cast(stride_output_0, dtype=tl.int64)\n    stride_input_1 = tl.cast(stride_input_1, dtype=tl.int64)\n    stride_output_1 = tl.cast(stride_output_1, dtype=tl.int64)\n\n    offs_in_d = hidden_dim_block_index * BLOCK_N + tl.arange(0, BLOCK_N)\n    input_ptr_offs = input_ptr + expert_id * stride_input_0 + offs_in_d\n    output_ptr_offs = output_ptr + expert_id * stride_output_0 + offs_in_d\n    output_scale_offs = (\n        output_scale_ptr\n        + expert_id * stride_output_scale_0\n        + hidden_dim_block_index * stride_output_scale_2\n    )\n\n    for token_index in tl.range(\n        token_id, token_num_cur_expert, block_num_per_expert, num_stages=NUM_STAGE\n    ):\n        gate = tl.load(\n            input_ptr_offs + token_index * stride_input_1,\n            mask=offs_in_d < size_n,\n            other=0.0,\n        ).to(tl.float32)\n        up = tl.load(\n            input_ptr_offs + token_index * stride_input_1 + size_n,\n            mask=offs_in_d < size_n,\n            other=0.0,\n        )\n        gate = gate / (1 + tl.exp(-gate))\n        gate = gate.to(input_ptr.dtype.element_ty)\n        gate_up = up * gate\n        _absmax = tl.maximum(tl.max(tl.abs(gate_up)), 1e-10)\n        output_s = _absmax / fp8_max\n        if SCALE_UE8M0:\n            output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s))))\n        output_q = tl.clamp(gate_up / output_s, fp8_min, fp8_max).to(\n            output_ptr.dtype.element_ty\n        )\n        tl.store(\n            output_ptr_offs + token_index * stride_output_1,\n            output_q,\n            mask=offs_in_d < size_n,\n        )\n        tl.store(\n            output_scale_offs + token_index * stride_output_scale_1,\n            output_s,\n        )\n\n\ndef silu_and_mul_masked_post_quant_fwd(\n    input: torch.Tensor,\n    output: torch.Tensor,\n    output_scale: torch.Tensor,\n    quant_group_size: int,\n    masked_m: torch.Tensor,\n    scale_ue8m0: bool = False,\n):\n    \"\"\"\n    input shape [expert_num, token_num_padded, hidden_dim]\n    output shape [expert_num, token_num_padded, hidden_dim // 2], dtype fp8\n    output_scale [expert_num token_num_paddded, hidden_dim // 2 // 128] dtype float32\n    quant_group_size  int,\n    masked_m shape [expert_num],\n    \"\"\"\n\n    assert input.is_contiguous()\n    assert output.dtype == torch.float8_e4m3fn\n    assert output.is_contiguous()\n    assert len(input.shape) == 3\n    assert input.shape[0] == masked_m.shape[0]\n    assert input.shape[-1] % 2 == 0\n\n    size_n = input.shape[-1] // 2\n    assert size_n % quant_group_size == 0\n\n    expert_num = len(masked_m)\n\n    if expert_num < 4:\n        BLOCK_NUM_PER_EXPERT = 64\n    else:\n        BLOCK_NUM_PER_EXPERT = 32\n\n    BLOCK_N = quant_group_size\n    num_warps = 1\n    NUM_STAGES = 6\n    hidden_dim_split_block_num = triton.cdiv(size_n, BLOCK_N)\n    assert BLOCK_N % quant_group_size == 0\n\n    grid = (\n        hidden_dim_split_block_num,\n        BLOCK_NUM_PER_EXPERT,\n        expert_num,\n    )\n\n    finfo = torch.finfo(torch.float8_e4m3fn)\n    fp8_max = finfo.max\n    fp8_min = -fp8_max\n\n    _silu_and_mul_post_quant_kernel[grid](\n        input,\n        *input.stride(),\n        output,\n        *output.stride(),\n        output_scale,\n        *output_scale.stride(),\n        masked_m,\n        size_n,\n        fp8_max,\n        fp8_min,\n        BLOCK_N=BLOCK_N,\n        NUM_STAGE=NUM_STAGES,\n        num_warps=num_warps,\n        SCALE_UE8M0=scale_ue8m0,\n    )\n    return\n\n\n@triton.jit\ndef silu_mul_static_tensorwise_quant_triton_kernel_for_cutlass_moe(\n    input_ptr,\n    output_ptr,\n    scale_ptr,\n    num_tokens_tensor_ptr,\n    intermediate_size,\n    BLOCK_SIZE: tl.constexpr,\n    NUM_STAGES: tl.constexpr,\n):\n    OutDtype = output_ptr.dtype.element_ty\n\n    num_tokens = tl.load(num_tokens_tensor_ptr)\n    numel = num_tokens * intermediate_size\n    gate_ptr = input_ptr\n    up_ptr = input_ptr + intermediate_size\n    scale = 1.0 / tl.load(scale_ptr)\n\n    start_idx = tl.program_id(0) * BLOCK_SIZE\n    step = tl.num_programs(0) * BLOCK_SIZE\n\n    for id in tl.range(start_idx, numel, step, num_stages=NUM_STAGES):\n        ids = id + tl.arange(0, BLOCK_SIZE)\n        token_ids = ids // intermediate_size\n        mask = ids < numel\n\n        offs = ids + token_ids * intermediate_size\n        gate = tl.load(gate_ptr + offs, mask=mask, other=0.0).to(tl.float32)\n        up = tl.load(up_ptr + offs, mask=mask, other=0.0).to(tl.float32)\n        output = gate / (1 + tl.exp(-gate)) * up * scale\n        tl.store(output_ptr + ids, output.to(OutDtype), mask=mask)\n\n\ndef silu_mul_static_tensorwise_quant_for_cutlass_moe(\n    input: torch.Tensor,\n    output: torch.Tensor,\n    scale: torch.Tensor,\n    num_tokens_tensor: torch.Tensor,\n    expected_num_tokens: int,\n    intermediate_size: int,\n):\n    grid, block_dim = _get_launch_config_1d(\n        input.device, expected_num_tokens * intermediate_size\n    )\n\n    silu_mul_static_tensorwise_quant_triton_kernel_for_cutlass_moe[grid](\n        input_ptr=input,\n        output_ptr=output,\n        scale_ptr=scale,\n        num_tokens_tensor_ptr=num_tokens_tensor,\n        intermediate_size=intermediate_size,\n        BLOCK_SIZE=block_dim,\n        NUM_STAGES=3,\n    )\n\n\n@triton.jit\ndef post_reorder_triton_kernel_for_cutlass_moe(\n    down_output_ptr,\n    output_ptr,\n    src2dst_ptr,\n    topk_ids_ptr,\n    topk_weights_ptr,\n    num_local_experts,\n    topk,\n    num_tokens,\n    hidden_size,\n    routed_scaling_factor: float,\n    BLOCK_SIZE: tl.constexpr,\n    NUM_STAGES: tl.constexpr,\n):\n    OutDtype = output_ptr.dtype.element_ty\n\n    offset = BLOCK_SIZE * tl.program_id(1) + tl.arange(0, BLOCK_SIZE)\n    mask = offset < hidden_size\n\n    down_output_ptr_offs = down_output_ptr + offset\n    output_ptr_offs = output_ptr + offset\n\n    start_src_idx = tl.program_id(0)\n    step = tl.num_programs(0)\n\n    for src_idx_int32 in tl.range(\n        start_src_idx, num_tokens, step, num_stages=NUM_STAGES\n    ):\n        src_idx = src_idx_int32.to(tl.int64)\n        token_src2dst_ptr = src2dst_ptr + src_idx * topk\n        token_topk_ids_ptr = topk_ids_ptr + src_idx * topk\n        token_topk_weights_ptr = topk_weights_ptr + src_idx * topk\n\n        sum_vec = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n        for idx in range(topk):\n            expert_id = tl.load(token_topk_ids_ptr + idx)\n            if expert_id != num_local_experts:\n                dst_idx_int32 = tl.load(token_src2dst_ptr + idx)\n                dst_idx = dst_idx_int32.to(tl.int64)\n                dst_idx = dst_idx\n                weight_scale = tl.load(token_topk_weights_ptr + idx).to(tl.float32)\n                load_ptr_offs = down_output_ptr_offs + dst_idx * hidden_size\n                in_data = tl.load(load_ptr_offs, mask=mask).to(tl.float32)\n                sum_vec += in_data * weight_scale\n        sum_vec *= routed_scaling_factor\n        store_ptr_offs = output_ptr_offs + src_idx * hidden_size\n        tl.store(store_ptr_offs, sum_vec.to(OutDtype), mask=mask)\n\n\ndef post_reorder_for_cutlass_moe(\n    down_output,\n    output,\n    src2dst,\n    topk_ids,\n    topk_weights,\n    num_local_experts,\n    topk,\n    num_tokens,\n    hidden_size,\n    routed_scaling_factor: float,\n):\n    grid, block_dim = _get_launch_config_2d(down_output.device, num_tokens, hidden_size)\n\n    post_reorder_triton_kernel_for_cutlass_moe[grid](\n        down_output_ptr=down_output,\n        output_ptr=output,\n        src2dst_ptr=src2dst,\n        topk_ids_ptr=topk_ids,\n        topk_weights_ptr=topk_weights,\n        num_local_experts=num_local_experts,\n        topk=topk,\n        num_tokens=num_tokens,\n        hidden_size=hidden_size,\n        routed_scaling_factor=routed_scaling_factor,\n        BLOCK_SIZE=block_dim,\n        NUM_STAGES=3,\n    )\n\n\n@triton.jit\ndef post_reorder_triton_kernel(\n    down_output_ptr,\n    output_ptr,\n    src2dst_ptr,\n    topk_ids_ptr,\n    topk_weights_ptr,\n    topk,\n    hidden_size,\n    BLOCK_SIZE: tl.constexpr,\n):\n    InDtype = down_output_ptr.dtype.element_ty\n\n    src_idx_int32 = tl.program_id(0)\n    src_idx = src_idx_int32.to(tl.int64)\n    src2dst_ptr = src2dst_ptr + src_idx * topk\n    topk_ids_ptr = topk_ids_ptr + src_idx * topk\n    topk_weights_ptr = topk_weights_ptr + src_idx * topk\n\n    store_ptr = output_ptr + src_idx * hidden_size\n\n    vec = tl.arange(0, BLOCK_SIZE)\n\n    for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):\n        offset = start_offset + vec\n        mask = offset < hidden_size\n\n        sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)\n        for idx in range(topk):\n            expert_id = tl.load(topk_ids_ptr + idx)\n            if expert_id > 0:\n                dst_idx_int32 = tl.load(src2dst_ptr + idx)\n                dst_idx = dst_idx_int32.to(tl.int64)\n                weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)\n                load_ptr = down_output_ptr + dst_idx * hidden_size\n                in_data = tl.load(load_ptr + offset, mask=mask)\n                sum_vec += in_data * weigh_scale\n        tl.store(store_ptr + offset, sum_vec, mask=mask)\n\n\n@triton.jit\ndef _fwd_kernel_ep_scatter_1(\n    num_recv_tokens_per_expert,\n    expert_start_loc,\n    m_indices,\n    num_experts: tl.constexpr,\n    BLOCK_E: tl.constexpr,\n    BLOCK_EXPERT_NUM: tl.constexpr,\n):\n    cur_expert = tl.program_id(0)\n\n    offset_cumsum = tl.arange(0, BLOCK_EXPERT_NUM)\n    tokens_per_expert = tl.load(\n        num_recv_tokens_per_expert + offset_cumsum,\n        mask=offset_cumsum < num_experts,\n        other=0,\n    )\n    cumsum = tl.cumsum(tokens_per_expert) - tokens_per_expert\n    tl.store(expert_start_loc + offset_cumsum, cumsum, mask=offset_cumsum < num_experts)\n\n    cur_expert_start = tl.load(expert_start_loc + cur_expert)\n    cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert)\n\n    m_indices_start_ptr = m_indices + cur_expert_start\n    off_expert = tl.arange(0, BLOCK_E)\n\n    for start_m in tl.range(0, cur_expert_token_num, BLOCK_E, num_stages=4):\n        tl.store(\n            m_indices_start_ptr + start_m + off_expert,\n            cur_expert,\n        )\n\n\n@triton.jit\ndef _fwd_kernel_ep_scatter_2(\n    total_token_num,\n    expert_start_loc,\n    recv_x,\n    recv_x_stride0,\n    recv_x_stride1,\n    recv_x_scale,\n    recv_x_scale_stride0,\n    recv_x_scale_stride1,\n    recv_topk,\n    recv_topk_stride0,\n    recv_topk_stride1,\n    output_tensor,\n    output_tensor_stride0,\n    output_tensor_stride1,\n    output_tensor_scale,\n    output_tensor_scale_stride0,\n    output_tensor_scale_stride1,\n    output_index,\n    output_index_stride0,\n    output_index_stride1,\n    topk_num: tl.constexpr,\n    HIDDEN_SIZE: tl.constexpr,\n    HIDDEN_SIZE_PAD: tl.constexpr,\n    SCALE_HIDDEN_SIZE: tl.constexpr,\n    SCALE_HIDDEN_SIZE_PAD: tl.constexpr,\n):\n    start_token_id = tl.program_id(0)\n    grid_num = tl.num_programs(0)\n\n    offset_in = tl.arange(0, HIDDEN_SIZE_PAD)\n    mask = offset_in < HIDDEN_SIZE\n\n    index_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD)\n    mask_s = index_in_s < SCALE_HIDDEN_SIZE\n\n    for token_id_int32 in range(start_token_id, total_token_num, grid_num):\n        token_id = token_id_int32.to(tl.int64)\n        to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask)\n        to_copy_s = tl.load(\n            recv_x_scale\n            + token_id * recv_x_scale_stride0\n            + index_in_s * recv_x_scale_stride1,\n            mask=mask_s,\n        )\n\n        for topk_idx_int32 in tl.range(0, topk_num, 1, num_stages=4):\n            topk_index = topk_idx_int32.to(tl.int64)\n            expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index)\n            if expert_id >= 0:\n                dest_token_index_int32 = tl.atomic_add(expert_start_loc + expert_id, 1)\n                dest_token_index = dest_token_index_int32.to(tl.int64)\n\n                tl.store(\n                    output_index + token_id * output_index_stride0 + topk_index,\n                    dest_token_index_int32,\n                )\n                output_tensor_ptr = (\n                    output_tensor + dest_token_index * output_tensor_stride0\n                )\n                output_tensor_scale_ptr = (\n                    output_tensor_scale + dest_token_index * output_tensor_scale_stride0\n                )\n                tl.store(output_tensor_ptr + offset_in, to_copy, mask=mask)\n                tl.store(\n                    output_tensor_scale_ptr + index_in_s * output_tensor_scale_stride1,\n                    to_copy_s,\n                    mask=mask_s,\n                )\n\n\n# copy from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/deepep_scatter_gather.py\n@torch.no_grad()\ndef ep_scatter(\n    recv_x: torch.Tensor,\n    recv_x_scale: torch.Tensor,\n    recv_topk: torch.Tensor,\n    num_recv_tokens_per_expert: torch.Tensor,\n    expert_start_loc: torch.Tensor,\n    output_tensor: torch.Tensor,\n    output_tensor_scale: torch.Tensor,\n    m_indices: torch.Tensor,\n    output_index: torch.Tensor,\n    scale_ue8m0: bool = False,\n):\n    BLOCK_E = 128  # token num of per expert is aligned to 128\n    BLOCK_D = 128  # block size of quantization\n    num_warps = 8\n    num_experts = num_recv_tokens_per_expert.shape[0]\n    hidden_size = recv_x.shape[1]\n    # grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts)\n    grid = num_experts\n\n    scale_hidden_size = hidden_size // BLOCK_D\n    if scale_ue8m0:\n        # ue8m0 scales are packed here (4 scales per int32),\n        # hence the effective size of this dimension is divided by 4.\n        scale_hidden_size = ceil_div(scale_hidden_size, 4)\n\n    assert m_indices.shape[0] % BLOCK_E == 0\n    assert (\n        recv_x_scale.dtype == output_tensor_scale.dtype\n    ), f\"recv_x_scale.dtype: {recv_x_scale.dtype}, output_tensor_scale.dtype: {output_tensor_scale.dtype}\"\n    assert recv_x_scale.shape[1] == output_tensor_scale.shape[1] == scale_hidden_size\n\n    _fwd_kernel_ep_scatter_1[(grid,)](\n        num_recv_tokens_per_expert,\n        expert_start_loc,\n        m_indices,\n        num_experts=num_experts,\n        num_warps=num_warps,\n        BLOCK_E=BLOCK_E,\n        BLOCK_EXPERT_NUM=triton.next_power_of_2(num_experts),\n    )\n\n    grid = min(recv_topk.shape[0], 1024 * 8)\n\n    _fwd_kernel_ep_scatter_2[(grid,)](\n        recv_topk.shape[0],\n        expert_start_loc,\n        recv_x,\n        recv_x.stride(0),\n        recv_x.stride(1),\n        recv_x_scale,\n        recv_x_scale.stride(0),\n        recv_x_scale.stride(1),\n        recv_topk,\n        recv_topk.stride(0),\n        recv_topk.stride(1),\n        output_tensor,\n        output_tensor.stride(0),\n        output_tensor.stride(1),\n        output_tensor_scale,\n        output_tensor_scale.stride(0),\n        output_tensor_scale.stride(1),\n        output_index,\n        output_index.stride(0),\n        output_index.stride(1),\n        topk_num=recv_topk.shape[1],\n        num_warps=num_warps,\n        HIDDEN_SIZE=hidden_size,\n        HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size),\n        SCALE_HIDDEN_SIZE=scale_hidden_size,\n        SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(scale_hidden_size),\n    )\n    return\n\n\n@triton.jit\ndef _fwd_kernel_ep_gather(\n    total_token_num,\n    input_tensor,\n    input_tensor_stride0,\n    input_tensor_stride1,\n    recv_topk_ids,\n    recv_topk_ids_stride0,\n    recv_topk_ids_stride1,\n    recv_topk_weight,\n    recv_topk_weight_stride0,\n    recv_topk_weight_stride1,\n    input_index,\n    input_index_stride0,\n    input_index_stride1,\n    output_tensor,\n    output_tensor_stride0,\n    output_tensor_stride1,\n    topk_num: tl.constexpr,\n    BLOCK_D: tl.constexpr,\n):\n    cur_block_int32 = tl.program_id(0)\n    cur_block = cur_block_int32.to(tl.int64)\n\n    start_cur_token_int32 = tl.program_id(1)\n\n    grid_num = tl.num_programs(1)\n\n    for cur_token_int32 in range(start_cur_token_int32, total_token_num, grid_num):\n        cur_token = cur_token_int32.to(tl.int64)\n\n        off_d = tl.arange(0, BLOCK_D)\n        accumulator = tl.zeros([BLOCK_D], dtype=tl.float32)\n\n        for topk_index_int32 in range(0, topk_num):\n            topk_index = topk_index_int32.to(tl.int64)\n\n            expert_id = tl.load(\n                recv_topk_ids + cur_token * recv_topk_ids_stride0 + topk_index\n            )\n            if expert_id >= 0:\n                source_token_index_int32 = tl.load(\n                    input_index + cur_token * input_index_stride0 + topk_index\n                )\n                source_token_index = source_token_index_int32.to(tl.int64)\n\n                acc_weight = tl.load(\n                    recv_topk_weight + cur_token * recv_topk_weight_stride0 + topk_index\n                )\n                tmp = tl.load(\n                    input_tensor\n                    + source_token_index * input_tensor_stride0\n                    + cur_block * BLOCK_D\n                    + off_d\n                )\n                accumulator += tmp.to(tl.float32) * acc_weight\n\n        tl.store(\n            output_tensor\n            + cur_token * output_tensor_stride0\n            + cur_block * BLOCK_D\n            + off_d,\n            accumulator.to(output_tensor.dtype.element_ty),\n        )\n\n\n@torch.no_grad()\ndef ep_gather(\n    input_tensor: torch.Tensor,\n    recv_topk_ids: torch.Tensor,\n    recv_topk_weight: torch.Tensor,\n    input_index: torch.Tensor,\n    output_tensor: torch.Tensor,\n):\n    num_warps = 2\n    num_tokens = output_tensor.shape[0]\n    hidden_size = input_tensor.shape[1]\n    BLOCK_D = 128 if hidden_size % 1024 != 0 else 1024  # block size of quantization\n    assert hidden_size % BLOCK_D == 0\n    grid = (triton.cdiv(hidden_size, BLOCK_D), min(num_tokens, 1024))\n    _fwd_kernel_ep_gather[grid](\n        num_tokens,\n        input_tensor,\n        input_tensor.stride(0),\n        input_tensor.stride(1),\n        recv_topk_ids,\n        recv_topk_ids.stride(0),\n        recv_topk_ids.stride(1),\n        recv_topk_weight,\n        recv_topk_weight.stride(0),\n        recv_topk_weight.stride(1),\n        input_index,\n        input_index.stride(0),\n        input_index.stride(1),\n        output_tensor,\n        output_tensor.stride(0),\n        output_tensor.stride(1),\n        topk_num=recv_topk_ids.shape[1],\n        num_warps=num_warps,\n        BLOCK_D=BLOCK_D,\n    )\n    return\n\n\n# copy from\n# https://github.com/deepseek-ai/DeepGEMM/blob/bd2a77552886b98c205af12f8d7d2d61247c4b27/deep_gemm/jit_kernels/utils.py#L58\ndef get_tma_aligned_size(x: int, element_size: int) -> int:\n    \"\"\"\n    Global memory address of TMA must be 16-byte aligned.\n    Since we use column-major layout for the LHS scaling tensor,\n        the M-axis of the LHS scaling tensor needs to be padded to a multiple of 16 bytes.\n\n    Arguments:\n        x: original M-axis shape of the LHS scaling tensor.\n        element_size: element size of the LHS scaling tensor.\n\n    Returns:\n        M-axis shape of the LHS scaling tensor after padding.\n    \"\"\"\n    tma_alignment_bytes = 16\n    assert tma_alignment_bytes % element_size == 0\n    alignment = tma_alignment_bytes // element_size\n    return ceil_div(x, alignment) * alignment\n\n\n@triton.jit\ndef _tma_align_input_scale_kernel(\n    input_scale_ptr,\n    output_ptr,\n    m,\n    k_div_block_size,\n    input_scale_stride_m,\n    input_scale_stride_k,\n    output_stride_m,\n    output_stride_k,\n    BLOCK_SIZE_K: tl.constexpr,\n):\n    pid_m = tl.program_id(axis=0)\n    grid_m = tl.num_programs(0)\n    k_offsets = tl.arange(0, BLOCK_SIZE_K)\n\n    for m_base in range(pid_m, m, grid_m):\n        input_offset = (\n            input_scale_ptr\n            + m_base * input_scale_stride_m\n            + k_offsets * input_scale_stride_k\n        )\n        input_data = tl.load(input_offset, mask=k_offsets < k_div_block_size)\n\n        output_offset = (\n            output_ptr + k_offsets * output_stride_k + m_base * output_stride_m\n        )\n        tl.store(output_offset, input_data, mask=k_offsets < k_div_block_size)\n\n\n# copy from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py\ndef tma_align_input_scale(input_scale: torch.Tensor):\n    assert input_scale.dim() == 2\n    m, k_div_block_size = input_scale.shape\n    padd_m = get_tma_aligned_size(m, input_scale.element_size())\n    output = torch.empty(\n        (k_div_block_size, padd_m), dtype=input_scale.dtype, device=input_scale.device\n    )\n\n    grid_m = min(m, 8192)\n    BLOCK_SIZE_K = triton.next_power_of_2(k_div_block_size)\n\n    _tma_align_input_scale_kernel[(grid_m,)](\n        input_scale_ptr=input_scale,\n        output_ptr=output,\n        m=m,\n        k_div_block_size=k_div_block_size,\n        input_scale_stride_m=input_scale.stride(0),\n        input_scale_stride_k=input_scale.stride(1),\n        output_stride_m=output.stride(1),  # Note: these are swapped\n        output_stride_k=output.stride(0),  # for column-major\n        BLOCK_SIZE_K=BLOCK_SIZE_K,\n    )\n    return output.t()[:m]\n\n\n@triton.jit\ndef compute_masked_m_triton_kernel(seg_indptr, masked_m):\n    expert_id = tl.program_id(0)\n    start = tl.load(seg_indptr + expert_id)\n    end = tl.load(seg_indptr + expert_id + 1)\n    tl.store(masked_m + expert_id, (end - start))\n\n\n@triton.jit\ndef deepgemm_compute_src2dst_triton_kernel(\n    topk_ids,\n    reorder_ids,\n    seg_indptr,\n    src2dst,\n    m_max,\n    num_toks,\n    BLOCK_SIZE: tl.constexpr,\n):\n    pid = tl.program_id(axis=0)\n    dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n    mask = dst_id < num_toks\n    src_id = tl.load(reorder_ids + dst_id, mask=mask)\n    expert_id = tl.load(topk_ids + src_id, mask=(src_id < num_toks))\n    expert_dst_start = tl.load(seg_indptr + expert_id, mask=(expert_id >= 0))\n    expert_dst_offset = dst_id - expert_dst_start\n    dst_id = expert_id * m_max + expert_dst_offset\n    tl.store(src2dst + src_id, dst_id, mask=mask)\n\n\n@triton.jit\ndef fill_gateup_input_triton_kernel(\n    input_ptr,\n    scale_ptr,\n    gateup_input_ptr,\n    gateup_input_scale_ptr,\n    src2dst_ptr,\n    topk_ids_ptr,\n    topk,\n    hidden_size,\n    scale_size,\n    BLOCK_SIZE: tl.constexpr,\n):\n\n    src_idx_int32 = tl.program_id(0)\n    src_idx = src_idx_int32.to(tl.int64)\n    src2dst_ptr = src2dst_ptr + src_idx * topk\n    topk_ids_ptr = topk_ids_ptr + src_idx * topk\n    src_ptr = input_ptr + src_idx * hidden_size\n    scale_src_ptr = scale_ptr + src_idx * scale_size\n\n    vec = tl.arange(0, BLOCK_SIZE)\n    for idx in range(topk):\n        expert_id = tl.load(topk_ids_ptr + idx)\n        if expert_id >= 0:\n            dst_idx_int32 = tl.load(src2dst_ptr + idx)\n            dst_idx = dst_idx_int32.to(tl.int64)\n            dst_ptr = gateup_input_ptr + dst_idx * hidden_size\n            for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):\n                offset = start_offset + vec\n                mask = offset < hidden_size\n                in_data = tl.load(src_ptr + offset, mask=mask)\n                tl.store(dst_ptr + offset, in_data, mask=mask)\n            scale_dst_ptr = gateup_input_scale_ptr + dst_idx * scale_size\n            for start_offset in tl.range(0, scale_size, BLOCK_SIZE):\n                offset = start_offset + vec\n                mask = offset < scale_size\n                in_scale = tl.load(scale_src_ptr + offset, mask=mask)\n                tl.store(scale_dst_ptr + offset, in_scale, mask=mask)\n\n\ndef moe_ep_deepgemm_preprocess(\n    topk_ids: torch.Tensor,\n    num_local_experts: int,\n    hidden_states: torch.Tensor,\n    top_k: int,\n    block_shape,\n    output_dtype: torch.dtype = torch.float8_e4m3fn,\n):\n    reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)\n    seg_indptr = torch.zeros(\n        num_local_experts + 1, device=topk_ids.device, dtype=torch.int64\n    )\n    src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)\n    masked_m = torch.empty(num_local_experts, device=topk_ids.device, dtype=torch.int32)\n\n    compute_seg_indptr_triton_kernel[(num_local_experts + 1,)](\n        reorder_topk_ids, seg_indptr, topk_ids.numel()\n    )\n\n    grid = lambda meta: (triton.cdiv(topk_ids.numel(), meta[\"BLOCK_SIZE\"]),)\n    compute_masked_m_triton_kernel[(num_local_experts,)](seg_indptr, masked_m)\n\n    # For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m}) https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/jit_kernels/m_grouped_gemm.py#L165\n    m_max = (hidden_states.size(0) // 256 + 1) * 256\n    expected_m = (topk_ids.numel() - 1) // num_local_experts + 1\n    gateup_input = torch.empty(\n        (num_local_experts, m_max, hidden_states.size(1)),\n        device=hidden_states.device,\n        dtype=output_dtype,\n    )\n\n    deepgemm_compute_src2dst_triton_kernel[grid](\n        topk_ids,\n        reorder_ids,\n        seg_indptr,\n        src2dst,\n        m_max,\n        topk_ids.numel(),\n        BLOCK_SIZE=256,\n    )\n\n    if block_shape is None:\n        block_shape = [128, 128]\n    assert len(block_shape) == 2\n    block_n, block_k = block_shape[0], block_shape[1]\n\n    # TODO: fuse this with the preprocess\n    hidden_states, scale = per_token_group_quant_fp8(hidden_states, block_k)\n\n    gateup_input_scale = torch.empty(\n        (gateup_input.size(0), gateup_input.size(1), scale.size(1)),\n        device=hidden_states.device,\n        dtype=scale.dtype,\n    )\n\n    fill_gateup_input_triton_kernel[(hidden_states.shape[0],)](\n        hidden_states,\n        scale,\n        gateup_input,\n        gateup_input_scale,\n        src2dst,\n        topk_ids,\n        top_k,\n        hidden_states.size(1),\n        scale.size(1),\n        BLOCK_SIZE=1024,\n    )\n\n    return (\n        masked_m,\n        expected_m,\n        src2dst,\n        gateup_input,\n        gateup_input_scale,\n    )\n\n\n@triton.jit\ndef compute_identity_kernel(\n    top_k,\n    hidden_states_ptr,\n    expert_scales_ptr,\n    num_tokens,\n    output_ptr,\n    hidden_dim,\n    scales_stride,\n    BLOCK_SIZE: tl.constexpr,\n):\n    pid = tl.program_id(0)\n\n    batch_id = pid // (hidden_dim // BLOCK_SIZE)\n    dim_offset = pid % (hidden_dim // BLOCK_SIZE) * BLOCK_SIZE\n\n    if batch_id >= num_tokens or dim_offset >= hidden_dim:\n        return\n\n    h = tl.load(\n        hidden_states_ptr\n        + batch_id * hidden_dim\n        + dim_offset\n        + tl.arange(0, BLOCK_SIZE),\n        mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim,\n    )\n\n    result = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n    for i in range(top_k):\n        scale = tl.load(expert_scales_ptr + batch_id * scales_stride + i)\n        result += h * scale\n\n    tl.store(\n        output_ptr + batch_id * hidden_dim + dim_offset + tl.arange(0, BLOCK_SIZE),\n        result,\n        mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim,\n    )\n\n\ndef zero_experts_compute_triton(\n    expert_indices, expert_scales, num_experts, zero_expert_type, hidden_states\n):\n    N = expert_indices.numel()\n    top_k = expert_indices.size(-1)\n    grid = lambda meta: (triton.cdiv(N, meta[\"BLOCK_SIZE\"]),)\n\n    if zero_expert_type == \"identity\":\n        zero_expert_mask = expert_indices < num_experts\n        zero_expert_scales = expert_scales.clone()\n        zero_expert_scales[zero_expert_mask] = 0.0\n\n    normal_expert_mask = expert_indices >= num_experts\n    expert_indices[normal_expert_mask] = -1\n    expert_scales[normal_expert_mask] = 0.0\n\n    output = torch.zeros_like(hidden_states).to(hidden_states.device)\n    hidden_dim = hidden_states.size(-1)\n    num_tokens = hidden_states.size(0)\n\n    grid = lambda meta: (num_tokens * (hidden_dim // meta[\"BLOCK_SIZE\"]),)\n    compute_identity_kernel[grid](\n        top_k,\n        hidden_states,\n        zero_expert_scales,\n        num_tokens,\n        output,\n        hidden_dim,\n        zero_expert_scales.stride(0),\n        BLOCK_SIZE=256,\n    )\n\n    return output\n\n\n@triton.jit\ndef compute_problem_sizes_w4a8_kernel(\n    masked_m_ptr,\n    problem_sizes1_ptr,\n    problem_sizes2_ptr,\n    n,\n    k,\n    num_experts,\n    BLOCK_SIZE: tl.constexpr,\n):\n    pid = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n    mask = pid < num_experts\n    final_occurrences = tl.load(masked_m_ptr + pid, mask=mask, other=0)\n\n    ps1_idx_0 = pid * 3\n    ps1_idx_1 = ps1_idx_0 + 1\n    ps1_idx_2 = ps1_idx_0 + 2\n\n    ps2_idx_0 = pid * 3\n    ps2_idx_1 = ps2_idx_0 + 1\n    ps2_idx_2 = ps2_idx_0 + 2\n\n    ps1_mask_0 = ps1_idx_0 < num_experts * 3\n    ps1_mask_1 = ps1_idx_1 < num_experts * 3\n    ps1_mask_2 = ps1_idx_2 < num_experts * 3\n    ps2_mask_0 = ps2_idx_0 < num_experts * 3\n    ps2_mask_1 = ps2_idx_1 < num_experts * 3\n    ps2_mask_2 = ps2_idx_2 < num_experts * 3\n\n    tl.store(problem_sizes1_ptr + ps1_idx_0, 2 * n, mask=ps1_mask_0)\n    tl.store(problem_sizes1_ptr + ps1_idx_1, final_occurrences, mask=ps1_mask_1)\n    tl.store(problem_sizes1_ptr + ps1_idx_2, k, mask=ps1_mask_2)\n\n    tl.store(problem_sizes2_ptr + ps2_idx_0, k, mask=ps2_mask_0)\n    tl.store(problem_sizes2_ptr + ps2_idx_1, final_occurrences, mask=ps2_mask_1)\n    tl.store(problem_sizes2_ptr + ps2_idx_2, n, mask=ps2_mask_2)\n\n\ndef compute_problem_sizes_w4a8(\n    masked_m, problem_sizes1, problem_sizes2, n, k, num_experts\n):\n    BLOCK_SIZE = 256\n    grid = lambda meta: (triton.cdiv(num_experts, meta[\"BLOCK_SIZE\"]),)\n    compute_problem_sizes_w4a8_kernel[grid](\n        masked_m,\n        problem_sizes1,\n        problem_sizes2,\n        n,\n        k,\n        num_experts,\n        BLOCK_SIZE=BLOCK_SIZE,\n    )\n    return problem_sizes1, problem_sizes2\n\n\ndef deepep_ll_get_cutlass_w4a8_moe_mm_data(\n    masked_m,\n    problem_sizes1,\n    problem_sizes2,\n    num_experts,\n    n,\n    k,\n):\n    problem_sizes1, problem_sizes2 = compute_problem_sizes_w4a8(\n        masked_m, problem_sizes1, problem_sizes2, n, k, num_experts\n    )\n    return (\n        problem_sizes1.to(torch.int32),\n        problem_sizes2.to(torch.int32),\n    )\n\n\n@triton.jit\ndef _silu_and_mul_post_per_tensor_quant_kernel(\n    input_ptr,\n    stride_input_expert,\n    stride_input_token,\n    stride_input_dim,\n    output_ptr,\n    stride_output_expert,\n    stride_output_token,\n    stride_output_dim,\n    scale_ptr,\n    masked_m_ptr,\n    inner_dim,\n    fp8_max,\n    fp8_min,\n    BLOCK_N: tl.constexpr,\n    NUM_STAGE: tl.constexpr,\n):\n    \"\"\"\n    Triton kernel: fused SiLU(gate) * up + per-tensor FP8 quantization.\n\n    Shape:\n        input:  [E, T_padded, 2*D]  -> gate: [:,:,D], up: [:,:,D]\n        output: [E, T_padded, D], dtype=float8_e4m3fn\n    \"\"\"\n    expert_id = tl.program_id(2)\n    block_id_token = tl.program_id(1)\n    block_id_dim = tl.program_id(0)\n\n    num_token_blocks = tl.num_programs(1)\n\n    token_num_cur_expert = tl.load(masked_m_ptr + expert_id)\n\n    scale = 1.0 / tl.load(scale_ptr).to(tl.float32)\n\n    stride_input_expert = tl.cast(stride_input_expert, tl.int32)\n    stride_output_expert = tl.cast(stride_output_expert, tl.int32)\n    stride_input_token = tl.cast(stride_input_token, tl.int32)\n    stride_output_token = tl.cast(stride_output_token, tl.int32)\n\n    offset_d = block_id_dim * BLOCK_N + tl.arange(0, BLOCK_N)\n    mask_d = offset_d < inner_dim\n\n    # base pointers for current expert and dim block\n    input_base_offs = input_ptr + expert_id * stride_input_expert + offset_d\n    output_base_offs = output_ptr + expert_id * stride_output_expert + offset_d\n\n    for token_idx in tl.range(\n        block_id_token, token_num_cur_expert, num_token_blocks, num_stages=NUM_STAGE\n    ):\n        gate_ptr = input_base_offs + token_idx * stride_input_token\n        up_ptr = gate_ptr + inner_dim\n        gate = tl.load(gate_ptr, mask=mask_d, other=0.0).to(tl.float32)\n        up = tl.load(up_ptr, mask=mask_d, other=0.0).to(tl.float32)\n\n        # SiLU: x * sigmoid(x)\n        gate = gate / (1 + tl.exp(-gate))\n        gate = gate.to(input_ptr.dtype.element_ty)\n        gate_up = up * gate\n\n        scaled = gate_up * scale\n        output_q = tl.clamp(scaled, fp8_min, fp8_max).to(output_ptr.dtype.element_ty)\n        out_ptr = output_base_offs + token_idx * stride_output_token\n        tl.store(out_ptr, output_q, mask=mask_d)\n\n\ndef silu_and_mul_masked_post_per_tensor_quant_fwd(\n    input: torch.Tensor,\n    output: torch.Tensor,\n    masked_m: torch.Tensor,\n    scale: torch.Tensor,\n) -> torch.Tensor:\n    \"\"\"\n    Fused SiLU + Mul + Per-Tensor Quantization to FP8.\n\n    Args:\n        input: [expert_num, token_num_padded, 2 * inner_dim]\n        output: [expert_num, token_num_padded, inner_dim], dtype=torch.float8_e4m3fn\n        masked_m: [expert_num], actual token count for each expert\n        scale: [1] or [expert_num], quantization scale (per-tensor or per-expert)\n\n    Returns:\n        output tensor\n    \"\"\"\n    assert input.is_contiguous()\n    assert output.is_contiguous()\n    assert output.dtype == torch.float8_e4m3fn\n    assert input.ndim == 3\n    assert input.shape[0] == masked_m.shape[0]\n    assert input.shape[-1] % 2 == 0\n    assert scale.numel() == 1 or scale.shape[0] == input.shape[0]\n\n    expert_num = input.shape[0]\n    #  3584\n    inner_dim = input.shape[-1] // 2\n\n    BLOCK_N = 256\n    BLOCK_M = 64 if expert_num < 4 else 32\n    NUM_STAGES = 3\n    hidden_dim_split_block_num = triton.cdiv(inner_dim, BLOCK_N)\n\n    grid = (hidden_dim_split_block_num, BLOCK_M, expert_num)\n    finfo = torch.finfo(torch.float8_e4m3fn)\n    fp8_max = finfo.max\n    fp8_min = -fp8_max\n\n    _silu_and_mul_post_per_tensor_quant_kernel[grid](\n        input,\n        *input.stride(),\n        output,\n        *output.stride(),\n        scale,\n        masked_m,\n        inner_dim,\n        fp8_max,\n        fp8_min,\n        BLOCK_N=BLOCK_N,\n        NUM_STAGE=NUM_STAGES,\n    )\n    return output\n"
  },
  {
    "path": "python/sglang/srt/layers/moe/flashinfer_cutedsl_moe.py",
    "content": "from typing import Optional\n\nimport torch\nfrom flashinfer import (\n    scaled_fp4_grouped_quantize,\n    silu_and_mul_scaled_nvfp4_experts_quantize,\n)\nfrom flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked\n\n\ndef get_cute_dtype(input: torch.Tensor) -> str:\n    if input.dtype == torch.bfloat16:\n        return \"bfloat16\"\n    elif input.dtype == torch.float16:\n        return \"float16\"\n    elif input.dtype == torch.float32:\n        return \"float32\"\n    else:\n        raise ValueError(f\"Unsupported cute dtype {input.dtype}\")\n\n\ndef flashinfer_cutedsl_moe_masked(\n    hidden_states: tuple[torch.Tensor, Optional[torch.Tensor]],\n    input_global_scale: torch.Tensor,\n    w1: torch.Tensor,\n    w1_blockscale: torch.Tensor,\n    w1_alpha,\n    w2: torch.Tensor,\n    a2_global_scale: torch.Tensor,\n    w2_blockscale: torch.Tensor,\n    w2_alpha,\n    masked_m: torch.Tensor,\n    down_sm_count: Optional[int] = None,\n    down_signals: Optional[torch.Tensor] = None,\n    down_start_event: Optional[torch.cuda.Event] = None,\n):\n    \"\"\"\n    Perform masked Mixture-of-Experts computation with FlashInfer's CuteDSL\n    kernels.\n\n    Args:\n        hidden_states: Either of the following case\n            * tuple[torch.Tensor, None]: [num_experts, m, k], bf16, None means no quant\n            * tuple[torch.Tensor, torch.Tensor]: [num_experts, m, k // 2], uint8, [num_experts, m, k // 16], float8_e4m3fn\n        input_global_scale (torch.Tensor): (l,)\n        w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8\n        w1_blockscale (torch.Tensor): blockscale factors, e4m3,\n        w1_alpha (torch.Tensor): (l,)\n        w2 (torch.Tensor): fp4 weights, [l, k, n // 2], uint8\n        a2_global_scale (torch.Tensor): (l,)\n        w2_blockscale (torch.Tensor): blockscale factors, e4m3,\n        w2_alpha (torch.Tensor): (l,)\n        masked_m (torch.Tensor): Masked dimension indices\n\n    Notes:\n        - Assumes max(masked_m) == m.\n    \"\"\"\n\n    # === Assertions on dtypes ===\n    assert w1.dtype == torch.uint8, f\"w1 must be uint8 (fp4 packed), got {w1.dtype}\"\n    assert (\n        w1_blockscale.dtype == torch.float8_e4m3fn\n    ), f\"w1_blockscale must be float8_e4m3fn, got {w1_blockscale.dtype}\"\n    assert (\n        w1_alpha.dtype == torch.float32\n    ), f\"w1_alpha must be float32, got {w1_alpha.dtype}\"\n    assert w2.dtype == torch.uint8, f\"w2 must be uint8 (fp4 packed), got {w2.dtype}\"\n    assert (\n        a2_global_scale.dtype == torch.float32\n    ), f\"a2_global_scale must be float32, got {a2_global_scale.dtype}\"\n    assert (\n        w2_blockscale.dtype == torch.float8_e4m3fn\n    ), f\"w2_blockscale must be float8_e4m3fn, got {w2_blockscale.dtype}\"\n    assert (\n        w2_alpha.dtype == torch.float32\n    ), f\"w2_alpha must be float32, got {w2_alpha.dtype}\"\n    assert (\n        len(hidden_states) == 2\n    ), f\"hidden_states must be a tuple of length 2, got {len(hidden_states)}\"\n\n    # === Assertions on shapes ===\n    n = w2.shape[-1] * 2  # intermediate dimension\n\n    if hidden_states[1] is not None:\n\n        a_q = hidden_states[0].view(torch.uint8)\n        a_q_sf = hidden_states[1].view(torch.float8_e4m3fn)\n        m, k_by_2, num_experts = a_q.shape\n        k = k_by_2 * 2\n    else:\n        num_experts, m, k = hidden_states[0].shape\n\n        assert (\n            input_global_scale.dtype == torch.float32\n        ), f\"input_global_scale must be float32, got {input_global_scale.dtype}\"\n        assert input_global_scale.shape == (\n            num_experts,\n        ), f\"input_global_scale must be (l,), got {input_global_scale.shape}\"\n\n        a_q, a_q_sf = scaled_fp4_grouped_quantize(\n            hidden_states[0],\n            masked_m,\n            input_global_scale,\n        )\n\n    assert w1.shape[-2] == 2 * n, f\"w1 last-2 dim must be 2*n, got {w1.shape}\"\n    assert (\n        w1.shape[-1] * 2 == k\n    ), f\"w1 last dim * 2 must equal k, got {w1.shape[-1]} vs k={k}\"\n    assert w2.shape[-2:] == (\n        k,\n        n // 2,\n    ), f\"w2 shape mismatch, got {w2.shape[-2:]}, expected {(k, n//2)}\"\n    assert w1_alpha.shape == (\n        num_experts,\n    ), f\"w1_alpha must be (l,), got {w1_alpha.shape}\"\n    assert a2_global_scale.shape == (\n        num_experts,\n    ), f\"a2_global_scale must be (l,), got {a2_global_scale.shape}\"\n    assert w2_alpha.shape == (\n        num_experts,\n    ), f\"w2_alpha must be (l,), got {w2_alpha.shape}\"\n\n    # TODO(kaixih@nvidia): dtype should be based on inputs.\n    gateup_output = torch.empty(\n        (num_experts, m, n * 2), dtype=torch.bfloat16, device=a_q.device\n    )\n    gateup_output = gateup_output.permute(1, 2, 0)  # requirement of kernel\n    sf_vec_size = 16\n    assert a_q_sf.dtype == torch.float8_e4m3fn\n    assert a_q.dtype == torch.uint8\n    ab_dtype = \"float4_e2m1fn\"\n    sf_dtype = \"float8_e4m3fn\"\n    c_dtype = \"bfloat16\"\n\n    # Gemm1\n    grouped_gemm_nt_masked(\n        (a_q, a_q_sf),\n        (w1.permute(1, 2, 0), w1_blockscale),\n        gateup_output,\n        masked_m,\n        ab_dtype=ab_dtype,\n        sf_dtype=sf_dtype,\n        c_dtype=c_dtype,\n        sf_vec_size=sf_vec_size,\n        alpha=w1_alpha.view(1, 1, num_experts),\n        alpha_dtype=get_cute_dtype(w1_alpha),\n    )  # in logical [m, n, l]\n\n    # SILU and quantization\n    diq, diq_sf = silu_and_mul_scaled_nvfp4_experts_quantize(\n        gateup_output.permute(2, 0, 1),\n        masked_m,\n        a2_global_scale,\n    )\n\n    if down_start_event is not None:\n        down_start_event.record()\n\n    # Gemm2\n    out = torch.empty((num_experts, m, k), dtype=torch.bfloat16, device=a_q.device)\n    out = out.permute(1, 2, 0)  # requirement of kernel\n    grouped_gemm_nt_masked(\n        (diq, diq_sf),\n        (w2.permute(1, 2, 0), w2_blockscale),\n        out,\n        masked_m,\n        ab_dtype=ab_dtype,\n        sf_dtype=sf_dtype,\n        c_dtype=c_dtype,\n        sf_vec_size=sf_vec_size,\n        alpha=w2_alpha.view(1, 1, num_experts),\n        alpha_dtype=get_cute_dtype(w2_alpha),\n        **(\n            dict(\n                sm_count=down_sm_count,\n                dst_signals=down_signals,\n            )\n            if down_sm_count is not None or down_signals is not None\n            else {}\n        ),\n    )  # in logical [m, k, l]\n    return out.permute(2, 0, 1)\n"
  },
  {
    "path": "python/sglang/srt/layers/moe/flashinfer_trtllm_moe.py",
    "content": "from typing import Optional\n\nimport torch\n\nfrom sglang.srt.utils.custom_op import register_custom_op\n\n\ndef _fake_fp8_block_scale_moe(\n    routing_logits: torch.Tensor,\n    routing_bias: Optional[torch.Tensor],\n    hidden_states: torch.Tensor,\n    hidden_states_scale: torch.Tensor,\n    gemm1_weights: torch.Tensor,\n    gemm1_weights_scale: torch.Tensor,\n    gemm2_weights: torch.Tensor,\n    gemm2_weights_scale: torch.Tensor,\n    num_experts: int,\n    top_k: int,\n    n_group: Optional[int],\n    topk_group: Optional[int],\n    intermediate_size: int,\n    local_expert_offset: int,\n    local_num_experts: int,\n    routed_scaling_factor: Optional[float],\n    routing_method_type: int = 0,\n    use_shuffled_weight: bool = False,\n    weight_layout: int = 0,\n    enable_pdl: Optional[bool] = None,\n    tune_max_num_tokens: int = 8192,\n    fp8_quantization_type: Optional[int] = None,\n) -> torch.Tensor:\n    return torch.empty(\n        hidden_states.shape, dtype=torch.bfloat16, device=hidden_states.device\n    )\n\n\n@register_custom_op(fake_impl=_fake_fp8_block_scale_moe)\ndef trtllm_fp8_block_scale_moe_wrapper(\n    routing_logits: torch.Tensor,\n    routing_bias: Optional[torch.Tensor],\n    hidden_states: torch.Tensor,\n    hidden_states_scale: torch.Tensor,\n    gemm1_weights: torch.Tensor,\n    gemm1_weights_scale: torch.Tensor,\n    gemm2_weights: torch.Tensor,\n    gemm2_weights_scale: torch.Tensor,\n    num_experts: int,\n    top_k: int,\n    n_group: Optional[int],\n    topk_group: Optional[int],\n    intermediate_size: int,\n    local_expert_offset: int,\n    local_num_experts: int,\n    routed_scaling_factor: Optional[float],\n    routing_method_type: int = 0,\n    use_shuffled_weight: bool = False,\n    weight_layout: int = 0,\n    enable_pdl: Optional[bool] = None,\n    tune_max_num_tokens: int = 8192,\n    fp8_quantization_type: Optional[int] = None,\n) -> torch.Tensor:\n    try:\n        from flashinfer.fused_moe import trtllm_fp8_block_scale_moe\n    except ImportError as e:\n        raise ImportError(\n            \"Can't import trtllm_fp8_block_scale_moe from flashinfer. \"\n            \"Please check flashinfer version.\"\n        ) from e\n    kwargs = {\n        \"routing_logits\": routing_logits,\n        \"routing_bias\": routing_bias,\n        \"hidden_states\": hidden_states,\n        \"hidden_states_scale\": hidden_states_scale,\n        \"gemm1_weights\": gemm1_weights,\n        \"gemm1_weights_scale\": gemm1_weights_scale,\n        \"gemm2_weights\": gemm2_weights,\n        \"gemm2_weights_scale\": gemm2_weights_scale,\n        \"num_experts\": num_experts,\n        \"top_k\": top_k,\n        \"n_group\": n_group,\n        \"topk_group\": topk_group,\n        \"intermediate_size\": intermediate_size,\n        \"local_expert_offset\": local_expert_offset,\n        \"local_num_experts\": local_num_experts,\n        \"routed_scaling_factor\": routed_scaling_factor,\n        \"routing_method_type\": routing_method_type,\n        \"use_shuffled_weight\": use_shuffled_weight,\n        \"weight_layout\": weight_layout,\n        \"enable_pdl\": enable_pdl,\n        \"tune_max_num_tokens\": tune_max_num_tokens,\n    }\n    if fp8_quantization_type is not None:\n        from flashinfer.fused_moe import Fp8QuantizationType\n\n        kwargs[\"fp8_quantization_type\"] = Fp8QuantizationType(fp8_quantization_type)\n\n    return trtllm_fp8_block_scale_moe(**kwargs)\n\n\ndef _fake_fp8_block_scale_routed_moe(\n    topk_ids: torch.Tensor,\n    routing_bias: Optional[torch.Tensor],\n    hidden_states: torch.Tensor,\n    hidden_states_scale: torch.Tensor,\n    gemm1_weights: torch.Tensor,\n    gemm1_weights_scale: torch.Tensor,\n    gemm2_weights: torch.Tensor,\n    gemm2_weights_scale: torch.Tensor,\n    num_experts: int,\n    top_k: int,\n    n_group: Optional[int],\n    topk_group: Optional[int],\n    intermediate_size: int,\n    local_expert_offset: int,\n    local_num_experts: int,\n    routed_scaling_factor: Optional[float],\n    routing_method_type: int = 0,\n    use_shuffled_weight: bool = False,\n    weight_layout: int = 0,\n    enable_pdl: Optional[bool] = None,\n    tune_max_num_tokens: int = 8192,\n    fp8_quantization_type: Optional[int] = None,\n) -> torch.Tensor:\n    return torch.empty(\n        hidden_states.shape, dtype=torch.bfloat16, device=hidden_states.device\n    )\n\n\n@register_custom_op(fake_impl=_fake_fp8_block_scale_routed_moe)\ndef trtllm_fp8_block_scale_routed_moe_wrapper(\n    topk_ids: torch.Tensor,\n    routing_bias: Optional[torch.Tensor],\n    hidden_states: torch.Tensor,\n    hidden_states_scale: torch.Tensor,\n    gemm1_weights: torch.Tensor,\n    gemm1_weights_scale: torch.Tensor,\n    gemm2_weights: torch.Tensor,\n    gemm2_weights_scale: torch.Tensor,\n    num_experts: int,\n    top_k: int,\n    n_group: Optional[int],\n    topk_group: Optional[int],\n    intermediate_size: int,\n    local_expert_offset: int,\n    local_num_experts: int,\n    routed_scaling_factor: Optional[float],\n    routing_method_type: int = 0,\n    use_shuffled_weight: bool = False,\n    weight_layout: int = 0,\n    enable_pdl: Optional[bool] = None,\n    tune_max_num_tokens: int = 8192,\n    fp8_quantization_type: Optional[int] = None,\n) -> torch.Tensor:\n    try:\n        from flashinfer.fused_moe import trtllm_fp8_block_scale_routed_moe\n    except ImportError as e:\n        raise ImportError(\n            \"Can't import trtllm_fp8_block_scale_routed_moe from flashinfer. \"\n            \"Please check flashinfer version.\"\n        ) from e\n    kwargs = {\n        \"topk_ids\": topk_ids,\n        \"routing_bias\": routing_bias,\n        \"hidden_states\": hidden_states,\n        \"hidden_states_scale\": hidden_states_scale,\n        \"gemm1_weights\": gemm1_weights,\n        \"gemm1_weights_scale\": gemm1_weights_scale,\n        \"gemm2_weights\": gemm2_weights,\n        \"gemm2_weights_scale\": gemm2_weights_scale,\n        \"num_experts\": num_experts,\n        \"top_k\": top_k,\n        \"n_group\": n_group,\n        \"topk_group\": topk_group,\n        \"intermediate_size\": intermediate_size,\n        \"local_expert_offset\": local_expert_offset,\n        \"local_num_experts\": local_num_experts,\n        \"routed_scaling_factor\": routed_scaling_factor,\n        \"routing_method_type\": routing_method_type,\n        \"use_shuffled_weight\": use_shuffled_weight,\n        \"weight_layout\": weight_layout,\n        \"enable_pdl\": enable_pdl,\n        \"tune_max_num_tokens\": tune_max_num_tokens,\n    }\n    if fp8_quantization_type is not None:\n        from flashinfer.fused_moe import Fp8QuantizationType\n\n        kwargs[\"fp8_quantization_type\"] = Fp8QuantizationType(fp8_quantization_type)\n\n    return trtllm_fp8_block_scale_routed_moe(**kwargs)\n\n\ndef _fake_fp8_per_tensor_scale_moe(\n    routing_logits: torch.Tensor,\n    routing_bias: Optional[torch.Tensor],\n    hidden_states: torch.Tensor,\n    gemm1_weights: torch.Tensor,\n    output1_scales_scalar: torch.Tensor,\n    output1_scales_gate_scalar: torch.Tensor,\n    gemm2_weights: torch.Tensor,\n    output2_scales_scalar: torch.Tensor,\n    num_experts: int,\n    top_k: int,\n    n_group: Optional[int],\n    topk_group: Optional[int],\n    intermediate_size: int,\n    local_expert_offset: int,\n    local_num_experts: int,\n    routed_scaling_factor: Optional[float],\n    use_routing_scales_on_input: bool,\n    routing_method_type: int = 0,\n    enable_pdl: Optional[bool] = None,\n    tune_max_num_tokens: int = 8192,\n) -> torch.Tensor:\n    return torch.empty(\n        hidden_states.shape, dtype=torch.bfloat16, device=hidden_states.device\n    )\n\n\n@register_custom_op(fake_impl=_fake_fp8_per_tensor_scale_moe)\ndef trtllm_fp8_per_tensor_scale_moe_wrapper(\n    routing_logits: torch.Tensor,\n    routing_bias: Optional[torch.Tensor],\n    hidden_states: torch.Tensor,\n    gemm1_weights: torch.Tensor,\n    output1_scales_scalar: torch.Tensor,\n    output1_scales_gate_scalar: torch.Tensor,\n    gemm2_weights: torch.Tensor,\n    output2_scales_scalar: torch.Tensor,\n    num_experts: int,\n    top_k: int,\n    n_group: Optional[int],\n    topk_group: Optional[int],\n    intermediate_size: int,\n    local_expert_offset: int,\n    local_num_experts: int,\n    routed_scaling_factor: Optional[float],\n    use_routing_scales_on_input: bool,\n    routing_method_type: int = 0,\n    enable_pdl: Optional[bool] = None,\n    tune_max_num_tokens: int = 8192,\n) -> torch.Tensor:\n    # lazy import\n    try:\n        from flashinfer.fused_moe import trtllm_fp8_per_tensor_scale_moe\n    except ImportError as e:\n        raise ImportError(\n            \"Can't import trtllm_fp8_per_tensor_scale_moe from flashinfer. \"\n            \"Please check flashinfer version.\"\n        ) from e\n\n    kwargs = {\n        \"routing_logits\": routing_logits,\n        \"routing_bias\": routing_bias,\n        \"hidden_states\": hidden_states,\n        \"gemm1_weights\": gemm1_weights,\n        \"output1_scales_scalar\": output1_scales_scalar,\n        \"output1_scales_gate_scalar\": output1_scales_gate_scalar,\n        \"gemm2_weights\": gemm2_weights,\n        \"output2_scales_scalar\": output2_scales_scalar,\n        \"num_experts\": num_experts,\n        \"top_k\": top_k,\n        \"n_group\": n_group,\n        \"topk_group\": topk_group,\n        \"intermediate_size\": intermediate_size,\n        \"local_expert_offset\": local_expert_offset,\n        \"local_num_experts\": local_num_experts,\n        \"routed_scaling_factor\": routed_scaling_factor,\n        \"use_routing_scales_on_input\": use_routing_scales_on_input,\n        \"routing_method_type\": routing_method_type,\n        \"enable_pdl\": enable_pdl,\n        \"tune_max_num_tokens\": tune_max_num_tokens,\n    }\n\n    return trtllm_fp8_per_tensor_scale_moe(**kwargs)\n"
  },
  {
    "path": "python/sglang/srt/layers/moe/fused_moe_native.py",
    "content": "\"\"\"\nTorch-native implementation for FusedMoE. This is used for torch.compile.\nIt is based on https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/mixtral-moe/model.py#L204\n\"\"\"\n\nimport torch\nfrom torch.nn import functional as F\n\nfrom sglang.srt.layers.activation import GeluAndMul, SiluAndMul\nfrom sglang.srt.layers.moe.moe_runner import MoeRunnerConfig\nfrom sglang.srt.layers.moe.token_dispatcher import (\n    StandardCombineInput,\n    StandardDispatchOutput,\n)\nfrom sglang.srt.layers.moe.topk import StandardTopKOutput\n\n\ndef fused_moe_forward_native(\n    layer: torch.nn.Module,\n    dispatch_output: StandardDispatchOutput,\n) -> StandardCombineInput:\n\n    x, x_scale, topk_output = dispatch_output\n    moe_runner_config = layer.moe_runner_config\n\n    if moe_runner_config.apply_router_weight_on_input:\n        raise NotImplementedError()\n\n    topk_weights, topk_ids, _ = topk_output\n\n    w13_weights = layer.w13_weight[topk_ids]\n    w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)\n    w2_weights = layer.w2_weight[topk_ids]\n    x1 = torch.einsum(\"ti,taoi -> tao\", x, w1_weights)\n    if moe_runner_config.activation == \"silu\":\n        x1 = F.silu(x1)\n    elif moe_runner_config.activation == \"gelu\":\n        x1 = F.gelu(x1)\n    else:\n        raise ValueError(f\"Unsupported activation: {moe_runner_config.activation=}\")\n    x3 = torch.einsum(\"ti, taoi -> tao\", x, w3_weights)\n    expert_outs = torch.einsum(\"tao, taio -> tai\", (x1 * x3), w2_weights)\n    expert_outs = torch.einsum(\n        \"tai,ta -> ti\", expert_outs, topk_weights.to(expert_outs.dtype)\n    )\n    return StandardCombineInput(hidden_states=expert_outs)\n\n\ndef moe_forward_native(\n    layer: torch.nn.Module,\n    x: torch.Tensor,\n    topk_output: StandardTopKOutput,\n    moe_runner_config: MoeRunnerConfig,\n) -> torch.Tensor:\n\n    if moe_runner_config.apply_router_weight_on_input:\n        raise NotImplementedError()\n\n    topk_weights, topk_ids, _ = topk_output\n\n    # Ref code from https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/e0828e3cc0a03408724b80c3cc92c8e072db8d01/modeling_deepseek.py#L589\n    len_experts = layer.num_experts\n\n    cnts = topk_ids.new_zeros((topk_ids.shape[0], len_experts))\n    cnts.scatter_(1, topk_ids.to(torch.int64), 1)\n    tokens_per_expert = cnts.sum(dim=0)\n    idxs = topk_ids.view(-1).argsort()\n\n    sorted_tokens = x[idxs // topk_ids.shape[1]]\n    tokens_per_expert = tokens_per_expert.cpu().numpy()\n\n    if moe_runner_config.activation == \"silu\":\n        act = SiluAndMul()\n    elif moe_runner_config.activation == \"gelu\":\n        act = GeluAndMul()\n    else:\n        raise ValueError(f\"Unsupported activation: {moe_runner_config.activation=}\")\n\n    outputs = []\n    start_idx = 0\n    for i, num_tokens in enumerate(tokens_per_expert):\n        end_idx = start_idx + num_tokens\n        if num_tokens == 0:\n            continue\n        tokens_for_this_expert = sorted_tokens[start_idx:end_idx]\n\n        layer_w13_weight = layer.w13_weight[i]\n        layer_w2_weight = layer.w2_weight[i]\n\n        gate_up = F.linear(tokens_for_this_expert, layer_w13_weight)\n        gate_up = act(gate_up)\n        expert_out = F.linear(gate_up, layer_w2_weight)\n        outputs.append(expert_out)\n        start_idx = end_idx\n\n    outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)\n    new_x = torch.empty_like(outs)\n\n    new_x[idxs] = outs\n    final_out = (\n        new_x.view(*topk_ids.shape, -1)\n        .type(topk_weights.dtype)\n        .mul_(topk_weights.unsqueeze(dim=-1))\n        .sum(dim=1)\n        .type(new_x.dtype)\n    )\n    return final_out\n"
  },
  {
    "path": "python/sglang/srt/layers/moe/fused_moe_triton/__init__.py",
    "content": "from contextlib import contextmanager\nfrom typing import Any, Dict, Optional\n\nfrom sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts\nfrom sglang.srt.layers.moe.fused_moe_triton.fused_moe_triton_config import (\n    get_config_file_name,\n    try_get_optimal_moe_config,\n)\nfrom sglang.srt.layers.moe.fused_moe_triton.layer import (\n    FusedMoE,\n    FusedMoeWeightScaleSupported,\n)\nfrom sglang.srt.layers.moe.fused_moe_triton.moe_align_block_size import (\n    moe_align_block_size,\n)\n\n_config: Optional[Dict[str, Any]] = None\n\n\n@contextmanager\ndef override_config(config):\n    global _config\n    old_config = _config\n    _config = config\n    yield\n    _config = old_config\n\n\ndef get_config() -> Optional[Dict[str, Any]]:\n    return _config\n\n\n__all__ = [\n    \"FusedMoE\",\n    \"FusedMoeWeightScaleSupported\",\n    \"override_config\",\n    \"get_config\",\n    \"fused_experts\",\n    \"get_config_file_name\",\n    \"moe_align_block_size\",\n    \"try_get_optimal_moe_config\",\n]\n"
  },
  {
    "path": "python/sglang/srt/layers/moe/fused_moe_triton/configs/README.md",
    "content": "# Fused MoE Triton Kernel Configurations\n\nThis directory contains tuned configurations for different settings of the fused_moe kernel.\n\n## Configuration Parameters\n\nEach configuration file is generated based on the following parameters:\n\n- **E** (number of experts): Total number of experts in the MoE layer\n- **N** (intermediate size): The intermediate/hidden dimension size\n  - For Tensor Parallelism (TP): `N = original_intermediate_size / tp_size`\n  - Example: Mixtral has N = 14336. For TP=2, N = 7168; for TP=4, N = 3584\n- **device_name**: GPU device name from `torch.cuda.get_device_name()`\n  - Examples: `NVIDIA_H100_80GB_HBM3`, `NVIDIA_A100-SXM4-80GB`, `NVIDIA_GeForce_RTX_4090`\n- **dtype**: Data type for computation\n  - Supported types: `fp8_w8a8`, `int8_w8a8`, `int8_w8a16`, `int4_w4a16`, etc.\n  - Determines precision and quantization scheme for weights and activations\n- **block_shape**: Block quantization shape (for DeepSeek V3/R1 models)\n  - Defines granularity for block-wise quantization, specified as `[block_n, block_k]`\n  - Example: DeepSeek V3 commonly uses `[128, 128]` for efficient block-wise FP8 quantization\n- **tp_size**: Tensor Parallelism size (affects N parameter)\n- **ep_size**: Expert Parallelism size (affects E parameter when EP is enabled)\n- **per_channel_quant**: Whether per-channel quantization is used\n\n## Configuration File Format\n\nEach JSON file contains a mapping from **M** (batch size) to the optimal kernel configuration for that batch size. The configuration includes parameters like `BLOCK_M`, `BLOCK_N`, `BLOCK_K`, `GROUP_M`, number of warps, and pipeline stages.\n\n**Filename Format**:\n```\nE={E},N={N},device_name={device_name},dtype={dtype}[,block_shape={block_shape}][,per_channel_quant={bool}].json\n```\n\n## Generating Configuration Files\n\nTo generate new configuration files for your specific hardware and model settings, use the tuning tools:\n\n**📖 Full Documentation**: [Tuning Triton MoE Kernels](https://github.com/sgl-project/sglang/tree/main/benchmark/kernels/fused_moe_triton)\n\nAfter tuning, move the generated JSON files to this directory to use them in SGLang.\n"
  },
  {
    "path": "python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json",
    "content": "{\n  \"1\": {\n    \"BLOCK_SIZE_M\": 16,\n    \"BLOCK_SIZE_N\": 32,\n    \"BLOCK_SIZE_K\": 128,\n    \"GROUP_SIZE_M\": 1,\n    \"num_warps\": 4,\n    \"num_stages\": 4\n  },\n  \"2\": {\n    \"BLOCK_SIZE_M\": 16,\n    \"BLOCK_SIZE_N\": 32,\n    \"BLOCK_SIZE_K\": 128,\n    \"GROUP_SIZE_M\": 64,\n    \"num_warps\": 4,\n    \"num_stages\": 3\n  },\n  \"4\": {\n    \"BLOCK_SIZE_M\": 16,\n    \"BLOCK_SIZE_N\": 32,\n    \"BLOCK_SIZE_K\": 128,\n    \"GROUP_SIZE_M\": 1,\n    \"num_warps\": 4,\n    \"num_stages\": 4\n  },\n  \"8\": {\n    \"BLOCK_SIZE_M\": 16,\n    \"BLOCK_SIZE_N\": 32,\n    \"BLOCK_SIZE_K\": 256,\n    \"GROUP_SIZE_M\": 32,\n    \"num_warps\": 4,\n    \"num_stages\": 3\n  },\n  \"16\": {\n    \"BLOCK_SIZE_M\": 16,\n    \"BLOCK_SIZE_N\": 32,\n    \"BLOCK_SIZE_K\": 128,\n    \"GROUP_SIZE_M\": 1,\n    \"num_warps\": 4,\n    \"num_stages\": 4\n  },\n  \"24\": {\n    \"BLOCK_SIZE_M\": 16,\n    \"BLOCK_SIZE_N\": 64,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 16,\n    \"num_warps\": 4,\n    \"num_stages\": 5\n  },\n  \"32\": {\n    \"BLOCK_SIZE_M\": 16,\n    \"BLOCK_SIZE_N\": 32,\n    \"BLOCK_SIZE_K\": 256,\n    \"GROUP_SIZE_M\": 1,\n    \"num_warps\": 4,\n    \"num_stages\": 2\n  },\n  \"48\": {\n    \"BLOCK_SIZE_M\": 64,\n    \"BLOCK_SIZE_N\": 64,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 1,\n    \"num_warps\": 4,\n    \"num_stages\": 3\n  },\n  \"64\": {\n    \"BLOCK_SIZE_M\": 64,\n    \"BLOCK_SIZE_N\": 64,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 1,\n    \"num_warps\": 4,\n    \"num_stages\": 3\n  },\n  \"96\": {\n    \"BLOCK_SIZE_M\": 32,\n    \"BLOCK_SIZE_N\": 128,\n    \"BLOCK_SIZE_K\": 128,\n    \"GROUP_SIZE_M\": 1,\n    \"num_warps\": 4,\n    \"num_stages\": 3\n  },\n  \"128\": {\n    \"BLOCK_SIZE_M\": 64,\n    \"BLOCK_SIZE_N\": 64,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 16,\n    \"num_warps\": 4,\n    \"num_stages\": 3\n  },\n  \"256\": {\n    \"BLOCK_SIZE_M\": 64,\n    \"BLOCK_SIZE_N\": 64,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 32,\n    \"num_warps\": 4,\n    \"num_stages\": 4\n  },\n  \"512\": {\n    \"BLOCK_SIZE_M\": 64,\n    \"BLOCK_SIZE_N\": 256,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 32,\n    \"num_warps\": 4,\n    \"num_stages\": 4\n  },\n  \"1024\": {\n    \"BLOCK_SIZE_M\": 64,\n    \"BLOCK_SIZE_N\": 256,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 64,\n    \"num_warps\": 4,\n    \"num_stages\": 4\n  },\n  \"1536\": {\n    \"BLOCK_SIZE_M\": 64,\n    \"BLOCK_SIZE_N\": 256,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 64,\n    \"num_warps\": 4,\n    \"num_stages\": 4\n  },\n  \"2048\": {\n    \"BLOCK_SIZE_M\": 64,\n    \"BLOCK_SIZE_N\": 256,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 32,\n    \"num_warps\": 4,\n    \"num_stages\": 4\n  },\n  \"3072\": {\n    \"BLOCK_SIZE_M\": 64,\n    \"BLOCK_SIZE_N\": 256,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 32,\n    \"num_warps\": 4,\n    \"num_stages\": 4\n  },\n  \"4096\": {\n    \"BLOCK_SIZE_M\": 64,\n    \"BLOCK_SIZE_N\": 256,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 16,\n    \"num_warps\": 4,\n    \"num_stages\": 4\n  }\n}\n"
  },
  {
    "path": "python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json",
    "content": "{\n  \"1\": {\n    \"BLOCK_SIZE_M\": 16,\n    \"BLOCK_SIZE_N\": 32,\n    \"BLOCK_SIZE_K\": 256,\n    \"GROUP_SIZE_M\": 32,\n    \"num_warps\": 4,\n    \"num_stages\": 3\n  },\n  \"2\": {\n    \"BLOCK_SIZE_M\": 16,\n    \"BLOCK_SIZE_N\": 32,\n    \"BLOCK_SIZE_K\": 256,\n    \"GROUP_SIZE_M\": 16,\n    \"num_warps\": 4,\n    \"num_stages\": 3\n  },\n  \"4\": {\n    \"BLOCK_SIZE_M\": 16,\n    \"BLOCK_SIZE_N\": 32,\n    \"BLOCK_SIZE_K\": 256,\n    \"GROUP_SIZE_M\": 32,\n    \"num_warps\": 4,\n    \"num_stages\": 3\n  },\n  \"8\": {\n    \"BLOCK_SIZE_M\": 16,\n    \"BLOCK_SIZE_N\": 32,\n    \"BLOCK_SIZE_K\": 256,\n    \"GROUP_SIZE_M\": 64,\n    \"num_warps\": 4,\n    \"num_stages\": 3\n  },\n  \"16\": {\n    \"BLOCK_SIZE_M\": 16,\n    \"BLOCK_SIZE_N\": 32,\n    \"BLOCK_SIZE_K\": 256,\n    \"GROUP_SIZE_M\": 32,\n    \"num_warps\": 4,\n    \"num_stages\": 3\n  },\n  \"24\": {\n    \"BLOCK_SIZE_M\": 16,\n    \"BLOCK_SIZE_N\": 32,\n    \"BLOCK_SIZE_K\": 256,\n    \"GROUP_SIZE_M\": 1,\n    \"num_warps\": 4,\n    \"num_stages\": 3\n  },\n  \"32\": {\n    \"BLOCK_SIZE_M\": 16,\n    \"BLOCK_SIZE_N\": 32,\n    \"BLOCK_SIZE_K\": 256,\n    \"GROUP_SIZE_M\": 1,\n    \"num_warps\": 4,\n    \"num_stages\": 3\n  },\n  \"48\": {\n    \"BLOCK_SIZE_M\": 16,\n    \"BLOCK_SIZE_N\": 128,\n    \"BLOCK_SIZE_K\": 128,\n    \"GROUP_SIZE_M\": 1,\n    \"num_warps\": 8,\n    \"num_stages\": 3\n  },\n  \"64\": {\n    \"BLOCK_SIZE_M\": 64,\n    \"BLOCK_SIZE_N\": 64,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 1,\n    \"num_warps\": 4,\n    \"num_stages\": 3\n  },\n  \"96\": {\n    \"BLOCK_SIZE_M\": 32,\n    \"BLOCK_SIZE_N\": 128,\n    \"BLOCK_SIZE_K\": 128,\n    \"GROUP_SIZE_M\": 1,\n    \"num_warps\": 4,\n    \"num_stages\": 3\n  },\n  \"128\": {\n    \"BLOCK_SIZE_M\": 64,\n    \"BLOCK_SIZE_N\": 128,\n    \"BLOCK_SIZE_K\": 128,\n    \"GROUP_SIZE_M\": 1,\n    \"num_warps\": 8,\n    \"num_stages\": 3\n  },\n  \"256\": {\n    \"BLOCK_SIZE_M\": 64,\n    \"BLOCK_SIZE_N\": 64,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 64,\n    \"num_warps\": 4,\n    \"num_stages\": 3\n  },\n  \"512\": {\n    \"BLOCK_SIZE_M\": 64,\n    \"BLOCK_SIZE_N\": 64,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 1,\n    \"num_warps\": 4,\n    \"num_stages\": 4\n  },\n  \"1024\": {\n    \"BLOCK_SIZE_M\": 64,\n    \"BLOCK_SIZE_N\": 64,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 1,\n    \"num_warps\": 4,\n    \"num_stages\": 3\n  },\n  \"1536\": {\n    \"BLOCK_SIZE_M\": 64,\n    \"BLOCK_SIZE_N\": 256,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 1,\n    \"num_warps\": 4,\n    \"num_stages\": 4\n  },\n  \"2048\": {\n    \"BLOCK_SIZE_M\": 64,\n    \"BLOCK_SIZE_N\": 256,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 32,\n    \"num_warps\": 4,\n    \"num_stages\": 4\n  },\n  \"3072\": {\n    \"BLOCK_SIZE_M\": 64,\n    \"BLOCK_SIZE_N\": 256,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 64,\n    \"num_warps\": 4,\n    \"num_stages\": 4\n  },\n  \"4096\": {\n    \"BLOCK_SIZE_M\": 64,\n    \"BLOCK_SIZE_N\": 256,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 64,\n    \"num_warps\": 4,\n    \"num_stages\": 4\n  },\n  \"5120\": {\n    \"BLOCK_SIZE_M\": 64,\n    \"BLOCK_SIZE_N\": 256,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 16,\n    \"num_warps\": 4,\n    \"num_stages\": 4\n  },\n  \"9216\": {\n    \"BLOCK_SIZE_M\": 64,\n    \"BLOCK_SIZE_N\": 256,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 32,\n    \"num_warps\": 4,\n    \"num_stages\": 4\n  },\n  \"13312\": {\n    \"BLOCK_SIZE_M\": 64,\n    \"BLOCK_SIZE_N\": 256,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 16,\n    \"num_warps\": 4,\n    \"num_stages\": 4\n  },\n  \"17408\": {\n    \"BLOCK_SIZE_M\": 64,\n    \"BLOCK_SIZE_N\": 256,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 32,\n    \"num_warps\": 4,\n    \"num_stages\": 4\n  },\n  \"25600\": {\n    \"BLOCK_SIZE_M\": 64,\n    \"BLOCK_SIZE_N\": 256,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 16,\n    \"num_warps\": 4,\n    \"num_stages\": 4\n  },\n  \"33792\": {\n    \"BLOCK_SIZE_M\": 64,\n    \"BLOCK_SIZE_N\": 256,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 16,\n    \"num_warps\": 4,\n    \"num_stages\": 4\n  },\n  \"41984\": {\n    \"BLOCK_SIZE_M\": 64,\n    \"BLOCK_SIZE_N\": 256,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 16,\n    \"num_warps\": 4,\n    \"num_stages\": 4\n  },\n  \"50176\": {\n    \"BLOCK_SIZE_M\": 64,\n    \"BLOCK_SIZE_N\": 256,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 32,\n    \"num_warps\": 4,\n    \"num_stages\": 4\n  },\n  \"58368\": {\n    \"BLOCK_SIZE_M\": 64,\n    \"BLOCK_SIZE_N\": 256,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 16,\n    \"num_warps\": 4,\n    \"num_stages\": 4\n  }\n}\n"
  },
  {
    "path": "python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json",
    "content": "{\n  \"1\": {\n    \"BLOCK_SIZE_M\": 16,\n    \"BLOCK_SIZE_N\": 64,\n    \"BLOCK_SIZE_K\": 128,\n    \"GROUP_SIZE_M\": 16,\n    \"num_warps\": 4,\n    \"num_stages\": 5\n  },\n  \"2\": {\n    \"BLOCK_SIZE_M\": 16,\n    \"BLOCK_SIZE_N\": 64,\n    \"BLOCK_SIZE_K\": 128,\n    \"GROUP_SIZE_M\": 1,\n    \"num_warps\": 4,\n    \"num_stages\": 5\n  },\n  \"4\": {\n    \"BLOCK_SIZE_M\": 16,\n    \"BLOCK_SIZE_N\": 64,\n    \"BLOCK_SIZE_K\": 128,\n    \"GROUP_SIZE_M\": 1,\n    \"num_warps\": 4,\n    \"num_stages\": 5\n  },\n  \"8\": {\n    \"BLOCK_SIZE_M\": 16,\n    \"BLOCK_SIZE_N\": 64,\n    \"BLOCK_SIZE_K\": 128,\n    \"GROUP_SIZE_M\": 1,\n    \"num_warps\": 4,\n    \"num_stages\": 5\n  },\n  \"16\": {\n    \"BLOCK_SIZE_M\": 16,\n    \"BLOCK_SIZE_N\": 64,\n    \"BLOCK_SIZE_K\": 128,\n    \"GROUP_SIZE_M\": 1,\n    \"num_warps\": 4,\n    \"num_stages\": 5\n  },\n  \"24\": {\n    \"BLOCK_SIZE_M\": 32,\n    \"BLOCK_SIZE_N\": 64,\n    \"BLOCK_SIZE_K\": 128,\n    \"GROUP_SIZE_M\": 1,\n    \"num_warps\": 4,\n    \"num_stages\": 5\n  },\n  \"32\": {\n    \"BLOCK_SIZE_M\": 32,\n    \"BLOCK_SIZE_N\": 64,\n    \"BLOCK_SIZE_K\": 128,\n    \"GROUP_SIZE_M\": 1,\n    \"num_warps\": 4,\n    \"num_stages\": 5\n  },\n  \"48\": {\n    \"BLOCK_SIZE_M\": 64,\n    \"BLOCK_SIZE_N\": 64,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 1,\n    \"num_warps\": 4,\n    \"num_stages\": 5\n  },\n  \"64\": {\n    \"BLOCK_SIZE_M\": 64,\n    \"BLOCK_SIZE_N\": 64,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 1,\n    \"num_warps\": 4,\n    \"num_stages\": 5\n  },\n  \"96\": {\n    \"BLOCK_SIZE_M\": 32,\n    \"BLOCK_SIZE_N\": 128,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 1,\n    \"num_warps\": 4,\n    \"num_stages\": 5\n  },\n  \"128\": {\n    \"BLOCK_SIZE_M\": 64,\n    \"BLOCK_SIZE_N\": 128,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 1,\n    \"num_warps\": 4,\n    \"num_stages\": 4\n  },\n  \"256\": {\n    \"BLOCK_SIZE_M\": 128,\n    \"BLOCK_SIZE_N\": 128,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 1,\n    \"num_warps\": 8,\n    \"num_stages\": 3\n  },\n  \"512\": {\n    \"BLOCK_SIZE_M\": 128,\n    \"BLOCK_SIZE_N\": 256,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 1,\n    \"num_warps\": 8,\n    \"num_stages\": 3\n  },\n  \"1024\": {\n    \"BLOCK_SIZE_M\": 64,\n    \"BLOCK_SIZE_N\": 128,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 1,\n    \"num_warps\": 4,\n    \"num_stages\": 4\n  },\n  \"1536\": {\n    \"BLOCK_SIZE_M\": 128,\n    \"BLOCK_SIZE_N\": 128,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 16,\n    \"num_warps\": 4,\n    \"num_stages\": 3\n  },\n  \"2048\": {\n    \"BLOCK_SIZE_M\": 128,\n    \"BLOCK_SIZE_N\": 128,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 64,\n    \"num_warps\": 4,\n    \"num_stages\": 3\n  },\n  \"3072\": {\n    \"BLOCK_SIZE_M\": 128,\n    \"BLOCK_SIZE_N\": 128,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 64,\n    \"num_warps\": 4,\n    \"num_stages\": 3\n  },\n  \"4096\": {\n    \"BLOCK_SIZE_M\": 128,\n    \"BLOCK_SIZE_N\": 128,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 16,\n    \"num_warps\": 4,\n    \"num_stages\": 3\n  },\n  \"5120\": {\n    \"BLOCK_SIZE_M\": 128,\n    \"BLOCK_SIZE_N\": 128,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 16,\n    \"num_warps\": 4,\n    \"num_stages\": 3\n  },\n  \"9216\": {\n    \"BLOCK_SIZE_M\": 128,\n    \"BLOCK_SIZE_N\": 128,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 16,\n    \"num_warps\": 4,\n    \"num_stages\": 3\n  },\n  \"13312\": {\n    \"BLOCK_SIZE_M\": 128,\n    \"BLOCK_SIZE_N\": 128,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 16,\n    \"num_warps\": 4,\n    \"num_stages\": 3\n  },\n  \"17408\": {\n    \"BLOCK_SIZE_M\": 128,\n    \"BLOCK_SIZE_N\": 128,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 16,\n    \"num_warps\": 4,\n    \"num_stages\": 3\n  },\n  \"25600\": {\n    \"BLOCK_SIZE_M\": 128,\n    \"BLOCK_SIZE_N\": 128,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 16,\n    \"num_warps\": 4,\n    \"num_stages\": 3\n  },\n  \"33792\": {\n    \"BLOCK_SIZE_M\": 128,\n    \"BLOCK_SIZE_N\": 128,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 16,\n    \"num_warps\": 4,\n    \"num_stages\": 3\n  },\n  \"41984\": {\n    \"BLOCK_SIZE_M\": 128,\n    \"BLOCK_SIZE_N\": 128,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 16,\n    \"num_warps\": 4,\n    \"num_stages\": 3\n  },\n  \"50176\": {\n    \"BLOCK_SIZE_M\": 128,\n    \"BLOCK_SIZE_N\": 128,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 16,\n    \"num_warps\": 4,\n    \"num_stages\": 3\n  },\n  \"58368\": {\n    \"BLOCK_SIZE_M\": 128,\n    \"BLOCK_SIZE_N\": 128,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 16,\n    \"num_warps\": 4,\n    \"num_stages\": 3\n  }\n}\n"
  },
  {
    "path": "python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json",
    "content": "{\n  \"1\": {\n    \"BLOCK_SIZE_M\": 16,\n    \"BLOCK_SIZE_N\": 64,\n    \"BLOCK_SIZE_K\": 256,\n    \"GROUP_SIZE_M\": 1,\n    \"num_warps\": 4,\n    \"num_stages\": 4\n  },\n  \"2\": {\n    \"BLOCK_SIZE_M\": 16,\n    \"BLOCK_SIZE_N\": 64,\n    \"BLOCK_SIZE_K\": 256,\n    \"GROUP_SIZE_M\": 1,\n    \"num_warps\": 4,\n    \"num_stages\": 3\n  },\n  \"4\": {\n    \"BLOCK_SIZE_M\": 16,\n    \"BLOCK_SIZE_N\": 64,\n    \"BLOCK_SIZE_K\": 256,\n    \"GROUP_SIZE_M\": 1,\n    \"num_warps\": 4,\n    \"num_stages\": 3\n  },\n  \"8\": {\n    \"BLOCK_SIZE_M\": 16,\n    \"BLOCK_SIZE_N\": 64,\n    \"BLOCK_SIZE_K\": 256,\n    \"GROUP_SIZE_M\": 1,\n    \"num_warps\": 4,\n    \"num_stages\": 3\n  },\n  \"16\": {\n    \"BLOCK_SIZE_M\": 16,\n    \"BLOCK_SIZE_N\": 64,\n    \"BLOCK_SIZE_K\": 256,\n    \"GROUP_SIZE_M\": 1,\n    \"num_warps\": 4,\n    \"num_stages\": 5\n  },\n  \"24\": {\n    \"BLOCK_SIZE_M\": 32,\n    \"BLOCK_SIZE_N\": 64,\n    \"BLOCK_SIZE_K\": 256,\n    \"GROUP_SIZE_M\": 1,\n    \"num_warps\": 8,\n    \"num_stages\": 4\n  },\n  \"32\": {\n    \"BLOCK_SIZE_M\": 64,\n    \"BLOCK_SIZE_N\": 64,\n    \"BLOCK_SIZE_K\": 128,\n    \"GROUP_SIZE_M\": 1,\n    \"num_warps\": 4,\n    \"num_stages\": 5\n  },\n  \"48\": {\n    \"BLOCK_SIZE_M\": 64,\n    \"BLOCK_SIZE_N\": 64,\n    \"BLOCK_SIZE_K\": 128,\n    \"GROUP_SIZE_M\": 1,\n    \"num_warps\": 4,\n    \"num_stages\": 5\n  },\n  \"64\": {\n    \"BLOCK_SIZE_M\": 64,\n    \"BLOCK_SIZE_N\": 64,\n    \"BLOCK_SIZE_K\": 128,\n    \"GROUP_SIZE_M\": 1,\n    \"num_warps\": 4,\n    \"num_stages\": 5\n  },\n  \"96\": {\n    \"BLOCK_SIZE_M\": 64,\n    \"BLOCK_SIZE_N\": 128,\n    \"BLOCK_SIZE_K\": 128,\n    \"GROUP_SIZE_M\": 16,\n    \"num_warps\": 4,\n    \"num_stages\": 4\n  },\n  \"128\": {\n    \"BLOCK_SIZE_M\": 64,\n    \"BLOCK_SIZE_N\": 128,\n    \"BLOCK_SIZE_K\": 128,\n    \"GROUP_SIZE_M\": 16,\n    \"num_warps\": 4,\n    \"num_stages\": 4\n  },\n  \"256\": {\n    \"BLOCK_SIZE_M\": 128,\n    \"BLOCK_SIZE_N\": 128,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 64,\n    \"num_warps\": 8,\n    \"num_stages\": 5\n  },\n  \"512\": {\n    \"BLOCK_SIZE_M\": 128,\n    \"BLOCK_SIZE_N\": 256,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 16,\n    \"num_warps\": 8,\n    \"num_stages\": 4\n  },\n  \"1024\": {\n    \"BLOCK_SIZE_M\": 128,\n    \"BLOCK_SIZE_N\": 256,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 32,\n    \"num_warps\": 8,\n    \"num_stages\": 4\n  },\n  \"1536\": {\n    \"BLOCK_SIZE_M\": 128,\n    \"BLOCK_SIZE_N\": 256,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 32,\n    \"num_warps\": 8,\n    \"num_stages\": 4\n  },\n  \"2048\": {\n    \"BLOCK_SIZE_M\": 128,\n    \"BLOCK_SIZE_N\": 256,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 16,\n    \"num_warps\": 8,\n    \"num_stages\": 3\n  },\n  \"3072\": {\n    \"BLOCK_SIZE_M\": 128,\n    \"BLOCK_SIZE_N\": 256,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 64,\n    \"num_warps\": 8,\n    \"num_stages\": 4\n  },\n  \"4096\": {\n    \"BLOCK_SIZE_M\": 128,\n    \"BLOCK_SIZE_N\": 256,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 32,\n    \"num_warps\": 8,\n    \"num_stages\": 4\n  },\n  \"5120\": {\n    \"BLOCK_SIZE_M\": 128,\n    \"BLOCK_SIZE_N\": 256,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 16,\n    \"num_warps\": 8,\n    \"num_stages\": 4\n  },\n  \"9216\": {\n    \"BLOCK_SIZE_M\": 128,\n    \"BLOCK_SIZE_N\": 256,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 16,\n    \"num_warps\": 8,\n    \"num_stages\": 3\n  },\n  \"13312\": {\n    \"BLOCK_SIZE_M\": 128,\n    \"BLOCK_SIZE_N\": 256,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 16,\n    \"num_warps\": 8,\n    \"num_stages\": 3\n  },\n  \"17408\": {\n    \"BLOCK_SIZE_M\": 128,\n    \"BLOCK_SIZE_N\": 256,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 16,\n    \"num_warps\": 8,\n    \"num_stages\": 3\n  },\n  \"25600\": {\n    \"BLOCK_SIZE_M\": 128,\n    \"BLOCK_SIZE_N\": 256,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 16,\n    \"num_warps\": 8,\n    \"num_stages\": 4\n  },\n  \"33792\": {\n    \"BLOCK_SIZE_M\": 128,\n    \"BLOCK_SIZE_N\": 256,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 16,\n    \"num_warps\": 8,\n    \"num_stages\": 3\n  },\n  \"41984\": {\n    \"BLOCK_SIZE_M\": 128,\n    \"BLOCK_SIZE_N\": 256,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 16,\n    \"num_warps\": 8,\n    \"num_stages\": 3\n  },\n  \"50176\": {\n    \"BLOCK_SIZE_M\": 128,\n    \"BLOCK_SIZE_N\": 256,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 16,\n    \"num_warps\": 8,\n    \"num_stages\": 3\n  },\n  \"58368\": {\n    \"BLOCK_SIZE_M\": 128,\n    \"BLOCK_SIZE_N\": 256,\n    \"BLOCK_SIZE_K\": 64,\n    \"GROUP_SIZE_M\": 16,\n    \"num_warps\": 8,\n    \"num_stages\": 3\n  }\n}\n"
  },
  {
    "path": "python/sglang/srt/layers/quantization/quark/__init__.py",
    "content": ""
  },
  {
    "path": "python/sglang/srt/models/deepseek_common/__init__.py",
    "content": ""
  },
  {
    "path": "python/sglang/test/__init__.py",
    "content": ""
  },
  {
    "path": "python/sglang/test/ascend/__init__.py",
    "content": ""
  },
  {
    "path": "python/sglang/test/attention/__init__.py",
    "content": ""
  },
  {
    "path": "sgl-kernel/python/sgl_kernel/testing/__init__.py",
    "content": ""
  },
  {
    "path": "sgl-model-gateway/e2e_test/benchmarks/__init__.py",
    "content": ""
  },
  {
    "path": "sgl-model-gateway/e2e_test/chat_completions/__init__.py",
    "content": ""
  },
  {
    "path": "sgl-model-gateway/e2e_test/embeddings/__init__.py",
    "content": ""
  },
  {
    "path": "sgl-model-gateway/e2e_test/router/__init__.py",
    "content": ""
  },
  {
    "path": "test/registered/debug_utils/comparator/__init__.py",
    "content": ""
  },
  {
    "path": "test/registered/debug_utils/comparator/aligner/__init__.py",
    "content": ""
  },
  {
    "path": "test/registered/debug_utils/comparator/aligner/conftest.py",
    "content": ""
  },
  {
    "path": "test/registered/debug_utils/comparator/aligner/entrypoint/__init__.py",
    "content": ""
  },
  {
    "path": "test/registered/debug_utils/comparator/aligner/entrypoint/conftest.py",
    "content": ""
  },
  {
    "path": "test/registered/debug_utils/comparator/aligner/reorderer/__init__.py",
    "content": ""
  },
  {
    "path": "test/registered/debug_utils/comparator/aligner/reorderer/conftest.py",
    "content": ""
  },
  {
    "path": "test/registered/debug_utils/comparator/aligner/token_aligner/__init__.py",
    "content": ""
  },
  {
    "path": "test/registered/debug_utils/comparator/aligner/token_aligner/conftest.py",
    "content": ""
  },
  {
    "path": "test/registered/debug_utils/comparator/aligner/unsharder/__init__.py",
    "content": ""
  },
  {
    "path": "test/registered/debug_utils/comparator/aligner/unsharder/conftest.py",
    "content": ""
  },
  {
    "path": "test/registered/debug_utils/comparator/dims_spec/__init__.py",
    "content": ""
  },
  {
    "path": "test/registered/debug_utils/comparator/tensor_comparator/__init__.py",
    "content": ""
  },
  {
    "path": "test/registered/debug_utils/comparator/tensor_comparator/conftest.py",
    "content": ""
  }
]